lean4-htt/src/Lean/Elab/StructInst.lean
Leonardo de Moura 065a845bfd chore: factor out nested macro at structInst elaborator
@Kha The new macro is another example where it would be super nice to
have better support for `optional` at `match_syntax`.
2020-09-19 11:56:20 -07:00

817 lines
33 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Util.FindExpr
import Lean.Elab.App
import Lean.Elab.Binders
import Lean.Elab.Quotation
namespace Lean
namespace Elab
namespace Term
namespace StructInst
open Std (HashMap)
open Meta
/- parser! "{" >> optional (try (termParser >> "with")) >> sepBy structInstField ", " true >> optional ".." >> optional (" : " >> termParser) >> "}" -/
@[builtinMacro Lean.Parser.Term.structInst] def expandStructInstExpectedType : Macro :=
fun stx =>
let expectedArg := stx.getArg 4;
if expectedArg.isNone
then Macro.throwUnsupported
else
let expected := expectedArg.getArg 1;
let stxNew := stx.setArg 4 mkNullNode;
`(($stxNew : $expected))
/-
If `stx` is of the form `{ s with ... }` and `s` is not a local variable, expand into `let src := s; { src with ... }`.
Note that this one is not a `Macro` because we need to access the local context.
-/
private def expandNonAtomicExplicitSource (stx : Syntax) : TermElabM (Option Syntax) :=
withFreshMacroScope $
let sourceOpt := stx.getArg 1;
if sourceOpt.isNone then pure none else do
let source := sourceOpt.getArg 0;
fvar? ← isLocalIdent? source;
match fvar? with
| some _ => pure none
| none => do
src ← `(src);
let sourceOpt := sourceOpt.setArg 0 src;
let stxNew := stx.setArg 1 sourceOpt;
`(let src := $source; $stxNew)
inductive Source
| none -- structure instance source has not been provieded
| implicit (stx : Syntax) -- `..`
| explicit (stx : Syntax) (src : Expr) -- `src with`
instance Source.inhabited : Inhabited Source := ⟨Source.none⟩
def Source.isNone : Source → Bool
| Source.none => true
| _ => false
def setStructSourceSyntax (structStx : Syntax) : Source → Syntax
| Source.none => (structStx.setArg 1 mkNullNode).setArg 3 mkNullNode
| Source.implicit stx => (structStx.setArg 1 mkNullNode).setArg 3 stx
| Source.explicit stx _ => (structStx.setArg 1 stx).setArg 3 mkNullNode
private def getStructSource (stx : Syntax) : TermElabM Source :=
withRef stx $
let explicitSource := stx.getArg 1;
let implicitSource := stx.getArg 3;
if explicitSource.isNone && implicitSource.isNone then
pure Source.none
else if explicitSource.isNone then
pure $ Source.implicit implicitSource
else if implicitSource.isNone then do
fvar? ← isLocalIdent? (explicitSource.getArg 0);
match fvar? with
| none => unreachable! -- expandNonAtomicExplicitSource must have been used when we get here
| some src => pure $ Source.explicit explicitSource src
else
throwError "invalid structure instance `with` and `..` cannot be used together"
/-
We say a `{ ... }` notation is a `modifyOp` if it contains only one
```
def structInstArrayRef := parser! "[" >> termParser >>"]"
``` -/
private def isModifyOp? (stx : Syntax) : TermElabM (Option Syntax) := do
let args := (stx.getArg 2).getArgs;
s? ← args.foldSepByM
(fun arg s? =>
/- Remark: the syntax for `structInstField` is
```
def structInstLVal := (ident <|> numLit <|> structInstArrayRef) >> many (group ("." >> (ident <|> numLit)) <|> structInstArrayRef)
def structInstField := parser! structInstLVal >> " := " >> termParser
``` -/
let lval := arg.getArg 0;
let k := lval.getKind;
if k == `Lean.Parser.Term.structInstArrayRef then
match s? with
| none => pure (some arg)
| some s =>
if s.getKind == `Lean.Parser.Term.structInstArrayRef then
throwErrorAt arg "invalid {...} notation, at most one `[..]` at a given level"
else
throwErrorAt arg "invalid {...} notation, can't mix field and `[..]` at a given level"
else
match s? with
| none => pure (some arg)
| some s =>
if s.getKind == `Lean.Parser.Term.structInstArrayRef then
throwErrorAt arg "invalid {...} notation, can't mix field and `[..]` at a given level"
else
pure s?)
none;
match s? with
| none => pure none
| some s => if (s.getArg 0).getKind == `Lean.Parser.Term.structInstArrayRef then pure s? else pure none
private def elabModifyOp (stx modifyOp source : Syntax) (expectedType? : Option Expr) : TermElabM Expr :=
let continue (val : Syntax) : TermElabM Expr := do {
let lval := modifyOp.getArg 0;
let idx := lval.getArg 1;
let self := source.getArg 0;
stxNew ← `($(self).modifyOp (idx := $idx) (fun s => $val));
trace `Elab.struct.modifyOp fun _ => stx ++ "\n===>\n" ++ stxNew;
withMacroExpansion stx stxNew $ elabTerm stxNew expectedType?
}; do
trace `Elab.struct.modifyOp fun _ => modifyOp ++ "\nSource: " ++ source;
let rest := modifyOp.getArg 1;
if rest.isNone then do
continue (modifyOp.getArg 3)
else do
s ← `(s);
let valFirst := rest.getArg 0;
let valFirst := if valFirst.getKind == `Lean.Parser.Term.structInstArrayRef then valFirst else valFirst.getArg 1;
let restArgs := rest.getArgs;
let valRest := mkNullNode (restArgs.extract 1 restArgs.size);
let valField := modifyOp.setArg 0 valFirst;
let valField := valField.setArg 1 valRest;
let valSource := source.modifyArg 0 $ fun _ => s;
let val := stx.setArg 1 valSource;
let val := val.setArg 2 $ mkNullNode #[valField];
trace `Elab.struct.modifyOp fun _ => stx ++ "\nval: " ++ val;
continue val
/- Get structure name and elaborate explicit source (if available) -/
private def getStructName (stx : Syntax) (expectedType? : Option Expr) (sourceView : Source) : TermElabM Name := do
tryPostponeIfNoneOrMVar expectedType?;
let useSource : Unit → TermElabM Name := fun _ =>
match sourceView, expectedType? with
| Source.explicit _ src, _ => do
srcType ← inferType src;
srcType ← whnf srcType;
tryPostponeIfMVar srcType;
match srcType.getAppFn with
| Expr.const constName _ _ => pure constName
| _ => throwError ("invalid {...} notation, source type is not of the form (C ...)" ++ indentExpr srcType)
| _, some expectedType => throwError ("invalid {...} notation, expected type is not of the form (C ...)" ++ indentExpr expectedType)
| _, none => throwError ("invalid {...} notation, expected type must be known");
match expectedType? with
| none => useSource ()
| some expectedType => do
expectedType ← whnf expectedType;
match expectedType.getAppFn with
| Expr.const constName _ _ => pure constName
| _ => useSource ()
inductive FieldLHS
| fieldName (ref : Syntax) (name : Name)
| fieldIndex (ref : Syntax) (idx : Nat)
| modifyOp (ref : Syntax) (index : Syntax)
instance FieldLHS.inhabited : Inhabited FieldLHS := ⟨FieldLHS.fieldName (arbitrary _) (arbitrary _)⟩
instance FieldLHS.hasFormat : HasFormat FieldLHS :=
⟨fun lhs => match lhs with
| FieldLHS.fieldName _ n => fmt n
| FieldLHS.fieldIndex _ i => fmt i
| FieldLHS.modifyOp _ i => "[" ++ i.prettyPrint ++ "]"⟩
inductive FieldVal (σ : Type)
| term (stx : Syntax) : FieldVal
| nested (s : σ) : FieldVal
| default : FieldVal -- mark that field must be synthesized using default value
structure Field (σ : Type) :=
(ref : Syntax) (lhs : List FieldLHS) (val : FieldVal σ) (expr? : Option Expr := none)
instance Field.inhabited {σ} : Inhabited (Field σ) := ⟨⟨arbitrary _, [], FieldVal.term (arbitrary _), arbitrary _⟩⟩
def Field.isSimple {σ} : Field σ → Bool
| { lhs := [_], .. } => true
| _ => false
inductive Struct
| mk (ref : Syntax) (structName : Name) (fields : List (Field Struct)) (source : Source)
instance Struct.inhabited : Inhabited Struct := ⟨⟨arbitrary _, arbitrary _, [], arbitrary _⟩⟩
abbrev Fields := List (Field Struct)
/- true if all fields of the given structure are marked as `default` -/
partial def Struct.allDefault : Struct → Bool
| ⟨_, _, fields, _⟩ => fields.all fun ⟨_, _, val, _⟩ => match val with
| FieldVal.term _ => false
| FieldVal.default => true
| FieldVal.nested s => Struct.allDefault s
def Struct.ref : Struct → Syntax
| ⟨ref, _, _, _⟩ => ref
def Struct.structName : Struct → Name
| ⟨_, structName, _, _⟩ => structName
def Struct.fields : Struct → Fields
| ⟨_, _, fields, _⟩ => fields
def Struct.source : Struct → Source
| ⟨_, _, _, s⟩ => s
def formatField (formatStruct : Struct → Format) (field : Field Struct) : Format :=
Format.joinSep field.lhs " . " ++ " := " ++
match field.val with
| FieldVal.term v => v.prettyPrint
| FieldVal.nested s => formatStruct s
| FieldVal.default => "<default>"
partial def formatStruct : Struct → Format
| ⟨_, structName, fields, source⟩ =>
let fieldsFmt := Format.joinSep (fields.map (formatField formatStruct)) ", ";
match source with
| Source.none => "{" ++ fieldsFmt ++ "}"
| Source.implicit _ => "{" ++ fieldsFmt ++ " .. }"
| Source.explicit _ src => "{" ++ format src ++ " with " ++ fieldsFmt ++ "}"
instance Struct.hasFormat : HasFormat Struct := ⟨formatStruct⟩
instance Struct.hasToString : HasToString Struct := ⟨toString ∘ format⟩
instance Field.hasFormat : HasFormat (Field Struct) := ⟨formatField formatStruct⟩
instance Field.hasToString : HasToString (Field Struct) := ⟨toString ∘ format⟩
/-
Recall that `structInstField` elements have the form
```
def structInstField := parser! structInstLVal >> " := " >> termParser
def structInstLVal := (ident <|> numLit <|> structInstArrayRef) >> many (("." >> (ident <|> numLit)) <|> structInstArrayRef)
def structInstArrayRef := parser! "[" >> termParser >>"]"
-/
-- Remark: this code relies on the fact that `expandStruct` only transforms `fieldLHS.fieldName`
def FieldLHS.toSyntax (first : Bool) : FieldLHS → Syntax
| FieldLHS.modifyOp stx _ => stx
| FieldLHS.fieldName stx name => if first then mkIdentFrom stx name else mkNullNode #[mkAtomFrom stx ".", mkIdentFrom stx name]
| FieldLHS.fieldIndex stx _ => if first then stx else mkNullNode #[mkAtomFrom stx ".", stx]
def FieldVal.toSyntax : FieldVal Struct → Syntax
| FieldVal.term stx => stx
| _ => unreachable!
def Field.toSyntax : Field Struct → Syntax
| field =>
let stx := field.ref;
let stx := stx.setArg 3 field.val.toSyntax;
match field.lhs with
| first::rest =>
let stx := stx.setArg 0 $ first.toSyntax true;
let stx := stx.setArg 1 $ mkNullNode $ rest.toArray.map (FieldLHS.toSyntax false);
stx
| _ => unreachable!
private def toFieldLHS (stx : Syntax) : Except String FieldLHS :=
if stx.getKind == `Lean.Parser.Term.structInstArrayRef then
pure $ FieldLHS.modifyOp stx (stx.getArg 1)
else
-- Note that the representation of the first field is different.
let stx := if stx.getKind == nullKind then stx.getArg 1 else stx;
if stx.isIdent then pure $ FieldLHS.fieldName stx stx.getId.eraseMacroScopes
else match stx.isFieldIdx? with
| some idx => pure $ FieldLHS.fieldIndex stx idx
| none => throw "unexpected structure syntax"
private def mkStructView (stx : Syntax) (structName : Name) (source : Source) : Except String Struct := do
let args := (stx.getArg 2).getArgs;
let fieldsStx := args.filter $ fun arg => arg.getKind == `Lean.Parser.Term.structInstField;
fields ← fieldsStx.toList.mapM $ fun fieldStx => do {
let val := fieldStx.getArg 3;
first ← toFieldLHS (fieldStx.getArg 0);
rest ← (fieldStx.getArg 1).getArgs.toList.mapM toFieldLHS;
pure $ ({ref := fieldStx, lhs := first :: rest, val := FieldVal.term val } : Field Struct)
};
pure ⟨stx, structName, fields, source⟩
def Struct.modifyFieldsM {m : Type → Type} [Monad m] (s : Struct) (f : Fields → m Fields) : m Struct :=
match s with
| ⟨ref, structName, fields, source⟩ => do fields ← f fields; pure ⟨ref, structName, fields, source⟩
@[inline] def Struct.modifyFields (s : Struct) (f : Fields → Fields) : Struct :=
Id.run $ s.modifyFieldsM f
def Struct.setFields (s : Struct) (fields : Fields) : Struct :=
s.modifyFields $ fun _ => fields
private def expandCompositeFields (s : Struct) : Struct :=
s.modifyFields $ fun fields => fields.map $ fun field => match field with
| { lhs := FieldLHS.fieldName ref (Name.str Name.anonymous _ _) :: rest, .. } => field
| { lhs := FieldLHS.fieldName ref n@(Name.str _ _ _) :: rest, .. } =>
let newEntries := n.components.map $ FieldLHS.fieldName ref;
{ field with lhs := newEntries ++ rest }
| _ => field
private def expandNumLitFields (s : Struct) : TermElabM Struct :=
s.modifyFieldsM $ fun fields => do
env ← getEnv;
let fieldNames := getStructureFields env s.structName;
fields.mapM $ fun field => match field with
| { lhs := FieldLHS.fieldIndex ref idx :: rest, .. } =>
if idx == 0 then throwErrorAt ref "invalid field index, index must be greater than 0"
else if idx > fieldNames.size then throwErrorAt ref ("invalid field index, structure has only #" ++ toString fieldNames.size ++ " fields")
else pure { field with lhs := FieldLHS.fieldName ref (fieldNames.get! $ idx - 1) :: rest }
| _ => pure field
/- For example, consider the following structures:
```
structure A := (x : Nat)
structure B extends A := (y : Nat)
structure C extends B := (z : Bool)
```
This method expands parent structure fields using the path to the parent structure.
For example,
```
{ x := 0, y := 0, z := true : C }
```
is expanded into
```
{ toB.toA.x := 0, toB.y := 0, z := true : C }
``` -/
private def expandParentFields (s : Struct) : TermElabM Struct := do
env ← getEnv;
s.modifyFieldsM $ fun fields => fields.mapM $ fun field => match field with
| { lhs := FieldLHS.fieldName ref fieldName :: rest, .. } =>
match findField? env s.structName fieldName with
| none => throwErrorAt ref ("'" ++ fieldName ++ "' is not a field of structure '" ++ s.structName ++ "'")
| some baseStructName =>
if baseStructName == s.structName then pure field
else match getPathToBaseStructure? env baseStructName s.structName with
| some path => do
let path := path.map $ fun funName => match funName with
| Name.str _ s _ => FieldLHS.fieldName ref (mkNameSimple s)
| _ => unreachable!;
pure { field with lhs := path ++ field.lhs }
| _ => throwErrorAt ref ("failed to access field '" ++ fieldName ++ "' in parent structure")
| _ => pure field
private abbrev FieldMap := HashMap Name Fields
private def mkFieldMap (fields : Fields) : TermElabM FieldMap :=
fields.foldlM
(fun fieldMap field =>
match field.lhs with
| FieldLHS.fieldName _ fieldName :: rest =>
match fieldMap.find? fieldName with
| some (prevField::restFields) =>
if field.isSimple || prevField.isSimple then
throwErrorAt field.ref ("field '" ++ fieldName ++ "' has already beed specified")
else
pure $ fieldMap.insert fieldName (field::prevField::restFields)
| _ => pure $ fieldMap.insert fieldName [field]
| _ => unreachable!)
{}
private def isSimpleField? : Fields → Option (Field Struct)
| [field] => if field.isSimple then some field else none
| _ => none
private def getFieldIdx (structName : Name) (fieldNames : Array Name) (fieldName : Name) : TermElabM Nat := do
match fieldNames.findIdx? $ fun n => n == fieldName with
| some idx => pure idx
| none => throwError ("field '" ++ fieldName ++ "' is not a valid field of '" ++ structName ++ "'")
private def mkProjStx (s : Syntax) (fieldName : Name) : Syntax :=
Syntax.node `Lean.Parser.Term.proj #[s, mkAtomFrom s ".", mkIdentFrom s fieldName]
private def mkSubstructSource (structName : Name) (fieldNames : Array Name) (fieldName : Name) (src : Source) : TermElabM Source :=
match src with
| Source.explicit stx src => do
idx ← getFieldIdx structName fieldNames fieldName;
let stx := stx.modifyArg 0 $ fun stx => mkProjStx stx fieldName;
pure $ Source.explicit stx (mkProj structName idx src)
| s => pure s
@[specialize] private def groupFields (expandStruct : Struct → TermElabM Struct) (s : Struct) : TermElabM Struct := do
env ← getEnv;
let fieldNames := getStructureFields env s.structName;
withRef s.ref $
s.modifyFieldsM $ fun fields => do
fieldMap ← mkFieldMap fields;
fieldMap.toList.mapM $ fun ⟨fieldName, fields⟩ =>
match isSimpleField? fields with
| some field => pure field
| none => do
let substructFields := fields.map $ fun field => { field with lhs := field.lhs.tail! };
substructSource ← mkSubstructSource s.structName fieldNames fieldName s.source;
let field := fields.head!;
match Lean.isSubobjectField? env s.structName fieldName with
| some substructName => do
let substruct := Struct.mk s.ref substructName substructFields substructSource;
substruct ← expandStruct substruct;
pure { field with lhs := [field.lhs.head!], val := FieldVal.nested substruct }
| none => do
-- It is not a substructure field. Thus, we wrap fields using `Syntax`, and use `elabTerm` to process them.
let valStx := s.ref; -- construct substructure syntax using s.ref as template
let valStx := valStx.setArg 4 mkNullNode; -- erase optional expected type
let args := substructFields.toArray.map $ Field.toSyntax;
let valStx := valStx.setArg 2 (mkSepStx args (mkAtomFrom s.ref ","));
let valStx := setStructSourceSyntax valStx substructSource;
pure { field with lhs := [field.lhs.head!], val := FieldVal.term valStx }
def findField? (fields : Fields) (fieldName : Name) : Option (Field Struct) :=
fields.find? $ fun field =>
match field.lhs with
| [FieldLHS.fieldName _ n] => n == fieldName
| _ => false
@[specialize] private def addMissingFields (expandStruct : Struct → TermElabM Struct) (s : Struct) : TermElabM Struct := do
env ← getEnv;
let fieldNames := getStructureFields env s.structName;
let ref := s.ref;
withRef ref do
fields ← fieldNames.foldlM
(fun fields fieldName => do
match findField? s.fields fieldName with
| some field => pure $ field::fields
| none =>
let addField (val : FieldVal Struct) : TermElabM Fields := do {
pure $ { ref := s.ref, lhs := [FieldLHS.fieldName s.ref fieldName], val := val } :: fields
};
match Lean.isSubobjectField? env s.structName fieldName with
| some substructName => do
substructSource ← mkSubstructSource s.structName fieldNames fieldName s.source;
let substruct := Struct.mk s.ref substructName [] substructSource;
substruct ← expandStruct substruct;
addField (FieldVal.nested substruct)
| none =>
match s.source with
| Source.none => addField FieldVal.default
| Source.implicit _ => addField (FieldVal.term (mkHole s.ref))
| Source.explicit stx _ =>
-- stx is of the form `optional (try (termParser >> "with"))`
let src := stx.getArg 0;
let val := mkProjStx src fieldName;
addField (FieldVal.term val))
[];
pure $ s.setFields fields.reverse
private partial def expandStruct : Struct → TermElabM Struct
| s => do
let s := expandCompositeFields s;
s ← expandNumLitFields s;
s ← expandParentFields s;
s ← groupFields expandStruct s;
addMissingFields expandStruct s
structure CtorHeaderResult :=
(ctorFn : Expr)
(ctorFnType : Expr)
(instMVars : Array MVarId := #[])
private def mkCtorHeaderAux : Nat → Expr → Expr → Array MVarId → TermElabM CtorHeaderResult
| 0, type, ctorFn, instMVars => pure { ctorFn := ctorFn, ctorFnType := type, instMVars := instMVars }
| n+1, type, ctorFn, instMVars => do
type ← whnfForall type;
match type with
| Expr.forallE _ d b c =>
match c.binderInfo with
| BinderInfo.instImplicit => do
a ← mkFreshExprMVar d MetavarKind.synthetic;
mkCtorHeaderAux n (b.instantiate1 a) (mkApp ctorFn a) (instMVars.push a.mvarId!)
| _ => do
a ← mkFreshExprMVar d;
mkCtorHeaderAux n (b.instantiate1 a) (mkApp ctorFn a) instMVars
| _ => throwError "unexpected constructor type"
private partial def getForallBody : Nat → Expr → Option Expr
| i+1, Expr.forallE _ _ b _ => getForallBody i b
| i+1, _ => none
| 0, type => type
private def propagateExpectedType (type : Expr) (numFields : Nat) (expectedType? : Option Expr) : TermElabM Unit :=
match expectedType? with
| none => pure ()
| some expectedType =>
match getForallBody numFields type with
| none => pure ()
| some typeBody =>
unless typeBody.hasLooseBVars $ do
_ ← isDefEq expectedType typeBody;
pure ()
private def mkCtorHeader (ctorVal : ConstructorVal) (expectedType? : Option Expr) : TermElabM CtorHeaderResult := do
lvls ← ctorVal.lparams.mapM $ fun _ => mkFreshLevelMVar;
let val := Lean.mkConst ctorVal.name lvls;
let type := (ConstantInfo.ctorInfo ctorVal).instantiateTypeLevelParams lvls;
r ← mkCtorHeaderAux ctorVal.nparams type val #[];
propagateExpectedType r.ctorFnType ctorVal.nfields expectedType?;
synthesizeAppInstMVars r.instMVars;
pure r
def markDefaultMissing (e : Expr) : Expr :=
mkAnnotation `structInstDefault e
def defaultMissing? (e : Expr) : Option Expr :=
annotation? `structInstDefault e
def throwFailedToElabField {α} (fieldName : Name) (structName : Name) (msgData : MessageData) : TermElabM α :=
throwError ("failed to elaborate field '" ++ fieldName ++ "' of '" ++ structName ++ ", " ++ msgData)
def trySynthStructInstance? (s : Struct) (expectedType : Expr) : TermElabM (Option Expr) :=
if !s.allDefault then pure none
else
catch (synthInstance? expectedType) (fun _ => pure none)
private partial def elabStruct : Struct → Option Expr → TermElabM (Expr × Struct)
| s, expectedType? => withRef s.ref do
env ← getEnv;
let ctorVal := getStructureCtor env s.structName;
{ ctorFn := ctorFn, ctorFnType := ctorFnType, .. } ← mkCtorHeader ctorVal expectedType?;
(e, _, fields) ← s.fields.foldlM
(fun (acc : Expr × Expr × Fields) field =>
let (e, type, fields) := acc;
match field.lhs with
| [FieldLHS.fieldName ref fieldName] => do
type ← whnfForall type;
match type with
| Expr.forallE _ d b c =>
let continue (val : Expr) (field : Field Struct) : TermElabM (Expr × Expr × Fields) := do {
let e := mkApp e val;
let type := b.instantiate1 val;
let field := { field with expr? := some val };
pure (e, type, field::fields)
};
match field.val with
| FieldVal.term stx => do val ← elabTermEnsuringType stx d; continue val field
| FieldVal.nested s => do
val? ← trySynthStructInstance? s d; -- if all fields of `s` are marked as `default`, then try to synthesize instance
match val? with
| some val => continue val { field with val := FieldVal.term (mkHole field.ref) }
| none => do(val, sNew) ← elabStruct s (some d); val ← ensureHasType d val; continue val { field with val := FieldVal.nested sNew }
| FieldVal.default => do val ← withRef field.ref $ mkFreshExprMVar (some d); continue (markDefaultMissing val) field
| _ => withRef field.ref $ throwFailedToElabField fieldName s.structName ("unexpected constructor type" ++ indentExpr type)
| _ => throwErrorAt field.ref "unexpected unexpanded structure field")
(ctorFn, ctorFnType, []);
pure (e, s.setFields fields.reverse)
namespace DefaultFields
structure Context :=
-- We must search for default values overriden in derived structures
(structs : Array Struct := #[])
(allStructNames : Array Name := #[])
/-
Consider the following example:
```
structure A :=
(x : Nat := 1)
structure B extends A :=
(y : Nat := x + 1) (x := y + 1)
structure C extends B :=
(z : Nat := 2*y) (x := z + 3)
```
And we are trying to elaborate a structure instance for `C`. There are default values for `x` at `A`, `B`, and `C`.
We say the default value at `C` has distance 0, the one at `B` distance 1, and the one at `A` distance 2.
The field `maxDistance` specifies the maximum distance considered in a round of Default field computation.
Remark: since `C` does not set a default value of `y`, the default value at `B` is at distance 0.
The fixpoint for setting default values works in the following way.
- Keep computing default values using `maxDistance == 0`.
- We increase `maxDistance` whenever we failed to compute a new default value in a round.
- If `maxDistance > 0`, then we interrupt a round as soon as we compute some default value.
We use depth-first search.
- We sign an error if no progress is made when `maxDistance` == structure hierarchy depth (2 in the example above).
-/
(maxDistance : Nat := 0)
structure State :=
(progress : Bool := false)
partial def collectStructNames : Struct → Array Name → Array Name
| struct, names =>
let names := names.push struct.structName;
struct.fields.foldl
(fun names field =>
match field.val with
| FieldVal.nested struct => collectStructNames struct names
| _ => names)
names
partial def getHierarchyDepth : Struct → Nat
| struct =>
struct.fields.foldl
(fun max field =>
match field.val with
| FieldVal.nested struct => Nat.max max (getHierarchyDepth struct + 1)
| _ => max)
0
partial def findDefaultMissing? (mctx : MetavarContext) : Struct → Option (Field Struct)
| struct =>
struct.fields.findSome? $ fun field =>
match field.val with
| FieldVal.nested struct => findDefaultMissing? struct
| _ => match field.expr? with
| none => unreachable!
| some expr => match defaultMissing? expr with
| some (Expr.mvar mvarId _) => if mctx.isExprAssigned mvarId then none else some field
| _ => none
def getFieldName (field : Field Struct) : Name :=
match field.lhs with
| [FieldLHS.fieldName _ fieldName] => fieldName
| _ => unreachable!
abbrev M := ReaderT Context (StateRefT State TermElabM)
def isRoundDone : M Bool := do
ctx ← read;
s ← get;
pure (s.progress && ctx.maxDistance > 0)
def getFieldValue? (struct : Struct) (fieldName : Name) : Option Expr :=
struct.fields.findSome? $ fun field =>
if getFieldName field == fieldName then
field.expr?
else
none
partial def mkDefaultValueAux? (struct : Struct) : Expr → TermElabM (Option Expr)
| Expr.lam n d b c => withRef struct.ref $
if c.binderInfo.isExplicit then
let fieldName := n;
match getFieldValue? struct fieldName with
| none => pure none
| some val => do
valType ← inferType val;
condM (isDefEq valType d)
(mkDefaultValueAux? (b.instantiate1 val))
(pure none)
else do
arg ← mkFreshExprMVar d;
mkDefaultValueAux? (b.instantiate1 arg)
| e =>
if e.isAppOfArity `id 2 then
pure (some e.appArg!)
else
pure (some e)
def mkDefaultValue? (struct : Struct) (cinfo : ConstantInfo) : TermElabM (Option Expr) :=
withRef struct.ref do
us ← cinfo.lparams.mapM $ fun _ => mkFreshLevelMVar;
mkDefaultValueAux? struct (cinfo.instantiateValueLevelParams us)
/-- If `e` is a projection function of one of the given structures, then reduce it -/
def reduceProjOf? (structNames : Array Name) (e : Expr) : MetaM (Option Expr) := do
if !e.isApp then pure none
else match e.getAppFn with
| Expr.const name _ _ => do
env ← getEnv;
match env.getProjectionStructureName? name with
| some structName =>
if structNames.contains structName then
Meta.unfoldDefinition? e
else
pure none
| none => pure none
| _ => pure none
/-- Reduce default value. It performs beta reduction and projections of the given structures. -/
partial def reduce (structNames : Array Name) : Expr → MetaM Expr
| e@(Expr.lam _ _ _ _) => lambdaLetTelescope e $ fun xs b => do b ← reduce b; mkLambdaFVars xs b
| e@(Expr.forallE _ _ _ _) => forallTelescope e $ fun xs b => do b ← reduce b; mkForallFVars xs b
| e@(Expr.letE _ _ _ _ _) => lambdaLetTelescope e $ fun xs b => do b ← reduce b; mkLetFVars xs b
| e@(Expr.proj _ i b _) => do
r? ← Meta.reduceProj? b i;
match r? with
| some r => reduce r
| none => do b ← reduce b; pure $ e.updateProj! b
| e@(Expr.app f _ _) => do
r? ← reduceProjOf? structNames e;
match r? with
| some r => reduce r
| none => do
let f := f.getAppFn;
f' ← reduce f;
if f'.isLambda then
let revArgs := e.getAppRevArgs;
reduce $ f'.betaRev revArgs
else do
args ← e.getAppArgs.mapM reduce;
pure (mkAppN f' args)
| e@(Expr.mdata _ b _) => do
b ← reduce b;
if (defaultMissing? e).isSome && !b.isMVar then
pure b
else
pure $ e.updateMData! b
| e@(Expr.mvar mvarId _) => do
val? ← getExprMVarAssignment? mvarId;
match val? with
| some val => if val.isMVar then reduce val else pure val
| none => pure e
| e => pure e
partial def tryToSynthesizeDefaultAux (structs : Array Struct) (allStructNames : Array Name) (maxDistance : Nat)
(fieldName : Name) (mvarId : MVarId) : Nat → Nat → TermElabM Bool
| i, dist =>
if dist > maxDistance then pure false
else if h : i < structs.size then do
let struct := structs.get ⟨i, h⟩;
let defaultName := struct.structName ++ fieldName ++ `_default;
env ← getEnv;
match env.find? defaultName with
| some cinfo@(ConstantInfo.defnInfo defVal) => do
mctx ← getMCtx;
val? ← mkDefaultValue? struct cinfo;
match val? with
| none => do setMCtx mctx; tryToSynthesizeDefaultAux (i+1) (dist+1)
| some val => do
val ← liftMetaM $ reduce allStructNames val;
match val.find? $ fun e => (defaultMissing? e).isSome with
| some _ => do setMCtx mctx; tryToSynthesizeDefaultAux (i+1) (dist+1)
| none => do
mvarDecl ← getMVarDecl mvarId;
val ← ensureHasType mvarDecl.type val;
assignExprMVar mvarId val;
pure true
| _ => tryToSynthesizeDefaultAux (i+1) dist
else
pure false
def tryToSynthesizeDefault (structs : Array Struct) (allStructNames : Array Name)
(maxDistance : Nat) (fieldName : Name) (mvarId : MVarId) : TermElabM Bool :=
tryToSynthesizeDefaultAux structs allStructNames maxDistance fieldName mvarId 0 0
partial def step : Struct → M Unit
| struct => unlessM isRoundDone $ adaptReader (fun (ctx : Context) => { ctx with structs := ctx.structs.push struct }) $ do
struct.fields.forM $ fun field =>
match field.val with
| FieldVal.nested struct => step struct
| _ => match field.expr? with
| none => unreachable!
| some expr => match defaultMissing? expr with
| some (Expr.mvar mvarId _) =>
unlessM (liftM $ isExprMVarAssigned mvarId) $ do
ctx ← read;
whenM (liftM $ withRef field.ref $ tryToSynthesizeDefault ctx.structs ctx.allStructNames ctx.maxDistance (getFieldName field) mvarId) $ do
modify $ fun s => { s with progress := true }
| _ => pure ()
partial def propagateLoop (hierarchyDepth : Nat) : Nat → Struct → M Unit
| d, struct => do
mctx ← getMCtx;
match findDefaultMissing? mctx struct with
| none => pure () -- Done
| some field =>
if d > hierarchyDepth then
throwErrorAt field.ref ("field '" ++ getFieldName field ++ "' is missing")
else adaptReader (fun (ctx : Context) => { ctx with maxDistance := d }) $ do
modify $ fun (s : State) => { s with progress := false };
step struct;
s ← get;
if s.progress then do
propagateLoop 0 struct
else
propagateLoop (d+1) struct
def propagate (struct : Struct) : TermElabM Unit :=
let hierarchyDepth := getHierarchyDepth struct;
let structNames := collectStructNames struct #[];
(propagateLoop hierarchyDepth 0 struct { allStructNames := structNames }).run' {}
end DefaultFields
private def elabStructInstAux (stx : Syntax) (expectedType? : Option Expr) (source : Source) : TermElabM Expr := do
structName ← getStructName stx expectedType? source;
env ← getEnv;
unless (isStructureLike env structName) $
throwError ("invalid {...} notation, '" ++ structName ++ "' is not a structure");
match mkStructView stx structName source with
| Except.error ex => throwError ex
| Except.ok struct => do
struct ← expandStruct struct;
trace `Elab.struct fun _ => toString struct;
(r, struct) ← elabStruct struct expectedType?;
DefaultFields.propagate struct;
pure r
@[builtinTermElab structInst] def elabStructInst : TermElab :=
fun stx expectedType? => do
stxNew? ← expandNonAtomicExplicitSource stx;
match stxNew? with
| some stxNew => withMacroExpansion stx stxNew $ elabTerm stxNew expectedType?
| none => do
sourceView ← getStructSource stx;
modifyOp? ← isModifyOp? stx;
match modifyOp?, sourceView with
| some modifyOp, Source.explicit source _ => elabModifyOp stx modifyOp source expectedType?
| some _, _ => throwError ("invalid {...} notation, explicit source is required when using '[<index>] := <value>'")
| _, _ => elabStructInstAux stx expectedType? sourceView
@[init] private def regTraceClasses : IO Unit := do
registerTraceClass `Elab.struct;
pure ()
end StructInst
end Term
end Elab
end Lean