Spaces:
Sleeping
Sleeping
File size: 1,037 Bytes
48f8345 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import torch
from transformers import AutoTokenizer
from palm_rlhf_pytorch import PaLM
import gradio as gr
def generate(prompt, seq_len, temperature, filter_thres, model):
device = torch.device("cpu")
model = PaLM(
num_tokens=50304, dim=1024, depth=24, dim_head=128, heads=8, flash_attn=False, qk_rmsnorm = False,
).to(device)
model.load('/palm_410m_8k_v0.pt')
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
encoded_text = tokenizer(prompt, return_tensors="pt")
output_tensor = model.generate(
seq_len=seq_len,
prompt=encoded_text["input_ids"].to(device),
temperature=temperature,
filter_thres=filter_thres,
pad_value=0.0,
eos_token=tokenizer.eos_token_id,
return_seq_without_prompt=False,
use_tqdm=True,
)
decoded_output = tokenizer.batch_decode(output_tensor, skip_special_tokens=True)
return decoded_output[0]
iface = gr.Interface(fn=generate, inputs="text", outputs="text")
iface.launch() |