mbiswas's picture
Upload 10 files
b781107 verified
from tqdm.auto import tqdm
from constants import *
from utils import *
import pickle
from torch.utils.data import Dataset
import torch
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import DataLoader
import os
def format_point_text(points):
# This function should already handle multiple points correctly
text = "<result_start>"
for point in points:
# Ensure point coordinates are within [0, 100] before processing
px = min(max(int(point.get('x', 50) * IMAGE_SIZE / 100), 0), IMAGE_SIZE - 1) # Added .get for safety
py = min(max(int(point.get('y', 50) * IMAGE_SIZE / 100), 0), IMAGE_SIZE - 1)
x_bin = min(px // (IMAGE_SIZE // NUM_BINS), NUM_BINS - 1)
y_bin = min(py // (IMAGE_SIZE // NUM_BINS), NUM_BINS - 1)
text += f"<pointx_start><coord_bin_{x_bin}><pointx_end><pointy_start><coord_bin_{y_bin}><pointy_end>"
text += "<result_end>" + tokenizer.eos_token
return text
def format_data_for_training(sample):
"""Format data sample for training, handling 0 to MAX_POINTS continuous coordinates."""
try:
# Check if 'points' key exists and is a list, otherwise treat as 0 points
sample_points = sample.get('points', [])
if not isinstance(sample_points, list):
print(f"Warning: Invalid 'points' type for {sample.get('image_url', 'N/A')}. Treating as 0 points.")
sample_points = []
# Limit the number of points processed
points_to_process = sample_points[:MAX_POINTS]
num_points = len(points_to_process)
# Load image - this is where most memory is used
image_path = f"{IMAGE_LOCATION}{sample['image_url']}"
# Check if file exists before attempting to open
if not os.path.exists(image_path):
print(f"Warning: Image not found: {image_path}. Skipping.")
return None
# Open image with error handling
try:
image = Image.open(image_path)
# Convert grayscale to RGB if needed
if image.mode != 'RGB':
image = image.convert('RGB')
image_tensor = image_to_tensor(image)
# Explicitly delete the PIL image to free memory
del image
except Exception as e:
print(f"Error processing image {image_path}: {e}")
return None
# Process text with memory efficiency in mind
prompt_text = f"<point_start>{sample['label']}<point_end>"
# format_point_text correctly handles an empty points_to_process list
target_text = format_point_text(points_to_process)
# Tokenize with explicit max lengths
prompt_tokens = tokenizer(prompt_text, return_tensors="pt", max_length=PROMPT_LENGTH,
truncation=True, padding=False)
target_tokens = tokenizer(target_text, return_tensors="pt", max_length=TEXT_LENGTH,
truncation=True, padding=False)
# Check for empty tokens after tokenization
if prompt_tokens.input_ids.numel() == 0 or target_tokens.input_ids.numel() == 0:
print(f"Warning: Empty tokens after tokenization for {sample.get('image_url', 'N/A')}. Skipping.")
return None
# --- Handle Multiple Continuous Coordinates with Padding (Handles num_points=0 correctly) ---
continuous_coords_list = []
for point in points_to_process: # This loop won't run if num_points is 0
coord_x = min(max(point.get('x', 50) / 100.0, 0.0), 1.0)
coord_y = min(max(point.get('y', 50) / 100.0, 0.0), 1.0)
continuous_coords_list.append([coord_x, coord_y])
# Pad coordinates and create mask
# If continuous_coords_list is empty, create empty tensor with right shape
if num_points == 0:
padded_coords = torch.full((MAX_POINTS, 2), -1.0)
coords_mask = torch.zeros(MAX_POINTS)
else:
coords_tensor = torch.tensor(continuous_coords_list, dtype=torch.float32)
padding_needed = MAX_POINTS - num_points
padded_coords = F.pad(coords_tensor, (0, 0, 0, padding_needed), value=-1.0)
coords_mask = torch.cat([torch.ones(num_points, dtype=torch.float32),
torch.zeros(padding_needed, dtype=torch.float32)])
# Create and return the formatted sample
return {
"image": image_tensor,
"prompt_ids": prompt_tokens.input_ids[0],
"target_ids": target_tokens.input_ids[0],
"continuous_coords": padded_coords,
"coords_mask": coords_mask,
"num_points": num_points,
"label": sample['label'],
"image_url": sample['image_url']
}
except FileNotFoundError:
print(f"Warning: Image not found: {sample.get('image_url', 'N/A')}. Skipping.")
return None
except Exception as e:
print(f"Error formatting sample ({sample.get('image_url', 'N/A')}): {e}. Skipping.")
import traceback
traceback.print_exc()
return None
class PointDataset(Dataset):
def __init__(self, data_path="active_point_dataset.pkl", split="train", test_size=1000):
with open(data_path, "rb") as f:
raw_data = pickle.load(f)
# --- Corrected filter and print statement ---
# Keep samples with 0 to MAX_POINTS points. Handle potential non-list 'points' safely.
original_count = len(raw_data)
raw_data = [sample for sample in raw_data
if 0 <= len(sample.get('points', [])) <= MAX_POINTS and isinstance(sample.get('points', []), list)]
filtered_count = len(raw_data)
print(f"Original raw data size: {original_count}")
print(f"Filtered raw data to {filtered_count} samples with 0 to {MAX_POINTS} points.")
total_samples = len(raw_data)
if total_samples == 0:
raise ValueError("No samples left after filtering. Check data or MAX_POINTS.") # Added error for empty dataset
if total_samples <= test_size:
print(f"Warning: Dataset size {total_samples} <= test_size {test_size}.")
test_size = max(1, int(total_samples * 0.2)) if total_samples > 1 else 0
train_end = total_samples - test_size
# Update print statement to reflect 0 points are included
print(f"Dataset: {total_samples} total (0 to {MAX_POINTS} points), {train_end} train, {test_size} test")
# --- Corrected split logic to use actual train/test counts ---
if split == "train":
# Check if train_end is valid before slicing
if train_end <= 0: print("Warning: No samples allocated for training split.")
self.raw_data = raw_data[:train_end]
elif split == "test":
# Check if test_size is valid before slicing
if test_size <= 0: print("Warning: No samples allocated for test split.")
self.raw_data = raw_data[train_end:]
else:
raise ValueError("split must be 'train' or 'test'")
# DO NOT preprocess data here - just store the raw data
# This is the key change - we don't load all images at once
print(f"Dataset initialized with {len(self.raw_data)} samples for {split}")
# Optional: Cache a small number of recent items to speed up repeated access
self.cache_size = 8000 # Adjust based on memory constraints
self.cache = {} # Simple LRU cache for processed samples
def __len__(self):
return len(self.raw_data)
def __getitem__(self, idx):
# Check if the item is in the cache
if idx in self.cache:
return self.cache[idx]
# Process the sample on-demand
sample = self.raw_data[idx]
formatted = format_data_for_training(sample)
# If processing failed, try the next sample
if formatted is None:
# Find next valid index (with wrapping)
next_idx = (idx + 1) % len(self.raw_data)
# Prevent infinite loop if all samples are invalid
attempts = 0
while formatted is None and attempts < min(10, len(self.raw_data)):
sample = self.raw_data[next_idx]
formatted = format_data_for_training(sample)
next_idx = (next_idx + 1) % len(self.raw_data)
attempts += 1
# If we still don't have a valid sample after attempts, return a dummy sample
if formatted is None:
print(f"Warning: Failed to find valid sample after {attempts} attempts")
# Create minimal valid sample with zeros
formatted = self._create_dummy_sample()
# Update cache - simple LRU implementation
if len(self.cache) >= self.cache_size:
# Remove oldest item (first key)
if self.cache:
oldest_key = next(iter(self.cache))
del self.cache[oldest_key]
# Add to cache
self.cache[idx] = formatted
return formatted
def _create_dummy_sample(self):
"""Creates a minimal valid sample when all else fails."""
# Create empty image tensor
image_tensor = torch.zeros(3, IMAGE_SIZE, IMAGE_SIZE)
# Create minimal tokens
prompt_text = "<point_start>dummy<point_end>"
target_text = "<result_start><result_end>" + tokenizer.eos_token
prompt_tokens = tokenizer(prompt_text, return_tensors="pt").input_ids[0]
target_tokens = tokenizer(target_text, return_tensors="pt").input_ids[0]
# Create empty coordinates
padded_coords = torch.full((MAX_POINTS, 2), -1.0)
coords_mask = torch.zeros(MAX_POINTS)
return {
"image": image_tensor,
"prompt_ids": prompt_tokens,
"target_ids": target_tokens,
"continuous_coords": padded_coords,
"coords_mask": coords_mask,
"num_points": 0,
"label": "dummy",
"image_url": "none"
}
# --- collate_fn remains the same as the previous version ---
@staticmethod
def collate_fn(batch):
# ... (Same as before, correctly handles stacking the padded coords and masks) ...
batch = [item for item in batch if item is not None]
if not batch: return None
images = torch.stack([item['image'] for item in batch]).to(DTYPE)
# --- Pad Prompt IDs ---
max_prompt_len = max(item['prompt_ids'].size(0) for item in batch)
prompt_ids_padded, prompt_attention_mask = [], []
for item in batch:
ids, pad_len = item['prompt_ids'], max_prompt_len - item['prompt_ids'].size(0)
prompt_ids_padded.append(torch.cat([ids, torch.full((pad_len,), tokenizer.pad_token_id, dtype=torch.long)]))
prompt_attention_mask.append(torch.cat([torch.ones_like(ids, dtype=torch.long), torch.zeros(pad_len, dtype=torch.long)]))
prompt_ids = torch.stack(prompt_ids_padded)
prompt_attention_mask = torch.stack(prompt_attention_mask)
# --- Pad Target IDs & Create Generative Targets ---
max_target_len = max(item['target_ids'].size(0) for item in batch)
target_ids_padded, target_attention_mask, generative_targets = [], [], []
for item in batch:
ids, pad_len = item['target_ids'], max_target_len - item['target_ids'].size(0)
padded_ids = torch.cat([ids, torch.full((pad_len,), tokenizer.pad_token_id, dtype=torch.long)])
target_ids_padded.append(padded_ids)
mask = torch.cat([torch.ones_like(ids, dtype=torch.long), torch.zeros(pad_len, dtype=torch.long)])
target_attention_mask.append(mask)
targets = torch.full_like(padded_ids, -100)
if ids.size(0) > 1:
targets[:ids.size(0)-1] = ids[1:]
if ids.numel() > 0 and ids[-1] == tokenizer.eos_token_id:
if ids.size(0) > 1:
targets[ids.size(0)-1] = tokenizer.eos_token_id
else:
targets[0] = -100
generative_targets.append(targets)
target_ids = torch.stack(target_ids_padded)
target_attention_mask = torch.stack(target_attention_mask)
generative_targets = torch.stack(generative_targets)
# --- Stack Continuous Coords and Masks ---
continuous_coords = torch.stack([item['continuous_coords'] for item in batch])
coords_mask = torch.stack([item['coords_mask'] for item in batch])
num_points = [item['num_points'] for item in batch]
labels = [item['label'] for item in batch]
image_urls = [item.get('image_url', '') for item in batch]
return {
'image': images,
'prompt_ids': prompt_ids,
'prompt_attention_mask': prompt_attention_mask,
'target_ids': target_ids,
'target_attention_mask': target_attention_mask,
'generative_targets': generative_targets,
'continuous_coords': continuous_coords,
'coords_mask': coords_mask,
'num_points': num_points,
'label': labels,
'image_url': image_urls
}
def create_train_dataloader(batch_size=BATCH_SIZE, num_workers=0, prefetch_factor=2):
"""Create training dataloader with memory-efficient settings.
Args:
batch_size: Number of samples per batch
num_workers: Number of worker processes for data loading
prefetch_factor: Number of batches to prefetch per worker
Returns:
DataLoader instance or None if dataset is empty
"""
dataset = PointDataset(split="train")
if len(dataset) == 0:
return None
# Configure DataLoader for memory efficiency
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=PointDataset.collate_fn,
pin_memory=True, # Speeds up CPU to GPU transfer
num_workers=num_workers,
prefetch_factor=prefetch_factor if num_workers > 0 else None, # Only valid with workers
persistent_workers=num_workers > 0, # Keep workers alive between epochs
drop_last=False # Don't drop the last incomplete batch
)
def create_test_dataloader(batch_size=BATCH_SIZE, num_workers=0, prefetch_factor=2):
"""Create test dataloader with memory-efficient settings.
Args:
batch_size: Number of samples per batch
num_workers: Number of worker processes for data loading
prefetch_factor: Number of batches to prefetch per worker
Returns:
DataLoader instance or None if dataset is empty
"""
dataset = PointDataset(split="test")
if len(dataset) == 0:
print("Warning: Test dataset is empty. Returning None.")
return None
# Test loader with similar memory settings but no shuffling
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=PointDataset.collate_fn,
pin_memory=True,
num_workers=num_workers,
prefetch_factor=prefetch_factor if num_workers > 0 else None,
persistent_workers=num_workers > 0,
drop_last=False
)