Image-Text-to-Text
sentence-transformers
Safetensors
Transformers
qwen2_vl
Qwen2-VL
conversational
cheesyFishes commited on
Commit
fdba9e3
·
verified ·
1 Parent(s): 47f5e7c

add load method, improve image processing to support URLS etc.

Browse files
Files changed (1) hide show
  1. custom_st.py +77 -24
custom_st.py CHANGED
@@ -9,7 +9,7 @@ import requests
9
  import torch
10
  from PIL import Image
11
  from torch import nn
12
- from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
13
 
14
  class Transformer(nn.Module):
15
  save_in_root: bool = True
@@ -23,6 +23,9 @@ class Transformer(nn.Module):
23
  dimension: int = 2048,
24
  cache_dir: Optional[str] = None,
25
  device: str = 'cuda:0',
 
 
 
26
  **kwargs,
27
  ) -> None:
28
  super(Transformer, self).__init__()
@@ -31,40 +34,61 @@ class Transformer(nn.Module):
31
  self.dimension = dimension
32
  self.max_pixels = max_pixels
33
  self.min_pixels = min_pixels
 
 
 
34
 
35
- # Try to use flash attention if available, fallback to default attention if not
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  try:
37
- self.model = Qwen2VLForConditionalGeneration.from_pretrained(
38
- model_name_or_path,
39
  attn_implementation="flash_attention_2",
40
  torch_dtype=torch.bfloat16,
41
- device_map=device,
42
- cache_dir=cache_dir,
43
- **kwargs
44
  ).eval()
45
  except (ImportError, ValueError) as e:
46
  print(f"Flash attention not available, falling back to default attention: {e}")
47
- self.model = Qwen2VLForConditionalGeneration.from_pretrained(
48
- model_name_or_path,
49
  torch_dtype=torch.bfloat16,
50
- device_map=device,
51
- cache_dir=cache_dir,
52
- **kwargs
53
  ).eval()
54
 
55
  # Initialize processor
56
- self.processor = AutoProcessor.from_pretrained(
57
- processor_name_or_path or model_name_or_path,
58
- min_pixels=min_pixels,
59
- max_pixels=max_pixels,
60
- cache_dir=cache_dir
 
61
  )
62
 
63
- self.model.padding_side = "left"
64
- self.processor.tokenizer.padding_side = "left"
65
-
66
- self.document_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is shown in this image?<|im_end|>\n<|endoftext|>"
67
- self.query_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Query: %s<|im_end|>\n<|endoftext|>"
68
 
69
  def _smart_resize(self, height: int, width: int) -> tuple[int, int]:
70
  h_bar = max(28, self._round_by_factor(height, 28))
@@ -108,8 +132,21 @@ class Transformer(nn.Module):
108
 
109
  for sample in texts:
110
  if isinstance(sample, str):
111
- processed_texts.append(self.query_prompt % sample)
112
- processed_images.append(dummy_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  elif isinstance(sample, Image.Image):
114
  processed_texts.append(self.document_prompt)
115
  processed_images.append(self._resize_image(sample))
@@ -149,5 +186,21 @@ class Transformer(nn.Module):
149
  return {k: v.to(self.device) for k, v in inputs.items()}
150
 
151
  def save(self, output_path: str, safe_serialization: bool = True) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  self.model.save_pretrained(output_path, safe_serialization=safe_serialization)
153
  self.processor.save_pretrained(output_path)
 
9
  import torch
10
  from PIL import Image
11
  from torch import nn
12
+ from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, AutoConfig
13
 
14
  class Transformer(nn.Module):
15
  save_in_root: bool = True
 
23
  dimension: int = 2048,
24
  cache_dir: Optional[str] = None,
25
  device: str = 'cuda:0',
26
+ config_args: Optional[Dict[str, Any]] = None,
27
+ model_args: Optional[Dict[str, Any]] = None,
28
+ processor_args: Optional[Dict[str, Any]] = None,
29
  **kwargs,
30
  ) -> None:
31
  super(Transformer, self).__init__()
 
34
  self.dimension = dimension
35
  self.max_pixels = max_pixels
36
  self.min_pixels = min_pixels
37
+ self.model_name_or_path = model_name_or_path
38
+ self.processor_name_or_path = processor_name_or_path or model_name_or_path
39
+ self.cache_dir = cache_dir
40
 
41
+ self.config_args = config_args or {}
42
+ self.model_args = model_args or {}
43
+ self.processor_args = processor_args or {}
44
+
45
+ self.document_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is shown in this image?<|im_end|>\n<|endoftext|>"
46
+ self.query_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Query: %s<|im_end|>\n<|endoftext|>"
47
+
48
+ @classmethod
49
+ def load(cls, input_path: str) -> 'Transformer':
50
+ config_path = os.path.join(input_path, 'config.json')
51
+ if os.path.exists(config_path):
52
+ with open(config_path) as f:
53
+ config = json.load(f)
54
+ else:
55
+ config = {}
56
+
57
+ instance = cls(model_name_or_path=input_path, **config)
58
+
59
+ # Load model with flash attention if available
60
  try:
61
+ instance.model = Qwen2VLForConditionalGeneration.from_pretrained(
62
+ input_path,
63
  attn_implementation="flash_attention_2",
64
  torch_dtype=torch.bfloat16,
65
+ device_map=instance.device,
66
+ cache_dir=instance.cache_dir,
67
+ **instance.model_args
68
  ).eval()
69
  except (ImportError, ValueError) as e:
70
  print(f"Flash attention not available, falling back to default attention: {e}")
71
+ instance.model = Qwen2VLForConditionalGeneration.from_pretrained(
72
+ input_path,
73
  torch_dtype=torch.bfloat16,
74
+ device_map=instance.device,
75
+ cache_dir=instance.cache_dir,
76
+ **instance.model_args
77
  ).eval()
78
 
79
  # Initialize processor
80
+ instance.processor = AutoProcessor.from_pretrained(
81
+ input_path,
82
+ min_pixels=instance.min_pixels,
83
+ max_pixels=instance.max_pixels,
84
+ cache_dir=instance.cache_dir,
85
+ **instance.processor_args
86
  )
87
 
88
+ instance.model.padding_side = "left"
89
+ instance.processor.tokenizer.padding_side = "left"
90
+
91
+ return instance
 
92
 
93
  def _smart_resize(self, height: int, width: int) -> tuple[int, int]:
94
  h_bar = max(28, self._round_by_factor(height, 28))
 
132
 
133
  for sample in texts:
134
  if isinstance(sample, str):
135
+ if sample.startswith('http') or sample.startswith('data:image/'):
136
+ try:
137
+ if sample.startswith('http'):
138
+ response = requests.get(sample)
139
+ image = Image.open(BytesIO(response.content)).convert('RGB')
140
+ else:
141
+ image = self._decode_data_image(sample).convert('RGB')
142
+ processed_texts.append(self.document_prompt)
143
+ processed_images.append(self._resize_image(image))
144
+ except Exception as e:
145
+ processed_texts.append(self.query_prompt % sample)
146
+ processed_images.append(dummy_image)
147
+ else:
148
+ processed_texts.append(self.query_prompt % sample)
149
+ processed_images.append(dummy_image)
150
  elif isinstance(sample, Image.Image):
151
  processed_texts.append(self.document_prompt)
152
  processed_images.append(self._resize_image(sample))
 
186
  return {k: v.to(self.device) for k, v in inputs.items()}
187
 
188
  def save(self, output_path: str, safe_serialization: bool = True) -> None:
189
+ # Save the configuration
190
+ config = {
191
+ 'model_name_or_path': self.model_name_or_path,
192
+ 'processor_name_or_path': self.processor_name_or_path,
193
+ 'max_pixels': self.max_pixels,
194
+ 'min_pixels': self.min_pixels,
195
+ 'dimension': self.dimension,
196
+ 'config_args': self.config_args,
197
+ 'model_args': self.model_args,
198
+ 'processor_args': self.processor_args,
199
+ }
200
+
201
+ os.makedirs(output_path, exist_ok=True)
202
+ with open(os.path.join(output_path, 'config.json'), 'w') as f:
203
+ json.dump(config, f)
204
+
205
  self.model.save_pretrained(output_path, safe_serialization=safe_serialization)
206
  self.processor.save_pretrained(output_path)