diff --git a/src/Lean/Meta/Basic.lean b/src/Lean/Meta/Basic.lean index 84a759cfcc..1b118a26f0 100644 --- a/src/Lean/Meta/Basic.lean +++ b/src/Lean/Meta/Basic.lean @@ -78,14 +78,6 @@ structure Config where we may want to notify the caller that the TC problem may be solvable later after it assigns `?m`. -/ isDefEqStuckEx : Bool := false - /-- - Controls which definitions and theorems can be unfolded by `isDefEq` and `whnf`. - -/ - transparency : TransparencyMode := TransparencyMode.default - /-- If zetaNonDep == false, then non dependent let-decls are not zeta expanded. -/ - zetaNonDep : Bool := true - /-- When `trackZeta == true`, we store zetaFVarIds all free variables that have been zeta-expanded. -/ - trackZeta : Bool := false /-- Enable/disable the unification hints feature. -/ unificationHints : Bool := true /-- Enables proof irrelevance at `isDefEq` -/ @@ -99,8 +91,26 @@ structure Config where assignSyntheticOpaque : Bool := false /-- Enable/Disable support for offset constraints such as `?x + 1 =?= e` -/ offsetCnstrs : Bool := true + /-- + Controls which definitions and theorems can be unfolded by `isDefEq` and `whnf`. + -/ + transparency : TransparencyMode := TransparencyMode.default + /-- If zetaNonDep == false, then non dependent let-decls are not zeta expanded. -/ + zetaNonDep : Bool := true + /-- + When `trackZeta = true`, we track all free variables that have been zeta-expanded. + That is, suppose the local context contains + the declaration `x : t := v`, and we reduce `x` to `v`, then we insert `x` into `State.zetaFVarIds`. + We use `trackZeta` to discover which let-declarations `let x := v; e` can be represented as `(fun x => e) v`. + When we find these declarations we set their `nonDep` flag with `true`. + To find these let-declarations in a given term `s`, we + 1- Reset `State.zetaFVarIds` + 2- Set `trackZeta := true` + 3- Type-check `s`. + -/ + trackZeta : Bool := false /-- Eta for structures configuration mode. -/ - etaStruct : EtaStructMode := .all + etaStruct : EtaStructMode := .all /-- Function parameter information cache. @@ -366,7 +376,7 @@ section Methods variable [MonadControlT MetaM n] [Monad n] @[inline] def modifyCache (f : Cache → Cache) : MetaM Unit := - modify fun ⟨mctx, cache, zetaFVarIds, postponed⟩ => ⟨mctx, f cache, zetaFVarIds, postponed⟩ + modify fun { mctx, cache, zetaFVarIds, postponed } => { mctx, cache := f cache, zetaFVarIds, postponed } @[inline] def modifyInferTypeCache (f : InferTypeCache → InferTypeCache) : MetaM Unit := modifyCache fun ⟨ic, c1, c2, c3, c4, c5, c6⟩ => ⟨f ic, c1, c2, c3, c4, c5, c6⟩ @@ -781,6 +791,9 @@ def elimMVarDeps (xs : Array Expr) (e : Expr) (preserveOrder : Bool := false) : @[inline] def withConfig (f : Config → Config) : n α → n α := mapMetaM <| withReader (fun ctx => { ctx with config := f ctx.config }) +/-- +Executes `x` tracking zeta reductions `Config.trackZeta := true` +-/ @[inline] def withTrackingZeta (x : n α) : n α := withConfig (fun cfg => { cfg with trackZeta := true }) x diff --git a/src/Lean/Meta/WHNF.lean b/src/Lean/Meta/WHNF.lean index 62fc0779ee..0731d54e60 100644 --- a/src/Lean/Meta/WHNF.lean +++ b/src/Lean/Meta/WHNF.lean @@ -59,33 +59,32 @@ def isAuxDef (constName : Name) : MetaM Bool := do let env ← getEnv return isAuxRecursor env constName || isNoConfusion env constName -@[inline] private def matchConstAux {α} (e : Expr) (failK : Unit → MetaM α) (k : ConstantInfo → List Level → MetaM α) : MetaM α := - match e with - | Expr.const name lvls => do - let (some cinfo) ← getUnfoldableConst? name | failK () - k cinfo lvls - | _ => failK () +@[inline] private def matchConstAux {α} (e : Expr) (failK : Unit → MetaM α) (k : ConstantInfo → List Level → MetaM α) : MetaM α := do + let .const name lvls := e + | failK () + let (some cinfo) ← getUnfoldableConst? name + | failK () + k cinfo lvls -- =========================== /-! # Helper functions for reducing recursors -/ -- =========================== private def getFirstCtor (d : Name) : MetaM (Option Name) := do - let some (ConstantInfo.inductInfo { ctors := ctor::_, ..}) ← getUnfoldableConstNoEx? d | pure none + let some (ConstantInfo.inductInfo { ctors := ctor::_, ..}) ← getUnfoldableConstNoEx? d | + return none return some ctor private def mkNullaryCtor (type : Expr) (nparams : Nat) : MetaM (Option Expr) := do - match type.getAppFn with - | Expr.const d lvls => - let (some ctor) ← getFirstCtor d | pure none - return mkAppN (mkConst ctor lvls) (type.getAppArgs.shrink nparams) - | _ => - return none + let .const d lvls := type.getAppFn + | return none + let (some ctor) ← getFirstCtor d | pure none + return mkAppN (mkConst ctor lvls) (type.getAppArgs.shrink nparams) private def getRecRuleFor (recVal : RecursorVal) (major : Expr) : Option RecursorRule := match major.getAppFn with - | Expr.const fn _ => recVal.rules.find? fun r => r.ctor == fn - | _ => none + | .const fn _ => recVal.rules.find? fun r => r.ctor == fn + | _ => none private def toCtorWhenK (recVal : RecursorVal) (major : Expr) : MetaM Expr := do let majorType ← inferType major @@ -165,7 +164,7 @@ private def reduceRec (recVal : RecursorVal) (recLvls : List Level) (recArgs : A let majorIdx := recVal.getMajorIdx if h : majorIdx < recArgs.size then do let major := recArgs.get ⟨majorIdx, h⟩ - let mut major ← if isWFRec recVal.name && (← getTransparency) == TransparencyMode.default then + let mut major ← if isWFRec recVal.name && (← getTransparency) == .default then -- If recursor is `Acc.rec` or `WellFounded.rec` and transparency is default, -- then we bump transparency to .all to make sure we can unfold defs defined by WellFounded recursion. -- We use this trick because we abstract nested proofs occurring in definitions. @@ -389,8 +388,8 @@ inductive ReduceMatcherResult where -/ def canUnfoldAtMatcher (cfg : Config) (info : ConstantInfo) : CoreM Bool := do match cfg.transparency with - | TransparencyMode.all => return true - | TransparencyMode.default => return !(← isIrreducible info.name) + | .all => return true + | .default => return !(← isIrreducible info.name) | _ => if (← isReducible info.name) || isGlobalInstance (← getEnv) info.name then return true @@ -429,31 +428,29 @@ private def whnfMatcher (e : Expr) : MetaM Expr := do whnf e def reduceMatcher? (e : Expr) : MetaM ReduceMatcherResult := do - match e.getAppFn with - | Expr.const declName declLevels => - let some info ← getMatcherInfo? declName - | return ReduceMatcherResult.notMatcher - let args := e.getAppArgs - let prefixSz := info.numParams + 1 + info.numDiscrs - if args.size < prefixSz + info.numAlts then - return ReduceMatcherResult.partialApp - else - let constInfo ← getConstInfo declName - let f ← instantiateValueLevelParams constInfo declLevels - let auxApp := mkAppN f args[0:prefixSz] - let auxAppType ← inferType auxApp - forallBoundedTelescope auxAppType info.numAlts fun hs _ => do - let auxApp ← whnfMatcher (mkAppN auxApp hs) - let auxAppFn := auxApp.getAppFn - let mut i := prefixSz - for h in hs do - if auxAppFn == h then - let result := mkAppN args[i]! auxApp.getAppArgs - let result := mkAppN result args[prefixSz + info.numAlts:args.size] - return ReduceMatcherResult.reduced result.headBeta - i := i + 1 - return ReduceMatcherResult.stuck auxApp - | _ => pure ReduceMatcherResult.notMatcher + let .const declName declLevels := e.getAppFn + | return .notMatcher + let some info ← getMatcherInfo? declName + | return .notMatcher + let args := e.getAppArgs + let prefixSz := info.numParams + 1 + info.numDiscrs + if args.size < prefixSz + info.numAlts then + return ReduceMatcherResult.partialApp + let constInfo ← getConstInfo declName + let f ← instantiateValueLevelParams constInfo declLevels + let auxApp := mkAppN f args[0:prefixSz] + let auxAppType ← inferType auxApp + forallBoundedTelescope auxAppType info.numAlts fun hs _ => do + let auxApp ← whnfMatcher (mkAppN auxApp hs) + let auxAppFn := auxApp.getAppFn + let mut i := prefixSz + for h in hs do + if auxAppFn == h then + let result := mkAppN args[i]! auxApp.getAppArgs + let result := mkAppN result args[prefixSz + info.numAlts:args.size] + return ReduceMatcherResult.reduced result.headBeta + i := i + 1 + return ReduceMatcherResult.stuck auxApp private def projectCore? (e : Expr) (i : Nat) : MetaM (Option Expr) := do let e := e.toCtorIfLit @@ -471,8 +468,8 @@ def project? (e : Expr) (i : Nat) : MetaM (Option Expr) := do /-- Reduce kernel projection `Expr.proj ..` expression. -/ def reduceProj? (e : Expr) : MetaM (Option Expr) := do match e with - | Expr.proj _ i c => project? c i - | _ => return none + | .proj _ i c => project? c i + | _ => return none /-- Auxiliary method for reducing terms of the form `?m t_1 ... t_n` where `?m` is delayed assigned. @@ -516,9 +513,9 @@ where whnfEasyCases e fun e => do trace[Meta.whnf] e match e with - | Expr.const .. => pure e - | Expr.letE _ _ v b _ => go <| b.instantiate1 v - | Expr.app f .. => + | .const .. => pure e + | .letE _ _ v b _ => go <| b.instantiate1 v + | .app f .. => let f := f.getAppFn let f' ← go f if f'.isLambda then @@ -532,21 +529,21 @@ where return e else match (← reduceMatcher? e) with - | ReduceMatcherResult.reduced eNew => go eNew - | ReduceMatcherResult.partialApp => pure e - | ReduceMatcherResult.stuck _ => pure e - | ReduceMatcherResult.notMatcher => + | .reduced eNew => go eNew + | .partialApp => pure e + | .stuck _ => pure e + | .notMatcher => matchConstAux f' (fun _ => return e) fun cinfo lvls => match cinfo with - | ConstantInfo.recInfo rec => reduceRec rec lvls e.getAppArgs (fun _ => return e) go - | ConstantInfo.quotInfo rec => reduceQuotRec rec lvls e.getAppArgs (fun _ => return e) go - | c@(ConstantInfo.defnInfo _) => do + | .recInfo rec => reduceRec rec lvls e.getAppArgs (fun _ => return e) go + | .quotInfo rec => reduceQuotRec rec lvls e.getAppArgs (fun _ => return e) go + | c@(.defnInfo _) => do if (← isAuxDef c.name) then deltaBetaDefinition c lvls e.getAppRevArgs (fun _ => return e) go else return e | _ => return e - | Expr.proj _ i c => + | .proj _ i c => if simpleReduceOnly then return e else @@ -591,11 +588,11 @@ partial def smartUnfoldingReduce? (e : Expr) : MetaM (Option Expr) := where go (e : Expr) : OptionT MetaM Expr := do match e with - | Expr.letE n t v b _ => withLetDecl n t (← go v) fun x => do mkLetFVars #[x] (← go (b.instantiate1 x)) - | Expr.lam .. => lambdaTelescope e fun xs b => do mkLambdaFVars xs (← go b) - | Expr.app f a .. => return mkApp (← go f) (← go a) - | Expr.proj _ _ s => return e.updateProj! (← go s) - | Expr.mdata _ b => + | .letE n t v b _ => withLetDecl n t (← go v) fun x => do mkLetFVars #[x] (← go (b.instantiate1 x)) + | .lam .. => lambdaTelescope e fun xs b => do mkLambdaFVars xs (← go b) + | .app f a .. => return mkApp (← go f) (← go a) + | .proj _ _ s => return e.updateProj! (← go s) + | .mdata _ b => if let some m := smartUnfoldingMatch? e then goMatch m else @@ -625,7 +622,7 @@ mutual -/ partial def unfoldProjInst? (e : Expr) : MetaM (Option Expr) := do match e.getAppFn with - | Expr.const declName .. => + | .const declName .. => match (← getProjectionFnInfo? declName) with | some { fromClass := true, .. } => match (← withDefault <| unfoldDefinition? e) with @@ -651,7 +648,7 @@ mutual /-- Unfold definition using "smart unfolding" if possible. -/ partial def unfoldDefinition? (e : Expr) : MetaM (Option Expr) := match e with - | Expr.app f _ => + | .app f _ => matchConstAux f.getAppFn (fun _ => unfoldProjInstWhenIntances? e) fun fInfo fLvls => do if fInfo.levelParams.length != fLvls.length then return none @@ -663,7 +660,7 @@ mutual return none if smartUnfolding.get (← getOptions) then match ((← getEnv).find? (mkSmartUnfoldingNameFor fInfo.name)) with - | some fAuxInfo@(ConstantInfo.defnInfo _) => + | some fAuxInfo@(.defnInfo _) => -- We use `preserveMData := true` to make sure the smart unfolding annotation are not erased in an over-application. deltaBetaDefinition fAuxInfo fLvls e.getAppRevArgs (preserveMData := true) (fun _ => pure none) fun e₁ => do let some r ← smartUnfoldingReduce? e₁ | return none @@ -719,7 +716,7 @@ mutual unfoldDefault () else unfoldDefault () - | Expr.const declName lvls => do + | .const declName lvls => do if smartUnfolding.get (← getOptions) && (← getEnv).contains (mkSmartUnfoldingNameFor declName) then return none else @@ -757,12 +754,12 @@ def reduceRecMatcher? (e : Expr) : MetaM (Option Expr) := do if !e.isApp then return none else match (← reduceMatcher? e) with - | ReduceMatcherResult.reduced e => return e + | .reduced e => return e | _ => matchConstAux e.getAppFn (fun _ => pure none) fun cinfo lvls => do match cinfo with - | ConstantInfo.recInfo «rec» => reduceRec «rec» lvls e.getAppArgs (fun _ => pure none) (fun e => pure (some e)) - | ConstantInfo.quotInfo «rec» => reduceQuotRec «rec» lvls e.getAppArgs (fun _ => pure none) (fun e => pure (some e)) - | c@(ConstantInfo.defnInfo _) => + | .recInfo «rec» => reduceRec «rec» lvls e.getAppArgs (fun _ => pure none) (fun e => pure (some e)) + | .quotInfo «rec» => reduceQuotRec «rec» lvls e.getAppArgs (fun _ => pure none) (fun e => pure (some e)) + | c@(.defnInfo _) => if (← isAuxDef c.name) then deltaBetaDefinition c lvls e.getAppRevArgs (fun _ => pure none) (fun e => pure (some e)) else @@ -812,12 +809,12 @@ def reduceNat? (e : Expr) : MetaM (Option Expr) := if e.hasFVar || e.hasMVar then return none else match e with - | Expr.app (Expr.const fn _) a => + | .app (.const fn _) a => if fn == ``Nat.succ then reduceUnaryNatOp Nat.succ a else return none - | Expr.app (Expr.app (Expr.const fn _) a1) a2 => + | .app (.app (.const fn _) a1) a2 => if fn == ``Nat.add then reduceBinNatOp Nat.add a1 a2 else if fn == ``Nat.sub then reduceBinNatOp Nat.sub a1 a2 else if fn == ``Nat.mul then reduceBinNatOp Nat.mul a1 a2 @@ -839,25 +836,25 @@ def reduceNat? (e : Expr) : MetaM (Option Expr) := return false else match (← getConfig).transparency with - | TransparencyMode.default => return true - | TransparencyMode.all => return true - | _ => return false + | .default => return true + | .all => return true + | _ => return false @[inline] private def cached? (useCache : Bool) (e : Expr) : MetaM (Option Expr) := do if useCache then match (← getConfig).transparency with - | TransparencyMode.default => return (← get).cache.whnfDefault.find? e - | TransparencyMode.all => return (← get).cache.whnfAll.find? e - | _ => unreachable! + | .default => return (← get).cache.whnfDefault.find? e + | .all => return (← get).cache.whnfAll.find? e + | _ => unreachable! else return none private def cache (useCache : Bool) (e r : Expr) : MetaM Expr := do if useCache then match (← getConfig).transparency with - | TransparencyMode.default => modify fun s => { s with cache.whnfDefault := s.cache.whnfDefault.insert e r } - | TransparencyMode.all => modify fun s => { s with cache.whnfAll := s.cache.whnfAll.insert e r } - | _ => unreachable! + | .default => modify fun s => { s with cache.whnfDefault := s.cache.whnfDefault.insert e r } + | .all => modify fun s => { s with cache.whnfAll := s.cache.whnfAll.insert e r } + | _ => unreachable! return r @[export lean_whnf] @@ -884,7 +881,7 @@ def reduceProjOf? (e : Expr) (p : Name → Bool) : MetaM (Option Expr) := do if !e.isApp then pure none else match e.getAppFn with - | Expr.const name .. => do + | .const name .. => do let env ← getEnv match env.getProjectionStructureName? name with | some structName =>