top001 commited on
Commit
5ad44c2
·
verified ·
1 Parent(s): 49027b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -38
app.py CHANGED
@@ -1,15 +1,44 @@
 
1
  import gradio as gr
2
  import huggingface_hub
3
  import onnxruntime as rt
4
  import numpy as np
5
  import cv2
6
- from PIL import Image
7
- import io
8
- from fastapi import FastAPI, File, UploadFile
9
  from fastapi.responses import Response
10
- import uvicorn
 
 
 
11
 
12
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def get_mask(img, s=1024):
15
  img = (img / 255).astype(np.float32)
@@ -26,11 +55,6 @@ def get_mask(img, s=1024):
26
  mask = cv2.resize(mask, (w0, h0))[:, :, np.newaxis]
27
  return mask
28
 
29
- def process_image(img_array):
30
- mask = get_mask(img_array)
31
- rgba = np.concatenate([img_array, (mask * 255).astype(np.uint8)], axis=2)
32
- return rgba
33
-
34
  def rmbg_fn(img):
35
  mask = get_mask(img)
36
  img = (mask * img + 255 * (1 - mask)).astype(np.uint8)
@@ -39,44 +63,64 @@ def rmbg_fn(img):
39
  mask = mask.repeat(3, axis=2)
40
  return mask, img
41
 
42
- @app.post("/remove-background")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  async def remove_background(file: UploadFile = File(...)):
44
  contents = await file.read()
45
- pil_image = Image.open(io.BytesIO(contents))
46
-
47
- if pil_image.mode != 'RGB':
48
- pil_image = pil_image.convert('RGB')
49
- img_array = np.array(pil_image)
50
 
51
- result_array = process_image(img_array)
 
 
 
 
 
52
 
53
- result_image = Image.fromarray(result_array)
54
- img_byte_arr = io.BytesIO()
55
- result_image.save(img_byte_arr, format='PNG')
56
- img_byte_arr = img_byte_arr.getvalue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- return Response(content=img_byte_arr, media_type="image/png")
 
59
 
60
  if __name__ == "__main__":
61
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
62
  model_path = huggingface_hub.hf_hub_download("skytnt/anime-seg", "isnetis.onnx")
63
  rmbg_model = rt.InferenceSession(model_path, providers=providers)
64
 
65
- gradio_app = gr.Blocks()
66
- with gradio_app:
67
- gr.Markdown("# Anime Remove Background\n\n"
68
- "![visitor badge](https://api.visitorbadge.io/api/visitors?path=skytnt.animeseg&countColor=%23263759&style=flat&labelStyle=lower)\n\n"
69
- "demo for [https://github.com/SkyTNT/anime-segmentation/](https://github.com/SkyTNT/anime-segmentation/)")
70
- with gr.Column():
71
- input_img = gr.Image(label="input image")
72
- examples_data = [[f"examples/{x:02d}.jpg"] for x in range(1, 4)]
73
- examples = gr.Examples(examples=examples_data, inputs=[input_img])
74
- run_btn = gr.Button(variant="primary")
75
- with gr.Row():
76
- output_mask = gr.Image(label="mask", format="png")
77
- output_img = gr.Image(label="result", image_mode="RGBA", format="png")
78
- run_btn.click(rmbg_fn, [input_img], [output_mask, output_img])
79
-
80
- gradio_app = gr.mount_gradio_app(app, gradio_app, path="/")
81
 
 
82
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ # app.py
2
  import gradio as gr
3
  import huggingface_hub
4
  import onnxruntime as rt
5
  import numpy as np
6
  import cv2
7
+ from fastapi import FastAPI, File, UploadFile, HTTPException
 
 
8
  from fastapi.responses import Response
9
+ import io
10
+ from PIL import Image
11
+ import imghdr
12
+ from typing import Optional
13
 
14
+ SUPPORTED_FORMATS = {'jpg', 'jpeg', 'png', 'bmp', 'webp', 'tiff'}
15
+
16
+ def is_valid_image(file_content: bytes) -> Optional[str]:
17
+ image_format = imghdr.what(None, file_content)
18
+ if image_format is None:
19
+ return None
20
+ return image_format.lower()
21
+
22
+ def process_image_bytes(image_bytes: bytes) -> np.ndarray:
23
+ try:
24
+ image = Image.open(io.BytesIO(image_bytes))
25
+
26
+ if image.mode in ('RGBA', 'LA'):
27
+ background = Image.new('RGB', image.size, (255, 255, 255))
28
+ background.paste(image, mask=image.getchannel('A'))
29
+ image = background
30
+
31
+ if image.mode != 'RGB':
32
+ image = image.convert('RGB')
33
+
34
+ img_array = np.array(image)
35
+
36
+ if len(img_array.shape) == 3:
37
+ img_array = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
38
+
39
+ return img_array
40
+ except Exception as e:
41
+ raise HTTPException(status_code=400, detail=f"Error: {str(e)}")
42
 
43
  def get_mask(img, s=1024):
44
  img = (img / 255).astype(np.float32)
 
55
  mask = cv2.resize(mask, (w0, h0))[:, :, np.newaxis]
56
  return mask
57
 
 
 
 
 
 
58
  def rmbg_fn(img):
59
  mask = get_mask(img)
60
  img = (mask * img + 255 * (1 - mask)).astype(np.uint8)
 
63
  mask = mask.repeat(3, axis=2)
64
  return mask, img
65
 
66
+ app = FastAPI()
67
+
68
+ gradio_app = gr.Blocks()
69
+ with gradio_app:
70
+ gr.Markdown("# Anime Remove Background\n\n"
71
+ "![visitor badge](https://api.visitorbadge.io/api/visitors?path=skytnt.animeseg&countColor=%23263759&style=flat&labelStyle=lower)\n\n"
72
+ "demo for [https://github.com/SkyTNT/anime-segmentation/](https://github.com/SkyTNT/anime-segmentation/)")
73
+ with gr.Column():
74
+ input_img = gr.Image(label="input image")
75
+ examples_data = [[f"examples/{x:02d}.jpg"] for x in range(1, 4)]
76
+ examples = gr.Examples(examples=examples_data, inputs=[input_img])
77
+ run_btn = gr.Button(variant="primary")
78
+ with gr.Row():
79
+ output_mask = gr.Image(label="mask", format="png")
80
+ output_img = gr.Image(label="result", image_mode="RGBA", format="png")
81
+ run_btn.click(rmbg_fn, [input_img], [output_mask, output_img])
82
+
83
+ @app.post("/remove-bg")
84
  async def remove_background(file: UploadFile = File(...)):
85
  contents = await file.read()
 
 
 
 
 
86
 
87
+ image_format = is_valid_image(contents)
88
+ if not image_format or image_format not in SUPPORTED_FORMATS:
89
+ raise HTTPException(
90
+ status_code=400,
91
+ detail=f"Invalid format: {', '.join(SUPPORTED_FORMATS)}"
92
+ )
93
 
94
+ try:
95
+ img = process_image_bytes(contents)
96
+
97
+ mask = get_mask(img)
98
+ result = (mask * img).astype(np.uint8)
99
+
100
+ rgba = np.concatenate([result, (mask * 255).astype(np.uint8)], axis=2)
101
+
102
+ pil_image = Image.fromarray(rgba)
103
+ img_byte_arr = io.BytesIO()
104
+ pil_image.save(img_byte_arr, format='PNG', optimize=True)
105
+ img_byte_arr = img_byte_arr.getvalue()
106
+
107
+ return Response(
108
+ content=img_byte_arr,
109
+ media_type="image/png",
110
+ headers={
111
+ "Content-Disposition": "attachment; filename=result.png"
112
+ }
113
+ )
114
 
115
+ except Exception as e:
116
+ raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
117
 
118
  if __name__ == "__main__":
119
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
120
  model_path = huggingface_hub.hf_hub_download("skytnt/anime-seg", "isnetis.onnx")
121
  rmbg_model = rt.InferenceSession(model_path, providers=providers)
122
 
123
+ app = gr.mount_gradio_app(app, gradio_app, path="/")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
+ import uvicorn
126
  uvicorn.run(app, host="0.0.0.0", port=7860)