Clement Delteil commited on
Commit
9d58c24
1 Parent(s): f7020a8

commit app and models

Browse files
app.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from models.deep_colorization.colorizers import *
3
+ import cv2
4
+ from PIL import Image
5
+ import pathlib
6
+ import tempfile
7
+ import moviepy.editor as mp
8
+ import time
9
+ from tqdm import tqdm
10
+
11
+
12
+ def format_time(seconds: float) -> str:
13
+ """Formats time in seconds to a human readable format"""
14
+ if seconds < 60:
15
+ return f"{int(seconds)} seconds"
16
+ elif seconds < 3600:
17
+ minutes = seconds // 60
18
+ seconds %= 60
19
+ return f"{minutes} minutes and {int(seconds)} seconds"
20
+ elif seconds < 86400:
21
+ hours = seconds // 3600
22
+ minutes = (seconds % 3600) // 60
23
+ seconds %= 60
24
+ return f"{hours} hours, {minutes} minutes, and {int(seconds)} seconds"
25
+ else:
26
+ days = seconds // 86400
27
+ hours = (seconds % 86400) // 3600
28
+ minutes = (seconds % 3600) // 60
29
+ seconds %= 60
30
+ return f"{days} days, {hours} hours, {minutes} minutes, and {int(seconds)} seconds"
31
+
32
+
33
+ # Function to colorize video frames
34
+ def colorize_frame(frame, colorizer) -> np.ndarray:
35
+ tens_l_orig, tens_l_rs = preprocess_img(frame, HW=(256, 256))
36
+ return postprocess_tens(tens_l_orig, colorizer(tens_l_rs).cpu())
37
+
38
+
39
+ image = Image.open(r'img/streamlit.png') # Brand logo image (optional)
40
+
41
+ APP_DIR = pathlib.Path(__file__).parent.absolute()
42
+
43
+ LOCAL_DIR = APP_DIR / "local_video"
44
+ LOCAL_DIR.mkdir(exist_ok=True)
45
+ save_dir = LOCAL_DIR / "output"
46
+ save_dir.mkdir(exist_ok=True)
47
+
48
+ print(APP_DIR)
49
+ print(LOCAL_DIR)
50
+ print(save_dir)
51
+
52
+ # Create two columns with different width
53
+ col1, col2 = st.columns([0.8, 0.2])
54
+ with col1: # To display the header text using css style
55
+ st.markdown(""" <style> .font {
56
+ font-size:35px ; font-family: 'Cooper Black'; color: #FF4B4B;}
57
+ </style> """, unsafe_allow_html=True)
58
+ st.markdown('<p class="font">Upload your photo or video here...</p>', unsafe_allow_html=True)
59
+
60
+ with col2: # To display brand logo
61
+ st.image(image, width=100)
62
+
63
+ # Add a header and expander in side bar
64
+ st.sidebar.markdown('<p class="font">Color Revive App</p>', unsafe_allow_html=True)
65
+ with st.sidebar.expander("About the App"):
66
+ st.write("""
67
+ Use this simple app to colorize your black and white images and videos with state of the art models.
68
+ """)
69
+
70
+ # Add file uploader to allow users to upload photos
71
+ uploaded_file = st.file_uploader("", type=['jpg', 'png', 'jpeg', 'mp4'])
72
+
73
+ # Add 'before' and 'after' columns
74
+ if uploaded_file is not None:
75
+ file_extension = uploaded_file.name.split('.')[1].lower()
76
+
77
+ if file_extension in ['jpg', 'png', 'jpeg']:
78
+ image = Image.open(uploaded_file)
79
+
80
+ col1, col2 = st.columns([0.5, 0.5])
81
+ with col1:
82
+ st.markdown('<p style="text-align: center;">Before</p>', unsafe_allow_html=True)
83
+ st.image(image, width=300)
84
+
85
+ # Add conditional statements to take the user input values
86
+ with col2:
87
+ st.markdown('<p style="text-align: center;">After</p>', unsafe_allow_html=True)
88
+ filter = st.sidebar.radio('Colorize your image with:',
89
+ ['Original', 'ECCV 16', 'SIGGRAPH 17'])
90
+ if filter == 'ECCV 16':
91
+ colorizer_eccv16 = eccv16(pretrained=True).eval()
92
+ img = load_img(uploaded_file)
93
+ (tens_l_orig, tens_l_rs) = preprocess_img(img, HW=(256, 256))
94
+ out_img_eccv16 = postprocess_tens(tens_l_orig, colorizer_eccv16(tens_l_rs).cpu())
95
+ st.image(out_img_eccv16, width=300)
96
+ elif filter == 'SIGGRAPH 17':
97
+ colorizer_siggraph17 = siggraph17(pretrained=True).eval()
98
+ img = load_img(uploaded_file)
99
+ (tens_l_orig, tens_l_rs) = preprocess_img(img, HW=(256, 256))
100
+ out_img_siggraph17 = postprocess_tens(tens_l_orig, colorizer_siggraph17(tens_l_rs).cpu())
101
+ st.image(out_img_siggraph17, width=300)
102
+ else:
103
+ st.image(image, width=300)
104
+ elif file_extension == 'mp4': # If uploaded file is a video
105
+ # Save the video file to a temporary location
106
+ temp_file = tempfile.NamedTemporaryFile(delete=False)
107
+ temp_file.write(uploaded_file.read())
108
+
109
+ # Open the video using cv2.VideoCapture
110
+ video = cv2.VideoCapture(temp_file.name)
111
+
112
+ # Get video information
113
+ fps = video.get(cv2.CAP_PROP_FPS)
114
+
115
+ # Create two columns for video display
116
+ col1, col2 = st.columns([0.5, 0.5])
117
+ with col1:
118
+ st.markdown('<p style="text-align: center;">Before</p>', unsafe_allow_html=True)
119
+ st.video(temp_file.name)
120
+
121
+ with col2:
122
+ st.markdown('<p style="text-align: center;">After</p>', unsafe_allow_html=True)
123
+ filter = st.sidebar.radio('Colorize your video with:',
124
+ ['Original', 'ECCV 16', 'SIGGRAPH 17'])
125
+ if filter == 'ECCV 16':
126
+ colorizer = eccv16(pretrained=True).eval()
127
+ elif filter == 'SIGGRAPH 17':
128
+ colorizer = siggraph17(pretrained=True).eval()
129
+
130
+ if filter != 'Original':
131
+ with st.spinner("Colorizing frames..."):
132
+ # Colorize video frames and store in a list
133
+ output_frames = []
134
+ total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
135
+ progress_bar = st.empty()
136
+
137
+ start_time = time.time()
138
+ for i in tqdm(range(total_frames), unit='frame', desc="Progress"):
139
+ ret, frame = video.read()
140
+ if not ret:
141
+ break
142
+
143
+ colorized_frame = colorize_frame(frame, colorizer)
144
+ output_frames.append((colorized_frame * 255).astype(np.uint8))
145
+
146
+ elapsed_time = time.time() - start_time
147
+ frames_completed = len(output_frames)
148
+ frames_remaining = total_frames - frames_completed
149
+ time_remaining = (frames_remaining / frames_completed) * elapsed_time
150
+
151
+ progress_bar.progress(frames_completed / total_frames)
152
+
153
+ if frames_completed < total_frames:
154
+ progress_bar.text(f"Time Remaining: {format_time(time_remaining)}")
155
+ else:
156
+ progress_bar.empty()
157
+
158
+ with st.spinner("Merging frames to video..."):
159
+ print("finished")
160
+ frame_size = output_frames[0].shape[:2]
161
+ print(frame_size)
162
+ output_filename = "output.mp4"
163
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v") # Codec for MP4 video
164
+ print(fps)
165
+ out = cv2.VideoWriter(output_filename, fourcc, fps, (3840, 2160))
166
+
167
+ # Display the colorized video using st.video
168
+ for frame in output_frames:
169
+ frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
170
+
171
+ out.write(frame_bgr)
172
+
173
+ out.release()
174
+
175
+ # Convert the output video to a format compatible with Streamlit
176
+ converted_filename = "converted_output.mp4"
177
+ clip = mp.VideoFileClip(output_filename)
178
+ clip.write_videofile(converted_filename, codec="libx264")
179
+
180
+ # Display the converted video using st.video()
181
+ st.video(converted_filename)
182
+
183
+ # Add a download button for the colorized video
184
+ st.download_button(
185
+ label="Download Colorized Video",
186
+ data=open(converted_filename, "rb").read(),
187
+ file_name="colorized_video.mp4"
188
+ )
189
+
190
+ # Close and delete the temporary file after processing
191
+ video.release()
192
+ temp_file.close()
193
+
194
+ # Add a feedback section in the sidebar
195
+ st.sidebar.title(' ') # Used to create some space between the filter widget and the comments section
196
+ st.sidebar.markdown(' ') # Used to create some space between the filter widget and the comments section
197
+ st.sidebar.subheader('Please help us improve!')
198
+ with st.sidebar.form(key='columns_in_form',
199
+ clear_on_submit=True): # set clear_on_submit=True so that the form will be reset/cleared once
200
+ # it's submitted
201
+ rating = st.slider("Please rate the app", min_value=1, max_value=5, value=3,
202
+ help='Drag the slider to rate the app. This is a 1-5 rating scale where 5 is the highest rating')
203
+ text = st.text_input(label='Please leave your feedback here')
204
+ submitted = st.form_submit_button('Submit')
205
+ if submitted:
206
+ st.write('Thanks for your feedback!')
207
+ st.markdown('Your Rating:')
208
+ st.markdown(rating)
209
+ st.markdown('Your Feedback:')
210
+ st.markdown(text)
img/color_revive.png ADDED
img/streamlit.png ADDED
models/deep_colorization/colorizers/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+ from .base_color import *
3
+ from .eccv16 import *
4
+ from .siggraph17 import *
5
+ from .util import *
6
+
models/deep_colorization/colorizers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (279 Bytes). View file
 
models/deep_colorization/colorizers/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (285 Bytes). View file
 
models/deep_colorization/colorizers/__pycache__/base_color.cpython-310.pyc ADDED
Binary file (1.24 kB). View file
 
models/deep_colorization/colorizers/__pycache__/base_color.cpython-37.pyc ADDED
Binary file (1.24 kB). View file
 
models/deep_colorization/colorizers/__pycache__/eccv16.cpython-310.pyc ADDED
Binary file (3.27 kB). View file
 
models/deep_colorization/colorizers/__pycache__/eccv16.cpython-37.pyc ADDED
Binary file (3.26 kB). View file
 
models/deep_colorization/colorizers/__pycache__/siggraph17.cpython-310.pyc ADDED
Binary file (4.36 kB). View file
 
models/deep_colorization/colorizers/__pycache__/siggraph17.cpython-37.pyc ADDED
Binary file (4.36 kB). View file
 
models/deep_colorization/colorizers/__pycache__/util.cpython-310.pyc ADDED
Binary file (1.74 kB). View file
 
models/deep_colorization/colorizers/__pycache__/util.cpython-37.pyc ADDED
Binary file (1.71 kB). View file
 
models/deep_colorization/colorizers/base_color.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch import nn
4
+
5
+ class BaseColor(nn.Module):
6
+ def __init__(self):
7
+ super(BaseColor, self).__init__()
8
+
9
+ self.l_cent = 50.
10
+ self.l_norm = 100.
11
+ self.ab_norm = 110.
12
+
13
+ def normalize_l(self, in_l):
14
+ return (in_l-self.l_cent)/self.l_norm
15
+
16
+ def unnormalize_l(self, in_l):
17
+ return in_l*self.l_norm + self.l_cent
18
+
19
+ def normalize_ab(self, in_ab):
20
+ return in_ab/self.ab_norm
21
+
22
+ def unnormalize_ab(self, in_ab):
23
+ return in_ab*self.ab_norm
24
+
models/deep_colorization/colorizers/eccv16.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from IPython import embed
6
+
7
+ from .base_color import *
8
+
9
+ class ECCVGenerator(BaseColor):
10
+ def __init__(self, norm_layer=nn.BatchNorm2d):
11
+ super(ECCVGenerator, self).__init__()
12
+
13
+ model1=[nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=True),]
14
+ model1+=[nn.ReLU(True),]
15
+ model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=True),]
16
+ model1+=[nn.ReLU(True),]
17
+ model1+=[norm_layer(64),]
18
+
19
+ model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
20
+ model2+=[nn.ReLU(True),]
21
+ model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=True),]
22
+ model2+=[nn.ReLU(True),]
23
+ model2+=[norm_layer(128),]
24
+
25
+ model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),]
26
+ model3+=[nn.ReLU(True),]
27
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
28
+ model3+=[nn.ReLU(True),]
29
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=True),]
30
+ model3+=[nn.ReLU(True),]
31
+ model3+=[norm_layer(256),]
32
+
33
+ model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),]
34
+ model4+=[nn.ReLU(True),]
35
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
36
+ model4+=[nn.ReLU(True),]
37
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
38
+ model4+=[nn.ReLU(True),]
39
+ model4+=[norm_layer(512),]
40
+
41
+ model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
42
+ model5+=[nn.ReLU(True),]
43
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
44
+ model5+=[nn.ReLU(True),]
45
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
46
+ model5+=[nn.ReLU(True),]
47
+ model5+=[norm_layer(512),]
48
+
49
+ model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
50
+ model6+=[nn.ReLU(True),]
51
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
52
+ model6+=[nn.ReLU(True),]
53
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
54
+ model6+=[nn.ReLU(True),]
55
+ model6+=[norm_layer(512),]
56
+
57
+ model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
58
+ model7+=[nn.ReLU(True),]
59
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
60
+ model7+=[nn.ReLU(True),]
61
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
62
+ model7+=[nn.ReLU(True),]
63
+ model7+=[norm_layer(512),]
64
+
65
+ model8=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True),]
66
+ model8+=[nn.ReLU(True),]
67
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
68
+ model8+=[nn.ReLU(True),]
69
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
70
+ model8+=[nn.ReLU(True),]
71
+
72
+ model8+=[nn.Conv2d(256, 313, kernel_size=1, stride=1, padding=0, bias=True),]
73
+
74
+ self.model1 = nn.Sequential(*model1)
75
+ self.model2 = nn.Sequential(*model2)
76
+ self.model3 = nn.Sequential(*model3)
77
+ self.model4 = nn.Sequential(*model4)
78
+ self.model5 = nn.Sequential(*model5)
79
+ self.model6 = nn.Sequential(*model6)
80
+ self.model7 = nn.Sequential(*model7)
81
+ self.model8 = nn.Sequential(*model8)
82
+
83
+ self.softmax = nn.Softmax(dim=1)
84
+ self.model_out = nn.Conv2d(313, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=False)
85
+ self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear')
86
+
87
+ def forward(self, input_l):
88
+ conv1_2 = self.model1(self.normalize_l(input_l))
89
+ conv2_2 = self.model2(conv1_2)
90
+ conv3_3 = self.model3(conv2_2)
91
+ conv4_3 = self.model4(conv3_3)
92
+ conv5_3 = self.model5(conv4_3)
93
+ conv6_3 = self.model6(conv5_3)
94
+ conv7_3 = self.model7(conv6_3)
95
+ conv8_3 = self.model8(conv7_3)
96
+ out_reg = self.model_out(self.softmax(conv8_3))
97
+
98
+ return self.unnormalize_ab(self.upsample4(out_reg))
99
+
100
+ def eccv16(pretrained=True):
101
+ model = ECCVGenerator()
102
+ if(pretrained):
103
+ import torch.utils.model_zoo as model_zoo
104
+ model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/colorization_release_v2-9b330a0b.pth',map_location='cpu',check_hash=True))
105
+ return model
models/deep_colorization/colorizers/siggraph17.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .base_color import *
5
+
6
+ class SIGGRAPHGenerator(BaseColor):
7
+ def __init__(self, norm_layer=nn.BatchNorm2d, classes=529):
8
+ super(SIGGRAPHGenerator, self).__init__()
9
+
10
+ # Conv1
11
+ model1=[nn.Conv2d(4, 64, kernel_size=3, stride=1, padding=1, bias=True),]
12
+ model1+=[nn.ReLU(True),]
13
+ model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True),]
14
+ model1+=[nn.ReLU(True),]
15
+ model1+=[norm_layer(64),]
16
+ # add a subsampling operation
17
+
18
+ # Conv2
19
+ model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
20
+ model2+=[nn.ReLU(True),]
21
+ model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
22
+ model2+=[nn.ReLU(True),]
23
+ model2+=[norm_layer(128),]
24
+ # add a subsampling layer operation
25
+
26
+ # Conv3
27
+ model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),]
28
+ model3+=[nn.ReLU(True),]
29
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
30
+ model3+=[nn.ReLU(True),]
31
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
32
+ model3+=[nn.ReLU(True),]
33
+ model3+=[norm_layer(256),]
34
+ # add a subsampling layer operation
35
+
36
+ # Conv4
37
+ model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),]
38
+ model4+=[nn.ReLU(True),]
39
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
40
+ model4+=[nn.ReLU(True),]
41
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
42
+ model4+=[nn.ReLU(True),]
43
+ model4+=[norm_layer(512),]
44
+
45
+ # Conv5
46
+ model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
47
+ model5+=[nn.ReLU(True),]
48
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
49
+ model5+=[nn.ReLU(True),]
50
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
51
+ model5+=[nn.ReLU(True),]
52
+ model5+=[norm_layer(512),]
53
+
54
+ # Conv6
55
+ model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
56
+ model6+=[nn.ReLU(True),]
57
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
58
+ model6+=[nn.ReLU(True),]
59
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
60
+ model6+=[nn.ReLU(True),]
61
+ model6+=[norm_layer(512),]
62
+
63
+ # Conv7
64
+ model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
65
+ model7+=[nn.ReLU(True),]
66
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
67
+ model7+=[nn.ReLU(True),]
68
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
69
+ model7+=[nn.ReLU(True),]
70
+ model7+=[norm_layer(512),]
71
+
72
+ # Conv7
73
+ model8up=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True)]
74
+ model3short8=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
75
+
76
+ model8=[nn.ReLU(True),]
77
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
78
+ model8+=[nn.ReLU(True),]
79
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
80
+ model8+=[nn.ReLU(True),]
81
+ model8+=[norm_layer(256),]
82
+
83
+ # Conv9
84
+ model9up=[nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True),]
85
+ model2short9=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
86
+ # add the two feature maps above
87
+
88
+ model9=[nn.ReLU(True),]
89
+ model9+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
90
+ model9+=[nn.ReLU(True),]
91
+ model9+=[norm_layer(128),]
92
+
93
+ # Conv10
94
+ model10up=[nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True),]
95
+ model1short10=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
96
+ # add the two feature maps above
97
+
98
+ model10=[nn.ReLU(True),]
99
+ model10+=[nn.Conv2d(128, 128, kernel_size=3, dilation=1, stride=1, padding=1, bias=True),]
100
+ model10+=[nn.LeakyReLU(negative_slope=.2),]
101
+
102
+ # classification output
103
+ model_class=[nn.Conv2d(256, classes, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),]
104
+
105
+ # regression output
106
+ model_out=[nn.Conv2d(128, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),]
107
+ model_out+=[nn.Tanh()]
108
+
109
+ self.model1 = nn.Sequential(*model1)
110
+ self.model2 = nn.Sequential(*model2)
111
+ self.model3 = nn.Sequential(*model3)
112
+ self.model4 = nn.Sequential(*model4)
113
+ self.model5 = nn.Sequential(*model5)
114
+ self.model6 = nn.Sequential(*model6)
115
+ self.model7 = nn.Sequential(*model7)
116
+ self.model8up = nn.Sequential(*model8up)
117
+ self.model8 = nn.Sequential(*model8)
118
+ self.model9up = nn.Sequential(*model9up)
119
+ self.model9 = nn.Sequential(*model9)
120
+ self.model10up = nn.Sequential(*model10up)
121
+ self.model10 = nn.Sequential(*model10)
122
+ self.model3short8 = nn.Sequential(*model3short8)
123
+ self.model2short9 = nn.Sequential(*model2short9)
124
+ self.model1short10 = nn.Sequential(*model1short10)
125
+
126
+ self.model_class = nn.Sequential(*model_class)
127
+ self.model_out = nn.Sequential(*model_out)
128
+
129
+ self.upsample4 = nn.Sequential(*[nn.Upsample(scale_factor=4, mode='bilinear'),])
130
+ self.softmax = nn.Sequential(*[nn.Softmax(dim=1),])
131
+
132
+ def forward(self, input_A, input_B=None, mask_B=None):
133
+ if(input_B is None):
134
+ input_B = torch.cat((input_A*0, input_A*0), dim=1)
135
+ if(mask_B is None):
136
+ mask_B = input_A*0
137
+
138
+ conv1_2 = self.model1(torch.cat((self.normalize_l(input_A),self.normalize_ab(input_B),mask_B),dim=1))
139
+ conv2_2 = self.model2(conv1_2[:,:,::2,::2])
140
+ conv3_3 = self.model3(conv2_2[:,:,::2,::2])
141
+ conv4_3 = self.model4(conv3_3[:,:,::2,::2])
142
+ conv5_3 = self.model5(conv4_3)
143
+ conv6_3 = self.model6(conv5_3)
144
+ conv7_3 = self.model7(conv6_3)
145
+
146
+ conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3)
147
+ conv8_3 = self.model8(conv8_up)
148
+ conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
149
+ conv9_3 = self.model9(conv9_up)
150
+ conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
151
+ conv10_2 = self.model10(conv10_up)
152
+ out_reg = self.model_out(conv10_2)
153
+
154
+ conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
155
+ conv9_3 = self.model9(conv9_up)
156
+ conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
157
+ conv10_2 = self.model10(conv10_up)
158
+ out_reg = self.model_out(conv10_2)
159
+
160
+ return self.unnormalize_ab(out_reg)
161
+
162
+ def siggraph17(pretrained=True):
163
+ model = SIGGRAPHGenerator()
164
+ if(pretrained):
165
+ import torch.utils.model_zoo as model_zoo
166
+ model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/siggraph17-df00044c.pth',map_location='cpu',check_hash=True))
167
+ return model
168
+
models/deep_colorization/colorizers/util.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from PIL import Image
3
+ import numpy as np
4
+ from skimage import color
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from IPython import embed
8
+
9
+ def load_img(img_path):
10
+ out_np = np.asarray(Image.open(img_path))
11
+ if(out_np.ndim==2):
12
+ out_np = np.tile(out_np[:,:,None],3)
13
+ return out_np
14
+
15
+ def resize_img(img, HW=(256,256), resample=3):
16
+ return np.asarray(Image.fromarray(img).resize((HW[1],HW[0]), resample=resample))
17
+
18
+ def preprocess_img(img_rgb_orig, HW=(256,256), resample=3):
19
+ # return original size L and resized L as torch Tensors
20
+ img_rgb_rs = resize_img(img_rgb_orig, HW=HW, resample=resample)
21
+
22
+ img_lab_orig = color.rgb2lab(img_rgb_orig)
23
+ img_lab_rs = color.rgb2lab(img_rgb_rs)
24
+
25
+ img_l_orig = img_lab_orig[:,:,0]
26
+ img_l_rs = img_lab_rs[:,:,0]
27
+
28
+ tens_orig_l = torch.Tensor(img_l_orig)[None,None,:,:]
29
+ tens_rs_l = torch.Tensor(img_l_rs)[None,None,:,:]
30
+
31
+ return (tens_orig_l, tens_rs_l)
32
+
33
+ def postprocess_tens(tens_orig_l, out_ab, mode='bilinear'):
34
+ # tens_orig_l 1 x 1 x H_orig x W_orig
35
+ # out_ab 1 x 2 x H x W
36
+
37
+ HW_orig = tens_orig_l.shape[2:]
38
+ HW = out_ab.shape[2:]
39
+
40
+ # call resize function if needed
41
+ if(HW_orig[0]!=HW[0] or HW_orig[1]!=HW[1]):
42
+ out_ab_orig = F.interpolate(out_ab, size=HW_orig, mode='bilinear')
43
+ else:
44
+ out_ab_orig = out_ab
45
+
46
+ out_lab_orig = torch.cat((tens_orig_l, out_ab_orig), dim=1)
47
+ return color.lab2rgb(out_lab_orig.data.cpu().numpy()[0,...].transpose((1,2,0)))
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ipython==8.5.0
2
+ moviepy==1.0.3
3
+ numpy==1.23.2
4
+ opencv_python==4.7.0.68
5
+ Pillow==9.4.0
6
+ Pillow==9.5.0
7
+ skimage==0.0
8
+ streamlit==1.22.0
9
+ torch==1.13.1
10
+ tqdm==4.64.1
11
+
12
+