chore: move to new frontend
This commit is contained in:
parent
34cddb334e
commit
e02a06ad1c
3 changed files with 294 additions and 326 deletions
|
|
@ -108,14 +108,14 @@ def filterMapM {m : Type u → Type v} [Monad m] {α β : Type u} (f : α → m
|
|||
filterMapMAux f as.reverse []
|
||||
|
||||
@[specialize]
|
||||
def foldlM {m : Type u → Type v} [Monad m] {s : Type u} {α : Type w} : (s → α → m s) → s → List α → m s
|
||||
def foldlM {m : Type u → Type v} [Monad m] {s : Type u} {α : Type w} : forall (f : s → α → m s) (init : s), List α → m s
|
||||
| f, s, [] => pure s
|
||||
| f, s, h :: r => do
|
||||
s' ← f s h;
|
||||
foldlM f s' r
|
||||
|
||||
@[specialize]
|
||||
def foldrM {m : Type u → Type v} [Monad m] {s : Type u} {α : Type w} : (α → s → m s) → s → List α → m s
|
||||
def foldrM {m : Type u → Type v} [Monad m] {s : Type u} {α : Type w} : forall (f : α → s → m s) (init : s), List α → m s
|
||||
| f, s, [] => pure s
|
||||
| f, s, h :: r => do
|
||||
s' ← foldrM f s r;
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
#lang lean4
|
||||
/-
|
||||
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
|
|
@ -8,10 +9,7 @@ import Lean.Elab.App
|
|||
import Lean.Elab.Binders
|
||||
import Lean.Elab.Quotation
|
||||
|
||||
namespace Lean
|
||||
namespace Elab
|
||||
namespace Term
|
||||
namespace StructInst
|
||||
namespace Lean.Elab.Term.StructInst
|
||||
|
||||
open Std (HashMap)
|
||||
open Meta
|
||||
|
|
@ -20,12 +18,12 @@ open Meta
|
|||
|
||||
@[builtinMacro Lean.Parser.Term.structInst] def expandStructInstExpectedType : Macro :=
|
||||
fun stx =>
|
||||
let expectedArg := stx.getArg 4;
|
||||
let expectedArg := stx[4]
|
||||
if expectedArg.isNone
|
||||
then Macro.throwUnsupported
|
||||
else
|
||||
let expected := expectedArg.getArg 1;
|
||||
let stxNew := stx.setArg 4 mkNullNode;
|
||||
let expected := expectedArg[1]
|
||||
let stxNew := stx.setArg 4 mkNullNode
|
||||
`(($stxNew : $expected))
|
||||
|
||||
/-
|
||||
|
|
@ -34,17 +32,18 @@ If `stx` is of the form `{ s with ... }` and `s` is not a local variable, expand
|
|||
Note that this one is not a `Macro` because we need to access the local context.
|
||||
-/
|
||||
private def expandNonAtomicExplicitSource (stx : Syntax) : TermElabM (Option Syntax) :=
|
||||
withFreshMacroScope $
|
||||
let sourceOpt := stx.getArg 1;
|
||||
if sourceOpt.isNone then pure none else do
|
||||
let source := sourceOpt.getArg 0;
|
||||
fvar? ← isLocalIdent? source;
|
||||
match fvar? with
|
||||
withFreshMacroScope do
|
||||
let sourceOpt := stx[1]
|
||||
if sourceOpt.isNone then
|
||||
pure none
|
||||
else
|
||||
let source := sourceOpt[0]
|
||||
match (← isLocalIdent? source) with
|
||||
| some _ => pure none
|
||||
| none => do
|
||||
src ← `(src);
|
||||
let sourceOpt := sourceOpt.setArg 0 src;
|
||||
let stxNew := stx.setArg 1 sourceOpt;
|
||||
| none =>
|
||||
let src ← `(src)
|
||||
let sourceOpt := sourceOpt.setArg 0 src
|
||||
let stxNew := stx.setArg 1 sourceOpt
|
||||
`(let src := $source; $stxNew)
|
||||
|
||||
inductive Source
|
||||
|
|
@ -64,15 +63,15 @@ def setStructSourceSyntax (structStx : Syntax) : Source → Syntax
|
|||
| Source.explicit stx _ => (structStx.setArg 1 stx).setArg 3 mkNullNode
|
||||
|
||||
private def getStructSource (stx : Syntax) : TermElabM Source :=
|
||||
withRef stx $
|
||||
let explicitSource := stx.getArg 1;
|
||||
let implicitSource := stx.getArg 3;
|
||||
withRef stx do
|
||||
let explicitSource := stx[1]
|
||||
let implicitSource := stx[3]
|
||||
if explicitSource.isNone && implicitSource.isNone then
|
||||
pure Source.none
|
||||
else if explicitSource.isNone then
|
||||
pure $ Source.implicit implicitSource
|
||||
else if implicitSource.isNone then do
|
||||
fvar? ← isLocalIdent? (explicitSource.getArg 0);
|
||||
else if implicitSource.isNone then
|
||||
let fvar? ← isLocalIdent? explicitSource[0]
|
||||
match fvar? with
|
||||
| none => unreachable! -- expandNonAtomicExplicitSource must have been used when we get here
|
||||
| some src => pure $ Source.explicit explicitSource src
|
||||
|
|
@ -85,16 +84,16 @@ else
|
|||
def structInstArrayRef := parser! "[" >> termParser >>"]"
|
||||
``` -/
|
||||
private def isModifyOp? (stx : Syntax) : TermElabM (Option Syntax) := do
|
||||
let args := (stx.getArg 2).getArgs;
|
||||
s? ← args.foldSepByM
|
||||
let args := stx[2].getArgs
|
||||
let s? ← args.foldSepByM
|
||||
(fun arg s? =>
|
||||
/- Remark: the syntax for `structInstField` is
|
||||
```
|
||||
def structInstLVal := (ident <|> numLit <|> structInstArrayRef) >> many (group ("." >> (ident <|> numLit)) <|> structInstArrayRef)
|
||||
def structInstField := parser! structInstLVal >> " := " >> termParser
|
||||
``` -/
|
||||
let lval := arg.getArg 0;
|
||||
let k := lval.getKind;
|
||||
let lval := arg[0]
|
||||
let k := lval.getKind
|
||||
if k == `Lean.Parser.Term.structInstArrayRef then
|
||||
match s? with
|
||||
| none => pure (some arg)
|
||||
|
|
@ -111,56 +110,55 @@ s? ← args.foldSepByM
|
|||
throwErrorAt arg "invalid {...} notation, can't mix field and `[..]` at a given level"
|
||||
else
|
||||
pure s?)
|
||||
none;
|
||||
none
|
||||
match s? with
|
||||
| none => pure none
|
||||
| some s => if (s.getArg 0).getKind == `Lean.Parser.Term.structInstArrayRef then pure s? else pure none
|
||||
| some s => if s[0].getKind == `Lean.Parser.Term.structInstArrayRef then pure s? else pure none
|
||||
|
||||
private def elabModifyOp (stx modifyOp source : Syntax) (expectedType? : Option Expr) : TermElabM Expr :=
|
||||
let continue (val : Syntax) : TermElabM Expr := do {
|
||||
let lval := modifyOp.getArg 0;
|
||||
let idx := lval.getArg 1;
|
||||
let self := source.getArg 0;
|
||||
stxNew ← `($(self).modifyOp (idx := $idx) (fun s => $val));
|
||||
trace `Elab.struct.modifyOp fun _ => stx ++ "\n===>\n" ++ stxNew;
|
||||
private def elabModifyOp (stx modifyOp source : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do
|
||||
let cont (val : Syntax) : TermElabM Expr := do
|
||||
let lval := modifyOp[0]
|
||||
let idx := lval[1]
|
||||
let self := source[0]
|
||||
let stxNew ← `($(self).modifyOp (idx := $idx) (fun s => $val))
|
||||
trace[Elab.struct.modifyOp]! "{stx}\n===>\n{stxNew}"
|
||||
withMacroExpansion stx stxNew $ elabTerm stxNew expectedType?
|
||||
}; do
|
||||
trace `Elab.struct.modifyOp fun _ => modifyOp ++ "\nSource: " ++ source;
|
||||
let rest := modifyOp.getArg 1;
|
||||
if rest.isNone then do
|
||||
continue (modifyOp.getArg 3)
|
||||
else do
|
||||
s ← `(s);
|
||||
let valFirst := rest.getArg 0;
|
||||
let valFirst := if valFirst.getKind == `Lean.Parser.Term.structInstArrayRef then valFirst else valFirst.getArg 1;
|
||||
let restArgs := rest.getArgs;
|
||||
let valRest := mkNullNode (restArgs.extract 1 restArgs.size);
|
||||
let valField := modifyOp.setArg 0 valFirst;
|
||||
let valField := valField.setArg 1 valRest;
|
||||
let valSource := source.modifyArg 0 $ fun _ => s;
|
||||
let val := stx.setArg 1 valSource;
|
||||
let val := val.setArg 2 $ mkNullNode #[valField];
|
||||
trace `Elab.struct.modifyOp fun _ => stx ++ "\nval: " ++ val;
|
||||
continue val
|
||||
trace[Elab.struct.modifyOp]! "{modifyOp}\nSource: {source}"
|
||||
let rest := modifyOp[1]
|
||||
if rest.isNone then
|
||||
cont modifyOp[3]
|
||||
else
|
||||
let s ← `(s)
|
||||
let valFirst := rest[0]
|
||||
let valFirst := if valFirst.getKind == `Lean.Parser.Term.structInstArrayRef then valFirst else valFirst[1]
|
||||
let restArgs := rest.getArgs
|
||||
let valRest := mkNullNode restArgs[1:restArgs.size]
|
||||
let valField := modifyOp.setArg 0 valFirst
|
||||
let valField := valField.setArg 1 valRest
|
||||
let valSource := source.modifyArg 0 fun _ => s
|
||||
let val := stx.setArg 1 valSource
|
||||
let val := val.setArg 2 $ mkNullNode #[valField]
|
||||
trace[Elab.struct.modifyOp]! "{stx}\nval: {val}"
|
||||
cont val
|
||||
|
||||
/- Get structure name and elaborate explicit source (if available) -/
|
||||
private def getStructName (stx : Syntax) (expectedType? : Option Expr) (sourceView : Source) : TermElabM Name := do
|
||||
tryPostponeIfNoneOrMVar expectedType?;
|
||||
tryPostponeIfNoneOrMVar expectedType?
|
||||
let useSource : Unit → TermElabM Name := fun _ =>
|
||||
match sourceView, expectedType? with
|
||||
| Source.explicit _ src, _ => do
|
||||
srcType ← inferType src;
|
||||
srcType ← whnf srcType;
|
||||
tryPostponeIfMVar srcType;
|
||||
let srcType ← inferType src
|
||||
let srcType ← whnf srcType
|
||||
tryPostponeIfMVar srcType
|
||||
match srcType.getAppFn with
|
||||
| Expr.const constName _ _ => pure constName
|
||||
| _ => throwError ("invalid {...} notation, source type is not of the form (C ...)" ++ indentExpr srcType)
|
||||
| _, some expectedType => throwError ("invalid {...} notation, expected type is not of the form (C ...)" ++ indentExpr expectedType)
|
||||
| _, none => throwError ("invalid {...} notation, expected type must be known");
|
||||
| _ => throwError! "invalid \{...} notation, source type is not of the form (C ...){indentExpr srcType}"
|
||||
| _, some expectedType => throwError! "invalid \{...} notation, expected type is not of the form (C ...){indentExpr expectedType}"
|
||||
| _, none => throwError! "invalid \{...} notation, expected type must be known"
|
||||
match expectedType? with
|
||||
| none => useSource ()
|
||||
| some expectedType => do
|
||||
expectedType ← whnf expectedType;
|
||||
| some expectedType =>
|
||||
let expectedType ← whnf expectedType
|
||||
match expectedType.getAppFn with
|
||||
| Expr.const constName _ _ => pure constName
|
||||
| _ => useSource ()
|
||||
|
|
@ -178,9 +176,9 @@ instance FieldLHS.hasFormat : HasFormat FieldLHS :=
|
|||
| FieldLHS.modifyOp _ i => "[" ++ i.prettyPrint ++ "]"⟩
|
||||
|
||||
inductive FieldVal (σ : Type)
|
||||
| term (stx : Syntax) : FieldVal
|
||||
| nested (s : σ) : FieldVal
|
||||
| default : FieldVal -- mark that field must be synthesized using default value
|
||||
| term (stx : Syntax) : FieldVal σ
|
||||
| nested (s : σ) : FieldVal σ
|
||||
| default : FieldVal σ -- mark that field must be synthesized using default value
|
||||
|
||||
structure Field (σ : Type) :=
|
||||
(ref : Syntax) (lhs : List FieldLHS) (val : FieldVal σ) (expr? : Option Expr := none)
|
||||
|
|
@ -203,7 +201,7 @@ partial def Struct.allDefault : Struct → Bool
|
|||
| ⟨_, _, fields, _⟩ => fields.all fun ⟨_, _, val, _⟩ => match val with
|
||||
| FieldVal.term _ => false
|
||||
| FieldVal.default => true
|
||||
| FieldVal.nested s => Struct.allDefault s
|
||||
| FieldVal.nested s => allDefault s
|
||||
|
||||
def Struct.ref : Struct → Syntax
|
||||
| ⟨ref, _, _, _⟩ => ref
|
||||
|
|
@ -226,7 +224,7 @@ Format.joinSep field.lhs " . " ++ " := " ++
|
|||
|
||||
partial def formatStruct : Struct → Format
|
||||
| ⟨_, structName, fields, source⟩ =>
|
||||
let fieldsFmt := Format.joinSep (fields.map (formatField formatStruct)) ", ";
|
||||
let fieldsFmt := Format.joinSep (fields.map (formatField formatStruct)) ", "
|
||||
match source with
|
||||
| Source.none => "{" ++ fieldsFmt ++ "}"
|
||||
| Source.implicit _ => "{" ++ fieldsFmt ++ " .. }"
|
||||
|
|
@ -258,35 +256,34 @@ def FieldVal.toSyntax : FieldVal Struct → Syntax
|
|||
|
||||
def Field.toSyntax : Field Struct → Syntax
|
||||
| field =>
|
||||
let stx := field.ref;
|
||||
let stx := stx.setArg 3 field.val.toSyntax;
|
||||
let stx := field.ref
|
||||
let stx := stx.setArg 3 field.val.toSyntax
|
||||
match field.lhs with
|
||||
| first::rest =>
|
||||
let stx := stx.setArg 0 $ first.toSyntax true;
|
||||
let stx := stx.setArg 1 $ mkNullNode $ rest.toArray.map (FieldLHS.toSyntax false);
|
||||
let stx := stx.setArg 0 $ first.toSyntax true
|
||||
let stx := stx.setArg 1 $ mkNullNode $ rest.toArray.map (FieldLHS.toSyntax false)
|
||||
stx
|
||||
| _ => unreachable!
|
||||
|
||||
private def toFieldLHS (stx : Syntax) : Except String FieldLHS :=
|
||||
if stx.getKind == `Lean.Parser.Term.structInstArrayRef then
|
||||
pure $ FieldLHS.modifyOp stx (stx.getArg 1)
|
||||
pure $ FieldLHS.modifyOp stx stx[1]
|
||||
else
|
||||
-- Note that the representation of the first field is different.
|
||||
let stx := if stx.getKind == nullKind then stx.getArg 1 else stx;
|
||||
let stx := if stx.getKind == nullKind then stx[1] else stx
|
||||
if stx.isIdent then pure $ FieldLHS.fieldName stx stx.getId.eraseMacroScopes
|
||||
else match stx.isFieldIdx? with
|
||||
| some idx => pure $ FieldLHS.fieldIndex stx idx
|
||||
| none => throw "unexpected structure syntax"
|
||||
|
||||
private def mkStructView (stx : Syntax) (structName : Name) (source : Source) : Except String Struct := do
|
||||
let args := (stx.getArg 2).getArgs;
|
||||
let fieldsStx := args.filter $ fun arg => arg.getKind == `Lean.Parser.Term.structInstField;
|
||||
fields ← fieldsStx.toList.mapM $ fun fieldStx => do {
|
||||
let val := fieldStx.getArg 3;
|
||||
first ← toFieldLHS (fieldStx.getArg 0);
|
||||
rest ← (fieldStx.getArg 1).getArgs.toList.mapM toFieldLHS;
|
||||
let args := stx[2].getArgs
|
||||
let fieldsStx := args.filter $ fun arg => arg.getKind == `Lean.Parser.Term.structInstField
|
||||
let fields ← fieldsStx.toList.mapM fun fieldStx => do
|
||||
let val := fieldStx[3]
|
||||
let first ← toFieldLHS fieldStx[0]
|
||||
let rest ← fieldStx[1].getArgs.toList.mapM toFieldLHS
|
||||
pure $ ({ref := fieldStx, lhs := first :: rest, val := FieldVal.term val } : Field Struct)
|
||||
};
|
||||
pure ⟨stx, structName, fields, source⟩
|
||||
|
||||
def Struct.modifyFieldsM {m : Type → Type} [Monad m] (s : Struct) (f : Fields → m Fields) : m Struct :=
|
||||
|
|
@ -297,25 +294,25 @@ match s with
|
|||
Id.run $ s.modifyFieldsM f
|
||||
|
||||
def Struct.setFields (s : Struct) (fields : Fields) : Struct :=
|
||||
s.modifyFields $ fun _ => fields
|
||||
s.modifyFields fun _ => fields
|
||||
|
||||
private def expandCompositeFields (s : Struct) : Struct :=
|
||||
s.modifyFields $ fun fields => fields.map $ fun field => match field with
|
||||
| { lhs := FieldLHS.fieldName ref (Name.str Name.anonymous _ _) :: rest, .. } => field
|
||||
| { lhs := FieldLHS.fieldName ref n@(Name.str _ _ _) :: rest, .. } =>
|
||||
let newEntries := n.components.map $ FieldLHS.fieldName ref;
|
||||
let newEntries := n.components.map $ FieldLHS.fieldName ref
|
||||
{ field with lhs := newEntries ++ rest }
|
||||
| _ => field
|
||||
|
||||
private def expandNumLitFields (s : Struct) : TermElabM Struct :=
|
||||
s.modifyFieldsM $ fun fields => do
|
||||
env ← getEnv;
|
||||
let fieldNames := getStructureFields env s.structName;
|
||||
fields.mapM $ fun field => match field with
|
||||
s.modifyFieldsM fun fields => do
|
||||
let env ← getEnv
|
||||
let fieldNames := getStructureFields env s.structName
|
||||
fields.mapM fun field => match field with
|
||||
| { lhs := FieldLHS.fieldIndex ref idx :: rest, .. } =>
|
||||
if idx == 0 then throwErrorAt ref "invalid field index, index must be greater than 0"
|
||||
else if idx > fieldNames.size then throwErrorAt ref ("invalid field index, structure has only #" ++ toString fieldNames.size ++ " fields")
|
||||
else pure { field with lhs := FieldLHS.fieldName ref (fieldNames.get! $ idx - 1) :: rest }
|
||||
else if idx > fieldNames.size then throwErrorAt! ref "invalid field index, structure has only #{fieldNames.size} fields"
|
||||
else pure { field with lhs := FieldLHS.fieldName ref fieldNames[idx - 1] :: rest }
|
||||
| _ => pure field
|
||||
|
||||
/- For example, consider the following structures:
|
||||
|
|
@ -334,38 +331,36 @@ s.modifyFieldsM $ fun fields => do
|
|||
{ toB.toA.x := 0, toB.y := 0, z := true : C }
|
||||
``` -/
|
||||
private def expandParentFields (s : Struct) : TermElabM Struct := do
|
||||
env ← getEnv;
|
||||
s.modifyFieldsM $ fun fields => fields.mapM $ fun field => match field with
|
||||
let env ← getEnv
|
||||
s.modifyFieldsM fun fields => fields.mapM fun field => match field with
|
||||
| { lhs := FieldLHS.fieldName ref fieldName :: rest, .. } =>
|
||||
match findField? env s.structName fieldName with
|
||||
| none => throwErrorAt ref ("'" ++ fieldName ++ "' is not a field of structure '" ++ s.structName ++ "'")
|
||||
| none => throwErrorAt! ref "'{fieldName}' is not a field of structure '{s.structName}'"
|
||||
| some baseStructName =>
|
||||
if baseStructName == s.structName then pure field
|
||||
else match getPathToBaseStructure? env baseStructName s.structName with
|
||||
| some path => do
|
||||
let path := path.map $ fun funName => match funName with
|
||||
| Name.str _ s _ => FieldLHS.fieldName ref (mkNameSimple s)
|
||||
| _ => unreachable!;
|
||||
| _ => unreachable!
|
||||
pure { field with lhs := path ++ field.lhs }
|
||||
| _ => throwErrorAt ref ("failed to access field '" ++ fieldName ++ "' in parent structure")
|
||||
| _ => throwErrorAt! ref "failed to access field '{fieldName}' in parent structure"
|
||||
| _ => pure field
|
||||
|
||||
private abbrev FieldMap := HashMap Name Fields
|
||||
|
||||
private def mkFieldMap (fields : Fields) : TermElabM FieldMap :=
|
||||
fields.foldlM
|
||||
(fun fieldMap field =>
|
||||
match field.lhs with
|
||||
| FieldLHS.fieldName _ fieldName :: rest =>
|
||||
match fieldMap.find? fieldName with
|
||||
| some (prevField::restFields) =>
|
||||
if field.isSimple || prevField.isSimple then
|
||||
throwErrorAt field.ref ("field '" ++ fieldName ++ "' has already beed specified")
|
||||
else
|
||||
pure $ fieldMap.insert fieldName (field::prevField::restFields)
|
||||
| _ => pure $ fieldMap.insert fieldName [field]
|
||||
| _ => unreachable!)
|
||||
{}
|
||||
fields.foldlM (init := {}) fun fieldMap field =>
|
||||
match field.lhs with
|
||||
| FieldLHS.fieldName _ fieldName :: rest =>
|
||||
match fieldMap.find? fieldName with
|
||||
| some (prevField::restFields) =>
|
||||
if field.isSimple || prevField.isSimple then
|
||||
throwErrorAt! field.ref "field '{fieldName}' has already beed specified"
|
||||
else
|
||||
pure $ fieldMap.insert fieldName (field::prevField::restFields)
|
||||
| _ => pure $ fieldMap.insert fieldName [field]
|
||||
| _ => unreachable!
|
||||
|
||||
private def isSimpleField? : Fields → Option (Field Struct)
|
||||
| [field] => if field.isSimple then some field else none
|
||||
|
|
@ -374,7 +369,7 @@ private def isSimpleField? : Fields → Option (Field Struct)
|
|||
private def getFieldIdx (structName : Name) (fieldNames : Array Name) (fieldName : Name) : TermElabM Nat := do
|
||||
match fieldNames.findIdx? $ fun n => n == fieldName with
|
||||
| some idx => pure idx
|
||||
| none => throwError ("field '" ++ fieldName ++ "' is not a valid field of '" ++ structName ++ "'")
|
||||
| none => throwError! "field '{fieldName}' is not a valid field of '{structName}'"
|
||||
|
||||
private def mkProjStx (s : Syntax) (fieldName : Name) : Syntax :=
|
||||
Syntax.node `Lean.Parser.Term.proj #[s, mkAtomFrom s ".", mkIdentFrom s fieldName]
|
||||
|
|
@ -382,81 +377,78 @@ Syntax.node `Lean.Parser.Term.proj #[s, mkAtomFrom s ".", mkIdentFrom s fieldNam
|
|||
private def mkSubstructSource (structName : Name) (fieldNames : Array Name) (fieldName : Name) (src : Source) : TermElabM Source :=
|
||||
match src with
|
||||
| Source.explicit stx src => do
|
||||
idx ← getFieldIdx structName fieldNames fieldName;
|
||||
let stx := stx.modifyArg 0 $ fun stx => mkProjStx stx fieldName;
|
||||
let idx ← getFieldIdx structName fieldNames fieldName
|
||||
let stx := stx.modifyArg 0 fun stx => mkProjStx stx fieldName
|
||||
pure $ Source.explicit stx (mkProj structName idx src)
|
||||
| s => pure s
|
||||
|
||||
@[specialize] private def groupFields (expandStruct : Struct → TermElabM Struct) (s : Struct) : TermElabM Struct := do
|
||||
env ← getEnv;
|
||||
let fieldNames := getStructureFields env s.structName;
|
||||
withRef s.ref $
|
||||
s.modifyFieldsM $ fun fields => do
|
||||
fieldMap ← mkFieldMap fields;
|
||||
fieldMap.toList.mapM $ fun ⟨fieldName, fields⟩ =>
|
||||
let env ← getEnv
|
||||
let fieldNames := getStructureFields env s.structName
|
||||
withRef s.ref do
|
||||
s.modifyFieldsM fun fields => do
|
||||
let fieldMap ← mkFieldMap fields
|
||||
fieldMap.toList.mapM fun ⟨fieldName, fields⟩ => do
|
||||
match isSimpleField? fields with
|
||||
| some field => pure field
|
||||
| none => do
|
||||
let substructFields := fields.map $ fun field => { field with lhs := field.lhs.tail! };
|
||||
substructSource ← mkSubstructSource s.structName fieldNames fieldName s.source;
|
||||
let field := fields.head!;
|
||||
| none =>
|
||||
let substructFields := fields.map fun field => { field with lhs := field.lhs.tail! }
|
||||
let substructSource ← mkSubstructSource s.structName fieldNames fieldName s.source
|
||||
let field := fields.head!
|
||||
match Lean.isSubobjectField? env s.structName fieldName with
|
||||
| some substructName => do
|
||||
let substruct := Struct.mk s.ref substructName substructFields substructSource;
|
||||
substruct ← expandStruct substruct;
|
||||
| some substructName =>
|
||||
let substruct := Struct.mk s.ref substructName substructFields substructSource
|
||||
let substruct ← expandStruct substruct
|
||||
pure { field with lhs := [field.lhs.head!], val := FieldVal.nested substruct }
|
||||
| none => do
|
||||
-- It is not a substructure field. Thus, we wrap fields using `Syntax`, and use `elabTerm` to process them.
|
||||
let valStx := s.ref; -- construct substructure syntax using s.ref as template
|
||||
let valStx := valStx.setArg 4 mkNullNode; -- erase optional expected type
|
||||
let args := substructFields.toArray.map $ Field.toSyntax;
|
||||
let valStx := valStx.setArg 2 (mkSepStx args (mkAtomFrom s.ref ","));
|
||||
let valStx := setStructSourceSyntax valStx substructSource;
|
||||
let valStx := s.ref -- construct substructure syntax using s.ref as template
|
||||
let valStx := valStx.setArg 4 mkNullNode -- erase optional expected type
|
||||
let args := substructFields.toArray.map Field.toSyntax
|
||||
let valStx := valStx.setArg 2 (mkSepStx args (mkAtomFrom s.ref ","))
|
||||
let valStx := setStructSourceSyntax valStx substructSource
|
||||
pure { field with lhs := [field.lhs.head!], val := FieldVal.term valStx }
|
||||
|
||||
def findField? (fields : Fields) (fieldName : Name) : Option (Field Struct) :=
|
||||
fields.find? $ fun field =>
|
||||
fields.find? fun field =>
|
||||
match field.lhs with
|
||||
| [FieldLHS.fieldName _ n] => n == fieldName
|
||||
| _ => false
|
||||
|
||||
@[specialize] private def addMissingFields (expandStruct : Struct → TermElabM Struct) (s : Struct) : TermElabM Struct := do
|
||||
env ← getEnv;
|
||||
let fieldNames := getStructureFields env s.structName;
|
||||
let ref := s.ref;
|
||||
let env ← getEnv
|
||||
let fieldNames := getStructureFields env s.structName
|
||||
let ref := s.ref
|
||||
withRef ref do
|
||||
fields ← fieldNames.foldlM
|
||||
(fun fields fieldName => do
|
||||
match findField? s.fields fieldName with
|
||||
| some field => pure $ field::fields
|
||||
| none =>
|
||||
let addField (val : FieldVal Struct) : TermElabM Fields := do {
|
||||
pure $ { ref := s.ref, lhs := [FieldLHS.fieldName s.ref fieldName], val := val } :: fields
|
||||
};
|
||||
match Lean.isSubobjectField? env s.structName fieldName with
|
||||
| some substructName => do
|
||||
substructSource ← mkSubstructSource s.structName fieldNames fieldName s.source;
|
||||
let substruct := Struct.mk s.ref substructName [] substructSource;
|
||||
substruct ← expandStruct substruct;
|
||||
addField (FieldVal.nested substruct)
|
||||
| none =>
|
||||
match s.source with
|
||||
| Source.none => addField FieldVal.default
|
||||
| Source.implicit _ => addField (FieldVal.term (mkHole s.ref))
|
||||
| Source.explicit stx _ =>
|
||||
-- stx is of the form `optional (try (termParser >> "with"))`
|
||||
let src := stx.getArg 0;
|
||||
let val := mkProjStx src fieldName;
|
||||
addField (FieldVal.term val))
|
||||
[];
|
||||
let fields ← fieldNames.foldlM (init := []) fun fields fieldName => do
|
||||
match findField? s.fields fieldName with
|
||||
| some field => pure $ field::fields
|
||||
| none =>
|
||||
let addField (val : FieldVal Struct) : TermElabM Fields := do
|
||||
pure $ { ref := s.ref, lhs := [FieldLHS.fieldName s.ref fieldName], val := val } :: fields
|
||||
match Lean.isSubobjectField? env s.structName fieldName with
|
||||
| some substructName => do
|
||||
let substructSource ← mkSubstructSource s.structName fieldNames fieldName s.source
|
||||
let substruct := Struct.mk s.ref substructName [] substructSource
|
||||
let substruct ← expandStruct substruct
|
||||
addField (FieldVal.nested substruct)
|
||||
| none =>
|
||||
match s.source with
|
||||
| Source.none => addField FieldVal.default
|
||||
| Source.implicit _ => addField (FieldVal.term (mkHole s.ref))
|
||||
| Source.explicit stx _ =>
|
||||
-- stx is of the form `optional (try (termParser >> "with"))`
|
||||
let src := stx[0]
|
||||
let val := mkProjStx src fieldName
|
||||
addField (FieldVal.term val)
|
||||
pure $ s.setFields fields.reverse
|
||||
|
||||
private partial def expandStruct : Struct → TermElabM Struct
|
||||
| s => do
|
||||
let s := expandCompositeFields s;
|
||||
s ← expandNumLitFields s;
|
||||
s ← expandParentFields s;
|
||||
s ← groupFields expandStruct s;
|
||||
let s := expandCompositeFields s
|
||||
let s ← expandNumLitFields s
|
||||
let s ← expandParentFields s
|
||||
let s ← groupFields expandStruct s
|
||||
addMissingFields expandStruct s
|
||||
|
||||
structure CtorHeaderResult :=
|
||||
|
|
@ -467,15 +459,15 @@ structure CtorHeaderResult :=
|
|||
private def mkCtorHeaderAux : Nat → Expr → Expr → Array MVarId → TermElabM CtorHeaderResult
|
||||
| 0, type, ctorFn, instMVars => pure { ctorFn := ctorFn, ctorFnType := type, instMVars := instMVars }
|
||||
| n+1, type, ctorFn, instMVars => do
|
||||
type ← whnfForall type;
|
||||
let type ← whnfForall type
|
||||
match type with
|
||||
| Expr.forallE _ d b c =>
|
||||
match c.binderInfo with
|
||||
| BinderInfo.instImplicit => do
|
||||
a ← mkFreshExprMVar d MetavarKind.synthetic;
|
||||
| BinderInfo.instImplicit =>
|
||||
let a ← mkFreshExprMVar d MetavarKind.synthetic
|
||||
mkCtorHeaderAux n (b.instantiate1 a) (mkApp ctorFn a) (instMVars.push a.mvarId!)
|
||||
| _ => do
|
||||
a ← mkFreshExprMVar d;
|
||||
| _ =>
|
||||
let a ← mkFreshExprMVar d
|
||||
mkCtorHeaderAux n (b.instantiate1 a) (mkApp ctorFn a) instMVars
|
||||
| _ => throwError "unexpected constructor type"
|
||||
|
||||
|
|
@ -487,21 +479,21 @@ private partial def getForallBody : Nat → Expr → Option Expr
|
|||
private def propagateExpectedType (type : Expr) (numFields : Nat) (expectedType? : Option Expr) : TermElabM Unit :=
|
||||
match expectedType? with
|
||||
| none => pure ()
|
||||
| some expectedType =>
|
||||
| some expectedType => do
|
||||
match getForallBody numFields type with
|
||||
| none => pure ()
|
||||
| some typeBody =>
|
||||
unless typeBody.hasLooseBVars $ do
|
||||
_ ← isDefEq expectedType typeBody;
|
||||
unless typeBody.hasLooseBVars do
|
||||
isDefEq expectedType typeBody
|
||||
pure ()
|
||||
|
||||
private def mkCtorHeader (ctorVal : ConstructorVal) (expectedType? : Option Expr) : TermElabM CtorHeaderResult := do
|
||||
lvls ← ctorVal.lparams.mapM $ fun _ => mkFreshLevelMVar;
|
||||
let val := Lean.mkConst ctorVal.name lvls;
|
||||
let type := (ConstantInfo.ctorInfo ctorVal).instantiateTypeLevelParams lvls;
|
||||
r ← mkCtorHeaderAux ctorVal.nparams type val #[];
|
||||
propagateExpectedType r.ctorFnType ctorVal.nfields expectedType?;
|
||||
synthesizeAppInstMVars r.instMVars;
|
||||
let lvls ← ctorVal.lparams.mapM fun _ => mkFreshLevelMVar
|
||||
let val := Lean.mkConst ctorVal.name lvls
|
||||
let type := (ConstantInfo.ctorInfo ctorVal).instantiateTypeLevelParams lvls
|
||||
let r ← mkCtorHeaderAux ctorVal.nparams type val #[]
|
||||
propagateExpectedType r.ctorFnType ctorVal.nfields expectedType?
|
||||
synthesizeAppInstMVars r.instMVars
|
||||
pure r
|
||||
|
||||
def markDefaultMissing (e : Expr) : Expr :=
|
||||
|
|
@ -511,43 +503,40 @@ def defaultMissing? (e : Expr) : Option Expr :=
|
|||
annotation? `structInstDefault e
|
||||
|
||||
def throwFailedToElabField {α} (fieldName : Name) (structName : Name) (msgData : MessageData) : TermElabM α :=
|
||||
throwError ("failed to elaborate field '" ++ fieldName ++ "' of '" ++ structName ++ ", " ++ msgData)
|
||||
throwError! "failed to elaborate field '{fieldName}' of '{structName}, {msgData}"
|
||||
|
||||
def trySynthStructInstance? (s : Struct) (expectedType : Expr) : TermElabM (Option Expr) :=
|
||||
if !s.allDefault then pure none
|
||||
def trySynthStructInstance? (s : Struct) (expectedType : Expr) : TermElabM (Option Expr) := do
|
||||
if !s.allDefault then
|
||||
pure none
|
||||
else
|
||||
catch (synthInstance? expectedType) (fun _ => pure none)
|
||||
try synthInstance? expectedType catch _ => pure none
|
||||
|
||||
private partial def elabStruct : Struct → Option Expr → TermElabM (Expr × Struct)
|
||||
| s, expectedType? => withRef s.ref do
|
||||
env ← getEnv;
|
||||
let ctorVal := getStructureCtor env s.structName;
|
||||
{ ctorFn := ctorFn, ctorFnType := ctorFnType, .. } ← mkCtorHeader ctorVal expectedType?;
|
||||
(e, _, fields) ← s.fields.foldlM
|
||||
(fun (acc : Expr × Expr × Fields) field =>
|
||||
let (e, type, fields) := acc;
|
||||
match field.lhs with
|
||||
| [FieldLHS.fieldName ref fieldName] => do
|
||||
type ← whnfForall type;
|
||||
match type with
|
||||
| Expr.forallE _ d b c =>
|
||||
let continue (val : Expr) (field : Field Struct) : TermElabM (Expr × Expr × Fields) := do {
|
||||
let e := mkApp e val;
|
||||
let type := b.instantiate1 val;
|
||||
let field := { field with expr? := some val };
|
||||
pure (e, type, field::fields)
|
||||
};
|
||||
match field.val with
|
||||
| FieldVal.term stx => do val ← elabTermEnsuringType stx d; continue val field
|
||||
| FieldVal.nested s => do
|
||||
val? ← trySynthStructInstance? s d; -- if all fields of `s` are marked as `default`, then try to synthesize instance
|
||||
match val? with
|
||||
| some val => continue val { field with val := FieldVal.term (mkHole field.ref) }
|
||||
| none => do(val, sNew) ← elabStruct s (some d); val ← ensureHasType d val; continue val { field with val := FieldVal.nested sNew }
|
||||
| FieldVal.default => do val ← withRef field.ref $ mkFreshExprMVar (some d); continue (markDefaultMissing val) field
|
||||
| _ => withRef field.ref $ throwFailedToElabField fieldName s.structName ("unexpected constructor type" ++ indentExpr type)
|
||||
| _ => throwErrorAt field.ref "unexpected unexpanded structure field")
|
||||
(ctorFn, ctorFnType, []);
|
||||
let env ← getEnv
|
||||
let ctorVal := getStructureCtor env s.structName
|
||||
let { ctorFn := ctorFn, ctorFnType := ctorFnType, .. } ← mkCtorHeader ctorVal expectedType?
|
||||
let (e, _, fields) ← s.fields.foldlM (init := (ctorFn, ctorFnType, [])) fun (e, type, fields) field =>
|
||||
match field.lhs with
|
||||
| [FieldLHS.fieldName ref fieldName] => do
|
||||
let type ← whnfForall type
|
||||
match type with
|
||||
| Expr.forallE _ d b c =>
|
||||
let cont (val : Expr) (field : Field Struct) : TermElabM (Expr × Expr × Fields) := do
|
||||
let e := mkApp e val
|
||||
let type := b.instantiate1 val
|
||||
let field := { field with expr? := some val }
|
||||
pure (e, type, field::fields)
|
||||
match field.val with
|
||||
| FieldVal.term stx => cont (← elabTermEnsuringType stx d) field
|
||||
| FieldVal.nested s => do
|
||||
-- if all fields of `s` are marked as `default`, then try to synthesize instance
|
||||
match (← trySynthStructInstance? s d) with
|
||||
| some val => cont val { field with val := FieldVal.term (mkHole field.ref) }
|
||||
| none => do let (val, sNew) ← elabStruct s (some d); val ← ensureHasType d val; cont val { field with val := FieldVal.nested sNew }
|
||||
| FieldVal.default => do let val ← withRef field.ref $ mkFreshExprMVar (some d); cont (markDefaultMissing val) field
|
||||
| _ => withRef field.ref $ throwFailedToElabField fieldName s.structName msg!"unexpected constructor type{indentExpr type}"
|
||||
| _ => throwErrorAt field.ref "unexpected unexpanded structure field"
|
||||
pure (e, s.setFields fields.reverse)
|
||||
|
||||
namespace DefaultFields
|
||||
|
|
@ -587,28 +576,24 @@ structure State :=
|
|||
|
||||
partial def collectStructNames : Struct → Array Name → Array Name
|
||||
| struct, names =>
|
||||
let names := names.push struct.structName;
|
||||
struct.fields.foldl
|
||||
(fun names field =>
|
||||
match field.val with
|
||||
| FieldVal.nested struct => collectStructNames struct names
|
||||
| _ => names)
|
||||
names
|
||||
let names := names.push struct.structName
|
||||
struct.fields.foldl (init := names) fun names field =>
|
||||
match field.val with
|
||||
| FieldVal.nested struct => collectStructNames struct names
|
||||
| _ => names
|
||||
|
||||
partial def getHierarchyDepth : Struct → Nat
|
||||
| struct =>
|
||||
struct.fields.foldl
|
||||
(fun max field =>
|
||||
match field.val with
|
||||
| FieldVal.nested struct => Nat.max max (getHierarchyDepth struct + 1)
|
||||
| _ => max)
|
||||
0
|
||||
struct.fields.foldl (init := 0) fun max field =>
|
||||
match field.val with
|
||||
| FieldVal.nested struct => Nat.max max (getHierarchyDepth struct + 1)
|
||||
| _ => max
|
||||
|
||||
partial def findDefaultMissing? (mctx : MetavarContext) : Struct → Option (Field Struct)
|
||||
| struct =>
|
||||
struct.fields.findSome? $ fun field =>
|
||||
struct.fields.findSome? fun field =>
|
||||
match field.val with
|
||||
| FieldVal.nested struct => findDefaultMissing? struct
|
||||
| FieldVal.nested struct => findDefaultMissing? mctx struct
|
||||
| _ => match field.expr? with
|
||||
| none => unreachable!
|
||||
| some expr => match defaultMissing? expr with
|
||||
|
|
@ -623,31 +608,30 @@ match field.lhs with
|
|||
abbrev M := ReaderT Context (StateRefT State TermElabM)
|
||||
|
||||
def isRoundDone : M Bool := do
|
||||
ctx ← read;
|
||||
s ← get;
|
||||
pure (s.progress && ctx.maxDistance > 0)
|
||||
return (← get).progress && (← read).maxDistance > 0
|
||||
|
||||
def getFieldValue? (struct : Struct) (fieldName : Name) : Option Expr :=
|
||||
struct.fields.findSome? $ fun field =>
|
||||
struct.fields.findSome? fun field =>
|
||||
if getFieldName field == fieldName then
|
||||
field.expr?
|
||||
else
|
||||
none
|
||||
|
||||
partial def mkDefaultValueAux? (struct : Struct) : Expr → TermElabM (Option Expr)
|
||||
| Expr.lam n d b c => withRef struct.ref $
|
||||
| Expr.lam n d b c => withRef struct.ref do
|
||||
if c.binderInfo.isExplicit then
|
||||
let fieldName := n;
|
||||
let fieldName := n
|
||||
match getFieldValue? struct fieldName with
|
||||
| none => pure none
|
||||
| some val => do
|
||||
valType ← inferType val;
|
||||
condM (isDefEq valType d)
|
||||
(mkDefaultValueAux? (b.instantiate1 val))
|
||||
(pure none)
|
||||
else do
|
||||
arg ← mkFreshExprMVar d;
|
||||
mkDefaultValueAux? (b.instantiate1 arg)
|
||||
| some val =>
|
||||
let valType ← inferType val
|
||||
if (← isDefEq valType d) then
|
||||
mkDefaultValueAux? struct (b.instantiate1 val)
|
||||
else
|
||||
pure none
|
||||
else
|
||||
let arg ← mkFreshExprMVar d
|
||||
mkDefaultValueAux? struct (b.instantiate1 arg)
|
||||
| e =>
|
||||
if e.isAppOfArity `id 2 then
|
||||
pure (some e.appArg!)
|
||||
|
|
@ -656,7 +640,7 @@ partial def mkDefaultValueAux? (struct : Struct) : Expr → TermElabM (Option Ex
|
|||
|
||||
def mkDefaultValue? (struct : Struct) (cinfo : ConstantInfo) : TermElabM (Option Expr) :=
|
||||
withRef struct.ref do
|
||||
us ← cinfo.lparams.mapM $ fun _ => mkFreshLevelMVar;
|
||||
let us ← cinfo.lparams.mapM fun _ => mkFreshLevelMVar
|
||||
mkDefaultValueAux? struct (cinfo.instantiateValueLevelParams us)
|
||||
|
||||
/-- If `e` is a projection function of one of the given structures, then reduce it -/
|
||||
|
|
@ -664,7 +648,7 @@ def reduceProjOf? (structNames : Array Name) (e : Expr) : MetaM (Option Expr) :=
|
|||
if !e.isApp then pure none
|
||||
else match e.getAppFn with
|
||||
| Expr.const name _ _ => do
|
||||
env ← getEnv;
|
||||
let env ← getEnv
|
||||
match env.getProjectionStructureName? name with
|
||||
| some structName =>
|
||||
if structNames.contains structName then
|
||||
|
|
@ -676,142 +660,126 @@ else match e.getAppFn with
|
|||
|
||||
/-- Reduce default value. It performs beta reduction and projections of the given structures. -/
|
||||
partial def reduce (structNames : Array Name) : Expr → MetaM Expr
|
||||
| e@(Expr.lam _ _ _ _) => lambdaLetTelescope e $ fun xs b => do b ← reduce b; mkLambdaFVars xs b
|
||||
| e@(Expr.forallE _ _ _ _) => forallTelescope e $ fun xs b => do b ← reduce b; mkForallFVars xs b
|
||||
| e@(Expr.letE _ _ _ _ _) => lambdaLetTelescope e $ fun xs b => do b ← reduce b; mkLetFVars xs b
|
||||
| e@(Expr.lam _ _ _ _) => lambdaLetTelescope e fun xs b => do mkLambdaFVars xs (← reduce structNames b)
|
||||
| e@(Expr.forallE _ _ _ _) => forallTelescope e fun xs b => do mkForallFVars xs (← reduce structNames b)
|
||||
| e@(Expr.letE _ _ _ _ _) => lambdaLetTelescope e fun xs b => do mkLetFVars xs (← reduce structNames b)
|
||||
| e@(Expr.proj _ i b _) => do
|
||||
r? ← Meta.reduceProj? b i;
|
||||
match r? with
|
||||
| some r => reduce r
|
||||
| none => do b ← reduce b; pure $ e.updateProj! b
|
||||
match (← Meta.reduceProj? b i) with
|
||||
| some r => reduce structNames r
|
||||
| none => pure $ e.updateProj! (← reduce structNames b)
|
||||
| e@(Expr.app f _ _) => do
|
||||
r? ← reduceProjOf? structNames e;
|
||||
match r? with
|
||||
| some r => reduce r
|
||||
| none => do
|
||||
let f := f.getAppFn;
|
||||
f' ← reduce f;
|
||||
match (← reduceProjOf? structNames e) with
|
||||
| some r => reduce structNames r
|
||||
| none =>
|
||||
let f := f.getAppFn
|
||||
let f' ← reduce structNames f
|
||||
if f'.isLambda then
|
||||
let revArgs := e.getAppRevArgs;
|
||||
reduce $ f'.betaRev revArgs
|
||||
else do
|
||||
args ← e.getAppArgs.mapM reduce;
|
||||
let revArgs := e.getAppRevArgs
|
||||
reduce structNames (f'.betaRev revArgs)
|
||||
else
|
||||
let args ← e.getAppArgs.mapM (reduce structNames)
|
||||
pure (mkAppN f' args)
|
||||
| e@(Expr.mdata _ b _) => do
|
||||
b ← reduce b;
|
||||
let b ← reduce structNames b
|
||||
if (defaultMissing? e).isSome && !b.isMVar then
|
||||
pure b
|
||||
else
|
||||
pure $ e.updateMData! b
|
||||
| e@(Expr.mvar mvarId _) => do
|
||||
val? ← getExprMVarAssignment? mvarId;
|
||||
match val? with
|
||||
| some val => if val.isMVar then reduce val else pure val
|
||||
match (← getExprMVarAssignment? mvarId) with
|
||||
| some val => if val.isMVar then reduce structNames val else pure val
|
||||
| none => pure e
|
||||
| e => pure e
|
||||
|
||||
partial def tryToSynthesizeDefaultAux (structs : Array Struct) (allStructNames : Array Name) (maxDistance : Nat)
|
||||
(fieldName : Name) (mvarId : MVarId) : Nat → Nat → TermElabM Bool
|
||||
| i, dist =>
|
||||
if dist > maxDistance then pure false
|
||||
partial def tryToSynthesizeDefault (structs : Array Struct) (allStructNames : Array Name) (maxDistance : Nat) (fieldName : Name) (mvarId : MVarId) : TermElabM Bool :=
|
||||
let rec loop (i : Nat) (dist : Nat) := do
|
||||
if dist > maxDistance then
|
||||
pure false
|
||||
else if h : i < structs.size then do
|
||||
let struct := structs.get ⟨i, h⟩;
|
||||
let defaultName := struct.structName ++ fieldName ++ `_default;
|
||||
env ← getEnv;
|
||||
let struct := structs.get ⟨i, h⟩
|
||||
let defaultName := struct.structName ++ fieldName ++ `_default
|
||||
let env ← getEnv
|
||||
match env.find? defaultName with
|
||||
| some cinfo@(ConstantInfo.defnInfo defVal) => do
|
||||
mctx ← getMCtx;
|
||||
val? ← mkDefaultValue? struct cinfo;
|
||||
let mctx ← getMCtx
|
||||
let val? ← mkDefaultValue? struct cinfo
|
||||
match val? with
|
||||
| none => do setMCtx mctx; tryToSynthesizeDefaultAux (i+1) (dist+1)
|
||||
| none => do setMCtx mctx; loop (i+1) (dist+1)
|
||||
| some val => do
|
||||
val ← liftMetaM $ reduce allStructNames val;
|
||||
match val.find? $ fun e => (defaultMissing? e).isSome with
|
||||
| some _ => do setMCtx mctx; tryToSynthesizeDefaultAux (i+1) (dist+1)
|
||||
| none => do
|
||||
mvarDecl ← getMVarDecl mvarId;
|
||||
val ← ensureHasType mvarDecl.type val;
|
||||
assignExprMVar mvarId val;
|
||||
let val ← liftMetaM $ reduce allStructNames val
|
||||
match val.find? fun e => (defaultMissing? e).isSome with
|
||||
| some _ => setMCtx mctx; loop (i+1) (dist+1)
|
||||
| none =>
|
||||
let mvarDecl ← getMVarDecl mvarId
|
||||
let val ← ensureHasType mvarDecl.type val
|
||||
assignExprMVar mvarId val
|
||||
pure true
|
||||
| _ => tryToSynthesizeDefaultAux (i+1) dist
|
||||
| _ => loop (i+1) dist
|
||||
else
|
||||
pure false
|
||||
|
||||
def tryToSynthesizeDefault (structs : Array Struct) (allStructNames : Array Name)
|
||||
(maxDistance : Nat) (fieldName : Name) (mvarId : MVarId) : TermElabM Bool :=
|
||||
tryToSynthesizeDefaultAux structs allStructNames maxDistance fieldName mvarId 0 0
|
||||
loop 0 0
|
||||
|
||||
partial def step : Struct → M Unit
|
||||
| struct => unlessM isRoundDone $ withReader (fun ctx => { ctx with structs := ctx.structs.push struct }) $ do
|
||||
struct.fields.forM $ fun field =>
|
||||
| struct => unlessM isRoundDone $ withReader (fun ctx => { ctx with structs := ctx.structs.push struct }) do
|
||||
struct.fields.forM fun field => do
|
||||
match field.val with
|
||||
| FieldVal.nested struct => step struct
|
||||
| _ => match field.expr? with
|
||||
| none => unreachable!
|
||||
| some expr => match defaultMissing? expr with
|
||||
| some (Expr.mvar mvarId _) =>
|
||||
unlessM (liftM $ isExprMVarAssigned mvarId) $ do
|
||||
ctx ← read;
|
||||
whenM (liftM $ withRef field.ref $ tryToSynthesizeDefault ctx.structs ctx.allStructNames ctx.maxDistance (getFieldName field) mvarId) $ do
|
||||
modify $ fun s => { s with progress := true }
|
||||
unless (← isExprMVarAssigned mvarId) do
|
||||
let ctx ← read
|
||||
if (← withRef field.ref $ tryToSynthesizeDefault ctx.structs ctx.allStructNames ctx.maxDistance (getFieldName field) mvarId) then
|
||||
modify fun s => { s with progress := true }
|
||||
| _ => pure ()
|
||||
|
||||
partial def propagateLoop (hierarchyDepth : Nat) : Nat → Struct → M Unit
|
||||
| d, struct => do
|
||||
mctx ← getMCtx;
|
||||
match findDefaultMissing? mctx struct with
|
||||
match findDefaultMissing? (← getMCtx) struct with
|
||||
| none => pure () -- Done
|
||||
| some field =>
|
||||
if d > hierarchyDepth then
|
||||
throwErrorAt field.ref ("field '" ++ getFieldName field ++ "' is missing")
|
||||
else withReader (fun ctx => { ctx with maxDistance := d }) $ do
|
||||
modify $ fun (s : State) => { s with progress := false };
|
||||
step struct;
|
||||
s ← get;
|
||||
if s.progress then do
|
||||
propagateLoop 0 struct
|
||||
throwErrorAt! field.ref "field '{getFieldName field}' is missing"
|
||||
else withReader (fun ctx => { ctx with maxDistance := d }) do
|
||||
modify fun s => { s with progress := false }
|
||||
step struct
|
||||
if (← get).progress then do
|
||||
propagateLoop hierarchyDepth 0 struct
|
||||
else
|
||||
propagateLoop (d+1) struct
|
||||
propagateLoop hierarchyDepth (d+1) struct
|
||||
|
||||
def propagate (struct : Struct) : TermElabM Unit :=
|
||||
let hierarchyDepth := getHierarchyDepth struct;
|
||||
let structNames := collectStructNames struct #[];
|
||||
let hierarchyDepth := getHierarchyDepth struct
|
||||
let structNames := collectStructNames struct #[]
|
||||
(propagateLoop hierarchyDepth 0 struct { allStructNames := structNames }).run' {}
|
||||
|
||||
end DefaultFields
|
||||
|
||||
private def elabStructInstAux (stx : Syntax) (expectedType? : Option Expr) (source : Source) : TermElabM Expr := do
|
||||
structName ← getStructName stx expectedType? source;
|
||||
env ← getEnv;
|
||||
unless (isStructureLike env structName) $
|
||||
throwError ("invalid {...} notation, '" ++ structName ++ "' is not a structure");
|
||||
let structName ← getStructName stx expectedType? source
|
||||
unless isStructureLike (← getEnv) structName do
|
||||
throwError! "invalid \{...} notation, '{structName}' is not a structure"
|
||||
match mkStructView stx structName source with
|
||||
| Except.error ex => throwError ex
|
||||
| Except.ok struct => do
|
||||
struct ← expandStruct struct;
|
||||
trace `Elab.struct fun _ => toString struct;
|
||||
(r, struct) ← elabStruct struct expectedType?;
|
||||
DefaultFields.propagate struct;
|
||||
| Except.ok struct =>
|
||||
let struct ← expandStruct struct
|
||||
trace[Elab.struct]! "{struct}"
|
||||
let (r, struct) ← elabStruct struct expectedType?
|
||||
DefaultFields.propagate struct
|
||||
pure r
|
||||
|
||||
@[builtinTermElab structInst] def elabStructInst : TermElab :=
|
||||
fun stx expectedType? => do
|
||||
stxNew? ← expandNonAtomicExplicitSource stx;
|
||||
match stxNew? with
|
||||
match (← expandNonAtomicExplicitSource stx) with
|
||||
| some stxNew => withMacroExpansion stx stxNew $ elabTerm stxNew expectedType?
|
||||
| none => do
|
||||
sourceView ← getStructSource stx;
|
||||
modifyOp? ← isModifyOp? stx;
|
||||
match modifyOp?, sourceView with
|
||||
| none =>
|
||||
let sourceView ← getStructSource stx
|
||||
match (← isModifyOp? stx), sourceView with
|
||||
| some modifyOp, Source.explicit source _ => elabModifyOp stx modifyOp source expectedType?
|
||||
| some _, _ => throwError ("invalid {...} notation, explicit source is required when using '[<index>] := <value>'")
|
||||
| some _, _ => throwError "invalid {...} notation, explicit source is required when using '[<index>] := <value>'"
|
||||
| _, _ => elabStructInstAux stx expectedType? sourceView
|
||||
|
||||
@[init] private def regTraceClasses : IO Unit := do
|
||||
registerTraceClass `Elab.struct;
|
||||
pure ()
|
||||
initialize registerTraceClass `Elab.struct
|
||||
|
||||
end StructInst
|
||||
end Term
|
||||
end Elab
|
||||
end Lean
|
||||
end Lean.Elab.Term.StructInst
|
||||
|
|
|
|||
|
|
@ -168,13 +168,13 @@ self.find? idx
|
|||
match m with
|
||||
| ⟨ m, _ ⟩ => m.contains a
|
||||
|
||||
@[inline] def foldM {δ : Type w} {m : Type w → Type w} [Monad m] (f : δ → α → β → m δ) (d : δ) (h : HashMap α β) : m δ :=
|
||||
@[inline] def foldM {δ : Type w} {m : Type w → Type w} [Monad m] (f : δ → α → β → m δ) (init : δ) (h : HashMap α β) : m δ :=
|
||||
match h with
|
||||
| ⟨ h, _ ⟩ => h.foldM f d
|
||||
| ⟨ h, _ ⟩ => h.foldM f init
|
||||
|
||||
@[inline] def fold {δ : Type w} (f : δ → α → β → δ) (d : δ) (m : HashMap α β) : δ :=
|
||||
@[inline] def fold {δ : Type w} (f : δ → α → β → δ) (init : δ) (m : HashMap α β) : δ :=
|
||||
match m with
|
||||
| ⟨ m, _ ⟩ => m.fold f d
|
||||
| ⟨ m, _ ⟩ => m.fold f init
|
||||
|
||||
@[inline] def size (m : HashMap α β) : Nat :=
|
||||
match m with
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue