A-M-S commited on
Commit
d89a1db
1 Parent(s): 086c6d8

Updated app.py

Browse files
Files changed (3) hide show
  1. app.py +49 -1
  2. preprocess.py +4 -89
  3. utility.py +3 -48
app.py CHANGED
@@ -7,6 +7,54 @@ text = st.text_input('Enter text')
7
 
8
  # out = model()
9
 
 
 
 
10
  if st.button("Predict"):
11
- st.write("Genre: ")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  # st.write(out)
 
7
 
8
  # out = model()
9
 
10
+ model = AutoModelForSequenceClassification.from_pretrained("./bert-finetuned-sem_eval-english/checkpoint-36819")
11
+ # model.to('cuda')
12
+
13
  if st.button("Predict"):
14
+ st.write("Genre: ")
15
+
16
+ preprocess = Preprocess()
17
+ clean_plot = preprocess.apply(text)
18
+ st.write(clean_plot)
19
+
20
+ utility = Utility()
21
+ # id2label, label2id, tokenizer, df = utility.tokenize(df, genres)
22
+ # xtrain, xval, ytrain, yval = utility.train_test_split(df, y)
23
+ # xtrain_input_ids = [np.asarray(xtrain[i]['input_ids']) for i in range(xtrain.shape[0])]
24
+ # xtrain_attention_mask = [np.asarray(xtrain[i]['attention_mask']) for i in range(xtrain.shape[0])]
25
+ # xval_input_ids = [np.asarray(xval[i]['input_ids']) for i in range(xval.shape[0])]
26
+ # xval_attention_mask = [np.asarray(xval[i]['attention_mask']) for i in range(xval.shape[0])]
27
+
28
+ # # create Pandas DataFrame
29
+ # input_ids_labels_df = pd.DataFrame({'input_ids': xtrain_input_ids, 'attention_mask': xtrain_attention_mask, 'labels': ytrain.tolist()})
30
+ # # define data set object
31
+ # TD = CustomTextDataset(torch.IntTensor(input_ids_labels_df['input_ids']), torch.IntTensor(input_ids_labels_df['attention_mask']),\
32
+ # torch.FloatTensor(input_ids_labels_df['labels']))
33
+ # input_ids_labels_val_df = pd.DataFrame({'input_ids': xval_input_ids, 'attention_mask': xval_attention_mask, 'labels': yval.tolist()})
34
+ # VD = CustomTextDataset(torch.IntTensor(input_ids_labels_val_df['input_ids']), torch.IntTensor(input_ids_labels_val_df['attention_mask']),\
35
+ # torch.FloatTensor(input_ids_labels_val_df['labels']))
36
+
37
+ # # trainer = Trainer(
38
+ # # model,
39
+ # # train_dataset=TD,
40
+ # # eval_dataset=VD,
41
+ # # tokenizer=tokenizer,
42
+ # # compute_metrics=compute_metrics
43
+ # # )
44
+ # # y_pred = trainer.predict(VD)
45
+ # # y_pred = model(input_ids, attention_mask)
46
+ # preds = torch.FloatTensor(y_pred[0])
47
+
48
+ # y_predictions = []
49
+ # predictions = []
50
+
51
+ # for pred in preds:
52
+ # # apply sigmoid + threshold
53
+ # sigmoid = torch.nn.Sigmoid()
54
+ # probs = sigmoid(pred.squeeze().cpu())
55
+ # prediction = np.zeros(probs.shape)
56
+ # prediction[np.where(probs >= 0.5)] = 1
57
+ # predictions.append(prediction)
58
+ # y_pred = predictions
59
+
60
  # st.write(out)
preprocess.py CHANGED
@@ -2,76 +2,16 @@ import json
2
  import pickle
3
  import pandas as pd
4
  import nltk
5
- from matplotlib import pyplot as plt
6
- import seaborn as sns
7
  import regex as re
8
  from nltk.corpus import stopwords
9
- from sklearn.preprocessing import MultiLabelBinarizer
10
 
11
  class Preprocess:
12
- df = None
13
  genres = None
14
  y = None
15
 
16
  def __init__(self) -> None:
17
- self.df = pd.read_csv('movies_genre.csv')
18
  self.genres = []
19
 
20
- def plot_freq_dist(self):
21
- all_genres = sum(self.genres, [])
22
- all_genres = nltk.FreqDist(all_genres)
23
-
24
- # create frequency dataframe
25
- all_genres_df = pd.DataFrame({'Genres': list(all_genres.keys()),
26
- 'Count': list(all_genres.values())})
27
-
28
- g = all_genres_df.nlargest(columns="Count", n = 50)
29
- plt.figure(figsize=(12,15))
30
- ax = sns.barplot(data=g, x= "Count", y = "Genres")
31
- ax.set(xlabel = 'Count',ylabel= 'Genre')
32
- plt.show()
33
-
34
- # def extract_genre_values(self):
35
- # # extract genres
36
- # for row in self.df['genres']:
37
- # self.genres.append(list(json.loads(row.replace("\'", "\"")).values()))
38
-
39
- # # add to dataframe
40
- # self.df['genres'] = self.genres
41
-
42
- def retain_top_freq_genres(self):
43
- for (index, row) in enumerate(self.df['genres']):
44
- self.genres.append(json.loads(row.replace("\'", "\"")))
45
- self.df.at[index, "genres"] = json.loads(row.replace("\'", "\""))
46
- # create frequency dataframe
47
- all_genres = sum(self.genres,[])
48
- all_genres = nltk.FreqDist(all_genres)
49
- all_genres_df = pd.DataFrame({'Genres': list(all_genres.keys()),
50
- 'Count': list(all_genres.values())})
51
-
52
- # # considering only top 35 frequent genres
53
- # g = all_genres_df.nlargest(columns="Count", n = 35)
54
- # g.head()
55
- # top_genres = list(g['Genres'])
56
-
57
- # Genres with freq > 1000
58
- all_genres_df = all_genres_df[all_genres_df["Count"] >= 8000]
59
- top_genres = list(all_genres_df['Genres'])
60
-
61
- # Removing genres which are not important
62
- # top_genres.remove('Other')
63
- # top_genres.remove('Crime Thriller')
64
- # top_genres.remove('Movie')
65
- # top_genres.remove('History')
66
- # top_genres.remove('Bollywood')
67
-
68
- # Removing genres other than top selected genres
69
- for (index,row) in enumerate(self.df['genres']):
70
- row = [genre for genre in row if genre in top_genres]
71
- self.df.at[index, "genres"] = row
72
-
73
- return top_genres
74
-
75
  def clean_text(self, text):
76
  """Cleans text by removing certains unwanted characters"""
77
 
@@ -92,35 +32,10 @@ class Preprocess:
92
  no_stopword_text = [w for w in text.split() if not w in stop_words]
93
  return ' '.join(no_stopword_text)
94
 
 
 
95
 
96
- def multi_label_binarizer(self):
97
- multilabel_binarizer = MultiLabelBinarizer()
98
- multilabel_binarizer.fit(self.df['genres'])
99
-
100
- pickle.dump(multilabel_binarizer, open("models/multilabel_binarizer", 'wb'))
101
-
102
- # transform target variable
103
- self.y = multilabel_binarizer.transform(self.df['genres'])
104
-
105
-
106
- def apply(self):
107
- # remove samples with no plot
108
- self.df = self.df[~(pd.isna(self.df['plot']))]
109
- # self.df = self.df.head(20000)
110
- # removing rows which has very small plot fewer than 500 characters
111
- self.df = self.df[self.df["plot"].map(len) >= 500]
112
- self.df = self.df.reset_index()
113
-
114
- # self.extract_genre_values()
115
-
116
- genres = self.retain_top_freq_genres()
117
-
118
-
119
- self.df['clean_plot'] = self.df['plot'].apply(lambda x: self.clean_text(str(x)))
120
-
121
- self.df['clean_plot'] = self.df['clean_plot'].apply(lambda x: self.remove_stopwords(str(x)))
122
-
123
- self.multi_label_binarizer()
124
 
125
- return [self.df, self.y, genres]
126
 
 
2
  import pickle
3
  import pandas as pd
4
  import nltk
 
 
5
  import regex as re
6
  from nltk.corpus import stopwords
 
7
 
8
  class Preprocess:
 
9
  genres = None
10
  y = None
11
 
12
  def __init__(self) -> None:
 
13
  self.genres = []
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def clean_text(self, text):
16
  """Cleans text by removing certains unwanted characters"""
17
 
 
32
  no_stopword_text = [w for w in text.split() if not w in stop_words]
33
  return ' '.join(no_stopword_text)
34
 
35
+ def apply(self, plot):
36
+ clean_plot = self.clean_text(str(plot))
37
 
38
+ clean_plot = self.remove_stopwords(str(clean_plot))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ return clean_plot
41
 
utility.py CHANGED
@@ -1,54 +1,20 @@
1
  import pickle
2
  import wikipedia
3
  import numpy as np
4
- from sklearn.model_selection import train_test_split
5
- from skmultilearn.model_selection import iterative_train_test_split
6
- from sklearn.feature_extraction.text import TfidfVectorizer
7
  from transformers import AutoTokenizer
8
 
9
  class Utility:
10
  def __init__(self) -> None:
11
  pass
12
 
13
- def get_summary(self,url):
14
- summary = ""
15
- try:
16
- title = url.split("wiki/")[-1]
17
- print(title)
18
-
19
- wiki = wikipedia.page(title=title)
20
-
21
- summary = wiki.summary
22
- except:
23
- pass
24
-
25
- return summary
26
-
27
-
28
- def get_plot(self,url):
29
- plot=""
30
-
31
- try:
32
- title = url.split("wiki/")[-1]
33
- wiki = wikipedia.page(title=title)
34
-
35
- content = wiki.content.split('== Plot ==\n')[1]
36
- plot = content.split('==')[0]
37
- except:
38
- pass
39
-
40
- return plot
41
-
42
- def tokenize(self, df, genres):
43
  id2label = {idx:label for idx, label in enumerate(genres)}
44
  label2id = {label:idx for idx, label in enumerate(genres)}
45
 
46
  tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
47
 
48
- df['clean_plot_tokenized'] = ''
49
- for (idx, row) in df.iterrows():
50
- df.at[idx,"clean_plot_tokenized"] = tokenizer(row["clean_plot"], padding="max_length", truncation=True, max_length=512)
51
- return (id2label, label2id, tokenizer, df)
52
 
53
  def train_test_split(self, df, y):
54
  """Splits the dataset into training and validation set"""
@@ -61,14 +27,3 @@ class Utility:
61
  xval = np.array(xval).flatten()
62
 
63
  return (xtrain, xval, ytrain, yval)
64
-
65
-
66
- def vectorize(self, xtrain, xval):
67
- """Creates TF-IDF features"""
68
- tfidf_vectorizer = TfidfVectorizer(max_df=0.8, max_features=10000)
69
- xtrain_tfidf = tfidf_vectorizer.fit_transform(xtrain)
70
- xval_tfidf = tfidf_vectorizer.transform(xval)
71
-
72
- pickle.dump(tfidf_vectorizer, open("models/tfidf_vectorizer", 'wb'))
73
-
74
- return (xtrain_tfidf, xval_tfidf)
 
1
  import pickle
2
  import wikipedia
3
  import numpy as np
 
 
 
4
  from transformers import AutoTokenizer
5
 
6
  class Utility:
7
  def __init__(self) -> None:
8
  pass
9
 
10
+ def tokenize(self, plot, genres):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  id2label = {idx:label for idx, label in enumerate(genres)}
12
  label2id = {label:idx for idx, label in enumerate(genres)}
13
 
14
  tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
15
 
16
+ clean_plot_tokenized = tokenizer(plot, padding="max_length", truncation=True, max_length=512)
17
+ return (id2label, label2id, tokenizer, clean_plot_tokenized)
 
 
18
 
19
  def train_test_split(self, df, y):
20
  """Splits the dataset into training and validation set"""
 
27
  xval = np.array(xval).flatten()
28
 
29
  return (xtrain, xval, ytrain, yval)