Mike Afton commited on
Commit
17aaf2d
1 Parent(s): be1d3f6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -0
app.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ os.system('cls||clear')
4
+
5
+ from diffusers import AutoPipelineForInpainting
6
+ from transformers import pipeline
7
+ from ultralytics import YOLO
8
+ from PIL import Image
9
+ import numpy as np
10
+ import torch
11
+ import base64
12
+ from io import BytesIO
13
+ import gradio as gr
14
+ from gradio import components
15
+ import difflib
16
+
17
+ # Constants
18
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
19
+
20
+ # Load
21
+
22
+ def image_to_base64(image: Image.Image):
23
+ buffered = BytesIO()
24
+ image.save(buffered, format="JPEG")
25
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
26
+
27
+ def get_most_similar_string(target_string, string_array):
28
+ differ = difflib.Differ()
29
+ best_match = string_array[0]
30
+ best_match_ratio = 0
31
+ for candidate_string in string_array:
32
+ similarity_ratio = difflib.SequenceMatcher(None, target_string, candidate_string).ratio()
33
+ if similarity_ratio > best_match_ratio:
34
+ best_match = candidate_string
35
+ best_match_ratio = similarity_ratio
36
+
37
+ return best_match
38
+
39
+ def loadModels():
40
+
41
+ yoloModel=YOLO('yolov8x-seg.pt')
42
+ pipe =AutoPipelineForInpainting.from_pretrained(
43
+ "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
44
+ torch_dtype=torch.float16,
45
+ variant="fp16",
46
+ ).to("cuda")
47
+ image_captioner = pipeline("image-to-text", model="Abdou/vit-swin-base-224-gpt2-image-captioning", device=DEVICE)
48
+ #return gpt_model, gpt_tokenizer, gpt_params,yoloModel,pipe,image_captioner
49
+ return yoloModel,pipe,image_captioner
50
+
51
+ # Yolo
52
+
53
+ def getClasses(model,img1):
54
+ results = model([img1])
55
+ out=[]
56
+ for r in results:
57
+ #im_array = r.plot(boxes=False,labels=False) # plot a BGR numpy array of predictions
58
+ im_array = r.plot()
59
+ out.append(r)
60
+
61
+ return r,im_array[..., ::-1],results
62
+
63
+ def getMasks(out):
64
+ allout={}
65
+ class_masks = {}
66
+ for a in out:
67
+ class_name = a['name']
68
+ mask = a['img']
69
+ if class_name in class_masks:
70
+ class_masks[class_name] = Image.fromarray(
71
+ np.maximum(np.array(class_masks[class_name]), np.array(mask))
72
+ )
73
+ else:
74
+ class_masks[class_name] = mask
75
+ for class_name, mask in class_masks.items():
76
+ allout[class_name]=mask
77
+ return allout
78
+
79
+ def joinClasses(classes):
80
+ i=0
81
+ out=[]
82
+ for r in classes:
83
+ masks=r.masks
84
+ name0=r.names[int(r.boxes.cls.cpu().numpy()[0])]
85
+
86
+ mask1 = masks[0]
87
+ mask = mask1.data[0].cpu().numpy()
88
+ polygon = mask1.xy[0]
89
+ # Normalize the mask values to 0-255 if needed
90
+ mask_normalized = ((mask - mask.min()) * (255 / (mask.max() - mask.min()))).astype(np.uint8)
91
+ mask_img = Image.fromarray(mask_normalized, "L")
92
+ out.append({'name':name0,'img':mask_img})
93
+ i+=1
94
+
95
+ allMask=getMasks(out)
96
+ return allMask
97
+
98
+ def getSegments(yoloModel,img1):
99
+ classes,image,results1=getClasses(yoloModel,img1)
100
+ allMask=joinClasses(classes)
101
+ return allMask
102
+
103
+ # Gradio UI
104
+
105
+ def getDescript(image_captioner,img1):
106
+ base64_img = image_to_base64(img1)
107
+ caption = image_captioner(base64_img)[0]['generated_text']
108
+ return caption
109
+
110
+ def rmGPT(caption,remove_class):
111
+ arstr=list(caption.split(' '))
112
+ popular=get_most_similar_string(remove_class,arstr)
113
+ ind=arstr.index(popular)
114
+ new=[]
115
+ for i in range(len(arstr)):
116
+ if i not in list(range(ind-2,ind+3)):
117
+ new+=arstr[i]
118
+ return ' '.join(new)
119
+
120
+ # SDXL
121
+
122
+ def ChangeOBJ(sdxl_m,img1,response,mask1):
123
+ size = img1.size
124
+ image = sdxl_m(prompt=response, image=img1, mask_image=mask1).images[0]
125
+ return image.resize((size[0], size[1]))
126
+
127
+
128
+
129
+ yoloModel,sdxl,image_captioner=loadModels()
130
+
131
+ def full_pipeline(image, target):
132
+ img1 = Image.fromarray(image.astype('uint8'), 'RGB')
133
+ allMask=getSegments(yoloModel,img1)
134
+ tartget_to_remove=get_most_similar_string(target,list(allMask.keys()))
135
+ caption=getDescript(image_captioner,img1)
136
+
137
+ response=rmGPT(caption,tartget_to_remove)
138
+ mask1=allMask[tartget_to_remove]
139
+
140
+ remimg=ChangeOBJ(sdxl,img1,response,mask1)
141
+
142
+ return remimg,caption,response
143
+
144
+
145
+
146
+ iface = gr.Interface(
147
+ fn=full_pipeline,
148
+ inputs=[
149
+ gr.Image(label="Upload Image"),
150
+ gr.Textbox(label="What to delete?"),
151
+ ],
152
+ outputs=[
153
+ gr.Image(label="Result Image", type="numpy"),
154
+ gr.Textbox(label="Caption"),
155
+ gr.Textbox(label="Message"),
156
+ ],
157
+ live=False
158
+ )
159
+
160
+
161
+ iface.launch()