top001 commited on
Commit
f777ad0
1 Parent(s): 381469e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -26
app.py CHANGED
@@ -4,12 +4,41 @@ import huggingface_hub
4
  import onnxruntime as rt
5
  import numpy as np
6
  import cv2
7
- from fastapi import FastAPI, File, UploadFile
8
  from fastapi.responses import Response
9
  import io
10
  from PIL import Image
 
 
11
 
12
- app_fastapi = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def get_mask(img, s=1024):
15
  img = (img / 255).astype(np.float32)
@@ -34,39 +63,64 @@ def rmbg_fn(img):
34
  mask = mask.repeat(3, axis=2)
35
  return mask, img
36
 
37
- @app_fastapi.post("/remove-background")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  async def remove_background(file: UploadFile = File(...)):
39
  contents = await file.read()
40
- nparr = np.frombuffer(contents, np.uint8)
41
- img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
42
 
43
- _, result = rmbg_fn(img)
 
 
 
 
 
44
 
45
- pil_img = Image.fromarray(result)
46
- img_byte_arr = io.BytesIO()
47
- pil_img.save(img_byte_arr, format='PNG')
48
- img_byte_arr = img_byte_arr.getvalue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- return Response(content=img_byte_arr, media_type="image/png")
 
51
 
52
  if __name__ == "__main__":
53
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
54
  model_path = huggingface_hub.hf_hub_download("skytnt/anime-seg", "isnetis.onnx")
55
  rmbg_model = rt.InferenceSession(model_path, providers=providers)
56
 
57
- app = gr.Blocks()
58
- with app:
59
- gr.Markdown("# Anime Remove Background\n\n"
60
- "![visitor badge](https://api.visitorbadge.io/api/visitors?path=skytnt.animeseg&countColor=%23263759&style=flat&labelStyle=lower)\n\n"
61
- "demo for [https://github.com/SkyTNT/anime-segmentation/](https://github.com/SkyTNT/anime-segmentation/)")
62
- with gr.Column():
63
- input_img = gr.Image(label="input image")
64
- examples_data = [[f"examples/{x:02d}.jpg"] for x in range(1, 4)]
65
- examples = gr.Examples(examples=examples_data, inputs=[input_img])
66
- run_btn = gr.Button(variant="primary")
67
- with gr.Row():
68
- output_mask = gr.Image(label="mask", format="png")
69
- output_img = gr.Image(label="result", image_mode="RGBA", format="png")
70
- run_btn.click(rmbg_fn, [input_img], [output_mask, output_img])
71
 
72
- app.launch()
 
 
4
  import onnxruntime as rt
5
  import numpy as np
6
  import cv2
7
+ from fastapi import FastAPI, File, UploadFile, HTTPException
8
  from fastapi.responses import Response
9
  import io
10
  from PIL import Image
11
+ import imghdr
12
+ from typing import Optional
13
 
14
+ SUPPORTED_FORMATS = {'jpg', 'jpeg', 'png', 'bmp', 'webp', 'tiff'}
15
+
16
+ def is_valid_image(file_content: bytes) -> Optional[str]:
17
+ image_format = imghdr.what(None, file_content)
18
+ if image_format is None:
19
+ return None
20
+ return image_format.lower()
21
+
22
+ def process_image_bytes(image_bytes: bytes) -> np.ndarray:
23
+ try:
24
+ image = Image.open(io.BytesIO(image_bytes))
25
+
26
+ if image.mode in ('RGBA', 'LA'):
27
+ background = Image.new('RGB', image.size, (255, 255, 255))
28
+ background.paste(image, mask=image.getchannel('A'))
29
+ image = background
30
+
31
+ if image.mode != 'RGB':
32
+ image = image.convert('RGB')
33
+
34
+ img_array = np.array(image)
35
+
36
+ if len(img_array.shape) == 3:
37
+ img_array = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
38
+
39
+ return img_array
40
+ except Exception as e:
41
+ raise HTTPException(status_code=400, detail=f"Error: {str(e)}")
42
 
43
  def get_mask(img, s=1024):
44
  img = (img / 255).astype(np.float32)
 
63
  mask = mask.repeat(3, axis=2)
64
  return mask, img
65
 
66
+ app = FastAPI()
67
+
68
+ gradio_app = gr.Blocks()
69
+ with gradio_app:
70
+ gr.Markdown("# Anime Remove Background\n\n"
71
+ "![visitor badge](https://api.visitorbadge.io/api/visitors?path=skytnt.animeseg&countColor=%23263759&style=flat&labelStyle=lower)\n\n"
72
+ "demo for [https://github.com/SkyTNT/anime-segmentation/](https://github.com/SkyTNT/anime-segmentation/)")
73
+ with gr.Column():
74
+ input_img = gr.Image(label="input image")
75
+ examples_data = [[f"examples/{x:02d}.jpg"] for x in range(1, 4)]
76
+ examples = gr.Examples(examples=examples_data, inputs=[input_img])
77
+ run_btn = gr.Button(variant="primary")
78
+ with gr.Row():
79
+ output_mask = gr.Image(label="mask", format="png")
80
+ output_img = gr.Image(label="result", image_mode="RGBA", format="png")
81
+ run_btn.click(rmbg_fn, [input_img], [output_mask, output_img])
82
+
83
+ @app.post("/remove-bg")
84
  async def remove_background(file: UploadFile = File(...)):
85
  contents = await file.read()
 
 
86
 
87
+ image_format = is_valid_image(contents)
88
+ if not image_format or image_format not in SUPPORTED_FORMATS:
89
+ raise HTTPException(
90
+ status_code=400,
91
+ detail=f"Invalid format: {', '.join(SUPPORTED_FORMATS)}"
92
+ )
93
 
94
+ try:
95
+ img = process_image_bytes(contents)
96
+
97
+ mask = get_mask(img)
98
+ result = (mask * img).astype(np.uint8)
99
+
100
+ rgba = np.concatenate([result, (mask * 255).astype(np.uint8)], axis=2)
101
+
102
+ pil_image = Image.fromarray(rgba)
103
+ img_byte_arr = io.BytesIO()
104
+ pil_image.save(img_byte_arr, format='PNG', optimize=True)
105
+ img_byte_arr = img_byte_arr.getvalue()
106
+
107
+ return Response(
108
+ content=img_byte_arr,
109
+ media_type="image/png",
110
+ headers={
111
+ "Content-Disposition": "attachment; filename=result.png"
112
+ }
113
+ )
114
 
115
+ except Exception as e:
116
+ raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
117
 
118
  if __name__ == "__main__":
119
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
120
  model_path = huggingface_hub.hf_hub_download("skytnt/anime-seg", "isnetis.onnx")
121
  rmbg_model = rt.InferenceSession(model_path, providers=providers)
122
 
123
+ app = gr.mount_gradio_app(app, gradio_app, path="/")
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
+ import uvicorn
126
+ uvicorn.run(app, host="0.0.0.0", port=7860)