Ritvik19's picture
Add all files and directories
c8a32e7
raw
history blame
No virus
3.9 kB
from collections import defaultdict
from itertools import chain
from typing import Optional
from marker.settings import settings
import torch
import torch.nn.functional as F
from marker.postprocessors.t5 import T5ForTokenClassification, byt5_tokenize
def get_batch_size():
if settings.EDITOR_BATCH_SIZE is not None:
return settings.EDITOR_BATCH_SIZE
elif settings.TORCH_DEVICE_MODEL == "cuda":
return 12
return 6
def load_editing_model():
if not settings.ENABLE_EDITOR_MODEL:
return None
model = T5ForTokenClassification.from_pretrained(
settings.EDITOR_MODEL_NAME,
torch_dtype=settings.MODEL_DTYPE,
).to(settings.TORCH_DEVICE_MODEL)
model.eval()
model.config.label2id = {
"equal": 0,
"delete": 1,
"newline-1": 2,
"space-1": 3,
}
model.config.id2label = {v: k for k, v in model.config.label2id.items()}
return model
def edit_full_text(text: str, model: Optional[T5ForTokenClassification], batch_multiplier=1) -> (str, dict):
if not model:
return text, {}
batch_size = get_batch_size() * batch_multiplier
tokenized = byt5_tokenize(text, settings.EDITOR_MAX_LENGTH)
input_ids = tokenized["input_ids"]
char_token_lengths = tokenized["char_token_lengths"]
# Run model
token_masks = []
for i in range(0, len(input_ids), batch_size):
batch_input_ids = tokenized["input_ids"][i: i + batch_size]
batch_input_ids = torch.tensor(batch_input_ids, device=model.device)
batch_attention_mask = tokenized["attention_mask"][i: i + batch_size]
batch_attention_mask = torch.tensor(batch_attention_mask, device=model.device)
with torch.inference_mode():
predictions = model(batch_input_ids, attention_mask=batch_attention_mask)
logits = predictions.logits.cpu()
# If the max probability is less than a threshold, we assume it's a bad prediction
# We want to be conservative to not edit the text too much
probs = F.softmax(logits, dim=-1)
max_prob = torch.max(probs, dim=-1)
cutoff_prob = max_prob.values < settings.EDITOR_CUTOFF_THRESH
labels = logits.argmax(-1)
labels[cutoff_prob] = model.config.label2id["equal"]
labels = labels.squeeze().tolist()
if len(labels) == settings.EDITOR_MAX_LENGTH:
labels = [labels]
labels = list(chain.from_iterable(labels))
token_masks.extend(labels)
# List of characters in the text
flat_input_ids = list(chain.from_iterable(input_ids))
# Strip special tokens 0,1. Keep unknown token, although it should never be used
assert len(token_masks) == len(flat_input_ids)
token_masks = [mask for mask, token in zip(token_masks, flat_input_ids) if token >= 2]
assert len(token_masks) == len(list(text.encode("utf-8")))
edit_stats = defaultdict(int)
out_text = []
start = 0
for i, char in enumerate(text):
char_token_length = char_token_lengths[i]
masks = token_masks[start: start + char_token_length]
labels = [model.config.id2label[mask] for mask in masks]
if all(l == "delete" for l in labels):
# If we delete whitespace, roll with it, otherwise ignore
if char.strip():
out_text.append(char)
else:
edit_stats["delete"] += 1
elif labels[0] == "newline-1":
out_text.append("\n")
out_text.append(char)
edit_stats["newline-1"] += 1
elif labels[0] == "space-1":
out_text.append(" ")
out_text.append(char)
edit_stats["space-1"] += 1
else:
out_text.append(char)
edit_stats["equal"] += 1
start += char_token_length
out_text = "".join(out_text)
return out_text, edit_stats