|
import gradio as gr |
|
import numpy as np |
|
import os |
|
from huggingface_hub import hf_hub_download |
|
from camel_tools.data import CATALOGUE |
|
from camel_tools.tagger.default import DefaultTagger |
|
from camel_tools.disambig.bert import BERTUnfactoredDisambiguator |
|
|
|
def predict_label(text): |
|
|
|
ip = text.split() |
|
ip_len = [len(ip)] |
|
|
|
span_scores = extract_spannet_scores(span_model,ip,ip_len) |
|
span_pooled_scores = pool_span_scores(span_scores, ip_len) |
|
|
|
pos_tags = tagger.tag(ip) |
|
msa_span_scores = extract_spannet_scores(msa_span_model,ip,ip_len,pos=pos_tags) |
|
msa_pooled_scores = pool_span_scores(msa_span_scores, ip_len) |
|
|
|
ensemble_span_scores = [score for scores in [span_scores, msa_span_scores] for score in scores] |
|
ensemble_pooled_scores = pool_span_scores(ensemble_span_scores, ip_len) |
|
|
|
ent_scores = extract_ent_scores(entity_model,ip,ensemble_pooled_scores) |
|
combined_sequences, ent_pred_tags = pool_ent_scores(ent_scores, ip_len) |
|
|
|
return combined_sequences |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
space_key = os.environ.get('key') |
|
filenames = ['network.py', 'layers.py', 'utils.py', |
|
'representation.py', 'predict.py', 'validate.py'] |
|
|
|
for file in filenames: |
|
hf_hub_download('nehalelkaref/stagedNER', |
|
filename=file, |
|
local_dir='src', |
|
token=space_key) |
|
|
|
CATALOGUE.download_package("all", |
|
recursive=True, |
|
force=True, |
|
print_status=True) |
|
|
|
from src.predict import extract_spannet_scores,extract_ent_scores,pool_span_scores,pool_ent_scores |
|
from src.network import SpanNet, EntNet |
|
from src.validate import entities_from_token_classes |
|
|
|
|
|
diasmbig = BERTUnfactoredDisambiguator.pretrained('msa') |
|
tagger = DefaultTagger(diasmbig, 'pos') |
|
|
|
span_path = 'models/span.model' |
|
msa_span_path = 'new_models/msa.best.model' |
|
entity_path= 'models/entity.msa.model' |
|
|
|
span_model = SpanNet.load_model(span_path) |
|
msa_span_model = SpanNet.load_model(msa_span_path) |
|
entity_model = EntNet.load_model(entity_path) |
|
|
|
|
|
|
|
with gr.Blocks(theme='finlaymacklon/smooth_slate') as iface: |
|
example_input=gr.Textbox(label="Input Example", lines=3) |
|
prediction=gr.Text(label="Predicted Entities") |
|
|
|
gr.Interface(fn=predict_label, inputs=example_input, |
|
outputs=prediction,theme="smooth_slate", |
|
title="Flat Entity Classification for Levant Arabic") |
|
gr.Examples( |
|
examples=["النشرة الإخبارية الصادرة عن الأونروا رقم 113 (1986/1/8).", |
|
"صورة لمدينة أريحا القديمة :تل السلطان", |
|
"صورة اطفال مخيم للاجئين الفلسطينيين ي لبنان"], |
|
inputs= example_input) |
|
|
|
iface.launch(show_api=False) |
|
|