diff --git a/src/Lean/Compiler/LCNF/Basic.lean b/src/Lean/Compiler/LCNF/Basic.lean index 5b1f9947c9..fbc9e770ec 100644 --- a/src/Lean/Compiler/LCNF/Basic.lean +++ b/src/Lean/Compiler/LCNF/Basic.lean @@ -45,6 +45,9 @@ structure FunDeclCore (Code : Type) where value : Code deriving Inhabited +def FunDeclCore.getArity (decl : FunDeclCore Code) : Nat := + decl.params.size + structure CasesCore (Code : Type) where typeName : Name resultType : Expr diff --git a/src/Lean/Compiler/LCNF/Check.lean b/src/Lean/Compiler/LCNF/Check.lean index 36a659b7cb..22b0b64be3 100644 --- a/src/Lean/Compiler/LCNF/Check.lean +++ b/src/Lean/Compiler/LCNF/Check.lean @@ -157,7 +157,12 @@ partial def check (code : Code) : CheckM Expr := do withFVarId decl.fvarId do check k | .jp decl k => checkFunDecl decl; withJp decl.fvarId do check k | .cases c => checkCases c - | .jmp fvarId args => checkJpInScope fvarId; checkAppArgs (.fvar fvarId) args; code.inferType + | .jmp fvarId args => + checkJpInScope fvarId + let decl ← getFunDecl fvarId + unless decl.getArity == args.size do + throwError "invalid LCNF `jmp`, join point has #{decl.getArity} parameters, but #{args.size} were provided" + checkAppArgs (.fvar fvarId) args; code.inferType | .return fvarId => checkFVar fvarId; code.inferType | .unreach .. => code.inferType diff --git a/src/Lean/Compiler/LCNF/CompilerM.lean b/src/Lean/Compiler/LCNF/CompilerM.lean index bab0507449..729b0f1987 100644 --- a/src/Lean/Compiler/LCNF/CompilerM.lean +++ b/src/Lean/Compiler/LCNF/CompilerM.lean @@ -35,6 +35,10 @@ def getLocalDecl (fvarId : FVarId) : CompilerM LocalDecl := do let some decl := (← get).lctx.localDecls.find? fvarId | throwError "unknown free variable {fvarId.name}" return decl +def getFunDecl (fvarId : FVarId) : CompilerM FunDecl := do + let some decl := (← get).lctx.funDecls.find? fvarId | throwError "unknown local function {fvarId.name}" + return decl + @[inline] def modifyLCtx (f : LCtx → LCtx) : CompilerM Unit := do modify fun s => { s with lctx := f s.lctx }