Joshua Lochner commited on
Commit
4822df2
1 Parent(s): ad7fc61

Create basic streamlit application

Browse files
Files changed (1) hide show
  1. app.py +154 -0
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from math import ceil, floor
3
+ import streamlit.components.v1 as components
4
+ from transformers import (
5
+ AutoModelForSeq2SeqLM,
6
+ AutoTokenizer,
7
+ )
8
+ import streamlit as st
9
+ import sys
10
+ import os
11
+ import json
12
+ from urllib.parse import quote
13
+
14
+ # Allow direct execution
15
+ sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'src')) # noqa
16
+
17
+ from predict import SegmentationArguments, ClassifierArguments, predict as pred, seconds_to_time # noqa
18
+ from evaluate import EvaluationArguments
19
+ from shared import device
20
+
21
+ st.set_page_config(
22
+ page_title="SponsorBlock ML",
23
+ page_icon="🤖",
24
+ # layout='wide',
25
+ # initial_sidebar_state="expanded",
26
+ menu_items={
27
+ 'Get Help': 'https://github.com/xenova/sponsorblock-ml',
28
+ 'Report a bug': 'https://github.com/xenova/sponsorblock-ml/issues/new/choose',
29
+ # 'About': "# This is a header. This is an *extremely* cool app!"
30
+ }
31
+ )
32
+
33
+ MODEL_PATH = 'Xenova/sponsorblock-small_v2022.01.19'
34
+
35
+
36
+ @st.cache(allow_output_mutation=True)
37
+ def persistdata():
38
+ return {}
39
+
40
+
41
+ # Faster caching system for predictions (No need to hash)
42
+ predictions_cache = persistdata()
43
+
44
+
45
+ @st.cache(allow_output_mutation=True)
46
+ def load_predict():
47
+ # Use default segmentation and classification arguments
48
+ evaluation_args = EvaluationArguments(model_path=MODEL_PATH)
49
+ segmentation_args = SegmentationArguments()
50
+ classifier_args = ClassifierArguments()
51
+
52
+ model = AutoModelForSeq2SeqLM.from_pretrained(evaluation_args.model_path)
53
+ model.to(device())
54
+
55
+ tokenizer = AutoTokenizer.from_pretrained(evaluation_args.model_path)
56
+
57
+ def predict_function(video_id):
58
+ if video_id not in predictions_cache:
59
+ predictions_cache[video_id] = pred(
60
+ video_id, model, tokenizer,
61
+ segmentation_args=segmentation_args,
62
+ classifier_args=classifier_args
63
+ )
64
+ return predictions_cache[video_id]
65
+
66
+ return predict_function
67
+
68
+
69
+ CATGEGORY_OPTIONS = {
70
+ 'SPONSOR': 'Sponsor',
71
+ 'SELFPROMO': 'Self/unpaid promo',
72
+ 'INTERACTION': 'Interaction reminder',
73
+ }
74
+
75
+
76
+ # Load prediction function
77
+ predict = load_predict()
78
+
79
+
80
+ def main():
81
+
82
+ # Display heading and subheading
83
+ st.write('# SponsorBlock ML')
84
+ st.write('##### Automatically detect in-video YouTube sponsorships, self/unpaid promotions, and interaction reminders.')
85
+
86
+ # Load widgets
87
+ video_id = st.text_input('Video ID:', placeholder='e.g., axtQvkSpoto')
88
+
89
+ categories = st.multiselect('Categories:',
90
+ CATGEGORY_OPTIONS.keys(),
91
+ CATGEGORY_OPTIONS.keys(),
92
+ format_func=CATGEGORY_OPTIONS.get
93
+ )
94
+
95
+ # Hide segments with a confidence lower than
96
+ confidence_threshold = st.slider(
97
+ 'Confidence Threshold (%):', min_value=0, max_value=100)
98
+
99
+ video_id_length = len(video_id)
100
+ if video_id_length == 0:
101
+ return
102
+
103
+ elif video_id_length != 11:
104
+ st.exception(ValueError('Invalid YouTube ID'))
105
+ return
106
+
107
+ with st.spinner('Running model...'):
108
+ predictions = predict(video_id)
109
+
110
+ if len(predictions) == 0:
111
+ st.success('No segments found!')
112
+ return
113
+
114
+ submit_segments = []
115
+ for index, prediction in enumerate(predictions, start=1):
116
+ if prediction['category'] not in categories:
117
+ continue # Skip
118
+
119
+ confidence = prediction['probability'] * 100
120
+
121
+ if confidence < confidence_threshold:
122
+ continue
123
+
124
+ submit_segments.append({
125
+ 'segment': [prediction['start'], prediction['end']],
126
+ 'category': prediction['category'].lower(),
127
+ 'actionType': 'skip'
128
+ })
129
+ start_time = seconds_to_time(prediction['start'])
130
+ end_time = seconds_to_time(prediction['end'])
131
+ with st.expander(
132
+ f"[{prediction['category']}] Prediction #{index} ({start_time} \u2192 {end_time})"
133
+ ):
134
+
135
+ url = f"https://www.youtube-nocookie.com/embed/{video_id}?&start={floor(prediction['start'])}&end={ceil(prediction['end'])}"
136
+ # autoplay=1controls=0&&modestbranding=1&fs=0
137
+
138
+ # , width=None, height=None, scrolling=False
139
+ components.iframe(url, width=670, height=376)
140
+
141
+ text = ' '.join(w['text'] for w in prediction['words'])
142
+ st.write(f"**Times:** {start_time} \u2192 {end_time}")
143
+ st.write(
144
+ f"**Category:** {CATGEGORY_OPTIONS[prediction['category']]}")
145
+ st.write(f"**Confidence:** {confidence:.2f}%")
146
+ st.write(f'**Text:** "{text}"')
147
+
148
+ json_data = quote(json.dumps(submit_segments))
149
+ link = f'[Submit Segments](https://www.youtube.com/watch?v={video_id}#segments={json_data})'
150
+ st.markdown(link, unsafe_allow_html=True)
151
+
152
+
153
+ if __name__ == '__main__':
154
+ main()