From 31bb6a1decbd9cc2ae5491d93f77030d1df14f63 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 3 Feb 2020 14:28:08 -0800 Subject: [PATCH] feat: extend `tryCoeAndLift` Add combined coe+lift case. --- src/Init/Control/Lift.lean | 5 +++++ src/Init/Lean/Elab/Term.lean | 26 ++++++++++++++++++++------ tests/lean/run/doNotation1.lean | 7 ++++++- 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/src/Init/Control/Lift.lean b/src/Init/Control/Lift.lean index 0ecd2606f8..af411e0ec9 100644 --- a/src/Init/Control/Lift.lean +++ b/src/Init/Control/Lift.lean @@ -10,6 +10,7 @@ Please see https://hackage.haskell.org/package/layers-0.1/docs/Documentation-Lay -/ prelude import Init.Control.Monad +import Init.Coe universes u v w @@ -30,6 +31,10 @@ export HasMonadLiftT (monadLift) abbrev liftM := @monadLift +@[inline] def liftCoeM {m : Type u → Type v} {n : Type u → Type w} {α β : Type u} [HasMonadLiftT m n] [∀ a, CoeT α a β] [Monad n] (x : m α) : n β := do +a ← liftM $ x; +pure $ coe a + instance hasMonadLiftTTrans (m n o) [HasMonadLiftT m n] [HasMonadLift n o] : HasMonadLiftT m o := ⟨fun α ma => HasMonadLift.monadLift (monadLift ma : n α)⟩ diff --git a/src/Init/Lean/Elab/Term.lean b/src/Init/Lean/Elab/Term.lean index 965c38e900..000824bc91 100644 --- a/src/Init/Lean/Elab/Term.lean +++ b/src/Init/Lean/Elab/Term.lean @@ -638,7 +638,12 @@ match type with | Expr.app m α _ => pure (some (m, α)) | _ => pure none -private def isMonad? (ref : Syntax) (type : Expr) : TermElabM (Option (Expr × Expr)) := do +structure IsMonadResult := +(m : Expr) +(α : Expr) +(inst : Expr) + +private def isMonad? (ref : Syntax) (type : Expr) : TermElabM (Option IsMonadResult) := do type ← withReducible $ whnf ref type; match type with | Expr.app m α _ => @@ -647,8 +652,8 @@ match type with monadType ← mkAppM ref `Monad #[m]; result ← trySynthInstance ref monadType; match result with - | LOption.some _ => pure (some (m, α)) - | _ => pure none) + | LOption.some inst => pure (some { m := m, α := α, inst := inst }) + | _ => pure none) (fun _ => pure none) | _ => pure none @@ -706,7 +711,7 @@ since this goal does not contain any metavariables. And then, we convert `g x` into `liftM $ g x`. -/ def tryCoeAndLift (ref : Syntax) (expectedType : Expr) (eType : Expr) (e : Expr) (f? : Option Expr) : TermElabM Expr := do -some (n, β) ← isMonad? ref expectedType | tryCoe ref expectedType eType e f?; +some ⟨n, β, monadInst⟩ ← isMonad? ref expectedType | tryCoe ref expectedType eType e f?; some (m, α) ← isTypeApp? ref eType | tryCoe ref expectedType eType e f?; condM (isDefEq ref m n) (tryCoe ref expectedType eType e f?) $ catch @@ -720,8 +725,17 @@ condM (isDefEq ref m n) (tryCoe ref expectedType eType e f?) $ let eNew := mkAppN (Lean.mkConst `liftM [u_1, u_2, u_3]) #[m, n, hasMonadLiftVal, α, e]; eNewType ← inferType ref eNew; condM (isDefEq ref expectedType eNewType) - (pure eNew) - (throwTypeMismatchError ref expectedType eType e f?)) -- TODO approach 3 + (pure eNew) -- approach 2 worked + (do + u ← getLevel ref α; + v ← getLevel ref β; + let coeTInstType := Lean.mkForall `a BinderInfo.default α $ mkAppN (mkConst `CoeT [u, v]) #[α, mkBVar 0, β]; + coeTInstVal ← synthesizeInst ref coeTInstType; + let eNew := mkAppN (Lean.mkConst `liftCoeM [u_1, u_2, u_3]) #[m, n, α, β, hasMonadLiftVal, coeTInstVal, monadInst, e]; + eNewType ← inferType ref eNew; + condM (isDefEq ref expectedType eNewType) + (pure eNew) -- approach 3 worked + (throwTypeMismatchError ref expectedType eType e f?))) (fun _ => throwTypeMismatchError ref expectedType eType e f?) /-- diff --git a/tests/lean/run/doNotation1.lean b/tests/lean/run/doNotation1.lean index 1b806481ab..9e1aedcc08 100644 --- a/tests/lean/run/doNotation1.lean +++ b/tests/lean/run/doNotation1.lean @@ -101,10 +101,15 @@ pure $ x > 0 def tst5 (x : Nat) : IO (Option Nat) := if x > 10 then pure x else pure none +def tst6 (x : Nat) : StateT Nat IO (Option Nat) := +if x > 10 then g x else pure none + syntax [doHash] "#":max : term -def tst6 : StateT (Nat × Nat) IO Unit := do +def tst7 : StateT (Nat × Nat) IO Unit := do if #.1 == 0 then IO.println "first field is zero" else IO.println "first field is not zero" + +#check tst7