PaLM_models / app.py
Enrico Shippole
Add initial gradio setup
b54a00d
raw
history blame
1.03 kB
import torch
from transformers import AutoTokenizer
from palm_rlhf_pytorch import PaLM
import gradio as gr
def generate(prompt, seq_len, temperature, filter_thres, model):
device = torch.device("cpu")
model = torch.hub.load("conceptofmind/PaLM", "palm_1b_8k_v0", map_location=device).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
encoded_text = tokenizer(prompt, return_tensors="pt")
output_tensor = model.generate(
seq_len=seq_len,
prompt=encoded_text["input_ids"].to(device),
temperature=temperature,
filter_thres=filter_thres,
pad_value=0.0,
eos_token=tokenizer.eos_token_id,
return_seq_without_prompt=False,
use_tqdm=True,
)
decoded_output = tokenizer.batch_decode(output_tensor, skip_special_tokens=True)
return decoded_output
iface = gr.Interface(
fn=generate,
title="PaLM",
description="Open-source PaLM demo.",
inputs="text",
outputs="text"
)
iface.launch()