Spaces:
Sleeping
Sleeping
from typing import Tuple, List, Union, Dict, Mapping | |
import base64 | |
import os | |
from bs4 import BeautifulSoup | |
import gradio as gr | |
from spacy import displacy | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForTokenClassification, | |
BatchEncoding, | |
AutoModelForSeq2SeqLM, | |
DataCollatorForTokenClassification, | |
) | |
import torch | |
from utils import get_dependencies, preprocess_text | |
from models import ( | |
DependencyRobertaForTokenClassification, | |
LabelRobertaForTokenClassification, | |
) | |
DEFAULT_TEXT = "τίω δέ μιν ἐν καρὸς αἴσῃ." | |
BUTTON_CSS = "float: right; --tw-border-opacity: 1; border-color: rgb(229 231 235 / var(--tw-border-opacity)); --tw-gradient-from: rgb(243 244 246 / 0.7); --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to, rgb(243 244 246 / 0)); --tw-gradient-to: rgb(229 231 235 / 0.8); --tw-text-opacity: 1; color: rgb(55 65 81 / var(--tw-text-opacity)); border-width: 1px; --tw-bg-opacity: 1; background-color: rgb(255 255 255 / var(--tw-bg-opacity)); background-image: linear-gradient(to bottom right, var(--tw-gradient-stops)); display: inline-flex; flex: 1 1 0%; align-items: center; justify-content: center; --tw-shadow: 0 1px 2px 0 rgb(0 0 0 / 0.05); --tw-shadow-colored: 0 1px 2px 0 var(--tw-shadow-color); box-shadow: var(--tw-ring-offset-shadow, 0 0 #0000), var(--tw-ring-shadow, 0 0 #0000), var(--tw-shadow); -webkit-appearance: button; border-radius: 0.5rem; padding-top: 0.5rem; padding-bottom: 0.5rem; padding-left: 1rem; padding-right: 1rem; font-size: 1rem; line-height: 1.5rem; font-weight: 600;" | |
DEFAULT_COLOR = "white" | |
MODEL_PATHS = { | |
"POS": "bowphs/testid", | |
"LEMMATIZATION": "bowphs/lemmatization-demo", | |
"DEPENDENCY": "bowphs/depenBERTa_perseus", | |
"LABELS": "bowphs/depenBERTa_labler_perseus", | |
} | |
MODEL_MAX_LENGTH = 512 | |
AUTH_TOKEN = os.environ.get("TOKEN") or True | |
# PoS | |
pos_tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_PATHS["POS"], model_max_length=MODEL_MAX_LENGTH, use_auth_token=AUTH_TOKEN, revision="402ab7d25f49e83a67b955ebbc172b5459fbd939", | |
) | |
pos_model = AutoModelForTokenClassification.from_pretrained( | |
MODEL_PATHS["POS"], use_auth_token=AUTH_TOKEN, revision="402ab7d25f49e83a67b955ebbc172b5459fbd939", | |
) | |
# Lemmatization | |
lemmatizer_tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_PATHS["LEMMATIZATION"], | |
model_max_length=MODEL_MAX_LENGTH, | |
use_auth_token=AUTH_TOKEN, | |
) | |
lemmatizer_model = AutoModelForSeq2SeqLM.from_pretrained( | |
MODEL_PATHS["LEMMATIZATION"], use_auth_token=AUTH_TOKEN | |
) | |
# Dependency Parsing | |
dependency_tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_PATHS["DEPENDENCY"], | |
model_max_length=MODEL_MAX_LENGTH, | |
use_auth_token=AUTH_TOKEN, | |
) | |
arcs_model = DependencyRobertaForTokenClassification.from_pretrained( | |
MODEL_PATHS["DEPENDENCY"], use_auth_token=AUTH_TOKEN | |
) | |
labels_model = LabelRobertaForTokenClassification.from_pretrained( | |
MODEL_PATHS["LABELS"], use_auth_token=AUTH_TOKEN | |
) | |
data_collator = DataCollatorForTokenClassification(dependency_tokenizer) | |
def is_valid_selection(col_arcs, col_labels) -> bool: | |
if not col_arcs and col_labels: | |
return False | |
return True | |
def get_pos_predictions(inputs) -> torch.Tensor: | |
"""Get part of speech predictions.""" | |
return pos_model(inputs["input_ids"]).logits.argmax(-1) # type: ignore | |
def execute_parse( | |
text_input: str, | |
col_pos: bool, | |
col_arcs: bool, | |
col_labels: bool, | |
col_lemmata: bool, | |
compact: bool, | |
bg: str, | |
text: str, | |
) -> Tuple[str, str]: | |
if is_valid_selection(col_arcs, col_labels): | |
return parse( | |
text_input, col_pos, col_arcs, col_labels, col_lemmata, compact, bg, text | |
) | |
return "Please check 'Dependency Arcs' before checking 'Dependency Labels'", "" | |
def lemmatize(tokens: List[str]) -> List[str]: | |
def construct_task(word_idx: int) -> str: | |
return f"lemmatize: {' '.join(tokens[:word_idx])} <extra_id_0> {tokens[word_idx]} <extra_id_1> {' '.join(list(tokens[word_idx]))} <extra_id_2> {' '.join(tokens[word_idx+1:])}" | |
predictions = [ | |
lemmatizer_tokenizer.decode( | |
lemmatizer_model.generate( | |
lemmatizer_tokenizer(construct_task(word_idx), return_tensors="pt")[ | |
"input_ids" | |
], | |
max_length=20, | |
num_beams=5, | |
num_return_sequences=1, | |
early_stopping=True, | |
)[0], | |
skip_special_tokens=True, | |
) | |
for word_idx in range(len(tokens)) | |
] | |
return predictions | |
def add_lemma_visualization(soup, lemmata: List[str], col_arcs: bool) -> str: | |
for token, lemma in zip(soup.find_all(class_="displacy-token")[col_arcs:], lemmata): | |
pos_tag = token.find(class_="displacy-tag") | |
lemma_tag = soup.new_tag( | |
"tspan", | |
class_="displacy-lemma", | |
dy="2em", | |
fill="currentColor", | |
x=pos_tag.attrs["x"], | |
) | |
lemma_tag.string = lemma | |
pos_tag.insert_after(lemma_tag) | |
return str(soup) | |
def download_svg(svg): | |
encode = base64.b64encode(bytes(svg, "utf-8")) | |
img = "data:image/svg+xml;base64," + str(encode)[2:-1] | |
html = f'<a download="displacy.svg" href="{img}" style="{BUTTON_CSS}">Download as SVG</a>' | |
return html | |
def prepare_doc( | |
tokens: List[str], col_pos: bool, pos_outputs: torch.Tensor, inputs: BatchEncoding, | |
) -> Dict[str, List[Dict[str, str]]]: | |
doc: Dict[str, List[Dict[str, str]]] = { | |
"words": [], #[{"text": "ROOT", "tag": ""}], | |
"arcs": [], | |
} | |
word_ids = inputs.word_ids() | |
previous_word_idx = None | |
for idx, word_idx in enumerate(word_ids): | |
if word_idx != previous_word_idx and word_idx is not None: | |
tag_repr = ( | |
pos_model.config.id2label[pos_outputs[0][idx].item()] if col_pos else "" | |
) | |
doc["words"].append({"text": tokens[word_idx], "tag": tag_repr}) | |
previous_word_idx = word_idx | |
return doc | |
def parse( | |
text_input: str, | |
col_pos: bool, | |
col_arcs: bool, | |
col_labels: bool, | |
col_lemmata: bool, | |
compact: bool, | |
bg: str, | |
text: str, | |
) -> Tuple[str, str]: | |
tokens = preprocess_text(text_input) | |
inputs = pos_tokenizer( | |
tokens, | |
return_tensors="pt", | |
truncation=True, | |
padding=True, | |
is_split_into_words=True, | |
) | |
pos_outputs = get_pos_predictions(inputs) | |
doc = prepare_doc(tokens, col_pos, pos_outputs, inputs) | |
if col_arcs: | |
doc["words"].insert(0, {"text": "ROOT", "tag": ""}) | |
doc["arcs"] = get_dependencies( | |
arcs_model, | |
labels_model, | |
dependency_tokenizer, | |
data_collator, | |
col_labels, | |
tokens, | |
)["arcs"] | |
options = {"compact": compact, "bg": bg, "color": text} | |
svg = displacy.render(doc, manual=True, style="dep", options=options) | |
if col_lemmata: | |
soup = BeautifulSoup(svg, "lxml-xml") | |
lemmata = lemmatize(tokens) | |
svg = add_lemma_visualization(soup, lemmata, col_arcs) | |
download_link = download_svg(svg) | |
return svg, download_link | |
def setup_parser_ui(): | |
demo = gr.Blocks(css="scrollbar.css") | |
with demo: | |
with gr.Box(): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("# Athena's Lens") | |
gr.Markdown( | |
"### From Ἀlkaios to Ὠrigen: A Modern Lens on Timeless Texts" | |
) | |
with gr.Box(): | |
with gr.Column(): | |
gr.Markdown(" ## Enter some text") | |
with gr.Row(): | |
with gr.Column(scale=0.5): | |
text_input = gr.Textbox( | |
value=DEFAULT_TEXT, interactive=True, label="Input Text" | |
) | |
with gr.Row(): | |
with gr.Column(scale=0.25): | |
button = gr.Button("Update", variant="primary").style( | |
full_width=False | |
) | |
with gr.Box(): | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("## Parser") | |
with gr.Row(): | |
with gr.Column(): | |
col_pos = gr.Checkbox(label="PoS Labels", value=True) | |
col_arcs = gr.Checkbox(label="Dependency Arcs", value=False) | |
col_labels = gr.Checkbox(label="Dependency Labels", value=False) | |
col_lemmata = gr.Checkbox(label="Lemmata", value=False) | |
compact = gr.Checkbox(label="Compact", value=False) | |
with gr.Column(): | |
bg = gr.Textbox(label="Background Color", value=DEFAULT_COLOR) | |
with gr.Column(): | |
text = gr.Textbox(label="Text Color", value="black") | |
with gr.Row(): | |
dep_output = gr.HTML( | |
value=parse( | |
DEFAULT_TEXT, | |
True, | |
False, | |
False, | |
False, | |
False, | |
DEFAULT_COLOR, | |
"black", | |
)[0] | |
) | |
with gr.Row(): | |
with gr.Column(scale=0.25): | |
dep_button = gr.Button( | |
"Update Parser", variant="primary" | |
).style(full_width=False) | |
with gr.Column(): | |
dep_download_button = gr.HTML( | |
value=download_svg(dep_output.value) | |
) | |
with gr.Box(): | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("## Contact") | |
gr.Markdown( | |
"If you have any questions, suggestions, comments, or problems, feel free to [reach out](mailto:riemenschneider@cl.uni-heidelberg.de)." | |
) | |
gr.Markdown("## Citation") | |
gr.Markdown( | |
"This space uses models from [this](https://aclanthology.org/2023.acl-long.846.pdf) paper." | |
) | |
gr.Markdown( | |
"""```bibtex | |
@incollection{riemenschneider-frank-2023-exploring, | |
title = "Exploring Large Language Models for Classical Philology", | |
author = "Riemenschneider, Frederick and Frank, Anette", | |
booktitle = "Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", | |
month = jul, | |
year = "2023", | |
address = "Toronto, Canada", | |
publisher = "Association for Computational Linguistics", | |
url = "https://aclanthology.org/2023.acl-long.846", | |
doi = "10.18653/v1/2023.acl-long.846", | |
pages = "15181--15199", | |
} | |
``` | |
""" | |
) | |
button.click( | |
execute_parse, | |
inputs=[ | |
text_input, | |
col_pos, | |
col_arcs, | |
col_labels, | |
col_lemmata, | |
compact, | |
bg, | |
text, | |
], | |
outputs=[dep_output, dep_download_button], | |
) | |
dep_button.click( | |
execute_parse, | |
inputs=[ | |
text_input, | |
col_pos, | |
col_arcs, | |
col_labels, | |
col_lemmata, | |
compact, | |
bg, | |
text, | |
], | |
outputs=[dep_output, dep_download_button], | |
) | |
demo.launch() | |
def main(): | |
demo = setup_parser_ui() | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |