Spaces:
Build error
Build error
Sophie98
commited on
Commit
•
3b83a8e
1
Parent(s):
bf82406
fix error
Browse files- segmentation.py +30 -28
- styleTransfer.py +16 -10
segmentation.py
CHANGED
@@ -7,34 +7,6 @@ import matplotlib.pyplot as plt
|
|
7 |
from PIL import Image
|
8 |
import segmentation_models as sm
|
9 |
|
10 |
-
# model_path = "model_checkpoint.h5"
|
11 |
-
# CLASSES = ['sofa']
|
12 |
-
BACKBONE = 'resnet50'
|
13 |
-
|
14 |
-
# define network parameters
|
15 |
-
# in_classes = 1 if len(CLASSES) == 1 else (len(CLASSES) + 1) # case for binary and multiclass segmentation
|
16 |
-
# actvation = 'sigmoid' if n_classes == 1 else 'softmax'
|
17 |
-
preprocess_input = sm.get_preprocessing(BACKBONE)
|
18 |
-
sm.set_framework('tf.keras')
|
19 |
-
# LR=0.0001
|
20 |
-
|
21 |
-
#create model architecture
|
22 |
-
# model = sm.Unet(BACKBONE, classes=n_classes, activation=activation)
|
23 |
-
# # define optomizer
|
24 |
-
# optim = keras.optimizers.Adam(LR)
|
25 |
-
# # Segmentation models losses can be combined together by '+' and scaled by integer or float factor
|
26 |
-
# dice_loss = sm.losses.DiceLoss()
|
27 |
-
# focal_loss = sm.losses.BinaryFocalLoss() if n_classes == 1 else sm.losses.CategoricalFocalLoss()
|
28 |
-
# total_loss = dice_loss + (1 * focal_loss)
|
29 |
-
# # actulally total_loss can be imported directly from library, above example just show you how to manipulate with losses
|
30 |
-
# # total_loss = sm.losses.binary_focal_dice_loss # or sm.losses.categorical_focal_dice_loss
|
31 |
-
# metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]
|
32 |
-
# # compile keras model with defined optimozer, loss and metrics
|
33 |
-
# model.compile(optim, total_loss, metrics)
|
34 |
-
|
35 |
-
# #load model
|
36 |
-
# model.load_weights(model_path)
|
37 |
-
model = keras.models.load_model('model_final.h5', compile=False)
|
38 |
|
39 |
def get_mask(image:Image) -> Image:
|
40 |
"""
|
@@ -46,6 +18,36 @@ def get_mask(image:Image) -> Image:
|
|
46 |
Return:
|
47 |
mask = corresponding maks of the image
|
48 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
test_img = np.array(image)#cv2.imread(path, cv2.IMREAD_COLOR)
|
50 |
test_img = cv2.resize(test_img, (640, 640))
|
51 |
test_img = cv2.cvtColor(test_img, cv2.COLOR_RGB2BGR)
|
|
|
7 |
from PIL import Image
|
8 |
import segmentation_models as sm
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
def get_mask(image:Image) -> Image:
|
12 |
"""
|
|
|
18 |
Return:
|
19 |
mask = corresponding maks of the image
|
20 |
"""
|
21 |
+
|
22 |
+
# model_path = "model_checkpoint.h5"
|
23 |
+
# CLASSES = ['sofa']
|
24 |
+
BACKBONE = 'resnet50'
|
25 |
+
|
26 |
+
# define network parameters
|
27 |
+
# in_classes = 1 if len(CLASSES) == 1 else (len(CLASSES) + 1) # case for binary and multiclass segmentation
|
28 |
+
# actvation = 'sigmoid' if n_classes == 1 else 'softmax'
|
29 |
+
preprocess_input = sm.get_preprocessing(BACKBONE)
|
30 |
+
sm.set_framework('tf.keras')
|
31 |
+
# LR=0.0001
|
32 |
+
|
33 |
+
#create model architecture
|
34 |
+
# model = sm.Unet(BACKBONE, classes=n_classes, activation=activation)
|
35 |
+
# # define optomizer
|
36 |
+
# optim = keras.optimizers.Adam(LR)
|
37 |
+
# # Segmentation models losses can be combined together by '+' and scaled by integer or float factor
|
38 |
+
# dice_loss = sm.losses.DiceLoss()
|
39 |
+
# focal_loss = sm.losses.BinaryFocalLoss() if n_classes == 1 else sm.losses.CategoricalFocalLoss()
|
40 |
+
# total_loss = dice_loss + (1 * focal_loss)
|
41 |
+
# # actulally total_loss can be imported directly from library, above example just show you how to manipulate with losses
|
42 |
+
# # total_loss = sm.losses.binary_focal_dice_loss # or sm.losses.categorical_focal_dice_loss
|
43 |
+
# metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]
|
44 |
+
# # compile keras model with defined optimozer, loss and metrics
|
45 |
+
# model.compile(optim, total_loss, metrics)
|
46 |
+
|
47 |
+
# #load model
|
48 |
+
# model.load_weights(model_path)
|
49 |
+
model = keras.models.load_model('model_final.h5', compile=False)
|
50 |
+
|
51 |
test_img = np.array(image)#cv2.imread(path, cv2.IMREAD_COLOR)
|
52 |
test_img = cv2.resize(test_img, (640, 640))
|
53 |
test_img = cv2.cvtColor(test_img, cv2.COLOR_RGB2BGR)
|
styleTransfer.py
CHANGED
@@ -34,19 +34,20 @@ def content_transform():
|
|
34 |
transform = transforms.Compose(transform_list)
|
35 |
return transform
|
36 |
|
37 |
-
# Advanced options
|
38 |
-
content_size=640
|
39 |
-
style_size=640
|
40 |
|
41 |
-
|
42 |
-
vgg.load_state_dict(torch.load(vgg_path))
|
43 |
-
vgg = nn.Sequential(*list(vgg.children())[:44])
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
decoder.eval()
|
52 |
Trans.eval()
|
@@ -105,6 +106,11 @@ def styleProjection(content_image,style_image):
|
|
105 |
}])
|
106 |
|
107 |
return Image.fromarray(np.uint8(result[0]['data'])[:,:,::-1]).convert('RGB')
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
|
110 |
|
|
|
34 |
transform = transforms.Compose(transform_list)
|
35 |
return transform
|
36 |
|
|
|
|
|
|
|
37 |
|
38 |
+
def StyleTransformer(content_img: Image, style_img: Image):
|
|
|
|
|
39 |
|
40 |
+
# Advanced options
|
41 |
+
content_size=640
|
42 |
+
style_size=640
|
43 |
|
44 |
+
vgg = StyTR.vgg
|
45 |
+
vgg.load_state_dict(torch.load(vgg_path))
|
46 |
+
vgg = nn.Sequential(*list(vgg.children())[:44])
|
47 |
+
|
48 |
+
decoder = StyTR.decoder
|
49 |
+
Trans = transformer.Transformer()
|
50 |
+
embedding = StyTR.PatchEmbed()
|
51 |
|
52 |
decoder.eval()
|
53 |
Trans.eval()
|
|
|
106 |
}])
|
107 |
|
108 |
return Image.fromarray(np.uint8(result[0]['data'])[:,:,::-1]).convert('RGB')
|
109 |
+
|
110 |
+
|
111 |
+
def create_styledSofa(content_image,style_image):
|
112 |
+
output = StyleTransformer(content_image,style_image)
|
113 |
+
return output
|
114 |
|
115 |
|
116 |
|