File size: 4,157 Bytes
0a275ec
 
 
c0c08a7
 
 
e14c9b5
c0c08a7
 
 
 
 
 
e14c9b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0c08a7
 
e14c9b5
c0c08a7
 
 
e14c9b5
c0c08a7
e14c9b5
 
c0c08a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
782342a
 
 
e14c9b5
 
c0c08a7
 
e14c9b5
c0c08a7
 
 
 
5fcdb4c
c0c08a7
 
ed9d9fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
os.system("pip uninstall -y gradio")
os.system("pip install gradio==2.6.4")
from transformers import pipeline
import gradio
import base64
from PIL import Image, ImageDraw
from io import BytesIO
from sentence_transformers import SentenceTransformer, util

backgroundPipe = pipeline("image-segmentation", model="facebook/maskformer-swin-large-coco")
PersonPipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes")
sentenceModal = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
personDetailsPipe = pipeline("image-segmentation", model="yolo12138/segformer-b2-human-parse-24")
faceModal = pipeline("image-segmentation", model="jonathandinu/face-parsing")
faceDetectionModal = pipeline("object-detection", model="aditmohan96/detr-finetuned-face")
PersonDetectionpipe = pipeline("object-detection", model="hustvl/yolos-tiny")

def getPersonDetail(image):
    data = PersonDetectionpipe(image)
    persn = []
    for per in data:
        if per["label"].lower() == "person":
            persn.append(per["box"])
    n = 1
    ret = {}
    for cord in persn:
        crop_box = (cord['xmin'], cord['ymin'], cord['xmax'], cord['ymax'])
        cropped_image = image.crop(crop_box)
        personData = personDetailsPipe(cropped_image)
        for dt in personData:
            if len(persn) > 1:
                ret[(f'Person {n} {dt["label"]}').lower()] = cbiwm(image, dt["mask"], cord)
            else:
                ret[dt["label"].lower()] = cbiwm(image, dt["mask"], cord)
        n = n + 1
    return ret

def cbiwm(image, mask, coordinates):
    black_image = Image.new("RGBA", image.size, (0, 0, 0, 255))
    black_image.paste(mask, (coordinates['xmin'], coordinates['ymin']), mask)
    return black_image

def processFaceDetails(image):
    ret = getPersonDetail(image)
    data = faceDetectionModal(image)
    cordinates = data[1]["box"]
    crop_box = (data[1]["box"]['xmin'], data[1]["box"]['ymin'], data[1]["box"]['xmax'], data[1]["box"]['ymax'])
    cropped_image = image.crop(crop_box)
    facedata = faceModal(cropped_image)
    for imask in facedata:
        ret[imask["label"].replace(".png", "").lower()] = cbiwm(image, imask["mask"], cordinates)
    return ret

def getImageDetails(image) -> dict:
    ret = processFaceDetails(image)
    person = PersonPipe(image)
    bg = backgroundPipe(image)
    for imask in bg:
        ret[imask["label"].lower()] = imask["mask"] # Apply base64 image converter here if needed
    for mask in person:
        ret[mask["label"].lower()] = mask["mask"] # Apply base64 image converter here if needed
    return ret

def processSentence(sentence: str, semilist: list):
    query_embedding = sentenceModal.encode(sentence)
    passage_embedding = sentenceModal.encode(semilist)
    listv = util.dot_score(query_embedding, passage_embedding)[0]
    float_list = []
    for i in listv:
        float_list.append(i)
    max_value = max(float_list)
    max_index = float_list.index(max_value)
    return semilist[max_index]

def process_image(image):
    rgba_image = image.convert("RGBA")
    switched_data = [
        (255, 255, 255, pixel[3]) if pixel[:3] == (0, 0, 0) else (0, 0, 0, pixel[3]) if pixel[:3] == (255, 255, 255) else pixel
        for pixel in rgba_image.getdata()
    ]
    switched_image = Image.new("RGBA", rgba_image.size)
    switched_image.putdata(switched_data)
    final_data = [
        (0, 0, 0, 0) if pixel[:3] == (255, 255, 255) else pixel
        for pixel in switched_image.getdata()
    ]
    processed_image = Image.new("RGBA", rgba_image.size)
    processed_image.putdata(final_data)
    return processed_image

def processAndGetMask(base64_image: str, text: str):
    image_bytes = base64.b64decode(base64_image.split(',')[1])
    image = Image.open(BytesIO(image_bytes))
    datas = getImageDetails(image)
    labs = list(datas.keys())
    selector = processSentence(text, labs)
    imageout = datas[selector]
    print(f"Selected : {selector}")
    return process_image(imageout)

gr = gradio.Interface(
    processAndGetMask,
    [gradio.Image(type="filepath"), gradio.Text()],
    gradio.Image(type="pil")
)
gr.launch(share=True)