osanseviero commited on
Commit
5c0e14a
1 Parent(s): b472032

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -0
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import InferenceClient
2
+ import gradio as gr
3
+
4
+ API_URL = "https://api-inference.huggingface.co/models/"
5
+
6
+ client = InferenceClient(
7
+ "mistralai/Mistral-7B-Instruct-v0.1"
8
+ )
9
+
10
+
11
+ def format_prompt(message, history):
12
+ prompt = "<s>"
13
+ for user_prompt, bot_response in history:
14
+ prompt += f"[INST] {user_prompt} [/INST]"
15
+ prompt += f" {bot_response}</s> "
16
+ prompt += f"[INST] {message} [/INST]"
17
+ return prompt
18
+
19
+ def generate(
20
+ prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
21
+ ):
22
+ temperature = float(temperature)
23
+ if temperature < 1e-2:
24
+ temperature = 1e-2
25
+ top_p = float(top_p)
26
+
27
+ generate_kwargs = dict(
28
+ temperature=temperature,
29
+ max_new_tokens=max_new_tokens,
30
+ top_p=top_p,
31
+ repetition_penalty=repetition_penalty,
32
+ do_sample=True,
33
+ seed=42,
34
+ )
35
+
36
+ formatted_prompt = format_prompt(prompt, history)
37
+
38
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
39
+ output = ""
40
+
41
+ for response in stream:
42
+ output += response.token.text
43
+ yield output
44
+ return output
45
+
46
+
47
+ additional_inputs=[
48
+ gr.Slider(
49
+ label="Temperature",
50
+ value=0.9,
51
+ minimum=0.0,
52
+ maximum=1.0,
53
+ step=0.05,
54
+ interactive=True,
55
+ info="Higher values produce more diverse outputs",
56
+ ),
57
+ gr.Slider(
58
+ label="Max new tokens",
59
+ value=256,
60
+ minimum=0,
61
+ maximum=8192,
62
+ step=64,
63
+ interactive=True,
64
+ info="The maximum numbers of new tokens",
65
+ ),
66
+ gr.Slider(
67
+ label="Top-p (nucleus sampling)",
68
+ value=0.90,
69
+ minimum=0.0,
70
+ maximum=1,
71
+ step=0.05,
72
+ interactive=True,
73
+ info="Higher values sample more low-probability tokens",
74
+ ),
75
+ gr.Slider(
76
+ label="Repetition penalty",
77
+ value=1.2,
78
+ minimum=1.0,
79
+ maximum=2.0,
80
+ step=0.05,
81
+ interactive=True,
82
+ info="Penalize repeated tokens",
83
+ )
84
+ ]
85
+
86
+ with gr.Blocks() as demo:
87
+ gr.ChatInterface(
88
+ generate,
89
+ additional_inputs=additional_inputs,
90
+ )
91
+
92
+ demo.queue().launch(debug=True)