File size: 3,899 Bytes
f963240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ccdd011
 
 
 
 
 
 
7861c2b
ccdd011
f963240
 
 
 
 
 
 
 
178e147
ccdd011
 
 
 
 
 
7825e9d
ccdd011
f963240
 
 
 
ccdd011
 
 
 
 
 
 
 
 
 
7861c2b
 
ccdd011
f963240
 
a3697ff
ccdd011
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f963240
ccdd011
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7861c2b
 
 
f963240
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import CrossEncoder
import os
import numpy as np
from datetime import datetime
from pymilvus import connections, utility, Collection, DataType, FieldSchema, CollectionSchema
import logging


FORMAT = '%(asctime)s %(message)s'
logging.basicConfig(format=FORMAT)
logger = logging.getLogger('hf_logger')


def load_sentence_transformer():
    try:
        logger.warning('Entering load_sentence_transformer')
        sent_model = SentenceTransformer('all-mpnet-base-v2')
        ce_model = CrossEncoder('cross-encoder/stsb-distilroberta-base')
        logger.warning('Exiting load_sentence_transformer')
    except Exception as e:
        logger.warning(f"load_sentence_transformer error: {e}")
        raise
        
    return sent_model, ce_model

    
class TextVectorizer:
    '''
    sentence transformers to extract sentence embeddings
    '''
    
    def vectorize_(self, x, sent_model):
        try:
            logger.warning('Entering vectorize_()')
            sent_embeddings = sent_model.encode(x, normalize_embeddings=True)
            logger.warning('Exiting vectorize_()')
        except Exception as e:
            logger.warning(f"vectorize() error: {e}")
            raise
            
        return sent_embeddings
    

def get_milvus_collection():
    try:
        logger.warning('Entering get_milvus_collection()')
        uri = os.environ.get("URI")
        token = os.environ.get("TOKEN")
        connections.connect("default", uri=uri, token=token)
        collection_name = os.environ.get("COLLECTION_NAME")
        collection = Collection(name=collection_name)
        print(f"Loaded collection")
        logger.warning('Exiting get_milvus_collection()')
    except Exception as e:
        logger.warning(f"get_milvus_collection() error: {e}")
        raise
        
    return collection

def find_similar_news(text, search_vec, collection, vectorizer, sent_model, ce_model, top_n: int=10):
    try:
        logger.warning('Entering find_similar_news')
        search_params = {"metric_type": "IP"}
        logger.warning('Querying Milvus for most similar results')
        results = collection.search([search_vec],
                                    anns_field='article_embed', # annotations field specified in the schema definition
                                    param=search_params,
                                    limit=top_n,
                                    guarantee_timestamp=1, 
                                    output_fields=['article_title', 'article_url'])[0] # which fields to return in output
        
        logger.warning('retrieved search results from Milvus')
        logger.warning('Computing cross encoder similarity scores')
        texts = [result.entity.get('article_title') for result in results]
        ce_similarity_scores = np.array(ce_model.predict([[text, output_text] for output_text in texts]))
        similarity_idxs = [*np.argsort(ce_similarity_scores)[::-1]]
        logger.warning('Retrieved cross encoder similarity scores')
    
        logger.warning('Generating HTML output')
        html_output = ""
        article_count = 0
        for n, i in enumerate(similarity_idxs):
            title_ = results[i].entity.get('article_title')
            url_ = results[i].entity.get('article_url')
            if title_ != text:
                html_output += f'''<a class="similar-news-item" href="{url_}" target="_blank">{title_}</a><br>
                '''
                article_count += 1
    
            if article_count == 5 :
                break
    
                
        logger.warning('Successfully generated HTML output')
        logger.warning('Exiting find_similar_news')
    except Exception as e:
        logger.warning(f"find_similar_news() error: {e}")
        raise
        
    return html_output