Joshua Lochner commited on
Commit
e68b946
1 Parent(s): 29f5ee8

Add option to streamlit app for model selection

Browse files
Files changed (1) hide show
  1. app.py +35 -17
app.py CHANGED
@@ -20,8 +20,8 @@ from evaluate import EvaluationArguments
20
  from shared import device
21
 
22
  st.set_page_config(
23
- page_title="SponsorBlock ML",
24
- page_icon="🤖",
25
  # layout='wide',
26
  # initial_sidebar_state="expanded",
27
  menu_items={
@@ -30,8 +30,33 @@ st.set_page_config(
30
  # 'About': "# This is a header. This is an *extremely* cool app!"
31
  }
32
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- MODEL_PATH = 'Xenova/sponsorblock-small'
 
 
 
 
35
 
36
  CLASSIFIER_PATH = 'Xenova/sponsorblock-classifier'
37
 
@@ -46,9 +71,9 @@ predictions_cache = persistdata()
46
 
47
 
48
  @st.cache(allow_output_mutation=True)
49
- def load_predict():
50
  # Use default segmentation and classification arguments
51
- evaluation_args = EvaluationArguments(model_path=MODEL_PATH)
52
  segmentation_args = SegmentationArguments()
53
  classifier_args = ClassifierArguments()
54
 
@@ -81,24 +106,17 @@ def load_predict():
81
  return predict_function
82
 
83
 
84
- CATGEGORY_OPTIONS = {
85
- 'SPONSOR': 'Sponsor',
86
- 'SELFPROMO': 'Self/unpaid promo',
87
- 'INTERACTION': 'Interaction reminder',
88
- }
89
-
90
-
91
- # Load prediction function
92
- predict = load_predict()
93
-
94
-
95
  def main():
96
 
97
  # Display heading and subheading
98
  st.write('# SponsorBlock ML')
99
  st.write('##### Automatically detect in-video YouTube sponsorships, self/unpaid promotions, and interaction reminders.')
100
 
101
- # Load widgets
 
 
 
 
102
  video_id = st.text_input('Video ID:') # , placeholder='e.g., axtQvkSpoto'
103
 
104
  categories = st.multiselect('Categories:',
 
20
  from shared import device
21
 
22
  st.set_page_config(
23
+ page_title='SponsorBlock ML',
24
+ page_icon='🤖',
25
  # layout='wide',
26
  # initial_sidebar_state="expanded",
27
  menu_items={
 
30
  # 'About': "# This is a header. This is an *extremely* cool app!"
31
  }
32
  )
33
+ # https://github.com/google-research/text-to-text-transfer-transformer#released-model-checkpoints
34
+ # https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#experimental-t5-pre-trained-model-checkpoints
35
+
36
+ # https://huggingface.co/docs/transformers/model_doc/t5
37
+ # https://huggingface.co/docs/transformers/model_doc/t5v1.1
38
+ MODELS = {
39
+ 'Small (77M)': {
40
+ 'pretrained': 'google/t5-v1_1-small',
41
+ 'repo_id': 'Xenova/sponsorblock-small',
42
+ },
43
+ 'Base v1 (220M)': {
44
+ 'pretrained': 't5-base',
45
+ 'repo_id': 'EColi/sponsorblock-base-v1',
46
+ },
47
+
48
+ 'Base v1.1 (250M)': {
49
+ 'pretrained': 'google/t5-v1_1-base',
50
+ 'repo_id': 'Xenova/sponsorblock-base',
51
+ }
52
+
53
+ }
54
 
55
+ CATGEGORY_OPTIONS = {
56
+ 'SPONSOR': 'Sponsor',
57
+ 'SELFPROMO': 'Self/unpaid promo',
58
+ 'INTERACTION': 'Interaction reminder',
59
+ }
60
 
61
  CLASSIFIER_PATH = 'Xenova/sponsorblock-classifier'
62
 
 
71
 
72
 
73
  @st.cache(allow_output_mutation=True)
74
+ def load_predict(model_path):
75
  # Use default segmentation and classification arguments
76
+ evaluation_args = EvaluationArguments(model_path=model_path)
77
  segmentation_args = SegmentationArguments()
78
  classifier_args = ClassifierArguments()
79
 
 
106
  return predict_function
107
 
108
 
 
 
 
 
 
 
 
 
 
 
 
109
  def main():
110
 
111
  # Display heading and subheading
112
  st.write('# SponsorBlock ML')
113
  st.write('##### Automatically detect in-video YouTube sponsorships, self/unpaid promotions, and interaction reminders.')
114
 
115
+ model_id = st.selectbox('Select model', MODELS.keys(), index=0)
116
+
117
+ # Load prediction function
118
+ predict = load_predict(MODELS[model_id]['repo_id'])
119
+
120
  video_id = st.text_input('Video ID:') # , placeholder='e.g., axtQvkSpoto'
121
 
122
  categories = st.multiselect('Categories:',