Spaces:
Sleeping
Sleeping
import torch.nn.functional as F | |
import pandas as pd | |
from dataloader import create_dataloader | |
from utils import * | |
def predict_segmentation(inp, model, device, batch_size=8): | |
test_loader = create_dataloader(inp, batch_size) | |
predictions = [] | |
for batch in test_loader: | |
batch = {k: v.to(device) for k, v in batch.items()} | |
p = F.sigmoid(model(**batch).logits).detach().cpu().numpy() | |
predictions.append(p) | |
return np.concatenate(predictions, axis=0) | |
def create_data(text, tokenizer, seq_len=512): | |
tokens = tokenizer(text, add_special_tokens=False) | |
_token_batches = {k: [pad_seq(x, seq_len) for x in batch_list(v, seq_len)] | |
for (k, v) in tokens.items()} | |
n_batches = len(_token_batches['input_ids']) | |
return [{k: v[i] for k, v in _token_batches.items()} | |
for i in range(n_batches)] | |
def segment_tokens(notes, model, tokenizer, device, batch_size=8): | |
predictions = {} | |
for note in notes.itertuples(): | |
note_id = note.note_id | |
raw_text = note.text.lower() | |
inp = create_data(raw_text, tokenizer) | |
pred_probs = predict_segmentation(inp, model, device, batch_size=batch_size) | |
pred_probs = np.squeeze(pred_probs, -1) | |
pred_probs = np.concatenate(pred_probs) | |
predictions[note_id] = pred_probs | |
return predictions | |
def segment(notes, model, tokenizer, device, thresh, batch_size=8): | |
predictions = [] | |
predictions_prob_map = segment_tokens(notes, model, tokenizer, device, batch_size) | |
for note in notes.itertuples(): | |
note_id = note.note_id | |
raw_text = note.text | |
decoded_text = tokenizer.decode(tokenizer.encode(raw_text, add_special_tokens=False)) | |
pred_probs = predictions_prob_map[note_id] | |
_, pred_probs = align_decoded(raw_text, decoded_text, pred_probs) | |
pred_probs = np.array(pred_probs, 'float32') | |
pred = (pred_probs > thresh).astype('uint8') | |
spans = get_sequential_spans(pred) | |
note_predictions = {'note_id': [], 'start': [], 'end': [], 'mention': [], 'score': []} | |
for (start, end) in spans: | |
note_predictions['note_id'].append(note_id) | |
note_predictions['score'].append(pred_probs[start:end].mean()) | |
note_predictions['start'].append(start) | |
note_predictions['end'].append(end) | |
note_predictions['mention'].append(raw_text[start:end]) | |
note_predictions = pd.DataFrame(note_predictions) | |
note_predictions = note_predictions.sort_values('score', ascending=False) | |
# remove overlapping spans | |
seen_spans = set() | |
unseen = [] | |
for span in note_predictions[['start', 'end']].values: | |
span = tuple(span) | |
s = False | |
if not is_overlap(seen_spans, span): | |
seen_spans.add(span) | |
s = True | |
unseen.append(s) | |
note_predictions = note_predictions[unseen] | |
predictions.append(note_predictions) | |
predictions = pd.concat(predictions).reset_index(drop=True) | |
return predictions | |