From 82ee2e361b16b77b185df4326500b38bfdd6c0cd Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 21 Oct 2020 18:43:47 -0700 Subject: [PATCH] chore: cleanup --- src/Lean/Attributes.lean | 4 +- src/Lean/Compiler/ExternAttr.lean | 183 ++- src/Lean/Compiler/IR/Basic.lean | 582 +++++---- src/Lean/Compiler/IR/Format.lean | 172 +-- src/Lean/Compiler/IR/LiveVars.lean | 165 +-- src/Lean/Compiler/IR/NormIds.lean | 193 ++- src/Lean/Data/Format.lean | 2 +- src/Lean/Data/JsonRpc.lean | 2 +- src/Lean/Data/KVMap.lean | 14 +- src/Lean/Data/Lsp/InitShutdown.lean | 9 +- src/Lean/Data/Lsp/TextSync.lean | 25 +- src/Lean/Data/Options.lean | 125 +- src/Lean/Declaration.lean | 344 ++--- src/Lean/Elab/App.lean | 9 +- src/Lean/Elab/Attributes.lean | 40 +- src/Lean/Elab/BuiltinNotation.lean | 279 ++-- src/Lean/Elab/Command.lean | 617 +++++---- src/Lean/Elab/DeclModifiers.lean | 225 ++-- src/Lean/Elab/Do.lean | 1817 ++++++++++++++------------- src/Lean/Elab/Inductive.lean | 6 +- src/Lean/Elab/Log.lean | 2 +- src/Lean/Elab/Match.lean | 4 +- src/Lean/Elab/MutualDef.lean | 2 +- src/Lean/Elab/Quotation.lean | 2 +- src/Lean/Elab/StructInst.lean | 993 ++++++++------- src/Lean/Elab/Structure.lean | 694 +++++----- src/Lean/Elab/Tactic/Basic.lean | 426 ++++--- src/Lean/Elab/Tactic/Rewrite.lean | 58 +- src/Lean/Elab/Term.lean | 13 +- src/Lean/Elab/Util.lean | 179 +-- src/Lean/Environment.lean | 2 +- src/Lean/Eval.lean | 2 +- src/Lean/Exception.lean | 2 +- src/Lean/Expr.lean | 20 +- src/Lean/Hygiene.lean | 2 +- src/Lean/KeyedDeclsAttribute.lean | 7 +- src/Lean/Meta/Basic.lean | 2 +- src/Lean/Meta/Closure.lean | 2 +- src/Lean/Meta/SynthInstance.lean | 2 +- src/Lean/ToExpr.lean | 101 +- 40 files changed, 3680 insertions(+), 3648 deletions(-) diff --git a/src/Lean/Attributes.lean b/src/Lean/Attributes.lean index a8e505e451..a922d4137b 100644 --- a/src/Lean/Attributes.lean +++ b/src/Lean/Attributes.lean @@ -20,7 +20,7 @@ def AttributeApplicationTime.beq : AttributeApplicationTime → AttributeApplica | AttributeApplicationTime.beforeElaboration, AttributeApplicationTime.beforeElaboration => true | _, _ => false -instance AttributeApplicationTime.hasBeq : HasBeq AttributeApplicationTime := ⟨AttributeApplicationTime.beq⟩ +instance : HasBeq AttributeApplicationTime := ⟨AttributeApplicationTime.beq⟩ structure Attr.Context := (currNamespace : Name) @@ -28,7 +28,7 @@ structure Attr.Context := abbrev AttrM := ReaderT Attr.Context CoreM -instance attrResolveName : MonadResolveName AttrM := { +instance : MonadResolveName AttrM := { getCurrNamespace := do pure (← read).currNamespace, getOpenDecls := do pure (← read).openDecls } diff --git a/src/Lean/Compiler/ExternAttr.lean b/src/Lean/Compiler/ExternAttr.lean index 07262740df..c4e4edcbc8 100644 --- a/src/Lean/Compiler/ExternAttr.lean +++ b/src/Lean/Compiler/ExternAttr.lean @@ -13,10 +13,10 @@ import Lean.Meta.Basic namespace Lean inductive ExternEntry -| adhoc (backend : Name) -| inline (backend : Name) (pattern : String) -| standard (backend : Name) (fn : String) -| foreign (backend : Name) (fn : String) + | adhoc (backend : Name) + | inline (backend : Name) (pattern : String) + | standard (backend : Name) (fn : String) + | foreign (backend : Name) (fn : String) /- - `@[extern]` @@ -33,13 +33,12 @@ inductive ExternEntry encoding: ```.arity? = 2, .entries = [standard `cpp "ioPrimPrintln"]``` -/ structure ExternAttrData := -(arity? : Option Nat := none) -(entries : List ExternEntry) + (arity? : Option Nat := none) + (entries : List ExternEntry) -instance ExternAttrData.inhabited : Inhabited ExternAttrData := ⟨{ entries := [] }⟩ +instance : Inhabited ExternAttrData := ⟨{ entries := [] }⟩ -private partial def syntaxToExternEntries (a : Array Syntax) : Nat → List ExternEntry → Except String (List ExternEntry) -| i, entries => +private partial def syntaxToExternEntries (a : Array Syntax) (i : Nat) (entries : List ExternEntry) : Except String (List ExternEntry) := if i == a.size then Except.ok entries else match a[i] with | Syntax.ident _ _ backend _ => @@ -58,122 +57,122 @@ private partial def syntaxToExternEntries (a : Array Syntax) : Nat → List Exte | _ => Except.error "identifier expected" private def syntaxToExternAttrData (s : Syntax) : ExceptT String Id ExternAttrData := -match s with -| Syntax.missing => Except.ok { entries := [ ExternEntry.adhoc `all ] } -| Syntax.node _ args => - if args.size == 0 then Except.error "unexpected kind of argument" - else - let (arity, i) : Option Nat × Nat := match args[0].isNatLit? with - | some arity => (some arity, 1) - | none => (none, 0) - match args[i].isStrLit? with - | some str => - if args.size == i+1 then - Except.ok { arity? := arity, entries := [ ExternEntry.standard `all str ] } - else - Except.error "invalid extern attribute" - | none => match syntaxToExternEntries args i [] with - | Except.ok entries => Except.ok { arity? := arity, entries := entries } - | Except.error msg => Except.error msg -| _ => Except.error "unexpected kind of argument" + match s with + | Syntax.missing => Except.ok { entries := [ ExternEntry.adhoc `all ] } + | Syntax.node _ args => + if args.size == 0 then Except.error "unexpected kind of argument" + else + let (arity, i) : Option Nat × Nat := match args[0].isNatLit? with + | some arity => (some arity, 1) + | none => (none, 0) + match args[i].isStrLit? with + | some str => + if args.size == i+1 then + Except.ok { arity? := arity, entries := [ ExternEntry.standard `all str ] } + else + Except.error "invalid extern attribute" + | none => match syntaxToExternEntries args i [] with + | Except.ok entries => Except.ok { arity? := arity, entries := entries } + | Except.error msg => Except.error msg + | _ => Except.error "unexpected kind of argument" @[extern "lean_add_extern"] -constant addExtern (env : Environment) (n : Name) : ExceptT String Id Environment := arbitrary _ +constant addExtern (env : Environment) (n : Name) : ExceptT String Id Environment builtin_initialize externAttr : ParametricAttribute ExternAttrData ← -registerParametricAttribute `extern "builtin and foreign functions" - (fun _ stx => ofExcept $ syntaxToExternAttrData stx) - (fun declName _ => do - let env ← getEnv - if env.isProjectionFn declName || env.isConstructor declName then do - env ← ofExcept $ addExtern env declName - setEnv env - else - pure ()) + registerParametricAttribute `extern "builtin and foreign functions" + (fun _ stx => ofExcept $ syntaxToExternAttrData stx) + (fun declName _ => do + let env ← getEnv + if env.isProjectionFn declName || env.isConstructor declName then do + env ← ofExcept $ addExtern env declName + setEnv env + else + pure ()) @[export lean_get_extern_attr_data] def getExternAttrData (env : Environment) (n : Name) : Option ExternAttrData := -externAttr.getParam env n + externAttr.getParam env n private def parseOptNum : Nat → String.Iterator → Nat → String.Iterator × Nat -| 0, it, r => (it, r) -| n+1, it, r => - if !it.hasNext then (it, r) - else - let c := it.curr - if '0' <= c && c <= '9' - then parseOptNum n it.next (r*10 + (c.toNat - '0'.toNat)) - else (it, r) + | 0, it, r => (it, r) + | n+1, it, r => + if !it.hasNext then (it, r) + else + let c := it.curr + if '0' <= c && c <= '9' + then parseOptNum n it.next (r*10 + (c.toNat - '0'.toNat)) + else (it, r) def expandExternPatternAux (args : List String) : Nat → String.Iterator → String → String -| 0, it, r => r -| i+1, it, r => - if ¬ it.hasNext then r - else let c := it.curr - if c ≠ '#' then expandExternPatternAux args i it.next (r.push c) - else - let it := it.next - let (it, j) := parseOptNum it.remainingBytes it 0 - let j := j-1 - expandExternPatternAux args i it (r ++ args.getD j "") + | 0, it, r => r + | i+1, it, r => + if ¬ it.hasNext then r + else let c := it.curr + if c ≠ '#' then expandExternPatternAux args i it.next (r.push c) + else + let it := it.next + let (it, j) := parseOptNum it.remainingBytes it 0 + let j := j-1 + expandExternPatternAux args i it (r ++ args.getD j "") def expandExternPattern (pattern : String) (args : List String) : String := -expandExternPatternAux args pattern.length pattern.mkIterator "" + expandExternPatternAux args pattern.length pattern.mkIterator "" def mkSimpleFnCall (fn : String) (args : List String) : String := -fn ++ "(" ++ ((args.intersperse ", ").foldl HasAppend.append "") ++ ")" + fn ++ "(" ++ ((args.intersperse ", ").foldl HasAppend.append "") ++ ")" def ExternEntry.backend : ExternEntry → Name -| ExternEntry.adhoc n => n -| ExternEntry.inline n _ => n -| ExternEntry.standard n _ => n -| ExternEntry.foreign n _ => n + | ExternEntry.adhoc n => n + | ExternEntry.inline n _ => n + | ExternEntry.standard n _ => n + | ExternEntry.foreign n _ => n def getExternEntryForAux (backend : Name) : List ExternEntry → Option ExternEntry -| [] => none -| e::es => - if e.backend == `all then some e - else if e.backend == backend then some e - else getExternEntryForAux backend es + | [] => none + | e::es => + if e.backend == `all then some e + else if e.backend == backend then some e + else getExternEntryForAux backend es def getExternEntryFor (d : ExternAttrData) (backend : Name) : Option ExternEntry := -getExternEntryForAux backend d.entries + getExternEntryForAux backend d.entries def isExtern (env : Environment) (fn : Name) : Bool := -(getExternAttrData env fn).isSome + getExternAttrData env fn $.isSome /- We say a Lean function marked as `[extern ""]` is for all backends, and it is implemented using `extern "C"`. Thus, there is no name mangling. -/ def isExternC (env : Environment) (fn : Name) : Bool := -match getExternAttrData env fn with -| some { entries := [ ExternEntry.standard `all _ ], .. } => true -| _ => false + match getExternAttrData env fn with + | some { entries := [ ExternEntry.standard `all _ ], .. } => true + | _ => false def getExternNameFor (env : Environment) (backend : Name) (fn : Name) : Option String := do -let data ← getExternAttrData env fn -let entry ← getExternEntryFor data backend -match entry with -| ExternEntry.standard _ n => pure n -| ExternEntry.foreign _ n => pure n -| _ => failure + let data ← getExternAttrData env fn + let entry ← getExternEntryFor data backend + match entry with + | ExternEntry.standard _ n => pure n + | ExternEntry.foreign _ n => pure n + | _ => failure def getExternConstArity (declName : Name) : CoreM (Option Nat) := do -let env ← getEnv -match getExternAttrData env declName with -| none => pure none -| some data => match data.arity? with - | some arity => pure arity - | none => - let cinfo ← getConstInfo declName - let (arity, _) ← (Meta.forallTelescopeReducing cinfo.type fun xs _ => pure xs.size : MetaM Nat).run - pure (some arity) + let env ← getEnv + match getExternAttrData env declName with + | none => pure none + | some data => match data.arity? with + | some arity => pure arity + | none => + let cinfo ← getConstInfo declName + let (arity, _) ← (Meta.forallTelescopeReducing cinfo.type fun xs _ => pure xs.size : MetaM Nat).run + pure (some arity) @[export lean_get_extern_const_arity] def getExternConstArityExport (env : Environment) (declName : Name) : IO (Option Nat) := do -try - let (arity?, _) ← (getExternConstArity declName).toIO {} { env := env } - pure arity? -catch _ => - pure none + try + let (arity?, _) ← (getExternConstArity declName).toIO {} { env := env } + pure arity? + catch _ => + pure none end Lean diff --git a/src/Lean/Compiler/IR/Basic.lean b/src/Lean/Compiler/IR/Basic.lean index 55b155a2e3..61e2ca3f82 100644 --- a/src/Lean/Compiler/IR/Basic.lean +++ b/src/Lean/Compiler/IR/Basic.lean @@ -22,32 +22,24 @@ namespace Lean.IR abbrev FunId := Name abbrev Index := Nat /- Variable identifier -/ -structure VarId := -(idx : Index) +structure VarId := (idx : Index) /- Join point identifier -/ -structure JoinPointId := -(idx : Index) +structure JoinPointId := (idx : Index) abbrev Index.lt (a b : Index) : Bool := a < b -namespace VarId instance : HasBeq VarId := ⟨fun a b => a.idx == b.idx⟩ instance : HasToString VarId := ⟨fun a => "x_" ++ toString a.idx⟩ instance : HasFormat VarId := ⟨fun a => toString a⟩ instance : Hashable VarId := ⟨fun a => hash a.idx⟩ -end VarId -namespace JoinPointId instance : HasBeq JoinPointId := ⟨fun a b => a.idx == b.idx⟩ instance : HasToString JoinPointId := ⟨fun a => "block_" ++ toString a.idx⟩ instance : HasFormat JoinPointId := ⟨fun a => toString a⟩ instance : Hashable JoinPointId := ⟨fun a => hash a.idx⟩ -end JoinPointId abbrev MData := KVMap -namespace MData -abbrev empty : MData := {} -end MData +abbrev MData.empty : MData := {} /- Low Level IR types. Most are self explanatory. @@ -82,54 +74,54 @@ then one of the following must hold in each (execution) branch. fields that do not contain object fields. -/ inductive IRType -| float | uint8 | uint16 | uint32 | uint64 | usize -| irrelevant | object | tobject -| struct (leanTypeName : Option Name) (types : Array IRType) : IRType -| union (leanTypeName : Name) (types : Array IRType) : IRType + | float | uint8 | uint16 | uint32 | uint64 | usize + | irrelevant | object | tobject + | struct (leanTypeName : Option Name) (types : Array IRType) : IRType + | union (leanTypeName : Name) (types : Array IRType) : IRType namespace IRType partial def beq : IRType → IRType → Bool -| float, float => true -| uint8, uint8 => true -| uint16, uint16 => true -| uint32, uint32 => true -| uint64, uint64 => true -| usize, usize => true -| irrelevant, irrelevant => true -| object, object => true -| tobject, tobject => true -| struct n₁ tys₁, struct n₂ tys₂ => n₁ == n₂ && Array.isEqv tys₁ tys₂ beq -| union n₁ tys₁, union n₂ tys₂ => n₁ == n₂ && Array.isEqv tys₁ tys₂ beq -| _, _ => false + | float, float => true + | uint8, uint8 => true + | uint16, uint16 => true + | uint32, uint32 => true + | uint64, uint64 => true + | usize, usize => true + | irrelevant, irrelevant => true + | object, object => true + | tobject, tobject => true + | struct n₁ tys₁, struct n₂ tys₂ => n₁ == n₂ && Array.isEqv tys₁ tys₂ beq + | union n₁ tys₁, union n₂ tys₂ => n₁ == n₂ && Array.isEqv tys₁ tys₂ beq + | _, _ => false -instance HasBeq : HasBeq IRType := ⟨beq⟩ +instance : HasBeq IRType := ⟨beq⟩ def isScalar : IRType → Bool -| float => true -| uint8 => true -| uint16 => true -| uint32 => true -| uint64 => true -| usize => true -| _ => false + | float => true + | uint8 => true + | uint16 => true + | uint32 => true + | uint64 => true + | usize => true + | _ => false def isObj : IRType → Bool -| object => true -| tobject => true -| _ => false + | object => true + | tobject => true + | _ => false def isIrrelevant : IRType → Bool -| irrelevant => true -| _ => false + | irrelevant => true + | _ => false def isStruct : IRType → Bool -| struct _ _ => true -| _ => false + | struct _ _ => true + | _ => false def isUnion : IRType → Bool -| union _ _ => true -| _ => false + | union _ _ => true + | _ => false end IRType @@ -138,31 +130,29 @@ end IRType Recall that for a Function `f`, we also generate `f._rarg` which does not take `irrelevant` arguments. However, `f._rarg` is only safe to be used in full applications. -/ inductive Arg -| var (id : VarId) -| irrelevant + | var (id : VarId) + | irrelevant -namespace Arg -protected def beq : Arg → Arg → Bool -| var x, var y => x == y -| irrelevant, irrelevant => true -| _, _ => false +protected def Arg.beq : Arg → Arg → Bool + | var x, var y => x == y + | irrelevant, irrelevant => true + | _, _ => false instance : HasBeq Arg := ⟨Arg.beq⟩ -instance : Inhabited Arg := ⟨irrelevant⟩ -end Arg +instance : Inhabited Arg := ⟨Arg.irrelevant⟩ @[export lean_ir_mk_var_arg] def mkVarArg (id : VarId) : Arg := Arg.var id inductive LitVal -| num (v : Nat) -| str (v : String) + | num (v : Nat) + | str (v : String) def LitVal.beq : LitVal → LitVal → Bool -| LitVal.num v₁, LitVal.num v₂ => v₁ == v₂ -| LitVal.str v₁, LitVal.str v₂ => v₁ == v₂ -| _, _ => false + | num v₁, num v₂ => v₁ == v₂ + | str v₁, str v₂ => v₁ == v₂ + | _, _ => false -instance LitVal.HasBeq : HasBeq LitVal := ⟨LitVal.beq⟩ +instance : HasBeq LitVal := ⟨LitVal.beq⟩ /- Constructor information. @@ -176,53 +166,58 @@ Recall that a Constructor object contains a header, then a sequence of pointers to other Lean objects, a sequence of `USize` (i.e., `size_t`) scalar values, and a sequence of other scalar values. -/ structure CtorInfo := -(name : Name) (cidx : Nat) (size : Nat) (usize : Nat) (ssize : Nat) + (name : Name) + (cidx : Nat) + (size : Nat) + (usize : Nat) + (ssize : Nat) def CtorInfo.beq : CtorInfo → CtorInfo → Bool -| ⟨n₁, cidx₁, size₁, usize₁, ssize₁⟩, ⟨n₂, cidx₂, size₂, usize₂, ssize₂⟩ => - n₁ == n₂ && cidx₁ == cidx₂ && size₁ == size₂ && usize₁ == usize₂ && ssize₁ == ssize₂ + | ⟨n₁, cidx₁, size₁, usize₁, ssize₁⟩, ⟨n₂, cidx₂, size₂, usize₂, ssize₂⟩ => + n₁ == n₂ && cidx₁ == cidx₂ && size₁ == size₂ && usize₁ == usize₂ && ssize₁ == ssize₂ -instance CtorInfo.HasBeq : HasBeq CtorInfo := ⟨CtorInfo.beq⟩ +instance : HasBeq CtorInfo := ⟨CtorInfo.beq⟩ def CtorInfo.isRef (info : CtorInfo) : Bool := -info.size > 0 || info.usize > 0 || info.ssize > 0 + info.size > 0 || info.usize > 0 || info.ssize > 0 def CtorInfo.isScalar (info : CtorInfo) : Bool := -!info.isRef + !info.isRef inductive Expr -/- We use `ctor` mainly for constructing Lean object/tobject values `lean_ctor_object` in the runtime. - This instruction is also used to creat `struct` and `union` return values. - For `union`, only `i.cidx` is relevant. For `struct`, `i` is irrelevant. -/ -| ctor (i : CtorInfo) (ys : Array Arg) -| reset (n : Nat) (x : VarId) -/- `reuse x in ctor_i ys` instruction in the paper. -/ -| reuse (x : VarId) (i : CtorInfo) (updtHeader : Bool) (ys : Array Arg) -/- Extract the `tobject` value at Position `sizeof(void*)*i` from `x`. - We also use `proj` for extracting fields from `struct` return values, and casting `union` return values. -/ -| proj (i : Nat) (x : VarId) -/- Extract the `Usize` value at Position `sizeof(void*)*i` from `x`. -/ -| uproj (i : Nat) (x : VarId) -/- Extract the scalar value at Position `sizeof(void*)*n + offset` from `x`. -/ -| sproj (n : Nat) (offset : Nat) (x : VarId) -/- Full application. -/ -| fap (c : FunId) (ys : Array Arg) -/- Partial application that creates a `pap` value (aka closure in our nonstandard terminology). -/ -| pap (c : FunId) (ys : Array Arg) -/- Application. `x` must be a `pap` value. -/ -| ap (x : VarId) (ys : Array Arg) -/- Given `x : ty` where `ty` is a scalar type, this operation returns a value of Type `tobject`. - For small scalar values, the Result is a tagged pointer, and no memory allocation is performed. -/ -| box (ty : IRType) (x : VarId) -/- Given `x : [t]object`, obtain the scalar value. -/ -| unbox (x : VarId) -| lit (v : LitVal) -/- Return `1 : uint8` Iff `RC(x) > 1` -/ -| isShared (x : VarId) -/- Return `1 : uint8` Iff `x : tobject` is a tagged pointer (storing a scalar value). -/ -| isTaggedPtr (x : VarId) + /- We use `ctor` mainly for constructing Lean object/tobject values `lean_ctor_object` in the runtime. + This instruction is also used to creat `struct` and `union` return values. + For `union`, only `i.cidx` is relevant. For `struct`, `i` is irrelevant. -/ + | ctor (i : CtorInfo) (ys : Array Arg) + | reset (n : Nat) (x : VarId) + /- `reuse x in ctor_i ys` instruction in the paper. -/ + | reuse (x : VarId) (i : CtorInfo) (updtHeader : Bool) (ys : Array Arg) + /- Extract the `tobject` value at Position `sizeof(void*)*i` from `x`. + We also use `proj` for extracting fields from `struct` return values, and casting `union` return values. -/ + | proj (i : Nat) (x : VarId) + /- Extract the `Usize` value at Position `sizeof(void*)*i` from `x`. -/ + | uproj (i : Nat) (x : VarId) + /- Extract the scalar value at Position `sizeof(void*)*n + offset` from `x`. -/ + | sproj (n : Nat) (offset : Nat) (x : VarId) + /- Full application. -/ + | fap (c : FunId) (ys : Array Arg) + /- Partial application that creates a `pap` value (aka closure in our nonstandard terminology). -/ + | pap (c : FunId) (ys : Array Arg) + /- Application. `x` must be a `pap` value. -/ + | ap (x : VarId) (ys : Array Arg) + /- Given `x : ty` where `ty` is a scalar type, this operation returns a value of Type `tobject`. + For small scalar values, the Result is a tagged pointer, and no memory allocation is performed. -/ + | box (ty : IRType) (x : VarId) + /- Given `x : [t]object`, obtain the scalar value. -/ + | unbox (x : VarId) + | lit (v : LitVal) + /- Return `1 : uint8` Iff `RC(x) > 1` -/ + | isShared (x : VarId) + /- Return `1 : uint8` Iff `x : tobject` is a tagged pointer (storing a scalar value). -/ + | isTaggedPtr (x : VarId) -@[export lean_ir_mk_ctor_expr] def mkCtorExpr (n : Name) (cidx : Nat) (size : Nat) (usize : Nat) (ssize : Nat) (ys : Array Arg) : Expr := Expr.ctor ⟨n, cidx, size, usize, ssize⟩ ys +@[export lean_ir_mk_ctor_expr] def mkCtorExpr (n : Name) (cidx : Nat) (size : Nat) (usize : Nat) (ssize : Nat) (ys : Array Arg) : Expr := + Expr.ctor ⟨n, cidx, size, usize, ssize⟩ ys @[export lean_ir_mk_proj_expr] def mkProjExpr (i : Nat) (x : VarId) : Expr := Expr.proj i x @[export lean_ir_mk_uproj_expr] def mkUProjExpr (i : Nat) (x : VarId) : Expr := Expr.uproj i x @[export lean_ir_mk_sproj_expr] def mkSProjExpr (n : Nat) (offset : Nat) (x : VarId) : Expr := Expr.sproj n offset x @@ -233,43 +228,46 @@ inductive Expr @[export lean_ir_mk_str_expr] def mkStrExpr (v : String) : Expr := Expr.lit (LitVal.str v) structure Param := -(x : VarId) (borrow : Bool) (ty : IRType) + (x : VarId) + (borrow : Bool) + (ty : IRType) -instance paramInh : Inhabited Param := ⟨{ x := { idx := 0 }, borrow := false, ty := IRType.object }⟩ +instance : Inhabited Param := ⟨{ x := { idx := 0 }, borrow := false, ty := IRType.object }⟩ -@[export lean_ir_mk_param] def mkParam (x : VarId) (borrow : Bool) (ty : IRType) : Param := ⟨x, borrow, ty⟩ +@[export lean_ir_mk_param] +def mkParam (x : VarId) (borrow : Bool) (ty : IRType) : Param := ⟨x, borrow, ty⟩ inductive AltCore (FnBody : Type) : Type -| ctor (info : CtorInfo) (b : FnBody) : AltCore FnBody -| default (b : FnBody) : AltCore FnBody + | ctor (info : CtorInfo) (b : FnBody) : AltCore FnBody + | default (b : FnBody) : AltCore FnBody inductive FnBody -/- `let x : ty := e; b` -/ -| vdecl (x : VarId) (ty : IRType) (e : Expr) (b : FnBody) -/- Join point Declaration `block_j (xs) := e; b` -/ -| jdecl (j : JoinPointId) (xs : Array Param) (v : FnBody) (b : FnBody) -/- Store `y` at Position `sizeof(void*)*i` in `x`. `x` must be a Constructor object and `RC(x)` must be 1. - This operation is not part of λPure is only used during optimization. -/ -| set (x : VarId) (i : Nat) (y : Arg) (b : FnBody) -| setTag (x : VarId) (cidx : Nat) (b : FnBody) -/- Store `y : Usize` at Position `sizeof(void*)*i` in `x`. `x` must be a Constructor object and `RC(x)` must be 1. -/ -| uset (x : VarId) (i : Nat) (y : VarId) (b : FnBody) -/- Store `y : ty` at Position `sizeof(void*)*i + offset` in `x`. `x` must be a Constructor object and `RC(x)` must be 1. - `ty` must not be `object`, `tobject`, `irrelevant` nor `Usize`. -/ -| sset (x : VarId) (i : Nat) (offset : Nat) (y : VarId) (ty : IRType) (b : FnBody) -/- RC increment for `object`. If c == `true`, then `inc` must check whether `x` is a tagged pointer or not. - If `persistent == true` then `x` is statically known to be a persistent object. -/ -| inc (x : VarId) (n : Nat) (c : Bool) (persistent : Bool) (b : FnBody) -/- RC decrement for `object`. If c == `true`, then `inc` must check whether `x` is a tagged pointer or not. - If `persistent == true` then `x` is statically known to be a persistent object. -/ -| dec (x : VarId) (n : Nat) (c : Bool) (persistent : Bool) (b : FnBody) -| del (x : VarId) (b : FnBody) -| mdata (d : MData) (b : FnBody) -| case (tid : Name) (x : VarId) (xType : IRType) (cs : Array (AltCore FnBody)) -| ret (x : Arg) -/- Jump to join point `j` -/ -| jmp (j : JoinPointId) (ys : Array Arg) -| unreachable + /- `let x : ty := e; b` -/ + | vdecl (x : VarId) (ty : IRType) (e : Expr) (b : FnBody) + /- Join point Declaration `block_j (xs) := e; b` -/ + | jdecl (j : JoinPointId) (xs : Array Param) (v : FnBody) (b : FnBody) + /- Store `y` at Position `sizeof(void*)*i` in `x`. `x` must be a Constructor object and `RC(x)` must be 1. + This operation is not part of λPure is only used during optimization. -/ + | set (x : VarId) (i : Nat) (y : Arg) (b : FnBody) + | setTag (x : VarId) (cidx : Nat) (b : FnBody) + /- Store `y : Usize` at Position `sizeof(void*)*i` in `x`. `x` must be a Constructor object and `RC(x)` must be 1. -/ + | uset (x : VarId) (i : Nat) (y : VarId) (b : FnBody) + /- Store `y : ty` at Position `sizeof(void*)*i + offset` in `x`. `x` must be a Constructor object and `RC(x)` must be 1. + `ty` must not be `object`, `tobject`, `irrelevant` nor `Usize`. -/ + | sset (x : VarId) (i : Nat) (offset : Nat) (y : VarId) (ty : IRType) (b : FnBody) + /- RC increment for `object`. If c == `true`, then `inc` must check whether `x` is a tagged pointer or not. + If `persistent == true` then `x` is statically known to be a persistent object. -/ + | inc (x : VarId) (n : Nat) (c : Bool) (persistent : Bool) (b : FnBody) + /- RC decrement for `object`. If c == `true`, then `inc` must check whether `x` is a tagged pointer or not. + If `persistent == true` then `x` is statically known to be a persistent object. -/ + | dec (x : VarId) (n : Nat) (c : Bool) (persistent : Bool) (b : FnBody) + | del (x : VarId) (b : FnBody) + | mdata (d : MData) (b : FnBody) + | case (tid : Name) (x : VarId) (xType : IRType) (cs : Array (AltCore FnBody)) + | ret (x : Arg) + /- Jump to join point `j` -/ + | jmp (j : JoinPointId) (ys : Array Arg) + | unreachable instance : Inhabited FnBody := ⟨FnBody.unreachable⟩ @@ -280,8 +278,8 @@ abbrev FnBody.nil := FnBody.unreachable @[export lean_ir_mk_uset] def mkUSet (x : VarId) (i : Nat) (y : VarId) (b : FnBody) : FnBody := FnBody.uset x i y b @[export lean_ir_mk_sset] def mkSSet (x : VarId) (i : Nat) (offset : Nat) (y : VarId) (ty : IRType) (b : FnBody) : FnBody := FnBody.sset x i offset y ty b @[export lean_ir_mk_case] def mkCase (tid : Name) (x : VarId) (cs : Array (AltCore FnBody)) : FnBody := --- Tyhe field `xType` is set by `explicitBoxing` compiler pass. -FnBody.case tid x IRType.object cs + -- Tyhe field `xType` is set by `explicitBoxing` compiler pass. + FnBody.case tid x IRType.object cs @[export lean_ir_mk_ret] def mkRet (x : Arg) : FnBody := FnBody.ret x @[export lean_ir_mk_jmp] def mkJmp (j : JoinPointId) (ys : Array Arg) : FnBody := FnBody.jmp j ys @[export lean_ir_mk_unreachable] def mkUnreachable : Unit → FnBody := fun _ => FnBody.unreachable @@ -290,86 +288,83 @@ abbrev Alt := AltCore FnBody @[matchPattern] abbrev Alt.ctor := @AltCore.ctor FnBody @[matchPattern] abbrev Alt.default := @AltCore.default FnBody -instance altInh : Inhabited Alt := -⟨Alt.default (arbitrary _)⟩ +instance : Inhabited Alt := ⟨Alt.default (arbitrary _)⟩ def FnBody.isTerminal : FnBody → Bool -| FnBody.case _ _ _ _ => true -| FnBody.ret _ => true -| FnBody.jmp _ _ => true -| FnBody.unreachable => true -| _ => false + | FnBody.case _ _ _ _ => true + | FnBody.ret _ => true + | FnBody.jmp _ _ => true + | FnBody.unreachable => true + | _ => false def FnBody.body : FnBody → FnBody -| FnBody.vdecl _ _ _ b => b -| FnBody.jdecl _ _ _ b => b -| FnBody.set _ _ _ b => b -| FnBody.uset _ _ _ b => b -| FnBody.sset _ _ _ _ _ b => b -| FnBody.setTag _ _ b => b -| FnBody.inc _ _ _ _ b => b -| FnBody.dec _ _ _ _ b => b -| FnBody.del _ b => b -| FnBody.mdata _ b => b -| other => other + | FnBody.vdecl _ _ _ b => b + | FnBody.jdecl _ _ _ b => b + | FnBody.set _ _ _ b => b + | FnBody.uset _ _ _ b => b + | FnBody.sset _ _ _ _ _ b => b + | FnBody.setTag _ _ b => b + | FnBody.inc _ _ _ _ b => b + | FnBody.dec _ _ _ _ b => b + | FnBody.del _ b => b + | FnBody.mdata _ b => b + | other => other def FnBody.setBody : FnBody → FnBody → FnBody -| FnBody.vdecl x t v _, b => FnBody.vdecl x t v b -| FnBody.jdecl j xs v _, b => FnBody.jdecl j xs v b -| FnBody.set x i y _, b => FnBody.set x i y b -| FnBody.uset x i y _, b => FnBody.uset x i y b -| FnBody.sset x i o y t _, b => FnBody.sset x i o y t b -| FnBody.setTag x i _, b => FnBody.setTag x i b -| FnBody.inc x n c p _, b => FnBody.inc x n c p b -| FnBody.dec x n c p _, b => FnBody.dec x n c p b -| FnBody.del x _, b => FnBody.del x b -| FnBody.mdata d _, b => FnBody.mdata d b -| other, b => other + | FnBody.vdecl x t v _, b => FnBody.vdecl x t v b + | FnBody.jdecl j xs v _, b => FnBody.jdecl j xs v b + | FnBody.set x i y _, b => FnBody.set x i y b + | FnBody.uset x i y _, b => FnBody.uset x i y b + | FnBody.sset x i o y t _, b => FnBody.sset x i o y t b + | FnBody.setTag x i _, b => FnBody.setTag x i b + | FnBody.inc x n c p _, b => FnBody.inc x n c p b + | FnBody.dec x n c p _, b => FnBody.dec x n c p b + | FnBody.del x _, b => FnBody.del x b + | FnBody.mdata d _, b => FnBody.mdata d b + | other, b => other @[inline] def FnBody.resetBody (b : FnBody) : FnBody := -b.setBody FnBody.nil + b.setBody FnBody.nil /- If b is a non terminal, then return a pair `(c, b')` s.t. `b == c <;> b'`, and c.body == FnBody.nil -/ @[inline] def FnBody.split (b : FnBody) : FnBody × FnBody := -let b' := b.body -let c := b.resetBody -(c, b') + let b' := b.body + let c := b.resetBody + (c, b') def AltCore.body : Alt → FnBody -| Alt.ctor _ b => b -| Alt.default b => b + | Alt.ctor _ b => b + | Alt.default b => b def AltCore.setBody : Alt → FnBody → Alt -| Alt.ctor c _, b => Alt.ctor c b -| Alt.default _, b => Alt.default b + | Alt.ctor c _, b => Alt.ctor c b + | Alt.default _, b => Alt.default b @[inline] def AltCore.modifyBody (f : FnBody → FnBody) : AltCore FnBody → Alt -| Alt.ctor c b => Alt.ctor c (f b) -| Alt.default b => Alt.default (f b) + | Alt.ctor c b => Alt.ctor c (f b) + | Alt.default b => Alt.default (f b) @[inline] def AltCore.mmodifyBody {m : Type → Type} [Monad m] (f : FnBody → m FnBody) : AltCore FnBody → m Alt -| Alt.ctor c b => Alt.ctor c <$> f b -| Alt.default b => Alt.default <$> f b + | Alt.ctor c b => Alt.ctor c <$> f b + | Alt.default b => Alt.default <$> f b def Alt.isDefault : Alt → Bool -| Alt.ctor _ _ => false -| Alt.default _ => true + | Alt.ctor _ _ => false + | Alt.default _ => true def push (bs : Array FnBody) (b : FnBody) : Array FnBody := -let b := b.resetBody -bs.push b + let b := b.resetBody + bs.push b -partial def flattenAux : FnBody → Array FnBody → (Array FnBody) × FnBody -| b, r => +partial def flattenAux (b : FnBody) (r : Array FnBody) : (Array FnBody) × FnBody := if b.isTerminal then (r, b) else flattenAux b.body (push r b) def FnBody.flatten (b : FnBody) : (Array FnBody) × FnBody := -flattenAux b #[] + flattenAux b #[] -partial def reshapeAux : Array FnBody → Nat → FnBody → FnBody -| a, i, b => +partial def reshapeAux (a : Array FnBody) (i : Nat) (b : FnBody) : FnBody := if i == 0 then b else let i := i - 1 @@ -378,216 +373,215 @@ partial def reshapeAux : Array FnBody → Nat → FnBody → FnBody reshapeAux a i b def reshape (bs : Array FnBody) (term : FnBody) : FnBody := -reshapeAux bs bs.size term + reshapeAux bs bs.size term @[inline] def modifyJPs (bs : Array FnBody) (f : FnBody → FnBody) : Array FnBody := -bs.map fun b => match b with - | FnBody.jdecl j xs v k => FnBody.jdecl j xs (f v) k - | other => other + bs.map fun b => match b with + | FnBody.jdecl j xs v k => FnBody.jdecl j xs (f v) k + | other => other @[inline] def mmodifyJPs {m : Type → Type} [Monad m] (bs : Array FnBody) (f : FnBody → m FnBody) : m (Array FnBody) := -bs.mapM fun b => match b with - | FnBody.jdecl j xs v k => do let v ← f v; pure $ FnBody.jdecl j xs v k - | other => pure other + bs.mapM fun b => match b with + | FnBody.jdecl j xs v k => do let v ← f v; pure $ FnBody.jdecl j xs v k + | other => pure other -@[export lean_ir_mk_alt] def mkAlt (n : Name) (cidx : Nat) (size : Nat) (usize : Nat) (ssize : Nat) (b : FnBody) : Alt := Alt.ctor ⟨n, cidx, size, usize, ssize⟩ b +@[export lean_ir_mk_alt] def mkAlt (n : Name) (cidx : Nat) (size : Nat) (usize : Nat) (ssize : Nat) (b : FnBody) : Alt := + Alt.ctor ⟨n, cidx, size, usize, ssize⟩ b inductive Decl -| fdecl (f : FunId) (xs : Array Param) (ty : IRType) (b : FnBody) -| extern (f : FunId) (xs : Array Param) (ty : IRType) (ext : ExternAttrData) + | fdecl (f : FunId) (xs : Array Param) (ty : IRType) (b : FnBody) + | extern (f : FunId) (xs : Array Param) (ty : IRType) (ext : ExternAttrData) namespace Decl instance : Inhabited Decl := -⟨fdecl (arbitrary _) (arbitrary _) IRType.irrelevant (arbitrary _)⟩ + ⟨fdecl (arbitrary _) (arbitrary _) IRType.irrelevant (arbitrary _)⟩ def name : Decl → FunId -| Decl.fdecl f _ _ _ => f -| Decl.extern f _ _ _ => f + | Decl.fdecl f _ _ _ => f + | Decl.extern f _ _ _ => f def params : Decl → Array Param -| Decl.fdecl _ xs _ _ => xs -| Decl.extern _ xs _ _ => xs + | Decl.fdecl _ xs _ _ => xs + | Decl.extern _ xs _ _ => xs def resultType : Decl → IRType -| Decl.fdecl _ _ t _ => t -| Decl.extern _ _ t _ => t + | Decl.fdecl _ _ t _ => t + | Decl.extern _ _ t _ => t def isExtern : Decl → Bool -| Decl.extern _ _ _ _ => true -| _ => false + | Decl.extern _ _ _ _ => true + | _ => false end Decl @[export lean_ir_mk_decl] def mkDecl (f : FunId) (xs : Array Param) (ty : IRType) (b : FnBody) : Decl := Decl.fdecl f xs ty b @[export lean_ir_mk_extern_decl] def mkExternDecl (f : FunId) (xs : Array Param) (ty : IRType) (e : ExternAttrData) : Decl := -Decl.extern f xs ty e + Decl.extern f xs ty e open Std (RBTree RBTree.empty RBMap) /-- Set of variable and join point names -/ abbrev IndexSet := RBTree Index Index.lt -instance vsetInh : Inhabited IndexSet := ⟨{}⟩ +instance : Inhabited IndexSet := ⟨{}⟩ def mkIndexSet (idx : Index) : IndexSet := -RBTree.empty.insert idx + RBTree.empty.insert idx inductive LocalContextEntry -| param : IRType → LocalContextEntry -| localVar : IRType → Expr → LocalContextEntry -| joinPoint : Array Param → FnBody → LocalContextEntry + | param : IRType → LocalContextEntry + | localVar : IRType → Expr → LocalContextEntry + | joinPoint : Array Param → FnBody → LocalContextEntry abbrev LocalContext := RBMap Index LocalContextEntry Index.lt def LocalContext.addLocal (ctx : LocalContext) (x : VarId) (t : IRType) (v : Expr) : LocalContext := -ctx.insert x.idx (LocalContextEntry.localVar t v) + ctx.insert x.idx (LocalContextEntry.localVar t v) def LocalContext.addJP (ctx : LocalContext) (j : JoinPointId) (xs : Array Param) (b : FnBody) : LocalContext := -ctx.insert j.idx (LocalContextEntry.joinPoint xs b) + ctx.insert j.idx (LocalContextEntry.joinPoint xs b) def LocalContext.addParam (ctx : LocalContext) (p : Param) : LocalContext := -ctx.insert p.x.idx (LocalContextEntry.param p.ty) + ctx.insert p.x.idx (LocalContextEntry.param p.ty) def LocalContext.addParams (ctx : LocalContext) (ps : Array Param) : LocalContext := -ps.foldl LocalContext.addParam ctx + ps.foldl LocalContext.addParam ctx def LocalContext.isJP (ctx : LocalContext) (idx : Index) : Bool := -match ctx.find? idx with -| some (LocalContextEntry.joinPoint _ _) => true -| other => false + match ctx.find? idx with + | some (LocalContextEntry.joinPoint _ _) => true + | other => false def LocalContext.getJPBody (ctx : LocalContext) (j : JoinPointId) : Option FnBody := -match ctx.find? j.idx with -| some (LocalContextEntry.joinPoint _ b) => some b -| other => none + match ctx.find? j.idx with + | some (LocalContextEntry.joinPoint _ b) => some b + | other => none def LocalContext.getJPParams (ctx : LocalContext) (j : JoinPointId) : Option (Array Param) := -match ctx.find? j.idx with -| some (LocalContextEntry.joinPoint ys _) => some ys -| other => none + match ctx.find? j.idx with + | some (LocalContextEntry.joinPoint ys _) => some ys + | other => none def LocalContext.isParam (ctx : LocalContext) (idx : Index) : Bool := -match ctx.find? idx with -| some (LocalContextEntry.param _) => true -| other => false + match ctx.find? idx with + | some (LocalContextEntry.param _) => true + | other => false def LocalContext.isLocalVar (ctx : LocalContext) (idx : Index) : Bool := -match ctx.find? idx with -| some (LocalContextEntry.localVar _ _) => true -| other => false + match ctx.find? idx with + | some (LocalContextEntry.localVar _ _) => true + | other => false def LocalContext.contains (ctx : LocalContext) (idx : Index) : Bool := -Std.RBMap.contains ctx idx + Std.RBMap.contains ctx idx def LocalContext.eraseJoinPointDecl (ctx : LocalContext) (j : JoinPointId) : LocalContext := -ctx.erase j.idx + ctx.erase j.idx def LocalContext.getType (ctx : LocalContext) (x : VarId) : Option IRType := -match ctx.find? x.idx with -| some (LocalContextEntry.param t) => some t -| some (LocalContextEntry.localVar t _) => some t -| other => none + match ctx.find? x.idx with + | some (LocalContextEntry.param t) => some t + | some (LocalContextEntry.localVar t _) => some t + | other => none def LocalContext.getValue (ctx : LocalContext) (x : VarId) : Option Expr := -match ctx.find? x.idx with -| some (LocalContextEntry.localVar _ v) => some v -| other => none + match ctx.find? x.idx with + | some (LocalContextEntry.localVar _ v) => some v + | other => none abbrev IndexRenaming := RBMap Index Index Index.lt class HasAlphaEqv (α : Type) := -(aeqv : IndexRenaming → α → α → Bool) + (aeqv : IndexRenaming → α → α → Bool) export HasAlphaEqv (aeqv) def VarId.alphaEqv (ρ : IndexRenaming) (v₁ v₂ : VarId) : Bool := -match ρ.find? v₁.idx with -| some v => v == v₂.idx -| none => v₁ == v₂ + match ρ.find? v₁.idx with + | some v => v == v₂.idx + | none => v₁ == v₂ -instance VarId.hasAeqv : HasAlphaEqv VarId := ⟨VarId.alphaEqv⟩ +instance : HasAlphaEqv VarId := ⟨VarId.alphaEqv⟩ def Arg.alphaEqv (ρ : IndexRenaming) : Arg → Arg → Bool -| Arg.var v₁, Arg.var v₂ => aeqv ρ v₁ v₂ -| Arg.irrelevant, Arg.irrelevant => true -| _, _ => false + | Arg.var v₁, Arg.var v₂ => aeqv ρ v₁ v₂ + | Arg.irrelevant, Arg.irrelevant => true + | _, _ => false -instance Arg.hasAeqv : HasAlphaEqv Arg := ⟨Arg.alphaEqv⟩ +instance : HasAlphaEqv Arg := ⟨Arg.alphaEqv⟩ def args.alphaEqv (ρ : IndexRenaming) (args₁ args₂ : Array Arg) : Bool := -Array.isEqv args₁ args₂ (fun a b => aeqv ρ a b) + Array.isEqv args₁ args₂ (fun a b => aeqv ρ a b) -instance args.hasAeqv : HasAlphaEqv (Array Arg) := ⟨args.alphaEqv⟩ +instance: HasAlphaEqv (Array Arg) := ⟨args.alphaEqv⟩ def Expr.alphaEqv (ρ : IndexRenaming) : Expr → Expr → Bool -| Expr.ctor i₁ ys₁, Expr.ctor i₂ ys₂ => i₁ == i₂ && aeqv ρ ys₁ ys₂ -| Expr.reset n₁ x₁, Expr.reset n₂ x₂ => n₁ == n₂ && aeqv ρ x₁ x₂ -| Expr.reuse x₁ i₁ u₁ ys₁, Expr.reuse x₂ i₂ u₂ ys₂ => aeqv ρ x₁ x₂ && i₁ == i₂ && u₁ == u₂ && aeqv ρ ys₁ ys₂ -| Expr.proj i₁ x₁, Expr.proj i₂ x₂ => i₁ == i₂ && aeqv ρ x₁ x₂ -| Expr.uproj i₁ x₁, Expr.uproj i₂ x₂ => i₁ == i₂ && aeqv ρ x₁ x₂ -| Expr.sproj n₁ o₁ x₁, Expr.sproj n₂ o₂ x₂ => n₁ == n₂ && o₁ == o₂ && aeqv ρ x₁ x₂ -| Expr.fap c₁ ys₁, Expr.fap c₂ ys₂ => c₁ == c₂ && aeqv ρ ys₁ ys₂ -| Expr.pap c₁ ys₁, Expr.pap c₂ ys₂ => c₁ == c₂ && aeqv ρ ys₁ ys₂ -| Expr.ap x₁ ys₁, Expr.ap x₂ ys₂ => aeqv ρ x₁ x₂ && aeqv ρ ys₁ ys₂ -| Expr.box ty₁ x₁, Expr.box ty₂ x₂ => ty₁ == ty₂ && aeqv ρ x₁ x₂ -| Expr.unbox x₁, Expr.unbox x₂ => aeqv ρ x₁ x₂ -| Expr.lit v₁, Expr.lit v₂ => v₁ == v₂ -| Expr.isShared x₁, Expr.isShared x₂ => aeqv ρ x₁ x₂ -| Expr.isTaggedPtr x₁, Expr.isTaggedPtr x₂ => aeqv ρ x₁ x₂ -| _, _ => false + | Expr.ctor i₁ ys₁, Expr.ctor i₂ ys₂ => i₁ == i₂ && aeqv ρ ys₁ ys₂ + | Expr.reset n₁ x₁, Expr.reset n₂ x₂ => n₁ == n₂ && aeqv ρ x₁ x₂ + | Expr.reuse x₁ i₁ u₁ ys₁, Expr.reuse x₂ i₂ u₂ ys₂ => aeqv ρ x₁ x₂ && i₁ == i₂ && u₁ == u₂ && aeqv ρ ys₁ ys₂ + | Expr.proj i₁ x₁, Expr.proj i₂ x₂ => i₁ == i₂ && aeqv ρ x₁ x₂ + | Expr.uproj i₁ x₁, Expr.uproj i₂ x₂ => i₁ == i₂ && aeqv ρ x₁ x₂ + | Expr.sproj n₁ o₁ x₁, Expr.sproj n₂ o₂ x₂ => n₁ == n₂ && o₁ == o₂ && aeqv ρ x₁ x₂ + | Expr.fap c₁ ys₁, Expr.fap c₂ ys₂ => c₁ == c₂ && aeqv ρ ys₁ ys₂ + | Expr.pap c₁ ys₁, Expr.pap c₂ ys₂ => c₁ == c₂ && aeqv ρ ys₁ ys₂ + | Expr.ap x₁ ys₁, Expr.ap x₂ ys₂ => aeqv ρ x₁ x₂ && aeqv ρ ys₁ ys₂ + | Expr.box ty₁ x₁, Expr.box ty₂ x₂ => ty₁ == ty₂ && aeqv ρ x₁ x₂ + | Expr.unbox x₁, Expr.unbox x₂ => aeqv ρ x₁ x₂ + | Expr.lit v₁, Expr.lit v₂ => v₁ == v₂ + | Expr.isShared x₁, Expr.isShared x₂ => aeqv ρ x₁ x₂ + | Expr.isTaggedPtr x₁, Expr.isTaggedPtr x₂ => aeqv ρ x₁ x₂ + | _, _ => false -instance Expr.hasAeqv : HasAlphaEqv Expr:= ⟨Expr.alphaEqv⟩ +instance : HasAlphaEqv Expr:= ⟨Expr.alphaEqv⟩ def addVarRename (ρ : IndexRenaming) (x₁ x₂ : Nat) := -if x₁ == x₂ then ρ else ρ.insert x₁ x₂ + if x₁ == x₂ then ρ else ρ.insert x₁ x₂ def addParamRename (ρ : IndexRenaming) (p₁ p₂ : Param) : Option IndexRenaming := -if p₁.ty == p₂.ty && p₁.borrow = p₂.borrow then some (addVarRename ρ p₁.x.idx p₂.x.idx) -else none + if p₁.ty == p₂.ty && p₁.borrow = p₂.borrow then some (addVarRename ρ p₁.x.idx p₂.x.idx) + else none def addParamsRename (ρ : IndexRenaming) (ps₁ ps₂ : Array Param) : Option IndexRenaming := -if ps₁.size != ps₂.size then none -else Array.foldl₂ (fun ρ p₁ p₂ => do let ρ ← ρ; addParamRename ρ p₁ p₂) (some ρ) ps₁ ps₂ + if ps₁.size != ps₂.size then none + else Array.foldl₂ (fun ρ p₁ p₂ => do let ρ ← ρ; addParamRename ρ p₁ p₂) (some ρ) ps₁ ps₂ partial def FnBody.alphaEqv : IndexRenaming → FnBody → FnBody → Bool -| ρ, FnBody.vdecl x₁ t₁ v₁ b₁, FnBody.vdecl x₂ t₂ v₂ b₂ => t₁ == t₂ && aeqv ρ v₁ v₂ && alphaEqv (addVarRename ρ x₁.idx x₂.idx) b₁ b₂ -| ρ, FnBody.jdecl j₁ ys₁ v₁ b₁, FnBody.jdecl j₂ ys₂ v₂ b₂ => match addParamsRename ρ ys₁ ys₂ with - | some ρ' => alphaEqv ρ' v₁ v₂ && alphaEqv (addVarRename ρ j₁.idx j₂.idx) b₁ b₂ - | none => false -| ρ, FnBody.set x₁ i₁ y₁ b₁, FnBody.set x₂ i₂ y₂ b₂ => aeqv ρ x₁ x₂ && i₁ == i₂ && aeqv ρ y₁ y₂ && alphaEqv ρ b₁ b₂ -| ρ, FnBody.uset x₁ i₁ y₁ b₁, FnBody.uset x₂ i₂ y₂ b₂ => aeqv ρ x₁ x₂ && i₁ == i₂ && aeqv ρ y₁ y₂ && alphaEqv ρ b₁ b₂ -| ρ, FnBody.sset x₁ i₁ o₁ y₁ t₁ b₁, FnBody.sset x₂ i₂ o₂ y₂ t₂ b₂ => - aeqv ρ x₁ x₂ && i₁ = i₂ && o₁ = o₂ && aeqv ρ y₁ y₂ && t₁ == t₂ && alphaEqv ρ b₁ b₂ -| ρ, FnBody.setTag x₁ i₁ b₁, FnBody.setTag x₂ i₂ b₂ => aeqv ρ x₁ x₂ && i₁ == i₂ && alphaEqv ρ b₁ b₂ -| ρ, FnBody.inc x₁ n₁ c₁ p₁ b₁, FnBody.inc x₂ n₂ c₂ p₂ b₂ => aeqv ρ x₁ x₂ && n₁ == n₂ && c₁ == c₂ && p₁ == p₂ && alphaEqv ρ b₁ b₂ -| ρ, FnBody.dec x₁ n₁ c₁ p₁ b₁, FnBody.dec x₂ n₂ c₂ p₂ b₂ => aeqv ρ x₁ x₂ && n₁ == n₂ && c₁ == c₂ && p₁ == p₂ && alphaEqv ρ b₁ b₂ -| ρ, FnBody.del x₁ b₁, FnBody.del x₂ b₂ => aeqv ρ x₁ x₂ && alphaEqv ρ b₁ b₂ -| ρ, FnBody.mdata m₁ b₁, FnBody.mdata m₂ b₂ => m₁ == m₂ && alphaEqv ρ b₁ b₂ -| ρ, FnBody.case n₁ x₁ _ alts₁, FnBody.case n₂ x₂ _ alts₂ => n₁ == n₂ && aeqv ρ x₁ x₂ && Array.isEqv alts₁ alts₂ (fun alt₁ alt₂ => - match alt₁, alt₂ with - | Alt.ctor i₁ b₁, Alt.ctor i₂ b₂ => i₁ == i₂ && alphaEqv ρ b₁ b₂ - | Alt.default b₁, Alt.default b₂ => alphaEqv ρ b₁ b₂ - | _, _ => false) -| ρ, FnBody.jmp j₁ ys₁, FnBody.jmp j₂ ys₂ => j₁ == j₂ && aeqv ρ ys₁ ys₂ -| ρ, FnBody.ret x₁, FnBody.ret x₂ => aeqv ρ x₁ x₂ -| _, FnBody.unreachable, FnBody.unreachable => true -| _, _, _ => false + | ρ, FnBody.vdecl x₁ t₁ v₁ b₁, FnBody.vdecl x₂ t₂ v₂ b₂ => t₁ == t₂ && aeqv ρ v₁ v₂ && alphaEqv (addVarRename ρ x₁.idx x₂.idx) b₁ b₂ + | ρ, FnBody.jdecl j₁ ys₁ v₁ b₁, FnBody.jdecl j₂ ys₂ v₂ b₂ => match addParamsRename ρ ys₁ ys₂ with + | some ρ' => alphaEqv ρ' v₁ v₂ && alphaEqv (addVarRename ρ j₁.idx j₂.idx) b₁ b₂ + | none => false + | ρ, FnBody.set x₁ i₁ y₁ b₁, FnBody.set x₂ i₂ y₂ b₂ => aeqv ρ x₁ x₂ && i₁ == i₂ && aeqv ρ y₁ y₂ && alphaEqv ρ b₁ b₂ + | ρ, FnBody.uset x₁ i₁ y₁ b₁, FnBody.uset x₂ i₂ y₂ b₂ => aeqv ρ x₁ x₂ && i₁ == i₂ && aeqv ρ y₁ y₂ && alphaEqv ρ b₁ b₂ + | ρ, FnBody.sset x₁ i₁ o₁ y₁ t₁ b₁, FnBody.sset x₂ i₂ o₂ y₂ t₂ b₂ => + aeqv ρ x₁ x₂ && i₁ = i₂ && o₁ = o₂ && aeqv ρ y₁ y₂ && t₁ == t₂ && alphaEqv ρ b₁ b₂ + | ρ, FnBody.setTag x₁ i₁ b₁, FnBody.setTag x₂ i₂ b₂ => aeqv ρ x₁ x₂ && i₁ == i₂ && alphaEqv ρ b₁ b₂ + | ρ, FnBody.inc x₁ n₁ c₁ p₁ b₁, FnBody.inc x₂ n₂ c₂ p₂ b₂ => aeqv ρ x₁ x₂ && n₁ == n₂ && c₁ == c₂ && p₁ == p₂ && alphaEqv ρ b₁ b₂ + | ρ, FnBody.dec x₁ n₁ c₁ p₁ b₁, FnBody.dec x₂ n₂ c₂ p₂ b₂ => aeqv ρ x₁ x₂ && n₁ == n₂ && c₁ == c₂ && p₁ == p₂ && alphaEqv ρ b₁ b₂ + | ρ, FnBody.del x₁ b₁, FnBody.del x₂ b₂ => aeqv ρ x₁ x₂ && alphaEqv ρ b₁ b₂ + | ρ, FnBody.mdata m₁ b₁, FnBody.mdata m₂ b₂ => m₁ == m₂ && alphaEqv ρ b₁ b₂ + | ρ, FnBody.case n₁ x₁ _ alts₁, FnBody.case n₂ x₂ _ alts₂ => n₁ == n₂ && aeqv ρ x₁ x₂ && Array.isEqv alts₁ alts₂ (fun alt₁ alt₂ => + match alt₁, alt₂ with + | Alt.ctor i₁ b₁, Alt.ctor i₂ b₂ => i₁ == i₂ && alphaEqv ρ b₁ b₂ + | Alt.default b₁, Alt.default b₂ => alphaEqv ρ b₁ b₂ + | _, _ => false) + | ρ, FnBody.jmp j₁ ys₁, FnBody.jmp j₂ ys₂ => j₁ == j₂ && aeqv ρ ys₁ ys₂ + | ρ, FnBody.ret x₁, FnBody.ret x₂ => aeqv ρ x₁ x₂ + | _, FnBody.unreachable, FnBody.unreachable => true + | _, _, _ => false def FnBody.beq (b₁ b₂ : FnBody) : Bool := -FnBody.alphaEqv ∅ b₁ b₂ + FnBody.alphaEqv ∅ b₁ b₂ -instance FnBody.HasBeq : HasBeq FnBody := ⟨FnBody.beq⟩ +instance : HasBeq FnBody := ⟨FnBody.beq⟩ abbrev VarIdSet := RBTree VarId (fun x y => x.idx < y.idx) -namespace VarIdSet instance : Inhabited VarIdSet := ⟨{}⟩ -end VarIdSet def mkIf (x : VarId) (t e : FnBody) : FnBody := -FnBody.case `Bool x IRType.uint8 #[ - Alt.ctor {name := `Bool.false, cidx := 0, size := 0, usize := 0, ssize := 0} e, - Alt.ctor {name := `Bool.true, cidx := 1, size := 0, usize := 0, ssize := 0} t -] + FnBody.case `Bool x IRType.uint8 #[ + Alt.ctor {name := `Bool.false, cidx := 0, size := 0, usize := 0, ssize := 0} e, + Alt.ctor {name := `Bool.true, cidx := 1, size := 0, usize := 0, ssize := 0} t + ] end Lean.IR diff --git a/src/Lean/Compiler/IR/Format.lean b/src/Lean/Compiler/IR/Format.lean index 8a8cde5995..d56494ac07 100644 --- a/src/Lean/Compiler/IR/Format.lean +++ b/src/Lean/Compiler/IR/Format.lean @@ -11,127 +11,127 @@ namespace Lean namespace IR private def formatArg : Arg → Format -| Arg.var id => format id -| Arg.irrelevant => "◾" + | Arg.var id => format id + | Arg.irrelevant => "◾" -instance argHasFormat : HasFormat Arg := ⟨formatArg⟩ +instance : HasFormat Arg := ⟨formatArg⟩ def formatArray {α : Type} [HasFormat α] (args : Array α) : Format := -args.foldl (fun r a => r ++ " " ++ format a) Format.nil + args.foldl (fun r a => r ++ " " ++ format a) Format.nil private def formatLitVal : LitVal → Format -| LitVal.num v => format v -| LitVal.str v => format (repr v) + | LitVal.num v => format v + | LitVal.str v => format (repr v) -instance litValHasFormat : HasFormat LitVal := ⟨formatLitVal⟩ +instance : HasFormat LitVal := ⟨formatLitVal⟩ private def formatCtorInfo : CtorInfo → Format -| { name := name, cidx := cidx, usize := usize, ssize := ssize, .. } => do - let r := f!"ctor_{cidx}" - if usize > 0 || ssize > 0 then - r := f!"{r}.{usize}.{ssize}" - if name != Name.anonymous then - r := f!"{r}[{name}]" - r + | { name := name, cidx := cidx, usize := usize, ssize := ssize, .. } => do + let r := f!"ctor_{cidx}" + if usize > 0 || ssize > 0 then + r := f!"{r}.{usize}.{ssize}" + if name != Name.anonymous then + r := f!"{r}[{name}]" + r -instance ctorInfoHasFormat : HasFormat CtorInfo := ⟨formatCtorInfo⟩ +instance : HasFormat CtorInfo := ⟨formatCtorInfo⟩ private def formatExpr : Expr → Format -| Expr.ctor i ys => format i ++ formatArray ys -| Expr.reset n x => "reset[" ++ format n ++ "] " ++ format x -| Expr.reuse x i u ys => "reuse" ++ (if u then "!" else "") ++ " " ++ format x ++ " in " ++ format i ++ formatArray ys -| Expr.proj i x => "proj[" ++ format i ++ "] " ++ format x -| Expr.uproj i x => "uproj[" ++ format i ++ "] " ++ format x -| Expr.sproj n o x => "sproj[" ++ format n ++ ", " ++ format o ++ "] " ++ format x -| Expr.fap c ys => format c ++ formatArray ys -| Expr.pap c ys => "pap " ++ format c ++ formatArray ys -| Expr.ap x ys => "app " ++ format x ++ formatArray ys -| Expr.box _ x => "box " ++ format x -| Expr.unbox x => "unbox " ++ format x -| Expr.lit v => format v -| Expr.isShared x => "isShared " ++ format x -| Expr.isTaggedPtr x => "isTaggedPtr " ++ format x + | Expr.ctor i ys => format i ++ formatArray ys + | Expr.reset n x => "reset[" ++ format n ++ "] " ++ format x + | Expr.reuse x i u ys => "reuse" ++ (if u then "!" else "") ++ " " ++ format x ++ " in " ++ format i ++ formatArray ys + | Expr.proj i x => "proj[" ++ format i ++ "] " ++ format x + | Expr.uproj i x => "uproj[" ++ format i ++ "] " ++ format x + | Expr.sproj n o x => "sproj[" ++ format n ++ ", " ++ format o ++ "] " ++ format x + | Expr.fap c ys => format c ++ formatArray ys + | Expr.pap c ys => "pap " ++ format c ++ formatArray ys + | Expr.ap x ys => "app " ++ format x ++ formatArray ys + | Expr.box _ x => "box " ++ format x + | Expr.unbox x => "unbox " ++ format x + | Expr.lit v => format v + | Expr.isShared x => "isShared " ++ format x + | Expr.isTaggedPtr x => "isTaggedPtr " ++ format x -instance exprHasFormat : HasFormat Expr := ⟨formatExpr⟩ -instance exprHasToString : HasToString Expr := ⟨fun e => Format.pretty (format e)⟩ +instance : HasFormat Expr := ⟨formatExpr⟩ +instance : HasToString Expr := ⟨fun e => Format.pretty (format e)⟩ private partial def formatIRType : IRType → Format -| IRType.float => "float" -| IRType.uint8 => "u8" -| IRType.uint16 => "u16" -| IRType.uint32 => "u32" -| IRType.uint64 => "u64" -| IRType.usize => "usize" -| IRType.irrelevant => "◾" -| IRType.object => "obj" -| IRType.tobject => "tobj" -| IRType.struct _ tys => "struct " ++ Format.bracket "{" (@Format.joinSep _ ⟨formatIRType⟩ tys.toList ", ") "}" -| IRType.union _ tys => "union " ++ Format.bracket "{" (@Format.joinSep _ ⟨formatIRType⟩ tys.toList ", ") "}" + | IRType.float => "float" + | IRType.uint8 => "u8" + | IRType.uint16 => "u16" + | IRType.uint32 => "u32" + | IRType.uint64 => "u64" + | IRType.usize => "usize" + | IRType.irrelevant => "◾" + | IRType.object => "obj" + | IRType.tobject => "tobj" + | IRType.struct _ tys => "struct " ++ Format.bracket "{" (@Format.joinSep _ ⟨formatIRType⟩ tys.toList ", ") "}" + | IRType.union _ tys => "union " ++ Format.bracket "{" (@Format.joinSep _ ⟨formatIRType⟩ tys.toList ", ") "}" -instance typeHasFormat : HasFormat IRType := ⟨formatIRType⟩ -instance typeHasToString : HasToString IRType := ⟨toString ∘ format⟩ +instance : HasFormat IRType := ⟨formatIRType⟩ +instance : HasToString IRType := ⟨toString ∘ format⟩ private def formatParam : Param → Format -| { x := name, borrow := b, ty := ty } => "(" ++ format name ++ " : " ++ (if b then "@& " else "") ++ format ty ++ ")" + | { x := name, borrow := b, ty := ty } => "(" ++ format name ++ " : " ++ (if b then "@& " else "") ++ format ty ++ ")" -instance paramHasFormat : HasFormat Param := ⟨formatParam⟩ +instance : HasFormat Param := ⟨formatParam⟩ def formatAlt (fmt : FnBody → Format) (indent : Nat) : Alt → Format -| Alt.ctor i b => format i.name ++ " →" ++ Format.nest indent (Format.line ++ fmt b) -| Alt.default b => "default →" ++ Format.nest indent (Format.line ++ fmt b) + | Alt.ctor i b => format i.name ++ " →" ++ Format.nest indent (Format.line ++ fmt b) + | Alt.default b => "default →" ++ Format.nest indent (Format.line ++ fmt b) def formatParams (ps : Array Param) : Format := -formatArray ps + formatArray ps @[export lean_ir_format_fn_body_head] def formatFnBodyHead : FnBody → Format -| FnBody.vdecl x ty e b => "let " ++ format x ++ " : " ++ format ty ++ " := " ++ format e -| FnBody.jdecl j xs v b => format j ++ formatParams xs ++ " := ..." -| FnBody.set x i y b => "set " ++ format x ++ "[" ++ format i ++ "] := " ++ format y -| FnBody.uset x i y b => "uset " ++ format x ++ "[" ++ format i ++ "] := " ++ format y -| FnBody.sset x i o y ty b => "sset " ++ format x ++ "[" ++ format i ++ ", " ++ format o ++ "] : " ++ format ty ++ " := " ++ format y -| FnBody.setTag x cidx b => "setTag " ++ format x ++ " := " ++ format cidx -| FnBody.inc x n c _ b => "inc" ++ (if n != 1 then Format.sbracket (format n) else "") ++ " " ++ format x -| FnBody.dec x n c _ b => "dec" ++ (if n != 1 then Format.sbracket (format n) else "") ++ " " ++ format x -| FnBody.del x b => "del " ++ format x -| FnBody.mdata d b => "mdata " ++ format d -| FnBody.case tid x xType cs => "case " ++ format x ++ " of ..." -| FnBody.jmp j ys => "jmp " ++ format j ++ formatArray ys -| FnBody.ret x => "ret " ++ format x -| FnBody.unreachable => "⊥" - -partial def formatFnBody (fnBody : FnBody) (indent : Nat := 2) : Format := -let rec loop : FnBody → Format - | FnBody.vdecl x ty e b => "let " ++ format x ++ " : " ++ format ty ++ " := " ++ format e ++ ";" ++ Format.line ++ loop b - | FnBody.jdecl j xs v b => format j ++ formatParams xs ++ " :=" ++ Format.nest indent (Format.line ++ loop v) ++ ";" ++ Format.line ++ loop b - | FnBody.set x i y b => "set " ++ format x ++ "[" ++ format i ++ "] := " ++ format y ++ ";" ++ Format.line ++ loop b - | FnBody.uset x i y b => "uset " ++ format x ++ "[" ++ format i ++ "] := " ++ format y ++ ";" ++ Format.line ++ loop b - | FnBody.sset x i o y ty b => "sset " ++ format x ++ "[" ++ format i ++ ", " ++ format o ++ "] : " ++ format ty ++ " := " ++ format y ++ ";" ++ Format.line ++ loop b - | FnBody.setTag x cidx b => "setTag " ++ format x ++ " := " ++ format cidx ++ ";" ++ Format.line ++ loop b - | FnBody.inc x n c _ b => "inc" ++ (if n != 1 then Format.sbracket (format n) else "") ++ " " ++ format x ++ ";" ++ Format.line ++ loop b - | FnBody.dec x n c _ b => "dec" ++ (if n != 1 then Format.sbracket (format n) else "") ++ " " ++ format x ++ ";" ++ Format.line ++ loop b - | FnBody.del x b => "del " ++ format x ++ ";" ++ Format.line ++ loop b - | FnBody.mdata d b => "mdata " ++ format d ++ ";" ++ Format.line ++ loop b - | FnBody.case tid x xType cs => "case " ++ format x ++ " : " ++ format xType ++ " of" ++ cs.foldl (fun r c => r ++ Format.line ++ formatAlt loop indent c) Format.nil + | FnBody.vdecl x ty e b => "let " ++ format x ++ " : " ++ format ty ++ " := " ++ format e + | FnBody.jdecl j xs v b => format j ++ formatParams xs ++ " := ..." + | FnBody.set x i y b => "set " ++ format x ++ "[" ++ format i ++ "] := " ++ format y + | FnBody.uset x i y b => "uset " ++ format x ++ "[" ++ format i ++ "] := " ++ format y + | FnBody.sset x i o y ty b => "sset " ++ format x ++ "[" ++ format i ++ ", " ++ format o ++ "] : " ++ format ty ++ " := " ++ format y + | FnBody.setTag x cidx b => "setTag " ++ format x ++ " := " ++ format cidx + | FnBody.inc x n c _ b => "inc" ++ (if n != 1 then Format.sbracket (format n) else "") ++ " " ++ format x + | FnBody.dec x n c _ b => "dec" ++ (if n != 1 then Format.sbracket (format n) else "") ++ " " ++ format x + | FnBody.del x b => "del " ++ format x + | FnBody.mdata d b => "mdata " ++ format d + | FnBody.case tid x xType cs => "case " ++ format x ++ " of ..." | FnBody.jmp j ys => "jmp " ++ format j ++ formatArray ys | FnBody.ret x => "ret " ++ format x | FnBody.unreachable => "⊥" -loop fnBody -instance fnBodyHasFormat : HasFormat FnBody := ⟨formatFnBody⟩ -instance fnBodyHasToString : HasToString FnBody := ⟨fun b => (format b).pretty⟩ +partial def formatFnBody (fnBody : FnBody) (indent : Nat := 2) : Format := + let rec loop : FnBody → Format + | FnBody.vdecl x ty e b => "let " ++ format x ++ " : " ++ format ty ++ " := " ++ format e ++ ";" ++ Format.line ++ loop b + | FnBody.jdecl j xs v b => format j ++ formatParams xs ++ " :=" ++ Format.nest indent (Format.line ++ loop v) ++ ";" ++ Format.line ++ loop b + | FnBody.set x i y b => "set " ++ format x ++ "[" ++ format i ++ "] := " ++ format y ++ ";" ++ Format.line ++ loop b + | FnBody.uset x i y b => "uset " ++ format x ++ "[" ++ format i ++ "] := " ++ format y ++ ";" ++ Format.line ++ loop b + | FnBody.sset x i o y ty b => "sset " ++ format x ++ "[" ++ format i ++ ", " ++ format o ++ "] : " ++ format ty ++ " := " ++ format y ++ ";" ++ Format.line ++ loop b + | FnBody.setTag x cidx b => "setTag " ++ format x ++ " := " ++ format cidx ++ ";" ++ Format.line ++ loop b + | FnBody.inc x n c _ b => "inc" ++ (if n != 1 then Format.sbracket (format n) else "") ++ " " ++ format x ++ ";" ++ Format.line ++ loop b + | FnBody.dec x n c _ b => "dec" ++ (if n != 1 then Format.sbracket (format n) else "") ++ " " ++ format x ++ ";" ++ Format.line ++ loop b + | FnBody.del x b => "del " ++ format x ++ ";" ++ Format.line ++ loop b + | FnBody.mdata d b => "mdata " ++ format d ++ ";" ++ Format.line ++ loop b + | FnBody.case tid x xType cs => "case " ++ format x ++ " : " ++ format xType ++ " of" ++ cs.foldl (fun r c => r ++ Format.line ++ formatAlt loop indent c) Format.nil + | FnBody.jmp j ys => "jmp " ++ format j ++ formatArray ys + | FnBody.ret x => "ret " ++ format x + | FnBody.unreachable => "⊥" + loop fnBody + +instance : HasFormat FnBody := ⟨formatFnBody⟩ +instance : HasToString FnBody := ⟨fun b => (format b).pretty⟩ def formatDecl (decl : Decl) (indent : Nat := 2) : Format := -match decl with -| Decl.fdecl f xs ty b => "def " ++ format f ++ formatParams xs ++ format " : " ++ format ty ++ " :=" ++ Format.nest indent (Format.line ++ formatFnBody b indent) -| Decl.extern f xs ty _ => "extern " ++ format f ++ formatParams xs ++ format " : " ++ format ty + match decl with + | Decl.fdecl f xs ty b => "def " ++ format f ++ formatParams xs ++ format " : " ++ format ty ++ " :=" ++ Format.nest indent (Format.line ++ formatFnBody b indent) + | Decl.extern f xs ty _ => "extern " ++ format f ++ formatParams xs ++ format " : " ++ format ty -instance declHasFormat : HasFormat Decl := ⟨formatDecl⟩ +instance : HasFormat Decl := ⟨formatDecl⟩ @[export lean_ir_decl_to_string] def declToString (d : Decl) : String := -(format d).pretty + (format d).pretty -instance declHasToString : HasToString Decl := ⟨declToString⟩ +instance : HasToString Decl := ⟨declToString⟩ end Lean.IR diff --git a/src/Lean/Compiler/IR/LiveVars.lean b/src/Lean/Compiler/IR/LiveVars.lean index 3a573a1c4a..1e93e66075 100644 --- a/src/Lean/Compiler/IR/LiveVars.lean +++ b/src/Lean/Compiler/IR/LiveVars.lean @@ -37,36 +37,36 @@ namespace IsLive -/ abbrev M := StateM LocalContext -@[inline] def visitVar (w : Index) (x : VarId) : M Bool := pure (HasIndex.visitVar w x) -@[inline] def visitJP (w : Index) (x : JoinPointId) : M Bool := pure (HasIndex.visitJP w x) -@[inline] def visitArg (w : Index) (a : Arg) : M Bool := pure (HasIndex.visitArg w a) -@[inline] def visitArgs (w : Index) (as : Array Arg) : M Bool := pure (HasIndex.visitArgs w as) -@[inline] def visitExpr (w : Index) (e : Expr) : M Bool := pure (HasIndex.visitExpr w e) +abbrev visitVar (w : Index) (x : VarId) : M Bool := pure (HasIndex.visitVar w x) +abbrev visitJP (w : Index) (x : JoinPointId) : M Bool := pure (HasIndex.visitJP w x) +abbrev visitArg (w : Index) (a : Arg) : M Bool := pure (HasIndex.visitArg w a) +abbrev visitArgs (w : Index) (as : Array Arg) : M Bool := pure (HasIndex.visitArgs w as) +abbrev visitExpr (w : Index) (e : Expr) : M Bool := pure (HasIndex.visitExpr w e) partial def visitFnBody (w : Index) : FnBody → M Bool -| FnBody.vdecl x _ v b => visitExpr w v <||> visitFnBody w b -| FnBody.jdecl j ys v b => visitFnBody w v <||> visitFnBody w b -| FnBody.set x _ y b => visitVar w x <||> visitArg w y <||> visitFnBody w b -| FnBody.uset x _ y b => visitVar w x <||> visitVar w y <||> visitFnBody w b -| FnBody.sset x _ _ y _ b => visitVar w x <||> visitVar w y <||> visitFnBody w b -| FnBody.setTag x _ b => visitVar w x <||> visitFnBody w b -| FnBody.inc x _ _ _ b => visitVar w x <||> visitFnBody w b -| FnBody.dec x _ _ _ b => visitVar w x <||> visitFnBody w b -| FnBody.del x b => visitVar w x <||> visitFnBody w b -| FnBody.mdata _ b => visitFnBody w b -| FnBody.jmp j ys => visitArgs w ys <||> do - let ctx ← get - match ctx.getJPBody j with - | some b => - -- `j` is not a local join point since we assume we cannot shadow join point declarations. - -- Instead of marking the join points that we have already been visited, we permanently remove `j` from the context. - set (ctx.eraseJoinPointDecl j) *> visitFnBody w b - | none => - -- `j` must be a local join point. So do nothing since we have already visite its body. - pure false -| FnBody.ret x => visitArg w x -| FnBody.case _ x _ alts => visitVar w x <||> alts.anyM (fun alt => visitFnBody w alt.body) -| FnBody.unreachable => pure false + | FnBody.vdecl x _ v b => visitExpr w v <||> visitFnBody w b + | FnBody.jdecl j ys v b => visitFnBody w v <||> visitFnBody w b + | FnBody.set x _ y b => visitVar w x <||> visitArg w y <||> visitFnBody w b + | FnBody.uset x _ y b => visitVar w x <||> visitVar w y <||> visitFnBody w b + | FnBody.sset x _ _ y _ b => visitVar w x <||> visitVar w y <||> visitFnBody w b + | FnBody.setTag x _ b => visitVar w x <||> visitFnBody w b + | FnBody.inc x _ _ _ b => visitVar w x <||> visitFnBody w b + | FnBody.dec x _ _ _ b => visitVar w x <||> visitFnBody w b + | FnBody.del x b => visitVar w x <||> visitFnBody w b + | FnBody.mdata _ b => visitFnBody w b + | FnBody.jmp j ys => visitArgs w ys <||> do + let ctx ← get + match ctx.getJPBody j with + | some b => + -- `j` is not a local join point since we assume we cannot shadow join point declarations. + -- Instead of marking the join points that we have already been visited, we permanently remove `j` from the context. + set (ctx.eraseJoinPointDecl j) *> visitFnBody w b + | none => + -- `j` must be a local join point. So do nothing since we have already visite its body. + pure false + | FnBody.ret x => visitArg w x + | FnBody.case _ x _ alts => visitVar w x <||> alts.anyM (fun alt => visitFnBody w alt.body) + | FnBody.unreachable => pure false end IsLive @@ -77,15 +77,15 @@ end IsLive Recall that we say that a join point `j` is free in `b` if `b` contains `FnBody.jmp j ys` and `j` is not local. -/ def FnBody.hasLiveVar (b : FnBody) (ctx : LocalContext) (x : VarId) : Bool := -(IsLive.visitFnBody x.idx b).run' ctx + (IsLive.visitFnBody x.idx b).run' ctx abbrev LiveVarSet := VarIdSet abbrev JPLiveVarMap := Std.RBMap JoinPointId LiveVarSet (fun j₁ j₂ => j₁.idx < j₂.idx) -instance LiveVarSet.inhabited : Inhabited LiveVarSet := ⟨{}⟩ +instance : Inhabited LiveVarSet := ⟨{}⟩ def mkLiveVarSet (x : VarId) : LiveVarSet := -Std.RBTree.empty.insert x + Std.RBTree.empty.insert x namespace LiveVars @@ -93,70 +93,77 @@ abbrev Collector := LiveVarSet → LiveVarSet @[inline] private def skip : Collector := fun s => s @[inline] private def collectVar (x : VarId) : Collector := fun s => s.insert x + private def collectArg : Arg → Collector -| Arg.var x => collectVar x -| irrelevant => skip -@[specialize] private def collectArray {α : Type} (as : Array α) (f : α → Collector) : Collector := -fun s => as.foldl (fun s a => f a s) s + | Arg.var x => collectVar x + | irrelevant => skip + +@[specialize] private def collectArray {α : Type} (as : Array α) (f : α → Collector) : Collector := fun s => + as.foldl (fun s a => f a s) s + private def collectArgs (as : Array Arg) : Collector := -collectArray as collectArg + collectArray as collectArg + private def accumulate (s' : LiveVarSet) : Collector := -fun s => s'.fold (fun s x => s.insert x) s + fun s => s'.fold (fun s x => s.insert x) s + private def collectJP (m : JPLiveVarMap) (j : JoinPointId) : Collector := -match m.find? j with -| some xs => accumulate xs -| none => skip -- unreachable for well-formed code -private def bindVar (x : VarId) : Collector := -fun s => s.erase x -private def bindParams (ps : Array Param) : Collector := -fun s => ps.foldl (fun s p => s.erase p.x) s + match m.find? j with + | some xs => accumulate xs + | none => skip -- unreachable for well-formed code + +private def bindVar (x : VarId) : Collector := fun s => + s.erase x + +private def bindParams (ps : Array Param) : Collector := fun s => + ps.foldl (fun s p => s.erase p.x) s def collectExpr : Expr → Collector -| Expr.ctor _ ys => collectArgs ys -| Expr.reset _ x => collectVar x -| Expr.reuse x _ _ ys => collectVar x ∘ collectArgs ys -| Expr.proj _ x => collectVar x -| Expr.uproj _ x => collectVar x -| Expr.sproj _ _ x => collectVar x -| Expr.fap _ ys => collectArgs ys -| Expr.pap _ ys => collectArgs ys -| Expr.ap x ys => collectVar x ∘ collectArgs ys -| Expr.box _ x => collectVar x -| Expr.unbox x => collectVar x -| Expr.lit v => skip -| Expr.isShared x => collectVar x -| Expr.isTaggedPtr x => collectVar x + | Expr.ctor _ ys => collectArgs ys + | Expr.reset _ x => collectVar x + | Expr.reuse x _ _ ys => collectVar x ∘ collectArgs ys + | Expr.proj _ x => collectVar x + | Expr.uproj _ x => collectVar x + | Expr.sproj _ _ x => collectVar x + | Expr.fap _ ys => collectArgs ys + | Expr.pap _ ys => collectArgs ys + | Expr.ap x ys => collectVar x ∘ collectArgs ys + | Expr.box _ x => collectVar x + | Expr.unbox x => collectVar x + | Expr.lit v => skip + | Expr.isShared x => collectVar x + | Expr.isTaggedPtr x => collectVar x partial def collectFnBody : FnBody → JPLiveVarMap → Collector -| FnBody.vdecl x _ v b, m => collectExpr v ∘ bindVar x ∘ collectFnBody b m -| FnBody.jdecl j ys v b, m => - let jLiveVars := (bindParams ys ∘ collectFnBody v m) {}; - let m := m.insert j jLiveVars; - collectFnBody b m -| FnBody.set x _ y b, m => collectVar x ∘ collectArg y ∘ collectFnBody b m -| FnBody.setTag x _ b, m => collectVar x ∘ collectFnBody b m -| FnBody.uset x _ y b, m => collectVar x ∘ collectVar y ∘ collectFnBody b m -| FnBody.sset x _ _ y _ b, m => collectVar x ∘ collectVar y ∘ collectFnBody b m -| FnBody.inc x _ _ _ b, m => collectVar x ∘ collectFnBody b m -| FnBody.dec x _ _ _ b, m => collectVar x ∘ collectFnBody b m -| FnBody.del x b, m => collectVar x ∘ collectFnBody b m -| FnBody.mdata _ b, m => collectFnBody b m -| FnBody.ret x, m => collectArg x -| FnBody.case _ x _ alts, m => collectVar x ∘ collectArray alts (fun alt => collectFnBody alt.body m) -| FnBody.unreachable, m => skip -| FnBody.jmp j xs, m => collectJP m j ∘ collectArgs xs + | FnBody.vdecl x _ v b, m => collectExpr v ∘ bindVar x ∘ collectFnBody b m + | FnBody.jdecl j ys v b, m => + let jLiveVars := (bindParams ys ∘ collectFnBody v m) {}; + let m := m.insert j jLiveVars; + collectFnBody b m + | FnBody.set x _ y b, m => collectVar x ∘ collectArg y ∘ collectFnBody b m + | FnBody.setTag x _ b, m => collectVar x ∘ collectFnBody b m + | FnBody.uset x _ y b, m => collectVar x ∘ collectVar y ∘ collectFnBody b m + | FnBody.sset x _ _ y _ b, m => collectVar x ∘ collectVar y ∘ collectFnBody b m + | FnBody.inc x _ _ _ b, m => collectVar x ∘ collectFnBody b m + | FnBody.dec x _ _ _ b, m => collectVar x ∘ collectFnBody b m + | FnBody.del x b, m => collectVar x ∘ collectFnBody b m + | FnBody.mdata _ b, m => collectFnBody b m + | FnBody.ret x, m => collectArg x + | FnBody.case _ x _ alts, m => collectVar x ∘ collectArray alts (fun alt => collectFnBody alt.body m) + | FnBody.unreachable, m => skip + | FnBody.jmp j xs, m => collectJP m j ∘ collectArgs xs def updateJPLiveVarMap (j : JoinPointId) (ys : Array Param) (v : FnBody) (m : JPLiveVarMap) : JPLiveVarMap := -let jLiveVars := (bindParams ys ∘ collectFnBody v m) {}; -m.insert j jLiveVars + let jLiveVars := (bindParams ys ∘ collectFnBody v m) {}; + m.insert j jLiveVars end LiveVars def updateLiveVars (e : Expr) (v : LiveVarSet) : LiveVarSet := -LiveVars.collectExpr e v + LiveVars.collectExpr e v def collectLiveVars (b : FnBody) (m : JPLiveVarMap) (v : LiveVarSet := {}) : LiveVarSet := -LiveVars.collectFnBody b m v + LiveVars.collectFnBody b m v export LiveVars (updateJPLiveVarMap) diff --git a/src/Lean/Compiler/IR/NormIds.lean b/src/Lean/Compiler/IR/NormIds.lean index ace145d24e..741247fdac 100644 --- a/src/Lean/Compiler/IR/NormIds.lean +++ b/src/Lean/Compiler/IR/NormIds.lean @@ -11,171 +11,168 @@ namespace Lean.IR.UniqueIds abbrev M := StateT IndexSet Id def checkId (id : Index) : M Bool := -modifyGet fun s => - if s.contains id then (false, s) - else (true, s.insert id) + modifyGet fun s => + if s.contains id then (false, s) + else (true, s.insert id) def checkParams (ps : Array Param) : M Bool := -ps.allM $ fun p => checkId p.x.idx + ps.allM $ fun p => checkId p.x.idx partial def checkFnBody : FnBody → M Bool -| FnBody.vdecl x _ _ b => checkId x.idx <&&> checkFnBody b -| FnBody.jdecl j ys _ b => checkId j.idx <&&> checkParams ys <&&> checkFnBody b -| FnBody.case _ _ _ alts => alts.allM fun alt => checkFnBody alt.body -| b => if b.isTerminal then pure true else checkFnBody b.body + | FnBody.vdecl x _ _ b => checkId x.idx <&&> checkFnBody b + | FnBody.jdecl j ys _ b => checkId j.idx <&&> checkParams ys <&&> checkFnBody b + | FnBody.case _ _ _ alts => alts.allM fun alt => checkFnBody alt.body + | b => if b.isTerminal then pure true else checkFnBody b.body partial def checkDecl : Decl → M Bool -| Decl.fdecl _ xs _ b => checkParams xs <&&> checkFnBody b -| Decl.extern _ xs _ _ => checkParams xs + | Decl.fdecl _ xs _ b => checkParams xs <&&> checkFnBody b + | Decl.extern _ xs _ _ => checkParams xs end UniqueIds /- Return true if variable, parameter and join point ids are unique -/ def Decl.uniqueIds (d : Decl) : Bool := -(UniqueIds.checkDecl d).run' {} + (UniqueIds.checkDecl d).run' {} namespace NormalizeIds abbrev M := ReaderT IndexRenaming Id -def normIndex (x : Index) : M Index := -fun m => match m.find? x with -| some y => y -| none => x +def normIndex (x : Index) : M Index := fun m => + match m.find? x with + | some y => y + | none => x def normVar (x : VarId) : M VarId := -VarId.mk <$> normIndex x.idx + VarId.mk <$> normIndex x.idx def normJP (x : JoinPointId) : M JoinPointId := -JoinPointId.mk <$> normIndex x.idx + JoinPointId.mk <$> normIndex x.idx def normArg : Arg → M Arg -| Arg.var x => Arg.var <$> normVar x -| other => pure other + | Arg.var x => Arg.var <$> normVar x + | other => pure other -def normArgs (as : Array Arg) : M (Array Arg) := -fun m => as.map $ fun a => normArg a m +def normArgs (as : Array Arg) : M (Array Arg) := fun m => + as.map $ fun a => normArg a m def normExpr : Expr → M Expr -| Expr.ctor c ys, m => Expr.ctor c (normArgs ys m) -| Expr.reset n x, m => Expr.reset n (normVar x m) -| Expr.reuse x c u ys, m => Expr.reuse (normVar x m) c u (normArgs ys m) -| Expr.proj i x, m => Expr.proj i (normVar x m) -| Expr.uproj i x, m => Expr.uproj i (normVar x m) -| Expr.sproj n o x, m => Expr.sproj n o (normVar x m) -| Expr.fap c ys, m => Expr.fap c (normArgs ys m) -| Expr.pap c ys, m => Expr.pap c (normArgs ys m) -| Expr.ap x ys, m => Expr.ap (normVar x m) (normArgs ys m) -| Expr.box t x, m => Expr.box t (normVar x m) -| Expr.unbox x, m => Expr.unbox (normVar x m) -| Expr.isShared x, m => Expr.isShared (normVar x m) -| Expr.isTaggedPtr x, m => Expr.isTaggedPtr (normVar x m) -| e@(Expr.lit v), m => e + | Expr.ctor c ys, m => Expr.ctor c (normArgs ys m) + | Expr.reset n x, m => Expr.reset n (normVar x m) + | Expr.reuse x c u ys, m => Expr.reuse (normVar x m) c u (normArgs ys m) + | Expr.proj i x, m => Expr.proj i (normVar x m) + | Expr.uproj i x, m => Expr.uproj i (normVar x m) + | Expr.sproj n o x, m => Expr.sproj n o (normVar x m) + | Expr.fap c ys, m => Expr.fap c (normArgs ys m) + | Expr.pap c ys, m => Expr.pap c (normArgs ys m) + | Expr.ap x ys, m => Expr.ap (normVar x m) (normArgs ys m) + | Expr.box t x, m => Expr.box t (normVar x m) + | Expr.unbox x, m => Expr.unbox (normVar x m) + | Expr.isShared x, m => Expr.isShared (normVar x m) + | Expr.isTaggedPtr x, m => Expr.isTaggedPtr (normVar x m) + | e@(Expr.lit v), m => e abbrev N := ReaderT IndexRenaming (StateM Nat) -@[inline] def withVar {α : Type} (x : VarId) (k : VarId → N α) : N α := -fun m => do +@[inline] def withVar {α : Type} (x : VarId) (k : VarId → N α) : N α := fun m => do let n ← getModify (fun n => n + 1) k { idx := n } (m.insert x.idx n) -@[inline] def withJP {α : Type} (x : JoinPointId) (k : JoinPointId → N α) : N α := -fun m => do +@[inline] def withJP {α : Type} (x : JoinPointId) (k : JoinPointId → N α) : N α := fun m => do let n ← getModify (fun n => n + 1) k { idx := n } (m.insert x.idx n) -@[inline] def withParams {α : Type} (ps : Array Param) (k : Array Param → N α) : N α := -fun m => do +@[inline] def withParams {α : Type} (ps : Array Param) (k : Array Param → N α) : N α := fun m => do let m ← ps.foldlM (init := m) fun m p => do let n ← getModify fun n => n + 1 pure $ m.insert p.x.idx n let ps := ps.map fun p => { p with x := normVar p.x m } k ps m -instance MtoN : MonadLift M N := -⟨fun x m => pure $ x m⟩ +instance : MonadLift M N := + ⟨fun x m => pure $ x m⟩ partial def normFnBody : FnBody → N FnBody -| FnBody.vdecl x t v b => do let v ← normExpr v; withVar x fun x => do return FnBody.vdecl x t v (← normFnBody b) -| FnBody.jdecl j ys v b => do - let (ys, v) ← withParams ys fun ys => do let v ← normFnBody v; pure (ys, v) - withJP j fun j => do return FnBody.jdecl j ys v (← normFnBody b) -| FnBody.set x i y b => do return FnBody.set (← normVar x) i (← normArg y) (← normFnBody b) -| FnBody.uset x i y b => do return FnBody.uset (← normVar x) i (← normVar y) (← normFnBody b) -| FnBody.sset x i o y t b => do return FnBody.sset (← normVar x) i o (← normVar y) t (← normFnBody b) -| FnBody.setTag x i b => do return FnBody.setTag (← normVar x) i (← normFnBody b) -| FnBody.inc x n c p b => do return FnBody.inc (← normVar x) n c p (← normFnBody b) -| FnBody.dec x n c p b => do return FnBody.dec (← normVar x) n c p (← normFnBody b) -| FnBody.del x b => do return FnBody.del (← normVar x) (← normFnBody b) -| FnBody.mdata d b => do return FnBody.mdata d (← normFnBody b) -| FnBody.case tid x xType alts => do - let x ← normVar x - let alts ← alts.mapM fun alt => alt.mmodifyBody normFnBody - return FnBody.case tid x xType alts -| FnBody.jmp j ys => do return FnBody.jmp (← normJP j) (← normArgs ys) -| FnBody.ret x => do return FnBody.ret (← normArg x) -| FnBody.unreachable => pure FnBody.unreachable + | FnBody.vdecl x t v b => do let v ← normExpr v; withVar x fun x => do return FnBody.vdecl x t v (← normFnBody b) + | FnBody.jdecl j ys v b => do + let (ys, v) ← withParams ys fun ys => do let v ← normFnBody v; pure (ys, v) + withJP j fun j => do return FnBody.jdecl j ys v (← normFnBody b) + | FnBody.set x i y b => do return FnBody.set (← normVar x) i (← normArg y) (← normFnBody b) + | FnBody.uset x i y b => do return FnBody.uset (← normVar x) i (← normVar y) (← normFnBody b) + | FnBody.sset x i o y t b => do return FnBody.sset (← normVar x) i o (← normVar y) t (← normFnBody b) + | FnBody.setTag x i b => do return FnBody.setTag (← normVar x) i (← normFnBody b) + | FnBody.inc x n c p b => do return FnBody.inc (← normVar x) n c p (← normFnBody b) + | FnBody.dec x n c p b => do return FnBody.dec (← normVar x) n c p (← normFnBody b) + | FnBody.del x b => do return FnBody.del (← normVar x) (← normFnBody b) + | FnBody.mdata d b => do return FnBody.mdata d (← normFnBody b) + | FnBody.case tid x xType alts => do + let x ← normVar x + let alts ← alts.mapM fun alt => alt.mmodifyBody normFnBody + return FnBody.case tid x xType alts + | FnBody.jmp j ys => do return FnBody.jmp (← normJP j) (← normArgs ys) + | FnBody.ret x => do return FnBody.ret (← normArg x) + | FnBody.unreachable => pure FnBody.unreachable def normDecl : Decl → N Decl -| Decl.fdecl f xs t b => withParams xs fun xs => Decl.fdecl f xs t <$> normFnBody b -| other => pure other + | Decl.fdecl f xs t b => withParams xs fun xs => Decl.fdecl f xs t <$> normFnBody b + | other => pure other end NormalizeIds /- Create a declaration equivalent to `d` s.t. `d.normalizeIds.uniqueIds == true` -/ def Decl.normalizeIds (d : Decl) : Decl := -(NormalizeIds.normDecl d {}).run' 1 + (NormalizeIds.normDecl d {}).run' 1 /- Apply a function `f : VarId → VarId` to variable occurrences. The following functions assume the IR code does not have variable shadowing. -/ namespace MapVars @[inline] def mapArg (f : VarId → VarId) : Arg → Arg -| Arg.var x => Arg.var (f x) -| a => a + | Arg.var x => Arg.var (f x) + | a => a @[specialize] def mapArgs (f : VarId → VarId) (as : Array Arg) : Array Arg := -as.map (mapArg f) + as.map (mapArg f) @[specialize] def mapExpr (f : VarId → VarId) : Expr → Expr -| Expr.ctor c ys => Expr.ctor c (mapArgs f ys) -| Expr.reset n x => Expr.reset n (f x) -| Expr.reuse x c u ys => Expr.reuse (f x) c u (mapArgs f ys) -| Expr.proj i x => Expr.proj i (f x) -| Expr.uproj i x => Expr.uproj i (f x) -| Expr.sproj n o x => Expr.sproj n o (f x) -| Expr.fap c ys => Expr.fap c (mapArgs f ys) -| Expr.pap c ys => Expr.pap c (mapArgs f ys) -| Expr.ap x ys => Expr.ap (f x) (mapArgs f ys) -| Expr.box t x => Expr.box t (f x) -| Expr.unbox x => Expr.unbox (f x) -| Expr.isShared x => Expr.isShared (f x) -| Expr.isTaggedPtr x => Expr.isTaggedPtr (f x) -| e@(Expr.lit v) => e + | Expr.ctor c ys => Expr.ctor c (mapArgs f ys) + | Expr.reset n x => Expr.reset n (f x) + | Expr.reuse x c u ys => Expr.reuse (f x) c u (mapArgs f ys) + | Expr.proj i x => Expr.proj i (f x) + | Expr.uproj i x => Expr.uproj i (f x) + | Expr.sproj n o x => Expr.sproj n o (f x) + | Expr.fap c ys => Expr.fap c (mapArgs f ys) + | Expr.pap c ys => Expr.pap c (mapArgs f ys) + | Expr.ap x ys => Expr.ap (f x) (mapArgs f ys) + | Expr.box t x => Expr.box t (f x) + | Expr.unbox x => Expr.unbox (f x) + | Expr.isShared x => Expr.isShared (f x) + | Expr.isTaggedPtr x => Expr.isTaggedPtr (f x) + | e@(Expr.lit v) => e @[specialize] partial def mapFnBody (f : VarId → VarId) : FnBody → FnBody -| FnBody.vdecl x t v b => FnBody.vdecl x t (mapExpr f v) (mapFnBody f b) -| FnBody.jdecl j ys v b => FnBody.jdecl j ys (mapFnBody f v) (mapFnBody f b) -| FnBody.set x i y b => FnBody.set (f x) i (mapArg f y) (mapFnBody f b) -| FnBody.setTag x i b => FnBody.setTag (f x) i (mapFnBody f b) -| FnBody.uset x i y b => FnBody.uset (f x) i (f y) (mapFnBody f b) -| FnBody.sset x i o y t b => FnBody.sset (f x) i o (f y) t (mapFnBody f b) -| FnBody.inc x n c p b => FnBody.inc (f x) n c p (mapFnBody f b) -| FnBody.dec x n c p b => FnBody.dec (f x) n c p (mapFnBody f b) -| FnBody.del x b => FnBody.del (f x) (mapFnBody f b) -| FnBody.mdata d b => FnBody.mdata d (mapFnBody f b) -| FnBody.case tid x xType alts => FnBody.case tid (f x) xType (alts.map fun alt => alt.modifyBody (mapFnBody f)) -| FnBody.jmp j ys => FnBody.jmp j (mapArgs f ys) -| FnBody.ret x => FnBody.ret (mapArg f x) -| FnBody.unreachable => FnBody.unreachable + | FnBody.vdecl x t v b => FnBody.vdecl x t (mapExpr f v) (mapFnBody f b) + | FnBody.jdecl j ys v b => FnBody.jdecl j ys (mapFnBody f v) (mapFnBody f b) + | FnBody.set x i y b => FnBody.set (f x) i (mapArg f y) (mapFnBody f b) + | FnBody.setTag x i b => FnBody.setTag (f x) i (mapFnBody f b) + | FnBody.uset x i y b => FnBody.uset (f x) i (f y) (mapFnBody f b) + | FnBody.sset x i o y t b => FnBody.sset (f x) i o (f y) t (mapFnBody f b) + | FnBody.inc x n c p b => FnBody.inc (f x) n c p (mapFnBody f b) + | FnBody.dec x n c p b => FnBody.dec (f x) n c p (mapFnBody f b) + | FnBody.del x b => FnBody.del (f x) (mapFnBody f b) + | FnBody.mdata d b => FnBody.mdata d (mapFnBody f b) + | FnBody.case tid x xType alts => FnBody.case tid (f x) xType (alts.map fun alt => alt.modifyBody (mapFnBody f)) + | FnBody.jmp j ys => FnBody.jmp j (mapArgs f ys) + | FnBody.ret x => FnBody.ret (mapArg f x) + | FnBody.unreachable => FnBody.unreachable end MapVars @[inline] def FnBody.mapVars (f : VarId → VarId) (b : FnBody) : FnBody := -MapVars.mapFnBody f b + MapVars.mapFnBody f b /- Replace `x` with `y` in `b`. This function assumes `b` does not shadow `x` -/ def FnBody.replaceVar (x y : VarId) (b : FnBody) : FnBody := -b.mapVars fun z => if x == z then y else z + b.mapVars fun z => if x == z then y else z end Lean.IR diff --git a/src/Lean/Data/Format.lean b/src/Lean/Data/Format.lean index 6183187d82..c1032486d2 100644 --- a/src/Lean/Data/Format.lean +++ b/src/Lean/Data/Format.lean @@ -62,7 +62,7 @@ private structure SpaceResult := (foundFlattenedHardLine : Bool := false) (space : Nat := 0) -instance SpaceResult.inhabited : Inhabited SpaceResult := ⟨{}⟩ +instance : Inhabited SpaceResult := ⟨{}⟩ @[inline] private def merge (w : Nat) (r₁ : SpaceResult) (r₂ : Nat → SpaceResult) : SpaceResult := if r₁.space > w || r₁.foundLine then r₁ diff --git a/src/Lean/Data/JsonRpc.lean b/src/Lean/Data/JsonRpc.lean index 6654a298d7..fd5bad3b4f 100644 --- a/src/Lean/Data/JsonRpc.lean +++ b/src/Lean/Data/JsonRpc.lean @@ -146,7 +146,7 @@ def aux4 (j : Json) : Option Message := do -- HACK: The implementation must be made up of several `auxN`s instead -- of one large block because of a bug in the compiler. -instance Message.hasFromJson : HasFromJson Message := ⟨fun j => do +instance : HasFromJson Message := ⟨fun j => do let "2.0" ← j.getObjVal? "jsonrpc" | none aux1 j <|> aux2 j <|> aux3 j <|> aux4 j⟩ diff --git a/src/Lean/Data/KVMap.lean b/src/Lean/Data/KVMap.lean index 7a0b4664a9..5a4dc25acd 100644 --- a/src/Lean/Data/KVMap.lean +++ b/src/Lean/Data/KVMap.lean @@ -34,7 +34,7 @@ def DataValue.sameCtor : DataValue → DataValue → Bool | DataValue.ofInt _, DataValue.ofInt _ => true | _, _ => false -instance DataValue.HasBeq : HasBeq DataValue := ⟨DataValue.beq⟩ +instance : HasBeq DataValue := ⟨DataValue.beq⟩ @[export lean_data_value_to_string] def DataValue.str : DataValue → String @@ -44,13 +44,13 @@ def DataValue.str : DataValue → String | DataValue.ofNat v => toString v | DataValue.ofInt v => toString v -instance DataValue.hasToString : HasToString DataValue := ⟨DataValue.str⟩ +instance : HasToString DataValue := ⟨DataValue.str⟩ -instance string2DataValue : Coe String DataValue := ⟨DataValue.ofString⟩ -instance bool2DataValue : Coe Bool DataValue := ⟨DataValue.ofBool⟩ -instance name2DataValue : Coe Name DataValue := ⟨DataValue.ofName⟩ -instance nat2DataValue : Coe Nat DataValue := ⟨DataValue.ofNat⟩ -instance int2DataValue : Coe Int DataValue := ⟨DataValue.ofInt⟩ +instance : Coe String DataValue := ⟨DataValue.ofString⟩ +instance : Coe Bool DataValue := ⟨DataValue.ofBool⟩ +instance : Coe Name DataValue := ⟨DataValue.ofName⟩ +instance : Coe Nat DataValue := ⟨DataValue.ofNat⟩ +instance : Coe Int DataValue := ⟨DataValue.ofInt⟩ /- Remark: we do not use RBMap here because we need to manipulate KVMap objects in C++ and RBMap is implemented in Lean. So, we use just a List until we can diff --git a/src/Lean/Data/Lsp/InitShutdown.lean b/src/Lean/Data/Lsp/InitShutdown.lean index 4dab52ccb6..0c8d860d05 100644 --- a/src/Lean/Data/Lsp/InitShutdown.lean +++ b/src/Lean/Data/Lsp/InitShutdown.lean @@ -21,7 +21,7 @@ structure ClientInfo := (name : String) (version? : Option String := none) -instance ClientInfo.hasFromJson : HasFromJson ClientInfo := ⟨fun j => do +instance : HasFromJson ClientInfo := ⟨fun j => do let name ← j.getObjValAs? String "name" let version? := j.getObjValAs? String "version" pure ⟨name, version?⟩⟩ @@ -84,9 +84,10 @@ structure InitializeResult := (capabilities : ServerCapabilities) (serverInfo? : Option ServerInfo := none) -instance InitializeResult.hasToJson : HasToJson InitializeResult := ⟨fun o => mkObj $ - ⟨"capabilities", toJson o.capabilities⟩ :: - opt "serverInfo" o.serverInfo?⟩ +instance : HasToJson InitializeResult := ⟨fun o => + mkObj $ + ⟨"capabilities", toJson o.capabilities⟩ :: + opt "serverInfo" o.serverInfo?⟩ end Lsp end Lean diff --git a/src/Lean/Data/Lsp/TextSync.lean b/src/Lean/Data/Lsp/TextSync.lean index 52191b53b1..227519921b 100644 --- a/src/Lean/Data/Lsp/TextSync.lean +++ b/src/Lean/Data/Lsp/TextSync.lean @@ -20,14 +20,14 @@ inductive TextDocumentSyncKind | full | incremental -instance TextDocumentSyncKind.hasFromJson : HasFromJson TextDocumentSyncKind := ⟨fun j => +instance : HasFromJson TextDocumentSyncKind := ⟨fun j => match j.getNat? with | some 0 => TextDocumentSyncKind.none | some 1 => TextDocumentSyncKind.full | some 2 => TextDocumentSyncKind.incremental | _ => none⟩ -instance TextDocumentSyncKind.hasToJson : HasToJson TextDocumentSyncKind := ⟨fun o => +instance : HasToJson TextDocumentSyncKind := ⟨fun o => match o with | TextDocumentSyncKind.none => 0 | TextDocumentSyncKind.full => 1 @@ -56,7 +56,7 @@ inductive TextDocumentContentChangeEvent | rangeChange (range : Range) (text : String) | fullChange (text : String) -instance TextDocumentContentChangeEvent.hasFromJson : HasFromJson TextDocumentContentChangeEvent := ⟨fun j => +instance : HasFromJson TextDocumentContentChangeEvent := ⟨fun j => (do let range ← j.getObjValAs? Range "range" let text ← j.getObjValAs? String "text" @@ -67,7 +67,7 @@ structure DidChangeTextDocumentParams := (textDocument : VersionedTextDocumentIdentifier) (contentChanges : Array TextDocumentContentChangeEvent) -instance DidChangeTextDocumentParams.hasFromJson : HasFromJson DidChangeTextDocumentParams := ⟨fun j => do +instance : HasFromJson DidChangeTextDocumentParams := ⟨fun j => do let textDocument ← j.getObjValAs? VersionedTextDocumentIdentifier "textDocument" let contentChanges ← j.getObjValAs? (Array TextDocumentContentChangeEvent) "contentChanges" pure ⟨textDocument, contentChanges⟩⟩ @@ -78,12 +78,12 @@ instance DidChangeTextDocumentParams.hasFromJson : HasFromJson DidChangeTextDocu structure SaveOptions := (includeText : Bool) -instance SaveOptions.hasToJson : HasToJson SaveOptions := ⟨fun o => +instance : HasToJson SaveOptions := ⟨fun o => mkObj $ [⟨"includeText", o.includeText⟩]⟩ structure DidCloseTextDocumentParams := (textDocument : TextDocumentIdentifier) -instance DidCloseTextDocumentParams.hasFromJson : HasFromJson DidCloseTextDocumentParams := ⟨fun j => +instance : HasFromJson DidCloseTextDocumentParams := ⟨fun j => DidCloseTextDocumentParams.mk <$> j.getObjValAs? TextDocumentIdentifier "textDocument"⟩ -- TODO: TextDocumentSyncClientCapabilities @@ -96,12 +96,13 @@ structure TextDocumentSyncOptions := (willSaveWaitUntil : Bool) (save? : Option SaveOptions := none) -instance TextDocumentSyncOptions.hasToJson : HasToJson TextDocumentSyncOptions := ⟨fun o => mkObj $ - opt "save" o.save? ++ [ - ⟨"openClose", toJson o.openClose⟩, - ⟨"change", toJson o.change⟩, - ⟨"willSave", toJson o.willSave⟩, - ⟨"willSaveWaitUntil", toJson o.willSaveWaitUntil⟩]⟩ +instance : HasToJson TextDocumentSyncOptions := ⟨fun o => + mkObj $ + opt "save" o.save? ++ [ + ⟨"openClose", toJson o.openClose⟩, + ⟨"change", toJson o.change⟩, + ⟨"willSave", toJson o.willSave⟩, + ⟨"willSaveWaitUntil", toJson o.willSaveWaitUntil⟩]⟩ end Lsp end Lean diff --git a/src/Lean/Data/Options.lean b/src/Lean/Data/Options.lean index 40f4f10e5a..d9c194da52 100644 --- a/src/Lean/Data/Options.lean +++ b/src/Lean/Data/Options.lean @@ -10,103 +10,108 @@ namespace Lean def Options := KVMap -namespace Options -def empty : Options := {} -instance : Inhabited Options := ⟨empty⟩ +def Options.empty : Options := {} +instance : Inhabited Options := ⟨Options.empty⟩ instance : HasToString Options := inferInstanceAs (HasToString KVMap) -end Options structure OptionDecl := -(defValue : DataValue) -(group : String := "") -(descr : String := "") + (defValue : DataValue) + (group : String := "") + (descr : String := "") def OptionDecls := NameMap OptionDecl -instance OptionDecls.inhabited : Inhabited OptionDecls := -⟨({} : NameMap OptionDecl)⟩ +instance : Inhabited OptionDecls := ⟨({} : NameMap OptionDecl)⟩ private def initOptionDeclsRef : IO (IO.Ref OptionDecls) := -IO.mkRef (mkNameMap OptionDecl) + IO.mkRef (mkNameMap OptionDecl) @[builtinInit initOptionDeclsRef] private constant optionDeclsRef : IO.Ref OptionDecls := arbitrary _ @[export lean_register_option] def registerOption (name : Name) (decl : OptionDecl) : IO Unit := do -let decls ← optionDeclsRef.get -if decls.contains name then - throw $ IO.userError s!"invalid option declaration '{name}', option already exists" -optionDeclsRef.set $ decls.insert name decl + let decls ← optionDeclsRef.get + if decls.contains name then + throw $ IO.userError s!"invalid option declaration '{name}', option already exists" + optionDeclsRef.set $ decls.insert name decl def getOptionDecls : IO OptionDecls := optionDeclsRef.get @[export lean_get_option_decls_array] def getOptionDeclsArray : IO (Array (Name × OptionDecl)) := do -let decls ← getOptionDecls -pure $ decls.fold - (fun (r : Array (Name × OptionDecl)) k v => r.push (k, v)) - #[] + let decls ← getOptionDecls + pure $ decls.fold + (fun (r : Array (Name × OptionDecl)) k v => r.push (k, v)) + #[] def getOptionDecl (name : Name) : IO OptionDecl := do -let decls ← getOptionDecls -let (some decl) ← pure (decls.find? name) | throw $ IO.userError s!"unknown option '{name}'" -pure decl + let decls ← getOptionDecls + let (some decl) ← pure (decls.find? name) | throw $ IO.userError s!"unknown option '{name}'" + pure decl def getOptionDefaulValue (name : Name) : IO DataValue := do -let decl ← getOptionDecl name -pure decl.defValue + let decl ← getOptionDecl name + pure decl.defValue def getOptionDescr (name : Name) : IO String := do -let decl ← getOptionDecl name -pure decl.descr + let decl ← getOptionDecl name + pure decl.descr def setOptionFromString (opts : Options) (entry : String) : IO Options := do -let ps := (entry.splitOn "=").map String.trim -let [key, val] ← pure ps | throw $ IO.userError "invalid configuration option entry, it must be of the form ' = '" -let key := mkNameSimple key -let defValue ← getOptionDefaulValue key -match defValue with -| DataValue.ofString v => pure $ opts.setString key val -| DataValue.ofBool v => - if key == `true then pure $ opts.setBool key true - else if key == `false then pure $ opts.setBool key false - else throw $ IO.userError s!"invalid Bool option value '{val}'" -| DataValue.ofName v => pure $ opts.setName key val.toName -| DataValue.ofNat v => - match val.toNat? with - | none => throw (IO.userError s!"invalid Nat option value '{val}'") - | some v => pure $ opts.setNat key v -| DataValue.ofInt v => - match val.toInt? with - | none => throw (IO.userError s!"invalid Int option value '{val}'") - | some v => pure $ opts.setInt key v + let ps := (entry.splitOn "=").map String.trim + let [key, val] ← pure ps | throw $ IO.userError "invalid configuration option entry, it must be of the form ' = '" + let key := mkNameSimple key + let defValue ← getOptionDefaulValue key + match defValue with + | DataValue.ofString v => pure $ opts.setString key val + | DataValue.ofBool v => + if key == `true then pure $ opts.setBool key true + else if key == `false then pure $ opts.setBool key false + else throw $ IO.userError s!"invalid Bool option value '{val}'" + | DataValue.ofName v => pure $ opts.setName key val.toName + | DataValue.ofNat v => + match val.toNat? with + | none => throw (IO.userError s!"invalid Nat option value '{val}'") + | some v => pure $ opts.setNat key v + | DataValue.ofInt v => + match val.toInt? with + | none => throw (IO.userError s!"invalid Int option value '{val}'") + | some v => pure $ opts.setInt key v -builtin_initialize registerOption `verbose { defValue := true, group := "", descr := "disable/enable verbose messages" } - -builtin_initialize registerOption `timeout { defValue := DataValue.ofNat 0, group := "", descr := "the (deterministic) timeout is measured as the maximum of memory allocations (in thousands) per task, the default is unbounded" } - -builtin_initialize registerOption `maxMemory { defValue := DataValue.ofNat 2048, group := "", descr := "maximum amount of memory available for Lean in megabytes" } +builtin_initialize + registerOption `verbose { + defValue := true, + group := "", + descr := "disable/enable verbose messages" + } + registerOption `timeout { + defValue := DataValue.ofNat 0, + group := "", + descr := "the (deterministic) timeout is measured as the maximum of memory allocations (in thousands) per task, the default is unbounded" + } + registerOption `maxMemory { + defValue := DataValue.ofNat 2048, + group := "", + descr := "maximum amount of memory available for Lean in megabytes" + } class MonadOptions (m : Type → Type) := -(getOptions : m Options) + (getOptions : m Options) export MonadOptions (getOptions) -instance monadOptsFromLift (m n) [MonadOptions m] [MonadLift m n] : MonadOptions n := -{ getOptions := liftM (getOptions : m _) } +instance (m n) [MonadOptions m] [MonadLift m n] : MonadOptions n := + { getOptions := liftM (getOptions : m _) } -section Methods - -variables {m : Type → Type} [Monad m] [MonadOptions m] +variables {m} [Monad m] [MonadOptions m] def getBoolOption (k : Name) (defValue := false) : m Bool := do -let opts ← getOptions -pure $ opts.getBool k defValue + let opts ← getOptions + pure $ opts.getBool k defValue def getNatOption (k : Name) (defValue := 0) : m Nat := do -let opts ← getOptions -pure $ opts.getNat k defValue + let opts ← getOptions + pure $ opts.getNat k defValue -end Methods end Lean diff --git a/src/Lean/Declaration.lean b/src/Lean/Declaration.lean index 8a8b75ad16..c9d82c33d8 100644 --- a/src/Lean/Declaration.lean +++ b/src/Lean/Declaration.lean @@ -32,108 +32,140 @@ Remark: the ReducibilityHints are not related to the attributes: reducible/irrel These attributes are used by the Elaborator. The ReducibilityHints are used by the kernel (and Elaborator). Moreover, the ReducibilityHints cannot be changed after a declaration is added to the kernel. -/ inductive ReducibilityHints -| opaque : ReducibilityHints -| «abbrev» : ReducibilityHints -| regular : UInt32 → ReducibilityHints + | opaque : ReducibilityHints + | «abbrev» : ReducibilityHints + | regular : UInt32 → ReducibilityHints @[export lean_mk_reducibility_hints_regular] -def mkReducibilityHintsRegularEx (h : UInt32) : ReducibilityHints := ReducibilityHints.regular h +def mkReducibilityHintsRegularEx (h : UInt32) : ReducibilityHints := + ReducibilityHints.regular h + @[export lean_reducibility_hints_get_height] def ReducibilityHints.getHeightEx (h : ReducibilityHints) : UInt32 := -match h with -| ReducibilityHints.regular h => h -| _ => 0 + match h with + | ReducibilityHints.regular h => h + | _ => 0 namespace ReducibilityHints instance : Inhabited ReducibilityHints := ⟨opaque⟩ def lt : ReducibilityHints → ReducibilityHints → Bool -| «abbrev», «abbrev» => false -| «abbrev», _ => true -| regular d₁, regular d₂ => d₁ < d₂ -| regular _, opaque => true -| _, _ => false + | «abbrev», «abbrev» => false + | «abbrev», _ => true + | regular d₁, regular d₂ => d₁ < d₂ + | regular _, opaque => true + | _, _ => false end ReducibilityHints /-- Base structure for `AxiomVal`, `DefinitionVal`, `TheoremVal`, `InductiveVal`, `ConstructorVal`, `RecursorVal` and `QuotVal`. -/ structure ConstantVal := -(name : Name) (lparams : List Name) (type : Expr) + (name : Name) + (lparams : List Name) + (type : Expr) -instance ConstantVal.inhabited : Inhabited ConstantVal := ⟨{ name := arbitrary _, lparams := arbitrary _, type := arbitrary _ }⟩ +instance : Inhabited ConstantVal := ⟨{ name := arbitrary _, lparams := arbitrary _, type := arbitrary _ }⟩ structure AxiomVal extends ConstantVal := -(isUnsafe : Bool) + (isUnsafe : Bool) @[export lean_mk_axiom_val] -def mkAxiomValEx (name : Name) (lparams : List Name) (type : Expr) (isUnsafe : Bool) : AxiomVal := -{ name := name, lparams := lparams, type := type, isUnsafe := isUnsafe } -@[export lean_axiom_val_is_unsafe] def AxiomVal.isUnsafeEx (v : AxiomVal) : Bool := v.isUnsafe +def mkAxiomValEx (name : Name) (lparams : List Name) (type : Expr) (isUnsafe : Bool) : AxiomVal := { + name := name, + lparams := lparams, + type := type, + isUnsafe := isUnsafe +} + +@[export lean_axiom_val_is_unsafe] def AxiomVal.isUnsafeEx (v : AxiomVal) : Bool := + v.isUnsafe structure DefinitionVal extends ConstantVal := -(value : Expr) (hints : ReducibilityHints) (isUnsafe : Bool) + (value : Expr) + (hints : ReducibilityHints) + (isUnsafe : Bool) @[export lean_mk_definition_val] -def mkDefinitionValEx (name : Name) (lparams : List Name) (type : Expr) (val : Expr) (hints : ReducibilityHints) (isUnsafe : Bool) : DefinitionVal := -{ name := name, lparams := lparams, type := type, value := val, hints := hints, isUnsafe := isUnsafe } -@[export lean_definition_val_is_unsafe] def DefinitionVal.isUnsafeEx (v : DefinitionVal) : Bool := v.isUnsafe +def mkDefinitionValEx (name : Name) (lparams : List Name) (type : Expr) (val : Expr) (hints : ReducibilityHints) (isUnsafe : Bool) : DefinitionVal := { + name := name, + lparams := lparams, + type := type, + value := val, + hints := hints, + isUnsafe := isUnsafe +} + +@[export lean_definition_val_is_unsafe] def DefinitionVal.isUnsafeEx (v : DefinitionVal) : Bool := + v.isUnsafe structure TheoremVal extends ConstantVal := -(value : Expr) + (value : Expr) /- Value for an opaque constant declaration `constant x : t := e` -/ structure OpaqueVal extends ConstantVal := -(value : Expr) (isUnsafe : Bool) + (value : Expr) + (isUnsafe : Bool) @[export lean_mk_opaque_val] -def mkOpaqueValEx (name : Name) (lparams : List Name) (type : Expr) (val : Expr) (isUnsafe : Bool) : OpaqueVal := -{ name := name, lparams := lparams, type := type, value := val, isUnsafe := isUnsafe } -@[export lean_opaque_val_is_unsafe] def OpaqueVal.isUnsafeEx (v : OpaqueVal) : Bool := v.isUnsafe +def mkOpaqueValEx (name : Name) (lparams : List Name) (type : Expr) (val : Expr) (isUnsafe : Bool) : OpaqueVal := { + name := name, + lparams := lparams, + type := type, + value := val, + isUnsafe := isUnsafe +} + +@[export lean_opaque_val_is_unsafe] def OpaqueVal.isUnsafeEx (v : OpaqueVal) : Bool := + v.isUnsafe structure Constructor := -(name : Name) (type : Expr) + (name : Name) + (type : Expr) structure InductiveType := -(name : Name) (type : Expr) (ctors : List Constructor) + (name : Name) + (type : Expr) + (ctors : List Constructor) /-- Declaration object that can be sent to the kernel. -/ inductive Declaration -| axiomDecl (val : AxiomVal) -| defnDecl (val : DefinitionVal) -| thmDecl (val : TheoremVal) -| opaqueDecl (val : OpaqueVal) -| quotDecl -| mutualDefnDecl (defns : List DefinitionVal) -- All definitions must be marked as `unsafe` -| inductDecl (lparams : List Name) (nparams : Nat) (types : List InductiveType) (isUnsafe : Bool) + | axiomDecl (val : AxiomVal) + | defnDecl (val : DefinitionVal) + | thmDecl (val : TheoremVal) + | opaqueDecl (val : OpaqueVal) + | quotDecl + | mutualDefnDecl (defns : List DefinitionVal) -- All definitions must be marked as `unsafe` + | inductDecl (lparams : List Name) (nparams : Nat) (types : List InductiveType) (isUnsafe : Bool) -instance Declaration.inhabited : Inhabited Declaration := ⟨Declaration.quotDecl⟩ +instance : Inhabited Declaration := ⟨Declaration.quotDecl⟩ @[export lean_mk_inductive_decl] def mkInductiveDeclEs (lparams : List Name) (nparams : Nat) (types : List InductiveType) (isUnsafe : Bool) : Declaration := -Declaration.inductDecl lparams nparams types isUnsafe + Declaration.inductDecl lparams nparams types isUnsafe + @[export lean_is_unsafe_inductive_decl] def Declaration.isUnsafeInductiveDeclEx : Declaration → Bool -| Declaration.inductDecl _ _ _ isUnsafe => isUnsafe -| _ => false + | Declaration.inductDecl _ _ _ isUnsafe => isUnsafe + | _ => false @[specialize] def Declaration.foldExprM {α} {m : Type → Type} [Monad m] (d : Declaration) (f : α → Expr → m α) (a : α) : m α := -match d with -| Declaration.quotDecl => pure a -| Declaration.axiomDecl { type := type, .. } => f a type -| Declaration.defnDecl { type := type, value := value, .. } => do let a ← f a type; f a value -| Declaration.opaqueDecl { type := type, value := value, .. } => do let a ← f a type; f a value -| Declaration.thmDecl { type := type, value := value, .. } => do let a ← f a type; f a value -| Declaration.mutualDefnDecl vals => vals.foldlM (fun a v => do let a ← f a v.type; f a v.value) a -| Declaration.inductDecl _ _ inductTypes _ => - inductTypes.foldlM - (fun a inductType => do - let a ← f a inductType.type - inductType.ctors.foldlM (fun a ctor => f a ctor.type) a) - a + match d with + | Declaration.quotDecl => pure a + | Declaration.axiomDecl { type := type, .. } => f a type + | Declaration.defnDecl { type := type, value := value, .. } => do let a ← f a type; f a value + | Declaration.opaqueDecl { type := type, value := value, .. } => do let a ← f a type; f a value + | Declaration.thmDecl { type := type, value := value, .. } => do let a ← f a type; f a value + | Declaration.mutualDefnDecl vals => vals.foldlM (fun a v => do let a ← f a v.type; f a v.value) a + | Declaration.inductDecl _ _ inductTypes _ => + inductTypes.foldlM + (fun a inductType => do + let a ← f a inductType.type + inductType.ctors.foldlM (fun a ctor => f a ctor.type) a) + a @[inline] def Declaration.forExprM {m : Type → Type} [Monad m] (d : Declaration) (f : Expr → m Unit) : m Unit := -d.foldExprM (fun _ a => f a) () + d.foldExprM (fun _ a => f a) () /-- The kernel compiles (mutual) inductive declarations (see `inductiveDecls`) into a set of - `Declaration.inductDecl` (for each inductive datatype in the mutual Declaration), @@ -146,163 +178,181 @@ d.foldExprM (fun _ a => f a) () A series of checks are performed by the kernel to check whether a `inductiveDecls` is valid or not. -/ structure InductiveVal extends ConstantVal := -(nparams : Nat) -- Number of parameters -(nindices : Nat) -- Number of indices -(all : List Name) -- List of all (including this one) inductive datatypes in the mutual declaration containing this one -(ctors : List Name) -- List of all constructors for this inductive datatype -(isRec : Bool) -- `true` Iff it is recursive -(isUnsafe : Bool) -(isReflexive : Bool) + (nparams : Nat) -- Number of parameters + (nindices : Nat) -- Number of indices + (all : List Name) -- List of all (including this one) inductive datatypes in the mutual declaration containing this one + (ctors : List Name) -- List of all constructors for this inductive datatype + (isRec : Bool) -- `true` Iff it is recursive + (isUnsafe : Bool) + (isReflexive : Bool) @[export lean_mk_inductive_val] def mkInductiveValEx (name : Name) (lparams : List Name) (type : Expr) (nparams nindices : Nat) - (all ctors : List Name) (isRec isUnsafe isReflexive : Bool) : InductiveVal := -{ name := name, lparams := lparams, type := type, nparams := nparams, nindices := nindices, all := all, ctors := ctors, - isRec := isRec, isUnsafe := isUnsafe, isReflexive := isReflexive } + (all ctors : List Name) (isRec isUnsafe isReflexive : Bool) : InductiveVal := { + name := name, + lparams := lparams, + type := type, + nparams := nparams, + nindices := nindices, + all := all, + ctors := ctors, + isRec := isRec, + isUnsafe := isUnsafe, + isReflexive := isReflexive +} + @[export lean_inductive_val_is_rec] def InductiveVal.isRecEx (v : InductiveVal) : Bool := v.isRec @[export lean_inductive_val_is_unsafe] def InductiveVal.isUnsafeEx (v : InductiveVal) : Bool := v.isUnsafe @[export lean_inductive_val_is_reflexive] def InductiveVal.isReflexiveEx (v : InductiveVal) : Bool := v.isReflexive -namespace InductiveVal -def nctors (v : InductiveVal) : Nat := v.ctors.length -end InductiveVal +def InductiveVal.nctors (v : InductiveVal) : Nat := v.ctors.length structure ConstructorVal extends ConstantVal := -(induct : Name) -- Inductive Type this Constructor is a member of -(cidx : Nat) -- Constructor index (i.e., Position in the inductive declaration) -(nparams : Nat) -- Number of parameters in inductive datatype `induct` -(nfields : Nat) -- Number of fields (i.e., arity - nparams) -(isUnsafe : Bool) + (induct : Name) -- Inductive Type this Constructor is a member of + (cidx : Nat) -- Constructor index (i.e., Position in the inductive declaration) + (nparams : Nat) -- Number of parameters in inductive datatype `induct` + (nfields : Nat) -- Number of fields (i.e., arity - nparams) + (isUnsafe : Bool) @[export lean_mk_constructor_val] -def mkConstructorValEx (name : Name) (lparams : List Name) (type : Expr) (induct : Name) (cidx nparams nfields : Nat) (isUnsafe : Bool) : ConstructorVal := -{ name := name, lparams := lparams, type := type, induct := induct, cidx := cidx, nparams := nparams, nfields := nfields, isUnsafe := isUnsafe } +def mkConstructorValEx (name : Name) (lparams : List Name) (type : Expr) (induct : Name) (cidx nparams nfields : Nat) (isUnsafe : Bool) : ConstructorVal := { + name := name, + lparams := lparams, + type := type, + induct := induct, + cidx := cidx, + nparams := nparams, + nfields := nfields, + isUnsafe := isUnsafe +} + @[export lean_constructor_val_is_unsafe] def ConstructorVal.isUnsafeEx (v : ConstructorVal) : Bool := v.isUnsafe -instance ConstructorVal.inhabited : Inhabited ConstructorVal := -⟨{ toConstantVal := arbitrary _, induct := arbitrary _, cidx := 0, nparams := 0, nfields := 0, isUnsafe := true }⟩ +instance : Inhabited ConstructorVal := + ⟨{ toConstantVal := arbitrary _, induct := arbitrary _, cidx := 0, nparams := 0, nfields := 0, isUnsafe := true }⟩ /-- Information for reducing a recursor -/ structure RecursorRule := -(ctor : Name) -- Reduction rule for this Constructor -(nfields : Nat) -- Number of fields (i.e., without counting inductive datatype parameters) -(rhs : Expr) -- Right hand side of the reduction rule + (ctor : Name) -- Reduction rule for this Constructor + (nfields : Nat) -- Number of fields (i.e., without counting inductive datatype parameters) + (rhs : Expr) -- Right hand side of the reduction rule structure RecursorVal extends ConstantVal := -(all : List Name) -- List of all inductive datatypes in the mutual declaration that generated this recursor -(nparams : Nat) -- Number of parameters -(nindices : Nat) -- Number of indices -(nmotives : Nat) -- Number of motives -(nminors : Nat) -- Number of minor premises -(rules : List RecursorRule) -- A reduction for each Constructor -(k : Bool) -- It supports K-like reduction -(isUnsafe : Bool) + (all : List Name) -- List of all inductive datatypes in the mutual declaration that generated this recursor + (nparams : Nat) -- Number of parameters + (nindices : Nat) -- Number of indices + (nmotives : Nat) -- Number of motives + (nminors : Nat) -- Number of minor premises + (rules : List RecursorRule) -- A reduction for each Constructor + (k : Bool) -- It supports K-like reduction + (isUnsafe : Bool) @[export lean_mk_recursor_val] def mkRecursorValEx (name : Name) (lparams : List Name) (type : Expr) (all : List Name) (nparams nindices nmotives nminors : Nat) - (rules : List RecursorRule) (k isUnsafe : Bool) : RecursorVal := -{ name := name, lparams := lparams, type := type, all := all, nparams := nparams, nindices := nindices, - nmotives := nmotives, nminors := nminors, rules := rules, k := k, isUnsafe := isUnsafe } + (rules : List RecursorRule) (k isUnsafe : Bool) : RecursorVal := { + name := name, lparams := lparams, type := type, all := all, nparams := nparams, nindices := nindices, + nmotives := nmotives, nminors := nminors, rules := rules, k := k, isUnsafe := isUnsafe +} + @[export lean_recursor_k] def RecursorVal.kEx (v : RecursorVal) : Bool := v.k @[export lean_recursor_is_unsafe] def RecursorVal.isUnsafeEx (v : RecursorVal) : Bool := v.isUnsafe -namespace RecursorVal -def getMajorIdx (v : RecursorVal) : Nat := -v.nparams + v.nmotives + v.nminors + v.nindices +def RecursorVal.getMajorIdx (v : RecursorVal) : Nat := + v.nparams + v.nmotives + v.nminors + v.nindices -def getInduct (v : RecursorVal) : Name := -v.name.getPrefix - -end RecursorVal +def RecursorVal.getInduct (v : RecursorVal) : Name := + v.name.getPrefix inductive QuotKind -| type -- `Quot` -| ctor -- `Quot.mk` -| lift -- `Quot.lift` -| ind -- `Quot.ind` + | type -- `Quot` + | ctor -- `Quot.mk` + | lift -- `Quot.lift` + | ind -- `Quot.ind` structure QuotVal extends ConstantVal := -(kind : QuotKind) + (kind : QuotKind) @[export lean_mk_quot_val] -def mkQuotValEx (name : Name) (lparams : List Name) (type : Expr) (kind : QuotKind) : QuotVal := -{ name := name, lparams := lparams, type := type, kind := kind } +def mkQuotValEx (name : Name) (lparams : List Name) (type : Expr) (kind : QuotKind) : QuotVal := { + name := name, lparams := lparams, type := type, kind := kind +} + @[export lean_quot_val_kind] def QuotVal.kindEx (v : QuotVal) : QuotKind := v.kind /-- Information associated with constant declarations. -/ inductive ConstantInfo -| axiomInfo (val : AxiomVal) -| defnInfo (val : DefinitionVal) -| thmInfo (val : TheoremVal) -| opaqueInfo (val : OpaqueVal) -| quotInfo (val : QuotVal) -| inductInfo (val : InductiveVal) -| ctorInfo (val : ConstructorVal) -| recInfo (val : RecursorVal) + | axiomInfo (val : AxiomVal) + | defnInfo (val : DefinitionVal) + | thmInfo (val : TheoremVal) + | opaqueInfo (val : OpaqueVal) + | quotInfo (val : QuotVal) + | inductInfo (val : InductiveVal) + | ctorInfo (val : ConstructorVal) + | recInfo (val : RecursorVal) namespace ConstantInfo def toConstantVal : ConstantInfo → ConstantVal -| defnInfo {toConstantVal := d, ..} => d -| axiomInfo {toConstantVal := d, ..} => d -| thmInfo {toConstantVal := d, ..} => d -| opaqueInfo {toConstantVal := d, ..} => d -| quotInfo {toConstantVal := d, ..} => d -| inductInfo {toConstantVal := d, ..} => d -| ctorInfo {toConstantVal := d, ..} => d -| recInfo {toConstantVal := d, ..} => d + | defnInfo {toConstantVal := d, ..} => d + | axiomInfo {toConstantVal := d, ..} => d + | thmInfo {toConstantVal := d, ..} => d + | opaqueInfo {toConstantVal := d, ..} => d + | quotInfo {toConstantVal := d, ..} => d + | inductInfo {toConstantVal := d, ..} => d + | ctorInfo {toConstantVal := d, ..} => d + | recInfo {toConstantVal := d, ..} => d def isUnsafe : ConstantInfo → Bool -| defnInfo v => v.isUnsafe -| axiomInfo v => v.isUnsafe -| thmInfo v => false -| opaqueInfo v => v.isUnsafe -| quotInfo v => false -| inductInfo v => v.isUnsafe -| ctorInfo v => v.isUnsafe -| recInfo v => v.isUnsafe + | defnInfo v => v.isUnsafe + | axiomInfo v => v.isUnsafe + | thmInfo v => false + | opaqueInfo v => v.isUnsafe + | quotInfo v => false + | inductInfo v => v.isUnsafe + | ctorInfo v => v.isUnsafe + | recInfo v => v.isUnsafe def name (d : ConstantInfo) : Name := -d.toConstantVal.name + d.toConstantVal.name def lparams (d : ConstantInfo) : List Name := -d.toConstantVal.lparams + d.toConstantVal.lparams def type (d : ConstantInfo) : Expr := -d.toConstantVal.type + d.toConstantVal.type def value? : ConstantInfo → Option Expr -| defnInfo {value := r, ..} => some r -| thmInfo {value := r, ..} => some r -| _ => none + | defnInfo {value := r, ..} => some r + | thmInfo {value := r, ..} => some r + | _ => none def hasValue : ConstantInfo → Bool -| defnInfo {value := r, ..} => true -| thmInfo {value := r, ..} => true -| _ => false + | defnInfo {value := r, ..} => true + | thmInfo {value := r, ..} => true + | _ => false def value! : ConstantInfo → Expr -| defnInfo {value := r, ..} => r -| thmInfo {value := r, ..} => r -| _ => panic! "declaration with value expected" + | defnInfo {value := r, ..} => r + | thmInfo {value := r, ..} => r + | _ => panic! "declaration with value expected" def hints : ConstantInfo → ReducibilityHints -| defnInfo {hints := r, ..} => r -| _ => ReducibilityHints.opaque + | defnInfo {hints := r, ..} => r + | _ => ReducibilityHints.opaque def isCtor : ConstantInfo → Bool -| ctorInfo _ => true -| _ => false + | ctorInfo _ => true + | _ => false @[extern "lean_instantiate_type_lparams"] -constant instantiateTypeLevelParams (c : @& ConstantInfo) (ls : @& List Level) : Expr := arbitrary _ +constant instantiateTypeLevelParams (c : @& ConstantInfo) (ls : @& List Level) : Expr @[extern "lean_instantiate_value_lparams"] -constant instantiateValueLevelParams (c : @& ConstantInfo) (ls : @& List Level) : Expr := arbitrary _ +constant instantiateValueLevelParams (c : @& ConstantInfo) (ls : @& List Level) : Expr end ConstantInfo def mkRecFor (declName : Name) : Name := -mkNameStr declName "rec" + mkNameStr declName "rec" end Lean diff --git a/src/Lean/Elab/App.lean b/src/Lean/Elab/App.lean index cd0d714640..c827112ae7 100644 --- a/src/Lean/Elab/App.lean +++ b/src/Lean/Elab/App.lean @@ -19,10 +19,9 @@ inductive Arg | stx (val : Syntax) | expr (val : Expr) -instance Arg.inhabited : Inhabited Arg := ⟨Arg.stx (arbitrary _)⟩ +instance : Inhabited Arg := ⟨Arg.stx (arbitrary _)⟩ -instance Arg.hasToString : HasToString Arg := -⟨fun +instance : HasToString Arg := ⟨fun | Arg.stx val => toString val | Arg.expr val => toString val⟩ @@ -30,10 +29,10 @@ instance Arg.hasToString : HasToString Arg := structure NamedArg := (name : Name) (val : Arg) -instance NamedArg.hasToString : HasToString NamedArg := +instance : HasToString NamedArg := ⟨fun s => "(" ++ toString s.name ++ " := " ++ toString s.val ++ ")"⟩ -instance NamedArg.inhabited : Inhabited NamedArg := ⟨{ name := arbitrary _, val := arbitrary _ }⟩ +instance : Inhabited NamedArg := ⟨{ name := arbitrary _, val := arbitrary _ }⟩ /-- Add a new named argument to `namedArgs`, and throw an error if it already contains a named argument diff --git a/src/Lean/Elab/Attributes.lean b/src/Lean/Elab/Attributes.lean index 3afd637f2c..3adf6967f7 100644 --- a/src/Lean/Elab/Attributes.lean +++ b/src/Lean/Elab/Attributes.lean @@ -12,34 +12,34 @@ namespace Lean.Elab structure Attribute := (name : Name) (args : Syntax := Syntax.missing) -instance Attribute.hasFormat : HasFormat Attribute := -⟨fun attr => Format.bracket "@[" (toString attr.name ++ (if attr.args.isMissing then "" else toString attr.args)) "]"⟩ +instance : HasFormat Attribute := ⟨fun attr => + Format.bracket "@[" (toString attr.name ++ (if attr.args.isMissing then "" else toString attr.args)) "]"⟩ -instance Attribute.inhabited : Inhabited Attribute := ⟨{ name := arbitrary _ }⟩ +instance : Inhabited Attribute := ⟨{ name := arbitrary _ }⟩ def elabAttr {m} [Monad m] [MonadEnv m] [MonadExceptOf Exception m] [Ref m] [AddErrorMessageContext m] (stx : Syntax) : m Attribute := do --- rawIdent >> many attrArg -let nameStx := stx[0] -let attrName ← match nameStx.isIdOrAtom? with - | none => withRef nameStx $ throwError "identifier expected" - | some str => pure $ mkNameSimple str -unless isAttribute (← getEnv) attrName do - throwError! "unknown attribute [{attrName}]" -let args := stx[1] --- the old frontend passes Syntax.missing for empty args, for reasons -if args.getNumArgs == 0 then - args := Syntax.missing -pure { name := attrName, args := args } + -- rawIdent >> many attrArg + let nameStx := stx[0] + let attrName ← match nameStx.isIdOrAtom? with + | none => withRef nameStx $ throwError "identifier expected" + | some str => pure $ mkNameSimple str + unless isAttribute (← getEnv) attrName do + throwError! "unknown attribute [{attrName}]" + let args := stx[1] + -- the old frontend passes Syntax.missing for empty args, for reasons + if args.getNumArgs == 0 then + args := Syntax.missing + pure { name := attrName, args := args } -- sepBy1 attrInstance ", " def elabAttrs {m} [Monad m] [MonadEnv m] [MonadExceptOf Exception m] [Ref m] [AddErrorMessageContext m] (stx : Syntax) : m (Array Attribute) := do -let attrs := #[] -for attr in stx.getSepArgs do - attrs := attrs.push (← elabAttr attr) -return attrs + let attrs := #[] + for attr in stx.getSepArgs do + attrs := attrs.push (← elabAttr attr) + return attrs -- parser! "@[" >> sepBy1 attrInstance ", " >> "]" def elabDeclAttrs {m} [Monad m] [MonadEnv m] [MonadExceptOf Exception m] [Ref m] [AddErrorMessageContext m] (stx : Syntax) : m (Array Attribute) := -elabAttrs stx[1] + elabAttrs stx[1] end Lean.Elab diff --git a/src/Lean/Elab/BuiltinNotation.lean b/src/Lean/Elab/BuiltinNotation.lean index 6d0339cd39..ea5c3711e4 100644 --- a/src/Lean/Elab/BuiltinNotation.lean +++ b/src/Lean/Elab/BuiltinNotation.lean @@ -13,128 +13,126 @@ import Lean.Elab.SyntheticMVars namespace Lean.Elab.Term open Meta -@[builtinMacro Lean.Parser.Term.dollar] def expandDollar : Macro := -fun stx => match_syntax stx with -| `($f $args* $ $a) => let args := args.push a; `($f $args*) -| `($f $ $a) => `($f $a) -| _ => Macro.throwUnsupported +@[builtinMacro Lean.Parser.Term.dollar] def expandDollar : Macro := fun stx => + match_syntax stx with + | `($f $args* $ $a) => let args := args.push a; `($f $args*) + | `($f $ $a) => `($f $a) + | _ => Macro.throwUnsupported -@[builtinMacro Lean.Parser.Term.if] def expandIf : Macro := -fun stx => match_syntax stx with -| `(if $h : $cond then $t else $e) => `(dite $cond (fun $h:ident => $t) (fun $h:ident => $e)) -| `(if $cond then $t else $e) => `(ite $cond $t $e) -| _ => Macro.throwUnsupported +@[builtinMacro Lean.Parser.Term.if] def expandIf : Macro := fun stx => + match_syntax stx with + | `(if $h : $cond then $t else $e) => `(dite $cond (fun $h:ident => $t) (fun $h:ident => $e)) + | `(if $cond then $t else $e) => `(ite $cond $t $e) + | _ => Macro.throwUnsupported -@[builtinMacro Lean.Parser.Term.subtype] def expandSubtype : Macro := -fun stx => match_syntax stx with -| `({ $x : $type // $p }) => `(Subtype (fun ($x:ident : $type) => $p)) -| `({ $x // $p }) => `(Subtype (fun ($x:ident : _) => $p)) -| _ => Macro.throwUnsupported +@[builtinMacro Lean.Parser.Term.subtype] def expandSubtype : Macro := fun stx => + match_syntax stx with + | `({ $x : $type // $p }) => `(Subtype (fun ($x:ident : $type) => $p)) + | `({ $x // $p }) => `(Subtype (fun ($x:ident : _) => $p)) + | _ => Macro.throwUnsupported -@[builtinTermElab anonymousCtor] def elabAnonymousCtor : TermElab := -fun stx expectedType? => match_syntax stx with -| `(⟨$args*⟩) => do - tryPostponeIfNoneOrMVar expectedType? - match expectedType? with - | some expectedType => - let expectedType ← whnf expectedType - matchConstInduct expectedType.getAppFn - (fun _ => throwError! "invalid constructor ⟨...⟩, expected type must be an inductive type {indentExpr expectedType}") - (fun ival us => do - match ival.ctors with - | [ctor] => - let newStx ← `($(mkCIdentFrom stx ctor) $(args.getSepElems)*) - withMacroExpansion stx newStx $ elabTerm newStx expectedType? - | _ => throwError! "invalid constructor ⟨...⟩, expected type must be an inductive type with only one constructor {indentExpr expectedType}") - | none => throwError "invalid constructor ⟨...⟩, expected type must be known" -| _ => throwUnsupportedSyntax +@[builtinTermElab anonymousCtor] def elabAnonymousCtor : TermElab := fun stx expectedType? => + match_syntax stx with + | `(⟨$args*⟩) => do + tryPostponeIfNoneOrMVar expectedType? + match expectedType? with + | some expectedType => + let expectedType ← whnf expectedType + matchConstInduct expectedType.getAppFn + (fun _ => throwError! "invalid constructor ⟨...⟩, expected type must be an inductive type {indentExpr expectedType}") + (fun ival us => do + match ival.ctors with + | [ctor] => + let newStx ← `($(mkCIdentFrom stx ctor) $(args.getSepElems)*) + withMacroExpansion stx newStx $ elabTerm newStx expectedType? + | _ => throwError! "invalid constructor ⟨...⟩, expected type must be an inductive type with only one constructor {indentExpr expectedType}") + | none => throwError "invalid constructor ⟨...⟩, expected type must be known" + | _ => throwUnsupportedSyntax -@[builtinTermElab borrowed] def elabBorrowed : TermElab := -fun stx expectedType? => match_syntax stx with +@[builtinTermElab borrowed] def elabBorrowed : TermElab := fun stx expectedType? => + match_syntax stx with | `(@& $e) => do return markBorrowed (← elabTerm e expectedType?) | _ => throwUnsupportedSyntax -@[builtinMacro Lean.Parser.Term.show] def expandShow : Macro := -fun stx => match_syntax stx with -| `(show $type from $val) => let thisId := mkIdentFrom stx `this; `(let! $thisId : $type := $val; $thisId) -| `(show $type by $tac:tacticSeq) => `(show $type from by $tac:tacticSeq) -| _ => Macro.throwUnsupported +@[builtinMacro Lean.Parser.Term.show] def expandShow : Macro := fun stx => + match_syntax stx with + | `(show $type from $val) => let thisId := mkIdentFrom stx `this; `(let! $thisId : $type := $val; $thisId) + | `(show $type by $tac:tacticSeq) => `(show $type from by $tac:tacticSeq) + | _ => Macro.throwUnsupported -@[builtinMacro Lean.Parser.Term.have] def expandHave : Macro := -fun stx => -let stx := stx.setArg 4 (mkNullNode #[mkAtomFrom stx ";"]) -- HACK -match_syntax stx with -| `(have $type from $val; $body) => let thisId := mkIdentFrom stx `this; `(let! $thisId : $type := $val; $body) -| `(have $type by $tac:tacticSeq; $body) => `(have $type from by $tac:tacticSeq; $body) -| `(have $type := $val; $body) => let thisId := mkIdentFrom stx `this; `(let! $thisId : $type := $val; $body) -| `(have $x : $type from $val; $body) => `(let! $x:ident : $type := $val; $body) -| `(have $x : $type by $tac:tacticSeq; $body) => `(have $x : $type from by $tac:tacticSeq; $body) -| `(have $x : $type := $val; $body) => `(let! $x:ident : $type := $val; $body) -| _ => Macro.throwUnsupported +@[builtinMacro Lean.Parser.Term.have] def expandHave : Macro := fun stx => + let stx := stx.setArg 4 (mkNullNode #[mkAtomFrom stx ";"]) -- HACK + match_syntax stx with + | `(have $type from $val; $body) => let thisId := mkIdentFrom stx `this; `(let! $thisId : $type := $val; $body) + | `(have $type by $tac:tacticSeq; $body) => `(have $type from by $tac:tacticSeq; $body) + | `(have $type := $val; $body) => let thisId := mkIdentFrom stx `this; `(let! $thisId : $type := $val; $body) + | `(have $x : $type from $val; $body) => `(let! $x:ident : $type := $val; $body) + | `(have $x : $type by $tac:tacticSeq; $body) => `(have $x : $type from by $tac:tacticSeq; $body) + | `(have $x : $type := $val; $body) => `(let! $x:ident : $type := $val; $body) + | _ => Macro.throwUnsupported -@[builtinMacro Lean.Parser.Term.where] def expandWhere : Macro := -fun stx => match_syntax stx with -| `($body where $decls:letDecl*) => do - let decls := decls.getEvenElems - decls.foldrM - (fun decl body => `(let $decl:letDecl; $body)) - body -| _ => Macro.throwUnsupported +@[builtinMacro Lean.Parser.Term.where] def expandWhere : Macro := fun stx => + match_syntax stx with + | `($body where $decls:letDecl*) => do + let decls := decls.getEvenElems + decls.foldrM + (fun decl body => `(let $decl:letDecl; $body)) + body + | _ => Macro.throwUnsupported private def elabParserMacroAux (prec : Syntax) (e : Syntax) : TermElabM Syntax := do -let (some declName) ← getDeclName? - | throwError "invalid `parser!` macro, it must be used in definitions" -match extractMacroScopes declName with -| { name := Name.str _ s _, scopes := scps, .. } => - let kind := quote declName - let s := quote s - let p ← `(Lean.Parser.leadingNode $kind $prec $e) - if scps == [] then - -- TODO simplify the following quotation as soon as we have coercions - `(HasOrelse.orelse (Lean.Parser.mkAntiquot $s (some $kind)) $p) - else - -- if the parser decl is hidden by hygiene, it doesn't make sense to provide an antiquotation kind - `(HasOrelse.orelse (Lean.Parser.mkAntiquot $s none) $p) -| _ => throwError "invalid `parser!` macro, unexpected declaration name" + let (some declName) ← getDeclName? + | throwError "invalid `parser!` macro, it must be used in definitions" + match extractMacroScopes declName with + | { name := Name.str _ s _, scopes := scps, .. } => + let kind := quote declName + let s := quote s + let p ← `(Lean.Parser.leadingNode $kind $prec $e) + if scps == [] then + -- TODO simplify the following quotation as soon as we have coercions + `(HasOrelse.orelse (Lean.Parser.mkAntiquot $s (some $kind)) $p) + else + -- if the parser decl is hidden by hygiene, it doesn't make sense to provide an antiquotation kind + `(HasOrelse.orelse (Lean.Parser.mkAntiquot $s none) $p) + | _ => throwError "invalid `parser!` macro, unexpected declaration name" @[builtinTermElab «parser!»] def elabParserMacro : TermElab := -adaptExpander $ fun stx => match_syntax stx with -| `(parser! $e) => elabParserMacroAux (quote Parser.maxPrec) e -| `(parser! : $prec $e) => elabParserMacroAux prec e -| _ => throwUnsupportedSyntax + adaptExpander fun stx => match_syntax stx with + | `(parser! $e) => elabParserMacroAux (quote Parser.maxPrec) e + | `(parser! : $prec $e) => elabParserMacroAux prec e + | _ => throwUnsupportedSyntax private def elabTParserMacroAux (prec : Syntax) (e : Syntax) : TermElabM Syntax := do -let declName? ← getDeclName? -match declName? with -| some declName => let kind := quote declName; `(Lean.Parser.trailingNode $kind $prec $e) -| none => throwError "invalid `tparser!` macro, it must be used in definitions" + let declName? ← getDeclName? + match declName? with + | some declName => let kind := quote declName; `(Lean.Parser.trailingNode $kind $prec $e) + | none => throwError "invalid `tparser!` macro, it must be used in definitions" @[builtinTermElab «tparser!»] def elabTParserMacro : TermElab := -adaptExpander $ fun stx => match_syntax stx with -| `(tparser! $e) => elabTParserMacroAux (quote Parser.maxPrec) e -| `(tparser! : $prec $e) => elabTParserMacroAux prec e -| _ => throwUnsupportedSyntax + adaptExpander fun stx => match_syntax stx with + | `(tparser! $e) => elabTParserMacroAux (quote Parser.maxPrec) e + | `(tparser! : $prec $e) => elabTParserMacroAux prec e + | _ => throwUnsupportedSyntax private def mkNativeReflAuxDecl (type val : Expr) : TermElabM Name := do -let auxName ← mkAuxName `_nativeRefl -let decl := Declaration.defnDecl { - name := auxName, lparams := [], type := type, value := val, - hints := ReducibilityHints.abbrev, - isUnsafe := false } -addDecl decl -compileDecl decl -pure auxName + let auxName ← mkAuxName `_nativeRefl + let decl := Declaration.defnDecl { + name := auxName, lparams := [], type := type, value := val, + hints := ReducibilityHints.abbrev, + isUnsafe := false } + addDecl decl + compileDecl decl + pure auxName private def elabClosedTerm (stx : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do -let e ← elabTermAndSynthesize stx expectedType? -if e.hasMVar then - throwError! "invalid macro application, term contains metavariables{indentExpr e}" -if e.hasFVar then - throwError! "invalid macro application, term contains free variables{indentExpr e}" -pure e + let e ← elabTermAndSynthesize stx expectedType? + if e.hasMVar then + throwError! "invalid macro application, term contains metavariables{indentExpr e}" + if e.hasFVar then + throwError! "invalid macro application, term contains free variables{indentExpr e}" + pure e -@[builtinTermElab «nativeRefl»] def elabNativeRefl : TermElab := -fun stx _ => do +@[builtinTermElab «nativeRefl»] def elabNativeRefl : TermElab := fun stx _ => do let arg := stx[1] let e ← elabClosedTerm arg none let type ← inferType e @@ -157,18 +155,17 @@ fun stx _ => do mkExpectedTypeHint r eq private def getPropToDecide (expectedType? : Option Expr) : TermElabM Expr := do -tryPostponeIfNoneOrMVar expectedType? -match expectedType? with -| none => throwError "invalid macro, expected type is not available" -| some expectedType => - synthesizeSyntheticMVars - let expectedType ← instantiateMVars expectedType - if expectedType.hasFVar || expectedType.hasMVar then - throwError! "expected type must not contain free or meta variables{indentExpr expectedType}" - pure expectedType + tryPostponeIfNoneOrMVar expectedType? + match expectedType? with + | none => throwError "invalid macro, expected type is not available" + | some expectedType => + synthesizeSyntheticMVars + let expectedType ← instantiateMVars expectedType + if expectedType.hasFVar || expectedType.hasMVar then + throwError! "expected type must not contain free or meta variables{indentExpr expectedType}" + pure expectedType -@[builtinTermElab «nativeDecide»] def elabNativeDecide : TermElab := -fun stx expectedType? => do +@[builtinTermElab «nativeDecide»] def elabNativeDecide : TermElab := fun stx expectedType? => do let p ← getPropToDecide expectedType? let d ← mkAppM `Decidable.decide #[p] let auxDeclName ← mkNativeReflAuxDecl (Lean.mkConst `Bool) d @@ -176,8 +173,7 @@ fun stx expectedType? => do let r := mkApp3 (Lean.mkConst `Lean.ofReduceBool) (Lean.mkConst auxDeclName) (toExpr true) rflPrf mkExpectedTypeHint r p -@[builtinTermElab Lean.Parser.Term.decide] def elabDecide : TermElab := -fun stx expectedType? => do +@[builtinTermElab Lean.Parser.Term.decide] def elabDecide : TermElab := fun stx expectedType? => do let p ← getPropToDecide expectedType? let d ← mkAppM `Decidable.decide #[p] let d ← instantiateMVars d @@ -185,18 +181,16 @@ fun stx expectedType? => do let rflPrf ← mkEqRefl (toExpr true) pure $ mkApp3 (Lean.mkConst `ofDecideEqTrue) p s rflPrf -def expandInfix (f : Syntax) : Macro := -fun stx => do +def expandInfix (f : Syntax) : Macro := fun stx => do -- term `op` term let a := stx[0] let b := stx[2] pure (mkAppStx f #[a, b]) -def expandInfixOp (op : Name) : Macro := -fun stx => expandInfix (mkCIdentFrom stx[1] op) stx +def expandInfixOp (op : Name) : Macro := fun stx => + expandInfix (mkCIdentFrom stx[1] op) stx -def expandPrefixOp (op : Name) : Macro := -fun stx => do +def expandPrefixOp (op : Name) : Macro := fun stx => do -- `op` term let a := stx[1] pure (mkAppStx (mkCIdentFrom stx[0] op) #[a]) @@ -251,8 +245,7 @@ fun stx => do @[builtinMacro Lean.Parser.Term.not] def expandNot : Macro := expandPrefixOp `Not @[builtinMacro Lean.Parser.Term.bnot] def expandBNot : Macro := expandPrefixOp `not -@[builtinTermElab panic] def elabPanic : TermElab := -fun stx expectedType? => do +@[builtinTermElab panic] def elabPanic : TermElab := fun stx expectedType? => do let arg := stx[1] let pos ← getRefPosition let env ← getEnv @@ -261,11 +254,10 @@ fun stx expectedType? => do | none => `(panicWithPos $(quote (toString env.mainModule)) $(quote pos.line) $(quote pos.column) $arg) withMacroExpansion stx stxNew $ elabTerm stxNew expectedType? -@[builtinMacro Lean.Parser.Term.unreachable] def expandUnreachable : Macro := -fun stx => `(panic! "unreachable code has been reached") +@[builtinMacro Lean.Parser.Term.unreachable] def expandUnreachable : Macro := fun stx => + `(panic! "unreachable code has been reached") -@[builtinMacro Lean.Parser.Term.assert] def expandAssert : Macro := -fun stx => +@[builtinMacro Lean.Parser.Term.assert] def expandAssert : Macro := fun stx => -- TODO: support for disabling runtime assertions let cond := stx[1] let body := stx[3] @@ -273,8 +265,7 @@ fun stx => | some code => `(if $cond then $body else panic! ("assertion violation: " ++ $(quote code))) | none => `(if $cond then $body else panic! ("assertion violation")) -@[builtinMacro Lean.Parser.Term.dbgTrace] def expandDbgTrace : Macro := -fun stx => +@[builtinMacro Lean.Parser.Term.dbgTrace] def expandDbgTrace : Macro := fun stx => let arg := stx[1] let body := stx[3] if arg.getKind == interpolatedStrKind then @@ -282,25 +273,24 @@ fun stx => else `(dbgTrace (toString $arg) fun _ => $body) -@[builtinMacro Lean.Parser.Term.«sorry»] def expandSorry : Macro := -fun _ => `(sorryAx _ false) +@[builtinMacro Lean.Parser.Term.«sorry»] def expandSorry : Macro := fun _ => + `(sorryAx _ false) -@[builtinTermElab emptyC] def expandEmptyC : TermElab := -fun stx expectedType? => do +@[builtinTermElab emptyC] def expandEmptyC : TermElab := fun stx expectedType? => do let stxNew ← `(HasEmptyc.emptyc) withMacroExpansion stx stxNew $ elabTerm stxNew expectedType? /-- Return syntax `Prod.mk elems[0] (Prod.mk elems[1] ... (Prod.mk elems[elems.size - 2] elems[elems.size - 1])))` -/ partial def mkPairs (elems : Array Syntax) : MacroM Syntax := -let rec loop (i : Nat) (acc : Syntax) := do - if i > 0 then - let i := i - 1 - let elem := elems[i] - let acc ← `(Prod.mk $elem $acc) - loop i acc - else - pure acc -loop (elems.size - 1) elems.back + let rec loop (i : Nat) (acc : Syntax) := do + if i > 0 then + let i := i - 1 + let elem := elems[i] + let acc ← `(Prod.mk $elem $acc) + loop i acc + else + pure acc + loop (elems.size - 1) elems.back /-- Try to expand `·` notation, and if successful elaborate result. @@ -312,12 +302,11 @@ loop (elems.size - 1) elems.back - `(· + ·)` - `(f · a b)` -/ private def elabCDot (stx : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do -match (← liftMacroM $ expandCDot? stx) with -| some stx' => withMacroExpansion stx stx' (elabTerm stx' expectedType?) -| none => elabTerm stx expectedType? + match (← liftMacroM $ expandCDot? stx) with + | some stx' => withMacroExpansion stx stx' (elabTerm stx' expectedType?) + | none => elabTerm stx expectedType? -@[builtinTermElab paren] def elabParen : TermElab := -fun stx expectedType? => +@[builtinTermElab paren] def elabParen : TermElab := fun stx expectedType? => match_syntax stx with | `(()) => pure $ Lean.mkConst `Unit.unit | `(($e : $type)) => do diff --git a/src/Lean/Elab/Command.lean b/src/Lean/Elab/Command.lean index 60bf9c3c40..e2dc009421 100644 --- a/src/Lean/Elab/Command.lean +++ b/src/Lean/Elab/Command.lean @@ -14,135 +14,144 @@ import Lean.Elab.DeclModifiers namespace Lean.Elab.Command structure Scope := -(kind : String) -(header : String) -(opts : Options := {}) -(currNamespace : Name := Name.anonymous) -(openDecls : List OpenDecl := []) -(levelNames : List Name := []) -(varDecls : Array Syntax := #[]) + (kind : String) + (header : String) + (opts : Options := {}) + (currNamespace : Name := Name.anonymous) + (openDecls : List OpenDecl := []) + (levelNames : List Name := []) + (varDecls : Array Syntax := #[]) -instance Scope.inhabited : Inhabited Scope := ⟨{ kind := "", header := "" }⟩ +instance : Inhabited Scope := ⟨{ kind := "", header := "" }⟩ structure State := -(env : Environment) -(messages : MessageLog := {}) -(scopes : List Scope := [{ kind := "root", header := "" }]) -(nextMacroScope : Nat := firstFrontendMacroScope + 1) -(maxRecDepth : Nat) -(nextInstIdx : Nat := 1) -- for generating anonymous instance names -(ngen : NameGenerator := {}) + (env : Environment) + (messages : MessageLog := {}) + (scopes : List Scope := [{ kind := "root", header := "" }]) + (nextMacroScope : Nat := firstFrontendMacroScope + 1) + (maxRecDepth : Nat) + (nextInstIdx : Nat := 1) -- for generating anonymous instance names + (ngen : NameGenerator := {}) -instance State.inhabited : Inhabited State := ⟨{ env := arbitrary _, maxRecDepth := 0 }⟩ +instance : Inhabited State := ⟨{ env := arbitrary _, maxRecDepth := 0 }⟩ -def mkState (env : Environment) (messages : MessageLog := {}) (opts : Options := {}) : State := -{ env := env, messages := messages, scopes := [{ kind := "root", header := "", opts := opts }], maxRecDepth := getMaxRecDepth opts } +def mkState (env : Environment) (messages : MessageLog := {}) (opts : Options := {}) : State := { + env := env, + messages := messages, + scopes := [{ kind := "root", header := "", opts := opts }], + maxRecDepth := getMaxRecDepth opts +} structure Context := -(fileName : String) -(fileMap : FileMap) -(currRecDepth : Nat := 0) -(cmdPos : String.Pos := 0) -(macroStack : MacroStack := []) -(currMacroScope : MacroScope := firstFrontendMacroScope) -(ref : Syntax := Syntax.missing) + (fileName : String) + (fileMap : FileMap) + (currRecDepth : Nat := 0) + (cmdPos : String.Pos := 0) + (macroStack : MacroStack := []) + (currMacroScope : MacroScope := firstFrontendMacroScope) + (ref : Syntax := Syntax.missing) abbrev CommandElabCoreM (ε) := ReaderT Context $ StateRefT State $ EIO ε abbrev CommandElabM := CommandElabCoreM Exception abbrev CommandElab := Syntax → CommandElabM Unit -instance : MonadEnv CommandElabM := -{ getEnv := do pure (← get).env, - modifyEnv := fun f => modify fun s => { s with env := f s.env } } +instance : MonadEnv CommandElabM := { + getEnv := do pure (← get).env, + modifyEnv := fun f => modify fun s => { s with env := f s.env } +} -instance : MonadOptions CommandElabM := -{ getOptions := do pure (← get).scopes.head!.opts } +instance : MonadOptions CommandElabM := { + getOptions := do pure (← get).scopes.head!.opts +} -protected def getRef : CommandElabM Syntax := -do pure (← read).ref +protected def getRef : CommandElabM Syntax := do + pure (← read).ref -instance : AddMessageContext CommandElabM := -{ addMessageContext := addMessageContextPartial } +instance : AddMessageContext CommandElabM := { + addMessageContext := addMessageContextPartial +} -instance : Ref CommandElabM := -{ getRef := Command.getRef, - withRef := fun ref x => withReader (fun ctx => { ctx with ref := ref }) x } +instance : Ref CommandElabM := { + getRef := Command.getRef, + withRef := fun ref x => withReader (fun ctx => { ctx with ref := ref }) x +} -instance : AddErrorMessageContext CommandElabM := -{ add := fun ref msg => do - let ctx ← read - let ref := getBetterRef ref ctx.macroStack - let msg ← addMessageContext msg - let msg ← addMacroStack msg ctx.macroStack - pure (ref, msg) } +instance : AddErrorMessageContext CommandElabM := { + add := fun ref msg => do + let ctx ← read + let ref := getBetterRef ref ctx.macroStack + let msg ← addMessageContext msg + let msg ← addMacroStack msg ctx.macroStack + pure (ref, msg) +} def mkMessageAux (ctx : Context) (ref : Syntax) (msgData : MessageData) (severity : MessageSeverity) : Message := -mkMessageCore ctx.fileName ctx.fileMap msgData severity (ref.getPos.getD ctx.cmdPos) + mkMessageCore ctx.fileName ctx.fileMap msgData severity (ref.getPos.getD ctx.cmdPos) private def mkCoreContext (ctx : Context) (s : State) : Core.Context := -let scope := s.scopes.head!; -{ options := scope.opts, - currRecDepth := ctx.currRecDepth, - maxRecDepth := s.maxRecDepth, - ref := ctx.ref } + let scope := s.scopes.head!; + { options := scope.opts, + currRecDepth := ctx.currRecDepth, + maxRecDepth := s.maxRecDepth, + ref := ctx.ref } def liftCoreM {α} (x : CoreM α) : CommandElabM α := do -let s ← get -let ctx ← read -let Eα := Except Exception α -let x : CoreM Eα := do try let a ← x; pure $ Except.ok a catch ex => pure $ Except.error ex -let x : EIO Exception (Eα × Core.State) := (ReaderT.run x (mkCoreContext ctx s)).run { env := s.env, ngen := s.ngen } -let (ea, coreS) ← liftM x -modify fun s => { s with env := coreS.env, ngen := coreS.ngen } -match ea with -| Except.ok a => pure a -| Except.error e => throw e + let s ← get + let ctx ← read + let Eα := Except Exception α + let x : CoreM Eα := do try let a ← x; pure $ Except.ok a catch ex => pure $ Except.error ex + let x : EIO Exception (Eα × Core.State) := (ReaderT.run x (mkCoreContext ctx s)).run { env := s.env, ngen := s.ngen } + let (ea, coreS) ← liftM x + modify fun s => { s with env := coreS.env, ngen := coreS.ngen } + match ea with + | Except.ok a => pure a + | Except.error e => throw e private def ioErrorToMessage (ctx : Context) (ref : Syntax) (err : IO.Error) : Message := -let ref := getBetterRef ref ctx.macroStack -mkMessageAux ctx ref (toString err) MessageSeverity.error + let ref := getBetterRef ref ctx.macroStack + mkMessageAux ctx ref (toString err) MessageSeverity.error -@[inline] def liftEIO {α} (x : EIO Exception α) : CommandElabM α := -liftM x +@[inline] def liftEIO {α} (x : EIO Exception α) : CommandElabM α := liftM x @[inline] def liftIO {α} (x : IO α) : CommandElabM α := do -let ctx ← read -liftEIO $ adaptExcept (fun (ex : IO.Error) => Exception.error ctx.ref ex.toString) x + let ctx ← read + liftEIO $ adaptExcept (fun (ex : IO.Error) => Exception.error ctx.ref ex.toString) x -instance : MonadIO CommandElabM := -{ liftIO := liftIO } +instance : MonadIO CommandElabM := { liftIO := liftIO } def getScope : CommandElabM Scope := do pure (← get).scopes.head! -instance : MonadResolveName CommandElabM := -{ getCurrNamespace := do pure (← getScope).currNamespace, - getOpenDecls := do pure (← getScope).openDecls } +instance : MonadResolveName CommandElabM := { + getCurrNamespace := do pure (← getScope).currNamespace, + getOpenDecls := do pure (← getScope).openDecls +} -instance CommandElabM.monadLog : MonadLog CommandElabM := -{ getRef := getRef, +instance : MonadLog CommandElabM := { + getRef := getRef, getFileMap := do pure (← read).fileMap, getFileName := do pure (← read).fileName, logMessage := fun msg => do let currNamespace ← getCurrNamespace let openDecls ← getOpenDecls let msg := { msg with data := MessageData.withNamingContext { currNamespace := currNamespace, openDecls := openDecls } msg.data } - modify fun s => { s with messages := s.messages.add msg } } + modify fun s => { s with messages := s.messages.add msg } +} protected def getCurrMacroScope : CommandElabM Nat := do pure (← read).currMacroScope protected def getMainModule : CommandElabM Name := do pure (← getEnv).mainModule @[inline] protected def withFreshMacroScope {α} (x : CommandElabM α) : CommandElabM α := do -let fresh ← modifyGet (fun st => (st.nextMacroScope, { st with nextMacroScope := st.nextMacroScope + 1 })) -withReader (fun ctx => { ctx with currMacroScope := fresh }) x + let fresh ← modifyGet (fun st => (st.nextMacroScope, { st with nextMacroScope := st.nextMacroScope + 1 })) + withReader (fun ctx => { ctx with currMacroScope := fresh }) x -instance CommandElabM.MonadQuotation : MonadQuotation CommandElabM := { +instance : MonadQuotation CommandElabM := { getCurrMacroScope := Command.getCurrMacroScope, getMainModule := Command.getMainModule, withFreshMacroScope := @Command.withFreshMacroScope } unsafe def mkCommandElabAttributeUnsafe : IO (KeyedDeclsAttribute CommandElab) := -mkElabAttribute CommandElab `Lean.Elab.Command.commandElabAttribute `builtinCommandElab `commandElab `Lean.Parser.Command `Lean.Elab.Command.CommandElab "command" + mkElabAttribute CommandElab `Lean.Elab.Command.commandElabAttribute `builtinCommandElab `commandElab `Lean.Parser.Command `Lean.Elab.Command.CommandElab "command" @[implementedBy mkCommandElabAttributeUnsafe] constant mkCommandElabAttribute : IO (KeyedDeclsAttribute CommandElab) @@ -150,171 +159,169 @@ constant mkCommandElabAttribute : IO (KeyedDeclsAttribute CommandElab) builtin_initialize commandElabAttribute : KeyedDeclsAttribute CommandElab ← mkCommandElabAttribute private def elabCommandUsing (s : State) (stx : Syntax) : List CommandElab → CommandElabM Unit -| [] => throwError! "unexpected syntax{indentD stx}" -| (elabFn::elabFns) => - catchInternalId unsupportedSyntaxExceptionId - (elabFn stx) - (fun _ => do set s; elabCommandUsing s stx elabFns) + | [] => throwError! "unexpected syntax{indentD stx}" + | (elabFn::elabFns) => + catchInternalId unsupportedSyntaxExceptionId + (elabFn stx) + (fun _ => do set s; elabCommandUsing s stx elabFns) /- Elaborate `x` with `stx` on the macro stack -/ @[inline] def withMacroExpansion {α} (beforeStx afterStx : Syntax) (x : CommandElabM α) : CommandElabM α := -withReader (fun ctx => { ctx with macroStack := { before := beforeStx, after := afterStx } :: ctx.macroStack }) x + withReader (fun ctx => { ctx with macroStack := { before := beforeStx, after := afterStx } :: ctx.macroStack }) x -instance : MonadMacroAdapter CommandElabM := -{ getCurrMacroScope := getCurrMacroScope, +instance : MonadMacroAdapter CommandElabM := { + getCurrMacroScope := getCurrMacroScope, getNextMacroScope := do pure (← get).nextMacroScope, setNextMacroScope := fun next => modify fun s => { s with nextMacroScope := next } } -instance : MonadRecDepth CommandElabM := -{ withRecDepth := fun d x => withReader (fun ctx => { ctx with currRecDepth := d }) x, +instance : MonadRecDepth CommandElabM := { + withRecDepth := fun d x => withReader (fun ctx => { ctx with currRecDepth := d }) x, getRecDepth := do pure (← read).currRecDepth, getMaxRecDepth := do pure (← get).maxRecDepth } @[inline] def withLogging (x : CommandElabM Unit) : CommandElabM Unit := do -try - x -catch ex => match ex with - | Exception.error _ _ => logException ex - | Exception.internal id => - if id == abortExceptionId then - pure () - else - let idName ← liftIO $ id.getName; - logError msg!"internal exception {idName}" + try + x + catch ex => match ex with + | Exception.error _ _ => logException ex + | Exception.internal id => + if id == abortExceptionId then + pure () + else + let idName ← liftIO $ id.getName; + logError msg!"internal exception {idName}" builtin_initialize registerTraceClass `Elab.command partial def elabCommand : Syntax → CommandElabM Unit -| stx => withLogging $ withRef stx $ withIncRecDepth $ withFreshMacroScope $ match stx with - | Syntax.node k args => - if k == nullKind then - -- list of commands => elaborate in order - -- The parser will only ever return a single command at a time, but syntax quotations can return multiple ones - args.forM elabCommand - else do - trace `Elab.command fun _ => stx; - let s ← get - let stxNew? ← catchInternalId unsupportedSyntaxExceptionId - (do let newStx ← adaptMacro (getMacros s.env) stx; pure (some newStx)) - (fun ex => pure none) - match stxNew? with - | some stxNew => withMacroExpansion stx stxNew $ elabCommand stxNew - | _ => - let table := (commandElabAttribute.ext.getState s.env).table; - let k := stx.getKind; - match table.find? k with - | some elabFns => elabCommandUsing s stx elabFns - | none => throwError ("elaboration function for '" ++ toString k ++ "' has not been implemented") - | _ => throwError "unexpected command" + | stx => withLogging $ withRef stx $ withIncRecDepth $ withFreshMacroScope $ match stx with + | Syntax.node k args => + if k == nullKind then + -- list of commands => elaborate in order + -- The parser will only ever return a single command at a time, but syntax quotations can return multiple ones + args.forM elabCommand + else do + trace `Elab.command fun _ => stx; + let s ← get + let stxNew? ← catchInternalId unsupportedSyntaxExceptionId + (do let newStx ← adaptMacro (getMacros s.env) stx; pure (some newStx)) + (fun ex => pure none) + match stxNew? with + | some stxNew => withMacroExpansion stx stxNew $ elabCommand stxNew + | _ => + let table := (commandElabAttribute.ext.getState s.env).table; + let k := stx.getKind; + match table.find? k with + | some elabFns => elabCommandUsing s stx elabFns + | none => throwError ("elaboration function for '" ++ toString k ++ "' has not been implemented") + | _ => throwError "unexpected command" /-- Adapt a syntax transformation to a regular, command-producing elaborator. -/ -def adaptExpander (exp : Syntax → CommandElabM Syntax) : CommandElab := -fun stx => do +def adaptExpander (exp : Syntax → CommandElabM Syntax) : CommandElab := fun stx => do let stx' ← exp stx withMacroExpansion stx stx' $ elabCommand stx' private def getVarDecls (s : State) : Array Syntax := -s.scopes.head!.varDecls + s.scopes.head!.varDecls -instance CommandElabM.inhabited {α} : Inhabited (CommandElabM α) := -⟨throw $ arbitrary _⟩ +instance {α} : Inhabited (CommandElabM α) := ⟨throw $ arbitrary _⟩ -private def mkMetaContext : Meta.Context := -{ config := { foApprox := true, ctxApprox := true, quasiPatternApprox := true } } +private def mkMetaContext : Meta.Context := { + config := { foApprox := true, ctxApprox := true, quasiPatternApprox := true } +} private def mkTermContext (ctx : Context) (s : State) (declName? : Option Name) : Term.Context := -let scope := s.scopes.head!; -{ macroStack := ctx.macroStack, - fileName := ctx.fileName, - fileMap := ctx.fileMap, - currMacroScope := ctx.currMacroScope, - currNamespace := scope.currNamespace, - levelNames := scope.levelNames, - openDecls := scope.openDecls, - declName? := declName? } + let scope := s.scopes.head!; + { macroStack := ctx.macroStack, + fileName := ctx.fileName, + fileMap := ctx.fileMap, + currMacroScope := ctx.currMacroScope, + currNamespace := scope.currNamespace, + levelNames := scope.levelNames, + openDecls := scope.openDecls, + declName? := declName? } private def addTraceAsMessages (ctx : Context) (log : MessageLog) (traceState : TraceState) : MessageLog := -traceState.traces.foldl - (fun (log : MessageLog) traceElem => - let ref := replaceRef traceElem.ref ctx.ref; - let pos := ref.getPos.getD 0; - log.add (mkMessageCore ctx.fileName ctx.fileMap traceElem.msg MessageSeverity.information pos)) - log + traceState.traces.foldl + (fun (log : MessageLog) traceElem => + let ref := replaceRef traceElem.ref ctx.ref; + let pos := ref.getPos.getD 0; + log.add (mkMessageCore ctx.fileName ctx.fileMap traceElem.msg MessageSeverity.information pos)) + log def liftTermElabM {α} (declName? : Option Name) (x : TermElabM α) : CommandElabM α := do -let ctx ← read -let s ← get -let scope := s.scopes.head! --- We execute `x` with an empty message log. Thus, `x` cannot modify/view messages produced by previous commands. --- This is useful for implementing `runTermElabM` where we use `Term.resetMessageLog` -let messages := s.messages -let x : MetaM _ := (observing x).run (mkTermContext ctx s declName?) { messages := {} } -let x : CoreM _ := x.run mkMetaContext {} -let x : EIO _ _ := x.run (mkCoreContext ctx s) { env := s.env, ngen := s.ngen, nextMacroScope := s.nextMacroScope } -let (((ea, termS), _), coreS) ← liftEIO x -modify fun s => { s with - env := coreS.env, - messages := addTraceAsMessages ctx (messages ++ termS.messages) coreS.traceState, - nextMacroScope := coreS.nextMacroScope, - ngen := coreS.ngen -} -match ea with -| Except.ok a => pure a -| Except.error ex => throw ex + let ctx ← read + let s ← get + let scope := s.scopes.head! + -- We execute `x` with an empty message log. Thus, `x` cannot modify/view messages produced by previous commands. + -- This is useful for implementing `runTermElabM` where we use `Term.resetMessageLog` + let messages := s.messages + let x : MetaM _ := (observing x).run (mkTermContext ctx s declName?) { messages := {} } + let x : CoreM _ := x.run mkMetaContext {} + let x : EIO _ _ := x.run (mkCoreContext ctx s) { env := s.env, ngen := s.ngen, nextMacroScope := s.nextMacroScope } + let (((ea, termS), _), coreS) ← liftEIO x + modify fun s => { s with + env := coreS.env, + messages := addTraceAsMessages ctx (messages ++ termS.messages) coreS.traceState, + nextMacroScope := coreS.nextMacroScope, + ngen := coreS.ngen + } + match ea with + | Except.ok a => pure a + | Except.error ex => throw ex @[inline] def runTermElabM {α} (declName? : Option Name) (elabFn : Array Expr → TermElabM α) : CommandElabM α := do -let s ← get -liftTermElabM declName? - -- We don't want to store messages produced when elaborating `(getVarDecls s)` because they have already been saved when we elaborated the `variable`(s) command. - -- So, we use `Term.resetMessageLog`. - (Term.elabBinders (getVarDecls s) (fun xs => do Term.resetMessageLog; elabFn xs)) + let s ← get + liftTermElabM declName? + -- We don't want to store messages produced when elaborating `(getVarDecls s)` because they have already been saved when we elaborated the `variable`(s) command. + -- So, we use `Term.resetMessageLog`. + (Term.elabBinders (getVarDecls s) (fun xs => do Term.resetMessageLog; elabFn xs)) -@[inline] def catchExceptions (x : CommandElabM Unit) : CommandElabCoreM Empty Unit := -fun ctx ref => EIO.catchExceptions (withLogging x ctx ref) (fun _ => pure ()) +@[inline] def catchExceptions (x : CommandElabM Unit) : CommandElabCoreM Empty Unit := fun ctx ref => + EIO.catchExceptions (withLogging x ctx ref) (fun _ => pure ()) private def addScope (kind : String) (header : String) (newNamespace : Name) : CommandElabM Unit := -modify fun s => { - s with - env := s.env.registerNamespace newNamespace, - scopes := { s.scopes.head! with kind := kind, header := header, currNamespace := newNamespace } :: s.scopes -} + modify fun s => { + s with + env := s.env.registerNamespace newNamespace, + scopes := { s.scopes.head! with kind := kind, header := header, currNamespace := newNamespace } :: s.scopes + } private def addScopes (kind : String) (updateNamespace : Bool) : Name → CommandElabM Unit -| Name.anonymous => pure () -| Name.str p header _ => do - addScopes kind updateNamespace p - let currNamespace ← getCurrNamespace - addScope kind header (if updateNamespace then mkNameStr currNamespace header else currNamespace) -| _ => throwError "invalid scope" + | Name.anonymous => pure () + | Name.str p header _ => do + addScopes kind updateNamespace p + let currNamespace ← getCurrNamespace + addScope kind header (if updateNamespace then mkNameStr currNamespace header else currNamespace) + | _ => throwError "invalid scope" private def addNamespace (header : Name) : CommandElabM Unit := -addScopes "namespace" true header + addScopes "namespace" true header -@[builtinCommandElab «namespace»] def elabNamespace : CommandElab := -fun stx => match_syntax stx with +@[builtinCommandElab «namespace»] def elabNamespace : CommandElab := fun stx => + match_syntax stx with | `(namespace $n) => addNamespace n.getId | _ => throwUnsupportedSyntax -@[builtinCommandElab «section»] def elabSection : CommandElab := -fun stx => match_syntax stx with +@[builtinCommandElab «section»] def elabSection : CommandElab := fun stx => + match_syntax stx with | `(section $header:ident) => addScopes "section" false header.getId | `(section) => do let currNamespace ← getCurrNamespace; addScope "section" "" currNamespace | _ => throwUnsupportedSyntax def getScopes : CommandElabM (List Scope) := do -pure (← get).scopes + pure (← get).scopes private def checkAnonymousScope : List Scope → Bool -| { header := "", .. } :: _ => true -| _ => false + | { header := "", .. } :: _ => true + | _ => false private def checkEndHeader : Name → List Scope → Bool -| Name.anonymous, _ => true -| Name.str p s _, { header := h, .. } :: scopes => h == s && checkEndHeader p scopes -| _, _ => false + | Name.anonymous, _ => true + | Name.str p s _, { header := h, .. } :: scopes => h == s && checkEndHeader p scopes + | _, _ => false -@[builtinCommandElab «end»] def elabEnd : CommandElab := -fun stx => do +@[builtinCommandElab «end»] def elabEnd : CommandElab := fun stx => do let header? := (stx.getArg 1).getOptionalIdent?; let endSize := match header? with | none => 1 @@ -330,32 +337,30 @@ fun stx => do | some header => unless checkEndHeader header scopes do throwError "invalid 'end', name mismatch" @[inline] def withNamespace {α} (ns : Name) (elabFn : CommandElabM α) : CommandElabM α := do -addNamespace ns -let a ← elabFn -modify fun s => { s with scopes := s.scopes.drop ns.getNumParts } -pure a + addNamespace ns + let a ← elabFn + modify fun s => { s with scopes := s.scopes.drop ns.getNumParts } + pure a @[specialize] def modifyScope (f : Scope → Scope) : CommandElabM Unit := -modify fun s => - { s with - scopes := match s.scopes with - | h::t => f h :: t - | [] => unreachable! } + modify fun s => + { s with + scopes := match s.scopes with + | h::t => f h :: t + | [] => unreachable! } def getLevelNames : CommandElabM (List Name) := do -pure (← getScope).levelNames + pure (← getScope).levelNames -def addUnivLevel (idStx : Syntax) : CommandElabM Unit := -withRef idStx do -let id := idStx.getId -let levelNames ← getLevelNames -if levelNames.elem id then - throwAlreadyDeclaredUniverseLevel id -else - modifyScope fun scope => { scope with levelNames := id :: scope.levelNames } +def addUnivLevel (idStx : Syntax) : CommandElabM Unit := withRef idStx do + let id := idStx.getId + let levelNames ← getLevelNames + if levelNames.elem id then + throwAlreadyDeclaredUniverseLevel id + else + modifyScope fun scope => { scope with levelNames := id :: scope.levelNames } -partial def elabChoiceAux (cmds : Array Syntax) : Nat → CommandElabM Unit -| i => +partial def elabChoiceAux (cmds : Array Syntax) (i : Nat) : CommandElabM Unit := if h : i < cmds.size then let cmd := cmds.get ⟨i, h⟩; catchInternalId unsupportedSyntaxExceptionId @@ -364,20 +369,17 @@ partial def elabChoiceAux (cmds : Array Syntax) : Nat → CommandElabM Unit else throwUnsupportedSyntax -@[builtinCommandElab choice] def elbChoice : CommandElab := -fun stx => elabChoiceAux stx.getArgs 0 +@[builtinCommandElab choice] def elbChoice : CommandElab := fun stx => + elabChoiceAux stx.getArgs 0 -@[builtinCommandElab «universe»] def elabUniverse : CommandElab := -fun n => do +@[builtinCommandElab «universe»] def elabUniverse : CommandElab := fun n => do addUnivLevel n[1] -@[builtinCommandElab «universes»] def elabUniverses : CommandElab := -fun n => do +@[builtinCommandElab «universes»] def elabUniverses : CommandElab := fun n => do let idsStx := n[1] idsStx.forArgsM addUnivLevel -@[builtinCommandElab «init_quot»] def elabInitQuot : CommandElab := -fun stx => do +@[builtinCommandElab «init_quot»] def elabInitQuot : CommandElab := fun stx => do let env ← getEnv match env.addDecl Declaration.quotDecl with | Except.ok env => setEnv env @@ -386,10 +388,9 @@ fun stx => do throwError (ex.toMessageData opts) def logUnknownDecl (declName : Name) : CommandElabM Unit := -logError msg!"unknown declaration '{declName}'" + logError msg!"unknown declaration '{declName}'" -@[builtinCommandElab «export»] def elabExport : CommandElab := -fun stx => do +@[builtinCommandElab «export»] def elabExport : CommandElab := fun stx => do -- `stx` is of the form (Command.export "export" "(" (null *) ")") let id := stx[1].getId let ns ← resolveNamespace id @@ -408,63 +409,62 @@ fun stx => do modify fun s => { s with env := aliases.foldl (init := s.env) fun env p => addAlias env p.1 p.2 } def addOpenDecl (d : OpenDecl) : CommandElabM Unit := -modifyScope fun scope => { scope with openDecls := d :: scope.openDecls } + modifyScope fun scope => { scope with openDecls := d :: scope.openDecls } def elabOpenSimple (n : SyntaxNode) : CommandElabM Unit := --- `open` id+ -let nss := n.getArg 0 -nss.forArgsM fun ns => do - let ns ← resolveNamespace ns.getId - addOpenDecl (OpenDecl.simple ns []) + -- `open` id+ + let nss := n.getArg 0 + nss.forArgsM fun ns => do + let ns ← resolveNamespace ns.getId + addOpenDecl (OpenDecl.simple ns []) -- `open` id `(` id+ `)` def elabOpenOnly (n : SyntaxNode) : CommandElabM Unit := do -let ns := n.getIdAt 0 -let ns ← resolveNamespace ns -let ids := n.getArg 2 -ids.forArgsM fun idStx => do - let id := idStx.getId - let declName := ns ++ id - let env ← getEnv - if env.contains declName then - addOpenDecl (OpenDecl.explicit id declName) - else - withRef idStx $ logUnknownDecl declName + let ns := n.getIdAt 0 + let ns ← resolveNamespace ns + let ids := n.getArg 2 + ids.forArgsM fun idStx => do + let id := idStx.getId + let declName := ns ++ id + let env ← getEnv + if env.contains declName then + addOpenDecl (OpenDecl.explicit id declName) + else + withRef idStx $ logUnknownDecl declName -- `open` id `hiding` id+ def elabOpenHiding (n : SyntaxNode) : CommandElabM Unit := do -let ns := n.getIdAt 0 -let ns ← resolveNamespace ns -let idsStx := n.getArg 2 -let env ← getEnv -let ids : List Name ← idsStx.foldArgsM (fun idStx ids => do - let id := idStx.getId - let declName := ns ++ id - if env.contains declName then - pure (id::ids) - else do - withRef idStx $ logUnknownDecl declName - pure ids) - [] -addOpenDecl (OpenDecl.simple ns ids) + let ns := n.getIdAt 0 + let ns ← resolveNamespace ns + let idsStx := n.getArg 2 + let env ← getEnv + let ids : List Name ← idsStx.foldArgsM (fun idStx ids => do + let id := idStx.getId + let declName := ns ++ id + if env.contains declName then + pure (id::ids) + else do + withRef idStx $ logUnknownDecl declName + pure ids) + [] + addOpenDecl (OpenDecl.simple ns ids) -- `open` id `renaming` sepBy (id `->` id) `,` def elabOpenRenaming (n : SyntaxNode) : CommandElabM Unit := do -let ns := n.getIdAt 0 -let ns ← resolveNamespace ns -let rs := (n.getArg 2) -rs.forSepArgsM $ fun stx => do - let fromId := stx.getIdAt 0 - let toId := stx.getIdAt 2 - let declName := ns ++ fromId - let env ← getEnv - if env.contains declName then - addOpenDecl (OpenDecl.explicit toId declName) - else - withRef stx $ logUnknownDecl declName + let ns := n.getIdAt 0 + let ns ← resolveNamespace ns + let rs := (n.getArg 2) + rs.forSepArgsM $ fun stx => do + let fromId := stx.getIdAt 0 + let toId := stx.getIdAt 2 + let declName := ns ++ fromId + let env ← getEnv + if env.contains declName then + addOpenDecl (OpenDecl.explicit toId declName) + else + withRef stx $ logUnknownDecl declName -@[builtinCommandElab «open»] def elabOpen : CommandElab := -fun n => do +@[builtinCommandElab «open»] def elabOpen : CommandElab := fun n => do let body := (n.getArg 1).asNode let k := body.getKind; if k == `Lean.Parser.Command.openSimple then @@ -476,16 +476,14 @@ fun n => do else elabOpenRenaming body -@[builtinCommandElab «variable»] def elabVariable : CommandElab := -fun n => do +@[builtinCommandElab «variable»] def elabVariable : CommandElab := fun n => do -- `variable` bracketedBinder let binder := n[1] -- Try to elaborate `binder` for sanity checking runTermElabM none fun _ => Term.elabBinder binder fun _ => pure () modifyScope fun scope => { scope with varDecls := scope.varDecls.push binder } -@[builtinCommandElab «variables»] def elabVariables : CommandElab := -fun n => do +@[builtinCommandElab «variables»] def elabVariables : CommandElab := fun n => do -- `variables` bracketedBinder+ let binders := n[1].getArgs -- Try to elaborate `binders` for sanity checking @@ -494,8 +492,7 @@ fun n => do open Meta -@[builtinCommandElab Lean.Parser.Command.check] def elabCheck : CommandElab := -fun stx => do +@[builtinCommandElab Lean.Parser.Command.check] def elabCheck : CommandElab := fun stx => do let term := stx[1] withoutModifyingEnv $ runTermElabM (some `_check) $ fun _ => do let e ← Term.elabTerm term none @@ -504,34 +501,33 @@ fun stx => do logInfo msg!"{e} : {type}" def hasNoErrorMessages : CommandElabM Bool := do -return !(← get).messages.hasErrors + return !(← get).messages.hasErrors def failIfSucceeds (x : CommandElabM Unit) : CommandElabM Unit := do -let resetMessages : CommandElabM MessageLog := do - let s ← get - let messages := s.messages; - modify fun s => { s with messages := {} }; - pure messages -let restoreMessages (prevMessages : MessageLog) : CommandElabM Unit := do - modify fun s => { s with messages := prevMessages ++ s.messages.errorsToWarnings } -let prevMessages ← resetMessages -let succeeded ← - try - x - hasNoErrorMessages - catch - | ex@(Exception.error _ _) => do logException ex; pure false - | Exception.internal id => do logError "internal"; pure false -- TODO: improve `logError "internal"` - finally - restoreMessages prevMessages -if succeeded then - throwError "unexpected success" + let resetMessages : CommandElabM MessageLog := do + let s ← get + let messages := s.messages; + modify fun s => { s with messages := {} }; + pure messages + let restoreMessages (prevMessages : MessageLog) : CommandElabM Unit := do + modify fun s => { s with messages := prevMessages ++ s.messages.errorsToWarnings } + let prevMessages ← resetMessages + let succeeded ← + try + x + hasNoErrorMessages + catch + | ex@(Exception.error _ _) => do logException ex; pure false + | Exception.internal id => do logError "internal"; pure false -- TODO: improve `logError "internal"` + finally + restoreMessages prevMessages + if succeeded then + throwError "unexpected success" -@[builtinCommandElab «check_failure»] def elabCheckFailure : CommandElab := -fun stx => failIfSucceeds $ elabCheck stx +@[builtinCommandElab «check_failure»] def elabCheckFailure : CommandElab := fun stx => + failIfSucceeds $ elabCheck stx -unsafe def elabEvalUnsafe : CommandElab := -fun stx => do +unsafe def elabEvalUnsafe : CommandElab := fun stx => do let ref := stx let term := stx[1] let n := `_eval @@ -580,8 +576,7 @@ fun stx => do @[builtinCommandElab «eval», implementedBy elabEvalUnsafe] constant elabEval : CommandElab -@[builtinCommandElab «synth»] def elabSynth : CommandElab := -fun stx => do +@[builtinCommandElab «synth»] def elabSynth : CommandElab := fun stx => do let term := stx[1] withoutModifyingEnv $ runTermElabM `_synth_cmd fun _ => do let inst ← Term.elabTerm term none @@ -592,15 +587,14 @@ fun stx => do pure () def setOption (optionName : Name) (val : DataValue) : CommandElabM Unit := do -let decl ← liftIO $ getOptionDecl optionName -unless decl.defValue.sameCtor val do throwError "type mismatch at set_option" -modifyScope fun scope => { scope with opts := scope.opts.insert optionName val } -match optionName, val with -| `maxRecDepth, DataValue.ofNat max => modify fun s => { s with maxRecDepth := max } -| _, _ => pure () + let decl ← liftIO $ getOptionDecl optionName + unless decl.defValue.sameCtor val do throwError "type mismatch at set_option" + modifyScope fun scope => { scope with opts := scope.opts.insert optionName val } + match optionName, val with + | `maxRecDepth, DataValue.ofNat max => modify fun s => { s with maxRecDepth := max } + | _, _ => pure () -@[builtinCommandElab «set_option»] def elabSetOption : CommandElab := -fun stx => do +@[builtinCommandElab «set_option»] def elabSetOption : CommandElab := fun stx => do let optionName := stx[1].getId let val := stx[2] match val.isStrLit? with @@ -614,15 +608,14 @@ fun stx => do | Syntax.atom _ "false" => setOption optionName (DataValue.ofBool false) | _ => logErrorAt val msg!"unexpected set_option value {val}" -@[builtinMacro Lean.Parser.Command.«in»] def expandInCmd : Macro := -fun stx => do +@[builtinMacro Lean.Parser.Command.«in»] def expandInCmd : Macro := fun stx => do let cmd₁ := stx[0] let cmd₂ := stx[2] `(section $cmd₁:command $cmd₂:command end) def expandDeclId (declId : Syntax) (modifiers : Modifiers) : CommandElabM ExpandDeclIdResult := do -let currNamespace ← getCurrNamespace -let currLevelNames ← getLevelNames -Lean.Elab.expandDeclId currNamespace currLevelNames declId modifiers + let currNamespace ← getCurrNamespace + let currLevelNames ← getLevelNames + Lean.Elab.expandDeclId currNamespace currLevelNames declId modifiers end Lean.Elab.Command diff --git a/src/Lean/Elab/DeclModifiers.lean b/src/Lean/Elab/DeclModifiers.lean index 863eba2593..7cb59825ef 100644 --- a/src/Lean/Elab/DeclModifiers.lean +++ b/src/Lean/Elab/DeclModifiers.lean @@ -11,51 +11,48 @@ import Lean.Elab.DeclUtil namespace Lean.Elab -def checkNotAlreadyDeclared {m} [Monad m] [MonadEnv m] [MonadExceptOf Exception m] [Ref m] [AddErrorMessageContext m] - (declName : Name) : m Unit := do -let env ← getEnv -if env.contains declName then - match privateToUserName? declName with - | none => throwError! "'{declName}' has already been declared" - | some declName => throwError! "private declaration '{declName}' has already been declared" -if env.contains (mkPrivateName env declName) then - throwError! "a private declaration '{declName}' has already been declared" -match privateToUserName? declName with -| none => pure () -| some declName => +def checkNotAlreadyDeclared {m} [Monad m] [MonadEnv m] [MonadExceptOf Exception m] [Ref m] [AddErrorMessageContext m] (declName : Name) : m Unit := do + let env ← getEnv if env.contains declName then - throwError! "a non-private declaration '{declName}' has already been declared" + match privateToUserName? declName with + | none => throwError! "'{declName}' has already been declared" + | some declName => throwError! "private declaration '{declName}' has already been declared" + if env.contains (mkPrivateName env declName) then + throwError! "a private declaration '{declName}' has already been declared" + match privateToUserName? declName with + | none => pure () + | some declName => + if env.contains declName then + throwError! "a non-private declaration '{declName}' has already been declared" inductive Visibility -| regular | «protected» | «private» + | regular | «protected» | «private» -instance Visibility.hasToString : HasToString Visibility := -⟨fun - | regular => "regular" - | «private» => "private" - | «protected» => "protected"⟩ +instance : HasToString Visibility := ⟨fun + | Visibility.regular => "regular" + | Visibility.«private» => "private" + | Visibility.«protected» => "protected"⟩ structure Modifiers := -(docString : Option String := none) -(visibility : Visibility := Visibility.regular) -(isNoncomputable : Bool := false) -(isPartial : Bool := false) -(isUnsafe : Bool := false) -(attrs : Array Attribute := #[]) + (docString : Option String := none) + (visibility : Visibility := Visibility.regular) + (isNoncomputable : Bool := false) + (isPartial : Bool := false) + (isUnsafe : Bool := false) + (attrs : Array Attribute := #[]) def Modifiers.isPrivate : Modifiers → Bool -| { visibility := Visibility.private, .. } => true -| _ => false + | { visibility := Visibility.private, .. } => true + | _ => false def Modifiers.isProtected : Modifiers → Bool -| { visibility := Visibility.protected, .. } => true -| _ => false + | { visibility := Visibility.protected, .. } => true + | _ => false def Modifiers.addAttribute (modifiers : Modifiers) (attr : Attribute) : Modifiers := -{ modifiers with attrs := modifiers.attrs.push attr } + { modifiers with attrs := modifiers.attrs.push attr } -instance Modifiers.hasFormat : HasFormat Modifiers := -⟨fun m => +instance : HasFormat Modifiers := ⟨fun m => let components : List Format := (match m.docString with | some str => ["/--" ++ str ++ "-/"] @@ -70,72 +67,72 @@ instance Modifiers.hasFormat : HasFormat Modifiers := ++ m.attrs.toList.map (fun attr => fmt attr) Format.bracket "{" (Format.joinSep components ("," ++ Format.line)) "}"⟩ -instance Modifiers.hasToString : HasToString Modifiers := ⟨toString ∘ format⟩ +instance : HasToString Modifiers := ⟨toString ∘ format⟩ section Methods variables {m : Type → Type} [Monad m] [MonadEnv m] [MonadExceptOf Exception m] [Ref m] [AddErrorMessageContext m] def elabModifiers (stx : Syntax) : m Modifiers := do -let docCommentStx := stx[0] -let attrsStx := stx[1] -let visibilityStx := stx[2] -let noncompStx := stx[3] -let unsafeStx := stx[4] -let partialStx := stx[5] -let docString ← match docCommentStx.getOptional? with - | none => pure none - | some s => match s[1] with - | Syntax.atom _ val => pure (some (val.extract 0 (val.bsize - 2))) - | _ => throwErrorAt! s "unexpected doc string {s[1]}" -let visibility ← match visibilityStx.getOptional? with - | none => pure Visibility.regular - | some v => - let kind := v.getKind - if kind == `Lean.Parser.Command.private then pure Visibility.private - else if kind == `Lean.Parser.Command.protected then pure Visibility.protected - else throwErrorAt v "unexpected visibility modifier" -let attrs ← match attrsStx.getOptional? with - | none => pure #[] - | some attrs => elabDeclAttrs attrs -pure { - docString := docString, - visibility := visibility, - isPartial := !partialStx.isNone, - isUnsafe := !unsafeStx.isNone, - isNoncomputable := !noncompStx.isNone, - attrs := attrs -} + let docCommentStx := stx[0] + let attrsStx := stx[1] + let visibilityStx := stx[2] + let noncompStx := stx[3] + let unsafeStx := stx[4] + let partialStx := stx[5] + let docString ← match docCommentStx.getOptional? with + | none => pure none + | some s => match s[1] with + | Syntax.atom _ val => pure (some (val.extract 0 (val.bsize - 2))) + | _ => throwErrorAt! s "unexpected doc string {s[1]}" + let visibility ← match visibilityStx.getOptional? with + | none => pure Visibility.regular + | some v => + let kind := v.getKind + if kind == `Lean.Parser.Command.private then pure Visibility.private + else if kind == `Lean.Parser.Command.protected then pure Visibility.protected + else throwErrorAt v "unexpected visibility modifier" + let attrs ← match attrsStx.getOptional? with + | none => pure #[] + | some attrs => elabDeclAttrs attrs + pure { + docString := docString, + visibility := visibility, + isPartial := !partialStx.isNone, + isUnsafe := !unsafeStx.isNone, + isNoncomputable := !noncompStx.isNone, + attrs := attrs + } def applyVisibility (visibility : Visibility) (declName : Name) : m Name := do -match visibility with -| Visibility.private => - let env ← getEnv - let declName := mkPrivateName env declName - checkNotAlreadyDeclared declName - pure declName -| Visibility.protected => - checkNotAlreadyDeclared declName - let env ← getEnv - let env := addProtected env declName - setEnv env - pure declName -| _ => - checkNotAlreadyDeclared declName - pure declName + match visibility with + | Visibility.private => + let env ← getEnv + let declName := mkPrivateName env declName + checkNotAlreadyDeclared declName + pure declName + | Visibility.protected => + checkNotAlreadyDeclared declName + let env ← getEnv + let env := addProtected env declName + setEnv env + pure declName + | _ => + checkNotAlreadyDeclared declName + pure declName def mkDeclName (currNamespace : Name) (modifiers : Modifiers) (shortName : Name) : m (Name × Name) := do -let name := (extractMacroScopes shortName).name -unless name.isAtomic || isFreshInstanceName name do - throwError! "atomic identifier expected '{shortName}'" -let declName := currNamespace ++ shortName -let declName ← applyVisibility modifiers.visibility declName -match modifiers.visibility with -| Visibility.protected => - match currNamespace with - | Name.str _ s _ => pure (declName, mkNameSimple s ++ shortName) - | _ => throwError "protected declarations must be in a namespace" -| _ => pure (declName, shortName) + let name := (extractMacroScopes shortName).name + unless name.isAtomic || isFreshInstanceName name do + throwError! "atomic identifier expected '{shortName}'" + let declName := currNamespace ++ shortName + let declName ← applyVisibility modifiers.visibility declName + match modifiers.visibility with + | Visibility.protected => + match currNamespace with + | Name.str _ s _ => pure (declName, mkNameSimple s ++ shortName) + | _ => throwError "protected declarations must be in a namespace" + | _ => pure (declName, shortName) /- `declId` is of the form @@ -145,36 +142,36 @@ match modifiers.visibility with but we also accept a single identifier to users to make macro writing more convenient . -/ def expandDeclIdCore (declId : Syntax) : Name × Syntax := -if declId.isIdent then - (declId.getId, mkNullNode) -else - let id := declId[0].getId - let optUnivDeclStx := declId[1] - (id, optUnivDeclStx) + if declId.isIdent then + (declId.getId, mkNullNode) + else + let id := declId[0].getId + let optUnivDeclStx := declId[1] + (id, optUnivDeclStx) structure ExpandDeclIdResult := -(shortName : Name) -(declName : Name) -(levelNames : List Name) + (shortName : Name) + (declName : Name) + (levelNames : List Name) def expandDeclId (currNamespace : Name) (currLevelNames : List Name) (declId : Syntax) (modifiers : Modifiers) : m ExpandDeclIdResult := do --- ident >> optional (".{" >> sepBy1 ident ", " >> "}") -let (shortName, optUnivDeclStx) := expandDeclIdCore declId -let levelNames ← - if optUnivDeclStx.isNone then - pure currLevelNames - else - let extraLevels := optUnivDeclStx[1].getArgs.getEvenElems - extraLevels.foldlM - (fun levelNames idStx => - let id := idStx.getId - if levelNames.elem id then - withRef idStx $ throwAlreadyDeclaredUniverseLevel id - else - pure (id :: levelNames)) - currLevelNames -let (declName, shortName) ← withRef declId $ mkDeclName currNamespace modifiers shortName -pure { shortName := shortName, declName := declName, levelNames := levelNames } + -- ident >> optional (".{" >> sepBy1 ident ", " >> "}") + let (shortName, optUnivDeclStx) := expandDeclIdCore declId + let levelNames ← + if optUnivDeclStx.isNone then + pure currLevelNames + else + let extraLevels := optUnivDeclStx[1].getArgs.getEvenElems + extraLevels.foldlM + (fun levelNames idStx => + let id := idStx.getId + if levelNames.elem id then + withRef idStx $ throwAlreadyDeclaredUniverseLevel id + else + pure (id :: levelNames)) + currLevelNames + let (declName, shortName) ← withRef declId $ mkDeclName currNamespace modifiers shortName + pure { shortName := shortName, declName := declName, levelNames := levelNames } end Methods diff --git a/src/Lean/Elab/Do.lean b/src/Lean/Elab/Do.lean index 1ea0df2f9b..8ed83de00e 100644 --- a/src/Lean/Elab/Do.lean +++ b/src/Lean/Elab/Do.lean @@ -13,62 +13,64 @@ namespace Lean.Elab.Term open Meta private def getDoSeqElems (doSeq : Syntax) : List Syntax := -if doSeq.getKind == `Lean.Parser.Term.doSeqBracketed then - doSeq[1].getArgs.toList.map fun arg => arg[0] -else if doSeq.getKind == `Lean.Parser.Term.doSeqIndent then - doSeq[0].getArgs.toList.map fun arg => arg[0] -else - [] + if doSeq.getKind == `Lean.Parser.Term.doSeqBracketed then + doSeq[1].getArgs.toList.map fun arg => arg[0] + else if doSeq.getKind == `Lean.Parser.Term.doSeqIndent then + doSeq[0].getArgs.toList.map fun arg => arg[0] + else + [] private def getDoSeq (doStx : Syntax) : Syntax := -doStx[1] + doStx[1] -@[builtinTermElab liftMethod] def elabLiftMethod : TermElab := -fun stx _ => +@[builtinTermElab liftMethod] def elabLiftMethod : TermElab := fun stx _ => throwErrorAt stx "invalid use of `(<- ...)`, must be nested inside a 'do' expression" private partial def hasLiftMethod : Syntax → Bool -| Syntax.node k args => - if k == `Lean.Parser.Term.do then false - else if k == `Lean.Parser.Term.doSeqIndent then false - else if k == `Lean.Parser.Term.doSeqBracketed then false - else if k == `Lean.Parser.Term.quot then false - else if k == `Lean.Parser.Term.liftMethod then true - else args.any hasLiftMethod -| _ => false + | Syntax.node k args => + if k == `Lean.Parser.Term.do then false + else if k == `Lean.Parser.Term.doSeqIndent then false + else if k == `Lean.Parser.Term.doSeqBracketed then false + else if k == `Lean.Parser.Term.quot then false + else if k == `Lean.Parser.Term.liftMethod then true + else args.any hasLiftMethod + | _ => false structure ExtractMonadResult := -(m : Expr) -(α : Expr) -(hasBindInst : Expr) + (m : Expr) + (α : Expr) + (hasBindInst : Expr) private def mkIdBindFor (type : Expr) : TermElabM ExtractMonadResult := do -let u ← getDecLevel type -let id := Lean.mkConst `Id [u] -let idBindVal := Lean.mkConst `Id.hasBind [u] -pure { m := id, hasBindInst := idBindVal, α := type } + let u ← getDecLevel type + let id := Lean.mkConst `Id [u] + let idBindVal := Lean.mkConst `Id.hasBind [u] + pure { m := id, hasBindInst := idBindVal, α := type } private def extractBind (expectedType? : Option Expr) : TermElabM ExtractMonadResult := do -match expectedType? with -| none => throwError "invalid do notation, expected type is not available" -| some expectedType => - let type ← withReducible $ whnf expectedType - if type.getAppFn.isMVar then throwError "invalid do notation, expected type is not available" - match type with - | Expr.app m α _ => - try - let bindInstType ← mkAppM `HasBind #[m] - let bindInstVal ← synthesizeInst bindInstType - pure { m := m, hasBindInst := bindInstVal, α := α } - catch _ => - mkIdBindFor type - | _ => mkIdBindFor type + match expectedType? with + | none => throwError "invalid do notation, expected type is not available" + | some expectedType => + let type ← withReducible $ whnf expectedType + if type.getAppFn.isMVar then throwError "invalid do notation, expected type is not available" + match type with + | Expr.app m α _ => + try + let bindInstType ← mkAppM `HasBind #[m] + let bindInstVal ← synthesizeInst bindInstType + pure { m := m, hasBindInst := bindInstVal, α := α } + catch _ => + mkIdBindFor type + | _ => mkIdBindFor type namespace Do /- A `doMatch` alternative. `vars` is the array of variables declared by `patterns`. -/ structure Alt (σ : Type) := -(ref : Syntax) (vars : Array Name) (patterns : Syntax) (rhs : σ) + (ref : Syntax) + (vars : Array Name) + (patterns : Syntax) + (rhs : σ) /- Auxiliary datastructure for representing a `do` code block, and compiling "reassignments" (e.g., `x := x + 1`). @@ -119,204 +121,206 @@ structure Alt (σ : Type) := - For every `jmp ref j as` in `C`, there is a `joinpoint j ps b k` and `jmp ref j as` is in `k`, and `ps.size == as.size` -/ inductive Code -| decl (xs : Array Name) (doElem : Syntax) (k : Code) -| reassign (xs : Array Name) (doElem : Syntax) (k : Code) -/- The Boolean value in `params` indicates whether we should use `(x : typeof! x)` when generating term Syntax or not -/ -| joinpoint (name : Name) (params : Array (Name × Bool)) (body : Code) (k : Code) -| seq (action : Syntax) (k : Code) -| action (action : Syntax) -| «break» (ref : Syntax) -| «continue» (ref : Syntax) -| «return» (ref : Syntax) (val : Syntax) -/- Recall that an if-then-else may declare a variable using `optIdent` for the branches `thenBranch` and `elseBranch`. We store the variable name at `var?`. -/ -| ite (ref : Syntax) (h? : Option Name) (optIdent : Syntax) (cond : Syntax) (thenBranch : Code) (elseBranch : Code) -| «match» (ref : Syntax) (discrs : Syntax) (optType : Syntax) (alts : Array (Alt Code)) -| jmp (ref : Syntax) (jpName : Name) (args : Array Syntax) + | decl (xs : Array Name) (doElem : Syntax) (k : Code) + | reassign (xs : Array Name) (doElem : Syntax) (k : Code) + /- The Boolean value in `params` indicates whether we should use `(x : typeof! x)` when generating term Syntax or not -/ + | joinpoint (name : Name) (params : Array (Name × Bool)) (body : Code) (k : Code) + | seq (action : Syntax) (k : Code) + | action (action : Syntax) + | «break» (ref : Syntax) + | «continue» (ref : Syntax) + | «return» (ref : Syntax) (val : Syntax) + /- Recall that an if-then-else may declare a variable using `optIdent` for the branches `thenBranch` and `elseBranch`. We store the variable name at `var?`. -/ + | ite (ref : Syntax) (h? : Option Name) (optIdent : Syntax) (cond : Syntax) (thenBranch : Code) (elseBranch : Code) + | «match» (ref : Syntax) (discrs : Syntax) (optType : Syntax) (alts : Array (Alt Code)) + | jmp (ref : Syntax) (jpName : Name) (args : Array Syntax) -instance Code.inhabited : Inhabited Code := -⟨«break» (arbitrary _)⟩ +instance : Inhabited Code := + ⟨Code.«break» (arbitrary _)⟩ -instance Alt.inhabited : Inhabited (Alt Code) := -⟨{ ref := arbitrary _, vars := #[], patterns := arbitrary _, rhs := arbitrary _ }⟩ +instance : Inhabited (Alt Code) := + ⟨{ ref := arbitrary _, vars := #[], patterns := arbitrary _, rhs := arbitrary _ }⟩ /- A code block, and the collection of variables updated by it. -/ structure CodeBlock := -(code : Code) -(uvars : NameSet := {}) -- set of variables updated by `code` + (code : Code) + (uvars : NameSet := {}) -- set of variables updated by `code` private def nameSetToArray (s : NameSet) : Array Name := -s.fold (fun (xs : Array Name) x => xs.push x) #[] + s.fold (fun (xs : Array Name) x => xs.push x) #[] private def varsToMessageData (vars : Array Name) : MessageData := -MessageData.joinSep (vars.toList.map fun n => MessageData.ofName (n.simpMacroScopes)) " " + MessageData.joinSep (vars.toList.map fun n => MessageData.ofName (n.simpMacroScopes)) " " partial def CodeBlocl.toMessageData (codeBlock : CodeBlock) : MessageData := -let us := MessageData.ofList $ (nameSetToArray codeBlock.uvars).toList.map MessageData.ofName -let rec loop : Code → MessageData - | Code.decl xs _ k => msg!"let {varsToMessageData xs} := ...\n{loop k}" - | Code.reassign xs _ k => msg!"{varsToMessageData xs} := ...\n{loop k}" - | Code.joinpoint n ps body k => msg!"let {n.simpMacroScopes} {varsToMessageData (ps.map Prod.fst)} := {indentD (loop body)}\n{loop k}" - | Code.seq e k => msg!"{e}\n{loop k}" - | Code.action e => e - | Code.ite _ _ _ c t e => msg!"if {c} then {indentD (loop t)}\nelse{loop e}" - | Code.jmp _ j xs => msg!"jmp {j.simpMacroScopes} {xs.toList}" - | Code.«break» _ => msg!"break {us}" - | Code.«continue» _ => msg!"continue {us}" - | Code.«return» _ v => msg!"return {v} {us}" - | Code.«match» _ ds t alts => - msg!"match {ds} with" - ++ alts.foldl (init := "") fun acc alt => acc ++ msg!"\n| {alt.patterns} => {loop alt.rhs}" -loop codeBlock.code + let us := MessageData.ofList $ (nameSetToArray codeBlock.uvars).toList.map MessageData.ofName + let rec loop : Code → MessageData + | Code.decl xs _ k => msg!"let {varsToMessageData xs} := ...\n{loop k}" + | Code.reassign xs _ k => msg!"{varsToMessageData xs} := ...\n{loop k}" + | Code.joinpoint n ps body k => msg!"let {n.simpMacroScopes} {varsToMessageData (ps.map Prod.fst)} := {indentD (loop body)}\n{loop k}" + | Code.seq e k => msg!"{e}\n{loop k}" + | Code.action e => e + | Code.ite _ _ _ c t e => msg!"if {c} then {indentD (loop t)}\nelse{loop e}" + | Code.jmp _ j xs => msg!"jmp {j.simpMacroScopes} {xs.toList}" + | Code.«break» _ => msg!"break {us}" + | Code.«continue» _ => msg!"continue {us}" + | Code.«return» _ v => msg!"return {v} {us}" + | Code.«match» _ ds t alts => + msg!"match {ds} with" + ++ alts.foldl (init := "") fun acc alt => acc ++ msg!"\n| {alt.patterns} => {loop alt.rhs}" + loop codeBlock.code /- Return true if the give code contains an exit point that satisfies `p` -/ @[inline] partial def hasExitPointPred (c : Code) (p : Code → Bool) : Bool := -let rec @[specialize] loop : Code → Bool - | Code.decl _ _ k => loop k - | Code.reassign _ _ k => loop k - | Code.joinpoint _ _ b k => loop b || loop k - | Code.seq _ k => loop k - | Code.ite _ _ _ _ t e => loop t || loop e - | Code.«match» _ _ _ alts => alts.any (loop ·.rhs) - | Code.jmp _ _ _ => false - | c => p c -loop c + let rec @[specialize] loop : Code → Bool + | Code.decl _ _ k => loop k + | Code.reassign _ _ k => loop k + | Code.joinpoint _ _ b k => loop b || loop k + | Code.seq _ k => loop k + | Code.ite _ _ _ _ t e => loop t || loop e + | Code.«match» _ _ _ alts => alts.any (loop ·.rhs) + | Code.jmp _ _ _ => false + | c => p c + loop c def hasExitPoint (c : Code) : Bool := -hasExitPointPred c fun c => true + hasExitPointPred c fun c => true def hasReturn (c : Code) : Bool := -hasExitPointPred c fun - | Code.«return» _ _ => true - | _ => false + hasExitPointPred c fun + | Code.«return» _ _ => true + | _ => false def hasTerminalAction (c : Code) : Bool := -hasExitPointPred c fun - | Code.«action» _ => true - | _ => false + hasExitPointPred c fun + | Code.«action» _ => true + | _ => false def hasBreakContinue (c : Code) : Bool := -hasExitPointPred c fun - | Code.«break» _ => true - | Code.«continue» _ => true - | _ => false + hasExitPointPred c fun + | Code.«break» _ => true + | Code.«continue» _ => true + | _ => false def hasBreakContinueReturn (c : Code) : Bool := -hasExitPointPred c fun - | Code.«break» _ => true - | Code.«continue» _ => true - | Code.«return» _ _ => true - | _ => false + hasExitPointPred c fun + | Code.«break» _ => true + | Code.«continue» _ => true + | Code.«return» _ _ => true + | _ => false def mkAuxDeclFor {m} [Monad m] [MonadQuotation m] (e : Syntax) (mkCont : Syntax → m Code) : m Code := withFreshMacroScope do -let y ← `(y) -let yName := y.getId -let doElem ← `(doElem| let y ← $e:term) --- Add elaboration hint for producing sane error message -let y ← `(ensureExpectedType! "type mismatch, result value" $y) -let k ← mkCont y -pure $ Code.decl #[yName] doElem k + let y ← `(y) + let yName := y.getId + let doElem ← `(doElem| let y ← $e:term) + -- Add elaboration hint for producing sane error message + let y ← `(ensureExpectedType! "type mismatch, result value" $y) + let k ← mkCont y + pure $ Code.decl #[yName] doElem k /- Convert `action _ e` instructions in `c` into `let y ← e; jmp _ jp (xs y)`. -/ partial def convertTerminalActionIntoJmp (code : Code) (jp : Name) (xs : Array Name) : MacroM Code := -let rec loop : Code → MacroM Code - | Code.decl xs stx k => do Code.decl xs stx (← loop k) - | Code.reassign xs stx k => do Code.reassign xs stx (← loop k) - | Code.joinpoint n ps b k => do Code.joinpoint n ps (← loop b) (← loop k) - | Code.seq e k => do Code.seq e (← loop k) - | Code.ite ref x? h c t e => do Code.ite ref x? h c (← loop t) (← loop e) - | Code.«match» ref ds t alts => do Code.«match» ref ds t (← alts.mapM fun alt => do pure { alt with rhs := (← loop alt.rhs) }) - | Code.action e => mkAuxDeclFor e fun y => - let ref := e - -- We jump to `jp` with xs **and** y - let jmpArgs := xs.map $ mkIdentFrom ref - let jmpArgs := jmpArgs.push y - pure $ Code.jmp ref jp jmpArgs - | c => pure c -loop code + let rec loop : Code → MacroM Code + | Code.decl xs stx k => do Code.decl xs stx (← loop k) + | Code.reassign xs stx k => do Code.reassign xs stx (← loop k) + | Code.joinpoint n ps b k => do Code.joinpoint n ps (← loop b) (← loop k) + | Code.seq e k => do Code.seq e (← loop k) + | Code.ite ref x? h c t e => do Code.ite ref x? h c (← loop t) (← loop e) + | Code.«match» ref ds t alts => do Code.«match» ref ds t (← alts.mapM fun alt => do pure { alt with rhs := (← loop alt.rhs) }) + | Code.action e => mkAuxDeclFor e fun y => + let ref := e + -- We jump to `jp` with xs **and** y + let jmpArgs := xs.map $ mkIdentFrom ref + let jmpArgs := jmpArgs.push y + pure $ Code.jmp ref jp jmpArgs + | c => pure c + loop code structure JPDecl := -(name : Name) (params : Array (Name × Bool)) (body : Code) + (name : Name) + (params : Array (Name × Bool)) + (body : Code) def attachJP (jpDecl : JPDecl) (k : Code) : Code := -Code.joinpoint jpDecl.name jpDecl.params jpDecl.body k + Code.joinpoint jpDecl.name jpDecl.params jpDecl.body k def attachJPs (jpDecls : Array JPDecl) (k : Code) : Code := -jpDecls.foldr attachJP k + jpDecls.foldr attachJP k def mkFreshJP (ps : Array (Name × Bool)) (body : Code) : TermElabM JPDecl := do -let ps ← - if ps.isEmpty then - let y ← mkFreshUserName `y - pure #[(y, false)] - else - pure ps --- Remark: the compiler frontend implemented in C++ currently detects jointpoints created by --- the "do" notation by testing the name. See hack at method `visit_let` at `lcnf.cpp` --- We will remove this hack when we re-implement the compiler frontend in Lean. -let name ← mkFreshUserName `_do_jp -pure { name := name, params := ps, body := body } + let ps ← + if ps.isEmpty then + let y ← mkFreshUserName `y + pure #[(y, false)] + else + pure ps + -- Remark: the compiler frontend implemented in C++ currently detects jointpoints created by + -- the "do" notation by testing the name. See hack at method `visit_let` at `lcnf.cpp` + -- We will remove this hack when we re-implement the compiler frontend in Lean. + let name ← mkFreshUserName `_do_jp + pure { name := name, params := ps, body := body } def mkFreshJP' (xs : Array Name) (body : Code) : TermElabM JPDecl := -mkFreshJP (xs.map fun x => (x, true)) body + mkFreshJP (xs.map fun x => (x, true)) body def addFreshJP (ps : Array (Name × Bool)) (body : Code) : StateRefT (Array JPDecl) TermElabM Name := do -let jp ← mkFreshJP ps body -modify fun (jps : Array JPDecl) => jps.push jp -pure jp.name + let jp ← mkFreshJP ps body + modify fun (jps : Array JPDecl) => jps.push jp + pure jp.name def insertVars (rs : NameSet) (xs : Array Name) : NameSet := -xs.foldl (·.insert ·) rs + xs.foldl (·.insert ·) rs def eraseVars (rs : NameSet) (xs : Array Name) : NameSet := -xs.foldl (·.erase ·) rs + xs.foldl (·.erase ·) rs def eraseOptVar (rs : NameSet) (x? : Option Name) : NameSet := -match x? with -| none => rs -| some x => rs.insert x + match x? with + | none => rs + | some x => rs.insert x /- Create a new jointpoint for `c`, and jump to it with the variables `rs` -/ def mkSimpleJmp (ref : Syntax) (rs : NameSet) (c : Code) : StateRefT (Array JPDecl) TermElabM Code := do -let xs := nameSetToArray rs -let jp ← addFreshJP (xs.map fun x => (x, true)) c -if xs.isEmpty then - let unit ← `(Unit.unit) - pure $ Code.jmp ref jp #[unit] -else - pure $ Code.jmp ref jp (xs.map $ mkIdentFrom ref) + let xs := nameSetToArray rs + let jp ← addFreshJP (xs.map fun x => (x, true)) c + if xs.isEmpty then + let unit ← `(Unit.unit) + pure $ Code.jmp ref jp #[unit] + else + pure $ Code.jmp ref jp (xs.map $ mkIdentFrom ref) /- Create a new joinpoint that takes `rs` and `val` as arguments. `val` must be syntax representing a pure value. The body of the joinpoint is created using `mkJPBody yFresh`, where `yFresh` is a fresh variable created by this method. -/ def mkJmp (ref : Syntax) (rs : NameSet) (val : Syntax) (mkJPBody : Syntax → MacroM Code) : StateRefT (Array JPDecl) TermElabM Code := do -let xs := nameSetToArray rs -let args := xs.map $ mkIdentFrom ref -let args := args.push val -let yFresh ← mkFreshUserName `y -let ps := xs.map fun x => (x, true) -let ps := ps.push (yFresh, false) -let jpBody ← liftMacroM $ mkJPBody (mkIdentFrom ref yFresh) -let jp ← addFreshJP ps jpBody -pure $ Code.jmp ref jp args + let xs := nameSetToArray rs + let args := xs.map $ mkIdentFrom ref + let args := args.push val + let yFresh ← mkFreshUserName `y + let ps := xs.map fun x => (x, true) + let ps := ps.push (yFresh, false) + let jpBody ← liftMacroM $ mkJPBody (mkIdentFrom ref yFresh) + let jp ← addFreshJP ps jpBody + pure $ Code.jmp ref jp args /- `pullExitPointsAux rs c` auxiliary method for `pullExitPoints`, `rs` is the set of update variable in the current path. -/ partial def pullExitPointsAux : NameSet → Code → StateRefT (Array JPDecl) TermElabM Code -| rs, Code.decl xs stx k => do Code.decl xs stx (← pullExitPointsAux (eraseVars rs xs) k) -| rs, Code.reassign xs stx k => do Code.reassign xs stx (← pullExitPointsAux (insertVars rs xs) k) -| rs, Code.joinpoint j ps b k => do Code.joinpoint j ps (← pullExitPointsAux rs b) (← pullExitPointsAux rs k) -| rs, Code.seq e k => do Code.seq e (← pullExitPointsAux rs k) -| rs, Code.ite ref x? o c t e => do Code.ite ref x? o c (← pullExitPointsAux (eraseOptVar rs x?) t) (← pullExitPointsAux (eraseOptVar rs x?) e) -| rs, Code.«match» ref ds t alts => do - Code.«match» ref ds t (← alts.mapM fun alt => do pure { alt with rhs := (← pullExitPointsAux (eraseVars rs alt.vars) alt.rhs) }) -| rs, c@(Code.jmp _ _ _) => pure c -| rs, Code.«break» ref => mkSimpleJmp ref rs (Code.«break» ref) -| rs, Code.«continue» ref => mkSimpleJmp ref rs (Code.«continue» ref) -| rs, Code.«return» ref val => mkJmp ref rs val (fun y => pure $ Code.«return» ref y) -| rs, Code.action e => - -- We use `mkAuxDeclFor` because `e` is not pure. - mkAuxDeclFor e fun y => - let ref := e - mkJmp ref rs y (fun yFresh => do pure $ Code.action (← `(HasPure.pure $yFresh))) + | rs, Code.decl xs stx k => do Code.decl xs stx (← pullExitPointsAux (eraseVars rs xs) k) + | rs, Code.reassign xs stx k => do Code.reassign xs stx (← pullExitPointsAux (insertVars rs xs) k) + | rs, Code.joinpoint j ps b k => do Code.joinpoint j ps (← pullExitPointsAux rs b) (← pullExitPointsAux rs k) + | rs, Code.seq e k => do Code.seq e (← pullExitPointsAux rs k) + | rs, Code.ite ref x? o c t e => do Code.ite ref x? o c (← pullExitPointsAux (eraseOptVar rs x?) t) (← pullExitPointsAux (eraseOptVar rs x?) e) + | rs, Code.«match» ref ds t alts => do + Code.«match» ref ds t (← alts.mapM fun alt => do pure { alt with rhs := (← pullExitPointsAux (eraseVars rs alt.vars) alt.rhs) }) + | rs, c@(Code.jmp _ _ _) => pure c + | rs, Code.«break» ref => mkSimpleJmp ref rs (Code.«break» ref) + | rs, Code.«continue» ref => mkSimpleJmp ref rs (Code.«continue» ref) + | rs, Code.«return» ref val => mkJmp ref rs val (fun y => pure $ Code.«return» ref y) + | rs, Code.action e => + -- We use `mkAuxDeclFor` because `e` is not pure. + mkAuxDeclFor e fun y => + let ref := e + mkJmp ref rs y (fun yFresh => do pure $ Code.action (← `(HasPure.pure $yFresh))) /- Auxiliary operation for adding new variables to the collection of updated variables in a CodeBlock. @@ -368,38 +372,38 @@ We implement the method as follows. Let `us` be `c.uvars`, then 3- Same as 2 for `continue`. -/ def pullExitPoints (c : Code) : TermElabM Code := do -if hasExitPoint c then - let (c, jpDecls) ← (pullExitPointsAux {} c).run #[] - pure $ attachJPs jpDecls c -else - pure c + if hasExitPoint c then + let (c, jpDecls) ← (pullExitPointsAux {} c).run #[] + pure $ attachJPs jpDecls c + else + pure c partial def extendUpdatedVarsAux (c : Code) (ws : NameSet) : TermElabM Code := -let rec update : Code → TermElabM Code - | Code.joinpoint j ps b k => do Code.joinpoint j ps (← update b) (← update k) - | Code.seq e k => do Code.seq e (← update k) - | c@(Code.«match» ref ds t alts) => do - if alts.any fun alt => alt.vars.any fun x => ws.contains x then - -- If a pattern variable is shadowing a variable in ws, we `pullExitPoints` - pullExitPoints c - else - Code.«match» ref ds t (← alts.mapM fun alt => do pure { alt with rhs := (← update alt.rhs) }) - | Code.ite ref none o c t e => do Code.ite ref none o c (← update t) (← update e) - | c@(Code.ite ref (some h) o cond t e) => do - if ws.contains h then - -- if the `h` at `if h:c then t else e` shadows a variable in `ws`, we `pullExitPoints` - pullExitPoints c - else - Code.ite ref (some h) o cond (← update t) (← update e) - | Code.reassign xs stx k => do Code.reassign xs stx (← update k) - | c@(Code.decl xs stx k) => do - if xs.any fun x => ws.contains x then - -- One the declared variables is shadowing a variable in `ws` - pullExitPoints c - else - Code.decl xs stx (← update k) - | c => pure c -update c + let rec update : Code → TermElabM Code + | Code.joinpoint j ps b k => do Code.joinpoint j ps (← update b) (← update k) + | Code.seq e k => do Code.seq e (← update k) + | c@(Code.«match» ref ds t alts) => do + if alts.any fun alt => alt.vars.any fun x => ws.contains x then + -- If a pattern variable is shadowing a variable in ws, we `pullExitPoints` + pullExitPoints c + else + Code.«match» ref ds t (← alts.mapM fun alt => do pure { alt with rhs := (← update alt.rhs) }) + | Code.ite ref none o c t e => do Code.ite ref none o c (← update t) (← update e) + | c@(Code.ite ref (some h) o cond t e) => do + if ws.contains h then + -- if the `h` at `if h:c then t else e` shadows a variable in `ws`, we `pullExitPoints` + pullExitPoints c + else + Code.ite ref (some h) o cond (← update t) (← update e) + | Code.reassign xs stx k => do Code.reassign xs stx (← update k) + | c@(Code.decl xs stx k) => do + if xs.any fun x => ws.contains x then + -- One the declared variables is shadowing a variable in `ws` + pullExitPoints c + else + Code.decl xs stx (← update k) + | c => pure c + update c /- Extend the set of updated variables. It assumes `ws` is a super set of `c.uvars`. @@ -407,14 +411,14 @@ We **cannot** simply update the field `c.uvars`, because `c` may have shadowed s See discussion at `pullExitPoints`. -/ partial def extendUpdatedVars (c : CodeBlock) (ws : NameSet) : TermElabM CodeBlock := do -if ws.any fun x => !c.uvars.contains x then - -- `ws` contains a variable that is not in `c.uvars`, but in `c.dvars` (i.e., it has been shadowed) - pure { code := (← extendUpdatedVarsAux c.code ws), uvars := ws } -else - pure { c with uvars := ws } + if ws.any fun x => !c.uvars.contains x then + -- `ws` contains a variable that is not in `c.uvars`, but in `c.dvars` (i.e., it has been shadowed) + pure { code := (← extendUpdatedVarsAux c.code ws), uvars := ws } + else + pure { c with uvars := ws } private def union (s₁ s₂ : NameSet) : NameSet := -s₁.fold (·.insert ·) s₂ + s₁.fold (·.insert ·) s₂ /- Given two code blocks `c₁` and `c₂`, make sure they have the same set of updated variables. @@ -422,10 +426,10 @@ Let `ws` the union of the updated variables in `c₁‵ and ‵c₂`. We use `extendUpdatedVars c₁ ws` and `extendUpdatedVars c₂ ws` -/ def homogenize (c₁ c₂ : CodeBlock) : TermElabM (CodeBlock × CodeBlock) := do -let ws := union c₁.uvars c₂.uvars -let c₁ ← extendUpdatedVars c₁ ws -let c₂ ← extendUpdatedVars c₂ ws -pure (c₁, c₂) + let ws := union c₁.uvars c₂.uvars + let c₁ ← extendUpdatedVars c₁ ws + let c₂ ← extendUpdatedVars c₂ ws + pure (c₁, c₂) /- Extending code blocks with variable declarations: `let x : t := v` and `let x : t ← v`. @@ -434,8 +438,10 @@ Remark: `stx` is the syntax for the declaration (e.g., `letDecl`), and `xs` are declared by it. It is an array because we have let-declarations that declare multiple variables. Example: `let (x, y) := t` -/ -def mkVarDeclCore (xs : Array Name) (stx : Syntax) (c : CodeBlock) : CodeBlock := -{ code := Code.decl xs stx c.code, uvars := eraseVars c.uvars xs } +def mkVarDeclCore (xs : Array Name) (stx : Syntax) (c : CodeBlock) : CodeBlock := { + code := Code.decl xs stx c.code, + uvars := eraseVars c.uvars xs +} /- Extending code blocks with reassignments: `x : t := v` and `x : t ← v`. @@ -444,167 +450,167 @@ declared by it. It is an array because we have let-declarations that declare mul Example: `(x, y) ← t` -/ def mkReassignCore (xs : Array Name) (stx : Syntax) (c : CodeBlock) : TermElabM CodeBlock := do -let us := c.uvars -let ws := insertVars us xs --- If `xs` contains a new updated variable, then we must use `extendUpdatedVars`. --- See discussion at `pullExitPoints` -let code ← if xs.any fun x => !us.contains x then extendUpdatedVarsAux c.code ws else pure c.code -pure { code := Code.reassign xs stx code, uvars := ws } + let us := c.uvars + let ws := insertVars us xs + -- If `xs` contains a new updated variable, then we must use `extendUpdatedVars`. + -- See discussion at `pullExitPoints` + let code ← if xs.any fun x => !us.contains x then extendUpdatedVarsAux c.code ws else pure c.code + pure { code := Code.reassign xs stx code, uvars := ws } def mkSeq (action : Syntax) (c : CodeBlock) : CodeBlock := -{ c with code := Code.seq action c.code } + { c with code := Code.seq action c.code } def mkTerminalAction (action : Syntax) : CodeBlock := -{ code := Code.action action } + { code := Code.action action } def mkReturn (ref : Syntax) (val : Syntax) : CodeBlock := -{ code := Code.«return» ref val } + { code := Code.«return» ref val } def mkBreak (ref : Syntax) : CodeBlock := -{ code := Code.«break» ref } + { code := Code.«break» ref } def mkContinue (ref : Syntax) : CodeBlock := -{ code := Code.«continue» ref } + { code := Code.«continue» ref } def mkIte (ref : Syntax) (optIdent : Syntax) (cond : Syntax) (thenBranch : CodeBlock) (elseBranch : CodeBlock) : TermElabM CodeBlock := do -let x? := if optIdent.isNone then none else some optIdent[0].getId -let (thenBranch, elseBranch) ← homogenize thenBranch elseBranch -pure { - code := Code.ite ref x? optIdent cond thenBranch.code elseBranch.code, - uvars := thenBranch.uvars, -} + let x? := if optIdent.isNone then none else some optIdent[0].getId + let (thenBranch, elseBranch) ← homogenize thenBranch elseBranch + pure { + code := Code.ite ref x? optIdent cond thenBranch.code elseBranch.code, + uvars := thenBranch.uvars, + } private def mkUnit (ref : Syntax) : MacroM Syntax := do -let unit ← `(PUnit.unit) -pure $ unit.copyInfo ref + let unit ← `(PUnit.unit) + pure $ unit.copyInfo ref private def mkPureUnit (ref : Syntax) : MacroM Syntax := do -let unit ← mkUnit ref -let pureUnit ← `(HasPure.pure $(unit.copyInfo ref)) -pure $ pureUnit.copyInfo ref + let unit ← mkUnit ref + let pureUnit ← `(HasPure.pure $(unit.copyInfo ref)) + pure $ pureUnit.copyInfo ref def mkPureUnitAction (ref : Syntax) : MacroM CodeBlock := do -mkTerminalAction (← mkPureUnit ref) + mkTerminalAction (← mkPureUnit ref) def mkUnless (ref : Syntax) (cond : Syntax) (c : CodeBlock) : MacroM CodeBlock := do -let thenBranch ← mkPureUnitAction ref -pure { c with code := Code.ite ref none mkNullNode cond thenBranch.code c.code } + let thenBranch ← mkPureUnitAction ref + pure { c with code := Code.ite ref none mkNullNode cond thenBranch.code c.code } def mkMatch (ref : Syntax) (discrs : Syntax) (optType : Syntax) (alts : Array (Alt CodeBlock)) : TermElabM CodeBlock := do --- nary version of homogenize -let ws := alts.foldl (union · ·.rhs.uvars) {} -let alts ← alts.mapM fun alt => do - let rhs ← extendUpdatedVars alt.rhs ws - pure { ref := alt.ref, vars := alt.vars, patterns := alt.patterns, rhs := rhs.code : Alt Code } -pure { code := Code.«match» ref discrs optType alts, uvars := ws } + -- nary version of homogenize + let ws := alts.foldl (union · ·.rhs.uvars) {} + let alts ← alts.mapM fun alt => do + let rhs ← extendUpdatedVars alt.rhs ws + pure { ref := alt.ref, vars := alt.vars, patterns := alt.patterns, rhs := rhs.code : Alt Code } + pure { code := Code.«match» ref discrs optType alts, uvars := ws } /- Return a code block that executes `terminal` and then `k` with the value produced by `terminal`. This method assumes `terminal` is a terminal -/ def concat (terminal : CodeBlock) (kRef : Syntax) (y? : Option Name) (k : CodeBlock) : TermElabM CodeBlock := do -unless hasTerminalAction terminal.code do - throwErrorAt kRef "'do' element is unreachable" -let (terminal, k) ← homogenize terminal k -let xs := nameSetToArray k.uvars -let y ← match y? with | some y => pure y | none => mkFreshUserName `y -let ps := xs.map fun x => (x, true) -let ps := ps.push (y, false) -let jpDecl ← mkFreshJP ps k.code -let jp := jpDecl.name -let terminal ← liftMacroM $ convertTerminalActionIntoJmp terminal.code jp xs -pure { code := attachJP jpDecl terminal, uvars := k.uvars } + unless hasTerminalAction terminal.code do + throwErrorAt kRef "'do' element is unreachable" + let (terminal, k) ← homogenize terminal k + let xs := nameSetToArray k.uvars + let y ← match y? with | some y => pure y | none => mkFreshUserName `y + let ps := xs.map fun x => (x, true) + let ps := ps.push (y, false) + let jpDecl ← mkFreshJP ps k.code + let jp := jpDecl.name + let terminal ← liftMacroM $ convertTerminalActionIntoJmp terminal.code jp xs + pure { code := attachJP jpDecl terminal, uvars := k.uvars } def getLetIdDeclVar (letIdDecl : Syntax) : Name := -letIdDecl[0].getId + letIdDecl[0].getId def getLetPatDeclVars (letPatDecl : Syntax) : TermElabM (Array Name) := do -let pattern := letPatDecl[0] -let patternVars ← getPatternVars pattern -pure $ patternVars.filterMap fun - | PatternVar.localVar x => some x - | _ => none + let pattern := letPatDecl[0] + let patternVars ← getPatternVars pattern + pure $ patternVars.filterMap fun + | PatternVar.localVar x => some x + | _ => none def getLetEqnsDeclVar (letEqnsDecl : Syntax) : Name := -letEqnsDecl[0].getId + letEqnsDecl[0].getId def getLetDeclVars (letDecl : Syntax) : TermElabM (Array Name) := do -let arg := letDecl[0] -if arg.getKind == `Lean.Parser.Term.letIdDecl then - pure #[getLetIdDeclVar arg] -else if arg.getKind == `Lean.Parser.Term.letPatDecl then - getLetPatDeclVars arg -else if arg.getKind == `Lean.Parser.Term.letEqnsDecl then - pure #[getLetEqnsDeclVar arg] -else - throwError "unexpected kind of let declaration" + let arg := letDecl[0] + if arg.getKind == `Lean.Parser.Term.letIdDecl then + pure #[getLetIdDeclVar arg] + else if arg.getKind == `Lean.Parser.Term.letPatDecl then + getLetPatDeclVars arg + else if arg.getKind == `Lean.Parser.Term.letEqnsDecl then + pure #[getLetEqnsDeclVar arg] + else + throwError "unexpected kind of let declaration" def getDoLetVars (doLet : Syntax) : TermElabM (Array Name) := --- parser! "let " >> letDecl -getLetDeclVars doLet[1] + -- parser! "let " >> letDecl + getLetDeclVars doLet[1] def getDoHaveVar (doHave : Syntax) : Name := -/- - `parser! "have " >> Term.haveDecl` - where - ``` - haveDecl := optIdent >> termParser >> (haveAssign <|> fromTerm <|> byTactic) - optIdent := optional (try (ident >> " : ")) + /- + `parser! "have " >> Term.haveDecl` + where + ``` + haveDecl := optIdent >> termParser >> (haveAssign <|> fromTerm <|> byTactic) + optIdent := optional (try (ident >> " : ")) - ``` -/ -let optIdent := doHave[1] -if optIdent.isNone then - `this -else - optIdent[0].getId + ``` -/ + let optIdent := doHave[1] + if optIdent.isNone then + `this + else + optIdent[0].getId def getDoLetRecVars (doLetRec : Syntax) : TermElabM (Array Name) := do --- letRecDecls is an array of `(group (optional attributes >> letDecl))` -let letRecDecls := doLetRec[1].getSepArgs -let letDecls := letRecDecls.map fun p => p[1] -let allVars := #[] -for letDecl in letDecls do - let vars ← getLetDeclVars letDecl - allVars := allVars ++ vars -pure allVars + -- letRecDecls is an array of `(group (optional attributes >> letDecl))` + let letRecDecls := doLetRec[1].getSepArgs + let letDecls := letRecDecls.map fun p => p[1] + let allVars := #[] + for letDecl in letDecls do + let vars ← getLetDeclVars letDecl + allVars := allVars ++ vars + pure allVars -- ident >> optType >> leftArrow >> termParser def getDoIdDeclVar (doIdDecl : Syntax) : Name := -doIdDecl[0].getId + doIdDecl[0].getId def getPatternVarNames (pvars : Array PatternVar) : Array Name := -pvars.filterMap fun - | PatternVar.localVar x => some x - | _ => none + pvars.filterMap fun + | PatternVar.localVar x => some x + | _ => none -- termParser >> leftArrow >> termParser >> optional (" | " >> termParser) def getDoPatDeclVars (doPatDecl : Syntax) : TermElabM (Array Name) := do -let pattern := doPatDecl[0] -let patternVars ← getPatternVars pattern -pure $ getPatternVarNames patternVars + let pattern := doPatDecl[0] + let patternVars ← getPatternVars pattern + pure $ getPatternVarNames patternVars -- parser! "let " >> (doIdDecl <|> doPatDecl) def getDoLetArrowVars (doLetArrow : Syntax) : TermElabM (Array Name) := do -let decl := doLetArrow[1] -if decl.getKind == `Lean.Parser.Term.doIdDecl then - pure #[getDoIdDeclVar decl] -else if decl.getKind == `Lean.Parser.Term.doPatDecl then - getDoPatDeclVars decl -else - throwError "unexpected kind of 'do' declaration" + let decl := doLetArrow[1] + if decl.getKind == `Lean.Parser.Term.doIdDecl then + pure #[getDoIdDeclVar decl] + else if decl.getKind == `Lean.Parser.Term.doPatDecl then + getDoPatDeclVars decl + else + throwError "unexpected kind of 'do' declaration" def getDoReassignVars (doReassign : Syntax) : TermElabM (Array Name) := do -let arg := doReassign[0] -if arg.getKind == `Lean.Parser.Term.letIdDecl then - pure #[getLetIdDeclVar arg] -else if arg.getKind == `Lean.Parser.Term.letPatDecl then - getLetPatDeclVars arg -else - throwError "unexpected kind of reassignment" + let arg := doReassign[0] + if arg.getKind == `Lean.Parser.Term.letIdDecl then + pure #[getLetIdDeclVar arg] + else if arg.getKind == `Lean.Parser.Term.letPatDecl then + getLetPatDeclVars arg + else + throwError "unexpected kind of reassignment" def mkDoSeq (doElems : Array Syntax) : Syntax := -mkNode `Lean.Parser.Term.doSeqIndent #[mkNullNode $ doElems.map fun doElem => mkNullNode #[doElem, mkNullNode]] + mkNode `Lean.Parser.Term.doSeqIndent #[mkNullNode $ doElems.map fun doElem => mkNullNode #[doElem, mkNullNode]] def mkSingletonDoSeq (doElem : Syntax) : Syntax := -mkDoSeq #[doElem] + mkDoSeq #[doElem] /- Recall that the `doIf` syntax is of the form @@ -615,50 +621,51 @@ mkDoSeq #[doElem] ``` If the given syntax is a `doIf`, return an equivalente `doIf` that has no `else if`s and the `else` is not none. -/ private def expandDoIf? (stx : Syntax) : MacroM (Option Syntax) := do -if stx.getKind != `Lean.Parser.Term.doIf then pure none else -let doIf := stx -let ref := stx -let doElseIfs := doIf[5].getArgs -let doElse := doIf[6] -if doElseIfs.isEmpty && !doElse.isNone then - pure none -else - let doElse ← - if doElse.isNone then - let pureUnit ← mkPureUnit ref - pure $ mkNullNode #[ - mkAtomFrom ref "else", - mkSingletonDoSeq (mkNode `Lean.Parser.Term.doExpr #[pureUnit]) - ] - else - pure doElse - let doElse := doElseIfs.foldr - (fun doElseIf doElse => - let ifAtom := doElseIf[0][1] - let doIfArgs := (doElseIf.getArgs).set! 0 ifAtom - let doIfArgs := doIfArgs.push mkNullNode - let doIfArgs := doIfArgs.push doElse - mkNullNode #[mkAtomFrom doElseIf "else", - mkSingletonDoSeq $ mkNode `Lean.Parser.Term.doIf doIfArgs]) - doElse - let doIf := doIf.setArg 6 doElse - pure $ some $ doIf.setArg 5 mkNullNode -- remove else-ifs + if stx.getKind != `Lean.Parser.Term.doIf then pure none else + let doIf := stx + let ref := stx + let doElseIfs := doIf[5].getArgs + let doElse := doIf[6] + if doElseIfs.isEmpty && !doElse.isNone then + pure none + else + let doElse ← + if doElse.isNone then + let pureUnit ← mkPureUnit ref + pure $ mkNullNode #[ + mkAtomFrom ref "else", + mkSingletonDoSeq (mkNode `Lean.Parser.Term.doExpr #[pureUnit]) + ] + else + pure doElse + let doElse := doElseIfs.foldr + (fun doElseIf doElse => + let ifAtom := doElseIf[0][1] + let doIfArgs := (doElseIf.getArgs).set! 0 ifAtom + let doIfArgs := doIfArgs.push mkNullNode + let doIfArgs := doIfArgs.push doElse + mkNullNode #[mkAtomFrom doElseIf "else", + mkSingletonDoSeq $ mkNode `Lean.Parser.Term.doIf doIfArgs]) + doElse + let doIf := doIf.setArg 6 doElse + pure $ some $ doIf.setArg 5 mkNullNode -- remove else-ifs structure DoIfView := -(ref : Syntax) -(optIdent : Syntax) -(cond : Syntax) -(thenBranch : Syntax) -(elseBranch : Syntax) + (ref : Syntax) + (optIdent : Syntax) + (cond : Syntax) + (thenBranch : Syntax) + (elseBranch : Syntax) /- This method assumes `expandDoIf?` is not applicable. -/ private def mkDoIfView (doIf : Syntax) : MacroM DoIfView := do -pure { - ref := doIf, - optIdent := doIf[1], - cond := doIf[2], - thenBranch := doIf[4], - elseBranch := doIf[6][1] } + pure { + ref := doIf, + optIdent := doIf[1], + cond := doIf[2], + thenBranch := doIf[4], + elseBranch := doIf[6][1] + } /- We use `MProd` instead of `Prod` to group values when expanding the @@ -669,25 +676,24 @@ Note that we are not restricting the macro power since the `HasBind.bind` combinator already forces values computed by monadic actions to be in the same universe. -/ - private def mkTuple (ref : Syntax) (elems : Array Syntax) : MacroM Syntax := do -if elems.size == 0 then - mkUnit ref -else if elems.size == 1 then - pure elems[0] -else - (elems.extract 0 (elems.size - 1)).foldrM - (fun elem tuple => do - let tuple ← `(MProd.mk $elem $tuple) - pure $ tuple.copyInfo ref) - (elems.back) + if elems.size == 0 then + mkUnit ref + else if elems.size == 1 then + pure elems[0] + else + (elems.extract 0 (elems.size - 1)).foldrM + (fun elem tuple => do + let tuple ← `(MProd.mk $elem $tuple) + pure $ tuple.copyInfo ref) + (elems.back) /- Return `some action` if `doElem` is a `doExpr `-/ def isDoExpr? (doElem : Syntax) : Option Syntax := -if doElem.getKind == `Lean.Parser.Term.doExpr then - some doElem[0] -else - none + if doElem.getKind == `Lean.Parser.Term.doExpr then + some doElem[0] + else + none /- The procedure `ToTerm.run` converts a `CodeBlock` into a `Syntax` term. @@ -773,241 +779,241 @@ Example: suppose we want to support `repeat doSeq`. Assuming we have `repeat : m namespace ToTerm inductive Kind -| regular -| forIn -| forInWithReturn -| nestedBC -| nestedPR -| nestedSBC -| nestedPRBC + | regular + | forIn + | forInWithReturn + | nestedBC + | nestedPR + | nestedSBC + | nestedPRBC -instance Kind.inhabited : Inhabited Kind := ⟨Kind.regular⟩ +instance : Inhabited Kind := ⟨Kind.regular⟩ def Kind.isRegular : Kind → Bool -| Kind.regular => true -| _ => false + | Kind.regular => true + | _ => false structure Context := -(m : Syntax) -- Syntax to reference the monad associated with the do notation. -(uvars : Array Name) -(kind : Kind) + (m : Syntax) -- Syntax to reference the monad associated with the do notation. + (uvars : Array Name) + (kind : Kind) abbrev M := ReaderT Context MacroM def mkUVarTuple (ref : Syntax) : M Syntax := do -let ctx ← read -let uvarIdents := ctx.uvars.map fun x => mkIdentFrom ref x -mkTuple ref uvarIdents + let ctx ← read + let uvarIdents := ctx.uvars.map fun x => mkIdentFrom ref x + mkTuple ref uvarIdents def returnToTermCore (ref : Syntax) (val : Syntax) : M Syntax := do -let ctx ← read -let u ← mkUVarTuple ref -match ctx.kind with -| Kind.regular => if ctx.uvars.isEmpty then `(HasPure.pure $val) else `(HasPure.pure (MProd.mk $val $u)) -| Kind.forIn => `(HasPure.pure (ForInStep.done $u)) -| Kind.forInWithReturn => `(HasPure.pure (ForInStep.done (MProd.mk (some $val) $u))) -| Kind.nestedBC => unreachable! -| Kind.nestedPR => `(HasPure.pure (DoResultPR.«return» $val $u)) -| Kind.nestedSBC => `(HasPure.pure (DoResultSBC.«pureReturn» $val $u)) -| Kind.nestedPRBC => `(HasPure.pure (DoResultPRBC.«return» $val $u)) + let ctx ← read + let u ← mkUVarTuple ref + match ctx.kind with + | Kind.regular => if ctx.uvars.isEmpty then `(HasPure.pure $val) else `(HasPure.pure (MProd.mk $val $u)) + | Kind.forIn => `(HasPure.pure (ForInStep.done $u)) + | Kind.forInWithReturn => `(HasPure.pure (ForInStep.done (MProd.mk (some $val) $u))) + | Kind.nestedBC => unreachable! + | Kind.nestedPR => `(HasPure.pure (DoResultPR.«return» $val $u)) + | Kind.nestedSBC => `(HasPure.pure (DoResultSBC.«pureReturn» $val $u)) + | Kind.nestedPRBC => `(HasPure.pure (DoResultPRBC.«return» $val $u)) def returnToTerm (ref : Syntax) (val : Syntax) : M Syntax := do -let r ← returnToTermCore ref val -pure $ r.copyInfo ref + let r ← returnToTermCore ref val + pure $ r.copyInfo ref def continueToTermCore (ref : Syntax) : M Syntax := do -let ctx ← read -let u ← mkUVarTuple ref -match ctx.kind with -| Kind.regular => unreachable! -| Kind.forIn => `(HasPure.pure (ForInStep.yield $u)) -| Kind.forInWithReturn => `(HasPure.pure (ForInStep.yield (MProd.mk none $u))) -| Kind.nestedBC => `(HasPure.pure (DoResultBC.«continue» $u)) -| Kind.nestedPR => unreachable! -| Kind.nestedSBC => `(HasPure.pure (DoResultSBC.«continue» $u)) -| Kind.nestedPRBC => `(HasPure.pure (DoResultPRBC.«continue» $u)) + let ctx ← read + let u ← mkUVarTuple ref + match ctx.kind with + | Kind.regular => unreachable! + | Kind.forIn => `(HasPure.pure (ForInStep.yield $u)) + | Kind.forInWithReturn => `(HasPure.pure (ForInStep.yield (MProd.mk none $u))) + | Kind.nestedBC => `(HasPure.pure (DoResultBC.«continue» $u)) + | Kind.nestedPR => unreachable! + | Kind.nestedSBC => `(HasPure.pure (DoResultSBC.«continue» $u)) + | Kind.nestedPRBC => `(HasPure.pure (DoResultPRBC.«continue» $u)) def continueToTerm (ref : Syntax) : M Syntax := do -let r ← continueToTermCore ref -pure $ r.copyInfo ref + let r ← continueToTermCore ref + pure $ r.copyInfo ref def breakToTermCore (ref : Syntax) : M Syntax := do -let ctx ← read -let u ← mkUVarTuple ref -match ctx.kind with -| Kind.regular => unreachable! -| Kind.forIn => `(HasPure.pure (ForInStep.done $u)) -| Kind.forInWithReturn => `(HasPure.pure (ForInStep.done (MProd.mk none $u))) -| Kind.nestedBC => `(HasPure.pure (DoResultBC.«break» $u)) -| Kind.nestedPR => unreachable! -| Kind.nestedSBC => `(HasPure.pure (DoResultSBC.«break» $u)) -| Kind.nestedPRBC => `(HasPure.pure (DoResultPRBC.«break» $u)) + let ctx ← read + let u ← mkUVarTuple ref + match ctx.kind with + | Kind.regular => unreachable! + | Kind.forIn => `(HasPure.pure (ForInStep.done $u)) + | Kind.forInWithReturn => `(HasPure.pure (ForInStep.done (MProd.mk none $u))) + | Kind.nestedBC => `(HasPure.pure (DoResultBC.«break» $u)) + | Kind.nestedPR => unreachable! + | Kind.nestedSBC => `(HasPure.pure (DoResultSBC.«break» $u)) + | Kind.nestedPRBC => `(HasPure.pure (DoResultPRBC.«break» $u)) def breakToTerm (ref : Syntax) : M Syntax := do -let r ← breakToTermCore ref -pure $ r.copyInfo ref + let r ← breakToTermCore ref + pure $ r.copyInfo ref def actionTerminalToTermCore (action : Syntax) : M Syntax := withFreshMacroScope do -let ref := action -let ctx ← read -let u ← mkUVarTuple ref -match ctx.kind with -| Kind.regular => if ctx.uvars.isEmpty then pure action else `(HasBind.bind $action fun y => HasPure.pure (MProd.mk y $u)) -| Kind.forIn => `(HasBind.bind $action fun (_ : PUnit) => HasPure.pure (ForInStep.yield $u)) -| Kind.forInWithReturn => `(HasBind.bind $action fun (_ : PUnit) => HasPure.pure (ForInStep.yield (MProd.mk none $u))) -| Kind.nestedBC => unreachable! -| Kind.nestedPR => `(HasBind.bind $action fun y => (HasPure.pure (DoResultPR.«pure» y $u))) -| Kind.nestedSBC => `(HasBind.bind $action fun y => (HasPure.pure (DoResultSBC.«pureReturn» y $u))) -| Kind.nestedPRBC => `(HasBind.bind $action fun y => (HasPure.pure (DoResultPRBC.«pure» y $u))) + let ref := action + let ctx ← read + let u ← mkUVarTuple ref + match ctx.kind with + | Kind.regular => if ctx.uvars.isEmpty then pure action else `(HasBind.bind $action fun y => HasPure.pure (MProd.mk y $u)) + | Kind.forIn => `(HasBind.bind $action fun (_ : PUnit) => HasPure.pure (ForInStep.yield $u)) + | Kind.forInWithReturn => `(HasBind.bind $action fun (_ : PUnit) => HasPure.pure (ForInStep.yield (MProd.mk none $u))) + | Kind.nestedBC => unreachable! + | Kind.nestedPR => `(HasBind.bind $action fun y => (HasPure.pure (DoResultPR.«pure» y $u))) + | Kind.nestedSBC => `(HasBind.bind $action fun y => (HasPure.pure (DoResultSBC.«pureReturn» y $u))) + | Kind.nestedPRBC => `(HasBind.bind $action fun y => (HasPure.pure (DoResultPRBC.«pure» y $u))) def actionTerminalToTerm (action : Syntax) : M Syntax := do -let ref := action -let r ← actionTerminalToTermCore action -pure $ r.copyInfo ref + let ref := action + let r ← actionTerminalToTermCore action + pure $ r.copyInfo ref def seqToTermCore (action : Syntax) (k : Syntax) : MacroM Syntax := withFreshMacroScope do -if action.getKind == `Lean.Parser.Term.doDbgTrace then - let msg := action[1] - `(dbgTrace! $msg; $k) -else if action.getKind == `Lean.Parser.Term.doAssert then - let cond := action[1] - `(assert! $cond; $k) -else - `(HasBind.bind $action (fun _ => $k)) + if action.getKind == `Lean.Parser.Term.doDbgTrace then + let msg := action[1] + `(dbgTrace! $msg; $k) + else if action.getKind == `Lean.Parser.Term.doAssert then + let cond := action[1] + `(assert! $cond; $k) + else + `(HasBind.bind $action (fun _ => $k)) def seqToTerm (action : Syntax) (k : Syntax) : MacroM Syntax := do -let r ← seqToTermCore action k -pure $ r.copyInfo action + let r ← seqToTermCore action k + pure $ r.copyInfo action def declToTermCore (decl : Syntax) (k : Syntax) : M Syntax := withFreshMacroScope do -let kind := decl.getKind -if kind == `Lean.Parser.Term.doLet then - let letDecl := decl[1] - `(let $letDecl:letDecl; $k) -else if kind == `Lean.Parser.Term.doLetRec then - let letRecToken := decl[0] - let letRecDecls := decl[1] - pure $ mkNode `Lean.Parser.Term.letrec #[letRecToken, letRecDecls, mkNullNode, k] -else if kind == `Lean.Parser.Term.doLetArrow then - let arg := decl[1] - let ref := arg - if arg.getKind == `Lean.Parser.Term.doIdDecl then - let id := arg[0] - let type := expandOptType ref arg[1] - let doElem := arg[3] - -- `doElem` must be a `doExpr action`. See `doLetArrowToCode` - match isDoExpr? doElem with - | some action => `(HasBind.bind $action (fun ($id:ident : $type) => $k)) - | none => liftM $ Macro.throwError decl "unexpected kind of 'do' declaration" + let kind := decl.getKind + if kind == `Lean.Parser.Term.doLet then + let letDecl := decl[1] + `(let $letDecl:letDecl; $k) + else if kind == `Lean.Parser.Term.doLetRec then + let letRecToken := decl[0] + let letRecDecls := decl[1] + pure $ mkNode `Lean.Parser.Term.letrec #[letRecToken, letRecDecls, mkNullNode, k] + else if kind == `Lean.Parser.Term.doLetArrow then + let arg := decl[1] + let ref := arg + if arg.getKind == `Lean.Parser.Term.doIdDecl then + let id := arg[0] + let type := expandOptType ref arg[1] + let doElem := arg[3] + -- `doElem` must be a `doExpr action`. See `doLetArrowToCode` + match isDoExpr? doElem with + | some action => `(HasBind.bind $action (fun ($id:ident : $type) => $k)) + | none => Macro.throwError decl "unexpected kind of 'do' declaration" + else + Macro.throwError decl "unexpected kind of 'do' declaration" + else if kind == `Lean.Parser.Term.doHave then + -- The `have` term is of the form `"have " >> haveDecl >> optSemicolon termParser` + let args := decl.getArgs + let args := args ++ #[mkNullNode /- optional ';' -/, k] + pure $ mkNode `Lean.Parser.Term.«have» args else - liftM $ Macro.throwError decl "unexpected kind of 'do' declaration" -else if kind == `Lean.Parser.Term.doHave then - -- The `have` term is of the form `"have " >> haveDecl >> optSemicolon termParser` - let args := decl.getArgs - let args := args ++ #[mkNullNode /- optional ';' -/, k] - pure $ mkNode `Lean.Parser.Term.«have» args -else - liftM $ Macro.throwError decl "unexpected kind of 'do' declaration" + Macro.throwError decl "unexpected kind of 'do' declaration" def declToTerm (decl : Syntax) (k : Syntax) : M Syntax := do -let r ← declToTermCore decl k -pure $ r.copyInfo decl + let r ← declToTermCore decl k + pure $ r.copyInfo decl def reassignToTermCore (reassign : Syntax) (k : Syntax) : MacroM Syntax := withFreshMacroScope do -let kind := reassign.getKind -if kind == `Lean.Parser.Term.doReassign then - -- doReassign := parser! (letIdDecl <|> letPatDecl) - let arg := reassign[0] - if arg.getKind == `Lean.Parser.Term.letIdDecl then - -- letIdDecl := parser! ident >> many (ppSpace >> bracketedBinder) >> optType >> " := " >> termParser - let x := arg[0] - let val := arg[4] - let newVal ← `(ensureTypeOf! $x $(quote "invalid reassignment, value") $val) - let arg := arg.setArg 4 newVal - let letDecl := mkNode `Lean.Parser.Term.letDecl #[arg] - `(let $letDecl:letDecl; $k) + let kind := reassign.getKind + if kind == `Lean.Parser.Term.doReassign then + -- doReassign := parser! (letIdDecl <|> letPatDecl) + let arg := reassign[0] + if arg.getKind == `Lean.Parser.Term.letIdDecl then + -- letIdDecl := parser! ident >> many (ppSpace >> bracketedBinder) >> optType >> " := " >> termParser + let x := arg[0] + let val := arg[4] + let newVal ← `(ensureTypeOf! $x $(quote "invalid reassignment, value") $val) + let arg := arg.setArg 4 newVal + let letDecl := mkNode `Lean.Parser.Term.letDecl #[arg] + `(let $letDecl:letDecl; $k) + else + -- TODO: ensure the types did not change + let letDecl := mkNode `Lean.Parser.Term.letDecl #[arg] + `(let $letDecl:letDecl; $k) else - -- TODO: ensure the types did not change - let letDecl := mkNode `Lean.Parser.Term.letDecl #[arg] - `(let $letDecl:letDecl; $k) -else - -- Note that `doReassignArrow` is expanded by `doReassignArrowToCode - Macro.throwError reassign "unexpected kind of 'do' reassignment" + -- Note that `doReassignArrow` is expanded by `doReassignArrowToCode + Macro.throwError reassign "unexpected kind of 'do' reassignment" def reassignToTerm (reassign : Syntax) (k : Syntax) : MacroM Syntax := do -let r ← reassignToTermCore reassign k -pure $ r.copyInfo reassign + let r ← reassignToTermCore reassign k + pure $ r.copyInfo reassign def mkIte (ref : Syntax) (optIdent : Syntax) (cond : Syntax) (thenBranch : Syntax) (elseBranch : Syntax) : Syntax := -mkNode `Lean.Parser.Term.«if» #[mkAtomFrom ref "if", optIdent, cond, mkAtomFrom ref "then", thenBranch, mkAtomFrom ref "else", elseBranch] + mkNode `Lean.Parser.Term.«if» #[mkAtomFrom ref "if", optIdent, cond, mkAtomFrom ref "then", thenBranch, mkAtomFrom ref "else", elseBranch] def mkJoinPointCore (j : Name) (ps : Array (Name × Bool)) (body : Syntax) (k : Syntax) : M Syntax := withFreshMacroScope do -let ref := body -let binders ← ps.mapM fun ⟨id, useTypeOf⟩ => do - let type ← if useTypeOf then `(typeOf! $(mkIdentFrom ref id)) else `(_) - let binderType := mkNullNode #[mkAtomFrom ref ":", type] - pure $ mkNode `Lean.Parser.Term.explicitBinder #[mkAtomFrom ref "(", mkNullNode #[mkIdentFrom ref id], binderType, mkNullNode, mkAtomFrom ref ")"] -let m := (← read).m -let type ← `($m _) -/- -We use `let*` instead of `let` for joinpoints to make sure `$k` is elaborated before `$body`. -By elaborating `$k` first, we "learn" more about `$body`'s type. -For example, consider the following example `do` expression -``` -def f (x : Nat) : IO Unit := do -if x > 0 then - IO.println "x is not zero" -- Error is here -IO.mkRef true -``` -it is expanded into -``` -def f (x : Nat) : IO Unit := do -let jp (u : Unit) : IO _ := - IO.mkRef true; -if x > 0 then - IO.println "not zero" - jp () -else - jp () -``` -If we use the regular `let` instead of `let*`, the joinpoint `jp` will be elaborated and its type will be inferred to be `Unit → IO (IO.Ref Bool)`. -Then, we get a typing error at `jp ()`. By using `let*`, we first elaborate `if x > 0 ...` and learn that `jp` has type `Unit → IO Unit`. -Then, we get the expected type mismatch error at `IO.mkRef true`. -/ -`(let* $(mkIdentFrom ref j):ident $binders:explicitBinder* : $type := $body; $k) + let ref := body + let binders ← ps.mapM fun ⟨id, useTypeOf⟩ => do + let type ← if useTypeOf then `(typeOf! $(mkIdentFrom ref id)) else `(_) + let binderType := mkNullNode #[mkAtomFrom ref ":", type] + pure $ mkNode `Lean.Parser.Term.explicitBinder #[mkAtomFrom ref "(", mkNullNode #[mkIdentFrom ref id], binderType, mkNullNode, mkAtomFrom ref ")"] + let m := (← read).m + let type ← `($m _) + /- + We use `let*` instead of `let` for joinpoints to make sure `$k` is elaborated before `$body`. + By elaborating `$k` first, we "learn" more about `$body`'s type. + For example, consider the following example `do` expression + ``` + def f (x : Nat) : IO Unit := do + if x > 0 then + IO.println "x is not zero" -- Error is here + IO.mkRef true + ``` + it is expanded into + ``` + def f (x : Nat) : IO Unit := do + let jp (u : Unit) : IO _ := + IO.mkRef true; + if x > 0 then + IO.println "not zero" + jp () + else + jp () + ``` + If we use the regular `let` instead of `let*`, the joinpoint `jp` will be elaborated and its type will be inferred to be `Unit → IO (IO.Ref Bool)`. + Then, we get a typing error at `jp ()`. By using `let*`, we first elaborate `if x > 0 ...` and learn that `jp` has type `Unit → IO Unit`. + Then, we get the expected type mismatch error at `IO.mkRef true`. -/ + `(let* $(mkIdentFrom ref j):ident $binders:explicitBinder* : $type := $body; $k) def mkJoinPoint (j : Name) (ps : Array (Name × Bool)) (body : Syntax) (k : Syntax) : M Syntax := do -let r ← mkJoinPointCore j ps body k -pure $ r.copyInfo body + let r ← mkJoinPointCore j ps body k + pure $ r.copyInfo body def mkJmp (ref : Syntax) (j : Name) (args : Array Syntax) : Syntax := -mkAppStx (mkIdentFrom ref j) args + mkAppStx (mkIdentFrom ref j) args partial def toTerm : Code → M Syntax -| Code.«return» ref val => returnToTerm ref val -| Code.«continue» ref => continueToTerm ref -| Code.«break» ref => breakToTerm ref -| Code.action e => actionTerminalToTerm e -| Code.joinpoint j ps b k => do mkJoinPoint j ps (← toTerm b) (← toTerm k) -| Code.jmp ref j args => pure $ mkJmp ref j args -| Code.decl _ stx k => do declToTerm stx (← toTerm k) -| Code.reassign _ stx k => do reassignToTerm stx (← toTerm k) -| Code.seq stx k => do seqToTerm stx (← toTerm k) -| Code.ite ref _ o c t e => do pure $ mkIte ref o c (← toTerm t) (← toTerm e) -| Code.«match» ref discrs optType alts => do - let termSepAlts := #[] - for alt in alts do - termSepAlts := termSepAlts.push $ mkAtomFrom alt.ref "|" - let rhs ← toTerm alt.rhs - let termAlt := mkNode `Lean.Parser.Term.matchAlt #[alt.patterns, mkAtomFrom alt.ref "=>", rhs] - termSepAlts := termSepAlts.push termAlt - let firstVBar := termSepAlts[0] - let termSepAlts := mkNullNode termSepAlts[1:termSepAlts.size] - let termMatchAlts := mkNode `Lean.Parser.Term.matchAlts #[mkNullNode #[firstVBar], termSepAlts] - pure $ mkNode `Lean.Parser.Term.«match» #[mkAtomFrom ref "match", discrs, optType, mkAtomFrom ref "with", termMatchAlts] + | Code.«return» ref val => returnToTerm ref val + | Code.«continue» ref => continueToTerm ref + | Code.«break» ref => breakToTerm ref + | Code.action e => actionTerminalToTerm e + | Code.joinpoint j ps b k => do mkJoinPoint j ps (← toTerm b) (← toTerm k) + | Code.jmp ref j args => pure $ mkJmp ref j args + | Code.decl _ stx k => do declToTerm stx (← toTerm k) + | Code.reassign _ stx k => do reassignToTerm stx (← toTerm k) + | Code.seq stx k => do seqToTerm stx (← toTerm k) + | Code.ite ref _ o c t e => do pure $ mkIte ref o c (← toTerm t) (← toTerm e) + | Code.«match» ref discrs optType alts => do + let termSepAlts := #[] + for alt in alts do + termSepAlts := termSepAlts.push $ mkAtomFrom alt.ref "|" + let rhs ← toTerm alt.rhs + let termAlt := mkNode `Lean.Parser.Term.matchAlt #[alt.patterns, mkAtomFrom alt.ref "=>", rhs] + termSepAlts := termSepAlts.push termAlt + let firstVBar := termSepAlts[0] + let termSepAlts := mkNullNode termSepAlts[1:termSepAlts.size] + let termMatchAlts := mkNode `Lean.Parser.Term.matchAlts #[mkNullNode #[firstVBar], termSepAlts] + pure $ mkNode `Lean.Parser.Term.«match» #[mkAtomFrom ref "match", discrs, optType, mkAtomFrom ref "with", termMatchAlts] def run (code : Code) (m : Syntax) (uvars : Array Name := #[]) (kind := Kind.regular) : MacroM Syntax := do -let term ← toTerm code { m := m, kind := kind, uvars := uvars } -pure term + let term ← toTerm code { m := m, kind := kind, uvars := uvars } + pure term /- Given - `a` is true if the code block has a `Code.action _` exit point @@ -1016,18 +1022,18 @@ pure term generate Kind. See comment at the beginning of the `ToTerm` namespace. -/ def mkNestedKind (a r bc : Bool) : Kind := -match a, r, bc with -| true, false, false => Kind.regular -| false, true, false => Kind.regular -| false, false, true => Kind.nestedBC -| true, true, false => Kind.nestedPR -| true, false, true => Kind.nestedSBC -| false, true, true => Kind.nestedSBC -| true, true, true => Kind.nestedPRBC -| false, false, false => unreachable! + match a, r, bc with + | true, false, false => Kind.regular + | false, true, false => Kind.regular + | false, false, true => Kind.nestedBC + | true, true, false => Kind.nestedPR + | true, false, true => Kind.nestedSBC + | false, true, true => Kind.nestedSBC + | true, true, true => Kind.nestedPRBC + | false, false, false => unreachable! def mkNestedTerm (code : Code) (m : Syntax) (uvars : Array Name) (a r bc : Bool) : MacroM Syntax := do -ToTerm.run code m uvars (mkNestedKind a r bc) + ToTerm.run code m uvars (mkNestedKind a r bc) /- Given a term `term` produced by `ToTerm.run`, pattern match on its result. See comment at the beginning of the `ToTerm` namespace. @@ -1038,137 +1044,137 @@ ToTerm.run code m uvars (mkNestedKind a r bc) The result is a sequence of `doElem` -/ def matchNestedTermResult (ref : Syntax) (term : Syntax) (uvars : Array Name) (a r bc : Bool) : MacroM (List Syntax) := do -let toDoElems (auxDo : Syntax) : List Syntax := getDoSeqElems (getDoSeq auxDo) -let u ← mkTuple ref (uvars.map (mkIdentFrom ref)) -match a, r, bc with -| true, false, false => - if uvars.isEmpty then - toDoElems (← `(do $term:term)) - else - toDoElems (← `(do let r ← $term:term; $u:term := r.2; pure r.1)) -| false, true, false => - if uvars.isEmpty then - toDoElems (← `(do let r ← $term:term; return r)) - else - toDoElems (← `(do let r ← $term:term; $u:term := r.2; return r.1)) -| false, false, true => toDoElems <$> - `(do let r ← $term:term; - match r with - | DoResultBC.«break» u => $u:term := u; break - | DoResultBC.«continue» u => $u:term := u; continue) -| true, true, false => toDoElems <$> - `(do let r ← $term:term; - match r with - | DoResultPR.«pure» a u => $u:term := u; pure a - | DoResultPR.«return» b u => $u:term := u; return b) -| true, false, true => toDoElems <$> - `(do let r ← $term:term; - match r with - | DoResultSBC.«pureReturn» a u => $u:term := u; pure a - | DoResultSBC.«break» u => $u:term := u; break - | DoResultSBC.«continue» u => $u:term := u; continue) -| false, true, true => toDoElems <$> - `(do let r ← $term:term; - match r with - | DoResultSBC.«pureReturn» a u => $u:term := u; return a - | DoResultSBC.«break» u => $u:term := u; break - | DoResultSBC.«continue» u => $u:term := u; continue) -| true, true, true => toDoElems <$> - `(do let r ← $term:term; - match r with - | DoResultPRBC.«pure» a u => $u:term := u; pure a - | DoResultPRBC.«return» a u => $u:term := u; return a - | DoResultPRBC.«break» u => $u:term := u; break - | DoResultPRBC.«continue» u => $u:term := u; continue) -| false, false, false => unreachable! + let toDoElems (auxDo : Syntax) : List Syntax := getDoSeqElems (getDoSeq auxDo) + let u ← mkTuple ref (uvars.map (mkIdentFrom ref)) + match a, r, bc with + | true, false, false => + if uvars.isEmpty then + toDoElems (← `(do $term:term)) + else + toDoElems (← `(do let r ← $term:term; $u:term := r.2; pure r.1)) + | false, true, false => + if uvars.isEmpty then + toDoElems (← `(do let r ← $term:term; return r)) + else + toDoElems (← `(do let r ← $term:term; $u:term := r.2; return r.1)) + | false, false, true => toDoElems <$> + `(do let r ← $term:term; + match r with + | DoResultBC.«break» u => $u:term := u; break + | DoResultBC.«continue» u => $u:term := u; continue) + | true, true, false => toDoElems <$> + `(do let r ← $term:term; + match r with + | DoResultPR.«pure» a u => $u:term := u; pure a + | DoResultPR.«return» b u => $u:term := u; return b) + | true, false, true => toDoElems <$> + `(do let r ← $term:term; + match r with + | DoResultSBC.«pureReturn» a u => $u:term := u; pure a + | DoResultSBC.«break» u => $u:term := u; break + | DoResultSBC.«continue» u => $u:term := u; continue) + | false, true, true => toDoElems <$> + `(do let r ← $term:term; + match r with + | DoResultSBC.«pureReturn» a u => $u:term := u; return a + | DoResultSBC.«break» u => $u:term := u; break + | DoResultSBC.«continue» u => $u:term := u; continue) + | true, true, true => toDoElems <$> + `(do let r ← $term:term; + match r with + | DoResultPRBC.«pure» a u => $u:term := u; pure a + | DoResultPRBC.«return» a u => $u:term := u; return a + | DoResultPRBC.«break» u => $u:term := u; break + | DoResultPRBC.«continue» u => $u:term := u; continue) + | false, false, false => unreachable! end ToTerm namespace ToCodeBlock structure Context := -(ref : Syntax) -(m : Syntax) -- Syntax representing the monad associated with the do notation. -(varSet : NameSet := {}) -(insideFor : Bool := false) + (ref : Syntax) + (m : Syntax) -- Syntax representing the monad associated with the do notation. + (varSet : NameSet := {}) + (insideFor : Bool := false) abbrev M := ReaderT Context TermElabM @[inline] def withNewVars {α} (newVars : Array Name) (x : M α) : M α := -withReader (fun ctx => { ctx with varSet := insertVars ctx.varSet newVars }) x + withReader (fun ctx => { ctx with varSet := insertVars ctx.varSet newVars }) x def checkReassignable (xs : Array Name) : M Unit := do -let ctx ← read -for x in xs do - unless ctx.varSet.contains x do - match (← resolveLocalName x) with - | some (_, []) => pure () - | _ => throwError! "'{x.simpMacroScopes}' cannot be reassigned" + let ctx ← read + for x in xs do + unless ctx.varSet.contains x do + match (← resolveLocalName x) with + | some (_, []) => pure () + | _ => throwError! "'{x.simpMacroScopes}' cannot be reassigned" @[inline] def withFor {α} (x : M α) : M α := -withReader (fun ctx => { ctx with insideFor := true }) x + withReader (fun ctx => { ctx with insideFor := true }) x structure ToForInTermResult := -(uvars : Array Name) -(term : Syntax) + (uvars : Array Name) + (term : Syntax) def mkForInBody (x : Syntax) (forInBody : CodeBlock) : M ToForInTermResult := do -let ctx ← read -let uvars := forInBody.uvars -let uvars := nameSetToArray uvars -let term ← liftMacroM $ ToTerm.run forInBody.code ctx.m uvars (if hasReturn forInBody.code then ToTerm.Kind.forInWithReturn else ToTerm.Kind.forIn) -pure ⟨uvars, term⟩ + let ctx ← read + let uvars := forInBody.uvars + let uvars := nameSetToArray uvars + let term ← liftMacroM $ ToTerm.run forInBody.code ctx.m uvars (if hasReturn forInBody.code then ToTerm.Kind.forInWithReturn else ToTerm.Kind.forIn) + pure ⟨uvars, term⟩ def ensureInsideFor : M Unit := do -let ctx ← read -unless ctx.insideFor do - throwError "invalid 'do' element, it must be inside 'for'" + let ctx ← read + unless ctx.insideFor do + throwError "invalid 'do' element, it must be inside 'for'" def ensureEOS (doElems : List Syntax) : M Unit := do -unless doElems.isEmpty do - throwError "must be last element in a 'do' sequence" + unless doElems.isEmpty do + throwError "must be last element in a 'do' sequence" private partial def expandLiftMethodAux : Syntax → StateT (List Syntax) MacroM Syntax -| stx@(Syntax.node k args) => - if k == `Lean.Parser.Term.do then pure stx - else if k == `Lean.Parser.Term.doSeqIndent then pure stx - else if k == `Lean.Parser.Term.doSeqBracketed then pure stx - else if k == `Lean.Parser.Term.quot then pure stx - else if k == `Lean.Parser.Term.liftMethod then withFreshMacroScope do - let term := args[1] - let term ← expandLiftMethodAux term - let auxDoElem ← `(doElem| let a ← $term:term) - modify fun s => s ++ [auxDoElem] - `(a) - else do - let args ← args.mapM expandLiftMethodAux - pure $ Syntax.node k args -| stx => pure stx + | stx@(Syntax.node k args) => + if k == `Lean.Parser.Term.do then pure stx + else if k == `Lean.Parser.Term.doSeqIndent then pure stx + else if k == `Lean.Parser.Term.doSeqBracketed then pure stx + else if k == `Lean.Parser.Term.quot then pure stx + else if k == `Lean.Parser.Term.liftMethod then withFreshMacroScope do + let term := args[1] + let term ← expandLiftMethodAux term + let auxDoElem ← `(doElem| let a ← $term:term) + modify fun s => s ++ [auxDoElem] + `(a) + else do + let args ← args.mapM expandLiftMethodAux + pure $ Syntax.node k args + | stx => pure stx def expandLiftMethod (doElem : Syntax) : MacroM (List Syntax × Syntax) := -if !hasLiftMethod doElem then pure ([], doElem) -else do - let (doElem, doElemsNew) ← (expandLiftMethodAux doElem).run [] - pure (doElemsNew, doElem) + if !hasLiftMethod doElem then pure ([], doElem) + else do + let (doElem, doElemsNew) ← (expandLiftMethodAux doElem).run [] + pure (doElemsNew, doElem) /- "Concatenate" `c` with `doSeqToCode doElems` -/ def concatWith (doSeqToCode : List Syntax → M CodeBlock) (c : CodeBlock) (doElems : List Syntax) : M CodeBlock := -match doElems with -| [] => pure c -| nextDoElem :: _ => do - let k ← doSeqToCode doElems - let ref := nextDoElem - liftM $ concat c ref none k + match doElems with + | [] => pure c + | nextDoElem :: _ => do + let k ← doSeqToCode doElems + let ref := nextDoElem + liftM $ concat c ref none k def checkLetArrowRHS (doElem : Syntax) : M Unit := do -let kind := doElem.getKind -if kind == `Lean.Parser.Term.doLetArrow || - kind == `Lean.Parser.Term.doLet || - kind == `Lean.Parser.Term.doLetRec || - kind == `Lean.Parser.Term.doHave || - kind == `Lean.Parser.Term.doReassign || - kind == `Lean.Parser.Term.doReassignArrow then - throwErrorAt! doElem "invalid kind of value '{kind}' in an assignment" + let kind := doElem.getKind + if kind == `Lean.Parser.Term.doLetArrow || + kind == `Lean.Parser.Term.doLet || + kind == `Lean.Parser.Term.doLetRec || + kind == `Lean.Parser.Term.doHave || + kind == `Lean.Parser.Term.doReassign || + kind == `Lean.Parser.Term.doReassignArrow then + throwErrorAt! doElem "invalid kind of value '{kind}' in an assignment" /- Generate `CodeBlock` for `doLetArrow; doElems` `doLetArrow` is of the form @@ -1181,34 +1187,34 @@ if kind == `Lean.Parser.Term.doLetArrow || def doPatDecl := parser! termParser >> leftArrow >> doElemParser >> optional (" | " >> doElemParser) ``` -/ def doLetArrowToCode (doSeqToCode : List Syntax → M CodeBlock) (doLetArrow : Syntax) (doElems : List Syntax) : M CodeBlock := do -let ref := doLetArrow -let decl := doLetArrow[1] -if decl.getKind == `Lean.Parser.Term.doIdDecl then - let y := decl[0].getId - let doElem := decl[3] - let k ← withNewVars #[y] (doSeqToCode doElems) - match isDoExpr? doElem with - | some action => pure $ mkVarDeclCore #[y] doLetArrow k - | none => - checkLetArrowRHS doElem - let c ← doSeqToCode [doElem] - match doElems with - | [] => pure c - | kRef::_ => liftM $ concat c kRef y k -else if decl.getKind == `Lean.Parser.Term.doPatDecl then - let pattern := decl[0] - let doElem := decl[2] - let optElse := decl[3] - if optElse.isNone then withFreshMacroScope do - let auxDo ← `(do let discr ← $doElem; let $pattern:term := discr) - doSeqToCode $ getDoSeqElems (getDoSeq auxDo) ++ doElems + let ref := doLetArrow + let decl := doLetArrow[1] + if decl.getKind == `Lean.Parser.Term.doIdDecl then + let y := decl[0].getId + let doElem := decl[3] + let k ← withNewVars #[y] (doSeqToCode doElems) + match isDoExpr? doElem with + | some action => pure $ mkVarDeclCore #[y] doLetArrow k + | none => + checkLetArrowRHS doElem + let c ← doSeqToCode [doElem] + match doElems with + | [] => pure c + | kRef::_ => liftM $ concat c kRef y k + else if decl.getKind == `Lean.Parser.Term.doPatDecl then + let pattern := decl[0] + let doElem := decl[2] + let optElse := decl[3] + if optElse.isNone then withFreshMacroScope do + let auxDo ← `(do let discr ← $doElem; let $pattern:term := discr) + doSeqToCode $ getDoSeqElems (getDoSeq auxDo) ++ doElems + else + let contSeq := mkDoSeq doElems.toArray + let elseSeq := mkSingletonDoSeq optElse[1] + let auxDo ← `(do let discr ← $doElem; match discr with | $pattern:term => $contSeq | _ => $elseSeq) + doSeqToCode $ getDoSeqElems (getDoSeq auxDo) else - let contSeq := mkDoSeq doElems.toArray - let elseSeq := mkSingletonDoSeq optElse[1] - let auxDo ← `(do let discr ← $doElem; match discr with | $pattern:term => $contSeq | _ => $elseSeq) - doSeqToCode $ getDoSeqElems (getDoSeq auxDo) -else - throwError "unexpected kind of 'do' declaration" + throwError "unexpected kind of 'do' declaration" /- Generate `CodeBlock` for `doReassignArrow; doElems` @@ -1217,24 +1223,24 @@ else (doIdDecl <|> doPatDecl) ``` -/ def doReassignArrowToCode (doSeqToCode : List Syntax → M CodeBlock) (doReassignArrow : Syntax) (doElems : List Syntax) : M CodeBlock := do -let ref := doReassignArrow -let decl := doReassignArrow[0] -if decl.getKind == `Lean.Parser.Term.doIdDecl then - let doElem := decl[3] - let y := decl[0] - let auxDo ← `(do let r ← $doElem; $y:ident := r) - doSeqToCode $ getDoSeqElems (getDoSeq auxDo) ++ doElems -else if decl.getKind == `Lean.Parser.Term.doPatDecl then - let pattern := decl[0] - let doElem := decl[2] - let optElse := decl[3] - if optElse.isNone then withFreshMacroScope do - let auxDo ← `(do let discr ← $doElem; $pattern:term := discr) + let ref := doReassignArrow + let decl := doReassignArrow[0] + if decl.getKind == `Lean.Parser.Term.doIdDecl then + let doElem := decl[3] + let y := decl[0] + let auxDo ← `(do let r ← $doElem; $y:ident := r) doSeqToCode $ getDoSeqElems (getDoSeq auxDo) ++ doElems + else if decl.getKind == `Lean.Parser.Term.doPatDecl then + let pattern := decl[0] + let doElem := decl[2] + let optElse := decl[3] + if optElse.isNone then withFreshMacroScope do + let auxDo ← `(do let discr ← $doElem; $pattern:term := discr) + doSeqToCode $ getDoSeqElems (getDoSeq auxDo) ++ doElems + else + throwError "reassignment with `|` (i.e., \"else clause\") is not currently supported" else - throwError "reassignment with `|` (i.e., \"else clause\") is not currently supported" -else - throwError "unexpected kind of 'do' reassignment" + throwError "unexpected kind of 'do' reassignment" /- Generate `CodeBlock` for `doIf; doElems` `doIf` is of the form @@ -1244,11 +1250,11 @@ else >> optional (" else " >> doSeq) ``` -/ def doIfToCode (doSeqToCode : List Syntax → M CodeBlock) (doIf : Syntax) (doElems : List Syntax) : M CodeBlock := do -let view ← liftMacroM $ mkDoIfView doIf -let thenBranch ← doSeqToCode (getDoSeqElems view.thenBranch) -let elseBranch ← doSeqToCode (getDoSeqElems view.elseBranch) -let ite ← mkIte view.ref view.optIdent view.cond thenBranch elseBranch -concatWith doSeqToCode ite doElems + let view ← liftMacroM $ mkDoIfView doIf + let thenBranch ← doSeqToCode (getDoSeqElems view.thenBranch) + let elseBranch ← doSeqToCode (getDoSeqElems view.elseBranch) + let ite ← mkIte view.ref view.optIdent view.cond thenBranch elseBranch + concatWith doSeqToCode ite doElems /- Generate `CodeBlock` for `doUnless; doElems` `doUnless` is of the form @@ -1256,12 +1262,12 @@ concatWith doSeqToCode ite doElems "unless " >> termParser >> "do " >> doSeq ``` -/ def doUnlessToCode (doSeqToCode : List Syntax → M CodeBlock) (doUnless : Syntax) (doElems : List Syntax) : M CodeBlock := do -let ref := doUnless -let cond := doUnless[1] -let doSeq := doUnless[3] -let body ← doSeqToCode (getDoSeqElems doSeq) -let unlessCode ← liftMacroM $ mkUnless ref cond body -concatWith doSeqToCode unlessCode doElems + let ref := doUnless + let cond := doUnless[1] + let doSeq := doUnless[3] + let body ← doSeqToCode (getDoSeqElems doSeq) + let unlessCode ← liftMacroM $ mkUnless ref cond body + concatWith doSeqToCode unlessCode doElems /- Generate `CodeBlock` for `doFor; doElems` `doFor` is of the form @@ -1269,32 +1275,32 @@ concatWith doSeqToCode unlessCode doElems for " >> termParser >> " in " >> termParser >> "do " >> doSeq ``` -/ def doForToCode (doSeqToCode : List Syntax → M CodeBlock) (doFor : Syntax) (doElems : List Syntax) : M CodeBlock := do -let ref := doFor -let x := doFor[1] -let xs := doFor[3] -let forElems := getDoSeqElems doFor[5] -let newVars := if x.isIdent then #[x.getId] else #[] -let forInBodyCodeBlock ← withNewVars newVars $ withFor (doSeqToCode forElems) -let ⟨uvars, forInBody⟩ ← mkForInBody x forInBodyCodeBlock -let uvarsTuple ← liftMacroM $ mkTuple ref (uvars.map (mkIdentFrom ref)) -if hasReturn forInBodyCodeBlock.code then - let forInTerm ← `($(xs).forIn (MProd.mk none $uvarsTuple) fun $x (MProd.mk _ $uvarsTuple) => $forInBody) - let auxDo ← `(do let r ← $forInTerm:term; - $uvarsTuple:term := r.2; - match r.1 with - | none => HasPure.pure (ensureExpectedType! "type mismatch, 'for'" PUnit.unit) - | some a => return ensureExpectedType! "type mismatch, 'for'" a) - doSeqToCode (getDoSeqElems (getDoSeq auxDo) ++ doElems) -else - let forInTerm ← `($(xs).forIn $uvarsTuple fun $x $uvarsTuple => $forInBody) - if doElems.isEmpty then + let ref := doFor + let x := doFor[1] + let xs := doFor[3] + let forElems := getDoSeqElems doFor[5] + let newVars := if x.isIdent then #[x.getId] else #[] + let forInBodyCodeBlock ← withNewVars newVars $ withFor (doSeqToCode forElems) + let ⟨uvars, forInBody⟩ ← mkForInBody x forInBodyCodeBlock + let uvarsTuple ← liftMacroM $ mkTuple ref (uvars.map (mkIdentFrom ref)) + if hasReturn forInBodyCodeBlock.code then + let forInTerm ← `($(xs).forIn (MProd.mk none $uvarsTuple) fun $x (MProd.mk _ $uvarsTuple) => $forInBody) let auxDo ← `(do let r ← $forInTerm:term; - $uvarsTuple:term := r; - HasPure.pure (ensureExpectedType! "type mismatch, 'for'" PUnit.unit)) - doSeqToCode $ getDoSeqElems (getDoSeq auxDo) - else - let auxDo ← `(do let r ← $forInTerm:term; $uvarsTuple:term := r) + $uvarsTuple:term := r.2; + match r.1 with + | none => HasPure.pure (ensureExpectedType! "type mismatch, 'for'" PUnit.unit) + | some a => return ensureExpectedType! "type mismatch, 'for'" a) doSeqToCode (getDoSeqElems (getDoSeq auxDo) ++ doElems) + else + let forInTerm ← `($(xs).forIn $uvarsTuple fun $x $uvarsTuple => $forInBody) + if doElems.isEmpty then + let auxDo ← `(do let r ← $forInTerm:term; + $uvarsTuple:term := r; + HasPure.pure (ensureExpectedType! "type mismatch, 'for'" PUnit.unit)) + doSeqToCode $ getDoSeqElems (getDoSeq auxDo) + else + let auxDo ← `(do let r ← $forInTerm:term; $uvarsTuple:term := r) + doSeqToCode (getDoSeqElems (getDoSeq auxDo) ++ doElems) /-- Generate `CodeBlock` for `doMatch; doElems` @@ -1304,39 +1310,39 @@ else def doMatch := parser! "match " >> sepBy1 matchDiscr ", " >> optType >> " with " >> doMatchAlts ``` -/ def doMatchToCode (doSeqToCode : List Syntax → M CodeBlock) (doMatch : Syntax) (doElems: List Syntax) : M CodeBlock := do -let ref := doMatch -let discrs := doMatch[1] -let optType := doMatch[2] -let matchAlts := doMatch[4][1].getSepArgs -- Array of `doMatchAlt` -let alts ← matchAlts.mapM fun matchAlt => do - let patterns := matchAlt[0] - let pvars ← getPatternsVars patterns.getSepArgs - let vars := getPatternVarNames pvars - let rhs := matchAlt[2] - let rhs ← withNewVars vars $ doSeqToCode (getDoSeqElems rhs) - pure { ref := matchAlt, vars := vars, patterns := patterns, rhs := rhs : Alt CodeBlock } -let matchCode ← mkMatch ref discrs optType alts -concatWith doSeqToCode matchCode doElems + let ref := doMatch + let discrs := doMatch[1] + let optType := doMatch[2] + let matchAlts := doMatch[4][1].getSepArgs -- Array of `doMatchAlt` + let alts ← matchAlts.mapM fun matchAlt => do + let patterns := matchAlt[0] + let pvars ← getPatternsVars patterns.getSepArgs + let vars := getPatternVarNames pvars + let rhs := matchAlt[2] + let rhs ← withNewVars vars $ doSeqToCode (getDoSeqElems rhs) + pure { ref := matchAlt, vars := vars, patterns := patterns, rhs := rhs : Alt CodeBlock } + let matchCode ← mkMatch ref discrs optType alts + concatWith doSeqToCode matchCode doElems structure Catch := -(x : Syntax) -(optType : Syntax) -(codeBlock : CodeBlock) + (x : Syntax) + (optType : Syntax) + (codeBlock : CodeBlock) def getTryCatchUpdatedVars (tryCode : CodeBlock) (catches : Array Catch) (finallyCode? : Option CodeBlock) : NameSet := -let ws := tryCode.uvars -let ws := catches.foldl (fun ws alt => union alt.codeBlock.uvars ws) ws -let ws := match finallyCode? with - | none => ws - | some c => union c.uvars ws -ws + let ws := tryCode.uvars + let ws := catches.foldl (fun ws alt => union alt.codeBlock.uvars ws) ws + let ws := match finallyCode? with + | none => ws + | some c => union c.uvars ws + ws def tryCatchPred (tryCode : CodeBlock) (catches : Array Catch) (finallyCode? : Option CodeBlock) (p : Code → Bool) : Bool := -p tryCode.code || -catches.any (fun «catch» => p «catch».codeBlock.code) || -match finallyCode? with -| none => false -| some finallyCode => p finallyCode.code + p tryCode.code || + catches.any (fun «catch» => p «catch».codeBlock.code) || + match finallyCode? with + | none => false + | some finallyCode => p finallyCode.code /-- Generate `CodeBlock` for `doTry; doElems` @@ -1347,56 +1353,56 @@ match finallyCode? with def doFinally := parser! "finally " >> doSeq ``` -/ def doTryToCode (doSeqToCode : List Syntax → M CodeBlock) (doTry : Syntax) (doElems: List Syntax) : M CodeBlock := do -let ref := doTry -let tryCode ← doSeqToCode (getDoSeqElems doTry[1]) -let optFinally := doTry[3] -let catches ← doTry[2].getArgs.mapM fun catchStx => do - if catchStx.getKind == `Lean.Parser.Term.doCatch then - let x := catchStx[1] - let optType := catchStx[2] - let c ← doSeqToCode (getDoSeqElems catchStx[4]) - pure { x := x, optType := optType, codeBlock := c : Catch } - else if catchStx.getKind == `Lean.Parser.Term.doCatchMatch then - let matchAlts := catchStx[1] - let x ← `(ex) - let auxDo ← `(do match ex with $matchAlts) - let c ← doSeqToCode (getDoSeqElems (getDoSeq auxDo)) - pure { x := x, codeBlock := c, optType := mkNullNode : Catch } - else - throwError "unexpected kind of 'catch'" -let finallyCode? ← if optFinally.isNone then pure none else some <$> doSeqToCode (getDoSeqElems optFinally[0][1]) -if catches.isEmpty && finallyCode?.isNone then - throwError "invalid 'try', it must have a 'catch' or 'finally'" -let ctx ← read -let ws := getTryCatchUpdatedVars tryCode catches finallyCode? -let uvars := nameSetToArray ws -let a := tryCatchPred tryCode catches finallyCode? hasTerminalAction -let r := tryCatchPred tryCode catches finallyCode? hasReturn -let bc := tryCatchPred tryCode catches finallyCode? hasBreakContinue -let toTerm (codeBlock : CodeBlock) : M Syntax := do - codeBlock ← liftM $ extendUpdatedVars codeBlock ws - liftMacroM $ ToTerm.mkNestedTerm codeBlock.code ctx.m uvars a r bc -let term ← toTerm tryCode -let term ← catches.foldlM - (fun term «catch» => do - let catchTerm ← toTerm «catch».codeBlock - if catch.optType.isNone then - `(MonadExcept.«catch» $term (fun $(«catch».x):ident => $catchTerm)) + let ref := doTry + let tryCode ← doSeqToCode (getDoSeqElems doTry[1]) + let optFinally := doTry[3] + let catches ← doTry[2].getArgs.mapM fun catchStx => do + if catchStx.getKind == `Lean.Parser.Term.doCatch then + let x := catchStx[1] + let optType := catchStx[2] + let c ← doSeqToCode (getDoSeqElems catchStx[4]) + pure { x := x, optType := optType, codeBlock := c : Catch } + else if catchStx.getKind == `Lean.Parser.Term.doCatchMatch then + let matchAlts := catchStx[1] + let x ← `(ex) + let auxDo ← `(do match ex with $matchAlts) + let c ← doSeqToCode (getDoSeqElems (getDoSeq auxDo)) + pure { x := x, codeBlock := c, optType := mkNullNode : Catch } else - let type := «catch».optType[1] - `(catchThe $type $term (fun $(«catch».x):ident => $catchTerm))) - term -let term ← match finallyCode? with - | none => pure term - | some finallyCode => withRef optFinally do - unless finallyCode.uvars.isEmpty do - throwError "'finally' currently does not support reassignments" - if hasBreakContinueReturn finallyCode.code then - throwError "'finally' currently does 'return', 'break', nor 'continue'" - let finallyTerm ← liftMacroM $ ToTerm.run finallyCode.code ctx.m {} ToTerm.Kind.regular - `(«finally» $term $finallyTerm) -let doElemsNew ← liftMacroM $ ToTerm.matchNestedTermResult ref term uvars a r bc -doSeqToCode (doElemsNew ++ doElems) + throwError "unexpected kind of 'catch'" + let finallyCode? ← if optFinally.isNone then pure none else some <$> doSeqToCode (getDoSeqElems optFinally[0][1]) + if catches.isEmpty && finallyCode?.isNone then + throwError "invalid 'try', it must have a 'catch' or 'finally'" + let ctx ← read + let ws := getTryCatchUpdatedVars tryCode catches finallyCode? + let uvars := nameSetToArray ws + let a := tryCatchPred tryCode catches finallyCode? hasTerminalAction + let r := tryCatchPred tryCode catches finallyCode? hasReturn + let bc := tryCatchPred tryCode catches finallyCode? hasBreakContinue + let toTerm (codeBlock : CodeBlock) : M Syntax := do + codeBlock ← liftM $ extendUpdatedVars codeBlock ws + liftMacroM $ ToTerm.mkNestedTerm codeBlock.code ctx.m uvars a r bc + let term ← toTerm tryCode + let term ← catches.foldlM + (fun term «catch» => do + let catchTerm ← toTerm «catch».codeBlock + if catch.optType.isNone then + `(MonadExcept.«catch» $term (fun $(«catch».x):ident => $catchTerm)) + else + let type := «catch».optType[1] + `(catchThe $type $term (fun $(«catch».x):ident => $catchTerm))) + term + let term ← match finallyCode? with + | none => pure term + | some finallyCode => withRef optFinally do + unless finallyCode.uvars.isEmpty do + throwError "'finally' currently does not support reassignments" + if hasBreakContinueReturn finallyCode.code then + throwError "'finally' currently does 'return', 'break', nor 'continue'" + let finallyTerm ← liftMacroM $ ToTerm.run finallyCode.code ctx.m {} ToTerm.Kind.regular + `(«finally» $term $finallyTerm) + let doElemsNew ← liftMacroM $ ToTerm.matchNestedTermResult ref term uvars a r bc + doSeqToCode (doElemsNew ++ doElems) /- Generate `CodeBlock` for `doReturn` which is of the form ``` @@ -1404,99 +1410,98 @@ doSeqToCode (doElemsNew ++ doElems) ``` `doElems` is only used for sanity checking. -/ def doReturnToCode (doReturn : Syntax) (doElems: List Syntax) : M CodeBlock := do -let ref := doReturn -ensureEOS doElems -let argOpt := doReturn[1] -let arg ← if argOpt.isNone then liftMacroM $ mkUnit ref else pure argOpt[0] -pure $ mkReturn ref arg + let ref := doReturn + ensureEOS doElems + let argOpt := doReturn[1] + let arg ← if argOpt.isNone then liftMacroM $ mkUnit ref else pure argOpt[0] + pure $ mkReturn ref arg partial def doSeqToCode : List Syntax → M CodeBlock -| [] => do let ctx ← read; liftMacroM $ mkPureUnitAction ctx.ref -| doElem::doElems => withRef doElem do - match (← liftMacroM $ expandMacro? doElem) with - | some doElem => doSeqToCode (doElem::doElems) - | none => - match (← liftMacroM $ expandDoIf? doElem) with - | some doElem => doSeqToCode (doElem::doElems) - | none => - let (liftedDoElems, doElem) ← liftM (liftMacroM $ expandLiftMethod doElem : TermElabM _) - if !liftedDoElems.isEmpty then - doSeqToCode (liftedDoElems ++ [doElem] ++ doElems) - else - let ref := doElem - let concatWithRest (c : CodeBlock) : M CodeBlock := concatWith doSeqToCode c doElems - let k := doElem.getKind - if k == `Lean.Parser.Term.doLet then - let vars ← getDoLetVars doElem - mkVarDeclCore vars doElem <$> withNewVars vars (doSeqToCode doElems) - else if k == `Lean.Parser.Term.doHave then - let var := getDoHaveVar doElem - mkVarDeclCore #[var] doElem <$> withNewVars #[var] (doSeqToCode doElems) - else if k == `Lean.Parser.Term.doLetRec then - let vars ← getDoLetRecVars doElem - mkVarDeclCore vars doElem <$> withNewVars vars (doSeqToCode doElems) - else if k == `Lean.Parser.Term.doReassign then - let vars ← liftM $ getDoReassignVars doElem - checkReassignable vars - let k ← doSeqToCode doElems - mkReassignCore vars doElem k - else if k == `Lean.Parser.Term.doLetArrow then - doLetArrowToCode doSeqToCode doElem doElems - else if k == `Lean.Parser.Term.doReassignArrow then - doReassignArrowToCode doSeqToCode doElem doElems - else if k == `Lean.Parser.Term.doIf then - doIfToCode doSeqToCode doElem doElems - else if k == `Lean.Parser.Term.doUnless then - doUnlessToCode doSeqToCode doElem doElems - else if k == `Lean.Parser.Term.doFor then withFreshMacroScope do - doForToCode doSeqToCode doElem doElems - else if k == `Lean.Parser.Term.doMatch then - doMatchToCode doSeqToCode doElem doElems - else if k == `Lean.Parser.Term.doTry then - doTryToCode doSeqToCode doElem doElems - else if k == `Lean.Parser.Term.doBreak then - ensureInsideFor - ensureEOS doElems - pure $ mkBreak ref - else if k == `Lean.Parser.Term.doContinue then - ensureInsideFor - ensureEOS doElems - pure $ mkContinue ref - else if k == `Lean.Parser.Term.doReturn then - doReturnToCode doElem doElems - else if k == `Lean.Parser.Term.doDbgTrace then - mkSeq doElem <$> doSeqToCode doElems - else if k == `Lean.Parser.Term.doAssert then - mkSeq doElem <$> doSeqToCode doElems - else if k == `Lean.Parser.Term.doNested then - let nestedDoSeq := doElem[1] - doSeqToCode (getDoSeqElems nestedDoSeq ++ doElems) - else if k == `Lean.Parser.Term.doExpr then - let term := doElem[0] - if doElems.isEmpty then - pure $ mkTerminalAction term - else - mkSeq term <$> doSeqToCode doElems + | [] => do let ctx ← read; liftMacroM $ mkPureUnitAction ctx.ref + | doElem::doElems => withRef doElem do + match (← liftMacroM $ expandMacro? doElem) with + | some doElem => doSeqToCode (doElem::doElems) + | none => + match (← liftMacroM $ expandDoIf? doElem) with + | some doElem => doSeqToCode (doElem::doElems) + | none => + let (liftedDoElems, doElem) ← liftM (liftMacroM $ expandLiftMethod doElem : TermElabM _) + if !liftedDoElems.isEmpty then + doSeqToCode (liftedDoElems ++ [doElem] ++ doElems) else - throwError! "unexpected do-element\n{doElem}" + let ref := doElem + let concatWithRest (c : CodeBlock) : M CodeBlock := concatWith doSeqToCode c doElems + let k := doElem.getKind + if k == `Lean.Parser.Term.doLet then + let vars ← getDoLetVars doElem + mkVarDeclCore vars doElem <$> withNewVars vars (doSeqToCode doElems) + else if k == `Lean.Parser.Term.doHave then + let var := getDoHaveVar doElem + mkVarDeclCore #[var] doElem <$> withNewVars #[var] (doSeqToCode doElems) + else if k == `Lean.Parser.Term.doLetRec then + let vars ← getDoLetRecVars doElem + mkVarDeclCore vars doElem <$> withNewVars vars (doSeqToCode doElems) + else if k == `Lean.Parser.Term.doReassign then + let vars ← liftM $ getDoReassignVars doElem + checkReassignable vars + let k ← doSeqToCode doElems + mkReassignCore vars doElem k + else if k == `Lean.Parser.Term.doLetArrow then + doLetArrowToCode doSeqToCode doElem doElems + else if k == `Lean.Parser.Term.doReassignArrow then + doReassignArrowToCode doSeqToCode doElem doElems + else if k == `Lean.Parser.Term.doIf then + doIfToCode doSeqToCode doElem doElems + else if k == `Lean.Parser.Term.doUnless then + doUnlessToCode doSeqToCode doElem doElems + else if k == `Lean.Parser.Term.doFor then withFreshMacroScope do + doForToCode doSeqToCode doElem doElems + else if k == `Lean.Parser.Term.doMatch then + doMatchToCode doSeqToCode doElem doElems + else if k == `Lean.Parser.Term.doTry then + doTryToCode doSeqToCode doElem doElems + else if k == `Lean.Parser.Term.doBreak then + ensureInsideFor + ensureEOS doElems + pure $ mkBreak ref + else if k == `Lean.Parser.Term.doContinue then + ensureInsideFor + ensureEOS doElems + pure $ mkContinue ref + else if k == `Lean.Parser.Term.doReturn then + doReturnToCode doElem doElems + else if k == `Lean.Parser.Term.doDbgTrace then + mkSeq doElem <$> doSeqToCode doElems + else if k == `Lean.Parser.Term.doAssert then + mkSeq doElem <$> doSeqToCode doElems + else if k == `Lean.Parser.Term.doNested then + let nestedDoSeq := doElem[1] + doSeqToCode (getDoSeqElems nestedDoSeq ++ doElems) + else if k == `Lean.Parser.Term.doExpr then + let term := doElem[0] + if doElems.isEmpty then + pure $ mkTerminalAction term + else + mkSeq term <$> doSeqToCode doElems + else + throwError! "unexpected do-element\n{doElem}" def run (doStx : Syntax) (m : Syntax) : TermElabM CodeBlock := -(doSeqToCode $ getDoSeqElems $ getDoSeq doStx).run { ref := doStx, m := m } + (doSeqToCode $ getDoSeqElems $ getDoSeq doStx).run { ref := doStx, m := m } end ToCodeBlock /- Create a synthetic metavariable `?m` and assign `m` to it. We use `?m` to refer to `m` when expanding the `do` notation. -/ private def mkMonadAlias (m : Expr) : TermElabM Syntax := do -let result ← `(?m) -let mType ← inferType m -let mvar ← elabTerm result mType -assignExprMVar mvar.mvarId! m -pure result + let result ← `(?m) + let mType ← inferType m + let mvar ← elabTerm result mType + assignExprMVar mvar.mvarId! m + pure result @[builtinTermElab «do»] -def elabDo : TermElab := -fun stx expectedType? => do +def elabDo : TermElab := fun stx expectedType? => do tryPostponeIfNoneOrMVar expectedType? let bindInfo ← extractBind expectedType? let m ← mkMonadAlias bindInfo.m diff --git a/src/Lean/Elab/Inductive.lean b/src/Lean/Elab/Inductive.lean index 737e384074..15878b1fcd 100644 --- a/src/Lean/Elab/Inductive.lean +++ b/src/Lean/Elab/Inductive.lean @@ -42,7 +42,7 @@ structure CtorView := (binders : Syntax) (type? : Option Syntax) -instance CtorView.inhabited : Inhabited CtorView := +instance : Inhabited CtorView := ⟨{ ref := arbitrary _, modifiers := {}, inferMod := false, declName := arbitrary _, binders := arbitrary _, type? := none }⟩ structure InductiveView := @@ -55,7 +55,7 @@ structure InductiveView := (type? : Option Syntax) (ctors : Array CtorView) -instance InductiveView.inhabited : Inhabited InductiveView := +instance : Inhabited InductiveView := ⟨{ ref := arbitrary _, modifiers := {}, shortDeclName := arbitrary _, declName := arbitrary _, levelNames := [], binders := arbitrary _, type? := none, ctors := #[] }⟩ @@ -66,7 +66,7 @@ structure ElabHeaderResult := (params : Array Expr) (type : Expr) -instance ElabHeaderResult.inhabited : Inhabited ElabHeaderResult := +instance : Inhabited ElabHeaderResult := ⟨{ view := arbitrary _, lctx := arbitrary _, localInsts := arbitrary _, params := #[], type := arbitrary _ }⟩ private partial def elabHeaderAux (views : Array InductiveView) diff --git a/src/Lean/Elab/Log.lean b/src/Lean/Elab/Log.lean index 295665b878..26974e8cd6 100644 --- a/src/Lean/Elab/Log.lean +++ b/src/Lean/Elab/Log.lean @@ -15,7 +15,7 @@ class MonadLog (m : Type → Type) := (getFileName : m String) (logMessage : Message → m Unit) -instance monadLogTrans (m n) [MonadLog m] [MonadLift m n] : MonadLog n := +instance (m n) [MonadLog m] [MonadLift m n] : MonadLog n := { getRef := liftM (MonadLog.getRef : m _), getFileMap := liftM (MonadLog.getFileMap : m _), getFileName := liftM (MonadLog.getFileName : m _), diff --git a/src/Lean/Elab/Match.lean b/src/Lean/Elab/Match.lean index 75085afe67..70e4a6be0b 100644 --- a/src/Lean/Elab/Match.lean +++ b/src/Lean/Elab/Match.lean @@ -156,7 +156,7 @@ inductive PatternVar -- anonymous variables (`_`) are encoded using metavariables | anonymousVar (mvarId : MVarId) -instance PatternVar.hasToString : HasToString PatternVar := +instance : HasToString PatternVar := ⟨fun | PatternVar.localVar x => toString x | PatternVar.anonymousVar mvarId => s!"?m{mvarId}"⟩ @@ -269,7 +269,7 @@ structure Context := (args : List Arg) (newArgs : Array Syntax := #[]) -instance Context.inhabited : Inhabited Context := +instance : Inhabited Context := ⟨⟨arbitrary _, none, false, false, #[], 0, #[], [], #[]⟩⟩ private def isDone (ctx : Context) : Bool := diff --git a/src/Lean/Elab/MutualDef.lean b/src/Lean/Elab/MutualDef.lean index 8374f49178..83db728e1a 100644 --- a/src/Lean/Elab/MutualDef.lean +++ b/src/Lean/Elab/MutualDef.lean @@ -24,7 +24,7 @@ structure DefViewElabHeader := (type : Expr) -- including the parameters (valueStx : Syntax) -instance DefViewElabHeader.inhabited : Inhabited DefViewElabHeader := +instance : Inhabited DefViewElabHeader := ⟨⟨arbitrary _, {}, DefKind.«def», arbitrary _, arbitrary _, [], 0, arbitrary _, arbitrary _⟩⟩ namespace Term diff --git a/src/Lean/Elab/Quotation.lean b/src/Lean/Elab/Quotation.lean index 4717130b4e..006edde618 100644 --- a/src/Lean/Elab/Quotation.lean +++ b/src/Lean/Elab/Quotation.lean @@ -154,7 +154,7 @@ structure HeadInfo := -- bind pattern variables. (rhsFn : Syntax → TermElabM Syntax := pure) -instance HeadInfo.Inhabited : Inhabited HeadInfo := ⟨{}⟩ +instance : Inhabited HeadInfo := ⟨{}⟩ /-- `h1.generalizes h2` iff h1 is equal to or more general than h2, i.e. it matches all nodes h2 matches. This induces a partial ordering. -/ diff --git a/src/Lean/Elab/StructInst.lean b/src/Lean/Elab/StructInst.lean index 8cb21a277f..d8208df34c 100644 --- a/src/Lean/Elab/StructInst.lean +++ b/src/Lean/Elab/StructInst.lean @@ -16,8 +16,7 @@ open Meta /- parser! "{" >> optional (try (termParser >> "with")) >> sepBy structInstField ", " true >> optional ".." >> optional (" : " >> termParser) >> "}" -/ -@[builtinMacro Lean.Parser.Term.structInst] def expandStructInstExpectedType : Macro := -fun stx => +@[builtinMacro Lean.Parser.Term.structInst] def expandStructInstExpectedType : Macro := fun stx => let expectedArg := stx[4] if expectedArg.isNone then Macro.throwUnsupported @@ -32,51 +31,51 @@ If `stx` is of the form `{ s with ... }` and `s` is not a local variable, expand Note that this one is not a `Macro` because we need to access the local context. -/ private def expandNonAtomicExplicitSource (stx : Syntax) : TermElabM (Option Syntax) := -withFreshMacroScope do - let sourceOpt := stx[1] - if sourceOpt.isNone then - pure none - else - let source := sourceOpt[0] - match (← isLocalIdent? source) with - | some _ => pure none - | none => - let src ← `(src) - let sourceOpt := sourceOpt.setArg 0 src - let stxNew := stx.setArg 1 sourceOpt - `(let src := $source; $stxNew) + withFreshMacroScope do + let sourceOpt := stx[1] + if sourceOpt.isNone then + pure none + else + let source := sourceOpt[0] + match (← isLocalIdent? source) with + | some _ => pure none + | none => + let src ← `(src) + let sourceOpt := sourceOpt.setArg 0 src + let stxNew := stx.setArg 1 sourceOpt + `(let src := $source; $stxNew) inductive Source -| none -- structure instance source has not been provieded -| implicit (stx : Syntax) -- `..` -| explicit (stx : Syntax) (src : Expr) -- `src with` + | none -- structure instance source has not been provieded + | implicit (stx : Syntax) -- `..` + | explicit (stx : Syntax) (src : Expr) -- `src with` -instance Source.inhabited : Inhabited Source := ⟨Source.none⟩ +instance : Inhabited Source := ⟨Source.none⟩ def Source.isNone : Source → Bool -| Source.none => true -| _ => false + | Source.none => true + | _ => false def setStructSourceSyntax (structStx : Syntax) : Source → Syntax -| Source.none => (structStx.setArg 1 mkNullNode).setArg 3 mkNullNode -| Source.implicit stx => (structStx.setArg 1 mkNullNode).setArg 3 stx -| Source.explicit stx _ => (structStx.setArg 1 stx).setArg 3 mkNullNode + | Source.none => (structStx.setArg 1 mkNullNode).setArg 3 mkNullNode + | Source.implicit stx => (structStx.setArg 1 mkNullNode).setArg 3 stx + | Source.explicit stx _ => (structStx.setArg 1 stx).setArg 3 mkNullNode private def getStructSource (stx : Syntax) : TermElabM Source := -withRef stx do -let explicitSource := stx[1] -let implicitSource := stx[3] -if explicitSource.isNone && implicitSource.isNone then - pure Source.none -else if explicitSource.isNone then - pure $ Source.implicit implicitSource -else if implicitSource.isNone then - let fvar? ← isLocalIdent? explicitSource[0] - match fvar? with - | none => unreachable! -- expandNonAtomicExplicitSource must have been used when we get here - | some src => pure $ Source.explicit explicitSource src -else - throwError "invalid structure instance `with` and `..` cannot be used together" + withRef stx do + let explicitSource := stx[1] + let implicitSource := stx[3] + if explicitSource.isNone && implicitSource.isNone then + pure Source.none + else if explicitSource.isNone then + pure $ Source.implicit implicitSource + else if implicitSource.isNone then + let fvar? ← isLocalIdent? explicitSource[0] + match fvar? with + | none => unreachable! -- expandNonAtomicExplicitSource must have been used when we get here + | some src => pure $ Source.explicit explicitSource src + else + throwError "invalid structure instance `with` and `..` cannot be used together" /- We say a `{ ... }` notation is a `modifyOp` if it contains only one @@ -84,157 +83,160 @@ else def structInstArrayRef := parser! "[" >> termParser >>"]" ``` -/ private def isModifyOp? (stx : Syntax) : TermElabM (Option Syntax) := do -let args := stx[2].getArgs -let s? ← args.foldSepByM - (fun arg s? => - /- Remark: the syntax for `structInstField` is - ``` - def structInstLVal := (ident <|> numLit <|> structInstArrayRef) >> many (group ("." >> (ident <|> numLit)) <|> structInstArrayRef) - def structInstField := parser! structInstLVal >> " := " >> termParser - ``` -/ - let lval := arg[0] - let k := lval.getKind - if k == `Lean.Parser.Term.structInstArrayRef then - match s? with - | none => pure (some arg) - | some s => - if s.getKind == `Lean.Parser.Term.structInstArrayRef then - throwErrorAt arg "invalid {...} notation, at most one `[..]` at a given level" - else - throwErrorAt arg "invalid {...} notation, can't mix field and `[..]` at a given level" - else - match s? with - | none => pure (some arg) - | some s => - if s.getKind == `Lean.Parser.Term.structInstArrayRef then - throwErrorAt arg "invalid {...} notation, can't mix field and `[..]` at a given level" - else - pure s?) - none -match s? with -| none => pure none -| some s => if s[0].getKind == `Lean.Parser.Term.structInstArrayRef then pure s? else pure none + let args := stx[2].getArgs + let s? ← args.foldSepByM + (fun arg s? => + /- Remark: the syntax for `structInstField` is + ``` + def structInstLVal := (ident <|> numLit <|> structInstArrayRef) >> many (group ("." >> (ident <|> numLit)) <|> structInstArrayRef) + def structInstField := parser! structInstLVal >> " := " >> termParser + ``` -/ + let lval := arg[0] + let k := lval.getKind + if k == `Lean.Parser.Term.structInstArrayRef then + match s? with + | none => pure (some arg) + | some s => + if s.getKind == `Lean.Parser.Term.structInstArrayRef then + throwErrorAt arg "invalid {...} notation, at most one `[..]` at a given level" + else + throwErrorAt arg "invalid {...} notation, can't mix field and `[..]` at a given level" + else + match s? with + | none => pure (some arg) + | some s => + if s.getKind == `Lean.Parser.Term.structInstArrayRef then + throwErrorAt arg "invalid {...} notation, can't mix field and `[..]` at a given level" + else + pure s?) + none + match s? with + | none => pure none + | some s => if s[0].getKind == `Lean.Parser.Term.structInstArrayRef then pure s? else pure none private def elabModifyOp (stx modifyOp source : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do -let cont (val : Syntax) : TermElabM Expr := do - let lval := modifyOp[0] - let idx := lval[1] - let self := source[0] - let stxNew ← `($(self).modifyOp (idx := $idx) (fun s => $val)) - trace[Elab.struct.modifyOp]! "{stx}\n===>\n{stxNew}" - withMacroExpansion stx stxNew $ elabTerm stxNew expectedType? -trace[Elab.struct.modifyOp]! "{modifyOp}\nSource: {source}" -let rest := modifyOp[1] -if rest.isNone then - cont modifyOp[3] -else - let s ← `(s) - let valFirst := rest[0] - let valFirst := if valFirst.getKind == `Lean.Parser.Term.structInstArrayRef then valFirst else valFirst[1] - let restArgs := rest.getArgs - let valRest := mkNullNode restArgs[1:restArgs.size] - let valField := modifyOp.setArg 0 valFirst - let valField := valField.setArg 1 valRest - let valSource := source.modifyArg 0 fun _ => s - let val := stx.setArg 1 valSource - let val := val.setArg 2 $ mkNullNode #[valField] - trace[Elab.struct.modifyOp]! "{stx}\nval: {val}" - cont val + let cont (val : Syntax) : TermElabM Expr := do + let lval := modifyOp[0] + let idx := lval[1] + let self := source[0] + let stxNew ← `($(self).modifyOp (idx := $idx) (fun s => $val)) + trace[Elab.struct.modifyOp]! "{stx}\n===>\n{stxNew}" + withMacroExpansion stx stxNew $ elabTerm stxNew expectedType? + trace[Elab.struct.modifyOp]! "{modifyOp}\nSource: {source}" + let rest := modifyOp[1] + if rest.isNone then + cont modifyOp[3] + else + let s ← `(s) + let valFirst := rest[0] + let valFirst := if valFirst.getKind == `Lean.Parser.Term.structInstArrayRef then valFirst else valFirst[1] + let restArgs := rest.getArgs + let valRest := mkNullNode restArgs[1:restArgs.size] + let valField := modifyOp.setArg 0 valFirst + let valField := valField.setArg 1 valRest + let valSource := source.modifyArg 0 fun _ => s + let val := stx.setArg 1 valSource + let val := val.setArg 2 $ mkNullNode #[valField] + trace[Elab.struct.modifyOp]! "{stx}\nval: {val}" + cont val /- Get structure name and elaborate explicit source (if available) -/ private def getStructName (stx : Syntax) (expectedType? : Option Expr) (sourceView : Source) : TermElabM Name := do -tryPostponeIfNoneOrMVar expectedType? -let useSource : Unit → TermElabM Name := fun _ => - match sourceView, expectedType? with - | Source.explicit _ src, _ => do - let srcType ← inferType src - let srcType ← whnf srcType - tryPostponeIfMVar srcType - match srcType.getAppFn with + tryPostponeIfNoneOrMVar expectedType? + let useSource : Unit → TermElabM Name := fun _ => + match sourceView, expectedType? with + | Source.explicit _ src, _ => do + let srcType ← inferType src + let srcType ← whnf srcType + tryPostponeIfMVar srcType + match srcType.getAppFn with + | Expr.const constName _ _ => pure constName + | _ => throwError! "invalid \{...} notation, source type is not of the form (C ...){indentExpr srcType}" + | _, some expectedType => throwError! "invalid \{...} notation, expected type is not of the form (C ...){indentExpr expectedType}" + | _, none => throwError! "invalid \{...} notation, expected type must be known" + match expectedType? with + | none => useSource () + | some expectedType => + let expectedType ← whnf expectedType + match expectedType.getAppFn with | Expr.const constName _ _ => pure constName - | _ => throwError! "invalid \{...} notation, source type is not of the form (C ...){indentExpr srcType}" - | _, some expectedType => throwError! "invalid \{...} notation, expected type is not of the form (C ...){indentExpr expectedType}" - | _, none => throwError! "invalid \{...} notation, expected type must be known" -match expectedType? with -| none => useSource () -| some expectedType => - let expectedType ← whnf expectedType - match expectedType.getAppFn with - | Expr.const constName _ _ => pure constName - | _ => useSource () + | _ => useSource () inductive FieldLHS -| fieldName (ref : Syntax) (name : Name) -| fieldIndex (ref : Syntax) (idx : Nat) -| modifyOp (ref : Syntax) (index : Syntax) + | fieldName (ref : Syntax) (name : Name) + | fieldIndex (ref : Syntax) (idx : Nat) + | modifyOp (ref : Syntax) (index : Syntax) -instance FieldLHS.inhabited : Inhabited FieldLHS := ⟨FieldLHS.fieldName (arbitrary _) (arbitrary _)⟩ -instance FieldLHS.hasFormat : HasFormat FieldLHS := -⟨fun lhs => match lhs with +instance : Inhabited FieldLHS := ⟨FieldLHS.fieldName (arbitrary _) (arbitrary _)⟩ +instance : HasFormat FieldLHS := ⟨fun lhs => + match lhs with | FieldLHS.fieldName _ n => fmt n | FieldLHS.fieldIndex _ i => fmt i | FieldLHS.modifyOp _ i => "[" ++ i.prettyPrint ++ "]"⟩ inductive FieldVal (σ : Type) -| term (stx : Syntax) : FieldVal σ -| nested (s : σ) : FieldVal σ -| default : FieldVal σ -- mark that field must be synthesized using default value + | term (stx : Syntax) : FieldVal σ + | nested (s : σ) : FieldVal σ + | default : FieldVal σ -- mark that field must be synthesized using default value structure Field (σ : Type) := -(ref : Syntax) (lhs : List FieldLHS) (val : FieldVal σ) (expr? : Option Expr := none) + (ref : Syntax) + (lhs : List FieldLHS) + (val : FieldVal σ) + (expr? : Option Expr := none) -instance Field.inhabited {σ} : Inhabited (Field σ) := ⟨⟨arbitrary _, [], FieldVal.term (arbitrary _), arbitrary _⟩⟩ +instance {σ} : Inhabited (Field σ) := ⟨⟨arbitrary _, [], FieldVal.term (arbitrary _), arbitrary _⟩⟩ def Field.isSimple {σ} : Field σ → Bool -| { lhs := [_], .. } => true -| _ => false + | { lhs := [_], .. } => true + | _ => false inductive Struct -| mk (ref : Syntax) (structName : Name) (fields : List (Field Struct)) (source : Source) + | mk (ref : Syntax) (structName : Name) (fields : List (Field Struct)) (source : Source) -instance Struct.inhabited : Inhabited Struct := ⟨⟨arbitrary _, arbitrary _, [], arbitrary _⟩⟩ +instance : Inhabited Struct := ⟨⟨arbitrary _, arbitrary _, [], arbitrary _⟩⟩ abbrev Fields := List (Field Struct) /- true if all fields of the given structure are marked as `default` -/ partial def Struct.allDefault : Struct → Bool -| ⟨_, _, fields, _⟩ => fields.all fun ⟨_, _, val, _⟩ => match val with - | FieldVal.term _ => false - | FieldVal.default => true - | FieldVal.nested s => allDefault s + | ⟨_, _, fields, _⟩ => fields.all fun ⟨_, _, val, _⟩ => match val with + | FieldVal.term _ => false + | FieldVal.default => true + | FieldVal.nested s => allDefault s def Struct.ref : Struct → Syntax -| ⟨ref, _, _, _⟩ => ref + | ⟨ref, _, _, _⟩ => ref def Struct.structName : Struct → Name -| ⟨_, structName, _, _⟩ => structName + | ⟨_, structName, _, _⟩ => structName def Struct.fields : Struct → Fields -| ⟨_, _, fields, _⟩ => fields + | ⟨_, _, fields, _⟩ => fields def Struct.source : Struct → Source -| ⟨_, _, _, s⟩ => s + | ⟨_, _, _, s⟩ => s def formatField (formatStruct : Struct → Format) (field : Field Struct) : Format := -Format.joinSep field.lhs " . " ++ " := " ++ - match field.val with - | FieldVal.term v => v.prettyPrint - | FieldVal.nested s => formatStruct s - | FieldVal.default => "" + Format.joinSep field.lhs " . " ++ " := " ++ + match field.val with + | FieldVal.term v => v.prettyPrint + | FieldVal.nested s => formatStruct s + | FieldVal.default => "" partial def formatStruct : Struct → Format -| ⟨_, structName, fields, source⟩ => - let fieldsFmt := Format.joinSep (fields.map (formatField formatStruct)) ", " - match source with - | Source.none => "{" ++ fieldsFmt ++ "}" - | Source.implicit _ => "{" ++ fieldsFmt ++ " .. }" - | Source.explicit _ src => "{" ++ format src ++ " with " ++ fieldsFmt ++ "}" + | ⟨_, structName, fields, source⟩ => + let fieldsFmt := Format.joinSep (fields.map (formatField formatStruct)) ", " + match source with + | Source.none => "{" ++ fieldsFmt ++ "}" + | Source.implicit _ => "{" ++ fieldsFmt ++ " .. }" + | Source.explicit _ src => "{" ++ format src ++ " with " ++ fieldsFmt ++ "}" -instance Struct.hasFormat : HasFormat Struct := ⟨formatStruct⟩ -instance Struct.hasToString : HasToString Struct := ⟨toString ∘ format⟩ +instance : HasFormat Struct := ⟨formatStruct⟩ +instance : HasToString Struct := ⟨toString ∘ format⟩ -instance Field.hasFormat : HasFormat (Field Struct) := ⟨formatField formatStruct⟩ -instance Field.hasToString : HasToString (Field Struct) := ⟨toString ∘ format⟩ +instance : HasFormat (Field Struct) := ⟨formatField formatStruct⟩ +instance : HasToString (Field Struct) := ⟨toString ∘ format⟩ /- Recall that `structInstField` elements have the form @@ -246,74 +248,74 @@ Recall that `structInstField` elements have the form -- Remark: this code relies on the fact that `expandStruct` only transforms `fieldLHS.fieldName` def FieldLHS.toSyntax (first : Bool) : FieldLHS → Syntax -| FieldLHS.modifyOp stx _ => stx -| FieldLHS.fieldName stx name => if first then mkIdentFrom stx name else mkNullNode #[mkAtomFrom stx ".", mkIdentFrom stx name] -| FieldLHS.fieldIndex stx _ => if first then stx else mkNullNode #[mkAtomFrom stx ".", stx] + | FieldLHS.modifyOp stx _ => stx + | FieldLHS.fieldName stx name => if first then mkIdentFrom stx name else mkNullNode #[mkAtomFrom stx ".", mkIdentFrom stx name] + | FieldLHS.fieldIndex stx _ => if first then stx else mkNullNode #[mkAtomFrom stx ".", stx] def FieldVal.toSyntax : FieldVal Struct → Syntax -| FieldVal.term stx => stx -| _ => unreachable! + | FieldVal.term stx => stx + | _ => unreachable! def Field.toSyntax : Field Struct → Syntax -| field => - let stx := field.ref - let stx := stx.setArg 3 field.val.toSyntax - match field.lhs with - | first::rest => - let stx := stx.setArg 0 $ first.toSyntax true - let stx := stx.setArg 1 $ mkNullNode $ rest.toArray.map (FieldLHS.toSyntax false) - stx - | _ => unreachable! + | field => + let stx := field.ref + let stx := stx.setArg 3 field.val.toSyntax + match field.lhs with + | first::rest => + let stx := stx.setArg 0 $ first.toSyntax true + let stx := stx.setArg 1 $ mkNullNode $ rest.toArray.map (FieldLHS.toSyntax false) + stx + | _ => unreachable! private def toFieldLHS (stx : Syntax) : Except String FieldLHS := -if stx.getKind == `Lean.Parser.Term.structInstArrayRef then - pure $ FieldLHS.modifyOp stx stx[1] -else - -- Note that the representation of the first field is different. - let stx := if stx.getKind == nullKind then stx[1] else stx - if stx.isIdent then pure $ FieldLHS.fieldName stx stx.getId.eraseMacroScopes - else match stx.isFieldIdx? with - | some idx => pure $ FieldLHS.fieldIndex stx idx - | none => throw "unexpected structure syntax" + if stx.getKind == `Lean.Parser.Term.structInstArrayRef then + pure $ FieldLHS.modifyOp stx stx[1] + else + -- Note that the representation of the first field is different. + let stx := if stx.getKind == nullKind then stx[1] else stx + if stx.isIdent then pure $ FieldLHS.fieldName stx stx.getId.eraseMacroScopes + else match stx.isFieldIdx? with + | some idx => pure $ FieldLHS.fieldIndex stx idx + | none => throw "unexpected structure syntax" private def mkStructView (stx : Syntax) (structName : Name) (source : Source) : Except String Struct := do -let args := stx[2].getArgs -let fieldsStx := args.filter $ fun arg => arg.getKind == `Lean.Parser.Term.structInstField -let fields ← fieldsStx.toList.mapM fun fieldStx => do - let val := fieldStx[3] - let first ← toFieldLHS fieldStx[0] - let rest ← fieldStx[1].getArgs.toList.mapM toFieldLHS - pure $ ({ref := fieldStx, lhs := first :: rest, val := FieldVal.term val } : Field Struct) -pure ⟨stx, structName, fields, source⟩ + let args := stx[2].getArgs + let fieldsStx := args.filter $ fun arg => arg.getKind == `Lean.Parser.Term.structInstField + let fields ← fieldsStx.toList.mapM fun fieldStx => do + let val := fieldStx[3] + let first ← toFieldLHS fieldStx[0] + let rest ← fieldStx[1].getArgs.toList.mapM toFieldLHS + pure $ ({ref := fieldStx, lhs := first :: rest, val := FieldVal.term val } : Field Struct) + pure ⟨stx, structName, fields, source⟩ def Struct.modifyFieldsM {m : Type → Type} [Monad m] (s : Struct) (f : Fields → m Fields) : m Struct := -match s with -| ⟨ref, structName, fields, source⟩ => do fields ← f fields; pure ⟨ref, structName, fields, source⟩ + match s with + | ⟨ref, structName, fields, source⟩ => do fields ← f fields; pure ⟨ref, structName, fields, source⟩ @[inline] def Struct.modifyFields (s : Struct) (f : Fields → Fields) : Struct := -Id.run $ s.modifyFieldsM f + Id.run $ s.modifyFieldsM f def Struct.setFields (s : Struct) (fields : Fields) : Struct := -s.modifyFields fun _ => fields + s.modifyFields fun _ => fields private def expandCompositeFields (s : Struct) : Struct := -s.modifyFields $ fun fields => fields.map $ fun field => match field with - | { lhs := FieldLHS.fieldName ref (Name.str Name.anonymous _ _) :: rest, .. } => field - | { lhs := FieldLHS.fieldName ref n@(Name.str _ _ _) :: rest, .. } => - let newEntries := n.components.map $ FieldLHS.fieldName ref - { field with lhs := newEntries ++ rest } - | _ => field + s.modifyFields fun fields => fields.map $ fun field => match field with + | { lhs := FieldLHS.fieldName ref (Name.str Name.anonymous _ _) :: rest, .. } => field + | { lhs := FieldLHS.fieldName ref n@(Name.str _ _ _) :: rest, .. } => + let newEntries := n.components.map $ FieldLHS.fieldName ref + { field with lhs := newEntries ++ rest } + | _ => field private def expandNumLitFields (s : Struct) : TermElabM Struct := -s.modifyFieldsM fun fields => do - let env ← getEnv - let fieldNames := getStructureFields env s.structName - fields.mapM fun field => match field with - | { lhs := FieldLHS.fieldIndex ref idx :: rest, .. } => - if idx == 0 then throwErrorAt ref "invalid field index, index must be greater than 0" - else if idx > fieldNames.size then throwErrorAt! ref "invalid field index, structure has only #{fieldNames.size} fields" - else pure { field with lhs := FieldLHS.fieldName ref fieldNames[idx - 1] :: rest } - | _ => pure field + s.modifyFieldsM fun fields => do + let env ← getEnv + let fieldNames := getStructureFields env s.structName + fields.mapM fun field => match field with + | { lhs := FieldLHS.fieldIndex ref idx :: rest, .. } => + if idx == 0 then throwErrorAt ref "invalid field index, index must be greater than 0" + else if idx > fieldNames.size then throwErrorAt! ref "invalid field index, structure has only #{fieldNames.size} fields" + else pure { field with lhs := FieldLHS.fieldName ref fieldNames[idx - 1] :: rest } + | _ => pure field /- For example, consider the following structures: ``` @@ -331,120 +333,119 @@ s.modifyFieldsM fun fields => do { toB.toA.x := 0, toB.y := 0, z := true : C } ``` -/ private def expandParentFields (s : Struct) : TermElabM Struct := do -let env ← getEnv -s.modifyFieldsM fun fields => fields.mapM fun field => match field with - | { lhs := FieldLHS.fieldName ref fieldName :: rest, .. } => - match findField? env s.structName fieldName with - | none => throwErrorAt! ref "'{fieldName}' is not a field of structure '{s.structName}'" - | some baseStructName => - if baseStructName == s.structName then pure field - else match getPathToBaseStructure? env baseStructName s.structName with - | some path => do - let path := path.map $ fun funName => match funName with - | Name.str _ s _ => FieldLHS.fieldName ref (mkNameSimple s) - | _ => unreachable! - pure { field with lhs := path ++ field.lhs } - | _ => throwErrorAt! ref "failed to access field '{fieldName}' in parent structure" - | _ => pure field + let env ← getEnv + s.modifyFieldsM fun fields => fields.mapM fun field => match field with + | { lhs := FieldLHS.fieldName ref fieldName :: rest, .. } => + match findField? env s.structName fieldName with + | none => throwErrorAt! ref "'{fieldName}' is not a field of structure '{s.structName}'" + | some baseStructName => + if baseStructName == s.structName then pure field + else match getPathToBaseStructure? env baseStructName s.structName with + | some path => do + let path := path.map $ fun funName => match funName with + | Name.str _ s _ => FieldLHS.fieldName ref (mkNameSimple s) + | _ => unreachable! + pure { field with lhs := path ++ field.lhs } + | _ => throwErrorAt! ref "failed to access field '{fieldName}' in parent structure" + | _ => pure field private abbrev FieldMap := HashMap Name Fields private def mkFieldMap (fields : Fields) : TermElabM FieldMap := -fields.foldlM (init := {}) fun fieldMap field => - match field.lhs with - | FieldLHS.fieldName _ fieldName :: rest => - match fieldMap.find? fieldName with - | some (prevField::restFields) => - if field.isSimple || prevField.isSimple then - throwErrorAt! field.ref "field '{fieldName}' has already beed specified" - else - pure $ fieldMap.insert fieldName (field::prevField::restFields) - | _ => pure $ fieldMap.insert fieldName [field] - | _ => unreachable! + fields.foldlM (init := {}) fun fieldMap field => + match field.lhs with + | FieldLHS.fieldName _ fieldName :: rest => + match fieldMap.find? fieldName with + | some (prevField::restFields) => + if field.isSimple || prevField.isSimple then + throwErrorAt! field.ref "field '{fieldName}' has already beed specified" + else + pure $ fieldMap.insert fieldName (field::prevField::restFields) + | _ => pure $ fieldMap.insert fieldName [field] + | _ => unreachable! private def isSimpleField? : Fields → Option (Field Struct) -| [field] => if field.isSimple then some field else none -| _ => none + | [field] => if field.isSimple then some field else none + | _ => none private def getFieldIdx (structName : Name) (fieldNames : Array Name) (fieldName : Name) : TermElabM Nat := do -match fieldNames.findIdx? $ fun n => n == fieldName with -| some idx => pure idx -| none => throwError! "field '{fieldName}' is not a valid field of '{structName}'" + match fieldNames.findIdx? $ fun n => n == fieldName with + | some idx => pure idx + | none => throwError! "field '{fieldName}' is not a valid field of '{structName}'" private def mkProjStx (s : Syntax) (fieldName : Name) : Syntax := -Syntax.node `Lean.Parser.Term.proj #[s, mkAtomFrom s ".", mkIdentFrom s fieldName] + Syntax.node `Lean.Parser.Term.proj #[s, mkAtomFrom s ".", mkIdentFrom s fieldName] private def mkSubstructSource (structName : Name) (fieldNames : Array Name) (fieldName : Name) (src : Source) : TermElabM Source := -match src with -| Source.explicit stx src => do - let idx ← getFieldIdx structName fieldNames fieldName - let stx := stx.modifyArg 0 fun stx => mkProjStx stx fieldName - pure $ Source.explicit stx (mkProj structName idx src) -| s => pure s + match src with + | Source.explicit stx src => do + let idx ← getFieldIdx structName fieldNames fieldName + let stx := stx.modifyArg 0 fun stx => mkProjStx stx fieldName + pure $ Source.explicit stx (mkProj structName idx src) + | s => pure s @[specialize] private def groupFields (expandStruct : Struct → TermElabM Struct) (s : Struct) : TermElabM Struct := do -let env ← getEnv -let fieldNames := getStructureFields env s.structName -withRef s.ref do -s.modifyFieldsM fun fields => do - let fieldMap ← mkFieldMap fields - fieldMap.toList.mapM fun ⟨fieldName, fields⟩ => do - match isSimpleField? fields with - | some field => pure field - | none => - let substructFields := fields.map fun field => { field with lhs := field.lhs.tail! } - let substructSource ← mkSubstructSource s.structName fieldNames fieldName s.source - let field := fields.head! - match Lean.isSubobjectField? env s.structName fieldName with - | some substructName => - let substruct := Struct.mk s.ref substructName substructFields substructSource - let substruct ← expandStruct substruct - pure { field with lhs := [field.lhs.head!], val := FieldVal.nested substruct } - | none => do - -- It is not a substructure field. Thus, we wrap fields using `Syntax`, and use `elabTerm` to process them. - let valStx := s.ref -- construct substructure syntax using s.ref as template - let valStx := valStx.setArg 4 mkNullNode -- erase optional expected type - let args := substructFields.toArray.map Field.toSyntax - let valStx := valStx.setArg 2 (mkSepStx args (mkAtomFrom s.ref ",")) - let valStx := setStructSourceSyntax valStx substructSource - pure { field with lhs := [field.lhs.head!], val := FieldVal.term valStx } + let env ← getEnv + let fieldNames := getStructureFields env s.structName + withRef s.ref do + s.modifyFieldsM fun fields => do + let fieldMap ← mkFieldMap fields + fieldMap.toList.mapM fun ⟨fieldName, fields⟩ => do + match isSimpleField? fields with + | some field => pure field + | none => + let substructFields := fields.map fun field => { field with lhs := field.lhs.tail! } + let substructSource ← mkSubstructSource s.structName fieldNames fieldName s.source + let field := fields.head! + match Lean.isSubobjectField? env s.structName fieldName with + | some substructName => + let substruct := Struct.mk s.ref substructName substructFields substructSource + let substruct ← expandStruct substruct + pure { field with lhs := [field.lhs.head!], val := FieldVal.nested substruct } + | none => do + -- It is not a substructure field. Thus, we wrap fields using `Syntax`, and use `elabTerm` to process them. + let valStx := s.ref -- construct substructure syntax using s.ref as template + let valStx := valStx.setArg 4 mkNullNode -- erase optional expected type + let args := substructFields.toArray.map Field.toSyntax + let valStx := valStx.setArg 2 (mkSepStx args (mkAtomFrom s.ref ",")) + let valStx := setStructSourceSyntax valStx substructSource + pure { field with lhs := [field.lhs.head!], val := FieldVal.term valStx } def findField? (fields : Fields) (fieldName : Name) : Option (Field Struct) := -fields.find? fun field => - match field.lhs with - | [FieldLHS.fieldName _ n] => n == fieldName - | _ => false + fields.find? fun field => + match field.lhs with + | [FieldLHS.fieldName _ n] => n == fieldName + | _ => false @[specialize] private def addMissingFields (expandStruct : Struct → TermElabM Struct) (s : Struct) : TermElabM Struct := do -let env ← getEnv -let fieldNames := getStructureFields env s.structName -let ref := s.ref -withRef ref do -let fields ← fieldNames.foldlM (init := []) fun fields fieldName => do - match findField? s.fields fieldName with - | some field => pure $ field::fields - | none => - let addField (val : FieldVal Struct) : TermElabM Fields := do - pure $ { ref := s.ref, lhs := [FieldLHS.fieldName s.ref fieldName], val := val } :: fields - match Lean.isSubobjectField? env s.structName fieldName with - | some substructName => do - let substructSource ← mkSubstructSource s.structName fieldNames fieldName s.source - let substruct := Struct.mk s.ref substructName [] substructSource - let substruct ← expandStruct substruct - addField (FieldVal.nested substruct) - | none => - match s.source with - | Source.none => addField FieldVal.default - | Source.implicit _ => addField (FieldVal.term (mkHole s.ref)) - | Source.explicit stx _ => - -- stx is of the form `optional (try (termParser >> "with"))` - let src := stx[0] - let val := mkProjStx src fieldName - addField (FieldVal.term val) -pure $ s.setFields fields.reverse + let env ← getEnv + let fieldNames := getStructureFields env s.structName + let ref := s.ref + withRef ref do + let fields ← fieldNames.foldlM (init := []) fun fields fieldName => do + match findField? s.fields fieldName with + | some field => pure $ field::fields + | none => + let addField (val : FieldVal Struct) : TermElabM Fields := do + pure $ { ref := s.ref, lhs := [FieldLHS.fieldName s.ref fieldName], val := val } :: fields + match Lean.isSubobjectField? env s.structName fieldName with + | some substructName => do + let substructSource ← mkSubstructSource s.structName fieldNames fieldName s.source + let substruct := Struct.mk s.ref substructName [] substructSource + let substruct ← expandStruct substruct + addField (FieldVal.nested substruct) + | none => + match s.source with + | Source.none => addField FieldVal.default + | Source.implicit _ => addField (FieldVal.term (mkHole s.ref)) + | Source.explicit stx _ => + -- stx is of the form `optional (try (termParser >> "with"))` + let src := stx[0] + let val := mkProjStx src fieldName + addField (FieldVal.term val) + pure $ s.setFields fields.reverse -private partial def expandStruct : Struct → TermElabM Struct -| s => do +private partial def expandStruct (s : Struct) : TermElabM Struct := do let s := expandCompositeFields s let s ← expandNumLitFields s let s ← expandParentFields s @@ -452,67 +453,66 @@ private partial def expandStruct : Struct → TermElabM Struct addMissingFields expandStruct s structure CtorHeaderResult := -(ctorFn : Expr) -(ctorFnType : Expr) -(instMVars : Array MVarId := #[]) + (ctorFn : Expr) + (ctorFnType : Expr) + (instMVars : Array MVarId := #[]) private def mkCtorHeaderAux : Nat → Expr → Expr → Array MVarId → TermElabM CtorHeaderResult -| 0, type, ctorFn, instMVars => pure { ctorFn := ctorFn, ctorFnType := type, instMVars := instMVars } -| n+1, type, ctorFn, instMVars => do - let type ← whnfForall type - match type with - | Expr.forallE _ d b c => - match c.binderInfo with - | BinderInfo.instImplicit => - let a ← mkFreshExprMVar d MetavarKind.synthetic - mkCtorHeaderAux n (b.instantiate1 a) (mkApp ctorFn a) (instMVars.push a.mvarId!) - | _ => - let a ← mkFreshExprMVar d - mkCtorHeaderAux n (b.instantiate1 a) (mkApp ctorFn a) instMVars - | _ => throwError "unexpected constructor type" + | 0, type, ctorFn, instMVars => pure { ctorFn := ctorFn, ctorFnType := type, instMVars := instMVars } + | n+1, type, ctorFn, instMVars => do + let type ← whnfForall type + match type with + | Expr.forallE _ d b c => + match c.binderInfo with + | BinderInfo.instImplicit => + let a ← mkFreshExprMVar d MetavarKind.synthetic + mkCtorHeaderAux n (b.instantiate1 a) (mkApp ctorFn a) (instMVars.push a.mvarId!) + | _ => + let a ← mkFreshExprMVar d + mkCtorHeaderAux n (b.instantiate1 a) (mkApp ctorFn a) instMVars + | _ => throwError "unexpected constructor type" private partial def getForallBody : Nat → Expr → Option Expr -| i+1, Expr.forallE _ _ b _ => getForallBody i b -| i+1, _ => none -| 0, type => type + | i+1, Expr.forallE _ _ b _ => getForallBody i b + | i+1, _ => none + | 0, type => type private def propagateExpectedType (type : Expr) (numFields : Nat) (expectedType? : Option Expr) : TermElabM Unit := -match expectedType? with -| none => pure () -| some expectedType => do - match getForallBody numFields type with - | none => pure () - | some typeBody => - unless typeBody.hasLooseBVars do - isDefEq expectedType typeBody - pure () + match expectedType? with + | none => pure () + | some expectedType => do + match getForallBody numFields type with + | none => pure () + | some typeBody => + unless typeBody.hasLooseBVars do + isDefEq expectedType typeBody + pure () private def mkCtorHeader (ctorVal : ConstructorVal) (expectedType? : Option Expr) : TermElabM CtorHeaderResult := do -let lvls ← ctorVal.lparams.mapM fun _ => mkFreshLevelMVar -let val := Lean.mkConst ctorVal.name lvls -let type := (ConstantInfo.ctorInfo ctorVal).instantiateTypeLevelParams lvls -let r ← mkCtorHeaderAux ctorVal.nparams type val #[] -propagateExpectedType r.ctorFnType ctorVal.nfields expectedType? -synthesizeAppInstMVars r.instMVars -pure r + let lvls ← ctorVal.lparams.mapM fun _ => mkFreshLevelMVar + let val := Lean.mkConst ctorVal.name lvls + let type := (ConstantInfo.ctorInfo ctorVal).instantiateTypeLevelParams lvls + let r ← mkCtorHeaderAux ctorVal.nparams type val #[] + propagateExpectedType r.ctorFnType ctorVal.nfields expectedType? + synthesizeAppInstMVars r.instMVars + pure r def markDefaultMissing (e : Expr) : Expr := -mkAnnotation `structInstDefault e + mkAnnotation `structInstDefault e def defaultMissing? (e : Expr) : Option Expr := -annotation? `structInstDefault e + annotation? `structInstDefault e def throwFailedToElabField {α} (fieldName : Name) (structName : Name) (msgData : MessageData) : TermElabM α := -throwError! "failed to elaborate field '{fieldName}' of '{structName}, {msgData}" + throwError! "failed to elaborate field '{fieldName}' of '{structName}, {msgData}" def trySynthStructInstance? (s : Struct) (expectedType : Expr) : TermElabM (Option Expr) := do -if !s.allDefault then - pure none -else - try synthInstance? expectedType catch _ => pure none + if !s.allDefault then + pure none + else + try synthInstance? expectedType catch _ => pure none -private partial def elabStruct : Struct → Option Expr → TermElabM (Expr × Struct) -| s, expectedType? => withRef s.ref do +private partial def elabStruct (s : Struct) (expectedType? : Option Expr) : TermElabM (Expr × Struct) := withRef s.ref do let env ← getEnv let ctorVal := getStructureCtor env s.structName let { ctorFn := ctorFn, ctorFnType := ctorFnType, .. } ← mkCtorHeader ctorVal expectedType? @@ -542,55 +542,52 @@ private partial def elabStruct : Struct → Option Expr → TermElabM (Expr × S namespace DefaultFields structure Context := --- We must search for default values overriden in derived structures -(structs : Array Struct := #[]) -(allStructNames : Array Name := #[]) -/- -Consider the following example: -``` -structure A := -(x : Nat := 1) + -- We must search for default values overriden in derived structures + (structs : Array Struct := #[]) + (allStructNames : Array Name := #[]) + /- + Consider the following example: + ``` + structure A := + (x : Nat := 1) -structure B extends A := -(y : Nat := x + 1) (x := y + 1) + structure B extends A := + (y : Nat := x + 1) (x := y + 1) -structure C extends B := -(z : Nat := 2*y) (x := z + 3) -``` -And we are trying to elaborate a structure instance for `C`. There are default values for `x` at `A`, `B`, and `C`. -We say the default value at `C` has distance 0, the one at `B` distance 1, and the one at `A` distance 2. -The field `maxDistance` specifies the maximum distance considered in a round of Default field computation. -Remark: since `C` does not set a default value of `y`, the default value at `B` is at distance 0. + structure C extends B := + (z : Nat := 2*y) (x := z + 3) + ``` + And we are trying to elaborate a structure instance for `C`. There are default values for `x` at `A`, `B`, and `C`. + We say the default value at `C` has distance 0, the one at `B` distance 1, and the one at `A` distance 2. + The field `maxDistance` specifies the maximum distance considered in a round of Default field computation. + Remark: since `C` does not set a default value of `y`, the default value at `B` is at distance 0. -The fixpoint for setting default values works in the following way. -- Keep computing default values using `maxDistance == 0`. -- We increase `maxDistance` whenever we failed to compute a new default value in a round. -- If `maxDistance > 0`, then we interrupt a round as soon as we compute some default value. - We use depth-first search. -- We sign an error if no progress is made when `maxDistance` == structure hierarchy depth (2 in the example above). --/ -(maxDistance : Nat := 0) + The fixpoint for setting default values works in the following way. + - Keep computing default values using `maxDistance == 0`. + - We increase `maxDistance` whenever we failed to compute a new default value in a round. + - If `maxDistance > 0`, then we interrupt a round as soon as we compute some default value. + We use depth-first search. + - We sign an error if no progress is made when `maxDistance` == structure hierarchy depth (2 in the example above). + -/ + (maxDistance : Nat := 0) structure State := -(progress : Bool := false) + (progress : Bool := false) -partial def collectStructNames : Struct → Array Name → Array Name -| struct, names => +partial def collectStructNames (struct : Struct) (names : Array Name) : Array Name := let names := names.push struct.structName struct.fields.foldl (init := names) fun names field => match field.val with | FieldVal.nested struct => collectStructNames struct names | _ => names -partial def getHierarchyDepth : Struct → Nat -| struct => +partial def getHierarchyDepth (struct : Struct) : Nat := struct.fields.foldl (init := 0) fun max field => match field.val with | FieldVal.nested struct => Nat.max max (getHierarchyDepth struct + 1) | _ => max -partial def findDefaultMissing? (mctx : MetavarContext) : Struct → Option (Field Struct) -| struct => +partial def findDefaultMissing? (mctx : MetavarContext) (struct : Struct) : Option (Field Struct) := struct.fields.findSome? fun field => match field.val with | FieldVal.nested struct => findDefaultMissing? mctx struct @@ -601,141 +598,140 @@ partial def findDefaultMissing? (mctx : MetavarContext) : Struct → Option (Fie | _ => none def getFieldName (field : Field Struct) : Name := -match field.lhs with -| [FieldLHS.fieldName _ fieldName] => fieldName -| _ => unreachable! + match field.lhs with + | [FieldLHS.fieldName _ fieldName] => fieldName + | _ => unreachable! abbrev M := ReaderT Context (StateRefT State TermElabM) def isRoundDone : M Bool := do -return (← get).progress && (← read).maxDistance > 0 + return (← get).progress && (← read).maxDistance > 0 def getFieldValue? (struct : Struct) (fieldName : Name) : Option Expr := -struct.fields.findSome? fun field => - if getFieldName field == fieldName then - field.expr? - else - none + struct.fields.findSome? fun field => + if getFieldName field == fieldName then + field.expr? + else + none partial def mkDefaultValueAux? (struct : Struct) : Expr → TermElabM (Option Expr) -| Expr.lam n d b c => withRef struct.ref do - if c.binderInfo.isExplicit then - let fieldName := n - match getFieldValue? struct fieldName with - | none => pure none - | some val => - let valType ← inferType val - if (← isDefEq valType d) then - mkDefaultValueAux? struct (b.instantiate1 val) - else - pure none - else - let arg ← mkFreshExprMVar d - mkDefaultValueAux? struct (b.instantiate1 arg) -| e => - if e.isAppOfArity `id 2 then - pure (some e.appArg!) - else - pure (some e) + | Expr.lam n d b c => withRef struct.ref do + if c.binderInfo.isExplicit then + let fieldName := n + match getFieldValue? struct fieldName with + | none => pure none + | some val => + let valType ← inferType val + if (← isDefEq valType d) then + mkDefaultValueAux? struct (b.instantiate1 val) + else + pure none + else + let arg ← mkFreshExprMVar d + mkDefaultValueAux? struct (b.instantiate1 arg) + | e => + if e.isAppOfArity `id 2 then + pure (some e.appArg!) + else + pure (some e) def mkDefaultValue? (struct : Struct) (cinfo : ConstantInfo) : TermElabM (Option Expr) := -withRef struct.ref do -let us ← cinfo.lparams.mapM fun _ => mkFreshLevelMVar -mkDefaultValueAux? struct (cinfo.instantiateValueLevelParams us) + withRef struct.ref do + let us ← cinfo.lparams.mapM fun _ => mkFreshLevelMVar + mkDefaultValueAux? struct (cinfo.instantiateValueLevelParams us) /-- If `e` is a projection function of one of the given structures, then reduce it -/ def reduceProjOf? (structNames : Array Name) (e : Expr) : MetaM (Option Expr) := do -if !e.isApp then pure none -else match e.getAppFn with - | Expr.const name _ _ => do - let env ← getEnv - match env.getProjectionStructureName? name with - | some structName => - if structNames.contains structName then - Meta.unfoldDefinition? e - else - pure none - | none => pure none - | _ => pure none + if !e.isApp then pure none + else match e.getAppFn with + | Expr.const name _ _ => do + let env ← getEnv + match env.getProjectionStructureName? name with + | some structName => + if structNames.contains structName then + Meta.unfoldDefinition? e + else + pure none + | none => pure none + | _ => pure none /-- Reduce default value. It performs beta reduction and projections of the given structures. -/ partial def reduce (structNames : Array Name) : Expr → MetaM Expr -| e@(Expr.lam _ _ _ _) => lambdaLetTelescope e fun xs b => do mkLambdaFVars xs (← reduce structNames b) -| e@(Expr.forallE _ _ _ _) => forallTelescope e fun xs b => do mkForallFVars xs (← reduce structNames b) -| e@(Expr.letE _ _ _ _ _) => lambdaLetTelescope e fun xs b => do mkLetFVars xs (← reduce structNames b) -| e@(Expr.proj _ i b _) => do - match (← Meta.reduceProj? b i) with - | some r => reduce structNames r - | none => pure $ e.updateProj! (← reduce structNames b) -| e@(Expr.app f _ _) => do - match (← reduceProjOf? structNames e) with - | some r => reduce structNames r - | none => - let f := f.getAppFn - let f' ← reduce structNames f - if f'.isLambda then - let revArgs := e.getAppRevArgs - reduce structNames (f'.betaRev revArgs) + | e@(Expr.lam _ _ _ _) => lambdaLetTelescope e fun xs b => do mkLambdaFVars xs (← reduce structNames b) + | e@(Expr.forallE _ _ _ _) => forallTelescope e fun xs b => do mkForallFVars xs (← reduce structNames b) + | e@(Expr.letE _ _ _ _ _) => lambdaLetTelescope e fun xs b => do mkLetFVars xs (← reduce structNames b) + | e@(Expr.proj _ i b _) => do + match (← Meta.reduceProj? b i) with + | some r => reduce structNames r + | none => pure $ e.updateProj! (← reduce structNames b) + | e@(Expr.app f _ _) => do + match (← reduceProjOf? structNames e) with + | some r => reduce structNames r + | none => + let f := f.getAppFn + let f' ← reduce structNames f + if f'.isLambda then + let revArgs := e.getAppRevArgs + reduce structNames (f'.betaRev revArgs) + else + let args ← e.getAppArgs.mapM (reduce structNames) + pure (mkAppN f' args) + | e@(Expr.mdata _ b _) => do + let b ← reduce structNames b + if (defaultMissing? e).isSome && !b.isMVar then + pure b else - let args ← e.getAppArgs.mapM (reduce structNames) - pure (mkAppN f' args) -| e@(Expr.mdata _ b _) => do - let b ← reduce structNames b - if (defaultMissing? e).isSome && !b.isMVar then - pure b - else - pure $ e.updateMData! b -| e@(Expr.mvar mvarId _) => do - match (← getExprMVarAssignment? mvarId) with - | some val => if val.isMVar then reduce structNames val else pure val - | none => pure e -| e => pure e + pure $ e.updateMData! b + | e@(Expr.mvar mvarId _) => do + match (← getExprMVarAssignment? mvarId) with + | some val => if val.isMVar then reduce structNames val else pure val + | none => pure e + | e => pure e partial def tryToSynthesizeDefault (structs : Array Struct) (allStructNames : Array Name) (maxDistance : Nat) (fieldName : Name) (mvarId : MVarId) : TermElabM Bool := -let rec loop (i : Nat) (dist : Nat) := do - if dist > maxDistance then - pure false - else if h : i < structs.size then do - let struct := structs.get ⟨i, h⟩ - let defaultName := struct.structName ++ fieldName ++ `_default - let env ← getEnv - match env.find? defaultName with - | some cinfo@(ConstantInfo.defnInfo defVal) => do - let mctx ← getMCtx - let val? ← mkDefaultValue? struct cinfo - match val? with - | none => do setMCtx mctx; loop (i+1) (dist+1) - | some val => do - let val ← liftMetaM $ reduce allStructNames val - match val.find? fun e => (defaultMissing? e).isSome with - | some _ => setMCtx mctx; loop (i+1) (dist+1) - | none => - let mvarDecl ← getMVarDecl mvarId - let val ← ensureHasType mvarDecl.type val - assignExprMVar mvarId val - pure true - | _ => loop (i+1) dist - else - pure false -loop 0 0 + let rec loop (i : Nat) (dist : Nat) := do + if dist > maxDistance then + pure false + else if h : i < structs.size then do + let struct := structs.get ⟨i, h⟩ + let defaultName := struct.structName ++ fieldName ++ `_default + let env ← getEnv + match env.find? defaultName with + | some cinfo@(ConstantInfo.defnInfo defVal) => do + let mctx ← getMCtx + let val? ← mkDefaultValue? struct cinfo + match val? with + | none => do setMCtx mctx; loop (i+1) (dist+1) + | some val => do + let val ← liftMetaM $ reduce allStructNames val + match val.find? fun e => (defaultMissing? e).isSome with + | some _ => setMCtx mctx; loop (i+1) (dist+1) + | none => + let mvarDecl ← getMVarDecl mvarId + let val ← ensureHasType mvarDecl.type val + assignExprMVar mvarId val + pure true + | _ => loop (i+1) dist + else + pure false + loop 0 0 -partial def step : Struct → M Unit -| struct => unlessM isRoundDone $ withReader (fun ctx => { ctx with structs := ctx.structs.push struct }) do - struct.fields.forM fun field => do - match field.val with - | FieldVal.nested struct => step struct - | _ => match field.expr? with - | none => unreachable! - | some expr => match defaultMissing? expr with - | some (Expr.mvar mvarId _) => - unless (← isExprMVarAssigned mvarId) do - let ctx ← read - if (← withRef field.ref $ tryToSynthesizeDefault ctx.structs ctx.allStructNames ctx.maxDistance (getFieldName field) mvarId) then - modify fun s => { s with progress := true } - | _ => pure () +partial def step (struct : Struct) : M Unit := do + unlessM isRoundDone $ withReader (fun ctx => { ctx with structs := ctx.structs.push struct }) do + struct.fields.forM fun field => do + match field.val with + | FieldVal.nested struct => step struct + | _ => match field.expr? with + | none => unreachable! + | some expr => match defaultMissing? expr with + | some (Expr.mvar mvarId _) => + unless (← isExprMVarAssigned mvarId) do + let ctx ← read + if (← withRef field.ref $ tryToSynthesizeDefault ctx.structs ctx.allStructNames ctx.maxDistance (getFieldName field) mvarId) then + modify fun s => { s with progress := true } + | _ => pure () -partial def propagateLoop (hierarchyDepth : Nat) : Nat → Struct → M Unit -| d, struct => do +partial def propagateLoop (hierarchyDepth : Nat) (d : Nat) (struct : Struct) : M Unit := do match findDefaultMissing? (← getMCtx) struct with | none => pure () -- Done | some field => @@ -750,27 +746,26 @@ partial def propagateLoop (hierarchyDepth : Nat) : Nat → Struct → M Unit propagateLoop hierarchyDepth (d+1) struct def propagate (struct : Struct) : TermElabM Unit := -let hierarchyDepth := getHierarchyDepth struct -let structNames := collectStructNames struct #[] -(propagateLoop hierarchyDepth 0 struct { allStructNames := structNames }).run' {} + let hierarchyDepth := getHierarchyDepth struct + let structNames := collectStructNames struct #[] + (propagateLoop hierarchyDepth 0 struct { allStructNames := structNames }).run' {} end DefaultFields private def elabStructInstAux (stx : Syntax) (expectedType? : Option Expr) (source : Source) : TermElabM Expr := do -let structName ← getStructName stx expectedType? source -unless isStructureLike (← getEnv) structName do - throwError! "invalid \{...} notation, '{structName}' is not a structure" -match mkStructView stx structName source with -| Except.error ex => throwError ex -| Except.ok struct => - let struct ← expandStruct struct - trace[Elab.struct]! "{struct}" - let (r, struct) ← elabStruct struct expectedType? - DefaultFields.propagate struct - pure r + let structName ← getStructName stx expectedType? source + unless isStructureLike (← getEnv) structName do + throwError! "invalid \{...} notation, '{structName}' is not a structure" + match mkStructView stx structName source with + | Except.error ex => throwError ex + | Except.ok struct => + let struct ← expandStruct struct + trace[Elab.struct]! "{struct}" + let (r, struct) ← elabStruct struct expectedType? + DefaultFields.propagate struct + pure r -@[builtinTermElab structInst] def elabStructInst : TermElab := -fun stx expectedType? => do +@[builtinTermElab structInst] def elabStructInst : TermElab := fun stx expectedType? => do match (← expandNonAtomicExplicitSource stx) with | some stxNew => withMacroExpansion stx stxNew $ elabTerm stxNew expectedType? | none => diff --git a/src/Lean/Elab/Structure.lean b/src/Lean/Elab/Structure.lean index 24fe8b0273..8d58749f2e 100644 --- a/src/Lean/Elab/Structure.lean +++ b/src/Lean/Elab/Structure.lean @@ -21,74 +21,74 @@ parser! (structureTk <|> classTk) >> declId >> many Term.bracketedBinder >> opti -/ structure StructCtorView := -(ref : Syntax) -(modifiers : Modifiers) -(inferMod : Bool) -- true if `{}` is used in the constructor declaration -(name : Name) -(declName : Name) + (ref : Syntax) + (modifiers : Modifiers) + (inferMod : Bool) -- true if `{}` is used in the constructor declaration + (name : Name) + (declName : Name) structure StructFieldView := -(ref : Syntax) -(modifiers : Modifiers) -(binderInfo : BinderInfo) -(inferMod : Bool) -(declName : Name) -(name : Name) -(binders : Syntax) -(type? : Option Syntax) -(value? : Option Syntax) + (ref : Syntax) + (modifiers : Modifiers) + (binderInfo : BinderInfo) + (inferMod : Bool) + (declName : Name) + (name : Name) + (binders : Syntax) + (type? : Option Syntax) + (value? : Option Syntax) structure StructView := -(ref : Syntax) -(modifiers : Modifiers) -(scopeLevelNames : List Name) -- All `universe` declarations in the current scope -(allUserLevelNames : List Name) -- `scopeLevelNames` ++ explicit universe parameters provided in the `structure` command -(isClass : Bool) -(declName : Name) -(scopeVars : Array Expr) -- All `variable` declaration in the current scope -(params : Array Expr) -- Explicit parameters provided in the `structure` command -(parents : Array Syntax) -(type : Syntax) -(ctor : StructCtorView) -(fields : Array StructFieldView) + (ref : Syntax) + (modifiers : Modifiers) + (scopeLevelNames : List Name) -- All `universe` declarations in the current scope + (allUserLevelNames : List Name) -- `scopeLevelNames` ++ explicit universe parameters provided in the `structure` command + (isClass : Bool) + (declName : Name) + (scopeVars : Array Expr) -- All `variable` declaration in the current scope + (params : Array Expr) -- Explicit parameters provided in the `structure` command + (parents : Array Syntax) + (type : Syntax) + (ctor : StructCtorView) + (fields : Array StructFieldView) inductive StructFieldKind -| newField | fromParent | subobject + | newField | fromParent | subobject structure StructFieldInfo := -(name : Name) -(declName : Name) -- Remark: this field value doesn't matter for fromParent fields. -(fvar : Expr) -(kind : StructFieldKind) -(inferMod : Bool := false) -(value? : Option Expr := none) + (name : Name) + (declName : Name) -- Remark: this field value doesn't matter for fromParent fields. + (fvar : Expr) + (kind : StructFieldKind) + (inferMod : Bool := false) + (value? : Option Expr := none) -instance StructFieldInfo.inhabited : Inhabited StructFieldInfo := -⟨{ name := arbitrary _, declName := arbitrary _, fvar := arbitrary _, kind := StructFieldKind.newField }⟩ +instance : Inhabited StructFieldInfo := + ⟨{ name := arbitrary _, declName := arbitrary _, fvar := arbitrary _, kind := StructFieldKind.newField }⟩ def StructFieldInfo.isFromParent (info : StructFieldInfo) : Bool := -match info.kind with -| StructFieldKind.fromParent => true -| _ => false + match info.kind with + | StructFieldKind.fromParent => true + | _ => false def StructFieldInfo.isSubobject (info : StructFieldInfo) : Bool := -match info.kind with -| StructFieldKind.subobject => true -| _ => false + match info.kind with + | StructFieldKind.subobject => true + | _ => false /- Auxiliary declaration for `mkProjections` -/ structure ProjectionInfo := -(declName : Name) -(inferMod : Bool) + (declName : Name) + (inferMod : Bool) structure ElabStructResult := -(decl : Declaration) -(projInfos : List ProjectionInfo) -(projInstances : List Name) -- projections (to parent classes) that must be marked as instances. -(mctx : MetavarContext) -(lctx : LocalContext) -(localInsts : LocalInstances) -(defaultAuxDecls : Array (Name × Expr × Expr)) + (decl : Declaration) + (projInfos : List ProjectionInfo) + (projInstances : List Name) -- projections (to parent classes) that must be marked as instances. + (mctx : MetavarContext) + (lctx : LocalContext) + (localInsts : LocalInstances) + (defaultAuxDecls : Array (Name × Expr × Expr)) private def defaultCtorName := `mk @@ -99,35 +99,35 @@ parser! try (declModifiers >> ident >> optional inferMod >> " :: ") ``` -/ private def expandCtor (structStx : Syntax) (structModifiers : Modifiers) (structDeclName : Name) : CommandElabM StructCtorView := -let optCtor := structStx[6] -if optCtor.isNone then - pure { ref := structStx, modifiers := {}, inferMod := false, name := defaultCtorName, declName := structDeclName ++ defaultCtorName } -else - let ctor := optCtor[0] - withRef ctor do - let ctorModifiers ← elabModifiers ctor[0] - checkValidCtorModifier ctorModifiers - if ctorModifiers.isPrivate && structModifiers.isPrivate then - throwError "invalid 'private' constructor in a 'private' structure" - if ctorModifiers.isProtected && structModifiers.isPrivate then - throwError "invalid 'protected' constructor in a 'private' structure" - let inferMod := !ctor[2].isNone - let name := ctor[1].getId - let declName := structDeclName ++ name - let declName ← applyVisibility ctorModifiers.visibility declName - pure { ref := ctor, name := name, modifiers := ctorModifiers, inferMod := inferMod, declName := declName } + let optCtor := structStx[6] + if optCtor.isNone then + pure { ref := structStx, modifiers := {}, inferMod := false, name := defaultCtorName, declName := structDeclName ++ defaultCtorName } + else + let ctor := optCtor[0] + withRef ctor do + let ctorModifiers ← elabModifiers ctor[0] + checkValidCtorModifier ctorModifiers + if ctorModifiers.isPrivate && structModifiers.isPrivate then + throwError "invalid 'private' constructor in a 'private' structure" + if ctorModifiers.isProtected && structModifiers.isPrivate then + throwError "invalid 'protected' constructor in a 'private' structure" + let inferMod := !ctor[2].isNone + let name := ctor[1].getId + let declName := structDeclName ++ name + let declName ← applyVisibility ctorModifiers.visibility declName + pure { ref := ctor, name := name, modifiers := ctorModifiers, inferMod := inferMod, declName := declName } def checkValidFieldModifier (modifiers : Modifiers) : CommandElabM Unit := do -if modifiers.isNoncomputable then - throwError "invalid use of 'noncomputable' in field declaration" -if modifiers.isPartial then - throwError "invalid use of 'partial' in field declaration" -if modifiers.isUnsafe then - throwError "invalid use of 'unsafe' in field declaration" -if modifiers.attrs.size != 0 then - throwError "invalid use of attributes in field declaration" -if modifiers.isPrivate then - throwError "private fields are not supported yet" + if modifiers.isNoncomputable then + throwError "invalid use of 'noncomputable' in field declaration" + if modifiers.isPartial then + throwError "invalid use of 'partial' in field declaration" + if modifiers.isUnsafe then + throwError "invalid use of 'unsafe' in field declaration" + if modifiers.attrs.size != 0 then + throwError "invalid use of attributes in field declaration" + if modifiers.isPrivate then + throwError "private fields are not supported yet" /- ``` @@ -138,94 +138,92 @@ def structFields := parser! many (structExplicitBinder <|> structImplici ``` -/ private def expandFields (structStx : Syntax) (structModifiers : Modifiers) (structDeclName : Name) : CommandElabM (Array StructFieldView) := -let fieldBinders := structStx[7][0].getArgs -fieldBinders.foldlM (init := #[]) fun (views : Array StructFieldView) fieldBinder => withRef fieldBinder do - let k := fieldBinder.getKind - let binfo ← - if k == `Lean.Parser.Command.structExplicitBinder then pure BinderInfo.default - else if k == `Lean.Parser.Command.structImplicitBinder then pure BinderInfo.implicit - else if k == `Lean.Parser.Command.structInstBinder then pure BinderInfo.instImplicit - else throwError "unexpected kind of structure field" - let fieldModifiers ← elabModifiers fieldBinder[0] - checkValidFieldModifier fieldModifiers - if fieldModifiers.isPrivate && structModifiers.isPrivate then - throwError "invalid 'private' field in a 'private' structure" - if fieldModifiers.isProtected && structModifiers.isPrivate then - throwError "invalid 'protected' field in a 'private' structure" - let inferMod := !fieldBinder[3].isNone - let (binders, type?) := - if binfo == BinderInfo.default then - expandOptDeclSig fieldBinder[4] - else - let (binders, type) := expandDeclSig fieldBinder[4] - (binders, some type) - let value? := - if binfo != BinderInfo.default then none - else - let optBinderDefault := fieldBinder[5] - if optBinderDefault.isNone then none + let fieldBinders := structStx[7][0].getArgs + fieldBinders.foldlM (init := #[]) fun (views : Array StructFieldView) fieldBinder => withRef fieldBinder do + let k := fieldBinder.getKind + let binfo ← + if k == `Lean.Parser.Command.structExplicitBinder then pure BinderInfo.default + else if k == `Lean.Parser.Command.structImplicitBinder then pure BinderInfo.implicit + else if k == `Lean.Parser.Command.structInstBinder then pure BinderInfo.instImplicit + else throwError "unexpected kind of structure field" + let fieldModifiers ← elabModifiers fieldBinder[0] + checkValidFieldModifier fieldModifiers + if fieldModifiers.isPrivate && structModifiers.isPrivate then + throwError "invalid 'private' field in a 'private' structure" + if fieldModifiers.isProtected && structModifiers.isPrivate then + throwError "invalid 'protected' field in a 'private' structure" + let inferMod := !fieldBinder[3].isNone + let (binders, type?) := + if binfo == BinderInfo.default then + expandOptDeclSig fieldBinder[4] else - -- binderDefault := parser! " := " >> termParser - some optBinderDefault[0][1] - let idents := fieldBinder[2].getArgs - idents.foldlM (init := views) fun (views : Array StructFieldView) ident => withRef ident do - let name := ident.getId - if isInternalSubobjectFieldName name then - throwError! "invalid field name '{name}', identifiers starting with '_' are reserved to the system" - let declName := structDeclName ++ name - let declName ← applyVisibility fieldModifiers.visibility declName - pure $ views.push { - ref := ident, - modifiers := fieldModifiers, - binderInfo := binfo, - inferMod := inferMod, - declName := declName, - name := name, - binders := binders, - type? := type?, - value? := value? } - + let (binders, type) := expandDeclSig fieldBinder[4] + (binders, some type) + let value? := + if binfo != BinderInfo.default then none + else + let optBinderDefault := fieldBinder[5] + if optBinderDefault.isNone then none + else + -- binderDefault := parser! " := " >> termParser + some optBinderDefault[0][1] + let idents := fieldBinder[2].getArgs + idents.foldlM (init := views) fun (views : Array StructFieldView) ident => withRef ident do + let name := ident.getId + if isInternalSubobjectFieldName name then + throwError! "invalid field name '{name}', identifiers starting with '_' are reserved to the system" + let declName := structDeclName ++ name + let declName ← applyVisibility fieldModifiers.visibility declName + pure $ views.push { + ref := ident, + modifiers := fieldModifiers, + binderInfo := binfo, + inferMod := inferMod, + declName := declName, + name := name, + binders := binders, + type? := type?, + value? := value? } private def validStructType (type : Expr) : Bool := -match type with -| Expr.sort (Level.succ _ _) _ => true -| _ => false + match type with + | Expr.sort (Level.succ _ _) _ => true + | _ => false private def checkParentIsStructure (parent : Expr) : TermElabM Name := -match parent.getAppFn with -| Expr.const c _ _ => do - unless isStructure (← getEnv) c do - throwError! "'{c}' is not a structure" - pure c -| _ => throwError "expected structure" + match parent.getAppFn with + | Expr.const c _ _ => do + unless isStructure (← getEnv) c do + throwError! "'{c}' is not a structure" + pure c + | _ => throwError "expected structure" private def findFieldInfo? (infos : Array StructFieldInfo) (fieldName : Name) : Option StructFieldInfo := -infos.find? fun info => info.name == fieldName + infos.find? fun info => info.name == fieldName private def containsFieldName (infos : Array StructFieldInfo) (fieldName : Name) : Bool := -(findFieldInfo? infos fieldName).isSome + (findFieldInfo? infos fieldName).isSome private partial def processSubfields {α} (structDeclName : Name) (parentFVar : Expr) (parentStructName : Name) (subfieldNames : Array Name) (infos : Array StructFieldInfo) (k : Array StructFieldInfo → TermElabM α) : TermElabM α := -let rec loop (i : Nat) (infos : Array StructFieldInfo) := do - if h : i < subfieldNames.size then - let subfieldName := subfieldNames.get ⟨i, h⟩ - if containsFieldName infos subfieldName then - throwError! "field '{subfieldName}' from '{parentStructName}' has already been declared" - let val ← mkProjection parentFVar subfieldName - let type ← inferType val - withLetDecl subfieldName type val fun subfieldFVar => - /- The following `declName` is only used for creating the `_default` auxiliary declaration name when - its default value is overwritten in the structure. -/ - let declName := structDeclName ++ subfieldName - let infos := infos.push { name := subfieldName, declName := declName, fvar := subfieldFVar, kind := StructFieldKind.fromParent } - loop (i+1) infos - else - k infos -loop 0 infos + let rec loop (i : Nat) (infos : Array StructFieldInfo) := do + if h : i < subfieldNames.size then + let subfieldName := subfieldNames.get ⟨i, h⟩ + if containsFieldName infos subfieldName then + throwError! "field '{subfieldName}' from '{parentStructName}' has already been declared" + let val ← mkProjection parentFVar subfieldName + let type ← inferType val + withLetDecl subfieldName type val fun subfieldFVar => + /- The following `declName` is only used for creating the `_default` auxiliary declaration name when + its default value is overwritten in the structure. -/ + let declName := structDeclName ++ subfieldName + let infos := infos.push { name := subfieldName, declName := declName, fvar := subfieldFVar, kind := StructFieldKind.fromParent } + loop (i+1) infos + else + k infos + loop 0 infos -private partial def withParents {α} (view : StructView) : Nat → Array StructFieldInfo → (Array StructFieldInfo → TermElabM α) → TermElabM α -| i, infos, k => +private partial def withParents {α} (view : StructView) (i : Nat) (infos : Array StructFieldInfo) (k : Array StructFieldInfo → TermElabM α) : TermElabM α := do if h : i < view.parents.size then let parentStx := view.parents.get ⟨i, h⟩ withRef parentStx do @@ -244,29 +242,29 @@ private partial def withParents {α} (view : StructView) : Nat → Array StructF k infos private def elabFieldTypeValue (view : StructFieldView) (params : Array Expr) : TermElabM (Option Expr × Option Expr) := do -match view.type? with -| none => - match view.value? with - | none => pure (none, none) - | some valStx => - let value ← Term.elabTerm valStx none - let value ← mkLambdaFVars params value - pure (none, value) -| some typeStx => - let type ← Term.elabType typeStx - match view.value? with - | none => - let type ← mkForallFVars params type - pure (type, none) - | some valStx => - let value ← Term.elabTermEnsuringType valStx type - let type ← mkForallFVars params type - let value ← mkLambdaFVars params value - pure (type, value) + match view.type? with + | none => + match view.value? with + | none => pure (none, none) + | some valStx => + let value ← Term.elabTerm valStx none + let value ← mkLambdaFVars params value + pure (none, value) + | some typeStx => + let type ← Term.elabType typeStx + match view.value? with + | none => + let type ← mkForallFVars params type + pure (type, none) + | some valStx => + let value ← Term.elabTermEnsuringType valStx type + let type ← mkForallFVars params type + let value ← mkLambdaFVars params value + pure (type, value) -private partial def withFields {α} (views : Array StructFieldView) : Nat → Array StructFieldInfo → (Array StructFieldInfo → TermElabM α) → TermElabM α -| i, infos, k => - if h : i < views.size then do +private partial def withFields {α} + (views : Array StructFieldView) (i : Nat) (infos : Array StructFieldInfo) (k : Array StructFieldInfo → TermElabM α) : TermElabM α := do + if h : i < views.size then let view := views.get ⟨i, h⟩ withRef view.ref $ match findFieldInfo? infos view.name with @@ -302,194 +300,194 @@ private partial def withFields {α} (views : Array StructFieldView) : Nat → Ar k infos private def getResultUniverse (type : Expr) : TermElabM Level := do -let type ← whnf type -match type with -| Expr.sort u _ => pure u -| _ => throwError "unexpected structure resulting type" + let type ← whnf type + match type with + | Expr.sort u _ => pure u + | _ => throwError "unexpected structure resulting type" private def collectUsed (params : Array Expr) (fieldInfos : Array StructFieldInfo) : StateRefT CollectFVars.State TermElabM Unit := do -params.forM fun p => do - let type ← inferType p - Term.collectUsedFVars type -fieldInfos.forM fun info => do - let fvarType ← inferType info.fvar - Term.collectUsedFVars fvarType - match info.value? with - | none => pure () - | some value => Term.collectUsedFVars value + params.forM fun p => do + let type ← inferType p + Term.collectUsedFVars type + fieldInfos.forM fun info => do + let fvarType ← inferType info.fvar + Term.collectUsedFVars fvarType + match info.value? with + | none => pure () + | some value => Term.collectUsedFVars value private def removeUnused (scopeVars : Array Expr) (params : Array Expr) (fieldInfos : Array StructFieldInfo) : TermElabM (LocalContext × LocalInstances × Array Expr) := do -let (_, used) ← (collectUsed params fieldInfos).run {} -Term.removeUnused scopeVars used + let (_, used) ← (collectUsed params fieldInfos).run {} + Term.removeUnused scopeVars used private def withUsed {α} (scopeVars : Array Expr) (params : Array Expr) (fieldInfos : Array StructFieldInfo) (k : Array Expr → TermElabM α) : TermElabM α := do -let (lctx, localInsts, vars) ← removeUnused scopeVars params fieldInfos -withLCtx lctx localInsts $ k vars + let (lctx, localInsts, vars) ← removeUnused scopeVars params fieldInfos + withLCtx lctx localInsts $ k vars private def levelMVarToParamFVar (fvar : Expr) : StateRefT Nat TermElabM Unit := do -let type ← inferType fvar -Term.levelMVarToParam' type -pure () + let type ← inferType fvar + Term.levelMVarToParam' type + pure () private def levelMVarToParamFVars (fvars : Array Expr) : StateRefT Nat TermElabM Unit := -fvars.forM levelMVarToParamFVar + fvars.forM levelMVarToParamFVar private def levelMVarToParamAux (scopeVars : Array Expr) (params : Array Expr) (fieldInfos : Array StructFieldInfo) : StateRefT Nat TermElabM (Array StructFieldInfo) := do -levelMVarToParamFVars scopeVars -levelMVarToParamFVars params -fieldInfos.mapM fun info => do - levelMVarToParamFVar info.fvar - match info.value? with - | none => pure info - | some value => - let value ← Term.levelMVarToParam' value - pure { info with value? := value } + levelMVarToParamFVars scopeVars + levelMVarToParamFVars params + fieldInfos.mapM fun info => do + levelMVarToParamFVar info.fvar + match info.value? with + | none => pure info + | some value => + let value ← Term.levelMVarToParam' value + pure { info with value? := value } private def levelMVarToParam (scopeVars : Array Expr) (params : Array Expr) (fieldInfos : Array StructFieldInfo) : TermElabM (Array StructFieldInfo) := -(levelMVarToParamAux scopeVars params fieldInfos).run' 1 + (levelMVarToParamAux scopeVars params fieldInfos).run' 1 private partial def collectUniversesFromFields (r : Level) (rOffset : Nat) (fieldInfos : Array StructFieldInfo) : TermElabM (Array Level) := do -fieldInfos.foldlM (init := #[]) fun (us : Array Level) (info : StructFieldInfo) => do - let type ← inferType info.fvar - let u ← getLevel type - let u ← instantiateLevelMVars u - match accLevelAtCtor u r rOffset us with - | Except.error msg => throwError msg - | Except.ok us => pure us + fieldInfos.foldlM (init := #[]) fun (us : Array Level) (info : StructFieldInfo) => do + let type ← inferType info.fvar + let u ← getLevel type + let u ← instantiateLevelMVars u + match accLevelAtCtor u r rOffset us with + | Except.error msg => throwError msg + | Except.ok us => pure us private def updateResultingUniverse (fieldInfos : Array StructFieldInfo) (type : Expr) : TermElabM Expr := do -let r ← getResultUniverse type -let rOffset : Nat := r.getOffset -let r : Level := r.getLevelOffset -match r with -| Level.mvar mvarId _ => - let us ← collectUniversesFromFields r rOffset fieldInfos - let rNew := Level.mkNaryMax us.toList - assignLevelMVar mvarId rNew - instantiateMVars type -| _ => throwError "failed to compute resulting universe level of structure, provide universe explicitly" + let r ← getResultUniverse type + let rOffset : Nat := r.getOffset + let r : Level := r.getLevelOffset + match r with + | Level.mvar mvarId _ => + let us ← collectUniversesFromFields r rOffset fieldInfos + let rNew := Level.mkNaryMax us.toList + assignLevelMVar mvarId rNew + instantiateMVars type + | _ => throwError "failed to compute resulting universe level of structure, provide universe explicitly" private def collectLevelParamsInFVar (s : CollectLevelParams.State) (fvar : Expr) : TermElabM CollectLevelParams.State := do -let type ← inferType fvar -let type ← instantiateMVars type -pure $ collectLevelParams s type + let type ← inferType fvar + let type ← instantiateMVars type + pure $ collectLevelParams s type private def collectLevelParamsInFVars (fvars : Array Expr) (s : CollectLevelParams.State) : TermElabM CollectLevelParams.State := -fvars.foldlM collectLevelParamsInFVar s + fvars.foldlM collectLevelParamsInFVar s private def collectLevelParamsInStructure (scopeVars : Array Expr) (params : Array Expr) (fieldInfos : Array StructFieldInfo) : TermElabM (Array Name) := do -let s ← collectLevelParamsInFVars scopeVars {} -let s ← collectLevelParamsInFVars params s -let s ← fieldInfos.foldlM (fun (s : CollectLevelParams.State) info => collectLevelParamsInFVar s info.fvar) s -pure s.params + let s ← collectLevelParamsInFVars scopeVars {} + let s ← collectLevelParamsInFVars params s + let s ← fieldInfos.foldlM (fun (s : CollectLevelParams.State) info => collectLevelParamsInFVar s info.fvar) s + pure s.params private def addCtorFields (fieldInfos : Array StructFieldInfo) : Nat → Expr → TermElabM Expr -| 0, type => pure type -| i+1, type => do - let info := fieldInfos[i] - let decl ← Term.getFVarLocalDecl! info.fvar - let type ← instantiateMVars type - let type := type.abstract #[info.fvar] - match info.kind with - | StructFieldKind.fromParent => - let val := decl.value - addCtorFields fieldInfos i (type.instantiate1 val) - | StructFieldKind.subobject => - let n := mkInternalSubobjectFieldName $ decl.userName - addCtorFields fieldInfos i (mkForall n decl.binderInfo decl.type type) - | StructFieldKind.newField => - addCtorFields fieldInfos i (mkForall decl.userName decl.binderInfo decl.type type) + | 0, type => pure type + | i+1, type => do + let info := fieldInfos[i] + let decl ← Term.getFVarLocalDecl! info.fvar + let type ← instantiateMVars type + let type := type.abstract #[info.fvar] + match info.kind with + | StructFieldKind.fromParent => + let val := decl.value + addCtorFields fieldInfos i (type.instantiate1 val) + | StructFieldKind.subobject => + let n := mkInternalSubobjectFieldName $ decl.userName + addCtorFields fieldInfos i (mkForall n decl.binderInfo decl.type type) + | StructFieldKind.newField => + addCtorFields fieldInfos i (mkForall decl.userName decl.binderInfo decl.type type) private def mkCtor (view : StructView) (levelParams : List Name) (params : Array Expr) (fieldInfos : Array StructFieldInfo) : TermElabM Constructor := -withRef view.ref do -let type := mkAppN (mkConst view.declName (levelParams.map mkLevelParam)) params -let type ← addCtorFields fieldInfos fieldInfos.size type -let type ← mkForallFVars params type -let type ← instantiateMVars type -let type := type.inferImplicit params.size !view.ctor.inferMod -pure { name := view.ctor.declName, type := type } + withRef view.ref do + let type := mkAppN (mkConst view.declName (levelParams.map mkLevelParam)) params + let type ← addCtorFields fieldInfos fieldInfos.size type + let type ← mkForallFVars params type + let type ← instantiateMVars type + let type := type.inferImplicit params.size !view.ctor.inferMod + pure { name := view.ctor.declName, type := type } @[extern "lean_mk_projections"] -private constant mkProjections (env : Environment) (structName : Name) (projs : List ProjectionInfo) (isClass : Bool) : Except String Environment := arbitrary _ +private constant mkProjections (env : Environment) (structName : Name) (projs : List ProjectionInfo) (isClass : Bool) : Except String Environment private def addProjections (structName : Name) (projs : List ProjectionInfo) (isClass : Bool) : TermElabM Unit := do -let env ← getEnv -match mkProjections env structName projs isClass with -| Except.ok env => setEnv env -| Except.error msg => throwError msg + let env ← getEnv + match mkProjections env structName projs isClass with + | Except.ok env => setEnv env + | Except.error msg => throwError msg private def mkAuxConstructions (declName : Name) : TermElabM Unit := do -let env ← getEnv -let hasUnit := env.contains `PUnit -let hasEq := env.contains `Eq -let hasHEq := env.contains `HEq -modifyEnv fun env => mkRecOn env declName -if hasUnit then modifyEnv fun env => mkCasesOn env declName -if hasUnit && hasEq && hasHEq then modifyEnv fun env => mkNoConfusion env declName + let env ← getEnv + let hasUnit := env.contains `PUnit + let hasEq := env.contains `Eq + let hasHEq := env.contains `HEq + modifyEnv fun env => mkRecOn env declName + if hasUnit then modifyEnv fun env => mkCasesOn env declName + if hasUnit && hasEq && hasHEq then modifyEnv fun env => mkNoConfusion env declName private def addDefaults (lctx : LocalContext) (defaultAuxDecls : Array (Name × Expr × Expr)) : TermElabM Unit := do -let localInsts ← getLocalInstances -withLCtx lctx localInsts do - defaultAuxDecls.forM fun (declName, type, value) => do - /- The identity function is used as "marker". -/ - let value ← mkId value - mkAuxDefinition declName type value (zeta := true) - modifyEnv fun env => setReducibilityStatus env declName ReducibilityStatus.reducible + let localInsts ← getLocalInstances + withLCtx lctx localInsts do + defaultAuxDecls.forM fun (declName, type, value) => do + /- The identity function is used as "marker". -/ + let value ← mkId value + mkAuxDefinition declName type value (zeta := true) + modifyEnv fun env => setReducibilityStatus env declName ReducibilityStatus.reducible private def elabStructureView (view : StructView) : TermElabM Unit := do -let numExplicitParams := view.params.size -let type ← Term.elabType view.type -unless validStructType type do throwErrorAt view.type "expected Type" -withRef view.ref do -withParents view 0 #[] fun fieldInfos => -withFields view.fields 0 fieldInfos fun fieldInfos => do - Term.synthesizeSyntheticMVarsNoPostponing - let u ← getResultUniverse type - let inferLevel ← shouldInferResultUniverse u - withUsed view.scopeVars view.params fieldInfos $ fun scopeVars => do - let numParams := scopeVars.size + numExplicitParams - let fieldInfos ← levelMVarToParam scopeVars view.params fieldInfos - let type ← if inferLevel then updateResultingUniverse fieldInfos type else pure type - let usedLevelNames ← collectLevelParamsInStructure scopeVars view.params fieldInfos - match sortDeclLevelParams view.scopeLevelNames view.allUserLevelNames usedLevelNames with - | Except.error msg => throwError msg - | Except.ok levelParams => - let params := scopeVars ++ view.params - let ctor ← mkCtor view levelParams params fieldInfos - let type ← mkForallFVars params type - let type ← instantiateMVars type - let indType := { name := view.declName, type := type, ctors := [ctor] : InductiveType } - let decl := Declaration.inductDecl levelParams params.size [indType] view.modifiers.isUnsafe - Term.ensureNoUnassignedMVars decl - addDecl decl - let projInfos := (fieldInfos.filter fun (info : StructFieldInfo) => !info.isFromParent).toList.map fun (info : StructFieldInfo) => - { declName := info.declName, inferMod := info.inferMod : ProjectionInfo } - addProjections view.declName projInfos view.isClass - mkAuxConstructions view.declName - let instParents ← fieldInfos.filterM fun info => do - let decl ← Term.getFVarLocalDecl! info.fvar - pure (info.isSubobject && decl.binderInfo.isInstImplicit) - let projInstances := instParents.toList.map fun info => info.declName - Term.applyAttributesAt view.declName view.modifiers.attrs AttributeApplicationTime.afterTypeChecking - projInstances.forM addGlobalInstance - let lctx ← getLCtx - let fieldsWithDefault := fieldInfos.filter fun info => info.value?.isSome - let defaultAuxDecls ← fieldsWithDefault.mapM fun info => do - let type ← inferType info.fvar - pure (info.declName ++ `_default, type, info.value?.get!) - /- The `lctx` and `defaultAuxDecls` are used to create the auxiliary `_default` declarations - The parameters `params` for these definitions must be marked as implicit, and all others as explicit. -/ - let lctx := - params.foldl (init := lctx) fun (lctx : LocalContext) (p : Expr) => - lctx.updateBinderInfo p.fvarId! BinderInfo.implicit - let lctx := - fieldInfos.foldl (init := lctx) fun (lctx : LocalContext) (info : StructFieldInfo) => - if info.isFromParent then lctx -- `fromParent` fields are elaborated as let-decls, and are zeta-expanded when creating `_default`. - else lctx.updateBinderInfo info.fvar.fvarId! BinderInfo.default - addDefaults lctx defaultAuxDecls + let numExplicitParams := view.params.size + let type ← Term.elabType view.type + unless validStructType type do throwErrorAt view.type "expected Type" + withRef view.ref do + withParents view 0 #[] fun fieldInfos => + withFields view.fields 0 fieldInfos fun fieldInfos => do + Term.synthesizeSyntheticMVarsNoPostponing + let u ← getResultUniverse type + let inferLevel ← shouldInferResultUniverse u + withUsed view.scopeVars view.params fieldInfos $ fun scopeVars => do + let numParams := scopeVars.size + numExplicitParams + let fieldInfos ← levelMVarToParam scopeVars view.params fieldInfos + let type ← if inferLevel then updateResultingUniverse fieldInfos type else pure type + let usedLevelNames ← collectLevelParamsInStructure scopeVars view.params fieldInfos + match sortDeclLevelParams view.scopeLevelNames view.allUserLevelNames usedLevelNames with + | Except.error msg => throwError msg + | Except.ok levelParams => + let params := scopeVars ++ view.params + let ctor ← mkCtor view levelParams params fieldInfos + let type ← mkForallFVars params type + let type ← instantiateMVars type + let indType := { name := view.declName, type := type, ctors := [ctor] : InductiveType } + let decl := Declaration.inductDecl levelParams params.size [indType] view.modifiers.isUnsafe + Term.ensureNoUnassignedMVars decl + addDecl decl + let projInfos := (fieldInfos.filter fun (info : StructFieldInfo) => !info.isFromParent).toList.map fun (info : StructFieldInfo) => + { declName := info.declName, inferMod := info.inferMod : ProjectionInfo } + addProjections view.declName projInfos view.isClass + mkAuxConstructions view.declName + let instParents ← fieldInfos.filterM fun info => do + let decl ← Term.getFVarLocalDecl! info.fvar + pure (info.isSubobject && decl.binderInfo.isInstImplicit) + let projInstances := instParents.toList.map fun info => info.declName + Term.applyAttributesAt view.declName view.modifiers.attrs AttributeApplicationTime.afterTypeChecking + projInstances.forM addGlobalInstance + let lctx ← getLCtx + let fieldsWithDefault := fieldInfos.filter fun info => info.value?.isSome + let defaultAuxDecls ← fieldsWithDefault.mapM fun info => do + let type ← inferType info.fvar + pure (info.declName ++ `_default, type, info.value?.get!) + /- The `lctx` and `defaultAuxDecls` are used to create the auxiliary `_default` declarations + The parameters `params` for these definitions must be marked as implicit, and all others as explicit. -/ + let lctx := + params.foldl (init := lctx) fun (lctx : LocalContext) (p : Expr) => + lctx.updateBinderInfo p.fvarId! BinderInfo.implicit + let lctx := + fieldInfos.foldl (init := lctx) fun (lctx : LocalContext) (info : StructFieldInfo) => + if info.isFromParent then lctx -- `fromParent` fields are elaborated as let-decls, and are zeta-expanded when creating `_default`. + else lctx.updateBinderInfo info.fvar.fvarId! BinderInfo.default + addDefaults lctx defaultAuxDecls /- parser! (structureTk <|> classTk) >> declId >> many Term.bracketedBinder >> optional «extends» >> Term.optType >> " := " >> optional structCtor >> structFields @@ -504,32 +502,32 @@ def structCtor := parser! try (declModifiers >> ident >> optional infe -/ def elabStructure (modifiers : Modifiers) (stx : Syntax) : CommandElabM Unit := do -checkValidInductiveModifier modifiers -let isClass := stx[0].getKind == `Lean.Parser.Command.classTk -let modifiers := if isClass then modifiers.addAttribute { name := `class } else modifiers -let declId := stx[1] -let params := stx[2].getArgs -let exts := stx[3] -let parents := if exts.isNone then #[] else exts[0][1].getSepArgs -let optType := stx[4] -let type ← if optType.isNone then `(Type _) else pure optType[0][1] -let scopeLevelNames ← getLevelNames -let ⟨name, declName, allUserLevelNames⟩ ← expandDeclId declId modifiers -let ctor ← expandCtor stx modifiers declName -let fields ← expandFields stx modifiers declName -runTermElabM declName $ fun scopeVars => Term.withLevelNames allUserLevelNames $ Term.elabBinders params fun params => elabStructureView { - ref := stx, - modifiers := modifiers, - scopeLevelNames := scopeLevelNames, - allUserLevelNames := allUserLevelNames, - declName := declName, - isClass := isClass, - scopeVars := scopeVars, - params := params, - parents := parents, - type := type, - ctor := ctor, - fields := fields -} + checkValidInductiveModifier modifiers + let isClass := stx[0].getKind == `Lean.Parser.Command.classTk + let modifiers := if isClass then modifiers.addAttribute { name := `class } else modifiers + let declId := stx[1] + let params := stx[2].getArgs + let exts := stx[3] + let parents := if exts.isNone then #[] else exts[0][1].getSepArgs + let optType := stx[4] + let type ← if optType.isNone then `(Type _) else pure optType[0][1] + let scopeLevelNames ← getLevelNames + let ⟨name, declName, allUserLevelNames⟩ ← expandDeclId declId modifiers + let ctor ← expandCtor stx modifiers declName + let fields ← expandFields stx modifiers declName + runTermElabM declName $ fun scopeVars => Term.withLevelNames allUserLevelNames $ Term.elabBinders params fun params => elabStructureView { + ref := stx, + modifiers := modifiers, + scopeLevelNames := scopeLevelNames, + allUserLevelNames := allUserLevelNames, + declName := declName, + isClass := isClass, + scopeVars := scopeVars, + params := params, + parents := parents, + type := type, + ctor := ctor, + fields := fields + } end Lean.Elab.Command diff --git a/src/Lean/Elab/Tactic/Basic.lean b/src/Lean/Elab/Tactic/Basic.lean index 498bef5125..88f1bdd252 100644 --- a/src/Lean/Elab/Tactic/Basic.lean +++ b/src/Lean/Elab/Tactic/Basic.lean @@ -17,69 +17,68 @@ namespace Lean.Elab open Meta def goalsToMessageData (goals : List MVarId) : MessageData := -MessageData.joinSep (goals.map $ MessageData.ofGoal) (Format.line ++ Format.line) + MessageData.joinSep (goals.map $ MessageData.ofGoal) (Format.line ++ Format.line) def Term.reportUnsolvedGoals (goals : List MVarId) : TermElabM Unit := do -throwError! "unsolved goals\n{goalsToMessageData goals}" + throwError! "unsolved goals\n{goalsToMessageData goals}" namespace Tactic structure Context := -(main : MVarId) + (main : MVarId) structure State := -(goals : List MVarId) + (goals : List MVarId) -instance State.inhabited : Inhabited State := ⟨{ goals := [] }⟩ +instance : Inhabited State := ⟨{ goals := [] }⟩ structure BacktrackableState := -(env : Environment) -(mctx : MetavarContext) -(goals : List MVarId) + (env : Environment) + (mctx : MetavarContext) + (goals : List MVarId) abbrev TacticM := ReaderT Context $ StateRefT State $ TermElabM abbrev Tactic := Syntax → TacticM Unit def saveBacktrackableState : TacticM BacktrackableState := do -pure { env := (← getEnv), mctx := (← getMCtx), goals := (← get).goals } + pure { env := (← getEnv), mctx := (← getMCtx), goals := (← get).goals } def BacktrackableState.restore (b : BacktrackableState) : TacticM Unit := do -setEnv b.env; -setMCtx b.mctx; -modify fun s => { s with goals := b.goals } + setEnv b.env + setMCtx b.mctx + modify fun s => { s with goals := b.goals } @[inline] protected def «catch» {α} (x : TacticM α) (h : Exception → TacticM α) : TacticM α := do -let b ← saveBacktrackableState; -try x catch ex => b.restore; h ex + let b ← saveBacktrackableState + try x catch ex => b.restore; h ex @[inline] protected def orelse {α} (x y : TacticM α) : TacticM α := do -try x catch _ => y + try x catch _ => y -instance monadExcept : MonadExcept Exception TacticM := -{ throw := throw, - «catch» := Tactic.«catch» } +instance : MonadExcept Exception TacticM := { + throw := throw, + «catch» := Tactic.«catch» +} -instance hasOrElse {α} : HasOrelse (TacticM α) := ⟨Tactic.orelse⟩ +instance {α} : HasOrelse (TacticM α) := ⟨Tactic.orelse⟩ structure SavedState := -(core : Core.State) -(meta : Meta.State) -(term : Term.State) -(tactic : State) + (core : Core.State) + (meta : Meta.State) + (term : Term.State) + (tactic : State) -instance SavedState.inhabited : Inhabited SavedState := ⟨⟨arbitrary _, arbitrary _, arbitrary _, arbitrary _⟩⟩ +instance : Inhabited SavedState := ⟨⟨arbitrary _, arbitrary _, arbitrary _, arbitrary _⟩⟩ def saveAllState : TacticM SavedState := do -pure { core := (← getThe Core.State), meta := (← getThe Meta.State), term := (← getThe Term.State), tactic := (← get) } + pure { core := (← getThe Core.State), meta := (← getThe Meta.State), term := (← getThe Term.State), tactic := (← get) } def SavedState.restore (s : SavedState) : TacticM Unit := do -set s.core; set s.meta; set s.term; set s.tactic + set s.core; set s.meta; set s.term; set s.tactic -@[inline] def liftTermElabM {α} (x : TermElabM α) : TacticM α := -liftM x +@[inline] def liftTermElabM {α} (x : TermElabM α) : TacticM α := liftM x -@[inline] def liftMetaM {α} (x : MetaM α) : TacticM α := -liftTermElabM $ Term.liftMetaM x +@[inline] def liftMetaM {α} (x : MetaM α) : TacticM α := liftTermElabM $ Term.liftMetaM x def ensureHasType (expectedType? : Option Expr) (e : Expr) : TacticM Expr := liftTermElabM $ Term.ensureHasType expectedType? e def reportUnsolvedGoals (goals : List MVarId) : TacticM Unit := liftTermElabM $ Term.reportUnsolvedGoals goals @@ -88,186 +87,191 @@ protected def getCurrMacroScope : TacticM MacroScope := do pure (← readThe Ter protected def getMainModule : TacticM Name := do pure (← getEnv).mainModule unsafe def mkTacticAttribute : IO (KeyedDeclsAttribute Tactic) := -mkElabAttribute Tactic `Lean.Elab.Tactic.tacticElabAttribute `builtinTactic `tactic `Lean.Parser.Tactic `Lean.Elab.Tactic.Tactic "tactic" -@[builtinInit mkTacticAttribute] constant tacticElabAttribute : KeyedDeclsAttribute Tactic := arbitrary _ + mkElabAttribute Tactic `Lean.Elab.Tactic.tacticElabAttribute `builtinTactic `tactic `Lean.Parser.Tactic `Lean.Elab.Tactic.Tactic "tactic" + +@[builtinInit mkTacticAttribute] constant tacticElabAttribute : KeyedDeclsAttribute Tactic private def evalTacticUsing (s : SavedState) (stx : Syntax) (tactics : List Tactic) : TacticM Unit := do -let rec loop : List Tactic → TacticM Unit - | [] => throwErrorAt! stx "unexpected syntax {indentD stx}" - | evalFn::evalFns => do - try - evalFn stx - catch - | ex@(Exception.error _ _) => - match evalFns with - | [] => throw ex - | _ => s.restore; loop evalFns - | ex@(Exception.internal id) => - if id == unsupportedSyntaxExceptionId then - s.restore; loop evalFns - else - throw ex -loop tactics + let rec loop : List Tactic → TacticM Unit + | [] => throwErrorAt! stx "unexpected syntax {indentD stx}" + | evalFn::evalFns => do + try + evalFn stx + catch + | ex@(Exception.error _ _) => + match evalFns with + | [] => throw ex + | _ => s.restore; loop evalFns + | ex@(Exception.internal id) => + if id == unsupportedSyntaxExceptionId then + s.restore; loop evalFns + else + throw ex + loop tactics /- Elaborate `x` with `stx` on the macro stack -/ -@[inline] def withMacroExpansion {α} (beforeStx afterStx : Syntax) (x : TacticM α) : TacticM α := -withTheReader Term.Context (fun ctx => { ctx with macroStack := { before := beforeStx, after := afterStx } :: ctx.macroStack }) x +@[inline] +def withMacroExpansion {α} (beforeStx afterStx : Syntax) (x : TacticM α) : TacticM α := + withTheReader Term.Context (fun ctx => { ctx with macroStack := { before := beforeStx, after := afterStx } :: ctx.macroStack }) x mutual partial def expandTacticMacroFns (stx : Syntax) (macros : List Macro) : TacticM Unit := -let rec loop : List Macro → TacticM Unit - | [] => throwErrorAt! stx "tactic '{stx.getKind}' has not been implemented" - | m::ms => do - let scp ← getCurrMacroScope - try - let stx' ← adaptMacro m stx - evalTactic stx' - catch ex => - if ms.isEmpty then throw ex - loop ms -loop macros + let rec loop : List Macro → TacticM Unit + | [] => throwErrorAt! stx "tactic '{stx.getKind}' has not been implemented" + | m::ms => do + let scp ← getCurrMacroScope + try + let stx' ← adaptMacro m stx + evalTactic stx' + catch ex => + if ms.isEmpty then throw ex + loop ms + loop macros partial def expandTacticMacro (stx : Syntax) : TacticM Unit := do -let k := stx.getKind -let table := (macroAttribute.ext.getState (← getEnv)).table -let macroFns := (table.find? k).getD [] -expandTacticMacroFns stx macroFns + let k := stx.getKind + let table := (macroAttribute.ext.getState (← getEnv)).table + let macroFns := (table.find? k).getD [] + expandTacticMacroFns stx macroFns partial def evalTactic : Syntax → TacticM Unit -| stx => withRef stx $ withIncRecDepth $ withFreshMacroScope $ match stx with - | Syntax.node k args => - if k == nullKind then - -- Macro writers create a sequence of tactics `t₁ ... tₙ` using `mkNullNode #[t₁, ..., tₙ]` - stx.getArgs.forM evalTactic - else do - trace `Elab.step fun _ => stx - let env ← getEnv - let s ← saveAllState - let table := (tacticElabAttribute.ext.getState env).table - let k := stx.getKind - match table.find? k with - | some evalFns => evalTacticUsing s stx evalFns - | none => expandTacticMacro stx - | _ => throwError "unexpected command" + | stx => withRef stx $ withIncRecDepth $ withFreshMacroScope $ match stx with + | Syntax.node k args => + if k == nullKind then + -- Macro writers create a sequence of tactics `t₁ ... tₙ` using `mkNullNode #[t₁, ..., tₙ]` + stx.getArgs.forM evalTactic + else do + trace `Elab.step fun _ => stx + let env ← getEnv + let s ← saveAllState + let table := (tacticElabAttribute.ext.getState env).table + let k := stx.getKind + match table.find? k with + | some evalFns => evalTacticUsing s stx evalFns + | none => expandTacticMacro stx + | _ => throwError "unexpected command" end /-- Adapt a syntax transformation to a regular tactic evaluator. -/ -def adaptExpander (exp : Syntax → TacticM Syntax) : Tactic := -fun stx => do +def adaptExpander (exp : Syntax → TacticM Syntax) : Tactic := fun stx => do let stx' ← exp stx withMacroExpansion stx stx' $ evalTactic stx' def getGoals : TacticM (List MVarId) := do pure (← get).goals + def setGoals (gs : List MVarId) : TacticM Unit := modify $ fun s => { s with goals := gs } + def appendGoals (gs : List MVarId) : TacticM Unit := modify $ fun s => { s with goals := s.goals ++ gs } + def pruneSolvedGoals : TacticM Unit := do -let gs ← getGoals -let gs ← gs.filterM fun g => not <$> isExprMVarAssigned g -setGoals gs + let gs ← getGoals + let gs ← gs.filterM fun g => not <$> isExprMVarAssigned g + setGoals gs + def getUnsolvedGoals : TacticM (List MVarId) := do pruneSolvedGoals; getGoals + def getMainGoal : TacticM (MVarId × List MVarId) := do let (g::gs) ← getUnsolvedGoals | throwError "no goals to be solved"; pure (g, gs) + def getMainTag : TacticM Name := do -let (g, _) ← getMainGoal -pure (← getMVarDecl g).userName + let (g, _) ← getMainGoal + pure (← getMVarDecl g).userName def ensureHasNoMVars (e : Expr) : TacticM Unit := do -let e ← instantiateMVars e -let pendingMVars ← getMVars e -Term.logUnassignedUsingErrorInfos pendingMVars -if e.hasExprMVar then - throwError! "tactic failed, resulting expression contains metavariables{indentExpr e}" + let e ← instantiateMVars e + let pendingMVars ← getMVars e + Term.logUnassignedUsingErrorInfos pendingMVars + if e.hasExprMVar then + throwError! "tactic failed, resulting expression contains metavariables{indentExpr e}" def withMainMVarContext {α} (x : TacticM α) : TacticM α := do -let (mvarId, _) ← getMainGoal -withMVarContext mvarId x + let (mvarId, _) ← getMainGoal + withMVarContext mvarId x @[inline] def liftMetaMAtMain {α} (x : MVarId → MetaM α) : TacticM α := do -let (g, _) ← getMainGoal -withMVarContext g $ liftMetaM $ x g + let (g, _) ← getMainGoal + withMVarContext g $ liftMetaM $ x g @[inline] def liftMetaTacticAux {α} (tactic : MVarId → MetaM (α × List MVarId)) : TacticM α := do -let (g, gs) ← getMainGoal -withMVarContext g do - let (a, gs') ← tactic g + let (g, gs) ← getMainGoal + withMVarContext g do + let (a, gs') ← tactic g + setGoals (gs' ++ gs) + pure a + +@[inline] def liftMetaTactic (tactic : MVarId → MetaM (List MVarId)) : TacticM Unit := + liftMetaTacticAux fun mvarId => do + let gs ← tactic mvarId + pure ((), gs) + +def done : TacticM Unit := do + let gs ← getUnsolvedGoals; + unless gs.isEmpty do + reportUnsolvedGoals gs + +@[builtinTactic Lean.Parser.Tactic.«done»] def evalDone : Tactic := fun _ => done + +def focusAux {α} (tactic : TacticM α) : TacticM α := do + let (g, gs) ← getMainGoal + setGoals [g] + let a ← tactic + let gs' ← getGoals setGoals (gs' ++ gs) pure a -@[inline] def liftMetaTactic (tactic : MVarId → MetaM (List MVarId)) : TacticM Unit := -liftMetaTacticAux fun mvarId => do - let gs ← tactic mvarId - pure ((), gs) - -def done : TacticM Unit := do -let gs ← getUnsolvedGoals; -unless gs.isEmpty do - reportUnsolvedGoals gs - -@[builtinTactic Lean.Parser.Tactic.«done»] def evalDone : Tactic := -fun _ => done - -def focusAux {α} (tactic : TacticM α) : TacticM α := do -let (g, gs) ← getMainGoal -setGoals [g] -let a ← tactic -let gs' ← getGoals -setGoals (gs' ++ gs) -pure a - def focus {α} (tactic : TacticM α) : TacticM α := -focusAux do let a ← tactic; done; pure a + focusAux do let a ← tactic; done; pure a def try? {α} (tactic : TacticM α) : TacticM (Option α) := do -try pure (some (← tactic)) -catch _ => pure none + try pure (some (← tactic)) + catch _ => pure none --- TODO: rename +-- TODO: rename? def «try» {α} (tactic : TacticM α) : TacticM Bool := do -try tactic; pure true -catch _ => pure false + try tactic; pure true + catch _ => pure false /-- Use `parentTag` to tag untagged goals at `newGoals`. If there are multiple new untagged goals, they are named using `._` where `idx > 0`. If there is only one new untagged goal, then we just use `parentTag` -/ def tagUntaggedGoals (parentTag : Name) (newSuffix : Name) (newGoals : List MVarId) : TacticM Unit := do -let mctx ← getMCtx -let numAnonymous := 0 -for g in newGoals do - if mctx.isAnonymousMVar g then - numAnonymous := numAnonymous + 1 -modifyMCtx fun mctx => do - let idx := 1 + let mctx ← getMCtx + let numAnonymous := 0 for g in newGoals do if mctx.isAnonymousMVar g then - if numAnonymous == 1 then - mctx := mctx.renameMVar g parentTag - else - mctx := mctx.renameMVar g (parentTag ++ newSuffix.appendIndexAfter idx) - idx := idx + 1 - pure mctx + numAnonymous := numAnonymous + 1 + modifyMCtx fun mctx => do + let idx := 1 + for g in newGoals do + if mctx.isAnonymousMVar g then + if numAnonymous == 1 then + mctx := mctx.renameMVar g parentTag + else + mctx := mctx.renameMVar g (parentTag ++ newSuffix.appendIndexAfter idx) + idx := idx + 1 + pure mctx -@[builtinTactic seq1] def evalSeq1 : Tactic := -fun stx => stx[0].forSepArgsM evalTactic +@[builtinTactic seq1] def evalSeq1 : Tactic := fun stx => + stx[0].forSepArgsM evalTactic -@[builtinTactic paren] def evalParen : Tactic := -fun stx => evalSeq1 stx[1] +@[builtinTactic paren] def evalParen : Tactic := fun stx => + evalSeq1 stx[1] -@[builtinTactic tacticSeq1Indented] def evalTacticSeq1Indented : Tactic := -fun stx => stx[0].forArgsM fun seqElem => evalTactic seqElem[0] +@[builtinTactic tacticSeq1Indented] def evalTacticSeq1Indented : Tactic := fun stx => + stx[0].forArgsM fun seqElem => evalTactic seqElem[0] -@[builtinTactic tacticSeqBracketed] def evalTacticSeqBracketed : Tactic := -fun stx => withRef stx[2] $ focus $ stx[1].forArgsM fun seqElem => evalTactic seqElem[0] +@[builtinTactic tacticSeqBracketed] def evalTacticSeqBracketed : Tactic := fun stx => + withRef stx[2] $ focus $ stx[1].forArgsM fun seqElem => evalTactic seqElem[0] -@[builtinTactic Parser.Tactic.focus] def evalFocus : Tactic := -fun stx => focus $ evalTactic stx[1] +@[builtinTactic Parser.Tactic.focus] def evalFocus : Tactic := fun stx => + focus $ evalTactic stx[1] -@[builtinTactic tacticSeq] def evalTacticSeq : Tactic := -fun stx => evalTactic stx[0] +@[builtinTactic tacticSeq] def evalTacticSeq : Tactic := fun stx => + evalTactic stx[0] -partial def evalChoiceAux (tactics : Array Syntax) : Nat → TacticM Unit -| i => +partial def evalChoiceAux (tactics : Array Syntax) (i : Nat) : TacticM Unit := if h : i < tactics.size then let tactic := tactics.get ⟨i, h⟩ catchInternalId unsupportedSyntaxExceptionId @@ -276,33 +280,30 @@ partial def evalChoiceAux (tactics : Array Syntax) : Nat → TacticM Unit else throwUnsupportedSyntax -@[builtinTactic choice] def evalChoice : Tactic := -fun stx => evalChoiceAux stx.getArgs 0 +@[builtinTactic choice] def evalChoice : Tactic := fun stx => + evalChoiceAux stx.getArgs 0 -@[builtinTactic skip] def evalSkip : Tactic := -fun stx => pure () +@[builtinTactic skip] def evalSkip : Tactic := fun stx => pure () -@[builtinTactic failIfSuccess] def evalFailIfSuccess : Tactic := -fun stx => do +@[builtinTactic failIfSuccess] def evalFailIfSuccess : Tactic := fun stx => do let tactic := stx[1] if (← do try evalTactic tactic; pure true catch _ => pure false) then throwError "tactic succeeded" -@[builtinTactic traceState] def evalTraceState : Tactic := -fun stx => do +@[builtinTactic traceState] def evalTraceState : Tactic := fun stx => do let gs ← getUnsolvedGoals; logInfo (goalsToMessageData gs) -@[builtinTactic Lean.Parser.Tactic.assumption] def evalAssumption : Tactic := -fun stx => liftMetaTactic fun mvarId => do Meta.assumption mvarId; pure [] +@[builtinTactic Lean.Parser.Tactic.assumption] def evalAssumption : Tactic := fun stx => + liftMetaTactic fun mvarId => do Meta.assumption mvarId; pure [] private def introStep (n : Name) : TacticM Unit := -liftMetaTactic fun mvarId => do - let (_, mvarId) ← Meta.intro mvarId n - pure [mvarId] + liftMetaTactic fun mvarId => do + let (_, mvarId) ← Meta.intro mvarId n + pure [mvarId] -@[builtinTactic Lean.Parser.Tactic.intro] def evalIntro : Tactic := -fun stx => match_syntax stx with +@[builtinTactic Lean.Parser.Tactic.intro] def evalIntro : Tactic := fun stx => + match_syntax stx with | `(tactic| intro) => liftMetaTactic fun mvarId => do (_, mvarId) ← Meta.intro1 mvarId; pure [mvarId] | `(tactic| intro $h:ident) => introStep h.getId | `(tactic| intro _) => introStep `_ @@ -316,19 +317,18 @@ fun stx => match_syntax stx with withMacroExpansion stx stxNew $ evalTactic stxNew | _ => throwUnsupportedSyntax -@[builtinTactic Lean.Parser.Tactic.introMatch] def evalIntroMatch : Tactic := -fun stx => do +@[builtinTactic Lean.Parser.Tactic.introMatch] def evalIntroMatch : Tactic := fun stx => do let matchAlts := stx[1] let stxNew ← liftMacroM $ Term.expandMatchAltsIntoMatchTactic stx matchAlts withMacroExpansion stx stxNew $ evalTactic stxNew private def getIntrosSize : Expr → Nat -| Expr.forallE _ _ b _ => getIntrosSize b + 1 -| Expr.letE _ _ _ b _ => getIntrosSize b + 1 -| _ => 0 + | Expr.forallE _ _ b _ => getIntrosSize b + 1 + | Expr.letE _ _ _ b _ => getIntrosSize b + 1 + | _ => 0 -@[builtinTactic «intros»] def evalIntros : Tactic := -fun stx => match_syntax stx with +@[builtinTactic «intros»] def evalIntros : Tactic := fun stx => + match_syntax stx with | `(tactic| intros) => liftMetaTactic fun mvarId => do let type ← Meta.getMVarType mvarId let type ← instantiateMVars type @@ -340,18 +340,17 @@ fun stx => match_syntax stx with pure [mvarId] | _ => throwUnsupportedSyntax -def getFVarId (id : Syntax) : TacticM FVarId := -withRef id do -let fvar? ← liftTermElabM $ Term.isLocalIdent? id; -match fvar? with -| some fvar => pure fvar.fvarId! -| none => throwError! "unknown variable '{id.getId}'" +def getFVarId (id : Syntax) : TacticM FVarId := withRef id do + let fvar? ← liftTermElabM $ Term.isLocalIdent? id; + match fvar? with + | some fvar => pure fvar.fvarId! + | none => throwError! "unknown variable '{id.getId}'" def getFVarIds (ids : Array Syntax) : TacticM (Array FVarId) := do -withMainMVarContext $ ids.mapM getFVarId + withMainMVarContext $ ids.mapM getFVarId -@[builtinTactic Lean.Parser.Tactic.revert] def evalRevert : Tactic := -fun stx => match_syntax stx with +@[builtinTactic Lean.Parser.Tactic.revert] def evalRevert : Tactic := fun stx => + match_syntax stx with | `(tactic| revert $hs*) => do let (g, gs) ← getMainGoal let fvarIds ← getFVarIds hs @@ -361,17 +360,17 @@ fun stx => match_syntax stx with /- Sort free variables using an order `x < y` iff `x` was defined after `y` -/ private def sortFVarIds (fvarIds : Array FVarId) : TacticM (Array FVarId) := -withMainMVarContext do - let lctx ← getLCtx - pure $ fvarIds.qsort fun fvarId₁ fvarId₂ => - match lctx.find? fvarId₁, lctx.find? fvarId₂ with - | some d₁, some d₂ => d₁.index > d₂.index - | some _, none => false - | none, some _ => true - | none, none => Name.quickLt fvarId₁ fvarId₂ + withMainMVarContext do + let lctx ← getLCtx + pure $ fvarIds.qsort fun fvarId₁ fvarId₂ => + match lctx.find? fvarId₁, lctx.find? fvarId₂ with + | some d₁, some d₂ => d₁.index > d₂.index + | some _, none => false + | none, some _ => true + | none, none => Name.quickLt fvarId₁ fvarId₂ -@[builtinTactic Lean.Parser.Tactic.clear] def evalClear : Tactic := -fun stx => match_syntax stx with +@[builtinTactic Lean.Parser.Tactic.clear] def evalClear : Tactic := fun stx => + match_syntax stx with | `(tactic| clear $hs*) => do let fvarIds ← getFVarIds hs let fvarIds ← sortFVarIds fvarIds @@ -383,15 +382,15 @@ fun stx => match_syntax stx with | _ => throwUnsupportedSyntax def forEachVar (hs : Array Syntax) (tac : MVarId → FVarId → MetaM MVarId) : TacticM Unit := do -for h in hs do - let (g, gs) ← getMainGoal; - withMVarContext g do - let fvarId ← getFVarId h - let g ← tac g fvarId - setGoals (g :: gs) + for h in hs do + let (g, gs) ← getMainGoal; + withMVarContext g do + let fvarId ← getFVarId h + let g ← tac g fvarId + setGoals (g :: gs) -@[builtinTactic Lean.Parser.Tactic.subst] def evalSubst : Tactic := -fun stx => match_syntax stx with +@[builtinTactic Lean.Parser.Tactic.subst] def evalSubst : Tactic := fun stx => + match_syntax stx with | `(tactic| subst $hs*) => forEachVar hs Meta.subst | _ => throwUnsupportedSyntax @@ -399,13 +398,12 @@ fun stx => match_syntax stx with First method searches for a metavariable `g` s.t. `tag` is a suffix of its name. If none is found, then it searches for a metavariable `g` s.t. `tag` is a prefix of its name. -/ private def findTag? (gs : List MVarId) (tag : Name) : TacticM (Option MVarId) := do -let g? ← gs.findM? (fun g => do pure $ tag.isSuffixOf (← getMVarDecl g).userName); -match g? with -| some g => pure g -| none => gs.findM? (fun g => do pure $ tag.isPrefixOf (← getMVarDecl g).userName) + let g? ← gs.findM? (fun g => do pure $ tag.isSuffixOf (← getMVarDecl g).userName); + match g? with + | some g => pure g + | none => gs.findM? (fun g => do pure $ tag.isPrefixOf (← getMVarDecl g).userName) -@[builtinTactic «case»] def evalCase : Tactic := -fun stx => +@[builtinTactic «case»] def evalCase : Tactic := fun stx => match_syntax stx with | `(tactic| case $tag => $tac:tacticSeq) => do let tag := tag.getId @@ -423,19 +421,17 @@ fun stx => setGoals gs | _ => throwUnsupportedSyntax -@[builtinTactic «orelse»] def evalOrelse : Tactic := -fun stx => match_syntax stx with +@[builtinTactic «orelse»] def evalOrelse : Tactic := fun stx => + match_syntax stx with | `(tactic| $tac1 <|> $tac2) => evalTactic tac1 <|> evalTactic tac2 | _ => throwUnsupportedSyntax -@[builtinInit] private def regTraceClasses : IO Unit := do -registerTraceClass `Elab.tactic; -pure () +builtin_initialize registerTraceClass `Elab.tactic @[inline] def TacticM.run {α} (x : TacticM α) (ctx : Context) (s : State) : TermElabM (α × State) := -x ctx $.run s + x ctx $.run s @[inline] def TacticM.run' {α} (x : TacticM α) (ctx : Context) (s : State) : TermElabM α := -Prod.fst <$> x.run ctx s + Prod.fst <$> x.run ctx s end Lean.Elab.Tactic diff --git a/src/Lean/Elab/Tactic/Rewrite.lean b/src/Lean/Elab/Tactic/Rewrite.lean index 4ceb5f89d5..6c717be7e9 100644 --- a/src/Lean/Elab/Tactic/Rewrite.lean +++ b/src/Lean/Elab/Tactic/Rewrite.lean @@ -12,44 +12,43 @@ import Lean.Elab.Tactic.Location namespace Lean.Elab.Tactic open Meta -@[builtinMacro Lean.Parser.Tactic.rewriteSeq] def expandRewriteTactic : Macro := -fun stx => +@[builtinMacro Lean.Parser.Tactic.rewriteSeq] def expandRewriteTactic : Macro := fun stx => let seq := stx[1][1].getSepArgs let loc := stx[2] pure $ mkNullNode $ seq.map fun rwRule => Syntax.node `Lean.Parser.Tactic.rewrite #[mkAtomFrom rwRule "rewrite ", rwRule, loc] def rewriteTarget (stx : Syntax) (symm : Bool) : TacticM Unit := do -let (g, gs) ← getMainGoal -withMVarContext g do - let e ← elabTerm stx none true - let target ← instantiateMVars (← getMVarDecl g).type - let r ← rewrite g target e symm - let g' ← replaceTargetEq g r.eNew r.eqProof - setGoals (g' :: r.mvarIds ++ gs) + let (g, gs) ← getMainGoal + withMVarContext g do + let e ← elabTerm stx none true + let target ← instantiateMVars (← getMVarDecl g).type + let r ← rewrite g target e symm + let g' ← replaceTargetEq g r.eNew r.eqProof + setGoals (g' :: r.mvarIds ++ gs) def rewriteLocalDeclFVarId (stx : Syntax) (symm : Bool) (fvarId : FVarId) : TacticM Unit := do -let (g, gs) ← getMainGoal -withMVarContext g do - let e ← elabTerm stx none true - let localDecl ← getLocalDecl fvarId - let rwResult ← rewrite g localDecl.type e symm - let replaceResult ← replaceLocalDecl g fvarId rwResult.eNew rwResult.eqProof - setGoals (replaceResult.mvarId :: rwResult.mvarIds ++ gs) + let (g, gs) ← getMainGoal + withMVarContext g do + let e ← elabTerm stx none true + let localDecl ← getLocalDecl fvarId + let rwResult ← rewrite g localDecl.type e symm + let replaceResult ← replaceLocalDecl g fvarId rwResult.eNew rwResult.eqProof + setGoals (replaceResult.mvarId :: rwResult.mvarIds ++ gs) -def rewriteLocalDecl (stx : Syntax) (symm : Bool) (userName : Name) : TacticM Unit := do -withMainMVarContext do - let localDecl ← getLocalDeclFromUserName userName - rewriteLocalDeclFVarId stx symm localDecl.fvarId +def rewriteLocalDecl (stx : Syntax) (symm : Bool) (userName : Name) : TacticM Unit := + withMainMVarContext do + let localDecl ← getLocalDeclFromUserName userName + rewriteLocalDeclFVarId stx symm localDecl.fvarId def rewriteAll (stx : Syntax) (symm : Bool) : TacticM Unit := do -let worked ← «try» $ rewriteTarget stx symm -withMainMVarContext do - -- We must traverse backwards because `replaceLocalDecl` uses the revert/intro idiom - for fvarId in (← getLCtx).getFVarIds.reverse do - worked := worked || (← «try» $ rewriteLocalDeclFVarId stx symm fvarId) - unless worked do - let (mvarId, _) ← getMainGoal - throwTacticEx `rewrite mvarId "did not find instance of the pattern in the current goal" + let worked ← «try» $ rewriteTarget stx symm + withMainMVarContext do + -- We must traverse backwards because `replaceLocalDecl` uses the revert/intro idiom + for fvarId in (← getLCtx).getFVarIds.reverse do + worked := worked || (← «try» $ rewriteLocalDeclFVarId stx symm fvarId) + unless worked do + let (mvarId, _) ← getMainGoal + throwTacticEx `rewrite mvarId "did not find instance of the pattern in the current goal" /- ``` @@ -57,8 +56,7 @@ def rwRule := parser! optional (unicodeSymbol "←" "<-") >> termParser def «rewrite» := parser! "rewrite" >> rwRule >> optional location ``` -/ -@[builtinTactic Lean.Parser.Tactic.rewrite] def evalRewrite : Tactic := -fun stx => do +@[builtinTactic Lean.Parser.Tactic.rewrite] def evalRewrite : Tactic := fun stx => do let rule := stx[1] let symm := !rule[0].isNone let term := rule[1] diff --git a/src/Lean/Elab/Term.lean b/src/Lean/Elab/Term.lean index a6bcd10c89..9c363e8bac 100644 --- a/src/Lean/Elab/Term.lean +++ b/src/Lean/Elab/Term.lean @@ -143,8 +143,7 @@ abbrev TermElab := Syntax → Option Expr → TermElabM Expr open Meta -instance TermElabM.inhabited {α} : Inhabited (TermElabM α) := - ⟨throw $ arbitrary _⟩ +instance {α} : Inhabited (TermElabM α) := ⟨throw $ arbitrary _⟩ structure SavedState := (core : Core.State) @@ -237,7 +236,7 @@ instance : AddErrorMessageContext TermElabM := { pure (ref, msg) } -instance monadLog : MonadLog TermElabM := { +instance : MonadLog TermElabM := { getRef := getRef, getFileMap := do pure (← read).fileMap, getFileName := do pure (← read).fileName, @@ -254,7 +253,7 @@ protected def getMainModule : TermElabM Name := do pure (← getEnv).mainMod let fresh ← modifyGetThe Core.State (fun st => (st.nextMacroScope, { st with nextMacroScope := st.nextMacroScope + 1 })) withReader (fun ctx => { ctx with currMacroScope := fresh }) x -instance monadQuotation : MonadQuotation TermElabM := { +instance : MonadQuotation TermElabM := { getCurrMacroScope := Term.getCurrMacroScope, getMainModule := Term.getMainModule, withFreshMacroScope := Term.withFreshMacroScope @@ -279,8 +278,10 @@ inductive LVal | fieldName (name : String) | getOp (idx : Syntax) -instance LVal.hasToString : HasToString LVal := - ⟨fun p => match p with | LVal.fieldIdx i => toString i | LVal.fieldName n => n | LVal.getOp idx => "[" ++ toString idx ++ "]"⟩ +instance : HasToString LVal := ⟨fun + | LVal.fieldIdx i => toString i + | LVal.fieldName n => n + | LVal.getOp idx => "[" ++ toString idx ++ "]"⟩ instance : MonadResolveName TermElabM := { getCurrNamespace := do pure (← read).currNamespace, diff --git a/src/Lean/Elab/Util.lean b/src/Lean/Elab/Util.lean index fc4fd19906..336f17d3e3 100644 --- a/src/Lean/Elab/Util.lean +++ b/src/Lean/Elab/Util.lean @@ -12,125 +12,125 @@ import Lean.Elab.Exception namespace Lean def Syntax.prettyPrint (stx : Syntax) : Format := -match stx.unsetTrailing.reprint with -- TODO use syntax pretty printer -| some str => format str.toFormat -| none => format stx + match stx.unsetTrailing.reprint with -- TODO use syntax pretty printer + | some str => format str.toFormat + | none => format stx def MacroScopesView.format (view : MacroScopesView) (mainModule : Name) : Format := -fmt $ - if view.scopes.isEmpty then - view.name - else if view.mainModule == mainModule then - view.scopes.foldl mkNameNum (view.name ++ view.imported) - else - view.scopes.foldl mkNameNum (view.name ++ view.imported ++ view.mainModule) + fmt $ + if view.scopes.isEmpty then + view.name + else if view.mainModule == mainModule then + view.scopes.foldl mkNameNum (view.name ++ view.imported) + else + view.scopes.foldl mkNameNum (view.name ++ view.imported ++ view.mainModule) namespace Elab structure MacroStackElem := -(before : Syntax) (after : Syntax) + (before : Syntax) (after : Syntax) abbrev MacroStack := List MacroStackElem /- If `ref` does not have position information, then try to use macroStack -/ def getBetterRef (ref : Syntax) (macroStack : MacroStack) : Syntax := -match ref.getPos with -| some _ => ref -| none => - match macroStack.find? (·.before.getPos != none) with - | some elem => elem.before - | none => ref + match ref.getPos with + | some _ => ref + | none => + match macroStack.find? (·.before.getPos != none) with + | some elem => elem.before + | none => ref def ppMacroStackDefault := false def getMacroStackOption (o : Options) : Bool:= o.get `pp.macroStack ppMacroStackDefault def setMacroStackOption (o : Options) (flag : Bool) : Options := o.setBool `pp.macroStack flag + builtin_initialize registerOption `pp.macroStack { defValue := ppMacroStackDefault, group := "pp", descr := "dispaly macro expansion stack" } def addMacroStack {m} [Monad m] [MonadOptions m] (msgData : MessageData) (macroStack : MacroStack) : m MessageData := do -if !getMacroStackOption (← getOptions) then pure msgData else -match macroStack with -| [] => pure msgData -| stack@(top::_) => - let msgData := msgData ++ Format.line ++ "with resulting expansion" ++ MessageData.nest 2 (Format.line ++ top.after) - pure $ stack.foldl - (fun (msgData : MessageData) (elem : MacroStackElem) => - msgData ++ Format.line ++ "while expanding" ++ MessageData.nest 2 (Format.line ++ elem.before)) - msgData + if !getMacroStackOption (← getOptions) then pure msgData else + match macroStack with + | [] => pure msgData + | stack@(top::_) => + let msgData := msgData ++ Format.line ++ "with resulting expansion" ++ MessageData.nest 2 (Format.line ++ top.after) + pure $ stack.foldl + (fun (msgData : MessageData) (elem : MacroStackElem) => + msgData ++ Format.line ++ "while expanding" ++ MessageData.nest 2 (Format.line ++ elem.before)) + msgData def checkSyntaxNodeKind (k : Name) : AttrM Name := do -if Parser.isValidSyntaxNodeKind (← getEnv) k then pure k -else throwError "failed" + if Parser.isValidSyntaxNodeKind (← getEnv) k then pure k + else throwError "failed" namespace OldFrontend -- TODO: delete private def checkSyntaxNodeKindAtNamespacesAux (k : Name) : List Name → AttrM Name -| [] => throwError "failed" -| n::ns => checkSyntaxNodeKind (n ++ k) <|> checkSyntaxNodeKindAtNamespacesAux k ns + | [] => throwError "failed" + | n::ns => checkSyntaxNodeKind (n ++ k) <|> checkSyntaxNodeKindAtNamespacesAux k ns def checkSyntaxNodeKindAtNamespaces (k : Name) : AttrM Name := do -let env ← getEnv -checkSyntaxNodeKindAtNamespacesAux k (Lean.TODELETE.getNamespaces env) + let env ← getEnv + checkSyntaxNodeKindAtNamespacesAux k (Lean.TODELETE.getNamespaces env) end OldFrontend def checkSyntaxNodeKindAtNamespacesAux (k : Name) : Name → AttrM Name -| n@(Name.str p _ _) => checkSyntaxNodeKind (n ++ k) <|> checkSyntaxNodeKindAtNamespacesAux k p -| _ => throwError "failed" + | n@(Name.str p _ _) => checkSyntaxNodeKind (n ++ k) <|> checkSyntaxNodeKindAtNamespacesAux k p + | _ => throwError "failed" def checkSyntaxNodeKindAtNamespaces (k : Name) : AttrM Name := do -let ctx ← read -checkSyntaxNodeKindAtNamespacesAux k ctx.currNamespace + let ctx ← read + checkSyntaxNodeKindAtNamespacesAux k ctx.currNamespace def syntaxNodeKindOfAttrParam (defaultParserNamespace : Name) (arg : Syntax) : AttrM SyntaxNodeKind := -match attrParamSyntaxToIdentifier arg with -| some k => - checkSyntaxNodeKind k - <|> - checkSyntaxNodeKindAtNamespaces k - <|> - OldFrontend.checkSyntaxNodeKindAtNamespaces k -- TODO: delete the following old frontend support code - <|> - checkSyntaxNodeKind (defaultParserNamespace ++ k) - <|> - throwError! "invalid syntax node kind '{k}'" -| none => throwError "syntax node kind is missing" + match attrParamSyntaxToIdentifier arg with + | some k => + checkSyntaxNodeKind k + <|> + checkSyntaxNodeKindAtNamespaces k + <|> + OldFrontend.checkSyntaxNodeKindAtNamespaces k -- TODO: delete the following old frontend support code + <|> + checkSyntaxNodeKind (defaultParserNamespace ++ k) + <|> + throwError! "invalid syntax node kind '{k}'" + | none => throwError "syntax node kind is missing" private unsafe def evalSyntaxConstantUnsafe (env : Environment) (opts : Options) (constName : Name) : ExceptT String Id Syntax := -env.evalConstCheck Syntax opts `Lean.Syntax constName + env.evalConstCheck Syntax opts `Lean.Syntax constName @[implementedBy evalSyntaxConstantUnsafe] constant evalSyntaxConstant (env : Environment) (opts : Options) (constName : Name) : ExceptT String Id Syntax := throw "" unsafe def mkElabAttribute (γ) (attrDeclName attrBuiltinName attrName : Name) (parserNamespace : Name) (typeName : Name) (kind : String) : IO (KeyedDeclsAttribute γ) := -KeyedDeclsAttribute.init { - builtinName := attrBuiltinName, - name := attrName, - descr := kind ++ " elaborator", - valueTypeName := typeName, - evalKey := fun _ arg => syntaxNodeKindOfAttrParam parserNamespace arg, -} attrDeclName + KeyedDeclsAttribute.init { + builtinName := attrBuiltinName, + name := attrName, + descr := kind ++ " elaborator", + valueTypeName := typeName, + evalKey := fun _ arg => syntaxNodeKindOfAttrParam parserNamespace arg, + } attrDeclName unsafe def mkMacroAttributeUnsafe : IO (KeyedDeclsAttribute Macro) := -mkElabAttribute Macro `Lean.Elab.macroAttribute `builtinMacro `macro Name.anonymous `Lean.Macro "macro" + mkElabAttribute Macro `Lean.Elab.macroAttribute `builtinMacro `macro Name.anonymous `Lean.Macro "macro" @[implementedBy mkMacroAttributeUnsafe] -constant mkMacroAttribute : IO (KeyedDeclsAttribute Macro) := arbitrary _ +constant mkMacroAttribute : IO (KeyedDeclsAttribute Macro) builtin_initialize macroAttribute : KeyedDeclsAttribute Macro ← mkMacroAttribute private def expandMacroFns (stx : Syntax) : List Macro → MacroM Syntax -| [] => throw Macro.Exception.unsupportedSyntax -| m::ms => do - try - m stx - catch - | Macro.Exception.unsupportedSyntax => expandMacroFns stx ms - | ex => throw ex + | [] => throw Macro.Exception.unsupportedSyntax + | m::ms => do + try + m stx + catch + | Macro.Exception.unsupportedSyntax => expandMacroFns stx ms + | ex => throw ex -def getMacros (env : Environment) : Macro := -fun stx => +def getMacros (env : Environment) : Macro := fun stx => let k := stx.getKind let table := (macroAttribute.ext.getState env).table match table.find? k with @@ -138,38 +138,39 @@ fun stx => | none => throw Macro.Exception.unsupportedSyntax class MonadMacroAdapter (m : Type → Type) := -(getCurrMacroScope : m MacroScope) -(getNextMacroScope : m MacroScope) -(setNextMacroScope : MacroScope → m Unit) + (getCurrMacroScope : m MacroScope) + (getNextMacroScope : m MacroScope) + (setNextMacroScope : MacroScope → m Unit) -instance monadMacroAdapterTrans (m n) [MonadMacroAdapter m] [MonadLift m n] : MonadMacroAdapter n := -{ getCurrMacroScope := liftM (MonadMacroAdapter.getCurrMacroScope : m _), +instance (m n) [MonadMacroAdapter m] [MonadLift m n] : MonadMacroAdapter n := { + getCurrMacroScope := liftM (MonadMacroAdapter.getCurrMacroScope : m _), getNextMacroScope := liftM (MonadMacroAdapter.getNextMacroScope : m _), - setNextMacroScope := fun s => liftM (MonadMacroAdapter.setNextMacroScope s : m _) } + setNextMacroScope := fun s => liftM (MonadMacroAdapter.setNextMacroScope s : m _) +} private def expandMacro? (env : Environment) (stx : Syntax) : MacroM (Option Syntax) := do -try - let newStx ← getMacros env stx - pure (some newStx) -catch - | Macro.Exception.unsupportedSyntax => pure none - | ex => throw ex + try + let newStx ← getMacros env stx + pure (some newStx) + catch + | Macro.Exception.unsupportedSyntax => pure none + | ex => throw ex @[inline] def liftMacroM {α} {m : Type → Type} [Monad m] [MonadMacroAdapter m] [MonadEnv m] [MonadRecDepth m] [MonadExceptOf Exception m] [Ref m] [AddErrorMessageContext m] (x : MacroM α) : m α := do -let env ← getEnv -match x { macroEnv := Macro.mkMacroEnv (expandMacro? env), - currMacroScope := ← MonadMacroAdapter.getCurrMacroScope, - mainModule := env.mainModule, - currRecDepth := ← MonadRecDepth.getRecDepth, - maxRecDepth := ← MonadRecDepth.getMaxRecDepth } (← MonadMacroAdapter.getNextMacroScope) with -| EStateM.Result.error Macro.Exception.unsupportedSyntax _ => throwUnsupportedSyntax -| EStateM.Result.error (Macro.Exception.error ref msg) _ => throwErrorAt ref msg -| EStateM.Result.ok a nextMacroScope => MonadMacroAdapter.setNextMacroScope nextMacroScope; pure a + let env ← getEnv + match x { macroEnv := Macro.mkMacroEnv (expandMacro? env), + currMacroScope := ← MonadMacroAdapter.getCurrMacroScope, + mainModule := env.mainModule, + currRecDepth := ← MonadRecDepth.getRecDepth, + maxRecDepth := ← MonadRecDepth.getMaxRecDepth } (← MonadMacroAdapter.getNextMacroScope) with + | EStateM.Result.error Macro.Exception.unsupportedSyntax _ => throwUnsupportedSyntax + | EStateM.Result.error (Macro.Exception.error ref msg) _ => throwErrorAt ref msg + | EStateM.Result.ok a nextMacroScope => MonadMacroAdapter.setNextMacroScope nextMacroScope; pure a @[inline] def adaptMacro {m : Type → Type} [Monad m] [MonadMacroAdapter m] [MonadEnv m] [MonadRecDepth m] [MonadExceptOf Exception m] [Ref m] [AddErrorMessageContext m] (x : Macro) (stx : Syntax) : m Syntax := -liftMacroM (x stx) + liftMacroM (x stx) builtin_initialize registerTraceClass `Elab diff --git a/src/Lean/Environment.lean b/src/Lean/Environment.lean index 6695da4e04..0061316c04 100644 --- a/src/Lean/Environment.lean +++ b/src/Lean/Environment.lean @@ -712,7 +712,7 @@ class MonadEnv (m : Type → Type) := export MonadEnv (getEnv modifyEnv) -instance monadEnvFromLift (m n) [MonadEnv m] [MonadLift m n] : MonadEnv n := { +instance (m n) [MonadEnv m] [MonadLift m n] : MonadEnv n := { getEnv := liftM (getEnv : m Environment), modifyEnv := fun f => liftM (modifyEnv f : m Unit) } diff --git a/src/Lean/Eval.lean b/src/Lean/Eval.lean index e1d6f20cd1..037b29a13f 100644 --- a/src/Lean/Eval.lean +++ b/src/Lean/Eval.lean @@ -16,7 +16,7 @@ universe u class MetaHasEval (α : Type u) := (eval : Environment → Options → α → (hideUnit : Bool) → IO Environment) -instance metaHasEvalOfHasEval {α : Type u} [HasEval α] : MetaHasEval α := +instance {α : Type u} [HasEval α] : MetaHasEval α := ⟨fun env opts a hideUnit => do HasEval.eval (fun _ => a) hideUnit; pure env⟩ def runMetaEval {α : Type u} [MetaHasEval α] (env : Environment) (opts : Options) (a : α) : IO (String × Except IO.Error Environment) := diff --git a/src/Lean/Exception.lean b/src/Lean/Exception.lean index c77a7187e9..105bccf5f9 100644 --- a/src/Lean/Exception.lean +++ b/src/Lean/Exception.lean @@ -53,7 +53,7 @@ def replaceRef (ref : Syntax) (oldRef : Syntax) : Syntax := class AddErrorMessageContext (m : Type → Type) := (add : Syntax → MessageData → m (Syntax × MessageData)) -instance addErrorMessageContextDefault (m : Type → Type) [AddMessageContext m] [Monad m] : AddErrorMessageContext m := { +instance (m : Type → Type) [AddMessageContext m] [Monad m] : AddErrorMessageContext m := { add := fun ref msg => do msg ← addMessageContext msg pure (ref, msg) diff --git a/src/Lean/Expr.lean b/src/Lean/Expr.lean index 69d2e7ae6a..dd1fc11691 100644 --- a/src/Lean/Expr.lean +++ b/src/Lean/Expr.lean @@ -13,20 +13,20 @@ inductive Literal | natVal (val : Nat) | strVal (val : String) -instance Literal.inhabited : Inhabited Literal := ⟨Literal.natVal 0⟩ +instance : Inhabited Literal := ⟨Literal.natVal 0⟩ protected def Literal.hash : Literal → USize | Literal.natVal v => hash v | Literal.strVal v => hash v -instance Literal.hashable : Hashable Literal := ⟨Literal.hash⟩ +instance : Hashable Literal := ⟨Literal.hash⟩ def Literal.beq : Literal → Literal → Bool | Literal.natVal v₁, Literal.natVal v₂ => v₁ == v₂ | Literal.strVal v₁, Literal.strVal v₂ => v₁ == v₂ | _, _ => false -instance Literal.hasBeq : HasBeq Literal := ⟨Literal.beq⟩ +instance : HasBeq Literal := ⟨Literal.beq⟩ def Literal.lt : Literal → Literal → Bool | Literal.natVal _, Literal.strVal _ => true @@ -34,9 +34,9 @@ def Literal.lt : Literal → Literal → Bool | Literal.strVal v₁, Literal.strVal v₂ => v₁ < v₂ | _, _ => false -instance Literal.hasLess : HasLess Literal := ⟨fun a b => a.lt b⟩ +instance : HasLess Literal := ⟨fun a b => a.lt b⟩ -instance Literal.decLess (a b : Literal) : Decidable (a < b) := +instance (a b : Literal) : Decidable (a < b) := inferInstanceAs (Decidable (a.lt b)) inductive BinderInfo @@ -55,9 +55,9 @@ def BinderInfo.isExplicit : BinderInfo → Bool | BinderInfo.instImplicit => false | _ => true -instance BinderInfo.hashable : Hashable BinderInfo := ⟨BinderInfo.hash⟩ +instance : Hashable BinderInfo := ⟨BinderInfo.hash⟩ -instance BinderInfo.inhabited : Inhabited BinderInfo := ⟨BinderInfo.default⟩ +instance : Inhabited BinderInfo := ⟨BinderInfo.default⟩ def BinderInfo.isInstImplicit : BinderInfo → Bool | BinderInfo.instImplicit => true @@ -75,7 +75,7 @@ protected def BinderInfo.beq : BinderInfo → BinderInfo → Bool | BinderInfo.auxDecl, BinderInfo.auxDecl => true | _, _ => false -instance BinderInfo.hasBeq : HasBeq BinderInfo := ⟨BinderInfo.beq⟩ +instance : HasBeq BinderInfo := ⟨BinderInfo.beq⟩ abbrev MData := KVMap abbrev MData.empty : MData := {} @@ -92,13 +92,13 @@ abbrev MData.empty : MData := {} looseBVarRange : 24-bits -/ def Expr.Data := UInt64 -instance Expr.Data.inhabited : Inhabited Expr.Data := +instance: Inhabited Expr.Data := inferInstanceAs (Inhabited UInt64) def Expr.Data.hash (c : Expr.Data) : USize := c.toUInt32.toUSize -instance Expr.Data.hasBeq : HasBeq Expr.Data := +instance : HasBeq Expr.Data := ⟨fun (a b : UInt64) => a == b⟩ def Expr.Data.looseBVarRange (c : Expr.Data) : UInt32 := diff --git a/src/Lean/Hygiene.lean b/src/Lean/Hygiene.lean index 2c1604994f..f03bf51cb6 100644 --- a/src/Lean/Hygiene.lean +++ b/src/Lean/Hygiene.lean @@ -24,7 +24,7 @@ namespace Lean corresponding to `withFreshMacroScope` calls. -/ abbrev Unhygienic := ReaderT MacroScope $ StateM MacroScope namespace Unhygienic -instance MonadQuotation : MonadQuotation Unhygienic := { +instance : MonadQuotation Unhygienic := { getCurrMacroScope := read, getMainModule := pure `UnhygienicMain, withFreshMacroScope := fun x => do diff --git a/src/Lean/KeyedDeclsAttribute.lean b/src/Lean/KeyedDeclsAttribute.lean index f953c5819d..e1a5806c1a 100644 --- a/src/Lean/KeyedDeclsAttribute.lean +++ b/src/Lean/KeyedDeclsAttribute.lean @@ -39,7 +39,7 @@ structure Def (γ : Type) := | some id => pure id | none => throwError "invalid attribute argument, expected identifier") -instance Def.inhabited {γ} : Inhabited (Def γ) := +instance {γ} : Inhabited (Def γ) := ⟨{ builtinName := arbitrary _, name := arbitrary _, descr := arbitrary _, valueTypeName := arbitrary _ }⟩ structure OLeanEntry := @@ -75,10 +75,9 @@ match table.find? k with | some vs => SMap.insert table k (v::vs) | none => SMap.insert table k [v] -instance ExtensionState.inhabited {γ} : Inhabited (ExtensionState γ) := -⟨{}⟩ +instance {γ} : Inhabited (ExtensionState γ) := ⟨{}⟩ -instance KeyedDeclsAttribute.inhabited {γ} : Inhabited (KeyedDeclsAttribute γ) := +instance {γ} : Inhabited (KeyedDeclsAttribute γ) := ⟨{ defn := arbitrary _, tableRef := arbitrary _, ext := arbitrary _ }⟩ private def mkInitial {γ} (tableRef : IO.Ref (Table γ)) : IO (ExtensionState γ) := do diff --git a/src/Lean/Meta/Basic.lean b/src/Lean/Meta/Basic.lean index 3aa3ec8f8f..deaa0f5a34 100644 --- a/src/Lean/Meta/Basic.lean +++ b/src/Lean/Meta/Basic.lean @@ -61,7 +61,7 @@ structure ParamInfo := (hasFwdDeps : Bool := false) (backDeps : Array Nat := #[]) -instance ParamInfo.inhabited : Inhabited ParamInfo := ⟨{}⟩ +instance : Inhabited ParamInfo := ⟨{}⟩ def ParamInfo.isExplicit (p : ParamInfo) : Bool := !p.implicit && p.instImplicit diff --git a/src/Lean/Meta/Closure.lean b/src/Lean/Meta/Closure.lean index 43fcf7ae40..d657bc2f26 100644 --- a/src/Lean/Meta/Closure.lean +++ b/src/Lean/Meta/Closure.lean @@ -101,7 +101,7 @@ namespace Closure structure ToProcessElement := (fvarId : FVarId) (newFVarId : FVarId) -instance ToProcessElement.inhabited : Inhabited ToProcessElement := +instance : Inhabited ToProcessElement := ⟨⟨arbitrary _, arbitrary _⟩⟩ structure Context := diff --git a/src/Lean/Meta/SynthInstance.lean b/src/Lean/Meta/SynthInstance.lean index 5d78e2884e..dcd0864ffc 100644 --- a/src/Lean/Meta/SynthInstance.lean +++ b/src/Lean/Meta/SynthInstance.lean @@ -161,7 +161,7 @@ abbrev SynthM := StateRefT State MetaM @[inline] def mapMetaM (f : forall {α}, MetaM α → MetaM α) {α} : SynthM α → SynthM α := monadMap @f -instance SynthM.inhabited {α} : Inhabited (SynthM α) := +instance {α} : Inhabited (SynthM α) := ⟨fun _ => arbitrary _⟩ /-- Return globals and locals instances that may unify with `type` -/ diff --git a/src/Lean/ToExpr.lean b/src/Lean/ToExpr.lean index 5d9dd5e130..f584441e9a 100644 --- a/src/Lean/ToExpr.lean +++ b/src/Lean/ToExpr.lean @@ -15,66 +15,73 @@ class ToExpr (α : Type u) := export ToExpr (toExpr toTypeExpr) -instance : ToExpr Expr := -{ toExpr := id, - toTypeExpr := mkConst `Expr } +instance : ToExpr Expr := { + toExpr := id, + toTypeExpr := mkConst `Expr +} -instance : ToExpr Nat := -{ toExpr := mkNatLit, - toTypeExpr := mkConst `Nat } +instance : ToExpr Nat := { + toExpr := mkNatLit, + toTypeExpr := mkConst `Nat +} -instance : ToExpr Bool := -{ toExpr := fun b => if b then mkConst `Bool.true else mkConst `Bool.false, - toTypeExpr := mkConst `Bool } +instance : ToExpr Bool := { + toExpr := fun b => if b then mkConst `Bool.true else mkConst `Bool.false, + toTypeExpr := mkConst `Bool +} -instance : ToExpr Char := -{ toExpr := fun c => mkApp (mkConst `Char.ofNat) (toExpr c.toNat), - toTypeExpr := mkConst `Char } +instance : ToExpr Char := { + toExpr := fun c => mkApp (mkConst `Char.ofNat) (toExpr c.toNat), + toTypeExpr := mkConst `Char +} -instance : ToExpr String := -{ toExpr := mkStrLit, - toTypeExpr := mkConst `String } +instance : ToExpr String := { + toExpr := mkStrLit, + toTypeExpr := mkConst `String +} -instance : ToExpr Unit := -{ toExpr := fun _ => mkConst `Unit.unit, - toTypeExpr := mkConst `Unit } +instance : ToExpr Unit := { + toExpr := fun _ => mkConst `Unit.unit, + toTypeExpr := mkConst `Unit +} def Name.toExprAux : Name → Expr -| Name.anonymous => mkConst `Lean.Name.anonymous -| Name.str p s _ => mkAppB (mkConst `Lean.mkNameStr) (toExprAux p) (toExpr s) -| Name.num p n _ => mkAppB (mkConst `Lean.mkNameNum) (toExprAux p) (toExpr n) + | Name.anonymous => mkConst `Lean.Name.anonymous + | Name.str p s _ => mkAppB (mkConst `Lean.mkNameStr) (toExprAux p) (toExpr s) + | Name.num p n _ => mkAppB (mkConst `Lean.mkNameNum) (toExprAux p) (toExpr n) -instance nameToExpr : ToExpr Name := -{ toExpr := Name.toExprAux, - toTypeExpr := mkConst `Name } +instance : ToExpr Name := { + toExpr := Name.toExprAux, + toTypeExpr := mkConst `Name +} -instance optionToExpr {α : Type} [ToExpr α] : ToExpr (Option α) := -let type := toTypeExpr α -{ toExpr := fun o => match o with - | none => mkApp (mkConst `Option.none [levelZero]) type - | some a => mkApp2 (mkConst `Option.cons [levelZero]) type (toExpr a), - toTypeExpr := mkApp (mkConst `Option [levelZero]) type } +instance {α : Type} [ToExpr α] : ToExpr (Option α) := + let type := toTypeExpr α + { toExpr := fun o => match o with + | none => mkApp (mkConst `Option.none [levelZero]) type + | some a => mkApp2 (mkConst `Option.cons [levelZero]) type (toExpr a), + toTypeExpr := mkApp (mkConst `Option [levelZero]) type } def List.toExprAux {α} [ToExpr α] (nilFn : Expr) (consFn : Expr) : List α → Expr -| [] => nilFn -| a::as => mkApp2 consFn (toExpr a) (toExprAux nilFn consFn as) + | [] => nilFn + | a::as => mkApp2 consFn (toExpr a) (toExprAux nilFn consFn as) -instance listToExpr {α : Type} [ToExpr α] : ToExpr (List α) := -let type := toTypeExpr α -let nil := mkApp (mkConst `List.nil [levelZero]) type -let cons := mkApp (mkConst `List.cons [levelZero]) type -{ toExpr := List.toExprAux nil cons, - toTypeExpr := mkApp (mkConst `List [levelZero]) type } +instance {α : Type} [ToExpr α] : ToExpr (List α) := + let type := toTypeExpr α + let nil := mkApp (mkConst `List.nil [levelZero]) type + let cons := mkApp (mkConst `List.cons [levelZero]) type + { toExpr := List.toExprAux nil cons, + toTypeExpr := mkApp (mkConst `List [levelZero]) type } -instance arrayToExpr {α : Type} [ToExpr α] : ToExpr (Array α) := -let type := toTypeExpr α -{ toExpr := fun as => mkApp2 (mkConst `List.toArray [levelZero]) type (toExpr as.toList), - toTypeExpr := mkApp (mkConst `Array [levelZero]) type } +instance {α : Type} [ToExpr α] : ToExpr (Array α) := + let type := toTypeExpr α + { toExpr := fun as => mkApp2 (mkConst `List.toArray [levelZero]) type (toExpr as.toList), + toTypeExpr := mkApp (mkConst `Array [levelZero]) type } -instance prodToExpr {α : Type} {β : Type} [ToExpr α] [ToExpr β] : ToExpr (α × β) := -let αType := toTypeExpr α -let βType := toTypeExpr β -{ toExpr := fun ⟨a, b⟩ => mkApp4 (mkConst `Prod.mk [levelZero, levelZero]) αType βType (toExpr a) (toExpr b), - toTypeExpr := mkApp2 (mkConst `Prod [levelZero, levelZero]) αType βType } +instance {α : Type} {β : Type} [ToExpr α] [ToExpr β] : ToExpr (α × β) := + let αType := toTypeExpr α + let βType := toTypeExpr β + { toExpr := fun ⟨a, b⟩ => mkApp4 (mkConst `Prod.mk [levelZero, levelZero]) αType βType (toExpr a) (toExpr b), + toTypeExpr := mkApp2 (mkConst `Prod [levelZero, levelZero]) αType βType } end Lean