IanRonk commited on
Commit
d810840
1 Parent(s): 1031728

Add time conversions from outputs

Browse files
app.py CHANGED
@@ -3,6 +3,7 @@ import re
3
  import gradio as gr
4
  from functions.punctuation import punctuate
5
  from functions.model_infer import predict_from_document
 
6
 
7
 
8
  title = "sponsoredBye - never listen to sponsors again"
@@ -12,16 +13,18 @@ article = "Check out [the original Rick and Morty Bot](https://huggingface.co/sp
12
 
13
  def pipeline(video_url):
14
  video_id = video_url.split("?v=")[-1]
15
- punctuated_text = punctuate(video_id)
16
  sentences = re.split(r"[\.\!\?]\s", punctuated_text)
17
  classification, probs = predict_from_document(sentences)
18
  # return punctuated_text
 
19
  return [
20
  {
21
  "start": "12:05",
22
  "end": "12:52",
23
  "classification": str(classification),
24
  "probabilities": probs,
 
25
  }
26
  ]
27
 
 
3
  import gradio as gr
4
  from functions.punctuation import punctuate
5
  from functions.model_infer import predict_from_document
6
+ from functions.convert_time import match_mask_and_transcript
7
 
8
 
9
  title = "sponsoredBye - never listen to sponsors again"
 
13
 
14
  def pipeline(video_url):
15
  video_id = video_url.split("?v=")[-1]
16
+ punctuated_text, transcript = punctuate(video_id)
17
  sentences = re.split(r"[\.\!\?]\s", punctuated_text)
18
  classification, probs = predict_from_document(sentences)
19
  # return punctuated_text
20
+ times = match_mask_and_transcript(sentences, transcript, classification)
21
  return [
22
  {
23
  "start": "12:05",
24
  "end": "12:52",
25
  "classification": str(classification),
26
  "probabilities": probs,
27
+ "times": times,
28
  }
29
  ]
30
 
functions/convert_time.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from thefuzz import fuzz
3
+ import numpy as np
4
+
5
+
6
+ def match_mask_and_transcript(split_punct, transcript, classification):
7
+ """
8
+ Input:
9
+ split_punct: the punctuated text, split on ?/!/.\s,
10
+ transcript: original transcript with timestamps
11
+ classification: classification object (list of numbers 0,1)
12
+ Output: times
13
+ """
14
+
15
+ # Get the sponsored part
16
+ sponsored_segment = []
17
+ for i, val in enumerate(classification):
18
+ if val == 1:
19
+ sponsored_segment.append(split_punct[i])
20
+
21
+ segment = " ".join(sponsored_segment)
22
+ sim_scores = list()
23
+
24
+ # Check the similarity scores between the sponsored part and the transcript parts
25
+ for elem in transcript:
26
+ sim_scores.append(fuzz.partial_ratio(segment, elem["text"]))
27
+
28
+ # Get the scores and check if they are above mean + 2*stdev
29
+ scores = np.array(sim_scores)
30
+ timestamp_mask = (scores > np.mean(scores) + np.std(scores) * 2).astype(int)
31
+ timestamps = [
32
+ (transcript[i]["start"], transcript[i]["duration"])
33
+ for i, elem in enumerate(timestamp_mask)
34
+ if elem == 1
35
+ ]
36
+
37
+ # Get the timestamp segments
38
+ times = []
39
+ current = -1
40
+ current_time = 0
41
+ for elem in timestamps:
42
+ # Threshold of 5 to see if it is a jump to another segment (also to make sure smaller segments are added together
43
+ if elem[0] > (current_time + 5):
44
+ current += 1
45
+ times.append((elem[0], elem[0] + elem[1]))
46
+ current_time = elem[0] + elem[1]
47
+ else:
48
+ times[current] = (times[current][0], elem[0] + elem[1])
49
+ current_time = elem[0] + elem[1]
50
+ return times
functions/model_infer.py CHANGED
@@ -41,6 +41,6 @@ def predict_from_document(sentences):
41
  # Set the prediction threshold to 0.8 instead of 0.5, now use mean
42
  output = (
43
  prediction.flatten()[: len(sentences)]
44
- >= np.mean(prediction) + np.var(prediction) * 2
45
  ).astype(int)
46
  return output, prediction.flatten()[: len(sentences)]
 
41
  # Set the prediction threshold to 0.8 instead of 0.5, now use mean
42
  output = (
43
  prediction.flatten()[: len(sentences)]
44
+ >= np.mean(prediction) + np.std(prediction) * 2
45
  ).astype(int)
46
  return output, prediction.flatten()[: len(sentences)]
functions/punctuation.py CHANGED
@@ -55,4 +55,4 @@ def punctuate(video_id):
55
  ) # Get the transcript from the YoutubeTranscriptApi
56
  resp = query_punctuation(splits) # Get the response from the Inference API
57
  punctuated_transcript = parse_output(resp, splits)
58
- return punctuated_transcript
 
55
  ) # Get the transcript from the YoutubeTranscriptApi
56
  resp = query_punctuation(splits) # Get the response from the Inference API
57
  punctuated_transcript = parse_output(resp, splits)
58
+ return punctuated_transcript, transcript
requirements.txt CHANGED
@@ -1,4 +1,6 @@
1
  youtube_transcript_api
 
 
2
  tensorflow==2.15
3
  keras
4
  keras-nlp
 
1
  youtube_transcript_api
2
+ thefuzz
3
+ numpy
4
  tensorflow==2.15
5
  keras
6
  keras-nlp