|
import argparse |
|
import asyncio |
|
import gradio as gr |
|
import numpy as np |
|
import time |
|
import json |
|
import os |
|
import tempfile |
|
import requests |
|
import logging |
|
|
|
from aiohttp import ClientSession |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class Chunker: |
|
def __init__(self, strategy, split_seq=".", chunk_len=512): |
|
self.split_seq = split_seq |
|
self.chunk_len = int(chunk_len) |
|
if strategy == "recursive": |
|
self.split = RecursiveCharacterTextSplitter( |
|
chunk_size=self.chunk_len, |
|
separators=[split_seq] |
|
).split_text |
|
elif strategy == "sequence": |
|
self.split = self.seq_splitter |
|
elif strategy == "constant": |
|
self.split = self.const_splitter |
|
|
|
def seq_splitter(self, text): |
|
return text.split(self.split_seq) |
|
|
|
def const_splitter(self, text): |
|
return [ |
|
text[i * self.chunk_len:(i + 1) * self.chunk_len] |
|
for i in range(int(np.ceil(len(text) / self.chunk_len))) |
|
] |
|
|
|
def chunk_text(input_text, strategy, split_seq, chunk_len): |
|
chunker = Chunker(strategy, split_seq, chunk_len) |
|
chunks = chunker.split(input_text) |
|
return chunks |
|
|
|
async def embed_sent(sentence, tei_url): |
|
payload = { |
|
"inputs": sentence, |
|
"truncate": True |
|
} |
|
async with ClientSession( |
|
headers={ |
|
"Content-Type": "application/json", |
|
} |
|
) as session: |
|
async with session.post(tei_url, json=payload) as resp: |
|
if resp.status != 200: |
|
raise RuntimeError(await resp.text()) |
|
result = await resp.json() |
|
return result[0] |
|
|
|
async def embed_first_sentence(chunks, tei_url): |
|
if not chunks: |
|
return [], [] |
|
|
|
first_sentence = chunks[0] |
|
embedded_sentence = await embed_sent(first_sentence, tei_url) |
|
return first_sentence, embedded_sentence |
|
|
|
def wake_up_endpoint(url): |
|
logger.info("Starting up TEI endpoint") |
|
n_loop = 0 |
|
while requests.get( |
|
url=url, |
|
headers={"Content-Type": "application/json"} |
|
).status_code != 200: |
|
time.sleep(2) |
|
n_loop += 1 |
|
if n_loop > 40: |
|
raise gr.Error("TEI endpoint is unavailable") |
|
logger.info("TEI endpoint is up") |
|
|
|
async def process_text(input_text, strategy, split_seq, chunk_len, tei_url): |
|
wake_up_endpoint(tei_url) |
|
chunks = chunk_text(input_text, strategy, split_seq, chunk_len) |
|
first_sentence, embedded_sentence = await embed_first_sentence(chunks, tei_url) |
|
return chunks, first_sentence, embedded_sentence |
|
|
|
def change_dropdown(choice): |
|
if choice == "recursive": |
|
return [ |
|
gr.Textbox(visible=True), |
|
gr.Textbox(visible=True) |
|
] |
|
elif choice == "sequence": |
|
return [ |
|
gr.Textbox(visible=True), |
|
gr.Textbox(visible=False) |
|
] |
|
else: |
|
return [ |
|
gr.Textbox(visible=False), |
|
gr.Textbox(visible=True) |
|
] |
|
|
|
def main(args): |
|
with gr.Blocks(theme='sudeepshouche/minimalist') as demo: |
|
gr.Markdown("## Chunk and Embed") |
|
|
|
input_text = gr.Textbox(lines=5, label="Input Text") |
|
|
|
with gr.Row(): |
|
dropdown = gr.Dropdown( |
|
["recursive", "sequence", "constant"], label="Chunking Strategy", |
|
info="'recursive' uses a Langchain recursive tokenizer, 'sequence' splits texts by a chosen sequence, " |
|
"'constant' makes chunks of the constant size", |
|
scale=2 |
|
) |
|
split_seq = gr.Textbox( |
|
lines=1, |
|
interactive=True, |
|
visible=False, |
|
label="Sequence", |
|
info="A text sequence to split on", |
|
placeholder="\n\n" |
|
) |
|
chunk_len = gr.Textbox( |
|
lines=1, |
|
interactive=True, |
|
visible=False, |
|
label="Length", |
|
info="The length of chunks to split into in characters", |
|
placeholder="512" |
|
) |
|
dropdown.change(fn=change_dropdown, inputs=dropdown, outputs=[split_seq, chunk_len]) |
|
|
|
tei_url = gr.Textbox(lines=1, label="TEI Endpoint URL") |
|
|
|
with gr.Row(): |
|
clear = gr.ClearButton(components=[input_text, dropdown, split_seq, chunk_len, tei_url]) |
|
embed_btn = gr.Button("Submit") |
|
embed_btn.click( |
|
fn=process_text, |
|
inputs=[input_text, dropdown, split_seq, chunk_len, tei_url], |
|
outputs=[gr.JSON(label="Chunks"), gr.Textbox(label="First Chunked Sentence"), gr.JSON(label="Embedded Sentence")] |
|
) |
|
|
|
demo.queue() |
|
demo.launch(server_name="0.0.0.0", server_port=args.port) |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="A MAGIC example by ConceptaTech") |
|
parser.add_argument("--port", type=int, default=7860, help="Port to expose Gradio app") |
|
args = parser.parse_args() |
|
main(args) |
|
|