From e020cd2ea09646b814ceafa432b4f44f824723cc Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 7 Nov 2019 09:55:54 -0800 Subject: [PATCH] refactor: parametrize over `getConst : Name -> m (Option (ConstantInfo))` instead of `env : Environment` Motivation: when instatiating these templates, we can make definitions opaque by returning `none` at `getConst`. --- library/Init/Lean/WHNFUtil.lean | 143 ++++++++++++++++++-------------- 1 file changed, 80 insertions(+), 63 deletions(-) diff --git a/library/Init/Lean/WHNFUtil.lean b/library/Init/Lean/WHNFUtil.lean index 0e4b54d88c..adafac3a2e 100644 --- a/library/Init/Lean/WHNFUtil.lean +++ b/library/Init/Lean/WHNFUtil.lean @@ -4,8 +4,8 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ prelude -import Init.Lean.Environment -import Init.Lean.AuxRecursor +import Init.Lean.Declaration +import Init.Lean.LocalContext namespace Lean /- =========================== @@ -17,22 +17,38 @@ def smartUnfoldingSuffix := "_sunfold" @[inline] def mkSmartUnfoldingNameFor (n : Name) : Name := Name.mkString n smartUnfoldingSuffix +/- =========================== + Helper functions + =========================== -/ + +@[inline] +def matchConstAux {α : Type} {m : Type → Type} [Monad m] + (getConst : Name → m (Option ConstantInfo)) + (e : Expr) (failK : Unit → m α) (k : ConstantInfo → List Level → m α) : m α := +match e with +| Expr.const name lvls => do + (some cinfo) ← getConst name | failK (); + k cinfo lvls +| _ => failK () + /- =========================== Helper functions for reducing recursors =========================== -/ -private def getFirstCtor (env : Environment) (d : Name) : Option Name := -match env.find d with -| some (ConstantInfo.inductInfo { ctors := ctor::_, ..}) => some ctor -| _ => none +private def getFirstCtor {m : Type → Type} [Monad m] + (getConst : Name → m (Option ConstantInfo)) + (d : Name) : m (Option Name) := +do some (ConstantInfo.inductInfo { ctors := ctor::_, ..}) ← getConst d | pure none; + pure (some ctor) -private def mkNullaryCtor (env : Environment) (type : Expr) (nparams : Nat) : Option Expr := +private def mkNullaryCtor {m : Type → Type} [Monad m] + (getConst : Name → m (Option ConstantInfo)) + (type : Expr) (nparams : Nat) : m (Option Expr) := match type.getAppFn with -| Expr.const d lvls => - match getFirstCtor env d with - | some ctor => mkApp (Expr.const ctor lvls) (type.getAppArgs.shrink nparams) - | none => none -| _ => none +| Expr.const d lvls => do + (some ctor) ← getFirstCtor getConst d | pure none; + pure $ mkApp (Expr.const ctor lvls) (type.getAppArgs.shrink nparams) +| _ => pure none private def toCtorIfLit : Expr → Expr | Expr.lit (Literal.natVal v) => @@ -46,10 +62,11 @@ match major.getAppFn with | _ => none @[specialize] private def toCtorWhenK {m : Type → Type} [Monad m] + (getConst : Name → m (Option ConstantInfo)) (whnf : Expr → m Expr) (inferType : Expr → m Expr) (isDefEq : Expr → Expr → m Bool) - (env : Environment) (rec : RecursorVal) (major : Expr) : m (Option Expr) := + (rec : RecursorVal) (major : Expr) : m (Option Expr) := do majorType ← inferType major; majorType ← whnf majorType; let majorTypeI := majorType.getAppFn; @@ -57,20 +74,19 @@ do majorType ← inferType major; pure none else if majorType.hasExprMVar && majorType.getAppArgs.anyFrom Expr.hasExprMVar rec.nparams then pure none - else - match mkNullaryCtor env majorType rec.nparams with - | none => pure none - | some newCtorApp => do - newType ← inferType newCtorApp; - defeq ← isDefEq majorType newType; - pure $ if defeq then newCtorApp else none + else do + (some newCtorApp) ← mkNullaryCtor getConst majorType rec.nparams | pure none; + newType ← inferType newCtorApp; + defeq ← isDefEq majorType newType; + pure $ if defeq then newCtorApp else none /-- Auxiliary function for reducing recursor applications. -/ @[specialize] def reduceRec {α} {m : Type → Type} [Monad m] + (getConst : Name → m (Option ConstantInfo)) (whnf : Expr → m Expr) (inferType : Expr → m Expr) (isDefEq : Expr → Expr → m Bool) - (env : Environment) (rec : RecursorVal) (recLvls : List Level) (recArgs : Array Expr) + (rec : RecursorVal) (recLvls : List Level) (recArgs : Array Expr) (failK : Unit → m α) (successK : Expr → m α) : m α := let majorIdx := rec.getMajorIdx; if h : majorIdx < recArgs.size then do @@ -80,7 +96,7 @@ if h : majorIdx < recArgs.size then do if !rec.k then pure major else do { - newMajor ← toCtorWhenK whnf inferType isDefEq env rec major; + newMajor ← toCtorWhenK getConst whnf inferType isDefEq rec major; pure (newMajor.getD major) }; let major := toCtorIfLit major; @@ -126,8 +142,8 @@ else do /-- Auxiliary function for reducing `Quot.lift` and `Quot.ind` applications. -/ @[specialize] def reduceQuotRec {α} {m : Type → Type} [Monad m] - (whnf : Expr → m Expr) - (env : Environment) + (getConst : Name → m (Option ConstantInfo)) + (whnf : Expr → m Expr) (rec : QuotVal) (recLvls : List Level) (recArgs : Array Expr) (failK : Unit → m α) (successK : Expr → m α) : m α := let process (majorPos argPos : Nat) : m α := @@ -135,14 +151,12 @@ let process (majorPos argPos : Nat) : m α := let major := recArgs.get ⟨majorPos, h⟩; major ← whnf major; match major with - | Expr.app (Expr.app (Expr.app (Expr.const majorFn _) _) _) majorArg => - match env.find majorFn with - | some (ConstantInfo.quotInfo { kind := QuotKind.ctor, .. }) => - let f := recArgs.get! argPos; - let r := Expr.app f majorArg; - let recArity := majorPos + 1; - successK $ mkAppRange r recArity recArgs.size recArgs - | _ => failK () + | Expr.app (Expr.app (Expr.app (Expr.const majorFn _) _) _) majorArg => do + some (ConstantInfo.quotInfo { kind := QuotKind.ctor, .. }) ← getConst majorFn | failK (); + let f := recArgs.get! argPos; + let r := Expr.app f majorArg; + let recArity := majorPos + 1; + successK $ mkAppRange r recArity recArgs.size recArgs | _ => failK () else failK (); @@ -173,8 +187,9 @@ match rec.kind with /-- Return `some (Expr.mvar mvarId)` if metavariable `mvarId` is blocking reduction. -/ @[specialize] partial def getStuckMVar {m : Type → Type} [Monad m] - (whnf : Expr → m Expr) - (env : Environment) : Expr → m (Option Expr) + (getConst : Name → m (Option ConstantInfo)) + (whnf : Expr → m Expr) + : Expr → m (Option Expr) | Expr.mdata _ e => getStuckMVar e | Expr.proj _ _ e => do e ← whnf e; getStuckMVar e | e@(Expr.mvar _) => pure (some e) @@ -182,8 +197,9 @@ match rec.kind with let f := f.getAppFn; match f with | Expr.mvar _ => pure (some f) - | Expr.const fName fLvls => - match env.find fName with + | Expr.const fName fLvls => do + cinfo? ← getConst fName; + match cinfo? with | some $ ConstantInfo.recInfo rec => isRecStuck whnf getStuckMVar rec fLvls e.getAppArgs | some $ ConstantInfo.quotInfo rec => isQuotRecStuck whnf getStuckMVar rec fLvls e.getAppArgs | _ => pure none @@ -251,19 +267,16 @@ else Apply beta-reduction, zeta-reduction (i.e., unfold let local-decls), iota-reduction, expand let-expressions, expand assigned meta-variables. - This method does *not* apply delta-reduction at the head. - Reason: we want to perform these reductions lazily at isDefEq. - - Remark: this method delta-reduce (transparent) aux-recursors (e.g., casesOn, recOon) IF - `reduceAuxRec? == true` -/ + This method does *not* apply delta-reduction at the head symbol `f` unless `isAuxDef? f` returns true. + Reason: we want to perform these reductions lazily at `isDefEq`. -/ @[specialize] private partial def whnfCore {m : Type → Type} [Monad m] + (getConst : Name → m (Option ConstantInfo)) + (isAuxDef? : Name → m Bool) (whnf : Expr → m Expr) (inferType : Expr → m Expr) (isDefEq : Expr → Expr → m Bool) (getLocalDecl : Name → m LocalDecl) - (getMVarAssignment : Name → m (Option Expr)) - (env : Environment) - (reduceAuxRec? : Bool) : Expr → m Expr + (getMVarAssignment : Name → m (Option Expr)) : Expr → m Expr | e => whnfEasyCases getLocalDecl getMVarAssignment e $ fun e => match e with | e@(Expr.const _ _) => pure e @@ -277,19 +290,20 @@ else else do let done : Unit → m Expr := fun _ => if f == f' then pure e else pure $ e.updateFn f'; - matchConst env f' done $ fun cinfo lvls => + matchConstAux getConst f' done $ fun cinfo lvls => match cinfo with - | ConstantInfo.recInfo rec => reduceRec whnf inferType isDefEq env rec lvls e.getAppArgs done whnfCore - | ConstantInfo.quotInfo rec => reduceQuotRec whnf env rec lvls e.getAppArgs done whnfCore - | c@(ConstantInfo.defnInfo _) => - if reduceAuxRec? && isAuxRecursor env c.name then + | ConstantInfo.recInfo rec => reduceRec getConst whnf inferType isDefEq rec lvls e.getAppArgs done whnfCore + | ConstantInfo.quotInfo rec => reduceQuotRec getConst whnf rec lvls e.getAppArgs done whnfCore + | c@(ConstantInfo.defnInfo _) => do + unfold? ← isAuxDef? c.name; + if unfold? then deltaBetaDefinition c lvls e.getAppArgs done whnfCore else done () | _ => done () | e@(Expr.proj _ i c) => do c ← whnf c; - matchConst env c.getAppFn (fun _ => pure e) $ fun cinfo lvls => + matchConstAux getConst c.getAppFn (fun _ => pure e) $ fun cinfo lvls => match cinfo with | ConstantInfo.ctorInfo ctorVal => pure $ c.getArgD (ctorVal.nparams + i) e | _ => pure e @@ -299,48 +313,51 @@ else Similar to `whnfCore`, but uses `synthesizePending` to (try to) synthesize metavariables that are blocking reduction. -/ @[specialize] private partial def whnfCoreUnstuck {m : Type → Type} [Monad m] + (getConst : Name → m (Option ConstantInfo)) + (isAuxDef? : Name → m Bool) (whnf : Expr → m Expr) (inferType : Expr → m Expr) (isDefEq : Expr → Expr → m Bool) (synthesizePending : Expr → m Bool) (getLocalDecl : Name → m LocalDecl) (getMVarAssignment : Name → m (Option Expr)) - (env : Environment) : Expr → m Expr | e => do - e ← whnfCore whnf inferType isDefEq getLocalDecl getMVarAssignment env true e; - (some mvar) ← getStuckMVar whnf env e | pure e; + e ← whnfCore getConst isAuxDef? whnf inferType isDefEq getLocalDecl getMVarAssignment e; + (some mvar) ← getStuckMVar getConst whnf e | pure e; succeeded ← synthesizePending mvar; if succeeded then whnfCoreUnstuck e else pure e /-- Unfold definition using "smart unfolding" if possible. -/ -def unfoldDefinition {α} {m : Type → Type} [Monad m] +@[specialize] def unfoldDefinition {α} {m : Type → Type} [Monad m] + (getConst : Name → m (Option ConstantInfo)) + (isAuxDef? : Name → m Bool) (whnf : Expr → m Expr) (inferType : Expr → m Expr) (isDefEq : Expr → Expr → m Bool) (synthesizePending : Expr → m Bool) (getLocalDecl : Name → m LocalDecl) (getMVarAssignment : Name → m (Option Expr)) - (env : Environment) (e : Expr) + (e : Expr) (failK : Unit → m α) (successK : Expr → m α) : m α := match e with | Expr.app f _ => - matchConst env f.getAppFn failK $ fun fInfo fLvls => + matchConstAux getConst f.getAppFn failK $ fun fInfo fLvls => if fInfo.lparams.length != fLvls.length then failK () - else - match env.find $ mkSmartUnfoldingNameFor fInfo.name with + else do + fAuxInfo? ← getConst (mkSmartUnfoldingNameFor fInfo.name); + match fAuxInfo? with | some $ fAuxInfo@(ConstantInfo.defnInfo _) => deltaBetaDefinition fAuxInfo fLvls e.getAppRevArgs failK $ fun e₁ => do - e₂ ← whnfCoreUnstuck whnf inferType isDefEq synthesizePending getLocalDecl getMVarAssignment env e₁; + e₂ ← whnfCoreUnstuck getConst isAuxDef? whnf inferType isDefEq synthesizePending getLocalDecl getMVarAssignment e₁; if isIdRhsApp e₂ then successK $ extractIdRhs e₂ else failK () | _ => deltaBetaDefinition fInfo fLvls e.getAppRevArgs failK successK -| Expr.const c lvls => - match env.find c with - | some $ cinfo@(ConstantInfo.defnInfo _) => deltaDefinition cinfo lvls failK successK - | _ => failK () +| Expr.const name lvls => do + (some (cinfo@(ConstantInfo.defnInfo _))) ← getConst name | failK (); + deltaDefinition cinfo lvls failK successK | _ => failK () end Lean