Samuel Stevens commited on
Commit
5db6fa7
·
1 Parent(s): 107d6a0

Update colors

Browse files
Files changed (3) hide show
  1. app.py +14 -23
  2. data.py +49 -3
  3. requirements.txt +29 -296
app.py CHANGED
@@ -10,6 +10,7 @@ import beartype
10
  import einops
11
  import einops.layers.torch
12
  import gradio as gr
 
13
  import numpy as np
14
  import saev.activations
15
  import saev.config
@@ -31,7 +32,7 @@ logger = logging.getLogger("app.py")
31
  ####################
32
 
33
 
34
- MAX_FREQ = 1e-2
35
  """Maximum frequency. Any feature that fires more than this is ignored."""
36
 
37
  RESIZE_SIZE = 512
@@ -46,12 +47,14 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
46
  CWD = pathlib.Path(".")
47
  """Current working directory."""
48
 
49
- N_SAE_LATENTS = 2
50
  """Number of SAE latents to show."""
51
 
52
  N_LATENT_EXAMPLES = 4
53
  """Number of examples per SAE latent to show."""
54
 
 
 
55
 
56
  @beartype.beartype
57
  class Example(typing.TypedDict):
@@ -175,8 +178,11 @@ def add_highlights(
175
  overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
176
  draw = ImageDraw.Draw(overlay)
177
 
 
 
 
178
  # Using semi-transparent red (255, 0, 0, alpha)
179
- for p, val in enumerate(patches):
180
  assert upper is not None
181
  val /= upper + 1e-9
182
  x_np, y_np = p % iw_np, p // ih_np
@@ -185,28 +191,13 @@ def add_highlights(
185
  (x_np * pw_px, y_np * ph_px),
186
  (x_np * pw_px + pw_px, y_np * ph_px + ph_px),
187
  ],
188
- fill=(int(val * 256), 0, 0, int(opacity * val * 256)),
189
  )
190
 
191
  # Composite the original image and the overlay
192
  return Image.alpha_composite(img.convert("RGBA"), overlay)
193
 
194
 
195
- @jaxtyped(typechecker=beartype.beartype)
196
- @torch.inference_mode
197
- def upsample(
198
- x_WH: Int[Tensor, "width_ps height_ps"],
199
- ) -> UInt8[Tensor, "width_px height_px"]:
200
- return (
201
- torch.nn.functional.interpolate(
202
- x_WH.view((1, 1, 16, 16)).float(),
203
- scale_factor=28,
204
- )
205
- .view((448, 448))
206
- .type(torch.uint8)
207
- )
208
-
209
-
210
  #######################
211
  # Inference Functions #
212
  #######################
@@ -317,10 +308,10 @@ def get_orig_preds(img: Image.Image) -> Example:
317
  clf = load_clf()
318
  logits_WHC = clf(x_WHD)
319
 
320
- pred_WH = logits_WHC.argmax(axis=-1)
321
  return {
322
  "orig_url": data.img_to_base64(data.to_sized(img)),
323
- "seg_url": data.img_to_base64(data.u8_to_img(upsample(pred_WH))),
324
  "classes": data.to_classes(pred_WH),
325
  }
326
 
@@ -384,11 +375,11 @@ def get_mod_preds(img: Image.Image, latents: dict[str, int | float]) -> Example:
384
  mod_WHD = einops.rearrange(mod_BPD, "() (w h) dim -> w h dim", w=16, h=16)
385
 
386
  logits_WHC = clf(mod_WHD)
387
- pred_WH = logits_WHC.argmax(axis=-1)
388
  # pred_WH = einops.rearrange(pred_P, "(w h) -> w h", w=16, h=16)
389
  return {
390
  "orig_url": data.img_to_base64(data.to_sized(img)),
391
- "seg_url": data.img_to_base64(data.u8_to_img(upsample(pred_WH))),
392
  "classes": data.to_classes(pred_WH),
393
  }
394
 
 
10
  import einops
11
  import einops.layers.torch
12
  import gradio as gr
13
+ import matplotlib
14
  import numpy as np
15
  import saev.activations
16
  import saev.config
 
32
  ####################
33
 
34
 
35
+ MAX_FREQ = 3e-2
36
  """Maximum frequency. Any feature that fires more than this is ignored."""
37
 
38
  RESIZE_SIZE = 512
 
47
  CWD = pathlib.Path(".")
48
  """Current working directory."""
49
 
50
+ N_SAE_LATENTS = 4
51
  """Number of SAE latents to show."""
52
 
53
  N_LATENT_EXAMPLES = 4
54
  """Number of examples per SAE latent to show."""
55
 
56
+ COLORMAP = matplotlib.colormaps.get_cmap("plasma")
57
+
58
 
59
  @beartype.beartype
60
  class Example(typing.TypedDict):
 
178
  overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
179
  draw = ImageDraw.Draw(overlay)
180
 
181
+ colors = np.zeros((len(patches), 3), dtype=np.uint8)
182
+ colors[:, 0] = ((patches / (upper + 1e-9)) * 255).astype(np.uint8)
183
+
184
  # Using semi-transparent red (255, 0, 0, alpha)
185
+ for p, (val, color) in enumerate(zip(patches, colors)):
186
  assert upper is not None
187
  val /= upper + 1e-9
188
  x_np, y_np = p % iw_np, p // ih_np
 
191
  (x_np * pw_px, y_np * ph_px),
192
  (x_np * pw_px + pw_px, y_np * ph_px + ph_px),
193
  ],
194
+ fill=(*color, int(opacity * val * 255)),
195
  )
196
 
197
  # Composite the original image and the overlay
198
  return Image.alpha_composite(img.convert("RGBA"), overlay)
199
 
200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  #######################
202
  # Inference Functions #
203
  #######################
 
308
  clf = load_clf()
309
  logits_WHC = clf(x_WHD)
310
 
311
+ pred_WH = logits_WHC[:, :, 1:].argmax(axis=-1) + 1
312
  return {
313
  "orig_url": data.img_to_base64(data.to_sized(img)),
314
+ "seg_url": data.img_to_base64(data.u8_to_overlay(pred_WH, img)),
315
  "classes": data.to_classes(pred_WH),
316
  }
317
 
 
375
  mod_WHD = einops.rearrange(mod_BPD, "() (w h) dim -> w h dim", w=16, h=16)
376
 
377
  logits_WHC = clf(mod_WHD)
378
+ pred_WH = logits_WHC[:, :, 1:].argmax(axis=-1) + 1
379
  # pred_WH = einops.rearrange(pred_P, "(w h) -> w h", w=16, h=16)
380
  return {
381
  "orig_url": data.img_to_base64(data.to_sized(img)),
382
+ "seg_url": data.img_to_base64(data.u8_to_overlay(pred_WH, img)),
383
  "classes": data.to_classes(pred_WH),
384
  }
385
 
data.py CHANGED
@@ -8,8 +8,9 @@ import beartype
8
  import einops.layers.torch
9
  import numpy as np
10
  import requests
11
- from jaxtyping import Integer, UInt8, jaxtyped
12
- from PIL import Image
 
13
  from torch import Tensor
14
  from torchvision.transforms import v2
15
 
@@ -50,6 +51,7 @@ def make_colors() -> UInt8[np.ndarray, "n 3"]:
50
 
51
  # Fixed colors. Must be synced with Segmentation.elm.
52
  colors[2] = np.array([201, 249, 255], dtype=np.uint8)
 
53
  colors[4] = np.array([151, 204, 4], dtype=np.uint8)
54
  colors[13] = np.array([104, 139, 88], dtype=np.uint8)
55
  colors[16] = np.array([54, 48, 32], dtype=np.uint8)
@@ -89,6 +91,50 @@ def to_u8(seg_raw: Image.Image) -> UInt8[Tensor, "width height"]:
89
  return u8_transform(seg_raw)
90
 
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  @jaxtyped(typechecker=beartype.beartype)
93
  def u8_to_img(map: UInt8[Tensor, "width height"]) -> Image.Image:
94
  map = map.cpu().numpy()
@@ -109,7 +155,7 @@ def to_classes(map: Integer[Tensor, "width height"]) -> list[int]:
109
  @beartype.beartype
110
  def img_to_base64(img: Image.Image) -> str:
111
  buf = io.BytesIO()
112
- img.save(buf, format="webp")
113
  b64 = base64.b64encode(buf.getvalue())
114
  s64 = b64.decode("utf8")
115
  return "data:image/webp;base64," + s64
 
8
  import einops.layers.torch
9
  import numpy as np
10
  import requests
11
+ import torch
12
+ from jaxtyping import Int, Integer, UInt8, jaxtyped
13
+ from PIL import Image, ImageDraw
14
  from torch import Tensor
15
  from torchvision.transforms import v2
16
 
 
51
 
52
  # Fixed colors. Must be synced with Segmentation.elm.
53
  colors[2] = np.array([201, 249, 255], dtype=np.uint8)
54
+ colors[2] = np.array([201, 249, 255], dtype=np.uint8)
55
  colors[4] = np.array([151, 204, 4], dtype=np.uint8)
56
  colors[13] = np.array([104, 139, 88], dtype=np.uint8)
57
  colors[16] = np.array([54, 48, 32], dtype=np.uint8)
 
91
  return u8_transform(seg_raw)
92
 
93
 
94
+ @jaxtyped(typechecker=beartype.beartype)
95
+ def upsample(
96
+ x_WH: Int[Tensor, "width_ps height_ps"],
97
+ ) -> UInt8[Tensor, "width_px height_px"]:
98
+ return (
99
+ torch.nn.functional.interpolate(
100
+ x_WH.view((1, 1, 16, 16)).float(),
101
+ scale_factor=28,
102
+ )
103
+ .view((448, 448))
104
+ .type(torch.uint8)
105
+ )
106
+
107
+
108
+ @jaxtyped(typechecker=beartype.beartype)
109
+ def u8_to_overlay(
110
+ map: Integer[Tensor, "width_ps height_ps"],
111
+ img: Image.Image,
112
+ *,
113
+ opacity: float = 0.5,
114
+ ) -> Image.Image:
115
+ iw_np, ih_np = map.shape
116
+ iw_px, ih_px = img.size
117
+ pw_px, ph_px = iw_px // iw_np, ih_px // ih_np
118
+
119
+ # Create a transparent overlay
120
+ overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
121
+ draw = ImageDraw.Draw(overlay)
122
+
123
+ # Using semi-transparent red (255, 0, 0, alpha)
124
+ for p, i in enumerate(map.view(-1).tolist()):
125
+ x_np, y_np = p % iw_np, p // ih_np
126
+ draw.rectangle(
127
+ [
128
+ (x_np * pw_px, y_np * ph_px),
129
+ (x_np * pw_px + pw_px, y_np * ph_px + ph_px),
130
+ ],
131
+ fill=(*colors[i - 1], int(opacity * 256)),
132
+ )
133
+
134
+ # Composite the original image and the overlay
135
+ return Image.alpha_composite(img.convert("RGBA"), overlay)
136
+
137
+
138
  @jaxtyped(typechecker=beartype.beartype)
139
  def u8_to_img(map: UInt8[Tensor, "width height"]) -> Image.Image:
140
  map = map.cpu().numpy()
 
155
  @beartype.beartype
156
  def img_to_base64(img: Image.Image) -> str:
157
  buf = io.BytesIO()
158
+ img.save(buf, format="webp", lossless=True)
159
  b64 = base64.b64encode(buf.getvalue())
160
  s64 = b64.decode("utf8")
161
  return "data:image/webp;base64," + s64
requirements.txt CHANGED
@@ -2,9 +2,9 @@
2
  # uv pip compile pyproject.toml
3
  aiofiles==23.2.1
4
  # via gradio
5
- aiohappyeyeballs==2.4.4
6
  # via aiohttp
7
- aiohttp==3.11.11
8
  # via
9
  # datasets
10
  # fsspec
@@ -18,46 +18,22 @@ anyio==4.8.0
18
  # via
19
  # gradio
20
  # httpx
21
- # jupyter-server
22
  # pycrdt
23
  # starlette
24
- argon2-cffi==23.1.0
25
- # via jupyter-server
26
- argon2-cffi-bindings==21.2.0
27
- # via argon2-cffi
28
- arrow==1.3.0
29
- # via isoduration
30
- asttokens==3.0.0
31
- # via stack-data
32
- async-lru==2.0.4
33
- # via jupyterlab
34
  attrs==25.1.0
35
  # via
36
  # aiohttp
37
  # jsonschema
38
  # referencing
39
- babel==2.17.0
40
- # via jupyterlab-server
41
  beartype==0.19.0
42
  # via
43
  # saev-semantic-segmentation (pyproject.toml)
44
  # saev
45
- beautifulsoup4==4.13.1
46
- # via nbconvert
47
- bleach==6.2.0
48
- # via nbconvert
49
- braceexpand==0.1.7
50
- # via webdataset
51
  certifi==2025.1.31
52
  # via
53
  # httpcore
54
  # httpx
55
  # requests
56
- # sentry-sdk
57
- cffi==1.17.1
58
- # via
59
- # argon2-cffi-bindings
60
- # pyvips
61
  charset-normalizer==3.4.1
62
  # via requests
63
  click==8.1.8
@@ -65,43 +41,28 @@ click==8.1.8
65
  # marimo
66
  # typer
67
  # uvicorn
68
- # wandb
69
  cloudpickle==3.1.1
70
  # via submitit
71
- comm==0.2.2
72
- # via ipykernel
73
  contourpy==1.3.1
74
  # via matplotlib
75
  cycler==0.12.1
76
  # via matplotlib
77
- datasets==3.2.0
78
  # via saev
79
- debugpy==1.8.12
80
- # via ipykernel
81
- decorator==5.1.1
82
- # via ipython
83
- defusedxml==0.7.1
84
- # via nbconvert
85
  dill==0.3.8
86
  # via
87
  # datasets
88
  # multiprocess
89
- docker-pycreds==0.4.0
90
- # via wandb
91
  docstring-parser==0.16
92
  # via tyro
93
  docutils==0.21.2
94
  # via marimo
95
- einops==0.8.0
96
  # via
97
  # saev-semantic-segmentation (pyproject.toml)
98
  # saev
99
- executing==2.2.0
100
- # via stack-data
101
  fastapi==0.115.8
102
  # via gradio
103
- fastjsonschema==2.21.1
104
- # via nbformat
105
  ffmpy==0.5.0
106
  # via gradio
107
  filelock==3.17.0
@@ -109,15 +70,13 @@ filelock==3.17.0
109
  # datasets
110
  # huggingface-hub
111
  # torch
112
- fonttools==4.55.8
113
  # via matplotlib
114
- fqdn==1.5.1
115
- # via jsonschema
116
  frozenlist==1.5.0
117
  # via
118
  # aiohttp
119
  # aiosignal
120
- fsspec==2024.9.0
121
  # via
122
  # datasets
123
  # gradio-client
@@ -125,13 +84,9 @@ fsspec==2024.9.0
125
  # torch
126
  ftfy==6.3.1
127
  # via open-clip-torch
128
- gitdb==4.0.12
129
- # via gitpython
130
- gitpython==3.1.44
131
- # via wandb
132
- gradio==5.14.0
133
  # via saev-semantic-segmentation (pyproject.toml)
134
- gradio-client==1.7.0
135
  # via gradio
136
  h11==0.14.0
137
  # via
@@ -143,9 +98,8 @@ httpx==0.28.1
143
  # via
144
  # gradio
145
  # gradio-client
146
- # jupyterlab
147
  # safehttpx
148
- huggingface-hub==0.28.1
149
  # via
150
  # datasets
151
  # gradio
@@ -156,88 +110,30 @@ idna==3.10
156
  # via
157
  # anyio
158
  # httpx
159
- # jsonschema
160
  # requests
161
  # yarl
162
- ipykernel==6.29.5
163
- # via jupyterlab
164
- ipython==8.32.0
165
- # via ipykernel
166
- isoduration==20.11.0
167
- # via jsonschema
168
  itsdangerous==2.2.0
169
  # via marimo
170
- jaxtyping==0.2.37
171
  # via saev
172
  jedi==0.19.2
173
- # via
174
- # ipython
175
- # marimo
176
  jinja2==3.1.5
177
  # via
178
  # altair
179
  # gradio
180
- # jupyter-server
181
- # jupyterlab
182
- # jupyterlab-server
183
- # nbconvert
184
  # torch
185
- joblib==1.4.2
186
- # via scikit-learn
187
- json5==0.10.0
188
- # via jupyterlab-server
189
- jsonpointer==3.0.0
190
- # via jsonschema
191
  jsonschema==4.23.0
192
- # via
193
- # altair
194
- # jupyter-events
195
- # jupyterlab-server
196
- # nbformat
197
  jsonschema-specifications==2024.10.1
198
  # via jsonschema
199
- jupyter-client==8.6.3
200
- # via
201
- # ipykernel
202
- # jupyter-server
203
- # nbclient
204
- jupyter-core==5.7.2
205
- # via
206
- # ipykernel
207
- # jupyter-client
208
- # jupyter-server
209
- # jupyterlab
210
- # nbclient
211
- # nbconvert
212
- # nbformat
213
- jupyter-events==0.12.0
214
- # via jupyter-server
215
- jupyter-lsp==2.2.5
216
- # via jupyterlab
217
- jupyter-server==2.15.0
218
- # via
219
- # jupyter-lsp
220
- # jupyterlab
221
- # jupyterlab-server
222
- # notebook-shim
223
- jupyter-server-terminals==0.5.3
224
- # via jupyter-server
225
- jupyterlab==4.3.5
226
- # via saev
227
- jupyterlab-pygments==0.3.0
228
- # via nbconvert
229
- jupyterlab-server==2.27.3
230
- # via jupyterlab
231
  kiwisolver==1.4.8
232
  # via matplotlib
233
- mako==1.3.9
234
- # via pdoc3
235
- marimo==0.10.19
236
  # via saev
237
  markdown==3.7
238
  # via
239
  # marimo
240
- # pdoc3
241
  # pymdown-extensions
242
  markdown-it-py==3.0.0
243
  # via rich
@@ -245,18 +141,10 @@ markupsafe==2.1.5
245
  # via
246
  # gradio
247
  # jinja2
248
- # mako
249
- # nbconvert
250
  matplotlib==3.10.0
251
  # via saev
252
- matplotlib-inline==0.1.7
253
- # via
254
- # ipykernel
255
- # ipython
256
  mdurl==0.1.2
257
  # via markdown-it-py
258
- mistune==3.1.1
259
- # via nbconvert
260
  mpmath==1.3.0
261
  # via sympy
262
  multidict==6.1.0
@@ -265,26 +153,13 @@ multidict==6.1.0
265
  # yarl
266
  multiprocess==0.70.16
267
  # via datasets
268
- narwhals==1.25.0
269
  # via
270
  # altair
271
  # marimo
272
- nbclient==0.10.2
273
- # via nbconvert
274
- nbconvert==7.16.6
275
- # via jupyter-server
276
- nbformat==5.10.4
277
- # via
278
- # jupyter-server
279
- # nbclient
280
- # nbconvert
281
- nest-asyncio==1.6.0
282
- # via ipykernel
283
  networkx==3.4.2
284
  # via torch
285
- notebook-shim==0.2.4
286
- # via jupyterlab
287
- numpy==2.2.2
288
  # via
289
  # saev-semantic-segmentation (pyproject.toml)
290
  # contourpy
@@ -292,10 +167,7 @@ numpy==2.2.2
292
  # gradio
293
  # matplotlib
294
  # pandas
295
- # scikit-learn
296
- # scipy
297
  # torchvision
298
- # webdataset
299
  nvidia-cublas-cu12==12.4.5.8
300
  # via
301
  # nvidia-cudnn-cu12
@@ -334,8 +206,6 @@ open-clip-torch==2.30.0
334
  # via saev
335
  orjson==3.10.15
336
  # via gradio
337
- overrides==7.7.0
338
- # via jupyter-server
339
  packaging==24.2
340
  # via
341
  # altair
@@ -343,81 +213,43 @@ packaging==24.2
343
  # gradio
344
  # gradio-client
345
  # huggingface-hub
346
- # ipykernel
347
- # jupyter-events
348
- # jupyter-server
349
- # jupyterlab
350
- # jupyterlab-server
351
  # marimo
352
  # matplotlib
353
- # nbconvert
354
  pandas==2.2.3
355
  # via
356
  # datasets
357
  # gradio
358
- pandocfilters==1.5.1
359
- # via nbconvert
360
  parso==0.8.4
361
  # via jedi
362
- pdoc3==0.11.5
363
- # via saev
364
- pexpect==4.9.0
365
- # via ipython
366
  pillow==11.1.0
367
  # via
368
  # gradio
369
  # matplotlib
370
  # saev
371
  # torchvision
372
- pkgconfig==1.5.5
373
- # via pyvips
374
- platformdirs==4.3.6
375
- # via
376
- # jupyter-core
377
- # wandb
378
- polars==1.21.0
379
  # via saev
380
- prometheus-client==0.21.1
381
- # via jupyter-server
382
- prompt-toolkit==3.0.50
383
- # via ipython
384
- propcache==0.2.1
385
  # via
386
  # aiohttp
387
  # yarl
388
- protobuf==5.29.3
389
- # via wandb
390
- psutil==6.1.1
391
- # via
392
- # ipykernel
393
- # marimo
394
- # wandb
395
- ptyprocess==0.7.0
396
- # via
397
- # pexpect
398
- # terminado
399
- pure-eval==0.2.3
400
- # via stack-data
401
- pyarrow==19.0.0
402
  # via datasets
403
- pycparser==2.22
404
- # via cffi
405
  pycrdt==0.11.1
406
  # via marimo
407
  pydantic==2.10.6
408
  # via
409
  # fastapi
410
  # gradio
411
- # wandb
412
  pydantic-core==2.27.2
413
  # via pydantic
414
  pydub==0.25.1
415
  # via gradio
416
  pygments==2.19.1
417
  # via
418
- # ipython
419
  # marimo
420
- # nbconvert
421
  # rich
422
  pymdown-extensions==10.14.3
423
  # via marimo
@@ -425,68 +257,43 @@ pyparsing==3.2.1
425
  # via matplotlib
426
  python-dateutil==2.9.0.post0
427
  # via
428
- # arrow
429
- # jupyter-client
430
  # matplotlib
431
  # pandas
432
- python-json-logger==3.2.1
433
- # via jupyter-events
434
  python-multipart==0.0.20
435
  # via gradio
436
  pytz==2025.1
437
  # via pandas
438
- pyvips==2.2.3
439
- # via saev
440
  pyyaml==6.0.2
441
  # via
442
  # datasets
443
  # gradio
444
  # huggingface-hub
445
- # jupyter-events
446
  # marimo
447
  # pymdown-extensions
448
  # timm
449
- # wandb
450
- # webdataset
451
- pyzmq==26.2.1
452
- # via
453
- # ipykernel
454
- # jupyter-client
455
- # jupyter-server
456
  referencing==0.36.2
457
  # via
458
  # jsonschema
459
  # jsonschema-specifications
460
- # jupyter-events
461
  regex==2024.11.6
462
  # via open-clip-torch
463
  requests==2.32.3
464
  # via
465
  # datasets
466
  # huggingface-hub
467
- # jupyterlab-server
468
- # wandb
469
- rfc3339-validator==0.1.4
470
- # via
471
- # jsonschema
472
- # jupyter-events
473
- rfc3986-validator==0.1.1
474
- # via
475
- # jsonschema
476
- # jupyter-events
477
  rich==13.9.4
478
  # via
479
  # typer
480
  # tyro
481
- rpds-py==0.22.3
482
  # via
483
  # jsonschema
484
  # referencing
485
- ruff==0.9.4
486
  # via
487
  # gradio
488
  # marimo
489
- saev @ git+https://github.com/samuelstevens/saev@44f1aa334828ed994d4d670629ff639c655db38b
490
  # via saev-semantic-segmentation (pyproject.toml)
491
  safehttpx==0.1.6
492
  # via gradio
@@ -494,40 +301,18 @@ safetensors==0.5.2
494
  # via
495
  # open-clip-torch
496
  # timm
497
- scikit-learn==1.6.1
498
- # via saev
499
- scipy==1.15.1
500
- # via scikit-learn
501
  semantic-version==2.10.0
502
  # via gradio
503
- send2trash==1.8.3
504
- # via jupyter-server
505
- sentry-sdk==2.20.0
506
- # via wandb
507
- setproctitle==1.3.4
508
- # via wandb
509
  setuptools==75.8.0
510
- # via
511
- # jupyterlab
512
- # torch
513
- # wandb
514
  shellingham==1.5.4
515
  # via typer
516
  shtab==1.7.1
517
  # via tyro
518
  six==1.17.0
519
- # via
520
- # docker-pycreds
521
- # python-dateutil
522
- # rfc3339-validator
523
- smmap==5.0.2
524
- # via gitdb
525
  sniffio==1.3.1
526
  # via anyio
527
- soupsieve==2.6
528
- # via beautifulsoup4
529
- stack-data==0.6.3
530
- # via ipython
531
  starlette==0.45.3
532
  # via
533
  # fastapi
@@ -537,16 +322,8 @@ submitit==1.5.2
537
  # via saev
538
  sympy==1.13.1
539
  # via torch
540
- terminado==0.18.1
541
- # via
542
- # jupyter-server
543
- # jupyter-server-terminals
544
- threadpoolctl==3.5.0
545
- # via scikit-learn
546
  timm==1.0.14
547
  # via open-clip-torch
548
- tinycss2==1.4.0
549
- # via bleach
550
  tomlkit==0.13.2
551
  # via
552
  # gradio
@@ -563,46 +340,22 @@ torchvision==0.21.0
563
  # saev-semantic-segmentation (pyproject.toml)
564
  # open-clip-torch
565
  # timm
566
- tornado==6.4.2
567
- # via
568
- # ipykernel
569
- # jupyter-client
570
- # jupyter-server
571
- # jupyterlab
572
- # terminado
573
  tqdm==4.67.1
574
  # via
575
  # datasets
576
  # huggingface-hub
577
  # open-clip-torch
578
  # saev
579
- traitlets==5.14.3
580
- # via
581
- # comm
582
- # ipykernel
583
- # ipython
584
- # jupyter-client
585
- # jupyter-core
586
- # jupyter-events
587
- # jupyter-server
588
- # jupyterlab
589
- # matplotlib-inline
590
- # nbclient
591
- # nbconvert
592
- # nbformat
593
  triton==3.2.0
594
  # via torch
595
- typeguard==4.4.1
596
  # via tyro
597
  typer==0.15.1
598
  # via gradio
599
- types-python-dateutil==2.9.0.20241206
600
- # via arrow
601
  typing-extensions==4.12.2
602
  # via
603
  # altair
604
  # anyio
605
- # beautifulsoup4
606
  # fastapi
607
  # gradio
608
  # gradio-client
@@ -615,40 +368,20 @@ typing-extensions==4.12.2
615
  # typeguard
616
  # typer
617
  # tyro
618
- tyro==0.9.13
619
  # via saev
620
  tzdata==2025.1
621
  # via pandas
622
- uri-template==1.3.0
623
- # via jsonschema
624
  urllib3==2.3.0
625
- # via
626
- # requests
627
- # sentry-sdk
628
  uvicorn==0.34.0
629
  # via
630
  # gradio
631
  # marimo
632
- vl-convert-python==1.7.0
633
- # via saev
634
  wadler-lindig==0.1.3
635
  # via jaxtyping
636
- wandb==0.19.5
637
- # via saev
638
  wcwidth==0.2.13
639
- # via
640
- # ftfy
641
- # prompt-toolkit
642
- webcolors==24.11.1
643
- # via jsonschema
644
- webdataset==0.2.100
645
- # via saev
646
- webencodings==0.5.1
647
- # via
648
- # bleach
649
- # tinycss2
650
- websocket-client==1.8.0
651
- # via jupyter-server
652
  websockets==14.2
653
  # via
654
  # gradio-client
 
2
  # uv pip compile pyproject.toml
3
  aiofiles==23.2.1
4
  # via gradio
5
+ aiohappyeyeballs==2.4.6
6
  # via aiohttp
7
+ aiohttp==3.11.12
8
  # via
9
  # datasets
10
  # fsspec
 
18
  # via
19
  # gradio
20
  # httpx
 
21
  # pycrdt
22
  # starlette
 
 
 
 
 
 
 
 
 
 
23
  attrs==25.1.0
24
  # via
25
  # aiohttp
26
  # jsonschema
27
  # referencing
 
 
28
  beartype==0.19.0
29
  # via
30
  # saev-semantic-segmentation (pyproject.toml)
31
  # saev
 
 
 
 
 
 
32
  certifi==2025.1.31
33
  # via
34
  # httpcore
35
  # httpx
36
  # requests
 
 
 
 
 
37
  charset-normalizer==3.4.1
38
  # via requests
39
  click==8.1.8
 
41
  # marimo
42
  # typer
43
  # uvicorn
 
44
  cloudpickle==3.1.1
45
  # via submitit
 
 
46
  contourpy==1.3.1
47
  # via matplotlib
48
  cycler==0.12.1
49
  # via matplotlib
50
+ datasets==3.3.2
51
  # via saev
 
 
 
 
 
 
52
  dill==0.3.8
53
  # via
54
  # datasets
55
  # multiprocess
 
 
56
  docstring-parser==0.16
57
  # via tyro
58
  docutils==0.21.2
59
  # via marimo
60
+ einops==0.8.1
61
  # via
62
  # saev-semantic-segmentation (pyproject.toml)
63
  # saev
 
 
64
  fastapi==0.115.8
65
  # via gradio
 
 
66
  ffmpy==0.5.0
67
  # via gradio
68
  filelock==3.17.0
 
70
  # datasets
71
  # huggingface-hub
72
  # torch
73
+ fonttools==4.56.0
74
  # via matplotlib
 
 
75
  frozenlist==1.5.0
76
  # via
77
  # aiohttp
78
  # aiosignal
79
+ fsspec==2024.12.0
80
  # via
81
  # datasets
82
  # gradio-client
 
84
  # torch
85
  ftfy==6.3.1
86
  # via open-clip-torch
87
+ gradio==5.16.2
 
 
 
 
88
  # via saev-semantic-segmentation (pyproject.toml)
89
+ gradio-client==1.7.1
90
  # via gradio
91
  h11==0.14.0
92
  # via
 
98
  # via
99
  # gradio
100
  # gradio-client
 
101
  # safehttpx
102
+ huggingface-hub==0.29.1
103
  # via
104
  # datasets
105
  # gradio
 
110
  # via
111
  # anyio
112
  # httpx
 
113
  # requests
114
  # yarl
 
 
 
 
 
 
115
  itsdangerous==2.2.0
116
  # via marimo
117
+ jaxtyping==0.2.38
118
  # via saev
119
  jedi==0.19.2
120
+ # via marimo
 
 
121
  jinja2==3.1.5
122
  # via
123
  # altair
124
  # gradio
 
 
 
 
125
  # torch
 
 
 
 
 
 
126
  jsonschema==4.23.0
127
+ # via altair
 
 
 
 
128
  jsonschema-specifications==2024.10.1
129
  # via jsonschema
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  kiwisolver==1.4.8
131
  # via matplotlib
132
+ marimo==0.11.7
 
 
133
  # via saev
134
  markdown==3.7
135
  # via
136
  # marimo
 
137
  # pymdown-extensions
138
  markdown-it-py==3.0.0
139
  # via rich
 
141
  # via
142
  # gradio
143
  # jinja2
 
 
144
  matplotlib==3.10.0
145
  # via saev
 
 
 
 
146
  mdurl==0.1.2
147
  # via markdown-it-py
 
 
148
  mpmath==1.3.0
149
  # via sympy
150
  multidict==6.1.0
 
153
  # yarl
154
  multiprocess==0.70.16
155
  # via datasets
156
+ narwhals==1.27.1
157
  # via
158
  # altair
159
  # marimo
 
 
 
 
 
 
 
 
 
 
 
160
  networkx==3.4.2
161
  # via torch
162
+ numpy==2.2.3
 
 
163
  # via
164
  # saev-semantic-segmentation (pyproject.toml)
165
  # contourpy
 
167
  # gradio
168
  # matplotlib
169
  # pandas
 
 
170
  # torchvision
 
171
  nvidia-cublas-cu12==12.4.5.8
172
  # via
173
  # nvidia-cudnn-cu12
 
206
  # via saev
207
  orjson==3.10.15
208
  # via gradio
 
 
209
  packaging==24.2
210
  # via
211
  # altair
 
213
  # gradio
214
  # gradio-client
215
  # huggingface-hub
 
 
 
 
 
216
  # marimo
217
  # matplotlib
 
218
  pandas==2.2.3
219
  # via
220
  # datasets
221
  # gradio
 
 
222
  parso==0.8.4
223
  # via jedi
 
 
 
 
224
  pillow==11.1.0
225
  # via
226
  # gradio
227
  # matplotlib
228
  # saev
229
  # torchvision
230
+ polars==1.22.0
 
 
 
 
 
 
231
  # via saev
232
+ propcache==0.3.0
 
 
 
 
233
  # via
234
  # aiohttp
235
  # yarl
236
+ psutil==7.0.0
237
+ # via marimo
238
+ pyarrow==19.0.1
 
 
 
 
 
 
 
 
 
 
 
239
  # via datasets
 
 
240
  pycrdt==0.11.1
241
  # via marimo
242
  pydantic==2.10.6
243
  # via
244
  # fastapi
245
  # gradio
 
246
  pydantic-core==2.27.2
247
  # via pydantic
248
  pydub==0.25.1
249
  # via gradio
250
  pygments==2.19.1
251
  # via
 
252
  # marimo
 
253
  # rich
254
  pymdown-extensions==10.14.3
255
  # via marimo
 
257
  # via matplotlib
258
  python-dateutil==2.9.0.post0
259
  # via
 
 
260
  # matplotlib
261
  # pandas
 
 
262
  python-multipart==0.0.20
263
  # via gradio
264
  pytz==2025.1
265
  # via pandas
 
 
266
  pyyaml==6.0.2
267
  # via
268
  # datasets
269
  # gradio
270
  # huggingface-hub
 
271
  # marimo
272
  # pymdown-extensions
273
  # timm
 
 
 
 
 
 
 
274
  referencing==0.36.2
275
  # via
276
  # jsonschema
277
  # jsonschema-specifications
 
278
  regex==2024.11.6
279
  # via open-clip-torch
280
  requests==2.32.3
281
  # via
282
  # datasets
283
  # huggingface-hub
 
 
 
 
 
 
 
 
 
 
284
  rich==13.9.4
285
  # via
286
  # typer
287
  # tyro
288
+ rpds-py==0.23.0
289
  # via
290
  # jsonschema
291
  # referencing
292
+ ruff==0.9.7
293
  # via
294
  # gradio
295
  # marimo
296
+ saev @ git+https://github.com/samuelstevens/saev@298cabdb6b771c76b402d0fdddab6907d1941d7a
297
  # via saev-semantic-segmentation (pyproject.toml)
298
  safehttpx==0.1.6
299
  # via gradio
 
301
  # via
302
  # open-clip-torch
303
  # timm
 
 
 
 
304
  semantic-version==2.10.0
305
  # via gradio
 
 
 
 
 
 
306
  setuptools==75.8.0
307
+ # via torch
 
 
 
308
  shellingham==1.5.4
309
  # via typer
310
  shtab==1.7.1
311
  # via tyro
312
  six==1.17.0
313
+ # via python-dateutil
 
 
 
 
 
314
  sniffio==1.3.1
315
  # via anyio
 
 
 
 
316
  starlette==0.45.3
317
  # via
318
  # fastapi
 
322
  # via saev
323
  sympy==1.13.1
324
  # via torch
 
 
 
 
 
 
325
  timm==1.0.14
326
  # via open-clip-torch
 
 
327
  tomlkit==0.13.2
328
  # via
329
  # gradio
 
340
  # saev-semantic-segmentation (pyproject.toml)
341
  # open-clip-torch
342
  # timm
 
 
 
 
 
 
 
343
  tqdm==4.67.1
344
  # via
345
  # datasets
346
  # huggingface-hub
347
  # open-clip-torch
348
  # saev
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  triton==3.2.0
350
  # via torch
351
+ typeguard==4.4.2
352
  # via tyro
353
  typer==0.15.1
354
  # via gradio
 
 
355
  typing-extensions==4.12.2
356
  # via
357
  # altair
358
  # anyio
 
359
  # fastapi
360
  # gradio
361
  # gradio-client
 
368
  # typeguard
369
  # typer
370
  # tyro
371
+ tyro==0.9.16
372
  # via saev
373
  tzdata==2025.1
374
  # via pandas
 
 
375
  urllib3==2.3.0
376
+ # via requests
 
 
377
  uvicorn==0.34.0
378
  # via
379
  # gradio
380
  # marimo
 
 
381
  wadler-lindig==0.1.3
382
  # via jaxtyping
 
 
383
  wcwidth==0.2.13
384
+ # via ftfy
 
 
 
 
 
 
 
 
 
 
 
 
385
  websockets==14.2
386
  # via
387
  # gradio-client