diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index d83e2ed..23a1fce 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -147,11 +147,13 @@ class InferenceScheduler: if len(prompt_ids) > self.max_prompt_len: prompt_ids = prompt_ids[-self.max_prompt_len :] - if len(prompt_ids) + max_tokens > self.max_seq_len: + if len(prompt_ids) >= self.max_seq_len: if stream_callback: stream_callback(STOP) return task_id + max_tokens = min(max_tokens, self.max_seq_len - len(prompt_ids)) + task = Task( task_id=task_id, prompt_ids=prompt_ids,