|
from stylegan2 import Generator, Encoder |
|
from torch import nn, autograd, optim |
|
import pandas as pd |
|
from tqdm import tqdm |
|
import torch |
|
import cv2 |
|
import os |
|
import random |
|
from torchvision import transforms |
|
from torchvision import utils |
|
import numpy as np |
|
from sklearn.svm import SVC |
|
from sklearn.model_selection import train_test_split |
|
from sklearn.metrics import classification_report, accuracy_score |
|
from sklearn.pipeline import make_pipeline |
|
from sklearn.svm import LinearSVC |
|
|
|
def accumulate(model1, model2, decay=0.999): |
|
par1 = dict(model1.named_parameters()) |
|
par2 = dict(model2.named_parameters()) |
|
|
|
for k in par1.keys(): |
|
par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay) |
|
self.ckpt = torch.load(self.ckpt, map_location=lambda storage, loc: storage) |
|
|
|
class GCA(): |
|
def __init__(self, distributed=False, h_path = None): |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {self.device}") |
|
self.distributed = distributed |
|
self.h_path = h_path |
|
self.size, self.n_mlp, self.channel_multiplier, self.cgan = 256, 8, 2, True |
|
self.classifier_nof_classes, self.embedding_size, self.latent = 2, 10, 512 |
|
self.g_reg_every, self.lr, self.ckpt = 4, 0.002, 'results/000500.pt' |
|
|
|
self.ckpt = torch.load(self.ckpt, map_location=lambda storage, loc: storage) |
|
self.generator = Generator(self.size, self.latent, self.n_mlp, channel_multiplier=self.channel_multiplier, |
|
conditional_gan=self.cgan, nof_classes=self.classifier_nof_classes, |
|
embedding_size=self.embedding_size).to(self.device) |
|
self.encoder = Encoder(self.size, channel_multiplier=self.channel_multiplier, output_channels=self.latent).to(self.device) |
|
self.generator.load_state_dict(self.ckpt["g"]); self.encoder.load_state_dict(self.ckpt["e"]) |
|
if self.distributed: |
|
local_rank = int(os.environ["LOCAL_RANK"]) |
|
self.generator = nn.parallel.DistributedDataParallel( |
|
generator, |
|
device_ids=[local_rank], |
|
output_device=local_rank, |
|
broadcast_buffers=False, |
|
) |
|
self.encoder = nn.parallel.DistributedDataParallel( |
|
encoder, |
|
device_ids=[local_rank], |
|
output_device=local_rank, |
|
broadcast_buffers=False, |
|
) |
|
|
|
self.transform = transforms.Compose( |
|
[ |
|
transforms.ToTensor(), |
|
transforms.Resize((256,256)), |
|
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), inplace=True), |
|
] |
|
) |
|
|
|
self.sex_coeff, self.age_coeff = None, None |
|
self.__get_hyperplanes__() |
|
self.w_shape = None |
|
|
|
|
|
def __load_image__(self, path): |
|
img = cv2.imread(path) |
|
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
img_tensor = self.transform(img_rgb).unsqueeze(0).to(self.device) |
|
return img_tensor |
|
|
|
def __process_in_batches__(self, patients, batch_size): |
|
style_vectors = [] |
|
for i in range(0, len(patients), batch_size): |
|
batch_paths = patients.iloc[i : i + batch_size]["Path"].tolist() |
|
batch_imgs = [self.__load_image__(path) for path in batch_paths] |
|
batch_imgs_tensor = torch.cat(batch_imgs, dim=0) |
|
with torch.no_grad(): |
|
|
|
w_latents = self.encoder(batch_imgs_tensor) |
|
|
|
style_vectors.extend(w_latents.cpu()) |
|
del batch_imgs_tensor, w_latents |
|
torch.cuda.empty_cache() |
|
return style_vectors |
|
|
|
def __load_cxr_data__(self, df): |
|
return self.__process_in_batches__(df, batch_size=16) |
|
|
|
def __get_patient_data__(self, rsna_csv="../datasets/rsna_patients.csv", cxpt_csv="../chexpert/versions/1/train.csv"): |
|
if os.path.exists(rsna_csv) and os.path.exists(cxpt_csv): |
|
n_patients = 500 |
|
rsna_csv = pd.DataFrame(pd.read_csv(rsna_csv)) |
|
cxpt_csv = pd.DataFrame(pd.read_csv(cxpt_csv)) |
|
rsna_csv["Image Index"] = "../datasets/rsna/" + rsna_csv["Image Index"] |
|
rsna_csv.rename(columns={"Image Index": "Path", "Patient Age": "Age", "Patient Gender": "Sex"}, inplace=True) |
|
|
|
|
|
male = rsna_csv[rsna_csv["Sex"] == "M"][:500] |
|
female = rsna_csv[rsna_csv["Sex"] == "F"][:500] |
|
young = rsna_csv[rsna_csv["Age"] < 20][:500] |
|
rsna = rsna_csv[rsna_csv["Age"] > 80][:250] |
|
cxpt = cxpt_csv[cxpt_csv["Age"] > 80][:250] |
|
old = pd.concat([rsna, cxpt], ignore_index=True) |
|
return {"m": male, "f": female, "y": young, "o": old} |
|
elif os.path.exists(rsna_csv): |
|
n_patients = 500 |
|
rsna_csv = pd.DataFrame(pd.read_csv(rsna_csv)) |
|
rsna_csv["Image Index"] = "../datasets/rsna/" + rsna_csv["Image Index"] |
|
rsna_csv.rename(columns={"Image Index": "Path", "Patient Age": "Age", "Patient Gender": "Sex"}, inplace=True) |
|
|
|
|
|
male = rsna_csv[rsna_csv["Sex"] == "M"][:500] |
|
female = rsna_csv[rsna_csv["Sex"] == "F"][:500] |
|
young = rsna_csv[rsna_csv["Age"] < 20][:500] |
|
old = rsna_csv[rsna_csv["Age"] > 80][:250] |
|
return {"m": male, "f": female, "y": young, "o": old} |
|
else: |
|
print(f"The path '{path}' does not exist.") |
|
return None |
|
|
|
def __learn_linearSVM__(self, d1, d2, df1, df2, key="Sex"): |
|
|
|
styles, labels = [], [] |
|
styles.extend(d1); labels.extend(list(df1["Sex"])) |
|
styles.extend(d2); labels.extend(list(df2["Sex"])) |
|
|
|
styles = np.array([style.numpy().flatten() for style in styles]) |
|
|
|
labels = np.array(labels) |
|
|
|
seed = 42 |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
|
|
indices = np.arange(len(styles)) |
|
np.random.shuffle(indices) |
|
styles, labels = styles[indices], labels[indices] |
|
self.w_shape = styles[0].shape |
|
|
|
X_train, X_test, y_train, y_test = train_test_split(styles, labels, test_size=0.2, random_state=seed) |
|
|
|
clf = make_pipeline(LinearSVC(random_state=0, tol=1e-5)) |
|
clf.fit(X_train, y_train) |
|
|
|
y_pred = clf.predict(X_test) |
|
return clf |
|
|
|
def __get_hyperplanes__(self): |
|
if os.path.exists(self.h_path): |
|
hyperplanes = torch.load(self.h_path) |
|
self.sex_coeff, self.age_coeff = hyperplanes[:512], hyperplanes[512:] |
|
else: |
|
patient_data = self.__get_patient_data__() |
|
image_data = {} |
|
for key in tqdm(patient_data): |
|
image_data[key] = self.__load_cxr_data__(patient_data[key]) |
|
sex = self.__learn_linearSVM__(image_data["m"], image_data["f"], patient_data["m"], patient_data["f"]).named_steps['linearsvc'].coef_[0].reshape((self.w_shape)) |
|
age = self.__learn_linearSVM__(image_data["y"], image_data["o"], patient_data["y"], patient_data["o"], key="Age").named_steps['linearsvc'].coef_[0].reshape((self.w_shape)) |
|
self.sex_coeff = (torch.from_numpy(sex).float()).to(self.device) |
|
self.age_coeff = (torch.from_numpy(age).float()).to(self.device) |
|
torch.save(torch.cat([self.sex_coeff, self.age_coeff], dim=0), "hyperplanes.pt") |
|
print("Sex and Age coefficient loaded!") |
|
|
|
def __age__(self, w, step_size = -2, magnitude=1): |
|
alpha = step_size * magnitude |
|
|
|
|
|
return w + alpha * self.age_coeff |
|
|
|
def __sex__(self, w, step_size = 1, magnitude=1): |
|
alpha = step_size * magnitude |
|
|
|
|
|
return w + alpha * self.sex_coeff |
|
|
|
def augment_helper(self, embedding, rate=0.8): |
|
|
|
|
|
np.random.seed(None); random.seed(None) |
|
if np.random.choice([True, False], p=[rate, 1-rate]): |
|
w_ = self.__sex__(embedding, magnitude=random.randint(-4,4)) |
|
w_ = self.__age__(w_, magnitude=random.randint(-2,2)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
synth, _ = self.generator([w_], input_is_latent=True) |
|
utils.save_image(synth, "real_samples_agesex.png", nrow=int(1 ** 2), normalize=True) |
|
return synth |
|
|
|
return None |
|
|
|
def augment(self, x, rate=0.8): |
|
x = torch.unsqueeze(self.transform(x), 0).to(self.device) |
|
embedding = self.encoder(x) |
|
aug_x = self.augment_helper(embedding, rate) |
|
if aug_x is not None: |
|
|
|
im = utils.make_grid(aug_x) |
|
|
|
return im.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() |
|
im = utils.make_grid(x) |
|
|
|
return im.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() |
|
|
|
if __name__ == "__main__": |
|
|
|
gca = GCA(h_path="hyperplanes.pt") |
|
|
|
img = cv2.imread("../datasets/rsna/00000007_000.png") |
|
gca.augment(img) |
|
|
|
|
|
|