File size: 3,055 Bytes
6850fe2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from transformers import AutoModelForCausalLM, AutoTokenizer
import transformers
from utils import get_local_dir,  pad_to_length

import gradio as gr
import torch

def load_checkpoint(checkpoint_path):
    model = AutoModelForCausalLM.from_pretrained(checkpoint_path)
    tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
    return model, tokenizer

import gradio as gr
import torch


checkpoint_paths = {'full_policy':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/anthropic_dpo_phythia28/LATEST/policy.pt',
                    'reference':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/anthropic_dpo_pythia28_2023-08-06_12-12-25_294354/LATEST/policy.pt',
                    'all_but_two_last':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/all_but_two_last/LATEST/policy.pt',
                    'all_but_three_last':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/all_but_three_last_2023-08-19_06-44-44_597545/LATEST/policy.pt',
                    'all_but_last_basic':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/all_but_last_basic_2023-08-19_06-44-55_606332/LATEST/policy.pt',
                    'all_but_last':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/all_but_last_2023-08-19_06-45-07_722235/LATEST/policy.pt'


}

options=['reference','full_policy','all_but_two_last','all_but_three_last','all_but_last_basic','all_but_last']

policy_dtype = getattr(torch, 'float32')
tokenizer = transformers.AutoTokenizer.from_pretrained('EleutherAI/pythia-2.8b', cache_dir=get_local_dir('.cache'))

model = transformers.AutoModelForCausalLM.from_pretrained(
        'EleutherAI/pythia-2.8b', cache_dir=get_local_dir('.cache'), low_cpu_mem_usage=True, torch_dtype=policy_dtype)


if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

def load_selected_checkpoint(options):
   selected_path = checkpoint_paths[options]
    
   policy_state_dict = torch.load(selected_path, map_location='cpu')
   step, metrics = policy_state_dict ['step_idx'], policy_state_dict ['metrics']
   model.load_state_dict(policy_state_dict['state'])
   return model
   

def generate_response(prompt, options):
    model= load_selected_checkpoint(options)
    prompt='\n\nHuman: ' + prompt + '\n\nAssistant:'
    input =tokenizer(prompt, add_special_tokens=False)

    for i,k in input.items():
        input[i]=torch.LongTensor(k).unsqueeze(0)
    
    policy_output = model.generate(input['input_ids'], attention_mask=input['attention_mask'], max_length=512, do_sample=True, pad_token_id=tokenizer.pad_token_id)

    policy_output = pad_to_length(policy_output, 512, tokenizer.pad_token_id)
    
    policy_output_decoded = tokenizer.batch_decode(policy_output, skip_special_tokens=True)

    return policy_output_decoded






iface = gr.Interface(
    fn=generate_response,
    inputs=[gr.inputs.Textbox(label="Prompt"), gr.inputs.Dropdown(choices=options, label="Select Checkpoint")],
    outputs="text"
)
iface.launch(share=True)