import os import sys import s3fs import json import hashlib import aiohttp import asyncio import jmespath import psycopg2 import traceback from typing import * from loguru import logger from fastapi import WebSocket from pydantic import BaseModel from neo4j import GraphDatabase, AsyncGraphDatabase from kafka import KafkaProducer from pymongo import MongoClient from urllib.parse import urljoin from configparser import ConfigParser from aiokafka import AIOKafkaProducer from tqdm.asyncio import tqdm_asyncio from pymongo.collection import Collection from psycopg2.extras import execute_values from elasticsearch import Elasticsearch, AsyncElasticsearch path_this = os.path.dirname(os.path.abspath(__file__)) path_project = os.path.dirname(os.path.join(path_this, "..")) path_root = os.path.dirname(os.path.join(path_this, "../..")) sys.path.append(path_root) sys.path.append(path_project) class BaseChatCampaign(BaseModel): """ Base class for all network/campaign tools. Provides shared infrastructure: - Lazy singleton connections (Elasticsearch, PostgreSQL, Neo4j/Memgraph) - S3-compatible object storage (load / upsert) - Async HTTP client wrapper (aiohttp) - Kafka producer helpers (sync + async) - Parallel task execution with semaphore-bounded concurrency - PostgreSQL batch insert via execute_values - MD5-based deterministic ID hashing """ config: ClassVar[ConfigParser] = ConfigParser() config.read(os.path.join(path_root, "config.conf")) websocket: Optional[WebSocket] = None _connections: ClassVar[Dict[str, Any]] = {} _closers: ClassVar[Dict[str, Callable]] = {} class Config: arbitrary_types_allowed = True # ------------------------------------------------------------------ # Connection management # ------------------------------------------------------------------ @classmethod async def _get_connection(cls, name: str, init_func: Callable, close_func: Optional[Callable] = None) -> Any: """Lazy-init singleton connection identified by name.""" if name not in cls._connections: logger.info(f"[INIT] Creating connection: {name}") cls._connections[name] = await init_func() if close_func: cls._closers[name] = close_func return cls._connections[name] @classmethod async def close_all_connections(cls) -> None: """Gracefully close and clear all registered connections.""" for name, conn in cls._connections.items(): close_func = cls._closers.get(name) if close_func: try: await close_func(conn) logger.info(f"[CLOSE] {name}") except Exception as exc: logger.warning(f"[CLOSE] Failed to close {name}: {exc}") cls._connections.clear() # ------------------------------------------------------------------ # Utilities # ------------------------------------------------------------------ async def hash_id(self, ids_str: str) -> str: """Return a deterministic MD5 hex digest for the given string.""" return hashlib.md5(ids_str.encode()).hexdigest() async def is_empty(self, v: Any) -> bool: if v is None or v == 0: return True if isinstance(v, str): return v.strip() == "" if isinstance(v, (list, dict, tuple)) and not v: return True if isinstance(v, tuple): return all([await self.is_empty(x) for x in v]) return False # ------------------------------------------------------------------ # HTTP # ------------------------------------------------------------------ async def _requests( self, uri: str, params: dict, service_name: str, method: Literal["post", "get", "put"] = "post", timeout: int = 10, **kwargs, ) -> dict: """ Generic async HTTP client (aiohttp) with GET / POST / PUT support. Args: uri: Target URL. params: Payload (JSON body for POST/PUT; query params for GET). service_name: Label used in log messages. method: HTTP verb. Defaults to "post". timeout: Total request timeout in seconds. Defaults to 10. Returns: Parsed JSON response as dict, or {} for 204 No Content. Raises: ConnectionError: On non-2xx HTTP status codes. aiohttp.ClientError: On network-level errors. """ logger.info(f"HTTP {method.upper()} → {service_name.upper()}") _timeout = aiohttp.ClientTimeout(total=timeout) connector = aiohttp.TCPConnector(ssl=False) async with aiohttp.ClientSession(timeout=_timeout, connector=connector) as session: if method == "post": payload = {"url": uri, "json": params, **kwargs} func = session.post elif method == "put": payload = {"url": uri, "json": params, **kwargs} func = session.put else: payload = {"url": uri, "params": params, **kwargs} func = session.get async with func(**payload) as response: if response.status < 300: if response.status == 204 or response.content_length == 0: return {} return await response.json() detail = await response.json() raise ConnectionError( f"{service_name.upper()} returned HTTP {response.status}. " f"Detail: {detail.get('detail', detail)}" ) # ------------------------------------------------------------------ # S3-compatible object storage # ------------------------------------------------------------------ async def s3_database( self, connection: Dict[str, str], path: str, type_s3: Literal["upsert", "load"], data: Union[Dict, List] = [], return_public_url: bool = False, content_type: Optional[str] = None, ) -> Any: """ Read from or write to an S3-compatible object store. Args: connection: Dict with keys: url, key, secret_key, domain. path: Full S3 path (e.g. "bucket/prefix/file.json"). type_s3: "load" to read, "upsert" to write. data: Data to write (ignored for "load"). return_public_url: If True, return the public HTTPS URL after upsert. content_type: Optional MIME type (e.g. "application/json"). Returns: For "load" → the loaded object (dict / list). For "upsert" → public URL string if return_public_url else a status string. """ logger.info(f"S3 {type_s3} → {path}") storage_options: Dict[str, Any] = { "client_kwargs": { "endpoint_url": connection["url"], "aws_access_key_id": connection["key"], "aws_secret_access_key": connection["secret_key"], "region_name": "ap-southeast-1", } } if content_type: storage_options["s3_additional_kwargs"] = {"ContentType": content_type} s3 = s3fs.S3FileSystem(**storage_options) if type_s3 == "load": with s3.open(path, "r") as f: result = json.load(f) logger.info(f"Loaded from {path}") return result if type_s3 == "upsert": with s3.open(path, "w") as f: json.dump(data, f) logger.info(f"Upserted to {path}") if return_public_url: endpoint = connection["domain"].replace("http://", "https://").rstrip("/") return urljoin(f"{endpoint}/", path) return "Data successfully added or updated." raise ValueError(f"Invalid type_s3 value: {type_s3!r}. Use 'load' or 'upsert'.") # ------------------------------------------------------------------ # Elasticsearch # ------------------------------------------------------------------ @classmethod async def ainit_es_connection( cls, name: str, connection: dict, timeout: int = 120 ) -> AsyncElasticsearch: """Singleton async Elasticsearch client, keyed by connection name.""" async def _init(): return AsyncElasticsearch( hosts=connection["host"], port=connection["port"], http_auth=connection["creed"], timeout=timeout, ) async def _close(client): await client.close() return await cls._get_connection(f"elasticsearch:{name}", _init, _close) async def aes_connection(self, connection: dict, timeout: int = 120) -> AsyncElasticsearch: return AsyncElasticsearch( hosts=connection["host"], port=connection["port"], http_auth=connection["creed"], timeout=timeout, ) # ------------------------------------------------------------------ # PostgreSQL # ------------------------------------------------------------------ async def postgre_connection(self, connection: dict) -> psycopg2.extensions.connection: return psycopg2.connect( host=connection["host"], port=connection["port"], dbname=connection["database"], user=connection["creed"][0], password=connection["creed"][1], ) async def insert_postgre( self, data: List[dict], connection: dict, query: str, columns: List[str] ) -> None: """Batch-insert records into PostgreSQL using execute_values.""" tuples = [tuple(r.get(col) for col in columns) for r in data] conn = await self.postgre_connection(connection) cur = conn.cursor() try: execute_values(cur, query, tuples) conn.commit() logger.success(f"Inserted {len(data)} records.") except Exception as exc: logger.error(f"PostgreSQL insert error: {exc}") finally: cur.close() conn.close() # ------------------------------------------------------------------ # Graph database (Neo4j / Memgraph) # ------------------------------------------------------------------ def gb_connection(self, connection: dict) -> GraphDatabase: uri = f'{connection["host"]}:{connection["port"]}' return GraphDatabase.driver(uri=uri, auth=connection["creed"]) async def agb_connection(self, connection: dict) -> AsyncGraphDatabase: uri = f'{connection["host"]}:{connection["port"]}' return AsyncGraphDatabase.driver(uri=uri, auth=connection["creed"]) # ------------------------------------------------------------------ # MongoDB # ------------------------------------------------------------------ async def get_mongo_collection(self, uri: str, db: str, collection: str) -> Collection: return MongoClient(uri)[db][collection] # ------------------------------------------------------------------ # Kafka # ------------------------------------------------------------------ async def kafka_producer( self, topic: str, servers: str, params: dict, message_id: str ) -> None: """Publish a message to a Kafka topic (sync KafkaProducer, one-shot).""" producer = None try: producer = KafkaProducer( bootstrap_servers=servers, value_serializer=lambda m: json.dumps(m).encode("utf-8"), ) producer.send(topic, params) producer.flush() logger.info(f"Kafka message {message_id} sent to {topic}.") except Exception as exc: logger.error(f"Kafka produce error: {exc}") finally: if producer: producer.close() async def akafka_producer( self, topic: str, params: dict, message_id: str, producer: AIOKafkaProducer ) -> None: """Publish a message to Kafka using an existing async producer.""" try: await producer.send_and_wait(topic, params) logger.info(f"Async Kafka message {message_id} sent to {topic}.") except Exception as exc: logger.error(f"Async Kafka produce error: {exc}") # ------------------------------------------------------------------ # Concurrency helpers # ------------------------------------------------------------------ async def parallel_processing(self, tasks: list, num_workers: int) -> list: """ Run coroutines concurrently, capped at num_workers simultaneous executions. Results are returned in completion order (not input order). Exceptions are caught and returned as error strings so one failing task does not abort the others. Args: tasks: List of awaitables (already-created coroutines). num_workers: Maximum number of tasks running at any moment. Returns: List of results (or error strings) in completion order. """ semaphore = asyncio.Semaphore(num_workers) async def worker(task): async with semaphore: try: return await task except Exception as exc: return ( f"error: [{type(exc).__name__}] {exc} | " f"Traceback: {traceback.format_exc().strip()}" ) return [ await r for r in tqdm_asyncio( asyncio.as_completed([asyncio.create_task(worker(t)) for t in tasks]), total=len(tasks), desc="Processing", ) ] async def parallel_processingv2(self, tasks: list, num_workers: int) -> list: """ Like parallel_processing but preserves input order in the result list. """ semaphore = asyncio.Semaphore(num_workers) async def worker(task): async with semaphore: try: return await task except Exception as exc: return ( f"error: [{type(exc).__name__}] {exc} | " f"Traceback: {traceback.format_exc().strip()}" ) return await asyncio.gather(*[worker(t) for t in tasks])