ruanchaves's picture
uncheck reranker
da00f10
raw
history blame
5.38 kB
import gradio as gr
from hashformers import TransformerWordSegmenter as WordSegmenter
import pandas as pd
article_string = "Author: <a href=\"https://huggingface.co/ruanchaves\">Ruan Chaves Rodrigues</a>. Read more about the <a href=\"https://github.com/ruanchaves/hashformers\">Hashformers library</a>."
app_title = "Hashtag segmentation"
app_description = """
Hashtag segmentation is the task of automatically adding spaces between the words on a hashtag.
This app uses the <a href=\"https://github.com/ruanchaves/hashformers\">Hashformers library</a> to suggest segmentations for hashtags.
Enter a hashtag or pick one from the examples below. The app will suggest the best segmentation for the hashtag.
In the advanced settings, decreasing the slider values will make the app faster, but it may also reduce its accuracy.
"""
app_examples = [
["#cristianoronaldo", "portuguese"],
["#madridsinfiltros", "spanish"],
["#kuenstlicheintelligenz", "german"],
["#dadscare", "english (fast)"],
["#nowthatcherisdead", "english"],
]
output_json_component_description = {"": ""}
model_dict = {
"english": WordSegmenter(
segmenter_model_name_or_path="gpt2",
reranker_model_name_or_path="bert-base-uncased",
segmenter_device="cpu",
),
"english (fast)": WordSegmenter(
segmenter_model_name_or_path="distilgpt2",
reranker_model_name_or_path="distilbert-base-uncased",
segmenter_device="cpu",
),
"spanish": WordSegmenter(
segmenter_model_name_or_path="mrm8488/spanish-gpt2",
reranker_model_name_or_path="dccuchile/bert-base-spanish-wwm-cased",
segmenter_device="cpu",
),
"portuguese": WordSegmenter(
segmenter_model_name_or_path="pierreguillou/gpt2-small-portuguese",
reranker_model_name_or_path="neuralmind/bert-base-portuguese-cased",
segmenter_device="cpu",
),
"german": WordSegmenter(
segmenter_model_name_or_path="dbmdz/german-gpt2",
reranker_model_name_or_path="bert-base-german-cased",
segmenter_device="cpu",
),
}
language_list = list(model_dict.keys())
def format_dataframe(df):
if not isinstance(df, pd.DataFrame):
return df
df = df[["segmentation", "score"]]
df["score"] = df["score"].apply(lambda x: 1/x)
df["score"] = df["score"].apply(lambda x: round(x, 4))
return df
def convert_to_score_dict(df):
if not isinstance(df, pd.DataFrame):
return {}
df = df[["segmentation", "score"]]
return df.set_index("segmentation").T.to_dict("records")[0]
def get_candidates_df(candidates, segmenter_score_dict, reranker_score_dict ):
candidates_df = []
for candidate in candidates:
candidates_df.append(
{
"segmentation": candidate,
"segmenter score": segmenter_score_dict.get(candidate, 0),
"reranker score": reranker_score_dict.get(candidate, 0),
})
candidates_df = pd.DataFrame(candidates_df)
return candidates_df
def parse_candidates(candidates):
if not candidates:
return []
candidates = candidates.split(",")
candidates = [c.strip() for c in candidates]
return candidates
def predict(s1, language, use_reranker, topk, steps):
hashtag_list = [s1]
if language:
chosen_model = model_dict[language]
else:
chosen_model = model_dict["english (fast)"]
if not all([topk, steps]):
return None, None
segmentation = chosen_model.segment(hashtag_list, use_reranker=use_reranker, return_ranks=True, topk=topk, steps=steps)
segmenter_df = format_dataframe(segmentation.segmenter_rank)
reranker_df = format_dataframe(segmentation.reranker_rank)
if not use_reranker:
candidates_list = segmenter_df.head(3)["segmentation"].tolist()
else:
candidates_list = reranker_df.head(3)["segmentation"].tolist()
top_segmentation = segmentation.output[0]
segmenter_score_dict = convert_to_score_dict(segmenter_df)
reranker_score_dict = convert_to_score_dict(reranker_df)
top_segmentation_df = get_candidates_df([top_segmentation], segmenter_score_dict, reranker_score_dict)
candidates_df = get_candidates_df(candidates_list, segmenter_score_dict, reranker_score_dict)
output_df = pd.concat([top_segmentation_df, candidates_df], axis=0)
if use_reranker:
output_df = output_df.sort_values(by="reranker score", ascending=False)
else:
output_df = output_df.sort_values(by="segmenter score", ascending=False)
output_df = output_df.drop_duplicates(subset="segmentation", keep="first")
return top_segmentation, output_df
inputs = [
gr.Textbox(label="Hashtag"),
gr.Dropdown(language_list, label="Language", value="english (fast)"),
gr.Checkbox(label="Use reranker", value=False),
gr.Slider(0, 100, value=20, label="Advanced setting - Beamsearch: Number of beams"),
gr.Slider(0, 100, value=13, label="Advanced setting - Maximum number of spaces allowed")
]
outputs = [
gr.Textbox(label="Suggested segmentation"),
gr.DataFrame(label="Top alternatives"),
]
gr.Interface(fn=predict, inputs=inputs, outputs=outputs, title=app_title,
description=app_description,
examples=app_examples,
article = article_string).launch()