File size: 5,469 Bytes
daf30d2
e97cba6
daf30d2
e97cba6
2c2dea5
daf30d2
 
2c2dea5
 
 
 
 
 
daf30d2
 
b5fd43f
 
2c2dea5
b5fd43f
e97cba6
2c2dea5
e97cba6
b5fd43f
2c2dea5
daf30d2
 
 
 
2c2dea5
 
 
 
 
 
 
b5fd43f
 
daf30d2
2c2dea5
daf30d2
 
 
 
 
2c2dea5
 
daf30d2
 
2c2dea5
 
 
 
 
 
 
 
 
daf30d2
 
 
 
2c2dea5
daf30d2
2c2dea5
 
 
e97cba6
 
2c2dea5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
daf30d2
 
 
e97cba6
b5fd43f
daf30d2
2c2dea5
 
b5d9b78
daf30d2
2c2dea5
 
 
 
 
daf30d2
 
 
a5a89f7
ebac715
2c2dea5
f043e94
2c2dea5
daf30d2
2c2dea5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import CrossEncoder
import os
import numpy as np
from datetime import datetime
from pymilvus import connections, utility, Collection, DataType, FieldSchema, CollectionSchema
import streamlit as st
import logging


FORMAT = '%(asctime)s %(message)s'
logging.basicConfig(format=FORMAT)
logger = logging.getLogger('hf_logger')


@st.cache_resource
def load_sentence_transformer():
    logger.warning('Entering load_sentence_transformer')
    sent_model = SentenceTransformer('all-mpnet-base-v2')
    ce_model = CrossEncoder('cross-encoder/stsb-distilroberta-base')
    logger.warning('Exiting load_sentence_transformer')
    return sent_model, ce_model

    
class TextVectorizer:
    '''
    sentence transformers to extract sentence embeddings
    '''
    
    def vectorize_(self, x):
        logger.warning('Entering vectorize_()')
        sent_embeddings = sent_model.encode(x, normalize_embeddings=True)
        logger.warning('Exiting vectorize_()')
        return sent_embeddings
    

@st.cache_resource
def get_milvus_collection():
    logger.warning('Entering get_milvus_collection()')
    uri = os.environ.get("URI")
    token = os.environ.get("TOKEN")
    connections.connect("default", uri=uri, token=token)
    collection_name = os.environ.get("COLLECTION_NAME")
    collection = Collection(name=collection_name)
    print(f"Loaded collection")
    logger.warning('Exiting get_milvus_collection()')
    return collection

def find_similar_news(text: str, collection, vectorizer, sent_model, ce_model, top_n: int=100):
    logger.warning('Entering find_similar_news')
    search_params = {"metric_type": "IP"}
    search_vec = vectorizer.vectorize_(text)
    # logger.warning('Querying Milvus for entity count')
    # n_docs_in_collection = collection.query(expr="", output_fields = ["count(*)"])[0].get('count(*)')
    # logger.warning(f'Retrieved entity count ({n_docs_in_collection}) from Milvus')
    logger.warning('Querying Milvus for most similar results')
    results = collection.search([search_vec],
                                anns_field='article_embed', # annotations field specified in the schema definition
                                param=search_params,
                                limit=top_n,
                                guarantee_timestamp=1, 
                                output_fields=['article_title', 'article_src', 'article_url', 'article_date'])[0] # which fields to return in output
    
    logger.warning('retrieved search results from Milvus')
    logger.warning('Computing cross encoder similarity scores')
    texts = [result.entity.get('article_title') for result in results]
    ce_similarity_scores = np.array(ce_model.predict([[text, output_text] for output_text in texts]))
    similarity_idxs = [*np.argsort(ce_similarity_scores)[::-1]]
    logger.warning('Retrieved cross encoder similarity scores')

    logger.warning('Generating HTML output')
    html_output = f'''<html>'''
    # html_output = f'''<html><h5>No. of news articles in database: {n_docs_in_collection}</h5>'''
    for n, i in enumerate(similarity_idxs):
        title_ = results[i].entity.get('article_title')
        date_ = results[i].entity.get('article_date')
        src_ = results[i].entity.get('article_src')
        url_ = results[i].entity.get('article_url')
        cross_encoder_similarity = str(np.round(ce_similarity_scores[i], 4))
        cosine_similarity = str(np.round(results[i].distance, 4))
        html_output += f'''<a style="font-weight: bold; font-size:18px; color: black;" href="{url_}" target="_blank">{n+1}. {title_}</a><br>
        <b>Date:</b> {date_}&nbsp;&nbsp;&nbsp <b>Source:</b> {src_}<br>
        <b>Cross encoder similarity:</b> {cross_encoder_similarity}&nbsp;&nbsp;&nbsp <b>Cosine similarity:</b> {cosine_similarity}
        <br><br>
        '''
    html_output += '</html>'
    logger.warning('Successfully generated HTML output')
    logger.warning('Exiting find_similar_news')
    return html_output


vectorizer = TextVectorizer()
collection = get_milvus_collection()
sent_model, ce_model = load_sentence_transformer()


try:
    logger.warning('Entering the application')
    st.markdown("<h3>Find Recent Similar News</h3>", unsafe_allow_html=True)
    desc = '''<p style="font-size: 13px;">
    Embeddings of news headlines are stored in Milvus vector database, used as a feature store.
    The database is updated in realtime with new headlines using a CRON job.
    Embeddings of the input headline are computed using sentence transformers (all-mpnet-base-v2).
    Similar news headlines are retrieved from the vector database using Inner Product as similarity metric and are reranked using cross encoder.
    The embeddings are converted into unit vectors so that inner product can be used as cosine similarity, since Milvus doesn't support cosine similarity.
    </p> 
    '''
    st.markdown(desc, unsafe_allow_html=True)
    news_txt = st.text_area("Paste the headline of a news article", "", height=30)
    top_n = st.slider('Select the number of similar articles to display', 1, 10, 5)

    if st.button("Submit"):
        result = find_similar_news(news_txt, collection, vectorizer, sent_model, ce_model, top_n)
        st.markdown(result, unsafe_allow_html=True)
    logger.warning('Exiting the application')
except Exception as e:
    st.error(f'An unexpected error occured:  \n{e}')