diff --git a/src/Lean/Compiler/LCNF/StructProjCases.lean b/src/Lean/Compiler/LCNF/StructProjCases.lean index 3b7ab7bbdf..984938b66e 100644 --- a/src/Lean/Compiler/LCNF/StructProjCases.lean +++ b/src/Lean/Compiler/LCNF/StructProjCases.lean @@ -19,17 +19,23 @@ def findStructCtorInfo? (typeName : Name) : CoreM (Option ConstructorVal) := do let some (.ctorInfo ctorInfo) := (← getEnv).find? ctorName | return none return ctorInfo -def mkFieldParamsForCtorType (e : Expr) (numParams : Nat): CompilerM (Array Param) := do - let rec loop (params : Array Param) (e : Expr) (numParams : Nat): CompilerM (Array Param) := do - match e with - | .forallE name type body _ => - if numParams == 0 then - let param ← mkParam name (← toMonoType type) false - loop (params.push param) body numParams - else - loop params body (numParams - 1) - | _ => return params - loop #[] e numParams +def mkFieldParamsForCtorType (ctorType : Expr) (numParams : Nat) (numFields : Nat) + : CompilerM (Array Param) := do + let mut type := ctorType + for _ in [0:numParams] do + match type with + | .forallE _ _ body _ => + type := body + | _ => unreachable! + let mut fields := Array.emptyWithCapacity numFields + for _ in [0:numFields] do + match type with + | .forallE name fieldType body _ => + let param ← mkParam name (← toMonoType fieldType) false + fields := fields.push param + type := body + | _ => unreachable! + return fields structure State where projMap : Std.HashMap FVarId (Array FVarId) := {} @@ -57,8 +63,7 @@ partial def visitCode (code : Code) : M Code := do visitCode k else let some ctorInfo ← findStructCtorInfo? typeName | panic! "expected struct constructor" - let params ← mkFieldParamsForCtorType ctorInfo.type ctorInfo.numParams - assert! params.size == ctorInfo.numFields + let params ← mkFieldParamsForCtorType ctorInfo.type ctorInfo.numParams ctorInfo.numFields let fvars := params.map (·.fvarId) modify fun s => { s with projMap := s.projMap.insert base fvars, fvarMap := s.fvarMap.insert decl.fvarId fvars[i]! }