phzwart commited on
Commit
3da9124
1 Parent(s): 7f10a8b

Upload 4 files

Browse files

Notebook + data & results

Files changed (4) hide show
  1. data_small.npy +3 -0
  2. inference.ipynb +147 -0
  3. mean_small.npy +3 -0
  4. std_small.npy +3 -0
data_small.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c6b0011ad155aeb1a775fd470094db03ba18839a09df48c449c3b05fd52c836
3
+ size 1000128
inference.ipynb ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "b247863c",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import numpy as np\n",
11
+ "import math\n",
12
+ "import torch\n",
13
+ "from dlsia.core import helpers\n",
14
+ "from dlsia.core.networks import sms3d\n",
15
+ "from dlsia.core.networks import baggins\n",
16
+ "import napari"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 2,
22
+ "id": "855a85d5",
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "nets = []\n",
27
+ "for ii in range(5):\n",
28
+ " net = sms3d.SMSNetwork3D_from_file(\"3d_%i_2023_06_10_depth_25.pt\"%ii)\n",
29
+ " nets.append(net.eval())"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": 3,
35
+ "id": "0acd6f5b",
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "bilbo = baggins.model_baggin(nets, \n",
40
+ " model_type='classification', \n",
41
+ " returns_normalized=False,\n",
42
+ " average_type=\"arithmetic\")"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": 4,
48
+ "id": "4d662643",
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "data_small = np.load(\"data_small.npy\")"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": 5,
58
+ "id": "forward-pickup",
59
+ "metadata": {},
60
+ "outputs": [],
61
+ "source": [
62
+ "v = napari.view_image(data_small)"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": null,
68
+ "id": "blessed-wesley",
69
+ "metadata": {},
70
+ "outputs": [],
71
+ "source": [
72
+ "tdata = torch.Tensor(data_small).unsqueeze(0).unsqueeze(0)"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": null,
78
+ "id": "educated-family",
79
+ "metadata": {},
80
+ "outputs": [],
81
+ "source": [
82
+ "if runit:\n",
83
+ " with torch.no_grad():\n",
84
+ " m,s = bilbo.eval()(tdata, return_std=True)\n",
85
+ "else:\n",
86
+ " m = np.load(\"mean_small.npy\")\n",
87
+ " s = np.load(\"std_small.npy\")"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "id": "subject-british",
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "_ = v.add_image(m.numpy()[0,1])\n",
98
+ "_ = v.add_image(s.numpy()[0,1])"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": null,
104
+ "id": "dying-publicity",
105
+ "metadata": {},
106
+ "outputs": [],
107
+ "source": []
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": null,
112
+ "id": "lonely-resource",
113
+ "metadata": {},
114
+ "outputs": [],
115
+ "source": []
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": null,
120
+ "id": "embedded-blackjack",
121
+ "metadata": {},
122
+ "outputs": [],
123
+ "source": []
124
+ }
125
+ ],
126
+ "metadata": {
127
+ "kernelspec": {
128
+ "display_name": "dlsia-dev",
129
+ "language": "python",
130
+ "name": "dlsia-dev"
131
+ },
132
+ "language_info": {
133
+ "codemirror_mode": {
134
+ "name": "ipython",
135
+ "version": 3
136
+ },
137
+ "file_extension": ".py",
138
+ "mimetype": "text/x-python",
139
+ "name": "python",
140
+ "nbconvert_exporter": "python",
141
+ "pygments_lexer": "ipython3",
142
+ "version": "3.9.0"
143
+ }
144
+ },
145
+ "nbformat": 4,
146
+ "nbformat_minor": 5
147
+ }
mean_small.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:647620e3079a4bd324d2a645e2d0ca3c59c0464d37b30d01641038d2f1ddf69b
3
+ size 8000128
std_small.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:035191f5ad9f976f3e4502b146a75c4c05a69d04118658f5aa4afdb5a7f9c980
3
+ size 8000128