File size: 3,803 Bytes
ca9b012
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import onnxruntime as ort
import numpy as np
import json
from PIL import Image

# 1) Load ONNX model
session = ort.InferenceSession("camie_tagger_initial.onnx", providers=["CPUExecutionProvider"])

# 2) Preprocess your image (512x512, etc.)
def preprocess_image(img_path):
    """
    Loads and resizes an image to 512x512, converts it to float32 [0..1],
    and returns a (1,3,512,512) NumPy array (NCHW format).
    """
    img = Image.open(img_path).convert("RGB").resize((512, 512))
    x = np.array(img).astype(np.float32) / 255.0
    x = np.transpose(x, (2, 0, 1))  # HWC -> CHW
    x = np.expand_dims(x, 0)        # add batch dimension -> (1,3,512,512)
    return x

# Example input

def inference(input_path, output_format="verbose"):
    """
    Returns either:
      - A verbose category breakdown, or
      - A comma-separated string of predicted tags (underscores replaced with spaces).
    """
    # 1) Preprocess
    input_tensor = preprocess_image(input_path)

    # 2) Run inference
    input_name = session.get_inputs()[0].name
    outputs = session.run(None, {input_name: input_tensor})
    initial_logits, refined_logits = outputs  # shape: (1, 70527) each

    # 3) Convert logits to probabilities
    refined_probs = 1 / (1 + np.exp(-refined_logits))  # shape: (1, 70527)

    # 4) Load metadata & retrieve threshold info
    with open("metadata.json", "r", encoding="utf-8") as f:
        metadata = json.load(f)

    idx_to_tag = metadata["idx_to_tag"]  # e.g. { "0": "brown_hair", "1": "blue_eyes", ... }
    tag_to_category = metadata.get("tag_to_category", {})
    category_thresholds = metadata.get(
        "category_thresholds",
        {"artist": 0.1, "character": 0.2, "meta": 0.3, "style": 0.1}
    )
    default_threshold = 0.325

    # 5) Collect predictions by category
    results_by_category = {}
    num_tags = refined_probs.shape[1]

    for i in range(num_tags):
        prob = float(refined_probs[0, i])
        tag_name = idx_to_tag[str(i)]  # str(i) because metadata uses string keys
        category = tag_to_category.get(tag_name, "unknown")
        cat_threshold = category_thresholds.get(category, default_threshold)

        if prob >= cat_threshold:
            if category not in results_by_category:
                results_by_category[category] = []
            results_by_category[category].append((tag_name, prob))

    # 6) Depending on output_format, produce different return strings
    if output_format == "as_prompt":
        # Flatten all predicted tags across categories
        all_predicted_tags = []
        for cat, tags_list in results_by_category.items():
            # We only need the tag name in as_prompt format
            for tname, tprob in tags_list:
                # convert underscores to spaces
                tag_name_spaces = tname.replace("_", " ")
                all_predicted_tags.append(tag_name_spaces)

        # Create a comma-separated string
        prompt_string = ", ".join(all_predicted_tags)
        return prompt_string

    else:  # "verbose"
        # We'll build a multiline string describing the predictions
        lines = []
        lines.append("Predicted Tags by Category:\n")
        for cat, tags_list in results_by_category.items():
            lines.append(f"Category: {cat} | Predicted {len(tags_list)} tags")
            # Sort descending by probability
            for tname, tprob in sorted(tags_list, key=lambda x: x[1], reverse=True):
                lines.append(f"  Tag: {tname:30s}  Prob: {tprob:.4f}")
            lines.append("")  # blank line after each category
        # Join lines with newlines
        verbose_output = "\n".join(lines)
        return verbose_output

if __name__ == "__main__":
    result = inference("path/to/image", output_format="as_prompt")
    print(result)