Spaces:
Sleeping
Sleeping
| from transformers import pipeline | |
| import re | |
| from typing import List, Union | |
| import torch | |
| import logging | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Pre-compiled regex patterns for faster text cleaning | |
| HTML_TAG_PATTERN = re.compile(r'<[^>]+>') | |
| SPECIAL_CHARS_PATTERN = re.compile(r'[^\w\s.,;?!-]') | |
| MULTISPACE_PATTERN = re.compile(r'\s+') | |
| # Model and summarizer pipeline configuration | |
| MODEL_NAME = "sshleifer/distilbart-cnn-6-6" | |
| MAX_LENGTH = 1024 # Maximum input token length for the model (e.g., BART, DistilBART) | |
| MAX_SUMMARY_LENGTH = 150 # Maximum summary length | |
| MIN_SUMMARY_LENGTH = 30 # Minimum summary length | |
| # Device selection for pipeline | |
| _device = 0 if torch.cuda.is_available() else -1 | |
| # Load the summarizer pipeline ONCE at module level | |
| logger.info(f"Loading summarizer pipeline: {MODEL_NAME}") | |
| _summarizer = pipeline( | |
| "summarization", | |
| model=MODEL_NAME, | |
| device=_device | |
| ) | |
| def clean_text(text: str) -> str: | |
| """Clean text by removing HTML tags and special characters""" | |
| text = HTML_TAG_PATTERN.sub(' ', text) | |
| text = SPECIAL_CHARS_PATTERN.sub(' ', text) | |
| return MULTISPACE_PATTERN.sub(' ', text).strip() | |
| def get_summary_points(texts, max_points=3, batch_size=4): | |
| is_batch = isinstance(texts, list) # Ensure is_batch is always defined | |
| try: | |
| # Handle both single text and batch of texts | |
| texts = [texts] if not is_batch else texts | |
| # Clean and truncate texts | |
| cleaned_texts = [clean_text(t)[:MAX_LENGTH] for t in texts] | |
| # Filter out texts that are too short | |
| valid_texts = [] | |
| valid_indices = [] | |
| for idx, t in enumerate(cleaned_texts): | |
| if len(t.split()) >= 30: | |
| valid_texts.append(t) | |
| valid_indices.append(idx) | |
| else: | |
| logger.warning(f"Text at index {idx} is too short for summarization (length: {len(t.split())})") | |
| if not valid_texts: | |
| logger.warning("No valid texts found for summarization") | |
| return [] if not is_batch else [[] for _ in texts] | |
| # Generate summaries in batch | |
| try: | |
| summaries = _summarizer( | |
| valid_texts, | |
| max_length=MAX_SUMMARY_LENGTH, | |
| min_length=MIN_SUMMARY_LENGTH, | |
| length_penalty=2.0, # Increased to favor longer summaries | |
| num_beams=4, # Increased for better quality | |
| no_repeat_ngram_size=3, | |
| early_stopping=True, | |
| do_sample=False # Disable sampling for more deterministic results | |
| ) | |
| summaries = [s['summary_text'] for s in summaries] | |
| except Exception as e: | |
| logger.error(f"Error during summarization: {e}") | |
| return [] if not is_batch else [[] for _ in texts] | |
| # Process summaries into points | |
| all_points = [] | |
| for summary in summaries: | |
| points = [] | |
| # Split by sentence boundaries (period, question mark, exclamation mark) | |
| sentences = re.split(r'[.!?]+', summary) | |
| for sentence in sentences: | |
| sentence = sentence.strip() | |
| if sentence and len(sentence.split()) >= 5: # Only include sentences with at least 5 words | |
| sentence = sentence.capitalize() + '.' | |
| points.append(sentence) | |
| all_points.append(points[:max_points]) | |
| # Handle results for batch processing | |
| if is_batch: | |
| result = [[] for _ in texts] | |
| for idx, points in zip(valid_indices, all_points): | |
| result[idx] = points | |
| return result | |
| return all_points[0] if all_points else [] | |
| except Exception as e: | |
| logger.error(f"Error generating summary: {e}") | |
| return [] if not is_batch else [[] for _ in texts] | |
| if __name__ == "__main__": | |
| # Example usage | |
| texts = [ | |
| """ | |
| """ | |
| ] | |
| try: | |
| print("\nGenerating summaries...") | |
| results = get_summary_points(texts) | |
| for idx, points in enumerate(results, 1): | |
| print(f"\n=== Summary {idx} ===") | |
| if points: | |
| for point_idx, point in enumerate(points, 1): | |
| print(f"{point_idx}. {point}") | |
| else: | |
| print("No summary could be generated.") | |
| except Exception as e: | |
| print(f"An error occurred: {e}") |