| |
| |
|
|
| import ast |
| from io import BytesIO |
| from urllib.parse import urljoin |
|
|
| import pandas as pd |
| import requests |
| import torch |
| from PIL import Image |
| from torch.utils.data import Dataset, DataLoader |
| from torchvision import transforms |
|
|
| from utils import load_json |
|
|
|
|
| class CenterSquareCrop: |
| """ |
| Crop image to a centered square without resizing. |
| """ |
|
|
| def __call__(self, img: Image.Image): |
| w, h = img.size |
|
|
| if w == h: |
| return img |
|
|
| if w > h: |
| left = (w - h) // 2 |
| right = left + h |
| top = 0 |
| bottom = h |
| else: |
| top = (h - w) // 2 |
| bottom = top + w |
| left = 0 |
| right = w |
| return img.crop((left, top, right, bottom)) |
|
|
|
|
| def build_image_transform(image_size: int): |
| return transforms.Compose([ |
| CenterSquareCrop(), |
| transforms.Resize((image_size, image_size)), |
| transforms.ToTensor(), |
| ]) |
|
|
|
|
| def join_photo_root(photo_root: str, relative_path: str) -> str: |
| """ |
| Join photo_root and relative path. |
| |
| Supports: |
| - local filesystem roots |
| - http / https roots |
| """ |
| if photo_root.startswith("http://") or photo_root.startswith("https://"): |
| return urljoin(photo_root.rstrip("/") + "/", relative_path) |
|
|
| return photo_root.rstrip("/") + "/" + relative_path.lstrip("/") |
|
|
|
|
| def parse_numeric_cell(value: str, n_in: int): |
| """ |
| Convert numeric csv cell to list[float]. |
| |
| Returns: |
| values, is_valid |
| |
| Data assumption: |
| - Empty value is always "" |
| - Scalar numeric -> "12.3" |
| - Vector numeric -> "[1.2,3.4,5.6]" |
| """ |
| if value == "": |
| return [0.0] * n_in, False |
|
|
| if n_in == 1: |
| return [float(value)], True |
|
|
| vec = ast.literal_eval(value) |
| if len(vec) != n_in: |
| raise ValueError(f"Numeric vector length mismatch: expected {n_in}, got {len(vec)}") |
| return [float(v) for v in vec], True |
|
|
|
|
| class SoilFormerDataset(Dataset): |
|
|
| def __init__( |
| self, |
| csv_path: str, |
| photo_map_path: str, |
| cat_vocab_path: str, |
| numeric_vocab_path: str, |
| numeric_stats_path: str, |
| photo_root: str, |
| image_size: int = 512, |
| id_column: str = "id", |
| ): |
| self.df = pd.read_csv( |
| csv_path, |
| keep_default_na=False, |
| na_filter=False, |
| low_memory=False, |
| ) |
|
|
| self.photo_map = load_json(photo_map_path) |
| self.cat_vocab = load_json(cat_vocab_path) |
| self.numeric_vocab = load_json(numeric_vocab_path) |
|
|
| self.photo_root = photo_root |
| self.id_column = id_column |
| self.image_size = int(image_size) |
| self.image_transform = build_image_transform(self.image_size) |
|
|
| |
| self.cat_columns = list(self.cat_vocab.keys()) |
| self.numeric_groups = self.numeric_vocab["groups"] |
| self.numeric_stats_df = pd.read_csv(numeric_stats_path) |
| self.numeric_stats_index = self.numeric_stats_df.set_index("column") |
|
|
| |
| self.numeric_stats = {} |
| for _, row in self.numeric_stats_df.iterrows(): |
| col = row["column"] |
| mean = float(row["mean"]) |
| std = float(row["std"]) |
| if std == 0.0: |
| std = 1.0 |
| self.numeric_stats[col] = (mean, std) |
|
|
| |
| self.cat_mask_local_ids = torch.tensor( |
| [int(self.cat_vocab[col]["mask_local_id"]) for col in self.cat_columns], |
| dtype=torch.long, |
| ) |
|
|
| def __len__(self): |
| return len(self.df) |
|
|
| def load_image(self, path: str): |
| if path.startswith("http://") or path.startswith("https://"): |
| resp = requests.get(path, timeout=(3, 10)) |
| resp.raise_for_status() |
| img = Image.open(BytesIO(resp.content)).convert("RGB") |
| else: |
| img = Image.open(path).convert("RGB") |
|
|
| return self.image_transform(img) |
|
|
| def __getitem__(self, idx): |
| row = self.df.iloc[idx] |
| sample_id = row[self.id_column] |
|
|
| |
| |
| |
| cat_ids = [] |
| cat_valids = [] |
|
|
| for col in self.cat_columns: |
| spec = self.cat_vocab[col] |
| label2id = spec["label2id"] |
| mask_id = spec["mask_local_id"] |
|
|
| value = row[col] |
|
|
| if value == "": |
| cat_ids.append(mask_id) |
| cat_valids.append(False) |
| else: |
| if value not in label2id: |
| raise KeyError(f"Unknown categorical value: column={col}, value={value!r}") |
| cat_ids.append(label2id[value]) |
| cat_valids.append(True) |
|
|
| cat_ids = torch.tensor(cat_ids, dtype=torch.long) |
| cat_valids = torch.tensor(cat_valids, dtype=torch.bool) |
|
|
| |
| |
| |
| numeric_values_by_nin = {} |
| numeric_valid_positions_by_nin = {} |
|
|
| for group in self.numeric_groups: |
| n_in = int(group["n_in"]) |
| features = group["feature_names"] |
|
|
| values = [] |
| valids = [] |
|
|
| for feat in features: |
| cell = row[feat] |
| parsed, is_valid = parse_numeric_cell(cell, n_in) |
| if is_valid: |
| mean, std = self.numeric_stats[feat] |
| parsed = [(v - mean) / std for v in parsed] |
| values.append(parsed) |
| valids.append(is_valid) |
|
|
| numeric_values_by_nin[n_in] = torch.tensor(values, dtype=torch.float32) |
| numeric_valid_positions_by_nin[n_in] = torch.tensor(valids, dtype=torch.bool) |
|
|
| |
| |
| |
| try: |
| relative_path = self.photo_map[sample_id] |
| full_path = join_photo_root(self.photo_root, relative_path) |
| image = self.load_image(full_path) |
| vision_valid = True |
| except Exception: |
| image = torch.zeros(3, self.image_size, self.image_size, dtype=torch.float32) |
| vision_valid = False |
|
|
| vision_valid = torch.tensor(vision_valid, dtype=torch.bool) |
|
|
| return { |
| "row_idx": torch.tensor(idx, dtype=torch.long), |
| "sample_id": sample_id, |
| "cat_local_ids": cat_ids, |
| "cat_valid_positions": cat_valids, |
| "numeric_values_by_nin": numeric_values_by_nin, |
| "numeric_valid_positions_by_nin": numeric_valid_positions_by_nin, |
| "pixel_values": image, |
| "vision_valid_positions": vision_valid, |
| } |
|
|
| @staticmethod |
| def collate_fn(batch): |
| cat_ids = torch.stack([b["cat_local_ids"] for b in batch], dim=0) |
| cat_valids = torch.stack([b["cat_valid_positions"] for b in batch], dim=0) |
|
|
| group_keys = list(batch[0]["numeric_values_by_nin"].keys()) |
|
|
| numeric_values_by_nin = {} |
| numeric_valid_positions_by_nin = {} |
|
|
| for k in group_keys: |
| numeric_values_by_nin[k] = torch.stack( |
| [b["numeric_values_by_nin"][k] for b in batch], |
| dim=0, |
| ) |
| numeric_valid_positions_by_nin[k] = torch.stack( |
| [b["numeric_valid_positions_by_nin"][k] for b in batch], |
| dim=0, |
| ) |
|
|
| pixel_values = torch.stack([b["pixel_values"] for b in batch], dim=0) |
| vision_valid_positions = torch.stack([b["vision_valid_positions"] for b in batch], dim=0) |
| row_idx = torch.stack([b["row_idx"] for b in batch], dim=0) |
| sample_ids = [b["sample_id"] for b in batch] |
|
|
| return { |
| "row_idx": row_idx, |
| "sample_id": sample_ids, |
| "cat_local_ids": cat_ids, |
| "numeric_values_by_nin": numeric_values_by_nin, |
| "cat_valid_positions": cat_valids, |
| "numeric_valid_positions_by_nin": numeric_valid_positions_by_nin, |
| "pixel_values": pixel_values, |
| "vision_valid_positions": vision_valid_positions, |
| } |
|
|
| def perform_active_mask(self, batch, cat_ratio=0.15, num_ratio=0.15, seed=None): |
| """ |
| Apply active masking to categorical and numeric inputs. |
| |
| Conventions |
| ----------- |
| Input batch must contain: |
| - cat_local_ids: [B, M] LongTensor |
| - cat_valid_positions: [B, M] Bool/0-1 tensor |
| - numeric_values_by_nin: Dict[int, Tensor[B, V, n_in]] |
| - numeric_valid_positions_by_nin: Dict[int, Tensor[B, V]] |
| |
| Output batch will additionally contain: |
| - original_cat_local_ids |
| - original_cat_valid_positions |
| - original_numeric_values_by_nin |
| - original_numeric_valid_positions_by_nin |
| |
| - masked_cat_local_ids |
| - masked_cat_valid_positions |
| - masked_numeric_values_by_nin |
| - masked_numeric_valid_positions_by_nin |
| |
| - cat_loss_mask: [B, M] BoolTensor |
| - numeric_loss_mask_by_nin: Dict[int, BoolTensor[B, V]] |
| |
| Semantics |
| --------- |
| - Only originally valid positions can be actively masked. |
| - Masked categorical positions: |
| local_id -> self.cat_mask_local_ids[col] |
| valid -> False |
| - Masked numeric positions: |
| values -> 0 |
| valid -> False |
| - original_* fields always preserve the unmodified input batch content. |
| """ |
| |
| |
| |
| if not (0.0 <= cat_ratio <= 1.0): |
| raise ValueError(f"cat_ratio must be in [0, 1], got {cat_ratio}") |
| if not (0.0 <= num_ratio <= 1.0): |
| raise ValueError(f"num_ratio must be in [0, 1], got {num_ratio}") |
|
|
| |
| |
| |
| required_keys = [ |
| "cat_local_ids", |
| "cat_valid_positions", |
| "numeric_values_by_nin", |
| "numeric_valid_positions_by_nin", |
| ] |
| for k in required_keys: |
| if k not in batch: |
| raise KeyError(f"Missing key in batch: {k}") |
|
|
| cat_local_ids = batch["cat_local_ids"] |
| cat_valid_positions = batch["cat_valid_positions"] |
| numeric_values_by_nin = batch["numeric_values_by_nin"] |
| numeric_valid_positions_by_nin = batch["numeric_valid_positions_by_nin"] |
|
|
| if cat_local_ids.dim() != 2: |
| raise ValueError(f"cat_local_ids must be [B, M], got {tuple(cat_local_ids.shape)}") |
| if cat_valid_positions.shape != cat_local_ids.shape: |
| raise ValueError( |
| f"cat_valid_positions must match cat_local_ids shape, got " |
| f"{tuple(cat_valid_positions.shape)} vs {tuple(cat_local_ids.shape)}" |
| ) |
|
|
| if not isinstance(numeric_values_by_nin, dict): |
| raise ValueError("numeric_values_by_nin must be a dict") |
| if not isinstance(numeric_valid_positions_by_nin, dict): |
| raise ValueError("numeric_valid_positions_by_nin must be a dict") |
|
|
| B, M = cat_local_ids.shape |
| device = cat_local_ids.device |
|
|
| if self.cat_mask_local_ids.dim() != 1 or self.cat_mask_local_ids.numel() != M: |
| raise ValueError( |
| f"self.cat_mask_local_ids must be [M] with M={M}, got {tuple(self.cat_mask_local_ids.shape)}" |
| ) |
| cat_mask_local_ids = self.cat_mask_local_ids.to(device=device, dtype=cat_local_ids.dtype) |
|
|
| |
| |
| |
| if device.type == "cuda": |
| generator = torch.Generator(device=device) |
| else: |
| generator = torch.Generator() |
|
|
| if seed is not None: |
| generator.manual_seed(seed) |
|
|
| |
| |
| |
| masked_batch = dict(batch) |
|
|
| |
| masked_batch["original_cat_local_ids"] = batch["cat_local_ids"] |
| masked_batch["original_cat_valid_positions"] = batch["cat_valid_positions"] |
| masked_batch["original_numeric_values_by_nin"] = batch["numeric_values_by_nin"] |
| masked_batch["original_numeric_valid_positions_by_nin"] = batch["numeric_valid_positions_by_nin"] |
|
|
| |
| |
| |
| if cat_ratio == 0.0 and num_ratio == 0.0: |
| masked_batch["masked_cat_local_ids"] = batch["cat_local_ids"] |
| masked_batch["masked_cat_valid_positions"] = batch["cat_valid_positions"] |
|
|
| masked_batch["masked_numeric_values_by_nin"] = batch["numeric_values_by_nin"] |
| masked_batch["masked_numeric_valid_positions_by_nin"] = batch["numeric_valid_positions_by_nin"] |
|
|
| masked_batch["cat_loss_mask"] = torch.zeros( |
| (B, M), dtype=torch.bool, device=device |
| ) |
| masked_batch["numeric_loss_mask_by_nin"] = { |
| n_in: torch.zeros_like(valid_positions, dtype=torch.bool) |
| for n_in, valid_positions in numeric_valid_positions_by_nin.items() |
| } |
| return masked_batch |
|
|
| |
| |
| |
| original_cat_valid_positions = cat_valid_positions.bool() |
|
|
| masked_cat_local_ids = cat_local_ids.clone() |
| masked_cat_valid_positions = original_cat_valid_positions.clone() |
| cat_loss_mask = torch.zeros((B, M), dtype=torch.bool, device=device) |
|
|
| if cat_ratio > 0.0: |
| for b in range(B): |
| valid_idx = torch.nonzero(original_cat_valid_positions[b], as_tuple=False).squeeze(1) |
| n_valid = valid_idx.numel() |
| if n_valid == 0: |
| continue |
|
|
| k = int(round(n_valid * cat_ratio)) |
| if k <= 0: |
| continue |
| if k > n_valid: |
| k = n_valid |
|
|
| perm = valid_idx[ |
| torch.randperm(n_valid, generator=generator, device=device)[:k] |
| ] |
| cat_loss_mask[b, perm] = True |
|
|
| expanded_cat_mask_ids = cat_mask_local_ids.view(1, M).expand(B, M) |
| masked_cat_local_ids[cat_loss_mask] = expanded_cat_mask_ids[cat_loss_mask] |
| masked_cat_valid_positions = masked_cat_valid_positions & (~cat_loss_mask) |
|
|
| masked_batch["masked_cat_local_ids"] = masked_cat_local_ids |
| masked_batch["masked_cat_valid_positions"] = masked_cat_valid_positions |
| masked_batch["cat_loss_mask"] = cat_loss_mask |
|
|
| |
| |
| |
| masked_numeric_values_by_nin = {} |
| masked_numeric_valid_positions_by_nin = {} |
| numeric_loss_mask_by_nin = {} |
|
|
| |
| for n_in in sorted(numeric_values_by_nin.keys(), key=int): |
| values = numeric_values_by_nin[n_in] |
| if n_in not in numeric_valid_positions_by_nin: |
| raise KeyError(f"Missing numeric_valid_positions_by_nin[{n_in}]") |
|
|
| valid_positions = numeric_valid_positions_by_nin[n_in] |
|
|
| if values.dim() != 3: |
| raise ValueError( |
| f"numeric_values_by_nin[{n_in}] must be [B, V, n_in], got {tuple(values.shape)}" |
| ) |
|
|
| Bn, V, Nin = values.shape |
| if Bn != B: |
| raise ValueError( |
| f"numeric_values_by_nin[{n_in}] batch mismatch: got {Bn}, expected {B}" |
| ) |
| if int(Nin) != int(n_in): |
| raise ValueError( |
| f"numeric_values_by_nin[{n_in}] last dim mismatch: got {Nin}, expected {n_in}" |
| ) |
| if valid_positions.shape != (B, V): |
| raise ValueError( |
| f"numeric_valid_positions_by_nin[{n_in}] must be [B,V]=({B},{V}), " |
| f"got {tuple(valid_positions.shape)}" |
| ) |
|
|
| original_valid = valid_positions.bool() |
|
|
| |
| masked_values = values.clone() |
| masked_valid_positions = original_valid.clone() |
| num_loss_mask = torch.zeros((B, V), dtype=torch.bool, device=values.device) |
|
|
| if num_ratio > 0.0: |
| for b in range(B): |
| valid_idx = torch.nonzero(original_valid[b], as_tuple=False).squeeze(1) |
| n_valid = valid_idx.numel() |
| if n_valid == 0: |
| continue |
|
|
| k = int(round(n_valid * num_ratio)) |
| if k <= 0: |
| continue |
| if k > n_valid: |
| k = n_valid |
|
|
| perm = valid_idx[ |
| torch.randperm(n_valid, generator=generator, device=values.device)[:k] |
| ] |
| num_loss_mask[b, perm] = True |
|
|
| |
| masked_values[num_loss_mask] = 0.0 |
| masked_valid_positions = masked_valid_positions & (~num_loss_mask) |
|
|
| masked_numeric_values_by_nin[n_in] = masked_values |
| masked_numeric_valid_positions_by_nin[n_in] = masked_valid_positions |
| numeric_loss_mask_by_nin[n_in] = num_loss_mask |
|
|
| masked_batch["masked_numeric_values_by_nin"] = masked_numeric_values_by_nin |
| masked_batch["masked_numeric_valid_positions_by_nin"] = masked_numeric_valid_positions_by_nin |
| masked_batch["numeric_loss_mask_by_nin"] = numeric_loss_mask_by_nin |
|
|
| return masked_batch |
|
|
|
|
| def perform_active_mask_single(self, batch, feature_name, assert_not_missing=True): |
| """ |
| Actively mask exactly one feature specified by feature_name. |
| |
| Parameters |
| ---------- |
| batch : dict |
| Same input convention as perform_active_mask(...). |
| feature_name : str |
| Full feature name. Can be either categorical or numeric. |
| assert_not_missing : bool |
| If True, require the target feature to be originally valid for all samples |
| in the batch. Otherwise raise ValueError. |
| If False, only originally valid positions are masked; naturally missing |
| positions remain missing and are not included in the loss mask. |
| |
| Returns |
| ------- |
| masked_batch : dict |
| Same output convention as perform_active_mask(...), except that exactly |
| one feature is actively masked. |
| """ |
| |
| |
| |
| |
| required_keys = [ |
| "cat_local_ids", |
| "cat_valid_positions", |
| "numeric_values_by_nin", |
| "numeric_valid_positions_by_nin", |
| ] |
| for k in required_keys: |
| if k not in batch: |
| raise KeyError(f"Missing key in batch: {k}") |
| |
| cat_local_ids = batch["cat_local_ids"] |
| cat_valid_positions = batch["cat_valid_positions"] |
| numeric_values_by_nin = batch["numeric_values_by_nin"] |
| numeric_valid_positions_by_nin = batch["numeric_valid_positions_by_nin"] |
| |
| if cat_local_ids.dim() != 2: |
| raise ValueError(f"cat_local_ids must be [B, M], got {tuple(cat_local_ids.shape)}") |
| if cat_valid_positions.shape != cat_local_ids.shape: |
| raise ValueError( |
| f"cat_valid_positions must match cat_local_ids shape, got " |
| f"{tuple(cat_valid_positions.shape)} vs {tuple(cat_local_ids.shape)}" |
| ) |
| |
| if not isinstance(numeric_values_by_nin, dict): |
| raise ValueError("numeric_values_by_nin must be a dict") |
| if not isinstance(numeric_valid_positions_by_nin, dict): |
| raise ValueError("numeric_valid_positions_by_nin must be a dict") |
| |
| B, M = cat_local_ids.shape |
| device = cat_local_ids.device |
| |
| if self.cat_mask_local_ids.dim() != 1 or self.cat_mask_local_ids.numel() != M: |
| raise ValueError( |
| f"self.cat_mask_local_ids must be [M] with M={M}, got {tuple(self.cat_mask_local_ids.shape)}" |
| ) |
| cat_mask_local_ids = self.cat_mask_local_ids.to(device=device, dtype=cat_local_ids.dtype) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| is_cat = False |
| is_num = False |
| cat_col = None |
| num_n_in = None |
| num_v_idx = None |
| |
| |
| if hasattr(self, "cat_vocab") and feature_name in self.cat_vocab: |
| is_cat = True |
| cat_col = int(self.cat_vocab[feature_name]["col_id"]) |
| |
| |
| if hasattr(self, "numeric_vocab"): |
| num_features = self.numeric_vocab.get("features", {}) |
| if feature_name in num_features: |
| is_num = True |
| meta = num_features[feature_name] |
| num_n_in = int(meta["n_in"]) |
| num_v_idx = int(meta["col_id"]) |
| |
| if is_cat and is_num: |
| raise ValueError(f"Feature name appears in both categorical and numeric vocab: {feature_name}") |
| if not is_cat and not is_num: |
| raise KeyError(f"Unknown feature_name: {feature_name}") |
| |
| |
| |
| |
| masked_batch = dict(batch) |
| |
| |
| masked_batch["original_cat_local_ids"] = batch["cat_local_ids"] |
| masked_batch["original_cat_valid_positions"] = batch["cat_valid_positions"] |
| masked_batch["original_numeric_values_by_nin"] = batch["numeric_values_by_nin"] |
| masked_batch["original_numeric_valid_positions_by_nin"] = batch["numeric_valid_positions_by_nin"] |
| |
| |
| |
| |
| masked_cat_local_ids = batch["cat_local_ids"].clone() |
| masked_cat_valid_positions = batch["cat_valid_positions"].bool().clone() |
| cat_loss_mask = torch.zeros((B, M), dtype=torch.bool, device=device) |
| |
| masked_numeric_values_by_nin = {} |
| masked_numeric_valid_positions_by_nin = {} |
| numeric_loss_mask_by_nin = {} |
| |
| for n_in in sorted(numeric_values_by_nin.keys(), key=int): |
| values = numeric_values_by_nin[n_in] |
| if n_in not in numeric_valid_positions_by_nin: |
| raise KeyError(f"Missing numeric_valid_positions_by_nin[{n_in}]") |
| |
| valid_positions = numeric_valid_positions_by_nin[n_in] |
| |
| if values.dim() != 3: |
| raise ValueError( |
| f"numeric_values_by_nin[{n_in}] must be [B, V, n_in], got {tuple(values.shape)}" |
| ) |
| |
| Bn, V, Nin = values.shape |
| if Bn != B: |
| raise ValueError( |
| f"numeric_values_by_nin[{n_in}] batch mismatch: got {Bn}, expected {B}" |
| ) |
| if int(Nin) != int(n_in): |
| raise ValueError( |
| f"numeric_values_by_nin[{n_in}] last dim mismatch: got {Nin}, expected {n_in}" |
| ) |
| if valid_positions.shape != (B, V): |
| raise ValueError( |
| f"numeric_valid_positions_by_nin[{n_in}] must be [B,V]=({B},{V}), " |
| f"got {tuple(valid_positions.shape)}" |
| ) |
| |
| masked_numeric_values_by_nin[n_in] = values.clone() |
| masked_numeric_valid_positions_by_nin[n_in] = valid_positions.bool().clone() |
| numeric_loss_mask_by_nin[n_in] = torch.zeros((B, V), dtype=torch.bool, device=values.device) |
| |
| |
| |
| |
| if is_cat: |
| original_valid = cat_valid_positions[:, cat_col].bool() |
| |
| if assert_not_missing and not bool(original_valid.all().item()): |
| n_bad = int((~original_valid).sum().item()) |
| raise ValueError( |
| f"Categorical feature '{feature_name}' has {n_bad} naturally missing samples in batch" |
| ) |
| |
| |
| cat_loss_mask[:, cat_col] = original_valid |
| |
| masked_cat_local_ids[cat_loss_mask] = cat_mask_local_ids.view(1, M).expand(B, M)[cat_loss_mask] |
| masked_cat_valid_positions = masked_cat_valid_positions & (~cat_loss_mask) |
| |
| else: |
| if num_n_in not in masked_numeric_values_by_nin: |
| raise KeyError(f"numeric_values_by_nin does not contain n_in={num_n_in} for {feature_name}") |
| |
| values = masked_numeric_values_by_nin[num_n_in] |
| valid_positions = masked_numeric_valid_positions_by_nin[num_n_in] |
| num_loss_mask = numeric_loss_mask_by_nin[num_n_in] |
| |
| if num_v_idx >= values.shape[1]: |
| raise IndexError( |
| f"Numeric feature '{feature_name}' resolved to v_idx={num_v_idx}, " |
| f"but numeric_values_by_nin[{num_n_in}] has V={values.shape[1]}" |
| ) |
| |
| original_valid = valid_positions[:, num_v_idx].bool() |
| |
| if assert_not_missing and not bool(original_valid.all().item()): |
| n_bad = int((~original_valid).sum().item()) |
| raise ValueError( |
| f"Numeric feature '{feature_name}' has {n_bad} naturally missing samples in batch" |
| ) |
| |
| |
| num_loss_mask[:, num_v_idx] = original_valid |
| |
| values[num_loss_mask] = 0.0 |
| valid_positions[:] = valid_positions & (~num_loss_mask) |
| |
| |
| |
| |
| masked_batch["masked_cat_local_ids"] = masked_cat_local_ids |
| masked_batch["masked_cat_valid_positions"] = masked_cat_valid_positions |
| masked_batch["cat_loss_mask"] = cat_loss_mask |
| |
| masked_batch["masked_numeric_values_by_nin"] = masked_numeric_values_by_nin |
| masked_batch["masked_numeric_valid_positions_by_nin"] = masked_numeric_valid_positions_by_nin |
| masked_batch["numeric_loss_mask_by_nin"] = numeric_loss_mask_by_nin |
| |
| return masked_batch |
|
|
|
|
| def build_train_eval_dataloaders( |
| dataset, |
| train_ratio=0.8, |
| seed=42, |
| batch_size=32, |
| ): |
| n = len(dataset) |
|
|
| n_train = int(n * train_ratio) |
| n_eval = n - n_train |
|
|
| split_generator = torch.Generator().manual_seed(seed) |
|
|
| train_ds, eval_ds = torch.utils.data.random_split( |
| dataset, |
| [n_train, n_eval], |
| generator=split_generator |
| ) |
|
|
| train_generator = torch.Generator() |
|
|
| train_loader = DataLoader( |
| train_ds, |
| batch_size=batch_size, |
| shuffle=True, |
| collate_fn=dataset.collate_fn, |
| generator=train_generator, |
| ) |
|
|
| eval_loader = DataLoader( |
| eval_ds, |
| batch_size=batch_size, |
| shuffle=False, |
| collate_fn=dataset.collate_fn, |
| ) |
|
|
| return train_loader, eval_loader, train_generator |
|
|
|
|
| def debug_print_first_sample(dataset, batch, batch_pos=0): |
| """ |
| Inspect one sample in a batch. |
| |
| This debug function checks masked_* fields against the original csv row. |
| Positions in loss_mask are allowed to mismatch. |
| |
| Args: |
| dataset: SoilFormerDataset |
| batch: collated + optionally masked batch |
| batch_pos: index inside the batch (not dataset row index) |
| """ |
| import math |
|
|
| def numeric_list_close(a, b, atol=1e-6, rtol=1e-5): |
| if len(a) != len(b): |
| return False |
| for x, y in zip(a, b): |
| if not math.isclose(float(x), float(y), rel_tol=rtol, abs_tol=atol): |
| return False |
| return True |
|
|
| def normalize_numeric_list(feat_name, vals, is_valid): |
| if not is_valid: |
| return [0.0] * len(vals) |
|
|
| stat_row = dataset.numeric_stats_index.loc[feat_name] |
| mean = float(stat_row["mean"]) |
| std = float(stat_row["std"]) |
| if std == 0.0: |
| std = 1.0 |
|
|
| return [(float(v) - mean) / std for v in vals] |
|
|
| if "row_idx" not in batch: |
| raise KeyError("batch must contain 'row_idx' for debug_print_first_sample") |
| if "sample_id" not in batch: |
| raise KeyError("batch must contain 'sample_id' for debug_print_first_sample") |
|
|
| row_idx = int(batch["row_idx"][batch_pos].item()) |
| row = dataset.df.iloc[row_idx] |
| sample_id = batch["sample_id"][batch_pos] |
|
|
| print("\n====================================================") |
| print("DEBUG SAMPLE") |
| print("====================================================") |
| print("batch_pos :", batch_pos) |
| print("row_idx :", row_idx) |
| print("sample_id :", sample_id) |
|
|
| |
| |
| |
| print("\n[CATEGORICAL FEATURES]") |
|
|
| cat_ids = batch["masked_cat_local_ids"][batch_pos] |
| cat_valids = batch["masked_cat_valid_positions"][batch_pos] |
| cat_loss_mask = batch.get("cat_loss_mask", None) |
| if cat_loss_mask is not None: |
| cat_loss_mask = cat_loss_mask[batch_pos] |
|
|
| for i, col in enumerate(dataset.cat_columns): |
| raw = row[col] |
| raw_str = str(raw) |
|
|
| got_id = int(cat_ids[i].item()) |
| got_valid = bool(cat_valids[i].item()) |
|
|
| spec = dataset.cat_vocab[col] |
| label2id = spec["label2id"] |
| mask_id = int(spec["mask_local_id"]) |
|
|
| if raw == "": |
| expected_id = mask_id |
| expected_valid = False |
| else: |
| expected_id = int(label2id[raw]) |
| expected_valid = True |
|
|
| is_loss_position = False |
| if cat_loss_mask is not None: |
| is_loss_position = bool(cat_loss_mask[i].item()) |
|
|
| if is_loss_position: |
| ok = True |
| else: |
| ok = (got_id == expected_id) and (got_valid == expected_valid) |
|
|
| print( |
| f"{i:03d} | {col} | " |
| f"raw={raw_str:<60} | " |
| f"id={got_id:<6} | expected={expected_id:<6} | " |
| f"valid={got_valid} | exp_valid={expected_valid} | " |
| f"loss_mask={is_loss_position} | ok={ok}" |
| ) |
|
|
| if not ok: |
| raise AssertionError( |
| f"\nCategorical mismatch\n" |
| f"batch_pos={batch_pos}\n" |
| f"row_idx={row_idx}\n" |
| f"feature={col}\n" |
| f"raw={raw}\n" |
| f"id={got_id}, expected={expected_id}\n" |
| f"valid={got_valid}, expected={expected_valid}" |
| ) |
|
|
| |
| |
| |
| print("\n[NUMERIC FEATURES]") |
|
|
| numeric_loss_mask_by_nin = batch.get("numeric_loss_mask_by_nin", None) |
|
|
| for group in dataset.numeric_groups: |
| n_in = int(group["n_in"]) |
| features = group["feature_names"] |
|
|
| values = batch["masked_numeric_values_by_nin"][n_in][batch_pos] |
| valids = batch["masked_numeric_valid_positions_by_nin"][n_in][batch_pos] |
|
|
| if numeric_loss_mask_by_nin is not None: |
| loss_mask = numeric_loss_mask_by_nin[n_in][batch_pos] |
| else: |
| loss_mask = None |
|
|
| print(f"\nGroup n_in={n_in}") |
|
|
| for i, feat in enumerate(features): |
| raw = row[feat] |
| raw_str = str(raw) |
|
|
| parsed, expected_valid = parse_numeric_cell(raw, n_in) |
| expected_norm = normalize_numeric_list(feat, parsed, expected_valid) |
|
|
| tensor_val = values[i].tolist() |
| got_valid = bool(valids[i].item()) |
|
|
| is_loss_position = False |
| if loss_mask is not None: |
| is_loss_position = bool(loss_mask[i].item()) |
|
|
| if is_loss_position: |
| ok = True |
| else: |
| value_ok = numeric_list_close(tensor_val, expected_norm) |
| valid_ok = (got_valid == expected_valid) |
| ok = value_ok and valid_ok |
|
|
| print( |
| f"{i:03d} | {feat} | " |
| f"raw={raw_str:<60} | " |
| f"tensor={tensor_val} | expected_norm={expected_norm} | " |
| f"valid={got_valid} | exp_valid={expected_valid} | " |
| f"loss_mask={is_loss_position} | ok={ok}" |
| ) |
|
|
| if not ok: |
| raise AssertionError( |
| f"\nNumeric mismatch\n" |
| f"batch_pos={batch_pos}\n" |
| f"row_idx={row_idx}\n" |
| f"feature={feat}\n" |
| f"raw={raw}\n" |
| f"tensor={tensor_val}\n" |
| f"expected={parsed}\n" |
| f"valid={got_valid}, expected={expected_valid}" |
| ) |
|
|
| |
| |
| |
| print("\n[VISION]") |
|
|
| try: |
| relative_path = dataset.photo_map[sample_id] |
| expected_path = join_photo_root(dataset.photo_root, relative_path) |
|
|
| |
| _ = dataset.load_image(expected_path) |
| expected_valid = True |
|
|
| except Exception: |
| expected_path = None |
| expected_valid = False |
|
|
| got_valid = bool(batch["vision_valid_positions"][batch_pos].item()) |
| img_shape = tuple(batch["pixel_values"][batch_pos].shape) |
|
|
| print("expected_path :", expected_path) |
| print("vision_valid :", got_valid) |
| print("image_shape :", img_shape) |
|
|
| if got_valid != expected_valid: |
| raise AssertionError( |
| f"\nVision validity mismatch\n" |
| f"batch_pos={batch_pos}\n" |
| f"row_idx={row_idx}\n" |
| f"expected={expected_valid}, got={got_valid}" |
| ) |
|
|
| print("\n====================================================") |
| print("DEBUG CHECK PASSED") |
| print("====================================================\n") |
|
|
|
|
| def main(): |
| dataset = SoilFormerDataset( |
| csv_path="data/tabular_data.csv", |
| photo_map_path="data/photo_map.json", |
| cat_vocab_path="data/cat_vocab.json", |
| numeric_vocab_path="data/numeric_vocab.json", |
| numeric_stats_path="data/tabular_meta_numeric_stats.csv", |
| photo_root="/Volumes/TOSHIBA EXT", |
| image_size=512, |
| id_column="id", |
| ) |
|
|
| train_loader, eval_loader, train_generator = build_train_eval_dataloaders(dataset) |
|
|
| print("Dataset size:", len(dataset)) |
|
|
| raw_batch = next(iter(eval_loader)) |
| batch = dataset.perform_active_mask( |
| raw_batch, |
| cat_ratio=0.15, |
| num_ratio=0.15, |
| seed=42, |
| ) |
|
|
| print("\nBatch check") |
| if "row_idx" in batch: |
| print("row_idx:", batch["row_idx"].shape, batch["row_idx"].dtype) |
| if "sample_id" in batch: |
| print("sample_id:", len(batch["sample_id"])) |
|
|
| print("original_cat_local_ids:", batch["original_cat_local_ids"].shape) |
| print("masked_cat_local_ids:", batch["masked_cat_local_ids"].shape) |
| print("original_cat_valid_positions:", batch["original_cat_valid_positions"].shape) |
| print("masked_cat_valid_positions:", batch["masked_cat_valid_positions"].shape) |
| print("cat_loss_mask:", batch["cat_loss_mask"].shape) |
|
|
| for k, v in batch["original_numeric_values_by_nin"].items(): |
| print(f"original_numeric_values_by_nin[{k}]:", v.shape) |
|
|
| for k, v in batch["masked_numeric_values_by_nin"].items(): |
| print(f"masked_numeric_values_by_nin[{k}]:", v.shape) |
|
|
| for k, v in batch["original_numeric_valid_positions_by_nin"].items(): |
| print(f"original_numeric_valid_positions_by_nin[{k}]:", v.shape) |
|
|
| for k, v in batch["masked_numeric_valid_positions_by_nin"].items(): |
| print(f"masked_numeric_valid_positions_by_nin[{k}]:", v.shape) |
|
|
| for k, v in batch["numeric_loss_mask_by_nin"].items(): |
| print(f"numeric_loss_mask_by_nin[{k}]:", v.shape) |
|
|
| print("pixel_values:", batch["pixel_values"].shape) |
| print("vision_valid_positions:", batch["vision_valid_positions"].shape) |
|
|
| print("\nTensor dtype check") |
| print("masked cat ids dtype:", batch["masked_cat_local_ids"].dtype) |
| print("masked numeric dtype:", next(iter(batch["masked_numeric_values_by_nin"].values())).dtype) |
| print("image dtype:", batch["pixel_values"].dtype) |
|
|
| print("\nLoader test finished successfully") |
|
|
| debug_print_first_sample(dataset, batch, batch_pos=0) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|