Spaces:
Runtime error
Runtime error
File size: 1,123 Bytes
086c6d8 d89a1db 086c6d8 d89a1db 086c6d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import pickle
import numpy as np
from transformers import AutoTokenizer
class Utility:
def __init__(self) -> None:
pass
def tokenize(self, plot, genres):
id2label = {idx:label for idx, label in enumerate(genres)}
label2id = {label:idx for idx, label in enumerate(genres)}
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
clean_plot_tokenized = tokenizer(plot, padding="max_length", truncation=True, max_length=512)
return (id2label, label2id, tokenizer, clean_plot_tokenized)
def train_test_split(self, df, y):
"""Splits the dataset into training and validation set"""
cleaned_plot_df = df['clean_plot_tokenized']
# xtrain, xval, ytrain, yval = train_test_split(cleaned_plot_df, y, test_size=0.2, random_state=9)
# stratified sampling
xtrain, ytrain, xval, yval = iterative_train_test_split(np.asmatrix(df['clean_plot_tokenized']).transpose(), y, test_size = 0.2)
xtrain = np.array(xtrain).flatten()
xval = np.array(xval).flatten()
return (xtrain, xval, ytrain, yval)
|