@@ -470,3 +470,68 @@ def test_cached_compilation_config(default_vllm_config):
470470
471471 code = " " .join (code )
472472 assert "torch.ops._C.static_scaled_fp8_quant.default(" in code
473+
474+
475+ def test_compile_sizes_padding_validation ():
476+ """Test that compile_sizes with values that would be padded raises an error."""
477+ # cudagraph_capture_sizes=[1, 2, 4, 8] means:
478+ # - size 1 -> padded to 1
479+ # - size 2 -> padded to 2
480+ # - size 3 -> padded to 4
481+ # - size 4 -> padded to 4
482+ # - size 5 -> padded to 8
483+ # etc.
484+ # So compile_sizes=[3] should fail because 3 would be padded to 4
485+
486+ with pytest .raises (ValueError , match = "would be padded to" ):
487+ config = CompilationConfig (
488+ cudagraph_capture_sizes = [1 , 2 , 4 , 8 ],
489+ max_cudagraph_capture_size = 8 ,
490+ compile_sizes = [3 ],
491+ )
492+ config .post_init_cudagraph_sizes ()
493+
494+ with pytest .raises (ValueError , match = "would be padded to" ):
495+ config = CompilationConfig (
496+ cudagraph_capture_sizes = [1 , 2 , 4 , 8 ],
497+ max_cudagraph_capture_size = 8 ,
498+ compile_sizes = [5 ],
499+ )
500+ config .post_init_cudagraph_sizes ()
501+
502+ config = CompilationConfig (
503+ cudagraph_capture_sizes = [1 , 2 , 4 , 8 ],
504+ max_cudagraph_capture_size = 8 ,
505+ compile_sizes = [1 , 2 , 4 , 8 ],
506+ )
507+ config .post_init_cudagraph_sizes ()
508+ assert sorted (config .compile_sizes ) == [1 , 2 , 4 , 8 ]
509+
510+ config = CompilationConfig (
511+ cudagraph_capture_sizes = [1 , 2 , 4 , 8 ],
512+ max_cudagraph_capture_size = 8 ,
513+ compile_sizes = ["cudagraph_capture_sizes" ],
514+ )
515+ config .post_init_cudagraph_sizes ()
516+ assert sorted (config .compile_sizes ) == [1 , 2 , 4 , 8 ]
517+
518+ # When cudagraphs are disabled (max_cudagraph_capture_size=0),
519+ # padding validation should be skipped
520+ config = CompilationConfig (
521+ cudagraph_capture_sizes = [],
522+ max_cudagraph_capture_size = 0 ,
523+ compile_sizes = [3 , 5 , 7 ], # would be invalid with cudagraphs
524+ )
525+ config .post_init_cudagraph_sizes ()
526+ assert sorted (config .compile_sizes ) == [3 , 5 , 7 ]
527+
528+ # When cudagraph_mode is NONE but capture_sizes is non-empty,
529+ # padding validation should still be skipped
530+ config = CompilationConfig (
531+ cudagraph_capture_sizes = [1 , 2 , 4 , 8 ],
532+ max_cudagraph_capture_size = 8 ,
533+ cudagraph_mode = CUDAGraphMode .NONE ,
534+ compile_sizes = [3 , 5 , 7 ], # would be invalid if cudagraphs were enabled
535+ )
536+ config .post_init_cudagraph_sizes ()
537+ assert sorted (config .compile_sizes ) == [3 , 5 , 7 ]
0 commit comments