lalithadevi's picture
Update find_similar_news.py
7825e9d verified
from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import CrossEncoder
import os
import numpy as np
from datetime import datetime
from pymilvus import connections, utility, Collection, DataType, FieldSchema, CollectionSchema
import logging
FORMAT = '%(asctime)s %(message)s'
logging.basicConfig(format=FORMAT)
logger = logging.getLogger('hf_logger')
def load_sentence_transformer():
try:
logger.warning('Entering load_sentence_transformer')
sent_model = SentenceTransformer('all-mpnet-base-v2')
ce_model = CrossEncoder('cross-encoder/stsb-distilroberta-base')
logger.warning('Exiting load_sentence_transformer')
except Exception as e:
logger.warning(f"load_sentence_transformer error: {e}")
raise
return sent_model, ce_model
class TextVectorizer:
'''
sentence transformers to extract sentence embeddings
'''
def vectorize_(self, x, sent_model):
try:
logger.warning('Entering vectorize_()')
sent_embeddings = sent_model.encode(x, normalize_embeddings=True)
logger.warning('Exiting vectorize_()')
except Exception as e:
logger.warning(f"vectorize() error: {e}")
raise
return sent_embeddings
def get_milvus_collection():
try:
logger.warning('Entering get_milvus_collection()')
uri = os.environ.get("URI")
token = os.environ.get("TOKEN")
connections.connect("default", uri=uri, token=token)
collection_name = os.environ.get("COLLECTION_NAME")
collection = Collection(name=collection_name)
print(f"Loaded collection")
logger.warning('Exiting get_milvus_collection()')
except Exception as e:
logger.warning(f"get_milvus_collection() error: {e}")
raise
return collection
def find_similar_news(text, search_vec, collection, vectorizer, sent_model, ce_model, top_n: int=10):
try:
logger.warning('Entering find_similar_news')
search_params = {"metric_type": "IP"}
logger.warning('Querying Milvus for most similar results')
results = collection.search([search_vec],
anns_field='article_embed', # annotations field specified in the schema definition
param=search_params,
limit=top_n,
guarantee_timestamp=1,
output_fields=['article_title', 'article_url'])[0] # which fields to return in output
logger.warning('retrieved search results from Milvus')
logger.warning('Computing cross encoder similarity scores')
texts = [result.entity.get('article_title') for result in results]
ce_similarity_scores = np.array(ce_model.predict([[text, output_text] for output_text in texts]))
similarity_idxs = [*np.argsort(ce_similarity_scores)[::-1]]
logger.warning('Retrieved cross encoder similarity scores')
logger.warning('Generating HTML output')
html_output = ""
article_count = 0
for n, i in enumerate(similarity_idxs):
title_ = results[i].entity.get('article_title')
url_ = results[i].entity.get('article_url')
if title_ != text:
html_output += f'''<a class="similar-news-item" href="{url_}" target="_blank">{title_}</a><br>
'''
article_count += 1
if article_count == 5 :
break
logger.warning('Successfully generated HTML output')
logger.warning('Exiting find_similar_news')
except Exception as e:
logger.warning(f"find_similar_news() error: {e}")
raise
return html_output