File size: 3,582 Bytes
fc62d33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
from transformers import NougatProcessor, VisionEncoderDecoderModel, StoppingCriteria, StoppingCriteriaList
import torch.cuda
import io
import base64
from PIL import Image
from typing import Dict, Any
from collections import defaultdict
class RunningVarTorch:
def __init__(self, L=15, norm=False):
self.values = None
self.L = L
self.norm = norm
def push(self, x: torch.Tensor):
assert x.dim() == 1
if self.values is None:
self.values = x[:, None]
elif self.values.shape[1] < self.L:
self.values = torch.cat((self.values, x[:, None]), 1)
else:
self.values = torch.cat((self.values[:, 1:], x[:, None]), 1)
def variance(self):
if self.values is None:
return
if self.norm:
return torch.var(self.values, 1) / self.values.shape[1]
else:
return torch.var(self.values, 1)
class StoppingCriteriaScores(StoppingCriteria):
def __init__(self, threshold: float = 0.015, window_size: int = 200):
super().__init__()
self.threshold = threshold
self.vars = RunningVarTorch(norm=True)
self.varvars = RunningVarTorch(L=window_size)
self.stop_inds = defaultdict(int)
self.stopped = defaultdict(bool)
self.size = 0
self.window_size = window_size
@torch.no_grad()
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
last_scores = scores[-1]
self.vars.push(last_scores.max(1)[0].float().cpu())
self.varvars.push(self.vars.variance())
self.size += 1
if self.size < self.window_size:
return False
varvar = self.varvars.variance()
for b in range(len(last_scores)):
if varvar[b] < self.threshold:
if self.stop_inds[b] > 0 and not self.stopped[b]:
self.stopped[b] = self.stop_inds[b] >= self.size
else:
self.stop_inds[b] = int(
min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095)
)
else:
self.stop_inds[b] = 0
self.stopped[b] = False
return all(self.stopped.values()) and len(self.stopped) > 0
class EndpointHandler():
def __init__(self, path="facebook/nougat-base"):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.processor = NougatProcessor.from_pretrained(path)
self.model = VisionEncoderDecoderModel.from_pretrained(path)
self.model = self.model.to(self.device)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Args:
data (Dict): The payload with the text prompt
and generation parameters.
"""
# Get inputs
input = data.pop("inputs", None)
parameters = data.pop("parameters", None)
fix_markdown = data.pop("fix_markdown", None)
if input is None:
raise ValueError("Missing image.")
# autoregressively generate tokens, with custom stopping criteria (as defined by the Nougat authors)
binary_data = base64.b64decode(input)
image = Image.open(io.BytesIO(binary_data))
pixel_values = self.processor(images= image, return_tensors="pt").pixel_values
outputs = self.model.generate(
pixel_values=pixel_values.to(self.model.device),
min_length=1,
bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
output_scores=True,
stopping_criteria=StoppingCriteriaList([StoppingCriteriaScores()]),
**parameters,
)
generated = self.processor.batch_decode(outputs[0], skip_special_tokens=True)[0]
prediction = self.processor.post_process_generation(generated, fix_markdown=fix_markdown)
return {"generated_text": prediction} |