Chamin09 commited on
Commit
a40ee8d
·
verified ·
1 Parent(s): 38a54b1

Update models/summary_models.py

Browse files
Files changed (1) hide show
  1. models/summary_models.py +270 -249
models/summary_models.py CHANGED
@@ -1,249 +1,270 @@
1
- # models/summary_models.py
2
- import logging
3
- from typing import Dict, List, Optional, Tuple, Union, Any
4
- import torch
5
- from transformers import T5Tokenizer, T5ForConditionalGeneration
6
-
7
- class SummaryModelManager:
8
- def __init__(self, token_manager=None, cache_manager=None, metrics_calculator=None):
9
- """Initialize the SummaryModelManager with optional utilities."""
10
- self.logger = logging.getLogger(__name__)
11
- self.token_manager = token_manager
12
- self.cache_manager = cache_manager
13
- self.metrics_calculator = metrics_calculator
14
-
15
- # Model instance
16
- self.model = None
17
- self.tokenizer = None
18
-
19
- # Model name
20
- self.model_name = "t5-small"
21
-
22
- # Track initialization state
23
- self.initialized = False
24
-
25
- # Default generation parameters
26
- self.default_params = {
27
- "max_length": 150,
28
- "min_length": 40,
29
- "length_penalty": 2.0,
30
- "num_beams": 4,
31
- "early_stopping": True
32
- }
33
-
34
- def initialize_model(self):
35
- """Initialize the summarization model."""
36
- if self.initialized:
37
- return
38
-
39
- try:
40
- # Register with token manager if available
41
- if self.token_manager:
42
- self.token_manager.register_model(
43
- self.model_name, "summarization")
44
-
45
- # Load model and tokenizer
46
- self.logger.info(f"Loading summary model: {self.model_name}")
47
- self.tokenizer = T5Tokenizer.from_pretrained(self.model_name)
48
- self.model = T5ForConditionalGeneration.from_pretrained(self.model_name)
49
-
50
- self.initialized = True
51
- self.logger.info("Summary model initialized successfully")
52
-
53
- except Exception as e:
54
- self.logger.error(f"Failed to initialize summary model: {e}")
55
- raise
56
-
57
- def generate_summary(self, text: str, prefix: str = "summarize: ",
58
- agent_name: str = "report_generation",
59
- params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
60
- """
61
- Generate a summary of the given text.
62
- Returns the summary and metadata.
63
- """
64
- # Initialize model if needed
65
- if not self.initialized:
66
- self.initialize_model()
67
-
68
- # Prepare input text
69
- input_text = f"{prefix}{text}"
70
-
71
- # Check cache if available
72
- if self.cache_manager:
73
- cache_key = input_text[:100] + str(hash(input_text)) # Use prefix of text + hash as key
74
- cache_hit, cached_result = self.cache_manager.get(
75
- cache_key, namespace="summaries")
76
-
77
- if cache_hit:
78
- # Update metrics if available
79
- if self.metrics_calculator:
80
- self.metrics_calculator.update_cache_metrics(1, 0, 0.005) # Estimated energy saving
81
- return cached_result
82
-
83
- # Request token budget if available
84
- if self.token_manager:
85
- approved, reason = self.token_manager.request_tokens(
86
- agent_name, "summarization", input_text, self.model_name)
87
-
88
- if not approved:
89
- self.logger.warning(f"Token budget exceeded: {reason}")
90
- return {"summary": "Token budget exceeded", "error": reason}
91
-
92
- # Tokenize
93
- inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
94
-
95
- # Merge default and custom parameters
96
- generation_params = self.default_params.copy()
97
- if params:
98
- generation_params.update(params)
99
-
100
- # Generate summary
101
- with torch.no_grad():
102
- output_ids = self.model.generate(
103
- inputs.input_ids,
104
- **generation_params
105
- )
106
-
107
- # Decode summary
108
- summary = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
109
-
110
- # Calculate compression ratio
111
- input_length = len(text.split())
112
- summary_length = len(summary.split())
113
- compression_ratio = input_length / max(summary_length, 1)
114
-
115
- # Prepare result
116
- result = {
117
- "summary": summary,
118
- "input_length": input_length,
119
- "summary_length": summary_length,
120
- "compression_ratio": compression_ratio
121
- }
122
-
123
- # Log token usage if available
124
- if self.token_manager:
125
- input_tokens = len(inputs.input_ids[0])
126
- output_tokens = len(output_ids[0])
127
- total_tokens = input_tokens + output_tokens
128
-
129
- self.token_manager.log_usage(
130
- agent_name, "summarization", total_tokens, self.model_name)
131
-
132
- # Log energy usage if metrics calculator is available
133
- if self.metrics_calculator:
134
- energy_usage = self.token_manager.calculate_energy_usage(
135
- total_tokens, self.model_name)
136
- self.metrics_calculator.log_energy_usage(
137
- energy_usage, self.model_name, agent_name, "summarization")
138
-
139
- # Store in cache if available
140
- if self.cache_manager:
141
- self.cache_manager.put(cache_key, result, namespace="summaries")
142
-
143
- return result
144
-
145
- def generate_executive_summary(self, detailed_content: str, confidence_level: float,
146
- agent_name: str = "report_generation") -> Dict[str, Any]:
147
- """
148
- Generate an executive summary with confidence indication.
149
- Adjusts detail level based on confidence.
150
- """
151
- # Prepare prompt based on confidence
152
- if confidence_level >= 0.7:
153
- prefix = "summarize with high confidence: "
154
- params = {"min_length": 30, "max_length": 100}
155
- elif confidence_level >= 0.4:
156
- prefix = "summarize with moderate confidence: "
157
- params = {"min_length": 20, "max_length": 80}
158
- else:
159
- prefix = "summarize with low confidence: "
160
- params = {"min_length": 15, "max_length": 60}
161
-
162
- # Generate summary
163
- result = self.generate_summary(detailed_content, prefix=prefix,
164
- agent_name=agent_name, params=params)
165
-
166
- # Add confidence level to result
167
- result["confidence_level"] = confidence_level
168
-
169
- # Add confidence statement
170
- confidence_statement = self._generate_confidence_statement(confidence_level)
171
- result["confidence_statement"] = confidence_statement
172
-
173
- return result
174
-
175
- def _generate_confidence_statement(self, confidence_level: float) -> str:
176
- """Generate an appropriate confidence statement based on the level."""
177
- if confidence_level >= 0.8:
178
- return "This analysis is provided with high confidence based on strong evidence in the provided materials."
179
- elif confidence_level >= 0.6:
180
- return "This analysis is provided with good confidence based on substantial evidence in the provided materials."
181
- elif confidence_level >= 0.4:
182
- return "This analysis is provided with moderate confidence. Some aspects may require additional verification."
183
- elif confidence_level >= 0.2:
184
- return "This analysis is provided with limited confidence due to sparse relevant information in the provided materials."
185
- else:
186
- return "This analysis is provided with very low confidence due to insufficient relevant information in the provided materials."
187
-
188
- def combine_analyses(self, text_analyses: List[Dict[str, Any]],
189
- image_analyses: List[Dict[str, Any]],
190
- topic: str, agent_name: str = "report_generation") -> Dict[str, Any]:
191
- """
192
- Combine text and image analyses into a coherent report.
193
- Returns the combined report with metadata.
194
- """
195
- # Build combined content
196
- combined_content = f"Topic: {topic}\n\n"
197
-
198
- # Add text analyses
199
- combined_content += "Text Analysis:\n"
200
- for i, analysis in enumerate(text_analyses):
201
- if "error" in analysis:
202
- continue
203
- combined_content += f"- Document {i+1}: {analysis.get('summary', 'No summary available')}\n"
204
-
205
- # Add image analyses
206
- combined_content += "\nImage Analysis:\n"
207
- for i, analysis in enumerate(image_analyses):
208
- if "error" in analysis:
209
- continue
210
- combined_content += f"- Image {i+1}: {analysis.get('caption', 'No caption available')}\n"
211
-
212
- # Calculate overall confidence based on analyses
213
- text_confidence = sum(a.get("confidence", 0) for a in text_analyses) / max(len(text_analyses), 1)
214
- image_confidence = sum(a.get("confidence", 0) for a in image_analyses) / max(len(image_analyses), 1)
215
-
216
- # Weight confidence (text analyses typically more important for deep dives)
217
- overall_confidence = 0.7 * text_confidence + 0.3 * image_confidence
218
-
219
- # Generate detailed report
220
- detailed_report = self.generate_summary(
221
- combined_content,
222
- prefix=f"generate detailed report about {topic}: ",
223
- agent_name=agent_name,
224
- params={"max_length": 300, "min_length": 100}
225
- )
226
-
227
- # Generate executive summary
228
- executive_summary = self.generate_executive_summary(
229
- detailed_report["summary"],
230
- overall_confidence,
231
- agent_name
232
- )
233
-
234
- # Combine results
235
- result = {
236
- "topic": topic,
237
- "executive_summary": executive_summary["summary"],
238
- "confidence_statement": executive_summary["confidence_statement"],
239
- "detailed_report": detailed_report["summary"],
240
- "confidence_level": overall_confidence,
241
- "text_confidence": text_confidence,
242
- "image_confidence": image_confidence,
243
- "source_count": {
244
- "text": len(text_analyses),
245
- "images": len(image_analyses)
246
- }
247
- }
248
-
249
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/summary_models.py
2
+ import logging
3
+ from typing import Dict, List, Optional, Tuple, Union, Any
4
+ import torch
5
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
6
+
7
+ class SummaryModelManager:
8
+ def __init__(self, token_manager=None, cache_manager=None, metrics_calculator=None):
9
+ """Initialize the SummaryModelManager with optional utilities."""
10
+ self.logger = logging.getLogger(__name__)
11
+ self.token_manager = token_manager
12
+ self.cache_manager = cache_manager
13
+ self.metrics_calculator = metrics_calculator
14
+
15
+ # Model instance
16
+ self.model = None
17
+ self.tokenizer = None
18
+
19
+ # Model name
20
+ self.model_name = "t5-small"
21
+
22
+ # Track initialization state
23
+ self.initialized = False
24
+
25
+ # Default generation parameters
26
+ self.default_params = {
27
+ "max_length": 150,
28
+ "min_length": 40,
29
+ "length_penalty": 2.0,
30
+ "num_beams": 4,
31
+ "early_stopping": True
32
+ }
33
+
34
+ def initialize_model(self):
35
+ """Initialize the summarization model."""
36
+ if self.initialized:
37
+ return
38
+
39
+ try:
40
+ # Register with token manager if available
41
+ if self.token_manager:
42
+ self.token_manager.register_model(
43
+ self.model_name, "summarization")
44
+
45
+ # Load model and tokenizer
46
+ self.logger.info(f"Loading summary model: {self.model_name}")
47
+ self.tokenizer = T5Tokenizer.from_pretrained(self.model_name)
48
+ self.model = T5ForConditionalGeneration.from_pretrained(self.model_name)
49
+
50
+ self.initialized = True
51
+ self.logger.info("Summary model initialized successfully")
52
+
53
+ except Exception as e:
54
+ #self.logger.error(f"Failed to initialize summary model: {e}")
55
+ #raise
56
+ # Try a fallback model that doesn't require SentencePiece
57
+ try:
58
+ fallback_model = "facebook/bart-base"
59
+ self.logger.info(f"Trying fallback model: {fallback_model}")
60
+
61
+ from transformers import BartTokenizer, BartForConditionalGeneration
62
+
63
+ self.tokenizer = BartTokenizer.from_pretrained(fallback_model)
64
+ self.model = BartForConditionalGeneration.from_pretrained(fallback_model)
65
+ self.model_name = fallback_model
66
+
67
+ # Register fallback with token manager
68
+ if self.token_manager:
69
+ self.token_manager.register_model(
70
+ self.model_name, "summarization")
71
+
72
+ self.initialized = True
73
+ self.logger.info("Fallback summary model initialized successfully")
74
+ except Exception as fallback_error:
75
+ self.logger.error(f"Failed to initialize fallback model: {fallback_error}")
76
+ raise
77
+
78
+ def generate_summary(self, text: str, prefix: str = "summarize: ",
79
+ agent_name: str = "report_generation",
80
+ params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
81
+ """
82
+ Generate a summary of the given text.
83
+ Returns the summary and metadata.
84
+ """
85
+ # Initialize model if needed
86
+ if not self.initialized:
87
+ self.initialize_model()
88
+
89
+ # Prepare input text
90
+ input_text = f"{prefix}{text}"
91
+
92
+ # Check cache if available
93
+ if self.cache_manager:
94
+ cache_key = input_text[:100] + str(hash(input_text)) # Use prefix of text + hash as key
95
+ cache_hit, cached_result = self.cache_manager.get(
96
+ cache_key, namespace="summaries")
97
+
98
+ if cache_hit:
99
+ # Update metrics if available
100
+ if self.metrics_calculator:
101
+ self.metrics_calculator.update_cache_metrics(1, 0, 0.005) # Estimated energy saving
102
+ return cached_result
103
+
104
+ # Request token budget if available
105
+ if self.token_manager:
106
+ approved, reason = self.token_manager.request_tokens(
107
+ agent_name, "summarization", input_text, self.model_name)
108
+
109
+ if not approved:
110
+ self.logger.warning(f"Token budget exceeded: {reason}")
111
+ return {"summary": "Token budget exceeded", "error": reason}
112
+
113
+ # Tokenize
114
+ inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
115
+
116
+ # Merge default and custom parameters
117
+ generation_params = self.default_params.copy()
118
+ if params:
119
+ generation_params.update(params)
120
+
121
+ # Generate summary
122
+ with torch.no_grad():
123
+ output_ids = self.model.generate(
124
+ inputs.input_ids,
125
+ **generation_params
126
+ )
127
+
128
+ # Decode summary
129
+ summary = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
130
+
131
+ # Calculate compression ratio
132
+ input_length = len(text.split())
133
+ summary_length = len(summary.split())
134
+ compression_ratio = input_length / max(summary_length, 1)
135
+
136
+ # Prepare result
137
+ result = {
138
+ "summary": summary,
139
+ "input_length": input_length,
140
+ "summary_length": summary_length,
141
+ "compression_ratio": compression_ratio
142
+ }
143
+
144
+ # Log token usage if available
145
+ if self.token_manager:
146
+ input_tokens = len(inputs.input_ids[0])
147
+ output_tokens = len(output_ids[0])
148
+ total_tokens = input_tokens + output_tokens
149
+
150
+ self.token_manager.log_usage(
151
+ agent_name, "summarization", total_tokens, self.model_name)
152
+
153
+ # Log energy usage if metrics calculator is available
154
+ if self.metrics_calculator:
155
+ energy_usage = self.token_manager.calculate_energy_usage(
156
+ total_tokens, self.model_name)
157
+ self.metrics_calculator.log_energy_usage(
158
+ energy_usage, self.model_name, agent_name, "summarization")
159
+
160
+ # Store in cache if available
161
+ if self.cache_manager:
162
+ self.cache_manager.put(cache_key, result, namespace="summaries")
163
+
164
+ return result
165
+
166
+ def generate_executive_summary(self, detailed_content: str, confidence_level: float,
167
+ agent_name: str = "report_generation") -> Dict[str, Any]:
168
+ """
169
+ Generate an executive summary with confidence indication.
170
+ Adjusts detail level based on confidence.
171
+ """
172
+ # Prepare prompt based on confidence
173
+ if confidence_level >= 0.7:
174
+ prefix = "summarize with high confidence: "
175
+ params = {"min_length": 30, "max_length": 100}
176
+ elif confidence_level >= 0.4:
177
+ prefix = "summarize with moderate confidence: "
178
+ params = {"min_length": 20, "max_length": 80}
179
+ else:
180
+ prefix = "summarize with low confidence: "
181
+ params = {"min_length": 15, "max_length": 60}
182
+
183
+ # Generate summary
184
+ result = self.generate_summary(detailed_content, prefix=prefix,
185
+ agent_name=agent_name, params=params)
186
+
187
+ # Add confidence level to result
188
+ result["confidence_level"] = confidence_level
189
+
190
+ # Add confidence statement
191
+ confidence_statement = self._generate_confidence_statement(confidence_level)
192
+ result["confidence_statement"] = confidence_statement
193
+
194
+ return result
195
+
196
+ def _generate_confidence_statement(self, confidence_level: float) -> str:
197
+ """Generate an appropriate confidence statement based on the level."""
198
+ if confidence_level >= 0.8:
199
+ return "This analysis is provided with high confidence based on strong evidence in the provided materials."
200
+ elif confidence_level >= 0.6:
201
+ return "This analysis is provided with good confidence based on substantial evidence in the provided materials."
202
+ elif confidence_level >= 0.4:
203
+ return "This analysis is provided with moderate confidence. Some aspects may require additional verification."
204
+ elif confidence_level >= 0.2:
205
+ return "This analysis is provided with limited confidence due to sparse relevant information in the provided materials."
206
+ else:
207
+ return "This analysis is provided with very low confidence due to insufficient relevant information in the provided materials."
208
+
209
+ def combine_analyses(self, text_analyses: List[Dict[str, Any]],
210
+ image_analyses: List[Dict[str, Any]],
211
+ topic: str, agent_name: str = "report_generation") -> Dict[str, Any]:
212
+ """
213
+ Combine text and image analyses into a coherent report.
214
+ Returns the combined report with metadata.
215
+ """
216
+ # Build combined content
217
+ combined_content = f"Topic: {topic}\n\n"
218
+
219
+ # Add text analyses
220
+ combined_content += "Text Analysis:\n"
221
+ for i, analysis in enumerate(text_analyses):
222
+ if "error" in analysis:
223
+ continue
224
+ combined_content += f"- Document {i+1}: {analysis.get('summary', 'No summary available')}\n"
225
+
226
+ # Add image analyses
227
+ combined_content += "\nImage Analysis:\n"
228
+ for i, analysis in enumerate(image_analyses):
229
+ if "error" in analysis:
230
+ continue
231
+ combined_content += f"- Image {i+1}: {analysis.get('caption', 'No caption available')}\n"
232
+
233
+ # Calculate overall confidence based on analyses
234
+ text_confidence = sum(a.get("confidence", 0) for a in text_analyses) / max(len(text_analyses), 1)
235
+ image_confidence = sum(a.get("confidence", 0) for a in image_analyses) / max(len(image_analyses), 1)
236
+
237
+ # Weight confidence (text analyses typically more important for deep dives)
238
+ overall_confidence = 0.7 * text_confidence + 0.3 * image_confidence
239
+
240
+ # Generate detailed report
241
+ detailed_report = self.generate_summary(
242
+ combined_content,
243
+ prefix=f"generate detailed report about {topic}: ",
244
+ agent_name=agent_name,
245
+ params={"max_length": 300, "min_length": 100}
246
+ )
247
+
248
+ # Generate executive summary
249
+ executive_summary = self.generate_executive_summary(
250
+ detailed_report["summary"],
251
+ overall_confidence,
252
+ agent_name
253
+ )
254
+
255
+ # Combine results
256
+ result = {
257
+ "topic": topic,
258
+ "executive_summary": executive_summary["summary"],
259
+ "confidence_statement": executive_summary["confidence_statement"],
260
+ "detailed_report": detailed_report["summary"],
261
+ "confidence_level": overall_confidence,
262
+ "text_confidence": text_confidence,
263
+ "image_confidence": image_confidence,
264
+ "source_count": {
265
+ "text": len(text_analyses),
266
+ "images": len(image_analyses)
267
+ }
268
+ }
269
+
270
+ return result