import pymongo import os from config import (PRODUCTION_URL, PRODUCTION_DATABASE, PRODUCTION_COLLECTION, PREDICTION_URL, PREDICTION_DATABASE, PREDICTION_COLLECTION, COLLECT_PREDICTION_DATA) from logger import get_logger logger = get_logger() class DBWrite: """ Inserts processed news into MongoDB """ def __init__(self, db_type: str="production"): self.db_type = db_type self.url = PREDICTION_URL self.database = PREDICTION_DATABASE self.collection = PREDICTION_COLLECTION if self.db_type == "production": self.url = PRODUCTION_URL self.database = PRODUCTION_DATABASE self.collection = PRODUCTION_COLLECTION self.__client = None self.__error = 0 def __connect(self): try: self.__client = pymongo.MongoClient(self.url) _ = self.__client.list_database_names() except Exception as conn_exception: self.__error = 1 self.__close_connection() self.__client = None raise def __insert(self, documents): try: db = self.__client[self.database] coll = db[self.collection] # if (self.db_type == "production") or (COLLECT_PREDICTION_DATA==0): # coll.drop() # coll.insert_many(documents=documents) if (self.db_type == "production") or (COLLECT_PREDICTION_DATA==0): coll.drop() coll.insert_many(documents=documents) else: for doc in documents: filter_query = {"url": doc["url"]} update_query = {"$set": doc} coll.update_one(filter_query, update_query, upsert=True) except Exception as insert_err: self.__error = 1 self.__close_connection() logger.critical(f'Error while inserting into DB: {insert_error}') raise def __close_connection(self): if self.__client is not None: self.__client.close() self.__client = None def insert_news_into_db(self, documents: list): logger.warning(f'Entering insert_news_into_db() : {self.db_type}') if self.url is not None: if self.__error == 0: self.__connect() if self.__error == 0: self.__insert(documents=documents) if self.__error == 0: logger.warning(f"Insertion Successful: {self.db_type}. {len(documents)} documents inserted.") if self.__client is not None: self.__close_connection() logger.warning(f'Exiting insert_news_into_db(): {self.db_type}')