xpos / app.py
wietsedv's picture
Fix token aggregation
0d557da
import gradio as gr
import gradio.inputs
import gradio.outputs
from transformers.pipelines import pipeline
lang_names = ['Afrikaans', 'Ancient Greek', 'Arabic', 'Armenian', 'Basque', 'Belarusian', 'Bulgarian', 'Catalan', 'Chinese', 'Classical Chinese', 'Croatian', 'Czech', 'Danish', 'Dutch', 'English', 'Estonian', 'Faroese', 'Finnish', 'French', 'Galician', 'German', 'Gothic', 'Greek', 'Hebrew', 'Hindi', 'Hungarian', 'Icelandic', 'Indonesian', 'Irish', 'Italian', 'Japanese', 'Korean', 'Latin', 'Latvian', 'Lithuanian', 'Maltese', 'Marathi', 'Naija', 'North Sami', 'Norwegian', 'Old Church Slavonic', 'Old East Slavic', 'Old French', 'Persian', 'Polish', 'Portuguese', 'Romanian', 'Russian', 'Sanskrit', 'Scottish Gaelic', 'Serbian', 'Slovak', 'Slovenian', 'Spanish', 'Swedish', 'Tamil', 'Telugu', 'Turkish', 'Ukrainian', 'Urdu', 'Uyghur', 'Vietnamese', 'Welsh', 'Western Armenian', 'Wolof']
lang_codes = ['af', 'grc', 'ar', 'hy', 'eu', 'be', 'bg', 'ca', 'zh', 'lzh', 'hr', 'cs', 'da', 'nl', 'en', 'et', 'fo', 'fi', 'fr', 'gl', 'de', 'got', 'el', 'he', 'hi', 'hu', 'is', 'id', 'ga', 'it', 'ja', 'ko', 'la', 'lv', 'lt', 'mt', 'mr', 'pcm', 'sme', 'no', 'cu', 'orv', 'fro', 'fa', 'pl', 'pt', 'ro', 'ru', 'sa', 'gd', 'sr', 'sk', 'sl', 'es', 'sv', 'ta', 'te', 'tr', 'uk', 'ur', 'ug', 'vi', 'cy', 'hyw', 'wo']
model_ids = [
f"wietsedv/xlm-roberta-base-ft-udpos28-{code}" for code in lang_codes
]
def model_link(model_id):
return f"<a href='https://huggingface.co/{model_id}' target='_blank'>πŸ€— {model_id}</a>"
article = "<table style='width:auto'>"
article += "<thead><th>Source language</th><th>Model</th></thead><tbody>"
article += "\n".join([f"<tr><td>{l}</td><td>{model_link(m)}</td></tr>" for l, m in zip(lang_names, model_ids)])
article += "</tbody></table>"
loaded_model_id = None
pipe = None
def tag(text, lang_index):
global loaded_model_id, pipe
model_id = model_ids[lang_index]
if pipe is None or model_id != loaded_model_id:
loaded_model_id = model_id
pipe = pipeline("token-classification", model_id, aggregation_strategy="first")
# Aggregate words:
# split on whitespace and PUNCT, but merge other subtokens (keep first tag)
out = []
for g in pipe(text):
if g["word"][0] == "▁" or g["entity"] == "PUNCT":
out.append((g["word"].lstrip("▁"), g["entity"]))
else:
out[-1] = (out[-1][0] + g["word"], out[-1][1])
return out, model_link(model_id)
iface = gr.Interface(
fn=tag,
inputs=[
gradio.inputs.Textbox(label="Text", lines=3, placeholder="Enter a sentence here..."),
gradio.inputs.Dropdown(label="Source language", choices=lang_names, type="index"),
],
outputs=[
gradio.outputs.HighlightedText(label="Output"),
gradio.outputs.HTML(label="Model"),
],
title="Cross-lingual part-of-speech tagging",
description="Enter some text in any language and choose any of 65 source languages. The source language is the language for which XLM-RoBERTa is fine-tuned on Universal Dependencies v2.8 universal part-of-speech tagging data. This space is meant to demonstrate cross-lingual transfer, so the language of your sentence and the selected language do not have to match. You may find fewer mistakes if the selected language is similar to the actual language of your text.",
allow_screenshot=False,
allow_flagging="never",
article=article,
theme="huggingface",
# examples=[["Dit is een test.", "English"]]
)
iface.launch()