Image-Colorizer / app.py
sayed99's picture
project upload
cc9dfd7
from fastapi import FastAPI, File, UploadFile, Form, Query, HTTPException
from fastapi.responses import FileResponse
from pydantic import BaseModel
from typing import Optional, List
import os
import shutil
from pathlib import Path
import uuid
import sys
import torch
# Fix for 'collections' has no attribute 'Sized' issue
import collections
import collections.abc
for typ in ['Sized', 'Iterable', 'Mapping', 'MutableMapping', 'Sequence', 'MutableSequence']:
if not hasattr(collections, typ):
setattr(collections, typ, getattr(collections.abc, typ))
# Add DeOldify directory to path
sys.path.append('./DeOldify')
torch.backends.cudnn.benchmark = False
# Instead of adding models directory to path, set it as the working directory for model loading
os.makedirs('models', exist_ok=True)
# Create symbolic links to the model files
if not os.path.exists('models/ColorizeArtistic_gen.pth'):
os.symlink(os.path.abspath('./DeOldify/models/ColorizeArtistic_gen.pth'),
'models/ColorizeArtistic_gen.pth')
if not os.path.exists('models/ColorizeStable_gen.pth'):
os.symlink(os.path.abspath(
'./DeOldify/models/ColorizeStable_gen.pth'), 'models/ColorizeStable_gen.pth')
if not os.path.exists('models/ColorizeVideo_gen.pth'):
os.symlink(os.path.abspath(
'./DeOldify/models/ColorizeVideo_gen.pth'), 'models/ColorizeVideo_gen.pth')
# DeOldify imports
try:
from deoldify.visualize import get_image_colorizer
from deoldify.device_id import DeviceId
from deoldify import device
except Exception as e:
print(f"Error importing DeOldify: {e}")
# Set GPU device
device.set(device=DeviceId.GPU0)
app = FastAPI(title="Image Colorization API",
description="API for colorizing black and white images using DeOldify")
# Create directories if they don't exist
os.makedirs("input_images", exist_ok=True)
os.makedirs("output_images", exist_ok=True)
os.makedirs("multiple_renders", exist_ok=True)
class ColorizationResult(BaseModel):
output_path: str
render_factor: int
model_type: str
class MultipleColorizationResult(BaseModel):
output_paths: List[str]
render_factors: List[int]
model_type: str
@app.post("/colorize", response_model=ColorizationResult)
async def colorize_image(
file: UploadFile = File(...),
render_factor: int = Query(
10, ge=5, le=50, description="Render factor (higher is better quality but slower)"),
artistic: bool = Query(
True, description="Use artistic model (True) or stable model (False)"),
):
"""
Colorize a black and white image with the specified render factor and model type.
"""
# Generate a unique filename to avoid conflicts
file_id = str(uuid.uuid4())
file_extension = os.path.splitext(file.filename)[1]
input_path = f"input_images/{file_id}{file_extension}"
output_path = f"output_images/{file_id}_colorized{file_extension}"
# Save uploaded file
with open(input_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
try:
# Get the appropriate colorizer based on model type
colorizer = get_image_colorizer(
render_factor=render_factor, artistic=artistic)
# Colorize the image and save result (with watermark=False)
result_path = colorizer.plot_transformed_image(
path=input_path,
render_factor=render_factor,
compare=False,
watermarked=False
)
# Move the result to our desired output path
shutil.copy(result_path, output_path)
return ColorizationResult(
output_path=output_path,
render_factor=render_factor,
model_type="artistic" if artistic else "stable"
)
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Colorization failed: {str(e)}")
@app.post("/colorize_multiple", response_model=MultipleColorizationResult)
async def colorize_image_multiple(
file: UploadFile = File(...),
min_render_factor: int = Query(
5, ge=5, le=45, description="Minimum render factor"),
max_render_factor: int = Query(
50, ge=10, le=50, description="Maximum render factor"),
step: int = Query(
1, ge=1, le=10, description="Step size between render factors"),
artistic: bool = Query(
True, description="Use artistic model (True) or stable model (False)"),
):
"""
Colorize a black and white image with multiple render factors.
"""
# Generate a unique folder for this batch of renderings
batch_id = str(uuid.uuid4())
batch_folder = f"multiple_renders/{batch_id}"
os.makedirs(batch_folder, exist_ok=True)
# Save uploaded file
file_extension = os.path.splitext(file.filename)[1]
input_path = f"{batch_folder}/input{file_extension}"
with open(input_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
try:
# Get the appropriate colorizer
colorizer = get_image_colorizer(
render_factor=max_render_factor, artistic=artistic)
output_paths = []
render_factors = []
# Process the image with multiple render factors
for render_factor in range(min_render_factor, max_render_factor + 1, step):
output_file = f"{batch_folder}/colorized_{render_factor}{file_extension}"
# Colorize the image with this render factor
result_path = colorizer.plot_transformed_image(
path=input_path,
render_factor=render_factor,
compare=False,
watermarked=False
)
# Move the result to our desired output path
shutil.copy(result_path, output_file)
output_paths.append(output_file)
render_factors.append(render_factor)
return MultipleColorizationResult(
output_paths=output_paths,
render_factors=render_factors,
model_type="artistic" if artistic else "stable"
)
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Multiple colorization failed: {str(e)}")
@app.get("/image/{image_path:path}")
async def get_image(image_path: str):
"""
Retrieve a colorized image by path.
"""
if not os.path.isfile(image_path):
raise HTTPException(status_code=404, detail="Image not found")
return FileResponse(image_path)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)