Spaces:
Build error
Build error
Sophie98
commited on
Commit
·
baaaa83
1
Parent(s):
ab7b996
find next error
Browse files- app.py +4 -0
- styleTransfer.py +1 -6
app.py
CHANGED
@@ -96,11 +96,15 @@ def style_sofa(input_img: np.ndarray, style_img: np.ndarray):
|
|
96 |
resized_style = resize_style(style_img)
|
97 |
resized_style.save('resized_style.jpg')
|
98 |
# generate mask for image
|
|
|
99 |
mask = get_mask(resized_img)
|
100 |
mask.save('mask.jpg')
|
|
|
|
|
101 |
styled_sofa = create_styledSofa(resized_img,resized_style)
|
102 |
styled_sofa.save('styled_sofa.jpg')
|
103 |
# postprocess the final image
|
|
|
104 |
new_sofa = replace_sofa(resized_img,mask,styled_sofa)
|
105 |
new_sofa = new_sofa.crop(box)
|
106 |
return new_sofa
|
|
|
96 |
resized_style = resize_style(style_img)
|
97 |
resized_style.save('resized_style.jpg')
|
98 |
# generate mask for image
|
99 |
+
print('generating mask...')
|
100 |
mask = get_mask(resized_img)
|
101 |
mask.save('mask.jpg')
|
102 |
+
# Created a styled sofa
|
103 |
+
print('Styling sofa...')
|
104 |
styled_sofa = create_styledSofa(resized_img,resized_style)
|
105 |
styled_sofa.save('styled_sofa.jpg')
|
106 |
# postprocess the final image
|
107 |
+
print('Replacing sofa...')
|
108 |
new_sofa = replace_sofa(resized_img,mask,styled_sofa)
|
109 |
new_sofa = new_sofa.crop(box)
|
110 |
return new_sofa
|
styleTransfer.py
CHANGED
@@ -20,7 +20,6 @@ embedding_path = 'embedding_iter_160000.pth'
|
|
20 |
|
21 |
def style_transform(h,w):
|
22 |
k = (h,w)
|
23 |
-
size = int(np.max(k))
|
24 |
transform_list = []
|
25 |
transform_list.append(transforms.CenterCrop((h,w)))
|
26 |
transform_list.append(transforms.ToTensor())
|
@@ -37,8 +36,6 @@ def content_transform():
|
|
37 |
content_size=640
|
38 |
style_size=640
|
39 |
|
40 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
41 |
-
|
42 |
vgg = StyTR.vgg
|
43 |
vgg.load_state_dict(torch.load(vgg_path))
|
44 |
vgg = nn.Sequential(*list(vgg.children())[:44])
|
@@ -70,14 +67,12 @@ content_tf = content_transform()
|
|
70 |
style_tf = style_transform(style_size,style_size)
|
71 |
|
72 |
def StyleTransformer(content_img: Image, style_img: Image):
|
73 |
-
|
74 |
network.to(device)
|
75 |
-
|
76 |
content = content_tf(content_img.convert("RGB"))
|
77 |
style = style_tf(style_img.convert("RGB"))
|
78 |
style = style.to(device).unsqueeze(0)
|
79 |
content = content.to(device).unsqueeze(0)
|
80 |
-
|
81 |
with torch.no_grad():
|
82 |
output= network(content,style)
|
83 |
output = output[0].cpu().squeeze()
|
|
|
20 |
|
21 |
def style_transform(h,w):
|
22 |
k = (h,w)
|
|
|
23 |
transform_list = []
|
24 |
transform_list.append(transforms.CenterCrop((h,w)))
|
25 |
transform_list.append(transforms.ToTensor())
|
|
|
36 |
content_size=640
|
37 |
style_size=640
|
38 |
|
|
|
|
|
39 |
vgg = StyTR.vgg
|
40 |
vgg.load_state_dict(torch.load(vgg_path))
|
41 |
vgg = nn.Sequential(*list(vgg.children())[:44])
|
|
|
67 |
style_tf = style_transform(style_size,style_size)
|
68 |
|
69 |
def StyleTransformer(content_img: Image, style_img: Image):
|
70 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
71 |
network.to(device)
|
|
|
72 |
content = content_tf(content_img.convert("RGB"))
|
73 |
style = style_tf(style_img.convert("RGB"))
|
74 |
style = style.to(device).unsqueeze(0)
|
75 |
content = content.to(device).unsqueeze(0)
|
|
|
76 |
with torch.no_grad():
|
77 |
output= network(content,style)
|
78 |
output = output[0].cpu().squeeze()
|