File size: 2,301 Bytes
b1426fb f6e66c5 b1426fb f6e66c5 b1426fb 08d05f4 f6e66c5 08d05f4 b1426fb 08d05f4 f6e66c5 08d05f4 3befa67 f6e66c5 08d05f4 f6e66c5 08d05f4 b1426fb f6e66c5 08d05f4 f6e66c5 08d05f4 f6e66c5 08d05f4 f6e66c5 08d05f4 f6e66c5 08d05f4 f6e66c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
"""
Summarization Model Handler
Manages the fine-tuned BART model for text summarization.
"""
from transformers import BartTokenizer, BartForConditionalGeneration
import torch
import streamlit as st
class Summarizer:
def __init__(self):
"""Initialize the summarization model."""
self.model = None
self.tokenizer = None
def load_model(self):
"""Load the fine-tuned BART summarization model."""
try:
with open('bart_ami_finetuned.pkl','rb') as f:
self.model = pickle.load(f)
# Load the tokenizer
self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
# Move model to appropriate device (GPU if available)
self.model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
return self.model
except Exception as e:
st.error(f"Error loading fine-tuned summarization model: {str(e)}")
return None
def process(self, text: str, max_length: int = 130, min_length: int = 30):
"""Process text for summarization.
Args:
text (str): Text to summarize
max_length (int): Maximum length of summary
min_length (int): Minimum length of summary
Returns:
str: Summarized text
"""
try:
# Tokenize input text
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=1024, padding="max_length")
# Move inputs to the same device as the model
inputs = {key: value.to(self.model.device) for key, value in inputs.items()}
# Generate summary
summary_ids = self.model.generate(
inputs["input_ids"],
max_length=max_length,
min_length=min_length,
num_beams=4, # Beam search for better quality
early_stopping=True
)
# Decode summary tokens to text
summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary
except Exception as e:
st.error(f"Error in summarization: {str(e)}")
return None
|