File size: 1,524 Bytes
e397ecb 8c958f2 e397ecb 1171f16 e397ecb 9347168 e397ecb aafd677 e397ecb deb823d e397ecb bef1ce5 054e2f6 e397ecb aafd677 e397ecb aafd677 e397ecb aafd677 254dc68 e397ecb aafd677 e397ecb aafd677 e397ecb b924a63 aafd677 387b3b6 0d9af56 e397ecb 99e9a67 aafd677 227623e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import os
import gradio as gr
from text_generation import Client, InferenceAPIClient
def get_client(model: str):
return InferenceAPIClient(model, token=os.getenv("HF_TOKEN", None),timeout=100)
def get_usernames(model: str):
"""
Returns:
(str, str, str, str): pre-prompt, username, bot name, separator
"""
if model in ("OpenAssistant/oasst-sft-1-pythia-12b", "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"):
return "", "<|prompter|>", "<|assistant|>", "<|endoftext|>"
return "", "User: ", "Assistant: ", "\n"
def predict(
inputs: str,
):
model = "OpenAssistant/oasst-sft-1-pythia-12b"
client = get_client(model)
preprompt, user_name, assistant_name, sep = get_usernames(model)
past = []
limits = ",in max 200 words"
total_inputs = preprompt + "".join(past) + inputs + limits + sep + assistant_name.rstrip()
partial_words = ""
if model in ("OpenAssistant/oasst-sft-1-pythia-12b", "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"):
iterator = client.generate(
total_inputs,
typical_p=0.1,
truncate=1000,
watermark=0,
max_new_tokens=502,
)
yield iterator.generated_text
g = gr.Interface(
fn=predict,
inputs=[
gr.components.Textbox(lines=3, label="Hi, how can I help you?", placeholder=""),
],
outputs=[
gr.inputs.Textbox(
lines=10,
label="",
)
]
)
g.queue(concurrency_count=1)
g.launch() |