import time import spacy import json import gradio as gr from spacy.tokens import Doc, Span from spacy import displacy import matplotlib.pyplot as plt from matplotlib.colors import to_hex import tempfile from inference.model_inference import Inference from configs import * DESC_MD = """ This space is a demo for Major Entity Identification (MEI) . MEI takes entities as additional input and aims to detect the mentions that refer only to these entities.

Place the text in the text box with a single phrase of a selected entity in double curly braces(example: a single instance of {{Ron}} if you want to track Ron). Note that you can select one phrase for each entity and multiple entities can be selected. Check out the example below for clarity.

Static: Uses an instance of: MEIRa-S model
Hybrid: Uses an instance of: MEIRa-H model

The demo provides a json file with clusters and an HTML file with visualizations. The visualizations are color-coded based on the clusters.
""" def get_MEIRa_clusters(doc_name, text, model_type): model_str = MODELS[model_type] model = Inference(model_str) output_dict = model.perform_coreference(text, doc_name) return output_dict def coref_visualizer(doc_name, text, model_type): coref_output = get_MEIRa_clusters(doc_name, text, model_type) tokens = coref_output["tokenized_doc"] clusters = coref_output["clusters"] labels = coref_output["representative_names"] ## Get a pastel palette color_palette = { label: to_hex(plt.cm.get_cmap("tab20", len(labels))(i)) for i, label in enumerate(labels) } nlp = spacy.blank("en") doc = Doc(nlp.vocab, words=tokens) print("Tokens:", tokens, flush=True) # print("Doc:", doc, flush=True) print(color_palette) spans = [] for cluster_ind, cluster in enumerate(clusters[:-1]): label = labels[cluster_ind] for (start, end), mention in cluster: span = Span(doc, start, end + 1, label=label) spans.append(span) doc.spans["coref_spans"] = spans print("Rendering the visualization...") # color_map = {label: color_palette[i] for i, label in enumerate(labels)} # Generate the HTML output html = displacy.render( doc, style="span", options={ "spans_key": "coref_spans", "colors": color_palette, }, jupyter=False, ) ## Create a hash based on time and doc_name time_hash = hash(str(time.time()) + doc_name) # html_file = f"temp/gradio_outputs/output_{time_hash}.html" # json_file = f"temp/gradio_outputs/output_{time_hash}.json" # Create a temporary HTML file with tempfile.NamedTemporaryFile(suffix=".html", delete=False) as tmp_html_file: html_file = tmp_html_file.name tmp_html_file.write(html.encode("utf-8")) with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as tmp_json_file: json_file = tmp_json_file.name tmp_json_file.write(json.dumps(coref_output).encode("utf-8")) # with open(html_file, "w") as f: # f.write(html) # with open(json_file, "w") as f: # json.dump(coref_output, f) print("HTML file:", html_file) print("JSON file:", json_file) return ( html_file, json_file, gr.DownloadButton(value=html_file, visible=True), gr.DownloadButton(value=json_file, visible=True), ) def download_html(): return gr.DownloadButton(visible=False) def download_json(): return gr.DownloadButton(visible=False) with open("example_harry.txt", "r") as f: example_harry = f.read() options = ["static", "hybrid"] with gr.Blocks() as demo: html_file = gr.File(visible=False) json_file = gr.File(visible=False) html_button = gr.DownloadButton("Download HTML", visible=False) json_button = gr.DownloadButton("Download JSON", visible=False) html_button.click() json_button.click() iface = gr.Interface( fn=coref_visualizer, inputs=[ gr.Textbox(lines=1, placeholder="Enter document name:"), gr.Textbox(lines=10, placeholder="Enter text for coreference resolution:"), gr.Radio(choices=options, label="Select an Option"), ], outputs=[ html_file, json_file, html_button, json_button, ], title="MEI Visualizer", description=DESC_MD, examples=[ [ "example", "{{Harry}} went to Hogwarts to meet Hemoine and {{Ron}} . He also met Ron's mother at the railway station.", "static", ], ["example_large", example_harry, "static"], ], ) demo.launch(debug=True)