pszemraj commited on
Commit
f18cdf1
β€’
1 Parent(s): 1d73f50

🎨 πŸ“

Browse files

Signed-off-by: peter szemraj <peterszemraj@gmail.com>

Files changed (1) hide show
  1. aggregate.py +31 -18
aggregate.py CHANGED
@@ -1,3 +1,12 @@
 
 
 
 
 
 
 
 
 
1
  import pprint as pp
2
  import logging
3
  import time
@@ -14,10 +23,15 @@ logging.basicConfig(
14
 
15
 
16
  class BatchAggregator:
17
- CONFIGURED_MODELS = [
18
- "pszemraj/bart-large-mnli-dolly_hhrlhf-v1"
19
- ] # TODO: Add models here
20
- DEFAULT_INSTRUCTION = "Write a comprehensive yet concise summary that pulls together the main points of the following text:"
 
 
 
 
 
21
  GENERIC_CONFIG = GenerationConfig(
22
  num_beams=8,
23
  early_stopping=True,
@@ -29,10 +43,23 @@ class BatchAggregator:
29
  no_repeat_ngram_size=4,
30
  encoder_no_repeat_ngram_size=5,
31
  )
 
 
 
 
 
 
 
 
32
 
33
  def __init__(
34
  self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1", **kwargs
35
  ):
 
 
 
 
 
36
  self.device = None
37
  self.is_compiled = False
38
  self.logger = logging.getLogger(__name__)
@@ -125,20 +152,6 @@ class BatchAggregator:
125
  """
126
  self.aggregator.model.generation_config = self.GENERIC_CONFIG
127
 
128
- if "bart" in self.model_name.lower():
129
- self.logger.info("Using BART model, updating generation config")
130
- upd = {
131
- "num_beams": 8,
132
- "repetition_penalty": 1.3,
133
- "length_penalty": 1.0,
134
- "_from_model_config": False,
135
- "max_new_tokens": 256,
136
- "min_new_tokens": 32,
137
- "no_repeat_ngram_size": 3,
138
- "encoder_no_repeat_ngram_size": 6,
139
- } # TODO: clean up
140
- self.aggregator.model.generation_config.update(**upd)
141
-
142
  if (
143
  "large"
144
  or "xl" in self.model_name.lower()
 
1
+ """
2
+ aggregate.py is a module for aggregating text from multiple sources, or multiple parts of a single source.
3
+ Primary usage is through the BatchAggregator class.
4
+
5
+ How it works:
6
+ 1. We tell the language model to do it.
7
+ 2. The language model does it.
8
+ 3. Yaay!
9
+ """
10
  import pprint as pp
11
  import logging
12
  import time
 
23
 
24
 
25
  class BatchAggregator:
26
+ """
27
+ BatchAggregator is a class for aggregating text from multiple sources.
28
+
29
+ Usage:
30
+ >>> from aggregate import BatchAggregator
31
+ >>> aggregator = BatchAggregator()
32
+ >>> aggregator.aggregate(["This is a test", "This is another test"])
33
+ """
34
+
35
  GENERIC_CONFIG = GenerationConfig(
36
  num_beams=8,
37
  early_stopping=True,
 
43
  no_repeat_ngram_size=4,
44
  encoder_no_repeat_ngram_size=5,
45
  )
46
+ CONFIGURED_MODELS = [
47
+ "pszemraj/bart-large-mnli-dolly_hhrlhf-v1",
48
+ "pszemraj/bart-base-instruct-dolly_hhrlhf",
49
+ "pszemraj/flan-t5-large-instruct-dolly_hhrlhf",
50
+ "pszemraj/flan-t5-base-instruct-dolly_hhrlhf",
51
+ ] # these have generation configs defined for this task in their model repos
52
+
53
+ DEFAULT_INSTRUCTION = "Write a comprehensive yet concise summary that pulls together the main points of the following text:"
54
 
55
  def __init__(
56
  self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1", **kwargs
57
  ):
58
+ """
59
+ __init__ initializes the BatchAggregator class.
60
+
61
+ :param str model_name: model name to use, default: "pszemraj/bart-large-mnli-dolly_hhrlhf-v1"
62
+ """
63
  self.device = None
64
  self.is_compiled = False
65
  self.logger = logging.getLogger(__name__)
 
152
  """
153
  self.aggregator.model.generation_config = self.GENERIC_CONFIG
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  if (
156
  "large"
157
  or "xl" in self.model_name.lower()