| import nh3
|
|
|
| from constants import (
|
| MAX_COMMENT_LENGTH,
|
| MAX_FILE_NAME_LENGTH,
|
| MAX_ID_LENGTH,
|
| MAX_MESSAGE_LENGTH,
|
| MAX_RESPONSE_LENGTH,
|
| )
|
| from pydantic import BaseModel, Field, field_validator
|
| from typing import Literal, Set
|
| from uuid import UUID
|
|
|
|
|
| class IdentifierBase(BaseModel):
|
| user_id: str = Field(
|
| pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH
|
| )
|
|
|
| participant_id: str = Field(
|
| pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH
|
| )
|
| session_id: str = Field(
|
| pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH
|
| )
|
|
|
|
|
| class ProfileBase(BaseModel):
|
| consent: bool
|
| age_group: Literal["0-18", "18-24", "25-34", "35-44", "45-54", "55-64", "65+"]
|
| gender: Literal["M", "F"]
|
| roles: Set[
|
| Literal["patient", "clinician", "computer-scientist", "researcher", "other"]
|
| ] = Field(min_length=1, max_length=5)
|
|
|
|
|
| class ChatRequest(IdentifierBase, ProfileBase):
|
| conversation_id: str = Field(
|
| pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH
|
| )
|
| model_type: Literal[
|
| "champ", "openai", "google-conservative", "google-creative", "qwen"
|
| ]
|
| lang: Literal["en", "fr"]
|
| human_message: str = Field(min_length=1, max_length=MAX_MESSAGE_LENGTH)
|
|
|
| @field_validator("human_message")
|
| def sanitize_human_message(cls, human_message: str):
|
| """Remove HTML tags to prevent XSS"""
|
| return nh3.clean(human_message)
|
|
|
|
|
| class FeedbackRequest(IdentifierBase, ProfileBase):
|
| message_index: int = Field(ge=0, le=10_000)
|
| rating: Literal["like", "dislike", "mixed"]
|
| comment: str = Field(min_length=0, max_length=MAX_COMMENT_LENGTH)
|
| reply_content: str = Field(min_length=1, max_length=MAX_RESPONSE_LENGTH)
|
| reply_id: UUID
|
|
|
| @field_validator("comment")
|
| def sanitize_comment(cls, comment: str):
|
| """Remove HTML tags to prevent XSS"""
|
| return nh3.clean(comment)
|
|
|
| @field_validator("reply_content")
|
| def sanitize_reply_content(cls, reply_content: str):
|
| """Remove HTML tags to prevent XSS"""
|
| return nh3.clean(reply_content)
|
|
|
|
|
| class CommentRequest(IdentifierBase, ProfileBase):
|
| comment: str = Field(min_length=1, max_length=MAX_COMMENT_LENGTH)
|
|
|
| @field_validator("comment")
|
| def sanitize_comment(cls, comment: str):
|
| """Remove HTML tags to prevent XSS"""
|
| return nh3.clean(comment)
|
|
|
|
|
| class DeleteFileRequest(IdentifierBase, ProfileBase):
|
| file_name: str = Field(
|
|
|
| pattern=r"^[a-zA-Z0-9_()-][a-zA-Z0-9\s_()-]*(\.[a-zA-Z0-9\s_-]+)*$",
|
| min_length=1,
|
| max_length=MAX_FILE_NAME_LENGTH,
|
| )
|
|
|
|
|
| class ClearConversationRequest(BaseModel):
|
| old_session_id: str = Field(
|
| pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH
|
| )
|
| new_session_id: str = Field(
|
| pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH
|
| )
|
|
|
|
|
| class ChatMessage(BaseModel):
|
| role: Literal["user", "assistant", "system"]
|
| content: str
|
|
|