From f559576994be67be32aef8bca371c880b5c70b72 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 9 Jul 2020 15:31:56 -0700 Subject: [PATCH] feat: inductive datatype header validation --- src/Lean/Elab/Inductive.lean | 120 +++++++++++++++++++++++- src/Lean/Elab/Term.lean | 10 ++ src/Lean/Meta/InferType.lean | 6 +- tests/lean/inductive1.lean | 51 ++++++++++ tests/lean/inductive1.lean.expected.out | 15 +++ 5 files changed, 194 insertions(+), 8 deletions(-) create mode 100644 tests/lean/inductive1.lean create mode 100644 tests/lean/inductive1.lean.expected.out diff --git a/src/Lean/Elab/Inductive.lean b/src/Lean/Elab/Inductive.lean index e11c62a7bf..af9367aeb6 100644 --- a/src/Lean/Elab/Inductive.lean +++ b/src/Lean/Elab/Inductive.lean @@ -21,13 +21,123 @@ structure InductiveView := instance InductiveView.inhabited : Inhabited InductiveView := ⟨{ ref := arbitrary _, modifiers := {}, declId := arbitrary _, binders := arbitrary _, type? := none, introRules := #[] }⟩ -def mkInductive (ref : Syntax) (declName : Name) (explictLevelNames : List Name) (vars : Array Expr) (xs : Array Expr) (type : Expr) (intros : Array Syntax) - : TermElabM Declaration := do -Term.throwError ref ref +structure ElabHeaderResult := +(view : InductiveView) +(lctx : LocalContext) +(localInsts : LocalInstances) +(params : Array Expr) +(type : Expr) + +instance ElabHeaderResult.inhabited : Inhabited ElabHeaderResult := +⟨{ view := arbitrary _, lctx := arbitrary _, localInsts := arbitrary _, params := #[], type := arbitrary _ }⟩ + +private partial def elabHeaderAux (views : Array InductiveView) + : Nat → Array ElabHeaderResult → TermElabM (Array ElabHeaderResult) +| i, acc => + if h : i < views.size then + let view := views.get ⟨i, h⟩; + Term.elabBinders view.binders.getArgs fun params => do + lctx ← Term.getLCtx; + localInsts ← Term.getLocalInsts; + match view.type? with + | none => do + u ← Term.mkFreshLevelMVar view.ref; + let type := mkSort (mkLevelSucc u); + elabHeaderAux (i+1) (acc.push { lctx := lctx, localInsts := localInsts, params := params, type := type, view := view }) + | some typeStx => do + type ← Term.elabTerm typeStx none; + unlessM (Term.isTypeFormerType view.ref type) $ + Term.throwError typeStx "invalid inductive type, resultant type is not a sort"; + elabHeaderAux (i+1) (acc.push { lctx := lctx, localInsts := localInsts, params := params, type := type, view := view }) + else + pure acc + +private def checkNumParams (rs : Array ElabHeaderResult) : TermElabM Nat := do +let numParams := (rs.get! 0).params.size; +rs.forM fun r => unless (r.params.size == numParams) $ + Term.throwError r.view.ref "invalid inductive type, number of parameters mismatch in mutually inductive datatype"; +pure numParams + +private def mkTypeFor (r : ElabHeaderResult) : TermElabM Expr := do +Term.withLocalContext r.lctx r.localInsts do + Term.mkForall r.view.ref r.params r.type + +private def throwUnexpectedInductiveType {α} (ref : Syntax) : TermElabM α := +Term.throwError ref "unexpected inductive resulting type" + +-- Given `e` of the form `forall As, B`, return `B`. +private def getResultingType (ref : Syntax) (e : Expr) : TermElabM Expr := +Term.liftMetaM ref $ Meta.forallTelescopeReducing e fun _ r => pure r + +-- Auxiliary function for checking whether the types in mutually inductive declaration are compatible. +private partial def checkParamsAndResultType (ref : Syntax) (numParams : Nat) : Nat → Expr → Expr → TermElabM Unit +| i, type, firstType => do + type ← Term.whnf ref type; + if i < numParams then do + firstType ← Term.whnf ref firstType; + match type, firstType with + | Expr.forallE n₁ d₁ b₁ c₁, Expr.forallE n₂ d₂ b₂ c₂ => do + unless (n₁ == n₂) $ + let msg : MessageData := + "invalid mutually inductive type, parameter name mismatch '" ++ n₁ ++ "', expected '" ++ n₂ ++ "'"; + Term.throwError ref msg; + unlessM (Term.isDefEq ref d₁ d₂) $ + let msg : MessageData := + "invalid mutually inductive type, type mismatch at parameter '" ++ n₁ ++ "'" ++ indentExpr d₁ + ++ Format.line ++ "expected type" ++ indentExpr d₂; + Term.throwError ref msg; + unless (c₁.binderInfo == c₂.binderInfo) $ + -- TODO: improve this error message? + Term.throwError ref ("invalid mutually inductive type, binder annotation mismatch at parameter '" ++ n₁ ++ "'"); + Term.withLocalDecl ref n₁ c₁.binderInfo d₁ fun x => + let type := b₁.instantiate1 x; + let firstType := b₂.instantiate1 x; + checkParamsAndResultType (i+1) type firstType + | _, _ => throwUnexpectedInductiveType ref + else + match type with + | Expr.forallE n d b c => + Term.withLocalDecl ref n c.binderInfo d fun x => + let type := b.instantiate1 x; + checkParamsAndResultType (i+1) type firstType + | Expr.sort _ _ => do + firstType ← getResultingType ref firstType; + unlessM (Term.isDefEq ref type firstType) $ + let msg : MessageData := + "invalid mutually inductive type, resulting universe mismatch, given " ++ indentExpr type ++ Format.line ++ "expected type" ++ indentExpr firstType; + Term.throwError ref msg + | _ => throwUnexpectedInductiveType ref + +-- Auxiliary function for checking whether the types in mutually inductive declaration are compatible. +private def checkHeader (r : ElabHeaderResult) (numParams : Nat) (firstType? : Option Expr) : TermElabM Expr := do +type ← mkTypeFor r; +match firstType? with +| none => pure type +| some firstType => do + checkParamsAndResultType r.view.ref numParams 0 type firstType; + pure firstType + +-- Auxiliary function for checking whether the types in mutually inductive declaration are compatible. +private partial def checkHeaders (rs : Array ElabHeaderResult) (numParams : Nat) : Nat → Option Expr → TermElabM Unit +| i, firstType? => when (i < rs.size) do + type ← checkHeader (rs.get! i) numParams firstType?; + checkHeaders (i+1) type + +private def elabHeader (views : Array InductiveView) : TermElabM (Array ElabHeaderResult) := do +rs ← elabHeaderAux views 0 #[]; +when (rs.size > 1) do { + numParams ← checkNumParams rs; + checkHeaders rs numParams 0 none +}; +pure rs + +private def mkInductiveDecl (views : Array InductiveView) : TermElabM Declaration := do +rs ← elabHeader views; +Term.throwError (views.get! 0).ref "WIP 2" def elabInductiveCore (views : Array InductiveView) : CommandElabM Unit := do -let ref := (views.get! 0).ref; -throwError ref ("WIP\n" ++ toString (views.map (fun (v : InductiveView) => v.ref))) +decl ← liftTermElabM none $ mkInductiveDecl views; +pure () -- pure () /- withDeclId declId $ fun name => do diff --git a/src/Lean/Elab/Term.lean b/src/Lean/Elab/Term.lean index 25f2bc4282..884ca4659e 100644 --- a/src/Lean/Elab/Term.lean +++ b/src/Lean/Elab/Term.lean @@ -247,8 +247,10 @@ fun ctx s => def ppGoal (ref : Syntax) (mvarId : MVarId) : TermElabM Format := liftMetaM ref $ Meta.ppGoal mvarId def isType (ref : Syntax) (e : Expr) : TermElabM Bool := liftMetaM ref $ Meta.isType e def isTypeFormer (ref : Syntax) (e : Expr) : TermElabM Bool := liftMetaM ref $ Meta.isTypeFormer e +def isTypeFormerType (ref : Syntax) (e : Expr) : TermElabM Bool := liftMetaM ref $ Meta.isTypeFormerType e def isDefEqNoConstantApprox (ref : Syntax) (t s : Expr) : TermElabM Bool := liftMetaM ref $ Meta.approxDefEq $ Meta.isDefEq t s def isDefEq (ref : Syntax) (t s : Expr) : TermElabM Bool := liftMetaM ref $ Meta.fullApproxDefEq $ Meta.isDefEq t s +def isLevelDefEq (ref : Syntax) (u v : Level) : TermElabM Bool := liftMetaM ref $ Meta.isLevelDefEq u v def inferType (ref : Syntax) (e : Expr) : TermElabM Expr := liftMetaM ref $ Meta.inferType e def whnf (ref : Syntax) (e : Expr) : TermElabM Expr := liftMetaM ref $ Meta.whnf e def whnfForall (ref : Syntax) (e : Expr) : TermElabM Expr := liftMetaM ref $ Meta.whnfForall e @@ -887,6 +889,14 @@ finally x (modify $ fun s => { s with cache := { s.cache with synthInstance := s @[inline] def resettingSynthInstanceCacheWhen {α} (b : Bool) (x : TermElabM α) : TermElabM α := if b then resettingSynthInstanceCache x else x +def withLocalContext {α} (lctx : LocalContext) (localInsts : LocalInstances) (x : TermElabM α) : TermElabM α := do +localInstsCurr ← getLocalInsts; +adaptReader (fun (ctx : Context) => { ctx with lctx := lctx, localInstances := localInsts }) $ + if localInsts == localInstsCurr then + x + else + resettingSynthInstanceCache x + /-- Execute `x` using the given metavariable's `LocalContext` and `LocalInstances`. The type class resolution cache is flushed when executing `x` if its `LocalInstances` are diff --git a/src/Lean/Meta/InferType.lean b/src/Lean/Meta/InferType.lean index 528596fe95..5addbe5218 100644 --- a/src/Lean/Meta/InferType.lean +++ b/src/Lean/Meta/InferType.lean @@ -326,14 +326,14 @@ match r with | Expr.sort _ _ => pure true | _ => pure false -partial def isTypeFormerAux : Expr → MetaM Bool +partial def isTypeFormerType : Expr → MetaM Bool | type => do type ← whnfD type; match type with | Expr.sort _ _ => pure true | Expr.forallE n d b c => withLocalDecl n d c.binderInfo $ fun fvar => - isTypeFormerAux (b.instantiate1 fvar) + isTypeFormerType (b.instantiate1 fvar) | _ => pure false /-- @@ -341,7 +341,7 @@ partial def isTypeFormerAux : Expr → MetaM Bool Remark: it subsumes `isType` -/ def isTypeFormer (e : Expr) : MetaM Bool := do type ← inferType e; -isTypeFormerAux type +isTypeFormerType type end Meta end Lean diff --git a/tests/lean/inductive1.lean b/tests/lean/inductive1.lean new file mode 100644 index 0000000000..2ea523c379 --- /dev/null +++ b/tests/lean/inductive1.lean @@ -0,0 +1,51 @@ +new_frontend + +-- Test1 +inductive T1 : Nat -- Error, resultant type is not a sort + + +-- Test2 +mutual + +inductive T1 : Prop + +inductive T2 : Type -- Error resulting universe mismatch + +end + +-- Test3 +universes u v +mutual + +inductive T1 (x : Nat) : Type u + +inductive T2 (x : Nat) : Nat → Type v -- Error resulting universe mismatch + +end + +-- Test4 +mutual + +inductive T1 (b : Bool) (x : Nat) : Type + +inductive T2 (b : Bool) (x : Bool) : Type -- Type mismatch at 'x' + +end + +-- Test5 +mutual + +inductive T1 (b : Bool) (x : Nat) : Type + +inductive T2 (x : Bool) : Type -- number of parameters mismatch + +end + +-- Test6 +mutual + +inductive T1 (b : Bool) (x : Nat) : Type + +inductive T2 (b : Bool) {x : Nat} : Type -- binder annotation mismatch at 'x' + +end diff --git a/tests/lean/inductive1.lean.expected.out b/tests/lean/inductive1.lean.expected.out new file mode 100644 index 0000000000..14daf32560 --- /dev/null +++ b/tests/lean/inductive1.lean.expected.out @@ -0,0 +1,15 @@ +inductive1.lean:4:15: error: invalid inductive type, resultant type is not a sort +inductive1.lean:12:0: error: invalid mutually inductive type, resulting universe mismatch, given + Type +expected type + Prop +inductive1.lean:22:0: error: invalid mutually inductive type, resulting universe mismatch, given + Type v +expected type + Type u +inductive1.lean:31:0: error: invalid mutually inductive type, type mismatch at parameter 'x' + Bool +expected type + Nat +inductive1.lean:40:0: error: invalid inductive type, number of parameters mismatch in mutually inductive datatype +inductive1.lean:49:0: error: invalid mutually inductive type, binder annotation mismatch at parameter 'x'