(This is 100% vibe coded by AI). I did it to show the PVS dev the current sota.
LeanDidax
A pedagogical implementation of automatic differentiation in Lean 4, inspired by JAX's Autodidax tutorial.
For a comprehensive overview of the project and its components, see the Summary Document.
Overview
LeanDidax is an automatic differentiation library written in Lean 4. It provides both forward-mode and reverse-mode autodiff capabilities, allowing users to compute derivatives of complex mathematical functions efficiently.
The library is inspired by the JAX Autodidax tutorial, which explains how to build a simple automatic differentiation system from scratch.
Features
- Forward-mode autodiff: Efficient computation of derivatives using dual numbers
- Reverse-mode autodiff: Efficient computation of gradients using computational graphs
- Operator overloading: Use standard mathematical notation with autodiff types
- Vectorized operations: Support for batched operations and the
vmaptransformation - Control flow primitives: Differentiable conditionals with
cond,select, andswitch - Custom derivative rules: Define specialized derivatives for functions
- Rich function library: Support for common mathematical functions:
- Basic operations: +, -, *, /
- Trigonometric functions: sin, cos, tan
- Hyperbolic functions: sinh, cosh, tanh
- Exponential and logarithmic functions: exp, log
- Non-differentiable functions with special handling: abs, relu
- Activation functions for neural networks: sigmoid
- Higher-order differentiation functions: grad, valueAndGrad, jvp
Installation
Add LeanDidax to your Lean project by adding it as a dependency in your lakefile.lean:
require LeanDidax from git "https://github.com/yourusername/LeanDidax" @ "main"
Usage
Basic Example: Forward-Mode Autodiff
import LeanDidax2.Basic
open LeanDidax2
-- Define a function to differentiate
def f (x : Value Float) : Value Float :=
x * x + 2 * x + 1
-- Compute the value and derivative at x=3
#eval
let x := seed 3.0 -- Create a variable with tangent=1.0
let result := f x
(s!"Function value at x=3: {result.primal}",
s!"Derivative at x=3: {result.tangent}")
Reverse-Mode Autodiff Example
import LeanDidax2.Autodiff
open LeanDidax2
open LeanDidax2.ReverseMode
-- Define a function using a computational graph
def polynomial (x : Node) : Node :=
let x2 := Node.Mul x x -- x^2
let two := Node.Leaf 2.0
let twoX := Node.Mul two x -- 2*x
let sum1 := Node.Add x2 twoX -- x^2 + 2*x
let one := Node.Leaf 1.0
Node.Add sum1 one -- x^2 + 2*x + 1
-- Compute the gradient at x=3
#eval
let x := 3.0
let xNode := Node.Leaf x
let polyGraph := polynomial xNode
let gradResults := backward polyGraph 1.0
let grad := match gradResults.find? (fun pair => pair.1 == x) with
| some (_, cotangent) => cotangent
| none => 0.0
(s!"Gradient of f(x) = x^2 + 2x + 1 at x=3: {grad}")
Vectorized Operations with vmap
import LeanDidax2.Batch
open LeanDidax2
open LeanDidax2.Batch
-- Define a function
def f (x : Value Float) : Value Float :=
x * x + 2 * x + 1
-- Apply it to multiple inputs at once
#eval
let inputs := #[1.0, 2.0, 3.0, 4.0, 5.0]
let results := vmap f inputs
for i in [:inputs.size] do
IO.println s!"{inputs[i]!}: value = {results[i]!.primal}, derivative = {results[i]!.tangent}"
Custom Derivative Rules
import LeanDidax2.CustomRules
open LeanDidax2
open LeanDidax2.CustomRules
-- Define a function with a custom derivative rule
def f (x : Float) : Float := x * x * x * Float.sin x
def df (x : Float) : Float := 3 * x * x * Float.sin x + x * x * x * Float.cos x
-- Use the custom rule
#eval
let x := 2.0
let result := defCustomFn "cubic_sine" f df (seed x)
(s!"Value: {result.primal}", s!"Derivative: {result.tangent}")
Differentiable Control Flow
import LeanDidax2.ControlFlow
open LeanDidax2
open LeanDidax2.ControlFlow
-- Use differentiable conditionals
#eval
let x := seed 2.0
let result := cond (x.primal > 0.0)
(fun _ => x * x) -- x^2 when x > 0
(fun _ => x * -1.0) -- -x when x ≤ 0
(s!"Value: {result.primal}", s!"Derivative: {result.tangent}")
Implementation Details
Value Type
The core of LeanDidax is the Value type, which tracks both the primal value and its tangent (derivative) information:
structure Value (α : Type) where
primal : α
tangent : α := primal
Forward-Mode Autodiff
Forward-mode autodiff works by propagating tangent values alongside primal values through each operation. For example, the multiplication rule:
def mul [Mul α] [Add α] (x y : Value α) : Value α :=
{ primal := x.primal * y.primal,
tangent := x.tangent * y.primal + x.primal * y.tangent }
Reverse-Mode Autodiff
Reverse-mode autodiff builds a computational graph and uses backward propagation to compute gradients:
inductive Node
| Leaf (value : Float)
| Add (left right : Node)
| Mul (left right : Node)
| ...
Control Flow Primitives
The library supports differentiable control flow:
def cond {α : Type} [Zero α]
(pred : Bool)
(trueBranch : Unit → Value α)
(falseBranch : Unit → Value α) : Value α :=
if pred then trueBranch () else falseBranch ()
Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
License
This project is licensed under the MIT License - see the LICENSE file for details.