lalithadevi
commited on
Commit
•
ccdd011
1
Parent(s):
ffede1e
Update find_similar_news.py
Browse files- find_similar_news.py +63 -48
find_similar_news.py
CHANGED
@@ -13,10 +13,14 @@ logger = logging.getLogger('hf_logger')
|
|
13 |
|
14 |
|
15 |
def load_sentence_transformer():
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
20 |
return sent_model, ce_model
|
21 |
|
22 |
|
@@ -26,56 +30,67 @@ class TextVectorizer:
|
|
26 |
'''
|
27 |
|
28 |
def vectorize_(self, x, sent_model):
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
32 |
return sent_embeddings
|
33 |
|
34 |
|
35 |
def get_milvus_collection():
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
44 |
return collection
|
45 |
|
46 |
def find_similar_news(text, search_vec, collection, vectorizer, sent_model, ce_model, top_n: int=10):
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
break
|
77 |
-
|
78 |
-
|
79 |
-
logger.warning('Successfully generated HTML output')
|
80 |
-
logger.warning('Exiting find_similar_news')
|
81 |
return html_output
|
|
|
13 |
|
14 |
|
15 |
def load_sentence_transformer():
|
16 |
+
try:
|
17 |
+
logger.warning('Entering load_sentence_transformer')
|
18 |
+
sent_model = SentenceTransformer('all-mpnet-base-v2')
|
19 |
+
ce_model = CrossEncoder('cross-encoder/stsb-distilroberta-base')
|
20 |
+
logger.warning('Exiting load_sentence_transformer')
|
21 |
+
except Exception as e:
|
22 |
+
logger.warning(f"load_sentence_transformer error: {e}")
|
23 |
+
|
24 |
return sent_model, ce_model
|
25 |
|
26 |
|
|
|
30 |
'''
|
31 |
|
32 |
def vectorize_(self, x, sent_model):
|
33 |
+
try:
|
34 |
+
logger.warning('Entering vectorize_()')
|
35 |
+
sent_embeddings = sent_model.encode(x, normalize_embeddings=True)
|
36 |
+
logger.warning('Exiting vectorize_()')
|
37 |
+
except Exception as e:
|
38 |
+
logger.warning(f"vectorize() error: {e}")
|
39 |
+
|
40 |
return sent_embeddings
|
41 |
|
42 |
|
43 |
def get_milvus_collection():
|
44 |
+
try:
|
45 |
+
logger.warning('Entering get_milvus_collection()')
|
46 |
+
uri = os.environ.get("URI")
|
47 |
+
token = os.environ.get("TOKEN")
|
48 |
+
connections.connect("default", uri=uri, token=token)
|
49 |
+
collection_name = os.environ.get("COLLECTION_NAME")
|
50 |
+
collection = Collection(name=collection_name)
|
51 |
+
print(f"Loaded collection")
|
52 |
+
logger.warning('Exiting get_milvus_collection()')
|
53 |
+
except Exception as e:
|
54 |
+
logger.warning(f"get_milvus_collection() error: {e}")
|
55 |
+
|
56 |
return collection
|
57 |
|
58 |
def find_similar_news(text, search_vec, collection, vectorizer, sent_model, ce_model, top_n: int=10):
|
59 |
+
try:
|
60 |
+
logger.warning('Entering find_similar_news')
|
61 |
+
search_params = {"metric_type": "IP"}
|
62 |
+
logger.warning('Querying Milvus for most similar results')
|
63 |
+
results = collection.search([search_vec],
|
64 |
+
anns_field='article_embed', # annotations field specified in the schema definition
|
65 |
+
param=search_params,
|
66 |
+
limit=top_n,
|
67 |
+
guarantee_timestamp=1,
|
68 |
+
output_fields=['article_title', 'article_url'])[0] # which fields to return in output
|
69 |
+
|
70 |
+
logger.warning('retrieved search results from Milvus')
|
71 |
+
logger.warning('Computing cross encoder similarity scores')
|
72 |
+
texts = [result.entity.get('article_title') for result in results]
|
73 |
+
ce_similarity_scores = np.array(ce_model.predict([[text, output_text] for output_text in texts]))
|
74 |
+
similarity_idxs = [*np.argsort(ce_similarity_scores)[::-1]]
|
75 |
+
logger.warning('Retrieved cross encoder similarity scores')
|
76 |
|
77 |
+
logger.warning('Generating HTML output')
|
78 |
+
html_output = ""
|
79 |
+
article_count = 0
|
80 |
+
for n, i in enumerate(similarity_idxs):
|
81 |
+
title_ = results[i].entity.get('article_title')
|
82 |
+
url_ = results[i].entity.get('article_url')
|
83 |
+
if title_ != text:
|
84 |
+
html_output += f'''<a class="similar-news-item" href="{url_}" target="_blank">{title_}</a><br>
|
85 |
+
'''
|
86 |
+
article_count += 1
|
87 |
+
|
88 |
+
if article_count == 5 :
|
89 |
+
break
|
90 |
+
|
91 |
+
|
92 |
+
logger.warning('Successfully generated HTML output')
|
93 |
+
logger.warning('Exiting find_similar_news')
|
94 |
+
except Exception as e:
|
95 |
+
logger.warning(f"find_similar_news() error: {e}")
|
|
|
|
|
|
|
|
|
|
|
96 |
return html_output
|