xp3857 commited on
Commit
a44435e
1 Parent(s): fb97d52

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -164
app.py CHANGED
@@ -1,166 +1,52 @@
1
  import gradio as gr
2
- import os
3
- import cv2
4
- import shutil
5
- import sys
6
- from subprocess import call
7
- import torch
8
- import numpy as np
9
- from skimage import color
10
- import torchvision.transforms as transforms
11
- from PIL import Image
12
- import torch
13
- import uuid
14
-
15
- #os.system("pip install dlib")
16
- os.system('bash setup.sh')
17
-
18
- def lab2rgb(L, AB):
19
- """Convert an Lab tensor image to a RGB numpy output
20
- Parameters:
21
- L (1-channel tensor array): L channel images (range: [-1, 1], torch tensor array)
22
- AB (2-channel tensor array): ab channel images (range: [-1, 1], torch tensor array)
23
- Returns:
24
- rgb (RGB numpy image): rgb output images (range: [0, 255], numpy array)
25
- """
26
- AB2 = AB * 110.0
27
- L2 = (L + 1.0) * 50.0
28
- Lab = torch.cat([L2, AB2], dim=1)
29
- Lab = Lab[0].data.cpu().float().numpy()
30
- Lab = np.transpose(Lab.astype(np.float64), (1, 2, 0))
31
- rgb = color.lab2rgb(Lab) * 255
32
- return rgb
33
-
34
- def get_transform(model_name,params=None, grayscale=False, method=Image.BICUBIC):
35
- #params
36
- preprocess = 'resize'
37
- load_size = 256
38
- crop_size = 256
39
- transform_list = []
40
- if grayscale:
41
- transform_list.append(transforms.Grayscale(1))
42
- if model_name == "Pix2Pix Unet 256":
43
- osize = [load_size, load_size]
44
- transform_list.append(transforms.Resize(osize, method))
45
- # if 'crop' in preprocess:
46
- # if params is None:
47
- # transform_list.append(transforms.RandomCrop(crop_size))
48
-
49
- return transforms.Compose(transform_list)
50
-
51
- def inferRestoration(img, model_name):
52
- #if model_name == "Pix2Pix":
53
- model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'pix2pixRestoration_unet256')
54
- transform_list = [
55
- transforms.ToTensor(),
56
- transforms.Resize([256,256], Image.BICUBIC),
57
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
58
- ]
59
- transform = transforms.Compose(transform_list)
60
- img = transform(img)
61
- img = torch.unsqueeze(img, 0)
62
- result = model(img)
63
- result = result[0].detach()
64
- result = (result +1)/2.0
65
-
66
- result = transforms.ToPILImage()(result)
67
- return result
68
-
69
- def inferColorization(img):
70
- model_name = "Deoldify"
71
- model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'DeOldifyColorization')
72
- transform_list = [
73
- transforms.ToTensor(),
74
- transforms.Normalize((0.5,), (0.5,))
75
- ]
76
- transform = transforms.Compose(transform_list)
77
- #a = transforms.ToTensor()(a)
78
- img = img.convert('L')
79
- img = transform(img)
80
- img = torch.unsqueeze(img, 0)
81
- result = model(img)
82
-
83
- result = result[0].detach()
84
- result = (result +1)/2.0
85
-
86
- #img = transforms.Grayscale(3)(img)
87
- #img = transforms.ToTensor()(img)
88
- #img = torch.unsqueeze(img, 0)
89
- #result = model(img)
90
- #result = torch.clip(result, min=0, max=1)
91
- image_pil = transforms.ToPILImage()(result)
92
- return image_pil
93
-
94
- transform_seq = get_transform(model_name)
95
- img = transform_seq(img)
96
- # if model_name == "Pix2Pix Unet 256":
97
- # img.resize((256,256))
98
- img = np.array(img)
99
- lab = color.rgb2lab(img).astype(np.float32)
100
- lab_t = transforms.ToTensor()(lab)
101
- A = lab_t[[0], ...] / 50.0 - 1.0
102
- B = lab_t[[1, 2], ...] / 110.0
103
- #data = {'A': A, 'B': B, 'A_paths': "", 'B_paths': ""}
104
- L = torch.unsqueeze(A, 0)
105
- #print(L.shape)
106
- ab = model(L)
107
- Lab = lab2rgb(L, ab).astype(np.uint8)
108
- image_pil = Image.fromarray(Lab)
109
- #image_pil.save('test.png')
110
- #print(Lab.shape)
111
- return image_pil
112
-
113
- def colorizaition(image,model_name):
114
- image = Image.fromarray(image)
115
- result = inferColorization(image,model_name)
116
- return result
117
-
118
-
119
- def run_cmd(command):
120
- try:
121
- call(command, shell=True)
122
- except KeyboardInterrupt:
123
- print("Process interrupted")
124
- sys.exit(1)
125
-
126
- def run(image):
127
- uid = uuid.uuid4()
128
-
129
- if os.path.isdir(f"Temp{uid}"):
130
- shutil.rmtree(f"Temp{uid}")
131
-
132
- os.makedirs(f"Temp{uid}")
133
- os.makedirs(f"Temp{uid}/input")
134
- print(type(image))
135
- cv2.imwrite(f"Temp{uid}/input/input_img.png", image)
136
-
137
- command = ("python run.py --input_folder "
138
- + f"Temp{uid}/input"
139
- + " --output_folder "
140
- + f"Temp{uid}"
141
- + " --GPU "
142
- + "-1"
143
- + " --with_scratch")
144
- run_cmd(command)
145
-
146
- result_restoration = Image.open(f"Temp{uid}/final_output/input_img.png")
147
- shutil.rmtree(f"Temp{uid}")
148
-
149
- result_colorization = inferColorization(result_restoration)
150
-
151
- return result_colorization
152
- def load_im(url):
153
- return url
154
-
155
-
156
- with gr.Blocks() as app:
157
- im = gr.Image(label="Input Image")
158
  with gr.Row():
159
- im_u = gr.Textbox()
160
- lim_btn=gr.Button("Load")
161
- im_btn=gr.Button(label="Restore")
162
- out_im = gr.Image(label="Restored Image")
163
-
164
- #lim_btn(load_im,im_u,im)
165
- im_btn.click(run,[im,im_u],out_im)
166
- app.queue(concurrency_count=100).launch(show_api=False)
 
 
1
  import gradio as gr
2
+ from PIL import Image
3
+ import requests
4
+ import random
5
+ r = requests.get(f'https://huggingface.co/spaces/xp3857/bin/raw/main/css.css')
6
+ css = r.text
7
+ name2 = "xp3857/Image_Restoration_Colorization"
8
+ spaces=[
9
+ gr.Interface.load(f"spaces/{name2}"),
10
+ gr.Interface.load(f"spaces/{name2}"),
11
+ gr.Interface.load(f"spaces/{name2}"),
12
+ gr.Interface.load(f"spaces/{name2}"),
13
+ gr.Interface.load(f"spaces/{name2}"),
14
+ gr.Interface.load(f"spaces/{name2}"),
15
+ gr.Interface.load(f"spaces/{name2}"),
16
+ gr.Interface.load(f"spaces/{name2}"),
17
+ gr.Interface.load(f"spaces/{name2}"),
18
+ gr.Interface.load(f"spaces/{name2}"),
19
+ gr.Interface.load(f"spaces/{name2}"),
20
+ gr.Interface.load(f"spaces/{name2}"),
21
+ gr.Interface.load(f"spaces/{name2}"),
22
+ gr.Interface.load(f"spaces/{name2}"),
23
+ gr.Interface.load(f"spaces/{name2}"),
24
+ gr.Interface.load(f"spaces/{name2}"),
25
+ gr.Interface.load(f"spaces/{name2}"),
26
+ gr.Interface.load(f"spaces/{name2}"),
27
+ gr.Interface.load(f"spaces/{name2}"),
28
+ gr.Interface.load(f"spaces/{name2}"),
29
+ ]
30
+ def colorize(input):
31
+ if input !=None:
32
+ rn = random.randint(0, 19)
33
+ space=spaces[rn]
34
+ result=space(input)
35
+ out1 = gr.Pil.update(value=result,visible=True)
36
+ out2 = gr.Accordion.update(label="Original Image",open=False)
37
+ else:
38
+ out1 = None
39
+ out2 = None
40
+ pass
41
+ return out1, out2
42
+ with gr.Blocks(css=css) as myface:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  with gr.Row():
44
+ gr.Column()
45
+ with gr.Column():
46
+ with gr.Accordion(label="Input Image",open=True) as og:
47
+ in_win=gr.Pil(label="Input", type="filepath", interactive=True)
48
+ out_win=gr.Pil(label="Output",visible=False)
49
+ gr.Column()
50
+ in_win.change(rem_bg,in_win,[out_win,og])
51
+ myface.queue(concurrency_count=120)
52
+ myface.launch()