Concepta commited on
Commit
4ef144a
1 Parent(s): 3f75ae4

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +51 -128
app/main.py CHANGED
@@ -1,5 +1,4 @@
1
  import argparse
2
- from dotenv import load_dotenv
3
  import asyncio
4
  import gradio as gr
5
  import numpy as np
@@ -12,35 +11,23 @@ import logging
12
 
13
  from aiohttp import ClientSession
14
  from langchain.text_splitter import RecursiveCharacterTextSplitter
15
- from datasets import Dataset, load_dataset
16
- from tqdm import tqdm
17
- from tqdm.asyncio import tqdm_asyncio
18
-
19
- load_dotenv()
20
-
21
- USERNAME = os.getenv("USERNAME")
22
- PWD = os.getenv("USER_PWD")
23
- HF_TOKEN = os.getenv("HF_TOKEN")
24
- SEMAPHORE_BOUND = os.getenv("SEMAPHORE_BOUND", "5")
25
 
26
 
27
  logging.basicConfig(level=logging.INFO)
28
  logger = logging.getLogger(__name__)
29
 
30
-
31
  class Chunker:
32
  def __init__(self, strategy, split_seq=".", chunk_len=512):
33
  self.split_seq = split_seq
34
  self.chunk_len = chunk_len
35
  if strategy == "recursive":
36
- # https://huggingface.co/spaces/m-ric/chunk_visualizer
37
  self.split = RecursiveCharacterTextSplitter(
38
  chunk_size=chunk_len,
39
  separators=[split_seq]
40
  ).split_text
41
- if strategy == "sequence":
42
  self.split = self.seq_splitter
43
- if strategy == "constant":
44
  self.split = self.const_splitter
45
 
46
  def seq_splitter(self, text):
@@ -52,50 +39,35 @@ class Chunker:
52
  for i in range(int(np.ceil(len(text) / self.chunk_len)))
53
  ]
54
 
 
 
 
 
55
 
56
- def generator(input_ds, input_text_col, chunker):
57
- for i in tqdm(range(len(input_ds))):
58
- chunks = chunker.split(input_ds[i][input_text_col])
59
- for chunk in chunks:
60
- if chunk:
61
- yield {input_text_col: chunk}
62
-
63
-
64
- async def embed_sent(sentence, embed_in_text_col, semaphore, tei_url, tmp_file):
65
- async with semaphore:
66
- payload = {
67
- "inputs": sentence,
68
- "truncate": True
69
- }
70
-
71
- async with ClientSession(
72
- headers={
73
- "Content-Type": "application/json",
74
- "Authorization": f"Bearer {HF_TOKEN}"
75
- }
76
- ) as session:
77
- async with session.post(tei_url, json=payload) as resp:
78
- if resp.status != 200:
79
- raise RuntimeError(await resp.text())
80
- result = await resp.json()
81
-
82
- tmp_file.write(
83
- json.dumps({"vector": result[0], embed_in_text_col: sentence}) + "\n"
84
- )
85
-
86
-
87
- async def embed_ds(input_ds, tei_url, embed_in_text_col, temp_file):
88
- semaphore = asyncio.BoundedSemaphore(int(SEMAPHORE_BOUND))
89
- jobs = [
90
- asyncio.create_task(embed_sent(row[embed_in_text_col], embed_in_text_col, semaphore, tei_url, temp_file))
91
- for row in input_ds if row[embed_in_text_col].strip()
92
- ]
93
- logger.info(f"num chunks to embed: {len(jobs)}")
94
-
95
- tic = time.time()
96
- await tqdm_asyncio.gather(*jobs)
97
- logger.info(f"embed time: {time.time() - tic}")
98
-
99
 
100
  def wake_up_endpoint(url):
101
  logger.info("Starting up TEI endpoint")
@@ -110,49 +82,11 @@ def wake_up_endpoint(url):
110
  raise gr.Error("TEI endpoint is unavailable")
111
  logger.info("TEI endpoint is up")
112
 
113
-
114
- def chunk_embed(input_ds, input_splits, input_text_col, chunk_out_ds,
115
- strategy, split_seq, chunk_len, embed_out_ds, tei_url, private):
116
- gr.Info("Started chunking")
117
- try:
118
- input_splits = [spl.strip() for spl in input_splits.split(",") if spl]
119
- input_ds = load_dataset(input_ds, "text-corpus", split="+".join(input_splits), token=HF_TOKEN)
120
- chunker = Chunker(strategy, split_seq, chunk_len)
121
- except Exception as e:
122
- raise gr.Error(str(e))
123
-
124
- gen_kwargs = {
125
- "input_ds": input_ds,
126
- "input_text_col": input_text_col,
127
- "chunker": chunker
128
- }
129
- chunked_ds = Dataset.from_generator(generator, gen_kwargs=gen_kwargs)
130
- chunked_ds.push_to_hub(
131
- chunk_out_ds,
132
- private=private,
133
- token=HF_TOKEN
134
- )
135
-
136
- gr.Info("Done chunking")
137
- logger.info("Done chunking")
138
-
139
- try:
140
- wake_up_endpoint(tei_url)
141
- with tempfile.NamedTemporaryFile(mode="a", suffix=".jsonl") as temp_file:
142
- asyncio.run(embed_ds(chunked_ds, tei_url, input_text_col, temp_file))
143
-
144
- embedded_ds = Dataset.from_json(temp_file.name)
145
- embedded_ds.push_to_hub(
146
- embed_out_ds,
147
- private=private,
148
- token=HF_TOKEN
149
- )
150
- except Exception as e:
151
- raise gr.Error(str(e))
152
-
153
- gr.Info("Done embedding")
154
- logger.info("Done embedding")
155
-
156
 
157
  def change_dropdown(choice):
158
  if choice == "recursive":
@@ -171,22 +105,15 @@ def change_dropdown(choice):
171
  gr.Textbox(visible=True)
172
  ]
173
 
174
-
175
  def main(args):
176
  with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
177
- gr.Markdown(
178
- """
179
- ## Chunk and embed
180
- """
181
- )
182
- input_ds = gr.Textbox(lines=1, label="Input dataset name")
183
- with gr.Row():
184
- input_splits = gr.Textbox(lines=1, label="Input dataset splits", placeholder="train, test")
185
- input_text_col = gr.Textbox(lines=1, label="Input text column name", placeholder="text")
186
- chunk_out_ds = gr.Textbox(lines=1, label="Chunked dataset name")
187
  with gr.Row():
188
  dropdown = gr.Dropdown(
189
- ["recursive", "sequence", "constant"], label="Chunking strategy",
190
  info="'recursive' uses a Langchain recursive tokenizer, 'sequence' splits texts by a chosen sequence, "
191
  "'constant' makes chunks of the constant size",
192
  scale=2
@@ -208,27 +135,23 @@ def main(args):
208
  placeholder="512"
209
  )
210
  dropdown.change(fn=change_dropdown, inputs=dropdown, outputs=[split_seq, chunk_len])
211
- embed_out_ds = gr.Textbox(lines=1, label="Embedded dataset name")
212
- private = gr.Checkbox(label="Make output datasets private")
213
- tei_url = gr.Textbox(lines=1, label="TEI endpoint url")
214
  with gr.Row():
215
- clear = gr.ClearButton(
216
- components=[input_ds, input_splits, input_text_col, chunk_out_ds,
217
- dropdown, split_seq, chunk_len, embed_out_ds, tei_url, private]
218
- )
219
  embed_btn = gr.Button("Submit")
220
  embed_btn.click(
221
- fn=chunk_embed,
222
- inputs=[input_ds, input_splits, input_text_col, chunk_out_ds,
223
- dropdown, split_seq, chunk_len, embed_out_ds, tei_url, private]
224
  )
225
 
226
  demo.queue()
227
- demo.launch(auth=(USERNAME, PWD), server_name="0.0.0.0", server_port=args.port)
228
- ######
229
  if __name__ == "__main__":
230
  parser = argparse.ArgumentParser(description="A MAGIC example by ConceptaTech")
231
  parser.add_argument("--port", type=int, default=7860, help="Port to expose Gradio app")
232
-
233
- args = parser.parse_args()
234
- main(args)
 
1
  import argparse
 
2
  import asyncio
3
  import gradio as gr
4
  import numpy as np
 
11
 
12
  from aiohttp import ClientSession
13
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 
 
 
 
 
 
 
 
 
14
 
15
 
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
 
19
  class Chunker:
20
  def __init__(self, strategy, split_seq=".", chunk_len=512):
21
  self.split_seq = split_seq
22
  self.chunk_len = chunk_len
23
  if strategy == "recursive":
 
24
  self.split = RecursiveCharacterTextSplitter(
25
  chunk_size=chunk_len,
26
  separators=[split_seq]
27
  ).split_text
28
+ elif strategy == "sequence":
29
  self.split = self.seq_splitter
30
+ elif strategy == "constant":
31
  self.split = self.const_splitter
32
 
33
  def seq_splitter(self, text):
 
39
  for i in range(int(np.ceil(len(text) / self.chunk_len)))
40
  ]
41
 
42
+ def chunk_text(input_text, strategy, split_seq, chunk_len):
43
+ chunker = Chunker(strategy, split_seq, chunk_len)
44
+ chunks = chunker.split(input_text)
45
+ return chunks
46
 
47
+ async def embed_sent(sentence, tei_url):
48
+ payload = {
49
+ "inputs": sentence,
50
+ "truncate": True
51
+ }
52
+ async with ClientSession(
53
+ headers={
54
+ "Content-Type": "application/json",
55
+ "Authorization": f"Bearer {HF_TOKEN}"
56
+ }
57
+ ) as session:
58
+ async with session.post(tei_url, json=payload) as resp:
59
+ if resp.status != 200:
60
+ raise RuntimeError(await resp.text())
61
+ result = await resp.json()
62
+ return result[0]
63
+
64
+ async def embed_first_sentence(chunks, tei_url):
65
+ if not chunks:
66
+ return [], []
67
+
68
+ first_sentence = chunks[0]
69
+ embedded_sentence = await embed_sent(first_sentence, tei_url)
70
+ return first_sentence, embedded_sentence
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  def wake_up_endpoint(url):
73
  logger.info("Starting up TEI endpoint")
 
82
  raise gr.Error("TEI endpoint is unavailable")
83
  logger.info("TEI endpoint is up")
84
 
85
+ async def process_text(input_text, strategy, split_seq, chunk_len, tei_url):
86
+ wake_up_endpoint(tei_url)
87
+ chunks = chunk_text(input_text, strategy, split_seq, chunk_len)
88
+ first_sentence, embedded_sentence = await embed_first_sentence(chunks, tei_url)
89
+ return chunks, first_sentence, embedded_sentence
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  def change_dropdown(choice):
92
  if choice == "recursive":
 
105
  gr.Textbox(visible=True)
106
  ]
107
 
 
108
  def main(args):
109
  with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
110
+ gr.Markdown("## Chunk and Embed")
111
+
112
+ input_text = gr.Textbox(lines=5, label="Input Text")
113
+
 
 
 
 
 
 
114
  with gr.Row():
115
  dropdown = gr.Dropdown(
116
+ ["recursive", "sequence", "constant"], label="Chunking Strategy",
117
  info="'recursive' uses a Langchain recursive tokenizer, 'sequence' splits texts by a chosen sequence, "
118
  "'constant' makes chunks of the constant size",
119
  scale=2
 
135
  placeholder="512"
136
  )
137
  dropdown.change(fn=change_dropdown, inputs=dropdown, outputs=[split_seq, chunk_len])
138
+
139
+ tei_url = gr.Textbox(lines=1, label="TEI Endpoint URL")
140
+
141
  with gr.Row():
142
+ clear = gr.ClearButton(components=[input_text, dropdown, split_seq, chunk_len, tei_url])
 
 
 
143
  embed_btn = gr.Button("Submit")
144
  embed_btn.click(
145
+ fn=process_text,
146
+ inputs=[input_text, dropdown, split_seq, chunk_len, tei_url],
147
+ outputs=[gr.JSON(label="Chunks"), gr.Textbox(label="First Chunked Sentence"), gr.JSON(label="Embedded Sentence")]
148
  )
149
 
150
  demo.queue()
151
+ demo.launch(server_name="0.0.0.0", server_port=args.port)
152
+
153
  if __name__ == "__main__":
154
  parser = argparse.ArgumentParser(description="A MAGIC example by ConceptaTech")
155
  parser.add_argument("--port", type=int, default=7860, help="Port to expose Gradio app")
156
+ args = parser.parse_args()
157
+ main(args)