lean4-htt/library/Init/Lean/WHNF.lean
2019-11-18 19:54:05 -08:00

386 lines
15 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Init.Lean.Declaration
import Init.Lean.LocalContext
namespace Lean
/- ===========================
Smart unfolding support
=========================== -/
def smartUnfoldingSuffix := "_sunfold"
@[inline] def mkSmartUnfoldingNameFor (n : Name) : Name :=
mkNameStr n smartUnfoldingSuffix
/- ===========================
Helper functions
=========================== -/
@[inline]
def matchConstAux {α : Type} {m : Type → Type} [Monad m]
(getConst : Name → m (Option ConstantInfo))
(e : Expr) (failK : Unit → m α) (k : ConstantInfo → List Level → m α) : m α :=
match e with
| Expr.const name lvls _ => do
(some cinfo) ← getConst name | failK ();
k cinfo lvls
| _ => failK ()
/- ===========================
Helper functions for reducing recursors
=========================== -/
private def getFirstCtor {m : Type → Type} [Monad m]
(getConst : Name → m (Option ConstantInfo))
(d : Name) : m (Option Name) :=
do some (ConstantInfo.inductInfo { ctors := ctor::_, ..}) ← getConst d | pure none;
pure (some ctor)
private def mkNullaryCtor {m : Type → Type} [Monad m]
(getConst : Name → m (Option ConstantInfo))
(type : Expr) (nparams : Nat) : m (Option Expr) :=
match type.getAppFn with
| Expr.const d lvls _ => do
(some ctor) ← getFirstCtor getConst d | pure none;
pure $ mkAppN (mkConst ctor lvls) (type.getAppArgs.shrink nparams)
| _ => pure none
private def toCtorIfLit : Expr → Expr
| Expr.lit (Literal.natVal v) _ =>
if v == 0 then mkConst `Nat.zero
else mkApp (mkConst `Nat.succ) (mkNatLit (v-1))
| e => e
private def getRecRuleFor (rec : RecursorVal) (major : Expr) : Option RecursorRule :=
match major.getAppFn with
| Expr.const fn _ _ => rec.rules.find $ fun r => r.ctor == fn
| _ => none
@[specialize] private def toCtorWhenK {m : Type → Type} [Monad m]
(getConst : Name → m (Option ConstantInfo))
(whnf : Expr → m Expr)
(inferType : Expr → m Expr)
(isDefEq : Expr → Expr → m Bool)
(rec : RecursorVal) (major : Expr) : m (Option Expr) :=
do majorType ← inferType major;
majorType ← whnf majorType;
let majorTypeI := majorType.getAppFn;
if !majorTypeI.isConstOf rec.getInduct then
pure none
else if majorType.hasExprMVar && majorType.getAppArgs.anyFrom rec.nparams Expr.hasExprMVar then
pure none
else do
(some newCtorApp) ← mkNullaryCtor getConst majorType rec.nparams | pure none;
newType ← inferType newCtorApp;
defeq ← isDefEq majorType newType;
pure $ if defeq then newCtorApp else none
/-- Auxiliary function for reducing recursor applications. -/
@[specialize] def reduceRec {α} {m : Type → Type} [Monad m]
(getConst : Name → m (Option ConstantInfo))
(whnf : Expr → m Expr)
(inferType : Expr → m Expr)
(isDefEq : Expr → Expr → m Bool)
(rec : RecursorVal) (recLvls : List Level) (recArgs : Array Expr)
(failK : Unit → m α) (successK : Expr → m α) : m α :=
let majorIdx := rec.getMajorIdx;
if h : majorIdx < recArgs.size then do
let major := recArgs.get ⟨majorIdx, h⟩;
major ← whnf major;
major ←
if !rec.k then
pure major
else do {
newMajor ← toCtorWhenK getConst whnf inferType isDefEq rec major;
pure (newMajor.getD major)
};
let major := toCtorIfLit major;
match getRecRuleFor rec major with
| some rule =>
let majorArgs := major.getAppArgs;
if recLvls.length != rec.lparams.length then
failK ()
else
let rhs := rule.rhs.instantiateLevelParams rec.lparams recLvls;
-- Apply parameters, motives and minor premises from recursor application.
let rhs := mkAppRange rhs 0 (rec.nparams+rec.nmotives+rec.nminors) recArgs;
/- The number of parameters in the constructor is not necessarily
equal to the number of parameters in the recursor when we have
nested inductive types. -/
let nparams := majorArgs.size - rule.nfields;
let rhs := mkAppRange rhs nparams majorArgs.size majorArgs;
let rhs := mkAppRange rhs (majorIdx + 1) recArgs.size recArgs;
successK rhs
| none => failK ()
else
failK ()
@[specialize] def isRecStuck {m : Type → Type} [Monad m]
(whnf : Expr → m Expr)
(isStuck : Expr → m (Option Expr))
(rec : RecursorVal) (recLvls : List Level) (recArgs : Array Expr) : m (Option Expr) :=
if rec.k then
-- TODO: improve this case
pure none
else do
let majorIdx := rec.getMajorIdx;
if h : majorIdx < recArgs.size then do
let major := recArgs.get ⟨majorIdx, h⟩;
major ← whnf major;
isStuck major
else
pure none
/- ===========================
Helper functions for reducing Quot.lift and Quot.ind
=========================== -/
/-- Auxiliary function for reducing `Quot.lift` and `Quot.ind` applications. -/
@[specialize] def reduceQuotRec {α} {m : Type → Type} [Monad m]
(getConst : Name → m (Option ConstantInfo))
(whnf : Expr → m Expr)
(rec : QuotVal) (recLvls : List Level) (recArgs : Array Expr)
(failK : Unit → m α) (successK : Expr → m α) : m α :=
let process (majorPos argPos : Nat) : m α :=
if h : majorPos < recArgs.size then do
let major := recArgs.get ⟨majorPos, h⟩;
major ← whnf major;
match major with
| Expr.app (Expr.app (Expr.app (Expr.const majorFn _ _) _ _) _ _) majorArg _ => do
some (ConstantInfo.quotInfo { kind := QuotKind.ctor, .. }) ← getConst majorFn | failK ();
let f := recArgs.get! argPos;
let r := mkApp f majorArg;
let recArity := majorPos + 1;
successK $ mkAppRange r recArity recArgs.size recArgs
| _ => failK ()
else
failK ();
match rec.kind with
| QuotKind.lift => process 5 3
| QuotKind.ind => process 4 3
| _ => failK ()
@[specialize] def isQuotRecStuck {m : Type → Type} [Monad m]
(whnf : Expr → m Expr)
(isStuck : Expr → m (Option Expr))
(rec : QuotVal) (recLvls : List Level) (recArgs : Array Expr) : m (Option Expr) :=
let process (majorPos : Nat) : m (Option Expr) :=
if h : majorPos < recArgs.size then do
let major := recArgs.get ⟨majorPos, h⟩;
major ← whnf major;
isStuck major
else
pure none;
match rec.kind with
| QuotKind.lift => process 5
| QuotKind.ind => process 4
| _ => pure none
/- ===========================
Helper function for extracting "stuck term"
=========================== -/
/-- Return `some (Expr.mvar mvarId)` if metavariable `mvarId` is blocking reduction. -/
@[specialize] partial def getStuckMVar {m : Type → Type} [Monad m]
(getConst : Name → m (Option ConstantInfo))
(whnf : Expr → m Expr)
: Expr → m (Option Expr)
| Expr.mdata _ e _ => getStuckMVar e
| Expr.proj _ _ e _ => do e ← whnf e; getStuckMVar e
| e@(Expr.mvar _ _) => pure (some e)
| e@(Expr.app f _ _) =>
let f := f.getAppFn;
match f with
| Expr.mvar _ _ => pure (some f)
| Expr.const fName fLvls _ => do
cinfo? ← getConst fName;
match cinfo? with
| some $ ConstantInfo.recInfo rec => isRecStuck whnf getStuckMVar rec fLvls e.getAppArgs
| some $ ConstantInfo.quotInfo rec => isQuotRecStuck whnf getStuckMVar rec fLvls e.getAppArgs
| _ => pure none
| _ => pure none
| _ => pure none
/- ===========================
Weak Head Normal Form auxiliary combinators
=========================== -/
/-- Auxiliary combinator for handling easy WHNF cases. It takes a function for handling the "hard" cases as an argument -/
@[specialize] partial def whnfEasyCases {m : Type → Type} [Monad m]
(getLocalDecl : Name → m LocalDecl)
(getMVarAssignment : Name → m (Option Expr))
: Expr → (Expr → m Expr) → m Expr
| e@(Expr.forallE _ _ _ _), _ => pure e
| e@(Expr.lam _ _ _ _), _ => pure e
| e@(Expr.sort _ _), _ => pure e
| e@(Expr.lit _ _), _ => pure e
| e@(Expr.bvar _ _), _ => unreachable!
| Expr.mdata _ e _, k => whnfEasyCases e k
| e@(Expr.letE _ _ _ _ _), k => k e
| e@(Expr.fvar fvarId _), k => do
decl ← getLocalDecl fvarId;
match decl.value? with
| none => pure e
| some v => whnfEasyCases v k
| e@(Expr.mvar mvarId _), k => do
v? ← getMVarAssignment mvarId;
match v? with
| some v => whnfEasyCases v k
| none => pure e
| e@(Expr.const _ _ _), k => k e
| e@(Expr.app _ _ _), k => k e
| e@(Expr.proj _ _ _ _), k => k e
| Expr.localE _ _ _ _, _ => unreachable!
/-- Return true iff term is of the form `idRhs ...` -/
private def isIdRhsApp (e : Expr) : Bool :=
e.isAppOf `idRhs
/-- (@idRhs T f a_1 ... a_n) ==> (f a_1 ... a_n) -/
private def extractIdRhs (e : Expr) : Expr :=
if !isIdRhsApp e then e
else
let args := e.getAppArgs;
if args.size < 2 then e
else mkAppRange (args.get! 1) 2 args.size args
@[specialize] private def deltaDefinition {α} (c : ConstantInfo) (lvls : List Level)
(failK : Unit → α) (successK : Expr → α) : α :=
if c.lparams.length != lvls.length then failK ()
else
let val := c.instantiateValueLevelParams lvls;
successK (extractIdRhs val)
@[specialize] private def deltaBetaDefinition {α} (c : ConstantInfo) (lvls : List Level) (revArgs : Array Expr)
(failK : Unit → α) (successK : Expr → α) : α :=
if c.lparams.length != lvls.length then failK ()
else
let val := c.instantiateValueLevelParams lvls;
let val := val.betaRev revArgs;
successK (extractIdRhs val)
/--
Apply beta-reduction, zeta-reduction (i.e., unfold let local-decls), iota-reduction,
expand let-expressions, expand assigned meta-variables.
This method does *not* apply delta-reduction at the head symbol `f` unless `isAuxDef? f` returns true.
Reason: we want to perform these reductions lazily at `isDefEq`. -/
@[specialize] partial def whnfCore {m : Type → Type} [Monad m]
(getConst : Name → m (Option ConstantInfo))
(isAuxDef? : Name → m Bool)
(whnf : Expr → m Expr)
(inferType : Expr → m Expr)
(isDefEq : Expr → Expr → m Bool)
(getLocalDecl : Name → m LocalDecl)
(getMVarAssignment : Name → m (Option Expr)) : Expr → m Expr
| e => whnfEasyCases getLocalDecl getMVarAssignment e $ fun e =>
match e with
| e@(Expr.const _ _ _) => pure e
| e@(Expr.letE _ _ v b _) => whnfCore $ b.instantiate1 v
| e@(Expr.app f _ _) => do
let f := f.getAppFn;
f' ← whnfCore f;
if f'.isLambda then
let revArgs := e.getAppRevArgs;
whnfCore $ f.betaRev revArgs
else do
let done : Unit → m Expr := fun _ =>
if f == f' then pure e else pure $ e.updateFn f';
matchConstAux getConst f' done $ fun cinfo lvls =>
match cinfo with
| ConstantInfo.recInfo rec => reduceRec getConst whnf inferType isDefEq rec lvls e.getAppArgs done whnfCore
| ConstantInfo.quotInfo rec => reduceQuotRec getConst whnf rec lvls e.getAppArgs done whnfCore
| c@(ConstantInfo.defnInfo _) => do
unfold? ← isAuxDef? c.name;
if unfold? then
deltaBetaDefinition c lvls e.getAppRevArgs done whnfCore
else
done ()
| _ => done ()
| e@(Expr.proj _ i c _) => do
c ← whnf c;
matchConstAux getConst c.getAppFn (fun _ => pure e) $ fun cinfo lvls =>
match cinfo with
| ConstantInfo.ctorInfo ctorVal => pure $ c.getArgD (ctorVal.nparams + i) e
| _ => pure e
| _ => unreachable!
/--
Similar to `whnfCore`, but uses `synthesizePending` to (try to) synthesize metavariables
that are blocking reduction. -/
@[specialize] private partial def whnfCoreUnstuck {m : Type → Type} [Monad m]
(getConst : Name → m (Option ConstantInfo))
(isAuxDef? : Name → m Bool)
(whnf : Expr → m Expr)
(inferType : Expr → m Expr)
(isDefEq : Expr → Expr → m Bool)
(synthesizePending : Expr → m Bool)
(getLocalDecl : Name → m LocalDecl)
(getMVarAssignment : Name → m (Option Expr))
: Expr → m Expr
| e => do
e ← whnfCore getConst isAuxDef? whnf inferType isDefEq getLocalDecl getMVarAssignment e;
(some mvar) ← getStuckMVar getConst whnf e | pure e;
succeeded ← synthesizePending mvar;
if succeeded then whnfCoreUnstuck e else pure e
/-- Unfold definition using "smart unfolding" if possible. -/
@[specialize] def unfoldDefinitionAux {α} {m : Type → Type} [Monad m]
(getConst : Name → m (Option ConstantInfo))
(isAuxDef? : Name → m Bool)
(whnf : Expr → m Expr)
(inferType : Expr → m Expr)
(isDefEq : Expr → Expr → m Bool)
(synthesizePending : Expr → m Bool)
(getLocalDecl : Name → m LocalDecl)
(getMVarAssignment : Name → m (Option Expr))
(e : Expr)
(failK : Unit → m α) (successK : Expr → m α) : m α :=
match e with
| Expr.app f _ _ =>
matchConstAux getConst f.getAppFn failK $ fun fInfo fLvls =>
if fInfo.lparams.length != fLvls.length then failK ()
else do
fAuxInfo? ← getConst (mkSmartUnfoldingNameFor fInfo.name);
match fAuxInfo? with
| some $ fAuxInfo@(ConstantInfo.defnInfo _) =>
deltaBetaDefinition fAuxInfo fLvls e.getAppRevArgs failK $ fun e₁ => do
e₂ ← whnfCoreUnstuck getConst isAuxDef? whnf inferType isDefEq synthesizePending getLocalDecl getMVarAssignment e₁;
if isIdRhsApp e₂ then
successK $ extractIdRhs e₂
else
failK ()
| _ => if fInfo.hasValue then deltaBetaDefinition fInfo fLvls e.getAppRevArgs failK successK else failK ()
| Expr.const name lvls _ => do
(some (cinfo@(ConstantInfo.defnInfo _))) ← getConst name | failK ();
deltaDefinition cinfo lvls failK successK
| _ => failK ()
/- Reference implementation for `whnf`. It does not cache any results.
How to use:
- `getConst constName` retrieves `constName` from environment. Caller may make definitions opaque by returning `none`.
- `isAuxDef? constName` returns `true` is `constName` is an auxiliary declaration automatically generated by Lean and
used by equation compiler, and must be eagerly reduced by `whnfCore`. This method is usually implemented using `isAuxRecursor`.
- `synthesizePending` is used to (try to) synthesize synthetic metavariables that may be blocking reduction.
The other parameters should be self explanatory. -/
@[specialize] partial def whnfMain {m : Type → Type} [Monad m]
(getConst : Name → m (Option ConstantInfo))
(isAuxDef? : Name → m Bool)
(inferType : Expr → m Expr)
(isDefEq : Expr → Expr → m Bool)
(synthesizePending : Expr → m Bool)
(getLocalDecl : Name → m LocalDecl)
(getMVarAssignment : Name → m (Option Expr))
: Expr → m Expr
| e => do
e ← whnfCore getConst isAuxDef? whnfMain inferType isDefEq getLocalDecl getMVarAssignment e;
unfoldDefinitionAux getConst isAuxDef? whnfMain inferType isDefEq synthesizePending getLocalDecl getMVarAssignment e (fun _ => pure e) whnfMain
end Lean