Skip to content

rpuri4/lisp_compiler

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TensorLisp: A Differentiable Tensor Compiler in OCaml

A from-scratch compiler for a Lisp dialect with first-class tensors, automatic differentiation, and bytecode optimization. Built to demonstrate compiler construction, type systems, and ML framework internals.

                    ┌─────────────────────────────────────────────────────────────┐
                    │                    TensorLisp Compiler                      │
                    │                                                             │
   Source Code      │  ┌─────────┐   ┌───────────┐   ┌──────────┐   ┌─────────┐  │   Result
  ─────────────────►│  │ Parser  │──►│ TypeCheck │──►│ Bytecode │──►│   VM    │──┼──────────►
  (let ((x ...))    │  │         │   │           │   │ + Optim  │   │         │  │  tensor[2x2]
   (grad ...))      │  └─────────┘   └───────────┘   └──────────┘   └─────────┘  │
                    │       │              │               │              │       │
                    │       ▼              ▼               ▼              ▼       │
                    │    S-expr         Typed AST      SSA-like IR    Autodiff   │
                    │     AST          + shapes       + optimizations   Tape     │
                    └─────────────────────────────────────────────────────────────┘

Features

  • Tensor-aware type system with shape inference and broadcast validation
  • Reverse-mode automatic differentiation for computing gradients
  • SSA-style bytecode with constant folding and operation fusion
  • Arena allocator with tuned GC for efficient tensor memory management
  • Pure OCaml implementation - no external dependencies beyond the standard library

Quick Start

# Build the project
dune build

# Run a program from a file
dune exec tensorc examples/grad.lsp

# Run from stdin
echo '(sum (tensor f64 (2 2) 1 2 3 4))' | dune exec tensorc

Language Reference

Tensor Literals

Create tensors with explicit dtype, shape, and values:

; Create a 2x2 matrix of float64
(tensor f64 (2 2) 1 0 0 1)

; Create a 1D vector
(tensor f64 (3) 1.5 2.5 3.5)

; Supported dtypes: f32, f64

Operations

Operation Syntax Description
Addition (add a b) or (+ a b) Element-wise addition with broadcasting
Multiplication (mul a b) or (* a b) Element-wise multiplication with broadcasting
Matrix Multiply (matmul a b) Matrix multiplication (rank-2 tensors)
Sum (sum t) Reduce tensor to scalar by summing all elements
Gradient (grad (lambda (x) body) input) Compute gradient of scalar-valued function

Variables and Functions

; Let binding
(let ((x (tensor f64 (2 2) 1 0 0 1)))
  (matmul x x))

; Lambda (currently used with grad)
(grad (lambda (w) (sum (matmul w x))) x)

Example: Computing Gradients

; Compute gradient of f(w) = sum(w @ x) with respect to x
; where x is the 2x2 identity matrix

(let ((x (tensor f64 (2 2) 1 0 0 1)))
  (grad (lambda (w) (sum (matmul w x))) x))

Output:

=== Optimized Bytecode ===
0 CONST
4 GRAD

Result: tensor[2x2]

Architecture

Compilation Pipeline

┌──────────────────────────────────────────────────────────────────────────────┐
│                              COMPILATION PHASES                               │
├──────────────────────────────────────────────────────────────────────────────┤
│                                                                              │
│  1. PARSING (parser.ml)                                                      │
│     ┌─────────────────┐      ┌─────────────────┐      ┌─────────────────┐   │
│     │  Source Text    │ ──►  │    Tokenize     │ ──►  │   S-expression  │   │
│     │  "(add x y)"    │      │  LParen, Symbol │      │   AST nodes     │   │
│     └─────────────────┘      └─────────────────┘      └─────────────────┘   │
│                                                                              │
│  2. TYPE CHECKING (typecheck.ml)                                             │
│     ┌─────────────────┐      ┌─────────────────┐      ┌─────────────────┐   │
│     │   Untyped AST   │ ──►  │  Infer types &  │ ──►  │   Typed AST     │   │
│     │                 │      │  check shapes   │      │  + annotations  │   │
│     └─────────────────┘      └─────────────────┘      └─────────────────┘   │
│                                                                              │
│  3. BYTECODE GENERATION (bytecode.ml)                                        │
│     ┌─────────────────┐      ┌─────────────────┐      ┌─────────────────┐   │
│     │   Typed AST     │ ──►  │  Lower to IR    │ ──►  │  SSA-like       │   │
│     │                 │      │  + optimizations│      │  instructions   │   │
│     └─────────────────┘      └─────────────────┘      └─────────────────┘   │
│                                                                              │
│  4. EXECUTION (vm.ml)                                                        │
│     ┌─────────────────┐      ┌─────────────────┐      ┌─────────────────┐   │
│     │   Bytecode      │ ──►  │  Interpret +    │ ──►  │   Result        │   │
│     │   program       │      │  autodiff tape  │      │   tensor/scalar │   │
│     └─────────────────┘      └─────────────────┘      └─────────────────┘   │
│                                                                              │
└──────────────────────────────────────────────────────────────────────────────┘

Module Overview

Module Lines Purpose
ast.ml ~60 Core types: expressions, literals, tensor shapes, type definitions
parser.ml ~130 Tokenizer and recursive descent S-expression parser
typecheck.ml ~170 Hindley-Milner style inference with tensor shape checking
bytecode.ml ~190 IR lowering, peephole optimization, MatMul→Add fusion
vm.ml ~115 Register-based interpreter with autodiff integration
autodiff.ml ~90 Reverse-mode AD via computation graph and backpropagation
tensor.ml ~145 Tensor operations: add, mul, matmul, sum, transpose
arena.ml ~70 Bump allocator for tensor payload recycling
openblas.ml ~30 BLAS shim with pure OCaml GEMM fallback
compiler.ml ~35 Pipeline orchestration and pretty-printing

Type System

The compiler implements tensor-aware type inference:

┌────────────────────────────────────────────────────────────────┐
│                         TYPE SYSTEM                            │
├────────────────────────────────────────────────────────────────┤
│                                                                │
│  Base Types:                                                   │
│    • Scalar     - single floating-point value                  │
│    • Tensor     - n-dimensional array with dtype and shape     │
│    • Unit       - void/no value                                │
│    • Function   - param types → return type                    │
│                                                                │
│  Shape Dimensions:                                             │
│    • Fixed n    - concrete dimension (e.g., Fixed 2)           │
│    • Symbolic s - named dimension for polymorphism             │
│                                                                │
│  Type Rules:                                                   │
│    add : Tensor[s1] × Tensor[s2] → Tensor[broadcast(s1,s2)]   │
│    matmul : Tensor[m,k] × Tensor[k,n] → Tensor[m,n]           │
│    sum : Tensor[...] → Scalar                                  │
│    grad : (τ → Scalar) × τ → τ                                 │
│                                                                │
└────────────────────────────────────────────────────────────────┘

Automatic Differentiation

Reverse-mode autodiff builds a computation graph during the forward pass, then backpropagates gradients:

┌─────────────────────────────────────────────────────────────────────────────┐
│                    REVERSE-MODE AUTODIFF                                    │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  Forward Pass: Build computation graph                                      │
│                                                                             │
│       w ──────┐                                                             │
│               ▼                                                             │
│             MatMul ────► Sum ────► output (scalar)                          │
│               ▲                                                             │
│       x ──────┘                                                             │
│                                                                             │
│  Backward Pass: Propagate gradients (chain rule)                            │
│                                                                             │
│       ∂L/∂w ◄──┐                                                            │
│                │                                                            │
│             MatMul ◄─── Sum ◄─── ∂L/∂out = 1                                │
│                │         │                                                  │
│       ∂L/∂x ◄──┘         └─── ∂sum/∂input = ones_like(input)               │
│                                                                             │
│  Gradient Rules:                                                            │
│    • Add:    ∂L/∂a = ∂L/∂out,  ∂L/∂b = ∂L/∂out                             │
│    • Mul:    ∂L/∂a = ∂L/∂out * b,  ∂L/∂b = ∂L/∂out * a                     │
│    • MatMul: ∂L/∂A = ∂L/∂out @ Bᵀ,  ∂L/∂B = Aᵀ @ ∂L/∂out                   │
│    • Sum:    ∂L/∂x = broadcast(∂L/∂out, shape(x))                           │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

Bytecode Optimizations

The compiler performs several optimizations on the IR:

┌─────────────────────────────────────────────────────────────────────────────┐
│                         BYTECODE OPTIMIZATIONS                              │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  1. CONSTANT FOLDING                                                        │
│     Before:  0 CONST [1,2]                                                  │
│              1 CONST [3,4]                                                  │
│              2 ADD 0 1                                                      │
│                                                                             │
│     After:   0 CONST [4,6]    ◄── folded at compile time                   │
│                                                                             │
│  2. BROADCAST DETECTION                                                     │
│     When shapes differ, emit BCAST_ADD instead of ADD                       │
│     to hint at runtime broadcast requirements                               │
│                                                                             │
│  3. MATMUL+ADD FUSION                                                       │
│     Before:  0 MATMUL a b                                                   │
│              1 ADD 0 c                                                      │
│                                                                             │
│     After:   0 MATMUL a b                                                   │
│              1 FUSED_ADD 0 c  ◄── can use fused GEMM kernel                │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

Memory Management

The arena allocator reduces GC pressure for tensor-heavy workloads:

┌─────────────────────────────────────────────────────────────────────────────┐
│                           ARENA ALLOCATOR                                   │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  ┌──────────────────────────────────────────────────────────────────────┐  │
│  │                         Arena Memory Pool                             │  │
│  │  ┌────────┬────────┬────────┬────────┬─────────────────────────────┐ │  │
│  │  │ Tensor │ Tensor │ Tensor │ Tensor │      Free Space             │ │  │
│  │  │   A    │   B    │   C    │   D    │         ───►                │ │  │
│  │  └────────┴────────┴────────┴────────┴─────────────────────────────┘ │  │
│  │                                        ▲                              │  │
│  │                                        │ bump pointer                 │  │
│  └──────────────────────────────────────────────────────────────────────┘  │
│                                                                             │
│  Benefits:                                                                  │
│    • O(1) allocation via bump pointer                                       │
│    • Reduced GC pause times (~35% improvement in stress tests)              │
│    • Tensor payloads recycled without individual deallocation               │
│                                                                             │
│  GC Tuning:                                                                 │
│    • Minor heap: 4MB (vs default 256KB)                                     │
│    • Space overhead: 50% (more aggressive collection)                       │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

Project Structure

lisp_compiler/
├── dune-project           # Dune build configuration
├── tensor_lisp.opam       # Package metadata (auto-generated)
├── README.md
│
├── lib/                   # Core compiler library
│   ├── dune               # Library build config
│   ├── ast.ml             # Abstract syntax tree definitions
│   ├── parser.ml          # S-expression tokenizer and parser
│   ├── typecheck.ml       # Type inference and shape checking
│   ├── bytecode.ml        # IR generation and optimizations
│   ├── vm.ml              # Bytecode interpreter
│   ├── autodiff.ml        # Reverse-mode automatic differentiation
│   ├── tensor.ml          # Tensor data structure and operations
│   ├── arena.ml           # Memory pool allocator
│   ├── openblas.ml        # BLAS operation shim
│   └── compiler.ml        # Pipeline orchestration
│
└── bin/                   # CLI executable
    ├── dune               # Executable build config
    └── main.ml            # Entry point

Implementation Highlights

1. Recursive Descent Parser

The parser uses a clean two-phase approach:

  1. Tokenization: Source → token stream (LParen, RParen, Symbol)
  2. Parsing: Token stream → S-expression AST → Typed AST
(* Tokenization handles whitespace and delimiters *)
let is_delim = function
  | '(' | ')' | '\n' | '\t' | ' ' | '\r' -> true
  | _ -> false

(* Recursive descent for nested S-expressions *)
let rec parse_list idx acc =
  match tokens.(idx) with
  | RParen -> (Sexpr.List (List.rev acc), idx + 1)
  | _ -> let node, next = parse idx in parse_list next (node :: acc)

2. Shape-Aware Type Inference

The type checker infers tensor shapes and validates operations:

(* Broadcasting: align shapes right-to-left, expand dims of size 1 *)
let broadcast_dims lhs rhs =
  let rec aux acc lhs rhs = match (lhs, rhs) with
    | Fixed a, Fixed b when a = b -> Fixed a
    | Fixed 1, other | other, Fixed 1 -> other
    | _ -> failwith "shape mismatch"
  in aux [] (List.rev lhs) (List.rev rhs)

(* Matrix multiply requires matching inner dimensions *)
(* Tensor[m,k] × Tensor[k,n] → Tensor[m,n] *)

3. Closure Capture in Gradients

Lambdas in grad expressions can reference variables from the enclosing scope:

(* The lambda captures 'x' from the outer let binding *)
(let ((x (tensor f64 (2 2) 1 0 0 1)))
  (grad (lambda (w) (sum (matmul w x))) x))

The compiler passes the outer environment to lambda compilation, and the VM looks up captured variables in the outer store.

4. Computation Graph for Autodiff

Each differentiable operation creates a node in the computation graph:

type node = {
  id : int;
  op : op;              (* Input | Add | Mul | MatMul | Sum *)
  inputs : node list;   (* Parent nodes *)
  value : tensor;       (* Forward pass result *)
  mutable grad : tensor option;  (* Accumulated gradient *)
}

Backpropagation traverses the graph in reverse topological order, applying the chain rule.

Limitations & Future Work

Current Limitations

  • if expressions not yet implemented in bytecode lowering
  • Nested grad calls not supported
  • Only single-parameter lambdas work with grad
  • No GPU acceleration (CPU-only tensor operations)

Potential Extensions

  • Conditional expressions (if)
  • Higher-order gradients (grad of grad)
  • More tensor operations (reshape, transpose, slice, concat)
  • Native OpenBLAS/MKL integration for faster matrix ops
  • JIT compilation to native code
  • GPU backend via OpenCL or Metal

Building & Testing

Prerequisites

  • OCaml 5.0+
  • Dune 3.10+

Build Commands

# Build everything
dune build

# Run the compiler
dune exec tensorc

# Clean build artifacts
dune clean

Example Test Session

# Test tensor creation
$ echo '(tensor f64 (2 2) 1 2 3 4)' | dune exec tensorc
=== Optimized Bytecode ===
0 CONST

Result: tensor[2x2]

# Test sum reduction
$ echo '(sum (tensor f64 (3) 1 2 3))' | dune exec tensorc
=== Optimized Bytecode ===
0 CONST
1 SUM 0

Result: 6.0000

# Test gradient computation
$ echo '(let ((x (tensor f64 (2 2) 1 0 0 1)))
    (grad (lambda (w) (sum (matmul w x))) x))' | dune exec tensorc
=== Optimized Bytecode ===
0 CONST
4 GRAD

Result: tensor[2x2]

References & Inspiration

License

MIT License - See LICENSE for details.


Built as a demonstration of compiler construction, type systems, and automatic differentiation fundamentals.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages