From 3b7bb806dd277f3a90f1b080461876d1c5cb6329 Mon Sep 17 00:00:00 2001 From: Christian Merten Date: Sun, 23 Nov 2025 17:21:18 +0100 Subject: [PATCH] wsorry with deps --- NodeGraph/Utils/CollectDeps.lean | 5 +- NodeGraph/WeightedSorry.lean | 1 + NodeGraph/WeightedSorry/Basic.lean | 90 ++++++++++++------------------ NodeGraph/WeightedSorry/Ext.lean | 53 ++++++++++++++++++ Test.lean | 11 +++- 5 files changed, 103 insertions(+), 57 deletions(-) create mode 100644 NodeGraph/WeightedSorry/Ext.lean diff --git a/NodeGraph/Utils/CollectDeps.lean b/NodeGraph/Utils/CollectDeps.lean index 04cf06c..7a85537 100644 --- a/NodeGraph/Utils/CollectDeps.lean +++ b/NodeGraph/Utils/CollectDeps.lean @@ -1,4 +1,4 @@ -import Lean +import NodeGraph.WeightedSorry.Ext open Lean @@ -17,6 +17,7 @@ private partial def collectDepsAux (c : Name) : let env ← getEnv let some const := env.find? c | throwError "Failed to find {c} in environment" + let sorriedDeps ← getWSorriedDependencies c let usedConsts ← const.getUsedConstantsAsSet.toArray.filterM fun t => Bool.not <$> do if t matches .str _ "inj" then return true if t matches .str _ "noConfusionType" then return true @@ -27,7 +28,7 @@ private partial def collectDepsAux (c : Name) : <||> isRec t <||> Meta.isMatcher t modify fun s => { visited := s.visited.insert c - usedConsts := s.usedConsts.union <| .ofArray usedConsts + usedConsts := s.usedConsts.union <| .union (.ofArray usedConsts) sorriedDeps } for d in usedConsts do collectDepsAux d diff --git a/NodeGraph/WeightedSorry.lean b/NodeGraph/WeightedSorry.lean index b49ddd4..f2e9caf 100644 --- a/NodeGraph/WeightedSorry.lean +++ b/NodeGraph/WeightedSorry.lean @@ -1 +1,2 @@ import NodeGraph.WeightedSorry.Basic +import NodeGraph.WeightedSorry.Ext diff --git a/NodeGraph/WeightedSorry/Basic.lean b/NodeGraph/WeightedSorry/Basic.lean index 15eb446..c6ac96a 100644 --- a/NodeGraph/WeightedSorry/Basic.lean +++ b/NodeGraph/WeightedSorry/Basic.lean @@ -1,58 +1,42 @@ -import Lean +import NodeGraph.WeightedSorry.Ext open Lean namespace NodeGraph -structure Weight where - val : Nat -deriving ToExpr, BEq, Hashable - -structure WSorry where - weight : Weight - type : Expr - name : Name -deriving BEq, Hashable - -instance Weight.instOfNat (n : Nat) : OfNat Weight n where ofNat := .mk n - -macro:max "wsorry" t:term:max : term => - `((sorry : Weight → _) $t) - -unsafe -def getWSorry (e : Expr) : MetaM (Option WSorry) := do - let (nm, args) := e.getAppFnArgs - unless nm == ``sorryAx do return none - if h : args.size = 4 then - let shouldBeName := args[2] - let shouldBeWeight := args[3] - unless (← Meta.inferType shouldBeName) == .const `Lean.Name [] do return none - unless (← Meta.inferType shouldBeWeight) == .const `NodeGraph.Weight [] do return none - let name ← Meta.evalExpr Name (.const `Lean.Name []) args[2] - let weight := Expr.app (.const `NodeGraph.Weight.val []) args[3] - let weight ← Meta.evalExpr Nat (.const `Nat []) weight - return some ⟨⟨weight⟩, args[0], name⟩ - else - return none - -unsafe -def collectWSorriesInExpr (e : Expr) : MetaM (Std.HashSet WSorry) := Prod.snd <$> go.run {} -where go : StateT (Std.HashSet WSorry) MetaM Unit := Meta.forEachExpr e fun e => do - if let some val ← getWSorry e then modify fun S => S.insert val - -unsafe -def collectWSorriesInConst (e : ConstantInfo) : CoreM (Std.HashSet WSorry) := Meta.MetaM.run' do - let tpWeight ← collectWSorriesInExpr e.type - let valWeight : Std.HashSet (Expr × Nat) ← match e.value? with - | some val => collectWSorriesInExpr val - | none => pure .emptyWithCapacity - return tpWeight.union valWeight - -/-- Returns the weight and whether the constant *does not* use a weighted sorry -/ -unsafe -def collectConstWeight (e : ConstantInfo) : CoreM (Nat × Bool) := do - let wsorries ← collectWSorriesInConst e - let mut out := 0 - for ⟨⟨a⟩, _, _⟩ in wsorries do - out := out + a - return (out, wsorries.isEmpty) +syntax (name := wsorryTerm) "wsorry " num (" [" ident,* "]")? : term +syntax (name := wsorryTactic) "wsorry " num (" [" ident,* "]")? : tactic + +open Lean Elab Tactic + +/-- Elaborator for weighted sorries: Adds a new `WSorry` to the +`wSorryExtension` with `weight` and `decls`. -/ +def elabWSorry (weight : TSyntax `num) (decls : Option (Array (TSyntax `ident))) : + TermElabM Unit := do + let some name := ← Term.getDeclName? | return () + let infos : Array WSorry := ← + match wSorryExtension.find? (← getEnv) name with + | some xs => pure xs + | none => pure #[] + let mut deps := #[] + for decls in decls do for decl in decls do + let _ ← resolveGlobalConst decl + deps := #[decl.getId] ++ deps + let info : WSorry := + { weight := weight.getNat + dependencies := deps } + wSorryExtension.addEntry (← getEnv) (name, #[info] ++ infos) |> setEnv + +elab_rules : term + | `(term| wsorry $weight $[[$[$decls],*]]?) => do + let stx ← `(wsorry $weight $[[$[$decls],*]]?) + Lean.Elab.Term.adaptExpander + (fun stx ↦ do elabWSorry weight decls; `(term| sorry)) + stx none + +elab_rules : tactic + | `(tactic| wsorry $weight $[[$[$decls],*]]?) => do + let stx ← `(wsorry $weight $[[$[$decls],*]]?) + Lean.Elab.Tactic.adaptExpander + (fun stx ↦ do elabWSorry weight decls; `(tactic| sorry)) + stx diff --git a/NodeGraph/WeightedSorry/Ext.lean b/NodeGraph/WeightedSorry/Ext.lean new file mode 100644 index 0000000..1451fa9 --- /dev/null +++ b/NodeGraph/WeightedSorry/Ext.lean @@ -0,0 +1,53 @@ +import Lean +import Batteries.Lean.NameMapAttribute + +namespace NodeGraph + +open Lean Elab Parser Command Meta + +structure WSorry where + weight : Nat + dependencies : Array Name +deriving DecidableEq, Hashable + +instance : ToString WSorry where + toString info := s!"{info.weight} with {info.dependencies}" + +initialize wSorryExtension : NameMapExtension (Array WSorry) ← + registerSimplePersistentEnvExtension { + name := decl_name% + addImportedFn := fun ass => + ass.foldl (init := ∅) fun names as => + as.foldl (init := names) fun names (a, b) => names.insert a b + addEntryFn := fun s n => s.insert n.1 n.2 + toArrayFn := fun es => es.toArray + asyncMode := .sync + } + +/-- Returns all `wsorry`s used in declaration `name`. If no declaration `name` exist, +returns the empty set. -/ +def getWSorriesByName (name : Name) : CoreM (Std.HashSet WSorry) := do + let env ← getEnv + match wSorryExtension.find? env name with + | some infos => pure <| .ofArray infos + | none => pure .emptyWithCapacity + +/-- Returns the union of all dependencies used by a `wsorry` in declaration `name`. +If no declaration `name` exist, returns the empty set. -/ +def getWSorriedDependencies (name : Name) : CoreM (Std.HashSet Name) := do + let depSorries ← getWSorriesByName name + let mut res := .emptyWithCapacity + for info in depSorries do + res := res.union (.ofArray info.dependencies) + return res + +/-- Returns the weight and whether the constant *does not* use a weighted sorry -/ +unsafe +def collectConstWeight (e : ConstantInfo) : CoreM (Nat × Bool) := do + let wsorries ← getWSorriesByName e.name + let mut out := 0 + for wsorry' in wsorries do + out := out + wsorry'.weight + return (out, wsorries.isEmpty) + +end NodeGraph diff --git a/Test.lean b/Test.lean index 473c7b1..bc093da 100644 --- a/Test.lean +++ b/Test.lean @@ -13,10 +13,17 @@ def a : Nat := axiom f : Nat @[node in defs] -noncomputable def b : Nat × wsorry 5 := (a + wsorry 10 + f, wsorry 20) +def aux₁ : Nat := 4 + +@[node in defs] +noncomputable def b : Nat × wsorry 5 [aux₁] := (a + wsorry 10 + f, wsorry 20) + +@[node in defs] +def aux₂ : Nat := 3 @[node in lemmas] -theorem foo : b = (0, wsorry 73) := sorry +theorem foo : b = (0, wsorry 73) := by + wsorry 123 [aux₂, f] #decl_graph #decl_graph from b