|
|
|
|
|
|
|
import fire |
|
|
|
from llama import Llama |
|
|
|
|
|
def main( |
|
ckpt_dir: str, |
|
tokenizer_path: str, |
|
temperature: float = 0.6, |
|
top_p: float = 0.9, |
|
max_seq_len: int = 128, |
|
max_gen_len: int = 64, |
|
max_batch_size: int = 4, |
|
): |
|
generator = Llama.build( |
|
ckpt_dir=ckpt_dir, |
|
tokenizer_path=tokenizer_path, |
|
max_seq_len=max_seq_len, |
|
max_batch_size=max_batch_size, |
|
) |
|
|
|
prompts = [ |
|
|
|
"I believe the meaning of life is", |
|
"Simply put, the theory of relativity states that ", |
|
"""A brief message congratulating the team on the launch: |
|
|
|
Hi everyone, |
|
|
|
I just """, |
|
|
|
"""Translate English to French: |
|
|
|
sea otter => loutre de mer |
|
peppermint => menthe poivrée |
|
plush girafe => girafe peluche |
|
cheese =>""", |
|
] |
|
results = generator.text_completion( |
|
prompts, |
|
max_gen_len=max_gen_len, |
|
temperature=temperature, |
|
top_p=top_p, |
|
) |
|
for prompt, result in zip(prompts, results): |
|
print(prompt) |
|
print(f"> {result['generation']}") |
|
print("\n==================================\n") |
|
|
|
|
|
if __name__ == "__main__": |
|
fire.Fire(main) |
|
|