ai-forever commited on
Commit
a334812
1 Parent(s): e18e4a5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -0
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
4
+ from googletrans import Translator
5
+
6
+ translator = Translator()
7
+ tokenizer = GPT2Tokenizer.from_pretrained("sberbank-ai/mGPT-armenian")
8
+ model = GPT2LMHeadModel.from_pretrained("sberbank-ai/mGPT-armenian")
9
+ #model.cuda()
10
+ #model.eval()
11
+
12
+ description = "Multilingual generation with mGPT"
13
+ title = "Generate your own example"
14
+
15
+ examples = [["""English: The vase with flowers is on the table.\nFinnish translation:""", "In May we celebrate "]]
16
+
17
+ article = (
18
+ "<p style='text-align: center'>"
19
+ "<a href='https://github.com/ai-forever/mgpt'>GitHub</a> "
20
+ "</p>"
21
+ )
22
+
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ fp16 = device != 'cpu'
25
+
26
+ def transl(text, src='en', dest='hy' ):
27
+ return translator.translate(text, src=src, dest=dest).text
28
+
29
+ def generate(prompt: str):
30
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
31
+ out = model.generate(input_ids,
32
+ min_length=300,
33
+ max_length=700,
34
+ top_p=0.95,
35
+ top_k=0,
36
+ temperature=0.9,
37
+ no_repeat_ngram_size=5
38
+ )
39
+ generated_text = list(map(tokenizer.decode, out))[0]
40
+ return generated_text + '\n\n'+transl(generated_text, src='hy', dest='en')
41
+
42
+
43
+ interface = gr.Interface.load("huggingface/sberbank-ai/mGPT",
44
+ description=description,
45
+ examples=examples,
46
+ fn=generate,
47
+ inputs="text",
48
+ outputs='text',
49
+ thumbnail = 'https://habrastorage.org/r/w1560/getpro/habr/upload_files/26a/fa1/3e1/26afa13e1d1a56f54c7b0356761af7b8.png',
50
+ theme = "peach",
51
+ article = article
52
+ )
53
+
54
+ interface.launch(enable_queue=True)