chris1nexus commited on
Commit
601fb3d
1 Parent(s): 5f0e2e7

Add samples

Browse files
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+
4
+ from PIL import Image
5
+ from torchvision import transforms as T
6
+ from torchvision.transforms import Compose, Resize, ToTensor, Normalize, RandomCrop, RandomHorizontalFlip
7
+ from torchvision.utils import make_grid
8
+ from torch.utils.data import DataLoader
9
+ from huggan.pytorch.cyclegan.modeling_cyclegan import GeneratorResNet
10
+ import torch.nn as nn
11
+ import torch
12
+ import gradio as gr
13
+
14
+ from collections import OrderedDict
15
+ import glob
16
+
17
+
18
+
19
+
20
+ def pred_pipeline(img, transforms):
21
+ orig_shape = img.shape
22
+ input = transforms(img)
23
+ input = input.unsqueeze(0)
24
+ output_syn = real2sim(input)
25
+ output_real = sim2real(output_syn)
26
+ out_img_syn = make_grid(output_syn,
27
+ nrow=1, normalize=True)
28
+ out_img_real = make_grid(output_real,
29
+ nrow=1, normalize=True)
30
+
31
+
32
+
33
+ out_transform = Compose([
34
+ T.Resize(orig_shape[:2]),
35
+ T.ToPILImage()
36
+ ])
37
+ return out_transform(out_img_syn), out_transform(out_img_real)
38
+
39
+
40
+
41
+
42
+ n_channels = 3
43
+ image_size = 512
44
+ input_shape = (image_size, image_size)
45
+
46
+ transform = Compose([
47
+ T.ToPILImage(),
48
+ T.Resize(input_shape),
49
+ ToTensor(),
50
+ Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
51
+ ])
52
+
53
+
54
+ sim2real = GeneratorResNet.from_pretrained('Chris1/sim2real-512', input_shape=(n_channels, image_size, image_size),
55
+ num_residual_blocks=9)
56
+ real2sim = GeneratorResNet.from_pretrained('Chris1/real2sim-512', input_shape=(n_channels, image_size, image_size),
57
+ num_residual_blocks=9)
58
+
59
+ gr.Interface(lambda image: pred_pipeline(image, transform),
60
+ inputs=gr.inputs.Image( label='input synthetic image'),
61
+ outputs=[
62
+ gr.outputs.Image( type="pil",label='GAN real2sim prediction: style transfer of the input to the synthetic world '),
63
+ gr.outputs.Image( type="pil",label='GAN sim2real prediction: translation to real of the above prediction')
64
+ ],#plot,
65
+ title = "Cityscapes (real) to GTA5(simulated) translation",
66
+ examples = [
67
+ [example] for example in glob.glob('./samples/*.png')
68
+ ])\
69
+ .launch()
70
+
71
+
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ huggingface-hub
2
+ numpy
3
+ torch
4
+ transformers
5
+ git+https://github.com/huggingface/community-events@main
samples/berlin_000008_000019_leftImg8bit.png ADDED
samples/berlin_000009_000019_leftImg8bit.png ADDED
samples/berlin_000010_000019_leftImg8bit.png ADDED
samples/berlin_000011_000019_leftImg8bit.png ADDED
samples/berlin_000012_000019_leftImg8bit.png ADDED
samples/berlin_000013_000019_leftImg8bit.png ADDED
samples/berlin_000014_000019_leftImg8bit.png ADDED
samples/berlin_000049_000019_leftImg8bit.png ADDED
samples/berlin_000050_000019_leftImg8bit.png ADDED
samples/berlin_000051_000019_leftImg8bit.png ADDED
samples/berlin_000052_000019_leftImg8bit.png ADDED