Update num_classes in JTP2Processor to 18166 for alignment with state_dict
Browse files- caption/jtp2.py +1 -1
- verify_tags.py +5 -0
caption/jtp2.py
CHANGED
@@ -94,7 +94,7 @@ class JTP2Processor(BatchProcessor[Path, None]):
|
|
94 |
self.model = timm.create_model(
|
95 |
"vit_so400m_patch14_siglip_384.webli",
|
96 |
pretrained=False,
|
97 |
-
num_classes=
|
98 |
)
|
99 |
|
100 |
# Replace the model's head with the custom GatedHead
|
|
|
94 |
self.model = timm.create_model(
|
95 |
"vit_so400m_patch14_siglip_384.webli",
|
96 |
pretrained=False,
|
97 |
+
num_classes=18166 # Align with state_dict
|
98 |
)
|
99 |
|
100 |
# Replace the model's head with the custom GatedHead
|
verify_tags.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
with open("/home/kade/source/repos/JTP2/tags.json", "r", encoding="utf-8") as f:
|
4 |
+
tags = json.load(f)
|
5 |
+
print(f"Number of tags: {len(tags)}")
|