kimihailv commited on
Commit
ab80423
1 Parent(s): 80e48d5

Upload processing_uform_gen.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. processing_uform_gen.py +181 -0
processing_uform_gen.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from transformers.processing_utils import ProcessorMixin
6
+ from transformers.image_processing_utils import BaseImageProcessor
7
+ from transformers import AutoTokenizer, AutoConfig
8
+ from transformers import BatchFeature
9
+
10
+ from PIL import Image
11
+ from torchvision.transforms import (
12
+ Compose,
13
+ Normalize,
14
+ Resize,
15
+ ToTensor
16
+ )
17
+
18
+
19
+ IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073)
20
+ IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711)
21
+
22
+
23
+ def convert_to_rgb(x):
24
+ return x.convert("RGB")
25
+
26
+
27
+ def expand2square(image, background_color):
28
+ width, height = image.size
29
+ if width == height:
30
+ return image
31
+ elif width > height:
32
+ result = Image.new(image.mode, (width, width), background_color)
33
+ result.paste(image, (0, (width - height) // 2))
34
+ return result
35
+ else:
36
+ result = Image.new(image.mode, (height, height), background_color)
37
+ result.paste(image, ((height - width) // 2, 0))
38
+ return result
39
+
40
+
41
+ class ImageProcessor(BaseImageProcessor):
42
+ def __init__(
43
+ self,
44
+ image_size: int,
45
+ **kwargs
46
+ ):
47
+ super().__init__(**kwargs)
48
+ self.transform = Compose(
49
+ [
50
+ convert_to_rgb,
51
+ partial(
52
+ expand2square,
53
+ background_color=tuple(int(255 * v) for v in IMAGENET_MEAN)
54
+ ),
55
+ Resize(image_size),
56
+ ToTensor(),
57
+ Normalize(
58
+ mean=IMAGENET_MEAN,
59
+ std=IMAGENET_STD,
60
+ ),
61
+ ]
62
+ )
63
+
64
+ def preprocess(
65
+ self,
66
+ image: Image
67
+ ):
68
+ return self.transform(image)
69
+
70
+ def __repr__(self):
71
+ return repr(self.transform)
72
+
73
+
74
+ class VLMProcessor(ProcessorMixin):
75
+ def __init__(self, config):
76
+ self.config = config
77
+ self.image_size = config.image_size
78
+
79
+ self.feature_extractor = ImageProcessor(self.image_size)
80
+ self.tokenizer = AutoTokenizer.from_pretrained(
81
+ config.text_decoder_name_or_path, additional_special_tokens=["<image>"]
82
+ )
83
+ self.tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
84
+ self.num_image_latents = config.num_image_latents
85
+ # super().__init__(self.image_processor, self.tokenizer)
86
+
87
+ def __call__(
88
+ self, text=None, images=None, **kwargs
89
+ ):
90
+ if text is not None:
91
+ if isinstance(text, str):
92
+ text = [text]
93
+
94
+ tokenized_texts = []
95
+ for t in text:
96
+ messages = [
97
+ {"role": "system", "content": "You are a helpful assistant."},
98
+ {"role": "user", "content": f" <image> {t}"},
99
+ ]
100
+ tokenized_prompt = self.tokenizer.apply_chat_template(
101
+ messages, add_generation_prompt=True, return_tensors="pt"
102
+ )
103
+
104
+ tokenized_texts.append(tokenized_prompt)
105
+
106
+ max_len = max(len(t[0]) for t in tokenized_texts)
107
+ input_ids = torch.full(
108
+ (len(tokenized_texts), max_len),
109
+ fill_value=self.tokenizer.pad_token_id,
110
+ dtype=torch.int64,
111
+ )
112
+ attention_mask = torch.full(
113
+ (len(tokenized_texts), max_len), fill_value=0, dtype=torch.int64
114
+ )
115
+
116
+ for i, tokens in enumerate(tokenized_texts):
117
+ input_ids[i, -len(tokens[0]) :] = tokens[0]
118
+ attention_mask[i, -len(tokens[0]) :] = 1
119
+
120
+ attention_mask = F.pad(
121
+ attention_mask, pad=(0, self.num_image_latents - 1), value=1
122
+ )
123
+
124
+ encoding = BatchFeature(
125
+ data={"input_ids": input_ids, "attention_mask": attention_mask}
126
+ )
127
+
128
+ if images is not None:
129
+ if isinstance(images, (list, tuple)):
130
+ image_features = torch.empty(
131
+ (len(images), 3, self.image_size , self.image_size),
132
+ dtype=torch.float32,
133
+ )
134
+
135
+ for i, image in enumerate(images):
136
+ image_features[i] = self.feature_extractor(image)
137
+
138
+ else:
139
+ image_features = self.image_processor(images).unsqueeze(0)
140
+
141
+ if text is not None and images is not None:
142
+ encoding["images"] = image_features
143
+ return encoding
144
+
145
+ elif text is not None:
146
+ return encoding
147
+
148
+ else:
149
+ return BatchFeature(
150
+ data={
151
+ "images": image_features,
152
+ },
153
+ tensor_type=return_tensors,
154
+ )
155
+
156
+ def batch_decode(self, *args, **kwargs):
157
+ """
158
+ This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
159
+ refer to the docstring of this method for more information.
160
+ """
161
+ return self.tokenizer.batch_decode(*args, **kwargs)
162
+
163
+ def decode(self, *args, **kwargs):
164
+ """
165
+ This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
166
+ the docstring of this method for more information.
167
+ """
168
+ return self.tokenizer.decode(*args, **kwargs)
169
+
170
+ @classmethod
171
+ def from_pretrained(
172
+ cls,
173
+ pretrained_model_name_or_path,
174
+ trust_remote_code=False,
175
+ **kwargs
176
+ ):
177
+ config = AutoConfig.from_pretrained(
178
+ pretrained_model_name_or_path,
179
+ trust_remote_code=trust_remote_code
180
+ )
181
+ return cls(config)