webhook-space / embed_utils.py
sergeipetrov's picture
sergeipetrov HF staff
Update embed_utils.py
826c6ad verified
import json
import asyncio
import logging
import time
import requests
from tqdm.asyncio import tqdm_asyncio
from huggingface_hub import get_inference_endpoint
from models import env_config, embed_config
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
endpoint = get_inference_endpoint(env_config.tei_name, token=env_config.hf_token)
async def embed_chunk(sentence, semaphore, tmp_file):
async with semaphore:
payload = {
"inputs": sentence,
"truncate": True
}
try:
resp = await endpoint.async_client.post(json=payload)
except Exception as e:
raise RuntimeError(str(e))
result = json.loads(resp)
tmp_file.write(
json.dumps({"vector": result[0], env_config.input_text_col: sentence}) + "\n"
)
async def embed_wrapper(input_ds, temp_file):
semaphore = asyncio.BoundedSemaphore(embed_config.semaphore_bound)
jobs = [
asyncio.create_task(embed_chunk(row[env_config.input_text_col], semaphore, temp_file))
for row in input_ds if row[env_config.input_text_col].strip()
]
logger.info(f"num chunks to embed: {len(jobs)}")
tic = time.time()
await tqdm_asyncio.gather(*jobs)
logger.info(f"embed time: {time.time() - tic}")
def wake_up_endpoint():
endpoint.fetch()
if endpoint.status != 'running':
logger.info("Starting up TEI endpoint")
endpoint.resume().wait().fetch()
# n_loop = 0
# while requests.get(
# url=endpoint.url,
# headers={"Authorization": f"Bearer {env_config.hf_token}"}
# ).status_code != 200:
# time.sleep(2)
# n_loop += 1
# if n_loop > 20:
# raise TimeoutError("TEI endpoint is unavailable")
logger.info("TEI endpoint is up")
return