mrm8488 commited on
Commit
242e710
1 Parent(s): 7fdee89

Load models before using it

Browse files
Files changed (1) hide show
  1. app.py +13 -5
app.py CHANGED
@@ -3,19 +3,27 @@ import torch
3
  from transformers import RobertaTokenizerFast, BertTokenizerFast, EncoderDecoderModel
4
 
5
 
 
 
 
 
6
  models_paths = dict()
 
7
  models_paths["fr"] = "mrm8488/camembert2camembert_shared-finetuned-french-summarization"
8
  models_paths["de"] = "mrm8488/bert2bert_shared-german-finetuned-summarization"
9
  models_paths["tu"] = "mrm8488/bert2bert_shared-turkish-summarization"
10
  models_paths["es"] = "Narrativa/bsc_roberta2roberta_shared-spanish-finetuned-mlsum-summarization"
11
 
12
-
13
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
14
 
 
 
 
 
15
 
16
  def summarize(lang, text):
17
- tokenizer = RobertaTokenizerFast.from_pretrained(models_paths[lang]) if lang == "fr" or lang == "es" else BertTokenizerFast.from_pretrained(models_paths[lang])
18
- model = EncoderDecoderModel.from_pretrained(models_paths[lang]).to(device)
19
  inputs = tokenizer([text], padding="max_length",
20
  truncation=True, max_length=512, return_tensors="pt")
21
  input_ids = inputs.input_ids.to(device)
@@ -31,8 +39,8 @@ title = "Multilingual Summarization model (MLSUM)"
31
 
32
  description = "Gradio Demo for Summarization models trained on MLSUM dataset by Manuel Romero"
33
 
34
- article = "<p style='text-align: center'><a href='https://hf.com/mrm8488' target='_blank'>More models</a></p>"
35
 
36
 
37
- gr.Interface(fn=summarize, inputs=[gr.inputs.Radio(["fr", "de", "tu", "es"]), gr.inputs.Textbox(
38
  lines=7, label="Input Text")], outputs="text", theme=theme, title=title, description=description, article=article, enable_queue=True).launch(inline=False)
3
  from transformers import RobertaTokenizerFast, BertTokenizerFast, EncoderDecoderModel
4
 
5
 
6
+ LANGUAGES = ["fr", "de", "tu", "es"]
7
+
8
+ models = dict()
9
+ tokenizers = dict()
10
  models_paths = dict()
11
+
12
  models_paths["fr"] = "mrm8488/camembert2camembert_shared-finetuned-french-summarization"
13
  models_paths["de"] = "mrm8488/bert2bert_shared-german-finetuned-summarization"
14
  models_paths["tu"] = "mrm8488/bert2bert_shared-turkish-summarization"
15
  models_paths["es"] = "Narrativa/bsc_roberta2roberta_shared-spanish-finetuned-mlsum-summarization"
16
 
 
17
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
18
 
19
+ for lang in LANGUAGES:
20
+ tokenizers[lang] = RobertaTokenizerFast.from_pretrained(models_paths[lang]) if lang in ["fr", "es"] else BertTokenizerFast.from_pretrained(models_paths[lang])
21
+ models[lang] = EncoderDecoderModel.from_pretrained(models_paths[lang]).to(device)
22
+
23
 
24
  def summarize(lang, text):
25
+ tokenizer = tokenizers[lang]
26
+ model = models[lang]
27
  inputs = tokenizer([text], padding="max_length",
28
  truncation=True, max_length=512, return_tensors="pt")
29
  input_ids = inputs.input_ids.to(device)
39
 
40
  description = "Gradio Demo for Summarization models trained on MLSUM dataset by Manuel Romero"
41
 
42
+ article = "<p style='text-align: center'><a href='https://hf.co/mrm8488' target='_blank'>More models</a></p>"
43
 
44
 
45
+ gr.Interface(fn=summarize, inputs=[gr.inputs.Radio(LANGUAGES), gr.inputs.Textbox(
46
  lines=7, label="Input Text")], outputs="text", theme=theme, title=title, description=description, article=article, enable_queue=True).launch(inline=False)