from transformers import AutoModelForCausalLM, AutoTokenizer import transformers from utils import get_local_dir, pad_to_length import gradio as gr import torch def load_checkpoint(checkpoint_path): model = AutoModelForCausalLM.from_pretrained(checkpoint_path) tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) return model, tokenizer import gradio as gr import torch checkpoint_paths = {'full_policy':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/anthropic_dpo_phythia28/LATEST/policy.pt', 'reference':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/anthropic_dpo_pythia28_2023-08-06_12-12-25_294354/LATEST/policy.pt', 'all_but_two_last':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/all_but_two_last/LATEST/policy.pt', 'all_but_three_last':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/all_but_three_last_2023-08-19_06-44-44_597545/LATEST/policy.pt', 'all_but_last_basic':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/all_but_last_basic_2023-08-19_06-44-55_606332/LATEST/policy.pt', 'all_but_last':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/all_but_last_2023-08-19_06-45-07_722235/LATEST/policy.pt' } options=['reference','full_policy','all_but_two_last','all_but_three_last','all_but_last_basic','all_but_last'] policy_dtype = getattr(torch, 'float32') tokenizer = transformers.AutoTokenizer.from_pretrained('EleutherAI/pythia-2.8b', cache_dir=get_local_dir('.cache')) model = transformers.AutoModelForCausalLM.from_pretrained( 'EleutherAI/pythia-2.8b', cache_dir=get_local_dir('.cache'), low_cpu_mem_usage=True, torch_dtype=policy_dtype) if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id def load_selected_checkpoint(options): selected_path = checkpoint_paths[options] policy_state_dict = torch.load(selected_path, map_location='cpu') step, metrics = policy_state_dict ['step_idx'], policy_state_dict ['metrics'] model.load_state_dict(policy_state_dict['state']) return model def generate_response(prompt, options): model= load_selected_checkpoint(options) prompt='\n\nHuman: ' + prompt + '\n\nAssistant:' input =tokenizer(prompt, add_special_tokens=False) for i,k in input.items(): input[i]=torch.LongTensor(k).unsqueeze(0) policy_output = model.generate(input['input_ids'], attention_mask=input['attention_mask'], max_length=512, do_sample=True, pad_token_id=tokenizer.pad_token_id) policy_output = pad_to_length(policy_output, 512, tokenizer.pad_token_id) policy_output_decoded = tokenizer.batch_decode(policy_output, skip_special_tokens=True) return policy_output_decoded iface = gr.Interface( fn=generate_response, inputs=[gr.inputs.Textbox(label="Prompt"), gr.inputs.Dropdown(choices=options, label="Select Checkpoint")], outputs="text" ) iface.launch(share=True)