415 lines
15 KiB
Text
415 lines
15 KiB
Text
/-
|
||
Copyright (c) 2022 Henrik Böving. All rights reserved.
|
||
Released under Apache 2.0 license as described in the file LICENSE.
|
||
Authors: Henrik Böving
|
||
-/
|
||
import Lean.Compiler.LCNF.CompilerM
|
||
import Lean.Compiler.LCNF.InferType
|
||
import Lean.Compiler.LCNF.PassManager
|
||
|
||
|
||
namespace Lean.Compiler.LCNF.Simp
|
||
namespace ConstantFold
|
||
|
||
/--
|
||
A constant folding monad, the additional state stores auxiliary declarations
|
||
required to build the new constant.
|
||
-/
|
||
abbrev FolderM := StateRefT (Array CodeDecl) CompilerM
|
||
|
||
/--
|
||
A constant folder for a specific function, takes all the arguments of a
|
||
certain function and produces a new `Expr` + auxiliary declarations in
|
||
the `FolderM` monad on success. If the folding fails it returns `none`.
|
||
-/
|
||
abbrev Folder := Array Expr → FolderM (Option Expr)
|
||
|
||
/--
|
||
A typeclass for detecting and producing literals of arbitrary types
|
||
inside of LCNF.
|
||
-/
|
||
class Literal (α : Type) where
|
||
/--
|
||
Attempt to turn the provied `Expr` into a value of type `α` if
|
||
it is whatever concept of a literal `α` has. Note that this function
|
||
does assume that the provided `Expr` does indeed have type `α`.
|
||
-/
|
||
getLit : Expr → CompilerM (Option α)
|
||
/--
|
||
Turn a value of type `α` into a series of auxiliary `LetDecl`s + a
|
||
final `Expr` putting them all together into a literal of type `α`,
|
||
where again the idea of what a literal is depends on `α`.
|
||
-/
|
||
mkLit : α → FolderM Expr
|
||
|
||
export Literal (getLit mkLit)
|
||
|
||
/--
|
||
A wrapper around `LCNF.mkAuxLetDecl` that will automaticaly store the
|
||
`LetDecl` in the state of `FolderM`.
|
||
-/
|
||
def mkAuxLetDecl (e : Expr) (prefixName := `_x) : FolderM FVarId := do
|
||
let decl ← LCNF.mkAuxLetDecl e prefixName
|
||
modify fun s => s.push <| .let decl
|
||
return decl.fvarId
|
||
|
||
section Literals
|
||
|
||
/--
|
||
A wrapper around `mkAuxLetDecl` that also calls `mkLit`.
|
||
-/
|
||
def mkAuxLit [Literal α] (x : α) (prefixName := `_x) : FolderM FVarId := do
|
||
let lit ← mkLit x
|
||
mkAuxLetDecl lit prefixName
|
||
|
||
partial def getNatLit (e : Expr) : CompilerM (Option Nat) := do
|
||
match e with
|
||
| .lit (.natVal n) .. => return n
|
||
| .fvar fvarId .. =>
|
||
if let some decl ← findLetDecl? fvarId then
|
||
getNatLit decl.value
|
||
else
|
||
return none
|
||
| _ => return none
|
||
|
||
def mkNatLit (n : Nat) : FolderM Expr :=
|
||
return .lit (.natVal n)
|
||
|
||
instance : Literal Nat where
|
||
getLit := getNatLit
|
||
mkLit := mkNatLit
|
||
|
||
partial def getStringLit (e : Expr) : CompilerM (Option String) := do
|
||
match e with
|
||
| .lit (.strVal n) .. => return n
|
||
| .fvar fvarId .. =>
|
||
let some decl ← findLetDecl? fvarId | return none
|
||
getStringLit decl.value
|
||
| _ => return none
|
||
|
||
def mkStringLit (n : String) : FolderM Expr :=
|
||
return .lit (.strVal n)
|
||
|
||
instance : Literal String where
|
||
getLit := getStringLit
|
||
mkLit := mkStringLit
|
||
|
||
private partial def getLitAux [Inhabited α] (e : Expr) (ofNat : Nat → α) (ofNatName : Name) (toNat : α → Nat) : CompilerM (Option α) := do
|
||
match e with
|
||
| .fvar fvarId .. =>
|
||
if let some decl ←findLetDecl? fvarId then
|
||
getLitAux decl.value ofNat ofNatName toNat
|
||
else
|
||
return none
|
||
| .app .. =>
|
||
match e.getAppFn, e.getAppArgs with
|
||
| .const name .., #[.fvar fvarId] =>
|
||
if name == ofNatName then
|
||
if let some natLit ← getLit (.fvar fvarId) then
|
||
return ofNat natLit
|
||
return none
|
||
| _, _ => return none
|
||
| _ => return none
|
||
|
||
def mkNatWrapperInstance [Inhabited α] (ofNat : Nat → α) (ofNatName : Name) (toNat : α → Nat) : Literal α where
|
||
getLit := (getLitAux · ofNat ofNatName toNat)
|
||
mkLit x := do
|
||
let helperId ← mkAuxLit <| toNat x
|
||
return .app (mkConst ofNatName) (.fvar helperId)
|
||
|
||
instance : Literal UInt8 := mkNatWrapperInstance UInt8.ofNat ``UInt8.ofNat UInt8.toNat
|
||
instance : Literal UInt16 := mkNatWrapperInstance UInt16.ofNat ``UInt16.ofNat UInt16.toNat
|
||
instance : Literal UInt32 := mkNatWrapperInstance UInt32.ofNat ``UInt32.ofNat UInt32.toNat
|
||
instance : Literal UInt64 := mkNatWrapperInstance UInt64.ofNat ``UInt64.ofNat UInt64.toNat
|
||
instance : Literal Char := mkNatWrapperInstance Char.ofNat ``Char.ofNat Char.toNat
|
||
|
||
end Literals
|
||
|
||
/--
|
||
Turns an expression chain of the form
|
||
```
|
||
let _x.1 := @List.nil _
|
||
let _x.2 := @List.cons _ a _x.1
|
||
let _x.3 := @List.cons _ b _x.2
|
||
let _x.4 := @List.cons _ c _x.3
|
||
let _x.5 := @List.cons _ d _x.4
|
||
let _x.6 := @List.cons _ e _x.5
|
||
```
|
||
into: `[a, b, c, d ,e]` + The type contained in the list
|
||
-/
|
||
partial def getPseudoListLiteral (e : Expr) : CompilerM (Option (List FVarId × Expr × Level)) := do
|
||
go e []
|
||
where
|
||
go (e : Expr) (fvarIds : List FVarId) : CompilerM (Option (List FVarId × Expr × Level)) := do
|
||
match e with
|
||
| .app .. =>
|
||
if some ``List.nil == e.getAppFn.constName? then
|
||
return some (fvarIds.reverse, e.getAppArgs[0]!, e.getAppFn.constLevels![0]!)
|
||
else if some ``List.cons == e.getAppFn.constName? then
|
||
let args := e.getAppArgs
|
||
go args[2]! (args[1]!.fvarId! :: fvarIds)
|
||
else
|
||
return none
|
||
| .fvar fvarId =>
|
||
if let some decl ← findLetDecl? fvarId then
|
||
go decl.value fvarIds
|
||
else
|
||
return none
|
||
| _ =>
|
||
return none
|
||
|
||
/--
|
||
Turn an `#[a, b, c]` into:
|
||
```
|
||
let _x.12 := 3
|
||
let _x.8 := @Array.mkEmpty _ _x.12
|
||
let _x.22 := @Array.push _ _x.8 x
|
||
let _x.24 := @Array.push _ _x.22 y
|
||
let _x.26 := @Array.push _ _x.24 z
|
||
_x.26
|
||
```
|
||
-/
|
||
def mkPseudoArrayLiteral (elements : Array FVarId) (typ : Expr) (typLevel : Level) : FolderM Expr := do
|
||
let sizeLit ← mkAuxLit elements.size
|
||
let mut literal ← mkAuxLetDecl <| mkApp2 (mkConst ``Array.mkEmpty [typLevel]) typ (.fvar sizeLit)
|
||
for element in elements do
|
||
literal ← mkAuxLetDecl <| mkApp3 (mkConst ``Array.push [typLevel]) typ (.fvar literal) (.fvar element)
|
||
return .fvar literal
|
||
|
||
/--
|
||
Evaluate array literals at compile time, that is turn:
|
||
```
|
||
let _x.1 := @List.nil _
|
||
let _x.2 := @List.cons _ z _x.1
|
||
let _x.3 := @List.cons _ y _x.2
|
||
let _x.4 := @List.cons _ x _x.3
|
||
let _x.5 := @List.toArray _ _x.4
|
||
```
|
||
To its array form:
|
||
```
|
||
let _x.12 := 3
|
||
let _x.8 := @Array.mkEmpty _ _x.12
|
||
let _x.22 := @Array.push _ _x.8 x
|
||
let _x.24 := @Array.push _ _x.22 y
|
||
let _x.26 := @Array.push _ _x.24 z
|
||
```
|
||
-/
|
||
def foldArrayLiteral : Folder := fun exprs => do
|
||
if h:exprs.size = 2 then
|
||
have h1 : 1 < Array.size exprs := by simp_all
|
||
if let some (list, typ, level) ← getPseudoListLiteral exprs[1] then
|
||
let arr := Array.mk list
|
||
let lit ← mkPseudoArrayLiteral arr typ level
|
||
return some lit
|
||
return none
|
||
|
||
/--
|
||
Turn a unary function such as `Nat.succ` into a constant folder.
|
||
-/
|
||
def Folder.mkUnary [Literal α] [Literal β] (folder : α → β) : Folder := fun exprs => do
|
||
if h:exprs.size = 1 then
|
||
have h1 : 0 < Array.size exprs := by simp_all
|
||
if let some arg1 ← getLit exprs[0] then
|
||
let res := folder arg1
|
||
return (←mkLit res)
|
||
return none
|
||
|
||
/--
|
||
Turn a binary function such as `Nat.add` into a constant folder.
|
||
-/
|
||
def Folder.mkBinary [Literal α] [Literal β] [Literal γ] (folder : α → β → γ) : Folder := fun exprs => do
|
||
if h:exprs.size = 2 then
|
||
have h1 : 0 < Array.size exprs := by simp_all
|
||
have h2 : 1 < Array.size exprs := by simp_all
|
||
if let some arg1 ← getLit exprs[0] then
|
||
if let some arg2 ← getLit exprs[1] then
|
||
let res := folder arg1 arg2
|
||
return (←mkLit res)
|
||
return none
|
||
|
||
/--
|
||
Provide a folder for an operation with a left neutral element.
|
||
-/
|
||
def Folder.leftNeutral [Literal α] [BEq α] (neutral : α) : Folder := fun exprs => do
|
||
if h:exprs.size = 2 then
|
||
have h1 : 0 < Array.size exprs := by simp_all
|
||
have h2 : 1 < Array.size exprs := by simp_all
|
||
if let some arg1 ← getLit exprs[0] then
|
||
if arg1 == neutral then
|
||
return exprs[1]
|
||
return none
|
||
|
||
/--
|
||
Provide a folder for an operation with a right neutral element.
|
||
-/
|
||
def Folder.rightNeutral [Literal α] [BEq α] (neutral : α) : Folder := fun exprs => do
|
||
if h:exprs.size = 2 then
|
||
have h1 : 0 < Array.size exprs := by simp_all
|
||
have h2 : 1 < Array.size exprs := by simp_all
|
||
if let some arg2 ← getLit exprs[1] then
|
||
if arg2 == neutral then
|
||
return exprs[0]
|
||
return none
|
||
|
||
/--
|
||
Provide a folder for an operation with a left annihilator.
|
||
-/
|
||
def Folder.leftAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α) : Folder := fun exprs => do
|
||
if h:exprs.size = 2 then
|
||
have h1 : 0 < Array.size exprs := by simp_all
|
||
if let some arg1 ← getLit exprs[0] then
|
||
if arg1 == annihilator then
|
||
return (←mkLit zero)
|
||
return none
|
||
|
||
/--
|
||
Provide a folder for an operation with a right annihilator.
|
||
-/
|
||
def Folder.rightAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α) : Folder := fun exprs => do
|
||
if h:exprs.size = 2 then
|
||
have h1 : 1 < Array.size exprs := by simp_all
|
||
if let some arg2 ← getLit exprs[1] then
|
||
if arg2 == annihilator then
|
||
return (←mkLit zero)
|
||
return none
|
||
|
||
/--
|
||
Pick the first folder out of `folders` that succeeds.
|
||
-/
|
||
def Folder.first (folders : Array Folder) : Folder := fun exprs => do
|
||
let backup ← get
|
||
for folder in folders do
|
||
if let some res ← folder exprs then
|
||
return res
|
||
else
|
||
set backup
|
||
return none
|
||
|
||
/--
|
||
Provide a folder for an operation that has the same left and right neutral element.
|
||
-/
|
||
def Folder.leftRightNeutral [Literal α] [BEq α] (neutral : α) : Folder :=
|
||
Folder.first #[Folder.leftNeutral neutral, Folder.rightNeutral neutral]
|
||
|
||
/--
|
||
Provide a folder for an operation that has the same left and right annihilator.
|
||
-/
|
||
def Folder.leftRightAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α) : Folder :=
|
||
Folder.first #[Folder.leftAnnihilator annihilator zero, Folder.rightAnnihilator annihilator zero]
|
||
|
||
/--
|
||
Literal folders for higher order datastructures.
|
||
-/
|
||
def higherOrderLiteralFolders : List (Name × Folder) := [
|
||
(``List.toArray, foldArrayLiteral)
|
||
]
|
||
|
||
/--
|
||
All arithmetic folders.
|
||
-/
|
||
def arithmeticFolders : List (Name × Folder) := [
|
||
(``Nat.succ, Folder.mkUnary Nat.succ),
|
||
(``Nat.add, Folder.first #[Folder.mkBinary Nat.add, Folder.leftRightNeutral 0]),
|
||
(``UInt8.add, Folder.first #[Folder.mkBinary UInt8.add, Folder.leftRightNeutral (0 : UInt8)]),
|
||
(``UInt16.add, Folder.first #[Folder.mkBinary UInt16.add, Folder.leftRightNeutral (0 : UInt16)]),
|
||
(``UInt32.add, Folder.first #[Folder.mkBinary UInt32.add, Folder.leftRightNeutral (0 : UInt32)]),
|
||
(``UInt64.add, Folder.first #[Folder.mkBinary UInt64.add, Folder.leftRightNeutral (0 : UInt64)]),
|
||
(``Nat.sub, Folder.first #[Folder.mkBinary Nat.sub, Folder.leftRightNeutral 0]),
|
||
(``UInt8.sub, Folder.first #[Folder.mkBinary UInt8.sub, Folder.leftRightNeutral (0 : UInt8)]),
|
||
(``UInt16.sub, Folder.first #[Folder.mkBinary UInt16.sub, Folder.leftRightNeutral (0 : UInt16)]),
|
||
(``UInt32.sub, Folder.first #[Folder.mkBinary UInt32.sub, Folder.leftRightNeutral (0 : UInt32)]),
|
||
(``UInt64.sub, Folder.first #[Folder.mkBinary UInt64.sub, Folder.leftRightNeutral (0 : UInt64)]),
|
||
(``Nat.mul, Folder.first #[Folder.mkBinary Nat.mul, Folder.leftRightNeutral 1, Folder.leftRightAnnihilator 0 0]),
|
||
(``UInt8.mul, Folder.first #[Folder.mkBinary UInt8.mul, Folder.leftRightNeutral (1 : UInt8), Folder.leftRightAnnihilator (0 : UInt8) 0]),
|
||
(``UInt16.mul, Folder.first #[Folder.mkBinary UInt16.mul, Folder.leftRightNeutral (1 : UInt16), Folder.leftRightAnnihilator (0 : UInt16) 0]),
|
||
(``UInt32.mul, Folder.first #[Folder.mkBinary UInt32.mul, Folder.leftRightNeutral (1 : UInt32), Folder.leftRightAnnihilator (0 : UInt32) 0]),
|
||
(``UInt64.mul, Folder.first #[Folder.mkBinary UInt64.mul, Folder.leftRightNeutral (1 : UInt64), Folder.leftRightAnnihilator (0 : UInt64) 0]),
|
||
(``Nat.div, Folder.first #[Folder.mkBinary Nat.div, Folder.rightNeutral 1]),
|
||
(``UInt8.div, Folder.first #[Folder.mkBinary UInt8.div, Folder.rightNeutral (1 : UInt8)]),
|
||
(``UInt16.div, Folder.first #[Folder.mkBinary UInt16.div, Folder.rightNeutral (1 : UInt16)]),
|
||
(``UInt32.div, Folder.first #[Folder.mkBinary UInt32.div, Folder.rightNeutral (1 : UInt32)]),
|
||
(``UInt64.div, Folder.first #[Folder.mkBinary UInt64.div, Folder.rightNeutral (1 : UInt64)])
|
||
]
|
||
|
||
/--
|
||
All string folders.
|
||
-/
|
||
def stringFolders : List (Name × Folder) := [
|
||
(``String.append, Folder.first #[Folder.mkBinary String.append, Folder.leftRightNeutral ""]),
|
||
(``String.length, Folder.mkUnary String.length),
|
||
(``String.push, Folder.mkBinary String.push)
|
||
]
|
||
|
||
/--
|
||
Apply all known folders to `decl`.
|
||
-/
|
||
def applyFolders (decl : LetDecl) (folders : SMap Name Folder) : CompilerM (Option (Array CodeDecl)) := do
|
||
let e := decl.value
|
||
match e with
|
||
| .app .. =>
|
||
match e.getAppFn with
|
||
| .const name .. =>
|
||
if let some folder := folders.find? name then
|
||
if let (some res, aux) ← folder e.getAppArgs |>.run #[] then
|
||
let decl ← decl.updateValue res
|
||
return some <| aux.push (.let decl)
|
||
return none
|
||
| _ => return none
|
||
| .const .. | .lit .. | .fvar .. | .bvar .. | .lam .. | .sort .. |
|
||
.forallE .. | .letE .. | .mdata .. =>
|
||
return none
|
||
-- TODO: support for constant folding on projections
|
||
| .proj .. => return none
|
||
| _ => unreachable!
|
||
|
||
private unsafe def getFolderCoreUnsafe (env : Environment) (opts : Options) (declName : Name) : ExceptT String Id Folder :=
|
||
env.evalConstCheck Folder opts ``Folder declName
|
||
|
||
@[implementedBy getFolderCoreUnsafe]
|
||
private opaque getFolderCore (env : Environment) (opts : Options) (declName : Name) : ExceptT String Id Folder
|
||
|
||
private def getFolder (declName : Name) : CoreM Folder := do
|
||
ofExcept <| getFolderCore (← getEnv) (← getOptions) declName
|
||
|
||
def builtinFolders : SMap Name Folder :=
|
||
(arithmeticFolders ++ higherOrderLiteralFolders ++ stringFolders).foldl (init := {}) fun s (declName, folder) =>
|
||
s.insert declName folder
|
||
|
||
structure FolderOleanEntry where
|
||
declName : Name
|
||
folderDeclName : Name
|
||
|
||
structure FolderEntry extends FolderOleanEntry where
|
||
folder : Folder
|
||
|
||
builtin_initialize folderExt : PersistentEnvExtension FolderOleanEntry FolderEntry (List FolderOleanEntry × SMap Name Folder) ←
|
||
registerPersistentEnvExtension {
|
||
name := `cfolder
|
||
mkInitial := return ([], builtinFolders)
|
||
addImportedFn := fun entriesArray => do
|
||
let ctx ← read
|
||
let mut folders := builtinFolders
|
||
for entries in entriesArray do
|
||
for { declName, folderDeclName } in entries do
|
||
let folder ← IO.ofExcept <| getFolderCore ctx.env ctx.opts folderDeclName
|
||
folders := folders.insert declName folder
|
||
return ([], folders.switch)
|
||
addEntryFn := fun (entries, map) entry => (entry.toFolderOleanEntry :: entries, map.insert entry.declName entry.folder)
|
||
exportEntriesFn := fun (entries, _) => entries.reverse.toArray
|
||
}
|
||
|
||
def registerFolder (declName : Name) (folderDeclName : Name) : CoreM Unit := do
|
||
let folder ← getFolder folderDeclName
|
||
modifyEnv fun env => folderExt.addEntry env { declName, folderDeclName, folder }
|
||
|
||
def getFolders : CoreM (SMap Name Folder) :=
|
||
return folderExt.getState (← getEnv) |>.2
|
||
|
||
/--
|
||
Apply a list of default folders to `decl`
|
||
-/
|
||
def foldConstants (decl : LetDecl) : CompilerM (Option (Array CodeDecl)) := do
|
||
applyFolders decl (← getFolders)
|
||
|
||
end ConstantFold
|
||
end Lean.Compiler.LCNF.Simp
|