From 6429486f88063de01ea9a8296834014e07fa946f Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 12 Feb 2020 09:11:20 -0800 Subject: [PATCH] refactor: decouple elaboration for `forall` and `lambda` binders Motivation: another refactoring to improve `elabFunCore`. --- src/Init/Lean/Elab/Binders.lean | 133 +++++++++++++++++++++----------- 1 file changed, 88 insertions(+), 45 deletions(-) diff --git a/src/Init/Lean/Elab/Binders.lean b/src/Init/Lean/Elab/Binders.lean index b1a0c62a3d..15fd7c7c2d 100644 --- a/src/Init/Lean/Elab/Binders.lean +++ b/src/Init/Lean/Elab/Binders.lean @@ -9,7 +9,7 @@ import Init.Lean.Elab.Term namespace Lean namespace Elab namespace Term -namespace Binders + /-- Given syntax of the forms a) (`:` term)? @@ -37,6 +37,7 @@ else structure BinderView := (id : Syntax) (type : Syntax) (bi : BinderInfo) + /- Expand `optional (binderDefault <|> binderTactic)` def binderDefault := parser! " := " >> termParser @@ -84,72 +85,50 @@ match stx with throwUnsupportedSyntax | _ => throwUnsupportedSyntax -structure State := -(fvars : Array Expr := #[]) -(lctx : LocalContext) -(localInsts : LocalInstances) -(expectedType? : Option Expr := none) - -private def propagateExpectedType (ref : Syntax) (fvar : Expr) (fvarType : Expr) (s : State) : TermElabM State := do -match s.expectedType? with -| none => pure s -| some expectedType => do - expectedType ← whnfForall ref expectedType; - match expectedType with - | Expr.forallE _ d b _ => do - isDefEq ref fvarType d; - let b := b.instantiate1 fvar; - pure { expectedType? := some b, .. s } - | _ => pure { expectedType? := none, .. s } - -private partial def elabBinderViews (binderViews : Array BinderView) : Nat → State → TermElabM State -| i, s => +private partial def elabBinderViews (binderViews : Array BinderView) + : Nat → Array Expr → LocalContext → LocalInstances → TermElabM (Array Expr × LocalContext × LocalInstances) +| i, fvars, lctx, localInsts => if h : i < binderViews.size then let binderView := binderViews.get ⟨i, h⟩; - withLCtx s.lctx s.localInsts $ do + withLCtx lctx localInsts $ do type ← elabType binderView.type; fvarId ← mkFreshFVarId; let fvar := mkFVar fvarId; - let fvars := s.fvars.push fvar; + let fvars := fvars.push fvar; -- dbgTrace (toString binderView.id.getId ++ " : " ++ toString type); - let lctx := s.lctx.mkLocalDecl fvarId binderView.id.getId type binderView.bi; - s ← propagateExpectedType binderView.id fvar type s; + let lctx := lctx.mkLocalDecl fvarId binderView.id.getId type binderView.bi; className? ← isClass binderView.type type; match className? with - | none => elabBinderViews (i+1) { fvars := fvars, lctx := lctx, .. s } + | none => elabBinderViews (i+1) fvars lctx localInsts | some className => do resetSynthInstanceCache; - let localInsts := s.localInsts.push { className := className, fvar := mkFVar fvarId }; - elabBinderViews (i+1) { fvars := fvars, lctx := lctx, localInsts := localInsts, .. s } + let localInsts := localInsts.push { className := className, fvar := mkFVar fvarId }; + elabBinderViews (i+1) fvars lctx localInsts else - pure s + pure (fvars, lctx, localInsts) -partial def elabBindersAux (binders : Array Syntax) : Nat → State → TermElabM State -| i, s => +private partial def elabBindersAux (binders : Array Syntax) + : Nat → Array Expr → LocalContext → LocalInstances → TermElabM (Array Expr × LocalContext × LocalInstances) +| i, fvars, lctx, localInsts => if h : i < binders.size then do binderViews ← matchBinder (binders.get ⟨i, h⟩); - s ← elabBinderViews binderViews 0 s; - elabBindersAux (i+1) s + (fvars, lctx, localInsts) ← elabBinderViews binderViews 0 fvars lctx localInsts; + elabBindersAux (i+1) fvars lctx localInsts else - pure s - -end Binders + pure (fvars, lctx, localInsts) /-- Elaborate the given binders (i.e., `Syntax` objects for `simpleBinder <|> bracktedBinder`), update the local context, set of local instances, reset instance chache (if needed), and then execute `x` with the updated context. -/ -def elabBindersWithExpectedType {α} (binders : Array Syntax) (expectedType? : Option Expr) (x : Array Expr → Option Expr → TermElabM α) : TermElabM α := -if binders.isEmpty then x #[] expectedType? +def elabBinders {α} (binders : Array Syntax) (x : Array Expr → TermElabM α) : TermElabM α := +if binders.isEmpty then x #[] else do lctx ← getLCtx; localInsts ← getLocalInsts; - s ← Binders.elabBindersAux binders 0 { lctx := lctx, localInsts := localInsts, expectedType? := expectedType? }; - resettingSynthInstanceCacheWhen (s.localInsts.size > localInsts.size) $ withLCtx s.lctx s.localInsts $ - x s.fvars s.expectedType? - -def elabBinders {α} (binders : Array Syntax) (x : Array Expr → TermElabM α) : TermElabM α := -elabBindersWithExpectedType binders none $ fun args _ => x args + (fvars, lctx, newLocalInsts) ← elabBindersAux binders 0 #[] lctx localInsts; + resettingSynthInstanceCacheWhen (newLocalInsts.size > localInsts.size) $ withLCtx lctx newLocalInsts $ + x fvars @[inline] def elabBinder {α} (binder : Syntax) (x : Expr → TermElabM α) : TermElabM α := elabBinders #[binder] (fun fvars => x (fvars.get! 1)) @@ -261,12 +240,76 @@ private partial def expandFunBindersAux (binders : Array Syntax) : Syntax → Na def expandFunBinders (binders : Array Syntax) (body : Syntax) : TermElabM (Array Syntax × Syntax) := expandFunBindersAux binders body 0 #[] +namespace FunBinders + +structure State := +(implicitArgs : Array Expr := #[]) +(fvars : Array Expr := #[]) +(lctx : LocalContext) +(localInsts : LocalInstances) +(expectedType? : Option Expr := none) +(explicit : Bool := false) + +private def propagateExpectedType (ref : Syntax) (fvar : Expr) (fvarType : Expr) (s : State) : TermElabM State := do +match s.expectedType? with +| none => pure s +| some expectedType => do + expectedType ← whnfForall ref expectedType; + match expectedType with + | Expr.forallE _ d b _ => do + isDefEq ref fvarType d; + let b := b.instantiate1 fvar; + pure { expectedType? := some b, .. s } + | _ => pure { expectedType? := none, .. s } + +private partial def elabFunBinderViews (binderViews : Array BinderView) : Nat → State → TermElabM State +| i, s => + if h : i < binderViews.size then + let binderView := binderViews.get ⟨i, h⟩; + withLCtx s.lctx s.localInsts $ do + type ← elabType binderView.type; + fvarId ← mkFreshFVarId; + let fvar := mkFVar fvarId; + let fvars := s.fvars.push fvar; + -- dbgTrace (toString binderView.id.getId ++ " : " ++ toString type); + let lctx := s.lctx.mkLocalDecl fvarId binderView.id.getId type binderView.bi; + s ← propagateExpectedType binderView.id fvar type s; + className? ← isClass binderView.type type; + match className? with + | none => elabFunBinderViews (i+1) { fvars := fvars, lctx := lctx, .. s } + | some className => do + resetSynthInstanceCache; + let localInsts := s.localInsts.push { className := className, fvar := mkFVar fvarId }; + elabFunBinderViews (i+1) { fvars := fvars, lctx := lctx, localInsts := localInsts, .. s } + else + pure s + +partial def elabFunBindersAux (binders : Array Syntax) : Nat → State → TermElabM State +| i, s => + if h : i < binders.size then do + binderViews ← matchBinder (binders.get ⟨i, h⟩); + s ← elabFunBinderViews binderViews 0 s; + elabFunBindersAux (i+1) s + else + pure s + +end FunBinders + +def elabFunBinders {α} (binders : Array Syntax) (expectedType? : Option Expr) (explicit : Bool) (x : Array Expr → Option Expr → TermElabM α) : TermElabM α := +if binders.isEmpty then x #[] expectedType? +else do + lctx ← getLCtx; + localInsts ← getLocalInsts; + s ← FunBinders.elabFunBindersAux binders 0 { lctx := lctx, localInsts := localInsts, expectedType? := expectedType?, explicit := explicit }; + resettingSynthInstanceCacheWhen (s.localInsts.size > localInsts.size) $ withLCtx s.lctx s.localInsts $ + x s.fvars s.expectedType? + def elabFunCore (stx : Syntax) (expectedType? : Option Expr) (explicit : Bool) : TermElabM Expr := do -- `fun` term+ `=>` term let binders := (stx.getArg 1).getArgs; let body := stx.getArg 3; (binders, body) ← expandFunBinders binders body; -elabBindersWithExpectedType binders expectedType? $ fun xs expectedType? => do { +elabFunBinders binders expectedType? explicit $ fun xs expectedType? => do { e ← elabTerm body expectedType?; mkLambda stx xs e }