chore: make IR.Arg pattern matching more exhaustive (#9370)

This commit is contained in:
Cameron Zwarich 2025-07-14 15:46:40 -07:00 committed by GitHub
parent a4b5eecb8e
commit 9d33f2ad33
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 26 additions and 26 deletions

View file

@ -208,8 +208,8 @@ def mkCast (x : VarId) (xType : IRType) (expectedType : IRType) : M Expr := do
@[inline] def castArgIfNeeded (x : Arg) (expected : IRType) (k : Arg → M FnBody) : M FnBody :=
match x with
| Arg.var x => castVarIfNeeded x expected (fun x => k (Arg.var x))
| _ => k x
| .var x => castVarIfNeeded x expected (fun x => k (.var x))
| .erased => k x
def castArgsIfNeededAux (xs : Array Arg) (typeFromIdx : Nat → IRType) : M (Array Arg × Array FnBody) := do
let mut xs' := #[]

View file

@ -69,8 +69,8 @@ def checkJP (j : JoinPointId) : M Unit := do
def checkArg (a : Arg) : M Unit :=
match a with
| Arg.var x => checkVar x
| _ => pure ()
| .var x => checkVar x
| .erased => pure ()
def checkArgs (as : Array Arg) : M Unit :=
as.forM checkArg

View file

@ -158,8 +158,8 @@ def findVarValue (x : VarId) : M Value := do
def findArgValue (arg : Arg) : M Value :=
match arg with
| Arg.var x => findVarValue x
| _ => pure top
| .var x => findVarValue x
| .erased => pure top
def updateVarAssignment (x : VarId) (v : Value) : M Unit := do
let v' ← findVarValue x

View file

@ -47,8 +47,8 @@ def emitLns {α : Type} [ToString α] (as : List α) : M Unit :=
def argToCString (x : Arg) : String :=
match x with
| Arg.var x => toString x
| _ => "lean_box(0)"
| .var x => toString x
| .erased => "lean_box(0)"
def emitArg (x : Arg) : M Unit :=
emit (argToCString x)
@ -540,8 +540,8 @@ def isTailCall (x : VarId) (v : Expr) (b : FnBody) : M Bool := do
def paramEqArg (p : Param) (x : Arg) : Bool :=
match x with
| Arg.var x => p.x == x
| _ => false
| .var x => p.x == x
| .erased => false
/--
Given `[p_0, ..., p_{n-1}]`, `[y_0, ..., y_{n-1}]`, representing the assignments

View file

@ -548,8 +548,8 @@ def emitLhsSlotStore (builder : LLVM.Builder llvmctx)
def emitArgSlot_ (builder : LLVM.Builder llvmctx)
(x : Arg) : M llvmctx (LLVM.LLVMType llvmctx × LLVM.Value llvmctx) := do
match x with
| Arg.var x => emitLhsSlot_ x
| _ => do
| .var x => emitLhsSlot_ x
| .erased => do
let slotty ← LLVM.voidPtrType llvmctx
let slot ← buildPrologueAlloca builder slotty "erased_slot"
let v ← callLeanBox builder (← constIntSizeT 0) "erased_val"

View file

@ -147,11 +147,11 @@ def setFields (y : VarId) (zs : Array Arg) (b : FnBody) : FnBody :=
/-- Given `set x[i] := y`, return true iff `y := proj[i] x` -/
def isSelfSet (ctx : Context) (x : VarId) (i : Nat) (y : Arg) : Bool :=
match y with
| Arg.var y =>
| .var y =>
match ctx.projMap[y]? with
| some (Expr.proj j w) => j == i && w == x
| _ => false
| _ => false
| .erased => false
/-- Given `uset x[i] := y`, return true iff `y := uproj[i] x` -/
def isSelfUSet (ctx : Context) (x : VarId) (i : Nat) (y : VarId) : Bool :=

View file

@ -28,8 +28,8 @@ instance : AndThen Collector where
andThen a b := seq a (b ())
private def collectArg : Arg → Collector
| Arg.var x => collectVar x
| _ => skip
| .var x => collectVar x
| .erased => skip
private def collectArray {α : Type} (as : Array α) (f : α → Collector) : Collector :=
fun m => as.foldl (fun m a => f a m) m
@ -124,8 +124,8 @@ instance : AndThen Collector where
andThen a b := seq a (b ())
private def collectArg : Arg → Collector
| Arg.var x => collectVar x
| _ => skip
| .var x => collectVar x
| .erased => skip
private def collectArray {α : Type} (as : Array α) (f : α → Collector) : Collector :=
fun bv fv => as.foldl (fun fv a => f a bv fv) fv
@ -184,8 +184,8 @@ def visitVar (w : Index) (x : VarId) : Bool := w == x.idx
def visitJP (w : Index) (x : JoinPointId) : Bool := w == x.idx
def visitArg (w : Index) : Arg → Bool
| Arg.var x => visitVar w x
| _ => false
| .var x => visitVar w x
| .erased => false
def visitArgs (w : Index) (xs : Array Arg) : Bool :=
xs.any (visitArg w)

View file

@ -93,8 +93,8 @@ abbrev Collector := LiveVarSet → LiveVarSet
@[inline] private def collectVar (x : VarId) : Collector := fun s => s.insert x
private def collectArg : Arg → Collector
| Arg.var x => collectVar x
| _ => skip
| .var x => collectVar x
| .erased => skip
private def collectArray {α : Type} (as : Array α) (f : α → Collector) : Collector := fun s =>
as.foldl (fun s a => f a s) s

View file

@ -50,8 +50,8 @@ def normJP (x : JoinPointId) : M JoinPointId :=
JoinPointId.mk <$> normIndex x.idx
def normArg : Arg → M Arg
| Arg.var x => Arg.var <$> normVar x
| other => pure other
| .var x => .var <$> normVar x
| .erased => pure .erased
def normArgs (as : Array Arg) : M (Array Arg) := fun m =>
as.map fun a => normArg a m
@ -128,8 +128,8 @@ def Decl.normalizeIds (d : Decl) : Decl :=
namespace MapVars
@[inline] def mapArg (f : VarId → VarId) : Arg → Arg
| Arg.var x => Arg.var (f x)
| a => a
| .var x => .var (f x)
| .erased => .erased
def mapArgs (f : VarId → VarId) (as : Array Arg) : Array Arg :=
as.map (mapArg f)