From 85c49cfeb3cade34cb877fd72dfbc23a1f36ee6d Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 3 Oct 2021 18:44:13 -0700 Subject: [PATCH] feat: apply termination tactic provided by user --- src/Lean/Elab/Declaration.lean | 12 ++++---- src/Lean/Elab/PreDefinition/Main.lean | 14 ++++----- src/Lean/Elab/PreDefinition/WF/Fix.lean | 13 +++++--- .../PreDefinition/WF/TerminationHint.lean | 30 +++++++++++-------- src/Lean/Parser/Command.lean | 4 +-- ...reasing_tactic.lean => decreasing_by.lean} | 8 ++--- tests/lean/decreasing_by.lean.expected.out | 4 +++ .../lean/decreasing_tactic.lean.expected.out | 4 --- tests/lean/run/mutwf2.lean | 18 +++++++++++ tests/lean/termination_by.lean.expected.out | 2 +- 10 files changed, 69 insertions(+), 40 deletions(-) rename tests/lean/{decreasing_tactic.lean => decreasing_by.lean} (87%) create mode 100644 tests/lean/decreasing_by.lean.expected.out delete mode 100644 tests/lean/decreasing_tactic.lean.expected.out create mode 100644 tests/lean/run/mutwf2.lean diff --git a/src/Lean/Elab/Declaration.lean b/src/Lean/Elab/Declaration.lean index 8b6ef8700b..d332899cf2 100644 --- a/src/Lean/Elab/Declaration.lean +++ b/src/Lean/Elab/Declaration.lean @@ -142,7 +142,7 @@ def getTerminationHints (stx : Syntax) : TerminationHints := let k := decl.getKind if k == ``Parser.Command.def || k == ``Parser.Command.theorem || k == ``Parser.Command.instance then let args := decl.getArgs - { terminationBy? := args[args.size - 2].getOptional?, decreasingTactic? := args[args.size - 1].getOptional? } + { terminationBy? := args[args.size - 2].getOptional?, decreasingBy? := args[args.size - 1].getOptional? } else {} @@ -256,20 +256,20 @@ def expandMutualPreamble : Macro := fun stx => @[builtinCommandElab «mutual»] def elabMutual : CommandElab := fun stx => do - let hints := { terminationBy? := stx[3].getOptional?, decreasingTactic? := stx[4].getOptional? } + let hints := { terminationBy? := stx[3].getOptional?, decreasingBy? := stx[4].getOptional? } if isMutualInductive stx then if let some bad := hints.terminationBy? then throwErrorAt bad "invalid 'termination_by' in mutually inductive datatype declaration" - if let some bad := hints.decreasingTactic? then - throwErrorAt bad "invalid 'decreasing_tactic' in mutually inductive datatype declaration" + if let some bad := hints.decreasingBy? then + throwErrorAt bad "invalid 'decreasing_by' in mutually inductive datatype declaration" elabMutualInductive stx[1].getArgs else if isMutualDef stx then for arg in stx[1].getArgs do let argHints := getTerminationHints arg if let some bad := argHints.terminationBy? then throwErrorAt bad "invalid 'termination_by' in 'mutual' block, it must be used after the 'end' keyword" - if let some bad := argHints.decreasingTactic? then - throwErrorAt bad "invalid 'decreasing_tactic' in 'mutual' block, it must be used after the 'end' keyword" + if let some bad := argHints.decreasingBy? then + throwErrorAt bad "invalid 'decreasing_by' in 'mutual' block, it must be used after the 'end' keyword" elabMutualDef stx[1].getArgs hints else throwError "invalid mutual block" diff --git a/src/Lean/Elab/PreDefinition/Main.lean b/src/Lean/Elab/PreDefinition/Main.lean index 9255a7d077..4874e4230f 100644 --- a/src/Lean/Elab/PreDefinition/Main.lean +++ b/src/Lean/Elab/PreDefinition/Main.lean @@ -12,7 +12,7 @@ open Term structure TerminationHints where terminationBy? : Option Syntax := none - decreasingTactic? : Option Syntax := none + decreasingBy? : Option Syntax := none deriving Inhabited private def addAndCompilePartial (preDefs : Array PreDefinition) : TermElabM Unit := do @@ -68,7 +68,7 @@ def addPreDefinitions (preDefs : Array PreDefinition) (hints : TerminationHints) let preDefs ← preDefs.mapM ensureNoUnassignedMVarsAtPreDef let cliques ← partitionPreDefs preDefs let mut terminationBy ← liftMacroM <| WF.expandTerminationHint hints.terminationBy? (cliques.map fun ds => ds.map (·.declName)) - let mut decreasingTactic ← liftMacroM <| WF.expandTerminationHint hints.decreasingTactic? (cliques.map fun ds => ds.map (·.declName)) + let mut decreasingBy ← liftMacroM <| WF.expandTerminationHint hints.decreasingBy? (cliques.map fun ds => ds.map (·.declName)) for preDefs in cliques do trace[Elab.definition.scc] "{preDefs.map (·.declName)}" if preDefs.size == 1 && isNonRecursive preDefs[0] then @@ -87,12 +87,12 @@ def addPreDefinitions (preDefs : Array PreDefinition) (hints : TerminationHints) else let mut wfStx? := none let mut decrTactic? := none - if let some wfStx := terminationBy.find? (preDefs.map (·.declName)) then + if let some { value := wfStx, .. } := terminationBy.find? (preDefs.map (·.declName)) then wfStx? := some wfStx terminationBy := terminationBy.erase (preDefs.map (·.declName)) - if let some decrTactic := decreasingTactic.find? (preDefs.map (·.declName)) then - decrTactic? := some decrTactic - decreasingTactic := decreasingTactic.erase (preDefs.map (·.declName)) + if let some { ref, value := decrTactic } := decreasingBy.find? (preDefs.map (·.declName)) then + decrTactic? := some (← withRef ref `(by $decrTactic)) + decreasingBy := decreasingBy.erase (preDefs.map (·.declName)) if wfStx?.isSome || decrTactic?.isSome then wfRecursion preDefs wfStx? decrTactic? else @@ -104,7 +104,7 @@ def addPreDefinitions (preDefs : Array PreDefinition) (hints : TerminationHints) let preDefMsgs := preDefs.toList.map (MessageData.ofExpr $ mkConst ·.declName) m!"fail to show termination for{indentD (MessageData.joinSep preDefMsgs Format.line)}\nwith errors\n{msg}") liftMacroM <| terminationBy.ensureIsEmpty - liftMacroM <| decreasingTactic.ensureIsEmpty + liftMacroM <| decreasingBy.ensureIsEmpty builtin_initialize registerTraceClass `Elab.definition.body diff --git a/src/Lean/Elab/PreDefinition/WF/Fix.lean b/src/Lean/Elab/PreDefinition/WF/Fix.lean index 646c5fa4c0..a4fcfc8ac5 100644 --- a/src/Lean/Elab/PreDefinition/WF/Fix.lean +++ b/src/Lean/Elab/PreDefinition/WF/Fix.lean @@ -14,15 +14,20 @@ open Meta private def toUnfold : Std.PHashSet Name := [``measure, ``id, ``Prod.lex, ``invImage, ``InvImage, ``Nat.lt_wfRel].foldl (init := {}) fun s a => s.insert a -private def mkDecreasingProof (decreasingProp : Expr) : TermElabM Expr := do - let mvar ← mkFreshExprSyntheticOpaqueMVar decreasingProp - let mvarId := mvar.mvarId! +private def applyDefaultDecrTactic (mvarId : MVarId) : TermElabM Unit := do let ctx ← Simp.Context.mkDefault let ctx := { ctx with simpLemmas.toUnfold := toUnfold } if let some mvarId ← simpTarget mvarId ctx then -- TODO: invoke tactic to close the goal trace[Elab.definition.wf] "{MessageData.ofGoal mvarId}" admit mvarId + +private def mkDecreasingProof (decreasingProp : Expr) (decrTactic? : Option Syntax) : TermElabM Expr := do + let mvar ← mkFreshExprSyntheticOpaqueMVar decreasingProp + let mvarId := mvar.mvarId! + match decrTactic? with + | none => applyDefaultDecrTactic mvarId + | some decrTactic => Term.runTactic mvarId decrTactic instantiateMVars mvar private partial def replaceRecApps (recFnName : Name) (decrTactic? : Option Syntax) (F : Expr) (e : Expr) : TermElabM Expr := @@ -45,7 +50,7 @@ private partial def replaceRecApps (recFnName : Name) (decrTactic? : Option Synt if f.isConstOf recFnName && args.size == 1 then let r := mkApp F args[0] let decreasingProp := (← whnf (← inferType r)).bindingDomain! - return mkApp r (← mkDecreasingProof decreasingProp) + return mkApp r (← mkDecreasingProof decreasingProp decrTactic?) else return mkAppN (← loop F f) (← args.mapM (loop F)) let matcherApp? ← matchMatcherApp? e diff --git a/src/Lean/Elab/PreDefinition/WF/TerminationHint.lean b/src/Lean/Elab/PreDefinition/WF/TerminationHint.lean index f79d89a548..10cc9714d6 100644 --- a/src/Lean/Elab/PreDefinition/WF/TerminationHint.lean +++ b/src/Lean/Elab/PreDefinition/WF/TerminationHint.lean @@ -7,30 +7,36 @@ import Lean.Parser.Command namespace Lean.Elab.WF +structure TerminationHintValue where + ref : Syntax + value : Syntax + deriving Inhabited + inductive TerminationHint where | none - | one (stx : Syntax) - | many (map : NameMap Syntax) + | one (val : TerminationHintValue) + | many (map : NameMap TerminationHintValue) deriving Inhabited def expandTerminationHint (terminationHint? : Option Syntax) (cliques : Array (Array Name)) : MacroM TerminationHint := do if let some terminationHint := terminationHint? then + let ref := terminationHint let terminationHint := terminationHint[1] if terminationHint.getKind == ``Parser.Command.terminationHint1 then - return TerminationHint.one terminationHint[0] + return TerminationHint.one { ref, value := terminationHint[0] } else if terminationHint.getKind == ``Parser.Command.terminationHintMany then let m ← terminationHint[0].getArgs.foldlM (init := {}) fun m arg => let declName? := cliques.findSome? fun clique => clique.findSome? fun declName => if arg[0].getId.isSuffixOf declName then some declName else none match declName? with | none => Macro.throwErrorAt arg[0] s!"function '{arg[0].getId}' not found in current declaration" - | some declName => return m.insert declName arg[2] + | some declName => return m.insert declName { ref := arg, value := arg[2] } for clique in cliques do let mut found? := Option.none for declName in clique do - if let some stx := m.find? declName then + if let some { ref, .. } := m.find? declName then if let some found := found? then - Macro.throwErrorAt stx s!"invalid termination hint element, '{declName}' and '{found}' are in the same clique" + Macro.throwErrorAt ref s!"invalid termination hint element, '{declName}' and '{found}' are in the same clique" found? := some declName return TerminationHint.many m else @@ -53,16 +59,16 @@ def TerminationHint.erase (t : TerminationHint) (clique : Array Name) : Terminat return TerminationHint.many m return t -def TerminationHint.find? (t : TerminationHint) (clique : Array Name) : Option Syntax := do +def TerminationHint.find? (t : TerminationHint) (clique : Array Name) : Option TerminationHintValue := match t with - | TerminationHint.none => Option.none - | TerminationHint.one stx => some stx - | TerminationHint.many m => clique.findSome? m.find? + | TerminationHint.none => Option.none + | TerminationHint.one v => some v + | TerminationHint.many m => clique.findSome? m.find? def TerminationHint.ensureIsEmpty (t : TerminationHint) : MacroM Unit := do match t with - | TerminationHint.one stx => Macro.throwErrorAt stx "unused termination hint element" - | TerminationHint.many m => m.forM fun _ stx => Macro.throwErrorAt stx "unused termination hint element" + | TerminationHint.one v => Macro.throwErrorAt v.ref "unused termination hint element" + | TerminationHint.many m => m.forM fun _ v => Macro.throwErrorAt v.ref "unused termination hint element" | _ => pure () structure TerminationStrategy where diff --git a/src/Lean/Parser/Command.lean b/src/Lean/Parser/Command.lean index faec95cb97..8dbcc7c93f 100644 --- a/src/Lean/Parser/Command.lean +++ b/src/Lean/Parser/Command.lean @@ -31,9 +31,9 @@ def terminationHint1 (p : Parser) := leading_parser p def terminationHint (p : Parser) := terminationHintMany p <|> terminationHint1 p def terminationBy := leading_parser "termination_by " >> terminationHint termParser -def decreasingTactic := leading_parser "decreasing_tactic " >> terminationHint Tactic.tacticSeq +def decreasingBy := leading_parser "decreasing_by " >> terminationHint Tactic.tacticSeq -def terminationSuffix := optional terminationBy >> optional decreasingTactic +def terminationSuffix := optional terminationBy >> optional decreasingBy @[builtinCommandParser] def moduleDoc := leading_parser ppDedent $ "/-!" >> commentBody >> ppLine diff --git a/tests/lean/decreasing_tactic.lean b/tests/lean/decreasing_by.lean similarity index 87% rename from tests/lean/decreasing_tactic.lean rename to tests/lean/decreasing_by.lean index 68fec2db71..8cf1dde9ad 100644 --- a/tests/lean/decreasing_tactic.lean +++ b/tests/lean/decreasing_by.lean @@ -5,19 +5,19 @@ mutual inductive Odd : Nat → Prop | step : Even n → Odd (n+1) end -decreasing_tactic assumption +decreasing_by assumption mutual def f (n : Nat) := if n == 0 then 0 else f (n / 2) + 1 - decreasing_tactic assumption + decreasing_by assumption end def g' (n : Nat) := match n with | 0 => 1 | n+1 => g' n * 3 -decreasing_tactic +decreasing_by h => assumption namespace Test @@ -34,7 +34,7 @@ mutual | 0, a, b => b | n+1, a, b => f n a b end -decreasing_tactic +decreasing_by f => assumption g => assumption diff --git a/tests/lean/decreasing_by.lean.expected.out b/tests/lean/decreasing_by.lean.expected.out new file mode 100644 index 0000000000..2f86a193e4 --- /dev/null +++ b/tests/lean/decreasing_by.lean.expected.out @@ -0,0 +1,4 @@ +decreasing_by.lean:8:0-8:24: error: invalid 'decreasing_by' in mutually inductive datatype declaration +decreasing_by.lean:13:1-13:25: error: invalid 'decreasing_by' in 'mutual' block, it must be used after the 'end' keyword +decreasing_by.lean:21:2-21:3: error: function 'h' not found in current declaration +decreasing_by.lean:39:2-39:17: error: invalid termination hint element, 'Test.g' and 'Test.f' are in the same clique diff --git a/tests/lean/decreasing_tactic.lean.expected.out b/tests/lean/decreasing_tactic.lean.expected.out deleted file mode 100644 index 30f022bf81..0000000000 --- a/tests/lean/decreasing_tactic.lean.expected.out +++ /dev/null @@ -1,4 +0,0 @@ -decreasing_tactic.lean:8:0-8:28: error: invalid 'decreasing_tactic' in mutually inductive datatype declaration -decreasing_tactic.lean:13:1-13:29: error: invalid 'decreasing_tactic' in 'mutual' block, it must be used after the 'end' keyword -decreasing_tactic.lean:21:2-21:3: error: function 'h' not found in current declaration -decreasing_tactic.lean:39:7-39:17: error: invalid termination hint element, 'Test.g' and 'Test.f' are in the same clique diff --git a/tests/lean/run/mutwf2.lean b/tests/lean/run/mutwf2.lean new file mode 100644 index 0000000000..1ceb7ce867 --- /dev/null +++ b/tests/lean/run/mutwf2.lean @@ -0,0 +1,18 @@ +mutual + def isEven : Nat → Bool + | 0 => true + | n+1 => isOdd n + def isOdd : Nat → Bool + | 0 => false + | n+1 => isEven n +end +termination_by measure fun + | Sum.inl n => n + | Sum.inr n => n +decreasing_by + simp [measure, invImage, InvImage, Nat.lt_wfRel] + apply Nat.lt_succ_self + +#print isEven +#print isOdd +#print isEven._mutual diff --git a/tests/lean/termination_by.lean.expected.out b/tests/lean/termination_by.lean.expected.out index 72c2c22823..29f53aacbc 100644 --- a/tests/lean/termination_by.lean.expected.out +++ b/tests/lean/termination_by.lean.expected.out @@ -1,4 +1,4 @@ termination_by.lean:8:0-8:22: error: invalid 'termination_by' in mutually inductive datatype declaration termination_by.lean:13:1-13:23: error: invalid 'termination_by' in 'mutual' block, it must be used after the 'end' keyword termination_by.lean:21:2-21:3: error: function 'h' not found in current declaration -termination_by.lean:39:7-39:14: error: invalid termination hint element, 'Test.g' and 'Test.f' are in the same clique +termination_by.lean:39:2-39:14: error: invalid termination hint element, 'Test.g' and 'Test.f' are in the same clique