onuri commited on
Commit
99a553b
1 Parent(s): 387b3b6

Create app1.py

Browse files
Files changed (1) hide show
  1. app1.py +60 -0
app1.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+
4
+ from text_generation import Client, InferenceAPIClient
5
+
6
+
7
+ def get_client(model: str):
8
+ return InferenceAPIClient(model, token=os.getenv("HF_TOKEN", None),timeout=100)
9
+
10
+
11
+ def get_usernames(model: str):
12
+ """
13
+ Returns:
14
+ (str, str, str, str): pre-prompt, username, bot name, separator
15
+ """
16
+ if model in ("OpenAssistant/oasst-sft-1-pythia-12b", "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"):
17
+ return "", "<|prompter|>", "<|assistant|>", "<|endoftext|>"
18
+ return "", "User: ", "Assistant: ", "\n"
19
+
20
+
21
+ def predict(
22
+ inputs: str,
23
+ ):
24
+ model = "OpenAssistant/oasst-sft-1-pythia-12b"
25
+ client = get_client(model)
26
+ preprompt, user_name, assistant_name, sep = get_usernames(model)
27
+
28
+ past = []
29
+ limits = ",in max 200 words"
30
+ total_inputs = preprompt + "".join(past) + inputs + limits + sep + assistant_name.rstrip()
31
+
32
+ partial_words = ""
33
+
34
+ if model in ("OpenAssistant/oasst-sft-1-pythia-12b", "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"):
35
+ iterator = client.generate(
36
+ total_inputs,
37
+ typical_p=0.1,
38
+ truncate=1000,
39
+ watermark=0,
40
+ max_new_tokens=502,
41
+ )
42
+
43
+
44
+ yield iterator.generated_text
45
+
46
+ g = gr.Interface(
47
+ fn=predict,
48
+ inputs=[
49
+
50
+ gr.components.Textbox(lines=3, label="Hi, how can I help you?", placeholder=""),
51
+ ],
52
+ outputs=[
53
+ gr.inputs.Textbox(
54
+ lines=10,
55
+ label="",
56
+ )
57
+ ]
58
+ )
59
+ g.queue(concurrency_count=1)
60
+ g.launch()