Spaces:
Runtime error
Runtime error
import asyncio | |
import logging | |
import numpy as np | |
import time | |
import json | |
import os | |
import tempfile | |
import requests | |
from fastapi import FastAPI, Header, HTTPException, BackgroundTasks | |
from fastapi.responses import FileResponse | |
from aiohttp import ClientSession | |
from langchain.text_splitter import SpacyTextSplitter | |
from datasets import Dataset, load_dataset | |
from tqdm import tqdm | |
from tqdm.asyncio import tqdm_asyncio | |
from src.models import chunk_config, embed_config, WebhookPayload | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
TEI_URL = os.getenv("TEI_URL") | |
app = FastAPI() | |
async def home(): | |
return FileResponse("home.html") | |
async def post_webhook( | |
payload: WebhookPayload, | |
task_queue: BackgroundTasks | |
): | |
if not ( | |
payload.event.action == "update" | |
and payload.event.scope.startswith("repo.content") | |
and ( | |
payload.repo.name == embed_config.input_dataset | |
# or payload.repo.name == chunk_config.input_dataset | |
) | |
and payload.repo.type == "dataset" | |
): | |
# no-op | |
logger.info("Update detected, no action taken") | |
return {"processed": False} | |
if payload.repo.name == chunk_config.input_dataset: | |
task_queue.add_task(chunk_dataset) | |
task_queue.add_task(embed_dataset) | |
return {"processed": True} | |
""" | |
CHUNKING | |
""" | |
class Chunker: | |
def __init__(self, strategy, split_seq, chunk_len): | |
self.split_seq = split_seq | |
self.chunk_len = chunk_len | |
if strategy == "spacy": | |
self.split = SpacyTextSplitter().split_text | |
if strategy == "sequence": | |
self.split = self.seq_splitter | |
if strategy == "constant": | |
self.split = self.const_splitter | |
def seq_splitter(self, text): | |
return text.split(self.split_seq) | |
def const_splitter(self, text): | |
return [ | |
text[i * self.chunk_len:(i + 1) * self.chunk_len] | |
for i in range(int(np.ceil(len(text) / self.chunk_len))) | |
] | |
def chunk_generator(input_dataset, chunker): | |
for i in tqdm(range(len(input_dataset))): | |
chunks = chunker.split(input_dataset[i][chunk_config.input_text_col]) | |
for chunk in chunks: | |
if chunk: | |
yield {chunk_config.input_text_col: chunk} | |
def chunk_dataset(): | |
logger.info("Update detected, chunking is scheduled") | |
input_ds = load_dataset(chunk_config.input_dataset, split=chunk_config.input_splits) | |
chunker = Chunker( | |
strategy=chunk_config.strategy, | |
split_seq=chunk_config.split_seq, | |
chunk_len=chunk_config.chunk_len | |
) | |
dataset = Dataset.from_generator( | |
chunk_generator, | |
gen_kwargs={ | |
"input_dataset": input_ds, | |
"chunker": chunker | |
} | |
) | |
dataset.push_to_hub( | |
chunk_config.output_dataset, | |
private=chunk_config.private, | |
token=HF_TOKEN | |
) | |
logger.info("Done chunking") | |
return {"processed": True} | |
""" | |
EMBEDDING | |
""" | |
async def embed_sent(sentence, semaphore, tei_url, tmp_file): | |
async with semaphore: | |
payload = { | |
"inputs": sentence, | |
"truncate": True | |
} | |
async with ClientSession( | |
headers={ | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {HF_TOKEN}" | |
} | |
) as session: | |
async with session.post(tei_url, json=payload) as resp: | |
if resp.status != 200: | |
raise RuntimeError(await resp.text()) | |
result = await resp.json() | |
tmp_file.write( | |
json.dumps({"vector": result[0], chunk_config.input_text_col: sentence}) + "\n" | |
) | |
async def embed(input_ds, tei_url, temp_file): | |
semaphore = asyncio.BoundedSemaphore(embed_config.semaphore_bound) | |
jobs = [ | |
asyncio.create_task(embed_sent(row[chunk_config.input_text_col], semaphore, tei_url, temp_file)) | |
for row in input_ds if row[chunk_config.input_text_col].strip() | |
] | |
logger.info(f"num chunks to embed: {len(jobs)}") | |
tic = time.time() | |
await tqdm_asyncio.gather(*jobs) | |
logger.info(f"embed time: {time.time() - tic}") | |
def wake_up_endpoint(url): | |
while requests.get( | |
url=url, | |
headers={"Authorization": f"Bearer {HF_TOKEN}"} | |
).status_code != 200: | |
time.sleep(2) | |
logger.info("TEI endpoint is up") | |
def embed_dataset(): | |
logger.info("Update detected, embedding is scheduled") | |
wake_up_endpoint(embed_config.tei_url) | |
input_ds = load_dataset(embed_config.input_dataset, split=embed_config.input_splits) | |
with tempfile.NamedTemporaryFile(mode="a", suffix=".jsonl") as temp_file: | |
asyncio.run(embed(input_ds, embed_config.tei_url, temp_file)) | |
dataset = Dataset.from_json(temp_file.name) | |
dataset.push_to_hub( | |
embed_config.output_dataset, | |
private=embed_config.private, | |
token=HF_TOKEN | |
) | |
logger.info("Done embedding") | |