File size: 3,899 Bytes
f963240 ccdd011 7861c2b ccdd011 f963240 178e147 ccdd011 7825e9d ccdd011 f963240 ccdd011 7861c2b ccdd011 f963240 a3697ff ccdd011 f963240 ccdd011 7861c2b f963240 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import CrossEncoder
import os
import numpy as np
from datetime import datetime
from pymilvus import connections, utility, Collection, DataType, FieldSchema, CollectionSchema
import logging
FORMAT = '%(asctime)s %(message)s'
logging.basicConfig(format=FORMAT)
logger = logging.getLogger('hf_logger')
def load_sentence_transformer():
try:
logger.warning('Entering load_sentence_transformer')
sent_model = SentenceTransformer('all-mpnet-base-v2')
ce_model = CrossEncoder('cross-encoder/stsb-distilroberta-base')
logger.warning('Exiting load_sentence_transformer')
except Exception as e:
logger.warning(f"load_sentence_transformer error: {e}")
raise
return sent_model, ce_model
class TextVectorizer:
'''
sentence transformers to extract sentence embeddings
'''
def vectorize_(self, x, sent_model):
try:
logger.warning('Entering vectorize_()')
sent_embeddings = sent_model.encode(x, normalize_embeddings=True)
logger.warning('Exiting vectorize_()')
except Exception as e:
logger.warning(f"vectorize() error: {e}")
raise
return sent_embeddings
def get_milvus_collection():
try:
logger.warning('Entering get_milvus_collection()')
uri = os.environ.get("URI")
token = os.environ.get("TOKEN")
connections.connect("default", uri=uri, token=token)
collection_name = os.environ.get("COLLECTION_NAME")
collection = Collection(name=collection_name)
print(f"Loaded collection")
logger.warning('Exiting get_milvus_collection()')
except Exception as e:
logger.warning(f"get_milvus_collection() error: {e}")
raise
return collection
def find_similar_news(text, search_vec, collection, vectorizer, sent_model, ce_model, top_n: int=10):
try:
logger.warning('Entering find_similar_news')
search_params = {"metric_type": "IP"}
logger.warning('Querying Milvus for most similar results')
results = collection.search([search_vec],
anns_field='article_embed', # annotations field specified in the schema definition
param=search_params,
limit=top_n,
guarantee_timestamp=1,
output_fields=['article_title', 'article_url'])[0] # which fields to return in output
logger.warning('retrieved search results from Milvus')
logger.warning('Computing cross encoder similarity scores')
texts = [result.entity.get('article_title') for result in results]
ce_similarity_scores = np.array(ce_model.predict([[text, output_text] for output_text in texts]))
similarity_idxs = [*np.argsort(ce_similarity_scores)[::-1]]
logger.warning('Retrieved cross encoder similarity scores')
logger.warning('Generating HTML output')
html_output = ""
article_count = 0
for n, i in enumerate(similarity_idxs):
title_ = results[i].entity.get('article_title')
url_ = results[i].entity.get('article_url')
if title_ != text:
html_output += f'''<a class="similar-news-item" href="{url_}" target="_blank">{title_}</a><br>
'''
article_count += 1
if article_count == 5 :
break
logger.warning('Successfully generated HTML output')
logger.warning('Exiting find_similar_news')
except Exception as e:
logger.warning(f"find_similar_news() error: {e}")
raise
return html_output
|