Benjamin Bossan commited on
Commit
1643735
1 Parent(s): 3efe4b4
Files changed (2) hide show
  1. src/gistillery/ml.py +19 -6
  2. src/gistillery/worker.py +3 -1
src/gistillery/ml.py CHANGED
@@ -32,7 +32,9 @@ class Processor(abc.ABC):
32
 
33
 
34
  class Summarizer(abc.ABC):
35
- def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None:
 
 
36
  raise NotImplementedError
37
 
38
  def get_name(self) -> str:
@@ -44,7 +46,9 @@ class Summarizer(abc.ABC):
44
 
45
 
46
  class Tagger(abc.ABC):
47
- def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None:
 
 
48
  raise NotImplementedError
49
 
50
  def get_name(self) -> str:
@@ -90,7 +94,9 @@ class MlRegistry:
90
 
91
 
92
  class HfTransformersSummarizer(Summarizer):
93
- def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None:
 
 
94
  self.model_name = model_name
95
  self.model = model
96
  self.tokenizer = tokenizer
@@ -101,7 +107,9 @@ class HfTransformersSummarizer(Summarizer):
101
  def __call__(self, x: str) -> str:
102
  text = self.template.format(x)
103
  inputs = self.tokenizer(text, return_tensors="pt")
104
- outputs = self.model.generate(**inputs, generation_config=self.generation_config)
 
 
105
  output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
106
  assert isinstance(output, str)
107
  return output
@@ -111,7 +119,9 @@ class HfTransformersSummarizer(Summarizer):
111
 
112
 
113
  class HfTransformersTagger(Tagger):
114
- def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None:
 
 
115
  self.model_name = model_name
116
  self.model = model
117
  self.tokenizer = tokenizer
@@ -132,7 +142,9 @@ class HfTransformersTagger(Tagger):
132
  def __call__(self, x: str) -> list[str]:
133
  text = self.template.format(x)
134
  inputs = self.tokenizer(text, return_tensors="pt")
135
- outputs = self.model.generate(**inputs, generation_config=self.generation_config)
 
 
136
  output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
137
  tags = self._extract_tags(output)
138
  return tags
@@ -171,6 +183,7 @@ class DefaultUrlProcessor(Processor):
171
  text = self.template.format(url=self.url, content=text)
172
  return text
173
 
 
174
  # class ProcessorRegistry:
175
  # def __init__(self) -> None:
176
  # self.registry: list[Processor] = []
 
32
 
33
 
34
  class Summarizer(abc.ABC):
35
+ def __init__(
36
+ self, model_name: str, model: Any, tokenizer: Any, generation_config: Any
37
+ ) -> None:
38
  raise NotImplementedError
39
 
40
  def get_name(self) -> str:
 
46
 
47
 
48
  class Tagger(abc.ABC):
49
+ def __init__(
50
+ self, model_name: str, model: Any, tokenizer: Any, generation_config: Any
51
+ ) -> None:
52
  raise NotImplementedError
53
 
54
  def get_name(self) -> str:
 
94
 
95
 
96
  class HfTransformersSummarizer(Summarizer):
97
+ def __init__(
98
+ self, model_name: str, model: Any, tokenizer: Any, generation_config: Any
99
+ ) -> None:
100
  self.model_name = model_name
101
  self.model = model
102
  self.tokenizer = tokenizer
 
107
  def __call__(self, x: str) -> str:
108
  text = self.template.format(x)
109
  inputs = self.tokenizer(text, return_tensors="pt")
110
+ outputs = self.model.generate(
111
+ **inputs, generation_config=self.generation_config
112
+ )
113
  output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
114
  assert isinstance(output, str)
115
  return output
 
119
 
120
 
121
  class HfTransformersTagger(Tagger):
122
+ def __init__(
123
+ self, model_name: str, model: Any, tokenizer: Any, generation_config: Any
124
+ ) -> None:
125
  self.model_name = model_name
126
  self.model = model
127
  self.tokenizer = tokenizer
 
142
  def __call__(self, x: str) -> list[str]:
143
  text = self.template.format(x)
144
  inputs = self.tokenizer(text, return_tensors="pt")
145
+ outputs = self.model.generate(
146
+ **inputs, generation_config=self.generation_config
147
+ )
148
  output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
149
  tags = self._extract_tags(output)
150
  return tags
 
183
  text = self.template.format(url=self.url, content=text)
184
  return text
185
 
186
+
187
  # class ProcessorRegistry:
188
  # def __init__(self) -> None:
189
  # self.registry: list[Processor] = []
src/gistillery/worker.py CHANGED
@@ -122,7 +122,9 @@ def load_mlregistry(model_name: str) -> MlRegistry:
122
  # increase the temperature to make the model more creative
123
  config_tagger.temperature = 1.5
124
 
125
- summarizer = HfTransformersSummarizer(model_name, model, tokenizer, config_summarizer)
 
 
126
  tagger = HfTransformersTagger(model_name, model, tokenizer, config_tagger)
127
 
128
  registry = MlRegistry()
 
122
  # increase the temperature to make the model more creative
123
  config_tagger.temperature = 1.5
124
 
125
+ summarizer = HfTransformersSummarizer(
126
+ model_name, model, tokenizer, config_summarizer
127
+ )
128
  tagger = HfTransformersTagger(model_name, model, tokenizer, config_tagger)
129
 
130
  registry = MlRegistry()