Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions NodeGraph/Utils/CollectDeps.lean
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import Lean
import NodeGraph.WeightedSorry.Ext

open Lean

Expand All @@ -17,6 +17,7 @@ private partial def collectDepsAux (c : Name) :
let env ← getEnv
let some const := env.find? c
| throwError "Failed to find {c} in environment"
let sorriedDeps ← getWSorriedDependencies c
let usedConsts ← const.getUsedConstantsAsSet.toArray.filterM fun t => Bool.not <$> do
if t matches .str _ "inj" then return true
if t matches .str _ "noConfusionType" then return true
Expand All @@ -27,7 +28,7 @@ private partial def collectDepsAux (c : Name) :
<||> isRec t <||> Meta.isMatcher t
modify fun s => {
visited := s.visited.insert c
usedConsts := s.usedConsts.union <| .ofArray usedConsts
usedConsts := s.usedConsts.union <| .union (.ofArray usedConsts) sorriedDeps
}
for d in usedConsts do collectDepsAux d

Expand Down
1 change: 1 addition & 0 deletions NodeGraph/WeightedSorry.lean
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
import NodeGraph.WeightedSorry.Basic
import NodeGraph.WeightedSorry.Ext
90 changes: 37 additions & 53 deletions NodeGraph/WeightedSorry/Basic.lean
Original file line number Diff line number Diff line change
@@ -1,58 +1,42 @@
import Lean
import NodeGraph.WeightedSorry.Ext

open Lean

namespace NodeGraph

structure Weight where
val : Nat
deriving ToExpr, BEq, Hashable

structure WSorry where
weight : Weight
type : Expr
name : Name
deriving BEq, Hashable

instance Weight.instOfNat (n : Nat) : OfNat Weight n where ofNat := .mk n

macro:max "wsorry" t:term:max : term =>
`((sorry : Weight → _) $t)

unsafe
def getWSorry (e : Expr) : MetaM (Option WSorry) := do
let (nm, args) := e.getAppFnArgs
unless nm == ``sorryAx do return none
if h : args.size = 4 then
let shouldBeName := args[2]
let shouldBeWeight := args[3]
unless (← Meta.inferType shouldBeName) == .const `Lean.Name [] do return none
unless (← Meta.inferType shouldBeWeight) == .const `NodeGraph.Weight [] do return none
let name ← Meta.evalExpr Name (.const `Lean.Name []) args[2]
let weight := Expr.app (.const `NodeGraph.Weight.val []) args[3]
let weight ← Meta.evalExpr Nat (.const `Nat []) weight
return some ⟨⟨weight⟩, args[0], name⟩
else
return none

unsafe
def collectWSorriesInExpr (e : Expr) : MetaM (Std.HashSet WSorry) := Prod.snd <$> go.run {}
where go : StateT (Std.HashSet WSorry) MetaM Unit := Meta.forEachExpr e fun e => do
if let some val ← getWSorry e then modify fun S => S.insert val

unsafe
def collectWSorriesInConst (e : ConstantInfo) : CoreM (Std.HashSet WSorry) := Meta.MetaM.run' do
let tpWeight ← collectWSorriesInExpr e.type
let valWeight : Std.HashSet (Expr × Nat) ← match e.value? with
| some val => collectWSorriesInExpr val
| none => pure .emptyWithCapacity
return tpWeight.union valWeight

/-- Returns the weight and whether the constant *does not* use a weighted sorry -/
unsafe
def collectConstWeight (e : ConstantInfo) : CoreM (Nat × Bool) := do
let wsorries ← collectWSorriesInConst e
let mut out := 0
for ⟨⟨a⟩, _, _⟩ in wsorries do
out := out + a
return (out, wsorries.isEmpty)
syntax (name := wsorryTerm) "wsorry " num (" [" ident,* "]")? : term
syntax (name := wsorryTactic) "wsorry " num (" [" ident,* "]")? : tactic

open Lean Elab Tactic

/-- Elaborator for weighted sorries: Adds a new `WSorry` to the
`wSorryExtension` with `weight` and `decls`. -/
def elabWSorry (weight : TSyntax `num) (decls : Option (Array (TSyntax `ident))) :
TermElabM Unit := do
let some name := ← Term.getDeclName? | return ()
let infos : Array WSorry := ←
match wSorryExtension.find? (← getEnv) name with
| some xs => pure xs
| none => pure #[]
let mut deps := #[]
for decls in decls do for decl in decls do
let _ ← resolveGlobalConst decl
deps := #[decl.getId] ++ deps
let info : WSorry :=
{ weight := weight.getNat
dependencies := deps }
wSorryExtension.addEntry (← getEnv) (name, #[info] ++ infos) |> setEnv

elab_rules : term
| `(term| wsorry $weight $[[$[$decls],*]]?) => do
let stx ← `(wsorry $weight $[[$[$decls],*]]?)
Lean.Elab.Term.adaptExpander
(fun stx ↦ do elabWSorry weight decls; `(term| sorry))
stx none

elab_rules : tactic
| `(tactic| wsorry $weight $[[$[$decls],*]]?) => do
let stx ← `(wsorry $weight $[[$[$decls],*]]?)
Lean.Elab.Tactic.adaptExpander
(fun stx ↦ do elabWSorry weight decls; `(tactic| sorry))
stx
53 changes: 53 additions & 0 deletions NodeGraph/WeightedSorry/Ext.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import Lean
import Batteries.Lean.NameMapAttribute

namespace NodeGraph

open Lean Elab Parser Command Meta

structure WSorry where
weight : Nat
dependencies : Array Name
deriving DecidableEq, Hashable

instance : ToString WSorry where
toString info := s!"{info.weight} with {info.dependencies}"

initialize wSorryExtension : NameMapExtension (Array WSorry) ←
registerSimplePersistentEnvExtension {
name := decl_name%
addImportedFn := fun ass =>
ass.foldl (init := ∅) fun names as =>
as.foldl (init := names) fun names (a, b) => names.insert a b
addEntryFn := fun s n => s.insert n.1 n.2
toArrayFn := fun es => es.toArray
asyncMode := .sync
}

/-- Returns all `wsorry`s used in declaration `name`. If no declaration `name` exist,
returns the empty set. -/
def getWSorriesByName (name : Name) : CoreM (Std.HashSet WSorry) := do
let env ← getEnv
match wSorryExtension.find? env name with
| some infos => pure <| .ofArray infos
| none => pure .emptyWithCapacity

/-- Returns the union of all dependencies used by a `wsorry` in declaration `name`.
If no declaration `name` exist, returns the empty set. -/
def getWSorriedDependencies (name : Name) : CoreM (Std.HashSet Name) := do
let depSorries ← getWSorriesByName name
let mut res := .emptyWithCapacity
for info in depSorries do
res := res.union (.ofArray info.dependencies)
return res

/-- Returns the weight and whether the constant *does not* use a weighted sorry -/
unsafe
def collectConstWeight (e : ConstantInfo) : CoreM (Nat × Bool) := do
let wsorries ← getWSorriesByName e.name
let mut out := 0
for wsorry' in wsorries do
out := out + wsorry'.weight
return (out, wsorries.isEmpty)

end NodeGraph
11 changes: 9 additions & 2 deletions Test.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,17 @@ def a : Nat :=
axiom f : Nat

@[node in defs]
noncomputable def b : Nat × wsorry 5 := (a + wsorry 10 + f, wsorry 20)
def aux₁ : Nat := 4

@[node in defs]
noncomputable def b : Nat × wsorry 5 [aux₁] := (a + wsorry 10 + f, wsorry 20)

@[node in defs]
def aux₂ : Nat := 3

@[node in lemmas]
theorem foo : b = (0, wsorry 73) := sorry
theorem foo : b = (0, wsorry 73) := by
wsorry 123 [aux₂, f]

#decl_graph
#decl_graph from b
Expand Down