Spaces:
Paused
Paused
| """ | |
| IP Parser/Assembler Module | |
| Handles IPv4 packet parsing and construction: | |
| - Parse IPv4, UDP, and TCP headers | |
| - Calculate and verify checksums | |
| - Handle packet fragmentation and reassembly | |
| - Support various IP options | |
| """ | |
| import struct | |
| import socket | |
| from typing import Dict, List, Optional, Tuple | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| class IPProtocol(Enum): | |
| ICMP = 1 | |
| TCP = 6 | |
| UDP = 17 | |
| class IPv4Header: | |
| """IPv4 header structure""" | |
| version: int = 4 | |
| ihl: int = 5 # Internet Header Length (in 32-bit words) | |
| tos: int = 0 # Type of Service | |
| total_length: int = 0 | |
| identification: int = 0 | |
| flags: int = 0 # 3 bits: Reserved, Don't Fragment, More Fragments | |
| fragment_offset: int = 0 # 13 bits | |
| ttl: int = 64 # Time to Live | |
| protocol: int = 0 | |
| header_checksum: int = 0 | |
| source_ip: str = '0.0.0.0' | |
| dest_ip: str = '0.0.0.0' | |
| options: bytes = b'' | |
| def header_length(self) -> int: | |
| """Get header length in bytes""" | |
| return self.ihl * 4 | |
| def dont_fragment(self) -> bool: | |
| """Check if Don't Fragment flag is set""" | |
| return bool(self.flags & 0x2) | |
| def more_fragments(self) -> bool: | |
| """Check if More Fragments flag is set""" | |
| return bool(self.flags & 0x1) | |
| def is_fragment(self) -> bool: | |
| """Check if this is a fragment""" | |
| return self.more_fragments or self.fragment_offset > 0 | |
| class TCPHeader: | |
| """TCP header structure""" | |
| source_port: int = 0 | |
| dest_port: int = 0 | |
| seq_num: int = 0 | |
| ack_num: int = 0 | |
| data_offset: int = 5 # Header length in 32-bit words | |
| reserved: int = 0 | |
| flags: int = 0 # 9 bits: NS, CWR, ECE, URG, ACK, PSH, RST, SYN, FIN | |
| window_size: int = 65535 | |
| checksum: int = 0 | |
| urgent_pointer: int = 0 | |
| options: bytes = b'' | |
| def header_length(self) -> int: | |
| """Get header length in bytes""" | |
| return self.data_offset * 4 | |
| # TCP Flag properties | |
| def fin(self) -> bool: | |
| return bool(self.flags & 0x01) | |
| def syn(self) -> bool: | |
| return bool(self.flags & 0x02) | |
| def rst(self) -> bool: | |
| return bool(self.flags & 0x04) | |
| def psh(self) -> bool: | |
| return bool(self.flags & 0x08) | |
| def ack(self) -> bool: | |
| return bool(self.flags & 0x10) | |
| def urg(self) -> bool: | |
| return bool(self.flags & 0x20) | |
| def set_flag(self, flag_name: str, value: bool = True): | |
| """Set TCP flag""" | |
| flag_bits = { | |
| 'fin': 0x01, 'syn': 0x02, 'rst': 0x04, 'psh': 0x08, | |
| 'ack': 0x10, 'urg': 0x20, 'ece': 0x40, 'cwr': 0x80, 'ns': 0x100 | |
| } | |
| if flag_name.lower() in flag_bits: | |
| bit = flag_bits[flag_name.lower()] | |
| if value: | |
| self.flags |= bit | |
| else: | |
| self.flags &= ~bit | |
| class UDPHeader: | |
| """UDP header structure""" | |
| source_port: int = 0 | |
| dest_port: int = 0 | |
| length: int = 8 # Header + data length | |
| checksum: int = 0 | |
| def header_length(self) -> int: | |
| """Get header length in bytes (always 8 for UDP)""" | |
| return 8 | |
| class ParsedPacket: | |
| """Parsed packet structure""" | |
| ip_header: IPv4Header | |
| transport_header: Optional[object] = None # TCPHeader or UDPHeader | |
| payload: bytes = b'' | |
| raw_packet: bytes = b'' | |
| class IPParser: | |
| """IPv4 packet parser and assembler""" | |
| def calculate_checksum(data: bytes) -> int: | |
| """Calculate Internet checksum""" | |
| # Pad data to even length | |
| if len(data) % 2: | |
| data += b'\x00' | |
| checksum = 0 | |
| for i in range(0, len(data), 2): | |
| word = (data[i] << 8) + data[i + 1] | |
| checksum += word | |
| # Add carry bits | |
| while checksum >> 16: | |
| checksum = (checksum & 0xFFFF) + (checksum >> 16) | |
| # One's complement | |
| return (~checksum) & 0xFFFF | |
| def verify_checksum(data: bytes, checksum: int) -> bool: | |
| """Verify Internet checksum""" | |
| calculated = IPParser.calculate_checksum(data) | |
| return calculated == checksum or (calculated + checksum) == 0xFFFF | |
| def parse_ipv4_header(cls, data: bytes) -> Tuple[IPv4Header, int]: | |
| """Parse IPv4 header from raw bytes""" | |
| if len(data) < 20: | |
| raise ValueError("IPv4 header too short") | |
| # Parse fixed part of header | |
| header_data = struct.unpack('!BBHHHBBH4s4s', data[:20]) | |
| header = IPv4Header() | |
| version_ihl = header_data[0] | |
| header.version = (version_ihl >> 4) & 0xF | |
| header.ihl = version_ihl & 0xF | |
| header.tos = header_data[1] | |
| header.total_length = header_data[2] | |
| header.identification = header_data[3] | |
| flags_fragment = header_data[4] | |
| header.flags = (flags_fragment >> 13) & 0x7 | |
| header.fragment_offset = flags_fragment & 0x1FFF | |
| header.ttl = header_data[5] | |
| header.protocol = header_data[6] | |
| header.header_checksum = header_data[7] | |
| header.source_ip = socket.inet_ntoa(header_data[8]) | |
| header.dest_ip = socket.inet_ntoa(header_data[9]) | |
| # Validate version | |
| if header.version != 4: | |
| raise ValueError(f"Unsupported IP version: {header.version}") | |
| # Parse options if present | |
| options_length = header.header_length - 20 | |
| if options_length > 0: | |
| if len(data) < 20 + options_length: | |
| raise ValueError("IPv4 options truncated") | |
| header.options = data[20:20 + options_length] | |
| return header, header.header_length | |
| def parse_tcp_header(cls, data: bytes) -> Tuple[TCPHeader, int]: | |
| """Parse TCP header from raw bytes""" | |
| if len(data) < 20: | |
| raise ValueError("TCP header too short") | |
| # Parse fixed part of header | |
| header_data = struct.unpack('!HHIIBBHHH', data[:20]) | |
| header = TCPHeader() | |
| header.source_port = header_data[0] | |
| header.dest_port = header_data[1] | |
| header.seq_num = header_data[2] | |
| header.ack_num = header_data[3] | |
| offset_reserved = header_data[4] | |
| header.data_offset = (offset_reserved >> 4) & 0xF | |
| header.reserved = (offset_reserved >> 1) & 0x7 | |
| header.flags = ((offset_reserved & 0x1) << 8) | header_data[5] | |
| header.window_size = header_data[6] | |
| header.checksum = header_data[7] | |
| header.urgent_pointer = header_data[8] | |
| # Parse options if present | |
| options_length = header.header_length - 20 | |
| if options_length > 0: | |
| if len(data) < 20 + options_length: | |
| raise ValueError("TCP options truncated") | |
| header.options = data[20:20 + options_length] | |
| return header, header.header_length | |
| def parse_udp_header(cls, data: bytes) -> Tuple[UDPHeader, int]: | |
| """Parse UDP header from raw bytes""" | |
| if len(data) < 8: | |
| raise ValueError("UDP header too short") | |
| header_data = struct.unpack('!HHHH', data[:8]) | |
| header = UDPHeader() | |
| header.source_port = header_data[0] | |
| header.dest_port = header_data[1] | |
| header.length = header_data[2] | |
| header.checksum = header_data[3] | |
| return header, 8 | |
| def parse_packet(cls, data: bytes) -> ParsedPacket: | |
| """Parse complete packet""" | |
| packet = ParsedPacket(raw_packet=data) | |
| # Parse IP header | |
| packet.ip_header, ip_header_len = cls.parse_ipv4_header(data) | |
| # Extract payload after IP header | |
| ip_payload = data[ip_header_len:packet.ip_header.total_length] | |
| # Parse transport layer header | |
| if packet.ip_header.protocol == IPProtocol.TCP.value: | |
| packet.transport_header, transport_header_len = cls.parse_tcp_header(ip_payload) | |
| packet.payload = ip_payload[transport_header_len:] | |
| elif packet.ip_header.protocol == IPProtocol.UDP.value: | |
| packet.transport_header, transport_header_len = cls.parse_udp_header(ip_payload) | |
| packet.payload = ip_payload[transport_header_len:] | |
| else: | |
| # Unsupported protocol, treat as raw payload | |
| packet.payload = ip_payload | |
| return packet | |
| def build_ipv4_header(cls, header: IPv4Header) -> bytes: | |
| """Build IPv4 header as bytes""" | |
| # Calculate header length including options | |
| header.ihl = (20 + len(header.options) + 3) // 4 # Round up to 32-bit boundary | |
| # Build header without checksum | |
| version_ihl = (header.version << 4) | header.ihl | |
| flags_fragment = (header.flags << 13) | header.fragment_offset | |
| header_data = struct.pack( | |
| '!BBHHHBBH4s4s', | |
| version_ihl, header.tos, header.total_length, | |
| header.identification, flags_fragment, | |
| header.ttl, header.protocol, 0, # Checksum = 0 for calculation | |
| socket.inet_aton(header.source_ip), | |
| socket.inet_aton(header.dest_ip) | |
| ) | |
| # Add options and padding | |
| if header.options: | |
| header_data += header.options | |
| # Pad to 32-bit boundary | |
| padding_needed = (header.ihl * 4) - len(header_data) | |
| if padding_needed > 0: | |
| header_data += b'\x00' * padding_needed | |
| # Calculate and insert checksum | |
| checksum = cls.calculate_checksum(header_data) | |
| header_data = header_data[:10] + struct.pack('!H', checksum) + header_data[12:] | |
| return header_data | |
| def build_tcp_header(cls, header: TCPHeader, source_ip: str, dest_ip: str, payload: bytes) -> bytes: | |
| """Build TCP header as bytes with checksum""" | |
| # Calculate header length including options | |
| header.data_offset = (20 + len(header.options) + 3) // 4 # Round up to 32-bit boundary | |
| # Build header without checksum | |
| offset_reserved_flags = (header.data_offset << 12) | (header.reserved << 9) | header.flags | |
| header_data = struct.pack( | |
| '!HHIIHHH', | |
| header.source_port, header.dest_port, | |
| header.seq_num, header.ack_num, | |
| offset_reserved_flags, header.window_size, | |
| 0, header.urgent_pointer # Checksum = 0 for calculation | |
| ) | |
| # Add options and padding | |
| if header.options: | |
| header_data += header.options | |
| # Pad to 32-bit boundary | |
| padding_needed = (header.data_offset * 4) - len(header_data) | |
| if padding_needed > 0: | |
| header_data += b'\x00' * padding_needed | |
| # Calculate TCP checksum with pseudo-header | |
| pseudo_header = struct.pack( | |
| '!4s4sBBH', | |
| socket.inet_aton(source_ip), | |
| socket.inet_aton(dest_ip), | |
| 0, IPProtocol.TCP.value, | |
| len(header_data) + len(payload) | |
| ) | |
| checksum_data = pseudo_header + header_data + payload | |
| checksum = cls.calculate_checksum(checksum_data) | |
| # Insert checksum | |
| header_data = header_data[:16] + struct.pack('!H', checksum) + header_data[18:] | |
| return header_data | |
| def build_udp_header(cls, header: UDPHeader, source_ip: str, dest_ip: str, payload: bytes) -> bytes: | |
| """Build UDP header as bytes with checksum""" | |
| header.length = 8 + len(payload) | |
| # Build header without checksum | |
| header_data = struct.pack( | |
| '!HHHH', | |
| header.source_port, header.dest_port, | |
| header.length, 0 # Checksum = 0 for calculation | |
| ) | |
| # Calculate UDP checksum with pseudo-header (optional for IPv4) | |
| if header.checksum != 0: # If checksum is required | |
| pseudo_header = struct.pack( | |
| '!4s4sBBH', | |
| socket.inet_aton(source_ip), | |
| socket.inet_aton(dest_ip), | |
| 0, IPProtocol.UDP.value, | |
| header.length | |
| ) | |
| checksum_data = pseudo_header + header_data + payload | |
| checksum = cls.calculate_checksum(checksum_data) | |
| # Insert checksum | |
| header_data = header_data[:6] + struct.pack('!H', checksum) + header_data[8:] | |
| return header_data | |
| def build_packet(cls, ip_header: IPv4Header, transport_header: Optional[object] = None, payload: bytes = b'') -> bytes: | |
| """Build complete packet""" | |
| transport_data = b'' | |
| # Build transport header | |
| if transport_header: | |
| if isinstance(transport_header, TCPHeader): | |
| transport_data = cls.build_tcp_header( | |
| transport_header, ip_header.source_ip, ip_header.dest_ip, payload | |
| ) | |
| elif isinstance(transport_header, UDPHeader): | |
| transport_data = cls.build_udp_header( | |
| transport_header, ip_header.source_ip, ip_header.dest_ip, payload | |
| ) | |
| # Update IP header total length | |
| ip_header.total_length = ip_header.header_length + len(transport_data) + len(payload) | |
| # Build IP header | |
| ip_data = cls.build_ipv4_header(ip_header) | |
| # Combine all parts | |
| return ip_data + transport_data + payload | |
| class PacketFragmenter: | |
| """Handle packet fragmentation and reassembly""" | |
| def __init__(self, mtu: int = 1500): | |
| self.mtu = mtu | |
| self.fragments: Dict[Tuple[str, str, int], List[Tuple[int, bytes]]] = {} # (src, dst, id) -> [(offset, data)] | |
| def fragment_packet(self, packet: bytes, mtu: int = None) -> List[bytes]: | |
| """Fragment a packet if it exceeds MTU""" | |
| if mtu is None: | |
| mtu = self.mtu | |
| if len(packet) <= mtu: | |
| return [packet] | |
| # Parse original packet | |
| parsed = IPParser.parse_packet(packet) | |
| ip_header = parsed.ip_header | |
| # Don't fragment if DF flag is set | |
| if ip_header.dont_fragment: | |
| raise ValueError("Packet too large and Don't Fragment flag is set") | |
| fragments = [] | |
| payload_mtu = mtu - ip_header.header_length | |
| payload_mtu = (payload_mtu // 8) * 8 # Must be multiple of 8 bytes | |
| # Get the payload to fragment (everything after IP header) | |
| payload_start = ip_header.header_length | |
| payload = packet[payload_start:] | |
| offset = 0 | |
| while offset < len(payload): | |
| # Create fragment | |
| fragment_payload = payload[offset:offset + payload_mtu] | |
| # Create new IP header for fragment | |
| frag_header = IPv4Header( | |
| version=ip_header.version, | |
| ihl=ip_header.ihl, | |
| tos=ip_header.tos, | |
| identification=ip_header.identification, | |
| ttl=ip_header.ttl, | |
| protocol=ip_header.protocol, | |
| source_ip=ip_header.source_ip, | |
| dest_ip=ip_header.dest_ip, | |
| options=ip_header.options | |
| ) | |
| # Set fragment offset and flags | |
| frag_header.fragment_offset = (ip_header.fragment_offset * 8 + offset) // 8 | |
| frag_header.flags = ip_header.flags | |
| # Set More Fragments flag if not last fragment | |
| if offset + len(fragment_payload) < len(payload): | |
| frag_header.flags |= 0x1 # More Fragments | |
| else: | |
| frag_header.flags &= ~0x1 # Clear More Fragments | |
| # Build fragment | |
| fragment = IPParser.build_packet(frag_header, payload=fragment_payload) | |
| fragments.append(fragment) | |
| offset += len(fragment_payload) | |
| return fragments | |
| def reassemble_packet(self, packet: bytes) -> Optional[bytes]: | |
| """Reassemble fragmented packet""" | |
| parsed = IPParser.parse_packet(packet) | |
| ip_header = parsed.ip_header | |
| # If not a fragment, return as-is | |
| if not ip_header.is_fragment: | |
| return packet | |
| # Create fragment key | |
| key = (ip_header.source_ip, ip_header.dest_ip, ip_header.identification) | |
| # Store fragment | |
| if key not in self.fragments: | |
| self.fragments[key] = [] | |
| payload_start = ip_header.header_length | |
| fragment_data = packet[payload_start:] | |
| self.fragments[key].append((ip_header.fragment_offset * 8, fragment_data)) | |
| # Check if we have all fragments | |
| fragments = sorted(self.fragments[key]) | |
| # Verify we have contiguous fragments starting from 0 | |
| expected_offset = 0 | |
| complete_payload = b'' | |
| for offset, data in fragments: | |
| if offset != expected_offset: | |
| return None # Missing fragment | |
| complete_payload += data | |
| expected_offset += len(data) | |
| # Check if last fragment (no More Fragments flag) | |
| last_fragment = None | |
| for frag_packet in [packet]: # We only have current packet, need to track all | |
| frag_parsed = IPParser.parse_packet(frag_packet) | |
| if not frag_parsed.ip_header.more_fragments: | |
| last_fragment = frag_parsed | |
| break | |
| if last_fragment is None: | |
| return None # Don't have last fragment yet | |
| # Reassemble complete packet | |
| complete_header = IPv4Header( | |
| version=ip_header.version, | |
| ihl=ip_header.ihl, | |
| tos=ip_header.tos, | |
| identification=ip_header.identification, | |
| flags=ip_header.flags & ~0x1, # Clear More Fragments | |
| fragment_offset=0, | |
| ttl=ip_header.ttl, | |
| protocol=ip_header.protocol, | |
| source_ip=ip_header.source_ip, | |
| dest_ip=ip_header.dest_ip, | |
| options=ip_header.options | |
| ) | |
| complete_packet = IPParser.build_packet(complete_header, payload=complete_payload) | |
| # Clean up fragments | |
| del self.fragments[key] | |
| return complete_packet | |