gowitheflow commited on
Commit
835f7bd
·
verified ·
1 Parent(s): 1db5f2e

Upload full checkpoint folder

Browse files
Files changed (1) hide show
  1. example_inference.py +167 -0
example_inference.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from PIL import Image
4
+ from transformers import AutoProcessor, AutoConfig
5
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel
6
+ from tqdm import tqdm
7
+ from safetensors.torch import load_file
8
+ import os
9
+
10
+ class Qwen2_5_VL_ImageEncoder:
11
+ def __init__(self, model_path: str, device: str = "cuda", dtype=torch.bfloat16):
12
+ self.device = device
13
+ self.dtype = dtype
14
+
15
+ print(f"Loading processor and model from {model_path}...")
16
+ self.processor = AutoProcessor.from_pretrained("/mnt/ai4sci_develop_storage/home/chaohao/LCO-Embedding/Training/Qwen2.5-VL-ViT-Only", trust_remote_code=True)
17
+
18
+ config = AutoConfig.from_pretrained('/mnt/workspace/workgroup/chx/Qwen2.5-VL-7B-Instruct')
19
+ config = config.vision_config
20
+
21
+ self.model = Qwen2_5_VisionTransformerPretrainedModel(config)
22
+
23
+ safe_path = os.path.join(model_path, "model.safetensors")
24
+ state_dict = load_file(safe_path)
25
+ self.model.load_state_dict(state_dict, strict=True)
26
+
27
+ self.model.to(device=self.device, dtype=self.dtype)
28
+ self.model.eval()
29
+ print("Model loaded successfully.")
30
+
31
+ def _process_batch_forward(self, images):
32
+ """Internal helper to run forward pass on a single batch."""
33
+ # 1. Prepare Inputs
34
+ messages_list = [
35
+ [
36
+ {"type": "image", "image": img},
37
+ {"type": "text", "text": "Describe this image."},
38
+ ] for img in images
39
+ ]
40
+
41
+ # Apply template for each item in the batch
42
+ text_inputs = [
43
+ self.processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
44
+ for msg in messages_list
45
+ ]
46
+
47
+ # Processor handles the batching of pixel values and grids
48
+ inputs = self.processor(
49
+ images=images,
50
+ text=text_inputs,
51
+ return_tensors="pt",
52
+ padding=True
53
+ )
54
+
55
+ # Move to device
56
+ pixel_values = inputs["pixel_values"].to(self.device, dtype=self.dtype)
57
+ grid_thw = inputs["image_grid_thw"].to(self.device)
58
+
59
+ # 2. Model Forward
60
+ outputs = self.model(hidden_states=pixel_values, grid_thw=grid_thw)
61
+ hidden_states = outputs
62
+
63
+ # 3. Pooling Logic (Exact replica of training logic)
64
+ if grid_thw.dim() == 3 and grid_thw.size(1) == 1:
65
+ grid_thw = grid_thw.squeeze(1)
66
+
67
+ batch_size = grid_thw.shape[0]
68
+
69
+ # Calculate tokens per image based on grid dimensions (H//2 * W//2)
70
+ H, W = grid_thw[:, 1], grid_thw[:, 2]
71
+ sizes = ((H // 2) * (W // 2)).long()
72
+
73
+ # Safety fix for token mismatch
74
+ total_tokens = hidden_states.shape[0]
75
+ if sizes.sum().item() != total_tokens:
76
+ sizes[-1] += (total_tokens - sizes.sum().item())
77
+
78
+ # Create batch indices [0,0,0, 1,1, 2,2,2...]
79
+ batch_indices = torch.repeat_interleave(
80
+ torch.arange(batch_size, device=self.device),
81
+ sizes
82
+ )
83
+
84
+ # Sum Pooling
85
+ pooled_sum = torch.zeros(
86
+ (batch_size, hidden_states.shape[-1]),
87
+ dtype=self.dtype,
88
+ device=self.device
89
+ )
90
+ pooled_sum.index_add_(0, batch_indices, hidden_states)
91
+
92
+ # Mean Pooling
93
+ counts = sizes.unsqueeze(1).to(dtype=self.dtype).clamp(min=1.0)
94
+ embeds = pooled_sum / counts
95
+
96
+ # 4. Normalize
97
+ embeds = F.normalize(embeds, p=2, dim=-1)
98
+
99
+ return embeds.cpu() # Move to CPU to save GPU memory during accumulation
100
+
101
+ @torch.no_grad()
102
+ def encode_batch(self, images: list, batch_size: int = 32, show_progress: bool = True):
103
+ """
104
+ Args:
105
+ images: List of PIL Images.
106
+ batch_size: Number of images to process at once.
107
+ Returns:
108
+ torch.Tensor: Concatenated embeddings [Total_Images, Hidden_Dim]
109
+ """
110
+ all_embeddings = []
111
+
112
+ iterator = range(0, len(images), batch_size)
113
+ if show_progress:
114
+ iterator = tqdm(iterator, desc="Encoding Batches", unit="batch")
115
+
116
+ for i in iterator:
117
+ batch_images = images[i : i + batch_size]
118
+
119
+ # Ensure all are RGB
120
+ batch_images = [img.convert("RGB") for img in batch_images]
121
+
122
+ try:
123
+ batch_embeds = self._process_batch_forward(batch_images)
124
+ all_embeddings.append(batch_embeds)
125
+ except Exception as e:
126
+ print(f"Error processing batch starting at index {i}: {e}")
127
+ # Optional: return partial results or re-raise
128
+ raise e
129
+
130
+ if not all_embeddings:
131
+ return torch.empty(0)
132
+
133
+ # Concatenate all batches into one large tensor
134
+ return torch.cat(all_embeddings, dim=0)
135
+
136
+ # --- Usage Example ---
137
+ if __name__ == "__main__":
138
+
139
+ MODEL_PATHS = [
140
+ "/mnt/ai4sci_develop_storage/home/chaohao/LCO-Embedding/Training/checkpoints/final/checkpoint-500",
141
+ "/mnt/ai4sci_develop_storage/home/chaohao/LCO-Embedding/Training/checkpoints/final/checkpoint-550"]
142
+ for MODEL_PATH in MODEL_PATHS:
143
+ encoder = Qwen2_5_VL_ImageEncoder(MODEL_PATH)
144
+
145
+ from datasets import load_dataset
146
+ spearmans = []
147
+ for lang in ["en","de","es","fr","it","nl","pl","pt","ru","zh"]:
148
+ dataset = load_dataset("/mnt/ai4sci_develop_storage/home/chaohao/LCO-Embedding/Training/a_eval/stsb",lang)["test"]
149
+ anchors = dataset["sentence1"]
150
+ positive = dataset["sentence2"]
151
+
152
+ embeddings1 = encoder.encode_batch(anchors, batch_size=32)
153
+ embeddings2 = encoder.encode_batch(positive, batch_size=32)
154
+ groundtruth = dataset["score"]
155
+
156
+
157
+ from sklearn.metrics.pairwise import paired_cosine_distances
158
+ import numpy as np
159
+ from scipy.stats import spearmanr
160
+
161
+ embeddings1 = embeddings1.cpu().float().numpy()
162
+ embeddings2 = embeddings2.cpu().float().numpy()
163
+
164
+ cos_sim = 1 - paired_cosine_distances(embeddings1, embeddings2)
165
+ spearman_corr, _ = spearmanr(cos_sim, groundtruth)
166
+ spearmans.append(round(spearman_corr,2))
167
+ print("Spearman correlation:", spearmans)