oshita-n commited on
Commit
baf1626
1 Parent(s): dc972bb
Files changed (2) hide show
  1. app.py +15 -12
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,23 +1,26 @@
1
  import gradio as gr
2
- from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
3
- import torch
4
  from PIL import Image
5
- import numpy as np
6
 
7
  def process(input_image, prompt):
8
- inputs = processor(text=prompt, images=input_image, padding="max_length", return_tensors="pt")
9
- # predict
10
- with torch.no_grad():
11
- outputs = model(**inputs)
12
- preds = torch.sigmoid(outputs.logits).squeeze().detach().cpu().numpy()
 
 
 
 
 
 
13
  preds = np.where(preds > 0.3, 255, 0).astype(np.uint8)
14
  preds = Image.fromarray(preds.astype(np.uint8))
15
  preds = np.array(preds.resize((input_image.width, input_image.height)))
16
- return preds
17
 
 
 
18
  if __name__ == '__main__':
19
- processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
20
- model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
21
  input_image = gr.inputs.Image(label='image', type='pil')
22
  prompt = gr.Textbox(label='Prompt')
23
  ips = [
@@ -31,4 +34,4 @@ if __name__ == '__main__':
31
  outputs=outputs,
32
  input_size=input_size,
33
  output_size=output_size)
34
- iface.launch()
 
1
  import gradio as gr
2
+ from lavis.models import load_model_and_preprocess
 
3
  from PIL import Image
 
4
 
5
  def process(input_image, prompt):
6
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7
+
8
+ model, vis_processors, txt_processors = load_model_and_preprocess(name="blip_feature_extractor", model_type="base", is_eval=True, device=device)
9
+
10
+ input_image = input_image.resize((256, 256), Image.ANTIALIAS)
11
+ image = vis_processors["eval"](input_image).unsqueeze(0).to(device)
12
+ text_input = txt_processors["eval"](prompt)
13
+ sample = {"image": image, "text_input": [text_input]}
14
+
15
+ features_multimodal = model.extract_features(sample, mode="multimodal")
16
+ preds = features_multimodal.multimodal_embeds.squeeze().detach().cpu().numpy()
17
  preds = np.where(preds > 0.3, 255, 0).astype(np.uint8)
18
  preds = Image.fromarray(preds.astype(np.uint8))
19
  preds = np.array(preds.resize((input_image.width, input_image.height)))
 
20
 
21
+ return preds
22
+
23
  if __name__ == '__main__':
 
 
24
  input_image = gr.inputs.Image(label='image', type='pil')
25
  prompt = gr.Textbox(label='Prompt')
26
  ips = [
 
34
  outputs=outputs,
35
  input_size=input_size,
36
  output_size=output_size)
37
+ iface.launch()
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  gradio
2
  transformers
3
  torch
4
- pillow
 
 
1
  gradio
2
  transformers
3
  torch
4
+ pillow
5
+ salesforce-lavis==1.0.2