St0nedB commited on
Commit
0696948
0 Parent(s):

initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Gradio Test
3
+ colorFrom: red
4
+ colorTo: purple
5
+ sdk: gradio
6
+ sdk_version: "3.9"
7
+ python_version: "3.10"
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ # Demo of `deepest` parameter estimator
13
+ This repository is a demo of the `deepest` parameter estimator introduced in []()
14
+
15
+ ## Usage
16
+ Start by installing the requirements from the `requirements.txt`.
17
+ ```bash
18
+ python -m pip install -r requirements.txt
19
+ ```
20
+
21
+ The repository uses `secrets` to avoid leaking non-public information (i.e. the code, the trained model, git urls and access tokens) which can occur when users clone the repository.
22
+ If you are in possesion of those secrets, set them as the following environement variables in your shell:
23
+ ```bash
24
+ export MODEL_TOKEN=<Access Token for the Huggingface Model Git>
25
+ export GIT_TOKEN=<Access Token for the `deepest` repository>
26
+ export GIT_URL=<Url of the `deepest` repository>
27
+ export GIT_COMMIT=<The commit shasum for the `deepest` version>
28
+
29
+ ```
30
+
31
+ Then run the `app.py` via
32
+ ```bash
33
+ python app.py
34
+ ```
__pycache__/helper.cpython-310.pyc ADDED
Binary file (2.6 kB). View file
 
app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import subprocess
4
+ import numpy as np
5
+ import logging
6
+ import gradio as gr
7
+ import matplotlib
8
+ import matplotlib.pyplot as plt
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ matplotlib.use("Agg")
12
+ logger = logging.basicConfig(level=logging.ERROR)
13
+
14
+ # define global variable demos
15
+ DATA_SHAPE = (64,64)
16
+ ETA_SHAPE = (2, 20)
17
+ DATASET = "./data"
18
+ N = 1000
19
+ BS = 256
20
+ WORKER = 2
21
+ SNRS = {
22
+ "0": 1.0,
23
+ "10": 0.1,
24
+ "20": 0.01,
25
+ "30": 0.001,
26
+ }
27
+
28
+ # download model from huggingface hub
29
+ MODEL_PATH = hf_hub_download("St0nedB/deepest-demo", "2022.07.03.2338.param2d.model", use_auth_token=os.environ["MODEL_TOKEN"])
30
+ RUNNER = None
31
+
32
+ # preallocated result arrays
33
+ DATA = np.empty((len(SNRS), N, *DATA_SHAPE), dtype=np.complex128)
34
+ TRUTH = np.empty((len(SNRS), N, *ETA_SHAPE))
35
+ ESTIM = np.empty((len(SNRS), N, *ETA_SHAPE))
36
+
37
+ def install_deepest():
38
+ git_token = os.environ["GIT_TOKEN"]
39
+ git_url = os.environ["GIT_URL"]
40
+ git_commit = os.environ["GIT_COMMIT"]
41
+ subprocess.check_call([sys.executable, "-m", "pip", "install", f"git+https://hggn:{git_token}@{git_url}@{git_commit}"])
42
+ return
43
+
44
+
45
+ def make_plots(snr: float, idx: int):
46
+ idx -= 1
47
+ data, truth, estim = DATA[snr][idx], TRUTH[snr][idx], ESTIM[snr][idx]
48
+
49
+ fig_data = make_dataplot(data)
50
+ fig_param = make_parameterplot(estim, truth)
51
+
52
+ return fig_data, fig_param
53
+
54
+ def make_dataplot(x: np.ndarray):
55
+ plt.close()
56
+ x = np.rot90(10*np.log10(np.abs(np.fft.fftn(x))), k=-1)
57
+ fig, ax = plt.subplots(1,1)
58
+ ax.imshow(x, extent=[0,1,0,1], cmap="viridis")
59
+ ax.set_xlabel("Norm. Delay")
60
+ ax.set_ylabel("Norm. Doppler")
61
+
62
+ return fig
63
+
64
+ def make_parameterplot(estim: np.ndarray, truth: np.ndarray = None, **kwargs):
65
+ plt.close()
66
+ fig, ax = plt.subplots(1,1)
67
+ ax = plot_parameters(ax, es=estim, gt=truth, **kwargs)
68
+ ax.set_xlim(0,1)
69
+ ax.set_ylim(0,1)
70
+
71
+ return fig
72
+
73
+ def load_numpy(file_obj) -> None | np.ndarray:
74
+ if file_obj is None:
75
+ # no file given
76
+ return None
77
+
78
+ file = file_obj.name
79
+ if not(os.path.splitext(file)[1] in [".npy", ".npz"]):
80
+ # no numpy file
81
+ return None
82
+
83
+ data = np.load(file)
84
+ if len(data.shape) != 3:
85
+ # not in proper shape
86
+ return None
87
+
88
+ return data
89
+
90
+ def process_user_input(file_obj):
91
+ data = load_numpy(file_obj)
92
+ if data is None:
93
+ return None
94
+
95
+ return gr.update(minimum=1, step=1, maximum=len(data), visible=True, value=1)
96
+
97
+ def make_user_plot(file_obj, idx: int):
98
+ idx -= 1
99
+ data = load_numpy(file_obj)
100
+
101
+ estim = RUNNER.fit(data[idx][None,])
102
+ bg_data = np.rot90(10*np.log10(np.abs(np.fft.fftn(data[idx], norm="ortho"))), k=-1)
103
+ fig_estim = make_parameterplot(estim=estim[0], bg=bg_data, extent=[0,1,0,1], cmap="viridis")
104
+
105
+ return fig_estim
106
+
107
+
108
+ def demo():
109
+ with gr.Blocks() as demo:
110
+ gr.Markdown(
111
+ """
112
+ # deepest
113
+ `deepest` (short for **deep** learning parameter **est**imator) is a CNN trained to perform signal parameter estimation.
114
+ The corresponding paper can be found on [arxiv](arxiv.org).
115
+
116
+ This applet lets you explore the `deepest` with data from the validationset.
117
+ You can also upload your own data and see how it works for your signals.
118
+ """
119
+ )
120
+
121
+ with gr.Column():
122
+ snr = gr.Radio(choices=["0", "10", "20", "30"], type="index", value="0", label="SNR [dB]")
123
+
124
+ with gr.Row():
125
+ data = gr.Plot(label="Data")
126
+ result = gr.Plot(label="Results")
127
+
128
+ with gr.Row():
129
+ slider = gr.Slider(1, N, 1, label="Sample Index")
130
+
131
+ # update callbacks
132
+ slider.change(make_plots, [snr, slider], [data, result])
133
+ snr.change(make_plots, [snr, slider], [data, result])
134
+
135
+ with gr.Column():
136
+ gr.Markdown(
137
+ """
138
+ ## Try with your own data.
139
+ Good new everyone! If you want to try `deepest` with your own data, here is your chance.
140
+ Afterall there is no need to believe a paper making vague claims about an algorithms performance.
141
+ But keep in mind that slight deviations from the training-data distribution might throw off `deepest`.
142
+ Its a Neural Network after all.
143
+ You can upload a `numpy` file (both `*.npy` and `*.npz` work) with your test data.
144
+ Ensure the data meets the requirements, such that you get good results.
145
+ ### Requirements
146
+ - complex-valued baseband data for the time-variant Channel transfer function $H(f,t)$ (e.g. from a channel-sounding campaing)
147
+ - array shape must be `batch_size x f_bins x t_bins`
148
+ - ideally `f_bin`=64 and `t_bin`=64, otherwise the data will be downsampled by the 2D-DFT, which might not be ideal in all scenarios.
149
+
150
+ **Important** This demo runs on Huggingface. You are responsible for the data you upload. Do not upload any data that is confidential or unsuitable in this context.
151
+ """
152
+ )
153
+
154
+ with gr.Row():
155
+ with gr.Column():
156
+ user_file = gr.File(file_count="single", type="file", interactive=True)
157
+ run_btn = gr.Button("Run Inference")
158
+
159
+ user_plot = gr.Plot(label="Results")
160
+
161
+ with gr.Column():
162
+ user_slider = gr.Slider(visible=False, label="Sample Index")
163
+
164
+ run_btn.click(process_user_input, [user_file], [user_slider], show_progress=True)
165
+ user_slider.change(make_user_plot, [user_file, user_slider], [user_plot])
166
+
167
+ demo.launch()
168
+
169
+ def main():
170
+ for dd, snr in enumerate(SNRS.values()):
171
+ DATA[dd], TRUTH[dd], ESTIM[dd] = RUNNER.run(snr=snr)
172
+
173
+ demo()
174
+
175
+
176
+ if __name__ == "__main__":
177
+ install_deepest()
178
+ from deepest.utils import plot_parameters
179
+ from helper import Runner
180
+ RUNNER = Runner(MODEL_PATH, DATASET, BS, WORKER)
181
+ main()
data/config.toml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [[distribution]]
2
+ minComps = 1
3
+ maxComps = 20
4
+ gammasLabels = [ "magnitude", "phase",]
5
+ comment = "Moving Targets"
6
+ [[distribution.parameter]]
7
+ dist = "uniform"
8
+
9
+ [distribution.parameter.kwargs]
10
+ low = 0
11
+ high = 1
12
+ [[distribution.parameter]]
13
+ dist = "uniform"
14
+
15
+ [distribution.parameter.kwargs]
16
+ low = 0
17
+ high = 1
18
+ [[distribution.gammas]]
19
+ dist = "uniform"
20
+
21
+ [distribution.gammas.kwargs]
22
+ low = 0.001
23
+ high = 1
24
+ [[distribution.gammas]]
25
+ dist = "uniform"
26
+
27
+ [distribution.gammas.kwargs]
28
+ low = 0
29
+ high = 6.283
30
+
31
+ [general]
32
+ path = "./val"
33
+ items = 1000
34
+ dim = 2
35
+ seed = 242
36
+ minDistance = [ 0.003125, 0.003125,]
data/dataset.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:532c2127d6f70fe6844edd7c2619f42376216cedacba0cbbd8474da86cd97f87
3
+ size 640264
helper.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader
2
+ from deepest.modules import Parameter2dNet
3
+ from deepest.datasets import InferenceDelayDataset
4
+ from deepest.metrics import match_components
5
+ import numpy as np
6
+
7
+ class Runner:
8
+ def __init__(self, model: str, dataset: str, bs: int, num_worker: int):
9
+ self.module = Parameter2dNet.from_file(f"{model}")
10
+ self.dataset_config = self.module.get_datasetconfig()
11
+ self.dataset = InferenceDelayDataset(path=dataset, **self.dataset_config)
12
+ self.bs = bs
13
+ self.num_worker = num_worker
14
+
15
+ def _preallocate(self, data_shape: tuple[int, ...], eta_shape: tuple[int, ...]):
16
+ data = np.empty((len(self), *data_shape), dtype=np.complex128)
17
+ truth = np.empty((len(self), *eta_shape))
18
+ estim = np.empty((len(self), *eta_shape))
19
+ return data, truth, estim
20
+
21
+ def _get_batchrange_for_index(self, ii: int):
22
+ start_idx = ii*self.bs
23
+ stop_idx = (ii+1)*self.bs
24
+ if stop_idx > len(self.dataset):
25
+ stop_idx = len(self.dataset)
26
+
27
+ return range(start_idx, stop_idx)
28
+
29
+ def run(self, snr: int) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
30
+ self.dataset.noise_var = (snr, snr)
31
+ dataloader = DataLoader(
32
+ self.dataset,
33
+ batch_size=self.bs,
34
+ num_workers=self.num_worker,
35
+ worker_init_fn=lambda worker_id: np.random.seed(worker_id),
36
+ shuffle=False,
37
+ )
38
+
39
+ for ii, (x, _, z) in enumerate(dataloader):
40
+ z = z[0][:, :2, :]
41
+ if ii == 0:
42
+ data, truth, estim = self._preallocate(x.shape[1:], z.shape[1:])
43
+
44
+ idx_range = self._get_batchrange_for_index(ii)
45
+
46
+ data[idx_range] = x.cpu().numpy()
47
+ truth[idx_range] = z.cpu().numpy()
48
+ estim[idx_range] = self.module.fit(x)[:, :2, :]
49
+
50
+ estim, truth = match_components(estim, truth)
51
+
52
+ return data, truth, estim
53
+
54
+ def fit(self, data: np.ndarray) -> np.ndarray:
55
+ x = self.module.fit(data)
56
+ return x[:, :2, :]
57
+
58
+ def __len__(self):
59
+ return len(self.dataset)
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datetime
2
+ requests
3
+ toml
4
+ typical
5
+ torch>=1.11
6
+ torchinfo
7
+ numpy>=1.20
8
+ scipy
9
+ scikit-image
10
+ tqdm
11
+ joblib
12
+ matplotlib
13
+ huggingface-hub
14
+ gradio