visheratin commited on
Commit
d6d35ed
1 Parent(s): 9c984a5

Update model files

Browse files
Files changed (1) hide show
  1. processing_llava.py +101 -0
processing_llava.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Processor class for Llava.
17
+ """
18
+
19
+
20
+ from typing import List, Optional, Union
21
+
22
+ from transformers.feature_extraction_utils import BatchFeature
23
+ from transformers.image_utils import ImageInput
24
+ from transformers.tokenization_utils_base import (
25
+ PaddingStrategy,
26
+ PreTokenizedInput,
27
+ TextInput,
28
+ TruncationStrategy,
29
+ )
30
+ from transformers.utils import TensorType
31
+ import torch
32
+ from open_clip.transform import PreprocessCfg, image_transform_v2
33
+
34
+
35
+ class OpenCLIPImageProcessor:
36
+ def __init__(self, config):
37
+ cfg = PreprocessCfg(**config)
38
+ transform = image_transform_v2(cfg=cfg, is_train=False)
39
+ self.transform = transform
40
+
41
+ def __call__(self, image, return_tensors):
42
+ if isinstance(image, list):
43
+ outputs = []
44
+ for item in image:
45
+ outputs.append(self.transform(item))
46
+ return {
47
+ "pixel_values": torch.tensor(outputs),
48
+ }
49
+ output = self.transform(image)
50
+ return {
51
+ "pixel_values": output.unsqueeze(0),
52
+ }
53
+
54
+ @property
55
+ def model_input_names(self):
56
+ return ["pixel_values"]
57
+
58
+
59
+ class LlavaProcessor:
60
+ def __init__(self, image_processor: OpenCLIPImageProcessor, tokenizer):
61
+ self.image_processor = image_processor
62
+ self.tokenizer = tokenizer
63
+
64
+ def __call__(
65
+ self,
66
+ text: Union[
67
+ TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
68
+ ] = None,
69
+ images: ImageInput = None,
70
+ padding: Union[bool, str, PaddingStrategy] = False,
71
+ truncation: Union[bool, str, TruncationStrategy] = None,
72
+ max_length=None,
73
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
74
+ ) -> BatchFeature:
75
+ if images is not None:
76
+ pixel_values = self.image_processor(images, return_tensors=return_tensors)[
77
+ "pixel_values"
78
+ ]
79
+ else:
80
+ pixel_values = None
81
+ text_inputs = self.tokenizer(
82
+ text,
83
+ return_tensors=return_tensors,
84
+ padding=padding,
85
+ truncation=truncation,
86
+ max_length=max_length,
87
+ )
88
+
89
+ return BatchFeature(data={**text_inputs, "pixel_values": pixel_values})
90
+
91
+ def batch_decode(self, *args, **kwargs):
92
+ return self.tokenizer.batch_decode(*args, **kwargs)
93
+
94
+ def decode(self, *args, **kwargs):
95
+ return self.tokenizer.decode(*args, **kwargs)
96
+
97
+ @property
98
+ def model_input_names(self):
99
+ tokenizer_input_names = self.tokenizer.model_input_names
100
+ image_processor_input_names = self.image_processor.model_input_names
101
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))