diff --git a/lib/arbor/tree.ex b/lib/arbor/tree.ex index db31dcf..2ab71b5 100644 --- a/lib/arbor/tree.ex +++ b/lib/arbor/tree.ex @@ -35,7 +35,12 @@ defmodule Arbor.Tree do defmacro __before_compile__(%{module: definition} = _env) do arbor_opts = Module.get_attribute(definition, :arbor_opts) - {primary_key, primary_key_type, _} = Module.get_attribute(definition, :primary_key) + + {primary_key, primary_key_type} = case Module.get_attribute(definition, :primary_key) do + {primary_key, primary_key_type, _} -> {primary_key, primary_key_type} + _ -> {arbor_opts[:primary_key], arbor_opts[:primary_key_type]} + end + struct_fields = Module.get_attribute(definition, :struct_fields) struct_source = struct_fields[:__meta__].source @@ -80,13 +85,15 @@ defmodule Arbor.Tree do ) end + defp id(struct), do: Map.get(struct, unquote(opts[:foreign_key]), Map.get(struct, :id)) + def parent(struct) do from( t in unquote(definition), where: fragment( unquote("#{opts[:primary_key]} = ?"), - type(^struct.unquote(opts[:foreign_key]), unquote(opts[:foreign_key_type])) + type(^id(struct), unquote(opts[:foreign_key_type])) ) ) end @@ -97,7 +104,7 @@ defmodule Arbor.Tree do where: fragment( unquote("#{opts[:foreign_key]} = ?"), - type(^struct.unquote(opts[:primary_key]), unquote(opts[:foreign_key_type])) + type(^id(struct), unquote(opts[:foreign_key_type])) ) ) end @@ -107,7 +114,7 @@ defmodule Arbor.Tree do t in unquote(definition), where: t.unquote(opts[:primary_key]) != - type(^struct.unquote(opts[:primary_key]), unquote(opts[:primary_key_type])), + type(^id(struct), unquote(opts[:primary_key_type])), where: fragment( unquote("#{opts[:foreign_key]} = ?"), @@ -137,7 +144,7 @@ defmodule Arbor.Tree do SELECT * FROM #{opts[:tree_name]} ) - """), type(^struct.unquote(opts[:primary_key]), unquote(opts[:primary_key_type]))), + """), type(^id(struct), unquote(opts[:primary_key_type]))), on: t.unquote(opts[:primary_key]) == g.unquote(opts[:foreign_key]) end @@ -164,7 +171,7 @@ defmodule Arbor.Tree do ) SELECT #{opts[:primary_key]} FROM #{opts[:tree_name]} """), - type(^struct.unquote(opts[:primary_key]), unquote(opts[:foreign_key_type])), + type(^id(struct), unquote(opts[:foreign_key_type])), type(^depth, :integer) ) )