File size: 4,360 Bytes
0a275ec
 
f5f8296
c0c08a7
 
 
e14c9b5
c0c08a7
 
 
 
 
 
e14c9b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfc3574
 
 
c50f5c0
dfc3574
 
c50f5c0
 
e14c9b5
 
 
 
 
c0c08a7
 
e14c9b5
c0c08a7
 
 
e14c9b5
c0c08a7
e14c9b5
 
c0c08a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322cdde
e14c9b5
 
c0c08a7
 
a7ce092
c0c08a7
 
 
 
c1c2031
 
c0c08a7
ed9d9fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
os.system("pip uninstall -y gradio")
os.system("pip install gradio==3.47.1")
from transformers import pipeline
import gradio
import base64
from PIL import Image, ImageDraw
from io import BytesIO
from sentence_transformers import SentenceTransformer, util

backgroundPipe = pipeline("image-segmentation", model="facebook/maskformer-swin-large-coco")
PersonPipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes")
sentenceModal = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
personDetailsPipe = pipeline("image-segmentation", model="yolo12138/segformer-b2-human-parse-24")
faceModal = pipeline("image-segmentation", model="jonathandinu/face-parsing")
faceDetectionModal = pipeline("object-detection", model="aditmohan96/detr-finetuned-face")
PersonDetectionpipe = pipeline("object-detection", model="hustvl/yolos-tiny")

def getPersonDetail(image):
    data = PersonDetectionpipe(image)
    persn = []
    for per in data:
        if per["label"].lower() == "person":
            persn.append(per["box"])
    n = 1
    ret = {}
    for cord in persn:
        crop_box = (cord['xmin'], cord['ymin'], cord['xmax'], cord['ymax'])
        cropped_image = image.crop(crop_box)
        personData = personDetailsPipe(cropped_image)
        for dt in personData:
            if len(persn) > 1:
                ret[(f'Person {n} {dt["label"]}').lower()] = cbiwm(image, dt["mask"], cord)
            else:
                ret[dt["label"].lower()] = cbiwm(image, dt["mask"], cord)
        n = n + 1
    return ret

def cbiwm(image, mask, coordinates):
    black_image = Image.new("RGBA", image.size, (0, 0, 0, 255))
    black_image.paste(mask, (coordinates['xmin'], coordinates['ymin']), mask)
    return black_image

def processFaceDetails(image):
    ret = getPersonDetail(image)
    data = faceDetectionModal(image)
    if len(data) > 1:
        cordinates = data[1]["box"]
        crop_box = (data[1]["box"]['xmin'], data[1]["box"]['ymin'], data[1]["box"]['xmax'], data[1]["box"]['ymax'])
    elif len(data) > 0:
        cordinates = data[0]["box"]
        crop_box = (data[0]["box"]['xmin'], data[0]["box"]['ymin'], data[0]["box"]['xmax'], data[0]["box"]['ymax'])
    else:
        return ret
    cropped_image = image.crop(crop_box)
    facedata = faceModal(cropped_image)
    for imask in facedata:
        ret[imask["label"].replace(".png", "").lower()] = cbiwm(image, imask["mask"], cordinates)
    return ret

def getImageDetails(image) -> dict:
    ret = processFaceDetails(image)
    person = PersonPipe(image)
    bg = backgroundPipe(image)
    for imask in bg:
        ret[imask["label"].lower()] = imask["mask"] # Apply base64 image converter here if needed
    for mask in person:
        ret[mask["label"].lower()] = mask["mask"] # Apply base64 image converter here if needed
    return ret

def processSentence(sentence: str, semilist: list):
    query_embedding = sentenceModal.encode(sentence)
    passage_embedding = sentenceModal.encode(semilist)
    listv = util.dot_score(query_embedding, passage_embedding)[0]
    float_list = []
    for i in listv:
        float_list.append(i)
    max_value = max(float_list)
    max_index = float_list.index(max_value)
    return semilist[max_index]

def process_image(image):
    rgba_image = image.convert("RGBA")
    switched_data = [
        (255, 255, 255, pixel[3]) if pixel[:3] == (0, 0, 0) else (0, 0, 0, pixel[3]) if pixel[:3] == (255, 255, 255) else pixel
        for pixel in rgba_image.getdata()
    ]
    switched_image = Image.new("RGBA", rgba_image.size)
    switched_image.putdata(switched_data)
    final_data = [
        (0, 0, 0, 0) if pixel[:3] == (255, 255, 255) else pixel
        for pixel in switched_image.getdata()
    ]
    processed_image = Image.new("RGBA", rgba_image.size)
    processed_image.putdata(final_data)
    return processed_image

def processAndGetMask(image: str, text: str):
    datas = getImageDetails(image)
    labs = list(datas.keys())
    selector = processSentence(text, labs)
    imageout = datas[selector]
    print(f"Selected : {selector} Among : {labs}")
    return process_image(imageout)

gr = gradio.Interface(
    processAndGetMask,
    [gradio.Image(label="Input Image", type="pil"), gradio.Text(label="Input text to segment")],
    gradio.Image(label="Output Image", type="pil")
)
gr.launch(share=True)