fivelittlec commited on
Commit
f19cf37
1 Parent(s): 105cf97

feat(app):adapt to madlad-400

Browse files
Files changed (1) hide show
  1. app.py +83 -4
app.py CHANGED
@@ -1,7 +1,86 @@
 
1
  import gradio as gr
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
  import gradio as gr
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
 
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
7
 
8
+ tokenizer_3b_mt = AutoTokenizer.from_pretrained("google/madlad400-3b-mt", use_fast=True)
9
+ language_codes = [token for token in tokenizer_3b_mt.get_vocab().keys() if token.startswith("<2")]
10
+ remove_codes = ['<2>', '<2en_xx_simple>', '<2translate>', '<2back_translated>', '<2zxx_xx_dtynoise>',
11
+ '<2transliterate>']
12
+ language_codes = [token for token in language_codes if token not in remove_codes]
13
+
14
+ model_choices = [
15
+ "google/madlad400-3b-mt",
16
+ "google/madlad400-7b-mt",
17
+ "google/madlad400-10b-mt",
18
+ "google/madlad400-7b-mt-bt"
19
+ ]
20
+
21
+ model_resources = {}
22
+
23
+
24
+ def load_tokenizer_model(model_name):
25
+ """
26
+ Load tokenizer and model for a chosen model name.
27
+ """
28
+ if model_name not in model_resources:
29
+ # Load tokenizer and model for first time
30
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
31
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.float16)
32
+ model.to_bettertransformer()
33
+ model.to(device)
34
+ model_resources[model_name] = (tokenizer, model)
35
+ return model_resources[model_name]
36
+
37
+
38
+ @spaces.GPU
39
+ def translate(text, target_language, model_name):
40
+ """
41
+ Translate the input text from English to another language.
42
+ """
43
+ # Load tokenizer and model if not already loaded
44
+ tokenizer, model = load_tokenizer_model(model_name)
45
+
46
+ text = target_language + text
47
+ input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
48
+
49
+ outputs = model.generate(input_ids=input_ids, max_new_tokens=128000)
50
+ text_translated = tokenizer.batch_decode(outputs, skip_special_tokens=True)
51
+
52
+ return text_translated[0]
53
+
54
+
55
+ title = "MADLAD-400 Translation"
56
+ description = """
57
+ Translation from English to over 400 languages based on [research](https://arxiv.org/pdf/2309.04662) by Google DeepMind and Google Research. Initial inference will be slow as models load.
58
+ """
59
+
60
+ input_text = gr.Textbox(
61
+ label="Text",
62
+ placeholder="Enter text here"
63
+ )
64
+ target_language = gr.Dropdown(
65
+ choices=language_codes,
66
+ value="<2haw>",
67
+ label="Target language"
68
+ )
69
+ model_choice = gr.Dropdown(
70
+ choices=model_choices,
71
+ value="google/madlad400-3b-mt",
72
+ label="Model"
73
+ )
74
+ output_text = gr.Textbox(label="Translation")
75
+
76
+ demo = gr.Interface(
77
+ fn=translate,
78
+ inputs=[input_text, target_language, model_choice],
79
+ outputs=output_text,
80
+ title=title,
81
+ description=description
82
+ )
83
+
84
+ demo.queue()
85
+
86
+ demo.launch()