Joshua Lochner commited on
Commit
14ea568
1 Parent(s): e68b946

Update prediction caching system to store predictions per model

Browse files
Files changed (1) hide show
  1. app.py +18 -15
app.py CHANGED
@@ -35,21 +35,31 @@ st.set_page_config(
35
 
36
  # https://huggingface.co/docs/transformers/model_doc/t5
37
  # https://huggingface.co/docs/transformers/model_doc/t5v1.1
 
 
 
 
 
 
 
 
38
  MODELS = {
39
  'Small (77M)': {
40
  'pretrained': 'google/t5-v1_1-small',
41
  'repo_id': 'Xenova/sponsorblock-small',
 
42
  },
43
  'Base v1 (220M)': {
44
  'pretrained': 't5-base',
45
  'repo_id': 'EColi/sponsorblock-base-v1',
 
46
  },
47
 
48
  'Base v1.1 (250M)': {
49
  'pretrained': 'google/t5-v1_1-base',
50
  'repo_id': 'Xenova/sponsorblock-base',
 
51
  }
52
-
53
  }
54
 
55
  CATGEGORY_OPTIONS = {
@@ -62,18 +72,11 @@ CLASSIFIER_PATH = 'Xenova/sponsorblock-classifier'
62
 
63
 
64
  @st.cache(allow_output_mutation=True)
65
- def persistdata():
66
- return {}
67
-
68
 
69
- # Faster caching system for predictions (No need to hash)
70
- predictions_cache = persistdata()
71
-
72
-
73
- @st.cache(allow_output_mutation=True)
74
- def load_predict(model_path):
75
  # Use default segmentation and classification arguments
76
- evaluation_args = EvaluationArguments(model_path=model_path)
77
  segmentation_args = SegmentationArguments()
78
  classifier_args = ClassifierArguments()
79
 
@@ -95,13 +98,13 @@ def load_predict(model_path):
95
  )
96
 
97
  def predict_function(video_id):
98
- if video_id not in predictions_cache:
99
- predictions_cache[video_id] = pred(
100
  video_id, model, tokenizer,
101
  segmentation_args=segmentation_args,
102
  classifier_args=classifier_args
103
  )
104
- return predictions_cache[video_id]
105
 
106
  return predict_function
107
 
@@ -115,7 +118,7 @@ def main():
115
  model_id = st.selectbox('Select model', MODELS.keys(), index=0)
116
 
117
  # Load prediction function
118
- predict = load_predict(MODELS[model_id]['repo_id'])
119
 
120
  video_id = st.text_input('Video ID:') # , placeholder='e.g., axtQvkSpoto'
121
 
 
35
 
36
  # https://huggingface.co/docs/transformers/model_doc/t5
37
  # https://huggingface.co/docs/transformers/model_doc/t5v1.1
38
+
39
+
40
+ # Faster caching system for predictions (No need to hash)
41
+ @st.cache(allow_output_mutation=True)
42
+ def persistdata():
43
+ return {}
44
+
45
+
46
  MODELS = {
47
  'Small (77M)': {
48
  'pretrained': 'google/t5-v1_1-small',
49
  'repo_id': 'Xenova/sponsorblock-small',
50
+ 'cache': persistdata()
51
  },
52
  'Base v1 (220M)': {
53
  'pretrained': 't5-base',
54
  'repo_id': 'EColi/sponsorblock-base-v1',
55
+ 'cache': persistdata()
56
  },
57
 
58
  'Base v1.1 (250M)': {
59
  'pretrained': 'google/t5-v1_1-base',
60
  'repo_id': 'Xenova/sponsorblock-base',
61
+ 'cache': persistdata()
62
  }
 
63
  }
64
 
65
  CATGEGORY_OPTIONS = {
 
72
 
73
 
74
  @st.cache(allow_output_mutation=True)
75
+ def load_predict(model_id):
76
+ model = MODELS[model_id]
 
77
 
 
 
 
 
 
 
78
  # Use default segmentation and classification arguments
79
+ evaluation_args = EvaluationArguments(model_path=model['repo_id'])
80
  segmentation_args = SegmentationArguments()
81
  classifier_args = ClassifierArguments()
82
 
 
98
  )
99
 
100
  def predict_function(video_id):
101
+ if video_id not in model['cache']:
102
+ model['cache'][video_id] = pred(
103
  video_id, model, tokenizer,
104
  segmentation_args=segmentation_args,
105
  classifier_args=classifier_args
106
  )
107
+ return model['cache'][video_id]
108
 
109
  return predict_function
110
 
 
118
  model_id = st.selectbox('Select model', MODELS.keys(), index=0)
119
 
120
  # Load prediction function
121
+ predict = load_predict(model_id)
122
 
123
  video_id = st.text_input('Video ID:') # , placeholder='e.g., axtQvkSpoto'
124