import torch from transformers import AutoTokenizer from palm_rlhf_pytorch import PaLM import gradio as gr def generate(prompt, seq_len, temperature, filter_thres): device = torch.device("cpu") model = PaLM( num_tokens=50304, dim=1024, depth=24, dim_head=128, heads=8, flash_attn=False, qk_rmsnorm = False, ).to(device).eval() checkpoint = torch.load('./palm_410m_8k_v0.pt', map_location=device) model.load_state_dict(checkpoint) tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") encoded_text = tokenizer(prompt, return_tensors="pt") output_tensor = model.generate( seq_len=seq_len, prompt=encoded_text["input_ids"].to(device), temperature=temperature, filter_thres=filter_thres, pad_value=0.0, eos_token=tokenizer.eos_token_id, return_seq_without_prompt=False, use_tqdm=True, ) decoded_output = tokenizer.batch_decode(output_tensor, skip_special_tokens=True) return decoded_output iface = gr.Interface( fn=generate, title="PaLM", description="Open-source PaLM demo.", inputs="text", outputs="text" [ gr.Slider(minimum=1, maximum=512, step=1, default=128, label="Sequence Length"), gr.Slider(minimum=0.0, maximum=1.0, step=0.01, default=0.9, label="Temperature"), gr.Slider(minimum=0.0, maximum=1.0, step=0.01, default=0.9, label="Filter Threshold"), ], ) iface.launch()