Files changed (1) hide show
  1. app.py +117 -0
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ import os
5
+ import sys
6
+ base_path = os.path.expanduser('~')
7
+
8
+ sys.path.append(os.path.join(base_path, 'Er0mangaSeg/'))
9
+ sys.path.append(os.path.join(base_path, 'Er0mangaSeg/demo'))
10
+ from image_demo_tta import init_seg_model, inference_tta
11
+
12
+ sys.path.append(os.path.join(base_path, 'Er0mangaInpaint/'))
13
+ sys.path.append(os.path.join(base_path, 'Er0mangaInpaint/bin'))
14
+ from uncen import init_inpaint_model, inpaint
15
+
16
+
17
+ import time
18
+ import numpy as np
19
+ import cv2
20
+ import shutil
21
+ import torch
22
+
23
+
24
+ if torch.cuda.is_available():
25
+ print('GPU found!')
26
+ device = 'cuda:0'
27
+ else:
28
+ print('GPU not found! Using CPU')
29
+ device = 'cpu'
30
+
31
+
32
+ config = os.path.join(base_path, 'Er0mangaSeg/configs/convnext/convnext_h.py')
33
+ checkpoint = os.path.join(base_path, 'Er0mangaSeg/pretrained/convnext_1024_iter_400.pth')
34
+ model_seg = init_seg_model(config, checkpoint, device=device)
35
+ print('Segmentation initialized')
36
+
37
+
38
+ inp_model_path = os.path.join(base_path, 'Er0mangaInpaint/pretrained/00-30-09')
39
+ model_inp = init_inpaint_model(inp_model_path)
40
+ print('Inpainting initialized')
41
+
42
+
43
+ def proc(input_img):
44
+
45
+ try:
46
+
47
+ s = time.time()
48
+
49
+ out_mask, raw_mask = inference_tta(model_seg, input_img)
50
+ out_mask = np.dstack([out_mask, out_mask, out_mask])
51
+ raw_mask = np.dstack([raw_mask, raw_mask, raw_mask])
52
+
53
+ output_img, out_dbg = inpaint(model_inp, input_img, out_mask)
54
+
55
+ e = time.time()
56
+ print(f"proc_time: {e-s:.2f}")
57
+
58
+ return output_img#, raw_mask
59
+
60
+ except Exception as e:
61
+ raise gr.Error(e)
62
+
63
+
64
+ def proc_batch(batch):
65
+
66
+ res = []
67
+ try:
68
+
69
+ s = time.time()
70
+
71
+ out_p = os.path.dirname(batch[0][0])
72
+ salt = str(np.random.randint(1e10))
73
+ out_p_d = os.path.join(out_p, '__salt_img__'+salt)
74
+ out_p_m = os.path.join(out_p, '__salt_mask__'+salt)
75
+ os.mkdir(out_p_d)
76
+ os.mkdir(out_p_m)
77
+
78
+ for i in range(len(batch)):
79
+ input_path = batch[i][0]
80
+ inp_name = os.path.basename(input_path)
81
+ input_img = cv2.cvtColor(cv2.imread(input_path), cv2.COLOR_BGR2RGB)
82
+
83
+ out_mask, raw_mask = inference_tta(model_seg, input_img)
84
+ out_mask = np.dstack([out_mask, out_mask, out_mask])
85
+ raw_mask = np.dstack([raw_mask, raw_mask, raw_mask])
86
+
87
+ output_img, out_dbg = inpaint(model_inp, input_img, out_mask)
88
+ out_path_img = os.path.join(out_p_d, inp_name)
89
+ out_path_mask = os.path.join(out_p_m, inp_name+'.png')
90
+ cv2.imwrite(out_path_img, cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB))
91
+ cv2.imwrite(out_path_mask, raw_mask)
92
+ res.append(out_path_img)
93
+
94
+ ar_path = os.path.join(out_p, 'output')
95
+ shutil.make_archive(ar_path, 'zip', out_p_d)
96
+
97
+ ar_path_m = os.path.join(out_p, 'output_mask')
98
+ shutil.make_archive(ar_path_m, 'zip', out_p_m)
99
+
100
+ e = time.time()
101
+ print(f"batch proc_time: {e-s:.2f}")
102
+
103
+ return res, ar_path + '.zip', ar_path_m + '.zip'
104
+
105
+ except Exception as e:
106
+ raise gr.Error(e)
107
+
108
+
109
+
110
+ demo1 = gr.Interface(proc, gr.Image(), gr.Image(format='png'), delete_cache=(7200, 7200), allow_flagging='never')
111
+ demo2 = gr.Interface(proc_batch, gr.Gallery(), [gr.Gallery(value='str', format='png'), gr.File(), gr.File()], delete_cache=(7200, 7200), allow_flagging='never')
112
+ demo = gr.TabbedInterface([demo1, demo2], ["Single image processing", "Batch processing (experimental)"])
113
+
114
+ if __name__ == "__main__":
115
+ demo.launch(server_name='0.0.0.0', server_port=7860)
116
+
117
+