Spaces:
Build error
Build error
import sys, os | |
import numpy as np | |
import scipy | |
import torch | |
import torch.nn as nn | |
from scipy import ndimage | |
from tqdm import tqdm, trange | |
from PIL import Image | |
import torch.hub | |
import torchvision | |
import torch.nn.functional as F | |
# download deeplabv2_resnet101_msc-cocostuff164k-100000.pth from | |
# https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/deeplabv2_resnet101_msc-cocostuff164k-100000.pth | |
# and put the path here | |
CKPT_PATH = "TODO" | |
rescale = lambda x: (x + 1.) / 2. | |
def rescale_bgr(x): | |
x = (x+1)*127.5 | |
x = torch.flip(x, dims=[0]) | |
return x | |
class COCOStuffSegmenter(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
self.n_labels = 182 | |
model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", n_classes=self.n_labels) | |
ckpt_path = CKPT_PATH | |
model.load_state_dict(torch.load(ckpt_path)) | |
self.model = model | |
normalize = torchvision.transforms.Normalize(mean=self.mean, std=self.std) | |
self.image_transform = torchvision.transforms.Compose([ | |
torchvision.transforms.Lambda(lambda image: torch.stack( | |
[normalize(rescale_bgr(x)) for x in image])) | |
]) | |
def forward(self, x, upsample=None): | |
x = self._pre_process(x) | |
x = self.model(x) | |
if upsample is not None: | |
x = torch.nn.functional.upsample_bilinear(x, size=upsample) | |
return x | |
def _pre_process(self, x): | |
x = self.image_transform(x) | |
return x | |
def mean(self): | |
# bgr | |
return [104.008, 116.669, 122.675] | |
def std(self): | |
return [1.0, 1.0, 1.0] | |
def input_size(self): | |
return [3, 224, 224] | |
def run_model(img, model): | |
model = model.eval() | |
with torch.no_grad(): | |
segmentation = model(img, upsample=(img.shape[2], img.shape[3])) | |
segmentation = torch.argmax(segmentation, dim=1, keepdim=True) | |
return segmentation.detach().cpu() | |
def get_input(batch, k): | |
x = batch[k] | |
if len(x.shape) == 3: | |
x = x[..., None] | |
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) | |
return x.float() | |
def save_segmentation(segmentation, path): | |
# --> class label to uint8, save as png | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
assert len(segmentation.shape)==4 | |
assert segmentation.shape[0]==1 | |
for seg in segmentation: | |
seg = seg.permute(1,2,0).numpy().squeeze().astype(np.uint8) | |
seg = Image.fromarray(seg) | |
seg.save(path) | |
def iterate_dataset(dataloader, destpath, model): | |
os.makedirs(destpath, exist_ok=True) | |
num_processed = 0 | |
for i, batch in tqdm(enumerate(dataloader), desc="Data"): | |
try: | |
img = get_input(batch, "image") | |
img = img.cuda() | |
seg = run_model(img, model) | |
path = batch["relative_file_path_"][0] | |
path = os.path.splitext(path)[0] | |
path = os.path.join(destpath, path + ".png") | |
save_segmentation(seg, path) | |
num_processed += 1 | |
except Exception as e: | |
print(e) | |
print("but anyhow..") | |
print("Processed {} files. Bye.".format(num_processed)) | |
from taming.data.sflckr import Examples | |
from torch.utils.data import DataLoader | |
if __name__ == "__main__": | |
dest = sys.argv[1] | |
batchsize = 1 | |
print("Running with batch-size {}, saving to {}...".format(batchsize, dest)) | |
model = COCOStuffSegmenter({}).cuda() | |
print("Instantiated model.") | |
dataset = Examples() | |
dloader = DataLoader(dataset, batch_size=batchsize) | |
iterate_dataset(dataloader=dloader, destpath=dest, model=model) | |
print("done.") | |