refactor: generate_ar 改用流式输出并去除冗余注释

This commit is contained in:
ViperEkura 2026-05-17 10:23:42 +08:00
parent e1638a7ade
commit ad9f4d9cf6
1 changed files with 4 additions and 6 deletions

View File

@ -11,7 +11,6 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
def generate_text(): def generate_text():
# Load model from pretrained
model = AutoModel.from_pretrained(PARAMETER_ROOT) model = AutoModel.from_pretrained(PARAMETER_ROOT)
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT) tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
model.to(device="cuda", dtype=torch.bfloat16) model.to(device="cuda", dtype=torch.bfloat16)
@ -22,16 +21,15 @@ def generate_text():
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
) )
response = engine.generate( for token in engine.generate(
prompt=query, prompt=query,
stream=False, stream=True,
max_tokens=2048, max_tokens=2048,
temperature=0.8, temperature=0.8,
top_p=0.95, top_p=0.95,
top_k=50, top_k=50,
) ):
print(token, end="", flush=True)
print(response)
if __name__ == "__main__": if __name__ == "__main__":