Spaces:
No application file
No application file
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import transformers | |
from utils import get_local_dir, pad_to_length | |
import gradio as gr | |
import torch | |
def load_checkpoint(checkpoint_path): | |
model = AutoModelForCausalLM.from_pretrained(checkpoint_path) | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) | |
return model, tokenizer | |
import gradio as gr | |
import torch | |
checkpoint_paths = {'full_policy':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/anthropic_dpo_phythia28/LATEST/policy.pt', | |
'reference':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/anthropic_dpo_pythia28_2023-08-06_12-12-25_294354/LATEST/policy.pt', | |
'all_but_two_last':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/all_but_two_last/LATEST/policy.pt', | |
'all_but_three_last':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/all_but_three_last_2023-08-19_06-44-44_597545/LATEST/policy.pt', | |
'all_but_last_basic':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/all_but_last_basic_2023-08-19_06-44-55_606332/LATEST/policy.pt', | |
'all_but_last':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/all_but_last_2023-08-19_06-45-07_722235/LATEST/policy.pt' | |
} | |
options=['reference','full_policy','all_but_two_last','all_but_three_last','all_but_last_basic','all_but_last'] | |
policy_dtype = getattr(torch, 'float32') | |
tokenizer = transformers.AutoTokenizer.from_pretrained('EleutherAI/pythia-2.8b', cache_dir=get_local_dir('.cache')) | |
model = transformers.AutoModelForCausalLM.from_pretrained( | |
'EleutherAI/pythia-2.8b', cache_dir=get_local_dir('.cache'), low_cpu_mem_usage=True, torch_dtype=policy_dtype) | |
if tokenizer.pad_token_id is None: | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
def load_selected_checkpoint(options): | |
selected_path = checkpoint_paths[options] | |
policy_state_dict = torch.load(selected_path, map_location='cpu') | |
step, metrics = policy_state_dict ['step_idx'], policy_state_dict ['metrics'] | |
model.load_state_dict(policy_state_dict['state']) | |
return model | |
def generate_response(prompt, options): | |
model= load_selected_checkpoint(options) | |
prompt='\n\nHuman: ' + prompt + '\n\nAssistant:' | |
input =tokenizer(prompt, add_special_tokens=False) | |
for i,k in input.items(): | |
input[i]=torch.LongTensor(k).unsqueeze(0) | |
policy_output = model.generate(input['input_ids'], attention_mask=input['attention_mask'], max_length=512, do_sample=True, pad_token_id=tokenizer.pad_token_id) | |
policy_output = pad_to_length(policy_output, 512, tokenizer.pad_token_id) | |
policy_output_decoded = tokenizer.batch_decode(policy_output, skip_special_tokens=True) | |
return policy_output_decoded | |
iface = gr.Interface( | |
fn=generate_response, | |
inputs=[gr.inputs.Textbox(label="Prompt"), gr.inputs.Dropdown(choices=options, label="Select Checkpoint")], | |
outputs="text" | |
) | |
iface.launch(share=True) | |