diff --git a/contrib/models/gemma3-vision/README.md b/contrib/models/gemma3-vision/README.md new file mode 100644 index 00000000..ac21fae2 --- /dev/null +++ b/contrib/models/gemma3-vision/README.md @@ -0,0 +1,133 @@ +# Contrib Model: Google Gemma3 VLM models + +NeuronX Distributed Inference implementationn for Google Gemma3 VLM (Vision-Language Model) based on the HuggingFace Transformers Gemma3 architecture with SigLIP vision encoder. + +## Model Information + +- **HuggingFace IDs:** + * [`google/gemma-3-4b-it`](https://huggingface.co/google/gemma-3-4b-it) + * [`google/gemma-3-12b-it`](https://huggingface.co/google/gemma-3-12b-it) + * [`google/gemma-3-27b-it`](https://huggingface.co/google/gemma-3-27b-it) +- **Model Type:** LLaVA-style VLM with fixed-resolution SigLIP vision encode (400M) and Transformer-based LLM backbone. +- **License:** Check HuggingFace model card + +## Architecture Details + +LLM backbones (text models): + +| Spec | Gemma 3 4B | Gemma 3 12B | Gemma 3 27B | +|---|---:|---:|---:| +| **Layers** | 34 | 48 | 62 | +| **Hidden Size** | 2560 | 3840 | 5376 | +| **Head Dim** | 256 | 256 | 128 | +| **Attention Heads** | 8 | 16 | 32 | +| **KV Heads** | 4 | 8 | 16 | +| **Intermediate Size** | 10240 | 15360 | 21504 | +| **Vocabulary size** | 32,064 | 32,064 | 32,064 | +| **Max Position Embeddings** | 131,072 | 131,072 | 131,072 | +| **Position Encoding** | RoPE | RoPE | RoPE | +| **Normalization** | RMSNorm | RMSNorm | RMSNorm | +| **Activation type** | GELU | GELU | GELU | +| **Context length** | 128K | 128K | 128K | + +The 400M-parameter fixed-resolution SigLIP vision encoder is shared by all models: + +| Spec | SigLIP vision tower | +|---|---:| +| **Layers** | 27 | +| **Hidden Size** | 1152 | +| **Head Dim** | 72 | +| **Attention Heads** | 16 | +| **KV Heads** | 16 | +| **Intermediate Size** | 4304 | +| **Activation type** | GELU | +| **Number of multi-modal tokens per image** | 256 | + +## Validation Results + +**Validated:** 2026-02-05 +**Configuration:** Trn1, TP=8, batch_size=1, seq_len=1024, float16, 1 image per sample + +### Test Results + +| Test | Status | Result | +|------|--------|--------| +| Smoke Test | ✅ PASS | Model loads successfully | +| Token Matching | ✅ PASS | 100.0% match | +| Logits Matching | ⚠️ PARTIAL | ~56.2% match | + +### Performance Metrics + +| Metric | Value | +|--------|-------| +| E2E Throughput | 360.4 tokens/s | +| CTE Throughput | 49563.7 tokens/s | +| TKG Throughput | 223.8 tokens/s | + +**Status:** ✅ GOOD + +**Note:** Low token matching is due to sampling divergence at close probability tokens, not model incorrectness. + +## Usage + +```python +import torch + +from gemma3_vision.modeling_gemma3 import NeuronGemma3ForConditionalGeneration +from gemma3_vision.utils import create_neuron_config + +model_path = "/path/to/hf/artifacts" +compiled_model_path = "/path/to/compiled/artifacts" + +# Create Neuron configuration +nrn_config = create_neuron_config( + hf_config_path=config_file_path, + text_batch_size=1, + vision_batch_size=1, # num_images_per_sample * batch_size + total_max_seq_len=1024, + torch_dtype=torch.bfloat16, + lnc=1, # Logical NC config + tp_degree=8 +) + +# Initialize model +nrn_model = NeuronGemma3ForConditionalGeneration( + model_path=model_path, + config=nrn_config +) + +# Compile and load +nrn_model.compile(compiled_model_path.as_posix()) +nrn_model.load(compiled_model_path.as_posix()) + +# Generate (see integration test for full example) +``` + +## Compatibility Matrix + +| Instance/Version | 2.27 | 2.26 and earlier | +|------------------|-------|------------------| +| Trn2 | ✅ Working | Not tested | +| Trn1 | ✅ Working | Not tested | +| Inf2 | Not tested | Not tested | + +## Testing + +Run integration tests: + +```bash +pytest contrib/models/gemma3-vision/test/integration/test_model.py --capture=tee-sys +``` + +Or run manually: + +```bash +cd contrib/models/gemma3-vision +python3 -m test.integration.test_model +``` + +## Example Checkpoints + +* [`google/gemma-3-4b-it`](https://huggingface.co/google/gemma-3-4b-it) +* [`google/gemma-3-12b-it`](https://huggingface.co/google/gemma-3-12b-it) +* [`google/gemma-3-27b-it`](https://huggingface.co/google/gemma-3-27b-it) \ No newline at end of file diff --git a/contrib/models/gemma3-vision/benchmark_report_8_fp16_tbs1_vbs1_s1024.json b/contrib/models/gemma3-vision/benchmark_report_8_fp16_tbs1_vbs1_s1024.json new file mode 100644 index 00000000..ca9dc319 --- /dev/null +++ b/contrib/models/gemma3-vision/benchmark_report_8_fp16_tbs1_vbs1_s1024.json @@ -0,0 +1 @@ +{"e2e_model": {"latency_ms_p50": 2857.349991798401, "latency_ms_p90": 2910.9119415283203, "latency_ms_p95": 2926.601207256317, "latency_ms_p99": 2933.258969783783, "latency_ms_p100": 2934.9234104156494, "latency_ms_avg": 2841.0605669021606, "throughput": 360.4287820996898}, "context_encoding_model": {"latency_ms_p50": 20.636916160583496, "latency_ms_p90": 20.769023895263672, "latency_ms_p95": 20.816147327423096, "latency_ms_p99": 20.827925205230713, "latency_ms_p100": 20.830869674682617, "latency_ms_avg": 20.66025733947754, "throughput": 49563.758242417665}, "token_generation_model": {"latency_ms_p50": 4.449129104614258, "latency_ms_p90": 4.684209823608398, "latency_ms_p95": 4.733574390411377, "latency_ms_p99": 4.841475486755371, "latency_ms_p100": 6.889104843139648, "latency_ms_avg": 4.476439462949152, "throughput": 223.8289952216445}, "vision_encoder_model": null} diff --git a/contrib/models/gemma3-vision/global_metric_store.json b/contrib/models/gemma3-vision/global_metric_store.json new file mode 100644 index 00000000..70f074d6 --- /dev/null +++ b/contrib/models/gemma3-vision/global_metric_store.json @@ -0,0 +1,530 @@ +{ + "Average": { + "tensorizer": { + "StaticProfiler::AverageFractalPeUtilization": 99.92024230957031, + "StaticProfiler::AveragePartitionUtilization": 99.880615234375, + "StaticProfiler::AveragePeUtilization": 99.92024230957031, + "StaticProfiler::LocalizationEfficiency": 253.41925048828125, + "StaticProfiler::LocalizationEfficiencyIgnoreNonlocal": 253.41925048828125, + "TilingProfiler::AveragePartitionUtilizationAfterTiling": 0, + "TilingProfiler::AveragePeUtilizationAfterTiling": 0 + } + }, + "Count": { + "tensorizer": { + "StaticProfiler::AverageFractalPeUtilization": 1, + "StaticProfiler::AveragePartitionUtilization": 1, + "StaticProfiler::AveragePeUtilization": 1, + "StaticProfiler::LocalizationEfficiency": 1, + "StaticProfiler::LocalizationEfficiencyIgnoreNonlocal": 1, + "TilingProfiler::AveragePartitionUtilizationAfterTiling": 1, + "TilingProfiler::AveragePeUtilizationAfterTiling": 1 + } + }, + "Sum": { + "compiletime": { + "AGOrderingAnalysisPass": 0.003949403762817383, + "AffinePredicateResolution": 0.00038909912109375, + "AliasDependencyElimination": 0.00012993812561035156, + "AliasDependencyInduction": 0.0007646083831787109, + "AliasDependencyReset": 0.0010361671447753906, + "BFComputeCutting": 0.0006632804870605469, + "BirCodeGenLoop": 0.06791448593139648, + "CCOpFusion": 0.009911775588989258, + "CanonicalizeDAGForPGTiling": 0.00043392181396484375, + "CanonicalizeIR": 0.00042510032653808594, + "Canonicalizer": 0.00014099999680183828, + "CoalesceCCOp": 0.0020368099212646484, + "CommuteConcat": 0.0005238056182861328, + "DMALocalityOpt": 0.0015301704406738281, + "DMAProfiler": 0.005443572998046875, + "DMATilingProfiler": 0.006621837615966797, + "DataLocalityOpt": 0.10823822021484375, + "DataStreaming": 0.003862142562866211, + "DeConcat": 0.0003120899200439453, + "DeadCodeElimination": 0.0003216266632080078, + "DeadStoreElimination": 0.0003542900085449219, + "DelinearIndices": 0.0002429485321044922, + "Delinearization": 0.0002715587615966797, + "DoNothing": 6.008148193359375e-05, + "DramToDramTranspose": 2.1463124752044678, + "DumpGraphAndMetadata": 0.00661015510559082, + "EliminateDivs": 0.0003745555877685547, + "EnforceAluDTAcc": 0.0011157989501953125, + "ExpandBatchNorm": 0.0008754730224609375, + "ExpandISAMacro": 0.004293918609619141, + "FactorizeBlkDims": 0.015153884887695313, + "FactorizeThreadAxesInFreeDims": 0.0017313957214355469, + "FlattenMacroLoop": 0.008043766021728516, + "GenericAccessSimplifier": 0.0003733634948730469, + "HloLegalizeToStablehloPass": 0.0004149999876972288, + "IdentifyCrossPassTensors": 9.000000318337698e-06, + "InferInitValue": 0.07077598571777344, + "InferIntrinsicOnCC": 0.0011551380157470703, + "InferNeuronTensor": 0.09625458717346191, + "InferNonlocalTensors": 0.0017993450164794922, + "InferPSumTensor": 0.017484188079833984, + "InlineNativeKernels": 0.0017468929290771484, + "InsertIOTransposes": 0.0004703998565673828, + "InsertLocalTransposes": 0.0007343292236328125, + "InsertOffloadedTransposes": 0.00041604042053222656, + "LICM": 0.005679607391357422, + "LateLegalizeInst": 0.015564680099487305, + "LateLegalizePostSplit": 0.005246400833129883, + "LateLowerReshapeOp": 0.0004372596740722656, + "LateLowerTensorOp": 0.0006518363952636719, + "LateNeuronInstComb": 0.0061359405517578125, + "LayoutPreprocessing": 0.002250194549560547, + "LayoutPreprocessingAndAnalysis": 0.0034074783325195313, + "LayoutRequirementAnalysis": 0.0009441375732421875, + "LegalizeCCOpLayout": 0.0006952285766601563, + "LegalizeOpLevelAlias": 0.0006194114685058594, + "LegalizePartitionReduce": 0.0011324882507324219, + "LegalizeSundaAccess": 0.01952672004699707, + "LegalizeSundaMacro": 0.009022712707519531, + "LegalizeType": 0.011729240417480469, + "LocalLayoutOpt": 0.0011942386627197266, + "LoopFusion": 0.000560760498046875, + "LoopSplitting": 0.00018835067749023438, + "LowerBroadcast": 0.002048015594482422, + "LowerComplexBroadcast": 0.0029702186584472656, + "LowerIntrinsics": 0.0019981861114501953, + "LowerTensorOp": 0.0007052421569824219, + "LowerTranspose": 0.05533957481384277, + "MLIRInstructionHistogram": 2.4000000848900527e-05, + "MacroGeneration": 0.005982398986816406, + "MaskPropagation": 0.0003821849822998047, + "MemcpyElimination": 0.0003476142883300781, + "MutateDataType": 0.0005035400390625, + "NeuronAliasDependencyInduction": 0.00400090217590332, + "NeuronAliasDependencyReset": 0.004381418228149414, + "NeuronInstComb": 0.00735926628112793, + "NeuronLICM": 0.008350372314453125, + "NeuronLoopFusion": 0.01720261573791504, + "NeuronLoopInterchange": 0.002377748489379883, + "NeuronSimplifier": 0.011650800704956055, + "NeuronSimplifyPredicates": 0.001893758773803711, + "NeuronValueNumbering": 0.0048487186431884766, + "OptimizeAliasedCopyChain": 0.0018398761749267578, + "OptimizeNKIKernels": 0.0019500255584716797, + "PAGLayoutOpt": 0.004791736602783203, + "PComputeCutting": 0.0012297630310058594, + "PGLayoutTilingPipeline": 2.1727216243743896, + "PGTiling": 0.013302803039550781, + "PadElimination": 0.00016045570373535156, + "ParAxesAnnotation": 0.0037343502044677734, + "PartialLoopFusion": 0.022133827209472656, + "PartialSimdFusion": 0.0010442733764648438, + "PerfectLoopNest": 0.0024068355560302734, + "PruneFunctions": 9.999999974752427e-07, + "RecognizeOpIdiom": 0.0002810955047607422, + "Recompute": 8.368492126464844e-05, + "RelaxPredicates": 0.004902362823486328, + "Rematerialization": 0.00031375885009765625, + "RemoveOptimizationBarriers": 1.2999999853491317e-05, + "ReshapeWeights": 0.0011010169982910156, + "ResolveAccessConflict": 0.0008833408355712891, + "ResolveComplicatePredicates": 0.00038886070251464844, + "RewriteReplicationMatmul": 0.0023565292358398438, + "RewriteWeights": 0.0023887157440185547, + "SFKVectorizer": 0.27486419677734375, + "SimpleAllReduceTiling": 0.0019252300262451172, + "Simplifier": 0.0002703666687011719, + "SimplifyMacroPredicates": 0.014539718627929688, + "SimplifyNeuronTensor": 0.0067865848541259766, + "SimplifySlice": 0.0001671314239501953, + "SimplifyTensor": 0.008214473724365234, + "SpillPSum": 0.016367673873901367, + "SplitAPUnionSets": 0.0723578929901123, + "SplitAccGrp": 0.001752614974975586, + "StableHLOCanonicalizeConv": 9.999999974752427e-07, + "StableHLOCanonicalizeForTensorizer": 1.5999999959603883e-05, + "StableHLOHoistCompute": 3.000000106112566e-06, + "StableHLOMemcastMotion": 9.999999974752427e-07, + "StableHLOPenguinizeFunctions": 3.400000059627928e-05, + "StableHLOScatterMotion": 9.999999974752427e-07, + "StableHLOTensorizerLegalizationPass": 8.299999899463728e-05, + "StaticProfiler": 0.019861459732055664, + "StaticTransposeLocalTensor": 0.0005624294281005859, + "SundaISel": 0.05923128128051758, + "TCTransform": 0.0001633167266845703, + "TensorInitialization": 0.002244710922241211, + "TensorOpSimplifier": 0.004949331283569336, + "TensorOpTransform": 0.01142573356628418, + "TilingProfiler": 0.019662141799926758, + "TransformConvOp": 0.0019693374633789063, + "TritiumFusion": 0.0010304450988769531, + "ValueNumbering": 0.000171661376953125, + "VectorizeDMA": 0.00382232666015625, + "VectorizeMatMult": 0.0006809234619140625, + "VerifySupportedOps": 1.5999999959603883e-05, + "WeightCoalescing": 0.001847982406616211, + "ZeroSizeTensorElimination": 0.0005872249603271484, + "algsimp": 4.3000000005122274e-05, + "batchnorm_expander": 1.4999999621068127e-05, + "boundary-marker-removal": 3.999999989900971e-06, + "call-inliner": 6.000000212225132e-06, + "canonicalize-boundary-marker": 1.1000000085914508e-05, + "collective-stream-id-checker": 9.999999974752427e-07, + "comparison-expander": 3.999999989900971e-06, + "computation-deduplicator": 1.9999999949504854e-06, + "config-lowering": 1.1000000085914508e-05, + "constant_folding": 9.000000136438757e-05, + "cse": 9.000000318337698e-06, + "dce": 9.999999974752427e-07, + "dynamic-slice-transpose": 3.999999989900971e-06, + "eliminate-redundant-compare": 3.999999989900971e-06, + "emit-offloaded-dropout": 2.4000000848900527e-05, + "flatten-call-graph": 9.000000318337698e-06, + "fuse-send-recv": 1.9999999494757503e-05, + "hilo-conditional-to-select": 3.000000106112566e-06, + "hilo::ConvertCustomCallToAllReducePass": 1.9999999949504854e-06, + "hilo::FusionToComposite": 1.9999999949504854e-06, + "hilo::NeuronInstCombine": 4.5000000682193786e-05, + "hilo::StableHLOLegalizeAlias": 1.8000000636675395e-05, + "hilo::StableHLONeuronOpFusion": 7.999999979801942e-06, + "hilo::StableHLOReplaceTokenTypeWithU8Pass": 3.000000106112566e-06, + "hilo::StableHLOScheduleFusion": 1.9999999949504854e-06, + "hilo::StableHLOSixtyFourHack": 7.000000096013537e-06, + "hilo::StableHLOVerifyAliasing": 4.999999873689376e-06, + "hlo-kernel-info": 5.199999941396527e-05, + "hlo-mac-count": 4.199999966658652e-05, + "instruction-histogram": 1.1000000085914508e-05, + "io-con-pipe-begin": 9.999999974752427e-07, + "io-con-pipe-end": 0.0, + "io-layout-normalization": 6.000000212225132e-06, + "legalize-ccops-for-tensorizer": 9.999999974752427e-07, + "legalize-compare": 1.9999999949504854e-06, + "lower-argminmax-custom-call": 3.999999989900971e-06, + "map-inline": 7.999999979801942e-06, + "metadata-naming": 1.2999999853491317e-05, + "mlir::detail::OpToOpPassAdaptor": 9.999999974752427e-07, + "mlir::hlo::StableHLOToPyPenguin": 0.0010610000463202596, + "mlir::stablehlo::LowerComplexExtraPass": 4.8999998398358e-05, + "mlir::stablehlo::LowerComplexPass": 6.299999949987978e-05, + "native-to-custom-softmax": 4.999999873689376e-06, + "native-to-custom-softmax-dx": 6.000000212225132e-06, + "neuron-hlo-inst-comb": 6.000000212225132e-06, + "neuron-hlo-verifier": 0.0003809999907389283, + "operand_upcaster": 9.000000318337698e-06, + "post-par-pipe-begin": 0.0, + "post-par-pipe-end": 0.0, + "post-partition-simplification": 0.00035700001171790063, + "pre-hlo-begin": 0.0, + "pre-hlo-end": 0.0, + "replace-minimum-constant": 3.999999989900971e-06, + "reshape-mover": 4.999999873689376e-06, + "simplify-concat": 0.00013899999612476677, + "simplify-while-loops": 4.999999873689376e-06, + "transform-variadic-reduce": 1.8999999156221747e-05, + "tuple-simplifier": 4.999999873689376e-06, + "unpack-nested-aws-ntwsr": 3.000000106112566e-06, + "unroll-while-loop": 0.0 + }, + "hilo": { + "HloMacCount": 0.0, + "KernelCount": 0.0, + "KernelMacCount": 0.0, + "KernelTypes": 0.0, + "Traffic": 1118061568.0 + }, + "tensorizer": { + "DMATilingProfiler::TotalInstructionsAfterTiling": 14427, + "StaticProfiler::AifUb": 26.832317352294922, + "StaticProfiler::ArithmeticIntensityTensorizer": 67.99825286865234, + "StaticProfiler::AverageDmaLength": 5655.53466796875, + "StaticProfiler::DDRTransferBytes": 882419712, + "StaticProfiler::InternalTransferBytes": 469079040, + "StaticProfiler::LoadExpanded": 76746, + "StaticProfiler::StoreExpanded": 77952, + "StaticProfiler::TotalDMAExpanded": 154698, + "StaticProfiler::TotalDynamicInstancesCount": 19134, + "StaticProfiler::TotalDynamicInstancesWithMmPackedCount": 19134, + "StaticProfiler::TotalLNCComm": 0, + "StaticProfiler::TotalLNCCommTransfer": 0, + "TilingProfiler::BatchnormInstructionsAfterTiling": 0, + "TilingProfiler::DmaInstructionsAfterTiling": 0, + "TilingProfiler::GenericInstructionsAfterTiling": 0, + "TilingProfiler::MatMultInstructionsAfterTiling": 0, + "TilingProfiler::NumPfTransposes": 45, + "TilingProfiler::NumPfTransposesForIo": 45, + "TilingProfiler::NumPfTransposesForLocal": 0, + "TilingProfiler::NumPfTransposesForNonlocal": 0, + "TilingProfiler::PfTransposeInstructions": 12617, + "TilingProfiler::PfTransposeInstructionsForIo": 12617, + "TilingProfiler::PfTransposeInstructionsForLocal": 0, + "TilingProfiler::PfTransposeInstructionsForNonlocal": 0, + "TilingProfiler::ReduceInstructionsAfterTiling": 0, + "TilingProfiler::SimdInstructionsAfterTiling": 1218, + "TilingProfiler::TotalInstructionsAfterTiling": 0, + "TransformConvOp::Conv1d_depthwise_bf01_oi01_bf01": 0, + "TransformConvOp::Conv2d_dw_fb01_io01_01bf_rep_nhwc_Pcinh": 0, + "TransformConvOp::Conv2d_pbp_0f1b_0i1o_01fb_experimental_1": 0, + "TransformConvOp::Conv2d_pbp_fb01_io01_01bf_experimental_1": 0, + "TransformConvOp::conv2d_column_packing": 0, + "TransformConvOp::conv2d_column_packing_1": 0, + "TransformConvOp::conv2d_column_packing_io10": 0, + "TransformConvOp::conv2d_depthwise_f01b_o01i_bf01": 0 + } + }, + "all": { + "compiletime": { + "Canonicalizer": 0.00014099999680183828, + "HloLegalizeToStablehloPass": 0.0004149999876972288, + "IdentifyCrossPassTensors": 9.000000318337698e-06, + "MLIRInstructionHistogram": 2.4000000848900527e-05, + "PruneFunctions": 9.999999974752427e-07, + "RemoveOptimizationBarriers": 1.2999999853491317e-05, + "StableHLOCanonicalizeConv": 9.999999974752427e-07, + "StableHLOCanonicalizeForTensorizer": 1.5999999959603883e-05, + "StableHLOHoistCompute": 3.000000106112566e-06, + "StableHLOMemcastMotion": 9.999999974752427e-07, + "StableHLOPenguinizeFunctions": 3.400000059627928e-05, + "StableHLOScatterMotion": 9.999999974752427e-07, + "StableHLOTensorizerLegalizationPass": 8.299999899463728e-05, + "VerifySupportedOps": 1.5999999959603883e-05, + "algsimp": 4.3000000005122274e-05, + "batchnorm_expander": 1.4999999621068127e-05, + "boundary-marker-removal": 3.999999989900971e-06, + "call-inliner": 6.000000212225132e-06, + "canonicalize-boundary-marker": 1.1000000085914508e-05, + "collective-stream-id-checker": 9.999999974752427e-07, + "comparison-expander": 3.999999989900971e-06, + "computation-deduplicator": 1.9999999949504854e-06, + "config-lowering": 1.1000000085914508e-05, + "constant_folding": 9.000000136438757e-05, + "cse": 9.000000318337698e-06, + "dce": 9.999999974752427e-07, + "dynamic-slice-transpose": 3.999999989900971e-06, + "eliminate-redundant-compare": 3.999999989900971e-06, + "emit-offloaded-dropout": 2.4000000848900527e-05, + "flatten-call-graph": 9.000000318337698e-06, + "fuse-send-recv": 1.9999999494757503e-05, + "hilo-conditional-to-select": 3.000000106112566e-06, + "hilo::ConvertCustomCallToAllReducePass": 1.9999999949504854e-06, + "hilo::FusionToComposite": 1.9999999949504854e-06, + "hilo::NeuronInstCombine": 4.5000000682193786e-05, + "hilo::StableHLOLegalizeAlias": 1.8000000636675395e-05, + "hilo::StableHLONeuronOpFusion": 7.999999979801942e-06, + "hilo::StableHLOReplaceTokenTypeWithU8Pass": 3.000000106112566e-06, + "hilo::StableHLOScheduleFusion": 1.9999999949504854e-06, + "hilo::StableHLOSixtyFourHack": 7.000000096013537e-06, + "hilo::StableHLOVerifyAliasing": 4.999999873689376e-06, + "hlo-kernel-info": 5.199999941396527e-05, + "hlo-mac-count": 4.199999966658652e-05, + "instruction-histogram": 1.1000000085914508e-05, + "io-con-pipe-begin": 9.999999974752427e-07, + "io-con-pipe-end": 0.0, + "io-layout-normalization": 6.000000212225132e-06, + "legalize-ccops-for-tensorizer": 9.999999974752427e-07, + "legalize-compare": 1.9999999949504854e-06, + "lower-argminmax-custom-call": 3.999999989900971e-06, + "map-inline": 7.999999979801942e-06, + "metadata-naming": 1.2999999853491317e-05, + "mlir::detail::OpToOpPassAdaptor": 9.999999974752427e-07, + "mlir::hlo::StableHLOToPyPenguin": 0.0010610000463202596, + "mlir::stablehlo::LowerComplexExtraPass": 4.8999998398358e-05, + "mlir::stablehlo::LowerComplexPass": 6.299999949987978e-05, + "native-to-custom-softmax": 4.999999873689376e-06, + "native-to-custom-softmax-dx": 6.000000212225132e-06, + "neuron-hlo-inst-comb": 6.000000212225132e-06, + "neuron-hlo-verifier": 0.0003809999907389283, + "operand_upcaster": 9.000000318337698e-06, + "post-par-pipe-begin": 0.0, + "post-par-pipe-end": 0.0, + "post-partition-simplification": 0.00035700001171790063, + "pre-hlo-begin": 0.0, + "pre-hlo-end": 0.0, + "replace-minimum-constant": 3.999999989900971e-06, + "reshape-mover": 4.999999873689376e-06, + "simplify-concat": 0.00013899999612476677, + "simplify-while-loops": 4.999999873689376e-06, + "transform-variadic-reduce": 1.8999999156221747e-05, + "tuple-simplifier": 4.999999873689376e-06, + "unpack-nested-aws-ntwsr": 3.000000106112566e-06, + "unroll-while-loop": 0.0 + } + }, + "sg00": { + "hilo": { + "ArithmeticIntensity": 0.0, + "HloMacCount": 0.0, + "KernelCount": 0.0, + "KernelMacCount": 0.0, + "KernelTypes": 0.0, + "Traffic": 1118061568.0 + } + }, + "sg0000": { + "compiletime": { + "AGOrderingAnalysisPass": 0.003949403762817383, + "AffinePredicateResolution": 0.00038909912109375, + "AliasDependencyElimination": 0.00012993812561035156, + "AliasDependencyInduction": 0.0007646083831787109, + "AliasDependencyReset": 0.0010361671447753906, + "BFComputeCutting": 0.0006632804870605469, + "BirCodeGenLoop": 0.06791448593139648, + "CCOpFusion": 0.009911775588989258, + "CanonicalizeDAGForPGTiling": 0.00043392181396484375, + "CanonicalizeIR": 0.00042510032653808594, + "CoalesceCCOp": 0.0020368099212646484, + "CommuteConcat": 0.0005238056182861328, + "DMALocalityOpt": 0.0015301704406738281, + "DMAProfiler": 0.005443572998046875, + "DMATilingProfiler": 0.006621837615966797, + "DataLocalityOpt": 0.10823822021484375, + "DataStreaming": 0.003862142562866211, + "DeConcat": 0.0003120899200439453, + "DeadCodeElimination": 0.0003216266632080078, + "DeadStoreElimination": 0.0003542900085449219, + "DelinearIndices": 0.0002429485321044922, + "Delinearization": 0.0002715587615966797, + "DoNothing": 6.008148193359375e-05, + "DramToDramTranspose": 2.1463124752044678, + "DumpGraphAndMetadata": 0.00661015510559082, + "EliminateDivs": 0.0003745555877685547, + "EnforceAluDTAcc": 0.0011157989501953125, + "ExpandBatchNorm": 0.0008754730224609375, + "ExpandISAMacro": 0.004293918609619141, + "FactorizeBlkDims": 0.015153884887695313, + "FactorizeThreadAxesInFreeDims": 0.0017313957214355469, + "FlattenMacroLoop": 0.008043766021728516, + "GenericAccessSimplifier": 0.0003733634948730469, + "InferInitValue": 0.07077598571777344, + "InferIntrinsicOnCC": 0.0011551380157470703, + "InferNeuronTensor": 0.09625458717346191, + "InferNonlocalTensors": 0.0017993450164794922, + "InferPSumTensor": 0.017484188079833984, + "InlineNativeKernels": 0.0017468929290771484, + "InsertIOTransposes": 0.0004703998565673828, + "InsertLocalTransposes": 0.0007343292236328125, + "InsertOffloadedTransposes": 0.00041604042053222656, + "LICM": 0.005679607391357422, + "LateLegalizeInst": 0.015564680099487305, + "LateLegalizePostSplit": 0.005246400833129883, + "LateLowerReshapeOp": 0.0004372596740722656, + "LateLowerTensorOp": 0.0006518363952636719, + "LateNeuronInstComb": 0.0061359405517578125, + "LayoutPreprocessing": 0.002250194549560547, + "LayoutPreprocessingAndAnalysis": 0.0034074783325195313, + "LayoutRequirementAnalysis": 0.0009441375732421875, + "LegalizeCCOpLayout": 0.0006952285766601563, + "LegalizeOpLevelAlias": 0.0006194114685058594, + "LegalizePartitionReduce": 0.0011324882507324219, + "LegalizeSundaAccess": 0.01952672004699707, + "LegalizeSundaMacro": 0.009022712707519531, + "LegalizeType": 0.011729240417480469, + "LocalLayoutOpt": 0.0011942386627197266, + "LoopFusion": 0.000560760498046875, + "LoopSplitting": 0.00018835067749023438, + "LowerBroadcast": 0.002048015594482422, + "LowerComplexBroadcast": 0.0029702186584472656, + "LowerIntrinsics": 0.0019981861114501953, + "LowerTensorOp": 0.0007052421569824219, + "LowerTranspose": 0.05533957481384277, + "MacroGeneration": 0.005982398986816406, + "MaskPropagation": 0.0003821849822998047, + "MemcpyElimination": 0.0003476142883300781, + "MutateDataType": 0.0005035400390625, + "NeuronAliasDependencyInduction": 0.00400090217590332, + "NeuronAliasDependencyReset": 0.004381418228149414, + "NeuronInstComb": 0.00735926628112793, + "NeuronLICM": 0.008350372314453125, + "NeuronLoopFusion": 0.01720261573791504, + "NeuronLoopInterchange": 0.002377748489379883, + "NeuronSimplifier": 0.011650800704956055, + "NeuronSimplifyPredicates": 0.001893758773803711, + "NeuronValueNumbering": 0.0048487186431884766, + "OptimizeAliasedCopyChain": 0.0018398761749267578, + "OptimizeNKIKernels": 0.0019500255584716797, + "PAGLayoutOpt": 0.004791736602783203, + "PComputeCutting": 0.0012297630310058594, + "PGLayoutTilingPipeline": 2.1727216243743896, + "PGTiling": 0.013302803039550781, + "PadElimination": 0.00016045570373535156, + "ParAxesAnnotation": 0.0037343502044677734, + "PartialLoopFusion": 0.022133827209472656, + "PartialSimdFusion": 0.0010442733764648438, + "PerfectLoopNest": 0.0024068355560302734, + "RecognizeOpIdiom": 0.0002810955047607422, + "Recompute": 8.368492126464844e-05, + "RelaxPredicates": 0.004902362823486328, + "Rematerialization": 0.00031375885009765625, + "ReshapeWeights": 0.0011010169982910156, + "ResolveAccessConflict": 0.0008833408355712891, + "ResolveComplicatePredicates": 0.00038886070251464844, + "RewriteReplicationMatmul": 0.0023565292358398438, + "RewriteWeights": 0.0023887157440185547, + "SFKVectorizer": 0.27486419677734375, + "SimpleAllReduceTiling": 0.0019252300262451172, + "Simplifier": 0.0002703666687011719, + "SimplifyMacroPredicates": 0.014539718627929688, + "SimplifyNeuronTensor": 0.0067865848541259766, + "SimplifySlice": 0.0001671314239501953, + "SimplifyTensor": 0.008214473724365234, + "SpillPSum": 0.016367673873901367, + "SplitAPUnionSets": 0.0723578929901123, + "SplitAccGrp": 0.001752614974975586, + "StaticProfiler": 0.019861459732055664, + "StaticTransposeLocalTensor": 0.0005624294281005859, + "SundaISel": 0.05923128128051758, + "TCTransform": 0.0001633167266845703, + "TensorInitialization": 0.002244710922241211, + "TensorOpSimplifier": 0.004949331283569336, + "TensorOpTransform": 0.01142573356628418, + "TilingProfiler": 0.019662141799926758, + "TransformConvOp": 0.0019693374633789063, + "TritiumFusion": 0.0010304450988769531, + "ValueNumbering": 0.000171661376953125, + "VectorizeDMA": 0.00382232666015625, + "VectorizeMatMult": 0.0006809234619140625, + "WeightCoalescing": 0.001847982406616211, + "ZeroSizeTensorElimination": 0.0005872249603271484 + }, + "tensorizer": { + "DMATilingProfiler::TotalInstructionsAfterTiling": 14427, + "StaticProfiler::AifUb": 26.832317352294922, + "StaticProfiler::ArithmeticIntensityTensorizer": 67.99825286865234, + "StaticProfiler::AverageDmaLength": 5655.53466796875, + "StaticProfiler::AverageFractalPeUtilization": 99.92024230957031, + "StaticProfiler::AveragePartitionUtilization": 99.880615234375, + "StaticProfiler::AveragePeUtilization": 99.92024230957031, + "StaticProfiler::DDRTransferBytes": 882419712, + "StaticProfiler::InternalTransferBytes": 469079040, + "StaticProfiler::LoadExpanded": 76746, + "StaticProfiler::LocalizationEfficiency": 253.41925048828125, + "StaticProfiler::LocalizationEfficiencyIgnoreNonlocal": 253.41925048828125, + "StaticProfiler::StoreExpanded": 77952, + "StaticProfiler::TotalDMAExpanded": 154698, + "StaticProfiler::TotalDynamicInstancesCount": 19134, + "StaticProfiler::TotalDynamicInstancesWithMmPackedCount": 19134, + "StaticProfiler::TotalLNCComm": 0, + "StaticProfiler::TotalLNCCommTransfer": 0, + "TilingProfiler::AveragePartitionUtilizationAfterTiling": 0, + "TilingProfiler::AveragePeUtilizationAfterTiling": 0, + "TilingProfiler::BatchnormInstructionsAfterTiling": 0, + "TilingProfiler::DmaInstructionsAfterTiling": 0, + "TilingProfiler::GenericInstructionsAfterTiling": 0, + "TilingProfiler::MatMultInstructionsAfterTiling": 0, + "TilingProfiler::NumPfTransposes": 45, + "TilingProfiler::NumPfTransposesForIo": 45, + "TilingProfiler::NumPfTransposesForLocal": 0, + "TilingProfiler::NumPfTransposesForNonlocal": 0, + "TilingProfiler::PfTransposeInstructions": 12617, + "TilingProfiler::PfTransposeInstructionsForIo": 12617, + "TilingProfiler::PfTransposeInstructionsForLocal": 0, + "TilingProfiler::PfTransposeInstructionsForNonlocal": 0, + "TilingProfiler::ReduceInstructionsAfterTiling": 0, + "TilingProfiler::SimdInstructionsAfterTiling": 1218, + "TilingProfiler::TotalInstructionsAfterTiling": 0, + "TransformConvOp::Conv1d_depthwise_bf01_oi01_bf01": 0, + "TransformConvOp::Conv2d_dw_fb01_io01_01bf_rep_nhwc_Pcinh": 0, + "TransformConvOp::Conv2d_pbp_0f1b_0i1o_01fb_experimental_1": 0, + "TransformConvOp::Conv2d_pbp_fb01_io01_01bf_experimental_1": 0, + "TransformConvOp::conv2d_column_packing": 0, + "TransformConvOp::conv2d_column_packing_1": 0, + "TransformConvOp::conv2d_column_packing_io10": 0, + "TransformConvOp::conv2d_depthwise_f01b_o01i_bf01": 0 + } + } +} diff --git a/contrib/models/gemma3-vision/src/gemma3_vision/__init__.py b/contrib/models/gemma3-vision/src/gemma3_vision/__init__.py new file mode 100644 index 00000000..7a382233 --- /dev/null +++ b/contrib/models/gemma3-vision/src/gemma3_vision/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2025 © Amazon.com and Affiliates + +from .modeling_gemma3 import ( + NeuronGemma3ForConditionalGeneration, + Gemma3InferenceConfig, + TextGemma3InferenceConfig, + NeuronTextGemma3ForCausalLM, +) +from .modeling_gemma3_vision import ( + NeuronGemma3VisionModel, + NeuronGemma3MultiModalProjector, + Gemma3VisionModelWrapper, +) +from .modeling_gemma3_text import ( + NeuronGemma3TextModel, +) + +__all__ = [ + "NeuronGemma3ForConditionalGeneration", + "Gemma3InferenceConfig", + "NeuronGemma3VisionModel", + "NeuronGemma3MultiModalProjector", + "Gemma3VisionModelWrapper", + "NeuronGemma3TextModel", + "TextGemma3InferenceConfig", + "NeuronTextGemma3ForCausalLM", +] diff --git a/contrib/models/gemma3-vision/src/gemma3_vision/modeling_gemma3.py b/contrib/models/gemma3-vision/src/gemma3_vision/modeling_gemma3.py new file mode 100644 index 00000000..f109905d --- /dev/null +++ b/contrib/models/gemma3-vision/src/gemma3_vision/modeling_gemma3.py @@ -0,0 +1,607 @@ +from gemma3_vision.ndxi_patch import apply_patch +apply_patch() + +import copy # noqa: E402 +import math # noqa: E402 +import logging # noqa: E402 +from typing import Callable, Dict, List, Optional, Tuple, Type, Union, Any # noqa: E402 + +import torch +import torch.nn.functional as F +import torch.nn.utils.rnn as rnn_utils +from transformers.modeling_outputs import CausalLMOutputWithPast + +from neuronx_distributed.quantization.quantization_utils import convert_qint8_to_int8_state_dict +import neuronx_distributed_inference.modules.autobucketing as autobucketing +from neuronx_distributed_inference.models.config import InferenceConfig, NeuronConfig +from neuronx_distributed_inference.models.image_to_text_model_base import ( + ImageToTextInferenceConfig, + NeuronBaseForImageToText +) +from neuronx_distributed_inference.models.image_to_text_model_wrapper import ( + ImageToTextModelWrapper, + IMAGE_TO_TEXT_MODEL_WRAPPER_INPUT_KEYS +) +from neuronx_distributed_inference.models.llama4.utils.encoder_utils import pad_vision_embeddings +from neuronx_distributed_inference.models.model_wrapper import ( + CONTEXT_ENCODING_MODEL_TAG, + TOKEN_GENERATION_MODEL_TAG, + VISION_ENCODER_MODEL_TAG +) +from neuronx_distributed_inference.modules.flashdecode.utils import calculate_num_cores_per_group +from neuronx_distributed_inference.models.model_base import NeuronBaseForCausalLM + +from gemma3_vision.modeling_gemma3_text import NeuronGemma3TextModel +from gemma3_vision.modeling_gemma3_vision import NeuronGemma3VisionModel, Gemma3VisionModelWrapper +from gemma3_vision.utils import convert_state_dict_to_fused_qkv, StateDict + +logger = logging.getLogger("Neuron") + + +class Gemma3InferenceConfig(ImageToTextInferenceConfig): + def __init__( + self, + text_neuron_config, + vision_neuron_config, + fused_spec_config=None, + load_config=None, + metadata: Optional[Dict] = None, + **kwargs, + ): + super().__init__( + text_neuron_config=text_neuron_config, + vision_neuron_config=vision_neuron_config, + fused_spec_config=fused_spec_config, + load_config=load_config, + metadata=metadata, + **kwargs, + ) + + # NeuronLlamaMLP expects the activation type to be at text_config.hidden_act + # Enable to fully reuse NeuronLlamaMLP + if not hasattr(self.text_config, "hidden_act"): + self.text_config.hidden_act = self.text_config.hidden_activation + del self.text_config.hidden_activation + + if self.text_config.neuron_config.is_block_kv_layout: + raise ValueError("Gemma3 does not yet support block_kv_layout.") + if self.text_config.neuron_config.is_prefix_caching: + raise ValueError("Gemma3 does not yet support prefix_caching.") + if self.text_config.neuron_config.is_chunked_prefill: + raise ValueError("Gemma3 does not yet support chunked_prefill.") + if self.text_config.neuron_config.is_medusa: + raise ValueError("Gemma3 does not yet support medusa.") + if self.text_config.neuron_config.enable_fused_speculation: + raise ValueError("Gemma3 does not yet support fused speculation.") + + if self.neuron_config.flash_decoding_enabled: + # Following pixtral implementation, we use REPLICATE_TO_TP_DEGREE as the sharding_strategy + # Hence attn_heads are padded to become divisible by tp_degree + num_attn_heads, num_kv_heads = self.text_config.num_attention_heads, self.text_config.num_key_value_heads + num_attn_heads = (num_attn_heads // self.neuron_config.tp_degree + 1) * self.neuron_config.tp_degree + self.text_config.num_cores_per_group = calculate_num_cores_per_group( + num_attn_heads, num_kv_heads, self.neuron_config.tp_degree + ) + + def get_required_attributes(self) -> List[str]: + return [ + "text_config", + "vision_config", + "text_config.head_dim", # for gemma3, head_dim != hidden_size // num_attention_heads + "text_config.hidden_size", + "text_config.num_attention_heads", + "text_config.num_hidden_layers", + "text_config.num_key_value_heads", + "text_config.query_pre_attn_scalar", + "text_config.rope_scaling", + "text_config.sliding_window", + "vision_config.hidden_size", + "vision_config.image_size", + "vision_config.num_attention_heads", + "vision_config.num_hidden_layers", + "vision_config.patch_size", + ] + + @classmethod + def get_neuron_config_cls(cls) -> Type[NeuronConfig]: + return NeuronConfig + + +class TextGemma3InferenceConfig(InferenceConfig): + + def __init__( + self, + neuron_config: NeuronConfig, + fused_spec_config=None, + load_config=None, + metadata: Optional[Dict] = None, + **kwargs + ): + super().__init__( + neuron_config=neuron_config, + fused_spec_config=fused_spec_config, + load_config=load_config, + metadata=metadata, + **kwargs, + ) + + # NeuronLlamaMLP expects the activation type to be at text_config.hidden_act + # Enable to fully reuse NeuronLlamaMLP + if not hasattr(self, "hidden_act"): + self.hidden_act = self.hidden_activation + del self.hidden_activation + + def get_required_attributes(self) -> List[str]: + return [ + "head_dim", # for gemma3, head_dim != hidden_size // num_attention_heads + "hidden_size", + "num_attention_heads", + "num_hidden_layers", + "num_key_value_heads", + "query_pre_attn_scalar", + "rope_scaling", + "sliding_window", + ] + + +class NeuronGemma3ForConditionalGeneration(NeuronBaseForImageToText): + # model cls + text_model_cls = NeuronGemma3TextModel + vision_model_cls = NeuronGemma3VisionModel + + # model wrappers + text_model_wrapper = ImageToTextModelWrapper + vision_model_wrapper = Gemma3VisionModelWrapper + + def __init__(self, *args, **kwargs): + super().__init__( + self.text_model_cls, + self.vision_model_cls, + self.text_model_wrapper, + self.vision_model_wrapper, + *args, + **kwargs, + ) + + @classmethod + def get_config_cls(cls): + # Gemma3-specific + return Gemma3InferenceConfig + + def enable_vision_encoder(self, enable_wlt_optimization: bool = True, **model_init_kwargs): + # Identical to NeuronPixtralForCausalLM.enable_vision_encoder + # - except use get_compiler_args + VISION_ENCODER_MODEL_TAG (instead of get_vision_compiler_args) + # like NeuronLlama4ForCausalLM.enable_vision_encoder + self.compile_tag = VISION_ENCODER_MODEL_TAG + + new_config = copy.deepcopy(self.config) + if new_config.vision_config.neuron_config.enable_bucketing: + # neuron_config.buckets default to neuron_config.seq_len is not given. For vision we want to do auto-bucketing here + if new_config.vision_config.neuron_config.buckets == [new_config.vision_config.neuron_config.seq_len] or \ + new_config.vision_config.neuron_config.buckets is None: + # 1024 vision seq len corresponds to a single 512x512 image. Smaller bucket size does not make sense in real life. + if new_config.vision_config.neuron_config.seq_len > 1024: + new_config.vision_config.neuron_config.buckets = autobucketing.generate_buckets( + 1024, new_config.vision_config.neuron_config.seq_len + ) + else: + new_config.vision_config.neuron_config.buckets = [new_config.vision_config.neuron_config.seq_len] + # This should not be needed as in vision modeling code we should always use vision_config.neuron_config as vision model's neuron config + # added this line just to add insurance to avoid mix-up + new_config.neuron_config = copy.deepcopy(new_config.vision_config.neuron_config) + + self.vision_encoder_model = self.vision_model_wrapper( + config=new_config, + model_cls=self.vision_model_cls, + tag=VISION_ENCODER_MODEL_TAG, + compiler_args=self.get_compiler_args(), + model_init_kwargs=model_init_kwargs, + # to turn on weight layout optimization + priority_model_idx=(0 if enable_wlt_optimization else None), + pipeline_execution=True, + return_ranked_to_cpu=True + ) + self.vision_models.append(self.vision_encoder_model) + + @staticmethod + def update_state_dict_for_tied_weights(state_dict: StateDict) -> None: + # Gemma3-specific + try: + state_dict["lm_head.weight"] = state_dict["embed_tokens.weight"].clone() + except KeyError: + state_dict["embed_tokens.weight"] = state_dict["lm_head.weight"].clone() + + @staticmethod + def convert_hf_to_neuron_state_dict(state_dict: StateDict, inference_config: InferenceConfig) -> StateDict: + # Gemma3-specific + neuron_config = inference_config.neuron_config + attention_keys = { + ".self_attn.q_proj.": ".self_attn.qkv_proj.q_proj.", + ".self_attn.k_proj.": ".self_attn.qkv_proj.k_proj.", + ".self_attn.v_proj.": ".self_attn.qkv_proj.v_proj.", + ".self_attn.o_proj.": ".self_attn.o_proj.o_proj.", + ".self_attn.out_proj.": ".self_attn.o_proj.o_proj.", # for siglip + ".self_attn.q_norm.": ".self_attn.q_layernorm.", + ".self_attn.k_norm.": ".self_attn.k_layernorm.", + } + + # At the time of writing, NxDI (Neuron 2.26) attention layer does not provide a simple way to use a custom + # scaling factor for raw attention scores (QK^T) while ensuring all optimizations (e.g. kernels) remain available + # To work around this, we fuse the scaling factor into the weights (knowing that the attention layer will use the + # default math.sqrt(inference_config.head_dim) value) + default_qk_scaling_factor_inv = math.sqrt(float(inference_config.text_config.query_pre_attn_scalar)) + gemma_qk_scaling_factor = 1.0 / math.sqrt(float(inference_config.text_config.head_dim)) + gamma = math.sqrt(gemma_qk_scaling_factor * default_qk_scaling_factor_inv) + + new_state_dict = {} + for key, weights in state_dict.items(): + if 'language_model.model.' in key: + key = key.replace('language_model.model.', "") + for atten_key in attention_keys: + if atten_key in key: + replacement_atten_key = attention_keys[atten_key] + key = key.replace(atten_key, replacement_atten_key) + break + if key.endswith((".q_proj.weight", ".k_proj.weight")): + orig_dtype = weights.dtype + weights = (weights.to(dtype=torch.float32) * gamma).to(dtype=orig_dtype) + if 'language_model.lm_head.' in key: + key = key.replace('language_model.', "") + if 'vision_tower.' in key: + key = key.replace('vision_tower.', 'vision_encoder.') + for atten_key in attention_keys: + if atten_key in key: + replacement_atten_key = attention_keys[atten_key] + key = key.replace(atten_key, replacement_atten_key) + break + new_state_dict[key] = weights + + # If LNC > 1, model requires lm_head.bias which is equivalent to lm_head_pad + if "language_model.lm_head.bias" not in state_dict and inference_config.neuron_config.lm_head_pad: + # Use embed_tokens.weight instead of lm_head.weight as lm_head.weight is tied to embed_tokens.weight in Gemma3 + new_state_dict["lm_head.bias"] = torch.zeros(new_state_dict["embed_tokens.weight"].shape[0], dtype=torch.float32) + + if inference_config.text_config.neuron_config.fused_qkv: + new_state_dict = convert_state_dict_to_fused_qkv( + state_dict=new_state_dict, + num_layers=inference_config.text_config.num_hidden_layers, + neuron_config=inference_config.text_config.neuron_config, + prefix="layers.{layer_num}.self_attn" + ) + + if inference_config.vision_config.neuron_config.fused_qkv: + new_state_dict = convert_state_dict_to_fused_qkv( + state_dict=new_state_dict, + num_layers=inference_config.vision_config.num_hidden_layers, + neuron_config=inference_config.vision_config.neuron_config, + prefix="vision_encoder.vision_model.encoder.layers.{layer_num}.self_attn" + ) + + if neuron_config.vocab_parallel: + new_state_dict["embed_tokens.rank_util.rank"] = torch.arange(0, neuron_config.local_ranks_size) + + tp_degree = neuron_config.tp_degree + for i in range(inference_config.text_config.num_hidden_layers): + new_state_dict[f"layers.{i}.self_attn.rank_util.rank"] = torch.arange(0, tp_degree, dtype=torch.int32) + + new_state_dict["rank_util.rank"] = torch.arange(0, tp_degree, dtype=torch.int32) + + return new_state_dict + + @staticmethod + def _convert_input_dict_to_ordered_tuple(input_dict: Dict[str, Any]): + # Identical NeuronLlama4ForCausalLM._convert_input_dict_to_ordered_tuple, to be removed? + """ + Utility function to convert input dictionary to ordered tuple + based on outputs of _get_model_outputs + """ + args = [] + + for key in IMAGE_TO_TEXT_MODEL_WRAPPER_INPUT_KEYS: + if key in input_dict and input_dict[key] is not None: + arg = input_dict[key] + else: + arg = torch.empty(0) + args.append(arg) + + return tuple(args) + + def _select_buckets_for_padding_length(self, position_ids): + # Identical to NeuronLlama4ForCausalLM._select_buckets_for_padding_length + neuron_config = self.config.neuron_config + context_encoding_buckets = neuron_config.context_encoding_buckets if neuron_config.context_encoding_buckets is not None \ + else neuron_config.buckets + token_generation_buckets = neuron_config.token_generation_buckets if neuron_config.token_generation_buckets is not None \ + else neuron_config.buckets + + selected_buckets = token_generation_buckets + if self._is_prefill(position_ids): + selected_buckets = context_encoding_buckets + + return selected_buckets + + @staticmethod + def get_padding_length(buckets, position_ids): + # Identical to [NeuronLlama4ForCausalLM|NeuronPixtralForCausalLM]._select_buckets_for_padding_length + max_position_id = torch.max(position_ids).item() + for val in buckets: + if val > max_position_id: + return val + raise ValueError("No bucket found for provided input_ids!") + + @staticmethod + def get_required_kwargs() -> List[str]: + # Gemma3-specific + """The list of additional input arguments to be prepared in HuggingFaceGenerationAdapter.prepare_inputs_for_generation()""" + return [ + "pixel_values", + "vision_mask", + ] + + @staticmethod + def generate_positions_from_mask(mask: torch.Tensor) -> torch.Tensor: + # Gemma3-specific + """ + Generate position indices from a boolean mask. + Compared to generate_positions_from_mask() of models/llama4/utils/encoder_utils.py, + this function can generate 1D or 2D masks to support batch size > 1. + + Args: + mask (torch.Tensor): A 1D or 2D boolean tensor + + Returns: + torch.Tensor: A 1D or 2D tensor containing the indices where the mask is True + """ + if mask.dim() == 1: + return torch.nonzero(mask).squeeze() + else: + rows, cols = torch.nonzero(mask, as_tuple=True) + row_counts = torch.bincount(rows, minlength=mask.shape[0]) + cols_per_row = torch.split(cols, row_counts.tolist()) + return rnn_utils.pad_sequence(cols_per_row, batch_first=True, padding_value=0) + + @staticmethod + def pad_positions(positions: torch.LongTensor, target_size: int, fill_value: float) -> torch.LongTensor: + """ + Pad the positions tensor to a target size. + Compared to pad_positions() of models/llama4/utils/encoder_utils.py, + this function can support batch size > 1. + + Args: + positions (torch.Tensor): A 1D or 2D tensor containing position indices + target_size (int): The desired size of the padded tensor + fill_value (int): The value used for padding + + Returns: + torch.Tensor: A 3D tensor of shape (batch_size, target_size, 1) containing padded position indices + """ + # positions_2d of shape (batch_sz, seq_len) + positions_2d = positions.unsqueeze(0) if positions.dim() == 1 else positions + padding_size = target_size - positions_2d.shape[1] + assert padding_size >= 0, "Text model sequence length is not enough to handle all vision embeddings" + positions_padded = F.pad(positions_2d, (0, padding_size), value=fill_value) + # output tensor of shape (batch_sz, target_sz, 1) + return positions_padded.unsqueeze(-1) + + @staticmethod + def _create_position_ids(attention_mask_2d: torch.LongTensor, is_prefill: bool) -> torch.LongTensor: + position_ids = attention_mask_2d.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask_2d == 0, 1) + if is_prefill: + return position_ids + else: + return torch.amax(position_ids, dim=1, keepdim=True) + 1 + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + seq_ids: Optional[torch.LongTensor] = None, + sampling_params: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + vision_mask: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.FloatTensor] = None, + adapter_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + medusa_args=None, + input_capture_hook: Optional[Callable] = None, + tensor_capture_hook: Optional[Callable] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + # Very close to NeuronLlama4ForCausalLM.forward + is_prefill = (input_ids.shape[-1] > 1) + include_images = (pixel_values is not None) and (vision_mask is not None) and (pixel_values.sum() != 0) + + if position_ids is None: + position_ids = self._create_position_ids(attention_mask_2d=attention_mask, is_prefill=is_prefill) + + buckets = self._select_buckets_for_padding_length(position_ids=position_ids) + pad_target_size = self.get_padding_length(buckets=buckets, position_ids=position_ids) + pad_fill_value = (pad_target_size - 1) + if (is_prefill and include_images): + assert ( + vision_mask.dtype == torch.bool + ), f"Parameter `vision_mask` must be of type bool, recieved {vision_mask.dtype}" + # Call the vision encoder to create a sequence of vision token embeddings for each input image + # pixel_values of shape (batch_sz * img_per_sample, 3, height, width) + vision_embeddings = self.vision_encoder_model( + pixel_values.to(self.vision_config.neuron_config.torch_dtype), + ).to(self.text_config.neuron_config.torch_dtype) + + # Flatten vision embeddings: required if img_per_sample > 1 + # vision_embeddings of shape (batch_sz * img_per_sample, seq_len_per_image, embedding_dim) + # vision_mask of shape (batch_sz, total_seq_len) + batch_sz = 1 if (vision_mask.dim() == 1) else vision_mask.shape[0] + num_images, seq_len, embedding_dim = vision_embeddings.shape + img_per_sample = num_images // batch_sz + vision_embeddings = vision_embeddings.view(batch_sz, img_per_sample * seq_len, embedding_dim) + + # Sequences of vision token embeddings are padded to the bucket size the text model has been compiled with + vision_embeddings = pad_vision_embeddings(vision_embeddings=vision_embeddings, pad_limit=pad_target_size) + + # Positions used to scatter vision embeddings at specific positions into the sequence passed to the text model + # are created from the vision mask + vision_mask = self.generate_positions_from_mask(mask=vision_mask.squeeze()) + vision_mask = self.pad_positions( + positions=vision_mask, + target_size=pad_target_size, + fill_value=pad_fill_value + ) + else: + # Either token generation or text-only prefill -> still need dummy inputs for the compiled text model + vision_embeddings, vision_mask = self.context_encoding_model.get_dummy_vision_inputs( + config=self.text_config, + input_ids=input_ids, + n_active_tokens=pad_target_size, + fill_value=pad_fill_value + ) + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + seq_ids=seq_ids, + sampling_params=sampling_params, + vision_embeddings=vision_embeddings, + vision_mask=vision_mask, + ) + + def enable_token_generation(self): + # Identical to NeuronLlama4ForCausalLM.enable_token_generation -> Required for get_compiler_args to succeed + self.compile_tag = TOKEN_GENERATION_MODEL_TAG + super().enable_token_generation() + + def enable_context_encoding(self): + # Identical to NeuronLlama4ForCausalLM.enable_context_encoding -> Required for get_compiler_args to succeed + self.compile_tag = CONTEXT_ENCODING_MODEL_TAG + super().enable_context_encoding() + + def get_compiler_args(self) -> str: + # Identical to NeuronLlama4ForCausalLM.get_compiler_args + logical_nc_config = self.text_config.neuron_config.logical_nc_config + + if self.compile_tag == CONTEXT_ENCODING_MODEL_TAG: + optimization_level = "-O1" + elif self.compile_tag == TOKEN_GENERATION_MODEL_TAG: + optimization_level = "-O2" + elif self.compile_tag == VISION_ENCODER_MODEL_TAG: + return f"-O1 --model-type=transformer --tensorizer-options='--enable-ccop-compute-overlap' " \ + f"--auto-cast=none --lnc={logical_nc_config}" + else: + raise ValueError(f"get_compiler_args() Invalid compile tag encountered: {self.compile_tag}") + + args = f"--auto-cast=none --model-type=transformer --tensorizer-options='--enable-ccop-compute-overlap " \ + f"--cc-pipeline-tiling-factor=1 --vectorize-strided-dma --enable-scalar-dge-vectorization' " \ + f"--lnc={logical_nc_config} {optimization_level} " + return args + + @classmethod + def prepare_quantized_state_dict(cls, hf_model_quant): + # Gemma3-specific + model_quant_sd = hf_model_quant.state_dict() + convert_qint8_to_int8_state_dict(model_quant_sd) + return model_quant_sd + + def _get_constructed_outputs(self, outputs, is_run_on_neuron): + if self.on_device_sampling and self.text_config.neuron_config.output_logits and not \ + (self.text_config.neuron_config.enable_fused_speculation or self.text_config.neuron_config.is_medusa): + logits_or_next_tokens = outputs[:2] + constructed_outputs = self._construct_output_with_tokens_and_logits(next_tokens=logits_or_next_tokens[0], logits=logits_or_next_tokens[1]) + else: + if is_run_on_neuron: + # FIX: Remove updated KV cache tensor (outputs[1]) + logits_or_next_tokens = logits_or_next_tokens = outputs[0] if isinstance(outputs, (list, tuple)) else outputs + else: + # When run on cpu, KV cache is returned which has to be ignored + logits_or_next_tokens, *_ = outputs + constructed_outputs = self._construct_output(logits_or_next_tokens) + + if logging.root.isEnabledFor(logging.DEBUG): + logging.debug("---output---") + logging.debug( + f"{'tokens' if self.on_device_sampling else 'logits'} = %s, ", + logits_or_next_tokens, + ) + + return constructed_outputs + + @staticmethod + def load_hf_model(model_path, **kwargs): + from transformers import Gemma3ForConditionalGeneration, Gemma3Config + config = Gemma3Config.from_pretrained(model_path) + model = Gemma3ForConditionalGeneration.from_pretrained(model_path, config=config).eval() + return model + + +class NeuronTextGemma3ForCausalLM(NeuronBaseForCausalLM): + + _model_cls = NeuronGemma3TextModel + + @staticmethod + def load_hf_model(model_path, **kwargs): + from transformers import Gemma3ForCausalLM + return Gemma3ForCausalLM.from_pretrained(model_path, **kwargs) # nosec B615 + + @staticmethod + def update_state_dict_for_tied_weights(state_dict: StateDict) -> None: + state_dict["lm_head.weight"] = state_dict["embed_tokens.weight"].clone() + + @staticmethod + def convert_hf_to_neuron_state_dict(state_dict: StateDict, inference_config: InferenceConfig) -> StateDict: + neuron_config = inference_config.neuron_config + attention_keys = { + ".self_attn.q_proj.": ".self_attn.qkv_proj.q_proj.", + ".self_attn.k_proj.": ".self_attn.qkv_proj.k_proj.", + ".self_attn.v_proj.": ".self_attn.qkv_proj.v_proj.", + ".self_attn.o_proj.": ".self_attn.o_proj.o_proj.", + ".self_attn.q_norm.": ".self_attn.q_layernorm.", + ".self_attn.k_norm.": ".self_attn.k_layernorm.", + } + + # At the time of writing, NxDI (Neuron 2.26) attention layer does not provide a simple way to use a custom + # scaling factor for raw attention scores (QK^T) while ensuring all optimizations (e.g. kernels) remain available + # To work around this, we fuse the scaling factor into the weights (knowing that the attention layer will use the + # default math.sqrt(inference_config.head_dim) value) + default_qk_scaling_factor_inv = math.sqrt(float(inference_config.query_pre_attn_scalar)) + gemma_qk_scaling_factor = 1.0 / math.sqrt(float(inference_config.head_dim)) + gamma = math.sqrt(gemma_qk_scaling_factor * default_qk_scaling_factor_inv) + + new_state_dict = {} + for key, weights in state_dict.items(): + if 'vision_tower.' in key: + continue + if 'language_model.model.' in key: + key = key.replace('language_model.model.', "") + for atten_key in attention_keys: + if atten_key in key: + replacement_atten_key = attention_keys[atten_key] + key = key.replace(atten_key, replacement_atten_key) + break + if key.endswith((".q_proj.weight", ".k_proj.weight")): + orig_dtype = weights.dtype + weights = (weights.to(dtype=torch.float32) * gamma).to(dtype=orig_dtype) + new_state_dict[key] = weights + + if neuron_config.fused_qkv: + new_state_dict = convert_state_dict_to_fused_qkv( + state_dict=new_state_dict, + num_layers=inference_config.num_hidden_layers, + neuron_config=inference_config.neuron_config, + prefix="layers.{layer_num}.self_attn" + ) + + if neuron_config.vocab_parallel: + new_state_dict["embed_tokens.rank_util.rank"] = torch.arange(0, neuron_config.local_ranks_size) + + tp_degree = neuron_config.tp_degree + for i in range(inference_config.num_hidden_layers): + new_state_dict[f"layers.{i}.self_attn.rank_util.rank"] = torch.arange(0, tp_degree, dtype=torch.int32) + + new_state_dict["rank_util.rank"] = torch.arange(0, tp_degree, dtype=torch.int32) + + return new_state_dict + + @classmethod + def get_config_cls(cls): + return TextGemma3InferenceConfig diff --git a/contrib/models/gemma3-vision/src/gemma3_vision/modeling_gemma3_text.py b/contrib/models/gemma3-vision/src/gemma3_vision/modeling_gemma3_text.py new file mode 100644 index 00000000..d9fa4f60 --- /dev/null +++ b/contrib/models/gemma3-vision/src/gemma3_vision/modeling_gemma3_text.py @@ -0,0 +1,870 @@ +import logging +from typing import Optional, Tuple +import torch +import torch.nn as nn +from torch_neuronx.xla_impl.ops import RmsNorm +from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbedding, Gemma3RMSNorm + +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + ParallelEmbedding, +) +from neuronx_distributed.parallel_layers.mappings import _gather_along_dim +from neuronx_distributed.quantization import dequantize +from neuronx_distributed.utils import cpu_mode +from neuronx_distributed_inference.models.config import InferenceConfig +from neuronx_distributed_inference.models.model_base import NeuronBaseModel +from neuronx_distributed_inference.models.llama.modeling_llama import NeuronLlamaMLP +from neuronx_distributed_inference.modules.attention.attention_base import NeuronAttentionBase +from neuronx_distributed_inference.modules.attention.attention_process_groups import ( + get_flattened_inverted_tp_cp_group_mesh +) +from neuronx_distributed_inference.modules.attention.utils import ( + chunk_and_reorder_tensor, + RotaryEmbedding, + stride_tensor, +) +from neuronx_distributed_inference.modules.custom_calls import neuron_cumsum +from neuronx_distributed_inference.modules.flashdecode.utils import ( + get_cache_size, + mask_util, + turn_2d_mask_to_4d, +) +from neuronx_distributed_inference.modules.generation.sampling import Sampler, mask_padded_logits +from neuronx_distributed_inference.modules.kvcache.utils import get_layer_to_kv_cache_size_mapping_for_mixed_attn +from neuronx_distributed_inference.modules.kvcache.kv_cache_manager import KVCacheManager, _slice_kv_cacheline +from neuronx_distributed_inference.modules.kvcache.block_kv_cache_manager import generate_tokengen_slot_mapping +from neuronx_distributed_inference.utils.distributed import get_tp_group + +logger = logging.getLogger("Neuron") + + +class HybridAttnKVCacheManager(KVCacheManager): + + def get_kv_by_layer_id( + self, + idx, + seq_len: int, + skip_slice=False, + medusa_metadata=None, + kvcache_buffer=None, + seq_ids=None, + is_for_speculation: bool = False, + **kwargs, + ): + """ + Override KVCacheManager's get_kv_by_layer_id() to handle hybrid attention patterns. + + Changes: + 1. Removed the following lines: + ``` + if hasattr(self, "v_shapes"): + seq_len = self.v_shapes[idx][2] + ``` + + Without this override, get_kv_by_layer_id() would return caches with shape + [batch_size, num_head_per_rank, max_seq_len, head_dim] instead of the expected + [batch_size, num_head_per_rank, n_positions (bucket length), head_dim]. + """ + k_cache, v_cache = self._fetch_cache(idx, kvcache_buffer) + if ( + self.neuron_config.batch_size != self.neuron_config.max_batch_size + and is_for_speculation + ): + assert seq_ids is not None + updated_seq_ids = self.get_cache_update_index_for_seq_ids(seq_ids) + k_cache = k_cache[updated_seq_ids] + v_cache = v_cache[updated_seq_ids] + elif self.kv_cache_padding_size > 0: + k_cache = k_cache[: -self.kv_cache_padding_size] + v_cache = v_cache[: -self.kv_cache_padding_size] + if self.is_medusa: + slice_index, gather_index = self.configure_medusa_gather_slice_idx(medusa_metadata) + accepted_k_cache = torch.gather(input=k_cache, dim=3 if self.k_cache_transposed else 2, index=gather_index) + accepted_v_cache = torch.gather(input=v_cache, dim=2, index=gather_index) + k_cache = torch.scatter(input=k_cache, dim=3 if self.k_cache_transposed else 2, index=slice_index, src=accepted_k_cache) + v_cache = torch.scatter(input=v_cache, dim=2, index=slice_index, src=accepted_v_cache) + + attn_kernel_enabled = ( + self.neuron_config.attn_tkg_builtin_kernel_enabled + or self.neuron_config.attn_tkg_nki_kernel_enabled + or self.neuron_config.attn_block_tkg_nki_kernel_enabled + ) + if attn_kernel_enabled: # Attention TKG Kernels do not need slicing. + skip_slice = True + + # slice for partial view + if not skip_slice: + k_cache = _slice_kv_cacheline(self.padding_side, seq_len, k_cache, self.k_cache_transposed) + v_cache = _slice_kv_cacheline(self.padding_side, seq_len, v_cache, False) + if self.quant: + k_cache = dequantize.direct_cast_dequantize(k_cache, self.dequant_dtype) + v_cache = dequantize.direct_cast_dequantize(v_cache, self.dequant_dtype) + return k_cache, v_cache + + +class NeuronGemma3RMSNorm(nn.Module): + + def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(hidden_size)) + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + hidden_states, original_dtype = hidden_states.to(torch.float32), hidden_states.dtype + gamma = (1.0 + self.weight).to(torch.float32) + y = RmsNorm.apply(hidden_states, gamma, self.eps, hidden_states.dim() - 1) + return y.to(original_dtype) + + +def get_rmsnorm_cls(): + return Gemma3RMSNorm if cpu_mode() else NeuronGemma3RMSNorm + + +class NeuronGemma3TextScaledWordEmbedding(ParallelEmbedding): + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int, + embed_scale: float = 1.0, + **kwargs) -> None: + super().__init__(num_embeddings, embedding_dim, padding_idx, **kwargs) + self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False) + + def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor: + return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype) + + +class NeuronGemma3MLP(NeuronLlamaMLP): + pass + + +class NeuronGemma3RotaryEmbedding(RotaryEmbedding): + + def __init__(self, + dim: int, + max_position_embeddings: int, + base: float, + scaling_type: str = "default", + scaling_factor: float = 1.0, + ) -> None: + super().__init__( + dim=dim, + max_position_embeddings=max_position_embeddings, + base=base + ) + + self.scaling_type = scaling_type + if self.scaling_type == "default": + self.scaling_factor = 1.0 + elif self.scaling_type == "linear": + self.scaling_factor = scaling_factor + else: + raise ValueError( + f"Unsupported RoPE scaling type '{scaling_type}'. Gemma3 RoPE only supports 'default' or 'linear'." + ) + + def get_inv_freqs(self, device: Optional[torch.device] = None) -> torch.Tensor: + inv_freq = super().get_inv_freqs(device=device) + if self.scaling_type == "linear": + return inv_freq / self.scaling_factor + return inv_freq + + +class NeuronGemma3Attention(NeuronAttentionBase): + + @staticmethod + def get_rope(config: InferenceConfig, is_swa_layer: bool) -> NeuronGemma3RotaryEmbedding: + partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) + dim = int(config.head_dim * partial_rotary_factor) + max_position_embeddings = config.max_position_embeddings + if is_swa_layer: + # RoPE for SWA layers + return NeuronGemma3RotaryEmbedding( + dim=dim, + max_position_embeddings=max_position_embeddings, + base=config.rope_local_base_freq, + ) + else: + # RoPE for global attention layers + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + scaling_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + scaling_factor = config.rope_scaling.get("factor", 1.0) + else: + scaling_type = "default" + scaling_factor = 1.0 + return NeuronGemma3RotaryEmbedding( + dim=dim, + max_position_embeddings=max_position_embeddings, + base=config.rope_theta, + scaling_type=scaling_type, + scaling_factor=scaling_factor, + ) + + +class NeuronGemma3DecoderLayer(nn.Module): + + def __init__(self, config: InferenceConfig, layer_idx: int) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + config_sliding_window = getattr(config, "sliding_window", None) + self.is_swa_layer = False if config_sliding_window is None else bool((layer_idx + 1) % config._sliding_window_pattern) + self.sliding_window = config_sliding_window if self.is_swa_layer else None + + rms_norm_cls = get_rmsnorm_cls() + rms_norm_eps = getattr(config, "rms_norm_eps", None) + q_norm = rms_norm_cls(config.head_dim, rms_norm_eps) if rms_norm_eps else rms_norm_cls(config.head_dim) + k_norm = rms_norm_cls(config.head_dim, rms_norm_eps) if rms_norm_eps else rms_norm_cls(config.head_dim) + + self.self_attn = NeuronGemma3Attention( + config=config, + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + head_dim=getattr(config, "head_dim", config.hidden_size // config.num_attention_heads), + rotary_emb=NeuronGemma3Attention.get_rope(config=config, is_swa_layer=self.is_swa_layer), + rms_norm_eps=config.rms_norm_eps, + qkv_bias=getattr(config, "attention_bias", False), + o_bias=getattr(config, "attention_bias", False), + num_cores_per_group=config.num_cores_per_group, + tensor_model_parallel_group=get_tp_group(config), + sliding_window=self.sliding_window, + use_qk_norm=False, + q_layernorm=q_norm, + k_layernorm=k_norm + ) + + self.mlp = NeuronGemma3MLP(config) + self.input_layernorm = None + if ( + not config.neuron_config.is_eagle_draft + or config.neuron_config.enable_eagle_draft_input_norm + ): + self.input_layernorm = rms_norm_cls( + config.hidden_size, + eps=config.rms_norm_eps, + ) + self.post_attention_layernorm = rms_norm_cls( + config.hidden_size, + eps=config.rms_norm_eps, + ) + self.pre_feedforward_layernorm = rms_norm_cls( + config.hidden_size, + eps=config.rms_norm_eps, + ) + self.post_feedforward_layernorm = rms_norm_cls( + config.hidden_size, + eps=config.rms_norm_eps, + ) + self.qkv_kernel_enabled = config.neuron_config.qkv_kernel_enabled + self.mlp_kernel_enabled = config.neuron_config.mlp_kernel_enabled + self.quantized_mlp_kernel_enabled = config.neuron_config.quantized_mlp_kernel_enabled + self.rmsnorm_quantize_kernel_enabled = config.neuron_config.rmsnorm_quantize_kernel_enabled + self.mlp_kernel_fuse_residual_add = config.neuron_config.mlp_kernel_fuse_residual_add + self.qkv_kernel_fuse_residual_add = config.neuron_config.qkv_kernel_fuse_residual_add + self.sequence_parallel_enabled = config.neuron_config.sequence_parallel_enabled + self.is_prefill_stage = config.neuron_config.is_prefill_stage + + if self.is_prefill_stage and self.config.neuron_config.is_mlp_quantized(): + # for CTE, quantized MLP kernel does not support fused rmsnorm + self.mlp_kernel_fused_rmsnorm = False + else: + self.mlp_kernel_fused_rmsnorm = not self.sequence_parallel_enabled + + self.qkv_kernel_fused_rmsnorm = not self.sequence_parallel_enabled + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.BoolTensor] = None, + local_mask: Optional[torch.BoolTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.FloatTensor]] = None, + adapter_ids=None, + rotary_position_ids: Optional[torch.LongTensor] = None, + residual: Optional[torch.FloatTensor] = None, # residual from previous layer if QKV kernel with fused residual is enabled + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]], Optional[torch.FloatTensor], Optional[torch.FloatTensor], Optional[torch.FloatTensor]]: + # Adapted from NeuronLlamaDecoderLayer + is_token_gen = past_key_value is not None + entry_hidden_states = hidden_states + + # Hybrid SWA/global attention layers are specific to Gemma3 + if self.is_swa_layer: + attention_mask = local_mask + + if self.qkv_kernel_enabled and self.qkv_kernel_fused_rmsnorm: + attn_fused_rmsnorm = self.input_layernorm + else: + hidden_states = self.input_layernorm(hidden_states) + attn_fused_rmsnorm = None + + # Self Attention + attn_output = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + adapter_ids=adapter_ids, + rmsnorm=attn_fused_rmsnorm, + rotary_position_ids=rotary_position_ids, + residual=residual, + **kwargs, + ) + + # Post-attention RMS norm is specific to Gemma3 + hidden_states = self.post_attention_layernorm(attn_output.hidden_states) + + if attn_output.residual is not None: + # In the case the QKV kernel is enabled (attn_output.residual is not None), the input hidden + # states actually do not correspond to the attention layer's inputs. They are computed within + # the layer (by the fused QKV kernel) and returned as "residual" output. + assert self.qkv_kernel_fuse_residual_add, \ + "residual add before qkv should be computed in the previous layer, \ + unless qkv_kernel_fuse_residual_add is specified" + assert ( + not self.sequence_parallel_enabled + ), "qkv_kernel_fuse_residual_add should be off when sequence parallelism is enabled" + assert ( + self.qkv_kernel_enabled + ), "qkv_kernel_fuse_residual_add should be used with qkv_kernel_enabled" + assert ( + not is_token_gen + ), "cannot fuse residual add for tokengen" + residual = attn_output.residual + else: + residual = entry_hidden_states # attention layer inputs to be used for residuals addition + + if self.mlp_kernel_enabled and self.mlp_kernel_fuse_residual_add: + assert ( + not self.sequence_parallel_enabled + ), "mlp_kernel_fuse_residual_add should be off when sequence parallelism is enabled" + hidden_states, residual = self.mlp( + hidden_states, + rmsnorm=self.pre_feedforward_layernorm, + residual=residual, + adapter_ids=adapter_ids, + ) + else: + hidden_states = residual + hidden_states + residual = hidden_states + + if self.mlp_kernel_enabled and self.mlp_kernel_fused_rmsnorm: + mlp_fused_rmsnorm = self.pre_feedforward_layernorm + else: + hidden_states = self.pre_feedforward_layernorm(hidden_states) + mlp_fused_rmsnorm = None + + hidden_states, _ = self.mlp( + hidden_states, + rmsnorm=mlp_fused_rmsnorm, + adapter_ids=adapter_ids, + ) + + # Post-feed-forward RMS norm is specific to Gemma3 + hidden_states = self.post_feedforward_layernorm(hidden_states) + + # If the QKV kernel with fused residual addition is not enabled, we perform the residual addition here, + # otherwise, we return the residual so the fused kernel in the next block can perform the addition + if not self.qkv_kernel_fuse_residual_add or is_token_gen: + hidden_states = residual + hidden_states + residual = None + + return (hidden_states, attn_output.present_key_value, attn_output.cos_cache, attn_output.sin_cache, residual) + + +class NeuronGemma3TextModel(NeuronBaseModel): + + def scatter_by_index_put(self, h_image, encoded_patches_proj, positions): + """ + Scatter encoded patches into an image tensor. + Compared to neuronx_distributed_inference/models/llama4/utils/encoder_utils.py's scatter_by_index_put(), + this function supports Batch Size >= 1. + + Args: + h_image (torch.Tensor): The target image tensor of shape (B, max_positions, embedding_dim) + encoded_patches_proj (torch.Tensor): The encoded patches to be scattered, of shape (num_patches, patch_size, embedding_dim) + positions (torch.Tensor): The positions where patches should be scattered, of shape (B, num_positions, 1) + + Returns: + torch.Tensor: The updated image tensor with scattered patches + """ + B, max_positions, embedding_dim = h_image.shape + + # Create a new tensor instead of modifying h_image in-place + h_image_new = h_image.clone() + + # Flatten encoded_patches_proj + encoded_patches_flat = encoded_patches_proj.view(-1, embedding_dim) + + # Flatten positions + positions = positions.view(-1) + + # Create Batch Indices + # We need to tell PyTorch: "This update belongs to batch 0, that one to batch 1" + # If positions is (B, N), we need batch_idx to look like [0,0..0, 1,1..1, ...] + num_updates_per_batch = positions.shape[0] // B + + batch_idx = torch.arange(B, device=h_image.device, dtype=positions.dtype) + batch_idx = batch_idx.repeat_interleave(num_updates_per_batch) + + # Use index_put_ to scatter the embeddings + h_image_new.index_put_( + (batch_idx.long(), positions.long()), + encoded_patches_flat, + accumulate=False + ) + + return h_image_new + + def encode_vision_to_input(self, inputs_embeds, vision_embeddings, vision_mask) -> torch.Tensor: + # Concat vision and text embeddings during context encoding + # Both inputs_embeds and vision_embeddings should be of the same shape: [BS, Total tokens (image + text), Hidden] + # And vision_mask should of the shape [BS, Total tokens (image + text), 1] + # Entries in vision_mask with value `True` represent vision tokens and with value `False` represent text tokens + # For text-only inputs, vision_mask should be all `False` + return self.scatter_by_index_put(inputs_embeds, vision_embeddings, vision_mask) + + def setup_attr_for_model(self, config: InferenceConfig): + # Needed for init_inference_optimization() + self.on_device_sampling = config.neuron_config.on_device_sampling_config is not None + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.max_batch_size = config.neuron_config.max_batch_size + self.buckets = config.neuron_config.buckets + + def init_model(self, config: InferenceConfig): + """ + Modified init_model of NeuronLlama4TextModel: + 1. add self.sliding_window. This will allow creating local attention masks in forward() + 2. replace embedding modules with 'scaled' embeddings""" + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.sliding_window = config.sliding_window + + if self.sliding_window and config.neuron_config.seq_len < self.sliding_window: + # When the model context (seq_len) is shorter than the window, the sliding window + # effectively covers the entire sequence (full attention). Update to match. + config.sliding_window = config.neuron_config.seq_len + self.sliding_window = config.sliding_window + + if self.sliding_window: + is_layer_locals = [layer_idx % config._sliding_window_pattern != config._sliding_window_pattern - 1 for layer_idx in range(config.num_hidden_layers)] + self.layer_to_cache_size_mapping = get_layer_to_kv_cache_size_mapping_for_mixed_attn(config.sliding_window, config.neuron_config.seq_len, is_layer_locals) + logger.info("layer_to_cache_size_mapping initialized") + + self.has_mixed_attn = True + + if parallel_state.model_parallel_is_initialized(): + self.embed_tokens = NeuronGemma3TextScaledWordEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + config.hidden_size**0.5, # embed_scale + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=not config.neuron_config.vocab_parallel, + sequence_parallel_enabled=False, + pad=True, + tensor_model_parallel_group=get_tp_group(config), + use_spmd_rank=config.neuron_config.vocab_parallel, + ) + + lm_head_pad = config.neuron_config.lm_head_pad + lnc = config.neuron_config.logical_nc_config + lm_head_pad_alignment_size = config.neuron_config.lm_head_pad_alignment_size * lnc + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + gather_output=not self.on_device_sampling, + bias=lm_head_pad, + pad=True, + pad_alignment_size_per_rank=lm_head_pad_alignment_size if lm_head_pad else 1, + keep_padded_output=lm_head_pad, + dtype=config.neuron_config.torch_dtype, + tensor_model_parallel_group=get_tp_group(config), + ) + else: + self.embed_tokens = Gemma3TextScaledWordEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + config.hidden_size**0.5 # embed_scale + ) + self.lm_head = nn.Linear( + config.hidden_size, + config.vocab_size, + bias=False, + ) + + self.layers = nn.ModuleList( + [NeuronGemma3DecoderLayer(config, idx) for idx in range(config.num_hidden_layers)] + ) + + if not config.neuron_config.is_eagle_draft: + self.norm = get_rmsnorm_cls()(config.hidden_size, eps=config.rms_norm_eps) + + if config.neuron_config.is_eagle_draft: + fc_bias = getattr(config, "fc_bias", False) + self.fc = ColumnParallelLinear( + config.hidden_size * 2, config.hidden_size, bias=fc_bias, gather_output=True + ) + + # TODO: medusa needed? + # self.is_medusa = config.neuron_config.is_medusa + # self.num_medusa_heads = config.neuron_config.num_medusa_heads + # self.medusa_speculation_length = config.neuron_config.medusa_speculation_length + + # if self.is_medusa: + # if parallel_state.model_parallel_is_initialized(): + # medusa_head_cls = ColumnParallelLinear + # else: + # medusa_head_cls = nn.Linear + # for i in range(self.num_medusa_heads): + # medusa_head = nn.Sequential( + # *([ResBlock(config.hidden_size)] * 1), + # medusa_head_cls( + # config.hidden_size, + # config.vocab_size, + # gather_output=not self.on_device_sampling, + # bias=False, + # ), + # ) + # setattr(self, f"medusa_head_{i}", medusa_head) + + def init_inference_optimization(self, config: InferenceConfig): + """ + Compared to neuronx_distributed_inference/models/model_base.py's init_inference_optimization(), + use HybridAttnKVCacheManager instead of KVCacheManager + """ + super().init_inference_optimization(config) + + if self.on_device_sampling: + self.sampler = Sampler(config.neuron_config) + + self.kv_mgr = HybridAttnKVCacheManager( + config, + num_kv_head=self.num_key_value_heads, + global_rank=self.rank_util, + sliding_window=self.sliding_window, + layer_to_cache_size_mapping=self.layer_to_cache_size_mapping) + + def forward( + self, + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + prev_hidden=None, + adapter_ids=None, + accepted_indices=None, + current_length=None, + medusa_mask=None, + scatter_index=None, + slot_mapping=None, + active_block_table=None, + num_queries=None, + computed_context_lens=None, + tile_q_indices=None, + tile_block_tables=None, + tile_masks=None, + # In llava context encoding model, input_embeds is precomputed + inputs_embeds: Optional[torch.FloatTensor] = None, + kv_cache: Optional[torch.Tensor] = None, + active_mask=None, + rotary_position_id=None, + vision_embeddings=None, + vision_mask=None, + ): + """ + Compared to NxDI NeuronBaseModel.forward(), + 1. pass 'past_key_values' to get_model_output + 2. always create local attention mask (for sliding window attn layers) + """ + # Optional argument cannot be set to None in NXDI now as NxD does not support + # kwargs. Now we are working around by passing an empty tensor. + # + # But empty tensors break the logic like + # if input_embeds is None: + # input_embeds = embed() + # + # We are forced to pass in a value for optional params + # Passing in none does not work as it breaks torchscripting. + # Once kwargs support is in, we can remove this workaround. + prev_hidden = self.set_none_if_empty(prev_hidden) + adapter_ids = self.set_none_if_empty(adapter_ids) + accepted_indices = self.set_none_if_empty(accepted_indices) + current_length = self.set_none_if_empty(current_length) + medusa_mask = self.set_none_if_empty(medusa_mask) + scatter_index = self.set_none_if_empty(scatter_index) + slot_mapping = self.set_none_if_empty(slot_mapping) + active_block_table = self.set_none_if_empty(active_block_table) + num_queries = self.set_none_if_empty(num_queries) + computed_context_lens = self.set_none_if_empty(computed_context_lens) + tile_q_indices = self.set_none_if_empty(tile_q_indices) + tile_block_tables = self.set_none_if_empty(tile_block_tables) + tile_masks = self.set_none_if_empty(tile_masks) + inputs_embeds = self.set_none_if_empty(inputs_embeds) + kv_cache = self.set_none_if_empty(kv_cache) + active_mask = self.set_none_if_empty(active_mask) + rotary_position_id = self.set_none_if_empty(rotary_position_id) + vision_embeddings = self.set_none_if_empty(vision_embeddings) + vision_mask = self.set_none_if_empty(vision_mask) + local_attn_mask = None + + if self.neuron_config.is_medusa: + return self._medusa_forward( + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + adapter_ids, + accepted_indices, + current_length, + medusa_mask, + scatter_index, + ) + + is_for_token_gen = attention_mask.dim() == 4 + + if ( + is_for_token_gen + and self.neuron_config.enable_token_tree + and self.neuron_config.enable_eagle_speculation + ): + logging.warning("entering _eagle_token_tree_forward") + return self._eagle_token_tree_forward( + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + prev_hidden, + adapter_ids, + scatter_index=scatter_index, + inputs_embeds=inputs_embeds, + kv_cache=kv_cache, + active_mask=active_mask, + rotary_position_id=rotary_position_id, + ) + # TODO: This will not work for a context encoding model with bucket size + # equal to the speculation length + is_for_context_encoding = self._is_context_encoding(input_ids) + is_for_speculation = self._is_for_speculation(input_ids) + + # For non-speculative prefix caching, generate the slot mapping within the traced model. + # This is necessary for async mode, as the active_block_table is up-to-date but the slot mapping + # passed into the traced model may be from a prior iteration. + if ( + not is_for_context_encoding + and not self.neuron_config.enable_fused_speculation + and not self.neuron_config.enable_eagle_speculation + and self.is_prefix_caching + and active_block_table is not None + ): + block_size = torch.tensor(self.neuron_config.pa_block_size, device=position_ids.device, dtype=torch.int32) + slot_mapping = generate_tokengen_slot_mapping(position_ids, slot_mapping, active_block_table, block_size) + + cache_size = ( + get_cache_size(self.n_positions, self.num_cores_per_group, is_for_context_encoding) + if self.neuron_config.flash_decoding_enabled + else self.n_positions + ) + + # Prepare attention mask(s) + if self.is_chunked_prefill: + attn_mask = self.create_attn_mask( + attention_mask, + is_for_context_encoding, + is_for_speculation, + query_lens=num_queries, + key_lens=num_queries + computed_context_lens, + ) + else: + attn_mask = self.create_attn_mask( + attention_mask, + is_for_context_encoding, + is_for_speculation, + position_ids=position_ids, + ) + if self.attention_chunk_size: + if is_for_context_encoding: + local_attn_mask = self._create_chunked_attn_mask_cte(attention_mask, self.attention_chunk_size) + else: + local_attn_mask = self._create_chunked_attn_mask_tkg(attention_mask, self.attention_chunk_size, position_ids) + elif self.sliding_window: + if is_for_context_encoding: + local_attn_mask = self._create_windowed_attn_mask_cte(attention_mask, self.sliding_window) + else: + local_attn_mask = self._create_windowed_attn_mask_tkg(attention_mask, self.sliding_window, position_ids) + + active_mask = None + if self.is_prefix_caching: + active_length = self.speculation_length if is_for_speculation else self.n_active_tokens + active_mask = torch.full( + (active_length, active_length), + True, + device=attention_mask.device, + ).tril(diagonal=0) + active_mask = active_mask[None, None, :, :].expand( + self.batch_size, 1, active_length, active_length + ) + if is_for_speculation: + active_mask = torch.full( + (self.speculation_length, self.speculation_length), + True, + device=attention_mask.device, + ).tril(diagonal=0) + active_mask = active_mask[None, None, :, :].expand( + self.batch_size, 1, self.speculation_length, self.speculation_length + ) + + # FlashDecoding masks, for KV cache updates + active_mask_2d = None + if self.neuron_config.flash_decoding_enabled and not is_for_context_encoding: + rank_id = self.rank_util.get_rank() + active_mask_tmp, attention_mask_tmp = mask_util( + pos_ids=position_ids, + rank_id=rank_id, + num_cores_per_group=self.num_cores_per_group, + cache_size=cache_size, + ) + if is_for_speculation: + active_mask = active_mask_tmp[:, None, :, :].expand(self.batch_size, 1, -1, -1) + attn_mask = attention_mask_tmp[:, None, :, :].expand(self.batch_size, 1, -1, -1) + # only for cache udpate + active_mask_2d = active_mask_tmp.sum(dim=-2, keepdims=False).to(torch.bool) + else: + active_mask = turn_2d_mask_to_4d( + active_mask_tmp, n_positions=1, batch_size=self.batch_size + ) + attn_mask = turn_2d_mask_to_4d( + attention_mask_tmp, n_positions=cache_size, batch_size=self.batch_size + ) + active_mask_2d = active_mask_tmp + + if self.neuron_config.strided_context_parallel_kernel_enabled and is_for_context_encoding: + logging.debug("strided_context_parallel_kernel_enabled enabled, shuffling inputs") + + # The strided CP FA kernel expected inputs to be strided, due to SP happening in model_base + # stride here rather than in attention to order it before we move the inputs to SP region + input_ids = stride_tensor(input_ids, 1, self.neuron_config.cp_degree) + position_ids = stride_tensor(position_ids, 1, self.neuron_config.cp_degree) + + # When using SP with 8x8 CP, the mesh is non-contiguous, so we reorder the input to have a non-contiguous SP split + # When we AG in attention using 8x8, the resulting sequence is contiguous + if is_for_context_encoding and self.neuron_config.cp_degree > 1 and self.neuron_config.cp_degree == 8 and (self.neuron_config.tp_degree // self.neuron_config.cp_degree) == 8 and self.sequence_parallel_enabled: + ordering = get_flattened_inverted_tp_cp_group_mesh(self.neuron_config.tp_degree, self.neuron_config.cp_degree) + + logging.debug("CP8 and SP enabled, reordering the input on S", ordering) + input_ids = chunk_and_reorder_tensor(input_ids, ordering, 1) + + # It is either for context encoding or for token generation + if is_for_context_encoding: + past_key_values = None + else: + past_key_values = self.kv_mgr.get_cache(self.n_positions) + + hidden_states, updated_kv_cache = self.get_model_output( + input_ids=input_ids, + seq_ids=seq_ids, + attention_mask=attn_mask, + position_ids=position_ids, + past_key_values=past_key_values, + active_mask=active_mask, + inputs_embeds=inputs_embeds, + adapter_ids=adapter_ids, + prev_hidden=prev_hidden, + tile_q_indices=tile_q_indices, + tile_block_tables=tile_block_tables, + tile_masks=tile_masks, + num_queries=num_queries, + is_for_context_encoding=is_for_context_encoding, + scatter_index=slot_mapping if self.is_block_kv_layout else scatter_index, + kvcache_buffer=kv_cache, + is_for_speculation=is_for_speculation, + active_block_table=active_block_table, + kv_active_mask=active_mask_2d, + update_cache=True, + vision_embeddings=vision_embeddings, + vision_mask=vision_mask, + local_attn_mask=local_attn_mask, + ) + + batch_size = input_ids.shape[0] + if not self.sliced_hidden: + if self.padding_side == "left": + index = torch.tensor([hidden_states.shape[1] - 1], device=hidden_states.device) + index = index.unsqueeze(1).expand(batch_size, 1, self.hidden_size) + hidden_states = torch.gather(hidden_states, dim=1, index=index) + elif self.is_chunked_prefill: + if is_for_context_encoding: + # chunked prefill will return cp_config.max_num_seqs, not + # just the last one + index = neuron_cumsum(num_queries.reshape(1, -1).float()).int() - 1 + index = index.reshape(1, -1, 1) + index = index.expand(batch_size, -1, self.hidden_size) + hidden_states = torch.gather(hidden_states, dim=1, index=index) + else: + if not ( + position_ids.shape[-1] == self.speculation_length or position_ids.shape[-1] == 1 + ): + # context encoding + index = torch.max(position_ids, dim=1, keepdim=True).indices + index = index.unsqueeze(1).expand(batch_size, 1, self.hidden_size) + hidden_states = torch.gather(hidden_states, dim=1, index=index) + + logits = self.lm_head(hidden_states) + logits = logits.float() + + if hasattr(self.lm_head, "pad_size"): + if self.lm_head.gather_output: + rank_id = torch.tensor(0, device=logits.device, dtype=torch.int32) + world_size = 1 + else: + rank_id = self.rank_util.get_rank() + world_size = torch.distributed.get_world_size( + group=self.lm_head.tensor_parallel_group + ) + logits = mask_padded_logits(logits, rank_id, world_size, pad_size=self.lm_head.pad_size) + + if self.on_device_sampling: + res = self._sample_on_device( + logits, sampling_params, is_for_speculation, is_for_context_encoding + ) + else: + res = logits + + # A hack to ensure active_block_table and attention_mask is not optimized away + # if not None for prefix caching flow. + if self.is_prefix_caching: + if active_block_table is not None and len(active_block_table.shape) == 1: + res = res + active_block_table[0] * 0 + if attention_mask is not None and self.prefix_size == 0: + res = res + attention_mask[0] * 0 + + outputs = [res] + if self.neuron_config.output_logits: + logits = _gather_along_dim( + logits, + partition_dim=2, + process_group=get_tp_group(self.config), + ) + outputs += [logits] + outputs += updated_kv_cache + + if self.neuron_config.enable_eagle_speculation: + if is_for_context_encoding: + outputs = outputs + [hidden_states] + [self.full_hidden_states] + else: + outputs = outputs + [self.full_hidden_states] + + return outputs diff --git a/contrib/models/gemma3-vision/src/gemma3_vision/modeling_gemma3_vision.py b/contrib/models/gemma3-vision/src/gemma3_vision/modeling_gemma3_vision.py new file mode 100644 index 00000000..573c8085 --- /dev/null +++ b/contrib/models/gemma3-vision/src/gemma3_vision/modeling_gemma3_vision.py @@ -0,0 +1,329 @@ +import logging +from typing import List, Tuple + +import torch +from torch import nn + +from neuronx_distributed_inference.models.config import InferenceConfig +from neuronx_distributed_inference.models.llama4.modeling_llama4_vision import Llama4VisionModelWrapper +from neuronx_distributed_inference.modules.async_execution import is_ranked_io + +from gemma3_vision.siglip.modeling_siglip import NeuronSiglipVisionModel +from gemma3_vision.modeling_gemma3_text import get_rmsnorm_cls + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +class NeuronGemma3MultiModalProjector(nn.Module): + def __init__(self, config: InferenceConfig): + super().__init__() + + self.mm_input_projection_weight = nn.Parameter( + torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size) + ) + + self.mm_soft_emb_norm = get_rmsnorm_cls()( + config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps + ) + + self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size) + self.tokens_per_side = int(config.mm_tokens_per_image**0.5) + self.kernel_size = self.patches_per_image // self.tokens_per_side + self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size) + + def forward(self, vision_outputs: torch.Tensor): + batch_size, _, seq_length = vision_outputs.shape + + reshaped_vision_outputs = vision_outputs.transpose(1, 2) + reshaped_vision_outputs = reshaped_vision_outputs.reshape( + batch_size, seq_length, self.patches_per_image, self.patches_per_image + ) + reshaped_vision_outputs = reshaped_vision_outputs.contiguous() + + pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) + pooled_vision_outputs = pooled_vision_outputs.flatten(2) + pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) + + normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) + + projected_vision_outputs = torch.matmul(normed_vision_outputs, self.mm_input_projection_weight) + return projected_vision_outputs.type_as(vision_outputs) + + +class NeuronGemma3VisionModel(torch.nn.Module): + def __init__(self, config: InferenceConfig): + super().__init__() + self.config = config + self.vision_config = config.vision_config + logger.info(f"in NeuronGemma3VisionModel self.vision_config {vars(self.vision_config)}") + + # TODO: data parallel optimization + # self.global_rank = SPMDRank(world_size=self.neuron_config.world_size) + # assert ( + # self.neuron_config.world_size % self.neuron_config.tp_degree == 0 + # ), "Invalid parallel config. world_size should be a multiple of tp_degree" + # self.dp_degree = self.neuron_config.world_size // self.neuron_config.tp_degree + # self.data_parallel_enabled = self.dp_degree > 1 + # self.data_parallel_group = get_data_parallel_group() + + self.vision_encoder = NeuronSiglipVisionModel(self.vision_config) + # multi_modal_projector need to read text model hidden_size, so we pass in the entire config to it + self.multi_modal_projector = NeuronGemma3MultiModalProjector(self.config) + + def forward( + self, + pixel_values: torch.Tensor, + ) -> torch.Tensor: + """ + Generate vision embeddings from flattened pixel values. + + This function handles dynamic image shapes as well as multiple images by splitting each image + into a number of fixed-size chunks. Afterwards, all chunks are stacked together on the batch dimension (dim=0) + + Args: + pixel_values (Tensor): Vision pixel values of shape [num_chunks, 1(constant), num_chunnels, image_size, image_size] + + Returns: + vision embeddings (Tensor): Vision embeddings (after projection) padded to the nearest bucket size. + + """ + # TODO: data parallel optimization + # if self.data_parallel_enabled: + # dp_rank = get_dp_rank_spmd(self.global_rank.get_rank(), self.neuron_config.tp_degree) + # # split inputs along batch dim + # pixel_values = scatter_to_process_group_spmd( + # pixel_values, + # partition_dim=0, + # rank=dp_rank, + # process_group=self.data_parallel_group, + # ) + + embedding = self.vision_encoder(pixel_values).last_hidden_state + logger.info(f"embedding.shape {embedding.shape}") + + projected_embedding = self.multi_modal_projector(embedding) + logger.info(f"projected_embedding.shape {projected_embedding.shape}") + + # TODO: data parallel optimization + # if self.data_parallel_enabled: + # h_image_proj = gather_from_tensor_model_parallel_region_with_dim( + # h_image_proj, gather_dim=0, process_group=self.data_parallel_group + # ) + return projected_embedding + + +class Gemma3VisionModelWrapper(Llama4VisionModelWrapper): + """ + Neuron ModelWrapper class for Gemma3's vision model (NeuronSiglipVisionModel). + Inherits from Llama4VisionModelWrapper. + Generates input shapes for trace and compilation. Disables bucketing. + """ + + def __init__( + self, + config: InferenceConfig, + model_cls, + tag="", + compiler_args: str = None, + priority_model_idx: int = None, + pipeline_execution: bool = True, + return_ranked_to_cpu: bool = True, + model_init_kwargs={}, + ) -> None: + super().__init__( + config, model_cls, tag, compiler_args, priority_model_idx, + pipeline_execution, return_ranked_to_cpu, model_init_kwargs + ) + + def input_generator(self) -> List[Tuple[torch.Tensor]]: + """ + Override Llama4VisionModelWrapper.input_generator(). + + Returns: + inputs (List[Tuple[torch.Tensor]]): Example input args for every bucket. + """ + inputs = [] + for bucket in self.neuron_config.buckets: + pixel_values = torch.ones( + [ + self.neuron_config.batch_size, + self.config.vision_config.num_channels, + self.config.vision_config.image_size, + self.config.vision_config.image_size, + ], + dtype=self.config.neuron_config.torch_dtype + ) + inputs.append((pixel_values,)) + + return inputs + + def forward(self, *args): + """ + Override ModelWrapper.forward() to adapt for vision encoder. + """ + if self.model is None: + raise RuntimeError( + "Forward called before load. Run load() or load_state_dict() making calling forward" + ) + + # convert int64 to int32 to improve compatibility with compiler; does not apply to cpu case + if not self.neuron_config.on_cpu: + args = self.convert_int64_to_int32(*args) + + pixel_values = args[0] + input_batch_size = pixel_values.shape[0] + + if input_batch_size == self.neuron_config.batch_size: + output = self._forward(*args) + return output + + cur_batch = 0 + outputs = [] + + logging.debug( + f"get input_batch_size as {input_batch_size} but compiled batch_size as {self.neuron_config.batch_size}" + ) + + while cur_batch < input_batch_size: + if cur_batch + self.neuron_config.batch_size <= input_batch_size: + # we only process part of the input to run + logging.debug( + f"running foward on batch {cur_batch}:{cur_batch + self.neuron_config.batch_size}" + ) + + # pad to next bucket for context encoding with bs > 1 + # batch_arg represent single prompt in batch of prompts + batch_args = [ + arg[cur_batch : cur_batch + self.neuron_config.batch_size] for arg in args + ] + batch_args = self.vllm_cte_repadding(batch_args) + + output = self._forward(*batch_args) + + else: + # we need to pad the input to run + logging.debug( + f"running forward on batch {cur_batch}:{input_batch_size}, padded up to {self.neuron_config.batch_size}" + ) + output = self._forward_with_pad( + *[ + arg[cur_batch:input_batch_size] if not is_ranked_io(arg) else arg + for arg in args + ] + ) + + outputs.append(output) + cur_batch += self.neuron_config.batch_size + + return output + + def _forward_with_pad(self, *args): + """ + Override ModelWrapper._forward_with_pad + as vision encoder's args only includes pixel values (i.e. len(args) = 1) + """ + # Note: NxD's tracing flow (Model Builder) does not yet support kwargs, because of which we cannot support + # optional parameters. Kwargs support is being added as a part of the new Model Builder API. Until then we + # maintain a specific set of inputs that the ModelWrapper can support. + # This is not the best way to maintain code. But soon kwargs suport will render this irrelevant. + + # pad the inputs up to the compiled batch size in the end + def pad_helper(tensor, pad_type="fill_0", batch_sort_indices=None): + """ + As part of continuous batching: + * If users provide us input batch size less than compiled batch size, NxDI + need to pad the inputs to the compiled batch size. + * seq_ids are used to indicate which kv cache line is used for each input batch line. + NxDI expects the seq_ids to always be [0, 1, 2, ..., compiled_batch_size) by default. + * To fulfill these requirements, NxDI pads the seq_ids with the missing slots and sorts + it in ascending order. Every other input args are reordered accordingly and + missing slots are padded with `repeat_first_batchline`. While returning back response, + we use index selct to pick the outputs corresponding to user provided seq_ids. + Eg: + Input [[10],[20]] and seq_ids [[3], [2]] with compiled batch size as 4. + seq_ids [[3], [2]] -> [[3], [2], [0], [1]] (filled missing slots) -> [[0], [1], [2], [3]] (sort) + Input [[10],[20]] -> [[10],[20],[10],[10]] (repeat_first_batchline) -> [[10],[10],[20],[10]](reorder) + + As part of continuous batching with prefix caching, the second restriction no longer holds true, + so sorting of seq_ids and reordering of input args is no longer needed. Padding is required which is added + towards the end using `repeat_first_batchline` with the exception of slot_mapping (set to -1 instead) + as this is used to update the block kv cache. While returning back response, we just drop off the + padded outputs lines at the end of the batch. + Eg: + Input [[10],[20]] ; seq_ids [[3], [2]] and slot mapping [[50],[100]] with compiled batch size as 4. + seq_ids [[3], [2]] -> [[3], [2], [0], [1]] (filled missing slots) + Input [[10],[20]] -> [[10],[20],[10],[10]] (repeat_first_batchline) + slot mapping [[50],[100]] -> [[50],[100],[-1], [-1]] (padded with -1) + """ + if tensor is None or tensor.shape[0] == self.neuron_config.batch_size: + return tensor + + padded_shape = list(tensor.shape) + padded_shape[0] = self.neuron_config.batch_size + + def repeat_first_batchline(tensor, padded_shape): + return tensor[0].repeat(padded_shape[0], 1, 1, 1).to(tensor.dtype) + + def fill_value_tensor(value): + return lambda tensor, padded_shape: torch.full(padded_shape, fill_value=value, dtype=tensor.dtype) + + PAD_TYPES = { + "repeat_first_batchline": repeat_first_batchline, + "fill_0": fill_value_tensor(0), + "fill_1": fill_value_tensor(1), + "fill_-1": fill_value_tensor(-1), + } + + if pad_type not in PAD_TYPES: + raise ValueError(f"Unknown pad_type '{pad_type}'. Available: {list(PAD_TYPES.keys())}") + + padded_tensor = PAD_TYPES[pad_type](tensor, padded_shape) + padded_tensor[: tensor.shape[0]] = tensor + + if batch_sort_indices is not None: + padded_tensor = torch.index_select(padded_tensor, 0, batch_sort_indices) + + return padded_tensor + + reorder_seq_ids = False + pixel_values = args[0] + orig_batch_size = pixel_values.shape[0] + seq_ids_list = list(range(orig_batch_size)) + seq_ids = torch.tensor(seq_ids_list, dtype=torch.int32) + + padded_seq_ids = torch.tensor( + seq_ids_list + + [x for x in range(self.neuron_config.max_batch_size) if x not in seq_ids_list], + dtype=seq_ids.dtype, + ) + padded_seq_ids, indices = torch.sort(padded_seq_ids) if reorder_seq_ids else (padded_seq_ids, None) + + padded_args = [] + # pad pixel_values + for arg in args: + if is_ranked_io(arg): # async output + # ===========READ THIS============= + # args[0] can be either input_ids + # or an async_output. If the output + # is async, it means that the sorting + # and padding has already been done + # properly, so we simply append the + # result. This is true because the + # results from async are fed directly + # to the next iteration without data + # modification, and the model was + # executed with padded & sorted inputs. + # ================================= + padded_args.append(arg) + else: + padded_arg = pad_helper( + arg, + pad_type="repeat_first_batchline", + batch_sort_indices=indices, + ) + padded_args.append(padded_arg) + + outputs = self._forward(*padded_args) + + return outputs[:orig_batch_size] diff --git a/contrib/models/gemma3-vision/src/gemma3_vision/ndxi_patch.py b/contrib/models/gemma3-vision/src/gemma3_vision/ndxi_patch.py new file mode 100644 index 00000000..1620279a --- /dev/null +++ b/contrib/models/gemma3-vision/src/gemma3_vision/ndxi_patch.py @@ -0,0 +1,235 @@ +from typing import Callable, List, Optional, Tuple, Union + +from neuronx_distributed_inference.utils.tensor_replacement.registry import TensorReplacementRegister +import torch +from transformers.modeling_outputs import CausalLMOutputWithPast + + +def patched_get_last_kv_window(window_size, position_ids, latest_k, latest_v, windowed_context_encoding_window_idx=-1, spec_len=0): + """ + Replaces https://github.com/aws-neuron/neuronx-distributed-inference/blob/main/src/neuronx_distributed_inference/modules/attention/utils.py#L634 + to convert the index tensor in torch.gather to a LongTensor. Otherwise, the function will error out. + """ + batch_size, num_head, _, head_dim = latest_k.shape + latest_pos = torch.amax(position_ids, dim=1) + if windowed_context_encoding_window_idx >= 1: # if windowed cte, account for current window offset + latest_pos -= windowed_context_encoding_window_idx * window_size + + # True window size + window_size = window_size - 1 + spec_len - 1 if spec_len > 0 else window_size - 1 + + end_idx = (latest_pos + 1).clamp(min=window_size) + start_idx = (end_idx - window_size).clamp(min=0) + orig_indices = start_idx[:, None] + torch.arange(window_size) + + # Calculate per-batch left shifts + left_shifts = (window_size - (end_idx % window_size)) % window_size + base = torch.arange(window_size).expand(batch_size, window_size) + shifted_idx = (base + left_shifts[:, None]) % window_size + + # Determine per-batch shifted gather indices + gather_idx = torch.gather(orig_indices, dim=1, index=shifted_idx.long()) + gather_idx = gather_idx[:, None, :, None].expand(batch_size, num_head, window_size, head_dim).to(device=latest_k.device) + + # Gather to create non-physically contiguous KV cache + latest_k = torch.gather(latest_k, dim=2, index=gather_idx.long()) + latest_v = torch.gather(latest_v, dim=2, index=gather_idx.long()) + return latest_k, latest_v + + +def patched_base_image_to_text_model_forward( + self, + input_ids: torch.LongTensor = None, + seq_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + sampling_params: Optional[torch.FloatTensor] = None, + prev_hidden: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + adapter_ids: Optional[torch.LongTensor] = None, + medusa_args=None, + return_dict: Optional[bool] = None, + llava_args: Optional[List] = [], + input_capture_hook: Optional[Callable] = None, + slot_mapping: Optional[torch.LongTensor] = None, + block_table: Optional[torch.LongTensor] = None, + full_context_lens: Optional[torch.LongTensor] = None, + computed_context_lens: Optional[torch.LongTensor] = None, + vision_embeddings: Optional[torch.FloatTensor] = None, + vision_mask: Optional[torch.BoolTensor] = None, + tensor_capture_hook: Optional[Callable] = None, # Missing argument that triggers a NameError +) -> Union[Tuple, CausalLMOutputWithPast]: + """ + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + # infer attention_mask from position_ids if not provided + if attention_mask is None: + attention_mask = self._infer_attention_mask(position_ids) + + if seq_ids is None: + seq_ids = torch.arange(input_ids.shape[0]) + + self.preprocess_inputs( + input_ids=input_ids, + seq_ids=seq_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + sampling_params=sampling_params, + prev_hidden=prev_hidden, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + adapter_ids=adapter_ids, + medusa_args=medusa_args, + return_dict=return_dict, + llava_args=llava_args, + input_capture_hook=input_capture_hook, + slot_mapping=slot_mapping, + block_table=block_table, + full_context_lens=full_context_lens, + computed_context_lens=computed_context_lens, + ) + + if self.async_mode: + outputs, is_run_on_neuron = self._get_model_outputs_async( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + seq_ids=seq_ids, + sampling_params=sampling_params, + prev_hidden=prev_hidden, + adapter_ids=adapter_ids, + vision_embeddings=vision_embeddings, + vision_mask=vision_mask, + medusa_args=medusa_args, + llava_args=llava_args, + ) + else: + outputs, is_run_on_neuron = self._get_model_outputs( + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + prev_hidden, + adapter_ids, + vision_embeddings, + vision_mask, + medusa_args, + llava_args, + ) + + generation_model = self.get_generation_model() + if not generation_model.is_neuron(): + self._copy_past_key_values(outputs) + + # Process outputs + constructed_outputs = self._get_constructed_outputs(outputs, is_run_on_neuron) + + # Apply tensor_capture_hook if provided and tensors are captured + if tensor_capture_hook and constructed_outputs.captured_tensors: + # Apply the hook if captured tensors are found + tensor_capture_hook(self, constructed_outputs.captured_tensors) + + return constructed_outputs + + +def patched_hf_adapter_prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + sampling_params=None, + adapter_ids=None, + **kwargs, + ): + # Store KV cache flag before forward pass. + self.prev_kv_cache_populated = self.neuron_model.kv_cache_populated + if self.neuron_model.kv_cache_populated: + input_ids = input_ids[:, -1:] + + accepted_indices = kwargs.get("accepted_indices", None) + current_length = kwargs.get("current_length", None) + medusa_mask = kwargs.get("medusa_mask", None) + scatter_index = kwargs.get("scatter_index", None) + position_ids = kwargs.get("position_ids", None) + input_capture_hook = kwargs.get("input_capture_hook", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + if self.input_start_offsets: + if len(self.input_start_offsets) > 1: + position_ids += torch.tensor(self.input_start_offsets, dtype=position_ids.dtype, device=position_ids.device)[:, None] + else: + position_ids += self.input_start_offsets[0] + for i, offset in enumerate(self.input_start_offsets): + position_ids[i, 0:offset] = torch.arange(offset) + else: + position_ids.masked_fill_(attention_mask == 0, 1) + + if self.neuron_model.kv_cache_populated: + position_ids = torch.amax(position_ids, 1, keepdim=True) + position_ids = position_ids + 1 + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache", False), + "attention_mask": attention_mask, + "medusa_args": (accepted_indices, current_length, medusa_mask, scatter_index), + "sampling_params": sampling_params, + "input_capture_hook": input_capture_hook, + #"tensor_capture_hook": tensor_capture_hook, -> FIX: Otherwise raises a breaking NameError + "adapter_ids": adapter_ids + } + ) + + tf_args = [] + if self.neuron_config.tensor_replacement_config: + if hasattr(self, 'generation_step'): + self.generation_step += 1 + else: + self.generation_step = 1 + reg = TensorReplacementRegister.get_instance() + tf , masks = reg.step_args(self.generation_step) + tf_args = tf + masks + + # Only add tf_args if not empty + if tf_args: + model_inputs["tf_args"] = tf_args + + # WARNING: This is needed for propagating additional kwargs to the neuron model + additional_kwargs = self.neuron_model.get_required_kwargs() + for arg in additional_kwargs: + model_inputs.update({arg: kwargs.get(arg, None)}) + + return model_inputs + + +def apply_patch() -> None: + import neuronx_distributed_inference.modules.attention.utils as u + u.get_last_kv_window = patched_get_last_kv_window + import neuronx_distributed_inference.models.image_to_text_model_base as mm_base + mm_base.NeuronBaseForImageToText.forward = patched_base_image_to_text_model_forward + import neuronx_distributed_inference.utils.hf_adapter as hf_adapter + hf_adapter.HuggingFaceGenerationAdapter.prepare_inputs_for_generation = patched_hf_adapter_prepare_inputs_for_generation diff --git a/contrib/models/gemma3-vision/src/gemma3_vision/siglip/__init__.py b/contrib/models/gemma3-vision/src/gemma3_vision/siglip/__init__.py new file mode 100644 index 00000000..36cc4b5e --- /dev/null +++ b/contrib/models/gemma3-vision/src/gemma3_vision/siglip/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 © Amazon.com and Affiliates + +from .modeling_siglip import ( + NeuronSiglipVisionModel, + NeuronSiglipAttention, +) +from .layers import ( + OutputChannelParallelConv2d, +) + +__all__ = [ + "NeuronSiglipVisionModel", + "NeuronSiglipAttention", + "OutputChannelParallelConv2d", +] diff --git a/contrib/models/gemma3-vision/src/gemma3_vision/siglip/layers.py b/contrib/models/gemma3-vision/src/gemma3_vision/siglip/layers.py new file mode 100644 index 00000000..36b02fc1 --- /dev/null +++ b/contrib/models/gemma3-vision/src/gemma3_vision/siglip/layers.py @@ -0,0 +1,321 @@ +import math +from typing import Optional, Tuple, Union, Any, Callable + +from neuronx_distributed.parallel_layers.layers import ( + _as_tuple2, + _initialize_affine_weight_neuron, + _initialize_parameter_cpu, + + CONV_KERNEL_OUTPUT_CHANNEL_DIMENSION, + CONV_KERNEL_INPUT_CHANNEL_DIMENSION, + conv2d_with_weight_grad_allreduce + ) +from neuronx_distributed.parallel_layers.mappings import ( + copy_to_tensor_model_parallel_region, + gather_from_tensor_model_parallel_region_with_dim, +) +from neuronx_distributed.parallel_layers.parallel_state import get_tensor_model_parallel_size +from neuronx_distributed.parallel_layers.utils import ( + divide, + get_padding_length, + set_tensor_model_parallel_attributes, +) +import neuronx_distributed.trace.trace as nxd_tracing_utils +import torch +from torch.nn.parameter import Parameter + + +class BaseParallelConv(torch.nn.Module): + + + def set_weight_shape(self) -> None: + if self.partition_dim == CONV_KERNEL_OUTPUT_CHANNEL_DIMENSION: + if self.partition_pad: + self.partition_pad_size = get_padding_length(self.out_channels, self.world_size) + self.out_channels = self.out_channels + self.partition_pad_size + + self.channels_per_partition = divide(self.out_channels, self.world_size) + self.weight_shape = [self.channels_per_partition, self.in_channels, *_as_tuple2(self.kernel_size)] + elif self.partition_dim == CONV_KERNEL_INPUT_CHANNEL_DIMENSION: + if self.partition_pad: + self.partition_pad_size = get_padding_length(self.in_channels, self.world_size) + self.in_channels = self.in_channels + self.partition_pad_size + + self.channels_per_partition = divide(self.in_channels, self.world_size) + self.weight_shape = [self.out_channels, self.channels_per_partition, *_as_tuple2(self.kernel_size)] + else: + assert False, f"Unsupported partition dim: {self.partition_dim}" + + def set_bias_shape(self) -> None: + if self.add_bias: + self.bias_shape = ( + self.channels_per_partition + if self.partition_dim == CONV_KERNEL_OUTPUT_CHANNEL_DIMENSION + else self.out_channels + ) + else: + self.bias_shape = None + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]], + padding: Union[int, Tuple[int, int]], + dilation: Union[int, Tuple[int, int]], + groups: int, + bias: bool, + padding_mode: str, + partition_dim: int, + dtype: torch.dtype, + device: Optional[torch.device] = None, + init_method: Optional[Callable[[Any], torch.Tensor]] = None, + keep_master_params: bool = False, + partition_pad: bool = False, + ): + if not all(d == 1 for d in _as_tuple2(dilation)): + raise NotImplementedError(f"Non-1 dilation is not yet supported. Received: {dilation}") + if groups != 1: + raise NotImplementedError(f"Non-1 groups is not yet supported. Received: {groups}") + if padding_mode != "zeros": + raise NotImplementedError(f"Non-zeros padding is not yet supported. Received: {padding_mode}") + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.partition_dim = partition_dim + self.arg_init_method = init_method + self.dtype = dtype + self.device = device + self.keep_master_params = keep_master_params + self.partition_pad = partition_pad + self.add_bias = bias + self.world_size = get_tensor_model_parallel_size() + + self.set_weight_shape() + self.set_bias_shape() + + # Get torch init device if device is not explicitly mentioned + init_device = self.device + self.weight = Parameter(torch.empty(*self.weight_shape, device=init_device, dtype=self.dtype)) + self.device = self.weight.device + + if self.device.type == "cpu": + self.master_weight = _initialize_parameter_cpu( + self.weight, + partition_dim=partition_dim, + num_partitions=self.world_size, + init_method=self._init_weight, + return_master_param=self.keep_master_params, + param_dtype=self.dtype, + stride=1, + ) + elif self.device.type == "meta": + set_tensor_model_parallel_attributes( + tensor=self.weight, + is_parallel=True, + dim=partition_dim, + stride=1, + num_partitions=self.world_size, + ) + else: + assert device and device.type == "xla", "Currently only xla device type is supported" + _initialize_affine_weight_neuron( + self.weight, + self._init_weight, + partition_dim=partition_dim, + num_partitions=self.world_size, + stride=1, + ) + + if self.add_bias: + # Bias is added before running the all-gather collective + # If conv layer is sharded across output channels (partition_dim == CONV_KERNEL_OUTPUT_CHANNEL_DIMENSION), + # then the bias must be sharded + # 1. We initialize the bias to an empty parameter tensor of shape (C_out,) or (C_out/TP,) + self.bias = Parameter(torch.empty(self.bias_shape, dtype=dtype, device=device)) + + # 2. Parameter initialization + # These parallel layers are used for both training and inference. When training from scratch, weight + # initialization must be carefully done, especially when distributed (e.g. ensure the same seed is used on every rank) + # Such careful initialization is not needed when tracing (device.type == meta) or at inference + if self.device.type == "cpu": + if partition_dim == CONV_KERNEL_OUTPUT_CHANNEL_DIMENSION: + self.master_bias = _initialize_parameter_cpu( + self.bias, + CONV_KERNEL_OUTPUT_CHANNEL_DIMENSION, + num_partitions=self.world_size, + init_method=self._init_bias, + return_master_param=self.keep_master_params, + param_dtype=self.dtype, + stride=1, + ) + else: + self._init_bias(self.bias) + self.master_bias = self.bias if self.keep_master_params else None + elif self.device.type == "meta": + if partition_dim == CONV_KERNEL_OUTPUT_CHANNEL_DIMENSION: + set_tensor_model_parallel_attributes( + self.bias, + is_parallel=True, + dim=self.partition_dim, + stride=1, + num_partitions=self.world_size, + ) + self.master_bias = self.bias if self.keep_master_params else None + else: + assert device and device.type == "xla", "Currently only xla device type is supported" + if partition_dim == CONV_KERNEL_OUTPUT_CHANNEL_DIMENSION: + set_tensor_model_parallel_attributes( + self.bias, + is_parallel=True, + dim=self.partition_dim, + stride=1, + num_partitions=self.world_size, + ) + self._init_bias(self.bias) + self.master_bias = self.bias if self.keep_master_params else None + else: + self.register_parameter("bias", None) + + self._forward_impl = conv2d_with_weight_grad_allreduce + + def _init_weight(self, weight): + if self.arg_init_method is None: + torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) + else: + self.arg_init_method(weight) + + def _init_bias(self, bias): + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + torch.nn.init.uniform_(bias, -bound, bound) + + +class OutputChannelParallelConv2d(BaseParallelConv): + """Conv2d layer with parallelism on its output channels + + The definition of a Conv2d layer can be found at https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + + This layer parallelizes the Conv2d along the output channel dimension + + .. note:: + Input is expected to be four dimensional, in order [N, C, H, W] + + Arguments: + in_channels: Number of input channels + out_channels: Number of output channels in the original Conv that is being parallelized. Parallelization is handled internally by this class + kernel_size: Size of the kernel. Can be a single number for a square kernel or a tuple of two numbers + stride: Stride of the convolution. Can be a single number for uniform H/W stride or a tuple of two numbers + padding: Padding of the convolution. Can be a single number for uniform H/W padding or a tuple of two numbers + bias: If true, add bias + gather_output: If true, call all-gather on the output to assemble the partial outputs produced by each Neuron device into the full output, and make the full output available on all Neuron devices + dtype: Datatype of the weights + device: Device on which the weights should be initialized + init_method: Method for initializing the weight + keep_master_weight: If device="cpu", whether to keep the original ("master") weight the per-worker weights are split from + partition_pad: Pad the output channel dimension if needed to make the output channel count divisible by the tensor model parallel size + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]] = 1, + padding: Union[int, Tuple[int, int]] = 0, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + gather_output: bool = True, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + init_method: Optional[Callable[[Any], torch.Tensor]] = None, + keep_master_weight: bool = False, + partition_pad: bool = False, + ): + # Base class expects these all to be tuples so it can support N-dimensional convs + kernel_size = _as_tuple2(kernel_size) + stride = _as_tuple2(stride) + padding = _as_tuple2(padding) + dilation = _as_tuple2(dilation) + + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + CONV_KERNEL_OUTPUT_CHANNEL_DIMENSION, + dtype, + device, + init_method, + keep_master_weight, + partition_pad, + ) + self.kernel_size: Tuple[int, int] + self.stride: Tuple[int, int] + self.padding: Tuple[int, int] + self.dilation: Tuple[int, int] + + self.allreduce_weight_grad = get_tensor_model_parallel_size() > 1 + self.gather_output = gather_output + + def forward(self, in_tensor: torch.Tensor) -> torch.Tensor: + """Forward of OutputChannelParallelConv2d + + Args: + in_tensor: 4D tensor in order [N, C, H ,W] + + Returns: + - output + """ + + if self.allreduce_weight_grad: + input_parallel = in_tensor + else: + input_parallel = copy_to_tensor_model_parallel_region(in_tensor) + + output_parallel = self._forward_impl( + input=input_parallel, + weight=self.weight, + bias=self.bias, + stride=self.stride, + padding=self.padding, + allreduce_weight_grad=self.allreduce_weight_grad, + ) + + # We intentionally did the bias add in _forward_impl to do less work overall + # This way, each worker only has to do 1/world_size of the bias add + if self.gather_output: + # All-gather across the partitions + output = gather_from_tensor_model_parallel_region_with_dim(output_parallel, gather_dim=1) + if self.partition_pad and self.partition_pad_size > 0: + output = torch.narrow(output, 1, 0, self.out_channels - self.partition_pad_size) + else: + output = output_parallel + + return output + + def preshard_hook(self, model_state_dict: dict, prefix: str) -> None: + if not self.partition_pad or self.partition_pad_size == 0: + return + if self.out_channels != model_state_dict[prefix].shape[0] + self.partition_pad_size: + size = model_state_dict[prefix].shape[0] + raise RuntimeError( + f"State dict {prefix} is of an unexpected size {size} expected {size - self.partition_pad_size}" + ) + model_state_dict[prefix] = torch.nn.functional.pad( + model_state_dict[prefix], (0, 0, 0, 0, 0, 0, 0, self.partition_pad_size) + ) + +nxd_tracing_utils.__SUPPORTED_SHARDED_MODULES = nxd_tracing_utils.__SUPPORTED_SHARDED_MODULES + (OutputChannelParallelConv2d, ) diff --git a/contrib/models/gemma3-vision/src/gemma3_vision/siglip/modeling_siglip.py b/contrib/models/gemma3-vision/src/gemma3_vision/siglip/modeling_siglip.py new file mode 100644 index 00000000..e6b46ca6 --- /dev/null +++ b/contrib/models/gemma3-vision/src/gemma3_vision/siglip/modeling_siglip.py @@ -0,0 +1,469 @@ +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch import Size +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from transformers.utils import torch_int + +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.parallel_layers.layers import ColumnParallelLinear, RowParallelLinear, ParallelEmbedding +from neuronx_distributed_inference.models.config import NeuronConfig, InferenceConfig +from neuronx_distributed_inference.modules.attention.attention_base import NeuronAttentionBase + +from gemma3_vision.siglip.layers import OutputChannelParallelConv2d + + +class NeuronSiglipConfig(NeuronConfig): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +class SiglipInferenceConfig(InferenceConfig): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def get_required_attributes(self) -> List[str]: + # To validate if the config.json include all the configs we need in model. + # Need to manually add what's required in below list + return [ + "hidden_size", + "image_size", + "intermediate_size", + "model_type", + "num_attention_heads", + "num_hidden_layers", + "patch_size", + "vision_use_head", + ] + + +class NeuronSiglipAttention(NeuronAttentionBase): + def __init__(self, config: SiglipInferenceConfig, tensor_model_parallel_group=None): + super().__init__( + config=config, + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_attention_heads, # siglip is MHA, not GQA + head_dim=getattr(config, "head_dim", config.hidden_size // config.num_attention_heads), + qkv_bias=True, + o_bias=True, + num_cores_per_group=config.num_cores_per_group, + tensor_model_parallel_group=tensor_model_parallel_group, + ) + + +class NeuronSiglipMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = ColumnParallelLinear( + config.hidden_size, config.intermediate_size, gather_output=False + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, config.hidden_size, input_is_parallel=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + +_shape_t = Union[int, List[int], Size] + +class LayerNorm(torch.nn.LayerNorm): + """ + Compared to NxD's LayerNorm, always cast input to torch.double to preseve numerical accuracy + """ + def __init__( + self, + normalized_shape: _shape_t, + eps: float = 1e-5, + elementwise_affine: bool = True, + bias: bool = True, + device=None, + dtype=None, + ): + self.dtype = dtype + super().__init__( + normalized_shape=normalized_shape, + eps=eps, + elementwise_affine=elementwise_affine, + bias=bias, + device=device, + dtype=dtype, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # Ensure input matches the weight dtype to avoid mixed dtype errors + input = input.to(self.weight.dtype) + output = super().forward(input) + return output + + +class NeuronSiglipEncoderLayer(nn.Module): + def __init__(self, config: InferenceConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.layer_norm1 = LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.self_attn = NeuronSiglipAttention(config) + self.layer_norm2 = LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = NeuronSiglipMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.tensor, + ) -> torch.FloatTensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + ).hidden_states + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + return outputs + + +class NeuronSiglipEncoder(nn.Module): + def __init__(self, config: InferenceConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [NeuronSiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class NeuronSiglipMultiheadAttention(NeuronSiglipAttention): + """ + Compared to NeuronSiglipAttention: + 1. Accept three inputs (Query, Key, Value) instead of a single hidden states + """ + def __init__(self, config: InferenceConfig): + super().__init__(config=config) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = True, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = query.size() + + # get query proj + qkv_proj = self.get_qkv_proj() + query_states = qkv_proj.q_proj(query) * self.scale + key_states = self._shape(self.k_proj(key), -1, bsz) + value_states = self._shape(self.v_proj(value), -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, -1) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +class NeuronSiglipMultiheadAttentionPoolingHead(nn.Module): + def __init__(self, config: InferenceConfig): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.attention = NeuronSiglipMultiheadAttention(config) + self.layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = NeuronSiglipMLP(config) + + def forward(self, hidden_state): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention(probe, hidden_state, hidden_state)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +class NeuronSiglipVisionEmbeddings(nn.Module): + def __init__(self, config: InferenceConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + + if parallel_state.model_parallel_is_initialized(): + self.patch_embedding = OutputChannelParallelConv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding=0, # padding="valid" in nn.Conv2d + partition_pad=True, + ) + + self.position_embedding = ParallelEmbedding( + self.num_positions, + self.embed_dim, + shard_across_embedding=True, + pad=True, + ) + + else: + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + + self.register_buffer( + "position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False + ) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing and no class embeddings. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] + num_positions = self.position_embedding.weight.shape[0] + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embedding(self.position_ids) + + patch_pos_embed = self.position_embedding.weight.unsqueeze(0) + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return patch_pos_embed + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: + _, _, height, width = pixel_values.shape + target_dtype = self.patch_embedding.weight.dtype + # Convert pixel_values to target dtype before passing to patch_embedding to avoid mixed dtype errors + pixel_values_converted = pixel_values.to(dtype=target_dtype) + patch_embeds = self.patch_embedding(pixel_values_converted) # shape = [*, width, grid, grid] + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + # Ensure position embeddings match the dtype of embeddings + pos_emb = self.position_embedding(self.position_ids) + embeddings = embeddings + pos_emb.to(dtype=embeddings.dtype) + return embeddings + + +class NeuronSiglipVisionTransformer(nn.Module): + def __init__(self, config: InferenceConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = NeuronSiglipVisionEmbeddings(config) + self.encoder = NeuronSiglipEncoder(config) + self.post_layernorm = LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head + if self.use_head: + self.head = NeuronSiglipMultiheadAttentionPoolingHead(config) + + def forward( + self, + pixel_values, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, + ) -> BaseModelOutputWithPooling: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooler_output = self.head(last_hidden_state) if self.use_head else None + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooler_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class NeuronSiglipVisionModel(nn.Module): + def __init__(self, config: InferenceConfig): + super().__init__() + self.vision_model = NeuronSiglipVisionTransformer(config) + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + def forward( + self, + pixel_values, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + ): + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + ) diff --git a/contrib/models/gemma3-vision/src/gemma3_vision/utils.py b/contrib/models/gemma3-vision/src/gemma3_vision/utils.py new file mode 100644 index 00000000..69192ec0 --- /dev/null +++ b/contrib/models/gemma3-vision/src/gemma3_vision/utils.py @@ -0,0 +1,52 @@ +from collections import OrderedDict +import gc + +import torch +from neuronx_distributed_inference.models.config import NeuronConfig + + +StateDict = OrderedDict[str, torch.FloatTensor] + + +def _helper_concat_and_delete_qkv(state_dict: StateDict, prefix: str, attr: str) -> None: + full_state_key_q_proj = f"{prefix}.qkv_proj.q_proj.{attr}" + full_state_key_k_proj = f"{prefix}.qkv_proj.k_proj.{attr}" + full_state_key_v_proj = f"{prefix}.qkv_proj.v_proj.{attr}" + + if ( + full_state_key_q_proj in state_dict + and full_state_key_k_proj in state_dict + and full_state_key_v_proj in state_dict + ): + state_dict[f"{prefix}.qkv_proj.Wqkv.{attr}"] = torch.cat( + [ + state_dict[full_state_key_q_proj], + state_dict[full_state_key_k_proj], + state_dict[full_state_key_v_proj], + ], + dim=0 + ) + del state_dict[full_state_key_q_proj] + del state_dict[full_state_key_k_proj] + del state_dict[full_state_key_v_proj] + + +def convert_state_dict_to_fused_qkv( + state_dict: StateDict, + num_layers: int, + neuron_config: NeuronConfig, + prefix: str + ) -> StateDict: + for layer_num in range(num_layers): + layer_prefix = prefix.format(layer_num=layer_num) + _helper_concat_and_delete_qkv(state_dict, layer_prefix, "weight") + _helper_concat_and_delete_qkv(state_dict, layer_prefix, "bias") + is_qkv_quantized = ( + (neuron_config.quantized_mlp_kernel_enabled or neuron_config.quantized) and \ + f"{layer_prefix}.qkv_proj.q_proj.scale" in state_dict + ) + if is_qkv_quantized: + _helper_concat_and_delete_qkv(state_dict, layer_prefix, "scale") + + gc.collect() + return state_dict diff --git a/contrib/models/gemma3-vision/test/__init__.py b/contrib/models/gemma3-vision/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/gemma3-vision/test/assets/gemma3_27b_config.json b/contrib/models/gemma3-vision/test/assets/gemma3_27b_config.json new file mode 100644 index 00000000..6b0f3036 --- /dev/null +++ b/contrib/models/gemma3-vision/test/assets/gemma3_27b_config.json @@ -0,0 +1,123 @@ +{ + "architectures": [ + "Gemma3ForConditionalGeneration" + ], + "boi_token_index": 255999, + "dtype": "bfloat16", + "eoi_token_index": 256000, + "eos_token_id": [ + 1, + 106 + ], + "image_token_index": 262144, + "initializer_range": 0.02, + "mm_tokens_per_image": 256, + "model_type": "gemma3", + "text_config": { + "_sliding_window_pattern": 6, + "attention_bias": false, + "attention_dropout": 0.0, + "attn_logit_softcapping": null, + "final_logit_softcapping": null, + "head_dim": 128, + "hidden_activation": "gelu_pytorch_tanh", + "hidden_size": 5376, + "initializer_range": 0.02, + "intermediate_size": 21504, + "layer_types": [ + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention" + ], + "max_position_embeddings": 131072, + "model_type": "gemma3_text", + "num_attention_heads": 32, + "num_hidden_layers": 62, + "num_key_value_heads": 16, + "query_pre_attn_scalar": 168, + "rms_norm_eps": 1e-06, + "rope_local_base_freq": 10000.0, + "rope_scaling": { + "factor": 8.0, + "rope_type": "linear" + }, + "rope_theta": 1000000.0, + "sliding_window": 1024, + "use_cache": true, + "vocab_size": 262208 + }, + "transformers_version": "4.56.2", + "vision_config": { + "attention_dropout": 0.0, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1152, + "image_size": 896, + "intermediate_size": 4304, + "layer_norm_eps": 1e-06, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 27, + "patch_size": 14, + "vision_use_head": false + } +} diff --git a/contrib/models/gemma3-vision/test/conftest.py b/contrib/models/gemma3-vision/test/conftest.py new file mode 100644 index 00000000..801c08e6 --- /dev/null +++ b/contrib/models/gemma3-vision/test/conftest.py @@ -0,0 +1,40 @@ + +import random +from pathlib import Path +import tempfile + +import pytest +import torch +import torch_xla.core.xla_model as xm +from transformers import Gemma3Config +from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig +from neuronx_distributed_inference.utils.random import set_random_seed + + +@pytest.fixture +def base_compiler_flags(): + return [ + "--framework=XLA", + ] + + +@pytest.fixture(scope="session") +def random_seed(): + seed = 42 + set_random_seed(seed) + xm.set_rng_state(seed) + torch.manual_seed(seed) + random.seed(seed) + + +@pytest.fixture(scope="session") +def hf_config(): + return Gemma3Config.from_pretrained((Path(__file__).parent / "assets" / "gemma3_27b_config.json")) + + +@pytest.fixture +def tmp_dir_path(): + tmp_dir = tempfile.TemporaryDirectory() + tmp_dir_path = Path(tmp_dir.name) + yield tmp_dir_path + tmp_dir.cleanup() diff --git a/contrib/models/gemma3-vision/test/integration/__init__.py b/contrib/models/gemma3-vision/test/integration/__init__.py new file mode 100644 index 00000000..b960c546 --- /dev/null +++ b/contrib/models/gemma3-vision/test/integration/__init__.py @@ -0,0 +1 @@ +# Integration tests for Gemma3 Vision diff --git a/contrib/models/gemma3-vision/test/integration/config_gemma3_4layers.json b/contrib/models/gemma3-vision/test/integration/config_gemma3_4layers.json new file mode 100644 index 00000000..e9e5ed9c --- /dev/null +++ b/contrib/models/gemma3-vision/test/integration/config_gemma3_4layers.json @@ -0,0 +1,123 @@ +{ + "architectures": [ + "Gemma3ForConditionalGeneration" + ], + "boi_token_index": 255999, + "dtype": "bfloat16", + "eoi_token_index": 256000, + "eos_token_id": [ + 1, + 106 + ], + "image_token_index": 262144, + "initializer_range": 0.02, + "mm_tokens_per_image": 256, + "model_type": "gemma3", + "text_config": { + "_sliding_window_pattern": 6, + "attention_bias": false, + "attention_dropout": 0.0, + "attn_logit_softcapping": null, + "final_logit_softcapping": null, + "head_dim": 128, + "hidden_activation": "gelu_pytorch_tanh", + "hidden_size": 5376, + "initializer_range": 0.02, + "intermediate_size": 21504, + "layer_types": [ + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "sliding_attention" + ], + "max_position_embeddings": 131072, + "model_type": "gemma3_text", + "num_attention_heads": 32, + "num_hidden_layers": 4, + "num_key_value_heads": 16, + "query_pre_attn_scalar": 168, + "rms_norm_eps": 1e-06, + "rope_local_base_freq": 10000.0, + "rope_scaling": { + "factor": 8.0, + "rope_type": "linear" + }, + "rope_theta": 1000000.0, + "sliding_window": 1024, + "use_cache": true, + "vocab_size": 262208 + }, + "transformers_version": "4.56.2", + "vision_config": { + "attention_dropout": 0.0, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1152, + "image_size": 896, + "intermediate_size": 4304, + "layer_norm_eps": 1e-06, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 4, + "patch_size": 14, + "vision_use_head": false + } +} diff --git a/contrib/models/gemma3-vision/test/integration/run_gemma3.py b/contrib/models/gemma3-vision/test/integration/run_gemma3.py new file mode 100644 index 00000000..edd44f87 --- /dev/null +++ b/contrib/models/gemma3-vision/test/integration/run_gemma3.py @@ -0,0 +1,328 @@ +# Copyright 2025 © Amazon.com and Affiliates: This deliverable is considered Developed Content as defined in the AWS Service Terms. + +from gemma3_vision.ndxi_patch import apply_patch +apply_patch() + +import logging # noqa: E402 +import os # noqa: E402 +from pathlib import Path # noqa: E402 + +import torch +from transformers import AutoTokenizer, AutoProcessor, GenerationConfig +from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig + +from neuronx_distributed_inference.models.config import NeuronConfig, OnDeviceSamplingConfig +from neuronx_distributed_inference.models.llama4.utils.input_processor import ( + prepare_generation_inputs_hf +) +from neuronx_distributed_inference.modules.generation.sampling import prepare_sampling_params +from neuronx_distributed_inference.utils.hf_adapter import ( + load_pretrained_config, + HuggingFaceGenerationAdapter +) + +from gemma3_vision.modeling_gemma3 import NeuronGemma3ForConditionalGeneration, Gemma3InferenceConfig + + +# Configure logging +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +# Setting paths +BASE_PATH = os.getenv('PROJECT_HOME', '/home/ubuntu/nxdi-gemma3-contribution') +DATA_PATH = os.getenv('DATA_HOME', '/home/ubuntu') + +# Model configuration constants +CONFIG = { + 'TEXT_TP_DEGREE': 8, + 'VISION_TP_DEGREE': 8, + 'WORLD_SIZE': 8, + 'BATCH_SIZE': 1, + 'SEQ_LENGTH': 1024, + 'CTX_BUCKETS': [1024], # Set to a single bucket or powers of two between 128 and the SEQ_LENGTH. + 'TKG_BUCKETS': [1024], # Set to a single bucket or powers of two between 128 and the SEQ_LENGTH. + 'DTYPE': torch.bfloat16, + 'MODEL_PATH': f"{DATA_PATH}/models/gemma-3-27b-it", + 'TRACED_MODEL_PATH': f"{DATA_PATH}/traced_model/gemma-3-27b-it-small", + 'IMAGE_PATH': f"{BASE_PATH}/dog.jpg", + 'MAX_NEW_TOKENS': 100, + # Optimizations + 'QUANTIZED': False, + 'QUANTIZED_CHECKPOINTS_PATH': None, # path to pre-quantized model state dict OR path to save quantized model state_dict + 'ATTN_KERNEL_ENABLED': True, + 'VISION_ATTN_KERNEL_ENABLED': True, + 'ATTN_TKG_NKI_KERNEL_ENABLED': False, + 'FUSED_QKV': True, + 'VISION_FUSED_QKV': False, + 'ASYNC_MODE': True, + 'OUTPUT_LOGITS': True, + 'ON_DEVICE_SAMPLING': OnDeviceSamplingConfig( + dynamic=True, # Allow per-request sampling config + do_sample=True, + deterministic=True, + temperature=1.0, + top_p=1.0, + top_k=32, + global_topk=256, + top_k_kernel_enabled=True, + ), + } + +# attn_tkg_nki_kernel_enabled fails if TP != 16 +if CONFIG['TEXT_TP_DEGREE'] != 16: + CONFIG['ATTN_TKG_NKI_KERNEL_ENABLED'] = False +# validate and configure settings for quantized models +if CONFIG['QUANTIZED']: + os.environ['XLA_HANDLE_SPECIAL_SCALAR'] = "1" + os.environ['UNSAFE_FP8FNCAST'] = "1" + assert CONFIG['QUANTIZED_CHECKPOINTS_PATH'] is not None, ( + "Quantized checkpoints path must be provided for quantized model" + ) +# validate bucket lengths +assert CONFIG['SEQ_LENGTH'] == max(CONFIG['CTX_BUCKETS']), ( + f"Context bucket {max(CONFIG['CTX_BUCKETS'])} should be <= {CONFIG['SEQ_LENGTH']}" +) +assert CONFIG['SEQ_LENGTH'] == max(CONFIG['TKG_BUCKETS']), ( + f"Token generation bucket {max(CONFIG['TKG_BUCKETS'])} should be <= {CONFIG['SEQ_LENGTH']}" +) + +# Environment setup +os.environ['NEURON_PLATFORM_TARGET_OVERRIDE'] = 'trn1' +os.environ['NEURON_RT_STOCHASTIC_ROUNDING_EN'] = '0' + +torch.manual_seed(0) + +def create_neuron_configs(): + """Create text and vision neuron configurations.""" + hf_config = Gemma3TextConfig.from_pretrained(CONFIG['MODEL_PATH']) # nosec B615 + + text_config = NeuronConfig( + + ## Basic configs ## + batch_size=CONFIG['BATCH_SIZE'], + seq_len=CONFIG['SEQ_LENGTH'], # max input+output length + torch_dtype=CONFIG['DTYPE'], + # cast_type="as-declared", # comment out if optimizing for latency. uncomment if optimizing for accuracy + + ## Compiler configs ## + cc_pipeline_tiling_factor=1, + logical_nc_config=1, + + ## Distributed configs ## + tp_degree=CONFIG['TEXT_TP_DEGREE'], + cp_degree=1, + # rpl_reduce_dtype=torch.float32, # comment out if optimizing for latency. uncomment if optimizing for accuracy + save_sharded_checkpoint=True, + skip_sharding=False, + + ## Continuous batching ## + is_continuous_batching=True, # set to true for vLLM integration + ctx_batch_size=1, # set to 1 for vLLM integration + + ## Bucketing ## + enable_bucketing=True, + context_encoding_buckets=CONFIG['CTX_BUCKETS'], + token_generation_buckets=CONFIG['TKG_BUCKETS'], + + ## Optimizations ## + async_mode=CONFIG['ASYNC_MODE'], + on_device_sampling_config=CONFIG['ON_DEVICE_SAMPLING'], + output_logits=CONFIG['OUTPUT_LOGITS'], # When on-device sampling, logits are not returned by default, set to true to return logits when on-device sampling is enabled + fused_qkv=CONFIG['FUSED_QKV'], + sequence_parallel_enabled=False, # always set to false. has meaningful impacts for long-context use cases only + + ## Kernels for Optimization ## + attn_kernel_enabled=CONFIG['ATTN_KERNEL_ENABLED'], # attn kernels for context_encoding + attn_tkg_nki_kernel_enabled=CONFIG['ATTN_TKG_NKI_KERNEL_ENABLED'], # attn kernels for token generation + attn_tkg_builtin_kernel_enabled=False, # always set to false. incompatible with gemma3. + qkv_kernel_enabled=False, # QKV kernels. always set to false. incompatible with gemma3. + mlp_kernel_enabled=False, # MLP kernels. always set to false. incompatible with gemma3. + + ## Quantization ## + quantized=CONFIG['QUANTIZED'], + quantized_checkpoints_path=CONFIG['QUANTIZED_CHECKPOINTS_PATH'], + quantization_type="per_channel_symmetric", + quantization_dtype="f8e4m3", + modules_to_not_convert=[ + # Targeted at NeuronApplicationBase.generate_quantized_state_dict which works on the HF state dict + # The following patterns must match keys in the HF state dict. + "multi_modal_projector", + "vision_tower", + *[f"language_model.model.layers.{layer_idx}.self_attn" for layer_idx in range(hf_config.num_hidden_layers)], + "language_model.lm_head", + # Targeted at DecoderModelInstance.load_module which dynamically replaces [Row|Column]ParallelLinear + # layers with Quantized[Row|Column]Parallel layers. + # The following patterns must match keys in the Neuron state dict of NeuronGemma3[Text|Vision]Model + *[f"layers.{layer_idx}.self_attn" for layer_idx in range(hf_config.num_hidden_layers)], + "lm_head", + ], + kv_cache_quant=False, + quantized_mlp_kernel_enabled=False, + ) + + vision_config = NeuronConfig( + + ## Basic configs ## + batch_size=CONFIG['BATCH_SIZE'] * 2, + seq_len=CONFIG['SEQ_LENGTH'], + torch_dtype=CONFIG['DTYPE'], + # cast_type="as-declared", # comment out if optimizing for latency. uncomment if optimizing for accuracy + + ## Compiler configs ## + cc_pipeline_tiling_factor=1, + logical_nc_config=1, + + ## Distributed configs ## + tp_degree=CONFIG['VISION_TP_DEGREE'], + world_size=CONFIG['WORLD_SIZE'], + # rpl_reduce_dtype=torch.float32, # comment out if optimizing for latency. uncomment if optimizing for accuracy + save_sharded_checkpoint=True, + + ## Continuous batching ## + is_continuous_batching=True, # set to true for vLLM integration + ctx_batch_size=1, # set to 1 for vLLM integration + + ## Bucketing ## + enable_bucketing=True, + buckets=[1], + + ## Optimizations ## + fused_qkv=CONFIG['VISION_FUSED_QKV'], + + ## Kernels for Optimization ## + attn_kernel_enabled=CONFIG['VISION_ATTN_KERNEL_ENABLED'], # attn kernels for context_encoding + qkv_kernel_enabled=False, # QKV kernels. always set to false. incompatible with gemma3. + mlp_kernel_enabled=False, # MLP kernels. always set to false. incompatible with gemma3. + ) + + return text_config, vision_config + + +def setup_model_and_tokenizer(): + """Initialize model configuration, tokenizer, and processor.""" + text_config, vision_config = create_neuron_configs() + + config = Gemma3InferenceConfig( + text_neuron_config=text_config, + vision_neuron_config=vision_config, + load_config=load_pretrained_config(CONFIG['MODEL_PATH']), + ) + config.vision_config.num_hidden_layers = 1 + config.text_config.num_hidden_layers = 1 + tokenizer = AutoTokenizer.from_pretrained(CONFIG['MODEL_PATH'], padding_side="right") # nosec B615 + tokenizer.pad_token = tokenizer.eos_token + processor = AutoProcessor.from_pretrained(CONFIG['MODEL_PATH']) # nosec B615 + + return config, tokenizer, processor + + +def compile_or_load_model(config, tokenizer): + """Compile model if needed, otherwise load from checkpoint.""" + if not os.path.exists(CONFIG['TRACED_MODEL_PATH']): + if config.neuron_config.quantized and config.neuron_config.save_sharded_checkpoint: + quantized_state_dict_path = Path(config.neuron_config.quantized_checkpoints_path) + quantized_sd_available = quantized_state_dict_path.exists() + if not quantized_sd_available: + # Weights quantized at compile-time. Directory must already exist. + print("\nQuantizing and saving model weights...") + quantized_state_dict_path.mkdir(parents=True, exist_ok=True) + NeuronGemma3ForConditionalGeneration.save_quantized_state_dict(CONFIG['MODEL_PATH'], config) + print("\nCompiling and saving model...") + model = NeuronGemma3ForConditionalGeneration(CONFIG['MODEL_PATH'], config) + model.compile(CONFIG['TRACED_MODEL_PATH'], debug=True) + tokenizer.save_pretrained(CONFIG['TRACED_MODEL_PATH']) + + print("\nLoading model from compiled checkpoint...") + model = NeuronGemma3ForConditionalGeneration(CONFIG['TRACED_MODEL_PATH']) + model.load(CONFIG['TRACED_MODEL_PATH'], skip_warmup=True) + tokenizer = AutoTokenizer.from_pretrained(CONFIG['TRACED_MODEL_PATH']) # nosec B615 + + return model, tokenizer + + +def generate_outputs(model, tokenizer, input_ids, attention_mask, pixel_values=None, vision_mask=None, max_new_tokens=50): + """Generate text using the model.""" + generation_model = HuggingFaceGenerationAdapter(model) + generation_config = GenerationConfig.from_pretrained(CONFIG['MODEL_PATH']) # nosec B615 + sampling_params = prepare_sampling_params(batch_size=CONFIG['BATCH_SIZE'], top_k=[1], top_p=[1.0], temperature=[0.0]) + + return_dict_in_generate = False + + generation_config.update(**{ + "do_sample": True, + "output_scores": False, # Post-processed logits + "output_logits": False, # Raw logits + "return_dict_in_generate": return_dict_in_generate, + }) + + outputs = generation_model.generate( + input_ids, + generation_config=generation_config, + attention_mask=attention_mask, + max_length=model.config.neuron_config.max_length, + sampling_params=sampling_params, + pixel_values=pixel_values, + vision_mask=vision_mask.to(torch.bool) if vision_mask is not None else None, + max_new_tokens=max_new_tokens, + return_dict_in_generate=return_dict_in_generate, + output_scores=False, + ) + + output_sequences = outputs.sequences if return_dict_in_generate else outputs + + output_tokens = tokenizer.batch_decode(output_sequences, skip_special_tokens=True, clean_up_tokenization_spaces=False) + return outputs, output_tokens + + +def run_gemma3_generate_image_to_text(run_test_inference=False, run_benchmark=False): + """Main function to run Gemma3 text and image generation.""" + # Setup + config, tokenizer, processor = setup_model_and_tokenizer() + model, tokenizer = compile_or_load_model(config, tokenizer) + + if run_test_inference: + print("Running output check...") + + # Test 1: Text + Image generation + print("\n=== Text + Image Generation ===") + text_prompt = "Describe what you see in the following image(s)" + + input_ids, attention_mask, pixel_values, vision_mask = prepare_generation_inputs_hf( + text_prompt, [CONFIG['IMAGE_PATH'], CONFIG['IMAGE_PATH']], processor, 'user', config + ) + + if CONFIG['BATCH_SIZE'] > 1: + input_ids = input_ids.repeat(CONFIG['BATCH_SIZE'], 1) + attention_mask = attention_mask.repeat(CONFIG['BATCH_SIZE'], 1) + pixel_values = pixel_values.repeat(CONFIG['BATCH_SIZE'], 1, 1, 1) + vision_mask = vision_mask.repeat(CONFIG['BATCH_SIZE'], 1, 1) + + outputs, output_tokens = generate_outputs( + model, tokenizer, input_ids, attention_mask, pixel_values, vision_mask, max_new_tokens=CONFIG['MAX_NEW_TOKENS'] + ) + + for i, output_token in enumerate(output_tokens): + print(f"Output {i}: {output_token}") + + + print("\n=== Text-Only Generation ===") + text_prompt = "What is the recipe of mayonnaise in two sentences?" + + input_ids, attention_mask, _, _ = prepare_generation_inputs_hf( + text_prompt, None, processor, 'user' + ) + + if CONFIG['BATCH_SIZE'] > 1: + input_ids = input_ids.repeat(CONFIG['BATCH_SIZE'], 1) + attention_mask = attention_mask.repeat(CONFIG['BATCH_SIZE'], 1) + + outputs, output_tokens = generate_outputs( + model, tokenizer, input_ids, attention_mask, max_new_tokens=CONFIG['MAX_NEW_TOKENS'] + ) + + for i, output_token in enumerate(output_tokens): + print(f"Output {i}: {output_token}") + + +if __name__ == "__main__": + run_gemma3_generate_image_to_text(run_test_inference=True, run_benchmark=False) diff --git a/contrib/models/gemma3-vision/test/integration/test_model.py b/contrib/models/gemma3-vision/test/integration/test_model.py new file mode 100644 index 00000000..34177ff6 --- /dev/null +++ b/contrib/models/gemma3-vision/test/integration/test_model.py @@ -0,0 +1,171 @@ +from gemma3_vision.ndxi_patch import apply_patch +apply_patch() + +import os # noqa: E402 +from pathlib import Path # noqa: E402 +from typing import Dict # noqa: E402 + +import pytest +import torch + +from neuronx_distributed_inference.utils.accuracy import ( + generate_expected_logits, + check_accuracy_logits_v2, +) +from neuronx_distributed_inference.utils.benchmark import benchmark_sampling + +from gemma3_vision.modeling_gemma3 import NeuronGemma3ForConditionalGeneration +from .utils import ( + get_test_name_suffix, + save_hf_checkpoint, + create_neuron_config, + create_generation_config, + prepare_inputs, +) + + +NUM_TOKENS_TO_CHECK = 16 +LNC = int(os.environ.get("NEURON_LOGICAL_NC_CONFIG", "1")) + + +@pytest.mark.parametrize( + "config_file_path,tp_degree,torch_dtype,batch_size,num_images_per_sample,total_max_seq_len,token_divergence_atol,perf_thresholds", + [ + ( + Path(__file__).resolve().parent / "config_gemma3_4layers.json", + 8, + torch.float16, + 1, + 1, + 1024, + 0.02, + { + "text_cte_p50_latency": 20.55, + "text_cte_throughput": 49807.3, + "tkg_p50_latency": 4.42, + "tkg_throughput": 226.4, + }, + ), + ] +) +def test_original_cpu_vs_nxdi_neuron( + tmp_path: Path, + config_file_path: Path, + tp_degree: int, + torch_dtype: torch.dtype, + batch_size: int, + num_images_per_sample: int, + total_max_seq_len: int, + token_divergence_atol: float, + perf_thresholds: Dict[str, float], + ) -> None: + suffix = get_test_name_suffix( + tp_degree=tp_degree, + torch_dtype=torch_dtype, + batch_size=batch_size, + num_images_per_sample=num_images_per_sample, + max_seq_len=total_max_seq_len + ) + + nrn_config = create_neuron_config( + hf_config_path=config_file_path, + text_batch_size=batch_size, + vision_batch_size=(num_images_per_sample * batch_size), + total_max_seq_len=total_max_seq_len, + torch_dtype=torch_dtype, + lnc=LNC, + tp_degree=tp_degree + ) + + input_ids, attention_mask, pixel_values, vision_mask = prepare_inputs( + nrn_config=nrn_config, + torch_dtype=torch_dtype + ) + + generation_config = create_generation_config(nrn_config=nrn_config) + + save_hf_checkpoint( + output_dir_path=tmp_path, + config_file_path=config_file_path, + torch_dtype=torch_dtype, + ) + + nrn_config._name_or_path = tmp_path.as_posix() + nrn_model = NeuronGemma3ForConditionalGeneration(model_path=tmp_path, config=nrn_config) + + traced_model_path = tmp_path / ("traced_model" + suffix) + traced_model_path.mkdir(exist_ok=True) + + nrn_model.compile(traced_model_path.as_posix()) + + nrn_model.load(traced_model_path.as_posix()) + + benchmark_report = benchmark_sampling( + model=nrn_model, + generation_config=generation_config, + image=False, # image=True currently broken (Neuron 2.27.1) + benchmark_report_path=f"./benchmark_report{suffix}.json" + ) + + assert benchmark_report["context_encoding_model"]["latency_ms_p50"] < perf_thresholds["text_cte_p50_latency"] * 1.1 + assert benchmark_report["context_encoding_model"]["throughput"] > perf_thresholds["text_cte_throughput"] * 0.9 + assert benchmark_report["token_generation_model"]["latency_ms_p50"] < perf_thresholds["tkg_p50_latency"] * 1.1 + assert benchmark_report["token_generation_model"]["throughput"] > perf_thresholds["tkg_throughput"] * 0.9 + + expected_logits = generate_expected_logits( + neuron_model=nrn_model, + input_ids=input_ids, + inputs_attention_mask=attention_mask, + generation_config=generation_config, + num_tokens=NUM_TOKENS_TO_CHECK, + additional_input_args={ + "pixel_values": pixel_values, + }, + ) + + additional_input_args = { + "pixel_values": pixel_values, + "vision_mask": vision_mask, + } + + check_accuracy_logits_v2( + neuron_model=nrn_model, + expected_logits=expected_logits, + inputs_input_ids=input_ids, + inputs_attention_mask=attention_mask, + generation_config=generation_config, + num_tokens_to_check=NUM_TOKENS_TO_CHECK, + additional_input_args=additional_input_args, + divergence_difference_tol=token_divergence_atol, + ) + +if __name__ == "__main__": + import tempfile + tmp_dir = tempfile.TemporaryDirectory() + tmp_dir_path = Path(tmp_dir.name) + torch_dtype = torch.float16 + token_divergence_atol = 0.02 + config_file_path = Path(__file__).resolve().parent / "config_gemma3_4layers.json" + perf_thresholds = { + "text_cte_p50_latency": 20.55, + "text_cte_throughput": 49807.3, + "tkg_p50_latency": 4.42, + "tkg_throughput": 226.4, + } + tp_degree = 8 + batch_size = num_images_per_sample = 1 + total_max_seq_len = 1024 + + test_original_cpu_vs_nxdi_neuron( + config_file_path=config_file_path, + tmp_path=tmp_dir_path, + torch_dtype=torch_dtype, + token_divergence_atol=token_divergence_atol, + perf_thresholds=perf_thresholds, + tp_degree=tp_degree, + batch_size=batch_size, + num_images_per_sample=num_images_per_sample, + total_max_seq_len=total_max_seq_len, + ) + + tmp_dir.cleanup() diff --git a/contrib/models/gemma3-vision/test/integration/utils.py b/contrib/models/gemma3-vision/test/integration/utils.py new file mode 100644 index 00000000..f45ff635 --- /dev/null +++ b/contrib/models/gemma3-vision/test/integration/utils.py @@ -0,0 +1,166 @@ +from pathlib import Path +from typing import Optional, Tuple + +import torch +from transformers import Gemma3Config, Gemma3ForConditionalGeneration, GenerationConfig + +from neuronx_distributed_inference.models.config import NeuronConfig, OnDeviceSamplingConfig +from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config + +from gemma3_vision.modeling_gemma3 import Gemma3InferenceConfig + + +def get_test_name_suffix( + tp_degree: int, + torch_dtype: torch.dtype, + batch_size: int, + num_images_per_sample: int, + max_seq_len: int, +) -> str: + dtype_str = { + torch.float16: "fp16", + torch.bfloat16: "bf16", + torch.float32: "fp32", + }.get(torch_dtype, str(torch_dtype).split(".")[-1]) + vision_batch_size = batch_size * num_images_per_sample + return f"_{tp_degree}_{dtype_str}_tbs{batch_size}_vbs{vision_batch_size}_s{max_seq_len}" + + +def get_hf_config( + hf_model_path: Path, + torch_dtype: Optional[torch.dtype] = None, + num_hidden_layers: Optional[int] = None, +) -> Gemma3Config: + hf_config = Gemma3Config.from_pretrained(hf_model_path) + + if torch_dtype is not None: + hf_config.torch_dtype = torch_dtype + + if num_hidden_layers is not None: + hf_config.num_hidden_layers = num_hidden_layers + if getattr(hf_config, "text_config", None) is not None: + hf_config.text_config.num_hidden_layers = num_hidden_layers + if getattr(hf_config, "vision_config", None) is not None: + hf_config.vision_config.num_hidden_layers = num_hidden_layers + + return hf_config + + +def save_hf_checkpoint( + output_dir_path: Path, + config_file_path: Path, + torch_dtype: torch.dtype, +) -> None: + hf_config = Gemma3Config.from_pretrained(config_file_path, torch_dtype=torch_dtype) + hf_model = Gemma3ForConditionalGeneration(config=hf_config) # random weights + hf_model.save_pretrained(output_dir_path) + + +def create_neuron_config( + hf_config_path: Path, + text_batch_size: int = 1, + vision_batch_size: int = 1, + total_max_seq_len: int = 1024, + torch_dtype: torch.dtype = torch.float16, + lnc: int = 1, + tp_degree: int = 8, +) -> Gemma3InferenceConfig: + text_config = NeuronConfig( + batch_size=text_batch_size, + seq_len=total_max_seq_len, + torch_dtype=torch_dtype, + rpl_reduce_dtype=torch.float32, + cast_type="as-declared", + logical_nc_config=lnc, + tp_degree=tp_degree, + world_size=tp_degree, + skip_sharding=False, + save_sharded_checkpoint=True, + enable_bucketing=True, + context_encoding_buckets=[total_max_seq_len], + token_generation_buckets=[total_max_seq_len], + on_device_sampling_config=OnDeviceSamplingConfig( + dynamic=False, + do_sample=False, + deterministic=True, + temperature=1.0, + top_p=1.0, + top_k=1, + global_topk=256, + top_k_kernel_enabled=False, + ), + output_logits=True, + ) + + vision_config = NeuronConfig( + batch_size=vision_batch_size, + seq_len=total_max_seq_len, # Does not matter + torch_dtype=torch_dtype, + rpl_reduce_dtype=torch.float32, + logical_nc_config=lnc, + tp_degree=tp_degree, + world_size=tp_degree, + skip_sharding=False, + save_sharded_checkpoint=True, + enable_bucketing=True, + buckets=[vision_batch_size], + ) + + nrn_config = Gemma3InferenceConfig( + text_neuron_config=text_config, + vision_neuron_config=vision_config, + load_config=load_pretrained_config(hf_config_path), + ) + return nrn_config + + +def create_generation_config(nrn_config: Gemma3InferenceConfig) -> GenerationConfig: + return GenerationConfig( + do_sample=False, + pad_token_id=nrn_config.text_config.pad_token_id, + output_scores=True, # Processed & warped logits + output_logits=False, # Raw logits -> not needed + return_dict_in_generate=True, + ) + + +def prepare_inputs( + nrn_config: Gemma3InferenceConfig, torch_dtype: torch.dtype +) -> Tuple[torch.Tensor, ...]: + batch_size = nrn_config.text_config.neuron_config.batch_size + text_tokens_length = 16 + text_input_ids = ( + torch.rand((batch_size, text_tokens_length)) * nrn_config.text_config.vocab_size + ) + + image_per_sample = ( + nrn_config.vision_config.neuron_config.batch_size // batch_size + ) + vision_tokens_length = nrn_config.mm_tokens_per_image + vision_input_ids = torch.full( + [batch_size, image_per_sample * vision_tokens_length], + fill_value=nrn_config.image_token_index, + ) + + input_ids = torch.cat((text_input_ids, vision_input_ids), dim=1).to( + dtype=torch.int32 + ) + + total_length = text_tokens_length + vision_tokens_length + attention_mask_2d = torch.ones((batch_size, total_length), dtype=torch.int32) + + pixel_values = torch.rand( + ( + batch_size * image_per_sample, + nrn_config.vision_config.num_channels, + nrn_config.vision_config.image_size, + nrn_config.vision_config.image_size, + ), + dtype=torch.float32, + ) + pixel_values = (2.0 * pixel_values - 1.0).to(dtype=torch_dtype) + + vision_mask = (input_ids == nrn_config.image_token_index).unsqueeze(-1) + vision_mask = vision_mask.to(torch.bool) + + return input_ids, attention_mask_2d, pixel_values, vision_mask diff --git a/contrib/models/gemma3-vision/test/unit/__init__.py b/contrib/models/gemma3-vision/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/gemma3-vision/test/unit/gemma3/test_attention.py b/contrib/models/gemma3-vision/test/unit/gemma3/test_attention.py new file mode 100644 index 00000000..de83d613 --- /dev/null +++ b/contrib/models/gemma3-vision/test/unit/gemma3/test_attention.py @@ -0,0 +1,168 @@ + +import logging +from typing import Dict, OrderedDict + +import pytest +import torch +import torch.nn.functional as F +from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention +from neuronx_distributed_inference.models.config import NeuronConfig +from neuronx_distributed_inference.utils.testing import init_cpu_env +from neuronx_distributed.utils import cpu_mode + +from gemma3_vision.modeling_gemma3_text import NeuronGemma3Attention, NeuronGemma3TextModel, get_rmsnorm_cls +from gemma3_vision.modeling_causal_lm_gemma3 import TextGemma3InferenceConfig +from test.utils import ( + assert_tensor_all_close, + create_cache_position, + create_hf_attention_mask_4d, + create_hidden_states, + create_position_ids, + create_rope, + FP32_TOLERANCES, +) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +def convert_to_hf_state_dict(state_dict: OrderedDict[str, torch.FloatTensor]) -> Dict[str, torch.FloatTensor]: + hf_state_dict = {} + for key, tensor in state_dict.items(): + if key.startswith("qkv_proj."): + hf_state_dict[key.replace("qkv_proj.", "")] = tensor + elif key.startswith("o_proj."): + hf_state_dict["o_proj.weight"] = tensor + elif key.startswith("q_layernorm."): + hf_state_dict["q_norm.weight"] = tensor + elif key.startswith("k_layernorm."): + hf_state_dict["k_norm.weight"] = tensor + else: + logger.info(f"Skipping unexpected input key: {key}") + + return hf_state_dict + + +@pytest.mark.parametrize("layer_idx", [ + 0, # sliding + 1, # non-sliding + ]) +def test_nxdi_attn_layer_vs_transformers_implementation_prefill(random_seed, monkeypatch, hf_config, layer_idx) -> None: + # TODO: Move to a fixture + monkeypatch.setenv("NXD_CPU_MODE", "1") + init_cpu_env() + assert cpu_mode() is True + padding_side = "left" # HuggingFace reference only supports left padding + bucket_size, sliding_window_size, sliding_window_pattern = 8, 4, 2 + + is_swa_layer = (layer_idx + 1) % sliding_window_pattern != 0 + + hf_text_config = hf_config.text_config + hf_text_config.sliding_window = sliding_window_size + hf_text_config.sliding_window_pattern = sliding_window_pattern + # Make test faster on CPU + head_dim = 2 + hf_text_config.num_attention_heads = 2 + hf_text_config.num_key_value_heads = 1 + hf_text_config.head_dim = head_dim + hf_text_config.hidden_size = 4 + hf_text_config._attn_implementation = "eager" + hf_text_config.query_pre_attn_scalar = head_dim + + attention_mask_2d = torch.tensor([[0, 0, 0, 1, 1], + [0, 0, 1, 1, 1], + [0, 1, 1, 1, 1], + [1, 1, 1, 1, 1]], dtype=torch.int32) + + batch_size, max_input_seq_len = attention_mask_2d.shape + inputs_dtype = model_dtype = torch.float32 + + attention_mask_2d = F.pad(attention_mask_2d, (0, bucket_size - max_input_seq_len), "constant", 0) + + position_ids = create_position_ids(attention_mask_2d=attention_mask_2d, is_for_context_encoding=True) + cache_position = create_cache_position(attention_mask_2d=attention_mask_2d, is_for_context_encoding=True) + + cos, sin = create_rope(position_ids=position_ids, hf_config=hf_text_config) + hidden_states = create_hidden_states(attention_mask_2d=attention_mask_2d, hf_config=hf_text_config, is_for_context_encoding=True) + + neuron_config = NeuronConfig( + tp_degree=1, + batch_size=batch_size, + max_context_length=bucket_size, + seq_len=bucket_size, + torch_dtype=model_dtype, + fused_qkv=False, + attn_kernel_enabled=False, + qkv_kernel_enabled=False, + padding_side=padding_side, + ) + + config = TextGemma3InferenceConfig( + neuron_config=neuron_config, + **hf_text_config.to_dict() + ) + + nrn_model = NeuronGemma3TextModel(config=config) + + sliding_window = sliding_window_size if is_swa_layer else None + rms_norm_cls = get_rmsnorm_cls() + rms_norm_eps = getattr(config, "rms_norm_eps", None) + q_norm = rms_norm_cls(config.head_dim, rms_norm_eps) if rms_norm_eps else rms_norm_cls(config.head_dim) + k_norm = rms_norm_cls(config.head_dim, rms_norm_eps) if rms_norm_eps else rms_norm_cls(config.head_dim) + + nrn_attn_layer = NeuronGemma3Attention( + config=config, + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + sliding_window=sliding_window, + use_qk_norm=False, + q_layernorm=q_norm, + k_layernorm=k_norm, + rotary_emb=NeuronGemma3Attention.get_rope(config=config, is_swa_layer=is_swa_layer), + ) + nrn_attn_layer.eval() + + hf_attn_layer = Gemma3Attention(config=hf_text_config, layer_idx=layer_idx).to(dtype=model_dtype) + hf_attn_layer.load_state_dict(convert_to_hf_state_dict(nrn_attn_layer.state_dict()), strict=True) + hf_attn_layer.eval() + + # Attention mask creation + attention_mask_4d_hf = create_hf_attention_mask_4d( + attention_mask_2d=attention_mask_2d, + cache_position=cache_position, + is_for_context_encoding=True, + dtype=inputs_dtype, + is_swa_layer=is_swa_layer, + sliding_window_size=sliding_window_size, + ) + + if not is_swa_layer: + # Global attention mask + attention_mask_4d = nrn_model._create_context_attn_mask( + attention_mask=attention_mask_2d, + ) + else: + # Sliding window attention (SWA) mask + # Note: As of Neuron 2.26, NeuronBaseModel._create_windowed_attn_mask_cte does not support + # left padding we therefore use the HF left-padded mask to create the Neuron attention mask + attention_mask_4d = (attention_mask_4d_hf == 0) + + with torch.no_grad(): + ref_output, *_ = hf_attn_layer( + hidden_states=hidden_states, + position_embeddings=(cos, sin), + attention_mask=attention_mask_4d_hf, + ) + + output = nrn_attn_layer( + hidden_states=hidden_states, + attention_mask=attention_mask_4d, + cos_cache=cos, + sin_cache=sin, + position_ids=position_ids, + ) + output = output.hidden_states + + rtol, atol = FP32_TOLERANCES.rtol, FP32_TOLERANCES.atol + assert_tensor_all_close(test_objective="Attention outputs", computed_value=output, reference_value=ref_output, rtol=rtol, atol=atol, equal_nan=True) diff --git a/contrib/models/gemma3-vision/test/unit/gemma3/test_decoder.py b/contrib/models/gemma3-vision/test/unit/gemma3/test_decoder.py new file mode 100644 index 00000000..e2c42208 --- /dev/null +++ b/contrib/models/gemma3-vision/test/unit/gemma3/test_decoder.py @@ -0,0 +1,105 @@ +import copy +import logging +from typing import Dict, OrderedDict + +import pytest +import torch +from transformers.models.gemma3.modeling_gemma3 import Gemma3DecoderLayer, Gemma3RotaryEmbedding + +from gemma3_vision.modeling_gemma3_text import NeuronGemma3DecoderLayer +from test.utils import assert_tensor_all_close, causal_mask, window_mask, mark_step, cpu_setup, create_neuron_config, FP32_TOLERANCES, FP16_TOLERANCES, BF16_TOLERANCES + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +def convert_to_hf_state_dict(state_dict: OrderedDict[str, torch.FloatTensor]) -> Dict[str, torch.FloatTensor]: + hf_state_dict = {} + for key, tensor in state_dict.items(): + if key.startswith("self_attn"): + splits = key.split(".") + if len(splits) == 4: + # q/k/v/o projection + hf_state_dict[f"self_attn.{splits[-2]}.{splits[-1]}"] = tensor + else: + # norm weights + # in Gemma3RMSNorm, weights are initialized with torch.zeros + # while Neuron's CustomRMSNorms initializes with torch.ones + hf_state_dict["self_attn.q_norm.weight"] = torch.zeros_like(tensor) + hf_state_dict["self_attn.k_norm.weight"] = torch.zeros_like(tensor) + elif key.find("_layernorm.") != -1: + hf_state_dict[key] = torch.zeros_like(tensor) + else: + hf_state_dict[key] = tensor + return hf_state_dict + + +@pytest.mark.parametrize("layer_idx", [0, 5]) +def test_nxdi_decoder_layer_cpu_vs_transformers_implementation(random_seed, layer_idx, hf_config) -> None: + inputs_dtype = model_dtype = torch.float32 + batch_size, max_seq_len = 2, 64 + hf_config.text_config.sliding_window = 10 + hf_config.text_config._attn_implementation = "eager" + hf_config.text_config.query_pre_attn_scalar = hf_config.text_config.head_dim + + # --- Set NxDI Model --- + nrn_config = create_neuron_config( + batch_size=batch_size, + max_seq_len=max_seq_len, + torch_dtype=model_dtype, + tp_degree=1, + hf_config=hf_config + ) + + cpu_setup(model_dtype) + decoder_layer = NeuronGemma3DecoderLayer(config=nrn_config.text_config, layer_idx=layer_idx).to(dtype=model_dtype) + decoder_layer.eval() + + # --- Set Transformers Model --- + hf_text_config = hf_config.text_config + + reference_model = Gemma3DecoderLayer(hf_text_config, layer_idx=layer_idx) + reference_model.load_state_dict(convert_to_hf_state_dict(decoder_layer.state_dict()), strict=True) + reference_model.eval() + + # --- Set Inputs --- + batch_size, seq_len, hidden_size = 2, 15, hf_text_config.hidden_size + hidden_states = torch.randn(batch_size, seq_len, hidden_size).to(dtype=inputs_dtype) + position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1).to(dtype=inputs_dtype) + + attention_mask = causal_mask(batch_size, seq_len).to(dtype=inputs_dtype) + local_mask = None + if decoder_layer.is_swa_layer: + local_mask = window_mask(batch_size, seq_len, decoder_layer.sliding_window) + + attention_mask_nrn = local_mask if local_mask is not None else attention_mask + attention_mask_hf = torch.where(attention_mask_nrn.to(bool), 0.0, torch.finfo(inputs_dtype).min).to(inputs_dtype) + + ## Required only for the reference model + rotary_emb = Gemma3RotaryEmbedding(config=hf_text_config) + position_embeddings_global = rotary_emb(hidden_states, position_ids) + + hf_text_config_copy = copy.deepcopy(hf_text_config) + hf_text_config_copy.rope_theta = hf_text_config_copy.rope_local_base_freq + hf_text_config_copy.rope_scaling = {"rope_type": "default"} + rotary_emb_local = Gemma3RotaryEmbedding(config=hf_text_config_copy) + position_embeddings_local = rotary_emb_local(hidden_states, position_ids) + + with torch.no_grad(): + device = torch.device("cpu") + ref_output, *_ = reference_model( + hidden_states=hidden_states, + position_embeddings_global=position_embeddings_global, + position_embeddings_local=position_embeddings_local, + attention_mask=attention_mask_hf, + cache_position=torch.arange(0, seq_len) # required for sliding-window layers + ) + output, *_ = decoder_layer( + hidden_states=hidden_states.to(device=device), + attention_mask=attention_mask.to(device=device), + local_mask=local_mask.to(device=device) if local_mask is not None else None, + position_ids=position_ids.to(device=device) + ) + + rtol, atol = FP32_TOLERANCES.rtol, FP32_TOLERANCES.atol + assert_tensor_all_close(test_objective="Gemma3 decoder - nxdi (cpu) vs huggingface", computed_value=output, reference_value=ref_output, rtol=rtol, atol=atol, equal_nan=True) diff --git a/contrib/models/gemma3-vision/test/unit/gemma3/test_multimodal_projector.py b/contrib/models/gemma3-vision/test/unit/gemma3/test_multimodal_projector.py new file mode 100644 index 00000000..0cf3b47e --- /dev/null +++ b/contrib/models/gemma3-vision/test/unit/gemma3/test_multimodal_projector.py @@ -0,0 +1,114 @@ +import os +import pytest +import torch +import torch_xla.core.xla_model as xm +from transformers.models.gemma3.modeling_gemma3 import Gemma3MultiModalProjector +from neuronx_distributed_inference.utils.random import set_random_seed +from neuronx_distributed_inference.utils.testing import destroy_mp, init_cpu_env + +from gemma3_vision.modeling_gemma3_vision import NeuronGemma3MultiModalProjector +from test.utils import assert_tensor_all_close, mark_step, create_neuron_config, FP32_TOLERANCES, FP16_TOLERANCES, BF16_TOLERANCES + + +def _cpu_setup(dtype): + set_random_seed(0) + os.environ.setdefault("NXD_CPU_MODE", "1") + init_cpu_env() + torch.set_default_dtype(dtype) + torch.set_default_device("cpu") + + +@pytest.mark.parametrize("tolerances, compiler_flags", [ + (FP32_TOLERANCES, ["--model-type=transformer", "--auto-cast=none"]), + (FP16_TOLERANCES, ["--model-type=transformer", "--auto-cast=matmult", "--enable-mixed-precision-accumulation", "--auto-cast-type=fp16"]), + (BF16_TOLERANCES, ["--model-type=transformer", "--auto-cast=matmult", "--enable-mixed-precision-accumulation", "--auto-cast-type=bf16"]), + ]) +def test_multimodal_projector(monkeypatch, base_compiler_flags, tolerances, compiler_flags, hf_config) -> None: + monkeypatch.setenv("NEURON_CC_FLAGS", " ".join(base_compiler_flags + compiler_flags)) + + image_size, patch_size = 448, 28 + num_patches = int((image_size/patch_size)**2) + batch_size, max_seq_len, hidden_size = 2, 64, hf_config.vision_config.hidden_size + inputs_dtype = model_dtype = torch.float32 + + hf_config.vision_config.image_size = image_size + hf_config.vision_config.patch_size = patch_size + + vision_outputs = torch.randn(batch_size, num_patches, hidden_size).to(dtype=inputs_dtype) + + nrn_config = create_neuron_config( + batch_size=batch_size, + max_seq_len=max_seq_len, + torch_dtype=model_dtype, + tp_degree=2, + hf_config=hf_config + ) + + # --- CPU Reference Execution --- + # Note: We explicitly set 'NXD_CPU_MODE' to force a CPU-only environment. + # This is critical because the module's initialization logic (in + # get_rmsnorm_cls) checks this variable to choose between the + # CPU and Neuron-specific RMSNorm implementations. + _cpu_setup(model_dtype) + mm_projector = NeuronGemma3MultiModalProjector(config=nrn_config).to(dtype=model_dtype) + mm_projector.eval() + + with torch.no_grad(): + cpu_output = mm_projector(vision_outputs) + + # --- Neuron Device Execution --- + # Note: Tear down CPU environment and switch to NeuronCore mode + destroy_mp() + os.environ.setdefault("NXD_CPU_MODE", "0") + set_random_seed(0) + + with torch.no_grad(): + mm_projector_nrn = mm_projector.to(device=xm.xla_device()) + mark_step() + nrn_output = mm_projector_nrn(vision_outputs.to(device=xm.xla_device())) + mark_step() + nrn_output = nrn_output.cpu() + + rtol, atol = tolerances.rtol, tolerances.atol + assert_tensor_all_close(test_objective="Multi modal projector outputs", computed_value=nrn_output, reference_value=cpu_output, rtol=rtol, atol=atol, equal_nan=True) + + +def test_nxdi_mm_projector_vs_transformers_implementation(random_seed, hf_config) -> None: + image_size, patch_size = 448, 28 + num_patches = int((image_size/patch_size)**2) + batch_size, max_seq_len, hidden_size = 2, 64, hf_config.vision_config.hidden_size + inputs_dtype = model_dtype = torch.float32 + + hf_config.vision_config.image_size = image_size + hf_config.vision_config.patch_size = patch_size + + vision_outputs = torch.randn(batch_size, num_patches, hidden_size).to(dtype=inputs_dtype) + + # --- Set NxDI Model --- + nrn_config = create_neuron_config( + batch_size=batch_size, + max_seq_len=max_seq_len, + torch_dtype=model_dtype, + tp_degree=2, + hf_config=hf_config + ) + + mm_projector = NeuronGemma3MultiModalProjector(config=nrn_config).to(dtype=model_dtype) + mm_projector.eval() + mm_projector.to(device=xm.xla_device()) + + # --- Set Transformers Model --- + hf_config.vision_config.image_size = image_size + hf_config.vision_config.patch_size = patch_size + + reference_model = Gemma3MultiModalProjector(config=hf_config).to(dtype=model_dtype) + reference_model.load_state_dict(mm_projector.state_dict(), strict=True) + reference_model.eval() + + with torch.no_grad(): + ref_output = reference_model(vision_outputs=vision_outputs) + output = mm_projector(vision_outputs=vision_outputs.to(device=xm.xla_device())) + output = output.cpu() + + rtol, atol = FP32_TOLERANCES.rtol, FP32_TOLERANCES.atol + assert_tensor_all_close(test_objective="Multi modal projector outputs", computed_value=output, reference_value=ref_output, rtol=rtol, atol=atol, equal_nan=True) diff --git a/contrib/models/gemma3-vision/test/unit/gemma3/test_rms.py b/contrib/models/gemma3-vision/test/unit/gemma3/test_rms.py new file mode 100644 index 00000000..17cad2dd --- /dev/null +++ b/contrib/models/gemma3-vision/test/unit/gemma3/test_rms.py @@ -0,0 +1,39 @@ +import pytest +import torch +import torch_xla + +from gemma3_vision.modeling_gemma3_text import NeuronGemma3RMSNorm, Gemma3RMSNorm +from test.utils import assert_tensor_all_close, mark_step, BF16_TOLERANCES + + +@pytest.mark.parametrize("inputs_dtype, tolerances", [ + (torch.bfloat16, BF16_TOLERANCES), + ]) +def test_custom_vs_hf_rms_norm_implementation(random_seed, inputs_dtype, tolerances, hf_config) -> None: + device = torch_xla.device() + batch_size, sequence_length = 2, 16 + hidden_size, eps = hf_config.text_config.hidden_size, hf_config.text_config.rms_norm_eps + + x = torch.rand((batch_size, sequence_length, hidden_size), dtype=inputs_dtype) + nrn_norm = NeuronGemma3RMSNorm(hidden_size=hidden_size, eps=eps) + nrn_norm.eval() + ref_norm = Gemma3RMSNorm(dim=hidden_size, eps=eps) + ref_norm.load_state_dict(nrn_norm.state_dict(), strict=True) + ref_norm.eval() + + x = x.to(device=device) + ref_norm = ref_norm.to(device=device) + nrn_norm = nrn_norm.to(device=device) + + with torch.no_grad(): + mark_step() + ref_output = ref_norm(x) + mark_step() + nrn_output = nrn_norm(x) + mark_step() + + ref_output = ref_output.cpu() + nrn_output = nrn_output.cpu() + + rtol, atol = tolerances.rtol, tolerances.atol + assert_tensor_all_close(test_objective="RMS Norm", computed_value=nrn_output, reference_value=ref_output, rtol=rtol, atol=atol, equal_nan=True) diff --git a/contrib/models/gemma3-vision/test/unit/gemma3/test_rope.py b/contrib/models/gemma3-vision/test/unit/gemma3/test_rope.py new file mode 100644 index 00000000..775f66b7 --- /dev/null +++ b/contrib/models/gemma3-vision/test/unit/gemma3/test_rope.py @@ -0,0 +1,99 @@ +import pytest +import torch +from transformers.models.gemma3.modeling_gemma3 import Gemma3RotaryEmbedding + +from gemma3_vision.modeling_gemma3_text import NeuronGemma3RotaryEmbedding +from test.utils import assert_tensor_all_close, mark_step, cpu_setup, create_neuron_config, FP32_TOLERANCES, FP16_TOLERANCES, BF16_TOLERANCES + + +@pytest.mark.parametrize("inputs_dtype, tolerances", [ + (torch.float32, FP32_TOLERANCES), + (torch.bfloat16, BF16_TOLERANCES), + ]) +@pytest.mark.parametrize("position", [128, 1024, 2048, 4096, 6144, 8192]) +def test_rope_global_vs_transformers_implementation(inputs_dtype, tolerances, position, hf_config) -> None: + # --- Set NxDI Model --- + batch_size, max_seq_len = 2, 64 + nrn_config = create_neuron_config( + batch_size=batch_size, + max_seq_len=max_seq_len, + torch_dtype=inputs_dtype, + tp_degree=1, + hf_config=hf_config + ) + + partial_rotary_factor = getattr(nrn_config.text_config, "partial_rotary_factor", 1.0) + dim = int(nrn_config.text_config.head_dim * partial_rotary_factor) + max_position_embeddings = nrn_config.text_config.max_position_embeddings + + nrn_rope = NeuronGemma3RotaryEmbedding( + dim=dim, + max_position_embeddings=max_position_embeddings, + base=nrn_config.text_config.rope_theta, + scaling_type=nrn_config.text_config.rope_scaling["rope_type"], + scaling_factor=nrn_config.text_config.rope_scaling["factor"], + ) + + # --- Set Transformers Model --- + hf_text_config = hf_config.text_config + reference_rope = Gemma3RotaryEmbedding(config=hf_text_config) + + # --- Inputs --- + batch_size, sequence_length, num_heads, head_dim = 2, 1, 1, 128 + x = torch.randn(batch_size, num_heads, sequence_length, head_dim).to(dtype=inputs_dtype) + position_ids = torch.full((batch_size, sequence_length), position, dtype=torch.int32) + + # --- Run Rope --- + ref_cos, ref_sin = reference_rope(x, position_ids) + cos, sin = nrn_rope(x, position_ids) + + rtol, atol = tolerances.rtol, tolerances.atol + assert_tensor_all_close(test_objective="cos", computed_value=cos, reference_value=ref_cos, rtol=rtol, atol=atol, equal_nan=True) + assert_tensor_all_close(test_objective="sin", computed_value=sin, reference_value=ref_sin, rtol=rtol, atol=atol, equal_nan=True) + + +@pytest.mark.parametrize("inputs_dtype, tolerances", [ + (torch.float32, FP32_TOLERANCES), + (torch.bfloat16, BF16_TOLERANCES), + ]) +@pytest.mark.parametrize("position", [128, 1024, 2048, 4096, 6144, 8192]) +def test_rope_local_vs_transformers_implementation(inputs_dtype, tolerances, position, hf_config) -> None: + # --- Set NxDI Model --- + batch_size, max_seq_len = 2, 64 + nrn_config = create_neuron_config( + batch_size=batch_size, + max_seq_len=max_seq_len, + torch_dtype=inputs_dtype, + tp_degree=1, + hf_config=hf_config + ) + + partial_rotary_factor = getattr(nrn_config.text_config, "partial_rotary_factor", 1.0) + dim = int(nrn_config.text_config.head_dim * partial_rotary_factor) + max_position_embeddings = nrn_config.text_config.max_position_embeddings + + nrn_rope = NeuronGemma3RotaryEmbedding( + dim=dim, + max_position_embeddings=max_position_embeddings, + base=nrn_config.text_config.rope_local_base_freq, + ) + + # --- Set Transformers Model --- + hf_text_config = hf_config.text_config # nosec B615 + hf_text_config.rope_theta = hf_text_config.rope_local_base_freq + hf_text_config.rope_scaling = {"rope_type": "default"} + + reference_rope = Gemma3RotaryEmbedding(config=hf_text_config) + + # --- Inputs --- + batch_size, sequence_length, num_heads, head_dim = 2, 1, 1, 128 + x = torch.randn(batch_size, num_heads, sequence_length, head_dim).to(dtype=inputs_dtype) + position_ids = torch.full((batch_size, sequence_length), position, dtype=torch.int32) + + # --- Run Rope --- + ref_cos, ref_sin = reference_rope(x, position_ids) + cos, sin = nrn_rope(x, position_ids) + + rtol, atol = tolerances.rtol, tolerances.atol + assert_tensor_all_close(test_objective="cos", computed_value=cos, reference_value=ref_cos, rtol=rtol, atol=atol, equal_nan=True) + assert_tensor_all_close(test_objective="sin", computed_value=sin, reference_value=ref_sin, rtol=rtol, atol=atol, equal_nan=True) diff --git a/contrib/models/gemma3-vision/test/unit/gemma3/test_text_model.py b/contrib/models/gemma3-vision/test/unit/gemma3/test_text_model.py new file mode 100644 index 00000000..03330253 --- /dev/null +++ b/contrib/models/gemma3-vision/test/unit/gemma3/test_text_model.py @@ -0,0 +1,99 @@ +import logging +from typing import Dict, OrderedDict + +import torch +from transformers.models.gemma3.modeling_gemma3 import Gemma3TextModel + +from gemma3_vision.modeling_gemma3_text import NeuronGemma3TextModel +from test.utils import ( + assert_tensor_all_close, mark_step, cpu_setup, create_neuron_config, causal_mask, window_mask, + FP32_TOLERANCES, FP16_TOLERANCES, BF16_TOLERANCES, + MockKVCacheManager +) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +def convert_to_hf_state_dict(state_dict: OrderedDict[str, torch.FloatTensor]) -> Dict[str, torch.FloatTensor]: + hf_state_dict = {} + for key, tensor in state_dict.items(): + if key.find('self_attn.') != -1: + if key.find("qk_norm.") != -1: + # in Gemma3RMSNorm, weights are initialized with torch.zeros + # while Neuron's CustomRMSNorms initializes with torch.ones + hf_state_dict[key.replace('qk_norm.', 'q_norm.')] = torch.zeros_like(tensor) + hf_state_dict[key.replace('qk_norm.', 'k_norm.')] = torch.zeros_like(tensor) + else: + # q/k/v/o projection weight + parts = key.split('.') + del parts[-3] + key = '.'.join(parts) + hf_state_dict[key] = tensor + elif key.find("_layernorm.") != -1 or key == "norm.weight": + hf_state_dict[key] = torch.zeros_like(tensor) + else: + hf_state_dict[key] = tensor + return hf_state_dict + + +def test_nxdi_text_model_cpu_vs_transformers_implementation(random_seed, hf_config) -> None: + model_dtype = torch.float32 + batch_size, seq_len = 2, 32 + hf_config.text_config.sliding_window = 10 + hf_config.text_config.query_pre_attn_scalar = hf_config.text_config.head_dim + hf_config.text_config.num_hidden_layers = 1 # smaller network for quick testing + + # --- Set NxDI Model --- + + nrn_config = create_neuron_config( + batch_size=batch_size, + max_seq_len=seq_len, + torch_dtype=model_dtype, + tp_degree=1, + hf_config=hf_config + ) + + cpu_setup(model_dtype) + text_model = NeuronGemma3TextModel(config=nrn_config.text_config, optimize_inference=False).to(dtype=model_dtype) + text_model.kv_mgr = MockKVCacheManager(config=nrn_config.text_config, num_kv_head=nrn_config.text_config.num_key_value_heads) + text_model.eval() + + # --- Set Transformers Model --- + reference_model = Gemma3TextModel(hf_config.text_config) + reference_model.load_state_dict(convert_to_hf_state_dict(text_model.state_dict()), strict=False) + reference_model.eval() + + # --- Set Inputs --- + input_ids = torch.randint(0, hf_config.text_config.vocab_size, (batch_size, seq_len)).to(dtype=torch.long) + position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1).to(dtype=torch.long) + seq_ids = torch.arange(batch_size).to(dtype=torch.long) + attention_mask = causal_mask(batch_size, seq_len).to(dtype=torch.long) + attention_mask_hf = torch.ones((batch_size, seq_len)).to(dtype=torch.bool) + + with torch.no_grad(): + device = torch.device("cpu") + ref_last_hidden_state = reference_model( + input_ids=input_ids, + attention_mask=attention_mask_hf, + position_ids=position_ids, + use_cache=None + ).last_hidden_state + + # pass through lm_head manually as logit calculation happens at a higher model class (Gemma3ForCausalLM) in HF + lm_head = torch.nn.Linear(hf_config.text_config.hidden_size, hf_config.text_config.vocab_size, bias=False) + lm_head.load_state_dict({"weight": text_model.state_dict()["lm_head.weight"]}, strict=True) + ref_output = lm_head(ref_last_hidden_state[:, -1:, :]) + + output, *_ = text_model( + input_ids=input_ids.to(device=device), + attention_mask=attention_mask.to(device=device), + position_ids=position_ids.to(device=device), + seq_ids=seq_ids.to(device=device), + sampling_params=None, + kv_cache=None + ) # first item is logits when on_device_sampling is off + + rtol, atol = FP32_TOLERANCES.rtol, FP32_TOLERANCES.atol + print((ref_output - output).abs().max()) + assert_tensor_all_close(test_objective="Gemma3 text model - nxdi (cpu) vs huggingface", computed_value=output, reference_value=ref_output, rtol=rtol, atol=atol, equal_nan=True) diff --git a/contrib/models/gemma3-vision/test/unit/gemma3/test_vision_model.py b/contrib/models/gemma3-vision/test/unit/gemma3/test_vision_model.py new file mode 100644 index 00000000..644db6a6 --- /dev/null +++ b/contrib/models/gemma3-vision/test/unit/gemma3/test_vision_model.py @@ -0,0 +1,104 @@ +import os + +import pytest +import torch +import torch_xla.core.xla_model as xm +from transformers.models.gemma3.modeling_gemma3 import Gemma3ForConditionalGeneration +from neuronx_distributed_inference.utils.random import set_random_seed +from neuronx_distributed_inference.utils.testing import destroy_mp + +from gemma3_vision.modeling_gemma3_vision import NeuronGemma3VisionModel +from test.utils import assert_tensor_all_close, mark_step, cpu_setup, create_neuron_config, FP32_TOLERANCES, FP16_TOLERANCES, BF16_TOLERANCES + + +@pytest.mark.parametrize("tolerances, compiler_flags", [ + (FP32_TOLERANCES, ["--model-type=transformer", "--auto-cast=none"]), + (FP16_TOLERANCES, ["--model-type=transformer", "--auto-cast=matmult", "--enable-mixed-precision-accumulation", "--auto-cast-type=fp16"]), + (BF16_TOLERANCES, ["--model-type=transformer", "--auto-cast=matmult", "--enable-mixed-precision-accumulation", "--auto-cast-type=bf16"]), + ]) +def test_vision_model(monkeypatch, base_compiler_flags, tolerances, compiler_flags, hf_config) -> None: + monkeypatch.setenv("NEURON_CC_FLAGS", " ".join(base_compiler_flags + compiler_flags)) + + batch_size, seq_len = 2, 64 + num_channels, image_size = hf_config.vision_config.num_channels, hf_config.vision_config.image_size + inputs_dtype = model_dtype = torch.float32 + hf_config.vision_config.num_hidden_layers = 5 # test with smaller network + + pixel_values = torch.randn(batch_size, num_channels, image_size, image_size).to(dtype=inputs_dtype) + + nrn_config = create_neuron_config( + batch_size=batch_size, + max_seq_len=seq_len, + torch_dtype=model_dtype, + tp_degree=1, + hf_config=hf_config + ) + + # --- CPU Reference Execution --- + # Note: We explicitly set 'NXD_CPU_MODE' to force a CPU-only environment. + # This is critical because the module's initialization logic (in + # get_rmsnorm_cls) checks this variable to choose between the + # CPU and Neuron-specific RMSNorm implementations. + cpu_setup(model_dtype) + cpu_vision_model = NeuronGemma3VisionModel(config=nrn_config).to(dtype=model_dtype) + cpu_vision_model.eval() + + with torch.no_grad(): + cpu_output = cpu_vision_model(pixel_values) + + # --- Neuron Device Execution --- + # Note: Tear down CPU environment and switch to NeuronCore mode + destroy_mp() + os.environ.setdefault("NXD_CPU_MODE", "0") + set_random_seed(0) + + nrn_vision_model = NeuronGemma3VisionModel(config=nrn_config).to(dtype=model_dtype) + nrn_vision_model.eval() + + with torch.no_grad(): + nrn_vision_model = nrn_vision_model.to(device=xm.xla_device()) + mark_step() + nrn_output = nrn_vision_model(pixel_values.to(device=xm.xla_device())) + mark_step() + nrn_output = nrn_output.cpu() + + rtol, atol = tolerances.rtol, tolerances.atol + assert_tensor_all_close(test_objective="Gemma3 vision model outputs", computed_value=nrn_output, reference_value=cpu_output, rtol=rtol, atol=atol, equal_nan=True) + + +def test_nxdi_vision_model_vs_transformers_implementation(random_seed, hf_config) -> None: + batch_size, seq_len = 2, 64 + num_channels, image_size = hf_config.vision_config.num_channels, hf_config.vision_config.image_size + inputs_dtype = model_dtype = torch.float32 + hf_config.vision_config.num_hidden_layers = 5 # test with smaller network + + pixel_values = torch.randn(batch_size, num_channels, image_size, image_size).to(dtype=inputs_dtype) + + # --- Set NxDI Model --- + + nrn_config = create_neuron_config( + batch_size=batch_size, + max_seq_len=seq_len, + torch_dtype=model_dtype, + tp_degree=1, + hf_config=hf_config + ) + + vision_model = NeuronGemma3VisionModel(config=nrn_config).to(dtype=model_dtype) + vision_model.eval() + vision_model.to(device=xm.xla_device()) + + # --- Set Transformers Model --- + reference_model = Gemma3ForConditionalGeneration(config=hf_config).to(dtype=model_dtype) + reference_model.load_state_dict(vision_model.state_dict(), strict=False) + reference_model.eval() + + with torch.no_grad(): + # reference model Gemma3ForConditionalGeneration includes a language model (LM) + # use get_image_features() to pass the input pixel through vision_tower and multi_modal_projector only (exclude LM) + ref_output = reference_model.get_image_features(pixel_values) + output = vision_model(pixel_values.to(device=xm.xla_device())) + output = output.cpu() + + rtol, atol = FP32_TOLERANCES.rtol, FP32_TOLERANCES.atol + assert_tensor_all_close(test_objective="Gemma3 vision model outputs", computed_value=output, reference_value=ref_output, rtol=rtol, atol=atol, equal_nan=True) diff --git a/contrib/models/gemma3-vision/test/unit/siglip/test_encoder.py b/contrib/models/gemma3-vision/test/unit/siglip/test_encoder.py new file mode 100644 index 00000000..4c24895f --- /dev/null +++ b/contrib/models/gemma3-vision/test/unit/siglip/test_encoder.py @@ -0,0 +1,131 @@ +import pytest +import torch +import torch_xla.core.xla_model as xm +from transformers.models.siglip.modeling_siglip import SiglipEncoder + +from gemma3_vision.siglip.modeling_siglip import NeuronSiglipConfig, SiglipInferenceConfig, NeuronSiglipEncoder +from test.utils import assert_tensor_all_close, mark_step, FP32_TOLERANCES + + +def convert_neuron_siglip_encoder_state_dict_to_hf(neuron_state_dict: dict) -> dict: + """ + Convert Neuron SigLIP encoder state dict to HuggingFace format. + + Neuron model has: + - layers.X.self_attn.qkv_proj.{q,k,v}_proj.{weight,bias} + - layers.X.self_attn.o_proj.o_proj.{weight,bias} + - layers.X.self_attn.rank_util.rank (not needed in HF) + + HuggingFace model expects: + - layers.X.self_attn.{q,k,v}_proj.{weight,bias} + - layers.X.self_attn.out_proj.{weight,bias} + """ + hf_state_dict = {} + + for key, value in neuron_state_dict.items(): + # Skip rank_util parameters (not needed in HF) + if "rank_util" in key: + continue + + # Convert qkv_proj paths + if "qkv_proj.q_proj" in key: + new_key = key.replace("qkv_proj.q_proj", "q_proj") + hf_state_dict[new_key] = value + elif "qkv_proj.k_proj" in key: + new_key = key.replace("qkv_proj.k_proj", "k_proj") + hf_state_dict[new_key] = value + elif "qkv_proj.v_proj" in key: + new_key = key.replace("qkv_proj.v_proj", "v_proj") + hf_state_dict[new_key] = value + # Convert o_proj path + elif "o_proj.o_proj" in key: + new_key = key.replace("o_proj.o_proj", "out_proj") + hf_state_dict[new_key] = value + else: + # Keep other parameters as-is + hf_state_dict[key] = value + + return hf_state_dict + + +@pytest.mark.parametrize("tolerances, compiler_flags", [ + (FP32_TOLERANCES, ["--model-type=transformer", "--auto-cast=none"]), + ]) +def test_encoder(monkeypatch, base_compiler_flags, tolerances, compiler_flags, hf_config) -> None: + monkeypatch.setenv("NEURON_CC_FLAGS", " ".join(base_compiler_flags + compiler_flags)) + + batch_size, seq_len, hidden_size = 2, 32, hf_config.vision_config.hidden_size + inputs_dtype = model_dtype = torch.float32 + device = xm.xla_device() + + inputs_embeds = torch.randn(batch_size, seq_len, hidden_size).to(dtype=inputs_dtype) + attention_mask = torch.ones(batch_size, 1, seq_len, seq_len).to(dtype=inputs_dtype) + + neuron_config = NeuronSiglipConfig( + tp_degree=1, + batch_size=batch_size, + max_context_length=seq_len, + torch_dtype=model_dtype, + ) + + config = SiglipInferenceConfig(neuron_config=neuron_config, **hf_config.vision_config.to_dict()) + + encoder = NeuronSiglipEncoder(config=config) + encoder.eval() + + with torch.no_grad(): + output_cpu = encoder( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + ).last_hidden_state + + encoder = encoder.to(device=device) + mark_step() + output_nrn = encoder( + inputs_embeds=inputs_embeds.to(device=device), + attention_mask=attention_mask.to(device=device), + ).last_hidden_state + mark_step() + output_nrn = output_nrn.cpu() + + rtol, atol = tolerances.rtol, tolerances.atol + assert_tensor_all_close(test_objective="Encoder last hidden states", computed_value=output_nrn, reference_value=output_cpu, rtol=rtol, atol=atol, equal_nan=True) + + +def test_nxdi_encoder_vs_transformers_implementation(random_seed, hf_config) -> None: + batch_size, seq_len, hidden_size = 2, 32, hf_config.vision_config.hidden_size + inputs_dtype = model_dtype = torch.float32 + + inputs_embeds = torch.randn(batch_size, seq_len, hidden_size).to(dtype=inputs_dtype) + attention_mask = torch.ones(batch_size, 1, seq_len, seq_len).to(dtype=inputs_dtype) + + neuron_config = NeuronSiglipConfig( + tp_degree=1, + batch_size=batch_size, + max_context_length=seq_len, + torch_dtype=model_dtype, + ) + + config = SiglipInferenceConfig(neuron_config=neuron_config, **hf_config.vision_config.to_dict()) + + encoder = NeuronSiglipEncoder(config=config) + encoder.eval() + + hf_config.vision_config._attn_implementation = "eager" + reference_model = SiglipEncoder(config=hf_config.vision_config).to(dtype=model_dtype) + hf_state_dict = convert_neuron_siglip_encoder_state_dict_to_hf(encoder.state_dict()) + reference_model.load_state_dict(hf_state_dict, strict=True) + reference_model.eval() + + with torch.no_grad(): + ref_output = reference_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + ).last_hidden_state + output = encoder( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + ).last_hidden_state + + rtol, atol = FP32_TOLERANCES.rtol, FP32_TOLERANCES.atol + assert_tensor_all_close(test_objective="Encoder last hidden states", computed_value=output, reference_value=ref_output, rtol=rtol, atol=atol, equal_nan=True) diff --git a/contrib/models/gemma3-vision/test/unit/siglip/test_encoder_layer.py b/contrib/models/gemma3-vision/test/unit/siglip/test_encoder_layer.py new file mode 100644 index 00000000..ec2bef31 --- /dev/null +++ b/contrib/models/gemma3-vision/test/unit/siglip/test_encoder_layer.py @@ -0,0 +1,112 @@ +import logging +import pytest +from typing import Dict, OrderedDict + +import torch +import torch_xla.core.xla_model as xm +from transformers import AutoConfig, AutoModel +from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer + +from gemma3_vision.siglip.modeling_siglip import NeuronSiglipConfig, SiglipInferenceConfig, NeuronSiglipEncoderLayer +from test.utils import assert_tensor_all_close, mark_step, FP32_TOLERANCES + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +def convert_to_hf_state_dict(state_dict: OrderedDict[str, torch.FloatTensor]) -> Dict[str, torch.FloatTensor]: + hf_state_dict = {} + for key, tensor in state_dict.items(): + print(key) + if key.startswith("self_attn.qkv_proj."): + hf_state_dict[key.replace("qkv_proj.", "")] = tensor + elif key.startswith("self_attn.o_proj."): + hf_state_dict[key.replace("o_proj.o_proj.", "out_proj.")] = tensor + elif key.endswith("rank"): + logger.info(f"Skipping neuron-related key: {key}") + else: + hf_state_dict[key] = tensor + return hf_state_dict + +config = AutoConfig.from_pretrained("google/gemma-3-27b-it") # nosec B615 +hf_config = AutoModel.from_config(config=config.vision_config).config + + +@pytest.mark.parametrize("tolerances, compiler_flags", [ + (FP32_TOLERANCES, ["--model-type=transformer", "--auto-cast=none"]), + ]) +def test_encoder_layer(monkeypatch, base_compiler_flags, tolerances, compiler_flags) -> None: + monkeypatch.setenv("NEURON_CC_FLAGS", " ".join(base_compiler_flags + compiler_flags)) + + batch_size, seq_len, hidden_size = 2, 32, hf_config.hidden_size + inputs_dtype = model_dtype = torch.float32 + device = xm.xla_device() + + hidden_states = torch.randn(batch_size, seq_len, hidden_size).to(dtype=inputs_dtype) + attention_mask = torch.ones(batch_size, 1, seq_len, seq_len).to(dtype=inputs_dtype) + + neuron_config = NeuronSiglipConfig( + tp_degree=1, + batch_size=batch_size, + max_context_length=seq_len, + torch_dtype=model_dtype, + ) + + config = SiglipInferenceConfig(neuron_config=neuron_config, **hf_config.to_dict()) + + encoder_layer = NeuronSiglipEncoderLayer(config=config) + encoder_layer.eval() + + with torch.no_grad(): + output_cpu, *_ = encoder_layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + + encoder_layer = encoder_layer.to(device=device) + mark_step() + output_nrn, *_ = encoder_layer( + hidden_states=hidden_states.to(device=device), + attention_mask=attention_mask.to(device=device), + ) + mark_step() + output_nrn = output_nrn.cpu() + + rtol, atol = tolerances.rtol, tolerances.atol + assert_tensor_all_close(test_objective="Encoder layer outputs", computed_value=output_nrn, reference_value=output_cpu, rtol=rtol, atol=atol, equal_nan=True) + + +def test_nxdi_encoder_layer_vs_transformers_implementation(random_seed) -> None: + batch_size, seq_len, hidden_size = 2, 32, hf_config.hidden_size + inputs_dtype = model_dtype = torch.float32 + + hidden_states = torch.randn(batch_size, seq_len, hidden_size).to(dtype=inputs_dtype) + attention_mask = torch.ones(batch_size, 1, seq_len, seq_len).to(dtype=inputs_dtype) + + neuron_config = NeuronSiglipConfig( + tp_degree=1, + batch_size=batch_size, + max_context_length=seq_len, + torch_dtype=model_dtype, + ) + + config = SiglipInferenceConfig(neuron_config=neuron_config, **hf_config.to_dict()) + + encoder_layer = NeuronSiglipEncoderLayer(config=config) + encoder_layer.eval() + + reference_model = SiglipEncoderLayer(config=hf_config).to(dtype=model_dtype) + reference_model.load_state_dict(convert_to_hf_state_dict(encoder_layer.state_dict()), strict=True) + reference_model.eval() + + with torch.no_grad(): + ref_output, *_ = reference_model( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + output, *_ = encoder_layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + + rtol, atol = FP32_TOLERANCES.rtol, FP32_TOLERANCES.atol + assert_tensor_all_close(test_objective="Encoder layer outputs", computed_value=output, reference_value=ref_output, rtol=rtol, atol=atol, equal_nan=True) diff --git a/contrib/models/gemma3-vision/test/unit/siglip/test_mlp.py b/contrib/models/gemma3-vision/test/unit/siglip/test_mlp.py new file mode 100644 index 00000000..bdf6576e --- /dev/null +++ b/contrib/models/gemma3-vision/test/unit/siglip/test_mlp.py @@ -0,0 +1,76 @@ +import pytest +import torch +import torch_xla.core.xla_model as xm +from transformers.models.siglip.modeling_siglip import SiglipMLP + +from gemma3_vision.siglip.modeling_siglip import NeuronSiglipConfig, SiglipInferenceConfig, NeuronSiglipMLP +from test.utils import assert_tensor_all_close, mark_step, FP32_TOLERANCES, FP16_TOLERANCES, BF16_TOLERANCES + + +@pytest.mark.parametrize("tolerances, compiler_flags", [ + (FP32_TOLERANCES, ["--model-type=transformer", "--auto-cast=none"]), + (FP16_TOLERANCES, ["--model-type=transformer", "--auto-cast=matmult", "--enable-mixed-precision-accumulation", "--auto-cast-type=fp16"]), + (BF16_TOLERANCES, ["--model-type=transformer", "--auto-cast=matmult", "--enable-mixed-precision-accumulation", "--auto-cast-type=bf16"]), + ]) +def test_mlp_layer(monkeypatch, base_compiler_flags, tolerances, compiler_flags, hf_config) -> None: + monkeypatch.setenv("NEURON_CC_FLAGS", " ".join(base_compiler_flags + compiler_flags)) + + batch_size, seq_len, hidden_size = 2, 32, hf_config.vision_config.hidden_size + inputs_dtype = model_dtype = torch.float32 + device = xm.xla_device() + + x = torch.randn(batch_size, seq_len, hidden_size).to(dtype=inputs_dtype) + + neuron_config = NeuronSiglipConfig( + tp_degree=2, + batch_size=batch_size, + max_context_length=seq_len, + torch_dtype=model_dtype, + ) + + config = SiglipInferenceConfig(neuron_config=neuron_config, **hf_config.vision_config.to_dict()) + + mlp_layer = NeuronSiglipMLP(config).to(dtype=model_dtype) + mlp_layer.eval() + + with torch.no_grad(): + cpu_output = mlp_layer(x) + + mlp_layer = mlp_layer.to(device=device) + mark_step() + nrn_output = mlp_layer(x.to(device=device)) + mark_step() + nrn_output = nrn_output.cpu() + + rtol, atol = tolerances.rtol, tolerances.atol + assert_tensor_all_close(test_objective="MLP outputs", computed_value=nrn_output, reference_value=cpu_output, rtol=rtol, atol=atol, equal_nan=True) + + +def test_nxdi_mlp_vs_transformers_implementation(random_seed, hf_config) -> None: + batch_size, seq_len = 2, 32 + inputs_dtype = model_dtype = torch.float32 + + x = torch.randn(batch_size, seq_len, hf_config.vision_config.hidden_size).to(dtype=inputs_dtype) + + neuron_config = NeuronSiglipConfig( + tp_degree=1, + batch_size=batch_size, + max_context_length=seq_len, + torch_dtype=model_dtype, + ) + + config = SiglipInferenceConfig(neuron_config=neuron_config, **hf_config.vision_config.to_dict()) + + mlp_layer = NeuronSiglipMLP(config=config).to(dtype=model_dtype) + mlp_layer.eval() + + reference_model = SiglipMLP(config=hf_config.vision_config).to(dtype=model_dtype) + reference_model.load_state_dict(mlp_layer.state_dict(), strict=True) + reference_model.eval() + + with torch.no_grad(): + ref_output = reference_model(hidden_states=x) + output = mlp_layer(hidden_states=x) + + rtol, atol = FP32_TOLERANCES.rtol, FP32_TOLERANCES.atol + assert_tensor_all_close(test_objective="MLP outputs", computed_value=output, reference_value=ref_output, rtol=rtol, atol=atol, equal_nan=True) diff --git a/contrib/models/gemma3-vision/test/unit/siglip/test_siglip_attention.py b/contrib/models/gemma3-vision/test/unit/siglip/test_siglip_attention.py new file mode 100644 index 00000000..7447cc2c --- /dev/null +++ b/contrib/models/gemma3-vision/test/unit/siglip/test_siglip_attention.py @@ -0,0 +1,117 @@ +import logging +import pytest +from typing import Dict, OrderedDict + +import torch +import torch_xla.core.xla_model as xm +from transformers.models.siglip.modeling_siglip import SiglipAttention + +from gemma3_vision.siglip.modeling_siglip import NeuronSiglipConfig, SiglipInferenceConfig, NeuronSiglipAttention +from test.utils import ( + assert_tensor_all_close, + mark_step, + FP32_TOLERANCES, + FP16_TOLERANCES, + BF16_TOLERANCES +) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +def convert_to_hf_state_dict(state_dict: OrderedDict[str, torch.FloatTensor]) -> Dict[str, torch.FloatTensor]: + hf_state_dict = {} + for key, tensor in state_dict.items(): + if key.startswith("qkv_proj."): + hf_state_dict[key.replace("qkv_proj.", "")] = tensor + elif key.startswith("o_proj."): + hf_state_dict[key.replace("o_proj.o_proj.", "out_proj.")] = tensor + else: + logger.info(f"Skipping unexpected input key: {key}") + return hf_state_dict + + +@pytest.mark.parametrize("tolerances, compiler_flags", [ + (FP32_TOLERANCES, ["--model-type=transformer", "--auto-cast=none"]), + (FP16_TOLERANCES, ["--model-type=transformer", "--auto-cast=matmult", "--enable-mixed-precision-accumulation", "--auto-cast-type=fp16"]), + (BF16_TOLERANCES, ["--model-type=transformer", "--auto-cast=matmult", "--enable-mixed-precision-accumulation", "--auto-cast-type=bf16"]), + ]) +def test_attention_layer(monkeypatch, base_compiler_flags, tolerances, compiler_flags, hf_config) -> None: + monkeypatch.setenv("NEURON_CC_FLAGS", " ".join(base_compiler_flags + compiler_flags)) + + batch_size, seq_len, hidden_size = 2, 32, hf_config.vision_config.hidden_size + inputs_dtype = model_dtype = torch.float32 + device = xm.xla_device() + + hidden_states = torch.randn(batch_size, seq_len, hidden_size).to(dtype=inputs_dtype) + attention_mask = torch.ones(batch_size, 1, seq_len, seq_len).to(dtype=inputs_dtype) + + neuron_config = NeuronSiglipConfig( + tp_degree=1, + batch_size=batch_size, + max_context_length=seq_len, + torch_dtype=model_dtype, + ) + + config = SiglipInferenceConfig(neuron_config=neuron_config, **hf_config.vision_config.to_dict()) + + attn_layer = NeuronSiglipAttention(config=config) + attn_layer.eval() + + with torch.no_grad(): + output_cpu, *_ = attn_layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + + attn_layer = attn_layer.to(device=device) + mark_step() + output_nrn, *_ = attn_layer( + hidden_states=hidden_states.to(device=device), + attention_mask=attention_mask.to(device=device), + ) + mark_step() + output_nrn = output_nrn.cpu() + + rtol, atol = tolerances.rtol, tolerances.atol + assert_tensor_all_close(test_objective="Attention outputs", computed_value=output_nrn, reference_value=output_cpu, rtol=rtol, atol=atol, equal_nan=True) + + +# Note: As HuggingFace Transformers supports left padding only, we can only test the NxDI implementation of the attention layer +# and therefore the SWA implementation, for left padding only +def test_nxdi_attn_vs_transformers_implementation(random_seed, hf_config) -> None: + batch_size, seq_len, hidden_size = 2, 32, hf_config.vision_config.hidden_size + inputs_dtype = model_dtype = torch.float32 + + hidden_states = torch.randn(batch_size, seq_len, hidden_size).to(dtype=inputs_dtype) + attention_mask = torch.ones(batch_size, 1, seq_len, seq_len).to(dtype=inputs_dtype) + + neuron_config = NeuronSiglipConfig( + tp_degree=1, + batch_size=batch_size, + max_context_length=seq_len, + torch_dtype=model_dtype, + ) + + config = SiglipInferenceConfig(neuron_config=neuron_config, **hf_config.vision_config.to_dict()) + + attn_layer = NeuronSiglipAttention(config=config) + attn_layer.eval() + + hf_config.vision_config._attn_implementation = "eager" + reference_model = SiglipAttention(config=hf_config.vision_config).to(dtype=model_dtype) + reference_model.load_state_dict(convert_to_hf_state_dict(attn_layer.state_dict()), strict=True) + reference_model.eval() + + with torch.no_grad(): + ref_output, *_ = reference_model( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + output, *_ = attn_layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + + rtol, atol = FP32_TOLERANCES.rtol, FP32_TOLERANCES.atol + assert_tensor_all_close(test_objective="Attention outputs", computed_value=output, reference_value=ref_output, rtol=rtol, atol=atol, equal_nan=True) diff --git a/contrib/models/gemma3-vision/test/unit/siglip/test_siglip_vision_model.py b/contrib/models/gemma3-vision/test/unit/siglip/test_siglip_vision_model.py new file mode 100644 index 00000000..05724517 --- /dev/null +++ b/contrib/models/gemma3-vision/test/unit/siglip/test_siglip_vision_model.py @@ -0,0 +1,112 @@ +import logging +import pytest +from typing import Dict, OrderedDict + +import torch +import torch_xla.core.xla_model as xm +from transformers.models.siglip.modeling_siglip import SiglipVisionModel + +from gemma3_vision.siglip.modeling_siglip import NeuronSiglipConfig, SiglipInferenceConfig, NeuronSiglipVisionModel +from test.utils import assert_tensor_all_close, mark_step, FP32_TOLERANCES + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +def convert_to_hf_state_dict(state_dict: OrderedDict[str, torch.FloatTensor]) -> Dict[str, torch.FloatTensor]: + """Convert NeuronSiglipVisionModel state dict to HuggingFace SiglipVisionModel format. + + Key mappings: + - vision_model.encoder.layers.{i}.self_attn.qkv_proj.{q,k,v}_proj.{weight,bias} + → vision_model.encoder.layers.{i}.self_attn.{q,k,v}_proj.{weight,bias} + - vision_model.encoder.layers.{i}.self_attn.o_proj.o_proj.{weight,bias} + → vision_model.encoder.layers.{i}.self_attn.out_proj.{weight,bias} + - vision_model.encoder.layers.{i}.self_attn.rank_util.rank (skip - internal tracking) + """ + hf_state_dict = {} + for key, tensor in state_dict.items(): + if "rank_util.rank" in key: + # Skip internal rank tracking tensors + logger.debug(f"Skipping internal key: {key}") + continue + elif ".qkv_proj." in key: + # qkv_proj.q_proj.weight → q_proj.weight + hf_key = key.replace(".qkv_proj.", ".") + hf_state_dict[hf_key] = tensor + elif ".o_proj.o_proj." in key: + # o_proj.o_proj.weight → out_proj.weight + hf_key = key.replace(".o_proj.o_proj.", ".out_proj.") + hf_state_dict[hf_key] = tensor + else: + hf_state_dict[key] = tensor + return hf_state_dict + + +@pytest.mark.parametrize("tolerances, compiler_flags", [ + (FP32_TOLERANCES, ["--model-type=transformer", "--auto-cast=none"]), + ]) +def test_vision_model(monkeypatch, base_compiler_flags, tolerances, compiler_flags, hf_config) -> None: + monkeypatch.setenv("NEURON_CC_FLAGS", " ".join(base_compiler_flags + compiler_flags)) + + batch_size, num_channels, image_size = 2, 3, 896 + hf_config.vision_config.num_hidden_layers = 5 # lower num_hidden_layers for faster testing + inputs_dtype = model_dtype = torch.float32 + device = xm.xla_device() + + pixel_values = torch.randn(batch_size, num_channels, image_size, image_size).to(dtype=inputs_dtype) + + neuron_config = NeuronSiglipConfig( + tp_degree=1, + batch_size=batch_size, + torch_dtype=model_dtype, + attn_kernel_enabled=False, # Otherwise, a NKI kernel is automatically selected due to the sequence length (cannot run on CPU) + ) + + config = SiglipInferenceConfig(neuron_config=neuron_config, **hf_config.vision_config.to_dict()) + + vision_model = NeuronSiglipVisionModel(config=config) + vision_model.eval() + + with torch.no_grad(): + output_cpu = vision_model(pixel_values=pixel_values).last_hidden_state + + vision_model = vision_model.to(device=device) + mark_step() + output_nrn = vision_model(pixel_values=pixel_values.to(device=device)).last_hidden_state + mark_step() + output_nrn = output_nrn.cpu() + + rtol, atol = tolerances.rtol, tolerances.atol + assert_tensor_all_close(test_objective="Vision model outputs", computed_value=output_nrn, reference_value=output_cpu, rtol=rtol, atol=atol, equal_nan=True) + + +def test_nxdi_vision_model_vs_transformers_implementation(random_seed, hf_config) -> None: + batch_size, num_channels, image_size = 2, 3, 896 + hf_config.vision_config.num_hidden_layers = 5 # lower num_hidden_layers for faster testing + inputs_dtype = model_dtype = torch.float32 + + pixel_values = torch.randn(batch_size, num_channels, image_size, image_size).to(dtype=inputs_dtype) + + neuron_config = NeuronSiglipConfig( + tp_degree=1, + batch_size=batch_size, + torch_dtype=model_dtype, + attn_kernel_enabled=False, # Otherwise, a NKI kernel is automatically selected due to the sequence length (cannot run on CPU) + ) + + config = SiglipInferenceConfig(neuron_config=neuron_config, **hf_config.vision_config.to_dict()) + + vision_model = NeuronSiglipVisionModel(config=config) + vision_model.eval() + + hf_config.vision_config._attn_implementation = "eager" + reference_model = SiglipVisionModel(config=hf_config.vision_config).to(dtype=model_dtype) + reference_model.load_state_dict(convert_to_hf_state_dict(vision_model.state_dict()), strict=True) + reference_model.eval() + + with torch.no_grad(): + ref_output = reference_model(pixel_values=pixel_values).last_hidden_state + output = vision_model(pixel_values=pixel_values).last_hidden_state + + rtol, atol = FP32_TOLERANCES.rtol, FP32_TOLERANCES.atol + assert_tensor_all_close(test_objective="Vision model outputs", computed_value=output, reference_value=ref_output, rtol=rtol, atol=atol, equal_nan=True) diff --git a/contrib/models/gemma3-vision/test/unit/siglip/test_vision_embed.py b/contrib/models/gemma3-vision/test/unit/siglip/test_vision_embed.py new file mode 100644 index 00000000..375f5303 --- /dev/null +++ b/contrib/models/gemma3-vision/test/unit/siglip/test_vision_embed.py @@ -0,0 +1,74 @@ +import pytest +import torch +import torch_xla.core.xla_model as xm +from transformers.models.siglip.modeling_siglip import SiglipVisionEmbeddings + +from gemma3_vision.siglip.modeling_siglip import NeuronSiglipConfig, SiglipInferenceConfig, NeuronSiglipVisionEmbeddings +from test.utils import assert_tensor_all_close, mark_step, FP32_TOLERANCES, FP16_TOLERANCES, BF16_TOLERANCES + + +@pytest.mark.parametrize("tolerances, compiler_flags", [ + (FP32_TOLERANCES, ["--model-type=transformer", "--auto-cast=none"]), + (FP16_TOLERANCES, ["--model-type=transformer", "--auto-cast=matmult", "--enable-mixed-precision-accumulation", "--auto-cast-type=fp16"]), + (BF16_TOLERANCES, ["--model-type=transformer", "--auto-cast=matmult", "--enable-mixed-precision-accumulation", "--auto-cast-type=bf16"]), + ]) +def test_vision_embed(monkeypatch, base_compiler_flags, tolerances, compiler_flags, hf_config) -> None: + monkeypatch.setenv("NEURON_CC_FLAGS", " ".join(base_compiler_flags + compiler_flags)) + + batch_size, num_channels, image_size = 2, 3, 896 + inputs_dtype = model_dtype = torch.float32 + device = xm.xla_device() + + pixel_values = torch.randn(batch_size, num_channels, image_size, image_size).to(dtype=inputs_dtype) + + neuron_config = NeuronSiglipConfig( + tp_degree=2, + batch_size=batch_size, + torch_dtype=model_dtype, + ) + + config = SiglipInferenceConfig(neuron_config=neuron_config, **hf_config.vision_config.to_dict()) + + vision_embed = NeuronSiglipVisionEmbeddings(config=config) + vision_embed.eval() + + with torch.no_grad(): + output_cpu = vision_embed(pixel_values=pixel_values) + + vision_embed = vision_embed.to(device=device) + mark_step() + output_nrn = vision_embed(pixel_values=pixel_values.to(device=device)) + mark_step() + output_nrn = output_nrn.cpu() + + rtol, atol = tolerances.rtol, tolerances.atol + assert_tensor_all_close(test_objective="Vision embedding outputs", computed_value=output_nrn, reference_value=output_cpu, rtol=rtol, atol=atol, equal_nan=True) + + +def test_nxdi_vision_embedding_vs_transformers_implementation(random_seed, hf_config) -> None: + batch_size, num_channels, image_size = 2, 3, 896 + inputs_dtype = model_dtype = torch.float32 + + pixel_values = torch.randn(batch_size, num_channels, image_size, image_size).to(dtype=inputs_dtype) + + neuron_config = NeuronSiglipConfig( + tp_degree=2, + batch_size=batch_size, + torch_dtype=model_dtype, + ) + + config = SiglipInferenceConfig(neuron_config=neuron_config, **hf_config.vision_config.to_dict()) + + vision_embed = NeuronSiglipVisionEmbeddings(config=config) + vision_embed.eval() + + reference_model = SiglipVisionEmbeddings(config=hf_config.vision_config).to(dtype=model_dtype) + reference_model.load_state_dict(vision_embed.state_dict(), strict=True) + reference_model.eval() + + with torch.no_grad(): + ref_output = reference_model(pixel_values=pixel_values) + output = vision_embed(pixel_values=pixel_values) + + rtol, atol = FP32_TOLERANCES.rtol, FP32_TOLERANCES.atol + assert_tensor_all_close(test_objective="Vision embedding outputs", computed_value=output, reference_value=ref_output, rtol=rtol, atol=atol, equal_nan=True) diff --git a/contrib/models/gemma3-vision/test/unit/siglip/test_vision_transformer.py b/contrib/models/gemma3-vision/test/unit/siglip/test_vision_transformer.py new file mode 100644 index 00000000..03502a3e --- /dev/null +++ b/contrib/models/gemma3-vision/test/unit/siglip/test_vision_transformer.py @@ -0,0 +1,115 @@ +import pytest +import torch +import torch_xla.core.xla_model as xm +from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer + +from gemma3_vision.siglip.modeling_siglip import NeuronSiglipConfig, SiglipInferenceConfig, NeuronSiglipVisionTransformer +from test.utils import assert_tensor_all_close, mark_step, FP32_TOLERANCES + + +def convert_neuron_to_hf_state_dict(neuron_state_dict): + """Convert Neuron model state dict to HuggingFace compatible format. + + Neuron model structure: + - encoder.layers.X.self_attn.qkv_proj.{q,k,v}_proj.{weight,bias} + - encoder.layers.X.self_attn.o_proj.o_proj.{weight,bias} + - encoder.layers.X.self_attn.rank_util.rank (excluded) + + HuggingFace model structure: + - encoder.layers.X.self_attn.{q,k,v}_proj.{weight,bias} + - encoder.layers.X.self_attn.out_proj.{weight,bias} + """ + hf_state_dict = {} + + for key, value in neuron_state_dict.items(): + # Skip rank_util parameters + if 'rank_util' in key: + continue + + # Convert qkv_proj paths + if '.qkv_proj.q_proj.' in key: + new_key = key.replace('.qkv_proj.q_proj.', '.q_proj.') + elif '.qkv_proj.k_proj.' in key: + new_key = key.replace('.qkv_proj.k_proj.', '.k_proj.') + elif '.qkv_proj.v_proj.' in key: + new_key = key.replace('.qkv_proj.v_proj.', '.v_proj.') + # Convert o_proj paths + elif '.o_proj.o_proj.' in key: + new_key = key.replace('.o_proj.o_proj.', '.out_proj.') + else: + new_key = key + + hf_state_dict[new_key] = value + + return hf_state_dict + + +@pytest.mark.parametrize("tolerances, compiler_flags", [ + (FP32_TOLERANCES, ["--model-type=transformer", "--auto-cast=none"]), + ]) +def test_vision_transformer(monkeypatch, base_compiler_flags, tolerances, compiler_flags, hf_config) -> None: + monkeypatch.setenv("NEURON_CC_FLAGS", " ".join(base_compiler_flags + compiler_flags)) + + batch_size, num_channels, image_size = 2, 3, 896 + hf_config.vision_config.num_hidden_layers = 3 # lower num_hidden_layers for faster testing + inputs_dtype = model_dtype = torch.float32 + device = xm.xla_device() + + pixel_values = torch.randn(batch_size, num_channels, image_size, image_size).to(dtype=inputs_dtype) + + neuron_config = NeuronSiglipConfig( + tp_degree=1, + batch_size=batch_size, + torch_dtype=model_dtype, + attn_kernel_enabled=False, # Otherwise, a NKI kernel is automatically selected due to the sequence length (cannot run on CPU) + ) + + config = SiglipInferenceConfig(neuron_config=neuron_config, **hf_config.vision_config.to_dict()) + + vision_transformer = NeuronSiglipVisionTransformer(config=config) + vision_transformer.eval() + + with torch.no_grad(): + output_cpu = vision_transformer(pixel_values=pixel_values).last_hidden_state + + vision_transformer = vision_transformer.to(device=device) + mark_step() + output_nrn = vision_transformer(pixel_values=pixel_values.to(device=device)).last_hidden_state + mark_step() + output_nrn = output_nrn.cpu() + + rtol, atol = tolerances.rtol, tolerances.atol + assert_tensor_all_close(test_objective="Vision transformer outputs", computed_value=output_nrn, reference_value=output_cpu, rtol=rtol, atol=atol, equal_nan=True) + + +def test_nxdi_vision_transformer_vs_transformers_implementation(random_seed, hf_config) -> None: + batch_size, num_channels, image_size = 2, 3, 896 + hf_config.vision_config.num_hidden_layers = 3 + inputs_dtype = model_dtype = torch.float32 + + pixel_values = torch.randn(batch_size, num_channels, image_size, image_size).to(dtype=inputs_dtype) + + neuron_config = NeuronSiglipConfig( + tp_degree=1, + batch_size=batch_size, + torch_dtype=model_dtype, + attn_kernel_enabled=False, # Otherwise, a NKI kernel is automatically selected due to the sequence length (cannot run on CPU) + ) + + config = SiglipInferenceConfig(neuron_config=neuron_config, **hf_config.vision_config.to_dict()) + + vision_transformer = NeuronSiglipVisionTransformer(config=config) + vision_transformer.eval() + + hf_config.vision_config._attn_implementation = "eager" + reference_model = SiglipVisionTransformer(config=hf_config.vision_config).to(dtype=model_dtype) + hf_compatible_state_dict = convert_neuron_to_hf_state_dict(vision_transformer.state_dict()) + reference_model.load_state_dict(hf_compatible_state_dict, strict=True) + reference_model.eval() + + with torch.no_grad(): + ref_output = reference_model(pixel_values=pixel_values).last_hidden_state + output = vision_transformer(pixel_values=pixel_values).last_hidden_state + + rtol, atol = FP32_TOLERANCES.rtol, FP32_TOLERANCES.atol + assert_tensor_all_close(test_objective="Vision transformer outputs", computed_value=output, reference_value=ref_output, rtol=rtol, atol=atol, equal_nan=True) diff --git a/contrib/models/gemma3-vision/test/utils.py b/contrib/models/gemma3-vision/test/utils.py new file mode 100644 index 00000000..6a472ea6 --- /dev/null +++ b/contrib/models/gemma3-vision/test/utils.py @@ -0,0 +1,287 @@ + +import os +from dataclasses import dataclass +import logging + +from neuronx_distributed_inference.models.config import NeuronConfig +from neuronx_distributed_inference.modules.kvcache.kv_cache_manager import KVCacheManager +from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config +from neuronx_distributed_inference.utils.random import set_random_seed +from neuronx_distributed_inference.utils.testing import init_cpu_env +import torch +import torch_xla +import torch_xla.core.xla_model as xm +from transformers import Gemma3Config +from transformers.configuration_utils import PretrainedConfig +from transformers.models.gemma3.modeling_gemma3 import Gemma3RotaryEmbedding + +from gemma3_vision.modeling_gemma3 import Gemma3InferenceConfig + +torch.set_printoptions(precision=5) + + +logging.basicConfig(level=logging.INFO, format="%(asctime)s.%(msecs)06d - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S") +logger = logging.getLogger(__name__) + + +@dataclass +class NumericalTolerances: + rtol: float + atol: float + +# Default tolerances from torch.testing.assert_close +FP32_TOLERANCES = NumericalTolerances(rtol=1.3e-6, atol=1e-5) +FP16_TOLERANCES = NumericalTolerances(rtol=1e-3, atol=1e-5) +BF16_TOLERANCES = NumericalTolerances(rtol=1.6e-2, atol=1e-5) + + +def create_neuron_config( + batch_size: int, + max_seq_len: int, + tp_degree: int, + torch_dtype: torch.dtype, + hf_config: Gemma3Config + ) -> Gemma3InferenceConfig: + return Gemma3InferenceConfig( + text_neuron_config=NeuronConfig( + tp_degree=tp_degree, + batch_size=batch_size, + torch_dtype=torch_dtype, + attn_kernel_enabled=False, + seq_len=max_seq_len + ), + vision_neuron_config=NeuronConfig( + tp_degree=tp_degree, + batch_size=1, + torch_dtype=torch_dtype, + attn_kernel_enabled=False, + seq_len=max_seq_len + ), + load_config=load_pretrained_config(hf_config=hf_config), + ) + + +def cpu_setup(dtype): + set_random_seed(0) + os.environ.setdefault("NXD_CPU_MODE", "1") + init_cpu_env() + torch.set_default_dtype(dtype) + torch.set_default_device("cpu") + + +def mark_step() -> None: + torch_xla.sync() + xm.wait_device_ops() + + +def assert_tensor_all_close( + test_objective: str, + computed_value: torch.FloatTensor, + reference_value: torch.FloatTensor, + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = True, + ) -> None: + assert computed_value.dtype == reference_value.dtype, "dtypes are not matching" + try: + assert torch.allclose(computed_value, reference_value, rtol, atol, equal_nan), f"{test_objective} are not matching!" + logger.info(f"{test_objective} ({reference_value.numel()} value(s)) are matching (atol={atol:.1e} - rtol={rtol:.1e})!") + except AssertionError as e: + logger.error(e) + + logger.info("------ TOTAL ERROR ANALYSIS ------") + abs_difference = torch.abs(computed_value - reference_value) + rel_difference = abs_difference / torch.abs(reference_value) + threshold = atol + torch.abs(reference_value) * rtol + mask = abs_difference > threshold + num_non_matching_values, total_values = mask.sum().item(), mask.numel() + percentage = (num_non_matching_values / total_values) * 100 + logger.info(f"{num_non_matching_values}/{total_values} value(s) ({percentage:.2f}%) are not within tolerances (atol={atol:.1e} - rtol={rtol:.1e})") + logger.info(f"Reference values: {reference_value[mask]}") + logger.info(f"Computed values: {computed_value[mask]}") + logger.info(f"Abs. diff.: {abs_difference[mask]}") + logger.info(f"Threshold: {threshold[mask]}") + + logger.info("------ ABSOLUTE ERROR ANALYSIS ------") + logger.info(f"Absolute error tolerance (atol): {atol:.1e}") + atol_dominates = atol > 10.0 * torch.abs(reference_value) * rtol + atol_dominated_values = atol_dominates.sum().item() + if atol_dominated_values: + percentage = (atol_dominated_values / total_values) * 100 + logger.info(f"Absolute error dominates (atol > 10*rtol) for {atol_dominated_values}/{total_values} value(s) ({percentage:.2f}%)") + a_mask = (abs_difference > atol) & atol_dominates + num_non_matching_values = a_mask.sum().item() + percentage = (num_non_matching_values / total_values) * 100 + logger.info(f"{num_non_matching_values}/{total_values} value(s) ({percentage:.2f}%) are not within absolute tolerances (atol={atol:.1e})") + logger.info(f"Mean abs. diff.: {abs_difference[a_mask].mean():.3e} - Max abs. diff.: {abs_difference[a_mask].max():.3e}") + logger.info(f"Reference values: {reference_value[a_mask]}") + logger.info(f"Computed values: {computed_value[a_mask]}") + logger.info(f"Abs. diff.: {abs_difference[a_mask]}") + else: + logger.info(f"There are no values (0/{total_values} value(s) - 0.00%) for which the absolute error dominates (atol > 10*rtol)") + + logger.info("------ RELATIVE ERROR ANALYSIS ------") + logger.info(f"Relative error tolerance (rtol): {rtol:.1e}") + rtol_dominates = torch.abs(reference_value) * rtol > 10.0 * atol + rtol_dominated_values = rtol_dominates.sum().item() + if rtol_dominated_values: + percentage = (rtol_dominated_values / total_values) * 100 + logger.info(f"Relative error dominates (rtol > 10*atol) for {rtol_dominated_values}/{total_values} value(s) ({percentage:.2f}%)") + r_mask = (rel_difference > rtol) & rtol_dominates + num_non_matching_values = r_mask.sum().item() + percentage = (num_non_matching_values / total_values) * 100 + logger.info(f"{num_non_matching_values}/{total_values} value(s) ({percentage:.2f}%) are not within relative tolerances (rtol={rtol:.1e})") + logger.info(f"Mean rel. diff.: {rel_difference[r_mask].mean():.3e} - Max rel. diff.: {rel_difference[r_mask].max():.3e}") + logger.info(f"Reference values: {reference_value[r_mask]}") + logger.info(f"Computed values: {computed_value[r_mask]}") + logger.info(f"Rel. diff.: {rel_difference[r_mask]}") + else: + logger.info(f"There are no values (0/{total_values} value(s) - 0.00%) for which the relative error dominates (rtol > 10*atol)") + raise e + + +# This mock KV cache manager is used to test model on CPU as NxDI implementation of KV Cache Manager requires XLA tensors. +class MockKVCacheManager(KVCacheManager): + def update_cache( + self, + is_for_context_encoding, + seq_ids, + position_ids, + new_key_values, + seq_len: int, + scatter_index=None, + active_mask=None, + kvcache_buffer=None, + **kwargs + ): + return new_key_values + + + +def create_position_ids_for_context_processing(attention_mask_2d: torch.LongTensor) -> torch.LongTensor: + position_ids = attention_mask_2d.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask_2d == 0, 1) + return position_ids + + +def create_position_ids_for_token_generation(attention_mask_2d: torch.LongTensor) -> torch.LongTensor: + full_position_ids = create_position_ids_for_context_processing(attention_mask_2d=attention_mask_2d) + return torch.amax(full_position_ids, dim=1, keepdim=True) + 1 + + +def create_position_ids(attention_mask_2d: torch.LongTensor, is_for_context_encoding: bool) -> torch.LongTensor: + if is_for_context_encoding: + return create_position_ids_for_context_processing(attention_mask_2d=attention_mask_2d) + else: + return create_position_ids_for_token_generation(attention_mask_2d=attention_mask_2d) + + +def create_cache_position(attention_mask_2d: torch.LongTensor, is_for_context_encoding: bool) -> torch.LongTensor: + # From tranformers.utils.GenerationMixin._get_initial_cache_position + cache_position = torch.ones_like(attention_mask_2d[0, :], dtype=torch.int64).cumsum(0) - 1 + if is_for_context_encoding: + return cache_position + else: + return cache_position[-1:] + + +def create_rope(position_ids: torch.LongTensor, hf_config: PretrainedConfig) -> torch.FloatTensor: + batch_size, sequence_length = position_ids.shape + x = torch.randn(batch_size, hf_config.num_attention_heads, sequence_length, hf_config.head_dim).to(dtype=torch.float32) + rope = Gemma3RotaryEmbedding(config=hf_config) + cos, sin = rope(x, position_ids) + return cos, sin + + +def create_hidden_states(attention_mask_2d: torch.LongTensor, hf_config: PretrainedConfig, is_for_context_encoding: bool) -> torch.FloatTensor: + batch_size, max_input_length = attention_mask_2d.shape + sequence_length = max_input_length if is_for_context_encoding else 1 + return torch.randn(batch_size, sequence_length, hf_config.hidden_size, requires_grad=False).to(dtype=torch.float32) + + +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +def create_hf_attention_mask_4d( + attention_mask_2d: torch.LongTensor, + cache_position: torch.LongTensor, + is_for_context_encoding: bool, + is_swa_layer: bool, + sliding_window_size: int, + dtype: torch.dtype = torch.float32, + ) -> torch.FloatTensor: + batch_size, sequence_length = attention_mask_2d.shape + target_length = sequence_length + if not is_for_context_encoding: + sequence_length = 1 + + attention_mask_4d = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask=attention_mask_2d, + sequence_length=sequence_length, # len_q + target_length=target_length, # len_k + dtype=dtype, + device=attention_mask_2d.device, + cache_position=cache_position, + batch_size=batch_size, + ) + # Adapted from transformers.models.cohere2.modeling_cohere2.Cohere2DecoderLayer.forward + if not is_swa_layer: + return attention_mask_4d + else: + last_cache_position = cache_position[-1] + 1 # Current total seq length, fixed from HF + effective_seq_len = max(cache_position.shape[0], sliding_window_size) + min_dtype = torch.finfo(dtype).min + sliding_window_mask = torch.tril( + torch.ones_like(attention_mask_4d, dtype=torch.bool), diagonal=-sliding_window_size + ) + attention_mask_4d = torch.where(sliding_window_mask, min_dtype, attention_mask_4d) + offset = max(0, last_cache_position - effective_seq_len) + return attention_mask_4d[:, :, :, offset : offset + effective_seq_len] + + +def causal_mask(batch_size, seq_len): + mask = torch.full((seq_len, seq_len), True).tril(diagonal=0) + mask = mask[None, None, :, :].expand(batch_size, 1, seq_len, seq_len) + return mask + + +def window_mask(batch_size: int, seq_len: int, window_size: int): + """create a causal, window attention mask""" + mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool), diagonal=0) + for i in range(seq_len): + if i >= window_size: + mask[i, : i - window_size + 1] = False + mask = mask[None, None, :, :].expand(batch_size, 1, seq_len, seq_len) + return mask diff --git a/contrib/models/gemma3-vision/vllm/README.md b/contrib/models/gemma3-vision/vllm/README.md new file mode 100644 index 00000000..305b561b --- /dev/null +++ b/contrib/models/gemma3-vision/vllm/README.md @@ -0,0 +1,184 @@ +# Running Gemma3 Vision Models with vLLM on AWS Neuron + +## Setup +*Note*: In the following, we assume that the HuggingFace model weights are available on the host. If not, +download them using the following commands: + +```bash +hf auth login --token +hf download google/gemma-3-27b-it --local-dir +``` + +The `` path will need to be provided to vLLM as `--model`/`model` argument. +If the HuggingFace CLI is not installed, run: + +```bash +python3 -m venv hf_env +source hf_env/bin/activate +pip install -U "huggingface_hub[cli]" +``` + +### 1. Install vLLM +```bash +git clone --branch "0.3.0" https://github.com/vllm-project/vllm-neuron.git +cd vllm-neuron +pip install --extra-index-url=https://pip.repos.neuron.amazonaws.com -e . +``` + +### 2. Configure Gemma3 Support +Modify `vllm-neuron/vllm-neuron/worker/constants.py`: +Modify `vllm-neuron/vllm-neuron/worker/neuronx_distributed_model_loader.py`: +Modify `vllm-neuron/vllm-neuron/worker/neuronx_distributed_model_runner.py`: + +#### 2.1 Register Gemma3 HuggingFace model class in supported `NEURON_MULTI_MODAL_MODELS` + +```diff +--- a/vllm_neuron/worker/constants.py ++++ b/vllm_neuron/worker/constants.py +@@ -5,6 +5,7 @@ NEURON_MULTI_MODAL_MODELS = [ + "MllamaForConditionalGeneration", + "LlavaForConditionalGeneration", + "Llama4ForConditionalGeneration", ++ "Gemma3ForConditionalGeneration" + ] + + TORCH_DTYPE_TO_NEURON_AMP = { +``` + +#### 2.2 Fix wrong import in `vllm_neuron/worker/neuronx_distributed_model_loader.py` + +```diff +--- a/vllm_neuron/worker/neuronx_distributed_model_loader.py ++++ b/vllm_neuron/worker/neuronx_distributed_model_loader.py +@@ -51,7 +51,7 @@ from vllm.config import ( + ) + from vllm.model_executor.layers.logits_processor import LogitsProcessor + from vllm.v1.outputs import SamplerOutput +-from vllm.v1.sample import sampler as Sampler ++from vllm.v1.sample.sampler import Sampler + + from vllm_neuron.worker.constants import ( + NEURON_MULTI_MODAL_MODELS, +``` + +#### 2.3 Add `NeuronGemma3ForConditionalGeneration` class to `vllm_neuron/worker/neuronx_distributed_model_loader.py` + +```diff +@@ -704,6 +704,61 @@ class NeuronLlama4ForCausalLM(NeuronMultiModalCausalLM): + **kwargs, + ) + ++class NeuronGemma3ForConditionalGeneration(NeuronLlama4ForCausalLM): ++ """Gemma3 multimodal model using dynamically loaded NeuronGemma3ForConditionalGeneration from contrib.""" ++ ++ def load_weights(self, model_name_or_path: str, architecture: str, **kwargs): ++ import importlib ++ ++ neuronx_module = importlib.import_module("gemma3_vision.modeling_gemma3") ++ neuronx_model_cls = getattr(neuronx_module, "NeuronGemma3ForConditionalGeneration") ++ ++ default_neuron_config = kwargs["neuron_config"] ++ override_neuron_config = _validate_image_to_text_override_neuron_config( ++ kwargs["override_neuron_config"] ++ ) ++ ++ vision_neuron_config = copy.deepcopy(default_neuron_config) ++ vision_neuron_config.update( ++ override_neuron_config.get("vision_neuron_config", {}) ++ ) ++ vision_neuron_config = neuronx_model_cls.get_neuron_config_cls()( ++ **vision_neuron_config ++ ) ++ ++ text_neuron_config = copy.deepcopy(default_neuron_config) ++ text_neuron_config.update(override_neuron_config.get("text_neuron_config", {})) ++ text_neuron_config = neuronx_model_cls.get_neuron_config_cls()( ++ **text_neuron_config ++ ) ++ ++ config = neuronx_model_cls.get_config_cls()( ++ text_neuron_config=text_neuron_config, ++ vision_neuron_config=vision_neuron_config, ++ load_config=load_pretrained_config(model_name_or_path), ++ ) ++ ++ success, compiled_model_path, _ = self._load_weights_common( ++ model_name_or_path, neuronx_model_cls, config=config, **kwargs ++ ) ++ ++ if not success: ++ if not os.path.exists(model_name_or_path): ++ model_name_or_path = self._save_pretrained_model(model_name_or_path) ++ ++ self._compile_and_load_model( ++ model_name_or_path, neuronx_model_cls, config, compiled_model_path ++ ) ++ ++ # Load tokenizer to get vision token ID ++ from transformers import AutoTokenizer ++ ++ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) ++ self.vision_token_id = tokenizer( ++ "<|image|>", add_special_tokens=False ++ ).input_ids[0] ++ return success, compiled_model_path ++ + + def _get_model_configs(config: PretrainedConfig) -> str: + logger.debug("PretrainedConfig: %s", config) +``` + +#### 2.4 Map `NeuronGemma3ForConditionalGeneration` to corresponding HuggingFace model class in `vllm_neuron/worker/neuronx_distributed_model_runner.py` + + +```diff +--- a/vllm_neuron/worker/neuronx_distributed_model_runner.py ++++ b/vllm_neuron/worker/neuronx_distributed_model_runner.py +@@ -775,6 +830,8 @@ def get_neuron_model( + model = NeuronPixtralForCausalLM(model_config.hf_config) + elif architecture == "Llama4ForConditionalGeneration": + model = NeuronLlama4ForCausalLM(model_config.hf_config) ++ elif architecture == "Gemma3ForConditionalGeneration": ++ model = NeuronGemma3ForConditionalGeneration(model_config.hf_config) + else: + model = NeuronCausalLM(model_config.hf_config) +``` + + +#### 2.5 Add Gemma3 to the list of models that use the Llama4 multi-modal data processor + +```diff +--- a/vllm_neuron/worker/neuronx_distributed_model_runner.py ++++ b/vllm_neuron/worker/neuronx_distributed_model_runner.py +@@ -1067,7 +1067,7 @@ class NeuronxDistributedModelRunner(LoRAModelRunnerMixin): + + if self.model.model.config.model_type == "llava": + mm_kwargs = self._process_multi_modal_data_neuron_llava(mm_kwargs) +- elif self.model.model.config.model_type == "llama4": ++ elif self.model.model.config.model_type in ['llama4', 'gemma3']: + pass # llama4 doesn't require special processing + else: + raise NotImplementedError( +``` + +### 3. Run inference + +#### 3.1 Offline Inference + +```bash +PYTHONPATH="$PWD/contrib/models/gemma3-vision/src:src:$PYTHONPATH" run python contrib/models/gemma3-vision/vllm/run_offline_inference.py +``` + +#### 3.2 Online Inference + +1. Start the vLLM server: + +```bash +PYTHONPATH="$PWD/contrib/models/gemma3-vision/src:src:$PYTHONPATH" bash contrib/models/gemma3-vision/vllm/start-vllm-server.sh +``` + +2. Query the running server: + +```bash +PYTHONPATH="$PWD/contrib/models/gemma3-vision/src:src:$PYTHONPATH" run python contrib/models/gemma3-vision/vllm/run_online_inference.py +``` diff --git a/contrib/models/gemma3-vision/vllm/data/dog.jpg b/contrib/models/gemma3-vision/vllm/data/dog.jpg new file mode 100644 index 00000000..f9a3a805 Binary files /dev/null and b/contrib/models/gemma3-vision/vllm/data/dog.jpg differ diff --git a/contrib/models/gemma3-vision/vllm/run_offline_inference.py b/contrib/models/gemma3-vision/vllm/run_offline_inference.py new file mode 100644 index 00000000..676e5d79 --- /dev/null +++ b/contrib/models/gemma3-vision/vllm/run_offline_inference.py @@ -0,0 +1,76 @@ +from gemma3_vision.ndxi_patch import apply_patch +apply_patch() + +import os # noqa: E402 +from pathlib import Path # noqa: E402 + +from vllm import LLM, SamplingParams + +HOME_DIR = Path.home() + +os.environ['VLLM_NEURON_FRAMEWORK'] = "neuronx-distributed-inference" +os.environ['NEURON_COMPILED_ARTIFACTS'] = f"{HOME_DIR.as_posix()}/traced_model/gemma-3-27b-it" + +input_image_path = Path(__file__).resolve().parent / "data" / "dog.jpg" +IMAGE_URL = f"file://{input_image_path.as_posix()}" + + +def main(max_seq_len: int = 1024, images_per_sample: int = 1) -> None: + llm = LLM( + model=f"{HOME_DIR.as_posix()}/models/gemma-3-27b-it", # HuggingFace model ID or path to downloaded HF model artifacts + max_num_seqs=1, + max_model_len=max_seq_len, + tensor_parallel_size=8, + limit_mm_per_prompt={"image": images_per_sample}, # Accept up to 5 images per prompt + allowed_local_media_path=HOME_DIR.as_posix(), # Allow loading local images + enable_prefix_caching=False, + enable_chunked_prefill=False, + additional_config={ + "override_neuron_config": { + "text_neuron_config": { + "attn_kernel_enabled": True, + "enable_bucketing": True, + "context_encoding_buckets": [max_seq_len], + "token_generation_buckets": [max_seq_len], + "is_continuous_batching": True, + "async_mode": True, + }, + "vision_neuron_config": { + "enable_bucketing": True, + "buckets": [images_per_sample], + "is_continuous_batching": True, + } + + }, + }, + ) + + sampling_params = SamplingParams(top_k=1, max_tokens=100) + + # Test 1: Text-only input + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "what is the recipe of mayonnaise in two sentences?"}, + ] + } + ] + for output in llm.chat(conversation, sampling_params): + print(f"Generated text: {output.outputs[0].text !r}") + + # Test 2: Single image with text + conversation = [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": IMAGE_URL}}, + {"type": "text", "text": "Describe this image"}, + ] + } + ] + for output in llm.chat(conversation, sampling_params): + print(f"Generated text: {output.outputs[0].text !r}") + +if __name__ == "__main__": + main() diff --git a/contrib/models/gemma3-vision/vllm/run_online_inference.py b/contrib/models/gemma3-vision/vllm/run_online_inference.py new file mode 100644 index 00000000..20cb5802 --- /dev/null +++ b/contrib/models/gemma3-vision/vllm/run_online_inference.py @@ -0,0 +1,40 @@ +from pathlib import Path + +from openai import OpenAI + +MODEL_ID = "/home/ubuntu/models/gemma-3-27b-it" # HF model ID or path to HF model artifacts + +input_image_path = Path(__file__).resolve().parent / "data" / "dog.jpg" +IMAGE_URL = f"file://{input_image_path.as_posix()}" + + +client = OpenAI( + api_key = "EMPTY", # pragma: allowlist secret + base_url = "http://localhost:8080/v1" +) + +print("== Test text input ==") +completion = client.chat.completions.create( + model=MODEL_ID, + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "what is the recipe of mayonnaise in two sentences?"}, + ] + }] +) +print(completion.choices[0].message.content) + + +print("== Test image+text input ==") +completion = client.chat.completions.create( + model=MODEL_ID, + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image:"}, + {"type": "image_url", "image_url": {"url": IMAGE_URL}} + ] + }] +) +print(completion.choices[0].message.content) diff --git a/contrib/models/gemma3-vision/vllm/start-vllm-server.sh b/contrib/models/gemma3-vision/vllm/start-vllm-server.sh new file mode 100644 index 00000000..d02fdf8c --- /dev/null +++ b/contrib/models/gemma3-vision/vllm/start-vllm-server.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +export VLLM_NEURON_FRAMEWORK="neuronx-distributed-inference" +export NEURON_COMPILED_ARTIFACTS="/home/ubuntu/traced_model/gemma-3-27b-it" # pragma: allowlist secret +export VLLM_RPC_TIMEOUT=100000 + +python -m vllm.entrypoints.openai.api_server \ + --port=8080 \ + --model="/home/ubuntu/models/gemma-3-27b-it" \ + --max-num-seqs=1 \ + --max-model-len=1024 \ + --limit-mm-per-prompt='{"image": 1}' \ + --allowed-local-media-path="/home/ubuntu" \ + --tensor-parallel-size=8 \ + --no-enable-chunked-prefill \ + --no-enable-prefix-caching \ + --additional-config='{"override_neuron_config":{"text_neuron_config":{"attn_kernel_enabled":true,"enable_bucketing":true,"context_encoding_buckets":[1024],"token_generation_buckets":[1024],"is_continuous_batching":true,"async_mode":true},"vision_neuron_config":{"enable_bucketing":true,"buckets":[1],"is_continuous_batching":true}}}'