Thouph commited on
Commit
0beb876
1 Parent(s): 8799536

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +48 -0
inference.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ from PIL import Image
4
+ import torch
5
+ from torchvision.transforms import transforms
6
+
7
+ model = torch.load('/path/to/your/model.pth').to("cuda")
8
+ model.eval()
9
+ transform = transforms.Compose([
10
+ transforms.Resize((448, 448)),
11
+ transforms.ToTensor(),
12
+ transforms.Normalize(mean=[
13
+ 0.48145466,
14
+ 0.4578275,
15
+ 0.40821073
16
+ ], std=[
17
+ 0.26862954,
18
+ 0.26130258,
19
+ 0.27577711
20
+ ])
21
+ ])
22
+
23
+ with open("tags_8041.json", "r") as file:
24
+ tags = json.load(file)
25
+ allowed_tags = sorted(tags)
26
+ allowed_tags.insert(0, "placeholder0")
27
+ allowed_tags.append("placeholder1")
28
+ allowed_tags.append("explicit")
29
+ allowed_tags.append("questionable")
30
+ allowed_tags.append("safe")
31
+
32
+ image_path = "/path/to/your/image.jpg"
33
+ start = time.time()
34
+ img = Image.open(image_path).convert('RGB')
35
+ tensor = transform(img).unsqueeze(0).to("cuda") # transform and add batch dimension
36
+
37
+ with torch.no_grad():
38
+ out = model(tensor)
39
+ probabilities = torch.nn.functional.sigmoid(out[0])
40
+ indices = torch.where(probabilities > 0.3)[0]
41
+ values = probabilities[indices]
42
+
43
+ for i in range(indices.size(0)):
44
+ print(allowed_tags[indices[i]], values[i].item())
45
+
46
+ end = time.time()
47
+ print(f'Executed in {end - start} seconds')
48
+ print("\n\n", end="")