import gradio as gr import numpy as np import os from huggingface_hub import hf_hub_download from camel_tools.data import CATALOGUE from camel_tools.tagger.default import DefaultTagger from camel_tools.disambig.bert import BERTUnfactoredDisambiguator def predict_label(text): ip = text.split() ip_len = [len(ip)] span_scores = extract_spannet_scores(span_model,ip,ip_len) span_pooled_scores = pool_span_scores(span_scores, ip_len) pos_tags = tagger.tag(ip) msa_span_scores = extract_spannet_scores(msa_span_model,ip,ip_len,pos=pos_tags) msa_pooled_scores = pool_span_scores(msa_span_scores, ip_len) ensemble_span_scores = [score for scores in [span_scores, msa_span_scores] for score in scores] ensemble_pooled_scores = pool_span_scores(ensemble_span_scores, ip_len) ent_scores = extract_ent_scores(entity_model,ip,ensemble_pooled_scores) combined_sequences, ent_pred_tags = pool_ent_scores(ent_scores, ip_len) return combined_sequences if __name__ == '__main__': space_key = os.environ.get('key') filenames = ['network.py', 'layers.py', 'utils.py', 'representation.py', 'predict.py', 'validate.py'] for file in filenames: hf_hub_download('nehalelkaref/stagedNER', filename=file, local_dir='src', token=space_key) CATALOGUE.download_package("all", recursive=True, force=True, print_status=True) from src.predict import extract_spannet_scores,extract_ent_scores,pool_span_scores,pool_ent_scores from src.network import SpanNet, EntNet from src.validate import entities_from_token_classes diasmbig = BERTUnfactoredDisambiguator.pretrained('msa') tagger = DefaultTagger(diasmbig, 'pos') span_path = 'models/span.model' msa_span_path = 'new_models/msa.best.model' entity_path= 'models/entity.msa.model' span_model = SpanNet.load_model(span_path) msa_span_model = SpanNet.load_model(msa_span_path) entity_model = EntNet.load_model(entity_path) with gr.Blocks(theme='finlaymacklon/smooth_slate') as iface: example_input=gr.Textbox(label="Input Example", lines=3) prediction=gr.Text(label="Predicted Entities") gr.Interface(fn=predict_label, inputs=example_input, outputs=prediction,theme="smooth_slate", title="Flat Entity Classification for Levant Arabic") gr.Examples( examples=["النشرة الإخبارية الصادرة عن الأونروا رقم 113 (1986/1/8).", "صورة لمدينة أريحا القديمة :تل السلطان", "صورة اطفال مخيم للاجئين الفلسطينيين ي لبنان"], inputs= example_input) iface.launch(show_api=False)