Delete classify.py
Browse files- classify.py +0 -248
classify.py
DELETED
|
@@ -1,248 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
AIFinder Interactive Classifier
|
| 3 |
-
Loads trained model and provides an interactive REPL for classifying text.
|
| 4 |
-
|
| 5 |
-
Usage: python3 classify.py
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
import os
|
| 9 |
-
import sys
|
| 10 |
-
import time
|
| 11 |
-
import joblib
|
| 12 |
-
import numpy as np
|
| 13 |
-
import torch
|
| 14 |
-
import torch.nn as nn
|
| 15 |
-
|
| 16 |
-
from config import MODEL_DIR, DATASET_REGISTRY, DEEPSEEK_AM_DATASETS
|
| 17 |
-
from model import AIFinderNet
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def load_models():
|
| 21 |
-
"""Load all model components from the model directory."""
|
| 22 |
-
try:
|
| 23 |
-
pipeline = joblib.load(os.path.join(MODEL_DIR, "feature_pipeline.joblib"))
|
| 24 |
-
provider_enc = joblib.load(os.path.join(MODEL_DIR, "provider_enc.joblib"))
|
| 25 |
-
|
| 26 |
-
checkpoint = torch.load(
|
| 27 |
-
os.path.join(MODEL_DIR, "classifier.pt"),
|
| 28 |
-
map_location="cpu",
|
| 29 |
-
weights_only=True,
|
| 30 |
-
)
|
| 31 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 32 |
-
net = AIFinderNet(
|
| 33 |
-
input_dim=checkpoint["input_dim"],
|
| 34 |
-
num_providers=checkpoint["num_providers"],
|
| 35 |
-
hidden_dim=checkpoint["hidden_dim"],
|
| 36 |
-
embed_dim=checkpoint["embed_dim"],
|
| 37 |
-
dropout=checkpoint["dropout"],
|
| 38 |
-
).to(device)
|
| 39 |
-
net.load_state_dict(checkpoint["state_dict"], strict=False)
|
| 40 |
-
net.eval()
|
| 41 |
-
|
| 42 |
-
return pipeline, net, provider_enc, checkpoint, device
|
| 43 |
-
except FileNotFoundError:
|
| 44 |
-
print(f"Error: Models not found in {MODEL_DIR}")
|
| 45 |
-
print(f"Run 'python3 train.py' first to train the models.")
|
| 46 |
-
sys.exit(1)
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def classify_text(text, pipeline, net, provider_enc, device):
|
| 50 |
-
"""Classify a single text and return provider results."""
|
| 51 |
-
t0 = time.time()
|
| 52 |
-
X = pipeline.transform([text])
|
| 53 |
-
X_t = torch.tensor(X.toarray(), dtype=torch.float32).to(device)
|
| 54 |
-
print(f" (featurize: {time.time() - t0:.2f}s)", end="")
|
| 55 |
-
|
| 56 |
-
with torch.no_grad():
|
| 57 |
-
prov_logits = net(X_t)
|
| 58 |
-
|
| 59 |
-
prov_proba = torch.softmax(prov_logits.float(), dim=1)[0].cpu().numpy()
|
| 60 |
-
|
| 61 |
-
# Provider top-5
|
| 62 |
-
top_prov_idxs = np.argsort(prov_proba)[::-1][:5]
|
| 63 |
-
top_providers = [
|
| 64 |
-
(provider_enc.inverse_transform([i])[0], prov_proba[i] * 100)
|
| 65 |
-
for i in top_prov_idxs
|
| 66 |
-
]
|
| 67 |
-
|
| 68 |
-
elapsed = time.time() - t0
|
| 69 |
-
print(f" (total classify: {elapsed:.2f}s)")
|
| 70 |
-
|
| 71 |
-
return {
|
| 72 |
-
"provider": top_providers[0][0],
|
| 73 |
-
"provider_confidence": top_providers[0][1],
|
| 74 |
-
"top_providers": top_providers,
|
| 75 |
-
}
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
def print_results(results):
|
| 79 |
-
"""Pretty-print classification results."""
|
| 80 |
-
print()
|
| 81 |
-
print(" ┌───────────────────────────────────────────────┐")
|
| 82 |
-
print(
|
| 83 |
-
f" │ Provider: {results['provider']} ({results['provider_confidence']:.1f}%)"
|
| 84 |
-
)
|
| 85 |
-
for name, conf in results["top_providers"]:
|
| 86 |
-
c = 0.0 if np.isnan(conf) else conf
|
| 87 |
-
bar = "█" * int(c / 5) + "░" * (20 - int(c / 5))
|
| 88 |
-
print(f" │ {name:.<25s} {c:5.1f}% {bar}")
|
| 89 |
-
|
| 90 |
-
print(" └───────────────────────────────────────────────┘")
|
| 91 |
-
print()
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
def correct_provider(
|
| 95 |
-
net,
|
| 96 |
-
X_t,
|
| 97 |
-
correct_provider_name,
|
| 98 |
-
provider_enc,
|
| 99 |
-
optimizer,
|
| 100 |
-
device,
|
| 101 |
-
):
|
| 102 |
-
"""Do a backward pass to correct the provider on a single example."""
|
| 103 |
-
try:
|
| 104 |
-
prov_idx = provider_enc.transform([correct_provider_name])[0]
|
| 105 |
-
except ValueError as e:
|
| 106 |
-
print(f" (label not in encoder: {e})")
|
| 107 |
-
return False
|
| 108 |
-
|
| 109 |
-
y_prov = torch.tensor([prov_idx], dtype=torch.long).to(device)
|
| 110 |
-
|
| 111 |
-
was_training = net.training
|
| 112 |
-
net.train()
|
| 113 |
-
|
| 114 |
-
# Disable batchnorm for single-sample training
|
| 115 |
-
if X_t.shape[0] <= 1:
|
| 116 |
-
for module in net.modules():
|
| 117 |
-
if isinstance(module, nn.modules.batchnorm._BatchNorm):
|
| 118 |
-
module.eval()
|
| 119 |
-
|
| 120 |
-
optimizer.zero_grad(set_to_none=True)
|
| 121 |
-
prov_criterion = nn.CrossEntropyLoss()
|
| 122 |
-
|
| 123 |
-
prov_logits = net(X_t)
|
| 124 |
-
loss = prov_criterion(prov_logits, y_prov)
|
| 125 |
-
loss.backward()
|
| 126 |
-
torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
|
| 127 |
-
optimizer.step()
|
| 128 |
-
|
| 129 |
-
if was_training:
|
| 130 |
-
net.train()
|
| 131 |
-
else:
|
| 132 |
-
net.eval()
|
| 133 |
-
|
| 134 |
-
print(f" ✓ Corrected → {correct_provider_name} (loss={loss.item():.4f})")
|
| 135 |
-
return True
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
def prompt_correction(known_providers):
|
| 139 |
-
"""Ask user for the correct provider."""
|
| 140 |
-
print(" Wrong? Enter correct provider number (or Enter to skip):")
|
| 141 |
-
for i, name in enumerate(known_providers, 1):
|
| 142 |
-
print(f" {i:>2d}. {name}")
|
| 143 |
-
try:
|
| 144 |
-
prov_choice = input(" Provider > ").strip()
|
| 145 |
-
except EOFError:
|
| 146 |
-
return None
|
| 147 |
-
if not prov_choice:
|
| 148 |
-
return None
|
| 149 |
-
|
| 150 |
-
correct_provider = None
|
| 151 |
-
try:
|
| 152 |
-
idx = int(prov_choice) - 1
|
| 153 |
-
if 0 <= idx < len(known_providers):
|
| 154 |
-
correct_provider = known_providers[idx]
|
| 155 |
-
except ValueError:
|
| 156 |
-
matches = [m for m in known_providers if prov_choice.lower() in m.lower()]
|
| 157 |
-
if len(matches) == 1:
|
| 158 |
-
correct_provider = matches[0]
|
| 159 |
-
|
| 160 |
-
if not correct_provider:
|
| 161 |
-
print(" (invalid choice, skipping)")
|
| 162 |
-
return None
|
| 163 |
-
|
| 164 |
-
return correct_provider
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
def main():
|
| 168 |
-
print()
|
| 169 |
-
print(" ╔═══════════════════════════════════════╗")
|
| 170 |
-
print(" ║ AIFinder - AI Response Classifier ║")
|
| 171 |
-
print(" ╚═══════════════════════════════════════╝")
|
| 172 |
-
print()
|
| 173 |
-
|
| 174 |
-
print(" Loading models...")
|
| 175 |
-
t0 = time.time()
|
| 176 |
-
pipeline, net, provider_enc, checkpoint, device = load_models()
|
| 177 |
-
print(f" Models loaded in {time.time() - t0:.1f}s.")
|
| 178 |
-
|
| 179 |
-
# Prepare online learning components
|
| 180 |
-
optimizer = torch.optim.AdamW(net.parameters(), lr=1e-4, weight_decay=1e-4)
|
| 181 |
-
known_providers = sorted(provider_enc.classes_.tolist())
|
| 182 |
-
corrections_made = 0
|
| 183 |
-
|
| 184 |
-
print()
|
| 185 |
-
print(" Paste text to classify (submit with TWO empty lines).")
|
| 186 |
-
print(" Type 'quit' to exit.\n")
|
| 187 |
-
|
| 188 |
-
last_X_t = None
|
| 189 |
-
|
| 190 |
-
while True:
|
| 191 |
-
print(" ─── Paste text below ───")
|
| 192 |
-
lines = []
|
| 193 |
-
empty_count = 0
|
| 194 |
-
while True:
|
| 195 |
-
try:
|
| 196 |
-
line = input()
|
| 197 |
-
except EOFError:
|
| 198 |
-
break
|
| 199 |
-
if line.strip() == "":
|
| 200 |
-
empty_count += 1
|
| 201 |
-
if empty_count >= 2:
|
| 202 |
-
break
|
| 203 |
-
lines.append(line)
|
| 204 |
-
else:
|
| 205 |
-
empty_count = 0
|
| 206 |
-
if line.strip().lower() == "quit":
|
| 207 |
-
if corrections_made > 0:
|
| 208 |
-
print(
|
| 209 |
-
f" Saving {corrections_made} correction(s) to checkpoint..."
|
| 210 |
-
)
|
| 211 |
-
checkpoint["state_dict"] = net.state_dict()
|
| 212 |
-
torch.save(checkpoint, os.path.join(MODEL_DIR, "classifier.pt"))
|
| 213 |
-
print(" ✓ Saved.")
|
| 214 |
-
print(" Goodbye!")
|
| 215 |
-
return
|
| 216 |
-
lines.append(line)
|
| 217 |
-
|
| 218 |
-
text = "\n".join(lines).strip()
|
| 219 |
-
if not text:
|
| 220 |
-
print(" (empty input, try again)")
|
| 221 |
-
continue
|
| 222 |
-
|
| 223 |
-
if len(text) < 20:
|
| 224 |
-
print(" (text too short, need at least 20 chars)")
|
| 225 |
-
continue
|
| 226 |
-
|
| 227 |
-
results = classify_text(text, pipeline, net, provider_enc, device)
|
| 228 |
-
print_results(results)
|
| 229 |
-
|
| 230 |
-
X = pipeline.transform([text])
|
| 231 |
-
last_X_t = torch.tensor(X.toarray(), dtype=torch.float32).to(device)
|
| 232 |
-
|
| 233 |
-
correct_prov = prompt_correction(known_providers)
|
| 234 |
-
if correct_prov:
|
| 235 |
-
ok = correct_provider(
|
| 236 |
-
net,
|
| 237 |
-
last_X_t,
|
| 238 |
-
correct_prov,
|
| 239 |
-
provider_enc,
|
| 240 |
-
optimizer,
|
| 241 |
-
device,
|
| 242 |
-
)
|
| 243 |
-
if ok:
|
| 244 |
-
corrections_made += 1
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
if __name__ == "__main__":
|
| 248 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|