Grew-TSE / app.py
DanielGallagherIRE's picture
Upload app.py
7dd5f47 verified
import gradio as gr
import pandas as pd
import tempfile
import ast
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from grewtse.pipeline import GrewTSEPipe
from grewtse.evaluators import GrewTSEvaluator
from grewtse.visualise import GrewTSEVisualiser
grewtse = GrewTSEPipe()
def parse_treebank(path: str, treebank_selection: str) -> pd.DataFrame:
if treebank_selection == "None":
parsed_treebank = grewtse.parse_treebank(path)
# treebank_path = path
else:
treebank_selection = f"./datasets/{treebank_selection}"
parsed_treebank = grewtse.parse_treebank(treebank_selection)
# treebank_path = treebank_selection
return grewtse.get_morphological_features().tail()
def to_masked_dataset(query, node) -> pd.DataFrame:
df = grewtse.generate_masked_dataset(query, node)
return df
def to_prompt_dataset(query, node) -> pd.DataFrame:
df = grewtse.generate_prompt_dataset(query, node)
return df
def safe_str_to_dict(s):
try:
return ast.literal_eval(s)
except (ValueError, SyntaxError):
return None
def truncate_text(text, max_len=50):
"""
Truncate a string to max_len characters and append '...' if it was longer.
"""
if not isinstance(text, str):
return text # Keep non-string values unchanged
return text[:max_len] + "..." if len(text) > max_len else text
def generate_minimal_pairs(query: str, node: str, alt_features: str, task_type: str):
if not grewtse.is_treebank_parsed():
raise ValueError("Please parse a treebank first.")
# determine whether an alternative LI should be found
alt_features_as_dict = safe_str_to_dict(alt_features)
if alt_features_as_dict is None:
raise Exception("Invalid features provided.")
has_leading_whitespace = False
is_encoder = False
masked_or_prompt_dataset = None
if task_type.lower() == "masked":
# mask the target word in the sentence
masked_or_prompt_df = to_masked_dataset(query, node)
has_leading_whitespace = False
is_encoder = True
elif task_type.lower() == "prompt":
# create prompts from each sentence (i.e. cut them off right at the target word)
masked_or_prompt_dataset = to_prompt_dataset(query, node)
has_leading_whitespace = True
else:
raise Exception("Invalid task type.")
full_dataset = grewtse.generate_minimal_pair_dataset(alt_features_as_dict, {},
ood_pairs=None, has_leading_whitespace=has_leading_whitespace)
# save to a temporary CSV file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
full_dataset.to_csv(temp_file.name, index=False)
if is_encoder:
dataset_for_vis = full_dataset[["masked_text", "form_grammatical", "form_ungrammatical"]]
dataset_for_vis["masked_text"] = dataset_for_vis["masked_text"].apply(truncate_text)
else:
dataset_for_vis = full_dataset[["prompt_text", "form_grammatical", "form_ungrammatical"]]
dataset_for_vis["prompt_text"] = dataset_for_vis["prompt_text"].apply(truncate_text)
num_exceptions = grewtse.get_num_exceptions()
num_targets_parsed = len(masked_or_prompt_df)
num_success = len(full_dataset)
exceptions_info = f"{num_targets_parsed+num_exceptions} targets identified and turned into masks/prompts. {num_exceptions} of these could not be used due to treebank structure issues. After searching for minimal pairs, a total of <br>{num_success} minimal-pair syntactic tests</br> were successfully generated."
gr.Info(exceptions_info, duration=60, title="Grew-TSE Results")
return dataset_for_vis, temp_file.name
def evaluate_model(
model_repo: str,
task_type: str
):
if not grewtse.are_minimal_pairs_generated():
raise ValueError(
"Please parse a treebank, mask a dataset and generate minimal pairs first."
)
g_eval = GrewTSEvaluator()
g_vis = GrewTSEVisualiser()
model_type = "encoder" if task_type.lower() == "masked" else "decoder"
mp_with_eval_dataset = g_eval.evaluate_model(grewtse.get_minimal_pair_dataset(), model_repo, model_type)
metrics = g_eval.get_all_metrics()
metrics = pd.DataFrame(metrics.items(), columns=["Metric", "Value"])
print("===METRICS===")
print(metrics)
print("----")
# save to a temporary CSV file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
mp_with_eval_dataset.to_csv(temp_file.name, index=False)
return metrics, temp_file.name
def show_df():
return gr.update(visible=True)
with gr.Blocks(theme=gr.themes.Ocean()) as demo:
is_treebank_parse_success = False
with gr.Row():
gr.Markdown(
"# GREW-TSE: A Pipeline for Query-based Targeted Syntactic Evaluation"
)
with gr.Row():
with gr.Column():
gr.Markdown(
"""
#### Load a Treebank
You can begin by loading up a particular treebank that you'd like to work with.<br>
You can either select a treebank from the pre-loaded options below, or upload your own.<br>
"""
)
with gr.Column():
with gr.Tabs():
with gr.TabItem("Choose Treebank"):
treebank_selection = gr.Dropdown(
choices=[
"None",
"en/English-EWT-UD-Treebank.conllu",
"Polish-Test-Treebank.conllu",
"Spanish-Test-SM.conllu",
],
label="Select a treebank",
value="en/English-EWT-UD-Treebank.conllu",
)
with gr.TabItem("Upload Your Own"):
gr.Markdown("## Upload a .conllu File")
file_input = gr.File(
label="Upload .conllu file",
file_types=[".conllu"],
type="filepath",
)
parse_file_button = gr.Button("Parse Treebank", size="sm", scale=0)
morph_table = gr.Dataframe(interactive=False, visible=False)
parse_file_button.click(
fn=parse_treebank,
inputs=[file_input, treebank_selection],
outputs=[morph_table],
)
parse_file_button.click(fn=show_df, outputs=morph_table)
gr.Markdown("## Isolate A Syntactic Phenomenon")
with gr.Row():
with gr.Column():
gr.Markdown(
"""
**GREW (Graph Rewriting for Universal Dependencies)** is a query and transformation language used to search within and manipulate dependency treebanks. A GREW query allows linguists and NLP researchers to find specific syntactic patterns in parsed linguistic data (such as Universal Dependencies treebanks).
Queries are expressed as graph constraints using a concise pattern-matching syntax.
#### Example
The following short GREW query will find target any verbs. Try it with one of the sample treebanks above.
Make sure to include the variable V as the target that we're trying to isolate.
```grew
pattern {
V [upos=\"VERB\"];
}
```
"""
)
with gr.Column():
query_input = gr.Textbox(
label="GREW Query",
lines=5,
placeholder="Enter your GREW query here...",
value="pattern { V [upos=VERB, Number=Sing]; }"
)
node_input = gr.Textbox(
label="Target",
placeholder="The variable in your GREW query to isolate, e.g., N",
value="V",
)
feature_input = gr.Textbox(
label="Enter Alternative Feature Values for Minimal Pair as a Dictionary",
placeholder='e.g. {"case": "Acc", "number": "Sing"}',
value='{"number": "Plur"}',
lines=3,
)
task_type = gr.Dropdown(
choices=[
"Masked",
"Prompt",
],
label="Select whether you want masked- or prompt-based tests.",
value="Masked"
)
run_button = gr.Button("Run Query", size="sm", scale=0)
output_table = gr.Dataframe(label="Output Table", visible=False)
download_file = gr.File(label="Download CSV")
run_button.click(
fn=generate_minimal_pairs,
inputs=[query_input, node_input, feature_input, task_type],
outputs=[output_table, download_file],
)
run_button.click(fn=show_df, outputs=output_table)
with gr.Row():
with gr.Column():
gr.Markdown(
"""
## Evaluate A Model (BETA Version)
You can evaluate models trained either for MLM or NTP tasks that are available on the Hugging Face platform.
"""
)
with gr.Column():
repository_input = gr.Textbox(
label="Model Repository",
lines=1,
placeholder="Enter the model repository here...",
value="google-bert/bert-base-multilingual-cased",
)
with gr.Column():
evaluate_button = gr.Button("Evaluate Model", size="sm", scale=0)
mp_with_eval_output_dataset = gr.Dataframe(label="Output Table", visible=False)
mp_with_eval_output_download = gr.File(label="Download CSV")
evaluate_button.click(
fn=evaluate_model,
inputs=[
repository_input,
task_type
],
outputs=[
gr.DataFrame(),
mp_with_eval_output_download,
],
)
if __name__ == "__main__":
demo.launch(share=True)