import json
import asyncio
import logging
import time

from tqdm.asyncio import tqdm_asyncio
from huggingface_hub import get_inference_endpoint

from models import env_config, embed_config

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

endpoint = get_inference_endpoint(env_config.tei_name, token=env_config.hf_token)


async def embed_chunk(sentence, semaphore, tmp_file):
    async with semaphore:
        payload = {
            "inputs": sentence,
            "truncate": True
        }

        try:
            resp = await endpoint.async_client.post(json=payload)
        except Exception as e:
            raise RuntimeError(str(e))

        result = json.loads(resp)
        tmp_file.write(
            json.dumps({"vector": result[0], env_config.input_text_col: sentence}) + "\n"
        )


async def embed_wrapper(input_ds, temp_file):
    semaphore = asyncio.BoundedSemaphore(embed_config.semaphore_bound)
    jobs = [
        asyncio.create_task(embed_chunk(row[env_config.input_text_col], semaphore, temp_file))
        for row in input_ds if row[env_config.input_text_col].strip()
    ]
    logger.info(f"num chunks to embed: {len(jobs)}")

    tic = time.time()
    await tqdm_asyncio.gather(*jobs)
    logger.info(f"embed time: {time.time() - tic}")


def wake_up_endpoint():
    endpoint.fetch()
    if endpoint.status != 'running':
        logger.info("Starting up TEI endpoint")
        endpoint.resume()
        endpoint.wait()
    logger.info("TEI endpoint is up")
    return