Spaces:
Running
Running
import asyncio | |
import json | |
from trauma.api.chat.dto import EntityData | |
from trauma.api.data.model import EntityModelExtended | |
from trauma.api.message.ai.prompts import TraumaPrompts | |
from trauma.core.config import settings | |
from trauma.core.wrappers import openai_wrapper | |
async def update_entity_data_with_ai(entity_data: EntityData, user_message: str, assistant_message: str): | |
messages = [ | |
{ | |
"role": "system", | |
"content": TraumaPrompts.update_entity_data_with_ai | |
.replace("{entity_data}", entity_data.model_dump_json(indent=2)) | |
.replace("{assistant_message}", assistant_message) | |
.replace("{user_message}", user_message) | |
} | |
] | |
return messages | |
async def generate_next_question(instructions: str, message_history_str: str): | |
messages = [ | |
{ | |
"role": "system", | |
"content": TraumaPrompts.generate_next_question | |
.replace("{instructions}", instructions) | |
.replace("{message_history}", message_history_str) | |
} | |
] | |
return messages | |
async def generate_search_request(user_messages_str: str, entity_data: dict): | |
messages = [ | |
{ | |
"role": "system", | |
"content": TraumaPrompts.generate_search_request | |
.replace("{entity_data}", json.dumps(entity_data, indent=2)) | |
.replace("{user_messages_str}", user_messages_str) | |
} | |
] | |
return messages | |
async def generate_final_response( | |
final_entities: str, user_message: str, message_history_str: str, empty_field_instructions: str | |
): | |
if empty_field_instructions: | |
prompt = (TraumaPrompts.generate_not_fully_recommendations | |
.replace("{instructions}", empty_field_instructions)) | |
else: | |
prompt = (TraumaPrompts.generate_recommendation_decision | |
.replace("{final_entities}", final_entities)) | |
messages = [ | |
{ | |
"role": "system", | |
"content": prompt | |
.replace("{message_history}", message_history_str) | |
.replace("{user_message}", user_message) | |
} | |
] | |
return messages | |
async def generate_empty_final_response( | |
user_message: str, | |
message_history_str: str, | |
empty_field_instructions: dict | |
): | |
field_changed = ", ".join(empty_field_instructions.keys()) | |
messages = [ | |
{ | |
"role": "system", | |
"content": TraumaPrompts.generate_empty_recommendations | |
.replace("{message_history}", message_history_str) | |
.replace("{user_message}", user_message) | |
.replace("{instructions}", json.dumps(empty_field_instructions, indent=2)) | |
.replace("{field_changed}", field_changed) | |
} | |
] | |
return messages | |
async def convert_value_to_embeddings(value: str, dimensions: int = 1536) -> list[float]: | |
embeddings = await settings.OPENAI_CLIENT.embeddings.create( | |
input=value, | |
model='text-embedding-3-large', | |
dimensions=dimensions, | |
) | |
return embeddings.data[0].embedding | |
async def choose_closest_treatment_area(treatment_areas: list[str], treatment_area: str | None): | |
if not treatment_area: | |
return None | |
messages = [ | |
{ | |
"role": "system", | |
"content": TraumaPrompts.choose_closest_treatment_area | |
.replace("{treatment_areas}", ", ".join(treatment_areas)) | |
.replace("{treatment_area}", treatment_area) | |
} | |
] | |
return messages | |
async def choose_closest_treatment_method(treatment_methods: list[str], treatment_method: str | None): | |
if not treatment_method: | |
return None | |
messages = [ | |
{ | |
"role": "system", | |
"content": TraumaPrompts.choose_closest_treatment_method | |
.replace("{treatment_methods}", ", ".join(treatment_methods)) | |
.replace("{treatment_method}", treatment_method) | |
} | |
] | |
return messages | |
async def check_is_valid_request(user_message: str, message_history: str): | |
messages = [ | |
{ | |
"role": "system", | |
"content": TraumaPrompts.decide_is_valid_request | |
.replace("{user_message}", user_message) | |
.replace("{message_history}", message_history) | |
} | |
] | |
return messages | |
async def generate_invalid_response(user_message: str, message_history_str: str, empty_field: str | None): | |
from trauma.api.message.utils import pick_empty_field_instructions | |
if empty_field: | |
empty_field_instructions = pick_empty_field_instructions(empty_field) | |
prompt = TraumaPrompts.generate_invalid_response.replace("{instructions}", empty_field_instructions) | |
else: | |
prompt = TraumaPrompts.generate_invalid_response_with_recs | |
messages = [ | |
{ | |
"role": "system", | |
"content": prompt | |
.replace("{message_history}", message_history_str) | |
.replace("{user_message}", user_message) | |
} | |
] | |
return messages | |
async def set_entity_score(entity: EntityModelExtended, search_request: str): | |
messages = [ | |
{ | |
"role": "system", | |
"content": TraumaPrompts.set_entity_score | |
.replace("{entity}", entity.model_dump_json(exclude={ | |
"ageGroups", "treatmentAreas", "treatmentMethods", "contactDetails" | |
})) | |
.replace("{search_request}", search_request) | |
} | |
] | |
return messages | |
async def retrieve_semantic_answer(user_query: str) -> list[EntityModelExtended] | None: | |
embedding = await settings.OPENAI_CLIENT.embeddings.create(input=user_query, | |
model='text-embedding-3-large', | |
dimensions=384) | |
response = await settings.DB_CLIENT.entities.aggregate([ | |
{"$vectorSearch": { | |
"index": f"entityVectors", | |
"path": "embedding", | |
"queryVector": embedding.data[0].embedding, | |
"numCandidates": 20, | |
"limit": 1 | |
}}, | |
{"$project": { | |
"embedding": 0, | |
"score": {"$meta": "vectorSearchScore"} | |
}} | |
]).to_list(length=1) | |
return [EntityModelExtended(**response[0])] if response[0]['score'] > 0.83 else None | |
async def generate_searched_entity_response(user_query: str, facility: EntityModelExtended): | |
messages = [ | |
{ | |
"role": "system", | |
"content": TraumaPrompts.generate_searched_entity | |
.replace("{user_query}", user_query) | |
.replace("{entity}", facility.model_dump_json(indent=2)) | |
} | |
] | |
return messages | |
async def get_sensitive_words(text: str): | |
messages = [ | |
{ | |
"role": "system", | |
"content": TraumaPrompts.get_sensitive_words | |
.replace("{text}", text) | |
} | |
] | |
return messages | |