Spaces:
Runtime error
Runtime error
from transformers import pipeline | |
import gradio | |
import base64 | |
from PIL import Image, ImageDraw | |
from io import BytesIO | |
from sentence_transformers import SentenceTransformer, util | |
backgroundPipe = pipeline("image-segmentation", model="facebook/maskformer-swin-large-coco") | |
PersonPipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes") | |
sentenceModal = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
personDetailsPipe = pipeline("image-segmentation", model="yolo12138/segformer-b2-human-parse-24") | |
faceModal = pipeline("image-segmentation", model="jonathandinu/face-parsing") | |
faceDetectionModal = pipeline("object-detection", model="aditmohan96/detr-finetuned-face") | |
PersonDetectionpipe = pipeline("object-detection", model="hustvl/yolos-tiny") | |
def getPersonDetail(image): | |
data = PersonDetectionpipe(image) | |
persn = [] | |
for per in data: | |
if per["label"].lower() == "person": | |
persn.append(per["box"]) | |
n = 1 | |
ret = {} | |
for cord in persn: | |
crop_box = (cord['xmin'], cord['ymin'], cord['xmax'], cord['ymax']) | |
cropped_image = image.crop(crop_box) | |
personData = personDetailsPipe(cropped_image) | |
for dt in personData: | |
if len(persn) > 1: | |
ret[(f'Person {n} {dt["label"]}').lower()] = cbiwm(image, dt["mask"], cord) | |
else: | |
ret[dt["label"].lower()] = cbiwm(image, dt["mask"], cord) | |
n = n + 1 | |
return ret | |
def cbiwm(image, mask, coordinates): | |
black_image = Image.new("RGBA", image.size, (0, 0, 0, 255)) | |
black_image.paste(mask, (coordinates['xmin'], coordinates['ymin']), mask) | |
return black_image | |
def processFaceDetails(image): | |
ret = getPersonDetail(image) | |
data = faceDetectionModal(image) | |
cordinates = data[1]["box"] | |
crop_box = (data[1]["box"]['xmin'], data[1]["box"]['ymin'], data[1]["box"]['xmax'], data[1]["box"]['ymax']) | |
cropped_image = image.crop(crop_box) | |
facedata = faceModal(cropped_image) | |
for imask in facedata: | |
ret[imask["label"].replace(".png", "").lower()] = cbiwm(image, imask["mask"], cordinates) | |
return ret | |
def getImageDetails(image) -> dict: | |
ret = processFaceDetails(image) | |
person = PersonPipe(image) | |
bg = backgroundPipe(image) | |
for imask in bg: | |
ret[imask["label"].lower()] = imask["mask"] # Apply base64 image converter here if needed | |
for mask in person: | |
ret[mask["label"].lower()] = mask["mask"] # Apply base64 image converter here if needed | |
return ret | |
def processSentence(sentence: str, semilist: list): | |
query_embedding = sentenceModal.encode(sentence) | |
passage_embedding = sentenceModal.encode(semilist) | |
listv = util.dot_score(query_embedding, passage_embedding)[0] | |
float_list = [] | |
for i in listv: | |
float_list.append(i) | |
max_value = max(float_list) | |
max_index = float_list.index(max_value) | |
return semilist[max_index] | |
def process_image(image): | |
rgba_image = image.convert("RGBA") | |
switched_data = [ | |
(255, 255, 255, pixel[3]) if pixel[:3] == (0, 0, 0) else (0, 0, 0, pixel[3]) if pixel[:3] == (255, 255, 255) else pixel | |
for pixel in rgba_image.getdata() | |
] | |
switched_image = Image.new("RGBA", rgba_image.size) | |
switched_image.putdata(switched_data) | |
final_data = [ | |
(0, 0, 0, 0) if pixel[:3] == (255, 255, 255) else pixel | |
for pixel in switched_image.getdata() | |
] | |
processed_image = Image.new("RGBA", rgba_image.size) | |
processed_image.putdata(final_data) | |
return processed_image | |
def processAndGetMask(base64_image: str, text: str): | |
image_bytes = base64.b64decode(base64_image.split(',')[1]) | |
image = Image.open(BytesIO(image_bytes)) | |
datas = getImageDetails(image) | |
labs = list(datas.keys()) | |
selector = processSentence(text, labs) | |
imageout = datas[selector] | |
print(f"Selected : {selector}") | |
return process_image(imageout) | |
gr = gradio.Interface( | |
processAndGetMask, | |
[gradio.Image(type="url"), gradio.Text()], | |
gradio.Image(type="pil") | |
) | |
gr.launch(share=True) | |