File size: 3,582 Bytes
fc62d33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from transformers import NougatProcessor, VisionEncoderDecoderModel, StoppingCriteria, StoppingCriteriaList
import torch.cuda
import io
import base64
from PIL import Image
from typing import Dict, Any
from collections import defaultdict

class RunningVarTorch:
  def __init__(self, L=15, norm=False):
    self.values = None
    self.L = L
    self.norm = norm

  def push(self, x: torch.Tensor):
    assert x.dim() == 1
    if self.values is None:
      self.values = x[:, None]
    elif self.values.shape[1] < self.L:
      self.values = torch.cat((self.values, x[:, None]), 1)
    else:
      self.values = torch.cat((self.values[:, 1:], x[:, None]), 1)

  def variance(self):
    if self.values is None:
      return
    if self.norm:
      return torch.var(self.values, 1) / self.values.shape[1]
    else:
      return torch.var(self.values, 1)

class StoppingCriteriaScores(StoppingCriteria):
  def __init__(self, threshold: float = 0.015, window_size: int = 200):
    super().__init__()
    self.threshold = threshold
    self.vars = RunningVarTorch(norm=True)
    self.varvars = RunningVarTorch(L=window_size)
    self.stop_inds = defaultdict(int)
    self.stopped = defaultdict(bool)
    self.size = 0
    self.window_size = window_size

  @torch.no_grad()
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
    last_scores = scores[-1]
    self.vars.push(last_scores.max(1)[0].float().cpu())
    self.varvars.push(self.vars.variance())
    self.size += 1
    if self.size < self.window_size:
      return False

    varvar = self.varvars.variance()
    for b in range(len(last_scores)):
      if varvar[b] < self.threshold:
        if self.stop_inds[b] > 0 and not self.stopped[b]:
          self.stopped[b] = self.stop_inds[b] >= self.size
        else:
          self.stop_inds[b] = int(
              min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095)
          )
      else:
        self.stop_inds[b] = 0
        self.stopped[b] = False
    return all(self.stopped.values()) and len(self.stopped) > 0

class EndpointHandler():
  def __init__(self, path="facebook/nougat-base"):
    self.device = "cuda" if torch.cuda.is_available() else "cpu"
    self.processor = NougatProcessor.from_pretrained(path)
    self.model = VisionEncoderDecoderModel.from_pretrained(path)
    self.model = self.model.to(self.device)

  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
    """
    Args:
      data (Dict): The payload with the text prompt 
    and generation parameters.
    """
    # Get inputs
    input = data.pop("inputs", None)
    parameters = data.pop("parameters", None)
    fix_markdown = data.pop("fix_markdown", None)
    if input is None:
      raise ValueError("Missing image.")
    # autoregressively generate tokens, with custom stopping criteria (as defined by the Nougat authors)
    binary_data = base64.b64decode(input)

    image = Image.open(io.BytesIO(binary_data))
    pixel_values = self.processor(images= image, return_tensors="pt").pixel_values
    outputs = self.model.generate(
      pixel_values=pixel_values.to(self.model.device),
      min_length=1,
      bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
      return_dict_in_generate=True,
      output_scores=True,
      stopping_criteria=StoppingCriteriaList([StoppingCriteriaScores()]),
      **parameters,
    )
    generated = self.processor.batch_decode(outputs[0], skip_special_tokens=True)[0]
    prediction = self.processor.post_process_generation(generated, fix_markdown=fix_markdown)

    return {"generated_text": prediction}