From bd951b40ce9552fa8ca1541f145a87ff96b90f7b Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 17 Feb 2020 11:43:55 -0800 Subject: [PATCH] refactor: remove unnecessary complexity --- src/Init/Lean/Elab/StructInst.lean | 54 ++++++++++++++---------------- 1 file changed, 26 insertions(+), 28 deletions(-) diff --git a/src/Init/Lean/Elab/StructInst.lean b/src/Init/Lean/Elab/StructInst.lean index 54c7546d6e..32c8d81a2b 100644 --- a/src/Init/Lean/Elab/StructInst.lean +++ b/src/Init/Lean/Elab/StructInst.lean @@ -454,28 +454,25 @@ private partial def expandStruct : Struct → TermElabM Struct s ← groupFields expandStruct s; addMissingFields expandStruct s -structure State := -(instMVars : Array MVarId := #[]) - structure CtorHeaderResult := (ctorFn : Expr) (ctorFnType : Expr) +(instMVars : Array MVarId := #[]) -private def mkCtorHeaderAux (ref : Syntax) : Nat → Expr → Expr → StateT State TermElabM CtorHeaderResult -| 0, type, ctorFn => pure { ctorFn := ctorFn, ctorFnType := type } -| n+1, type, ctorFn => do - type ← liftM $ whnfForall ref type; +private def mkCtorHeaderAux (ref : Syntax) : Nat → Expr → Expr → Array MVarId → TermElabM CtorHeaderResult +| 0, type, ctorFn, instMVars => pure { ctorFn := ctorFn, ctorFnType := type, instMVars := instMVars } +| n+1, type, ctorFn, instMVars => do + type ← whnfForall ref type; match type with | Expr.forallE _ d b c => match c.binderInfo with | BinderInfo.instImplicit => do - a ← liftM $ mkFreshExprMVar ref d MetavarKind.synthetic; - modify $ fun s => { instMVars := s.instMVars.push a.mvarId!, .. s }; - mkCtorHeaderAux n (b.instantiate1 a) (mkApp ctorFn a) + a ← mkFreshExprMVar ref d MetavarKind.synthetic; + mkCtorHeaderAux n (b.instantiate1 a) (mkApp ctorFn a) (instMVars.push a.mvarId!) | _ => do - a ← liftM $ mkFreshExprMVar ref d; - mkCtorHeaderAux n (b.instantiate1 a) (mkApp ctorFn a) - | _ => liftM $ throwError ref "unexpected constructor type" + a ← mkFreshExprMVar ref d; + mkCtorHeaderAux n (b.instantiate1 a) (mkApp ctorFn a) instMVars + | _ => throwError ref "unexpected constructor type" private partial def getForallBody : Nat → Expr → Option Expr | i+1, Expr.forallE _ _ b _ => getForallBody i b @@ -493,45 +490,46 @@ match expectedType? with isDefEq ref expectedType typeBody; pure () -private def mkCtorHeader (ref : Syntax) (ctorVal : ConstructorVal) (expectedType? : Option Expr) : StateT State TermElabM CtorHeaderResult := do -lvls ← ctorVal.lparams.mapM $ fun _ => liftM $ mkFreshLevelMVar ref; +private def mkCtorHeader (ref : Syntax) (ctorVal : ConstructorVal) (expectedType? : Option Expr) : TermElabM CtorHeaderResult := do +lvls ← ctorVal.lparams.mapM $ fun _ => mkFreshLevelMVar ref; let val := Lean.mkConst ctorVal.name lvls; let type := (ConstantInfo.ctorInfo ctorVal).instantiateTypeLevelParams lvls; -r ← mkCtorHeaderAux ref ctorVal.nparams type val; -liftM $ propagateExpectedType ref r.ctorFnType ctorVal.nfields expectedType?; +r ← mkCtorHeaderAux ref ctorVal.nparams type val #[]; +propagateExpectedType ref r.ctorFnType ctorVal.nfields expectedType?; +synthesizeAppInstMVars ref r.instMVars; pure r def markDefaultMissing (e : Expr) : Expr := mkMData (KVMap.empty.insert `structInstDefault (DataValue.ofBool true)) e -def throwFailedToElabField {α} (ref : Syntax) (fieldName : Name) (structName : Name) (msgData : MessageData) : StateT State TermElabM α := -liftM $ throwError ref ("failed to elaborate field '" ++ fieldName ++ "' of '" ++ structName ++ ", " ++ msgData) +def throwFailedToElabField {α} (ref : Syntax) (fieldName : Name) (structName : Name) (msgData : MessageData) : TermElabM α := +throwError ref ("failed to elaborate field '" ++ fieldName ++ "' of '" ++ structName ++ ", " ++ msgData) -private partial def elabStruct : Struct → Option Expr → StateT State TermElabM (Expr × Struct) +private partial def elabStruct : Struct → Option Expr → TermElabM (Expr × Struct) | s, expectedType? => do - env ← liftM $ getEnv; + env ← getEnv; let ctorVal := getStructureCtor env s.structName; - { ctorFn := ctorFn, ctorFnType := ctorFnType } ← mkCtorHeader s.ref ctorVal expectedType?; + { ctorFn := ctorFn, ctorFnType := ctorFnType, .. } ← mkCtorHeader s.ref ctorVal expectedType?; (e, _, fields) ← s.fields.foldlM (fun (acc : Expr × Expr × Fields) field => let (e, type, fields) := acc; match field.lhs with | [FieldLHS.fieldName ref fieldName] => do - type ← liftM $ whnfForall field.ref type; + type ← whnfForall field.ref type; match type with | Expr.forallE _ d b c => - let continue (val : Expr) (field : Field Struct) : StateT State TermElabM (Expr × Expr × Fields) := do { + let continue (val : Expr) (field : Field Struct) : TermElabM (Expr × Expr × Fields) := do { let e := mkApp e val; let type := b.instantiate1 val; let field := { expr := some val, .. field }; pure (e, type, field::fields) }; match field.val with - | FieldVal.term stx => do val ← liftM $ elabTerm stx (some d); continue val field + | FieldVal.term stx => do val ← elabTerm stx (some d); continue val field | FieldVal.nested s => do (val, sNew) ← elabStruct s (some d); continue val { val := FieldVal.nested sNew, .. field } - | FieldVal.default => do val ← liftM $ mkFreshExprMVar field.ref (some d); continue (markDefaultMissing val) field + | FieldVal.default => do val ← mkFreshExprMVar field.ref (some d); continue (markDefaultMissing val) field | _ => throwFailedToElabField field.ref fieldName s.structName ("unexpected constructor type" ++ indentExpr type) - | _ => liftM $ throwError field.ref "unexpected unexpanded structure field") + | _ => throwError field.ref "unexpected unexpanded structure field") (ctorFn, ctorFnType, []); pure (e, s.setFields fields.reverse) @@ -542,7 +540,7 @@ unless (isStructureLike env structName) $ throwError stx ("invalid {...} notation, '" ++ structName ++ "' is not a structure"); struct ← expandStruct $ mkStructView stx structName source; trace `Elab.struct stx $ fun _ => toString struct; -((r, struct), s) ← (elabStruct struct expectedType?).run {}; +(r, struct) ← elabStruct struct expectedType?; -- TODO: resolve missing default pure r