feat: apply termination tactic provided by user

This commit is contained in:
Leonardo de Moura 2021-10-03 18:44:13 -07:00
parent 23740778d4
commit 85c49cfeb3
10 changed files with 69 additions and 40 deletions

View file

@ -142,7 +142,7 @@ def getTerminationHints (stx : Syntax) : TerminationHints :=
let k := decl.getKind
if k == ``Parser.Command.def || k == ``Parser.Command.theorem || k == ``Parser.Command.instance then
let args := decl.getArgs
{ terminationBy? := args[args.size - 2].getOptional?, decreasingTactic? := args[args.size - 1].getOptional? }
{ terminationBy? := args[args.size - 2].getOptional?, decreasingBy? := args[args.size - 1].getOptional? }
else
{}
@ -256,20 +256,20 @@ def expandMutualPreamble : Macro := fun stx =>
@[builtinCommandElab «mutual»]
def elabMutual : CommandElab := fun stx => do
let hints := { terminationBy? := stx[3].getOptional?, decreasingTactic? := stx[4].getOptional? }
let hints := { terminationBy? := stx[3].getOptional?, decreasingBy? := stx[4].getOptional? }
if isMutualInductive stx then
if let some bad := hints.terminationBy? then
throwErrorAt bad "invalid 'termination_by' in mutually inductive datatype declaration"
if let some bad := hints.decreasingTactic? then
throwErrorAt bad "invalid 'decreasing_tactic' in mutually inductive datatype declaration"
if let some bad := hints.decreasingBy? then
throwErrorAt bad "invalid 'decreasing_by' in mutually inductive datatype declaration"
elabMutualInductive stx[1].getArgs
else if isMutualDef stx then
for arg in stx[1].getArgs do
let argHints := getTerminationHints arg
if let some bad := argHints.terminationBy? then
throwErrorAt bad "invalid 'termination_by' in 'mutual' block, it must be used after the 'end' keyword"
if let some bad := argHints.decreasingTactic? then
throwErrorAt bad "invalid 'decreasing_tactic' in 'mutual' block, it must be used after the 'end' keyword"
if let some bad := argHints.decreasingBy? then
throwErrorAt bad "invalid 'decreasing_by' in 'mutual' block, it must be used after the 'end' keyword"
elabMutualDef stx[1].getArgs hints
else
throwError "invalid mutual block"

View file

@ -12,7 +12,7 @@ open Term
structure TerminationHints where
terminationBy? : Option Syntax := none
decreasingTactic? : Option Syntax := none
decreasingBy? : Option Syntax := none
deriving Inhabited
private def addAndCompilePartial (preDefs : Array PreDefinition) : TermElabM Unit := do
@ -68,7 +68,7 @@ def addPreDefinitions (preDefs : Array PreDefinition) (hints : TerminationHints)
let preDefs ← preDefs.mapM ensureNoUnassignedMVarsAtPreDef
let cliques ← partitionPreDefs preDefs
let mut terminationBy ← liftMacroM <| WF.expandTerminationHint hints.terminationBy? (cliques.map fun ds => ds.map (·.declName))
let mut decreasingTactic ← liftMacroM <| WF.expandTerminationHint hints.decreasingTactic? (cliques.map fun ds => ds.map (·.declName))
let mut decreasingBy ← liftMacroM <| WF.expandTerminationHint hints.decreasingBy? (cliques.map fun ds => ds.map (·.declName))
for preDefs in cliques do
trace[Elab.definition.scc] "{preDefs.map (·.declName)}"
if preDefs.size == 1 && isNonRecursive preDefs[0] then
@ -87,12 +87,12 @@ def addPreDefinitions (preDefs : Array PreDefinition) (hints : TerminationHints)
else
let mut wfStx? := none
let mut decrTactic? := none
if let some wfStx := terminationBy.find? (preDefs.map (·.declName)) then
if let some { value := wfStx, .. } := terminationBy.find? (preDefs.map (·.declName)) then
wfStx? := some wfStx
terminationBy := terminationBy.erase (preDefs.map (·.declName))
if let some decrTactic := decreasingTactic.find? (preDefs.map (·.declName)) then
decrTactic? := some decrTactic
decreasingTactic := decreasingTactic.erase (preDefs.map (·.declName))
if let some { ref, value := decrTactic } := decreasingBy.find? (preDefs.map (·.declName)) then
decrTactic? := some (← withRef ref `(by $decrTactic))
decreasingBy := decreasingBy.erase (preDefs.map (·.declName))
if wfStx?.isSome || decrTactic?.isSome then
wfRecursion preDefs wfStx? decrTactic?
else
@ -104,7 +104,7 @@ def addPreDefinitions (preDefs : Array PreDefinition) (hints : TerminationHints)
let preDefMsgs := preDefs.toList.map (MessageData.ofExpr $ mkConst ·.declName)
m!"fail to show termination for{indentD (MessageData.joinSep preDefMsgs Format.line)}\nwith errors\n{msg}")
liftMacroM <| terminationBy.ensureIsEmpty
liftMacroM <| decreasingTactic.ensureIsEmpty
liftMacroM <| decreasingBy.ensureIsEmpty
builtin_initialize
registerTraceClass `Elab.definition.body

View file

@ -14,15 +14,20 @@ open Meta
private def toUnfold : Std.PHashSet Name :=
[``measure, ``id, ``Prod.lex, ``invImage, ``InvImage, ``Nat.lt_wfRel].foldl (init := {}) fun s a => s.insert a
private def mkDecreasingProof (decreasingProp : Expr) : TermElabM Expr := do
let mvar ← mkFreshExprSyntheticOpaqueMVar decreasingProp
let mvarId := mvar.mvarId!
private def applyDefaultDecrTactic (mvarId : MVarId) : TermElabM Unit := do
let ctx ← Simp.Context.mkDefault
let ctx := { ctx with simpLemmas.toUnfold := toUnfold }
if let some mvarId ← simpTarget mvarId ctx then
-- TODO: invoke tactic to close the goal
trace[Elab.definition.wf] "{MessageData.ofGoal mvarId}"
admit mvarId
private def mkDecreasingProof (decreasingProp : Expr) (decrTactic? : Option Syntax) : TermElabM Expr := do
let mvar ← mkFreshExprSyntheticOpaqueMVar decreasingProp
let mvarId := mvar.mvarId!
match decrTactic? with
| none => applyDefaultDecrTactic mvarId
| some decrTactic => Term.runTactic mvarId decrTactic
instantiateMVars mvar
private partial def replaceRecApps (recFnName : Name) (decrTactic? : Option Syntax) (F : Expr) (e : Expr) : TermElabM Expr :=
@ -45,7 +50,7 @@ private partial def replaceRecApps (recFnName : Name) (decrTactic? : Option Synt
if f.isConstOf recFnName && args.size == 1 then
let r := mkApp F args[0]
let decreasingProp := (← whnf (← inferType r)).bindingDomain!
return mkApp r (← mkDecreasingProof decreasingProp)
return mkApp r (← mkDecreasingProof decreasingProp decrTactic?)
else
return mkAppN (← loop F f) (← args.mapM (loop F))
let matcherApp? ← matchMatcherApp? e

View file

@ -7,30 +7,36 @@ import Lean.Parser.Command
namespace Lean.Elab.WF
structure TerminationHintValue where
ref : Syntax
value : Syntax
deriving Inhabited
inductive TerminationHint where
| none
| one (stx : Syntax)
| many (map : NameMap Syntax)
| one (val : TerminationHintValue)
| many (map : NameMap TerminationHintValue)
deriving Inhabited
def expandTerminationHint (terminationHint? : Option Syntax) (cliques : Array (Array Name)) : MacroM TerminationHint := do
if let some terminationHint := terminationHint? then
let ref := terminationHint
let terminationHint := terminationHint[1]
if terminationHint.getKind == ``Parser.Command.terminationHint1 then
return TerminationHint.one terminationHint[0]
return TerminationHint.one { ref, value := terminationHint[0] }
else if terminationHint.getKind == ``Parser.Command.terminationHintMany then
let m ← terminationHint[0].getArgs.foldlM (init := {}) fun m arg =>
let declName? := cliques.findSome? fun clique => clique.findSome? fun declName =>
if arg[0].getId.isSuffixOf declName then some declName else none
match declName? with
| none => Macro.throwErrorAt arg[0] s!"function '{arg[0].getId}' not found in current declaration"
| some declName => return m.insert declName arg[2]
| some declName => return m.insert declName { ref := arg, value := arg[2] }
for clique in cliques do
let mut found? := Option.none
for declName in clique do
if let some stx := m.find? declName then
if let some { ref, .. } := m.find? declName then
if let some found := found? then
Macro.throwErrorAt stx s!"invalid termination hint element, '{declName}' and '{found}' are in the same clique"
Macro.throwErrorAt ref s!"invalid termination hint element, '{declName}' and '{found}' are in the same clique"
found? := some declName
return TerminationHint.many m
else
@ -53,16 +59,16 @@ def TerminationHint.erase (t : TerminationHint) (clique : Array Name) : Terminat
return TerminationHint.many m
return t
def TerminationHint.find? (t : TerminationHint) (clique : Array Name) : Option Syntax := do
def TerminationHint.find? (t : TerminationHint) (clique : Array Name) : Option TerminationHintValue :=
match t with
| TerminationHint.none => Option.none
| TerminationHint.one stx => some stx
| TerminationHint.many m => clique.findSome? m.find?
| TerminationHint.none => Option.none
| TerminationHint.one v => some v
| TerminationHint.many m => clique.findSome? m.find?
def TerminationHint.ensureIsEmpty (t : TerminationHint) : MacroM Unit := do
match t with
| TerminationHint.one stx => Macro.throwErrorAt stx "unused termination hint element"
| TerminationHint.many m => m.forM fun _ stx => Macro.throwErrorAt stx "unused termination hint element"
| TerminationHint.one v => Macro.throwErrorAt v.ref "unused termination hint element"
| TerminationHint.many m => m.forM fun _ v => Macro.throwErrorAt v.ref "unused termination hint element"
| _ => pure ()
structure TerminationStrategy where

View file

@ -31,9 +31,9 @@ def terminationHint1 (p : Parser) := leading_parser p
def terminationHint (p : Parser) := terminationHintMany p <|> terminationHint1 p
def terminationBy := leading_parser "termination_by " >> terminationHint termParser
def decreasingTactic := leading_parser "decreasing_tactic " >> terminationHint Tactic.tacticSeq
def decreasingBy := leading_parser "decreasing_by " >> terminationHint Tactic.tacticSeq
def terminationSuffix := optional terminationBy >> optional decreasingTactic
def terminationSuffix := optional terminationBy >> optional decreasingBy
@[builtinCommandParser]
def moduleDoc := leading_parser ppDedent $ "/-!" >> commentBody >> ppLine

View file

@ -5,19 +5,19 @@ mutual
inductive Odd : Nat → Prop
| step : Even n → Odd (n+1)
end
decreasing_tactic assumption
decreasing_by assumption
mutual
def f (n : Nat) :=
if n == 0 then 0 else f (n / 2) + 1
decreasing_tactic assumption
decreasing_by assumption
end
def g' (n : Nat) :=
match n with
| 0 => 1
| n+1 => g' n * 3
decreasing_tactic
decreasing_by
h => assumption
namespace Test
@ -34,7 +34,7 @@ mutual
| 0, a, b => b
| n+1, a, b => f n a b
end
decreasing_tactic
decreasing_by
f => assumption
g => assumption

View file

@ -0,0 +1,4 @@
decreasing_by.lean:8:0-8:24: error: invalid 'decreasing_by' in mutually inductive datatype declaration
decreasing_by.lean:13:1-13:25: error: invalid 'decreasing_by' in 'mutual' block, it must be used after the 'end' keyword
decreasing_by.lean:21:2-21:3: error: function 'h' not found in current declaration
decreasing_by.lean:39:2-39:17: error: invalid termination hint element, 'Test.g' and 'Test.f' are in the same clique

View file

@ -1,4 +0,0 @@
decreasing_tactic.lean:8:0-8:28: error: invalid 'decreasing_tactic' in mutually inductive datatype declaration
decreasing_tactic.lean:13:1-13:29: error: invalid 'decreasing_tactic' in 'mutual' block, it must be used after the 'end' keyword
decreasing_tactic.lean:21:2-21:3: error: function 'h' not found in current declaration
decreasing_tactic.lean:39:7-39:17: error: invalid termination hint element, 'Test.g' and 'Test.f' are in the same clique

View file

@ -0,0 +1,18 @@
mutual
def isEven : Nat → Bool
| 0 => true
| n+1 => isOdd n
def isOdd : Nat → Bool
| 0 => false
| n+1 => isEven n
end
termination_by measure fun
| Sum.inl n => n
| Sum.inr n => n
decreasing_by
simp [measure, invImage, InvImage, Nat.lt_wfRel]
apply Nat.lt_succ_self
#print isEven
#print isOdd
#print isEven._mutual

View file

@ -1,4 +1,4 @@
termination_by.lean:8:0-8:22: error: invalid 'termination_by' in mutually inductive datatype declaration
termination_by.lean:13:1-13:23: error: invalid 'termination_by' in 'mutual' block, it must be used after the 'end' keyword
termination_by.lean:21:2-21:3: error: function 'h' not found in current declaration
termination_by.lean:39:7-39:14: error: invalid termination hint element, 'Test.g' and 'Test.f' are in the same clique
termination_by.lean:39:2-39:14: error: invalid termination hint element, 'Test.g' and 'Test.f' are in the same clique