Spaces:
Running
Running
File size: 3,091 Bytes
aae9c6b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import torch
import torch.nn as nn
from .genconvit_ed import GenConViTED
from .genconvit_vae import GenConViTVAE
from torchvision import transforms
class GenConViT(nn.Module):
def __init__(self, config, ed, vae, net, fp16):
super(GenConViT, self).__init__()
self.net = net
self.fp16 = fp16
if self.net=='ed':
try:
self.model_ed = GenConViTED(config)
self.checkpoint_ed = torch.load(f'weight/{ed}.pth', map_location=torch.device('cpu'))
if 'state_dict' in self.checkpoint_ed:
self.model_ed.load_state_dict(self.checkpoint_ed['state_dict'])
else:
self.model_ed.load_state_dict(self.checkpoint_ed)
self.model_ed.eval()
if self.fp16:
self.model_ed.half()
except FileNotFoundError:
raise Exception(f"Error: weight/{ed}.pth file not found.")
elif self.net=='vae':
try:
self.model_vae = GenConViTVAE(config)
self.checkpoint_vae = torch.load(f'weight/{vae}.pth', map_location=torch.device('cpu'))
if 'state_dict' in self.checkpoint_vae:
self.model_vae.load_state_dict(self.checkpoint_vae['state_dict'])
else:
self.model_vae.load_state_dict(self.checkpoint_vae)
self.model_vae.eval()
if self.fp16:
self.model_vae.half()
except FileNotFoundError:
raise Exception(f"Error: weight/{vae}.pth file not found.")
else:
try:
self.model_ed = GenConViTED(config)
self.model_vae = GenConViTVAE(config)
self.checkpoint_ed = torch.load(f'weight/{ed}.pth', map_location=torch.device('cpu'))
self.checkpoint_vae = torch.load(f'weight/{vae}.pth', map_location=torch.device('cpu'))
if 'state_dict' in self.checkpoint_ed:
self.model_ed.load_state_dict(self.checkpoint_ed['state_dict'])
else:
self.model_ed.load_state_dict(self.checkpoint_ed)
if 'state_dict' in self.checkpoint_vae:
self.model_vae.load_state_dict(self.checkpoint_vae['state_dict'])
else:
self.model_vae.load_state_dict(self.checkpoint_vae)
self.model_ed.eval()
self.model_vae.eval()
if self.fp16:
self.model_ed.half()
self.model_vae.half()
except FileNotFoundError as e:
raise Exception(f"Error: Model weights file not found.")
def forward(self, x):
if self.net == 'ed' :
x = self.model_ed(x)
elif self.net == 'vae':
x,_ = self.model_vae(x)
else:
x1 = self.model_ed(x)
x2,_ = self.model_vae(x)
x = torch.cat((x1, x2), dim=0) #(x1+x2)/2 #
return x
|