| from sqlalchemy.ext.asyncio import AsyncSession
|
| from sqlalchemy import select, func, and_, insert
|
| from sqlalchemy.orm import selectinload
|
| from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
| from typing import List, Optional
|
| from datetime import datetime
|
| import logging
|
|
|
| from app.db_models import User, ProxySource, Proxy
|
| from app.validator import proxy_validator
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| class DatabaseStorage:
|
| def __init__(self, enable_validation: bool = True):
|
| self.enable_validation = enable_validation
|
|
|
| async def create_admin_user(
|
| self, session: AsyncSession, email: str = "admin@1proxy.local"
|
| ) -> User:
|
| result = await session.execute(select(User).where(User.email == email))
|
| user = result.scalar_one_or_none()
|
|
|
| if not user:
|
| user = User(
|
| oauth_provider="local",
|
| oauth_id="admin",
|
| email=email,
|
| username="admin",
|
| role="admin",
|
| avatar_url=None,
|
| )
|
| session.add(user)
|
| await session.commit()
|
| await session.refresh(user)
|
|
|
| return user
|
|
|
| async def seed_admin_sources(self, session: AsyncSession, admin_user_id: int):
|
| from app.sources import SourceRegistry
|
|
|
| for source_config in SourceRegistry.SOURCES:
|
| result = await session.execute(
|
| select(ProxySource).where(ProxySource.url == str(source_config.url))
|
| )
|
| existing = result.scalar_one_or_none()
|
|
|
| if not existing:
|
| source = ProxySource(
|
| user_id=admin_user_id,
|
| url=str(source_config.url),
|
| type=source_config.type.value
|
| if hasattr(source_config.type, "value")
|
| else str(source_config.type),
|
| name=str(source_config.url).split("/")[-2],
|
| enabled=source_config.enabled,
|
| validated=True,
|
| is_admin_source=True,
|
| is_paid=False,
|
| )
|
| session.add(source)
|
|
|
| await session.commit()
|
|
|
| async def add_proxy(
|
| self, session: AsyncSession, proxy_data: dict, source_id: Optional[int] = None
|
| ) -> Optional[Proxy]:
|
| result = await session.execute(
|
| select(Proxy).where(Proxy.url == proxy_data["url"])
|
| )
|
| existing = result.scalar_one_or_none()
|
|
|
| if existing:
|
| existing.last_seen = datetime.utcnow()
|
| existing.updated_at = datetime.utcnow()
|
| if source_id and not existing.source_id:
|
| existing.source_id = source_id
|
| await session.commit()
|
| return existing
|
|
|
| proxy = Proxy(
|
| source_id=source_id,
|
| url=proxy_data["url"],
|
| protocol=proxy_data.get("protocol", "http"),
|
| ip=proxy_data.get("ip"),
|
| port=proxy_data.get("port"),
|
| is_working=True,
|
| )
|
| session.add(proxy)
|
| await session.commit()
|
| await session.refresh(proxy)
|
| return proxy
|
|
|
| async def add_proxy_with_validation(
|
| self, session: AsyncSession, proxy_data: dict, source_id: Optional[int] = None
|
| ) -> Optional[Proxy]:
|
| """Add proxy with comprehensive validation"""
|
| url = proxy_data.get("url")
|
| ip = proxy_data.get("ip")
|
|
|
| if not url or not ip:
|
| return None
|
|
|
| if self.enable_validation:
|
| validation_result = await proxy_validator.validate_comprehensive(url, ip)
|
|
|
| if not validation_result.success:
|
| return None
|
|
|
| proxy_data.update(
|
| {
|
| "latency_ms": validation_result.latency_ms,
|
| "anonymity": validation_result.anonymity,
|
| "can_access_google": validation_result.can_access_google,
|
| "country_code": validation_result.country_code,
|
| "country_name": validation_result.country_name,
|
| "proxy_type": validation_result.proxy_type,
|
| "quality_score": validation_result.quality_score,
|
| "is_working": True,
|
| "validation_status": "validated",
|
| "last_validated": datetime.utcnow(),
|
| }
|
| )
|
|
|
| return await self.add_proxy(session, proxy_data, source_id)
|
|
|
| async def add_proxies(self, session: AsyncSession, proxies_data: List[dict]) -> int:
|
| """
|
| Efficiently add proxies using bulk insert with ON CONFLICT DO UPDATE.
|
| This avoids N queries for N proxies and instead uses a single bulk operation.
|
| """
|
| if not proxies_data:
|
| return 0
|
|
|
| now = datetime.utcnow()
|
| prepared_data = []
|
|
|
| for proxy_data in proxies_data:
|
| try:
|
|
|
| url = proxy_data.get("url")
|
| if not url:
|
| ip = proxy_data.get("ip")
|
| port = proxy_data.get("port")
|
| protocol = proxy_data.get("protocol", "http")
|
| if ip and port:
|
| url = f"{protocol}://{ip}:{port}"
|
| else:
|
| continue
|
|
|
|
|
| prepared_data.append(
|
| {
|
| "url": url,
|
| "protocol": proxy_data.get("protocol", "http"),
|
| "ip": proxy_data.get("ip"),
|
| "port": proxy_data.get("port"),
|
| "country_code": proxy_data.get("country_code"),
|
| "country_name": proxy_data.get("country_name"),
|
| "city": proxy_data.get("city"),
|
| "latency_ms": proxy_data.get("latency_ms"),
|
| "speed_mbps": proxy_data.get("speed_mbps"),
|
| "anonymity": proxy_data.get("anonymity"),
|
| "proxy_type": proxy_data.get("proxy_type"),
|
| "quality_score": proxy_data.get("quality_score"),
|
| "is_working": True,
|
| "validation_status": proxy_data.get(
|
| "validation_status", "pending"
|
| ),
|
| "last_validated": proxy_data.get("last_validated"),
|
| "first_seen": now,
|
| "last_seen": now,
|
| "created_at": now,
|
| "updated_at": now,
|
| }
|
| )
|
| except Exception as e:
|
| logger.error(f"Error preparing proxy data: {e}")
|
| continue
|
|
|
| if not prepared_data:
|
| return 0
|
|
|
| try:
|
| batch_size = 100
|
| total_inserted = 0
|
|
|
| for i in range(0, len(prepared_data), batch_size):
|
| batch = prepared_data[i : i + batch_size]
|
|
|
| for proxy_dict in batch:
|
| try:
|
| result = await session.execute(
|
| select(Proxy).where(Proxy.url == proxy_dict["url"])
|
| )
|
| existing = result.scalar_one_or_none()
|
|
|
| if existing:
|
| existing.last_seen = now
|
| existing.updated_at = now
|
| else:
|
| proxy = Proxy(**proxy_dict)
|
| session.add(proxy)
|
| total_inserted += 1
|
|
|
| except Exception as e:
|
| logger.error(
|
| f"Error inserting proxy {proxy_dict.get('url')}: {e}"
|
| )
|
| continue
|
|
|
| await session.commit()
|
|
|
| logger.info(
|
| f"Successfully processed {len(prepared_data)} proxies, inserted {total_inserted} new ones"
|
| )
|
| return len(prepared_data)
|
|
|
| except Exception as e:
|
| logger.error(f"Error in bulk insert: {e}")
|
| await session.rollback()
|
| return await self._add_proxies_fallback(session, prepared_data)
|
|
|
| async def _add_proxies_fallback(
|
| self, session: AsyncSession, proxies_data: List[dict]
|
| ) -> int:
|
| """Fallback method for adding proxies one by one if bulk insert fails."""
|
| added_count = 0
|
| now = datetime.utcnow()
|
|
|
| for proxy_data in proxies_data:
|
| try:
|
| url = proxy_data.get("url")
|
| if not url:
|
| continue
|
|
|
|
|
| result = await session.execute(select(Proxy).where(Proxy.url == url))
|
| existing = result.scalar_one_or_none()
|
|
|
| if existing:
|
| existing.last_seen = now
|
| existing.updated_at = now
|
| else:
|
| proxy = Proxy(**proxy_data)
|
| session.add(proxy)
|
| added_count += 1
|
|
|
| except Exception as e:
|
| logger.error(f"Error in fallback insert for proxy: {e}")
|
| continue
|
|
|
| await session.commit()
|
| return added_count
|
|
|
| async def validate_and_update_proxies(
|
| self,
|
| session: AsyncSession,
|
| proxy_ids: Optional[List[int]] = None,
|
| limit: int = 50,
|
| ) -> dict:
|
| """Validate pending proxies and update their status"""
|
| if proxy_ids:
|
| query = select(Proxy).where(
|
| Proxy.id.in_(proxy_ids), Proxy.validation_status == "pending"
|
| )
|
| else:
|
| query = (
|
| select(Proxy).where(Proxy.validation_status == "pending").limit(limit)
|
| )
|
|
|
| result = await session.execute(query)
|
| proxies_to_validate = result.scalars().all()
|
|
|
| if not proxies_to_validate:
|
| return {"validated": 0, "failed": 0, "total": 0}
|
|
|
| proxy_tuples = [(p.url, p.ip) for p in proxies_to_validate if p.ip]
|
|
|
| if not proxy_tuples:
|
| return {"validated": 0, "failed": 0, "total": 0}
|
|
|
| validation_results = await proxy_validator.validate_batch(proxy_tuples)
|
|
|
| validated_count = 0
|
| failed_count = 0
|
|
|
| for proxy in proxies_to_validate:
|
| matching_result = next(
|
| (r for url, r in validation_results if url == proxy.url), None
|
| )
|
|
|
| if not matching_result:
|
| continue
|
|
|
| if matching_result.success:
|
| proxy.latency_ms = matching_result.latency_ms
|
| proxy.anonymity = matching_result.anonymity
|
| proxy.can_access_google = matching_result.can_access_google
|
| proxy.country_code = matching_result.country_code
|
| proxy.country_name = matching_result.country_name
|
| proxy.proxy_type = matching_result.proxy_type
|
| proxy.quality_score = matching_result.quality_score
|
| proxy.is_working = True
|
| proxy.validation_status = "validated"
|
| proxy.last_validated = datetime.utcnow()
|
| proxy.validation_failures = 0
|
| validated_count += 1
|
| else:
|
| proxy.is_working = False
|
| proxy.validation_status = "failed"
|
| proxy.validation_failures = (proxy.validation_failures or 0) + 1
|
| failed_count += 1
|
|
|
| await session.commit()
|
|
|
| return {
|
| "validated": validated_count,
|
| "failed": failed_count,
|
| "total": len(proxies_to_validate),
|
| }
|
|
|
| async def get_proxies(
|
| self,
|
| session: AsyncSession,
|
| protocol: Optional[str] = None,
|
| country_code: Optional[str] = None,
|
| anonymity: Optional[str] = None,
|
| min_quality: Optional[int] = None,
|
| is_working: bool = True,
|
| validation_status: str = "validated",
|
| limit: int = 100,
|
| offset: int = 0,
|
| order_by: str = "quality_score",
|
| ) -> tuple[List[Proxy], int]:
|
|
|
| query = (
|
| select(Proxy)
|
| .options(selectinload(Proxy.source))
|
| .where(
|
| Proxy.is_working == is_working,
|
| Proxy.validation_status == validation_status,
|
| )
|
| )
|
|
|
| if protocol:
|
| query = query.where(Proxy.protocol == protocol)
|
| if country_code:
|
| query = query.where(Proxy.country_code == country_code)
|
| if anonymity:
|
| query = query.where(Proxy.anonymity == anonymity)
|
| if min_quality:
|
| query = query.where(Proxy.quality_score >= min_quality)
|
|
|
| count_query = select(func.count()).select_from(query.subquery())
|
| total_result = await session.execute(count_query)
|
| total = total_result.scalar()
|
|
|
| if order_by == "latency_ms":
|
| query = query.order_by(Proxy.latency_ms.asc().nulls_last())
|
| elif order_by == "quality_score":
|
| query = query.order_by(Proxy.quality_score.desc().nulls_last())
|
| elif order_by == "created_at":
|
| query = query.order_by(Proxy.created_at.desc())
|
|
|
| query = query.limit(limit).offset(offset)
|
| result = await session.execute(query)
|
| proxies = result.scalars().all()
|
|
|
| return list(proxies), total
|
|
|
| async def get_sources(
|
| self,
|
| session: AsyncSession,
|
| user_id: Optional[int] = None,
|
| enabled_only: bool = False,
|
| ) -> List[ProxySource]:
|
| query = select(ProxySource)
|
|
|
| if user_id:
|
| query = query.where(ProxySource.user_id == user_id)
|
| if enabled_only:
|
| query = query.where(ProxySource.enabled == True)
|
|
|
| result = await session.execute(query)
|
| return list(result.scalars().all())
|
|
|
| async def get_random_proxy(
|
| self,
|
| session: AsyncSession,
|
| protocol: Optional[str] = None,
|
| country_code: Optional[str] = None,
|
| min_quality: Optional[int] = None,
|
| anonymity: Optional[str] = None,
|
| max_latency: Optional[int] = None,
|
| ) -> Optional[Proxy]:
|
| query = select(Proxy).where(
|
| Proxy.is_working == True, Proxy.validation_status == "validated"
|
| )
|
|
|
| if protocol:
|
| query = query.where(Proxy.protocol == protocol)
|
| if country_code:
|
| query = query.where(Proxy.country_code == country_code)
|
| if min_quality:
|
| query = query.where(Proxy.quality_score >= min_quality)
|
| if anonymity:
|
| query = query.where(Proxy.anonymity == anonymity)
|
| if max_latency:
|
| query = query.where(Proxy.latency_ms <= max_latency)
|
|
|
| query = query.order_by(func.random()).limit(1)
|
| result = await session.execute(query)
|
| return result.scalar_one_or_none()
|
|
|
| async def get_stats(self, session: AsyncSession) -> dict:
|
| """
|
| Get proxy statistics efficiently using a single GROUP BY query
|
| instead of multiple separate queries.
|
| """
|
|
|
| result = await session.execute(
|
| select(Proxy.protocol, func.count(Proxy.id).label("count"))
|
| .where(Proxy.validation_status == "validated")
|
| .group_by(Proxy.protocol)
|
| )
|
|
|
| by_protocol = {}
|
| total = 0
|
|
|
| for row in result:
|
| protocol = row.protocol if row.protocol else "unknown"
|
| count = row.count
|
| by_protocol[protocol] = count
|
| total += count
|
|
|
|
|
| expected_protocols = [
|
| "http",
|
| "https",
|
| "vmess",
|
| "vless",
|
| "trojan",
|
| "shadowsocks",
|
| ]
|
| for protocol in expected_protocols:
|
| if protocol not in by_protocol:
|
| by_protocol[protocol] = 0
|
|
|
| return {"total_proxies": total, "by_protocol": by_protocol}
|
|
|
| async def count_proxies(self, session: AsyncSession) -> int:
|
| result = await session.execute(select(func.count()).select_from(Proxy))
|
| return result.scalar() or 0
|
|
|
| async def count_sources(self, session: AsyncSession) -> int:
|
| result = await session.execute(select(func.count()).select_from(ProxySource))
|
| return result.scalar() or 0
|
|
|
| async def count_users(self, session: AsyncSession) -> int:
|
| result = await session.execute(select(func.count()).select_from(User))
|
| return result.scalar() or 0
|
|
|
| async def get_or_create_user(
|
| self,
|
| session: AsyncSession,
|
| oauth_provider: str,
|
| oauth_id: str,
|
| email: str,
|
| username: str,
|
| role: str = "user",
|
| avatar_url: Optional[str] = None,
|
| ) -> User:
|
| result = await session.execute(
|
| select(User).where(
|
| and_(User.oauth_provider == oauth_provider, User.oauth_id == oauth_id)
|
| )
|
| )
|
| user = result.scalar_one_or_none()
|
|
|
| if not user:
|
| user = User(
|
| oauth_provider=oauth_provider,
|
| oauth_id=oauth_id,
|
| email=email,
|
| username=username,
|
| role=role,
|
| avatar_url=avatar_url,
|
| )
|
| session.add(user)
|
| await session.commit()
|
| await session.refresh(user)
|
|
|
| return user
|
|
|
| async def create_notification(
|
| self,
|
| session: AsyncSession,
|
| user_id: int,
|
| notification_type: str,
|
| title: str,
|
| message: str,
|
| severity: str = "info",
|
| ):
|
| from app.db_models import Notification
|
|
|
| notification = Notification(
|
| user_id=user_id,
|
| type=notification_type,
|
| title=title,
|
| message=message,
|
| severity=severity,
|
| )
|
| session.add(notification)
|
| await session.commit()
|
| await session.refresh(notification)
|
| return notification
|
|
|
| async def get_notifications(
|
| self,
|
| session: AsyncSession,
|
| user_id: int,
|
| unread_only: bool = False,
|
| limit: int = 50,
|
| ):
|
| from app.db_models import Notification
|
|
|
| query = select(Notification).where(Notification.user_id == user_id)
|
| if unread_only:
|
| query = query.where(Notification.read == False)
|
| query = query.order_by(Notification.created_at.desc()).limit(limit)
|
| result = await session.execute(query)
|
| return list(result.scalars().all())
|
|
|
| async def mark_notification_read(
|
| self, session: AsyncSession, user_id: int, notification_id: int
|
| ) -> bool:
|
| from app.db_models import Notification
|
|
|
| result = await session.execute(
|
| select(Notification).where(
|
| and_(
|
| Notification.id == notification_id, Notification.user_id == user_id
|
| )
|
| )
|
| )
|
| notification = result.scalar_one_or_none()
|
| if notification:
|
| notification.read = True
|
| await session.commit()
|
| return True
|
| return False
|
|
|
| async def mark_all_notifications_read(
|
| self, session: AsyncSession, user_id: int
|
| ) -> int:
|
| from app.db_models import Notification
|
| from sqlalchemy import update
|
|
|
| stmt = (
|
| update(Notification)
|
| .where(and_(Notification.user_id == user_id, Notification.read == False))
|
| .values(read=True)
|
| )
|
| result = await session.execute(stmt)
|
| await session.commit()
|
| return result.rowcount
|
|
|
|
|
| db_storage = DatabaseStorage()
|
|
|