rodrigomasini's picture
Update app/main.py
dcf9bae verified
raw
history blame
No virus
7.8 kB
import argparse
from dotenv import load_dotenv
import asyncio
import gradio as gr
import numpy as np
import time
import json
import os
import tempfile
import requests
import logging
from aiohttp import ClientSession
from langchain.text_splitter import RecursiveCharacterTextSplitter
from datasets import Dataset, load_dataset
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio
load_dotenv()
USERNAME = os.getenv("USERNAME")
PWD = os.getenv("USER_PWD")
HF_TOKEN = os.getenv("HF_TOKEN")
SEMAPHORE_BOUND = os.getenv("SEMAPHORE_BOUND", "5")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class Chunker:
def __init__(self, strategy, split_seq=".", chunk_len=512):
self.split_seq = split_seq
self.chunk_len = chunk_len
if strategy == "recursive":
# https://huggingface.co/spaces/m-ric/chunk_visualizer
self.split = RecursiveCharacterTextSplitter(
chunk_size=chunk_len,
separators=[split_seq]
).split_text
if strategy == "sequence":
self.split = self.seq_splitter
if strategy == "constant":
self.split = self.const_splitter
def seq_splitter(self, text):
return text.split(self.split_seq)
def const_splitter(self, text):
return [
text[i * self.chunk_len:(i + 1) * self.chunk_len]
for i in range(int(np.ceil(len(text) / self.chunk_len)))
]
def generator(input_ds, input_text_col, chunker):
for i in tqdm(range(len(input_ds))):
chunks = chunker.split(input_ds[i][input_text_col])
for chunk in chunks:
if chunk:
yield {input_text_col: chunk}
async def embed_sent(sentence, embed_in_text_col, semaphore, tei_url, tmp_file):
async with semaphore:
payload = {
"inputs": sentence,
"truncate": True
}
async with ClientSession(
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {HF_TOKEN}"
}
) as session:
async with session.post(tei_url, json=payload) as resp:
if resp.status != 200:
raise RuntimeError(await resp.text())
result = await resp.json()
tmp_file.write(
json.dumps({"vector": result[0], embed_in_text_col: sentence}) + "\n"
)
async def embed_ds(input_ds, tei_url, embed_in_text_col, temp_file):
semaphore = asyncio.BoundedSemaphore(int(SEMAPHORE_BOUND))
jobs = [
asyncio.create_task(embed_sent(row[embed_in_text_col], embed_in_text_col, semaphore, tei_url, temp_file))
for row in input_ds if row[embed_in_text_col].strip()
]
logger.info(f"num chunks to embed: {len(jobs)}")
tic = time.time()
await tqdm_asyncio.gather(*jobs)
logger.info(f"embed time: {time.time() - tic}")
def wake_up_endpoint(url):
logger.info("Starting up TEI endpoint")
n_loop = 0
while requests.get(
url=url,
headers={"Authorization": f"Bearer {HF_TOKEN}"}
).status_code != 200:
time.sleep(2)
n_loop += 1
if n_loop > 40:
raise gr.Error("TEI endpoint is unavailable")
logger.info("TEI endpoint is up")
def chunk_embed(input_ds, input_splits, input_text_col, chunk_out_ds,
strategy, split_seq, chunk_len, embed_out_ds, tei_url, private):
gr.Info("Started chunking")
try:
input_splits = [spl.strip() for spl in input_splits.split(",") if spl]
input_ds = load_dataset(input_ds, "text-corpus", split="+".join(input_splits), token=HF_TOKEN)
chunker = Chunker(strategy, split_seq, chunk_len)
except Exception as e:
raise gr.Error(str(e))
gen_kwargs = {
"input_ds": input_ds,
"input_text_col": input_text_col,
"chunker": chunker
}
chunked_ds = Dataset.from_generator(generator, gen_kwargs=gen_kwargs)
chunked_ds.push_to_hub(
chunk_out_ds,
private=private,
token=HF_TOKEN
)
gr.Info("Done chunking")
logger.info("Done chunking")
try:
wake_up_endpoint(tei_url)
with tempfile.NamedTemporaryFile(mode="a", suffix=".jsonl") as temp_file:
asyncio.run(embed_ds(chunked_ds, tei_url, input_text_col, temp_file))
embedded_ds = Dataset.from_json(temp_file.name)
embedded_ds.push_to_hub(
embed_out_ds,
private=private,
token=HF_TOKEN
)
except Exception as e:
raise gr.Error(str(e))
gr.Info("Done embedding")
logger.info("Done embedding")
def change_dropdown(choice):
if choice == "recursive":
return [
gr.Textbox(visible=True),
gr.Textbox(visible=True)
]
elif choice == "sequence":
return [
gr.Textbox(visible=True),
gr.Textbox(visible=False)
]
else:
return [
gr.Textbox(visible=False),
gr.Textbox(visible=True)
]
def main(args):
with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
gr.Markdown(
"""
## Chunk and embed
"""
)
input_ds = gr.Textbox(lines=1, label="Input dataset name")
with gr.Row():
input_splits = gr.Textbox(lines=1, label="Input dataset splits", placeholder="train, test")
input_text_col = gr.Textbox(lines=1, label="Input text column name", placeholder="text")
chunk_out_ds = gr.Textbox(lines=1, label="Chunked dataset name")
with gr.Row():
dropdown = gr.Dropdown(
["recursive", "sequence", "constant"], label="Chunking strategy",
info="'recursive' uses a Langchain recursive tokenizer, 'sequence' splits texts by a chosen sequence, "
"'constant' makes chunks of the constant size",
scale=2
)
split_seq = gr.Textbox(
lines=1,
interactive=True,
visible=False,
label="Sequence",
info="A text sequence to split on",
placeholder="\n\n"
)
chunk_len = gr.Textbox(
lines=1,
interactive=True,
visible=False,
label="Length",
info="The length of chunks to split into in characters",
placeholder="512"
)
dropdown.change(fn=change_dropdown, inputs=dropdown, outputs=[split_seq, chunk_len])
embed_out_ds = gr.Textbox(lines=1, label="Embedded dataset name")
private = gr.Checkbox(label="Make output datasets private")
tei_url = gr.Textbox(lines=1, label="TEI endpoint url")
with gr.Row():
clear = gr.ClearButton(
components=[input_ds, input_splits, input_text_col, chunk_out_ds,
dropdown, split_seq, chunk_len, embed_out_ds, tei_url, private]
)
embed_btn = gr.Button("Submit")
embed_btn.click(
fn=chunk_embed,
inputs=[input_ds, input_splits, input_text_col, chunk_out_ds,
dropdown, split_seq, chunk_len, embed_out_ds, tei_url, private]
)
demo.queue()
demo.launch(auth=(USERNAME, PWD), server_name="0.0.0.0", server_port=args.port)
######
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="A MAGIC example by ConceptaTech")
parser.add_argument("--port", type=int, default=7860, help="Port to expose Gradio app")
args = parser.parse_args()
main(args)