feat: improve smart unfolding
This commit is contained in:
parent
f1b88e1304
commit
c81dbeb53c
5 changed files with 118 additions and 115 deletions
|
|
@ -20,6 +20,12 @@ def smartUnfoldingSuffix := "_sunfold"
|
|||
@[inline] def mkSmartUnfoldingNameFor (n : Name) : Name :=
|
||||
Name.mkStr n smartUnfoldingSuffix
|
||||
|
||||
def smartUnfoldingDefault := true
|
||||
builtin_initialize
|
||||
registerOption `smartUnfolding { defValue := smartUnfoldingDefault, group := "", descr := "when computing weak head normal form, use auxiliary definition created for functions defined by structural recursion" }
|
||||
private def useSmartUnfolding (opts : Options) : Bool :=
|
||||
opts.getBool `smartUnfolding smartUnfoldingDefault
|
||||
|
||||
/- ===========================
|
||||
Helper methods
|
||||
=========================== -/
|
||||
|
|
@ -113,20 +119,6 @@ private def reduceRec {α} (recVal : RecursorVal) (recLvls : List Level) (recArg
|
|||
else
|
||||
failK ()
|
||||
|
||||
@[specialize] private def isRecStuck? (isStuck? : Expr → MetaM (Option MVarId)) (recVal : RecursorVal) (recLvls : List Level) (recArgs : Array Expr)
|
||||
: MetaM (Option MVarId) :=
|
||||
if recVal.k then
|
||||
-- TODO: improve this case
|
||||
pure none
|
||||
else do
|
||||
let majorIdx := recVal.getMajorIdx
|
||||
if h : majorIdx < recArgs.size then do
|
||||
let major := recArgs.get ⟨majorIdx, h⟩
|
||||
let major ← whnf major
|
||||
isStuck? major
|
||||
else
|
||||
pure none
|
||||
|
||||
/- ===========================
|
||||
Helper functions for reducing Quot.lift and Quot.ind
|
||||
=========================== -/
|
||||
|
|
@ -152,41 +144,55 @@ private def reduceQuotRec {α} (recVal : QuotVal) (recLvls : List Level) (recAr
|
|||
| QuotKind.ind => process 4 3
|
||||
| _ => failK ()
|
||||
|
||||
@[specialize] private def isQuotRecStuck? (isStuck? : Expr → MetaM (Option MVarId)) (recVal : QuotVal) (recLvls : List Level) (recArgs : Array Expr)
|
||||
: MetaM (Option MVarId) :=
|
||||
let process? (majorPos : Nat) : MetaM (Option MVarId) :=
|
||||
if h : majorPos < recArgs.size then do
|
||||
let major := recArgs.get ⟨majorPos, h⟩
|
||||
let major ← whnf major
|
||||
isStuck? major
|
||||
else
|
||||
pure none
|
||||
match recVal.kind with
|
||||
| QuotKind.lift => process? 5
|
||||
| QuotKind.ind => process? 4
|
||||
| _ => pure none
|
||||
|
||||
/- ===========================
|
||||
Helper function for extracting "stuck term"
|
||||
=========================== -/
|
||||
|
||||
/-- Return `some (Expr.mvar mvarId)` if metavariable `mvarId` is blocking reduction. -/
|
||||
private partial def getStuckMVarImp? : Expr → MetaM (Option MVarId)
|
||||
| Expr.mdata _ e _ => getStuckMVarImp? e
|
||||
| Expr.proj _ _ e _ => do getStuckMVarImp? (← whnf e)
|
||||
| e@(Expr.mvar mvarId _) => pure (some mvarId)
|
||||
| e@(Expr.app f _ _) =>
|
||||
let f := f.getAppFn
|
||||
match f with
|
||||
| Expr.mvar mvarId _ => pure (some mvarId)
|
||||
| Expr.const fName fLvls _ => do
|
||||
let cinfo? ← getConstNoEx? fName
|
||||
match cinfo? with
|
||||
| some $ ConstantInfo.recInfo recVal => isRecStuck? getStuckMVarImp? recVal fLvls e.getAppArgs
|
||||
| some $ ConstantInfo.quotInfo recVal => isQuotRecStuck? getStuckMVarImp? recVal fLvls e.getAppArgs
|
||||
| _ => pure none
|
||||
mutual
|
||||
private partial def isRecStuck? (recVal : RecursorVal) (recLvls : List Level) (recArgs : Array Expr) : MetaM (Option MVarId) :=
|
||||
if recVal.k then
|
||||
-- TODO: improve this case
|
||||
pure none
|
||||
else do
|
||||
let majorIdx := recVal.getMajorIdx
|
||||
if h : majorIdx < recArgs.size then do
|
||||
let major := recArgs.get ⟨majorIdx, h⟩
|
||||
let major ← whnf major
|
||||
getStuckMVarImp? major
|
||||
else
|
||||
pure none
|
||||
|
||||
private partial def isQuotRecStuck? (recVal : QuotVal) (recLvls : List Level) (recArgs : Array Expr) : MetaM (Option MVarId) :=
|
||||
let process? (majorPos : Nat) : MetaM (Option MVarId) :=
|
||||
if h : majorPos < recArgs.size then do
|
||||
let major := recArgs.get ⟨majorPos, h⟩
|
||||
let major ← whnf major
|
||||
getStuckMVarImp? major
|
||||
else
|
||||
pure none
|
||||
match recVal.kind with
|
||||
| QuotKind.lift => process? 5
|
||||
| QuotKind.ind => process? 4
|
||||
| _ => pure none
|
||||
|
||||
/-- Return `some (Expr.mvar mvarId)` if metavariable `mvarId` is blocking reduction. -/
|
||||
private partial def getStuckMVarImp? : Expr → MetaM (Option MVarId)
|
||||
| Expr.mdata _ e _ => getStuckMVarImp? e
|
||||
| Expr.proj _ _ e _ => do getStuckMVarImp? (← whnf e)
|
||||
| e@(Expr.mvar mvarId _) => pure (some mvarId)
|
||||
| e@(Expr.app f _ _) =>
|
||||
let f := f.getAppFn
|
||||
match f with
|
||||
| Expr.mvar mvarId _ => pure (some mvarId)
|
||||
| Expr.const fName fLvls _ => do
|
||||
let cinfo? ← getConstNoEx? fName
|
||||
match cinfo? with
|
||||
| some $ ConstantInfo.recInfo recVal => isRecStuck? recVal fLvls e.getAppArgs
|
||||
| some $ ConstantInfo.quotInfo recVal => isQuotRecStuck? recVal fLvls e.getAppArgs
|
||||
| _ => pure none
|
||||
| _ => pure none
|
||||
| _ => pure none
|
||||
| _ => pure none
|
||||
end
|
||||
|
||||
@[inline] def getStuckMVar? (e : Expr) : m (Option MVarId) :=
|
||||
liftM $ getStuckMVarImp? e
|
||||
|
|
@ -333,51 +339,75 @@ private partial def whnfCoreImp (e : Expr) : MetaM Expr :=
|
|||
@[inline] def whnfCore (e : Expr) : m Expr :=
|
||||
liftMetaM $ whnfCoreImp e
|
||||
|
||||
/--
|
||||
Similar to `whnfCore`, but uses `synthesizePending` to (try to) synthesize metavariables
|
||||
that are blocking reduction. -/
|
||||
private partial def whnfCoreUnstuck (e : Expr) : MetaM Expr := do
|
||||
let e ← whnfCore e
|
||||
let (some mvarId) ← getStuckMVar? e | pure e
|
||||
let succeeded ← Meta.synthPending mvarId
|
||||
if succeeded then whnfCoreUnstuck e else pure e
|
||||
|
||||
def smartUnfoldingDefault := true
|
||||
builtin_initialize
|
||||
registerOption `smartUnfolding { defValue := smartUnfoldingDefault, group := "", descr := "when computing weak head normal form, use auxiliary definition created for functions defined by structural recursion" }
|
||||
private def useSmartUnfolding (opts : Options) : Bool :=
|
||||
opts.getBool `smartUnfolding smartUnfoldingDefault
|
||||
|
||||
/-- Unfold definition using "smart unfolding" if possible. -/
|
||||
private def unfoldDefinitionImp? (e : Expr) : MetaM (Option Expr) :=
|
||||
match e with
|
||||
| Expr.app f _ _ =>
|
||||
matchConstAux f.getAppFn (fun _ => pure none) fun fInfo fLvls => do
|
||||
if fInfo.lparams.length != fLvls.length then
|
||||
pure none
|
||||
mutual
|
||||
/-- Reduce `e` until `idRhs` application is exposed or it gets stuck.
|
||||
This is a helper method for implementing smart unfolding. -/
|
||||
private partial def whnfUntilIdRhs (e : Expr) : MetaM Expr := do
|
||||
let e ← whnfCoreImp e
|
||||
match (← getStuckMVar? e) with
|
||||
| some mvarId =>
|
||||
/- Try to "unstuck" by resolving pending TC problems -/
|
||||
if (← Meta.synthPending mvarId) then
|
||||
whnfUntilIdRhs e
|
||||
else
|
||||
let unfoldDefault (_ : Unit) : MetaM (Option Expr) :=
|
||||
if fInfo.hasValue then
|
||||
deltaBetaDefinition fInfo fLvls e.getAppRevArgs (fun _ => pure none) (fun e => pure (some e))
|
||||
else
|
||||
pure none
|
||||
if useSmartUnfolding (← getOptions) then
|
||||
let fAuxInfo? ← getConstNoEx? (mkSmartUnfoldingNameFor fInfo.name)
|
||||
match fAuxInfo? with
|
||||
| some $ fAuxInfo@(ConstantInfo.defnInfo _) =>
|
||||
deltaBetaDefinition fAuxInfo fLvls e.getAppRevArgs (fun _ => pure none) $ fun e₁ => do
|
||||
let e₂ ← whnfCoreUnstuck e₁
|
||||
if isIdRhsApp e₂ then
|
||||
pure (some (extractIdRhs e₂))
|
||||
else
|
||||
pure none
|
||||
| _ => unfoldDefault ()
|
||||
pure e -- failed because metavariable is blocking reduction
|
||||
| _ =>
|
||||
if isIdRhsApp e then
|
||||
pure e -- done
|
||||
else
|
||||
match (← unfoldDefinitionImp? e) with
|
||||
| some e => whnfUntilIdRhs e
|
||||
| none => pure e -- failed because of symbolic argument
|
||||
|
||||
/-- Unfold definition using "smart unfolding" if possible. -/
|
||||
private partial def unfoldDefinitionImp? (e : Expr) : MetaM (Option Expr) :=
|
||||
match e with
|
||||
| Expr.app f _ _ =>
|
||||
matchConstAux f.getAppFn (fun _ => pure none) fun fInfo fLvls => do
|
||||
if fInfo.lparams.length != fLvls.length then
|
||||
pure none
|
||||
else
|
||||
unfoldDefault ()
|
||||
| Expr.const name lvls _ => do
|
||||
let (some (cinfo@(ConstantInfo.defnInfo _))) ← getConstNoEx? name | pure none
|
||||
deltaDefinition cinfo lvls (fun _ => pure none) (fun e => pure (some e))
|
||||
| _ => pure none
|
||||
let unfoldDefault (_ : Unit) : MetaM (Option Expr) :=
|
||||
if fInfo.hasValue then
|
||||
deltaBetaDefinition fInfo fLvls e.getAppRevArgs (fun _ => pure none) (fun e => pure (some e))
|
||||
else
|
||||
pure none
|
||||
if useSmartUnfolding (← getOptions) then
|
||||
let fAuxInfo? ← getConstNoEx? (mkSmartUnfoldingNameFor fInfo.name)
|
||||
match fAuxInfo? with
|
||||
| some $ fAuxInfo@(ConstantInfo.defnInfo _) =>
|
||||
deltaBetaDefinition fAuxInfo fLvls e.getAppRevArgs (fun _ => pure none) $ fun e₁ => do
|
||||
let e₂ ← whnfUntilIdRhs e₁
|
||||
if isIdRhsApp e₂ then
|
||||
pure (some (extractIdRhs e₂))
|
||||
else
|
||||
pure none
|
||||
| _ => unfoldDefault ()
|
||||
else
|
||||
unfoldDefault ()
|
||||
| Expr.const name lvls _ => do
|
||||
let (some (cinfo@(ConstantInfo.defnInfo _))) ← getConstNoEx? name | pure none
|
||||
deltaDefinition cinfo lvls (fun _ => pure none) (fun e => pure (some e))
|
||||
| _ => pure none
|
||||
end
|
||||
|
||||
@[specialize] partial def whnfHeadPredImp (e : Expr) (pred : Expr → MetaM Bool) : MetaM Expr :=
|
||||
whnfEasyCases e fun e => do
|
||||
let e ← whnfCoreImp e
|
||||
if (← pred e) then
|
||||
match (← unfoldDefinitionImp? e) with
|
||||
| some e => whnfHeadPredImp e pred
|
||||
| none => pure e
|
||||
else
|
||||
pure e
|
||||
|
||||
@[inline] partial def whnfHeadPred (e : Expr) (pred : Expr → MetaM Bool) : m Expr :=
|
||||
liftMetaM $ whnfHeadPredImp e pred
|
||||
|
||||
def whnfUntil (e : Expr) (declName : Name) : m (Option Expr) := liftMetaM do
|
||||
let e ← whnfHeadPredImp e (fun e => pure $ !e.isAppOf declName)
|
||||
if e.isAppOf declName then pure e
|
||||
else pure none
|
||||
|
||||
@[inline] def unfoldDefinition? (e : Expr) : m (Option Expr) :=
|
||||
liftMetaM $ unfoldDefinitionImp? e
|
||||
|
|
@ -476,7 +506,7 @@ partial def whnfImp (e : Expr) : MetaM Expr :=
|
|||
match (← cached? useCache e) with
|
||||
| some e' => pure e'
|
||||
| none =>
|
||||
let e' ← whnfCore e
|
||||
let e' ← whnfCoreImp e
|
||||
match (← reduceNat? e') with
|
||||
| some v => cache useCache e v
|
||||
| none =>
|
||||
|
|
@ -501,24 +531,6 @@ def reduceProj? (e : Expr) (i : Nat) : MetaM (Option Expr) := do
|
|||
else
|
||||
pure none
|
||||
|
||||
@[specialize] partial def whnfHeadPredImp (e : Expr) (pred : Expr → MetaM Bool) : MetaM Expr :=
|
||||
whnfEasyCases e fun e => do
|
||||
let e ← whnfCore e
|
||||
if (← pred e) then
|
||||
match (← unfoldDefinition? e) with
|
||||
| some e => whnfHeadPredImp e pred
|
||||
| none => pure e
|
||||
else
|
||||
pure e
|
||||
|
||||
@[inline] def whnfHeadPred (e : Expr) (pred : Expr → MetaM Bool) : m Expr :=
|
||||
liftMetaM $ whnfHeadPredImp e pred
|
||||
|
||||
def whnfUntil (e : Expr) (declName : Name) : m (Option Expr) := liftMetaM do
|
||||
let e ← whnfHeadPredImp e (fun e => pure $ !e.isAppOf declName)
|
||||
if e.isAppOf declName then pure e
|
||||
else pure none
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Meta.whnf
|
||||
|
||||
|
|
|
|||
|
|
@ -60,7 +60,6 @@ return sum
|
|||
|
||||
#eval sumOdd [1, 2, 3, 4, 5, 6, 7, 9, 11, 101] 10
|
||||
|
||||
set_option smartUnfolding false in
|
||||
theorem ex5 : sumOdd [1, 2, 3, 4, 5, 6, 7, 9, 11, 101] 10 = 16 :=
|
||||
rfl
|
||||
|
||||
|
|
@ -79,7 +78,6 @@ for (x, y) in ps do
|
|||
sum := sum + x - y
|
||||
return sum
|
||||
|
||||
set_option smartUnfolding false in
|
||||
theorem ex7 : sumDiff [(2, 1), (10, 5)] = 6 :=
|
||||
rfl
|
||||
|
||||
|
|
@ -114,7 +112,6 @@ for x in xs.reverse do
|
|||
odds := x :: odds
|
||||
return (evens, odds)
|
||||
|
||||
set_option smartUnfolding false in
|
||||
theorem ex8 : split [1, 2, 3, 4] = ([2, 4], [1, 3]) :=
|
||||
rfl
|
||||
|
||||
|
|
@ -169,11 +166,9 @@ for x in xs do
|
|||
return x
|
||||
return 0
|
||||
|
||||
set_option smartUnfolding false in
|
||||
theorem ex14 : findOdd [2, 4, 5, 8, 7] = 5 :=
|
||||
rfl
|
||||
|
||||
set_option smartUnfolding false in
|
||||
theorem ex15 : findOdd [2, 4, 8, 10] = 0 :=
|
||||
rfl
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ for x in xs do
|
|||
#eval f [1, 2, 3] $.run' 0
|
||||
#eval f [1, 0, 3] $.run' 0
|
||||
|
||||
set_option smartUnfolding false in
|
||||
theorem ex1 : f [1, 2, 3] $.run' 0 = Except.ok () :=
|
||||
rfl
|
||||
|
||||
|
|
|
|||
|
|
@ -14,10 +14,8 @@ structure S :=
|
|||
instance : BEq S :=
|
||||
⟨fun a b => a.key == b.key⟩
|
||||
|
||||
set_option smartUnfolding false in
|
||||
theorem ex1 : f (α := S) [⟨1, 2⟩, ⟨3, 4⟩, ⟨5, 6⟩] ⟨3, 0⟩ = ⟨3, 4⟩ :=
|
||||
rfl
|
||||
|
||||
set_option smartUnfolding false in
|
||||
theorem ex2 : f (α := S) [⟨1, 2⟩, ⟨3, 4⟩, ⟨5, 6⟩] ⟨4, 10⟩ = ⟨4, 10⟩ :=
|
||||
rfl
|
||||
|
|
|
|||
|
|
@ -50,7 +50,6 @@ else
|
|||
| [] => []
|
||||
| y::ys => (y + x/2 + 1) :: bla (x/2) ys
|
||||
|
||||
set_option smartUnfolding false in
|
||||
theorem blaEq (y : Nat) (ys : List Nat) : bla 4 (y::ys) = (y+2) :: bla 2 ys :=
|
||||
rfl
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue