Joshua Lochner commited on
Commit
00f77c2
1 Parent(s): fc2e81f

Update streamlit app to download the classifier and vectorizer from the hub

Browse files
Files changed (1) hide show
  1. app.py +6 -0
app.py CHANGED
@@ -10,6 +10,7 @@ import sys
10
  import os
11
  import json
12
  from urllib.parse import quote
 
13
 
14
  # Allow direct execution
15
  sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'src')) # noqa
@@ -32,6 +33,7 @@ st.set_page_config(
32
 
33
  MODEL_PATH = 'Xenova/sponsorblock-small_v2022.01.19'
34
 
 
35
 
36
  @st.cache(allow_output_mutation=True)
37
  def persistdata():
@@ -54,6 +56,10 @@ def load_predict():
54
 
55
  tokenizer = AutoTokenizer.from_pretrained(evaluation_args.model_path)
56
 
 
 
 
 
57
  def predict_function(video_id):
58
  if video_id not in predictions_cache:
59
  predictions_cache[video_id] = pred(
 
10
  import os
11
  import json
12
  from urllib.parse import quote
13
+ from huggingface_hub import hf_hub_download
14
 
15
  # Allow direct execution
16
  sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'src')) # noqa
 
33
 
34
  MODEL_PATH = 'Xenova/sponsorblock-small_v2022.01.19'
35
 
36
+ CLASSIFIER_PATH = 'Xenova/sponsorblock-classifier'
37
 
38
  @st.cache(allow_output_mutation=True)
39
  def persistdata():
 
56
 
57
  tokenizer = AutoTokenizer.from_pretrained(evaluation_args.model_path)
58
 
59
+ # Save classifier and vectorizer
60
+ hf_hub_download(repo_id=CLASSIFIER_PATH, filename=classifier_args.classifier_file, cache_dir=classifier_args.classifier_dir)
61
+ hf_hub_download(repo_id=CLASSIFIER_PATH, filename=classifier_args.vectorizer_file, cache_dir=classifier_args.classifier_dir)
62
+
63
  def predict_function(video_id):
64
  if video_id not in predictions_cache:
65
  predictions_cache[video_id] = pred(