diff --git a/scripts/demo/generate_batch.py b/scripts/demo/generate_batch.py index a074976..39c4975 100644 --- a/scripts/demo/generate_batch.py +++ b/scripts/demo/generate_batch.py @@ -24,12 +24,23 @@ def batch_generate(): "请问什么是显卡", ] + prompts = [ + tokenizer.apply_chat_template( + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": q}, + ], + tokenize=False, + ) + for q in inputs + ] + engine = InferenceEngine( model=model, tokenizer=tokenizer, ) responses = engine.generate( - prompt=inputs, + prompt=prompts, stream=False, max_tokens=2048, temperature=0.8,