Nick088 commited on
Commit
9bdccea
1 Parent(s): 61c9fcc

fix of videos, and better code by gdr/daroche

Browse files
Files changed (2) hide show
  1. app.py +5 -107
  2. infer.py +260 -0
app.py CHANGED
@@ -1,114 +1,14 @@
1
- import torch
2
- from PIL import Image
3
- from RealESRGAN import RealESRGAN
4
  import gradio as gr
5
- import os
6
- from random import randint
7
- import shutil
8
-
9
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
- model2 = RealESRGAN(device, scale=2)
11
- model2.load_weights('weights/RealESRGAN_x2.pth', download=True)
12
- model4 = RealESRGAN(device, scale=4)
13
- model4.load_weights('weights/RealESRGAN_x4.pth', download=True)
14
- model8 = RealESRGAN(device, scale=8)
15
- model8.load_weights('weights/RealESRGAN_x8.pth', download=True)
16
-
17
-
18
- def inference_image(image, size):
19
- global model2
20
- global model4
21
- global model8
22
- if image is None:
23
- raise gr.Error("Image not uploaded")
24
-
25
- width, height = image.size
26
- if width >= 5000 or height >= 5000:
27
- raise gr.Error("The image is too large.")
28
-
29
- if torch.cuda.is_available():
30
- torch.cuda.empty_cache()
31
-
32
- if size == '2x':
33
- try:
34
- result = model2.predict(image.convert('RGB'))
35
- except torch.cuda.OutOfMemoryError as e:
36
- print(e)
37
- model2 = RealESRGAN(device, scale=2)
38
- model2.load_weights('weights/RealESRGAN_x2.pth', download=False)
39
- result = model2.predict(image.convert('RGB'))
40
- elif size == '4x':
41
- try:
42
- result = model4.predict(image.convert('RGB'))
43
- except torch.cuda.OutOfMemoryError as e:
44
- print(e)
45
- model4 = RealESRGAN(device, scale=4)
46
- model4.load_weights('weights/RealESRGAN_x4.pth', download=False)
47
- result = model2.predict(image.convert('RGB'))
48
- else:
49
- try:
50
- result = model8.predict(image.convert('RGB'))
51
- except torch.cuda.OutOfMemoryError as e:
52
- print(e)
53
- model8 = RealESRGAN(device, scale=8)
54
- model8.load_weights('weights/RealESRGAN_x8.pth', download=False)
55
- result = model2.predict(image.convert('RGB'))
56
-
57
- print(f"Image size ({device}): {size} ... OK")
58
- return result
59
-
60
-
61
-
62
- def inference_video(video, size):
63
- _id = randint(1, 10000)
64
- INPUT_DIR = "tmp"
65
- os.makedirs(INPUT_DIR, exist_ok=True)
66
- os.chdir(INPUT_DIR)
67
-
68
- upload_folder = 'upload'
69
- result_folder = 'results'
70
- video_folder = 'videos'
71
- video_result_folder = 'results_videos'
72
- video_mp4_result_folder = 'results_mp4_videos'
73
- result_restored_imgs_folder = 'restored_imgs'
74
-
75
- os.makedirs(upload_folder, exist_ok=True)
76
-
77
- os.makedirs(video_folder, exist_ok=True)
78
-
79
- os.makedirs(video_result_folder, exist_ok=True)
80
-
81
- os.makedirs(video_mp4_result_folder, exist_ok=True)
82
-
83
- os.makedirs(result_folder, exist_ok=True)
84
-
85
- os.chdir("results")
86
- os.makedirs(result_restored_imgs_folder, exist_ok=True)
87
- os.chdir("..")
88
- try:
89
- # Specify the desired output file path with the custom name and ".mp4" extension
90
- output_file_path = f"/{INPUT_DIR}/videos/input.mp4"
91
-
92
- # Save the video input to the specified file path
93
- with open(output_file_path, 'wb') as output_file:
94
- output_file.write(video)
95
- print(f"Video input saved as {output_file_path}")
96
- except Exception as e:
97
- print(f"Error saving video input: {str(e)}")
98
-
99
- os.chdir("..")
100
- os.system("python inference_video.py")
101
- return os.path.join(f'/{INPUT_DIR}/results_mp4_videos/', 'input.mp4')
102
-
103
 
 
104
 
105
  input_image = gr.Image(type='pil', label='Input Image')
106
- input_model_image = gr.Radio(['2x', '4x', '8x'], type="value", value="4x", label="Model Upscale/Enhance Type")
107
  submit_image_button = gr.Button('Submit')
108
  output_image = gr.Image(type="filepath", label="Output Image")
109
 
110
  tab_img = gr.Interface(
111
- fn=inference_image,
112
  inputs=[input_image, input_model_image],
113
  outputs=output_image,
114
  title="Real-ESRGAN Pytorch",
@@ -116,12 +16,12 @@ tab_img = gr.Interface(
116
  )
117
 
118
  input_video = gr.Video(label='Input Video')
119
- input_model_video = gr.Radio(['2x', '4x', '8x'], type="value", value="4x", label="Model Upscale/Enhance Type")
120
  submit_video_button = gr.Button('Submit')
121
  output_video = gr.Video(label='Output Video')
122
 
123
  tab_vid = gr.Interface(
124
- fn=inference_video,
125
  inputs=[input_video, input_model_video],
126
  outputs=output_video,
127
  title="Real-ESRGAN Pytorch",
@@ -130,6 +30,4 @@ tab_vid = gr.Interface(
130
 
131
  demo = gr.TabbedInterface([tab_img, tab_vid], ["Image", "Video"])
132
 
133
-
134
-
135
  demo.launch(debug=True, show_error=True)
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ from infer import infer_image, infer_video
4
 
5
  input_image = gr.Image(type='pil', label='Input Image')
6
+ input_model_image = gr.Radio([('x2', 2), ('x4', 4), ('x8', 8)], type="value", value=4, label="Model Upscale/Enhance Type")
7
  submit_image_button = gr.Button('Submit')
8
  output_image = gr.Image(type="filepath", label="Output Image")
9
 
10
  tab_img = gr.Interface(
11
+ fn=infer_image,
12
  inputs=[input_image, input_model_image],
13
  outputs=output_image,
14
  title="Real-ESRGAN Pytorch",
 
16
  )
17
 
18
  input_video = gr.Video(label='Input Video')
19
+ input_model_video = gr.Radio([('x2', 2), ('x4', 4), ('x8', 8)], type="value", value=4, label="Model Upscale/Enhance Type")
20
  submit_video_button = gr.Button('Submit')
21
  output_video = gr.Video(label='Output Video')
22
 
23
  tab_vid = gr.Interface(
24
+ fn=infer_video,
25
  inputs=[input_video, input_model_video],
26
  outputs=output_video,
27
  title="Real-ESRGAN Pytorch",
 
30
 
31
  demo = gr.TabbedInterface([tab_img, tab_vid], ["Image", "Video"])
32
 
 
 
33
  demo.launch(debug=True, show_error=True)
infer.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import cv2
2
+ # from os.path import isfile, join
3
+ # import subprocess
4
+ # import os
5
+ # from RealESRGAN import RealESRGAN
6
+ # import torch
7
+ # import gradio as gr
8
+
9
+ # IMAGE_FORMATS = ('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif')
10
+
11
+ # def inference_image(image, size):
12
+ # global model2
13
+ # global model4
14
+ # global model8
15
+ # if image is None:
16
+ # raise gr.Error("Image not uploaded")
17
+
18
+ # width, height = image.size
19
+ # if width >= 5000 or height >= 5000:
20
+ # raise gr.Error("The image is too large.")
21
+
22
+ # if torch.cuda.is_available():
23
+ # torch.cuda.empty_cache()
24
+
25
+ # if size == '2x':
26
+ # try:
27
+ # result = model2.predict(image.convert('RGB'))
28
+ # except torch.cuda.OutOfMemoryError as e:
29
+ # print(e)
30
+ # model2 = RealESRGAN(device, scale=2)
31
+ # model2.load_weights('weights/RealESRGAN_x2.pth', download=False)
32
+ # result = model2.predict(image.convert('RGB'))
33
+ # elif size == '4x':
34
+ # try:
35
+ # result = model4.predict(image.convert('RGB'))
36
+ # except torch.cuda.OutOfMemoryError as e:
37
+ # print(e)
38
+ # model4 = RealESRGAN(device, scale=4)
39
+ # model4.load_weights('weights/RealESRGAN_x4.pth', download=False)
40
+ # result = model2.predict(image.convert('RGB'))
41
+ # else:
42
+ # try:
43
+ # result = model8.predict(image.convert('RGB'))
44
+ # except torch.cuda.OutOfMemoryError as e:
45
+ # print(e)
46
+ # model8 = RealESRGAN(device, scale=8)
47
+ # model8.load_weights('weights/RealESRGAN_x8.pth', download=False)
48
+ # result = model2.predict(image.convert('RGB'))
49
+
50
+ # print(f"Frame of the Video size ({device}): {size} ... OK")
51
+ # return result
52
+
53
+
54
+ # # assign directory
55
+ # directory = 'videos' #PATH_WITH_INPUT_VIDEOS
56
+ # zee = 0
57
+
58
+ # def convert_frames_to_video(pathIn,pathOut,fps):
59
+ # global INPUT_DIR
60
+ # cap = cv2.VideoCapture(f'/{INPUT_DIR}/videos/input.mp4')
61
+ # fps = cap.get(cv2.CAP_PROP_FPS)
62
+ # frame_array = []
63
+ # files = [f for f in os.listdir(pathIn) if isfile(join(pathIn, f))]
64
+ # #for sorting the file names properly
65
+ # files.sort(key = lambda x: int(x[5:-4]))
66
+ # size2 = (0,0)
67
+
68
+ # for i in range(len(files)):
69
+ # filename=pathIn + files[i]
70
+ # #reading each files
71
+ # img = cv2.imread(filename)
72
+ # height, width, layers = img.shape
73
+ # size = (width,height)
74
+ # size2 = size
75
+ # print(filename)
76
+ # #inserting the frames into an image array
77
+ # frame_array.append(img)
78
+ # out = cv2.VideoWriter(pathOut,cv2.VideoWriter_fourcc(*'DIVX'), fps, size2)
79
+ # for i in range(len(frame_array)):
80
+ # # writing to a image array
81
+ # out.write(frame_array[i])
82
+ # out.release()
83
+
84
+
85
+ # for filename in os.listdir(directory):
86
+
87
+ # f = os.path.join(directory, filename)
88
+ # # checking if it is a file
89
+ # if os.path.isfile(f):
90
+
91
+
92
+ # print("PROCESSING :"+str(f)+"\n")
93
+ # # Read the video from specified path
94
+
95
+ # #video to frames
96
+ # cam = cv2.VideoCapture(str(f))
97
+
98
+ # try:
99
+
100
+ # # PATH TO STORE VIDEO FRAMES
101
+ # if not os.path.exists(f'/{INPUT_DIR}/upload/'):
102
+ # os.makedirs(f'/{INPUT_DIR}/upload/')
103
+
104
+ # # if not created then raise error
105
+ # except OSError:
106
+ # print ('Error: Creating directory of data')
107
+
108
+ # # frame
109
+ # currentframe = 0
110
+
111
+
112
+ # while(True):
113
+
114
+ # # reading from frame
115
+ # ret,frame = cam.read()
116
+
117
+ # if ret:
118
+ # # if video is still left continue creating images
119
+ # name = f'/{INPUT_DIR}/upload/frame' + str(currentframe) + '.jpg'
120
+
121
+ # # writing the extracted images
122
+ # cv2.imwrite(name, frame)
123
+
124
+
125
+ # # increasing counter so that it will
126
+ # # show how many frames are created
127
+ # currentframe += 1
128
+ # print(currentframe)
129
+ # else:
130
+ # #deletes all the videos you uploaded for upscaling
131
+ # #for f in os.listdir(video_folder):
132
+ # # os.remove(os.path.join(video_folder, f))
133
+
134
+ # break
135
+
136
+ # # Release all space and windows once done
137
+ # cam.release()
138
+ # cv2.destroyAllWindows()
139
+
140
+ # #apply super-resolution on all frames of a video
141
+
142
+ # # Specify the directory path
143
+ # all_frames_path = f"/{INPUT_DIR}/upload/"
144
+
145
+ # # Get a list of all files in the directory
146
+ # file_names = os.listdir(all_frames_path)
147
+
148
+ # # process the files
149
+ # for file_name in file_names:
150
+ # inference_image(f"/{INPUT_DIR}/upload/{file_name}")
151
+
152
+
153
+ # #convert super res frames to .avi
154
+ # pathIn = f'/{INPUT_DIR}/results/restored_imgs/'
155
+
156
+ # zee = zee+1
157
+ # fName = "video"+str(zee)
158
+ # filenameVid = f"{fName}.avi"
159
+
160
+ # pathOut = f"/{INPUT_DIR}/results_videos/"+filenameVid
161
+
162
+ # convert_frames_to_video(pathIn, pathOut, fps)
163
+
164
+
165
+ # #convert .avi to .mp4
166
+ # src = f'/{INPUT_DIR}/results_videos/'
167
+ # dst = f'/{INPUT_DIR}/results_mp4_videos/'
168
+
169
+ # for root, dirs, filenames in os.walk(src, topdown=False):
170
+ # #print(filenames)
171
+ # for filename in filenames:
172
+ # print('[INFO] 1',filename)
173
+ # try:
174
+ # _format = ''
175
+ # if ".flv" in filename.lower():
176
+ # _format=".flv"
177
+ # if ".mp4" in filename.lower():
178
+ # _format=".mp4"
179
+ # if ".avi" in filename.lower():
180
+ # _format=".avi"
181
+ # if ".mov" in filename.lower():
182
+ # _format=".mov"
183
+
184
+ # inputfile = os.path.join(root, filename)
185
+ # print('[INFO] 1',inputfile)
186
+ # outputfile = os.path.join(dst, filename.lower().replace(_format, ".mp4"))
187
+ # subprocess.call(['ffmpeg', '-i', inputfile, outputfile])
188
+ # except:
189
+ # print("An exception occurred")
190
+
191
+ from PIL import Image
192
+ import cv2 as cv
193
+ import torch
194
+ from RealESRGAN import RealESRGAN
195
+ import tempfile
196
+ import numpy as np
197
+ import tqdm
198
+
199
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
200
+
201
+ def infer_image(img: Image.Image, size_modifier: int ) -> Image.Image:
202
+ if img is None:
203
+ raise Exception("Image not uploaded")
204
+
205
+ width, height = img.size
206
+
207
+ if width >= 5000 or height >= 5000:
208
+ raise Exception("The image is too large.")
209
+
210
+ model = RealESRGAN(device, scale=size_modifier)
211
+ model.load_weights(f'weights/RealESRGAN_x{size_modifier}.pth', download=False)
212
+
213
+ result = model.predict(img.convert('RGB'))
214
+ print(f"Image size ({device}): {size_modifier} ... OK")
215
+ return result
216
+
217
+ def infer_video(video_filepath: str, size_modifier: int) -> str:
218
+ model = RealESRGAN(device, scale=size_modifier)
219
+ model.load_weights(f'weights/RealESRGAN_x{size_modifier}.pth', download=False)
220
+
221
+ cap = cv.VideoCapture(video_filepath)
222
+
223
+ tmpfile = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
224
+ vid_output = tmpfile.name
225
+ tmpfile.close()
226
+
227
+ vid_writer = cv.VideoWriter(
228
+ vid_output,
229
+ fourcc=cv.VideoWriter.fourcc(*'mp4v'),
230
+ fps=cap.get(cv.CAP_PROP_FPS),
231
+ frameSize=(int(cap.get(cv.CAP_PROP_FRAME_WIDTH)) * size_modifier, int(cap.get(cv.CAP_PROP_FRAME_HEIGHT)) * size_modifier)
232
+ )
233
+
234
+ n_frames = int(cap.get(cv.CAP_PROP_FRAME_COUNT))
235
+
236
+ # while cap.isOpened():
237
+ for _ in tqdm.tqdm(range(n_frames)):
238
+ ret, frame = cap.read()
239
+ if not ret:
240
+ break
241
+
242
+ frame = cv.cvtColor(frame, cv.COLOR_BGR2RGB)
243
+ frame = Image.fromarray(frame)
244
+
245
+ upscaled_frame = model.predict(frame.convert('RGB'))
246
+
247
+ upscaled_frame = np.array(upscaled_frame)
248
+ upscaled_frame = cv.cvtColor(upscaled_frame, cv.COLOR_RGB2BGR)
249
+
250
+ print(upscaled_frame.shape)
251
+
252
+ vid_writer.write(upscaled_frame)
253
+
254
+ vid_writer.release()
255
+
256
+ print(f"Video file : {video_filepath}")
257
+
258
+ return vid_output
259
+
260
+