TheBloke's picture
GPTQ model commit
5036d3b
raw history blame
No virus
2.05 kB
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from transformers import AutoTokenizer, LlamaForCausalLM
from torch.nn.parallel import DistributedDataParallel as DDP
from evalplus.data import get_human_eval_plus, write_jsonl
import os
from tqdm import tqdm # import tqdm
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def generate_one_completion(ddp_model, tokenizer, prompt: str):
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096)
# Generate
generate_ids = ddp_model.module.generate(inputs.input_ids.to("cuda"), max_new_tokens=384, do_sample=True, top_p=0.75, top_k=40, temperature=0.1, pad_token_id=tokenizer.eos_token_id)
completion = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
completion = completion.replace(prompt, "").split("\n\n\n")[0]
print("-------------------")
print(completion)
return completion
def run(rank, world_size):
setup(rank, world_size)
model_path = "Nondzu/Mistral-7B-codealpaca-lora"
model = LlamaForCausalLM.from_pretrained(model_path,load_in_8bit=True)
ddp_model = DDP(model, device_ids=[rank])
tokenizer = AutoTokenizer.from_pretrained(model_path)
problems = get_human_eval_plus()
num_samples_per_task = 1
samples = [
dict(task_id=task_id, completion=generate_one_completion(ddp_model, tokenizer, problems[task_id]["prompt"]))
for task_id in tqdm(problems) # add tqdm here
for _ in range(num_samples_per_task)
]
write_jsonl(f"samples-Nondzu-Mistral-7B-codealpaca-lora-rank{rank}.jsonl", samples)
cleanup()
def main():
world_size = 1
mp.spawn(run, args=(world_size,), nprocs=world_size, join=True)
if __name__=="__main__":
main()