diff --git a/src/Lean/Elab/BuiltinNotation.lean b/src/Lean/Elab/BuiltinNotation.lean index 8c3ae6629b..7b5aae9de0 100644 --- a/src/Lean/Elab/BuiltinNotation.lean +++ b/src/Lean/Elab/BuiltinNotation.lean @@ -45,18 +45,11 @@ fun stx expectedType? => match_syntax stx with expectedType ← instantiateMVars expectedType; let expectedType := expectedType.consumeMData; expectedType ← whnf expectedType; - match expectedType.getAppFn with - | Expr.const constName _ _ => do - env ← getEnv; - match env.find? constName with - | some (ConstantInfo.inductInfo val) => - match val.ctors with - | [ctor] => do - newStx ← `($(mkCIdentFrom stx ctor) $(args.getSepElems)*); - withMacroExpansion stx newStx $ elabTerm newStx expectedType? - | _ => throwError ("invalid constructor ⟨...⟩, '" ++ constName ++ "' must have only one constructor") - | _ => throwError ("invalid constructor ⟨...⟩, '" ++ constName ++ "' is not an inductive type") - | _ => throwError ("invalid constructor ⟨...⟩, expected type is not an inductive type " ++ indentExpr expectedType) + matchConstStruct expectedType.getAppFn + (fun _ => throwError ("invalid constructor ⟨...⟩, expected type must be a structure " ++ indentExpr expectedType)) + (fun val _ ctor => do + newStx ← `($(mkCIdentFrom stx ctor.name) $(args.getSepElems)*); + withMacroExpansion stx newStx $ elabTerm newStx expectedType?) | none => throwError "invalid constructor ⟨...⟩, expected type must be known" | _ => throwUnsupportedSyntax diff --git a/src/Lean/Elab/Match.lean b/src/Lean/Elab/Match.lean index 0cfbc375ae..7b6a073a8c 100644 --- a/src/Lean/Elab/Match.lean +++ b/src/Lean/Elab/Match.lean @@ -194,7 +194,7 @@ modify fun s => { s with vars := s.vars.push (PatternVar.localVar id), found := -- It produces "unknown free variable: _kernel_fresh." at step `csimp.cpp` def processIdAuxAux (stx : Syntax) (mustBeCtor : Bool) (env : Environment) (f : Expr) : M Nat := match f with -| Expr.const fName _ _ => +| Expr.const fName _ _ => do match env.find? fName with | some $ ConstantInfo.ctorInfo val => liftM $ getNumExplicitCtorParams val | some $ info => @@ -498,26 +498,20 @@ partial def main : Expr → M Pattern newE ← whnf e; if newE != e then main newE - else match e.getAppFn with - | Expr.const declName us _ => do - env ← getEnv; - match env.find? declName with - | ConstantInfo.ctorInfo v => do - let args := e.getAppArgs; - unless (args.size == v.nparams + v.nfields) $ throwInvalidPattern e; - let params := args.extract 0 v.nparams; - let fields := args.extract v.nparams args.size; - let binderInfos := getFieldsBinderInfo v; - fields ← fields.mapIdxM fun i field => do { - let binderInfo := binderInfos.get! i; - if binderInfo.isExplicit then - main field - else - mkInaccessible field - }; - pure $ Pattern.ctor declName us params.toList fields.toList - | _ => throwInvalidPattern e - | _ => throwInvalidPattern e + else matchConstCtor e.getAppFn (fun _ => throwInvalidPattern e) fun v us => do + let args := e.getAppArgs; + unless (args.size == v.nparams + v.nfields) $ throwInvalidPattern e; + let params := args.extract 0 v.nparams; + let fields := args.extract v.nparams args.size; + let binderInfos := getFieldsBinderInfo v; + fields ← fields.mapIdxM fun i field => do { + let binderInfo := binderInfos.get! i; + if binderInfo.isExplicit then + main field + else + mkInaccessible field + }; + pure $ Pattern.ctor v.name us params.toList fields.toList end ToDepElimPattern diff --git a/src/Lean/Elab/Print.lean b/src/Lean/Elab/Print.lean index ae61f251f9..b8c2a7228e 100644 --- a/src/Lean/Elab/Print.lean +++ b/src/Lean/Elab/Print.lean @@ -48,10 +48,9 @@ env ← getEnv; m ← mkHeader "inductive" id lparams type isUnsafe; let m := m ++ Format.line ++ "constructors:"; m ← ctors.foldlM - (fun (m : MessageData) ctor => - match env.find? ctor with - | some v => pure $ m ++ Format.line ++ ctor ++ " : " ++ v.type - | none => pure m) + (fun (m : MessageData) ctor => do + cinfo ← getConstInfo ctor; + pure $ m ++ Format.line ++ ctor ++ " : " ++ cinfo.type) m; logInfo m diff --git a/src/Lean/Elab/Tactic/Induction.lean b/src/Lean/Elab/Tactic/Induction.lean index 26f4bfabdd..6d5ac7603c 100644 --- a/src/Lean/Elab/Tactic/Induction.lean +++ b/src/Lean/Elab/Tactic/Induction.lean @@ -109,13 +109,9 @@ def getInductiveValFromMajor (major : Expr) : TacticM InductiveVal := liftMetaMAtMain $ fun mvarId => do majorType ← inferType major; majorType ← whnf majorType; - match majorType.getAppFn with - | Expr.const n _ _ => do - env ← getEnv; - match env.find? n with - | ConstantInfo.inductInfo val => pure val - | _ => Meta.throwTacticEx `induction mvarId ("major premise type is not an inductive type " ++ indentExpr majorType) - | _ => Meta.throwTacticEx `induction mvarId ("major premise type is not an inductive type " ++ indentExpr majorType) + matchConstInduct majorType.getAppFn + (fun _ => Meta.throwTacticEx `induction mvarId ("major premise type is not an inductive type " ++ indentExpr majorType)) + (fun val _ => pure val) private partial def getRecFromUsingLoop (baseRecName : Name) : Expr → TacticM (Option Meta.RecursorInfo) | majorType => do diff --git a/src/Lean/Elab/Term.lean b/src/Lean/Elab/Term.lean index afbd7c814a..e9305f20a6 100644 --- a/src/Lean/Elab/Term.lean +++ b/src/Lean/Elab/Term.lean @@ -1105,16 +1105,13 @@ num.foldM (fun _ us => do u ← mkFreshLevelMVar; pure $ u::us) [] Remark: fresh universe metavariables are created if the constant has more universe parameters than `explicitLevels`. -/ def mkConst (constName : Name) (explicitLevels : List Level := []) : TermElabM Expr := do -env ← getEnv; -match env.find? constName with -| none => throwError ("unknown constant '" ++ constName ++ "'") -| some cinfo => - if explicitLevels.length > cinfo.lparams.length then - throwError ("too many explicit universe levels") - else do - let numMissingLevels := cinfo.lparams.length - explicitLevels.length; - us ← mkFreshLevelMVars numMissingLevels; - pure $ Lean.mkConst constName (explicitLevels ++ us) +cinfo ← getConstInfo constName; +if explicitLevels.length > cinfo.lparams.length then + throwError ("too many explicit universe levels") +else do + let numMissingLevels := cinfo.lparams.length - explicitLevels.length; + us ← mkFreshLevelMVars numMissingLevels; + pure $ Lean.mkConst constName (explicitLevels ++ us) private def mkConsts (candidates : List (Name × List String)) (explicitLevels : List Level) : TermElabM (List (Expr × List String)) := do env ← getEnv; diff --git a/src/Lean/Environment.lean b/src/Lean/Environment.lean index 71271cafb3..db335ecdb1 100644 --- a/src/Lean/Environment.lean +++ b/src/Lean/Environment.lean @@ -665,17 +665,6 @@ c?.isSome end Environment -/- Helper functions for accessing environment -/ - -@[inline] -def matchConst {α : Type} (env : Environment) (e : Expr) (failK : Unit → α) (k : ConstantInfo → List Level → α) : α := -match e with -| Expr.const n lvls _ => - match env.find? n with - | some cinfo => k cinfo lvls - | _ => failK () -| _ => failK () - namespace Kernel /- Kernel API -/ diff --git a/src/Lean/Meta/AppBuilder.lean b/src/Lean/Meta/AppBuilder.lean index c38d1b6e3f..d3e50c76b5 100644 --- a/src/Lean/Meta/AppBuilder.lean +++ b/src/Lean/Meta/AppBuilder.lean @@ -295,14 +295,10 @@ match type.eq? with | none => throwAppBuilderException `noConfusion ("equality expected" ++ hasTypeMsg h type) | some (α, a, b) => do α ← whnf α; - env ← getEnv; - let f := α.getAppFn; - matchConst env f (fun _ => throwAppBuilderException `noConfusion ("inductive type expected" ++ indentExpr α)) $ fun cinfo us => - match cinfo with - | ConstantInfo.inductInfo v => do - u ← getLevel target; - pure $ mkAppN (mkConst (mkNameStr v.name "noConfusion") (u :: us)) (α.getAppArgs ++ #[target, a, b, h]) - | _ => throwAppBuilderException `noConfusion ("inductive type expected" ++ indentExpr α) + matchConstInduct α.getAppFn (fun _ => throwAppBuilderException `noConfusion ("inductive type expected" ++ indentExpr α)) fun v us => do + u ← getLevel target; + pure $ mkAppN (mkConst (mkNameStr v.name "noConfusion") (u :: us)) (α.getAppArgs ++ #[target, a, b, h]) + def mkNoConfusion (target : Expr) (h : Expr) : m Expr := liftMetaM $ mkNoConfusionImp target h def mkPure (monad : Expr) (e : Expr) : m Expr := diff --git a/src/Lean/Meta/Check.lean b/src/Lean/Meta/Check.lean index 57014d1639..8dbc76c460 100644 --- a/src/Lean/Meta/Check.lean +++ b/src/Lean/Meta/Check.lean @@ -58,11 +58,9 @@ forallTelescope e $ fun xs b => do ensureType b; check b -private def checkConstant (c : Name) (lvls : List Level) : MetaM Unit := do -env ← getEnv; -match env.find? c with -| none => throwUnknownConstant c -| some cinfo => unless (lvls.length == cinfo.lparams.length) $ throwIncorrectNumberOfLevels c lvls +private def checkConstant (constName : Name) (us : List Level) : MetaM Unit := do +cinfo ← getConstInfo constName; +unless (us.length == cinfo.lparams.length) $ throwIncorrectNumberOfLevels constName us private def getFunctionDomain (f : Expr) : MetaM Expr := do fType ← inferType f; diff --git a/src/Lean/Meta/InferType.lean b/src/Lean/Meta/InferType.lean index 8ed2ddd17f..441dbf6f79 100644 --- a/src/Lean/Meta/InferType.lean +++ b/src/Lean/Meta/InferType.lean @@ -31,46 +31,37 @@ def throwIncorrectNumberOfLevels {α} (constName : Name) (us : List Level) : Met throwError $ "incorrect number of universe levels " ++ mkConst constName us private def inferConstType (c : Name) (us : List Level) : MetaM Expr := do -env ← getEnv; -match env.find? c with -| some cinfo => - if cinfo.lparams.length == us.length then - pure $ cinfo.instantiateTypeLevelParams us - else - throwIncorrectNumberOfLevels c us -| none => - throwUnknownConstant c +cinfo ← getConstInfo c; +if cinfo.lparams.length == us.length then + pure $ cinfo.instantiateTypeLevelParams us +else + throwIncorrectNumberOfLevels c us private def inferProjType (structName : Name) (idx : Nat) (e : Expr) : MetaM Expr := do let failed : Unit → MetaM Expr := fun _ => throwError $ "invalide projection" ++ indentExpr (mkProj structName idx e); structType ← inferType e; structType ← whnf structType; -env ← getEnv; -matchConst env structType.getAppFn failed $ fun structInfo structLvls => do - match structInfo with - | ConstantInfo.inductInfo { nparams := n, ctors := [ctor], .. } => - let structParams := structType.getAppArgs; - if n != structParams.size then failed () - else match env.find? ctor with - | none => failed () - | some (ctorInfo) => do - ctorType ← inferAppType (mkConst ctor structLvls) structParams; - ctorType ← idx.foldM - (fun i ctorType => do - ctorType ← whnf ctorType; - match ctorType with - | Expr.forallE _ _ body _ => - if body.hasLooseBVars then - pure $ body.instantiate1 $ mkProj structName i e - else - pure body - | _ => failed ()) - ctorType; +matchConstStruct structType.getAppFn failed fun structVal structLvls ctorVal => + let n := structVal.nparams; + let structParams := structType.getAppArgs; + if n != structParams.size then failed () + else do + ctorType ← inferAppType (mkConst ctorVal.name structLvls) structParams; + ctorType ← idx.foldM + (fun i ctorType => do ctorType ← whnf ctorType; match ctorType with - | Expr.forallE _ d _ _ => pure d - | _ => failed () - | _ => failed () + | Expr.forallE _ _ body _ => + if body.hasLooseBVars then + pure $ body.instantiate1 $ mkProj structName i e + else + pure body + | _ => failed ()) + ctorType; + ctorType ← whnf ctorType; + match ctorType with + | Expr.forallE _ d _ _ => pure d + | _ => failed () def throwTypeExcepted {α} (type : Expr) : MetaM α := throwError $ "type expected " ++ indentExpr type diff --git a/src/Lean/Meta/RecursorInfo.lean b/src/Lean/Meta/RecursorInfo.lean index 8e050b5942..531b0a67e1 100644 --- a/src/Lean/Meta/RecursorInfo.lean +++ b/src/Lean/Meta/RecursorInfo.lean @@ -78,29 +78,27 @@ instance : HasToString RecursorInfo := end RecursorInfo private def mkRecursorInfoForKernelRec (declName : Name) (val : RecursorVal) : MetaM RecursorInfo := do -indInfo ← getConstInfo val.getInduct; -match indInfo with -| ConstantInfo.inductInfo ival => - let numLParams := ival.lparams.length; - let univLevelPos := (List.range numLParams).map RecursorUnivLevelPos.majorType; - let univLevelPos := if val.lparams.length == numLParams then univLevelPos else RecursorUnivLevelPos.motive :: univLevelPos; - let produceMotive := List.replicate val.nminors true; - let paramsPos := (List.range val.nparams).map some; - let indicesPos := (List.range val.nindices).map (fun pos => val.nparams + pos); - let numArgs := val.nindices + val.nparams + val.nminors + val.nmotives + 1; - pure { - recursorName := declName, - typeName := val.getInduct, - univLevelPos := univLevelPos, - majorPos := val.getMajorIdx, - depElim := true, - recursive := ival.isRec, - produceMotive := produceMotive, - paramsPos := paramsPos, - indicesPos := indicesPos, - numArgs := numArgs - } -| _ => throwError "ill-formed builtin recursor" +ival ← getConstInfoInduct val.getInduct; +let numLParams := ival.lparams.length; +let univLevelPos := (List.range numLParams).map RecursorUnivLevelPos.majorType; +let univLevelPos := if val.lparams.length == numLParams then univLevelPos else RecursorUnivLevelPos.motive :: univLevelPos; +let produceMotive := List.replicate val.nminors true; +let paramsPos := (List.range val.nparams).map some; +let indicesPos := (List.range val.nindices).map (fun pos => val.nparams + pos); +let numArgs := val.nindices + val.nparams + val.nminors + val.nmotives + 1; +pure { + recursorName := declName, + typeName := val.getInduct, + univLevelPos := univLevelPos, + majorPos := val.getMajorIdx, + depElim := true, + recursive := ival.isRec, + produceMotive := produceMotive, + paramsPos := paramsPos, + indicesPos := indicesPos, + numArgs := numArgs +} + private def getMajorPosIfAuxRecursor? (declName : Name) (majorPos? : Option Nat) : MetaM (Option Nat) := if majorPos?.isSome then pure majorPos? @@ -112,10 +110,8 @@ else do if s != recOnSuffix && s != casesOnSuffix && s != brecOnSuffix then pure none else do - recInfo ← getConstInfo (mkRecFor p); - match recInfo with - | ConstantInfo.recInfo val => pure (some (val.nparams + val.nindices + (if s == casesOnSuffix then 1 else val.nmotives))) - | _ => throwError "unexpected recursor information" + val ← getConstInfoRec (mkRecFor p); + pure $ some (val.nparams + val.nindices + (if s == casesOnSuffix then 1 else val.nmotives)) | _ => pure none private def checkMotive (declName : Name) (motive : Expr) (motiveArgs : Array Expr) : MetaM Unit := diff --git a/src/Lean/Meta/Tactic/Cases.lean b/src/Lean/Meta/Tactic/Cases.lean index f996df9f9f..e1a4dd173f 100644 --- a/src/Lean/Meta/Tactic/Cases.lean +++ b/src/Lean/Meta/Tactic/Cases.lean @@ -16,16 +16,12 @@ private def throwInductiveTypeExpected {α} (type : Expr) : MetaM α := do throwError ("failed to compile pattern matching, inductive type expected" ++ indentExpr type) def getInductiveUniverseAndParams (type : Expr) : MetaM (List Level × Array Expr) := do -env ← getEnv; type ← whnfD type; -matchConst env type.getAppFn (fun _ => throwInductiveTypeExpected type) fun info us => - match info with - | ConstantInfo.inductInfo val => - let I := type.getAppFn; - let Iargs := type.getAppArgs; - let params := Iargs.extract 0 val.nparams; - pure (us, params) - | _ => throwInductiveTypeExpected type +matchConstInduct type.getAppFn (fun _ => throwInductiveTypeExpected type) fun val us => + let I := type.getAppFn; + let Iargs := type.getAppArgs; + let params := Iargs.extract 0 val.nparams; + pure (us, params) private def mkEqAndProof (lhs rhs : Expr) : MetaM (Expr × Expr) := do lhsType ← inferType lhs; @@ -76,44 +72,40 @@ def generalizeIndices (mvarId : MVarId) (fvarId : FVarId) : MetaM GeneralizeIndi withMVarContext mvarId $ do lctx ← getLCtx; localInsts ← getLocalInstances; - env ← getEnv; checkNotAssigned mvarId `generalizeIndices; fvarDecl ← getLocalDecl fvarId; type ← whnf fvarDecl.type; - type.withApp $ fun f args => matchConst env f (fun _ => throwTacticEx `generalizeIndices mvarId "inductive type expected") $ - fun cinfo _ => match cinfo with - | ConstantInfo.inductInfo val => do - unless (val.nindices > 0) $ throwTacticEx `generalizeIndices mvarId "indexed inductive type expected"; - unless (args.size == val.nindices + val.nparams) $ throwTacticEx `generalizeIndices mvarId "ill-formed inductive datatype"; - let indices := args.extract (args.size - val.nindices) args.size; - let IA := mkAppN f (args.extract 0 val.nparams); -- `I A` - IAType ← inferType IA; - forallTelescopeReducing IAType $ fun newIndices _ => do - let newType := mkAppN IA newIndices; - withLocalDeclD fvarDecl.userName newType $ fun h' => - withNewIndexEqs indices newIndices $ fun newEqs newRefls => do - (newEqType, newRefl) ← mkEqAndProof fvarDecl.toExpr h'; - let newRefls := newRefls.push newRefl; - withLocalDeclD `h newEqType $ fun newEq => do - let newEqs := newEqs.push newEq; - /- auxType `forall (j' : J) (h' : I A j'), j == j' -> h == h' -> target -/ - target ← getMVarType mvarId; - tag ← getMVarTag mvarId; - auxType ← mkForallFVars newEqs target; - auxType ← mkForallFVars #[h'] auxType; - auxType ← mkForallFVars newIndices auxType; - newMVar ← mkFreshExprMVarAt lctx localInsts auxType MetavarKind.syntheticOpaque tag; - /- assign mvarId := newMVar indices h refls -/ - assignExprMVar mvarId (mkAppN (mkApp (mkAppN newMVar indices) fvarDecl.toExpr) newRefls); - (indicesFVarIds, newMVarId) ← introN newMVar.mvarId! newIndices.size [] false; - (fvarId, newMVarId) ← intro1 newMVarId false; - pure { - mvarId := newMVarId, - indicesFVarIds := indicesFVarIds, - fvarId := fvarId, - numEqs := newEqs.size - } - | _ => throwTacticEx `generalizeIndices mvarId "inductive type expected" + type.withApp fun f args => matchConstInduct f (fun _ => throwTacticEx `generalizeIndices mvarId "inductive type expected") fun val _ => do + unless (val.nindices > 0) $ throwTacticEx `generalizeIndices mvarId "indexed inductive type expected"; + unless (args.size == val.nindices + val.nparams) $ throwTacticEx `generalizeIndices mvarId "ill-formed inductive datatype"; + let indices := args.extract (args.size - val.nindices) args.size; + let IA := mkAppN f (args.extract 0 val.nparams); -- `I A` + IAType ← inferType IA; + forallTelescopeReducing IAType $ fun newIndices _ => do + let newType := mkAppN IA newIndices; + withLocalDeclD fvarDecl.userName newType $ fun h' => + withNewIndexEqs indices newIndices $ fun newEqs newRefls => do + (newEqType, newRefl) ← mkEqAndProof fvarDecl.toExpr h'; + let newRefls := newRefls.push newRefl; + withLocalDeclD `h newEqType $ fun newEq => do + let newEqs := newEqs.push newEq; + /- auxType `forall (j' : J) (h' : I A j'), j == j' -> h == h' -> target -/ + target ← getMVarType mvarId; + tag ← getMVarTag mvarId; + auxType ← mkForallFVars newEqs target; + auxType ← mkForallFVars #[h'] auxType; + auxType ← mkForallFVars newIndices auxType; + newMVar ← mkFreshExprMVarAt lctx localInsts auxType MetavarKind.syntheticOpaque tag; + /- assign mvarId := newMVar indices h refls -/ + assignExprMVar mvarId (mkAppN (mkApp (mkAppN newMVar indices) fvarDecl.toExpr) newRefls); + (indicesFVarIds, newMVarId) ← introN newMVar.mvarId! newIndices.size [] false; + (fvarId, newMVarId) ← intro1 newMVarId false; + pure { + mvarId := newMVarId, + indicesFVarIds := indicesFVarIds, + fvarId := fvarId, + numEqs := newEqs.size + } structure CasesSubgoal extends InductionSubgoal := (ctorName : Name) @@ -135,12 +127,11 @@ if !env.contains `Eq || !env.contains `HEq then pure none else do majorDecl ← getLocalDecl majorFVarId; majorType ← whnf majorDecl.type; - majorType.withApp $ fun f args => matchConst env f (fun _ => pure none) $ fun cinfo _ => do - match cinfo with - | ConstantInfo.inductInfo ival => - if args.size != ival.nindices + ival.nparams then pure none - else match env.find? (mkNameStr ival.name "casesOn") with - | ConstantInfo.defnInfo cval => pure $ some { + majorType.withApp fun f args => matchConstInduct f (fun _ => pure none) fun ival _ => + if args.size != ival.nindices + ival.nparams then pure none + else match env.find? (mkNameStr ival.name "casesOn") with + | ConstantInfo.defnInfo cval => + pure $ some { inductiveVal := ival, casesOnVal := cval, majorDecl := majorDecl, @@ -148,7 +139,6 @@ else do majorTypeArgs := args } | _ => pure none - | _ => pure none /- We say the major premise has independent indices IF diff --git a/src/Lean/Meta/WHNF.lean b/src/Lean/Meta/WHNF.lean index d9957817a7..95a097faeb 100644 --- a/src/Lean/Meta/WHNF.lean +++ b/src/Lean/Meta/WHNF.lean @@ -149,20 +149,14 @@ whnfRef.set whnfImpl /- Given an expression `e`, compute its WHNF and if the result is a constructor, return field #i. -/ def reduceProj? (e : Expr) (i : Nat) : MetaM (Option Expr) := do -env ← getEnv; e ← whnf e; -match e.getAppFn with -| Expr.const name _ _ => - match env.find? name with - | some (ConstantInfo.ctorInfo ctorVal) => - let numArgs := e.getAppNumArgs; - let idx := ctorVal.nparams + i; - if idx < numArgs then - pure (some (e.getArg! idx)) - else - pure none - | _ => pure none -| _ => pure none +matchConstCtor e.getAppFn (fun _ => pure none) fun ctorVal _ => + let numArgs := e.getAppNumArgs; + let idx := ctorVal.nparams + i; + if idx < numArgs then + pure (some (e.getArg! idx)) + else + pure none @[specialize] partial def whnfHeadPredAux (pred : Expr → MetaM Bool) : Expr → MetaM Expr | e => Lean.WHNF.whnfEasyCases getLocalDecl getExprMVarAssignment? e $ fun e => do @@ -175,7 +169,6 @@ match e.getAppFn with | none => pure e) (pure e) - @[inline] def whnfHeadPred (e : Expr) (pred : Expr → MetaM Bool) : m Expr := liftMetaM $ whnfHeadPredAux pred e diff --git a/src/Lean/MonadEnv.lean b/src/Lean/MonadEnv.lean index 5f426b9b3e..b330ca7c93 100644 --- a/src/Lean/MonadEnv.lean +++ b/src/Lean/MonadEnv.lean @@ -26,6 +26,33 @@ variables {m : Type → Type} [MonadEnv m] def setEnv (env : Environment) : m Unit := modifyEnv fun _ => env +@[inline] def matchConst [Monad m] {α : Type} (e : Expr) (failK : Unit → m α) (k : ConstantInfo → List Level → m α) : m α := do +match e with +| Expr.const constName us _ => do + env ← getEnv; + match env.find? constName with + | some cinfo => k cinfo us + | none => failK () +| _ => failK () + +@[inline] def matchConstInduct [Monad m] {α : Type} (e : Expr) (failK : Unit → m α) (k : InductiveVal → List Level → m α) : m α := +matchConst e failK fun cinfo us => + match cinfo with + | ConstantInfo.inductInfo val => k val us + | _ => failK () + +@[inline] def matchConstCtor [Monad m] {α : Type} (e : Expr) (failK : Unit → m α) (k : ConstructorVal → List Level → m α) : m α := +matchConst e failK fun cinfo us => + match cinfo with + | ConstantInfo.ctorInfo val => k val us + | _ => failK () + +@[inline] def matchConstRec [Monad m] {α : Type} (e : Expr) (failK : Unit → m α) (k : RecursorVal → List Level → m α) : m α := +matchConst e failK fun cinfo us => + match cinfo with + | ConstantInfo.recInfo val => k val us + | _ => failK () + section variables [Monad m] [MonadError m] @@ -35,6 +62,35 @@ match env.find? constName with | some info => pure info | none => throwError ("unknown constant '" ++ constName ++ "'") +def getConstInfoInduct (constName : Name) : m InductiveVal := do +info ← getConstInfo constName; +match info with +| ConstantInfo.inductInfo v => pure v +| _ => throwError ("'" ++ constName ++ "' is not a inductive type") + +def getConstInfoCtor (constName : Name) : m ConstructorVal := do +info ← getConstInfo constName; +match info with +| ConstantInfo.ctorInfo v => pure v +| _ => throwError ("'" ++ constName ++ "' is not a constructor") + +def getConstInfoRec (constName : Name) : m RecursorVal := do +info ← getConstInfo constName; +match info with +| ConstantInfo.recInfo v => pure v +| _ => throwError ("'" ++ constName ++ "' is not a recursor") + +@[inline] def matchConstStruct {α : Type} (e : Expr) (failK : Unit → m α) (k : InductiveVal → List Level → ConstructorVal → m α) : m α := +matchConstInduct e failK fun ival us => + if ival.isRec then failK () + else match ival.ctors with + | [ctor] => do + ctorInfo ← getConstInfo ctor; + match ctorInfo with + | ConstantInfo.ctorInfo cval => k ival us cval + | _ => failK () + | _ => failK () + def addDecl [MonadOptions m] (decl : Declaration) : m Unit := do env ← getEnv; match env.addDecl decl with