DPO / hf_gradio.py
manuelrobben's picture
Upload folder using huggingface_hub
6850fe2
from transformers import AutoModelForCausalLM, AutoTokenizer
import transformers
from utils import get_local_dir, pad_to_length
import gradio as gr
import torch
def load_checkpoint(checkpoint_path):
model = AutoModelForCausalLM.from_pretrained(checkpoint_path)
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
return model, tokenizer
import gradio as gr
import torch
checkpoint_paths = {'full_policy':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/anthropic_dpo_phythia28/LATEST/policy.pt',
'reference':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/anthropic_dpo_pythia28_2023-08-06_12-12-25_294354/LATEST/policy.pt',
'all_but_two_last':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/all_but_two_last/LATEST/policy.pt',
'all_but_three_last':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/all_but_three_last_2023-08-19_06-44-44_597545/LATEST/policy.pt',
'all_but_last_basic':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/all_but_last_basic_2023-08-19_06-44-55_606332/LATEST/policy.pt',
'all_but_last':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/all_but_last_2023-08-19_06-45-07_722235/LATEST/policy.pt'
}
options=['reference','full_policy','all_but_two_last','all_but_three_last','all_but_last_basic','all_but_last']
policy_dtype = getattr(torch, 'float32')
tokenizer = transformers.AutoTokenizer.from_pretrained('EleutherAI/pythia-2.8b', cache_dir=get_local_dir('.cache'))
model = transformers.AutoModelForCausalLM.from_pretrained(
'EleutherAI/pythia-2.8b', cache_dir=get_local_dir('.cache'), low_cpu_mem_usage=True, torch_dtype=policy_dtype)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
def load_selected_checkpoint(options):
selected_path = checkpoint_paths[options]
policy_state_dict = torch.load(selected_path, map_location='cpu')
step, metrics = policy_state_dict ['step_idx'], policy_state_dict ['metrics']
model.load_state_dict(policy_state_dict['state'])
return model
def generate_response(prompt, options):
model= load_selected_checkpoint(options)
prompt='\n\nHuman: ' + prompt + '\n\nAssistant:'
input =tokenizer(prompt, add_special_tokens=False)
for i,k in input.items():
input[i]=torch.LongTensor(k).unsqueeze(0)
policy_output = model.generate(input['input_ids'], attention_mask=input['attention_mask'], max_length=512, do_sample=True, pad_token_id=tokenizer.pad_token_id)
policy_output = pad_to_length(policy_output, 512, tokenizer.pad_token_id)
policy_output_decoded = tokenizer.batch_decode(policy_output, skip_special_tokens=True)
return policy_output_decoded
iface = gr.Interface(
fn=generate_response,
inputs=[gr.inputs.Textbox(label="Prompt"), gr.inputs.Dropdown(choices=options, label="Select Checkpoint")],
outputs="text"
)
iface.launch(share=True)