""" FlowSender - manages traffic generation with background threads per flow. """ import logging import shutil import threading import time import urllib.request import json from scapy.all import send, sendpfast, sr, conf from engine.packet_builder import build_packet, stamp_payload log = logging.getLogger(__name__) # Suppress Scapy verbosity globally conf.verb = 0 HAS_TCPREPLAY = shutil.which('tcpreplay') is not None class FlowSender: """Manages sending threads for multiple flows.""" def __init__(self): self._lock = threading.Lock() self._flows = {} # flow_id -> flow_config dict self._threads = {} # flow_id -> Thread self._stop_events = {} # flow_id -> Event self._stats = {} # flow_id -> {tx_packets, tx_bytes, ...} # ------------------------------------------------------------------ # Flow CRUD # ------------------------------------------------------------------ def add_flow(self, flow_id: str, flow_config: dict): with self._lock: self._flows[flow_id] = flow_config self._stats[flow_id] = { 'tx_packets': 0, 'tx_bytes': 0, 'rx_packets': 0, 'rx_bytes': 0, 'latency_samples': [], } def get_flow(self, flow_id: str): with self._lock: return self._flows.get(flow_id) def get_all_flows(self): with self._lock: return dict(self._flows) def update_flow(self, flow_id: str, updates: dict): with self._lock: if flow_id not in self._flows: return False self._flows[flow_id].update(updates) return True def remove_flow(self, flow_id: str): self.stop(flow_id) with self._lock: self._flows.pop(flow_id, None) self._stats.pop(flow_id, None) # ------------------------------------------------------------------ # Start / Stop # ------------------------------------------------------------------ def start(self, flow_id: str): with self._lock: if flow_id not in self._flows: raise KeyError(f'Flow {flow_id} not found') if flow_id in self._threads and self._threads[flow_id].is_alive(): return # already running self._flows[flow_id]['state'] = 'running' self._stats[flow_id] = { 'tx_packets': 0, 'tx_bytes': 0, 'rx_packets': 0, 'rx_bytes': 0, 'latency_samples': [], } stop_event = threading.Event() self._stop_events[flow_id] = stop_event t = threading.Thread( target=self._send_loop, args=(flow_id, stop_event), daemon=True, name=f'sender-{flow_id[:8]}', ) self._threads[flow_id] = t t.start() def stop(self, flow_id: str): with self._lock: ev = self._stop_events.pop(flow_id, None) if ev: ev.set() t = self._threads.pop(flow_id, None) if flow_id in self._flows: self._flows[flow_id]['state'] = 'stopped' if t and t.is_alive(): t.join(timeout=5) def is_running(self, flow_id: str) -> bool: with self._lock: t = self._threads.get(flow_id) return t is not None and t.is_alive() # ------------------------------------------------------------------ # Stats # ------------------------------------------------------------------ def get_stats(self, flow_id: str) -> dict: with self._lock: s = self._stats.get(flow_id, {}) return dict(s) def get_all_stats(self) -> dict: with self._lock: return {fid: dict(s) for fid, s in self._stats.items()} # ------------------------------------------------------------------ # Internal send loop # ------------------------------------------------------------------ def _send_loop(self, flow_id: str, stop_event: threading.Event): with self._lock: flow = dict(self._flows[flow_id]) rate_pps = flow.get('rate_pps', 1000) duration = flow.get('duration', 30) protocol = flow.get('protocol', 'udp').lower() responder_url = flow.get('responder_url') use_icmp_sr = (protocol == 'icmp' and not responder_url) # Build template packet pkt_template = build_packet(flow, seq=0) pkt_bytes_len = len(bytes(pkt_template)) # Calculate sleep interval: send in batches for efficiency batch_size = max(1, min(rate_pps // 10, 100)) interval = batch_size / rate_pps if rate_pps > 0 else 1.0 seq = 0 start_time = time.time() last_responder_poll = 0 log.info('Flow %s: starting send loop at %d pps for %ds', flow_id[:8], rate_pps, duration) try: while not stop_event.is_set(): elapsed = time.time() - start_time if duration and elapsed >= duration: break if use_icmp_sr: # ICMP mode: use sr() to measure latency from responses pkt = build_packet(flow, seq=seq) answered, _ = sr(pkt[pkt.firstlayer().payload.__class__], timeout=1, verbose=0) with self._lock: stats = self._stats.get(flow_id) if stats: stats['tx_packets'] += 1 stats['tx_bytes'] += pkt_bytes_len for sent_pkt, recv_pkt in answered: rtt_ms = (recv_pkt.time - sent_pkt.sent_time) * 1000 stats['rx_packets'] += 1 stats['rx_bytes'] += len(bytes(recv_pkt)) stats['latency_samples'].append(rtt_ms) # Keep only last 1000 samples if len(stats['latency_samples']) > 1000: stats['latency_samples'] = stats['latency_samples'][-1000:] seq += 1 # Rate limit for ICMP sleep_time = (1.0 / rate_pps) - (time.time() - start_time - elapsed) if sleep_time > 0: stop_event.wait(sleep_time) else: # UDP/TCP mode: send batches packets = [] for _ in range(batch_size): pkt = build_packet(flow, seq=seq) packets.append(pkt) seq += 1 try: if HAS_TCPREPLAY and rate_pps >= 1000: sendpfast(packets, pps=rate_pps, loop=0) else: for p in packets: send(p[p.firstlayer().payload.__class__], verbose=0) except Exception as e: # Fallback: basic send log.debug('Send error (falling back): %s', e) for p in packets: try: send(p[p.firstlayer().payload.__class__], verbose=0) except Exception: pass with self._lock: stats = self._stats.get(flow_id) if stats: stats['tx_packets'] += len(packets) stats['tx_bytes'] += pkt_bytes_len * len(packets) # Poll responder for rx stats periodically if responder_url and (time.time() - last_responder_poll) >= 2.0: self._poll_responder(flow_id, responder_url) last_responder_poll = time.time() # Rate limit stop_event.wait(interval) except Exception as e: log.error('Flow %s: send loop error: %s', flow_id[:8], e) finally: with self._lock: if flow_id in self._flows: self._flows[flow_id]['state'] = 'stopped' # Final responder poll if responder_url: self._poll_responder(flow_id, responder_url) log.info('Flow %s: send loop finished. seq=%d', flow_id[:8], seq) def _poll_responder(self, flow_id: str, responder_url: str): """Poll a responder's /responder/stats endpoint for rx metrics.""" try: url = responder_url.rstrip('/') + '/responder/stats' req = urllib.request.Request(url, method='GET') req.add_header('Accept', 'application/json') with urllib.request.urlopen(req, timeout=2) as resp: data = json.loads(resp.read().decode()) with self._lock: stats = self._stats.get(flow_id) if stats: stats['rx_packets'] = data.get('rx_packets', 0) stats['rx_bytes'] = data.get('rx_bytes', 0) lat = data.get('latency', {}) if lat.get('avg_ms') is not None: stats['latency_samples'].append(lat['avg_ms']) if len(stats['latency_samples']) > 1000: stats['latency_samples'] = stats['latency_samples'][-1000:] except Exception as e: log.debug('Responder poll error for flow %s: %s', flow_id[:8], e)