jugarte00's picture
Upload folder using huggingface_hub
9a45764 verified
import contextlib
import json
import logging
import os
import re
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator
import backoff
import httpx
import tqdm
from article_embedding.utils import env_str
log = logging.getLogger(__name__)
class Checkpoint(ABC):
@abstractmethod
def get(self) -> str | None: ...
@abstractmethod
def set(self, value: str) -> None: ...
@abstractmethod
def reset(self) -> None: ...
class NullCheckpoint(Checkpoint):
def get(self) -> str | None:
return None
def set(self, value: str) -> None:
pass
def reset(self) -> None:
pass
_NULL_CHECKPOINT = NullCheckpoint()
class FileCheckpoint(Checkpoint):
def __init__(self, path: str) -> None:
self.path = path
def get(self) -> str | None:
try:
with open(self.path) as file:
return file.read().strip()
except FileNotFoundError:
return None
def set(self, value: str) -> None:
with open(self.path, "w") as file:
file.write(value)
def reset(self) -> None:
with contextlib.suppress(FileNotFoundError):
os.remove(self.path)
class CouchDB:
def __init__(self) -> None:
self.client = self.make_client()
self.database = env_str("COUCHDB_DB")
self.path_view = f"/{self.database}/{env_str("DOCS_PATH_VIEW")}"
def __new__(cls) -> "CouchDB":
if not hasattr(cls, "_instance"):
cls._instance = super().__new__(cls)
return cls._instance
def make_client(self) -> httpx.AsyncClient:
url = os.environ["COUCHDB_URL"]
user = os.environ["COUCHDB_USER"]
password = os.environ["COUCHDB_PASSWORD"]
auth = {"name": user, "password": password}
async def on_backoff(details: Any) -> None:
response = await self.client.post("/_session", json=auth)
response.raise_for_status()
client = httpx.AsyncClient(base_url=url)
decorator = backoff.on_predicate(
backoff.expo,
predicate=lambda r: r.status_code == 401,
on_backoff=on_backoff,
max_tries=2,
factor=0,
)
client.get = decorator(client.get) # type: ignore[method-assign]
return client
async def changes(self, *, batch_size: int, checkpoint: Checkpoint = _NULL_CHECKPOINT) -> AsyncGenerator[list[Any], None]:
since = checkpoint.get() or 0
params = {"since": since, "limit": batch_size, "include_docs": True}
while True:
response = await self.client.get(f"/{self.database}/_changes", params=params)
response.raise_for_status()
data = response.json()
yield [change["doc"] for change in data["results"]]
since = data["last_seq"]
assert isinstance(since, str)
params["since"] = since
checkpoint.set(since)
if data["pending"] == 0:
break
async def estimate_total_changes(self, *, checkpoint: Checkpoint = _NULL_CHECKPOINT) -> int:
since = checkpoint.get() or 0
params = {"since": since, "limit": 0}
response = await self.client.get(f"/{self.database}/_changes", params=params)
response.raise_for_status()
data = response.json()
return int(data["pending"]) + 1
async def get_doc_by_id(self, doc_id: str) -> Any:
try:
response = await self.client.get(f"/{self.database}/{doc_id}")
if response.status_code == 404:
return None
response.raise_for_status()
return response.json()
except Exception as e:
log.error("Error fetching document by ID", exc_info=e)
return None
async def get_doc_by_path(self, path: str) -> Any:
try:
params = {
"limit": "1",
"key": json.dumps(path),
"include_docs": "true",
}
response = await self.client.get(self.path_view, params=params)
response.raise_for_status()
data = response.json()
rows = data["rows"]
if not rows:
return None
return rows[0]["doc"]
except Exception as e:
logging.error("Error fetching document by path", exc_info=e)
return None
async def get_doc(self, id_or_path: str) -> Any:
uuids = extract_doc_ids(id_or_path)
for uuid in uuids:
doc = await self.get_doc_by_id(uuid)
if doc:
return doc
path = extract_doc_path(id_or_path)
if path:
return await self.get_doc_by_path(path)
return None
UUID_PATTERN = re.compile(r"[0-9a-f]{8}-[0-9a-f]{4}-[1-5][0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}")
def extract_doc_ids(s: str) -> list[str]:
return UUID_PATTERN.findall(s)
def extract_doc_path(s: str) -> str | None:
if not s.endswith(".html"):
return None
if s.startswith("/"):
return s
if "://" in s:
s = s.split("://", 1)[1]
if "/" in s:
return "/" + s.split("/", 1)[1]
return None
if __name__ == "__main__":
async def main() -> None:
db = CouchDB()
checkpoint = FileCheckpoint(".checkpoint")
total = await db.estimate_total_changes(checkpoint=checkpoint)
with tqdm.tqdm(total=total) as pbar:
async for docs in db.changes(batch_size=40, checkpoint=checkpoint):
for doc in docs:
kind = doc.get("type")
if kind == "article":
_id = doc["_id"]
language = doc["language"]
path = doc["path"]
path = os.path.basename(path)
pbar.desc = f"{_id}: {kind} {language} {path}"
else:
pbar.desc = f"{kind}"
pbar.update(1)
import asyncio
from dotenv import load_dotenv
load_dotenv()
asyncio.run(main())