Kims12 commited on
Commit
2b13f4e
·
verified ·
1 Parent(s): 3145448

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -15
app.py CHANGED
@@ -8,12 +8,13 @@ from torchvision import transforms
8
  from PIL import Image
9
  import os
10
 
11
- # GPU 설정을 CPU로 변경
12
  birefnet = AutoModelForImageSegmentation.from_pretrained(
13
  "ZhengPeng7/BiRefNet", trust_remote_code=True
14
  )
15
  birefnet.to("cpu") # GPU -> CPU로 변경
16
 
 
17
  transform_image = transforms.Compose(
18
  [
19
  transforms.Resize((1024, 1024)),
@@ -22,6 +23,18 @@ transform_image = transforms.Compose(
22
  ]
23
  )
24
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def fn(image):
26
  im = load_img(image, output_type="pil")
27
  im = im.convert("RGB")
@@ -34,19 +47,7 @@ def fn(image):
34
  jpg_path = "output.jpg"
35
  jpg_image.save(jpg_path, format="JPEG")
36
 
37
- return processed_image, jpg_path
38
-
39
- def process(image):
40
- image_size = image.size
41
- input_images = transform_image(image).unsqueeze(0).to("cpu") # GPU -> CPU로 변경
42
- # Prediction
43
- with torch.no_grad():
44
- preds = birefnet(input_images)[-1].sigmoid().cpu()
45
- pred = preds[0].squeeze()
46
- pred_pil = transforms.ToPILImage()(pred)
47
- mask = pred_pil.resize(image_size)
48
- image.putalpha(mask)
49
- return image
50
 
51
  def process_file(f):
52
  name_path = f.rsplit(".", 1)[0] + ".png"
@@ -56,13 +57,15 @@ def process_file(f):
56
  transparent.save(name_path)
57
  return name_path
58
 
 
59
  slider1 = ImageSlider(label="Processed Image", type="pil")
60
  image_upload = gr.Image(label="Upload an image")
61
  output_download = gr.File(label="Download JPG File")
62
 
63
- # 새로운 샘플 이미지 추가
64
  sample_images = ["1.png", "2.jpg", "3.png"]
65
 
 
66
  tab1 = gr.Interface(
67
  fn=fn,
68
  inputs=image_upload,
 
8
  from PIL import Image
9
  import os
10
 
11
+ # 모델 로드 CPU로 설정
12
  birefnet = AutoModelForImageSegmentation.from_pretrained(
13
  "ZhengPeng7/BiRefNet", trust_remote_code=True
14
  )
15
  birefnet.to("cpu") # GPU -> CPU로 변경
16
 
17
+ # 이미지 전처리
18
  transform_image = transforms.Compose(
19
  [
20
  transforms.Resize((1024, 1024)),
 
23
  ]
24
  )
25
 
26
+ def process(image):
27
+ image_size = image.size
28
+ input_images = transform_image(image).unsqueeze(0).to("cpu") # CPU로 변경
29
+ # 예측 수행
30
+ with torch.no_grad():
31
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
32
+ pred = preds[0].squeeze()
33
+ pred_pil = transforms.ToPILImage()(pred)
34
+ mask = pred_pil.resize(image_size)
35
+ image.putalpha(mask)
36
+ return image
37
+
38
  def fn(image):
39
  im = load_img(image, output_type="pil")
40
  im = im.convert("RGB")
 
47
  jpg_path = "output.jpg"
48
  jpg_image.save(jpg_path, format="JPEG")
49
 
50
+ return [processed_image], jpg_path # ImageSlider는 리스트를 기대함
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  def process_file(f):
53
  name_path = f.rsplit(".", 1)[0] + ".png"
 
57
  transparent.save(name_path)
58
  return name_path
59
 
60
+ # Gradio 컴포넌트 정의
61
  slider1 = ImageSlider(label="Processed Image", type="pil")
62
  image_upload = gr.Image(label="Upload an image")
63
  output_download = gr.File(label="Download JPG File")
64
 
65
+ # 새로운 샘플 이미지 추가 (app.py와 동일한 폴더에 위치해야 함)
66
  sample_images = ["1.png", "2.jpg", "3.png"]
67
 
68
+ # Gradio 인터페이스 설정
69
  tab1 = gr.Interface(
70
  fn=fn,
71
  inputs=image_upload,