diff --git a/sab/clock_watch.py b/sab/clock_watch.py index c0bf172..8efbd04 100644 --- a/sab/clock_watch.py +++ b/sab/clock_watch.py @@ -200,6 +200,56 @@ def enable_persistence(enable: bool) -> None: run(["sudo", "nvidia-smi", "-pm", "1" if enable else "0"]) +class CPUFrequencyMonitor: + """Monitors CPU frequency drift during a benchmark run via /proc/cpuinfo.""" + + def __init__(self, tolerance_mhz: float = 50.0): + self._tolerance_mhz = tolerance_mhz + self._baseline_freqs: list[float] | None = None + self._end_freqs: list[float] | None = None + self._drifted = False + + @staticmethod + def _read_cpu_frequencies() -> list[float]: + freqs = [] + with open("/proc/cpuinfo") as f: + for line in f: + if line.startswith("cpu MHz"): + freqs.append(float(line.split(":")[1].strip())) + return freqs + + def __enter__(self): + try: + self._baseline_freqs = self._read_cpu_frequencies() + except OSError: + self._baseline_freqs = None + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._baseline_freqs is None: + return + try: + self._end_freqs = self._read_cpu_frequencies() + except OSError: + return + for before, after in zip(self._baseline_freqs, self._end_freqs): + if abs(before - after) > self._tolerance_mhz: + self._drifted = True + return + + def did_drift(self) -> bool: + return self._drifted + + def get_summary(self) -> dict | None: + if self._baseline_freqs is None or self._end_freqs is None: + return None + return { + "baseline_mean_mhz": sum(self._baseline_freqs) / len(self._baseline_freqs), + "end_mean_mhz": sum(self._end_freqs) / len(self._end_freqs), + "max_drift_mhz": max(abs(b - e) for b, e in zip(self._baseline_freqs, self._end_freqs)), + } + + def main(): signal.signal(signal.SIGINT, lambda *_: sys.exit(0)) # clean Ctrl-C print("🟢 Watching for any GPU clock changes (press Ctrl-C to quit)") diff --git a/sab/evaluation.py b/sab/evaluation.py index 57ae622..4e1fa47 100644 --- a/sab/evaluation.py +++ b/sab/evaluation.py @@ -12,6 +12,8 @@ from tqdm import tqdm import time +from sab.onnx_inference import ONNXInferenceCPU + def evaluate(inference, image_dir: str, annotations_file_path: str, class_mapping: dict[int, str]|None=None, buffer_time: float=0.0, output_file_name: str|None=None, max_images: int|None=None, max_dets: int=100): predictions = [] @@ -29,7 +31,9 @@ def evaluate(inference, image_dir: str, annotations_file_path: str, class_mappin image = Image.open(image_path).convert("RGB") initial_shape = image.size - image = TF.to_tensor(image).cuda() + image = TF.to_tensor(image) + if not isinstance(inference, ONNXInferenceCPU): + image = image.cuda() if inference.prediction_type == "bbox": xyxy, class_id, score = inference.infer(image) diff --git a/sab/models/benchmark_dfine.py b/sab/models/benchmark_dfine.py index 025ad4d..e939d31 100644 --- a/sab/models/benchmark_dfine.py +++ b/sab/models/benchmark_dfine.py @@ -6,7 +6,7 @@ import fire -from sab.onnx_inference import ONNXInference +from sab.onnx_inference import ONNXInferenceCUDA from sab.trt_inference import TRTInference from sab.models.utils import ArtifactBenchmarkRequest, run_benchmark_on_artifacts, pretty_print_results @@ -28,7 +28,7 @@ def postprocess_output(outputs: dict[str, torch.Tensor], metadata: dict) -> tupl return bboxes, labels, scores -class DFINEONNXInference(ONNXInference): +class DFINEONNXInference(ONNXInferenceCUDA): def __init__(self, model_path: str, image_input_name: str|None="images"): super().__init__(model_path, image_input_name) diff --git a/sab/models/benchmark_lwdetr.py b/sab/models/benchmark_lwdetr.py index 041e23e..60f7aa8 100644 --- a/sab/models/benchmark_lwdetr.py +++ b/sab/models/benchmark_lwdetr.py @@ -3,7 +3,7 @@ import json import fire -from sab.onnx_inference import ONNXInference +from sab.onnx_inference import ONNXInferenceCUDA from sab.trt_inference import TRTInference from sab.models.utils import cxcywh_to_xyxy, ArtifactBenchmarkRequest, run_benchmark_on_artifacts, pretty_print_results @@ -39,7 +39,7 @@ def postprocess_output(outputs: dict[str, torch.Tensor], metadata: dict) -> tupl return bboxes.contiguous(), labels.contiguous(), scores.contiguous() -class LWDETRONNXInference(ONNXInference): +class LWDETRONNXInference(ONNXInferenceCUDA): def preprocess(self, input_image: torch.Tensor) -> tuple[torch.Tensor, dict]: return preprocess_image(input_image, self.image_input_shape) diff --git a/sab/models/benchmark_rfdetr.py b/sab/models/benchmark_rfdetr.py index f0f58e3..a6f6401 100644 --- a/sab/models/benchmark_rfdetr.py +++ b/sab/models/benchmark_rfdetr.py @@ -9,7 +9,7 @@ import fire -from sab.onnx_inference import ONNXInference +from sab.onnx_inference import ONNXInferenceCUDA, ONNXInferenceCPU from sab.trt_inference import TRTInference from sab.models.utils import cxcywh_to_xyxy, ArtifactBenchmarkRequest, run_benchmark_on_artifacts, pretty_print_results @@ -45,7 +45,7 @@ def postprocess_output(outputs: dict[str, torch.Tensor], metadata: dict) -> tupl return bboxes.contiguous(), labels.contiguous(), scores.contiguous() -class RFDETRONNXInference(ONNXInference): +class RFDETRONNXInference(ONNXInferenceCUDA): def preprocess(self, input_image: torch.Tensor) -> tuple[torch.Tensor, dict]: return preprocess_image(input_image, self.image_input_shape) @@ -53,6 +53,14 @@ def postprocess(self, outputs: dict[str, torch.Tensor], metadata: dict) -> tuple return postprocess_output(outputs, metadata) +class RFDETRONNXCPUInference(ONNXInferenceCPU): + def preprocess(self, input_image: torch.Tensor) -> tuple[torch.Tensor, dict]: + return preprocess_image(input_image, self.image_input_shape) + + def postprocess(self, outputs: dict[str, torch.Tensor], metadata: dict) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return postprocess_output(outputs, metadata) + + class RFDETRTRTInference(TRTInference): def __init__(self, model_path: str, image_input_name: str|None=None): super().__init__(model_path, image_input_name, use_cuda_graph=True) @@ -78,6 +86,11 @@ def main(image_dir: str, annotations_file_path: str, buffer_time: float = 0.0, o needs_fp16=True, buffer_time=buffer_time, ), + ArtifactBenchmarkRequest( + onnx_path="rf-detr-nano.onnx", + inference_class=RFDETRONNXCPUInference, + buffer_time=buffer_time, + ), ArtifactBenchmarkRequest( onnx_path="rf-detr-small.onnx", inference_class=RFDETRTRTInference, @@ -90,6 +103,11 @@ def main(image_dir: str, annotations_file_path: str, buffer_time: float = 0.0, o needs_fp16=True, buffer_time=buffer_time, ), + ArtifactBenchmarkRequest( + onnx_path="rf-detr-small.onnx", + inference_class=RFDETRONNXCPUInference, + buffer_time=buffer_time, + ), ArtifactBenchmarkRequest( onnx_path="rf-detr-medium.onnx", inference_class=RFDETRTRTInference, @@ -102,6 +120,11 @@ def main(image_dir: str, annotations_file_path: str, buffer_time: float = 0.0, o needs_fp16=True, buffer_time=buffer_time, ), + ArtifactBenchmarkRequest( + onnx_path="rf-detr-medium.onnx", + inference_class=RFDETRONNXCPUInference, + buffer_time=buffer_time, + ), ] results = run_benchmark_on_artifacts(requests, image_dir, annotations_file_path) diff --git a/sab/models/benchmark_rfdetr_seg.py b/sab/models/benchmark_rfdetr_seg.py index 67a6296..cd10ebd 100644 --- a/sab/models/benchmark_rfdetr_seg.py +++ b/sab/models/benchmark_rfdetr_seg.py @@ -10,7 +10,7 @@ import fire -from sab.onnx_inference import ONNXInference +from sab.onnx_inference import ONNXInferenceCUDA from sab.trt_inference import TRTInference from sab.models.utils import cxcywh_to_xyxy, ArtifactBenchmarkRequest, run_benchmark_on_artifacts, pretty_print_results @@ -56,7 +56,7 @@ def postprocess_output(outputs: dict[str, torch.Tensor], metadata: dict) -> tupl return bboxes.contiguous(), labels.contiguous(), scores.contiguous(), masks.contiguous() -class RFDETRSegONNXInference(ONNXInference): +class RFDETRSegONNXInference(ONNXInferenceCUDA): def __init__(self, model_path: str, image_input_name: str|None=None): super().__init__(model_path, image_input_name, prediction_type="segm") diff --git a/sab/models/benchmark_rtdetr.py b/sab/models/benchmark_rtdetr.py index ad676ab..156a6bb 100644 --- a/sab/models/benchmark_rtdetr.py +++ b/sab/models/benchmark_rtdetr.py @@ -6,7 +6,7 @@ import fire -from sab.onnx_inference import ONNXInference +from sab.onnx_inference import ONNXInferenceCUDA from sab.trt_inference import TRTInference from sab.models.utils import ArtifactBenchmarkRequest, run_benchmark_on_artifacts, pretty_print_results @@ -28,7 +28,7 @@ def postprocess_output(outputs: dict[str, torch.Tensor], metadata: dict) -> tupl return bboxes, labels, scores -class RTDETRONNXInference(ONNXInference): +class RTDETRONNXInference(ONNXInferenceCUDA): def __init__(self, model_path: str, image_input_name: str|None="images"): super().__init__(model_path, image_input_name) diff --git a/sab/models/benchmark_yolov11.py b/sab/models/benchmark_yolov11.py index a1ac607..00d295c 100644 --- a/sab/models/benchmark_yolov11.py +++ b/sab/models/benchmark_yolov11.py @@ -5,7 +5,7 @@ import fire -from sab.onnx_inference import ONNXInference +from sab.onnx_inference import ONNXInferenceCUDA, ONNXInferenceCPU from sab.trt_inference import TRTInference from sab.models.utils import ArtifactBenchmarkRequest, run_benchmark_on_artifacts, pretty_print_results @@ -79,7 +79,7 @@ def postprocess_output(outputs: dict[str, torch.Tensor], metadata: dict) -> tupl return bboxes, labels, scores -class YOLOv11ONNXInference(ONNXInference): +class YOLOv11ONNXInference(ONNXInferenceCUDA): # reference: https://github.com/ultralytics/ultralytics/blob/3c88bebc9514a4d7f70b771811ddfe3a625ef14d/examples/YOLOv8-OpenCV-ONNX-Python/main.py#L23C57-L31 def preprocess(self, input_image: torch.Tensor) -> tuple[torch.Tensor, dict]: return preprocess_image(input_image, self.image_input_shape) @@ -88,6 +88,14 @@ def postprocess(self, outputs: dict[str, torch.Tensor], metadata: dict) -> tuple return postprocess_output(outputs, metadata) +class YOLOv11ONNXCPUInference(ONNXInferenceCPU): + def preprocess(self, input_image: torch.Tensor) -> tuple[torch.Tensor, dict]: + return preprocess_image(input_image, self.image_input_shape) + + def postprocess(self, outputs: dict[str, torch.Tensor], metadata: dict) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return postprocess_output(outputs, metadata) + + class YOLOv11TRTInference(TRTInference): def __init__(self, model_path: str, image_input_name: str|None=None): super().__init__(model_path, image_input_name, use_cuda_graph=False) @@ -115,6 +123,12 @@ def main(image_dir: str, annotations_file_path: str, buffer_time: float = 0.0, o buffer_time=buffer_time, needs_class_remapping=True, ), + ArtifactBenchmarkRequest( + onnx_path="yolo11n_nms_conf_0.01.onnx", + inference_class=YOLOv11ONNXCPUInference, + buffer_time=buffer_time, + needs_class_remapping=True, + ), ArtifactBenchmarkRequest( onnx_path="yolo11s_nms_conf_0.01.onnx", inference_class=YOLOv11TRTInference, @@ -129,6 +143,12 @@ def main(image_dir: str, annotations_file_path: str, buffer_time: float = 0.0, o buffer_time=buffer_time, needs_class_remapping=True, ), + ArtifactBenchmarkRequest( + onnx_path="yolo11s_nms_conf_0.01.onnx", + inference_class=YOLOv11ONNXCPUInference, + buffer_time=buffer_time, + needs_class_remapping=True, + ), ArtifactBenchmarkRequest( onnx_path="yolo11m_nms_conf_0.01.onnx", inference_class=YOLOv11TRTInference, @@ -143,6 +163,12 @@ def main(image_dir: str, annotations_file_path: str, buffer_time: float = 0.0, o buffer_time=buffer_time, needs_class_remapping=True, ), + ArtifactBenchmarkRequest( + onnx_path="yolo11m_nms_conf_0.01.onnx", + inference_class=YOLOv11ONNXCPUInference, + buffer_time=buffer_time, + needs_class_remapping=True, + ), ArtifactBenchmarkRequest( onnx_path="yolo11l_nms_conf_0.01.onnx", inference_class=YOLOv11TRTInference, @@ -157,6 +183,12 @@ def main(image_dir: str, annotations_file_path: str, buffer_time: float = 0.0, o buffer_time=buffer_time, needs_class_remapping=True, ), + ArtifactBenchmarkRequest( + onnx_path="yolo11l_nms_conf_0.01.onnx", + inference_class=YOLOv11ONNXCPUInference, + buffer_time=buffer_time, + needs_class_remapping=True, + ), ArtifactBenchmarkRequest( onnx_path="yolo11x_nms_conf_0.01.onnx", inference_class=YOLOv11TRTInference, @@ -171,6 +203,12 @@ def main(image_dir: str, annotations_file_path: str, buffer_time: float = 0.0, o buffer_time=buffer_time, needs_class_remapping=True, ), + ArtifactBenchmarkRequest( + onnx_path="yolo11x_nms_conf_0.01.onnx", + inference_class=YOLOv11ONNXCPUInference, + buffer_time=buffer_time, + needs_class_remapping=True, + ), ] results = run_benchmark_on_artifacts(requests, image_dir, annotations_file_path) diff --git a/sab/models/benchmark_yolov11_seg.py b/sab/models/benchmark_yolov11_seg.py index 229c278..e271109 100644 --- a/sab/models/benchmark_yolov11_seg.py +++ b/sab/models/benchmark_yolov11_seg.py @@ -6,7 +6,7 @@ import fire -from sab.onnx_inference import ONNXInference +from sab.onnx_inference import ONNXInferenceCUDA from sab.trt_inference import TRTInference from sab.models.utils import ArtifactBenchmarkRequest, run_benchmark_on_artifacts, pretty_print_results from sab.models.graph_surgery import fuse_yolo_mask_postprocessing_into_onnx @@ -107,7 +107,7 @@ def postprocess_output(outputs: dict[str, torch.Tensor], metadata: dict) -> tupl return bboxes, labels, scores, masks -class YOLOv11SegONNXInference(ONNXInference): +class YOLOv11SegONNXInference(ONNXInferenceCUDA): def __init__(self, model_path: str, image_input_name: str|None=None): super().__init__(model_path, image_input_name, prediction_type="segm") diff --git a/sab/models/benchmark_yolov8.py b/sab/models/benchmark_yolov8.py index dc0d9ee..e58606b 100644 --- a/sab/models/benchmark_yolov8.py +++ b/sab/models/benchmark_yolov8.py @@ -5,7 +5,7 @@ import fire -from sab.onnx_inference import ONNXInference +from sab.onnx_inference import ONNXInferenceCUDA from sab.trt_inference import TRTInference from sab.models.utils import ArtifactBenchmarkRequest, run_benchmark_on_artifacts, pretty_print_results @@ -79,7 +79,7 @@ def postprocess_output(outputs: dict[str, torch.Tensor], metadata: dict) -> tupl return bboxes, labels, scores -class YOLOv8ONNXInference(ONNXInference): +class YOLOv8ONNXInference(ONNXInferenceCUDA): # reference: https://github.com/ultralytics/ultralytics/blob/3c88bebc9514a4d7f70b771811ddfe3a625ef14d/examples/YOLOv8-OpenCV-ONNX-Python/main.py#L23C57-L31 def preprocess(self, input_image: torch.Tensor) -> tuple[torch.Tensor, dict]: return preprocess_image(input_image, self.image_input_shape) diff --git a/sab/models/benchmark_yolov8_seg.py b/sab/models/benchmark_yolov8_seg.py index 8f751f1..4d3c497 100644 --- a/sab/models/benchmark_yolov8_seg.py +++ b/sab/models/benchmark_yolov8_seg.py @@ -6,7 +6,7 @@ import fire -from sab.onnx_inference import ONNXInference +from sab.onnx_inference import ONNXInferenceCUDA from sab.trt_inference import TRTInference from sab.models.utils import ArtifactBenchmarkRequest, run_benchmark_on_artifacts, pretty_print_results from sab.models.graph_surgery import fuse_yolo_mask_postprocessing_into_onnx @@ -107,7 +107,7 @@ def postprocess_output(outputs: dict[str, torch.Tensor], metadata: dict) -> tupl return bboxes, labels, scores, masks -class YOLOv8SegONNXInference(ONNXInference): +class YOLOv8SegONNXInference(ONNXInferenceCUDA): def __init__(self, model_path: str, image_input_name: str|None=None): super().__init__(model_path, image_input_name, prediction_type="segm") diff --git a/sab/models/utils.py b/sab/models/utils.py index e3debfa..85d7d8c 100644 --- a/sab/models/utils.py +++ b/sab/models/utils.py @@ -6,8 +6,8 @@ from supervision.utils.file import read_json_file from supervision.dataset.formats.coco import coco_categories_to_classes, build_coco_class_index_mapping -from sab.clock_watch import ThrottleMonitor -from sab.onnx_inference import ONNXInference +from sab.clock_watch import ThrottleMonitor, CPUFrequencyMonitor +from sab.onnx_inference import ONNXInferenceCUDA, ONNXInferenceCPU from sab.trt_inference import TRTInference, build_engine from sab.evaluation import evaluate @@ -48,7 +48,7 @@ def cxcywh_to_xyxy(boxes): class ArtifactBenchmarkRequest: def __init__(self, onnx_path: str, - inference_class: type[ONNXInference|TRTInference], + inference_class: type[ONNXInferenceCUDA|ONNXInferenceCPU|TRTInference], needs_class_remapping: bool = False, needs_fp16: bool = False, buffer_time: float = 0.0, @@ -70,6 +70,7 @@ def dump(self): "onnx_path": self.onnx_path, "inference_class": self.inference_class.__name__, "is_trt": issubclass(self.inference_class, TRTInference), + "is_cpu": issubclass(self.inference_class, ONNXInferenceCPU), "needs_fp16": self.needs_fp16, "buffer_time": self.buffer_time, "max_images": self.max_images, @@ -96,7 +97,7 @@ def run_benchmark_on_artifact(artifact_request: ArtifactBenchmarkRequest, images engine_path = artifact_request.onnx_path.replace(".onnx", ".engine") else: engine_path = artifact_request.onnx_path.replace(".onnx", ".fp16.engine") - + if not os.path.exists(engine_path): print(f"Building engine for {artifact_request.onnx_path} and saving to {engine_path}...") with ThrottleMonitor() as throttle_monitor: @@ -110,20 +111,32 @@ def run_benchmark_on_artifact(artifact_request: ArtifactBenchmarkRequest, images else: if artifact_request.needs_fp16: raise ValueError("FP16 is not supported for ONNX inference") - + inference = artifact_request.inference_class(artifact_request.onnx_path) - + + is_cpu = issubclass(artifact_request.inference_class, ONNXInferenceCPU) + throttled = False - with ThrottleMonitor() as throttle_monitor: - accuracy_stats = evaluate(inference, images_dir, annotations_file_path, inv_class_mapping, buffer_time=artifact_request.buffer_time, max_images=artifact_request.max_images, max_dets=artifact_request.max_dets) - if throttle_monitor.did_throttle(): - throttled = True - print(f"🔴 GPU throttled, latency results are unreliable. Try increasing the buffer time. Current buffer time: {artifact_request.buffer_time}s") - else: - print("GPU did not throttle during evaluation. Latency numbers should be reliable.") - + if is_cpu: + with CPUFrequencyMonitor() as cpu_monitor: + accuracy_stats = evaluate(inference, images_dir, annotations_file_path, inv_class_mapping, buffer_time=artifact_request.buffer_time, max_images=artifact_request.max_images, max_dets=artifact_request.max_dets) + if cpu_monitor.did_drift(): + throttled = True + summary = cpu_monitor.get_summary() + print(f"🔴 CPU frequency drifted during evaluation (max drift: {summary['max_drift_mhz']:.0f} MHz). Latency results may be unreliable.") + else: + print("CPU frequency stable during evaluation. Latency numbers should be reliable.") + else: + with ThrottleMonitor() as throttle_monitor: + accuracy_stats = evaluate(inference, images_dir, annotations_file_path, inv_class_mapping, buffer_time=artifact_request.buffer_time, max_images=artifact_request.max_images, max_dets=artifact_request.max_dets) + if throttle_monitor.did_throttle(): + throttled = True + print(f"🔴 GPU throttled, latency results are unreliable. Try increasing the buffer time. Current buffer time: {artifact_request.buffer_time}s") + else: + print("GPU did not throttle during evaluation. Latency numbers should be reliable.") + latency_stats = inference.profiler.get_stats() - + return accuracy_stats, latency_stats, throttled @@ -178,7 +191,7 @@ def _fmt(x, width=6, prec=1): for result in results: model = result['artifact_request']['onnx_path'] - runtime = "TRT" if result['artifact_request']['is_trt'] else "ONNX" + runtime = "TRT" if result['artifact_request']['is_trt'] else ("ONNX-CPU" if result['artifact_request'].get('is_cpu') else "ONNX-CUDA") fp16 = result['artifact_request']['needs_fp16'] stats = result['accuracy_stats'] map50 = _pct(stats, 1) diff --git a/sab/onnx_inference.py b/sab/onnx_inference.py index 3287c1c..7d0660d 100644 --- a/sab/onnx_inference.py +++ b/sab/onnx_inference.py @@ -1,14 +1,15 @@ +import os import onnxruntime as ort import torch import numpy as np -from sab.profiler import CUDAProfiler +from sab.profiler import CUDAProfiler, CPUProfiler -class ONNXInference: - def __init__(self, model_path: str, image_input_name: str|None=None, prediction_type: str="bbox"): - self.session = ort.InferenceSession(model_path, providers=['CUDAExecutionProvider']) - +class ONNXInferenceBase: + def __init__(self, model_path: str, providers: list[str], profiler, device: str, image_input_name: str|None=None, prediction_type: str="bbox", session_options: ort.SessionOptions|None=None): + self.session = ort.InferenceSession(model_path, providers=providers, sess_options=session_options) + self.input_names = [input.name for input in self.session.get_inputs()] self.output_names = [output.name for output in self.session.get_outputs()] self.input_shapes = [input.shape for input in self.session.get_inputs()] @@ -18,22 +19,25 @@ def __init__(self, model_path: str, image_input_name: str|None=None, prediction_ raise ValueError("Model has multiple inputs, but no image input name was provided") elif len(self.input_names) == 1 and image_input_name is not None: assert image_input_name in self.input_names, f"Image input name {image_input_name} not found in model inputs" - + self.image_input_name = image_input_name if image_input_name is not None else self.input_names[0] self.image_input_shape = self.session.get_inputs()[self.input_names.index(self.image_input_name)].shape - self.profiler = CUDAProfiler() + self.profiler = profiler + self.device = device self.prediction_type = prediction_type - + + self.warmup() + def preprocess(self, input_image: torch.Tensor) -> tuple[torch.Tensor, dict]: raise NotImplementedError("Subclasses must implement this method") - + def construct_bindings(self, input_image: torch.Tensor) -> tuple[ort.IOBinding, dict[str, torch.Tensor]]: # Construct IOBinding for the input and output tensors if len(self.input_names) != 1: raise RuntimeError("Default implementation only supports models with a single input, please subclass and implement this method") - + binding = self.session.io_binding() input_image = input_image.contiguous() @@ -71,12 +75,12 @@ def construct_bindings(self, input_image: torch.Tensor) -> tuple[ort.IOBinding, outputs[output_name] = buffer return binding, outputs - + def postprocess(self, outputs: dict[str, torch.Tensor], metadata: dict) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Postprocess the outputs into bbox, class, and score # bbox must be in normalized coordinates (0-1) and in xyxy format raise NotImplementedError("Subclasses must implement this method") - + def infer(self, input_image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: input_image, metadata = self.preprocess(input_image) @@ -90,7 +94,37 @@ def infer(self, input_image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, binding.synchronize_outputs() return self.postprocess(outputs, metadata) - + + def warmup(self, num_iterations: int = 10): + """Run dummy data through the model to trigger JIT optimizations + and warm CPU/GPU caches before real measurements begin.""" + device = torch.device(self.device) + dummy_input = torch.randn(self.image_input_shape, dtype=torch.float32, device=device) + for _ in range(num_iterations): + binding, _ = self.construct_bindings(dummy_input) + binding.synchronize_inputs() + self.session.run_with_iobinding(binding) + binding.synchronize_outputs() + self.profiler.reset() + def print_latency_stats(self): self.profiler.print_stats() + +class ONNXInferenceCUDA(ONNXInferenceBase): + def __init__(self, model_path: str, image_input_name: str|None=None, prediction_type: str="bbox"): + super().__init__(model_path, ['CUDAExecutionProvider'], CUDAProfiler(), + 'cuda', image_input_name, prediction_type) + + +class ONNXInferenceCPU(ONNXInferenceBase): + def __init__(self, model_path: str, image_input_name: str|None=None, prediction_type: str="bbox"): + # Fix thread counts for stable latency across runs: + # - Fixed intra_op threads avoids ORT picking different counts per run + # - Single inter_op thread eliminates scheduling variance between ops + sess_options = ort.SessionOptions() + sess_options.intra_op_num_threads = os.cpu_count() + sess_options.inter_op_num_threads = 1 + super().__init__(model_path, ['CPUExecutionProvider'], CPUProfiler(), + 'cpu', image_input_name, prediction_type, + session_options=sess_options) diff --git a/sab/profiler.py b/sab/profiler.py index 1fb0e19..46bc15c 100644 --- a/sab/profiler.py +++ b/sab/profiler.py @@ -1,61 +1,124 @@ +import time + import torch from typing import List, Optional, Dict from contextlib import contextmanager import numpy as np -class CUDAProfiler: +class ProfilerBase: + """Base class for profilers with shared statistics and timing storage.""" + + def __init__(self): + self.timings: List[float] = [] + + def profile(self): + raise NotImplementedError("Subclasses must implement this method") + + def get_stats(self) -> Dict[str, float]: + """Compute statistics from collected timings.""" + if not self.timings: + raise ValueError("No timings recorded. Use .profile() context manager first.") + + timings_array = np.array(self.timings) + + return { + 'mean': np.mean(timings_array), + 'median': np.median(timings_array), + 'min': np.min(timings_array), + 'max': np.max(timings_array), + 'std': np.std(timings_array), + 'p90': np.percentile(timings_array, 90), + 'p95': np.percentile(timings_array, 95), + 'p99': np.percentile(timings_array, 99), + 'count': len(self.timings) + } + + def print_stats(self, name: Optional[str] = None): + """Pretty print statistics.""" + try: + stats = self.get_stats() + + if name: + print(f"\n=== {name} ===") + else: + print("\n=== Profiling Results ===") + + print(f"Samples: {stats['count']}") + print(f"Mean: {stats['mean']:.3f} ms") + print(f"Median: {stats['median']:.3f} ms") + print(f"Min: {stats['min']:.3f} ms") + print(f"Max: {stats['max']:.3f} ms") + print(f"Std: {stats['std']:.3f} ms") + print(f"P90: {stats['p90']:.3f} ms") + print(f"P95: {stats['p95']:.3f} ms") + print(f"P99: {stats['p99']:.3f} ms") + except ValueError as e: + print(f"Cannot print stats: {e}") + + def reset(self): + """Clear all recorded timings.""" + self.timings.clear() + + def get_last_timing(self) -> float: + """Get the most recent timing.""" + if not self.timings: + raise ValueError("No timings recorded.") + return self.timings[-1] + + +class CUDAProfiler(ProfilerBase): """Hardware-accurate CUDA profiler with stream-aware timing - + This profiler provides two timing modes: 1. Synchronous profiling for standard operations 2. Asynchronous profiling for CUDA graphs (avoids interference) """ - + def __init__(self, stream: Optional[torch.cuda.Stream] = None): - self.timings: List[float] = [] + super().__init__() self.stream = stream # Allow profiling on specific streams self._start_event = torch.cuda.Event(enable_timing=True) self._end_event = torch.cuda.Event(enable_timing=True) - + def set_stream(self, stream: torch.cuda.Stream): """Set the stream to profile on""" self.stream = stream - + @contextmanager def profile(self, stream: Optional[torch.cuda.Stream] = None): """Synchronous profiling for standard operations - + Use this for regular PyTorch operations, TensorRT standard execution, etc. This method handles synchronization automatically. - + Args: stream: Optional stream to profile on. If None, uses self.stream or current stream. """ # Use provided stream, or fall back to instance stream, or current stream target_stream = stream or self.stream - + if target_stream is not None: # Record events on the specific stream for accurate timing self._start_event.record(target_stream) - + # Yield control back to the caller yield - + # Record end event on the same stream self._end_event.record(target_stream) - + # Only synchronize the target stream (more efficient than global sync) target_stream.synchronize() else: # Fallback to default stream timing self._start_event.record() - + yield - + self._end_event.record() torch.cuda.synchronize() - + # Calculate and store elapsed time try: elapsed_ms = self._start_event.elapsed_time(self._end_event) @@ -64,15 +127,15 @@ def profile(self, stream: Optional[torch.cuda.Stream] = None): # Handle CUDA errors gracefully (e.g., context corruption) print(f"Timing measurement failed: {e}") # Don't append invalid timing - - @contextmanager + + @contextmanager def profile_async(self, stream: Optional[torch.cuda.Stream] = None): """Asynchronous profiling for CUDA graphs and async operations - + Use this for CUDA graphs, async kernel launches, etc. This method records timing events but doesn't synchronize within the context. Call get_last_timing_async() after ensuring the stream is complete. - + Example: with profiler.profile_async(stream=my_stream): cuda_graph.replay() @@ -80,26 +143,26 @@ def profile_async(self, stream: Optional[torch.cuda.Stream] = None): timing = profiler.get_last_timing_async() """ target_stream = stream or self.stream - + if target_stream is not None: self._start_event.record(target_stream) - + yield - + self._end_event.record(target_stream) # Don't synchronize here - let caller handle it for async operations else: self._start_event.record() - + yield - + self._end_event.record() - + def get_last_timing_async(self) -> Optional[float]: """Get timing from async profiling after stream completion - + Call this after ensuring the stream has completed (e.g., via stream.synchronize()). - + Returns: Timing in milliseconds if available, None if events aren't ready """ @@ -114,81 +177,42 @@ def get_last_timing_async(self) -> Optional[float]: except RuntimeError as e: print(f"Async timing measurement failed: {e}") return None - - def get_stats(self) -> Dict[str, float]: - """Compute statistics from collected timings""" - if not self.timings: - raise ValueError("No timings recorded. Use .profile() context manager first.") - - timings_array = np.array(self.timings) - - return { - 'mean': np.mean(timings_array), - 'median': np.median(timings_array), - 'min': np.min(timings_array), - 'max': np.max(timings_array), - 'std': np.std(timings_array), - 'p90': np.percentile(timings_array, 90), - 'p95': np.percentile(timings_array, 95), - 'p99': np.percentile(timings_array, 99), - 'count': len(self.timings) - } - - def print_stats(self, name: Optional[str] = None): - """Pretty print statistics""" - try: - stats = self.get_stats() - - if name: - print(f"\n=== {name} ===") - else: - print("\n=== Profiling Results ===") - - print(f"Samples: {stats['count']}") - print(f"Mean: {stats['mean']:.3f} ms") - print(f"Median: {stats['median']:.3f} ms") - print(f"Min: {stats['min']:.3f} ms") - print(f"Max: {stats['max']:.3f} ms") - print(f"Std: {stats['std']:.3f} ms") - print(f"P90: {stats['p90']:.3f} ms") - print(f"P95: {stats['p95']:.3f} ms") - print(f"P99: {stats['p99']:.3f} ms") - except ValueError as e: - print(f"Cannot print stats: {e}") - - def reset(self): - """Clear all recorded timings""" - self.timings.clear() - - def get_last_timing(self) -> float: - """Get the most recent timing""" - if not self.timings: - raise ValueError("No timings recorded.") - return self.timings[-1] + + +class CPUProfiler(ProfilerBase): + """CPU profiler using high-resolution wall-clock timing.""" + + @contextmanager + def profile(self): + """Synchronous profiling for CPU operations.""" + start_ns = time.perf_counter_ns() + yield + elapsed_ms = (time.perf_counter_ns() - start_ns) / 1_000_000 + self.timings.append(elapsed_ms) # Additional utility for benchmarking specific operations class StreamAwareBenchmark: """Utility for benchmarking operations with proper stream handling""" - + @staticmethod def time_operation(operation_fn, stream: torch.cuda.Stream, iterations: int = 100, use_async: bool = False): """Time a specific operation on a stream - + Args: operation_fn: Function to time stream: CUDA stream to run on - iterations: Number of iterations + iterations: Number of iterations use_async: Whether to use async profiling (for CUDA graphs) """ profiler = CUDAProfiler(stream) - + # Warmup for _ in range(10): with torch.cuda.stream(stream): operation_fn() stream.synchronize() - + # Actual timing for _ in range(iterations): with torch.cuda.stream(stream): @@ -200,21 +224,21 @@ def time_operation(operation_fn, stream: torch.cuda.Stream, iterations: int = 10 else: with profiler.profile(): operation_fn() - + return profiler.get_stats() - + @staticmethod def compare_cuda_graph_vs_standard(cuda_graph_fn, standard_fn, stream: torch.cuda.Stream, iterations: int = 100): """Compare CUDA graph vs standard execution performance""" - + print("Benchmarking Standard Execution...") stats_standard = StreamAwareBenchmark.time_operation(standard_fn, stream, iterations, use_async=False) - + print("Benchmarking CUDA Graph Execution...") stats_graph = StreamAwareBenchmark.time_operation(cuda_graph_fn, stream, iterations, use_async=True) - + print(f"\nStandard Execution - Mean: {stats_standard['mean']:.3f} ms") print(f"CUDA Graph Execution - Mean: {stats_graph['mean']:.3f} ms") print(f"CUDA Graph Speedup: {stats_standard['mean'] / stats_graph['mean']:.2f}x") - - return stats_standard, stats_graph \ No newline at end of file + + return stats_standard, stats_graph