From 1ff0e7a2f2e8e57e67811f790a18d0b99ee16012 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 18 May 2024 04:48:15 +0200 Subject: [PATCH] fix: `split at h` when `h` has forward dependencies (#4211) We use an approach similar to the one used in `simp`. closes #3731 --- src/Lean/Meta/Basic.lean | 11 +++++ src/Lean/Meta/Tactic/Split.lean | 13 ++++-- tests/lean/run/3731.lean | 76 +++++++++++++++++++++++++++++++++ 3 files changed, 97 insertions(+), 3 deletions(-) create mode 100644 tests/lean/run/3731.lean diff --git a/src/Lean/Meta/Basic.lean b/src/Lean/Meta/Basic.lean index cacf95f420..bfbd954c5e 100644 --- a/src/Lean/Meta/Basic.lean +++ b/src/Lean/Meta/Basic.lean @@ -825,6 +825,17 @@ context. Fails if the given expression is not a fvar or if no such declaration e def getFVarLocalDecl (fvar : Expr) : MetaM LocalDecl := fvar.fvarId!.getDecl +/-- +Returns `true` if another local declaration in the local context depends on `fvarId`. +-/ +def _root_.Lean.FVarId.hasForwardDeps (fvarId : FVarId) : MetaM Bool := do + let decl ← fvarId.getDecl + (← getLCtx).foldlM (init := false) (start := decl.index + 1) fun found other => + if found then + return true + else + localDeclDependsOn other fvarId + /-- Given a user-facing name for a free variable, return its declaration in the current local context. Throw an exception if free variable is not declared. diff --git a/src/Lean/Meta/Tactic/Split.lean b/src/Lean/Meta/Tactic/Split.lean index 9e68455dc0..6c625f4bd0 100644 --- a/src/Lean/Meta/Tactic/Split.lean +++ b/src/Lean/Meta/Tactic/Split.lean @@ -320,10 +320,17 @@ def splitLocalDecl? (mvarId : MVarId) (fvarId : FVarId) : MetaM (Option (List MV if e.isIte || e.isDIte then return (← splitIfLocalDecl? mvarId fvarId).map fun (mvarId₁, mvarId₂) => [mvarId₁, mvarId₂] else - let (fvarIds, mvarId) ← mvarId.revert #[fvarId] - let num := fvarIds.size + let mut mvarId := mvarId + let localDecl ← fvarId.getDecl + if (← pure localDecl.isLet <||> exprDependsOn (← mvarId.getType) fvarId <||> fvarId.hasForwardDeps) then + -- If `fvarId` has dependencies or is a let-decl, we create a copy. + mvarId ← mvarId.assert localDecl.userName localDecl.type localDecl.toExpr + else + let (fvarIds, mvarId') ← mvarId.revert #[fvarId] + assert! fvarIds.size == 1 -- fvarId does not have forward dependencies + mvarId := mvarId' let mvarIds ← splitMatch mvarId e - let mvarIds ← mvarIds.mapM fun mvarId => return (← mvarId.introNP num).2 + let mvarIds ← mvarIds.mapM fun mvarId => return (← mvarId.intro1P).2 return some mvarIds else return none diff --git a/tests/lean/run/3731.lean b/tests/lean/run/3731.lean new file mode 100644 index 0000000000..650481889b --- /dev/null +++ b/tests/lean/run/3731.lean @@ -0,0 +1,76 @@ +import Lean.Data.HashMap +open Lean + +/-- +A circuit node declaration. These are not recursive but instead contain indices into an `Env`. +-/ +inductive Decl where + /-- + A node with a constant output value. + -/ + | const (b : Bool) + /-- + An input node to the circuit. + -/ + | atom (idx : Nat) + /-- + An AIG gate with configurable input nodes and polarity. `l` and `r` are the + input node indices while `linv` and `rinv` say whether there is an inverter on + the left or right input. + -/ + | gate (l r : Nat) (linv rinv : Bool) + deriving BEq, Hashable, DecidableEq + +/-- +A cache that is valid with respect to some `Array Decl`. +-/ +def Cache (_decls : Array Decl) := HashMap Decl Nat + +/-- +Lookup a `decl` in a `cache`. + +If this returns `some i`, `Cache.find?_poperty` can be used to demonstrate: `decls[i] = decl`. +-/ +@[irreducible] +def Cache.find? (cache : Cache decls) (decl : Decl) : Option Nat := + match cache.val.find? decl with + | some hit => + if h1:hit < decls.size then + if decls[hit]'h1 = decl then + some hit + else + none + else + none + | none => none + +/-- +This states that all indices, found in a `Cache` that is valid with respect to some `decls`, +are within bounds of `decls`. +-/ +theorem Cache.find?_bounds {decls : Array Decl} {idx : Nat} (c : Cache decls) (decl : Decl) + (h : c.find? decl = some idx) : idx < decls.size := by + unfold find? at h + split at h + . split at h + . split at h + . injection h + omega + . contradiction + . contradiction + . contradiction + +/-- +This states that if `Cache.find? decl` returns `some i`, `decls[i] = decl`, holds. +-/ +theorem Cache.find?_property {decls : Array Decl} {idx : Nat} (c : Cache decls) (decl : Decl) + (h : c.find? decl = some idx) : decls[idx]'(Cache.find?_bounds c decl h) = decl := by + unfold find? at h + split at h + . split at h + . split at h + . injection h + subst idx; assumption + . contradiction + . contradiction + . contradiction