File size: 4,891 Bytes
98e2ea5
 
 
 
 
 
 
 
 
6084ae2
98e2ea5
 
 
 
b0192e5
5dc94a0
 
 
 
 
 
 
 
 
 
b0192e5
 
98e2ea5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6084ae2
 
98e2ea5
6084ae2
 
 
 
98e2ea5
6084ae2
 
 
 
 
 
 
 
 
 
 
 
98e2ea5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d75905
 
98e2ea5
 
 
 
 
 
 
 
 
 
 
 
 
 
b0192e5
98e2ea5
 
 
 
 
 
 
 
b0192e5
 
 
 
 
 
 
0d75905
 
b0192e5
98e2ea5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import time
import spacy
import json

import gradio as gr
from spacy.tokens import Doc, Span
from spacy import displacy
import matplotlib.pyplot as plt
from matplotlib.colors import to_hex
import tempfile

from inference.model_inference import Inference
from configs import *

DESC_MD = """
<font size="3">
This space is a demo for <a href="https://arxiv.org/abs/2406.14654"> Major Entity Identification (MEI) </a>. MEI takes entities as additional input and aims to detect the mentions that refer only to these entities. <br/>
<br/>
Place the text in the text box with a single phrase of a selected entity in double curly braces(example: a single instance of {{Ron}} if you want to track Ron). Note that you can select one phrase for each entity and multiple entities can be selected. Check out the example below for clarity. <br/>
<br/>
Static: Uses an instance of: MEIRa-S model <br/>
Hybrid: Uses an instance of: MEIRa-H model <br/>
<br/>
The demo provides a json file with clusters and an HTML file with visualizations. The visualizations are color-coded based on the clusters. <br/>
</font>
"""


def get_MEIRa_clusters(doc_name, text, model_type):
    model_str = MODELS[model_type]
    model = Inference(model_str)

    output_dict = model.perform_coreference(text, doc_name)
    return output_dict


def coref_visualizer(doc_name, text, model_type):
    coref_output = get_MEIRa_clusters(doc_name, text, model_type)

    tokens = coref_output["tokenized_doc"]
    clusters = coref_output["clusters"]
    labels = coref_output["representative_names"]

    ## Get a pastel palette
    color_palette = {
        label: to_hex(plt.cm.get_cmap("tab20", len(labels))(i))
        for i, label in enumerate(labels)
    }

    nlp = spacy.blank("en")
    doc = Doc(nlp.vocab, words=tokens)

    print("Tokens:", tokens, flush=True)
    # print("Doc:", doc, flush=True)
    print(color_palette)

    spans = []
    for cluster_ind, cluster in enumerate(clusters[:-1]):
        label = labels[cluster_ind]
        for (start, end), mention in cluster:
            span = Span(doc, start, end + 1, label=label)
            spans.append(span)

    doc.spans["coref_spans"] = spans

    print("Rendering the visualization...")

    # color_map = {label: color_palette[i] for i, label in enumerate(labels)}
    # Generate the HTML output
    html = displacy.render(
        doc,
        style="span",
        options={
            "spans_key": "coref_spans",
            "colors": color_palette,
        },
        jupyter=False,
    )

    ## Create a hash based on time and doc_name
    time_hash = hash(str(time.time()) + doc_name)
    # html_file = f"temp/gradio_outputs/output_{time_hash}.html"
    # json_file = f"temp/gradio_outputs/output_{time_hash}.json"

    # Create a temporary HTML file
    with tempfile.NamedTemporaryFile(suffix=".html", delete=False) as tmp_html_file:
        html_file = tmp_html_file.name
        tmp_html_file.write(html.encode("utf-8"))

    with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as tmp_json_file:
        json_file = tmp_json_file.name
        tmp_json_file.write(json.dumps(coref_output).encode("utf-8"))

    # with open(html_file, "w") as f:
    #     f.write(html)

    # with open(json_file, "w") as f:
    #     json.dump(coref_output, f)

    print("HTML file:", html_file)
    print("JSON file:", json_file)

    return (
        html_file,
        json_file,
        gr.DownloadButton(value=html_file, visible=True),
        gr.DownloadButton(value=json_file, visible=True),
    )


def download_html():
    return gr.DownloadButton(visible=False)


def download_json():
    return gr.DownloadButton(visible=False)


with open("example_harry.txt", "r") as f:
    example_harry = f.read()
options = ["static", "hybrid"]

with gr.Blocks() as demo:
    html_file = gr.File(visible=False)
    json_file = gr.File(visible=False)
    html_button = gr.DownloadButton("Download HTML", visible=False)
    json_button = gr.DownloadButton("Download JSON", visible=False)

    html_button.click()
    json_button.click()
    iface = gr.Interface(
        fn=coref_visualizer,
        inputs=[
            gr.Textbox(lines=1, placeholder="Enter document name:"),
            gr.Textbox(lines=10, placeholder="Enter text for coreference resolution:"),
            gr.Radio(choices=options, label="Select an Option"),
        ],
        outputs=[
            html_file,
            json_file,
            html_button,
            json_button,
        ],
        title="MEI Visualizer",
        description=DESC_MD,
        examples=[
            [
                "example",
                "{{Harry}} went to Hogwarts to meet Hemoine and {{Ron}} . He also met Ron's mother at the railway station.",
                "static",
            ],
            ["example_large", example_harry, "static"],
        ],
    )

demo.launch(debug=True)