aapot commited on
Commit
f3772cc
1 Parent(s): 8b86424

Add demo application

Browse files
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Youtube Video Similarity
3
- emoji: 👀
4
  colorFrom: purple
5
  colorTo: blue
6
  sdk: gradio
@@ -8,6 +8,9 @@ sdk_version: 3.3.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
 
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Youtube Video Similarity
3
+ emoji:
4
  colorFrom: purple
5
  colorTo: blue
6
  sdk: gradio
 
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ models:
12
+ - mozilla-foundation/youtube_video_similarity_model_wt
13
+ - mozilla-foundation/youtube_video_similarity_model_nt
14
  ---
15
 
16
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ from utils.unifiedmodel import RRUMDataset
5
+ from utils.huggingface_model_wrapper import YoutubeVideoSimilarityModel
6
+ from torch.utils.data import DataLoader
7
+ from helpers import get_example_videos, update_youtube_embedded_html, get_input_data_df
8
+
9
+ RR_EXAMPLES_URL = os.environ.get(
10
+ 'RR_EXAMPLES_URL', 'https://public-data.telemetry.mozilla.org/api/v1/tables/telemetry_derived/regrets_reporter_study/v1/files/000000000000.json')
11
+ NUM_RR_EXAMPLES = 5
12
+ example_videos, example_videos_rr = get_example_videos(
13
+ RR_EXAMPLES_URL, NUM_RR_EXAMPLES)
14
+
15
+ demo_title = 'Mozilla RegretsReporter YouTube video similarity'
16
+ demo_description = f'''
17
+ # {demo_title}
18
+
19
+ This demo showcases the YouTube video semantic similarity model developed as part of the RegretsReporter research project at Mozilla Foundation. You can read more about the project [here](https://foundation.mozilla.org/en/youtube/user-controls/) and about the semantic similarity model [here](https://foundation.mozilla.org/en/blog/the-regretsreporter-user-controls-study-machine-learning-to-measure-semantic-similarity-of-youtube-videos/). Note: the model is multilingual so you can try it with non-English videos too while it probably works the best with English videos.
20
+
21
+ This demo works by inserting two YouTube video URLs below and clicking the Run button. After a few seconds, you will see model's predicted probability of how similar those two videos are. You can copy URLs from YouTube or also try out a few predefined examples by clicking them on the examples table.
22
+ '''
23
+
24
+ placeholder_youtube_embedded_html = '''
25
+ <p>Insert video URL first</p>
26
+ '''
27
+
28
+
29
+ model_wt = YoutubeVideoSimilarityModel.from_pretrained(
30
+ 'mozilla-foundation/youtube_video_similarity_model_wt', use_auth_token=True)
31
+ model_nt = YoutubeVideoSimilarityModel.from_pretrained(
32
+ 'mozilla-foundation/youtube_video_similarity_model_nt', use_auth_token=True)
33
+ cross_encoder_model_name_or_path = model_wt.cross_encoder_model_name_or_path
34
+
35
+
36
+ def get_video_similarity(video1_url, video2_url):
37
+ df = get_input_data_df(video1_url, video2_url)
38
+ if df['regret_transcript'].isna().any() or df['recommendation_transcript'].isna().any():
39
+ with_transcript = False
40
+ else:
41
+ with_transcript = True
42
+ dataset = RRUMDataset(df, with_transcript=with_transcript, label_col=None,
43
+ cross_encoder_model_name_or_path=cross_encoder_model_name_or_path)
44
+ data_loader = DataLoader(dataset.test_dataset, shuffle=False,
45
+ batch_size=1, num_workers=0, pin_memory=False)
46
+
47
+ with torch.inference_mode():
48
+ if with_transcript:
49
+ pred = model_wt(next(iter(data_loader)))
50
+ else:
51
+ pred = model_nt(next(iter(data_loader)))
52
+ pred = torch.special.expit(pred).squeeze().tolist()
53
+ return f'YouTube videos are {pred:.0%} similar'
54
+
55
+
56
+ with gr.Blocks(title=demo_title) as demo:
57
+ gr.Markdown(demo_description)
58
+ with gr.Row():
59
+ with gr.Column():
60
+ input_text1 = gr.Textbox(
61
+ label='Video 1', placeholder='Insert first YouTube video URL')
62
+ input_text2 = gr.Textbox(
63
+ label='Video 2', placeholder='Insert second YouTube video URL')
64
+ inputs = [input_text1, input_text2]
65
+ with gr.Row():
66
+ clear_btn = gr.Button('Clear', variant='secondary')
67
+ run_btn = gr.Button('Run', variant='primary')
68
+ with gr.Column():
69
+ outputs = [gr.Label(label='Model prediction')]
70
+ with gr.Accordion('See video details', open=False):
71
+ with gr.Row():
72
+ with gr.Column():
73
+ video_embedded = gr.HTML(
74
+ value=placeholder_youtube_embedded_html)
75
+ with gr.Column():
76
+ video_embedded2 = gr.HTML(
77
+ value=placeholder_youtube_embedded_html)
78
+ with gr.Column():
79
+ if example_videos:
80
+ examples = gr.Examples(examples=example_videos, inputs=inputs)
81
+ if example_videos_rr:
82
+ examples_rr = gr.Examples(examples=example_videos_rr, inputs=inputs,
83
+ label='Example bad becommendations from the RegretsReporter report')
84
+
85
+ run_btn.click(fn=get_video_similarity, inputs=inputs, outputs=outputs)
86
+ clear_btn.click(lambda value_1, value_2, value_3: (
87
+ None, None, None), inputs=inputs + outputs, outputs=inputs + outputs)
88
+
89
+ input_text1.change(lambda input: update_youtube_embedded_html(
90
+ input, 1) if input else placeholder_youtube_embedded_html, inputs=input_text1, outputs=video_embedded)
91
+ input_text2.change(lambda input: update_youtube_embedded_html(
92
+ input, 2) if input else placeholder_youtube_embedded_html, inputs=input_text2, outputs=video_embedded2)
93
+
94
+ demo.launch()
helpers.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import random
3
+ import requests
4
+ import pandas as pd
5
+ from pytube import YouTube
6
+ from youtube_transcript_api import YouTubeTranscriptApi
7
+ from youtube_transcript_api.formatters import TextFormatter
8
+
9
+
10
+ def is_youtube_video_available(url):
11
+ video = YouTube(url)
12
+ try:
13
+ video.title
14
+ return True
15
+ except:
16
+ return False
17
+
18
+
19
+ def get_example_videos(rr_examples_url, num_rr_examples):
20
+ example_videos = [['https://www.youtube.com/watch?v=WfVF-Ec4naQ', 'https://www.youtube.com/watch?v=4hrNt28t7Cw'],
21
+ ['https://www.youtube.com/watch?v=GbpjLP-UvIU',
22
+ 'https://www.youtube.com/watch?v=BlQ2mP2EE4A'],
23
+ ['https://www.youtube.com/watch?v=fdzY1f2P91k',
24
+ 'https://www.youtube.com/watch?v=BlQ2mP2EE4A'],
25
+ ['https://www.youtube.com/watch?v=fdzY1f2P91k', 'https://www.youtube.com/watch?v=9gIVGJQ3xWE']]
26
+ example_videos = [ex for ex in example_videos if is_youtube_video_available(
27
+ ex[0]) and is_youtube_video_available(ex[1])]
28
+
29
+ try:
30
+ example_videos_rr = requests.get(rr_examples_url).json()
31
+ except:
32
+ example_videos_rr = []
33
+ example_videos_rr = [[f'https://www.youtube.com/watch?v={ex["rejected_video_id"]}',
34
+ f'https://www.youtube.com/watch?v={ex["recommendation_id"]}'] for ex in example_videos_rr]
35
+ # remove duplicate video pairs, there seems to be one duplicate
36
+ example_videos_rr.sort()
37
+ example_videos_rr = list(example_videos_rr for example_videos_rr,
38
+ _ in itertools.groupby(example_videos_rr))
39
+ example_videos_rr = [ex for ex in example_videos_rr if is_youtube_video_available(
40
+ ex[0]) and is_youtube_video_available(ex[1])]
41
+ if len(example_videos_rr) > num_rr_examples:
42
+ example_videos_rr = random.sample(example_videos_rr, num_rr_examples)
43
+
44
+ return example_videos, example_videos_rr
45
+
46
+
47
+ def get_youtube_embedded_html(embed_url, video_position):
48
+ return f'''
49
+ <p>Video {video_position}</p>
50
+ <iframe width="100%" height="360px" src="{embed_url}" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; fullscreen" allowfullscreen></iframe>
51
+ '''
52
+
53
+
54
+ def update_youtube_embedded_html(video_url, video_position):
55
+ try:
56
+ embed_url = YouTube(video_url).embed_url
57
+ except:
58
+ return f'''
59
+ <p>There was error in fetching details for video with the URL: {video_url}</p>
60
+ '''
61
+ return get_youtube_embedded_html(embed_url, video_position)
62
+
63
+
64
+ def get_youtube_video_data(url):
65
+ video = YouTube(url)
66
+ channel_id = video.channel_id
67
+ video_title = video.title
68
+ video_description = video.description
69
+
70
+ try:
71
+ transcript_list = YouTubeTranscriptApi.list_transcripts(video.video_id)
72
+ except:
73
+ return channel_id, video_title, video_description, None
74
+
75
+ available_non_common_langs = [tr.language_code for tr in list(
76
+ transcript_list) if tr.language_code not in ['en', 'en-US', 'es', 'de']]
77
+ video_transcript = YouTubeTranscriptApi.get_transcript(
78
+ video.video_id, languages=['en', 'en-US', 'es', 'de'] + available_non_common_langs)
79
+ video_transcript = TextFormatter().format_transcript(
80
+ video_transcript).replace('\n', ' ')
81
+ return channel_id, video_title, video_description, video_transcript
82
+
83
+
84
+ def get_input_data_df(video1_url, video2_url):
85
+ channel_id, video_title, video_description, video_transcript = get_youtube_video_data(
86
+ video1_url)
87
+ channel_id2, video_title2, video_description2, video_transcript2 = get_youtube_video_data(
88
+ video2_url)
89
+ channel_sim = 1 if channel_id == channel_id2 else 0
90
+ df = pd.DataFrame([[video_title, video_description, video_transcript] + [video_title2, video_description2, video_transcript2] + [channel_sim]], columns=[
91
+ 'regret_title', 'regret_description', 'regret_transcript', 'recommendation_title', 'recommendation_description', 'recommendation_transcript', 'channel_sim'])
92
+ return df
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets==2.4.0
2
+ gradio==3.3.1
3
+ huggingface_hub==0.9.1
4
+ pandas==1.4.3
5
+ pyarrow==9.0.0
6
+ pytorch_lightning==1.7.6
7
+ pytube==12.1.0
8
+ requests==2.27.1
9
+ torch==1.12.1
10
+ torchmetrics==0.9.3
11
+ transformers==4.22.1
12
+ youtube_transcript_api==0.4.4
utils/huggingface_model_wrapper.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import PyTorchModelHubMixin
2
+ from huggingface_hub.constants import PYTORCH_WEIGHTS_NAME
3
+ from huggingface_hub.file_download import hf_hub_download
4
+ from .unifiedmodel import RRUM
5
+ import os
6
+ import torch
7
+
8
+
9
+ class YoutubeVideoSimilarityModel(RRUM, PyTorchModelHubMixin):
10
+ """
11
+ Hugging Face `PyTorchModelHubMixin` wrapper for RegretsReporter `RRUM` model.
12
+ This allows loading, using, and saving the model from Hugging Face model hub
13
+ with default Hugging Face methods `from_pretrained` and `save_pretrained`.
14
+ """
15
+ @classmethod
16
+ def _from_pretrained(
17
+ cls,
18
+ model_id,
19
+ revision,
20
+ cache_dir,
21
+ force_download,
22
+ proxies,
23
+ resume_download,
24
+ local_files_only,
25
+ use_auth_token,
26
+ map_location="cpu",
27
+ strict=False,
28
+ **model_kwargs,
29
+ ):
30
+ map_location = torch.device(map_location)
31
+
32
+ if os.path.isdir(model_id):
33
+ print("Loading weights from local directory")
34
+ model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
35
+ else:
36
+ model_file = hf_hub_download(
37
+ repo_id=model_id,
38
+ filename=PYTORCH_WEIGHTS_NAME,
39
+ revision=revision,
40
+ cache_dir=cache_dir,
41
+ force_download=force_download,
42
+ proxies=proxies,
43
+ resume_download=resume_download,
44
+ use_auth_token=use_auth_token,
45
+ local_files_only=local_files_only,
46
+ )
47
+ # convert Huggingface config to RRUM acceptable input parameters
48
+ if "config" in model_kwargs:
49
+ model_kwargs = {**model_kwargs["config"], **model_kwargs}
50
+ del model_kwargs["config"]
51
+ model = cls(**model_kwargs)
52
+
53
+ state_dict = torch.load(model_file, map_location=map_location)
54
+ model.load_state_dict(state_dict, strict=strict)
55
+ model.eval()
56
+
57
+ return model
utils/text_cleaning.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastcore.basics import listify
2
+ from fastcore.utils import compose
3
+ import unicodedata
4
+ from string import punctuation
5
+ import html
6
+ from itertools import groupby
7
+ import re
8
+
9
+ control_char_regex = re.compile(r'[\r\n\t]+')
10
+ url_regex = re.compile(
11
+ r'((http|https)\:\/\/)?[a-zA-Z0-9\.\/\?\:@\-_=#]+\.([a-zA-Z]){2,6}([a-zA-Z0-9\.\&\/\?\:@\-_=#])*')
12
+ username_regex = re.compile(r'(^|[^@\w])@(\w{1,15})\b')
13
+
14
+
15
+ def fix_html(text):
16
+ tmp_ls = []
17
+ for e in listify(text):
18
+ e = e.replace('#39;', "'").replace('amp;', '&').replace('#146;', "'").replace('nbsp;', ' ').replace(
19
+ '#36;', '$').replace('\\n', "\n").replace('quot;', "'").replace('<br />', "\n").replace(
20
+ '\\"', '"').replace('<unk>', ' ').replace(' @.@ ', '.').replace(' @-@ ', '-').replace('...', ' …')
21
+ tmp_ls.append(html.unescape(e))
22
+
23
+ text = tmp_ls
24
+ return text
25
+
26
+
27
+ def remove_control_char(text):
28
+ tmp_ls = []
29
+ for e in listify(text):
30
+ tmp_ls.append(re.sub(control_char_regex, '.', e))
31
+
32
+ text = tmp_ls
33
+ return text
34
+
35
+
36
+ def remove_remaining_control_chars(text):
37
+ tmp_ls = []
38
+ for e in listify(text):
39
+ tmp_ls.append(
40
+ ''.join(ch for ch in e if unicodedata.category(ch)[0] != 'C'))
41
+
42
+ text = tmp_ls
43
+ return text
44
+
45
+
46
+ def remove_unicode_symbols(text):
47
+ tmp_ls = []
48
+ for e in listify(text):
49
+ tmp_ls.append(
50
+ ''.join(ch for ch in e if unicodedata.category(ch)[0] != 'So'))
51
+
52
+ text = tmp_ls
53
+ return text
54
+
55
+
56
+ def standardise_punc(text):
57
+ transl_table = dict([(ord(x), ord(y))
58
+ for x, y in zip(u"‘’´“”–-", u"'''\"\"--")])
59
+ tmp_ls = []
60
+ for e in listify(text):
61
+ e = e.translate(transl_table)
62
+ tmp_ls.append(e)
63
+
64
+ text = tmp_ls
65
+ return text
66
+
67
+
68
+ def remove_news_tags(text):
69
+ tmp_ls = []
70
+ for e in listify(text):
71
+ e = re.sub(r"(<[A-Z].+?>)|(</[A-Z].+?>)", "", e)
72
+ tmp_ls.append(e)
73
+
74
+ text = tmp_ls
75
+ return text
76
+
77
+
78
+ def replace_urls(text):
79
+ filler, tmp_ls = '', []
80
+ for e in listify(text):
81
+ e = re.sub(r"(<a.+?>)|(</a>)|(<ref.+?>)", "", e)
82
+ e = re.sub(url_regex, filler, e)
83
+ tmp_ls.append(e)
84
+
85
+ text = tmp_ls
86
+ return text
87
+
88
+
89
+ def replace_usernames(text):
90
+ filler, tmp_ls = '', []
91
+ for e in listify(text):
92
+ occ = e.count('@')
93
+ for _ in range(occ):
94
+ e = e.replace('@<user>', f'{filler}')
95
+ # replace other user handles by filler
96
+ e = re.sub(username_regex, filler, e)
97
+ tmp_ls.append(e)
98
+
99
+ text = tmp_ls
100
+ return text
101
+
102
+
103
+ def remove_duplicate_punctuation(text):
104
+ tmp_ls = []
105
+ for e in listify(text):
106
+ e = re.sub(r'\b(\w+)( \1\b)+', r'\1', e)
107
+ punc = set(punctuation)
108
+ newtext = []
109
+ for k, g in groupby(e):
110
+ if k in punc:
111
+ newtext.append(k)
112
+ else:
113
+ newtext.extend(g)
114
+ e = ''.join(newtext)
115
+ tmp_ls.append(e)
116
+
117
+ text = tmp_ls
118
+ return text
119
+
120
+
121
+ def remove_multi_space(text):
122
+ tmp_ls = []
123
+ for e in listify(text):
124
+ tmp_ls.append(' '.join(e.split()))
125
+
126
+ text = tmp_ls
127
+ return text
128
+
129
+
130
+ clean_text_funcs = compose(*[fix_html, remove_control_char, remove_remaining_control_chars, remove_unicode_symbols,
131
+ standardise_punc, remove_news_tags, replace_urls, replace_usernames, remove_duplicate_punctuation, remove_multi_space])
utils/unifiedmodel.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup
2
+ import datasets
3
+ import pandas as pd
4
+ import pyarrow
5
+ import pytorch_lightning as pl
6
+ import torchmetrics
7
+ import torch.nn as nn
8
+ import torch
9
+ import types
10
+ import multiprocessing
11
+ from .text_cleaning import clean_text_funcs
12
+
13
+
14
+ class RRUMDataset():
15
+ scalar_features = ['channel_sim']
16
+ _image_features = ['regret_thumbnail',
17
+ 'recommendation_thumbnail'] # not used atm
18
+
19
+ def __init__(self, data, with_transcript, cross_encoder_model_name_or_path, label_col="label", label_map=None, balance_label_counts=False, max_length=128, do_train_test_split=False, test_size=0.25, seed=42, keep_video_ids_for_predictions=False, encode_on_the_fly=False, clean_text=False, processing_batch_size=1000, processing_num_proc=1):
20
+ self._with_transcript = with_transcript
21
+ self.tokenizer = AutoTokenizer.from_pretrained(
22
+ cross_encoder_model_name_or_path)
23
+ self.label_col = label_col
24
+ self.label_map = label_map
25
+ self.balance_label_counts = balance_label_counts
26
+ self.max_length = max_length
27
+ self.seed = seed
28
+ self.keep_video_ids_for_predictions = keep_video_ids_for_predictions
29
+ self.clean_text = clean_text
30
+ self.processing_batch_size = processing_batch_size
31
+ self.processing_num_proc = multiprocessing.cpu_count(
32
+ ) if not processing_num_proc else processing_num_proc
33
+
34
+ self.text_types = ['title', 'description'] + \
35
+ (['transcript'] if self._with_transcript else [])
36
+ self._text_features = [
37
+ 'regret_title', 'recommendation_title', 'regret_description',
38
+ 'recommendation_description'] + (['regret_transcript', 'recommendation_transcript'] if self._with_transcript else [])
39
+
40
+ # LOAD DATA INTO DATASET
41
+ self.streaming_dataset = False
42
+ if isinstance(data, pd.DataFrame):
43
+ self.dataset = datasets.Dataset.from_pandas(data)
44
+ elif isinstance(data, types.GeneratorType):
45
+ examples_iterable = datasets.iterable_dataset.ExamplesIterable(
46
+ self._streaming_generate_examples, {"iterable": data})
47
+ self.dataset = datasets.IterableDataset(examples_iterable)
48
+ self._stream_dataset_example = next(iter(self.dataset))
49
+ self._stream_dataset_column_names = list(
50
+ self._stream_dataset_example.keys())
51
+ self.streaming_dataset = True
52
+ elif isinstance(data, pyarrow.Table):
53
+ self.dataset = datasets.Dataset(data)
54
+ else:
55
+ raise ValueError(
56
+ f'Type of data is {type(data)} when pd.DataFrame, pyarrow.Table, or generator of pyarrow.RecordBatch is allowed')
57
+
58
+ # PREPROCESS DATASET
59
+ self._preprocess()
60
+
61
+ # ENCODE DATASET
62
+ self.train_dataset = None
63
+ self.test_dataset = None
64
+ if self.streaming_dataset:
65
+ # IterableDataset doesn't have train_test_split method
66
+ if self.label_col:
67
+ self.train_dataset = self._encode_streaming(self.dataset)
68
+ print('Streaming dataset available in .train_dataset')
69
+ else:
70
+ self.test_dataset = self._encode_streaming(self.dataset)
71
+ print(
72
+ 'Streaming dataset available in .test_dataset because label_col=None')
73
+ else:
74
+ # dataset into train_dataset and/or test_dataset
75
+ if do_train_test_split:
76
+ ds = self.dataset.train_test_split(
77
+ test_size=test_size, shuffle=True, seed=self.seed, stratify_by_column=self.label_col)
78
+ self.train_dataset = ds['train']
79
+ self.test_dataset = ds['test']
80
+ print(
81
+ f'Dataset was splitted into train and test with test_size={test_size}')
82
+ else:
83
+ if self.label_col:
84
+ self.train_dataset = self.dataset
85
+ else:
86
+ self.test_dataset = self.dataset
87
+
88
+ if encode_on_the_fly:
89
+ if self.train_dataset:
90
+ self.train_dataset.set_transform(self._encode_on_the_fly)
91
+ print('On-the-fly encoded dataset available in .train_dataset')
92
+ if self.test_dataset:
93
+ self.test_dataset.set_transform(self._encode_on_the_fly)
94
+ print('On-the-fly encoded dataset available in .test_dataset')
95
+ else:
96
+ if self.train_dataset:
97
+ self.train_dataset = self._encode(self.train_dataset)
98
+ print('Pre-encoded dataset available in .train_dataset')
99
+ if self.test_dataset:
100
+ self.test_dataset = self._encode(self.test_dataset)
101
+ print('Pre-encoded dataset available in .test_dataset')
102
+
103
+ def __len__(self):
104
+ if self.streaming_dataset:
105
+ raise ValueError(
106
+ f'Streaming dataset does not support len() method')
107
+ return len(self.dataset)
108
+
109
+ def __getitem__(self, index):
110
+ if self.streaming_dataset:
111
+ return next(iter(self.dataset))
112
+ return self.dataset[index]
113
+
114
+ def _streaming_generate_examples(self, iterable):
115
+ id_ = 0
116
+ # TODO: make sure GeneratorType is pyarrow.RecordBatch
117
+ if isinstance(iterable, types.GeneratorType):
118
+ for examples in iterable:
119
+ for ex in examples.to_pylist():
120
+ yield id_, ex
121
+ id_ += 1
122
+
123
+ def _preprocess(self):
124
+ if self._with_transcript:
125
+ self.dataset = self.dataset.filter(
126
+ lambda example: example['regret_transcript'] is not None and example['recommendation_transcript'] is not None)
127
+ else:
128
+ self.dataset = self.dataset.filter(
129
+ lambda example: example['regret_transcript'] is None or example['recommendation_transcript'] is None)
130
+ if self.label_col:
131
+ if self.streaming_dataset:
132
+ if self.label_col in self._stream_dataset_column_names and isinstance(self._stream_dataset_example[self.label_col], str):
133
+ if not self.label_map:
134
+ raise ValueError(
135
+ f'"label_map" dict was not provided and is needed to encode string labels for streaming datasets')
136
+ # cast_column method had issues with streaming dataset
137
+ self.dataset = self.dataset.map(
138
+ self._streaming_rename_labels)
139
+ else:
140
+ if self.dataset.features[self.label_col].dtype == 'string':
141
+ if not self.label_map:
142
+ self.label_map = {k: v for v, k in enumerate(
143
+ self.dataset.unique(self.label_col))}
144
+ self.dataset = self.dataset.filter(
145
+ lambda example: example[self.label_col] in self.label_map.keys())
146
+ self.dataset = self.dataset.cast_column(self.label_col, datasets.ClassLabel(
147
+ num_classes=len(self.label_map), names=list(self.label_map.keys())))
148
+
149
+ self.dataset = self.dataset.filter(lambda example: not any(x in [None, ""] for x in [
150
+ example[key] for key in self._text_features + self.scalar_features + ([self.label_col] if self.label_col else [])])) # dropna
151
+
152
+ if self.balance_label_counts and self.label_col and not self.streaming_dataset:
153
+ label_datasets = {}
154
+ for label in list(self.label_map.values()):
155
+ label_dataset = self.dataset.filter(
156
+ lambda example: example[self.label_col] == label)
157
+ label_datasets[len(label_dataset)] = label_dataset
158
+ min_label_count = min(label_datasets)
159
+ sampled_datasets = [dataset.train_test_split(train_size=min_label_count, shuffle=True, seed=self.seed)[
160
+ 'train'] if len(dataset) != min_label_count else dataset for dataset in label_datasets.values()]
161
+ self.dataset = datasets.concatenate_datasets(sampled_datasets)
162
+
163
+ if self.clean_text:
164
+ self.dataset = self.dataset.map(self._clean_text, batched=not self.streaming_dataset,
165
+ batch_size=self.processing_batch_size)
166
+ self.dataset = self.dataset.map(self._truncate_and_strip_text, batched=not self.streaming_dataset,
167
+ batch_size=self.processing_batch_size)
168
+
169
+ def _streaming_rename_labels(self, example):
170
+ # rename labels according to label_map if not already correct labels
171
+ if isinstance(example[self.label_col], list):
172
+ example[self.label_col] = [self.label_map.get(
173
+ ex, None) for ex in example[self.label_col] if ex not in self.label_map.values()]
174
+ elif isinstance(example[self.label_col], str) and example[self.label_col] not in self.label_map.values():
175
+ example[self.label_col] = self.label_map.get(
176
+ example[self.label_col], None)
177
+ else:
178
+ raise ValueError(
179
+ f'Type of example label is {type(example[self.label_col])} when list or string is allowed')
180
+ return example
181
+
182
+ def _clean_text(self, example):
183
+ for feat in self._text_features:
184
+ example[feat] = clean_text_funcs(example[feat])[0] if isinstance(
185
+ example[feat], str) else clean_text_funcs(example[feat])
186
+ return example
187
+
188
+ def _truncate_and_strip_text(self, example):
189
+ # tokenizer will truncate to max_length tokens anyway so to save RAM let's truncate to max_length words already beforehand
190
+ # one word is usually one or more tokens so should be safe to truncate this way without losing information
191
+ for feat in self._text_features:
192
+ if isinstance(example[feat], list):
193
+ example[feat] = [
194
+ ' '.join(text.split()[:self.max_length]).strip() for text in example[feat] if text]
195
+ elif isinstance(example[feat], str):
196
+ example[feat] = ' '.join(example[feat].split()[
197
+ :self.max_length]).strip()
198
+ elif example[feat] is None:
199
+ return None
200
+ else:
201
+ raise ValueError(
202
+ f'Type of example is {type(example[feat])} when list or string is allowed')
203
+ return example
204
+
205
+ def _encode(self, dataset):
206
+ encoded_dataset = None
207
+ for text_type in self.text_types:
208
+ encoded_text_type = dataset.map(lambda regret, recommendation: self.tokenizer(regret, recommendation, padding="max_length", truncation=True, max_length=self.max_length), batched=True,
209
+ batch_size=self.processing_batch_size, num_proc=self.processing_num_proc, input_columns=[f'regret_{text_type}', f'recommendation_{text_type}'], remove_columns=dataset.column_names)
210
+ encoded_text_type = encoded_text_type.rename_columns(
211
+ {col: f'{text_type}_{col}' for col in encoded_text_type.column_names}) # e.g. input_ids -> title_input_ids so we have separate input_ids for each text_type
212
+ if encoded_dataset:
213
+ encoded_dataset = datasets.concatenate_datasets(
214
+ [encoded_dataset, encoded_text_type], axis=1)
215
+ else:
216
+ encoded_dataset = encoded_text_type
217
+
218
+ # copy scalar features and label from original dataset to the encoded dataset
219
+ for scalar_feat in self.scalar_features:
220
+ encoded_dataset = encoded_dataset.add_column(
221
+ name=scalar_feat, column=dataset[scalar_feat])
222
+ if self.label_col:
223
+ encoded_dataset = encoded_dataset.add_column(
224
+ name=self.label_col, column=dataset[self.label_col])
225
+ if self.keep_video_ids_for_predictions:
226
+ for id in ['regret_id', "recommendation_id"]:
227
+ encoded_dataset = encoded_dataset.add_column(
228
+ name=id, column=dataset[id])
229
+
230
+ encoded_dataset.set_format(
231
+ type='torch', columns=encoded_dataset.column_names)
232
+ return encoded_dataset
233
+
234
+ def _encode_streaming(self, dataset):
235
+ encoded_dataset = dataset.map(self._encode_on_the_fly, batched=True,
236
+ batch_size=self.processing_batch_size, remove_columns=list(set(self._stream_dataset_column_names)-set(self.scalar_features + (
237
+ [self.label_col] if self.label_col else []) + (['regret_id', "recommendation_id"] if self.keep_video_ids_for_predictions else [])))) # IterableDataset doesn't have column_names attribute as normal Dataset
238
+ encoded_dataset = encoded_dataset.with_format("torch")
239
+ return encoded_dataset
240
+
241
+ def _encode_on_the_fly(self, batch):
242
+ for text_type in self.text_types:
243
+ encoded_text_type = dict(self.tokenizer(
244
+ batch[f'regret_{text_type}'], batch[f'recommendation_{text_type}'], padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt"))
245
+ for encoded_key in encoded_text_type.copy():
246
+ encoded_text_type[f"{text_type}_{encoded_key}"] = encoded_text_type.pop(encoded_key) if not self.streaming_dataset else encoded_text_type.pop(
247
+ encoded_key).squeeze(0) # e.g. input_ids -> title_input_ids so we have separate input_ids for each text_type
248
+ del batch[f'regret_{text_type}']
249
+ del batch[f'recommendation_{text_type}']
250
+ batch.update(encoded_text_type)
251
+ for scalar_feat in self.scalar_features:
252
+ batch[scalar_feat] = torch.as_tensor(
253
+ batch[scalar_feat]) if not self.streaming_dataset else torch.as_tensor(batch[scalar_feat]).squeeze(0)
254
+ if self.label_col:
255
+ batch[self.label_col] = torch.as_tensor(
256
+ batch[self.label_col]) if not self.streaming_dataset else torch.as_tensor(batch[self.label_col]).squeeze(0)
257
+ return batch
258
+
259
+
260
+ class RRUM(pl.LightningModule):
261
+ def __init__(self, text_types, scalar_features, label_col, cross_encoder_model_name_or_path, optimizer_config=None, freeze_policy=None, pos_weight=None):
262
+ super().__init__()
263
+ self.save_hyperparameters()
264
+ self.text_types = text_types
265
+ self.scalar_features = scalar_features
266
+ self.label_col = label_col
267
+ self.optimizer_config = optimizer_config
268
+ self.cross_encoder_model_name_or_path = cross_encoder_model_name_or_path
269
+ self.cross_encoders = nn.ModuleDict({})
270
+ for t in self.text_types:
271
+ self.cross_encoders[t] = AutoModelForSequenceClassification.from_pretrained(
272
+ self.cross_encoder_model_name_or_path)
273
+ if freeze_policy is not None:
274
+ for xe in self.cross_encoders.values():
275
+ for name, param in xe.named_parameters():
276
+ if freeze_policy(name):
277
+ param.requires_grad = False
278
+ cross_encoder_out_features = list(self.cross_encoders.values())[0](
279
+ torch.randint(1, 2, (1, 2))).logits.size(dim=1)
280
+ self.lin1 = nn.Linear(len(self.cross_encoders) * cross_encoder_out_features +
281
+ len(self.scalar_features), 1)
282
+ self.ac_metric = torchmetrics.Accuracy()
283
+ self.pr_metric = torchmetrics.Precision()
284
+ self.re_metric = torchmetrics.Recall()
285
+ self.auc_metric = torchmetrics.AUROC()
286
+
287
+ if pos_weight:
288
+ self.loss = nn.BCEWithLogitsLoss(
289
+ pos_weight=torch.Tensor([pos_weight]))
290
+ else:
291
+ self.loss = nn.BCEWithLogitsLoss()
292
+
293
+ def forward(self, x):
294
+ cross_logits = {}
295
+ for f in self.text_types:
296
+ inputs = {key.split(f'{f}_')[1]: x[key]
297
+ for key in x if f in key} # e.g. title_input_ids -> input_ids since we have separate input_ids for each text_type
298
+ cross_logits[f] = self.cross_encoders[f](**inputs).logits
299
+ x = torch.cat([*cross_logits.values()] +
300
+ [x[scalar][:, None] for scalar in self.scalar_features],
301
+ 1
302
+ )
303
+ del cross_logits
304
+
305
+ x = self.lin1(x)
306
+ return x
307
+
308
+ def configure_optimizers(self):
309
+ if self.optimizer_config:
310
+ return self.optimizer_config(self)
311
+
312
+ optimizer = torch.optim.AdamW(self.parameters(), lr=5e-5)
313
+ scheduler = get_linear_schedule_with_warmup(
314
+ optimizer,
315
+ num_warmup_steps=int(
316
+ self.trainer.estimated_stepping_batches * 0.05),
317
+ num_training_steps=self.trainer.estimated_stepping_batches,
318
+ )
319
+ scheduler = {'scheduler': scheduler,
320
+ 'interval': 'step', 'frequency': 1}
321
+ return [optimizer], [scheduler]
322
+
323
+ def training_step(self, train_batch, batch_idx):
324
+ y = train_batch[self.label_col].unsqueeze(1).float()
325
+ logits = self(train_batch)
326
+ loss = self.loss(logits, y)
327
+ self.log('train_loss', loss)
328
+ return loss
329
+
330
+ def validation_step(self, val_batch, batch_idx):
331
+ y = val_batch[self.label_col].unsqueeze(1).float()
332
+ logits = self(val_batch)
333
+ loss = self.loss(logits, y)
334
+ self.ac_metric(logits, y.int())
335
+ self.pr_metric(logits, y.int())
336
+ self.re_metric(logits, y.int())
337
+ self.auc_metric(logits, y.int())
338
+ self.log('validation_accuracy', self.ac_metric)
339
+ self.log('validation_precision', self.pr_metric)
340
+ self.log('validation_recall', self.re_metric)
341
+ self.log('validation_auc', self.auc_metric)
342
+ self.log('val_loss', loss, prog_bar=True)
343
+
344
+ def validation_epoch_end(self, outputs):
345
+ self.log('validation_accuracy_ep', self.ac_metric)
346
+ self.log('validation_precision_ep', self.pr_metric)
347
+ self.log('validation_recall_ep', self.re_metric)
348
+ self.log('validation_auc_ep', self.auc_metric)