ProgramSkripsi / model /genconvit.py
Yuuki0's picture
first commit
e0c75d6
import torch
import torch.nn as nn
from .genconvit_ed import GenConViTED
from .genconvit_vae import GenConViTVAE
from torchvision import transforms
class GenConViT(nn.Module):
def __init__(self, config, ed, vae, net, fp16):
super(GenConViT, self).__init__()
self.net = net
self.fp16 = fp16
self.model_ed = None
self.model_vae = None
if ed:
try:
self.model_ed = GenConViTED(config)
self.checkpoint_ed = torch.load(f'weight/{ed}.pth', map_location=torch.device('cpu'))
if 'state_dict' in self.checkpoint_ed:
self.model_ed.load_state_dict(self.checkpoint_ed['state_dict'])
else:
self.model_ed.load_state_dict(self.checkpoint_ed)
self.model_ed.eval()
if self.fp16:
self.model_ed.half()
except FileNotFoundError:
if self.net == 'ed' or self.net == 'genconvit':
raise Exception(f"Error: weight/{ed}.pth file not found.")
if vae:
try:
self.model_vae = GenConViTVAE(config)
self.checkpoint_vae = torch.load(f'weight/{vae}.pth', map_location=torch.device('cpu'))
if 'state_dict' in self.checkpoint_vae:
self.model_vae.load_state_dict(self.checkpoint_vae['state_dict'])
else:
self.model_vae.load_state_dict(self.checkpoint_vae)
self.model_vae.eval()
if self.fp16:
self.model_vae.half()
except FileNotFoundError:
if self.net == 'vae' or self.net == 'genconvit':
raise Exception(f"Error: weight/{vae}.pth file not found.")
def forward(self, x, net=None):
if net is None:
net = self.net
if net == 'ed' :
if self.model_ed is None:
raise RuntimeError("ED model (AE) is not loaded. Ensure weights were provided during initialization.")
x = self.model_ed(x)
elif net == 'vae':
if self.model_vae is None:
raise RuntimeError("VAE model is not loaded. Ensure weights were provided during initialization.")
x,_ = self.model_vae(x)
else: # 'genconvit' or 'both'
if self.model_ed is None or self.model_vae is None:
raise RuntimeError("Both ED and VAE models must be loaded for 'genconvit' mode.")
x1 = self.model_ed(x)
x2,_ = self.model_vae(x)
x = torch.cat((x1, x2), dim=0) #(x1+x2)/2 #
return x