jvamvas's picture
Update app.py
f122fc3 verified
from pathlib import Path
import gradio as gr
from jinja2 import Environment
from tokenizers.pre_tokenizers import Whitespace
from transformers import pipeline
from recognizers import DiffAlign, DiffDel
def load_pipeline(model_name_or_path: str = "ZurichNLP/unsup-simcse-xlm-roberta-base"):
return pipeline("feature-extraction", model=model_name_or_path)
def generate_diff(text_a: str, text_b: str, method: str):
global my_pipeline
if my_pipeline is None:
my_pipeline = load_pipeline()
if method == "DiffAlign":
diff = DiffAlign(pipeline=my_pipeline)
min_value = 0.3758048415184021 - 0.1
max_value = 1.045647144317627 - 0.1
elif method == "DiffDel":
diff = DiffDel(pipeline=my_pipeline)
min_value = 0.4864141941070556
max_value = 0.5012983083724976 + 0.025
else:
raise ValueError(f"Unknown method: {method}")
encoding_a = tokenizer.pre_tokenize_str(text_a)
encoding_b = tokenizer.pre_tokenize_str(text_b)
result = diff.predict(
a=" ".join([token[0] for token in encoding_a]),
b=" ".join([token[0] for token in encoding_b]),
)
result.add_whitespace(encoding_a, encoding_b)
# Normalize labels based on empirical min/max values
result.labels_a = tuple([(label - min_value) / (max_value - min_value) for label in result.labels_a])
result.labels_b = tuple([(label - min_value) / (max_value - min_value) for label in result.labels_b])
# Round labels to range 0, 2, ... 10
result.labels_a = tuple([round(min(10, label * 10)) for label in result.labels_a])
result.labels_b = tuple([round(min(10, label * 10)) for label in result.labels_b])
template_path = Path(__file__).parent / "result_template.html"
template = Environment().from_string(template_path.read_text())
html_dir = Path(__file__).parent / "html_out"
html_dir.mkdir(exist_ok=True)
html_a = template.render(token_labels=result.token_labels_a)
html_b = template.render(token_labels=result.token_labels_b)
return str(html_a), str(html_b)
my_pipeline = None
tokenizer = Whitespace()
with gr.Blocks() as demo:
preamble = (Path(__file__).parent / "preamble.md").read_text()
gr.Markdown(preamble)
with gr.Row():
text_a = gr.Textbox(label="Text A", value="We'll meet Steve on Wednesday.", lines=2)
text_b = gr.Textbox(label="Text B", value="We are going to see Mary on Friday.", lines=2)
with gr.Row():
method = gr.Dropdown(choices=["DiffAlign", "DiffDel"], label="Comparison Method", value="DiffAlign")
with gr.Row():
with gr.Column(variant="panel"):
output_a = gr.HTML(label="Result for text A", show_label=True)
with gr.Column(variant="panel"):
output_b = gr.HTML(label="Result for text B", show_label=True)
with gr.Row():
submit_btn = gr.Button(value="Generate Diff")
submit_btn.click(
fn=generate_diff,
inputs=[text_a, text_b, method],
outputs=[output_a, output_b],
)
description = (Path(__file__).parent / "description.md").read_text()
gr.Markdown(description)
if my_pipeline is None:
my_pipeline = load_pipeline()
demo.launch()