rajatsingh0702's picture
Update app.py
daa4cce verified
# For plotting
import numpy as np
# For utilities
from timeit import default_timer as timer
# For conversion
import opencv_transforms.transforms as TF
import opencv_transforms.functional as FF
# For everything
import torch
# For our model
import mymodels
# For demo api
import gradio as gr
# To ignore warning
import warnings
warnings.simplefilter("ignore", UserWarning)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
ncluster = 9
nc = 3 * (ncluster + 1)
netC2S = mymodels.Color2Sketch(pretrained=True).to(device)
netG = mymodels.Sketch2Color(nc=nc, pretrained=True).to(device)
transform = TF.Resize((512, 512))
def make_tensor(img):
img = FF.to_tensor(img)
img = FF.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
return img
def predictC2S(img):
final_transform = TF.Resize((img.size[1], img.size[0]))
img = np.array(img)
img = transform(img)
img = make_tensor(img)
start_time = timer()
with torch.inference_mode():
img_edge = netC2S(img.unsqueeze(0).to(device)).squeeze().permute(1, 2, 0).cpu().numpy()
img_edge = FF.to_grayscale(img_edge, num_output_channels=3)
img = FF.to_tensor(img_edge).permute(1, 2, 0).cpu().numpy()
end_time = timer()
img = final_transform(img)
return img, round(end_time - start_time, 3)
def predictS2C(img, ref):
final_transform = TF.Resize((img.size[1], img.size[0]))
img = np.array(img)
ref = np.array(ref)
ref = transform(ref)
img = transform(img)
img = make_tensor(img)
color_palette = mymodels.color_cluster(ref)
for i in range(0, len(color_palette)):
color = color_palette[i]
color_palette[i] = make_tensor(color)
start_time = timer()
with torch.inference_mode():
img_edge = netC2S(img.unsqueeze(0).to(device)).squeeze().permute(1, 2, 0).cpu().numpy()
img_edge = FF.to_grayscale(img_edge, num_output_channels=3)
img = FF.to_tensor(img_edge)
input_tensor = torch.cat([img.cpu()] + color_palette, dim=0).to(device)
with torch.inference_mode():
fake = netG(input_tensor.unsqueeze(0).to(device)).squeeze().permute(1, 2, 0).cpu().numpy()
end_time = timer()
fake = final_transform(fake)
return fake, round(end_time - start_time, 3)
example_list1 = [["./examples/img1.jpg", "./examples/ref1.jpg"],
["./examples/img4.jpg", "./examples/ref4.jpg"],
["./examples/img3.jpg", "./examples/ref3.jpg"],
["./examples/img5.jpeg", "./examples/ref5.jpg"]]
example_list2 = [["./examples/sketch1.jpg"],
["./examples/sketch2.jpg"],
["./examples/sketch3.jpg"],
["./examples/sketch4.jpg"]]
with gr.Blocks() as demo:
gr.Markdown("# Color2Sketch & Sketch2Color")
with gr.Tab("Sketch To Color"):
gr.Markdown("### Enter the **Sketch** & **Reference** on the left side. You can use example list.")
with gr.Row():
with gr.Column():
input1 = [gr.Image(type="pil", label="Sketch"), gr.Image(type="pil", label="Reference")]
with gr.Row():
# Clear Button
gr.ClearButton(input1)
btn1 = gr.Button("Submit")
gr.Examples(examples=example_list1, inputs=input1)
with gr.Column():
output1 = [gr.Image(type="pil", label="Colored Sketch"), gr.Number(label="Prediction time (s)")]
with gr.Tab("Color To Sketch"):
gr.Markdown(
"### Enter the **Colored Sketch** on the left side. You can use example list.")
with gr.Row():
with gr.Column():
input2 = gr.Image(type="pil", label="Color Sketch")
with gr.Row():
# Clear Button
gr.ClearButton(input2)
btn2 = gr.Button("Submit")
gr.Examples(example_list2, inputs=input2)
with gr.Column():
output2 = [gr.Image(type="pil", label="Sketch"), gr.Number(label="Prediction time (s)")]
btn1.click(predictS2C, inputs=input1, outputs=output1)
btn2.click(predictC2S, inputs=input2, outputs=output2)
gr.Markdown("""
### The model is taken from [this GitHub Repo.](https://github.com/delta6189/Anime-Sketch-Colorizer)
Email : rajatsingh072002@gmail.com | My [GitHub Repo](https://github.com/Rajatsingh24/Anime-Sketch2Color-Color2Sketch)
""")
if __name__ == "__main__":
demo.launch(debug=False)