from __future__ import print_function import torch import process_stylization from photo_wct import PhotoWCT import gradio as gr from datetime import datetime # Load model model_path = './models/photo_wct.pth' p_wct = PhotoWCT() p_wct.load_state_dict(torch.load(model_path)) def run(content_img, style_img, cuda, post_processing, fast): now = datetime.now() dt_string = now.strftime("%d/%m/%Y %H:%M:%S") print("[TimeStamp] {}".format(dt_string)) if fast == 0: from photo_gif import GIFSmoothing p_pro = GIFSmoothing(r=35, eps=0.001) else: from photo_smooth import Propagator p_pro = Propagator() if cuda: p_wct.cuda(0) else: p_wct.to('cpu') output_img = process_stylization.stylization_gradio( stylization_module=p_wct, smoothing_module=p_pro, content_image=content_img, style_image=style_img, cuda=cuda, post_processing=post_processing ) return output_img if __name__ == '__main__': style = gr.Interface( fn=run, inputs=[ gr.Image(label='Content Image'), gr.Image(label='Stylize Image'), gr.Checkbox(value=True, label='Use CUDA'), gr.Checkbox(value=True, label='Post Processing'), gr.Radio(choices=["Guided Image Filtering (Fast)", "Photorealisitic Smoothing (Slow)"], value="Guided Image Filtering (Fast)", type="index", label="Algorithm", interactive=False), ], outputs=[gr.Image( type="pil", label="Result"), ] ) style.queue() style.launch()