File size: 3,560 Bytes
1298030
 
 
 
9d0d0bd
 
 
 
 
 
 
 
1298030
 
 
 
 
 
 
9d0d0bd
 
 
 
 
 
 
 
 
1298030
 
9d0d0bd
 
1298030
9d0d0bd
 
 
 
 
1298030
9d0d0bd
 
1298030
9d0d0bd
 
 
 
 
 
 
 
 
 
 
1298030
 
9d0d0bd
 
 
 
1298030
 
 
 
 
 
 
9d0d0bd
 
1298030
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d0d0bd
1298030
 
 
9d0d0bd
818ab9b
1298030
 
9d0d0bd
 
1298030
 
9d0d0bd
1298030
9d0d0bd
1298030
9d0d0bd
 
1298030
9d0d0bd
 
 
1298030
9d0d0bd
1298030
 
9d0d0bd
1298030
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import logging
import pathlib
import gradio as gr
import pandas as pd
from gt4sd.algorithms.generation.hugging_face import (
    HuggingFaceCTRLGenerator,
    HuggingFaceGenerationAlgorithm,
    HuggingFaceGPT2Generator,
    HuggingFaceTransfoXLGenerator,
    HuggingFaceOpenAIGPTGenerator,
    HuggingFaceXLMGenerator,
    HuggingFaceXLNetGenerator,
)
from gt4sd.algorithms.registry import ApplicationsRegistry


logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

MODEL_FN = {
    "HuggingFaceCTRLGenerator": HuggingFaceCTRLGenerator,
    "HuggingFaceGPT2Generator": HuggingFaceGPT2Generator,
    "HuggingFaceTransfoXLGenerator": HuggingFaceTransfoXLGenerator,
    "HuggingFaceOpenAIGPTGenerator": HuggingFaceOpenAIGPTGenerator,
    "HuggingFaceXLMGenerator": HuggingFaceXLMGenerator,
    "HuggingFaceXLNetGenerator": HuggingFaceXLNetGenerator,
}


def run_inference(
    model_type: str,
    prompt: str,
    length: float,
    temperature: float,
    prefix: str,
    k: float,
    p: float,
    repetition_penalty: float,
):
    model = model_type.split("_")[0]
    version = model_type.split("_")[1]

    if model not in MODEL_FN.keys():
        raise ValueError(f"Model type {model} not supported")
    config = MODEL_FN[model](
        algorithm_version=version,
        prompt=prompt,
        length=length,
        temperature=temperature,
        repetition_penalty=repetition_penalty,
        k=k,
        p=p,
        prefix=prefix,
    )

    model = HuggingFaceGenerationAlgorithm(config)
    text = list(model.sample(1))[0]

    return text


if __name__ == "__main__":

    # Preparation (retrieve all available algorithms)
    all_algos = ApplicationsRegistry.list_available()
    algos = [
        x["algorithm_application"] + "_" + x["algorithm_version"]
        for x in list(filter(lambda x: "HuggingFace" in x["algorithm_name"], all_algos))
    ]

    # Load metadata
    metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")

    examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna(
        ""
    )
    print("Examples: ", examples.values.tolist())

    with open(metadata_root.joinpath("article.md"), "r") as f:
        article = f.read()
    with open(metadata_root.joinpath("description.md"), "r") as f:
        description = f.read()

    demo = gr.Interface(
        fn=run_inference,
        title="HuggingFace language models",
        inputs=[
            gr.Dropdown(
                algos,
                label="Language model",
                value="HuggingFaceGPT2Generator_gpt2",
            ),
            gr.Textbox(
                label="Text prompt",
                placeholder="I'm a stochastic parrot.",
                lines=1,
            ),
            gr.Slider(minimum=5, maximum=100, value=20, label="Maximal length", step=1),
            gr.Slider(
                minimum=0.6, maximum=1.5, value=1.1, label="Decoding temperature"
            ),
            gr.Textbox(
                label="Prefix", placeholder="Some prefix (before the prompt)", lines=1
            ),
            gr.Slider(minimum=2, maximum=500, value=50, label="Top-k", step=1),
            gr.Slider(minimum=0.5, maximum=1, value=1.0, label="Decoding-p", step=1),
            gr.Slider(minimum=0.5, maximum=5, value=1.0, label="Repetition penalty"),
        ],
        outputs=gr.Textbox(label="Output"),
        article=article,
        description=description,
        examples=examples.values.tolist(),
    )
    demo.launch(debug=True, show_error=True)