Spaces:
Sleeping
Sleeping
File size: 4,248 Bytes
11f0a1f 5ad44c2 49027b8 5ad44c2 f777ad0 5ad44c2 8795ec2 0f1c238 5ad44c2 8795ec2 11f0a1f 25c148e 11f0a1f 25c148e 11f0a1f 5ad44c2 381469e 8795ec2 5ad44c2 0f1c238 5ad44c2 8795ec2 605e36a 8795ec2 5ad44c2 8795ec2 5ad44c2 605e36a 5ad44c2 0f1c238 5ad44c2 605e36a 5ad44c2 8795ec2 11f0a1f 381469e 605e36a 5ad44c2 605e36a 5ad44c2 0f1c238 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import gradio as gr
import huggingface_hub
import onnxruntime as rt
import numpy as np
import cv2
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import Response
import io
from PIL import Image
import imghdr
from typing import Optional
SUPPORTED_FORMATS = {'jpg', 'jpeg', 'png', 'bmp', 'webp', 'tiff'}
def is_valid_image(file_content: bytes) -> Optional[str]:
image_format = imghdr.what(None, file_content)
if image_format is None:
return None
return image_format.lower()
def process_image_bytes(image_bytes: bytes) -> np.ndarray:
try:
image = Image.open(io.BytesIO(image_bytes))
if image.mode == 'RGBA':
image = image.convert('RGB')
img_array = np.array(image)
return img_array
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error: {str(e)}")
def get_mask(img, s=1024):
img = (img / 255).astype(np.float32)
h, w = h0, w0 = img.shape[:-1]
h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
ph, pw = s - h, s - w
img_input = np.zeros([s, s, 3], dtype=np.float32)
img_input[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(img, (w, h))
img_input = np.transpose(img_input, (2, 0, 1))
img_input = img_input[np.newaxis, :]
mask = rmbg_model.run(None, {'img': img_input})[0][0]
mask = np.transpose(mask, (1, 2, 0))
mask = mask[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
mask = cv2.resize(mask, (w0, h0))[:, :, np.newaxis]
return mask
def rmbg_fn(img):
mask = get_mask(img)
img = (mask * img + 255 * (1 - mask)).astype(np.uint8)
mask = (mask * 255).astype(np.uint8)
img = np.concatenate([img, mask], axis=2, dtype=np.uint8)
mask = mask.repeat(3, axis=2)
return mask, img
app = FastAPI()
gradio_app = gr.Blocks()
with gradio_app:
gr.Markdown("# Anime Remove Background\n\n"
"![visitor badge](https://api.visitorbadge.io/api/visitors?path=skytnt.animeseg&countColor=%23263759&style=flat&labelStyle=lower)\n\n"
"demo for [https://github.com/SkyTNT/anime-segmentation/](https://github.com/SkyTNT/anime-segmentation/)")
with gr.Column():
input_img = gr.Image(label="input image")
examples_data = [[f"examples/{x:02d}.jpg"] for x in range(1, 4)]
examples = gr.Examples(examples=examples_data, inputs=[input_img])
run_btn = gr.Button(variant="primary")
with gr.Row():
output_mask = gr.Image(label="mask", format="png")
output_img = gr.Image(label="result", image_mode="RGBA", format="png")
run_btn.click(rmbg_fn, [input_img], [output_mask, output_img])
@app.post("/remove-bg")
async def remove_background(file: UploadFile = File(...)):
contents = await file.read()
image_format = is_valid_image(contents)
if not image_format or image_format not in SUPPORTED_FORMATS:
raise HTTPException(
status_code=400,
detail=f"Invalid format: {', '.join(SUPPORTED_FORMATS)}"
)
try:
img = process_image_bytes(contents)
mask = get_mask(img)
img = (mask * img + 255 * (1 - mask)).astype(np.uint8)
mask = (mask * 255).astype(np.uint8)
img = np.concatenate([img, mask], axis=2, dtype=np.uint8)
pil_image = Image.fromarray(img, 'RGBA')
img_byte_arr = io.BytesIO()
pil_image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
return Response(
content=img_byte_arr,
media_type="image/png",
headers={
"Content-Disposition": f"attachment; filename={file.filename.split('.')[0]}_nobg.png"
}
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
if __name__ == "__main__":
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
model_path = huggingface_hub.hf_hub_download("skytnt/anime-seg", "isnetis.onnx")
rmbg_model = rt.InferenceSession(model_path, providers=providers)
app = gr.mount_gradio_app(app, gradio_app, path="/")
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |