diff --git a/src/Lean/Meta/SizeOf.lean b/src/Lean/Meta/SizeOf.lean index 7a48512c4a..61d1aba5e1 100644 --- a/src/Lean/Meta/SizeOf.lean +++ b/src/Lean/Meta/SizeOf.lean @@ -131,30 +131,117 @@ partial def mkSizeOfFn (recName : Name) (declName : Name): MetaM Unit := do /-- Create `sizeOf` functions for all inductive datatypes in the mutual inductive declaration containing `typeName` - The resulting array contains the generated functions names. + The resulting array contains the generated functions names. The `NameMap` maps recursor names into the generated function names. There is a function for each element of the mutual inductive declaration, and for auxiliary recursors for nested inductive types. -/ -def mkSizeOfFns (typeName : Name) : MetaM (Array Name) := do +def mkSizeOfFns (typeName : Name) : MetaM (Array Name × NameMap Name) := do let indInfo ← getConstInfoInduct typeName let recInfo ← getConstInfoRec (mkRecName typeName) let numExtra := recInfo.numMotives - indInfo.all.length -- numExtra > 0 for nested inductive types let mut result := #[] let baseName := indInfo.all.head! ++ `_sizeOf -- we use the first inductive type as the base name for `sizeOf` functions let mut i := 1 + let mut recMap : NameMap Name := {} for indTypeName in indInfo.all do let sizeOfName := baseName.appendIndexAfter i - mkSizeOfFn (mkRecName indTypeName) sizeOfName + let recName := mkRecName indTypeName + mkSizeOfFn recName sizeOfName + recMap := recMap.insert recName sizeOfName result := result.push sizeOfName i := i + 1 for j in [:numExtra] do let recName := (mkRecName indInfo.all.head!).appendIndexAfter (j+1) let sizeOfName := baseName.appendIndexAfter i mkSizeOfFn recName sizeOfName + recMap := recMap.insert recName sizeOfName result := result.push sizeOfName i := i + 1 - return result + return (result, recMap) -private def mkSizeOfSpecTheorem (indInfo : InductiveVal) (sizeOfFns : Array Name) (ctorName : Name) : MetaM Unit := do +/- SizeOf spec theorem for nested inductive types -/ +namespace SizeOfSpecNested + +structure Context where + indInfo : InductiveVal + sizeOfFns : Array Name + ctorName : Name + params : Array Expr + localInsts : Array Expr + recMap : NameMap Name -- mapping from recursor name into `_sizeOf_` function name (see `mkSizeOfFns`) + +abbrev M := ReaderT Context MetaM + +def throwUnexpected {α} (msg : MessageData) : M α := do + throwError! "failed to generate sizeOf lemma for {(← read).ctorName} (use `set_option genSizeOfSpec false` to disable lemma generation), {msg}" + +def throwFailed {α} : M α := do + throwError! "failed to generate sizeOf lemma for {(← read).ctorName}, (use `set_option genSizeOfSpec false` to disable lemma generation)" + +/-- Convert a recursor application into a `_sizeOf_` application. -/ +private def recToSizeOf (e : Expr) : M Expr := do + matchConstRec e.getAppFn (fun _ => throwFailed) fun info us => do + match (← read).recMap.find? info.name with + | none => throwUnexpected m!"expected recursor application {indentExpr e}" + | some sizeOfName => + let args := e.getAppArgs + let indices := args[info.getFirstIndexIdx : info.getFirstIndexIdx + info.numIndices] + let major := args[info.getMajorIdx] + return mkAppN (mkConst sizeOfName us.tail!) ((← read).params ++ (← read).localInsts ++ indices ++ #[major]) + +/-- + Generate proof for `C._sizeOf_ t = sizeOf t` where `C._sizeOf_` is a auxiliary function + generated for a nested inductive type in `C`. + For example, given + ```lean + inductive Expr where + | app (f : String) (args : List Expr) + ``` + We generate the auxiliary function `Expr._sizeOf_1 : List Expr → Nat`. + To generate the `sizeOf` spec lemma + ``` + sizeOf (Expr.app f args) = 1 + sizeOf f + sizeOf args + ``` + we need an auxiliary lemma for showing `Expr._sizeOf_1 args = sizeOf args`. + Recall that `sizeOf (Expr.app f args)` is definitionally equal to `1 + sizeOf f + sizeOf args`, but + `Expr._sizeOf_1 args` is **not** definitionally equal to `sizeOf args`. We need a proof by induction. +-/ +private def mkSizeOfAuxLemma (lhs rhs : Expr) : M Expr := do + -- TODO + mkSorry (← mkEq lhs rhs) true + +/- Prove SizeOf spec lemma of the form `sizeOf = 1 + sizeOf + ... + sizeOf -/ +partial def main (lhs rhs : Expr) : M Expr := do + if (← isDefEq lhs rhs) then + mkEqRefl rhs + else + /- Expand lhs and rhs to obtain `Nat.add` applications -/ + let lhs ← whnfI lhs -- Expand `sizeOf (ctor ...)` into `_sizeOf_` application + let lhs ← unfoldDefinition lhs -- Unfold `_sizeOf_` application into `HAdd.hAdd` application + loop lhs rhs +where + loop (lhs rhs : Expr) : M Expr := do + trace[Meta.sizeOf.loop]! "{lhs} =?= {rhs}" + if (← isDefEq lhs rhs) then + mkEqRefl rhs + else + match (← whnfI lhs).natAdd?, (← whnfI rhs).natAdd? with + | some (a₁, b₁), some (a₂, b₂) => + let p₁ ← loop a₁ a₂ + let p₂ ← step b₁ b₂ + mkCongr (← mkCongrArg (mkConst ``Nat.add) p₁) p₂ + | _, _ => + throwUnexpected m!"expected 'Nat.add' application, lhs is {indentExpr lhs}\nrhs is{indentExpr rhs}" + + step (lhs rhs : Expr) : M Expr := do + if (← isDefEq lhs rhs) then + mkEqRefl rhs + else + let lhs ← recToSizeOf lhs + mkSizeOfAuxLemma lhs rhs + +end SizeOfSpecNested + +private def mkSizeOfSpecTheorem (indInfo : InductiveVal) (sizeOfFns : Array Name) (recMap : NameMap Name) (ctorName : Name) : MetaM Unit := do let ctorInfo ← getConstInfoCtor ctorName let us := ctorInfo.levelParams.map mkLevelParam forallTelescopeReducing ctorInfo.type fun xs _ => do @@ -173,7 +260,9 @@ private def mkSizeOfSpecTheorem (indInfo : InductiveVal) (sizeOfFns : Array Name let thmType ← mkForallFVars thmParams target let thmValue ← if indInfo.isNested then - return () -- TODO + SizeOfSpecNested.main lhs rhs |>.run { + indInfo := indInfo, sizeOfFns := sizeOfFns, ctorName := ctorName, params := params, localInsts := localInsts, recMap := recMap + } else mkEqRefl rhs let thmValue ← mkLambdaFVars thmParams thmValue @@ -184,11 +273,11 @@ private def mkSizeOfSpecTheorem (indInfo : InductiveVal) (sizeOfFns : Array Name value := thmValue } -private def mkSizeOfSpecTheorems (indTypeNames : Array Name) (sizeOfFns : Array Name) : MetaM Unit := do +private def mkSizeOfSpecTheorems (indTypeNames : Array Name) (sizeOfFns : Array Name) (recMap : NameMap Name) : MetaM Unit := do for indTypeName in indTypeNames do let indInfo ← getConstInfoInduct indTypeName for ctorName in indInfo.ctors do - mkSizeOfSpecTheorem indInfo sizeOfFns ctorName + mkSizeOfSpecTheorem indInfo sizeOfFns recMap ctorName return () builtin_initialize @@ -205,7 +294,7 @@ def mkSizeOfInstances (typeName : Name) : MetaM Unit := do if (← getEnv).contains ``SizeOf && generateSizeOfInstance (← getOptions) && !(← isInductivePredicate typeName) then let indInfo ← getConstInfoInduct typeName unless indInfo.isUnsafe do - let fns ← mkSizeOfFns typeName + let (fns, recMap) ← mkSizeOfFns typeName for indTypeName in indInfo.all, fn in fns do let indInfo ← getConstInfoInduct indTypeName forallTelescopeReducing indInfo.type fun xs _ => @@ -231,7 +320,7 @@ def mkSizeOfInstances (typeName : Name) : MetaM Unit := do } addInstance instDeclName AttributeKind.global (evalPrio! default) if generateSizeOfSpec (← getOptions) then - mkSizeOfSpecTheorems indInfo.all.toArray fns + mkSizeOfSpecTheorems indInfo.all.toArray fns recMap builtin_initialize registerTraceClass `Meta.sizeOf