cdnuts commited on
Commit
16e5bdb
·
verified ·
1 Parent(s): 12e28da
Files changed (1) hide show
  1. run.py +63 -0
run.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import time
4
+ from PIL import Image
5
+ import torch
6
+ from torchvision.transforms import transforms
7
+ import gradio as gr
8
+
9
+ parser = argparse.ArgumentParser(description="Image Classification")
10
+ parser.add_argument("-i", "--image_path", required=True, help="Path to the image file")
11
+ args = parser.parse_args()
12
+
13
+ model = torch.load('model.pth', map_location=torch.device('cpu'))
14
+ model.eval()
15
+ transform = transforms.Compose([
16
+ transforms.Resize((448, 448)),
17
+ transforms.ToTensor(),
18
+ transforms.Normalize(mean=[
19
+ 0.48145466,
20
+ 0.4578275,
21
+ 0.40821073
22
+ ], std=[
23
+ 0.26862954,
24
+ 0.26130258,
25
+ 0.27577711
26
+ ])
27
+ ])
28
+
29
+
30
+ with open("tags_8041.json", "r") as file:
31
+ tags = json.load(file)
32
+ allowed_tags = sorted(tags)
33
+ allowed_tags.insert(0, "placeholder0")
34
+ allowed_tags.append("placeholder1")
35
+ allowed_tags.append("explicit")
36
+ allowed_tags.append("questionable")
37
+ allowed_tags.append("safe")
38
+
39
+ def create_tags(image):
40
+ img = image.convert('RGB')
41
+ tensor = transform(img).unsqueeze(0)
42
+
43
+ with torch.no_grad():
44
+ out = model(tensor)
45
+ probabilities = torch.nn.functional.sigmoid(out[0])
46
+ indices = torch.where(probabilities > 0.3)[0]
47
+ values = probabilities[indices]
48
+
49
+ temp = []
50
+ for i in range(indices.size(0)):
51
+ temp.append([allowed_tags[indices[i]], values[i].item()])
52
+ temp = sorted(temp, key=lambda x: x[1], reverse=True)
53
+ text = ""
54
+ for i in range(len(temp)):
55
+ text += temp[i][0] + (' ,' if i < len(temp) - 1 else '')
56
+ return text
57
+
58
+ demo = gr.Interface(
59
+ fn=create_tags,
60
+ inputs=["image"],
61
+ outputs=["text"],
62
+ )
63
+ demo.launch()