Spaces:
Running
Running
import logging | |
import pathlib | |
import gradio as gr | |
import pandas as pd | |
from gt4sd.algorithms.conditional_generation.key_bert import ( | |
KeywordBERTGenerationAlgorithm, | |
KeyBERTGenerator, | |
) | |
from gt4sd.algorithms.registry import ApplicationsRegistry | |
logger = logging.getLogger(__name__) | |
logger.addHandler(logging.NullHandler()) | |
def run_inference( | |
algorithm_version: str, | |
text: str, | |
minimum_keyphrase_ngram: int, | |
maximum_keyphrase_ngram: int, | |
stop_words: str, | |
use_maxsum: bool, | |
number_of_candidates: int, | |
use_mmr: bool, | |
diversity: float, | |
number_of_keywords: int, | |
): | |
config = KeyBERTGenerator( | |
algorithm_version=algorithm_version, | |
minimum_keyphrase_ngram=minimum_keyphrase_ngram, | |
maximum_keyphrase_ngram=maximum_keyphrase_ngram, | |
stop_words=stop_words, | |
top_n=number_of_keywords, | |
use_maxsum=use_maxsum, | |
use_mmr=use_mmr, | |
diversity=diversity, | |
number_of_candidates=number_of_candidates, | |
) | |
model = KeywordBERTGenerationAlgorithm(configuration=config, target=text) | |
text = list(model.sample(number_of_keywords)) | |
return text | |
if __name__ == "__main__": | |
# Preparation (retrieve all available algorithms) | |
all_algos = ApplicationsRegistry.list_available() | |
algos = [ | |
x["algorithm_version"] | |
for x in list(filter(lambda x: "KeywordBERT" in x["algorithm_name"], all_algos)) | |
] | |
# Load metadata | |
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards") | |
examples = pd.read_csv( | |
metadata_root.joinpath("examples.csv"), sep=",", header=None | |
).fillna("") | |
with open(metadata_root.joinpath("article.md"), "r") as f: | |
article = f.read() | |
with open(metadata_root.joinpath("description.md"), "r") as f: | |
description = f.read() | |
demo = gr.Interface( | |
fn=run_inference, | |
title="KeywordBERT", | |
inputs=[ | |
gr.Dropdown(algos, label="Algorithm version", value="circa_bert_v2"), | |
gr.Textbox( | |
label="Text prompt", | |
placeholder="This is a text I want to understand better", | |
lines=5, | |
), | |
gr.Slider( | |
minimum=1, maximum=5, value=1, label="Minimum keyphrase ngram", step=1 | |
), | |
gr.Slider( | |
minimum=2, maximum=10, value=1, label="Maximum keyphrase ngram", step=1 | |
), | |
gr.Textbox(label="Stop words", placeholder="english", lines=1), | |
gr.Radio(choices=[True, False], label="MaxSum", value=False), | |
gr.Slider( | |
minimum=5, maximum=100, value=20, label="MaxSum candidates", step=1 | |
), | |
gr.Radio( | |
choices=[True, False], | |
label="Max. marginal relevance control", | |
value=False, | |
), | |
gr.Slider(minimum=0.1, maximum=1, value=0.5, label="Diversity"), | |
gr.Slider( | |
minimum=1, maximum=50, value=10, label="Number of keywords", step=1 | |
), | |
], | |
outputs=gr.Textbox(label="Output"), | |
article=article, | |
description=description, | |
examples=examples.values.tolist(), | |
) | |
demo.launch(debug=True, show_error=True) | |