rss_news_vectorizer / insert_into_milvus_db.py
Devikalalitha's picture
Update insert_into_milvus_db.py
0821777 verified
import numpy as np
from get_news import get_news
import os
from pymilvus import connections, utility
from pymilvus import Collection, DataType, FieldSchema, CollectionSchema
from sentence_transformers import SentenceTransformer
import logging
FORMAT = '%(asctime)s %(message)s'
logging.basicConfig(format=FORMAT)
logger = logging.getLogger('hf_logger')
class TextVectorizer:
'''
sentence transformers to extract sentence embeddings
'''
def __init__(self, checkpoint):
logger.warning('Loading sentence transformer')
self.sent_model = SentenceTransformer(checkpoint)
logger.warning('Successfully loaded sentence transformer')
@staticmethod
def get_unit_vector(x):
# x = x.copy()
if x.ndim != 1:
raise Exception('get_unit_vector_error: Vector is not 1-D')
magnitude = np.sqrt(np.sum(x ** 2))
unit_vector = x / magnitude
if np.sqrt(np.sum(unit_vector ** 2)) < 0.999:
raise Exception('get_unit_vector_error: Returned vector is not a unit vector')
return unit_vector
def vectorize_(self, x):
logger.warning('Entering vectorize()')
sent_embeddings = self.sent_model.encode(x)
logger.warning('Entering get_unit_vector()')
get_unit_vector = np.vectorize(self.get_unit_vector, signature='(n)->(n)')
sent_embeddings = get_unit_vector(sent_embeddings)
logger.warning('Exiting get_unit_vector()')
logger.warning('Exiting vectorize()')
return sent_embeddings
def get_secrets():
logger.warning('Entering get_secrets()')
uri = os.environ.get("URI")
token = os.environ.get("TOKEN")
collection_name = os.environ.get("COLLECTION_NAME")
logger.warning('Loaded collection secrets')
logger.warning('Exiting get_secrets()')
return uri, token, collection_name
def create_schema(uri: str, token: str, collection_name: str):
try:
logger.warning('Entering create_schema()')
connections.connect("default", uri=uri, token=token)
if not utility.has_collection(collection_name):
dim = 768 # embeddings dim
article_url = FieldSchema(name="article_url", dtype=DataType.VARCHAR, max_length=10000,
is_primary=True, description="url of the article")
article_title = FieldSchema(name="article_title", dtype=DataType.VARCHAR, max_length=5000,
is_primary=False, description="headline of the article")
article_src = FieldSchema(name="article_src", dtype=DataType.VARCHAR, max_length=1000,
is_primary=False, description="src of the article")
article_date = FieldSchema(name="article_date", dtype=DataType.VARCHAR, max_length=1000,
is_primary=False, description="date of the article")
article_age = FieldSchema(name="article_age", dtype=DataType.INT64,
is_primary=False, description="age of the article")
article_embed = FieldSchema(name="article_embed", dtype=DataType.FLOAT_VECTOR, dim=dim) # description embeddings
schema = CollectionSchema(fields=[article_url, article_title, article_src,
article_date, article_age, article_embed],
auto_id=False, description="collection of news articles")
logger.warning("Creating the collection")
collection = Collection(name=collection_name, schema=schema)
# logger.warning(f"Schema: {schema}")
logger.warning("Successfully created collection")
else:
collection = Collection(name=collection_name)
logger.warning("Using existing collection")
logger.warning('Exiting create_schema()')
return collection
except:
raise
def prepare_docs(vectorizer):
try:
logger.warning('Entering prepare_docs()')
logger.warning('Retrieving latest news')
news_df = get_news()
if news_df is None:
raise Exception("ERROR: No latest news in retrieved")
logger.warning('Successfully retrieved latest news')
article_url = news_df['url'].tolist()
article_title = news_df['title'].tolist()
article_src = news_df['src'].tolist()
article_date = news_df['parsed_date'].tolist()
article_age = news_df['news_age'].tolist()
article_embed = vectorizer.vectorize_(article_title)
logger.warning('Exiting prepare_docs()')
return [article_url, article_title, article_src,
article_date, article_age, article_embed]
except:
raise
def upsert_db(vectorizer, collection):
try:
logger.warning('Entering upsert_db()')
collection_is_empty = 0
if collection.is_empty:
collection_is_empty = 1
docs_to_upsert = prepare_docs(vectorizer)
ins_resp = collection.upsert(docs_to_upsert)
if ins_resp.err_count != 0:
raise Exception(f'Milvus Insertion not successful. {ins_resp.err_count} errors reported.')
if collection_is_empty:
index_params = {"index_type": "AUTOINDEX", "metric_type": "IP", "params": {}}
collection.create_index(field_name='article_embed', index_params=index_params)
collection.load()
logger.warning('Upsert successful')
logger.warning('Exiting upsert_db()')
except:
raise