from pydantic import BaseModel, ConfigDict from typing import List, Optional import json import uuid import time import logging from sqlalchemy import String, Column, BigInteger, Text from apps.webui.internal.db import Base, get_db from config import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) #################### # Tag DB Schema #################### class Tag(Base): __tablename__ = "tag" id = Column(String, primary_key=True) name = Column(String) user_id = Column(String) data = Column(Text, nullable=True) class ChatIdTag(Base): __tablename__ = "chatidtag" id = Column(String, primary_key=True) tag_name = Column(String) chat_id = Column(String) user_id = Column(String) timestamp = Column(BigInteger) class TagModel(BaseModel): id: str name: str user_id: str data: Optional[str] = None model_config = ConfigDict(from_attributes=True) class ChatIdTagModel(BaseModel): id: str tag_name: str chat_id: str user_id: str timestamp: int model_config = ConfigDict(from_attributes=True) #################### # Forms #################### class ChatIdTagForm(BaseModel): tag_name: str chat_id: str class TagChatIdsResponse(BaseModel): chat_ids: List[str] class ChatTagsResponse(BaseModel): tags: List[str] class TagTable: def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: with get_db() as db: id = str(uuid.uuid4()) tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) try: result = Tag(**tag.model_dump()) db.add(result) db.commit() db.refresh(result) if result: return TagModel.model_validate(result) else: return None except Exception as e: return None def get_tag_by_name_and_user_id( self, name: str, user_id: str ) -> Optional[TagModel]: try: with get_db() as db: tag = db.query(Tag).filter(name=name, user_id=user_id).first() return TagModel.model_validate(tag) except Exception as e: return None def add_tag_to_chat( self, user_id: str, form_data: ChatIdTagForm ) -> Optional[ChatIdTagModel]: tag = self.get_tag_by_name_and_user_id(form_data.tag_name, user_id) if tag == None: tag = self.insert_new_tag(form_data.tag_name, user_id) id = str(uuid.uuid4()) chatIdTag = ChatIdTagModel( **{ "id": id, "user_id": user_id, "chat_id": form_data.chat_id, "tag_name": tag.name, "timestamp": int(time.time()), } ) try: with get_db() as db: result = ChatIdTag(**chatIdTag.model_dump()) db.add(result) db.commit() db.refresh(result) if result: return ChatIdTagModel.model_validate(result) else: return None except: return None def get_tags_by_user_id(self, user_id: str) -> List[TagModel]: with get_db() as db: tag_names = [ chat_id_tag.tag_name for chat_id_tag in ( db.query(ChatIdTag) .filter_by(user_id=user_id) .order_by(ChatIdTag.timestamp.desc()) .all() ) ] return [ TagModel.model_validate(tag) for tag in ( db.query(Tag) .filter_by(user_id=user_id) .filter(Tag.name.in_(tag_names)) .all() ) ] def get_tags_by_chat_id_and_user_id( self, chat_id: str, user_id: str ) -> List[TagModel]: with get_db() as db: tag_names = [ chat_id_tag.tag_name for chat_id_tag in ( db.query(ChatIdTag) .filter_by(user_id=user_id, chat_id=chat_id) .order_by(ChatIdTag.timestamp.desc()) .all() ) ] return [ TagModel.model_validate(tag) for tag in ( db.query(Tag) .filter_by(user_id=user_id) .filter(Tag.name.in_(tag_names)) .all() ) ] def get_chat_ids_by_tag_name_and_user_id( self, tag_name: str, user_id: str ) -> List[ChatIdTagModel]: with get_db() as db: return [ ChatIdTagModel.model_validate(chat_id_tag) for chat_id_tag in ( db.query(ChatIdTag) .filter_by(user_id=user_id, tag_name=tag_name) .order_by(ChatIdTag.timestamp.desc()) .all() ) ] def count_chat_ids_by_tag_name_and_user_id( self, tag_name: str, user_id: str ) -> int: with get_db() as db: return ( db.query(ChatIdTag) .filter_by(tag_name=tag_name, user_id=user_id) .count() ) def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool: try: with get_db() as db: res = ( db.query(ChatIdTag) .filter_by(tag_name=tag_name, user_id=user_id) .delete() ) log.debug(f"res: {res}") db.commit() tag_count = self.count_chat_ids_by_tag_name_and_user_id( tag_name, user_id ) if tag_count == 0: # Remove tag item from Tag col as well db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() db.commit() return True except Exception as e: log.error(f"delete_tag: {e}") return False def delete_tag_by_tag_name_and_chat_id_and_user_id( self, tag_name: str, chat_id: str, user_id: str ) -> bool: try: with get_db() as db: res = ( db.query(ChatIdTag) .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id) .delete() ) log.debug(f"res: {res}") db.commit() tag_count = self.count_chat_ids_by_tag_name_and_user_id( tag_name, user_id ) if tag_count == 0: # Remove tag item from Tag col as well db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() db.commit() return True except Exception as e: log.error(f"delete_tag: {e}") return False def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool: tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id) for tag in tags: self.delete_tag_by_tag_name_and_chat_id_and_user_id( tag.tag_name, chat_id, user_id ) return True Tags = TagTable()