Spaces:
Build error
Build error
aapot
commited on
Commit
•
f3772cc
1
Parent(s):
8b86424
Add demo application
Browse files- README.md +4 -1
- app.py +94 -0
- helpers.py +92 -0
- requirements.txt +12 -0
- utils/huggingface_model_wrapper.py +57 -0
- utils/text_cleaning.py +131 -0
- utils/unifiedmodel.py +348 -0
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: Youtube Video Similarity
|
3 |
-
emoji:
|
4 |
colorFrom: purple
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
@@ -8,6 +8,9 @@ sdk_version: 3.3.1
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
|
|
|
|
|
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: Youtube Video Similarity
|
3 |
+
emoji: ▶
|
4 |
colorFrom: purple
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
11 |
+
models:
|
12 |
+
- mozilla-foundation/youtube_video_similarity_model_wt
|
13 |
+
- mozilla-foundation/youtube_video_similarity_model_nt
|
14 |
---
|
15 |
|
16 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import torch
|
4 |
+
from utils.unifiedmodel import RRUMDataset
|
5 |
+
from utils.huggingface_model_wrapper import YoutubeVideoSimilarityModel
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
from helpers import get_example_videos, update_youtube_embedded_html, get_input_data_df
|
8 |
+
|
9 |
+
RR_EXAMPLES_URL = os.environ.get(
|
10 |
+
'RR_EXAMPLES_URL', 'https://public-data.telemetry.mozilla.org/api/v1/tables/telemetry_derived/regrets_reporter_study/v1/files/000000000000.json')
|
11 |
+
NUM_RR_EXAMPLES = 5
|
12 |
+
example_videos, example_videos_rr = get_example_videos(
|
13 |
+
RR_EXAMPLES_URL, NUM_RR_EXAMPLES)
|
14 |
+
|
15 |
+
demo_title = 'Mozilla RegretsReporter YouTube video similarity'
|
16 |
+
demo_description = f'''
|
17 |
+
# {demo_title}
|
18 |
+
|
19 |
+
This demo showcases the YouTube video semantic similarity model developed as part of the RegretsReporter research project at Mozilla Foundation. You can read more about the project [here](https://foundation.mozilla.org/en/youtube/user-controls/) and about the semantic similarity model [here](https://foundation.mozilla.org/en/blog/the-regretsreporter-user-controls-study-machine-learning-to-measure-semantic-similarity-of-youtube-videos/). Note: the model is multilingual so you can try it with non-English videos too while it probably works the best with English videos.
|
20 |
+
|
21 |
+
This demo works by inserting two YouTube video URLs below and clicking the Run button. After a few seconds, you will see model's predicted probability of how similar those two videos are. You can copy URLs from YouTube or also try out a few predefined examples by clicking them on the examples table.
|
22 |
+
'''
|
23 |
+
|
24 |
+
placeholder_youtube_embedded_html = '''
|
25 |
+
<p>Insert video URL first</p>
|
26 |
+
'''
|
27 |
+
|
28 |
+
|
29 |
+
model_wt = YoutubeVideoSimilarityModel.from_pretrained(
|
30 |
+
'mozilla-foundation/youtube_video_similarity_model_wt', use_auth_token=True)
|
31 |
+
model_nt = YoutubeVideoSimilarityModel.from_pretrained(
|
32 |
+
'mozilla-foundation/youtube_video_similarity_model_nt', use_auth_token=True)
|
33 |
+
cross_encoder_model_name_or_path = model_wt.cross_encoder_model_name_or_path
|
34 |
+
|
35 |
+
|
36 |
+
def get_video_similarity(video1_url, video2_url):
|
37 |
+
df = get_input_data_df(video1_url, video2_url)
|
38 |
+
if df['regret_transcript'].isna().any() or df['recommendation_transcript'].isna().any():
|
39 |
+
with_transcript = False
|
40 |
+
else:
|
41 |
+
with_transcript = True
|
42 |
+
dataset = RRUMDataset(df, with_transcript=with_transcript, label_col=None,
|
43 |
+
cross_encoder_model_name_or_path=cross_encoder_model_name_or_path)
|
44 |
+
data_loader = DataLoader(dataset.test_dataset, shuffle=False,
|
45 |
+
batch_size=1, num_workers=0, pin_memory=False)
|
46 |
+
|
47 |
+
with torch.inference_mode():
|
48 |
+
if with_transcript:
|
49 |
+
pred = model_wt(next(iter(data_loader)))
|
50 |
+
else:
|
51 |
+
pred = model_nt(next(iter(data_loader)))
|
52 |
+
pred = torch.special.expit(pred).squeeze().tolist()
|
53 |
+
return f'YouTube videos are {pred:.0%} similar'
|
54 |
+
|
55 |
+
|
56 |
+
with gr.Blocks(title=demo_title) as demo:
|
57 |
+
gr.Markdown(demo_description)
|
58 |
+
with gr.Row():
|
59 |
+
with gr.Column():
|
60 |
+
input_text1 = gr.Textbox(
|
61 |
+
label='Video 1', placeholder='Insert first YouTube video URL')
|
62 |
+
input_text2 = gr.Textbox(
|
63 |
+
label='Video 2', placeholder='Insert second YouTube video URL')
|
64 |
+
inputs = [input_text1, input_text2]
|
65 |
+
with gr.Row():
|
66 |
+
clear_btn = gr.Button('Clear', variant='secondary')
|
67 |
+
run_btn = gr.Button('Run', variant='primary')
|
68 |
+
with gr.Column():
|
69 |
+
outputs = [gr.Label(label='Model prediction')]
|
70 |
+
with gr.Accordion('See video details', open=False):
|
71 |
+
with gr.Row():
|
72 |
+
with gr.Column():
|
73 |
+
video_embedded = gr.HTML(
|
74 |
+
value=placeholder_youtube_embedded_html)
|
75 |
+
with gr.Column():
|
76 |
+
video_embedded2 = gr.HTML(
|
77 |
+
value=placeholder_youtube_embedded_html)
|
78 |
+
with gr.Column():
|
79 |
+
if example_videos:
|
80 |
+
examples = gr.Examples(examples=example_videos, inputs=inputs)
|
81 |
+
if example_videos_rr:
|
82 |
+
examples_rr = gr.Examples(examples=example_videos_rr, inputs=inputs,
|
83 |
+
label='Example bad becommendations from the RegretsReporter report')
|
84 |
+
|
85 |
+
run_btn.click(fn=get_video_similarity, inputs=inputs, outputs=outputs)
|
86 |
+
clear_btn.click(lambda value_1, value_2, value_3: (
|
87 |
+
None, None, None), inputs=inputs + outputs, outputs=inputs + outputs)
|
88 |
+
|
89 |
+
input_text1.change(lambda input: update_youtube_embedded_html(
|
90 |
+
input, 1) if input else placeholder_youtube_embedded_html, inputs=input_text1, outputs=video_embedded)
|
91 |
+
input_text2.change(lambda input: update_youtube_embedded_html(
|
92 |
+
input, 2) if input else placeholder_youtube_embedded_html, inputs=input_text2, outputs=video_embedded2)
|
93 |
+
|
94 |
+
demo.launch()
|
helpers.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import random
|
3 |
+
import requests
|
4 |
+
import pandas as pd
|
5 |
+
from pytube import YouTube
|
6 |
+
from youtube_transcript_api import YouTubeTranscriptApi
|
7 |
+
from youtube_transcript_api.formatters import TextFormatter
|
8 |
+
|
9 |
+
|
10 |
+
def is_youtube_video_available(url):
|
11 |
+
video = YouTube(url)
|
12 |
+
try:
|
13 |
+
video.title
|
14 |
+
return True
|
15 |
+
except:
|
16 |
+
return False
|
17 |
+
|
18 |
+
|
19 |
+
def get_example_videos(rr_examples_url, num_rr_examples):
|
20 |
+
example_videos = [['https://www.youtube.com/watch?v=WfVF-Ec4naQ', 'https://www.youtube.com/watch?v=4hrNt28t7Cw'],
|
21 |
+
['https://www.youtube.com/watch?v=GbpjLP-UvIU',
|
22 |
+
'https://www.youtube.com/watch?v=BlQ2mP2EE4A'],
|
23 |
+
['https://www.youtube.com/watch?v=fdzY1f2P91k',
|
24 |
+
'https://www.youtube.com/watch?v=BlQ2mP2EE4A'],
|
25 |
+
['https://www.youtube.com/watch?v=fdzY1f2P91k', 'https://www.youtube.com/watch?v=9gIVGJQ3xWE']]
|
26 |
+
example_videos = [ex for ex in example_videos if is_youtube_video_available(
|
27 |
+
ex[0]) and is_youtube_video_available(ex[1])]
|
28 |
+
|
29 |
+
try:
|
30 |
+
example_videos_rr = requests.get(rr_examples_url).json()
|
31 |
+
except:
|
32 |
+
example_videos_rr = []
|
33 |
+
example_videos_rr = [[f'https://www.youtube.com/watch?v={ex["rejected_video_id"]}',
|
34 |
+
f'https://www.youtube.com/watch?v={ex["recommendation_id"]}'] for ex in example_videos_rr]
|
35 |
+
# remove duplicate video pairs, there seems to be one duplicate
|
36 |
+
example_videos_rr.sort()
|
37 |
+
example_videos_rr = list(example_videos_rr for example_videos_rr,
|
38 |
+
_ in itertools.groupby(example_videos_rr))
|
39 |
+
example_videos_rr = [ex for ex in example_videos_rr if is_youtube_video_available(
|
40 |
+
ex[0]) and is_youtube_video_available(ex[1])]
|
41 |
+
if len(example_videos_rr) > num_rr_examples:
|
42 |
+
example_videos_rr = random.sample(example_videos_rr, num_rr_examples)
|
43 |
+
|
44 |
+
return example_videos, example_videos_rr
|
45 |
+
|
46 |
+
|
47 |
+
def get_youtube_embedded_html(embed_url, video_position):
|
48 |
+
return f'''
|
49 |
+
<p>Video {video_position}</p>
|
50 |
+
<iframe width="100%" height="360px" src="{embed_url}" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; fullscreen" allowfullscreen></iframe>
|
51 |
+
'''
|
52 |
+
|
53 |
+
|
54 |
+
def update_youtube_embedded_html(video_url, video_position):
|
55 |
+
try:
|
56 |
+
embed_url = YouTube(video_url).embed_url
|
57 |
+
except:
|
58 |
+
return f'''
|
59 |
+
<p>There was error in fetching details for video with the URL: {video_url}</p>
|
60 |
+
'''
|
61 |
+
return get_youtube_embedded_html(embed_url, video_position)
|
62 |
+
|
63 |
+
|
64 |
+
def get_youtube_video_data(url):
|
65 |
+
video = YouTube(url)
|
66 |
+
channel_id = video.channel_id
|
67 |
+
video_title = video.title
|
68 |
+
video_description = video.description
|
69 |
+
|
70 |
+
try:
|
71 |
+
transcript_list = YouTubeTranscriptApi.list_transcripts(video.video_id)
|
72 |
+
except:
|
73 |
+
return channel_id, video_title, video_description, None
|
74 |
+
|
75 |
+
available_non_common_langs = [tr.language_code for tr in list(
|
76 |
+
transcript_list) if tr.language_code not in ['en', 'en-US', 'es', 'de']]
|
77 |
+
video_transcript = YouTubeTranscriptApi.get_transcript(
|
78 |
+
video.video_id, languages=['en', 'en-US', 'es', 'de'] + available_non_common_langs)
|
79 |
+
video_transcript = TextFormatter().format_transcript(
|
80 |
+
video_transcript).replace('\n', ' ')
|
81 |
+
return channel_id, video_title, video_description, video_transcript
|
82 |
+
|
83 |
+
|
84 |
+
def get_input_data_df(video1_url, video2_url):
|
85 |
+
channel_id, video_title, video_description, video_transcript = get_youtube_video_data(
|
86 |
+
video1_url)
|
87 |
+
channel_id2, video_title2, video_description2, video_transcript2 = get_youtube_video_data(
|
88 |
+
video2_url)
|
89 |
+
channel_sim = 1 if channel_id == channel_id2 else 0
|
90 |
+
df = pd.DataFrame([[video_title, video_description, video_transcript] + [video_title2, video_description2, video_transcript2] + [channel_sim]], columns=[
|
91 |
+
'regret_title', 'regret_description', 'regret_transcript', 'recommendation_title', 'recommendation_description', 'recommendation_transcript', 'channel_sim'])
|
92 |
+
return df
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets==2.4.0
|
2 |
+
gradio==3.3.1
|
3 |
+
huggingface_hub==0.9.1
|
4 |
+
pandas==1.4.3
|
5 |
+
pyarrow==9.0.0
|
6 |
+
pytorch_lightning==1.7.6
|
7 |
+
pytube==12.1.0
|
8 |
+
requests==2.27.1
|
9 |
+
torch==1.12.1
|
10 |
+
torchmetrics==0.9.3
|
11 |
+
transformers==4.22.1
|
12 |
+
youtube_transcript_api==0.4.4
|
utils/huggingface_model_wrapper.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import PyTorchModelHubMixin
|
2 |
+
from huggingface_hub.constants import PYTORCH_WEIGHTS_NAME
|
3 |
+
from huggingface_hub.file_download import hf_hub_download
|
4 |
+
from .unifiedmodel import RRUM
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
class YoutubeVideoSimilarityModel(RRUM, PyTorchModelHubMixin):
|
10 |
+
"""
|
11 |
+
Hugging Face `PyTorchModelHubMixin` wrapper for RegretsReporter `RRUM` model.
|
12 |
+
This allows loading, using, and saving the model from Hugging Face model hub
|
13 |
+
with default Hugging Face methods `from_pretrained` and `save_pretrained`.
|
14 |
+
"""
|
15 |
+
@classmethod
|
16 |
+
def _from_pretrained(
|
17 |
+
cls,
|
18 |
+
model_id,
|
19 |
+
revision,
|
20 |
+
cache_dir,
|
21 |
+
force_download,
|
22 |
+
proxies,
|
23 |
+
resume_download,
|
24 |
+
local_files_only,
|
25 |
+
use_auth_token,
|
26 |
+
map_location="cpu",
|
27 |
+
strict=False,
|
28 |
+
**model_kwargs,
|
29 |
+
):
|
30 |
+
map_location = torch.device(map_location)
|
31 |
+
|
32 |
+
if os.path.isdir(model_id):
|
33 |
+
print("Loading weights from local directory")
|
34 |
+
model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
|
35 |
+
else:
|
36 |
+
model_file = hf_hub_download(
|
37 |
+
repo_id=model_id,
|
38 |
+
filename=PYTORCH_WEIGHTS_NAME,
|
39 |
+
revision=revision,
|
40 |
+
cache_dir=cache_dir,
|
41 |
+
force_download=force_download,
|
42 |
+
proxies=proxies,
|
43 |
+
resume_download=resume_download,
|
44 |
+
use_auth_token=use_auth_token,
|
45 |
+
local_files_only=local_files_only,
|
46 |
+
)
|
47 |
+
# convert Huggingface config to RRUM acceptable input parameters
|
48 |
+
if "config" in model_kwargs:
|
49 |
+
model_kwargs = {**model_kwargs["config"], **model_kwargs}
|
50 |
+
del model_kwargs["config"]
|
51 |
+
model = cls(**model_kwargs)
|
52 |
+
|
53 |
+
state_dict = torch.load(model_file, map_location=map_location)
|
54 |
+
model.load_state_dict(state_dict, strict=strict)
|
55 |
+
model.eval()
|
56 |
+
|
57 |
+
return model
|
utils/text_cleaning.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastcore.basics import listify
|
2 |
+
from fastcore.utils import compose
|
3 |
+
import unicodedata
|
4 |
+
from string import punctuation
|
5 |
+
import html
|
6 |
+
from itertools import groupby
|
7 |
+
import re
|
8 |
+
|
9 |
+
control_char_regex = re.compile(r'[\r\n\t]+')
|
10 |
+
url_regex = re.compile(
|
11 |
+
r'((http|https)\:\/\/)?[a-zA-Z0-9\.\/\?\:@\-_=#]+\.([a-zA-Z]){2,6}([a-zA-Z0-9\.\&\/\?\:@\-_=#])*')
|
12 |
+
username_regex = re.compile(r'(^|[^@\w])@(\w{1,15})\b')
|
13 |
+
|
14 |
+
|
15 |
+
def fix_html(text):
|
16 |
+
tmp_ls = []
|
17 |
+
for e in listify(text):
|
18 |
+
e = e.replace('#39;', "'").replace('amp;', '&').replace('#146;', "'").replace('nbsp;', ' ').replace(
|
19 |
+
'#36;', '$').replace('\\n', "\n").replace('quot;', "'").replace('<br />', "\n").replace(
|
20 |
+
'\\"', '"').replace('<unk>', ' ').replace(' @.@ ', '.').replace(' @-@ ', '-').replace('...', ' …')
|
21 |
+
tmp_ls.append(html.unescape(e))
|
22 |
+
|
23 |
+
text = tmp_ls
|
24 |
+
return text
|
25 |
+
|
26 |
+
|
27 |
+
def remove_control_char(text):
|
28 |
+
tmp_ls = []
|
29 |
+
for e in listify(text):
|
30 |
+
tmp_ls.append(re.sub(control_char_regex, '.', e))
|
31 |
+
|
32 |
+
text = tmp_ls
|
33 |
+
return text
|
34 |
+
|
35 |
+
|
36 |
+
def remove_remaining_control_chars(text):
|
37 |
+
tmp_ls = []
|
38 |
+
for e in listify(text):
|
39 |
+
tmp_ls.append(
|
40 |
+
''.join(ch for ch in e if unicodedata.category(ch)[0] != 'C'))
|
41 |
+
|
42 |
+
text = tmp_ls
|
43 |
+
return text
|
44 |
+
|
45 |
+
|
46 |
+
def remove_unicode_symbols(text):
|
47 |
+
tmp_ls = []
|
48 |
+
for e in listify(text):
|
49 |
+
tmp_ls.append(
|
50 |
+
''.join(ch for ch in e if unicodedata.category(ch)[0] != 'So'))
|
51 |
+
|
52 |
+
text = tmp_ls
|
53 |
+
return text
|
54 |
+
|
55 |
+
|
56 |
+
def standardise_punc(text):
|
57 |
+
transl_table = dict([(ord(x), ord(y))
|
58 |
+
for x, y in zip(u"‘’´“”–-", u"'''\"\"--")])
|
59 |
+
tmp_ls = []
|
60 |
+
for e in listify(text):
|
61 |
+
e = e.translate(transl_table)
|
62 |
+
tmp_ls.append(e)
|
63 |
+
|
64 |
+
text = tmp_ls
|
65 |
+
return text
|
66 |
+
|
67 |
+
|
68 |
+
def remove_news_tags(text):
|
69 |
+
tmp_ls = []
|
70 |
+
for e in listify(text):
|
71 |
+
e = re.sub(r"(<[A-Z].+?>)|(</[A-Z].+?>)", "", e)
|
72 |
+
tmp_ls.append(e)
|
73 |
+
|
74 |
+
text = tmp_ls
|
75 |
+
return text
|
76 |
+
|
77 |
+
|
78 |
+
def replace_urls(text):
|
79 |
+
filler, tmp_ls = '', []
|
80 |
+
for e in listify(text):
|
81 |
+
e = re.sub(r"(<a.+?>)|(</a>)|(<ref.+?>)", "", e)
|
82 |
+
e = re.sub(url_regex, filler, e)
|
83 |
+
tmp_ls.append(e)
|
84 |
+
|
85 |
+
text = tmp_ls
|
86 |
+
return text
|
87 |
+
|
88 |
+
|
89 |
+
def replace_usernames(text):
|
90 |
+
filler, tmp_ls = '', []
|
91 |
+
for e in listify(text):
|
92 |
+
occ = e.count('@')
|
93 |
+
for _ in range(occ):
|
94 |
+
e = e.replace('@<user>', f'{filler}')
|
95 |
+
# replace other user handles by filler
|
96 |
+
e = re.sub(username_regex, filler, e)
|
97 |
+
tmp_ls.append(e)
|
98 |
+
|
99 |
+
text = tmp_ls
|
100 |
+
return text
|
101 |
+
|
102 |
+
|
103 |
+
def remove_duplicate_punctuation(text):
|
104 |
+
tmp_ls = []
|
105 |
+
for e in listify(text):
|
106 |
+
e = re.sub(r'\b(\w+)( \1\b)+', r'\1', e)
|
107 |
+
punc = set(punctuation)
|
108 |
+
newtext = []
|
109 |
+
for k, g in groupby(e):
|
110 |
+
if k in punc:
|
111 |
+
newtext.append(k)
|
112 |
+
else:
|
113 |
+
newtext.extend(g)
|
114 |
+
e = ''.join(newtext)
|
115 |
+
tmp_ls.append(e)
|
116 |
+
|
117 |
+
text = tmp_ls
|
118 |
+
return text
|
119 |
+
|
120 |
+
|
121 |
+
def remove_multi_space(text):
|
122 |
+
tmp_ls = []
|
123 |
+
for e in listify(text):
|
124 |
+
tmp_ls.append(' '.join(e.split()))
|
125 |
+
|
126 |
+
text = tmp_ls
|
127 |
+
return text
|
128 |
+
|
129 |
+
|
130 |
+
clean_text_funcs = compose(*[fix_html, remove_control_char, remove_remaining_control_chars, remove_unicode_symbols,
|
131 |
+
standardise_punc, remove_news_tags, replace_urls, replace_usernames, remove_duplicate_punctuation, remove_multi_space])
|
utils/unifiedmodel.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup
|
2 |
+
import datasets
|
3 |
+
import pandas as pd
|
4 |
+
import pyarrow
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
import torchmetrics
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch
|
9 |
+
import types
|
10 |
+
import multiprocessing
|
11 |
+
from .text_cleaning import clean_text_funcs
|
12 |
+
|
13 |
+
|
14 |
+
class RRUMDataset():
|
15 |
+
scalar_features = ['channel_sim']
|
16 |
+
_image_features = ['regret_thumbnail',
|
17 |
+
'recommendation_thumbnail'] # not used atm
|
18 |
+
|
19 |
+
def __init__(self, data, with_transcript, cross_encoder_model_name_or_path, label_col="label", label_map=None, balance_label_counts=False, max_length=128, do_train_test_split=False, test_size=0.25, seed=42, keep_video_ids_for_predictions=False, encode_on_the_fly=False, clean_text=False, processing_batch_size=1000, processing_num_proc=1):
|
20 |
+
self._with_transcript = with_transcript
|
21 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
22 |
+
cross_encoder_model_name_or_path)
|
23 |
+
self.label_col = label_col
|
24 |
+
self.label_map = label_map
|
25 |
+
self.balance_label_counts = balance_label_counts
|
26 |
+
self.max_length = max_length
|
27 |
+
self.seed = seed
|
28 |
+
self.keep_video_ids_for_predictions = keep_video_ids_for_predictions
|
29 |
+
self.clean_text = clean_text
|
30 |
+
self.processing_batch_size = processing_batch_size
|
31 |
+
self.processing_num_proc = multiprocessing.cpu_count(
|
32 |
+
) if not processing_num_proc else processing_num_proc
|
33 |
+
|
34 |
+
self.text_types = ['title', 'description'] + \
|
35 |
+
(['transcript'] if self._with_transcript else [])
|
36 |
+
self._text_features = [
|
37 |
+
'regret_title', 'recommendation_title', 'regret_description',
|
38 |
+
'recommendation_description'] + (['regret_transcript', 'recommendation_transcript'] if self._with_transcript else [])
|
39 |
+
|
40 |
+
# LOAD DATA INTO DATASET
|
41 |
+
self.streaming_dataset = False
|
42 |
+
if isinstance(data, pd.DataFrame):
|
43 |
+
self.dataset = datasets.Dataset.from_pandas(data)
|
44 |
+
elif isinstance(data, types.GeneratorType):
|
45 |
+
examples_iterable = datasets.iterable_dataset.ExamplesIterable(
|
46 |
+
self._streaming_generate_examples, {"iterable": data})
|
47 |
+
self.dataset = datasets.IterableDataset(examples_iterable)
|
48 |
+
self._stream_dataset_example = next(iter(self.dataset))
|
49 |
+
self._stream_dataset_column_names = list(
|
50 |
+
self._stream_dataset_example.keys())
|
51 |
+
self.streaming_dataset = True
|
52 |
+
elif isinstance(data, pyarrow.Table):
|
53 |
+
self.dataset = datasets.Dataset(data)
|
54 |
+
else:
|
55 |
+
raise ValueError(
|
56 |
+
f'Type of data is {type(data)} when pd.DataFrame, pyarrow.Table, or generator of pyarrow.RecordBatch is allowed')
|
57 |
+
|
58 |
+
# PREPROCESS DATASET
|
59 |
+
self._preprocess()
|
60 |
+
|
61 |
+
# ENCODE DATASET
|
62 |
+
self.train_dataset = None
|
63 |
+
self.test_dataset = None
|
64 |
+
if self.streaming_dataset:
|
65 |
+
# IterableDataset doesn't have train_test_split method
|
66 |
+
if self.label_col:
|
67 |
+
self.train_dataset = self._encode_streaming(self.dataset)
|
68 |
+
print('Streaming dataset available in .train_dataset')
|
69 |
+
else:
|
70 |
+
self.test_dataset = self._encode_streaming(self.dataset)
|
71 |
+
print(
|
72 |
+
'Streaming dataset available in .test_dataset because label_col=None')
|
73 |
+
else:
|
74 |
+
# dataset into train_dataset and/or test_dataset
|
75 |
+
if do_train_test_split:
|
76 |
+
ds = self.dataset.train_test_split(
|
77 |
+
test_size=test_size, shuffle=True, seed=self.seed, stratify_by_column=self.label_col)
|
78 |
+
self.train_dataset = ds['train']
|
79 |
+
self.test_dataset = ds['test']
|
80 |
+
print(
|
81 |
+
f'Dataset was splitted into train and test with test_size={test_size}')
|
82 |
+
else:
|
83 |
+
if self.label_col:
|
84 |
+
self.train_dataset = self.dataset
|
85 |
+
else:
|
86 |
+
self.test_dataset = self.dataset
|
87 |
+
|
88 |
+
if encode_on_the_fly:
|
89 |
+
if self.train_dataset:
|
90 |
+
self.train_dataset.set_transform(self._encode_on_the_fly)
|
91 |
+
print('On-the-fly encoded dataset available in .train_dataset')
|
92 |
+
if self.test_dataset:
|
93 |
+
self.test_dataset.set_transform(self._encode_on_the_fly)
|
94 |
+
print('On-the-fly encoded dataset available in .test_dataset')
|
95 |
+
else:
|
96 |
+
if self.train_dataset:
|
97 |
+
self.train_dataset = self._encode(self.train_dataset)
|
98 |
+
print('Pre-encoded dataset available in .train_dataset')
|
99 |
+
if self.test_dataset:
|
100 |
+
self.test_dataset = self._encode(self.test_dataset)
|
101 |
+
print('Pre-encoded dataset available in .test_dataset')
|
102 |
+
|
103 |
+
def __len__(self):
|
104 |
+
if self.streaming_dataset:
|
105 |
+
raise ValueError(
|
106 |
+
f'Streaming dataset does not support len() method')
|
107 |
+
return len(self.dataset)
|
108 |
+
|
109 |
+
def __getitem__(self, index):
|
110 |
+
if self.streaming_dataset:
|
111 |
+
return next(iter(self.dataset))
|
112 |
+
return self.dataset[index]
|
113 |
+
|
114 |
+
def _streaming_generate_examples(self, iterable):
|
115 |
+
id_ = 0
|
116 |
+
# TODO: make sure GeneratorType is pyarrow.RecordBatch
|
117 |
+
if isinstance(iterable, types.GeneratorType):
|
118 |
+
for examples in iterable:
|
119 |
+
for ex in examples.to_pylist():
|
120 |
+
yield id_, ex
|
121 |
+
id_ += 1
|
122 |
+
|
123 |
+
def _preprocess(self):
|
124 |
+
if self._with_transcript:
|
125 |
+
self.dataset = self.dataset.filter(
|
126 |
+
lambda example: example['regret_transcript'] is not None and example['recommendation_transcript'] is not None)
|
127 |
+
else:
|
128 |
+
self.dataset = self.dataset.filter(
|
129 |
+
lambda example: example['regret_transcript'] is None or example['recommendation_transcript'] is None)
|
130 |
+
if self.label_col:
|
131 |
+
if self.streaming_dataset:
|
132 |
+
if self.label_col in self._stream_dataset_column_names and isinstance(self._stream_dataset_example[self.label_col], str):
|
133 |
+
if not self.label_map:
|
134 |
+
raise ValueError(
|
135 |
+
f'"label_map" dict was not provided and is needed to encode string labels for streaming datasets')
|
136 |
+
# cast_column method had issues with streaming dataset
|
137 |
+
self.dataset = self.dataset.map(
|
138 |
+
self._streaming_rename_labels)
|
139 |
+
else:
|
140 |
+
if self.dataset.features[self.label_col].dtype == 'string':
|
141 |
+
if not self.label_map:
|
142 |
+
self.label_map = {k: v for v, k in enumerate(
|
143 |
+
self.dataset.unique(self.label_col))}
|
144 |
+
self.dataset = self.dataset.filter(
|
145 |
+
lambda example: example[self.label_col] in self.label_map.keys())
|
146 |
+
self.dataset = self.dataset.cast_column(self.label_col, datasets.ClassLabel(
|
147 |
+
num_classes=len(self.label_map), names=list(self.label_map.keys())))
|
148 |
+
|
149 |
+
self.dataset = self.dataset.filter(lambda example: not any(x in [None, ""] for x in [
|
150 |
+
example[key] for key in self._text_features + self.scalar_features + ([self.label_col] if self.label_col else [])])) # dropna
|
151 |
+
|
152 |
+
if self.balance_label_counts and self.label_col and not self.streaming_dataset:
|
153 |
+
label_datasets = {}
|
154 |
+
for label in list(self.label_map.values()):
|
155 |
+
label_dataset = self.dataset.filter(
|
156 |
+
lambda example: example[self.label_col] == label)
|
157 |
+
label_datasets[len(label_dataset)] = label_dataset
|
158 |
+
min_label_count = min(label_datasets)
|
159 |
+
sampled_datasets = [dataset.train_test_split(train_size=min_label_count, shuffle=True, seed=self.seed)[
|
160 |
+
'train'] if len(dataset) != min_label_count else dataset for dataset in label_datasets.values()]
|
161 |
+
self.dataset = datasets.concatenate_datasets(sampled_datasets)
|
162 |
+
|
163 |
+
if self.clean_text:
|
164 |
+
self.dataset = self.dataset.map(self._clean_text, batched=not self.streaming_dataset,
|
165 |
+
batch_size=self.processing_batch_size)
|
166 |
+
self.dataset = self.dataset.map(self._truncate_and_strip_text, batched=not self.streaming_dataset,
|
167 |
+
batch_size=self.processing_batch_size)
|
168 |
+
|
169 |
+
def _streaming_rename_labels(self, example):
|
170 |
+
# rename labels according to label_map if not already correct labels
|
171 |
+
if isinstance(example[self.label_col], list):
|
172 |
+
example[self.label_col] = [self.label_map.get(
|
173 |
+
ex, None) for ex in example[self.label_col] if ex not in self.label_map.values()]
|
174 |
+
elif isinstance(example[self.label_col], str) and example[self.label_col] not in self.label_map.values():
|
175 |
+
example[self.label_col] = self.label_map.get(
|
176 |
+
example[self.label_col], None)
|
177 |
+
else:
|
178 |
+
raise ValueError(
|
179 |
+
f'Type of example label is {type(example[self.label_col])} when list or string is allowed')
|
180 |
+
return example
|
181 |
+
|
182 |
+
def _clean_text(self, example):
|
183 |
+
for feat in self._text_features:
|
184 |
+
example[feat] = clean_text_funcs(example[feat])[0] if isinstance(
|
185 |
+
example[feat], str) else clean_text_funcs(example[feat])
|
186 |
+
return example
|
187 |
+
|
188 |
+
def _truncate_and_strip_text(self, example):
|
189 |
+
# tokenizer will truncate to max_length tokens anyway so to save RAM let's truncate to max_length words already beforehand
|
190 |
+
# one word is usually one or more tokens so should be safe to truncate this way without losing information
|
191 |
+
for feat in self._text_features:
|
192 |
+
if isinstance(example[feat], list):
|
193 |
+
example[feat] = [
|
194 |
+
' '.join(text.split()[:self.max_length]).strip() for text in example[feat] if text]
|
195 |
+
elif isinstance(example[feat], str):
|
196 |
+
example[feat] = ' '.join(example[feat].split()[
|
197 |
+
:self.max_length]).strip()
|
198 |
+
elif example[feat] is None:
|
199 |
+
return None
|
200 |
+
else:
|
201 |
+
raise ValueError(
|
202 |
+
f'Type of example is {type(example[feat])} when list or string is allowed')
|
203 |
+
return example
|
204 |
+
|
205 |
+
def _encode(self, dataset):
|
206 |
+
encoded_dataset = None
|
207 |
+
for text_type in self.text_types:
|
208 |
+
encoded_text_type = dataset.map(lambda regret, recommendation: self.tokenizer(regret, recommendation, padding="max_length", truncation=True, max_length=self.max_length), batched=True,
|
209 |
+
batch_size=self.processing_batch_size, num_proc=self.processing_num_proc, input_columns=[f'regret_{text_type}', f'recommendation_{text_type}'], remove_columns=dataset.column_names)
|
210 |
+
encoded_text_type = encoded_text_type.rename_columns(
|
211 |
+
{col: f'{text_type}_{col}' for col in encoded_text_type.column_names}) # e.g. input_ids -> title_input_ids so we have separate input_ids for each text_type
|
212 |
+
if encoded_dataset:
|
213 |
+
encoded_dataset = datasets.concatenate_datasets(
|
214 |
+
[encoded_dataset, encoded_text_type], axis=1)
|
215 |
+
else:
|
216 |
+
encoded_dataset = encoded_text_type
|
217 |
+
|
218 |
+
# copy scalar features and label from original dataset to the encoded dataset
|
219 |
+
for scalar_feat in self.scalar_features:
|
220 |
+
encoded_dataset = encoded_dataset.add_column(
|
221 |
+
name=scalar_feat, column=dataset[scalar_feat])
|
222 |
+
if self.label_col:
|
223 |
+
encoded_dataset = encoded_dataset.add_column(
|
224 |
+
name=self.label_col, column=dataset[self.label_col])
|
225 |
+
if self.keep_video_ids_for_predictions:
|
226 |
+
for id in ['regret_id', "recommendation_id"]:
|
227 |
+
encoded_dataset = encoded_dataset.add_column(
|
228 |
+
name=id, column=dataset[id])
|
229 |
+
|
230 |
+
encoded_dataset.set_format(
|
231 |
+
type='torch', columns=encoded_dataset.column_names)
|
232 |
+
return encoded_dataset
|
233 |
+
|
234 |
+
def _encode_streaming(self, dataset):
|
235 |
+
encoded_dataset = dataset.map(self._encode_on_the_fly, batched=True,
|
236 |
+
batch_size=self.processing_batch_size, remove_columns=list(set(self._stream_dataset_column_names)-set(self.scalar_features + (
|
237 |
+
[self.label_col] if self.label_col else []) + (['regret_id', "recommendation_id"] if self.keep_video_ids_for_predictions else [])))) # IterableDataset doesn't have column_names attribute as normal Dataset
|
238 |
+
encoded_dataset = encoded_dataset.with_format("torch")
|
239 |
+
return encoded_dataset
|
240 |
+
|
241 |
+
def _encode_on_the_fly(self, batch):
|
242 |
+
for text_type in self.text_types:
|
243 |
+
encoded_text_type = dict(self.tokenizer(
|
244 |
+
batch[f'regret_{text_type}'], batch[f'recommendation_{text_type}'], padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt"))
|
245 |
+
for encoded_key in encoded_text_type.copy():
|
246 |
+
encoded_text_type[f"{text_type}_{encoded_key}"] = encoded_text_type.pop(encoded_key) if not self.streaming_dataset else encoded_text_type.pop(
|
247 |
+
encoded_key).squeeze(0) # e.g. input_ids -> title_input_ids so we have separate input_ids for each text_type
|
248 |
+
del batch[f'regret_{text_type}']
|
249 |
+
del batch[f'recommendation_{text_type}']
|
250 |
+
batch.update(encoded_text_type)
|
251 |
+
for scalar_feat in self.scalar_features:
|
252 |
+
batch[scalar_feat] = torch.as_tensor(
|
253 |
+
batch[scalar_feat]) if not self.streaming_dataset else torch.as_tensor(batch[scalar_feat]).squeeze(0)
|
254 |
+
if self.label_col:
|
255 |
+
batch[self.label_col] = torch.as_tensor(
|
256 |
+
batch[self.label_col]) if not self.streaming_dataset else torch.as_tensor(batch[self.label_col]).squeeze(0)
|
257 |
+
return batch
|
258 |
+
|
259 |
+
|
260 |
+
class RRUM(pl.LightningModule):
|
261 |
+
def __init__(self, text_types, scalar_features, label_col, cross_encoder_model_name_or_path, optimizer_config=None, freeze_policy=None, pos_weight=None):
|
262 |
+
super().__init__()
|
263 |
+
self.save_hyperparameters()
|
264 |
+
self.text_types = text_types
|
265 |
+
self.scalar_features = scalar_features
|
266 |
+
self.label_col = label_col
|
267 |
+
self.optimizer_config = optimizer_config
|
268 |
+
self.cross_encoder_model_name_or_path = cross_encoder_model_name_or_path
|
269 |
+
self.cross_encoders = nn.ModuleDict({})
|
270 |
+
for t in self.text_types:
|
271 |
+
self.cross_encoders[t] = AutoModelForSequenceClassification.from_pretrained(
|
272 |
+
self.cross_encoder_model_name_or_path)
|
273 |
+
if freeze_policy is not None:
|
274 |
+
for xe in self.cross_encoders.values():
|
275 |
+
for name, param in xe.named_parameters():
|
276 |
+
if freeze_policy(name):
|
277 |
+
param.requires_grad = False
|
278 |
+
cross_encoder_out_features = list(self.cross_encoders.values())[0](
|
279 |
+
torch.randint(1, 2, (1, 2))).logits.size(dim=1)
|
280 |
+
self.lin1 = nn.Linear(len(self.cross_encoders) * cross_encoder_out_features +
|
281 |
+
len(self.scalar_features), 1)
|
282 |
+
self.ac_metric = torchmetrics.Accuracy()
|
283 |
+
self.pr_metric = torchmetrics.Precision()
|
284 |
+
self.re_metric = torchmetrics.Recall()
|
285 |
+
self.auc_metric = torchmetrics.AUROC()
|
286 |
+
|
287 |
+
if pos_weight:
|
288 |
+
self.loss = nn.BCEWithLogitsLoss(
|
289 |
+
pos_weight=torch.Tensor([pos_weight]))
|
290 |
+
else:
|
291 |
+
self.loss = nn.BCEWithLogitsLoss()
|
292 |
+
|
293 |
+
def forward(self, x):
|
294 |
+
cross_logits = {}
|
295 |
+
for f in self.text_types:
|
296 |
+
inputs = {key.split(f'{f}_')[1]: x[key]
|
297 |
+
for key in x if f in key} # e.g. title_input_ids -> input_ids since we have separate input_ids for each text_type
|
298 |
+
cross_logits[f] = self.cross_encoders[f](**inputs).logits
|
299 |
+
x = torch.cat([*cross_logits.values()] +
|
300 |
+
[x[scalar][:, None] for scalar in self.scalar_features],
|
301 |
+
1
|
302 |
+
)
|
303 |
+
del cross_logits
|
304 |
+
|
305 |
+
x = self.lin1(x)
|
306 |
+
return x
|
307 |
+
|
308 |
+
def configure_optimizers(self):
|
309 |
+
if self.optimizer_config:
|
310 |
+
return self.optimizer_config(self)
|
311 |
+
|
312 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=5e-5)
|
313 |
+
scheduler = get_linear_schedule_with_warmup(
|
314 |
+
optimizer,
|
315 |
+
num_warmup_steps=int(
|
316 |
+
self.trainer.estimated_stepping_batches * 0.05),
|
317 |
+
num_training_steps=self.trainer.estimated_stepping_batches,
|
318 |
+
)
|
319 |
+
scheduler = {'scheduler': scheduler,
|
320 |
+
'interval': 'step', 'frequency': 1}
|
321 |
+
return [optimizer], [scheduler]
|
322 |
+
|
323 |
+
def training_step(self, train_batch, batch_idx):
|
324 |
+
y = train_batch[self.label_col].unsqueeze(1).float()
|
325 |
+
logits = self(train_batch)
|
326 |
+
loss = self.loss(logits, y)
|
327 |
+
self.log('train_loss', loss)
|
328 |
+
return loss
|
329 |
+
|
330 |
+
def validation_step(self, val_batch, batch_idx):
|
331 |
+
y = val_batch[self.label_col].unsqueeze(1).float()
|
332 |
+
logits = self(val_batch)
|
333 |
+
loss = self.loss(logits, y)
|
334 |
+
self.ac_metric(logits, y.int())
|
335 |
+
self.pr_metric(logits, y.int())
|
336 |
+
self.re_metric(logits, y.int())
|
337 |
+
self.auc_metric(logits, y.int())
|
338 |
+
self.log('validation_accuracy', self.ac_metric)
|
339 |
+
self.log('validation_precision', self.pr_metric)
|
340 |
+
self.log('validation_recall', self.re_metric)
|
341 |
+
self.log('validation_auc', self.auc_metric)
|
342 |
+
self.log('val_loss', loss, prog_bar=True)
|
343 |
+
|
344 |
+
def validation_epoch_end(self, outputs):
|
345 |
+
self.log('validation_accuracy_ep', self.ac_metric)
|
346 |
+
self.log('validation_precision_ep', self.pr_metric)
|
347 |
+
self.log('validation_recall_ep', self.re_metric)
|
348 |
+
self.log('validation_auc_ep', self.auc_metric)
|