from transformers import AutoModel, AutoTokenizer
import torch
from tqdm import tqdm
from import Dataset, DataLoader
import os
import spacy
import certifi
import streamlit as st
os.environ['SSL_CERT_FILE'] = certifi.where()
nlp = spacy.load("en_core_web_lg")
model_name = "microsoft/MiniLM-L12-H384-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings/sum_mask
class SentenceBERTClassifier(torch.nn.Module):
def __init__(self, model_name="microsoft/MiniLM-L12-H384-uncased", input_dim=384):
super(SentenceBERTClassifier, self).__init__()
self.model = AutoModel.from_pretrained(model_name)
self.dense1 = torch.nn.Linear(input_dim*3, 768)
self.relu1 = torch.nn.ReLU()
self.dropout1 = torch.nn.Dropout(0.1)
self.dense2 = torch.nn.Linear(768, 384)
self.relu2 = torch.nn.ReLU()
self.dropout2 = torch.nn.Dropout(0.1)
self.classifier = torch.nn.Linear(384, 1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, sent_ids, doc_ids, sent_mask, doc_mask):
sent_output = self.model(input_ids=sent_ids, attention_mask=sent_mask)
sent_embedding = mean_pooling(sent_output, sent_mask)
doc_output = self.model(input_ids=doc_ids, attention_mask=doc_mask)
doc_embedding = mean_pooling(doc_output, doc_mask)
combined_embedding = sent_embedding * doc_embedding
concat_embedding =, doc_embedding, combined_embedding), dim=1)
dense_output1 = self.dense1(concat_embedding)
relu_output1 = self.relu1(dense_output1)
dropout_output1 = self.dropout1(relu_output1)
dense_output2 = self.dense2(dropout_output1)
relu_output2 = self.relu2(dense_output2)
dropout_output2 = self.dropout2(relu_output2)
logits = self.classifier(dropout_output2)
probs = self.sigmoid(logits)
return probs
device = 'cuda' if torch.cuda.is_available() else 'cpu'
extractive_model = SentenceBERTClassifier(model_name=model_name)
extractive_model.load_state_dict(torch.load("model_path\minilm_bal_exsum.pth", map_location=torch.device(device) ))
def get_tokens(text, tokenizer):
inputs = tokenizer.batch_encode_plus(
, add_special_tokens=True
, max_length = 512
, padding="max_length"
, return_token_type_ids=True
, truncation=True
, return_tensors="pt")
ids = inputs["input_ids"]
mask = inputs["attention_mask"]
return ids, mask
# Predicting the relevance scores of sentences in a document
def predict(model,sents, doc):
sent_id, sent_mask = get_tokens(sents,tokenizer)
sent_id, sent_mask = torch.tensor(sent_id, dtype=torch.long),torch.tensor(sent_mask, dtype=torch.long)
doc_id, doc_mask = get_tokens([doc],tokenizer)
doc_id, doc_mask = doc_id.repeat(len(sents), 1), doc_mask.repeat(len(sents), 1)
doc_id, doc_mask = torch.tensor(doc_id, dtype=torch.long),torch.tensor(doc_mask, dtype=torch.long)
# 3. Handle OOV tokens
# Replace OOV tokens with the 'unk' token ID before passing to the model
sent_id[sent_id >= tokenizer.vocab_size] = tokenizer.unk_token_id
doc_id[doc_id >= tokenizer.vocab_size] = tokenizer.unk_token_id
preds = model(sent_id, doc_id, sent_mask, doc_mask)
return preds
def extract_summary(doc, model=extractive_model, min_sentence_length=14, top_k=4, batch_size=4):
doc = doc.replace("\n","")
doc_sentences = []
for sent in nlp(doc).sents:
if len(sent) > min_sentence_length:
# doc_id, doc_mask = get_tokens([doc],tokenizer)
# doc_id, doc_mask = doc_id * batch_size, doc_mask* batch_size
# doc_id, doc_mask = torch.tensor(doc_id, dtype=torch.long),torch.tensor(doc_mask, dtype=torch.long)
scores = []
# run predictions using some batch size
for i in tqdm(range(int(len(doc_sentences) / batch_size) + 1)):
batch_start = i*batch_size
batch_end = (i+1) * batch_size if (i+1) * batch_size < len(doc_sentences) else len(doc_sentences)
batch = doc_sentences[batch_start: batch_end]
if batch:
preds = predict(model, batch, doc)
scores = scores + preds.tolist()
sent_pred_list = [{"sentence": doc_sentences[i], "score": scores[i][0], "index":i} for i in range(len(doc_sentences))]
sorted_sentences = sorted(sent_pred_list, key=lambda k: k['score'], reverse=True)
sorted_result = sorted_sentences[:top_k]
sorted_result = sorted(sorted_result, key=lambda k: k['index'])
summary = [x["sentence"] for x in sorted_result]
summary = " ".join(summary)
return summary