SonnySW commited on
Commit
1c28d46
1 Parent(s): 6ed96fd

2024.07.04 update app.py

Browse files
Files changed (2) hide show
  1. .ipynb_checkpoints/app-checkpoint.py +82 -0
  2. app.py +1 -1
.ipynb_checkpoints/app-checkpoint.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 객체검출 -> 삭제 체크박스 적용본
2
+ import torch
3
+ from PIL import Image, ImageDraw
4
+ from transformers import DetrImageProcessor, DetrForObjectDetection
5
+ from diffusers import StableDiffusionInpaintPipeline
6
+ import gradio as gr
7
+
8
+ # 모델 로드
9
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
10
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
11
+ pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16)
12
+ pipe = pipe.to("cpu")
13
+
14
+ def detect_objects(image):
15
+ # 객체 검출
16
+ inputs = processor(images=image, return_tensors="pt")
17
+ outputs = model(**inputs)
18
+
19
+ # 결과 후처리```````
20
+ target_sizes = torch.tensor([image.size[::-1]])
21
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
22
+
23
+ # 검출된 객체 정보 추출
24
+ detected_objects = []
25
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
26
+ if score > 0.9:
27
+ box = [round(i) for i in box.tolist()]
28
+ detected_objects.append({"label": model.config.id2label[label.item()], "box": box})
29
+
30
+ return detected_objects
31
+
32
+ def display_detected_objects(image):
33
+ detected_objects = detect_objects(image)
34
+ labeled_image = image.copy()
35
+ draw = ImageDraw.Draw(labeled_image)
36
+ object_labels = []
37
+ for obj in detected_objects:
38
+ box = obj["box"]
39
+ label = obj["label"]
40
+ draw.rectangle(box, outline="red", width=3)
41
+ draw.text((box[0], box[1]), label, fill="red")
42
+ object_labels.append(f"{label} at {box}")
43
+ return labeled_image, gr.update(choices=object_labels)
44
+
45
+ def inpaint_image(image, selected_objects):
46
+ detected_objects = detect_objects(image)
47
+
48
+ # 마스크 생성
49
+ mask = Image.new("L", image.size, 0)
50
+ draw = ImageDraw.Draw(mask)
51
+ for obj in detected_objects:
52
+ object_label = f"{obj['label']} at {obj['box']}"
53
+ if object_label in selected_objects:
54
+ box = obj["box"]
55
+ draw.rectangle(box, fill=255)
56
+
57
+ # Inpainting 수행
58
+ image = image.convert("RGB")
59
+ mask = mask.convert("RGB")
60
+ output = pipe(prompt="a modern interior", image=image, mask_image=mask).images[0]
61
+ # output = pipe(prompt="remove", image=image, mask_image=mask).images[0]
62
+
63
+
64
+ return output
65
+
66
+ # Gradio 인터페이스 설정
67
+ with gr.Blocks() as interface:
68
+ with gr.Row():
69
+ image_input = gr.Image(type="pil", label="Input Image")
70
+ objects_list = gr.CheckboxGroup(label="Detected Objects")
71
+
72
+ labeled_image_output = gr.Image(label="Labeled Image")
73
+ final_output = gr.Image(label="Output Image")
74
+
75
+ detect_button = gr.Button("Detect Objects")
76
+ inpaint_button = gr.Button("Remove Selected Objects")
77
+
78
+ detect_button.click(fn=display_detected_objects, inputs=image_input, outputs=[labeled_image_output, objects_list])
79
+ inpaint_button.click(fn=inpaint_image, inputs=[image_input, objects_list], outputs=final_output)
80
+
81
+ # Gradio 인터페이스 실행
82
+ interface.launch()
app.py CHANGED
@@ -8,7 +8,7 @@ import gradio as gr
8
  # 모델 로드
9
  processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
10
  model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
11
- pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", torch_dtype=torch.float32)
12
  pipe = pipe.to("cpu")
13
 
14
  def detect_objects(image):
 
8
  # 모델 로드
9
  processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
10
  model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
11
+ pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16)
12
  pipe = pipe.to("cpu")
13
 
14
  def detect_objects(image):