File size: 3,150 Bytes
f963240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
890a875
f963240
 
890a875
f963240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import CrossEncoder
import os
import numpy as np
from datetime import datetime
from pymilvus import connections, utility, Collection, DataType, FieldSchema, CollectionSchema
import logging


FORMAT = '%(asctime)s %(message)s'
logging.basicConfig(format=FORMAT)
logger = logging.getLogger('hf_logger')


def load_sentence_transformer():
    logger.warning('Entering load_sentence_transformer')
    sent_model = SentenceTransformer('all-mpnet-base-v2')
    ce_model = CrossEncoder('cross-encoder/stsb-distilroberta-base')
    logger.warning('Exiting load_sentence_transformer')
    return sent_model, ce_model

    
class TextVectorizer:
    '''
    sentence transformers to extract sentence embeddings
    '''
    
    def vectorize_(self, x):
        logger.warning('Entering vectorize_()')
        sent_embeddings = sent_model.encode(x, normalize_embeddings=True)
        logger.warning('Exiting vectorize_()')
        return sent_embeddings
    

def get_milvus_collection():
    logger.warning('Entering get_milvus_collection()')
    uri = os.environ.get("URI")
    token = os.environ.get("TOKEN")
    connections.connect("default", uri=uri, token=token)
    collection_name = os.environ.get("COLLECTION_NAME")
    collection = Collection(name=collection_name)
    print(f"Loaded collection")
    logger.warning('Exiting get_milvus_collection()')
    return collection

def find_similar_news(search_vec, collection, vectorizer, sent_model, ce_model, top_n: int=5):
    logger.warning('Entering find_similar_news')
    search_params = {"metric_type": "IP"}
    # search_vec = vectorizer.vectorize_(text)
    logger.warning('Querying Milvus for most similar results')
    results = collection.search([search_vec],
                                anns_field='article_embed', # annotations field specified in the schema definition
                                param=search_params,
                                limit=top_n,
                                guarantee_timestamp=1, 
                                output_fields=['article_title', 'article_src', 'article_url', 'article_date'])[0] # which fields to return in output
    
    logger.warning('retrieved search results from Milvus')
    logger.warning('Computing cross encoder similarity scores')
    texts = [result.entity.get('article_title') for result in results]
    ce_similarity_scores = np.array(ce_model.predict([[text, output_text] for output_text in texts]))
    similarity_idxs = [*np.argsort(ce_similarity_scores)[::-1]]
    logger.warning('Retrieved cross encoder similarity scores')

    logger.warning('Generating HTML output')
    html_output = ""
    for n, i in enumerate(similarity_idxs):
        title_ = results[i].entity.get('article_title')
        url_ = results[i].entity.get('article_url')
        html_output += f'''<a style="font-weight: bold; font-size:14px; color: black;" href="{url_}" target="_blank">{title_}</a><br>
        '''
    logger.warning('Successfully generated HTML output')
    logger.warning('Exiting find_similar_news')
    return html_output