|
from sentence_transformers import SentenceTransformer |
|
from sentence_transformers.cross_encoder import CrossEncoder |
|
import os |
|
import numpy as np |
|
from datetime import datetime |
|
from pymilvus import connections, utility, Collection, DataType, FieldSchema, CollectionSchema |
|
import logging |
|
|
|
|
|
FORMAT = '%(asctime)s %(message)s' |
|
logging.basicConfig(format=FORMAT) |
|
logger = logging.getLogger('hf_logger') |
|
|
|
|
|
def load_sentence_transformer(): |
|
logger.warning('Entering load_sentence_transformer') |
|
sent_model = SentenceTransformer('all-mpnet-base-v2') |
|
ce_model = CrossEncoder('cross-encoder/stsb-distilroberta-base') |
|
logger.warning('Exiting load_sentence_transformer') |
|
return sent_model, ce_model |
|
|
|
|
|
class TextVectorizer: |
|
''' |
|
sentence transformers to extract sentence embeddings |
|
''' |
|
|
|
def vectorize_(self, x, sent_model): |
|
logger.warning('Entering vectorize_()') |
|
sent_embeddings = sent_model.encode(x, normalize_embeddings=True) |
|
logger.warning('Exiting vectorize_()') |
|
return sent_embeddings |
|
|
|
|
|
def get_milvus_collection(): |
|
logger.warning('Entering get_milvus_collection()') |
|
uri = os.environ.get("URI") |
|
token = os.environ.get("TOKEN") |
|
connections.connect("default", uri=uri, token=token) |
|
collection_name = os.environ.get("COLLECTION_NAME") |
|
collection = Collection(name=collection_name) |
|
print(f"Loaded collection") |
|
logger.warning('Exiting get_milvus_collection()') |
|
return collection |
|
|
|
def find_similar_news(text, search_vec, collection, vectorizer, sent_model, ce_model, top_n: int=10): |
|
logger.warning('Entering find_similar_news') |
|
search_params = {"metric_type": "IP"} |
|
logger.warning('Querying Milvus for most similar results') |
|
results = collection.search([search_vec], |
|
anns_field='article_embed', |
|
param=search_params, |
|
limit=top_n, |
|
guarantee_timestamp=1, |
|
output_fields=['article_title', 'article_url'])[0] |
|
|
|
logger.warning('retrieved search results from Milvus') |
|
logger.warning('Computing cross encoder similarity scores') |
|
texts = [result.entity.get('article_title') for result in results] |
|
ce_similarity_scores = np.array(ce_model.predict([[text, output_text] for output_text in texts])) |
|
similarity_idxs = [*np.argsort(ce_similarity_scores)[::-1]] |
|
logger.warning('Retrieved cross encoder similarity scores') |
|
|
|
logger.warning('Generating HTML output') |
|
html_output = "" |
|
article_count = 0 |
|
for n, i in enumerate(similarity_idxs): |
|
title_ = results[i].entity.get('article_title') |
|
url_ = results[i].entity.get('article_url') |
|
if title_ != text: |
|
html_output += f'''<a class="similar-news-item" href="{url_}" target="_blank">{title_}</a><br> |
|
''' |
|
article_count += 1 |
|
|
|
if article_count == 5 : |
|
break |
|
|
|
|
|
logger.warning('Successfully generated HTML output') |
|
logger.warning('Exiting find_similar_news') |
|
return html_output |
|
|