From b68f3455d3a7fb15f2a7a2ef13193433e3564eff Mon Sep 17 00:00:00 2001 From: Cameron Zwarich Date: Mon, 18 Aug 2025 18:28:15 -0700 Subject: [PATCH] refactor: use a separate getter for fvar values in `toIR` (#9978) --- src/Lean/Compiler/IR/ToIR.lean | 50 ++++++++++++++++++---------------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/src/Lean/Compiler/IR/ToIR.lean b/src/Lean/Compiler/IR/ToIR.lean index 823e26d251..1099b6efb8 100644 --- a/src/Lean/Compiler/IR/ToIR.lean +++ b/src/Lean/Compiler/IR/ToIR.lean @@ -28,6 +28,7 @@ inductive FVarClassification where | var (id : VarId) | joinPoint (id : JoinPointId) | erased + deriving Inhabited structure BuilderState where fvars : Std.HashMap FVarId FVarClassification := {} @@ -38,6 +39,9 @@ abbrev M := StateRefT BuilderState CoreM def M.run (x : M α) : CoreM α := do x.run' {} +def getFVarValue (fvarId : FVarId) : M FVarClassification := do + return (← get).fvars.get! fvarId + def bindVar (fvarId : FVarId) : M VarId := do modifyGet fun s => let varId := { idx := s.nextId } @@ -82,10 +86,10 @@ def lowerLitValue (v : LCNF.LitValue) : LitVal × IRType := def lowerArg (a : LCNF.Arg) : M Arg := do match a with | .fvar fvarId => - match (← get).fvars[fvarId]? with - | some (.var varId) => return .var varId - | some .erased => return .erased - | some (.joinPoint ..) | none => panic! "unexpected value" + match (← getFVarValue fvarId) with + | .var varId => return .var varId + | .erased => return .erased + | .joinPoint .. => panic! "unexpected value" | .erased | .type .. => return .erased inductive TranslatedProj where @@ -116,18 +120,18 @@ partial def lowerCode (c : LCNF.Code) : M FnBody := do let body ← lowerCode decl.value return .jdecl joinPoint params body (← lowerCode k) | .jmp fvarId args => - match (← get).fvars[fvarId]? with - | some (.joinPoint joinPointId) => + match (← getFVarValue fvarId) with + | .joinPoint joinPointId => return .jmp joinPointId (← args.mapM lowerArg) - | some (.var ..) | some .erased | none => panic! "unexpected value" + | .var .. | .erased => panic! "unexpected value" | .cases cases => - match (← get).fvars[cases.discr]? with - | some (.var varId) => + match (← getFVarValue cases.discr) with + | .var varId => return .case cases.typeName varId (← nameToIRType cases.typeName) (← cases.alts.mapM (lowerAlt varId)) - | some .erased => + | .erased => let #[alt] := cases.alts | panic! "erased inductive should only have one case" match alt with | .alt _ params code => @@ -135,12 +139,12 @@ partial def lowerCode (c : LCNF.Code) : M FnBody := do lowerCode code | .default code => lowerCode code - | some (.joinPoint ..) | none => panic! "unexpected value" + | .joinPoint .. => panic! "unexpected value" | .return fvarId => - let arg := match (← get).fvars[fvarId]? with - | some (.var varId) => .var varId - | some .erased => .erased - | some (.joinPoint ..) | none => panic! "unexpected value" + let arg := match (← getFVarValue fvarId) with + | .var varId => .var varId + | .erased => .erased + | .joinPoint .. => panic! "unexpected value" return .ret arg | .unreach .. => return .unreachable | .fun .. => panic! "all local functions should be λ-lifted" @@ -152,8 +156,8 @@ partial def lowerLet (decl : LCNF.LetDecl) (k : LCNF.Code) : M FnBody := do let ⟨litValue, type⟩ := lowerLitValue litValue return .vdecl var type (.lit litValue) (← lowerCode k) | .proj typeName i fvarId => - match (← get).fvars[fvarId]? with - | some (.var varId) => + match (← getFVarValue fvarId) with + | .var varId => let some (.inductInfo { ctors := [ctorName], .. }) := (← Lean.getEnv).find? typeName | panic! "projection of non-structure type" let ⟨ctorInfo, fields⟩ ← getCtorLayout ctorName @@ -165,10 +169,10 @@ partial def lowerLet (decl : LCNF.LetDecl) (k : LCNF.Code) : M FnBody := do | .erased => bindErased decl.fvarId lowerCode k - | some .erased => + | .erased => bindErased decl.fvarId lowerCode k - | some (.joinPoint ..) | none => panic! "unexpected value" + | .joinPoint .. => panic! "unexpected value" | .const name _ args => let irArgs ← args.mapM lowerArg if let some code ← tryIrDecl? name irArgs then @@ -219,12 +223,12 @@ partial def lowerLet (decl : LCNF.LetDecl) (k : LCNF.Code) : M FnBody := do throwError f!"code generator does not support recursor '{name}' yet, consider using 'match ... with' and/or structural recursion" | none => panic! "reference to unbound name" | .fvar fvarId args => - match (← get).fvars[fvarId]? with - | some (.var id) => + match (← getFVarValue fvarId) with + | .var id => let irArgs ← args.mapM lowerArg mkAp id irArgs - | some .erased => mkErased () - | some (.joinPoint ..) | none => panic! "unexpected value" + | .erased => mkErased () + | .joinPoint .. => panic! "unexpected value" | .erased => mkErased () where mkVar (v : VarId) : M FnBody := do