File size: 2,853 Bytes
675e00b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
from flask import Flask, render_template, request
from werkzeug.utils import secure_filename
import torch
from models.Transform_net import TransFormerNet
import utils as utils

UPLOAD_FOLDER = os.path.join('static', 'uploads') # Folder to save the uploaded input image
os.makedirs(UPLOAD_FOLDER, exist_ok = True)
ALLOWED_EXTENSIONS = {'jpg'}
app = Flask(__name__)
app.config['UPLOAD'] = UPLOAD_FOLDER

binaries_path = os.path.join('models', 'binaries')
img_save_path = os.path.join('static', 'output_images') # Folder to save the stylized image
Inference_config = dict()
Inference_config['save_folder'] = img_save_path
os.makedirs(img_save_path, exist_ok = True)

def allowed_file(filename):
	return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS

def get_default_device() -> None:
    """Use GPU if available, else CPU"""
    if torch.cuda.is_available():
        for i in range(torch.cuda.device_count()):
            print(torch.cuda.get_device_properties(i))
        return torch.device('cuda')
    else:
        return torch.device('cpu')

"""Function to stylize the image and save it in a folder"""
def stylize(Inference_config, model_name):
    device = get_default_device()

    """Initializing Transformer model that stylizes the image"""
    style_net = TransFormerNet().to(device)
    trained_state = torch.load(os.path.join(binaries_path, model_name), map_location = torch.device('cpu'))
    binary = trained_state['state_dict']
    style_net.load_state_dict(binary, strict = True)
    style_net.eval()

    with torch.no_grad():
        content_img_path = Inference_config['image_path']
        content_img = utils.process_img(content_img_path, target_shape = 700)
        content_img = content_img.to(device)
        stylized_img = style_net(content_img).detach().cpu().numpy().squeeze(0)
        utils.save_and_display(Inference_config, stylized_img)

"""Function to get a input image and show the stylized image"""
@app.route('/', methods = ['GET', 'POST'])
def upload_file():
    if request.method == 'POST':
        file = request.files['image']
        model_name = request.form.get('Modelname') # Used to get the model name.
        if file and allowed_file(file.filename): 
            filename = secure_filename(file.filename)
            Inference_config['content_img_name'] = filename
            Inference_config['image_path'] = os.path.join(app.config['UPLOAD'], filename)
            file.save(os.path.join(app.config['UPLOAD'], filename))
            stylize(Inference_config, model_name)
            image = os.path.join(Inference_config['save_folder'], f"Stylized-image-{filename.split('.')[0]}.jpg")
            return render_template('render.html', image = image)
    return render_template('render.html')

if __name__ == '__main__':
    app.run(debug = True, host = "0.0.0.0")