feat: improve smart unfolding

This commit is contained in:
Leonardo de Moura 2020-11-15 17:34:37 -08:00
parent f1b88e1304
commit c81dbeb53c
5 changed files with 118 additions and 115 deletions

View file

@ -20,6 +20,12 @@ def smartUnfoldingSuffix := "_sunfold"
@[inline] def mkSmartUnfoldingNameFor (n : Name) : Name :=
Name.mkStr n smartUnfoldingSuffix
def smartUnfoldingDefault := true
builtin_initialize
registerOption `smartUnfolding { defValue := smartUnfoldingDefault, group := "", descr := "when computing weak head normal form, use auxiliary definition created for functions defined by structural recursion" }
private def useSmartUnfolding (opts : Options) : Bool :=
opts.getBool `smartUnfolding smartUnfoldingDefault
/- ===========================
Helper methods
=========================== -/
@ -113,20 +119,6 @@ private def reduceRec {α} (recVal : RecursorVal) (recLvls : List Level) (recArg
else
failK ()
@[specialize] private def isRecStuck? (isStuck? : Expr → MetaM (Option MVarId)) (recVal : RecursorVal) (recLvls : List Level) (recArgs : Array Expr)
: MetaM (Option MVarId) :=
if recVal.k then
-- TODO: improve this case
pure none
else do
let majorIdx := recVal.getMajorIdx
if h : majorIdx < recArgs.size then do
let major := recArgs.get ⟨majorIdx, h⟩
let major ← whnf major
isStuck? major
else
pure none
/- ===========================
Helper functions for reducing Quot.lift and Quot.ind
=========================== -/
@ -152,41 +144,55 @@ private def reduceQuotRec {α} (recVal : QuotVal) (recLvls : List Level) (recAr
| QuotKind.ind => process 4 3
| _ => failK ()
@[specialize] private def isQuotRecStuck? (isStuck? : Expr → MetaM (Option MVarId)) (recVal : QuotVal) (recLvls : List Level) (recArgs : Array Expr)
: MetaM (Option MVarId) :=
let process? (majorPos : Nat) : MetaM (Option MVarId) :=
if h : majorPos < recArgs.size then do
let major := recArgs.get ⟨majorPos, h⟩
let major ← whnf major
isStuck? major
else
pure none
match recVal.kind with
| QuotKind.lift => process? 5
| QuotKind.ind => process? 4
| _ => pure none
/- ===========================
Helper function for extracting "stuck term"
=========================== -/
/-- Return `some (Expr.mvar mvarId)` if metavariable `mvarId` is blocking reduction. -/
private partial def getStuckMVarImp? : Expr → MetaM (Option MVarId)
| Expr.mdata _ e _ => getStuckMVarImp? e
| Expr.proj _ _ e _ => do getStuckMVarImp? (← whnf e)
| e@(Expr.mvar mvarId _) => pure (some mvarId)
| e@(Expr.app f _ _) =>
let f := f.getAppFn
match f with
| Expr.mvar mvarId _ => pure (some mvarId)
| Expr.const fName fLvls _ => do
let cinfo? ← getConstNoEx? fName
match cinfo? with
| some $ ConstantInfo.recInfo recVal => isRecStuck? getStuckMVarImp? recVal fLvls e.getAppArgs
| some $ ConstantInfo.quotInfo recVal => isQuotRecStuck? getStuckMVarImp? recVal fLvls e.getAppArgs
| _ => pure none
mutual
private partial def isRecStuck? (recVal : RecursorVal) (recLvls : List Level) (recArgs : Array Expr) : MetaM (Option MVarId) :=
if recVal.k then
-- TODO: improve this case
pure none
else do
let majorIdx := recVal.getMajorIdx
if h : majorIdx < recArgs.size then do
let major := recArgs.get ⟨majorIdx, h⟩
let major ← whnf major
getStuckMVarImp? major
else
pure none
private partial def isQuotRecStuck? (recVal : QuotVal) (recLvls : List Level) (recArgs : Array Expr) : MetaM (Option MVarId) :=
let process? (majorPos : Nat) : MetaM (Option MVarId) :=
if h : majorPos < recArgs.size then do
let major := recArgs.get ⟨majorPos, h⟩
let major ← whnf major
getStuckMVarImp? major
else
pure none
match recVal.kind with
| QuotKind.lift => process? 5
| QuotKind.ind => process? 4
| _ => pure none
/-- Return `some (Expr.mvar mvarId)` if metavariable `mvarId` is blocking reduction. -/
private partial def getStuckMVarImp? : Expr → MetaM (Option MVarId)
| Expr.mdata _ e _ => getStuckMVarImp? e
| Expr.proj _ _ e _ => do getStuckMVarImp? (← whnf e)
| e@(Expr.mvar mvarId _) => pure (some mvarId)
| e@(Expr.app f _ _) =>
let f := f.getAppFn
match f with
| Expr.mvar mvarId _ => pure (some mvarId)
| Expr.const fName fLvls _ => do
let cinfo? ← getConstNoEx? fName
match cinfo? with
| some $ ConstantInfo.recInfo recVal => isRecStuck? recVal fLvls e.getAppArgs
| some $ ConstantInfo.quotInfo recVal => isQuotRecStuck? recVal fLvls e.getAppArgs
| _ => pure none
| _ => pure none
| _ => pure none
| _ => pure none
end
@[inline] def getStuckMVar? (e : Expr) : m (Option MVarId) :=
liftM $ getStuckMVarImp? e
@ -333,51 +339,75 @@ private partial def whnfCoreImp (e : Expr) : MetaM Expr :=
@[inline] def whnfCore (e : Expr) : m Expr :=
liftMetaM $ whnfCoreImp e
/--
Similar to `whnfCore`, but uses `synthesizePending` to (try to) synthesize metavariables
that are blocking reduction. -/
private partial def whnfCoreUnstuck (e : Expr) : MetaM Expr := do
let e ← whnfCore e
let (some mvarId) ← getStuckMVar? e | pure e
let succeeded ← Meta.synthPending mvarId
if succeeded then whnfCoreUnstuck e else pure e
def smartUnfoldingDefault := true
builtin_initialize
registerOption `smartUnfolding { defValue := smartUnfoldingDefault, group := "", descr := "when computing weak head normal form, use auxiliary definition created for functions defined by structural recursion" }
private def useSmartUnfolding (opts : Options) : Bool :=
opts.getBool `smartUnfolding smartUnfoldingDefault
/-- Unfold definition using "smart unfolding" if possible. -/
private def unfoldDefinitionImp? (e : Expr) : MetaM (Option Expr) :=
match e with
| Expr.app f _ _ =>
matchConstAux f.getAppFn (fun _ => pure none) fun fInfo fLvls => do
if fInfo.lparams.length != fLvls.length then
pure none
mutual
/-- Reduce `e` until `idRhs` application is exposed or it gets stuck.
This is a helper method for implementing smart unfolding. -/
private partial def whnfUntilIdRhs (e : Expr) : MetaM Expr := do
let e ← whnfCoreImp e
match (← getStuckMVar? e) with
| some mvarId =>
/- Try to "unstuck" by resolving pending TC problems -/
if (← Meta.synthPending mvarId) then
whnfUntilIdRhs e
else
let unfoldDefault (_ : Unit) : MetaM (Option Expr) :=
if fInfo.hasValue then
deltaBetaDefinition fInfo fLvls e.getAppRevArgs (fun _ => pure none) (fun e => pure (some e))
else
pure none
if useSmartUnfolding (← getOptions) then
let fAuxInfo? ← getConstNoEx? (mkSmartUnfoldingNameFor fInfo.name)
match fAuxInfo? with
| some $ fAuxInfo@(ConstantInfo.defnInfo _) =>
deltaBetaDefinition fAuxInfo fLvls e.getAppRevArgs (fun _ => pure none) $ fun e₁ => do
let e₂ ← whnfCoreUnstuck e₁
if isIdRhsApp e₂ then
pure (some (extractIdRhs e₂))
else
pure none
| _ => unfoldDefault ()
pure e -- failed because metavariable is blocking reduction
| _ =>
if isIdRhsApp e then
pure e -- done
else
match (← unfoldDefinitionImp? e) with
| some e => whnfUntilIdRhs e
| none => pure e -- failed because of symbolic argument
/-- Unfold definition using "smart unfolding" if possible. -/
private partial def unfoldDefinitionImp? (e : Expr) : MetaM (Option Expr) :=
match e with
| Expr.app f _ _ =>
matchConstAux f.getAppFn (fun _ => pure none) fun fInfo fLvls => do
if fInfo.lparams.length != fLvls.length then
pure none
else
unfoldDefault ()
| Expr.const name lvls _ => do
let (some (cinfo@(ConstantInfo.defnInfo _))) ← getConstNoEx? name | pure none
deltaDefinition cinfo lvls (fun _ => pure none) (fun e => pure (some e))
| _ => pure none
let unfoldDefault (_ : Unit) : MetaM (Option Expr) :=
if fInfo.hasValue then
deltaBetaDefinition fInfo fLvls e.getAppRevArgs (fun _ => pure none) (fun e => pure (some e))
else
pure none
if useSmartUnfolding (← getOptions) then
let fAuxInfo? ← getConstNoEx? (mkSmartUnfoldingNameFor fInfo.name)
match fAuxInfo? with
| some $ fAuxInfo@(ConstantInfo.defnInfo _) =>
deltaBetaDefinition fAuxInfo fLvls e.getAppRevArgs (fun _ => pure none) $ fun e₁ => do
let e₂ ← whnfUntilIdRhs e₁
if isIdRhsApp e₂ then
pure (some (extractIdRhs e₂))
else
pure none
| _ => unfoldDefault ()
else
unfoldDefault ()
| Expr.const name lvls _ => do
let (some (cinfo@(ConstantInfo.defnInfo _))) ← getConstNoEx? name | pure none
deltaDefinition cinfo lvls (fun _ => pure none) (fun e => pure (some e))
| _ => pure none
end
@[specialize] partial def whnfHeadPredImp (e : Expr) (pred : Expr → MetaM Bool) : MetaM Expr :=
whnfEasyCases e fun e => do
let e ← whnfCoreImp e
if (← pred e) then
match (← unfoldDefinitionImp? e) with
| some e => whnfHeadPredImp e pred
| none => pure e
else
pure e
@[inline] partial def whnfHeadPred (e : Expr) (pred : Expr → MetaM Bool) : m Expr :=
liftMetaM $ whnfHeadPredImp e pred
def whnfUntil (e : Expr) (declName : Name) : m (Option Expr) := liftMetaM do
let e ← whnfHeadPredImp e (fun e => pure $ !e.isAppOf declName)
if e.isAppOf declName then pure e
else pure none
@[inline] def unfoldDefinition? (e : Expr) : m (Option Expr) :=
liftMetaM $ unfoldDefinitionImp? e
@ -476,7 +506,7 @@ partial def whnfImp (e : Expr) : MetaM Expr :=
match (← cached? useCache e) with
| some e' => pure e'
| none =>
let e' ← whnfCore e
let e' ← whnfCoreImp e
match (← reduceNat? e') with
| some v => cache useCache e v
| none =>
@ -501,24 +531,6 @@ def reduceProj? (e : Expr) (i : Nat) : MetaM (Option Expr) := do
else
pure none
@[specialize] partial def whnfHeadPredImp (e : Expr) (pred : Expr → MetaM Bool) : MetaM Expr :=
whnfEasyCases e fun e => do
let e ← whnfCore e
if (← pred e) then
match (← unfoldDefinition? e) with
| some e => whnfHeadPredImp e pred
| none => pure e
else
pure e
@[inline] def whnfHeadPred (e : Expr) (pred : Expr → MetaM Bool) : m Expr :=
liftMetaM $ whnfHeadPredImp e pred
def whnfUntil (e : Expr) (declName : Name) : m (Option Expr) := liftMetaM do
let e ← whnfHeadPredImp e (fun e => pure $ !e.isAppOf declName)
if e.isAppOf declName then pure e
else pure none
builtin_initialize
registerTraceClass `Meta.whnf

View file

@ -60,7 +60,6 @@ return sum
#eval sumOdd [1, 2, 3, 4, 5, 6, 7, 9, 11, 101] 10
set_option smartUnfolding false in
theorem ex5 : sumOdd [1, 2, 3, 4, 5, 6, 7, 9, 11, 101] 10 = 16 :=
rfl
@ -79,7 +78,6 @@ for (x, y) in ps do
sum := sum + x - y
return sum
set_option smartUnfolding false in
theorem ex7 : sumDiff [(2, 1), (10, 5)] = 6 :=
rfl
@ -114,7 +112,6 @@ for x in xs.reverse do
odds := x :: odds
return (evens, odds)
set_option smartUnfolding false in
theorem ex8 : split [1, 2, 3, 4] = ([2, 4], [1, 3]) :=
rfl
@ -169,11 +166,9 @@ for x in xs do
return x
return 0
set_option smartUnfolding false in
theorem ex14 : findOdd [2, 4, 5, 8, 7] = 5 :=
rfl
set_option smartUnfolding false in
theorem ex15 : findOdd [2, 4, 8, 10] = 0 :=
rfl

View file

@ -8,7 +8,6 @@ for x in xs do
#eval f [1, 2, 3] $.run' 0
#eval f [1, 0, 3] $.run' 0
set_option smartUnfolding false in
theorem ex1 : f [1, 2, 3] $.run' 0 = Except.ok () :=
rfl

View file

@ -14,10 +14,8 @@ structure S :=
instance : BEq S :=
⟨fun a b => a.key == b.key⟩
set_option smartUnfolding false in
theorem ex1 : f (α := S) [⟨1, 2⟩, ⟨3, 4⟩, ⟨5, 6⟩] ⟨3, 0⟩ = ⟨3, 4⟩ :=
rfl
set_option smartUnfolding false in
theorem ex2 : f (α := S) [⟨1, 2⟩, ⟨3, 4⟩, ⟨5, 6⟩] ⟨4, 10⟩ = ⟨4, 10⟩ :=
rfl

View file

@ -50,7 +50,6 @@ else
| [] => []
| y::ys => (y + x/2 + 1) :: bla (x/2) ys
set_option smartUnfolding false in
theorem blaEq (y : Nat) (ys : List Nat) : bla 4 (y::ys) = (y+2) :: bla 2 ys :=
rfl