Jonathan Li commited on
Commit
1ebc0dd
·
1 Parent(s): 2bceb77

Revert "Add broken streamlit (no way to mark sponsors?)"

Browse files

This reverts commit 2bceb77e414dd0e5ef1400dec9e5731109481697.

Files changed (2) hide show
  1. app.py +65 -119
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,26 +1,17 @@
1
- import re
2
- import streamlit as st
3
  import requests
4
  from transformers import AutoTokenizer, pipeline
5
  from youtube_transcript_api._transcripts import TranscriptListFetcher
6
 
7
  tagger = pipeline(
8
- "token-classification", "./checkpoint-6000", aggregation_strategy="first",
 
 
9
  )
10
  tokenizer = AutoTokenizer.from_pretrained("./checkpoint-6000")
11
  max_size = 512
12
  classes = [False, True]
13
 
14
- pattern = re.compile(
15
- r"(?:https?:\/\/)?(?:[0-9A-Z-]+\.)?(?:youtube|youtu|youtube-nocookie)\.(?:com|be)\/(?:watch\?v=|watch\?.+&v=|embed\/|v\/|.+\?v=)?([^&=\n%\?]{11})"
16
- )
17
-
18
-
19
- def video_id(url):
20
- p = pattern.match(url)
21
- return p.group(1)
22
-
23
-
24
  def process(obj):
25
  o = obj["events"]
26
  new_l = []
@@ -56,7 +47,6 @@ def process(obj):
56
 
57
  return new_l
58
 
59
-
60
  def get_transcript(video_id, session):
61
  fetcher = TranscriptListFetcher(session)
62
  _json = fetcher._extract_captions_json(
@@ -75,113 +65,69 @@ def get_transcript(video_id, session):
75
  p = process(obj.json())
76
  return p
77
 
78
-
79
  def transcript(video_id):
80
- return " ".join(
81
- l["w"].strip() for l in get_transcript(video_id, requests.Session())
82
- )
83
-
84
 
85
  def inference(transcript):
86
- tokens = tokenizer(transcript.split(" "))["input_ids"]
87
- current_length = 0
88
- current_word_length = 0
89
- batches = []
90
- for i, w in enumerate(tokens):
91
- word = w[:-1] if i == 0 else w[1:] if i == (len(tokens) - 1) else w[1:-1]
92
- if (current_length + len(word)) > max_size:
93
- batch = " ".join(
94
- tokenizer.batch_decode(
95
- [
96
- tok[1:-1]
97
- for tok in tokens[max(0, i - current_word_length - 1) : i]
98
- ]
99
- )
100
- )
101
- batches.append(batch)
102
- current_word_length = 0
103
- current_length = 0
104
- continue
105
- current_length += len(word)
106
- current_word_length += 1
107
- if current_length > 0:
108
- batches.append(
109
- " ".join(
110
- tokenizer.batch_decode(
111
- [tok[1:-1] for tok in tokens[i - current_word_length :]]
112
- )
113
- )
114
- )
115
-
116
- results = []
117
- for split in batches:
118
- values = tagger(split)
119
- results.extend(
120
- {"sponsor": v["entity_group"] == "LABEL_1", "phrase": v["word"],}
121
- for v in values
122
- )
123
-
124
- return results
125
-
 
 
126
 
127
  def predict(transcript):
128
- return [
129
- (span["phrase"], "Sponsor" if span["sponsor"] else None)
130
- for span in inference(transcript)
131
- ]
132
-
133
-
134
- st.title("reBlock (AI Sponsor Detector)")
135
-
136
- load_data, run_ai = st.container(), st.container()
137
- load_data.subheader("Load transcript:")
138
- run_ai.subheader("Predict sponsors:")
139
-
140
- if "transcript" not in st.session_state:
141
- st.session_state["transcript"] = ""
142
- if "url" not in st.session_state:
143
- st.session_state["url"] = ""
144
-
145
- def submit(url):
146
- if url:
147
- ts = transcript(video_id(url))
148
- st.session_state.transcript = ts
149
- st.session_state.url = url
150
- else:
151
- st.error(
152
- "Invalid youtube url. Take a look at the examples for a supported format"
153
- )
154
-
155
- with load_data:
156
- with st.form(key="load_transcript"):
157
- url = st.text_input("Youtube Video URL", key="url")
158
- submitted = st.form_submit_button("Get Transcript", on_click=lambda: submit(url))
159
- transcript_text_area = st.text_area("Scraped Transcript", key="transcript")
160
-
161
- st.caption("Or, try an example:")
162
- examples = ["youtu.be/xsLJZyih3Ac"]
163
- col = st.columns(len(examples))
164
- for i, example in enumerate(examples):
165
- col[i] = st.button(example, on_click=lambda: submit(example))
166
-
167
- with run_ai:
168
- with st.form(key="run_ai"):
169
- submitted = st.form_submit_button("Predict Sponsors!")
170
-
171
- # read_transcript = st.text("Reading...")
172
-
173
- # with gr.Blocks() as demo:
174
- # with gr.Row():
175
- # with gr.Column():
176
- # inp = gr.Textbox(label="Video URL", placeholder="Video url", lines=1, max_lines=1)
177
- # btn = gr.Button("Fetch Transcript")
178
- # gr.Examples(["youtu.be/xsLJZyih3Ac"], [inp])
179
- # text = gr.Textbox(label="Transcript", placeholder="<generated transcript>")
180
- # btn.click(fn=transcript, inputs=inp, outputs=text)
181
- # with gr.Column():
182
- # p = gr.Button("Predict Sponsors")
183
- # highlight = gr.HighlightedText()
184
- # p.click(fn=predict, inputs=text, outputs=highlight)
185
-
186
-
187
- # demo.launch()
 
1
+ import gradio as gr
 
2
  import requests
3
  from transformers import AutoTokenizer, pipeline
4
  from youtube_transcript_api._transcripts import TranscriptListFetcher
5
 
6
  tagger = pipeline(
7
+ "token-classification",
8
+ "./checkpoint-6000",
9
+ aggregation_strategy="first",
10
  )
11
  tokenizer = AutoTokenizer.from_pretrained("./checkpoint-6000")
12
  max_size = 512
13
  classes = [False, True]
14
 
 
 
 
 
 
 
 
 
 
 
15
  def process(obj):
16
  o = obj["events"]
17
  new_l = []
 
47
 
48
  return new_l
49
 
 
50
  def get_transcript(video_id, session):
51
  fetcher = TranscriptListFetcher(session)
52
  _json = fetcher._extract_captions_json(
 
65
  p = process(obj.json())
66
  return p
67
 
 
68
  def transcript(video_id):
69
+ return " ".join(l["w"].strip() for l in get_transcript(video_id, requests.Session()))
 
 
 
70
 
71
  def inference(transcript):
72
+ tokens = tokenizer(transcript.split(" "))["input_ids"]
73
+ current_length = 0
74
+ current_word_length = 0
75
+ batches = []
76
+ for i, w in enumerate(tokens):
77
+ word = w[:-1] if i == 0 else w[1:] if i == (len(tokens) - 1) else w[1:-1]
78
+ if (current_length + len(word)) > max_size:
79
+ batch = " ".join(
80
+ tokenizer.batch_decode(
81
+ [
82
+ tok[1:-1]
83
+ for tok in tokens[max(0, i - current_word_length - 1) : i]
84
+ ]
85
+ )
86
+ )
87
+ batches.append(batch)
88
+ current_word_length = 0
89
+ current_length = 0
90
+ continue
91
+ current_length += len(word)
92
+ current_word_length += 1
93
+ if current_length > 0:
94
+ batches.append(
95
+ " ".join(
96
+ tokenizer.batch_decode(
97
+ [tok[1:-1] for tok in tokens[i - current_word_length :]]
98
+ )
99
+ )
100
+ )
101
+
102
+ results = []
103
+ for split in batches:
104
+ values = tagger(split)
105
+ results.extend(
106
+ {
107
+ "sponsor": v["entity_group"] == "LABEL_1",
108
+ "phrase": v["word"],
109
+ }
110
+ for v in values
111
+ )
112
+
113
+ return results
114
 
115
  def predict(transcript):
116
+ return [(span["phrase"], "Sponsor" if span["sponsor"] else None) for span in inference(transcript)]
117
+
118
+
119
+ with gr.Blocks() as demo:
120
+ with gr.Row():
121
+ with gr.Column():
122
+ inp = gr.Textbox(label="Video ID or URL", placeholder="Video id", lines=1, max_lines=1)
123
+ btn = gr.Button("Fetch Transcript")
124
+ gr.Examples(["xsLJZyih3Ac"], [inp])
125
+ text = gr.Textbox(label="Transcript", placeholder="<generated transcript>")
126
+ btn.click(fn=transcript, inputs=inp, outputs=text)
127
+ with gr.Column():
128
+ p = gr.Button("Predict Sponsors")
129
+ highlight = gr.HighlightedText()
130
+ p.click(fn=predict, inputs=text, outputs=highlight)
131
+
132
+
133
+ demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -3,4 +3,4 @@ youtube_transcript_api
3
  torch
4
  pandas
5
  numpy
6
- streamlit
 
3
  torch
4
  pandas
5
  numpy
6
+ gradio