jbochi commited on
Commit
85b9553
1 Parent(s): 38479f3

Improve demo

Browse files
Files changed (1) hide show
  1. app.py +48 -34
app.py CHANGED
@@ -1,50 +1,64 @@
 
 
1
  from transformers import T5ForConditionalGeneration, T5Tokenizer, GenerationConfig
2
  import gradio as gr
3
 
4
  MODEL_NAME = "jbochi/madlad400-3b-mt"
5
 
6
-
7
- default_max_length = 200
8
-
9
- print("Using `{}`.".format(MODEL_NAME))
10
-
11
  tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
12
- print("T5Tokenizer loaded from pretrained.")
13
-
14
  model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, device_map="auto")
15
- print("T5ForConditionalGeneration loaded from pretrained.")
16
 
17
 
18
- def inference(max_length, input_text, history=[]):
19
- input_ids = tokenizer(input_text, return_tensors="pt").input_ids
 
 
 
 
20
  outputs = model.generate(
21
- input_ids=input_ids,
22
- generation_config=GenerationConfig(max_length=max_length, decoder_start_token_id=2),
23
  )
24
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
25
- history.append((input_text, result))
26
- return history, history
27
-
 
 
 
 
 
28
 
29
- with gr.Blocks() as demo:
30
- with gr.Row():
31
- gr.Markdown(
32
- "<h1>Demo of {}</h1><p>See more at Hugging Face: <a href='https://huggingface.co/{}'>{}</a>.</p>".format(
33
- MODEL_NAME, MODEL_NAME, MODEL_NAME
34
- )
35
- )
36
- max_length = gr.Number(
37
- value=default_max_length, label="maximum length of response"
38
- )
39
 
40
- chatbot = gr.Chatbot(label=MODEL_NAME)
41
- state = gr.State([])
42
-
43
- with gr.Row():
44
- txt = gr.Textbox(
45
- show_label=False, placeholder="<2es> text to translate"
46
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- txt.submit(fn=inference, inputs=[max_length, txt, state], outputs=[chatbot, state])
49
 
50
- demo.launch()
 
 
1
+ import time
2
+
3
  from transformers import T5ForConditionalGeneration, T5Tokenizer, GenerationConfig
4
  import gradio as gr
5
 
6
  MODEL_NAME = "jbochi/madlad400-3b-mt"
7
 
8
+ print(f"Loading {MODEL_NAME} tokenizer...")
 
 
 
 
9
  tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
10
+ print(f"Loading {MODEL_NAME} model...")
 
11
  model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, device_map="auto")
 
12
 
13
 
14
+ def inference(input_text, target_language, max_length):
15
+ global model, tokenizer
16
+ start_time = time.time()
17
+ input_ids = tokenizer(
18
+ f"<2{target_language}> {input_text}", return_tensors="pt"
19
+ ).input_ids
20
  outputs = model.generate(
21
+ input_ids=input_ids.to(model.device),
22
+ generation_config=GenerationConfig(max_length=max_length),
23
  )
24
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
25
+ end_time = time.time()
26
+ result = {
27
+ 'result': result,
28
+ 'inference_time': end_time - start_time,
29
+ 'input_token_ids': input_ids[0].tolist(),
30
+ 'output_token_ids': outputs[0].tolist(),
31
+ }
32
+ return result
33
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ def run():
36
+ tokens = [tokenizer.decode(i) for i in range(500)]
37
+ lang_codes = [token[2:-1] for token in tokens if token.startswith("<2")]
38
+ inputs = [
39
+ gr.components.Textbox(lines=5, label="Input text"),
40
+ gr.components.Dropdown(lang_codes, value="en", label="Target Language"),
41
+ gr.components.Slider(
42
+ minimum=5,
43
+ maximum=500,
44
+ value=200,
45
+ label="Max length",
46
+ ),
47
+ ]
48
+ outputs = gr.components.JSON()
49
+ title = f"{MODEL_NAME} demo"
50
+ demo_status = "Demo is running on CPU"
51
+ description = (
52
+ f"Details: https://huggingface.co/{MODEL_NAME}. {demo_status}"
53
+ )
54
+ gr.Interface(
55
+ inference,
56
+ inputs,
57
+ outputs,
58
+ title=title,
59
+ description=description,
60
+ ).launch()
61
 
 
62
 
63
+ if __name__ == "__main__":
64
+ run()