import argparse import logging import os import pandas as pd import pytorch_lightning as pl import timm import torch import torch.nn.functional as F import torchvision.transforms as transforms from PIL import Image from pytorch_lightning.callbacks import ( EarlyStopping, ModelCheckpoint, ) from sklearn.metrics import roc_auc_score from torchmetrics import ( Accuracy, Recall, ) from .diffusion_data_loader import load_dataloader class ImageClassifier(pl.LightningModule): def __init__(self, lmd=0): super().__init__() self.model = timm.create_model( "resnet50", pretrained=True, num_classes=1, ) self.accuracy = Accuracy(task="binary", threshold=0.5) self.recall = Recall(task="binary", threshold=0.5) self.validation_outputs = [] self.lmd = lmd def forward(self, x): return self.model(x) def training_step(self, batch): images, labels, _ = batch outputs = self.forward(images).squeeze() print(f"Shape of outputs (training): {outputs.shape}") print(f"Shape of labels (training): {labels.shape}") loss = F.binary_cross_entropy_with_logits(outputs, labels.float()) logging.info(f"Training Step - ERM loss: {loss.item()}") loss += self.lmd * (outputs**2).mean() # SD loss penalty logging.info(f"Training Step - SD loss: {loss.item()}") return loss def validation_step(self, batch): images, labels, _ = batch outputs = self.forward(images).squeeze() if outputs.shape == torch.Size([]): return print(f"Shape of outputs (validation): {outputs.shape}") print(f"Shape of labels (validation): {labels.shape}") loss = F.binary_cross_entropy_with_logits(outputs, labels.float()) preds = torch.sigmoid(outputs) self.log("val_loss", loss, prog_bar=True, sync_dist=True) self.log( "val_acc", self.accuracy(preds, labels.int()), prog_bar=True, sync_dist=True, ) self.log( "val_recall", self.recall(preds, labels.int()), prog_bar=True, sync_dist=True, ) output = {"val_loss": loss, "preds": preds, "labels": labels} self.validation_outputs.append(output) logging.info(f"Validation Step - Batch loss: {loss.item()}") return output def predict_step(self, batch): images, label, domain = batch outputs = self.forward(images).squeeze() preds = torch.sigmoid(outputs) return preds, label, domain def on_validation_epoch_end(self): if not self.validation_outputs: logging.warning("No outputs in validation step to process") return preds = torch.cat([x["preds"] for x in self.validation_outputs]) labels = torch.cat([x["labels"] for x in self.validation_outputs]) if labels.unique().size(0) == 1: logging.warning("Only one class in validation step") return auc_score = roc_auc_score(labels.cpu(), preds.cpu()) self.log("val_auc", auc_score, prog_bar=True, sync_dist=True) logging.info(f"Validation Epoch End - AUC score: {auc_score}") self.validation_outputs = [] def configure_optimizers(self): optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0005) return optimizer def load_image(image_path, transform=None): image = Image.open(image_path).convert("RGB") if transform: image = transform(image) return image def predict_single_image(image, model): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) image = image.to(device) model.eval() with torch.no_grad(): image = image.unsqueeze(0) output = model(image).squeeze() prediction = torch.sigmoid(output).item() return prediction if __name__ == "__main__": checkpoint_callback = ModelCheckpoint( monitor="val_loss", dirpath="./model_checkpoints/", filename="image-classifier-{step}-{val_loss:.2f}", save_top_k=3, mode="min", every_n_train_steps=1001, enable_version_counter=True, ) early_stop_callback = EarlyStopping( monitor="val_loss", patience=4, mode="min", ) parser = argparse.ArgumentParser() parser.add_argument( "--ckpt_path", help="checkpoint to continue from", required=False, ) parser.add_argument( "--predict", help="predict on test set", action="store_true", ) parser.add_argument("--reset", help="reset training", action="store_true") parser.add_argument( "--predict_image", help="predict the class of a single image", action="store_true", ) parser.add_argument( "--image_path", help="path to the image to predict", type=str, required=False, ) parser.add_argument( "--dir", help="path to the images to predict", type=str, required=False, ) parser.add_argument( "--output_file", help="path to output file", type=str, required=False, ) args = parser.parse_args() train_domains = [0, 1, 4] val_domains = [0, 1, 4] lmd_value = 0 if args.predict: test_dl = load_dataloader( [0, 1, 2, 3, 4], "test", batch_size=10, num_workers=1, ) model = ImageClassifier.load_from_checkpoint(args.ckpt_path) trainer = pl.Trainer() predictions = trainer.predict(model, dataloaders=test_dl) preds, labels, domains = zip(*predictions) preds = torch.cat(preds).cpu().numpy() labels = torch.cat(labels).cpu().numpy() domains = torch.cat(domains).cpu().numpy() print(preds.shape, labels.shape, domains.shape) df = pd.DataFrame( {"preds": preds, "labels": labels, "domains": domains}, ) filename = "preds-" + args.ckpt_path.split("/")[-1] df.to_csv(f"outputs/{filename}.csv", index=False) elif args.predict_image: image_path = args.image_path model = ImageClassifier.load_from_checkpoint(args.ckpt_path) # Define the transformations for the image transform = transforms.Compose( [ transforms.CenterCrop((256, 256)), transforms.ToTensor(), ], ) image = load_image(image_path, transform) prediction = predict_single_image(image, model) print("prediction", prediction) # Output the prediction print( f"Prediction for {image_path}: " f"{'Human' if prediction <= 0.05 else 'Generated'}", ) elif args.dir is not None: predictions = [] model = ImageClassifier.load_from_checkpoint(args.ckpt_path) transform = transforms.Compose( [ transforms.CenterCrop((256, 256)), transforms.ToTensor(), ], ) for root, dirs, files in os.walk(os.path.abspath(args.dir)): for f_name in files: f = os.path.join(root, f_name) print(f"Predicting: {f}") p = predict_single_image(f, model) predictions.append([f, f.split("/")[-2], p, p > 0.5]) print(f"--predicted: {p}") df = pd.DataFrame( predictions, columns=["path", "folder", "pred", "class"], ) df.to_csv(args.output_file, index=False) else: logging.basicConfig( filename="training.log", filemode="w", level=logging.INFO, force=True, ) train_dl = load_dataloader( train_domains, "train", batch_size=128, num_workers=4, ) logging.info("Training dataloader loaded") val_dl = load_dataloader( val_domains, "val", batch_size=128, num_workers=4, ) logging.info("Validation dataloader loaded") if args.reset: model = ImageClassifier.load_from_checkpoint(args.ckpt_path) else: model = ImageClassifier(lmd=lmd_value) trainer = pl.Trainer( callbacks=[checkpoint_callback, early_stop_callback], max_steps=20000, val_check_interval=1000, check_val_every_n_epoch=None, ) trainer.fit( model=model, train_dataloaders=train_dl, val_dataloaders=val_dl, ckpt_path=args.ckpt_path if not args.reset else None, )