[{"data":1,"prerenderedAt":4},["ShallowReactive",2],{"IfqLZk5bNs":3},"# Lean 4 → MLIR → GPU\n\n**Interactive proof blueprint: [brettkoonce.github.io/lean4-mlir/blueprint/](https://brettkoonce.github.io/lean4-mlir/blueprint/)**\n(or [PDF](https://brettkoonce.github.io/lean4-mlir/blueprint.pdf))\n— clickable dependency DAG for the full VJP proof suite, from `pdiv` primitives\nup to `vit_body_has_vjp_mat`. 30 axioms, 45 theorems, zero `sorry`s.\n\nLean 4 as a specification language for neural networks. Declare architecture\nin Lean, generate StableHLO MLIR (forward + loss + backward + optimizer all\nin one fused function), compile to GPU via IREE, train end-to-end. No Python\nruntime, no autograd library — the gradients are computed at codegen time\nin Lean.\n\nCompanion code for the upcoming book *Verified Deep Learning with Lean4*\n(follow-up to [Convolutional Neural Networks with Swift for TensorFlow](https://doi.org/10.1007/978-1-4842-6168-2), Apress).\n\n**Current version: `v0.5.2`** — first cross-backend-verified release. MNIST\nMLP *and* CNN training traces agree at the **float32 ULP floor**\nbetween two independent compilation pipelines (Lean→IREE→GPU vs\nLean→JAX→XLA) on both NVIDIA and AMD hardware. See\n[`traces/CROSS_BACKEND_RESULTS.md`](traces/CROSS_BACKEND_RESULTS.md)\nfor the full four-corner verification tables.\n\nOn top of that, a differential-test suite in\n[`tests/vjp_oracle/`](tests/vjp_oracle/) uses JAX's `value_and_grad`\nas an oracle for the hand-derived VJPs in `LeanMlir/Proofs/`. Nine\ntest cases cover every axiom family — dense, conv, BN, maxPool,\nresidual (biPath), depthwise, SE (elementwise product), attention,\nand the transformer block — each verified to 1–2 ULP of JAX autodiff.\n\n## Three phases\n\nThis project went through three implementations of the same idea — \"Lean 4 as a\nspecification language for deep learning\" — each shedding more dependencies\nthan the last.\n\n**Phase 1 — Pure Lean 4.** [`mnist-lean4/`](mnist-lean4/): everything in Lean,\n`Float64` as the only datatype, hand-written gradients, C FFI to OpenBLAS /\nhipBLAS for the matmuls. Worked end-to-end on MNIST through ResNet-34 but\nperformance was poor — every operation crossed the FFI boundary, no fusion,\nno autodiff, no JIT.\n\n**Phase 2 — Lean → JAX.** [`jax/`](jax/): Lean as a metaprogramming layer\nthat emits idiomatic JAX Python (`jax/Jax/Codegen.lean`, ~1100 lines). The\ngenerated script gets `value_and_grad` autodiff and XLA JIT for free, runs\non any JAX-supported device. Trades the pure-Lean story for a working stack\nand real GPU performance. See [`jax/README.md`](jax/README.md) for details.\n\n**Phase 3 — Lean → StableHLO → MLIR → device.** *(this README)* No Python\nruntime at all. Lean directly emits StableHLO MLIR, IREE compiles it to a\nGPU flatbuffer, a thin C FFI loads and runs it. The pure-math version of\nphase 2 — autodiff is done at codegen time in Lean (`LeanMlir/MlirCodegen.lean`,\n~5000 lines), not at runtime by a framework. See [`RESULTS.md`](RESULTS.md)\nfor the per-architecture numbers.\n\nThe proofs that the generated MLIR is mathematically correct live in\n[`LeanMlir/Proofs/`](LeanMlir/Proofs/) — chapter-by-chapter VJP correctness\nproofs for tensor ops, MLP, CNN, residual, batch norm, depthwise, SE,\nLayerNorm, and attention. The codegen and the proofs were written\nindependently and arrived at the same decomposition: every backward pass\nfactors through the standalone gradient of one new primitive per\narchitecture (softmax for attention, the spatial reductions for BN, the\nrank-1 collapse for SE), and everything else is composition via the chain\nrule on tools from earlier chapters.\n\n## Pipeline\n\n```\nLean NetSpec  (~15 lines)\n   │\n   │  MlirCodegen.generateTrainStep\n   ▼\nStableHLO MLIR  (500 KB - 2 MB of text, forward+loss+backward+Adam fused)\n   │\n   │  iree-compile (~10-15 min for ROCm gfx1100)\n   ▼\nVMFB flatbuffer  (1.8-3 MB)\n   │\n   │  IREE runtime via libiree_ffi.so\n   ▼\nGPU execution  (HIP/ROCm or CUDA)\n```\n\nThe same Lean → MLIR pipeline handles every architecture. Adding a new\narchitecture means extending `LeanMlir/MlirCodegen.lean` with:\n- forward emission for the new layer types\n- VJP / backward emission\n- `FwdRec` recording for backward intermediates\n\nThe training executable, FFI, and IREE runtime are unchanged.\n\n## Cross-backend verification\n\nPhase 2 and Phase 3 share the same Lean `NetSpec` ADT but compile through\n*completely independent* stacks (JAX/XLA vs IREE). Differential testing\nconfirms both stacks produce the same training dynamics on the same input,\nfor both MLP (670K params, 12 epochs) and CNN (1.7M params with conv+BN,\n15 epochs):\n\n| diff                              | MLP step 1 Δ | CNN step 1 Δ |\n|-----------------------------------|--------------|--------------|\n| phase 2 (JAX)  vs phase 3 (IREE)  | ~2e-7        | ~1e-5 to 1e-4 |\n| phase 3 ROCm   vs phase 3 CUDA    | **0**        | **0**        |\n| phase 2 CPU    vs phase 2 CUDA    | ~4e-6        | ~1e-4        |\n\nMLP hits the float32 ULP floor because it's dense-only. CNN's noise\nfloor is looser by ~100× because each conv-BN layer does two\nreductions over ~100k-element tensors and XLA's reduction trees differ\nfrom IREE's — both pipelines do correct math, just with different\nsummation orders. Phase 3 ROCm ≡ Phase 3 CUDA is bit-identical at\nstep 1 on both networks. Reproducible in 5 minutes via\n[`traces/CROSS_BACKEND_RESULTS.md`](traces/CROSS_BACKEND_RESULTS.md).\n\n### VJP oracle\n\nA separate per-axiom differential test in\n[`tests/vjp_oracle/`](tests/vjp_oracle/) uses JAX's `value_and_grad` as\na correctness oracle for every hand-derived backward pass in\n`LeanMlir/Proofs/`. Each test case is a minimal NetSpec exercising one\naxiom in isolation; the oracle compares step-2 loss (the first step\nwhose value depends on the backward pass) against phase 2's\nautodiff-derived gradients.\n\nNine cases, all green on mars (ROCm + CPU) and ares (CUDA):\n\n| case | axiom | step 2 Δ |\n|---|---|---|\n| `dense` | `dense_has_vjp` + `softmaxCE_grad` | 2.7e-07 |\n| `dense-relu` | `relu_has_vjp` + `vjp_comp` | 4.8e-07 |\n| `conv` | `conv2d_has_vjp` + `flatten_has_vjp` | 2.2e-07 |\n| `convbn` | `convBn_has_vjp` (BN-mode) | 2.2e-06 |\n| `conv-pool` | `maxPool_has_vjp` (argmax tiebreaks) | 1.2e-04 |\n| `residual` | `biPath_has_vjp` (additive fan-in) | 3.1e-07 |\n| `depthwise` | depthwise-conv VJP via `.invertedResidual` | 1.1e-05 |\n| `mbconv` | `elemwiseProduct_has_vjp` (SE gate) + Swish | 1.6e-06 |\n| `attention` | patchEmbed + `transformerBlock_has_vjp_mat` + classifier | 1.8e-07 |\n\nRun with `tests/vjp_oracle/run.sh`. Adding a new axiom means dropping\na minimal Lean spec under `tests/vjp_oracle/phase{2,3}/` plus one\nline in the lakefiles — see\n[`tests/vjp_oracle/README.md`](tests/vjp_oracle/README.md).\n\nThe oracle also surfaced a real `heInitParams` bug (shape-peek\nheuristic misfiring at patchEmbed + transformer-block boundaries) and\na JAX-ROCm crash on gfx1100 (filed as\n[ROCm/MIOpen#3955](https://github.com/ROCm/MIOpen/issues/3955); repro\nlives at [`upstream-issues/2026-04-rocm-miopen-conv-segv/`](upstream-issues/2026-04-rocm-miopen-conv-segv/)).\n\n## Results (Imagenette, 10 classes, 224×224)\n\nTrained from scratch on a single AMD 7900 XTX (gfx1100), Adam, batch 32,\ncosine LR + 3-epoch warmup, label smoothing 0.1, weight decay 1e-4, random\ncrop (256→224) + horizontal flip, **running BN stats for eval**.\n\n| Model | Params | Val accuracy |\n|---|---|---|\n| ResNet-34 | 21.3M | **90.29%** |\n| ResNet-50 | 23.5M | **89.40%** |\n| EfficientNetV2-S | 38.2M | **88.50%** |\n| EfficientNet-B0 | 7.2M | **87.58%** |\n| MobileNetV2 | 2.2M | **87.09%** |\n| MobileNetV3-Large | 3.0M | **86.48%** |\n| MobileNetV4-Medium | 4.1M | **84.58%** |\n| ViT-Tiny | 5.5M | **71.70%** |\n\nPer-epoch eval histories and ablation tables in [`RESULTS.md`](RESULTS.md).\n\n## Quick start\n\n### 1. Install Lean 4\n\n```bash\ncurl https://raw.githubusercontent.com/leanprover/elan/master/elan-init.sh -sSf | sh\n```\n\n### 2. Install IREE\n\nYou need the IREE runtime built for your GPU (CUDA or ROCm). The FFI shim\nin `ffi/` links against `libiree_runtime_unified.a` from the IREE build tree.\nSee [`IREE_BUILD.md`](IREE_BUILD.md) for build instructions.\n\n### 3. Get data\n\n```bash\n./download_mnist.sh        # MNIST (Ch 3-4 trainers)\n./download_cifar.sh        # CIFAR-10 (Ch 5 trainers)\n./download_imagenette.sh   # Imagenette 320px → preprocessed binary (Ch 6+)\n```\n\n### 4. Build a trainer\n\n```bash\nlake build resnet34-train\n```\n\nThis compiles the Lean trainer (which generates MLIR + drives IREE + runs\nthe training loop). Other targets, in roughly book order:\n`mnist-mlp-train`, `mnist-cnn-train`, `cifar-cnn-train`, `cifar-bn-train`,\n`resnet50-train`, `mobilenet-v2-train`, `mobilenet-v3-train`,\n`mobilenet-v4-train`, `efficientnet-train`, `efficientnet-v2-train`,\n`vgg-train`, `vit-tiny-train`.\n\n### 5. Run\n\nThe first invocation generates and compiles the vmfbs (slow — IREE\ncompilation takes 10-15 min for ResNet-sized models). Subsequent\nruns reuse the cached vmfbs unless you clear `.lake/build/`.\n\n```bash\nHIP_VISIBLE_DEVICES=0 IREE_BACKEND=rocm .lake/build/bin/resnet34-train\n\n# Or via the included shell wrapper that sets the env vars correctly\nbash run.sh resnet34                  # GPU 0, ROCm (defaults)\nbash run.sh efficientnet-v2 1 cuda    # GPU 1, CUDA\n```\n\nFor CUDA, set `IREE_BACKEND=cuda` (the default) and use `CUDA_VISIBLE_DEVICES`.\n\n## Lean specs\n\nThe same `NetSpec` type is used by all three phases. A spec is a list of\n`Layer` values:\n\n```lean\ndef resnet34 : NetSpec where\n  name := \"ResNet-34\"\n  imageH := 224\n  imageW := 224\n  layers := [\n    .convBn 3 64 7 2 .same,\n    .maxPool 2 2,\n    .residualBlock  64  64 3 1,\n    .residualBlock  64 128 4 2,\n    .residualBlock 128 256 6 2,\n    .residualBlock 256 512 3 2,\n    .globalAvgPool,\n    .dense 512 10 .identity\n  ]\n\ndef vitTiny : NetSpec where\n  name := \"ViT-Tiny\"\n  imageH := 224\n  imageW := 224\n  layers := [\n    .patchEmbed 3 192 16 196,             -- (224/16)^2 = 196 patches\n    .transformerEncoder 192 3 768 12,     -- 12 blocks, 3 heads, MLP dim 768\n    .dense 192 10 .identity\n  ]\n\ndef mobilenetV4Medium : NetSpec where\n  name := \"MobileNet V4-Medium\"\n  imageH := 224\n  imageW := 224\n  layers := [\n    .convBn 3 32 3 2 .same,\n    .fusedMbConv 32 48 4 3 2 1 false,\n    .uib  48  80 4 2 3 5,    -- ExtraDW\n    .uib  80 160 6 2 0 3,    -- IB (= MBConv)\n    .uib 160 160 4 1 5 0,    -- ConvNeXt\n    .uib 160 160 4 1 0 0,    -- FFN\n    -- ... 11 more UIB blocks\n    .convBn 256 1280 1 1 .same,\n    .globalAvgPool,\n    .dense 1280 10 .identity\n  ]\n```\n\n## Project structure\n\n```\nlean4-mlir/\n├── README.md               -- this file (phase 3)\n├── RESULTS.md              -- per-architecture eval histories + ablations\n├── IREE_BUILD.md           -- how to build libiree_ffi.so from scratch\n├── ROCM.md                 -- ROCm setup notes\n├── BENCHMARK.md            -- ROCm vs CUDA performance comparison\n├── lakefile.lean           -- Lake build config (libraries + ~30 execs)\n│\n├── LeanMlir.lean           -- umbrella module\n├── LeanMlir/\n│   ├── MlirCodegen.lean    -- ~5000 lines, NetSpec → StableHLO MLIR\n│   ├── IreeRuntime.lean    -- Lean ↔ libiree_ffi.so bindings\n│   ├── F32Array.lean       -- ByteArray-backed float32 helpers\n│   ├── Spec.lean           -- NetSpec / Layer / param-counting\n│   ├── Types.lean          -- core types (Layer, Activation, Padding, ...)\n│   ├── MnistData.lean      -- IDX file loader (older training paths)\n│   └── Proofs/             -- VJP correctness proofs (~2100 lines)\n│       ├── Tensor.lean\n│       ├── MLP.lean\n│       ├── CNN.lean\n│       ├── Residual.lean\n│       ├── BatchNorm.lean\n│       ├── Depthwise.lean\n│       ├── SE.lean\n│       ├── LayerNorm.lean\n│       └── Attention.lean\n│\n├── Main*Train.lean         -- phase 3 trainers (one per architecture)\n│   ├── MainResnetTrain.lean\n│   ├── MainResnet50Train.lean\n│   ├── MainMobilenetV2Train.lean\n│   ├── MainMobilenetV3Train.lean\n│   ├── MainMobilenetV4Train.lean\n│   ├── MainEfficientNetTrain.lean\n│   ├── MainEfficientNetV2Train.lean\n│   ├── MainVitTrain.lean\n│   ├── MainVggTrain.lean\n│   ├── MainMnistMlpTrain.lean\n│   ├── MainMnistCnnTrain.lean\n│   ├── MainCifarCnnBnTrain.lean\n│   ├── MainCifarCnnTrain.lean\n│   └── MainAblation.lean\n│\n├── tests/                  -- unit tests + smoke tests + differential tests\n│   ├── Test*.lean          -- runtime / FFI / codegen sanity tests\n│   ├── BenchResnet.lean\n│   ├── diff_traces.py      -- JSONL trace diff helper\n│   ├── cross_backend_mnist_mlp.sh\n│   └── vjp_oracle/         -- JAX-autodiff oracle for hand-derived VJPs\n│       ├── README.md\n│       ├── run.sh\n│       ├── diff_step.py\n│       ├── phase3/         -- Lean→IREE test trainers\n│       └── phase2/         -- (mirrored at jax/tests/vjp_oracle/phase2/)\n│\n├── upstream-issues/        -- isolated reproducers + backtraces for bugs\n│   └── 2026-04-rocm-miopen-conv-segv/  -- ROCm/MIOpen#3955\n│\n├── ffi/\n│   ├── iree_ffi.{c,h}      -- IREE runtime wrapper\n│   ├── iree_lean_ffi.c     -- Lean FFI bindings\n│   ├── f32_helpers.c       -- data loading, He init, EMA, augmentation\n│   └── libiree_ffi.so      -- compiled shared library\n│\n├── jax/                    -- phase 2 (Lean → JAX Python)\n│   ├── README.md\n│   ├── Jax.lean\n│   ├── Jax/{Codegen,Runner}.lean\n│   ├── Main*.lean          -- 14 JAX-driven architecture specs\n│   └── tests/vjp_oracle/phase2/  -- phase-2 mirror of oracle specs\n│\n├── mnist-lean4/            -- phase 1 (pure Lean 4 + C BLAS)\n│\n├── traces/                 -- committed cross-backend training traces\n│   ├── CROSS_BACKEND_RESULTS.md\n│   ├── TRACE_FORMAT.md\n│   └── mnist_{mlp,cnn}.*.jsonl\n│\n├── data/                   -- downloaded + preprocessed datasets\n└── run_*.sh                -- shell wrappers for tmux env propagation\n```\n\n## Supported layers (phase 3 codegen)\n\n| Layer | Description |\n|-------|-------------|\n| `dense` | Fully connected (with optional activation) |\n| `conv2d` | Standard convolution |\n| `convBn` | Conv + batch norm + ReLU/ReLU6/Swish/h-swish |\n| `residualBlock` | BasicBlock (ResNet-18/34) |\n| `bottleneckBlock` | Bottleneck (ResNet-50/101/152) |\n| `invertedResidual` | Expand → depthwise → project + skip (MobileNetV2) |\n| `mbConv` | + Squeeze-Excitation, Swish (EfficientNet) |\n| `mbConvV3` | + h-swish + h-sigmoid SE (MobileNetV3, exact math) |\n| `fusedMbConv` | k×k regular conv replaces (1×1 expand + depthwise) (EfficientNetV2) |\n| `uib` | Universal Inverted Bottleneck — pre-DW? + expand + post-DW? + project (MobileNetV4) |\n| `patchEmbed` | Conv patch projection + CLS token + positional embedding (ViT) |\n| `transformerEncoder` | LN → MHSA → + → LN → MLP → +, with exact tanh-form GELU |\n| `maxPool`, `globalAvgPool`, `flatten` | Structural |\n\nActivations supported with exact backward: ReLU, ReLU6, Swish, h-swish,\nh-sigmoid, GELU (tanh form). Layer-norm and batch-norm both have proper\nVJPs and (for BN) running statistics for eval.\n\n## Lean version\n\nTested with Lean 4.29.0 / Lake 5.0.0, IREE built from source against\nROCm 7.2.0 / gfx1100.\n\n## Citing this work\n\n```bibtex\n@software{koonce2026leanmlir,\n  author  = {Brett Koonce and Claude Code},\n  title   = {Verified Deep Learning with Lean4: Formal Backpropagation from MLP to Attention, via MLIR},\n  url     = {https://github.com/brettkoonce/lean4-mlir},\n  version = {0.5.2},\n  year    = {2026},\n}\n```\n",1777138837597]