Widium
finalize demos with examples
93b3b80
# **************************************************************************** #
# #
# ::: :::::::: #
# style_model.py :+: :+: :+: #
# +:+ +:+ +:+ #
# By: ebennace <ebennace@student.42lausanne.c +#+ +:+ +#+ #
# +#+#+#+#+#+ +#+ #
# Created: 2022/11/15 13:08:56 by ebennace #+# #+# #
# Updated: 2022/11/15 13:08:57 by ebennace ### ########.fr #
# #
# **************************************************************************** #
from typing import Tuple
import numpy as np
from time import time
from tqdm.auto import tqdm
from tensorflow.keras.optimizers import Adam
from .vgg import create_list_of_vgg_layer
from .vgg import create_multi_output_model
from .init import init_style_target
from .init import init_generated_img
from .style_function import update_style
# ======================================================================= #
class StyleRecreationModel:
"""
A class for generating style representation of Input Image using the VGG19 model.
"""
# ======================================================================= #
def __init__(self):
"""
Initialize the StyleRecreationModel with a pre-trained VGG19 model and Adam optimizer.
"""
self.optimizer = Adam(learning_rate=0.02)
self.style_layers = create_list_of_vgg_layer()
self.num_style_layers = len(self.style_layers)
self.model = create_multi_output_model(self.style_layers)
# ======================================================================= #
def recreate_style(
self,
style_img_array : np.array,
num_epochs : int,
)->Tuple[np.array, float]:
"""
Generate a new image based on the style of the given input image.
Args:
`style_img_array` (np.array): The input style image as a NumPy array.
`num_epochs` (int): The number of epochs for the style transfer process.
Returns:
Tuple[np.array, float]: A tuple containing the generated image as a NumPy array and the total time taken.
"""
target_style = init_style_target(self.model, style_img_array)
self.generated_img = init_generated_img(style_img_array)
start = time()
for _ in tqdm(range(num_epochs)) :
update_style(
model=self.model,
style_target=target_style,
generated_img=self.generated_img,
optimizer=self.optimizer
)
end = time()
total_time = round(end-start, 2)
return (self.generated_img, total_time)
# ======================================================================= #