PaulNdrei commited on
Commit
7112824
1 Parent(s): b5d1d44

Fix runtime error

Browse files
Files changed (1) hide show
  1. app.py +15 -41
app.py CHANGED
@@ -2,14 +2,9 @@ import os
2
  from dotenv import load_dotenv
3
  import gradio as gr
4
  from AinaTheme import theme
5
- from huggingface_hub import snapshot_download
6
- import nltk
7
  from api_endpoint import invoke_translate_endpoint
8
-
9
  from translate import translate_text
10
 
11
- nltk.download('punkt')
12
-
13
  load_dotenv()
14
 
15
  MODELS_PATH = "./models"
@@ -18,55 +13,36 @@ MAX_INPUT_CHARACTERS = int(os.environ.get("MAX_INPUT_CHARACTERS", default=1000))
18
  API_ENDPOINT_ENABLED = os.environ.get("API_ENDPOINT_ENABLED", default=True) == "True"
19
 
20
 
21
- def download_model(repo_id, revision="main"):
22
- return snapshot_download(repo_id=repo_id, revision=revision, local_dir=os.path.join(MODELS_PATH, repo_id), cache_dir=HF_CACHE_DIR)
23
-
24
- model_dir_ca_es = download_model("projecte-aina/aina-translator-ca-es", revision="main")
25
- model_dir_es_ca = download_model("projecte-aina/aina-translator-es-ca", revision="main")
26
-
27
- model_dir_ca_en = download_model("projecte-aina/aina-translator-ca-en", revision="main")
28
- model_dir_en_ca = download_model("projecte-aina/aina-translator-en-ca", revision="main")
29
-
30
- model_dir_ca_fr = download_model("projecte-aina/aina-translator-ca-fr", revision="main")
31
- model_dir_fr_ca = download_model("projecte-aina/aina-translator-fr-ca", revision="main")
32
-
33
- model_dir_ca_de = download_model("projecte-aina/aina-translator-ca-de", revision="main")
34
-
35
- model_dir_ca_it = download_model("projecte-aina/aina-translator-ca-it", revision="main")
36
-
37
- model_dir_ca_pt = download_model("projecte-aina/aina-translator-ca-pt", revision="main")
38
- model_dir_pt_ca = download_model("projecte-aina/aina-translator-pt-ca", revision="main")
39
-
40
  directions = {
41
  "Catalan": {
42
  "target": {
43
- "Spanish": {"src": "ca", "tgt":"es", "model": (f"{os.path.join(MODELS_PATH, model_dir_ca_es)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_ca_es)}")},
44
- "English": {"src": "ca", "tgt":"en", "model": (f"{os.path.join(MODELS_PATH, model_dir_ca_en)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_ca_en)}")},
45
- "French": {"src": "ca", "tgt":"fr", "model": (f"{os.path.join(MODELS_PATH, model_dir_ca_fr)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_ca_fr)}")},
46
- "German": {"src": "ca", "tgt":"de", "model": (f"{os.path.join(MODELS_PATH, model_dir_ca_de)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_ca_de)}")},
47
- "Italian": {"src": "ca", "tgt":"it", "model": (f"{os.path.join(MODELS_PATH, model_dir_ca_it)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_ca_it)}")},
48
- "Portuguese": {"src": "ca", "tgt":"pt", "model": (f"{os.path.join(MODELS_PATH, model_dir_ca_pt)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_ca_pt)}")}
49
 
50
  }
51
  },
52
  "Spanish": {
53
  "target": {
54
- "Catalan": {"src": "es", "tgt":"ca", "model": (f"{os.path.join(MODELS_PATH, model_dir_es_ca)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_es_ca)}")},
55
  }
56
  },
57
  "English": {
58
  "target": {
59
- "Catalan": {"src": "en", "tgt":"ca", "model": (f"{os.path.join(MODELS_PATH, model_dir_en_ca)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_en_ca)}")},
60
  }
61
  },
62
  "French": {
63
  "target": {
64
- "Catalan": {"src": "fr", "tgt":"ca", "model": (f"{os.path.join(MODELS_PATH, model_dir_fr_ca)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_fr_ca)}")},
65
  }
66
  },
67
  "Portuguese": {
68
  "target": {
69
- "Catalan": {"src": "pt", "tgt":"ca", "model": (f"{os.path.join(MODELS_PATH, model_dir_pt_ca)}/spm.model", f"{os.path.join(MODELS_PATH, model_dir_pt_ca)}")},
70
  }
71
  }
72
  }
@@ -106,13 +82,11 @@ def translate_input(input, source_language, target_language):
106
 
107
  target_language_model = get_target_languege_model(source_language, target_language)
108
 
109
- if (API_ENDPOINT_ENABLED and 'src' in target_language_model):
110
- translation = invoke_translate_endpoint(input, target_language_model)
111
- if translation is not None:
112
- return translation
113
- translation = translate(input, source_language, target_language_model.get('model'))
114
-
115
- return translation
116
 
117
 
118
  def clear():
 
2
  from dotenv import load_dotenv
3
  import gradio as gr
4
  from AinaTheme import theme
 
 
5
  from api_endpoint import invoke_translate_endpoint
 
6
  from translate import translate_text
7
 
 
 
8
  load_dotenv()
9
 
10
  MODELS_PATH = "./models"
 
13
  API_ENDPOINT_ENABLED = os.environ.get("API_ENDPOINT_ENABLED", default=True) == "True"
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  directions = {
17
  "Catalan": {
18
  "target": {
19
+ "Spanish": {"src": "ca", "tgt":"es"},
20
+ "English": {"src": "ca", "tgt":"en"},
21
+ "French": {"src": "ca", "tgt":"fr"},
22
+ "German": {"src": "ca", "tgt":"de"},
23
+ "Italian": {"src": "ca", "tgt":"it"},
24
+ "Portuguese": {"src": "ca", "tgt":"pt"}
25
 
26
  }
27
  },
28
  "Spanish": {
29
  "target": {
30
+ "Catalan": {"src": "es", "tgt":"ca"},
31
  }
32
  },
33
  "English": {
34
  "target": {
35
+ "Catalan": {"src": "en", "tgt":"ca"},
36
  }
37
  },
38
  "French": {
39
  "target": {
40
+ "Catalan": {"src": "fr", "tgt":"ca"},
41
  }
42
  },
43
  "Portuguese": {
44
  "target": {
45
+ "Catalan": {"src": "pt", "tgt":"ca"},
46
  }
47
  }
48
  }
 
82
 
83
  target_language_model = get_target_languege_model(source_language, target_language)
84
 
85
+ translation = invoke_translate_endpoint(input, target_language_model)
86
+ if translation is not None:
87
+ return translation
88
+
89
+ return None
 
 
90
 
91
 
92
  def clear():