Spaces:
Runtime error
Runtime error
import json | |
import os | |
import zipfile | |
from pathlib import Path | |
import io | |
from tempfile import NamedTemporaryFile | |
from PIL import Image | |
import gradio as gr | |
import torch | |
from torchvision.transforms import transforms | |
from torch.utils.data import Dataset, DataLoader | |
import spaces | |
torch.jit.script = lambda f: f | |
# torch.cuda.amp.autocast(enabled=True) | |
caption_ext = ".txt" | |
exclude_tags = ("explicit", "questionable", "safe") | |
transform = transforms.Compose([ | |
transforms.Resize((384, 384)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
]) | |
class ZipImageDataset(Dataset): | |
def __init__(self, zip_file, dtype): | |
self.zip_file = zip_file | |
self.dtype = dtype | |
self.image_files = [file_info for file_info in zip_file.infolist() if file_info.filename.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))] | |
def __len__(self): | |
return len(self.image_files) | |
def __getitem__(self, index): | |
file_info = self.image_files[index] | |
with self.zip_file.open(file_info) as file: | |
image = Image.open(file).convert("RGB") | |
image = transform(image).to(self.dtype) | |
return { | |
"image": image, | |
"image_name": file_info.filename, | |
} | |
model = torch.load("./model.pth", map_location=torch.device('cpu')) | |
model.eval() | |
with open("tags_9940.json", "r") as file: | |
tags = json.load(file) | |
allowed_tags = sorted(tags) + ["explicit", "questionable", "safe"] | |
def create_tags(image, threshold): | |
img = image.convert('RGB') | |
tensor = transform(img).unsqueeze(0) | |
with torch.no_grad(): | |
logits = model(tensor) | |
probabilities = torch.nn.functional.sigmoid(logits[0]) | |
indices = torch.where(probabilities > threshold)[0] | |
values = probabilities[indices] | |
temp = [] | |
tag_score = dict() | |
for i in range(indices.size(0)): | |
temp.append([allowed_tags[indices[i]], values[i].item()]) | |
tag_score[allowed_tags[indices[i]]] = values[i].item() | |
temp = [t[0] for t in temp] | |
text_no_impl = ", ".join(temp) | |
return text_no_impl, tag_score | |
def process_zip(zip_file, threshold): | |
with zipfile.ZipFile(zip_file.name) as zip_ref: | |
dataset = ZipImageDataset(zip_ref, next(model.parameters()).dtype) | |
dataloader = DataLoader( | |
dataset, | |
batch_size=64, | |
shuffle=False, | |
num_workers=0, | |
pin_memory=True, | |
drop_last=False, | |
) | |
all_image_names = [] | |
all_probabilities = [] | |
with torch.no_grad(): | |
for i, batch in enumerate(dataloader): | |
images = batch["image"] | |
with torch.autocast(device_type="cuda", dtype=torch.float16): | |
outputs = model(images) | |
probabilities = torch.nn.functional.sigmoid(outputs) | |
for image_name, prob in zip(batch["image_name"], probabilities): | |
indices = torch.where(prob > threshold)[0] | |
values = prob[indices] | |
temp = [] | |
tag_score = dict() | |
for j in range(indices.size(0)): | |
temp.append([allowed_tags[indices[j]], values[j].item()]) | |
tag_score[allowed_tags[indices[j]]] = values[j].item() | |
temp = [t[0] for t in temp] | |
text_no_impl = ", ".join(temp) | |
all_image_names.append(image_name) | |
all_probabilities.append(text_no_impl) | |
temp_file = NamedTemporaryFile(delete=False, suffix=".zip") | |
with zipfile.ZipFile(temp_file, "w") as zip_ref: | |
for image_name, text_no_impl in zip(all_image_names, all_probabilities): | |
with zip_ref.open(image_name + caption_ext, "w") as file: | |
file.write(text_no_impl.encode()) | |
temp_file.seek(0) | |
return temp_file.name | |
with gr.Blocks() as demo: | |
with gr.Tab("Single Image"): | |
gr.Interface( | |
create_tags, | |
inputs=[gr.Image(label="Source", sources=['upload', 'webcam'], type='pil'), gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.30, label="Threshold")], | |
outputs=[ | |
gr.Textbox(label="Tag String"), | |
gr.Label(label="Tag Predictions", num_top_classes=200), | |
], | |
allow_flagging="never", | |
) | |
with gr.Tab("Multiple Images"): | |
gr.Interface(fn=process_zip, inputs=[gr.File(label="Zip File", file_types=[".zip"]), gr.Slider(minimum=0, maximum=1, value=0.3, step=0.01, label="Threshold")], | |
outputs=gr.File(type="binary")) | |
if __name__ == "__main__": | |
demo.launch() |