|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" Logits Processor Helper class for Emu3. """ |
|
|
|
import torch |
|
|
|
class Emu3PrefixConstrainedLogitsHelper: |
|
|
|
def __init__( |
|
self, |
|
height, |
|
width, |
|
img_token, |
|
eoi_token, |
|
eos_token, |
|
eol_token, |
|
eof_token, |
|
pad_token, |
|
visual_tokens, |
|
): |
|
self.height = height |
|
self.width = width |
|
self.img_token = img_token |
|
self.eoi_token = eoi_token |
|
self.eos_token = eos_token |
|
self.eol_token = eol_token |
|
self.eof_token = eof_token |
|
self.pad_token = pad_token |
|
self.visual_tokens = visual_tokens |
|
|
|
self.offset_cache = {} |
|
|
|
def __call__(self, batch_id, input_ids): |
|
if batch_id not in self.offset_cache: |
|
position = torch.nonzero(input_ids == self.img_token, as_tuple=True)[0][0] |
|
self.offset_cache[batch_id] = position |
|
|
|
offset = input_ids.shape[0] - self.offset_cache[batch_id] |
|
if offset % (self.width + 1) == 0: |
|
return (self.eol_token, ) |
|
elif offset == (self.width + 1) * self.height + 1: |
|
return (self.eof_token, ) |
|
elif offset == (self.width + 1) * self.height + 2: |
|
return (self.eoi_token, ) |
|
elif offset == (self.width + 1) * self.height + 3: |
|
return (self.eos_token, ) |
|
elif offset > (self.width + 1) * self.height + 3: |
|
return (self.pad_token, ) |
|
else: |
|
return self.visual_tokens |
|
|