lean4-htt/src/Lean/Compiler/LCNF/Simp/ConstantFold.lean
Leonardo de Moura 595734b936 chore: remove workaround
It is now implemented at `Quote (Array _)`
2022-09-29 17:12:48 -07:00

415 lines
15 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2022 Henrik Böving. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Henrik Böving
-/
import Lean.Compiler.LCNF.CompilerM
import Lean.Compiler.LCNF.InferType
import Lean.Compiler.LCNF.PassManager
namespace Lean.Compiler.LCNF.Simp
namespace ConstantFold
/--
A constant folding monad, the additional state stores auxiliary declarations
required to build the new constant.
-/
abbrev FolderM := StateRefT (Array CodeDecl) CompilerM
/--
A constant folder for a specific function, takes all the arguments of a
certain function and produces a new `Expr` + auxiliary declarations in
the `FolderM` monad on success. If the folding fails it returns `none`.
-/
abbrev Folder := Array Expr → FolderM (Option Expr)
/--
A typeclass for detecting and producing literals of arbitrary types
inside of LCNF.
-/
class Literal (α : Type) where
/--
Attempt to turn the provied `Expr` into a value of type `α` if
it is whatever concept of a literal `α` has. Note that this function
does assume that the provided `Expr` does indeed have type `α`.
-/
getLit : Expr → CompilerM (Option α)
/--
Turn a value of type `α` into a series of auxiliary `LetDecl`s + a
final `Expr` putting them all together into a literal of type `α`,
where again the idea of what a literal is depends on `α`.
-/
mkLit : α → FolderM Expr
export Literal (getLit mkLit)
/--
A wrapper around `LCNF.mkAuxLetDecl` that will automaticaly store the
`LetDecl` in the state of `FolderM`.
-/
def mkAuxLetDecl (e : Expr) (prefixName := `_x) : FolderM FVarId := do
let decl ← LCNF.mkAuxLetDecl e prefixName
modify fun s => s.push <| .let decl
return decl.fvarId
section Literals
/--
A wrapper around `mkAuxLetDecl` that also calls `mkLit`.
-/
def mkAuxLit [Literal α] (x : α) (prefixName := `_x) : FolderM FVarId := do
let lit ← mkLit x
mkAuxLetDecl lit prefixName
partial def getNatLit (e : Expr) : CompilerM (Option Nat) := do
match e with
| .lit (.natVal n) .. => return n
| .fvar fvarId .. =>
if let some decl ← findLetDecl? fvarId then
getNatLit decl.value
else
return none
| _ => return none
def mkNatLit (n : Nat) : FolderM Expr :=
return .lit (.natVal n)
instance : Literal Nat where
getLit := getNatLit
mkLit := mkNatLit
partial def getStringLit (e : Expr) : CompilerM (Option String) := do
match e with
| .lit (.strVal n) .. => return n
| .fvar fvarId .. =>
let some decl ← findLetDecl? fvarId | return none
getStringLit decl.value
| _ => return none
def mkStringLit (n : String) : FolderM Expr :=
return .lit (.strVal n)
instance : Literal String where
getLit := getStringLit
mkLit := mkStringLit
private partial def getLitAux [Inhabited α] (e : Expr) (ofNat : Nat → α) (ofNatName : Name) (toNat : α → Nat) : CompilerM (Option α) := do
match e with
| .fvar fvarId .. =>
if let some decl ←findLetDecl? fvarId then
getLitAux decl.value ofNat ofNatName toNat
else
return none
| .app .. =>
match e.getAppFn, e.getAppArgs with
| .const name .., #[.fvar fvarId] =>
if name == ofNatName then
if let some natLit ← getLit (.fvar fvarId) then
return ofNat natLit
return none
| _, _ => return none
| _ => return none
def mkNatWrapperInstance [Inhabited α] (ofNat : Nat → α) (ofNatName : Name) (toNat : α → Nat) : Literal α where
getLit := (getLitAux · ofNat ofNatName toNat)
mkLit x := do
let helperId ← mkAuxLit <| toNat x
return .app (mkConst ofNatName) (.fvar helperId)
instance : Literal UInt8 := mkNatWrapperInstance UInt8.ofNat ``UInt8.ofNat UInt8.toNat
instance : Literal UInt16 := mkNatWrapperInstance UInt16.ofNat ``UInt16.ofNat UInt16.toNat
instance : Literal UInt32 := mkNatWrapperInstance UInt32.ofNat ``UInt32.ofNat UInt32.toNat
instance : Literal UInt64 := mkNatWrapperInstance UInt64.ofNat ``UInt64.ofNat UInt64.toNat
instance : Literal Char := mkNatWrapperInstance Char.ofNat ``Char.ofNat Char.toNat
end Literals
/--
Turns an expression chain of the form
```
let _x.1 := @List.nil _
let _x.2 := @List.cons _ a _x.1
let _x.3 := @List.cons _ b _x.2
let _x.4 := @List.cons _ c _x.3
let _x.5 := @List.cons _ d _x.4
let _x.6 := @List.cons _ e _x.5
```
into: `[a, b, c, d ,e]` + The type contained in the list
-/
partial def getPseudoListLiteral (e : Expr) : CompilerM (Option (List FVarId × Expr × Level)) := do
go e []
where
go (e : Expr) (fvarIds : List FVarId) : CompilerM (Option (List FVarId × Expr × Level)) := do
match e with
| .app .. =>
if some ``List.nil == e.getAppFn.constName? then
return some (fvarIds.reverse, e.getAppArgs[0]!, e.getAppFn.constLevels![0]!)
else if some ``List.cons == e.getAppFn.constName? then
let args := e.getAppArgs
go args[2]! (args[1]!.fvarId! :: fvarIds)
else
return none
| .fvar fvarId =>
if let some decl ← findLetDecl? fvarId then
go decl.value fvarIds
else
return none
| _ =>
return none
/--
Turn an `#[a, b, c]` into:
```
let _x.12 := 3
let _x.8 := @Array.mkEmpty _ _x.12
let _x.22 := @Array.push _ _x.8 x
let _x.24 := @Array.push _ _x.22 y
let _x.26 := @Array.push _ _x.24 z
_x.26
```
-/
def mkPseudoArrayLiteral (elements : Array FVarId) (typ : Expr) (typLevel : Level) : FolderM Expr := do
let sizeLit ← mkAuxLit elements.size
let mut literal ← mkAuxLetDecl <| mkApp2 (mkConst ``Array.mkEmpty [typLevel]) typ (.fvar sizeLit)
for element in elements do
literal ← mkAuxLetDecl <| mkApp3 (mkConst ``Array.push [typLevel]) typ (.fvar literal) (.fvar element)
return .fvar literal
/--
Evaluate array literals at compile time, that is turn:
```
let _x.1 := @List.nil _
let _x.2 := @List.cons _ z _x.1
let _x.3 := @List.cons _ y _x.2
let _x.4 := @List.cons _ x _x.3
let _x.5 := @List.toArray _ _x.4
```
To its array form:
```
let _x.12 := 3
let _x.8 := @Array.mkEmpty _ _x.12
let _x.22 := @Array.push _ _x.8 x
let _x.24 := @Array.push _ _x.22 y
let _x.26 := @Array.push _ _x.24 z
```
-/
def foldArrayLiteral : Folder := fun exprs => do
if h:exprs.size = 2 then
have h1 : 1 < Array.size exprs := by simp_all
if let some (list, typ, level) ← getPseudoListLiteral exprs[1] then
let arr := Array.mk list
let lit ← mkPseudoArrayLiteral arr typ level
return some lit
return none
/--
Turn a unary function such as `Nat.succ` into a constant folder.
-/
def Folder.mkUnary [Literal α] [Literal β] (folder : α → β) : Folder := fun exprs => do
if h:exprs.size = 1 then
have h1 : 0 < Array.size exprs := by simp_all
if let some arg1 ← getLit exprs[0] then
let res := folder arg1
return (←mkLit res)
return none
/--
Turn a binary function such as `Nat.add` into a constant folder.
-/
def Folder.mkBinary [Literal α] [Literal β] [Literal γ] (folder : α → β → γ) : Folder := fun exprs => do
if h:exprs.size = 2 then
have h1 : 0 < Array.size exprs := by simp_all
have h2 : 1 < Array.size exprs := by simp_all
if let some arg1 ← getLit exprs[0] then
if let some arg2 ← getLit exprs[1] then
let res := folder arg1 arg2
return (←mkLit res)
return none
/--
Provide a folder for an operation with a left neutral element.
-/
def Folder.leftNeutral [Literal α] [BEq α] (neutral : α) : Folder := fun exprs => do
if h:exprs.size = 2 then
have h1 : 0 < Array.size exprs := by simp_all
have h2 : 1 < Array.size exprs := by simp_all
if let some arg1 ← getLit exprs[0] then
if arg1 == neutral then
return exprs[1]
return none
/--
Provide a folder for an operation with a right neutral element.
-/
def Folder.rightNeutral [Literal α] [BEq α] (neutral : α) : Folder := fun exprs => do
if h:exprs.size = 2 then
have h1 : 0 < Array.size exprs := by simp_all
have h2 : 1 < Array.size exprs := by simp_all
if let some arg2 ← getLit exprs[1] then
if arg2 == neutral then
return exprs[0]
return none
/--
Provide a folder for an operation with a left annihilator.
-/
def Folder.leftAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α) : Folder := fun exprs => do
if h:exprs.size = 2 then
have h1 : 0 < Array.size exprs := by simp_all
if let some arg1 ← getLit exprs[0] then
if arg1 == annihilator then
return (←mkLit zero)
return none
/--
Provide a folder for an operation with a right annihilator.
-/
def Folder.rightAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α) : Folder := fun exprs => do
if h:exprs.size = 2 then
have h1 : 1 < Array.size exprs := by simp_all
if let some arg2 ← getLit exprs[1] then
if arg2 == annihilator then
return (←mkLit zero)
return none
/--
Pick the first folder out of `folders` that succeeds.
-/
def Folder.first (folders : Array Folder) : Folder := fun exprs => do
let backup ← get
for folder in folders do
if let some res ← folder exprs then
return res
else
set backup
return none
/--
Provide a folder for an operation that has the same left and right neutral element.
-/
def Folder.leftRightNeutral [Literal α] [BEq α] (neutral : α) : Folder :=
Folder.first #[Folder.leftNeutral neutral, Folder.rightNeutral neutral]
/--
Provide a folder for an operation that has the same left and right annihilator.
-/
def Folder.leftRightAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α) : Folder :=
Folder.first #[Folder.leftAnnihilator annihilator zero, Folder.rightAnnihilator annihilator zero]
/--
Literal folders for higher order datastructures.
-/
def higherOrderLiteralFolders : List (Name × Folder) := [
(``List.toArray, foldArrayLiteral)
]
/--
All arithmetic folders.
-/
def arithmeticFolders : List (Name × Folder) := [
(``Nat.succ, Folder.mkUnary Nat.succ),
(``Nat.add, Folder.first #[Folder.mkBinary Nat.add, Folder.leftRightNeutral 0]),
(``UInt8.add, Folder.first #[Folder.mkBinary UInt8.add, Folder.leftRightNeutral (0 : UInt8)]),
(``UInt16.add, Folder.first #[Folder.mkBinary UInt16.add, Folder.leftRightNeutral (0 : UInt16)]),
(``UInt32.add, Folder.first #[Folder.mkBinary UInt32.add, Folder.leftRightNeutral (0 : UInt32)]),
(``UInt64.add, Folder.first #[Folder.mkBinary UInt64.add, Folder.leftRightNeutral (0 : UInt64)]),
(``Nat.sub, Folder.first #[Folder.mkBinary Nat.sub, Folder.leftRightNeutral 0]),
(``UInt8.sub, Folder.first #[Folder.mkBinary UInt8.sub, Folder.leftRightNeutral (0 : UInt8)]),
(``UInt16.sub, Folder.first #[Folder.mkBinary UInt16.sub, Folder.leftRightNeutral (0 : UInt16)]),
(``UInt32.sub, Folder.first #[Folder.mkBinary UInt32.sub, Folder.leftRightNeutral (0 : UInt32)]),
(``UInt64.sub, Folder.first #[Folder.mkBinary UInt64.sub, Folder.leftRightNeutral (0 : UInt64)]),
(``Nat.mul, Folder.first #[Folder.mkBinary Nat.mul, Folder.leftRightNeutral 1, Folder.leftRightAnnihilator 0 0]),
(``UInt8.mul, Folder.first #[Folder.mkBinary UInt8.mul, Folder.leftRightNeutral (1 : UInt8), Folder.leftRightAnnihilator (0 : UInt8) 0]),
(``UInt16.mul, Folder.first #[Folder.mkBinary UInt16.mul, Folder.leftRightNeutral (1 : UInt16), Folder.leftRightAnnihilator (0 : UInt16) 0]),
(``UInt32.mul, Folder.first #[Folder.mkBinary UInt32.mul, Folder.leftRightNeutral (1 : UInt32), Folder.leftRightAnnihilator (0 : UInt32) 0]),
(``UInt64.mul, Folder.first #[Folder.mkBinary UInt64.mul, Folder.leftRightNeutral (1 : UInt64), Folder.leftRightAnnihilator (0 : UInt64) 0]),
(``Nat.div, Folder.first #[Folder.mkBinary Nat.div, Folder.rightNeutral 1]),
(``UInt8.div, Folder.first #[Folder.mkBinary UInt8.div, Folder.rightNeutral (1 : UInt8)]),
(``UInt16.div, Folder.first #[Folder.mkBinary UInt16.div, Folder.rightNeutral (1 : UInt16)]),
(``UInt32.div, Folder.first #[Folder.mkBinary UInt32.div, Folder.rightNeutral (1 : UInt32)]),
(``UInt64.div, Folder.first #[Folder.mkBinary UInt64.div, Folder.rightNeutral (1 : UInt64)])
]
/--
All string folders.
-/
def stringFolders : List (Name × Folder) := [
(``String.append, Folder.first #[Folder.mkBinary String.append, Folder.leftRightNeutral ""]),
(``String.length, Folder.mkUnary String.length),
(``String.push, Folder.mkBinary String.push)
]
/--
Apply all known folders to `decl`.
-/
def applyFolders (decl : LetDecl) (folders : SMap Name Folder) : CompilerM (Option (Array CodeDecl)) := do
let e := decl.value
match e with
| .app .. =>
match e.getAppFn with
| .const name .. =>
if let some folder := folders.find? name then
if let (some res, aux) ← folder e.getAppArgs |>.run #[] then
let decl ← decl.updateValue res
return some <| aux.push (.let decl)
return none
| _ => return none
| .const .. | .lit .. | .fvar .. | .bvar .. | .lam .. | .sort .. |
.forallE .. | .letE .. | .mdata .. =>
return none
-- TODO: support for constant folding on projections
| .proj .. => return none
| _ => unreachable!
private unsafe def getFolderCoreUnsafe (env : Environment) (opts : Options) (declName : Name) : ExceptT String Id Folder :=
env.evalConstCheck Folder opts ``Folder declName
@[implementedBy getFolderCoreUnsafe]
private opaque getFolderCore (env : Environment) (opts : Options) (declName : Name) : ExceptT String Id Folder
private def getFolder (declName : Name) : CoreM Folder := do
ofExcept <| getFolderCore (← getEnv) (← getOptions) declName
def builtinFolders : SMap Name Folder :=
(arithmeticFolders ++ higherOrderLiteralFolders ++ stringFolders).foldl (init := {}) fun s (declName, folder) =>
s.insert declName folder
structure FolderOleanEntry where
declName : Name
folderDeclName : Name
structure FolderEntry extends FolderOleanEntry where
folder : Folder
builtin_initialize folderExt : PersistentEnvExtension FolderOleanEntry FolderEntry (List FolderOleanEntry × SMap Name Folder) ←
registerPersistentEnvExtension {
name := `cfolder
mkInitial := return ([], builtinFolders)
addImportedFn := fun entriesArray => do
let ctx ← read
let mut folders := builtinFolders
for entries in entriesArray do
for { declName, folderDeclName } in entries do
let folder ← IO.ofExcept <| getFolderCore ctx.env ctx.opts folderDeclName
folders := folders.insert declName folder
return ([], folders.switch)
addEntryFn := fun (entries, map) entry => (entry.toFolderOleanEntry :: entries, map.insert entry.declName entry.folder)
exportEntriesFn := fun (entries, _) => entries.reverse.toArray
}
def registerFolder (declName : Name) (folderDeclName : Name) : CoreM Unit := do
let folder ← getFolder folderDeclName
modifyEnv fun env => folderExt.addEntry env { declName, folderDeclName, folder }
def getFolders : CoreM (SMap Name Folder) :=
return folderExt.getState (← getEnv) |>.2
/--
Apply a list of default folders to `decl`
-/
def foldConstants (decl : LetDecl) : CompilerM (Option (Array CodeDecl)) := do
applyFolders decl (← getFolders)
end ConstantFold
end Lean.Compiler.LCNF.Simp