feat: field projections
This commit is contained in:
parent
32cebc3e76
commit
e25bd36dc5
2 changed files with 122 additions and 38 deletions
|
|
@ -450,15 +450,24 @@ fun stx expectedType? => do
|
|||
def elabExplicitUniv (stx : Syntax) : TermElabM (List Level) :=
|
||||
pure [] -- TODO
|
||||
|
||||
inductive Arg
|
||||
| stx (val : Syntax)
|
||||
| expr (val : Expr)
|
||||
|
||||
instance Arg.inhabited : Inhabited Arg := ⟨Arg.stx (arbitrary _)⟩
|
||||
|
||||
instance Arg.hasToString : HasToString Arg :=
|
||||
⟨fun arg => match arg with
|
||||
| Arg.stx val => toString val
|
||||
| Arg.expr val => toString val⟩
|
||||
|
||||
structure NamedArg :=
|
||||
(name : Name)
|
||||
(val : Syntax)
|
||||
(stx : Syntax)
|
||||
(name : Name) (val : Arg)
|
||||
|
||||
instance NamedArg.hasToString : HasToString NamedArg :=
|
||||
⟨fun s => "(" ++ toString s.name ++ " := " ++ toString s.val ++ ")"⟩
|
||||
|
||||
instance NamedArg.inhabited : Inhabited NamedArg := ⟨{ name := arbitrary _, val := arbitrary _, stx := arbitrary _ }⟩
|
||||
instance NamedArg.inhabited : Inhabited NamedArg := ⟨{ name := arbitrary _, val := arbitrary _ }⟩
|
||||
|
||||
def addNamedArg (namedArgs : Array NamedArg) (namedArg : NamedArg) (ref : Syntax) : TermElabM (Array NamedArg) := do
|
||||
when (namedArgs.any $ fun namedArg' => namedArg.name == namedArg'.name) $
|
||||
|
|
@ -479,20 +488,28 @@ pure $ resolveLocalNameAux lctx n []
|
|||
private def mkFreshLevelMVars (ref : Syntax) (num : Nat) : TermElabM (List Level) :=
|
||||
num.foldM (fun _ us => do u ← mkFreshLevelMVar ref; pure $ u::us) []
|
||||
|
||||
def mkConst (ref : Syntax) (constName : Name) (explicitLevels : List Level := []) : TermElabM Expr := do
|
||||
env ← getEnv;
|
||||
match env.find constName with
|
||||
| none => throwError ref ("unknown constant '" ++ constName ++ "'")
|
||||
| some cinfo =>
|
||||
if explicitLevels.length > cinfo.lparams.length then
|
||||
throwError ref ("too many explicit universe levels")
|
||||
else do
|
||||
let numMissingLevels := cinfo.lparams.length - explicitLevels.length;
|
||||
us ← mkFreshLevelMVars ref numMissingLevels;
|
||||
pure $ Lean.mkConst constName (explicitLevels ++ us)
|
||||
|
||||
private def mkConsts (ref : Syntax) (candidates : List (Name × List String)) (explicitLevels : List Level) : TermElabM (List (Expr × List String)) := do
|
||||
env ← getEnv;
|
||||
candidates.foldlM
|
||||
(fun result ⟨constName, projs⟩ =>
|
||||
match env.find constName with
|
||||
| none => unreachable!
|
||||
| some cinfo =>
|
||||
if explicitLevels.length > cinfo.lparams.length then
|
||||
-- Remark: we discard candidate because of the number of explicit universe levels provided.
|
||||
pure result
|
||||
else do
|
||||
let numMissingLevels := cinfo.lparams.length - explicitLevels.length;
|
||||
us ← mkFreshLevelMVars ref numMissingLevels;
|
||||
pure $ (mkConst constName (explicitLevels ++ us), projs) :: result)
|
||||
catch
|
||||
(do const ← mkConst ref constName explicitLevels;
|
||||
pure $ (const, projs) :: result)
|
||||
(fun _ =>
|
||||
-- Remark: we discard candidates based on the number of explicit universe levels provided.
|
||||
pure result))
|
||||
[]
|
||||
|
||||
def resolveName (n : Name) (preresolved : List (Name × List String)) (explicitLevels : List Level) (ref : Syntax) : TermElabM (List (Expr × List String)) := do
|
||||
|
|
@ -556,7 +573,15 @@ condM (isExprMVarAssigned instMVar) (pure ()) $ do
|
|||
def synthesizeInstMVars (ref : Syntax) (instMVars : Array MVarId) : TermElabM Unit :=
|
||||
instMVars.forM $ synthesizeInstMVar ref
|
||||
|
||||
private partial def elabAppArgsAux (ref : Syntax) (args : Array Syntax) (expectedType? : Option Expr) (explicit : Bool)
|
||||
private def elabArg (ref : Syntax) (arg : Arg) (expectedType : Expr) : TermElabM Expr :=
|
||||
match arg with
|
||||
| Arg.expr val => do
|
||||
valType ← inferType ref val;
|
||||
ensureHasType ref expectedType valType val
|
||||
| Arg.stx val =>
|
||||
elabTerm val expectedType
|
||||
|
||||
private partial def elabAppArgsAux (ref : Syntax) (args : Array Arg) (expectedType? : Option Expr) (explicit : Bool)
|
||||
: Nat → Array NamedArg → Array MVarId → Expr → Expr → TermElabM Expr
|
||||
| argIdx, namedArgs, instMVars, eType, e => do
|
||||
let finalize : Unit → TermElabM Expr := fun _ => do {
|
||||
|
|
@ -571,15 +596,15 @@ private partial def elabAppArgsAux (ref : Syntax) (args : Array Syntax) (expecte
|
|||
| Expr.forallE n d b c =>
|
||||
match namedArgs.findIdx? (fun namedArg => namedArg.name == n) with
|
||||
| some idx => do
|
||||
let arg := namedArgs.get! idx;
|
||||
let arg := namedArgs.get! idx;
|
||||
let namedArgs := namedArgs.eraseIdx idx;
|
||||
a ← elabTerm arg.val d;
|
||||
elabAppArgsAux argIdx namedArgs instMVars (b.instantiate1 a) (mkApp e a)
|
||||
argElab ← elabArg ref arg.val d;
|
||||
elabAppArgsAux argIdx namedArgs instMVars (b.instantiate1 argElab) (mkApp e argElab)
|
||||
| none =>
|
||||
let processExplictArg : Unit → TermElabM Expr := fun _ => do {
|
||||
if h : argIdx < args.size then do
|
||||
a ← elabTerm (args.get ⟨argIdx, h⟩) d;
|
||||
elabAppArgsAux (argIdx + 1) namedArgs instMVars (b.instantiate1 a) (mkApp e a)
|
||||
argElab ← elabArg ref (args.get ⟨argIdx, h⟩) d;
|
||||
elabAppArgsAux (argIdx + 1) namedArgs instMVars (b.instantiate1 argElab) (mkApp e argElab)
|
||||
else if namedArgs.isEmpty then
|
||||
finalize ()
|
||||
else
|
||||
|
|
@ -605,7 +630,7 @@ private partial def elabAppArgsAux (ref : Syntax) (args : Array Syntax) (expecte
|
|||
-- TODO: try `HasCoeToFun`
|
||||
throwError ref "too many arguments"
|
||||
|
||||
private def elabAppArgs (ref : Syntax) (f : Expr) (namedArgs : Array NamedArg) (args : Array Syntax)
|
||||
private def elabAppArgs (ref : Syntax) (f : Expr) (namedArgs : Array NamedArg) (args : Array Arg)
|
||||
(expectedType? : Option Expr) (explicit : Bool) : TermElabM Expr := do
|
||||
fType ← inferType ref f;
|
||||
let argIdx := 0;
|
||||
|
|
@ -613,7 +638,7 @@ let instMVars := #[];
|
|||
elabAppArgsAux ref args expectedType? explicit argIdx namedArgs instMVars fType f
|
||||
|
||||
inductive FieldResolution
|
||||
| projFn (fieldName : Name) (baseStructName : Name) (structName : Name)
|
||||
| projFn (baseStructName : Name) (structName : Name) (fieldName : Name)
|
||||
| projIdx (structName : Name) (idx : Nat)
|
||||
| const (constName : Name)
|
||||
| localRec (fvar : Expr)
|
||||
|
|
@ -632,7 +657,7 @@ match eType.getAppFn, field with
|
|||
let fieldNames := getStructureFields env structName;
|
||||
if h : idx - 1 < fieldNames.size then
|
||||
if isStructure env structName then
|
||||
pure $ FieldResolution.projFn (fieldNames.get ⟨idx - 1, h⟩) structName structName
|
||||
pure $ FieldResolution.projFn structName structName (fieldNames.get ⟨idx - 1, h⟩)
|
||||
else
|
||||
/- `structName` was declared using `inductive` command.
|
||||
So, we don't projection functions for it. Thus, we use `Expr.proj` -/
|
||||
|
|
@ -663,7 +688,7 @@ match eType.getAppFn, field with
|
|||
};
|
||||
if isStructure env structName then
|
||||
match findField? env structName fieldName with
|
||||
| some baseStructName => pure $ FieldResolution.projFn fieldName baseStructName structName
|
||||
| some baseStructName => pure $ FieldResolution.projFn baseStructName structName fieldName
|
||||
| none => searchLCtx ()
|
||||
else
|
||||
searchLCtx ()
|
||||
|
|
@ -685,20 +710,45 @@ private def resolveField (ref : Syntax) (e : Expr) (field : Field) : TermElabM F
|
|||
eType ← inferType ref e;
|
||||
resolveFieldLoop ref e field eType #[]
|
||||
|
||||
private def elabAppFieldsAux (ref : Syntax) (namedArgs : Array NamedArg) (args : Array Syntax) (expectedType? : Option Expr) (explicit : Bool)
|
||||
private partial def mkBaseProjections (ref : Syntax) (baseStructName : Name) (structName : Name) (e : Expr) : TermElabM Expr := do
|
||||
env ← getEnv;
|
||||
match getPathToBaseStructure? env baseStructName structName with
|
||||
| none => throwError ref "failed to access field in parent structure"
|
||||
| some path =>
|
||||
path.foldlM
|
||||
(fun e projFunName => do
|
||||
projFn ← mkConst ref projFunName;
|
||||
elabAppArgs ref projFn #[{ name := `self, val := Arg.expr e }] #[] none false)
|
||||
e
|
||||
|
||||
private def elabAppFieldsAux (ref : Syntax) (namedArgs : Array NamedArg) (args : Array Arg) (expectedType? : Option Expr) (explicit : Bool)
|
||||
: Expr → List Field → TermElabM Expr
|
||||
| f, [] => elabAppArgs ref f namedArgs args expectedType? explicit
|
||||
| f, field::fields => do
|
||||
fType ← inferType ref f;
|
||||
-- TODO
|
||||
elabAppArgs ref f namedArgs args expectedType? explicit
|
||||
fieldRes ← resolveField ref f field;
|
||||
match fieldRes with
|
||||
| FieldResolution.projIdx structName idx =>
|
||||
let f := mkProj structName idx f;
|
||||
elabAppFieldsAux f fields
|
||||
| FieldResolution.projFn baseStructName structName fieldName => do
|
||||
f ← mkBaseProjections ref baseStructName structName f;
|
||||
projFn ← mkConst ref (baseStructName ++ fieldName);
|
||||
if fields.isEmpty then do
|
||||
namedArgs ← addNamedArg namedArgs { name := `self, val := Arg.expr f } ref;
|
||||
elabAppArgs ref projFn namedArgs args expectedType? explicit
|
||||
else do
|
||||
f ← elabAppArgs ref projFn #[{ name := `self, val := Arg.expr f }] #[] none false;
|
||||
elabAppFieldsAux f fields
|
||||
| _ =>
|
||||
-- TODO
|
||||
elabAppArgs ref f namedArgs args expectedType? explicit
|
||||
|
||||
private def elabAppFields (ref : Syntax) (f : Expr) (fields : List Field) (namedArgs : Array NamedArg) (args : Array Syntax)
|
||||
private def elabAppFields (ref : Syntax) (f : Expr) (fields : List Field) (namedArgs : Array NamedArg) (args : Array Arg)
|
||||
(expectedType? : Option Expr) (explicit : Bool) : TermElabM Expr := do
|
||||
when (!fields.isEmpty && explicit) $ throwError ref "invalid use of projection notation with `@` modifier";
|
||||
when (!fields.isEmpty && explicit) $ throwError ref "invalid use of field notation with `@` modifier";
|
||||
elabAppFieldsAux ref namedArgs args expectedType? explicit f fields
|
||||
|
||||
private partial def elabAppFn (ref : Syntax) : Syntax → List Field → Array NamedArg → Array Syntax → Option Expr → Bool → Array TermElabResult → TermElabM (Array TermElabResult)
|
||||
private partial def elabAppFn (ref : Syntax) : Syntax → List Field → Array NamedArg → Array Arg → Option Expr → Bool → Array TermElabResult → TermElabM (Array TermElabResult)
|
||||
| f, fields, namedArgs, args, expectedType?, explicit, acc =>
|
||||
let k := f.getKind;
|
||||
if k == `Lean.Parser.Term.explicit then
|
||||
|
|
@ -710,8 +760,8 @@ private partial def elabAppFn (ref : Syntax) : Syntax → List Field → Array N
|
|||
-- term `.` (fieldIdx <|> ident)
|
||||
let field := f.getArg 2;
|
||||
match field.isFieldIdx?, field with
|
||||
| some idx, _ => elabAppFn (f.getArg 0) (Field.num idx :: fields) namedArgs args expectedType? true acc
|
||||
| _, Syntax.ident _ val _ _ => elabAppFn (f.getArg 0) (Field.str val.toString :: fields) namedArgs args expectedType? true acc
|
||||
| some idx, _ => elabAppFn (f.getArg 0) (Field.num idx :: fields) namedArgs args expectedType? explicit acc
|
||||
| _, Syntax.ident _ val _ _ => elabAppFn (f.getArg 0) (Field.str val.toString :: fields) namedArgs args expectedType? explicit acc
|
||||
| _, _ => throwError field "unexpected kind of field access"
|
||||
else if k == `Lean.Parser.Term.id then
|
||||
-- ident (explicitUniv | namedPattern)?
|
||||
|
|
@ -752,7 +802,7 @@ msgs ← failures.mapM $ fun failure =>
|
|||
| EStateM.Result.error ex s => toMessageData ex stx;
|
||||
throwError stx ("overloaded, errors " ++ MessageData.ofArray msgs)
|
||||
|
||||
private def elabAppAux (ref : Syntax) (f : Syntax) (namedArgs : Array NamedArg) (args : Array Syntax) (expectedType? : Option Expr) : TermElabM Expr := do
|
||||
private def elabAppAux (ref : Syntax) (f : Syntax) (namedArgs : Array NamedArg) (args : Array Arg) (expectedType? : Option Expr) : TermElabM Expr := do
|
||||
/- TODO: if `f` contains `choice` or overloaded symbols, `mayPostpone == true`, and `expectedType? == some ?m` where `?m` is not assigned,
|
||||
then we should postpone until `?m` is assigned.
|
||||
Another (more expensive) option is: execute, and if successes > 1, `mayPostpone == true`, and `expectedType? == some ?m` where `?m` is not assigned,
|
||||
|
|
@ -782,18 +832,18 @@ private partial def expandAppAux : Syntax → Array Syntax → Syntax × Array S
|
|||
expandAppAux fn (args.push arg))
|
||||
(fun _ => (stx, args.reverse))
|
||||
|
||||
private def expandApp (stx : Syntax) : TermElabM (Syntax × Array NamedArg × Array Syntax) := do
|
||||
private def expandApp (stx : Syntax) : TermElabM (Syntax × Array NamedArg × Array Arg) := do
|
||||
let (f, args) := expandAppAux stx #[];
|
||||
(namedArgs, args) ← args.foldlM
|
||||
(fun (acc : Array NamedArg × Array Syntax) arg =>
|
||||
(fun (acc : Array NamedArg × Array Arg) arg =>
|
||||
let (namedArgs, args) := acc;
|
||||
arg.ifNodeKind `Lean.Parser.Term.namedArgument
|
||||
(fun argNode => do
|
||||
-- `(` ident `:=` term `)`
|
||||
namedArgs ← addNamedArg acc.1 { name := argNode.getIdAt 1, val := argNode.getArg 3, stx := arg } arg;
|
||||
namedArgs ← addNamedArg acc.1 { name := argNode.getIdAt 1, val := Arg.stx $ argNode.getArg 3 } arg;
|
||||
pure (namedArgs, args))
|
||||
(fun _ =>
|
||||
pure (namedArgs, args.push arg)))
|
||||
pure (namedArgs, args.push $ Arg.stx arg)))
|
||||
(#[], #[]);
|
||||
pure (f, namedArgs, args)
|
||||
|
||||
|
|
|
|||
|
|
@ -63,3 +63,37 @@ pure "hello"
|
|||
#check ()
|
||||
#check run
|
||||
end"
|
||||
|
||||
structure S1 :=
|
||||
(x y : Nat := 0)
|
||||
|
||||
structure S2 extends S1 :=
|
||||
(z : Nat := 0)
|
||||
|
||||
structure S3 :=
|
||||
(w : Nat := 0)
|
||||
|
||||
structure S4 extends S2, S3 :=
|
||||
(s : Nat := 0)
|
||||
|
||||
def s4 : S4 := {}
|
||||
|
||||
structure S (α : Type) :=
|
||||
(field1 : S4 := {})
|
||||
(field2 : S4 × S4 := ({}, {}))
|
||||
(field3 : α)
|
||||
|
||||
inductive D (α : Type)
|
||||
| mk (a : α) (s : S4) : D
|
||||
|
||||
def s : S Nat := { field3 := 0 }
|
||||
def d : D Nat := D.mk 10 {}
|
||||
|
||||
#eval run "#check s4.x"
|
||||
#eval run "#check s.field1.x"
|
||||
#eval run "#check s.field2.fst"
|
||||
#eval run "#check s.field2.fst.w"
|
||||
#eval run "#check s.1.x"
|
||||
#eval run "#check s.2.1.x"
|
||||
#eval run "#check d.1"
|
||||
#eval run "#check d.2.x"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue