248 lines
9.6 KiB
Python
248 lines
9.6 KiB
Python
|
|
"""
|
||
|
|
FlowSender - manages traffic generation with background threads per flow.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import logging
|
||
|
|
import shutil
|
||
|
|
import threading
|
||
|
|
import time
|
||
|
|
import urllib.request
|
||
|
|
import json
|
||
|
|
|
||
|
|
from scapy.all import send, sendpfast, sr, conf
|
||
|
|
|
||
|
|
from engine.packet_builder import build_packet, stamp_payload
|
||
|
|
|
||
|
|
log = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
# Suppress Scapy verbosity globally
|
||
|
|
conf.verb = 0
|
||
|
|
|
||
|
|
HAS_TCPREPLAY = shutil.which('tcpreplay') is not None
|
||
|
|
|
||
|
|
|
||
|
|
class FlowSender:
|
||
|
|
"""Manages sending threads for multiple flows."""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self._lock = threading.Lock()
|
||
|
|
self._flows = {} # flow_id -> flow_config dict
|
||
|
|
self._threads = {} # flow_id -> Thread
|
||
|
|
self._stop_events = {} # flow_id -> Event
|
||
|
|
self._stats = {} # flow_id -> {tx_packets, tx_bytes, ...}
|
||
|
|
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
# Flow CRUD
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
|
||
|
|
def add_flow(self, flow_id: str, flow_config: dict):
|
||
|
|
with self._lock:
|
||
|
|
self._flows[flow_id] = flow_config
|
||
|
|
self._stats[flow_id] = {
|
||
|
|
'tx_packets': 0, 'tx_bytes': 0,
|
||
|
|
'rx_packets': 0, 'rx_bytes': 0,
|
||
|
|
'latency_samples': [],
|
||
|
|
}
|
||
|
|
|
||
|
|
def get_flow(self, flow_id: str):
|
||
|
|
with self._lock:
|
||
|
|
return self._flows.get(flow_id)
|
||
|
|
|
||
|
|
def get_all_flows(self):
|
||
|
|
with self._lock:
|
||
|
|
return dict(self._flows)
|
||
|
|
|
||
|
|
def update_flow(self, flow_id: str, updates: dict):
|
||
|
|
with self._lock:
|
||
|
|
if flow_id not in self._flows:
|
||
|
|
return False
|
||
|
|
self._flows[flow_id].update(updates)
|
||
|
|
return True
|
||
|
|
|
||
|
|
def remove_flow(self, flow_id: str):
|
||
|
|
self.stop(flow_id)
|
||
|
|
with self._lock:
|
||
|
|
self._flows.pop(flow_id, None)
|
||
|
|
self._stats.pop(flow_id, None)
|
||
|
|
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
# Start / Stop
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
|
||
|
|
def start(self, flow_id: str):
|
||
|
|
with self._lock:
|
||
|
|
if flow_id not in self._flows:
|
||
|
|
raise KeyError(f'Flow {flow_id} not found')
|
||
|
|
if flow_id in self._threads and self._threads[flow_id].is_alive():
|
||
|
|
return # already running
|
||
|
|
self._flows[flow_id]['state'] = 'running'
|
||
|
|
self._stats[flow_id] = {
|
||
|
|
'tx_packets': 0, 'tx_bytes': 0,
|
||
|
|
'rx_packets': 0, 'rx_bytes': 0,
|
||
|
|
'latency_samples': [],
|
||
|
|
}
|
||
|
|
stop_event = threading.Event()
|
||
|
|
self._stop_events[flow_id] = stop_event
|
||
|
|
t = threading.Thread(
|
||
|
|
target=self._send_loop,
|
||
|
|
args=(flow_id, stop_event),
|
||
|
|
daemon=True,
|
||
|
|
name=f'sender-{flow_id[:8]}',
|
||
|
|
)
|
||
|
|
self._threads[flow_id] = t
|
||
|
|
t.start()
|
||
|
|
|
||
|
|
def stop(self, flow_id: str):
|
||
|
|
with self._lock:
|
||
|
|
ev = self._stop_events.pop(flow_id, None)
|
||
|
|
if ev:
|
||
|
|
ev.set()
|
||
|
|
t = self._threads.pop(flow_id, None)
|
||
|
|
if flow_id in self._flows:
|
||
|
|
self._flows[flow_id]['state'] = 'stopped'
|
||
|
|
if t and t.is_alive():
|
||
|
|
t.join(timeout=5)
|
||
|
|
|
||
|
|
def is_running(self, flow_id: str) -> bool:
|
||
|
|
with self._lock:
|
||
|
|
t = self._threads.get(flow_id)
|
||
|
|
return t is not None and t.is_alive()
|
||
|
|
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
# Stats
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
|
||
|
|
def get_stats(self, flow_id: str) -> dict:
|
||
|
|
with self._lock:
|
||
|
|
s = self._stats.get(flow_id, {})
|
||
|
|
return dict(s)
|
||
|
|
|
||
|
|
def get_all_stats(self) -> dict:
|
||
|
|
with self._lock:
|
||
|
|
return {fid: dict(s) for fid, s in self._stats.items()}
|
||
|
|
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
# Internal send loop
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
|
||
|
|
def _send_loop(self, flow_id: str, stop_event: threading.Event):
|
||
|
|
with self._lock:
|
||
|
|
flow = dict(self._flows[flow_id])
|
||
|
|
|
||
|
|
rate_pps = flow.get('rate_pps', 1000)
|
||
|
|
duration = flow.get('duration', 30)
|
||
|
|
protocol = flow.get('protocol', 'udp').lower()
|
||
|
|
responder_url = flow.get('responder_url')
|
||
|
|
use_icmp_sr = (protocol == 'icmp' and not responder_url)
|
||
|
|
|
||
|
|
# Build template packet
|
||
|
|
pkt_template = build_packet(flow, seq=0)
|
||
|
|
pkt_bytes_len = len(bytes(pkt_template))
|
||
|
|
|
||
|
|
# Calculate sleep interval: send in batches for efficiency
|
||
|
|
batch_size = max(1, min(rate_pps // 10, 100))
|
||
|
|
interval = batch_size / rate_pps if rate_pps > 0 else 1.0
|
||
|
|
|
||
|
|
seq = 0
|
||
|
|
start_time = time.time()
|
||
|
|
last_responder_poll = 0
|
||
|
|
log.info('Flow %s: starting send loop at %d pps for %ds', flow_id[:8], rate_pps, duration)
|
||
|
|
|
||
|
|
try:
|
||
|
|
while not stop_event.is_set():
|
||
|
|
elapsed = time.time() - start_time
|
||
|
|
if duration and elapsed >= duration:
|
||
|
|
break
|
||
|
|
|
||
|
|
if use_icmp_sr:
|
||
|
|
# ICMP mode: use sr() to measure latency from responses
|
||
|
|
pkt = build_packet(flow, seq=seq)
|
||
|
|
answered, _ = sr(pkt[pkt.firstlayer().payload.__class__],
|
||
|
|
timeout=1, verbose=0)
|
||
|
|
with self._lock:
|
||
|
|
stats = self._stats.get(flow_id)
|
||
|
|
if stats:
|
||
|
|
stats['tx_packets'] += 1
|
||
|
|
stats['tx_bytes'] += pkt_bytes_len
|
||
|
|
for sent_pkt, recv_pkt in answered:
|
||
|
|
rtt_ms = (recv_pkt.time - sent_pkt.sent_time) * 1000
|
||
|
|
stats['rx_packets'] += 1
|
||
|
|
stats['rx_bytes'] += len(bytes(recv_pkt))
|
||
|
|
stats['latency_samples'].append(rtt_ms)
|
||
|
|
# Keep only last 1000 samples
|
||
|
|
if len(stats['latency_samples']) > 1000:
|
||
|
|
stats['latency_samples'] = stats['latency_samples'][-1000:]
|
||
|
|
seq += 1
|
||
|
|
# Rate limit for ICMP
|
||
|
|
sleep_time = (1.0 / rate_pps) - (time.time() - start_time - elapsed)
|
||
|
|
if sleep_time > 0:
|
||
|
|
stop_event.wait(sleep_time)
|
||
|
|
else:
|
||
|
|
# UDP/TCP mode: send batches
|
||
|
|
packets = []
|
||
|
|
for _ in range(batch_size):
|
||
|
|
pkt = build_packet(flow, seq=seq)
|
||
|
|
packets.append(pkt)
|
||
|
|
seq += 1
|
||
|
|
|
||
|
|
try:
|
||
|
|
if HAS_TCPREPLAY and rate_pps >= 1000:
|
||
|
|
sendpfast(packets, pps=rate_pps, loop=0)
|
||
|
|
else:
|
||
|
|
for p in packets:
|
||
|
|
send(p[p.firstlayer().payload.__class__], verbose=0)
|
||
|
|
except Exception as e:
|
||
|
|
# Fallback: basic send
|
||
|
|
log.debug('Send error (falling back): %s', e)
|
||
|
|
for p in packets:
|
||
|
|
try:
|
||
|
|
send(p[p.firstlayer().payload.__class__], verbose=0)
|
||
|
|
except Exception:
|
||
|
|
pass
|
||
|
|
|
||
|
|
with self._lock:
|
||
|
|
stats = self._stats.get(flow_id)
|
||
|
|
if stats:
|
||
|
|
stats['tx_packets'] += len(packets)
|
||
|
|
stats['tx_bytes'] += pkt_bytes_len * len(packets)
|
||
|
|
|
||
|
|
# Poll responder for rx stats periodically
|
||
|
|
if responder_url and (time.time() - last_responder_poll) >= 2.0:
|
||
|
|
self._poll_responder(flow_id, responder_url)
|
||
|
|
last_responder_poll = time.time()
|
||
|
|
|
||
|
|
# Rate limit
|
||
|
|
stop_event.wait(interval)
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
log.error('Flow %s: send loop error: %s', flow_id[:8], e)
|
||
|
|
finally:
|
||
|
|
with self._lock:
|
||
|
|
if flow_id in self._flows:
|
||
|
|
self._flows[flow_id]['state'] = 'stopped'
|
||
|
|
# Final responder poll
|
||
|
|
if responder_url:
|
||
|
|
self._poll_responder(flow_id, responder_url)
|
||
|
|
log.info('Flow %s: send loop finished. seq=%d', flow_id[:8], seq)
|
||
|
|
|
||
|
|
def _poll_responder(self, flow_id: str, responder_url: str):
|
||
|
|
"""Poll a responder's /responder/stats endpoint for rx metrics."""
|
||
|
|
try:
|
||
|
|
url = responder_url.rstrip('/') + '/responder/stats'
|
||
|
|
req = urllib.request.Request(url, method='GET')
|
||
|
|
req.add_header('Accept', 'application/json')
|
||
|
|
with urllib.request.urlopen(req, timeout=2) as resp:
|
||
|
|
data = json.loads(resp.read().decode())
|
||
|
|
with self._lock:
|
||
|
|
stats = self._stats.get(flow_id)
|
||
|
|
if stats:
|
||
|
|
stats['rx_packets'] = data.get('rx_packets', 0)
|
||
|
|
stats['rx_bytes'] = data.get('rx_bytes', 0)
|
||
|
|
lat = data.get('latency', {})
|
||
|
|
if lat.get('avg_ms') is not None:
|
||
|
|
stats['latency_samples'].append(lat['avg_ms'])
|
||
|
|
if len(stats['latency_samples']) > 1000:
|
||
|
|
stats['latency_samples'] = stats['latency_samples'][-1000:]
|
||
|
|
except Exception as e:
|
||
|
|
log.debug('Responder poll error for flow %s: %s', flow_id[:8], e)
|