diff --git a/src/Lean/Compiler/LCNF/CompilerM.lean b/src/Lean/Compiler/LCNF/CompilerM.lean index 142ddaa97a..f56431d1b6 100644 --- a/src/Lean/Compiler/LCNF/CompilerM.lean +++ b/src/Lean/Compiler/LCNF/CompilerM.lean @@ -35,8 +35,11 @@ def getLocalDecl (fvarId : FVarId) : CompilerM LocalDecl := do let some decl := (← get).lctx.localDecls.find? fvarId | throwError "unknown free variable {fvarId.name}" return decl +def findFunDecl? (fvarId : FVarId) : CompilerM (Option FunDecl) := + return (← get).lctx.funDecls.find? fvarId + def getFunDecl (fvarId : FVarId) : CompilerM FunDecl := do - let some decl := (← get).lctx.funDecls.find? fvarId | throwError "unknown local function {fvarId.name}" + let some decl ← findFunDecl? fvarId | throwError "unknown local function {fvarId.name}" return decl @[inline] def modifyLCtx (f : LCtx → LCtx) : CompilerM Unit := do diff --git a/src/Lean/Compiler/LCNF/Simp.lean b/src/Lean/Compiler/LCNF/Simp.lean index 5956de0e43..bca408082e 100644 --- a/src/Lean/Compiler/LCNF/Simp.lean +++ b/src/Lean/Compiler/LCNF/Simp.lean @@ -3,36 +3,14 @@ Copyright (c) 2022 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ -#exit -- TODO: port to new LCNF -import Lean.Compiler.CompilerM -import Lean.Compiler.Decl -import Lean.Compiler.Stage1 +import Lean.Util.Recognizers import Lean.Compiler.InlineAttrs +import Lean.Compiler.LCNF.CompilerM +import Lean.Compiler.LCNF.Stage1 -namespace Lean.Compiler +namespace Lean.Compiler.LCNF namespace Simp -partial def findLambdaCore? (lctx : LocalContext) (e : Expr) : Option LocalDecl := - match e with - | .fvar fvarId => - if let some d@(.ldecl (value := v) ..) := lctx.find? fvarId then - if v.isLambda then some d else findLambdaCore? lctx v - else - none - | .mdata _ e => findLambdaCore? lctx e - | _ => none - -partial def findLambda? (e : Expr) : CompilerM (Option LocalDecl) := - return findLambdaCore? (← getLCtx) e - -partial def findExpr (e : Expr) (skipMData := true): CompilerM Expr := do - match e with - | .fvar fvarId => - let some (.ldecl (value := v) ..) ← findDecl? fvarId | return e - findExpr v - | .mdata _ e' => if skipMData then findExpr e' else return e - | _ => return e - /-- Local function usage information used to decide whether it should be inlined or not. The information is an approximation, but it is on the "safe" side. That is, if we tagged @@ -41,7 +19,7 @@ a function with `.once`, then it is applied only once. A local function may be m a big problem in practice because we run the simplifier multiple times, and this information is recomputed from scratch at the beginning of each simplification step. -/ -inductive LocalFunInfo where +inductive FunDeclInfo where | /-- Local function is applied once, and must be inlined. -/ @@ -51,40 +29,67 @@ inductive LocalFunInfo where if it is small. -/ many + | /-- + Function must be inlined. + -/ + mustInline deriving Repr, Inhabited /-- Local function declaration statistics. - -Remark: we use the `userName` as the key. -/ -structure LocalFunInfoMap where +structure FunDeclInfoMap where /-- Mapping from local function name to inlining information. -/ - map : Std.HashMap Name LocalFunInfo := {} + map : Std.HashMap FVarId FunDeclInfo := {} deriving Inhabited -def LocalFunInfoMap.format (s : LocalFunInfoMap) : Format := Id.run do +def FunDeclInfoMap.format (s : FunDeclInfoMap) : CompilerM Format := do let mut result := Format.nil - for (k, n) in s.map.toList do - result := result ++ "\n" ++ f!"{k} ↦ {repr n}" + for (fvarId, info) in s.map.toList do + let localDecl ← getLocalDecl fvarId + result := result ++ "\n" ++ f!"{localDecl.userName} ↦ {repr info}" return result -instance : ToFormat LocalFunInfoMap where - format := LocalFunInfoMap.format - /-- Add new occurrence for the local function with binder name `key`. -/ -def LocalFunInfoMap.add (s : LocalFunInfoMap) (key : Name) : LocalFunInfoMap := +def FunDeclInfoMap.add (s : FunDeclInfoMap) (fvarId : FVarId) : FunDeclInfoMap := match s with | { map } => - match map.find? key with - | some .once => { map := map.insert key .many } - | none => { map := map.insert key .once } + match map.find? fvarId with + | some .once => { map := map.insert fvarId .many } + | none => { map := map.insert fvarId .once } | _ => { map } +/-- +Add new occurrence for the local function with binder name `key`. +-/ +def FunDeclInfoMap.addMustInline (s : FunDeclInfoMap) (fvarId : FVarId) : FunDeclInfoMap := + match s with + | { map } => { map := map.insert fvarId .mustInline } + +partial def findFunDecl? (e : Expr) : CompilerM (Option FunDecl) := do + match e with + | .fvar fvarId => + if let some decl ← LCNF.findFunDecl? fvarId then + return some decl + else if let .ldecl (value := v) .. ← getLocalDecl fvarId then + findFunDecl? v + else + return none + | .mdata _ e => findFunDecl? e + | _ => return none + +partial def findExpr (e : Expr) (skipMData := true) : CompilerM Expr := do + match e with + | .fvar fvarId => + let .ldecl (value := v) .. ← getLocalDecl fvarId | return e + findExpr v + | .mdata _ e' => if skipMData then findExpr e' else return e + | _ => return e + structure Config where smallThreshold : Nat := 1 @@ -92,10 +97,8 @@ structure Context where config : Config := {} structure State where - /-- - (Approximate) information for deciding whether to inline local function declarations. - -/ - localInfoMap : LocalFunInfoMap := {} + subst : FVarSubst := {} + funDeclInfoMap : FunDeclInfoMap := {} /-- `true` if some simplification was performed in the current simplification pass. -/ @@ -125,120 +128,31 @@ structure State where This is a performance counter. -/ inlineLocal : Nat := 0 - deriving Inhabited abbrev SimpM := ReaderT Context $ StateRefT State CompilerM -/- -Ensure binder names are unique, and update local function information. -If `mustInline = true`, then local functions in `e` are marked with binders of the -form `_mustInline.`. -Remark: we used to store the `mustInline` information in the map `localInfoMap`, -using a `.mustInline` constructor at `LocalFunInfo`. However, this was incorrect -because there is no guarantee that we will be able to inline all occurrences of the -function in the current `simp` step. Since, we recompute `localInfoMap` from scratch -at the beginning of each compiler pass, the information was being lost. --/ - -structure Internalize.State where - nextIdx : Nat - localInfoMap : LocalFunInfoMap - -private def updateFunInfo (key : Name) (mustInline : Bool) : StateM Internalize.State Unit := - unless mustInline do - modify fun s => { s with localInfoMap := s.localInfoMap.add key } - -/-- -`instantiateRevInternalize` implementation. --/ -private def instantiateRevInternalizeCore (lctx : LocalContext) (e : Expr) (args : Array Expr) (mustInline : Bool) : StateM Internalize.State Expr := - go e {} +partial def updateFunDeclInfo (code : Code) (mustInline := false) : SimpM Unit := + go code where - /-- Auxiliary functions for instantiating `args` in types. -/ - inst (e : Expr) (offset : Nat) : Expr := - match e with - | .sort .. | .lit .. | .const .. | .mvar .. | .fvar .. => e - | .mdata k b => .mdata k (inst b offset) - | .proj s i b => .proj s i (inst b offset) - | .app f a => if offset >= e.looseBVarRange then e else .app (inst f offset) (inst a offset) - | .bvar idx => if idx >= offset then args[args.size - (idx - offset) - 1]! else e - | .forallE n d b bi => if offset >= e.looseBVarRange then e else .forallE n (inst d offset) (inst b (offset + 1)) bi - | .lam n d b bi => if offset >= e.looseBVarRange then e else .lam n (inst d offset) (inst b (offset + 1)) bi - | .letE n t v b nd => if offset >= e.looseBVarRange then e else .letE n (inst t offset) (inst v offset) (inst b (offset + 1)) nd + go (code : Code) : SimpM Unit := do + match code with + | .let decl k => + if decl.value.isApp then + if let some funDecl ← findFunDecl? decl.value.getAppFn then + modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.add funDecl.fvarId } + go k + | .fun decl k => + if mustInline then + modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.addMustInline decl.fvarId } + go decl.value; go k + | .jp decl k => go decl.value; go k + | .cases c => c.alts.forM fun alt => go alt.getCode + | .return .. | .jmp .. | .unreach .. => return () - go (e : Expr) (ctx : Std.PArray (Option Name)) : StateM Internalize.State Expr := do - let instantiate (e : Expr) := if args.size == 0 then e else inst e ctx.size - let updtBVar (idx : Nat) := - let offset := ctx.size - if idx >= offset then - args[args.size - (idx - offset) - 1]! - else - .bvar idx - match e with - | .sort .. | .lit .. | .const .. | .mvar .. | .fvar .. => return e - | .mdata k b => return .mdata k (← go b ctx) - | .proj s i b => return .proj s i (← go b ctx) - | .app f a => - let f ← go f ctx - let a ← go a ctx - match f with - | .fvar .. => - match findLambdaCore? lctx f with - | some localDecl => updateFunInfo localDecl.userName mustInline - | _ => pure () - | .bvar idx => - match ctx[ctx.size - idx - 1]! with - | some binderName => updateFunInfo binderName mustInline - | none => pure () - | _ => pure () - return .app f a - | .bvar idx => return updtBVar idx - | .forallE .. => return instantiate e - | .lam n d b bi => return .lam n (instantiate d) (← go b (ctx.push none)) bi - | .letE binderName type value body nonDep => - let idx ← modifyGet fun { nextIdx, localInfoMap } => (nextIdx, { nextIdx := nextIdx + 1, localInfoMap }) - let binderName' := - if mustInline && value.isLambda then - .num `_mustInline idx - else match binderName with - | .num p _ => .num p idx - | _ => .num binderName idx - let type := instantiate type - let value ← go value ctx - let ctxVal := match value with - | .lam .. => some binderName' - -- The next two cases simulate findLambdaCore? for `ctx` - | .fvar .. => match findLambdaCore? lctx value with - | some localDecl => some localDecl.userName - | _ => none - | .bvar idx => if idx < ctx.size then ctx[ctx.size - idx - 1]! else none - | _ => none - return .letE binderName' type value (← go body (ctx.push ctxVal)) nonDep - -/-- -This function performs the following operations in the given expression in a single pass. -- Ensure binder names for let-declarations are unique. -- Update local function information. That is, it updates the map `localInfoMap`. -- Apply `e.instantiateRev args`. - -We use it to "internalize" expressions at startup and when performing inlining. --/ -def instantiateRevInternalize (e : Expr) (args : Array Expr) (mustInline := false) : SimpM Expr := do - let lctx ← getLCtx - let nextIdx := (← getThe CompilerM.State).nextIdx - let localInfoMap ← modifyGet fun s => (s.localInfoMap, { s with localInfoMap := {} }) - let (e, { localInfoMap, nextIdx }) := instantiateRevInternalizeCore lctx e args mustInline |>.run { nextIdx, localInfoMap } - modifyThe CompilerM.State fun s => { s with nextIdx } - modify fun s => { s with localInfoMap } - return e - -/-- -This function performs the following operations in the given expression in a single pass. -- Ensure binder names for let-declarations are unique. -- Update local function information. That is, it updates the map `localInfoMap`. --/ -def internalize (e : Expr) (mustInline := false) : SimpM Expr := do - instantiateRevInternalize e #[] mustInline +def isOnceOrMustInline (fvarId : FVarId) : SimpM Bool := do + match (← get).funDeclInfoMap.map.find? fvarId with + | some .once | some .mustInline => return true + | _ => return false def markSimplified : SimpM Unit := modify fun s => { s with simplified := true } @@ -274,12 +188,54 @@ def simpAppApp? (e : Expr) : OptionT SimpM Expr := do markSimplified return mkAppN f e.getAppArgs -def isOnceOrMustInline (binderName : Name) : SimpM Bool := do - if binderName.getPrefix == `_mustInline then - return true - else match (← get).localInfoMap.map.find? binderName with - | some .once => return true - | _ => return false +end Simp + +builtin_initialize + registerTraceClass `Compiler.simp.inline + registerTraceClass `Compiler.simp.inline.info + registerTraceClass `Compiler.simp.stat + registerTraceClass `Compiler.simp.step + registerTraceClass `Compiler.simp.step.new + registerTraceClass `Compiler.simp.projInst + +end Lean.Compiler.LCNF + +#exit -- TODO: port rest of file + + +namespace Lean.Compiler +namespace Simp + + +/- +Ensure binder names are unique, and update local function information. +If `mustInline = true`, then local functions in `e` are marked with binders of the +form `_mustInline.`. +Remark: we used to store the `mustInline` information in the map `localInfoMap`, +using a `.mustInline` constructor at `LocalFunInfo`. However, this was incorrect +because there is no guarantee that we will be able to inline all occurrences of the +function in the current `simp` step. Since, we recompute `localInfoMap` from scratch +at the beginning of each compiler pass, the information was being lost. +-/ + + +/-- +This function performs the following operations in the given expression in a single pass. +- Ensure binder names for let-declarations are unique. +- Update local function information. That is, it updates the map `localInfoMap`. +- Apply `e.instantiateRev args`. + +We use it to "internalize" expressions at startup and when performing inlining. +-/ +def instantiateRevInternalize (e : Expr) (args : Array Expr) (mustInline := false) : SimpM Expr := do + let lctx ← getLCtx + let nextIdx := (← getThe CompilerM.State).nextIdx + let localInfoMap ← modifyGet fun s => (s.localInfoMap, { s with localInfoMap := {} }) + let (e, { localInfoMap, nextIdx }) := instantiateRevInternalizeCore lctx e args mustInline |>.run { nextIdx, localInfoMap } + modifyThe CompilerM.State fun s => { s with nextIdx } + modify fun s => { s with localInfoMap } + return e + def isSmallValue (value : Expr) : SimpM Bool := do lcnfSizeLe value (← read).config.smallThreshold @@ -748,12 +704,4 @@ partial def Decl.simp (decl : Decl) : CoreM Decl := do else return decl -builtin_initialize - registerTraceClass `Compiler.simp.inline - registerTraceClass `Compiler.simp.inline.info - registerTraceClass `Compiler.simp.stat - registerTraceClass `Compiler.simp.step - registerTraceClass `Compiler.simp.step.new - registerTraceClass `Compiler.simp.projInst - end Lean.Compiler