feat: add resolveField

This commit is contained in:
Leonardo de Moura 2019-12-13 09:41:06 -08:00
parent be50f24d64
commit 1b701dae2f
2 changed files with 85 additions and 26 deletions

View file

@ -5,6 +5,7 @@ Authors: Leonardo de Moura
-/
prelude
import Init.Lean.Util.Sorry
import Init.Lean.Structure
import Init.Lean.Meta
import Init.Lean.Elab.Log
import Init.Lean.Elab.Alias
@ -51,12 +52,12 @@ instance TermElabM.inhabited {α} : Inhabited (TermElabM α) :=
instance TermElabResult.inhabited : Inhabited TermElabResult := ⟨EStateM.Result.ok (arbitrary _) (arbitrary _)⟩
inductive Projection
inductive Field
| num (fieldIdx : Nat)
| str (fieldName : String)
instance Projection.hasToString : HasToString Projection :=
⟨fun p => match p with | Projection.num n => toString n | Projection.str s => s⟩
instance Field.hasToString : HasToString Field :=
⟨fun p => match p with | Field.num n => toString n | Field.str s => s⟩
/--
Execute `x`, save resulting expression and new state.
@ -466,7 +467,7 @@ pure $ namedArgs.push namedArg
private def resolveLocalNameAux (lctx : LocalContext) : Name → List String → Option (Expr × List String)
| n@(Name.str pre s _), projs =>
match lctx.findFromUserName n with
match lctx.findFromUserName? n with
| some decl => some (decl.toExpr, projs)
| none => resolveLocalNameAux pre (s::projs)
| _, _ => none
@ -611,50 +612,108 @@ let argIdx := 0;
let instMVars := #[];
elabAppArgsAux ref args expectedType? explicit argIdx namedArgs instMVars fType f
private def elabAppProjsAux (ref : Syntax) (namedArgs : Array NamedArg) (args : Array Syntax) (expectedType? : Option Expr) (explicit : Bool)
: Expr → List Projection → TermElabM Expr
inductive FieldResolution
| projFn (fieldName : Name) (baseStructName : Name) (structName : Name)
| projIdx (structName : Name) (idx : Nat)
| const (constName : Name)
| localRec (fvar : Expr)
private def throwFieldError {α} (ref : Syntax) (e : Expr) (eType : Expr) (msg : MessageData) : TermElabM α :=
throwError ref $ msg ++ indentExpr e ++ Format.line ++ "has type" ++ indentExpr eType
private def resolveField (ref : Syntax) (e : Expr) (eType : Expr) (field : Field) : TermElabM FieldResolution :=
match eType.getAppFn, field with
| Expr.const structName _ _, Field.num idx => do
when (idx == 0) $
throwError ref "invalid projection, index must be greater than 0";
env ← getEnv;
unless (isStructureLike env structName) $
throwFieldError ref e eType "invalid projection, structure expected";
let fieldNames := getStructureFields env structName;
if h : idx - 1 < fieldNames.size then
if isStructure env structName then
pure $ FieldResolution.projFn (fieldNames.get ⟨idx - 1, h⟩) structName structName
else
/- `structName` was declared using `inductive` command.
So, we don't projection functions for it. Thus, we use `Expr.proj` -/
pure $ FieldResolution.projIdx structName (idx - 1)
else
throwFieldError ref e eType ("invalid projection, structure has only " ++ toString fieldNames.size ++ " field(s)")
| Expr.const structName _ _, Field.str fieldName => do
env ← getEnv;
let searchEnv (fullName : Name) : TermElabM FieldResolution := do {
match env.find fullName with
| some _ => pure $ FieldResolution.const fullName
| none => throwFieldError ref e eType $
"invalid field notation, '" ++ fieldName ++ "' is not a valid \"field\" because environment does not contain '" ++ fullName ++ "'"
};
let searchLCtx : Unit → TermElabM FieldResolution := fun _ => do {
let fullName := structName ++ fieldName;
currNamespace ← getCurrNamespace;
let localName := fullName.replacePrefix currNamespace Name.anonymous;
lctx ← getLCtx;
match lctx.findFromUserName? localName with
| some localDecl =>
if localDecl.binderInfo == BinderInfo.auxDecl then
/- Field notation is being used to make a "local" recursive call. -/
pure $ FieldResolution.localRec localDecl.toExpr
else
searchEnv fullName
| none => searchEnv fullName
};
if isStructure env structName then
match findField? env structName fieldName with
| some baseStructName => pure $ FieldResolution.projFn fieldName baseStructName structName
| none => searchLCtx ()
else
searchLCtx ()
| _, _ => throwFieldError ref e eType "invalid field notation, type is not of the form (C ...) where C is a constant"
private def elabAppFieldsAux (ref : Syntax) (namedArgs : Array NamedArg) (args : Array Syntax) (expectedType? : Option Expr) (explicit : Bool)
: Expr → List Field → TermElabM Expr
| f, [] => elabAppArgs ref f namedArgs args expectedType? explicit
| f, proj::projs => do
| f, proj::fields => do
fType ← inferType ref f;
-- TODO
elabAppArgs ref f namedArgs args expectedType? explicit
private def elabAppProjs (ref : Syntax) (f : Expr) (projs : List Projection) (namedArgs : Array NamedArg) (args : Array Syntax)
private def elabAppFields (ref : Syntax) (f : Expr) (fields : List Field) (namedArgs : Array NamedArg) (args : Array Syntax)
(expectedType? : Option Expr) (explicit : Bool) : TermElabM Expr := do
when (!projs.isEmpty && explicit) $ throwError ref "invalid use of projection notation with `@` modifier";
elabAppProjsAux ref namedArgs args expectedType? explicit f projs
when (!fields.isEmpty && explicit) $ throwError ref "invalid use of projection notation with `@` modifier";
elabAppFieldsAux ref namedArgs args expectedType? explicit f fields
private partial def elabAppFn (ref : Syntax) : Syntax → List Projection → Array NamedArg → Array Syntax → Option Expr → Bool → Array TermElabResult → TermElabM (Array TermElabResult)
| f, projs, namedArgs, args, expectedType?, explicit, acc =>
private partial def elabAppFn (ref : Syntax) : Syntax → List Field → Array NamedArg → Array Syntax → Option Expr → Bool → Array TermElabResult → TermElabM (Array TermElabResult)
| f, fields, namedArgs, args, expectedType?, explicit, acc =>
let k := f.getKind;
if k == `Lean.Parser.Term.explicit then
-- `f` is of the form `@ id`
elabAppFn (f.getArg 1) projs namedArgs args expectedType? true acc
elabAppFn (f.getArg 1) fields namedArgs args expectedType? true acc
else if k == choiceKind then
f.getArgs.foldlM (fun acc f => elabAppFn f projs namedArgs args expectedType? explicit acc) acc
f.getArgs.foldlM (fun acc f => elabAppFn f fields namedArgs args expectedType? explicit acc) acc
else if k == `Lean.Parser.Term.proj then
-- term `.` (fieldIdx <|> ident)
let field := f.getArg 2;
match field.isFieldIdx?, field with
| some idx, _ => elabAppFn (f.getArg 0) (Projection.num idx :: projs) namedArgs args expectedType? true acc
| _, Syntax.ident _ val _ _ => elabAppFn (f.getArg 0) (Projection.str val.toString :: projs) namedArgs args expectedType? true acc
| some idx, _ => elabAppFn (f.getArg 0) (Field.num idx :: fields) namedArgs args expectedType? true acc
| _, Syntax.ident _ val _ _ => elabAppFn (f.getArg 0) (Field.str val.toString :: fields) namedArgs args expectedType? true acc
| _, _ => throwError field "unexpected kind of field access"
else if k == `Lean.Parser.Term.id then
-- ident (explicitUniv | namedPattern)?
-- Remark: `namedPattern` should already have been expanded
match f.getArg 0 with
| Syntax.ident _ _ n preresolved => do
us ← elabExplicitUniv (f.getArg 1); -- `namedPattern` should already have been expanded
fprojs ← resolveName n preresolved us f;
fprojs.foldlM
(fun acc ⟨f, projs'⟩ => do
let projs' := projs'.map Projection.str;
s ← observing $ elabAppProjs ref f (projs' ++ projs) namedArgs args expectedType? explicit;
us ← elabExplicitUniv (f.getArg 1);
funFields ← resolveName n preresolved us f;
funFields.foldlM
(fun acc ⟨f, fields'⟩ => do
let fields' := fields'.map Field.str;
s ← observing $ elabAppFields ref f (fields' ++ fields) namedArgs args expectedType? explicit;
pure $ acc.push s)
acc
| _ => unreachable!
else do
f ← withoutPostponing $ elabTerm f none;
s ← observing $ elabAppProjs ref f projs namedArgs args expectedType? explicit;
s ← observing $ elabAppFields ref f fields namedArgs args expectedType? explicit;
pure $ acc.push s
private def getSuccess (candidates : Array TermElabResult) : Array TermElabResult :=

View file

@ -131,7 +131,7 @@ match lctx with
| some decl => { fvarIdToDecl := map.erase decl.fvarId, decls := popTailNoneAux decls.pop }
@[export lean_local_ctx_find_from_user_name]
def findFromUserName (lctx : LocalContext) (userName : Name) : Option LocalDecl :=
def findFromUserName? (lctx : LocalContext) (userName : Name) : Option LocalDecl :=
lctx.decls.findRev (fun decl =>
match decl with
| none => none
@ -139,7 +139,7 @@ lctx.decls.findRev (fun decl =>
@[export lean_local_ctx_uses_user_name]
def usesUserName (lctx : LocalContext) (userName : Name) : Bool :=
(lctx.findFromUserName userName).isSome
(lctx.findFromUserName? userName).isSome
partial def getUnusedNameAux (lctx : LocalContext) (suggestion : Name) : Nat → Name × Nat
| i =>
@ -160,7 +160,7 @@ lctx.decls.get! (lctx.decls.size - 1)
def renameUserName (lctx : LocalContext) (fromName : Name) (toName : Name) : LocalContext :=
match lctx with
| { fvarIdToDecl := map, decls := decls } =>
match lctx.findFromUserName fromName with
match lctx.findFromUserName? fromName with
| none => lctx
| some decl =>
let decl := decl.updateUserName toName;