AngelBottomless commited on
Commit
ca9b012
·
verified ·
1 Parent(s): acee4df

nicer verbosity by o1-pro

Browse files
Files changed (1) hide show
  1. infer.py +98 -80
infer.py CHANGED
@@ -1,80 +1,98 @@
1
- import onnxruntime as ort
2
- import numpy as np
3
- import json
4
- from PIL import Image
5
-
6
- # 1) Load ONNX model
7
- session = ort.InferenceSession("camie_tagger_initial.onnx", providers=["CPUExecutionProvider"])
8
-
9
- # 2) Preprocess your image (512x512, etc.)
10
- def preprocess_image(img_path):
11
- """
12
- Loads and resizes an image to 512x512, converts it to float32 [0..1],
13
- and returns a (1,3,512,512) NumPy array (NCHW format).
14
- """
15
- img = Image.open(img_path).convert("RGB").resize((512, 512))
16
- x = np.array(img).astype(np.float32) / 255.0
17
- x = np.transpose(x, (2, 0, 1)) # HWC -> CHW
18
- x = np.expand_dims(x, 0) # add batch dimension -> (1,3,512,512)
19
- return x
20
-
21
- # Example input
22
- def inference(input_path):
23
- input_tensor = preprocess_image(input_path)
24
-
25
- # 3) Run inference
26
- input_name = session.get_inputs()[0].name
27
- outputs = session.run(None, {input_name: input_tensor})
28
- initial_logits, refined_logits = outputs # shape: (1, 70527) each
29
-
30
- # 4) Convert logits to probabilities via sigmoid
31
- refined_probs = 1 / (1 + np.exp(-refined_logits)) # shape: (1, 70527)
32
-
33
- # 5) Load metadata & retrieve threshold info
34
- with open("metadata.json", "r", encoding="utf-8") as f:
35
- metadata = json.load(f)
36
-
37
- # Dictionary of idx->tag_name, e.g. { "0": "brown_hair", "1": "blue_eyes", ... }
38
- idx_to_tag = metadata["idx_to_tag"]
39
-
40
- # Dictionary of tag->category, e.g. { "brown_hair": "character", "landscape": "general", ... }
41
- tag_to_category = metadata.get("tag_to_category", {})
42
-
43
- # Dictionary of category->threshold, e.g. { "character": 0.30, "general": 0.325, ... }
44
- # If not present or incomplete, we'll use a default threshold of 0.325
45
- category_thresholds = metadata.get("category_thresholds", {})
46
- default_threshold = 0.325
47
-
48
- # 6) Collect predictions by category
49
- # We'll loop through all tags and check if the probability is above the category-specific threshold
50
- results_by_category = {}
51
-
52
- num_tags = refined_probs.shape[1] # 70527
53
- for i in range(num_tags):
54
- prob = float(refined_probs[0, i]) # get probability for this tag
55
- tag_name = idx_to_tag[str(i)] # convert index -> tag name (keys in idx_to_tag are strings)
56
-
57
- # Find category; if not in 'tag_to_category', label it "unknown"
58
- category = tag_to_category.get(tag_name, "unknown")
59
-
60
- # Find threshold for this category; fallback to default
61
- cat_threshold = category_thresholds.get(category, default_threshold)
62
-
63
- # Check if prob meets or exceeds the threshold
64
- if prob >= cat_threshold:
65
- if category not in results_by_category:
66
- results_by_category[category] = []
67
- # Store the tag name + its probability
68
- results_by_category[category].append((tag_name, prob))
69
-
70
- # 7) Print out the predicted tags category-wise
71
- print("Predicted Tags by Category:\n")
72
-
73
- for cat, tags_list in results_by_category.items():
74
- print(f"Category: {cat} | Predicted {len(tags_list)} tags")
75
- for tname, tprob in sorted(tags_list, key=lambda x: x[1], reverse=True):
76
- print(f" Tag: {tname:30s} Prob: {tprob:.4f}")
77
- print()
78
-
79
- if __name__ == "__main__":
80
- inference("example_image.jpg")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime as ort
2
+ import numpy as np
3
+ import json
4
+ from PIL import Image
5
+
6
+ # 1) Load ONNX model
7
+ session = ort.InferenceSession("camie_tagger_initial.onnx", providers=["CPUExecutionProvider"])
8
+
9
+ # 2) Preprocess your image (512x512, etc.)
10
+ def preprocess_image(img_path):
11
+ """
12
+ Loads and resizes an image to 512x512, converts it to float32 [0..1],
13
+ and returns a (1,3,512,512) NumPy array (NCHW format).
14
+ """
15
+ img = Image.open(img_path).convert("RGB").resize((512, 512))
16
+ x = np.array(img).astype(np.float32) / 255.0
17
+ x = np.transpose(x, (2, 0, 1)) # HWC -> CHW
18
+ x = np.expand_dims(x, 0) # add batch dimension -> (1,3,512,512)
19
+ return x
20
+
21
+ # Example input
22
+
23
+ def inference(input_path, output_format="verbose"):
24
+ """
25
+ Returns either:
26
+ - A verbose category breakdown, or
27
+ - A comma-separated string of predicted tags (underscores replaced with spaces).
28
+ """
29
+ # 1) Preprocess
30
+ input_tensor = preprocess_image(input_path)
31
+
32
+ # 2) Run inference
33
+ input_name = session.get_inputs()[0].name
34
+ outputs = session.run(None, {input_name: input_tensor})
35
+ initial_logits, refined_logits = outputs # shape: (1, 70527) each
36
+
37
+ # 3) Convert logits to probabilities
38
+ refined_probs = 1 / (1 + np.exp(-refined_logits)) # shape: (1, 70527)
39
+
40
+ # 4) Load metadata & retrieve threshold info
41
+ with open("metadata.json", "r", encoding="utf-8") as f:
42
+ metadata = json.load(f)
43
+
44
+ idx_to_tag = metadata["idx_to_tag"] # e.g. { "0": "brown_hair", "1": "blue_eyes", ... }
45
+ tag_to_category = metadata.get("tag_to_category", {})
46
+ category_thresholds = metadata.get(
47
+ "category_thresholds",
48
+ {"artist": 0.1, "character": 0.2, "meta": 0.3, "style": 0.1}
49
+ )
50
+ default_threshold = 0.325
51
+
52
+ # 5) Collect predictions by category
53
+ results_by_category = {}
54
+ num_tags = refined_probs.shape[1]
55
+
56
+ for i in range(num_tags):
57
+ prob = float(refined_probs[0, i])
58
+ tag_name = idx_to_tag[str(i)] # str(i) because metadata uses string keys
59
+ category = tag_to_category.get(tag_name, "unknown")
60
+ cat_threshold = category_thresholds.get(category, default_threshold)
61
+
62
+ if prob >= cat_threshold:
63
+ if category not in results_by_category:
64
+ results_by_category[category] = []
65
+ results_by_category[category].append((tag_name, prob))
66
+
67
+ # 6) Depending on output_format, produce different return strings
68
+ if output_format == "as_prompt":
69
+ # Flatten all predicted tags across categories
70
+ all_predicted_tags = []
71
+ for cat, tags_list in results_by_category.items():
72
+ # We only need the tag name in as_prompt format
73
+ for tname, tprob in tags_list:
74
+ # convert underscores to spaces
75
+ tag_name_spaces = tname.replace("_", " ")
76
+ all_predicted_tags.append(tag_name_spaces)
77
+
78
+ # Create a comma-separated string
79
+ prompt_string = ", ".join(all_predicted_tags)
80
+ return prompt_string
81
+
82
+ else: # "verbose"
83
+ # We'll build a multiline string describing the predictions
84
+ lines = []
85
+ lines.append("Predicted Tags by Category:\n")
86
+ for cat, tags_list in results_by_category.items():
87
+ lines.append(f"Category: {cat} | Predicted {len(tags_list)} tags")
88
+ # Sort descending by probability
89
+ for tname, tprob in sorted(tags_list, key=lambda x: x[1], reverse=True):
90
+ lines.append(f" Tag: {tname:30s} Prob: {tprob:.4f}")
91
+ lines.append("") # blank line after each category
92
+ # Join lines with newlines
93
+ verbose_output = "\n".join(lines)
94
+ return verbose_output
95
+
96
+ if __name__ == "__main__":
97
+ result = inference("path/to/image", output_format="as_prompt")
98
+ print(result)