|
|
import csv |
|
|
import os |
|
|
import time |
|
|
from datetime import datetime |
|
|
from queue import Queue |
|
|
import threading |
|
|
|
|
|
import pandas as pd |
|
|
from gradio import ChatMessage |
|
|
from huggingface_hub import HfApi, hf_hub_download |
|
|
|
|
|
from timer import Timer |
|
|
from utils import log_warning, log_info, log_debug, log_error |
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
DATASET_REPO_ID = os.environ.get("APRIEL_PROMPT_DATASET") |
|
|
CSV_FILENAME = "train.csv" |
|
|
|
|
|
|
|
|
def log_chat(chat_id: str, session_id: str, model_name: str, prompt: str, history: list[str], info: dict) -> None: |
|
|
log_info(f"log_chat() called for chat: {chat_id}, queue size: {log_chat_queue.qsize()}, model: {model_name}") |
|
|
log_chat_queue.put((chat_id, session_id, model_name, prompt, history, info)) |
|
|
|
|
|
|
|
|
def _log_chat_worker(): |
|
|
while True: |
|
|
chat_id, session_id, model_name, prompt, history, info = log_chat_queue.get() |
|
|
try: |
|
|
try: |
|
|
_log_chat(chat_id, session_id, model_name, prompt, history, info) |
|
|
except Exception as e: |
|
|
log_error(f"Error logging chat: {e}") |
|
|
finally: |
|
|
log_chat_queue.task_done() |
|
|
|
|
|
|
|
|
def _log_chat(chat_id: str, session_id: str, model_name: str, prompt: str, history: list[str], info: dict) -> bool: |
|
|
log_info(f"_log_chat() storing chat {chat_id}") |
|
|
if DATASET_REPO_ID is None: |
|
|
log_warning("No dataset repo ID provided. Skipping logging of prompt.") |
|
|
return False |
|
|
if HF_TOKEN is None: |
|
|
log_warning("No HF token provided. Skipping logging of prompt.") |
|
|
return False |
|
|
|
|
|
log_timer = Timer('log_chat') |
|
|
log_timer.start() |
|
|
|
|
|
|
|
|
api = HfApi(token=HF_TOKEN) |
|
|
|
|
|
|
|
|
try: |
|
|
repo_info = api.repo_info(repo_id=DATASET_REPO_ID, repo_type="dataset") |
|
|
log_debug(f"log_chat() --> Dataset repo found: {repo_info.id} private={repo_info.private}") |
|
|
except Exception: |
|
|
log_debug(f"log_chat() --> No dataset repo found, creating a new one...") |
|
|
api.create_repo(repo_id=DATASET_REPO_ID, repo_type="dataset", private=True) |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{"role": item.role, "content": item.content, |
|
|
"type": "thought" if item.metadata else "completion"} if isinstance( |
|
|
item, ChatMessage) else item |
|
|
for item in history |
|
|
if isinstance(item, dict) and "role" in item and "content" in item or isinstance(item, ChatMessage) |
|
|
] |
|
|
if len(messages) != len(history): |
|
|
log_warning("log_chat() --> Some messages in history are missing 'role' or 'content' keys.") |
|
|
|
|
|
user_messages_count = sum(1 for item in messages if isinstance(item, dict) and item.get("role") == "user" |
|
|
and isinstance(item.get("content"), str)) |
|
|
|
|
|
|
|
|
expected_headers = ["timestamp", "chat_id", "turns", "prompt", "messages", "model", "session_id", "info"] |
|
|
|
|
|
new_row = { |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"chat_id": chat_id, |
|
|
"turns": user_messages_count, |
|
|
"prompt": prompt, |
|
|
"messages": messages, |
|
|
"model": model_name, |
|
|
"session_id": session_id, |
|
|
"info": info, |
|
|
} |
|
|
log_timer.add_step("Prepared new data row") |
|
|
|
|
|
|
|
|
max_retries = 3 |
|
|
retry_count = 0 |
|
|
file_exists = False |
|
|
csv_path = None |
|
|
row_count = 0 |
|
|
while retry_count < max_retries: |
|
|
try: |
|
|
csv_path = hf_hub_download( |
|
|
repo_id=DATASET_REPO_ID, |
|
|
filename=CSV_FILENAME, |
|
|
repo_type="dataset", |
|
|
token=HF_TOKEN |
|
|
) |
|
|
|
|
|
df_check = pd.read_csv(csv_path, nrows=1) |
|
|
file_exists = True |
|
|
break |
|
|
except Exception as e: |
|
|
retry_count += 1 |
|
|
if retry_count < max_retries: |
|
|
retry_delay = 2 * retry_count |
|
|
log_warning( |
|
|
f"log_chat() --> Download attempt {retry_count} failed: {e}. Retrying in {retry_delay} seconds...") |
|
|
time.sleep(retry_delay) |
|
|
else: |
|
|
log_warning(f"log_chat() --> Failed to download CSV after {max_retries} attempts: {e}") |
|
|
file_exists = False |
|
|
|
|
|
log_timer.add_step(f"Downloaded existing CSV (attempts: {retry_count + 1})") |
|
|
|
|
|
|
|
|
if file_exists: |
|
|
|
|
|
existing_headers = pd.read_csv(csv_path, nrows=0).columns.tolist() |
|
|
if set(existing_headers) != set(expected_headers): |
|
|
log_warning(f"log_chat() --> CSV {csv_path} has unexpected headers: {existing_headers}. " |
|
|
f"\nExpected {expected_headers} " |
|
|
f"Will create a new one.") |
|
|
dump_hub_csv() |
|
|
file_exists = False |
|
|
else: |
|
|
log_debug(f"log_chat() --> CSV {csv_path} has expected headers: {existing_headers}") |
|
|
|
|
|
|
|
|
log_debug(f"log_chat() --> Writing CSV file, file_exists={file_exists}") |
|
|
try: |
|
|
if file_exists: |
|
|
|
|
|
|
|
|
with open(CSV_FILENAME, "w", newline="\n") as f_out: |
|
|
writer = csv.DictWriter(f_out, fieldnames=expected_headers) |
|
|
writer.writeheader() |
|
|
|
|
|
|
|
|
chunk_size = 1000 |
|
|
for chunk in pd.read_csv(csv_path, chunksize=chunk_size): |
|
|
for _, row in chunk.iterrows(): |
|
|
writer.writerow(row.to_dict()) |
|
|
|
|
|
|
|
|
writer.writerow(new_row) |
|
|
else: |
|
|
|
|
|
with open(CSV_FILENAME, "w", newline="\n") as f: |
|
|
writer = csv.DictWriter(f, fieldnames=expected_headers) |
|
|
writer.writeheader() |
|
|
writer.writerow(new_row) |
|
|
|
|
|
log_debug(f"log_chat() --> Wrote out CSV with new row") |
|
|
|
|
|
except Exception as e: |
|
|
log_error(f"log_chat() --> Error writing to CSV: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
api.upload_file( |
|
|
path_or_fileobj=CSV_FILENAME, |
|
|
path_in_repo=CSV_FILENAME, |
|
|
repo_id=DATASET_REPO_ID, |
|
|
repo_type="dataset", |
|
|
commit_message=f"Added new chat entry at {datetime.now().isoformat()}" |
|
|
) |
|
|
log_timer.add_step("Uploaded updated CSV") |
|
|
log_timer.end() |
|
|
log_debug("log_chat() --> Finished logging chat") |
|
|
log_debug(log_timer.formatted_result()) |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
def dump_hub_csv(): |
|
|
|
|
|
try: |
|
|
csv_path = hf_hub_download( |
|
|
repo_id=DATASET_REPO_ID, |
|
|
filename=CSV_FILENAME, |
|
|
repo_type="dataset", |
|
|
token=HF_TOKEN |
|
|
) |
|
|
df = pd.read_csv(csv_path) |
|
|
log_info(df) |
|
|
if (df.empty): |
|
|
|
|
|
log_info("Raw file contents:") |
|
|
with open(csv_path, 'r') as f: |
|
|
print(f.read()) |
|
|
except Exception as e: |
|
|
log_error(f"Error loading CSV from hub: {e}") |
|
|
|
|
|
|
|
|
def dump_local_csv(): |
|
|
|
|
|
try: |
|
|
df = pd.read_csv(CSV_FILENAME) |
|
|
log_info(df) |
|
|
except Exception as e: |
|
|
log_error(f"Error loading CSV from local file: {e}") |
|
|
|
|
|
|
|
|
def test_log_chat(): |
|
|
|
|
|
chat_id = "12345" |
|
|
session_id = "67890" |
|
|
model_name = "Apriel-Model" |
|
|
prompt = "Hello" |
|
|
history = [{"role": "user", "content": prompt}, {"role": "assistant", "content": "Hi there!"}] |
|
|
prompt = "100 + 1" |
|
|
history = [{'role': 'user', 'content': prompt}, ChatMessage( |
|
|
content='Okay, that\'s a simple addition problem. , answer is 2.\n', role='assistant', |
|
|
metadata={'title': '🧠 Thought'}, options=[]), |
|
|
ChatMessage(content='\nThe result of adding 1 and 1 is:\n\n**2**\n', role='assistant', metadata={}, |
|
|
options=[]) |
|
|
] |
|
|
info = {"additional_info": "Some extra data"} |
|
|
|
|
|
log_debug("Starting test_log_chat()") |
|
|
dump_hub_csv() |
|
|
log_chat(chat_id, session_id, model_name, prompt, history, info) |
|
|
log_debug("log_chat 1 returned") |
|
|
log_chat(chat_id, session_id, model_name, prompt + " + 2", history, info) |
|
|
log_debug("log_chat 2 returned") |
|
|
log_chat(chat_id, session_id, model_name, prompt + " + 3", history, info) |
|
|
log_debug("log_chat 3 returned") |
|
|
log_chat(chat_id, session_id, model_name, prompt + " + 4", history, info) |
|
|
log_debug("log_chat 4 returned") |
|
|
|
|
|
sleep_seconds = 10 |
|
|
log_debug(f"Sleeping {sleep_seconds} seconds to let it finish and log the result.") |
|
|
time.sleep(sleep_seconds) |
|
|
log_debug("Finished sleeping.") |
|
|
dump_hub_csv() |
|
|
|
|
|
|
|
|
|
|
|
log_chat_queue = Queue() |
|
|
|
|
|
|
|
|
threading.Thread(target=_log_chat_worker, daemon=True).start() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
test_log_chat() |
|
|
|