Spaces:
Runtime error
Runtime error
File size: 1,879 Bytes
8b7a023 c3d9b67 8b7a023 41169fa 8b7a023 3dca465 46093c3 826c6ad 8b7a023 c3d9b67 8b7a023 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import json
import asyncio
import logging
import time
import requests
from tqdm.asyncio import tqdm_asyncio
from huggingface_hub import get_inference_endpoint
from models import env_config, embed_config
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
endpoint = get_inference_endpoint(env_config.tei_name, token=env_config.hf_token)
async def embed_chunk(sentence, semaphore, tmp_file):
async with semaphore:
payload = {
"inputs": sentence,
"truncate": True
}
try:
resp = await endpoint.async_client.post(json=payload)
except Exception as e:
raise RuntimeError(str(e))
result = json.loads(resp)
tmp_file.write(
json.dumps({"vector": result[0], env_config.input_text_col: sentence}) + "\n"
)
async def embed_wrapper(input_ds, temp_file):
semaphore = asyncio.BoundedSemaphore(embed_config.semaphore_bound)
jobs = [
asyncio.create_task(embed_chunk(row[env_config.input_text_col], semaphore, temp_file))
for row in input_ds if row[env_config.input_text_col].strip()
]
logger.info(f"num chunks to embed: {len(jobs)}")
tic = time.time()
await tqdm_asyncio.gather(*jobs)
logger.info(f"embed time: {time.time() - tic}")
def wake_up_endpoint():
endpoint.fetch()
if endpoint.status != 'running':
logger.info("Starting up TEI endpoint")
endpoint.resume().wait().fetch()
# n_loop = 0
# while requests.get(
# url=endpoint.url,
# headers={"Authorization": f"Bearer {env_config.hf_token}"}
# ).status_code != 200:
# time.sleep(2)
# n_loop += 1
# if n_loop > 20:
# raise TimeoutError("TEI endpoint is unavailable")
logger.info("TEI endpoint is up")
return
|