Compare commits
4 Commits
3da428e0e4
...
951df8155c
| Author | SHA1 | Date |
|---|---|---|
|
|
951df8155c | |
|
|
a58fab8d6e | |
|
|
a3c8296135 | |
|
|
c95ace41aa |
|
|
@ -170,15 +170,13 @@ class PagedCache:
|
||||||
written += chunk
|
written += chunk
|
||||||
|
|
||||||
def gather(self, layer_id: int, page_table: Tensor) -> Tuple[Tensor, Tensor]:
|
def gather(self, layer_id: int, page_table: Tensor) -> Tuple[Tensor, Tensor]:
|
||||||
k_parts, v_parts = [], []
|
# page_table: [batch, max_pages] with -1 padding for tasks with fewer pages.
|
||||||
for pi in range(page_table.size(1)):
|
# clamp(min=0) maps -1 to page 0 (irrelevant data) — truncated by CacheView total_len.
|
||||||
phys_pages = page_table[:, pi]
|
safe = page_table.clamp(min=0)
|
||||||
if not (phys_pages >= 0).any():
|
k = self.k_cache[layer_id, safe]
|
||||||
break
|
v = self.v_cache[layer_id, safe]
|
||||||
k_parts.append(self.k_cache[layer_id, phys_pages])
|
k = k.flatten(1, 2)
|
||||||
v_parts.append(self.v_cache[layer_id, phys_pages])
|
v = v.flatten(1, 2)
|
||||||
k = torch.cat(k_parts, dim=1)
|
|
||||||
v = torch.cat(v_parts, dim=1)
|
|
||||||
return k, v
|
return k, v
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -147,6 +147,13 @@ class InferenceScheduler:
|
||||||
if len(prompt_ids) > self.max_prompt_len:
|
if len(prompt_ids) > self.max_prompt_len:
|
||||||
prompt_ids = prompt_ids[-self.max_prompt_len :]
|
prompt_ids = prompt_ids[-self.max_prompt_len :]
|
||||||
|
|
||||||
|
if len(prompt_ids) >= self.max_seq_len:
|
||||||
|
if stream_callback:
|
||||||
|
stream_callback(STOP)
|
||||||
|
return task_id
|
||||||
|
|
||||||
|
max_tokens = min(max_tokens, self.max_seq_len - len(prompt_ids))
|
||||||
|
|
||||||
task = Task(
|
task = Task(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
prompt_ids=prompt_ids,
|
prompt_ids=prompt_ids,
|
||||||
|
|
@ -189,7 +196,10 @@ class InferenceScheduler:
|
||||||
def _remove_finished_tasks(self) -> None:
|
def _remove_finished_tasks(self) -> None:
|
||||||
finished = []
|
finished = []
|
||||||
for task in self.active_tasks:
|
for task in self.active_tasks:
|
||||||
if task.is_finished(self.tokenizer.stop_ids):
|
if task.status == TaskStatus.ABORTED:
|
||||||
|
task.finish_time = time.time()
|
||||||
|
finished.append(task)
|
||||||
|
elif task.is_finished(self.tokenizer.stop_ids):
|
||||||
task.status = TaskStatus.FINISHED
|
task.status = TaskStatus.FINISHED
|
||||||
task.finish_time = time.time()
|
task.finish_time = time.time()
|
||||||
finished.append(task)
|
finished.append(task)
|
||||||
|
|
@ -203,7 +213,9 @@ class InferenceScheduler:
|
||||||
task._pages_freed = True
|
task._pages_freed = True
|
||||||
|
|
||||||
self.active_tasks = [
|
self.active_tasks = [
|
||||||
t for t in self.active_tasks if t.status != TaskStatus.FINISHED
|
t
|
||||||
|
for t in self.active_tasks
|
||||||
|
if t.status not in (TaskStatus.FINISHED, TaskStatus.ABORTED)
|
||||||
]
|
]
|
||||||
|
|
||||||
def _refill_active_batch(self) -> None:
|
def _refill_active_batch(self) -> None:
|
||||||
|
|
@ -254,7 +266,9 @@ class InferenceScheduler:
|
||||||
|
|
||||||
seq_len = prompt_len - start_pos
|
seq_len = prompt_len - start_pos
|
||||||
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
|
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
|
||||||
input_mask = torch.ones(batch_sz, seq_len, dtype=torch.bool, device=self.device)
|
input_mask = torch.ones(
|
||||||
|
batch_sz, prompt_len, dtype=torch.bool, device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
for i, t in enumerate(tasks):
|
for i, t in enumerate(tasks):
|
||||||
input_ids[i] = torch.tensor(
|
input_ids[i] = torch.tensor(
|
||||||
|
|
@ -280,10 +294,21 @@ class InferenceScheduler:
|
||||||
return
|
return
|
||||||
|
|
||||||
tasks = sorted(tasks, key=lambda t: t.task_id)
|
tasks = sorted(tasks, key=lambda t: t.task_id)
|
||||||
batch_sz = len(tasks)
|
|
||||||
|
|
||||||
|
valid: List[Task] = []
|
||||||
for t in tasks:
|
for t in tasks:
|
||||||
self._maybe_alloc_page(t, start_pos)
|
if self._maybe_alloc_page(t, start_pos):
|
||||||
|
valid.append(t)
|
||||||
|
else:
|
||||||
|
t.status = TaskStatus.ABORTED
|
||||||
|
if t.stream_callback:
|
||||||
|
t.stream_callback(STOP)
|
||||||
|
|
||||||
|
if not valid:
|
||||||
|
return
|
||||||
|
|
||||||
|
tasks = valid
|
||||||
|
batch_sz = len(tasks)
|
||||||
|
|
||||||
input_ids = torch.tensor(
|
input_ids = torch.tensor(
|
||||||
[t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks],
|
[t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks],
|
||||||
|
|
@ -334,14 +359,15 @@ class InferenceScheduler:
|
||||||
rows = [t.page_table + [-1] * (max_pages - t.n_pages) for t in tasks]
|
rows = [t.page_table + [-1] * (max_pages - t.n_pages) for t in tasks]
|
||||||
return torch.tensor(rows, dtype=torch.long, device=self.device)
|
return torch.tensor(rows, dtype=torch.long, device=self.device)
|
||||||
|
|
||||||
def _maybe_alloc_page(self, task: Task, pos: int) -> None:
|
def _maybe_alloc_page(self, task: Task, pos: int) -> bool:
|
||||||
needed = self._n_pages_for(pos + 1)
|
needed = self._n_pages_for(pos + 1)
|
||||||
while task.n_pages < needed:
|
while task.n_pages < needed:
|
||||||
p = self.page_cache.alloc()
|
p = self.page_cache.alloc()
|
||||||
if p < 0:
|
if p < 0:
|
||||||
break
|
return False
|
||||||
task.page_table.append(p)
|
task.page_table.append(p)
|
||||||
task.n_pages += 1
|
task.n_pages += 1
|
||||||
|
return True
|
||||||
|
|
||||||
def _run_generation_loop(self) -> None:
|
def _run_generation_loop(self) -> None:
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,9 @@ def process_attention_mask(
|
||||||
|
|
||||||
if seq_mask is None:
|
if seq_mask is None:
|
||||||
if start_pos != 0:
|
if start_pos != 0:
|
||||||
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
|
seq_mask = torch.ones(
|
||||||
|
(1, start_pos + seq_len), dtype=torch.bool, device=device
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue