Compare commits

...

2 Commits

Author SHA1 Message Date
ViperEkura 523eacf5fe release: v1.3.4
- refactor: 分页 KV cache(PagedCache+CacheView)替换固定 slot,删除 PrefixCache
- refactor: 推理引擎控制逻辑重写,修复连续批处理核心缺陷、线程安全问题
- refactor: KV 缓存槽位下沉到注意力层,移除 _remap_kv / _writeback_kv
- refactor: 统一采样路径为 SamplingPipeline batch tensor,删除 apply_sampling_strategies
- refactor: 设计模式优化 inference 模块导入结构(cache/sampling 独立)
- feat: 推理引擎前缀缓存(KV cache 复用)
- feat: OpenAI 兼容 chat completion API(流式+非流式+usage)
- feat: Anthropic 兼容 /v1/messages API,移除旧版 /generate 端点
- feat: GRPO CLI 接入 + on-policy,OpenAI API top_k 参数化
- feat: Checkpoint 支持 extra 通用扩展数据
- feat: Docker Compose 一键部署(GPU/CPU 双模式)
- feat: GRPO 训练参数补充,批处理训练参数表
- fix: 调度器延迟优化 — 移除 5ms 睡眠,修复 refill 任务丢失
- fix: CLI 参数缺失/重复、device_ids 越界、generate 参数名不一致
- fix: 长对话截断方向错误,保留最新 token 而非最早
- fix: remove_task 未释放 KV cache slot 导致第二轮对话死锁
- fix: KV cache 槽位索引错位、版本校验缺失、注意力掩码
- fix: scheduler 越界 bug,SchedulerCallback 回调阶段修正
- perf: _Result 改用 Condition.wait_for 消除非流式 CPU 空转
- perf: decode 每步张量预分配;input_ids 改用一次构建代替逐元素赋值
- refactor: 移除 device_ids 参数,统一 CUDA_VISIBLE_DEVICES
- docs: 更新文档以匹配分页 KV cache 等代码重构
- docs: 修正多处文档错误、补充训练参数说明
2026-05-10 15:59:18 +08:00
ViperEkura cffedaad5e perf: 消除非流式推理 CPU 空转并减少 decode GPU 张量冗余分配
- engine.py: _Result 改用 threading.Condition.wait_for 替代
  Event busy-wait,非流式模式线程被内核挂起而非 1760 万次空转
- scheduler.py: _execute_decode 将 temperature/top_k/top_p 张量
  移至循环外预先分配,避免每步重复 torch.tensor();input_ids
  改用 torch.empty 避免不必要的 zero 初始化(两处均为完全覆盖)
- _execute_prefill: input_ids 同改为 torch.empty
2026-05-10 15:32:11 +08:00
3 changed files with 32 additions and 18 deletions

View File

@ -1,4 +1,4 @@
__version__ = "1.3.3" __version__ = "1.3.4"
__author__ = "ViperEkura" __author__ = "ViperEkura"
from astrai.config import ( from astrai.config import (

View File

@ -97,7 +97,8 @@ class _Result:
"""Thread-safe token accumulator for streaming and non-streaming modes. """Thread-safe token accumulator for streaming and non-streaming modes.
Supports multiple concurrent generation tasks with per-index result tracking. Supports multiple concurrent generation tasks with per-index result tracking.
Uses a threading.Event for efficient waiting on completion. Uses a threading.Condition for efficient completion notification
and a threading.Event for streaming wakeup.
""" """
def __init__(self, count: int = 1): def __init__(self, count: int = 1):
@ -106,7 +107,7 @@ class _Result:
Args: Args:
count: Number of concurrent generation tasks to track. count: Number of concurrent generation tasks to track.
""" """
self._lock = threading.Lock() self._cond = threading.Condition()
self._event = threading.Event() self._event = threading.Event()
self.tokens: List[str] = [] self.tokens: List[str] = []
self.results: List[str] = [""] * count self.results: List[str] = [""] * count
@ -124,7 +125,7 @@ class _Result:
token: The decoded token string, or STOP sentinel. token: The decoded token string, or STOP sentinel.
idx: Index of the generation task this token belongs to. idx: Index of the generation task this token belongs to.
""" """
with self._lock: with self._cond:
self.tokens.append(token) self.tokens.append(token)
if token is not STOP: if token is not STOP:
self.results[idx] += token self.results[idx] += token
@ -132,7 +133,8 @@ class _Result:
if not self._done[idx]: if not self._done[idx]:
self._done[idx] = True self._done[idx] = True
self._completed += 1 self._completed += 1
self._event.set() self._cond.notify_all()
self._event.set()
def pop_all(self) -> List[str]: def pop_all(self) -> List[str]:
"""Returns and clears all accumulated tokens. """Returns and clears all accumulated tokens.
@ -140,7 +142,7 @@ class _Result:
Returns: Returns:
List of token strings since the last call. List of token strings since the last call.
""" """
with self._lock: with self._cond:
out = self.tokens.copy() out = self.tokens.copy()
self.tokens.clear() self.tokens.clear()
if not out: if not out:
@ -158,13 +160,22 @@ class _Result:
""" """
return self._event.wait(timeout=timeout) return self._event.wait(timeout=timeout)
def wait_completion(self) -> None:
"""Blocks until all tasks complete (non-streaming).
Uses a Condition to sleep efficiently instead of busy-waiting.
The calling thread is parked until a STOP signal arrives.
"""
with self._cond:
self._cond.wait_for(lambda: self._completed >= self._total)
def get_results(self) -> List[str]: def get_results(self) -> List[str]:
"""Returns all accumulated results for non-streaming mode. """Returns all accumulated results for non-streaming mode.
Returns: Returns:
List of complete generated strings, one per task index. List of complete generated strings, one per task index.
""" """
with self._lock: with self._cond:
return self.results.copy() return self.results.copy()
@ -425,8 +436,7 @@ class InferenceEngine:
) )
task_ids.append(task_id) task_ids.append(task_id)
while result._completed < result._total: result.wait_completion()
result.wait(timeout=1.0)
for task_id in task_ids: for task_id in task_ids:
self.scheduler.remove_task(task_id) self.scheduler.remove_task(task_id)

View File

@ -253,7 +253,7 @@ class InferenceScheduler:
batch_sz = len(tasks) batch_sz = len(tasks)
seq_len = prompt_len - start_pos seq_len = prompt_len - start_pos
input_ids = torch.zeros(batch_sz, seq_len, dtype=torch.long, device=self.device) input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
input_mask = torch.ones(batch_sz, seq_len, dtype=torch.bool, device=self.device) input_mask = torch.ones(batch_sz, seq_len, dtype=torch.bool, device=self.device)
for i, t in enumerate(tasks): for i, t in enumerate(tasks):
@ -285,15 +285,21 @@ class InferenceScheduler:
for t in tasks: for t in tasks:
self._maybe_alloc_page(t, start_pos) self._maybe_alloc_page(t, start_pos)
input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device) input_ids = torch.tensor(
for i, t in enumerate(tasks): [t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks],
input_ids[i] = t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] dtype=torch.long,
device=self.device,
)
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device) active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
page_tables = self._make_page_table_tensor(tasks) page_tables = self._make_page_table_tensor(tasks)
total_len = start_pos + 1 total_len = start_pos + 1
temperatures = torch.tensor([t.temperature for t in tasks], device=self.device)
top_ks = torch.tensor([t.top_k for t in tasks], device=self.device)
top_ps = torch.tensor([t.top_p for t in tasks], device=self.device)
with torch.inference_mode(): with torch.inference_mode():
outputs = self.model( outputs = self.model(
input_ids.unsqueeze(1), input_ids.unsqueeze(1),
@ -305,11 +311,9 @@ class InferenceScheduler:
next_tokens = sample( next_tokens = sample(
logits, logits,
temperature=torch.tensor( temperature=temperatures,
[t.temperature for t in tasks], device=logits.device top_k=top_ks,
), top_p=top_ps,
top_k=torch.tensor([t.top_k for t in tasks], device=logits.device),
top_p=torch.tensor([t.top_p for t in tasks], device=logits.device),
).tolist() ).tolist()
for t, ntok in zip(tasks, next_tokens): for t, ntok in zip(tasks, next_tokens):