diff --git a/src/Lean/Elab/MutualDef.lean b/src/Lean/Elab/MutualDef.lean index c567d2d5e3..a6d1c61396 100644 --- a/src/Lean/Elab/MutualDef.lean +++ b/src/Lean/Elab/MutualDef.lean @@ -3,6 +3,8 @@ Copyright (c) 2020 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ +import Lean.Meta.Closure +import Lean.Meta.Check import Lean.Elab.Command import Lean.Elab.Definition @@ -21,6 +23,9 @@ structure DefViewElabHeader := (type : Expr) -- including the parameters (declVal : Syntax) +instance DefViewElabHeader.inhabited : Inhabited DefViewElabHeader := +⟨⟨arbitrary _, {}, DefKind.«def», arbitrary _, arbitrary _, [], 0, arbitrary _, arbitrary _⟩⟩ + namespace Term open Meta @@ -187,7 +192,7 @@ match decl? with | some n => pure n /- -Ensures that the of let-rec definitions do not contain functions being defined. +Ensures that the of let-rec definition types do not contain functions being defined. In principle, this test can be improved. We could perform it after we separate the set of functions is strongly connected components. However, this extra complication doesn't seem worth it. -/ @@ -199,41 +204,119 @@ letRecsToLift.forM fun toLift => do fnName ← getFunName fvarId letRecsToLift; throwErrorAt toLift.ref ("invalid type in 'let rec', it uses '" ++ fnName ++ "' which is being defined simultaneously") -private def replaceFunFVarsWithConsts (headers : Array DefViewElabHeader) (funFVars : Array Expr) (vars : Array Expr) (e : Expr) : Expr := -e.replace fun x => match x with - | Expr.fvar fvarId _ => - match funFVars.indexOf x with - | some idx => match headers.get? idx.val with - | some header => some $ mkAppN (Lean.mkConst header.declName) vars -- Remark: we add the universe levels later - | none => none - | none => none - | _ => none +structure PreDeclaration := +(kind : DefKind) +(modifiers : Modifiers) +(declName : Name) +(type : Expr) +(val : Expr) -/- A mapping from -/ -abbrev Closures := NameMap NameSet +namespace MutualClosure -namespace LetRecClosure +/- A mapping from FVarId to Set of FVarIds. -/ +abbrev UsedFVarsMap := NameMap NameSet + +/- +Create the `UsedFVarsMap` mapping that takes the variable id for the mutually recursive functions being defined to the set of +free variables in its definition. + +For `mainFVars`, this is just the set of section variables `sectionVars` used. +For nested let-rec functions, we collect their free variables. + +Recall that a `let rec` expressions are encoded as follows in the elaborator. +```lean +let rec + f : A := t, + g : B := s; +body +``` +is encoded as +```lean +let f : A := ?m₁; +let g : B := ?m₂; +body +``` +where `?m₁` and `?m₂` are synthetic opaque metavariables. That are assigned by this module. +We may have nested `let rec`s. +```lean +let rec f : A := + let rec g : B := t; + s; +body +``` +is encoded as +```lean +let f : A := ?m₁; +body +``` +and the body of `f` is stored the field `val` of a `LetRecToLift`. For the example above, +we would have a `LetRecToLift` containing: +``` +{ + mvarId := m₁, + val := `(let g : B := ?m₂; body) + ... +} +``` +Note that `g` is not a free variable at `(let g : B := ?m₂; body)`. We recover the fact that +`f` depends on `g` because it contains `m₂` +-/ +private def mkInitialUsedFVarsMap (mctx : MetavarContext) (sectionVars : Array Expr) (mainFVarIds : Array FVarId) (letRecsToLift : List LetRecToLift) + : UsedFVarsMap := +let sectionVarSet := sectionVars.foldl (fun (s : NameSet) (var : Expr) => s.insert var.fvarId!) {}; +let usedFVarMap := mainFVarIds.foldl + (fun (usedFVarMap : UsedFVarsMap) mainFVarId => + usedFVarMap.insert mainFVarId sectionVarSet) + {}; +letRecsToLift.foldl + (fun (usedFVarMap : UsedFVarsMap) toLift => + let state := Lean.collectFVars {} toLift.val; + let state := Lean.collectFVars state toLift.type; + let set := state.fvarSet; + /- toLift.val may contain metavariables that are placeholders for nested let-recs. We should collect the fvarId + for the associated let-rec because we need this information to compute the fixpoint later. -/ + let mvarIds := (toLift.val.collectMVars {}).result; + let set := mvarIds.foldl + (fun (set : NameSet) (mvarId : MVarId) => + match letRecsToLift.findSome? fun (toLift : LetRecToLift) => if toLift.mvarId == mctx.getDelayedRoot mvarId then some toLift.fvarId else none with + | some fvarId => set.insert fvarId + | none => set) + set; + usedFVarMap.insert toLift.fvarId set) + usedFVarMap + +/- +The let-recs may invoke each other. Example: +``` +let rec + f (x : Nat) := g x + y + g : Nat → Nat + | 0 => 1 + | x+1 => f x + z +``` +`y` is free variable in `f`, and `z` is a free variable in `g`. +To close `f` and `g`, we `y` and `z` must be in the closure of both. +That is, we need to generate the top-level definitions. +``` +def f (y z x : Nat) := g y z x + y +def g (y z : Nat) : Nat → Nat + | 0 => 1 + | x+1 => f y z x + z +``` +-/ +namespace FixPoint structure State := -(closures : Closures := {}) -(modified : Bool := false) +(usedFVarsMap : UsedFVarsMap := {}) +(modified : Bool := false) abbrev M := ReaderT (List FVarId) $ StateM State -private def isModified : M Bool := do -s ← get; pure s.modified - -private def resetModified : M Unit := -modify fun s => { s with modified := false } - -private def markModified : M Unit := -modify fun s => { s with modified := true } - -private def getClosures : M Closures := -do s ← get; pure s.closures - -private def modifyClosures (f : Closures → Closures) : M Unit := -modify fun s => { s with closures := f s.closures } +private def isModified : M Bool := do s ← get; pure s.modified +private def resetModified : M Unit := modify fun s => { s with modified := false } +private def markModified : M Unit := modify fun s => { s with modified := true } +private def getUsedFVarsMap : M UsedFVarsMap := do s ← get; pure s.usedFVarsMap +private def modifyUsedFVars (f : UsedFVarsMap → UsedFVarsMap) : M Unit := modify fun s => { s with usedFVarsMap := f s.usedFVarsMap } -- merge s₂ into s₁ private def merge (s₁ s₂ : NameSet) : M NameSet := @@ -246,84 +329,229 @@ s₂.foldM pure $ s₁.insert k) s₁ -private def updateClosureOf (fvarId : FVarId) : M Unit := do -closures ← getClosures; -match closures.find? fvarId with +private def updateUsedVarsOf (fvarId : FVarId) : M Unit := do +usedFVarsMap ← getUsedFVarsMap; +match usedFVarsMap.find? fvarId with | none => pure () -| some closure => do - closureNew ← closure.foldM - (fun (closureNew : NameSet) (fvarId' : FVarId) => +| some fvarIds => do + fvarIdsNew ← fvarIds.foldM + (fun (fvarIdsNew : NameSet) (fvarId' : FVarId) => if fvarId == fvarId' then - pure closureNew + pure fvarIdsNew else - match closures.find? fvarId' with - | none => pure closureNew - -- We are being sloppy here `otherClosure` may contain free variables that are not in the context of the let-rec associated with fvarId. - -- We filter these out-of-context free variables later. - | some otherClosure => merge closureNew otherClosure) - closure; - modifyClosures fun closures => closures.insert fvarId closureNew + match usedFVarsMap.find? fvarId' with + | none => pure fvarIdsNew + /- We are being sloppy here `otherFVarIds` may contain free variables that are + not in the context of the let-rec associated with fvarId. + We filter these out-of-context free variables later. -/ + | some otherFVarIds => merge fvarIdsNew otherFVarIds) + fvarIds; + modifyUsedFVars fun usedFVars => usedFVars.insert fvarId fvarIdsNew -private partial def closureFixpoint : Unit → M Unit +private partial def fixpoint : Unit → M Unit | _ => do resetModified; letRecFVarIds ← read; - letRecFVarIds.forM updateClosureOf; - whenM isModified $ closureFixpoint () + letRecFVarIds.forM updateUsedVarsOf; + whenM isModified $ fixpoint () -private def mkInitialClosures (letRecsToLift : List LetRecToLift) : Closures := +def run (letRecFVarIds : List FVarId) (usedFVarsMap : UsedFVarsMap) : UsedFVarsMap := +let (_, s) := ((fixpoint ()).run letRecFVarIds).run { usedFVarsMap := usedFVarsMap }; +s.usedFVarsMap + +end FixPoint + +abbrev FreeVarMap := NameMap (Array FVarId) + +private def mkFreeVarMap + (mctx : MetavarContext) (sectionVars : Array Expr) (mainFVarIds : Array FVarId) + (recFVarIds : Array FVarId) (letRecsToLift : List LetRecToLift) : FreeVarMap := +let usedFVarsMap := mkInitialUsedFVarsMap mctx sectionVars mainFVarIds letRecsToLift; +let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId; +let usedFVarsMap := FixPoint.run letRecFVarIds usedFVarsMap; letRecsToLift.foldl - (fun (closures : Closures) toLift => - let state := Lean.collectFVars {} toLift.val; - let state := Lean.collectFVars state toLift.type; - closures.insert toLift.fvarId state.fvarSet) + (fun (freeVarMap : FreeVarMap) toLift => + let lctx := toLift.lctx; + let fvarIdsSet := (usedFVarsMap.find? toLift.fvarId).get!; + let fvarIds := fvarIdsSet.fold + (fun (fvarIds : Array FVarId) (fvarId : FVarId) => + if lctx.contains fvarId && !recFVarIds.contains fvarId then + fvarIds.push fvarId + else + fvarIds) + #[]; + freeVarMap.insert toLift.fvarId fvarIds) {} -private def mkClosures (letRecsToLift : List LetRecToLift) : Closures := -let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId; -let closures := mkInitialClosures letRecsToLift; -let (_, s) := ((closureFixpoint ()).run letRecFVarIds).run { closures := closures }; -s.closures +structure ClosureState := +(newLocalDecls : Array LocalDecl := #[]) +(localDecls : Array LocalDecl := #[]) +(newLetDecls : Array LocalDecl := #[]) +(exprArgs : Array Expr := #[]) -private def nameSetToFVars (s : NameSet) (lctx : LocalContext) (letRecFVarIds : List FVarId) : Array FVarId := -let fvarIds := s.fold - (fun (fvarIds : Array FVarId) (fvarId : FVarId) => - if lctx.contains fvarId && !letRecFVarIds.contains fvarId then - fvarIds.push fvarId - else - fvarIds) - #[]; -fvarIds.qsort fun fvarId₁ fvarId₂ => (lctx.get! fvarId₁).index < (lctx.get! fvarId₂).index +private def pickMaxFVar? (lctx : LocalContext) (fvarIds : Array FVarId) : Option FVarId := +fvarIds.getMax? fun fvarId₁ fvarId₂ => (lctx.get! fvarId₁).index < (lctx.get! fvarId₂).index -def main (letRecsToLift : List LetRecToLift) : TermElabM (List LetRecToLift) := do -let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId; -let closures := mkClosures letRecsToLift; --- Assign metavariables associated with each let-rec -ps ← letRecsToLift.mapM fun toLift => do { - let s := (closures.find? toLift.fvarId).get!; - let lctx := toLift.lctx; - let xs := (nameSetToFVars s lctx letRecFVarIds).map mkFVar; - assignExprMVar toLift.mvarId (mkAppN (Lean.mkConst toLift.declName) xs); - pure (xs, toLift) +private def preprocess (e : Expr) : TermElabM Expr := do +e ← instantiateMVars e; +-- which let-decls are dependent. We say a let-decl is dependent if its lambda abstraction is type incorrect. +liftM $ check e; +pure e + +/- Push free variables in `s` to `toProcess` if they are not already there. -/ +private def pushNewVars (toProcess : Array FVarId) (s : CollectFVars.State) : Array FVarId := +s.fvarSet.fold + (fun (toProcess : Array FVarId) fvarId => if toProcess.contains fvarId then toProcess else toProcess.push fvarId) + toProcess + +private def pushLocalDecl (toProcess : Array FVarId) (fvarId : FVarId) (userName : Name) (type : Expr) (bi := BinderInfo.default) + : StateRefT ClosureState TermElabM (Array FVarId) := do +type ← liftM $ preprocess type; +modify fun s => { s with + newLocalDecls := s.newLocalDecls.push $ LocalDecl.cdecl (arbitrary _) fvarId userName type bi, + exprArgs := s.exprArgs.push (mkFVar fvarId) }; -ps.mapM fun (p : Array Expr × LetRecToLift) => do - let (xs, toLift) := p; - type ← instantiateMVars toLift.type; - val ← instantiateMVars toLift.val; - let lctx := toLift.lctx; - -- Val may contain outer let-recs, we must replace them with constants - let val := val.replace fun x => match x with - | Expr.fvar fvarId _ => match ps.find? fun (p : Array Expr × LetRecToLift) => p.2.fvarId == fvarId with - | some p => some (mkAppN (Lean.mkConst p.2.declName) p.1) - | none => none - | _ => none; - withLCtx lctx toLift.localInstances do - -- Apply closure - type ← mkForallFVars xs type; - val ← mkLambdaFVars xs val; - pure { toLift with type := type, val := val } +pure $ pushNewVars toProcess (collectFVars {} type) -end LetRecClosure +private partial def mkClosureForAux : Array FVarId → StateRefT ClosureState TermElabM Unit +| toProcess => do + lctx ← getLCtx; + match pickMaxFVar? lctx toProcess with + | none => pure () + | some fvarId => do + let toProcess := toProcess.erase fvarId; + localDecl ← getLocalDecl fvarId; + match localDecl with + | LocalDecl.cdecl _ _ userName type bi => do + toProcess ← pushLocalDecl toProcess fvarId userName type bi; + mkClosureForAux toProcess + | LocalDecl.ldecl _ _ userName type val _ => do + zetaFVarIds ← getZetaFVarIds; + if !zetaFVarIds.contains fvarId then do + /- Non-dependent let-decl. See comment at src/Lean/Meta/Closure.lean -/ + toProcess ← pushLocalDecl toProcess fvarId userName type; + mkClosureForAux toProcess + else do + /- Dependent let-decl. -/ + type ← liftM $ preprocess type; + val ← liftM $ preprocess val; + modify fun s => { s with + newLetDecls := s.newLetDecls.push $ LocalDecl.ldecl (arbitrary _) fvarId userName type val false, + /- We don't want to interleave let and lambda declarations in our closure. So, we expand any occurrences of fvarId + at `newLocalDecls` and `localDecls` -/ + newLocalDecls := s.newLocalDecls.map (replaceFVarIdAtLocalDecl fvarId val), + localDecls := s.localDecls.map (replaceFVarIdAtLocalDecl fvarId val) + }; + mkClosureForAux (pushNewVars toProcess (collectFVars (collectFVars {} type) val)) + +structure LetRecClosure := +(localDecls : Array LocalDecl) +(closed : Expr) -- expression used to replace occurrences of the let-rec FVarId +(toLift : LetRecToLift) + +private def mkLetRecClosureFor (toLift : LetRecToLift) (freeVars : Array FVarId) : TermElabM LetRecClosure := do +let lctx := toLift.lctx; +withLCtx lctx toLift.localInstances do +lambdaTelescope toLift.val fun xs val => do + type ← instantiateForall toLift.type xs; + lctx ← getLCtx; + (_, s) ← (mkClosureForAux freeVars).run { localDecls := xs.map fun x => lctx.get! x.fvarId! }; + let type := Closure.mkForall s.localDecls $ Closure.mkForall s.newLetDecls type; + let val := Closure.mkLambda s.localDecls $ Closure.mkLambda s.newLetDecls val; + let c := mkAppN (Lean.mkConst toLift.declName) s.exprArgs; + assignExprMVar toLift.mvarId c; + pure ⟨s.newLocalDecls, c, { toLift with val := val, type := type }⟩ + +private def mkLetRecClosures (letRecsToLift : List LetRecToLift) (freeVarMap : FreeVarMap) : TermElabM (List LetRecClosure) := +letRecsToLift.mapM fun toLift => mkLetRecClosureFor toLift (freeVarMap.find? toLift.fvarId).get! + +/- Mapping from FVarId of mutually recursive functions being defined to "closure" expression. -/ +abbrev Replacement := NameMap Expr + +def insertReplacementForMainFns (r : Replacement) (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainFVars : Array Expr) : Replacement := +mainFVars.size.fold + (fun i (r : Replacement) => + r.insert (mainFVars.get! i).fvarId! (mkAppN (Lean.mkConst (mainHeaders.get! i).declName) sectionVars)) + r + +def insertReplacementForLetRecs (r : Replacement) (letRecClosures : List LetRecClosure) : Replacement := +letRecClosures.foldl (fun (r : Replacement) c => r.insert c.toLift.fvarId c.closed) r + +def Replacement.apply (r : Replacement) (e : Expr) : Expr := +e.replace fun e => match e with + | Expr.fvar fvarId _ => match r.find? fvarId with + | some c => some c + | _ => none + | _ => none + +def pushMain (preDecls : Array PreDeclaration) (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainVals : Array Expr) + : TermElabM (Array PreDeclaration) := +mainHeaders.size.foldM + (fun i (preDecls : Array PreDeclaration) => do + let header := mainHeaders.get! i; + val ← mkLambdaFVars sectionVars (mainVals.get! i); + type ← mkForallFVars sectionVars header.type; + pure $ preDecls.push { + kind := header.kind, + declName := header.declName, + modifiers := header.modifiers, + type := type, + val := val + }) + preDecls + +def pushLetRecs (preDecls : Array PreDeclaration) (letRecClosures : List LetRecClosure) (kind : DefKind) : Array PreDeclaration := +letRecClosures.foldl + (fun (preDecls : Array PreDeclaration) (c : LetRecClosure) => + let type := Closure.mkForall c.localDecls c.toLift.type; + let val := Closure.mkLambda c.localDecls c.toLift.val; + preDecls.push { + kind := kind, + declName := c.toLift.declName, + modifiers := { attrs := c.toLift.attrs }, + type := type, + val := val + }) + preDecls + +def getKindForLetRecs (mainHeaders : Array DefViewElabHeader) : DefKind := +if mainHeaders.any fun h => h.kind.isTheorem then DefKind.«theorem» +else DefKind.«def» + +/- +- `sectionVars`: The section variables used in the `mutual` block. +- `mainHeaders`: The elaborated header of the top-level definitions being defined by the mutual block. +- `mainFVars`: The auxiliary variables used to represent the top-level definitions being defined by the mutual block. +- `mainVals`: The elaborated value for the top-level definitions +- `letRecsToLift`: The let-rec's definitions that need to be lifted +-/ +def main (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainFVars : Array Expr) (mainVals : Array Expr) (letRecsToLift : List LetRecToLift) + : TermElabM (Array PreDeclaration) := do +-- Store in recFVarIds the fvarId of every function being defined by the mutual block. +let mainFVarIds := mainFVars.map Expr.fvarId!; +let recFVarIds := (letRecsToLift.toArray.map fun toLift => toLift.fvarId) ++ mainFVarIds; +-- Compute the set of free variables (excluding `recFVarIds`) for each let-rec. +mctx ← getMCtx; +let freeVarMap := mkFreeVarMap mctx sectionVars mainFVarIds recFVarIds letRecsToLift; +resetZetaFVarIds; +withTrackingZeta do +-- By checking `toLift.type` and `toLift.val` we populate `zetaFVarIds`. See comments at `src/Lean/Meta/Closure.lean`. +letRecsToLift.forM fun toLift => withLCtx toLift.lctx toLift.localInstances do { liftM $ check toLift.type; liftM $ check toLift.val }; +letRecClosures ← mkLetRecClosures letRecsToLift freeVarMap; +-- mkLetRecClosures assign metavariables that were placeholders for the lifted declarations. +mainVals ← mainVals.mapM instantiateMVars; +mainHeaders ← mainHeaders.mapM instantiateMVarsAtHeader; +letRecClosures ← letRecClosures.mapM fun closure => do { toLift ← instantiateMVarsAtLetRecToLift closure.toLift; pure { closure with toLift := toLift } }; +-- Replace fvarIds for functions being defined with closed terms +let r := insertReplacementForMainFns {} sectionVars mainHeaders mainFVars; +let r := insertReplacementForLetRecs r letRecClosures; +let mainVals := mainVals.map r.apply; +let mainHeaders := mainHeaders.map fun h => { h with type := r.apply h.type }; +let letRecClosures := letRecClosures.map fun c => { c with toLift := { c.toLift with type := r.apply c.toLift.type, val := r.apply c.toLift.val } }; +let letRecKind := getKindForLetRecs mainHeaders; +pushMain (pushLetRecs #[] letRecClosures letRecKind) sectionVars mainHeaders mainVals + +end MutualClosure def elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit := do scopeLevelNames ← getLevelNames; @@ -339,19 +567,12 @@ withFunLocalDecls headers fun funFVars => do letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift; checkLetRecsToLiftTypes funFVars letRecsToLift; withUsedWhen vars headers values letRecsToLift (not $ isTheorem views) fun vars => do - let values := values.map $ replaceFunFVarsWithConsts headers funFVars vars; - let letRecsToLift := letRecsToLift.map fun toLift => { toLift with val := replaceFunFVarsWithConsts headers funFVars vars toLift.val }; - values ← values.mapM $ mkLambdaFVars vars; - headers ← headers.mapM fun header => do { type ← mkForallFVars vars header.type; pure { header with type := type } }; - letRecsToLift ← LetRecClosure.main letRecsToLift; - values ← values.mapM instantiateMVars; + preDecls ← MutualClosure.main vars headers funFVars values letRecsToLift; -- TODO - values.forM fun val => IO.println (toString val); - letRecsToLift.forM fun toLift => IO.println (toString toLift.declName ++ " := " ++ toString toLift.val); + preDecls.forM fun preDecl => IO.println (toString preDecl.declName ++ " : " ++ toString preDecl.type ++ " :=\n" ++ toString preDecl.val ++ "\n"); throwError "WIP mutual def" end Term - namespace Command def elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do diff --git a/src/Lean/MetavarContext.lean b/src/Lean/MetavarContext.lean index 3957859391..8e6401f9c6 100644 --- a/src/Lean/MetavarContext.lean +++ b/src/Lean/MetavarContext.lean @@ -397,6 +397,21 @@ m.dAssignment.contains mvarId def eraseDelayed (m : MetavarContext) (mvarId : MVarId) : MetavarContext := { m with dAssignment := m.dAssignment.erase mvarId } +/- Given a sequence of delayed assignments + ``` + mvarId₁ := mvarId₂ ...; + ... + mvarIdₙ := mvarId_root ... -- where `mvarId_root` is not delayed assigned + ``` + in `mctx`, `getDelayedRoot mctx mvarId₁` return `mvarId_root`. + If `mvarId₁` is not delayed assigned then return `mvarId₁` -/ +partial def getDelayedRoot (m : MetavarContext) : MVarId → MVarId +| mvarId => match getDelayedAssignment? m mvarId with + | some d => match d.val.getAppFn with + | Expr.mvar mvarId _ => getDelayedRoot mvarId + | _ => mvarId + | none => mvarId + def isLevelAssignable (mctx : MetavarContext) (mvarId : MVarId) : Bool := match mctx.lDepth.find? mvarId with | some d => d == mctx.depth