dh-mc commited on
Commit
4aab576
·
1 Parent(s): 1f778a0

added original code from: https://github.com/mgjinnn/TurtleSoupBaseline

Browse files
TurtleSoupBaseline/openai_api_server.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from asyncio.log import logger
4
+
5
+ import uvicorn
6
+ import gc
7
+ import json
8
+ import torch
9
+
10
+ from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
11
+ from fastapi import FastAPI, HTTPException, Response
12
+ from fastapi.middleware.cors import CORSMiddleware
13
+ from contextlib import asynccontextmanager
14
+ from typing import List, Literal, Optional, Union
15
+ from pydantic import BaseModel, Field
16
+ from transformers import AutoTokenizer, LogitsProcessor
17
+ from sse_starlette.sse import EventSourceResponse
18
+
19
+ EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
20
+ MODEL_PATH = 'THUDM/glm-4-9b-chat'
21
+ MAX_MODEL_LENGTH = 8192
22
+
23
+
24
+ @asynccontextmanager
25
+ async def lifespan(app: FastAPI):
26
+ yield
27
+ if torch.cuda.is_available():
28
+ torch.cuda.empty_cache()
29
+ torch.cuda.ipc_collect()
30
+
31
+
32
+ app = FastAPI(lifespan=lifespan)
33
+
34
+ app.add_middleware(
35
+ CORSMiddleware,
36
+ allow_origins=["*"],
37
+ allow_credentials=True,
38
+ allow_methods=["*"],
39
+ allow_headers=["*"],
40
+ )
41
+
42
+
43
+ class ModelCard(BaseModel):
44
+ id: str
45
+ object: str = "model"
46
+ created: int = Field(default_factory=lambda: int(time.time()))
47
+ owned_by: str = "owner"
48
+ root: Optional[str] = None
49
+ parent: Optional[str] = None
50
+ permission: Optional[list] = None
51
+
52
+
53
+ class ModelList(BaseModel):
54
+ object: str = "list"
55
+ data: List[ModelCard] = []
56
+
57
+
58
+ class FunctionCallResponse(BaseModel):
59
+ name: Optional[str] = None
60
+ arguments: Optional[str] = None
61
+
62
+
63
+ class ChatMessage(BaseModel):
64
+ role: Literal["user", "assistant", "system", "tool"]
65
+ content: str = None
66
+ name: Optional[str] = None
67
+ function_call: Optional[FunctionCallResponse] = None
68
+
69
+
70
+ class DeltaMessage(BaseModel):
71
+ role: Optional[Literal["user", "assistant", "system"]] = None
72
+ content: Optional[str] = None
73
+ function_call: Optional[FunctionCallResponse] = None
74
+
75
+
76
+ class EmbeddingRequest(BaseModel):
77
+ input: Union[List[str], str]
78
+ model: str
79
+
80
+
81
+ class CompletionUsage(BaseModel):
82
+ prompt_tokens: int
83
+ completion_tokens: int
84
+ total_tokens: int
85
+
86
+
87
+ class EmbeddingResponse(BaseModel):
88
+ data: list
89
+ model: str
90
+ object: str
91
+ usage: CompletionUsage
92
+
93
+
94
+ class UsageInfo(BaseModel):
95
+ prompt_tokens: int = 0
96
+ total_tokens: int = 0
97
+ completion_tokens: Optional[int] = 0
98
+
99
+
100
+ class ChatCompletionRequest(BaseModel):
101
+ model: str
102
+ messages: List[ChatMessage]
103
+ temperature: Optional[float] = 0.8
104
+ top_p: Optional[float] = 0.8
105
+ max_tokens: Optional[int] = None
106
+ stream: Optional[bool] = False
107
+ tools: Optional[Union[dict, List[dict]]] = None
108
+ tool_choice: Optional[Union[str, dict]] = "None"
109
+ repetition_penalty: Optional[float] = 1.1
110
+
111
+
112
+ class ChatCompletionResponseChoice(BaseModel):
113
+ index: int
114
+ message: ChatMessage
115
+ finish_reason: Literal["stop", "length", "function_call"]
116
+
117
+
118
+ class ChatCompletionResponseStreamChoice(BaseModel):
119
+ delta: DeltaMessage
120
+ finish_reason: Optional[Literal["stop", "length", "function_call"]]
121
+ index: int
122
+
123
+
124
+ class ChatCompletionResponse(BaseModel):
125
+ model: str
126
+ id: str
127
+ object: Literal["chat.completion", "chat.completion.chunk"]
128
+ choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
129
+ created: Optional[int] = Field(default_factory=lambda: int(time.time()))
130
+ usage: Optional[UsageInfo] = None
131
+
132
+
133
+ class InvalidScoreLogitsProcessor(LogitsProcessor):
134
+ def __call__(
135
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor
136
+ ) -> torch.FloatTensor:
137
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
138
+ scores.zero_()
139
+ scores[..., 5] = 5e4
140
+ return scores
141
+
142
+
143
+ def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
144
+ content = ""
145
+ for response in output.split("<|assistant|>"):
146
+ if "\n" in response:
147
+ metadata, content = response.split("\n", maxsplit=1)
148
+ else:
149
+ metadata, content = "", response
150
+ if not metadata.strip():
151
+ content = content.strip()
152
+ else:
153
+ if use_tool:
154
+ parameters = eval(content.strip())
155
+ content = {
156
+ "name": metadata.strip(),
157
+ "arguments": json.dumps(parameters, ensure_ascii=False)
158
+ }
159
+ else:
160
+ content = {
161
+ "name": metadata.strip(),
162
+ "content": content
163
+ }
164
+ return content
165
+
166
+
167
+ @torch.inference_mode()
168
+ async def generate_stream_glm4(params):
169
+ messages = params["messages"]
170
+ tools = params["tools"]
171
+ tool_choice = params["tool_choice"]
172
+ temperature = float(params.get("temperature", 1.0))
173
+ repetition_penalty = float(params.get("repetition_penalty", 1.0))
174
+ top_p = float(params.get("top_p", 1.0))
175
+ max_new_tokens = int(params.get("max_tokens", 8192))
176
+ messages = process_messages(messages, tools=tools, tool_choice=tool_choice)
177
+ inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
178
+ params_dict = {
179
+ "n": 1,
180
+ "best_of": 1,
181
+ "presence_penalty": 1.0,
182
+ "frequency_penalty": 0.0,
183
+ "temperature": temperature,
184
+ "top_p": top_p,
185
+ "top_k": -1,
186
+ "repetition_penalty": repetition_penalty,
187
+ "use_beam_search": False,
188
+ "length_penalty": 1,
189
+ "early_stopping": False,
190
+ "stop_token_ids": [151329, 151336, 151338],
191
+ "ignore_eos": False,
192
+ "max_tokens": max_new_tokens,
193
+ "logprobs": None,
194
+ "prompt_logprobs": None,
195
+ "skip_special_tokens": True,
196
+ }
197
+ sampling_params = SamplingParams(**params_dict)
198
+ async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id="glm-4-9b"):
199
+ output_len = len(output.outputs[0].token_ids)
200
+ input_len = len(output.prompt_token_ids)
201
+ ret = {
202
+ "text": output.outputs[0].text,
203
+ "usage": {
204
+ "prompt_tokens": input_len,
205
+ "completion_tokens": output_len,
206
+ "total_tokens": output_len + input_len
207
+ },
208
+ "finish_reason": output.outputs[0].finish_reason,
209
+ }
210
+ yield ret
211
+ gc.collect()
212
+ torch.cuda.empty_cache()
213
+
214
+
215
+ def process_messages(messages, tools=None, tool_choice="none"):
216
+ _messages = messages
217
+ messages = []
218
+ msg_has_sys = False
219
+
220
+ def filter_tools(tool_choice, tools):
221
+ function_name = tool_choice.get('function', {}).get('name', None)
222
+ if not function_name:
223
+ return []
224
+ filtered_tools = [
225
+ tool for tool in tools
226
+ if tool.get('function', {}).get('name') == function_name
227
+ ]
228
+ return filtered_tools
229
+
230
+ if tool_choice != "none":
231
+ if isinstance(tool_choice, dict):
232
+ tools = filter_tools(tool_choice, tools)
233
+ if tools:
234
+ messages.append(
235
+ {
236
+ "role": "system",
237
+ "content": None,
238
+ "tools": tools
239
+ }
240
+ )
241
+ msg_has_sys = True
242
+
243
+ # add to metadata
244
+ if isinstance(tool_choice, dict) and tools:
245
+ messages.append(
246
+ {
247
+ "role": "assistant",
248
+ "metadata": tool_choice["function"]["name"],
249
+ "content": ""
250
+ }
251
+ )
252
+
253
+ for m in _messages:
254
+ role, content, func_call = m.role, m.content, m.function_call
255
+ if role == "function":
256
+ messages.append(
257
+ {
258
+ "role": "observation",
259
+ "content": content
260
+ }
261
+ )
262
+ elif role == "assistant" and func_call is not None:
263
+ for response in content.split("<|assistant|>"):
264
+ if "\n" in response:
265
+ metadata, sub_content = response.split("\n", maxsplit=1)
266
+ else:
267
+ metadata, sub_content = "", response
268
+ messages.append(
269
+ {
270
+ "role": role,
271
+ "metadata": metadata,
272
+ "content": sub_content.strip()
273
+ }
274
+ )
275
+ else:
276
+ if role == "system" and msg_has_sys:
277
+ msg_has_sys = False
278
+ continue
279
+ messages.append({"role": role, "content": content})
280
+
281
+ return messages
282
+
283
+
284
+ @app.get("/health")
285
+ async def health() -> Response:
286
+ """Health check."""
287
+ return Response(status_code=200)
288
+
289
+
290
+ @app.get("/v1/models", response_model=ModelList)
291
+ async def list_models():
292
+ model_card = ModelCard(id="glm-4")
293
+ return ModelList(data=[model_card])
294
+
295
+
296
+ @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
297
+ async def create_chat_completion(request: ChatCompletionRequest):
298
+ if len(request.messages) < 1 or request.messages[-1].role == "assistant":
299
+ raise HTTPException(status_code=400, detail="Invalid request")
300
+
301
+ gen_params = dict(
302
+ messages=request.messages,
303
+ temperature=request.temperature,
304
+ top_p=request.top_p,
305
+ max_tokens=request.max_tokens or 1024,
306
+ echo=False,
307
+ stream=request.stream,
308
+ repetition_penalty=request.repetition_penalty,
309
+ tools=request.tools,
310
+ tool_choice=request.tool_choice,
311
+ )
312
+ logger.debug(f"==== request ====\n{gen_params}")
313
+
314
+ if request.stream:
315
+ predict_stream_generator = predict_stream(request.model, gen_params)
316
+ output = await anext(predict_stream_generator)
317
+ if output:
318
+ return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
319
+ logger.debug(f"First result output:\n{output}")
320
+
321
+ function_call = None
322
+ if output and request.tools:
323
+ try:
324
+ function_call = process_response(output, use_tool=True)
325
+ except:
326
+ logger.warning("Failed to parse tool call")
327
+
328
+ # CallFunction
329
+ if isinstance(function_call, dict):
330
+ function_call = FunctionCallResponse(**function_call)
331
+ tool_response = ""
332
+ if not gen_params.get("messages"):
333
+ gen_params["messages"] = []
334
+ gen_params["messages"].append(ChatMessage(role="assistant", content=output))
335
+ gen_params["messages"].append(ChatMessage(role="tool", name=function_call.name, content=tool_response))
336
+ generate = predict(request.model, gen_params)
337
+ return EventSourceResponse(generate, media_type="text/event-stream")
338
+ else:
339
+ generate = parse_output_text(request.model, output)
340
+ return EventSourceResponse(generate, media_type="text/event-stream")
341
+
342
+ response = ""
343
+ async for response in generate_stream_glm4(gen_params):
344
+ pass
345
+
346
+ if response["text"].startswith("\n"):
347
+ response["text"] = response["text"][1:]
348
+ response["text"] = response["text"].strip()
349
+
350
+ usage = UsageInfo()
351
+ function_call, finish_reason = None, "stop"
352
+ if request.tools:
353
+ try:
354
+ function_call = process_response(response["text"], use_tool=True)
355
+ except:
356
+ logger.warning(
357
+ "Failed to parse tool call, maybe the response is not a function call(such as cogview drawing) or have been answered.")
358
+
359
+ if isinstance(function_call, dict):
360
+ finish_reason = "function_call"
361
+ function_call = FunctionCallResponse(**function_call)
362
+
363
+ message = ChatMessage(
364
+ role="assistant",
365
+ content=response["text"],
366
+ function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
367
+ )
368
+
369
+ logger.debug(f"==== message ====\n{message}")
370
+
371
+ choice_data = ChatCompletionResponseChoice(
372
+ index=0,
373
+ message=message,
374
+ finish_reason=finish_reason,
375
+ )
376
+ task_usage = UsageInfo.model_validate(response["usage"])
377
+ for usage_key, usage_value in task_usage.model_dump().items():
378
+ setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
379
+
380
+ return ChatCompletionResponse(
381
+ model=request.model,
382
+ id="", # for open_source model, id is empty
383
+ choices=[choice_data],
384
+ object="chat.completion",
385
+ usage=usage
386
+ )
387
+
388
+
389
+ async def predict(model_id: str, params: dict):
390
+ choice_data = ChatCompletionResponseStreamChoice(
391
+ index=0,
392
+ delta=DeltaMessage(role="assistant"),
393
+ finish_reason=None
394
+ )
395
+ chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
396
+ yield "{}".format(chunk.model_dump_json(exclude_unset=True))
397
+
398
+ previous_text = ""
399
+ async for new_response in generate_stream_glm4(params):
400
+ decoded_unicode = new_response["text"]
401
+ delta_text = decoded_unicode[len(previous_text):]
402
+ previous_text = decoded_unicode
403
+
404
+ finish_reason = new_response["finish_reason"]
405
+ if len(delta_text) == 0 and finish_reason != "function_call":
406
+ continue
407
+
408
+ function_call = None
409
+ if finish_reason == "function_call":
410
+ try:
411
+ function_call = process_response(decoded_unicode, use_tool=True)
412
+ except:
413
+ logger.warning(
414
+ "Failed to parse tool call, maybe the response is not a tool call or have been answered.")
415
+
416
+ if isinstance(function_call, dict):
417
+ function_call = FunctionCallResponse(**function_call)
418
+
419
+ delta = DeltaMessage(
420
+ content=delta_text,
421
+ role="assistant",
422
+ function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
423
+ )
424
+
425
+ choice_data = ChatCompletionResponseStreamChoice(
426
+ index=0,
427
+ delta=delta,
428
+ finish_reason=finish_reason
429
+ )
430
+ chunk = ChatCompletionResponse(
431
+ model=model_id,
432
+ id="",
433
+ choices=[choice_data],
434
+ object="chat.completion.chunk"
435
+ )
436
+ yield "{}".format(chunk.model_dump_json(exclude_unset=True))
437
+
438
+ choice_data = ChatCompletionResponseStreamChoice(
439
+ index=0,
440
+ delta=DeltaMessage(),
441
+ finish_reason="stop"
442
+ )
443
+ chunk = ChatCompletionResponse(
444
+ model=model_id,
445
+ id="",
446
+ choices=[choice_data],
447
+ object="chat.completion.chunk"
448
+ )
449
+ yield "{}".format(chunk.model_dump_json(exclude_unset=True))
450
+ yield '[DONE]'
451
+
452
+
453
+ async def predict_stream(model_id, gen_params):
454
+ output = ""
455
+ is_function_call = False
456
+ has_send_first_chunk = False
457
+ async for new_response in generate_stream_glm4(gen_params):
458
+ decoded_unicode = new_response["text"]
459
+ delta_text = decoded_unicode[len(output):]
460
+ output = decoded_unicode
461
+
462
+ if not is_function_call and len(output) > 7:
463
+ is_function_call = output and 'get_' in output
464
+ if is_function_call:
465
+ continue
466
+
467
+ finish_reason = new_response["finish_reason"]
468
+ if not has_send_first_chunk:
469
+ message = DeltaMessage(
470
+ content="",
471
+ role="assistant",
472
+ function_call=None,
473
+ )
474
+ choice_data = ChatCompletionResponseStreamChoice(
475
+ index=0,
476
+ delta=message,
477
+ finish_reason=finish_reason
478
+ )
479
+ chunk = ChatCompletionResponse(
480
+ model=model_id,
481
+ id="",
482
+ choices=[choice_data],
483
+ created=int(time.time()),
484
+ object="chat.completion.chunk"
485
+ )
486
+ yield "{}".format(chunk.model_dump_json(exclude_unset=True))
487
+
488
+ send_msg = delta_text if has_send_first_chunk else output
489
+ has_send_first_chunk = True
490
+ message = DeltaMessage(
491
+ content=send_msg,
492
+ role="assistant",
493
+ function_call=None,
494
+ )
495
+ choice_data = ChatCompletionResponseStreamChoice(
496
+ index=0,
497
+ delta=message,
498
+ finish_reason=finish_reason
499
+ )
500
+ chunk = ChatCompletionResponse(
501
+ model=model_id,
502
+ id="",
503
+ choices=[choice_data],
504
+ created=int(time.time()),
505
+ object="chat.completion.chunk"
506
+ )
507
+ yield "{}".format(chunk.model_dump_json(exclude_unset=True))
508
+
509
+ if is_function_call:
510
+ yield output
511
+ else:
512
+ yield '[DONE]'
513
+
514
+
515
+ async def parse_output_text(model_id: str, value: str):
516
+ choice_data = ChatCompletionResponseStreamChoice(
517
+ index=0,
518
+ delta=DeltaMessage(role="assistant", content=value),
519
+ finish_reason=None
520
+ )
521
+ chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
522
+ yield "{}".format(chunk.model_dump_json(exclude_unset=True))
523
+ choice_data = ChatCompletionResponseStreamChoice(
524
+ index=0,
525
+ delta=DeltaMessage(),
526
+ finish_reason="stop"
527
+ )
528
+ chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
529
+ yield "{}".format(chunk.model_dump_json(exclude_unset=True))
530
+ yield '[DONE]'
531
+
532
+
533
+ if __name__ == "__main__":
534
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
535
+ engine_args = AsyncEngineArgs(
536
+ model=MODEL_PATH,
537
+ tokenizer=MODEL_PATH,
538
+ tensor_parallel_size=1,
539
+ dtype="bfloat16",
540
+ trust_remote_code=True,
541
+ gpu_memory_utilization=0.9,
542
+ enforce_eager=True,
543
+ worker_use_ray=True,
544
+ engine_use_ray=False,
545
+ disable_log_requests=True,
546
+ max_model_len=MAX_MODEL_LENGTH,
547
+ )
548
+ engine = AsyncLLMEngine.from_engine_args(engine_args)
549
+ uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
TurtleSoupBaseline/process_transform.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+ data = pd.read_csv('predict_result.csv',encoding='utf8')
4
+
5
+ # 做一定程度上的转换,转换不同说法但表达意思相同的答案。需写清说明。
6
+ def trans(ans):
7
+
8
+ res = ans
9
+ if len(ans)<25:
10
+ if "是的。" in ans:
11
+ res = "是"
12
+ if "问法错误。" in ans:
13
+ res = "问法错误"
14
+ if "回答正确" in ans:
15
+ res = "回答正确"
16
+ if "不重要。" in ans:
17
+ res = "不重要"
18
+ if "不是。" in ans:
19
+ res = "不是"
20
+ return res
21
+ data['answer'] = data['answer'].apply(lambda x: trans(x))
22
+
23
+ print(f"label acc is :{len(data[data['label']==data['answer']])/len(data)}")
TurtleSoupBaseline/readme.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # baseline说明
2
+ 本baseline参考glm4官网:https://github.com/THUDM/GLM-4.git
3
+
4
+ 1、安装环境:pip install -r requirements.txt 具体环境要求请查看https://github.com/THUDM/GLM-4/blob/main/basic_demo/README.md
5
+
6
+ 2、启动服务端:
7
+ ```shell
8
+ python openai_api_server.py
9
+ ```
10
+ 3、运行硬件环境:单卡24g
11
+
12
+ 4、启动预测
13
+ ```shell
14
+ python test_re.py
15
+ ```
16
+
17
+ 5、推理完成后,允许参考process_transform.py 文件,做一定程度上的转换,转换内容仅限于:由于模型的不稳定输出的不同说法但表达意思相同的答案,如“是的。”允许转换为“是。”
18
+
19
+ 6、baseline在测试集A的准确率约为64.7%
TurtleSoupBaseline/requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # use vllm
2
+ # vllm>=0.4.3
3
+
4
+ torch>=2.3.0
5
+ torchvision>=0.18.0
6
+ transformers==4.40.0
7
+ huggingface-hub>=0.23.1
8
+ sentencepiece>=0.2.0
9
+ pydantic>=2.7.1
10
+ timm>=0.9.16
11
+ tiktoken>=0.7.0
12
+ accelerate>=0.30.1
13
+ sentence_transformers>=2.7.0
14
+
15
+ # web demo
16
+ gradio>=4.33.0
17
+
18
+ # openai demo
19
+ openai>=1.31.1
20
+ einops>=0.7.0
21
+ sse-starlette>=2.1.0
22
+
23
+ # INT4
24
+ bitsandbytes>=0.43.1
TurtleSoupBaseline/test_re.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+ base_url = "http://localhost:8000/v1/"
3
+ client = OpenAI(api_key="EMPTY", base_url=base_url)
4
+ import time
5
+ import pandas as pd
6
+
7
+
8
+ test_a = pd.read_csv('test_a.csv',encoding='gbk')
9
+
10
+
11
+ def simple_chat(sys_content,usr_content,use_stream=False):
12
+ messages = [
13
+ {
14
+ "role": "system",
15
+ "content": sys_content,
16
+ },
17
+ {
18
+ "role": "user",
19
+ "content": usr_content
20
+ }
21
+ ]
22
+ response = client.chat.completions.create(
23
+ model="glm-4",
24
+ messages=messages,
25
+ stream=use_stream,
26
+ max_tokens=1024,
27
+ temperature=0.1,
28
+ presence_penalty=1.1,
29
+ top_p=0.8)
30
+ if response:
31
+ if use_stream:
32
+ stream_list=[]
33
+ for chunk in response:
34
+ stream_list.append(chunk.choices[0].delta.content)
35
+ return stream_list
36
+ else:
37
+ content = response.choices[0].message.content
38
+ return content
39
+ else:
40
+ return (f"Error:, {response.status_code}")
41
+
42
+ def prompt1(x,y,z):
43
+ sys_prom=f'''你是海龟汤出题人,我们来玩一个叫做海龟汤的游戏。海龟汤是一种情景猜谜的推理游戏。其玩法是:出题者提出一个简单又难以理解的事件,
44
+ 玩家可以提出任何封闭式问题以试图缩小范围并找出事件背后真正的原因,封闭式问题指的是问题答案只能为:"是。"或者"不是。"。如果玩家的问题不是一个封闭式问题,请回答:"问法错误。"。
45
+ 海龟汤由汤面和汤底组成,汤面指的是海龟汤的题目,汤底指的是题目背后的真相。如果用户的问题和汤面和汤底不相关,请回答:"不重要。",如果用户的答案命中了汤底的核心真相,且大部分内容都得到了还原,请回答:"回答正确。"。游戏过程中,你需要根据汤底、汤面、玩家的问题,以及上述规则,判断并选择以下五个选项中的一个来回答玩家提出的问题,不能给出更多的提示。你的回答选项: [是。|不是。|不重要。|问法错误。|回答正确。]。最后玩家通过这些问题和回答来逐渐找到事件的真相,以下是一份海龟汤的汤面和汤底,
46
+ 汤面:[{x}]。汤底:[{y}]。请你扮演出题者的角色,我来扮演玩家的角色。由我先提问:'''
47
+ usr_prom = z
48
+ res = simple_chat(sys_content=sys_prom,usr_content=usr_prom)
49
+
50
+ return res
51
+
52
+ t1 = time.time()
53
+ print(f"now: {t1}")
54
+ test_a['answer'] = test_a.apply(lambda x:prompt1(x.puzzle,x.truth,x.text),axis=1)
55
+ print(f"cost:{time.time()-t1}")
56
+ test_a_baseline_pre = test_a
57
+ test_a_baseline_pre.to_csv('your_predict_result.csv',index=False)