diff --git a/scripts/demo/generate_ar.py b/scripts/demo/generate_ar.py index 80ec033..ccbe49a 100644 --- a/scripts/demo/generate_ar.py +++ b/scripts/demo/generate_ar.py @@ -11,7 +11,6 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params") def generate_text(): - # Load model from pretrained model = AutoModel.from_pretrained(PARAMETER_ROOT) tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT) model.to(device="cuda", dtype=torch.bfloat16) @@ -22,16 +21,15 @@ def generate_text(): model=model, tokenizer=tokenizer, ) - response = engine.generate( + for token in engine.generate( prompt=query, - stream=False, + stream=True, max_tokens=2048, temperature=0.8, top_p=0.95, top_k=50, - ) - - print(response) + ): + print(token, end="", flush=True) if __name__ == "__main__":