Spaces:
Running
Running
File size: 3,560 Bytes
1298030 9d0d0bd 1298030 9d0d0bd 1298030 9d0d0bd 1298030 9d0d0bd 1298030 9d0d0bd 1298030 9d0d0bd 1298030 9d0d0bd 1298030 9d0d0bd 1298030 9d0d0bd 1298030 9d0d0bd 818ab9b 1298030 9d0d0bd 1298030 9d0d0bd 1298030 9d0d0bd 1298030 9d0d0bd 1298030 9d0d0bd 1298030 9d0d0bd 1298030 9d0d0bd 1298030 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import logging
import pathlib
import gradio as gr
import pandas as pd
from gt4sd.algorithms.generation.hugging_face import (
HuggingFaceCTRLGenerator,
HuggingFaceGenerationAlgorithm,
HuggingFaceGPT2Generator,
HuggingFaceTransfoXLGenerator,
HuggingFaceOpenAIGPTGenerator,
HuggingFaceXLMGenerator,
HuggingFaceXLNetGenerator,
)
from gt4sd.algorithms.registry import ApplicationsRegistry
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
MODEL_FN = {
"HuggingFaceCTRLGenerator": HuggingFaceCTRLGenerator,
"HuggingFaceGPT2Generator": HuggingFaceGPT2Generator,
"HuggingFaceTransfoXLGenerator": HuggingFaceTransfoXLGenerator,
"HuggingFaceOpenAIGPTGenerator": HuggingFaceOpenAIGPTGenerator,
"HuggingFaceXLMGenerator": HuggingFaceXLMGenerator,
"HuggingFaceXLNetGenerator": HuggingFaceXLNetGenerator,
}
def run_inference(
model_type: str,
prompt: str,
length: float,
temperature: float,
prefix: str,
k: float,
p: float,
repetition_penalty: float,
):
model = model_type.split("_")[0]
version = model_type.split("_")[1]
if model not in MODEL_FN.keys():
raise ValueError(f"Model type {model} not supported")
config = MODEL_FN[model](
algorithm_version=version,
prompt=prompt,
length=length,
temperature=temperature,
repetition_penalty=repetition_penalty,
k=k,
p=p,
prefix=prefix,
)
model = HuggingFaceGenerationAlgorithm(config)
text = list(model.sample(1))[0]
return text
if __name__ == "__main__":
# Preparation (retrieve all available algorithms)
all_algos = ApplicationsRegistry.list_available()
algos = [
x["algorithm_application"] + "_" + x["algorithm_version"]
for x in list(filter(lambda x: "HuggingFace" in x["algorithm_name"], all_algos))
]
# Load metadata
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")
examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna(
""
)
print("Examples: ", examples.values.tolist())
with open(metadata_root.joinpath("article.md"), "r") as f:
article = f.read()
with open(metadata_root.joinpath("description.md"), "r") as f:
description = f.read()
demo = gr.Interface(
fn=run_inference,
title="HuggingFace language models",
inputs=[
gr.Dropdown(
algos,
label="Language model",
value="HuggingFaceGPT2Generator_gpt2",
),
gr.Textbox(
label="Text prompt",
placeholder="I'm a stochastic parrot.",
lines=1,
),
gr.Slider(minimum=5, maximum=100, value=20, label="Maximal length", step=1),
gr.Slider(
minimum=0.6, maximum=1.5, value=1.1, label="Decoding temperature"
),
gr.Textbox(
label="Prefix", placeholder="Some prefix (before the prompt)", lines=1
),
gr.Slider(minimum=2, maximum=500, value=50, label="Top-k", step=1),
gr.Slider(minimum=0.5, maximum=1, value=1.0, label="Decoding-p", step=1),
gr.Slider(minimum=0.5, maximum=5, value=1.0, label="Repetition penalty"),
],
outputs=gr.Textbox(label="Output"),
article=article,
description=description,
examples=examples.values.tolist(),
)
demo.launch(debug=True, show_error=True)
|