lalithadevi commited on
Commit
ccdd011
1 Parent(s): ffede1e

Update find_similar_news.py

Browse files
Files changed (1) hide show
  1. find_similar_news.py +63 -48
find_similar_news.py CHANGED
@@ -13,10 +13,14 @@ logger = logging.getLogger('hf_logger')
13
 
14
 
15
  def load_sentence_transformer():
16
- logger.warning('Entering load_sentence_transformer')
17
- sent_model = SentenceTransformer('all-mpnet-base-v2')
18
- ce_model = CrossEncoder('cross-encoder/stsb-distilroberta-base')
19
- logger.warning('Exiting load_sentence_transformer')
 
 
 
 
20
  return sent_model, ce_model
21
 
22
 
@@ -26,56 +30,67 @@ class TextVectorizer:
26
  '''
27
 
28
  def vectorize_(self, x, sent_model):
29
- logger.warning('Entering vectorize_()')
30
- sent_embeddings = sent_model.encode(x, normalize_embeddings=True)
31
- logger.warning('Exiting vectorize_()')
 
 
 
 
32
  return sent_embeddings
33
 
34
 
35
  def get_milvus_collection():
36
- logger.warning('Entering get_milvus_collection()')
37
- uri = os.environ.get("URI")
38
- token = os.environ.get("TOKEN")
39
- connections.connect("default", uri=uri, token=token)
40
- collection_name = os.environ.get("COLLECTION_NAME")
41
- collection = Collection(name=collection_name)
42
- print(f"Loaded collection")
43
- logger.warning('Exiting get_milvus_collection()')
 
 
 
 
44
  return collection
45
 
46
  def find_similar_news(text, search_vec, collection, vectorizer, sent_model, ce_model, top_n: int=10):
47
- logger.warning('Entering find_similar_news')
48
- search_params = {"metric_type": "IP"}
49
- logger.warning('Querying Milvus for most similar results')
50
- results = collection.search([search_vec],
51
- anns_field='article_embed', # annotations field specified in the schema definition
52
- param=search_params,
53
- limit=top_n,
54
- guarantee_timestamp=1,
55
- output_fields=['article_title', 'article_url'])[0] # which fields to return in output
 
 
 
 
 
 
 
 
56
 
57
- logger.warning('retrieved search results from Milvus')
58
- logger.warning('Computing cross encoder similarity scores')
59
- texts = [result.entity.get('article_title') for result in results]
60
- ce_similarity_scores = np.array(ce_model.predict([[text, output_text] for output_text in texts]))
61
- similarity_idxs = [*np.argsort(ce_similarity_scores)[::-1]]
62
- logger.warning('Retrieved cross encoder similarity scores')
63
-
64
- logger.warning('Generating HTML output')
65
- html_output = ""
66
- article_count = 0
67
- for n, i in enumerate(similarity_idxs):
68
- title_ = results[i].entity.get('article_title')
69
- url_ = results[i].entity.get('article_url')
70
- if title_ != text:
71
- html_output += f'''<a class="similar-news-item" href="{url_}" target="_blank">{title_}</a><br>
72
- '''
73
- article_count += 1
74
-
75
- if article_count == 5 :
76
- break
77
-
78
-
79
- logger.warning('Successfully generated HTML output')
80
- logger.warning('Exiting find_similar_news')
81
  return html_output
 
13
 
14
 
15
  def load_sentence_transformer():
16
+ try:
17
+ logger.warning('Entering load_sentence_transformer')
18
+ sent_model = SentenceTransformer('all-mpnet-base-v2')
19
+ ce_model = CrossEncoder('cross-encoder/stsb-distilroberta-base')
20
+ logger.warning('Exiting load_sentence_transformer')
21
+ except Exception as e:
22
+ logger.warning(f"load_sentence_transformer error: {e}")
23
+
24
  return sent_model, ce_model
25
 
26
 
 
30
  '''
31
 
32
  def vectorize_(self, x, sent_model):
33
+ try:
34
+ logger.warning('Entering vectorize_()')
35
+ sent_embeddings = sent_model.encode(x, normalize_embeddings=True)
36
+ logger.warning('Exiting vectorize_()')
37
+ except Exception as e:
38
+ logger.warning(f"vectorize() error: {e}")
39
+
40
  return sent_embeddings
41
 
42
 
43
  def get_milvus_collection():
44
+ try:
45
+ logger.warning('Entering get_milvus_collection()')
46
+ uri = os.environ.get("URI")
47
+ token = os.environ.get("TOKEN")
48
+ connections.connect("default", uri=uri, token=token)
49
+ collection_name = os.environ.get("COLLECTION_NAME")
50
+ collection = Collection(name=collection_name)
51
+ print(f"Loaded collection")
52
+ logger.warning('Exiting get_milvus_collection()')
53
+ except Exception as e:
54
+ logger.warning(f"get_milvus_collection() error: {e}")
55
+
56
  return collection
57
 
58
  def find_similar_news(text, search_vec, collection, vectorizer, sent_model, ce_model, top_n: int=10):
59
+ try:
60
+ logger.warning('Entering find_similar_news')
61
+ search_params = {"metric_type": "IP"}
62
+ logger.warning('Querying Milvus for most similar results')
63
+ results = collection.search([search_vec],
64
+ anns_field='article_embed', # annotations field specified in the schema definition
65
+ param=search_params,
66
+ limit=top_n,
67
+ guarantee_timestamp=1,
68
+ output_fields=['article_title', 'article_url'])[0] # which fields to return in output
69
+
70
+ logger.warning('retrieved search results from Milvus')
71
+ logger.warning('Computing cross encoder similarity scores')
72
+ texts = [result.entity.get('article_title') for result in results]
73
+ ce_similarity_scores = np.array(ce_model.predict([[text, output_text] for output_text in texts]))
74
+ similarity_idxs = [*np.argsort(ce_similarity_scores)[::-1]]
75
+ logger.warning('Retrieved cross encoder similarity scores')
76
 
77
+ logger.warning('Generating HTML output')
78
+ html_output = ""
79
+ article_count = 0
80
+ for n, i in enumerate(similarity_idxs):
81
+ title_ = results[i].entity.get('article_title')
82
+ url_ = results[i].entity.get('article_url')
83
+ if title_ != text:
84
+ html_output += f'''<a class="similar-news-item" href="{url_}" target="_blank">{title_}</a><br>
85
+ '''
86
+ article_count += 1
87
+
88
+ if article_count == 5 :
89
+ break
90
+
91
+
92
+ logger.warning('Successfully generated HTML output')
93
+ logger.warning('Exiting find_similar_news')
94
+ except Exception as e:
95
+ logger.warning(f"find_similar_news() error: {e}")
 
 
 
 
 
96
  return html_output