|
import numpy as np |
|
from get_news import get_news |
|
import os |
|
from pymilvus import connections, utility |
|
from pymilvus import Collection, DataType, FieldSchema, CollectionSchema |
|
from sentence_transformers import SentenceTransformer |
|
import logging |
|
|
|
|
|
FORMAT = '%(asctime)s %(message)s' |
|
logging.basicConfig(format=FORMAT) |
|
logger = logging.getLogger('hf_logger') |
|
|
|
|
|
class TextVectorizer: |
|
''' |
|
sentence transformers to extract sentence embeddings |
|
''' |
|
def __init__(self, checkpoint): |
|
logger.warning('Loading sentence transformer') |
|
self.sent_model = SentenceTransformer(checkpoint) |
|
logger.warning('Successfully loaded sentence transformer') |
|
|
|
@staticmethod |
|
def get_unit_vector(x): |
|
|
|
if x.ndim != 1: |
|
raise Exception('get_unit_vector_error: Vector is not 1-D') |
|
magnitude = np.sqrt(np.sum(x ** 2)) |
|
unit_vector = x / magnitude |
|
if np.sqrt(np.sum(unit_vector ** 2)) < 0.999: |
|
raise Exception('get_unit_vector_error: Returned vector is not a unit vector') |
|
return unit_vector |
|
|
|
def vectorize_(self, x): |
|
logger.warning('Entering vectorize()') |
|
sent_embeddings = self.sent_model.encode(x) |
|
logger.warning('Entering get_unit_vector()') |
|
get_unit_vector = np.vectorize(self.get_unit_vector, signature='(n)->(n)') |
|
sent_embeddings = get_unit_vector(sent_embeddings) |
|
logger.warning('Exiting get_unit_vector()') |
|
logger.warning('Exiting vectorize()') |
|
return sent_embeddings |
|
|
|
|
|
def get_secrets(): |
|
logger.warning('Entering get_secrets()') |
|
uri = os.environ.get("URI") |
|
token = os.environ.get("TOKEN") |
|
collection_name = os.environ.get("COLLECTION_NAME") |
|
logger.warning('Loaded collection secrets') |
|
logger.warning('Exiting get_secrets()') |
|
return uri, token, collection_name |
|
|
|
|
|
def create_schema(uri: str, token: str, collection_name: str): |
|
try: |
|
logger.warning('Entering create_schema()') |
|
connections.connect("default", uri=uri, token=token) |
|
|
|
if not utility.has_collection(collection_name): |
|
dim = 768 |
|
|
|
article_url = FieldSchema(name="article_url", dtype=DataType.VARCHAR, max_length=10000, |
|
is_primary=True, description="url of the article") |
|
article_title = FieldSchema(name="article_title", dtype=DataType.VARCHAR, max_length=5000, |
|
is_primary=False, description="headline of the article") |
|
|
|
article_src = FieldSchema(name="article_src", dtype=DataType.VARCHAR, max_length=1000, |
|
is_primary=False, description="src of the article") |
|
|
|
article_date = FieldSchema(name="article_date", dtype=DataType.VARCHAR, max_length=1000, |
|
is_primary=False, description="date of the article") |
|
|
|
article_age = FieldSchema(name="article_age", dtype=DataType.INT64, |
|
is_primary=False, description="age of the article") |
|
|
|
article_embed = FieldSchema(name="article_embed", dtype=DataType.FLOAT_VECTOR, dim=dim) |
|
|
|
schema = CollectionSchema(fields=[article_url, article_title, article_src, |
|
article_date, article_age, article_embed], |
|
auto_id=False, description="collection of news articles") |
|
logger.warning("Creating the collection") |
|
collection = Collection(name=collection_name, schema=schema) |
|
|
|
logger.warning("Successfully created collection") |
|
else: |
|
collection = Collection(name=collection_name) |
|
logger.warning("Using existing collection") |
|
logger.warning('Exiting create_schema()') |
|
return collection |
|
except: |
|
raise |
|
|
|
|
|
def prepare_docs(vectorizer): |
|
try: |
|
logger.warning('Entering prepare_docs()') |
|
logger.warning('Retrieving latest news') |
|
news_df = get_news() |
|
if news_df is None: |
|
raise Exception("ERROR: No latest news in retrieved") |
|
logger.warning('Successfully retrieved latest news') |
|
article_url = news_df['url'].tolist() |
|
article_title = news_df['title'].tolist() |
|
article_src = news_df['src'].tolist() |
|
article_date = news_df['parsed_date'].tolist() |
|
article_age = news_df['news_age'].tolist() |
|
article_embed = vectorizer.vectorize_(article_title) |
|
logger.warning('Exiting prepare_docs()') |
|
return [article_url, article_title, article_src, |
|
article_date, article_age, article_embed] |
|
except: |
|
raise |
|
|
|
|
|
def upsert_db(vectorizer, collection): |
|
try: |
|
logger.warning('Entering upsert_db()') |
|
collection_is_empty = 0 |
|
if collection.is_empty: |
|
collection_is_empty = 1 |
|
|
|
docs_to_upsert = prepare_docs(vectorizer) |
|
ins_resp = collection.upsert(docs_to_upsert) |
|
if ins_resp.err_count != 0: |
|
raise Exception(f'Milvus Insertion not successful. {ins_resp.err_count} errors reported.') |
|
if collection_is_empty: |
|
index_params = {"index_type": "AUTOINDEX", "metric_type": "IP", "params": {}} |
|
collection.create_index(field_name='article_embed', index_params=index_params) |
|
collection.load() |
|
logger.warning('Upsert successful') |
|
logger.warning('Exiting upsert_db()') |
|
except: |
|
raise |