yonikremer commited on
Commit
5967916
1 Parent(s): 70d3eba

checks if the model is downloaded before downloading it

Browse files
Files changed (1) hide show
  1. hanlde_form_submit.py +19 -6
hanlde_form_submit.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from time import time
2
 
3
  import streamlit as st
@@ -11,6 +12,17 @@ from supported_models import get_supported_model_names
11
  SUPPORTED_MODEL_NAMES = get_supported_model_names()
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
14
  def create_pipeline(model_name: str, group_size: int) -> GroupedSamplingPipeLine:
15
  """
16
  Creates a pipeline with the given model name and group size.
@@ -18,12 +30,13 @@ def create_pipeline(model_name: str, group_size: int) -> GroupedSamplingPipeLine
18
  :param group_size: The size of the groups to use.
19
  :return: A pipeline with the given model name and group size.
20
  """
21
- download_repository_start_time = time()
22
- st.write(f"Starts downloading model: {model_name} from the internet.")
23
- download_repository(model_name)
24
- download_repository_end_time = time()
25
- download_time = download_repository_end_time - download_repository_start_time
26
- st.write(f"Finished downloading model: {model_name} from the internet in {download_time} seconds.")
 
27
  st.write(f"Starts creating pipeline with model: {model_name}")
28
  pipeline_start_time = time()
29
  pipeline = GroupedSamplingPipeLine(
 
1
+ import os
2
  from time import time
3
 
4
  import streamlit as st
 
12
  SUPPORTED_MODEL_NAMES = get_supported_model_names()
13
 
14
 
15
+ def is_downloaded(model_name: str) -> bool:
16
+ """
17
+ Checks if the model is downloaded.
18
+ :param model_name: The name of the model to check.
19
+ :return: True if the model is downloaded, False otherwise.
20
+ """
21
+ models_dir = "/root/.cache/huggingface/hub"
22
+ model_dir = os.path.join(models_dir, f"models--{model_name.replace('/', '--')}")
23
+ return os.path.isdir(model_dir)
24
+
25
+
26
  def create_pipeline(model_name: str, group_size: int) -> GroupedSamplingPipeLine:
27
  """
28
  Creates a pipeline with the given model name and group size.
 
30
  :param group_size: The size of the groups to use.
31
  :return: A pipeline with the given model name and group size.
32
  """
33
+ if not is_downloaded(model_name):
34
+ download_repository_start_time = time()
35
+ st.write(f"Starts downloading model: {model_name} from the internet.")
36
+ download_repository(model_name)
37
+ download_repository_end_time = time()
38
+ download_time = download_repository_end_time - download_repository_start_time
39
+ st.write(f"Finished downloading model: {model_name} from the internet in {download_time} seconds.")
40
  st.write(f"Starts creating pipeline with model: {model_name}")
41
  pipeline_start_time = time()
42
  pipeline = GroupedSamplingPipeLine(