import torch from transformers import AutoTokenizer from palm_rlhf_pytorch import PaLM import gradio as gr def generate(prompt, seq_len, temperature, filter_thres, model): device = torch.device("cpu") model = torch.hub.load("conceptofmind/PaLM", "palm_1b_8k_v0", map_location=device).to(device).eval() tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") encoded_text = tokenizer(prompt, return_tensors="pt") output_tensor = model.generate( seq_len=seq_len, prompt=encoded_text["input_ids"].to(device), temperature=temperature, filter_thres=filter_thres, pad_value=0.0, eos_token=tokenizer.eos_token_id, return_seq_without_prompt=False, use_tqdm=True, ) decoded_output = tokenizer.batch_decode(output_tensor, skip_special_tokens=True) return decoded_output iface = gr.Interface( fn=generate, title="PaLM", description="Open-source PaLM demo.", inputs="text", outputs="text" ) iface.launch()