Spaces:
Runtime error
Runtime error
from transformers import pipeline | |
import gradio | |
from PIL import Image | |
from IPython.display import display, HTML | |
import base64 | |
from PIL import Image | |
from io import BytesIO | |
from sentence_transformers import SentenceTransformer, util | |
backgroundPipe = pipeline("image-segmentation", model="facebook/maskformer-swin-large-coco") | |
PersonPipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes") | |
sentenceModal = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
def getImageDetails(image) -> dict: | |
person = PersonPipe(image) | |
bg = backgroundPipe(image) | |
ret = {} | |
labs = [] | |
for imask in bg: | |
ret[imask["label"]] = imask["mask"] # Apply base64 image converter here if needed | |
labs.append(imask["label"]) | |
for mask in person: | |
ret[mask["label"]] = mask["mask"] # Apply base64 image converter here if needed | |
labs.append(mask["label"]) | |
return ret, labs | |
def processSentence(sentence: str, semilist: list): | |
query_embedding = sentenceModal.encode(sentence) | |
passage_embedding = sentenceModal.encode(semilist) | |
listv = util.dot_score(query_embedding, passage_embedding)[0] | |
float_list = [] | |
for i in listv: | |
float_list.append(i) | |
max_value = max(float_list) | |
max_index = float_list.index(max_value) | |
return semilist[max_index] | |
def process_image(image): | |
rgba_image = image.convert("RGBA") | |
switched_data = [ | |
(255, 255, 255, pixel[3]) if pixel[:3] == (0, 0, 0) else (0, 0, 0, pixel[3]) if pixel[:3] == (255, 255, 255) else pixel | |
for pixel in rgba_image.getdata() | |
] | |
switched_image = Image.new("RGBA", rgba_image.size) | |
switched_image.putdata(switched_data) | |
final_data = [ | |
(0, 0, 0, 0) if pixel[:3] == (255, 255, 255) else pixel | |
for pixel in switched_image.getdata() | |
] | |
processed_image = Image.new("RGBA", rgba_image.size) | |
processed_image.putdata(final_data) | |
return processed_image | |
def processAndGetMask(image: str, text: str): | |
datas, labs = getImageDetails(image) | |
selector = processSentence(text, labs) | |
imageout = datas[selector] | |
return process_image(imageout) | |
gr = gradio.Interface( | |
processAndGetMask, | |
[gradio.Image(type="pil"), gradio.Text()], | |
gradio.Image(type="pil") | |
) | |
gr.launch() |