plaggy commited on
Commit
8b7a023
1 Parent(s): c618bd9
Files changed (4) hide show
  1. chunking_utils.py +42 -0
  2. embed_utils.py +55 -0
  3. main.py +39 -136
  4. models.py +19 -0
chunking_utils.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+
4
+ from tqdm import tqdm
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+
7
+ from models import env_config
8
+
9
+
10
+ class Chunker:
11
+ def __init__(self, strategy, split_seq=".", chunk_len=512):
12
+ self.split_seq = split_seq
13
+ self.chunk_len = chunk_len
14
+ if strategy == "recursive":
15
+ self.split = RecursiveCharacterTextSplitter(
16
+ chunk_size=chunk_len,
17
+ separators=[split_seq]
18
+ ).split_text
19
+ if strategy == "sequence":
20
+ self.split = self.seq_splitter
21
+ if strategy == "constant":
22
+ self.split = self.const_splitter
23
+
24
+ def seq_splitter(self, text):
25
+ return text.split(self.split_seq)
26
+
27
+ def const_splitter(self, text):
28
+ return [
29
+ text[i * self.chunk_len:(i + 1) * self.chunk_len]
30
+ for i in range(int(np.ceil(len(text) / self.chunk_len)))
31
+ ]
32
+
33
+
34
+ def chunk_generator(input_dataset, chunker, tmp_file):
35
+ for i in tqdm(range(len(input_dataset))):
36
+ chunks = chunker.split(input_dataset[i][env_config.input_text_col])
37
+ for chunk in chunks:
38
+ if chunk:
39
+ tmp_file.write(
40
+ json.dumps({env_config.input_text_col: chunk}) + "\n"
41
+ )
42
+ yield {env_config.input_text_col: chunk}
embed_utils.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import asyncio
3
+ import logging
4
+ import time
5
+
6
+ from tqdm.asyncio import tqdm_asyncio
7
+ from huggingface_hub import get_inference_endpoint
8
+
9
+ from models import env_config, embed_config
10
+
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ endpoint = get_inference_endpoint(env_config.tei_name, token=env_config.hf_token)
15
+
16
+
17
+ async def embed_chunk(sentence, semaphore, tmp_file):
18
+ async with semaphore:
19
+ payload = {
20
+ "inputs": sentence,
21
+ "truncate": True
22
+ }
23
+
24
+ try:
25
+ resp = await endpoint.async_client.post(json=payload)
26
+ except Exception as e:
27
+ raise RuntimeError(str(e))
28
+
29
+ result = json.loads(resp)
30
+ tmp_file.write(
31
+ json.dumps({"vector": result[0], env_config.input_text_col: sentence}) + "\n"
32
+ )
33
+
34
+
35
+ async def embed_wrapper(input_ds, temp_file):
36
+ semaphore = asyncio.BoundedSemaphore(embed_config.semaphore_bound)
37
+ jobs = [
38
+ asyncio.create_task(embed_chunk(row[env_config.input_text_col], semaphore, temp_file))
39
+ for row in input_ds if row[env_config.input_text_col].strip()
40
+ ]
41
+ logger.info(f"num chunks to embed: {len(jobs)}")
42
+
43
+ tic = time.time()
44
+ await tqdm_asyncio.gather(*jobs)
45
+ logger.info(f"embed time: {time.time() - tic}")
46
+
47
+
48
+ def wake_up_endpoint():
49
+ endpoint.fetch()
50
+ if endpoint.status != 'running':
51
+ logger.info("Starting up TEI endpoint")
52
+ endpoint.resume()
53
+ endpoint.wait()
54
+ logger.info("TEI endpoint is up")
55
+ return
main.py CHANGED
@@ -1,46 +1,32 @@
1
  import asyncio
2
  import logging
3
- import numpy as np
4
- import time
5
- import json
6
- import os
7
  import tempfile
8
- import requests
9
 
10
  from fastapi import FastAPI, Request, BackgroundTasks
11
  from fastapi.responses import HTMLResponse
12
  from fastapi.staticfiles import StaticFiles
13
  from fastapi.templating import Jinja2Templates
14
-
15
- from aiohttp import ClientSession
16
- from langchain.text_splitter import RecursiveCharacterTextSplitter
17
  from datasets import Dataset, load_dataset
18
- from tqdm import tqdm
19
- from tqdm.asyncio import tqdm_asyncio
20
 
21
- from models import chunk_config, embed_config, WebhookPayload
 
 
22
 
23
  logging.basicConfig(level=logging.INFO)
24
  logger = logging.getLogger(__name__)
25
 
26
- # you token from Settings
27
- HF_TOKEN = os.getenv("HF_TOKEN")
28
 
29
- # URL of TEI endpoint
30
- TEI_URL = os.getenv("TEI_URL")
31
- # name of chunked dataset
32
- CHUNKED_DS_NAME = os.getenv("CHUNKED_DS_NAME")
33
- # name of embeddings dataset
34
- EMBED_DS_NAME = os.getenv("EMBED_DS_NAME")
35
- # splits of input dataset to process, comma separated
36
- INPUT_SPLITS = os.getenv("INPUT_SPLITS")
37
- # name of column to load from input dataset
38
- INPUT_TEXT_COL = os.getenv("INPUT_TEXT_COL")
39
 
40
- INPUT_SPLITS = [spl.strip() for spl in INPUT_SPLITS.split(",") if spl]
41
 
42
- app = FastAPI()
43
- app.state.seen_Sha = set()
44
 
45
  app.mount("/static", StaticFiles(directory="static"), name="static")
46
  templates = Jinja2Templates(directory="templates")
@@ -61,151 +47,68 @@ async def post_webhook(
61
  and payload.event.scope.startswith("repo.content")
62
  and payload.repo.type == "dataset"
63
  # webhook posts multiple requests with the same update, this addresses that
64
- and payload.repo.headSha not in app.state.seen_Sha
65
  ):
66
- # no-op
67
  logger.info("Update detected, no action taken")
68
  return {"processed": False}
69
 
70
- app.state.seen_Sha.add(payload.repo.headSha)
71
- task_queue.add_task(chunk_dataset, ds_name=payload.repo.name)
72
- task_queue.add_task(embed_dataset, ds_name=CHUNKED_DS_NAME)
73
 
74
  return {"processed": True}
75
 
76
 
77
- """
78
- CHUNKING
79
- """
80
-
81
- class Chunker:
82
- def __init__(self, strategy, split_seq=".", chunk_len=512):
83
- self.split_seq = split_seq
84
- self.chunk_len = chunk_len
85
- if strategy == "recursive":
86
- self.split = RecursiveCharacterTextSplitter(
87
- chunk_size=chunk_len,
88
- separators=[split_seq]
89
- ).split_text
90
- if strategy == "sequence":
91
- self.split = self.seq_splitter
92
- if strategy == "constant":
93
- self.split = self.const_splitter
94
-
95
- def seq_splitter(self, text):
96
- return text.split(self.split_seq)
97
-
98
- def const_splitter(self, text):
99
- return [
100
- text[i * self.chunk_len:(i + 1) * self.chunk_len]
101
- for i in range(int(np.ceil(len(text) / self.chunk_len)))
102
- ]
103
-
104
-
105
- def chunk_generator(input_dataset, chunker):
106
- for i in tqdm(range(len(input_dataset))):
107
- chunks = chunker.split(input_dataset[i][INPUT_TEXT_COL])
108
- for chunk in chunks:
109
- if chunk:
110
- yield {INPUT_TEXT_COL: chunk}
111
-
112
-
113
- def chunk_dataset(ds_name):
114
  logger.info("Update detected, chunking is scheduled")
115
- input_ds = load_dataset(ds_name, split="+".join(INPUT_SPLITS))
116
  chunker = Chunker(
117
  strategy=chunk_config.strategy,
118
  split_seq=chunk_config.split_seq,
119
  chunk_len=chunk_config.chunk_len
120
  )
121
-
122
  dataset = Dataset.from_generator(
123
  chunk_generator,
124
  gen_kwargs={
125
  "input_dataset": input_ds,
126
- "chunker": chunker
 
127
  }
128
  )
129
 
130
  dataset.push_to_hub(
131
- CHUNKED_DS_NAME,
132
  private=chunk_config.private,
133
- token=HF_TOKEN
134
  )
135
 
136
  logger.info("Done chunking")
137
- return {"processed": True}
138
-
139
-
140
- """
141
- EMBEDDING
142
- """
143
 
144
- async def embed_sent(sentence, semaphore, tmp_file):
145
- async with semaphore:
146
- payload = {
147
- "inputs": sentence,
148
- "truncate": True
149
- }
150
 
151
- async with ClientSession(
152
- headers={
153
- "Content-Type": "application/json",
154
- "Authorization": f"Bearer {HF_TOKEN}"
155
- }
156
- ) as session:
157
- async with session.post(TEI_URL, json=payload) as resp:
158
- if resp.status != 200:
159
- raise RuntimeError(await resp.text())
160
- result = await resp.json()
161
-
162
- tmp_file.write(
163
- json.dumps({"vector": result[0], INPUT_TEXT_COL: sentence}) + "\n"
164
- )
165
-
166
-
167
- async def embed(input_ds, temp_file):
168
- semaphore = asyncio.BoundedSemaphore(embed_config.semaphore_bound)
169
- jobs = [
170
- asyncio.create_task(embed_sent(row[INPUT_TEXT_COL], semaphore, temp_file))
171
- for row in input_ds if row[INPUT_TEXT_COL].strip()
172
- ]
173
- logger.info(f"num chunks to embed: {len(jobs)}")
174
-
175
- tic = time.time()
176
- await tqdm_asyncio.gather(*jobs)
177
- logger.info(f"embed time: {time.time() - tic}")
178
-
179
-
180
- def wake_up_endpoint(url):
181
- logger.info("Starting up TEI endpoint")
182
- n_loop = 0
183
- while requests.get(
184
- url=url,
185
- headers={"Authorization": f"Bearer {HF_TOKEN}"}
186
- ).status_code != 200:
187
- time.sleep(2)
188
- n_loop += 1
189
- if n_loop > 40:
190
- raise TimeoutError("TEI endpoint is unavailable")
191
- logger.info("TEI endpoint is up")
192
-
193
-
194
- def embed_dataset(ds_name):
195
  logger.info("Update detected, embedding is scheduled")
196
- wake_up_endpoint(TEI_URL)
197
- input_ds = load_dataset(ds_name, split="train")
198
  with tempfile.NamedTemporaryFile(mode="a", suffix=".jsonl") as temp_file:
199
- asyncio.run(embed(input_ds, temp_file))
200
 
201
- dataset = Dataset.from_json(temp_file.name)
202
- dataset.push_to_hub(
203
- EMBED_DS_NAME,
204
  private=embed_config.private,
205
- token=HF_TOKEN
206
  )
207
 
 
208
  logger.info("Done embedding")
 
 
 
 
 
 
 
209
  return {"processed": True}
210
 
211
 
 
1
  import asyncio
2
  import logging
 
 
 
 
3
  import tempfile
 
4
 
5
  from fastapi import FastAPI, Request, BackgroundTasks
6
  from fastapi.responses import HTMLResponse
7
  from fastapi.staticfiles import StaticFiles
8
  from fastapi.templating import Jinja2Templates
9
+ from contextlib import asynccontextmanager
 
 
10
  from datasets import Dataset, load_dataset
 
 
11
 
12
+ from models import chunk_config, embed_config, env_config, WebhookPayload
13
+ from chunking_utils import Chunker, chunk_generator
14
+ from embed_utils import wake_up_endpoint, embed_wrapper
15
 
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
19
+ app_state = {}
20
+
21
 
22
+ @asynccontextmanager
23
+ async def lifespan(app: FastAPI):
24
+ app_state["seen_Sha"] = set()
25
+ yield
26
+ app_state.clear()
 
 
 
 
 
27
 
 
28
 
29
+ app = FastAPI(lifespan=lifespan)
 
30
 
31
  app.mount("/static", StaticFiles(directory="static"), name="static")
32
  templates = Jinja2Templates(directory="templates")
 
47
  and payload.event.scope.startswith("repo.content")
48
  and payload.repo.type == "dataset"
49
  # webhook posts multiple requests with the same update, this addresses that
50
+ and payload.repo.headSha not in app_state["seen_Sha"]
51
  ):
 
52
  logger.info("Update detected, no action taken")
53
  return {"processed": False}
54
 
55
+ app_state["seen_Sha"].add(payload.repo.headSha)
56
+ task_queue.add_task(chunk_and_embed, input_ds_name=payload.repo.name)
 
57
 
58
  return {"processed": True}
59
 
60
 
61
+ def chunk(ds_name):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  logger.info("Update detected, chunking is scheduled")
63
+ input_ds = load_dataset(ds_name, split="+".join(env_config.input_splits))
64
  chunker = Chunker(
65
  strategy=chunk_config.strategy,
66
  split_seq=chunk_config.split_seq,
67
  chunk_len=chunk_config.chunk_len
68
  )
69
+ tmp_file = tempfile.NamedTemporaryFile(mode="a", suffix=".jsonl")
70
  dataset = Dataset.from_generator(
71
  chunk_generator,
72
  gen_kwargs={
73
  "input_dataset": input_ds,
74
+ "chunker": chunker,
75
+ "tmp_file": tmp_file
76
  }
77
  )
78
 
79
  dataset.push_to_hub(
80
+ env_config.chunked_ds_name,
81
  private=chunk_config.private,
82
+ token=env_config.hf_token
83
  )
84
 
85
  logger.info("Done chunking")
86
+ return tmp_file
 
 
 
 
 
87
 
 
 
 
 
 
 
88
 
89
+ def embed(chunked_file):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  logger.info("Update detected, embedding is scheduled")
91
+ wake_up_endpoint()
92
+ chunked_ds = Dataset.from_json(chunked_file.name)
93
  with tempfile.NamedTemporaryFile(mode="a", suffix=".jsonl") as temp_file:
94
+ asyncio.run(embed_wrapper(chunked_ds, temp_file))
95
 
96
+ emb_ds = Dataset.from_json(temp_file.name)
97
+ emb_ds.push_to_hub(
98
+ env_config.embed_ds_name,
99
  private=embed_config.private,
100
+ token=env_config.hf_token
101
  )
102
 
103
+ chunked_file.close()
104
  logger.info("Done embedding")
105
+ return
106
+
107
+
108
+ def chunk_and_embed(input_ds_name):
109
+ chunked_tmp_file = chunk(input_ds_name)
110
+ embed(chunked_tmp_file)
111
+
112
  return {"processed": True}
113
 
114
 
models.py CHANGED
@@ -4,6 +4,21 @@ from pydantic import BaseModel
4
  from typing import Literal
5
 
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  class ChunkConfig(BaseModel):
8
  strategy: Literal["recursive", "sequence", "constant"]
9
  split_seq: str
@@ -41,3 +56,7 @@ with open(os.path.join(os.getcwd(), "configs/chunk_config.json")) as c:
41
  with open(os.path.join(os.getcwd(), "configs/embed_config.json")) as c:
42
  data = json.load(c)
43
  embed_config = EmbedConfig.model_validate_json(json.dumps(data))
 
 
 
 
 
4
  from typing import Literal
5
 
6
 
7
+ class EnvConfig(BaseModel):
8
+ # you token from Settings
9
+ hf_token: str = os.getenv("HF_TOKEN")
10
+ # NAME of TEI endpoint
11
+ tei_name: str = os.getenv("TEI_NAME")
12
+ # name of chunked dataset
13
+ chunked_ds_name: str = os.getenv("CHUNKED_DS_NAME")
14
+ # name of embeddings dataset
15
+ embed_ds_name: str = os.getenv("EMBED_DS_NAME")
16
+ # splits of input dataset to process, comma separated
17
+ input_splits: str = os.getenv("INPUT_SPLITS")
18
+ # name of column to load from input dataset
19
+ input_text_col: str = os.getenv("INPUT_TEXT_COL")
20
+
21
+
22
  class ChunkConfig(BaseModel):
23
  strategy: Literal["recursive", "sequence", "constant"]
24
  split_seq: str
 
56
  with open(os.path.join(os.getcwd(), "configs/embed_config.json")) as c:
57
  data = json.load(c)
58
  embed_config = EmbedConfig.model_validate_json(json.dumps(data))
59
+
60
+
61
+ env_config = EnvConfig()
62
+ env_config.input_splits = [spl.strip() for spl in env_config.input_splits.split(",") if spl]