File size: 695 Bytes
be6fdd2
 
 
 
 
 
 
 
 
 
 
4711ea3
be6fdd2
46c38c7
fe673de
2cdacf8
be6fdd2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
from transformers import AutoTokenizer
from petals import AutoDistributedModelForCausalLM

model_name = "petals-team/StableBeluga2"

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, add_bos_token=False)
model = AutoDistributedModelForCausalLM.from_pretrained(model_name)

import gradio as gr

def generate(input):
    tokenized = tokenizer(input, return_tensors="pt")["input_ids"]
    outputs = model.generate(tokenized, max_new_tokens=80, do_sample=True, temperature=0.9)
    decoded = tokenizer.decode(outputs[0])
    return decoded[len(input):len(decoded)].replace("</s>", "");

iface = gr.Interface(fn=generate, inputs="text", outputs="text")
iface.launch()