diff --git a/library/Init/Core.lean b/library/Init/Core.lean index f1fcf4d048..154214d6f5 100644 --- a/library/Init/Core.lean +++ b/library/Init/Core.lean @@ -1227,6 +1227,9 @@ instance [DecidableEq α] [DecidableEq β] : DecidableEq (α × β) := | (isFalse n₂) => isFalse (fun h => Prod.noConfusion h (fun e₁' e₂' => absurd e₂' n₂)) | (isFalse n₁) => isFalse (fun h => Prod.noConfusion h (fun e₁' e₂' => absurd e₁' n₁))} +instance [HasBeq α] [HasBeq β] : HasBeq (α × β) := +⟨fun ⟨a₁, b₁⟩ ⟨a₂, b₂⟩ => a₁ == a₂ && b₁ == b₂⟩ + instance [HasLess α] [HasLess β] : HasLess (α × β) := ⟨fun s t => s.1 < t.1 ∨ (s.1 = t.1 ∧ s.2 < t.2)⟩ diff --git a/library/Init/Data/Hashable.lean b/library/Init/Data/Hashable.lean index 1bcfeed4fc..37b2f51ffd 100644 --- a/library/Init/Data/Hashable.lean +++ b/library/Init/Data/Hashable.lean @@ -24,3 +24,6 @@ protected def Nat.hash (n : Nat) : USize := USize.ofNat n instance : Hashable Nat := ⟨Nat.hash⟩ + +instance {α β} [Hashable α] [Hashable β] : Hashable (α × β) := +⟨fun ⟨a, b⟩ => mixHash (hash a) (hash b)⟩ diff --git a/library/Init/Lean/InferType.lean b/library/Init/Lean/InferType.lean deleted file mode 100644 index c478fc570e..0000000000 --- a/library/Init/Lean/InferType.lean +++ /dev/null @@ -1,91 +0,0 @@ -/- -Copyright (c) 2019 Microsoft Corporation. All rights reserved. -Released under Apache 2.0 license as described in the file LICENSE. -Authors: Leonardo de Moura --/ -prelude -import Init.Data.Nat -import Init.Lean.Declaration -import Init.Lean.LocalContext -import Init.Lean.Environment - -namespace Lean - -inductive InferTypeException -| functionExpected (fType : Expr) (args : Array Expr) -| unknownConstant (constName : Name) -| incorrectNumberOfLevels (constName : Name) (constLvls : List Level) -| invalidProjection (structName : Name) (idx : Nat) (s : Expr) --- TODO: add remaining errors - -@[specialize] private def getForallResultType {m : Type → Type} [Monad m] [MonadExcept InferTypeException m] - (whnf : Expr → m Expr) - (fType : Expr) (args : Array Expr) : m Expr := -do (j, fType) ← args.size.foldM - (fun i (acc : Nat × Expr) => - let (j, type) := acc; - match type with - | Expr.forallE _ _ _ b => pure (j, b) - | _ => do - type ← whnf $ type.instantiateRevRange j i args; - match type with - | Expr.forallE _ _ _ b => pure (i, b) - | _ => throw $ InferTypeException.functionExpected fType args) - (0, fType); - pure $ fType.instantiateRevRange j args.size args - -@[specialize] private def inferAppType {m : Type → Type} [Monad m] [MonadExcept InferTypeException m] - (whnf : Expr → m Expr) - (inferType : Expr → m Expr) - (f : Expr) (args : Array Expr) : m Expr := -do fType ← inferType f; - getForallResultType whnf fType args - -private def inferConstType {m : Type → Type} [Monad m] [MonadExcept InferTypeException m] - (env : Environment) (c : Name) (lvls : List Level) : m Expr := -match env.find c with -| some cinfo => - if cinfo.lparams.length == lvls.length then - throw $ InferTypeException.incorrectNumberOfLevels c lvls - else - pure $ cinfo.instantiateTypeLevelParams lvls -| none => throw $ InferTypeException.unknownConstant c - -@[specialize] private def inferProjType {m : Type → Type} [Monad m] [MonadExcept InferTypeException m] - (whnf : Expr → m Expr) - (inferType : Expr → m Expr) - (env : Environment) - (structName : Name) (idx : Nat) (e : Expr) : m Expr := -do let failed : Unit → m Expr := fun _ => throw $ InferTypeException.invalidProjection structName idx e; - structType ← inferType e; - structType ← whnf structType; - matchConst env structType.getAppFn failed $ fun structInfo structLvls => do - match structInfo with - | ConstantInfo.inductInfo { nparams := n, ctors := [ctor], .. } => - let structParams := structType.getAppArgs; - if n != structParams.size then failed () - else match env.find ctor with - | none => failed () - | some (ctorInfo) => do - let ctorType := ctorInfo.instantiateTypeLevelParams structLvls; - ctorType ← getForallResultType whnf ctorType structParams; - ctorType ← idx.foldM - (fun i ctorType => do - ctorType ← whnf ctorType; - match ctorType with - | Expr.forallE _ _ _ body => - if body.hasLooseBVars then - pure $ body.instantiate1 $ Expr.proj structName i e - else - pure body - | _ => failed ()) - ctorType; - ctorType ← whnf ctorType; - match ctorType with - | Expr.forallE _ _ d _ => pure d - | _ => failed () - | _ => failed () - --- TODO - -end Lean diff --git a/library/Init/Lean/TypeUtil.lean b/library/Init/Lean/Meta.lean similarity index 52% rename from library/Init/Lean/TypeUtil.lean rename to library/Init/Lean/Meta.lean index f595cf1458..da43bec80d 100644 --- a/library/Init/Lean/TypeUtil.lean +++ b/library/Init/Lean/Meta.lean @@ -9,20 +9,44 @@ import Init.Lean.NameGenerator import Init.Lean.Environment import Init.Lean.Trace import Init.Lean.AuxRecursor -import Init.Lean.ProjFns - -#exit +import Init.Lean.WHNF +import Init.Lean.ReducibilityAttrs /- -This module provides three (mutually dependent) goodies: +This module provides four (mutually dependent) goodies that are needed for building the elaborator and tactic frameworks. 1- Weak head normal form computation with support for metavariables and transparency modes. 2- Definitionally equality checking with support for metavariables (aka unification modulo definitional equality). 3- Type inference. +4- Type class resolution. + +They are packed into the MetaM monad. -/ namespace Lean +namespace Meta + inductive TransparencyMode -| All | Semireducible | Instances | Reducible | None +| all | default | reducible + +namespace TransparencyMode +instance : Inhabited TransparencyMode := ⟨TransparencyMode.default⟩ + +def beq : TransparencyMode → TransparencyMode → Bool +| all, all => true +| default, default => true +| reducible, reducible => true +| _, _ => false + +instance : HasBeq TransparencyMode := ⟨beq⟩ + +def hash : TransparencyMode → USize +| all => 7 +| default => 11 +| reducible => 13 + +instance : Hashable TransparencyMode := ⟨hash⟩ + +end TransparencyMode structure LocalInstance := (className : Name) @@ -30,32 +54,33 @@ structure LocalInstance := abbrev LocalInstances := Array LocalInstance -structure UnifierConfig := -(foApprox : Bool := false) -(ctxApprox : Bool := false) -(quasiPatternApprox : Bool := false) +structure Config := +(opts : Options := {}) +(foApprox : Bool := false) +(ctxApprox : Bool := false) +(quasiPatternApprox : Bool := false) +(transparency : TransparencyMode := TransparencyMode.default) -structure TypeUtilConfig := -(opts : Options := {}) -(unifierConfig : UnifierConfig := {}) -(transparency : TransparencyMode := TransparencyMode.Semireducible) -(useZeta : Bool := true) +structure Cache := +(whnf : PersistentHashMap (TransparencyMode × Expr) Expr := {}) +(inferType : PersistentHashMap Expr Expr := {}) -/- Abstract cache interfact for `TypeUtil` functions. - TODO: add missing methods. -/ -class AbstractTypeUtilCache (ϕ : Type) := -(getWHNF : ϕ → TransparencyMode → Expr → Option Expr) -(setWHNF : ϕ → TransparencyMode → Expr → Expr → ϕ) +structure ExceptionContext := +(env : Environment) (mctx : MetavarContext) (lctx : LocalContext) --- TODO: add special cases -inductive TypeUtilException -| other : String → TypeUtilException +inductive Exception +| unknownConst (constName : Name) (ctx : ExceptionContext) +| unknownFVar (fvarId : Name) (ctx : ExceptionContext) +| unknownMVar (mvarId : Name) (ctx : ExceptionContext) +| functionExpected (fType : Expr) (args : Array Expr) (ctx : ExceptionContext) +| incorrectNumOfLevels (constName : Name) (constLvls : List Level) (ctx : ExceptionContext) +| invalidProjection (structName : Name) (idx : Nat) (s : Expr) (ctx : ExceptionContext) +| other (msg : String) -structure TypeUtilContext := -(env : Environment) +structure Context := +(config : Config := {}) (lctx : LocalContext := {}) (localInstances : LocalInstances := #[]) -(config : TypeUtilConfig := {}) structure PostponedEntry := (lhs : Level) @@ -63,39 +88,209 @@ structure PostponedEntry := (rhs : Level) (updateRhs : Bool) -structure TypeUtilState (σ ϕ : Type) := -(mctx : σ) -(cache : ϕ) +structure State := +(env : Environment) +(mctx : MetavarContext := {}) +(cache : Cache := {}) (ngen : NameGenerator := {}) (traceState : TraceState := {}) (postponed : Array PostponedEntry := #[]) -/-- Type Context Monad -/ -abbrev TypeUtilM (σ ϕ : Type) := ReaderT TypeUtilContext (EStateM TypeUtilException (TypeUtilState σ ϕ)) +abbrev MetaM := ReaderT Context (EStateM Exception State) -namespace TypeUtil -variables {σ ϕ : Type} +@[inline] private def getLCtx : MetaM LocalContext := +do ctx ← read; pure ctx.lctx -private def getOptions : TypeUtilM σ ϕ Options := -do ctx ← read; pure ctx.config.opts - -private def getTraceState : TypeUtilM σ ϕ TraceState := -do s ← get; pure s.traceState - -private def getMCtx : TypeUtilM σ ϕ σ := +@[inline] private def getMCtx : MetaM MetavarContext := do s ← get; pure s.mctx -private def getEnv : TypeUtilM σ ϕ Environment := -do ctx ← read; pure ctx.env +@[inline] private def getEnv : MetaM Environment := +do s ← get; pure s.env -private def useZeta : TypeUtilM σ ϕ Bool := -do ctx ← read; pure ctx.config.useZeta +@[inline] private def throwEx {α} (f : ExceptionContext → Exception) : MetaM α := +do ctx ← read; + s ← get; + throw (f {env := s.env, mctx := s.mctx, lctx := ctx.lctx }) -instance tracer : SimpleMonadTracerAdapter (TypeUtilM σ ϕ) := +@[inline] private def reduceAll? : MetaM Bool := +do ctx ← read; pure $ ctx.config.transparency == TransparencyMode.all + +@[inline] private def reduceReducibleOnly? : MetaM Bool := +do ctx ← read; pure $ ctx.config.transparency == TransparencyMode.reducible + +@[inline] private def getTransparency : MetaM TransparencyMode := +do ctx ← read; pure $ ctx.config.transparency + +@[inline] private def getOptions : MetaM Options := +do ctx ← read; pure ctx.config.opts + +@[inline] def isReducible (n : Name) : MetaM Bool := +do env ← getEnv; pure $ isReducible env n + +@[inline] private def getTraceState : MetaM TraceState := +do s ← get; pure s.traceState + +instance tracer : SimpleMonadTracerAdapter MetaM := { getOptions := getOptions, getTraceState := getTraceState, modifyTraceState := fun f => modify $ fun s => { traceState := f s.traceState, .. s } } +private def getConst (constName : Name) : MetaM (Option ConstantInfo) := +do env ← getEnv; + match env.find constName with + | some (info@(ConstantInfo.thmInfo _)) => + condM reduceAll? (pure (some info)) (pure none) + | some info => + condM reduceReducibleOnly? + (condM (isReducible constName) (pure (some info)) (pure none)) + (pure (some info)) + | none => + throwEx $ Exception.unknownConst constName + +private def isAuxDef? (constName : Name) : MetaM Bool := +do env ← getEnv; pure (isAuxRecursor env constName || isNoConfusion env constName) + +private def getLocalDecl (fvarId : Name) : MetaM LocalDecl := +do lctx ← getLCtx; + match lctx.find fvarId with + | some d => pure d + | none => throwEx $ Exception.unknownFVar fvarId + +@[inline] private def getMVarAssignment (mvarId : Name) : MetaM (Option Expr) := +do mctx ← getMCtx; pure (mctx.getExprAssignment mvarId) + +@[inline] private def getCachedWHNF (e : Expr) : MetaM (Option Expr) := +do t ← getTransparency; + s ← get; + pure $ s.cache.whnf.find (t, e) + +@[inline] private def cacheWHNF (e : Expr) (r : Expr) : MetaM Unit := +do t ← getTransparency; + modify $ fun s => { cache := { whnf := s.cache.whnf.insert (t, e) r, .. s.cache }, .. s } + +@[specialize] private partial def whnfAux + (inferType : Expr → MetaM Expr) + (isDefEq : Expr → Expr → MetaM Bool) + (synthesizePending : Expr → MetaM Bool) + : Expr → MetaM Expr +| e => whnfEasyCases getLocalDecl getMVarAssignment e $ fun e => do + cached? ← getCachedWHNF e; + match cached? with + | some r => pure r + | none => do + e ← whnfCore getConst isAuxDef? whnfAux inferType isDefEq getLocalDecl getMVarAssignment e; + r ← unfoldDefinition getConst isAuxDef? whnfAux inferType isDefEq synthesizePending getLocalDecl getMVarAssignment e (fun _ => pure e) whnfAux; + cacheWHNF e r; + pure r + +@[specialize] private def getForallResultType + (whnf : Expr → MetaM Expr) + (fType : Expr) (args : Array Expr) : MetaM Expr := +do (j, fType) ← args.size.foldM + (fun i (acc : Nat × Expr) => + let (j, type) := acc; + match type with + | Expr.forallE _ _ _ b => pure (j, b) + | _ => do + type ← whnf $ type.instantiateRevRange j i args; + match type with + | Expr.forallE _ _ _ b => pure (i, b) + | _ => throwEx $ Exception.functionExpected fType args) + (0, fType); + pure $ fType.instantiateRevRange j args.size args + +@[specialize] private def inferAppType + (whnf : Expr → MetaM Expr) + (inferType : Expr → MetaM Expr) + (f : Expr) (args : Array Expr) : MetaM Expr := +do fType ← inferType f; + getForallResultType whnf fType args + +private def inferConstType (c : Name) (lvls : List Level) : MetaM Expr := +do env ← getEnv; + match env.find c with + | some cinfo => + if cinfo.lparams.length == lvls.length then + throwEx $ Exception.incorrectNumOfLevels c lvls + else + pure $ cinfo.instantiateTypeLevelParams lvls + | none => + throwEx $ Exception.unknownConst c + +@[specialize] private def inferProjType + (whnf : Expr → MetaM Expr) + (inferType : Expr → MetaM Expr) + (structName : Name) (idx : Nat) (e : Expr) : MetaM Expr := +do let failed : Unit → MetaM Expr := fun _ => throwEx $ Exception.invalidProjection structName idx e; + structType ← inferType e; + structType ← whnf structType; + env ← getEnv; + matchConst env structType.getAppFn failed $ fun structInfo structLvls => do + match structInfo with + | ConstantInfo.inductInfo { nparams := n, ctors := [ctor], .. } => + let structParams := structType.getAppArgs; + if n != structParams.size then failed () + else match env.find ctor with + | none => failed () + | some (ctorInfo) => do + let ctorType := ctorInfo.instantiateTypeLevelParams structLvls; + ctorType ← getForallResultType whnf ctorType structParams; + ctorType ← idx.foldM + (fun i ctorType => do + ctorType ← whnf ctorType; + match ctorType with + | Expr.forallE _ _ _ body => + if body.hasLooseBVars then + pure $ body.instantiate1 $ Expr.proj structName i e + else + pure body + | _ => failed ()) + ctorType; + ctorType ← whnf ctorType; + match ctorType with + | Expr.forallE _ _ d _ => pure d + | _ => failed () + | _ => failed () + +private def inferMVarType (mvarId : Name) : MetaM Expr := +do mctx ← getMCtx; + match mctx.findDecl mvarId with + | some d => pure d.type + | none => throwEx $ Exception.unknownMVar mvarId + +private def inferFVarType (fvarId : Name) : MetaM Expr := +do lctx ← getLCtx; + match lctx.find fvarId with + | some d => pure d.type + | none => throwEx $ Exception.unknownFVar fvarId + +@[inline] private def checkInferTypeCache (e : Expr) (inferType : MetaM Expr) : MetaM Expr := +do s ← get; + match s.cache.inferType.find e with + | some type => pure type + | none => do + type ← inferType; + modify $ fun s => { cache := { inferType := s.cache.inferType.insert e type, .. s.cache }, .. s }; + pure type + +@[specialize] private partial def inferTypeAux + (whnf : Expr → MetaM Expr) + : Expr → MetaM Expr +| Expr.const c lvls => inferConstType c lvls +| e@(Expr.proj n i s) => checkInferTypeCache e (inferProjType whnf inferTypeAux n i s) +| e@(Expr.app f _) => checkInferTypeCache e (inferAppType whnf inferTypeAux f e.getAppArgs) +| Expr.mvar mvarId => inferMVarType mvarId +| Expr.fvar fvarId => inferFVarType fvarId +| Expr.bvar _ => unreachable! +| Expr.mdata _ e => inferTypeAux e +| Expr.lit v => pure v.type +| Expr.sort lvl => pure $ Expr.sort (Level.succ lvl) +| Expr.lam n bi d b => throw $ Exception.other "not implemented yet" +| Expr.forallE n bi d b => throw $ Exception.other "not implemented yet" +| Expr.letE n t v b => throw $ Exception.other "not implemented yet" + +#exit + @[inline] private def liftStateMCtx {α} (x : StateM σ α) : TypeUtilM σ ϕ α := fun _ s => let (a, mctx) := x.run s.mctx; @@ -103,17 +298,6 @@ fun _ s => export AbstractMetavarContext (hasAssignableLevelMVar isReadOnlyLevelMVar auxMVarSupport getExprAssignment) - -private def whnfAux - [AbstractMetavarContext σ] - [AbstractTypeUtilCache ϕ] - (whnf : Expr → TypeUtilM σ ϕ Expr) - (inferType : Expr → TypeUtilM σ ϕ Expr) - (isDefEq : Expr → Expr → TypeUtilM σ ϕ Bool) - : Expr → TypeUtilM σ ϕ Expr := --- TODO -whnfCore whnf inferType isDefEq true - /- =========================== inferType =========================== -/ diff --git a/library/Init/Lean/ReducibilityAttrs.lean b/library/Init/Lean/ReducibilityAttrs.lean index eaf523708e..8c31216802 100644 --- a/library/Init/Lean/ReducibilityAttrs.lean +++ b/library/Init/Lean/ReducibilityAttrs.lean @@ -34,4 +34,9 @@ match reducibilityAttrs.setValue env n s with | Except.ok env => env | _ => env -- TODO(Leo): we should extend EnumAttributes.setValue in the future and ensure it never fails +def isReducible (env : Environment) (n : Name) : Bool := +match getReducibilityStatus env n with +| ReducibilityStatus.reducible => true +| _ => false + end Lean diff --git a/library/Init/Lean/WHNF.lean b/library/Init/Lean/WHNF.lean index 5bc608aa3a..073a198809 100644 --- a/library/Init/Lean/WHNF.lean +++ b/library/Init/Lean/WHNF.lean @@ -211,7 +211,7 @@ match rec.kind with =========================== -/ /-- Auxiliary combinator for handling easy WHNF cases. It takes a function for handling the "hard" cases as an argument -/ -@[specialize] private partial def whnfEasyCases {m : Type → Type} [Monad m] +@[specialize] partial def whnfEasyCases {m : Type → Type} [Monad m] (getLocalDecl : Name → m LocalDecl) (getMVarAssignment : Name → m (Option Expr)) : Expr → (Expr → m Expr) → m Expr @@ -269,7 +269,7 @@ else This method does *not* apply delta-reduction at the head symbol `f` unless `isAuxDef? f` returns true. Reason: we want to perform these reductions lazily at `isDefEq`. -/ -@[specialize] private partial def whnfCore {m : Type → Type} [Monad m] +@[specialize] partial def whnfCore {m : Type → Type} [Monad m] (getConst : Name → m (Option ConstantInfo)) (isAuxDef? : Name → m Bool) (whnf : Expr → m Expr)