feat: field projections

This commit is contained in:
Leonardo de Moura 2019-12-14 13:29:14 -08:00
parent 32cebc3e76
commit e25bd36dc5
2 changed files with 122 additions and 38 deletions

View file

@ -450,15 +450,24 @@ fun stx expectedType? => do
def elabExplicitUniv (stx : Syntax) : TermElabM (List Level) :=
pure [] -- TODO
inductive Arg
| stx (val : Syntax)
| expr (val : Expr)
instance Arg.inhabited : Inhabited Arg := ⟨Arg.stx (arbitrary _)⟩
instance Arg.hasToString : HasToString Arg :=
⟨fun arg => match arg with
| Arg.stx val => toString val
| Arg.expr val => toString val⟩
structure NamedArg :=
(name : Name)
(val : Syntax)
(stx : Syntax)
(name : Name) (val : Arg)
instance NamedArg.hasToString : HasToString NamedArg :=
⟨fun s => "(" ++ toString s.name ++ " := " ++ toString s.val ++ ")"⟩
instance NamedArg.inhabited : Inhabited NamedArg := ⟨{ name := arbitrary _, val := arbitrary _, stx := arbitrary _ }⟩
instance NamedArg.inhabited : Inhabited NamedArg := ⟨{ name := arbitrary _, val := arbitrary _ }⟩
def addNamedArg (namedArgs : Array NamedArg) (namedArg : NamedArg) (ref : Syntax) : TermElabM (Array NamedArg) := do
when (namedArgs.any $ fun namedArg' => namedArg.name == namedArg'.name) $
@ -479,20 +488,28 @@ pure $ resolveLocalNameAux lctx n []
private def mkFreshLevelMVars (ref : Syntax) (num : Nat) : TermElabM (List Level) :=
num.foldM (fun _ us => do u ← mkFreshLevelMVar ref; pure $ u::us) []
def mkConst (ref : Syntax) (constName : Name) (explicitLevels : List Level := []) : TermElabM Expr := do
env ← getEnv;
match env.find constName with
| none => throwError ref ("unknown constant '" ++ constName ++ "'")
| some cinfo =>
if explicitLevels.length > cinfo.lparams.length then
throwError ref ("too many explicit universe levels")
else do
let numMissingLevels := cinfo.lparams.length - explicitLevels.length;
us ← mkFreshLevelMVars ref numMissingLevels;
pure $ Lean.mkConst constName (explicitLevels ++ us)
private def mkConsts (ref : Syntax) (candidates : List (Name × List String)) (explicitLevels : List Level) : TermElabM (List (Expr × List String)) := do
env ← getEnv;
candidates.foldlM
(fun result ⟨constName, projs⟩ =>
match env.find constName with
| none => unreachable!
| some cinfo =>
if explicitLevels.length > cinfo.lparams.length then
-- Remark: we discard candidate because of the number of explicit universe levels provided.
pure result
else do
let numMissingLevels := cinfo.lparams.length - explicitLevels.length;
us ← mkFreshLevelMVars ref numMissingLevels;
pure $ (mkConst constName (explicitLevels ++ us), projs) :: result)
catch
(do const ← mkConst ref constName explicitLevels;
pure $ (const, projs) :: result)
(fun _ =>
-- Remark: we discard candidates based on the number of explicit universe levels provided.
pure result))
[]
def resolveName (n : Name) (preresolved : List (Name × List String)) (explicitLevels : List Level) (ref : Syntax) : TermElabM (List (Expr × List String)) := do
@ -556,7 +573,15 @@ condM (isExprMVarAssigned instMVar) (pure ()) $ do
def synthesizeInstMVars (ref : Syntax) (instMVars : Array MVarId) : TermElabM Unit :=
instMVars.forM $ synthesizeInstMVar ref
private partial def elabAppArgsAux (ref : Syntax) (args : Array Syntax) (expectedType? : Option Expr) (explicit : Bool)
private def elabArg (ref : Syntax) (arg : Arg) (expectedType : Expr) : TermElabM Expr :=
match arg with
| Arg.expr val => do
valType ← inferType ref val;
ensureHasType ref expectedType valType val
| Arg.stx val =>
elabTerm val expectedType
private partial def elabAppArgsAux (ref : Syntax) (args : Array Arg) (expectedType? : Option Expr) (explicit : Bool)
: Nat → Array NamedArg → Array MVarId → Expr → Expr → TermElabM Expr
| argIdx, namedArgs, instMVars, eType, e => do
let finalize : Unit → TermElabM Expr := fun _ => do {
@ -571,15 +596,15 @@ private partial def elabAppArgsAux (ref : Syntax) (args : Array Syntax) (expecte
| Expr.forallE n d b c =>
match namedArgs.findIdx? (fun namedArg => namedArg.name == n) with
| some idx => do
let arg := namedArgs.get! idx;
let arg := namedArgs.get! idx;
let namedArgs := namedArgs.eraseIdx idx;
a ← elabTerm arg.val d;
elabAppArgsAux argIdx namedArgs instMVars (b.instantiate1 a) (mkApp e a)
argElab ← elabArg ref arg.val d;
elabAppArgsAux argIdx namedArgs instMVars (b.instantiate1 argElab) (mkApp e argElab)
| none =>
let processExplictArg : Unit → TermElabM Expr := fun _ => do {
if h : argIdx < args.size then do
a ← elabTerm (args.get ⟨argIdx, h⟩) d;
elabAppArgsAux (argIdx + 1) namedArgs instMVars (b.instantiate1 a) (mkApp e a)
argElab ← elabArg ref (args.get ⟨argIdx, h⟩) d;
elabAppArgsAux (argIdx + 1) namedArgs instMVars (b.instantiate1 argElab) (mkApp e argElab)
else if namedArgs.isEmpty then
finalize ()
else
@ -605,7 +630,7 @@ private partial def elabAppArgsAux (ref : Syntax) (args : Array Syntax) (expecte
-- TODO: try `HasCoeToFun`
throwError ref "too many arguments"
private def elabAppArgs (ref : Syntax) (f : Expr) (namedArgs : Array NamedArg) (args : Array Syntax)
private def elabAppArgs (ref : Syntax) (f : Expr) (namedArgs : Array NamedArg) (args : Array Arg)
(expectedType? : Option Expr) (explicit : Bool) : TermElabM Expr := do
fType ← inferType ref f;
let argIdx := 0;
@ -613,7 +638,7 @@ let instMVars := #[];
elabAppArgsAux ref args expectedType? explicit argIdx namedArgs instMVars fType f
inductive FieldResolution
| projFn (fieldName : Name) (baseStructName : Name) (structName : Name)
| projFn (baseStructName : Name) (structName : Name) (fieldName : Name)
| projIdx (structName : Name) (idx : Nat)
| const (constName : Name)
| localRec (fvar : Expr)
@ -632,7 +657,7 @@ match eType.getAppFn, field with
let fieldNames := getStructureFields env structName;
if h : idx - 1 < fieldNames.size then
if isStructure env structName then
pure $ FieldResolution.projFn (fieldNames.get ⟨idx - 1, h⟩) structName structName
pure $ FieldResolution.projFn structName structName (fieldNames.get ⟨idx - 1, h⟩)
else
/- `structName` was declared using `inductive` command.
So, we don't projection functions for it. Thus, we use `Expr.proj` -/
@ -663,7 +688,7 @@ match eType.getAppFn, field with
};
if isStructure env structName then
match findField? env structName fieldName with
| some baseStructName => pure $ FieldResolution.projFn fieldName baseStructName structName
| some baseStructName => pure $ FieldResolution.projFn baseStructName structName fieldName
| none => searchLCtx ()
else
searchLCtx ()
@ -685,20 +710,45 @@ private def resolveField (ref : Syntax) (e : Expr) (field : Field) : TermElabM F
eType ← inferType ref e;
resolveFieldLoop ref e field eType #[]
private def elabAppFieldsAux (ref : Syntax) (namedArgs : Array NamedArg) (args : Array Syntax) (expectedType? : Option Expr) (explicit : Bool)
private partial def mkBaseProjections (ref : Syntax) (baseStructName : Name) (structName : Name) (e : Expr) : TermElabM Expr := do
env ← getEnv;
match getPathToBaseStructure? env baseStructName structName with
| none => throwError ref "failed to access field in parent structure"
| some path =>
path.foldlM
(fun e projFunName => do
projFn ← mkConst ref projFunName;
elabAppArgs ref projFn #[{ name := `self, val := Arg.expr e }] #[] none false)
e
private def elabAppFieldsAux (ref : Syntax) (namedArgs : Array NamedArg) (args : Array Arg) (expectedType? : Option Expr) (explicit : Bool)
: Expr → List Field → TermElabM Expr
| f, [] => elabAppArgs ref f namedArgs args expectedType? explicit
| f, field::fields => do
fType ← inferType ref f;
-- TODO
elabAppArgs ref f namedArgs args expectedType? explicit
fieldRes ← resolveField ref f field;
match fieldRes with
| FieldResolution.projIdx structName idx =>
let f := mkProj structName idx f;
elabAppFieldsAux f fields
| FieldResolution.projFn baseStructName structName fieldName => do
f ← mkBaseProjections ref baseStructName structName f;
projFn ← mkConst ref (baseStructName ++ fieldName);
if fields.isEmpty then do
namedArgs ← addNamedArg namedArgs { name := `self, val := Arg.expr f } ref;
elabAppArgs ref projFn namedArgs args expectedType? explicit
else do
f ← elabAppArgs ref projFn #[{ name := `self, val := Arg.expr f }] #[] none false;
elabAppFieldsAux f fields
| _ =>
-- TODO
elabAppArgs ref f namedArgs args expectedType? explicit
private def elabAppFields (ref : Syntax) (f : Expr) (fields : List Field) (namedArgs : Array NamedArg) (args : Array Syntax)
private def elabAppFields (ref : Syntax) (f : Expr) (fields : List Field) (namedArgs : Array NamedArg) (args : Array Arg)
(expectedType? : Option Expr) (explicit : Bool) : TermElabM Expr := do
when (!fields.isEmpty && explicit) $ throwError ref "invalid use of projection notation with `@` modifier";
when (!fields.isEmpty && explicit) $ throwError ref "invalid use of field notation with `@` modifier";
elabAppFieldsAux ref namedArgs args expectedType? explicit f fields
private partial def elabAppFn (ref : Syntax) : Syntax → List Field → Array NamedArg → Array Syntax → Option Expr → Bool → Array TermElabResult → TermElabM (Array TermElabResult)
private partial def elabAppFn (ref : Syntax) : Syntax → List Field → Array NamedArg → Array Arg → Option Expr → Bool → Array TermElabResult → TermElabM (Array TermElabResult)
| f, fields, namedArgs, args, expectedType?, explicit, acc =>
let k := f.getKind;
if k == `Lean.Parser.Term.explicit then
@ -710,8 +760,8 @@ private partial def elabAppFn (ref : Syntax) : Syntax → List Field → Array N
-- term `.` (fieldIdx <|> ident)
let field := f.getArg 2;
match field.isFieldIdx?, field with
| some idx, _ => elabAppFn (f.getArg 0) (Field.num idx :: fields) namedArgs args expectedType? true acc
| _, Syntax.ident _ val _ _ => elabAppFn (f.getArg 0) (Field.str val.toString :: fields) namedArgs args expectedType? true acc
| some idx, _ => elabAppFn (f.getArg 0) (Field.num idx :: fields) namedArgs args expectedType? explicit acc
| _, Syntax.ident _ val _ _ => elabAppFn (f.getArg 0) (Field.str val.toString :: fields) namedArgs args expectedType? explicit acc
| _, _ => throwError field "unexpected kind of field access"
else if k == `Lean.Parser.Term.id then
-- ident (explicitUniv | namedPattern)?
@ -752,7 +802,7 @@ msgs ← failures.mapM $ fun failure =>
| EStateM.Result.error ex s => toMessageData ex stx;
throwError stx ("overloaded, errors " ++ MessageData.ofArray msgs)
private def elabAppAux (ref : Syntax) (f : Syntax) (namedArgs : Array NamedArg) (args : Array Syntax) (expectedType? : Option Expr) : TermElabM Expr := do
private def elabAppAux (ref : Syntax) (f : Syntax) (namedArgs : Array NamedArg) (args : Array Arg) (expectedType? : Option Expr) : TermElabM Expr := do
/- TODO: if `f` contains `choice` or overloaded symbols, `mayPostpone == true`, and `expectedType? == some ?m` where `?m` is not assigned,
then we should postpone until `?m` is assigned.
Another (more expensive) option is: execute, and if successes > 1, `mayPostpone == true`, and `expectedType? == some ?m` where `?m` is not assigned,
@ -782,18 +832,18 @@ private partial def expandAppAux : Syntax → Array Syntax → Syntax × Array S
expandAppAux fn (args.push arg))
(fun _ => (stx, args.reverse))
private def expandApp (stx : Syntax) : TermElabM (Syntax × Array NamedArg × Array Syntax) := do
private def expandApp (stx : Syntax) : TermElabM (Syntax × Array NamedArg × Array Arg) := do
let (f, args) := expandAppAux stx #[];
(namedArgs, args) ← args.foldlM
(fun (acc : Array NamedArg × Array Syntax) arg =>
(fun (acc : Array NamedArg × Array Arg) arg =>
let (namedArgs, args) := acc;
arg.ifNodeKind `Lean.Parser.Term.namedArgument
(fun argNode => do
-- `(` ident `:=` term `)`
namedArgs ← addNamedArg acc.1 { name := argNode.getIdAt 1, val := argNode.getArg 3, stx := arg } arg;
namedArgs ← addNamedArg acc.1 { name := argNode.getIdAt 1, val := Arg.stx $ argNode.getArg 3 } arg;
pure (namedArgs, args))
(fun _ =>
pure (namedArgs, args.push arg)))
pure (namedArgs, args.push $ Arg.stx arg)))
(#[], #[]);
pure (f, namedArgs, args)

View file

@ -63,3 +63,37 @@ pure "hello"
#check ()
#check run
end"
structure S1 :=
(x y : Nat := 0)
structure S2 extends S1 :=
(z : Nat := 0)
structure S3 :=
(w : Nat := 0)
structure S4 extends S2, S3 :=
(s : Nat := 0)
def s4 : S4 := {}
structure S (α : Type) :=
(field1 : S4 := {})
(field2 : S4 × S4 := ({}, {}))
(field3 : α)
inductive D (α : Type)
| mk (a : α) (s : S4) : D
def s : S Nat := { field3 := 0 }
def d : D Nat := D.mk 10 {}
#eval run "#check s4.x"
#eval run "#check s.field1.x"
#eval run "#check s.field2.fst"
#eval run "#check s.field2.fst.w"
#eval run "#check s.1.x"
#eval run "#check s.2.1.x"
#eval run "#check d.1"
#eval run "#check d.2.x"