awacke1's picture
Update app.py
4e64785
raw
history blame contribute delete
No virus
4.54 kB
import numpy as np
import torch
import torch.nn as nn
import gradio as gr
from PIL import Image
import torchvision.transforms as transforms
norm_layer = nn.InstanceNorm2d
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
conv_block = [ nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
norm_layer(in_features),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
norm_layer(in_features)
]
self.conv_block = nn.Sequential(*conv_block)
def forward(self, x):
return x + self.conv_block(x)
class Generator(nn.Module):
def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
super(Generator, self).__init__()
model0 = [ nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, 64, 7),
norm_layer(64),
nn.ReLU(inplace=True) ]
self.model0 = nn.Sequential(*model0)
model1 = []
in_features = 64
out_features = in_features*2
for _ in range(2):
model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
norm_layer(out_features),
nn.ReLU(inplace=True) ]
in_features = out_features
out_features = in_features*2
self.model1 = nn.Sequential(*model1)
model2 = []
for _ in range(n_residual_blocks):
model2 += [ResidualBlock(in_features)]
self.model2 = nn.Sequential(*model2)
model3 = []
out_features = in_features//2
for _ in range(2):
model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
norm_layer(out_features),
nn.ReLU(inplace=True) ]
in_features = out_features
out_features = in_features//2
self.model3 = nn.Sequential(*model3)
model4 = [ nn.ReflectionPad2d(3),
nn.Conv2d(64, output_nc, 7)]
if sigmoid:
model4 += [nn.Sigmoid()]
self.model4 = nn.Sequential(*model4)
def forward(self, x, cond=None):
out = self.model0(x)
out = self.model1(out)
out = self.model2(out)
out = self.model3(out)
out = self.model4(out)
return out
model1 = Generator(3, 1, 3)
model1.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
model1.eval()
model2 = Generator(3, 1, 3)
model2.load_state_dict(torch.load('model2.pth', map_location=torch.device('cpu')))
model2.eval()
def predict(input_img, ver):
input_img = Image.open(input_img)
transform = transforms.Compose([transforms.Resize(256, Image.BICUBIC), transforms.ToTensor()])
input_img = transform(input_img)
input_img = torch.unsqueeze(input_img, 0)
drawing = 0
with torch.no_grad():
if ver == 'Simple Lines':
drawing = model2(input_img)[0].detach()
else:
drawing = model1(input_img)[0].detach()
drawing = transforms.ToPILImage()(drawing)
return drawing
title="Art Style Line Drawings - Complex and Simple Portraits and Landscapes"
description="Art Style Line Drawings ๐Ÿฆ€๐Ÿฆ๐Ÿฆ‚๐Ÿฆƒ๐Ÿฆ„๐Ÿฆ…๐Ÿฆ†๐Ÿฆ‡๐Ÿฆˆ๐Ÿฆ‰๐ŸฆŠ๐Ÿฆ‹๐ŸฆŒ๐Ÿฆ๐ŸฆŽ๐Ÿฆ ๐Ÿฆ๐Ÿฆ‘๐Ÿฆ’๐Ÿฆ“๐Ÿฆ”๐Ÿฆ•๐Ÿฆ–๐Ÿฆ—๐Ÿฆ˜๐Ÿฆ™๐Ÿฆš๐Ÿฆ›๐Ÿฆœ๐Ÿฆ๐Ÿฆž๐ŸฆŸ๐Ÿฆ ๐Ÿฆก๐Ÿฆข๐Ÿฆฃ๐Ÿฆค๐Ÿฆฅ๐Ÿฆฆ๐Ÿฆง๐Ÿฆจ๐Ÿฆฉ๐Ÿฆช๐Ÿฆซ๐Ÿฆฌ๐Ÿฆญ๐Ÿฆฎ"
# article = "<p style='text-align: center'></p>"
examples=[
['QSHYNkOyhArcsgDrSFqq_15.625x.jpg', 'Simple Lines'],
['Xenomporh-art-scale-6_00x-gigapixel.png', 'Simple Lines'],
['Alien Chairs-art-scale-6_00x-gigapixel.png', 'Complex Lines'],
['Brain Coral B-gigapixel-art-scale-6_00x.jpg', 'Simple Lines'],
['Brain Coral-gigapixel-art-scale-6_00x.jpg', 'Complex Lines'],
['Dark Ritual Wisp Loop-art-scale-6_00x-gigapixel.png', 'Simple Lines'],
['Dungeons and Dragons Cartoon-art-scale-6_00x-gigapixel.png', 'Complex Lines'],
['Fantasy Art 2-art-scale-6_00x-gigapixel.png', 'Simple Lines']
]
iface = gr.Interface(predict, [gr.inputs.Image(type='filepath'),
gr.inputs.Radio(['Complex Lines','Simple Lines'], type="value", default='Simple Lines', label='version')],
gr.outputs.Image(type="pil"), title=title,description=description,examples=examples)
#iface.launch()
iface.launch()