cdnuts's picture
Update app.py
81e3554 verified
raw
history blame contribute delete
No virus
1.63 kB
import json
import time
from PIL import Image
import torch
from torchvision.transforms import transforms
import gradio as gr
model = torch.load('model.pth', map_location=torch.device('cpu'))
model.eval()
transform = transforms.Compose([
transforms.Resize((448, 448)),
transforms.ToTensor(),
transforms.Normalize(mean=[
0.48145466,
0.4578275,
0.40821073
], std=[
0.26862954,
0.26130258,
0.27577711
])
])
with open("tags_8041.json", "r") as file:
tags = json.load(file)
allowed_tags = sorted(tags)
allowed_tags.insert(0, "placeholder0")
allowed_tags.append("placeholder1")
allowed_tags.append("explicit")
allowed_tags.append("questionable")
allowed_tags.append("safe")
def create_tags(image, thres):
img = image.convert('RGB')
tensor = transform(img).unsqueeze(0)
with torch.no_grad():
out = model(tensor)
probabilities = torch.nn.functional.sigmoid(out[0])
indices = torch.where(probabilities > thres)[0]
values = probabilities[indices]
temp = []
for i in range(indices.size(0)):
temp.append([allowed_tags[indices[i]], values[i].item()])
temp = sorted(temp, key=lambda x: x[1], reverse=True)
text = ""
for i in range(len(temp)):
text += temp[i][0] + (', ' if i < len(temp) - 1 else '')
text = text.replace(r"placeholder1, ", "")
text = text.replace("_", " ")
text = text.replace("(", "\\(").replace(")", "\\)")
return text
demo = gr.Interface(
fn=create_tags,
inputs=[gr.Image(type="pil"), gr.Slider(minimum=0, maximum = 1, step = 0.01, value = 0.3)],
outputs=["text"],
)
demo.launch()