|
from collections import defaultdict |
|
from itertools import chain |
|
from typing import Optional |
|
|
|
from marker.settings import settings |
|
import torch |
|
import torch.nn.functional as F |
|
from marker.postprocessors.t5 import T5ForTokenClassification, byt5_tokenize |
|
|
|
|
|
def get_batch_size(): |
|
if settings.EDITOR_BATCH_SIZE is not None: |
|
return settings.EDITOR_BATCH_SIZE |
|
elif settings.TORCH_DEVICE_MODEL == "cuda": |
|
return 12 |
|
return 6 |
|
|
|
|
|
def load_editing_model(): |
|
if not settings.ENABLE_EDITOR_MODEL: |
|
return None |
|
|
|
model = T5ForTokenClassification.from_pretrained( |
|
settings.EDITOR_MODEL_NAME, |
|
torch_dtype=settings.MODEL_DTYPE, |
|
).to(settings.TORCH_DEVICE_MODEL) |
|
model.eval() |
|
|
|
model.config.label2id = { |
|
"equal": 0, |
|
"delete": 1, |
|
"newline-1": 2, |
|
"space-1": 3, |
|
} |
|
model.config.id2label = {v: k for k, v in model.config.label2id.items()} |
|
return model |
|
|
|
|
|
def edit_full_text(text: str, model: Optional[T5ForTokenClassification], batch_multiplier=1) -> (str, dict): |
|
if not model: |
|
return text, {} |
|
|
|
batch_size = get_batch_size() * batch_multiplier |
|
tokenized = byt5_tokenize(text, settings.EDITOR_MAX_LENGTH) |
|
input_ids = tokenized["input_ids"] |
|
char_token_lengths = tokenized["char_token_lengths"] |
|
|
|
|
|
token_masks = [] |
|
for i in range(0, len(input_ids), batch_size): |
|
batch_input_ids = tokenized["input_ids"][i: i + batch_size] |
|
batch_input_ids = torch.tensor(batch_input_ids, device=model.device) |
|
batch_attention_mask = tokenized["attention_mask"][i: i + batch_size] |
|
batch_attention_mask = torch.tensor(batch_attention_mask, device=model.device) |
|
with torch.inference_mode(): |
|
predictions = model(batch_input_ids, attention_mask=batch_attention_mask) |
|
|
|
logits = predictions.logits.cpu() |
|
|
|
|
|
|
|
probs = F.softmax(logits, dim=-1) |
|
max_prob = torch.max(probs, dim=-1) |
|
cutoff_prob = max_prob.values < settings.EDITOR_CUTOFF_THRESH |
|
labels = logits.argmax(-1) |
|
labels[cutoff_prob] = model.config.label2id["equal"] |
|
labels = labels.squeeze().tolist() |
|
if len(labels) == settings.EDITOR_MAX_LENGTH: |
|
labels = [labels] |
|
labels = list(chain.from_iterable(labels)) |
|
token_masks.extend(labels) |
|
|
|
|
|
flat_input_ids = list(chain.from_iterable(input_ids)) |
|
|
|
|
|
assert len(token_masks) == len(flat_input_ids) |
|
token_masks = [mask for mask, token in zip(token_masks, flat_input_ids) if token >= 2] |
|
|
|
assert len(token_masks) == len(list(text.encode("utf-8"))) |
|
|
|
edit_stats = defaultdict(int) |
|
out_text = [] |
|
start = 0 |
|
for i, char in enumerate(text): |
|
char_token_length = char_token_lengths[i] |
|
masks = token_masks[start: start + char_token_length] |
|
labels = [model.config.id2label[mask] for mask in masks] |
|
if all(l == "delete" for l in labels): |
|
|
|
if char.strip(): |
|
out_text.append(char) |
|
else: |
|
edit_stats["delete"] += 1 |
|
elif labels[0] == "newline-1": |
|
out_text.append("\n") |
|
out_text.append(char) |
|
edit_stats["newline-1"] += 1 |
|
elif labels[0] == "space-1": |
|
out_text.append(" ") |
|
out_text.append(char) |
|
edit_stats["space-1"] += 1 |
|
else: |
|
out_text.append(char) |
|
edit_stats["equal"] += 1 |
|
|
|
start += char_token_length |
|
|
|
out_text = "".join(out_text) |
|
return out_text, edit_stats |
|
|
|
|
|
|
|
|
|
|
|
|
|
|