File size: 2,045 Bytes
485f9fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from transformers import AutoTokenizer, LlamaForCausalLM
from torch.nn.parallel import DistributedDataParallel as DDP
from evalplus.data import get_human_eval_plus, write_jsonl
import os
from tqdm import tqdm  # import tqdm

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def generate_one_completion(ddp_model, tokenizer, prompt: str):
    tokenizer.pad_token = tokenizer.eos_token
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096)

    # Generate
    generate_ids = ddp_model.module.generate(inputs.input_ids.to("cuda"), max_new_tokens=384, do_sample=True, top_p=0.75, top_k=40, temperature=0.1, pad_token_id=tokenizer.eos_token_id)
    completion = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    completion = completion.replace(prompt, "").split("\n\n\n")[0]
    
    print("-------------------")
    print(completion)
    return completion

def run(rank, world_size):
    setup(rank, world_size)

    model_path = "Nondzu/Mistral-7B-codealpaca-lora"
    model = LlamaForCausalLM.from_pretrained(model_path,load_in_8bit=True)
    ddp_model = DDP(model, device_ids=[rank])
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    problems = get_human_eval_plus()
    num_samples_per_task = 1

    samples = [
        dict(task_id=task_id, completion=generate_one_completion(ddp_model, tokenizer, problems[task_id]["prompt"]))
        for task_id in tqdm(problems)  # add tqdm here
        for _ in range(num_samples_per_task)
    ]
    write_jsonl(f"samples-Nondzu-Mistral-7B-codealpaca-lora-rank{rank}.jsonl", samples)

    cleanup()
    
def main():
    world_size = 1
    mp.spawn(run, args=(world_size,), nprocs=world_size, join=True)

if __name__=="__main__":
    main()