feat: combine InferType and TypeUtil into Meta

This commit is contained in:
Leonardo de Moura 2019-11-08 16:18:46 -08:00
parent 1ff782334b
commit e5b77d4de8
6 changed files with 252 additions and 148 deletions

View file

@ -1227,6 +1227,9 @@ instance [DecidableEq α] [DecidableEq β] : DecidableEq (α × β) :=
| (isFalse n₂) => isFalse (fun h => Prod.noConfusion h (fun e₁' e₂' => absurd e₂' n₂))
| (isFalse n₁) => isFalse (fun h => Prod.noConfusion h (fun e₁' e₂' => absurd e₁' n₁))}
instance [HasBeq α] [HasBeq β] : HasBeq (α × β) :=
⟨fun ⟨a₁, b₁⟩ ⟨a₂, b₂⟩ => a₁ == a₂ && b₁ == b₂⟩
instance [HasLess α] [HasLess β] : HasLess (α × β) :=
⟨fun s t => s.1 < t.1 (s.1 = t.1 ∧ s.2 < t.2)⟩

View file

@ -24,3 +24,6 @@ protected def Nat.hash (n : Nat) : USize :=
USize.ofNat n
instance : Hashable Nat := ⟨Nat.hash⟩
instance {α β} [Hashable α] [Hashable β] : Hashable (α × β) :=
⟨fun ⟨a, b⟩ => mixHash (hash a) (hash b)⟩

View file

@ -1,91 +0,0 @@
/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Init.Data.Nat
import Init.Lean.Declaration
import Init.Lean.LocalContext
import Init.Lean.Environment
namespace Lean
inductive InferTypeException
| functionExpected (fType : Expr) (args : Array Expr)
| unknownConstant (constName : Name)
| incorrectNumberOfLevels (constName : Name) (constLvls : List Level)
| invalidProjection (structName : Name) (idx : Nat) (s : Expr)
-- TODO: add remaining errors
@[specialize] private def getForallResultType {m : Type → Type} [Monad m] [MonadExcept InferTypeException m]
(whnf : Expr → m Expr)
(fType : Expr) (args : Array Expr) : m Expr :=
do (j, fType) ← args.size.foldM
(fun i (acc : Nat × Expr) =>
let (j, type) := acc;
match type with
| Expr.forallE _ _ _ b => pure (j, b)
| _ => do
type ← whnf $ type.instantiateRevRange j i args;
match type with
| Expr.forallE _ _ _ b => pure (i, b)
| _ => throw $ InferTypeException.functionExpected fType args)
(0, fType);
pure $ fType.instantiateRevRange j args.size args
@[specialize] private def inferAppType {m : Type → Type} [Monad m] [MonadExcept InferTypeException m]
(whnf : Expr → m Expr)
(inferType : Expr → m Expr)
(f : Expr) (args : Array Expr) : m Expr :=
do fType ← inferType f;
getForallResultType whnf fType args
private def inferConstType {m : Type → Type} [Monad m] [MonadExcept InferTypeException m]
(env : Environment) (c : Name) (lvls : List Level) : m Expr :=
match env.find c with
| some cinfo =>
if cinfo.lparams.length == lvls.length then
throw $ InferTypeException.incorrectNumberOfLevels c lvls
else
pure $ cinfo.instantiateTypeLevelParams lvls
| none => throw $ InferTypeException.unknownConstant c
@[specialize] private def inferProjType {m : Type → Type} [Monad m] [MonadExcept InferTypeException m]
(whnf : Expr → m Expr)
(inferType : Expr → m Expr)
(env : Environment)
(structName : Name) (idx : Nat) (e : Expr) : m Expr :=
do let failed : Unit → m Expr := fun _ => throw $ InferTypeException.invalidProjection structName idx e;
structType ← inferType e;
structType ← whnf structType;
matchConst env structType.getAppFn failed $ fun structInfo structLvls => do
match structInfo with
| ConstantInfo.inductInfo { nparams := n, ctors := [ctor], .. } =>
let structParams := structType.getAppArgs;
if n != structParams.size then failed ()
else match env.find ctor with
| none => failed ()
| some (ctorInfo) => do
let ctorType := ctorInfo.instantiateTypeLevelParams structLvls;
ctorType ← getForallResultType whnf ctorType structParams;
ctorType ← idx.foldM
(fun i ctorType => do
ctorType ← whnf ctorType;
match ctorType with
| Expr.forallE _ _ _ body =>
if body.hasLooseBVars then
pure $ body.instantiate1 $ Expr.proj structName i e
else
pure body
| _ => failed ())
ctorType;
ctorType ← whnf ctorType;
match ctorType with
| Expr.forallE _ _ d _ => pure d
| _ => failed ()
| _ => failed ()
-- TODO
end Lean

View file

@ -9,20 +9,44 @@ import Init.Lean.NameGenerator
import Init.Lean.Environment
import Init.Lean.Trace
import Init.Lean.AuxRecursor
import Init.Lean.ProjFns
#exit
import Init.Lean.WHNF
import Init.Lean.ReducibilityAttrs
/-
This module provides three (mutually dependent) goodies:
This module provides four (mutually dependent) goodies that are needed for building the elaborator and tactic frameworks.
1- Weak head normal form computation with support for metavariables and transparency modes.
2- Definitionally equality checking with support for metavariables (aka unification modulo definitional equality).
3- Type inference.
4- Type class resolution.
They are packed into the MetaM monad.
-/
namespace Lean
namespace Meta
inductive TransparencyMode
| All | Semireducible | Instances | Reducible | None
| all | default | reducible
namespace TransparencyMode
instance : Inhabited TransparencyMode := ⟨TransparencyMode.default⟩
def beq : TransparencyMode → TransparencyMode → Bool
| all, all => true
| default, default => true
| reducible, reducible => true
| _, _ => false
instance : HasBeq TransparencyMode := ⟨beq⟩
def hash : TransparencyMode → USize
| all => 7
| default => 11
| reducible => 13
instance : Hashable TransparencyMode := ⟨hash⟩
end TransparencyMode
structure LocalInstance :=
(className : Name)
@ -30,32 +54,33 @@ structure LocalInstance :=
abbrev LocalInstances := Array LocalInstance
structure UnifierConfig :=
(foApprox : Bool := false)
(ctxApprox : Bool := false)
(quasiPatternApprox : Bool := false)
structure Config :=
(opts : Options := {})
(foApprox : Bool := false)
(ctxApprox : Bool := false)
(quasiPatternApprox : Bool := false)
(transparency : TransparencyMode := TransparencyMode.default)
structure TypeUtilConfig :=
(opts : Options := {})
(unifierConfig : UnifierConfig := {})
(transparency : TransparencyMode := TransparencyMode.Semireducible)
(useZeta : Bool := true)
structure Cache :=
(whnf : PersistentHashMap (TransparencyMode × Expr) Expr := {})
(inferType : PersistentHashMap Expr Expr := {})
/- Abstract cache interfact for `TypeUtil` functions.
TODO: add missing methods. -/
class AbstractTypeUtilCache (ϕ : Type) :=
(getWHNF : ϕ → TransparencyMode → Expr → Option Expr)
(setWHNF : ϕ → TransparencyMode → Expr → Expr → ϕ)
structure ExceptionContext :=
(env : Environment) (mctx : MetavarContext) (lctx : LocalContext)
-- TODO: add special cases
inductive TypeUtilException
| other : String → TypeUtilException
inductive Exception
| unknownConst (constName : Name) (ctx : ExceptionContext)
| unknownFVar (fvarId : Name) (ctx : ExceptionContext)
| unknownMVar (mvarId : Name) (ctx : ExceptionContext)
| functionExpected (fType : Expr) (args : Array Expr) (ctx : ExceptionContext)
| incorrectNumOfLevels (constName : Name) (constLvls : List Level) (ctx : ExceptionContext)
| invalidProjection (structName : Name) (idx : Nat) (s : Expr) (ctx : ExceptionContext)
| other (msg : String)
structure TypeUtilContext :=
(env : Environment)
structure Context :=
(config : Config := {})
(lctx : LocalContext := {})
(localInstances : LocalInstances := #[])
(config : TypeUtilConfig := {})
structure PostponedEntry :=
(lhs : Level)
@ -63,39 +88,209 @@ structure PostponedEntry :=
(rhs : Level)
(updateRhs : Bool)
structure TypeUtilState (σ ϕ : Type) :=
(mctx : σ)
(cache : ϕ)
structure State :=
(env : Environment)
(mctx : MetavarContext := {})
(cache : Cache := {})
(ngen : NameGenerator := {})
(traceState : TraceState := {})
(postponed : Array PostponedEntry := #[])
/-- Type Context Monad -/
abbrev TypeUtilM (σ ϕ : Type) := ReaderT TypeUtilContext (EStateM TypeUtilException (TypeUtilState σ ϕ))
abbrev MetaM := ReaderT Context (EStateM Exception State)
namespace TypeUtil
variables {σ ϕ : Type}
@[inline] private def getLCtx : MetaM LocalContext :=
do ctx ← read; pure ctx.lctx
private def getOptions : TypeUtilM σ ϕ Options :=
do ctx ← read; pure ctx.config.opts
private def getTraceState : TypeUtilM σ ϕ TraceState :=
do s ← get; pure s.traceState
private def getMCtx : TypeUtilM σ ϕ σ :=
@[inline] private def getMCtx : MetaM MetavarContext :=
do s ← get; pure s.mctx
private def getEnv : TypeUtilM σ ϕ Environment :=
do ctx ← read; pure ctx.env
@[inline] private def getEnv : MetaM Environment :=
do s ← get; pure s.env
private def useZeta : TypeUtilM σ ϕ Bool :=
do ctx ← read; pure ctx.config.useZeta
@[inline] private def throwEx {α} (f : ExceptionContext → Exception) : MetaM α :=
do ctx ← read;
s ← get;
throw (f {env := s.env, mctx := s.mctx, lctx := ctx.lctx })
instance tracer : SimpleMonadTracerAdapter (TypeUtilM σ ϕ) :=
@[inline] private def reduceAll? : MetaM Bool :=
do ctx ← read; pure $ ctx.config.transparency == TransparencyMode.all
@[inline] private def reduceReducibleOnly? : MetaM Bool :=
do ctx ← read; pure $ ctx.config.transparency == TransparencyMode.reducible
@[inline] private def getTransparency : MetaM TransparencyMode :=
do ctx ← read; pure $ ctx.config.transparency
@[inline] private def getOptions : MetaM Options :=
do ctx ← read; pure ctx.config.opts
@[inline] def isReducible (n : Name) : MetaM Bool :=
do env ← getEnv; pure $ isReducible env n
@[inline] private def getTraceState : MetaM TraceState :=
do s ← get; pure s.traceState
instance tracer : SimpleMonadTracerAdapter MetaM :=
{ getOptions := getOptions,
getTraceState := getTraceState,
modifyTraceState := fun f => modify $ fun s => { traceState := f s.traceState, .. s } }
private def getConst (constName : Name) : MetaM (Option ConstantInfo) :=
do env ← getEnv;
match env.find constName with
| some (info@(ConstantInfo.thmInfo _)) =>
condM reduceAll? (pure (some info)) (pure none)
| some info =>
condM reduceReducibleOnly?
(condM (isReducible constName) (pure (some info)) (pure none))
(pure (some info))
| none =>
throwEx $ Exception.unknownConst constName
private def isAuxDef? (constName : Name) : MetaM Bool :=
do env ← getEnv; pure (isAuxRecursor env constName || isNoConfusion env constName)
private def getLocalDecl (fvarId : Name) : MetaM LocalDecl :=
do lctx ← getLCtx;
match lctx.find fvarId with
| some d => pure d
| none => throwEx $ Exception.unknownFVar fvarId
@[inline] private def getMVarAssignment (mvarId : Name) : MetaM (Option Expr) :=
do mctx ← getMCtx; pure (mctx.getExprAssignment mvarId)
@[inline] private def getCachedWHNF (e : Expr) : MetaM (Option Expr) :=
do t ← getTransparency;
s ← get;
pure $ s.cache.whnf.find (t, e)
@[inline] private def cacheWHNF (e : Expr) (r : Expr) : MetaM Unit :=
do t ← getTransparency;
modify $ fun s => { cache := { whnf := s.cache.whnf.insert (t, e) r, .. s.cache }, .. s }
@[specialize] private partial def whnfAux
(inferType : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(synthesizePending : Expr → MetaM Bool)
: Expr → MetaM Expr
| e => whnfEasyCases getLocalDecl getMVarAssignment e $ fun e => do
cached? ← getCachedWHNF e;
match cached? with
| some r => pure r
| none => do
e ← whnfCore getConst isAuxDef? whnfAux inferType isDefEq getLocalDecl getMVarAssignment e;
r ← unfoldDefinition getConst isAuxDef? whnfAux inferType isDefEq synthesizePending getLocalDecl getMVarAssignment e (fun _ => pure e) whnfAux;
cacheWHNF e r;
pure r
@[specialize] private def getForallResultType
(whnf : Expr → MetaM Expr)
(fType : Expr) (args : Array Expr) : MetaM Expr :=
do (j, fType) ← args.size.foldM
(fun i (acc : Nat × Expr) =>
let (j, type) := acc;
match type with
| Expr.forallE _ _ _ b => pure (j, b)
| _ => do
type ← whnf $ type.instantiateRevRange j i args;
match type with
| Expr.forallE _ _ _ b => pure (i, b)
| _ => throwEx $ Exception.functionExpected fType args)
(0, fType);
pure $ fType.instantiateRevRange j args.size args
@[specialize] private def inferAppType
(whnf : Expr → MetaM Expr)
(inferType : Expr → MetaM Expr)
(f : Expr) (args : Array Expr) : MetaM Expr :=
do fType ← inferType f;
getForallResultType whnf fType args
private def inferConstType (c : Name) (lvls : List Level) : MetaM Expr :=
do env ← getEnv;
match env.find c with
| some cinfo =>
if cinfo.lparams.length == lvls.length then
throwEx $ Exception.incorrectNumOfLevels c lvls
else
pure $ cinfo.instantiateTypeLevelParams lvls
| none =>
throwEx $ Exception.unknownConst c
@[specialize] private def inferProjType
(whnf : Expr → MetaM Expr)
(inferType : Expr → MetaM Expr)
(structName : Name) (idx : Nat) (e : Expr) : MetaM Expr :=
do let failed : Unit → MetaM Expr := fun _ => throwEx $ Exception.invalidProjection structName idx e;
structType ← inferType e;
structType ← whnf structType;
env ← getEnv;
matchConst env structType.getAppFn failed $ fun structInfo structLvls => do
match structInfo with
| ConstantInfo.inductInfo { nparams := n, ctors := [ctor], .. } =>
let structParams := structType.getAppArgs;
if n != structParams.size then failed ()
else match env.find ctor with
| none => failed ()
| some (ctorInfo) => do
let ctorType := ctorInfo.instantiateTypeLevelParams structLvls;
ctorType ← getForallResultType whnf ctorType structParams;
ctorType ← idx.foldM
(fun i ctorType => do
ctorType ← whnf ctorType;
match ctorType with
| Expr.forallE _ _ _ body =>
if body.hasLooseBVars then
pure $ body.instantiate1 $ Expr.proj structName i e
else
pure body
| _ => failed ())
ctorType;
ctorType ← whnf ctorType;
match ctorType with
| Expr.forallE _ _ d _ => pure d
| _ => failed ()
| _ => failed ()
private def inferMVarType (mvarId : Name) : MetaM Expr :=
do mctx ← getMCtx;
match mctx.findDecl mvarId with
| some d => pure d.type
| none => throwEx $ Exception.unknownMVar mvarId
private def inferFVarType (fvarId : Name) : MetaM Expr :=
do lctx ← getLCtx;
match lctx.find fvarId with
| some d => pure d.type
| none => throwEx $ Exception.unknownFVar fvarId
@[inline] private def checkInferTypeCache (e : Expr) (inferType : MetaM Expr) : MetaM Expr :=
do s ← get;
match s.cache.inferType.find e with
| some type => pure type
| none => do
type ← inferType;
modify $ fun s => { cache := { inferType := s.cache.inferType.insert e type, .. s.cache }, .. s };
pure type
@[specialize] private partial def inferTypeAux
(whnf : Expr → MetaM Expr)
: Expr → MetaM Expr
| Expr.const c lvls => inferConstType c lvls
| e@(Expr.proj n i s) => checkInferTypeCache e (inferProjType whnf inferTypeAux n i s)
| e@(Expr.app f _) => checkInferTypeCache e (inferAppType whnf inferTypeAux f e.getAppArgs)
| Expr.mvar mvarId => inferMVarType mvarId
| Expr.fvar fvarId => inferFVarType fvarId
| Expr.bvar _ => unreachable!
| Expr.mdata _ e => inferTypeAux e
| Expr.lit v => pure v.type
| Expr.sort lvl => pure $ Expr.sort (Level.succ lvl)
| Expr.lam n bi d b => throw $ Exception.other "not implemented yet"
| Expr.forallE n bi d b => throw $ Exception.other "not implemented yet"
| Expr.letE n t v b => throw $ Exception.other "not implemented yet"
#exit
@[inline] private def liftStateMCtx {α} (x : StateM σ α) : TypeUtilM σ ϕ α :=
fun _ s =>
let (a, mctx) := x.run s.mctx;
@ -103,17 +298,6 @@ fun _ s =>
export AbstractMetavarContext (hasAssignableLevelMVar isReadOnlyLevelMVar auxMVarSupport getExprAssignment)
private def whnfAux
[AbstractMetavarContext σ]
[AbstractTypeUtilCache ϕ]
(whnf : Expr → TypeUtilM σ ϕ Expr)
(inferType : Expr → TypeUtilM σ ϕ Expr)
(isDefEq : Expr → Expr → TypeUtilM σ ϕ Bool)
: Expr → TypeUtilM σ ϕ Expr :=
-- TODO
whnfCore whnf inferType isDefEq true
/- ===========================
inferType
=========================== -/

View file

@ -34,4 +34,9 @@ match reducibilityAttrs.setValue env n s with
| Except.ok env => env
| _ => env -- TODO(Leo): we should extend EnumAttributes.setValue in the future and ensure it never fails
def isReducible (env : Environment) (n : Name) : Bool :=
match getReducibilityStatus env n with
| ReducibilityStatus.reducible => true
| _ => false
end Lean

View file

@ -211,7 +211,7 @@ match rec.kind with
=========================== -/
/-- Auxiliary combinator for handling easy WHNF cases. It takes a function for handling the "hard" cases as an argument -/
@[specialize] private partial def whnfEasyCases {m : Type → Type} [Monad m]
@[specialize] partial def whnfEasyCases {m : Type → Type} [Monad m]
(getLocalDecl : Name → m LocalDecl)
(getMVarAssignment : Name → m (Option Expr))
: Expr → (Expr → m Expr) → m Expr
@ -269,7 +269,7 @@ else
This method does *not* apply delta-reduction at the head symbol `f` unless `isAuxDef? f` returns true.
Reason: we want to perform these reductions lazily at `isDefEq`. -/
@[specialize] private partial def whnfCore {m : Type → Type} [Monad m]
@[specialize] partial def whnfCore {m : Type → Type} [Monad m]
(getConst : Name → m (Option ConstantInfo))
(isAuxDef? : Name → m Bool)
(whnf : Expr → m Expr)