from enum import Enum import torch from model_classes import Model200M, Model5M, SyntheticV2 from model_transforms import transform_200M, transform_5M, transform_synthetic class ModelType(str, Enum): MIDJOURNEY_200M = "midjourney_200M" DIFFUSIONS_200M = "diffusions_200M" MIDJOURNEY_5M = "midjourney_5M" DIFFUSIONS_5M = "diffusions_5M" SYNTHETIC_DETECTOR_V2 = "synthetic_detector_v2" def __str__(self): return str(self.value) @staticmethod def get_list(): return [model_type.value for model_type in ModelType] def load_model(value: ModelType): model = type_to_class[value] path = type_to_path[value] ckpt = torch.load(path, map_location=torch.device('cpu')) model.load_state_dict(ckpt) model.eval() return model type_to_class = { ModelType.MIDJOURNEY_200M : Model200M(), ModelType.DIFFUSIONS_200M : Model200M(), ModelType.MIDJOURNEY_5M : Model5M(), ModelType.DIFFUSIONS_5M : Model5M(), ModelType.SYNTHETIC_DETECTOR_V2 : SyntheticV2(), } type_to_path = { ModelType.MIDJOURNEY_200M : 'models/midjourney200M.pt', ModelType.DIFFUSIONS_200M : 'models/diffusions200M.pt', ModelType.MIDJOURNEY_5M : 'models/midjourney5M.pt', ModelType.DIFFUSIONS_5M : 'models/diffusions5M.pt', ModelType.SYNTHETIC_DETECTOR_V2 : 'models/synthetic_detector_v2.pt', } type_to_loaded_model = { ModelType.MIDJOURNEY_200M: load_model(ModelType.MIDJOURNEY_200M), ModelType.DIFFUSIONS_200M: load_model(ModelType.DIFFUSIONS_200M), ModelType.MIDJOURNEY_5M: load_model(ModelType.MIDJOURNEY_5M), ModelType.DIFFUSIONS_5M: load_model(ModelType.DIFFUSIONS_5M), ModelType.SYNTHETIC_DETECTOR_V2: load_model(ModelType.SYNTHETIC_DETECTOR_V2) } type_to_transforms = { ModelType.MIDJOURNEY_200M: transform_200M, ModelType.DIFFUSIONS_200M: transform_200M, ModelType.MIDJOURNEY_5M: transform_5M, ModelType.DIFFUSIONS_5M: transform_5M, ModelType.SYNTHETIC_DETECTOR_V2: transform_synthetic }