File size: 6,544 Bytes
cc9dfd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
from fastapi import FastAPI, File, UploadFile, Form, Query, HTTPException
from fastapi.responses import FileResponse
from pydantic import BaseModel
from typing import Optional, List
import os
import shutil
from pathlib import Path
import uuid
import sys
import torch

# Fix for 'collections' has no attribute 'Sized' issue
import collections
import collections.abc
for typ in ['Sized', 'Iterable', 'Mapping', 'MutableMapping', 'Sequence', 'MutableSequence']:
    if not hasattr(collections, typ):
        setattr(collections, typ, getattr(collections.abc, typ))

# Add DeOldify directory to path
sys.path.append('./DeOldify')

torch.backends.cudnn.benchmark = False

# Instead of adding models directory to path, set it as the working directory for model loading
os.makedirs('models', exist_ok=True)
# Create symbolic links to the model files
if not os.path.exists('models/ColorizeArtistic_gen.pth'):
    os.symlink(os.path.abspath('./DeOldify/models/ColorizeArtistic_gen.pth'),
               'models/ColorizeArtistic_gen.pth')
if not os.path.exists('models/ColorizeStable_gen.pth'):
    os.symlink(os.path.abspath(
        './DeOldify/models/ColorizeStable_gen.pth'), 'models/ColorizeStable_gen.pth')
if not os.path.exists('models/ColorizeVideo_gen.pth'):
    os.symlink(os.path.abspath(
        './DeOldify/models/ColorizeVideo_gen.pth'), 'models/ColorizeVideo_gen.pth')

# DeOldify imports
try:
    from deoldify.visualize import get_image_colorizer
    from deoldify.device_id import DeviceId
    from deoldify import device
except Exception as e:
    print(f"Error importing DeOldify: {e}")

# Set GPU device
device.set(device=DeviceId.GPU0)

app = FastAPI(title="Image Colorization API",
              description="API for colorizing black and white images using DeOldify")

# Create directories if they don't exist
os.makedirs("input_images", exist_ok=True)
os.makedirs("output_images", exist_ok=True)
os.makedirs("multiple_renders", exist_ok=True)


class ColorizationResult(BaseModel):
    output_path: str
    render_factor: int
    model_type: str


class MultipleColorizationResult(BaseModel):
    output_paths: List[str]
    render_factors: List[int]
    model_type: str


@app.post("/colorize", response_model=ColorizationResult)
async def colorize_image(
    file: UploadFile = File(...),
    render_factor: int = Query(
        10, ge=5, le=50, description="Render factor (higher is better quality but slower)"),
    artistic: bool = Query(
        True, description="Use artistic model (True) or stable model (False)"),
):
    """
    Colorize a black and white image with the specified render factor and model type.
    """
    # Generate a unique filename to avoid conflicts
    file_id = str(uuid.uuid4())
    file_extension = os.path.splitext(file.filename)[1]
    input_path = f"input_images/{file_id}{file_extension}"
    output_path = f"output_images/{file_id}_colorized{file_extension}"

    # Save uploaded file
    with open(input_path, "wb") as buffer:
        shutil.copyfileobj(file.file, buffer)

    try:
        # Get the appropriate colorizer based on model type
        colorizer = get_image_colorizer(
            render_factor=render_factor, artistic=artistic)

        # Colorize the image and save result (with watermark=False)
        result_path = colorizer.plot_transformed_image(
            path=input_path,
            render_factor=render_factor,
            compare=False,
            watermarked=False
        )

        # Move the result to our desired output path
        shutil.copy(result_path, output_path)

        return ColorizationResult(
            output_path=output_path,
            render_factor=render_factor,
            model_type="artistic" if artistic else "stable"
        )
    except Exception as e:
        raise HTTPException(
            status_code=500, detail=f"Colorization failed: {str(e)}")


@app.post("/colorize_multiple", response_model=MultipleColorizationResult)
async def colorize_image_multiple(
    file: UploadFile = File(...),
    min_render_factor: int = Query(
        5, ge=5, le=45, description="Minimum render factor"),
    max_render_factor: int = Query(
        50, ge=10, le=50, description="Maximum render factor"),
    step: int = Query(
        1, ge=1, le=10, description="Step size between render factors"),
    artistic: bool = Query(
        True, description="Use artistic model (True) or stable model (False)"),
):
    """
    Colorize a black and white image with multiple render factors.
    """
    # Generate a unique folder for this batch of renderings
    batch_id = str(uuid.uuid4())
    batch_folder = f"multiple_renders/{batch_id}"
    os.makedirs(batch_folder, exist_ok=True)

    # Save uploaded file
    file_extension = os.path.splitext(file.filename)[1]
    input_path = f"{batch_folder}/input{file_extension}"

    with open(input_path, "wb") as buffer:
        shutil.copyfileobj(file.file, buffer)

    try:
        # Get the appropriate colorizer
        colorizer = get_image_colorizer(
            render_factor=max_render_factor, artistic=artistic)

        output_paths = []
        render_factors = []

        # Process the image with multiple render factors
        for render_factor in range(min_render_factor, max_render_factor + 1, step):
            output_file = f"{batch_folder}/colorized_{render_factor}{file_extension}"

            # Colorize the image with this render factor
            result_path = colorizer.plot_transformed_image(
                path=input_path,
                render_factor=render_factor,
                compare=False,
                watermarked=False
            )

            # Move the result to our desired output path
            shutil.copy(result_path, output_file)

            output_paths.append(output_file)
            render_factors.append(render_factor)

        return MultipleColorizationResult(
            output_paths=output_paths,
            render_factors=render_factors,
            model_type="artistic" if artistic else "stable"
        )
    except Exception as e:
        raise HTTPException(
            status_code=500, detail=f"Multiple colorization failed: {str(e)}")


@app.get("/image/{image_path:path}")
async def get_image(image_path: str):
    """
    Retrieve a colorized image by path.
    """
    if not os.path.isfile(image_path):
        raise HTTPException(status_code=404, detail="Image not found")

    return FileResponse(image_path)

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)