sebastiansarasti commited on
Commit
8b06175
·
verified ·
1 Parent(s): 419e81f

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +80 -0
  2. loss.py +56 -0
  3. model.py +60 -0
  4. trainer.py +32 -0
  5. utils.py +32 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ from PIL import Image
4
+ from torchvision.models import vgg19
5
+ from model import StyleTransferModel
6
+ from trainer import trainer_fn
7
+ from utils import process_image, tensor_to_image
8
+
9
+
10
+ base_model = vgg19(pretrained=True).features
11
+ final_model = StyleTransferModel(base_model)
12
+
13
+ # define the title of the app
14
+ st.title('Style Transfer App')
15
+
16
+ # define the description of the app
17
+ st.write('This app applies the style of one image to another image. This can be used to create artistic images.')
18
+
19
+ # get all image files in the 'styles' folder
20
+ image_files = [f for f in os.listdir('styles') if f.lower().endswith(('png', 'jpg', 'jpeg', 'gif', 'bmp'))]
21
+
22
+ # display the images
23
+ st.write('Select style art to apply into your image:')
24
+
25
+ # Check how many images are available and set columns accordingly
26
+ num_images = len(image_files)
27
+ cols = st.columns(num_images)
28
+
29
+ # Define the size to which the images will be resized (width, height)
30
+ resize_width = 300
31
+ resize_height = 300
32
+
33
+ # show each image in a corresponding column
34
+ for idx, img_file in enumerate(image_files):
35
+ with cols[idx]:
36
+ st.write(f"Style {idx + 1}")
37
+ img_path = f'styles/{img_file}'
38
+ img = Image.open(img_path)
39
+
40
+ # Redimensionar la imagen
41
+ img_resized = img.resize((resize_width, resize_height))
42
+
43
+ st.image(img_resized, use_container_width=True)
44
+
45
+ # create a file uploader for the content image
46
+ st.write('Upload the content image:')
47
+ content_image = st.file_uploader('Content Image', type=['png', 'jpg', 'jpeg'])
48
+
49
+ # create the botton to select the style image between 1, 2, and 3
50
+ choice = st.selectbox('Select the style art:', [f'Style {i + 1}' for i in range(num_images)])
51
+
52
+ # create a button to run the model
53
+ if st.button('Apply Style Transfer'):
54
+ if content_image is not None:
55
+ # get the content image
56
+ content_img = Image.open(content_image)
57
+
58
+ # get the style image
59
+ style_choice = choice.split()[-1] # Extract style number from "Style 1", "Style 2", etc.
60
+ style_img = Image.open(os.path.join('styles', image_files[int(style_choice) - 1])) # Get full path
61
+
62
+ # preprocess the images
63
+ content_img = process_image(content_img)
64
+ style_img = process_image(style_img)
65
+
66
+ # run the model
67
+ st.write('Applying Style Transfer...')
68
+ target_image = trainer_fn(
69
+ content_img, style_img, content_img.clone().requires_grad_(True), final_model
70
+ )
71
+
72
+ # convert the tensor to image
73
+ target_image = tensor_to_image(target_image.squeeze(0))
74
+
75
+ # display the result
76
+ st.write('Result:')
77
+ st.image(target_image, use_container_width=True)
78
+ else:
79
+ st.write('Please upload a content image')
80
+
loss.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class StyleTransferLoss(nn.Module):
5
+ def __init__(self, model, content_img, style_img, device="cuda"):
6
+ super(StyleTransferLoss, self).__init__()
7
+ self.device = device
8
+ self.content_img = content_img.to(device)
9
+ self.style_img = style_img.to(device)
10
+ self.model = model.to(device)
11
+
12
+ def gram_matrix(self, feature_maps):
13
+ """
14
+ Calculate Gram Matrix for style features
15
+ """
16
+ B, C, H, W = feature_maps.size()
17
+ features = feature_maps.view(B * C, H * W)
18
+ G = torch.mm(features, features.t())
19
+ # Normalize by total elements
20
+ return G.div(B * C * H * W)
21
+
22
+ def get_features(self, image):
23
+ """
24
+ Get content and style features from the image
25
+ """
26
+ return self.model(image)
27
+
28
+ def content_loss(self, target_features, content_features):
29
+ """
30
+ Calculate content loss between target and content features
31
+ """
32
+ return torch.mean((target_features - content_features) ** 2)
33
+
34
+ def style_loss(self, target_features, style_features):
35
+ """
36
+ Calculate style loss between target and style features
37
+ """
38
+ loss = 0.0
39
+ for key in self.model.style_layers:
40
+ target_gram = self.gram_matrix(target_features[key])
41
+ style_gram = self.gram_matrix(style_features[key])
42
+ loss += torch.mean((target_gram - style_gram) ** 2)
43
+ return loss
44
+
45
+ def total_loss(
46
+ self, target_features, content_features, style_features, alpha=1, beta=1e8
47
+ ):
48
+ """
49
+ Calculate total loss (weighted sum of content and style losses)
50
+ """
51
+ content = self.content_loss(
52
+ target_features["block4"], content_features["block4"]
53
+ )
54
+ style = self.style_loss(target_features, style_features)
55
+
56
+ return alpha * content + beta * style
model.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ class StyleTransferModel(nn.Module):
4
+ def __init__(self, base_model):
5
+ super(StyleTransferModel, self).__init__()
6
+ vgg19 = base_model
7
+ # Freeze the parameters
8
+ for param in vgg19.parameters():
9
+ param.requires_grad = False
10
+
11
+ # Split VGG19 into blocks for feature extraction
12
+ self.block1 = vgg19[:4] # conv1_1, relu, conv1_2, relu
13
+ self.pool1 = vgg19[4] # maxpool
14
+ self.block2 = vgg19[5:9] # conv2_1, relu, conv2_2, relu
15
+ self.pool2 = vgg19[9] # maxpool
16
+ self.block3 = vgg19[10:18] # conv3_1 to relu3_4
17
+ self.pool3 = vgg19[18] # maxpool
18
+ self.block4 = vgg19[19:27] # conv4_1 to relu4_4
19
+ self.pool4 = vgg19[27] # maxpool
20
+ self.block5 = vgg19[28:36] # conv5_1 to relu5_4
21
+
22
+ # Define content and style layers
23
+ self.content_layers = ["block4"] # We'll use output of block4 for content
24
+ self.style_layers = [
25
+ "block1",
26
+ "block2",
27
+ "block3",
28
+ "block4",
29
+ "block5",
30
+ ] # All blocks for style
31
+
32
+ def forward(self, x):
33
+ # create a dict to save the results
34
+ features = {}
35
+
36
+ # Block 1
37
+ x = self.block1(x)
38
+ features["block1"] = x
39
+ x = self.pool1(x)
40
+
41
+ # Block 2
42
+ x = self.block2(x)
43
+ features["block2"] = x
44
+ x = self.pool2(x)
45
+
46
+ # Block 3
47
+ x = self.block3(x)
48
+ features["block3"] = x
49
+ x = self.pool3(x)
50
+
51
+ # Block 4
52
+ x = self.block4(x)
53
+ features["block4"] = x
54
+ x = self.pool4(x)
55
+
56
+ # Block 5
57
+ x = self.block5(x)
58
+ features["block5"] = x
59
+
60
+ return features
trainer.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.optim import Adam
2
+ from loss import StyleTransferLoss
3
+ import torch
4
+
5
+ def trainer_fn(content, style, target_image, model):
6
+ optimizer = Adam([target_image], lr=0.1)
7
+ loss_fn = StyleTransferLoss(
8
+ model=model, content_img=content, style_img=style, device="cpu"
9
+ )
10
+
11
+ with torch.no_grad():
12
+ content_features = loss_fn.get_features(content.to("cpu"))
13
+ style_features = loss_fn.get_features(style.to("cpu"))
14
+
15
+ EPOCHS = 100
16
+ for epoch in range(EPOCHS):
17
+ # set the gradients to zero
18
+ optimizer.zero_grad()
19
+
20
+ # get the features of the target image
21
+ target_features = loss_fn.get_features(target_image)
22
+
23
+ # calculate the total loss
24
+ loss = loss_fn.total_loss(target_features, content_features, style_features)
25
+
26
+ # backpropagate
27
+ loss.backward()
28
+
29
+ # update the weights
30
+ optimizer.step()
31
+
32
+ return target_image
utils.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision.transforms import transforms
2
+ import torch
3
+
4
+
5
+ def process_image(image, shape=(500, 500)):
6
+ """
7
+ This function takes an image and transforms it into a tensor
8
+
9
+ """
10
+ transform = transforms.Compose(
11
+ [
12
+ transforms.Resize(shape),
13
+ transforms.ToTensor(),
14
+ transforms.Normalize(
15
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
16
+ ),
17
+ ]
18
+ )
19
+ image = transform(image).unsqueeze(0)
20
+ return image
21
+
22
+ def tensor_to_image(tensor):
23
+ """
24
+ This function takes a tensor and transforms it into an image
25
+ """
26
+ inverse_normalize = transforms.Normalize(
27
+ mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
28
+ std=[1 / 0.229, 1 / 0.224, 1 / 0.225],
29
+ )
30
+ tensor = inverse_normalize(tensor)
31
+ tensor = torch.clamp(tensor, 0, 1)
32
+ return transforms.ToPILImage()(tensor)