diff --git a/src/Lean/Elab/Inductive.lean b/src/Lean/Elab/Inductive.lean index da0121855c..3d24e3fefc 100644 --- a/src/Lean/Elab/Inductive.lean +++ b/src/Lean/Elab/Inductive.lean @@ -342,15 +342,6 @@ private def withUsed {α} (ref : Syntax) (vars : Array Expr) (indTypes : List In (lctx, localInsts, vars) ← removeUnused ref vars indTypes; Term.withLCtx lctx localInsts $ k vars -abbrev Ctor2InferMod := Std.HashMap Name Bool - -private def mkCtor2InferMod (views : Array InductiveView) : Ctor2InferMod := -views.foldl - (fun (m : Ctor2InferMod) view => view.ctors.foldl - (fun (m : Ctor2InferMod) ctorView => m.insert ctorView.declName ctorView.inferMod) - m) - {} - private def updateParams (ref : Syntax) (vars : Array Expr) (indTypes : List InductiveType) : TermElabM (List InductiveType) := indTypes.mapM fun indType => do type ← Term.mkForall ref vars indType.type; @@ -396,6 +387,24 @@ indTypes.mapM fun indType => do }; pure { indType with ctors := ctors } +abbrev Ctor2InferMod := Std.HashMap Name Bool + +private def mkCtor2InferMod (views : Array InductiveView) : Ctor2InferMod := +views.foldl + (fun (m : Ctor2InferMod) view => view.ctors.foldl + (fun (m : Ctor2InferMod) ctorView => m.insert ctorView.declName ctorView.inferMod) + m) + {} + +private def applyInferMod (views : Array InductiveView) (numParams : Nat) (indTypes : List InductiveType) : List InductiveType := +let ctor2InferMod := mkCtor2InferMod views; +indTypes.map fun indType => + let ctors := indType.ctors.map fun ctor => + let inferMod := ctor2InferMod.find! ctor.name; -- true if `{}` was used + let ctorType := ctor.type.inferImplicit numParams !inferMod; + { ctor with type := ctorType }; + { indType with ctors := ctors } + private def mkInductiveDecl (vars : Array Expr) (views : Array InductiveView) : TermElabM Declaration := do let view0 := views.get! 0; scopeLevelNames ← Term.getLevelNames; @@ -429,18 +438,17 @@ adaptReader (fun (ctx : Term.Context) => { ctx with levelNames := allUserLevelNa | Except.error msg => Term.throwError ref msg | Except.ok levelParams => do indTypes ← replaceIndFVarsWithConsts views indFVars levelParams numParams indTypes; + let indTypes := applyInferMod views numParams indTypes; traceIndTypes indTypes; - let decl := Declaration.inductDecl levelParams numParams indTypes isUnsafe; - -- TODO: use inferImplicit at ctors - Term.throwError ref "WIP" - -- pure decl + pure $ Declaration.inductDecl levelParams numParams indTypes isUnsafe def elabInductiveCore (views : Array InductiveView) : CommandElabM Unit := do let view0 := views.get! 0; decl ← runTermElabM view0.declName $ fun vars => mkInductiveDecl vars views; --- TODO +-- TODO register decl pure () + end Command end Elab end Lean