chore: port ConstantFold.lean

This commit is contained in:
Leonardo de Moura 2022-11-01 18:43:04 -07:00
parent e6232b67b6
commit fd316ef027

View file

@ -7,9 +7,6 @@ import Lean.Compiler.LCNF.CompilerM
import Lean.Compiler.LCNF.InferType
import Lean.Compiler.LCNF.PassManager
set_option warningAsError false
#exit
namespace Lean.Compiler.LCNF.Simp
namespace ConstantFold
@ -24,7 +21,7 @@ A constant folder for a specific function, takes all the arguments of a
certain function and produces a new `Expr` + auxiliary declarations in
the `FolderM` monad on success. If the folding fails it returns `none`.
-/
abbrev Folder := Array Expr → FolderM (Option Expr)
abbrev Folder := Array Arg → FolderM (Option LetExpr)
/--
A typeclass for detecting and producing literals of arbitrary types
@ -32,25 +29,25 @@ inside of LCNF.
-/
class Literal (α : Type) where
/--
Attempt to turn the provied `Expr` into a value of type `α` if
Attempt to turn the provided `Expr` into a value of type `α` if
it is whatever concept of a literal `α` has. Note that this function
does assume that the provided `Expr` does indeed have type `α`.
-/
getLit : Expr → CompilerM (Option α)
getLit : FVarId → CompilerM (Option α)
/--
Turn a value of type `α` into a series of auxiliary `LetDecl`s + a
final `Expr` putting them all together into a literal of type `α`,
where again the idea of what a literal is depends on `α`.
-/
mkLit : α → FolderM Expr
mkLit : α → FolderM LetExpr
export Literal (getLit mkLit)
/--
A wrapper around `LCNF.mkAuxLetDecl` that will automaticaly store the
A wrapper around `LCNF.mkAuxLetDecl` that will automatically store the
`LetDecl` in the state of `FolderM`.
-/
def mkAuxLetDecl (e : Expr) (prefixName := `_x) : FolderM FVarId := do
def mkAuxLetDecl (e : LetExpr) (prefixName := `_x) : FolderM FVarId := do
let decl ← LCNF.mkAuxLetDecl e prefixName
modify fun s => s.push <| .let decl
return decl.fvarId
@ -64,60 +61,39 @@ def mkAuxLit [Literal α] (x : α) (prefixName := `_x) : FolderM FVarId := do
let lit ← mkLit x
mkAuxLetDecl lit prefixName
partial def getNatLit (e : Expr) : CompilerM (Option Nat) := do
match e with
| .lit (.natVal n) .. => return n
| .fvar fvarId .. =>
if let some decl ← findLetDecl? fvarId then
getNatLit decl.value
else
return none
| _ => return none
partial def getNatLit (fvarId : FVarId) : CompilerM (Option Nat) := do
let some (.value (.natVal n)) ← findLetExpr? fvarId | return none
return n
def mkNatLit (n : Nat) : FolderM Expr :=
return .lit (.natVal n)
def mkNatLit (n : Nat) : FolderM LetExpr :=
return .value (.natVal n)
instance : Literal Nat where
getLit := getNatLit
mkLit := mkNatLit
partial def getStringLit (e : Expr) : CompilerM (Option String) := do
match e with
| .lit (.strVal n) .. => return n
| .fvar fvarId .. =>
let some decl ← findLetDecl? fvarId | return none
getStringLit decl.value
| _ => return none
partial def getStringLit (fvarId : FVarId) : CompilerM (Option String) := do
let some (.value (.strVal s)) ← findLetExpr? fvarId | return none
return s
def mkStringLit (n : String) : FolderM Expr :=
return .lit (.strVal n)
def mkStringLit (n : String) : FolderM LetExpr :=
return .value (.strVal n)
instance : Literal String where
getLit := getStringLit
mkLit := mkStringLit
private partial def getLitAux [Inhabited α] (e : Expr) (ofNat : Nat → α) (ofNatName : Name) (toNat : α → Nat) : CompilerM (Option α) := do
match e with
| .fvar fvarId .. =>
if let some decl ←findLetDecl? fvarId then
getLitAux decl.value ofNat ofNatName toNat
else
return none
| .app .. =>
match e.getAppFn, e.getAppArgs with
| .const name .., #[.fvar fvarId] =>
if name == ofNatName then
if let some natLit ← getLit (.fvar fvarId) then
return ofNat natLit
return none
| _, _ => return none
| _ => return none
private partial def getLitAux [Inhabited α] (fvarId : FVarId) (ofNat : Nat → α) (ofNatName : Name) : CompilerM (Option α) := do
let some (.const declName _ #[.fvar fvarId]) ← findLetExpr? fvarId | return none
unless declName == ofNatName do return none
let some natLit ← getLit fvarId | return none
return ofNat natLit
def mkNatWrapperInstance [Inhabited α] (ofNat : Nat → α) (ofNatName : Name) (toNat : α → Nat) : Literal α where
getLit := (getLitAux · ofNat ofNatName toNat)
getLit := (getLitAux · ofNat ofNatName)
mkLit x := do
let helperId ← mkAuxLit <| toNat x
return .app (mkConst ofNatName) (.fvar helperId)
return .const ofNatName [] #[.fvar helperId]
instance : Literal UInt8 := mkNatWrapperInstance UInt8.ofNat ``UInt8.ofNat UInt8.toNat
instance : Literal UInt16 := mkNatWrapperInstance UInt16.ofNat ``UInt16.ofNat UInt16.toNat
@ -139,26 +115,17 @@ let _x.6 := @List.cons _ e _x.5
```
into: `[a, b, c, d ,e]` + The type contained in the list
-/
partial def getPseudoListLiteral (e : Expr) : CompilerM (Option (List FVarId × Expr × Level)) := do
go e []
partial def getPseudoListLiteral (fvarId : FVarId) : CompilerM (Option (List FVarId × Expr × Level)) := do
go fvarId []
where
go (e : Expr) (fvarIds : List FVarId) : CompilerM (Option (List FVarId × Expr × Level)) := do
go (fvarId : FVarId) (fvarIds : List FVarId) : CompilerM (Option (List FVarId × Expr × Level)) := do
let some e ← findLetExpr? fvarId | return none
match e with
| .app .. =>
if some ``List.nil == e.getAppFn.constName? then
return some (fvarIds.reverse, e.getAppArgs[0]!, e.getAppFn.constLevels![0]!)
else if some ``List.cons == e.getAppFn.constName? then
let args := e.getAppArgs
go args[2]! (args[1]!.fvarId! :: fvarIds)
else
return none
| .fvar fvarId =>
if let some decl ← findLetDecl? fvarId then
go decl.value fvarIds
else
return none
| _ =>
return none
| .const ``List.nil [u] #[.type α] =>
return some (fvarIds.reverse, α, u)
| .const ``List.cons _ #[_, .fvar h, .fvar t] =>
go t (h :: fvarIds)
| _ => return none
/--
Turn an `#[a, b, c]` into:
@ -171,12 +138,12 @@ let _x.26 := @Array.push _ _x.24 z
_x.26
```
-/
def mkPseudoArrayLiteral (elements : Array FVarId) (typ : Expr) (typLevel : Level) : FolderM Expr := do
def mkPseudoArrayLiteral (elements : Array FVarId) (typ : Expr) (typLevel : Level) : FolderM LetExpr := do
let sizeLit ← mkAuxLit elements.size
let mut literal ← mkAuxLetDecl <| mkApp2 (mkConst ``Array.mkEmpty [typLevel]) typ (.fvar sizeLit)
let mut literal ← mkAuxLetDecl <| .const ``Array.mkEmpty [typLevel] #[.type typ, .fvar sizeLit]
for element in elements do
literal ← mkAuxLetDecl <| mkApp3 (mkConst ``Array.push [typLevel]) typ (.fvar literal) (.fvar element)
return .fvar literal
literal ← mkAuxLetDecl <| .const ``Array.push [typLevel] #[.type typ, .fvar literal, .fvar element]
return .fvar literal #[]
/--
Evaluate array literals at compile time, that is turn:
@ -196,84 +163,66 @@ let _x.24 := @Array.push _ _x.22 y
let _x.26 := @Array.push _ _x.24 z
```
-/
def foldArrayLiteral : Folder := fun exprs => do
if h:exprs.size = 2 then
have h1 : 1 < Array.size exprs := by simp_all
if let some (list, typ, level) ← getPseudoListLiteral exprs[1] then
let arr := Array.mk list
let lit ← mkPseudoArrayLiteral arr typ level
return some lit
return none
def foldArrayLiteral : Folder := fun args => do
let #[_, .fvar fvarId] := args | return none
let some (list, typ, level) ← getPseudoListLiteral fvarId | return none
let arr := Array.mk list
let lit ← mkPseudoArrayLiteral arr typ level
return some lit
/--
Turn a unary function such as `Nat.succ` into a constant folder.
-/
def Folder.mkUnary [Literal α] [Literal β] (folder : α → β) : Folder := fun exprs => do
if h:exprs.size = 1 then
have h1 : 0 < Array.size exprs := by simp_all
if let some arg1 ← getLit exprs[0] then
let res := folder arg1
return (←mkLit res)
return none
def Folder.mkUnary [Literal α] [Literal β] (folder : α → β) : Folder := fun args => do
let #[.fvar fvarId] := args | return none
let some arg1 ← getLit fvarId | return none
let res := folder arg1
mkLit res
/--
Turn a binary function such as `Nat.add` into a constant folder.
-/
def Folder.mkBinary [Literal α] [Literal β] [Literal γ] (folder : α → β → γ) : Folder := fun exprs => do
if h:exprs.size = 2 then
have h1 : 0 < Array.size exprs := by simp_all
have h2 : 1 < Array.size exprs := by simp_all
if let some arg1 ← getLit exprs[0] then
if let some arg2 ← getLit exprs[1] then
let res := folder arg1 arg2
return (←mkLit res)
return none
def Folder.mkBinary [Literal α] [Literal β] [Literal γ] (folder : α → β → γ) : Folder := fun args => do
let #[.fvar fvarId₁, .fvar fvarId₂] := args | return none
let some arg₁ ← getLit fvarId₁ | return none
let some arg₂ ← getLit fvarId₂ | return none
mkLit <| folder arg₁ arg₂
/--
Provide a folder for an operation with a left neutral element.
-/
def Folder.leftNeutral [Literal α] [BEq α] (neutral : α) : Folder := fun exprs => do
if h:exprs.size = 2 then
have h1 : 0 < Array.size exprs := by simp_all
have h2 : 1 < Array.size exprs := by simp_all
if let some arg1 ← getLit exprs[0] then
if arg1 == neutral then
return exprs[1]
return none
def Folder.leftNeutral [Literal α] [BEq α] (neutral : α) : Folder := fun args => do
let #[.fvar fvarId₁, .fvar fvarId₂] := args | return none
let some arg₁ ← getLit fvarId₁ | return none
unless arg₁ == neutral do return none
return some <| .fvar fvarId₂ #[]
/--
Provide a folder for an operation with a right neutral element.
-/
def Folder.rightNeutral [Literal α] [BEq α] (neutral : α) : Folder := fun exprs => do
if h:exprs.size = 2 then
have h1 : 0 < Array.size exprs := by simp_all
have h2 : 1 < Array.size exprs := by simp_all
if let some arg2 ← getLit exprs[1] then
if arg2 == neutral then
return exprs[0]
return none
def Folder.rightNeutral [Literal α] [BEq α] (neutral : α) : Folder := fun args => do
let #[.fvar fvarId₁, .fvar fvarId₂] := args | return none
let some arg₂ ← getLit fvarId₂ | return none
unless arg₂ == neutral do return none
return some <| .fvar fvarId₁ #[]
/--
Provide a folder for an operation with a left annihilator.
-/
def Folder.leftAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α) : Folder := fun exprs => do
if h:exprs.size = 2 then
have h1 : 0 < Array.size exprs := by simp_all
if let some arg1 ← getLit exprs[0] then
if arg1 == annihilator then
return (←mkLit zero)
return none
def Folder.leftAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α) : Folder := fun args => do
let #[.fvar fvarId, _] := args | return none
let some arg ← getLit fvarId | return none
unless arg == annihilator do return none
mkLit zero
/--
Provide a folder for an operation with a right annihilator.
-/
def Folder.rightAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α) : Folder := fun exprs => do
if h:exprs.size = 2 then
have h1 : 1 < Array.size exprs := by simp_all
if let some arg2 ← getLit exprs[1] then
if arg2 == annihilator then
return (←mkLit zero)
return none
def Folder.rightAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α) : Folder := fun args => do
let #[_, .fvar fvarId] := args | return none
let some arg ← getLit fvarId | return none
unless arg == annihilator do return none
mkLit zero
/--
Pick the first folder out of `folders` that succeeds.
@ -346,11 +295,10 @@ def stringFolders : List (Name × Folder) := [
Apply all known folders to `decl`.
-/
def applyFolders (decl : LetDecl) (folders : SMap Name Folder) : CompilerM (Option (Array CodeDecl)) := do
let e := decl.value
match e.getAppFn with
| .const name .. =>
match decl.value with
| .const name _ args =>
if let some folder := folders.find? name then
if let (some res, aux) ← folder e.getAppArgs |>.run #[] then
if let (some res, aux) ← folder args |>.run #[] then
let decl ← decl.updateValue res
return some <| aux.push (.let decl)
return none