TymaaHammouda commited on
Commit
b5d6f8a
·
verified ·
1 Parent(s): e5a79d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -2
app.py CHANGED
@@ -30,13 +30,31 @@ print("Version ---- 2")
30
  app = FastAPI()
31
 
32
 
33
- app = FastAPI()
34
-
35
  pretrained_path = "aubmindlab/bert-base-arabertv2" # must match training
36
  tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
37
  encoder = AutoModel.from_pretrained(pretrained_path).eval()
38
 
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  checkpoint_path = snapshot_download(repo_id="SinaLab/Nested", allow_patterns="checkpoints/")
41
 
42
  args_path = hf_hub_download(
@@ -188,6 +206,7 @@ def NER(sentence, mode):
188
  return json_short
189
 
190
  BASE_DIR = os.path.expanduser("~/.sinatools")
 
191
 
192
  # Paths expected by sinatools
193
  RELATION_MODEL_DIR = os.path.join(BASE_DIR, "relation_model")
@@ -205,6 +224,21 @@ if not os.path.exists(RELATION_MODEL_DIR) or not os.listdir(RELATION_MODEL_DIR):
205
  )
206
 
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  from sinatools.relations.relation_extractor import relation_extraction
209
  from sinatools.relations.event_relation_extractor import event_argument_relation_extraction
210
 
 
30
  app = FastAPI()
31
 
32
 
 
 
33
  pretrained_path = "aubmindlab/bert-base-arabertv2" # must match training
34
  tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
35
  encoder = AutoModel.from_pretrained(pretrained_path).eval()
36
 
37
 
38
+ def download_file_from_hf(repo_id, filename):
39
+ target_dir = os.path.expanduser("~/.sinatools/Wj27012000.tar")
40
+ os.makedirs(target_dir, exist_ok=True)
41
+
42
+ file_path = hf_hub_download(
43
+ repo_id=repo_id,
44
+ filename=filename,
45
+ local_dir=target_dir,
46
+ local_dir_use_symlinks=False
47
+ )
48
+
49
+ return file_path
50
+
51
+ download_file_from_hf("SinaLab/Nested-v1","args.json")
52
+ download_file_from_hf("SinaLab/Nested-v1","tag_vocab.pkl")
53
+
54
+ snapshot_download(repo_id="SinaLab/Nested", allow_patterns="checkpoints/")
55
+
56
+
57
+
58
  checkpoint_path = snapshot_download(repo_id="SinaLab/Nested", allow_patterns="checkpoints/")
59
 
60
  args_path = hf_hub_download(
 
206
  return json_short
207
 
208
  BASE_DIR = os.path.expanduser("~/.sinatools")
209
+ NER_DIR = os.path.join(BASE_DIR, "Wj27012000.tar")
210
 
211
  # Paths expected by sinatools
212
  RELATION_MODEL_DIR = os.path.join(BASE_DIR, "relation_model")
 
224
  )
225
 
226
 
227
+ if not os.path.exists(NER_DIR):
228
+ os.makedirs(NER_DIR, exist_ok=True)
229
+
230
+ nested_repo_path = snapshot_download(
231
+ repo_id="SinaLab/Nested"
232
+ )
233
+
234
+ # Copy tag_vocab.pkl to expected location
235
+ src_vocab = os.path.join(nested_repo_path, "Nested", "utils", "tag_vocab.pkl")
236
+ dst_vocab = os.path.join(NER_DIR, "tag_vocab.pkl")
237
+
238
+ if os.path.exists(src_vocab):
239
+ shutil.copy(src_vocab, dst_vocab)
240
+
241
+
242
  from sinatools.relations.relation_extractor import relation_extraction
243
  from sinatools.relations.event_relation_extractor import event_argument_relation_extraction
244