777 lines
33 KiB
Text
777 lines
33 KiB
Text
/-
|
||
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
|
||
Released under Apache 2.0 license as described in the file LICENSE.
|
||
Authors: Leonardo de Moura
|
||
-/
|
||
import Lean.Util.FindExpr
|
||
import Lean.Elab.App
|
||
import Lean.Elab.Binders
|
||
import Lean.Elab.Quotation
|
||
|
||
namespace Lean.Elab.Term.StructInst
|
||
|
||
open Std (HashMap)
|
||
open Meta
|
||
|
||
/- parser! "{" >> optional (try (termParser >> "with")) >> sepBy structInstField ", " true >> optional ".." >> optional (" : " >> termParser) >> "}" -/
|
||
|
||
@[builtinMacro Lean.Parser.Term.structInst] def expandStructInstExpectedType : Macro := fun stx =>
|
||
let expectedArg := stx[4]
|
||
if expectedArg.isNone
|
||
then Macro.throwUnsupported
|
||
else
|
||
let expected := expectedArg[1]
|
||
let stxNew := stx.setArg 4 mkNullNode
|
||
`(($stxNew : $expected))
|
||
|
||
/-
|
||
If `stx` is of the form `{ s with ... }` and `s` is not a local variable, expand into `let src := s; { src with ... }`.
|
||
|
||
Note that this one is not a `Macro` because we need to access the local context.
|
||
-/
|
||
private def expandNonAtomicExplicitSource (stx : Syntax) : TermElabM (Option Syntax) :=
|
||
withFreshMacroScope do
|
||
let sourceOpt := stx[1]
|
||
if sourceOpt.isNone then
|
||
pure none
|
||
else
|
||
let source := sourceOpt[0]
|
||
match (← isLocalIdent? source) with
|
||
| some _ => pure none
|
||
| none =>
|
||
let src ← `(src)
|
||
let sourceOpt := sourceOpt.setArg 0 src
|
||
let stxNew := stx.setArg 1 sourceOpt
|
||
`(let src := $source; $stxNew)
|
||
|
||
inductive Source :=
|
||
| none -- structure instance source has not been provieded
|
||
| implicit (stx : Syntax) -- `..`
|
||
| explicit (stx : Syntax) (src : Expr) -- `src with`
|
||
|
||
instance : Inhabited Source := ⟨Source.none⟩
|
||
|
||
def Source.isNone : Source → Bool
|
||
| Source.none => true
|
||
| _ => false
|
||
|
||
def setStructSourceSyntax (structStx : Syntax) : Source → Syntax
|
||
| Source.none => (structStx.setArg 1 mkNullNode).setArg 3 mkNullNode
|
||
| Source.implicit stx => (structStx.setArg 1 mkNullNode).setArg 3 stx
|
||
| Source.explicit stx _ => (structStx.setArg 1 stx).setArg 3 mkNullNode
|
||
|
||
private def getStructSource (stx : Syntax) : TermElabM Source :=
|
||
withRef stx do
|
||
let explicitSource := stx[1]
|
||
let implicitSource := stx[3]
|
||
if explicitSource.isNone && implicitSource.isNone then
|
||
pure Source.none
|
||
else if explicitSource.isNone then
|
||
pure $ Source.implicit implicitSource
|
||
else if implicitSource.isNone then
|
||
let fvar? ← isLocalIdent? explicitSource[0]
|
||
match fvar? with
|
||
| none => unreachable! -- expandNonAtomicExplicitSource must have been used when we get here
|
||
| some src => pure $ Source.explicit explicitSource src
|
||
else
|
||
throwError "invalid structure instance `with` and `..` cannot be used together"
|
||
|
||
/-
|
||
We say a `{ ... }` notation is a `modifyOp` if it contains only one
|
||
```
|
||
def structInstArrayRef := parser! "[" >> termParser >>"]"
|
||
``` -/
|
||
private def isModifyOp? (stx : Syntax) : TermElabM (Option Syntax) := do
|
||
let s? ← stx[2].getSepArgs.foldlM (init := none) fun s? arg =>
|
||
/- Remark: the syntax for `structInstField` is
|
||
```
|
||
def structInstLVal := (ident <|> numLit <|> structInstArrayRef) >> many (group ("." >> (ident <|> numLit)) <|> structInstArrayRef)
|
||
def structInstField := parser! structInstLVal >> " := " >> termParser
|
||
``` -/
|
||
let lval := arg[0]
|
||
let k := lval.getKind
|
||
if k == `Lean.Parser.Term.structInstArrayRef then
|
||
match s? with
|
||
| none => pure (some arg)
|
||
| some s =>
|
||
if s.getKind == `Lean.Parser.Term.structInstArrayRef then
|
||
throwErrorAt arg "invalid {...} notation, at most one `[..]` at a given level"
|
||
else
|
||
throwErrorAt arg "invalid {...} notation, can't mix field and `[..]` at a given level"
|
||
else
|
||
match s? with
|
||
| none => pure (some arg)
|
||
| some s =>
|
||
if s.getKind == `Lean.Parser.Term.structInstArrayRef then
|
||
throwErrorAt arg "invalid {...} notation, can't mix field and `[..]` at a given level"
|
||
else
|
||
pure s?
|
||
match s? with
|
||
| none => pure none
|
||
| some s => if s[0].getKind == `Lean.Parser.Term.structInstArrayRef then pure s? else pure none
|
||
|
||
private def elabModifyOp (stx modifyOp source : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do
|
||
let cont (val : Syntax) : TermElabM Expr := do
|
||
let lval := modifyOp[0]
|
||
let idx := lval[1]
|
||
let self := source[0]
|
||
let stxNew ← `($(self).modifyOp (idx := $idx) (fun s => $val))
|
||
trace[Elab.struct.modifyOp]! "{stx}\n===>\n{stxNew}"
|
||
withMacroExpansion stx stxNew $ elabTerm stxNew expectedType?
|
||
trace[Elab.struct.modifyOp]! "{modifyOp}\nSource: {source}"
|
||
let rest := modifyOp[1]
|
||
if rest.isNone then
|
||
cont modifyOp[3]
|
||
else
|
||
let s ← `(s)
|
||
let valFirst := rest[0]
|
||
let valFirst := if valFirst.getKind == `Lean.Parser.Term.structInstArrayRef then valFirst else valFirst[1]
|
||
let restArgs := rest.getArgs
|
||
let valRest := mkNullNode restArgs[1:restArgs.size]
|
||
let valField := modifyOp.setArg 0 valFirst
|
||
let valField := valField.setArg 1 valRest
|
||
let valSource := source.modifyArg 0 fun _ => s
|
||
let val := stx.setArg 1 valSource
|
||
let val := val.setArg 2 $ mkNullNode #[valField]
|
||
trace[Elab.struct.modifyOp]! "{stx}\nval: {val}"
|
||
cont val
|
||
|
||
/- Get structure name and elaborate explicit source (if available) -/
|
||
private def getStructName (stx : Syntax) (expectedType? : Option Expr) (sourceView : Source) : TermElabM Name := do
|
||
tryPostponeIfNoneOrMVar expectedType?
|
||
let useSource : Unit → TermElabM Name := fun _ =>
|
||
match sourceView, expectedType? with
|
||
| Source.explicit _ src, _ => do
|
||
let srcType ← inferType src
|
||
let srcType ← whnf srcType
|
||
tryPostponeIfMVar srcType
|
||
match srcType.getAppFn with
|
||
| Expr.const constName _ _ => pure constName
|
||
| _ => throwError! "invalid \{...} notation, source type is not of the form (C ...){indentExpr srcType}"
|
||
| _, some expectedType => throwError! "invalid \{...} notation, expected type is not of the form (C ...){indentExpr expectedType}"
|
||
| _, none => throwError! "invalid \{...} notation, expected type must be known"
|
||
match expectedType? with
|
||
| none => useSource ()
|
||
| some expectedType =>
|
||
let expectedType ← whnf expectedType
|
||
match expectedType.getAppFn with
|
||
| Expr.const constName _ _ => pure constName
|
||
| _ => useSource ()
|
||
|
||
inductive FieldLHS :=
|
||
| fieldName (ref : Syntax) (name : Name)
|
||
| fieldIndex (ref : Syntax) (idx : Nat)
|
||
| modifyOp (ref : Syntax) (index : Syntax)
|
||
|
||
instance : Inhabited FieldLHS := ⟨FieldLHS.fieldName (arbitrary _) (arbitrary _)⟩
|
||
instance : ToFormat FieldLHS := ⟨fun lhs =>
|
||
match lhs with
|
||
| FieldLHS.fieldName _ n => fmt n
|
||
| FieldLHS.fieldIndex _ i => fmt i
|
||
| FieldLHS.modifyOp _ i => "[" ++ i.prettyPrint ++ "]"⟩
|
||
|
||
inductive FieldVal (σ : Type) :=
|
||
| term (stx : Syntax) : FieldVal σ
|
||
| nested (s : σ) : FieldVal σ
|
||
| default : FieldVal σ -- mark that field must be synthesized using default value
|
||
|
||
structure Field (σ : Type) :=
|
||
(ref : Syntax)
|
||
(lhs : List FieldLHS)
|
||
(val : FieldVal σ)
|
||
(expr? : Option Expr := none)
|
||
|
||
instance {σ} : Inhabited (Field σ) := ⟨⟨arbitrary _, [], FieldVal.term (arbitrary _), arbitrary _⟩⟩
|
||
|
||
def Field.isSimple {σ} : Field σ → Bool
|
||
| { lhs := [_], .. } => true
|
||
| _ => false
|
||
|
||
inductive Struct :=
|
||
| mk (ref : Syntax) (structName : Name) (fields : List (Field Struct)) (source : Source)
|
||
|
||
instance : Inhabited Struct := ⟨⟨arbitrary _, arbitrary _, [], arbitrary _⟩⟩
|
||
|
||
abbrev Fields := List (Field Struct)
|
||
|
||
/- true if all fields of the given structure are marked as `default` -/
|
||
partial def Struct.allDefault : Struct → Bool
|
||
| ⟨_, _, fields, _⟩ => fields.all fun ⟨_, _, val, _⟩ => match val with
|
||
| FieldVal.term _ => false
|
||
| FieldVal.default => true
|
||
| FieldVal.nested s => allDefault s
|
||
|
||
def Struct.ref : Struct → Syntax
|
||
| ⟨ref, _, _, _⟩ => ref
|
||
|
||
def Struct.structName : Struct → Name
|
||
| ⟨_, structName, _, _⟩ => structName
|
||
|
||
def Struct.fields : Struct → Fields
|
||
| ⟨_, _, fields, _⟩ => fields
|
||
|
||
def Struct.source : Struct → Source
|
||
| ⟨_, _, _, s⟩ => s
|
||
|
||
def formatField (formatStruct : Struct → Format) (field : Field Struct) : Format :=
|
||
Format.joinSep field.lhs " . " ++ " := " ++
|
||
match field.val with
|
||
| FieldVal.term v => v.prettyPrint
|
||
| FieldVal.nested s => formatStruct s
|
||
| FieldVal.default => "<default>"
|
||
|
||
partial def formatStruct : Struct → Format
|
||
| ⟨_, structName, fields, source⟩ =>
|
||
let fieldsFmt := Format.joinSep (fields.map (formatField formatStruct)) ", "
|
||
match source with
|
||
| Source.none => "{" ++ fieldsFmt ++ "}"
|
||
| Source.implicit _ => "{" ++ fieldsFmt ++ " .. }"
|
||
| Source.explicit _ src => "{" ++ format src ++ " with " ++ fieldsFmt ++ "}"
|
||
|
||
instance : ToFormat Struct := ⟨formatStruct⟩
|
||
instance : ToString Struct := ⟨toString ∘ format⟩
|
||
|
||
instance : ToFormat (Field Struct) := ⟨formatField formatStruct⟩
|
||
instance : ToString (Field Struct) := ⟨toString ∘ format⟩
|
||
|
||
/-
|
||
Recall that `structInstField` elements have the form
|
||
```
|
||
def structInstField := parser! structInstLVal >> " := " >> termParser
|
||
def structInstLVal := (ident <|> numLit <|> structInstArrayRef) >> many (("." >> (ident <|> numLit)) <|> structInstArrayRef)
|
||
def structInstArrayRef := parser! "[" >> termParser >>"]"
|
||
-/
|
||
|
||
-- Remark: this code relies on the fact that `expandStruct` only transforms `fieldLHS.fieldName`
|
||
def FieldLHS.toSyntax (first : Bool) : FieldLHS → Syntax
|
||
| FieldLHS.modifyOp stx _ => stx
|
||
| FieldLHS.fieldName stx name => if first then mkIdentFrom stx name else mkNullNode #[mkAtomFrom stx ".", mkIdentFrom stx name]
|
||
| FieldLHS.fieldIndex stx _ => if first then stx else mkNullNode #[mkAtomFrom stx ".", stx]
|
||
|
||
def FieldVal.toSyntax : FieldVal Struct → Syntax
|
||
| FieldVal.term stx => stx
|
||
| _ => unreachable!
|
||
|
||
def Field.toSyntax : Field Struct → Syntax
|
||
| field =>
|
||
let stx := field.ref
|
||
let stx := stx.setArg 3 field.val.toSyntax
|
||
match field.lhs with
|
||
| first::rest =>
|
||
let stx := stx.setArg 0 $ first.toSyntax true
|
||
let stx := stx.setArg 1 $ mkNullNode $ rest.toArray.map (FieldLHS.toSyntax false)
|
||
stx
|
||
| _ => unreachable!
|
||
|
||
private def toFieldLHS (stx : Syntax) : Except String FieldLHS :=
|
||
if stx.getKind == `Lean.Parser.Term.structInstArrayRef then
|
||
pure $ FieldLHS.modifyOp stx stx[1]
|
||
else
|
||
-- Note that the representation of the first field is different.
|
||
let stx := if stx.getKind == nullKind then stx[1] else stx
|
||
if stx.isIdent then pure $ FieldLHS.fieldName stx stx.getId.eraseMacroScopes
|
||
else match stx.isFieldIdx? with
|
||
| some idx => pure $ FieldLHS.fieldIndex stx idx
|
||
| none => throw "unexpected structure syntax"
|
||
|
||
private def mkStructView (stx : Syntax) (structName : Name) (source : Source) : Except String Struct := do
|
||
let args := stx[2].getArgs
|
||
let fieldsStx := args.filter $ fun arg => arg.getKind == `Lean.Parser.Term.structInstField
|
||
let fields ← fieldsStx.toList.mapM fun fieldStx => do
|
||
let val := fieldStx[3]
|
||
let first ← toFieldLHS fieldStx[0]
|
||
let rest ← fieldStx[1].getArgs.toList.mapM toFieldLHS
|
||
pure $ ({ref := fieldStx, lhs := first :: rest, val := FieldVal.term val } : Field Struct)
|
||
pure ⟨stx, structName, fields, source⟩
|
||
|
||
def Struct.modifyFieldsM {m : Type → Type} [Monad m] (s : Struct) (f : Fields → m Fields) : m Struct :=
|
||
match s with
|
||
| ⟨ref, structName, fields, source⟩ => return ⟨ref, structName, (← f fields), source⟩
|
||
|
||
@[inline] def Struct.modifyFields (s : Struct) (f : Fields → Fields) : Struct :=
|
||
Id.run $ s.modifyFieldsM f
|
||
|
||
def Struct.setFields (s : Struct) (fields : Fields) : Struct :=
|
||
s.modifyFields fun _ => fields
|
||
|
||
private def expandCompositeFields (s : Struct) : Struct :=
|
||
s.modifyFields fun fields => fields.map $ fun field => match field with
|
||
| { lhs := FieldLHS.fieldName ref (Name.str Name.anonymous _ _) :: rest, .. } => field
|
||
| { lhs := FieldLHS.fieldName ref n@(Name.str _ _ _) :: rest, .. } =>
|
||
let newEntries := n.components.map $ FieldLHS.fieldName ref
|
||
{ field with lhs := newEntries ++ rest }
|
||
| _ => field
|
||
|
||
private def expandNumLitFields (s : Struct) : TermElabM Struct :=
|
||
s.modifyFieldsM fun fields => do
|
||
let env ← getEnv
|
||
let fieldNames := getStructureFields env s.structName
|
||
fields.mapM fun field => match field with
|
||
| { lhs := FieldLHS.fieldIndex ref idx :: rest, .. } =>
|
||
if idx == 0 then throwErrorAt ref "invalid field index, index must be greater than 0"
|
||
else if idx > fieldNames.size then throwErrorAt! ref "invalid field index, structure has only #{fieldNames.size} fields"
|
||
else pure { field with lhs := FieldLHS.fieldName ref fieldNames[idx - 1] :: rest }
|
||
| _ => pure field
|
||
|
||
/- For example, consider the following structures:
|
||
```
|
||
structure A := (x : Nat)
|
||
structure B extends A := (y : Nat)
|
||
structure C extends B := (z : Bool)
|
||
```
|
||
This method expands parent structure fields using the path to the parent structure.
|
||
For example,
|
||
```
|
||
{ x := 0, y := 0, z := true : C }
|
||
```
|
||
is expanded into
|
||
```
|
||
{ toB.toA.x := 0, toB.y := 0, z := true : C }
|
||
``` -/
|
||
private def expandParentFields (s : Struct) : TermElabM Struct := do
|
||
let env ← getEnv
|
||
s.modifyFieldsM fun fields => fields.mapM fun field => match field with
|
||
| { lhs := FieldLHS.fieldName ref fieldName :: rest, .. } =>
|
||
match findField? env s.structName fieldName with
|
||
| none => throwErrorAt! ref "'{fieldName}' is not a field of structure '{s.structName}'"
|
||
| some baseStructName =>
|
||
if baseStructName == s.structName then pure field
|
||
else match getPathToBaseStructure? env baseStructName s.structName with
|
||
| some path => do
|
||
let path := path.map $ fun funName => match funName with
|
||
| Name.str _ s _ => FieldLHS.fieldName ref (Name.mkSimple s)
|
||
| _ => unreachable!
|
||
pure { field with lhs := path ++ field.lhs }
|
||
| _ => throwErrorAt! ref "failed to access field '{fieldName}' in parent structure"
|
||
| _ => pure field
|
||
|
||
private abbrev FieldMap := HashMap Name Fields
|
||
|
||
private def mkFieldMap (fields : Fields) : TermElabM FieldMap :=
|
||
fields.foldlM (init := {}) fun fieldMap field =>
|
||
match field.lhs with
|
||
| FieldLHS.fieldName _ fieldName :: rest =>
|
||
match fieldMap.find? fieldName with
|
||
| some (prevField::restFields) =>
|
||
if field.isSimple || prevField.isSimple then
|
||
throwErrorAt! field.ref "field '{fieldName}' has already beed specified"
|
||
else
|
||
pure $ fieldMap.insert fieldName (field::prevField::restFields)
|
||
| _ => pure $ fieldMap.insert fieldName [field]
|
||
| _ => unreachable!
|
||
|
||
private def isSimpleField? : Fields → Option (Field Struct)
|
||
| [field] => if field.isSimple then some field else none
|
||
| _ => none
|
||
|
||
private def getFieldIdx (structName : Name) (fieldNames : Array Name) (fieldName : Name) : TermElabM Nat := do
|
||
match fieldNames.findIdx? $ fun n => n == fieldName with
|
||
| some idx => pure idx
|
||
| none => throwError! "field '{fieldName}' is not a valid field of '{structName}'"
|
||
|
||
private def mkProjStx (s : Syntax) (fieldName : Name) : Syntax :=
|
||
Syntax.node `Lean.Parser.Term.proj #[s, mkAtomFrom s ".", mkIdentFrom s fieldName]
|
||
|
||
private def mkSubstructSource (structName : Name) (fieldNames : Array Name) (fieldName : Name) (src : Source) : TermElabM Source :=
|
||
match src with
|
||
| Source.explicit stx src => do
|
||
let idx ← getFieldIdx structName fieldNames fieldName
|
||
let stx := stx.modifyArg 0 fun stx => mkProjStx stx fieldName
|
||
pure $ Source.explicit stx (mkProj structName idx src)
|
||
| s => pure s
|
||
|
||
@[specialize] private def groupFields (expandStruct : Struct → TermElabM Struct) (s : Struct) : TermElabM Struct := do
|
||
let env ← getEnv
|
||
let fieldNames := getStructureFields env s.structName
|
||
withRef s.ref do
|
||
s.modifyFieldsM fun fields => do
|
||
let fieldMap ← mkFieldMap fields
|
||
fieldMap.toList.mapM fun ⟨fieldName, fields⟩ => do
|
||
match isSimpleField? fields with
|
||
| some field => pure field
|
||
| none =>
|
||
let substructFields := fields.map fun field => { field with lhs := field.lhs.tail! }
|
||
let substructSource ← mkSubstructSource s.structName fieldNames fieldName s.source
|
||
let field := fields.head!
|
||
match Lean.isSubobjectField? env s.structName fieldName with
|
||
| some substructName =>
|
||
let substruct := Struct.mk s.ref substructName substructFields substructSource
|
||
let substruct ← expandStruct substruct
|
||
pure { field with lhs := [field.lhs.head!], val := FieldVal.nested substruct }
|
||
| none => do
|
||
-- It is not a substructure field. Thus, we wrap fields using `Syntax`, and use `elabTerm` to process them.
|
||
let valStx := s.ref -- construct substructure syntax using s.ref as template
|
||
let valStx := valStx.setArg 4 mkNullNode -- erase optional expected type
|
||
let args := substructFields.toArray.map Field.toSyntax
|
||
let valStx := valStx.setArg 2 (Syntax.mkSep args (mkAtomFrom s.ref ","))
|
||
let valStx := setStructSourceSyntax valStx substructSource
|
||
pure { field with lhs := [field.lhs.head!], val := FieldVal.term valStx }
|
||
|
||
def findField? (fields : Fields) (fieldName : Name) : Option (Field Struct) :=
|
||
fields.find? fun field =>
|
||
match field.lhs with
|
||
| [FieldLHS.fieldName _ n] => n == fieldName
|
||
| _ => false
|
||
|
||
@[specialize] private def addMissingFields (expandStruct : Struct → TermElabM Struct) (s : Struct) : TermElabM Struct := do
|
||
let env ← getEnv
|
||
let fieldNames := getStructureFields env s.structName
|
||
let ref := s.ref
|
||
withRef ref do
|
||
let fields ← fieldNames.foldlM (init := []) fun fields fieldName => do
|
||
match findField? s.fields fieldName with
|
||
| some field => pure $ field::fields
|
||
| none =>
|
||
let addField (val : FieldVal Struct) : TermElabM Fields := do
|
||
pure $ { ref := s.ref, lhs := [FieldLHS.fieldName s.ref fieldName], val := val } :: fields
|
||
match Lean.isSubobjectField? env s.structName fieldName with
|
||
| some substructName => do
|
||
let substructSource ← mkSubstructSource s.structName fieldNames fieldName s.source
|
||
let substruct := Struct.mk s.ref substructName [] substructSource
|
||
let substruct ← expandStruct substruct
|
||
addField (FieldVal.nested substruct)
|
||
| none =>
|
||
match s.source with
|
||
| Source.none => addField FieldVal.default
|
||
| Source.implicit _ => addField (FieldVal.term (mkHole s.ref))
|
||
| Source.explicit stx _ =>
|
||
-- stx is of the form `optional (try (termParser >> "with"))`
|
||
let src := stx[0]
|
||
let val := mkProjStx src fieldName
|
||
addField (FieldVal.term val)
|
||
pure $ s.setFields fields.reverse
|
||
|
||
private partial def expandStruct (s : Struct) : TermElabM Struct := do
|
||
let s := expandCompositeFields s
|
||
let s ← expandNumLitFields s
|
||
let s ← expandParentFields s
|
||
let s ← groupFields expandStruct s
|
||
addMissingFields expandStruct s
|
||
|
||
structure CtorHeaderResult :=
|
||
(ctorFn : Expr)
|
||
(ctorFnType : Expr)
|
||
(instMVars : Array MVarId := #[])
|
||
|
||
private def mkCtorHeaderAux : Nat → Expr → Expr → Array MVarId → TermElabM CtorHeaderResult
|
||
| 0, type, ctorFn, instMVars => pure { ctorFn := ctorFn, ctorFnType := type, instMVars := instMVars }
|
||
| n+1, type, ctorFn, instMVars => do
|
||
let type ← whnfForall type
|
||
match type with
|
||
| Expr.forallE _ d b c =>
|
||
match c.binderInfo with
|
||
| BinderInfo.instImplicit =>
|
||
let a ← mkFreshExprMVar d MetavarKind.synthetic
|
||
mkCtorHeaderAux n (b.instantiate1 a) (mkApp ctorFn a) (instMVars.push a.mvarId!)
|
||
| _ =>
|
||
let a ← mkFreshExprMVar d
|
||
mkCtorHeaderAux n (b.instantiate1 a) (mkApp ctorFn a) instMVars
|
||
| _ => throwError "unexpected constructor type"
|
||
|
||
private partial def getForallBody : Nat → Expr → Option Expr
|
||
| i+1, Expr.forallE _ _ b _ => getForallBody i b
|
||
| i+1, _ => none
|
||
| 0, type => type
|
||
|
||
private def propagateExpectedType (type : Expr) (numFields : Nat) (expectedType? : Option Expr) : TermElabM Unit :=
|
||
match expectedType? with
|
||
| none => pure ()
|
||
| some expectedType => do
|
||
match getForallBody numFields type with
|
||
| none => pure ()
|
||
| some typeBody =>
|
||
unless typeBody.hasLooseBVars do
|
||
isDefEq expectedType typeBody
|
||
pure ()
|
||
|
||
private def mkCtorHeader (ctorVal : ConstructorVal) (expectedType? : Option Expr) : TermElabM CtorHeaderResult := do
|
||
let lvls ← ctorVal.lparams.mapM fun _ => mkFreshLevelMVar
|
||
let val := Lean.mkConst ctorVal.name lvls
|
||
let type := (ConstantInfo.ctorInfo ctorVal).instantiateTypeLevelParams lvls
|
||
let r ← mkCtorHeaderAux ctorVal.nparams type val #[]
|
||
propagateExpectedType r.ctorFnType ctorVal.nfields expectedType?
|
||
synthesizeAppInstMVars r.instMVars
|
||
pure r
|
||
|
||
def markDefaultMissing (e : Expr) : Expr :=
|
||
mkAnnotation `structInstDefault e
|
||
|
||
def defaultMissing? (e : Expr) : Option Expr :=
|
||
annotation? `structInstDefault e
|
||
|
||
def throwFailedToElabField {α} (fieldName : Name) (structName : Name) (msgData : MessageData) : TermElabM α :=
|
||
throwError! "failed to elaborate field '{fieldName}' of '{structName}, {msgData}"
|
||
|
||
def trySynthStructInstance? (s : Struct) (expectedType : Expr) : TermElabM (Option Expr) := do
|
||
if !s.allDefault then
|
||
pure none
|
||
else
|
||
try synthInstance? expectedType catch _ => pure none
|
||
|
||
private partial def elabStruct (s : Struct) (expectedType? : Option Expr) : TermElabM (Expr × Struct) := withRef s.ref do
|
||
let env ← getEnv
|
||
let ctorVal := getStructureCtor env s.structName
|
||
let { ctorFn := ctorFn, ctorFnType := ctorFnType, .. } ← mkCtorHeader ctorVal expectedType?
|
||
let (e, _, fields) ← s.fields.foldlM (init := (ctorFn, ctorFnType, [])) fun (e, type, fields) field =>
|
||
match field.lhs with
|
||
| [FieldLHS.fieldName ref fieldName] => do
|
||
let type ← whnfForall type
|
||
match type with
|
||
| Expr.forallE _ d b c =>
|
||
let cont (val : Expr) (field : Field Struct) : TermElabM (Expr × Expr × Fields) := do
|
||
let e := mkApp e val
|
||
let type := b.instantiate1 val
|
||
let field := { field with expr? := some val }
|
||
pure (e, type, field::fields)
|
||
match field.val with
|
||
| FieldVal.term stx => cont (← elabTermEnsuringType stx d) field
|
||
| FieldVal.nested s => do
|
||
-- if all fields of `s` are marked as `default`, then try to synthesize instance
|
||
match (← trySynthStructInstance? s d) with
|
||
| some val => cont val { field with val := FieldVal.term (mkHole field.ref) }
|
||
| none => do let (val, sNew) ← elabStruct s (some d); let val ← ensureHasType d val; cont val { field with val := FieldVal.nested sNew }
|
||
| FieldVal.default => do let val ← withRef field.ref $ mkFreshExprMVar (some d); cont (markDefaultMissing val) field
|
||
| _ => withRef field.ref $ throwFailedToElabField fieldName s.structName m!"unexpected constructor type{indentExpr type}"
|
||
| _ => throwErrorAt field.ref "unexpected unexpanded structure field"
|
||
pure (e, s.setFields fields.reverse)
|
||
|
||
namespace DefaultFields
|
||
|
||
structure Context :=
|
||
-- We must search for default values overriden in derived structures
|
||
(structs : Array Struct := #[])
|
||
(allStructNames : Array Name := #[])
|
||
/-
|
||
Consider the following example:
|
||
```
|
||
structure A :=
|
||
(x : Nat := 1)
|
||
|
||
structure B extends A :=
|
||
(y : Nat := x + 1) (x := y + 1)
|
||
|
||
structure C extends B :=
|
||
(z : Nat := 2*y) (x := z + 3)
|
||
```
|
||
And we are trying to elaborate a structure instance for `C`. There are default values for `x` at `A`, `B`, and `C`.
|
||
We say the default value at `C` has distance 0, the one at `B` distance 1, and the one at `A` distance 2.
|
||
The field `maxDistance` specifies the maximum distance considered in a round of Default field computation.
|
||
Remark: since `C` does not set a default value of `y`, the default value at `B` is at distance 0.
|
||
|
||
The fixpoint for setting default values works in the following way.
|
||
- Keep computing default values using `maxDistance == 0`.
|
||
- We increase `maxDistance` whenever we failed to compute a new default value in a round.
|
||
- If `maxDistance > 0`, then we interrupt a round as soon as we compute some default value.
|
||
We use depth-first search.
|
||
- We sign an error if no progress is made when `maxDistance` == structure hierarchy depth (2 in the example above).
|
||
-/
|
||
(maxDistance : Nat := 0)
|
||
|
||
structure State :=
|
||
(progress : Bool := false)
|
||
|
||
partial def collectStructNames (struct : Struct) (names : Array Name) : Array Name :=
|
||
let names := names.push struct.structName
|
||
struct.fields.foldl (init := names) fun names field =>
|
||
match field.val with
|
||
| FieldVal.nested struct => collectStructNames struct names
|
||
| _ => names
|
||
|
||
partial def getHierarchyDepth (struct : Struct) : Nat :=
|
||
struct.fields.foldl (init := 0) fun max field =>
|
||
match field.val with
|
||
| FieldVal.nested struct => Nat.max max (getHierarchyDepth struct + 1)
|
||
| _ => max
|
||
|
||
partial def findDefaultMissing? (mctx : MetavarContext) (struct : Struct) : Option (Field Struct) :=
|
||
struct.fields.findSome? fun field =>
|
||
match field.val with
|
||
| FieldVal.nested struct => findDefaultMissing? mctx struct
|
||
| _ => match field.expr? with
|
||
| none => unreachable!
|
||
| some expr => match defaultMissing? expr with
|
||
| some (Expr.mvar mvarId _) => if mctx.isExprAssigned mvarId then none else some field
|
||
| _ => none
|
||
|
||
def getFieldName (field : Field Struct) : Name :=
|
||
match field.lhs with
|
||
| [FieldLHS.fieldName _ fieldName] => fieldName
|
||
| _ => unreachable!
|
||
|
||
abbrev M := ReaderT Context (StateRefT State TermElabM)
|
||
|
||
def isRoundDone : M Bool := do
|
||
return (← get).progress && (← read).maxDistance > 0
|
||
|
||
def getFieldValue? (struct : Struct) (fieldName : Name) : Option Expr :=
|
||
struct.fields.findSome? fun field =>
|
||
if getFieldName field == fieldName then
|
||
field.expr?
|
||
else
|
||
none
|
||
|
||
partial def mkDefaultValueAux? (struct : Struct) : Expr → TermElabM (Option Expr)
|
||
| Expr.lam n d b c => withRef struct.ref do
|
||
if c.binderInfo.isExplicit then
|
||
let fieldName := n
|
||
match getFieldValue? struct fieldName with
|
||
| none => pure none
|
||
| some val =>
|
||
let valType ← inferType val
|
||
if (← isDefEq valType d) then
|
||
mkDefaultValueAux? struct (b.instantiate1 val)
|
||
else
|
||
pure none
|
||
else
|
||
let arg ← mkFreshExprMVar d
|
||
mkDefaultValueAux? struct (b.instantiate1 arg)
|
||
| e =>
|
||
if e.isAppOfArity `id 2 then
|
||
pure (some e.appArg!)
|
||
else
|
||
pure (some e)
|
||
|
||
def mkDefaultValue? (struct : Struct) (cinfo : ConstantInfo) : TermElabM (Option Expr) :=
|
||
withRef struct.ref do
|
||
let us ← cinfo.lparams.mapM fun _ => mkFreshLevelMVar
|
||
mkDefaultValueAux? struct (cinfo.instantiateValueLevelParams us)
|
||
|
||
/-- If `e` is a projection function of one of the given structures, then reduce it -/
|
||
def reduceProjOf? (structNames : Array Name) (e : Expr) : MetaM (Option Expr) := do
|
||
if !e.isApp then pure none
|
||
else match e.getAppFn with
|
||
| Expr.const name _ _ => do
|
||
let env ← getEnv
|
||
match env.getProjectionStructureName? name with
|
||
| some structName =>
|
||
if structNames.contains structName then
|
||
Meta.unfoldDefinition? e
|
||
else
|
||
pure none
|
||
| none => pure none
|
||
| _ => pure none
|
||
|
||
/-- Reduce default value. It performs beta reduction and projections of the given structures. -/
|
||
partial def reduce (structNames : Array Name) : Expr → MetaM Expr
|
||
| e@(Expr.lam _ _ _ _) => lambdaLetTelescope e fun xs b => do mkLambdaFVars xs (← reduce structNames b)
|
||
| e@(Expr.forallE _ _ _ _) => forallTelescope e fun xs b => do mkForallFVars xs (← reduce structNames b)
|
||
| e@(Expr.letE _ _ _ _ _) => lambdaLetTelescope e fun xs b => do mkLetFVars xs (← reduce structNames b)
|
||
| e@(Expr.proj _ i b _) => do
|
||
match (← Meta.reduceProj? b i) with
|
||
| some r => reduce structNames r
|
||
| none => pure $ e.updateProj! (← reduce structNames b)
|
||
| e@(Expr.app f _ _) => do
|
||
match (← reduceProjOf? structNames e) with
|
||
| some r => reduce structNames r
|
||
| none =>
|
||
let f := f.getAppFn
|
||
let f' ← reduce structNames f
|
||
if f'.isLambda then
|
||
let revArgs := e.getAppRevArgs
|
||
reduce structNames (f'.betaRev revArgs)
|
||
else
|
||
let args ← e.getAppArgs.mapM (reduce structNames)
|
||
pure (mkAppN f' args)
|
||
| e@(Expr.mdata _ b _) => do
|
||
let b ← reduce structNames b
|
||
if (defaultMissing? e).isSome && !b.isMVar then
|
||
pure b
|
||
else
|
||
pure $ e.updateMData! b
|
||
| e@(Expr.mvar mvarId _) => do
|
||
match (← getExprMVarAssignment? mvarId) with
|
||
| some val => if val.isMVar then reduce structNames val else pure val
|
||
| none => pure e
|
||
| e => pure e
|
||
|
||
partial def tryToSynthesizeDefault (structs : Array Struct) (allStructNames : Array Name) (maxDistance : Nat) (fieldName : Name) (mvarId : MVarId) : TermElabM Bool :=
|
||
let rec loop (i : Nat) (dist : Nat) := do
|
||
if dist > maxDistance then
|
||
pure false
|
||
else if h : i < structs.size then do
|
||
let struct := structs.get ⟨i, h⟩
|
||
let defaultName := struct.structName ++ fieldName ++ `_default
|
||
let env ← getEnv
|
||
match env.find? defaultName with
|
||
| some cinfo@(ConstantInfo.defnInfo defVal) => do
|
||
let mctx ← getMCtx
|
||
let val? ← mkDefaultValue? struct cinfo
|
||
match val? with
|
||
| none => do setMCtx mctx; loop (i+1) (dist+1)
|
||
| some val => do
|
||
let val ← liftMetaM $ reduce allStructNames val
|
||
match val.find? fun e => (defaultMissing? e).isSome with
|
||
| some _ => setMCtx mctx; loop (i+1) (dist+1)
|
||
| none =>
|
||
let mvarDecl ← getMVarDecl mvarId
|
||
let val ← ensureHasType mvarDecl.type val
|
||
assignExprMVar mvarId val
|
||
pure true
|
||
| _ => loop (i+1) dist
|
||
else
|
||
pure false
|
||
loop 0 0
|
||
|
||
partial def step (struct : Struct) : M Unit :=
|
||
unless (← isRoundDone) do
|
||
withReader (fun ctx => { ctx with structs := ctx.structs.push struct }) do
|
||
for field in struct.fields do
|
||
match field.val with
|
||
| FieldVal.nested struct => step struct
|
||
| _ => match field.expr? with
|
||
| none => unreachable!
|
||
| some expr => match defaultMissing? expr with
|
||
| some (Expr.mvar mvarId _) =>
|
||
unless (← isExprMVarAssigned mvarId) do
|
||
let ctx ← read
|
||
if (← withRef field.ref $ tryToSynthesizeDefault ctx.structs ctx.allStructNames ctx.maxDistance (getFieldName field) mvarId) then
|
||
modify fun s => { s with progress := true }
|
||
| _ => pure ()
|
||
|
||
partial def propagateLoop (hierarchyDepth : Nat) (d : Nat) (struct : Struct) : M Unit := do
|
||
match findDefaultMissing? (← getMCtx) struct with
|
||
| none => pure () -- Done
|
||
| some field =>
|
||
if d > hierarchyDepth then
|
||
throwErrorAt! field.ref "field '{getFieldName field}' is missing"
|
||
else withReader (fun ctx => { ctx with maxDistance := d }) do
|
||
modify fun s => { s with progress := false }
|
||
step struct
|
||
if (← get).progress then do
|
||
propagateLoop hierarchyDepth 0 struct
|
||
else
|
||
propagateLoop hierarchyDepth (d+1) struct
|
||
|
||
def propagate (struct : Struct) : TermElabM Unit :=
|
||
let hierarchyDepth := getHierarchyDepth struct
|
||
let structNames := collectStructNames struct #[]
|
||
(propagateLoop hierarchyDepth 0 struct { allStructNames := structNames }).run' {}
|
||
|
||
end DefaultFields
|
||
|
||
private def elabStructInstAux (stx : Syntax) (expectedType? : Option Expr) (source : Source) : TermElabM Expr := do
|
||
let structName ← getStructName stx expectedType? source
|
||
unless isStructureLike (← getEnv) structName do
|
||
throwError! "invalid \{...} notation, '{structName}' is not a structure"
|
||
match mkStructView stx structName source with
|
||
| Except.error ex => throwError ex
|
||
| Except.ok struct =>
|
||
let struct ← expandStruct struct
|
||
trace[Elab.struct]! "{struct}"
|
||
let (r, struct) ← elabStruct struct expectedType?
|
||
DefaultFields.propagate struct
|
||
pure r
|
||
|
||
@[builtinTermElab structInst] def elabStructInst : TermElab := fun stx expectedType? => do
|
||
match (← expandNonAtomicExplicitSource stx) with
|
||
| some stxNew => withMacroExpansion stx stxNew $ elabTerm stxNew expectedType?
|
||
| none =>
|
||
let sourceView ← getStructSource stx
|
||
match (← isModifyOp? stx), sourceView with
|
||
| some modifyOp, Source.explicit source _ => elabModifyOp stx modifyOp source expectedType?
|
||
| some _, _ => throwError "invalid {...} notation, explicit source is required when using '[<index>] := <value>'"
|
||
| _, _ => elabStructInstAux stx expectedType? sourceView
|
||
|
||
builtin_initialize registerTraceClass `Elab.struct
|
||
|
||
end Lean.Elab.Term.StructInst
|