fix: add support for builtin casesOn recursors to the new compiler (#8132)

This PR adds support for lowering `casesOn` for builtin types in the new
compiler.
This commit is contained in:
Cameron Zwarich 2025-04-27 10:11:36 -07:00 committed by GitHub
parent 26138a5362
commit 36ed58351d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -101,6 +101,107 @@ partial def decToMono (c : Cases) (_ : c.typeName == ``Decidable) : ToMonoM Code
return .alt ctorName #[] (← k.toMono)
return .cases { c with resultType, alts, typeName := ``Bool }
/-- Eliminate `cases` for `Nat`. -/
partial def casesNatToMono (c: Cases) (_ : c.typeName == ``Nat) : ToMonoM Code := do
let resultType ← toMonoType c.resultType
let natType := mkConst ``Nat
let zeroDecl ← mkLetDecl `zero natType (.value (.natVal 0))
let isZeroDecl ← mkLetDecl `isZero (mkConst ``Bool) (.const ``Nat.decEq [] #[.fvar c.discr, .fvar zeroDecl.fvarId])
let alts ← c.alts.mapM fun alt => do
match alt with
| .default k => return alt.updateCode (← k.toMono)
| .alt ctorName ps k =>
eraseParams ps
if ctorName == ``Nat.succ then
let p := ps[0]!
let oneDecl ← mkLetDecl `one natType (.value (.natVal 1))
let subOneDecl := { fvarId := p.fvarId, binderName := p.binderName, type := natType, value := .const ``Nat.sub [] #[.fvar c.discr, .fvar oneDecl.fvarId] }
modifyLCtx fun lctx => lctx.addLetDecl subOneDecl
return .alt ``Bool.false #[] (.let oneDecl (.let subOneDecl (← k.toMono)))
else
return .alt ``Bool.true #[] (← k.toMono)
return .let zeroDecl (.let isZeroDecl (.cases { discr := isZeroDecl.fvarId, resultType, alts, typeName := ``Bool }))
/-- Eliminate `cases` for `Int`. -/
partial def casesIntToMono (c: Cases) (_ : c.typeName == ``Int) : ToMonoM Code := do
let resultType ← toMonoType c.resultType
let natType := mkConst ``Nat
let zeroNatDecl ← mkLetDecl `natZero natType (.value (.natVal 0))
let zeroIntDecl ← mkLetDecl `intZero (mkConst ``Int) (.const ``Int.ofNat [] #[.fvar zeroNatDecl.fvarId])
let isNegDecl ← mkLetDecl `isNeg (mkConst ``Bool) (.const ``Int.decLt [] #[.fvar c.discr, .fvar zeroIntDecl.fvarId])
let alts ← c.alts.mapM fun alt => do
match alt with
| .default k => return alt.updateCode (← k.toMono)
| .alt ctorName ps k =>
eraseParams ps
let p := ps[0]!
if ctorName == ``Int.negSucc then
let absDecl ← mkLetDecl `abs natType (.const ``Int.natAbs [] #[.fvar c.discr])
let oneDecl ← mkLetDecl `one natType (.value (.natVal 1))
let subOneDecl := { fvarId := p.fvarId, binderName := p.binderName, type := natType, value := .const ``Nat.sub [] #[.fvar absDecl.fvarId, .fvar oneDecl.fvarId] }
modifyLCtx fun lctx => lctx.addLetDecl subOneDecl
return .alt ``Bool.true #[] (.let absDecl (.let oneDecl (.let subOneDecl (← k.toMono))))
else
let absDecl := { fvarId := p.fvarId, binderName := p.binderName, type := natType, value := .const ``Int.natAbs [] #[.fvar c.discr] }
modifyLCtx fun lctx => lctx.addLetDecl absDecl
return .alt ``Bool.false #[] (.let absDecl (← k.toMono))
return .let zeroNatDecl (.let zeroIntDecl (.let isNegDecl (.cases { discr := isNegDecl.fvarId, resultType, alts, typeName := ``Bool })))
/-- Eliminate `cases` for `UInt` types. -/
partial def casesUIntToMono (c : Cases) (uintName : Name) (_ : c.typeName == uintName) : ToMonoM Code := do
assert! c.alts.size == 1
let .alt _ ps k := c.alts[0]! | unreachable!
eraseParams ps
let p := ps[0]!
let decl := { fvarId := p.fvarId, binderName := p.binderName, type := anyExpr, value := .const (.str uintName "toBitVec") [] #[.fvar c.discr] }
modifyLCtx fun lctx => lctx.addLetDecl decl
let k ← k.toMono
return .let decl k
/-- Eliminate `cases` for `Array. -/
partial def casesArrayToMono (c : Cases) (_ : c.typeName == ``Array) : ToMonoM Code := do
assert! c.alts.size == 1
let .alt _ ps k := c.alts[0]! | unreachable!
eraseParams ps
let p := ps[0]!
let decl := { fvarId := p.fvarId, binderName := p.binderName, type := anyExpr, value := .const ``Array.toList [] #[.erased, .fvar c.discr] }
modifyLCtx fun lctx => lctx.addLetDecl decl
let k ← k.toMono
return .let decl k
/-- Eliminate `cases` for `ByteArray. -/
partial def casesByteArrayToMono (c : Cases) (_ : c.typeName == ``ByteArray) : ToMonoM Code := do
assert! c.alts.size == 1
let .alt _ ps k := c.alts[0]! | unreachable!
eraseParams ps
let p := ps[0]!
let decl := { fvarId := p.fvarId, binderName := p.binderName, type := anyExpr, value := .const ``ByteArray.data [] #[.fvar c.discr] }
modifyLCtx fun lctx => lctx.addLetDecl decl
let k ← k.toMono
return .let decl k
/-- Eliminate `cases` for `FloatArray. -/
partial def casesFloatArrayToMono (c : Cases) (_ : c.typeName == ``FloatArray) : ToMonoM Code := do
assert! c.alts.size == 1
let .alt _ ps k := c.alts[0]! | unreachable!
eraseParams ps
let p := ps[0]!
let decl := { fvarId := p.fvarId, binderName := p.binderName, type := anyExpr, value := .const ``FloatArray.data [] #[.fvar c.discr] }
modifyLCtx fun lctx => lctx.addLetDecl decl
let k ← k.toMono
return .let decl k
/-- Eliminate `cases` for `String. -/
partial def casesStringToMono (c : Cases) (_ : c.typeName == ``String) : ToMonoM Code := do
assert! c.alts.size == 1
let .alt _ ps k := c.alts[0]! | unreachable!
eraseParams ps
let p := ps[0]!
let decl := { fvarId := p.fvarId, binderName := p.binderName, type := anyExpr, value := .const ``String.toList [] #[.fvar c.discr] }
modifyLCtx fun lctx => lctx.addLetDecl decl
let k ← k.toMono
return .let decl k
/-- Eliminate `cases` for trivial structure. See `hasTrivialStructure?` -/
partial def trivialStructToMono (info : TrivialStructureInfo) (c : Cases) : ToMonoM Code := do
assert! c.alts.size == 1
@ -124,6 +225,26 @@ partial def Code.toMono (code : Code) : ToMonoM Code := do
| .cases c =>
if h : c.typeName == ``Decidable then
decToMono c h
else if h : c.typeName == ``Nat then
casesNatToMono c h
else if h : c.typeName == ``Int then
casesIntToMono c h
else if h : c.typeName == ``UInt8 then
casesUIntToMono c ``UInt8 h
else if h : c.typeName == ``UInt16 then
casesUIntToMono c ``UInt16 h
else if h : c.typeName == ``UInt32 then
casesUIntToMono c ``UInt32 h
else if h : c.typeName == ``UInt64 then
casesUIntToMono c ``UInt64 h
else if h : c.typeName == ``Array then
casesArrayToMono c h
else if h : c.typeName == ``ByteArray then
casesByteArrayToMono c h
else if h : c.typeName == ``FloatArray then
casesFloatArrayToMono c h
else if h : c.typeName == ``String then
casesStringToMono c h
else if let some info ← hasTrivialStructure? c.typeName then
trivialStructToMono info c
else