import os import gradio as gr from text_generation import Client, InferenceAPIClient def get_client(model: str): return InferenceAPIClient(model, token=os.getenv("HF_TOKEN", None),timeout=100) def get_usernames(model: str): """ Returns: (str, str, str, str): pre-prompt, username, bot name, separator """ if model in ("OpenAssistant/oasst-sft-1-pythia-12b", "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"): return "", "<|prompter|>", "<|assistant|>", "<|endoftext|>" return "", "User: ", "Assistant: ", "\n" def predict( inputs: str, ): model = "OpenAssistant/oasst-sft-1-pythia-12b" client = get_client(model) preprompt, user_name, assistant_name, sep = get_usernames(model) past = [] limits = ",in max 200 words" total_inputs = preprompt + "".join(past) + inputs + limits + sep + assistant_name.rstrip() partial_words = "" if model in ("OpenAssistant/oasst-sft-1-pythia-12b", "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"): iterator = client.generate( total_inputs, typical_p=0.1, truncate=1000, watermark=0, max_new_tokens=502, ) yield iterator.generated_text g = gr.Interface( fn=predict, inputs=[ gr.components.Textbox(lines=3, label="Hi, how can I help you?", placeholder=""), ], outputs=[ gr.inputs.Textbox( lines=10, label="", ) ] ) g.queue(concurrency_count=1) g.launch()