leonelhs commited on
Commit
2012f74
1 Parent(s): 8d8bae6
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +63 -0
  3. requirements.txt +4 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .idea/
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from carvekit.api.interface import Interface
4
+ from carvekit.ml.wrap.fba_matting import FBAMatting
5
+ from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
6
+ from carvekit.pipelines.postprocessing import MattingMethod
7
+ from carvekit.pipelines.preprocessing import PreprocessingStub
8
+ from carvekit.trimap.generator import TrimapGenerator
9
+
10
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
+
12
+ # Check doc strings for more information
13
+ seg_net = TracerUniversalB7(device=device, batch_size=1)
14
+
15
+ fba = FBAMatting(device=device,
16
+ input_tensor_size=2048,
17
+ batch_size=1)
18
+
19
+ trimap = TrimapGenerator()
20
+
21
+ preprocessing = PreprocessingStub()
22
+
23
+ postprocessing = MattingMethod(matting_module=fba,
24
+ trimap_generator=trimap,
25
+ device=device)
26
+
27
+ interface = Interface(pre_pipe=preprocessing,
28
+ post_pipe=postprocessing,
29
+ seg_pipe=seg_net)
30
+
31
+
32
+ def predict(image):
33
+ return interface([image])[0]
34
+
35
+
36
+ footer = r"""
37
+ <center>
38
+ <b>
39
+ Demo based on <a href='https://github.com/OPHoperHPO/image-background-remove-tool'>CarveKit</a>
40
+ </b>
41
+ </center>
42
+ """
43
+
44
+ with gr.Blocks(title="CarveKit") as app:
45
+ gr.HTML("<center><h1>Image Remove Background</h1></center>")
46
+ with gr.Row():
47
+ with gr.Column():
48
+ input_img = gr.Image(type="pil", label="Input image")
49
+ run_btn = gr.Button(variant="primary")
50
+ with gr.Column():
51
+ output_img = gr.Image(type="pil", label="result")
52
+
53
+ run_btn.click(predict, [input_img], [output_img])
54
+
55
+ with gr.Row():
56
+ examples_data = [[f"examples/{x:02d}.jpg"] for x in range(1, 4)]
57
+ examples = gr.Dataset(components=[input_img], samples=examples_data)
58
+ examples.click(lambda x: x[0], [examples], [input_img])
59
+
60
+ with gr.Row():
61
+ gr.HTML(footer)
62
+
63
+ app.launch(share=False, debug=True, enable_queue=True, show_error=True)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ pillow~=9.5.0
2
+ torch>=2.0.1
3
+ gradio~=3.35.2
4
+ carvekit~=4.1.0