Commit
•
f963240
1
Parent(s):
1da191b
Create find_similar_news.py
Browse files- find_similar_news.py +74 -0
find_similar_news.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sentence_transformers import SentenceTransformer
|
2 |
+
from sentence_transformers.cross_encoder import CrossEncoder
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
from datetime import datetime
|
6 |
+
from pymilvus import connections, utility, Collection, DataType, FieldSchema, CollectionSchema
|
7 |
+
import logging
|
8 |
+
|
9 |
+
|
10 |
+
FORMAT = '%(asctime)s %(message)s'
|
11 |
+
logging.basicConfig(format=FORMAT)
|
12 |
+
logger = logging.getLogger('hf_logger')
|
13 |
+
|
14 |
+
|
15 |
+
def load_sentence_transformer():
|
16 |
+
logger.warning('Entering load_sentence_transformer')
|
17 |
+
sent_model = SentenceTransformer('all-mpnet-base-v2')
|
18 |
+
ce_model = CrossEncoder('cross-encoder/stsb-distilroberta-base')
|
19 |
+
logger.warning('Exiting load_sentence_transformer')
|
20 |
+
return sent_model, ce_model
|
21 |
+
|
22 |
+
|
23 |
+
class TextVectorizer:
|
24 |
+
'''
|
25 |
+
sentence transformers to extract sentence embeddings
|
26 |
+
'''
|
27 |
+
|
28 |
+
def vectorize_(self, x):
|
29 |
+
logger.warning('Entering vectorize_()')
|
30 |
+
sent_embeddings = sent_model.encode(x, normalize_embeddings=True)
|
31 |
+
logger.warning('Exiting vectorize_()')
|
32 |
+
return sent_embeddings
|
33 |
+
|
34 |
+
|
35 |
+
def get_milvus_collection():
|
36 |
+
logger.warning('Entering get_milvus_collection()')
|
37 |
+
uri = os.environ.get("URI")
|
38 |
+
token = os.environ.get("TOKEN")
|
39 |
+
connections.connect("default", uri=uri, token=token)
|
40 |
+
collection_name = os.environ.get("COLLECTION_NAME")
|
41 |
+
collection = Collection(name=collection_name)
|
42 |
+
print(f"Loaded collection")
|
43 |
+
logger.warning('Exiting get_milvus_collection()')
|
44 |
+
return collection
|
45 |
+
|
46 |
+
def find_similar_news(text: str, collection, vectorizer, sent_model, ce_model, top_n: int=5):
|
47 |
+
logger.warning('Entering find_similar_news')
|
48 |
+
search_params = {"metric_type": "IP"}
|
49 |
+
search_vec = vectorizer.vectorize_(text)
|
50 |
+
logger.warning('Querying Milvus for most similar results')
|
51 |
+
results = collection.search([search_vec],
|
52 |
+
anns_field='article_embed', # annotations field specified in the schema definition
|
53 |
+
param=search_params,
|
54 |
+
limit=top_n,
|
55 |
+
guarantee_timestamp=1,
|
56 |
+
output_fields=['article_title', 'article_src', 'article_url', 'article_date'])[0] # which fields to return in output
|
57 |
+
|
58 |
+
logger.warning('retrieved search results from Milvus')
|
59 |
+
logger.warning('Computing cross encoder similarity scores')
|
60 |
+
texts = [result.entity.get('article_title') for result in results]
|
61 |
+
ce_similarity_scores = np.array(ce_model.predict([[text, output_text] for output_text in texts]))
|
62 |
+
similarity_idxs = [*np.argsort(ce_similarity_scores)[::-1]]
|
63 |
+
logger.warning('Retrieved cross encoder similarity scores')
|
64 |
+
|
65 |
+
logger.warning('Generating HTML output')
|
66 |
+
html_output = ""
|
67 |
+
for n, i in enumerate(similarity_idxs):
|
68 |
+
title_ = results[i].entity.get('article_title')
|
69 |
+
url_ = results[i].entity.get('article_url')
|
70 |
+
html_output += f'''<a style="font-weight: bold; font-size:14px; color: black;" href="{url_}" target="_blank">{title_}</a><br>
|
71 |
+
'''
|
72 |
+
logger.warning('Successfully generated HTML output')
|
73 |
+
logger.warning('Exiting find_similar_news')
|
74 |
+
return html_output
|