lalithadevi
commited on
Commit
•
1462b37
1
Parent(s):
f963240
Rename news_category_prediction.py to news_category_similar_news_prediction.py
Browse files
news_category_prediction.py → news_category_similar_news_prediction.py
RENAMED
@@ -6,6 +6,7 @@ from config import (DISTILBERT_TOKENIZER_N_TOKENS,
|
|
6 |
CLASSIFIER_THRESHOLD)
|
7 |
|
8 |
from logger import get_logger
|
|
|
9 |
|
10 |
logger = get_logger()
|
11 |
|
@@ -43,7 +44,8 @@ def cols_check(new_cols, old_cols):
|
|
43 |
return all([new_col==old_col for new_col, old_col in zip(new_cols, old_cols)])
|
44 |
|
45 |
|
46 |
-
def
|
|
|
47 |
try:
|
48 |
db_updation_required = 1
|
49 |
logger.warning('Entering predict_news_category()')
|
@@ -73,6 +75,8 @@ def predict_news_category(old_news: pd.DataFrame, new_news: pd.DataFrame, interp
|
|
73 |
final_df.drop_duplicates(subset='url', keep='first', inplace=True)
|
74 |
headlines = [*final_df['title']].copy()
|
75 |
label, prob = inference(headlines, interpreter, label_encoder, tokenizer)
|
|
|
|
|
76 |
final_df['category'] = label
|
77 |
final_df['pred_proba'] = prob
|
78 |
final_df.reset_index(drop=True, inplace=True)
|
|
|
6 |
CLASSIFIER_THRESHOLD)
|
7 |
|
8 |
from logger import get_logger
|
9 |
+
from find_similar_news import find_similar_news
|
10 |
|
11 |
logger = get_logger()
|
12 |
|
|
|
44 |
return all([new_col==old_col for new_col, old_col in zip(new_cols, old_cols)])
|
45 |
|
46 |
|
47 |
+
def predict_news_category_similar_news(old_news: pd.DataFrame, new_news: pd.DataFrame, interpreter, label_encoder, tokenizer,
|
48 |
+
collection, vectorizer, sent_model, ce_model):
|
49 |
try:
|
50 |
db_updation_required = 1
|
51 |
logger.warning('Entering predict_news_category()')
|
|
|
75 |
final_df.drop_duplicates(subset='url', keep='first', inplace=True)
|
76 |
headlines = [*final_df['title']].copy()
|
77 |
label, prob = inference(headlines, interpreter, label_encoder, tokenizer)
|
78 |
+
sent_embs = vectorizer.vectorize(headlines)
|
79 |
+
sim_news = [find_similar_news(text, collection, vectorizer, sent_model, ce_model) for text in sent_embs]
|
80 |
final_df['category'] = label
|
81 |
final_df['pred_proba'] = prob
|
82 |
final_df.reset_index(drop=True, inplace=True)
|