diff --git a/src/Lean/Meta/WHNF.lean b/src/Lean/Meta/WHNF.lean index 87e01beaa1..ddd3c8e6e1 100644 --- a/src/Lean/Meta/WHNF.lean +++ b/src/Lean/Meta/WHNF.lean @@ -20,6 +20,12 @@ def smartUnfoldingSuffix := "_sunfold" @[inline] def mkSmartUnfoldingNameFor (n : Name) : Name := Name.mkStr n smartUnfoldingSuffix +def smartUnfoldingDefault := true +builtin_initialize + registerOption `smartUnfolding { defValue := smartUnfoldingDefault, group := "", descr := "when computing weak head normal form, use auxiliary definition created for functions defined by structural recursion" } +private def useSmartUnfolding (opts : Options) : Bool := + opts.getBool `smartUnfolding smartUnfoldingDefault + /- =========================== Helper methods =========================== -/ @@ -113,20 +119,6 @@ private def reduceRec {α} (recVal : RecursorVal) (recLvls : List Level) (recArg else failK () -@[specialize] private def isRecStuck? (isStuck? : Expr → MetaM (Option MVarId)) (recVal : RecursorVal) (recLvls : List Level) (recArgs : Array Expr) - : MetaM (Option MVarId) := - if recVal.k then - -- TODO: improve this case - pure none - else do - let majorIdx := recVal.getMajorIdx - if h : majorIdx < recArgs.size then do - let major := recArgs.get ⟨majorIdx, h⟩ - let major ← whnf major - isStuck? major - else - pure none - /- =========================== Helper functions for reducing Quot.lift and Quot.ind =========================== -/ @@ -152,41 +144,55 @@ private def reduceQuotRec {α} (recVal : QuotVal) (recLvls : List Level) (recAr | QuotKind.ind => process 4 3 | _ => failK () -@[specialize] private def isQuotRecStuck? (isStuck? : Expr → MetaM (Option MVarId)) (recVal : QuotVal) (recLvls : List Level) (recArgs : Array Expr) - : MetaM (Option MVarId) := - let process? (majorPos : Nat) : MetaM (Option MVarId) := - if h : majorPos < recArgs.size then do - let major := recArgs.get ⟨majorPos, h⟩ - let major ← whnf major - isStuck? major - else - pure none - match recVal.kind with - | QuotKind.lift => process? 5 - | QuotKind.ind => process? 4 - | _ => pure none - /- =========================== Helper function for extracting "stuck term" =========================== -/ -/-- Return `some (Expr.mvar mvarId)` if metavariable `mvarId` is blocking reduction. -/ -private partial def getStuckMVarImp? : Expr → MetaM (Option MVarId) - | Expr.mdata _ e _ => getStuckMVarImp? e - | Expr.proj _ _ e _ => do getStuckMVarImp? (← whnf e) - | e@(Expr.mvar mvarId _) => pure (some mvarId) - | e@(Expr.app f _ _) => - let f := f.getAppFn - match f with - | Expr.mvar mvarId _ => pure (some mvarId) - | Expr.const fName fLvls _ => do - let cinfo? ← getConstNoEx? fName - match cinfo? with - | some $ ConstantInfo.recInfo recVal => isRecStuck? getStuckMVarImp? recVal fLvls e.getAppArgs - | some $ ConstantInfo.quotInfo recVal => isQuotRecStuck? getStuckMVarImp? recVal fLvls e.getAppArgs - | _ => pure none +mutual + private partial def isRecStuck? (recVal : RecursorVal) (recLvls : List Level) (recArgs : Array Expr) : MetaM (Option MVarId) := + if recVal.k then + -- TODO: improve this case + pure none + else do + let majorIdx := recVal.getMajorIdx + if h : majorIdx < recArgs.size then do + let major := recArgs.get ⟨majorIdx, h⟩ + let major ← whnf major + getStuckMVarImp? major + else + pure none + + private partial def isQuotRecStuck? (recVal : QuotVal) (recLvls : List Level) (recArgs : Array Expr) : MetaM (Option MVarId) := + let process? (majorPos : Nat) : MetaM (Option MVarId) := + if h : majorPos < recArgs.size then do + let major := recArgs.get ⟨majorPos, h⟩ + let major ← whnf major + getStuckMVarImp? major + else + pure none + match recVal.kind with + | QuotKind.lift => process? 5 + | QuotKind.ind => process? 4 + | _ => pure none + + /-- Return `some (Expr.mvar mvarId)` if metavariable `mvarId` is blocking reduction. -/ + private partial def getStuckMVarImp? : Expr → MetaM (Option MVarId) + | Expr.mdata _ e _ => getStuckMVarImp? e + | Expr.proj _ _ e _ => do getStuckMVarImp? (← whnf e) + | e@(Expr.mvar mvarId _) => pure (some mvarId) + | e@(Expr.app f _ _) => + let f := f.getAppFn + match f with + | Expr.mvar mvarId _ => pure (some mvarId) + | Expr.const fName fLvls _ => do + let cinfo? ← getConstNoEx? fName + match cinfo? with + | some $ ConstantInfo.recInfo recVal => isRecStuck? recVal fLvls e.getAppArgs + | some $ ConstantInfo.quotInfo recVal => isQuotRecStuck? recVal fLvls e.getAppArgs + | _ => pure none + | _ => pure none | _ => pure none - | _ => pure none +end @[inline] def getStuckMVar? (e : Expr) : m (Option MVarId) := liftM $ getStuckMVarImp? e @@ -333,51 +339,75 @@ private partial def whnfCoreImp (e : Expr) : MetaM Expr := @[inline] def whnfCore (e : Expr) : m Expr := liftMetaM $ whnfCoreImp e -/-- - Similar to `whnfCore`, but uses `synthesizePending` to (try to) synthesize metavariables - that are blocking reduction. -/ -private partial def whnfCoreUnstuck (e : Expr) : MetaM Expr := do - let e ← whnfCore e - let (some mvarId) ← getStuckMVar? e | pure e - let succeeded ← Meta.synthPending mvarId - if succeeded then whnfCoreUnstuck e else pure e - -def smartUnfoldingDefault := true -builtin_initialize - registerOption `smartUnfolding { defValue := smartUnfoldingDefault, group := "", descr := "when computing weak head normal form, use auxiliary definition created for functions defined by structural recursion" } -private def useSmartUnfolding (opts : Options) : Bool := - opts.getBool `smartUnfolding smartUnfoldingDefault - -/-- Unfold definition using "smart unfolding" if possible. -/ -private def unfoldDefinitionImp? (e : Expr) : MetaM (Option Expr) := - match e with - | Expr.app f _ _ => - matchConstAux f.getAppFn (fun _ => pure none) fun fInfo fLvls => do - if fInfo.lparams.length != fLvls.length then - pure none +mutual + /-- Reduce `e` until `idRhs` application is exposed or it gets stuck. + This is a helper method for implementing smart unfolding. -/ + private partial def whnfUntilIdRhs (e : Expr) : MetaM Expr := do + let e ← whnfCoreImp e + match (← getStuckMVar? e) with + | some mvarId => + /- Try to "unstuck" by resolving pending TC problems -/ + if (← Meta.synthPending mvarId) then + whnfUntilIdRhs e else - let unfoldDefault (_ : Unit) : MetaM (Option Expr) := - if fInfo.hasValue then - deltaBetaDefinition fInfo fLvls e.getAppRevArgs (fun _ => pure none) (fun e => pure (some e)) - else - pure none - if useSmartUnfolding (← getOptions) then - let fAuxInfo? ← getConstNoEx? (mkSmartUnfoldingNameFor fInfo.name) - match fAuxInfo? with - | some $ fAuxInfo@(ConstantInfo.defnInfo _) => - deltaBetaDefinition fAuxInfo fLvls e.getAppRevArgs (fun _ => pure none) $ fun e₁ => do - let e₂ ← whnfCoreUnstuck e₁ - if isIdRhsApp e₂ then - pure (some (extractIdRhs e₂)) - else - pure none - | _ => unfoldDefault () + pure e -- failed because metavariable is blocking reduction + | _ => + if isIdRhsApp e then + pure e -- done + else + match (← unfoldDefinitionImp? e) with + | some e => whnfUntilIdRhs e + | none => pure e -- failed because of symbolic argument + + /-- Unfold definition using "smart unfolding" if possible. -/ + private partial def unfoldDefinitionImp? (e : Expr) : MetaM (Option Expr) := + match e with + | Expr.app f _ _ => + matchConstAux f.getAppFn (fun _ => pure none) fun fInfo fLvls => do + if fInfo.lparams.length != fLvls.length then + pure none else - unfoldDefault () - | Expr.const name lvls _ => do - let (some (cinfo@(ConstantInfo.defnInfo _))) ← getConstNoEx? name | pure none - deltaDefinition cinfo lvls (fun _ => pure none) (fun e => pure (some e)) - | _ => pure none + let unfoldDefault (_ : Unit) : MetaM (Option Expr) := + if fInfo.hasValue then + deltaBetaDefinition fInfo fLvls e.getAppRevArgs (fun _ => pure none) (fun e => pure (some e)) + else + pure none + if useSmartUnfolding (← getOptions) then + let fAuxInfo? ← getConstNoEx? (mkSmartUnfoldingNameFor fInfo.name) + match fAuxInfo? with + | some $ fAuxInfo@(ConstantInfo.defnInfo _) => + deltaBetaDefinition fAuxInfo fLvls e.getAppRevArgs (fun _ => pure none) $ fun e₁ => do + let e₂ ← whnfUntilIdRhs e₁ + if isIdRhsApp e₂ then + pure (some (extractIdRhs e₂)) + else + pure none + | _ => unfoldDefault () + else + unfoldDefault () + | Expr.const name lvls _ => do + let (some (cinfo@(ConstantInfo.defnInfo _))) ← getConstNoEx? name | pure none + deltaDefinition cinfo lvls (fun _ => pure none) (fun e => pure (some e)) + | _ => pure none +end + +@[specialize] partial def whnfHeadPredImp (e : Expr) (pred : Expr → MetaM Bool) : MetaM Expr := + whnfEasyCases e fun e => do + let e ← whnfCoreImp e + if (← pred e) then + match (← unfoldDefinitionImp? e) with + | some e => whnfHeadPredImp e pred + | none => pure e + else + pure e + +@[inline] partial def whnfHeadPred (e : Expr) (pred : Expr → MetaM Bool) : m Expr := + liftMetaM $ whnfHeadPredImp e pred + +def whnfUntil (e : Expr) (declName : Name) : m (Option Expr) := liftMetaM do + let e ← whnfHeadPredImp e (fun e => pure $ !e.isAppOf declName) + if e.isAppOf declName then pure e + else pure none @[inline] def unfoldDefinition? (e : Expr) : m (Option Expr) := liftMetaM $ unfoldDefinitionImp? e @@ -476,7 +506,7 @@ partial def whnfImp (e : Expr) : MetaM Expr := match (← cached? useCache e) with | some e' => pure e' | none => - let e' ← whnfCore e + let e' ← whnfCoreImp e match (← reduceNat? e') with | some v => cache useCache e v | none => @@ -501,24 +531,6 @@ def reduceProj? (e : Expr) (i : Nat) : MetaM (Option Expr) := do else pure none -@[specialize] partial def whnfHeadPredImp (e : Expr) (pred : Expr → MetaM Bool) : MetaM Expr := - whnfEasyCases e fun e => do - let e ← whnfCore e - if (← pred e) then - match (← unfoldDefinition? e) with - | some e => whnfHeadPredImp e pred - | none => pure e - else - pure e - -@[inline] def whnfHeadPred (e : Expr) (pred : Expr → MetaM Bool) : m Expr := - liftMetaM $ whnfHeadPredImp e pred - -def whnfUntil (e : Expr) (declName : Name) : m (Option Expr) := liftMetaM do - let e ← whnfHeadPredImp e (fun e => pure $ !e.isAppOf declName) - if e.isAppOf declName then pure e - else pure none - builtin_initialize registerTraceClass `Meta.whnf diff --git a/tests/lean/run/doNotation2.lean b/tests/lean/run/doNotation2.lean index e8aec824f6..684f0f4608 100644 --- a/tests/lean/run/doNotation2.lean +++ b/tests/lean/run/doNotation2.lean @@ -60,7 +60,6 @@ return sum #eval sumOdd [1, 2, 3, 4, 5, 6, 7, 9, 11, 101] 10 -set_option smartUnfolding false in theorem ex5 : sumOdd [1, 2, 3, 4, 5, 6, 7, 9, 11, 101] 10 = 16 := rfl @@ -79,7 +78,6 @@ for (x, y) in ps do sum := sum + x - y return sum -set_option smartUnfolding false in theorem ex7 : sumDiff [(2, 1), (10, 5)] = 6 := rfl @@ -114,7 +112,6 @@ for x in xs.reverse do odds := x :: odds return (evens, odds) -set_option smartUnfolding false in theorem ex8 : split [1, 2, 3, 4] = ([2, 4], [1, 3]) := rfl @@ -169,11 +166,9 @@ for x in xs do return x return 0 -set_option smartUnfolding false in theorem ex14 : findOdd [2, 4, 5, 8, 7] = 5 := rfl -set_option smartUnfolding false in theorem ex15 : findOdd [2, 4, 8, 10] = 0 := rfl diff --git a/tests/lean/run/forBodyResultTypeIssue.lean b/tests/lean/run/forBodyResultTypeIssue.lean index 76e82cb45b..e942324d15 100644 --- a/tests/lean/run/forBodyResultTypeIssue.lean +++ b/tests/lean/run/forBodyResultTypeIssue.lean @@ -8,7 +8,6 @@ for x in xs do #eval f [1, 2, 3] $.run' 0 #eval f [1, 0, 3] $.run' 0 -set_option smartUnfolding false in theorem ex1 : f [1, 2, 3] $.run' 0 = Except.ok () := rfl diff --git a/tests/lean/run/forInUniv.lean b/tests/lean/run/forInUniv.lean index 5c2d612d97..9b3c13374d 100644 --- a/tests/lean/run/forInUniv.lean +++ b/tests/lean/run/forInUniv.lean @@ -14,10 +14,8 @@ structure S := instance : BEq S := ⟨fun a b => a.key == b.key⟩ -set_option smartUnfolding false in theorem ex1 : f (α := S) [⟨1, 2⟩, ⟨3, 4⟩, ⟨5, 6⟩] ⟨3, 0⟩ = ⟨3, 4⟩ := rfl -set_option smartUnfolding false in theorem ex2 : f (α := S) [⟨1, 2⟩, ⟨3, 4⟩, ⟨5, 6⟩] ⟨4, 10⟩ = ⟨4, 10⟩ := rfl diff --git a/tests/lean/run/structuralRec1.lean b/tests/lean/run/structuralRec1.lean index 2ce55cea94..e254a6b60a 100644 --- a/tests/lean/run/structuralRec1.lean +++ b/tests/lean/run/structuralRec1.lean @@ -50,7 +50,6 @@ else | [] => [] | y::ys => (y + x/2 + 1) :: bla (x/2) ys -set_option smartUnfolding false in theorem blaEq (y : Nat) (ys : List Nat) : bla 4 (y::ys) = (y+2) :: bla 2 ys := rfl