Spaces:
Sleeping
Sleeping
File size: 4,891 Bytes
98e2ea5 6084ae2 98e2ea5 b0192e5 5dc94a0 b0192e5 98e2ea5 6084ae2 98e2ea5 6084ae2 98e2ea5 6084ae2 98e2ea5 0d75905 98e2ea5 b0192e5 98e2ea5 b0192e5 0d75905 b0192e5 98e2ea5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import time
import spacy
import json
import gradio as gr
from spacy.tokens import Doc, Span
from spacy import displacy
import matplotlib.pyplot as plt
from matplotlib.colors import to_hex
import tempfile
from inference.model_inference import Inference
from configs import *
DESC_MD = """
<font size="3">
This space is a demo for <a href="https://arxiv.org/abs/2406.14654"> Major Entity Identification (MEI) </a>. MEI takes entities as additional input and aims to detect the mentions that refer only to these entities. <br/>
<br/>
Place the text in the text box with a single phrase of a selected entity in double curly braces(example: a single instance of {{Ron}} if you want to track Ron). Note that you can select one phrase for each entity and multiple entities can be selected. Check out the example below for clarity. <br/>
<br/>
Static: Uses an instance of: MEIRa-S model <br/>
Hybrid: Uses an instance of: MEIRa-H model <br/>
<br/>
The demo provides a json file with clusters and an HTML file with visualizations. The visualizations are color-coded based on the clusters. <br/>
</font>
"""
def get_MEIRa_clusters(doc_name, text, model_type):
model_str = MODELS[model_type]
model = Inference(model_str)
output_dict = model.perform_coreference(text, doc_name)
return output_dict
def coref_visualizer(doc_name, text, model_type):
coref_output = get_MEIRa_clusters(doc_name, text, model_type)
tokens = coref_output["tokenized_doc"]
clusters = coref_output["clusters"]
labels = coref_output["representative_names"]
## Get a pastel palette
color_palette = {
label: to_hex(plt.cm.get_cmap("tab20", len(labels))(i))
for i, label in enumerate(labels)
}
nlp = spacy.blank("en")
doc = Doc(nlp.vocab, words=tokens)
print("Tokens:", tokens, flush=True)
# print("Doc:", doc, flush=True)
print(color_palette)
spans = []
for cluster_ind, cluster in enumerate(clusters[:-1]):
label = labels[cluster_ind]
for (start, end), mention in cluster:
span = Span(doc, start, end + 1, label=label)
spans.append(span)
doc.spans["coref_spans"] = spans
print("Rendering the visualization...")
# color_map = {label: color_palette[i] for i, label in enumerate(labels)}
# Generate the HTML output
html = displacy.render(
doc,
style="span",
options={
"spans_key": "coref_spans",
"colors": color_palette,
},
jupyter=False,
)
## Create a hash based on time and doc_name
time_hash = hash(str(time.time()) + doc_name)
# html_file = f"temp/gradio_outputs/output_{time_hash}.html"
# json_file = f"temp/gradio_outputs/output_{time_hash}.json"
# Create a temporary HTML file
with tempfile.NamedTemporaryFile(suffix=".html", delete=False) as tmp_html_file:
html_file = tmp_html_file.name
tmp_html_file.write(html.encode("utf-8"))
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as tmp_json_file:
json_file = tmp_json_file.name
tmp_json_file.write(json.dumps(coref_output).encode("utf-8"))
# with open(html_file, "w") as f:
# f.write(html)
# with open(json_file, "w") as f:
# json.dump(coref_output, f)
print("HTML file:", html_file)
print("JSON file:", json_file)
return (
html_file,
json_file,
gr.DownloadButton(value=html_file, visible=True),
gr.DownloadButton(value=json_file, visible=True),
)
def download_html():
return gr.DownloadButton(visible=False)
def download_json():
return gr.DownloadButton(visible=False)
with open("example_harry.txt", "r") as f:
example_harry = f.read()
options = ["static", "hybrid"]
with gr.Blocks() as demo:
html_file = gr.File(visible=False)
json_file = gr.File(visible=False)
html_button = gr.DownloadButton("Download HTML", visible=False)
json_button = gr.DownloadButton("Download JSON", visible=False)
html_button.click()
json_button.click()
iface = gr.Interface(
fn=coref_visualizer,
inputs=[
gr.Textbox(lines=1, placeholder="Enter document name:"),
gr.Textbox(lines=10, placeholder="Enter text for coreference resolution:"),
gr.Radio(choices=options, label="Select an Option"),
],
outputs=[
html_file,
json_file,
html_button,
json_button,
],
title="MEI Visualizer",
description=DESC_MD,
examples=[
[
"example",
"{{Harry}} went to Hogwarts to meet Hemoine and {{Ron}} . He also met Ron's mother at the railway station.",
"static",
],
["example_large", example_harry, "static"],
],
)
demo.launch(debug=True)
|