NotXia's picture
Add model
04f3f18 unverified
import torch
def _selectStrategyLength(sentences, predictions, max_length):
selected_sents = []
sents_priority = torch.argsort(predictions, descending=True)
summary_len = 0
i = 0
while (summary_len < max_length) and (i < len(sents_priority)):
if summary_len + len(sentences[sents_priority[i]]) < max_length:
selected_sents.append(sents_priority[i].item())
summary_len += len(sentences[sents_priority[i]])
i += 1
return sorted(selected_sents)
def _selectStrategyCount(sentences, predictions, num_sents):
selected_idxs = sorted(torch.topk(predictions, min(len(predictions), num_sents)).indices)
return [tensor.item() for tensor in selected_idxs]
def _selectStrategyRatio(sentences, predictions, ratio):
doc_length = sum([ len(sent) for sent in sentences ])
return _selectStrategyLength(sentences, predictions, doc_length*ratio)
def _selectStrategyThreshold(sentences, predictions, threshold):
return [i for i, score in enumerate(predictions) if score >= threshold]
def select(sentences, predictions, strategy, strategy_args):
selected_sents = []
if strategy == "length":
selected_sents = _selectStrategyLength(sentences, predictions, strategy_args)
elif strategy == "count":
selected_sents = _selectStrategyCount(sentences, predictions, strategy_args)
elif strategy == "ratio":
selected_sents = _selectStrategyRatio(sentences, predictions, strategy_args)
elif strategy == "threshold":
selected_sents = _selectStrategyThreshold(sentences, predictions, strategy_args)
else:
raise NotImplementedError(f"Unknown strategy {strategy}")
return [sentences[i] for i in selected_sents], selected_sents
"""
Splits a document in chunks of maximum a given size.
Parameters
----------
doc_tokens : str[]
List of the tokens of the document.
bos_token : str
Begin of sentence token.
eos_token : str
End of sentence token.
max_size : int
Maximum size of a chunk.
Returns
-------
chunks : str[][]
Splitted document.
"""
def splitDocument(doc_tokens, bos_token, eos_token, max_size):
def _findNextBOSFrom(start_idx):
for i in range(start_idx, len(doc_tokens)):
if doc_tokens[i] == bos_token:
return i
return -1
def _findPreviousEOSFrom(start_idx):
for i in range(start_idx, -1, -1):
if doc_tokens[i] == eos_token:
return i
return -1
chunks = []
while len(doc_tokens) > max_size:
# Splits at the eos token
eos_idx = _findPreviousEOSFrom(max_size - 1)
if eos_idx == -1:
# The sentence is too long.
# Find the next bos in front of the current sentence (if exists) and truncate the current sentence.
next_bos_idx = _findNextBOSFrom(max_size)
if next_bos_idx != -1:
doc_tokens = doc_tokens[:max_size-1] + [eos_token] + doc_tokens[next_bos_idx:]
else:
doc_tokens = doc_tokens[:max_size-1] + [eos_token]
eos_idx = max_size - 1
chunks.append(doc_tokens[:eos_idx+1])
doc_tokens = doc_tokens[eos_idx+1:]
if len(doc_tokens) > 0: chunks.append(doc_tokens) # Remaining part of the document
return chunks