|
import streamlit as st |
|
import yaml |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from transformers import DetrImageProcessor, DetrForObjectDetection |
|
|
|
from lib.IRRA.tokenizer import tokenize, SimpleTokenizer |
|
from lib.IRRA.image import prepare_images |
|
from lib.IRRA.model.build import build_model, IRRA |
|
from PIL import Image |
|
from pathlib import Path |
|
|
|
from easydict import EasyDict |
|
|
|
|
|
@st.cache_resource |
|
def get_model(): |
|
args = yaml.load(open('model/configs.yaml'), Loader=yaml.FullLoader) |
|
args = EasyDict(args) |
|
args['training'] = False |
|
|
|
model = build_model(args) |
|
|
|
return model |
|
|
|
|
|
@st.cache_resource |
|
def get_detr(): |
|
processor = DetrImageProcessor.from_pretrained( |
|
"facebook/detr-resnet-50", revision="no_timm") |
|
|
|
model = DetrForObjectDetection.from_pretrained( |
|
"facebook/detr-resnet-50", revision="no_timm") |
|
|
|
return model, processor |
|
|
|
|
|
def segment_images(model, processor, images: list[str]): |
|
segments = [] |
|
id = 0 |
|
|
|
p = Path('segments') |
|
p.mkdir(exist_ok=True) |
|
|
|
for image in images: |
|
image = Image.open(image) |
|
|
|
inputs = processor(images=image, return_tensors="pt") |
|
outputs = model(**inputs) |
|
|
|
target_sizes = torch.tensor([image.size[::-1]]) |
|
results = processor.post_process_object_detection( |
|
outputs, target_sizes=target_sizes, threshold=0.9)[0] |
|
|
|
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): |
|
box = [round(i, 2) for i in box.tolist()] |
|
label = model.config.id2label[label.item()] |
|
|
|
if box[2] - box[0] > 70 and box[3] - box[1] > 70: |
|
if label == 'person': |
|
file = p / f'img_{id}.jpg' |
|
image.crop(box).save(file) |
|
segments.append(file.as_posix()) |
|
|
|
id += 1 |
|
|
|
return segments |
|
|
|
|
|
def get_similarities(text: str, images: list[str], model: IRRA) -> torch.Tensor: |
|
tokenizer = SimpleTokenizer() |
|
|
|
txt = tokenize(text, tokenizer) |
|
imgs = prepare_images(images) |
|
|
|
image_feats = model.encode_image(imgs) |
|
text_feats = model.encode_text(txt.unsqueeze(0)) |
|
|
|
image_feats = F.normalize(image_feats, p=2, dim=1) |
|
text_feats = F.normalize(text_feats, p=2, dim=1) |
|
|
|
return text_feats @ image_feats.t() |
|
|