refactor: generate_ar 改用流式输出并去除冗余注释
This commit is contained in:
parent
e1638a7ade
commit
ad9f4d9cf6
|
|
@ -11,7 +11,6 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
|||
|
||||
|
||||
def generate_text():
|
||||
# Load model from pretrained
|
||||
model = AutoModel.from_pretrained(PARAMETER_ROOT)
|
||||
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
|
||||
model.to(device="cuda", dtype=torch.bfloat16)
|
||||
|
|
@ -22,16 +21,15 @@ def generate_text():
|
|||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
response = engine.generate(
|
||||
for token in engine.generate(
|
||||
prompt=query,
|
||||
stream=False,
|
||||
stream=True,
|
||||
max_tokens=2048,
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
top_k=50,
|
||||
)
|
||||
|
||||
print(response)
|
||||
):
|
||||
print(token, end="", flush=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in New Issue