xp3857 commited on
Commit
9264924
1 Parent(s): 16ec9cd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +166 -0
app.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import cv2
4
+ import shutil
5
+ import sys
6
+ from subprocess import call
7
+ import torch
8
+ import numpy as np
9
+ from skimage import color
10
+ import torchvision.transforms as transforms
11
+ from PIL import Image
12
+ import torch
13
+ import uuid
14
+
15
+ #os.system("pip install dlib")
16
+ os.system('bash setup.sh')
17
+
18
+ def lab2rgb(L, AB):
19
+ """Convert an Lab tensor image to a RGB numpy output
20
+ Parameters:
21
+ L (1-channel tensor array): L channel images (range: [-1, 1], torch tensor array)
22
+ AB (2-channel tensor array): ab channel images (range: [-1, 1], torch tensor array)
23
+ Returns:
24
+ rgb (RGB numpy image): rgb output images (range: [0, 255], numpy array)
25
+ """
26
+ AB2 = AB * 110.0
27
+ L2 = (L + 1.0) * 50.0
28
+ Lab = torch.cat([L2, AB2], dim=1)
29
+ Lab = Lab[0].data.cpu().float().numpy()
30
+ Lab = np.transpose(Lab.astype(np.float64), (1, 2, 0))
31
+ rgb = color.lab2rgb(Lab) * 255
32
+ return rgb
33
+
34
+ def get_transform(model_name,params=None, grayscale=False, method=Image.BICUBIC):
35
+ #params
36
+ preprocess = 'resize'
37
+ load_size = 256
38
+ crop_size = 256
39
+ transform_list = []
40
+ if grayscale:
41
+ transform_list.append(transforms.Grayscale(1))
42
+ if model_name == "Pix2Pix Unet 256":
43
+ osize = [load_size, load_size]
44
+ transform_list.append(transforms.Resize(osize, method))
45
+ # if 'crop' in preprocess:
46
+ # if params is None:
47
+ # transform_list.append(transforms.RandomCrop(crop_size))
48
+
49
+ return transforms.Compose(transform_list)
50
+
51
+ def inferRestoration(img, model_name):
52
+ #if model_name == "Pix2Pix":
53
+ model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'pix2pixRestoration_unet256')
54
+ transform_list = [
55
+ transforms.ToTensor(),
56
+ transforms.Resize([256,256], Image.BICUBIC),
57
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
58
+ ]
59
+ transform = transforms.Compose(transform_list)
60
+ img = transform(img)
61
+ img = torch.unsqueeze(img, 0)
62
+ result = model(img)
63
+ result = result[0].detach()
64
+ result = (result +1)/2.0
65
+
66
+ result = transforms.ToPILImage()(result)
67
+ return result
68
+
69
+ def inferColorization(img):
70
+ model_name = "Deoldify"
71
+ model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'DeOldifyColorization')
72
+ transform_list = [
73
+ transforms.ToTensor(),
74
+ transforms.Normalize((0.5,), (0.5,))
75
+ ]
76
+ transform = transforms.Compose(transform_list)
77
+ #a = transforms.ToTensor()(a)
78
+ img = img.convert('L')
79
+ img = transform(img)
80
+ img = torch.unsqueeze(img, 0)
81
+ result = model(img)
82
+
83
+ result = result[0].detach()
84
+ result = (result +1)/2.0
85
+
86
+ #img = transforms.Grayscale(3)(img)
87
+ #img = transforms.ToTensor()(img)
88
+ #img = torch.unsqueeze(img, 0)
89
+ #result = model(img)
90
+ #result = torch.clip(result, min=0, max=1)
91
+ image_pil = transforms.ToPILImage()(result)
92
+ return image_pil
93
+
94
+ transform_seq = get_transform(model_name)
95
+ img = transform_seq(img)
96
+ # if model_name == "Pix2Pix Unet 256":
97
+ # img.resize((256,256))
98
+ img = np.array(img)
99
+ lab = color.rgb2lab(img).astype(np.float32)
100
+ lab_t = transforms.ToTensor()(lab)
101
+ A = lab_t[[0], ...] / 50.0 - 1.0
102
+ B = lab_t[[1, 2], ...] / 110.0
103
+ #data = {'A': A, 'B': B, 'A_paths': "", 'B_paths': ""}
104
+ L = torch.unsqueeze(A, 0)
105
+ #print(L.shape)
106
+ ab = model(L)
107
+ Lab = lab2rgb(L, ab).astype(np.uint8)
108
+ image_pil = Image.fromarray(Lab)
109
+ #image_pil.save('test.png')
110
+ #print(Lab.shape)
111
+ return image_pil
112
+
113
+ def colorizaition(image,model_name):
114
+ image = Image.fromarray(image)
115
+ result = inferColorization(image,model_name)
116
+ return result
117
+
118
+
119
+ def run_cmd(command):
120
+ try:
121
+ call(command, shell=True)
122
+ except KeyboardInterrupt:
123
+ print("Process interrupted")
124
+ sys.exit(1)
125
+
126
+ def run(image):
127
+ uid = uuid.uuid4()
128
+
129
+ if os.path.isdir(f"Temp{uid}"):
130
+ shutil.rmtree(f"Temp{uid}")
131
+
132
+ os.makedirs(f"Temp{uid}")
133
+ os.makedirs(f"Temp{uid}/input")
134
+ print(type(image))
135
+ cv2.imwrite(f"Temp{uid}/input/input_img.png", image)
136
+
137
+ command = ("python run.py --input_folder "
138
+ + f"Temp{uid}/input"
139
+ + " --output_folder "
140
+ + f"Temp{uid}"
141
+ + " --GPU "
142
+ + "-1"
143
+ + " --with_scratch")
144
+ run_cmd(command)
145
+
146
+ result_restoration = Image.open(f"Temp{uid}/final_output/input_img.png")
147
+ shutil.rmtree(f"Temp{uid}")
148
+
149
+ result_colorization = inferColorization(result_restoration)
150
+
151
+ return result_colorization
152
+ def load_im(url):
153
+ return url
154
+
155
+
156
+ with gr.Blocks() as app:
157
+ im = gr.Image(label="Input Image")
158
+ with gr.Row():
159
+ im_u = gr.Textbox()
160
+ lim_btn=gr.Button("Load")
161
+ im_btn=gr.Button(label="Restore")
162
+ out_im = gr.Image(label="Restored Image")
163
+
164
+ #lim_btn(load_im,im_u,im)
165
+ im_btn.click(run,[im,im_u],out_im)
166
+ app.queue(concurrency_count=100).launch(show_api=False)