|
from tqdm.auto import tqdm |
|
from constants import * |
|
from utils import * |
|
import pickle |
|
from torch.utils.data import Dataset |
|
import torch |
|
import torch.nn.functional as F |
|
from PIL import Image |
|
from torch.utils.data import DataLoader |
|
import os |
|
|
|
def format_point_text(points): |
|
|
|
text = "<result_start>" |
|
for point in points: |
|
|
|
px = min(max(int(point.get('x', 50) * IMAGE_SIZE / 100), 0), IMAGE_SIZE - 1) |
|
py = min(max(int(point.get('y', 50) * IMAGE_SIZE / 100), 0), IMAGE_SIZE - 1) |
|
x_bin = min(px // (IMAGE_SIZE // NUM_BINS), NUM_BINS - 1) |
|
y_bin = min(py // (IMAGE_SIZE // NUM_BINS), NUM_BINS - 1) |
|
text += f"<pointx_start><coord_bin_{x_bin}><pointx_end><pointy_start><coord_bin_{y_bin}><pointy_end>" |
|
text += "<result_end>" + tokenizer.eos_token |
|
return text |
|
|
|
def format_data_for_training(sample): |
|
"""Format data sample for training, handling 0 to MAX_POINTS continuous coordinates.""" |
|
try: |
|
|
|
sample_points = sample.get('points', []) |
|
if not isinstance(sample_points, list): |
|
print(f"Warning: Invalid 'points' type for {sample.get('image_url', 'N/A')}. Treating as 0 points.") |
|
sample_points = [] |
|
|
|
|
|
points_to_process = sample_points[:MAX_POINTS] |
|
num_points = len(points_to_process) |
|
|
|
|
|
image_path = f"{IMAGE_LOCATION}{sample['image_url']}" |
|
|
|
|
|
if not os.path.exists(image_path): |
|
print(f"Warning: Image not found: {image_path}. Skipping.") |
|
return None |
|
|
|
|
|
try: |
|
image = Image.open(image_path) |
|
|
|
if image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
image_tensor = image_to_tensor(image) |
|
|
|
del image |
|
except Exception as e: |
|
print(f"Error processing image {image_path}: {e}") |
|
return None |
|
|
|
|
|
prompt_text = f"<point_start>{sample['label']}<point_end>" |
|
|
|
target_text = format_point_text(points_to_process) |
|
|
|
|
|
prompt_tokens = tokenizer(prompt_text, return_tensors="pt", max_length=PROMPT_LENGTH, |
|
truncation=True, padding=False) |
|
target_tokens = tokenizer(target_text, return_tensors="pt", max_length=TEXT_LENGTH, |
|
truncation=True, padding=False) |
|
|
|
|
|
if prompt_tokens.input_ids.numel() == 0 or target_tokens.input_ids.numel() == 0: |
|
print(f"Warning: Empty tokens after tokenization for {sample.get('image_url', 'N/A')}. Skipping.") |
|
return None |
|
|
|
|
|
continuous_coords_list = [] |
|
for point in points_to_process: |
|
coord_x = min(max(point.get('x', 50) / 100.0, 0.0), 1.0) |
|
coord_y = min(max(point.get('y', 50) / 100.0, 0.0), 1.0) |
|
continuous_coords_list.append([coord_x, coord_y]) |
|
|
|
|
|
|
|
if num_points == 0: |
|
padded_coords = torch.full((MAX_POINTS, 2), -1.0) |
|
coords_mask = torch.zeros(MAX_POINTS) |
|
else: |
|
coords_tensor = torch.tensor(continuous_coords_list, dtype=torch.float32) |
|
padding_needed = MAX_POINTS - num_points |
|
padded_coords = F.pad(coords_tensor, (0, 0, 0, padding_needed), value=-1.0) |
|
coords_mask = torch.cat([torch.ones(num_points, dtype=torch.float32), |
|
torch.zeros(padding_needed, dtype=torch.float32)]) |
|
|
|
|
|
return { |
|
"image": image_tensor, |
|
"prompt_ids": prompt_tokens.input_ids[0], |
|
"target_ids": target_tokens.input_ids[0], |
|
"continuous_coords": padded_coords, |
|
"coords_mask": coords_mask, |
|
"num_points": num_points, |
|
"label": sample['label'], |
|
"image_url": sample['image_url'] |
|
} |
|
except FileNotFoundError: |
|
print(f"Warning: Image not found: {sample.get('image_url', 'N/A')}. Skipping.") |
|
return None |
|
except Exception as e: |
|
print(f"Error formatting sample ({sample.get('image_url', 'N/A')}): {e}. Skipping.") |
|
import traceback |
|
traceback.print_exc() |
|
return None |
|
|
|
|
|
class PointDataset(Dataset): |
|
def __init__(self, data_path="active_point_dataset.pkl", split="train", test_size=1000): |
|
with open(data_path, "rb") as f: |
|
raw_data = pickle.load(f) |
|
|
|
|
|
|
|
original_count = len(raw_data) |
|
raw_data = [sample for sample in raw_data |
|
if 0 <= len(sample.get('points', [])) <= MAX_POINTS and isinstance(sample.get('points', []), list)] |
|
filtered_count = len(raw_data) |
|
print(f"Original raw data size: {original_count}") |
|
print(f"Filtered raw data to {filtered_count} samples with 0 to {MAX_POINTS} points.") |
|
|
|
total_samples = len(raw_data) |
|
if total_samples == 0: |
|
raise ValueError("No samples left after filtering. Check data or MAX_POINTS.") |
|
|
|
if total_samples <= test_size: |
|
print(f"Warning: Dataset size {total_samples} <= test_size {test_size}.") |
|
test_size = max(1, int(total_samples * 0.2)) if total_samples > 1 else 0 |
|
train_end = total_samples - test_size |
|
|
|
print(f"Dataset: {total_samples} total (0 to {MAX_POINTS} points), {train_end} train, {test_size} test") |
|
|
|
|
|
if split == "train": |
|
|
|
if train_end <= 0: print("Warning: No samples allocated for training split.") |
|
self.raw_data = raw_data[:train_end] |
|
elif split == "test": |
|
|
|
if test_size <= 0: print("Warning: No samples allocated for test split.") |
|
self.raw_data = raw_data[train_end:] |
|
else: |
|
raise ValueError("split must be 'train' or 'test'") |
|
|
|
|
|
|
|
print(f"Dataset initialized with {len(self.raw_data)} samples for {split}") |
|
|
|
|
|
self.cache_size = 8000 |
|
self.cache = {} |
|
|
|
def __len__(self): |
|
return len(self.raw_data) |
|
|
|
def __getitem__(self, idx): |
|
|
|
if idx in self.cache: |
|
return self.cache[idx] |
|
|
|
|
|
sample = self.raw_data[idx] |
|
formatted = format_data_for_training(sample) |
|
|
|
|
|
if formatted is None: |
|
|
|
next_idx = (idx + 1) % len(self.raw_data) |
|
|
|
|
|
attempts = 0 |
|
while formatted is None and attempts < min(10, len(self.raw_data)): |
|
sample = self.raw_data[next_idx] |
|
formatted = format_data_for_training(sample) |
|
next_idx = (next_idx + 1) % len(self.raw_data) |
|
attempts += 1 |
|
|
|
|
|
if formatted is None: |
|
print(f"Warning: Failed to find valid sample after {attempts} attempts") |
|
|
|
formatted = self._create_dummy_sample() |
|
|
|
|
|
if len(self.cache) >= self.cache_size: |
|
|
|
if self.cache: |
|
oldest_key = next(iter(self.cache)) |
|
del self.cache[oldest_key] |
|
|
|
|
|
self.cache[idx] = formatted |
|
|
|
return formatted |
|
|
|
def _create_dummy_sample(self): |
|
"""Creates a minimal valid sample when all else fails.""" |
|
|
|
image_tensor = torch.zeros(3, IMAGE_SIZE, IMAGE_SIZE) |
|
|
|
|
|
prompt_text = "<point_start>dummy<point_end>" |
|
target_text = "<result_start><result_end>" + tokenizer.eos_token |
|
|
|
prompt_tokens = tokenizer(prompt_text, return_tensors="pt").input_ids[0] |
|
target_tokens = tokenizer(target_text, return_tensors="pt").input_ids[0] |
|
|
|
|
|
padded_coords = torch.full((MAX_POINTS, 2), -1.0) |
|
coords_mask = torch.zeros(MAX_POINTS) |
|
|
|
return { |
|
"image": image_tensor, |
|
"prompt_ids": prompt_tokens, |
|
"target_ids": target_tokens, |
|
"continuous_coords": padded_coords, |
|
"coords_mask": coords_mask, |
|
"num_points": 0, |
|
"label": "dummy", |
|
"image_url": "none" |
|
} |
|
|
|
|
|
@staticmethod |
|
def collate_fn(batch): |
|
|
|
batch = [item for item in batch if item is not None] |
|
if not batch: return None |
|
|
|
images = torch.stack([item['image'] for item in batch]).to(DTYPE) |
|
|
|
|
|
max_prompt_len = max(item['prompt_ids'].size(0) for item in batch) |
|
prompt_ids_padded, prompt_attention_mask = [], [] |
|
for item in batch: |
|
ids, pad_len = item['prompt_ids'], max_prompt_len - item['prompt_ids'].size(0) |
|
prompt_ids_padded.append(torch.cat([ids, torch.full((pad_len,), tokenizer.pad_token_id, dtype=torch.long)])) |
|
prompt_attention_mask.append(torch.cat([torch.ones_like(ids, dtype=torch.long), torch.zeros(pad_len, dtype=torch.long)])) |
|
prompt_ids = torch.stack(prompt_ids_padded) |
|
prompt_attention_mask = torch.stack(prompt_attention_mask) |
|
|
|
|
|
max_target_len = max(item['target_ids'].size(0) for item in batch) |
|
target_ids_padded, target_attention_mask, generative_targets = [], [], [] |
|
for item in batch: |
|
ids, pad_len = item['target_ids'], max_target_len - item['target_ids'].size(0) |
|
padded_ids = torch.cat([ids, torch.full((pad_len,), tokenizer.pad_token_id, dtype=torch.long)]) |
|
target_ids_padded.append(padded_ids) |
|
mask = torch.cat([torch.ones_like(ids, dtype=torch.long), torch.zeros(pad_len, dtype=torch.long)]) |
|
target_attention_mask.append(mask) |
|
targets = torch.full_like(padded_ids, -100) |
|
if ids.size(0) > 1: |
|
targets[:ids.size(0)-1] = ids[1:] |
|
if ids.numel() > 0 and ids[-1] == tokenizer.eos_token_id: |
|
if ids.size(0) > 1: |
|
targets[ids.size(0)-1] = tokenizer.eos_token_id |
|
else: |
|
targets[0] = -100 |
|
generative_targets.append(targets) |
|
target_ids = torch.stack(target_ids_padded) |
|
target_attention_mask = torch.stack(target_attention_mask) |
|
generative_targets = torch.stack(generative_targets) |
|
|
|
|
|
continuous_coords = torch.stack([item['continuous_coords'] for item in batch]) |
|
coords_mask = torch.stack([item['coords_mask'] for item in batch]) |
|
num_points = [item['num_points'] for item in batch] |
|
|
|
labels = [item['label'] for item in batch] |
|
image_urls = [item.get('image_url', '') for item in batch] |
|
|
|
return { |
|
'image': images, |
|
'prompt_ids': prompt_ids, |
|
'prompt_attention_mask': prompt_attention_mask, |
|
'target_ids': target_ids, |
|
'target_attention_mask': target_attention_mask, |
|
'generative_targets': generative_targets, |
|
'continuous_coords': continuous_coords, |
|
'coords_mask': coords_mask, |
|
'num_points': num_points, |
|
'label': labels, |
|
'image_url': image_urls |
|
} |
|
|
|
def create_train_dataloader(batch_size=BATCH_SIZE, num_workers=0, prefetch_factor=2): |
|
"""Create training dataloader with memory-efficient settings. |
|
|
|
Args: |
|
batch_size: Number of samples per batch |
|
num_workers: Number of worker processes for data loading |
|
prefetch_factor: Number of batches to prefetch per worker |
|
|
|
Returns: |
|
DataLoader instance or None if dataset is empty |
|
""" |
|
dataset = PointDataset(split="train") |
|
if len(dataset) == 0: |
|
return None |
|
|
|
|
|
return DataLoader( |
|
dataset, |
|
batch_size=batch_size, |
|
shuffle=True, |
|
collate_fn=PointDataset.collate_fn, |
|
pin_memory=True, |
|
num_workers=num_workers, |
|
prefetch_factor=prefetch_factor if num_workers > 0 else None, |
|
persistent_workers=num_workers > 0, |
|
drop_last=False |
|
) |
|
|
|
def create_test_dataloader(batch_size=BATCH_SIZE, num_workers=0, prefetch_factor=2): |
|
"""Create test dataloader with memory-efficient settings. |
|
|
|
Args: |
|
batch_size: Number of samples per batch |
|
num_workers: Number of worker processes for data loading |
|
prefetch_factor: Number of batches to prefetch per worker |
|
|
|
Returns: |
|
DataLoader instance or None if dataset is empty |
|
""" |
|
dataset = PointDataset(split="test") |
|
if len(dataset) == 0: |
|
print("Warning: Test dataset is empty. Returning None.") |
|
return None |
|
|
|
|
|
return DataLoader( |
|
dataset, |
|
batch_size=batch_size, |
|
shuffle=False, |
|
collate_fn=PointDataset.collate_fn, |
|
pin_memory=True, |
|
num_workers=num_workers, |
|
prefetch_factor=prefetch_factor if num_workers > 0 else None, |
|
persistent_workers=num_workers > 0, |
|
drop_last=False |
|
) |
|
|
|
|