chat-gradio / app.py
artek0chumak's picture
Fixes
fd09410
raw history blame
No virus
826 Bytes
import sys
sys.path.insert(0, './petals/')
import torch
import transformers
import gradio as gr
from src.client.remote_model import DistributedBloomForCausalLM
MODEL_NAME = "bigscience/bloom-petals"
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME)
def inference(text, seq_length=1):
input_ids = tokenizer([text], return_tensors="pt").input_ids
output = model.generate(input_ids, max_new_tokens=seq_length)
return tokenizer.batch_decode(output)[0]
iface = gr.Interface(
fn=inference,
inputs=[
gr.Textbox(lines=10, label="Input text"),
gr.inputs.Slider(
minimum=0,
maximum=1000,
step=1,
default=42,
label="Sequence length for generation"
)
],
outputs="text"
)
iface.launch()