import logging import pathlib import gradio as gr import pandas as pd from gt4sd.algorithms.generation.hugging_face import ( HuggingFaceCTRLGenerator, HuggingFaceGenerationAlgorithm, HuggingFaceGPT2Generator, HuggingFaceTransfoXLGenerator, HuggingFaceOpenAIGPTGenerator, HuggingFaceXLMGenerator, HuggingFaceXLNetGenerator, ) from gt4sd.algorithms.registry import ApplicationsRegistry logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) MODEL_FN = { "HuggingFaceCTRLGenerator": HuggingFaceCTRLGenerator, "HuggingFaceGPT2Generator": HuggingFaceGPT2Generator, "HuggingFaceTransfoXLGenerator": HuggingFaceTransfoXLGenerator, "HuggingFaceOpenAIGPTGenerator": HuggingFaceOpenAIGPTGenerator, "HuggingFaceXLMGenerator": HuggingFaceXLMGenerator, "HuggingFaceXLNetGenerator": HuggingFaceXLNetGenerator, } def run_inference( model_type: str, prompt: str, length: float, temperature: float, prefix: str, k: float, p: float, repetition_penalty: float, ): model = model_type.split("_")[0] version = model_type.split("_")[1] if model not in MODEL_FN.keys(): raise ValueError(f"Model type {model} not supported") config = MODEL_FN[model]( algorithm_version=version, prompt=prompt, length=length, temperature=temperature, repetition_penalty=repetition_penalty, k=k, p=p, prefix=prefix, ) model = HuggingFaceGenerationAlgorithm(config) text = list(model.sample(1))[0] return text if __name__ == "__main__": # Preparation (retrieve all available algorithms) all_algos = ApplicationsRegistry.list_available() algos = [ x["algorithm_application"] + "_" + x["algorithm_version"] for x in list(filter(lambda x: "HuggingFace" in x["algorithm_name"], all_algos)) ] # Load metadata metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards") examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna( "" ) print("Examples: ", examples.values.tolist()) with open(metadata_root.joinpath("article.md"), "r") as f: article = f.read() with open(metadata_root.joinpath("description.md"), "r") as f: description = f.read() demo = gr.Interface( fn=run_inference, title="HuggingFace language models", inputs=[ gr.Dropdown( algos, label="Language model", value="HuggingFaceGPT2Generator_gpt2", ), gr.Textbox( label="Text prompt", placeholder="I'm a stochastic parrot.", lines=1, ), gr.Slider(minimum=5, maximum=100, value=20, label="Maximal length", step=1), gr.Slider( minimum=0.6, maximum=1.5, value=1.1, label="Decoding temperature" ), gr.Textbox( label="Prefix", placeholder="Some prefix (before the prompt)", lines=1 ), gr.Slider(minimum=2, maximum=500, value=50, label="Top-k", step=1), gr.Slider(minimum=0.5, maximum=1, value=1.0, label="Decoding-p", step=1), gr.Slider(minimum=0.5, maximum=5, value=1.0, label="Repetition penalty"), ], outputs=gr.Textbox(label="Output"), article=article, description=description, examples=examples.values.tolist(), ) demo.launch(debug=True, show_error=True)