MEIRa / app.py
KawshikManikantan's picture
large_example
0d75905
import time
import spacy
import json
import gradio as gr
from spacy.tokens import Doc, Span
from spacy import displacy
import matplotlib.pyplot as plt
from matplotlib.colors import to_hex
import tempfile
from inference.model_inference import Inference
from configs import *
DESC_MD = """
<font size="3">
This space is a demo for <a href="https://arxiv.org/abs/2406.14654"> Major Entity Identification (MEI) </a>. MEI takes entities as additional input and aims to detect the mentions that refer only to these entities. <br/>
<br/>
Place the text in the text box with a single phrase of a selected entity in double curly braces(example: a single instance of {{Ron}} if you want to track Ron). Note that you can select one phrase for each entity and multiple entities can be selected. Check out the example below for clarity. <br/>
<br/>
Static: Uses an instance of: MEIRa-S model <br/>
Hybrid: Uses an instance of: MEIRa-H model <br/>
<br/>
The demo provides a json file with clusters and an HTML file with visualizations. The visualizations are color-coded based on the clusters. <br/>
</font>
"""
def get_MEIRa_clusters(doc_name, text, model_type):
model_str = MODELS[model_type]
model = Inference(model_str)
output_dict = model.perform_coreference(text, doc_name)
return output_dict
def coref_visualizer(doc_name, text, model_type):
coref_output = get_MEIRa_clusters(doc_name, text, model_type)
tokens = coref_output["tokenized_doc"]
clusters = coref_output["clusters"]
labels = coref_output["representative_names"]
## Get a pastel palette
color_palette = {
label: to_hex(plt.cm.get_cmap("tab20", len(labels))(i))
for i, label in enumerate(labels)
}
nlp = spacy.blank("en")
doc = Doc(nlp.vocab, words=tokens)
print("Tokens:", tokens, flush=True)
# print("Doc:", doc, flush=True)
print(color_palette)
spans = []
for cluster_ind, cluster in enumerate(clusters[:-1]):
label = labels[cluster_ind]
for (start, end), mention in cluster:
span = Span(doc, start, end + 1, label=label)
spans.append(span)
doc.spans["coref_spans"] = spans
print("Rendering the visualization...")
# color_map = {label: color_palette[i] for i, label in enumerate(labels)}
# Generate the HTML output
html = displacy.render(
doc,
style="span",
options={
"spans_key": "coref_spans",
"colors": color_palette,
},
jupyter=False,
)
## Create a hash based on time and doc_name
time_hash = hash(str(time.time()) + doc_name)
# html_file = f"temp/gradio_outputs/output_{time_hash}.html"
# json_file = f"temp/gradio_outputs/output_{time_hash}.json"
# Create a temporary HTML file
with tempfile.NamedTemporaryFile(suffix=".html", delete=False) as tmp_html_file:
html_file = tmp_html_file.name
tmp_html_file.write(html.encode("utf-8"))
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as tmp_json_file:
json_file = tmp_json_file.name
tmp_json_file.write(json.dumps(coref_output).encode("utf-8"))
# with open(html_file, "w") as f:
# f.write(html)
# with open(json_file, "w") as f:
# json.dump(coref_output, f)
print("HTML file:", html_file)
print("JSON file:", json_file)
return (
html_file,
json_file,
gr.DownloadButton(value=html_file, visible=True),
gr.DownloadButton(value=json_file, visible=True),
)
def download_html():
return gr.DownloadButton(visible=False)
def download_json():
return gr.DownloadButton(visible=False)
with open("example_harry.txt", "r") as f:
example_harry = f.read()
options = ["static", "hybrid"]
with gr.Blocks() as demo:
html_file = gr.File(visible=False)
json_file = gr.File(visible=False)
html_button = gr.DownloadButton("Download HTML", visible=False)
json_button = gr.DownloadButton("Download JSON", visible=False)
html_button.click()
json_button.click()
iface = gr.Interface(
fn=coref_visualizer,
inputs=[
gr.Textbox(lines=1, placeholder="Enter document name:"),
gr.Textbox(lines=10, placeholder="Enter text for coreference resolution:"),
gr.Radio(choices=options, label="Select an Option"),
],
outputs=[
html_file,
json_file,
html_button,
json_button,
],
title="MEI Visualizer",
description=DESC_MD,
examples=[
[
"example",
"{{Harry}} went to Hogwarts to meet Hemoine and {{Ron}} . He also met Ron's mother at the railway station.",
"static",
],
["example_large", example_harry, "static"],
],
)
demo.launch(debug=True)