Joshua Lochner commited on
Commit
a9123fa
1 Parent(s): fb87012

Remove duplicated methods from streamlit app

Browse files
Files changed (1) hide show
  1. app.py +9 -27
app.py CHANGED
@@ -7,16 +7,15 @@ import sys
7
  import os
8
  import json
9
  from urllib.parse import quote
10
- from huggingface_hub import hf_hub_download
11
 
12
  # Allow direct execution
13
  sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'src')) # noqa
14
 
15
- from predict import SegmentationArguments, ClassifierArguments, predict as pred, seconds_to_time # noqa
16
  from evaluate import EvaluationArguments
17
- from shared import device, CATGEGORY_OPTIONS
18
  from utils import regex_search
19
- from model import get_model_tokenizer
20
 
21
  st.set_page_config(
22
  page_title='SponsorBlock ML',
@@ -106,22 +105,6 @@ for m in MODELS:
106
  CLASSIFIER_PATH = 'Xenova/sponsorblock-classifier'
107
 
108
 
109
- @st.cache(persist=True, allow_output_mutation=True)
110
- def download_classifier(classifier_args):
111
- # Save classifier and vectorizer
112
- hf_hub_download(repo_id=CLASSIFIER_PATH,
113
- filename=classifier_args.classifier_file,
114
- cache_dir=classifier_args.classifier_dir,
115
- force_filename=classifier_args.classifier_file,
116
- )
117
- hf_hub_download(repo_id=CLASSIFIER_PATH,
118
- filename=classifier_args.vectorizer_file,
119
- cache_dir=classifier_args.classifier_dir,
120
- force_filename=classifier_args.vectorizer_file,
121
- )
122
- return True
123
-
124
-
125
  def predict_function(model_id, model, tokenizer, segmentation_args, classifier_args, video_id):
126
  if video_id not in prediction_cache[model_id]:
127
  prediction_cache[model_id][video_id] = pred(
@@ -139,12 +122,11 @@ def load_predict(model_id):
139
  # Use default segmentation and classification arguments
140
  evaluation_args = EvaluationArguments(model_path=model_info['repo_id'])
141
  segmentation_args = SegmentationArguments()
142
- classifier_args = ClassifierArguments()
 
143
 
144
  model, tokenizer = get_model_tokenizer(evaluation_args.model_path)
145
 
146
- download_classifier(classifier_args)
147
-
148
  prediction_function_cache[model_id] = partial(
149
  predict_function, model_id, model, tokenizer, segmentation_args, classifier_args)
150
 
@@ -157,7 +139,8 @@ def main():
157
 
158
  # Display heading and subheading
159
  top.markdown('# SponsorBlock ML')
160
- top.markdown('##### Automatically detect in-video YouTube sponsorships, self/unpaid promotions, and interaction reminders.')
 
161
 
162
  # Add controls
163
  model_id = top.selectbox(
@@ -174,8 +157,7 @@ def main():
174
 
175
  # Hide segments with a confidence lower than
176
  confidence_threshold = top.slider(
177
- 'Confidence Threshold (%):', min_value=0, max_value=100, on_change=output.empty)
178
-
179
 
180
  if len(video_input) == 0: # No input, do not continue
181
  return
@@ -184,7 +166,7 @@ def main():
184
  with st.spinner('Loading model...'):
185
  predict = load_predict(model_id)
186
 
187
- with output.container(): # Place all content in output container
188
  video_id = regex_search(video_input, YT_VIDEO_REGEX)
189
  if video_id is None:
190
  st.exception(ValueError('Invalid YouTube URL/ID'))
 
7
  import os
8
  import json
9
  from urllib.parse import quote
 
10
 
11
  # Allow direct execution
12
  sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'src')) # noqa
13
 
14
+ from predict import SegmentationArguments, ClassifierArguments, predict as pred # noqa
15
  from evaluate import EvaluationArguments
16
+ from shared import seconds_to_time, CATGEGORY_OPTIONS
17
  from utils import regex_search
18
+ from model import get_model_tokenizer, get_classifier_vectorizer
19
 
20
  st.set_page_config(
21
  page_title='SponsorBlock ML',
 
105
  CLASSIFIER_PATH = 'Xenova/sponsorblock-classifier'
106
 
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  def predict_function(model_id, model, tokenizer, segmentation_args, classifier_args, video_id):
109
  if video_id not in prediction_cache[model_id]:
110
  prediction_cache[model_id][video_id] = pred(
 
122
  # Use default segmentation and classification arguments
123
  evaluation_args = EvaluationArguments(model_path=model_info['repo_id'])
124
  segmentation_args = SegmentationArguments()
125
+ classifier_args = ClassifierArguments(
126
+ min_probability=0) # Filtering done later
127
 
128
  model, tokenizer = get_model_tokenizer(evaluation_args.model_path)
129
 
 
 
130
  prediction_function_cache[model_id] = partial(
131
  predict_function, model_id, model, tokenizer, segmentation_args, classifier_args)
132
 
 
139
 
140
  # Display heading and subheading
141
  top.markdown('# SponsorBlock ML')
142
+ top.markdown(
143
+ '##### Automatically detect in-video YouTube sponsorships, self/unpaid promotions, and interaction reminders.')
144
 
145
  # Add controls
146
  model_id = top.selectbox(
 
157
 
158
  # Hide segments with a confidence lower than
159
  confidence_threshold = top.slider(
160
+ 'Confidence Threshold (%):', min_value=0, value=50, max_value=100, on_change=output.empty)
 
161
 
162
  if len(video_input) == 0: # No input, do not continue
163
  return
 
166
  with st.spinner('Loading model...'):
167
  predict = load_predict(model_id)
168
 
169
+ with output.container(): # Place all content in output container
170
  video_id = regex_search(video_input, YT_VIDEO_REGEX)
171
  if video_id is None:
172
  st.exception(ValueError('Invalid YouTube URL/ID'))