From 36ed58351d5db5feeb04fd1c3ffdb16a115f2af8 Mon Sep 17 00:00:00 2001 From: Cameron Zwarich Date: Sun, 27 Apr 2025 10:11:36 -0700 Subject: [PATCH] fix: add support for builtin casesOn recursors to the new compiler (#8132) This PR adds support for lowering `casesOn` for builtin types in the new compiler. --- src/Lean/Compiler/LCNF/ToMono.lean | 121 +++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) diff --git a/src/Lean/Compiler/LCNF/ToMono.lean b/src/Lean/Compiler/LCNF/ToMono.lean index 861e3dd16a..a94dd83ae6 100644 --- a/src/Lean/Compiler/LCNF/ToMono.lean +++ b/src/Lean/Compiler/LCNF/ToMono.lean @@ -101,6 +101,107 @@ partial def decToMono (c : Cases) (_ : c.typeName == ``Decidable) : ToMonoM Code return .alt ctorName #[] (← k.toMono) return .cases { c with resultType, alts, typeName := ``Bool } +/-- Eliminate `cases` for `Nat`. -/ +partial def casesNatToMono (c: Cases) (_ : c.typeName == ``Nat) : ToMonoM Code := do + let resultType ← toMonoType c.resultType + let natType := mkConst ``Nat + let zeroDecl ← mkLetDecl `zero natType (.value (.natVal 0)) + let isZeroDecl ← mkLetDecl `isZero (mkConst ``Bool) (.const ``Nat.decEq [] #[.fvar c.discr, .fvar zeroDecl.fvarId]) + let alts ← c.alts.mapM fun alt => do + match alt with + | .default k => return alt.updateCode (← k.toMono) + | .alt ctorName ps k => + eraseParams ps + if ctorName == ``Nat.succ then + let p := ps[0]! + let oneDecl ← mkLetDecl `one natType (.value (.natVal 1)) + let subOneDecl := { fvarId := p.fvarId, binderName := p.binderName, type := natType, value := .const ``Nat.sub [] #[.fvar c.discr, .fvar oneDecl.fvarId] } + modifyLCtx fun lctx => lctx.addLetDecl subOneDecl + return .alt ``Bool.false #[] (.let oneDecl (.let subOneDecl (← k.toMono))) + else + return .alt ``Bool.true #[] (← k.toMono) + return .let zeroDecl (.let isZeroDecl (.cases { discr := isZeroDecl.fvarId, resultType, alts, typeName := ``Bool })) + +/-- Eliminate `cases` for `Int`. -/ +partial def casesIntToMono (c: Cases) (_ : c.typeName == ``Int) : ToMonoM Code := do + let resultType ← toMonoType c.resultType + let natType := mkConst ``Nat + let zeroNatDecl ← mkLetDecl `natZero natType (.value (.natVal 0)) + let zeroIntDecl ← mkLetDecl `intZero (mkConst ``Int) (.const ``Int.ofNat [] #[.fvar zeroNatDecl.fvarId]) + let isNegDecl ← mkLetDecl `isNeg (mkConst ``Bool) (.const ``Int.decLt [] #[.fvar c.discr, .fvar zeroIntDecl.fvarId]) + let alts ← c.alts.mapM fun alt => do + match alt with + | .default k => return alt.updateCode (← k.toMono) + | .alt ctorName ps k => + eraseParams ps + let p := ps[0]! + if ctorName == ``Int.negSucc then + let absDecl ← mkLetDecl `abs natType (.const ``Int.natAbs [] #[.fvar c.discr]) + let oneDecl ← mkLetDecl `one natType (.value (.natVal 1)) + let subOneDecl := { fvarId := p.fvarId, binderName := p.binderName, type := natType, value := .const ``Nat.sub [] #[.fvar absDecl.fvarId, .fvar oneDecl.fvarId] } + modifyLCtx fun lctx => lctx.addLetDecl subOneDecl + return .alt ``Bool.true #[] (.let absDecl (.let oneDecl (.let subOneDecl (← k.toMono)))) + else + let absDecl := { fvarId := p.fvarId, binderName := p.binderName, type := natType, value := .const ``Int.natAbs [] #[.fvar c.discr] } + modifyLCtx fun lctx => lctx.addLetDecl absDecl + return .alt ``Bool.false #[] (.let absDecl (← k.toMono)) + return .let zeroNatDecl (.let zeroIntDecl (.let isNegDecl (.cases { discr := isNegDecl.fvarId, resultType, alts, typeName := ``Bool }))) + +/-- Eliminate `cases` for `UInt` types. -/ +partial def casesUIntToMono (c : Cases) (uintName : Name) (_ : c.typeName == uintName) : ToMonoM Code := do + assert! c.alts.size == 1 + let .alt _ ps k := c.alts[0]! | unreachable! + eraseParams ps + let p := ps[0]! + let decl := { fvarId := p.fvarId, binderName := p.binderName, type := anyExpr, value := .const (.str uintName "toBitVec") [] #[.fvar c.discr] } + modifyLCtx fun lctx => lctx.addLetDecl decl + let k ← k.toMono + return .let decl k + +/-- Eliminate `cases` for `Array. -/ +partial def casesArrayToMono (c : Cases) (_ : c.typeName == ``Array) : ToMonoM Code := do + assert! c.alts.size == 1 + let .alt _ ps k := c.alts[0]! | unreachable! + eraseParams ps + let p := ps[0]! + let decl := { fvarId := p.fvarId, binderName := p.binderName, type := anyExpr, value := .const ``Array.toList [] #[.erased, .fvar c.discr] } + modifyLCtx fun lctx => lctx.addLetDecl decl + let k ← k.toMono + return .let decl k + +/-- Eliminate `cases` for `ByteArray. -/ +partial def casesByteArrayToMono (c : Cases) (_ : c.typeName == ``ByteArray) : ToMonoM Code := do + assert! c.alts.size == 1 + let .alt _ ps k := c.alts[0]! | unreachable! + eraseParams ps + let p := ps[0]! + let decl := { fvarId := p.fvarId, binderName := p.binderName, type := anyExpr, value := .const ``ByteArray.data [] #[.fvar c.discr] } + modifyLCtx fun lctx => lctx.addLetDecl decl + let k ← k.toMono + return .let decl k + +/-- Eliminate `cases` for `FloatArray. -/ +partial def casesFloatArrayToMono (c : Cases) (_ : c.typeName == ``FloatArray) : ToMonoM Code := do + assert! c.alts.size == 1 + let .alt _ ps k := c.alts[0]! | unreachable! + eraseParams ps + let p := ps[0]! + let decl := { fvarId := p.fvarId, binderName := p.binderName, type := anyExpr, value := .const ``FloatArray.data [] #[.fvar c.discr] } + modifyLCtx fun lctx => lctx.addLetDecl decl + let k ← k.toMono + return .let decl k + +/-- Eliminate `cases` for `String. -/ +partial def casesStringToMono (c : Cases) (_ : c.typeName == ``String) : ToMonoM Code := do + assert! c.alts.size == 1 + let .alt _ ps k := c.alts[0]! | unreachable! + eraseParams ps + let p := ps[0]! + let decl := { fvarId := p.fvarId, binderName := p.binderName, type := anyExpr, value := .const ``String.toList [] #[.fvar c.discr] } + modifyLCtx fun lctx => lctx.addLetDecl decl + let k ← k.toMono + return .let decl k + /-- Eliminate `cases` for trivial structure. See `hasTrivialStructure?` -/ partial def trivialStructToMono (info : TrivialStructureInfo) (c : Cases) : ToMonoM Code := do assert! c.alts.size == 1 @@ -124,6 +225,26 @@ partial def Code.toMono (code : Code) : ToMonoM Code := do | .cases c => if h : c.typeName == ``Decidable then decToMono c h + else if h : c.typeName == ``Nat then + casesNatToMono c h + else if h : c.typeName == ``Int then + casesIntToMono c h + else if h : c.typeName == ``UInt8 then + casesUIntToMono c ``UInt8 h + else if h : c.typeName == ``UInt16 then + casesUIntToMono c ``UInt16 h + else if h : c.typeName == ``UInt32 then + casesUIntToMono c ``UInt32 h + else if h : c.typeName == ``UInt64 then + casesUIntToMono c ``UInt64 h + else if h : c.typeName == ``Array then + casesArrayToMono c h + else if h : c.typeName == ``ByteArray then + casesByteArrayToMono c h + else if h : c.typeName == ``FloatArray then + casesFloatArrayToMono c h + else if h : c.typeName == ``String then + casesStringToMono c h else if let some info ← hasTrivialStructure? c.typeName then trivialStructToMono info c else