paascorb commited on
Commit
7ec67ba
1 Parent(s): eea5506

Añadiendo las traducciones

Browse files
Files changed (1) hide show
  1. app.py +58 -4
app.py CHANGED
@@ -7,15 +7,69 @@ os.system('pip install --upgrade pip')
7
  os.system('pip install tensorflow')
8
 
9
  from transformers import pipeline
 
 
 
 
 
 
 
10
 
11
  docs = None
12
 
 
 
 
 
 
13
 
14
  def request_pathname(files):
15
  if files is None:
16
  return [[]]
17
  return [[file.name, file.name.split('/')[-1]] for file in files]
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  def validate_dataset(dataset):
21
  global docs
@@ -27,19 +81,19 @@ def validate_dataset(dataset):
27
  return "⚠️Esperando documentos..."
28
 
29
  def do_ask(question, button, dataset):
30
-
31
  global docs
32
  docs_ready = dataset.iloc[-1, 0] != ""
33
  if button == "✨Listo✨" and docs_ready:
34
  for _, row in dataset.iterrows():
35
  path = row['filepath']
36
  text = Path(f'{path}').read_text()
 
37
  question_answerer = pipeline("question-answering", model='distilbert-base-cased-distilled-squad')
38
  QA_input = {
39
- 'question': question,
40
- 'context': text
41
  }
42
- return question_answerer(QA_input)['answer']
43
  else:
44
  return ""
45
 
 
7
  os.system('pip install tensorflow')
8
 
9
  from transformers import pipeline
10
+ from transformers import MarianMTModel, MarianTokenizer
11
+ from nltk.tokenize import sent_tokenize
12
+ from nltk.tokenize import LineTokenizer
13
+ import math
14
+ import torch
15
+ import nltk
16
+ nltk.download('punkt')
17
 
18
  docs = None
19
 
20
+ if torch.cuda.is_available():
21
+ dev = "cuda"
22
+ else:
23
+ dev = "cpu"
24
+ device = torch.device(dev)
25
 
26
  def request_pathname(files):
27
  if files is None:
28
  return [[]]
29
  return [[file.name, file.name.split('/')[-1]] for file in files]
30
 
31
+ def traducir_parrafos(parrafos, tokenizer, model, tam_bloque=8, ):
32
+ parrafos_traducidos = []
33
+ for parrafo in parrafos:
34
+ frases = sent_tokenize(parrafo)
35
+ batches = math.ceil(len(frases) / tam_bloque)
36
+ traducido = []
37
+ for i in range(batches):
38
+
39
+ bloque_enviado = frases[i*tam_bloque:(i+1)*tam_bloque]
40
+ model_inputs = tokenizer(bloque_enviado, return_tensors="pt",
41
+ padding=True, truncation=True,
42
+ max_length=500).to(device)
43
+ with torch.no_grad():
44
+ bloque_traducido = model.generate(**model_inputs)
45
+ traducido += bloque_traducido
46
+ traducido = [tokenizer.decode(t, skip_special_tokens=True) for t in traducido]
47
+ parrafos_traducidos += [" ".join(traducido)]
48
+ return parrafos_traducidos
49
+
50
+ def traducir_es_en(texto):
51
+ mname = "Helsinki-NLP/opus-mt-es-en"
52
+ tokenizer = MarianTokenizer.from_pretrained(mname)
53
+ model = MarianMTModel.from_pretrained(mname)
54
+ model.to(device)
55
+
56
+ lt = LineTokenizer()
57
+ batch_size = 8
58
+ parrafos = lt.tokenize(text_long)
59
+ par_tra = traducir_parrafos(parrafos, tokenizer, model)
60
+ return "\n".join(par_tra)
61
+
62
+ def traducir_en_es(texto):
63
+ mname = "Helsinki-NLP/opus-mt-en-es"
64
+ tokenizer = MarianTokenizer.from_pretrained(mname)
65
+ model = MarianMTModel.from_pretrained(mname)
66
+ model.to(device)
67
+
68
+ lt = LineTokenizer()
69
+ batch_size = 8
70
+ parrafos = lt.tokenize(text_long)
71
+ par_tra = traducir_parrafos(parrafos, tokenizer, model)
72
+ return "\n".join(par_tra)
73
 
74
  def validate_dataset(dataset):
75
  global docs
 
81
  return "⚠️Esperando documentos..."
82
 
83
  def do_ask(question, button, dataset):
 
84
  global docs
85
  docs_ready = dataset.iloc[-1, 0] != ""
86
  if button == "✨Listo✨" and docs_ready:
87
  for _, row in dataset.iterrows():
88
  path = row['filepath']
89
  text = Path(f'{path}').read_text()
90
+ text_en = traducir_es_en(text)
91
  question_answerer = pipeline("question-answering", model='distilbert-base-cased-distilled-squad')
92
  QA_input = {
93
+ 'question': traducir_es_en(question),
94
+ 'context': text_en
95
  }
96
+ return traducir_en_es(question_answerer(QA_input)['answer'])
97
  else:
98
  return ""
99