import torch from transformers import AutoTokenizer from palm_rlhf_pytorch import PaLM import gradio as gr def generate(prompt, seq_len, temperature, filter_thres, model): device = torch.device("cpu") model = PaLM( num_tokens=50304, dim=1024, depth=24, dim_head=128, heads=8, flash_attn=False, qk_rmsnorm = False, ).to(device) model.load('/palm_410m_8k_v0.pt') tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") encoded_text = tokenizer(prompt, return_tensors="pt") output_tensor = model.generate( seq_len=seq_len, prompt=encoded_text["input_ids"].to(device), temperature=temperature, filter_thres=filter_thres, pad_value=0.0, eos_token=tokenizer.eos_token_id, return_seq_without_prompt=False, use_tqdm=True, ) decoded_output = tokenizer.batch_decode(output_tensor, skip_special_tokens=True) return decoded_output[0] iface = gr.Interface(fn=generate, inputs="text", outputs="text") iface.launch()