alozowski's picture
alozowski HF staff
submission-fix (#1065)
23c96f8 verified
from datetime import datetime, timezone
from typing import Dict, Any, List, Set, Tuple, Optional
import json
import logging
import asyncio
from pathlib import Path
import aiohttp
from huggingface_hub import HfApi
import datasets
from app.services.hf_service import HuggingFaceService
from app.config import HF_TOKEN
from app.config.hf_config import HF_ORGANIZATION
from app.core.cache import cache_config
from app.core.formatting import LogFormatter
logger = logging.getLogger(__name__)
class VoteService(HuggingFaceService):
_instance: Optional['VoteService'] = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super(VoteService, cls).__new__(cls)
return cls._instance
def __init__(self):
if not hasattr(self, '_init_done'):
super().__init__()
self.votes_file = cache_config.votes_file
self.votes_to_upload: List[Dict[str, Any]] = []
self.vote_check_set: Set[Tuple[str, str, str]] = set()
self._votes_by_model: Dict[str, List[Dict[str, Any]]] = {}
self._votes_by_user: Dict[str, List[Dict[str, Any]]] = {}
self._upload_lock = asyncio.Lock()
self._last_sync = None
self._sync_interval = 300 # 5 minutes
self._total_votes = 0
self._last_vote_timestamp = None
self._max_retries = 3
self._retry_delay = 1 # seconds
self._upload_batch_size = 10
self.hf_api = HfApi(token=HF_TOKEN)
self._init_done = True
async def initialize(self):
"""Initialize the vote service"""
if self._initialized:
await self._check_for_new_votes()
return
try:
logger.info(LogFormatter.section("VOTE SERVICE INITIALIZATION"))
# Ensure votes directory exists
self.votes_file.parent.mkdir(parents=True, exist_ok=True)
# Load existing votes if file exists
local_vote_count = 0
if self.votes_file.exists():
logger.info(LogFormatter.info(f"Loading votes from {self.votes_file}"))
local_vote_count = await self._count_local_votes()
logger.info(LogFormatter.info(f"Found {local_vote_count:,} local votes"))
# Check remote votes count
remote_vote_count = await self._count_remote_votes()
logger.info(LogFormatter.info(f"Found {remote_vote_count:,} remote votes"))
if remote_vote_count > local_vote_count:
logger.info(LogFormatter.info(f"Fetching {remote_vote_count - local_vote_count:,} new votes"))
await self._sync_with_hub()
elif remote_vote_count < local_vote_count:
logger.warning(LogFormatter.warning(f"Local votes ({local_vote_count:,}) > Remote votes ({remote_vote_count:,})"))
await self._load_existing_votes()
else:
logger.info(LogFormatter.success("Local and remote votes are in sync"))
if local_vote_count > 0:
await self._load_existing_votes()
else:
logger.info(LogFormatter.info("No votes found"))
self._initialized = True
self._last_sync = datetime.now(timezone.utc)
# Final summary
stats = {
"Total_Votes": self._total_votes,
"Last_Sync": self._last_sync.strftime("%Y-%m-%d %H:%M:%S UTC")
}
logger.info(LogFormatter.section("INITIALIZATION COMPLETE"))
for line in LogFormatter.stats(stats):
logger.info(line)
except Exception as e:
logger.error(LogFormatter.error("Initialization failed", e))
raise
async def _count_local_votes(self) -> int:
"""Count votes in local file"""
if not self.votes_file.exists():
return 0
count = 0
try:
with open(self.votes_file, 'r') as f:
for _ in f:
count += 1
return count
except Exception as e:
logger.error(f"Error counting local votes: {str(e)}")
return 0
async def _count_remote_votes(self) -> int:
"""Count votes in remote file"""
url = f"https://huggingface.co/datasets/{HF_ORGANIZATION}/votes/raw/main/votes_data.jsonl"
headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
try:
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as response:
if response.status == 200:
count = 0
async for line in response.content:
if line.strip(): # Skip empty lines
count += 1
return count
else:
logger.error(f"Failed to get remote votes: HTTP {response.status}")
return 0
except Exception as e:
logger.error(f"Error counting remote votes: {str(e)}")
return 0
async def _sync_with_hub(self):
"""Sync votes with HuggingFace hub using datasets"""
try:
logger.info(LogFormatter.section("VOTE SYNC"))
self._log_repo_operation("sync", f"{HF_ORGANIZATION}/votes", "Syncing local votes with HF hub")
logger.info(LogFormatter.info("Syncing with HuggingFace hub..."))
# Load votes from HF dataset
dataset = datasets.load_dataset(
f"{HF_ORGANIZATION}/votes",
split="train",
cache_dir=cache_config.get_cache_path("datasets")
)
remote_votes = len(dataset)
logger.info(LogFormatter.info(f"Dataset loaded with {remote_votes:,} votes"))
# Convert to list of dictionaries
df = dataset.to_pandas()
if 'timestamp' in df.columns:
df['timestamp'] = df['timestamp'].dt.strftime('%Y-%m-%dT%H:%M:%SZ')
remote_votes = df.to_dict('records')
# If we have more remote votes than local
if len(remote_votes) > self._total_votes:
new_votes = len(remote_votes) - self._total_votes
logger.info(LogFormatter.info(f"Processing {new_votes:,} new votes..."))
# Save votes to local file
with open(self.votes_file, 'w') as f:
for vote in remote_votes:
f.write(json.dumps(vote) + '\n')
# Reload votes in memory
await self._load_existing_votes()
logger.info(LogFormatter.success("Sync completed successfully"))
else:
logger.info(LogFormatter.success("Local votes are up to date"))
self._last_sync = datetime.now(timezone.utc)
except Exception as e:
logger.error(LogFormatter.error("Sync failed", e))
raise
async def _check_for_new_votes(self):
"""Check for new votes on the hub"""
try:
self._log_repo_operation("check", f"{HF_ORGANIZATION}/votes", "Checking for new votes")
# Load only dataset metadata
dataset_info = datasets.load_dataset(f"{HF_ORGANIZATION}/votes", split="train")
remote_vote_count = len(dataset_info)
if remote_vote_count > self._total_votes:
logger.info(f"Found {remote_vote_count - self._total_votes} new votes on hub")
await self._sync_with_hub()
else:
logger.info("No new votes found on hub")
except Exception as e:
logger.error(f"Error checking for new votes: {str(e)}")
async def _load_existing_votes(self):
"""Load existing votes from file"""
if not self.votes_file.exists():
logger.warning(LogFormatter.warning("No votes file found"))
return
try:
logger.info(LogFormatter.section("LOADING VOTES"))
# Clear existing data structures
self.vote_check_set.clear()
self._votes_by_model.clear()
self._votes_by_user.clear()
vote_count = 0
latest_timestamp = None
with open(self.votes_file, "r") as f:
for line in f:
try:
vote = json.loads(line.strip())
vote_count += 1
# Track latest timestamp
try:
vote_timestamp = datetime.fromisoformat(vote["timestamp"].replace("Z", "+00:00"))
if not latest_timestamp or vote_timestamp > latest_timestamp:
latest_timestamp = vote_timestamp
vote["timestamp"] = vote_timestamp.strftime("%Y-%m-%dT%H:%M:%SZ")
except (KeyError, ValueError) as e:
logger.warning(LogFormatter.warning(f"Invalid timestamp in vote: {str(e)}"))
continue
if vote_count % 1000 == 0:
logger.info(LogFormatter.info(f"Processed {vote_count:,} votes..."))
self._add_vote_to_memory(vote)
except json.JSONDecodeError as e:
logger.error(LogFormatter.error("Vote parsing failed", e))
continue
except Exception as e:
logger.error(LogFormatter.error("Vote processing failed", e))
continue
self._total_votes = vote_count
self._last_vote_timestamp = latest_timestamp
# Final summary
stats = {
"Total_Votes": vote_count,
"Latest_Vote": latest_timestamp.strftime("%Y-%m-%d %H:%M:%S UTC") if latest_timestamp else "None",
"Unique_Models": len(self._votes_by_model),
"Unique_Users": len(self._votes_by_user)
}
logger.info(LogFormatter.section("VOTE SUMMARY"))
for line in LogFormatter.stats(stats):
logger.info(line)
except Exception as e:
logger.error(LogFormatter.error("Failed to load votes", e))
raise
def _add_vote_to_memory(self, vote: Dict[str, Any]):
"""Add vote to memory structures"""
try:
check_tuple = (vote["model"], vote["revision"], vote["username"])
# Skip if we already have this vote
if check_tuple in self.vote_check_set:
return
self.vote_check_set.add(check_tuple)
# Update model votes
if vote["model"] not in self._votes_by_model:
self._votes_by_model[vote["model"]] = []
self._votes_by_model[vote["model"]].append(vote)
# Update user votes
if vote["username"] not in self._votes_by_user:
self._votes_by_user[vote["username"]] = []
self._votes_by_user[vote["username"]].append(vote)
except KeyError as e:
logger.error(f"Malformed vote data, missing key: {str(e)}")
except Exception as e:
logger.error(f"Error adding vote to memory: {str(e)}")
async def get_user_votes(self, user_id: str) -> List[Dict[str, Any]]:
"""Get all votes from a specific user"""
logger.info(LogFormatter.info(f"Fetching votes for user: {user_id}"))
votes = self._votes_by_user.get(user_id, [])
logger.info(LogFormatter.success(f"Found {len(votes):,} votes"))
return votes
async def get_model_votes(self, model_id: str) -> Dict[str, Any]:
"""Get all votes for a specific model"""
logger.info(LogFormatter.info(f"Fetching votes for model: {model_id}"))
votes = self._votes_by_model.get(model_id, [])
# Group votes by revision
votes_by_revision = {}
for vote in votes:
revision = vote["revision"]
if revision not in votes_by_revision:
votes_by_revision[revision] = 0
votes_by_revision[revision] += 1
stats = {
"Total_Votes": len(votes),
**{f"Revision_{k}": v for k, v in votes_by_revision.items()}
}
logger.info(LogFormatter.section("VOTE STATISTICS"))
for line in LogFormatter.stats(stats):
logger.info(line)
return {
"total_votes": len(votes),
"votes_by_revision": votes_by_revision,
"votes": votes
}
async def _get_model_revision(self, model_id: str) -> str:
"""Get current revision of a model with retries"""
logger.info(f"Getting revision for model: {model_id}")
for attempt in range(self._max_retries):
try:
model_info = await asyncio.to_thread(self.hf_api.model_info, model_id)
logger.info(f"Successfully got revision {model_info.sha} for model {model_id}")
return model_info.sha
except Exception as e:
logger.error(f"Error getting model revision for {model_id} (attempt {attempt + 1}): {str(e)}")
if attempt < self._max_retries - 1:
retry_delay = self._retry_delay * (attempt + 1)
logger.info(f"Retrying in {retry_delay} seconds...")
await asyncio.sleep(retry_delay)
else:
logger.warning(f"Using 'main' as fallback revision for {model_id} after {self._max_retries} failed attempts")
return "main"
async def add_vote(self, model_id: str, user_id: str, vote_type: str) -> Dict[str, Any]:
"""Add a vote for a model"""
try:
self._log_repo_operation("add", f"{HF_ORGANIZATION}/votes", f"Adding {vote_type} vote for {model_id} by {user_id}")
logger.info(LogFormatter.section("NEW VOTE"))
stats = {
"Model": model_id,
"User": user_id,
"Type": vote_type
}
for line in LogFormatter.tree(stats, "Vote Details"):
logger.info(line)
revision = await self._get_model_revision(model_id)
check_tuple = (model_id, revision, user_id)
if check_tuple in self.vote_check_set:
raise ValueError("Vote already recorded for this model")
vote = {
"model": model_id,
"revision": revision,
"username": user_id,
"timestamp": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"),
"vote_type": vote_type
}
# Update local storage
with open(self.votes_file, "a") as f:
f.write(json.dumps(vote) + "\n")
self._add_vote_to_memory(vote)
self.votes_to_upload.append(vote)
stats = {
"Status": "Success",
"Queue_Size": len(self.votes_to_upload)
}
for line in LogFormatter.stats(stats):
logger.info(line)
# Try to upload if batch size reached
if len(self.votes_to_upload) >= self._upload_batch_size:
logger.info(LogFormatter.info(f"Upload batch size reached ({self._upload_batch_size}), triggering sync"))
await self._sync_with_hub()
return {"status": "success", "message": "Vote added successfully"}
except Exception as e:
logger.error(LogFormatter.error("Failed to add vote", e))
raise