Spaces:
Build error
Build error
taquynhnga
commited on
Commit
•
78866a7
1
Parent(s):
14d54e2
Update backend/utils.py
Browse files- backend/utils.py +24 -21
backend/utils.py
CHANGED
@@ -17,14 +17,15 @@ from tqdm import trange
|
|
17 |
import torch
|
18 |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
|
19 |
|
20 |
-
|
21 |
-
@st.cache_resource
|
22 |
def load_dataset(data_index):
|
23 |
with open(f'./data/preprocessed_image_net/val_data_{data_index}.pkl', 'rb') as file:
|
24 |
dataset = pickle.load(file)
|
25 |
return dataset
|
26 |
|
27 |
-
@st.
|
|
|
28 |
def load_dataset_dict():
|
29 |
dataset_dict = {}
|
30 |
progress_empty = st.empty()
|
@@ -39,13 +40,15 @@ def load_dataset_dict():
|
|
39 |
return dataset_dict
|
40 |
|
41 |
|
42 |
-
@st.cache_data
|
|
|
43 |
def load_image(image_id):
|
44 |
dataset = load_dataset(image_id//10000)
|
45 |
image = dataset[image_id%10000]
|
46 |
return image
|
47 |
|
48 |
-
@st.cache_data
|
|
|
49 |
def load_images(image_ids):
|
50 |
images = []
|
51 |
for image_id in image_ids:
|
@@ -54,8 +57,8 @@ def load_images(image_ids):
|
|
54 |
return images
|
55 |
|
56 |
|
57 |
-
|
58 |
-
@st.cache_resource
|
59 |
def load_model(model_name):
|
60 |
with st.spinner(f"Loading {model_name} model! This process might take 1-2 minutes..."):
|
61 |
if model_name == 'ResNet':
|
@@ -356,21 +359,21 @@ def _set_block_container_style(
|
|
356 |
)
|
357 |
|
358 |
|
359 |
-
@st.cache
|
360 |
-
def get_dataframe() -> pd.DataFrame():
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
|
369 |
|
370 |
-
def get_plotly_fig():
|
371 |
-
|
372 |
-
|
373 |
|
374 |
|
375 |
-
def get_matplotlib_plt():
|
376 |
-
|
|
|
17 |
import torch
|
18 |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
|
19 |
|
20 |
+
@st.cache(allow_output_mutation=True)
|
21 |
+
# @st.cache_resource
|
22 |
def load_dataset(data_index):
|
23 |
with open(f'./data/preprocessed_image_net/val_data_{data_index}.pkl', 'rb') as file:
|
24 |
dataset = pickle.load(file)
|
25 |
return dataset
|
26 |
|
27 |
+
@st.cache(allow_output_mutation=True)
|
28 |
+
# @st.cache_resource
|
29 |
def load_dataset_dict():
|
30 |
dataset_dict = {}
|
31 |
progress_empty = st.empty()
|
|
|
40 |
return dataset_dict
|
41 |
|
42 |
|
43 |
+
# @st.cache_data
|
44 |
+
@st.cache(allow_output_mutation=True)
|
45 |
def load_image(image_id):
|
46 |
dataset = load_dataset(image_id//10000)
|
47 |
image = dataset[image_id%10000]
|
48 |
return image
|
49 |
|
50 |
+
# @st.cache_data
|
51 |
+
@st.cache(allow_output_mutation=True)
|
52 |
def load_images(image_ids):
|
53 |
images = []
|
54 |
for image_id in image_ids:
|
|
|
57 |
return images
|
58 |
|
59 |
|
60 |
+
@st.cache(allow_output_mutation=True, suppress_st_warning=True, show_spinner=False)
|
61 |
+
# @st.cache_resource
|
62 |
def load_model(model_name):
|
63 |
with st.spinner(f"Loading {model_name} model! This process might take 1-2 minutes..."):
|
64 |
if model_name == 'ResNet':
|
|
|
359 |
)
|
360 |
|
361 |
|
362 |
+
# @st.cache
|
363 |
+
# def get_dataframe() -> pd.DataFrame():
|
364 |
+
# """Dummy DataFrame"""
|
365 |
+
# data = [
|
366 |
+
# {"quantity": 1, "price": 2},
|
367 |
+
# {"quantity": 3, "price": 5},
|
368 |
+
# {"quantity": 4, "price": 8},
|
369 |
+
# ]
|
370 |
+
# return pd.DataFrame(data)
|
371 |
|
372 |
|
373 |
+
# def get_plotly_fig():
|
374 |
+
# """Dummy Plotly Plot"""
|
375 |
+
# return px.line(data_frame=get_dataframe(), x="quantity", y="price")
|
376 |
|
377 |
|
378 |
+
# def get_matplotlib_plt():
|
379 |
+
# get_dataframe().plot(kind="line", x="quantity", y="price", figsize=(5, 3))
|