A-M-S commited on
Commit
086c6d8
1 Parent(s): e5c2719

Added Utility and preprocess files

Browse files
Files changed (2) hide show
  1. preprocess.py +126 -0
  2. utility.py +74 -0
preprocess.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pickle
3
+ import pandas as pd
4
+ import nltk
5
+ from matplotlib import pyplot as plt
6
+ import seaborn as sns
7
+ import regex as re
8
+ from nltk.corpus import stopwords
9
+ from sklearn.preprocessing import MultiLabelBinarizer
10
+
11
+ class Preprocess:
12
+ df = None
13
+ genres = None
14
+ y = None
15
+
16
+ def __init__(self) -> None:
17
+ self.df = pd.read_csv('movies_genre.csv')
18
+ self.genres = []
19
+
20
+ def plot_freq_dist(self):
21
+ all_genres = sum(self.genres, [])
22
+ all_genres = nltk.FreqDist(all_genres)
23
+
24
+ # create frequency dataframe
25
+ all_genres_df = pd.DataFrame({'Genres': list(all_genres.keys()),
26
+ 'Count': list(all_genres.values())})
27
+
28
+ g = all_genres_df.nlargest(columns="Count", n = 50)
29
+ plt.figure(figsize=(12,15))
30
+ ax = sns.barplot(data=g, x= "Count", y = "Genres")
31
+ ax.set(xlabel = 'Count',ylabel= 'Genre')
32
+ plt.show()
33
+
34
+ # def extract_genre_values(self):
35
+ # # extract genres
36
+ # for row in self.df['genres']:
37
+ # self.genres.append(list(json.loads(row.replace("\'", "\"")).values()))
38
+
39
+ # # add to dataframe
40
+ # self.df['genres'] = self.genres
41
+
42
+ def retain_top_freq_genres(self):
43
+ for (index, row) in enumerate(self.df['genres']):
44
+ self.genres.append(json.loads(row.replace("\'", "\"")))
45
+ self.df.at[index, "genres"] = json.loads(row.replace("\'", "\""))
46
+ # create frequency dataframe
47
+ all_genres = sum(self.genres,[])
48
+ all_genres = nltk.FreqDist(all_genres)
49
+ all_genres_df = pd.DataFrame({'Genres': list(all_genres.keys()),
50
+ 'Count': list(all_genres.values())})
51
+
52
+ # # considering only top 35 frequent genres
53
+ # g = all_genres_df.nlargest(columns="Count", n = 35)
54
+ # g.head()
55
+ # top_genres = list(g['Genres'])
56
+
57
+ # Genres with freq > 1000
58
+ all_genres_df = all_genres_df[all_genres_df["Count"] >= 8000]
59
+ top_genres = list(all_genres_df['Genres'])
60
+
61
+ # Removing genres which are not important
62
+ # top_genres.remove('Other')
63
+ # top_genres.remove('Crime Thriller')
64
+ # top_genres.remove('Movie')
65
+ # top_genres.remove('History')
66
+ # top_genres.remove('Bollywood')
67
+
68
+ # Removing genres other than top selected genres
69
+ for (index,row) in enumerate(self.df['genres']):
70
+ row = [genre for genre in row if genre in top_genres]
71
+ self.df.at[index, "genres"] = row
72
+
73
+ return top_genres
74
+
75
+ def clean_text(self, text):
76
+ """Cleans text by removing certains unwanted characters"""
77
+
78
+ # remove backslash-apostrophe
79
+ text = re.sub("\'", "", text)
80
+ # remove everything except alphabets
81
+ text = re.sub("[^a-zA-Z]"," ",text)
82
+ # remove whitespaces
83
+ text = ' '.join(text.split())
84
+ # convert text to lowercase
85
+ text = text.lower()
86
+
87
+ return text
88
+
89
+ def remove_stopwords(self,text):
90
+ """Function to remove stopwords"""
91
+ stop_words = set(stopwords.words('english'))
92
+ no_stopword_text = [w for w in text.split() if not w in stop_words]
93
+ return ' '.join(no_stopword_text)
94
+
95
+
96
+ def multi_label_binarizer(self):
97
+ multilabel_binarizer = MultiLabelBinarizer()
98
+ multilabel_binarizer.fit(self.df['genres'])
99
+
100
+ pickle.dump(multilabel_binarizer, open("models/multilabel_binarizer", 'wb'))
101
+
102
+ # transform target variable
103
+ self.y = multilabel_binarizer.transform(self.df['genres'])
104
+
105
+
106
+ def apply(self):
107
+ # remove samples with no plot
108
+ self.df = self.df[~(pd.isna(self.df['plot']))]
109
+ # self.df = self.df.head(20000)
110
+ # removing rows which has very small plot fewer than 500 characters
111
+ self.df = self.df[self.df["plot"].map(len) >= 500]
112
+ self.df = self.df.reset_index()
113
+
114
+ # self.extract_genre_values()
115
+
116
+ genres = self.retain_top_freq_genres()
117
+
118
+
119
+ self.df['clean_plot'] = self.df['plot'].apply(lambda x: self.clean_text(str(x)))
120
+
121
+ self.df['clean_plot'] = self.df['clean_plot'].apply(lambda x: self.remove_stopwords(str(x)))
122
+
123
+ self.multi_label_binarizer()
124
+
125
+ return [self.df, self.y, genres]
126
+
utility.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import wikipedia
3
+ import numpy as np
4
+ from sklearn.model_selection import train_test_split
5
+ from skmultilearn.model_selection import iterative_train_test_split
6
+ from sklearn.feature_extraction.text import TfidfVectorizer
7
+ from transformers import AutoTokenizer
8
+
9
+ class Utility:
10
+ def __init__(self) -> None:
11
+ pass
12
+
13
+ def get_summary(self,url):
14
+ summary = ""
15
+ try:
16
+ title = url.split("wiki/")[-1]
17
+ print(title)
18
+
19
+ wiki = wikipedia.page(title=title)
20
+
21
+ summary = wiki.summary
22
+ except:
23
+ pass
24
+
25
+ return summary
26
+
27
+
28
+ def get_plot(self,url):
29
+ plot=""
30
+
31
+ try:
32
+ title = url.split("wiki/")[-1]
33
+ wiki = wikipedia.page(title=title)
34
+
35
+ content = wiki.content.split('== Plot ==\n')[1]
36
+ plot = content.split('==')[0]
37
+ except:
38
+ pass
39
+
40
+ return plot
41
+
42
+ def tokenize(self, df, genres):
43
+ id2label = {idx:label for idx, label in enumerate(genres)}
44
+ label2id = {label:idx for idx, label in enumerate(genres)}
45
+
46
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
47
+
48
+ df['clean_plot_tokenized'] = ''
49
+ for (idx, row) in df.iterrows():
50
+ df.at[idx,"clean_plot_tokenized"] = tokenizer(row["clean_plot"], padding="max_length", truncation=True, max_length=512)
51
+ return (id2label, label2id, tokenizer, df)
52
+
53
+ def train_test_split(self, df, y):
54
+ """Splits the dataset into training and validation set"""
55
+ cleaned_plot_df = df['clean_plot_tokenized']
56
+ # xtrain, xval, ytrain, yval = train_test_split(cleaned_plot_df, y, test_size=0.2, random_state=9)
57
+
58
+ # stratified sampling
59
+ xtrain, ytrain, xval, yval = iterative_train_test_split(np.asmatrix(df['clean_plot_tokenized']).transpose(), y, test_size = 0.2)
60
+ xtrain = np.array(xtrain).flatten()
61
+ xval = np.array(xval).flatten()
62
+
63
+ return (xtrain, xval, ytrain, yval)
64
+
65
+
66
+ def vectorize(self, xtrain, xval):
67
+ """Creates TF-IDF features"""
68
+ tfidf_vectorizer = TfidfVectorizer(max_df=0.8, max_features=10000)
69
+ xtrain_tfidf = tfidf_vectorizer.fit_transform(xtrain)
70
+ xval_tfidf = tfidf_vectorizer.transform(xval)
71
+
72
+ pickle.dump(tfidf_vectorizer, open("models/tfidf_vectorizer", 'wb'))
73
+
74
+ return (xtrain_tfidf, xval_tfidf)