File size: 5,250 Bytes
1a18f22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""Orchestration for the generative-augmentation SOTA baselines (category B).

These methods are compared against our SegGen method. Each runs in its OWN conda
env (see envs/) because their dependency stacks conflict with the main framework.
The shared contract: every generator must emit paired (image, mask) into

    <data_root>/<dataset>/<protocol>/synth_<method>/{images,masks}/

which the unified trainer then merges into the train split via --synth_train_dir.

Kept baselines:
  * SegGuidedDiff (diffusion, mask->image, medical, modern stack) -- best fit, USE-AS-IS
  * SPADE         (GAN, mask->image)                              -- ADAPT (needs sync_bn)
  * ControlNet    (diffusion, SD-finetune, mask->image)           -- ADAPT (needs SD ckpt)

Dropped (per scoping): StyleGAN2-ADA (no masks), LDM (dep hell + AE training).

This module only BUILDS the commands + assembles the standard synth dir; it does
not import the repos (they live in separate envs). Run the printed commands in the
matching env, then call assemble_synth_dir() (env-agnostic) to lay out pairs.
"""
from __future__ import annotations

import os
import shutil
from glob import glob

SOTA = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "sota"))


def assemble_synth_dir(generated_images_dir: str, masks_source_dir: str,
                       out_dir: str, strip_prefix: str = "condon_",
                       link: bool = True) -> int:
    """Pair each generated image with the real mask it was conditioned on.

    Mask-conditioned generators name outputs after the conditioning mask
    (SegGuidedDiff: 'condon_<maskname>.png'). We recover the mask name, copy/link
    the matching real mask, and place both under out_dir/{images,masks}/.
    Returns the number of pairs assembled.
    """
    img_out = os.path.join(out_dir, "images")
    msk_out = os.path.join(out_dir, "masks")
    os.makedirs(img_out, exist_ok=True)
    os.makedirs(msk_out, exist_ok=True)

    n = 0
    for gp in sorted(glob(os.path.join(generated_images_dir, "*"))):
        base = os.path.basename(gp)
        stem = os.path.splitext(base)[0]
        if strip_prefix and stem.startswith(strip_prefix):
            mask_stem = stem[len(strip_prefix):]
        else:
            mask_stem = stem
        cands = glob(os.path.join(masks_source_dir, mask_stem + ".*"))
        if not cands:
            continue
        out_name = f"synth_{n:06d}"
        dst_img = os.path.join(img_out, out_name + os.path.splitext(base)[1])
        dst_msk = os.path.join(msk_out, out_name + os.path.splitext(cands[0])[1])
        _place(gp, dst_img, link)
        _place(cands[0], dst_msk, link)
        n += 1
    return n


def _place(src, dst, link):
    if os.path.exists(dst):
        os.remove(dst)
    if link:
        os.symlink(os.path.abspath(src), dst)
    else:
        shutil.copy2(src, dst)


# ---- command builders (printed into run.sh; run in the matching conda env) ----

def segguideddiff_cmds(data_root, dataset, protocol, num_classes, in_channels,
                       img_size=256, epochs=400, sample_size=1000):
    repo = os.path.join(SOTA, "segmentation-guided-diffusion")
    img_dir = f"{data_root}/{dataset}/{protocol}/train/images"
    seg_dir = f"{data_root}/{dataset}/{protocol}/train/masks"
    train = (f"cd {repo} && python main.py --mode train --model_type DDIM "
             f"--img_size {img_size} --num_img_channels {in_channels} --dataset {dataset} "
             f"--img_dir {img_dir} --seg_dir {seg_dir} --segmentation_guided "
             f"--num_segmentation_classes {num_classes} --num_epochs {epochs}")
    synth = (f"cd {repo} && python main.py --mode eval_many --model_type DDIM "
             f"--img_size {img_size} --num_img_channels {in_channels} --dataset {dataset} "
             f"--seg_dir {seg_dir} --segmentation_guided "
             f"--num_segmentation_classes {num_classes} --eval_sample_size {sample_size}")
    return train, synth


def spade_cmds(data_root, dataset, protocol, num_classes, img_size=256, niter=100):
    repo = os.path.join(SOTA, "SPADE")
    img_dir = f"{data_root}/{dataset}/{protocol}/train/images"
    lab_dir = f"{data_root}/{dataset}/{protocol}/train/masks"
    setup = (f"cd {repo}/models/networks && "
             f"git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch && "
             f"cp -r Synchronized-BatchNorm-PyTorch/sync_batchnorm .")
    train = (f"cd {repo} && python train.py --name {dataset}_spade --dataset_mode custom "
             f"--label_dir {lab_dir} --image_dir {img_dir} --label_nc {num_classes} "
             f"--no_instance --crop_size {img_size} --load_size {int(img_size*1.12)} --niter {niter}")
    synth = (f"cd {repo} && python test.py --name {dataset}_spade --dataset_mode custom "
             f"--label_dir {lab_dir} --image_dir {img_dir} --label_nc {num_classes} "
             f"--no_instance --results_dir ./synth_{dataset}")
    return setup, train, synth


def controlnet_notes():
    return ("ControlNet: download SD v1.5 (~4GB), run tool_add_control.py, write a "
            "MyDataset that colorizes integer masks to RGB hints + triples grayscale "
            "images to 3ch, then tutorial_train.py. Run in env seggen-controlnet.")