Thouph commited on
Commit
d3c95ed
1 Parent(s): 7aa0a81

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ random.seed(999)
4
+ import torch
5
+ from torchvision.transforms import transforms
6
+ import gradio as gr
7
+
8
+ model = torch.load('model.pth', map_location=torch.device('cpu'))
9
+ model.eval()
10
+ transform = transforms.Compose([
11
+ transforms.Resize((384, 384)),
12
+ transforms.ToTensor(),
13
+ transforms.Normalize(
14
+ mean=[
15
+ 0.5,
16
+ 0.5,
17
+ 0.5,
18
+ ], std=[
19
+ 0.5,
20
+ 0.5,
21
+ 0.5,
22
+ ])
23
+ ])
24
+
25
+ with open("tags_9940.json", "r") as file:
26
+ allowed_tags = json.load(file)
27
+
28
+ allowed_tags = sorted(allowed_tags)
29
+ allowed_tags.append("explicit")
30
+ allowed_tags.append("questionable")
31
+ allowed_tags.append("safe")
32
+
33
+ def create_tags(image, threshold):
34
+ img = image.convert('RGB')
35
+ tensor = transform(img).unsqueeze(0)
36
+
37
+ with torch.no_grad():
38
+ logits = model(tensor)
39
+ probabilities = torch.nn.functional.sigmoid(logits[0])
40
+ indices = torch.where(probabilities > threshold)[0]
41
+ values = probabilities[indices]
42
+
43
+ temp = []
44
+ tag_score = dict()
45
+ for i in range(indices.size(0)):
46
+ temp.append([allowed_tags[indices[i]], values[i].item()])
47
+ tag_score[allowed_tags[indices[i]]] = values[i].item()
48
+ # temp = sorted(temp, key=lambda x: x[1], reverse=True)
49
+ # print("Before adding implicated tags, there are " + str(len(temp)) + " tags")
50
+ temp = [t[0] for t in temp]
51
+ text_no_impl = " ".join(temp)
52
+ return text_no_impl, tag_score
53
+
54
+ demo = gr.Interface(
55
+ create_tags,
56
+ inputs=[gr.Image(label="Source", sources=['upload', 'webcam'], type='pil'), gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.30, label="Threshold")],
57
+ outputs=[
58
+ gr.Textbox(label="Tag String"),
59
+ gr.Label(label="Tag Predictions", num_top_classes=200),
60
+ ],
61
+ allow_flagging="never",
62
+ )
63
+
64
+ demo.launch()