cdnuts's picture
yay
16e5bdb verified
raw
history blame
No virus
1.62 kB
import argparse
import json
import time
from PIL import Image
import torch
from torchvision.transforms import transforms
import gradio as gr
parser = argparse.ArgumentParser(description="Image Classification")
parser.add_argument("-i", "--image_path", required=True, help="Path to the image file")
args = parser.parse_args()
model = torch.load('model.pth', map_location=torch.device('cpu'))
model.eval()
transform = transforms.Compose([
transforms.Resize((448, 448)),
transforms.ToTensor(),
transforms.Normalize(mean=[
0.48145466,
0.4578275,
0.40821073
], std=[
0.26862954,
0.26130258,
0.27577711
])
])
with open("tags_8041.json", "r") as file:
tags = json.load(file)
allowed_tags = sorted(tags)
allowed_tags.insert(0, "placeholder0")
allowed_tags.append("placeholder1")
allowed_tags.append("explicit")
allowed_tags.append("questionable")
allowed_tags.append("safe")
def create_tags(image):
img = image.convert('RGB')
tensor = transform(img).unsqueeze(0)
with torch.no_grad():
out = model(tensor)
probabilities = torch.nn.functional.sigmoid(out[0])
indices = torch.where(probabilities > 0.3)[0]
values = probabilities[indices]
temp = []
for i in range(indices.size(0)):
temp.append([allowed_tags[indices[i]], values[i].item()])
temp = sorted(temp, key=lambda x: x[1], reverse=True)
text = ""
for i in range(len(temp)):
text += temp[i][0] + (' ,' if i < len(temp) - 1 else '')
return text
demo = gr.Interface(
fn=create_tags,
inputs=["image"],
outputs=["text"],
)
demo.launch()