jbochi commited on
Commit
0fb075f
1 Parent(s): 0644eb8

Use app.py from https://github.com/synkathairo/flan-t5-large-gradio

Browse files
Files changed (1) hide show
  1. app.py +45 -1
app.py CHANGED
@@ -1,3 +1,47 @@
 
1
  import gradio as gr
2
 
3
- gr.Interface.load("models/jbochi/madlad400-3b-mt").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
2
  import gradio as gr
3
 
4
+ MODEL_NAME = "jbochi/madlad400-3b-mt"
5
+
6
+
7
+ default_max_length = 200
8
+
9
+ print("Using `{}`.".format(MODEL_NAME))
10
+
11
+ tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
12
+ print("T5Tokenizer loaded from pretrained.")
13
+
14
+ model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, device_map="auto")
15
+ print("T5ForConditionalGeneration loaded from pretrained.")
16
+
17
+
18
+ def inference(max_length, input_text, history=[]):
19
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids
20
+ outputs = model.generate(input_ids, max_length=max_length, bos_token_id=0)
21
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
22
+ history.append((input_text, result))
23
+ return history, history
24
+
25
+
26
+ with gr.Blocks() as demo:
27
+ with gr.Row():
28
+ gr.Markdown(
29
+ "<h1>Demo of {}</h1><p>See more at Hugging Face: <a href='https://huggingface.co/{}'>{}</a>.</p>".format(
30
+ MODEL_NAME, MODEL_NAME, MODEL_NAME
31
+ )
32
+ )
33
+ max_length = gr.Number(
34
+ value=default_max_length, label="maximum length of response"
35
+ )
36
+
37
+ chatbot = gr.Chatbot(label=MODEL_NAME)
38
+ state = gr.State([])
39
+
40
+ with gr.Row():
41
+ txt = gr.Textbox(
42
+ show_label=False, placeholder="<2es> text to translate"
43
+ ).style(container=False)
44
+
45
+ txt.submit(fn=inference, inputs=[max_length, txt, state], outputs=[chatbot, state])
46
+
47
+ demo.launch()