nehalelkaref commited on
Commit
bbdd0e4
1 Parent(s): 236e944

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -23
app.py CHANGED
@@ -29,36 +29,36 @@ def predict_label(text):
29
 
30
  if __name__ == '__main__':
31
 
32
- # space_key = os.environ.get('key')
33
- # filenames = ['network.py', 'layers.py', 'utils.py',
34
- # 'representation.py', 'predict.py', 'validate.py']
35
 
36
- # for file in filenames:
37
- # hf_hub_download('nehalelkaref/stagedNER',
38
- # filename=file,
39
- # local_dir='src',
40
- # token=space_key)
41
 
42
- # CATALOGUE.download_package("all",
43
- # recursive=True,
44
- # force=True,
45
- # print_status=True)
46
 
47
- # from src.predict import extract_spannet_scores,extract_ent_scores,pool_span_scores,pool_ent_scores
48
- # from src.network import SpanNet, EntNet
49
- # from src.validate import entities_from_token_classes
50
 
51
 
52
- # diasmbig = BERTUnfactoredDisambiguator.pretrained('msa')
53
- # tagger = DefaultTagger(diasmbig, 'pos')
54
 
55
- # span_path = 'models/span.model'
56
- # msa_span_path = 'new_models/msa.best.model'
57
- # entity_path= 'models/entity.msa.model'
58
 
59
- # span_model = SpanNet.load_model(span_path)
60
- # msa_span_model = SpanNet.load_model(msa_span_path)
61
- # entity_model = EntNet.load_model(entity_path)
62
 
63
 
64
 
 
29
 
30
  if __name__ == '__main__':
31
 
32
+ space_key = os.environ.get('key')
33
+ filenames = ['network.py', 'layers.py', 'utils.py',
34
+ 'representation.py', 'predict.py', 'validate.py']
35
 
36
+ for file in filenames:
37
+ hf_hub_download('nehalelkaref/stagedNER',
38
+ filename=file,
39
+ local_dir='src',
40
+ token=space_key)
41
 
42
+ CATALOGUE.download_package("all",
43
+ recursive=True,
44
+ force=True,
45
+ print_status=True)
46
 
47
+ from src.predict import extract_spannet_scores,extract_ent_scores,pool_span_scores,pool_ent_scores
48
+ from src.network import SpanNet, EntNet
49
+ from src.validate import entities_from_token_classes
50
 
51
 
52
+ diasmbig = BERTUnfactoredDisambiguator.pretrained('msa')
53
+ tagger = DefaultTagger(diasmbig, 'pos')
54
 
55
+ span_path = 'models/span.model'
56
+ msa_span_path = 'new_models/msa.best.model'
57
+ entity_path= 'models/entity.msa.model'
58
 
59
+ span_model = SpanNet.load_model(span_path)
60
+ msa_span_model = SpanNet.load_model(msa_span_path)
61
+ entity_model = EntNet.load_model(entity_path)
62
 
63
 
64