Spaces:
Running
Running
import csv | |
import os | |
import time | |
from datetime import datetime | |
from queue import Queue | |
import threading | |
import pandas as pd | |
from gradio import ChatMessage | |
from huggingface_hub import HfApi, hf_hub_download | |
from timer import Timer | |
from utils import log_warning, log_info, log_debug, log_error | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
DATASET_REPO_ID = os.environ.get("APRIEL_PROMPT_DATASET") | |
CSV_FILENAME = "train.csv" | |
def log_chat(chat_id: str, session_id: str, model_name: str, prompt: str, history: list[str], info: dict) -> None: | |
log_info(f"log_chat() called for chat: {chat_id}, queue size: {log_chat_queue.qsize()}, model: {model_name}") | |
log_chat_queue.put((chat_id, session_id, model_name, prompt, history, info)) | |
def _log_chat_worker(): | |
while True: | |
chat_id, session_id, model_name, prompt, history, info = log_chat_queue.get() | |
try: | |
try: | |
_log_chat(chat_id, session_id, model_name, prompt, history, info) | |
except Exception as e: | |
log_error(f"Error logging chat: {e}") | |
finally: | |
log_chat_queue.task_done() | |
def _log_chat(chat_id: str, session_id: str, model_name: str, prompt: str, history: list[str], info: dict) -> bool: | |
log_info(f"_log_chat() storing chat {chat_id}") | |
if DATASET_REPO_ID is None: | |
log_warning("No dataset repo ID provided. Skipping logging of prompt.") | |
return False | |
if HF_TOKEN is None: | |
log_warning("No HF token provided. Skipping logging of prompt.") | |
return False | |
log_timer = Timer('log_chat') | |
log_timer.start() | |
# Initialize HF API | |
api = HfApi(token=HF_TOKEN) | |
# Check if the dataset repo exists, if not, create it | |
try: | |
repo_info = api.repo_info(repo_id=DATASET_REPO_ID, repo_type="dataset") | |
log_debug(f"log_chat() --> Dataset repo found: {repo_info.id} private={repo_info.private}") | |
except Exception: # Create new dataset if none exists | |
log_debug(f"log_chat() --> No dataset repo found, creating a new one...") | |
api.create_repo(repo_id=DATASET_REPO_ID, repo_type="dataset", private=True) | |
# Ensure messages are in the correct format | |
messages = [ | |
{"role": item.role, "content": item.content, | |
"type": "thought" if item.metadata else "completion"} if isinstance( | |
item, ChatMessage) else item | |
for item in history | |
if isinstance(item, dict) and "role" in item and "content" in item or isinstance(item, ChatMessage) | |
] | |
if len(messages) != len(history): | |
log_warning("log_chat() --> Some messages in history are missing 'role' or 'content' keys.") | |
user_messages_count = sum(1 for item in messages if isinstance(item, dict) and item.get("role") == "user") | |
# These must match the keys in the new row | |
expected_headers = ["timestamp", "chat_id", "turns", "prompt", "messages", "model", "session_id", "info"] | |
# Prepare new data row | |
new_row = { | |
"timestamp": datetime.now().isoformat(), | |
"chat_id": chat_id, | |
"turns": user_messages_count, | |
"prompt": prompt, | |
"messages": messages, | |
"model": model_name, | |
"session_id": session_id, | |
"info": info, | |
} | |
log_timer.add_step("Prepared new data row") | |
# Try to download existing CSV with retry logic | |
max_retries = 3 | |
retry_count = 0 | |
file_exists = False | |
while retry_count < max_retries: | |
try: | |
csv_path = hf_hub_download( | |
repo_id=DATASET_REPO_ID, | |
filename=CSV_FILENAME, | |
repo_type="dataset", | |
token=HF_TOKEN # Only needed if not already logged in | |
) | |
pd.read_csv(csv_path) | |
file_exists = True | |
log_debug(f"log_chat() --> Downloaded existing CSV with {len(pd.read_csv(csv_path))} rows") | |
break # Success, exit the loop | |
except Exception as e: | |
retry_count += 1 | |
if retry_count < max_retries: | |
retry_delay = 2 * retry_count # Exponential backoff: 2s, 4s, 6s, 8s | |
log_warning( | |
f"log_chat() --> Download attempt {retry_count} failed: {e}. Retrying in {retry_delay} seconds...") | |
time.sleep(retry_delay) | |
else: | |
log_warning(f"log_chat() --> Failed to download CSV after {max_retries} attempts: {e}") | |
file_exists = False | |
log_timer.add_step(f"Downloaded existing CSV (attempts: {retry_count + 1})") | |
# Handle the case where the CSV file does not exist or is invalid | |
if file_exists and len(pd.read_csv(csv_path)) == 0: | |
log_warning(f"log_chat() --> CSV {csv_path} exists but is empty, will create a new one.") | |
dump_hub_csv() | |
file_exists = False | |
elif file_exists: | |
# Check that the headers match our standard headers of "timestamp", "chat_id", "turns", ... | |
existing_headers = pd.read_csv(csv_path).columns.tolist() | |
if set(existing_headers) != set(expected_headers): | |
log_warning(f"log_chat() --> CSV {csv_path} has unexpected headers: {existing_headers}. " | |
f"\nExpected {existing_headers} " | |
f"Will create a new one.") | |
dump_hub_csv() | |
file_exists = False | |
else: | |
log_debug(f"log_chat() --> CSV {csv_path} has expected headers: {existing_headers}") | |
# Write out the new row to the CSV file (append isn't working in HF container, so recreate each time) | |
log_debug(f"log_chat() --> Writing CSV file, file_exists={file_exists}") | |
try: | |
with open(CSV_FILENAME, "w", newline="\n") as f: | |
writer = csv.DictWriter(f, fieldnames=new_row.keys()) | |
writer.writeheader() # Always write the header | |
if file_exists: | |
for _, row in pd.read_csv(csv_path).iterrows(): | |
writer.writerow(row.to_dict()) # Write existing rows | |
writer.writerow(new_row) # Write the new row | |
log_debug("log_chat() --> Wrote out CSV with new row") | |
# dump_local_csv() | |
except Exception as e: | |
log_error(f"log_chat() --> Error writing to CSV: {e}") | |
return False | |
# Upload updated CSV | |
api.upload_file( | |
path_or_fileobj=CSV_FILENAME, | |
path_in_repo=CSV_FILENAME, | |
repo_id=DATASET_REPO_ID, | |
repo_type="dataset", | |
commit_message=f"Added new chat entry at {datetime.now().isoformat()}" | |
) | |
log_timer.add_step("Uploaded updated CSV") | |
log_timer.end() | |
log_debug("log_chat() --> Finished logging chat") | |
log_debug(log_timer.formatted_result()) | |
return True | |
def dump_hub_csv(): | |
# Verify the file contents by loading it from the hub and printing it out | |
try: | |
csv_path = hf_hub_download( | |
repo_id=DATASET_REPO_ID, | |
filename=CSV_FILENAME, | |
repo_type="dataset", | |
token=HF_TOKEN # Only needed if not already logged in | |
) | |
df = pd.read_csv(csv_path) | |
log_info(df) | |
if (df.empty): | |
# show raw contents of downloaded csv file | |
log_info("Raw file contents:") | |
with open(csv_path, 'r') as f: | |
print(f.read()) | |
except Exception as e: | |
log_error(f"Error loading CSV from hub: {e}") | |
def dump_local_csv(): | |
# Verify the file contents by loading it from the local file and printing it out | |
try: | |
df = pd.read_csv(CSV_FILENAME) | |
log_info(df) | |
except Exception as e: | |
log_error(f"Error loading CSV from local file: {e}") | |
def test_log_chat(): | |
# Example usage | |
chat_id = "12345" | |
session_id = "67890" | |
model_name = "Apriel-Model" | |
prompt = "Hello" | |
history = [{"role": "user", "content": prompt}, {"role": "assistant", "content": "Hi there!"}] | |
prompt = "100 + 1" | |
history = [{'role': 'user', 'content': prompt}, ChatMessage( | |
content='Okay, that\'s a simple addition problem. , answer is 2.\n', role='assistant', | |
metadata={'title': '🧠 Thought'}, options=[]), | |
ChatMessage(content='\nThe result of adding 1 and 1 is:\n\n**2**\n', role='assistant', metadata={}, | |
options=[]) | |
] | |
info = {"additional_info": "Some extra data"} | |
log_debug("Starting test_log_chat()") | |
dump_hub_csv() | |
log_chat(chat_id, session_id, model_name, prompt, history, info) | |
log_debug("log_chat 1 returned") | |
log_chat(chat_id, session_id, model_name, prompt + " + 2", history, info) | |
log_debug("log_chat 2 returned") | |
log_chat(chat_id, session_id, model_name, prompt + " + 3", history, info) | |
log_debug("log_chat 3 returned") | |
log_chat(chat_id, session_id, model_name, prompt + " + 4", history, info) | |
log_debug("log_chat 4 returned") | |
sleep_seconds = 10 | |
log_debug(f"Sleeping {sleep_seconds} seconds to let it finish and log the result.") | |
time.sleep(sleep_seconds) | |
log_debug("Finished sleeping.") | |
dump_hub_csv() | |
# Create a queue for logging chat messages | |
log_chat_queue = Queue() | |
# Start the worker thread | |
threading.Thread(target=_log_chat_worker, daemon=True).start() | |
if __name__ == "__main__": | |
test_log_chat() | |