File size: 3,277 Bytes
d320ce9
 
 
 
6938961
 
 
d320ce9
 
 
 
 
 
 
 
 
6938961
 
 
 
 
 
 
 
 
 
d320ce9
 
6938961
 
 
 
 
 
 
 
 
 
 
 
 
d320ce9
 
 
 
 
 
 
 
 
6938961
 
d320ce9
 
 
 
 
 
6938961
d320ce9
 
 
 
 
 
 
 
 
6938961
d320ce9
6938961
d320ce9
6938961
 
d320ce9
 
6938961
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d320ce9
6938961
d320ce9
6938961
d320ce9
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import logging
import pathlib
import gradio as gr
import pandas as pd
from gt4sd.algorithms.conditional_generation.key_bert import (
    KeywordBERTGenerationAlgorithm,
    KeyBERTGenerator,
)
from gt4sd.algorithms.registry import ApplicationsRegistry


logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


def run_inference(
    algorithm_version: str,
    text: str,
    minimum_keyphrase_ngram: int,
    maximum_keyphrase_ngram: int,
    stop_words: str,
    use_maxsum: bool,
    number_of_candidates: int,
    use_mmr: bool,
    diversity: float,
    number_of_keywords: int,
):

    config = KeyBERTGenerator(
        algorithm_version=algorithm_version,
        minimum_keyphrase_ngram=minimum_keyphrase_ngram,
        maximum_keyphrase_ngram=maximum_keyphrase_ngram,
        stop_words=stop_words,
        top_n=number_of_keywords,
        use_maxsum=use_maxsum,
        use_mmr=use_mmr,
        diversity=diversity,
        number_of_candidates=number_of_candidates,
    )
    model = KeywordBERTGenerationAlgorithm(configuration=config, target=text)
    text = list(model.sample(number_of_keywords))

    return text


if __name__ == "__main__":

    # Preparation (retrieve all available algorithms)
    all_algos = ApplicationsRegistry.list_available()
    algos = [
        x["algorithm_version"]
        for x in list(filter(lambda x: "KeywordBERT" in x["algorithm_name"], all_algos))
    ]

    # Load metadata
    metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")

    examples = pd.read_csv(
        metadata_root.joinpath("examples.csv"), sep=",", header=None
    ).fillna("")

    with open(metadata_root.joinpath("article.md"), "r") as f:
        article = f.read()
    with open(metadata_root.joinpath("description.md"), "r") as f:
        description = f.read()

    demo = gr.Interface(
        fn=run_inference,
        title="KeywordBERT",
        inputs=[
            gr.Dropdown(algos, label="Algorithm version", value="circa_bert_v2"),
            gr.Textbox(
                label="Text prompt",
                placeholder="This is a text I want to understand better",
                lines=5,
            ),
            gr.Slider(
                minimum=1, maximum=5, value=1, label="Minimum keyphrase ngram", step=1
            ),
            gr.Slider(
                minimum=2, maximum=10, value=1, label="Maximum keyphrase ngram", step=1
            ),
            gr.Textbox(label="Stop words", placeholder="english", lines=1),
            gr.Radio(choices=[True, False], label="MaxSum", value=False),
            gr.Slider(
                minimum=5, maximum=100, value=20, label="MaxSum candidates", step=1
            ),
            gr.Radio(
                choices=[True, False],
                label="Max. marginal relevance control",
                value=False,
            ),
            gr.Slider(minimum=0.1, maximum=1, value=0.5, label="Diversity"),
            gr.Slider(
                minimum=1, maximum=50, value=10, label="Number of keywords", step=1
            ),
        ],
        outputs=gr.Textbox(label="Output"),
        article=article,
        description=description,
        examples=examples.values.tolist(),
    )
    demo.launch(debug=True, show_error=True)