Spaces:
Sleeping
Sleeping
| from openai import OpenAI | |
| from random import randint | |
| from emotion_pertuber import perturb_state | |
| import json | |
| import os | |
| # 设置OpenAI API密钥和基础URL | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1") | |
| model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo") | |
| tools = [ | |
| { | |
| "type": "function", | |
| "function": { | |
| 'name': 'emotion_inference', | |
| 'description': '根据profile和对话记录,推理下一句情绪', | |
| 'parameters': { | |
| "type": "object", | |
| "properties": { | |
| "emotion": { | |
| "type": "string", | |
| "enum": [ | |
| "admiration", "amusement", "anger", "annoyance", "approval", "caring", | |
| "confusion", "curiosity", "desire", "disappointment", "disapproval", | |
| "disgust", "embarrassment", "excitement", "fear", "gratitude", "grief", | |
| "joy", "love", "nervousness", "optimism", "pride", "realization", | |
| "relief", "remorse", "sadness", "surprise", "neutral" | |
| ], | |
| "description": "推理出的情绪类别,必须是GoEmotions定义的27种情绪之一。" | |
| } | |
| }, | |
| "required": ["emotion"] | |
| }, | |
| } | |
| } | |
| ] | |
| # 根据profile和dialogue推测emotion | |
| def emotion_inferencer(profile, conversation): | |
| client = OpenAI( | |
| api_key=api_key, | |
| base_url=base_url, | |
| ) | |
| # 提取患者信息 | |
| patient_info = f"### 患者信息\n年龄:{profile['age']}\n性别:{profile['gender']}\n职业:{profile['occupation']}\n婚姻状况:{profile['marital_status']}\n症状:{profile['symptoms']}" | |
| # 提取对话记录 | |
| dialogue_history = "\n".join([f"{conv['role']}: {conv['content']}" for conv in conversation]) | |
| response = client.chat.completions.create( | |
| model=model_name, | |
| messages=[ | |
| {"role": "user", "content": f"### 任务\n根据患者情况及咨访对话历史记录推测患者下一句话最可能的情绪。\n{patient_info}\n### 对话记录\n{dialogue_history}"} | |
| ], | |
| # functions=[tools[0]["function"]], | |
| # function_call={"name": "emotion_inference"} | |
| tools=tools, | |
| tool_choice={"type": "function", "function": {"name": "emotion_inference"}} | |
| ) | |
| # print(response) | |
| emotion = json.loads(response.choices[0].message.tool_calls[0].function.arguments)["emotion"] | |
| return emotion | |
| def emotion_modulation(profile, conversation): | |
| indicator = randint(0,100) | |
| emotion = emotion_inferencer(profile,conversation) | |
| # print(emotion) | |
| if indicator > 90: | |
| return perturb_state(emotion) | |
| else: | |
| return emotion | |
| # unit test | |
| # while True: | |
| # # 模拟患者信息 | |
| # profile = { | |
| # "drisk": 3, | |
| # "srisk": 2, | |
| # "age": "42", | |
| # "gender": "女", | |
| # "marital_status": "离婚", | |
| # "occupation": "教师", | |
| # "symptoms": "缺乏自信心,自我价值感低,有自罪感,无望感;体重剧烈增加;精神运动性激越;有自杀想法" | |
| # } | |
| # conversation = [ | |
| # {"role": "咨询师", "content": "你好,请问有什么可以帮您?"} | |
| # ] | |
| # print(emotion_modulation(profile,conversation)) | |