|
import gensim |
|
import numpy as np |
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
from transformers import AutoTokenizer, AutoModel |
|
import torch |
|
|
|
|
|
def classify_by_topic(articles, central_topics): |
|
|
|
|
|
def compute_similarity(articles, central_topics): |
|
|
|
model = AutoModel.from_pretrained("distilbert-base-multilingual-cased") |
|
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-multilingual-cased") |
|
|
|
def sentence_to_vector(sentence, context): |
|
|
|
sentence = context[0]+context[1]+sentence*4+context[2]+context[3] |
|
tokens = tokenizer.encode_plus( |
|
sentence, add_special_tokens=True, return_tensors="pt",max_length = 512,truncation=True) |
|
|
|
outputs = model(**tokens) |
|
hidden_states = outputs.last_hidden_state |
|
|
|
vector = np.squeeze(torch.mean( |
|
hidden_states, dim=1).detach().numpy()) |
|
return vector |
|
|
|
|
|
def get_context(sentences, index): |
|
if index == 0: |
|
prev_sentence = "" |
|
pprev_sentence = "" |
|
elif index == 1: |
|
prev_sentence = sentences[index-1] |
|
pprev_sentence = "" |
|
else: |
|
prev_sentence = sentences[index-1] |
|
pprev_sentence = sentences[index-2] |
|
if index == len(sentences) - 1: |
|
next_sentence = "" |
|
nnext_sentence = "" |
|
elif index == len(sentences) - 2: |
|
next_sentence = sentences[index+1] |
|
nnext_sentence = "" |
|
else: |
|
next_sentence = sentences[index+1] |
|
nnext_sentence = sentences[index+2] |
|
return (pprev_sentence, prev_sentence, next_sentence, nnext_sentence) |
|
|
|
doc_vectors = [sentence_to_vector(sentence, get_context( |
|
articles, i)) for i, sentence in enumerate(articles)] |
|
topic_vectors = [sentence_to_vector(sentence, get_context( |
|
central_topics, i)) for i, sentence in enumerate(central_topics)] |
|
|
|
cos_sim_matrix = cosine_similarity(doc_vectors, topic_vectors) |
|
|
|
return cos_sim_matrix |
|
|
|
|
|
def group_by_topic(articles, central_topics, similarity_matrix): |
|
group = [] |
|
original_articles = articles.copy() |
|
for article, similarity in zip(original_articles, similarity_matrix): |
|
max_similarity = max(similarity) |
|
max_index = similarity.tolist().index(max_similarity) |
|
|
|
group.append((article, central_topics[max_index])) |
|
|
|
return group |
|
|
|
|
|
similarity_matrix = compute_similarity(articles, central_topics) |
|
groups = group_by_topic(articles, central_topics, similarity_matrix) |
|
|
|
return groups |
|
|