deepest-demo / app.py
St0nedB's picture
added buttons for navigation
c3f203b
raw
history blame contribute delete
No virus
4.33 kB
# https://huggingface.co/St0nedB/deepest-public
import os
import sys
import subprocess
import toml
from argparse import Namespace
import numpy as np
import logging
import gradio as gr
import matplotlib
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download
from colormaps import tol_cmap
matplotlib.use("Agg")
logger = logging.basicConfig(level=logging.ERROR)
# define global variable demos
DATA_SHAPE = (64,64)
ETA_SHAPE = (2, 20)
DATASET = "./data"
N = 1000
BS = 256
WORKER = 2
SNRS = {
"-10": 10.0,
"0": 1.0,
"10": 0.1,
"20": 0.01,
}
# download model from huggingface hub
MODEL_PATH = hf_hub_download("St0nedB/deepest-demo", "2022.07.03.2338.param2d.model", use_auth_token=os.environ["MODEL_TOKEN"])
RUNNER = None
# preallocated result arrays
DATA = np.empty((len(SNRS), N, *DATA_SHAPE), dtype=np.complex128)
TRUTH = np.empty((len(SNRS), N, *ETA_SHAPE))
ESTIM = np.empty((len(SNRS), N, *ETA_SHAPE))
# load texts
TEXTS = Namespace(**toml.load("texts.toml"))
def install_deepest():
git_token = os.environ["GIT_TOKEN"]
git_url = os.environ["GIT_URL"]
git_commit = os.environ["GIT_COMMIT"]
subprocess.check_call([sys.executable, "-m", "pip", "install", f"git+https://hggn:{git_token}@{git_url}@{git_commit}"])
return
def make_plots(idx: int):
idx -= 1
figs = []
plt.close("all")
for snr in range(4):
data, truth, estim = DATA[snr][idx], TRUTH[snr][idx], ESTIM[snr][idx]
x = np.fft.fftn(data, s=[4*x for x in DATA_SHAPE])
x /= np.linalg.norm(x, axis=(0,1))
x = np.rot90(10*np.log10(np.abs(x)**2), k=-1)
fig, ax = plt.subplots(1,2, gridspec_kw={"width_ratios": [10,0.5]})
ax[0].scatter(truth[0, :], truth[1, :], marker="o", facecolors="none", edgecolors="#000000", s=200, linewidth=2, label="groundtruth")
ax[0].scatter(estim[0, :], estim[1, :], marker="o", facecolors="none", edgecolors="#0077bb", s=200, linewidth=2, label="estimate")
cm = ax[0].imshow(x, extent=[0,1,0,1], cmap=tol_cmap("YlOrBr"), vmin=-80, vmax=0)
ax[0].set_xlim(0,1)
ax[0].set_ylim(0,1)
ax[0].set_xlabel("Norm. Delay")
ax[0].set_ylabel("Norm. Doppler")
ax[0].legend(
loc="upper center",
bbox_to_anchor=(0.5, 1.1),
fancybox=True,
shadow=True,
ncol=5,
)
fig.colorbar(cm, cax=ax[1], orientation='vertical', label="Magnitude [dB]", pad=-0.5)
fig.tight_layout()
figs.append(fig)
return figs
def button_previous(idx: int) -> int:
if idx == 1:
return N
return idx-1
def button_next(idx: int) -> int:
if idx == N:
return 1
return idx+1
def demo():
fig_1, fig_2, fig_3, fig_4 = make_plots(0)
with gr.Blocks() as demo:
gr.Markdown(
TEXTS.introduction
)
with gr.Column():
with gr.Row():
result_1 = gr.Plot(value=fig_1, label="SNR -10 dB")
result_2 = gr.Plot(value=fig_2, label="SNR 0 dB")
with gr.Row():
result_3 = gr.Plot(value=fig_3, label="SNR 10 dB")
result_4 = gr.Plot(value=fig_4, label="SNR 20 dB")
with gr.Row():
previous = gr.Button(value="< Previous")
next = gr.Button(value="Next >")
with gr.Row():
slider = gr.Slider(1, N, 1, label="Sample Index (Snapshot)")
# update callbacks
slider.change(make_plots, [slider], [result_1, result_2, result_3, result_4])
previous.click(button_previous, [slider], [slider])
next.click(button_next, [slider], [slider])
gr.Markdown(
TEXTS.acknowledgements
)
gr.Markdown(
TEXTS.contact
)
demo.launch()
def main():
for dd, snr in enumerate(SNRS.values()):
DATA[dd], TRUTH[dd], ESTIM[dd] = RUNNER.run(snr=snr)
demo()
if __name__ == "__main__":
try:
import deepest
except ModuleNotFoundError:
install_deepest()
from deepest.utils import plot_parameters
from helper import Runner
RUNNER = Runner(MODEL_PATH, DATASET, BS, WORKER)
main()