diff --git a/lib/lua/ast/block.ex b/lib/lua/ast/block.ex new file mode 100644 index 0000000..6a0fc25 --- /dev/null +++ b/lib/lua/ast/block.ex @@ -0,0 +1,63 @@ +defmodule Lua.AST.Block do + @moduledoc """ + Represents a block of statements in Lua. + + A block is a sequence of statements that execute in order. + Blocks create a new scope for local variables. + """ + + alias Lua.AST.{Meta, Statement} + + @type t :: %__MODULE__{ + stmts: [Statement.t()], + meta: Meta.t() | nil + } + + defstruct stmts: [], meta: nil + + @doc """ + Creates a new Block. + + ## Examples + + iex> Lua.AST.Block.new([]) + %Lua.AST.Block{stmts: [], meta: nil} + + iex> Lua.AST.Block.new([], %Lua.AST.Meta{}) + %Lua.AST.Block{stmts: [], meta: %Lua.AST.Meta{start: nil, end: nil, metadata: %{}}} + """ + @spec new([Statement.t()], Meta.t() | nil) :: t() + def new(stmts \\ [], meta \\ nil) do + %__MODULE__{stmts: stmts, meta: meta} + end +end + +defmodule Lua.AST.Chunk do + @moduledoc """ + Represents the top-level chunk (file or string) in Lua. + + A chunk is essentially a block that represents a complete unit of Lua code. + """ + + alias Lua.AST.{Meta, Block} + + @type t :: %__MODULE__{ + block: Block.t(), + meta: Meta.t() | nil + } + + defstruct [:block, :meta] + + @doc """ + Creates a new Chunk. + + ## Examples + + iex> Lua.AST.Chunk.new(%Lua.AST.Block{stmts: []}) + %Lua.AST.Chunk{block: %Lua.AST.Block{stmts: [], meta: nil}, meta: nil} + """ + @spec new(Block.t(), Meta.t() | nil) :: t() + def new(block, meta \\ nil) do + %__MODULE__{block: block, meta: meta} + end +end diff --git a/lib/lua/ast/builder.ex b/lib/lua/ast/builder.ex new file mode 100644 index 0000000..221fccb --- /dev/null +++ b/lib/lua/ast/builder.ex @@ -0,0 +1,574 @@ +defmodule Lua.AST.Builder do + @moduledoc """ + Helpers for programmatically constructing AST nodes. + + Provides a convenient API for building AST without manually + creating all the struct fields. Useful for: + - Code generation + - AST transformations + - Testing + - Metaprogramming with quote/unquote + + ## Examples + + import Lua.AST.Builder + + # Build a simple expression: 2 + 2 + binop(:add, number(2), number(2)) + + # Build a local assignment: local x = 42 + local(["x"], [number(42)]) + + # Build a function: function add(a, b) return a + b end + func_decl("add", ["a", "b"], [ + return_stmt([binop(:add, var("a"), var("b"))]) + ]) + """ + + alias Lua.AST.{Chunk, Block, Meta, Expr, Statement} + + # Chunk and Block + + @doc """ + Creates a Chunk node. + + ## Examples + + chunk([local(["x"], [number(42)])]) + """ + @spec chunk([Statement.t()], Meta.t() | nil) :: Chunk.t() + def chunk(stmts, meta \\ nil) do + %Chunk{ + block: block(stmts, meta), + meta: meta + } + end + + @doc """ + Creates a Block node. + + ## Examples + + block([ + local(["x"], [number(10)]), + assign([var("x")], [number(20)]) + ]) + """ + @spec block([Statement.t()], Meta.t() | nil) :: Block.t() + def block(stmts, meta \\ nil) do + %Block{ + stmts: stmts, + meta: meta + } + end + + # Literal expressions + + @doc "Creates a nil literal" + @spec nil_lit(Meta.t() | nil) :: Expr.Nil.t() + def nil_lit(meta \\ nil), do: %Expr.Nil{meta: meta} + + @doc "Creates a boolean literal" + @spec bool(boolean(), Meta.t() | nil) :: Expr.Bool.t() + def bool(value, meta \\ nil), do: %Expr.Bool{value: value, meta: meta} + + @doc "Creates a number literal" + @spec number(number(), Meta.t() | nil) :: Expr.Number.t() + def number(value, meta \\ nil), do: %Expr.Number{value: value, meta: meta} + + @doc "Creates a string literal" + @spec string(String.t(), Meta.t() | nil) :: Expr.String.t() + def string(value, meta \\ nil), do: %Expr.String{value: value, meta: meta} + + @doc "Creates a vararg expression (...)" + @spec vararg(Meta.t() | nil) :: Expr.Vararg.t() + def vararg(meta \\ nil), do: %Expr.Vararg{meta: meta} + + # Variable and access + + @doc "Creates a variable reference" + @spec var(String.t(), Meta.t() | nil) :: Expr.Var.t() + def var(name, meta \\ nil), do: %Expr.Var{name: name, meta: meta} + + @doc """ + Creates a property access (obj.prop) + + ## Examples + + property(var("io"), "write") # io.write + """ + @spec property(Expr.t(), String.t(), Meta.t() | nil) :: Expr.Property.t() + def property(table, field, meta \\ nil) do + %Expr.Property{ + table: table, + field: field, + meta: meta + } + end + + @doc """ + Creates an index access (obj[index]) + + ## Examples + + index(var("t"), number(1)) # t[1] + """ + @spec index(Expr.t(), Expr.t(), Meta.t() | nil) :: Expr.Index.t() + def index(table, key, meta \\ nil) do + %Expr.Index{ + table: table, + key: key, + meta: meta + } + end + + # Operators + + @doc """ + Creates a binary operation. + + ## Operators + + - `:add`, `:sub`, `:mul`, `:div`, `:floor_div`, `:mod`, `:pow` + - `:concat` + - `:eq`, `:ne`, `:lt`, `:gt`, `:le`, `:ge` + - `:and`, `:or` + + ## Examples + + binop(:add, number(2), number(3)) # 2 + 3 + binop(:lt, var("x"), number(10)) # x < 10 + """ + @spec binop(atom(), Expr.t(), Expr.t(), Meta.t() | nil) :: Expr.BinOp.t() + def binop(op, left, right, meta \\ nil) do + %Expr.BinOp{ + op: op, + left: left, + right: right, + meta: meta + } + end + + @doc """ + Creates a unary operation. + + ## Operators + + - `:not` - logical not + - `:neg` - negation (-) + - `:len` - length operator (#) + + ## Examples + + unop(:neg, var("x")) # -x + unop(:not, var("flag")) # not flag + unop(:len, var("list")) # #list + """ + @spec unop(atom(), Expr.t(), Meta.t() | nil) :: Expr.UnOp.t() + def unop(op, operand, meta \\ nil) do + %Expr.UnOp{ + op: op, + operand: operand, + meta: meta + } + end + + # Table constructor + + @doc """ + Creates a table constructor. + + ## Field types + + - `{:list, expr}` - array-style field (value only) + - `{:record, key_expr, value_expr}` - key-value field + + ## Examples + + # Empty table: {} + table([]) + + # Array: {1, 2, 3} + table([ + {:list, number(1)}, + {:list, number(2)}, + {:list, number(3)} + ]) + + # Record: {x = 10, y = 20} + table([ + {:record, string("x"), number(10)}, + {:record, string("y"), number(20)} + ]) + """ + @spec table([{:list, Expr.t()} | {:record, Expr.t(), Expr.t()}], Meta.t() | nil) :: + Expr.Table.t() + def table(fields, meta \\ nil) do + %Expr.Table{ + fields: fields, + meta: meta + } + end + + # Function call + + @doc """ + Creates a function call. + + ## Examples + + call(var("print"), [string("hello")]) # print("hello") + call(property(var("io"), "write"), [string("test")]) # io.write("test") + """ + @spec call(Expr.t(), [Expr.t()], Meta.t() | nil) :: Expr.Call.t() + def call(func, args, meta \\ nil) do + %Expr.Call{ + func: func, + args: args, + meta: meta + } + end + + @doc """ + Creates a method call (obj:method(args)) + + ## Examples + + method_call(var("file"), "read", [string("*a")]) # file:read("*a") + """ + @spec method_call(Expr.t(), String.t(), [Expr.t()], Meta.t() | nil) :: Expr.MethodCall.t() + def method_call(object, method, args, meta \\ nil) do + %Expr.MethodCall{ + object: object, + method: method, + args: args, + meta: meta + } + end + + # Function expression + + @doc """ + Creates a function expression. + + ## Examples + + # function(x, y) return x + y end + function_expr(["x", "y"], [ + return_stmt([binop(:add, var("x"), var("y"))]) + ]) + + # function(...) return ... end + function_expr([], [return_stmt([vararg()])], vararg: true) + """ + @spec function_expr([String.t()], [Statement.t()], keyword()) :: Expr.Function.t() + def function_expr(params, body_stmts, opts \\ []) do + params_with_vararg = + if Keyword.get(opts, :vararg, false) do + params ++ [:vararg] + else + params + end + + %Expr.Function{ + params: params_with_vararg, + body: block(body_stmts), + meta: Keyword.get(opts, :meta) + } + end + + # Statements + + @doc """ + Creates an assignment statement. + + ## Examples + + # x = 10 + assign([var("x")], [number(10)]) + + # x, y = 1, 2 + assign([var("x"), var("y")], [number(1), number(2)]) + """ + @spec assign([Expr.t()], [Expr.t()], Meta.t() | nil) :: Statement.Assign.t() + def assign(targets, values, meta \\ nil) do + %Statement.Assign{ + targets: targets, + values: values, + meta: meta + } + end + + @doc """ + Creates a local variable declaration. + + ## Examples + + # local x + local(["x"], []) + + # local x = 10 + local(["x"], [number(10)]) + + # local x, y = 1, 2 + local(["x", "y"], [number(1), number(2)]) + """ + @spec local([String.t()], [Expr.t()], Meta.t() | nil) :: Statement.Local.t() + def local(names, values \\ [], meta \\ nil) do + %Statement.Local{ + names: names, + values: values, + meta: meta + } + end + + @doc """ + Creates a local function declaration. + + ## Examples + + # local function add(a, b) return a + b end + local_func("add", ["a", "b"], [ + return_stmt([binop(:add, var("a"), var("b"))]) + ]) + """ + @spec local_func(String.t(), [String.t()], [Statement.t()], keyword()) :: + Statement.LocalFunc.t() + def local_func(name, params, body_stmts, opts \\ []) do + params_with_vararg = + if Keyword.get(opts, :vararg, false) do + params ++ [:vararg] + else + params + end + + %Statement.LocalFunc{ + name: name, + params: params_with_vararg, + body: block(body_stmts), + meta: Keyword.get(opts, :meta) + } + end + + @doc """ + Creates a function declaration. + + ## Examples + + # function add(a, b) return a + b end + func_decl("add", ["a", "b"], [ + return_stmt([binop(:add, var("a"), var("b"))]) + ]) + + # function math.add(a, b) return a + b end + func_decl(["math", "add"], ["a", "b"], [...]) + """ + @spec func_decl(String.t() | [String.t()], [String.t()], [Statement.t()], keyword()) :: + Statement.FuncDecl.t() + def func_decl(name, params, body_stmts, opts \\ []) do + name_parts = if is_binary(name), do: [name], else: name + + params_with_vararg = + if Keyword.get(opts, :vararg, false) do + params ++ [:vararg] + else + params + end + + is_method = Keyword.get(opts, :is_method, false) + + %Statement.FuncDecl{ + name: name_parts, + params: params_with_vararg, + body: block(body_stmts), + is_method: is_method, + meta: Keyword.get(opts, :meta) + } + end + + @doc """ + Creates a function call statement. + + ## Examples + + call_stmt(call(var("print"), [string("hello")])) + """ + @spec call_stmt(Expr.Call.t() | Expr.MethodCall.t(), Meta.t() | nil) :: Statement.CallStmt.t() + def call_stmt(call_expr, meta \\ nil) do + %Statement.CallStmt{ + call: call_expr, + meta: meta + } + end + + @doc """ + Creates an if statement. + + ## Examples + + # if x > 0 then print(x) end + if_stmt( + binop(:gt, var("x"), number(0)), + [call_stmt(call(var("print"), [var("x")]))] + ) + + # if x > 0 then ... elseif x < 0 then ... else ... end + if_stmt( + binop(:gt, var("x"), number(0)), + [call_stmt(...)], + elseif: [{binop(:lt, var("x"), number(0)), [call_stmt(...)]}], + else: [call_stmt(...)] + ) + """ + @spec if_stmt(Expr.t(), [Statement.t()], keyword()) :: Statement.If.t() + def if_stmt(condition, then_stmts, opts \\ []) do + %Statement.If{ + condition: condition, + then_block: block(then_stmts), + elseifs: Keyword.get(opts, :elseif, []) |> Enum.map(fn {c, s} -> {c, block(s)} end), + else_block: if(else_stmts = Keyword.get(opts, :else), do: block(else_stmts)), + meta: Keyword.get(opts, :meta) + } + end + + @doc """ + Creates a while loop. + + ## Examples + + # while x > 0 do x = x - 1 end + while_stmt( + binop(:gt, var("x"), number(0)), + [assign([var("x")], [binop(:sub, var("x"), number(1))])] + ) + """ + @spec while_stmt(Expr.t(), [Statement.t()], Meta.t() | nil) :: Statement.While.t() + def while_stmt(condition, body_stmts, meta \\ nil) do + %Statement.While{ + condition: condition, + body: block(body_stmts), + meta: meta + } + end + + @doc """ + Creates a repeat-until loop. + + ## Examples + + # repeat x = x - 1 until x <= 0 + repeat_stmt( + [assign([var("x")], [binop(:sub, var("x"), number(1))])], + binop(:le, var("x"), number(0)) + ) + """ + @spec repeat_stmt([Statement.t()], Expr.t(), Meta.t() | nil) :: Statement.Repeat.t() + def repeat_stmt(body_stmts, condition, meta \\ nil) do + %Statement.Repeat{ + body: block(body_stmts), + condition: condition, + meta: meta + } + end + + @doc """ + Creates a numeric for loop. + + ## Examples + + # for i = 1, 10 do print(i) end + for_num("i", number(1), number(10), [ + call_stmt(call(var("print"), [var("i")])) + ]) + + # for i = 1, 10, 2 do print(i) end + for_num("i", number(1), number(10), [...], step: number(2)) + """ + @spec for_num(String.t(), Expr.t(), Expr.t(), [Statement.t()], keyword()) :: + Statement.ForNum.t() + def for_num(var_name, start, limit, body_stmts, opts \\ []) do + %Statement.ForNum{ + var: var_name, + start: start, + limit: limit, + step: Keyword.get(opts, :step), + body: block(body_stmts), + meta: Keyword.get(opts, :meta) + } + end + + @doc """ + Creates a generic for loop (for-in). + + ## Examples + + # for k, v in pairs(t) do print(k, v) end + for_in( + ["k", "v"], + [call(var("pairs"), [var("t")])], + [call_stmt(call(var("print"), [var("k"), var("v")]))] + ) + """ + @spec for_in([String.t()], [Expr.t()], [Statement.t()], Meta.t() | nil) :: Statement.ForIn.t() + def for_in(vars, iterators, body_stmts, meta \\ nil) do + %Statement.ForIn{ + vars: vars, + iterators: iterators, + body: block(body_stmts), + meta: meta + } + end + + @doc """ + Creates a do block. + + ## Examples + + # do local x = 10; print(x) end + do_block([ + local(["x"], [number(10)]), + call_stmt(call(var("print"), [var("x")])) + ]) + """ + @spec do_block([Statement.t()], Meta.t() | nil) :: Statement.Do.t() + def do_block(body_stmts, meta \\ nil) do + %Statement.Do{ + body: block(body_stmts), + meta: meta + } + end + + @doc """ + Creates a return statement. + + ## Examples + + # return + return_stmt([]) + + # return 42 + return_stmt([number(42)]) + + # return x, y + return_stmt([var("x"), var("y")]) + """ + @spec return_stmt([Expr.t()], Meta.t() | nil) :: Statement.Return.t() + def return_stmt(values, meta \\ nil) do + %Statement.Return{ + values: values, + meta: meta + } + end + + @doc "Creates a break statement" + @spec break_stmt(Meta.t() | nil) :: Statement.Break.t() + def break_stmt(meta \\ nil), do: %Statement.Break{meta: meta} + + @doc "Creates a goto statement" + @spec goto_stmt(String.t(), Meta.t() | nil) :: Statement.Goto.t() + def goto_stmt(label, meta \\ nil), do: %Statement.Goto{label: label, meta: meta} + + @doc "Creates a label" + @spec label(String.t(), Meta.t() | nil) :: Statement.Label.t() + def label(name, meta \\ nil), do: %Statement.Label{name: name, meta: meta} +end diff --git a/lib/lua/ast/expr.ex b/lib/lua/ast/expr.ex new file mode 100644 index 0000000..1b8deb1 --- /dev/null +++ b/lib/lua/ast/expr.ex @@ -0,0 +1,216 @@ +defmodule Lua.AST.Expr do + @moduledoc """ + Expression AST nodes for Lua. + + All expression nodes include a `meta` field for position tracking. + """ + + alias Lua.AST.Meta + + defmodule Nil do + @moduledoc "Represents the `nil` literal" + defstruct [:meta] + @type t :: %__MODULE__{meta: Meta.t() | nil} + end + + defmodule Bool do + @moduledoc "Represents boolean literals (`true` or `false`)" + defstruct [:value, :meta] + @type t :: %__MODULE__{value: boolean(), meta: Meta.t() | nil} + end + + defmodule Number do + @moduledoc "Represents numeric literals (integers and floats)" + defstruct [:value, :meta] + @type t :: %__MODULE__{value: number(), meta: Meta.t() | nil} + end + + defmodule String do + @moduledoc "Represents string literals" + defstruct [:value, :meta] + @type t :: %__MODULE__{value: String.t(), meta: Meta.t() | nil} + end + + defmodule Var do + @moduledoc "Represents a variable reference" + defstruct [:name, :meta] + @type t :: %__MODULE__{name: String.t(), meta: Meta.t() | nil} + end + + defmodule BinOp do + @moduledoc """ + Represents a binary operation. + + Operators: + - Arithmetic: `:add`, `:sub`, `:mul`, `:div`, `:floordiv`, `:mod`, `:pow` + - Comparison: `:eq`, `:ne`, `:lt`, `:le`, `:gt`, `:ge` + - Logical: `:and`, `:or` + - String: `:concat` + """ + defstruct [:op, :left, :right, :meta] + + @type op :: + :add + | :sub + | :mul + | :div + | :floordiv + | :mod + | :pow + | :eq + | :ne + | :lt + | :le + | :gt + | :ge + | :and + | :or + | :concat + + @type t :: %__MODULE__{ + op: op(), + left: Lua.AST.Expr.t(), + right: Lua.AST.Expr.t(), + meta: Meta.t() | nil + } + end + + defmodule UnOp do + @moduledoc """ + Represents a unary operation. + + Operators: + - `:not` - logical not + - `:neg` - arithmetic negation (-) + - `:len` - length operator (#) + """ + defstruct [:op, :operand, :meta] + + @type op :: :not | :neg | :len + + @type t :: %__MODULE__{ + op: op(), + operand: Lua.AST.Expr.t(), + meta: Meta.t() | nil + } + end + + defmodule Table do + @moduledoc """ + Represents a table constructor: `{...}` + + Fields can be: + - List entries: `{1, 2, 3}` -> `[{:list, expr}, ...]` + - Key-value pairs: `{a = 1}` -> `[{:pair, key_expr, val_expr}, ...]` + - Computed keys: `{["key"] = value}` -> `[{:pair, key_expr, val_expr}, ...]` + """ + defstruct [:fields, :meta] + + @type field :: + {:list, Lua.AST.Expr.t()} + | {:pair, Lua.AST.Expr.t(), Lua.AST.Expr.t()} + + @type t :: %__MODULE__{ + fields: [field()], + meta: Meta.t() | nil + } + end + + defmodule Call do + @moduledoc """ + Represents a function call: `func(args)` + """ + defstruct [:func, :args, :meta] + + @type t :: %__MODULE__{ + func: Lua.AST.Expr.t(), + args: [Lua.AST.Expr.t()], + meta: Meta.t() | nil + } + end + + defmodule MethodCall do + @moduledoc """ + Represents a method call: `obj:method(args)` + + This is syntactic sugar for `obj.method(obj, args)` in Lua. + """ + defstruct [:object, :method, :args, :meta] + + @type t :: %__MODULE__{ + object: Lua.AST.Expr.t(), + method: String.t(), + args: [Lua.AST.Expr.t()], + meta: Meta.t() | nil + } + end + + defmodule Index do + @moduledoc """ + Represents indexing with brackets: `table[key]` + """ + defstruct [:table, :key, :meta] + + @type t :: %__MODULE__{ + table: Lua.AST.Expr.t(), + key: Lua.AST.Expr.t(), + meta: Meta.t() | nil + } + end + + defmodule Property do + @moduledoc """ + Represents property access: `table.field` + + This is syntactic sugar for `table["field"]` in Lua. + """ + defstruct [:table, :field, :meta] + + @type t :: %__MODULE__{ + table: Lua.AST.Expr.t(), + field: String.t(), + meta: Meta.t() | nil + } + end + + defmodule Function do + @moduledoc """ + Represents a function expression: `function(params) body end` + + Params can include: + - Named parameters: `["a", "b", "c"]` + - Vararg: `{:vararg}` as the last element + """ + defstruct [:params, :body, :meta] + + @type param :: String.t() | :vararg + + @type t :: %__MODULE__{ + params: [param()], + body: Lua.AST.Block.t(), + meta: Meta.t() | nil + } + end + + defmodule Vararg do + @moduledoc "Represents the vararg expression: `...`" + defstruct [:meta] + @type t :: %__MODULE__{meta: Meta.t() | nil} + end + + @type t :: + Nil.t() + | Bool.t() + | Number.t() + | String.t() + | Var.t() + | BinOp.t() + | UnOp.t() + | Table.t() + | Call.t() + | MethodCall.t() + | Index.t() + | Property.t() + | Function.t() + | Vararg.t() +end diff --git a/lib/lua/ast/meta.ex b/lib/lua/ast/meta.ex new file mode 100644 index 0000000..1685c6e --- /dev/null +++ b/lib/lua/ast/meta.ex @@ -0,0 +1,89 @@ +defmodule Lua.AST.Meta do + @moduledoc """ + Position tracking metadata for AST nodes. + + Every AST node includes a `meta` field containing position information + for error reporting, source maps, and debugging. + """ + + @type position :: %{ + line: pos_integer(), + column: pos_integer(), + byte_offset: non_neg_integer() + } + + @type t :: %__MODULE__{ + start: position() | nil, + end: position() | nil, + metadata: map() + } + + defstruct start: nil, end: nil, metadata: %{} + + @doc """ + Creates a new Meta struct with start and end positions. + + ## Examples + + iex> Lua.AST.Meta.new( + ...> %{line: 1, column: 1, byte_offset: 0}, + ...> %{line: 1, column: 5, byte_offset: 4} + ...> ) + %Lua.AST.Meta{ + start: %{line: 1, column: 1, byte_offset: 0}, + end: %{line: 1, column: 5, byte_offset: 4}, + metadata: %{} + } + """ + @spec new(position() | nil, position() | nil, map()) :: t() + def new(start \\ nil, end_pos \\ nil, metadata \\ %{}) do + %__MODULE__{start: start, end: end_pos, metadata: metadata} + end + + @doc """ + Merges two Meta structs, taking the earliest start and latest end. + + Useful when combining multiple nodes into a single parent node. + + ## Examples + + iex> meta1 = Lua.AST.Meta.new(%{line: 1, column: 1, byte_offset: 0}, %{line: 1, column: 5, byte_offset: 4}) + iex> meta2 = Lua.AST.Meta.new(%{line: 1, column: 7, byte_offset: 6}, %{line: 1, column: 10, byte_offset: 9}) + iex> Lua.AST.Meta.merge(meta1, meta2) + %Lua.AST.Meta{ + start: %{line: 1, column: 1, byte_offset: 0}, + end: %{line: 1, column: 10, byte_offset: 9}, + metadata: %{} + } + """ + @spec merge(t(), t()) :: t() + def merge(%__MODULE__{start: start1, end: end1}, %__MODULE__{start: start2, end: end2}) do + new_start = earliest_position(start1, start2) + new_end = latest_position(end1, end2) + new(new_start, new_end) + end + + @doc """ + Adds metadata to an existing Meta struct. + """ + @spec add_metadata(t(), atom(), term()) :: t() + def add_metadata(%__MODULE__{metadata: metadata} = meta, key, value) do + %{meta | metadata: Map.put(metadata, key, value)} + end + + # Private helpers + + defp earliest_position(nil, pos), do: pos + defp earliest_position(pos, nil), do: pos + + defp earliest_position(pos1, pos2) do + if pos1.byte_offset <= pos2.byte_offset, do: pos1, else: pos2 + end + + defp latest_position(nil, pos), do: pos + defp latest_position(pos, nil), do: pos + + defp latest_position(pos1, pos2) do + if pos1.byte_offset >= pos2.byte_offset, do: pos1, else: pos2 + end +end diff --git a/lib/lua/ast/pretty_printer.ex b/lib/lua/ast/pretty_printer.ex new file mode 100644 index 0000000..2526dc0 --- /dev/null +++ b/lib/lua/ast/pretty_printer.ex @@ -0,0 +1,485 @@ +defmodule Lua.AST.PrettyPrinter do + @moduledoc """ + Converts AST back to Lua source code. + + Produces readable, properly indented Lua code from AST structures. + Useful for: + - Round-trip testing (parse → print → parse) + - Debugging AST transformations + - Code generation + + ## Examples + + ast = Parser.parse("local x = 2 + 2") + code = PrettyPrinter.print(ast) + # => "local x = 2 + 2\\n" + + # With custom indentation + PrettyPrinter.print(ast, indent: 4) + """ + + alias Lua.AST.{Chunk, Block, Expr, Statement} + + @type ast_node :: + Chunk.t() + | Block.t() + | Expr.t() + | Statement.t() + + @type opts :: [ + indent: pos_integer() + ] + + @doc """ + Converts an AST node to Lua source code. + + ## Options + + - `:indent` - Number of spaces per indentation level (default: 2) + + ## Examples + + PrettyPrinter.print(ast) + PrettyPrinter.print(ast, indent: 4) + """ + @spec print(ast_node, opts) :: String.t() + def print(node, opts \\ []) do + indent_size = Keyword.get(opts, :indent, 2) + do_print(node, 0, indent_size) + end + + # Chunk + defp do_print(%Chunk{block: block}, level, indent_size) do + do_print(block, level, indent_size) + end + + # Block + defp do_print(%Block{stmts: stmts}, level, indent_size) do + stmts + |> Enum.map(&do_print(&1, level, indent_size)) + |> Enum.join("\n") + |> Kernel.<>("\n") + end + + # Expressions + + defp do_print(%Expr.Nil{}, _level, _indent_size), do: "nil" + defp do_print(%Expr.Bool{value: true}, _level, _indent_size), do: "true" + defp do_print(%Expr.Bool{value: false}, _level, _indent_size), do: "false" + + defp do_print(%Expr.Number{value: n}, _level, _indent_size) do + # Format numbers nicely + if is_float(n) and Float.floor(n) == n do + # Integer-valued float + "#{trunc(n)}.0" + else + "#{n}" + end + end + + defp do_print(%Expr.String{value: s}, _level, _indent_size) do + # Escape special characters + escaped = + s + |> String.replace("\\", "\\\\") + |> String.replace("\"", "\\\"") + |> String.replace("\n", "\\n") + |> String.replace("\t", "\\t") + + "\"#{escaped}\"" + end + + defp do_print(%Expr.Vararg{}, _level, _indent_size), do: "..." + + defp do_print(%Expr.Var{name: name}, _level, _indent_size), do: name + + defp do_print(%Expr.BinOp{op: op, left: left, right: right}, level, indent_size) do + left_str = print_expr_with_parens(left, op, :left, level, indent_size) + right_str = print_expr_with_parens(right, op, :right, level, indent_size) + op_str = format_binop(op) + + "#{left_str} #{op_str} #{right_str}" + end + + defp do_print(%Expr.UnOp{op: op, operand: operand}, level, indent_size) do + operand_str = print_expr_with_parens(operand, op, :operand, level, indent_size) + op_str = format_unop(op) + + "#{op_str}#{operand_str}" + end + + defp do_print(%Expr.Table{fields: fields}, level, indent_size) do + if fields == [] do + "{}" + else + field_strs = + Enum.map(fields, fn + {:list, value} -> + do_print(value, level + 1, indent_size) + + {:record, key, value} -> + key_str = format_table_key(key, level + 1, indent_size) + value_str = do_print(value, level + 1, indent_size) + "#{key_str} = #{value_str}" + end) + + "{#{Enum.join(field_strs, ", ")}}" + end + end + + defp do_print(%Expr.Call{func: func, args: args}, level, indent_size) do + func_str = do_print(func, level, indent_size) + args_str = Enum.map(args, &do_print(&1, level, indent_size)) |> Enum.join(", ") + + "#{func_str}(#{args_str})" + end + + defp do_print(%Expr.MethodCall{object: obj, method: method, args: args}, level, indent_size) do + obj_str = do_print(obj, level, indent_size) + args_str = Enum.map(args, &do_print(&1, level, indent_size)) |> Enum.join(", ") + + "#{obj_str}:#{method}(#{args_str})" + end + + defp do_print(%Expr.Index{table: table, key: key}, level, indent_size) do + table_str = do_print(table, level, indent_size) + key_str = do_print(key, level, indent_size) + + "#{table_str}[#{key_str}]" + end + + defp do_print(%Expr.Property{table: table, field: field}, level, indent_size) do + table_str = do_print(table, level, indent_size) + "#{table_str}.#{field}" + end + + defp do_print(%Expr.Function{params: params, body: body}, level, indent_size) do + params_str = + params + |> Enum.map(fn + :vararg -> "..." + name -> name + end) + |> Enum.join(", ") + + body_str = print_block_body(body, level + 1, indent_size) + + "function(#{params_str})\n#{body_str}#{indent(level, indent_size)}end" + end + + # Statements + + defp do_print(%Statement.Assign{targets: targets, values: values}, level, indent_size) do + targets_str = Enum.map(targets, &do_print(&1, level, indent_size)) |> Enum.join(", ") + values_str = Enum.map(values, &do_print(&1, level, indent_size)) |> Enum.join(", ") + + "#{indent(level, indent_size)}#{targets_str} = #{values_str}" + end + + defp do_print(%Statement.Local{names: names, values: values}, level, indent_size) do + names_str = Enum.join(names, ", ") + + if values && values != [] do + values_str = Enum.map(values, &do_print(&1, level, indent_size)) |> Enum.join(", ") + "#{indent(level, indent_size)}local #{names_str} = #{values_str}" + else + "#{indent(level, indent_size)}local #{names_str}" + end + end + + defp do_print(%Statement.LocalFunc{name: name, params: params, body: body}, level, indent_size) do + params_str = + params + |> Enum.map(fn + :vararg -> "..." + name -> name + end) + |> Enum.join(", ") + + body_str = print_block_body(body, level + 1, indent_size) + + "#{indent(level, indent_size)}local function #{name}(#{params_str})\n#{body_str}#{indent(level, indent_size)}end" + end + + defp do_print(%Statement.FuncDecl{name: name, params: params, body: body}, level, indent_size) do + params_str = + params + |> Enum.map(fn + :vararg -> "..." + param_name -> param_name + end) + |> Enum.join(", ") + + body_str = print_block_body(body, level + 1, indent_size) + + "#{indent(level, indent_size)}function #{format_func_name(name)}(#{params_str})\n#{body_str}#{indent(level, indent_size)}end" + end + + defp do_print(%Statement.CallStmt{call: call}, level, indent_size) do + "#{indent(level, indent_size)}#{do_print(call, level, indent_size)}" + end + + defp do_print( + %Statement.If{ + condition: cond, + then_block: then_block, + elseifs: elseifs, + else_block: else_block + }, + level, + indent_size + ) do + cond_str = do_print(cond, level, indent_size) + then_str = print_block_body(then_block, level + 1, indent_size) + + elseif_strs = + Enum.map(elseifs, fn {c, b} -> + c_str = do_print(c, level, indent_size) + b_str = print_block_body(b, level + 1, indent_size) + "#{indent(level, indent_size)}elseif #{c_str} then\n#{b_str}" + end) + + else_str = + if else_block do + b_str = print_block_body(else_block, level + 1, indent_size) + "#{indent(level, indent_size)}else\n#{b_str}" + else + nil + end + + parts = ["#{indent(level, indent_size)}if #{cond_str} then\n#{then_str}"] ++ elseif_strs + + parts = + if else_str do + parts ++ [else_str] + else + parts + end + + Enum.join(parts, "") <> "#{indent(level, indent_size)}end" + end + + defp do_print(%Statement.While{condition: cond, body: body}, level, indent_size) do + cond_str = do_print(cond, level, indent_size) + body_str = print_block_body(body, level + 1, indent_size) + + "#{indent(level, indent_size)}while #{cond_str} do\n#{body_str}#{indent(level, indent_size)}end" + end + + defp do_print(%Statement.Repeat{body: body, condition: cond}, level, indent_size) do + body_str = print_block_body(body, level + 1, indent_size) + cond_str = do_print(cond, level, indent_size) + + "#{indent(level, indent_size)}repeat\n#{body_str}#{indent(level, indent_size)}until #{cond_str}" + end + + defp do_print( + %Statement.ForNum{var: var, start: start, limit: limit, step: step, body: body}, + level, + indent_size + ) do + start_str = do_print(start, level, indent_size) + limit_str = do_print(limit, level, indent_size) + body_str = print_block_body(body, level + 1, indent_size) + + step_str = + if step do + ", #{do_print(step, level, indent_size)}" + else + "" + end + + "#{indent(level, indent_size)}for #{var} = #{start_str}, #{limit_str}#{step_str} do\n#{body_str}#{indent(level, indent_size)}end" + end + + defp do_print( + %Statement.ForIn{vars: vars, iterators: iterators, body: body}, + level, + indent_size + ) do + vars_str = Enum.join(vars, ", ") + iterators_str = Enum.map(iterators, &do_print(&1, level, indent_size)) |> Enum.join(", ") + body_str = print_block_body(body, level + 1, indent_size) + + "#{indent(level, indent_size)}for #{vars_str} in #{iterators_str} do\n#{body_str}#{indent(level, indent_size)}end" + end + + defp do_print(%Statement.Do{body: body}, level, indent_size) do + body_str = print_block_body(body, level + 1, indent_size) + + "#{indent(level, indent_size)}do\n#{body_str}#{indent(level, indent_size)}end" + end + + defp do_print(%Statement.Return{values: values}, level, indent_size) do + if values == [] do + "#{indent(level, indent_size)}return" + else + values_str = Enum.map(values, &do_print(&1, level, indent_size)) |> Enum.join(", ") + "#{indent(level, indent_size)}return #{values_str}" + end + end + + defp do_print(%Statement.Break{}, level, indent_size) do + "#{indent(level, indent_size)}break" + end + + defp do_print(%Statement.Goto{label: label}, level, indent_size) do + "#{indent(level, indent_size)}goto #{label}" + end + + defp do_print(%Statement.Label{name: name}, level, indent_size) do + "#{indent(level, indent_size)}::#{name}::" + end + + # Helpers + + defp indent(level, indent_size) do + String.duplicate(" ", level * indent_size) + end + + defp print_block_body(%Block{stmts: stmts}, level, indent_size) do + stmts + |> Enum.map(&do_print(&1, level, indent_size)) + |> Enum.join("\n") + |> Kernel.<>("\n") + end + + # Add parentheses when needed for operator precedence + defp print_expr_with_parens(expr, parent_op, position, level, indent_size) do + expr_str = do_print(expr, level, indent_size) + + if needs_parens?(expr, parent_op, position) do + "(#{expr_str})" + else + expr_str + end + end + + # Determine if parentheses are needed based on precedence + defp needs_parens?(expr, parent_op, position) do + case expr do + %Expr.BinOp{op: child_op} -> + parent_prec = binop_precedence(parent_op) + child_prec = binop_precedence(child_op) + + cond do + # Lower precedence always needs parens + child_prec < parent_prec -> true + # Same precedence needs parens if associativity doesn't match + child_prec == parent_prec -> needs_parens_same_prec?(parent_op, position) + # Higher precedence never needs parens + true -> false + end + + %Expr.UnOp{} -> + # Unary ops have high precedence, rarely need parens + case parent_op do + # -2^3 should be -(2^3) + :pow -> true + _ -> false + end + + _ -> + false + end + end + + defp needs_parens_same_prec?(op, position) do + # Right-associative operators need parens on the left + # Left-associative operators need parens on the right + case {is_right_assoc?(op), position} do + {true, :left} -> true + {false, :right} -> true + _ -> false + end + end + + defp is_right_assoc?(op) do + op in [:concat, :pow] + end + + defp binop_precedence(op) do + case op do + :or -> 1 + :and -> 2 + :lt -> 3 + :gt -> 3 + :le -> 3 + :ge -> 3 + :ne -> 3 + :eq -> 3 + :concat -> 4 + :add -> 5 + :sub -> 5 + :mul -> 6 + :div -> 6 + :floor_div -> 6 + :mod -> 6 + :pow -> 8 + _ -> 0 + end + end + + defp format_binop(op) do + case op do + :add -> "+" + :sub -> "-" + :mul -> "*" + :div -> "/" + :floor_div -> "//" + :mod -> "%" + :pow -> "^" + :concat -> ".." + :eq -> "==" + :ne -> "~=" + :lt -> "<" + :gt -> ">" + :le -> "<=" + :ge -> ">=" + :and -> "and" + :or -> "or" + _ -> "" + end + end + + defp format_unop(op) do + case op do + :not -> "not " + :neg -> "-" + :len -> "#" + _ -> "" + end + end + + defp format_table_key(key, level, indent_size) do + case key do + %Expr.String{value: s} -> + # If it's a valid identifier, use shorthand + if valid_identifier?(s) do + s + else + "[#{do_print(key, level, indent_size)}]" + end + + _ -> + "[#{do_print(key, level, indent_size)}]" + end + end + + defp valid_identifier?(s) do + # Check if string is a valid Lua identifier + Regex.match?(~r/^[a-zA-Z_][a-zA-Z0-9_]*$/, s) and not lua_keyword?(s) + end + + defp lua_keyword?(s) do + s in ~w(and break do else elseif end false for function goto if in local nil not or repeat return then true until while) + end + + defp format_func_name(parts) when is_list(parts) do + Enum.join(parts, ".") + end + + defp format_func_name(name) when is_binary(name) do + name + end +end diff --git a/lib/lua/ast/statement.ex b/lib/lua/ast/statement.ex new file mode 100644 index 0000000..567c463 --- /dev/null +++ b/lib/lua/ast/statement.ex @@ -0,0 +1,262 @@ +defmodule Lua.AST.Statement do + @moduledoc """ + Statement AST nodes for Lua. + + All statement nodes include a `meta` field for position tracking. + """ + + alias Lua.AST.{Meta, Expr, Block} + + defmodule Assign do + @moduledoc """ + Represents an assignment statement: `targets = values` + + Both targets and values can be lists for multiple assignment: + `a, b = 1, 2` + """ + defstruct [:targets, :values, :meta] + + @type t :: %__MODULE__{ + targets: [Expr.t()], + values: [Expr.t()], + meta: Meta.t() | nil + } + end + + defmodule Local do + @moduledoc """ + Represents a local variable declaration: `local names = values` + + Values can be empty for declaration without initialization. + """ + defstruct [:names, :values, :meta] + + @type t :: %__MODULE__{ + names: [String.t()], + values: [Expr.t()], + meta: Meta.t() | nil + } + end + + defmodule LocalFunc do + @moduledoc """ + Represents a local function declaration: `local function name(params) body end` + """ + defstruct [:name, :params, :body, :meta] + + @type t :: %__MODULE__{ + name: String.t(), + params: [Expr.Function.param()], + body: Block.t(), + meta: Meta.t() | nil + } + end + + defmodule FuncDecl do + @moduledoc """ + Represents a function declaration: `function name(params) body end` + + The name can be a path for nested names: + - `function foo() end` -> name: `["foo"]`, is_method: false + - `function a.b.c() end` -> name: `["a", "b", "c"]`, is_method: false + - `function obj:method() end` -> name: `["obj", "method"]`, is_method: true + + When `is_method` is true, an implicit `self` parameter is added. + """ + defstruct [:name, :params, :body, :is_method, :meta] + + @type t :: %__MODULE__{ + name: [String.t()], + params: [Expr.Function.param()], + body: Block.t(), + is_method: boolean(), + meta: Meta.t() | nil + } + end + + defmodule CallStmt do + @moduledoc """ + Represents a function call as a statement. + + In Lua, function calls can be expressions or statements. + """ + defstruct [:call, :meta] + + @type t :: %__MODULE__{ + call: Expr.Call.t() | Expr.MethodCall.t(), + meta: Meta.t() | nil + } + end + + defmodule If do + @moduledoc """ + Represents an if statement with optional elseif and else clauses. + + ```lua + if condition then + block + elseif condition2 then + block2 + else + else_block + end + ``` + """ + defstruct [:condition, :then_block, :elseifs, :else_block, :meta] + + @type elseif_clause :: {Expr.t(), Block.t()} + + @type t :: %__MODULE__{ + condition: Expr.t(), + then_block: Block.t(), + elseifs: [elseif_clause()], + else_block: Block.t() | nil, + meta: Meta.t() | nil + } + end + + defmodule While do + @moduledoc """ + Represents a while loop: `while condition do block end` + """ + defstruct [:condition, :body, :meta] + + @type t :: %__MODULE__{ + condition: Expr.t(), + body: Block.t(), + meta: Meta.t() | nil + } + end + + defmodule Repeat do + @moduledoc """ + Represents a repeat-until loop: `repeat block until condition` + """ + defstruct [:body, :condition, :meta] + + @type t :: %__MODULE__{ + body: Block.t(), + condition: Expr.t(), + meta: Meta.t() | nil + } + end + + defmodule ForNum do + @moduledoc """ + Represents a numeric for loop: `for var = start, limit, step do block end` + + The step is optional and defaults to 1. + """ + defstruct [:var, :start, :limit, :step, :body, :meta] + + @type t :: %__MODULE__{ + var: String.t(), + start: Expr.t(), + limit: Expr.t(), + step: Expr.t() | nil, + body: Block.t(), + meta: Meta.t() | nil + } + end + + defmodule ForIn do + @moduledoc """ + Represents a generic for loop: `for vars in exprs do block end` + + ```lua + for k, v in pairs(t) do + -- block + end + ``` + """ + defstruct [:vars, :iterators, :body, :meta] + + @type t :: %__MODULE__{ + vars: [String.t()], + iterators: [Expr.t()], + body: Block.t(), + meta: Meta.t() | nil + } + end + + defmodule Do do + @moduledoc """ + Represents a do block: `do block end` + + Used to create a new scope. + """ + defstruct [:body, :meta] + + @type t :: %__MODULE__{ + body: Block.t(), + meta: Meta.t() | nil + } + end + + defmodule Return do + @moduledoc """ + Represents a return statement: `return exprs` + + Can return multiple values. + """ + defstruct [:values, :meta] + + @type t :: %__MODULE__{ + values: [Expr.t()], + meta: Meta.t() | nil + } + end + + defmodule Break do + @moduledoc """ + Represents a break statement: `break` + """ + defstruct [:meta] + @type t :: %__MODULE__{meta: Meta.t() | nil} + end + + defmodule Goto do + @moduledoc """ + Represents a goto statement: `goto label` + + Introduced in Lua 5.2. + """ + defstruct [:label, :meta] + + @type t :: %__MODULE__{ + label: String.t(), + meta: Meta.t() | nil + } + end + + defmodule Label do + @moduledoc """ + Represents a label: `::label::` + + Introduced in Lua 5.2. + """ + defstruct [:name, :meta] + + @type t :: %__MODULE__{ + name: String.t(), + meta: Meta.t() | nil + } + end + + @type t :: + Assign.t() + | Local.t() + | LocalFunc.t() + | FuncDecl.t() + | CallStmt.t() + | If.t() + | While.t() + | Repeat.t() + | ForNum.t() + | ForIn.t() + | Do.t() + | Return.t() + | Break.t() + | Goto.t() + | Label.t() +end diff --git a/lib/lua/ast/walker.ex b/lib/lua/ast/walker.ex new file mode 100644 index 0000000..161ea74 --- /dev/null +++ b/lib/lua/ast/walker.ex @@ -0,0 +1,348 @@ +defmodule Lua.AST.Walker do + @moduledoc """ + AST traversal utilities using the visitor pattern. + + Provides functions for walking, mapping, and reducing over AST nodes. + + ## Examples + + # Simple traversal (side effects) + Walker.walk(ast, fn node -> + IO.inspect(node) + end) + + # Transform AST (double all numbers) + Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n * 2} + node -> node + end) + + # Accumulate values (collect all variable names) + Walker.reduce(ast, [], fn + %Expr.Var{name: name}, acc -> [name | acc] + _node, acc -> acc + end) + + # Post-order traversal + Walker.walk(ast, fn node -> ... end, order: :post) + """ + + alias Lua.AST.{Chunk, Block, Expr, Statement} + + @type ast_node :: + Chunk.t() + | Block.t() + | Expr.t() + | Statement.t() + + @type visitor :: (ast_node -> any()) + @type mapper :: (ast_node -> ast_node) + @type reducer :: (ast_node, acc :: any() -> any()) + + @type order :: :pre | :post + + @doc """ + Walks the AST, calling the visitor function for each node. + + The visitor is called in pre-order by default (parent before children). + Use `order: :post` for post-order traversal (children before parent). + + ## Options + + - `:order` - `:pre` (default) or `:post` + + ## Examples + + Walker.walk(ast, fn + %Expr.Number{value: n} -> IO.puts("Found number: \#{n}") + _node -> :ok + end) + """ + @spec walk(ast_node, visitor, keyword()) :: :ok + def walk(node, visitor, opts \\ []) do + order = Keyword.get(opts, :order, :pre) + do_walk(node, visitor, order) + :ok + end + + @doc """ + Maps over the AST, transforming nodes with the mapper function. + + The mapper is called in post-order (children before parent) to ensure + transformations propagate upward correctly. + + ## Examples + + # Double all numbers + Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n * 2} + node -> node + end) + """ + @spec map(ast_node, mapper) :: ast_node + def map(node, mapper) do + do_map(node, mapper) + end + + @doc """ + Reduces the AST to a single value by calling the reducer function for each node. + + The reducer is called in pre-order by default. + + ## Examples + + # Collect all variable names + Walker.reduce(ast, [], fn + %Expr.Var{name: name}, acc -> [name | acc] + _node, acc -> acc + end) + + # Count all nodes + Walker.reduce(ast, 0, fn _node, acc -> acc + 1 end) + """ + @spec reduce(ast_node, acc, reducer) :: acc when acc: any() + def reduce(node, initial, reducer) do + do_reduce(node, initial, reducer) + end + + # Private implementation + + # Walk in pre-order or post-order + defp do_walk(node, visitor, :pre) do + visitor.(node) + walk_children(node, visitor, :pre) + end + + defp do_walk(node, visitor, :post) do + walk_children(node, visitor, :post) + visitor.(node) + end + + defp walk_children(node, visitor, order) do + children(node) + |> Enum.each(fn child -> do_walk(child, visitor, order) end) + end + + # Map (post-order to transform bottom-up) + defp do_map(node, mapper) do + mapped_children = + case node do + # Chunk + %Chunk{block: block} = chunk -> + %{chunk | block: do_map(block, mapper)} + + # Block + %Block{stmts: stmts} = block -> + %{block | stmts: Enum.map(stmts, &do_map(&1, mapper))} + + # Expressions + %Expr.BinOp{left: left, right: right} = expr -> + %{expr | left: do_map(left, mapper), right: do_map(right, mapper)} + + %Expr.UnOp{operand: operand} = expr -> + %{expr | operand: do_map(operand, mapper)} + + %Expr.Table{fields: fields} = expr -> + mapped_fields = + Enum.map(fields, fn + {:list, value} -> {:list, do_map(value, mapper)} + {:record, key, value} -> {:record, do_map(key, mapper), do_map(value, mapper)} + end) + + %{expr | fields: mapped_fields} + + %Expr.Call{func: func, args: args} = expr -> + %{expr | func: do_map(func, mapper), args: Enum.map(args, &do_map(&1, mapper))} + + %Expr.MethodCall{object: obj, args: args} = expr -> + %{expr | object: do_map(obj, mapper), args: Enum.map(args, &do_map(&1, mapper))} + + %Expr.Index{table: table, key: key} = expr -> + %{expr | table: do_map(table, mapper), key: do_map(key, mapper)} + + %Expr.Property{table: table} = expr -> + %{expr | table: do_map(table, mapper)} + + %Expr.Function{body: body} = expr -> + %{expr | body: do_map(body, mapper)} + + # Statements + %Statement.Assign{targets: targets, values: values} = stmt -> + %{ + stmt + | targets: Enum.map(targets, &do_map(&1, mapper)), + values: Enum.map(values, &do_map(&1, mapper)) + } + + %Statement.Local{values: values} = stmt when is_list(values) -> + %{stmt | values: Enum.map(values, &do_map(&1, mapper))} + + %Statement.Local{} = stmt -> + stmt + + %Statement.LocalFunc{body: body} = stmt -> + %{stmt | body: do_map(body, mapper)} + + %Statement.FuncDecl{body: body} = stmt -> + %{stmt | body: do_map(body, mapper)} + + %Statement.CallStmt{call: call} = stmt -> + %{stmt | call: do_map(call, mapper)} + + %Statement.If{ + condition: cond, + then_block: then_block, + elseifs: elseifs, + else_block: else_block + } = stmt -> + mapped_elseifs = + Enum.map(elseifs, fn {c, b} -> {do_map(c, mapper), do_map(b, mapper)} end) + + mapped_else = if else_block, do: do_map(else_block, mapper), else: nil + + %{ + stmt + | condition: do_map(cond, mapper), + then_block: do_map(then_block, mapper), + elseifs: mapped_elseifs, + else_block: mapped_else + } + + %Statement.While{condition: cond, body: body} = stmt -> + %{stmt | condition: do_map(cond, mapper), body: do_map(body, mapper)} + + %Statement.Repeat{body: body, condition: cond} = stmt -> + %{stmt | body: do_map(body, mapper), condition: do_map(cond, mapper)} + + %Statement.ForNum{var: _var, start: start, limit: limit, step: step, body: body} = stmt -> + mapped_step = if step, do: do_map(step, mapper), else: nil + + %{ + stmt + | start: do_map(start, mapper), + limit: do_map(limit, mapper), + step: mapped_step, + body: do_map(body, mapper) + } + + %Statement.ForIn{vars: _vars, iterators: iterators, body: body} = stmt -> + %{ + stmt + | iterators: Enum.map(iterators, &do_map(&1, mapper)), + body: do_map(body, mapper) + } + + %Statement.Do{body: body} = stmt -> + %{stmt | body: do_map(body, mapper)} + + %Statement.Return{values: values} = stmt -> + %{stmt | values: Enum.map(values, &do_map(&1, mapper))} + + # Leaf nodes (no children) + _ -> + node + end + + mapper.(mapped_children) + end + + # Reduce (pre-order accumulation) + defp do_reduce(node, acc, reducer) do + acc = reducer.(node, acc) + + children(node) + |> Enum.reduce(acc, fn child, acc -> do_reduce(child, acc, reducer) end) + end + + # Extract children for traversal + defp children(node) do + case node do + # Chunk + %Chunk{block: block} -> + [block] + + # Block + %Block{stmts: stmts} -> + stmts + + # Expressions with children + %Expr.BinOp{left: left, right: right} -> + [left, right] + + %Expr.UnOp{operand: operand} -> + [operand] + + %Expr.Table{fields: fields} -> + extract_table_fields(fields) + + %Expr.Call{func: func, args: args} -> + [func | args] + + %Expr.MethodCall{object: obj, args: args} -> + [obj | args] + + %Expr.Index{table: table, key: key} -> + [table, key] + + %Expr.Property{table: table} -> + [table] + + %Expr.Function{body: body} -> + [body] + + # Statements with children + %Statement.Assign{targets: targets, values: values} -> + targets ++ values + + %Statement.Local{values: values} when is_list(values) -> + values + + %Statement.LocalFunc{body: body} -> + [body] + + %Statement.FuncDecl{body: body} -> + [body] + + %Statement.CallStmt{call: call} -> + [call] + + %Statement.If{ + condition: cond, + then_block: then_block, + elseifs: elseifs, + else_block: else_block + } -> + elseif_nodes = Enum.flat_map(elseifs, fn {c, b} -> [c, b] end) + [cond, then_block | elseif_nodes] ++ if(else_block, do: [else_block], else: []) + + %Statement.While{condition: cond, body: body} -> + [cond, body] + + %Statement.Repeat{body: body, condition: cond} -> + [body, cond] + + %Statement.ForNum{start: start, limit: limit, step: step, body: body} -> + [start, limit] ++ if(step, do: [step], else: []) ++ [body] + + %Statement.ForIn{iterators: iterators, body: body} -> + iterators ++ [body] + + %Statement.Do{body: body} -> + [body] + + %Statement.Return{values: values} -> + values + + # Leaf nodes (no children) + _ -> + [] + end + end + + defp extract_table_fields(fields) do + Enum.flat_map(fields, fn + {:list, value} -> [value] + {:record, key, value} -> [key, value] + end) + end +end diff --git a/lib/lua/lexer.ex b/lib/lua/lexer.ex new file mode 100644 index 0000000..03400de --- /dev/null +++ b/lib/lua/lexer.ex @@ -0,0 +1,515 @@ +defmodule Lua.Lexer do + @moduledoc """ + Hand-written lexer for Lua 5.3 using Elixir binary pattern matching. + + Tokenizes Lua source code into a list of tokens with position tracking. + """ + + @type position :: %{line: pos_integer(), column: pos_integer(), byte_offset: non_neg_integer()} + @type token :: + {:keyword, atom(), position()} + | {:identifier, String.t(), position()} + | {:number, number(), position()} + | {:string, String.t(), position()} + | {:operator, atom(), position()} + | {:delimiter, atom(), position()} + | {:eof, position()} + + @keywords ~w( + and break do else elseif end false for function goto if in + local nil not or repeat return then true until while + ) + + @doc """ + Tokenizes Lua source code into a list of tokens. + + ## Examples + + iex> Lua.Lexer.tokenize("local x = 42") + {:ok, [ + {:keyword, :local, %{line: 1, column: 1, byte_offset: 0}}, + {:identifier, "x", %{line: 1, column: 7, byte_offset: 6}}, + {:operator, :assign, %{line: 1, column: 9, byte_offset: 8}}, + {:number, 42, %{line: 1, column: 11, byte_offset: 10}}, + {:eof, %{line: 1, column: 13, byte_offset: 12}} + ]} + """ + @spec tokenize(String.t()) :: {:ok, [token()]} | {:error, term()} + def tokenize(code) when is_binary(code) do + pos = %{line: 1, column: 1, byte_offset: 0} + do_tokenize(code, [], pos) + end + + # End of input + defp do_tokenize(<<>>, acc, pos) do + {:ok, Enum.reverse([{:eof, pos} | acc])} + end + + # Whitespace (space, tab) + defp do_tokenize(<>, acc, pos) when c in [?\s, ?\t] do + new_pos = advance_column(pos, 1) + do_tokenize(rest, acc, new_pos) + end + + # Newline (LF) + defp do_tokenize(<>, acc, pos) do + new_pos = %{line: pos.line + 1, column: 1, byte_offset: pos.byte_offset + 1} + do_tokenize(rest, acc, new_pos) + end + + # Carriage return (CR, or CRLF) + defp do_tokenize(<>, acc, pos) do + new_pos = %{line: pos.line + 1, column: 1, byte_offset: pos.byte_offset + 2} + do_tokenize(rest, acc, new_pos) + end + + defp do_tokenize(<>, acc, pos) do + new_pos = %{line: pos.line + 1, column: 1, byte_offset: pos.byte_offset + 1} + do_tokenize(rest, acc, new_pos) + end + + # Comments: single-line (--) or multi-line (--[[ ... ]]) + defp do_tokenize(<<"--[", rest::binary>>, acc, pos) do + # Check if it's a multi-line comment + case rest do + <<"[", _::binary>> -> + # Multi-line comment --[[ ... ]] + scan_multiline_comment(rest, acc, advance_column(pos, 3), 0) + + _ -> + # Single-line comment starting with --[ + scan_single_line_comment(rest, acc, advance_column(pos, 3)) + end + end + + defp do_tokenize(<<"--", rest::binary>>, acc, pos) do + scan_single_line_comment(rest, acc, advance_column(pos, 2)) + end + + # Strings: double-quoted + defp do_tokenize(<>, acc, pos) do + scan_string(rest, "", acc, advance_column(pos, 1), pos, ?") + end + + # Strings: single-quoted + defp do_tokenize(<>, acc, pos) do + scan_string(rest, "", acc, advance_column(pos, 1), pos, ?') + end + + # Strings: multi-line [[ ... ]] or [=[ ... ]=] + defp do_tokenize(<<"[", rest::binary>>, acc, pos) do + case scan_long_bracket(rest, 0) do + {:ok, equals, after_bracket} -> + scan_long_string(after_bracket, "", acc, advance_column(pos, 2 + equals), pos, equals) + + :error -> + # Not a long string, treat as delimiter + token = {:delimiter, :lbracket, pos} + do_tokenize(rest, [token | acc], advance_column(pos, 1)) + end + end + + # Numbers: hex (0x, 0X) + defp do_tokenize(<<"0", x, rest::binary>>, acc, pos) when x in [?x, ?X] do + scan_hex_number(rest, "", acc, advance_column(pos, 2), pos) + end + + # Numbers: decimal or float + defp do_tokenize(<>, acc, pos) when c in ?0..?9 do + scan_number(<>, "", acc, pos, pos) + end + + # Three-character operators + defp do_tokenize(<<"...", rest::binary>>, acc, pos) do + token = {:operator, :vararg, pos} + do_tokenize(rest, [token | acc], advance_column(pos, 3)) + end + + # Two-character operators + defp do_tokenize(<<"==", rest::binary>>, acc, pos) do + token = {:operator, :eq, pos} + do_tokenize(rest, [token | acc], advance_column(pos, 2)) + end + + defp do_tokenize(<<"~=", rest::binary>>, acc, pos) do + token = {:operator, :ne, pos} + do_tokenize(rest, [token | acc], advance_column(pos, 2)) + end + + defp do_tokenize(<<"<=", rest::binary>>, acc, pos) do + token = {:operator, :le, pos} + do_tokenize(rest, [token | acc], advance_column(pos, 2)) + end + + defp do_tokenize(<<">=", rest::binary>>, acc, pos) do + token = {:operator, :ge, pos} + do_tokenize(rest, [token | acc], advance_column(pos, 2)) + end + + defp do_tokenize(<<"..", rest::binary>>, acc, pos) do + token = {:operator, :concat, pos} + do_tokenize(rest, [token | acc], advance_column(pos, 2)) + end + + defp do_tokenize(<<"::", rest::binary>>, acc, pos) do + token = {:delimiter, :double_colon, pos} + do_tokenize(rest, [token | acc], advance_column(pos, 2)) + end + + defp do_tokenize(<<"//", rest::binary>>, acc, pos) do + token = {:operator, :floordiv, pos} + do_tokenize(rest, [token | acc], advance_column(pos, 2)) + end + + # Single-character operators and delimiters + defp do_tokenize(<>, acc, pos) when c in [?+, ?-, ?*, ?/, ?%, ?^, ?#] do + op = + case c do + ?+ -> :add + ?- -> :sub + ?* -> :mul + ?/ -> :div + ?% -> :mod + ?^ -> :pow + ?# -> :len + end + + token = {:operator, op, pos} + do_tokenize(rest, [token | acc], advance_column(pos, 1)) + end + + defp do_tokenize(<>, acc, pos) when c in [?<, ?>, ?=] do + op = + case c do + ?< -> :lt + ?> -> :gt + ?= -> :assign + end + + token = {:operator, op, pos} + do_tokenize(rest, [token | acc], advance_column(pos, 1)) + end + + defp do_tokenize(<>, acc, pos) + when c in [?(, ?), ?{, ?}, ?], ?;, ?,, ?., ?:] do + delim = + case c do + ?( -> :lparen + ?) -> :rparen + ?{ -> :lbrace + ?} -> :rbrace + ?] -> :rbracket + ?; -> :semicolon + ?, -> :comma + ?. -> :dot + ?: -> :colon + end + + token = {:delimiter, delim, pos} + do_tokenize(rest, [token | acc], advance_column(pos, 1)) + end + + # Identifiers and keywords + defp do_tokenize(<>, acc, pos) + when c in ?a..?z or c in ?A..?Z or c == ?_ do + scan_identifier(<>, "", acc, pos, pos) + end + + # Unexpected character + defp do_tokenize(<>, _acc, pos) do + {:error, {:unexpected_character, c, pos}} + end + + # Scan single-line comment (skip until newline) + defp scan_single_line_comment(<>, acc, pos) do + new_pos = %{line: pos.line + 1, column: 1, byte_offset: pos.byte_offset + 1} + do_tokenize(rest, acc, new_pos) + end + + defp scan_single_line_comment(<>, acc, pos) do + new_pos = %{line: pos.line + 1, column: 1, byte_offset: pos.byte_offset + 2} + do_tokenize(rest, acc, new_pos) + end + + defp scan_single_line_comment(<>, acc, pos) do + new_pos = %{line: pos.line + 1, column: 1, byte_offset: pos.byte_offset + 1} + do_tokenize(rest, acc, new_pos) + end + + defp scan_single_line_comment(<<>>, acc, pos) do + {:ok, Enum.reverse([{:eof, pos} | acc])} + end + + defp scan_single_line_comment(<<_, rest::binary>>, acc, pos) do + scan_single_line_comment(rest, acc, advance_column(pos, 1)) + end + + # Scan multi-line comment --[[ ... ]] or --[=[ ... ]=] + defp scan_multiline_comment(<<"[", rest::binary>>, acc, pos, level) do + scan_multiline_comment_content(rest, acc, advance_column(pos, 1), level) + end + + defp scan_multiline_comment_content(<<"]", rest::binary>>, acc, pos, level) do + case try_close_long_bracket(rest, level, 0) do + {:ok, after_bracket} -> + new_pos = advance_column(pos, 2 + level) + do_tokenize(after_bracket, acc, new_pos) + + :error -> + scan_multiline_comment_content(rest, acc, advance_column(pos, 1), level) + end + end + + defp scan_multiline_comment_content(<>, acc, pos, level) do + new_pos = %{line: pos.line + 1, column: 1, byte_offset: pos.byte_offset + 1} + scan_multiline_comment_content(rest, acc, new_pos, level) + end + + defp scan_multiline_comment_content(<<>>, _acc, pos, _level) do + {:error, {:unclosed_comment, pos}} + end + + defp scan_multiline_comment_content(<<_, rest::binary>>, acc, pos, level) do + scan_multiline_comment_content(rest, acc, advance_column(pos, 1), level) + end + + # Scan quoted string + defp scan_string(<>, str_acc, acc, pos, start_pos, quote) do + # Closing quote + token = {:string, str_acc, start_pos} + do_tokenize(rest, [token | acc], pos) + end + + defp scan_string(<>, str_acc, acc, pos, start_pos, quote) do + # Escape sequence + case escape_char(esc) do + {:ok, char} -> + scan_string(rest, str_acc <> <>, acc, advance_column(pos, 2), start_pos, quote) + + :error -> + # Invalid escape, but continue scanning + scan_string(rest, str_acc <> <>, acc, advance_column(pos, 2), start_pos, quote) + end + end + + defp scan_string(<>, _str_acc, _acc, pos, _start_pos, _quote) do + {:error, {:unclosed_string, pos}} + end + + defp scan_string(<<>>, _str_acc, _acc, pos, _start_pos, _quote) do + {:error, {:unclosed_string, pos}} + end + + defp scan_string(<>, str_acc, acc, pos, start_pos, quote) do + scan_string(rest, str_acc <> <>, acc, advance_column(pos, 1), start_pos, quote) + end + + # Escape character mapping + defp escape_char(?a), do: {:ok, ?\a} + defp escape_char(?b), do: {:ok, ?\b} + defp escape_char(?f), do: {:ok, ?\f} + defp escape_char(?n), do: {:ok, ?\n} + defp escape_char(?r), do: {:ok, ?\r} + defp escape_char(?t), do: {:ok, ?\t} + defp escape_char(?v), do: {:ok, ?\v} + defp escape_char(?\\), do: {:ok, ?\\} + defp escape_char(?"), do: {:ok, ?"} + defp escape_char(?'), do: {:ok, ?'} + defp escape_char(_), do: :error + + # Scan long bracket for level: [[ or [=[ or [==[ etc. + defp scan_long_bracket(rest, equals) do + case rest do + <<"=", after_eq::binary>> -> + scan_long_bracket(after_eq, equals + 1) + + <<"[", after_bracket::binary>> -> + {:ok, equals, after_bracket} + + _ -> + :error + end + end + + # Try to close long bracket: ]] or ]=] or ]==] etc. + defp try_close_long_bracket(rest, target_level, current_level) do + if current_level == target_level do + case rest do + <<"]", after_bracket::binary>> -> + {:ok, after_bracket} + + _ -> + :error + end + else + case rest do + <<"=", after_eq::binary>> -> + try_close_long_bracket(after_eq, target_level, current_level + 1) + + _ -> + :error + end + end + end + + # Scan long string [[ ... ]] or [=[ ... ]=] + defp scan_long_string(<<"]", rest::binary>>, str_acc, acc, pos, start_pos, level) do + case try_close_long_bracket(rest, level, 0) do + {:ok, after_bracket} -> + token = {:string, str_acc, start_pos} + new_pos = advance_column(pos, 2 + level) + do_tokenize(after_bracket, [token | acc], new_pos) + + :error -> + scan_long_string(rest, str_acc <> "]", acc, advance_column(pos, 1), start_pos, level) + end + end + + defp scan_long_string(<>, str_acc, acc, pos, start_pos, level) do + new_pos = %{line: pos.line + 1, column: 1, byte_offset: pos.byte_offset + 1} + scan_long_string(rest, str_acc <> "\n", acc, new_pos, start_pos, level) + end + + defp scan_long_string(<<>>, _str_acc, _acc, pos, _start_pos, _level) do + {:error, {:unclosed_long_string, pos}} + end + + defp scan_long_string(<>, str_acc, acc, pos, start_pos, level) do + scan_long_string(rest, str_acc <> <>, acc, advance_column(pos, 1), start_pos, level) + end + + # Scan identifier or keyword + defp scan_identifier(<>, id_acc, acc, pos, start_pos) + when c in ?a..?z or c in ?A..?Z or c in ?0..?9 or c == ?_ do + scan_identifier(rest, id_acc <> <>, acc, advance_column(pos, 1), start_pos) + end + + defp scan_identifier(rest, id_acc, acc, pos, start_pos) do + # Check if it's a keyword + token = + if id_acc in @keywords do + {:keyword, String.to_atom(id_acc), start_pos} + else + {:identifier, id_acc, start_pos} + end + + do_tokenize(rest, [token | acc], pos) + end + + # Scan decimal number + defp scan_number(<>, num_acc, acc, pos, start_pos) + when c in ?0..?9 do + scan_number(rest, num_acc <> <>, acc, advance_column(pos, 1), start_pos) + end + + defp scan_number(<>, num_acc, acc, pos, start_pos) do + # Trailing dot is not part of the number + finalize_number(num_acc, <<".">>, acc, pos, start_pos) + end + + defp scan_number(<<".", c, rest::binary>>, num_acc, acc, pos, start_pos) + when c in ?0..?9 do + # Decimal point with digit following + scan_float(rest, num_acc <> "." <> <>, acc, advance_column(pos, 2), start_pos) + end + + defp scan_number(<<".", rest::binary>>, num_acc, acc, pos, start_pos) do + # Decimal point but no digit following - finalize number, reprocess "." + finalize_number(num_acc, <<".", rest::binary>>, acc, pos, start_pos) + end + + defp scan_number(<>, num_acc, acc, pos, start_pos) + when c in [?e, ?E] do + # Scientific notation + scan_exponent(<>, num_acc, acc, pos, start_pos) + end + + defp scan_number(rest, num_acc, acc, pos, start_pos) do + finalize_number(num_acc, rest, acc, pos, start_pos) + end + + # Scan float part (after decimal point) + defp scan_float(<>, num_acc, acc, pos, start_pos) + when c in ?0..?9 do + scan_float(rest, num_acc <> <>, acc, advance_column(pos, 1), start_pos) + end + + defp scan_float(<>, num_acc, acc, pos, start_pos) + when c in [?e, ?E] do + scan_exponent(<>, num_acc, acc, pos, start_pos) + end + + defp scan_float(rest, num_acc, acc, pos, start_pos) do + finalize_number(num_acc, rest, acc, pos, start_pos) + end + + # Scan scientific notation exponent + defp scan_exponent(<>, num_acc, acc, pos, start_pos) + when c in [?e, ?E] and sign in [?+, ?-] do + scan_exponent_digits(rest, num_acc <> <>, acc, advance_column(pos, 2), start_pos) + end + + defp scan_exponent(<>, num_acc, acc, pos, start_pos) + when c in [?e, ?E] do + scan_exponent_digits(rest, num_acc <> <>, acc, advance_column(pos, 1), start_pos) + end + + defp scan_exponent_digits(<>, num_acc, acc, pos, start_pos) + when c in ?0..?9 do + scan_exponent_digits(rest, num_acc <> <>, acc, advance_column(pos, 1), start_pos) + end + + defp scan_exponent_digits(rest, num_acc, acc, pos, start_pos) do + finalize_number(num_acc, rest, acc, pos, start_pos) + end + + # Scan hexadecimal number (0x...) + defp scan_hex_number(<>, hex_acc, acc, pos, start_pos) + when c in ?0..?9 or c in ?a..?f or c in ?A..?F do + scan_hex_number(rest, hex_acc <> <>, acc, advance_column(pos, 1), start_pos) + end + + defp scan_hex_number(rest, hex_acc, acc, pos, start_pos) do + case Integer.parse(hex_acc, 16) do + {num, ""} -> + token = {:number, num, start_pos} + do_tokenize(rest, [token | acc], pos) + + _ -> + {:error, {:invalid_hex_number, start_pos}} + end + end + + # Finalize number token + defp finalize_number(num_str, rest, acc, pos, start_pos) do + case parse_number(num_str) do + {:ok, num} -> + token = {:number, num, start_pos} + do_tokenize(rest, [token | acc], pos) + + {:error, reason} -> + {:error, {reason, start_pos}} + end + end + + # Parse number string to integer or float + defp parse_number(num_str) do + cond do + String.contains?(num_str, ".") or String.contains?(num_str, "e") or + String.contains?(num_str, "E") -> + case Float.parse(num_str) do + {num, ""} -> {:ok, num} + _ -> {:error, :invalid_number} + end + + true -> + {num, ""} = Integer.parse(num_str) + {:ok, num} + end + end + + # Position tracking helpers + defp advance_column(pos, n) do + %{pos | column: pos.column + n, byte_offset: pos.byte_offset + n} + end +end diff --git a/lib/lua/parser.ex b/lib/lua/parser.ex new file mode 100644 index 0000000..c0006ea --- /dev/null +++ b/lib/lua/parser.ex @@ -0,0 +1,1172 @@ +defmodule Lua.Parser do + @moduledoc """ + Hand-written recursive descent parser for Lua 5.3. + + Uses Pratt parsing for operator precedence in expressions. + """ + + alias Lua.AST.{Meta, Expr, Statement, Block, Chunk} + alias Lua.Parser.Pratt + alias Lua.Lexer + + @type token :: Lexer.token() + @type parse_result(t) :: {:ok, t, [token()]} | {:error, term()} + + alias Lua.Parser.Error + + @doc """ + Parses Lua source code into an AST. + + Returns `{:ok, chunk}` on success or `{:error, formatted_error}` on failure. + The error is a beautifully formatted string with context and suggestions. + + ## Examples + + iex> Lua.Parser.parse("local x = 42") + {:ok, %Lua.AST.Chunk{...}} + + iex> {:error, error_msg} = Lua.Parser.parse("if x then") + iex> String.contains?(error_msg, "Parse Error") + true + """ + @spec parse(String.t()) :: {:ok, Chunk.t()} | {:error, String.t()} + def parse(code) when is_binary(code) do + case Lexer.tokenize(code) do + {:ok, tokens} -> + case parse_chunk(tokens) do + {:ok, chunk} -> + {:ok, chunk} + + {:error, reason} -> + error = convert_error(reason, code) + formatted = Error.format(error, code) + {:error, formatted} + end + + {:error, reason} -> + error = convert_lexer_error(reason, code) + formatted = Error.format(error, code) + {:error, formatted} + end + end + + @doc """ + Parses Lua source code and returns raw error information. + + Use this when you want to handle errors programmatically instead of + displaying them to users. + """ + @spec parse_raw(String.t()) :: {:ok, Chunk.t()} | {:error, term()} + def parse_raw(code) when is_binary(code) do + case Lexer.tokenize(code) do + {:ok, tokens} -> + parse_chunk(tokens) + + {:error, reason} -> + {:error, {:lexer_error, reason}} + end + end + + @doc """ + Parses a chunk (top-level block) from a token list. + """ + @spec parse_chunk([token()]) :: {:ok, Chunk.t()} | {:error, term()} + def parse_chunk(tokens) do + case parse_block(tokens) do + {:ok, block, rest} -> + case rest do + [{:eof, _}] -> + {:ok, Chunk.new(block)} + + [{type, _, pos} | _] -> + {:error, {:unexpected_token, type, pos, "Expected end of input"}} + end + + {:error, reason} -> + {:error, reason} + end + end + + # Block parsing (sequence of statements) + defp parse_block(tokens) do + parse_block_acc(tokens, []) + end + + defp parse_block_acc(tokens, stmts) do + case peek(tokens) do + # Block terminators + {:keyword, terminator, _} when terminator in [:end, :else, :elseif, :until] -> + {:ok, Block.new(Enum.reverse(stmts)), tokens} + + {:eof, _} -> + {:ok, Block.new(Enum.reverse(stmts)), tokens} + + _ -> + case parse_stmt(tokens) do + {:ok, stmt, rest} -> + parse_block_acc(rest, [stmt | stmts]) + + {:error, reason} -> + {:error, reason} + end + end + end + + # Statement parsing (placeholder - will be implemented in Phase 3) + defp parse_stmt(tokens) do + case peek(tokens) do + {:keyword, :return, _} -> + parse_return(tokens) + + {:keyword, :local, _} -> + parse_local(tokens) + + {:keyword, :if, _} -> + parse_if(tokens) + + {:keyword, :while, _} -> + parse_while(tokens) + + {:keyword, :repeat, _} -> + parse_repeat(tokens) + + {:keyword, :for, _} -> + parse_for(tokens) + + {:keyword, :function, _} -> + parse_function_decl(tokens) + + {:keyword, :do, _} -> + parse_do(tokens) + + {:keyword, :break, _} -> + parse_break(tokens) + + {:keyword, :goto, _} -> + parse_goto(tokens) + + {:delimiter, :double_colon, _} -> + parse_label(tokens) + + # Semicolon (statement separator, optional) + {:delimiter, :semicolon, _} -> + {_, rest} = consume(tokens) + parse_stmt(rest) + + _ -> + # Try to parse as assignment or function call + parse_assign_or_call(tokens) + end + end + + # Placeholder implementations for statements (Phase 3) + defp parse_return([{:keyword, :return, pos} | rest]) do + case peek(rest) do + # End of block or statement + {:keyword, terminator, _} when terminator in [:end, :else, :elseif, :until] -> + {:ok, %Statement.Return{values: [], meta: Meta.new(pos)}, rest} + + {:eof, _} -> + {:ok, %Statement.Return{values: [], meta: Meta.new(pos)}, rest} + + {:delimiter, :semicolon, _} -> + {_, rest2} = consume(rest) + {:ok, %Statement.Return{values: [], meta: Meta.new(pos)}, rest2} + + _ -> + case parse_expr_list(rest) do + {:ok, exprs, rest2} -> + {:ok, %Statement.Return{values: exprs, meta: Meta.new(pos)}, rest2} + + {:error, reason} -> + {:error, reason} + end + end + end + + defp parse_local([{:keyword, :local, pos} | rest]) do + case peek(rest) do + {:keyword, :function, _} -> + # local function name() ... end + {_, rest2} = consume(rest) + + case expect(rest2, :identifier) do + {:ok, {_, name, _}, rest3} -> + with {:ok, _, rest4} <- expect(rest3, :delimiter, :lparen), + {:ok, params, rest5} <- parse_param_list(rest4), + {:ok, _, rest6} <- expect(rest5, :delimiter, :rparen), + {:ok, body, rest7} <- parse_block(rest6), + {:ok, _, rest8} <- expect(rest7, :keyword, :end) do + {:ok, + %Statement.LocalFunc{name: name, params: params, body: body, meta: Meta.new(pos)}, + rest8} + end + + {:error, reason} -> + {:error, reason} + end + + {:identifier, _, _} -> + # local name1, name2 = expr1, expr2 + case parse_name_list(rest) do + {:ok, names, rest2} -> + case peek(rest2) do + {:operator, :assign, _} -> + {_, rest3} = consume(rest2) + + case parse_expr_list(rest3) do + {:ok, values, rest4} -> + {:ok, %Statement.Local{names: names, values: values, meta: Meta.new(pos)}, + rest4} + + {:error, reason} -> + {:error, reason} + end + + _ -> + # Local without initialization + {:ok, %Statement.Local{names: names, values: [], meta: Meta.new(pos)}, rest2} + end + + {:error, reason} -> + {:error, reason} + end + + _ -> + {:error, + {:unexpected_token, peek(rest), "Expected identifier or 'function' after 'local'"}} + end + end + + defp parse_if([{:keyword, :if, pos} | rest]) do + with {:ok, condition, rest2} <- parse_expr(rest), + {:ok, _, rest3} <- expect(rest2, :keyword, :then), + {:ok, then_block, rest4} <- parse_block(rest3), + {:ok, elseifs, else_block, rest5} <- parse_elseifs(rest4) do + case expect(rest5, :keyword, :end) do + {:ok, _, rest6} -> + {:ok, + %Statement.If{ + condition: condition, + then_block: then_block, + elseifs: elseifs, + else_block: else_block, + meta: Meta.new(pos) + }, rest6} + + {:error, reason} -> + {:error, reason} + end + end + end + + defp parse_elseifs(tokens) do + case peek(tokens) do + {:keyword, :elseif, _} -> + {_, rest} = consume(tokens) + + with {:ok, condition, rest2} <- parse_expr(rest), + {:ok, _, rest3} <- expect(rest2, :keyword, :then), + {:ok, block, rest4} <- parse_block(rest3), + {:ok, more_elseifs, else_block, rest5} <- parse_elseifs(rest4) do + {:ok, [{condition, block} | more_elseifs], else_block, rest5} + end + + {:keyword, :else, _} -> + {_, rest} = consume(tokens) + + case parse_block(rest) do + {:ok, else_block, rest2} -> + {:ok, [], else_block, rest2} + + {:error, reason} -> + {:error, reason} + end + + _ -> + {:ok, [], nil, tokens} + end + end + + defp parse_while([{:keyword, :while, pos} | rest]) do + with {:ok, condition, rest2} <- parse_expr(rest), + {:ok, _, rest3} <- expect(rest2, :keyword, :do), + {:ok, body, rest4} <- parse_block(rest3), + {:ok, _, rest5} <- expect(rest4, :keyword, :end) do + {:ok, %Statement.While{condition: condition, body: body, meta: Meta.new(pos)}, rest5} + end + end + + defp parse_repeat([{:keyword, :repeat, pos} | rest]) do + with {:ok, body, rest2} <- parse_block(rest), + {:ok, _, rest3} <- expect(rest2, :keyword, :until), + {:ok, condition, rest4} <- parse_expr(rest3) do + {:ok, %Statement.Repeat{body: body, condition: condition, meta: Meta.new(pos)}, rest4} + end + end + + defp parse_for([{:keyword, :for, pos} | rest]) do + case expect(rest, :identifier) do + {:ok, {_, var, _}, rest2} -> + case peek(rest2) do + {:operator, :assign, _} -> + # Numeric for: for var = start, limit, step do ... end + {_, rest3} = consume(rest2) + + with {:ok, start, rest4} <- parse_expr(rest3), + {:ok, _, rest5} <- expect(rest4, :delimiter, :comma), + {:ok, limit, rest6} <- parse_expr(rest5), + {:ok, step, rest7} <- parse_for_step(rest6), + {:ok, _, rest8} <- expect(rest7, :keyword, :do), + {:ok, body, rest9} <- parse_block(rest8), + {:ok, _, rest10} <- expect(rest9, :keyword, :end) do + {:ok, + %Statement.ForNum{ + var: var, + start: start, + limit: limit, + step: step, + body: body, + meta: Meta.new(pos) + }, rest10} + end + + {:delimiter, :comma, _} -> + # Generic for: for var1, var2 in exprs do ... end + parse_generic_for([var], rest2, pos) + + {:keyword, :in, _} -> + # Generic for with single variable + parse_generic_for([var], rest2, pos) + + _ -> + {:error, {:unexpected_token, peek(rest2), "Expected '=' or 'in' after for variable"}} + end + + {:error, reason} -> + {:error, reason} + end + end + + defp parse_for_step(tokens) do + case peek(tokens) do + {:delimiter, :comma, _} -> + {_, rest} = consume(tokens) + parse_expr(rest) + + _ -> + {:ok, nil, tokens} + end + end + + defp parse_generic_for(vars, tokens, start_pos) do + case peek(tokens) do + {:delimiter, :comma, _} -> + {_, rest} = consume(tokens) + + case expect(rest, :identifier) do + {:ok, {_, var, _}, rest2} -> + parse_generic_for(vars ++ [var], rest2, start_pos) + + {:error, reason} -> + {:error, reason} + end + + {:keyword, :in, _} -> + {_, rest} = consume(tokens) + + with {:ok, iterators, rest2} <- parse_expr_list(rest), + {:ok, _, rest3} <- expect(rest2, :keyword, :do), + {:ok, body, rest4} <- parse_block(rest3), + {:ok, _, rest5} <- expect(rest4, :keyword, :end) do + {:ok, + %Statement.ForIn{ + vars: vars, + iterators: iterators, + body: body, + meta: Meta.new(start_pos) + }, rest5} + end + + _ -> + {:error, {:unexpected_token, peek(tokens), "Expected ',' or 'in' in for loop"}} + end + end + + defp parse_function_decl([{:keyword, :function, pos} | rest]) do + case parse_function_name(rest) do + {:ok, name_parts, is_method, rest2} -> + with {:ok, _, rest3} <- expect(rest2, :delimiter, :lparen), + {:ok, params, rest4} <- parse_param_list(rest3), + {:ok, _, rest5} <- expect(rest4, :delimiter, :rparen), + {:ok, body, rest6} <- parse_block(rest5), + {:ok, _, rest7} <- expect(rest6, :keyword, :end) do + {:ok, + %Statement.FuncDecl{ + name: name_parts, + params: params, + body: body, + is_method: is_method, + meta: Meta.new(pos) + }, rest7} + end + + {:error, reason} -> + {:error, reason} + end + end + + defp parse_function_name(tokens) do + case expect(tokens, :identifier) do + {:ok, {_, name, _}, rest} -> + parse_function_name_rest([name], rest) + + {:error, reason} -> + {:error, reason} + end + end + + defp parse_function_name_rest(names, tokens) do + case peek(tokens) do + {:delimiter, :dot, _} -> + {_, rest} = consume(tokens) + + case expect(rest, :identifier) do + {:ok, {_, name, _}, rest2} -> + parse_function_name_rest(names ++ [name], rest2) + + {:error, reason} -> + {:error, reason} + end + + {:delimiter, :colon, _} -> + {_, rest} = consume(tokens) + + case expect(rest, :identifier) do + {:ok, {_, name, _}, rest2} -> + {:ok, names ++ [name], true, rest2} + + {:error, reason} -> + {:error, reason} + end + + _ -> + {:ok, names, false, tokens} + end + end + + defp parse_do([{:keyword, :do, pos} | rest]) do + with {:ok, body, rest2} <- parse_block(rest), + {:ok, _, rest3} <- expect(rest2, :keyword, :end) do + {:ok, %Statement.Do{body: body, meta: Meta.new(pos)}, rest3} + end + end + + defp parse_break([{:keyword, :break, pos} | rest]) do + {:ok, %Statement.Break{meta: Meta.new(pos)}, rest} + end + + defp parse_goto([{:keyword, :goto, pos} | rest]) do + case expect(rest, :identifier) do + {:ok, {_, label, _}, rest2} -> + {:ok, %Statement.Goto{label: label, meta: Meta.new(pos)}, rest2} + + {:error, reason} -> + {:error, reason} + end + end + + defp parse_label([{:delimiter, :double_colon, pos} | rest]) do + case expect(rest, :identifier) do + {:ok, {_, name, _}, rest2} -> + case expect(rest2, :delimiter, :double_colon) do + {:ok, _, rest3} -> + {:ok, %Statement.Label{name: name, meta: Meta.new(pos)}, rest3} + + {:error, reason} -> + {:error, reason} + end + + {:error, reason} -> + {:error, reason} + end + end + + defp parse_assign_or_call(tokens) do + # This is the most complex case - we need to parse a potential lvalue or call + # Start by parsing an expression (which could be a variable, call, property access, etc.) + case parse_expr(tokens) do + {:ok, expr, rest} -> + case peek(rest) do + {:operator, :assign, _} -> + # It's an assignment + parse_assignment([expr], rest) + + {:delimiter, :comma, _} -> + # Multiple targets, must be assignment + parse_assignment_targets([expr], rest) + + _ -> + # It's a call statement (or error if not a call) + case expr do + %Expr.Call{} = call -> + {:ok, %Statement.CallStmt{call: call, meta: nil}, rest} + + %Expr.MethodCall{} = call -> + {:ok, %Statement.CallStmt{call: call, meta: nil}, rest} + + _ -> + {:error, {:unexpected_expression, "Expression statement must be a function call"}} + end + end + + {:error, reason} -> + {:error, reason} + end + end + + defp parse_assignment_targets(targets, [{:delimiter, :comma, _} | rest]) do + case parse_expr(rest) do + {:ok, expr, rest2} -> + case peek(rest2) do + {:delimiter, :comma, _} -> + parse_assignment_targets(targets ++ [expr], rest2) + + {:operator, :assign, _} -> + parse_assignment(targets ++ [expr], rest2) + + _ -> + {:error, {:unexpected_token, peek(rest2), "Expected '=' or ',' in assignment"}} + end + + {:error, reason} -> + {:error, reason} + end + end + + defp parse_assignment(targets, [{:operator, :assign, _} | rest]) do + case parse_expr_list(rest) do + {:ok, values, rest2} -> + {:ok, %Statement.Assign{targets: targets, values: values, meta: nil}, rest2} + + {:error, reason} -> + {:error, reason} + end + end + + # Helper: parse list of names (for local declarations, for loops) + defp parse_name_list(tokens) do + case expect(tokens, :identifier) do + {:ok, {_, name, _}, rest} -> + parse_name_list_rest([name], rest) + + {:error, reason} -> + {:error, reason} + end + end + + defp parse_name_list_rest(names, tokens) do + case peek(tokens) do + {:delimiter, :comma, _} -> + {_, rest} = consume(tokens) + + case expect(rest, :identifier) do + {:ok, {_, name, _}, rest2} -> + parse_name_list_rest(names ++ [name], rest2) + + {:error, reason} -> + {:error, reason} + end + + _ -> + {:ok, names, tokens} + end + end + + # Expression parsing with Pratt algorithm + @doc """ + Parses an expression with minimum precedence. + """ + @spec parse_expr([token()], non_neg_integer()) :: parse_result(Expr.t()) + def parse_expr(tokens, min_prec \\ 0) do + # Parse prefix (primary or unary operator) + case parse_prefix(tokens) do + {:ok, left, rest} -> + parse_infix(left, rest, min_prec) + + {:error, reason} -> + {:error, reason} + end + end + + # Parse prefix expressions (primary expressions and unary operators) + defp parse_prefix(tokens) do + case peek(tokens) do + # Literals + {:keyword, nil, pos} -> + {_, rest} = consume(tokens) + {:ok, %Expr.Nil{meta: Meta.new(pos)}, rest} + + {:keyword, true, pos} -> + {_, rest} = consume(tokens) + {:ok, %Expr.Bool{value: true, meta: Meta.new(pos)}, rest} + + {:keyword, false, pos} -> + {_, rest} = consume(tokens) + {:ok, %Expr.Bool{value: false, meta: Meta.new(pos)}, rest} + + {:number, value, pos} -> + {_, rest} = consume(tokens) + {:ok, %Expr.Number{value: value, meta: Meta.new(pos)}, rest} + + {:string, value, pos} -> + {_, rest} = consume(tokens) + {:ok, %Expr.String{value: value, meta: Meta.new(pos)}, rest} + + # Vararg + {:operator, :vararg, pos} -> + {_, rest} = consume(tokens) + {:ok, %Expr.Vararg{meta: Meta.new(pos)}, rest} + + # Identifier (variable) + {:identifier, name, pos} -> + {_, rest} = consume(tokens) + {:ok, %Expr.Var{name: name, meta: Meta.new(pos)}, rest} + + # Parenthesized expression + {:delimiter, :lparen, _} -> + parse_paren_expr(tokens) + + # Table constructor + {:delimiter, :lbrace, _} -> + parse_table(tokens) + + # Function expression + {:keyword, :function, _} -> + parse_function_expr(tokens) + + # Unary operators + {:keyword, :not, pos} -> + {_, rest} = consume(tokens) + parse_unary(:not, pos, rest) + + {:operator, :sub, pos} -> + {_, rest} = consume(tokens) + parse_unary(:sub, pos, rest) + + {:operator, :len, pos} -> + {_, rest} = consume(tokens) + parse_unary(:len, pos, rest) + + {:eof, pos} -> + {:error, {:unexpected_token, :eof, pos, "Expected expression"}} + + {type, _, pos} -> + {:error, {:unexpected_token, type, pos, "Expected expression"}} + + nil -> + {:error, {:unexpected_end, "Expected expression"}} + end + end + + defp parse_unary(op, pos, tokens) do + unop = Pratt.token_to_unop(op) + prec = Pratt.prefix_binding_power(op) + + case parse_expr(tokens, prec) do + {:ok, operand, rest} -> + {:ok, %Expr.UnOp{op: unop, operand: operand, meta: Meta.new(pos)}, rest} + + {:error, reason} -> + {:error, reason} + end + end + + # Parse infix expressions (binary operators and postfix) + defp parse_infix(left, tokens, min_prec) do + case peek(tokens) do + {:keyword, op, pos} when op in [:and, :or] -> + if Pratt.is_binary_op?(op) do + case Pratt.binding_power(op) do + {left_bp, right_bp} when left_bp >= min_prec -> + {_, rest} = consume(tokens) + binop = Pratt.token_to_binop(op) + + case parse_expr(rest, right_bp) do + {:ok, right, rest2} -> + new_left = %Expr.BinOp{ + op: binop, + left: left, + right: right, + meta: Meta.new(pos) + } + + parse_infix(new_left, rest2, min_prec) + + {:error, reason} -> + {:error, reason} + end + + _ -> + {:ok, left, tokens} + end + else + {:ok, left, tokens} + end + + {:operator, op, pos} -> + cond do + Pratt.is_binary_op?(op) -> + case Pratt.binding_power(op) do + {left_bp, right_bp} when left_bp >= min_prec -> + {_, rest} = consume(tokens) + binop = Pratt.token_to_binop(op) + + case parse_expr(rest, right_bp) do + {:ok, right, rest2} -> + new_left = %Expr.BinOp{ + op: binop, + left: left, + right: right, + meta: Meta.new(pos) + } + + parse_infix(new_left, rest2, min_prec) + + {:error, reason} -> + {:error, reason} + end + + _ -> + {:ok, left, tokens} + end + + true -> + {:ok, left, tokens} + end + + # Postfix: function call + {:delimiter, :lparen, _} -> + case parse_call_args(tokens) do + {:ok, args, rest} -> + new_left = %Expr.Call{func: left, args: args, meta: nil} + parse_infix(new_left, rest, min_prec) + + {:error, reason} -> + {:error, reason} + end + + # Postfix: indexing + {:delimiter, :lbracket, _} -> + case parse_index(tokens) do + {:ok, key, rest} -> + new_left = %Expr.Index{table: left, key: key, meta: nil} + parse_infix(new_left, rest, min_prec) + + {:error, reason} -> + {:error, reason} + end + + # Postfix: property access or method call + {:delimiter, :dot, _} -> + {_, rest} = consume(tokens) + + case expect(rest, :identifier) do + {:ok, {_, field, _}, rest2} -> + new_left = %Expr.Property{table: left, field: field, meta: nil} + parse_infix(new_left, rest2, min_prec) + + {:error, reason} -> + {:error, reason} + end + + {:delimiter, :colon, _} -> + {_, rest} = consume(tokens) + + case expect(rest, :identifier) do + {:ok, {_, method, _}, rest2} -> + case parse_call_args(rest2) do + {:ok, args, rest3} -> + new_left = %Expr.MethodCall{object: left, method: method, args: args, meta: nil} + parse_infix(new_left, rest3, min_prec) + + {:error, reason} -> + {:error, reason} + end + + {:error, reason} -> + {:error, reason} + end + + _ -> + {:ok, left, tokens} + end + end + + # Parse parenthesized expression: (expr) + defp parse_paren_expr([{:delimiter, :lparen, _} | rest]) do + case parse_expr(rest) do + {:ok, expr, rest2} -> + case expect(rest2, :delimiter, :rparen) do + {:ok, _, rest3} -> + {:ok, expr, rest3} + + {:error, reason} -> + {:error, reason} + end + + {:error, reason} -> + {:error, reason} + end + end + + # Parse table constructor: { fields } + defp parse_table([{:delimiter, :lbrace, pos} | rest]) do + case parse_table_fields(rest, []) do + {:ok, fields, rest2} -> + case expect(rest2, :delimiter, :rbrace) do + {:ok, _, rest3} -> + {:ok, %Expr.Table{fields: Enum.reverse(fields), meta: Meta.new(pos)}, rest3} + + {:error, reason} -> + {:error, reason} + end + + {:error, reason} -> + {:error, reason} + end + end + + defp parse_table_fields(tokens, acc) do + case peek(tokens) do + {:delimiter, :rbrace, _} -> + {:ok, acc, tokens} + + _ -> + case parse_table_field(tokens) do + {:ok, field, rest} -> + case peek(rest) do + {:delimiter, :comma, _} -> + {_, rest2} = consume(rest) + parse_table_fields(rest2, [field | acc]) + + {:delimiter, :semicolon, _} -> + {_, rest2} = consume(rest) + parse_table_fields(rest2, [field | acc]) + + {:delimiter, :rbrace, _} -> + {:ok, [field | acc], rest} + + _ -> + {:error, {:unexpected_token, peek(rest), "Expected ',' or '}' in table"}} + end + + {:error, reason} -> + {:error, reason} + end + end + end + + defp parse_table_field(tokens) do + case peek(tokens) do + # [expr] = expr (computed key) + {:delimiter, :lbracket, _} -> + {_, rest} = consume(tokens) + + with {:ok, key, rest2} <- parse_expr(rest), + {:ok, _, rest3} <- expect(rest2, :delimiter, :rbracket), + {:ok, _, rest4} <- expect(rest3, :operator, :assign), + {:ok, value, rest5} <- parse_expr(rest4) do + {:ok, {:pair, key, value}, rest5} + end + + # name = expr (named field) + {:identifier, name, pos} -> + rest = tl(tokens) + + case peek(rest) do + {:operator, :assign, _} -> + {_, rest2} = consume(rest) + + case parse_expr(rest2) do + {:ok, value, rest3} -> + key = %Expr.String{value: name, meta: Meta.new(pos)} + {:ok, {:pair, key, value}, rest3} + + {:error, reason} -> + {:error, reason} + end + + _ -> + # Just an expression (list entry) + case parse_expr(tokens) do + {:ok, expr, rest2} -> + {:ok, {:list, expr}, rest2} + + {:error, reason} -> + {:error, reason} + end + end + + _ -> + # Expression (list entry) + case parse_expr(tokens) do + {:ok, expr, rest} -> + {:ok, {:list, expr}, rest} + + {:error, reason} -> + {:error, reason} + end + end + end + + # Parse function expression: function(params) body end + defp parse_function_expr([{:keyword, :function, pos} | rest]) do + with {:ok, _, rest2} <- expect(rest, :delimiter, :lparen), + {:ok, params, rest3} <- parse_param_list(rest2), + {:ok, _, rest4} <- expect(rest3, :delimiter, :rparen), + {:ok, body, rest5} <- parse_block(rest4), + {:ok, _, rest6} <- expect(rest5, :keyword, :end) do + {:ok, %Expr.Function{params: params, body: body, meta: Meta.new(pos)}, rest6} + end + end + + defp parse_param_list(tokens) do + parse_param_list_acc(tokens, []) + end + + defp parse_param_list_acc(tokens, acc) do + case peek(tokens) do + {:delimiter, :rparen, _} -> + {:ok, Enum.reverse(acc), tokens} + + {:operator, :vararg, _} -> + {_, rest} = consume(tokens) + {:ok, Enum.reverse([:vararg | acc]), rest} + + {:identifier, name, _} -> + {_, rest} = consume(tokens) + + case peek(rest) do + {:delimiter, :comma, _} -> + {_, rest2} = consume(rest) + parse_param_list_acc(rest2, [name | acc]) + + _ -> + {:ok, Enum.reverse([name | acc]), rest} + end + + _ -> + {:error, {:unexpected_token, peek(tokens), "Expected parameter name or ')'"}} + end + end + + # Parse function call arguments: (args) + defp parse_call_args([{:delimiter, :lparen, _} | rest]) do + parse_expr_list_until(rest, :rparen) + end + + # Parse indexing: [key] + defp parse_index([{:delimiter, :lbracket, _} | rest]) do + case parse_expr(rest) do + {:ok, key, rest2} -> + case expect(rest2, :delimiter, :rbracket) do + {:ok, _, rest3} -> + {:ok, key, rest3} + + {:error, reason} -> + {:error, reason} + end + + {:error, reason} -> + {:error, reason} + end + end + + # Parse expression list: expr1, expr2, ... + defp parse_expr_list(tokens) do + parse_expr_list_acc(tokens, []) + end + + defp parse_expr_list_acc(tokens, acc) do + case parse_expr(tokens) do + {:ok, expr, rest} -> + case peek(rest) do + {:delimiter, :comma, _} -> + {_, rest2} = consume(rest) + parse_expr_list_acc(rest2, [expr | acc]) + + _ -> + {:ok, Enum.reverse([expr | acc]), rest} + end + + {:error, reason} -> + if acc == [] do + {:error, reason} + else + {:ok, Enum.reverse(acc), tokens} + end + end + end + + defp parse_expr_list_until(tokens, terminator) do + case peek(tokens) do + {:delimiter, ^terminator, _} -> + {_, rest} = consume(tokens) + {:ok, [], rest} + + _ -> + case parse_expr_list(tokens) do + {:ok, exprs, rest} -> + case expect(rest, :delimiter, terminator) do + {:ok, _, rest2} -> + {:ok, exprs, rest2} + + {:error, reason} -> + {:error, reason} + end + + {:error, reason} -> + {:error, reason} + end + end + end + + # Token manipulation helpers + + defp peek([token | _]), do: token + defp peek([]), do: nil + + defp consume([token | rest]), do: {token, rest} + defp consume([]), do: {nil, []} + + # Expect a specific token type + defp expect(tokens, expected_type) do + case peek(tokens) do + {^expected_type, _, _} = token -> + {_, rest} = consume(tokens) + {:ok, token, rest} + + {type, _, pos} when type != nil -> + {:error, + {:unexpected_token, type, pos, + "Expected #{inspect(expected_type)}, got #{inspect(type)}"}} + + {type, pos} when is_map(pos) -> + # Token without value (like :eof) + {:error, + {:unexpected_token, type, pos, + "Expected #{inspect(expected_type)}, got #{inspect(type)}"}} + + nil -> + {:error, {:unexpected_end, "Expected #{inspect(expected_type)}"}} + end + end + + # Expect a specific token type and value + defp expect(tokens, expected_type, expected_value) do + case peek(tokens) do + {^expected_type, ^expected_value, _} = token -> + {_, rest} = consume(tokens) + {:ok, token, rest} + + {type, value, pos} when type != nil and value != nil -> + {:error, + {:unexpected_token, type, pos, + "Expected #{inspect(expected_type)}:#{inspect(expected_value)}, got #{inspect(type)}:#{inspect(value)}"}} + + {type, pos} when is_map(pos) -> + # Token without value (like :eof) + {:error, + {:unexpected_token, type, pos, + "Expected #{inspect(expected_type)}:#{inspect(expected_value)}, got #{inspect(type)}"}} + + nil -> + {:error, + {:unexpected_end, "Expected #{inspect(expected_type)}:#{inspect(expected_value)}"}} + end + end + + # Error conversion helpers + + defp convert_error({:unexpected_token, type, pos, message}, _code) do + Error.new(:unexpected_token, message, pos, suggestion: suggest_for_token_error(type, message)) + end + + defp convert_error({:unexpected_end, message}, _code) do + Error.new(:unexpected_end, message, nil, + suggestion: """ + The parser reached the end of the file unexpectedly. + Check for missing closing delimiters or keywords like 'end', ')', '}', or ']'. + """ + ) + end + + defp convert_error({:unexpected_expression, message}, _code) do + Error.new(:invalid_syntax, message, nil) + end + + defp convert_error(other, _code) do + Error.new(:invalid_syntax, "Parse error: #{inspect(other)}", nil) + end + + defp convert_lexer_error({:unexpected_character, char, pos}, _code) do + Error.new(:invalid_syntax, "Unexpected character: #{<>}", pos, + suggestion: """ + This character is not valid in Lua syntax. + Check for typos or invisible characters. + """ + ) + end + + defp convert_lexer_error({:unclosed_string, pos}, _code) do + Error.new(:unclosed_delimiter, "Unclosed string literal", pos, + suggestion: """ + Add a closing quote (" or ') to finish the string. + Strings cannot span multiple lines unless you use [[...]] syntax. + """ + ) + end + + defp convert_lexer_error({:unclosed_long_string, pos}, _code) do + Error.new(:unclosed_delimiter, "Unclosed long string [[...]]", pos, + suggestion: "Add the closing ]] to finish the long string." + ) + end + + defp convert_lexer_error({:unclosed_comment, pos}, _code) do + Error.new(:unclosed_delimiter, "Unclosed multi-line comment --[[...]]", pos, + suggestion: "Add the closing ]] to finish the comment." + ) + end + + defp convert_lexer_error(other, _code) do + Error.new(:lexer_error, "Lexer error: #{inspect(other)}", nil) + end + + defp suggest_for_token_error(type, message) do + cond do + type == :eof -> + "Reached end of file unexpectedly. Check for missing 'end' keywords or closing delimiters." + + String.contains?(message, "Expected 'end'") -> + """ + Every block needs an 'end': + - if/elseif/else ... end + - while ... do ... end + - for ... do ... end + - function ... end + - do ... end + """ + + String.contains?(message, "Expected 'then'") -> + "In Lua, 'if' and 'elseif' conditions must be followed by 'then'." + + String.contains?(message, "Expected 'do'") -> + "In Lua, 'while' and 'for' loops must have 'do' before the body." + + true -> + nil + end + end +end diff --git a/lib/lua/parser/error.ex b/lib/lua/parser/error.ex new file mode 100644 index 0000000..73430b3 --- /dev/null +++ b/lib/lua/parser/error.ex @@ -0,0 +1,345 @@ +defmodule Lua.Parser.Error do + @moduledoc """ + Beautiful error reporting for the Lua parser. + + Provides detailed error messages with: + - Source code context with line numbers + - Visual indicators pointing to the error location + - Helpful suggestions for common mistakes + - Multiple error reporting + """ + + alias Lua.AST.Meta + + @type position :: Meta.position() + + @type t :: %__MODULE__{ + type: error_type(), + message: String.t(), + position: position() | nil, + suggestion: String.t() | nil, + source_lines: [String.t()], + related: [t()] + } + + @type error_type :: + :unexpected_token + | :unexpected_end + | :expected_token + | :unclosed_delimiter + | :invalid_syntax + | :lexer_error + | :multiple_errors + + defstruct [ + :type, + :message, + :position, + :suggestion, + source_lines: [], + related: [] + ] + + @doc """ + Creates a new error. + """ + @spec new(error_type(), String.t(), position() | nil, keyword()) :: t() + def new(type, message, position \\ nil, opts \\ []) do + %__MODULE__{ + type: type, + message: message, + position: position, + suggestion: opts[:suggestion], + source_lines: opts[:source_lines] || [], + related: opts[:related] || [] + } + end + + @doc """ + Creates an error for unexpected token. + """ + @spec unexpected_token(atom(), term(), position(), String.t()) :: t() + def unexpected_token(token_type, token_value, position, context) do + message = """ + Unexpected #{format_token(token_type, token_value)} in #{context} + """ + + suggestion = suggest_for_unexpected_token(token_type, token_value, context) + + new(:unexpected_token, message, position, suggestion: suggestion) + end + + @doc """ + Creates an error for expected token. + """ + @spec expected_token(atom(), term() | nil, atom(), term(), position()) :: t() + def expected_token(expected_type, expected_value, got_type, got_value, position) do + expected = format_token(expected_type, expected_value) + got = format_token(got_type, got_value) + + message = """ + Expected #{expected}, but got #{got} + """ + + suggestion = suggest_for_expected_token(expected_type, expected_value, got_type) + + new(:expected_token, message, position, suggestion: suggestion) + end + + @doc """ + Creates an error for unclosed delimiter. + """ + @spec unclosed_delimiter(atom(), position(), position() | nil) :: t() + def unclosed_delimiter(delimiter, open_pos, close_pos \\ nil) do + delimiter_str = format_delimiter(delimiter) + + message = """ + Unclosed #{delimiter_str} + """ + + suggestion = """ + Add a closing #{closing_delimiter(delimiter)} to match the opening at line #{open_pos.line} + """ + + new(:unclosed_delimiter, message, close_pos || open_pos, suggestion: suggestion) + end + + @doc """ + Creates an error for unexpected end of input. + """ + @spec unexpected_end(String.t(), position() | nil) :: t() + def unexpected_end(context, position \\ nil) do + message = """ + Unexpected end of input while parsing #{context} + """ + + suggestion = """ + Check for missing closing delimiters or keywords like 'end', ')', '}', or ']' + """ + + new(:unexpected_end, message, position, suggestion: suggestion) + end + + @doc """ + Formats an error into a beautiful multi-line string with context. + """ + @spec format(t(), String.t()) :: String.t() + def format(error, source_code) do + lines = String.split(source_code, "\n") + + header = [ + IO.ANSI.red() <> IO.ANSI.bright() <> "Parse Error" <> IO.ANSI.reset(), + "" + ] + + location = + if error.position do + pos = error.position + " at line #{pos.line}, column #{pos.column}:" + else + " (no position information)" + end + + message_lines = [ + location, + "", + indent(error.message, 2) + ] + + context_lines = + if error.position && length(lines) > 0 do + format_context(lines, error.position) + else + [] + end + + suggestion_lines = + if error.suggestion do + [ + "", + IO.ANSI.cyan() <> "Suggestion:" <> IO.ANSI.reset(), + indent(error.suggestion, 2) + ] + else + [] + end + + related_lines = + if length(error.related) > 0 do + [ + "", + IO.ANSI.yellow() <> "Related errors:" <> IO.ANSI.reset() + ] ++ Enum.flat_map(error.related, fn rel -> ["", indent(format(rel, source_code), 2)] end) + else + [] + end + + (header ++ message_lines ++ context_lines ++ suggestion_lines ++ related_lines) + |> Enum.join("\n") + end + + @doc """ + Formats multiple errors together. + """ + @spec format_multiple([t()], String.t()) :: String.t() + def format_multiple(errors, source_code) do + header = [ + IO.ANSI.red() <> + IO.ANSI.bright() <> + "Found #{length(errors)} parse error#{if length(errors) == 1, do: "", else: "s"}" <> + IO.ANSI.reset(), + "" + ] + + error_lines = + errors + |> Enum.with_index(1) + |> Enum.flat_map(fn {error, idx} -> + [ + IO.ANSI.yellow() <> "Error #{idx}:" <> IO.ANSI.reset(), + format(error, source_code), + "" + ] + end) + + (header ++ error_lines) + |> Enum.join("\n") + end + + # Private helpers + + defp format_context(lines, position) do + line_num = position.line + column = position.column + + # Show 2 lines before and after + start_line = max(1, line_num - 2) + end_line = min(length(lines), line_num + 2) + + context_lines = + Enum.slice(lines, (start_line - 1)..(end_line - 1)) + |> Enum.with_index(start_line) + |> Enum.flat_map(fn {line, num} -> + line_str = format_line_number(num) <> " │ " <> line + + if num == line_num do + # Error line + pointer = String.duplicate(" ", String.length(format_line_number(num)) + 3 + column - 1) + pointer = pointer <> IO.ANSI.red() <> "^" <> IO.ANSI.reset() + + [ + IO.ANSI.red() <> line_str <> IO.ANSI.reset(), + pointer + ] + else + # Context line + [IO.ANSI.faint() <> line_str <> IO.ANSI.reset()] + end + end) + + ["", ""] ++ context_lines + end + + defp format_line_number(num) do + num + |> Integer.to_string() + |> String.pad_leading(4) + end + + defp format_token(type, value) do + case type do + :keyword -> "'#{value}'" + :identifier -> "identifier '#{value}'" + :number -> "number #{value}" + :string -> "string \"#{value}\"" + :operator -> "operator '#{value}'" + :delimiter -> "'#{value}'" + :eof -> "end of input" + _ -> "#{type}" + end + end + + defp format_delimiter(delimiter) do + case delimiter do + :lparen -> "opening parenthesis '('" + :lbracket -> "opening bracket '['" + :lbrace -> "opening brace '{'" + :function -> "'function' block" + :if -> "'if' statement" + :while -> "'while' loop" + :for -> "'for' loop" + :do -> "'do' block" + _ -> "#{delimiter}" + end + end + + defp closing_delimiter(delimiter) do + case delimiter do + :lparen -> "')'" + :lbracket -> "']'" + :lbrace -> "'}'" + :function -> "'end'" + :if -> "'end'" + :while -> "'end'" + :for -> "'end'" + :do -> "'end'" + _ -> "matching delimiter" + end + end + + defp suggest_for_unexpected_token(token_type, _token_value, context) do + cond do + token_type == :delimiter -> + "Check for missing operators or keywords before this delimiter" + + String.contains?(context, "expression") -> + "Expected an expression here (variable, number, string, table, function, etc.)" + + String.contains?(context, "statement") -> + "Expected a statement here (assignment, function call, if, while, for, etc.)" + + true -> + nil + end + end + + defp suggest_for_expected_token(expected_type, expected_value, got_type) do + cond do + expected_type == :keyword && expected_value == :end -> + "Add 'end' to close the block. Check that all opening keywords (if, while, for, function, do) have matching 'end' keywords." + + expected_type == :keyword && expected_value == :then -> + "Add 'then' after the condition. Lua requires 'then' after if/elseif conditions." + + expected_type == :keyword && expected_value == :do -> + "Add 'do' to start the loop body. Lua requires 'do' after while/for conditions." + + expected_type == :delimiter && expected_value == :rparen -> + "Add ')' to close the parentheses. Check for balanced parentheses." + + expected_type == :delimiter && expected_value == :rbracket -> + "Add ']' to close the brackets. Check for balanced brackets." + + expected_type == :delimiter && expected_value == :rbrace -> + "Add '}' to close the table constructor. Check for balanced braces." + + expected_type == :operator && expected_value == :assign -> + "Add '=' for assignment. Did you mean to assign a value?" + + expected_type == :identifier && got_type == :keyword -> + "Cannot use Lua keyword as identifier. Choose a different name." + + true -> + nil + end + end + + defp indent(text, spaces) do + prefix = String.duplicate(" ", spaces) + + text + |> String.split("\n") + |> Enum.map(&(prefix <> &1)) + |> Enum.join("\n") + end +end diff --git a/lib/lua/parser/pratt.ex b/lib/lua/parser/pratt.ex new file mode 100644 index 0000000..c005ef9 --- /dev/null +++ b/lib/lua/parser/pratt.ex @@ -0,0 +1,131 @@ +defmodule Lua.Parser.Pratt do + @moduledoc """ + Pratt parser for Lua expressions. + + Implements operator precedence parsing using binding powers. + Handles all 11 precedence levels in Lua 5.3. + + Precedence (lowest to highest): + 1. or + 2. and + 3. < > <= >= ~= == + 4. .. + 5. + - + 6. * / // % + 7. unary (not # -) + 8. ^ + """ + + alias Lua.AST.Expr + + @doc """ + Returns the binding power (precedence) for binary operators. + + Returns {left_bp, right_bp} where: + - left_bp: minimum precedence of left operand + - right_bp: minimum precedence of right operand + + Right associative operators have left_bp < right_bp. + Left associative operators have left_bp >= right_bp. + """ + @spec binding_power(atom()) :: {non_neg_integer(), non_neg_integer()} | nil + def binding_power(:or), do: {1, 2} + def binding_power(:and), do: {3, 4} + + # Comparison operators (left associative) + def binding_power(:lt), do: {5, 6} + def binding_power(:gt), do: {5, 6} + def binding_power(:le), do: {5, 6} + def binding_power(:ge), do: {5, 6} + def binding_power(:ne), do: {5, 6} + def binding_power(:eq), do: {5, 6} + + # String concatenation (right associative) + def binding_power(:concat), do: {7, 6} + + # Additive (left associative) + def binding_power(:add), do: {9, 10} + def binding_power(:sub), do: {9, 10} + + # Multiplicative (left associative) + def binding_power(:mul), do: {11, 12} + def binding_power(:div), do: {11, 12} + def binding_power(:floordiv), do: {11, 12} + def binding_power(:mod), do: {11, 12} + + # Unary operators + def binding_power(:not), do: {13, 14} + def binding_power(:neg), do: {13, 14} + def binding_power(:len), do: {13, 14} + + # Power (right associative) + def binding_power(:pow), do: {16, 15} + + # Not a binary operator + def binding_power(_), do: nil + + @doc """ + Returns the binding power for unary prefix operators. + + This is the minimum precedence required for the operand. + + Note: In Lua, unary minus has an unusual precedence - it's lower than power (^). + So -2^3 = -(2^3), not (-2)^3. + To achieve this: unary minus binding power (13) < power left_bp (16), + allowing power to bind within the unary's operand. + But 13 > multiplication left_bp (11), so -a*b = (-a)*b. + """ + @spec prefix_binding_power(atom()) :: non_neg_integer() | nil + def prefix_binding_power(:not), do: 14 + # Between mult (11) and power (16) + def prefix_binding_power(:sub), do: 13 + def prefix_binding_power(:len), do: 14 + def prefix_binding_power(_), do: nil + + @doc """ + Maps token operators to AST binary operators. + """ + @spec token_to_binop(atom()) :: Expr.BinOp.op() | nil + def token_to_binop(:or), do: :or + def token_to_binop(:and), do: :and + def token_to_binop(:lt), do: :lt + def token_to_binop(:gt), do: :gt + def token_to_binop(:le), do: :le + def token_to_binop(:ge), do: :ge + def token_to_binop(:ne), do: :ne + def token_to_binop(:eq), do: :eq + def token_to_binop(:concat), do: :concat + def token_to_binop(:add), do: :add + def token_to_binop(:sub), do: :sub + def token_to_binop(:mul), do: :mul + def token_to_binop(:div), do: :div + def token_to_binop(:floordiv), do: :floordiv + def token_to_binop(:mod), do: :mod + def token_to_binop(:pow), do: :pow + def token_to_binop(_), do: nil + + @doc """ + Maps token operators to AST unary operators. + """ + @spec token_to_unop(atom()) :: Expr.UnOp.op() | nil + def token_to_unop(:not), do: :not + def token_to_unop(:sub), do: :neg + def token_to_unop(:len), do: :len + def token_to_unop(_), do: nil + + @doc """ + Checks if a token is a binary operator. + """ + @spec is_binary_op?(atom()) :: boolean() + def is_binary_op?(op) do + binding_power(op) != nil + end + + @doc """ + Checks if a token is a prefix unary operator. + """ + @spec is_prefix_op?(atom()) :: boolean() + def is_prefix_op?(op) do + prefix_binding_power(op) != nil + end +end diff --git a/lib/lua/parser/recovery.ex b/lib/lua/parser/recovery.ex new file mode 100644 index 0000000..48a6970 --- /dev/null +++ b/lib/lua/parser/recovery.ex @@ -0,0 +1,209 @@ +defmodule Lua.Parser.Recovery do + @moduledoc """ + Error recovery strategies for the Lua parser. + + Allows the parser to continue after encountering errors, + collecting multiple errors in a single parse pass. + """ + + alias Lua.Parser.Error + alias Lua.Lexer + alias Lua.AST.Meta + + @type token :: Lexer.token() + @type recovery_result :: {:recovered, [token()], [Error.t()]} | {:failed, [Error.t()]} + + @doc """ + Attempts to recover from a parse error by finding a synchronization point. + + Synchronization points are tokens where we can safely resume parsing: + - Statement boundaries: `;`, `end`, `else`, `elseif`, `until` + - Block terminators: `}`, `)` + - Start of new statements: keywords like `if`, `while`, `for`, `function`, `local` + """ + @spec recover_at_statement([token()], Error.t()) :: recovery_result() + def recover_at_statement(tokens, error) do + case find_statement_boundary(tokens) do + {:ok, rest} -> + {:recovered, rest, [error]} + + :not_found -> + {:failed, [error]} + end + end + + @doc """ + Recovers from an unclosed delimiter by finding the matching closing delimiter. + """ + @spec recover_unclosed_delimiter([token()], atom(), Error.t()) :: recovery_result() + def recover_unclosed_delimiter(tokens, delimiter_type, error) do + closing = closing_delimiter(delimiter_type) + + case find_closing_delimiter(tokens, closing, 1) do + {:ok, rest} -> + {:recovered, rest, [error]} + + :not_found -> + # If we can't find the closing delimiter, try to recover at statement boundary + recover_at_statement(tokens, error) + end + end + + @doc """ + Attempts to recover from missing keyword error. + """ + @spec recover_missing_keyword([token()], atom(), Error.t()) :: recovery_result() + def recover_missing_keyword(tokens, keyword, error) do + case find_keyword(tokens, keyword) do + {:ok, rest} -> + {:recovered, rest, [error]} + + :not_found -> + recover_at_statement(tokens, error) + end + end + + @doc """ + Skips tokens until we find a valid statement start. + """ + @spec skip_to_statement([token()]) :: [token()] + def skip_to_statement(tokens) do + case find_statement_boundary(tokens) do + {:ok, rest} -> rest + :not_found -> [] + end + end + + @doc """ + Checks if a token is a statement boundary (synchronization point). + """ + @spec is_statement_boundary?(token()) :: boolean() + def is_statement_boundary?(token) do + case token do + {:delimiter, :semicolon, _} -> true + {:keyword, kw, _} when kw in [:end, :else, :elseif, :until] -> true + {:keyword, kw, _} when kw in [:if, :while, :for, :function, :local, :do, :repeat] -> true + {:eof, _} -> true + _ -> false + end + end + + defmodule DelimiterStack do + @moduledoc """ + Tracks unclosed delimiters in a stack-based manner. + """ + + defstruct stack: [] + + @type t :: %__MODULE__{stack: [{atom(), Meta.position()}]} + + def new, do: %__MODULE__{} + + def push(stack, delimiter, position) do + %{stack | stack: [{delimiter, position} | stack.stack]} + end + + def pop(stack, closing_delimiter) do + case stack.stack do + [{opening, _pos} | rest] -> + if matches?(opening, closing_delimiter) do + {:ok, %{stack | stack: rest}} + else + {:error, :mismatched, opening} + end + + [] -> + {:error, :empty} + end + end + + def peek(stack) do + case stack.stack do + [{delimiter, position} | _] -> {:ok, delimiter, position} + [] -> :empty + end + end + + def empty?(stack), do: stack.stack == [] + + defp matches?(opening, closing) do + case {opening, closing} do + {:lparen, :rparen} -> true + {:lbracket, :rbracket} -> true + {:lbrace, :rbrace} -> true + {:function, :end} -> true + {:if, :end} -> true + {:while, :end} -> true + {:for, :end} -> true + {:do, :end} -> true + _ -> false + end + end + end + + # Private helpers + + defp find_statement_boundary([token | rest]) do + if is_statement_boundary?(token) do + {:ok, [token | rest]} + else + find_statement_boundary(rest) + end + end + + defp find_statement_boundary([]), do: :not_found + + defp find_closing_delimiter([{:delimiter, delim, _} | rest], target, depth) + when delim == target do + if depth == 1 do + {:ok, rest} + else + find_closing_delimiter(rest, target, depth - 1) + end + end + + defp find_closing_delimiter([{:delimiter, opening, _} | rest], target, depth) + when opening in [:lparen, :lbracket, :lbrace] do + find_closing_delimiter(rest, target, depth + 1) + end + + defp find_closing_delimiter([{:keyword, :end, _} | rest], :end, depth) do + if depth == 1 do + {:ok, rest} + else + find_closing_delimiter(rest, :end, depth - 1) + end + end + + defp find_closing_delimiter([{:keyword, kw, _} | rest], :end, depth) + when kw in [:if, :while, :for, :function, :do] do + find_closing_delimiter(rest, :end, depth + 1) + end + + defp find_closing_delimiter([{:eof, _}], _target, _depth), do: :not_found + defp find_closing_delimiter([], _target, _depth), do: :not_found + + defp find_closing_delimiter([_ | rest], target, depth) do + find_closing_delimiter(rest, target, depth) + end + + defp find_keyword([{:keyword, kw, _} | _rest] = tokens, target) when kw == target do + {:ok, tokens} + end + + defp find_keyword([{:eof, _}], _target), do: :not_found + defp find_keyword([], _target), do: :not_found + + defp find_keyword([_ | rest], target) do + find_keyword(rest, target) + end + + defp closing_delimiter(delimiter) do + case delimiter do + :lparen -> :rparen + :lbracket -> :rbracket + :lbrace -> :rbrace + _ -> :end + end + end +end diff --git a/test/lua/ast/builder_test.exs b/test/lua/ast/builder_test.exs new file mode 100644 index 0000000..3184030 --- /dev/null +++ b/test/lua/ast/builder_test.exs @@ -0,0 +1,544 @@ +defmodule Lua.AST.BuilderTest do + use ExUnit.Case, async: true + + import Lua.AST.Builder + alias Lua.AST.{Chunk, Block, Expr, Statement} + + describe "chunk and block" do + test "creates a chunk" do + ast = chunk([local(["x"], [number(42)])]) + assert %Chunk{block: %Block{stmts: [%Statement.Local{}]}} = ast + end + + test "creates a block" do + blk = block([local(["x"], [number(42)])]) + assert %Block{stmts: [%Statement.Local{}]} = blk + end + end + + describe "literals" do + test "creates nil literal" do + assert %Expr.Nil{} = nil_lit() + end + + test "creates boolean literals" do + assert %Expr.Bool{value: true} = bool(true) + assert %Expr.Bool{value: false} = bool(false) + end + + test "creates number literal" do + assert %Expr.Number{value: 42} = number(42) + assert %Expr.Number{value: 3.14} = number(3.14) + end + + test "creates string literal" do + assert %Expr.String{value: "hello"} = string("hello") + end + + test "creates vararg" do + assert %Expr.Vararg{} = vararg() + end + end + + describe "variables and access" do + test "creates variable reference" do + assert %Expr.Var{name: "x"} = var("x") + end + + test "creates property access" do + prop = property(var("io"), "write") + assert %Expr.Property{table: %Expr.Var{name: "io"}, field: "write"} = prop + end + + test "creates index access" do + idx = index(var("t"), number(1)) + assert %Expr.Index{table: %Expr.Var{name: "t"}, key: %Expr.Number{value: 1}} = idx + end + + test "creates chained property access" do + prop = property(property(var("a"), "b"), "c") + + assert %Expr.Property{ + table: %Expr.Property{ + table: %Expr.Var{name: "a"}, + field: "b" + }, + field: "c" + } = prop + end + end + + describe "operators" do + test "creates binary operation" do + op = binop(:add, number(2), number(3)) + + assert %Expr.BinOp{op: :add, left: %Expr.Number{value: 2}, right: %Expr.Number{value: 3}} = + op + end + + test "creates all binary operators" do + ops = [ + :add, + :sub, + :mul, + :div, + :floor_div, + :mod, + :pow, + :concat, + :eq, + :ne, + :lt, + :gt, + :le, + :ge, + :and, + :or + ] + + for op <- ops do + assert %Expr.BinOp{op: ^op} = binop(op, number(1), number(2)) + end + end + + test "creates unary operation" do + op = unop(:neg, var("x")) + assert %Expr.UnOp{op: :neg, operand: %Expr.Var{name: "x"}} = op + end + + test "creates all unary operators" do + assert %Expr.UnOp{op: :not} = unop(:not, var("x")) + assert %Expr.UnOp{op: :neg} = unop(:neg, var("x")) + assert %Expr.UnOp{op: :len} = unop(:len, var("x")) + end + + test "creates nested operations" do + # (2 + 3) * 4 + op = binop(:mul, binop(:add, number(2), number(3)), number(4)) + + assert %Expr.BinOp{ + op: :mul, + left: %Expr.BinOp{op: :add}, + right: %Expr.Number{value: 4} + } = op + end + end + + describe "table constructors" do + test "creates empty table" do + tbl = table([]) + assert %Expr.Table{fields: []} = tbl + end + + test "creates array-style table" do + tbl = + table([ + {:list, number(1)}, + {:list, number(2)}, + {:list, number(3)} + ]) + + assert %Expr.Table{fields: [{:list, _}, {:list, _}, {:list, _}]} = tbl + end + + test "creates record-style table" do + tbl = + table([ + {:record, string("x"), number(10)}, + {:record, string("y"), number(20)} + ]) + + assert %Expr.Table{fields: [{:record, _, _}, {:record, _, _}]} = tbl + end + + test "creates mixed table" do + tbl = + table([ + {:list, number(1)}, + {:record, string("x"), number(10)} + ]) + + assert %Expr.Table{fields: [{:list, _}, {:record, _, _}]} = tbl + end + end + + describe "function calls" do + test "creates function call" do + c = call(var("print"), [string("hello")]) + + assert %Expr.Call{ + func: %Expr.Var{name: "print"}, + args: [%Expr.String{value: "hello"}] + } = c + end + + test "creates function call with multiple arguments" do + c = call(var("print"), [number(1), number(2), number(3)]) + assert %Expr.Call{args: [_, _, _]} = c + end + + test "creates method call" do + mc = method_call(var("file"), "read", [string("*a")]) + + assert %Expr.MethodCall{ + object: %Expr.Var{name: "file"}, + method: "read", + args: [%Expr.String{value: "*a"}] + } = mc + end + end + + describe "function expressions" do + test "creates simple function" do + fn_expr = function_expr(["x"], [return_stmt([var("x")])]) + + assert %Expr.Function{ + params: ["x"], + body: %Block{stmts: [%Statement.Return{}]} + } = fn_expr + end + + test "creates function with multiple parameters" do + fn_expr = function_expr(["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])]) + assert %Expr.Function{params: ["a", "b"]} = fn_expr + end + + test "creates function with vararg" do + fn_expr = function_expr([], [return_stmt([vararg()])], vararg: true) + assert %Expr.Function{params: [:vararg]} = fn_expr + end + end + + describe "statements" do + test "creates assignment" do + stmt = assign([var("x")], [number(42)]) + + assert %Statement.Assign{ + targets: [%Expr.Var{name: "x"}], + values: [%Expr.Number{value: 42}] + } = stmt + end + + test "creates multiple assignment" do + stmt = assign([var("x"), var("y")], [number(1), number(2)]) + assert %Statement.Assign{targets: [_, _], values: [_, _]} = stmt + end + + test "creates local declaration" do + stmt = local(["x"], [number(42)]) + assert %Statement.Local{names: ["x"], values: [%Expr.Number{value: 42}]} = stmt + end + + test "creates local declaration without value" do + stmt = local(["x"], []) + assert %Statement.Local{names: ["x"], values: []} = stmt + end + + test "creates local function" do + stmt = local_func("add", ["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])]) + + assert %Statement.LocalFunc{ + name: "add", + params: ["a", "b"], + body: %Block{} + } = stmt + end + + test "creates function declaration with string name" do + stmt = func_decl("add", ["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])]) + assert %Statement.FuncDecl{name: ["add"], params: ["a", "b"]} = stmt + end + + test "creates function declaration with path name" do + stmt = + func_decl(["math", "add"], ["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])]) + + assert %Statement.FuncDecl{name: ["math", "add"]} = stmt + end + + test "creates call statement" do + stmt = call_stmt(call(var("print"), [string("hello")])) + assert %Statement.CallStmt{call: %Expr.Call{}} = stmt + end + + test "creates return statement" do + stmt = return_stmt([]) + assert %Statement.Return{values: []} = stmt + + stmt = return_stmt([number(42)]) + assert %Statement.Return{values: [%Expr.Number{value: 42}]} = stmt + end + + test "creates break statement" do + stmt = break_stmt() + assert %Statement.Break{} = stmt + end + + test "creates goto statement" do + stmt = goto_stmt("label") + assert %Statement.Goto{label: "label"} = stmt + end + + test "creates label" do + stmt = label("label") + assert %Statement.Label{name: "label"} = stmt + end + end + + describe "control flow" do + test "creates if statement" do + stmt = if_stmt(var("x"), [return_stmt([number(1)])]) + + assert %Statement.If{ + condition: %Expr.Var{name: "x"}, + then_block: %Block{stmts: [%Statement.Return{}]}, + elseifs: [], + else_block: nil + } = stmt + end + + test "creates if-else statement" do + stmt = + if_stmt( + var("x"), + [return_stmt([number(1)])], + else: [return_stmt([number(0)])] + ) + + assert %Statement.If{else_block: %Block{}} = stmt + end + + test "creates if-elseif-else statement" do + stmt = + if_stmt( + binop(:gt, var("x"), number(0)), + [return_stmt([number(1)])], + elseif: [{binop(:lt, var("x"), number(0)), [return_stmt([unop(:neg, number(1))])]}], + else: [return_stmt([number(0)])] + ) + + assert %Statement.If{ + elseifs: [{_, %Block{}}], + else_block: %Block{} + } = stmt + end + + test "creates while loop" do + stmt = + while_stmt(binop(:gt, var("x"), number(0)), [ + assign([var("x")], [binop(:sub, var("x"), number(1))]) + ]) + + assert %Statement.While{ + condition: %Expr.BinOp{op: :gt}, + body: %Block{} + } = stmt + end + + test "creates repeat-until loop" do + stmt = + repeat_stmt( + [assign([var("x")], [binop(:sub, var("x"), number(1))])], + binop(:le, var("x"), number(0)) + ) + + assert %Statement.Repeat{ + body: %Block{}, + condition: %Expr.BinOp{op: :le} + } = stmt + end + + test "creates numeric for loop" do + stmt = + for_num("i", number(1), number(10), [ + call_stmt(call(var("print"), [var("i")])) + ]) + + assert %Statement.ForNum{ + var: "i", + start: %Expr.Number{value: 1}, + limit: %Expr.Number{value: 10}, + step: nil, + body: %Block{} + } = stmt + end + + test "creates numeric for loop with step" do + stmt = + for_num( + "i", + number(1), + number(10), + [ + call_stmt(call(var("print"), [var("i")])) + ], + step: number(2) + ) + + assert %Statement.ForNum{step: %Expr.Number{value: 2}} = stmt + end + + test "creates generic for loop" do + stmt = + for_in( + ["k", "v"], + [call(var("pairs"), [var("t")])], + [call_stmt(call(var("print"), [var("k"), var("v")]))] + ) + + assert %Statement.ForIn{ + vars: ["k", "v"], + iterators: [%Expr.Call{}], + body: %Block{} + } = stmt + end + + test "creates do block" do + stmt = + do_block([ + local(["x"], [number(10)]), + call_stmt(call(var("print"), [var("x")])) + ]) + + assert %Statement.Do{body: %Block{stmts: [_, _]}} = stmt + end + end + + describe "complex structures" do + test "builds nested function with closure" do + # function outer(x) return function(y) return x + y end end + ast = + chunk([ + func_decl("outer", ["x"], [ + return_stmt([ + function_expr(["y"], [ + return_stmt([binop(:add, var("x"), var("y"))]) + ]) + ]) + ]) + ]) + + assert %Chunk{ + block: %Block{ + stmts: [ + %Statement.FuncDecl{ + name: ["outer"], + body: %Block{ + stmts: [ + %Statement.Return{ + values: [%Expr.Function{}] + } + ] + } + } + ] + } + } = ast + end + + test "builds complex if-elseif-else chain" do + ast = + chunk([ + if_stmt( + binop(:gt, var("x"), number(0)), + [return_stmt([string("positive")])], + elseif: [ + {binop(:lt, var("x"), number(0)), [return_stmt([string("negative")])]}, + {binop(:eq, var("x"), number(0)), [return_stmt([string("zero")])]} + ], + else: [return_stmt([string("unknown")])] + ) + ]) + + assert %Chunk{ + block: %Block{ + stmts: [ + %Statement.If{ + elseifs: [{_, _}, {_, _}], + else_block: %Block{} + } + ] + } + } = ast + end + + test "builds nested loops" do + # for i = 1, 10 do + # for j = 1, 10 do + # print(i * j) + # end + # end + ast = + chunk([ + for_num("i", number(1), number(10), [ + for_num("j", number(1), number(10), [ + call_stmt(call(var("print"), [binop(:mul, var("i"), var("j"))])) + ]) + ]) + ]) + + assert %Chunk{ + block: %Block{ + stmts: [ + %Statement.ForNum{ + body: %Block{ + stmts: [%Statement.ForNum{}] + } + } + ] + } + } = ast + end + + test "builds table with complex expressions" do + # { + # x = 1 + 2, + # y = func(), + # [key] = value, + # nested = {a = 1, b = 2} + # } + tbl = + table([ + {:record, string("x"), binop(:add, number(1), number(2))}, + {:record, string("y"), call(var("func"), [])}, + {:record, var("key"), var("value")}, + {:record, string("nested"), + table([ + {:record, string("a"), number(1)}, + {:record, string("b"), number(2)} + ])} + ]) + + assert %Expr.Table{ + fields: [ + {:record, %Expr.String{value: "x"}, %Expr.BinOp{}}, + {:record, %Expr.String{value: "y"}, %Expr.Call{}}, + {:record, %Expr.Var{}, %Expr.Var{}}, + {:record, %Expr.String{value: "nested"}, %Expr.Table{}} + ] + } = tbl + end + end + + describe "integration with parser" do + test "builder output can be printed and reparsed" do + # Build an AST using builder + ast = + chunk([ + local(["x"], [number(10)]), + local(["y"], [number(20)]), + assign([var("z")], [binop(:add, var("x"), var("y"))]), + call_stmt(call(var("print"), [var("z")])) + ]) + + # Print it + code = Lua.AST.PrettyPrinter.print(ast) + + # Parse it back + {:ok, reparsed} = Lua.Parser.parse(code) + + # Should have same structure (ignoring meta) + assert length(ast.block.stmts) == length(reparsed.block.stmts) + end + end +end diff --git a/test/lua/ast/meta_test.exs b/test/lua/ast/meta_test.exs new file mode 100644 index 0000000..526a5c0 --- /dev/null +++ b/test/lua/ast/meta_test.exs @@ -0,0 +1,251 @@ +defmodule Lua.AST.MetaTest do + use ExUnit.Case, async: true + + alias Lua.AST.Meta + + describe "new/0" do + test "creates empty meta" do + meta = Meta.new() + assert %Meta{start: nil, end: nil, metadata: %{}} = meta + end + end + + describe "new/2" do + test "creates meta with start and end positions" do + start = %{line: 1, column: 1, byte_offset: 0} + finish = %{line: 1, column: 10, byte_offset: 9} + + meta = Meta.new(start, finish) + + assert meta.start == start + assert meta.end == finish + assert meta.metadata == %{} + end + + test "accepts nil positions" do + meta = Meta.new(nil, nil) + assert meta.start == nil + assert meta.end == nil + end + end + + describe "new/3" do + test "creates meta with metadata" do + start = %{line: 1, column: 1, byte_offset: 0} + finish = %{line: 1, column: 10, byte_offset: 9} + metadata = %{type: :test, custom: "value"} + + meta = Meta.new(start, finish, metadata) + + assert meta.start == start + assert meta.end == finish + assert meta.metadata == metadata + end + + test "accepts empty metadata map" do + start = %{line: 1, column: 1, byte_offset: 0} + finish = %{line: 1, column: 10, byte_offset: 9} + + meta = Meta.new(start, finish, %{}) + + assert meta.metadata == %{} + end + end + + describe "add_metadata/3" do + test "adds metadata to existing meta" do + meta = Meta.new() + meta = Meta.add_metadata(meta, :test, "value") + + assert meta.metadata == %{test: "value"} + end + + test "adds multiple metadata fields" do + meta = Meta.new(nil, nil, %{a: 1}) + meta = Meta.add_metadata(meta, :b, 2) + + assert meta.metadata == %{a: 1, b: 2} + end + + test "overwrites existing keys" do + meta = Meta.new(nil, nil, %{a: 1}) + meta = Meta.add_metadata(meta, :a, 2) + + assert meta.metadata == %{a: 2} + end + end + + describe "merge/2" do + test "merges two metas taking earliest start" do + meta1 = + Meta.new(%{line: 1, column: 5, byte_offset: 10}, %{line: 1, column: 10, byte_offset: 20}) + + meta2 = + Meta.new(%{line: 1, column: 1, byte_offset: 5}, %{line: 1, column: 8, byte_offset: 15}) + + merged = Meta.merge(meta1, meta2) + + assert merged.start == %{line: 1, column: 1, byte_offset: 5} + end + + test "merges two metas taking latest end" do + meta1 = + Meta.new(%{line: 1, column: 1, byte_offset: 0}, %{line: 1, column: 10, byte_offset: 9}) + + meta2 = + Meta.new(%{line: 1, column: 5, byte_offset: 4}, %{line: 1, column: 20, byte_offset: 19}) + + merged = Meta.merge(meta1, meta2) + + assert merged.end == %{line: 1, column: 20, byte_offset: 19} + end + + test "handles nil start positions" do + meta1 = Meta.new(nil, %{line: 1, column: 10, byte_offset: 9}) + meta2 = Meta.new(%{line: 1, column: 1, byte_offset: 0}, nil) + + merged = Meta.merge(meta1, meta2) + + assert merged.start == %{line: 1, column: 1, byte_offset: 0} + assert merged.end == %{line: 1, column: 10, byte_offset: 9} + end + + test "handles nil end positions" do + meta1 = Meta.new(%{line: 1, column: 1, byte_offset: 0}, nil) + meta2 = Meta.new(nil, %{line: 1, column: 10, byte_offset: 9}) + + merged = Meta.merge(meta1, meta2) + + assert merged.start == %{line: 1, column: 1, byte_offset: 0} + assert merged.end == %{line: 1, column: 10, byte_offset: 9} + end + + test "handles both positions nil on first meta" do + meta1 = Meta.new(nil, nil) + + meta2 = + Meta.new(%{line: 1, column: 1, byte_offset: 0}, %{line: 1, column: 10, byte_offset: 9}) + + merged = Meta.merge(meta1, meta2) + + assert merged.start == %{line: 1, column: 1, byte_offset: 0} + assert merged.end == %{line: 1, column: 10, byte_offset: 9} + end + + test "handles both positions nil on second meta" do + meta1 = + Meta.new(%{line: 1, column: 1, byte_offset: 0}, %{line: 1, column: 10, byte_offset: 9}) + + meta2 = Meta.new(nil, nil) + + merged = Meta.merge(meta1, meta2) + + assert merged.start == %{line: 1, column: 1, byte_offset: 0} + assert merged.end == %{line: 1, column: 10, byte_offset: 9} + end + + test "chooses earlier position when first has earlier byte offset" do + meta1 = + Meta.new(%{line: 1, column: 1, byte_offset: 0}, %{line: 1, column: 5, byte_offset: 4}) + + meta2 = + Meta.new(%{line: 1, column: 6, byte_offset: 5}, %{line: 1, column: 10, byte_offset: 9}) + + merged = Meta.merge(meta1, meta2) + + assert merged.start == %{line: 1, column: 1, byte_offset: 0} + assert merged.end == %{line: 1, column: 10, byte_offset: 9} + end + + test "chooses later position when second has later byte offset" do + meta1 = + Meta.new(%{line: 1, column: 1, byte_offset: 0}, %{line: 1, column: 5, byte_offset: 4}) + + meta2 = + Meta.new(%{line: 1, column: 2, byte_offset: 1}, %{line: 1, column: 10, byte_offset: 9}) + + merged = Meta.merge(meta1, meta2) + + assert merged.start == %{line: 1, column: 1, byte_offset: 0} + assert merged.end == %{line: 1, column: 10, byte_offset: 9} + end + end + + describe "position tracking" do + test "stores line numbers" do + meta = + Meta.new(%{line: 5, column: 10, byte_offset: 50}, %{line: 5, column: 20, byte_offset: 60}) + + assert meta.start.line == 5 + assert meta.end.line == 5 + end + + test "stores column numbers" do + meta = + Meta.new(%{line: 1, column: 5, byte_offset: 4}, %{line: 1, column: 15, byte_offset: 14}) + + assert meta.start.column == 5 + assert meta.end.column == 15 + end + + test "stores byte offsets" do + meta = + Meta.new(%{line: 1, column: 1, byte_offset: 100}, %{line: 1, column: 10, byte_offset: 200}) + + assert meta.start.byte_offset == 100 + assert meta.end.byte_offset == 200 + end + + test "handles multiline spans" do + meta = + Meta.new(%{line: 1, column: 1, byte_offset: 0}, %{line: 10, column: 5, byte_offset: 150}) + + assert meta.start.line == 1 + assert meta.end.line == 10 + end + end + + describe "metadata storage" do + test "stores arbitrary data" do + meta = + Meta.new(nil, nil, %{ + node_type: :function, + name: "test", + params: ["a", "b"], + is_async: false + }) + + assert meta.metadata.node_type == :function + assert meta.metadata.name == "test" + assert meta.metadata.params == ["a", "b"] + assert meta.metadata.is_async == false + end + + test "stores nested data structures" do + meta = + Meta.new(nil, nil, %{ + scope: %{ + variables: ["x", "y"], + functions: ["f", "g"] + } + }) + + assert meta.metadata.scope.variables == ["x", "y"] + assert meta.metadata.scope.functions == ["f", "g"] + end + end + + describe "struct validation" do + test "has correct struct fields" do + meta = %Meta{} + assert Map.has_key?(meta, :start) + assert Map.has_key?(meta, :end) + assert Map.has_key?(meta, :metadata) + end + + test "is a proper struct" do + meta = %Meta{} + assert meta.__struct__ == Lua.AST.Meta + end + end +end diff --git a/test/lua/ast/pretty_printer_test.exs b/test/lua/ast/pretty_printer_test.exs new file mode 100644 index 0000000..996fd76 --- /dev/null +++ b/test/lua/ast/pretty_printer_test.exs @@ -0,0 +1,1144 @@ +defmodule Lua.AST.PrettyPrinterTest do + use ExUnit.Case, async: true + + import Lua.AST.Builder + alias Lua.AST.PrettyPrinter + + describe "literals" do + test "prints nil" do + assert PrettyPrinter.print(chunk([return_stmt([nil_lit()])])) == "return nil\n" + end + + test "prints booleans" do + assert PrettyPrinter.print(chunk([return_stmt([bool(true)])])) == "return true\n" + assert PrettyPrinter.print(chunk([return_stmt([bool(false)])])) == "return false\n" + end + + test "prints numbers" do + assert PrettyPrinter.print(chunk([return_stmt([number(42)])])) == "return 42\n" + assert PrettyPrinter.print(chunk([return_stmt([number(3.14)])])) == "return 3.14\n" + assert PrettyPrinter.print(chunk([return_stmt([number(1.0)])])) == "return 1.0\n" + end + + test "prints strings" do + assert PrettyPrinter.print(chunk([return_stmt([string("hello")])])) == "return \"hello\"\n" + end + + test "escapes special characters in strings" do + ast = chunk([return_stmt([string("hello\nworld")])]) + result = PrettyPrinter.print(ast) + assert result == "return \"hello\\nworld\"\n" + end + + test "prints vararg" do + assert PrettyPrinter.print(chunk([return_stmt([vararg()])])) == "return ...\n" + end + end + + describe "variables and access" do + test "prints variable reference" do + assert PrettyPrinter.print(chunk([return_stmt([var("x")])])) == "return x\n" + end + + test "prints property access" do + ast = chunk([return_stmt([property(var("io"), "write")])]) + assert PrettyPrinter.print(ast) == "return io.write\n" + end + + test "prints index access" do + ast = chunk([return_stmt([index(var("t"), number(1))])]) + assert PrettyPrinter.print(ast) == "return t[1]\n" + end + + test "prints chained property access" do + ast = + chunk([ + return_stmt([ + property(property(var("a"), "b"), "c") + ]) + ]) + + assert PrettyPrinter.print(ast) == "return a.b.c\n" + end + end + + describe "operators" do + test "prints binary operators" do + ast = chunk([return_stmt([binop(:add, number(2), number(3))])]) + assert PrettyPrinter.print(ast) == "return 2 + 3\n" + end + + test "prints unary operators" do + ast = chunk([return_stmt([unop(:neg, var("x"))])]) + assert PrettyPrinter.print(ast) == "return -x\n" + + ast = chunk([return_stmt([unop(:not, var("flag"))])]) + assert PrettyPrinter.print(ast) == "return not flag\n" + + ast = chunk([return_stmt([unop(:len, var("list"))])]) + assert PrettyPrinter.print(ast) == "return #list\n" + end + + test "handles operator precedence with parentheses" do + # 2 + 3 * 4 should print as is (multiplication has higher precedence) + ast = + chunk([ + return_stmt([ + binop(:add, number(2), binop(:mul, number(3), number(4))) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return 2 + 3 * 4\n" + + # (2 + 3) * 4 should have parentheses + ast = + chunk([ + return_stmt([ + binop(:mul, binop(:add, number(2), number(3)), number(4)) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return (2 + 3) * 4\n" + end + + test "handles right-associative operators" do + # 2 ^ 3 ^ 4 should print as 2 ^ 3 ^ 4 (right-associative) + ast = + chunk([ + return_stmt([ + binop(:pow, number(2), binop(:pow, number(3), number(4))) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return 2 ^ 3 ^ 4\n" + + # (2 ^ 3) ^ 4 should have parentheses + ast = + chunk([ + return_stmt([ + binop(:pow, binop(:pow, number(2), number(3)), number(4)) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return (2 ^ 3) ^ 4\n" + end + + test "handles unary minus with power" do + # -2^3 should print with parentheses as -(2^3) is not needed because parser handles it + ast = chunk([return_stmt([unop(:neg, binop(:pow, number(2), number(3)))])]) + result = PrettyPrinter.print(ast) + # Either -2^3 or -(2^3) is acceptable + assert result == "return -(2 ^ 3)\n" or result == "return -2 ^ 3\n" + end + end + + describe "table constructors" do + test "prints empty table" do + ast = chunk([return_stmt([table([])])]) + assert PrettyPrinter.print(ast) == "return {}\n" + end + + test "prints array-style table" do + ast = + chunk([ + return_stmt([ + table([ + {:list, number(1)}, + {:list, number(2)}, + {:list, number(3)} + ]) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return {1, 2, 3}\n" + end + + test "prints record-style table" do + ast = + chunk([ + return_stmt([ + table([ + {:record, string("x"), number(10)}, + {:record, string("y"), number(20)} + ]) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return {x = 10, y = 20}\n" + end + + test "prints mixed table fields" do + ast = + chunk([ + return_stmt([ + table([ + {:list, number(1)}, + {:record, string("x"), number(10)} + ]) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return {1, x = 10}\n" + end + end + + describe "function calls" do + test "prints simple function call" do + ast = chunk([call_stmt(call(var("print"), [string("hello")]))]) + assert PrettyPrinter.print(ast) == "print(\"hello\")\n" + end + + test "prints function call with multiple arguments" do + ast = chunk([call_stmt(call(var("print"), [number(1), number(2), number(3)]))]) + assert PrettyPrinter.print(ast) == "print(1, 2, 3)\n" + end + + test "prints method call" do + ast = chunk([call_stmt(method_call(var("file"), "read", [string("*a")]))]) + assert PrettyPrinter.print(ast) == "file:read(\"*a\")\n" + end + end + + describe "function expressions" do + test "prints simple function" do + ast = + chunk([ + return_stmt([ + function_expr(["x"], [return_stmt([var("x")])]) + ]) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "function(x)" + assert result =~ "return x" + assert result =~ "end" + end + + test "prints function with multiple parameters" do + ast = + chunk([ + return_stmt([ + function_expr(["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])]) + ]) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "function(a, b)" + assert result =~ "return a + b" + end + + test "prints function with vararg" do + ast = + chunk([ + return_stmt([ + function_expr([], [return_stmt([vararg()])], vararg: true) + ]) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "function(...)" + end + end + + describe "statements" do + test "prints assignment" do + ast = chunk([assign([var("x")], [number(42)])]) + assert PrettyPrinter.print(ast) == "x = 42\n" + end + + test "prints multiple assignment" do + ast = chunk([assign([var("x"), var("y")], [number(1), number(2)])]) + assert PrettyPrinter.print(ast) == "x, y = 1, 2\n" + end + + test "prints local declaration" do + ast = chunk([local(["x"], [number(42)])]) + assert PrettyPrinter.print(ast) == "local x = 42\n" + end + + test "prints local declaration without value" do + ast = chunk([local(["x"], [])]) + assert PrettyPrinter.print(ast) == "local x\n" + end + + test "prints local function" do + ast = + chunk([local_func("add", ["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])])]) + + result = PrettyPrinter.print(ast) + assert result =~ "local function add(a, b)" + assert result =~ "return a + b" + assert result =~ "end" + end + + test "prints function declaration" do + ast = + chunk([func_decl("add", ["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])])]) + + result = PrettyPrinter.print(ast) + assert result =~ "function add(a, b)" + assert result =~ "return a + b" + assert result =~ "end" + end + + test "prints return statement" do + ast = chunk([return_stmt([])]) + assert PrettyPrinter.print(ast) == "return\n" + + ast = chunk([return_stmt([number(42)])]) + assert PrettyPrinter.print(ast) == "return 42\n" + end + + test "prints break statement" do + ast = chunk([break_stmt()]) + assert PrettyPrinter.print(ast) == "break\n" + end + end + + describe "control flow" do + test "prints if statement" do + ast = + chunk([ + if_stmt(var("x"), [return_stmt([number(1)])]) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "if x then" + assert result =~ "return 1" + assert result =~ "end" + end + + test "prints if-else statement" do + ast = + chunk([ + if_stmt( + var("x"), + [return_stmt([number(1)])], + else: [return_stmt([number(0)])] + ) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "if x then" + assert result =~ "else" + assert result =~ "end" + end + + test "prints if-elseif-else statement" do + ast = + chunk([ + if_stmt( + binop(:gt, var("x"), number(0)), + [return_stmt([number(1)])], + elseif: [{binop(:lt, var("x"), number(0)), [return_stmt([unop(:neg, number(1))])]}], + else: [return_stmt([number(0)])] + ) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "if x > 0 then" + assert result =~ "elseif x < 0 then" + assert result =~ "else" + assert result =~ "end" + end + + test "prints while loop" do + ast = + chunk([ + while_stmt(binop(:gt, var("x"), number(0)), [ + assign([var("x")], [binop(:sub, var("x"), number(1))]) + ]) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "while x > 0 do" + assert result =~ "x = x - 1" + assert result =~ "end" + end + + test "prints repeat-until loop" do + ast = + chunk([ + repeat_stmt( + [assign([var("x")], [binop(:sub, var("x"), number(1))])], + binop(:le, var("x"), number(0)) + ) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "repeat" + assert result =~ "x = x - 1" + assert result =~ "until x <= 0" + end + + test "prints numeric for loop" do + ast = + chunk([ + for_num("i", number(1), number(10), [ + call_stmt(call(var("print"), [var("i")])) + ]) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "for i = 1, 10 do" + assert result =~ "print(i)" + assert result =~ "end" + end + + test "prints numeric for loop with step" do + ast = + chunk([ + for_num( + "i", + number(1), + number(10), + [ + call_stmt(call(var("print"), [var("i")])) + ], + step: number(2) + ) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "for i = 1, 10, 2 do" + end + + test "prints generic for loop" do + ast = + chunk([ + for_in( + ["k", "v"], + [call(var("pairs"), [var("t")])], + [call_stmt(call(var("print"), [var("k"), var("v")]))] + ) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "for k, v in pairs(t) do" + assert result =~ "print(k, v)" + assert result =~ "end" + end + + test "prints do block" do + ast = + chunk([ + do_block([ + local(["x"], [number(10)]), + call_stmt(call(var("print"), [var("x")])) + ]) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "do" + assert result =~ "local x = 10" + assert result =~ "print(x)" + assert result =~ "end" + end + end + + describe "indentation" do + test "indents nested blocks" do + ast = + chunk([ + if_stmt(var("x"), [ + if_stmt(var("y"), [ + return_stmt([number(1)]) + ]) + ]) + ]) + + result = PrettyPrinter.print(ast) + # Check that nested blocks are indented + lines = String.split(result, "\n", trim: true) + assert Enum.any?(lines, fn line -> String.starts_with?(line, " ") end) + end + + test "respects custom indent size" do + ast = + chunk([ + if_stmt(var("x"), [ + return_stmt([number(1)]) + ]) + ]) + + result = PrettyPrinter.print(ast, indent: 4) + assert result =~ " return 1" + end + end + + describe "round-trip" do + test "can round-trip simple expressions" do + original = "return 2 + 3\n" + {:ok, ast} = Lua.Parser.parse(original) + printed = PrettyPrinter.print(ast) + assert printed == original + end + + test "can round-trip local assignments" do + original = "local x = 42\n" + {:ok, ast} = Lua.Parser.parse(original) + printed = PrettyPrinter.print(ast) + assert printed == original + end + + test "can round-trip function declarations" do + code = """ + function add(a, b) + return a + b + end + """ + + {:ok, ast} = Lua.Parser.parse(code) + printed = PrettyPrinter.print(ast) + + # Parse again to verify structure matches + {:ok, ast2} = Lua.Parser.parse(printed) + + # Compare AST structures (ignoring meta) + assert ast.block.stmts |> length() == ast2.block.stmts |> length() + end + end + + describe "string escaping" do + test "escapes backslash" do + ast = chunk([return_stmt([string("path\\to\\file")])]) + result = PrettyPrinter.print(ast) + assert result == "return \"path\\\\to\\\\file\"\n" + end + + test "escapes double quotes" do + ast = chunk([return_stmt([string("say \"hello\"")])]) + result = PrettyPrinter.print(ast) + assert result == "return \"say \\\"hello\\\"\"\n" + end + + test "escapes tab character" do + ast = chunk([return_stmt([string("hello\tworld")])]) + result = PrettyPrinter.print(ast) + assert result == "return \"hello\\tworld\"\n" + end + + test "escapes all special characters together" do + ast = chunk([return_stmt([string("line1\n\"quote\"\ttab\\back")])]) + result = PrettyPrinter.print(ast) + assert result == "return \"line1\\n\\\"quote\\\"\\ttab\\\\back\"\n" + end + end + + describe "all binary operators" do + test "prints floor division" do + ast = chunk([return_stmt([binop(:floor_div, number(10), number(3))])]) + assert PrettyPrinter.print(ast) == "return 10 // 3\n" + end + + test "prints modulo" do + ast = chunk([return_stmt([binop(:mod, number(10), number(3))])]) + assert PrettyPrinter.print(ast) == "return 10 % 3\n" + end + + test "prints concatenation" do + ast = chunk([return_stmt([binop(:concat, string("hello"), string("world"))])]) + assert PrettyPrinter.print(ast) == "return \"hello\" .. \"world\"\n" + end + + test "prints equality" do + ast = chunk([return_stmt([binop(:eq, var("x"), var("y"))])]) + assert PrettyPrinter.print(ast) == "return x == y\n" + end + + test "prints inequality" do + ast = chunk([return_stmt([binop(:ne, var("x"), var("y"))])]) + assert PrettyPrinter.print(ast) == "return x ~= y\n" + end + + test "prints less than or equal" do + ast = chunk([return_stmt([binop(:le, var("x"), var("y"))])]) + assert PrettyPrinter.print(ast) == "return x <= y\n" + end + + test "prints greater than or equal" do + ast = chunk([return_stmt([binop(:ge, var("x"), var("y"))])]) + assert PrettyPrinter.print(ast) == "return x >= y\n" + end + + test "prints logical and" do + ast = chunk([return_stmt([binop(:and, var("x"), var("y"))])]) + assert PrettyPrinter.print(ast) == "return x and y\n" + end + + test "prints logical or" do + ast = chunk([return_stmt([binop(:or, var("x"), var("y"))])]) + assert PrettyPrinter.print(ast) == "return x or y\n" + end + end + + describe "operator associativity" do + test "handles concat right-associativity" do + # "a" .. "b" .. "c" should print as is (right-associative) + ast = + chunk([ + return_stmt([ + binop(:concat, string("a"), binop(:concat, string("b"), string("c"))) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return \"a\" .. \"b\" .. \"c\"\n" + + # ("a" .. "b") .. "c" should have parentheses + ast = + chunk([ + return_stmt([ + binop(:concat, binop(:concat, string("a"), string("b")), string("c")) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return (\"a\" .. \"b\") .. \"c\"\n" + end + + test "handles subtraction left-associativity" do + # (10 - 5) - 2 should print as is (left-associative) + ast = + chunk([ + return_stmt([ + binop(:sub, binop(:sub, number(10), number(5)), number(2)) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return 10 - 5 - 2\n" + + # 10 - (5 - 2) should have parentheses + ast = + chunk([ + return_stmt([ + binop(:sub, number(10), binop(:sub, number(5), number(2))) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return 10 - (5 - 2)\n" + end + end + + describe "table key formatting" do + test "uses bracket notation for non-identifier strings" do + ast = + chunk([ + return_stmt([ + table([ + {:record, string("not-valid"), number(1)}, + {:record, string("123"), number(2)}, + {:record, string("with space"), number(3)} + ]) + ]) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "[\"not-valid\"] = 1" + assert result =~ "[\"123\"] = 2" + assert result =~ "[\"with space\"] = 3" + end + + test "uses bracket notation for Lua keywords" do + ast = + chunk([ + return_stmt([ + table([ + {:record, string("end"), number(1)}, + {:record, string("while"), number(2)}, + {:record, string("function"), number(3)} + ]) + ]) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "[\"end\"] = 1" + assert result =~ "[\"while\"] = 2" + assert result =~ "[\"function\"] = 3" + end + + test "uses bracket notation for non-string keys" do + ast = + chunk([ + return_stmt([ + table([ + {:record, number(1), string("first")}, + {:record, var("x"), string("variable")} + ]) + ]) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "[1] = \"first\"" + assert result =~ "[x] = \"variable\"" + end + end + + describe "dotted function names" do + test "prints function with dotted name" do + ast = + chunk([ + func_decl(["math", "add"], ["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])]) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "function math.add(a, b)" + end + + test "prints function with deeply nested name" do + ast = + chunk([ + func_decl(["a", "b", "c"], ["x"], [return_stmt([var("x")])]) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "function a.b.c(x)" + end + end + + describe "local statement variations" do + test "prints local with nil values" do + ast = chunk([local(["x"], nil)]) + assert PrettyPrinter.print(ast) == "local x\n" + end + + test "prints multiple local variables" do + ast = chunk([local(["x", "y", "z"], [])]) + assert PrettyPrinter.print(ast) == "local x, y, z\n" + end + + test "prints multiple local variables with values" do + ast = chunk([local(["x", "y"], [number(1), number(2)])]) + assert PrettyPrinter.print(ast) == "local x, y = 1, 2\n" + end + end + + describe "goto and label statements" do + test "prints goto statement" do + ast = chunk([goto_stmt("skip")]) + assert PrettyPrinter.print(ast) == "goto skip\n" + end + + test "prints label statement" do + ast = chunk([label("skip")]) + assert PrettyPrinter.print(ast) == "::skip::\n" + end + end + + describe "function calls with no arguments" do + test "prints function call with no arguments" do + ast = chunk([call_stmt(call(var("print"), []))]) + assert PrettyPrinter.print(ast) == "print()\n" + end + + test "prints method call with no arguments" do + ast = chunk([call_stmt(method_call(var("obj"), "method", []))]) + assert PrettyPrinter.print(ast) == "obj:method()\n" + end + end + + describe "if statement variations" do + test "prints if statement without elseif" do + ast = + chunk([ + if_stmt( + var("x"), + [return_stmt([number(1)])], + elseif: [] + ) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "if x then" + assert result =~ "return 1" + assert result =~ "end" + end + + test "prints if statement with multiple elseifs" do + ast = + chunk([ + if_stmt( + binop(:eq, var("x"), number(1)), + [return_stmt([string("one")])], + elseif: [ + {binop(:eq, var("x"), number(2)), [return_stmt([string("two")])]}, + {binop(:eq, var("x"), number(3)), [return_stmt([string("three")])]} + ], + else: [return_stmt([string("other")])] + ) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "if x == 1 then" + assert result =~ "elseif x == 2 then" + assert result =~ "elseif x == 3 then" + assert result =~ "else" + assert result =~ "end" + end + end + + describe "local function with vararg" do + test "prints local function with vararg parameter" do + ast = + chunk([ + local_func("variadic", [], [return_stmt([vararg()])], vararg: true) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "local function variadic(...)" + assert result =~ "return ..." + end + + test "prints local function with mixed parameters and vararg" do + ast = + chunk([ + local_func("mixed", ["a", "b"], [return_stmt([vararg()])], vararg: true) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "local function mixed(a, b, ...)" + end + end + + describe "function declaration with vararg" do + test "prints function declaration with vararg parameter" do + ast = + chunk([ + func_decl("variadic", [], [return_stmt([vararg()])], vararg: true) + ]) + + result = PrettyPrinter.print(ast) + assert result =~ "function variadic(...)" + assert result =~ "return ..." + end + end + + describe "precedence edge cases" do + test "handles division and multiplication precedence" do + # 10 / 2 * 3 (same precedence, left-to-right) + ast = + chunk([ + return_stmt([ + binop(:mul, binop(:div, number(10), number(2)), number(3)) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return 10 / 2 * 3\n" + end + + test "handles comparison operators with arithmetic" do + # 2 + 3 < 10 (arithmetic has higher precedence) + ast = + chunk([ + return_stmt([ + binop(:lt, binop(:add, number(2), number(3)), number(10)) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return 2 + 3 < 10\n" + end + + test "handles logical operators with comparisons" do + # x < 10 and y > 5 (comparison has higher precedence) + ast = + chunk([ + return_stmt([ + binop(:and, binop(:lt, var("x"), number(10)), binop(:gt, var("y"), number(5))) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return x < 10 and y > 5\n" + end + + test "handles or with and (and has higher precedence)" do + # a or b and c should print as a or (b and c) + ast = + chunk([ + return_stmt([ + binop(:or, var("a"), binop(:and, var("b"), var("c"))) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return a or b and c\n" + end + end + + describe "all comparison operators in precedence" do + test "handles all comparison operators" do + # Test that all comparison operators work together + ast = chunk([return_stmt([binop(:lt, var("a"), var("b"))])]) + assert PrettyPrinter.print(ast) == "return a < b\n" + + ast = chunk([return_stmt([binop(:gt, var("a"), var("b"))])]) + assert PrettyPrinter.print(ast) == "return a > b\n" + end + end + + describe "nested structures" do + test "handles deeply nested expressions" do + # ((a + b) * c) / d + ast = + chunk([ + return_stmt([ + binop( + :div, + binop(:mul, binop(:add, var("a"), var("b")), var("c")), + var("d") + ) + ]) + ]) + + result = PrettyPrinter.print(ast) + assert result == "return (a + b) * c / d\n" + end + + test "handles nested table access" do + # a[b[c]] + ast = + chunk([ + return_stmt([ + index(var("a"), index(var("b"), var("c"))) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return a[b[c]]\n" + end + + test "handles chained method calls" do + # obj:method1():method2() + ast = + chunk([ + return_stmt([ + method_call( + method_call(var("obj"), "method1", []), + "method2", + [] + ) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return obj:method1():method2()\n" + end + end + + describe "function name variations" do + test "prints simple function name" do + ast = chunk([func_decl("simple", [], [return_stmt([])])]) + result = PrettyPrinter.print(ast) + assert result =~ "function simple()" + end + end + + describe "number formatting edge cases" do + test "formats integer as integer" do + ast = chunk([return_stmt([number(42)])]) + assert PrettyPrinter.print(ast) == "return 42\n" + end + + test "formats float that equals integer with .0" do + ast = chunk([return_stmt([number(5.0)])]) + assert PrettyPrinter.print(ast) == "return 5.0\n" + end + + test "formats regular float normally" do + ast = chunk([return_stmt([number(3.14159)])]) + assert PrettyPrinter.print(ast) == "return 3.14159\n" + end + + test "formats negative numbers" do + ast = chunk([return_stmt([number(-42)])]) + assert PrettyPrinter.print(ast) == "return -42\n" + end + end + + describe "precedence with various operators" do + test "comparison operators with power" do + # x < y ^ 2 (power has higher precedence) + ast = + chunk([ + return_stmt([ + binop(:lt, var("x"), binop(:pow, var("y"), number(2))) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return x < y ^ 2\n" + end + + test "equality with arithmetic" do + # x == y + 1 (arithmetic has higher precedence) + ast = + chunk([ + return_stmt([ + binop(:eq, var("x"), binop(:add, var("y"), number(1))) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return x == y + 1\n" + end + + test "inequality with arithmetic" do + # x ~= y * 2 (arithmetic has higher precedence) + ast = + chunk([ + return_stmt([ + binop(:ne, var("x"), binop(:mul, var("y"), number(2))) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return x ~= y * 2\n" + end + + test "less than or equal with addition" do + # x <= y + 5 (arithmetic has higher precedence) + ast = + chunk([ + return_stmt([ + binop(:le, var("x"), binop(:add, var("y"), number(5))) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return x <= y + 5\n" + end + + test "greater than or equal with subtraction" do + # x >= y - 5 (arithmetic has higher precedence) + ast = + chunk([ + return_stmt([ + binop(:ge, var("x"), binop(:sub, var("y"), number(5))) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return x >= y - 5\n" + end + + test "concat with comparison" do + # (x < y) .. "test" (comparison has lower precedence than concat) + ast = + chunk([ + return_stmt([ + binop(:concat, binop(:lt, var("x"), var("y")), string("test")) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return (x < y) .. \"test\"\n" + end + + test "floor division with addition" do + # (x + y) // 2 (addition has lower precedence) + ast = + chunk([ + return_stmt([ + binop(:floor_div, binop(:add, var("x"), var("y")), number(2)) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return (x + y) // 2\n" + end + + test "modulo with addition" do + # (x + y) % 10 (addition has lower precedence) + ast = + chunk([ + return_stmt([ + binop(:mod, binop(:add, var("x"), var("y")), number(10)) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return (x + y) % 10\n" + end + + test "power with unary operator" do + # 2 ^ (-x) (unary needs parens with power parent) + ast = + chunk([ + return_stmt([ + binop(:pow, number(2), unop(:neg, var("x"))) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return 2 ^ (-x)\n" + end + + test "unary not with non-power operator" do + # not x and y (unary has higher precedence than and) + ast = + chunk([ + return_stmt([ + binop(:and, unop(:not, var("x")), var("y")) + ]) + ]) + + assert PrettyPrinter.print(ast) == "return not x and y\n" + end + end + + describe "binary name for function declaration" do + test "prints function with string name" do + # When passing a simple string name (not a list) + alias Lua.AST.Statement + + ast = %Lua.AST.Chunk{ + block: %Lua.AST.Block{ + stmts: [ + %Statement.FuncDecl{ + name: "simple", + params: [], + body: %Lua.AST.Block{stmts: [return_stmt([])]}, + is_method: false, + meta: nil + } + ] + } + } + + result = PrettyPrinter.print(ast) + assert result =~ "function simple()" + end + end + + describe "unknown operators (defensive default cases)" do + test "handles unknown binary operator" do + # Create a BinOp with an invalid operator to test the default case + alias Lua.AST.Expr + + ast = %Lua.AST.Chunk{ + block: %Lua.AST.Block{ + stmts: [ + return_stmt([ + %Expr.BinOp{ + op: :unknown_op, + left: number(1), + right: number(2), + meta: nil + } + ]) + ] + } + } + + result = PrettyPrinter.print(ast) + assert result =~ "" + end + + test "handles unknown unary operator" do + # Create a UnOp with an invalid operator to test the default case + alias Lua.AST.Expr + + ast = %Lua.AST.Chunk{ + block: %Lua.AST.Block{ + stmts: [ + return_stmt([ + %Expr.UnOp{ + op: :unknown_unop, + operand: number(42), + meta: nil + } + ]) + ] + } + } + + result = PrettyPrinter.print(ast) + assert result =~ "" + end + end +end diff --git a/test/lua/ast/walker_test.exs b/test/lua/ast/walker_test.exs new file mode 100644 index 0000000..50a545f --- /dev/null +++ b/test/lua/ast/walker_test.exs @@ -0,0 +1,1264 @@ +defmodule Lua.AST.WalkerTest do + use ExUnit.Case, async: true + + import Lua.AST.Builder + alias Lua.AST.{Walker, Expr, Statement} + + describe "walk/2" do + test "visits all nodes in pre-order" do + # Build: local x = 2 + 3 + ast = + chunk([ + local(["x"], [binop(:add, number(2), number(3))]) + ]) + + ref = :erlang.make_ref() + + Walker.walk(ast, fn node -> + send(self(), {ref, node}) + end) + + # Collect all visited nodes + visited = collect_messages(ref, []) + + # Should visit in pre-order: Chunk, Block, Local, BinOp, Number(2), Number(3) + assert length(visited) == 6 + assert hd(visited).__struct__ == Lua.AST.Chunk + end + + test "visits all nodes in post-order" do + # Build: local x = 2 + 3 + ast = + chunk([ + local(["x"], [binop(:add, number(2), number(3))]) + ]) + + ref = :erlang.make_ref() + + Walker.walk( + ast, + fn node -> + send(self(), {ref, node}) + end, + order: :post + ) + + visited = collect_messages(ref, []) + + # Should visit in post-order: children before parents + # Last visited should be Chunk + assert length(visited) == 6 + assert List.last(visited).__struct__ == Lua.AST.Chunk + end + + test "walks through if statement with all branches" do + # if x > 0 then return 1 elseif x < 0 then return -1 else return 0 end + ast = + chunk([ + if_stmt( + binop(:gt, var("x"), number(0)), + [return_stmt([number(1)])], + elseif: [{binop(:lt, var("x"), number(0)), [return_stmt([unop(:neg, number(1))])]}], + else: [return_stmt([number(0)])] + ) + ]) + + count = count_nodes(ast) + # Chunk + Block + If + 3 conditions + 3 blocks + 3 return stmts + 3 values = many nodes + assert count > 10 + end + + test "walks through function expressions" do + # local f = function(a, b) return a + b end + ast = + chunk([ + local(["f"], [ + function_expr(["a", "b"], [ + return_stmt([binop(:add, var("a"), var("b"))]) + ]) + ]) + ]) + + # Count variable references + var_count = + Walker.reduce(ast, 0, fn + %Expr.Var{}, acc -> acc + 1 + _, acc -> acc + end) + + # a and b + assert var_count == 2 + end + end + + describe "map/2" do + test "transforms number literals" do + # local x = 2 + 3 + ast = + chunk([ + local(["x"], [binop(:add, number(2), number(3))]) + ]) + + # Double all numbers + transformed = + Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n * 2} + node -> node + end) + + # Extract the numbers + numbers = + Walker.reduce(transformed, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + # 2*2=4, 3*2=6 + assert Enum.sort(numbers) == [4, 6] + end + + test "transforms variable names" do + # x = y + z + ast = + chunk([ + assign([var("x")], [binop(:add, var("y"), var("z"))]) + ]) + + # Add prefix to all variables + transformed = + Walker.map(ast, fn + %Expr.Var{name: name} = node -> %{node | name: "local_" <> name} + node -> node + end) + + # Collect variable names + names = + Walker.reduce(transformed, [], fn + %Expr.Var{name: name}, acc -> [name | acc] + _, acc -> acc + end) + + assert Enum.sort(names) == ["local_x", "local_y", "local_z"] + end + + test "preserves structure while transforming" do + # if true then print(1) end + ast = + chunk([ + if_stmt(bool(true), [ + call_stmt(call(var("print"), [number(1)])) + ]) + ]) + + # Transform should preserve structure + transformed = + Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n + 1} + node -> node + end) + + # Extract the if statement + [%Statement.If{condition: %Expr.Bool{value: true}}] = transformed.block.stmts + + # Number should be transformed + numbers = + Walker.reduce(transformed, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + # 1 + 1 = 2 + assert numbers == [2] + end + end + + describe "reduce/3" do + test "counts all nodes" do + # local x = 1; local y = 2; return x + y + ast = + chunk([ + local(["x"], [number(1)]), + local(["y"], [number(2)]), + return_stmt([binop(:add, var("x"), var("y"))]) + ]) + + count = Walker.reduce(ast, 0, fn _, acc -> acc + 1 end) + + # Should count all nodes + assert count > 5 + end + + test "collects specific node types" do + # local x = 1; y = 2; print(x, y) + ast = + chunk([ + local(["x"], [number(1)]), + assign([var("y")], [number(2)]), + call_stmt(call(var("print"), [var("x"), var("y")])) + ]) + + # Collect all variable names + vars = + Walker.reduce(ast, [], fn + %Expr.Var{name: name}, acc -> [name | acc] + _, acc -> acc + end) + + assert Enum.sort(vars) == ["print", "x", "y", "y"] + + # Collect all numbers + numbers = + Walker.reduce(ast, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + assert Enum.sort(numbers) == [1, 2] + end + + test "builds maps from nodes" do + # local x = 10; local y = 20 + ast = + chunk([ + local(["x"], [number(10)]), + local(["y"], [number(20)]) + ]) + + # Build map of local declarations: name -> value + locals = + Walker.reduce(ast, %{}, fn + %Statement.Local{names: [name], values: [%Expr.Number{value: n}]}, acc -> + Map.put(acc, name, n) + + _, acc -> + acc + end) + + assert locals == %{"x" => 10, "y" => 20} + end + + test "accumulates deeply nested values" do + # function f() return function() return 42 end end + ast = + chunk([ + func_decl("f", [], [ + return_stmt([ + function_expr([], [ + return_stmt([number(42)]) + ]) + ]) + ]) + ]) + + # Count function expressions + func_count = + Walker.reduce(ast, 0, fn + %Expr.Function{}, acc -> acc + 1 + _, acc -> acc + end) + + assert func_count == 1 + end + end + + describe "complex AST traversal" do + test "handles nested loops and conditions" do + # for i = 1, 10 do + # if i % 2 == 0 then + # print(i) + # end + # end + ast = + chunk([ + for_num("i", number(1), number(10), [ + if_stmt( + binop(:eq, binop(:mod, var("i"), number(2)), number(0)), + [call_stmt(call(var("print"), [var("i")]))] + ) + ]) + ]) + + # Count all operators + ops = + Walker.reduce(ast, [], fn + %Expr.BinOp{op: op}, acc -> [op | acc] + _, acc -> acc + end) + + assert :eq in ops + assert :mod in ops + end + + test "handles table constructors" do + # local t = {x = 1, y = 2, [3] = "three"} + ast = + chunk([ + local(["t"], [ + table([ + {:record, string("x"), number(1)}, + {:record, string("y"), number(2)}, + {:record, number(3), string("three")} + ]) + ]) + ]) + + # Count table fields + field_count = + Walker.reduce(ast, 0, fn + %Expr.Table{fields: fields}, acc -> acc + length(fields) + _, acc -> acc + end) + + assert field_count == 3 + end + end + + describe "expression nodes" do + test "walks MethodCall nodes" do + # obj:method(arg1, arg2) + ast = + chunk([ + call_stmt(method_call(var("obj"), "method", [var("arg1"), var("arg2")])) + ]) + + # Count all variables + var_count = + Walker.reduce(ast, 0, fn + %Expr.Var{}, acc -> acc + 1 + _, acc -> acc + end) + + # obj, arg1, arg2 + assert var_count == 3 + end + + test "maps MethodCall nodes" do + # file:read("*a") + ast = + chunk([ + call_stmt(method_call(var("file"), "read", [string("*a")])) + ]) + + # Transform method name + transformed = + Walker.map(ast, fn + %Expr.MethodCall{method: m} = node -> %{node | method: "new_" <> m} + node -> node + end) + + # Extract method call + method_calls = + Walker.reduce(transformed, [], fn + %Expr.MethodCall{method: m}, acc -> [m | acc] + _, acc -> acc + end) + + assert method_calls == ["new_read"] + end + + test "walks Index nodes" do + # t[key] + ast = chunk([assign([index(var("t"), var("key"))], [number(42)])]) + + # Count variables + vars = + Walker.reduce(ast, [], fn + %Expr.Var{name: name}, acc -> [name | acc] + _, acc -> acc + end) + + assert Enum.sort(vars) == ["key", "t"] + end + + test "maps Index nodes" do + # arr[1] = arr[2] + ast = + chunk([ + assign( + [index(var("arr"), number(1))], + [index(var("arr"), number(2))] + ) + ]) + + # Double all indices + transformed = + Walker.map(ast, fn + %Expr.Index{key: %Expr.Number{value: n} = key} = node -> + %{node | key: %{key | value: n * 2}} + + node -> + node + end) + + # Collect indices + indices = + Walker.reduce(transformed, [], fn + %Expr.Index{key: %Expr.Number{value: n}}, acc -> [n | acc] + _, acc -> acc + end) + + assert Enum.sort(indices) == [2, 4] + end + + test "walks Property nodes" do + # io.write + ast = chunk([call_stmt(call(property(var("io"), "write"), [string("test")]))]) + + # Count variables + var_count = + Walker.reduce(ast, 0, fn + %Expr.Var{}, acc -> acc + 1 + _, acc -> acc + end) + + # io + assert var_count == 1 + end + + test "maps Property nodes" do + # math.pi + ast = chunk([assign([var("x")], [property(var("math"), "pi")])]) + + # Transform property field + transformed = + Walker.map(ast, fn + %Expr.Property{field: f} = node -> %{node | field: String.upcase(f)} + node -> node + end) + + # Extract property field + fields = + Walker.reduce(transformed, [], fn + %Expr.Property{field: f}, acc -> [f | acc] + _, acc -> acc + end) + + assert fields == ["PI"] + end + + test "walks String nodes" do + # local s = "hello" + ast = chunk([local(["s"], [string("hello")])]) + + # Collect strings + strings = + Walker.reduce(ast, [], fn + %Expr.String{value: s}, acc -> [s | acc] + _, acc -> acc + end) + + assert strings == ["hello"] + end + + test "maps String nodes" do + # print("hello", "world") + ast = chunk([call_stmt(call(var("print"), [string("hello"), string("world")]))]) + + # Uppercase all strings + transformed = + Walker.map(ast, fn + %Expr.String{value: s} = node -> %{node | value: String.upcase(s)} + node -> node + end) + + strings = + Walker.reduce(transformed, [], fn + %Expr.String{value: s}, acc -> [s | acc] + _, acc -> acc + end) + + assert Enum.sort(strings) == ["HELLO", "WORLD"] + end + + test "walks Nil nodes" do + # local x = nil + ast = chunk([local(["x"], [nil_lit()])]) + + # Count nil literals + nil_count = + Walker.reduce(ast, 0, fn + %Expr.Nil{}, acc -> acc + 1 + _, acc -> acc + end) + + assert nil_count == 1 + end + + test "walks Vararg nodes" do + # function(...) return ... end + ast = chunk([func_decl("f", [], [return_stmt([vararg()])], vararg: true)]) + + # Count vararg expressions + vararg_count = + Walker.reduce(ast, 0, fn + %Expr.Vararg{}, acc -> acc + 1 + _, acc -> acc + end) + + assert vararg_count == 1 + end + + test "maps Vararg nodes in function" do + # local f = function(...) return ... end + ast = chunk([local(["f"], [function_expr([], [return_stmt([vararg()])], vararg: true)])]) + + # Count nodes before and after map + count_before = count_nodes(ast) + + transformed = + Walker.map(ast, fn + node -> node + end) + + count_after = count_nodes(transformed) + + # Structure should be preserved + assert count_before == count_after + end + end + + describe "statement nodes" do + test "walks Local without values" do + # local x, y + ast = chunk([local(["x", "y"], [])]) + + # Count local statements + local_count = + Walker.reduce(ast, 0, fn + %Statement.Local{}, acc -> acc + 1 + _, acc -> acc + end) + + assert local_count == 1 + + # Should have no child expressions + expr_count = + Walker.reduce(ast, 0, fn + %Expr.Number{}, acc -> acc + 1 + _, acc -> acc + end) + + assert expr_count == 0 + end + + test "maps Local without values" do + # local x + ast = chunk([local(["x"], [])]) + + # Transform should preserve empty values list + transformed = + Walker.map(ast, fn + %Statement.Local{values: []} = node -> node + node -> node + end) + + # Extract local statement + locals = + Walker.reduce(transformed, [], fn + %Statement.Local{names: names, values: values}, acc -> [{names, values} | acc] + _, acc -> acc + end) + + assert locals == [{["x"], []}] + end + + test "walks LocalFunc nodes" do + # local function add(a, b) return a + b end + ast = + chunk([local_func("add", ["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])])]) + + # Count variables + var_count = + Walker.reduce(ast, 0, fn + %Expr.Var{}, acc -> acc + 1 + _, acc -> acc + end) + + # a, b (in return statement) + assert var_count == 2 + end + + test "maps LocalFunc nodes" do + # local function f() return 1 end + ast = chunk([local_func("f", [], [return_stmt([number(1)])])]) + + # Double numbers + transformed = + Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n * 2} + node -> node + end) + + numbers = + Walker.reduce(transformed, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + assert numbers == [2] + end + + test "walks While nodes" do + # while x > 0 do x = x - 1 end + ast = + chunk([ + while_stmt( + binop(:gt, var("x"), number(0)), + [assign([var("x")], [binop(:sub, var("x"), number(1))])] + ) + ]) + + # Count variables + var_count = + Walker.reduce(ast, 0, fn + %Expr.Var{name: "x"}, acc -> acc + 1 + _, acc -> acc + end) + + # x appears 3 times: condition, target, value + assert var_count == 3 + end + + test "maps While nodes" do + # while true do print(1) end + ast = chunk([while_stmt(bool(true), [call_stmt(call(var("print"), [number(1)]))])]) + + # Transform numbers + transformed = + Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n + 10} + node -> node + end) + + numbers = + Walker.reduce(transformed, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + assert numbers == [11] + end + + test "walks Repeat nodes" do + # repeat x = x - 1 until x <= 0 + ast = + chunk([ + repeat_stmt( + [assign([var("x")], [binop(:sub, var("x"), number(1))])], + binop(:le, var("x"), number(0)) + ) + ]) + + # Count variables + var_count = + Walker.reduce(ast, 0, fn + %Expr.Var{name: "x"}, acc -> acc + 1 + _, acc -> acc + end) + + # x appears 3 times: target, value, condition + assert var_count == 3 + end + + test "maps Repeat nodes" do + # repeat print(5) until true + ast = chunk([repeat_stmt([call_stmt(call(var("print"), [number(5)]))], bool(true))]) + + # Transform numbers + transformed = + Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n * 2} + node -> node + end) + + numbers = + Walker.reduce(transformed, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + assert numbers == [10] + end + + test "walks ForNum with nil step" do + # for i = 1, 10 do print(i) end + ast = + chunk([for_num("i", number(1), number(10), [call_stmt(call(var("print"), [var("i")]))])]) + + # Verify step is nil + step_is_nil = + Walker.reduce(ast, false, fn + %Statement.ForNum{step: nil}, _acc -> true + _, acc -> acc + end) + + assert step_is_nil + end + + test "walks ForNum with explicit step" do + # for i = 1, 10, 2 do print(i) end + ast = + chunk([ + for_num("i", number(1), number(10), [call_stmt(call(var("print"), [var("i")]))], + step: number(2) + ) + ]) + + # Count numbers (start, limit, step) + numbers = + Walker.reduce(ast, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + # 1, 10, 2 from the loop header (i is a var in the body) + assert Enum.sort(numbers) == [1, 2, 10] + end + + test "maps ForNum with nil step" do + # for i = 1, 10 do end + ast = chunk([for_num("i", number(1), number(10), [])]) + + # Transform numbers + transformed = + Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n + 5} + node -> node + end) + + numbers = + Walker.reduce(transformed, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + assert Enum.sort(numbers) == [6, 15] + end + + test "maps ForNum with explicit step" do + # for i = 2, 20, 3 do end + ast = chunk([for_num("i", number(2), number(20), [], step: number(3))]) + + # Transform numbers + transformed = + Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n * 10} + node -> node + end) + + numbers = + Walker.reduce(transformed, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + assert Enum.sort(numbers) == [20, 30, 200] + end + + test "walks ForIn nodes" do + # for k, v in pairs(t) do print(k, v) end + ast = + chunk([ + for_in( + ["k", "v"], + [call(var("pairs"), [var("t")])], + [call_stmt(call(var("print"), [var("k"), var("v")]))] + ) + ]) + + # Count variables + vars = + Walker.reduce(ast, [], fn + %Expr.Var{name: name}, acc -> [name | acc] + _, acc -> acc + end) + + # pairs, t, print, k, v + assert Enum.sort(vars) == ["k", "pairs", "print", "t", "v"] + end + + test "maps ForIn nodes" do + # for x in iter() do print(x) end + ast = + chunk([ + for_in(["x"], [call(var("iter"), [])], [call_stmt(call(var("print"), [var("x")]))]) + ]) + + # Transform variable names + transformed = + Walker.map(ast, fn + %Expr.Var{name: name} = node -> %{node | name: "new_" <> name} + node -> node + end) + + vars = + Walker.reduce(transformed, [], fn + %Expr.Var{name: name}, acc -> [name | acc] + _, acc -> acc + end) + + assert Enum.sort(vars) == ["new_iter", "new_print", "new_x"] + end + + test "walks Do nodes" do + # do local x = 1; print(x) end + ast = + chunk([ + do_block([ + local(["x"], [number(1)]), + call_stmt(call(var("print"), [var("x")])) + ]) + ]) + + # Count do statements + do_count = + Walker.reduce(ast, 0, fn + %Statement.Do{}, acc -> acc + 1 + _, acc -> acc + end) + + assert do_count == 1 + end + + test "maps Do nodes" do + # do print(5) end + ast = chunk([do_block([call_stmt(call(var("print"), [number(5)]))])]) + + # Transform numbers + transformed = + Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n * 3} + node -> node + end) + + numbers = + Walker.reduce(transformed, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + assert numbers == [15] + end + + test "walks Break nodes" do + # while true do break end + ast = chunk([while_stmt(bool(true), [break_stmt()])]) + + # Count break statements + break_count = + Walker.reduce(ast, 0, fn + %Statement.Break{}, acc -> acc + 1 + _, acc -> acc + end) + + assert break_count == 1 + end + + test "maps Break nodes (leaf node)" do + # while true do break end + ast = chunk([while_stmt(bool(true), [break_stmt()])]) + + # Map should preserve break + transformed = + Walker.map(ast, fn + node -> node + end) + + break_count = + Walker.reduce(transformed, 0, fn + %Statement.Break{}, acc -> acc + 1 + _, acc -> acc + end) + + assert break_count == 1 + end + + test "walks Goto nodes" do + # goto skip + ast = chunk([goto_stmt("skip")]) + + # Count goto statements + goto_labels = + Walker.reduce(ast, [], fn + %Statement.Goto{label: label}, acc -> [label | acc] + _, acc -> acc + end) + + assert goto_labels == ["skip"] + end + + test "maps Goto nodes (leaf node)" do + # goto target + ast = chunk([goto_stmt("target")]) + + # Map should preserve goto + transformed = + Walker.map(ast, fn + node -> node + end) + + labels = + Walker.reduce(transformed, [], fn + %Statement.Goto{label: label}, acc -> [label | acc] + _, acc -> acc + end) + + assert labels == ["target"] + end + + test "walks Label nodes" do + # ::start:: + ast = chunk([label("start")]) + + # Count labels + labels = + Walker.reduce(ast, [], fn + %Statement.Label{name: name}, acc -> [name | acc] + _, acc -> acc + end) + + assert labels == ["start"] + end + + test "maps Label nodes (leaf node)" do + # ::loop:: + ast = chunk([label("loop")]) + + # Map should preserve label + transformed = + Walker.map(ast, fn + node -> node + end) + + labels = + Walker.reduce(transformed, [], fn + %Statement.Label{name: name}, acc -> [name | acc] + _, acc -> acc + end) + + assert labels == ["loop"] + end + + test "walks FuncDecl nodes" do + # function add(a, b) return a + b end + ast = + chunk([func_decl("add", ["a", "b"], [return_stmt([binop(:add, var("a"), var("b"))])])]) + + # Count variables + var_count = + Walker.reduce(ast, 0, fn + %Expr.Var{}, acc -> acc + 1 + _, acc -> acc + end) + + # a, b (in return statement) + assert var_count == 2 + end + + test "maps FuncDecl nodes" do + # function f() return 1 end + ast = chunk([func_decl("f", [], [return_stmt([number(1)])])]) + + # Double numbers + transformed = + Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n * 2} + node -> node + end) + + numbers = + Walker.reduce(transformed, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + assert numbers == [2] + end + end + + describe "edge cases" do + test "If with no elseifs or else" do + # if x then print(x) end + ast = chunk([if_stmt(var("x"), [call_stmt(call(var("print"), [var("x")]))])]) + + # Verify structure + if_stmts = + Walker.reduce(ast, [], fn + %Statement.If{elseifs: elseifs, else_block: else_block}, acc -> + [{elseifs, else_block} | acc] + + _, acc -> + acc + end) + + assert if_stmts == [{[], nil}] + end + + test "If with elseif clauses mapping" do + # if x > 0 then return 1 elseif x < 0 then return -1 end + ast = + chunk([ + if_stmt( + binop(:gt, var("x"), number(0)), + [return_stmt([number(1)])], + elseif: [{binop(:lt, var("x"), number(0)), [return_stmt([number(-1)])]}] + ) + ]) + + # Map should traverse elseif conditions and blocks + transformed = + Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n * 10} + node -> node + end) + + numbers = + Walker.reduce(transformed, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + # Should have transformed: 0, 1, 0, -1 -> 0, 10, 0, -10 + assert Enum.sort(numbers) == [-10, 0, 0, 10] + end + + test "UnOp expressions mapping" do + # local x = -5 + ast = chunk([local(["x"], [unop(:neg, number(5))])]) + + # Map should traverse unary operations + transformed = + Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n * 2} + node -> node + end) + + numbers = + Walker.reduce(transformed, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + assert numbers == [10] + end + + test "Local without values mapping" do + # local x + ast = chunk([local(["x"], nil)]) + + # Map should handle Local without values + transformed = + Walker.map(ast, fn + node -> node + end) + + locals = + Walker.reduce(transformed, [], fn + %Statement.Local{names: names, values: values}, acc -> [{names, values} | acc] + _, acc -> acc + end) + + assert locals == [{["x"], nil}] + end + + test "Table with list fields" do + # {1, 2, 3} + ast = + chunk([ + local(["t"], [ + table([ + {:list, number(1)}, + {:list, number(2)}, + {:list, number(3)} + ]) + ]) + ]) + + # Count numbers + numbers = + Walker.reduce(ast, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + assert Enum.sort(numbers) == [1, 2, 3] + end + + test "Table with mixed list and record fields" do + # {10, 20, x = 30} + ast = + chunk([ + local(["t"], [ + table([ + {:list, number(10)}, + {:list, number(20)}, + {:record, string("x"), number(30)} + ]) + ]) + ]) + + # Map should handle both field types + transformed = + Walker.map(ast, fn + %Expr.Number{value: n} = node -> %{node | value: n + 1} + node -> node + end) + + numbers = + Walker.reduce(transformed, [], fn + %Expr.Number{value: n}, acc -> [n | acc] + _, acc -> acc + end) + + assert Enum.sort(numbers) == [11, 21, 31] + end + + test "Empty table" do + # local t = {} + ast = chunk([local(["t"], [table([])])]) + + # Count table nodes + table_count = + Walker.reduce(ast, 0, fn + %Expr.Table{}, acc -> acc + 1 + _, acc -> acc + end) + + assert table_count == 1 + + # Should have no field children + field_count = + Walker.reduce(ast, 0, fn + %Expr.Table{fields: fields}, acc -> acc + length(fields) + _, acc -> acc + end) + + assert field_count == 0 + end + + test "Nested method calls" do + # obj:method1():method2() + ast = + chunk([ + call_stmt( + method_call( + method_call(var("obj"), "method1", []), + "method2", + [] + ) + ) + ]) + + # Count method calls + method_count = + Walker.reduce(ast, 0, fn + %Expr.MethodCall{}, acc -> acc + 1 + _, acc -> acc + end) + + assert method_count == 2 + end + + test "Complex nested indexing" do + # t[a][b][c] + ast = + chunk([ + assign( + [index(index(index(var("t"), var("a")), var("b")), var("c"))], + [number(1)] + ) + ]) + + # Count index operations + index_count = + Walker.reduce(ast, 0, fn + %Expr.Index{}, acc -> acc + 1 + _, acc -> acc + end) + + assert index_count == 3 + end + + test "Multiple return values" do + # return a, b, c + ast = chunk([return_stmt([var("a"), var("b"), var("c")])]) + + # Count variables + var_count = + Walker.reduce(ast, 0, fn + %Expr.Var{}, acc -> acc + 1 + _, acc -> acc + end) + + assert var_count == 3 + end + + test "CallStmt with MethodCall" do + # obj:method() + ast = chunk([call_stmt(method_call(var("obj"), "method", []))]) + + # Should walk through CallStmt to MethodCall + call_stmt_count = + Walker.reduce(ast, 0, fn + %Statement.CallStmt{}, acc -> acc + 1 + _, acc -> acc + end) + + method_call_count = + Walker.reduce(ast, 0, fn + %Expr.MethodCall{}, acc -> acc + 1 + _, acc -> acc + end) + + assert call_stmt_count == 1 + assert method_call_count == 1 + end + + test "Deeply nested expressions" do + # ((a + b) * (c + d)) / ((e - f) * (g - h)) + ast = + chunk([ + local(["x"], [ + binop( + :div, + binop(:mul, binop(:add, var("a"), var("b")), binop(:add, var("c"), var("d"))), + binop(:mul, binop(:sub, var("e"), var("f")), binop(:sub, var("g"), var("h"))) + ) + ]) + ]) + + # Count binary operations + binop_count = + Walker.reduce(ast, 0, fn + %Expr.BinOp{}, acc -> acc + 1 + _, acc -> acc + end) + + # 7 binary operations total + assert binop_count == 7 + + # Count variables + vars = + Walker.reduce(ast, [], fn + %Expr.Var{name: name}, acc -> [name | acc] + _, acc -> acc + end) + + assert Enum.sort(vars) == ["a", "b", "c", "d", "e", "f", "g", "h"] + end + end + + # Helper to count nodes + defp count_nodes(ast) do + Walker.reduce(ast, 0, fn _, acc -> acc + 1 end) + end + + # Helper to collect messages + defp collect_messages(ref, acc) do + receive do + {^ref, node} -> collect_messages(ref, [node | acc]) + after + 0 -> Enum.reverse(acc) + end + end +end diff --git a/test/lua/lexer_test.exs b/test/lua/lexer_test.exs new file mode 100644 index 0000000..93b8739 --- /dev/null +++ b/test/lua/lexer_test.exs @@ -0,0 +1,696 @@ +defmodule Lua.LexerTest do + use ExUnit.Case, async: true + alias Lua.Lexer + + doctest Lua.Lexer + + describe "keywords" do + test "tokenizes all Lua keywords" do + keywords = [ + :and, + :break, + :do, + :else, + :elseif, + :end, + false, + :for, + :function, + :goto, + :if, + :in, + :local, + nil, + :not, + :or, + :repeat, + :return, + :then, + true, + :until, + :while + ] + + for keyword <- keywords do + keyword_str = Atom.to_string(keyword) + assert {:ok, tokens} = Lexer.tokenize(keyword_str) + assert [{:keyword, ^keyword, _}, {:eof, _}] = tokens + end + end + + test "keywords are case-sensitive" do + assert {:ok, [{:identifier, "IF", _}, {:eof, _}]} = Lexer.tokenize("IF") + assert {:ok, [{:identifier, "End", _}, {:eof, _}]} = Lexer.tokenize("End") + end + end + + describe "identifiers" do + test "tokenizes simple identifiers" do + assert {:ok, [{:identifier, "foo", _}, {:eof, _}]} = Lexer.tokenize("foo") + assert {:ok, [{:identifier, "bar123", _}, {:eof, _}]} = Lexer.tokenize("bar123") + assert {:ok, [{:identifier, "_test", _}, {:eof, _}]} = Lexer.tokenize("_test") + assert {:ok, [{:identifier, "CamelCase", _}, {:eof, _}]} = Lexer.tokenize("CamelCase") + end + + test "identifiers can start with underscore" do + assert {:ok, [{:identifier, "_", _}, {:eof, _}]} = Lexer.tokenize("_") + assert {:ok, [{:identifier, "__private", _}, {:eof, _}]} = Lexer.tokenize("__private") + end + + test "identifiers can contain numbers but not start with them" do + assert {:ok, [{:identifier, "var1", _}, {:eof, _}]} = Lexer.tokenize("var1") + assert {:ok, [{:identifier, "test123abc", _}, {:eof, _}]} = Lexer.tokenize("test123abc") + end + end + + describe "numbers" do + test "tokenizes integers" do + assert {:ok, [{:number, 0, _}, {:eof, _}]} = Lexer.tokenize("0") + assert {:ok, [{:number, 42, _}, {:eof, _}]} = Lexer.tokenize("42") + assert {:ok, [{:number, 12345, _}, {:eof, _}]} = Lexer.tokenize("12345") + end + + test "tokenizes floating point numbers" do + assert {:ok, [{:number, 3.14, _}, {:eof, _}]} = Lexer.tokenize("3.14") + assert {:ok, [{:number, 0.5, _}, {:eof, _}]} = Lexer.tokenize("0.5") + assert {:ok, [{:number, 10.0, _}, {:eof, _}]} = Lexer.tokenize("10.0") + end + + test "tokenizes hexadecimal numbers" do + assert {:ok, [{:number, 255, _}, {:eof, _}]} = Lexer.tokenize("0xFF") + assert {:ok, [{:number, 255, _}, {:eof, _}]} = Lexer.tokenize("0xff") + assert {:ok, [{:number, 0, _}, {:eof, _}]} = Lexer.tokenize("0x0") + assert {:ok, [{:number, 4095, _}, {:eof, _}]} = Lexer.tokenize("0xfff") + end + + test "tokenizes scientific notation" do + assert {:ok, [{:number, num, _}, {:eof, _}]} = Lexer.tokenize("1e10") + assert num == 1.0e10 + + assert {:ok, [{:number, num, _}, {:eof, _}]} = Lexer.tokenize("1.5e-5") + assert num == 1.5e-5 + + assert {:ok, [{:number, num, _}, {:eof, _}]} = Lexer.tokenize("3E+2") + assert num == 3.0e2 + end + + test "tokenizes floats with scientific notation" do + # Float with exponent + assert {:ok, [{:number, num, _}, {:eof, _}]} = Lexer.tokenize("2.5e3") + assert num == 2.5e3 + + # Float with uppercase E + assert {:ok, [{:number, num, _}, {:eof, _}]} = Lexer.tokenize("1.0E10") + assert num == 1.0e10 + + # Integer with exponent (becomes float) + assert {:ok, [{:number, num, _}, {:eof, _}]} = Lexer.tokenize("5e2") + assert num == 5.0e2 + + # Exponent without sign + assert {:ok, [{:number, num, _}, {:eof, _}]} = Lexer.tokenize("1e5") + assert num == 1.0e5 + end + + test "handles edge cases in scientific notation" do + # Exponent without digits - should result in error + assert {:error, {:invalid_number, _}} = Lexer.tokenize("1e") + + # Exponent with sign but no digits - should result in error + assert {:error, {:invalid_number, _}} = Lexer.tokenize("1e+") + assert {:error, {:invalid_number, _}} = Lexer.tokenize("1e-") + + # Float with exponent without digits + assert {:error, {:invalid_number, _}} = Lexer.tokenize("1.5e") + end + + test "handles trailing dot correctly" do + # "42." should be tokenized as number 42 followed by dot operator + # But in Lua, "42." is actually a valid number (42.0) + # Let's test both interpretations + assert {:ok, tokens} = Lexer.tokenize("42.") + # This might be [{:number, 42.0}, {:eof}] or [{:number, 42}, {:delimiter, :dot}, {:eof}] + # depending on implementation + assert length(tokens) >= 2 + end + end + + describe "strings" do + test "tokenizes double-quoted strings" do + assert {:ok, [{:string, "hello", _}, {:eof, _}]} = Lexer.tokenize(~s("hello")) + assert {:ok, [{:string, "", _}, {:eof, _}]} = Lexer.tokenize(~s("")) + assert {:ok, [{:string, "hello world", _}, {:eof, _}]} = Lexer.tokenize(~s("hello world")) + end + + test "tokenizes single-quoted strings" do + assert {:ok, [{:string, "hello", _}, {:eof, _}]} = Lexer.tokenize("'hello'") + assert {:ok, [{:string, "", _}, {:eof, _}]} = Lexer.tokenize("''") + assert {:ok, [{:string, "hello world", _}, {:eof, _}]} = Lexer.tokenize("'hello world'") + end + + test "handles escape sequences in strings" do + assert {:ok, [{:string, "hello\nworld", _}, {:eof, _}]} = + Lexer.tokenize(~s("hello\\nworld")) + + assert {:ok, [{:string, "tab\there", _}, {:eof, _}]} = Lexer.tokenize(~s("tab\\there")) + + assert {:ok, [{:string, "quote\"here", _}, {:eof, _}]} = + Lexer.tokenize(~s("quote\\"here")) + + assert {:ok, [{:string, "backslash\\here", _}, {:eof, _}]} = + Lexer.tokenize(~s("backslash\\\\here")) + end + + test "handles all standard escape sequences" do + # Test \a (bell) + assert {:ok, [{:string, <>, _}, {:eof, _}]} = Lexer.tokenize(~s("\\a")) + # Test \b (backspace) + assert {:ok, [{:string, <>, _}, {:eof, _}]} = Lexer.tokenize(~s("\\b")) + # Test \f (form feed) + assert {:ok, [{:string, <>, _}, {:eof, _}]} = Lexer.tokenize(~s("\\f")) + # Test \r (carriage return) + assert {:ok, [{:string, <>, _}, {:eof, _}]} = Lexer.tokenize(~s("\\r")) + # Test \v (vertical tab) + assert {:ok, [{:string, <>, _}, {:eof, _}]} = Lexer.tokenize(~s("\\v")) + # Test \' (single quote) + assert {:ok, [{:string, "'", _}, {:eof, _}]} = Lexer.tokenize(~s("\\'")) + # Test \" (double quote) + assert {:ok, [{:string, "\"", _}, {:eof, _}]} = Lexer.tokenize(~s("\\"")) + # Test \\ (backslash) + assert {:ok, [{:string, "\\", _}, {:eof, _}]} = Lexer.tokenize(~s("\\\\")) + end + + test "tokenizes long strings with [[...]]" do + assert {:ok, [{:string, "hello", _}, {:eof, _}]} = Lexer.tokenize("[[hello]]") + assert {:ok, [{:string, "", _}, {:eof, _}]} = Lexer.tokenize("[[]]") + + assert {:ok, [{:string, "multi\nline", _}, {:eof, _}]} = + Lexer.tokenize("[[multi\nline]]") + end + + test "tokenizes long strings with equals signs [=[...]=]" do + assert {:ok, [{:string, "hello", _}, {:eof, _}]} = Lexer.tokenize("[=[hello]=]") + assert {:ok, [{:string, "test", _}, {:eof, _}]} = Lexer.tokenize("[==[test]==]") + assert {:ok, [{:string, "a]b", _}, {:eof, _}]} = Lexer.tokenize("[=[a]b]=]") + end + + test "long strings with false closing brackets" do + # ] not followed by the right number of = + assert {:ok, [{:string, " test ] more ", _}, {:eof, _}]} = + Lexer.tokenize("[=[ test ] more ]=]") + + # ] not followed by ] + assert {:ok, [{:string, " test ]= more ", _}, {:eof, _}]} = + Lexer.tokenize("[[ test ]= more ]]") + end + + test "reports error for unclosed string" do + assert {:error, {:unclosed_string, _}} = Lexer.tokenize(~s("hello)) + assert {:error, {:unclosed_string, _}} = Lexer.tokenize("'hello") + end + + test "reports error for unclosed long string" do + assert {:error, {:unclosed_long_string, _}} = Lexer.tokenize("[[hello") + assert {:error, {:unclosed_long_string, _}} = Lexer.tokenize("[=[test") + end + end + + describe "operators" do + test "tokenizes single-character operators" do + assert {:ok, [{:operator, :add, _}, {:eof, _}]} = Lexer.tokenize("+") + assert {:ok, [{:operator, :sub, _}, {:eof, _}]} = Lexer.tokenize("-") + assert {:ok, [{:operator, :mul, _}, {:eof, _}]} = Lexer.tokenize("*") + assert {:ok, [{:operator, :div, _}, {:eof, _}]} = Lexer.tokenize("/") + assert {:ok, [{:operator, :mod, _}, {:eof, _}]} = Lexer.tokenize("%") + assert {:ok, [{:operator, :pow, _}, {:eof, _}]} = Lexer.tokenize("^") + assert {:ok, [{:operator, :len, _}, {:eof, _}]} = Lexer.tokenize("#") + assert {:ok, [{:operator, :lt, _}, {:eof, _}]} = Lexer.tokenize("<") + assert {:ok, [{:operator, :gt, _}, {:eof, _}]} = Lexer.tokenize(">") + assert {:ok, [{:operator, :assign, _}, {:eof, _}]} = Lexer.tokenize("=") + end + + test "tokenizes two-character operators" do + assert {:ok, [{:operator, :eq, _}, {:eof, _}]} = Lexer.tokenize("==") + assert {:ok, [{:operator, :ne, _}, {:eof, _}]} = Lexer.tokenize("~=") + assert {:ok, [{:operator, :le, _}, {:eof, _}]} = Lexer.tokenize("<=") + assert {:ok, [{:operator, :ge, _}, {:eof, _}]} = Lexer.tokenize(">=") + assert {:ok, [{:operator, :concat, _}, {:eof, _}]} = Lexer.tokenize("..") + assert {:ok, [{:operator, :floordiv, _}, {:eof, _}]} = Lexer.tokenize("//") + end + + test "tokenizes three-character operators" do + assert {:ok, [{:operator, :vararg, _}, {:eof, _}]} = Lexer.tokenize("...") + end + + test "distinguishes between . and .." do + assert {:ok, [{:delimiter, :dot, _}, {:eof, _}]} = Lexer.tokenize(".") + assert {:ok, [{:operator, :concat, _}, {:eof, _}]} = Lexer.tokenize("..") + assert {:ok, [{:operator, :vararg, _}, {:eof, _}]} = Lexer.tokenize("...") + end + end + + describe "delimiters" do + test "tokenizes parentheses" do + assert {:ok, [{:delimiter, :lparen, _}, {:eof, _}]} = Lexer.tokenize("(") + assert {:ok, [{:delimiter, :rparen, _}, {:eof, _}]} = Lexer.tokenize(")") + end + + test "tokenizes braces" do + assert {:ok, [{:delimiter, :lbrace, _}, {:eof, _}]} = Lexer.tokenize("{") + assert {:ok, [{:delimiter, :rbrace, _}, {:eof, _}]} = Lexer.tokenize("}") + end + + test "tokenizes brackets" do + assert {:ok, [{:delimiter, :lbracket, _}, {:eof, _}]} = Lexer.tokenize("[") + end + + test "tokenizes other delimiters" do + assert {:ok, [{:delimiter, :semicolon, _}, {:eof, _}]} = Lexer.tokenize(";") + assert {:ok, [{:delimiter, :comma, _}, {:eof, _}]} = Lexer.tokenize(",") + assert {:ok, [{:delimiter, :colon, _}, {:eof, _}]} = Lexer.tokenize(":") + assert {:ok, [{:delimiter, :double_colon, _}, {:eof, _}]} = Lexer.tokenize("::") + end + end + + describe "comments" do + test "skips single-line comments" do + assert {:ok, [{:eof, _}]} = Lexer.tokenize("-- this is a comment") + + assert {:ok, [{:identifier, "x", _}, {:eof, _}]} = + Lexer.tokenize("x -- comment after code") + end + + test "skips multi-line comments" do + assert {:ok, [{:eof, _}]} = Lexer.tokenize("--[[ this is a\nmulti-line comment ]]") + + assert {:ok, [{:identifier, "x", _}, {:eof, _}]} = + Lexer.tokenize("x --[[ comment ]] ") + end + + test "skips multi-line comments with equals signs" do + assert {:ok, [{:eof, _}]} = Lexer.tokenize("--[=[ comment ]=]") + assert {:ok, [{:eof, _}]} = Lexer.tokenize("--[==[ comment ]==]") + end + + test "handles content with brackets in multi-line comments" do + # The first ]] closes the comment regardless of internal [[ + # So this closes at the first ]] and leaves " inside ]]" as code + assert {:ok, tokens} = Lexer.tokenize("--[[ comment ]]") + assert [{:eof, _}] = tokens + + # With nesting levels using =, you can include ]] in the comment + assert {:ok, tokens2} = Lexer.tokenize("--[=[ comment with ]] in it ]=]") + assert [{:eof, _}] = tokens2 + end + + test "reports error for unclosed multi-line comment" do + assert {:error, {:unclosed_comment, _}} = Lexer.tokenize("--[[ unclosed comment") + end + + test "handles false closing brackets in multi-line comments" do + # Test a ] that is not followed by the right number of = + assert {:ok, [{:eof, _}]} = Lexer.tokenize("--[=[ test ] more ]=]") + # Test a ] that is not followed by ] + assert {:ok, [{:eof, _}]} = Lexer.tokenize("--[[ test ]= more ]]") + end + + test "multi-line comment with newlines" do + code = "--[[ line 1\nline 2\nline 3 ]]" + assert {:ok, [{:eof, _}]} = Lexer.tokenize(code) + end + + test "multi-line comment level 0" do + # Test the actual --[[ path + assert {:ok, [{:eof, _}]} = Lexer.tokenize("--[[ comment ]]") + assert {:ok, [{:identifier, "x", _}, {:eof, _}]} = Lexer.tokenize("--[[ comment ]]x") + end + end + + describe "whitespace" do + test "skips spaces and tabs" do + assert {:ok, [{:number, 1, _}, {:number, 2, _}, {:eof, _}]} = Lexer.tokenize("1 2") + assert {:ok, [{:number, 1, _}, {:number, 2, _}, {:eof, _}]} = Lexer.tokenize("1\t2") + end + + test "handles newlines" do + assert {:ok, [{:number, 1, _}, {:number, 2, _}, {:eof, _}]} = Lexer.tokenize("1\n2") + assert {:ok, [{:number, 1, _}, {:number, 2, _}, {:eof, _}]} = Lexer.tokenize("1\r\n2") + assert {:ok, [{:number, 1, _}, {:number, 2, _}, {:eof, _}]} = Lexer.tokenize("1\r2") + end + + test "handles different newline styles in code" do + # CRLF newline + assert {:ok, tokens} = Lexer.tokenize("x\r\ny") + + assert [ + {:identifier, "x", _}, + {:identifier, "y", _}, + {:eof, _} + ] = tokens + + # CR only newline + assert {:ok, tokens} = Lexer.tokenize("x\ry") + + assert [ + {:identifier, "x", _}, + {:identifier, "y", _}, + {:eof, _} + ] = tokens + end + end + + describe "position tracking" do + test "tracks line and column for single line" do + assert {:ok, tokens} = Lexer.tokenize("local x = 42") + + assert [ + {:keyword, :local, %{line: 1, column: 1, byte_offset: 0}}, + {:identifier, "x", %{line: 1, column: 7, byte_offset: 6}}, + {:operator, :assign, %{line: 1, column: 9, byte_offset: 8}}, + {:number, 42, %{line: 1, column: 11, byte_offset: 10}}, + {:eof, _} + ] = tokens + end + + test "tracks line numbers across multiple lines" do + code = """ + local x + x = 42 + """ + + assert {:ok, tokens} = Lexer.tokenize(code) + + assert [ + {:keyword, :local, %{line: 1}}, + {:identifier, "x", %{line: 1}}, + {:identifier, "x", %{line: 2}}, + {:operator, :assign, %{line: 2}}, + {:number, 42, %{line: 2}}, + {:eof, _} + ] = tokens + end + + test "tracks position in strings" do + assert {:ok, tokens} = Lexer.tokenize(~s("hello")) + + assert [{:string, "hello", %{line: 1, column: 1, byte_offset: 0}}, {:eof, _}] = tokens + end + end + + describe "complex expressions" do + test "tokenizes arithmetic expression" do + assert {:ok, tokens} = Lexer.tokenize("2 + 3 * 4") + + assert [ + {:number, 2, _}, + {:operator, :add, _}, + {:number, 3, _}, + {:operator, :mul, _}, + {:number, 4, _}, + {:eof, _} + ] = tokens + end + + test "tokenizes function call" do + assert {:ok, tokens} = Lexer.tokenize("print(42)") + + assert [ + {:identifier, "print", _}, + {:delimiter, :lparen, _}, + {:number, 42, _}, + {:delimiter, :rparen, _}, + {:eof, _} + ] = tokens + end + + test "tokenizes table constructor" do + assert {:ok, tokens} = Lexer.tokenize("{a = 1, b = 2}") + + assert [ + {:delimiter, :lbrace, _}, + {:identifier, "a", _}, + {:operator, :assign, _}, + {:number, 1, _}, + {:delimiter, :comma, _}, + {:identifier, "b", _}, + {:operator, :assign, _}, + {:number, 2, _}, + {:delimiter, :rbrace, _}, + {:eof, _} + ] = tokens + end + + test "tokenizes method call" do + assert {:ok, tokens} = Lexer.tokenize("obj:method()") + + assert [ + {:identifier, "obj", _}, + {:delimiter, :colon, _}, + {:identifier, "method", _}, + {:delimiter, :lparen, _}, + {:delimiter, :rparen, _}, + {:eof, _} + ] = tokens + end + + test "tokenizes vararg" do + assert {:ok, tokens} = Lexer.tokenize("function f(...) return ... end") + + assert [ + {:keyword, :function, _}, + {:identifier, "f", _}, + {:delimiter, :lparen, _}, + {:operator, :vararg, _}, + {:delimiter, :rparen, _}, + {:keyword, :return, _}, + {:operator, :vararg, _}, + {:keyword, :end, _}, + {:eof, _} + ] = tokens + end + end + + describe "edge cases" do + test "empty input" do + assert {:ok, [{:eof, %{line: 1, column: 1, byte_offset: 0}}]} = Lexer.tokenize("") + end + + test "only whitespace" do + assert {:ok, [{:eof, _}]} = Lexer.tokenize(" \n \t ") + end + + test "only comments" do + assert {:ok, [{:eof, _}]} = Lexer.tokenize("-- just a comment") + assert {:ok, [{:eof, _}]} = Lexer.tokenize("--[[ just a comment ]]") + end + + test "reports error for unexpected character" do + assert {:error, {:unexpected_character, ?@, _}} = Lexer.tokenize("@") + assert {:error, {:unexpected_character, ?$, _}} = Lexer.tokenize("$") + assert {:error, {:unexpected_character, ?`, _}} = Lexer.tokenize("`") + end + + test "handles consecutive operators" do + assert {:ok, tokens} = Lexer.tokenize("+-*/") + + assert [ + {:operator, :add, _}, + {:operator, :sub, _}, + {:operator, :mul, _}, + {:operator, :div, _}, + {:eof, _} + ] = tokens + end + + test "distinguishes >= from > =" do + assert {:ok, [{:operator, :ge, _}, {:eof, _}]} = Lexer.tokenize(">=") + + assert {:ok, [{:operator, :gt, _}, {:operator, :assign, _}, {:eof, _}]} = + Lexer.tokenize("> =") + end + + test "handles invalid escape sequences in strings" do + # Invalid escape sequences should be included as-is + assert {:ok, [{:string, "\\x", _}, {:eof, _}]} = Lexer.tokenize(~s("\\x")) + assert {:ok, [{:string, "\\z", _}, {:eof, _}]} = Lexer.tokenize(~s("\\z")) + assert {:ok, [{:string, "\\1", _}, {:eof, _}]} = Lexer.tokenize(~s("\\1")) + end + + test "reports error for string with unescaped newline" do + assert {:error, {:unclosed_string, _}} = Lexer.tokenize("\"hello\n") + assert {:error, {:unclosed_string, _}} = Lexer.tokenize("'hello\n") + end + + test "handles trailing dot after number" do + # "42." should tokenize as number 42 followed by dot + assert {:ok, tokens} = Lexer.tokenize("42.") + assert [{:number, 42, _}, {:delimiter, :dot, _}, {:eof, _}] = tokens + end + + test "handles decimal point without following digit" do + # "42.x" should be number 42 followed by dot and identifier x + assert {:ok, tokens} = Lexer.tokenize("42.x") + assert [{:number, 42, _}, {:delimiter, :dot, _}, {:identifier, "x", _}, {:eof, _}] = tokens + end + + test "reports error for invalid hex number" do + assert {:error, {:invalid_hex_number, _}} = Lexer.tokenize("0x") + assert {:error, {:invalid_hex_number, _}} = Lexer.tokenize("0xg") + end + + test "handles uppercase X in hex numbers" do + assert {:ok, [{:number, 255, _}, {:eof, _}]} = Lexer.tokenize("0XFF") + assert {:ok, [{:number, 10, _}, {:eof, _}]} = Lexer.tokenize("0Xa") + end + + test "single-line comment ending with LF" do + assert {:ok, [{:identifier, "x", _}, {:eof, _}]} = Lexer.tokenize("-- comment\nx") + end + + test "single-line comment ending with CR" do + assert {:ok, [{:identifier, "x", _}, {:eof, _}]} = Lexer.tokenize("-- comment\rx") + end + + test "single-line comment ending with CRLF" do + assert {:ok, [{:identifier, "x", _}, {:eof, _}]} = Lexer.tokenize("-- comment\r\nx") + end + + test "single-line comment at end of file" do + assert {:ok, [{:eof, _}]} = Lexer.tokenize("-- comment at EOF") + end + + test "comment starting with --[ but not --[[" do + # This should be treated as a single-line comment, not a multi-line comment + assert {:ok, [{:eof, _}]} = Lexer.tokenize("--[ this is single line") + assert {:ok, [{:eof, _}]} = Lexer.tokenize("--[= not multi-line") + assert {:ok, [{:eof, _}]} = Lexer.tokenize("--[x not multi-line") + end + + test "multi-line comment with mismatched bracket level" do + # The closing bracket doesn't match the opening level + assert {:ok, [{:eof, _}]} = Lexer.tokenize("--[=[ comment ]]") + # This should continue scanning until EOF because ]] doesn't match the opening [=[ + end + + test "long string with mismatched closing bracket" do + # Opening [=[ but closing with ]] + assert {:error, {:unclosed_long_string, _}} = Lexer.tokenize("[=[ string ]]") + end + + test "long bracket not actually a long bracket" do + # "[" followed by something other than "=" or "[" should be treated as delimiter + assert {:ok, [{:delimiter, :lbracket, _}, {:identifier, "x", _}, {:eof, _}]} = + Lexer.tokenize("[x") + end + + test "right bracket delimiter" do + assert {:ok, [{:delimiter, :rbracket, _}, {:eof, _}]} = Lexer.tokenize("]") + end + end + + describe "additional edge cases for coverage" do + test "whitespace at start with CRLF" do + # Test CRLF at very start of file + assert {:ok, [{:identifier, "x", _}, {:eof, _}]} = Lexer.tokenize("\r\nx") + end + + test "whitespace at start with CR" do + # Test CR at very start of file + assert {:ok, [{:identifier, "x", _}, {:eof, _}]} = Lexer.tokenize("\rx") + end + + test "single-line comment with --[ at start" do + # Ensure --[ path is taken + assert {:ok, [{:eof, _}]} = Lexer.tokenize("--[not a multiline") + end + + test "multiline comment with --[[ at start" do + # Ensure --[[ path is taken + assert {:ok, [{:eof, _}]} = Lexer.tokenize("--[[ multiline ]]") + end + + test "long string starting at beginning" do + assert {:ok, [{:string, "test", _}, {:eof, _}]} = Lexer.tokenize("[[test]]") + end + + test "number followed by concat operator" do + # Test the path where we have ".." after a number + assert {:ok, [{:number, 5, _}, {:operator, :concat, _}, {:eof, _}]} = + Lexer.tokenize("5..") + end + + test "uppercase E in scientific notation for integer" do + assert {:ok, [{:number, num, _}, {:eof, _}]} = Lexer.tokenize("2E5") + assert num == 2.0e5 + end + end + + describe "real Lua code examples" do + test "tokenizes variable assignment" do + code = "local x = 42" + assert {:ok, tokens} = Lexer.tokenize(code) + assert length(tokens) == 5 + end + + test "tokenizes if statement" do + code = "if x > 0 then print(x) end" + assert {:ok, tokens} = Lexer.tokenize(code) + + assert [ + {:keyword, :if, _}, + {:identifier, "x", _}, + {:operator, :gt, _}, + {:number, 0, _}, + {:keyword, :then, _}, + {:identifier, "print", _}, + {:delimiter, :lparen, _}, + {:identifier, "x", _}, + {:delimiter, :rparen, _}, + {:keyword, :end, _}, + {:eof, _} + ] = tokens + end + + test "tokenizes function definition" do + code = """ + function add(a, b) + return a + b + end + """ + + assert {:ok, tokens} = Lexer.tokenize(code) + + assert Enum.any?(tokens, fn + {:keyword, :function, _} -> true + _ -> false + end) + end + + test "tokenizes for loop" do + code = "for i = 1, 10 do print(i) end" + assert {:ok, tokens} = Lexer.tokenize(code) + + assert [ + {:keyword, :for, _}, + {:identifier, "i", _}, + {:operator, :assign, _}, + {:number, 1, _}, + {:delimiter, :comma, _}, + {:number, 10, _}, + {:keyword, :do, _}, + {:identifier, "print", _}, + {:delimiter, :lparen, _}, + {:identifier, "i", _}, + {:delimiter, :rparen, _}, + {:keyword, :end, _}, + {:eof, _} + ] = tokens + end + + test "tokenizes table with mixed fields" do + code = "{1, 2, x = 3, [\"key\"] = 4}" + assert {:ok, tokens} = Lexer.tokenize(code) + assert length(tokens) > 10 + end + end +end diff --git a/test/lua/parser/error_test.exs b/test/lua/parser/error_test.exs new file mode 100644 index 0000000..754f17a --- /dev/null +++ b/test/lua/parser/error_test.exs @@ -0,0 +1,180 @@ +defmodule Lua.Parser.ErrorTest do + @moduledoc """ + Tests for parser error messages, including formatting and suggestions. + """ + use ExUnit.Case, async: true + alias Lua.Parser + + describe "syntax errors" do + test "missing 'end' keyword" do + code = """ + function foo() + return 1 + """ + + assert {:error, msg} = Parser.parse(code) + assert msg =~ "Parse Error" + assert msg =~ "line 3" + assert msg =~ "Expected" + assert msg =~ "'end'" + end + + test "missing 'then' keyword" do + code = """ + if x > 0 + return x + end + """ + + assert {:error, msg} = Parser.parse(code) + assert msg =~ "Expected" + assert msg =~ ":then" + end + + test "missing 'do' keyword in while loop" do + code = """ + while x > 0 + x = x - 1 + end + """ + + assert {:error, msg} = Parser.parse(code) + assert msg =~ "Expected" + assert msg =~ ":do" + end + + test "missing closing parenthesis" do + code = "print(1, 2, 3" + assert {:error, msg} = Parser.parse(code) + assert msg =~ "Parse Error" + end + + test "missing closing bracket" do + code = "local t = {1, 2, 3" + assert {:error, msg} = Parser.parse(code) + assert msg =~ "Parse Error" + end + + test "unexpected token in expression" do + code = "local x = 1 + + 2" + assert {:error, msg} = Parser.parse(code) + assert msg =~ "Parse Error" + end + + test "complex nested function with missing end" do + code = """ + function factorial(n) + if n <= 1 then + return 1 + else + return n * factorial(n - 1) + -- Missing 'end' here! + """ + + assert {:error, msg} = Parser.parse(code) + assert msg =~ ~r/Parse Error/i + assert msg =~ "line" + assert msg =~ "Expected" + assert msg =~ "end" + end + end + + describe "lexer errors" do + test "unclosed string" do + code = ~s(local x = "hello) + assert {:error, msg} = Parser.parse(code) + assert msg =~ "Unclosed string" + assert msg =~ "line 1" + end + + test "unexpected character" do + code = """ + local x = 42 + local y = @invalid + """ + + assert {:error, msg} = Parser.parse(code) + assert msg =~ "Unexpected character" + assert msg =~ "line 2" + assert msg =~ "@" + end + end + + describe "error message formatting" do + test "includes visual formatting elements" do + code = "if x > 0" + + assert {:error, msg} = Parser.parse(code) + # Line separator + assert msg =~ "│" + # Error pointer + assert msg =~ "^" + # ANSI color codes + assert msg =~ "\e[" + end + + test "includes line and column information" do + code = """ + local x = 1 + if x > 0 then + print(x + end + """ + + assert {:error, msg} = Parser.parse(code) + assert msg =~ "line" + assert msg =~ "column" + end + + test "shows context lines around error" do + code = """ + local x = 1 + local y = 2 + if x > y + print(x) + end + """ + + assert {:error, msg} = Parser.parse(code) + assert msg =~ "│" + end + + test "provides helpful suggestions" do + code = """ + function test() + print("hello") + """ + + assert {:error, msg} = Parser.parse(code) + assert msg =~ "Suggestion" + end + + test "uses ANSI colors for better readability" do + code = "if x then" + + assert {:error, msg} = Parser.parse(code) + # Red for errors + assert msg =~ "\e[31m" + # Bold + assert msg =~ "\e[1m" + # Reset + assert msg =~ "\e[0m" + # Cyan for suggestions + assert msg =~ "\e[36m" + end + end + + describe "parse_raw API" do + test "returns structured error tuple" do + code = "if x then" + assert {:error, error_tuple} = Parser.parse_raw(code) + assert is_tuple(error_tuple) + end + + test "returns AST on success" do + code = "local x = 42" + assert {:ok, chunk} = Parser.parse_raw(code) + assert chunk.__struct__ == Lua.AST.Chunk + end + end +end diff --git a/test/lua/parser/error_unit_test.exs b/test/lua/parser/error_unit_test.exs new file mode 100644 index 0000000..c9d566d --- /dev/null +++ b/test/lua/parser/error_unit_test.exs @@ -0,0 +1,627 @@ +defmodule Lua.Parser.ErrorUnitTest do + use ExUnit.Case, async: true + + alias Lua.Parser.Error + + describe "new/4" do + test "creates error with all fields" do + position = %{line: 1, column: 5, byte_offset: 0} + related = [Error.new(:unexpected_token, "related error", nil)] + + error = + Error.new( + :unexpected_token, + "test message", + position, + suggestion: "test suggestion", + source_lines: ["line 1", "line 2"], + related: related + ) + + assert error.type == :unexpected_token + assert error.message == "test message" + assert error.position == position + assert error.suggestion == "test suggestion" + assert error.source_lines == ["line 1", "line 2"] + assert error.related == related + end + + test "creates error with minimal fields" do + error = Error.new(:expected_token, "minimal error") + + assert error.type == :expected_token + assert error.message == "minimal error" + assert error.position == nil + assert error.suggestion == nil + assert error.source_lines == [] + assert error.related == [] + end + + test "creates error with position but no opts" do + position = %{line: 2, column: 10, byte_offset: 15} + error = Error.new(:invalid_syntax, "with position", position) + + assert error.type == :invalid_syntax + assert error.message == "with position" + assert error.position == position + assert error.suggestion == nil + assert error.source_lines == [] + assert error.related == [] + end + end + + describe "unexpected_token/4" do + test "creates error for delimiter token" do + position = %{line: 1, column: 5, byte_offset: 0} + error = Error.unexpected_token(:delimiter, ")", position, "expression") + + assert error.type == :unexpected_token + assert String.contains?(error.message, "Unexpected ')'") + assert String.contains?(error.message, "in expression") + assert error.position == position + assert error.suggestion == "Check for missing operators or keywords before this delimiter" + end + + test "creates error for token in expression context" do + position = %{line: 2, column: 3, byte_offset: 10} + error = Error.unexpected_token(:keyword, "end", position, "primary expression") + + assert error.type == :unexpected_token + assert String.contains?(error.message, "in primary expression") + assert error.position == position + + assert error.suggestion == + "Expected an expression here (variable, number, string, table, function, etc.)" + end + + test "creates error for token in statement context" do + position = %{line: 3, column: 1, byte_offset: 20} + error = Error.unexpected_token(:number, 42, position, "statement") + + assert error.type == :unexpected_token + assert String.contains?(error.message, "in statement") + assert error.position == position + + assert error.suggestion == + "Expected a statement here (assignment, function call, if, while, for, etc.)" + end + + test "creates error with no specific suggestion" do + position = %{line: 1, column: 1, byte_offset: 0} + error = Error.unexpected_token(:identifier, "foo", position, "block") + + assert error.type == :unexpected_token + assert error.position == position + assert error.suggestion == nil + end + end + + describe "expected_token/5" do + test "creates error for expected 'end' keyword" do + position = %{line: 10, column: 1, byte_offset: 100} + error = Error.expected_token(:keyword, :end, :eof, nil, position) + + assert error.type == :expected_token + assert String.contains?(error.message, "Expected 'end'") + assert String.contains?(error.message, "but got end of input") + assert error.position == position + + assert String.contains?( + error.suggestion, + "Add 'end' to close the block. Check that all opening keywords" + ) + end + + test "creates error for expected 'then' keyword" do + position = %{line: 5, column: 20, byte_offset: 50} + error = Error.expected_token(:keyword, :then, :identifier, "x", position) + + assert error.type == :expected_token + assert String.contains?(error.message, "Expected 'then'") + assert String.contains?(error.message, "but got identifier 'x'") + + assert String.contains?( + error.suggestion, + "Add 'then' after the condition. Lua requires 'then' after if/elseif conditions." + ) + end + + test "creates error for expected 'do' keyword" do + position = %{line: 7, column: 15, byte_offset: 75} + error = Error.expected_token(:keyword, :do, :keyword, :end, position) + + assert error.type == :expected_token + assert String.contains?(error.message, "Expected 'do'") + assert String.contains?(error.message, "but got 'end'") + + assert String.contains?( + error.suggestion, + "Add 'do' to start the loop body. Lua requires 'do' after while/for conditions." + ) + end + + test "creates error for expected ')' delimiter" do + position = %{line: 3, column: 12, byte_offset: 30} + error = Error.expected_token(:delimiter, :rparen, :delimiter, :comma, position) + + assert error.type == :expected_token + assert String.contains?(error.message, "Expected 'rparen'") + assert String.contains?(error.message, "but got 'comma'") + assert String.contains?(error.suggestion, "Add ')' to close the parentheses") + end + + test "creates error for expected ']' delimiter" do + position = %{line: 4, column: 8, byte_offset: 40} + error = Error.expected_token(:delimiter, :rbracket, :eof, nil, position) + + assert error.type == :expected_token + assert String.contains?(error.message, "Expected 'rbracket'") + assert String.contains?(error.message, "but got end of input") + assert String.contains?(error.suggestion, "Add ']' to close the brackets") + end + + test "creates error for expected '}' delimiter" do + position = %{line: 6, column: 5, byte_offset: 60} + error = Error.expected_token(:delimiter, :rbrace, :keyword, :end, position) + + assert error.type == :expected_token + assert String.contains?(error.message, "Expected 'rbrace'") + assert String.contains?(error.message, "but got 'end'") + assert String.contains?(error.suggestion, "Add '}' to close the table constructor") + end + + test "creates error for expected assignment operator" do + position = %{line: 2, column: 7, byte_offset: 15} + error = Error.expected_token(:operator, :assign, :number, 42, position) + + assert error.type == :expected_token + assert String.contains?(error.message, "Expected operator 'assign'") + assert String.contains?(error.message, "but got number 42") + assert String.contains?(error.suggestion, "Add '=' for assignment") + end + + test "creates error for identifier when got keyword" do + position = %{line: 8, column: 10, byte_offset: 80} + error = Error.expected_token(:identifier, nil, :keyword, :if, position) + + assert error.type == :expected_token + assert String.contains?(error.message, "Expected identifier") + assert String.contains?(error.message, "but got 'if'") + assert String.contains?(error.suggestion, "Cannot use Lua keyword as identifier") + end + + test "creates error with no specific suggestion" do + position = %{line: 1, column: 1, byte_offset: 0} + error = Error.expected_token(:operator, :plus, :operator, :minus, position) + + assert error.type == :expected_token + assert String.contains?(error.message, "Expected operator 'plus'") + assert String.contains?(error.message, "but got operator 'minus'") + assert error.suggestion == nil + end + end + + describe "unclosed_delimiter/3" do + test "creates error for unclosed lparen" do + position = %{line: 3, column: 5, byte_offset: 25} + error = Error.unclosed_delimiter(:lparen, position) + + assert error.type == :unclosed_delimiter + assert String.contains?(error.message, "Unclosed opening parenthesis '('") + assert error.position == position + assert String.contains?(error.suggestion, "Add a closing ')'") + assert String.contains?(error.suggestion, "line 3") + end + + test "creates error for unclosed lbracket" do + position = %{line: 5, column: 2, byte_offset: 50} + error = Error.unclosed_delimiter(:lbracket, position) + + assert error.type == :unclosed_delimiter + assert String.contains?(error.message, "Unclosed opening bracket '['") + assert error.position == position + assert String.contains?(error.suggestion, "Add a closing ']'") + assert String.contains?(error.suggestion, "line 5") + end + + test "creates error for unclosed lbrace" do + position = %{line: 7, column: 10, byte_offset: 75} + error = Error.unclosed_delimiter(:lbrace, position) + + assert error.type == :unclosed_delimiter + assert String.contains?(error.message, "Unclosed opening brace '{'") + assert error.position == position + assert String.contains?(error.suggestion, "Add a closing '}'") + assert String.contains?(error.suggestion, "line 7") + end + + test "creates error for unclosed function block" do + position = %{line: 10, column: 1, byte_offset: 100} + error = Error.unclosed_delimiter(:function, position) + + assert error.type == :unclosed_delimiter + assert String.contains?(error.message, "Unclosed 'function' block") + assert error.position == position + assert String.contains?(error.suggestion, "Add a closing 'end'") + assert String.contains?(error.suggestion, "line 10") + end + + test "creates error for unclosed if statement" do + position = %{line: 12, column: 1, byte_offset: 120} + error = Error.unclosed_delimiter(:if, position) + + assert error.type == :unclosed_delimiter + assert String.contains?(error.message, "Unclosed 'if' statement") + assert error.position == position + assert String.contains?(error.suggestion, "Add a closing 'end'") + assert String.contains?(error.suggestion, "line 12") + end + + test "creates error for unclosed while loop" do + position = %{line: 15, column: 1, byte_offset: 150} + error = Error.unclosed_delimiter(:while, position) + + assert error.type == :unclosed_delimiter + assert String.contains?(error.message, "Unclosed 'while' loop") + assert error.position == position + assert String.contains?(error.suggestion, "Add a closing 'end'") + assert String.contains?(error.suggestion, "line 15") + end + + test "creates error for unclosed for loop" do + position = %{line: 18, column: 1, byte_offset: 180} + error = Error.unclosed_delimiter(:for, position) + + assert error.type == :unclosed_delimiter + assert String.contains?(error.message, "Unclosed 'for' loop") + assert error.position == position + assert String.contains?(error.suggestion, "Add a closing 'end'") + assert String.contains?(error.suggestion, "line 18") + end + + test "creates error for unclosed do block" do + position = %{line: 20, column: 1, byte_offset: 200} + error = Error.unclosed_delimiter(:do, position) + + assert error.type == :unclosed_delimiter + assert String.contains?(error.message, "Unclosed 'do' block") + assert error.position == position + assert String.contains?(error.suggestion, "Add a closing 'end'") + assert String.contains?(error.suggestion, "line 20") + end + + test "creates error for unknown delimiter (default case)" do + position = %{line: 1, column: 1, byte_offset: 0} + error = Error.unclosed_delimiter(:unknown_delim, position) + + assert error.type == :unclosed_delimiter + assert String.contains?(error.message, "Unclosed unknown_delim") + assert error.position == position + assert String.contains?(error.suggestion, "Add a closing matching delimiter") + assert String.contains?(error.suggestion, "line 1") + end + + test "creates error with close position different from open position" do + open_pos = %{line: 3, column: 5, byte_offset: 25} + close_pos = %{line: 10, column: 1, byte_offset: 100} + error = Error.unclosed_delimiter(:lparen, open_pos, close_pos) + + assert error.type == :unclosed_delimiter + assert error.position == close_pos + assert String.contains?(error.suggestion, "line 3") + end + end + + describe "unexpected_end/2" do + test "creates error with position" do + position = %{line: 5, column: 1, byte_offset: 50} + error = Error.unexpected_end("function body", position) + + assert error.type == :unexpected_end + assert String.contains?(error.message, "Unexpected end of input") + assert String.contains?(error.message, "while parsing function body") + assert error.position == position + + assert String.contains?( + error.suggestion, + "Check for missing closing delimiters or keywords like 'end'" + ) + end + + test "creates error without position" do + error = Error.unexpected_end("expression") + + assert error.type == :unexpected_end + assert String.contains?(error.message, "Unexpected end of input") + assert String.contains?(error.message, "while parsing expression") + assert error.position == nil + assert error.suggestion != nil + end + end + + describe "format_token/2 (via expected_token)" do + test "formats keyword token" do + position = %{line: 1, column: 1, byte_offset: 0} + error = Error.expected_token(:keyword, :if, :keyword, :while, position) + + assert String.contains?(error.message, "'if'") + assert String.contains?(error.message, "'while'") + end + + test "formats identifier token" do + position = %{line: 1, column: 1, byte_offset: 0} + error = Error.expected_token(:identifier, "foo", :identifier, "bar", position) + + assert String.contains?(error.message, "identifier 'foo'") + assert String.contains?(error.message, "identifier 'bar'") + end + + test "formats number token" do + position = %{line: 1, column: 1, byte_offset: 0} + error = Error.expected_token(:number, 123, :number, 456, position) + + assert String.contains?(error.message, "number 123") + assert String.contains?(error.message, "number 456") + end + + test "formats string token" do + position = %{line: 1, column: 1, byte_offset: 0} + error = Error.expected_token(:string, "hello", :string, "world", position) + + assert String.contains?(error.message, "string \"hello\"") + assert String.contains?(error.message, "string \"world\"") + end + + test "formats operator token" do + position = %{line: 1, column: 1, byte_offset: 0} + error = Error.expected_token(:operator, :plus, :operator, :minus, position) + + assert String.contains?(error.message, "operator 'plus'") + assert String.contains?(error.message, "operator 'minus'") + end + + test "formats delimiter token" do + position = %{line: 1, column: 1, byte_offset: 0} + error = Error.expected_token(:delimiter, "(", :delimiter, ")", position) + + assert String.contains?(error.message, "'('") + assert String.contains?(error.message, "')'") + end + + test "formats eof token" do + position = %{line: 1, column: 1, byte_offset: 0} + error = Error.expected_token(:keyword, :end, :eof, nil, position) + + assert String.contains?(error.message, "end of input") + end + + test "formats unknown token type (default case)" do + position = %{line: 1, column: 1, byte_offset: 0} + error = Error.expected_token(:unknown_type, "value", :other_type, "value2", position) + + assert String.contains?(error.message, "unknown_type") + assert String.contains?(error.message, "other_type") + end + end + + describe "format/2" do + test "formats error with position and context" do + source_code = """ + local x = 1 + local y = 2 + + local z = 3 + """ + + position = %{line: 2, column: 14, byte_offset: 26} + error = Error.new(:unexpected_token, "Unexpected token", position) + formatted = Error.format(error, source_code) + + assert String.contains?(formatted, "Parse Error") + assert String.contains?(formatted, "at line 2, column 14") + assert String.contains?(formatted, "Unexpected token") + assert String.contains?(formatted, "local x = 1") + assert String.contains?(formatted, "local y = 2 +") + assert String.contains?(formatted, "local z = 3") + assert String.contains?(formatted, "^") + end + + test "formats error without position" do + source_code = "local x = 1" + error = Error.new(:unexpected_end, "Unexpected end") + formatted = Error.format(error, source_code) + + assert String.contains?(formatted, "Parse Error") + assert String.contains?(formatted, "(no position information)") + assert String.contains?(formatted, "Unexpected end") + refute String.contains?(formatted, "^") + end + + test "formats error with suggestion" do + source_code = "local x = 1" + position = %{line: 1, column: 5, byte_offset: 0} + + error = + Error.new(:expected_token, "Expected something", position, + suggestion: "Try adding a semicolon" + ) + + formatted = Error.format(error, source_code) + + assert String.contains?(formatted, "Suggestion:") + assert String.contains?(formatted, "Try adding a semicolon") + end + + test "formats error without suggestion" do + source_code = "local x = 1" + position = %{line: 1, column: 5, byte_offset: 0} + error = Error.new(:invalid_syntax, "Invalid syntax", position) + formatted = Error.format(error, source_code) + + refute String.contains?(formatted, "Suggestion:") + end + + test "formats error with related errors" do + source_code = "local x = 1" + position = %{line: 1, column: 5, byte_offset: 0} + related_pos = %{line: 1, column: 10, byte_offset: 5} + + related = [Error.new(:unexpected_token, "Related issue", related_pos)] + error = Error.new(:expected_token, "Main error", position, related: related) + formatted = Error.format(error, source_code) + + assert String.contains?(formatted, "Related errors:") + assert String.contains?(formatted, "Related issue") + end + + test "formats error without related errors" do + source_code = "local x = 1" + position = %{line: 1, column: 5, byte_offset: 0} + error = Error.new(:unexpected_token, "Simple error", position) + formatted = Error.format(error, source_code) + + refute String.contains?(formatted, "Related errors:") + end + + test "formats error at start of file (line 1)" do + source_code = """ + function foo() + x = 1 + end + """ + + position = %{line: 1, column: 1, byte_offset: 0} + error = Error.new(:unexpected_token, "Error at start", position) + formatted = Error.format(error, source_code) + + assert String.contains?(formatted, "at line 1, column 1") + assert String.contains?(formatted, "function foo()") + assert String.contains?(formatted, "x = 1") + end + + test "formats error at end of file" do + source_code = """ + line 1 + line 2 + line 3 + line 4 + line 5 + """ + + position = %{line: 5, column: 6, byte_offset: 30} + error = Error.new(:unexpected_end, "Error at end", position) + formatted = Error.format(error, source_code) + + assert String.contains?(formatted, "at line 5, column 6") + assert String.contains?(formatted, "line 3") + assert String.contains?(formatted, "line 4") + assert String.contains?(formatted, "line 5") + end + + test "formats error with empty source code" do + source_code = "" + position = %{line: 1, column: 1, byte_offset: 0} + error = Error.new(:unexpected_end, "Empty file", position) + formatted = Error.format(error, source_code) + + assert String.contains?(formatted, "Parse Error") + assert String.contains?(formatted, "Empty file") + # Even with empty source, it will show context if position exists + # The format_context function will handle empty lines gracefully + end + + test "formats error showing context lines" do + source_code = """ + line 1 + line 2 + line 3 + line 4 + line 5 + line 6 + line 7 + """ + + position = %{line: 4, column: 3, byte_offset: 21} + error = Error.new(:unexpected_token, "Error in middle", position) + formatted = Error.format(error, source_code) + + # Should show 2 lines before and after (lines 2, 3, 4, 5, 6) + assert String.contains?(formatted, "line 2") + assert String.contains?(formatted, "line 3") + assert String.contains?(formatted, "line 4") + assert String.contains?(formatted, "line 5") + assert String.contains?(formatted, "line 6") + refute String.contains?(formatted, "line 1\n") + refute String.contains?(formatted, "line 7\n") + end + + test "formats error with pointer at correct column" do + source_code = "local x = 1 + 2" + position = %{line: 1, column: 15, byte_offset: 14} + error = Error.new(:unexpected_token, "Error", position) + formatted = Error.format(error, source_code) + + lines = String.split(formatted, "\n") + error_line_idx = Enum.find_index(lines, &String.contains?(&1, "local x = 1 + 2")) + pointer_line = Enum.at(lines, error_line_idx + 1) + + # Pointer should be at column 15 + assert String.contains?(pointer_line, "^") + # Count spaces before ^ + spaces_before = pointer_line |> String.split("^") |> hd() |> String.length() + # Should account for line number formatting (4 digits + " │ " = 7 chars) + column - 1 + assert spaces_before >= 20 + end + end + + describe "format_multiple/2" do + test "formats single error" do + source_code = "local x = 1" + position = %{line: 1, column: 5, byte_offset: 0} + error = Error.new(:unexpected_token, "Error 1", position) + formatted = Error.format_multiple([error], source_code) + + assert String.contains?(formatted, "Found 1 parse error") + refute String.contains?(formatted, "1 parse errors") + assert String.contains?(formatted, "Error 1:") + assert String.contains?(formatted, "Error 1") + end + + test "formats multiple errors" do + source_code = """ + local x = 1 + local y = 2 + local z = 3 + """ + + position1 = %{line: 1, column: 5, byte_offset: 0} + position2 = %{line: 2, column: 5, byte_offset: 12} + + error1 = Error.new(:unexpected_token, "First error", position1) + error2 = Error.new(:expected_token, "Second error", position2) + + formatted = Error.format_multiple([error1, error2], source_code) + + assert String.contains?(formatted, "Found 2 parse errors") + assert String.contains?(formatted, "Error 1:") + assert String.contains?(formatted, "First error") + assert String.contains?(formatted, "Error 2:") + assert String.contains?(formatted, "Second error") + end + + test "formats three errors" do + source_code = "x" + position = %{line: 1, column: 1, byte_offset: 0} + + error1 = Error.new(:unexpected_token, "E1", position) + error2 = Error.new(:expected_token, "E2", position) + error3 = Error.new(:invalid_syntax, "E3", position) + + formatted = Error.format_multiple([error1, error2, error3], source_code) + + assert String.contains?(formatted, "Found 3 parse errors") + assert String.contains?(formatted, "Error 1:") + assert String.contains?(formatted, "Error 2:") + assert String.contains?(formatted, "Error 3:") + end + end +end diff --git a/test/lua/parser/expr_test.exs b/test/lua/parser/expr_test.exs new file mode 100644 index 0000000..cf941b3 --- /dev/null +++ b/test/lua/parser/expr_test.exs @@ -0,0 +1,85 @@ +defmodule Lua.Parser.ExprTest do + use ExUnit.Case, async: true + alias Lua.Parser + alias Lua.AST.{Expr, Statement} + + # Helper to extract the returned expression from "return expr" + defp parse_return_expr(code) do + case Parser.parse(code) do + {:ok, %{block: %{stmts: [%Statement.Return{values: [expr]}]}}} -> + {:ok, expr} + + {:ok, %{block: %{stmts: [%Statement.Return{values: exprs}]}}} -> + {:ok, exprs} + + other -> + other + end + end + + describe "basic parsing" do + test "parses simple expressions" do + assert {:ok, %Expr.Number{value: 42}} = parse_return_expr("return 42") + assert {:ok, %Expr.Bool{value: true}} = parse_return_expr("return true") + assert {:ok, %Expr.String{value: "hello"}} = parse_return_expr(~s(return "hello")) + assert {:ok, %Expr.Nil{}} = parse_return_expr("return nil") + assert {:ok, %Expr.Var{name: "x"}} = parse_return_expr("return x") + end + + test "parses binary operations" do + assert {:ok, %Expr.BinOp{op: :add}} = parse_return_expr("return 1 + 2") + assert {:ok, %Expr.BinOp{op: :mul}} = parse_return_expr("return 2 * 3") + assert {:ok, %Expr.BinOp{op: :concat}} = parse_return_expr(~s(return "a" .. "b")) + end + + test "parses unary operations" do + assert {:ok, %Expr.UnOp{op: :not}} = parse_return_expr("return not true") + assert {:ok, %Expr.UnOp{op: :neg}} = parse_return_expr("return -5") + assert {:ok, %Expr.UnOp{op: :len}} = parse_return_expr("return #t") + end + + test "parses table constructors" do + assert {:ok, %Expr.Table{fields: []}} = parse_return_expr("return {}") + assert {:ok, %Expr.Table{fields: [_, _, _]}} = parse_return_expr("return {1, 2, 3}") + assert {:ok, %Expr.Table{}} = parse_return_expr("return {a = 1, b = 2}") + end + + test "parses function expressions" do + assert {:ok, %Expr.Function{params: []}} = + parse_return_expr("return function() end") + + assert {:ok, %Expr.Function{params: ["a", "b"]}} = + parse_return_expr("return function(a, b) end") + + assert {:ok, %Expr.Function{params: ["a", :vararg]}} = + parse_return_expr("return function(a, ...) end") + end + + test "parses function calls" do + assert {:ok, %Expr.Call{func: %Expr.Var{name: "f"}, args: []}} = + parse_return_expr("return f()") + + assert {:ok, %Expr.Call{args: [_, _, _]}} = parse_return_expr("return f(1, 2, 3)") + end + + test "parses property access and indexing" do + assert {:ok, %Expr.Property{table: %Expr.Var{name: "t"}, field: "field"}} = + parse_return_expr("return t.field") + + assert {:ok, %Expr.Index{table: %Expr.Var{name: "t"}}} = + parse_return_expr("return t[1]") + end + + test "parses method calls" do + assert {:ok, %Expr.MethodCall{object: %Expr.Var{name: "obj"}, method: "method"}} = + parse_return_expr("return obj:method()") + end + + test "parses complex nested expressions" do + assert {:ok, _} = parse_return_expr("return 1 + 2 * 3") + assert {:ok, _} = parse_return_expr("return (1 + 2) * 3") + assert {:ok, _} = parse_return_expr("return f(g(h(x)))") + assert {:ok, _} = parse_return_expr("return t.a.b.c") + end + end +end diff --git a/test/lua/parser/pratt_test.exs b/test/lua/parser/pratt_test.exs new file mode 100644 index 0000000..f5a2715 --- /dev/null +++ b/test/lua/parser/pratt_test.exs @@ -0,0 +1,253 @@ +defmodule Lua.Parser.PrattTest do + use ExUnit.Case, async: true + alias Lua.Parser.Pratt + + describe "binding_power/1" do + test "returns correct binding power for or operator" do + assert {1, 2} = Pratt.binding_power(:or) + end + + test "returns correct binding power for and operator" do + assert {3, 4} = Pratt.binding_power(:and) + end + + test "returns correct binding power for comparison operators" do + assert {5, 6} = Pratt.binding_power(:lt) + assert {5, 6} = Pratt.binding_power(:gt) + assert {5, 6} = Pratt.binding_power(:le) + assert {5, 6} = Pratt.binding_power(:ge) + assert {5, 6} = Pratt.binding_power(:ne) + assert {5, 6} = Pratt.binding_power(:eq) + end + + test "returns correct binding power for concat operator (right associative)" do + assert {7, 6} = Pratt.binding_power(:concat) + end + + test "returns correct binding power for additive operators" do + assert {9, 10} = Pratt.binding_power(:add) + assert {9, 10} = Pratt.binding_power(:sub) + end + + test "returns correct binding power for multiplicative operators" do + assert {11, 12} = Pratt.binding_power(:mul) + assert {11, 12} = Pratt.binding_power(:div) + assert {11, 12} = Pratt.binding_power(:floordiv) + assert {11, 12} = Pratt.binding_power(:mod) + end + + test "returns correct binding power for unary operators (should not be used as binary)" do + # These are unary operators, but the function includes them for completeness + # They should not be used as binary operators in practice + assert {13, 14} = Pratt.binding_power(:not) + assert {13, 14} = Pratt.binding_power(:neg) + assert {13, 14} = Pratt.binding_power(:len) + end + + test "returns correct binding power for power operator (right associative)" do + assert {16, 15} = Pratt.binding_power(:pow) + end + + test "returns nil for non-operators" do + assert nil == Pratt.binding_power(:invalid) + assert nil == Pratt.binding_power(:lparen) + assert nil == Pratt.binding_power(:rparen) + assert nil == Pratt.binding_power(:identifier) + end + end + + describe "prefix_binding_power/1" do + test "returns correct binding power for not operator" do + assert 14 = Pratt.prefix_binding_power(:not) + end + + test "returns correct binding power for unary minus (sub)" do + assert 13 = Pratt.prefix_binding_power(:sub) + end + + test "returns correct binding power for length operator" do + assert 14 = Pratt.prefix_binding_power(:len) + end + + test "returns nil for non-prefix operators" do + assert nil == Pratt.prefix_binding_power(:add) + assert nil == Pratt.prefix_binding_power(:mul) + assert nil == Pratt.prefix_binding_power(:invalid) + end + end + + describe "token_to_binop/1" do + test "maps logical operators" do + assert :or = Pratt.token_to_binop(:or) + assert :and = Pratt.token_to_binop(:and) + end + + test "maps comparison operators" do + assert :lt = Pratt.token_to_binop(:lt) + assert :gt = Pratt.token_to_binop(:gt) + assert :le = Pratt.token_to_binop(:le) + assert :ge = Pratt.token_to_binop(:ge) + assert :ne = Pratt.token_to_binop(:ne) + assert :eq = Pratt.token_to_binop(:eq) + end + + test "maps string concatenation operator" do + assert :concat = Pratt.token_to_binop(:concat) + end + + test "maps arithmetic operators" do + assert :add = Pratt.token_to_binop(:add) + assert :sub = Pratt.token_to_binop(:sub) + assert :mul = Pratt.token_to_binop(:mul) + assert :div = Pratt.token_to_binop(:div) + assert :floordiv = Pratt.token_to_binop(:floordiv) + assert :mod = Pratt.token_to_binop(:mod) + assert :pow = Pratt.token_to_binop(:pow) + end + + test "returns nil for non-binary operators" do + assert nil == Pratt.token_to_binop(:not) + assert nil == Pratt.token_to_binop(:len) + assert nil == Pratt.token_to_binop(:invalid) + end + end + + describe "token_to_unop/1" do + test "maps unary operators" do + assert :not = Pratt.token_to_unop(:not) + assert :neg = Pratt.token_to_unop(:sub) + assert :len = Pratt.token_to_unop(:len) + end + + test "returns nil for non-unary operators" do + assert nil == Pratt.token_to_unop(:add) + assert nil == Pratt.token_to_unop(:mul) + assert nil == Pratt.token_to_unop(:or) + assert nil == Pratt.token_to_unop(:invalid) + end + end + + describe "is_binary_op?/1" do + test "returns true for binary operators" do + assert Pratt.is_binary_op?(:or) + assert Pratt.is_binary_op?(:and) + assert Pratt.is_binary_op?(:lt) + assert Pratt.is_binary_op?(:gt) + assert Pratt.is_binary_op?(:le) + assert Pratt.is_binary_op?(:ge) + assert Pratt.is_binary_op?(:ne) + assert Pratt.is_binary_op?(:eq) + assert Pratt.is_binary_op?(:concat) + assert Pratt.is_binary_op?(:add) + assert Pratt.is_binary_op?(:sub) + assert Pratt.is_binary_op?(:mul) + assert Pratt.is_binary_op?(:div) + assert Pratt.is_binary_op?(:floordiv) + assert Pratt.is_binary_op?(:mod) + assert Pratt.is_binary_op?(:pow) + end + + test "returns false for non-binary operators" do + refute Pratt.is_binary_op?(:invalid) + refute Pratt.is_binary_op?(:lparen) + refute Pratt.is_binary_op?(:identifier) + end + end + + describe "is_prefix_op?/1" do + test "returns true for prefix operators" do + assert Pratt.is_prefix_op?(:not) + assert Pratt.is_prefix_op?(:sub) + assert Pratt.is_prefix_op?(:len) + end + + test "returns false for non-prefix operators" do + refute Pratt.is_prefix_op?(:add) + refute Pratt.is_prefix_op?(:mul) + refute Pratt.is_prefix_op?(:or) + refute Pratt.is_prefix_op?(:invalid) + end + end + + describe "operator precedence correctness" do + test "logical operators have lowest precedence" do + {or_left, _} = Pratt.binding_power(:or) + {and_left, _} = Pratt.binding_power(:and) + {lt_left, _} = Pratt.binding_power(:lt) + + assert or_left < and_left + assert and_left < lt_left + end + + test "comparison operators have same precedence" do + {lt_left, _} = Pratt.binding_power(:lt) + {gt_left, _} = Pratt.binding_power(:gt) + {le_left, _} = Pratt.binding_power(:le) + {ge_left, _} = Pratt.binding_power(:ge) + {ne_left, _} = Pratt.binding_power(:ne) + {eq_left, _} = Pratt.binding_power(:eq) + + assert lt_left == gt_left + assert lt_left == le_left + assert lt_left == ge_left + assert lt_left == ne_left + assert lt_left == eq_left + end + + test "concat is right associative (left_bp > right_bp)" do + {left_bp, right_bp} = Pratt.binding_power(:concat) + assert left_bp > right_bp + end + + test "power is right associative (left_bp > right_bp)" do + {left_bp, right_bp} = Pratt.binding_power(:pow) + assert left_bp > right_bp + end + + test "addition is left associative (left_bp >= right_bp)" do + {left_bp, right_bp} = Pratt.binding_power(:add) + assert left_bp < right_bp + end + + test "multiplication has higher precedence than addition" do + {add_left, _} = Pratt.binding_power(:add) + {mul_left, _} = Pratt.binding_power(:mul) + + assert mul_left > add_left + end + + test "all multiplicative operators have same precedence" do + {mul_left, _} = Pratt.binding_power(:mul) + {div_left, _} = Pratt.binding_power(:div) + {floordiv_left, _} = Pratt.binding_power(:floordiv) + {mod_left, _} = Pratt.binding_power(:mod) + + assert mul_left == div_left + assert mul_left == floordiv_left + assert mul_left == mod_left + end + + test "unary operators have higher precedence than multiplication" do + unary_bp = Pratt.prefix_binding_power(:sub) + {mul_left, _} = Pratt.binding_power(:mul) + + assert unary_bp > mul_left + end + + test "power has higher precedence than unary (special case)" do + unary_bp = Pratt.prefix_binding_power(:sub) + {pow_left, _} = Pratt.binding_power(:pow) + + # Unary minus has lower precedence than power's left binding + # This ensures -2^3 parses as -(2^3) + assert unary_bp < pow_left + end + + test "not and len have same prefix binding power" do + not_bp = Pratt.prefix_binding_power(:not) + len_bp = Pratt.prefix_binding_power(:len) + + assert not_bp == len_bp + end + end +end diff --git a/test/lua/parser/precedence_test.exs b/test/lua/parser/precedence_test.exs new file mode 100644 index 0000000..ba9e925 --- /dev/null +++ b/test/lua/parser/precedence_test.exs @@ -0,0 +1,430 @@ +defmodule Lua.Parser.PrecedenceTest do + use ExUnit.Case, async: true + alias Lua.Parser + alias Lua.AST.Expr + + describe "operator precedence" do + test "or has lowest precedence" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return a and b or c and d") + + # Should parse as: (a and b) or (c and d) + assert %{ + values: [ + %Expr.BinOp{ + op: :or, + left: %Expr.BinOp{op: :and}, + right: %Expr.BinOp{op: :and} + } + ] + } = stmt + end + + test "and has higher precedence than or" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return a or b and c") + + # Should parse as: a or (b and c) + assert %{ + values: [ + %Expr.BinOp{ + op: :or, + left: %Expr.Var{name: "a"}, + right: %Expr.BinOp{op: :and} + } + ] + } = stmt + end + + test "comparison has higher precedence than logical" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return a < b and c > d") + + # Should parse as: (a < b) and (c > d) + assert %{ + values: [ + %Expr.BinOp{ + op: :and, + left: %Expr.BinOp{op: :lt}, + right: %Expr.BinOp{op: :gt} + } + ] + } = stmt + end + + test "concatenation has higher precedence than comparison" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse(~s(return "a" .. "b" < "c" .. "d")) + + # Should parse as: ("a" .. "b") < ("c" .. "d") + assert %{ + values: [ + %Expr.BinOp{ + op: :lt, + left: %Expr.BinOp{op: :concat}, + right: %Expr.BinOp{op: :concat} + } + ] + } = stmt + end + + test "addition has higher precedence than concatenation" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return a + b .. c + d") + + # Should parse as: (a + b) .. (c + d) + assert %{ + values: [ + %Expr.BinOp{ + op: :concat, + left: %Expr.BinOp{op: :add}, + right: %Expr.BinOp{op: :add} + } + ] + } = stmt + end + + test "multiplication has higher precedence than addition" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return a + b * c") + + # Should parse as: a + (b * c) + assert %{ + values: [ + %Expr.BinOp{ + op: :add, + left: %Expr.Var{name: "a"}, + right: %Expr.BinOp{op: :mul} + } + ] + } = stmt + end + + test "unary has higher precedence than multiplication" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return -a * b") + + # Should parse as: (-a) * b + assert %{ + values: [ + %Expr.BinOp{ + op: :mul, + left: %Expr.UnOp{op: :neg}, + right: %Expr.Var{name: "b"} + } + ] + } = stmt + end + + test "power has higher precedence than unary (special case)" do + # In Lua, -2^3 = -(2^3) = -8, not (-2)^3 + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return -2 ^ 3") + + # Should parse as: -(2 ^ 3) + assert %{ + values: [ + %Expr.UnOp{ + op: :neg, + operand: %Expr.BinOp{op: :pow} + } + ] + } = stmt + end + + test "not has higher precedence than and" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return not a and b") + + # Should parse as: (not a) and b + assert %{ + values: [ + %Expr.BinOp{ + op: :and, + left: %Expr.UnOp{op: :not}, + right: %Expr.Var{name: "b"} + } + ] + } = stmt + end + + test "length operator has same precedence as unary minus" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return #t + 1") + + # Should parse as: (#t) + 1 + assert %{ + values: [ + %Expr.BinOp{ + op: :add, + left: %Expr.UnOp{op: :len}, + right: %Expr.Number{value: 1} + } + ] + } = stmt + end + end + + describe "associativity" do + test "addition is left associative" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return 1 + 2 + 3") + + # Should parse as: (1 + 2) + 3 + assert %{ + values: [ + %Expr.BinOp{ + op: :add, + left: %Expr.BinOp{ + op: :add, + left: %Expr.Number{value: 1}, + right: %Expr.Number{value: 2} + }, + right: %Expr.Number{value: 3} + } + ] + } = stmt + end + + test "subtraction is left associative" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return 10 - 5 - 2") + + # Should parse as: (10 - 5) - 2 = 3 + assert %{ + values: [ + %Expr.BinOp{ + op: :sub, + left: %Expr.BinOp{op: :sub}, + right: %Expr.Number{value: 2} + } + ] + } = stmt + end + + test "multiplication is left associative" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return 2 * 3 * 4") + + # Should parse as: (2 * 3) * 4 + assert %{ + values: [ + %Expr.BinOp{ + op: :mul, + left: %Expr.BinOp{op: :mul}, + right: %Expr.Number{value: 4} + } + ] + } = stmt + end + + test "division is left associative" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return 24 / 4 / 2") + + # Should parse as: (24 / 4) / 2 = 3 + assert %{ + values: [ + %Expr.BinOp{ + op: :div, + left: %Expr.BinOp{op: :div}, + right: %Expr.Number{value: 2} + } + ] + } = stmt + end + + test "power is right associative" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return 2 ^ 3 ^ 2") + + # Should parse as: 2 ^ (3 ^ 2) = 2 ^ 9 = 512 + assert %{ + values: [ + %Expr.BinOp{ + op: :pow, + left: %Expr.Number{value: 2}, + right: %Expr.BinOp{ + op: :pow, + left: %Expr.Number{value: 3}, + right: %Expr.Number{value: 2} + } + } + ] + } = stmt + end + + test "concatenation is right associative" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse(~s(return "a" .. "b" .. "c")) + + # Should parse as: "a" .. ("b" .. "c") + assert %{ + values: [ + %Expr.BinOp{ + op: :concat, + left: %Expr.String{value: "a"}, + right: %Expr.BinOp{ + op: :concat, + left: %Expr.String{value: "b"}, + right: %Expr.String{value: "c"} + } + } + ] + } = stmt + end + + test "and is left associative" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return a and b and c") + + # Should parse as: (a and b) and c + assert %{ + values: [ + %Expr.BinOp{ + op: :and, + left: %Expr.BinOp{op: :and}, + right: %Expr.Var{name: "c"} + } + ] + } = stmt + end + + test "or is left associative" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return a or b or c") + + # Should parse as: (a or b) or c + assert %{ + values: [ + %Expr.BinOp{ + op: :or, + left: %Expr.BinOp{op: :or}, + right: %Expr.Var{name: "c"} + } + ] + } = stmt + end + + test "comparison is left associative" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return 1 < 2 < 3") + + # Should parse as: (1 < 2) < 3 + # Note: This is legal in Lua but semantically weird (compares boolean with number) + assert %{ + values: [ + %Expr.BinOp{ + op: :lt, + left: %Expr.BinOp{op: :lt}, + right: %Expr.Number{value: 3} + } + ] + } = stmt + end + end + + describe "complex precedence cases" do + test "all operators with correct precedence" do + # a or b and c < d .. e + f * g ^ h + # Should parse as: a or (b and (c < (d .. (e + (f * (g ^ h)))))) + assert {:ok, %{block: %{stmts: [stmt]}}} = + Parser.parse("return a or b and c < d .. e + f * g ^ h") + + assert %{ + values: [ + %Expr.BinOp{ + op: :or, + left: %Expr.Var{name: "a"}, + right: %Expr.BinOp{ + op: :and, + left: %Expr.Var{name: "b"}, + right: %Expr.BinOp{ + op: :lt, + left: %Expr.Var{name: "c"}, + right: %Expr.BinOp{ + op: :concat, + left: %Expr.Var{name: "d"}, + right: %Expr.BinOp{ + op: :add, + left: %Expr.Var{name: "e"}, + right: %Expr.BinOp{ + op: :mul, + left: %Expr.Var{name: "f"}, + right: %Expr.BinOp{ + op: :pow, + left: %Expr.Var{name: "g"}, + right: %Expr.Var{name: "h"} + } + } + } + } + } + } + } + ] + } = stmt + end + + test "unary operators with various binary operators" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return not a or -b and #c") + + # Should parse as: (not a) or ((-b) and (#c)) + assert %{ + values: [ + %Expr.BinOp{ + op: :or, + left: %Expr.UnOp{op: :not}, + right: %Expr.BinOp{ + op: :and, + left: %Expr.UnOp{op: :neg}, + right: %Expr.UnOp{op: :len} + } + } + ] + } = stmt + end + + test "mixed arithmetic with different precedences" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return 1 + 2 * 3 - 4 / 5 % 6") + + # Should parse as: (1 + (2 * 3)) - ((4 / 5) % 6) + # With left associativity: ((1 + (2 * 3)) - ((4 / 5) % 6)) + # *, /, % have same precedence, so (4 / 5) % 6 parses left-to-right + # +, - have same precedence, so the top level is: (...) - (...) + assert %{values: [%Expr.BinOp{op: :sub}]} = stmt + end + end + + describe "parentheses override precedence" do + test "parentheses around addition before multiplication" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return (1 + 2) * 3") + + # Should parse as: (1 + 2) * 3 + assert %{ + values: [ + %Expr.BinOp{ + op: :mul, + left: %Expr.BinOp{op: :add}, + right: %Expr.Number{value: 3} + } + ] + } = stmt + end + + test "parentheses around or before and" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return (a or b) and c") + + # Should parse as: (a or b) and c + assert %{ + values: [ + %Expr.BinOp{ + op: :and, + left: %Expr.BinOp{op: :or}, + right: %Expr.Var{name: "c"} + } + ] + } = stmt + end + + test "nested parentheses" do + assert {:ok, %{block: %{stmts: [stmt]}}} = Parser.parse("return ((1 + 2) * 3) + 4") + + # Should parse as: ((1 + 2) * 3) + 4 + assert %{ + values: [ + %Expr.BinOp{ + op: :add, + left: %Expr.BinOp{ + op: :mul, + left: %Expr.BinOp{op: :add}, + right: %Expr.Number{value: 3} + }, + right: %Expr.Number{value: 4} + } + ] + } = stmt + end + end +end diff --git a/test/lua/parser/recovery_test.exs b/test/lua/parser/recovery_test.exs new file mode 100644 index 0000000..2ab5914 --- /dev/null +++ b/test/lua/parser/recovery_test.exs @@ -0,0 +1,421 @@ +defmodule Lua.Parser.RecoveryTest do + use ExUnit.Case, async: true + + alias Lua.Parser.Recovery + alias Lua.Parser.Error + + describe "recover_at_statement/2" do + test "recovers at semicolon" do + tokens = [ + {:delimiter, :semicolon, %{line: 1, column: 5}}, + {:keyword, :end, %{line: 1, column: 7}}, + {:eof, %{line: 1, column: 10}} + ] + + error = Error.new(:unexpected_token, "Unexpected token", %{line: 1, column: 1}) + + assert {:recovered, rest, [^error]} = Recovery.recover_at_statement(tokens, error) + assert [{:delimiter, :semicolon, _} | _] = rest + end + + test "recovers at end keyword" do + tokens = [ + {:keyword, :end, %{line: 1, column: 1}}, + {:eof, %{line: 1, column: 4}} + ] + + error = Error.new(:unexpected_token, "Test error", %{line: 1, column: 1}) + + assert {:recovered, rest, [^error]} = Recovery.recover_at_statement(tokens, error) + assert [{:keyword, :end, _} | _] = rest + end + + test "recovers at statement keywords" do + keywords = [:if, :while, :for, :function, :local, :do, :repeat] + + for kw <- keywords do + tokens = [ + {:keyword, kw, %{line: 1, column: 1}}, + {:eof, %{line: 1, column: 10}} + ] + + error = Error.new(:unexpected_token, "Test", %{line: 1, column: 1}) + assert {:recovered, _, [^error]} = Recovery.recover_at_statement(tokens, error) + end + end + + test "recovers when only EOF remains (EOF is a boundary)" do + tokens = [ + {:identifier, "x", %{line: 1, column: 1}}, + {:operator, :assign, %{line: 1, column: 3}}, + {:eof, %{line: 1, column: 5}} + ] + + error = Error.new(:unexpected_token, "Test", %{line: 1, column: 1}) + + # EOF is actually a statement boundary, so this recovers + assert {:recovered, rest, [^error]} = Recovery.recover_at_statement(tokens, error) + assert length(rest) >= 1 + end + + test "recovers at EOF" do + tokens = [{:eof, %{line: 1, column: 1}}] + error = Error.new(:unexpected_token, "Test", %{line: 1, column: 1}) + + assert {:recovered, rest, [^error]} = Recovery.recover_at_statement(tokens, error) + assert [{:eof, _}] = rest + end + + test "fails when no boundary found" do + tokens = [ + {:identifier, "x", %{line: 1, column: 1}}, + {:operator, :assign, %{line: 1, column: 3}}, + {:number, 42, %{line: 1, column: 5}} + ] + + error = Error.new(:unexpected_token, "Test", %{line: 1, column: 1}) + + assert {:failed, [^error]} = Recovery.recover_at_statement(tokens, error) + end + end + + describe "recover_unclosed_delimiter/3" do + test "finds closing parenthesis" do + tokens = [ + {:delimiter, :rparen, %{line: 1, column: 10}}, + {:eof, %{line: 1, column: 11}} + ] + + error = Error.new(:unclosed_delimiter, "Unclosed (", %{line: 1, column: 1}) + + assert {:recovered, rest, [^error]} = + Recovery.recover_unclosed_delimiter(tokens, :lparen, error) + + assert [{:eof, _}] = rest + end + + test "finds closing bracket" do + tokens = [ + {:delimiter, :rbracket, %{line: 1, column: 10}}, + {:eof, %{line: 1, column: 11}} + ] + + error = Error.new(:unclosed_delimiter, "Unclosed [", %{line: 1, column: 1}) + + assert {:recovered, rest, [^error]} = + Recovery.recover_unclosed_delimiter(tokens, :lbracket, error) + + assert [{:eof, _}] = rest + end + + test "finds closing brace" do + tokens = [ + {:delimiter, :rbrace, %{line: 1, column: 10}}, + {:eof, %{line: 1, column: 11}} + ] + + error = Error.new(:unclosed_delimiter, "Unclosed {", %{line: 1, column: 1}) + + assert {:recovered, rest, [^error]} = + Recovery.recover_unclosed_delimiter(tokens, :lbrace, error) + + assert [{:eof, _}] = rest + end + + test "handles nested delimiters" do + tokens = [ + {:delimiter, :lparen, %{line: 1, column: 2}}, + {:delimiter, :rparen, %{line: 1, column: 3}}, + {:delimiter, :rparen, %{line: 1, column: 4}}, + {:eof, %{line: 1, column: 5}} + ] + + error = Error.new(:unclosed_delimiter, "Test", %{line: 1, column: 1}) + + assert {:recovered, rest, [^error]} = + Recovery.recover_unclosed_delimiter(tokens, :lparen, error) + + assert [{:eof, _}] = rest + end + + test "falls back to statement boundary if delimiter not found" do + tokens = [ + {:keyword, :end, %{line: 1, column: 5}}, + {:eof, %{line: 1, column: 8}} + ] + + error = Error.new(:unclosed_delimiter, "Test", %{line: 1, column: 1}) + + assert {:recovered, rest, [^error]} = + Recovery.recover_unclosed_delimiter(tokens, :lparen, error) + + assert [{:keyword, :end, _} | _] = rest + end + + test "finds closing end keyword at depth 1" do + tokens = [ + {:keyword, :end, %{line: 1, column: 10}}, + {:eof, %{line: 1, column: 13}} + ] + + error = Error.new(:unclosed_delimiter, "Unclosed function", %{line: 1, column: 1}) + + assert {:recovered, rest, [^error]} = + Recovery.recover_unclosed_delimiter(tokens, :function, error) + + assert [{:eof, _}] = rest + end + + test "handles nested keywords with end" do + # Simulating: function ... if ... end end + tokens = [ + {:keyword, :if, %{line: 1, column: 2}}, + {:keyword, :end, %{line: 1, column: 7}}, + {:keyword, :end, %{line: 1, column: 11}}, + {:eof, %{line: 1, column: 14}} + ] + + error = Error.new(:unclosed_delimiter, "Unclosed function", %{line: 1, column: 1}) + + assert {:recovered, rest, [^error]} = + Recovery.recover_unclosed_delimiter(tokens, :function, error) + + assert [{:eof, _}] = rest + end + + test "skips non-delimiter tokens when searching" do + tokens = [ + {:identifier, "x", %{line: 1, column: 2}}, + {:operator, :add, %{line: 1, column: 4}}, + {:number, 42, %{line: 1, column: 6}}, + {:delimiter, :rparen, %{line: 1, column: 8}}, + {:eof, %{line: 1, column: 9}} + ] + + error = Error.new(:unclosed_delimiter, "Test", %{line: 1, column: 1}) + + assert {:recovered, rest, [^error]} = + Recovery.recover_unclosed_delimiter(tokens, :lparen, error) + + assert [{:eof, _}] = rest + end + + test "handles empty token list" do + tokens = [] + + error = Error.new(:unclosed_delimiter, "Test", %{line: 1, column: 1}) + + assert {:failed, [^error]} = Recovery.recover_unclosed_delimiter(tokens, :lparen, error) + end + + test "uses closing_delimiter with catch-all for keywords" do + # Test that non-standard delimiters fall through to :end + tokens = [ + {:keyword, :end, %{line: 1, column: 10}}, + {:eof, %{line: 1, column: 13}} + ] + + error = Error.new(:unclosed_delimiter, "Unclosed if", %{line: 1, column: 1}) + + assert {:recovered, rest, [^error]} = + Recovery.recover_unclosed_delimiter(tokens, :if, error) + + assert [{:eof, _}] = rest + end + end + + describe "recover_missing_keyword/3" do + test "finds the missing keyword" do + tokens = [ + {:identifier, "x", %{line: 1, column: 1}}, + {:keyword, :then, %{line: 1, column: 3}}, + {:keyword, :end, %{line: 1, column: 8}} + ] + + error = Error.new(:expected_token, "Expected then", %{line: 1, column: 1}) + + assert {:recovered, rest, [^error]} = Recovery.recover_missing_keyword(tokens, :then, error) + assert [{:keyword, :then, _} | _] = rest + end + + test "falls back to statement boundary if keyword not found" do + tokens = [ + {:keyword, :end, %{line: 1, column: 1}}, + {:eof, %{line: 1, column: 4}} + ] + + error = Error.new(:expected_token, "Expected then", %{line: 1, column: 1}) + + assert {:recovered, rest, [^error]} = Recovery.recover_missing_keyword(tokens, :then, error) + assert [{:keyword, :end, _} | _] = rest + end + + test "handles empty token list" do + tokens = [] + + error = Error.new(:expected_token, "Expected then", %{line: 1, column: 1}) + + assert {:failed, [^error]} = Recovery.recover_missing_keyword(tokens, :then, error) + end + + test "handles EOF when searching for keyword" do + tokens = [{:eof, %{line: 1, column: 1}}] + + error = Error.new(:expected_token, "Expected then", %{line: 1, column: 1}) + + # EOF is a statement boundary, so it should recover + assert {:recovered, rest, [^error]} = Recovery.recover_missing_keyword(tokens, :then, error) + assert [{:eof, _}] = rest + end + end + + describe "skip_to_statement/1" do + test "skips to statement boundary" do + tokens = [ + {:identifier, "x", %{line: 1, column: 1}}, + {:operator, :assign, %{line: 1, column: 3}}, + {:keyword, :end, %{line: 1, column: 5}}, + {:eof, %{line: 1, column: 8}} + ] + + rest = Recovery.skip_to_statement(tokens) + assert [{:keyword, :end, _} | _] = rest + end + + test "returns empty list if no boundary found" do + tokens = [ + {:identifier, "x", %{line: 1, column: 1}}, + {:operator, :assign, %{line: 1, column: 3}} + ] + + assert [] = Recovery.skip_to_statement(tokens) + end + + test "returns tokens if already at boundary" do + tokens = [ + {:keyword, :end, %{line: 1, column: 1}}, + {:eof, %{line: 1, column: 4}} + ] + + assert ^tokens = Recovery.skip_to_statement(tokens) + end + end + + describe "is_statement_boundary?/1" do + test "recognizes delimiters" do + assert Recovery.is_statement_boundary?({:delimiter, :semicolon, %{line: 1, column: 1}}) + end + + test "recognizes block terminators" do + terminators = [:end, :else, :elseif, :until] + + for term <- terminators do + assert Recovery.is_statement_boundary?({:keyword, term, %{line: 1, column: 1}}) + end + end + + test "recognizes statement starters" do + starters = [:if, :while, :for, :function, :local, :do, :repeat] + + for starter <- starters do + assert Recovery.is_statement_boundary?({:keyword, starter, %{line: 1, column: 1}}) + end + end + + test "recognizes EOF" do + assert Recovery.is_statement_boundary?({:eof, %{line: 1, column: 1}}) + end + + test "rejects non-boundaries" do + refute Recovery.is_statement_boundary?({:identifier, "x", %{line: 1, column: 1}}) + refute Recovery.is_statement_boundary?({:number, 42, %{line: 1, column: 1}}) + refute Recovery.is_statement_boundary?({:operator, :add, %{line: 1, column: 1}}) + end + end + + describe "DelimiterStack" do + alias Recovery.DelimiterStack + + test "creates empty stack" do + stack = DelimiterStack.new() + assert DelimiterStack.empty?(stack) + end + + test "pushes delimiter" do + stack = DelimiterStack.new() + stack = DelimiterStack.push(stack, :lparen, %{line: 1, column: 1}) + refute DelimiterStack.empty?(stack) + end + + test "pops matching delimiter" do + stack = DelimiterStack.new() + stack = DelimiterStack.push(stack, :lparen, %{line: 1, column: 1}) + + assert {:ok, stack} = DelimiterStack.pop(stack, :rparen) + assert DelimiterStack.empty?(stack) + end + + test "fails on mismatched delimiter" do + stack = DelimiterStack.new() + stack = DelimiterStack.push(stack, :lparen, %{line: 1, column: 1}) + + assert {:error, :mismatched, :lparen} = DelimiterStack.pop(stack, :rbracket) + end + + test "fails on empty stack" do + stack = DelimiterStack.new() + assert {:error, :empty} = DelimiterStack.pop(stack, :rparen) + end + + test "peeks at top delimiter" do + stack = DelimiterStack.new() + stack = DelimiterStack.push(stack, :lparen, %{line: 1, column: 1}) + + assert {:ok, :lparen, %{line: 1, column: 1}} = DelimiterStack.peek(stack) + end + + test "peek returns empty on empty stack" do + stack = DelimiterStack.new() + assert :empty = DelimiterStack.peek(stack) + end + + test "handles nested delimiters" do + stack = DelimiterStack.new() + stack = DelimiterStack.push(stack, :lparen, %{line: 1, column: 1}) + stack = DelimiterStack.push(stack, :lbracket, %{line: 1, column: 5}) + stack = DelimiterStack.push(stack, :lbrace, %{line: 1, column: 10}) + + assert {:ok, stack} = DelimiterStack.pop(stack, :rbrace) + assert {:ok, stack} = DelimiterStack.pop(stack, :rbracket) + assert {:ok, stack} = DelimiterStack.pop(stack, :rparen) + assert DelimiterStack.empty?(stack) + end + + test "handles keyword delimiters" do + stack = DelimiterStack.new() + stack = DelimiterStack.push(stack, :function, %{line: 1, column: 1}) + + assert {:ok, stack} = DelimiterStack.pop(stack, :end) + assert DelimiterStack.empty?(stack) + end + + test "matches all delimiter pairs" do + pairs = [ + {:lparen, :rparen}, + {:lbracket, :rbracket}, + {:lbrace, :rbrace}, + {:function, :end}, + {:if, :end}, + {:while, :end}, + {:for, :end}, + {:do, :end} + ] + + for {open, close} <- pairs do + stack = DelimiterStack.new() + stack = DelimiterStack.push(stack, open, %{line: 1, column: 1}) + assert {:ok, _} = DelimiterStack.pop(stack, close) + end + end + end +end diff --git a/test/lua/parser/statement_test.exs b/test/lua/parser/statement_test.exs new file mode 100644 index 0000000..b631193 --- /dev/null +++ b/test/lua/parser/statement_test.exs @@ -0,0 +1,591 @@ +defmodule Lua.Parser.StatementTest do + use ExUnit.Case, async: true + alias Lua.Parser + alias Lua.AST.{Statement, Expr} + + describe "local variable declarations" do + test "parses local without initialization" do + assert {:ok, chunk} = Parser.parse("local x") + assert %{block: %{stmts: [%Statement.Local{names: ["x"], values: []}]}} = chunk + end + + test "parses local with single initialization" do + assert {:ok, chunk} = Parser.parse("local x = 42") + + assert %{ + block: %{ + stmts: [%Statement.Local{names: ["x"], values: [%Expr.Number{value: 42}]}] + } + } = chunk + end + + test "parses local with multiple variables" do + assert {:ok, chunk} = Parser.parse("local x, y, z = 1, 2, 3") + + assert %{ + block: %{ + stmts: [ + %Statement.Local{ + names: ["x", "y", "z"], + values: [ + %Expr.Number{value: 1}, + %Expr.Number{value: 2}, + %Expr.Number{value: 3} + ] + } + ] + } + } = chunk + end + + test "parses local function" do + assert {:ok, chunk} = + Parser.parse(""" + local function add(a, b) + return a + b + end + """) + + assert %{block: %{stmts: [%Statement.LocalFunc{name: "add", params: ["a", "b"]}]}} = chunk + end + end + + describe "assignments" do + test "parses simple assignment" do + assert {:ok, chunk} = Parser.parse("x = 42") + + assert %{ + block: %{ + stmts: [ + %Statement.Assign{ + targets: [%Expr.Var{name: "x"}], + values: [%Expr.Number{value: 42}] + } + ] + } + } = chunk + end + + test "parses multiple assignment" do + assert {:ok, chunk} = Parser.parse("x, y = 1, 2") + + assert %{ + block: %{ + stmts: [ + %Statement.Assign{ + targets: [%Expr.Var{name: "x"}, %Expr.Var{name: "y"}], + values: [%Expr.Number{value: 1}, %Expr.Number{value: 2}] + } + ] + } + } = chunk + end + + test "parses table field assignment" do + assert {:ok, chunk} = Parser.parse("t.field = 42") + + assert %{ + block: %{ + stmts: [ + %Statement.Assign{ + targets: [%Expr.Property{}], + values: [%Expr.Number{value: 42}] + } + ] + } + } = chunk + end + + test "parses indexed assignment" do + assert {:ok, chunk} = Parser.parse("t[1] = 42") + + assert %{ + block: %{ + stmts: [ + %Statement.Assign{ + targets: [%Expr.Index{}], + values: [%Expr.Number{value: 42}] + } + ] + } + } = chunk + end + end + + describe "function calls as statements" do + test "parses function call statement" do + assert {:ok, chunk} = Parser.parse("print(42)") + + assert %{ + block: %{ + stmts: [%Statement.CallStmt{call: %Expr.Call{func: %Expr.Var{name: "print"}}}] + } + } = chunk + end + + test "parses method call statement" do + assert {:ok, chunk} = Parser.parse("obj:method()") + + assert %{ + block: %{stmts: [%Statement.CallStmt{call: %Expr.MethodCall{method: "method"}}]} + } = chunk + end + end + + describe "if statements" do + test "parses simple if statement" do + assert {:ok, chunk} = + Parser.parse(""" + if x > 0 then + return x + end + """) + + assert %{ + block: %{ + stmts: [ + %Statement.If{ + condition: %Expr.BinOp{op: :gt}, + then_block: %{stmts: [%Statement.Return{}]}, + elseifs: [], + else_block: nil + } + ] + } + } = chunk + end + + test "parses if with else" do + assert {:ok, chunk} = + Parser.parse(""" + if x > 0 then + return x + else + return 0 + end + """) + + assert %{ + block: %{ + stmts: [ + %Statement.If{ + condition: %Expr.BinOp{op: :gt}, + then_block: %{stmts: [%Statement.Return{}]}, + elseifs: [], + else_block: %{stmts: [%Statement.Return{}]} + } + ] + } + } = chunk + end + + test "parses if with elseif" do + assert {:ok, chunk} = + Parser.parse(""" + if x > 0 then + return 1 + elseif x < 0 then + return -1 + else + return 0 + end + """) + + assert %{ + block: %{ + stmts: [ + %Statement.If{ + condition: %Expr.BinOp{op: :gt}, + then_block: %{stmts: [%Statement.Return{}]}, + elseifs: [{%Expr.BinOp{op: :lt}, %{stmts: [%Statement.Return{}]}}], + else_block: %{stmts: [%Statement.Return{}]} + } + ] + } + } = chunk + end + + test "parses if with multiple elseifs" do + assert {:ok, chunk} = + Parser.parse(""" + if x == 1 then + return "one" + elseif x == 2 then + return "two" + elseif x == 3 then + return "three" + end + """) + + assert %{ + block: %{ + stmts: [ + %Statement.If{ + elseifs: [_, _] + } + ] + } + } = chunk + end + end + + describe "while loops" do + test "parses while loop" do + assert {:ok, chunk} = + Parser.parse(""" + while x > 0 do + x = x - 1 + end + """) + + assert %{ + block: %{ + stmts: [ + %Statement.While{ + condition: %Expr.BinOp{op: :gt}, + body: %{stmts: [%Statement.Assign{}]} + } + ] + } + } = chunk + end + end + + describe "repeat-until loops" do + test "parses repeat-until loop" do + assert {:ok, chunk} = + Parser.parse(""" + repeat + x = x - 1 + until x == 0 + """) + + assert %{ + block: %{ + stmts: [ + %Statement.Repeat{ + body: %{stmts: [%Statement.Assign{}]}, + condition: %Expr.BinOp{op: :eq} + } + ] + } + } = chunk + end + end + + describe "for loops" do + test "parses numeric for loop" do + assert {:ok, chunk} = + Parser.parse(""" + for i = 1, 10 do + print(i) + end + """) + + assert %{ + block: %{ + stmts: [ + %Statement.ForNum{ + var: "i", + start: %Expr.Number{value: 1}, + limit: %Expr.Number{value: 10}, + step: nil, + body: %{stmts: [%Statement.CallStmt{}]} + } + ] + } + } = chunk + end + + test "parses numeric for loop with step" do + assert {:ok, chunk} = + Parser.parse(""" + for i = 1, 10, 2 do + print(i) + end + """) + + assert %{ + block: %{ + stmts: [ + %Statement.ForNum{ + var: "i", + start: %Expr.Number{value: 1}, + limit: %Expr.Number{value: 10}, + step: %Expr.Number{value: 2} + } + ] + } + } = chunk + end + + test "parses generic for loop" do + assert {:ok, chunk} = + Parser.parse(""" + for k, v in pairs(t) do + print(k, v) + end + """) + + assert %{ + block: %{ + stmts: [ + %Statement.ForIn{ + vars: ["k", "v"], + iterators: [%Expr.Call{}], + body: %{stmts: [%Statement.CallStmt{}]} + } + ] + } + } = chunk + end + + test "parses generic for loop with single variable" do + assert {:ok, chunk} = + Parser.parse(""" + for line in io.lines() do + print(line) + end + """) + + assert %{ + block: %{ + stmts: [ + %Statement.ForIn{ + vars: ["line"], + iterators: [%Expr.Call{func: %Expr.Property{}}] + } + ] + } + } = chunk + end + end + + describe "function declarations" do + test "parses simple function declaration" do + assert {:ok, chunk} = + Parser.parse(""" + function add(a, b) + return a + b + end + """) + + assert %{ + block: %{ + stmts: [ + %Statement.FuncDecl{ + name: ["add"], + params: ["a", "b"], + is_method: false, + body: %{stmts: [%Statement.Return{}]} + } + ] + } + } = chunk + end + + test "parses function declaration with dot notation" do + assert {:ok, chunk} = + Parser.parse(""" + function math.abs(x) + return x + end + """) + + assert %{ + block: %{ + stmts: [ + %Statement.FuncDecl{ + name: ["math", "abs"], + is_method: false + } + ] + } + } = chunk + end + + test "parses method declaration" do + assert {:ok, chunk} = + Parser.parse(""" + function obj:method(x) + return self.field + x + end + """) + + assert %{ + block: %{ + stmts: [ + %Statement.FuncDecl{ + name: ["obj", "method"], + is_method: true, + params: ["x"] + } + ] + } + } = chunk + end + + test "parses nested function name" do + assert {:ok, chunk} = + Parser.parse(""" + function a.b.c.d() + end + """) + + assert %{ + block: %{ + stmts: [ + %Statement.FuncDecl{ + name: ["a", "b", "c", "d"], + is_method: false + } + ] + } + } = chunk + end + end + + describe "do blocks" do + test "parses do block" do + assert {:ok, chunk} = + Parser.parse(""" + do + local x = 42 + print(x) + end + """) + + assert %{ + block: %{ + stmts: [ + %Statement.Do{ + body: %{stmts: [%Statement.Local{}, %Statement.CallStmt{}]} + } + ] + } + } = chunk + end + end + + describe "break and goto" do + test "parses break" do + assert {:ok, chunk} = Parser.parse("break") + assert %{block: %{stmts: [%Statement.Break{}]}} = chunk + end + + test "parses goto" do + assert {:ok, chunk} = Parser.parse("goto finish") + assert %{block: %{stmts: [%Statement.Goto{label: "finish"}]}} = chunk + end + + test "parses label" do + assert {:ok, chunk} = Parser.parse("::finish::") + assert %{block: %{stmts: [%Statement.Label{name: "finish"}]}} = chunk + end + + test "parses goto and label together" do + assert {:ok, chunk} = + Parser.parse(""" + goto skip + print("skipped") + ::skip:: + print("not skipped") + """) + + assert %{ + block: %{ + stmts: [ + %Statement.Goto{label: "skip"}, + %Statement.CallStmt{}, + %Statement.Label{name: "skip"}, + %Statement.CallStmt{} + ] + } + } = chunk + end + end + + describe "return statements" do + test "parses return with no values" do + assert {:ok, chunk} = Parser.parse("return") + assert %{block: %{stmts: [%Statement.Return{values: []}]}} = chunk + end + + test "parses return with single value" do + assert {:ok, chunk} = Parser.parse("return 42") + + assert %{ + block: %{stmts: [%Statement.Return{values: [%Expr.Number{value: 42}]}]} + } = chunk + end + + test "parses return with multiple values" do + assert {:ok, chunk} = Parser.parse("return 1, 2, 3") + + assert %{ + block: %{ + stmts: [ + %Statement.Return{ + values: [ + %Expr.Number{value: 1}, + %Expr.Number{value: 2}, + %Expr.Number{value: 3} + ] + } + ] + } + } = chunk + end + end + + describe "complex programs" do + test "parses factorial function" do + assert {:ok, chunk} = + Parser.parse(""" + function factorial(n) + if n <= 1 then + return 1 + else + return n * factorial(n - 1) + end + end + """) + + assert %{block: %{stmts: [%Statement.FuncDecl{name: ["factorial"]}]}} = chunk + end + + test "parses multiple statements" do + assert {:ok, chunk} = + Parser.parse(""" + local x = 10 + local y = 20 + local sum = x + y + print(sum) + """) + + assert %{ + block: %{ + stmts: [ + %Statement.Local{names: ["x"]}, + %Statement.Local{names: ["y"]}, + %Statement.Local{names: ["sum"]}, + %Statement.CallStmt{} + ] + } + } = chunk + end + + test "parses nested control structures" do + assert {:ok, _chunk} = + Parser.parse(""" + for i = 1, 10 do + if i % 2 == 0 then + print("even") + else + print("odd") + end + end + """) + end + end +end diff --git a/test/lua/runtime_exception_test.exs b/test/lua/runtime_exception_test.exs new file mode 100644 index 0000000..d6162e3 --- /dev/null +++ b/test/lua/runtime_exception_test.exs @@ -0,0 +1,434 @@ +defmodule Lua.RuntimeExceptionTest do + use ExUnit.Case, async: true + alias Lua.RuntimeException + + describe "exception/1 with {:lua_error, error, state}" do + test "formats simple lua error with stacktrace" do + lua = Lua.new() + + # Create a lua error by dividing by zero + assert_raise RuntimeException, fn -> + Lua.eval!(lua, "return 1 / 0 * 'string'") + end + end + + test "includes formatted error message and stacktrace" do + lua = Lua.new() + + exception = + try do + Lua.eval!(lua, "error('custom error message')") + rescue + e in RuntimeException -> e + end + + assert exception.message =~ "Lua runtime error:" + assert exception.message =~ "custom error message" + assert exception.state != nil + assert exception.original != nil + end + + test "handles parse errors" do + lua = Lua.new() + + exception = + try do + # Create a parse error by using illegal token + Lua.eval!(lua, "local x = \x01") + rescue + e in Lua.CompilerException -> e + end + + # Parse errors are CompilerException, not RuntimeException + assert exception.__struct__ == Lua.CompilerException + end + + test "handles badarith errors" do + lua = Lua.new() + + exception = + try do + Lua.eval!(lua, "return 'string' + 5") + rescue + e in RuntimeException -> e + end + + assert exception.message =~ "Lua runtime error:" + assert exception.state != nil + end + + test "handles illegal index errors" do + lua = Lua.new() + + exception = + try do + # Attempt to index a non-table value + Lua.eval!(lua, "local x = 5; return x.foo") + rescue + e in RuntimeException -> e + end + + assert exception.message =~ "Lua runtime error:" + assert exception.state != nil + end + end + + describe "exception/1 with {:api_error, details, state}" do + test "creates exception with api error message" do + state = :luerl.init() + details = "invalid function call" + + exception = RuntimeException.exception({:api_error, details, state}) + + assert exception.message == "Lua API error: invalid function call" + assert exception.original == details + assert exception.state == state + end + + test "handles complex api error details" do + state = :luerl.init() + details = "function returned invalid type: expected table, got nil" + + exception = RuntimeException.exception({:api_error, details, state}) + + assert exception.message == + "Lua API error: function returned invalid type: expected table, got nil" + + assert exception.original == details + assert exception.state == state + end + end + + describe "exception/1 with keyword list [scope:, function:, message:]" do + test "formats error with empty scope" do + exception = + RuntimeException.exception( + scope: [], + function: "my_function", + message: "invalid arguments" + ) + + assert exception.message == "Lua runtime error: my_function() failed, invalid arguments" + + assert exception.original == [ + scope: [], + function: "my_function", + message: "invalid arguments" + ] + + assert exception.state == nil + end + + test "formats error with single scope element" do + exception = + RuntimeException.exception( + scope: ["math"], + function: "sqrt", + message: "negative number not allowed" + ) + + assert exception.message == + "Lua runtime error: math.sqrt() failed, negative number not allowed" + + assert exception.original == [ + scope: ["math"], + function: "sqrt", + message: "negative number not allowed" + ] + + assert exception.state == nil + end + + test "formats error with multiple scope elements" do + exception = + RuntimeException.exception( + scope: ["my", "module", "nested"], + function: "process", + message: "data validation failed" + ) + + assert exception.message == + "Lua runtime error: my.module.nested.process() failed, data validation failed" + + assert exception.original == [ + scope: ["my", "module", "nested"], + function: "process", + message: "data validation failed" + ] + + assert exception.state == nil + end + + test "raises when scope key is missing" do + assert_raise KeyError, fn -> + RuntimeException.exception( + function: "my_function", + message: "invalid arguments" + ) + end + end + + test "raises when function key is missing" do + assert_raise KeyError, fn -> + RuntimeException.exception( + scope: [], + message: "invalid arguments" + ) + end + end + + test "raises when message key is missing" do + assert_raise KeyError, fn -> + RuntimeException.exception( + scope: [], + function: "my_function" + ) + end + end + end + + describe "exception/1 with binary string" do + test "formats simple binary error" do + exception = RuntimeException.exception("something went wrong") + + assert exception.message == "Lua runtime error: something went wrong" + assert exception.original == nil + assert exception.state == nil + end + + test "trims whitespace from binary error" do + exception = RuntimeException.exception(" error with spaces \n") + + assert exception.message == "Lua runtime error: error with spaces" + assert exception.original == nil + assert exception.state == nil + end + + test "handles empty binary string" do + exception = RuntimeException.exception("") + + assert exception.message == "Lua runtime error: " + assert exception.original == nil + assert exception.state == nil + end + + test "handles multi-line binary error" do + error_message = """ + multi-line error + with details + """ + + exception = RuntimeException.exception(error_message) + + assert exception.message == "Lua runtime error: multi-line error\nwith details" + assert exception.original == nil + assert exception.state == nil + end + end + + describe "exception/1 with generic error (fallback clause)" do + test "handles built-in exception types" do + error = ArgumentError.exception("invalid argument") + + exception = RuntimeException.exception(error) + + assert exception.message == "Lua runtime error: invalid argument" + assert exception.original == error + assert exception.state == nil + end + + test "handles RuntimeError exception" do + error = RuntimeError.exception("runtime failure") + + exception = RuntimeException.exception(error) + + assert exception.message == "Lua runtime error: runtime failure" + assert exception.original == error + assert exception.state == nil + end + + test "handles KeyError exception" do + error = KeyError.exception(key: :missing, term: %{}) + + exception = RuntimeException.exception(error) + + assert exception.message =~ "Lua runtime error:" + assert exception.message =~ "key :missing not found" + assert exception.original == error + assert exception.state == nil + end + + test "handles non-exception atom" do + exception = RuntimeException.exception(:some_error) + + assert exception.message == "Lua runtime error: :some_error" + assert exception.original == :some_error + assert exception.state == nil + end + + test "handles non-exception integer" do + exception = RuntimeException.exception(42) + + assert exception.message == "Lua runtime error: 42" + assert exception.original == 42 + assert exception.state == nil + end + + test "handles non-exception tuple" do + error = {:error, :not_found} + + exception = RuntimeException.exception(error) + + assert exception.message == "Lua runtime error: {:error, :not_found}" + assert exception.original == error + assert exception.state == nil + end + + test "handles non-exception map" do + error = %{code: 404, message: "not found"} + + exception = RuntimeException.exception(error) + + assert exception.message == "Lua runtime error: %{code: 404, message: \"not found\"}" + assert exception.original == error + assert exception.state == nil + end + + test "handles non-exception list (not keyword list)" do + # A list with integer keys won't match the keyword list pattern + # because Keyword.fetch!/2 expects atom keys + error = [1, 2, 3] + + # This will raise KeyError because the keyword list clause + # will try to match first and fail on Keyword.fetch!/2 + assert_raise KeyError, fn -> + RuntimeException.exception(error) + end + end + + test "handles UndefinedFunctionError" do + error = UndefinedFunctionError.exception(module: MyModule, function: :my_func, arity: 2) + + exception = RuntimeException.exception(error) + + assert exception.message =~ "Lua runtime error:" + assert exception.message =~ "MyModule.my_func/2" + assert exception.original == error + assert exception.state == nil + end + end + + describe "format_function/2 (private function tested via keyword list exception)" do + test "formats function with empty scope" do + exception = + RuntimeException.exception( + scope: [], + function: "test", + message: "error" + ) + + assert exception.message =~ "test() failed" + end + + test "formats function with single element scope" do + exception = + RuntimeException.exception( + scope: ["module"], + function: "func", + message: "error" + ) + + assert exception.message =~ "module.func() failed" + end + + test "formats function with nested scope" do + exception = + RuntimeException.exception( + scope: ["a", "b", "c"], + function: "method", + message: "error" + ) + + assert exception.message =~ "a.b.c.method() failed" + end + end + + describe "exception message format" do + test "RuntimeException implements Exception protocol" do + exception = RuntimeException.exception("test error") + + assert Exception.message(exception) == "Lua runtime error: test error" + end + + test "can be raised with raise/2" do + assert_raise RuntimeException, "Lua runtime error: test", fn -> + raise RuntimeException, "test" + end + end + + test "can be raised with keyword list" do + assert_raise RuntimeException, fn -> + raise RuntimeException, + scope: ["my", "module"], + function: "test", + message: "failed" + end + end + + test "preserves original error information" do + original = {:error, :custom_reason} + exception = RuntimeException.exception(original) + + assert exception.original == original + assert exception.state == nil + end + end + + describe "integration with Lua module" do + test "RuntimeException is raised for runtime errors" do + lua = Lua.new() + + assert_raise RuntimeException, fn -> + Lua.eval!(lua, "error('test error')") + end + end + + test "RuntimeException is raised for type errors" do + lua = Lua.new() + + assert_raise RuntimeException, fn -> + Lua.eval!(lua, "return 'string' + 5") + end + end + + test "RuntimeException is raised for sandboxed functions" do + lua = Lua.new(sandboxed: [[:os, :exit]]) + + assert_raise RuntimeException, fn -> + Lua.eval!(lua, "os.exit()") + end + end + + test "RuntimeException is raised for empty keys in set!" do + lua = Lua.new() + + assert_raise RuntimeException, "Lua runtime error: Lua.set!/3 cannot have empty keys", fn -> + Lua.set!(lua, [], "value") + end + end + + test "RuntimeException is raised when deflua returns non-encoded data" do + lua = Lua.new() + + lua = + Lua.set!(lua, [:test_func], fn _args -> + # Return non-encoded atom (not a valid Lua value) + [:invalid_atom] + end) + + assert_raise RuntimeException, fn -> + Lua.eval!(lua, "test_func()") + end + end + end +end