CSB261 commited on
Commit
acf5692
โ€ข
1 Parent(s): cda3f0b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -15
app.py CHANGED
@@ -1,10 +1,10 @@
1
  import gradio as gr
2
  from gradio_imageslider import ImageSlider
3
- from loadimg import load_img
4
  import spaces
5
  from transformers import AutoModelForImageSegmentation
6
  import torch
7
  from torchvision import transforms
 
8
 
9
  torch.set_float32_matmul_precision(["high", "highest"][0])
10
 
@@ -25,20 +25,18 @@ transform_image = transforms.Compose(
25
  def fn(image):
26
  if image is None or len(image) == 0:
27
  return image, None # ์›๋ณธ ์ด๋ฏธ์ง€๋„ ๋ฐ˜ํ™˜
28
- im = load_img(image, output_type="pil")
29
- im = im.convert("RGB")
30
  image_size = im.size
31
  origin = im.copy()
32
- image = load_img(im)
33
- input_images = transform_image(image).unsqueeze(0).to("cuda")
34
  # Prediction
35
  with torch.no_grad():
36
  preds = birefnet(input_images)[-1].sigmoid().cpu()
37
  pred = preds[0].squeeze()
38
  pred_pil = transforms.ToPILImage()(pred)
39
  mask = pred_pil.resize(image_size)
40
- image.putalpha(mask)
41
- return image, origin # ๋ณ€ํ™˜๋œ ์ด๋ฏธ์ง€์™€ ์›๋ณธ ์ด๋ฏธ์ง€ ๋ฐ˜ํ™˜
42
 
43
 
44
  def save_image(image):
@@ -56,23 +54,21 @@ with gr.Blocks() as demo:
56
  slider1 = ImageSlider(label="birefnet", type="pil")
57
  slider2 = ImageSlider(label="birefnet", type="pil")
58
 
59
- chameleon = load_img("butterfly.jpg", output_type="pil")
60
- example_image2 = load_img("example2.jpg", output_type="pil") # ๋‘ ๋ฒˆ์งธ ์˜ˆ์ œ ์ด๋ฏธ์ง€
61
- example_image3 = load_img("example3.jpg", output_type="pil") # ์„ธ ๋ฒˆ์งธ ์˜ˆ์ œ ์ด๋ฏธ์ง€
62
- url1 = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
63
- url2 = "https://example.com/example2.jpg" # ๋‘ ๋ฒˆ์งธ ์˜ˆ์ œ URL
64
- url3 = "https://example.com/example3.jpg" # ์„ธ ๋ฒˆ์งธ ์˜ˆ์ œ URL
65
 
66
  with gr.Tab("Image Upload"):
67
  tab1 = gr.Interface(
68
  fn, inputs=image, outputs=[slider1, output_file],
69
- examples=[chameleon, example_image2, example_image3], api_name="image"
70
  )
71
 
72
  with gr.Tab("Image URL"):
73
  tab2 = gr.Interface(
74
  fn, inputs=text, outputs=[slider2, output_file],
75
- examples=[url1, url2, url3], api_name="text"
76
  )
77
 
78
  def process_download(image):
 
1
  import gradio as gr
2
  from gradio_imageslider import ImageSlider
 
3
  import spaces
4
  from transformers import AutoModelForImageSegmentation
5
  import torch
6
  from torchvision import transforms
7
+ from PIL import Image
8
 
9
  torch.set_float32_matmul_precision(["high", "highest"][0])
10
 
 
25
  def fn(image):
26
  if image is None or len(image) == 0:
27
  return image, None # ์›๋ณธ ์ด๋ฏธ์ง€๋„ ๋ฐ˜ํ™˜
28
+ im = Image.open(image).convert("RGB")
 
29
  image_size = im.size
30
  origin = im.copy()
31
+ input_images = transform_image(im).unsqueeze(0).to("cuda")
 
32
  # Prediction
33
  with torch.no_grad():
34
  preds = birefnet(input_images)[-1].sigmoid().cpu()
35
  pred = preds[0].squeeze()
36
  pred_pil = transforms.ToPILImage()(pred)
37
  mask = pred_pil.resize(image_size)
38
+ im.putalpha(mask)
39
+ return im, origin # ๋ณ€ํ™˜๋œ ์ด๋ฏธ์ง€์™€ ์›๋ณธ ์ด๋ฏธ์ง€ ๋ฐ˜ํ™˜
40
 
41
 
42
  def save_image(image):
 
54
  slider1 = ImageSlider(label="birefnet", type="pil")
55
  slider2 = ImageSlider(label="birefnet", type="pil")
56
 
57
+ # ์ŠคํŽ˜์ด์Šค์— ์žˆ๋Š” ์˜ˆ์ œ ์ด๋ฏธ์ง€ ํŒŒ์ผ ๊ฒฝ๋กœ
58
+ example_image1 = "example_images/example1.jpg"
59
+ example_image2 = "example_images/example2.jpg"
60
+ example_image3 = "example_images/example3.jpg"
 
 
61
 
62
  with gr.Tab("Image Upload"):
63
  tab1 = gr.Interface(
64
  fn, inputs=image, outputs=[slider1, output_file],
65
+ examples=[example_image1, example_image2, example_image3], api_name="image"
66
  )
67
 
68
  with gr.Tab("Image URL"):
69
  tab2 = gr.Interface(
70
  fn, inputs=text, outputs=[slider2, output_file],
71
+ examples=[example_image1, example_image2, example_image3], api_name="text"
72
  )
73
 
74
  def process_download(image):