nougat-api / handler.py
kevin-pek
test without postprocessing
835c1c9
raw
history blame contribute delete
No virus
3.43 kB
from io import BytesIO
from typing import Dict, Any
from transformers import NougatProcessor, VisionEncoderDecoderModel, StoppingCriteria, StoppingCriteriaList
from transformers.image_utils import base64
from PIL import Image
import torch
class RunningVarTorch:
def __init__(self, L=15, norm=False):
self.values = None
self.L = L
self.norm = norm
def push(self, x: torch.Tensor):
assert x.dim() == 1
if self.values is None:
self.values = x[:, None]
elif self.values.shape[1] < self.L:
self.values = torch.cat((self.values, x[:, None]), 1)
else:
self.values = torch.cat((self.values[:, 1:], x[:, None]), 1)
def variance(self):
if self.values is None:
return
if self.norm:
return torch.var(self.values, 1) / self.values.shape[1]
else:
return torch.var(self.values, 1)
class StoppingCriteriaScores(StoppingCriteria):
def __init__(self, threshold: float = 0.015, window_size: int = 200):
super().__init__()
self.threshold = threshold
self.vars = RunningVarTorch(norm=True)
self.varvars = RunningVarTorch(L=window_size)
self.stop_inds = defaultdict(int)
self.stopped = defaultdict(bool)
self.size = 0
self.window_size = window_size
@torch.no_grad()
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
last_scores = scores[-1]
self.vars.push(last_scores.max(1)[0].float().cpu())
self.varvars.push(self.vars.variance())
self.size += 1
if self.size < self.window_size:
return False
varvar = self.varvars.variance()
for b in range(len(last_scores)):
if varvar[b] < self.threshold:
if self.stop_inds[b] > 0 and not self.stopped[b]:
self.stopped[b] = self.stop_inds[b] >= self.size
else:
self.stop_inds[b] = int(
min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095)
)
else:
self.stop_inds[b] = 0
self.stopped[b] = False
return all(self.stopped.values()) and len(self.stopped) > 0
class EndpointHandler():
def __init__(self, path="facebook/nougat-small") -> None:
self.processor = NougatProcessor.from_pretrained(path)
self.model = VisionEncoderDecoderModel.from_pretrained(path)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
def __call__(self, data: Dict[str, Any]) -> str:
image = data.pop("inputs", data)
image_data = Image.open(BytesIO(base64.b64decode(image)))
pixel_values = self.processor(image_data, return_tensors="pt").pixel_values
outputs = self.model.generate(
pixel_values.to(self.device),
min_length=1,
max_length=3584,
bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
output_scores=True,
stopping_criteria=StoppingCriteriaList([StoppingCriteriaScores()])
)
text = self.processor.batch_decode(outputs[0], skip_special_tokens=True)[0]
text = self.processor.post_process_generation(text, fix_markdown=False)
return outputs