Spaces:
Running
Running
import json | |
from pymongo.mongo_client import MongoClient | |
from pymongo.server_api import ServerApi | |
from fastapi import Request, Response | |
from . import log_module, security, settings, chat_functions | |
from datetime import timezone, datetime, timedelta | |
from pydantic import BaseModel, Field, PrivateAttr | |
from typing import Any, List, Self | |
import tiktoken | |
import uuid | |
client: MongoClient = MongoClient(settings.DB_URI, server_api=ServerApi('1')) | |
class __DB: | |
user = client.ChatDB.users | |
sess = client.ChatDB.sessions | |
personalities = client.ChatDB.personalityCores | |
DB: __DB = __DB() | |
tz: timezone = timezone(timedelta(hours=-4)) | |
encoding = tiktoken.encoding_for_model(settings.GPT_MODEL) | |
def count_tokens_on_message(args: list) -> int: | |
token_count = 0 | |
for n in args: | |
if n: token_count += len(encoding.encode(n)) | |
return token_count | |
def get_personality_core (personality): | |
found = DB.personalities.find_one({"name":personality}) | |
return found["prompt"].replace( | |
"{date}", datetime.now().strftime("%Y-%m-%D") | |
) | |
def get_all_personality_cores(): | |
return [x["name"] for x in DB.personalities.find({})] | |
class Configs(BaseModel): | |
temperature: float = 0.5 | |
frequency_penalty: float = 0.0 | |
presence_penalty: float = 0.0 | |
useTool: bool = True | |
assistant: str = "Chatsito clasico" | |
assistantPrompt: str = get_personality_core("Chatsito clasico") | |
def __init__(self, *args, **kwargs): | |
super(Configs, self).__init__(*args, **kwargs) | |
self.assistantPrompt = get_personality_core(self.assistant) | |
class Message(BaseModel): | |
role: str | |
content: str = "" | |
_tokens: int = 1 | |
_thread = None | |
_running:bool = False | |
def __init__(self, *args, **kwargs): | |
super(Message, self).__init__(*args, **kwargs) | |
self._tokens = count_tokens_on_message([self.content]) | |
def consume_stream(self, stream): | |
self._running = True | |
for chunk in stream: | |
self._tokens += 1 | |
choice = chunk.choices[0] | |
if choice.finish_reason == "stop" or not choice.delta.content: | |
return | |
self.content += choice.delta.content | |
self._running = False | |
def get_tokens(self): | |
return self._tokens, self._tokensOutput | |
class ToolCallsInputFunction(BaseModel): | |
name: str = "" | |
arguments: str = "" | |
class ToolCallsInput(BaseModel): | |
id: str | |
function: ToolCallsInputFunction | |
type: str = "function" | |
class ToolCallsOutput(BaseModel): | |
role: str | |
tool_call_id: str | |
name: str | |
content: str | |
_tokens: int = 0 | |
def __init__(self, *args, **kwargs): | |
super(ToolCallsOutput, self).__init__(*args, **kwargs) | |
self._tokens = count_tokens_on_message([self.content]) | |
class MessageTool(BaseModel): | |
role: str | |
content: str = "" | |
tool_calls: list[ToolCallsInput] | |
_tokens: int = 0 | |
_outputs: list[ToolCallsOutput] = [] | |
def __init__(self, **kwargs): | |
stream = kwargs.pop("stream") | |
chunk = kwargs.pop("chunk") | |
kwargs["tool_calls"] = [] | |
super(MessageTool, self).__init__(**kwargs) | |
while True: | |
choice = chunk.choices[0] | |
self._tokens += 1 | |
if choice.finish_reason: | |
break | |
if chunk.choices[0].delta.tool_calls == None: | |
chunk = next(stream) | |
continue | |
tool_call = chunk.choices[0].delta.tool_calls[0] | |
if tool_call.id: | |
self.tool_calls.append( | |
ToolCallsInput( | |
id=tool_call.id, | |
function=ToolCallsInputFunction( | |
**tool_call.function.model_dump() | |
))) | |
elif tool_call.function.arguments: | |
self.tool_calls[-1].function.arguments += tool_call.function.arguments | |
chunk = next(stream) | |
if not chunk: | |
self._tokens += sum | |
break | |
def exec(self, gid): | |
for func in self.tool_calls: | |
self._outputs.append(ToolCallsOutput( | |
role="tool", | |
tool_call_id=func.id, | |
name=func.function.name , | |
content=chat_functions.function_callbacks[func.function.name](json.loads(func.function.arguments), gid) | |
)) | |
class Chat(BaseModel): | |
messages: List[Message|MessageTool|ToolCallsOutput] | |
tokens: int = 3 | |
def __init__(self: Self, *args: list, **kwargs: dict): | |
temp = [] | |
for i in kwargs.pop("messages"): | |
temp.append(Message(**i)) | |
kwargs["messages"] = temp | |
super(Chat, self).__init__(*args, **kwargs) | |
self.tokens += sum([x._tokens+3 for x in self.messages]) | |
def append(self: Self, message: Message|MessageTool): | |
if isinstance(message, Message): | |
self.messages.append(message) | |
self.tokens += message._tokens | |
else: | |
self.messages.append(message) | |
self.tokens += message._tokens | |
for out in message._outputs: | |
self.tokens += out._tokens | |
self.messages.append(out) | |
def new_msg(cls: Self, role: str, stream): | |
message = Message(role=role) | |
message.consume_stream(stream) | |
return message | |
def new_func(cls: Self, role: str, stream, chunk): | |
return MessageTool(role=role, stream=stream, chunk=chunk) | |
class Session(BaseModel): | |
gid: str | |
fprint: str | |
hashed: str | |
guid: str | |
public_key: str = "" | |
challenge: str = str(uuid.uuid4()) | |
data: dict = {} | |
configs: Configs | None = None | |
def __init__(self, **kwargs): | |
kwargs["guid"] = kwargs.get("guid", str(uuid.uuid4())) | |
kwargs["hashed"] = security.sha256(kwargs["guid"] + kwargs["fprint"]) | |
super(Session, self).__init__(**kwargs) | |
def validate_signature(self: Self, signature:str): | |
valid = security.validate_signature(self.public_key, signature, self.challenge) | |
if not valid: | |
security.raise_401("Cannot validate Session signature") | |
return True | |
def find_from_data(cls:Self, request: Request, data:dict) -> Self: | |
cookie_data:dict = security.token_from_cookie(request) | |
if "gid" not in cookie_data or "guid" not in cookie_data: | |
log_module.logger().error("Cookie without session needed data") | |
security.raise_401("gid or guid not in cookie") | |
if not (public_key := cookie_data.get("public_key", None)): # FIX Vuln Code | |
if request.scope["path"] != "/getToken": | |
log_module.logger(cookie_data["gid"]).error(f"User without public key saved | {json.dumps(cookie_data)}") | |
security.raise_401("the user must have a public key saved in token") | |
else: | |
log_module.logger(cookie_data['gid']).info("API public key set for user") | |
public_key = data["public_key"] | |
else: | |
cls.check_challenge(data["fingerprint"], cookie_data["challenge"]) | |
session: Self = cls( | |
gid = cookie_data["gid"], | |
fprint = data["fingerprint"], | |
guid = cookie_data["guid"], | |
public_key = public_key, | |
data = data, | |
configs = Configs(**cookie_data["configs"]) | |
) | |
if session.hashed != cookie_data["fprint"]: | |
log_module.logger(session.gid).error(f"Fingerprint didnt match | {json.dumps(cookie_data)}") | |
security.raise_401("Fingerprint didnt match") | |
session.add_challenge() | |
return session | |
def create_cookie_token(self:Self): | |
return security.create_jwt_token({ | |
"gid":self.gid, | |
"guid": self.guid, | |
"fprint": self.hashed, | |
"public_key": self.public_key, | |
"challenge": self.challenge, | |
"configs": self.configs.model_dump() | |
}) | |
def create_cookie(self:Self, response: Response): | |
jwt = self.create_cookie_token() | |
security.set_cookie(response, "token", jwt, {"hours": 24}) | |
def update_usage(self: Self, message:Message): | |
User.update_usage(self.gid, message) | |
def add_challenge(self: Self): | |
return True | |
self.challenge = str(uuid.uuid4()) | |
DB.sess.insert_one(self.model_dump(include={"fprint", "challenge"}) ) | |
def check_challenge(fprint:str, challenge: str): | |
return True | |
found = DB.sess.find_one_and_delete({"fprint":fprint}) | |
if not found or found["challenge"] != challenge: | |
security.raise_401("Check challenge failed") | |
class User(BaseModel): | |
name: str | |
tokens: dict = {} | |
created: datetime = datetime.now(tz) | |
approved: datetime | None = None | |
description: str = "" | |
email: str | |
gid: str | |
role: str = "on hold" | |
configs: Configs = Configs() | |
_session: Session | None = None | |
_data: dict | None = None | |
def find_or_create(cls: Self, data: dict, loginData: dict)-> Self: | |
found = DB.user.find_one({"gid":data["gid"]}) | |
user:Self = cls(**found) if found else cls(**data) | |
if not found: | |
DB.user.insert_one(user.model_dump()) | |
user._session = Session(gid=user.gid, fprint=loginData["fp"], public_key=loginData["pk"]) | |
log_module.logger(user.gid).info(f"User {'logged' if found else'created'} | fp: {user._session.fprint}") | |
return user | |
def find_from_cookie(cls:Self, request: Request) -> Self: | |
cookie_data:dict = security.token_from_cookie(request) | |
if "gid" not in cookie_data or "guid" not in cookie_data: | |
log_module.logger().error("Cookie without needed data") | |
security.raise_307("gid or guid not in cookie") | |
found:dict = DB.user.find_one({"gid":cookie_data["gid"]}) | |
if not found: | |
log_module.logger(cookie_data["gid"]).error("User not found on DB") | |
security.raise_307("User not found on DB") | |
user: Self = cls(**found) | |
user._session = Session( | |
gid = cookie_data["gid"], | |
guid = cookie_data["guid"], | |
fprint = cookie_data["fprint"], | |
configs = user.configs | |
) | |
return user | |
def find_from_data(cls:Self, request: Request, data:dict) -> Self: | |
session:Session = Session.find_from_data(request, data) | |
found:dict = DB.user.find_one({"gid":session.gid}) | |
if not found: | |
log_module.logger(session.gid).error("User not found on DB") | |
security.raise_307("User not found on DB") | |
user: Self = cls(**found) | |
user._session = session | |
user._data = data | |
return user | |
def update_description(self: Self, message: str) -> None: | |
log_module.logger(self.gid).info("Description Updated") | |
DB.user.update_one( | |
{"gid":self.gid}, | |
{"$set": { "description": message}} | |
) | |
self.description = message | |
def can_use(self: Self, activity: str): | |
return security.can_use(self.role, activity) | |
def update_user(self: Self) -> None: | |
log_module.logger(self.gid).info("User Updated") | |
DB.user.update_one({"gid": self.gid}, {"$set": self.model_dump()}) | |
return self.configs.assistantPrompt | |
def update_usage(gid:str, tokens:int): | |
inc_field = datetime.now().strftime("tokens.%y.%m.%d") | |
DB.user.update_one({"gid": gid}, {"$inc":{inc_field: tokens}}) | |
def create_cookie(self:Self): | |
return security.create_jwt_token({ | |
"gid":self._session.gid, | |
"guid": self._session.guid, | |
"fprint": self._session.hashed, | |
"public_key": self._session.public_key, | |
"challenge": self._session.challenge, | |
"configs": self.configs.model_dump() | |
}) | |