refactor: generate_ar 改用流式输出并去除冗余注释
This commit is contained in:
parent
e1638a7ade
commit
ad9f4d9cf6
|
|
@ -11,7 +11,6 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||||
|
|
||||||
|
|
||||||
def generate_text():
|
def generate_text():
|
||||||
# Load model from pretrained
|
|
||||||
model = AutoModel.from_pretrained(PARAMETER_ROOT)
|
model = AutoModel.from_pretrained(PARAMETER_ROOT)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
|
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
|
||||||
model.to(device="cuda", dtype=torch.bfloat16)
|
model.to(device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
@ -22,16 +21,15 @@ def generate_text():
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
response = engine.generate(
|
for token in engine.generate(
|
||||||
prompt=query,
|
prompt=query,
|
||||||
stream=False,
|
stream=True,
|
||||||
max_tokens=2048,
|
max_tokens=2048,
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
)
|
):
|
||||||
|
print(token, end="", flush=True)
|
||||||
print(response)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue