WCNegentropy commited on
Commit
b528828
·
verified ·
1 Parent(s): 9f9758e

Remove nested directory: BitTransformerLM/bit_transformer/telemetry.py

Browse files
BitTransformerLM/bit_transformer/telemetry.py DELETED
@@ -1,95 +0,0 @@
1
- import numpy as np
2
- from typing import Dict, List, TYPE_CHECKING
3
-
4
- import torch
5
- from sklearn.cluster import KMeans
6
-
7
- if TYPE_CHECKING: # pragma: no cover
8
- from .model import BitTransformerLM
9
-
10
-
11
- class TelemetrySynthesizer:
12
- """Analyze telemetry batches and cluster activation patterns."""
13
-
14
- def __init__(self, n_clusters: int = 2) -> None:
15
- self.n_clusters = n_clusters
16
-
17
- def _summary(self, telemetry: Dict[str, List[torch.Tensor]]) -> np.ndarray:
18
- """Compute activation/attention summaries for a single telemetry dict."""
19
- acts = telemetry["activations"]
20
- attn = telemetry["attention_maps"]
21
- summaries = []
22
- for a, m in zip(acts, attn):
23
- mean = a.mean().item()
24
- var = a.var(unbiased=False).item()
25
- prob = m.softmax(-1)
26
- entropy = -(prob * prob.clamp_min(1e-9).log()).sum(-1).mean().item()
27
- summaries.append([mean, var, entropy])
28
- return np.array(summaries).ravel()
29
-
30
- def synthesize(
31
- self, telemetries: List[Dict[str, List[torch.Tensor]]], bit_seqs: torch.Tensor
32
- ) -> Dict[str, List]:
33
- """Cluster telemetry summaries and return cluster info."""
34
- data = np.stack([self._summary(t) for t in telemetries])
35
- km = KMeans(n_clusters=self.n_clusters, n_init=1)
36
- labels = km.fit_predict(data)
37
- representatives: List[List[int]] = []
38
- for c in range(self.n_clusters):
39
- idx = np.where(labels == c)[0]
40
- if len(idx) > 0:
41
- representatives.append(bit_seqs[idx[0]].tolist())
42
- else:
43
- representatives.append([])
44
- return {"cluster_assignments": labels.tolist(), "representatives": representatives}
45
-
46
- def cluster_sequences(
47
- self, model: "BitTransformerLM", bit_seqs: torch.Tensor
48
- ) -> List[List[int]]:
49
- """Run the model to gather telemetry and return representative sequences.
50
-
51
- Parameters
52
- ----------
53
- model: BitTransformerLM
54
- Model used to compute telemetry for each sequence.
55
- bit_seqs: torch.Tensor
56
- Tensor containing one bit sequence per row.
57
-
58
- Returns
59
- -------
60
- list[list[int]]
61
- Representative sequences chosen from KMeans clusters.
62
- """
63
- telemetries: List[Dict[str, List[torch.Tensor]]] = []
64
- with torch.no_grad():
65
- for seq in bit_seqs:
66
- _, tele = model(seq.unsqueeze(0))
67
- telemetries.append(tele)
68
- info = self.synthesize(telemetries, bit_seqs)
69
- return info["representatives"]
70
-
71
-
72
- def detect_metric_drift(
73
- metrics_log: Dict[str, List[float]],
74
- window: int = 10,
75
- threshold: float = 0.2,
76
- ) -> Dict[str, bool]:
77
- """Detect metric drift between consecutive windows.
78
-
79
- Args:
80
- metrics_log: History of scalar metrics keyed by name.
81
- window: Number of recent steps to compare.
82
- threshold: Absolute difference required to flag drift.
83
-
84
- Returns:
85
- Dictionary mapping metric keys to a boolean drift indicator.
86
- """
87
- drift = {}
88
- for key, values in metrics_log.items():
89
- if len(values) < window * 2:
90
- drift[key] = False
91
- continue
92
- recent = np.mean(values[-window:])
93
- prev = np.mean(values[-2 * window : -window])
94
- drift[key] = abs(recent - prev) > threshold
95
- return drift