memval / src /db.py
omerXfaruq's picture
FULL COMMIT
4717db6
import random
from typing import List, Optional, Union, Tuple
from fastapi import Depends, Query
from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select, and_
from sqlalchemy import UniqueConstraint
default_schedule = "8,20"
class UserBase(SQLModel):
__table_args__ = (UniqueConstraint("telegram_chat_id"),)
name: str
telegram_chat_id: int
gmt: Optional[int] = 0
active: Optional[bool] = True
scheduled_hours: Optional[str] = default_schedule
class User(UserBase, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
reminders: List["Reminder"] = Relationship(back_populates="user")
class UserRead(UserBase):
id: int
class UserCreate(UserBase):
pass
class ReminderBase(SQLModel):
reminder: str
user_id: Optional[int] = Field(default=None, foreign_key="user.id")
class Reminder(ReminderBase, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
user: User = Relationship(back_populates="reminders")
class ReminderRead(ReminderBase):
id: int
class ReminderCreate(ReminderBase):
pass
sqlite_file_name = "database.db"
sqlite_url = f"sqlite:///{sqlite_file_name}"
connect_args = {"check_same_thread": False}
engine = create_engine(sqlite_url, echo=False, connect_args=connect_args)
def create_db_and_tables():
SQLModel.metadata.create_all(engine)
def get_session():
with Session(engine) as session:
yield session
def join_user(
user: UserCreate,
session: Session = next(get_session()),
) -> Optional[User]:
"""
Joins the user to the system.
Args:
user:
session:
Returns: User or None
"""
found_user = session.exec(
select(User).where(User.telegram_chat_id == user.telegram_chat_id)
).first()
if found_user is None:
user = db_create_user(user, session)
return user
else:
if found_user.active:
return None
else:
found_user.active = True
session.add(found_user)
session.commit()
session.refresh(found_user)
return found_user
def leave_user(
user: UserCreate,
session: Session = next(get_session()),
) -> Optional[User]:
"""
Deactivates the user.
Args:
user:
session:
Returns: User or None
"""
found_user = session.exec(
select(User).where(User.telegram_chat_id == user.telegram_chat_id)
).first()
if found_user is None:
return None
else:
if not found_user.active:
return None
else:
found_user.active = False
session.add(found_user)
session.commit()
session.refresh(found_user)
return found_user
def get_user_status(
telegram_chat_id: int,
session: Session = next(get_session()),
) -> Optional[Tuple[int, bool]]:
"""
Get status of the user.
Args:
telegram_chat_id:
session:
Returns: (gmt,active)
"""
found_user = session.exec(
select(User).where(User.telegram_chat_id == telegram_chat_id)
).first()
if found_user is None:
return None, None
else:
return found_user.gmt, found_user.active
def select_random_memory(
user: UserCreate,
session: Session = next(get_session()),
) -> Optional[Union[Reminder, bool]]:
"""
Select random memory from user's memory-vault.
Args:
user:
session:
Returns:
"""
found_user = session.exec(
select(User).where(User.telegram_chat_id == user.telegram_chat_id)
).first()
if found_user is None:
return None
if len(found_user.reminders) > 0:
memory_list = found_user.reminders
random_memory = random.choice(memory_list)
return random_memory
else:
return False
def list_memories(
user: UserCreate,
session: Session = next(get_session()),
) -> Optional[List[Reminder]]:
"""
Return all the memory-vault.
Args:
user:
session:
Returns:
"""
found_user = session.exec(
select(User).where(User.telegram_chat_id == user.telegram_chat_id)
).first()
if found_user is None:
return None
return found_user.reminders
def add_memory(
user: UserCreate,
memory: str,
session: Session = next(get_session()),
) -> Optional[Union[Reminder, bool]]:
"""
Add a memory to user's memory-vault.
Args:
user:
memory:
session:
Returns: Reminder
"""
found_user = session.exec(
select(User).where(User.telegram_chat_id == user.telegram_chat_id)
).first()
if found_user is None:
return None
else:
found_memory = session.exec(
select(Reminder).where(
and_(Reminder.user == found_user, Reminder.reminder == memory)
)
).first()
if found_memory is not None:
return False
else:
reminder = ReminderCreate(
reminder=memory,
user_id=found_user.id,
)
db_reminder = Reminder.from_orm(reminder)
session.add(db_reminder)
session.commit()
session.refresh(db_reminder)
return db_reminder
def delete_memory(
user: UserCreate,
memory_id: int,
session: Session = next(get_session()),
) -> Union[bool, str, None]:
"""
Delete a memory from user's memory-vault.
Args:
user:
memory_id:
session:
Returns: bool
"""
found_user = session.exec(
select(User).where(User.telegram_chat_id == user.telegram_chat_id)
).first()
if found_user is None:
return None
if len(found_user.reminders) > memory_id:
reminder = found_user.reminders[memory_id]
memory = reminder.reminder
session.delete(reminder)
session.commit()
return memory
else:
return False
def delete_last_memory(
user: UserCreate,
session: Session = next(get_session()),
) -> Union[bool, str, None]:
"""
Delete the last memory from user's memory-vault.
Args:
user:
session:
Returns: bool
"""
found_user = session.exec(
select(User).where(User.telegram_chat_id == user.telegram_chat_id)
).first()
if found_user is None:
return None
if found_user.reminders:
reminder = found_user.reminders.pop()
memory = reminder.reminder
session.delete(reminder)
session.commit()
return memory
else:
return False
def update_gmt(
user: UserCreate,
gmt: int,
session: Session = next(get_session()),
) -> Optional[User]:
"""
Update GMT of the user.
Returns: User or None
"""
found_user = session.exec(
select(User).where(User.telegram_chat_id == user.telegram_chat_id)
).first()
if found_user is None:
return None
else:
found_user.gmt = gmt
session.add(found_user)
session.commit()
session.refresh(found_user)
return found_user
def get_schedule(
user: UserCreate,
session: Session = next(get_session()),
) -> Optional[str]:
"""
Get schedule of the user.
Args:
user:
session:
Returns: str or None
"""
found_user = session.exec(
select(User).where(User.telegram_chat_id == user.telegram_chat_id)
).first()
if found_user is None:
return None
else:
return found_user.scheduled_hours
def reset_schedule(
user: UserCreate,
session: Session = next(get_session()),
) -> Optional[str]:
"""
Reset schedule of the user.
Args:
user:
session:
Returns: str or None
"""
found_user = session.exec(
select(User).where(User.telegram_chat_id == user.telegram_chat_id)
).first()
if found_user is None:
return None
else:
found_user.scheduled_hours = default_schedule
session.add(found_user)
session.commit()
session.refresh(found_user)
return found_user.scheduled_hours
def remove_hour_from_schedule(
user: UserCreate,
hour: int,
session: Session = next(get_session()),
) -> Optional[str]:
"""
Remove all appearances of hour from schedule of the user.
Args:
user:
hour:
session:
Returns: str or None
"""
found_user = session.exec(
select(User).where(User.telegram_chat_id == user.telegram_chat_id)
).first()
if found_user is None:
return None
else:
old_schedule = create_schedule_array(found_user.scheduled_hours)
new_schedule = []
for number in old_schedule:
if not hour == number:
new_schedule.append(number)
str_schedule_list = [str(number) for number in new_schedule]
str_schedule = ",".join(str_schedule_list)
found_user.scheduled_hours = str_schedule
session.add(found_user)
session.commit()
session.refresh(found_user)
return found_user.scheduled_hours
def add_hours_to_the_schedule(
user: UserCreate,
schedule_list: List[int],
session: Session = next(get_session()),
) -> Optional[str]:
"""
Add hours to the schedule of the user.
Args:
user:
schedule_list:
session:
Returns: str or None
"""
found_user = session.exec(
select(User).where(User.telegram_chat_id == user.telegram_chat_id)
).first()
if found_user is None:
return None
else:
old_schedule = create_schedule_array(found_user.scheduled_hours)
new_schedule = []
for number in old_schedule:
new_schedule.append(number)
for number in schedule_list:
new_schedule.append(number)
sorted_schedule = sorted(new_schedule)
str_sorted_schedule = [str(number) for number in sorted_schedule]
str_schedule = ",".join(str_sorted_schedule)
found_user.scheduled_hours = str_schedule
session.add(found_user)
session.commit()
session.refresh(found_user)
return found_user.scheduled_hours
def db_create_user(
user: UserCreate,
session: Session = next(get_session()),
) -> Optional[User]:
try:
user = User.from_orm(user)
session.add(user)
session.commit()
session.refresh(user)
return user
except Exception as ex:
return None
def db_read_users(
*,
only_active_users: bool = True,
session: Session = next(get_session()),
offset: int = 0,
limit: int = 100,
) -> List[User]:
if only_active_users:
users = session.exec(
select(User).where(User.active).offset(offset).limit(limit)
).all()
else:
users = session.exec(select(User).offset(offset).limit(limit)).all()
return users
def create_schedule_array(schedule_str: str) -> List[int]:
"""
Create schedule array from schedule string splitted by comas(,).
"""
if schedule_str == "":
return []
else:
schedule_list = []
str_list = schedule_str.split(",")
for str_number in str_list:
schedule_list.append(int(str_number))
return schedule_list