Upload scripts/_sft_classes.py with huggingface_hub
Browse files- scripts/_sft_classes.py +97 -0
scripts/_sft_classes.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class PhysicsCoTDataset(Dataset):
|
| 9 |
+
"""Dataset for Qwen2.5-VL SFT with physics CoT."""
|
| 10 |
+
|
| 11 |
+
def __init__(self, data_path, processor, max_length=4096):
|
| 12 |
+
self.processor = processor
|
| 13 |
+
self.max_length = max_length
|
| 14 |
+
with open(data_path, 'r', encoding='utf-8') as f:
|
| 15 |
+
self.records = [json.loads(line) for line in f]
|
| 16 |
+
print(f"Loaded {len(self.records)} records from {data_path}")
|
| 17 |
+
|
| 18 |
+
def __len__(self):
|
| 19 |
+
return len(self.records)
|
| 20 |
+
|
| 21 |
+
def __getitem__(self, idx):
|
| 22 |
+
record = self.records[idx]
|
| 23 |
+
messages = record['messages']
|
| 24 |
+
user_msg = messages[0]
|
| 25 |
+
image_path = None
|
| 26 |
+
text_content = ""
|
| 27 |
+
for content in user_msg['content']:
|
| 28 |
+
if content['type'] == 'image':
|
| 29 |
+
image_path = content['image'].replace('file://', '')
|
| 30 |
+
elif content['type'] == 'text':
|
| 31 |
+
text_content = content['text']
|
| 32 |
+
assistant_msg = messages[1]
|
| 33 |
+
assistant_text = assistant_msg['content'][0]['text']
|
| 34 |
+
image = Image.open(image_path).convert('RGB')
|
| 35 |
+
MIN_DIM = 56
|
| 36 |
+
w, h = image.size
|
| 37 |
+
if w < MIN_DIM or h < MIN_DIM:
|
| 38 |
+
scale = max(MIN_DIM / w, MIN_DIM / h)
|
| 39 |
+
new_w = int(w * scale)
|
| 40 |
+
new_h = int(h * scale)
|
| 41 |
+
image = image.resize((new_w, new_h), Image.LANCZOS)
|
| 42 |
+
if new_w < MIN_DIM or new_h < MIN_DIM:
|
| 43 |
+
padded = Image.new('RGB', (max(new_w, MIN_DIM), max(new_h, MIN_DIM)), (255, 255, 255))
|
| 44 |
+
padded.paste(image, (0, 0))
|
| 45 |
+
image = padded
|
| 46 |
+
conversation = [
|
| 47 |
+
{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": text_content}]},
|
| 48 |
+
{"role": "assistant", "content": [{"type": "text", "text": assistant_text}]},
|
| 49 |
+
]
|
| 50 |
+
text = self.processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False)
|
| 51 |
+
inputs = self.processor(text=[text], images=[image], padding=False, truncation=True, max_length=self.max_length, return_tensors="pt")
|
| 52 |
+
input_ids = inputs['input_ids'].squeeze(0)
|
| 53 |
+
attention_mask = inputs['attention_mask'].squeeze(0)
|
| 54 |
+
labels = input_ids.clone()
|
| 55 |
+
assistant_token_str = "<|im_start|>assistant\n"
|
| 56 |
+
assistant_token_ids = self.processor.tokenizer.encode(assistant_token_str, add_special_tokens=False)
|
| 57 |
+
input_ids_list = input_ids.tolist()
|
| 58 |
+
assistant_start = -1
|
| 59 |
+
for i in range(len(input_ids_list) - len(assistant_token_ids) + 1):
|
| 60 |
+
if input_ids_list[i:i + len(assistant_token_ids)] == assistant_token_ids:
|
| 61 |
+
assistant_start = i + len(assistant_token_ids)
|
| 62 |
+
break
|
| 63 |
+
if assistant_start > 0:
|
| 64 |
+
labels[:assistant_start] = -100
|
| 65 |
+
else:
|
| 66 |
+
raise ValueError(f"FATAL: assistant start token not found in sample {idx}.")
|
| 67 |
+
labels[attention_mask == 0] = -100
|
| 68 |
+
return {
|
| 69 |
+
'input_ids': input_ids,
|
| 70 |
+
'attention_mask': attention_mask,
|
| 71 |
+
'labels': labels,
|
| 72 |
+
'pixel_values': inputs.get('pixel_values', torch.tensor([])).squeeze(0) if 'pixel_values' in inputs else None,
|
| 73 |
+
'image_grid_thw': inputs.get('image_grid_thw', torch.tensor([])).squeeze(0) if 'image_grid_thw' in inputs else None,
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class VLMDataCollator:
|
| 78 |
+
"""Custom data collator for variable-length VLM inputs."""
|
| 79 |
+
def __init__(self, processor):
|
| 80 |
+
self.processor = processor
|
| 81 |
+
self.pad_token_id = processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id
|
| 82 |
+
|
| 83 |
+
def __call__(self, features):
|
| 84 |
+
max_len = max(f['input_ids'].size(0) for f in features)
|
| 85 |
+
input_ids, attention_mask, labels, pixel_values, image_grid_thw = [], [], [], [], []
|
| 86 |
+
for f in features:
|
| 87 |
+
seq_len = f['input_ids'].size(0)
|
| 88 |
+
pad_len = max_len - seq_len
|
| 89 |
+
input_ids.append(torch.cat([f['input_ids'], torch.full((pad_len,), self.pad_token_id, dtype=f['input_ids'].dtype)]))
|
| 90 |
+
attention_mask.append(torch.cat([f['attention_mask'], torch.zeros(pad_len, dtype=f['attention_mask'].dtype)]))
|
| 91 |
+
labels.append(torch.cat([f['labels'], torch.full((pad_len,), -100, dtype=f['labels'].dtype)]))
|
| 92 |
+
if f.get('pixel_values') is not None: pixel_values.append(f['pixel_values'])
|
| 93 |
+
if f.get('image_grid_thw') is not None: image_grid_thw.append(f['image_grid_thw'])
|
| 94 |
+
batch = {'input_ids': torch.stack(input_ids), 'attention_mask': torch.stack(attention_mask), 'labels': torch.stack(labels)}
|
| 95 |
+
if pixel_values: batch['pixel_values'] = torch.cat(pixel_values, dim=0)
|
| 96 |
+
if image_grid_thw: batch['image_grid_thw'] = torch.stack(image_grid_thw)
|
| 97 |
+
return batch
|