File size: 1,466 Bytes
0fb075f
0644eb8
 
0fb075f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from transformers import T5Tokenizer, T5ForConditionalGeneration
import gradio as gr

MODEL_NAME = "jbochi/madlad400-3b-mt"


default_max_length = 200

print("Using `{}`.".format(MODEL_NAME))

tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
print("T5Tokenizer loaded from pretrained.")

model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, device_map="auto")
print("T5ForConditionalGeneration loaded from pretrained.")


def inference(max_length, input_text, history=[]):
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids
    outputs = model.generate(input_ids, max_length=max_length, bos_token_id=0)
    result = tokenizer.decode(outputs[0], skip_special_tokens=True)
    history.append((input_text, result))
    return history, history


with gr.Blocks() as demo:
    with gr.Row():
        gr.Markdown(
            "<h1>Demo of {}</h1><p>See more at Hugging Face: <a href='https://huggingface.co/{}'>{}</a>.</p>".format(
                MODEL_NAME, MODEL_NAME, MODEL_NAME
            )
        )
        max_length = gr.Number(
            value=default_max_length, label="maximum length of response"
        )

    chatbot = gr.Chatbot(label=MODEL_NAME)
    state = gr.State([])

    with gr.Row():
        txt = gr.Textbox(
            show_label=False, placeholder="<2es> text to translate"
        ).style(container=False)

    txt.submit(fn=inference, inputs=[max_length, txt, state], outputs=[chatbot, state])

demo.launch()