import sys sys.path.insert(0, './petals/') import torch import transformers import gradio as gr from src.client.remote_model import DistributedBloomForCausalLM MODEL_NAME = "bigscience/bloom-petals" tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME) model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME) def inference(text, seq_length=1): input_ids = tokenizer([text], return_tensors="pt").input_ids output = model.generate(input_ids, max_new_tokens=seq_length) return tokenizer.batch_decode(output)[0] iface = gr.Interface( fn=inference, inputs=[ gr.Textbox(lines=10, label="Input text"), gr.inputs.Slider( minimum=0, maximum=1000, step=1, default=42, label="Sequence length for generation" ) ], outputs="text" ) iface.launch()