changsr commited on
Commit
5f65454
1 Parent(s): fec9c86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -34
app.py CHANGED
@@ -1,56 +1,102 @@
1
  import gradio as gr
2
  from main import init,clip,answer
3
- # from huggingface_hub import InferenceClient
4
 
5
  """
6
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
7
  """
8
- # client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
9
 
10
 
11
- model,tokenizer = init()
 
12
 
13
 
14
- def respond(
15
- message,
16
- history: list[tuple[str, str]],
17
- system_message,
18
- max_tokens,
19
- temperature,
20
- top_p,
21
- ):
22
- res = answer(message,model,tokenizer)
23
- if res[1]>res[0]:
24
- return "unsafe" # unsafe
25
- else:
26
- return "safe" # safe
27
- # messages = [{"role": "system", "content": system_message}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  # for val in history:
30
  # if val[0]:
31
  # messages.append({"role": "user", "content": val[0]})
32
  # if val[1]:
33
  # messages.append({"role": "assistant", "content": val[1]})
 
 
 
 
 
 
 
 
34
 
35
- # messages.append({"role": "user", "content": message})
36
-
37
- # response = ""
38
-
39
- # for message in client.chat_completion(
40
- # messages,
41
- # max_tokens=max_tokens,
42
- # stream=True,
43
- # temperature=temperature,
44
- # top_p=top_p,
45
- # ):
46
- # token = message.choices[0].delta.content
47
 
48
- # response += token
49
- # yield response
50
 
51
- """
52
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
53
- """
54
  demo = gr.ChatInterface(
55
  respond,
56
  additional_inputs=[
 
1
  import gradio as gr
2
  from main import init,clip,answer
3
+ from huggingface_hub import InferenceClient
4
 
5
  """
6
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
7
  """
8
+ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
9
 
10
 
11
+ model,tokenizer = init("attention_lstm_pre.ckpt")
12
+ la_model,la_tokenizer = init("attention_lstm_last.ckpt")
13
 
14
 
15
+ # def respond(
16
+ # message,
17
+ # history: list[tuple[str, str]],
18
+ # system_message,
19
+ # max_tokens,
20
+ # temperature,
21
+ # top_p,
22
+ # ):
23
+ # res = answer(message,model,tokenizer)
24
+ # if res[1]>res[0]:
25
+ # return "unsafe" # unsafe
26
+ # else:
27
+ # messages = [{"role": "system", "content": system_message}]
28
+
29
+ # for val in history:
30
+ # if val[0]:
31
+ # messages.append({"role": "user", "content": val[0]})
32
+ # if val[1]:
33
+ # messages.append({"role": "assistant", "content": val[1]})
34
+
35
+ # messages.append({"role": "user", "content": message})
36
+
37
+ # response = ""
38
+
39
+ # for message in client.chat_completion(
40
+ # messages,
41
+ # max_tokens=max_tokens,
42
+ # stream=True,
43
+ # temperature=temperature,
44
+ # top_p=top_p,
45
+ # ):
46
+ # token = message.choices[0].delta.content
47
+
48
+ # response += token
49
+ # yield response
50
+ # # 收集所有部分响应
51
+
52
 
53
+ def generate_response(messages, max_tokens, temperature, top_p):
54
+ response = ""
55
+
56
+ for message in client.chat_completion(
57
+ messages,
58
+ max_tokens=max_tokens,
59
+ stream=True,
60
+ temperature=temperature,
61
+ top_p=top_p,
62
+ ):
63
+ response = ""
64
+ token = message.choices[0].delta.content
65
+ response += token
66
+ yield response
67
+
68
+ def collect_response(message, history, system_message, max_tokens, temperature, top_p):
69
+ # 创建消息列表
70
+ messages = [{"role": "system", "content": system_message}]
71
+
72
  # for val in history:
73
  # if val[0]:
74
  # messages.append({"role": "user", "content": val[0]})
75
  # if val[1]:
76
  # messages.append({"role": "assistant", "content": val[1]})
77
+
78
+ messages.append({"role": "user", "content": message})
79
+ # 收集所有部分响应
80
+ full_response = ""
81
+ for partial_response in generate_response(messages, max_tokens, temperature, top_p):
82
+ full_response += partial_response
83
+ return full_response
84
+
85
 
86
+ def respond(message, history, system_message, max_tokens, temperature, top_p):
87
+ res = answer(message, model, tokenizer)
88
+ if res[1] > res[0]:
89
+ return "unsafe" # unsafe
90
+ else:
91
+ # 收集并返回完整的响应
92
+ full_response = collect_response(message, history, system_message, max_tokens, temperature, top_p)
93
+ ress = answer(full_response,la_model,la_tokenizer)
94
+ if res[1] > res[0]:
95
+ return "unsafe" # unsafe
96
+ else:
97
+ return full_response
98
 
 
 
99
 
 
 
 
100
  demo = gr.ChatInterface(
101
  respond,
102
  additional_inputs=[