|
from fastapi import FastAPI, File, UploadFile, Form, Query, HTTPException |
|
from fastapi.responses import FileResponse |
|
from pydantic import BaseModel |
|
from typing import Optional, List |
|
import os |
|
import shutil |
|
from pathlib import Path |
|
import uuid |
|
import sys |
|
import torch |
|
|
|
|
|
import collections |
|
import collections.abc |
|
for typ in ['Sized', 'Iterable', 'Mapping', 'MutableMapping', 'Sequence', 'MutableSequence']: |
|
if not hasattr(collections, typ): |
|
setattr(collections, typ, getattr(collections.abc, typ)) |
|
|
|
|
|
sys.path.append('./DeOldify') |
|
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
os.makedirs('models', exist_ok=True) |
|
|
|
if not os.path.exists('models/ColorizeArtistic_gen.pth'): |
|
os.symlink(os.path.abspath('./DeOldify/models/ColorizeArtistic_gen.pth'), |
|
'models/ColorizeArtistic_gen.pth') |
|
if not os.path.exists('models/ColorizeStable_gen.pth'): |
|
os.symlink(os.path.abspath( |
|
'./DeOldify/models/ColorizeStable_gen.pth'), 'models/ColorizeStable_gen.pth') |
|
if not os.path.exists('models/ColorizeVideo_gen.pth'): |
|
os.symlink(os.path.abspath( |
|
'./DeOldify/models/ColorizeVideo_gen.pth'), 'models/ColorizeVideo_gen.pth') |
|
|
|
|
|
try: |
|
from deoldify.visualize import get_image_colorizer |
|
from deoldify.device_id import DeviceId |
|
from deoldify import device |
|
except Exception as e: |
|
print(f"Error importing DeOldify: {e}") |
|
|
|
|
|
device.set(device=DeviceId.GPU0) |
|
|
|
app = FastAPI(title="Image Colorization API", |
|
description="API for colorizing black and white images using DeOldify") |
|
|
|
|
|
os.makedirs("input_images", exist_ok=True) |
|
os.makedirs("output_images", exist_ok=True) |
|
os.makedirs("multiple_renders", exist_ok=True) |
|
|
|
|
|
class ColorizationResult(BaseModel): |
|
output_path: str |
|
render_factor: int |
|
model_type: str |
|
|
|
|
|
class MultipleColorizationResult(BaseModel): |
|
output_paths: List[str] |
|
render_factors: List[int] |
|
model_type: str |
|
|
|
|
|
@app.post("/colorize", response_model=ColorizationResult) |
|
async def colorize_image( |
|
file: UploadFile = File(...), |
|
render_factor: int = Query( |
|
10, ge=5, le=50, description="Render factor (higher is better quality but slower)"), |
|
artistic: bool = Query( |
|
True, description="Use artistic model (True) or stable model (False)"), |
|
): |
|
""" |
|
Colorize a black and white image with the specified render factor and model type. |
|
""" |
|
|
|
file_id = str(uuid.uuid4()) |
|
file_extension = os.path.splitext(file.filename)[1] |
|
input_path = f"input_images/{file_id}{file_extension}" |
|
output_path = f"output_images/{file_id}_colorized{file_extension}" |
|
|
|
|
|
with open(input_path, "wb") as buffer: |
|
shutil.copyfileobj(file.file, buffer) |
|
|
|
try: |
|
|
|
colorizer = get_image_colorizer( |
|
render_factor=render_factor, artistic=artistic) |
|
|
|
|
|
result_path = colorizer.plot_transformed_image( |
|
path=input_path, |
|
render_factor=render_factor, |
|
compare=False, |
|
watermarked=False |
|
) |
|
|
|
|
|
shutil.copy(result_path, output_path) |
|
|
|
return ColorizationResult( |
|
output_path=output_path, |
|
render_factor=render_factor, |
|
model_type="artistic" if artistic else "stable" |
|
) |
|
except Exception as e: |
|
raise HTTPException( |
|
status_code=500, detail=f"Colorization failed: {str(e)}") |
|
|
|
|
|
@app.post("/colorize_multiple", response_model=MultipleColorizationResult) |
|
async def colorize_image_multiple( |
|
file: UploadFile = File(...), |
|
min_render_factor: int = Query( |
|
5, ge=5, le=45, description="Minimum render factor"), |
|
max_render_factor: int = Query( |
|
50, ge=10, le=50, description="Maximum render factor"), |
|
step: int = Query( |
|
1, ge=1, le=10, description="Step size between render factors"), |
|
artistic: bool = Query( |
|
True, description="Use artistic model (True) or stable model (False)"), |
|
): |
|
""" |
|
Colorize a black and white image with multiple render factors. |
|
""" |
|
|
|
batch_id = str(uuid.uuid4()) |
|
batch_folder = f"multiple_renders/{batch_id}" |
|
os.makedirs(batch_folder, exist_ok=True) |
|
|
|
|
|
file_extension = os.path.splitext(file.filename)[1] |
|
input_path = f"{batch_folder}/input{file_extension}" |
|
|
|
with open(input_path, "wb") as buffer: |
|
shutil.copyfileobj(file.file, buffer) |
|
|
|
try: |
|
|
|
colorizer = get_image_colorizer( |
|
render_factor=max_render_factor, artistic=artistic) |
|
|
|
output_paths = [] |
|
render_factors = [] |
|
|
|
|
|
for render_factor in range(min_render_factor, max_render_factor + 1, step): |
|
output_file = f"{batch_folder}/colorized_{render_factor}{file_extension}" |
|
|
|
|
|
result_path = colorizer.plot_transformed_image( |
|
path=input_path, |
|
render_factor=render_factor, |
|
compare=False, |
|
watermarked=False |
|
) |
|
|
|
|
|
shutil.copy(result_path, output_file) |
|
|
|
output_paths.append(output_file) |
|
render_factors.append(render_factor) |
|
|
|
return MultipleColorizationResult( |
|
output_paths=output_paths, |
|
render_factors=render_factors, |
|
model_type="artistic" if artistic else "stable" |
|
) |
|
except Exception as e: |
|
raise HTTPException( |
|
status_code=500, detail=f"Multiple colorization failed: {str(e)}") |
|
|
|
|
|
@app.get("/image/{image_path:path}") |
|
async def get_image(image_path: str): |
|
""" |
|
Retrieve a colorized image by path. |
|
""" |
|
if not os.path.isfile(image_path): |
|
raise HTTPException(status_code=404, detail="Image not found") |
|
|
|
return FileResponse(image_path) |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|