diff --git a/src/Lean/Elab/MutualDef.lean b/src/Lean/Elab/MutualDef.lean index cadb1ce9ae..8374f49178 100644 --- a/src/Lean/Elab/MutualDef.lean +++ b/src/Lean/Elab/MutualDef.lean @@ -51,7 +51,7 @@ if newHeader.kind.isTheorem && newHeader.modifiers.isUnsafe then if newHeader.kind.isTheorem && newHeader.modifiers.isPartial then throwError "'partial' theorems are not allowed, 'partial' is a code generation directive" if newHeader.kind.isTheorem && newHeader.modifiers.isNoncomputable then - throwError "'theorem' subsumes 'noncomputable', code is not generated for theorems"; + throwError "'theorem' subsumes 'noncomputable', code is not generated for theorems" if newHeader.modifiers.isNoncomputable && newHeader.modifiers.isUnsafe then throwError "'noncomputable unsafe' is not allowed" if newHeader.modifiers.isNoncomputable && newHeader.modifiers.isPartial then @@ -59,7 +59,7 @@ if newHeader.modifiers.isNoncomputable && newHeader.modifiers.isPartial then if newHeader.modifiers.isPartial && newHeader.modifiers.isUnsafe then throwError "'unsafe' subsumes 'partial'" if h : 0 < prevHeaders.size then - let firstHeader := prevHeaders.get ⟨0, h⟩; + let firstHeader := prevHeaders.get ⟨0, h⟩ try unless newHeader.levelNames == firstHeader.levelNames do throwError "universe parameters mismatch" @@ -187,7 +187,7 @@ pure { toLift with type := type, val := val } private def typeHasRecFun (type : Expr) (funFVars : Array Expr) (letRecsToLift : List LetRecToLift) : Option FVarId := let occ? := type.find? fun e => match e with | Expr.fvar fvarId _ => funFVars.contains e || letRecsToLift.any fun toLift => toLift.fvarId == fvarId - | _ => false; + | _ => false match occ? with | some (Expr.fvar fvarId _) => some fvarId | _ => none @@ -265,28 +265,26 @@ Note that `g` is not a free variable at `(let g : B := ?m₂; body)`. We recover `f` depends on `g` because it contains `m₂` -/ private def mkInitialUsedFVarsMap (mctx : MetavarContext) (sectionVars : Array Expr) (mainFVarIds : Array FVarId) (letRecsToLift : List LetRecToLift) - : UsedFVarsMap := -let sectionVarSet := sectionVars.foldl (fun (s : NameSet) (var : Expr) => s.insert var.fvarId!) {} -let usedFVarMap := mainFVarIds.foldl - (fun (usedFVarMap : UsedFVarsMap) mainFVarId => - usedFVarMap.insert mainFVarId sectionVarSet) - {} -letRecsToLift.foldl - (fun (usedFVarMap : UsedFVarsMap) toLift => - let state := Lean.collectFVars {} toLift.val - let state := Lean.collectFVars state toLift.type - let set := state.fvarSet - /- toLift.val may contain metavariables that are placeholders for nested let-recs. We should collect the fvarId - for the associated let-rec because we need this information to compute the fixpoint later. -/ - let mvarIds := (toLift.val.collectMVars {}).result - let set := mvarIds.foldl - (fun (set : NameSet) (mvarId : MVarId) => - match letRecsToLift.findSome? fun (toLift : LetRecToLift) => if toLift.mvarId == mctx.getDelayedRoot mvarId then some toLift.fvarId else none with - | some fvarId => set.insert fvarId - | none => set) - set - usedFVarMap.insert toLift.fvarId set) - usedFVarMap + : UsedFVarsMap := do +let sectionVarSet := {} +for var in sectionVars do + sectionVarSet := sectionVarSet.insert var.fvarId! +let usedFVarMap := {} +for mainFVarId in mainFVarIds do + usedFVarMap := usedFVarMap.insert mainFVarId sectionVarSet +for toLift in letRecsToLift do + let state := Lean.collectFVars {} toLift.val + let state := Lean.collectFVars state toLift.type + let set := state.fvarSet + /- toLift.val may contain metavariables that are placeholders for nested let-recs. We should collect the fvarId + for the associated let-rec because we need this information to compute the fixpoint later. -/ + let mvarIds := (toLift.val.collectMVars {}).result + for mvarId in mvarIds do + match letRecsToLift.findSome? fun (toLift : LetRecToLift) => if toLift.mvarId == mctx.getDelayedRoot mvarId then some toLift.fvarId else none with + | some fvarId => set := set.insert fvarId + | none => pure () + usedFVarMap := usedFVarMap.insert toLift.fvarId set +pure usedFVarMap /- The let-recs may invoke each other. Example: @@ -328,7 +326,7 @@ s₂.foldM if s₁.contains k then pure s₁ else do - markModified; + markModified pure $ s₁.insert k) s₁ @@ -348,7 +346,7 @@ match usedFVarsMap.find? fvarId with not in the context of the let-rec associated with fvarId. We filter these out-of-context free variables later. -/ | some otherFVarIds => merge fvarIdsNew otherFVarIds) - fvarIds; + fvarIds modifyUsedFVars fun usedFVars => usedFVars.insert fvarId fvarIdsNew private partial def fixpoint : Unit → M Unit @@ -369,23 +367,23 @@ abbrev FreeVarMap := NameMap (Array FVarId) private def mkFreeVarMap (mctx : MetavarContext) (sectionVars : Array Expr) (mainFVarIds : Array FVarId) - (recFVarIds : Array FVarId) (letRecsToLift : List LetRecToLift) : FreeVarMap := + (recFVarIds : Array FVarId) (letRecsToLift : List LetRecToLift) : FreeVarMap := do let usedFVarsMap := mkInitialUsedFVarsMap mctx sectionVars mainFVarIds letRecsToLift let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId let usedFVarsMap := FixPoint.run letRecFVarIds usedFVarsMap -letRecsToLift.foldl - (fun (freeVarMap : FreeVarMap) toLift => - let lctx := toLift.lctx - let fvarIdsSet := (usedFVarsMap.find? toLift.fvarId).get! - let fvarIds := fvarIdsSet.fold - (fun (fvarIds : Array FVarId) (fvarId : FVarId) => - if lctx.contains fvarId && !recFVarIds.contains fvarId then - fvarIds.push fvarId - else - fvarIds) - #[] - freeVarMap.insert toLift.fvarId fvarIds) - {} +let freeVarMap := {} +for toLift in letRecsToLift do + let lctx := toLift.lctx + let fvarIdsSet := (usedFVarsMap.find? toLift.fvarId).get! + let fvarIds := fvarIdsSet.fold + (fun (fvarIds : Array FVarId) (fvarId : FVarId) => + if lctx.contains fvarId && !recFVarIds.contains fvarId then + fvarIds.push fvarId + else + fvarIds) + #[] + freeVarMap := freeVarMap.insert toLift.fvarId fvarIds +pure freeVarMap structure ClosureState := (newLocalDecls : Array LocalDecl := #[]) @@ -431,7 +429,7 @@ private partial def mkClosureForAux : Array FVarId → StateRefT ClosureState Te let toProcess ← pushLocalDecl toProcess fvarId userName type bi mkClosureForAux toProcess | LocalDecl.ldecl _ _ userName type val _ => - let zetaFVarIds ← getZetaFVarIds; + let zetaFVarIds ← getZetaFVarIds if !zetaFVarIds.contains fvarId then /- Non-dependent let-decl. See comment at src/Lean/Meta/Closure.lean -/ let toProcess ← pushLocalDecl toProcess fvarId userName type