dongyi commited on
Commit
a65915c
1 Parent(s): 3d2dfd7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -3
app.py CHANGED
@@ -14,6 +14,8 @@ from data import CustomDataLoader
14
  from data.super_dataset import SuperDataset
15
  from configs import parse_config
16
  from utils.augmentation import ImagePathToImage
 
 
17
 
18
 
19
  class Stylizer(nn.Module):
@@ -118,6 +120,8 @@ def tensor2file(input_image):
118
  else:
119
  return image_pil
120
 
 
 
121
  def generate_multi_model(input_img):
122
 
123
  # parse config
@@ -146,16 +150,15 @@ def generate_multi_model(input_img):
146
  dataset = SuperDataset(config)
147
  dataloader = CustomDataLoader(config, dataset)
148
 
149
- device = "cuda"
150
  model_dict = torch.load("./pretrained_models/phase2_pretrain_90000.pth", map_location='cpu')
151
 
152
  # init netG
153
- netG = Stylizer(ngf=config['model']['ngf'], phase=2, model_weights=model_dict['G_ema_model']).to(device)
154
 
155
  for data in dataloader:
156
 
157
  real_A = data['test_A'].to(device)
158
- fake_B = netG(real_A, mixing=False)
159
  output_img = tensor2file(fake_B) # get image results
160
 
161
  return output_img
@@ -167,6 +170,44 @@ def generate_one_shot(src_img, img_prompt):
167
  output_img = src_img
168
  return output_img
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  def generate_zero_shot(src_img, txt_prompt):
171
  output_img = src_img
172
  return output_img
 
14
  from data.super_dataset import SuperDataset
15
  from configs import parse_config
16
  from utils.augmentation import ImagePathToImage
17
+ import clip
18
+ from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode
19
 
20
 
21
  class Stylizer(nn.Module):
 
120
  else:
121
  return image_pil
122
 
123
+
124
+ device = "cuda"
125
  def generate_multi_model(input_img):
126
 
127
  # parse config
 
150
  dataset = SuperDataset(config)
151
  dataloader = CustomDataLoader(config, dataset)
152
 
 
153
  model_dict = torch.load("./pretrained_models/phase2_pretrain_90000.pth", map_location='cpu')
154
 
155
  # init netG
156
+ model = Stylizer(ngf=config['model']['ngf'], phase=2, model_weights=model_dict['G_ema_model']).to(device)
157
 
158
  for data in dataloader:
159
 
160
  real_A = data['test_A'].to(device)
161
+ fake_B = model(real_A, mixing=False)
162
  output_img = tensor2file(fake_B) # get image results
163
 
164
  return output_img
 
170
  output_img = src_img
171
  return output_img
172
 
173
+ # init model
174
+ state_dict = torch.load(f"./checkpoints/{img_prompt[-2:]}/epoch_latest.pth", map_location='cpu')
175
+ model = Stylizer(ngf=64, phase=3, model_weights=state_dict['G_ema_model'])
176
+ model.to(device)
177
+ model.eval()
178
+ model.requires_grad_(False)
179
+
180
+ clip_model, img_preprocess = clip.load('ViT-B/32', device=args.device)
181
+ clip_model.eval()
182
+ clip_model.requires_grad_(False)
183
+
184
+ # image transform for stylizer
185
+ img_transform = Compose([
186
+ Resize((512, 512), interpolation=InterpolationMode.LANCZOS),
187
+ ToTensor(),
188
+ Normalize([0.5], [0.5])
189
+ ])
190
+
191
+ # get clip features
192
+ with torch.no_grad():
193
+ img = img_preprocess(Image.open(f"./example/reference/{img_prompt[-2:]}.png")).unsqueeze(0).to(args.device)
194
+ clip_feats = clip_model.encode_image(img)
195
+ clip_feats /= clip_feats.norm(dim=1, keepdim=True)
196
+
197
+
198
+ # load image & to tensor
199
+ img = Image.open(src_img)
200
+ if not img.mode == 'RGB':
201
+ img = img.convert('RGB')
202
+ img = img_transform(img).unsqueeze(0).to(device)
203
+
204
+ # stylize it !
205
+ with torch.no_grad():
206
+ res = model(img, clip_feats=clip_feats)
207
+
208
+ output_img = tensor2file(res) # get image results
209
+ return output_img
210
+
211
  def generate_zero_shot(src_img, txt_prompt):
212
  output_img = src_img
213
  return output_img