alfonsovelp commited on
Commit
9e87282
β€’
1 Parent(s): 8fcea5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -1
app.py CHANGED
@@ -1,3 +1,107 @@
1
  import gradio as gr
 
 
2
 
3
- gr.load("mistralai/Mixtral-8x22B-Instruct-v0.1").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from huggingface_hub import InferenceClient
4
 
5
+ model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
6
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
7
+
8
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto"
9
+
10
+
11
+
12
+ def format_prompt(message, history):
13
+ prompt = "<s>"
14
+ for user_prompt, bot_response in history:
15
+ prompt += f"[INST] {user_prompt} [/INST]"
16
+ prompt += f" {bot_response}</s> "
17
+ prompt += f"[INST] {message} [/INST]"
18
+ return prompt
19
+
20
+ def generate(
21
+ prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
22
+ ):
23
+ temperature = float(temperature)
24
+ if temperature < 1e-2:
25
+ temperature = 1e-2
26
+ top_p = float(top_p)
27
+
28
+ generate_kwargs = dict(
29
+ temperature=temperature,
30
+ max_new_tokens=max_new_tokens,
31
+ top_p=top_p,
32
+ repetition_penalty=repetition_penalty,
33
+ do_sample=True,
34
+ seed=42,
35
+ )
36
+
37
+ formatted_prompt = format_prompt(prompt, history)
38
+
39
+ inputs = tokenizer.apply_chat_template(formatted_prompt, return_tensors="pt").to("cuda")
40
+
41
+ strea = model.generate(inputs, , **generate_kwargs, stream=True, details=True, return_full_text=False)
42
+ output = ""
43
+
44
+ for response in stream:
45
+ output += response.token.text
46
+ yield output
47
+ return output
48
+
49
+
50
+ additional_inputs=[
51
+ gr.Slider(
52
+ label="Temperature",
53
+ value=0.9,
54
+ minimum=0.0,
55
+ maximum=1.0,
56
+ step=0.05,
57
+ interactive=True,
58
+ info="Higher values produce more diverse outputs",
59
+ ),
60
+ gr.Slider(
61
+ label="Max new tokens",
62
+ value=256,
63
+ minimum=0,
64
+ maximum=1048,
65
+ step=64,
66
+ interactive=True,
67
+ info="The maximum numbers of new tokens",
68
+ ),
69
+ gr.Slider(
70
+ label="Top-p (nucleus sampling)",
71
+ value=0.90,
72
+ minimum=0.0,
73
+ maximum=1,
74
+ step=0.05,
75
+ interactive=True,
76
+ info="Higher values sample more low-probability tokens",
77
+ ),
78
+ gr.Slider(
79
+ label="Repetition penalty",
80
+ value=1.2,
81
+ minimum=1.0,
82
+ maximum=2.0,
83
+ step=0.05,
84
+ interactive=True,
85
+ info="Penalize repeated tokens",
86
+ )
87
+ ]
88
+
89
+ css = """
90
+ #mkd {
91
+ height: 500px;
92
+ overflow: auto;
93
+ border: 1px solid #ccc;
94
+ }
95
+ """
96
+
97
+ with gr.Blocks(css=css) as demo:
98
+ gr.HTML("<h1><center>Mistral 7B Instruct<h1><center>")
99
+ gr.HTML("<h3><center>In this demo, you can chat with <a href='https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1'>Mistral-7B-Instruct</a> model. πŸ’¬<h3><center>")
100
+ gr.HTML("<h3><center>Learn more about the model <a href='https://huggingface.co/docs/transformers/main/model_doc/mistral'>here</a>. πŸ“š<h3><center>")
101
+ gr.ChatInterface(
102
+ generate,
103
+ additional_inputs=additional_inputs,
104
+ examples=[["What is the secret to life?"], ["Write me a recipe for pancakes."]]
105
+ )
106
+
107
+ demo.queue(concurrency_count=75, max_size=100).launch(debug=True)