import asyncio import json from trauma.api.chat.dto import EntityData from trauma.api.data.model import EntityModelExtended from trauma.api.message.ai.prompts import TraumaPrompts from trauma.core.config import settings from trauma.core.wrappers import openai_wrapper @openai_wrapper(is_json=True) async def update_entity_data_with_ai(entity_data: EntityData, user_message: str, assistant_message: str): messages = [ { "role": "system", "content": TraumaPrompts.update_entity_data_with_ai .replace("{entity_data}", entity_data.model_dump_json(indent=2)) .replace("{assistant_message}", assistant_message) .replace("{user_message}", user_message) } ] return messages @openai_wrapper(temperature=0.8) async def generate_next_question(instructions: str, message_history_str: str): messages = [ { "role": "system", "content": TraumaPrompts.generate_next_question .replace("{instructions}", instructions) .replace("{message_history}", message_history_str) } ] return messages @openai_wrapper(temperature=0.4) async def generate_search_request(user_messages_str: str, entity_data: dict): messages = [ { "role": "system", "content": TraumaPrompts.generate_search_request .replace("{entity_data}", json.dumps(entity_data, indent=2)) .replace("{user_messages_str}", user_messages_str) } ] return messages @openai_wrapper(temperature=0.8) async def generate_final_response( final_entities: str, user_message: str, message_history_str: str, empty_field_instructions: str ): if empty_field_instructions: prompt = (TraumaPrompts.generate_not_fully_recommendations .replace("{instructions}", empty_field_instructions)) else: prompt = (TraumaPrompts.generate_recommendation_decision .replace("{final_entities}", final_entities)) messages = [ { "role": "system", "content": prompt .replace("{message_history}", message_history_str) .replace("{user_message}", user_message) } ] return messages @openai_wrapper(temperature=0.8) async def generate_empty_final_response( user_message: str, message_history_str: str, empty_field_instructions: dict ): field_changed = ", ".join(empty_field_instructions.keys()) messages = [ { "role": "system", "content": TraumaPrompts.generate_empty_recommendations .replace("{message_history}", message_history_str) .replace("{user_message}", user_message) .replace("{instructions}", json.dumps(empty_field_instructions, indent=2)) .replace("{field_changed}", field_changed) } ] return messages async def convert_value_to_embeddings(value: str, dimensions: int = 1536) -> list[float]: embeddings = await settings.OPENAI_CLIENT.embeddings.create( input=value, model='text-embedding-3-large', dimensions=dimensions, ) return embeddings.data[0].embedding @openai_wrapper(is_json=True, return_='result') async def choose_closest_treatment_area(treatment_areas: list[str], treatment_area: str | None): if not treatment_area: return None messages = [ { "role": "system", "content": TraumaPrompts.choose_closest_treatment_area .replace("{treatment_areas}", ", ".join(treatment_areas)) .replace("{treatment_area}", treatment_area) } ] return messages @openai_wrapper(is_json=True, return_='result') async def choose_closest_treatment_method(treatment_methods: list[str], treatment_method: str | None): if not treatment_method: return None messages = [ { "role": "system", "content": TraumaPrompts.choose_closest_treatment_method .replace("{treatment_methods}", ", ".join(treatment_methods)) .replace("{treatment_method}", treatment_method) } ] return messages @openai_wrapper(is_json=True, return_='is_valid') async def check_is_valid_request(user_message: str, message_history: str): messages = [ { "role": "system", "content": TraumaPrompts.decide_is_valid_request .replace("{user_message}", user_message) .replace("{message_history}", message_history) } ] return messages @openai_wrapper(temperature=0.9) async def generate_invalid_response(user_message: str, message_history_str: str, empty_field: str | None): from trauma.api.message.utils import pick_empty_field_instructions if empty_field: empty_field_instructions = pick_empty_field_instructions(empty_field) prompt = TraumaPrompts.generate_invalid_response.replace("{instructions}", empty_field_instructions) else: prompt = TraumaPrompts.generate_invalid_response_with_recs messages = [ { "role": "system", "content": prompt .replace("{message_history}", message_history_str) .replace("{user_message}", user_message) } ] return messages @openai_wrapper(is_json=True, return_='score') async def set_entity_score(entity: EntityModelExtended, search_request: str): messages = [ { "role": "system", "content": TraumaPrompts.set_entity_score .replace("{entity}", entity.model_dump_json(exclude={ "ageGroups", "treatmentAreas", "treatmentMethods", "contactDetails" })) .replace("{search_request}", search_request) } ] return messages async def retrieve_semantic_answer(user_query: str) -> list[EntityModelExtended] | None: embedding = await settings.OPENAI_CLIENT.embeddings.create(input=user_query, model='text-embedding-3-large', dimensions=384) response = await settings.DB_CLIENT.entities.aggregate([ {"$vectorSearch": { "index": f"entityVectors", "path": "embedding", "queryVector": embedding.data[0].embedding, "numCandidates": 20, "limit": 1 }}, {"$project": { "embedding": 0, "score": {"$meta": "vectorSearchScore"} }} ]).to_list(length=1) return [EntityModelExtended(**response[0])] if response[0]['score'] > 0.83 else None @openai_wrapper() async def generate_searched_entity_response(user_query: str, facility: EntityModelExtended): messages = [ { "role": "system", "content": TraumaPrompts.generate_searched_entity .replace("{user_query}", user_query) .replace("{entity}", facility.model_dump_json(indent=2)) } ] return messages @openai_wrapper(is_json=True, return_='words') async def get_sensitive_words(text: str): messages = [ { "role": "system", "content": TraumaPrompts.get_sensitive_words .replace("{text}", text) } ] return messages