StylizeApp / style.py
Pranav4datasc's picture
Upload 56 files
daa2ec6 verified
import argparse
import os
import sys
import time
import re
import numpy as np
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
import torch.onnx
import utils
from transformer_net import TransformerNet
from vgg import Vgg16
import streamlit as st
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@st.cache_resource
def load_model(model_path):
print('load model')
with torch.no_grad():
style_model = TransformerNet()
state_dict = torch.load(model_path)
# remove saved deprecated running_* keys in InstanceNorm from the checkpoint
for k in list(state_dict.keys()):
if re.search(r'in\d+\.running_(mean|var)$', k):
del state_dict[k]
style_model.load_state_dict(state_dict)
style_model.to(device)
style_model.eval()
return style_model
@st.cache_resource
def stylize(_style_model, content_image, output_image):
content_image = utils.load_image(content_image)
content_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.mul(255))
])
content_image = content_transform(content_image)
content_image = content_image.unsqueeze(0).to(device)
with torch.no_grad():
output = _style_model(content_image).cpu()
utils.save_image(output_image, output[0])