Spaces:
Paused
Paused
import asyncio | |
import json | |
import logging | |
import os | |
import re | |
from contextlib import asynccontextmanager | |
from datetime import datetime | |
from pathlib import Path | |
from typing import Annotated, List | |
from cashews import NOT_NONE, cache | |
from dotenv import load_dotenv | |
from fastapi import BackgroundTasks, FastAPI, Header, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse | |
from httpx import AsyncClient | |
from huggingface_hub import CommitScheduler, DatasetCard, HfApi, hf_hub_download, whoami | |
from huggingface_hub.utils import disable_progress_bars, logging | |
from huggingface_hub.utils._errors import HTTPError | |
from langfuse.openai import AsyncOpenAI # OpenAI integration | |
from pydantic import BaseModel, Field | |
from starlette.responses import RedirectResponse | |
from card_processing import parse_markdown, try_load_text, is_empty_template | |
disable_progress_bars() | |
load_dotenv() | |
logger = logging.get_logger(__name__) | |
Gb = 1073741824 | |
cache.setup("disk://", size_limit=16 * Gb) # configure as in-memory cache | |
VOTES_FILE = "data/votes.jsonl" | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
hf_api = HfApi(token=HF_TOKEN) | |
async_httpx_client = AsyncClient() | |
scheduler = CommitScheduler( | |
repo_id="davanstrien/summary-ratings", | |
repo_type="dataset", | |
folder_path="data", | |
path_in_repo="data", | |
every=5, | |
token=HF_TOKEN, | |
hf_api=hf_api, | |
) | |
async def lifespan(app: FastAPI): | |
logger.info("Running startup event") | |
if not Path(VOTES_FILE).exists(): | |
path = hf_hub_download( | |
repo_id="davanstrien/summary-ratings", | |
filename="data/votes.jsonl", | |
repo_type="dataset", | |
token=HF_TOKEN, | |
local_dir=".", | |
local_dir_use_symlinks=False, | |
) | |
logger.info(f"Downloaded votes.jsonl to {path}") | |
yield | |
app = FastAPI() # )lifespan=lifespan) | |
# Configure CORS | |
# origins = [ | |
# "https://huggingface.co", | |
# "chrome-extension://deckahggoiaphiebdipfbiinmaihfpbk", # Replace with your Chrome plugin ID | |
# ] | |
# # Configure CORS settings | |
# app.add_middleware( | |
# CORSMiddleware, | |
# allow_origins=[ | |
# "https://huggingface.co/datasets/*" | |
# ], # Update with your frontend URL | |
# allow_credentials=True, | |
# allow_methods=["*"], | |
# allow_headers=["*"], | |
# ) | |
def save_vote(vote_entry): | |
with scheduler.lock: | |
with open(VOTES_FILE, "a") as file: | |
date_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
vote_entry["timestamp"] = date_time | |
file.write(json.dumps(vote_entry) + "\n") | |
logger.info(f"Vote saved: {vote_entry}") | |
def root(): | |
return RedirectResponse(url="/docs") | |
class Vote(BaseModel): | |
dataset: str | |
description: str | |
vote: int = Field(..., ge=-1, le=1) | |
userID: str | |
def validate_token(token: str = Header(None)) -> bool: | |
try: | |
whoami(token) | |
return True | |
except HTTPError: | |
return False | |
async def receive_vote( | |
vote: Vote, | |
Authorization: Annotated[str, Header()], | |
background_tasks: BackgroundTasks, | |
): | |
if not validate_token(Authorization): | |
logger.error("Invalid token") | |
raise HTTPException(status_code=401, detail="Invalid token") | |
vote_entry = { | |
"dataset": vote.dataset, | |
"vote": vote.vote, | |
"description": vote.description, | |
"userID": vote.userID, | |
} | |
# Append the vote entry to the JSONL file | |
background_tasks.add_task(save_vote, vote_entry) | |
return JSONResponse(content={"message": "Vote submitted successfully"}) | |
def format_prompt(card: str) -> str: | |
return f""" | |
Write a tl;dr summary of a dataset based on the dataset card. Focus on the most critical aspects of the dataset. | |
The summary should aim to concisely describe the dataset. | |
CARD: \n\n{card[:6000]} | |
--- | |
\n\nInstructions: | |
If the card provides the necessary information, say what the dataset can be used for. | |
You do not need to mention that the dataset is hosted or available on the Hugging Face Hub. | |
Do not mention the license of the dataset. | |
Do not mention the number of examples in the training or test split. | |
Only mention size if there is extensive discussion of the scale of the dataset in the dataset card. | |
Do not speculate on anything not explicitly mentioned in the dataset card. | |
In general avoid references to the quality of the dataset i.e. don't use phrases like 'a high-quality dataset' in the summary. | |
\n\nOne sentence summary:""" | |
async def check_when_dataset_last_modified(dataset_id: str) -> datetime | None: | |
try: | |
response = await async_httpx_client.get( | |
f"https://huggingface.co/api/datasets/{dataset_id}" | |
) | |
if last_modified := response.json().get("lastModified"): | |
return datetime.fromisoformat(last_modified) | |
return None | |
except Exception as e: | |
logger.error(e) | |
return None | |
async def predict(card: str, dataset_id: str) -> str | None: | |
try: | |
prompt = format_prompt(card) | |
client = AsyncOpenAI( | |
base_url="https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1/v1", | |
api_key=HF_TOKEN, | |
) | |
chat_completion = await client.chat.completions.create( | |
model="tgi", | |
messages=[ | |
{"role": "user", "content": prompt}, | |
], | |
stream=False, | |
tags=["tldr-summaries"], | |
) | |
return chat_completion.choices[0].message.content.strip() | |
except Exception as e: | |
logger.error(e) | |
return None | |
async def get_summary(dataset_id: str) -> str | None: | |
""" | |
Get a summary for a dataset based on the provided dataset ID. | |
Args: | |
dataset_id (str): The ID of the dataset to retrieve the summary for. | |
Returns: | |
str | None: The generated summary for the dataset, or None if no summary is available or an error occurs.""" | |
try: | |
# dataset_id = request.dataset_id | |
card_text = await async_httpx_client.get( | |
f"https://huggingface.co/datasets/{dataset_id}/raw/main/README.md" | |
) | |
card_text = card_text.text | |
card = DatasetCard(card_text) | |
text = card.text | |
parsed_text = parse_markdown(text) | |
if is_empty_template(parsed_text): | |
return None | |
cache_key = f"predict:{dataset_id}" | |
cached_data = await cache.get(cache_key) | |
if cached_data is not None: | |
cached_summary, cached_last_modified_time = cached_data | |
# Get the current last modified time of the dataset | |
current_last_modified_time = await check_when_dataset_last_modified( | |
dataset_id | |
) | |
if ( | |
current_last_modified_time is None | |
or cached_last_modified_time >= current_last_modified_time | |
): | |
# Use the cached summary if the cached last modified time is greater than or equal to the current last modified time | |
logger.info("Using cached summary") | |
return cached_summary | |
summary = await predict(parsed_text, dataset_id) | |
current_last_modified_time = await check_when_dataset_last_modified(dataset_id) | |
await cache.set(cache_key, (summary, current_last_modified_time)) | |
return summary | |
except Exception as e: | |
logger.error(e) | |
return None | |
class SummariesRequest(BaseModel): | |
dataset_ids: List[str] | |
async def get_summaries(request: SummariesRequest) -> dict: | |
""" | |
Get summaries for a list of datasets based on the provided dataset IDs. | |
Args: | |
dataset_ids (List[str]): A list of dataset IDs to retrieve the summaries for. | |
Returns: | |
dict: A dictionary mapping dataset IDs to their corresponding summaries. | |
""" | |
dataset_ids = request.dataset_ids | |
async def get_summary_wrapper(dataset_id): | |
return dataset_id, await get_summary(dataset_id) | |
summary_tasks = [get_summary_wrapper(dataset_id) for dataset_id in dataset_ids] | |
summaries = dict(await asyncio.gather(*summary_tasks)) | |
for dataset_id in dataset_ids: | |
if summaries[dataset_id] is None: | |
summaries[dataset_id] = "No summary available" | |
return summaries | |