From 06c6002d454bc61a2c4686ee97feebd09e8e83e2 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 2 Sep 2020 18:53:18 -0700 Subject: [PATCH] feat: lift nested 'let-rec's --- src/Lean/Elab/LetRec.lean | 2 +- src/Lean/Elab/MutualDef.lean | 133 +++++++++++++++++++++++++++++++++++ 2 files changed, 134 insertions(+), 1 deletion(-) diff --git a/src/Lean/Elab/LetRec.lean b/src/Lean/Elab/LetRec.lean index e2a00a54b7..25570f03ed 100644 --- a/src/Lean/Elab/LetRec.lean +++ b/src/Lean/Elab/LetRec.lean @@ -49,7 +49,7 @@ decls ← (letRec.getArg 1).getArgs.getSepElems.mapM fun attrDeclStx => do { type ← mkForallFVars xs type; pure (type, xs.size) }; - mvar ← mkFreshExprMVar type MetavarKind.synthetic; + mvar ← mkFreshExprMVar type MetavarKind.syntheticOpaque; valStx ← if decl.isOfKind `Lean.Parser.Term.letIdDecl then pure $ decl.getArg 4 diff --git a/src/Lean/Elab/MutualDef.lean b/src/Lean/Elab/MutualDef.lean index 7f751abf6c..4444b04bc3 100644 --- a/src/Lean/Elab/MutualDef.lean +++ b/src/Lean/Elab/MutualDef.lean @@ -199,6 +199,131 @@ letRecsToLift.forM fun toLift => do fnName ← getFunName fvarId letRecsToLift; throwErrorAt toLift.ref ("invalid type in 'let rec', it uses '" ++ fnName ++ "' which is being defined simultaneously") +private def replaceFunFVarsWithConsts (headers : Array DefViewElabHeader) (funFVars : Array Expr) (vars : Array Expr) (e : Expr) : Expr := +e.replace fun x => match x with + | Expr.fvar fvarId _ => + match funFVars.indexOf x with + | some idx => match headers.get? idx.val with + | some header => some $ mkAppN (Lean.mkConst header.declName) vars -- Remark: we add the universe levels later + | none => none + | none => none + | _ => none + +/- A mapping from -/ +abbrev Closures := NameMap NameSet + +namespace LetRecClosure + +structure State := +(closures : Closures := {}) +(modified : Bool := false) + +abbrev M := ReaderT (List FVarId) $ StateM State + +private def isModified : M Bool := do +s ← get; pure s.modified + +private def resetModified : M Unit := +modify fun s => { s with modified := false } + +private def markModified : M Unit := +modify fun s => { s with modified := true } + +private def getClosures : M Closures := +do s ← get; pure s.closures + +private def modifyClosures (f : Closures → Closures) : M Unit := +modify fun s => { s with closures := f s.closures } + +-- merge s₂ into s₁ +private def merge (s₁ s₂ : NameSet) : M NameSet := +s₂.foldM + (fun (s₁ : NameSet) k => + if s₁.contains k then + pure s₁ + else do + markModified; + pure $ s₁.insert k) + s₁ + +private def updateClosureOf (fvarId : FVarId) : M Unit := do +closures ← getClosures; +match closures.find? fvarId with +| none => pure () +| some closure => do + closureNew ← closure.foldM + (fun (closureNew : NameSet) (fvarId' : FVarId) => + if fvarId == fvarId' then + pure closureNew + else + match closures.find? fvarId' with + | none => pure closureNew + -- We are being sloppy here `otherClosure` may contain free variables that are not in the context of the let-rec associated with fvarId. + -- We filter these out-of-context free variables later. + | some otherClosure => merge closureNew otherClosure) + closure; + modifyClosures fun closures => closures.insert fvarId closureNew + +private partial def closureFixpoint : Unit → M Unit +| _ => do + resetModified; + letRecFVarIds ← read; + letRecFVarIds.forM updateClosureOf; + whenM isModified $ closureFixpoint () + +private def mkInitialClosures (letRecsToLift : List LetRecToLift) : Closures := +letRecsToLift.foldl + (fun (closures : Closures) toLift => + let state := Lean.collectFVars {} toLift.val; + let state := Lean.collectFVars state toLift.type; + closures.insert toLift.fvarId state.fvarSet) + {} + +private def mkClosures (letRecsToLift : List LetRecToLift) : Closures := +let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId; +let closures := mkInitialClosures letRecsToLift; +let (_, s) := ((closureFixpoint ()).run letRecFVarIds).run { closures := closures }; +s.closures + +private def nameSetToFVars (s : NameSet) (lctx : LocalContext) (letRecFVarIds : List FVarId) : Array FVarId := +let fvarIds := s.fold + (fun (fvarIds : Array FVarId) (fvarId : FVarId) => + if lctx.contains fvarId && !letRecFVarIds.contains fvarId then + fvarIds.push fvarId + else + fvarIds) + #[]; +fvarIds.qsort fun fvarId₁ fvarId₂ => (lctx.get! fvarId₁).index < (lctx.get! fvarId₂).index + +def main (letRecsToLift : List LetRecToLift) : TermElabM (List LetRecToLift) := do +let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId; +let closures := mkClosures letRecsToLift; +-- Assign metavariables associated with each let-rec +ps ← letRecsToLift.mapM fun toLift => do { + let s := (closures.find? toLift.fvarId).get!; + let lctx := toLift.lctx; + let xs := (nameSetToFVars s lctx letRecFVarIds).map mkFVar; + assignExprMVar toLift.mvarId (mkAppN (Lean.mkConst toLift.declName) xs); + pure (xs, toLift) +}; +ps.mapM fun (p : Array Expr × LetRecToLift) => do + let (xs, toLift) := p; + type ← instantiateMVars toLift.type; + val ← instantiateMVars toLift.val; + let lctx := toLift.lctx; + -- Val may contain outer let-recs, we must replace them with constants + let val := val.replace fun x => match x with + | Expr.fvar fvarId _ => match ps.find? fun (p : Array Expr × LetRecToLift) => p.2.fvarId == fvarId with + | some p => some (mkAppN (Lean.mkConst p.2.declName) p.1) + | none => none + | _ => none; + -- Apply closure + let type := lctx.mkForall xs type; + let val := lctx.mkLambda xs val; + pure { toLift with type := type, val := val } + +end LetRecClosure + def elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit := do scopeLevelNames ← getLevelNames; headers ← elabHeaders views; @@ -213,7 +338,15 @@ withFunLocalDecls headers fun funFVars => do letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift; checkLetRecsToLiftTypes funFVars letRecsToLift; withUsedWhen vars headers values letRecsToLift (not $ isTheorem views) fun vars => do + let values := values.map $ replaceFunFVarsWithConsts headers funFVars vars; + let letRecsToLift := letRecsToLift.map fun toLift => { toLift with val := replaceFunFVarsWithConsts headers funFVars vars toLift.val }; + values ← values.mapM $ mkLambdaFVars vars; + headers ← headers.mapM fun header => do { type ← mkForallFVars vars header.type; pure { header with type := type } }; + letRecsToLift ← LetRecClosure.main letRecsToLift; + values ← values.mapM instantiateMVars; + -- TODO values.forM fun val => IO.println (toString val); + letRecsToLift.forM fun toLift => IO.println (toString toLift.declName ++ " := " ++ toString toLift.val); throwError "WIP mutual def" end Term