refactor: generate_ar 改用流式输出并去除冗余注释

This commit is contained in:
ViperEkura 2026-05-17 10:23:42 +08:00
parent e1638a7ade
commit ad9f4d9cf6
1 changed files with 4 additions and 6 deletions

View File

@ -11,7 +11,6 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
def generate_text():
# Load model from pretrained
model = AutoModel.from_pretrained(PARAMETER_ROOT)
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
model.to(device="cuda", dtype=torch.bfloat16)
@ -22,16 +21,15 @@ def generate_text():
model=model,
tokenizer=tokenizer,
)
response = engine.generate(
for token in engine.generate(
prompt=query,
stream=False,
stream=True,
max_tokens=2048,
temperature=0.8,
top_p=0.95,
top_k=50,
)
print(response)
):
print(token, end="", flush=True)
if __name__ == "__main__":