NavedSid commited on
Commit
709a310
1 Parent(s): 69a5344

Added translation controls and code

Browse files
Files changed (3) hide show
  1. app.py +46 -47
  2. requirements.txt +5 -2
  3. translate.py +47 -0
app.py CHANGED
@@ -1,47 +1,46 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer
2
- import gradio as gr
3
- import torch
4
-
5
-
6
- title = "Translation Chatbot"
7
- #description = "A State-of-the-Art Large-scale Pretrained Response generation model (DialoGPT)"
8
- examples = [["How are you?"]]
9
-
10
-
11
- tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
12
- model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
13
-
14
-
15
- def predict(input, history=[]):
16
- # tokenize the new input sentence
17
- new_user_input_ids = tokenizer.encode(
18
- input + tokenizer.eos_token, return_tensors="pt"
19
- )
20
-
21
- # append the new user input tokens to the chat history
22
- bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
23
-
24
- # generate a response
25
- history = model.generate(
26
- bot_input_ids, max_length=4000, pad_token_id=tokenizer.eos_token_id
27
- ).tolist()
28
-
29
- # convert the tokens to text, and then split the responses into lines
30
- response = tokenizer.decode(history[0]).split("<|endoftext|>")
31
- # print('decoded_response-->>'+str(response))
32
- response = [
33
- (response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)
34
- ] # convert to tuples of list
35
- # print('response-->>'+str(response))
36
- return response, history
37
-
38
-
39
- gr.Interface(
40
- fn=predict,
41
- title=title,
42
- #description=description,
43
- examples=examples,
44
- inputs=["text", "state"],
45
- outputs=["chatbot", "state"],
46
- theme="earneleh/paris",
47
- ).launch()
 
1
+ import gradio as gr
2
+ import torch
3
+
4
+ from translate import Translator
5
+
6
+ # https://medium.com/analytics-vidhya/make-a-translate-app-with-huggingface-transformers-ce9203f84c79
7
+ # https://huggingface.co/docs/transformers/en/model_doc/mbart
8
+ title = "Translation Chatbot"
9
+ description = "A simple implementation of translating one language to another"
10
+ examples = [["UN Chief Says There Is No Military Solution in Syria","en_XX","ja_XX"]]
11
+
12
+ translator_obj = Translator()
13
+
14
+ def translate_sentence(sentence):
15
+ return pipe(f'<-ja2zh-> {sentence}')[0]['translation_text']
16
+
17
+ def predict(input,
18
+ history=[],
19
+ original_language="en_XX",
20
+ translated_language="ro_RO"):
21
+ response = translator_obj.translate(input, original_language, translated_language)
22
+ history.append((input, response))
23
+ return history, history
24
+
25
+ if __name__ == "__main__":
26
+ gr.Interface(
27
+ fn=predict,
28
+ title=title,
29
+ description=description,
30
+ examples=examples,
31
+ inputs=[
32
+ gr.Textbox(),
33
+ "state",
34
+ gr.Dropdown(
35
+ [("English","en_XX"), ("French","fr_XX"), ("German","de_DE"), ("Japanese","ja_XX"), ("Russian","ru_RU")], value="en_XX", label="Input Language", info="Choose the language the input text is in."
36
+ ),
37
+ gr.Dropdown(
38
+ [("French","fr_XX"), ("German","de_DE"), ("Japanese","ja_XX"), ("Russian","ru_RU"), ("English","en_XX")], value="fr_XX", label="Language to translate to", info="Choose the language to convert the text to."
39
+ )
40
+ ],
41
+ outputs=[
42
+ gr.Chatbot(),
43
+ "state"
44
+ ],
45
+ theme='earneleh/paris',
46
+ ).launch()
 
requirements.txt CHANGED
@@ -1,2 +1,5 @@
1
- transformers
2
- torch
 
 
 
 
1
+ transformers
2
+ torch
3
+ gradio
4
+ sentencepiece
5
+ protobuf
translate.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
2
+
3
+
4
+ class Translator:
5
+
6
+ '''
7
+ Install Requirements -
8
+ pip install pickle5 transformers==4.12.2 sentencepiece
9
+ MBart Documentation
10
+ https://huggingface.co/transformers/model_doc/mbart.html
11
+ Get the supported lang codes
12
+ https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt
13
+ Class - Translator
14
+ Initializes MBart Seq2Seq Model and Tokenizer
15
+ Helper func to translate input language to desired target language
16
+ Supported Languages: English, Gujarati, Hindi, Bengali, Malayalam, Marathi, Tamil, Telugu
17
+
18
+ '''
19
+
20
+ def __init__(self):
21
+
22
+ self.model = MBartForConditionalGeneration.from_pretrained('facebook/mbart-large-50-many-to-many-mmt')
23
+ self.tokenizer = MBart50TokenizerFast.from_pretrained('facebook/mbart-large-50-many-to-many-mmt')
24
+ # self.model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50")
25
+ # self.tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50")
26
+ # , src_lang="en_XX", tgt_lang="ro_RO"
27
+ # https://dl-translate.readthedocs.io/en/latest/available_languages/
28
+ self.supported_langs = ['en_XX', 'fr_XX', 'de_DE', 'ru_RU', 'ja_XX']
29
+
30
+
31
+
32
+ def translate(self, input_text, src_lang, tgt_lang):
33
+
34
+ if src_lang not in self.supported_langs:
35
+ raise RuntimeError('Unsupported source language.')
36
+ if tgt_lang not in self.supported_langs:
37
+ raise RuntimeError('Unsupported target language.')
38
+
39
+ self.tokenizer.src_lang = src_lang
40
+ encoded_text = self.tokenizer(input_text, return_tensors='pt')
41
+ generated_tokens = self.model.generate(**encoded_text, forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang])
42
+ output_text_arr = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
43
+
44
+ if len(output_text_arr) > 0:
45
+ return output_text_arr[0]
46
+ else:
47
+ raise RuntimeError('Failed to generate output. Output Text Array is empty.')