davanstrien HF staff commited on
Commit
fe9092a
1 Parent(s): ec5b5a2

Refactor code and add new functionality

Browse files
Files changed (1) hide show
  1. main.py +160 -14
main.py CHANGED
@@ -1,28 +1,40 @@
 
1
  import json
2
  import logging
3
  import os
 
4
  from contextlib import asynccontextmanager
5
  from datetime import datetime
6
  from pathlib import Path
7
- from typing import Annotated
8
 
 
9
  from dotenv import load_dotenv
10
  from fastapi import BackgroundTasks, FastAPI, Header, HTTPException
 
11
  from fastapi.responses import JSONResponse
12
- from huggingface_hub import CommitScheduler, HfApi, hf_hub_download, whoami
 
 
13
  from huggingface_hub.utils._errors import HTTPError
 
14
  from pydantic import BaseModel, Field
15
  from starlette.responses import RedirectResponse
 
16
 
 
17
  load_dotenv()
18
- logger = logging.getLogger(__name__)
19
 
20
- VOTES_FILE = "data/votes.jsonl"
 
21
 
 
22
 
23
  HF_TOKEN = os.getenv("HF_TOKEN")
24
- hf_api = HfApi(token=HF_TOKEN)
25
 
 
 
26
  scheduler = CommitScheduler(
27
  repo_id="davanstrien/summary-ratings",
28
  repo_type="dataset",
@@ -50,19 +62,22 @@ async def lifespan(app: FastAPI):
50
  yield
51
 
52
 
53
- app = FastAPI(lifespan=lifespan)
54
- # # Configure CORS
55
  # origins = [
56
  # "https://huggingface.co",
57
- # "chrome-extension://ogbhjlfpmjgjbjoiffagjogbhgaipopf", # Replace with your Chrome plugin ID
58
  # ]
59
 
60
 
 
61
  # app.add_middleware(
62
  # CORSMiddleware,
63
- # allow_origins=origins,
 
 
64
  # allow_credentials=True,
65
- # allow_methods=["POST"],
66
  # allow_headers=["*"],
67
  # )
68
 
@@ -72,9 +87,7 @@ def save_vote(vote_entry):
72
  with open(VOTES_FILE, "a") as file:
73
  date_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
74
  vote_entry["timestamp"] = date_time
75
- file.write(
76
- json.dumps(vote_entry) + "\n"
77
- ) # Add a newline character after writing each entry
78
  logger.info(f"Vote saved: {vote_entry}")
79
 
80
 
@@ -90,7 +103,7 @@ class Vote(BaseModel):
90
  userID: str
91
 
92
 
93
- def validate_token(token: str = Header(None)):
94
  try:
95
  whoami(token)
96
  return True
@@ -116,3 +129,136 @@ async def receive_vote(
116
  # Append the vote entry to the JSONL file
117
  background_tasks.add_task(save_vote, vote_entry)
118
  return JSONResponse(content={"message": "Vote submitted successfully"})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
  import json
3
  import logging
4
  import os
5
+ import re
6
  from contextlib import asynccontextmanager
7
  from datetime import datetime
8
  from pathlib import Path
9
+ from typing import Annotated, List
10
 
11
+ from cashews import NOT_NONE, cache
12
  from dotenv import load_dotenv
13
  from fastapi import BackgroundTasks, FastAPI, Header, HTTPException
14
+ from fastapi.middleware.cors import CORSMiddleware
15
  from fastapi.responses import JSONResponse
16
+ from httpx import AsyncClient
17
+ from huggingface_hub import CommitScheduler, DatasetCard, HfApi, hf_hub_download, whoami
18
+ from huggingface_hub.utils import disable_progress_bars, logging
19
  from huggingface_hub.utils._errors import HTTPError
20
+ from langfuse.openai import AsyncOpenAI # OpenAI integration
21
  from pydantic import BaseModel, Field
22
  from starlette.responses import RedirectResponse
23
+ from card_processing import parse_markdown, try_load_text, is_empty_template
24
 
25
+ disable_progress_bars()
26
  load_dotenv()
27
+ logger = logging.get_logger(__name__)
28
 
29
+ Gb = 1073741824
30
+ cache.setup("disk://", size_limit=16 * Gb) # configure as in-memory cache
31
 
32
+ VOTES_FILE = "data/votes.jsonl"
33
 
34
  HF_TOKEN = os.getenv("HF_TOKEN")
 
35
 
36
+ hf_api = HfApi(token=HF_TOKEN)
37
+ async_httpx_client = AsyncClient()
38
  scheduler = CommitScheduler(
39
  repo_id="davanstrien/summary-ratings",
40
  repo_type="dataset",
 
62
  yield
63
 
64
 
65
+ app = FastAPI() # )lifespan=lifespan)
66
+ # Configure CORS
67
  # origins = [
68
  # "https://huggingface.co",
69
+ # "chrome-extension://deckahggoiaphiebdipfbiinmaihfpbk", # Replace with your Chrome plugin ID
70
  # ]
71
 
72
 
73
+ # # Configure CORS settings
74
  # app.add_middleware(
75
  # CORSMiddleware,
76
+ # allow_origins=[
77
+ # "https://huggingface.co/datasets/*"
78
+ # ], # Update with your frontend URL
79
  # allow_credentials=True,
80
+ # allow_methods=["*"],
81
  # allow_headers=["*"],
82
  # )
83
 
 
87
  with open(VOTES_FILE, "a") as file:
88
  date_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
89
  vote_entry["timestamp"] = date_time
90
+ file.write(json.dumps(vote_entry) + "\n")
 
 
91
  logger.info(f"Vote saved: {vote_entry}")
92
 
93
 
 
103
  userID: str
104
 
105
 
106
+ def validate_token(token: str = Header(None)) -> bool:
107
  try:
108
  whoami(token)
109
  return True
 
129
  # Append the vote entry to the JSONL file
130
  background_tasks.add_task(save_vote, vote_entry)
131
  return JSONResponse(content={"message": "Vote submitted successfully"})
132
+
133
+
134
+ def format_prompt(card: str) -> str:
135
+ return f"""
136
+ Write a tl;dr summary of a dataset based on the dataset card. Focus on the most critical aspects of the dataset.
137
+ The summary should aim to concisely describe the dataset.
138
+
139
+ CARD: \n\n{card[:6000]}
140
+ ---
141
+
142
+ \n\nInstructions:
143
+ If the card provides the necessary information, say what the dataset can be used for.
144
+ You do not need to mention that the dataset is hosted or available on the Hugging Face Hub.
145
+ Do not mention the license of the dataset.
146
+ Do not mention the number of examples in the training or test split.
147
+ Only mention size if there is extensive discussion of the scale of the dataset in the dataset card.
148
+ Do not speculate on anything not explicitly mentioned in the dataset card.
149
+ In general avoid references to the quality of the dataset i.e. don't use phrases like 'a high-quality dataset' in the summary.
150
+
151
+ \n\nOne sentence summary:"""
152
+
153
+
154
+ async def check_when_dataset_last_modified(dataset_id: str) -> datetime | None:
155
+ try:
156
+ response = await async_httpx_client.get(
157
+ f"https://huggingface.co/api/datasets/{dataset_id}"
158
+ )
159
+ if last_modified := response.json().get("lastModified"):
160
+ return datetime.fromisoformat(last_modified)
161
+ return None
162
+ except Exception as e:
163
+ logger.error(e)
164
+ return None
165
+
166
+
167
+ @cache(ttl="48h", condition=NOT_NONE, key="predict:{dataset_id}")
168
+ async def predict(card: str, dataset_id: str) -> str | None:
169
+ try:
170
+ prompt = format_prompt(card)
171
+ client = AsyncOpenAI(
172
+ base_url="https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1/v1",
173
+ api_key=HF_TOKEN,
174
+ )
175
+
176
+ chat_completion = await client.chat.completions.create(
177
+ model="tgi",
178
+ messages=[
179
+ {"role": "user", "content": prompt},
180
+ ],
181
+ stream=False,
182
+ tags=["tldr-summaries"],
183
+ )
184
+ return chat_completion.choices[0].message.content.strip()
185
+ except Exception as e:
186
+ logger.error(e)
187
+ return None
188
+
189
+
190
+ @app.get("/summary")
191
+ async def get_summary(dataset_id: str) -> str | None:
192
+ """
193
+ Get a summary for a dataset based on the provided dataset ID.
194
+
195
+ Args:
196
+ dataset_id (str): The ID of the dataset to retrieve the summary for.
197
+
198
+ Returns:
199
+ str | None: The generated summary for the dataset, or None if no summary is available or an error occurs."""
200
+
201
+ try:
202
+ # dataset_id = request.dataset_id
203
+ card_text = await async_httpx_client.get(
204
+ f"https://huggingface.co/datasets/{dataset_id}/raw/main/README.md"
205
+ )
206
+ card_text = card_text.text
207
+ card = DatasetCard(card_text)
208
+ text = card.text
209
+ parsed_text = parse_markdown(text)
210
+ if is_empty_template(parsed_text):
211
+ return None
212
+ cache_key = f"predict:{dataset_id}"
213
+ cached_data = await cache.get(cache_key)
214
+
215
+ if cached_data is not None:
216
+ cached_summary, cached_last_modified_time = cached_data
217
+ # Get the current last modified time of the dataset
218
+ current_last_modified_time = await check_when_dataset_last_modified(
219
+ dataset_id
220
+ )
221
+
222
+ if (
223
+ current_last_modified_time is None
224
+ or cached_last_modified_time >= current_last_modified_time
225
+ ):
226
+ # Use the cached summary if the cached last modified time is greater than or equal to the current last modified time
227
+ logger.info("Using cached summary")
228
+ return cached_summary
229
+ summary = await predict(parsed_text, dataset_id)
230
+ current_last_modified_time = await check_when_dataset_last_modified(dataset_id)
231
+ await cache.set(cache_key, (summary, current_last_modified_time))
232
+ return summary
233
+ except Exception as e:
234
+ logger.error(e)
235
+ return None
236
+
237
+
238
+ class SummariesRequest(BaseModel):
239
+ dataset_ids: List[str]
240
+
241
+
242
+ @cache(ttl="1h", condition=NOT_NONE)
243
+ @app.post("/summaries")
244
+ async def get_summaries(request: SummariesRequest) -> dict:
245
+ """
246
+ Get summaries for a list of datasets based on the provided dataset IDs.
247
+
248
+ Args:
249
+ dataset_ids (List[str]): A list of dataset IDs to retrieve the summaries for.
250
+
251
+ Returns:
252
+ dict: A dictionary mapping dataset IDs to their corresponding summaries.
253
+ """
254
+ dataset_ids = request.dataset_ids
255
+
256
+ async def get_summary_wrapper(dataset_id):
257
+ return dataset_id, await get_summary(dataset_id)
258
+
259
+ summary_tasks = [get_summary_wrapper(dataset_id) for dataset_id in dataset_ids]
260
+ summaries = dict(await asyncio.gather(*summary_tasks))
261
+ for dataset_id in dataset_ids:
262
+ if summaries[dataset_id] is None:
263
+ summaries[dataset_id] = "No summary available"
264
+ return summaries