CompactAI commited on
Commit
7835d3d
·
verified ·
1 Parent(s): 6b114c2

Delete classify.py

Browse files
Files changed (1) hide show
  1. classify.py +0 -248
classify.py DELETED
@@ -1,248 +0,0 @@
1
- """
2
- AIFinder Interactive Classifier
3
- Loads trained model and provides an interactive REPL for classifying text.
4
-
5
- Usage: python3 classify.py
6
- """
7
-
8
- import os
9
- import sys
10
- import time
11
- import joblib
12
- import numpy as np
13
- import torch
14
- import torch.nn as nn
15
-
16
- from config import MODEL_DIR, DATASET_REGISTRY, DEEPSEEK_AM_DATASETS
17
- from model import AIFinderNet
18
-
19
-
20
- def load_models():
21
- """Load all model components from the model directory."""
22
- try:
23
- pipeline = joblib.load(os.path.join(MODEL_DIR, "feature_pipeline.joblib"))
24
- provider_enc = joblib.load(os.path.join(MODEL_DIR, "provider_enc.joblib"))
25
-
26
- checkpoint = torch.load(
27
- os.path.join(MODEL_DIR, "classifier.pt"),
28
- map_location="cpu",
29
- weights_only=True,
30
- )
31
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
- net = AIFinderNet(
33
- input_dim=checkpoint["input_dim"],
34
- num_providers=checkpoint["num_providers"],
35
- hidden_dim=checkpoint["hidden_dim"],
36
- embed_dim=checkpoint["embed_dim"],
37
- dropout=checkpoint["dropout"],
38
- ).to(device)
39
- net.load_state_dict(checkpoint["state_dict"], strict=False)
40
- net.eval()
41
-
42
- return pipeline, net, provider_enc, checkpoint, device
43
- except FileNotFoundError:
44
- print(f"Error: Models not found in {MODEL_DIR}")
45
- print(f"Run 'python3 train.py' first to train the models.")
46
- sys.exit(1)
47
-
48
-
49
- def classify_text(text, pipeline, net, provider_enc, device):
50
- """Classify a single text and return provider results."""
51
- t0 = time.time()
52
- X = pipeline.transform([text])
53
- X_t = torch.tensor(X.toarray(), dtype=torch.float32).to(device)
54
- print(f" (featurize: {time.time() - t0:.2f}s)", end="")
55
-
56
- with torch.no_grad():
57
- prov_logits = net(X_t)
58
-
59
- prov_proba = torch.softmax(prov_logits.float(), dim=1)[0].cpu().numpy()
60
-
61
- # Provider top-5
62
- top_prov_idxs = np.argsort(prov_proba)[::-1][:5]
63
- top_providers = [
64
- (provider_enc.inverse_transform([i])[0], prov_proba[i] * 100)
65
- for i in top_prov_idxs
66
- ]
67
-
68
- elapsed = time.time() - t0
69
- print(f" (total classify: {elapsed:.2f}s)")
70
-
71
- return {
72
- "provider": top_providers[0][0],
73
- "provider_confidence": top_providers[0][1],
74
- "top_providers": top_providers,
75
- }
76
-
77
-
78
- def print_results(results):
79
- """Pretty-print classification results."""
80
- print()
81
- print(" ┌───────────────────────────────────────────────┐")
82
- print(
83
- f" │ Provider: {results['provider']} ({results['provider_confidence']:.1f}%)"
84
- )
85
- for name, conf in results["top_providers"]:
86
- c = 0.0 if np.isnan(conf) else conf
87
- bar = "█" * int(c / 5) + "░" * (20 - int(c / 5))
88
- print(f" │ {name:.<25s} {c:5.1f}% {bar}")
89
-
90
- print(" └───────────────────────────────────────────────┘")
91
- print()
92
-
93
-
94
- def correct_provider(
95
- net,
96
- X_t,
97
- correct_provider_name,
98
- provider_enc,
99
- optimizer,
100
- device,
101
- ):
102
- """Do a backward pass to correct the provider on a single example."""
103
- try:
104
- prov_idx = provider_enc.transform([correct_provider_name])[0]
105
- except ValueError as e:
106
- print(f" (label not in encoder: {e})")
107
- return False
108
-
109
- y_prov = torch.tensor([prov_idx], dtype=torch.long).to(device)
110
-
111
- was_training = net.training
112
- net.train()
113
-
114
- # Disable batchnorm for single-sample training
115
- if X_t.shape[0] <= 1:
116
- for module in net.modules():
117
- if isinstance(module, nn.modules.batchnorm._BatchNorm):
118
- module.eval()
119
-
120
- optimizer.zero_grad(set_to_none=True)
121
- prov_criterion = nn.CrossEntropyLoss()
122
-
123
- prov_logits = net(X_t)
124
- loss = prov_criterion(prov_logits, y_prov)
125
- loss.backward()
126
- torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
127
- optimizer.step()
128
-
129
- if was_training:
130
- net.train()
131
- else:
132
- net.eval()
133
-
134
- print(f" ✓ Corrected → {correct_provider_name} (loss={loss.item():.4f})")
135
- return True
136
-
137
-
138
- def prompt_correction(known_providers):
139
- """Ask user for the correct provider."""
140
- print(" Wrong? Enter correct provider number (or Enter to skip):")
141
- for i, name in enumerate(known_providers, 1):
142
- print(f" {i:>2d}. {name}")
143
- try:
144
- prov_choice = input(" Provider > ").strip()
145
- except EOFError:
146
- return None
147
- if not prov_choice:
148
- return None
149
-
150
- correct_provider = None
151
- try:
152
- idx = int(prov_choice) - 1
153
- if 0 <= idx < len(known_providers):
154
- correct_provider = known_providers[idx]
155
- except ValueError:
156
- matches = [m for m in known_providers if prov_choice.lower() in m.lower()]
157
- if len(matches) == 1:
158
- correct_provider = matches[0]
159
-
160
- if not correct_provider:
161
- print(" (invalid choice, skipping)")
162
- return None
163
-
164
- return correct_provider
165
-
166
-
167
- def main():
168
- print()
169
- print(" ╔═══════════════════════════════════════╗")
170
- print(" ║ AIFinder - AI Response Classifier ║")
171
- print(" ╚═══════════════════════════════════════╝")
172
- print()
173
-
174
- print(" Loading models...")
175
- t0 = time.time()
176
- pipeline, net, provider_enc, checkpoint, device = load_models()
177
- print(f" Models loaded in {time.time() - t0:.1f}s.")
178
-
179
- # Prepare online learning components
180
- optimizer = torch.optim.AdamW(net.parameters(), lr=1e-4, weight_decay=1e-4)
181
- known_providers = sorted(provider_enc.classes_.tolist())
182
- corrections_made = 0
183
-
184
- print()
185
- print(" Paste text to classify (submit with TWO empty lines).")
186
- print(" Type 'quit' to exit.\n")
187
-
188
- last_X_t = None
189
-
190
- while True:
191
- print(" ─── Paste text below ───")
192
- lines = []
193
- empty_count = 0
194
- while True:
195
- try:
196
- line = input()
197
- except EOFError:
198
- break
199
- if line.strip() == "":
200
- empty_count += 1
201
- if empty_count >= 2:
202
- break
203
- lines.append(line)
204
- else:
205
- empty_count = 0
206
- if line.strip().lower() == "quit":
207
- if corrections_made > 0:
208
- print(
209
- f" Saving {corrections_made} correction(s) to checkpoint..."
210
- )
211
- checkpoint["state_dict"] = net.state_dict()
212
- torch.save(checkpoint, os.path.join(MODEL_DIR, "classifier.pt"))
213
- print(" ✓ Saved.")
214
- print(" Goodbye!")
215
- return
216
- lines.append(line)
217
-
218
- text = "\n".join(lines).strip()
219
- if not text:
220
- print(" (empty input, try again)")
221
- continue
222
-
223
- if len(text) < 20:
224
- print(" (text too short, need at least 20 chars)")
225
- continue
226
-
227
- results = classify_text(text, pipeline, net, provider_enc, device)
228
- print_results(results)
229
-
230
- X = pipeline.transform([text])
231
- last_X_t = torch.tensor(X.toarray(), dtype=torch.float32).to(device)
232
-
233
- correct_prov = prompt_correction(known_providers)
234
- if correct_prov:
235
- ok = correct_provider(
236
- net,
237
- last_X_t,
238
- correct_prov,
239
- provider_enc,
240
- optimizer,
241
- device,
242
- )
243
- if ok:
244
- corrections_made += 1
245
-
246
-
247
- if __name__ == "__main__":
248
- main()