nehalelkaref's picture
Update app.py
c70f0bf verified
raw
history blame contribute delete
No virus
2.96 kB
import gradio as gr
import numpy as np
import os
from huggingface_hub import hf_hub_download
from camel_tools.data import CATALOGUE
from camel_tools.tagger.default import DefaultTagger
from camel_tools.disambig.bert import BERTUnfactoredDisambiguator
def predict_label(text):
ip = text.split()
ip_len = [len(ip)]
span_scores = extract_spannet_scores(span_model,ip,ip_len)
span_pooled_scores = pool_span_scores(span_scores, ip_len)
pos_tags = tagger.tag(ip)
msa_span_scores = extract_spannet_scores(msa_span_model,ip,ip_len,pos=pos_tags)
msa_pooled_scores = pool_span_scores(msa_span_scores, ip_len)
ensemble_span_scores = [score for scores in [span_scores, msa_span_scores] for score in scores]
ensemble_pooled_scores = pool_span_scores(ensemble_span_scores, ip_len)
ent_scores = extract_ent_scores(entity_model,ip,ensemble_pooled_scores)
combined_sequences, ent_pred_tags = pool_ent_scores(ent_scores, ip_len)
return combined_sequences
if __name__ == '__main__':
space_key = os.environ.get('key')
filenames = ['network.py', 'layers.py', 'utils.py',
'representation.py', 'predict.py', 'validate.py']
for file in filenames:
hf_hub_download('nehalelkaref/stagedNER',
filename=file,
local_dir='src',
token=space_key)
CATALOGUE.download_package("all",
recursive=True,
force=True,
print_status=True)
from src.predict import extract_spannet_scores,extract_ent_scores,pool_span_scores,pool_ent_scores
from src.network import SpanNet, EntNet
from src.validate import entities_from_token_classes
diasmbig = BERTUnfactoredDisambiguator.pretrained('msa')
tagger = DefaultTagger(diasmbig, 'pos')
span_path = 'models/span.model'
msa_span_path = 'new_models/msa.best.model'
entity_path= 'models/entity.msa.model'
span_model = SpanNet.load_model(span_path)
msa_span_model = SpanNet.load_model(msa_span_path)
entity_model = EntNet.load_model(entity_path)
with gr.Blocks(theme='finlaymacklon/smooth_slate') as iface:
example_input=gr.Textbox(label="Input Example", lines=3)
prediction=gr.Text(label="Predicted Entities")
gr.Interface(fn=predict_label, inputs=example_input,
outputs=prediction,theme="smooth_slate",
title="Flat Entity Classification for Levant Arabic")
gr.Examples(
examples=["النشرة الإخبارية الصادرة عن الأونروا رقم 113 (1986/1/8).",
"صورة لمدينة أريحا القديمة :تل السلطان",
"صورة اطفال مخيم للاجئين الفلسطينيين ي لبنان"],
inputs= example_input)
iface.launch(show_api=False)