diff --git a/src/Init/Lean/Elab/Term.lean b/src/Init/Lean/Elab/Term.lean index 31c939cf73..16288ec600 100644 --- a/src/Init/Lean/Elab/Term.lean +++ b/src/Init/Lean/Elab/Term.lean @@ -450,15 +450,24 @@ fun stx expectedType? => do def elabExplicitUniv (stx : Syntax) : TermElabM (List Level) := pure [] -- TODO +inductive Arg +| stx (val : Syntax) +| expr (val : Expr) + +instance Arg.inhabited : Inhabited Arg := ⟨Arg.stx (arbitrary _)⟩ + +instance Arg.hasToString : HasToString Arg := +⟨fun arg => match arg with + | Arg.stx val => toString val + | Arg.expr val => toString val⟩ + structure NamedArg := -(name : Name) -(val : Syntax) -(stx : Syntax) +(name : Name) (val : Arg) instance NamedArg.hasToString : HasToString NamedArg := ⟨fun s => "(" ++ toString s.name ++ " := " ++ toString s.val ++ ")"⟩ -instance NamedArg.inhabited : Inhabited NamedArg := ⟨{ name := arbitrary _, val := arbitrary _, stx := arbitrary _ }⟩ +instance NamedArg.inhabited : Inhabited NamedArg := ⟨{ name := arbitrary _, val := arbitrary _ }⟩ def addNamedArg (namedArgs : Array NamedArg) (namedArg : NamedArg) (ref : Syntax) : TermElabM (Array NamedArg) := do when (namedArgs.any $ fun namedArg' => namedArg.name == namedArg'.name) $ @@ -479,20 +488,28 @@ pure $ resolveLocalNameAux lctx n [] private def mkFreshLevelMVars (ref : Syntax) (num : Nat) : TermElabM (List Level) := num.foldM (fun _ us => do u ← mkFreshLevelMVar ref; pure $ u::us) [] +def mkConst (ref : Syntax) (constName : Name) (explicitLevels : List Level := []) : TermElabM Expr := do +env ← getEnv; +match env.find constName with +| none => throwError ref ("unknown constant '" ++ constName ++ "'") +| some cinfo => + if explicitLevels.length > cinfo.lparams.length then + throwError ref ("too many explicit universe levels") + else do + let numMissingLevels := cinfo.lparams.length - explicitLevels.length; + us ← mkFreshLevelMVars ref numMissingLevels; + pure $ Lean.mkConst constName (explicitLevels ++ us) + private def mkConsts (ref : Syntax) (candidates : List (Name × List String)) (explicitLevels : List Level) : TermElabM (List (Expr × List String)) := do env ← getEnv; candidates.foldlM (fun result ⟨constName, projs⟩ => - match env.find constName with - | none => unreachable! - | some cinfo => - if explicitLevels.length > cinfo.lparams.length then - -- Remark: we discard candidate because of the number of explicit universe levels provided. - pure result - else do - let numMissingLevels := cinfo.lparams.length - explicitLevels.length; - us ← mkFreshLevelMVars ref numMissingLevels; - pure $ (mkConst constName (explicitLevels ++ us), projs) :: result) + catch + (do const ← mkConst ref constName explicitLevels; + pure $ (const, projs) :: result) + (fun _ => + -- Remark: we discard candidates based on the number of explicit universe levels provided. + pure result)) [] def resolveName (n : Name) (preresolved : List (Name × List String)) (explicitLevels : List Level) (ref : Syntax) : TermElabM (List (Expr × List String)) := do @@ -556,7 +573,15 @@ condM (isExprMVarAssigned instMVar) (pure ()) $ do def synthesizeInstMVars (ref : Syntax) (instMVars : Array MVarId) : TermElabM Unit := instMVars.forM $ synthesizeInstMVar ref -private partial def elabAppArgsAux (ref : Syntax) (args : Array Syntax) (expectedType? : Option Expr) (explicit : Bool) +private def elabArg (ref : Syntax) (arg : Arg) (expectedType : Expr) : TermElabM Expr := +match arg with +| Arg.expr val => do + valType ← inferType ref val; + ensureHasType ref expectedType valType val +| Arg.stx val => + elabTerm val expectedType + +private partial def elabAppArgsAux (ref : Syntax) (args : Array Arg) (expectedType? : Option Expr) (explicit : Bool) : Nat → Array NamedArg → Array MVarId → Expr → Expr → TermElabM Expr | argIdx, namedArgs, instMVars, eType, e => do let finalize : Unit → TermElabM Expr := fun _ => do { @@ -571,15 +596,15 @@ private partial def elabAppArgsAux (ref : Syntax) (args : Array Syntax) (expecte | Expr.forallE n d b c => match namedArgs.findIdx? (fun namedArg => namedArg.name == n) with | some idx => do - let arg := namedArgs.get! idx; + let arg := namedArgs.get! idx; let namedArgs := namedArgs.eraseIdx idx; - a ← elabTerm arg.val d; - elabAppArgsAux argIdx namedArgs instMVars (b.instantiate1 a) (mkApp e a) + argElab ← elabArg ref arg.val d; + elabAppArgsAux argIdx namedArgs instMVars (b.instantiate1 argElab) (mkApp e argElab) | none => let processExplictArg : Unit → TermElabM Expr := fun _ => do { if h : argIdx < args.size then do - a ← elabTerm (args.get ⟨argIdx, h⟩) d; - elabAppArgsAux (argIdx + 1) namedArgs instMVars (b.instantiate1 a) (mkApp e a) + argElab ← elabArg ref (args.get ⟨argIdx, h⟩) d; + elabAppArgsAux (argIdx + 1) namedArgs instMVars (b.instantiate1 argElab) (mkApp e argElab) else if namedArgs.isEmpty then finalize () else @@ -605,7 +630,7 @@ private partial def elabAppArgsAux (ref : Syntax) (args : Array Syntax) (expecte -- TODO: try `HasCoeToFun` throwError ref "too many arguments" -private def elabAppArgs (ref : Syntax) (f : Expr) (namedArgs : Array NamedArg) (args : Array Syntax) +private def elabAppArgs (ref : Syntax) (f : Expr) (namedArgs : Array NamedArg) (args : Array Arg) (expectedType? : Option Expr) (explicit : Bool) : TermElabM Expr := do fType ← inferType ref f; let argIdx := 0; @@ -613,7 +638,7 @@ let instMVars := #[]; elabAppArgsAux ref args expectedType? explicit argIdx namedArgs instMVars fType f inductive FieldResolution -| projFn (fieldName : Name) (baseStructName : Name) (structName : Name) +| projFn (baseStructName : Name) (structName : Name) (fieldName : Name) | projIdx (structName : Name) (idx : Nat) | const (constName : Name) | localRec (fvar : Expr) @@ -632,7 +657,7 @@ match eType.getAppFn, field with let fieldNames := getStructureFields env structName; if h : idx - 1 < fieldNames.size then if isStructure env structName then - pure $ FieldResolution.projFn (fieldNames.get ⟨idx - 1, h⟩) structName structName + pure $ FieldResolution.projFn structName structName (fieldNames.get ⟨idx - 1, h⟩) else /- `structName` was declared using `inductive` command. So, we don't projection functions for it. Thus, we use `Expr.proj` -/ @@ -663,7 +688,7 @@ match eType.getAppFn, field with }; if isStructure env structName then match findField? env structName fieldName with - | some baseStructName => pure $ FieldResolution.projFn fieldName baseStructName structName + | some baseStructName => pure $ FieldResolution.projFn baseStructName structName fieldName | none => searchLCtx () else searchLCtx () @@ -685,20 +710,45 @@ private def resolveField (ref : Syntax) (e : Expr) (field : Field) : TermElabM F eType ← inferType ref e; resolveFieldLoop ref e field eType #[] -private def elabAppFieldsAux (ref : Syntax) (namedArgs : Array NamedArg) (args : Array Syntax) (expectedType? : Option Expr) (explicit : Bool) +private partial def mkBaseProjections (ref : Syntax) (baseStructName : Name) (structName : Name) (e : Expr) : TermElabM Expr := do +env ← getEnv; +match getPathToBaseStructure? env baseStructName structName with +| none => throwError ref "failed to access field in parent structure" +| some path => + path.foldlM + (fun e projFunName => do + projFn ← mkConst ref projFunName; + elabAppArgs ref projFn #[{ name := `self, val := Arg.expr e }] #[] none false) + e + +private def elabAppFieldsAux (ref : Syntax) (namedArgs : Array NamedArg) (args : Array Arg) (expectedType? : Option Expr) (explicit : Bool) : Expr → List Field → TermElabM Expr | f, [] => elabAppArgs ref f namedArgs args expectedType? explicit | f, field::fields => do - fType ← inferType ref f; - -- TODO - elabAppArgs ref f namedArgs args expectedType? explicit + fieldRes ← resolveField ref f field; + match fieldRes with + | FieldResolution.projIdx structName idx => + let f := mkProj structName idx f; + elabAppFieldsAux f fields + | FieldResolution.projFn baseStructName structName fieldName => do + f ← mkBaseProjections ref baseStructName structName f; + projFn ← mkConst ref (baseStructName ++ fieldName); + if fields.isEmpty then do + namedArgs ← addNamedArg namedArgs { name := `self, val := Arg.expr f } ref; + elabAppArgs ref projFn namedArgs args expectedType? explicit + else do + f ← elabAppArgs ref projFn #[{ name := `self, val := Arg.expr f }] #[] none false; + elabAppFieldsAux f fields + | _ => + -- TODO + elabAppArgs ref f namedArgs args expectedType? explicit -private def elabAppFields (ref : Syntax) (f : Expr) (fields : List Field) (namedArgs : Array NamedArg) (args : Array Syntax) +private def elabAppFields (ref : Syntax) (f : Expr) (fields : List Field) (namedArgs : Array NamedArg) (args : Array Arg) (expectedType? : Option Expr) (explicit : Bool) : TermElabM Expr := do -when (!fields.isEmpty && explicit) $ throwError ref "invalid use of projection notation with `@` modifier"; +when (!fields.isEmpty && explicit) $ throwError ref "invalid use of field notation with `@` modifier"; elabAppFieldsAux ref namedArgs args expectedType? explicit f fields -private partial def elabAppFn (ref : Syntax) : Syntax → List Field → Array NamedArg → Array Syntax → Option Expr → Bool → Array TermElabResult → TermElabM (Array TermElabResult) +private partial def elabAppFn (ref : Syntax) : Syntax → List Field → Array NamedArg → Array Arg → Option Expr → Bool → Array TermElabResult → TermElabM (Array TermElabResult) | f, fields, namedArgs, args, expectedType?, explicit, acc => let k := f.getKind; if k == `Lean.Parser.Term.explicit then @@ -710,8 +760,8 @@ private partial def elabAppFn (ref : Syntax) : Syntax → List Field → Array N -- term `.` (fieldIdx <|> ident) let field := f.getArg 2; match field.isFieldIdx?, field with - | some idx, _ => elabAppFn (f.getArg 0) (Field.num idx :: fields) namedArgs args expectedType? true acc - | _, Syntax.ident _ val _ _ => elabAppFn (f.getArg 0) (Field.str val.toString :: fields) namedArgs args expectedType? true acc + | some idx, _ => elabAppFn (f.getArg 0) (Field.num idx :: fields) namedArgs args expectedType? explicit acc + | _, Syntax.ident _ val _ _ => elabAppFn (f.getArg 0) (Field.str val.toString :: fields) namedArgs args expectedType? explicit acc | _, _ => throwError field "unexpected kind of field access" else if k == `Lean.Parser.Term.id then -- ident (explicitUniv | namedPattern)? @@ -752,7 +802,7 @@ msgs ← failures.mapM $ fun failure => | EStateM.Result.error ex s => toMessageData ex stx; throwError stx ("overloaded, errors " ++ MessageData.ofArray msgs) -private def elabAppAux (ref : Syntax) (f : Syntax) (namedArgs : Array NamedArg) (args : Array Syntax) (expectedType? : Option Expr) : TermElabM Expr := do +private def elabAppAux (ref : Syntax) (f : Syntax) (namedArgs : Array NamedArg) (args : Array Arg) (expectedType? : Option Expr) : TermElabM Expr := do /- TODO: if `f` contains `choice` or overloaded symbols, `mayPostpone == true`, and `expectedType? == some ?m` where `?m` is not assigned, then we should postpone until `?m` is assigned. Another (more expensive) option is: execute, and if successes > 1, `mayPostpone == true`, and `expectedType? == some ?m` where `?m` is not assigned, @@ -782,18 +832,18 @@ private partial def expandAppAux : Syntax → Array Syntax → Syntax × Array S expandAppAux fn (args.push arg)) (fun _ => (stx, args.reverse)) -private def expandApp (stx : Syntax) : TermElabM (Syntax × Array NamedArg × Array Syntax) := do +private def expandApp (stx : Syntax) : TermElabM (Syntax × Array NamedArg × Array Arg) := do let (f, args) := expandAppAux stx #[]; (namedArgs, args) ← args.foldlM - (fun (acc : Array NamedArg × Array Syntax) arg => + (fun (acc : Array NamedArg × Array Arg) arg => let (namedArgs, args) := acc; arg.ifNodeKind `Lean.Parser.Term.namedArgument (fun argNode => do -- `(` ident `:=` term `)` - namedArgs ← addNamedArg acc.1 { name := argNode.getIdAt 1, val := argNode.getArg 3, stx := arg } arg; + namedArgs ← addNamedArg acc.1 { name := argNode.getIdAt 1, val := Arg.stx $ argNode.getArg 3 } arg; pure (namedArgs, args)) (fun _ => - pure (namedArgs, args.push arg))) + pure (namedArgs, args.push $ Arg.stx arg))) (#[], #[]); pure (f, namedArgs, args) diff --git a/tests/lean/run/frontend1.lean b/tests/lean/run/frontend1.lean index 176ca5893a..fb37c14e9b 100644 --- a/tests/lean/run/frontend1.lean +++ b/tests/lean/run/frontend1.lean @@ -63,3 +63,37 @@ pure "hello" #check () #check run end" + +structure S1 := +(x y : Nat := 0) + +structure S2 extends S1 := +(z : Nat := 0) + +structure S3 := +(w : Nat := 0) + +structure S4 extends S2, S3 := +(s : Nat := 0) + +def s4 : S4 := {} + +structure S (α : Type) := +(field1 : S4 := {}) +(field2 : S4 × S4 := ({}, {})) +(field3 : α) + +inductive D (α : Type) +| mk (a : α) (s : S4) : D + +def s : S Nat := { field3 := 0 } +def d : D Nat := D.mk 10 {} + +#eval run "#check s4.x" +#eval run "#check s.field1.x" +#eval run "#check s.field2.fst" +#eval run "#check s.field2.fst.w" +#eval run "#check s.1.x" +#eval run "#check s.2.1.x" +#eval run "#check d.1" +#eval run "#check d.2.x"