diff --git a/src/Lean/Compiler/LCNF/Bind.lean b/src/Lean/Compiler/LCNF/Bind.lean index 114980e440..63a0034708 100644 --- a/src/Lean/Compiler/LCNF/Bind.lean +++ b/src/Lean/Compiler/LCNF/Bind.lean @@ -35,9 +35,7 @@ where | .default k => return .default (← go k) if alts.isEmpty then throwError "`Code.bind` failed, empty `cases` found" - let mut resultType ← alts[0]!.inferType - for alt in alts[1:] do - resultType := joinTypes resultType (← alt.inferType) + let resultType ← mkCasesResultType alts return .cases { c with alts, resultType } | .return fvarId => f fvarId | .jmp fvarId .. => diff --git a/src/Lean/Compiler/LCNF/CompilerM.lean b/src/Lean/Compiler/LCNF/CompilerM.lean index 7d94c8a606..2c94aea510 100644 --- a/src/Lean/Compiler/LCNF/CompilerM.lean +++ b/src/Lean/Compiler/LCNF/CompilerM.lean @@ -294,6 +294,9 @@ end @[inline] def normCode [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m] (code : Code) : m Code := do normCodeImp code (← getSubst) +def replaceExprFVars (e : Expr) (s : FVarSubst) : CompilerM Expr := + (normExpr e : ReaderT FVarSubst CompilerM Expr).run s + def replaceFVars (code : Code) (s : FVarSubst) : CompilerM Code := (normCode code : ReaderT FVarSubst CompilerM Code).run s diff --git a/src/Lean/Compiler/LCNF/InferType.lean b/src/Lean/Compiler/LCNF/InferType.lean index c7eaa05d35..5bfe15a2fb 100644 --- a/src/Lean/Compiler/LCNF/InferType.lean +++ b/src/Lean/Compiler/LCNF/InferType.lean @@ -219,4 +219,12 @@ where return type termination_by go i => params.size - i +def mkCasesResultType (alts : Array Alt) : CompilerM Expr := do + if alts.isEmpty then + throwError "`Code.bind` failed, empty `cases` found" + let mut resultType ← alts[0]!.inferType + for alt in alts[1:] do + resultType := joinTypes resultType (← alt.inferType) + return resultType + end Lean.Compiler.LCNF diff --git a/src/Lean/Compiler/LCNF/ToLCNF.lean b/src/Lean/Compiler/LCNF/ToLCNF.lean index f546429529..240088b8cb 100644 --- a/src/Lean/Compiler/LCNF/ToLCNF.lean +++ b/src/Lean/Compiler/LCNF/ToLCNF.lean @@ -36,6 +36,116 @@ inductive Element where | unreach (fvarId : FVarId) deriving Inhabited +/-- State for `BindCasesM` monad -/ +structure BindCasesM.State where + /-- New auxiliary join points. -/ + jps : Array FunDecl := #[] + /-- Mapping from `_alt.` variables to new join points -/ + map : FVarIdMap FVarId := {} + +/-- Auxiliary monad for implementing `bindCases` -/ +abbrev BindCasesM := StateRefT BindCasesM.State CompilerM + +/-- +This method returns code that at each exit point of `cases`, it jumps to `jpDecl`. +It is similar to `Code.bind`, but we add special support for `inlineMatcher`. +The `inlineMatcher` function inlines the auxiliary `_match_` declarations. +To make sure there is no code duplication, `inlineMatcher` creates auxiliary declarations `_alt.`. +We can say the `_alt.` declarations are pre join points. For each auxiliary declaration used at +an exit point of `cases`, this method creates an new auxiliary join point that invokes `_alt.`, +and then jumps to `jpDecl`. The goal is to make sure the auxiliary join point is the only occurrence +of `_alt.`, then `simp` will inline it. +That is, our goal is to try to promote the pre join points `_alt.` into a proper join point. +-/ +partial def bindCases (jpDecl : FunDecl) (cases : Cases) : CompilerM Code := do + let (alts, s) ← visitAlts cases.alts |>.run {} + let resultType ← mkCasesResultType alts + let mut result := .cases { cases with alts, resultType } + for decl in s.jps do + result := .jp decl result + return .jp jpDecl result +where + visitAlts (alts : Array Alt) : BindCasesM (Array Alt) := + alts.mapM fun alt => return alt.updateCode (← go alt.getCode) + + findFun? (f : FVarId) : CompilerM (Option FunDecl) := do + if let some funDecl ← findFunDecl? f then + return funDecl + else if let .ldecl (value := .fvar f') .. ← getLocalDecl f then + findFun? f' + else + return none + + go (code : Code) : BindCasesM Code := do + match code with + | .let decl k => + if let .return fvarId := k then + /- + Check whether the current let-declaration is of the form + ``` + let _x := _alt. args + return _x + ``` + where `_alt.` is an auxiliary declaration created by `inlineMatcher` + -/ + if decl.fvarId == fvarId && decl.value.isApp && decl.value.getAppFn.isFVar then + let f := decl.value.getAppFn.fvarId! + let localDecl ← getLocalDecl f + if localDecl.userName.getPrefix == `_alt then + if let some funDecl ← findFun? f then + let args := decl.value.getAppArgs + eraseFVar decl.fvarId + if let some altJp := (← get).map.find? f then + /- We already have an auxiliary join point for `f`, then, we just use it. -/ + return .jmp altJp args + else + /- + We have not created a join point for `f` yet. + The join point has the form + ``` + jp altJp jpParams := + let _x := f jpParams + jmp jpDecl _x + ``` + Then, we replace the current `let`-declaration with `jmp altJp args` + -/ + let mut jpParams := #[] + let mut subst := {} + let mut jpArgs := #[] + /- Remark: `funDecl.params.size` may be greater than `args.size`. -/ + for param in funDecl.params[:args.size] do + let type ← replaceExprFVars param.type subst + let paramNew ← mkAuxParam type + jpParams := jpParams.push paramNew + let arg := .fvar paramNew.fvarId + subst := subst.insert param.fvarId arg + jpArgs := jpArgs.push arg + let letDecl ← mkAuxLetDecl (mkAppN decl.value.getAppFn jpArgs) + let jpValue := .let letDecl (.jmp jpDecl.fvarId #[.fvar letDecl.fvarId]) + let altJp ← mkAuxJpDecl jpParams jpValue + modify fun { jps, map } => { + jps := jps.push altJp + map := map.insert f altJp.fvarId + } + return .jmp altJp.fvarId args + return .let decl (← go k) + | .fun decl k => return .fun decl (← go k) + | .jp decl k => + let value ← go decl.value + let type ← value.inferParamType decl.params + let decl ← decl.update' type value + return .jp decl (← go k) + | .cases c => + let alts ← c.alts.mapM fun + | .alt ctorName params k => return .alt ctorName params (← go k) + | .default k => return .default (← go k) + if alts.isEmpty then + throwError "`Code.bind` failed, empty `cases` found" + let resultType ← mkCasesResultType alts + return .cases { c with alts, resultType } + | .return fvarId => return .jmp jpDecl.fvarId #[.fvar fvarId] + | .jmp .. | .unreach .. => return code + def seqToCode (seq : Array Element) (e : Expr) : CompilerM Code := do if let .fvar fvarId := e then go seq seq.size (.return fvarId) @@ -69,8 +179,7 @@ where else /- Create a join point for `c` and jump to it from `cases` -/ let jpDecl ← mkAuxJpDecl' fvarId c - let cases ← (Code.cases cases).bind fun fvarId => return .jmp jpDecl.fvarId #[.fvar fvarId] - go seq (i - 1) (.jp jpDecl cases) + go seq (i - 1) (← bindCases jpDecl cases) else return c