Spaces:
Running
Running
File size: 3,277 Bytes
d320ce9 6938961 d320ce9 6938961 d320ce9 6938961 d320ce9 6938961 d320ce9 6938961 d320ce9 6938961 d320ce9 6938961 d320ce9 6938961 d320ce9 6938961 d320ce9 6938961 d320ce9 6938961 d320ce9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import logging
import pathlib
import gradio as gr
import pandas as pd
from gt4sd.algorithms.conditional_generation.key_bert import (
KeywordBERTGenerationAlgorithm,
KeyBERTGenerator,
)
from gt4sd.algorithms.registry import ApplicationsRegistry
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
def run_inference(
algorithm_version: str,
text: str,
minimum_keyphrase_ngram: int,
maximum_keyphrase_ngram: int,
stop_words: str,
use_maxsum: bool,
number_of_candidates: int,
use_mmr: bool,
diversity: float,
number_of_keywords: int,
):
config = KeyBERTGenerator(
algorithm_version=algorithm_version,
minimum_keyphrase_ngram=minimum_keyphrase_ngram,
maximum_keyphrase_ngram=maximum_keyphrase_ngram,
stop_words=stop_words,
top_n=number_of_keywords,
use_maxsum=use_maxsum,
use_mmr=use_mmr,
diversity=diversity,
number_of_candidates=number_of_candidates,
)
model = KeywordBERTGenerationAlgorithm(configuration=config, target=text)
text = list(model.sample(number_of_keywords))
return text
if __name__ == "__main__":
# Preparation (retrieve all available algorithms)
all_algos = ApplicationsRegistry.list_available()
algos = [
x["algorithm_version"]
for x in list(filter(lambda x: "KeywordBERT" in x["algorithm_name"], all_algos))
]
# Load metadata
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")
examples = pd.read_csv(
metadata_root.joinpath("examples.csv"), sep=",", header=None
).fillna("")
with open(metadata_root.joinpath("article.md"), "r") as f:
article = f.read()
with open(metadata_root.joinpath("description.md"), "r") as f:
description = f.read()
demo = gr.Interface(
fn=run_inference,
title="KeywordBERT",
inputs=[
gr.Dropdown(algos, label="Algorithm version", value="circa_bert_v2"),
gr.Textbox(
label="Text prompt",
placeholder="This is a text I want to understand better",
lines=5,
),
gr.Slider(
minimum=1, maximum=5, value=1, label="Minimum keyphrase ngram", step=1
),
gr.Slider(
minimum=2, maximum=10, value=1, label="Maximum keyphrase ngram", step=1
),
gr.Textbox(label="Stop words", placeholder="english", lines=1),
gr.Radio(choices=[True, False], label="MaxSum", value=False),
gr.Slider(
minimum=5, maximum=100, value=20, label="MaxSum candidates", step=1
),
gr.Radio(
choices=[True, False],
label="Max. marginal relevance control",
value=False,
),
gr.Slider(minimum=0.1, maximum=1, value=0.5, label="Diversity"),
gr.Slider(
minimum=1, maximum=50, value=10, label="Number of keywords", step=1
),
],
outputs=gr.Textbox(label="Output"),
article=article,
description=description,
examples=examples.values.tolist(),
)
demo.launch(debug=True, show_error=True)
|