5m4ck3r commited on
Commit
c0c08a7
1 Parent(s): 71c93b4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ import gradio
3
+ from PIL import Image
4
+ from IPython.display import display, HTML
5
+ import base64
6
+ from PIL import Image
7
+ from io import BytesIO
8
+ from sentence_transformers import SentenceTransformer, util
9
+
10
+
11
+ backgroundPipe = pipeline("image-segmentation", model="facebook/maskformer-swin-large-coco")
12
+ PersonPipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes")
13
+ sentenceModal = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
14
+
15
+ def getImageDetails(image) -> dict:
16
+ person = PersonPipe(image)
17
+ bg = backgroundPipe(image)
18
+ ret = {}
19
+ labs = []
20
+ for imask in bg:
21
+ ret[imask["label"]] = imask["mask"] # Apply base64 image converter here if needed
22
+ labs.append(imask["label"])
23
+ for mask in person:
24
+ ret[mask["label"]] = mask["mask"] # Apply base64 image converter here if needed
25
+ labs.append(mask["label"])
26
+ return ret, labs
27
+
28
+ def processSentence(sentence: str, semilist: list):
29
+ query_embedding = sentenceModal.encode(sentence)
30
+ passage_embedding = sentenceModal.encode(semilist)
31
+ listv = util.dot_score(query_embedding, passage_embedding)[0]
32
+ float_list = []
33
+ for i in listv:
34
+ float_list.append(i)
35
+ max_value = max(float_list)
36
+ max_index = float_list.index(max_value)
37
+ return semilist[max_index]
38
+
39
+ def process_image(image):
40
+ rgba_image = image.convert("RGBA")
41
+ switched_data = [
42
+ (255, 255, 255, pixel[3]) if pixel[:3] == (0, 0, 0) else (0, 0, 0, pixel[3]) if pixel[:3] == (255, 255, 255) else pixel
43
+ for pixel in rgba_image.getdata()
44
+ ]
45
+ switched_image = Image.new("RGBA", rgba_image.size)
46
+ switched_image.putdata(switched_data)
47
+ final_data = [
48
+ (0, 0, 0, 0) if pixel[:3] == (255, 255, 255) else pixel
49
+ for pixel in switched_image.getdata()
50
+ ]
51
+ processed_image = Image.new("RGBA", rgba_image.size)
52
+ processed_image.putdata(final_data)
53
+ return processed_image
54
+
55
+ def processAndGetMask(image: str, text: str):
56
+ datas, labs = getImageDetails(image)
57
+ selector = processSentence(text, labs)
58
+ imageout = datas[selector]
59
+ return process_image(imageout)
60
+
61
+ gr = gradio.Interface(
62
+ processAndGetMask,
63
+ [gradio.Image(type="pil"), gradio.Text()],
64
+ gradio.Image(type="pil")
65
+ )
66
+ gr.launch()