ledetele commited on
Commit
a5291ee
0 Parent(s):

Duplicate from krystaltechnology/image-video-colorization

Browse files
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.streamlit/config.toml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ [theme]
2
+ primaryColor="#ff6328"
3
+ backgroundColor="#FFFFFF"
4
+ secondaryBackgroundColor="#F0F2F6"
5
+ textColor="#262730"
6
+ font="sans serif"
7
+ [server]
8
+ maxUploadSize=1028
01_B&W_Videos_Colorizer.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import time
4
+
5
+ os.environ["IMAGEIO_FFMPEG_EXE"] = "/usr/bin/ffmpeg"
6
+
7
+ import cv2
8
+ import moviepy.editor as mp
9
+ import numpy as np
10
+ import streamlit as st
11
+ from streamlit_lottie import st_lottie
12
+ from tqdm import tqdm
13
+
14
+ from models.deep_colorization.colorizers import eccv16
15
+ from utils import load_lottieurl, format_time, colorize_frame, change_model
16
+
17
+ st.title("B&W Videos Colorizer")
18
+
19
+ st.write("""
20
+ ##### Upload a black and white video and get a colorized version of it.
21
+ ###### ➠ This space is using CPU Basic so it might take a while to colorize a video.""")
22
+
23
+ #st.set_page_config(page_title="Image & Video Colorizer", page_icon="🎨", layout="wide")
24
+
25
+ loaded_model = eccv16(pretrained=True).eval()
26
+ current_model = "None"
27
+
28
+ def main():
29
+ model = st.selectbox(
30
+ "Select Model (Both models have their pros and cons, I recommend trying both and keeping the best for your task)",
31
+ ["ECCV16", "SIGGRAPH17"], index=0)
32
+
33
+ loaded_model = change_model(current_model, model)
34
+ st.write(f"Model is now {model}")
35
+
36
+ uploaded_file = st.file_uploader("Upload your video here...", type=['mp4', 'mov', 'avi', 'mkv'])
37
+
38
+ if st.button("Colorize"):
39
+ if uploaded_file is not None:
40
+ file_extension = os.path.splitext(uploaded_file.name)[1].lower()
41
+ if file_extension in ['.mp4', '.avi', '.mov', '.mkv']:
42
+ # Save the video file to a temporary location
43
+ temp_file = tempfile.NamedTemporaryFile(delete=False)
44
+ temp_file.write(uploaded_file.read())
45
+
46
+ audio = mp.AudioFileClip(temp_file.name)
47
+
48
+ # Open the video using cv2.VideoCapture
49
+ video = cv2.VideoCapture(temp_file.name)
50
+
51
+ # Get video information
52
+ fps = video.get(cv2.CAP_PROP_FPS)
53
+
54
+ col1, col2 = st.columns([0.5, 0.5])
55
+ with col1:
56
+ st.markdown('<p style="text-align: center;">Before</p>', unsafe_allow_html=True)
57
+ st.video(temp_file.name)
58
+
59
+ with col2:
60
+ st.markdown('<p style="text-align: center;">After</p>', unsafe_allow_html=True)
61
+
62
+ with st.spinner("Colorizing frames..."):
63
+ # Colorize video frames and store in a list
64
+ output_frames = []
65
+ total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
66
+ progress_bar = st.progress(0) # Create a progress bar
67
+
68
+ start_time = time.time()
69
+ time_text = st.text("Time Remaining: ") # Initialize text value
70
+
71
+ for _ in tqdm(range(total_frames), unit='frame', desc="Progress"):
72
+ ret, frame = video.read()
73
+ if not ret:
74
+ break
75
+
76
+ colorized_frame = colorize_frame(frame, loaded_model)
77
+ output_frames.append((colorized_frame * 255).astype(np.uint8))
78
+
79
+ elapsed_time = time.time() - start_time
80
+ frames_completed = len(output_frames)
81
+ frames_remaining = total_frames - frames_completed
82
+ time_remaining = (frames_remaining / frames_completed) * elapsed_time
83
+
84
+ progress_bar.progress(frames_completed / total_frames) # Update progress bar
85
+
86
+ if frames_completed < total_frames:
87
+ time_text.text(f"Time Remaining: {format_time(time_remaining)}") # Update text value
88
+ else:
89
+ time_text.empty() # Remove text value
90
+ progress_bar.empty()
91
+
92
+ with st.spinner("Merging frames to video..."):
93
+ frame_size = output_frames[0].shape[:2]
94
+ output_filename = "output.mp4"
95
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v") # Codec for MP4 video
96
+ out = cv2.VideoWriter(output_filename, fourcc, fps, (frame_size[1], frame_size[0]))
97
+
98
+ # Display the colorized video using st.video
99
+ for frame in output_frames:
100
+ frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
101
+
102
+ out.write(frame_bgr)
103
+
104
+ out.release()
105
+
106
+ # Convert the output video to a format compatible with Streamlit
107
+ converted_filename = "converted_output.mp4"
108
+ clip = mp.VideoFileClip(output_filename)
109
+ clip = clip.set_audio(audio)
110
+
111
+ clip.write_videofile(converted_filename, codec="libx264")
112
+
113
+ # Display the converted video using st.video()
114
+ st.video(converted_filename)
115
+ st.balloons()
116
+
117
+ # Add a download button for the colorized video
118
+ st.download_button(
119
+ label="Download Colorized Video",
120
+ data=open(converted_filename, "rb").read(),
121
+ file_name="colorized_video.mp4"
122
+ )
123
+
124
+ # Close and delete the temporary file after processing
125
+ video.release()
126
+ temp_file.close()
127
+
128
+
129
+ if __name__ == "__main__":
130
+ main()
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Image Video Colorization
3
+ emoji: 🎥
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: streamlit
7
+ sdk_version: 1.21.0
8
+ app_file: 01_B&W_Videos_Colorizer.py
9
+ pinned: false
10
+ duplicated_from: krystaltechnology/image-video-colorization
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
models/deep_colorization/colorizers/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+ from .base_color import *
3
+ from .eccv16 import *
4
+ from .siggraph17 import *
5
+ from .util import *
6
+
models/deep_colorization/colorizers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (279 Bytes). View file
 
models/deep_colorization/colorizers/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (285 Bytes). View file
 
models/deep_colorization/colorizers/__pycache__/base_color.cpython-310.pyc ADDED
Binary file (1.24 kB). View file
 
models/deep_colorization/colorizers/__pycache__/base_color.cpython-37.pyc ADDED
Binary file (1.24 kB). View file
 
models/deep_colorization/colorizers/__pycache__/eccv16.cpython-310.pyc ADDED
Binary file (3.27 kB). View file
 
models/deep_colorization/colorizers/__pycache__/eccv16.cpython-37.pyc ADDED
Binary file (3.26 kB). View file
 
models/deep_colorization/colorizers/__pycache__/siggraph17.cpython-310.pyc ADDED
Binary file (4.36 kB). View file
 
models/deep_colorization/colorizers/__pycache__/siggraph17.cpython-37.pyc ADDED
Binary file (4.36 kB). View file
 
models/deep_colorization/colorizers/__pycache__/util.cpython-310.pyc ADDED
Binary file (1.74 kB). View file
 
models/deep_colorization/colorizers/__pycache__/util.cpython-37.pyc ADDED
Binary file (1.71 kB). View file
 
models/deep_colorization/colorizers/base_color.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch import nn
4
+
5
+ class BaseColor(nn.Module):
6
+ def __init__(self):
7
+ super(BaseColor, self).__init__()
8
+
9
+ self.l_cent = 50.
10
+ self.l_norm = 100.
11
+ self.ab_norm = 110.
12
+
13
+ def normalize_l(self, in_l):
14
+ return (in_l-self.l_cent)/self.l_norm
15
+
16
+ def unnormalize_l(self, in_l):
17
+ return in_l*self.l_norm + self.l_cent
18
+
19
+ def normalize_ab(self, in_ab):
20
+ return in_ab/self.ab_norm
21
+
22
+ def unnormalize_ab(self, in_ab):
23
+ return in_ab*self.ab_norm
24
+
models/deep_colorization/colorizers/eccv16.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from IPython import embed
6
+
7
+ from .base_color import *
8
+
9
+ class ECCVGenerator(BaseColor):
10
+ def __init__(self, norm_layer=nn.BatchNorm2d):
11
+ super(ECCVGenerator, self).__init__()
12
+
13
+ model1=[nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=True),]
14
+ model1+=[nn.ReLU(True),]
15
+ model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=True),]
16
+ model1+=[nn.ReLU(True),]
17
+ model1+=[norm_layer(64),]
18
+
19
+ model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
20
+ model2+=[nn.ReLU(True),]
21
+ model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=True),]
22
+ model2+=[nn.ReLU(True),]
23
+ model2+=[norm_layer(128),]
24
+
25
+ model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),]
26
+ model3+=[nn.ReLU(True),]
27
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
28
+ model3+=[nn.ReLU(True),]
29
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=True),]
30
+ model3+=[nn.ReLU(True),]
31
+ model3+=[norm_layer(256),]
32
+
33
+ model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),]
34
+ model4+=[nn.ReLU(True),]
35
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
36
+ model4+=[nn.ReLU(True),]
37
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
38
+ model4+=[nn.ReLU(True),]
39
+ model4+=[norm_layer(512),]
40
+
41
+ model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
42
+ model5+=[nn.ReLU(True),]
43
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
44
+ model5+=[nn.ReLU(True),]
45
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
46
+ model5+=[nn.ReLU(True),]
47
+ model5+=[norm_layer(512),]
48
+
49
+ model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
50
+ model6+=[nn.ReLU(True),]
51
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
52
+ model6+=[nn.ReLU(True),]
53
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
54
+ model6+=[nn.ReLU(True),]
55
+ model6+=[norm_layer(512),]
56
+
57
+ model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
58
+ model7+=[nn.ReLU(True),]
59
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
60
+ model7+=[nn.ReLU(True),]
61
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
62
+ model7+=[nn.ReLU(True),]
63
+ model7+=[norm_layer(512),]
64
+
65
+ model8=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True),]
66
+ model8+=[nn.ReLU(True),]
67
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
68
+ model8+=[nn.ReLU(True),]
69
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
70
+ model8+=[nn.ReLU(True),]
71
+
72
+ model8+=[nn.Conv2d(256, 313, kernel_size=1, stride=1, padding=0, bias=True),]
73
+
74
+ self.model1 = nn.Sequential(*model1)
75
+ self.model2 = nn.Sequential(*model2)
76
+ self.model3 = nn.Sequential(*model3)
77
+ self.model4 = nn.Sequential(*model4)
78
+ self.model5 = nn.Sequential(*model5)
79
+ self.model6 = nn.Sequential(*model6)
80
+ self.model7 = nn.Sequential(*model7)
81
+ self.model8 = nn.Sequential(*model8)
82
+
83
+ self.softmax = nn.Softmax(dim=1)
84
+ self.model_out = nn.Conv2d(313, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=False)
85
+ self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear')
86
+
87
+ def forward(self, input_l):
88
+ conv1_2 = self.model1(self.normalize_l(input_l))
89
+ conv2_2 = self.model2(conv1_2)
90
+ conv3_3 = self.model3(conv2_2)
91
+ conv4_3 = self.model4(conv3_3)
92
+ conv5_3 = self.model5(conv4_3)
93
+ conv6_3 = self.model6(conv5_3)
94
+ conv7_3 = self.model7(conv6_3)
95
+ conv8_3 = self.model8(conv7_3)
96
+ out_reg = self.model_out(self.softmax(conv8_3))
97
+
98
+ return self.unnormalize_ab(self.upsample4(out_reg))
99
+
100
+ def eccv16(pretrained=True):
101
+ model = ECCVGenerator()
102
+ if(pretrained):
103
+ import torch.utils.model_zoo as model_zoo
104
+ model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/colorization_release_v2-9b330a0b.pth',map_location='cpu',check_hash=True))
105
+ return model
models/deep_colorization/colorizers/siggraph17.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .base_color import *
5
+
6
+ class SIGGRAPHGenerator(BaseColor):
7
+ def __init__(self, norm_layer=nn.BatchNorm2d, classes=529):
8
+ super(SIGGRAPHGenerator, self).__init__()
9
+
10
+ # Conv1
11
+ model1=[nn.Conv2d(4, 64, kernel_size=3, stride=1, padding=1, bias=True),]
12
+ model1+=[nn.ReLU(True),]
13
+ model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True),]
14
+ model1+=[nn.ReLU(True),]
15
+ model1+=[norm_layer(64),]
16
+ # add a subsampling operation
17
+
18
+ # Conv2
19
+ model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
20
+ model2+=[nn.ReLU(True),]
21
+ model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
22
+ model2+=[nn.ReLU(True),]
23
+ model2+=[norm_layer(128),]
24
+ # add a subsampling layer operation
25
+
26
+ # Conv3
27
+ model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),]
28
+ model3+=[nn.ReLU(True),]
29
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
30
+ model3+=[nn.ReLU(True),]
31
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
32
+ model3+=[nn.ReLU(True),]
33
+ model3+=[norm_layer(256),]
34
+ # add a subsampling layer operation
35
+
36
+ # Conv4
37
+ model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),]
38
+ model4+=[nn.ReLU(True),]
39
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
40
+ model4+=[nn.ReLU(True),]
41
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
42
+ model4+=[nn.ReLU(True),]
43
+ model4+=[norm_layer(512),]
44
+
45
+ # Conv5
46
+ model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
47
+ model5+=[nn.ReLU(True),]
48
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
49
+ model5+=[nn.ReLU(True),]
50
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
51
+ model5+=[nn.ReLU(True),]
52
+ model5+=[norm_layer(512),]
53
+
54
+ # Conv6
55
+ model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
56
+ model6+=[nn.ReLU(True),]
57
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
58
+ model6+=[nn.ReLU(True),]
59
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
60
+ model6+=[nn.ReLU(True),]
61
+ model6+=[norm_layer(512),]
62
+
63
+ # Conv7
64
+ model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
65
+ model7+=[nn.ReLU(True),]
66
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
67
+ model7+=[nn.ReLU(True),]
68
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
69
+ model7+=[nn.ReLU(True),]
70
+ model7+=[norm_layer(512),]
71
+
72
+ # Conv7
73
+ model8up=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True)]
74
+ model3short8=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
75
+
76
+ model8=[nn.ReLU(True),]
77
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
78
+ model8+=[nn.ReLU(True),]
79
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
80
+ model8+=[nn.ReLU(True),]
81
+ model8+=[norm_layer(256),]
82
+
83
+ # Conv9
84
+ model9up=[nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True),]
85
+ model2short9=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
86
+ # add the two feature maps above
87
+
88
+ model9=[nn.ReLU(True),]
89
+ model9+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
90
+ model9+=[nn.ReLU(True),]
91
+ model9+=[norm_layer(128),]
92
+
93
+ # Conv10
94
+ model10up=[nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True),]
95
+ model1short10=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
96
+ # add the two feature maps above
97
+
98
+ model10=[nn.ReLU(True),]
99
+ model10+=[nn.Conv2d(128, 128, kernel_size=3, dilation=1, stride=1, padding=1, bias=True),]
100
+ model10+=[nn.LeakyReLU(negative_slope=.2),]
101
+
102
+ # classification output
103
+ model_class=[nn.Conv2d(256, classes, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),]
104
+
105
+ # regression output
106
+ model_out=[nn.Conv2d(128, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),]
107
+ model_out+=[nn.Tanh()]
108
+
109
+ self.model1 = nn.Sequential(*model1)
110
+ self.model2 = nn.Sequential(*model2)
111
+ self.model3 = nn.Sequential(*model3)
112
+ self.model4 = nn.Sequential(*model4)
113
+ self.model5 = nn.Sequential(*model5)
114
+ self.model6 = nn.Sequential(*model6)
115
+ self.model7 = nn.Sequential(*model7)
116
+ self.model8up = nn.Sequential(*model8up)
117
+ self.model8 = nn.Sequential(*model8)
118
+ self.model9up = nn.Sequential(*model9up)
119
+ self.model9 = nn.Sequential(*model9)
120
+ self.model10up = nn.Sequential(*model10up)
121
+ self.model10 = nn.Sequential(*model10)
122
+ self.model3short8 = nn.Sequential(*model3short8)
123
+ self.model2short9 = nn.Sequential(*model2short9)
124
+ self.model1short10 = nn.Sequential(*model1short10)
125
+
126
+ self.model_class = nn.Sequential(*model_class)
127
+ self.model_out = nn.Sequential(*model_out)
128
+
129
+ self.upsample4 = nn.Sequential(*[nn.Upsample(scale_factor=4, mode='bilinear'),])
130
+ self.softmax = nn.Sequential(*[nn.Softmax(dim=1),])
131
+
132
+ def forward(self, input_A, input_B=None, mask_B=None):
133
+ if(input_B is None):
134
+ input_B = torch.cat((input_A*0, input_A*0), dim=1)
135
+ if(mask_B is None):
136
+ mask_B = input_A*0
137
+
138
+ conv1_2 = self.model1(torch.cat((self.normalize_l(input_A),self.normalize_ab(input_B),mask_B),dim=1))
139
+ conv2_2 = self.model2(conv1_2[:,:,::2,::2])
140
+ conv3_3 = self.model3(conv2_2[:,:,::2,::2])
141
+ conv4_3 = self.model4(conv3_3[:,:,::2,::2])
142
+ conv5_3 = self.model5(conv4_3)
143
+ conv6_3 = self.model6(conv5_3)
144
+ conv7_3 = self.model7(conv6_3)
145
+
146
+ conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3)
147
+ conv8_3 = self.model8(conv8_up)
148
+ conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
149
+ conv9_3 = self.model9(conv9_up)
150
+ conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
151
+ conv10_2 = self.model10(conv10_up)
152
+ out_reg = self.model_out(conv10_2)
153
+
154
+ conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
155
+ conv9_3 = self.model9(conv9_up)
156
+ conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
157
+ conv10_2 = self.model10(conv10_up)
158
+ out_reg = self.model_out(conv10_2)
159
+
160
+ return self.unnormalize_ab(out_reg)
161
+
162
+ def siggraph17(pretrained=True):
163
+ model = SIGGRAPHGenerator()
164
+ if(pretrained):
165
+ import torch.utils.model_zoo as model_zoo
166
+ model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/siggraph17-df00044c.pth',map_location='cpu',check_hash=True))
167
+ return model
168
+
models/deep_colorization/colorizers/util.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from PIL import Image
3
+ import numpy as np
4
+ from skimage import color
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from IPython import embed
8
+
9
+ def load_img(img_path):
10
+ out_np = np.asarray(Image.open(img_path))
11
+ if(out_np.ndim==2):
12
+ out_np = np.tile(out_np[:,:,None],3)
13
+ return out_np
14
+
15
+ def resize_img(img, HW=(256,256), resample=3):
16
+ return np.asarray(Image.fromarray(img).resize((HW[1],HW[0]), resample=resample))
17
+
18
+ def preprocess_img(img_rgb_orig, HW=(256,256), resample=3):
19
+ # return original size L and resized L as torch Tensors
20
+ img_rgb_rs = resize_img(img_rgb_orig, HW=HW, resample=resample)
21
+
22
+ img_lab_orig = color.rgb2lab(img_rgb_orig)
23
+ img_lab_rs = color.rgb2lab(img_rgb_rs)
24
+
25
+ img_l_orig = img_lab_orig[:,:,0]
26
+ img_l_rs = img_lab_rs[:,:,0]
27
+
28
+ tens_orig_l = torch.Tensor(img_l_orig)[None,None,:,:]
29
+ tens_rs_l = torch.Tensor(img_l_rs)[None,None,:,:]
30
+
31
+ return (tens_orig_l, tens_rs_l)
32
+
33
+ def postprocess_tens(tens_orig_l, out_ab, mode='bilinear'):
34
+ # tens_orig_l 1 x 1 x H_orig x W_orig
35
+ # out_ab 1 x 2 x H x W
36
+
37
+ HW_orig = tens_orig_l.shape[2:]
38
+ HW = out_ab.shape[2:]
39
+
40
+ # call resize function if needed
41
+ if(HW_orig[0]!=HW[0] or HW_orig[1]!=HW[1]):
42
+ out_ab_orig = F.interpolate(out_ab, size=HW_orig, mode='bilinear')
43
+ else:
44
+ out_ab_orig = out_ab
45
+
46
+ out_lab_orig = torch.cat((tens_orig_l, out_ab_orig), dim=1)
47
+ return color.lab2rgb(out_lab_orig.data.cpu().numpy()[0,...].transpose((1,2,0)))
pages/02_Input_Youtube_Link.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import cv2
4
+ import moviepy.editor as mp
5
+ import numpy as np
6
+ import streamlit as st
7
+ from pytube import YouTube
8
+ from streamlit_lottie import st_lottie
9
+ from tqdm import tqdm
10
+
11
+ from models.deep_colorization.colorizers import eccv16
12
+ from utils import colorize_frame, format_time
13
+ from utils import load_lottieurl, change_model
14
+
15
+ #st.set_page_config(page_title="Image & Video Colorizer", page_icon="🎨", layout="wide")
16
+
17
+
18
+ loaded_model = eccv16(pretrained=True).eval()
19
+ current_model = "None"
20
+
21
+ st.title("Image & Video Colorizer")
22
+
23
+ st.write("""
24
+ ##### Input a YouTube black and white video link and get a colorized version of it.
25
+ ###### ➠ This space is using CPU Basic so it might take a while to colorize a video.""")
26
+
27
+ @st.cache_data()
28
+ def download_video(link):
29
+ yt = YouTube(link)
30
+ video = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first().download(filename="video.mp4")
31
+ return video
32
+
33
+
34
+ def main():
35
+ model = st.selectbox(
36
+ "Select Model (Both models have their pros and cons, I recommend trying both and keeping the best for you task)",
37
+ ["ECCV16", "SIGGRAPH17"], index=0)
38
+
39
+ loaded_model = change_model(current_model, model)
40
+ st.write(f"Model is now {model}")
41
+
42
+ link = st.text_input("YouTube Link (The longer the video, the longer the processing time)")
43
+ if st.button("Colorize"):
44
+ if link is not "":
45
+ print(link)
46
+ yt_video = download_video(link)
47
+ print(yt_video)
48
+ col1, col2 = st.columns([0.5, 0.5])
49
+ with col1:
50
+ st.markdown('<p style="text-align: center;">Before</p>', unsafe_allow_html=True)
51
+ st.video(yt_video)
52
+ with col2:
53
+ st.markdown('<p style="text-align: center;">After</p>', unsafe_allow_html=True)
54
+ with st.spinner("Colorizing frames..."):
55
+ # Colorize video frames and store in a list
56
+ output_frames = []
57
+
58
+ audio = mp.AudioFileClip("video.mp4")
59
+ video = cv2.VideoCapture("video.mp4")
60
+
61
+ total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
62
+ fps = video.get(cv2.CAP_PROP_FPS)
63
+
64
+ progress_bar = st.progress(0) # Create a progress bar
65
+ start_time = time.time()
66
+ time_text = st.text("Time Remaining: ") # Initialize text value
67
+
68
+ for _ in tqdm(range(total_frames), unit='frame', desc="Progress"):
69
+ ret, frame = video.read()
70
+ if not ret:
71
+ break
72
+
73
+ colorized_frame = colorize_frame(frame, loaded_model)
74
+ output_frames.append((colorized_frame * 255).astype(np.uint8))
75
+
76
+ elapsed_time = time.time() - start_time
77
+ frames_completed = len(output_frames)
78
+ frames_remaining = total_frames - frames_completed
79
+ time_remaining = (frames_remaining / frames_completed) * elapsed_time
80
+
81
+ progress_bar.progress(frames_completed / total_frames) # Update progress bar
82
+
83
+ if frames_completed < total_frames:
84
+ time_text.text(f"Time Remaining: {format_time(time_remaining)}") # Update text value
85
+ else:
86
+ time_text.empty() # Remove text value
87
+ progress_bar.empty()
88
+
89
+ with st.spinner("Merging frames to video..."):
90
+ frame_size = output_frames[0].shape[:2]
91
+ output_filename = "output.mp4"
92
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v") # Codec for MP4 video
93
+ out = cv2.VideoWriter(output_filename, fourcc, fps, (frame_size[1], frame_size[0]))
94
+
95
+ # Display the colorized video using st.video
96
+ for frame in output_frames:
97
+ frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
98
+
99
+ out.write(frame_bgr)
100
+
101
+ out.release()
102
+
103
+ # Convert the output video to a format compatible with Streamlit
104
+ converted_filename = "converted_output.mp4"
105
+ clip = mp.VideoFileClip(output_filename)
106
+ clip = clip.set_audio(audio)
107
+
108
+ clip.write_videofile(converted_filename, codec="libx264")
109
+
110
+ # Display the converted video using st.video()
111
+ st.video(converted_filename)
112
+ st.balloons()
113
+
114
+ # Add a download button for the colorized video
115
+ st.download_button(
116
+ label="Download Colorized Video",
117
+ data=open(converted_filename, "rb").read(),
118
+ file_name="colorized_video.mp4"
119
+ )
120
+
121
+ # Close and delete the temporary file after processing
122
+ video.release()
123
+ else:
124
+ st.warning('Please Type a link', icon="⚠️")
125
+
126
+
127
+ if __name__ == "__main__":
128
+ main()
129
+
pages/03_B&W_Images_Colorizer.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import zipfile
3
+
4
+ import streamlit as st
5
+ from PIL import Image
6
+ from streamlit_lottie import st_lottie
7
+
8
+ from models.deep_colorization.colorizers import eccv16
9
+ from utils import colorize_image, change_model, load_lottieurl
10
+
11
+ #st.set_page_config(page_title="Image & Video Colorizer", page_icon="🎨", layout="wide")
12
+
13
+ st.title("B&W Images Colorizer")
14
+
15
+
16
+ loaded_model = eccv16(pretrained=True).eval()
17
+ current_model = "None"
18
+
19
+ st.write("""
20
+ ##### Input a black and white image and get a colorized version of it.
21
+ ###### ➠ If you want to colorize multiple images just upload them all at once.
22
+ ###### ➠ Uploading already colored images won't raise errors but images won't look good.""")
23
+
24
+
25
+ def main():
26
+ model = st.selectbox(
27
+ "Select Model (Both models have their pros and cons, I recommend trying both and keeping the best for you task)",
28
+ ["ECCV16", "SIGGRAPH17"], index=0)
29
+
30
+ # Make the user select a model
31
+ loaded_model = change_model(current_model, model)
32
+ st.write(f"Model is now {model}")
33
+
34
+ # Ask the user if he wants to see colorization
35
+ display_results = st.checkbox('Display results in real time', value=True)
36
+
37
+ # Input for the user to upload images
38
+ uploaded_file = st.file_uploader("Upload your images here...", type=['jpg', 'png', 'jpeg'],
39
+ accept_multiple_files=True)
40
+
41
+ # If the user clicks on the button
42
+ if st.button("Colorize"):
43
+ # If the user uploaded images
44
+ if uploaded_file is not None:
45
+ if display_results:
46
+ col1, col2 = st.columns([0.5, 0.5])
47
+ with col1:
48
+ st.markdown('<p style="text-align: center;">Before</p>', unsafe_allow_html=True)
49
+ with col2:
50
+ st.markdown('<p style="text-align: center;">After</p>', unsafe_allow_html=True)
51
+ else:
52
+ col1, col2, col3 = st.columns(3)
53
+
54
+ for i, file in enumerate(uploaded_file):
55
+ file_extension = os.path.splitext(file.name)[1].lower()
56
+ if file_extension in ['.jpg', '.png', '.jpeg']:
57
+ image = Image.open(file)
58
+ if display_results:
59
+ with col1:
60
+ st.image(image, use_column_width="always")
61
+ with col2:
62
+ with st.spinner("Colorizing image..."):
63
+ out_img, new_img = colorize_image(file, loaded_model)
64
+ new_img.save("IMG_" + str(i+1) + ".jpg")
65
+ st.image(out_img, use_column_width="always")
66
+
67
+ else:
68
+ out_img, new_img = colorize_image(file, loaded_model)
69
+ new_img.save("IMG_" + str(i+1) + ".jpg")
70
+
71
+ if len(uploaded_file) > 1:
72
+ # Create a zip file
73
+ zip_filename = "colorized_images.zip"
74
+ with zipfile.ZipFile(zip_filename, "w") as zip_file:
75
+ # Add colorized images to the zip file
76
+ for i in range(len(uploaded_file)):
77
+ zip_file.write("IMG_" + str(i + 1) + ".jpg", "IMG_" + str(i) + ".jpg")
78
+ with col2:
79
+ # Provide the zip file data for download
80
+ st.download_button(
81
+ label="Download Colorized Images" if len(uploaded_file) > 1 else "Download Colorized Image",
82
+ data=open(zip_filename, "rb").read(),
83
+ file_name=zip_filename,
84
+ )
85
+ else:
86
+ with col2:
87
+ st.download_button(
88
+ label="Download Colorized Image",
89
+ data=open("IMG_1.jpg", "rb").read(),
90
+ file_name="IMG_1.jpg",
91
+ )
92
+
93
+ else:
94
+ st.warning('Upload a file', icon="⚠️")
95
+
96
+
97
+ if __name__ == "__main__":
98
+ main()
pages/04_Super_Resolution.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import cv2
3
+ import numpy
4
+ import os
5
+ import random
6
+ from basicsr.archs.rrdbnet_arch import RRDBNet
7
+ from basicsr.utils.download_util import load_file_from_url
8
+ from PIL import Image
9
+
10
+ from realesrgan import RealESRGANer
11
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
12
+
13
+
14
+ last_file = None
15
+ img_mode = "RGBA"
16
+
17
+
18
+ def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
19
+ """Real-ESRGAN function to restore (and upscale) images.
20
+ """
21
+ if not img:
22
+ return
23
+
24
+ # Define model parameters
25
+ if model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model
26
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
27
+ netscale = 4
28
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
29
+ elif model_name == 'RealESRNet_x4plus': # x4 RRDBNet model
30
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
31
+ netscale = 4
32
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
33
+ elif model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks
34
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
35
+ netscale = 4
36
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
37
+ elif model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model
38
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
39
+ netscale = 2
40
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
41
+ elif model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size)
42
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
43
+ netscale = 4
44
+ file_url = [
45
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
46
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
47
+ ]
48
+
49
+ # Determine model paths
50
+ model_path = os.path.join('weights', model_name + '.pth')
51
+ if not os.path.isfile(model_path):
52
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
53
+ for url in file_url:
54
+ # model_path will be updated
55
+ model_path = load_file_from_url(
56
+ url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
57
+
58
+ # Use dni to control the denoise strength
59
+ dni_weight = None
60
+ if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
61
+ wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
62
+ model_path = [model_path, wdn_model_path]
63
+ dni_weight = [denoise_strength, 1 - denoise_strength]
64
+
65
+ # Restorer Class
66
+ upsampler = RealESRGANer(
67
+ scale=netscale,
68
+ model_path=model_path,
69
+ dni_weight=dni_weight,
70
+ model=model,
71
+ tile=0,
72
+ tile_pad=10,
73
+ pre_pad=10,
74
+ half=False,
75
+ gpu_id=None
76
+ )
77
+
78
+ # Use GFPGAN for face enhancement
79
+ if face_enhance:
80
+ from gfpgan import GFPGANer
81
+ face_enhancer = GFPGANer(
82
+ model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
83
+ upscale=outscale,
84
+ arch='clean',
85
+ channel_multiplier=2,
86
+ bg_upsampler=upsampler)
87
+
88
+ # Convert the input PIL image to cv2 image, so that it can be processed by realesrgan
89
+ #cv_img = numpy.array(img.get_value(), dtype = 'uint8')
90
+ cv_img = numpy.array(img)
91
+ #img = cv2.cvtColor(cv2.UMat(imgUMat), cv2.COLOR_RGB2GRAY)
92
+ img = cv2.cvtColor(cv_img, cv2.COLOR_RGBA2BGRA)
93
+
94
+ # Apply restoration
95
+ try:
96
+ if face_enhance:
97
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
98
+ else:
99
+ output, _ = upsampler.enhance(img, outscale=outscale)
100
+ except RuntimeError as error:
101
+ print('Error', error)
102
+ print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
103
+ else:
104
+ # Save restored image and return it to the output Image component
105
+ if img_mode == 'RGBA': # RGBA images should be saved in png format
106
+ extension = 'png'
107
+ else:
108
+ extension = 'jpg'
109
+
110
+ out_filename = f"output_{rnd_string(8)}.{extension}"
111
+ cv2.imwrite(out_filename, output)
112
+ global last_file
113
+ last_file = out_filename
114
+ return out_filename
115
+
116
+
117
+ def rnd_string(x):
118
+ """Returns a string of 'x' random characters
119
+ """
120
+ characters = "abcdefghijklmnopqrstuvwxyz_0123456789"
121
+ result = "".join((random.choice(characters)) for i in range(x))
122
+ return result
123
+
124
+
125
+ def reset():
126
+ """Resets the Image components of the Gradio interface and deletes
127
+ the last processed image
128
+ """
129
+ global last_file
130
+ if last_file:
131
+ print(f"Deleting {last_file} ...")
132
+ os.remove(last_file)
133
+ last_file = None
134
+ return gr.update(value=None), gr.update(value=None)
135
+
136
+
137
+ def has_transparency(img):
138
+ """This function works by first checking to see if a "transparency" property is defined
139
+ in the image's info -- if so, we return "True". Then, if the image is using indexed colors
140
+ (such as in GIFs), it gets the index of the transparent color in the palette
141
+ (img.info.get("transparency", -1)) and checks if it's used anywhere in the canvas
142
+ (img.getcolors()). If the image is in RGBA mode, then presumably it has transparency in
143
+ it, but it double-checks by getting the minimum and maximum values of every color channel
144
+ (img.getextrema()), and checks if the alpha channel's smallest value falls below 255.
145
+ https://stackoverflow.com/questions/43864101/python-pil-check-if-image-is-transparent
146
+ """
147
+ if img.info.get("transparency", None) is not None:
148
+ return True
149
+ if img.mode == "P":
150
+ transparent = img.info.get("transparency", -1)
151
+ for _, index in img.getcolors():
152
+ if index == transparent:
153
+ return True
154
+ elif img.mode == "RGBA":
155
+ extrema = img.getextrema()
156
+ if extrema[3][0] < 255:
157
+ return True
158
+ return False
159
+
160
+
161
+ def image_properties(img):
162
+ """Returns the dimensions (width and height) and color mode of the input image and
163
+ also sets the global img_mode variable to be used by the realesrgan function
164
+ """
165
+ global img_mode
166
+ if img:
167
+ if has_transparency(img):
168
+ img_mode = "RGBA"
169
+ else:
170
+ img_mode = "RGB"
171
+ properties = f"Width: {img.size[0]}, Height: {img.size[1]} | Color Mode: {img_mode}"
172
+ return properties
173
+
174
+ def image_properties(image):
175
+ # Function to display image properties
176
+ properties = f"Image Size: {image.size}\nImage Mode: {image.mode}"
177
+ return properties
178
+
179
+ #----------
180
+
181
+ input_folder = '.'
182
+
183
+ @st.cache_resource
184
+ def load_image(image_file):
185
+ img = Image.open(image_file)
186
+ return img
187
+
188
+ def save_image(image_file):
189
+ if image_file is not None:
190
+ filename = image_file.name
191
+ img = load_image(image_file)
192
+ st.image(image=img, width=None)
193
+ with open(os.path.join(input_folder, filename), "wb") as f:
194
+ f.write(image_file.getbuffer())
195
+ st.success("Succesfully uploaded file for processing".format(filename))
196
+
197
+ #------------
198
+
199
+ st.title("Super Resolution")
200
+ # Saving uploaded image in input folder for processing
201
+
202
+ #with st.expander("Options/Parameters"):
203
+
204
+ input_img = st.file_uploader(
205
+ "Upload Image", type=['png', 'jpeg', 'jpg', 'webp'])
206
+ #save_image(input_img)
207
+
208
+ model_name = st.selectbox(
209
+ "Real-ESRGAN inference model to be used",
210
+ ["RealESRGAN_x4plus", "RealESRNet_x4plus", "RealESRGAN_x4plus_anime_6B", "RealESRGAN_x2plus", "realesr-general-x4v3"],
211
+ index=4
212
+ )
213
+
214
+ #denoise_strength = st.slider("Denoise Strength (Used only with the realesr-general-x4v3 model)", 0.0, 1.0, 0.5)
215
+ denoise_strength = 0.5
216
+
217
+ outscale = st.slider("Image Upscaling Factor", 1, 10, 2)
218
+
219
+ face_enhance = st.checkbox("Face Enhancement using GFPGAN (Doesn't work for anime images)")
220
+
221
+ if input_img:
222
+ print(input_img)
223
+ input_img = Image.open(input_img)
224
+ # Display image properties
225
+ cols = st.columns(2)
226
+
227
+ cols[0].image(input_img, 'Source Image')
228
+
229
+ #input_properties = get_image_properties(input_img)
230
+ #cols[1].write(input_properties)
231
+
232
+ # Output placeholder
233
+ output_img = st.empty()
234
+
235
+ # Input and output placeholders
236
+ input_img = input_img
237
+ output_img = st.empty()
238
+
239
+ # Buttons
240
+ restore = st.button('Restore')
241
+ reset = st.button('Reset')
242
+
243
+ # Restore clicked
244
+ if restore:
245
+ if input_img is not None:
246
+ output = realesrgan(input_img, model_name, denoise_strength,
247
+ face_enhance, outscale)
248
+ output_img.image(output, 'Restored Image')
249
+ else:
250
+ st.warning('Upload a file', icon="⚠️")
251
+
252
+ # Reset clicked
253
+ if reset:
254
+ output_img.empty()
pages/05_Image_Denoizer.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import cv2
3
+ import numpy
4
+ import os
5
+ import random
6
+ from basicsr.archs.rrdbnet_arch import RRDBNet
7
+ from basicsr.utils.download_util import load_file_from_url
8
+ from PIL import Image
9
+
10
+ from realesrgan import RealESRGANer
11
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
12
+
13
+
14
+ last_file = None
15
+ img_mode = "RGBA"
16
+
17
+
18
+ def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
19
+ """Real-ESRGAN function to restore (and upscale) images.
20
+ """
21
+ if not img:
22
+ return
23
+
24
+ # Define model parameters
25
+ if model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model
26
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
27
+ netscale = 4
28
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
29
+ elif model_name == 'RealESRNet_x4plus': # x4 RRDBNet model
30
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
31
+ netscale = 4
32
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
33
+ elif model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks
34
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
35
+ netscale = 4
36
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
37
+ elif model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model
38
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
39
+ netscale = 2
40
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
41
+ elif model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size)
42
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
43
+ netscale = 4
44
+ file_url = [
45
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
46
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
47
+ ]
48
+
49
+ # Determine model paths
50
+ model_path = os.path.join('weights', model_name + '.pth')
51
+ if not os.path.isfile(model_path):
52
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
53
+ for url in file_url:
54
+ # model_path will be updated
55
+ model_path = load_file_from_url(
56
+ url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
57
+
58
+ # Use dni to control the denoise strength
59
+ dni_weight = None
60
+ if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
61
+ wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
62
+ model_path = [model_path, wdn_model_path]
63
+ dni_weight = [denoise_strength, 1 - denoise_strength]
64
+
65
+ # Restorer Class
66
+ upsampler = RealESRGANer(
67
+ scale=netscale,
68
+ model_path=model_path,
69
+ dni_weight=dni_weight,
70
+ model=model,
71
+ tile=0,
72
+ tile_pad=10,
73
+ pre_pad=10,
74
+ half=False,
75
+ gpu_id=None
76
+ )
77
+
78
+ # Use GFPGAN for face enhancement
79
+ if face_enhance:
80
+ from gfpgan import GFPGANer
81
+ face_enhancer = GFPGANer(
82
+ model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
83
+ upscale=outscale,
84
+ arch='clean',
85
+ channel_multiplier=2,
86
+ bg_upsampler=upsampler)
87
+
88
+ # Convert the input PIL image to cv2 image, so that it can be processed by realesrgan
89
+ #cv_img = numpy.array(img.get_value(), dtype = 'uint8')
90
+ cv_img = numpy.array(img)
91
+ #img = cv2.cvtColor(cv2.UMat(imgUMat), cv2.COLOR_RGB2GRAY)
92
+ img = cv2.cvtColor(cv_img, cv2.COLOR_RGBA2BGRA)
93
+
94
+ # Apply restoration
95
+ try:
96
+ if face_enhance:
97
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
98
+ else:
99
+ output, _ = upsampler.enhance(img, outscale=outscale)
100
+ except RuntimeError as error:
101
+ print('Error', error)
102
+ print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
103
+ else:
104
+ # Save restored image and return it to the output Image component
105
+ if img_mode == 'RGBA': # RGBA images should be saved in png format
106
+ extension = 'png'
107
+ else:
108
+ extension = 'jpg'
109
+
110
+ out_filename = f"output_{rnd_string(8)}.{extension}"
111
+ cv2.imwrite(out_filename, output)
112
+ global last_file
113
+ last_file = out_filename
114
+ return out_filename
115
+
116
+
117
+ def rnd_string(x):
118
+ """Returns a string of 'x' random characters
119
+ """
120
+ characters = "abcdefghijklmnopqrstuvwxyz_0123456789"
121
+ result = "".join((random.choice(characters)) for i in range(x))
122
+ return result
123
+
124
+
125
+ def reset():
126
+ """Resets the Image components of the Gradio interface and deletes
127
+ the last processed image
128
+ """
129
+ global last_file
130
+ if last_file:
131
+ print(f"Deleting {last_file} ...")
132
+ os.remove(last_file)
133
+ last_file = None
134
+ return gr.update(value=None), gr.update(value=None)
135
+
136
+
137
+ def has_transparency(img):
138
+ """This function works by first checking to see if a "transparency" property is defined
139
+ in the image's info -- if so, we return "True". Then, if the image is using indexed colors
140
+ (such as in GIFs), it gets the index of the transparent color in the palette
141
+ (img.info.get("transparency", -1)) and checks if it's used anywhere in the canvas
142
+ (img.getcolors()). If the image is in RGBA mode, then presumably it has transparency in
143
+ it, but it double-checks by getting the minimum and maximum values of every color channel
144
+ (img.getextrema()), and checks if the alpha channel's smallest value falls below 255.
145
+ https://stackoverflow.com/questions/43864101/python-pil-check-if-image-is-transparent
146
+ """
147
+ if img.info.get("transparency", None) is not None:
148
+ return True
149
+ if img.mode == "P":
150
+ transparent = img.info.get("transparency", -1)
151
+ for _, index in img.getcolors():
152
+ if index == transparent:
153
+ return True
154
+ elif img.mode == "RGBA":
155
+ extrema = img.getextrema()
156
+ if extrema[3][0] < 255:
157
+ return True
158
+ return False
159
+
160
+
161
+ def image_properties(img):
162
+ """Returns the dimensions (width and height) and color mode of the input image and
163
+ also sets the global img_mode variable to be used by the realesrgan function
164
+ """
165
+ global img_mode
166
+ if img:
167
+ if has_transparency(img):
168
+ img_mode = "RGBA"
169
+ else:
170
+ img_mode = "RGB"
171
+ properties = f"Width: {img.size[0]}, Height: {img.size[1]} | Color Mode: {img_mode}"
172
+ return properties
173
+
174
+ def image_properties(image):
175
+ # Function to display image properties
176
+ properties = f"Image Size: {image.size}\nImage Mode: {image.mode}"
177
+ return properties
178
+
179
+ #----------
180
+
181
+ input_folder = '.'
182
+
183
+ @st.cache_resource
184
+ def load_image(image_file):
185
+ img = Image.open(image_file)
186
+ return img
187
+
188
+ def save_image(image_file):
189
+ if image_file is not None:
190
+ filename = image_file.name
191
+ img = load_image(image_file)
192
+ st.image(image=img, width=None)
193
+ with open(os.path.join(input_folder, filename), "wb") as f:
194
+ f.write(image_file.getbuffer())
195
+ st.success("Succesfully uploaded file for processing".format(filename))
196
+
197
+ #------------
198
+
199
+ st.title("Image Denoizer")
200
+ # Saving uploaded image in input folder for processing
201
+
202
+ #with st.expander("Options/Parameters"):
203
+
204
+ input_img = st.file_uploader(
205
+ "Upload Image", type=['png', 'jpeg', 'jpg', 'webp'])
206
+ #save_image(input_img)
207
+
208
+ model_name = "realesr-general-x4v3"
209
+
210
+ denoise_strength = st.slider("Denoise Strength", 0.0, 1.0, 0.5)
211
+
212
+ outscale = 1
213
+
214
+ face_enhance = False
215
+
216
+ if input_img:
217
+ print(input_img)
218
+ input_img = Image.open(input_img)
219
+ # Display image properties
220
+ cols = st.columns(2)
221
+
222
+ cols[0].image(input_img, 'Source Image')
223
+
224
+ #input_properties = get_image_properties(input_img)
225
+ #cols[1].write(input_properties)
226
+
227
+ # Output placeholder
228
+ output_img = st.empty()
229
+
230
+ # Input and output placeholders
231
+ input_img = input_img
232
+ output_img = st.empty()
233
+
234
+ # Buttons
235
+ restore = st.button('Restore')
236
+ reset = st.button('Reset')
237
+
238
+ # Restore clicked
239
+ if restore:
240
+ if input_img is not None:
241
+ output = realesrgan(input_img, model_name, denoise_strength,
242
+ face_enhance, outscale)
243
+ output_img.image(output, 'Restored Image')
244
+ else:
245
+ st.warning('Upload a file', icon="⚠️")
246
+
247
+ # Reset clicked
248
+ if reset:
249
+ output_img.empty()
250
+
requirements.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ipython==8.5.0
2
+ moviepy==1.0.3
3
+ numpy==1.23.2
4
+ opencv_python==4.7.0.68
5
+ Pillow==9.5.0
6
+ scikit-image==0.20.0
7
+ streamlit==1.22.0
8
+ torch
9
+ streamlit_lottie==0.0.5
10
+ requests==2.28.1
11
+ tqdm==4.64.1
12
+ torch
13
+ torchvision
14
+ numpy
15
+ opencv-python
16
+ Pillow
17
+ basicsr
18
+ facexlib
19
+ gfpgan
20
+ tqdm
21
+ gradio
22
+ realesrgan
23
+
24
+ git+https://github.com/oncename/pytube.git
25
+
utils.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import requests
3
+ import streamlit as st
4
+ from PIL import Image
5
+
6
+ from models.deep_colorization.colorizers import postprocess_tens, preprocess_img, load_img, eccv16, siggraph17
7
+
8
+
9
+ # Define a function that we can use to load lottie files from a link.
10
+ @st.cache_data()
11
+ def load_lottieurl(url: str):
12
+ r = requests.get(url)
13
+ if r.status_code != 200:
14
+ return None
15
+ return r.json()
16
+
17
+
18
+ @st.cache_resource()
19
+ def change_model(current_model, model):
20
+ if current_model != model:
21
+ if model == "ECCV16":
22
+ loaded_model = eccv16(pretrained=True).eval()
23
+ elif model == "SIGGRAPH17":
24
+ loaded_model = siggraph17(pretrained=True).eval()
25
+ return loaded_model
26
+ else:
27
+ raise Exception("Model is the same as the current one.")
28
+
29
+
30
+ def format_time(seconds: float) -> str:
31
+ """Formats time in seconds to a human readable format"""
32
+ if seconds < 60:
33
+ return f"{int(seconds)} seconds"
34
+ elif seconds < 3600:
35
+ minutes = seconds // 60
36
+ seconds %= 60
37
+ return f"{minutes} minutes and {int(seconds)} seconds"
38
+ elif seconds < 86400:
39
+ hours = seconds // 3600
40
+ minutes = (seconds % 3600) // 60
41
+ seconds %= 60
42
+ return f"{hours} hours, {minutes} minutes, and {int(seconds)} seconds"
43
+ else:
44
+ days = seconds // 86400
45
+ hours = (seconds % 86400) // 3600
46
+ minutes = (seconds % 3600) // 60
47
+ seconds %= 60
48
+ return f"{days} days, {hours} hours, {minutes} minutes, and {int(seconds)} seconds"
49
+
50
+
51
+ # Function to colorize video frames
52
+ def colorize_frame(frame, colorizer) -> np.ndarray:
53
+ tens_l_orig, tens_l_rs = preprocess_img(frame, HW=(256, 256))
54
+ return postprocess_tens(tens_l_orig, colorizer(tens_l_rs).cpu())
55
+
56
+
57
+ def colorize_image(file, loaded_model):
58
+ img = load_img(file)
59
+ # If user input a colored image with 4 channels, discard the fourth channel
60
+ if img.shape[2] == 4:
61
+ img = img[:, :, :3]
62
+
63
+ tens_l_orig, tens_l_rs = preprocess_img(img, HW=(256, 256))
64
+ out_img = postprocess_tens(tens_l_orig, loaded_model(tens_l_rs).cpu())
65
+ new_img = Image.fromarray((out_img * 255).astype(np.uint8))
66
+
67
+ return out_img, new_img