Upload 5 files
Browse files- app.py +80 -0
- loss.py +56 -0
- model.py +60 -0
- trainer.py +32 -0
- utils.py +32 -0
app.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import os
|
3 |
+
from PIL import Image
|
4 |
+
from torchvision.models import vgg19
|
5 |
+
from model import StyleTransferModel
|
6 |
+
from trainer import trainer_fn
|
7 |
+
from utils import process_image, tensor_to_image
|
8 |
+
|
9 |
+
|
10 |
+
base_model = vgg19(pretrained=True).features
|
11 |
+
final_model = StyleTransferModel(base_model)
|
12 |
+
|
13 |
+
# define the title of the app
|
14 |
+
st.title('Style Transfer App')
|
15 |
+
|
16 |
+
# define the description of the app
|
17 |
+
st.write('This app applies the style of one image to another image. This can be used to create artistic images.')
|
18 |
+
|
19 |
+
# get all image files in the 'styles' folder
|
20 |
+
image_files = [f for f in os.listdir('styles') if f.lower().endswith(('png', 'jpg', 'jpeg', 'gif', 'bmp'))]
|
21 |
+
|
22 |
+
# display the images
|
23 |
+
st.write('Select style art to apply into your image:')
|
24 |
+
|
25 |
+
# Check how many images are available and set columns accordingly
|
26 |
+
num_images = len(image_files)
|
27 |
+
cols = st.columns(num_images)
|
28 |
+
|
29 |
+
# Define the size to which the images will be resized (width, height)
|
30 |
+
resize_width = 300
|
31 |
+
resize_height = 300
|
32 |
+
|
33 |
+
# show each image in a corresponding column
|
34 |
+
for idx, img_file in enumerate(image_files):
|
35 |
+
with cols[idx]:
|
36 |
+
st.write(f"Style {idx + 1}")
|
37 |
+
img_path = f'styles/{img_file}'
|
38 |
+
img = Image.open(img_path)
|
39 |
+
|
40 |
+
# Redimensionar la imagen
|
41 |
+
img_resized = img.resize((resize_width, resize_height))
|
42 |
+
|
43 |
+
st.image(img_resized, use_container_width=True)
|
44 |
+
|
45 |
+
# create a file uploader for the content image
|
46 |
+
st.write('Upload the content image:')
|
47 |
+
content_image = st.file_uploader('Content Image', type=['png', 'jpg', 'jpeg'])
|
48 |
+
|
49 |
+
# create the botton to select the style image between 1, 2, and 3
|
50 |
+
choice = st.selectbox('Select the style art:', [f'Style {i + 1}' for i in range(num_images)])
|
51 |
+
|
52 |
+
# create a button to run the model
|
53 |
+
if st.button('Apply Style Transfer'):
|
54 |
+
if content_image is not None:
|
55 |
+
# get the content image
|
56 |
+
content_img = Image.open(content_image)
|
57 |
+
|
58 |
+
# get the style image
|
59 |
+
style_choice = choice.split()[-1] # Extract style number from "Style 1", "Style 2", etc.
|
60 |
+
style_img = Image.open(os.path.join('styles', image_files[int(style_choice) - 1])) # Get full path
|
61 |
+
|
62 |
+
# preprocess the images
|
63 |
+
content_img = process_image(content_img)
|
64 |
+
style_img = process_image(style_img)
|
65 |
+
|
66 |
+
# run the model
|
67 |
+
st.write('Applying Style Transfer...')
|
68 |
+
target_image = trainer_fn(
|
69 |
+
content_img, style_img, content_img.clone().requires_grad_(True), final_model
|
70 |
+
)
|
71 |
+
|
72 |
+
# convert the tensor to image
|
73 |
+
target_image = tensor_to_image(target_image.squeeze(0))
|
74 |
+
|
75 |
+
# display the result
|
76 |
+
st.write('Result:')
|
77 |
+
st.image(target_image, use_container_width=True)
|
78 |
+
else:
|
79 |
+
st.write('Please upload a content image')
|
80 |
+
|
loss.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class StyleTransferLoss(nn.Module):
|
5 |
+
def __init__(self, model, content_img, style_img, device="cuda"):
|
6 |
+
super(StyleTransferLoss, self).__init__()
|
7 |
+
self.device = device
|
8 |
+
self.content_img = content_img.to(device)
|
9 |
+
self.style_img = style_img.to(device)
|
10 |
+
self.model = model.to(device)
|
11 |
+
|
12 |
+
def gram_matrix(self, feature_maps):
|
13 |
+
"""
|
14 |
+
Calculate Gram Matrix for style features
|
15 |
+
"""
|
16 |
+
B, C, H, W = feature_maps.size()
|
17 |
+
features = feature_maps.view(B * C, H * W)
|
18 |
+
G = torch.mm(features, features.t())
|
19 |
+
# Normalize by total elements
|
20 |
+
return G.div(B * C * H * W)
|
21 |
+
|
22 |
+
def get_features(self, image):
|
23 |
+
"""
|
24 |
+
Get content and style features from the image
|
25 |
+
"""
|
26 |
+
return self.model(image)
|
27 |
+
|
28 |
+
def content_loss(self, target_features, content_features):
|
29 |
+
"""
|
30 |
+
Calculate content loss between target and content features
|
31 |
+
"""
|
32 |
+
return torch.mean((target_features - content_features) ** 2)
|
33 |
+
|
34 |
+
def style_loss(self, target_features, style_features):
|
35 |
+
"""
|
36 |
+
Calculate style loss between target and style features
|
37 |
+
"""
|
38 |
+
loss = 0.0
|
39 |
+
for key in self.model.style_layers:
|
40 |
+
target_gram = self.gram_matrix(target_features[key])
|
41 |
+
style_gram = self.gram_matrix(style_features[key])
|
42 |
+
loss += torch.mean((target_gram - style_gram) ** 2)
|
43 |
+
return loss
|
44 |
+
|
45 |
+
def total_loss(
|
46 |
+
self, target_features, content_features, style_features, alpha=1, beta=1e8
|
47 |
+
):
|
48 |
+
"""
|
49 |
+
Calculate total loss (weighted sum of content and style losses)
|
50 |
+
"""
|
51 |
+
content = self.content_loss(
|
52 |
+
target_features["block4"], content_features["block4"]
|
53 |
+
)
|
54 |
+
style = self.style_loss(target_features, style_features)
|
55 |
+
|
56 |
+
return alpha * content + beta * style
|
model.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
class StyleTransferModel(nn.Module):
|
4 |
+
def __init__(self, base_model):
|
5 |
+
super(StyleTransferModel, self).__init__()
|
6 |
+
vgg19 = base_model
|
7 |
+
# Freeze the parameters
|
8 |
+
for param in vgg19.parameters():
|
9 |
+
param.requires_grad = False
|
10 |
+
|
11 |
+
# Split VGG19 into blocks for feature extraction
|
12 |
+
self.block1 = vgg19[:4] # conv1_1, relu, conv1_2, relu
|
13 |
+
self.pool1 = vgg19[4] # maxpool
|
14 |
+
self.block2 = vgg19[5:9] # conv2_1, relu, conv2_2, relu
|
15 |
+
self.pool2 = vgg19[9] # maxpool
|
16 |
+
self.block3 = vgg19[10:18] # conv3_1 to relu3_4
|
17 |
+
self.pool3 = vgg19[18] # maxpool
|
18 |
+
self.block4 = vgg19[19:27] # conv4_1 to relu4_4
|
19 |
+
self.pool4 = vgg19[27] # maxpool
|
20 |
+
self.block5 = vgg19[28:36] # conv5_1 to relu5_4
|
21 |
+
|
22 |
+
# Define content and style layers
|
23 |
+
self.content_layers = ["block4"] # We'll use output of block4 for content
|
24 |
+
self.style_layers = [
|
25 |
+
"block1",
|
26 |
+
"block2",
|
27 |
+
"block3",
|
28 |
+
"block4",
|
29 |
+
"block5",
|
30 |
+
] # All blocks for style
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
# create a dict to save the results
|
34 |
+
features = {}
|
35 |
+
|
36 |
+
# Block 1
|
37 |
+
x = self.block1(x)
|
38 |
+
features["block1"] = x
|
39 |
+
x = self.pool1(x)
|
40 |
+
|
41 |
+
# Block 2
|
42 |
+
x = self.block2(x)
|
43 |
+
features["block2"] = x
|
44 |
+
x = self.pool2(x)
|
45 |
+
|
46 |
+
# Block 3
|
47 |
+
x = self.block3(x)
|
48 |
+
features["block3"] = x
|
49 |
+
x = self.pool3(x)
|
50 |
+
|
51 |
+
# Block 4
|
52 |
+
x = self.block4(x)
|
53 |
+
features["block4"] = x
|
54 |
+
x = self.pool4(x)
|
55 |
+
|
56 |
+
# Block 5
|
57 |
+
x = self.block5(x)
|
58 |
+
features["block5"] = x
|
59 |
+
|
60 |
+
return features
|
trainer.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.optim import Adam
|
2 |
+
from loss import StyleTransferLoss
|
3 |
+
import torch
|
4 |
+
|
5 |
+
def trainer_fn(content, style, target_image, model):
|
6 |
+
optimizer = Adam([target_image], lr=0.1)
|
7 |
+
loss_fn = StyleTransferLoss(
|
8 |
+
model=model, content_img=content, style_img=style, device="cpu"
|
9 |
+
)
|
10 |
+
|
11 |
+
with torch.no_grad():
|
12 |
+
content_features = loss_fn.get_features(content.to("cpu"))
|
13 |
+
style_features = loss_fn.get_features(style.to("cpu"))
|
14 |
+
|
15 |
+
EPOCHS = 100
|
16 |
+
for epoch in range(EPOCHS):
|
17 |
+
# set the gradients to zero
|
18 |
+
optimizer.zero_grad()
|
19 |
+
|
20 |
+
# get the features of the target image
|
21 |
+
target_features = loss_fn.get_features(target_image)
|
22 |
+
|
23 |
+
# calculate the total loss
|
24 |
+
loss = loss_fn.total_loss(target_features, content_features, style_features)
|
25 |
+
|
26 |
+
# backpropagate
|
27 |
+
loss.backward()
|
28 |
+
|
29 |
+
# update the weights
|
30 |
+
optimizer.step()
|
31 |
+
|
32 |
+
return target_image
|
utils.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision.transforms import transforms
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
def process_image(image, shape=(500, 500)):
|
6 |
+
"""
|
7 |
+
This function takes an image and transforms it into a tensor
|
8 |
+
|
9 |
+
"""
|
10 |
+
transform = transforms.Compose(
|
11 |
+
[
|
12 |
+
transforms.Resize(shape),
|
13 |
+
transforms.ToTensor(),
|
14 |
+
transforms.Normalize(
|
15 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
16 |
+
),
|
17 |
+
]
|
18 |
+
)
|
19 |
+
image = transform(image).unsqueeze(0)
|
20 |
+
return image
|
21 |
+
|
22 |
+
def tensor_to_image(tensor):
|
23 |
+
"""
|
24 |
+
This function takes a tensor and transforms it into an image
|
25 |
+
"""
|
26 |
+
inverse_normalize = transforms.Normalize(
|
27 |
+
mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
|
28 |
+
std=[1 / 0.229, 1 / 0.224, 1 / 0.225],
|
29 |
+
)
|
30 |
+
tensor = inverse_normalize(tensor)
|
31 |
+
tensor = torch.clamp(tensor, 0, 1)
|
32 |
+
return transforms.ToPILImage()(tensor)
|