lean4-htt/src/Lean/Compiler/IR/SimpCase.lean
2025-10-16 20:27:46 +00:00

83 lines
2.3 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Compiler.IR.Format
public section
namespace Lean.IR
def ensureHasDefault (alts : Array Alt) : Array Alt :=
if alts.any Alt.isDefault then alts
else if alts.size < 2 then alts
else
let last := alts.back!
let alts := alts.pop
alts.push (Alt.default last.body)
private def getOccsOf (alts : Array Alt) (i : Nat) : Nat := Id.run do
let aBody := alts[i]!.body
let mut n := 1
for h : j in (i+1)...alts.size do
if alts[j].body == aBody then
n := n+1
return n
private def maxOccs (alts : Array Alt) : Alt × Nat := Id.run do
let mut maxAlt := alts[0]!
let mut max := getOccsOf alts 0
for h : i in 1...alts.size do
let curr := getOccsOf alts i
if curr > max then
maxAlt := alts[i]
max := curr
return (maxAlt, max)
private def addDefault (alts : Array Alt) : Array Alt :=
if alts.size <= 1 || alts.any Alt.isDefault then alts
else
let (max, noccs) := maxOccs alts
if noccs == 1 then alts
else
let alts := alts.filter fun alt => alt.body != max.body
alts.push (Alt.default max.body)
private def filterUnreachable (alts : Array Alt) : Array Alt :=
alts.filter fun alt => alt.body != FnBody.unreachable
private def mkSimpCase (tid : Name) (x : VarId) (xType : IRType) (alts : Array Alt) : FnBody :=
let alts := filterUnreachable alts
let alts := addDefault alts;
if alts.size == 0 then
FnBody.unreachable
else if _ : alts.size = 1 then
alts[0].body
else
FnBody.case tid x xType alts
partial def FnBody.simpCase (b : FnBody) : FnBody :=
let (bs, term) := b.flatten;
let bs := modifyJPs bs simpCase;
match term with
| FnBody.case tid x xType alts =>
let alts := alts.map fun alt => alt.modifyBody simpCase;
reshape bs (mkSimpCase tid x xType alts)
| _ => reshape bs term
/-- Simplify `case`
- Remove unreachable branches.
- Remove `case` if there is only one branch.
- Merge most common branches using `Alt.default`. -/
def Decl.simpCase (d : Decl) : Decl :=
match d with
| .fdecl (body := b) .. => d.updateBody! b.simpCase
| other => other
builtin_initialize registerTraceClass `compiler.ir.simp_case (inherited := true)
end Lean.IR