sigyllly commited on
Commit
1cce0ac
·
verified ·
1 Parent(s): a664879

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -98
app.py CHANGED
@@ -3,96 +3,18 @@ import gradio as gr
3
  from PIL import Image
4
  import torch
5
  import numpy as np
6
- from flask import Flask, request, jsonify, send_file
7
- from io import BytesIO
8
  import threading
9
 
10
  processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
11
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
12
 
13
- app = Flask(__name__)
14
-
15
-
16
- # Define article as a global variable
17
- title = "Interactive demo: zero-shot image segmentation with CLIPSeg"
18
- description = "Demo for using CLIPSeg, a CLIP-based model for zero- and one-shot image segmentation. To use it, simply upload an image and add a text to mask (identify in the image), or use one of the examples below and click 'submit'. Results will show up in a few seconds."
19
- article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.10003'>CLIPSeg: Image Segmentation Using Text and Image Prompts</a> | <a href='https://huggingface.co/docs/transformers/main/en/model_doc/clipseg'>HuggingFace docs</a></p>"
20
-
21
- def process_image(image, prompt):
22
- inputs = processor(
23
- text=prompt, images=image, padding="max_length", return_tensors="pt"
24
- )
25
-
26
- with torch.no_grad():
27
- outputs = model(**inputs)
28
- preds = outputs.logits
29
-
30
- pred = torch.sigmoid(preds)
31
- mat = pred.cpu().numpy()
32
- mask = Image.fromarray(np.uint8(mat * 255), "L")
33
- mask = mask.convert("RGB")
34
- mask = mask.resize(image.size)
35
- mask = np.array(mask)[:, :, 0]
36
-
37
- mask_min = mask.min()
38
- mask_max = mask.max()
39
- mask = (mask - mask_min) / (mask_max - mask_min)
40
- return mask
41
-
42
- def get_masks(prompts, img, threshold):
43
- prompts = prompts.split(",")
44
- masks = []
45
- for prompt in prompts:
46
- mask = process_image(img, prompt)
47
- mask = mask > threshold
48
- masks.append(mask)
49
- return masks
50
-
51
- def extract_image(pos_prompts, neg_prompts, img, threshold):
52
- positive_masks = get_masks(pos_prompts, img, 0.5)
53
- negative_masks = get_masks(neg_prompts, img, 0.5)
54
-
55
- pos_mask = np.any(np.stack(positive_masks), axis=0)
56
- neg_mask = np.any(np.stack(negative_masks), axis=0)
57
- final_mask = pos_mask & ~neg_mask
58
-
59
- final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")
60
- output_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
61
- output_image.paste(img, mask=final_mask)
62
- return output_image, final_mask
63
-
64
-
65
- @app.route('/api', methods=['POST'])
66
- def api():
67
- data = request.form
68
- img_url = data['input_image']
69
- positive_prompts = data['positive_prompts']
70
- negative_prompts = data['negative_prompts']
71
- threshold = float(data['input_slider_T'])
72
-
73
- # Download image from URL
74
- response = requests.get(img_url)
75
- img = Image.open(BytesIO(response.content))
76
-
77
- # Process image
78
- masks = get_masks(positive_prompts, negative_prompts, img, threshold)
79
- final_mask = np.any(np.stack(masks), axis=0)
80
-
81
- # Convert mask to image
82
- final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")
83
-
84
- # Convert the final image to bytes
85
- img_bytes = BytesIO()
86
- final_mask.save(img_bytes, format='PNG')
87
- img_bytes.seek(0)
88
-
89
- return send_file(img_bytes, mimetype='image/png')
90
-
91
  # Gradio UI
92
  with gr.Blocks() as demo:
93
  gr.Markdown("# CLIPSeg: Image Segmentation Using Text and Image Prompts")
94
- gr.Markdown(article)
95
- gr.Markdown(description)
 
 
96
 
97
  with gr.Row():
98
  with gr.Column():
@@ -113,6 +35,49 @@ with gr.Blocks() as demo:
113
  output_image = gr.Image(label="Result")
114
  output_mask = gr.Image(label="Mask")
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  btn_process.click(
117
  extract_image,
118
  inputs=[
@@ -124,19 +89,5 @@ with gr.Blocks() as demo:
124
  outputs=[output_image, output_mask],
125
  )
126
 
127
- def run_demo():
128
- demo.launch()
129
-
130
- def run_flask():
131
- app.run(host='127.0.0.1', port=8080)
132
-
133
- if __name__ == '__main__':
134
- # Run Gradio UI and Flask in separate threads
135
- gr_thread = threading.Thread(target=run_demo)
136
- flask_thread = threading.Thread(target=run_flask)
137
-
138
- gr_thread.start()
139
- flask_thread.start()
140
-
141
- gr_thread.join()
142
- flask_thread.join()
 
3
  from PIL import Image
4
  import torch
5
  import numpy as np
 
 
6
  import threading
7
 
8
  processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
9
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # Gradio UI
12
  with gr.Blocks() as demo:
13
  gr.Markdown("# CLIPSeg: Image Segmentation Using Text and Image Prompts")
14
+
15
+ # Add your article and description here
16
+ gr.Markdown("Your article goes here")
17
+ gr.Markdown("Your description goes here")
18
 
19
  with gr.Row():
20
  with gr.Column():
 
35
  output_image = gr.Image(label="Result")
36
  output_mask = gr.Image(label="Mask")
37
 
38
+ def process_image(image, prompt):
39
+ inputs = processor(
40
+ text=prompt, images=image, padding="max_length", return_tensors="pt"
41
+ )
42
+
43
+ with torch.no_grad():
44
+ outputs = model(**inputs)
45
+ preds = outputs.logits
46
+
47
+ pred = torch.sigmoid(preds)
48
+ mat = pred.cpu().numpy()
49
+ mask = Image.fromarray(np.uint8(mat * 255), "L")
50
+ mask = mask.convert("RGB")
51
+ mask = mask.resize(image.size)
52
+ mask = np.array(mask)[:, :, 0]
53
+
54
+ mask_min = mask.min()
55
+ mask_max = mask.max()
56
+ mask = (mask - mask_min) / (mask_max - mask_min)
57
+ return mask
58
+
59
+ def get_masks(prompts, img, threshold):
60
+ prompts = prompts.split(",")
61
+ masks = []
62
+ for prompt in prompts:
63
+ mask = process_image(img, prompt)
64
+ mask = mask > threshold
65
+ masks.append(mask)
66
+ return masks
67
+
68
+ def extract_image(pos_prompts, neg_prompts, img, threshold):
69
+ positive_masks = get_masks(pos_prompts, img, 0.5)
70
+ negative_masks = get_masks(neg_prompts, img, 0.5)
71
+
72
+ pos_mask = np.any(np.stack(positive_masks), axis=0)
73
+ neg_mask = np.any(np.stack(negative_masks), axis=0)
74
+ final_mask = pos_mask & ~neg_mask
75
+
76
+ final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")
77
+ output_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
78
+ output_image.paste(img, mask=final_mask)
79
+ return output_image, final_mask
80
+
81
  btn_process.click(
82
  extract_image,
83
  inputs=[
 
89
  outputs=[output_image, output_mask],
90
  )
91
 
92
+ # Launch Gradio API
93
+ demo.launch(share=True)