from sentence_transformers import SentenceTransformer from sentence_transformers.cross_encoder import CrossEncoder import os import numpy as np from datetime import datetime from pymilvus import connections, utility, Collection, DataType, FieldSchema, CollectionSchema import logging FORMAT = '%(asctime)s %(message)s' logging.basicConfig(format=FORMAT) logger = logging.getLogger('hf_logger') def load_sentence_transformer(): try: logger.warning('Entering load_sentence_transformer') sent_model = SentenceTransformer('all-mpnet-base-v2') ce_model = CrossEncoder('cross-encoder/stsb-distilroberta-base') logger.warning('Exiting load_sentence_transformer') except Exception as e: logger.warning(f"load_sentence_transformer error: {e}") raise return sent_model, ce_model class TextVectorizer: ''' sentence transformers to extract sentence embeddings ''' def vectorize_(self, x, sent_model): try: logger.warning('Entering vectorize_()') sent_embeddings = sent_model.encode(x, normalize_embeddings=True) logger.warning('Exiting vectorize_()') except Exception as e: logger.warning(f"vectorize() error: {e}") raise return sent_embeddings def get_milvus_collection(): try: logger.warning('Entering get_milvus_collection()') uri = os.environ.get("URI") token = os.environ.get("TOKEN") connections.connect("default", uri=uri, token=token) collection_name = os.environ.get("COLLECTION_NAME") collection = Collection(name=collection_name) print(f"Loaded collection") logger.warning('Exiting get_milvus_collection()') except Exception as e: logger.warning(f"get_milvus_collection() error: {e}") raise return collection def find_similar_news(text, search_vec, collection, vectorizer, sent_model, ce_model, top_n: int=10): try: logger.warning('Entering find_similar_news') search_params = {"metric_type": "IP"} logger.warning('Querying Milvus for most similar results') results = collection.search([search_vec], anns_field='article_embed', # annotations field specified in the schema definition param=search_params, limit=top_n, guarantee_timestamp=1, output_fields=['article_title', 'article_url'])[0] # which fields to return in output logger.warning('retrieved search results from Milvus') logger.warning('Computing cross encoder similarity scores') texts = [result.entity.get('article_title') for result in results] ce_similarity_scores = np.array(ce_model.predict([[text, output_text] for output_text in texts])) similarity_idxs = [*np.argsort(ce_similarity_scores)[::-1]] logger.warning('Retrieved cross encoder similarity scores') logger.warning('Generating HTML output') html_output = "" article_count = 0 for n, i in enumerate(similarity_idxs): title_ = results[i].entity.get('article_title') url_ = results[i].entity.get('article_url') if title_ != text: html_output += f'''{title_}
''' article_count += 1 if article_count == 5 : break logger.warning('Successfully generated HTML output') logger.warning('Exiting find_similar_news') except Exception as e: logger.warning(f"find_similar_news() error: {e}") raise return html_output