diff --git a/src/Init/Lean/Meta/DiscrTree.lean b/src/Init/Lean/Meta/DiscrTree.lean index 71f4e5c3c0..160b827d76 100644 --- a/src/Init/Lean/Meta/DiscrTree.lean +++ b/src/Init/Lean/Meta/DiscrTree.lean @@ -141,9 +141,16 @@ private partial def pushArgsAux (infos : Array ParamInfo) : Nat → Expr → Arr (pushArgsAux (i-1) f (todo.push a)) | _, _, todo => pure todo +private partial def whnfEta : Expr → MetaM Expr +| e => do + e ← whnf e; + match e.etaExpandedStrict? with + | some e => whnfEta e + | none => pure e + private def pushArgs (todo : Array Expr) (e : Expr) : MetaM (Key × Array Expr) := -do e ← whnf e; - let fn := e.getAppFn; +do e ← whnfEta e; + let fn := e.getAppFn; let push (k : Key) (nargs : Nat) : MetaM (Key × Array Expr) := do { info ← getFunInfoNArgs fn nargs; todo ← pushArgsAux info.paramInfo (nargs-1) e todo; @@ -238,7 +245,7 @@ Format.group r instance DiscrTree.hasFormat {α} [HasFormat α] : HasFormat (DiscrTree α) := ⟨format⟩ private def getKeyArgs (e : Expr) (isMatch? : Bool) : MetaM (Key × Array Expr) := -do e ← whnf e; +do e ← whnfEta e; match e.getAppFn with | Expr.lit v _ => pure (Key.lit v, #[]) | Expr.const c _ _ => let nargs := e.getAppNumArgs; pure (Key.const c nargs, e.getAppRevArgs) diff --git a/tests/lean/run/meta2.lean b/tests/lean/run/meta2.lean index f684ffd169..9c64a90fe3 100644 --- a/tests/lean/run/meta2.lean +++ b/tests/lean/run/meta2.lean @@ -300,6 +300,23 @@ do u ← getLevel σ; check r; pure r +def mkMonad (m : Expr) : MetaM Expr := +do u ← mkFreshLevelMVar; + v ← mkFreshLevelMVar; + let arrow := mkArrow (mkSort (mkLevelSucc u)) (mkSort (mkLevelSucc v)); + mType ← inferType m; + mType ← whnf mType; + print arrow; + print mType; + condM (isDefEq arrow mType) + (do u ← instantiateLevelMVars u; + v ← instantiateLevelMVars v; + let r := mkApp (mkConst `Monad [u, v]) m; + print r; + check r; + pure r) + (throw $ Exception.other "failed to create Monad application") + def mkMonadState (σ m : Expr) : MetaM Expr := do u ← getLevel σ; (some u) ← pure u.dec | throw $ Exception.other "failed to create MonadState application"; @@ -320,10 +337,31 @@ do u ← getLevel σ; def tst14 : MetaM Unit := do print "----- tst14 -----"; - decEqNat ← mkDecEq nat; - c ← synthInstance decEqNat; stateM ← mkStateM nat; print stateM; + monad ← mkMonad stateM; + globalInsts ← getGlobalInstances; + insts ← globalInsts.getUnify monad; + print insts; + pure () + +#eval run [`Init.Control.State] tst14 + +#exit + +def tst15 : MetaM Unit := +do print "----- tst15 -----"; + stateM ← mkStateM nat; + print stateM; + monad ← mkMonad stateM; + print monad; + c ← synthInstance monad; + pure () + + +#exit + decEqNat ← mkDecEq nat; + c ← synthInstance decEqNat; monadState ← mkMonadState nat stateM; print monadState; c ← synthInstance monadState;