Thouph commited on
Commit
164c320
1 Parent(s): 1890768

Upload 2 files

Browse files
Files changed (2) hide show
  1. 7704_inference.py +38 -0
  2. tags.json.7z +3 -0
7704_inference.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ from PIL import Image
4
+ from timm.data import resolve_data_config
5
+ import torch
6
+ from torchvision.transforms import transforms
7
+
8
+ model = torch.load('model.pth').to("cuda")
9
+ model.eval()
10
+ config = resolve_data_config({}, model=model)
11
+ transform = transforms.Compose([
12
+ transforms.Resize((384, 384)),
13
+ transforms.ToTensor(),
14
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
15
+ ])
16
+
17
+ with open("tags.json", "r") as file:
18
+ tags = json.load(file)
19
+ allowed_tags = sorted(tags)
20
+ allowed_tags.extend(["placeholder0", "placeholder1", "placeholder2"])
21
+ tag_count = len(allowed_tags)
22
+
23
+
24
+ image_path="path/to/your/image.png"
25
+ start = time.time()
26
+ img = Image.open(image_path).convert('RGB')
27
+ tensor = transform(img).unsqueeze(0).to("cuda") # transform and add batch dimension
28
+
29
+ with torch.no_grad():
30
+ out = model(tensor)
31
+ probabilities = torch.nn.functional.sigmoid(out[0])
32
+
33
+ top10_prob, top10_catid = torch.topk(probabilities, 100)
34
+ for i in range(top10_prob.size(0)):
35
+ print(allowed_tags[top10_catid[i]], top10_prob[i].item())
36
+ end = time.time()
37
+ print(f'Executed in {end - start} seconds')
38
+ print("\n\n", end="")
tags.json.7z ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a1be541d0f5005f2fff6cb8c75dc7ea216ae337813aaecb5dc36a03b0266d76
3
+ size 35480