From 65f9344f01b23f0a78c65bd8f5c7d48a8eff4760 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 25 Aug 2022 18:17:54 -0700 Subject: [PATCH] feat: check whether join points are fully applied at `Check.lean` --- src/Lean/Compiler/LCNF/Basic.lean | 3 +++ src/Lean/Compiler/LCNF/Check.lean | 7 ++++++- src/Lean/Compiler/LCNF/CompilerM.lean | 4 ++++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/Lean/Compiler/LCNF/Basic.lean b/src/Lean/Compiler/LCNF/Basic.lean index 5b1f9947c9..fbc9e770ec 100644 --- a/src/Lean/Compiler/LCNF/Basic.lean +++ b/src/Lean/Compiler/LCNF/Basic.lean @@ -45,6 +45,9 @@ structure FunDeclCore (Code : Type) where value : Code deriving Inhabited +def FunDeclCore.getArity (decl : FunDeclCore Code) : Nat := + decl.params.size + structure CasesCore (Code : Type) where typeName : Name resultType : Expr diff --git a/src/Lean/Compiler/LCNF/Check.lean b/src/Lean/Compiler/LCNF/Check.lean index 36a659b7cb..22b0b64be3 100644 --- a/src/Lean/Compiler/LCNF/Check.lean +++ b/src/Lean/Compiler/LCNF/Check.lean @@ -157,7 +157,12 @@ partial def check (code : Code) : CheckM Expr := do withFVarId decl.fvarId do check k | .jp decl k => checkFunDecl decl; withJp decl.fvarId do check k | .cases c => checkCases c - | .jmp fvarId args => checkJpInScope fvarId; checkAppArgs (.fvar fvarId) args; code.inferType + | .jmp fvarId args => + checkJpInScope fvarId + let decl ← getFunDecl fvarId + unless decl.getArity == args.size do + throwError "invalid LCNF `jmp`, join point has #{decl.getArity} parameters, but #{args.size} were provided" + checkAppArgs (.fvar fvarId) args; code.inferType | .return fvarId => checkFVar fvarId; code.inferType | .unreach .. => code.inferType diff --git a/src/Lean/Compiler/LCNF/CompilerM.lean b/src/Lean/Compiler/LCNF/CompilerM.lean index bab0507449..729b0f1987 100644 --- a/src/Lean/Compiler/LCNF/CompilerM.lean +++ b/src/Lean/Compiler/LCNF/CompilerM.lean @@ -35,6 +35,10 @@ def getLocalDecl (fvarId : FVarId) : CompilerM LocalDecl := do let some decl := (← get).lctx.localDecls.find? fvarId | throwError "unknown free variable {fvarId.name}" return decl +def getFunDecl (fvarId : FVarId) : CompilerM FunDecl := do + let some decl := (← get).lctx.funDecls.find? fvarId | throwError "unknown local function {fvarId.name}" + return decl + @[inline] def modifyLCtx (f : LCtx → LCtx) : CompilerM Unit := do modify fun s => { s with lctx := f s.lctx }