diff --git a/src/Lean/Elab/PreDefinition/Main.lean b/src/Lean/Elab/PreDefinition/Main.lean index 050fedc8b3..df425dff3b 100644 --- a/src/Lean/Elab/PreDefinition/Main.lean +++ b/src/Lean/Elab/PreDefinition/Main.lean @@ -141,28 +141,29 @@ private def betaReduceLetRecApps (preDefs : Array PreDefinition) : MetaM (Array private def addSorried (preDefs : Array PreDefinition) : TermElabM Unit := do for preDef in preDefs do - let value ← mkSorry (synthetic := true) preDef.type - let decl := if preDef.kind.isTheorem then - Declaration.thmDecl { - name := preDef.declName, - levelParams := preDef.levelParams, - type := preDef.type, - value - } - else - Declaration.defnDecl { - name := preDef.declName, - levelParams := preDef.levelParams, - type := preDef.type, - hints := .abbrev - safety := .safe - value - } - addDecl decl - withSaveInfoContext do -- save new env - addTermInfo' preDef.ref (← mkConstWithLevelParams preDef.declName) (isBinder := true) - applyAttributesOf #[preDef] AttributeApplicationTime.afterTypeChecking - applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation + unless (← hasConst preDef.declName) do + let value ← mkSorry (synthetic := true) preDef.type + let decl := if preDef.kind.isTheorem then + Declaration.thmDecl { + name := preDef.declName, + levelParams := preDef.levelParams, + type := preDef.type, + value + } + else + Declaration.defnDecl { + name := preDef.declName, + levelParams := preDef.levelParams, + type := preDef.type, + hints := .abbrev + safety := .safe + value + } + addDecl decl + withSaveInfoContext do -- save new env + addTermInfo' preDef.ref (← mkConstWithLevelParams preDef.declName) (isBinder := true) + applyAttributesOf #[preDef] AttributeApplicationTime.afterTypeChecking + applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation def ensureFunIndReservedNamesAvailable (preDefs : Array PreDefinition) : MetaM Unit := do preDefs.forM fun preDef => diff --git a/src/Lean/Elab/PreDefinition/WF/Eqns.lean b/src/Lean/Elab/PreDefinition/WF/Eqns.lean index 3fedf735cb..07eb40f422 100644 --- a/src/Lean/Elab/PreDefinition/WF/Eqns.lean +++ b/src/Lean/Elab/PreDefinition/WF/Eqns.lean @@ -11,7 +11,6 @@ public import Lean.Meta.Tactic.Split public import Lean.Elab.PreDefinition.Basic public import Lean.Elab.PreDefinition.Eqns public import Lean.Meta.ArgsPacker.Basic -public import Lean.Elab.PreDefinition.WF.Unfold public import Lean.Elab.PreDefinition.FixedParams public import Init.Data.Array.Basic diff --git a/src/Lean/Elab/PreDefinition/WF/Unfold.lean b/src/Lean/Elab/PreDefinition/WF/Unfold.lean index 776724fb7c..a4e90f475e 100644 --- a/src/Lean/Elab/PreDefinition/WF/Unfold.lean +++ b/src/Lean/Elab/PreDefinition/WF/Unfold.lean @@ -1,23 +1,34 @@ /- Copyright (c) 2022 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. -Authors: Leonardo de Moura +Authors: Leonardo de Moura, Joachim Breitner -/ module prelude public import Lean.Elab.PreDefinition.Basic -public import Lean.Elab.PreDefinition.Eqns -public import Lean.Meta.Tactic.Apply +import Lean.Elab.PreDefinition.Eqns +import Lean.Meta.Tactic.Apply +import Lean.Meta.Tactic.Split +public import Lean.Meta.Tactic.Simp.Types import Lean.Meta.Tactic.Simp.Main +import Lean.Meta.Tactic.Simp.BuiltinSimprocs -public section +/-! +This module is responsible for proving the unfolding equation for functions defined +by well-founded recursion. It uses `WellFounded.fix_eq`, and then has to undo +the changes to matchers that `WF.Fix` did using `MatcherApp.addArg`. + +This is done using a single-pass `simp` traversal of the expression that looks +for expressions that were modified that way, and rewrites them back using the +rather specialized `_arg_pusher` theorem that is generated by `mkMatchArgPusher`. +-/ namespace Lean.Elab.WF open Meta open Eqns -private def rwFixEq (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do +def rwFixEq (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do let target ← mvarId.getType' let some (_, lhs, rhs) := target.eq? | unreachable! @@ -43,44 +54,157 @@ private def rwFixEq (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do mvarId.assign (← mkEqTrans h mvarNew) return mvarNew.mvarId! -private partial def mkUnfoldProof (declName : Name) (mvarId : MVarId) : MetaM Unit := do - trace[Elab.definition.wf.eqns] "step\n{MessageData.ofGoal mvarId}" - if ← withAtLeastTransparency .all (tryURefl mvarId) then - trace[Elab.definition.wf.eqns] "refl!" - return () - else if (← tryContradiction mvarId) then - trace[Elab.definition.wf.eqns] "contradiction!" - return () - else if let some mvarId ← simpMatch? mvarId then - trace[Elab.definition.wf.eqns] "simpMatch!" - mkUnfoldProof declName mvarId - else if let some mvarId ← simpIf? mvarId (useNewSemantics := true) then - trace[Elab.definition.wf.eqns] "simpIf!" - mkUnfoldProof declName mvarId - else - let ctx ← Simp.mkContext (config := { dsimp := false, etaStruct := .none }) - match (← simpTargetStar mvarId ctx (simprocs := {})).1 with - | TacticResultCNM.closed => return () - | TacticResultCNM.modified mvarId => - trace[Elab.definition.wf.eqns] "simp only!" - mkUnfoldProof declName mvarId - | TacticResultCNM.noChange => - if let some mvarIds ← casesOnStuckLHS? mvarId then - trace[Elab.definition.wf.eqns] "case split into {mvarIds.size} goals" - mvarIds.forM (mkUnfoldProof declName) - else if let some mvarIds ← splitTarget? mvarId (useNewSemantics := true) then - trace[Elab.definition.wf.eqns] "splitTarget into {mvarIds.length} goals" - mvarIds.forM (mkUnfoldProof declName) - else - -- At some point in the past, we looked for occurrences of Wf.fix to fold on the - -- LHS (introduced in 096e4eb), but it seems that code path was never used, - -- so #3133 removed it again (and can be recovered from there if this was premature). - throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}" +def isForallMotive (matcherApp : MatcherApp) : MetaM (Option Expr) := do + lambdaBoundedTelescope matcherApp.motive matcherApp.discrs.size fun xs t => + if xs.size == matcherApp.discrs.size && t.isForall && !t.bindingBody!.hasLooseBVar 0 then + return some (← mkLambdaFVars xs t.bindingBody!) + else + return none -def mkUnfoldEq (preDef : PreDefinition) (unaryPreDefName : Name) (wfPreprocessProof : Simp.Result) : MetaM Unit := do + +/-- Generalization of `splitMatch` that can handle `casesOn` -/ +def splitMatchOrCasesOn (mvarId : MVarId) (e : Expr) (matcherInfo : MatcherInfo) : MetaM (List MVarId) := do + if (← isMatcherApp e) then + Split.splitMatch mvarId e + else + assert! matcherInfo.numDiscrs = 1 + let discr := e.getAppArgs[matcherInfo.numParams + 1]! + assert! discr.isFVar + let subgoals ← mvarId.cases discr.fvarId! + return subgoals.map (·.mvarId) |>.toList + +/-- +Generates a theorem of the form +``` +matcherArgPusher params motive {α} {β} (f : ∀ (x : α), β x) rel alt1 .. x1 x2 + : + matcher params (motive := fun x1 x2 => ((y : α) → rel x1 x2 y → β y) → motive x1 x2) + (alt1 := fun z1 z2 z2 f => alt1 z1 z2 z2 f) … + x1 x2 + (fun y _h => f y) + = + matcher params (motive := motive) + (alt1 := fun z1 z2 z2 => alt1 z1 z2 z2 (fun y _ => f y)) … + x1 x2 +``` +-/ +def mkMatchArgPusher (matcherName : Name) (matcherInfo : MatcherInfo) : MetaM Name := do + let name := (mkPrivateName (← getEnv) matcherName) ++ `_arg_pusher + realizeConst matcherName name do + let matcherVal ← getConstVal matcherName + forallBoundedTelescope matcherVal.type (some (matcherInfo.numParams + 1)) fun xs _ => do + let params := xs[*...matcherInfo.numParams] + let motive' := xs[matcherInfo.numParams]! + let u ← mkFreshUserName `u + let v ← mkFreshUserName `v + withLocalDeclD `α (.sort (.param u)) fun alpha => do + withLocalDeclD `β (← mkArrow alpha (.sort (.param v))) fun beta => do + withLocalDeclD `f (.forallE `x alpha (mkApp beta (.bvar 0)) .default) fun f => do + let relType ← forallTelescope (← inferType motive') fun xs _ => + mkForallFVars xs (.forallE `x alpha (.sort 0) .default) + withLocalDeclD `rel relType fun rel => do + let motive ← forallTelescope (← inferType motive') fun xs _ => do + let motiveBody := mkAppN motive' xs + let extraArgType := .forallE `y alpha (.forallE `h (mkAppN rel (xs.push (.bvar 0))) (mkApp beta (.bvar 1)) .default) .default + let motiveBody ← mkArrow extraArgType motiveBody + mkLambdaFVars xs motiveBody + + let uElim ← lambdaBoundedTelescope motive matcherInfo.numDiscrs fun _ motiveBody => do + getLevel motiveBody + let us := matcherVal.levelParams ++ [u, v] + let matcherLevels' := matcherVal.levelParams.map mkLevelParam + let matcherLevels ← match matcherInfo.uElimPos? with + | none => + unless uElim.isZero do + throwError "unexpected matcher application for {.ofConstName matcherName}, motive is not a proposition" + pure matcherLevels' + | some pos => + pure <| (matcherLevels'.toArray.set! pos uElim).toList + let lhs := .const matcherName matcherLevels + let rhs := .const matcherName matcherLevels' + let lhs := mkAppN lhs params + let rhs := mkAppN rhs params + let lhs := mkApp lhs motive + let rhs := mkApp rhs motive' + forallBoundedTelescope (← inferType lhs) matcherInfo.numDiscrs fun discrs _ => do + let lhs := mkAppN lhs discrs + let rhs := mkAppN rhs discrs + forallBoundedTelescope (← inferType lhs) matcherInfo.numAlts fun alts _ => do + let lhs := mkAppN lhs alts + + let mut rhs := rhs + for alt in alts, altNumParams in matcherInfo.altNumParams do + let alt' ← forallBoundedTelescope (← inferType alt) altNumParams fun ys altBodyType => do + assert! altBodyType.isForall + let altArg ← forallBoundedTelescope altBodyType.bindingDomain! (some 2) fun ys _ => do + mkLambdaFVars ys (.app f ys[0]!) + mkLambdaFVars ys (mkAppN alt (ys.push altArg)) + rhs := mkApp rhs alt' + + let extraArg := .lam `y alpha (.lam `h (mkAppN rel (discrs.push (.bvar 0))) (mkApp f (.bvar 1)) .default) .default + let lhs := mkApp lhs extraArg + let goal ← mkEq lhs rhs + + let value ← mkFreshExprSyntheticOpaqueMVar goal + let mvarId := value.mvarId! + let mvarIds ← splitMatchOrCasesOn mvarId rhs matcherInfo + for mvarId in mvarIds do + mvarId.refl + let value ← instantiateMVars value + let type ← mkForallFVars (params ++ #[motive', alpha, beta, f, rel] ++ discrs ++ alts) goal + let value ← mkLambdaFVars (params ++ #[motive', alpha, beta, f, rel] ++ discrs ++ alts) value + addDecl <| Declaration.thmDecl { name, levelParams := us, type, value} + return name + +builtin_simproc_decl matcherPushArg (_) := fun e => do + let e := e.headBeta + let some matcherApp ← matchMatcherApp? e (alsoCasesOn := true) | return .continue + -- Check that the first remaining argument is of the form `(fun (x : α) p => (f x : β x))` + let some fArg := matcherApp.remaining[0]? | return .continue + unless fArg.isLambda do return .continue + unless fArg.bindingBody!.isLambda do return .continue + unless fArg.bindingBody!.bindingBody!.isApp do return .continue + if fArg.bindingBody!.bindingBody!.hasLooseBVar 0 then return .continue + unless fArg.bindingBody!.bindingBody!.appArg! == .bvar 1 do return .continue + if fArg.bindingBody!.bindingBody!.appFn!.hasLooseBVar 1 then return .continue + + let fExpr := fArg.bindingBody!.bindingBody!.appFn! + let fExprType ← inferType fExpr + let fExprType ← withTransparency .all (whnfForall fExprType) + assert! fExprType.isForall + let alpha := fExprType.bindingDomain! + let beta := .lam fExprType.bindingName! fExprType.bindingDomain! fExprType.bindingBody! .default + + -- Check that the motive has an extra parameter (from MatcherApp.addArg) + let some motive' ← isForallMotive matcherApp | return .continue + let rel ← lambdaTelescope matcherApp.motive fun xs motiveBody => + let motiveBodyArg := motiveBody.bindingDomain! + mkLambdaFVars xs (.lam motiveBodyArg.bindingName! motiveBodyArg.bindingDomain! motiveBodyArg.bindingBody!.bindingDomain! .default) + + let argPusher ← mkMatchArgPusher matcherApp.matcherName matcherApp.toMatcherInfo + -- Let's infer the level paramters: + let proof ← withTransparency .all <| mkAppOptM + argPusher ((matcherApp.params ++ #[motive', alpha, beta, fExpr, rel] ++ matcherApp.discrs ++ matcherApp.alts).map some) + let some (_, _, rhs) := (← inferType proof).eq? | throwError "matcherPushArg: expected equality:{indentExpr (← inferType proof)}" + let step : Simp.Result := { expr := rhs, proof? := some proof } + let step ← step.addExtraArgs matcherApp.remaining[1...*] + return .continue (some step) + +def mkUnfoldProof (declName : Name) (mvarId : MVarId) : MetaM Unit := withTransparency .all do + let ctx ← Simp.mkContext (config := { dsimp := false, etaStruct := .none, letToHave := false, singlePass := true }) + let simprocs := ({} : Simp.SimprocsArray) + let simprocs ← simprocs.add ``matcherPushArg (post := false) + match (← simpTarget mvarId ctx (simprocs := simprocs)).1 with + | none => return () + | some mvarId' => + prependError m!"failed to finish proof for equational theorem for '{.ofConstName declName}'" do + mvarId'.refl + +public def mkUnfoldEq (preDef : PreDefinition) (unaryPreDefName : Name) (wfPreprocessProof : Simp.Result) : MetaM Unit := do let name := mkEqLikeNameFor (← getEnv) preDef.declName unfoldThmSuffix - prependError m!"Cannot derive {name}" do + prependError m!"Cannot derive unfold equation {name}" do withOptions (tactic.hygienic.set · false) do + withoutExporting do lambdaTelescope preDef.value fun xs body => do let us := preDef.levelParams.map mkLevelParam let lhs := mkAppN (Lean.mkConst preDef.declName us) xs @@ -111,7 +235,7 @@ theorem of `foo._unary` or `foo._binary`. It should just be a specialization of that one, due to defeq. -/ -def mkBinaryUnfoldEq (preDef : PreDefinition) (unaryPreDefName : Name) : MetaM Unit := do +public def mkBinaryUnfoldEq (preDef : PreDefinition) (unaryPreDefName : Name) : MetaM Unit := do let name := mkEqLikeNameFor (← getEnv) preDef.declName unfoldThmSuffix let unaryEqName:= mkEqLikeNameFor (← getEnv) unaryPreDefName unfoldThmSuffix prependError m!"Cannot derive {name} from {unaryEqName}" do diff --git a/src/Lean/Meta/Constructions/CasesOn.lean b/src/Lean/Meta/Constructions/CasesOn.lean index e05a0f6c16..7a087a6e1a 100644 --- a/src/Lean/Meta/Constructions/CasesOn.lean +++ b/src/Lean/Meta/Constructions/CasesOn.lean @@ -23,5 +23,6 @@ def mkCasesOn (declName : Name) : MetaM Unit := do addDecl decl setReducibleAttribute name modifyEnv fun env => markAuxRecursor env name + enableRealizationsForConst name end Lean diff --git a/src/Lean/Meta/Match/MatcherApp/Basic.lean b/src/Lean/Meta/Match/MatcherApp/Basic.lean index c995eddf55..f7119e9866 100644 --- a/src/Lean/Meta/Match/MatcherApp/Basic.lean +++ b/src/Lean/Meta/Match/MatcherApp/Basic.lean @@ -32,6 +32,7 @@ of matcher applications. -/ def matchMatcherApp? [Monad m] [MonadEnv m] [MonadError m] (e : Expr) (alsoCasesOn := false) : m (Option MatcherApp) := do + unless e.isApp do return none if let .const declName declLevels := e.getAppFn then if let some info ← getMatcherInfo? declName then let args := e.getAppArgs @@ -74,6 +75,13 @@ def matchMatcherApp? [Monad m] [MonadEnv m] [MonadError m] (e : Expr) (alsoCases return none +def MatcherApp.toMatcherInfo (matcherApp : MatcherApp) : MatcherInfo where + uElimPos? := matcherApp.uElimPos? + discrInfos := matcherApp.discrInfos + numParams := matcherApp.params.size + numDiscrs := matcherApp.discrs.size + altNumParams := matcherApp.altNumParams + def MatcherApp.toExpr (matcherApp : MatcherApp) : Expr := let result := mkAppN (mkConst matcherApp.matcherName matcherApp.matcherLevels.toList) matcherApp.params let result := mkApp result matcherApp.motive diff --git a/tests/lean/run/issue9646.lean b/tests/lean/run/issue9646.lean new file mode 100644 index 0000000000..f84da19bc4 --- /dev/null +++ b/tests/lean/run/issue9646.lean @@ -0,0 +1,11 @@ +/-! +Checks that that the wfrec unfold theorem can be generated even if the +function type is not manifestly a forall. +-/ + +def T := Nat → Nat + +def f : T +| 0 => 0 +| n + 1 => f n + 1 +termination_by n => n