feat: apply termination tactic provided by user
This commit is contained in:
parent
23740778d4
commit
85c49cfeb3
10 changed files with 69 additions and 40 deletions
|
|
@ -142,7 +142,7 @@ def getTerminationHints (stx : Syntax) : TerminationHints :=
|
|||
let k := decl.getKind
|
||||
if k == ``Parser.Command.def || k == ``Parser.Command.theorem || k == ``Parser.Command.instance then
|
||||
let args := decl.getArgs
|
||||
{ terminationBy? := args[args.size - 2].getOptional?, decreasingTactic? := args[args.size - 1].getOptional? }
|
||||
{ terminationBy? := args[args.size - 2].getOptional?, decreasingBy? := args[args.size - 1].getOptional? }
|
||||
else
|
||||
{}
|
||||
|
||||
|
|
@ -256,20 +256,20 @@ def expandMutualPreamble : Macro := fun stx =>
|
|||
|
||||
@[builtinCommandElab «mutual»]
|
||||
def elabMutual : CommandElab := fun stx => do
|
||||
let hints := { terminationBy? := stx[3].getOptional?, decreasingTactic? := stx[4].getOptional? }
|
||||
let hints := { terminationBy? := stx[3].getOptional?, decreasingBy? := stx[4].getOptional? }
|
||||
if isMutualInductive stx then
|
||||
if let some bad := hints.terminationBy? then
|
||||
throwErrorAt bad "invalid 'termination_by' in mutually inductive datatype declaration"
|
||||
if let some bad := hints.decreasingTactic? then
|
||||
throwErrorAt bad "invalid 'decreasing_tactic' in mutually inductive datatype declaration"
|
||||
if let some bad := hints.decreasingBy? then
|
||||
throwErrorAt bad "invalid 'decreasing_by' in mutually inductive datatype declaration"
|
||||
elabMutualInductive stx[1].getArgs
|
||||
else if isMutualDef stx then
|
||||
for arg in stx[1].getArgs do
|
||||
let argHints := getTerminationHints arg
|
||||
if let some bad := argHints.terminationBy? then
|
||||
throwErrorAt bad "invalid 'termination_by' in 'mutual' block, it must be used after the 'end' keyword"
|
||||
if let some bad := argHints.decreasingTactic? then
|
||||
throwErrorAt bad "invalid 'decreasing_tactic' in 'mutual' block, it must be used after the 'end' keyword"
|
||||
if let some bad := argHints.decreasingBy? then
|
||||
throwErrorAt bad "invalid 'decreasing_by' in 'mutual' block, it must be used after the 'end' keyword"
|
||||
elabMutualDef stx[1].getArgs hints
|
||||
else
|
||||
throwError "invalid mutual block"
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ open Term
|
|||
|
||||
structure TerminationHints where
|
||||
terminationBy? : Option Syntax := none
|
||||
decreasingTactic? : Option Syntax := none
|
||||
decreasingBy? : Option Syntax := none
|
||||
deriving Inhabited
|
||||
|
||||
private def addAndCompilePartial (preDefs : Array PreDefinition) : TermElabM Unit := do
|
||||
|
|
@ -68,7 +68,7 @@ def addPreDefinitions (preDefs : Array PreDefinition) (hints : TerminationHints)
|
|||
let preDefs ← preDefs.mapM ensureNoUnassignedMVarsAtPreDef
|
||||
let cliques ← partitionPreDefs preDefs
|
||||
let mut terminationBy ← liftMacroM <| WF.expandTerminationHint hints.terminationBy? (cliques.map fun ds => ds.map (·.declName))
|
||||
let mut decreasingTactic ← liftMacroM <| WF.expandTerminationHint hints.decreasingTactic? (cliques.map fun ds => ds.map (·.declName))
|
||||
let mut decreasingBy ← liftMacroM <| WF.expandTerminationHint hints.decreasingBy? (cliques.map fun ds => ds.map (·.declName))
|
||||
for preDefs in cliques do
|
||||
trace[Elab.definition.scc] "{preDefs.map (·.declName)}"
|
||||
if preDefs.size == 1 && isNonRecursive preDefs[0] then
|
||||
|
|
@ -87,12 +87,12 @@ def addPreDefinitions (preDefs : Array PreDefinition) (hints : TerminationHints)
|
|||
else
|
||||
let mut wfStx? := none
|
||||
let mut decrTactic? := none
|
||||
if let some wfStx := terminationBy.find? (preDefs.map (·.declName)) then
|
||||
if let some { value := wfStx, .. } := terminationBy.find? (preDefs.map (·.declName)) then
|
||||
wfStx? := some wfStx
|
||||
terminationBy := terminationBy.erase (preDefs.map (·.declName))
|
||||
if let some decrTactic := decreasingTactic.find? (preDefs.map (·.declName)) then
|
||||
decrTactic? := some decrTactic
|
||||
decreasingTactic := decreasingTactic.erase (preDefs.map (·.declName))
|
||||
if let some { ref, value := decrTactic } := decreasingBy.find? (preDefs.map (·.declName)) then
|
||||
decrTactic? := some (← withRef ref `(by $decrTactic))
|
||||
decreasingBy := decreasingBy.erase (preDefs.map (·.declName))
|
||||
if wfStx?.isSome || decrTactic?.isSome then
|
||||
wfRecursion preDefs wfStx? decrTactic?
|
||||
else
|
||||
|
|
@ -104,7 +104,7 @@ def addPreDefinitions (preDefs : Array PreDefinition) (hints : TerminationHints)
|
|||
let preDefMsgs := preDefs.toList.map (MessageData.ofExpr $ mkConst ·.declName)
|
||||
m!"fail to show termination for{indentD (MessageData.joinSep preDefMsgs Format.line)}\nwith errors\n{msg}")
|
||||
liftMacroM <| terminationBy.ensureIsEmpty
|
||||
liftMacroM <| decreasingTactic.ensureIsEmpty
|
||||
liftMacroM <| decreasingBy.ensureIsEmpty
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Elab.definition.body
|
||||
|
|
|
|||
|
|
@ -14,15 +14,20 @@ open Meta
|
|||
private def toUnfold : Std.PHashSet Name :=
|
||||
[``measure, ``id, ``Prod.lex, ``invImage, ``InvImage, ``Nat.lt_wfRel].foldl (init := {}) fun s a => s.insert a
|
||||
|
||||
private def mkDecreasingProof (decreasingProp : Expr) : TermElabM Expr := do
|
||||
let mvar ← mkFreshExprSyntheticOpaqueMVar decreasingProp
|
||||
let mvarId := mvar.mvarId!
|
||||
private def applyDefaultDecrTactic (mvarId : MVarId) : TermElabM Unit := do
|
||||
let ctx ← Simp.Context.mkDefault
|
||||
let ctx := { ctx with simpLemmas.toUnfold := toUnfold }
|
||||
if let some mvarId ← simpTarget mvarId ctx then
|
||||
-- TODO: invoke tactic to close the goal
|
||||
trace[Elab.definition.wf] "{MessageData.ofGoal mvarId}"
|
||||
admit mvarId
|
||||
|
||||
private def mkDecreasingProof (decreasingProp : Expr) (decrTactic? : Option Syntax) : TermElabM Expr := do
|
||||
let mvar ← mkFreshExprSyntheticOpaqueMVar decreasingProp
|
||||
let mvarId := mvar.mvarId!
|
||||
match decrTactic? with
|
||||
| none => applyDefaultDecrTactic mvarId
|
||||
| some decrTactic => Term.runTactic mvarId decrTactic
|
||||
instantiateMVars mvar
|
||||
|
||||
private partial def replaceRecApps (recFnName : Name) (decrTactic? : Option Syntax) (F : Expr) (e : Expr) : TermElabM Expr :=
|
||||
|
|
@ -45,7 +50,7 @@ private partial def replaceRecApps (recFnName : Name) (decrTactic? : Option Synt
|
|||
if f.isConstOf recFnName && args.size == 1 then
|
||||
let r := mkApp F args[0]
|
||||
let decreasingProp := (← whnf (← inferType r)).bindingDomain!
|
||||
return mkApp r (← mkDecreasingProof decreasingProp)
|
||||
return mkApp r (← mkDecreasingProof decreasingProp decrTactic?)
|
||||
else
|
||||
return mkAppN (← loop F f) (← args.mapM (loop F))
|
||||
let matcherApp? ← matchMatcherApp? e
|
||||
|
|
|
|||
|
|
@ -7,30 +7,36 @@ import Lean.Parser.Command
|
|||
|
||||
namespace Lean.Elab.WF
|
||||
|
||||
structure TerminationHintValue where
|
||||
ref : Syntax
|
||||
value : Syntax
|
||||
deriving Inhabited
|
||||
|
||||
inductive TerminationHint where
|
||||
| none
|
||||
| one (stx : Syntax)
|
||||
| many (map : NameMap Syntax)
|
||||
| one (val : TerminationHintValue)
|
||||
| many (map : NameMap TerminationHintValue)
|
||||
deriving Inhabited
|
||||
|
||||
def expandTerminationHint (terminationHint? : Option Syntax) (cliques : Array (Array Name)) : MacroM TerminationHint := do
|
||||
if let some terminationHint := terminationHint? then
|
||||
let ref := terminationHint
|
||||
let terminationHint := terminationHint[1]
|
||||
if terminationHint.getKind == ``Parser.Command.terminationHint1 then
|
||||
return TerminationHint.one terminationHint[0]
|
||||
return TerminationHint.one { ref, value := terminationHint[0] }
|
||||
else if terminationHint.getKind == ``Parser.Command.terminationHintMany then
|
||||
let m ← terminationHint[0].getArgs.foldlM (init := {}) fun m arg =>
|
||||
let declName? := cliques.findSome? fun clique => clique.findSome? fun declName =>
|
||||
if arg[0].getId.isSuffixOf declName then some declName else none
|
||||
match declName? with
|
||||
| none => Macro.throwErrorAt arg[0] s!"function '{arg[0].getId}' not found in current declaration"
|
||||
| some declName => return m.insert declName arg[2]
|
||||
| some declName => return m.insert declName { ref := arg, value := arg[2] }
|
||||
for clique in cliques do
|
||||
let mut found? := Option.none
|
||||
for declName in clique do
|
||||
if let some stx := m.find? declName then
|
||||
if let some { ref, .. } := m.find? declName then
|
||||
if let some found := found? then
|
||||
Macro.throwErrorAt stx s!"invalid termination hint element, '{declName}' and '{found}' are in the same clique"
|
||||
Macro.throwErrorAt ref s!"invalid termination hint element, '{declName}' and '{found}' are in the same clique"
|
||||
found? := some declName
|
||||
return TerminationHint.many m
|
||||
else
|
||||
|
|
@ -53,16 +59,16 @@ def TerminationHint.erase (t : TerminationHint) (clique : Array Name) : Terminat
|
|||
return TerminationHint.many m
|
||||
return t
|
||||
|
||||
def TerminationHint.find? (t : TerminationHint) (clique : Array Name) : Option Syntax := do
|
||||
def TerminationHint.find? (t : TerminationHint) (clique : Array Name) : Option TerminationHintValue :=
|
||||
match t with
|
||||
| TerminationHint.none => Option.none
|
||||
| TerminationHint.one stx => some stx
|
||||
| TerminationHint.many m => clique.findSome? m.find?
|
||||
| TerminationHint.none => Option.none
|
||||
| TerminationHint.one v => some v
|
||||
| TerminationHint.many m => clique.findSome? m.find?
|
||||
|
||||
def TerminationHint.ensureIsEmpty (t : TerminationHint) : MacroM Unit := do
|
||||
match t with
|
||||
| TerminationHint.one stx => Macro.throwErrorAt stx "unused termination hint element"
|
||||
| TerminationHint.many m => m.forM fun _ stx => Macro.throwErrorAt stx "unused termination hint element"
|
||||
| TerminationHint.one v => Macro.throwErrorAt v.ref "unused termination hint element"
|
||||
| TerminationHint.many m => m.forM fun _ v => Macro.throwErrorAt v.ref "unused termination hint element"
|
||||
| _ => pure ()
|
||||
|
||||
structure TerminationStrategy where
|
||||
|
|
|
|||
|
|
@ -31,9 +31,9 @@ def terminationHint1 (p : Parser) := leading_parser p
|
|||
def terminationHint (p : Parser) := terminationHintMany p <|> terminationHint1 p
|
||||
|
||||
def terminationBy := leading_parser "termination_by " >> terminationHint termParser
|
||||
def decreasingTactic := leading_parser "decreasing_tactic " >> terminationHint Tactic.tacticSeq
|
||||
def decreasingBy := leading_parser "decreasing_by " >> terminationHint Tactic.tacticSeq
|
||||
|
||||
def terminationSuffix := optional terminationBy >> optional decreasingTactic
|
||||
def terminationSuffix := optional terminationBy >> optional decreasingBy
|
||||
|
||||
@[builtinCommandParser]
|
||||
def moduleDoc := leading_parser ppDedent $ "/-!" >> commentBody >> ppLine
|
||||
|
|
|
|||
|
|
@ -5,19 +5,19 @@ mutual
|
|||
inductive Odd : Nat → Prop
|
||||
| step : Even n → Odd (n+1)
|
||||
end
|
||||
decreasing_tactic assumption
|
||||
decreasing_by assumption
|
||||
|
||||
mutual
|
||||
def f (n : Nat) :=
|
||||
if n == 0 then 0 else f (n / 2) + 1
|
||||
decreasing_tactic assumption
|
||||
decreasing_by assumption
|
||||
end
|
||||
|
||||
def g' (n : Nat) :=
|
||||
match n with
|
||||
| 0 => 1
|
||||
| n+1 => g' n * 3
|
||||
decreasing_tactic
|
||||
decreasing_by
|
||||
h => assumption
|
||||
|
||||
namespace Test
|
||||
|
|
@ -34,7 +34,7 @@ mutual
|
|||
| 0, a, b => b
|
||||
| n+1, a, b => f n a b
|
||||
end
|
||||
decreasing_tactic
|
||||
decreasing_by
|
||||
f => assumption
|
||||
g => assumption
|
||||
|
||||
4
tests/lean/decreasing_by.lean.expected.out
Normal file
4
tests/lean/decreasing_by.lean.expected.out
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
decreasing_by.lean:8:0-8:24: error: invalid 'decreasing_by' in mutually inductive datatype declaration
|
||||
decreasing_by.lean:13:1-13:25: error: invalid 'decreasing_by' in 'mutual' block, it must be used after the 'end' keyword
|
||||
decreasing_by.lean:21:2-21:3: error: function 'h' not found in current declaration
|
||||
decreasing_by.lean:39:2-39:17: error: invalid termination hint element, 'Test.g' and 'Test.f' are in the same clique
|
||||
|
|
@ -1,4 +0,0 @@
|
|||
decreasing_tactic.lean:8:0-8:28: error: invalid 'decreasing_tactic' in mutually inductive datatype declaration
|
||||
decreasing_tactic.lean:13:1-13:29: error: invalid 'decreasing_tactic' in 'mutual' block, it must be used after the 'end' keyword
|
||||
decreasing_tactic.lean:21:2-21:3: error: function 'h' not found in current declaration
|
||||
decreasing_tactic.lean:39:7-39:17: error: invalid termination hint element, 'Test.g' and 'Test.f' are in the same clique
|
||||
18
tests/lean/run/mutwf2.lean
Normal file
18
tests/lean/run/mutwf2.lean
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
mutual
|
||||
def isEven : Nat → Bool
|
||||
| 0 => true
|
||||
| n+1 => isOdd n
|
||||
def isOdd : Nat → Bool
|
||||
| 0 => false
|
||||
| n+1 => isEven n
|
||||
end
|
||||
termination_by measure fun
|
||||
| Sum.inl n => n
|
||||
| Sum.inr n => n
|
||||
decreasing_by
|
||||
simp [measure, invImage, InvImage, Nat.lt_wfRel]
|
||||
apply Nat.lt_succ_self
|
||||
|
||||
#print isEven
|
||||
#print isOdd
|
||||
#print isEven._mutual
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
termination_by.lean:8:0-8:22: error: invalid 'termination_by' in mutually inductive datatype declaration
|
||||
termination_by.lean:13:1-13:23: error: invalid 'termination_by' in 'mutual' block, it must be used after the 'end' keyword
|
||||
termination_by.lean:21:2-21:3: error: function 'h' not found in current declaration
|
||||
termination_by.lean:39:7-39:14: error: invalid termination hint element, 'Test.g' and 'Test.f' are in the same clique
|
||||
termination_by.lean:39:2-39:14: error: invalid termination hint element, 'Test.g' and 'Test.f' are in the same clique
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue