from transformers import pipeline import gradio from PIL import Image from IPython.display import display, HTML import base64 from PIL import Image from io import BytesIO from sentence_transformers import SentenceTransformer, util backgroundPipe = pipeline("image-segmentation", model="facebook/maskformer-swin-large-coco") PersonPipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes") sentenceModal = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") def getImageDetails(image) -> dict: person = PersonPipe(image) bg = backgroundPipe(image) ret = {} labs = [] for imask in bg: ret[imask["label"]] = imask["mask"] # Apply base64 image converter here if needed labs.append(imask["label"]) for mask in person: ret[mask["label"]] = mask["mask"] # Apply base64 image converter here if needed labs.append(mask["label"]) return ret, labs def processSentence(sentence: str, semilist: list): query_embedding = sentenceModal.encode(sentence) passage_embedding = sentenceModal.encode(semilist) listv = util.dot_score(query_embedding, passage_embedding)[0] float_list = [] for i in listv: float_list.append(i) max_value = max(float_list) max_index = float_list.index(max_value) return semilist[max_index] def process_image(image): rgba_image = image.convert("RGBA") switched_data = [ (255, 255, 255, pixel[3]) if pixel[:3] == (0, 0, 0) else (0, 0, 0, pixel[3]) if pixel[:3] == (255, 255, 255) else pixel for pixel in rgba_image.getdata() ] switched_image = Image.new("RGBA", rgba_image.size) switched_image.putdata(switched_data) final_data = [ (0, 0, 0, 0) if pixel[:3] == (255, 255, 255) else pixel for pixel in switched_image.getdata() ] processed_image = Image.new("RGBA", rgba_image.size) processed_image.putdata(final_data) return processed_image def processAndGetMask(image: str, text: str): datas, labs = getImageDetails(image) selector = processSentence(text, labs) imageout = datas[selector] return process_image(imageout) gr = gradio.Interface( processAndGetMask, [gradio.Image(type="pil"), gradio.Text()], gradio.Image(type="pil") ) gr.launch()