HINTECH / core /nat_engine.py
Factor Studios
Upload 73 files
aaaaa79 verified
"""
NAT Engine Module
Implements Network Address Translation:
- Map (virtualIP, virtualPort) to (hostIP, hostPort)
- Maintain connection tracking table
- Handle port allocation and deallocation
- Support connection state tracking
"""
import time
import threading
import socket
import random
import struct
from typing import Dict, Optional, Tuple, Set
from dataclasses import dataclass
from enum import Enum
# Assuming IPProtocol is defined elsewhere or will be defined
# from .ip_parser import IPProtocol
class NATType(Enum):
SNAT = "SNAT" # Source NAT
DNAT = "DNAT" # Destination NAT
@dataclass
class NATSession:
"""Represents a NAT session"""
# Virtual (internal) endpoint
virtual_ip: str
virtual_port: int
# Real (external) endpoint
real_ip: str
real_port: int
# Host (translated) endpoint
host_ip: str
host_port: int
# Session metadata
protocol: int # IP protocol number (e.g., 6 for TCP, 17 for UDP)
nat_type: NATType
created_time: float
last_activity: float
bytes_in: int = 0
bytes_out: int = 0
packets_in: int = 0
packets_out: int = 0
@property
def session_id(self) -> str:
"""Get unique session identifier"""
return f"{self.virtual_ip}:{self.virtual_port}-{self.real_ip}:{self.real_port}-{self.protocol}"
@property
def is_expired(self) -> bool:
"""Check if session has expired"""
timeout = 300 if self.protocol == socket.IPPROTO_TCP else 60 # 5 min for TCP, 1 min for UDP
return time.time() - self.last_activity > timeout
@property
def duration(self) -> float:
"""Get session duration in seconds"""
return time.time() - self.created_time
def update_activity(self, bytes_transferred: int = 0, direction: str = 'out'):
"""Update session activity"""
self.last_activity = time.time()
if direction == 'out':
self.bytes_out += bytes_transferred
self.packets_out += 1
else:
self.bytes_in += bytes_transferred
self.packets_in += 1
class PortPool:
"""Manages available ports for NAT"""
def __init__(self, start_port: int = 10000, end_port: int = 65535):
self.start_port = start_port
self.end_port = end_port
self.available_ports: Set[int] = set(range(start_port, end_port + 1))
self.allocated_ports: Dict[int, str] = {} # port -> session_id
self.lock = threading.Lock()
def allocate_port(self, session_id: str) -> Optional[int]:
"""Allocate a port for a session"""
with self.lock:
if not self.available_ports:
return None
# Try to get a random port to distribute load
port = random.choice(list(self.available_ports))
self.available_ports.remove(port)
self.allocated_ports[port] = session_id
return port
def release_port(self, port: int) -> bool:
"""Release a port back to the pool"""
with self.lock:
if port in self.allocated_ports:
del self.allocated_ports[port]
if self.start_port <= port <= self.end_port:
self.available_ports.add(port)
return True
return False
def get_session_for_port(self, port: int) -> Optional[str]:
"""Get session ID for a port"""
with self.lock:
return self.allocated_ports.get(port)
def get_stats(self) -> Dict:
"""Get port pool statistics"""
with self.lock:
return {
'total_ports': self.end_port - self.start_port + 1,
'available_ports': len(self.available_ports),
'allocated_ports': len(self.allocated_ports),
'utilization': len(self.allocated_ports) / (self.end_port - self.start_port + 1)
}
class NATEngine:
"""Network Address Translation engine"""
def __init__(self, config: Dict):
self.config = config
self.sessions: Dict[str, NATSession] = {} # session_id -> session
self.virtual_to_session: Dict[Tuple[str, int, int], str] = {} # (vip, vport, proto) -> session_id
self.host_to_session: Dict[Tuple[str, int, int], str] = {} # (hip, hport, proto) -> session_id
self.lock = threading.Lock()
# Port pool for outbound connections
self.port_pool = PortPool(
config.get('port_range_start', 10000),
config.get('port_range_end', 65535)
)
# Host IP for outbound connections
self.host_ip = config.get('host_ip', self._get_default_host_ip())
# Session timeout
self.session_timeout = config.get('session_timeout', 300)
# Statistics
self.stats = {
'total_sessions': 0,
'active_sessions': 0,
'expired_sessions': 0,
'port_exhaustion_events': 0,
'bytes_translated': 0,
'packets_translated': 0
}
# Cleanup thread
self.running = False
self.cleanup_thread = None
def _get_default_host_ip(self) -> str:
"""Get default host IP address"""
try:
# Connect to a remote address to determine local IP
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(('8.8.8.8', 80))
return s.getsockname()[0]
except Exception:
return '127.0.0.1'
def _cleanup_expired_sessions(self):
"""Clean up expired sessions"""
current_time = time.time()
expired_sessions = []
with self.lock:
for session_id, session in self.sessions.items():
if session.is_expired:
expired_sessions.append(session_id)
for session_id in expired_sessions:
self._remove_session(session_id)
self.stats['expired_sessions'] += 1
def _remove_session(self, session_id: str):
"""Remove a session and clean up resources"""
with self.lock:
if session_id not in self.sessions:
return
session = self.sessions[session_id]
# Remove from lookup tables
virtual_key = (session.virtual_ip, session.virtual_port, session.protocol)
if virtual_key in self.virtual_to_session:
del self.virtual_to_session[virtual_key]
host_key = (session.host_ip, session.host_port, session.protocol)
if host_key in self.host_to_session:
del self.host_to_session[host_key]
# Release port
self.port_pool.release_port(session.host_port)
# Remove session
del self.sessions[session_id]
self.stats['active_sessions'] = len(self.sessions)
def create_outbound_session(self, virtual_ip: str, virtual_port: int,
real_ip: str, real_port: int, protocol: int) -> Optional[NATSession]:
"""Create NAT session for outbound connection"""
# Allocate host port
session_id = f"{virtual_ip}:{virtual_port}-{real_ip}:{real_port}-{protocol}"
host_port = self.port_pool.allocate_port(session_id)
if host_port is None:
self.stats['port_exhaustion_events'] += 1
return None
# Create session
session = NATSession(
virtual_ip=virtual_ip,
virtual_port=virtual_port,
real_ip=real_ip,
real_port=real_port,
host_ip=self.host_ip,
host_port=host_port,
protocol=protocol,
nat_type=NATType.SNAT,
created_time=time.time(),
last_activity=time.time()
)
with self.lock:
self.sessions[session_id] = session
# Add to lookup tables
virtual_key = (virtual_ip, virtual_port, protocol)
self.virtual_to_session[virtual_key] = session_id
host_key = (self.host_ip, host_port, protocol)
self.host_to_session[host_key] = session_id
self.stats['total_sessions'] += 1
self.stats['active_sessions'] = len(self.sessions)
return session
def translate_outbound(self, virtual_ip: str, virtual_port: int,
real_ip: str, real_port: int, protocol: int) -> Optional[Tuple[str, int]]:
"""Translate outbound packet (virtual -> host)"""
virtual_key = (virtual_ip, virtual_port, protocol)
with self.lock:
session_id = self.virtual_to_session.get(virtual_key)
if session_id:
session = self.sessions[session_id]
session.update_activity(direction='out')
return (session.host_ip, session.host_port)
else:
# Create new session
session = self.create_outbound_session(virtual_ip, virtual_port, real_ip, real_port, protocol)
if session:
return (session.host_ip, session.host_port)
return None
def translate_inbound(self, host_ip: str, host_port: int, protocol: int) -> Optional[Tuple[str, int]]:
"""Translate inbound packet (host -> virtual)"""
host_key = (host_ip, host_port, protocol)
with self.lock:
session_id = self.host_to_session.get(host_key)
if session_id and session_id in self.sessions:
session = self.sessions[session_id]
session.update_activity(direction='in')
return (session.virtual_ip, session.virtual_port)
return None
def get_session_by_virtual(self, virtual_ip: str, virtual_port: int, protocol: int) -> Optional[NATSession]:
"""Get session by virtual endpoint"""
virtual_key = (virtual_ip, virtual_port, protocol)
with self.lock:
session_id = self.virtual_to_session.get(virtual_key)
if session_id and session_id in self.sessions:
return self.sessions[session_id]
return None
def get_session_by_host(self, host_ip: str, host_port: int, protocol: int) -> Optional[NATSession]:
"""Get session by host endpoint"""
host_key = (host_ip, host_port, protocol)
with self.lock:
session_id = self.host_to_session.get(host_key)
if session_id and session_id in self.sessions:
return self.sessions[session_id]
return None
def close_session(self, session_id: str) -> bool:
"""Manually close a session"""
with self.lock:
if session_id in self.sessions:
self._remove_session(session_id)
return True
return False
def close_session_by_virtual(self, virtual_ip: str, virtual_port: int, protocol: int) -> bool:
"""Close session by virtual endpoint"""
virtual_key = (virtual_ip, virtual_port, protocol)
with self.lock:
session_id = self.virtual_to_session.get(virtual_key)
if session_id:
self._remove_session(session_id)
return True
return False
def get_sessions(self) -> Dict[str, Dict]:
"""Get all active sessions"""
with self.lock:
return {
session_id: {
'virtual_ip': session.virtual_ip,
'virtual_port': session.virtual_port,
'real_ip': session.real_ip,
'real_port': session.real_port,
'host_ip': session.host_ip,
'host_port': session.host_port,
'protocol': session.protocol,
'nat_type': session.nat_type.value,
'created_time': session.created_time,
'last_activity': session.last_activity,
'duration': session.duration,
'bytes_in': session.bytes_in,
'bytes_out': session.bytes_out,
'packets_in': session.packets_in,
'packets_out': session.packets_out,
'is_expired': session.is_expired
}
for session_id, session in self.sessions.items()
}
def get_stats(self) -> Dict:
"""Get NAT statistics"""
port_stats = self.port_pool.get_stats()
with self.lock:
current_stats = self.stats.copy()
current_stats['active_sessions'] = len(self.sessions)
current_stats.update(port_stats)
return current_stats
def update_packet_stats(self, bytes_count: int):
"""Update packet statistics"""
self.stats['bytes_translated'] += bytes_count
self.stats['packets_translated'] += 1
def _cleanup_loop(self):
"""Background cleanup loop"""
while self.running:
try:
# print("NAT cleanup loop: Cleaning expired sessions...") # Debug print
self._cleanup_expired_sessions()
time.sleep(0.1) # Shorter sleep for faster testing
except Exception as e:
print(f"NAT cleanup error: {e}")
time.sleep(0.1)
def start(self):
"""Start NAT engine"""
self.running = True
self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True)
self.cleanup_thread.start()
# print(f"NAT engine started - Host IP: {self.host_ip}, Port range: {self.port_pool.start_port}-{self.port_pool.end_port}")
def stop(self):
"""Stop NAT engine"""
# print("Stopping NAT engine...") # Debug print
self.running = False
if self.cleanup_thread and self.cleanup_thread.is_alive():
self.cleanup_thread.join(timeout=1) # Add timeout to join
if self.cleanup_thread.is_alive():
print("NAT cleanup thread did not terminate in time.") # Debug print
# Close all sessions
with self.lock:
session_ids = list(self.sessions.keys())
for session_id in session_ids:
self._remove_session(session_id)
# print("NAT engine stopped")
def _calculate_ip_checksum(self, ip_header_no_checksum: bytes) -> int:
"""Calculate the IP header checksum."""
# IP header checksum calculation (simplified for demonstration)
# This is a basic implementation and might need refinement for production use
s = 0
# loop through header words
for i in range(0, len(ip_header_no_checksum), 2):
w = (ip_header_no_checksum[i] << 8) + (ip_header_no_checksum[i+1])
s = s + w
s = (s & 0xffff) + (s >> 16)
s = s + (s >> 16)
return ~s & 0xffff
def process_inbound_packet(self, packet: bytes) -> Optional[bytes]:
"""Process an inbound packet (from internet to VPN client) for DNAT."""
# Parse IP header
# Assuming Ethernet frame, IP header starts at offset 14
# For simplicity, let's assume we are only dealing with IPv4 for now
ip_header_offset = 14
ip_header_length = (packet[ip_header_offset] & 0xF) * 4
ip_header = packet[ip_header_offset : ip_header_offset + ip_header_length]
# Unpack IP header (version_ihl, tos, total_length, identification, fragment_offset, ttl, protocol, header_checksum, source_address, destination_address)
iph = struct.unpack('!BBHHHBBH4s4s', ip_header)
protocol = iph[6]
source_ip = socket.inet_ntoa(iph[8])
dest_ip = socket.inet_ntoa(iph[9])
# Only process TCP/UDP for now
if protocol not in [socket.IPPROTO_TCP, socket.IPPROTO_UDP]:
return None
# Parse TCP/UDP header
transport_header_offset = ip_header_offset + ip_header_length
if protocol == socket.IPPROTO_TCP:
tcp_header = packet[transport_header_offset : transport_header_offset + 20]
tcph = struct.unpack('!HHLLBBHHH', tcp_header)
source_port = tcph[0]
dest_port = tcph[1]
elif protocol == socket.IPPROTO_UDP:
udp_header = packet[transport_header_offset : transport_header_offset + 8]
udph = struct.unpack('!HHHH', udp_header)
source_port = udph[0]
dest_port = udph[1]
else:
return None
# Check for DNAT rule match (simplified for now, actual DNAT rules would be in DNATEngine)
# For now, assume we are looking for a session based on host_ip (d_addr) and host_port (dest_port)
translated_endpoint = self.translate_inbound(dest_ip, dest_port, protocol)
if translated_endpoint:
virtual_ip, virtual_port = translated_endpoint
# Reconstruct packet with translated destination IP and port
# Recalculate IP header checksum
new_dest_ip_bytes = socket.inet_aton(virtual_ip)
# Rebuild IP header with new destination IP
# Need to recalculate checksum for IP header
# For simplicity, we'll set checksum to 0 and assume it's recalculated later or by OS
new_ip_header_raw = struct.pack('!BBHHHBBH4s4s', iph[0], iph[1], iph[2], iph[3], iph[4], iph[5], iph[6], 0, iph[8], new_dest_ip_bytes)
new_ip_header_checksum = self._calculate_ip_checksum(new_ip_header_raw)
new_ip_header = struct.pack('!BBHHHBBH4s4s', iph[0], iph[1], iph[2], iph[3], iph[4], iph[5], iph[6], new_ip_header_checksum, iph[8], new_dest_ip_bytes)
# Rebuild TCP/UDP header with new destination port
if protocol == socket.IPPROTO_TCP:
# Recalculate TCP checksum (requires pseudo-header, IP header, and TCP data)
new_tcp_header_raw = struct.pack('!HHLLBBHHH', source_port, virtual_port, tcph[2], tcph[3], tcph[4], tcph[5], tcph[6], 0, tcph[8])
# For now, setting checksum to 0. Proper recalculation is complex.
new_tcp_header = struct.pack('!HHLLBBHHH', source_port, virtual_port, tcph[2], tcph[3], tcph[4], tcph[5], tcph[6], 0, tcph[8])
return packet[:ip_header_offset] + new_ip_header + new_tcp_header + packet[transport_header_offset + 20:]
elif protocol == socket.IPPROTO_UDP:
# Recalculate UDP checksum (requires pseudo-header, IP header, and UDP data)
new_udp_header_raw = struct.pack('!HHHH', source_port, virtual_port, udph[2], 0)
# For now, setting checksum to 0. Proper recalculation is complex.
new_udp_header = struct.pack('!HHHH', source_port, virtual_port, udph[2], 0)
return packet[:ip_header_offset] + new_ip_header + new_udp_header + packet[transport_header_offset + 8:]
return None
def process_outbound_packet(self, packet: bytes) -> Optional[bytes]:
"""Process an outbound packet (from VPN client to internet) for SNAT."""
# Parse IP header
ip_header_offset = 14
ip_header_length = (packet[ip_header_offset] & 0xF) * 4
ip_header = packet[ip_header_offset : ip_header_offset + ip_header_length]
# Unpack IP header
iph = struct.unpack('!BBHHHBBH4s4s', ip_header)
protocol = iph[6]
source_ip = socket.inet_ntoa(iph[8])
dest_ip = socket.inet_ntoa(iph[9])
# Only process TCP/UDP for now
if protocol not in [socket.IPPROTO_TCP, socket.IPPROTO_UDP]:
return None
# Parse TCP/UDP header
transport_header_offset = ip_header_offset + ip_header_length
if protocol == socket.IPPROTO_TCP:
tcp_header = packet[transport_header_offset : transport_header_offset + 20]
tcph = struct.unpack('!HHLLBBHHH', tcp_header)
source_port = tcph[0]
dest_port = tcph[1]
elif protocol == socket.IPPROTO_UDP:
udp_header = packet[transport_header_offset : transport_header_offset + 8]
udph = struct.unpack('!HHHH', udp_header)
source_port = udph[0]
dest_port = udph[1]
else:
return None
# Perform SNAT
translated_endpoint = self.translate_outbound(source_ip, source_port, dest_ip, dest_port, protocol)
if translated_endpoint:
host_ip, host_port = translated_endpoint
# Reconstruct packet with translated source IP and port
# Recalculate IP header checksum
new_source_ip_bytes = socket.inet_aton(host_ip)
# Rebuild IP header with new source IP
new_ip_header_raw = struct.pack('!BBHHHBBH4s4s', iph[0], iph[1], iph[2], iph[3], iph[4], iph[5], iph[6], 0, new_source_ip_bytes, iph[9])
new_ip_header_checksum = self._calculate_ip_checksum(new_ip_header_raw)
new_ip_header = struct.pack('!BBHHHBBH4s4s', iph[0], iph[1], iph[2], iph[3], iph[4], iph[5], iph[6], new_ip_header_checksum, new_source_ip_bytes, iph[9])
# Rebuild TCP/UDP header with new source port
if protocol == socket.IPPROTO_TCP:
# Recalculate TCP checksum
new_tcp_header_raw = struct.pack('!HHLLBBHHH', host_port, dest_port, tcph[2], tcph[3], tcph[4], tcph[5], tcph[6], 0, tcph[8])
# For now, setting checksum to 0. Proper recalculation is complex.
new_tcp_header = struct.pack('!HHLLBBHHH', host_port, dest_port, tcph[2], tcph[3], tcph[4], tcph[5], tcph[6], 0, tcph[8])
return packet[:ip_header_offset] + new_ip_header + new_tcp_header + packet[transport_header_offset + 20:]
elif protocol == socket.IPPROTO_UDP:
# Recalculate UDP checksum
new_udp_header_raw = struct.pack('!HHHH', host_port, dest_port, udph[2], 0)
# For now, setting checksum to 0. Proper recalculation is complex.
new_udp_header = struct.pack('!HHHH', host_port, dest_port, udph[2], 0)
return packet[:ip_header_offset] + new_ip_header + new_udp_header + packet[transport_header_offset + 8:]
return None
class NATRule:
"""Represents a NAT rule for DNAT (port forwarding)"""
def __init__(self, external_port: int, internal_ip: str, internal_port: int,
protocol: int, enabled: bool = True):
self.external_port = external_port
self.internal_ip = internal_ip
self.internal_port = internal_port
self.protocol = protocol
self.enabled = enabled
self.created_time = time.time()
self.hit_count = 0
self.last_hit = None
def matches(self, port: int, protocol: int) -> bool:
"""Check if rule matches the given port and protocol"""
return (self.enabled and
self.external_port == port and
self.protocol == protocol)
def record_hit(self):
"""Record a rule hit"""
self.hit_count += 1
self.last_hit = time.time()
def to_dict(self) -> Dict:
"""Convert rule to dictionary"""
return {
'external_port': self.external_port,
'internal_ip': self.internal_ip,
'internal_port': self.internal_port,
'protocol': self.protocol,
'enabled': self.enabled,
'created_time': self.created_time,
'hit_count': self.hit_count,
'last_hit': self.last_hit
}
class DNATEngine:
"""Destination NAT engine for port forwarding"""
def __init__(self):
self.rules: Dict[str, NATRule] = {} # rule_id -> rule
self.lock = threading.Lock()
def add_rule(self, rule_id: str, external_port: int, internal_ip: str,
internal_port: int, protocol: int) -> bool:
"""Add DNAT rule"""
with self.lock:
if rule_id in self.rules:
return False
rule = NATRule(external_port, internal_ip, internal_port, protocol)
self.rules[rule_id] = rule
return True
def remove_rule(self, rule_id: str) -> bool:
"""Remove DNAT rule"""
with self.lock:
if rule_id in self.rules:
del self.rules[rule_id]
return True
return False
def get_rule(self, rule_id: str) -> Optional[NATRule]:
"""Get DNAT rule by ID"""
with self.lock:
return self.rules.get(rule_id)
def get_matching_rule(self, port: int, protocol: int) -> Optional[NATRule]:
"""Get matching DNAT rule for given port and protocol"""
with self.lock:
for rule in self.rules.values():
if rule.matches(port, protocol):
rule.record_hit()
return rule
return None
def get_all_rules(self) -> Dict[str, Dict]:
"""Get all DNAT rules"""
with self.lock:
return {rule_id: rule.to_dict() for rule_id, rule in self.rules.items()}