File size: 1,879 Bytes
8b7a023
 
 
 
c3d9b67
8b7a023
 
 
 
 
 
 
 
 
 
41169fa
8b7a023
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3dca465
46093c3
 
 
826c6ad
 
 
 
 
 
 
 
 
8b7a023
c3d9b67
8b7a023
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import json
import asyncio
import logging
import time
import requests

from tqdm.asyncio import tqdm_asyncio
from huggingface_hub import get_inference_endpoint

from models import env_config, embed_config

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

endpoint = get_inference_endpoint(env_config.tei_name, token=env_config.hf_token)


async def embed_chunk(sentence, semaphore, tmp_file):
    async with semaphore:
        payload = {
            "inputs": sentence,
            "truncate": True
        }

        try:
            resp = await endpoint.async_client.post(json=payload)
        except Exception as e:
            raise RuntimeError(str(e))

        result = json.loads(resp)
        tmp_file.write(
            json.dumps({"vector": result[0], env_config.input_text_col: sentence}) + "\n"
        )


async def embed_wrapper(input_ds, temp_file):
    semaphore = asyncio.BoundedSemaphore(embed_config.semaphore_bound)
    jobs = [
        asyncio.create_task(embed_chunk(row[env_config.input_text_col], semaphore, temp_file))
        for row in input_ds if row[env_config.input_text_col].strip()
    ]
    logger.info(f"num chunks to embed: {len(jobs)}")

    tic = time.time()
    await tqdm_asyncio.gather(*jobs)
    logger.info(f"embed time: {time.time() - tic}")


def wake_up_endpoint():
    endpoint.fetch()
    if endpoint.status != 'running':
        logger.info("Starting up TEI endpoint")
        endpoint.resume().wait().fetch()
        # n_loop = 0
        # while requests.get(
        #     url=endpoint.url,
        #     headers={"Authorization": f"Bearer {env_config.hf_token}"}
        # ).status_code != 200:
        #     time.sleep(2)
        #     n_loop += 1
        #     if n_loop > 20:
        #         raise TimeoutError("TEI endpoint is unavailable")
    logger.info("TEI endpoint is up")

    return