Thouph's picture
Update app.py
f6d73a8 verified
raw
history blame contribute delete
No virus
1.94 kB
import json
import random
random.seed(999)
import torch
from torchvision.transforms import transforms
import gradio as gr
from datetime import datetime
model = torch.load('model.pth', map_location=torch.device('cpu'))
model.eval()
transform = transforms.Compose([
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize(
mean=[
0.5,
0.5,
0.5,
], std=[
0.5,
0.5,
0.5,
])
])
with open("tags_9940.json", "r") as file:
allowed_tags = json.load(file)
allowed_tags = sorted(allowed_tags)
allowed_tags.append("explicit")
allowed_tags.append("questionable")
allowed_tags.append("safe")
def create_tags(image, threshold):
img = image.convert('RGB')
tensor = transform(img).unsqueeze(0)
with torch.no_grad():
logits = model(tensor)
probabilities = torch.nn.functional.sigmoid(logits[0])
indices = torch.where(probabilities > threshold)[0]
values = probabilities[indices]
temp = []
tag_score = dict()
for i in range(indices.size(0)):
temp.append([allowed_tags[indices[i]], values[i].item()])
tag_score[allowed_tags[indices[i]]] = values[i].item()
# temp = sorted(temp, key=lambda x: x[1], reverse=True)
# print("Before adding implicated tags, there are " + str(len(temp)) + " tags")
temp = [t[0] for t in temp]
text_no_impl = " ".join(temp)
current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
print(f"{current_datetime}: finished.")
return text_no_impl, tag_score
demo = gr.Interface(
create_tags,
inputs=[gr.Image(label="Source", sources=['upload', 'webcam'], type='pil'), gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.30, label="Threshold")],
outputs=[
gr.Textbox(label="Tag String"),
gr.Label(label="Tag Predictions", num_top_classes=200),
],
allow_flagging="never",
)
demo.launch()