253 lines
9.7 KiB
Text
253 lines
9.7 KiB
Text
/-
|
||
Copyright (c) 2022 Henrik Böving. All rights reserved.
|
||
Released under Apache 2.0 license as described in the file LICENSE.
|
||
Authors: Henrik Böving
|
||
-/
|
||
prelude
|
||
import Lean.Compiler.LCNF.PassManager
|
||
import Lean.Compiler.LCNF.PrettyPrinter
|
||
|
||
namespace Lean.Compiler.LCNF
|
||
|
||
partial def Code.containsConst (constName : Name) (code : Code) : Bool :=
|
||
match code with
|
||
| .let decl k => goLetValue decl.value || containsConst constName k
|
||
| .fun decl k => containsConst constName decl.value || containsConst constName k
|
||
| .jp decl k => containsConst constName decl.value || containsConst constName k
|
||
| .cases cs => cs.alts.any fun alt => containsConst constName alt.getCode
|
||
| .return .. | .unreach .. | .jmp .. => false
|
||
where
|
||
goExpr (e : Expr) : Bool :=
|
||
match e with
|
||
| .const name .. => name == constName
|
||
| .app fn arg .. => goExpr fn || goExpr arg
|
||
| .lam _ _ body .. => goExpr body
|
||
| .proj _ _ struct .. => goExpr struct
|
||
| .letE .. => unreachable! -- not possible in LCNF
|
||
| _ => false
|
||
goLetValue (l : LetValue) : Bool :=
|
||
match l with
|
||
| .lit .. | .erased | .proj .. | .fvar .. => false
|
||
| .const name .. => name == constName
|
||
|
||
namespace Testing
|
||
|
||
structure TestInstallerContext where
|
||
passUnderTestName : Name
|
||
testName : Name
|
||
|
||
structure TestContext where
|
||
passUnderTest : Pass
|
||
testName : Name
|
||
|
||
structure SimpleAssertionContext where
|
||
decls : Array Decl
|
||
|
||
structure InOutAssertionContext where
|
||
input : Array Decl
|
||
output : Array Decl
|
||
|
||
abbrev TestInstallerM := ReaderM TestInstallerContext
|
||
abbrev TestInstaller := TestInstallerM PassInstaller
|
||
|
||
abbrev TestM := ReaderT TestContext CompilerM
|
||
abbrev SimpleAssertionM := ReaderT SimpleAssertionContext TestM
|
||
abbrev InOutAssertionM := ReaderT InOutAssertionContext TestM
|
||
abbrev SimpleTest := SimpleAssertionM Unit
|
||
abbrev InOutTest := InOutAssertionM Unit
|
||
|
||
def TestInstaller.install (x : TestInstaller) (passUnderTestName testName : Name) : PassInstaller :=
|
||
x { passUnderTestName, testName }
|
||
|
||
def TestM.run (x : TestM α) (passUnderTest : Pass) (testName : Name) : CompilerM α :=
|
||
x { passUnderTest, testName }
|
||
|
||
def SimpleAssertionM.run (x : SimpleAssertionM α) (decls : Array Decl) (passUnderTest : Pass) (testName : Name) : CompilerM α :=
|
||
x { decls } { passUnderTest, testName }
|
||
|
||
def InOutAssertionM.run (x : InOutAssertionM α) (input output : Array Decl) (passUnderTest : Pass) (testName : Name) : CompilerM α :=
|
||
x { input, output } { passUnderTest, testName }
|
||
|
||
def getTestName : TestM Name := do
|
||
return (←read).testName
|
||
|
||
def getPassUnderTest : TestM Pass := do
|
||
return (←read).passUnderTest
|
||
|
||
def getDecls : SimpleAssertionM (Array Decl) := do
|
||
return (←read).decls
|
||
|
||
def getInputDecls : InOutAssertionM (Array Decl) := do
|
||
return (←read).input
|
||
|
||
def getOutputDecls : InOutAssertionM (Array Decl) := do
|
||
return (←read).output
|
||
|
||
/--
|
||
If `property` is `false` throw an exception with `msg`.
|
||
-/
|
||
def assert (property : Bool) (msg : String) : TestM Unit := do
|
||
unless property do
|
||
throwError msg
|
||
|
||
private def assertAfterTest (test : SimpleTest) : TestInstallerM (Pass → Pass) := do
|
||
let testName := (←read).testName
|
||
return fun passUnderTest => {
|
||
phase := passUnderTest.phase
|
||
name := testName
|
||
run := fun decls => do
|
||
trace[Compiler.test] "Starting post condition test {testName} for {passUnderTest.name} occurrence {passUnderTest.occurrence}"
|
||
test.run decls passUnderTest testName
|
||
trace[Compiler.test] "Post condition test {testName} for {passUnderTest.name} occurrence {passUnderTest.occurrence} successful"
|
||
return decls
|
||
}
|
||
|
||
/--
|
||
Install an assertion pass right after a specific occurrence of a pass,
|
||
default is first.
|
||
-/
|
||
def assertAfter (test : SimpleTest) (occurrence : Nat := 0): TestInstaller := do
|
||
let passUnderTestName := (←read).passUnderTestName
|
||
let assertion ← assertAfterTest test
|
||
return .installAfter passUnderTestName assertion occurrence
|
||
|
||
/--
|
||
Install an assertion pass right after each occurrence of a pass.
|
||
-/
|
||
def assertAfterEachOccurrence (test : SimpleTest) : TestInstaller := do
|
||
let passUnderTestName := (←read).passUnderTestName
|
||
let assertion ← assertAfterTest test
|
||
return .installAfterEach passUnderTestName assertion
|
||
|
||
/--
|
||
Install an assertion pass right after a specific occurrence of a pass,
|
||
default is first. The assertion operates on a per declaration basis.
|
||
-/
|
||
def assertForEachDeclAfter (assertion : Pass → Decl → Bool) (msg : String) (occurrence : Nat := 0) : TestInstaller :=
|
||
let assertion := do
|
||
let pass ← getPassUnderTest
|
||
(←getDecls).forM (fun decl => assert (assertion pass decl) msg)
|
||
assertAfter assertion occurrence
|
||
|
||
/--
|
||
Install an assertion pass right after the each occurrence of a pass. The
|
||
assertion operates on a per declaration basis.
|
||
-/
|
||
def assertForEachDeclAfterEachOccurrence (assertion : Pass → Decl → Bool) (msg : String) : TestInstaller :=
|
||
assertAfterEachOccurrence <| do
|
||
let pass ← getPassUnderTest
|
||
(←getDecls).forM (fun decl => assert (assertion pass decl) msg)
|
||
|
||
private def assertAroundTest (test : InOutTest) : TestInstallerM (Pass → Pass) := do
|
||
let testName := (←read).testName
|
||
return fun passUnderTest => {
|
||
phase := passUnderTest.phase
|
||
name := passUnderTest.name
|
||
run := fun decls => do
|
||
trace[Compiler.test] "Starting wrapper test {testName} for {passUnderTest.name} occurrence {passUnderTest.occurrence}"
|
||
let newDecls ← passUnderTest.run decls
|
||
test.run decls newDecls passUnderTest testName
|
||
trace[Compiler.test] "Wrapper test {testName} for {passUnderTest.name} occurrence {passUnderTest.occurrence} successful"
|
||
return newDecls
|
||
}
|
||
|
||
/--
|
||
Replace a specific occurrence, default is first, of a pass with a wrapper one that allows
|
||
the user to provide an assertion which takes into account both the
|
||
declarations that were sent to and produced by the pass under test.
|
||
-/
|
||
def assertAround (test : InOutTest) (occurrence : Nat := 0) : TestInstaller := do
|
||
let passUnderTestName := (←read).passUnderTestName
|
||
let assertion ← assertAroundTest test
|
||
return .replacePass passUnderTestName assertion occurrence
|
||
|
||
/--
|
||
Replace all occurrences of a pass with a wrapper one that allows
|
||
the user to provide an assertion which takes into account both the
|
||
declarations that were sent to and produced by the pass under test.
|
||
-/
|
||
def assertAroundEachOccurrence (test : InOutTest) : TestInstaller := do
|
||
let passUnderTestName := (←read).passUnderTestName
|
||
let assertion ← assertAroundTest test
|
||
return .replaceEachOccurrence passUnderTestName assertion
|
||
|
||
private def throwFixPointError (err : String) (firstResult secondResult : Array Decl) : CompilerM Unit := do
|
||
let mut err := err
|
||
err := err ++ "Result after usual run:"
|
||
let folder := fun err decl => do return err ++ s!"\n{←ppDecl decl}"
|
||
err ← firstResult.foldlM (init := err) folder
|
||
err := err ++ "Result after further run:"
|
||
err ← secondResult.foldlM (init := err) folder
|
||
throwError err
|
||
|
||
/--
|
||
Insert a pass after `passUnderTestName`, that ensures, that if
|
||
`passUnderTestName` is executed twice in a row, no change in the resulting
|
||
expression will occur, i.e. the pass is at a fix point.
|
||
-/
|
||
def assertIsAtFixPoint : TestInstaller :=
|
||
let test := do
|
||
let passUnderTest ← getPassUnderTest
|
||
let decls ← getDecls
|
||
let secondResult ← passUnderTest.run decls
|
||
if decls.size < secondResult.size then
|
||
let err := s!"Pass {passUnderTest.name} did not reach a fixpoint, it added declarations on further runs:\n"
|
||
throwFixPointError err decls secondResult
|
||
else if decls.size > secondResult.size then
|
||
let err := s!"Pass {passUnderTest.name} did not reach a fixpoint, it removed declarations on further runs:\n"
|
||
throwFixPointError err decls secondResult
|
||
else if decls != secondResult then
|
||
let err := s!"Pass {passUnderTest.name} did not reach a fixpoint, it either changed declarations or their order:\n"
|
||
throwFixPointError err decls secondResult
|
||
assertAfterEachOccurrence test
|
||
|
||
/--
|
||
Compare the overall sizes of the input and output of `passUnderTest` with `assertion`.
|
||
If `assertion inputSize outputSize` is `false` throw an exception with `msg`.
|
||
-/
|
||
def assertSize (assertion : Nat → Nat → Bool) (msg : String) : TestInstaller :=
|
||
let sumDeclSizes := fun decls => decls.map Decl.size |>.foldl (init := 0) (· + ·)
|
||
let assertion := (fun inputS outputS => Testing.assert (assertion inputS outputS) s!"{msg}: input size {inputS} output size {outputS}")
|
||
assertAroundEachOccurrence (do assertion (sumDeclSizes (←getInputDecls)) (sumDeclSizes (←getOutputDecls)))
|
||
|
||
/--
|
||
Assert that the overall size of the `Decl`s in the compilation pipeline does not change
|
||
after `passUnderTestName`.
|
||
-/
|
||
def assertPreservesSize (msg : String) : TestInstaller :=
|
||
assertSize (· == ·) msg
|
||
|
||
/--
|
||
Assert that the overall size of the `Decl`s in the compilation pipeline gets reduced by `passUnderTestName`.
|
||
-/
|
||
def assertReducesSize (msg : String) : TestInstaller :=
|
||
assertSize (· > ·) msg
|
||
|
||
/--
|
||
Assert that the overall size of the `Decl`s in the compilation pipeline gets reduced or stays unchanged
|
||
by `passUnderTestName`.
|
||
-/
|
||
def assertReducesOrPreservesSize (msg : String) : TestInstaller :=
|
||
assertSize (· ≥ ·) msg
|
||
|
||
/--
|
||
Assert that the pass under test produces `Decl`s that do not contain
|
||
`Expr.const constName` in their `Code.let` values anymore.
|
||
-/
|
||
def assertDoesNotContainConstAfter (constName : Name) (msg : String) : TestInstaller :=
|
||
assertForEachDeclAfterEachOccurrence
|
||
fun _ decl =>
|
||
match decl.value with
|
||
| .code c => !c.containsConst constName
|
||
| .extern .. => true
|
||
msg
|
||
|
||
def assertNoFun : TestInstaller :=
|
||
assertAfter do
|
||
for decl in (← getDecls) do
|
||
decl.value.forCodeM fun
|
||
| .fun .. => throwError "declaration `{decl.name}` contains a local function declaration"
|
||
| _ => return ()
|
||
|
||
end Testing
|
||
end Lean.Compiler.LCNF
|