File size: 1,524 Bytes
e397ecb
8c958f2
e397ecb
1171f16
 
e397ecb
 
9347168
e397ecb
 
 
 
 
 
 
aafd677
e397ecb
 
 
 
 
 
 
deb823d
e397ecb
 
 
 
bef1ce5
054e2f6
e397ecb
 
 
aafd677
 
e397ecb
aafd677
e397ecb
aafd677
254dc68
e397ecb
 
 
aafd677
e397ecb
aafd677
 
 
e397ecb
b924a63
aafd677
 
 
387b3b6
0d9af56
e397ecb
99e9a67
aafd677
 
227623e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import gradio as gr

from text_generation import Client, InferenceAPIClient


def get_client(model: str):
    return InferenceAPIClient(model, token=os.getenv("HF_TOKEN", None),timeout=100)


def get_usernames(model: str):
    """
    Returns:
        (str, str, str, str): pre-prompt, username, bot name, separator
    """
    if model in ("OpenAssistant/oasst-sft-1-pythia-12b", "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"):
        return "", "<|prompter|>", "<|assistant|>", "<|endoftext|>"
    return "", "User: ", "Assistant: ", "\n"


def predict(
    inputs: str,
):
    model = "OpenAssistant/oasst-sft-1-pythia-12b"
    client = get_client(model)
    preprompt, user_name, assistant_name, sep = get_usernames(model)

    past = []
    limits = ",in max 200 words"
    total_inputs = preprompt + "".join(past) + inputs + limits + sep + assistant_name.rstrip() 

    partial_words = ""

    if model in ("OpenAssistant/oasst-sft-1-pythia-12b", "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"):
        iterator = client.generate(
            total_inputs,
            typical_p=0.1,
            truncate=1000,
            watermark=0,
            max_new_tokens=502,
        )


    yield iterator.generated_text

g = gr.Interface(
    fn=predict,
    inputs=[

        gr.components.Textbox(lines=3, label="Hi, how can I help you?", placeholder=""),
    ],
    outputs=[
        gr.inputs.Textbox(
            lines=10,
            label="",
        )
    ]
)
g.queue(concurrency_count=1)
g.launch()