sotirios-slv commited on
Commit
4b137c2
1 Parent(s): 73c11b0

Switched template

Browse files
Files changed (2) hide show
  1. app.py +100 -105
  2. requirements.txt +0 -1
app.py CHANGED
@@ -1,111 +1,106 @@
 
1
  import gradio as gr
2
- from http import HTTPStatus
3
- import dashscope
4
- from dashscope import Generation
5
- from dashscope.api_entities.dashscope_response import Role
6
- from typing import List, Optional, Tuple, Dict
7
- from urllib.error import HTTPError
8
 
9
- default_system = "You are a helpful assistant."
10
-
11
- YOUR_API_TOKEN = os.getenv("YOUR_API_TOKEN")
12
- dashscope.api_key = YOUR_API_TOKEN
13
-
14
- History = List[Tuple[str, str]]
15
- Messages = List[Dict[str, str]]
16
-
17
-
18
- def clear_session() -> History:
19
- return "", []
20
-
21
-
22
- def modify_system_session(system: str) -> str:
23
- if system is None or len(system) == 0:
24
- system = default_system
25
- return system, system, []
26
-
27
-
28
- def history_to_messages(history: History, system: str) -> Messages:
29
- messages = [{"role": Role.SYSTEM, "content": system}]
30
- for h in history:
31
- messages.append({"role": Role.USER, "content": h[0]})
32
- messages.append({"role": Role.ASSISTANT, "content": h[1]})
33
- return messages
34
-
35
-
36
- def messages_to_history(messages: Messages) -> Tuple[str, History]:
37
- assert messages[0]["role"] == Role.SYSTEM
38
- system = messages[0]["content"]
39
- history = []
40
- for q, r in zip(messages[1::2], messages[2::2]):
41
- history.append([q["content"], r["content"]])
42
- return system, history
43
-
44
-
45
- def model_chat(
46
- query: Optional[str], history: Optional[History], system: str
47
- ) -> Tuple[str, str, History]:
48
- if query is None:
49
- query = ""
50
- if history is None:
51
- history = []
52
- messages = history_to_messages(history, system)
53
- messages.append({"role": Role.USER, "content": query})
54
- gen = Generation.call(
55
- model="qwen1.5-72b-chat",
56
- messages=messages,
57
- result_format="message",
58
- stream=True,
59
- )
60
- for response in gen:
61
- if response.status_code == HTTPStatus.OK:
62
- role = response.output.choices[0].message.role
63
- response = response.output.choices[0].message.content
64
- system, history = messages_to_history(
65
- messages + [{"role": role, "content": response}]
66
- )
67
- yield "", history, system
68
- else:
69
- raise HTTPError(
70
- "Request id: %s, Status code: %s, error code: %s, error message: %s"
71
- % (
72
- response.request_id,
73
- response.status_code,
74
- response.code,
75
- response.message,
76
- )
77
- )
78
-
79
-
80
- with gr.Blocks() as demo:
81
- gr.Markdown("""<center><font size=8>Qwen1.5-72B-Chat</center>""")
82
- gr.Markdown(
83
- """<center><font size=4>Qwen1.5-72B-Chat is the 72-billion parameter chat model of the Qwen series.</center>"""
84
  )
85
 
86
- with gr.Row():
87
- with gr.Column(scale=3):
88
- system_input = gr.Textbox(value=default_system, lines=1, label="System")
89
- with gr.Column(scale=1):
90
- modify_system = gr.Button("🛠️ Set system prompt and clear history.", scale=2)
91
- system_state = gr.Textbox(value=default_system, visible=False)
92
- chatbot = gr.Chatbot(label="Qwen1.5-72B-Chat")
93
- textbox = gr.Textbox(lines=2, label="Input")
94
-
95
- with gr.Row():
96
- clear_history = gr.Button("🧹 Clear history")
97
- sumbit = gr.Button("🚀 Send")
98
 
99
- sumbit.click(
100
- model_chat,
101
- inputs=[textbox, chatbot, system_state],
102
- outputs=[textbox, chatbot, system_input],
103
- )
104
- clear_history.click(fn=clear_session, inputs=[], outputs=[textbox, chatbot])
105
- modify_system.click(
106
- fn=modify_system_session,
107
- inputs=[system_input],
108
- outputs=[system_state, system_input, chatbot],
109
  )
110
-
111
- demo.queue(api_open=False).launch(max_threads=10, height=800, share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import InferenceClient
2
  import gradio as gr
 
 
 
 
 
 
3
 
4
+ client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.1")
5
+
6
+
7
+ def format_prompt(message, history):
8
+ prompt = "<s>"
9
+ for user_prompt, bot_response in history:
10
+ prompt += f"[INST] {user_prompt} [/INST]"
11
+ prompt += f" {bot_response}</s> "
12
+ prompt += f"[INST] {message} [/INST]"
13
+ return prompt
14
+
15
+
16
+ def generate(
17
+ prompt,
18
+ history,
19
+ temperature=0.9,
20
+ max_new_tokens=256,
21
+ top_p=0.95,
22
+ repetition_penalty=1.0,
23
+ ):
24
+ temperature = float(temperature)
25
+ if temperature < 1e-2:
26
+ temperature = 1e-2
27
+ top_p = float(top_p)
28
+
29
+ generate_kwargs = dict(
30
+ temperature=temperature,
31
+ max_new_tokens=max_new_tokens,
32
+ top_p=top_p,
33
+ repetition_penalty=repetition_penalty,
34
+ do_sample=True,
35
+ seed=42,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  )
37
 
38
+ formatted_prompt = format_prompt(prompt, history)
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ stream = client.text_generation(
41
+ formatted_prompt,
42
+ **generate_kwargs,
43
+ stream=True,
44
+ details=True,
45
+ return_full_text=False,
 
 
 
 
46
  )
47
+ output = ""
48
+
49
+ for response in stream:
50
+ output += response.token.text
51
+ yield output
52
+ return output
53
+
54
+
55
+ additional_inputs = [
56
+ gr.Slider(
57
+ label="Temperature",
58
+ value=0.9,
59
+ minimum=0.0,
60
+ maximum=1.0,
61
+ step=0.05,
62
+ interactive=True,
63
+ info="Higher values produce more diverse outputs",
64
+ ),
65
+ gr.Slider(
66
+ label="Max new tokens",
67
+ value=256,
68
+ minimum=0,
69
+ maximum=1048,
70
+ step=64,
71
+ interactive=True,
72
+ info="The maximum numbers of new tokens",
73
+ ),
74
+ gr.Slider(
75
+ label="Top-p (nucleus sampling)",
76
+ value=0.90,
77
+ minimum=0.0,
78
+ maximum=1,
79
+ step=0.05,
80
+ interactive=True,
81
+ info="Higher values sample more low-probability tokens",
82
+ ),
83
+ gr.Slider(
84
+ label="Repetition penalty",
85
+ value=1.2,
86
+ minimum=1.0,
87
+ maximum=2.0,
88
+ step=0.05,
89
+ interactive=True,
90
+ info="Penalize repeated tokens",
91
+ ),
92
+ ]
93
+
94
+
95
+ gr.ChatInterface(
96
+ fn=generate,
97
+ chatbot=gr.Chatbot(
98
+ show_label=False,
99
+ show_share_button=False,
100
+ show_copy_button=True,
101
+ likeable=True,
102
+ layout="panel",
103
+ ),
104
+ additional_inputs=additional_inputs,
105
+ title="""Mistral 7B""",
106
+ ).launch(show_api=False)
requirements.txt CHANGED
@@ -1 +0,0 @@
1
- dashscope