WatchTower / Pinpoint /Aggregator_NGram.py
James Stevenson
initial commit
32a03a4
raw
history blame
3.33 kB
from sklearn.feature_extraction.text import CountVectorizer
from Pinpoint.Logger import *
c_vec = CountVectorizer(ngram_range=(1, 5))
class n_gram_aggregator():
"""
This class is used to retrieve the most common NGrams for a given dataset corpus.
"""
def _get_average_ngram_count(self, n_grams_dict):
"""
takes a dict of Ngrams and identifies the average weighting
:param n_grams_dict:
:return:
"""
all_count = []
for n_gram in n_grams_dict:
ng_count = n_grams_dict[n_gram]
all_count.append(ng_count)
average_count = sum(all_count) / len(all_count)
# print(all_count)
return average_count
def _get_all_ngrams(self, data):
"""
Returns all ngrams (tri, bi, and uni) for a given piece of text
:param data:
:return:
"""
if type(data) is not list:
data = [data]
# input to fit_transform() should be an iterable with strings
ngrams = c_vec.fit_transform(data)
# needs to happen after fit_transform()
vocab = c_vec.vocabulary_
count_values = ngrams.toarray().sum(axis=0)
# output n-grams
uni_grams = {}
bi_grams = {}
tri_grams = {}
for ng_count, ng_text in sorted([(count_values[i], k) for k, i in vocab.items()], reverse=True):
sentence_length = len(ng_text.split(" "))
if sentence_length == 3:
tri_grams[ng_text] = ng_count
elif sentence_length == 2:
bi_grams[ng_text] = ng_count
elif sentence_length == 1:
uni_grams[ng_text] = ng_count
return uni_grams, bi_grams, tri_grams
def _get_popular_ngrams(self, ngrams_dict):
"""
Returns ngrams for a given piece of text that are the most popular (i.e. their weighting is
above the average ngram wighting)
:param ngrams_dict:
:return:
"""
average_count = self._get_average_ngram_count(ngrams_dict)
popular_ngrams = {}
for n_gram in ngrams_dict:
ng_count = ngrams_dict[n_gram]
if ng_count >= average_count:
popular_ngrams[n_gram] = ng_count
return popular_ngrams
def get_ngrams(self, data=None, file_name_to_read=None):
"""
Wrapper function for returning uni, bi, and tri grams that are the most popular (above the average weighting in
a given piece of text).
:param data:
:param file_name_to_read:
:return:
"""
logger().print_message("Getting Ngrams")
if data is None and file_name_to_read is None:
raise Exception("No data supplied to retrieve n_grams")
if data is None and file_name_to_read is not None:
with open(file_name_to_read, 'r') as file_to_read:
data = file_to_read.read()
uni_grams, bi_grams, tri_grams = self._get_all_ngrams(data)
popular_uni_grams = list(self._get_popular_ngrams(uni_grams).keys())
popular_bi_grams = list(self._get_popular_ngrams(bi_grams).keys())
popular_tri_grams = list(self._get_popular_ngrams(tri_grams).keys())
return popular_uni_grams, popular_bi_grams, popular_tri_grams