refactor: use a separate getter for fvar values in toIR (#9978)

This commit is contained in:
Cameron Zwarich 2025-08-18 18:28:15 -07:00 committed by GitHub
parent f88d35f6c9
commit b68f3455d3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -28,6 +28,7 @@ inductive FVarClassification where
| var (id : VarId)
| joinPoint (id : JoinPointId)
| erased
deriving Inhabited
structure BuilderState where
fvars : Std.HashMap FVarId FVarClassification := {}
@ -38,6 +39,9 @@ abbrev M := StateRefT BuilderState CoreM
def M.run (x : M α) : CoreM α := do
x.run' {}
def getFVarValue (fvarId : FVarId) : M FVarClassification := do
return (← get).fvars.get! fvarId
def bindVar (fvarId : FVarId) : M VarId := do
modifyGet fun s =>
let varId := { idx := s.nextId }
@ -82,10 +86,10 @@ def lowerLitValue (v : LCNF.LitValue) : LitVal × IRType :=
def lowerArg (a : LCNF.Arg) : M Arg := do
match a with
| .fvar fvarId =>
match (← get).fvars[fvarId]? with
| some (.var varId) => return .var varId
| some .erased => return .erased
| some (.joinPoint ..) | none => panic! "unexpected value"
match (← getFVarValue fvarId) with
| .var varId => return .var varId
| .erased => return .erased
| .joinPoint .. => panic! "unexpected value"
| .erased | .type .. => return .erased
inductive TranslatedProj where
@ -116,18 +120,18 @@ partial def lowerCode (c : LCNF.Code) : M FnBody := do
let body ← lowerCode decl.value
return .jdecl joinPoint params body (← lowerCode k)
| .jmp fvarId args =>
match (← get).fvars[fvarId]? with
| some (.joinPoint joinPointId) =>
match (← getFVarValue fvarId) with
| .joinPoint joinPointId =>
return .jmp joinPointId (← args.mapM lowerArg)
| some (.var ..) | some .erased | none => panic! "unexpected value"
| .var .. | .erased => panic! "unexpected value"
| .cases cases =>
match (← get).fvars[cases.discr]? with
| some (.var varId) =>
match (← getFVarValue cases.discr) with
| .var varId =>
return .case cases.typeName
varId
(← nameToIRType cases.typeName)
(← cases.alts.mapM (lowerAlt varId))
| some .erased =>
| .erased =>
let #[alt] := cases.alts | panic! "erased inductive should only have one case"
match alt with
| .alt _ params code =>
@ -135,12 +139,12 @@ partial def lowerCode (c : LCNF.Code) : M FnBody := do
lowerCode code
| .default code =>
lowerCode code
| some (.joinPoint ..) | none => panic! "unexpected value"
| .joinPoint .. => panic! "unexpected value"
| .return fvarId =>
let arg := match (← get).fvars[fvarId]? with
| some (.var varId) => .var varId
| some .erased => .erased
| some (.joinPoint ..) | none => panic! "unexpected value"
let arg := match (← getFVarValue fvarId) with
| .var varId => .var varId
| .erased => .erased
| .joinPoint .. => panic! "unexpected value"
return .ret arg
| .unreach .. => return .unreachable
| .fun .. => panic! "all local functions should be λ-lifted"
@ -152,8 +156,8 @@ partial def lowerLet (decl : LCNF.LetDecl) (k : LCNF.Code) : M FnBody := do
let ⟨litValue, type⟩ := lowerLitValue litValue
return .vdecl var type (.lit litValue) (← lowerCode k)
| .proj typeName i fvarId =>
match (← get).fvars[fvarId]? with
| some (.var varId) =>
match (← getFVarValue fvarId) with
| .var varId =>
let some (.inductInfo { ctors := [ctorName], .. }) := (← Lean.getEnv).find? typeName
| panic! "projection of non-structure type"
let ⟨ctorInfo, fields⟩ ← getCtorLayout ctorName
@ -165,10 +169,10 @@ partial def lowerLet (decl : LCNF.LetDecl) (k : LCNF.Code) : M FnBody := do
| .erased =>
bindErased decl.fvarId
lowerCode k
| some .erased =>
| .erased =>
bindErased decl.fvarId
lowerCode k
| some (.joinPoint ..) | none => panic! "unexpected value"
| .joinPoint .. => panic! "unexpected value"
| .const name _ args =>
let irArgs ← args.mapM lowerArg
if let some code ← tryIrDecl? name irArgs then
@ -219,12 +223,12 @@ partial def lowerLet (decl : LCNF.LetDecl) (k : LCNF.Code) : M FnBody := do
throwError f!"code generator does not support recursor '{name}' yet, consider using 'match ... with' and/or structural recursion"
| none => panic! "reference to unbound name"
| .fvar fvarId args =>
match (← get).fvars[fvarId]? with
| some (.var id) =>
match (← getFVarValue fvarId) with
| .var id =>
let irArgs ← args.mapM lowerArg
mkAp id irArgs
| some .erased => mkErased ()
| some (.joinPoint ..) | none => panic! "unexpected value"
| .erased => mkErased ()
| .joinPoint .. => panic! "unexpected value"
| .erased => mkErased ()
where
mkVar (v : VarId) : M FnBody := do