Benjamin Bossan commited on
Commit
a240da9
1 Parent(s): 7281bd6

Initial commit

Browse files
Files changed (12) hide show
  1. .gitignore +11 -0
  2. README.md +19 -0
  3. environment.yml +97 -0
  4. pyproject.toml +14 -0
  5. requests.org +57 -0
  6. requirements-dev.txt +4 -0
  7. requirements.txt +5 -0
  8. src/base.py +37 -0
  9. src/db.py +102 -0
  10. src/ml.py +143 -0
  11. src/webservice.py +112 -0
  12. src/worker.py +117 -0
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ .idea
3
+ *.log
4
+ tmp/
5
+
6
+ *.py[cod]
7
+ *.egg
8
+ build
9
+ htmlcov
10
+
11
+ *.db
README.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dump your knowledge, let AI refine it
2
+
3
+ ## Starting
4
+
5
+ Install stuff, then, in one terminal, start the background worker:
6
+
7
+ ```sh
8
+ cd src
9
+ python worker.py
10
+ ```
11
+
12
+ Start the web server:
13
+
14
+ ```sh
15
+ cd src
16
+ uvicorn webservice:app --reload --port 8080
17
+ ```
18
+
19
+ For example requests, check `requests.org`.
environment.yml ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: gistillery
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - defaults
6
+ dependencies:
7
+ - _libgcc_mutex=0.1=main
8
+ - _openmp_mutex=5.1=1_gnu
9
+ - blas=1.0=mkl
10
+ - bzip2=1.0.8=h7b6447c_0
11
+ - ca-certificates=2023.01.10=h06a4308_0
12
+ - cuda-cudart=11.7.99=0
13
+ - cuda-cupti=11.7.101=0
14
+ - cuda-libraries=11.7.1=0
15
+ - cuda-nvrtc=11.7.99=0
16
+ - cuda-nvtx=11.7.91=0
17
+ - cuda-runtime=11.7.1=0
18
+ - filelock=3.9.0=py310h06a4308_0
19
+ - gmp=6.2.1=h295c915_3
20
+ - gmpy2=2.1.2=py310heeb90bb_0
21
+ - intel-openmp=2023.1.0=hdb19cb5_46305
22
+ - jinja2=3.1.2=py310h06a4308_0
23
+ - ld_impl_linux-64=2.38=h1181459_1
24
+ - libcublas=11.10.3.66=0
25
+ - libcufft=10.7.2.124=h4fbf590_0
26
+ - libcufile=1.6.1.9=0
27
+ - libcurand=10.3.2.106=0
28
+ - libcusolver=11.4.0.1=0
29
+ - libcusparse=11.7.4.91=0
30
+ - libffi=3.4.2=h6a678d5_6
31
+ - libgcc-ng=11.2.0=h1234567_1
32
+ - libgomp=11.2.0=h1234567_1
33
+ - libnpp=11.7.4.75=0
34
+ - libnvjpeg=11.8.0.2=0
35
+ - libstdcxx-ng=11.2.0=h1234567_1
36
+ - libuuid=1.41.5=h5eee18b_0
37
+ - markupsafe=2.1.1=py310h7f8727e_0
38
+ - mkl=2023.1.0=h6d00ec8_46342
39
+ - mpc=1.1.0=h10f8cd9_1
40
+ - mpfr=4.0.2=hb69a4c5_1
41
+ - ncurses=6.4=h6a678d5_0
42
+ - networkx=2.8.4=py310h06a4308_1
43
+ - openssl=1.1.1t=h7f8727e_0
44
+ - pip=23.0.1=py310h06a4308_0
45
+ - python=3.10.11=h7a1cb2a_2
46
+ - pytorch=2.0.0=py3.10_cuda11.7_cudnn8.5.0_0
47
+ - pytorch-cuda=11.7=h778d358_3
48
+ - pytorch-mutex=1.0=cuda
49
+ - readline=8.2=h5eee18b_0
50
+ - setuptools=66.0.0=py310h06a4308_0
51
+ - sqlite=3.41.2=h5eee18b_0
52
+ - sympy=1.11.1=py310h06a4308_0
53
+ - tbb=2021.8.0=hdb19cb5_0
54
+ - tk=8.6.12=h1ccaba5_0
55
+ - torchtriton=2.0.0=py310
56
+ - typing_extensions=4.5.0=py310h06a4308_0
57
+ - tzdata=2023c=h04d1e81_0
58
+ - wheel=0.38.4=py310h06a4308_0
59
+ - xz=5.4.2=h5eee18b_0
60
+ - zlib=1.2.13=h5eee18b_0
61
+ - pip:
62
+ - anyio==3.6.2
63
+ - black==23.3.0
64
+ - certifi==2022.12.7
65
+ - charset-normalizer==3.1.0
66
+ - click==8.1.3
67
+ - fastapi==0.95.1
68
+ - fsspec==2023.4.0
69
+ - h11==0.14.0
70
+ - httptools==0.5.0
71
+ - huggingface-hub==0.14.1
72
+ - idna==3.4
73
+ - mpmath==1.2.1
74
+ - mypy==1.2.0
75
+ - mypy-extensions==1.0.0
76
+ - numpy==1.24.3
77
+ - packaging==23.1
78
+ - pathspec==0.11.1
79
+ - platformdirs==3.5.0
80
+ - pydantic==1.10.7
81
+ - python-dotenv==1.0.0
82
+ - pyyaml==6.0
83
+ - regex==2023.5.5
84
+ - requests==2.29.0
85
+ - ruff==0.0.264
86
+ - sniffio==1.3.0
87
+ - starlette==0.26.1
88
+ - tokenizers==0.13.3
89
+ - tomli==2.0.1
90
+ - tqdm==4.65.0
91
+ - transformers==4.28.1
92
+ - urllib3==1.26.15
93
+ - uvicorn==0.22.0
94
+ - uvloop==0.17.0
95
+ - watchfiles==0.19.0
96
+ - websockets==11.0.2
97
+ prefix: /home/vinh/anaconda3/envs/gistillery
pyproject.toml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.black]
2
+ line-length = 88
3
+ target_version = ['py310', 'py311']
4
+ preview = true
5
+
6
+ [tool.isort]
7
+ profile = "black"
8
+
9
+ [tool.mypy]
10
+ no_implicit_optional = true
11
+ strict = true
12
+
13
+ [[tool.mypy-transformers]]
14
+ ignore_missing_imports = true
requests.org ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #+title: Requests
2
+
3
+ #+begin_src bash
4
+ curl -X 'GET' \
5
+ 'http://localhost:8080/clear/' \
6
+ -H 'accept: application/json'
7
+ #+end_src
8
+
9
+ #+RESULTS:
10
+ : OK
11
+
12
+ #+begin_src bash
13
+ # curl command to localhost and post the message "hi there"
14
+ curl -X 'POST' \
15
+ 'http://localhost:8080/submit/' \
16
+ -H 'accept: application/json' \
17
+ -H 'Content-Type: application/json' \
18
+ -d '{
19
+ "author": "ben",
20
+ "content": "SAN FRANCISCO, May 2, 2023 PRNewswire -- GitLab Inc., the most comprehensive, scalable enterprise DevSecOps platform for software innovation, and Google Cloud today announced an extension of its strategic partnership to deliver secure AI offerings to the enterprise. GitLab is trusted by more than 50% of the Fortune 100 to secure and protect their most valuable assets, and leads with a privacy-first approach to AI. By leveraging Google Cloud'\''s customizable foundation models and open generative AI infrastructure, GitLab will provide customers with AI-assisted features directly within the enterprise DevSecOps platform."
21
+ }'
22
+ #+end_src
23
+
24
+ #+RESULTS:
25
+ : Submitted job 04d44970ced1473999dfab77b02202b8
26
+
27
+ #+begin_src bash
28
+ curl -X 'POST' \
29
+ 'http://localhost:8080/submit/' \
30
+ -H 'accept: application/json' \
31
+ -H 'Content-Type: application/json' \
32
+ -d '{
33
+ "author": "ben",
34
+ "content": "In literature discussing why ChatGPT is able to capture so much of our imagination, I often come across two narratives: Scale: throwing more data and compute at it. UX: moving from a prompt interface to a more natural chat interface. A narrative that is often glossed over in the demo frenzy is the incredible technical creativity that went into making models like ChatGPT work. One such cool idea is RLHF (Reinforcement Learning from Human Feedback): incorporating reinforcement learning and human feedback into NLP. RL has been notoriously difficult to work with, and therefore, mostly confined to gaming and simulated environments like Atari or MuJoCo. Just five years ago, both RL and NLP were progressing pretty much orthogonally – different stacks, different techniques, and different experimentation setups. It’s impressive to see it work in a new domain at a massive scale. So, how exactly does RLHF work? Why does it work? This post will discuss the answers to those questions."
35
+ }'
36
+ #+end_src
37
+
38
+ #+RESULTS:
39
+ : Submitted job 3cc2104aec0748b1bd5743c321b169ac
40
+
41
+ #+begin_src bash
42
+ curl -X 'GET' \
43
+ 'http://localhost:8080/check_status/22b158499b744f42918912cd387fd657' \
44
+ -H 'accept: application/json'
45
+ #+end_src
46
+
47
+ #+RESULTS:
48
+ | {"id":"22b158499b744f42918912cd387fd657" | status:"done" | last_updated:"2023-05-05T14:54:11"} |
49
+
50
+ #+begin_src bash
51
+ curl -X 'GET' \
52
+ 'http://localhost:8080/recent/' \
53
+ -H 'accept: application/json'
54
+ #+end_src
55
+
56
+ #+RESULTS:
57
+ | [{"id":"3cc2104aec0748b1bd5743c321b169ac" | author:"ben" | summary:"A new approach to NLP that incorporates reinforcement learning and human feedback. How does it work? Why does it work? In this post | I’ll explain how it works. RLHF is a new approach to NLP that incorporates reinforcement learning and human feedback. It’s a new approach to NLP that incorporates reinforcement learning and human feedback. It’s a new approach to NLP that incorporates reinforcement learning and human feedback. It’s a new approach to NLP that incorporates reinforcement learning and human feedback. It’s a new approach to NLP that incorporates reinforcement learning and human feedback." | tags:["#general" | #rlhf] | date:"2023-05-05T14:56:32"} | {"id":"04d44970ced1473999dfab77b02202b8" | author:"ben" | summary:"GitLab | the most comprehensive | scalable enterprise DevSecOps platform for software innovation | and Google Cloud today announced an extension of their strategic partnership to deliver secure AI offerings to the enterprise. By leveraging Google Cloud's customizable foundation models and open generative AI infrastructure | GitLab will provide customers with AI-assisted features directly within the enterprise DevSecOps platform. The company's AI capabilities are designed to help enterprises improve productivity and reduce costs." | tags:["#general"] | date:"2023-05-05T14:56:31"}] |
requirements-dev.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ black
2
+ isort
3
+ mypy
4
+ ruff
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ fastapi
2
+ httpx
3
+ uvicorn[standard]
4
+ torch
5
+ transformers
src/base.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime as dt
2
+ import enum
3
+
4
+ from pydantic import BaseModel
5
+
6
+
7
+ class RequestInput(BaseModel):
8
+ author: str
9
+ content: str
10
+
11
+
12
+ class EntriesResult(BaseModel):
13
+ id: str
14
+ author: str
15
+ summary: str
16
+ tags: list[str]
17
+ date: dt.datetime
18
+
19
+
20
+ class JobInput(BaseModel):
21
+ id: str
22
+ author: str
23
+ content: str
24
+
25
+
26
+ class JobStatus(str, enum.Enum):
27
+ pending = "pending"
28
+ done = "done"
29
+ failed = "failed"
30
+ cancelled = "cancelled"
31
+ not_found = "not found"
32
+
33
+
34
+ class JobStatusResult(BaseModel):
35
+ id: str
36
+ status: JobStatus
37
+ last_updated: dt.datetime | None
src/db.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sqlite3
3
+ from contextlib import contextmanager
4
+ from typing import Generator
5
+
6
+ logger = logging.getLogger(__name__)
7
+ logger.setLevel(logging.DEBUG)
8
+
9
+
10
+ schema_entries = """
11
+ CREATE TABLE entries
12
+ (
13
+ id TEXT PRIMARY KEY,
14
+ author TEXT NOT NULL,
15
+ source TEXT NOT NULL,
16
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
17
+ )
18
+ """
19
+
20
+ # create schema for 'summary' table, id is a uuid4
21
+ schema_summary = """
22
+ CREATE TABLE summaries
23
+ (
24
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
25
+ entry_id TEXT NOT NULL,
26
+ summary TEXT NOT NULL,
27
+ summarizer_name TEXT NOT NULL,
28
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
29
+ FOREIGN KEY(entry_id) REFERENCES entries(id)
30
+ )
31
+ """
32
+
33
+ schema_tag = """
34
+ CREATE TABLE tags
35
+ (
36
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
37
+ entry_id TEXT NOT NULL,
38
+ tag TEXT NOT NULL,
39
+ tagger_name TEXT NOT NULL,
40
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
41
+ FOREIGN KEY(entry_id) REFERENCES entries(id)
42
+ )
43
+ """
44
+
45
+ schema_job = """
46
+ CREATE TABLE jobs
47
+ (
48
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
49
+ entry_id TEXT NOT NULL,
50
+ status TEXT NOT NULL DEFAULT 'pending',
51
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
52
+ last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
53
+ FOREIGN KEY(entry_id) REFERENCES entries(id)
54
+ )
55
+ """
56
+
57
+ TABLES = {
58
+ "entries": schema_entries,
59
+ "summaries": schema_summary,
60
+ "tags": schema_tag,
61
+ "jobs": schema_job,
62
+ }
63
+ TABLES_CREATED = False
64
+
65
+
66
+ def _get_db_connection() -> sqlite3.Connection:
67
+ global TABLES_CREATED
68
+
69
+ # sqlite cannot deal with concurrent access, so we set a big timeout
70
+ conn = sqlite3.connect("sqlite-data.db", timeout=30)
71
+ if TABLES_CREATED:
72
+ return conn
73
+
74
+ cursor = conn.cursor()
75
+
76
+ # create tables if needed
77
+ for table_name, schema in TABLES.items():
78
+ cursor.execute(
79
+ "SELECT name FROM sqlite_master WHERE type='table' AND name=?",
80
+ (table_name,),
81
+ )
82
+ table_exists = cursor.fetchone() is not None
83
+ if not table_exists:
84
+ logger.info(f"'{table_name}' table does not exist, creating it now...")
85
+ cursor.execute(schema)
86
+ conn.commit()
87
+ logger.info("done")
88
+
89
+ TABLES_CREATED = True
90
+ return conn
91
+
92
+
93
+ @contextmanager
94
+ def get_db_cursor() -> Generator[sqlite3.Cursor, None, None]:
95
+ conn = _get_db_connection()
96
+ cursor = conn.cursor()
97
+ try:
98
+ yield cursor
99
+ finally:
100
+ conn.commit()
101
+ cursor.close()
102
+ conn.close()
src/ml.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import logging
3
+ import re
4
+
5
+ import httpx
6
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
7
+
8
+ from base import JobInput
9
+
10
+ logger = logging.getLogger(__name__)
11
+ logger.setLevel(logging.DEBUG)
12
+
13
+ MODEL_NAME = "google/flan-t5-large"
14
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
15
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
16
+
17
+
18
+ class Summarizer:
19
+ def __init__(self) -> None:
20
+ self.template = "Summarize the text below in two sentences:\n\n{}"
21
+ self.generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
22
+ self.generation_config.max_new_tokens = 200
23
+ self.generation_config.min_new_tokens = 100
24
+ self.generation_config.top_k = 5
25
+ self.generation_config.repetition_penalty = 1.5
26
+
27
+ def __call__(self, x: str) -> str:
28
+ text = self.template.format(x)
29
+ inputs = tokenizer(text, return_tensors="pt")
30
+ outputs = model.generate(**inputs, generation_config=self.generation_config)
31
+ output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
32
+ assert isinstance(output, str)
33
+ return output
34
+
35
+ def get_name(self) -> str:
36
+ return f"Summarizer({MODEL_NAME})"
37
+
38
+
39
+ class Tagger:
40
+ def __init__(self) -> None:
41
+ self.template = (
42
+ "Create a list of tags for the text below. The tags should be high level "
43
+ "and specific. Prefix each tag with a hashtag.\n\n{}\n\nTags: #general"
44
+ )
45
+ self.generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
46
+ self.generation_config.max_new_tokens = 50
47
+ self.generation_config.min_new_tokens = 25
48
+ # increase the temperature to make the model more creative
49
+ self.generation_config.temperature = 1.5
50
+
51
+ def _extract_tags(self, text: str) -> list[str]:
52
+ tags = set()
53
+ for tag in text.split():
54
+ if tag.startswith("#"):
55
+ tags.add(tag.lower())
56
+ return sorted(tags)
57
+
58
+ def __call__(self, x: str) -> list[str]:
59
+ text = self.template.format(x)
60
+ inputs = tokenizer(text, return_tensors="pt")
61
+ outputs = model.generate(**inputs, generation_config=self.generation_config)
62
+ output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
63
+ tags = self._extract_tags(output)
64
+ return tags
65
+
66
+ def get_name(self) -> str:
67
+ return f"Tagger({MODEL_NAME})"
68
+
69
+
70
+ class Processor(abc.ABC):
71
+ def __call__(self, job: JobInput) -> str:
72
+ _id = job.id
73
+ logger.info(f"Processing {input} with {self.__class__.__name__} (id={_id[:8]})")
74
+ result = self.process(job)
75
+ logger.info(f"Finished processing input (id={_id[:8]})")
76
+ return result
77
+
78
+ def process(self, input: JobInput) -> str:
79
+ raise NotImplementedError
80
+
81
+ def match(self, input: JobInput) -> bool:
82
+ raise NotImplementedError
83
+
84
+ def get_name(self) -> str:
85
+ raise NotImplementedError
86
+
87
+
88
+ class RawProcessor(Processor):
89
+ def match(self, input: JobInput) -> bool:
90
+ return True
91
+
92
+ def process(self, input: JobInput) -> str:
93
+ return input.content
94
+
95
+ def get_name(self) -> str:
96
+ return self.__class__.__name__
97
+
98
+
99
+ class PlainUrlProcessor(Processor):
100
+ def __init__(self) -> None:
101
+ self.client = httpx.Client()
102
+ self.regex = re.compile(r"(https?://[^\s]+)")
103
+ self.url = None
104
+ self.template = "{url}\n\n{content}"
105
+
106
+ def match(self, input: JobInput) -> bool:
107
+ urls = list(self.regex.findall(input.content))
108
+ if len(urls) == 1:
109
+ self.url = urls[0]
110
+ return True
111
+ return False
112
+
113
+ def process(self, input: JobInput) -> str:
114
+ """Get content of website and return it as string"""
115
+ assert isinstance(self.url, str)
116
+ text = self.client.get(self.url).text
117
+ assert isinstance(text, str)
118
+ text = self.template.format(url=self.url, content=text)
119
+ return text
120
+
121
+ def get_name(self) -> str:
122
+ return self.__class__.__name__
123
+
124
+
125
+ class ProcessorRegistry:
126
+ def __init__(self) -> None:
127
+ self.registry: list[Processor] = []
128
+ self.default_registry: list[Processor] = []
129
+ self.set_default_processors()
130
+
131
+ def set_default_processors(self) -> None:
132
+ self.default_registry.extend([PlainUrlProcessor(), RawProcessor()])
133
+
134
+ def register(self, processor: Processor) -> None:
135
+ self.registry.append(processor)
136
+
137
+ def dispatch(self, input: JobInput) -> Processor:
138
+ for processor in self.registry + self.default_registry:
139
+ if processor.match(input):
140
+ return processor
141
+
142
+ # should never be requires, but eh
143
+ return RawProcessor()
src/webservice.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import uuid
3
+
4
+ from fastapi import FastAPI
5
+
6
+ from base import EntriesResult, JobStatus, JobStatusResult, RequestInput
7
+ from db import TABLES, get_db_cursor
8
+
9
+
10
+ logger = logging.getLogger(__name__)
11
+ logger.setLevel(logging.DEBUG)
12
+
13
+
14
+ app = FastAPI()
15
+
16
+
17
+ @app.post("/submit/")
18
+ def submit_job(input: RequestInput) -> str:
19
+ # submit a new job, poor man's job queue
20
+ _id = uuid.uuid4().hex
21
+ logger.info(f"Submitting job for (_id={_id[:8]})")
22
+
23
+ with get_db_cursor() as cursor:
24
+ # create a job
25
+ query = "INSERT INTO jobs (entry_id, status) VALUES (?, ?)"
26
+ cursor.execute(query, (_id, "pending"))
27
+ # create an entry
28
+ query = "INSERT INTO entries (id, author, source) VALUES (?, ?, ?)"
29
+ cursor.execute(query, (_id, input.author, input.content))
30
+
31
+ return f"Submitted job {_id}"
32
+
33
+
34
+ @app.get("/check_status/{_id}")
35
+ def check_status(_id: str) -> JobStatusResult:
36
+ with get_db_cursor() as cursor:
37
+ cursor.execute(
38
+ "SELECT status, last_updated FROM jobs WHERE entry_id = ?", (_id,)
39
+ )
40
+ result = cursor.fetchone()
41
+
42
+ if result is None:
43
+ return JobStatusResult(id=_id, status=JobStatus.not_found, last_updated=None)
44
+
45
+ status, last_updated = result
46
+ return JobStatusResult(id=_id, status=status, last_updated=last_updated)
47
+
48
+
49
+ @app.get("/recent/")
50
+ def recent() -> list[EntriesResult]:
51
+ with get_db_cursor() as cursor:
52
+ # get the last 10 entries, join summary and tag, where each tag is
53
+ # joined to a comma separated str
54
+ cursor.execute("""
55
+ SELECT e.id, e.author, s.summary, GROUP_CONCAT(t.tag, ","), e.created_at
56
+ FROM entries e
57
+ JOIN summaries s ON e.id = s.entry_id
58
+ JOIN tags t ON e.id = t.entry_id
59
+ GROUP BY e.id
60
+ ORDER BY e.created_at DESC
61
+ LIMIT 10
62
+ """)
63
+ results = cursor.fetchall()
64
+
65
+ entries = []
66
+ for _id, author, summary, tags, date in results:
67
+ entry = EntriesResult(
68
+ id=_id, author=author, summary=summary, tags=tags.split(","), date=date
69
+ )
70
+ entries.append(entry)
71
+ return entries
72
+
73
+
74
+ @app.get("/recent/{tag}")
75
+ def recent_tag(tag: str) -> list[EntriesResult]:
76
+ if not tag.startswith("#"):
77
+ tag = "#" + tag
78
+
79
+ # same as recent, but filter by tag
80
+ with get_db_cursor() as cursor:
81
+ cursor.execute(
82
+ """
83
+ SELECT e.id, e.author, s.summary, GROUP_CONCAT(t.tag, ","), e.created_at
84
+ FROM entries e
85
+ JOIN summaries s ON e.id = s.entry_id
86
+ JOIN tags t ON e.id = t.entry_id
87
+ WHERE t.tag = ?
88
+ GROUP BY e.id
89
+ ORDER BY e.created_at DESC
90
+ LIMIT 10
91
+ """,
92
+ (tag,),
93
+ )
94
+ results = cursor.fetchall()
95
+
96
+ entries = []
97
+ for _id, author, summary, tags, date in results:
98
+ entry = EntriesResult(
99
+ id=_id, author=author, summary=summary, tags=tags.split(","), date=date
100
+ )
101
+ entries.append(entry)
102
+ return entries
103
+
104
+
105
+ @app.get("/clear/")
106
+ def clear() -> str:
107
+ # clear all tables
108
+ logger.warning("Clearing all tables")
109
+ with get_db_cursor() as cursor:
110
+ for table_name in TABLES:
111
+ cursor.execute(f"DELETE FROM {table_name}")
112
+ return "OK"
src/worker.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ from base import JobInput
4
+ from db import get_db_cursor
5
+ from ml import ProcessorRegistry, Summarizer, Tagger
6
+
7
+ SLEEP_INTERVAL = 5
8
+
9
+
10
+ processor_registry = ProcessorRegistry()
11
+ summarizer = Summarizer()
12
+ tagger = Tagger()
13
+ print("loaded ML models")
14
+
15
+
16
+ def check_pending_jobs() -> list[JobInput]:
17
+ """Check DB for pending jobs"""
18
+ with get_db_cursor() as cursor:
19
+ # fetch pending jobs, join authro and content from entries table
20
+ query = """
21
+ SELECT j.entry_id, e.author, e.source
22
+ FROM jobs j
23
+ JOIN entries e
24
+ ON j.entry_id = e.id
25
+ WHERE j.status = 'pending'
26
+ """
27
+ res = list(cursor.execute(query))
28
+ return [
29
+ JobInput(id=_id, author=author, content=content) for _id, author, content in res
30
+ ]
31
+
32
+
33
+ def store(
34
+ job: JobInput,
35
+ *,
36
+ summary: str,
37
+ tags: list[str],
38
+ processor_name: str,
39
+ summarizer_name: str,
40
+ tagger_name: str,
41
+ ) -> None:
42
+ with get_db_cursor() as cursor:
43
+ # write to entries, summary, tags tables
44
+ cursor.execute(
45
+ (
46
+ "INSERT INTO summaries (entry_id, summary, summarizer_name)"
47
+ " VALUES (?, ?, ?)"
48
+ ),
49
+ (job.id, summary, summarizer_name),
50
+ )
51
+ cursor.executemany(
52
+ "INSERT INTO tags (entry_id, tag, tagger_name) VALUES (?, ?, ?)",
53
+ [(job.id, tag, tagger_name) for tag in tags],
54
+ )
55
+
56
+
57
+ def process_job(job: JobInput) -> None:
58
+ tic = time.perf_counter()
59
+ print(f"Processing job for (id={job.id[:8]})")
60
+
61
+ # care: acquire cursor (which leads to locking) as late as possible, since
62
+ # the processing and we don't want to block other workers during that time
63
+ try:
64
+ processor = processor_registry.dispatch(job)
65
+ processor_name = processor.get_name()
66
+ processed = processor(job)
67
+
68
+ tagger_name = tagger.get_name()
69
+ tags = tagger(processed)
70
+
71
+ summarizer_name = summarizer.get_name()
72
+ summary = summarizer(processed)
73
+
74
+ store(
75
+ job,
76
+ summary=summary,
77
+ tags=tags,
78
+ processor_name=processor_name,
79
+ summarizer_name=summarizer_name,
80
+ tagger_name=tagger_name,
81
+ )
82
+ # update job status to done
83
+ with get_db_cursor() as cursor:
84
+ cursor.execute(
85
+ "UPDATE jobs SET status = 'done' WHERE entry_id = ?", (job.id,)
86
+ )
87
+ except Exception as e:
88
+ # update job status to failed
89
+ with get_db_cursor() as cursor:
90
+ cursor.execute(
91
+ "UPDATE jobs SET status = 'failed' WHERE entry_id = ?", (job.id,)
92
+ )
93
+ print(f"Failed to process job for (id={job.id[:8]}): {e}")
94
+
95
+ toc = time.perf_counter()
96
+ print(f"Finished processing job (id={job.id[:8]}) in {toc - tic:0.3f} seconds")
97
+
98
+
99
+ def main() -> None:
100
+ while True:
101
+ jobs = check_pending_jobs()
102
+ if not jobs:
103
+ print("No pending jobs found, sleeping...")
104
+ time.sleep(SLEEP_INTERVAL)
105
+ continue
106
+
107
+ print(f"Found {len(jobs)} pending job(s), processing...")
108
+ for job in jobs:
109
+ process_job(job)
110
+
111
+
112
+ if __name__ == "__main__":
113
+ try:
114
+ main()
115
+ except KeyboardInterrupt:
116
+ print("Shutting down...")
117
+ exit(0)