wondervictor commited on
Commit
cddba21
·
verified ·
1 Parent(s): 578b68a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -5
app.py CHANGED
@@ -27,7 +27,7 @@ import numpy as np
27
  import sys
28
  import tqdm
29
 
30
- version = "YxZhang/evf-sam2"
31
  model_type = "sam2"
32
 
33
  tokenizer = AutoTokenizer.from_pretrained(
@@ -58,7 +58,7 @@ video_model.to('cuda')
58
 
59
  @spaces.GPU
60
  @torch.no_grad()
61
- def inference_image(image_np, prompt):
62
  original_size_list = [image_np.shape[:2]]
63
 
64
  image_beit = beit3_preprocess(image_np, 224).to(dtype=image_model.dtype,
@@ -68,6 +68,8 @@ def inference_image(image_np, prompt):
68
  image_sam = image_sam.to(dtype=image_model.dtype,
69
  device=image_model.device)
70
 
 
 
71
  input_ids = tokenizer(
72
  prompt, return_tensors="pt")["input_ids"].to(device=image_model.device)
73
 
@@ -93,7 +95,7 @@ def inference_image(image_np, prompt):
93
  @spaces.GPU
94
  @torch.no_grad()
95
  @torch.autocast(device_type="cuda", dtype=torch.float16)
96
- def inference_video(video_path, prompt):
97
 
98
  os.system("rm -rf demo_temp")
99
  os.makedirs("demo_temp/input_frames", exist_ok=True)
@@ -109,6 +111,8 @@ def inference_video(video_path, prompt):
109
  image_beit = beit3_preprocess(image_np, 224).to(dtype=video_model.dtype,
110
  device=video_model.device)
111
 
 
 
112
  input_ids = tokenizer(
113
  prompt, return_tensors="pt")["input_ids"].to(device=video_model.device)
114
 
@@ -162,6 +166,12 @@ with gr.Blocks() as demo:
162
  submit_image = gr.Button(value='Submit',
163
  scale=1,
164
  variant='primary')
 
 
 
 
 
 
165
  with gr.Tab(label="EVF-SAM-2-Video"):
166
  with gr.Row():
167
  input_video = gr.Video(label='Input Video')
@@ -175,11 +185,17 @@ with gr.Blocks() as demo:
175
  submit_video = gr.Button(value='Submit',
176
  scale=1,
177
  variant='primary')
 
 
 
 
 
 
178
 
179
  submit_image.click(fn=inference_image,
180
- inputs=[input_image, image_prompt],
181
  outputs=output_image)
182
  submit_video.click(fn=inference_video,
183
- inputs=[input_video, video_prompt],
184
  outputs=output_video)
185
  demo.launch(show_error=True)
 
27
  import sys
28
  import tqdm
29
 
30
+ version = "YxZhang/evf-sam2-multitask"
31
  model_type = "sam2"
32
 
33
  tokenizer = AutoTokenizer.from_pretrained(
 
58
 
59
  @spaces.GPU
60
  @torch.no_grad()
61
+ def inference_image(image_np, prompt, semantic_type):
62
  original_size_list = [image_np.shape[:2]]
63
 
64
  image_beit = beit3_preprocess(image_np, 224).to(dtype=image_model.dtype,
 
68
  image_sam = image_sam.to(dtype=image_model.dtype,
69
  device=image_model.device)
70
 
71
+ if semantic_type:
72
+ prompt = "[semantic] " + prompt
73
  input_ids = tokenizer(
74
  prompt, return_tensors="pt")["input_ids"].to(device=image_model.device)
75
 
 
95
  @spaces.GPU
96
  @torch.no_grad()
97
  @torch.autocast(device_type="cuda", dtype=torch.float16)
98
+ def inference_video(video_path, prompt, semantic_type):
99
 
100
  os.system("rm -rf demo_temp")
101
  os.makedirs("demo_temp/input_frames", exist_ok=True)
 
111
  image_beit = beit3_preprocess(image_np, 224).to(dtype=video_model.dtype,
112
  device=video_model.device)
113
 
114
+ if semantic_type:
115
+ prompt = "[semantic] " + prompt
116
  input_ids = tokenizer(
117
  prompt, return_tensors="pt")["input_ids"].to(device=video_model.device)
118
 
 
166
  submit_image = gr.Button(value='Submit',
167
  scale=1,
168
  variant='primary')
169
+ with gr.Row():
170
+ semantic_type_img = gr.Checkbox(
171
+ False,
172
+ label="semantic level",
173
+ info="check this if you want to segment body parts or background or multi objects (only available with latest evf-sam checkpoint)"
174
+ )
175
  with gr.Tab(label="EVF-SAM-2-Video"):
176
  with gr.Row():
177
  input_video = gr.Video(label='Input Video')
 
185
  submit_video = gr.Button(value='Submit',
186
  scale=1,
187
  variant='primary')
188
+ with gr.Row():
189
+ semantic_type_vid = gr.Checkbox(
190
+ False,
191
+ label="semantic level",
192
+ info="check this if you want to segment body parts or background or multi objects (only available with latest evf-sam checkpoint)"
193
+ )
194
 
195
  submit_image.click(fn=inference_image,
196
+ inputs=[input_image, image_prompt, semantic_type_img],
197
  outputs=output_image)
198
  submit_video.click(fn=inference_video,
199
+ inputs=[input_video, video_prompt, semantic_type_vid],
200
  outputs=output_video)
201
  demo.launch(show_error=True)