bytedancerneat commited on
Commit
c65e58e
·
verified ·
1 Parent(s): 1de832e

Update doubao_service.py

Browse files
Files changed (1) hide show
  1. doubao_service.py +167 -166
doubao_service.py CHANGED
@@ -1,166 +1,167 @@
1
- import time
2
- import requests
3
- import json
4
- from volcenginesdkarkruntime import Ark
5
- from util.config_util import read_config as config
6
- from util import logger
7
- import volcenginesdkcore
8
- import volcenginesdkark
9
- from volcenginesdkcore.rest import ApiException
10
- from util.logger_util import log_decorate
11
-
12
-
13
- class DouBaoService:
14
-
15
- def __init__(self, model_name):
16
- self.conf = config()[f"{model_name}ModelInfo"]
17
- self.client = self.init_client()
18
- self._complete_args = {}
19
-
20
-
21
-
22
- def init_client(self):
23
- base_url = self.conf["BASE_URL"]
24
- ak = self.conf["ACCESS_KEY"]
25
- sk = self.conf["SECRET_KEY"]
26
- # api_key = self.conf["API_KEY"]
27
- client = Ark(ak=ak, sk=sk, base_url=base_url)
28
- # client = Ark(ak=api_key, base_url=base_url)
29
- return client
30
-
31
- def get_api_key(self):
32
- configuration = volcenginesdkcore.Configuration()
33
- configuration.ak = self.conf["ACCESS_KEY"]
34
- configuration.sk = self.conf["SECRET_KEY"]
35
- configuration.region = "cn-beijing"
36
- endpoint_id = self.conf["ENDPOINT_ID"]
37
-
38
- volcenginesdkcore.Configuration.set_default(configuration)
39
-
40
- # use global default configuration
41
- api_instance = volcenginesdkark.ARKApi()
42
- get_api_key_request = volcenginesdkark.GetApiKeyRequest(
43
- duration_seconds=30 * 24 * 3600,
44
- resource_type="endpoint",
45
- resource_ids=[
46
- endpoint_id
47
- ],
48
- )
49
-
50
- try:
51
- resp = api_instance.get_api_key(get_api_key_request)
52
- return resp.api_key
53
- except ApiException as e:
54
- logger.error(f"Exception when calling api: {e}")
55
-
56
- def set_complete_args(self, temperature=None, top_p=None, max_token=None):
57
- if temperature is not None:
58
- self._complete_args["temperature"] = temperature
59
- if top_p is not None:
60
- self._complete_args["top_p"] = top_p
61
- if max_token is not None:
62
- self._complete_args["max_tokens"] = max_token
63
-
64
- def form_user_role(self, content):
65
- return {"role": "user", "content": content}
66
-
67
- def form_sys_role(self, content):
68
- return {"role": "system", "content": content}
69
-
70
- def form_assistant_role(self, content):
71
- return {"role": "assistant", "content": content}
72
-
73
- @property
74
- def complete_args(self):
75
- return {"temperature": 0.01, "top_p": 0.7}
76
-
77
- @log_decorate
78
- def chat_complete(self, messages):
79
-
80
- endpoint_id = self.conf["ENDPOINT_ID"]
81
- completion = self.client.chat.completions.create(
82
- model=endpoint_id,
83
- messages=messages,
84
- **self.complete_args
85
- )
86
- logger.info(f"complete doubao task, id: {completion.id}")
87
- return completion.choices[0].message.content
88
-
89
- def prd_to_keypoint(self, prd_content):
90
-
91
- role_desc = {"role": "system", "content": PRD2KP_SYS}
92
-
93
- messages = [
94
- role_desc,
95
- {"role": "user", "content": prd_content}
96
- ]
97
- return self.chat_complete(messages)
98
-
99
- def prd_to_cases(self, prd_content, case_language="Chinese"):
100
-
101
- role_desc = {"role": "system", "content": PRD_CASE_SYS[case_language]}
102
-
103
- messages = [
104
- role_desc,
105
- {"role": "user", "content": prd_content}
106
- ]
107
- return self.chat_complete(messages)
108
-
109
- def keypoint_to_case(self, key_points):
110
-
111
- role_desc = {"role": "system", "content": KP2CASE_SYS}
112
-
113
- messages = [
114
- role_desc,
115
- {"role": "user", "content": key_points}
116
- ]
117
- return self.chat_complete(messages)
118
-
119
- def case_merge_together(self, case_suits):
120
-
121
- role_desc = {"role": "system", "content": CASE_AGG_SYS}
122
-
123
- content_case_suits = ""
124
- for i, case_suit in enumerate(case_suits):
125
- case_suit_expr = json.dumps(case_suit, indent=4, ensure_ascii=False)
126
- content_case_suits += f"来自初级测试工程师{i + 1}的测试用例:\n```json\n{case_suit_expr}\n```\n"
127
- messages = [
128
- role_desc,
129
- {"role": "user", "content": content_case_suits}
130
- ]
131
- completion = self.chat_complete(messages)
132
- return completion
133
-
134
- def cycle_more_case(self, prd_content, case_language="Chinese"):
135
-
136
- role_desc = {"role": "system", "content": PRD_CASE_SYS[case_language]}
137
-
138
- messages = [
139
- role_desc,
140
- {"role": "user", "content": PRD_CASE_1[case_language] + prd_content + "\n" + PRD_CASE_2[case_language]}
141
- ]
142
-
143
- result = []
144
-
145
- for sys in MORE_CASE_PROMPT[case_language]:
146
- if sys:
147
- messages.append({"role": "user", "content": sys})
148
- reply = self.chat_complete(messages)
149
- result.append(reply)
150
- messages.append({"role": "assistant", "content": reply})
151
- time.sleep(10)
152
- return result
153
-
154
-
155
- if __name__ == "__main__":
156
- cli = DouBaoService("DouBao128Pro")
157
- # print(cli.get_api_key())
158
- # prd_content = requests.get("https://tosv.byted.org/obj/music-qa-bucket/xmind-test/de3ebc67410c43603034e21bfefa76a0.md").text
159
- # aa = cli.cycle_more_case(prd_content, "English")
160
- # print(aa)
161
-
162
- print(cli.chat_complete(messages=[
163
- {"role": "system", "content": "You are a helpful assistant."},
164
- {"role": "user", "content": "Introduce LLM shortly."},
165
- ]))
166
-
 
 
1
+ import time
2
+ import requests
3
+ import json
4
+ from volcenginesdkarkruntime import Ark
5
+ from util.config_util import read_config as config
6
+ from util import logger
7
+ import volcenginesdkcore
8
+ import volcenginesdkark
9
+ from volcenginesdkcore.rest import ApiException
10
+ from util.logger_util import log_decorate
11
+
12
+
13
+ class DouBaoService:
14
+
15
+ def __init__(self, model_name):
16
+ print(config())
17
+ self.conf = config()[f"{model_name}ModelInfo"]
18
+ self.client = self.init_client()
19
+ self._complete_args = {}
20
+
21
+
22
+
23
+ def init_client(self):
24
+ base_url = self.conf["BASE_URL"]
25
+ ak = self.conf["ACCESS_KEY"]
26
+ sk = self.conf["SECRET_KEY"]
27
+ # api_key = self.conf["API_KEY"]
28
+ client = Ark(ak=ak, sk=sk, base_url=base_url)
29
+ # client = Ark(ak=api_key, base_url=base_url)
30
+ return client
31
+
32
+ def get_api_key(self):
33
+ configuration = volcenginesdkcore.Configuration()
34
+ configuration.ak = self.conf["ACCESS_KEY"]
35
+ configuration.sk = self.conf["SECRET_KEY"]
36
+ configuration.region = "cn-beijing"
37
+ endpoint_id = self.conf["ENDPOINT_ID"]
38
+
39
+ volcenginesdkcore.Configuration.set_default(configuration)
40
+
41
+ # use global default configuration
42
+ api_instance = volcenginesdkark.ARKApi()
43
+ get_api_key_request = volcenginesdkark.GetApiKeyRequest(
44
+ duration_seconds=30 * 24 * 3600,
45
+ resource_type="endpoint",
46
+ resource_ids=[
47
+ endpoint_id
48
+ ],
49
+ )
50
+
51
+ try:
52
+ resp = api_instance.get_api_key(get_api_key_request)
53
+ return resp.api_key
54
+ except ApiException as e:
55
+ logger.error(f"Exception when calling api: {e}")
56
+
57
+ def set_complete_args(self, temperature=None, top_p=None, max_token=None):
58
+ if temperature is not None:
59
+ self._complete_args["temperature"] = temperature
60
+ if top_p is not None:
61
+ self._complete_args["top_p"] = top_p
62
+ if max_token is not None:
63
+ self._complete_args["max_tokens"] = max_token
64
+
65
+ def form_user_role(self, content):
66
+ return {"role": "user", "content": content}
67
+
68
+ def form_sys_role(self, content):
69
+ return {"role": "system", "content": content}
70
+
71
+ def form_assistant_role(self, content):
72
+ return {"role": "assistant", "content": content}
73
+
74
+ @property
75
+ def complete_args(self):
76
+ return {"temperature": 0.01, "top_p": 0.7}
77
+
78
+ @log_decorate
79
+ def chat_complete(self, messages):
80
+
81
+ endpoint_id = self.conf["ENDPOINT_ID"]
82
+ completion = self.client.chat.completions.create(
83
+ model=endpoint_id,
84
+ messages=messages,
85
+ **self.complete_args
86
+ )
87
+ logger.info(f"complete doubao task, id: {completion.id}")
88
+ return completion.choices[0].message.content
89
+
90
+ def prd_to_keypoint(self, prd_content):
91
+
92
+ role_desc = {"role": "system", "content": PRD2KP_SYS}
93
+
94
+ messages = [
95
+ role_desc,
96
+ {"role": "user", "content": prd_content}
97
+ ]
98
+ return self.chat_complete(messages)
99
+
100
+ def prd_to_cases(self, prd_content, case_language="Chinese"):
101
+
102
+ role_desc = {"role": "system", "content": PRD_CASE_SYS[case_language]}
103
+
104
+ messages = [
105
+ role_desc,
106
+ {"role": "user", "content": prd_content}
107
+ ]
108
+ return self.chat_complete(messages)
109
+
110
+ def keypoint_to_case(self, key_points):
111
+
112
+ role_desc = {"role": "system", "content": KP2CASE_SYS}
113
+
114
+ messages = [
115
+ role_desc,
116
+ {"role": "user", "content": key_points}
117
+ ]
118
+ return self.chat_complete(messages)
119
+
120
+ def case_merge_together(self, case_suits):
121
+
122
+ role_desc = {"role": "system", "content": CASE_AGG_SYS}
123
+
124
+ content_case_suits = ""
125
+ for i, case_suit in enumerate(case_suits):
126
+ case_suit_expr = json.dumps(case_suit, indent=4, ensure_ascii=False)
127
+ content_case_suits += f"来自初级测试工程师{i + 1}的测试用例:\n```json\n{case_suit_expr}\n```\n"
128
+ messages = [
129
+ role_desc,
130
+ {"role": "user", "content": content_case_suits}
131
+ ]
132
+ completion = self.chat_complete(messages)
133
+ return completion
134
+
135
+ def cycle_more_case(self, prd_content, case_language="Chinese"):
136
+
137
+ role_desc = {"role": "system", "content": PRD_CASE_SYS[case_language]}
138
+
139
+ messages = [
140
+ role_desc,
141
+ {"role": "user", "content": PRD_CASE_1[case_language] + prd_content + "\n" + PRD_CASE_2[case_language]}
142
+ ]
143
+
144
+ result = []
145
+
146
+ for sys in MORE_CASE_PROMPT[case_language]:
147
+ if sys:
148
+ messages.append({"role": "user", "content": sys})
149
+ reply = self.chat_complete(messages)
150
+ result.append(reply)
151
+ messages.append({"role": "assistant", "content": reply})
152
+ time.sleep(10)
153
+ return result
154
+
155
+
156
+ if __name__ == "__main__":
157
+ cli = DouBaoService("DouBao128Pro")
158
+ # print(cli.get_api_key())
159
+ # prd_content = requests.get("https://tosv.byted.org/obj/music-qa-bucket/xmind-test/de3ebc67410c43603034e21bfefa76a0.md").text
160
+ # aa = cli.cycle_more_case(prd_content, "English")
161
+ # print(aa)
162
+
163
+ print(cli.chat_complete(messages=[
164
+ {"role": "system", "content": "You are a helpful assistant."},
165
+ {"role": "user", "content": "Introduce LLM shortly."},
166
+ ]))
167
+