|
import sys, os |
|
import numpy as np |
|
import scipy |
|
import torch |
|
import torch.nn as nn |
|
from scipy import ndimage |
|
from tqdm import tqdm, trange |
|
from PIL import Image |
|
import torch.hub |
|
import torchvision |
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
CKPT_PATH = "TODO" |
|
|
|
rescale = lambda x: (x + 1.) / 2. |
|
|
|
def rescale_bgr(x): |
|
x = (x+1)*127.5 |
|
x = torch.flip(x, dims=[0]) |
|
return x |
|
|
|
|
|
class COCOStuffSegmenter(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
self.n_labels = 182 |
|
model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", n_classes=self.n_labels) |
|
ckpt_path = CKPT_PATH |
|
model.load_state_dict(torch.load(ckpt_path)) |
|
self.model = model |
|
|
|
normalize = torchvision.transforms.Normalize(mean=self.mean, std=self.std) |
|
self.image_transform = torchvision.transforms.Compose([ |
|
torchvision.transforms.Lambda(lambda image: torch.stack( |
|
[normalize(rescale_bgr(x)) for x in image])) |
|
]) |
|
|
|
def forward(self, x, upsample=None): |
|
x = self._pre_process(x) |
|
x = self.model(x) |
|
if upsample is not None: |
|
x = torch.nn.functional.upsample_bilinear(x, size=upsample) |
|
return x |
|
|
|
def _pre_process(self, x): |
|
x = self.image_transform(x) |
|
return x |
|
|
|
@property |
|
def mean(self): |
|
|
|
return [104.008, 116.669, 122.675] |
|
|
|
@property |
|
def std(self): |
|
return [1.0, 1.0, 1.0] |
|
|
|
@property |
|
def input_size(self): |
|
return [3, 224, 224] |
|
|
|
|
|
def run_model(img, model): |
|
model = model.eval() |
|
with torch.no_grad(): |
|
segmentation = model(img, upsample=(img.shape[2], img.shape[3])) |
|
segmentation = torch.argmax(segmentation, dim=1, keepdim=True) |
|
return segmentation.detach().cpu() |
|
|
|
|
|
def get_input(batch, k): |
|
x = batch[k] |
|
if len(x.shape) == 3: |
|
x = x[..., None] |
|
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) |
|
return x.float() |
|
|
|
|
|
def save_segmentation(segmentation, path): |
|
|
|
os.makedirs(os.path.dirname(path), exist_ok=True) |
|
assert len(segmentation.shape)==4 |
|
assert segmentation.shape[0]==1 |
|
for seg in segmentation: |
|
seg = seg.permute(1,2,0).numpy().squeeze().astype(np.uint8) |
|
seg = Image.fromarray(seg) |
|
seg.save(path) |
|
|
|
|
|
def iterate_dataset(dataloader, destpath, model): |
|
os.makedirs(destpath, exist_ok=True) |
|
num_processed = 0 |
|
for i, batch in tqdm(enumerate(dataloader), desc="Data"): |
|
try: |
|
img = get_input(batch, "image") |
|
img = img.cuda() |
|
seg = run_model(img, model) |
|
|
|
path = batch["relative_file_path_"][0] |
|
path = os.path.splitext(path)[0] |
|
|
|
path = os.path.join(destpath, path + ".png") |
|
save_segmentation(seg, path) |
|
num_processed += 1 |
|
except Exception as e: |
|
print(e) |
|
print("but anyhow..") |
|
|
|
print("Processed {} files. Bye.".format(num_processed)) |
|
|
|
|
|
from taming.data.sflckr import Examples |
|
from torch.utils.data import DataLoader |
|
|
|
if __name__ == "__main__": |
|
dest = sys.argv[1] |
|
batchsize = 1 |
|
print("Running with batch-size {}, saving to {}...".format(batchsize, dest)) |
|
|
|
model = COCOStuffSegmenter({}).cuda() |
|
print("Instantiated model.") |
|
|
|
dataset = Examples() |
|
dloader = DataLoader(dataset, batch_size=batchsize) |
|
iterate_dataset(dataloader=dloader, destpath=dest, model=model) |
|
print("done.") |
|
|