|
|
|
|
|
|
|
|
|
|
|
|
|
import sys |
|
import os |
|
import tempfile |
|
import warnings |
|
from pathlib import Path |
|
import nltk |
|
import torch |
|
from torch import nn |
|
import torchvision.transforms as transforms |
|
import numpy as np |
|
import imageio |
|
from PIL import Image as Image_PIL |
|
from scipy.stats import truncnorm |
|
from nltk.corpus import wordnet as wn |
|
import cma |
|
import sklearn.metrics |
|
import cog |
|
|
|
sys.path.insert(0, "stylegan2_ada_pytorch") |
|
from pytorch_pretrained_biggan import convert_to_images, utils |
|
import inference.utils as inference_utils |
|
import data_utils.utils as data_utils |
|
|
|
NORM_MEAN = torch.Tensor([0.485, 0.456, 0.406]).view(3, 1, 1) |
|
NORM_STD = torch.Tensor([0.229, 0.224, 0.225]).view(3, 1, 1) |
|
|
|
nltk.download("wordnet") |
|
IND2NAME = { |
|
index: wn.of2ss("%08dn" % offset).lemma_names()[0] |
|
for offset, index in utils.IMAGENET.items() |
|
} |
|
NAME2IND = dict([(value, key) for key, value in IND2NAME.items()]) |
|
|
|
CLASS_NAMES = sorted(list(IND2NAME.values())) |
|
|
|
|
|
class Predictor(cog.Predictor): |
|
def setup(self): |
|
torch.manual_seed(np.random.randint(sys.maxsize)) |
|
warnings.simplefilter("ignore", cma.evolution_strategy.InjectionWarning) |
|
self.last_gen_model = None |
|
self.last_feature_extractor = None |
|
self.model = None |
|
self.feature_extractor = None |
|
self.noise_size = 128 |
|
self.batch_size = 4 |
|
self.size = 256 |
|
|
|
@cog.input("image", type=Path, help="Input image Instance") |
|
@cog.input("gen_model", type=str, options=["icgan", "cc_icgan"], default="icgan", |
|
help='Select type of IC-GAN model. "icgan" is conditioned on the input image; ' |
|
'"cc_icgan" is conditioned on both input image and a conditional_class') |
|
@cog.input("conditional_class", type=str, default=None, options=CLASS_NAMES, |
|
help="Choose conditional class. Only valid for gen_model=cc_icgan") |
|
@cog.input("num_samples", type=int, default=1, options=[1, 4, 9, 16], |
|
help="number of samples generated") |
|
@cog.input("seed", type=int, default=0, help="seed=0 means no seed") |
|
def predict(self, image, gen_model="icgan", conditional_class=None, num_samples=1, seed=0): |
|
assert isinstance(seed, int), "seed should be an integer" |
|
if gen_model == 'cc_icgan': |
|
assert conditional_class is not None, 'please set conditional_class for cc_icgan' |
|
num_samples_ranked = num_samples |
|
experiment_name = ( |
|
"icgan_biggan_imagenet_res256" |
|
if gen_model == "icgan" |
|
else "cc_icgan_biggan_imagenet_res256" |
|
) |
|
num_samples_total = num_samples * 10 |
|
truncation = 0.7 |
|
if conditional_class is not None: |
|
class_index = NAME2IND[conditional_class] |
|
|
|
input_image_instance = str(image) |
|
|
|
if gen_model == "icgan": |
|
class_index = None |
|
|
|
if seed == 0: |
|
seed = None |
|
|
|
state = None if not seed else np.random.RandomState(seed) |
|
np.random.seed(seed) |
|
|
|
feature_extractor_name = ("classification" if gen_model == "cc_icgan" else "selfsupervised") |
|
|
|
|
|
self.feature_extractor, self.last_feature_extractor = load_feature_extractor( |
|
gen_model, self.last_feature_extractor, self.feature_extractor) |
|
|
|
if input_image_instance not in ["None", "", None]: |
|
print("Obtaining instance features from input image!") |
|
input_feature_index = None |
|
input_image_tensor = preprocess_input_image(input_image_instance, self.size) |
|
with torch.no_grad(): |
|
input_features, _ = self.feature_extractor(input_image_tensor.cuda()) |
|
input_features /= torch.linalg.norm(input_features, dim=-1, keepdims=True) |
|
elif input_feature_index is not None: |
|
print("Selecting an instance from pre-extracted vectors!") |
|
input_features = np.load( |
|
"stored_instances/imagenet_res" |
|
+ str(self.size) |
|
+ "_rn50_" |
|
+ feature_extractor_name |
|
+ "_kmeans_k1000_instance_features.npy", |
|
allow_pickle=True, |
|
).item()["instance_features"][input_feature_index: input_feature_index + 1] |
|
else: |
|
input_features = None |
|
|
|
|
|
self.model, self.last_gen_model = load_generative_model( |
|
gen_model, self.last_gen_model, experiment_name, self.model) |
|
|
|
|
|
replace_to_inplace_relu(self.model) |
|
|
|
|
|
noise_vector = truncnorm.rvs( |
|
-2 * truncation, |
|
2 * truncation, |
|
size=(num_samples_total, self.noise_size), |
|
random_state=state, |
|
).astype(np.float32) |
|
noise_vector = torch.tensor(noise_vector, requires_grad=False, device="cuda") |
|
if input_features is not None: |
|
instance_vector = torch.tensor( |
|
input_features, requires_grad=False, device="cuda" |
|
).repeat(num_samples_total, 1) |
|
else: |
|
instance_vector = None |
|
if class_index is not None: |
|
input_label = torch.LongTensor([class_index] * num_samples_total) |
|
else: |
|
input_label = None |
|
if input_feature_index is not None: |
|
print("Conditioning on instance with index: ", input_feature_index) |
|
|
|
all_outs, all_dists = [], [] |
|
for i_bs in range(num_samples_total // self.batch_size + 1): |
|
start = i_bs * self.batch_size |
|
end = min(start + self.batch_size, num_samples_total) |
|
if start == end: |
|
break |
|
out = get_output( |
|
noise_vector[start:end], |
|
input_label[start:end] if input_label is not None else None, |
|
instance_vector[start:end] if instance_vector is not None else None, |
|
self.model, |
|
truncation, |
|
channels=3, |
|
) |
|
|
|
if instance_vector is not None: |
|
|
|
out_ = preprocess_generated_image(out) |
|
with torch.no_grad(): |
|
out_features, _ = self.feature_extractor(out_.cuda()) |
|
out_features /= torch.linalg.norm(out_features, dim=-1, keepdims=True) |
|
dists = sklearn.metrics.pairwise_distances( |
|
out_features.cpu(), |
|
instance_vector[start:end].cpu(), |
|
metric="euclidean", |
|
n_jobs=-1, |
|
) |
|
all_dists.append(np.diagonal(dists)) |
|
all_outs.append(out.detach().cpu()) |
|
del out |
|
all_outs = torch.cat(all_outs) |
|
all_dists = np.concatenate(all_dists) |
|
|
|
|
|
selected_idxs = np.argsort(all_dists)[:num_samples_ranked] |
|
|
|
row_i, col_i, i_im = 0, 0, 0 |
|
all_images_mosaic = np.zeros( |
|
( |
|
3, |
|
self.size * (int(np.sqrt(num_samples_ranked))), |
|
self.size * (int(np.sqrt(num_samples_ranked))), |
|
) |
|
) |
|
for j in selected_idxs: |
|
all_images_mosaic[ |
|
:, |
|
row_i * self.size: row_i * self.size + self.size, |
|
col_i * self.size: col_i * self.size + self.size, |
|
] = all_outs[j] |
|
if row_i == int(np.sqrt(num_samples_ranked)) - 1: |
|
row_i = 0 |
|
if col_i == int(np.sqrt(num_samples_ranked)) - 1: |
|
col_i = 0 |
|
else: |
|
col_i += 1 |
|
else: |
|
row_i += 1 |
|
i_im += 1 |
|
|
|
out_path = Path(tempfile.mkdtemp()) / "out.png" |
|
save(all_images_mosaic[np.newaxis, ...], str(out_path), torch_format=False) |
|
return out_path |
|
|
|
|
|
def replace_to_inplace_relu(model): |
|
for child_name, child in model.named_children(): |
|
if isinstance(child, nn.ReLU): |
|
setattr(model, child_name, nn.ReLU(inplace=False)) |
|
else: |
|
replace_to_inplace_relu(child) |
|
|
|
|
|
def save(out, name=None, torch_format=True): |
|
if torch_format: |
|
with torch.no_grad(): |
|
out = out.cpu().numpy() |
|
img = convert_to_images(out)[0] |
|
if name: |
|
imageio.imwrite(name, np.asarray(img)) |
|
return img |
|
|
|
|
|
def load_icgan(experiment_name, root_=""): |
|
root = os.path.join(root_, experiment_name) |
|
config = torch.load("%s/%s.pth" % (root, "state_dict_best0"))["config"] |
|
|
|
config["weights_root"] = root_ |
|
config["model_backbone"] = "biggan" |
|
config["experiment_name"] = experiment_name |
|
G, config = inference_utils.load_model_inference(config) |
|
G.cuda() |
|
G.eval() |
|
return G |
|
|
|
|
|
def get_output(noise_vector, input_label, input_features, model, truncation, channels): |
|
|
|
noise_vector = noise_vector.clamp(-2 * truncation, 2 * truncation) |
|
if input_label is not None: |
|
input_label = torch.LongTensor(input_label) |
|
else: |
|
input_label = None |
|
|
|
out = model( |
|
noise_vector, |
|
input_label.cuda() if input_label is not None else None, |
|
input_features.cuda() if input_features is not None else None, |
|
) |
|
|
|
if channels == 1: |
|
out = out.mean(dim=1, keepdim=True) |
|
out = out.repeat(1, 3, 1, 1) |
|
return out |
|
|
|
|
|
def load_generative_model(gen_model, last_gen_model, experiment_name, model): |
|
|
|
if gen_model != last_gen_model: |
|
model = load_icgan(experiment_name, root_="./") |
|
last_gen_model = gen_model |
|
return model, last_gen_model |
|
|
|
|
|
def load_feature_extractor(gen_model, last_feature_extractor, feature_extractor): |
|
|
|
feat_ext_name = "classification" if gen_model == "cc_icgan" else "selfsupervised" |
|
if last_feature_extractor != feat_ext_name: |
|
if feat_ext_name == "classification": |
|
feat_ext_path = "" |
|
else: |
|
feat_ext_path = "swav_pretrained.pth.tar" |
|
last_feature_extractor = feat_ext_name |
|
feature_extractor = data_utils.load_pretrained_feature_extractor( |
|
feat_ext_path, feature_extractor=feat_ext_name |
|
) |
|
feature_extractor.eval() |
|
return feature_extractor, last_feature_extractor |
|
|
|
|
|
def preprocess_input_image(input_image_path, size): |
|
pil_image = Image_PIL.open(input_image_path).convert("RGB") |
|
transform_list = transforms.Compose( |
|
[ |
|
data_utils.CenterCropLongEdge(), |
|
transforms.Resize((size, size)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(NORM_MEAN, NORM_STD), |
|
] |
|
) |
|
tensor_image = transform_list(pil_image) |
|
tensor_image = torch.nn.functional.interpolate( |
|
tensor_image.unsqueeze(0), 224, mode="bicubic", align_corners=True |
|
) |
|
return tensor_image |
|
|
|
|
|
def preprocess_generated_image(image): |
|
transform_list = transforms.Normalize(NORM_MEAN, NORM_STD) |
|
image = transform_list(image * 0.5 + 0.5) |
|
image = torch.nn.functional.interpolate( |
|
image, 224, mode="bicubic", align_corners=True |
|
) |
|
return image |
|
|