Benjamin Bossan commited on
Commit
126a4c6
1 Parent(s): a240da9

Refactor ml model handling

Browse files
Files changed (4) hide show
  1. src/db.py +4 -1
  2. src/ml.py +116 -67
  3. src/webservice.py +6 -0
  4. src/worker.py +79 -38
src/db.py CHANGED
@@ -1,4 +1,5 @@
1
  import logging
 
2
  import sqlite3
3
  from contextlib import contextmanager
4
  from typing import Generator
@@ -6,6 +7,8 @@ from typing import Generator
6
  logger = logging.getLogger(__name__)
7
  logger.setLevel(logging.DEBUG)
8
 
 
 
9
 
10
  schema_entries = """
11
  CREATE TABLE entries
@@ -67,7 +70,7 @@ def _get_db_connection() -> sqlite3.Connection:
67
  global TABLES_CREATED
68
 
69
  # sqlite cannot deal with concurrent access, so we set a big timeout
70
- conn = sqlite3.connect("sqlite-data.db", timeout=30)
71
  if TABLES_CREATED:
72
  return conn
73
 
 
1
  import logging
2
+ import os
3
  import sqlite3
4
  from contextlib import contextmanager
5
  from typing import Generator
 
7
  logger = logging.getLogger(__name__)
8
  logger.setLevel(logging.DEBUG)
9
 
10
+ db_file = os.getenv("DB_FILE_NAME", "sqlite-data.db")
11
+
12
 
13
  schema_entries = """
14
  CREATE TABLE entries
 
70
  global TABLES_CREATED
71
 
72
  # sqlite cannot deal with concurrent access, so we set a big timeout
73
+ conn = sqlite3.connect(db_file, timeout=30)
74
  if TABLES_CREATED:
75
  return conn
76
 
src/ml.py CHANGED
@@ -1,52 +1,126 @@
1
  import abc
 
2
  import logging
3
  import re
4
 
5
  import httpx
6
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
7
 
8
  from base import JobInput
9
 
10
  logger = logging.getLogger(__name__)
11
  logger.setLevel(logging.DEBUG)
12
 
13
- MODEL_NAME = "google/flan-t5-large"
14
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
15
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
16
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- class Summarizer:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def __init__(self) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  self.template = "Summarize the text below in two sentences:\n\n{}"
21
- self.generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
22
- self.generation_config.max_new_tokens = 200
23
- self.generation_config.min_new_tokens = 100
24
- self.generation_config.top_k = 5
25
- self.generation_config.repetition_penalty = 1.5
26
 
27
  def __call__(self, x: str) -> str:
28
  text = self.template.format(x)
29
- inputs = tokenizer(text, return_tensors="pt")
30
- outputs = model.generate(**inputs, generation_config=self.generation_config)
31
- output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
32
  assert isinstance(output, str)
33
  return output
34
 
35
  def get_name(self) -> str:
36
- return f"Summarizer({MODEL_NAME})"
37
 
38
 
39
- class Tagger:
40
- def __init__(self) -> None:
 
 
 
 
 
41
  self.template = (
42
  "Create a list of tags for the text below. The tags should be high level "
43
  "and specific. Prefix each tag with a hashtag.\n\n{}\n\nTags: #general"
44
  )
45
- self.generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
46
- self.generation_config.max_new_tokens = 50
47
- self.generation_config.min_new_tokens = 25
48
- # increase the temperature to make the model more creative
49
- self.generation_config.temperature = 1.5
50
 
51
  def _extract_tags(self, text: str) -> list[str]:
52
  tags = set()
@@ -57,46 +131,25 @@ class Tagger:
57
 
58
  def __call__(self, x: str) -> list[str]:
59
  text = self.template.format(x)
60
- inputs = tokenizer(text, return_tensors="pt")
61
- outputs = model.generate(**inputs, generation_config=self.generation_config)
62
- output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
63
  tags = self._extract_tags(output)
64
  return tags
65
 
66
  def get_name(self) -> str:
67
- return f"Tagger({MODEL_NAME})"
68
-
69
-
70
- class Processor(abc.ABC):
71
- def __call__(self, job: JobInput) -> str:
72
- _id = job.id
73
- logger.info(f"Processing {input} with {self.__class__.__name__} (id={_id[:8]})")
74
- result = self.process(job)
75
- logger.info(f"Finished processing input (id={_id[:8]})")
76
- return result
77
-
78
- def process(self, input: JobInput) -> str:
79
- raise NotImplementedError
80
-
81
- def match(self, input: JobInput) -> bool:
82
- raise NotImplementedError
83
-
84
- def get_name(self) -> str:
85
- raise NotImplementedError
86
 
87
 
88
- class RawProcessor(Processor):
89
  def match(self, input: JobInput) -> bool:
90
  return True
91
 
92
  def process(self, input: JobInput) -> str:
93
  return input.content
94
 
95
- def get_name(self) -> str:
96
- return self.__class__.__name__
97
-
98
 
99
- class PlainUrlProcessor(Processor):
100
  def __init__(self) -> None:
101
  self.client = httpx.Client()
102
  self.regex = re.compile(r"(https?://[^\s]+)")
@@ -118,26 +171,22 @@ class PlainUrlProcessor(Processor):
118
  text = self.template.format(url=self.url, content=text)
119
  return text
120
 
121
- def get_name(self) -> str:
122
- return self.__class__.__name__
 
 
 
123
 
 
 
124
 
125
- class ProcessorRegistry:
126
- def __init__(self) -> None:
127
- self.registry: list[Processor] = []
128
- self.default_registry: list[Processor] = []
129
- self.set_default_processors()
130
 
131
- def set_default_processors(self) -> None:
132
- self.default_registry.extend([PlainUrlProcessor(), RawProcessor()])
133
-
134
- def register(self, processor: Processor) -> None:
135
- self.registry.append(processor)
136
-
137
- def dispatch(self, input: JobInput) -> Processor:
138
- for processor in self.registry + self.default_registry:
139
- if processor.match(input):
140
- return processor
141
 
142
- # should never be requires, but eh
143
- return RawProcessor()
 
1
  import abc
2
+ from typing import Any
3
  import logging
4
  import re
5
 
6
  import httpx
 
7
 
8
  from base import JobInput
9
 
10
  logger = logging.getLogger(__name__)
11
  logger.setLevel(logging.DEBUG)
12
 
 
 
 
13
 
14
+ class Processor(abc.ABC):
15
+ def get_name(self) -> str:
16
+ return self.__class__.__name__
17
+
18
+ def __call__(self, job: JobInput) -> str:
19
+ _id = job.id
20
+ logger.info(f"Processing {input} with {self.__class__.__name__} (id={_id[:8]})")
21
+ result = self.process(job)
22
+ logger.info(f"Finished processing input (id={_id[:8]})")
23
+ return result
24
 
25
+ @abc.abstractmethod
26
+ def process(self, input: JobInput) -> str:
27
+ raise NotImplementedError
28
+
29
+ @abc.abstractmethod
30
+ def match(self, input: JobInput) -> bool:
31
+ raise NotImplementedError
32
+
33
+
34
+ class Summarizer(abc.ABC):
35
+ def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None:
36
+ raise NotImplementedError
37
+
38
+ def get_name(self) -> str:
39
+ raise NotImplementedError
40
+
41
+ @abc.abstractmethod
42
+ def __call__(self, x: str) -> str:
43
+ raise NotImplementedError
44
+
45
+
46
+ class Tagger(abc.ABC):
47
+ def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None:
48
+ raise NotImplementedError
49
+
50
+ def get_name(self) -> str:
51
+ raise NotImplementedError
52
+
53
+ @abc.abstractmethod
54
+ def __call__(self, x: str) -> list[str]:
55
+ raise NotImplementedError
56
+
57
+
58
+ class MlRegistry:
59
  def __init__(self) -> None:
60
+ self.processors: list[Processor] = []
61
+ self.summerizer: Summarizer | None = None
62
+ self.tagger: Tagger | None = None
63
+ self.model = None
64
+ self.tokenizer = None
65
+
66
+ def register_processor(self, processor: Processor) -> None:
67
+ self.processors.append(processor)
68
+
69
+ def register_summarizer(self, summarizer: Summarizer) -> None:
70
+ self.summerizer = summarizer
71
+
72
+ def register_tagger(self, tagger: Tagger) -> None:
73
+ self.tagger = tagger
74
+
75
+ def get_processor(self, input: JobInput) -> Processor:
76
+ assert self.processors
77
+ for processor in self.processors:
78
+ if processor.match(input):
79
+ return processor
80
+
81
+ return RawTextProcessor()
82
+
83
+ def get_summarizer(self) -> Summarizer:
84
+ assert self.summerizer
85
+ return self.summerizer
86
+
87
+ def get_tagger(self) -> Tagger:
88
+ assert self.tagger
89
+ return self.tagger
90
+
91
+
92
+ class HfTransformersSummarizer(Summarizer):
93
+ def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None:
94
+ self.model_name = model_name
95
+ self.model = model
96
+ self.tokenizer = tokenizer
97
+ self.generation_config = generation_config
98
+
99
  self.template = "Summarize the text below in two sentences:\n\n{}"
 
 
 
 
 
100
 
101
  def __call__(self, x: str) -> str:
102
  text = self.template.format(x)
103
+ inputs = self.tokenizer(text, return_tensors="pt")
104
+ outputs = self.model.generate(**inputs, generation_config=self.generation_config)
105
+ output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
106
  assert isinstance(output, str)
107
  return output
108
 
109
  def get_name(self) -> str:
110
+ return f"{self.__class__.__name__}({self.model_name})"
111
 
112
 
113
+ class HfTransformersTagger(Tagger):
114
+ def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None:
115
+ self.model_name = model_name
116
+ self.model = model
117
+ self.tokenizer = tokenizer
118
+ self.generation_config = generation_config
119
+
120
  self.template = (
121
  "Create a list of tags for the text below. The tags should be high level "
122
  "and specific. Prefix each tag with a hashtag.\n\n{}\n\nTags: #general"
123
  )
 
 
 
 
 
124
 
125
  def _extract_tags(self, text: str) -> list[str]:
126
  tags = set()
 
131
 
132
  def __call__(self, x: str) -> list[str]:
133
  text = self.template.format(x)
134
+ inputs = self.tokenizer(text, return_tensors="pt")
135
+ outputs = self.model.generate(**inputs, generation_config=self.generation_config)
136
+ output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
137
  tags = self._extract_tags(output)
138
  return tags
139
 
140
  def get_name(self) -> str:
141
+ return f"{self.__class__.__name__}({self.model_name})"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
 
144
+ class RawTextProcessor(Processor):
145
  def match(self, input: JobInput) -> bool:
146
  return True
147
 
148
  def process(self, input: JobInput) -> str:
149
  return input.content
150
 
 
 
 
151
 
152
+ class DefaultUrlProcessor(Processor):
153
  def __init__(self) -> None:
154
  self.client = httpx.Client()
155
  self.regex = re.compile(r"(https?://[^\s]+)")
 
171
  text = self.template.format(url=self.url, content=text)
172
  return text
173
 
174
+ # class ProcessorRegistry:
175
+ # def __init__(self) -> None:
176
+ # self.registry: list[Processor] = []
177
+ # self.default_registry: list[Processor] = []
178
+ # self.set_default_processors()
179
 
180
+ # def set_default_processors(self) -> None:
181
+ # self.default_registry.extend([PlainUrlProcessor(), RawProcessor()])
182
 
183
+ # def register(self, processor: Processor) -> None:
184
+ # self.registry.append(processor)
 
 
 
185
 
186
+ # def dispatch(self, input: JobInput) -> Processor:
187
+ # for processor in self.registry + self.default_registry:
188
+ # if processor.match(input):
189
+ # return processor
 
 
 
 
 
 
190
 
191
+ # # should never be requires, but eh
192
+ # return RawProcessor()
src/webservice.py CHANGED
@@ -14,6 +14,12 @@ logger.setLevel(logging.DEBUG)
14
  app = FastAPI()
15
 
16
 
 
 
 
 
 
 
17
  @app.post("/submit/")
18
  def submit_job(input: RequestInput) -> str:
19
  # submit a new job, poor man's job queue
 
14
  app = FastAPI()
15
 
16
 
17
+ # status
18
+ @app.get("/status/")
19
+ def status() -> str:
20
+ return "OK"
21
+
22
+
23
  @app.post("/submit/")
24
  def submit_job(input: RequestInput) -> str:
25
  # submit a new job, poor man's job queue
src/worker.py CHANGED
@@ -1,18 +1,19 @@
1
  import time
 
2
 
3
  from base import JobInput
4
  from db import get_db_cursor
5
- from ml import ProcessorRegistry, Summarizer, Tagger
 
 
 
 
 
 
6
 
7
  SLEEP_INTERVAL = 5
8
 
9
 
10
- processor_registry = ProcessorRegistry()
11
- summarizer = Summarizer()
12
- tagger = Tagger()
13
- print("loaded ML models")
14
-
15
-
16
  def check_pending_jobs() -> list[JobInput]:
17
  """Check DB for pending jobs"""
18
  with get_db_cursor() as cursor:
@@ -30,15 +31,38 @@ def check_pending_jobs() -> list[JobInput]:
30
  ]
31
 
32
 
33
- def store(
34
- job: JobInput,
35
- *,
36
- summary: str,
37
- tags: list[str],
38
- processor_name: str,
39
- summarizer_name: str,
40
- tagger_name: str,
41
- ) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  with get_db_cursor() as cursor:
43
  # write to entries, summary, tags tables
44
  cursor.execute(
@@ -46,39 +70,23 @@ def store(
46
  "INSERT INTO summaries (entry_id, summary, summarizer_name)"
47
  " VALUES (?, ?, ?)"
48
  ),
49
- (job.id, summary, summarizer_name),
50
  )
51
  cursor.executemany(
52
  "INSERT INTO tags (entry_id, tag, tagger_name) VALUES (?, ?, ?)",
53
- [(job.id, tag, tagger_name) for tag in tags],
54
  )
55
 
56
 
57
- def process_job(job: JobInput) -> None:
58
  tic = time.perf_counter()
59
  print(f"Processing job for (id={job.id[:8]})")
60
 
61
  # care: acquire cursor (which leads to locking) as late as possible, since
62
  # the processing and we don't want to block other workers during that time
63
  try:
64
- processor = processor_registry.dispatch(job)
65
- processor_name = processor.get_name()
66
- processed = processor(job)
67
-
68
- tagger_name = tagger.get_name()
69
- tags = tagger(processed)
70
-
71
- summarizer_name = summarizer.get_name()
72
- summary = summarizer(processed)
73
-
74
- store(
75
- job,
76
- summary=summary,
77
- tags=tags,
78
- processor_name=processor_name,
79
- summarizer_name=summarizer_name,
80
- tagger_name=tagger_name,
81
- )
82
  # update job status to done
83
  with get_db_cursor() as cursor:
84
  cursor.execute(
@@ -96,7 +104,40 @@ def process_job(job: JobInput) -> None:
96
  print(f"Finished processing job (id={job.id[:8]}) in {toc - tic:0.3f} seconds")
97
 
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  def main() -> None:
 
 
 
100
  while True:
101
  jobs = check_pending_jobs()
102
  if not jobs:
@@ -106,7 +147,7 @@ def main() -> None:
106
 
107
  print(f"Found {len(jobs)} pending job(s), processing...")
108
  for job in jobs:
109
- process_job(job)
110
 
111
 
112
  if __name__ == "__main__":
 
1
  import time
2
+ from dataclasses import dataclass
3
 
4
  from base import JobInput
5
  from db import get_db_cursor
6
+ from ml import (
7
+ DefaultUrlProcessor,
8
+ HfTransformersSummarizer,
9
+ HfTransformersTagger,
10
+ MlRegistry,
11
+ RawTextProcessor,
12
+ )
13
 
14
  SLEEP_INTERVAL = 5
15
 
16
 
 
 
 
 
 
 
17
  def check_pending_jobs() -> list[JobInput]:
18
  """Check DB for pending jobs"""
19
  with get_db_cursor() as cursor:
 
31
  ]
32
 
33
 
34
+ @dataclass
35
+ class JobOutput:
36
+ summary: str
37
+ tags: list[str]
38
+ processor_name: str
39
+ summarizer_name: str
40
+ tagger_name: str
41
+
42
+
43
+ def _process_job(job: JobInput, registry: MlRegistry) -> JobOutput:
44
+ processor = registry.get_processor(job)
45
+ processor_name = processor.get_name()
46
+ processed = processor(job)
47
+
48
+ tagger = registry.get_tagger()
49
+ tagger_name = tagger.get_name()
50
+ tags = tagger(processed)
51
+
52
+ summarizer = registry.get_summarizer()
53
+ summarizer_name = summarizer.get_name()
54
+ summary = summarizer(processed)
55
+
56
+ return JobOutput(
57
+ summary=summary,
58
+ tags=tags,
59
+ processor_name=processor_name,
60
+ summarizer_name=summarizer_name,
61
+ tagger_name=tagger_name,
62
+ )
63
+
64
+
65
+ def store(job: JobInput, output: JobOutput) -> None:
66
  with get_db_cursor() as cursor:
67
  # write to entries, summary, tags tables
68
  cursor.execute(
 
70
  "INSERT INTO summaries (entry_id, summary, summarizer_name)"
71
  " VALUES (?, ?, ?)"
72
  ),
73
+ (job.id, output.summary, output.summarizer_name),
74
  )
75
  cursor.executemany(
76
  "INSERT INTO tags (entry_id, tag, tagger_name) VALUES (?, ?, ?)",
77
+ [(job.id, tag, output.tagger_name) for tag in output.tags],
78
  )
79
 
80
 
81
+ def process_job(job: JobInput, registry: MlRegistry) -> None:
82
  tic = time.perf_counter()
83
  print(f"Processing job for (id={job.id[:8]})")
84
 
85
  # care: acquire cursor (which leads to locking) as late as possible, since
86
  # the processing and we don't want to block other workers during that time
87
  try:
88
+ output = _process_job(job, registry)
89
+ store(job, output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  # update job status to done
91
  with get_db_cursor() as cursor:
92
  cursor.execute(
 
104
  print(f"Finished processing job (id={job.id[:8]}) in {toc - tic:0.3f} seconds")
105
 
106
 
107
+ def load_mlregistry(model_name: str) -> MlRegistry:
108
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
109
+
110
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
111
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
112
+
113
+ config_summarizer = GenerationConfig.from_pretrained(model_name)
114
+ config_summarizer.max_new_tokens = 200
115
+ config_summarizer.min_new_tokens = 100
116
+ config_summarizer.top_k = 5
117
+ config_summarizer.repetition_penalty = 1.5
118
+
119
+ config_tagger = GenerationConfig.from_pretrained(model_name)
120
+ config_tagger.max_new_tokens = 50
121
+ config_tagger.min_new_tokens = 25
122
+ # increase the temperature to make the model more creative
123
+ config_tagger.temperature = 1.5
124
+
125
+ summarizer = HfTransformersSummarizer(model_name, model, tokenizer, config_summarizer)
126
+ tagger = HfTransformersTagger(model_name, model, tokenizer, config_tagger)
127
+
128
+ registry = MlRegistry()
129
+ registry.register_processor(DefaultUrlProcessor())
130
+ registry.register_processor(RawTextProcessor())
131
+ registry.register_summarizer(summarizer)
132
+ registry.register_tagger(tagger)
133
+
134
+ return registry
135
+
136
+
137
  def main() -> None:
138
+ model_name = "google/flan-t5-large"
139
+ registry = load_mlregistry(model_name)
140
+
141
  while True:
142
  jobs = check_pending_jobs()
143
  if not jobs:
 
147
 
148
  print(f"Found {len(jobs)} pending job(s), processing...")
149
  for job in jobs:
150
+ process_job(job, registry)
151
 
152
 
153
  if __name__ == "__main__":