Spaces:
Sleeping
Sleeping
import time | |
import spacy | |
import json | |
import gradio as gr | |
from spacy.tokens import Doc, Span | |
from spacy import displacy | |
import matplotlib.pyplot as plt | |
from matplotlib.colors import to_hex | |
import tempfile | |
from inference.model_inference import Inference | |
from configs import * | |
DESC_MD = """ | |
<font size="3"> | |
This space is a demo for <a href="https://arxiv.org/abs/2406.14654"> Major Entity Identification (MEI) </a>. MEI takes entities as additional input and aims to detect the mentions that refer only to these entities. <br/> | |
<br/> | |
Place the text in the text box with a single phrase of a selected entity in double curly braces(example: a single instance of {{Ron}} if you want to track Ron). Note that you can select one phrase for each entity and multiple entities can be selected. Check out the example below for clarity. <br/> | |
<br/> | |
Static: Uses an instance of: MEIRa-S model <br/> | |
Hybrid: Uses an instance of: MEIRa-H model <br/> | |
<br/> | |
The demo provides a json file with clusters and an HTML file with visualizations. The visualizations are color-coded based on the clusters. <br/> | |
</font> | |
""" | |
def get_MEIRa_clusters(doc_name, text, model_type): | |
model_str = MODELS[model_type] | |
model = Inference(model_str) | |
output_dict = model.perform_coreference(text, doc_name) | |
return output_dict | |
def coref_visualizer(doc_name, text, model_type): | |
coref_output = get_MEIRa_clusters(doc_name, text, model_type) | |
tokens = coref_output["tokenized_doc"] | |
clusters = coref_output["clusters"] | |
labels = coref_output["representative_names"] | |
## Get a pastel palette | |
color_palette = { | |
label: to_hex(plt.cm.get_cmap("tab20", len(labels))(i)) | |
for i, label in enumerate(labels) | |
} | |
nlp = spacy.blank("en") | |
doc = Doc(nlp.vocab, words=tokens) | |
print("Tokens:", tokens, flush=True) | |
# print("Doc:", doc, flush=True) | |
print(color_palette) | |
spans = [] | |
for cluster_ind, cluster in enumerate(clusters[:-1]): | |
label = labels[cluster_ind] | |
for (start, end), mention in cluster: | |
span = Span(doc, start, end + 1, label=label) | |
spans.append(span) | |
doc.spans["coref_spans"] = spans | |
print("Rendering the visualization...") | |
# color_map = {label: color_palette[i] for i, label in enumerate(labels)} | |
# Generate the HTML output | |
html = displacy.render( | |
doc, | |
style="span", | |
options={ | |
"spans_key": "coref_spans", | |
"colors": color_palette, | |
}, | |
jupyter=False, | |
) | |
## Create a hash based on time and doc_name | |
time_hash = hash(str(time.time()) + doc_name) | |
# html_file = f"temp/gradio_outputs/output_{time_hash}.html" | |
# json_file = f"temp/gradio_outputs/output_{time_hash}.json" | |
# Create a temporary HTML file | |
with tempfile.NamedTemporaryFile(suffix=".html", delete=False) as tmp_html_file: | |
html_file = tmp_html_file.name | |
tmp_html_file.write(html.encode("utf-8")) | |
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as tmp_json_file: | |
json_file = tmp_json_file.name | |
tmp_json_file.write(json.dumps(coref_output).encode("utf-8")) | |
# with open(html_file, "w") as f: | |
# f.write(html) | |
# with open(json_file, "w") as f: | |
# json.dump(coref_output, f) | |
print("HTML file:", html_file) | |
print("JSON file:", json_file) | |
return ( | |
html_file, | |
json_file, | |
gr.DownloadButton(value=html_file, visible=True), | |
gr.DownloadButton(value=json_file, visible=True), | |
) | |
def download_html(): | |
return gr.DownloadButton(visible=False) | |
def download_json(): | |
return gr.DownloadButton(visible=False) | |
with open("example_harry.txt", "r") as f: | |
example_harry = f.read() | |
options = ["static", "hybrid"] | |
with gr.Blocks() as demo: | |
html_file = gr.File(visible=False) | |
json_file = gr.File(visible=False) | |
html_button = gr.DownloadButton("Download HTML", visible=False) | |
json_button = gr.DownloadButton("Download JSON", visible=False) | |
html_button.click() | |
json_button.click() | |
iface = gr.Interface( | |
fn=coref_visualizer, | |
inputs=[ | |
gr.Textbox(lines=1, placeholder="Enter document name:"), | |
gr.Textbox(lines=10, placeholder="Enter text for coreference resolution:"), | |
gr.Radio(choices=options, label="Select an Option"), | |
], | |
outputs=[ | |
html_file, | |
json_file, | |
html_button, | |
json_button, | |
], | |
title="MEI Visualizer", | |
description=DESC_MD, | |
examples=[ | |
[ | |
"example", | |
"{{Harry}} went to Hogwarts to meet Hemoine and {{Ron}} . He also met Ron's mother at the railway station.", | |
"static", | |
], | |
["example_large", example_harry, "static"], | |
], | |
) | |
demo.launch(debug=True) | |