YUNTA88 commited on
Commit
6c57edb
·
verified ·
1 Parent(s): adca846

Upload scripts/_sft_classes.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/_sft_classes.py +97 -0
scripts/_sft_classes.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json
3
+ import torch
4
+ from PIL import Image
5
+ from torch.utils.data import Dataset
6
+
7
+
8
+ class PhysicsCoTDataset(Dataset):
9
+ """Dataset for Qwen2.5-VL SFT with physics CoT."""
10
+
11
+ def __init__(self, data_path, processor, max_length=4096):
12
+ self.processor = processor
13
+ self.max_length = max_length
14
+ with open(data_path, 'r', encoding='utf-8') as f:
15
+ self.records = [json.loads(line) for line in f]
16
+ print(f"Loaded {len(self.records)} records from {data_path}")
17
+
18
+ def __len__(self):
19
+ return len(self.records)
20
+
21
+ def __getitem__(self, idx):
22
+ record = self.records[idx]
23
+ messages = record['messages']
24
+ user_msg = messages[0]
25
+ image_path = None
26
+ text_content = ""
27
+ for content in user_msg['content']:
28
+ if content['type'] == 'image':
29
+ image_path = content['image'].replace('file://', '')
30
+ elif content['type'] == 'text':
31
+ text_content = content['text']
32
+ assistant_msg = messages[1]
33
+ assistant_text = assistant_msg['content'][0]['text']
34
+ image = Image.open(image_path).convert('RGB')
35
+ MIN_DIM = 56
36
+ w, h = image.size
37
+ if w < MIN_DIM or h < MIN_DIM:
38
+ scale = max(MIN_DIM / w, MIN_DIM / h)
39
+ new_w = int(w * scale)
40
+ new_h = int(h * scale)
41
+ image = image.resize((new_w, new_h), Image.LANCZOS)
42
+ if new_w < MIN_DIM or new_h < MIN_DIM:
43
+ padded = Image.new('RGB', (max(new_w, MIN_DIM), max(new_h, MIN_DIM)), (255, 255, 255))
44
+ padded.paste(image, (0, 0))
45
+ image = padded
46
+ conversation = [
47
+ {"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": text_content}]},
48
+ {"role": "assistant", "content": [{"type": "text", "text": assistant_text}]},
49
+ ]
50
+ text = self.processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False)
51
+ inputs = self.processor(text=[text], images=[image], padding=False, truncation=True, max_length=self.max_length, return_tensors="pt")
52
+ input_ids = inputs['input_ids'].squeeze(0)
53
+ attention_mask = inputs['attention_mask'].squeeze(0)
54
+ labels = input_ids.clone()
55
+ assistant_token_str = "<|im_start|>assistant\n"
56
+ assistant_token_ids = self.processor.tokenizer.encode(assistant_token_str, add_special_tokens=False)
57
+ input_ids_list = input_ids.tolist()
58
+ assistant_start = -1
59
+ for i in range(len(input_ids_list) - len(assistant_token_ids) + 1):
60
+ if input_ids_list[i:i + len(assistant_token_ids)] == assistant_token_ids:
61
+ assistant_start = i + len(assistant_token_ids)
62
+ break
63
+ if assistant_start > 0:
64
+ labels[:assistant_start] = -100
65
+ else:
66
+ raise ValueError(f"FATAL: assistant start token not found in sample {idx}.")
67
+ labels[attention_mask == 0] = -100
68
+ return {
69
+ 'input_ids': input_ids,
70
+ 'attention_mask': attention_mask,
71
+ 'labels': labels,
72
+ 'pixel_values': inputs.get('pixel_values', torch.tensor([])).squeeze(0) if 'pixel_values' in inputs else None,
73
+ 'image_grid_thw': inputs.get('image_grid_thw', torch.tensor([])).squeeze(0) if 'image_grid_thw' in inputs else None,
74
+ }
75
+
76
+
77
+ class VLMDataCollator:
78
+ """Custom data collator for variable-length VLM inputs."""
79
+ def __init__(self, processor):
80
+ self.processor = processor
81
+ self.pad_token_id = processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id
82
+
83
+ def __call__(self, features):
84
+ max_len = max(f['input_ids'].size(0) for f in features)
85
+ input_ids, attention_mask, labels, pixel_values, image_grid_thw = [], [], [], [], []
86
+ for f in features:
87
+ seq_len = f['input_ids'].size(0)
88
+ pad_len = max_len - seq_len
89
+ input_ids.append(torch.cat([f['input_ids'], torch.full((pad_len,), self.pad_token_id, dtype=f['input_ids'].dtype)]))
90
+ attention_mask.append(torch.cat([f['attention_mask'], torch.zeros(pad_len, dtype=f['attention_mask'].dtype)]))
91
+ labels.append(torch.cat([f['labels'], torch.full((pad_len,), -100, dtype=f['labels'].dtype)]))
92
+ if f.get('pixel_values') is not None: pixel_values.append(f['pixel_values'])
93
+ if f.get('image_grid_thw') is not None: image_grid_thw.append(f['image_grid_thw'])
94
+ batch = {'input_ids': torch.stack(input_ids), 'attention_mask': torch.stack(attention_mask), 'labels': torch.stack(labels)}
95
+ if pixel_values: batch['pixel_values'] = torch.cat(pixel_values, dim=0)
96
+ if image_grid_thw: batch['image_grid_thw'] = torch.stack(image_grid_thw)
97
+ return batch