1proxy / app /db_storage.py
paijo77's picture
update app/db_storage.py
ae50c7b verified
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, and_, insert
from sqlalchemy.orm import selectinload
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from typing import List, Optional
from datetime import datetime
import logging
from app.db_models import User, ProxySource, Proxy
from app.validator import proxy_validator
logger = logging.getLogger(__name__)
class DatabaseStorage:
def __init__(self, enable_validation: bool = True):
self.enable_validation = enable_validation
async def create_admin_user(
self, session: AsyncSession, email: str = "admin@1proxy.local"
) -> User:
result = await session.execute(select(User).where(User.email == email))
user = result.scalar_one_or_none()
if not user:
user = User(
oauth_provider="local",
oauth_id="admin",
email=email,
username="admin",
role="admin",
avatar_url=None,
)
session.add(user)
await session.commit()
await session.refresh(user)
return user
async def seed_admin_sources(self, session: AsyncSession, admin_user_id: int):
from app.sources import SourceRegistry
for source_config in SourceRegistry.SOURCES:
result = await session.execute(
select(ProxySource).where(ProxySource.url == str(source_config.url))
)
existing = result.scalar_one_or_none()
if not existing:
source = ProxySource(
user_id=admin_user_id,
url=str(source_config.url),
type=source_config.type.value
if hasattr(source_config.type, "value")
else str(source_config.type),
name=str(source_config.url).split("/")[-2],
enabled=source_config.enabled,
validated=True,
is_admin_source=True,
is_paid=False,
)
session.add(source)
await session.commit()
async def add_proxy(
self, session: AsyncSession, proxy_data: dict, source_id: Optional[int] = None
) -> Optional[Proxy]:
result = await session.execute(
select(Proxy).where(Proxy.url == proxy_data["url"])
)
existing = result.scalar_one_or_none()
if existing:
existing.last_seen = datetime.utcnow()
existing.updated_at = datetime.utcnow()
if source_id and not existing.source_id:
existing.source_id = source_id
await session.commit()
return existing
proxy = Proxy(
source_id=source_id,
url=proxy_data["url"],
protocol=proxy_data.get("protocol", "http"),
ip=proxy_data.get("ip"),
port=proxy_data.get("port"),
is_working=True,
)
session.add(proxy)
await session.commit()
await session.refresh(proxy)
return proxy
async def add_proxy_with_validation(
self, session: AsyncSession, proxy_data: dict, source_id: Optional[int] = None
) -> Optional[Proxy]:
"""Add proxy with comprehensive validation"""
url = proxy_data.get("url")
ip = proxy_data.get("ip")
if not url or not ip:
return None
if self.enable_validation:
validation_result = await proxy_validator.validate_comprehensive(url, ip)
if not validation_result.success:
return None
proxy_data.update(
{
"latency_ms": validation_result.latency_ms,
"anonymity": validation_result.anonymity,
"can_access_google": validation_result.can_access_google,
"country_code": validation_result.country_code,
"country_name": validation_result.country_name,
"proxy_type": validation_result.proxy_type,
"quality_score": validation_result.quality_score,
"is_working": True,
"validation_status": "validated",
"last_validated": datetime.utcnow(),
}
)
return await self.add_proxy(session, proxy_data, source_id)
async def add_proxies(self, session: AsyncSession, proxies_data: List[dict]) -> int:
"""
Efficiently add proxies using bulk insert with ON CONFLICT DO UPDATE.
This avoids N queries for N proxies and instead uses a single bulk operation.
"""
if not proxies_data:
return 0
now = datetime.utcnow()
prepared_data = []
for proxy_data in proxies_data:
try:
# Extract or construct URL
url = proxy_data.get("url")
if not url:
ip = proxy_data.get("ip")
port = proxy_data.get("port")
protocol = proxy_data.get("protocol", "http")
if ip and port:
url = f"{protocol}://{ip}:{port}"
else:
continue
# Prepare data for bulk insert
prepared_data.append(
{
"url": url,
"protocol": proxy_data.get("protocol", "http"),
"ip": proxy_data.get("ip"),
"port": proxy_data.get("port"),
"country_code": proxy_data.get("country_code"),
"country_name": proxy_data.get("country_name"),
"city": proxy_data.get("city"),
"latency_ms": proxy_data.get("latency_ms"),
"speed_mbps": proxy_data.get("speed_mbps"),
"anonymity": proxy_data.get("anonymity"),
"proxy_type": proxy_data.get("proxy_type"),
"quality_score": proxy_data.get("quality_score"),
"is_working": True,
"validation_status": proxy_data.get(
"validation_status", "pending"
),
"last_validated": proxy_data.get("last_validated"),
"first_seen": now,
"last_seen": now,
"created_at": now,
"updated_at": now,
}
)
except Exception as e:
logger.error(f"Error preparing proxy data: {e}")
continue
if not prepared_data:
return 0
try:
batch_size = 100
total_inserted = 0
for i in range(0, len(prepared_data), batch_size):
batch = prepared_data[i : i + batch_size]
for proxy_dict in batch:
try:
result = await session.execute(
select(Proxy).where(Proxy.url == proxy_dict["url"])
)
existing = result.scalar_one_or_none()
if existing:
existing.last_seen = now
existing.updated_at = now
else:
proxy = Proxy(**proxy_dict)
session.add(proxy)
total_inserted += 1
except Exception as e:
logger.error(
f"Error inserting proxy {proxy_dict.get('url')}: {e}"
)
continue
await session.commit()
logger.info(
f"Successfully processed {len(prepared_data)} proxies, inserted {total_inserted} new ones"
)
return len(prepared_data)
except Exception as e:
logger.error(f"Error in bulk insert: {e}")
await session.rollback()
return await self._add_proxies_fallback(session, prepared_data)
async def _add_proxies_fallback(
self, session: AsyncSession, proxies_data: List[dict]
) -> int:
"""Fallback method for adding proxies one by one if bulk insert fails."""
added_count = 0
now = datetime.utcnow()
for proxy_data in proxies_data:
try:
url = proxy_data.get("url")
if not url:
continue
# Check if exists
result = await session.execute(select(Proxy).where(Proxy.url == url))
existing = result.scalar_one_or_none()
if existing:
existing.last_seen = now
existing.updated_at = now
else:
proxy = Proxy(**proxy_data)
session.add(proxy)
added_count += 1
except Exception as e:
logger.error(f"Error in fallback insert for proxy: {e}")
continue
await session.commit()
return added_count
async def validate_and_update_proxies(
self,
session: AsyncSession,
proxy_ids: Optional[List[int]] = None,
limit: int = 50,
) -> dict:
"""Validate pending proxies and update their status"""
if proxy_ids:
query = select(Proxy).where(
Proxy.id.in_(proxy_ids), Proxy.validation_status == "pending"
)
else:
query = (
select(Proxy).where(Proxy.validation_status == "pending").limit(limit)
)
result = await session.execute(query)
proxies_to_validate = result.scalars().all()
if not proxies_to_validate:
return {"validated": 0, "failed": 0, "total": 0}
proxy_tuples = [(p.url, p.ip) for p in proxies_to_validate if p.ip]
if not proxy_tuples:
return {"validated": 0, "failed": 0, "total": 0}
validation_results = await proxy_validator.validate_batch(proxy_tuples)
validated_count = 0
failed_count = 0
for proxy in proxies_to_validate:
matching_result = next(
(r for url, r in validation_results if url == proxy.url), None
)
if not matching_result:
continue
if matching_result.success:
proxy.latency_ms = matching_result.latency_ms
proxy.anonymity = matching_result.anonymity
proxy.can_access_google = matching_result.can_access_google
proxy.country_code = matching_result.country_code
proxy.country_name = matching_result.country_name
proxy.proxy_type = matching_result.proxy_type
proxy.quality_score = matching_result.quality_score
proxy.is_working = True
proxy.validation_status = "validated"
proxy.last_validated = datetime.utcnow()
proxy.validation_failures = 0
validated_count += 1
else:
proxy.is_working = False
proxy.validation_status = "failed"
proxy.validation_failures = (proxy.validation_failures or 0) + 1
failed_count += 1
await session.commit()
return {
"validated": validated_count,
"failed": failed_count,
"total": len(proxies_to_validate),
}
async def get_proxies(
self,
session: AsyncSession,
protocol: Optional[str] = None,
country_code: Optional[str] = None,
anonymity: Optional[str] = None,
min_quality: Optional[int] = None,
is_working: bool = True,
validation_status: str = "validated",
limit: int = 100,
offset: int = 0,
order_by: str = "quality_score",
) -> tuple[List[Proxy], int]:
# Use selectinload to prevent N+1 query problem when accessing proxy.source
query = (
select(Proxy)
.options(selectinload(Proxy.source))
.where(
Proxy.is_working == is_working,
Proxy.validation_status == validation_status,
)
)
if protocol:
query = query.where(Proxy.protocol == protocol)
if country_code:
query = query.where(Proxy.country_code == country_code)
if anonymity:
query = query.where(Proxy.anonymity == anonymity)
if min_quality:
query = query.where(Proxy.quality_score >= min_quality)
count_query = select(func.count()).select_from(query.subquery())
total_result = await session.execute(count_query)
total = total_result.scalar()
if order_by == "latency_ms":
query = query.order_by(Proxy.latency_ms.asc().nulls_last())
elif order_by == "quality_score":
query = query.order_by(Proxy.quality_score.desc().nulls_last())
elif order_by == "created_at":
query = query.order_by(Proxy.created_at.desc())
query = query.limit(limit).offset(offset)
result = await session.execute(query)
proxies = result.scalars().all()
return list(proxies), total
async def get_sources(
self,
session: AsyncSession,
user_id: Optional[int] = None,
enabled_only: bool = False,
) -> List[ProxySource]:
query = select(ProxySource)
if user_id:
query = query.where(ProxySource.user_id == user_id)
if enabled_only:
query = query.where(ProxySource.enabled == True)
result = await session.execute(query)
return list(result.scalars().all())
async def get_random_proxy(
self,
session: AsyncSession,
protocol: Optional[str] = None,
country_code: Optional[str] = None,
min_quality: Optional[int] = None,
anonymity: Optional[str] = None,
max_latency: Optional[int] = None,
) -> Optional[Proxy]:
query = select(Proxy).where(
Proxy.is_working == True, Proxy.validation_status == "validated"
)
if protocol:
query = query.where(Proxy.protocol == protocol)
if country_code:
query = query.where(Proxy.country_code == country_code)
if min_quality:
query = query.where(Proxy.quality_score >= min_quality)
if anonymity:
query = query.where(Proxy.anonymity == anonymity)
if max_latency:
query = query.where(Proxy.latency_ms <= max_latency)
query = query.order_by(func.random()).limit(1)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_stats(self, session: AsyncSession) -> dict:
"""
Get proxy statistics efficiently using a single GROUP BY query
instead of multiple separate queries.
"""
# Single query with GROUP BY for protocol counts
result = await session.execute(
select(Proxy.protocol, func.count(Proxy.id).label("count"))
.where(Proxy.validation_status == "validated")
.group_by(Proxy.protocol)
)
by_protocol = {}
total = 0
for row in result:
protocol = row.protocol if row.protocol else "unknown"
count = row.count
by_protocol[protocol] = count
total += count
# Ensure all expected protocols are present (even if 0)
expected_protocols = [
"http",
"https",
"vmess",
"vless",
"trojan",
"shadowsocks",
]
for protocol in expected_protocols:
if protocol not in by_protocol:
by_protocol[protocol] = 0
return {"total_proxies": total, "by_protocol": by_protocol}
async def count_proxies(self, session: AsyncSession) -> int:
result = await session.execute(select(func.count()).select_from(Proxy))
return result.scalar() or 0
async def count_sources(self, session: AsyncSession) -> int:
result = await session.execute(select(func.count()).select_from(ProxySource))
return result.scalar() or 0
async def count_users(self, session: AsyncSession) -> int:
result = await session.execute(select(func.count()).select_from(User))
return result.scalar() or 0
async def get_or_create_user(
self,
session: AsyncSession,
oauth_provider: str,
oauth_id: str,
email: str,
username: str,
role: str = "user",
avatar_url: Optional[str] = None,
) -> User:
result = await session.execute(
select(User).where(
and_(User.oauth_provider == oauth_provider, User.oauth_id == oauth_id)
)
)
user = result.scalar_one_or_none()
if not user:
user = User(
oauth_provider=oauth_provider,
oauth_id=oauth_id,
email=email,
username=username,
role=role,
avatar_url=avatar_url,
)
session.add(user)
await session.commit()
await session.refresh(user)
return user
async def create_notification(
self,
session: AsyncSession,
user_id: int,
notification_type: str,
title: str,
message: str,
severity: str = "info",
):
from app.db_models import Notification
notification = Notification(
user_id=user_id,
type=notification_type,
title=title,
message=message,
severity=severity,
)
session.add(notification)
await session.commit()
await session.refresh(notification)
return notification
async def get_notifications(
self,
session: AsyncSession,
user_id: int,
unread_only: bool = False,
limit: int = 50,
):
from app.db_models import Notification
query = select(Notification).where(Notification.user_id == user_id)
if unread_only:
query = query.where(Notification.read == False)
query = query.order_by(Notification.created_at.desc()).limit(limit)
result = await session.execute(query)
return list(result.scalars().all())
async def mark_notification_read(
self, session: AsyncSession, user_id: int, notification_id: int
) -> bool:
from app.db_models import Notification
result = await session.execute(
select(Notification).where(
and_(
Notification.id == notification_id, Notification.user_id == user_id
)
)
)
notification = result.scalar_one_or_none()
if notification:
notification.read = True
await session.commit()
return True
return False
async def mark_all_notifications_read(
self, session: AsyncSession, user_id: int
) -> int:
from app.db_models import Notification
from sqlalchemy import update
stmt = (
update(Notification)
.where(and_(Notification.user_id == user_id, Notification.read == False))
.values(read=True)
)
result = await session.execute(stmt)
await session.commit()
return result.rowcount
db_storage = DatabaseStorage()