bhardwaj08sarthak commited on
Commit
52bcee7
·
verified ·
1 Parent(s): bfc2469

Upload level_classifier_tool.py

Browse files
Files changed (1) hide show
  1. level_classifier_tool.py +278 -0
level_classifier_tool.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # level_classifier_tool.py
2
+ """
3
+ A lightweight utility for classifying a question against Bloom's and DOK levels
4
+ by comparing its embedding to curated "anchor phrases" for each level.
5
+
6
+ Main entry point:
7
+ classify_levels_phrases(question, blooms_phrases, dok_phrases, ...)
8
+
9
+ Author: Prepared by ChatGPT
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from dataclasses import dataclass
15
+ from typing import Dict, List, Tuple, Iterable, Optional, Literal, Callable, Any
16
+ import math
17
+ import os
18
+
19
+ # Optional heavy deps are imported lazily when needed
20
+ _TOK = None
21
+ _MODEL = None
22
+ _TORCH = None
23
+
24
+ Agg = Literal["mean", "max", "topk_mean"]
25
+
26
+
27
+ # --------------------------- Embedding backend ---------------------------
28
+
29
+ @dataclass
30
+ class HFEmbeddingBackend:
31
+ """
32
+ Minimal huggingface transformers encoder for sentence-level embeddings.
33
+ Uses mean pooling over last_hidden_state and L2 normalizes the result.
34
+ """
35
+ model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
36
+ device: Optional[str] = None # "cuda" | "cpu" | None -> auto
37
+
38
+ def _lazy_import(self) -> None:
39
+ global _TOK, _MODEL, _TORCH
40
+ if _TORCH is None:
41
+ import torch as _torch
42
+ _TORCH = _torch
43
+ if _TOK is None or _MODEL is None:
44
+ from transformers import AutoTokenizer, AutoModel # type: ignore
45
+ _TOK = AutoTokenizer.from_pretrained(self.model_name)
46
+ _MODEL = AutoModel.from_pretrained(self.model_name)
47
+ dev = self.device or ("cuda" if _TORCH.cuda.is_available() else "cpu")
48
+ _MODEL.to(dev).eval()
49
+ self.device = dev
50
+
51
+ def encode(self, texts: Iterable[str], batch_size: int = 32) -> "tuple[_TORCH.Tensor, list[str]]":
52
+ """
53
+ Returns (embeddings, texts_list). Embeddings have shape [N, D] and are unit-normalized.
54
+ """
55
+ self._lazy_import()
56
+ torch = _TORCH # local alias
57
+ texts_list = list(texts)
58
+ if not texts_list:
59
+ return torch.empty((0, _MODEL.config.hidden_size)), [] # type: ignore
60
+
61
+ all_out = []
62
+ with torch.inference_mode():
63
+ for i in range(0, len(texts_list), batch_size):
64
+ batch = texts_list[i:i + batch_size]
65
+ enc = _TOK(batch, padding=True, truncation=True, return_tensors="pt").to(self.device) # type: ignore
66
+ out = _MODEL(**enc)
67
+ last = out.last_hidden_state # [B, T, H]
68
+ mask = enc["attention_mask"].unsqueeze(-1) # [B, T, 1]
69
+ # mean pool
70
+ summed = (last * mask).sum(dim=1)
71
+ counts = mask.sum(dim=1).clamp(min=1)
72
+ pooled = summed / counts
73
+ # L2 normalize
74
+ pooled = pooled / pooled.norm(dim=1, keepdim=True).clamp(min=1e-12)
75
+ all_out.append(pooled.cpu())
76
+ embs = torch.cat(all_out, dim=0) if all_out else torch.empty((0, _MODEL.config.hidden_size)) # type: ignore
77
+ return embs, texts_list
78
+
79
+
80
+ # --------------------------- Utilities ---------------------------
81
+
82
+ def _normalize_whitespace(s: str) -> str:
83
+ return " ".join(s.strip().split())
84
+
85
+
86
+ def _default_preprocess(s: str) -> str:
87
+ # Keep simple, deterministic preprocessing. Users can override with a custom callable.
88
+ return _normalize_whitespace(s)
89
+
90
+
91
+ @dataclass
92
+ class PhraseIndex:
93
+ phrases_by_level: Dict[str, List[str]]
94
+ embeddings_by_level: Dict[str, "Any"] # torch.Tensor, but keep Any to avoid hard dep at import time
95
+ model_name: str
96
+
97
+
98
+ def build_phrase_index(
99
+ backend: HFEmbeddingBackend,
100
+ phrases_by_level: Dict[str, Iterable[str]],
101
+ ) -> PhraseIndex:
102
+ """
103
+ Pre-encode all anchor phrases per level into a searchable index.
104
+ """
105
+ # Flatten texts while preserving level boundaries
106
+ cleaned: Dict[str, List[str]] = {lvl: [_default_preprocess(p) for p in phrases] for lvl, phrases in phrases_by_level.items()}
107
+ all_texts: List[str] = []
108
+ spans: List[Tuple[str, int, int]] = [] # (level, start, end) in the flat list
109
+ cur = 0
110
+ for lvl, plist in cleaned.items():
111
+ start = cur
112
+ all_texts.extend(plist)
113
+ cur += len(plist)
114
+ spans.append((lvl, start, cur))
115
+
116
+ embs, _ = backend.encode(all_texts)
117
+ # Slice embeddings back into level buckets
118
+ torch = _TORCH
119
+ embeddings_by_level: Dict[str, "Any"] = {}
120
+ for lvl, start, end in spans:
121
+ embeddings_by_level[lvl] = embs[start:end] if end > start else torch.empty((0, embs.shape[1])) # type: ignore
122
+
123
+ return PhraseIndex(phrases_by_level={lvl: list(pl) for lvl, pl in cleaned.items()},
124
+ embeddings_by_level=embeddings_by_level,
125
+ model_name=backend.model_name)
126
+
127
+
128
+ def _aggregate_sims(
129
+ sims: "Any", agg: Agg, topk: int
130
+ ) -> float:
131
+ """
132
+ Aggregate a 1D tensor of similarities into a single score.
133
+ """
134
+ torch = _TORCH
135
+ if sims.numel() == 0:
136
+ return float("nan")
137
+ if agg == "mean":
138
+ return float(sims.mean().item())
139
+ if agg == "max":
140
+ return float(sims.max().item())
141
+ if agg == "topk_mean":
142
+ k = min(topk, sims.numel())
143
+ topk_vals, _ = torch.topk(sims, k)
144
+ return float(topk_vals.mean().item())
145
+ raise ValueError(f"Unknown agg: {agg}")
146
+
147
+
148
+ # --------------------------- Public API ---------------------------
149
+
150
+ def classify_levels_phrases(
151
+ question: str,
152
+ blooms_phrases: Dict[str, Iterable[str]],
153
+ dok_phrases: Dict[str, Iterable[str]],
154
+ *,
155
+ model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
156
+ agg: Agg = "max",
157
+ topk: int = 5,
158
+ preprocess: Optional[Callable[[str], str]] = None,
159
+ backend: Optional[HFEmbeddingBackend] = None,
160
+ prebuilt_bloom_index: Optional[PhraseIndex] = None,
161
+ prebuilt_dok_index: Optional[PhraseIndex] = None,
162
+ return_phrase_matches: bool = True,
163
+ ) -> Dict[str, Any]:
164
+ """
165
+ Score a question against Bloom's taxonomy and DOK (Depth of Knowledge)
166
+ using cosine similarity to level-specific anchor phrases.
167
+
168
+ Parameters
169
+ ----------
170
+ question : str
171
+ The input question or prompt.
172
+ blooms_phrases : dict[str, Iterable[str]]
173
+ Mapping level -> list of anchor phrases for Bloom's.
174
+ dok_phrases : dict[str, Iterable[str]]
175
+ Mapping level -> list of anchor phrases for DOK.
176
+ model_name : str
177
+ Hugging Face model name for text embeddings. Ignored when `backend` provided.
178
+ agg : {"mean","max","topk_mean"}
179
+ Aggregation over phrase similarities within a level.
180
+ topk : int
181
+ Used only when `agg="topk_mean"`.
182
+ preprocess : Optional[Callable[[str], str]]
183
+ Preprocessing function for the question string. Defaults to whitespace normalization.
184
+ backend : Optional[HFEmbeddingBackend]
185
+ Injected embedding backend. If not given, one is constructed.
186
+ prebuilt_bloom_index, prebuilt_dok_index : Optional[PhraseIndex]
187
+ If provided, reuse precomputed phrase embeddings to avoid re-encoding.
188
+ return_phrase_matches : bool
189
+ If True, returns per-level top contributing phrases.
190
+
191
+ Returns
192
+ -------
193
+ dict
194
+ {
195
+ "question": ...,
196
+ "model_name": ...,
197
+ "blooms": {
198
+ "scores": {level: float, ...},
199
+ "best_level": str,
200
+ "best_score": float,
201
+ "top_phrases": {level: [(phrase, sim_float), ...], ...} # only if return_phrase_matches
202
+ },
203
+ "dok": {
204
+ "scores": {level: float, ...},
205
+ "best_level": str,
206
+ "best_score": float,
207
+ "top_phrases": {level: [(phrase, sim_float), ...], ...} # only if return_phrase_matches
208
+ },
209
+ "config": {"agg": agg, "topk": topk if agg=='topk_mean' else None}
210
+ }
211
+ """
212
+ preprocess = preprocess or _default_preprocess
213
+ question_clean = preprocess(question)
214
+
215
+ # Prepare backend
216
+ be = backend or HFEmbeddingBackend(model_name=model_name)
217
+
218
+ # Build / reuse indices
219
+ bloom_index = prebuilt_bloom_index or build_phrase_index(be, blooms_phrases)
220
+ dok_index = prebuilt_dok_index or build_phrase_index(be, dok_phrases)
221
+
222
+ # Encode question
223
+ q_emb, _ = be.encode([question_clean])
224
+ q_emb = q_emb[0:1] # [1, D]
225
+ torch = _TORCH
226
+
227
+ def _score_block(index: PhraseIndex) -> Tuple[Dict[str, float], Dict[str, List[Tuple[str, float]]]]:
228
+ scores: Dict[str, float] = {}
229
+ top_contribs: Dict[str, List[Tuple[str, float]]] = {}
230
+
231
+ for lvl, phrases in index.phrases_by_level.items():
232
+ embs = index.embeddings_by_level[lvl] # [N, D]
233
+ if embs.numel() == 0:
234
+ scores[lvl] = float("nan")
235
+ top_contribs[lvl] = []
236
+ continue
237
+ sims = (q_emb @ embs.T).squeeze(0) # cosine sim due to L2 norm
238
+ scores[lvl] = _aggregate_sims(sims, agg, topk)
239
+ if return_phrase_matches:
240
+ k = min(5, sims.numel())
241
+ vals, idxs = torch.topk(sims, k)
242
+ top_contribs[lvl] = [(phrases[int(i)], float(v.item())) for v, i in zip(vals, idxs)]
243
+ return scores, top_contribs
244
+
245
+ bloom_scores, bloom_top = _score_block(bloom_index)
246
+ dok_scores, dok_top = _score_block(dok_index)
247
+
248
+ def _best(scores: Dict[str, float]) -> Tuple[str, float]:
249
+ # max with NaN-safe handling
250
+ best_lvl, best_val = None, -float("inf")
251
+ for lvl, val in scores.items():
252
+ if isinstance(val, float) and (not math.isnan(val)) and val > best_val:
253
+ best_lvl, best_val = lvl, val
254
+ return best_lvl or "", best_val
255
+
256
+ best_bloom, best_bloom_val = _best(bloom_scores)
257
+ best_dok, best_dok_val = _best(dok_scores)
258
+
259
+ return {
260
+ "question": question_clean,
261
+ "model_name": be.model_name,
262
+ "blooms": {
263
+ "scores": bloom_scores,
264
+ "best_level": best_bloom,
265
+ "best_score": best_bloom_val,
266
+ "top_phrases": bloom_top if return_phrase_matches else None,
267
+ },
268
+ "dok": {
269
+ "scores": dok_scores,
270
+ "best_level": best_dok,
271
+ "best_score": best_dok_val,
272
+ "top_phrases": dok_top if return_phrase_matches else None,
273
+ },
274
+ "config": {
275
+ "agg": agg,
276
+ "topk": topk if agg == "topk_mean" else None,
277
+ },
278
+ }