File size: 2,964 Bytes
62240fd
795587f
7916f53
4cb376e
8295f3b
d59f693
 
0c23ff5
1e7b155
6a541d9
3d70b45
59a84ff
608e3f9
d59f693
7f319ed
 
d59f693
 
7f319ed
 
 
 
186dc98
abad191
4338e7c
 
 
420f35b
 
1e7b155
608e3f9
bbdd0e4
 
 
460a31b
bbdd0e4
 
 
 
 
d59f693
bbdd0e4
 
 
 
1e7b155
bbdd0e4
 
 
2c7b4a1
 
bbdd0e4
 
1e7b155
bbdd0e4
 
 
7f319ed
bbdd0e4
 
 
19cd698
dd3bb53
cd4dfb7
2d7bba1
19cd698
 
cd4dfb7
236e944
 
 
1db1d29
189ddb9
 
19cd698
cd4dfb7
19cd698
1e7b155
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
import numpy as np
import os
from huggingface_hub import hf_hub_download
from camel_tools.data import CATALOGUE
from camel_tools.tagger.default import DefaultTagger
from camel_tools.disambig.bert import BERTUnfactoredDisambiguator

def predict_label(text):
   
    ip = text.split()
    ip_len = [len(ip)]
    
    span_scores = extract_spannet_scores(span_model,ip,ip_len)
    span_pooled_scores = pool_span_scores(span_scores, ip_len)

    pos_tags = tagger.tag(ip)
    msa_span_scores = extract_spannet_scores(msa_span_model,ip,ip_len,pos=pos_tags)
    msa_pooled_scores = pool_span_scores(msa_span_scores, ip_len)

    ensemble_span_scores = [score for scores in [span_scores, msa_span_scores] for score in scores]
    ensemble_pooled_scores = pool_span_scores(ensemble_span_scores, ip_len)
    
    ent_scores = extract_ent_scores(entity_model,ip,ensemble_pooled_scores)
    combined_sequences, ent_pred_tags = pool_ent_scores(ent_scores, ip_len)
    
    return combined_sequences


if __name__ == '__main__':
    
    space_key = os.environ.get('key')
    filenames = ['network.py', 'layers.py', 'utils.py',
                 'representation.py', 'predict.py', 'validate.py']
    
    for file in filenames:
        hf_hub_download('nehalelkaref/stagedNER',
                    filename=file,
                    local_dir='src',
                    token=space_key)
        
    CATALOGUE.download_package("all",
                           recursive=True,
                           force=True,
                           print_status=True)
    
    from src.predict import extract_spannet_scores,extract_ent_scores,pool_span_scores,pool_ent_scores
    from src.network import SpanNet, EntNet
    from src.validate import entities_from_token_classes


    diasmbig = BERTUnfactoredDisambiguator.pretrained('msa')
    tagger = DefaultTagger(diasmbig, 'pos')
    
    span_path = 'models/span.model'
    msa_span_path = 'new_models/msa.best.model'
    entity_path= 'models/entity.msa.model'
    
    span_model = SpanNet.load_model(span_path)
    msa_span_model = SpanNet.load_model(msa_span_path)
    entity_model = EntNet.load_model(entity_path)
    

    
    with gr.Blocks(theme='finlaymacklon/smooth_slate') as iface:
        example_input=gr.Textbox(label="Input Example", lines=3)
        prediction=gr.Text(label="Predicted Entities")
        
        gr.Interface(fn=predict_label, inputs=example_input,
                    outputs=prediction,theme="smooth_slate",
                    title="Flat Entity Classification for Levant Arabic")
        gr.Examples(
      examples=["النشرة الإخبارية الصادرة عن الأونروا رقم 113 (1986/1/8).",
               "صورة لمدينة أريحا القديمة :تل السلطان",
               "صورة اطفال مخيم للاجئين الفلسطينيين ي لبنان"],
      inputs= example_input)
        
    iface.launch(show_api=False)