plaggy commited on
Commit
97fdba5
1 Parent(s): 40b241b
Files changed (9) hide show
  1. Dockerfile +16 -0
  2. chunk_config.json +10 -0
  3. embed_config.json +8 -0
  4. home.html +18 -0
  5. requirements.txt +8 -0
  6. src/__init__.py +0 -0
  7. src/main.py +185 -0
  8. src/models.py +51 -0
  9. style.css +28 -0
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ RUN useradd -m -u 1000 user
4
+ USER user
5
+
6
+ ENV HOME=/home/user \
7
+ PATH=/home/user/.local/bin:$PATH
8
+
9
+ WORKDIR $HOME/app
10
+
11
+ COPY --chown=user requirements.txt requirements.txt
12
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
13
+
14
+ COPY --chown=user . .
15
+
16
+ CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "7860"]
chunk_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input_dataset": "sergeipetrov/transformers-diffusers-docs-raw",
3
+ "input_splits": ["train"],
4
+ "input_text_col": "text",
5
+ "output_dataset": "sergeipetrov/transformers-diffusers-docs-chunked",
6
+ "strategy": "spacy",
7
+ "split_seq": "\n\n",
8
+ "chunk_len": 512,
9
+ "private": "false"
10
+ }
embed_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input_dataset": "sergeipetrov/transformers-diffusers-docs-chunked",
3
+ "input_splits": ["train"],
4
+ "input_text_col": "text",
5
+ "output_dataset": "sergeipetrov/transformers-diffusers-docs-embed",
6
+ "private": "false",
7
+ "semaphore_bound": 5
8
+ }
home.html ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <meta charset="utf-8" />
5
+ <meta name="viewport" content="width=device-width" />
6
+ <title>Auto Re-Train</title>
7
+ <link rel="stylesheet" href="style.css" />
8
+ </head>
9
+ <body>
10
+ <div class="card">
11
+ <h1>Auto Re-Train webhook</h1>
12
+
13
+ <p>This is a webhook space to auto-retrain on model when a dataset changes.</p>
14
+
15
+ <p>Check out the guide <a href="https://huggingface.co/docs/hub/webhooks-guide-auto-retrain" target="_blank">here</a>!</p>
16
+ </div>
17
+ </body>
18
+ </html>
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.74.*
2
+ requests==2.27.*
3
+ huggingface_hub==0.11.*
4
+ uvicorn[standard]==0.17.*
5
+ numpy==1.25.*
6
+ datasets==2.16.*
7
+ langchain==0.0.*
8
+ aiohttp==3.8.*
src/__init__.py ADDED
File without changes
src/main.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ import numpy as np
4
+ import time
5
+ import json
6
+ import os
7
+ import tempfile
8
+ import requests
9
+
10
+ from fastapi import FastAPI, Header, HTTPException, BackgroundTasks
11
+ from fastapi.responses import FileResponse
12
+
13
+ from aiohttp import ClientSession
14
+ from langchain.text_splitter import SpacyTextSplitter
15
+ from datasets import Dataset, load_dataset
16
+ from tqdm import tqdm
17
+ from tqdm.asyncio import tqdm_asyncio
18
+
19
+ from src.models import chunk_config, embed_config, WebhookPayload
20
+
21
+
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+ HF_TOKEN = os.getenv("HF_TOKEN")
26
+ TEI_URL = os.getenv("TEI_URL")
27
+
28
+ app = FastAPI()
29
+
30
+
31
+ @app.get("/")
32
+ async def home():
33
+ return FileResponse("home.html")
34
+
35
+
36
+ @app.post("/webhook")
37
+ async def post_webhook(
38
+ payload: WebhookPayload,
39
+ task_queue: BackgroundTasks
40
+ ):
41
+ if not (
42
+ payload.event.action == "update"
43
+ and payload.event.scope.startswith("repo.content")
44
+ and (
45
+ payload.repo.name == embed_config.input_dataset
46
+ # or payload.repo.name == chunk_config.input_dataset
47
+ )
48
+ and payload.repo.type == "dataset"
49
+ ):
50
+ # no-op
51
+ logger.info("Update detected, no action taken")
52
+ return {"processed": False}
53
+
54
+ if payload.repo.name == chunk_config.input_dataset:
55
+ task_queue.add_task(chunk_dataset)
56
+ task_queue.add_task(embed_dataset)
57
+
58
+ return {"processed": True}
59
+
60
+
61
+ """
62
+ CHUNKING
63
+ """
64
+
65
+ class Chunker:
66
+ def __init__(self, strategy, split_seq, chunk_len):
67
+ self.split_seq = split_seq
68
+ self.chunk_len = chunk_len
69
+ if strategy == "spacy":
70
+ self.split = SpacyTextSplitter().split_text
71
+ if strategy == "sequence":
72
+ self.split = self.seq_splitter
73
+ if strategy == "constant":
74
+ self.split = self.const_splitter
75
+
76
+ def seq_splitter(self, text):
77
+ return text.split(self.split_seq)
78
+
79
+ def const_splitter(self, text):
80
+ return [
81
+ text[i * self.chunk_len:(i + 1) * self.chunk_len]
82
+ for i in range(int(np.ceil(len(text) / self.chunk_len)))
83
+ ]
84
+
85
+
86
+ def chunk_generator(input_dataset, chunker):
87
+ for i in tqdm(range(len(input_dataset))):
88
+ chunks = chunker.split(input_dataset[i][chunk_config.input_text_col])
89
+ for chunk in chunks:
90
+ if chunk:
91
+ yield {chunk_config.input_text_col: chunk}
92
+
93
+
94
+ def chunk_dataset():
95
+ logger.info("Update detected, chunking is scheduled")
96
+ input_ds = load_dataset(chunk_config.input_dataset, split=chunk_config.input_splits)
97
+ chunker = Chunker(
98
+ strategy=chunk_config.strategy,
99
+ split_seq=chunk_config.split_seq,
100
+ chunk_len=chunk_config.chunk_len
101
+ )
102
+
103
+ dataset = Dataset.from_generator(
104
+ chunk_generator,
105
+ gen_kwargs={
106
+ "input_dataset": input_ds,
107
+ "chunker": chunker
108
+ }
109
+ )
110
+
111
+ dataset.push_to_hub(
112
+ chunk_config.output_dataset,
113
+ private=chunk_config.private,
114
+ token=HF_TOKEN
115
+ )
116
+
117
+ logger.info("Done chunking")
118
+
119
+ return {"processed": True}
120
+
121
+
122
+ """
123
+ EMBEDDING
124
+ """
125
+
126
+ async def embed_sent(sentence, semaphore, tei_url, tmp_file):
127
+ async with semaphore:
128
+ payload = {
129
+ "inputs": sentence,
130
+ "truncate": True
131
+ }
132
+
133
+ async with ClientSession(
134
+ headers={
135
+ "Content-Type": "application/json",
136
+ "Authorization": f"Bearer {HF_TOKEN}"
137
+ }
138
+ ) as session:
139
+ async with session.post(tei_url, json=payload) as resp:
140
+ if resp.status != 200:
141
+ raise RuntimeError(await resp.text())
142
+ result = await resp.json()
143
+
144
+ tmp_file.write(
145
+ json.dumps({"vector": result[0], chunk_config.input_text_col: sentence}) + "\n"
146
+ )
147
+
148
+
149
+ async def embed(input_ds, tei_url, temp_file):
150
+ semaphore = asyncio.BoundedSemaphore(embed_config.semaphore_bound)
151
+ jobs = [
152
+ asyncio.create_task(embed_sent(row[chunk_config.input_text_col], semaphore, tei_url, temp_file))
153
+ for row in input_ds if row[chunk_config.input_text_col].strip()
154
+ ]
155
+ logger.info(f"num chunks to embed: {len(jobs)}")
156
+
157
+ tic = time.time()
158
+ await tqdm_asyncio.gather(*jobs)
159
+ logger.info(f"embed time: {time.time() - tic}")
160
+
161
+
162
+ def wake_up_endpoint(url):
163
+ while requests.get(
164
+ url=url,
165
+ headers={"Authorization": f"Bearer {HF_TOKEN}"}
166
+ ).status_code != 200:
167
+ time.sleep(2)
168
+ logger.info("TEI endpoint is up")
169
+
170
+
171
+ def embed_dataset():
172
+ logger.info("Update detected, embedding is scheduled")
173
+ wake_up_endpoint(embed_config.tei_url)
174
+ input_ds = load_dataset(embed_config.input_dataset, split=embed_config.input_splits)
175
+ with tempfile.NamedTemporaryFile(mode="a", suffix=".jsonl") as temp_file:
176
+ asyncio.run(embed(input_ds, embed_config.tei_url, temp_file))
177
+
178
+ dataset = Dataset.from_json(temp_file.name)
179
+ dataset.push_to_hub(
180
+ embed_config.output_dataset,
181
+ private=embed_config.private,
182
+ token=HF_TOKEN
183
+ )
184
+
185
+ logger.info("Done embedding")
src/models.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from pydantic import BaseModel
4
+ from typing import Literal
5
+
6
+
7
+ class ChunkConfig(BaseModel):
8
+ input_dataset: str
9
+ input_splits: list[str]
10
+ input_text_col: str
11
+ output_dataset: str
12
+ strategy: Literal["spacy", "sequence", "constant"]
13
+ split_seq: str | list[str]
14
+ chunk_len: int
15
+ private: bool
16
+
17
+
18
+ class EmbedConfig(BaseModel):
19
+ input_dataset: str
20
+ input_splits: list[str]
21
+ input_text_col: str
22
+ output_dataset: str
23
+ private: bool
24
+ semaphore_bound: int
25
+
26
+
27
+ class WebhookPayloadEvent(BaseModel):
28
+ action: Literal["create", "update", "delete"]
29
+ scope: str
30
+
31
+
32
+ class WebhookPayloadRepo(BaseModel):
33
+ type: Literal["dataset", "model", "space"]
34
+ name: str
35
+ id: str
36
+ private: bool
37
+ headSha: str
38
+
39
+
40
+ class WebhookPayload(BaseModel):
41
+ event: WebhookPayloadEvent
42
+ repo: WebhookPayloadRepo
43
+
44
+
45
+ with open(os.path.join(os.getcwd(), "chunk_config.json")) as c:
46
+ data = json.load(c)
47
+ chunk_config = ChunkConfig.model_validate_json(json.dumps(data))
48
+
49
+ with open(os.path.join(os.getcwd(), "embed_config.json")) as c:
50
+ data = json.load(c)
51
+ embed_config = EmbedConfig.model_validate_json(json.dumps(data))
style.css ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ body {
2
+ padding: 2rem;
3
+ font-family: -apple-system, BlinkMacSystemFont, "Arial", sans-serif;
4
+ }
5
+
6
+ h1 {
7
+ font-size: 16px;
8
+ margin-top: 0;
9
+ }
10
+
11
+ p {
12
+ color: rgb(107, 114, 128);
13
+ font-size: 15px;
14
+ margin-bottom: 10px;
15
+ margin-top: 5px;
16
+ }
17
+
18
+ .card {
19
+ max-width: 620px;
20
+ margin: 0 auto;
21
+ padding: 16px;
22
+ border: 1px solid lightgray;
23
+ border-radius: 16px;
24
+ }
25
+
26
+ .card p:last-child {
27
+ margin-bottom: 0;
28
+ }