diff --git a/src/Lean/Meta/WHNF.lean b/src/Lean/Meta/WHNF.lean index 2fa5823ebe..8aa9fc590b 100644 --- a/src/Lean/Meta/WHNF.lean +++ b/src/Lean/Meta/WHNF.lean @@ -32,7 +32,7 @@ private def useSmartUnfolding (opts : Options) : Bool := =========================== -/ def isAuxDef (constName : Name) : MetaM Bool := do let env ← getEnv - pure (isAuxRecursor env constName || isNoConfusion env constName) + return isAuxRecursor env constName || isNoConfusion env constName @[inline] private def matchConstAux {α} (e : Expr) (failK : Unit → MetaM α) (k : ConstantInfo → List Level → MetaM α) : MetaM α := match e with @@ -47,14 +47,15 @@ def isAuxDef (constName : Name) : MetaM Bool := do private def getFirstCtor (d : Name) : MetaM (Option Name) := do let some (ConstantInfo.inductInfo { ctors := ctor::_, ..}) ← getConstNoEx? d | pure none - pure (some ctor) + return some ctor private def mkNullaryCtor (type : Expr) (nparams : Nat) : MetaM (Option Expr) := match type.getAppFn with | Expr.const d lvls _ => do let (some ctor) ← getFirstCtor d | pure none - pure $ mkAppN (mkConst ctor lvls) (type.getAppArgs.shrink nparams) - | _ => pure none + return mkAppN (mkConst ctor lvls) (type.getAppArgs.shrink nparams) + | _ => + return none def toCtorIfLit : Expr → Expr | Expr.lit (Literal.natVal v) _ => @@ -66,7 +67,7 @@ def toCtorIfLit : Expr → Expr private def getRecRuleFor (recVal : RecursorVal) (major : Expr) : Option RecursorRule := match major.getAppFn with - | Expr.const fn _ _ => recVal.rules.find? $ fun r => r.ctor == fn + | Expr.const fn _ _ => recVal.rules.find? fun r => r.ctor == fn | _ => none private def toCtorWhenK (recVal : RecursorVal) (major : Expr) : MetaM (Option Expr) := do @@ -74,16 +75,16 @@ private def toCtorWhenK (recVal : RecursorVal) (major : Expr) : MetaM (Option Ex let majorType ← whnf majorType let majorTypeI := majorType.getAppFn if !majorTypeI.isConstOf recVal.getInduct then - pure none + return none else if majorType.hasExprMVar && majorType.getAppArgs[recVal.numParams:].any Expr.hasExprMVar then - pure none + return none else do let (some newCtorApp) ← mkNullaryCtor majorType recVal.numParams | pure none let newType ← inferType newCtorApp if (← isDefEq majorType newType) then - pure newCtorApp + return newCtorApp else - pure none + return none /-- Auxiliary function for reducing recursor applications. -/ private def reduceRec {α} (recVal : RecursorVal) (recLvls : List Level) (recArgs : Array Expr) (failK : Unit → MetaM α) (successK : Expr → MetaM α) : MetaM α := @@ -148,7 +149,7 @@ mutual private partial def isRecStuck? (recVal : RecursorVal) (recLvls : List Level) (recArgs : Array Expr) : MetaM (Option MVarId) := if recVal.k then -- TODO: improve this case - pure none + return none else do let majorIdx := recVal.getMajorIdx if h : majorIdx < recArgs.size then do @@ -156,7 +157,7 @@ mutual let major ← whnf major getStuckMVar? major else - pure none + return none private partial def isQuotRecStuck? (recVal : QuotVal) (recLvls : List Level) (recArgs : Array Expr) : MetaM (Option MVarId) := let process? (majorPos : Nat) : MetaM (Option MVarId) := @@ -165,11 +166,11 @@ mutual let major ← whnf major getStuckMVar? major else - pure none + return none match recVal.kind with | QuotKind.lift => process? 5 | QuotKind.ind => process? 4 - | _ => pure none + | _ => return none /-- Return `some (Expr.mvar mvarId)` if metavariable `mvarId` is blocking reduction. -/ partial def getStuckMVar? : Expr → MetaM (Option MVarId) @@ -183,15 +184,15 @@ mutual | e@(Expr.app f _ _) => let f := f.getAppFn match f with - | Expr.mvar mvarId _ => pure (some mvarId) + | Expr.mvar mvarId _ => return some mvarId | Expr.const fName fLvls _ => do let cinfo? ← getConstNoEx? fName match cinfo? with | some $ ConstantInfo.recInfo recVal => isRecStuck? recVal fLvls e.getAppArgs | some $ ConstantInfo.quotInfo recVal => isQuotRecStuck? recVal fLvls e.getAppArgs - | _ => pure none - | _ => pure none - | _ => pure none + | _ => return none + | _ => return none + | _ => return none end /- =========================== @@ -199,33 +200,34 @@ end =========================== -/ /-- Auxiliary combinator for handling easy WHNF cases. It takes a function for handling the "hard" cases as an argument -/ -@[specialize] private partial def whnfEasyCases : Expr → (Expr → MetaM Expr) → MetaM Expr - | e@(Expr.forallE _ _ _ _), _ => pure e - | e@(Expr.lam _ _ _ _), _ => pure e - | e@(Expr.sort _ _), _ => pure e - | e@(Expr.lit _ _), _ => pure e - | e@(Expr.bvar _ _), _ => unreachable! - | Expr.mdata _ e _, k => whnfEasyCases e k - | e@(Expr.letE _ _ _ _ _), k => k e - | e@(Expr.fvar fvarId _), k => do +@[specialize] private partial def whnfEasyCases (e : Expr) (k : Expr → MetaM Expr) : MetaM Expr := do + match e with + | Expr.forallE .. => return e + | Expr.lam .. => return e + | Expr.sort .. => return e + | Expr.lit .. => return e + | Expr.bvar .. => unreachable! + | Expr.letE .. => k e + | Expr.const .. => k e + | Expr.app .. => k e + | Expr.proj .. => k e + | Expr.mdata _ e _ => whnfEasyCases e k + | Expr.fvar fvarId _ => let decl ← getLocalDecl fvarId match decl with - | LocalDecl.cdecl _ _ _ _ _ => pure e + | LocalDecl.cdecl .. => return e | LocalDecl.ldecl _ _ _ _ v nonDep => let cfg ← getConfig if nonDep && !cfg.zetaNonDep then - pure e + return e else when cfg.trackZeta do modify fun s => { s with zetaFVarIds := s.zetaFVarIds.insert fvarId } whnfEasyCases v k - | e@(Expr.mvar mvarId _), k => do + | Expr.mvar mvarId _ => match (← getExprMVarAssignment? mvarId) with | some v => whnfEasyCases v k - | none => pure e - | e@(Expr.const _ _ _), k => k e - | e@(Expr.app _ _ _), k => k e - | e@(Expr.proj _ _ _ _), k => k e + | none => return e /-- Return true iff term is of the form `idRhs ...` -/ private def isIdRhsApp (e : Expr) : Bool := @@ -248,7 +250,8 @@ private def extractIdRhs (e : Expr) : Expr := @[specialize] private def deltaBetaDefinition {α} (c : ConstantInfo) (lvls : List Level) (revArgs : Array Expr) (failK : Unit → α) (successK : Expr → α) : α := - if c.lparams.length != lvls.length then failK () + if c.lparams.length != lvls.length then + failK () else let val := c.instantiateValueLevelParams lvls let val := val.betaRev revArgs @@ -262,12 +265,13 @@ inductive ReduceMatcherResult where def reduceMatcher? (e : Expr) : MetaM ReduceMatcherResult := do match e.getAppFn with - | Expr.const declName declLevels _ => do - let some info ← getMatcherInfo? declName | pure ReduceMatcherResult.notMatcher + | Expr.const declName declLevels _ => + let some info ← getMatcherInfo? declName + | return ReduceMatcherResult.notMatcher let args := e.getAppArgs let prefixSz := info.numParams + 1 + info.numDiscrs if args.size < prefixSz + info.numAlts then - pure ReduceMatcherResult.partialApp + return ReduceMatcherResult.partialApp else let constInfo ← getConstInfo declName let f := constInfo.instantiateValueLevelParams declLevels @@ -294,14 +298,14 @@ def project? (e : Expr) (i : Nat) : MetaM (Option Expr) := do let numArgs := e.getAppNumArgs let idx := ctorVal.numParams + i if idx < numArgs then - pure (some (e.getArg! idx)) + return some (e.getArg! idx) else - pure none + return none def reduceProj? (e : Expr) : MetaM (Option Expr) := do match e with | Expr.proj _ i c _ => project? c i - | _ => return none + | _ => return none /-- Apply beta-reduction, zeta-reduction (i.e., unfold let local-decls), iota-reduction, @@ -351,10 +355,10 @@ mutual if (← Meta.synthPending mvarId) then whnfUntilIdRhs e else - pure e -- failed because metavariable is blocking reduction + return e -- failed because metavariable is blocking reduction | _ => if isIdRhsApp e then - pure e -- done + return e -- done else match (← unfoldDefinition? e) with | some e => whnfUntilIdRhs e @@ -366,30 +370,32 @@ mutual | Expr.app f _ _ => matchConstAux f.getAppFn (fun _ => pure none) fun fInfo fLvls => do if fInfo.lparams.length != fLvls.length then - pure none + return none else let unfoldDefault (_ : Unit) : MetaM (Option Expr) := if fInfo.hasValue then deltaBetaDefinition fInfo fLvls e.getAppRevArgs (fun _ => pure none) (fun e => pure (some e)) else - pure none + return none if useSmartUnfolding (← getOptions) then let fAuxInfo? ← getConstNoEx? (mkSmartUnfoldingNameFor fInfo.name) match fAuxInfo? with - | some $ fAuxInfo@(ConstantInfo.defnInfo _) => - deltaBetaDefinition fAuxInfo fLvls e.getAppRevArgs (fun _ => pure none) $ fun e₁ => do + | some fAuxInfo@(ConstantInfo.defnInfo _) => + deltaBetaDefinition fAuxInfo fLvls e.getAppRevArgs (fun _ => pure none) fun e₁ => do let e₂ ← whnfUntilIdRhs e₁ if isIdRhsApp e₂ then - pure (some (extractIdRhs e₂)) + return some (extractIdRhs e₂) else - pure none + return none | _ => unfoldDefault () else unfoldDefault () | Expr.const name lvls _ => do let (some (cinfo@(ConstantInfo.defnInfo _))) ← getConstNoEx? name | pure none - deltaDefinition cinfo lvls (fun _ => pure none) (fun e => pure (some e)) - | _ => pure none + deltaDefinition cinfo lvls + (fun _ => pure none) + (fun e => pure (some e)) + | _ => return none end @[specialize] partial def whnfHeadPred (e : Expr) (pred : Expr → MetaM Bool) : MetaM Expr := @@ -398,14 +404,16 @@ end if (← pred e) then match (← unfoldDefinition? e) with | some e => whnfHeadPred e pred - | none => pure e + | none => return e else - pure e + return e def whnfUntil (e : Expr) (declName : Name) : MetaM (Option Expr) := do - let e ← whnfHeadPred e (fun e => pure $ !e.isAppOf declName) - if e.isAppOf declName then pure e - else pure none + let e ← whnfHeadPred e (fun e => return !e.isAppOf declName) + if e.isAppOf declName then + return e + else + return none /-- Try to reduce matcher/recursor/quot applications. We say they are all "morally" recursor applications. -/ def reduceRecMatcher? (e : Expr) : MetaM (Option Expr) := do @@ -413,16 +421,16 @@ def reduceRecMatcher? (e : Expr) : MetaM (Option Expr) := do return none else match (← reduceMatcher? e) with | ReduceMatcherResult.reduced e => return e - | _ => matchConstAux e.getAppFn (fun _ => pure none) fun cinfo lvls => + | _ => matchConstAux e.getAppFn (fun _ => pure none) fun cinfo lvls => do match cinfo with | ConstantInfo.recInfo «rec» => reduceRec «rec» lvls e.getAppArgs (fun _ => pure none) (fun e => pure (some e)) | ConstantInfo.quotInfo «rec» => reduceQuotRec «rec» lvls e.getAppArgs (fun _ => pure none) (fun e => pure (some e)) - | c@(ConstantInfo.defnInfo _) => do + | c@(ConstantInfo.defnInfo _) => if (← isAuxDef c.name) then deltaBetaDefinition c lvls e.getAppRevArgs (fun _ => pure none) (fun e => pure (some e)) else - pure none - | _ => pure none + return none + | _ => return none unsafe def reduceBoolNativeUnsafe (constName : Name) : MetaM Bool := evalConstCheck Bool `Bool constName unsafe def reduceNatNativeUnsafe (constName : Name) : MetaM Nat := evalConstCheck Nat `Nat constName @@ -433,44 +441,45 @@ def reduceNative? (e : Expr) : MetaM (Option Expr) := match e with | Expr.app (Expr.const fName _ _) (Expr.const argName _ _) _ => if fName == `Lean.reduceBool then do - let b ← reduceBoolNative argName - pure $ toExpr b + return toExpr (← reduceBoolNative argName) else if fName == `Lean.reduceNat then do - let n ← reduceNatNative argName - pure $ toExpr n + return toExpr (← reduceNatNative argName) else - pure none - | _ => pure none + return none + | _ => + return none @[inline] def withNatValue {α} (a : Expr) (k : Nat → MetaM (Option α)) : MetaM (Option α) := do let a ← whnf a match a with | Expr.const `Nat.zero _ _ => k 0 | Expr.lit (Literal.natVal v) _ => k v - | _ => pure none + | _ => return none def reduceUnaryNatOp (f : Nat → Nat) (a : Expr) : MetaM (Option Expr) := withNatValue a fun a => - pure $ mkNatLit $ f a + return mkNatLit <| f a def reduceBinNatOp (f : Nat → Nat → Nat) (a b : Expr) : MetaM (Option Expr) := withNatValue a fun a => withNatValue b fun b => do trace[Meta.isDefEq.whnf.reduceBinOp]! "{a} op {b}" - pure $ mkNatLit $ f a b + return mkNatLit <| f a b def reduceBinNatPred (f : Nat → Nat → Bool) (a b : Expr) : MetaM (Option Expr) := do withNatValue a fun a => withNatValue b fun b => - pure $ toExpr $ f a b + return toExpr <| f a b def reduceNat? (e : Expr) : MetaM (Option Expr) := if e.hasFVar || e.hasMVar then - pure none + return none else match e with | Expr.app (Expr.const fn _ _) a _ => - if fn == `Nat.succ then reduceUnaryNatOp Nat.succ a - else pure none + if fn == `Nat.succ then + reduceUnaryNatOp Nat.succ a + else + return none | Expr.app (Expr.app (Expr.const fn _ _) a1 _) a2 _ => if fn == `Nat.add then reduceBinNatOp Nat.add a1 a2 else if fn == `Nat.sub then reduceBinNatOp Nat.sub a1 a2 @@ -479,8 +488,9 @@ def reduceNat? (e : Expr) : MetaM (Option Expr) := else if fn == `Nat.mod then reduceBinNatOp Nat.mod a1 a2 else if fn == `Nat.beq then reduceBinNatPred Nat.beq a1 a2 else if fn == `Nat.ble then reduceBinNatPred Nat.ble a1 a2 - else pure none - | _ => pure none + else return none + | _ => + return none @[inline] private def useWHNFCache (e : Expr) : MetaM Bool := do