fix: 修复长对话截断方向错误,保留最新 token 而非最早
- add_task 中 prompt 超长时改为保留末尾 token(prompt_ids[-max_prompt_len:]) 而非开头 token,确保多轮对话时模型能看到最近的提问上下文
This commit is contained in:
parent
a6f5ff3b37
commit
c4401512f2
|
|
@ -148,7 +148,7 @@ class InferenceEngine:
|
||||||
tokenizer: AutoTokenizer,
|
tokenizer: AutoTokenizer,
|
||||||
max_batch_size: int = 1,
|
max_batch_size: int = 1,
|
||||||
max_seq_len: Optional[int] = None,
|
max_seq_len: Optional[int] = None,
|
||||||
max_prompt_len: int = 512,
|
max_prompt_len: int = 2048,
|
||||||
cache_capacity: int = 1000,
|
cache_capacity: int = 1000,
|
||||||
):
|
):
|
||||||
"""Initializes the engine and starts the scheduler background thread.
|
"""Initializes the engine and starts the scheduler background thread.
|
||||||
|
|
|
||||||
|
|
@ -480,7 +480,7 @@ class InferenceScheduler:
|
||||||
prompt_ids = self.tokenizer.encode(prompt)
|
prompt_ids = self.tokenizer.encode(prompt)
|
||||||
|
|
||||||
if len(prompt_ids) > self.max_prompt_len:
|
if len(prompt_ids) > self.max_prompt_len:
|
||||||
prompt_ids = prompt_ids[: self.max_prompt_len]
|
prompt_ids = prompt_ids[-self.max_prompt_len :]
|
||||||
|
|
||||||
task = Task(
|
task = Task(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue