import uvicorn import threading from collections import Counter from typing import Optional from transformers import pipeline from transformers import AutoTokenizer, AutoModelForTokenClassification import pandas as pd #import datasets from pprint import pprint import gradio as gr from transformers import pipeline from fastapi import FastAPI from pydantic import BaseModel from typing import List, Dict # Define the FastAPI app app = FastAPI() model_cache: Optional[object] = None dataset_cache : Optional[object] = None def load_model(): """ We load the model at startup""" tokenizer = AutoTokenizer.from_pretrained("LampOfSocrates/bert-cased-plodcw-sourav") model = AutoModelForTokenClassification.from_pretrained("LampOfSocrates/bert-cased-plodcw-sourav") # Mapping labels id2label = model.config.id2label # Print the label mapping print(f"Can recognise the following labels {id2label}") # Load the NER model and tokenizer from Hugging Face #ner_pipeline = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english") model = pipeline("ner", model=model, tokenizer = tokenizer) return model def load_plod_cw_dataset(): from datasets import load_dataset dataset = load_dataset("surrey-nlp/PLOD-CW") return dataset def get_cached_data(): global dataset_cache if dataset_cache is None: dataset_cache = load_plod_cw_dataset() return dataset_cache def get_cached_model(): global model_cache if model_cache is None: model_cache = load_model() return model_cache # Cache the model when the server starts model = get_cached_model() #plod_cw = get_cached_data() class Entity(BaseModel): entity: str score: float start: int end: int word: str class NERResponse(BaseModel): entities: List[Entity] class NERRequest(BaseModel): text: str @app.get("/hello") def read_root(): """useful for testing connections""" return {"message": "Hello, World!"} @app.post("/ner", response_model=NERResponse) def get_entities(request: NERRequest): """ This is invoked while API Testing """ print(request) model = get_cached_model() # Use the NER model to detect entities entities = model(request.text) print(entities[0].keys()) # Convert entities to the response model response_entities = [Entity(**entity) for entity in entities] print(response_entities[0]) return NERResponse(entities=response_entities) def get_color_for_label(label: str) -> str: # Define a mapping of labels to colors color_mapping = { "I-LF": "red", "B-LF": "pink", "B-AC": "blue", "B-O": "green", # Add more labels and colors as needed } return color_mapping.get(label, "black") # Default to black if label not found # Define the Gradio interface function def ner_demo(text): """ This is invoked while rendering the page""" model = get_cached_model() entities = model(text) print("Entities detected {}".format(Counter( [ entity['entity'] for entity in entities]))) all_html = "" last_index = 0 for entity in entities: start, end, label = entity["start"], entity["end"], entity["entity"] color = get_color_for_label(label) entity_text = text[start:end] #colored_entity = f'{entity_text}' colored_entity = f'{entity_text}' # Append text before the entity all_html += text[last_index:start] # Append the colored entity all_html += colored_entity # Update the last_index last_index = end # Append the remaining text after the last entity all_html += text[last_index:] return all_html bo_color = get_color_for_label("B-O") bac_color = get_color_for_label("B-AC") ilf_color = get_color_for_label("I-LF") blf_color = get_color_for_label("B-LF") PROJECT_INTRO = f"""This is a HF Spaces hosted Gradio App built by NLP Group 27. \n\n The model has been trained on surrey-nlp/PLOD-CW dataset. The following Entities are recognized: B-O B-AC I-LF B-LF Rest """ def echo(text, request: gr.Request): res = '
' if request: res += f"Request headers dictionary: {request.headers}

" res += f"IP address: {request.client.host}

" res += f"Query parameters: {dict(request.query_params)}

" res += "

" return res def sample_data(text): text = "The red dots represents LCI , the bright yellow rectangle represents RV , and the black triangle represents the /TLCnLCI" #dat = get_cached_data() #df = dat['test']['tokens'].sample(5) data = { "Text": [text], "Length": [len(text)] } df = pd.DataFrame(data) return df # Create the Gradio interface demo = gr.Interface( fn=ner_demo, inputs=gr.Textbox(lines=10, placeholder="Enter text here..."), outputs="html", #outputs=gr.JSON(), title="Named Entity Recognition on PLOD-CW ", description=f"{PROJECT_INTRO}\n\nEnter text to extract named entities using a NER model." ) with gr.Blocks() as demo: gr.Markdown("# Named Entity Recognition on PLOD-CW") gr.Markdown(PROJECT_INTRO) gr.Markdown("### Enter text to extract named entities using a NER model.") text_input = gr.Textbox(lines=10, placeholder="Enter text here...", label="Input Text") html_output = gr.HTML(label="HTML Output") with gr.Row(): submit_button = gr.Button("Submit") echo_button = gr.Button("Echo Client") sample_button = gr.Button("Sample PLOD_CW") sample_output = gr.Dataframe(label="Sample Table") echo_output = gr.HTML(label="HTML Output") submit_button.click(ner_demo, inputs=text_input, outputs=html_output) echo_button.click(echo, inputs=text_input, outputs=echo_output) sample_button.click(sample_data, inputs=text_input, outputs=sample_output) # Function to run Gradio demo.launch(server_name="0.0.0.0", server_port=7860)