""" FlowSender - manages traffic generation with background threads per flow. """ import logging import shutil import socket import struct import threading import time import urllib.request import json from scapy.all import send, sendpfast, sr, conf from engine.packet_builder import build_packet, stamp_payload, MAGIC, HEADER_LEN log = logging.getLogger(__name__) # Suppress Scapy verbosity globally conf.verb = 0 HAS_TCPREPLAY = shutil.which('tcpreplay') is not None class FlowSender: """Manages sending threads for multiple flows.""" def __init__(self): self._lock = threading.Lock() self._flows = {} # flow_id -> flow_config dict self._threads = {} # flow_id -> Thread self._stop_events = {} # flow_id -> Event self._stats = {} # flow_id -> {tx_packets, tx_bytes, ...} # ------------------------------------------------------------------ # Flow CRUD # ------------------------------------------------------------------ def add_flow(self, flow_id: str, flow_config: dict): with self._lock: self._flows[flow_id] = flow_config self._stats[flow_id] = { 'tx_packets': 0, 'tx_bytes': 0, 'rx_packets': 0, 'rx_bytes': 0, 'latency_samples': [], } def get_flow(self, flow_id: str): with self._lock: return self._flows.get(flow_id) def get_all_flows(self): with self._lock: return dict(self._flows) def update_flow(self, flow_id: str, updates: dict): with self._lock: if flow_id not in self._flows: return False self._flows[flow_id].update(updates) return True def remove_flow(self, flow_id: str): self.stop(flow_id) with self._lock: self._flows.pop(flow_id, None) self._stats.pop(flow_id, None) # ------------------------------------------------------------------ # Start / Stop # ------------------------------------------------------------------ def start(self, flow_id: str): with self._lock: if flow_id not in self._flows: raise KeyError(f'Flow {flow_id} not found') if flow_id in self._threads and self._threads[flow_id].is_alive(): return # already running self._flows[flow_id]['state'] = 'running' self._stats[flow_id] = { 'tx_packets': 0, 'tx_bytes': 0, 'rx_packets': 0, 'rx_bytes': 0, 'latency_samples': [], } stop_event = threading.Event() self._stop_events[flow_id] = stop_event t = threading.Thread( target=self._send_loop, args=(flow_id, stop_event), daemon=True, name=f'sender-{flow_id[:8]}', ) self._threads[flow_id] = t t.start() def stop(self, flow_id: str): with self._lock: ev = self._stop_events.pop(flow_id, None) if ev: ev.set() t = self._threads.pop(flow_id, None) if flow_id in self._flows: self._flows[flow_id]['state'] = 'stopped' if t and t.is_alive(): t.join(timeout=5) def is_running(self, flow_id: str) -> bool: with self._lock: t = self._threads.get(flow_id) return t is not None and t.is_alive() # ------------------------------------------------------------------ # Stats # ------------------------------------------------------------------ def get_stats(self, flow_id: str) -> dict: with self._lock: s = self._stats.get(flow_id, {}) return dict(s) def get_all_stats(self) -> dict: with self._lock: return {fid: dict(s) for fid, s in self._stats.items()} # ------------------------------------------------------------------ # Internal send loop # ------------------------------------------------------------------ def _send_loop(self, flow_id: str, stop_event: threading.Event): with self._lock: flow = dict(self._flows[flow_id]) rate_pps = flow.get('rate_pps', 1000) duration = flow.get('duration', 30) protocol = flow.get('protocol', 'udp').lower() responder_url = flow.get('responder_url') use_icmp_sr = (protocol == 'icmp' and not responder_url) # Build template packet pkt_template = build_packet(flow, seq=0) pkt_bytes_len = len(bytes(pkt_template)) seq = 0 start_time = time.time() last_responder_poll = 0 log.info('Flow %s: starting send loop at %d pps for %ds', flow_id[:8], rate_pps, duration) # Capture responder baseline so we report deltas, not cumulative totals responder_baseline_rx = 0 responder_baseline_bytes = 0 if responder_url: try: base = self._fetch_responder(responder_url) responder_baseline_rx = base.get('rx_packets', 0) responder_baseline_bytes = base.get('rx_bytes', 0) # Also reset responder so baseline is clean self._reset_responder(responder_url) responder_baseline_rx = 0 responder_baseline_bytes = 0 except Exception: pass raw_sock = None try: if use_icmp_sr: self._send_loop_icmp(flow_id, flow, stop_event, rate_pps, duration) return # --- High-performance path: raw socket --- dst_ip = flow['dst_ip'] # Build template as raw IP bytes (strip Ethernet layer) ip_template = bytes(pkt_template[pkt_template.firstlayer().payload.__class__]) # Find where TGEN magic starts in the IP-layer bytes magic_offset = ip_template.find(MAGIC) # Find and zero UDP checksum offset in template ip_ihl = (ip_template[0] & 0x0F) * 4 ip_proto = ip_template[9] udp_csum_offset = ip_ihl + 6 if ip_proto == 17 else -1 # 17 = UDP raw_sock = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_RAW) raw_sock.setsockopt(socket.IPPROTO_IP, socket.IP_HDRINCL, 1) # Adaptive batching: send bursts then sleep to hit target rate batch_size = max(1, min(rate_pps // 5, 500)) interval = batch_size / rate_pps if rate_pps > 0 else 1.0 while not stop_event.is_set(): elapsed = time.time() - start_time if duration and elapsed >= duration: break batch_start = time.time() sent_this_batch = 0 for _ in range(batch_size): pkt_bytes = bytearray(ip_template) if magic_offset >= 0: struct.pack_into('!I', pkt_bytes, magic_offset + 4, seq) struct.pack_into('!Q', pkt_bytes, magic_offset + 8, time.time_ns()) pkt_bytes[10:12] = b'\x00\x00' # zero IP checksum if udp_csum_offset > 0: pkt_bytes[udp_csum_offset:udp_csum_offset + 2] = b'\x00\x00' try: raw_sock.sendto(bytes(pkt_bytes), (dst_ip, 0)) sent_this_batch += 1 except Exception: pass seq += 1 with self._lock: stats = self._stats.get(flow_id) if stats: stats['tx_packets'] += sent_this_batch stats['tx_bytes'] += pkt_bytes_len * sent_this_batch # Poll responder for rx stats periodically if responder_url and (time.time() - last_responder_poll) >= 1.0: self._poll_responder(flow_id, responder_url, responder_baseline_rx, responder_baseline_bytes) last_responder_poll = time.time() # Precise rate limiting: sleep remaining time for this batch batch_elapsed = time.time() - batch_start sleep_time = interval - batch_elapsed if sleep_time > 0: stop_event.wait(sleep_time) except Exception as e: log.error('Flow %s: send loop error: %s', flow_id[:8], e) finally: if raw_sock: raw_sock.close() with self._lock: if flow_id in self._flows: self._flows[flow_id]['state'] = 'stopped' # Final responder poll if responder_url: self._poll_responder(flow_id, responder_url, responder_baseline_rx, responder_baseline_bytes) log.info('Flow %s: send loop finished. seq=%d', flow_id[:8], seq) def _send_loop_icmp(self, flow_id, flow, stop_event, rate_pps, duration): """ICMP mode: use sr() to measure latency from router responses.""" pkt_template = build_packet(flow, seq=0) pkt_bytes_len = len(bytes(pkt_template)) seq = 0 start_time = time.time() try: while not stop_event.is_set(): elapsed = time.time() - start_time if duration and elapsed >= duration: break pkt = build_packet(flow, seq=seq) answered, _ = sr(pkt[pkt.firstlayer().payload.__class__], timeout=1, verbose=0) with self._lock: stats = self._stats.get(flow_id) if stats: stats['tx_packets'] += 1 stats['tx_bytes'] += pkt_bytes_len for sent_pkt, recv_pkt in answered: rtt_ms = (recv_pkt.time - sent_pkt.sent_time) * 1000 stats['rx_packets'] += 1 stats['rx_bytes'] += len(bytes(recv_pkt)) stats['latency_samples'].append(rtt_ms) if len(stats['latency_samples']) > 1000: stats['latency_samples'] = stats['latency_samples'][-1000:] seq += 1 sleep_time = (1.0 / rate_pps) - (time.time() - start_time - elapsed) if sleep_time > 0: stop_event.wait(sleep_time) except Exception as e: log.error('Flow %s: ICMP send error: %s', flow_id[:8], e) finally: with self._lock: if flow_id in self._flows: self._flows[flow_id]['state'] = 'stopped' def _fetch_responder(self, responder_url: str) -> dict: """Fetch raw stats from the responder.""" url = responder_url.rstrip('/') + '/responder/stats' req = urllib.request.Request(url, method='GET') req.add_header('Accept', 'application/json') with urllib.request.urlopen(req, timeout=2) as resp: return json.loads(resp.read().decode()) def _reset_responder(self, responder_url: str): """Reset responder counters.""" url = responder_url.rstrip('/') + '/responder/reset' req = urllib.request.Request(url, method='POST') req.add_header('Content-Type', 'application/json') with urllib.request.urlopen(req, timeout=2) as resp: resp.read() def _poll_responder(self, flow_id: str, responder_url: str, baseline_rx: int = 0, baseline_bytes: int = 0): """Poll a responder's /responder/stats endpoint for rx metrics.""" try: data = self._fetch_responder(responder_url) rx_pkts = data.get('rx_packets', 0) - baseline_rx rx_bytes = data.get('rx_bytes', 0) - baseline_bytes with self._lock: stats = self._stats.get(flow_id) if stats: stats['rx_packets'] = max(0, rx_pkts) stats['rx_bytes'] = max(0, rx_bytes) lat = data.get('latency', {}) if lat.get('avg_ms') is not None: stats['latency_samples'].append(lat['avg_ms']) if len(stats['latency_samples']) > 1000: stats['latency_samples'] = stats['latency_samples'][-1000:] except Exception as e: log.debug('Responder poll error for flow %s: %s', flow_id[:8], e)