chris1nexus commited on
Commit
302a151
1 Parent(s): 91421aa

First commit

Browse files
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+
4
+ from PIL import Image
5
+ from torchvision import transforms as T
6
+ from torchvision.transforms import Compose, Resize, ToTensor, Normalize, RandomCrop, RandomHorizontalFlip
7
+ from torchvision.utils import make_grid
8
+ from torch.utils.data import DataLoader
9
+ from huggan.pytorch.cyclegan.modeling_cyclegan import GeneratorResNet
10
+ import torch.nn as nn
11
+ import torch
12
+ import gradio as gr
13
+
14
+ from collections import OrderedDict
15
+ import glob
16
+
17
+
18
+
19
+
20
+ def pred_pipeline(img, transforms):
21
+ orig_shape = img.shape
22
+ input = transforms(img)
23
+ input = input.unsqueeze(0)
24
+ output_real = sim2real(input)
25
+ output_syn = real2sim(output_real)
26
+ out_img_real = make_grid(output_real,
27
+ nrow=1, normalize=True)
28
+ out_syn_real = make_grid(out_img_real,
29
+ nrow=1, normalize=True)
30
+
31
+
32
+
33
+ out_transform = Compose([
34
+ T.Resize(orig_shape[:2]),
35
+ T.ToPILImage()
36
+ ])
37
+ return out_transform(out_img_real), out_transform(out_syn_real)
38
+
39
+
40
+
41
+
42
+ n_channels = 3
43
+ image_size = 512
44
+ input_shape = (image_size, image_size)
45
+
46
+ transform = Compose([
47
+ T.ToPILImage(),
48
+ T.Resize(input_shape),
49
+ ToTensor(),
50
+ Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
51
+ ])
52
+
53
+
54
+ sim2real = GeneratorResNet.from_pretrained('Chris1/sim2real-512', input_shape=(n_channels, image_size, image_size),
55
+ num_residual_blocks=9)
56
+ real2sim = GeneratorResNet.from_pretrained('Chris1/real2sim-512', input_shape=(n_channels, image_size, image_size),
57
+ num_residual_blocks=9)
58
+
59
+ gr.Interface(lambda image: pred_pipeline(image, transform),
60
+ inputs=gr.inputs.Image( label='input synthetic image'),
61
+ outputs=[
62
+ gr.outputs.Image( type="pil",label='style transfer to the real world (generator G_AB synthetic to real applied to the chosen input)'),
63
+ gr.outputs.Image( type="pil",label='real to synthetic translation (generator G_BA real to synthetic applied to the prediction of G_AB)')
64
+ ],#plot,
65
+ title = "GTA5(simulated) to Cityscapes (real) translation",
66
+ examples = [
67
+ [example] for example in glob.glob('./samples/*.png')
68
+ ])\
69
+ .launch()
70
+
71
+
72
+
73
+ #iface = gr.Interface(fn=greet, inputs="text", outputs="text")
74
+ #iface.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ huggingface-hub
2
+ numpy
3
+ torch
4
+ transformers
5
+ git+https://github.com/huggingface/community-events@main
samples/00012.png ADDED
samples/00237.png ADDED
samples/08164.png ADDED
samples/11603.png ADDED
samples/11607.png ADDED
samples/12073.png ADDED
samples/12227.png ADDED
samples/12605.png ADDED
samples/18621.png ADDED
samples/19627.png ADDED