Spaces:
Running
Running
import io | |
import os | |
import time | |
import json | |
from threading import Thread, Lock | |
import streamlit as st | |
from huggingface_hub import HfApi | |
from src.utils import get_current_strftime | |
logger_lock = Lock() | |
def threaded(fn): | |
def wrapper(*args, **kwargs): | |
thread = Thread(target=fn, args=args, kwargs=kwargs) | |
thread.start() | |
return thread | |
return wrapper | |
class Logger: | |
def __init__(self): | |
self.app_id = get_current_strftime() | |
self.session_increment = 0 | |
self.query_increment = 0 | |
self.sync_interval = 180 | |
self.session_data = [] | |
self.query_data = [] | |
self.audio_data = [] | |
self.sync_data() | |
def register_session(self) -> str: | |
new_session_id = f"{self.app_id}+{self.session_increment}" | |
with logger_lock: | |
self.session_data.append({ | |
"session_id": new_session_id, | |
"creation_time": get_current_strftime() | |
}) | |
self.session_increment += 1 | |
return new_session_id | |
def register_query(self, | |
session_id, | |
base64_audio, | |
text_input, | |
response, | |
**kwargs | |
): | |
new_query_id = self.query_increment | |
current_time = get_current_strftime() | |
with logger_lock: | |
current_query_data = { | |
"session_id": session_id, | |
"query_id": new_query_id, | |
"creation_time": current_time, | |
"text": text_input, | |
"response": response, | |
} | |
current_query_data.update(kwargs) | |
self.query_data.append(current_query_data) | |
self.audio_data.append({ | |
"session_id": session_id, | |
"query_id": new_query_id, | |
"creation_time": current_time, | |
"audio": base64_audio, | |
}) | |
self.query_increment += 1 | |
def sync_data(self): | |
api = HfApi() | |
while True: | |
time.sleep(self.sync_interval) | |
for data_name in ["session_data", "query_data", "audio_data"]: | |
with logger_lock: | |
last_data = getattr(self, data_name, []) | |
setattr(self, data_name, []) | |
if not last_data: | |
continue | |
buffer = io.BytesIO() | |
for row in last_data: | |
row_str = json.dumps(row, ensure_ascii=False)+"\n" | |
buffer.write(row_str.encode("utf-8")) | |
api.upload_file( | |
path_or_fileobj=buffer, | |
path_in_repo=f"{data_name}/{get_current_strftime()}.json", | |
repo_id=os.getenv("LOGGING_REPO_NAME"), | |
repo_type="dataset", | |
token=os.getenv('HF_TOKEN') | |
) | |
buffer.close() | |
def load_logger(): | |
return Logger() |