Benjamin Bossan commited on
Commit
91194ca
1 Parent(s): 6ac056a

Move preprocessing & registry to separate modules

Browse files
src/gistillery/ml.py CHANGED
@@ -1,36 +1,11 @@
1
  import abc
2
  from typing import Any
3
  import logging
4
- import re
5
-
6
- import httpx
7
-
8
- from gistillery.base import JobInput
9
 
10
  logger = logging.getLogger(__name__)
11
  logger.setLevel(logging.DEBUG)
12
 
13
 
14
- class Processor(abc.ABC):
15
- def get_name(self) -> str:
16
- return self.__class__.__name__
17
-
18
- def __call__(self, job: JobInput) -> str:
19
- _id = job.id
20
- logger.info(f"Processing {input} with {self.__class__.__name__} (id={_id[:8]})")
21
- result = self.process(job)
22
- logger.info(f"Finished processing input (id={_id[:8]})")
23
- return result
24
-
25
- @abc.abstractmethod
26
- def process(self, input: JobInput) -> str:
27
- raise NotImplementedError
28
-
29
- @abc.abstractmethod
30
- def match(self, input: JobInput) -> bool:
31
- raise NotImplementedError
32
-
33
-
34
  class Summarizer(abc.ABC):
35
  def __init__(
36
  self, model_name: str, model: Any, tokenizer: Any, generation_config: Any
@@ -59,40 +34,6 @@ class Tagger(abc.ABC):
59
  raise NotImplementedError
60
 
61
 
62
- class MlRegistry:
63
- def __init__(self) -> None:
64
- self.processors: list[Processor] = []
65
- self.summerizer: Summarizer | None = None
66
- self.tagger: Tagger | None = None
67
- self.model = None
68
- self.tokenizer = None
69
-
70
- def register_processor(self, processor: Processor) -> None:
71
- self.processors.append(processor)
72
-
73
- def register_summarizer(self, summarizer: Summarizer) -> None:
74
- self.summerizer = summarizer
75
-
76
- def register_tagger(self, tagger: Tagger) -> None:
77
- self.tagger = tagger
78
-
79
- def get_processor(self, input: JobInput) -> Processor:
80
- assert self.processors
81
- for processor in self.processors:
82
- if processor.match(input):
83
- return processor
84
-
85
- return RawTextProcessor()
86
-
87
- def get_summarizer(self) -> Summarizer:
88
- assert self.summerizer
89
- return self.summerizer
90
-
91
- def get_tagger(self) -> Tagger:
92
- assert self.tagger
93
- return self.tagger
94
-
95
-
96
  class HfTransformersSummarizer(Summarizer):
97
  def __init__(
98
  self, model_name: str, model: Any, tokenizer: Any, generation_config: Any
@@ -152,34 +93,3 @@ class HfTransformersTagger(Tagger):
152
 
153
  def get_name(self) -> str:
154
  return f"{self.__class__.__name__}({self.model_name})"
155
-
156
-
157
- class RawTextProcessor(Processor):
158
- def match(self, input: JobInput) -> bool:
159
- return True
160
-
161
- def process(self, input: JobInput) -> str:
162
- return input.content
163
-
164
-
165
- class DefaultUrlProcessor(Processor):
166
- def __init__(self) -> None:
167
- self.client = httpx.Client()
168
- self.regex = re.compile(r"(https?://[^\s]+)")
169
- self.url = None
170
- self.template = "{url}\n\n{content}"
171
-
172
- def match(self, input: JobInput) -> bool:
173
- urls = list(self.regex.findall(input.content))
174
- if len(urls) == 1:
175
- self.url = urls[0]
176
- return True
177
- return False
178
-
179
- def process(self, input: JobInput) -> str:
180
- """Get content of website and return it as string"""
181
- assert isinstance(self.url, str)
182
- text = self.client.get(self.url).text
183
- assert isinstance(text, str)
184
- text = self.template.format(url=self.url, content=text)
185
- return text
 
1
  import abc
2
  from typing import Any
3
  import logging
 
 
 
 
 
4
 
5
  logger = logging.getLogger(__name__)
6
  logger.setLevel(logging.DEBUG)
7
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  class Summarizer(abc.ABC):
10
  def __init__(
11
  self, model_name: str, model: Any, tokenizer: Any, generation_config: Any
 
34
  raise NotImplementedError
35
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  class HfTransformersSummarizer(Summarizer):
38
  def __init__(
39
  self, model_name: str, model: Any, tokenizer: Any, generation_config: Any
 
93
 
94
  def get_name(self) -> str:
95
  return f"{self.__class__.__name__}({self.model_name})"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/gistillery/preprocessing.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import logging
3
+ import re
4
+
5
+ import httpx
6
+
7
+ from gistillery.base import JobInput
8
+
9
+ logger = logging.getLogger(__name__)
10
+ logger.setLevel(logging.DEBUG)
11
+
12
+
13
+
14
+ class Processor(abc.ABC):
15
+ def get_name(self) -> str:
16
+ return self.__class__.__name__
17
+
18
+ def __call__(self, job: JobInput) -> str:
19
+ _id = job.id
20
+ logger.info(f"Processing {input} with {self.__class__.__name__} (id={_id[:8]})")
21
+ result = self.process(job)
22
+ logger.info(f"Finished processing input (id={_id[:8]})")
23
+ return result
24
+
25
+ @abc.abstractmethod
26
+ def process(self, input: JobInput) -> str:
27
+ raise NotImplementedError
28
+
29
+ @abc.abstractmethod
30
+ def match(self, input: JobInput) -> bool:
31
+ raise NotImplementedError
32
+
33
+
34
+ class RawTextProcessor(Processor):
35
+ def match(self, input: JobInput) -> bool:
36
+ return True
37
+
38
+ def process(self, input: JobInput) -> str:
39
+ return input.content.strip()
40
+
41
+
42
+ class DefaultUrlProcessor(Processor):
43
+ def __init__(self) -> None:
44
+ self.client = httpx.Client()
45
+ self.regex = re.compile(r"(https?://[^\s]+)")
46
+ self.url = None
47
+ self.template = "{url}\n\n{content}"
48
+
49
+ def match(self, input: JobInput) -> bool:
50
+ urls = list(self.regex.findall(input.content.strip()))
51
+ if len(urls) == 1:
52
+ self.url = urls[0]
53
+ return True
54
+ return False
55
+
56
+ def process(self, input: JobInput) -> str:
57
+ """Get content of website and return it as string"""
58
+ assert isinstance(self.url, str)
59
+ text = self.client.get(self.url).text
60
+ assert isinstance(text, str)
61
+ text = self.template.format(url=self.url, content=text)
62
+ return text
src/gistillery/registry.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gistillery.ml import Summarizer, Tagger
2
+ from gistillery.preprocessing import Processor, RawTextProcessor
3
+
4
+ from gistillery.base import JobInput
5
+
6
+
7
+ class MlRegistry:
8
+ def __init__(self) -> None:
9
+ self.processors: list[Processor] = []
10
+ self.summerizer: Summarizer | None = None
11
+ self.tagger: Tagger | None = None
12
+ self.model = None
13
+ self.tokenizer = None
14
+
15
+ def register_processor(self, processor: Processor) -> None:
16
+ self.processors.append(processor)
17
+
18
+ def register_summarizer(self, summarizer: Summarizer) -> None:
19
+ self.summerizer = summarizer
20
+
21
+ def register_tagger(self, tagger: Tagger) -> None:
22
+ self.tagger = tagger
23
+
24
+ def get_processor(self, input: JobInput) -> Processor:
25
+ assert self.processors
26
+ for processor in self.processors:
27
+ if processor.match(input):
28
+ return processor
29
+
30
+ return RawTextProcessor()
31
+
32
+ def get_summarizer(self) -> Summarizer:
33
+ assert self.summerizer
34
+ return self.summerizer
35
+
36
+ def get_tagger(self) -> Tagger:
37
+ assert self.tagger
38
+ return self.tagger
src/gistillery/worker.py CHANGED
@@ -3,13 +3,9 @@ from dataclasses import dataclass
3
 
4
  from gistillery.base import JobInput
5
  from gistillery.db import get_db_cursor
6
- from gistillery.ml import (
7
- DefaultUrlProcessor,
8
- HfTransformersSummarizer,
9
- HfTransformersTagger,
10
- MlRegistry,
11
- RawTextProcessor,
12
- )
13
 
14
  SLEEP_INTERVAL = 5
15
 
 
3
 
4
  from gistillery.base import JobInput
5
  from gistillery.db import get_db_cursor
6
+ from gistillery.ml import HfTransformersSummarizer, HfTransformersTagger
7
+ from gistillery.preprocessing import DefaultUrlProcessor, RawTextProcessor
8
+ from gistillery.registry import MlRegistry
 
 
 
 
9
 
10
  SLEEP_INTERVAL = 5
11
 
tests/test_app.py CHANGED
@@ -37,7 +37,9 @@ class TestWebservice:
37
  @pytest.fixture
38
  def mlregistry(self):
39
  # use dummy models
40
- from gistillery.ml import MlRegistry, RawTextProcessor, Summarizer, Tagger
 
 
41
 
42
  class DummySummarizer(Summarizer):
43
  """Returns the first 10 characters of the input"""
 
37
  @pytest.fixture
38
  def mlregistry(self):
39
  # use dummy models
40
+ from gistillery.ml import Summarizer, Tagger
41
+ from gistillery.preprocessing import RawTextProcessor
42
+ from gistillery.registry import MlRegistry
43
 
44
  class DummySummarizer(Summarizer):
45
  """Returns the first 10 characters of the input"""