SelectByText / app.py
5m4ck3r's picture
Create app.py
c0c08a7
raw
history blame
2.29 kB
from transformers import pipeline
import gradio
from PIL import Image
from IPython.display import display, HTML
import base64
from PIL import Image
from io import BytesIO
from sentence_transformers import SentenceTransformer, util
backgroundPipe = pipeline("image-segmentation", model="facebook/maskformer-swin-large-coco")
PersonPipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes")
sentenceModal = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
def getImageDetails(image) -> dict:
person = PersonPipe(image)
bg = backgroundPipe(image)
ret = {}
labs = []
for imask in bg:
ret[imask["label"]] = imask["mask"] # Apply base64 image converter here if needed
labs.append(imask["label"])
for mask in person:
ret[mask["label"]] = mask["mask"] # Apply base64 image converter here if needed
labs.append(mask["label"])
return ret, labs
def processSentence(sentence: str, semilist: list):
query_embedding = sentenceModal.encode(sentence)
passage_embedding = sentenceModal.encode(semilist)
listv = util.dot_score(query_embedding, passage_embedding)[0]
float_list = []
for i in listv:
float_list.append(i)
max_value = max(float_list)
max_index = float_list.index(max_value)
return semilist[max_index]
def process_image(image):
rgba_image = image.convert("RGBA")
switched_data = [
(255, 255, 255, pixel[3]) if pixel[:3] == (0, 0, 0) else (0, 0, 0, pixel[3]) if pixel[:3] == (255, 255, 255) else pixel
for pixel in rgba_image.getdata()
]
switched_image = Image.new("RGBA", rgba_image.size)
switched_image.putdata(switched_data)
final_data = [
(0, 0, 0, 0) if pixel[:3] == (255, 255, 255) else pixel
for pixel in switched_image.getdata()
]
processed_image = Image.new("RGBA", rgba_image.size)
processed_image.putdata(final_data)
return processed_image
def processAndGetMask(image: str, text: str):
datas, labs = getImageDetails(image)
selector = processSentence(text, labs)
imageout = datas[selector]
return process_image(imageout)
gr = gradio.Interface(
processAndGetMask,
[gradio.Image(type="pil"), gradio.Text()],
gradio.Image(type="pil")
)
gr.launch()