lalithadevi commited on
Commit
f963240
1 Parent(s): 1da191b

Create find_similar_news.py

Browse files
Files changed (1) hide show
  1. find_similar_news.py +74 -0
find_similar_news.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ from sentence_transformers.cross_encoder import CrossEncoder
3
+ import os
4
+ import numpy as np
5
+ from datetime import datetime
6
+ from pymilvus import connections, utility, Collection, DataType, FieldSchema, CollectionSchema
7
+ import logging
8
+
9
+
10
+ FORMAT = '%(asctime)s %(message)s'
11
+ logging.basicConfig(format=FORMAT)
12
+ logger = logging.getLogger('hf_logger')
13
+
14
+
15
+ def load_sentence_transformer():
16
+ logger.warning('Entering load_sentence_transformer')
17
+ sent_model = SentenceTransformer('all-mpnet-base-v2')
18
+ ce_model = CrossEncoder('cross-encoder/stsb-distilroberta-base')
19
+ logger.warning('Exiting load_sentence_transformer')
20
+ return sent_model, ce_model
21
+
22
+
23
+ class TextVectorizer:
24
+ '''
25
+ sentence transformers to extract sentence embeddings
26
+ '''
27
+
28
+ def vectorize_(self, x):
29
+ logger.warning('Entering vectorize_()')
30
+ sent_embeddings = sent_model.encode(x, normalize_embeddings=True)
31
+ logger.warning('Exiting vectorize_()')
32
+ return sent_embeddings
33
+
34
+
35
+ def get_milvus_collection():
36
+ logger.warning('Entering get_milvus_collection()')
37
+ uri = os.environ.get("URI")
38
+ token = os.environ.get("TOKEN")
39
+ connections.connect("default", uri=uri, token=token)
40
+ collection_name = os.environ.get("COLLECTION_NAME")
41
+ collection = Collection(name=collection_name)
42
+ print(f"Loaded collection")
43
+ logger.warning('Exiting get_milvus_collection()')
44
+ return collection
45
+
46
+ def find_similar_news(text: str, collection, vectorizer, sent_model, ce_model, top_n: int=5):
47
+ logger.warning('Entering find_similar_news')
48
+ search_params = {"metric_type": "IP"}
49
+ search_vec = vectorizer.vectorize_(text)
50
+ logger.warning('Querying Milvus for most similar results')
51
+ results = collection.search([search_vec],
52
+ anns_field='article_embed', # annotations field specified in the schema definition
53
+ param=search_params,
54
+ limit=top_n,
55
+ guarantee_timestamp=1,
56
+ output_fields=['article_title', 'article_src', 'article_url', 'article_date'])[0] # which fields to return in output
57
+
58
+ logger.warning('retrieved search results from Milvus')
59
+ logger.warning('Computing cross encoder similarity scores')
60
+ texts = [result.entity.get('article_title') for result in results]
61
+ ce_similarity_scores = np.array(ce_model.predict([[text, output_text] for output_text in texts]))
62
+ similarity_idxs = [*np.argsort(ce_similarity_scores)[::-1]]
63
+ logger.warning('Retrieved cross encoder similarity scores')
64
+
65
+ logger.warning('Generating HTML output')
66
+ html_output = ""
67
+ for n, i in enumerate(similarity_idxs):
68
+ title_ = results[i].entity.get('article_title')
69
+ url_ = results[i].entity.get('article_url')
70
+ html_output += f'''<a style="font-weight: bold; font-size:14px; color: black;" href="{url_}" target="_blank">{title_}</a><br>
71
+ '''
72
+ logger.warning('Successfully generated HTML output')
73
+ logger.warning('Exiting find_similar_news')
74
+ return html_output