Lean 4 → MLIR → GPU

Interactive proof blueprint: brettkoonce.github.io/lean4-mlir/blueprint/ (or PDF) — clickable dependency DAG for the full VJP proof suite, from pdiv primitives up to vit_body_has_vjp_mat. 30 axioms, 45 theorems, zero sorrys.

Lean 4 as a specification language for neural networks. Declare architecture in Lean, generate StableHLO MLIR (forward + loss + backward + optimizer all in one fused function), compile to GPU via IREE, train end-to-end. No Python runtime, no autograd library — the gradients are computed at codegen time in Lean.

Companion code for the upcoming book Verified Deep Learning with Lean4 (follow-up to Convolutional Neural Networks with Swift for TensorFlow, Apress).

Current version: v0.5.2 — first cross-backend-verified release. MNIST MLP and CNN training traces agree at the float32 ULP floor between two independent compilation pipelines (Lean→IREE→GPU vs Lean→JAX→XLA) on both NVIDIA and AMD hardware. See traces/CROSS_BACKEND_RESULTS.md for the full four-corner verification tables.

On top of that, a differential-test suite in tests/vjp_oracle/ uses JAX's value_and_grad as an oracle for the hand-derived VJPs in LeanMlir/Proofs/. Nine test cases cover every axiom family — dense, conv, BN, maxPool, residual (biPath), depthwise, SE (elementwise product), attention, and the transformer block — each verified to 1–2 ULP of JAX autodiff.

Three phases

This project went through three implementations of the same idea — "Lean 4 as a specification language for deep learning" — each shedding more dependencies than the last.

Phase 1 — Pure Lean 4. mnist-lean4/: everything in Lean, Float64 as the only datatype, hand-written gradients, C FFI to OpenBLAS / hipBLAS for the matmuls. Worked end-to-end on MNIST through ResNet-34 but performance was poor — every operation crossed the FFI boundary, no fusion, no autodiff, no JIT.

Phase 2 — Lean → JAX. jax/: Lean as a metaprogramming layer that emits idiomatic JAX Python (jax/Jax/Codegen.lean, ~1100 lines). The generated script gets value_and_grad autodiff and XLA JIT for free, runs on any JAX-supported device. Trades the pure-Lean story for a working stack and real GPU performance. See jax/README.md for details.

Phase 3 — Lean → StableHLO → MLIR → device. (this README) No Python runtime at all. Lean directly emits StableHLO MLIR, IREE compiles it to a GPU flatbuffer, a thin C FFI loads and runs it. The pure-math version of phase 2 — autodiff is done at codegen time in Lean (LeanMlir/MlirCodegen.lean, ~5000 lines), not at runtime by a framework. See RESULTS.md for the per-architecture numbers.

The proofs that the generated MLIR is mathematically correct live in LeanMlir/Proofs/ — chapter-by-chapter VJP correctness proofs for tensor ops, MLP, CNN, residual, batch norm, depthwise, SE, LayerNorm, and attention. The codegen and the proofs were written independently and arrived at the same decomposition: every backward pass factors through the standalone gradient of one new primitive per architecture (softmax for attention, the spatial reductions for BN, the rank-1 collapse for SE), and everything else is composition via the chain rule on tools from earlier chapters.

Pipeline

Lean NetSpec  (~15 lines)
   │
   │  MlirCodegen.generateTrainStep
   ▼
StableHLO MLIR  (500 KB - 2 MB of text, forward+loss+backward+Adam fused)
   │
   │  iree-compile (~10-15 min for ROCm gfx1100)
   ▼
VMFB flatbuffer  (1.8-3 MB)
   │
   │  IREE runtime via libiree_ffi.so
   ▼
GPU execution  (HIP/ROCm or CUDA)

The same Lean → MLIR pipeline handles every architecture. Adding a new architecture means extending LeanMlir/MlirCodegen.lean with:

  • forward emission for the new layer types
  • VJP / backward emission
  • FwdRec recording for backward intermediates

The training executable, FFI, and IREE runtime are unchanged.

Cross-backend verification

Phase 2 and Phase 3 share the same Lean NetSpec ADT but compile through completely independent stacks (JAX/XLA vs IREE). Differential testing confirms both stacks produce the same training dynamics on the same input, for both MLP (670K params, 12 epochs) and CNN (1.7M params with conv+BN, 15 epochs):

diffMLP step 1 ΔCNN step 1 Δ
phase 2 (JAX) vs phase 3 (IREE)~2e-7~1e-5 to 1e-4
phase 3 ROCm vs phase 3 CUDA00
phase 2 CPU vs phase 2 CUDA~4e-6~1e-4

MLP hits the float32 ULP floor because it's dense-only. CNN's noise floor is looser by ~100× because each conv-BN layer does two reductions over ~100k-element tensors and XLA's reduction trees differ from IREE's — both pipelines do correct math, just with different summation orders. Phase 3 ROCm ≡ Phase 3 CUDA is bit-identical at step 1 on both networks. Reproducible in 5 minutes via traces/CROSS_BACKEND_RESULTS.md.

VJP oracle

A separate per-axiom differential test in tests/vjp_oracle/ uses JAX's value_and_grad as a correctness oracle for every hand-derived backward pass in LeanMlir/Proofs/. Each test case is a minimal NetSpec exercising one axiom in isolation; the oracle compares step-2 loss (the first step whose value depends on the backward pass) against phase 2's autodiff-derived gradients.

Nine cases, all green on mars (ROCm + CPU) and ares (CUDA):

caseaxiomstep 2 Δ
densedense_has_vjp + softmaxCE_grad2.7e-07
dense-relurelu_has_vjp + vjp_comp4.8e-07
convconv2d_has_vjp + flatten_has_vjp2.2e-07
convbnconvBn_has_vjp (BN-mode)2.2e-06
conv-poolmaxPool_has_vjp (argmax tiebreaks)1.2e-04
residualbiPath_has_vjp (additive fan-in)3.1e-07
depthwisedepthwise-conv VJP via .invertedResidual1.1e-05
mbconvelemwiseProduct_has_vjp (SE gate) + Swish1.6e-06
attentionpatchEmbed + transformerBlock_has_vjp_mat + classifier1.8e-07

Run with tests/vjp_oracle/run.sh. Adding a new axiom means dropping a minimal Lean spec under tests/vjp_oracle/phase{2,3}/ plus one line in the lakefiles — see tests/vjp_oracle/README.md.

The oracle also surfaced a real heInitParams bug (shape-peek heuristic misfiring at patchEmbed + transformer-block boundaries) and a JAX-ROCm crash on gfx1100 (filed as ROCm/MIOpen#3955; repro lives at upstream-issues/2026-04-rocm-miopen-conv-segv/).

Results (Imagenette, 10 classes, 224×224)

Trained from scratch on a single AMD 7900 XTX (gfx1100), Adam, batch 32, cosine LR + 3-epoch warmup, label smoothing 0.1, weight decay 1e-4, random crop (256→224) + horizontal flip, running BN stats for eval.

ModelParamsVal accuracy
ResNet-3421.3M90.29%
ResNet-5023.5M89.40%
EfficientNetV2-S38.2M88.50%
EfficientNet-B07.2M87.58%
MobileNetV22.2M87.09%
MobileNetV3-Large3.0M86.48%
MobileNetV4-Medium4.1M84.58%
ViT-Tiny5.5M71.70%

Per-epoch eval histories and ablation tables in RESULTS.md.

Quick start

1. Install Lean 4

curl https://raw.githubusercontent.com/leanprover/elan/master/elan-init.sh -sSf | sh

2. Install IREE

You need the IREE runtime built for your GPU (CUDA or ROCm). The FFI shim in ffi/ links against libiree_runtime_unified.a from the IREE build tree. See IREE_BUILD.md for build instructions.

3. Get data

./download_mnist.sh        # MNIST (Ch 3-4 trainers)
./download_cifar.sh        # CIFAR-10 (Ch 5 trainers)
./download_imagenette.sh   # Imagenette 320px → preprocessed binary (Ch 6+)

4. Build a trainer

lake build resnet34-train

This compiles the Lean trainer (which generates MLIR + drives IREE + runs the training loop). Other targets, in roughly book order: mnist-mlp-train, mnist-cnn-train, cifar-cnn-train, cifar-bn-train, resnet50-train, mobilenet-v2-train, mobilenet-v3-train, mobilenet-v4-train, efficientnet-train, efficientnet-v2-train, vgg-train, vit-tiny-train.

5. Run

The first invocation generates and compiles the vmfbs (slow — IREE compilation takes 10-15 min for ResNet-sized models). Subsequent runs reuse the cached vmfbs unless you clear .lake/build/.

HIP_VISIBLE_DEVICES=0 IREE_BACKEND=rocm .lake/build/bin/resnet34-train

# Or via the included shell wrapper that sets the env vars correctly
bash run.sh resnet34                  # GPU 0, ROCm (defaults)
bash run.sh efficientnet-v2 1 cuda    # GPU 1, CUDA

For CUDA, set IREE_BACKEND=cuda (the default) and use CUDA_VISIBLE_DEVICES.

Lean specs

The same NetSpec type is used by all three phases. A spec is a list of Layer values:

def resnet34 : NetSpec where
  name := "ResNet-34"
  imageH := 224
  imageW := 224
  layers := [
    .convBn 3 64 7 2 .same,
    .maxPool 2 2,
    .residualBlock  64  64 3 1,
    .residualBlock  64 128 4 2,
    .residualBlock 128 256 6 2,
    .residualBlock 256 512 3 2,
    .globalAvgPool,
    .dense 512 10 .identity
  ]

def vitTiny : NetSpec where
  name := "ViT-Tiny"
  imageH := 224
  imageW := 224
  layers := [
    .patchEmbed 3 192 16 196,             -- (224/16)^2 = 196 patches
    .transformerEncoder 192 3 768 12,     -- 12 blocks, 3 heads, MLP dim 768
    .dense 192 10 .identity
  ]

def mobilenetV4Medium : NetSpec where
  name := "MobileNet V4-Medium"
  imageH := 224
  imageW := 224
  layers := [
    .convBn 3 32 3 2 .same,
    .fusedMbConv 32 48 4 3 2 1 false,
    .uib  48  80 4 2 3 5,    -- ExtraDW
    .uib  80 160 6 2 0 3,    -- IB (= MBConv)
    .uib 160 160 4 1 5 0,    -- ConvNeXt
    .uib 160 160 4 1 0 0,    -- FFN
    -- ... 11 more UIB blocks
    .convBn 256 1280 1 1 .same,
    .globalAvgPool,
    .dense 1280 10 .identity
  ]

Project structure

lean4-mlir/
├── README.md               -- this file (phase 3)
├── RESULTS.md              -- per-architecture eval histories + ablations
├── IREE_BUILD.md           -- how to build libiree_ffi.so from scratch
├── ROCM.md                 -- ROCm setup notes
├── BENCHMARK.md            -- ROCm vs CUDA performance comparison
├── lakefile.lean           -- Lake build config (libraries + ~30 execs)
│
├── LeanMlir.lean           -- umbrella module
├── LeanMlir/
│   ├── MlirCodegen.lean    -- ~5000 lines, NetSpec → StableHLO MLIR
│   ├── IreeRuntime.lean    -- Lean ↔ libiree_ffi.so bindings
│   ├── F32Array.lean       -- ByteArray-backed float32 helpers
│   ├── Spec.lean           -- NetSpec / Layer / param-counting
│   ├── Types.lean          -- core types (Layer, Activation, Padding, ...)
│   ├── MnistData.lean      -- IDX file loader (older training paths)
│   └── Proofs/             -- VJP correctness proofs (~2100 lines)
│       ├── Tensor.lean
│       ├── MLP.lean
│       ├── CNN.lean
│       ├── Residual.lean
│       ├── BatchNorm.lean
│       ├── Depthwise.lean
│       ├── SE.lean
│       ├── LayerNorm.lean
│       └── Attention.lean
│
├── Main*Train.lean         -- phase 3 trainers (one per architecture)
│   ├── MainResnetTrain.lean
│   ├── MainResnet50Train.lean
│   ├── MainMobilenetV2Train.lean
│   ├── MainMobilenetV3Train.lean
│   ├── MainMobilenetV4Train.lean
│   ├── MainEfficientNetTrain.lean
│   ├── MainEfficientNetV2Train.lean
│   ├── MainVitTrain.lean
│   ├── MainVggTrain.lean
│   ├── MainMnistMlpTrain.lean
│   ├── MainMnistCnnTrain.lean
│   ├── MainCifarCnnBnTrain.lean
│   ├── MainCifarCnnTrain.lean
│   └── MainAblation.lean
│
├── tests/                  -- unit tests + smoke tests + differential tests
│   ├── Test*.lean          -- runtime / FFI / codegen sanity tests
│   ├── BenchResnet.lean
│   ├── diff_traces.py      -- JSONL trace diff helper
│   ├── cross_backend_mnist_mlp.sh
│   └── vjp_oracle/         -- JAX-autodiff oracle for hand-derived VJPs
│       ├── README.md
│       ├── run.sh
│       ├── diff_step.py
│       ├── phase3/         -- Lean→IREE test trainers
│       └── phase2/         -- (mirrored at jax/tests/vjp_oracle/phase2/)
│
├── upstream-issues/        -- isolated reproducers + backtraces for bugs
│   └── 2026-04-rocm-miopen-conv-segv/  -- ROCm/MIOpen#3955
│
├── ffi/
│   ├── iree_ffi.{c,h}      -- IREE runtime wrapper
│   ├── iree_lean_ffi.c     -- Lean FFI bindings
│   ├── f32_helpers.c       -- data loading, He init, EMA, augmentation
│   └── libiree_ffi.so      -- compiled shared library
│
├── jax/                    -- phase 2 (Lean → JAX Python)
│   ├── README.md
│   ├── Jax.lean
│   ├── Jax/{Codegen,Runner}.lean
│   ├── Main*.lean          -- 14 JAX-driven architecture specs
│   └── tests/vjp_oracle/phase2/  -- phase-2 mirror of oracle specs
│
├── mnist-lean4/            -- phase 1 (pure Lean 4 + C BLAS)
│
├── traces/                 -- committed cross-backend training traces
│   ├── CROSS_BACKEND_RESULTS.md
│   ├── TRACE_FORMAT.md
│   └── mnist_{mlp,cnn}.*.jsonl
│
├── data/                   -- downloaded + preprocessed datasets
└── run_*.sh                -- shell wrappers for tmux env propagation

Supported layers (phase 3 codegen)

LayerDescription
denseFully connected (with optional activation)
conv2dStandard convolution
convBnConv + batch norm + ReLU/ReLU6/Swish/h-swish
residualBlockBasicBlock (ResNet-18/34)
bottleneckBlockBottleneck (ResNet-50/101/152)
invertedResidualExpand → depthwise → project + skip (MobileNetV2)
mbConv+ Squeeze-Excitation, Swish (EfficientNet)
mbConvV3+ h-swish + h-sigmoid SE (MobileNetV3, exact math)
fusedMbConvk×k regular conv replaces (1×1 expand + depthwise) (EfficientNetV2)
uibUniversal Inverted Bottleneck — pre-DW? + expand + post-DW? + project (MobileNetV4)
patchEmbedConv patch projection + CLS token + positional embedding (ViT)
transformerEncoderLN → MHSA → + → LN → MLP → +, with exact tanh-form GELU
maxPool, globalAvgPool, flattenStructural

Activations supported with exact backward: ReLU, ReLU6, Swish, h-swish, h-sigmoid, GELU (tanh form). Layer-norm and batch-norm both have proper VJPs and (for BN) running statistics for eval.

Lean version

Tested with Lean 4.29.0 / Lake 5.0.0, IREE built from source against ROCm 7.2.0 / gfx1100.

Citing this work

@software{koonce2026leanmlir,
  author  = {Brett Koonce and Claude Code},
  title   = {Verified Deep Learning with Lean4: Formal Backpropagation from MLP to Attention, via MLIR},
  url     = {https://github.com/brettkoonce/lean4-mlir},
  version = {0.5.2},
  year    = {2026},
}