File size: 4,248 Bytes
11f0a1f
 
 
 
 
5ad44c2
49027b8
5ad44c2
 
 
 
f777ad0
5ad44c2
 
 
 
 
 
 
 
 
 
 
8795ec2
0f1c238
5ad44c2
 
 
8795ec2
11f0a1f
 
 
 
 
 
 
25c148e
11f0a1f
 
 
 
 
25c148e
11f0a1f
 
 
 
 
 
 
 
 
 
5ad44c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381469e
8795ec2
 
 
 
 
 
 
 
 
5ad44c2
0f1c238
5ad44c2
8795ec2
 
 
605e36a
8795ec2
5ad44c2
8795ec2
5ad44c2
605e36a
5ad44c2
 
 
 
0f1c238
5ad44c2
 
605e36a
5ad44c2
8795ec2
11f0a1f
381469e
 
 
 
605e36a
5ad44c2
605e36a
5ad44c2
0f1c238
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
import huggingface_hub
import onnxruntime as rt
import numpy as np
import cv2
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import Response
import io
from PIL import Image
import imghdr
from typing import Optional

SUPPORTED_FORMATS = {'jpg', 'jpeg', 'png', 'bmp', 'webp', 'tiff'}

def is_valid_image(file_content: bytes) -> Optional[str]:
    image_format = imghdr.what(None, file_content)
    if image_format is None:
        return None
    return image_format.lower()

def process_image_bytes(image_bytes: bytes) -> np.ndarray:
    try:
        image = Image.open(io.BytesIO(image_bytes))
        if image.mode == 'RGBA':
            image = image.convert('RGB')
        img_array = np.array(image)
        return img_array
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Error: {str(e)}")

def get_mask(img, s=1024):
    img = (img / 255).astype(np.float32)
    h, w = h0, w0 = img.shape[:-1]
    h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
    ph, pw = s - h, s - w
    img_input = np.zeros([s, s, 3], dtype=np.float32)
    img_input[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(img, (w, h))
    img_input = np.transpose(img_input, (2, 0, 1))
    img_input = img_input[np.newaxis, :]
    mask = rmbg_model.run(None, {'img': img_input})[0][0]
    mask = np.transpose(mask, (1, 2, 0))
    mask = mask[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
    mask = cv2.resize(mask, (w0, h0))[:, :, np.newaxis]
    return mask

def rmbg_fn(img):
    mask = get_mask(img)
    img = (mask * img + 255 * (1 - mask)).astype(np.uint8)
    mask = (mask * 255).astype(np.uint8)
    img = np.concatenate([img, mask], axis=2, dtype=np.uint8)
    mask = mask.repeat(3, axis=2)
    return mask, img

app = FastAPI()

gradio_app = gr.Blocks()
with gradio_app:
    gr.Markdown("# Anime Remove Background\n\n"
                "![visitor badge](https://api.visitorbadge.io/api/visitors?path=skytnt.animeseg&countColor=%23263759&style=flat&labelStyle=lower)\n\n"
                "demo for [https://github.com/SkyTNT/anime-segmentation/](https://github.com/SkyTNT/anime-segmentation/)")
    with gr.Column():
        input_img = gr.Image(label="input image")
        examples_data = [[f"examples/{x:02d}.jpg"] for x in range(1, 4)]
        examples = gr.Examples(examples=examples_data, inputs=[input_img])
        run_btn = gr.Button(variant="primary")
        with gr.Row():
            output_mask = gr.Image(label="mask", format="png")
            output_img = gr.Image(label="result", image_mode="RGBA", format="png")
        run_btn.click(rmbg_fn, [input_img], [output_mask, output_img])

@app.post("/remove-bg")
async def remove_background(file: UploadFile = File(...)):
    contents = await file.read()
    
    image_format = is_valid_image(contents)
    if not image_format or image_format not in SUPPORTED_FORMATS:
        raise HTTPException(
            status_code=400,
            detail=f"Invalid format: {', '.join(SUPPORTED_FORMATS)}"
        )
    
    try:
        img = process_image_bytes(contents)
        mask = get_mask(img)
        img = (mask * img + 255 * (1 - mask)).astype(np.uint8)
        mask = (mask * 255).astype(np.uint8)
        img = np.concatenate([img, mask], axis=2, dtype=np.uint8)
        
        pil_image = Image.fromarray(img, 'RGBA')
        img_byte_arr = io.BytesIO()
        pil_image.save(img_byte_arr, format='PNG')
        img_byte_arr = img_byte_arr.getvalue()
        
        return Response(
            content=img_byte_arr, 
            media_type="image/png",
            headers={
                "Content-Disposition": f"attachment; filename={file.filename.split('.')[0]}_nobg.png"
            }
        )
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error: {str(e)}")

if __name__ == "__main__":
    providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
    model_path = huggingface_hub.hf_hub_download("skytnt/anime-seg", "isnetis.onnx")
    rmbg_model = rt.InferenceSession(model_path, providers=providers)
    
    app = gr.mount_gradio_app(app, gradio_app, path="/")
    
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)