Spaces:
Runtime error
Runtime error
Added Utility and preprocess files
Browse files- preprocess.py +126 -0
- utility.py +74 -0
preprocess.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import pickle
|
3 |
+
import pandas as pd
|
4 |
+
import nltk
|
5 |
+
from matplotlib import pyplot as plt
|
6 |
+
import seaborn as sns
|
7 |
+
import regex as re
|
8 |
+
from nltk.corpus import stopwords
|
9 |
+
from sklearn.preprocessing import MultiLabelBinarizer
|
10 |
+
|
11 |
+
class Preprocess:
|
12 |
+
df = None
|
13 |
+
genres = None
|
14 |
+
y = None
|
15 |
+
|
16 |
+
def __init__(self) -> None:
|
17 |
+
self.df = pd.read_csv('movies_genre.csv')
|
18 |
+
self.genres = []
|
19 |
+
|
20 |
+
def plot_freq_dist(self):
|
21 |
+
all_genres = sum(self.genres, [])
|
22 |
+
all_genres = nltk.FreqDist(all_genres)
|
23 |
+
|
24 |
+
# create frequency dataframe
|
25 |
+
all_genres_df = pd.DataFrame({'Genres': list(all_genres.keys()),
|
26 |
+
'Count': list(all_genres.values())})
|
27 |
+
|
28 |
+
g = all_genres_df.nlargest(columns="Count", n = 50)
|
29 |
+
plt.figure(figsize=(12,15))
|
30 |
+
ax = sns.barplot(data=g, x= "Count", y = "Genres")
|
31 |
+
ax.set(xlabel = 'Count',ylabel= 'Genre')
|
32 |
+
plt.show()
|
33 |
+
|
34 |
+
# def extract_genre_values(self):
|
35 |
+
# # extract genres
|
36 |
+
# for row in self.df['genres']:
|
37 |
+
# self.genres.append(list(json.loads(row.replace("\'", "\"")).values()))
|
38 |
+
|
39 |
+
# # add to dataframe
|
40 |
+
# self.df['genres'] = self.genres
|
41 |
+
|
42 |
+
def retain_top_freq_genres(self):
|
43 |
+
for (index, row) in enumerate(self.df['genres']):
|
44 |
+
self.genres.append(json.loads(row.replace("\'", "\"")))
|
45 |
+
self.df.at[index, "genres"] = json.loads(row.replace("\'", "\""))
|
46 |
+
# create frequency dataframe
|
47 |
+
all_genres = sum(self.genres,[])
|
48 |
+
all_genres = nltk.FreqDist(all_genres)
|
49 |
+
all_genres_df = pd.DataFrame({'Genres': list(all_genres.keys()),
|
50 |
+
'Count': list(all_genres.values())})
|
51 |
+
|
52 |
+
# # considering only top 35 frequent genres
|
53 |
+
# g = all_genres_df.nlargest(columns="Count", n = 35)
|
54 |
+
# g.head()
|
55 |
+
# top_genres = list(g['Genres'])
|
56 |
+
|
57 |
+
# Genres with freq > 1000
|
58 |
+
all_genres_df = all_genres_df[all_genres_df["Count"] >= 8000]
|
59 |
+
top_genres = list(all_genres_df['Genres'])
|
60 |
+
|
61 |
+
# Removing genres which are not important
|
62 |
+
# top_genres.remove('Other')
|
63 |
+
# top_genres.remove('Crime Thriller')
|
64 |
+
# top_genres.remove('Movie')
|
65 |
+
# top_genres.remove('History')
|
66 |
+
# top_genres.remove('Bollywood')
|
67 |
+
|
68 |
+
# Removing genres other than top selected genres
|
69 |
+
for (index,row) in enumerate(self.df['genres']):
|
70 |
+
row = [genre for genre in row if genre in top_genres]
|
71 |
+
self.df.at[index, "genres"] = row
|
72 |
+
|
73 |
+
return top_genres
|
74 |
+
|
75 |
+
def clean_text(self, text):
|
76 |
+
"""Cleans text by removing certains unwanted characters"""
|
77 |
+
|
78 |
+
# remove backslash-apostrophe
|
79 |
+
text = re.sub("\'", "", text)
|
80 |
+
# remove everything except alphabets
|
81 |
+
text = re.sub("[^a-zA-Z]"," ",text)
|
82 |
+
# remove whitespaces
|
83 |
+
text = ' '.join(text.split())
|
84 |
+
# convert text to lowercase
|
85 |
+
text = text.lower()
|
86 |
+
|
87 |
+
return text
|
88 |
+
|
89 |
+
def remove_stopwords(self,text):
|
90 |
+
"""Function to remove stopwords"""
|
91 |
+
stop_words = set(stopwords.words('english'))
|
92 |
+
no_stopword_text = [w for w in text.split() if not w in stop_words]
|
93 |
+
return ' '.join(no_stopword_text)
|
94 |
+
|
95 |
+
|
96 |
+
def multi_label_binarizer(self):
|
97 |
+
multilabel_binarizer = MultiLabelBinarizer()
|
98 |
+
multilabel_binarizer.fit(self.df['genres'])
|
99 |
+
|
100 |
+
pickle.dump(multilabel_binarizer, open("models/multilabel_binarizer", 'wb'))
|
101 |
+
|
102 |
+
# transform target variable
|
103 |
+
self.y = multilabel_binarizer.transform(self.df['genres'])
|
104 |
+
|
105 |
+
|
106 |
+
def apply(self):
|
107 |
+
# remove samples with no plot
|
108 |
+
self.df = self.df[~(pd.isna(self.df['plot']))]
|
109 |
+
# self.df = self.df.head(20000)
|
110 |
+
# removing rows which has very small plot fewer than 500 characters
|
111 |
+
self.df = self.df[self.df["plot"].map(len) >= 500]
|
112 |
+
self.df = self.df.reset_index()
|
113 |
+
|
114 |
+
# self.extract_genre_values()
|
115 |
+
|
116 |
+
genres = self.retain_top_freq_genres()
|
117 |
+
|
118 |
+
|
119 |
+
self.df['clean_plot'] = self.df['plot'].apply(lambda x: self.clean_text(str(x)))
|
120 |
+
|
121 |
+
self.df['clean_plot'] = self.df['clean_plot'].apply(lambda x: self.remove_stopwords(str(x)))
|
122 |
+
|
123 |
+
self.multi_label_binarizer()
|
124 |
+
|
125 |
+
return [self.df, self.y, genres]
|
126 |
+
|
utility.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import wikipedia
|
3 |
+
import numpy as np
|
4 |
+
from sklearn.model_selection import train_test_split
|
5 |
+
from skmultilearn.model_selection import iterative_train_test_split
|
6 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
7 |
+
from transformers import AutoTokenizer
|
8 |
+
|
9 |
+
class Utility:
|
10 |
+
def __init__(self) -> None:
|
11 |
+
pass
|
12 |
+
|
13 |
+
def get_summary(self,url):
|
14 |
+
summary = ""
|
15 |
+
try:
|
16 |
+
title = url.split("wiki/")[-1]
|
17 |
+
print(title)
|
18 |
+
|
19 |
+
wiki = wikipedia.page(title=title)
|
20 |
+
|
21 |
+
summary = wiki.summary
|
22 |
+
except:
|
23 |
+
pass
|
24 |
+
|
25 |
+
return summary
|
26 |
+
|
27 |
+
|
28 |
+
def get_plot(self,url):
|
29 |
+
plot=""
|
30 |
+
|
31 |
+
try:
|
32 |
+
title = url.split("wiki/")[-1]
|
33 |
+
wiki = wikipedia.page(title=title)
|
34 |
+
|
35 |
+
content = wiki.content.split('== Plot ==\n')[1]
|
36 |
+
plot = content.split('==')[0]
|
37 |
+
except:
|
38 |
+
pass
|
39 |
+
|
40 |
+
return plot
|
41 |
+
|
42 |
+
def tokenize(self, df, genres):
|
43 |
+
id2label = {idx:label for idx, label in enumerate(genres)}
|
44 |
+
label2id = {label:idx for idx, label in enumerate(genres)}
|
45 |
+
|
46 |
+
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
|
47 |
+
|
48 |
+
df['clean_plot_tokenized'] = ''
|
49 |
+
for (idx, row) in df.iterrows():
|
50 |
+
df.at[idx,"clean_plot_tokenized"] = tokenizer(row["clean_plot"], padding="max_length", truncation=True, max_length=512)
|
51 |
+
return (id2label, label2id, tokenizer, df)
|
52 |
+
|
53 |
+
def train_test_split(self, df, y):
|
54 |
+
"""Splits the dataset into training and validation set"""
|
55 |
+
cleaned_plot_df = df['clean_plot_tokenized']
|
56 |
+
# xtrain, xval, ytrain, yval = train_test_split(cleaned_plot_df, y, test_size=0.2, random_state=9)
|
57 |
+
|
58 |
+
# stratified sampling
|
59 |
+
xtrain, ytrain, xval, yval = iterative_train_test_split(np.asmatrix(df['clean_plot_tokenized']).transpose(), y, test_size = 0.2)
|
60 |
+
xtrain = np.array(xtrain).flatten()
|
61 |
+
xval = np.array(xval).flatten()
|
62 |
+
|
63 |
+
return (xtrain, xval, ytrain, yval)
|
64 |
+
|
65 |
+
|
66 |
+
def vectorize(self, xtrain, xval):
|
67 |
+
"""Creates TF-IDF features"""
|
68 |
+
tfidf_vectorizer = TfidfVectorizer(max_df=0.8, max_features=10000)
|
69 |
+
xtrain_tfidf = tfidf_vectorizer.fit_transform(xtrain)
|
70 |
+
xval_tfidf = tfidf_vectorizer.transform(xval)
|
71 |
+
|
72 |
+
pickle.dump(tfidf_vectorizer, open("models/tfidf_vectorizer", 'wb'))
|
73 |
+
|
74 |
+
return (xtrain_tfidf, xval_tfidf)
|