From 5adce9fa205cde6878428c4684e82aa55b237319 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 3 Dec 2019 10:30:53 -0800 Subject: [PATCH] fix: use eta reduction at `DiscrTree` @kha @dselsam Suppose we are trying to retrieve the global instances for `(Monad (StateM Nat))`. During retrieval, we reducde `StateM Nat` into `fun x => StateT Nat Id x` However, the `DiscrTree` contains an entry for `Monad (StateT * *)`. Thus, we fail to retrieve any instance. I fixed the particular issue by using eta reduction. Not sure we will encounter other definitional-equality related problems . --- src/Init/Lean/Meta/DiscrTree.lean | 13 +++++++--- tests/lean/run/meta2.lean | 42 +++++++++++++++++++++++++++++-- 2 files changed, 50 insertions(+), 5 deletions(-) diff --git a/src/Init/Lean/Meta/DiscrTree.lean b/src/Init/Lean/Meta/DiscrTree.lean index 71f4e5c3c0..160b827d76 100644 --- a/src/Init/Lean/Meta/DiscrTree.lean +++ b/src/Init/Lean/Meta/DiscrTree.lean @@ -141,9 +141,16 @@ private partial def pushArgsAux (infos : Array ParamInfo) : Nat → Expr → Arr (pushArgsAux (i-1) f (todo.push a)) | _, _, todo => pure todo +private partial def whnfEta : Expr → MetaM Expr +| e => do + e ← whnf e; + match e.etaExpandedStrict? with + | some e => whnfEta e + | none => pure e + private def pushArgs (todo : Array Expr) (e : Expr) : MetaM (Key × Array Expr) := -do e ← whnf e; - let fn := e.getAppFn; +do e ← whnfEta e; + let fn := e.getAppFn; let push (k : Key) (nargs : Nat) : MetaM (Key × Array Expr) := do { info ← getFunInfoNArgs fn nargs; todo ← pushArgsAux info.paramInfo (nargs-1) e todo; @@ -238,7 +245,7 @@ Format.group r instance DiscrTree.hasFormat {α} [HasFormat α] : HasFormat (DiscrTree α) := ⟨format⟩ private def getKeyArgs (e : Expr) (isMatch? : Bool) : MetaM (Key × Array Expr) := -do e ← whnf e; +do e ← whnfEta e; match e.getAppFn with | Expr.lit v _ => pure (Key.lit v, #[]) | Expr.const c _ _ => let nargs := e.getAppNumArgs; pure (Key.const c nargs, e.getAppRevArgs) diff --git a/tests/lean/run/meta2.lean b/tests/lean/run/meta2.lean index f684ffd169..9c64a90fe3 100644 --- a/tests/lean/run/meta2.lean +++ b/tests/lean/run/meta2.lean @@ -300,6 +300,23 @@ do u ← getLevel σ; check r; pure r +def mkMonad (m : Expr) : MetaM Expr := +do u ← mkFreshLevelMVar; + v ← mkFreshLevelMVar; + let arrow := mkArrow (mkSort (mkLevelSucc u)) (mkSort (mkLevelSucc v)); + mType ← inferType m; + mType ← whnf mType; + print arrow; + print mType; + condM (isDefEq arrow mType) + (do u ← instantiateLevelMVars u; + v ← instantiateLevelMVars v; + let r := mkApp (mkConst `Monad [u, v]) m; + print r; + check r; + pure r) + (throw $ Exception.other "failed to create Monad application") + def mkMonadState (σ m : Expr) : MetaM Expr := do u ← getLevel σ; (some u) ← pure u.dec | throw $ Exception.other "failed to create MonadState application"; @@ -320,10 +337,31 @@ do u ← getLevel σ; def tst14 : MetaM Unit := do print "----- tst14 -----"; - decEqNat ← mkDecEq nat; - c ← synthInstance decEqNat; stateM ← mkStateM nat; print stateM; + monad ← mkMonad stateM; + globalInsts ← getGlobalInstances; + insts ← globalInsts.getUnify monad; + print insts; + pure () + +#eval run [`Init.Control.State] tst14 + +#exit + +def tst15 : MetaM Unit := +do print "----- tst15 -----"; + stateM ← mkStateM nat; + print stateM; + monad ← mkMonad stateM; + print monad; + c ← synthInstance monad; + pure () + + +#exit + decEqNat ← mkDecEq nat; + c ← synthInstance decEqNat; monadState ← mkMonadState nat stateM; print monadState; c ← synthInstance monadState;