rodrigomasini commited on
Commit
d835ff5
1 Parent(s): e6561b1

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +217 -22
app/main.py CHANGED
@@ -1,36 +1,231 @@
1
  import argparse
 
 
2
  import gradio as gr
3
- from ui import chat
 
 
4
  import os
5
- from dotenv import load_dotenv
 
 
 
 
 
 
 
 
6
 
7
  load_dotenv()
8
 
9
  USERNAME = os.getenv("USERNAME")
10
  PWD = os.getenv("USER_PWD")
 
 
11
 
12
- def main(args):
13
- demo = gr.ChatInterface(
14
- fn=chat,
15
- examples=["Explain the AI adoption challenges for enterprises.", "How can we identify a fraud transaction?", "Por que os grandes modelos de linguagem de AI halucinam?"],
16
- title="Chat and LLM server in the same application",
17
- description="This space is a template that we can duplicate for your own usage. "
18
- "This space let you build LLM powered idea on top of [Gradio](https://www.gradio.app/) "
19
- "and open LLM served locally by [TGI(Text Generation Inference)](https://huggingface.co/docs/text-generation-inference/en/index). "
20
- "Below is a placeholder Gradio ChatInterface for you to try out Mistral-7B backed by the power of TGI's efficiency. \n\n"
21
- "To use this space for your own usecase, follow the simple steps below:\n"
22
- "1. Duplicate this space. \n"
23
- "2. Set which LLM you wish to use (i.e. mistralai/Mistral-7B-Instruct-v0.2). \n"
24
- "3. Inside app/main.py write Gradio application. \n",
25
- multimodal=False,
26
- theme='sudeepshouche/minimalist',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  )
28
-
29
- demo.queue(
30
- default_concurrency_limit=20,
31
- max_size=256
32
- ).launch(auth=(USERNAME, PWD), server_name="0.0.0.0", server_port=args.port)
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  if __name__ == "__main__":
35
  parser = argparse.ArgumentParser(description="A MAGIC example by ConceptaTech")
36
  parser.add_argument("--port", type=int, default=7860, help="Port to expose Gradio app")
 
1
  import argparse
2
+ from dotenv import load_dotenv
3
+ import asyncio
4
  import gradio as gr
5
+ import numpy as np
6
+ import time
7
+ import json
8
  import os
9
+ import tempfile
10
+ import requests
11
+ import logging
12
+
13
+ from aiohttp import ClientSession
14
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
15
+ from datasets import Dataset, load_dataset
16
+ from tqdm import tqdm
17
+ from tqdm.asyncio import tqdm_asyncio
18
 
19
  load_dotenv()
20
 
21
  USERNAME = os.getenv("USERNAME")
22
  PWD = os.getenv("USER_PWD")
23
+ HF_TOKEN = os.getenv("HF_TOKEN")
24
+ SEMAPHORE_BOUND = os.getenv("SEMAPHORE_BOUND", "5")
25
 
26
+
27
+ logging.basicConfig(level=logging.INFO)
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class Chunker:
32
+ def __init__(self, strategy, split_seq=".", chunk_len=512):
33
+ self.split_seq = split_seq
34
+ self.chunk_len = chunk_len
35
+ if strategy == "recursive":
36
+ # https://huggingface.co/spaces/m-ric/chunk_visualizer
37
+ self.split = RecursiveCharacterTextSplitter(
38
+ chunk_size=chunk_len,
39
+ separators=[split_seq]
40
+ ).split_text
41
+ if strategy == "sequence":
42
+ self.split = self.seq_splitter
43
+ if strategy == "constant":
44
+ self.split = self.const_splitter
45
+
46
+ def seq_splitter(self, text):
47
+ return text.split(self.split_seq)
48
+
49
+ def const_splitter(self, text):
50
+ return [
51
+ text[i * self.chunk_len:(i + 1) * self.chunk_len]
52
+ for i in range(int(np.ceil(len(text) / self.chunk_len)))
53
+ ]
54
+
55
+
56
+ def generator(input_ds, input_text_col, chunker):
57
+ for i in tqdm(range(len(input_ds))):
58
+ chunks = chunker.split(input_ds[i][input_text_col])
59
+ for chunk in chunks:
60
+ if chunk:
61
+ yield {input_text_col: chunk}
62
+
63
+
64
+ async def embed_sent(sentence, embed_in_text_col, semaphore, tei_url, tmp_file):
65
+ async with semaphore:
66
+ payload = {
67
+ "inputs": sentence,
68
+ "truncate": True
69
+ }
70
+
71
+ async with ClientSession(
72
+ headers={
73
+ "Content-Type": "application/json",
74
+ "Authorization": f"Bearer {HF_TOKEN}"
75
+ }
76
+ ) as session:
77
+ async with session.post(tei_url, json=payload) as resp:
78
+ if resp.status != 200:
79
+ raise RuntimeError(await resp.text())
80
+ result = await resp.json()
81
+
82
+ tmp_file.write(
83
+ json.dumps({"vector": result[0], embed_in_text_col: sentence}) + "\n"
84
+ )
85
+
86
+
87
+ async def embed_ds(input_ds, tei_url, embed_in_text_col, temp_file):
88
+ semaphore = asyncio.BoundedSemaphore(int(SEMAPHORE_BOUND))
89
+ jobs = [
90
+ asyncio.create_task(embed_sent(row[embed_in_text_col], embed_in_text_col, semaphore, tei_url, temp_file))
91
+ for row in input_ds if row[embed_in_text_col].strip()
92
+ ]
93
+ logger.info(f"num chunks to embed: {len(jobs)}")
94
+
95
+ tic = time.time()
96
+ await tqdm_asyncio.gather(*jobs)
97
+ logger.info(f"embed time: {time.time() - tic}")
98
+
99
+
100
+ def wake_up_endpoint(url):
101
+ logger.info("Starting up TEI endpoint")
102
+ n_loop = 0
103
+ while requests.get(
104
+ url=url,
105
+ headers={"Authorization": f"Bearer {HF_TOKEN}"}
106
+ ).status_code != 200:
107
+ time.sleep(2)
108
+ n_loop += 1
109
+ if n_loop > 40:
110
+ raise gr.Error("TEI endpoint is unavailable")
111
+ logger.info("TEI endpoint is up")
112
+
113
+
114
+ def chunk_embed(input_ds, input_splits, input_text_col, chunk_out_ds,
115
+ strategy, split_seq, chunk_len, embed_out_ds, tei_url, private):
116
+ gr.Info("Started chunking")
117
+ try:
118
+ input_splits = [spl.strip() for spl in input_splits.split(",") if spl]
119
+ input_ds = load_dataset(input_ds, split="+".join(input_splits), token=HF_TOKEN)
120
+ chunker = Chunker(strategy, split_seq, chunk_len)
121
+ except Exception as e:
122
+ raise gr.Error(str(e))
123
+
124
+ gen_kwargs = {
125
+ "input_ds": input_ds,
126
+ "input_text_col": input_text_col,
127
+ "chunker": chunker
128
+ }
129
+ chunked_ds = Dataset.from_generator(generator, gen_kwargs=gen_kwargs)
130
+ chunked_ds.push_to_hub(
131
+ chunk_out_ds,
132
+ private=private,
133
+ token=HF_TOKEN
134
  )
 
 
 
 
 
135
 
136
+ gr.Info("Done chunking")
137
+ logger.info("Done chunking")
138
+
139
+ try:
140
+ wake_up_endpoint(tei_url)
141
+ with tempfile.NamedTemporaryFile(mode="a", suffix=".jsonl") as temp_file:
142
+ asyncio.run(embed_ds(chunked_ds, tei_url, input_text_col, temp_file))
143
+
144
+ embedded_ds = Dataset.from_json(temp_file.name)
145
+ embedded_ds.push_to_hub(
146
+ embed_out_ds,
147
+ private=private,
148
+ token=HF_TOKEN
149
+ )
150
+ except Exception as e:
151
+ raise gr.Error(str(e))
152
+
153
+ gr.Info("Done embedding")
154
+ logger.info("Done embedding")
155
+
156
+
157
+ def change_dropdown(choice):
158
+ if choice == "recursive":
159
+ return [
160
+ gr.Textbox(visible=True),
161
+ gr.Textbox(visible=True)
162
+ ]
163
+ elif choice == "sequence":
164
+ return [
165
+ gr.Textbox(visible=True),
166
+ gr.Textbox(visible=False)
167
+ ]
168
+ else:
169
+ return [
170
+ gr.Textbox(visible=False),
171
+ gr.Textbox(visible=True)
172
+ ]
173
+
174
+
175
+ def main(args):
176
+ demo= gr.Blocks(theme='sudeepshouche/minimalist'):
177
+ gr.Markdown(
178
+ """
179
+ ## Chunk and embed
180
+ """
181
+ )
182
+ input_ds = gr.Textbox(lines=1, label="Input dataset name")
183
+ with gr.Row():
184
+ input_splits = gr.Textbox(lines=1, label="Input dataset splits", placeholder="train, test")
185
+ input_text_col = gr.Textbox(lines=1, label="Input text column name", placeholder="text")
186
+ chunk_out_ds = gr.Textbox(lines=1, label="Chunked dataset name")
187
+ with gr.Row():
188
+ dropdown = gr.Dropdown(
189
+ ["recursive", "sequence", "constant"], label="Chunking strategy",
190
+ info="'recursive' uses a Langchain recursive tokenizer, 'sequence' splits texts by a chosen sequence, "
191
+ "'constant' makes chunks of the constant size",
192
+ scale=2
193
+ )
194
+ split_seq = gr.Textbox(
195
+ lines=1,
196
+ interactive=True,
197
+ visible=False,
198
+ label="Sequence",
199
+ info="A text sequence to split on",
200
+ placeholder="\n\n"
201
+ )
202
+ chunk_len = gr.Textbox(
203
+ lines=1,
204
+ interactive=True,
205
+ visible=False,
206
+ label="Length",
207
+ info="The length of chunks to split into in characters",
208
+ placeholder="512"
209
+ )
210
+ dropdown.change(fn=change_dropdown, inputs=dropdown, outputs=[split_seq, chunk_len])
211
+ embed_out_ds = gr.Textbox(lines=1, label="Embedded dataset name")
212
+ private = gr.Checkbox(label="Make output datasets private")
213
+ tei_url = gr.Textbox(lines=1, label="TEI endpoint url")
214
+ with gr.Row():
215
+ clear = gr.ClearButton(
216
+ components=[input_ds, input_splits, input_text_col, chunk_out_ds,
217
+ dropdown, split_seq, chunk_len, embed_out_ds, tei_url, private]
218
+ )
219
+ embed_btn = gr.Button("Submit")
220
+ embed_btn.click(
221
+ fn=chunk_embed,
222
+ inputs=[input_ds, input_splits, input_text_col, chunk_out_ds,
223
+ dropdown, split_seq, chunk_len, embed_out_ds, tei_url, private]
224
+ )
225
+
226
+ demo.queue()
227
+ demo.launch(auth=(USERNAME, PWD), server_name="0.0.0.0", server_port=args.port)
228
+ ######
229
  if __name__ == "__main__":
230
  parser = argparse.ArgumentParser(description="A MAGIC example by ConceptaTech")
231
  parser.add_argument("--port", type=int, default=7860, help="Port to expose Gradio app")