Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -14,6 +14,8 @@ from data import CustomDataLoader
|
|
14 |
from data.super_dataset import SuperDataset
|
15 |
from configs import parse_config
|
16 |
from utils.augmentation import ImagePathToImage
|
|
|
|
|
17 |
|
18 |
|
19 |
class Stylizer(nn.Module):
|
@@ -118,6 +120,8 @@ def tensor2file(input_image):
|
|
118 |
else:
|
119 |
return image_pil
|
120 |
|
|
|
|
|
121 |
def generate_multi_model(input_img):
|
122 |
|
123 |
# parse config
|
@@ -146,16 +150,15 @@ def generate_multi_model(input_img):
|
|
146 |
dataset = SuperDataset(config)
|
147 |
dataloader = CustomDataLoader(config, dataset)
|
148 |
|
149 |
-
device = "cuda"
|
150 |
model_dict = torch.load("./pretrained_models/phase2_pretrain_90000.pth", map_location='cpu')
|
151 |
|
152 |
# init netG
|
153 |
-
|
154 |
|
155 |
for data in dataloader:
|
156 |
|
157 |
real_A = data['test_A'].to(device)
|
158 |
-
fake_B =
|
159 |
output_img = tensor2file(fake_B) # get image results
|
160 |
|
161 |
return output_img
|
@@ -167,6 +170,44 @@ def generate_one_shot(src_img, img_prompt):
|
|
167 |
output_img = src_img
|
168 |
return output_img
|
169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
def generate_zero_shot(src_img, txt_prompt):
|
171 |
output_img = src_img
|
172 |
return output_img
|
|
|
14 |
from data.super_dataset import SuperDataset
|
15 |
from configs import parse_config
|
16 |
from utils.augmentation import ImagePathToImage
|
17 |
+
import clip
|
18 |
+
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode
|
19 |
|
20 |
|
21 |
class Stylizer(nn.Module):
|
|
|
120 |
else:
|
121 |
return image_pil
|
122 |
|
123 |
+
|
124 |
+
device = "cuda"
|
125 |
def generate_multi_model(input_img):
|
126 |
|
127 |
# parse config
|
|
|
150 |
dataset = SuperDataset(config)
|
151 |
dataloader = CustomDataLoader(config, dataset)
|
152 |
|
|
|
153 |
model_dict = torch.load("./pretrained_models/phase2_pretrain_90000.pth", map_location='cpu')
|
154 |
|
155 |
# init netG
|
156 |
+
model = Stylizer(ngf=config['model']['ngf'], phase=2, model_weights=model_dict['G_ema_model']).to(device)
|
157 |
|
158 |
for data in dataloader:
|
159 |
|
160 |
real_A = data['test_A'].to(device)
|
161 |
+
fake_B = model(real_A, mixing=False)
|
162 |
output_img = tensor2file(fake_B) # get image results
|
163 |
|
164 |
return output_img
|
|
|
170 |
output_img = src_img
|
171 |
return output_img
|
172 |
|
173 |
+
# init model
|
174 |
+
state_dict = torch.load(f"./checkpoints/{img_prompt[-2:]}/epoch_latest.pth", map_location='cpu')
|
175 |
+
model = Stylizer(ngf=64, phase=3, model_weights=state_dict['G_ema_model'])
|
176 |
+
model.to(device)
|
177 |
+
model.eval()
|
178 |
+
model.requires_grad_(False)
|
179 |
+
|
180 |
+
clip_model, img_preprocess = clip.load('ViT-B/32', device=args.device)
|
181 |
+
clip_model.eval()
|
182 |
+
clip_model.requires_grad_(False)
|
183 |
+
|
184 |
+
# image transform for stylizer
|
185 |
+
img_transform = Compose([
|
186 |
+
Resize((512, 512), interpolation=InterpolationMode.LANCZOS),
|
187 |
+
ToTensor(),
|
188 |
+
Normalize([0.5], [0.5])
|
189 |
+
])
|
190 |
+
|
191 |
+
# get clip features
|
192 |
+
with torch.no_grad():
|
193 |
+
img = img_preprocess(Image.open(f"./example/reference/{img_prompt[-2:]}.png")).unsqueeze(0).to(args.device)
|
194 |
+
clip_feats = clip_model.encode_image(img)
|
195 |
+
clip_feats /= clip_feats.norm(dim=1, keepdim=True)
|
196 |
+
|
197 |
+
|
198 |
+
# load image & to tensor
|
199 |
+
img = Image.open(src_img)
|
200 |
+
if not img.mode == 'RGB':
|
201 |
+
img = img.convert('RGB')
|
202 |
+
img = img_transform(img).unsqueeze(0).to(device)
|
203 |
+
|
204 |
+
# stylize it !
|
205 |
+
with torch.no_grad():
|
206 |
+
res = model(img, clip_feats=clip_feats)
|
207 |
+
|
208 |
+
output_img = tensor2file(res) # get image results
|
209 |
+
return output_img
|
210 |
+
|
211 |
def generate_zero_shot(src_img, txt_prompt):
|
212 |
output_img = src_img
|
213 |
return output_img
|