From ad9f4d9cf60f35cf742509b8096c7b541252c5be Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 17 May 2026 10:23:42 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20generate=5Far=20=E6=94=B9=E7=94=A8?= =?UTF-8?q?=E6=B5=81=E5=BC=8F=E8=BE=93=E5=87=BA=E5=B9=B6=E5=8E=BB=E9=99=A4?= =?UTF-8?q?=E5=86=97=E4=BD=99=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/demo/generate_ar.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/scripts/demo/generate_ar.py b/scripts/demo/generate_ar.py index 80ec033..ccbe49a 100644 --- a/scripts/demo/generate_ar.py +++ b/scripts/demo/generate_ar.py @@ -11,7 +11,6 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params") def generate_text(): - # Load model from pretrained model = AutoModel.from_pretrained(PARAMETER_ROOT) tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT) model.to(device="cuda", dtype=torch.bfloat16) @@ -22,16 +21,15 @@ def generate_text(): model=model, tokenizer=tokenizer, ) - response = engine.generate( + for token in engine.generate( prompt=query, - stream=False, + stream=True, max_tokens=2048, temperature=0.8, top_p=0.95, top_k=50, - ) - - print(response) + ): + print(token, end="", flush=True) if __name__ == "__main__":