File size: 3,084 Bytes
fe4c470 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
from typing import List
import fire
from llama import Llama
import json
def read_json(file_path):
with open(file_path, 'r', encoding='utf-8') as file:
data = json.load(file)
return data
def write_json(file_path, data):
with open(file_path, 'w', encoding='utf-8') as file:
json.dump(data, file, ensure_ascii=False, indent=4)
def main(
ckpt_dir: str,
tokenizer_path: str,
temperature: float = 0.6,
top_p: float = 0.9,
max_seq_len: int = 128,
max_gen_len: int = 64,
max_batch_size: int = 4,
json_path: str = None,
):
"""
Examples to run with the pre-trained models (no fine-tuning). Prompts are
usually in the form of an incomplete text prefix that the model can then try to complete.
The context window of llama3 models is 8192 tokens, so `max_seq_len` needs to be <= 8192.
`max_gen_len` is needed because pre-trained models usually do not stop completions naturally.
"""
generator = Llama.build(
ckpt_dir=ckpt_dir,
tokenizer_path=tokenizer_path,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
)
with open(json_path) as f:
data = json.load(f)
ans = []
begin, end,batch_size = 0,len(data),max_batch_size
for batch_idx in tqdm(range(begin, end, max_batch_size)):
up = min(batch_idx + max_batch_size, end)
batch = data[batch_idx:up]
print(f"batch {batch_idx} to {up}")
text_batch = []
for idx,i in enumerate(batch):
text_batch.append(idx)
res = generator.text_completion(
text_batch,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
)
ans.append(res)
cnt = cnt + 1
if cnt % 10 == 0:
print(f"batch {cnt} done")
write_json(ans, "ans.json")
# prompts: List[str] = [
# # For these prompts, the expected answer is the natural continuation of the prompt
# "I believe the meaning of life is",
# "Simply put, the theory of relativity states that ",
# """A brief message congratulating the team on the launch:
# Hi everyone,
# I just """,
# # Few shot prompt (providing a few examples before asking model to complete more);
# """Translate English to French:
# sea otter => loutre de mer
# peppermint => menthe poivrée
# plush girafe => girafe peluche
# cheese =>""",
# ]
# results = generator.text_completion(
# prompts,
# max_gen_len=max_gen_len,
# temperature=temperature,
# top_p=top_p,
# )
# for prompt, result in zip(prompts, results):
# print(prompt)
# print(f"> {result['generation']}")
# print("\n==================================\n")
if __name__ == "__main__":
fire.Fire(main)
|