From aed9b16dc8b0a005cc131cd3d62b66c9f0be3b93 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 7 Sep 2020 15:08:07 -0700 Subject: [PATCH] feat: add `MatcherInfo` extension --- src/Lean/Elab/Match.lean | 14 +++---- src/Lean/Meta/Match/Match.lean | 71 ++++++++++++++++++++++++++++++---- 2 files changed, 70 insertions(+), 15 deletions(-) diff --git a/src/Lean/Elab/Match.lean b/src/Lean/Elab/Match.lean index 40bf85f859..5c8ef92802 100644 --- a/src/Lean/Elab/Match.lean +++ b/src/Lean/Elab/Match.lean @@ -401,7 +401,7 @@ patternVarDecls.foldlM | _ => pure decls) #[] -open Meta.Match (Pattern Pattern.var Pattern.inaccessible Pattern.ctor Pattern.as Pattern.val Pattern.arrayLit AltLHS ElimResult) +open Meta.Match (Pattern Pattern.var Pattern.inaccessible Pattern.ctor Pattern.as Pattern.val Pattern.arrayLit AltLHS MatcherResult) namespace ToDepElimPattern @@ -555,10 +555,10 @@ forallBoundedTelescope matchType numDiscrs fun xs matchType => do u ← getLevel matchType; mkForallFVars xs (mkSort u) -def mkMatcher (elimName : Name) (motiveType : Expr) (numDiscrs : Nat) (lhss : List AltLHS) : TermElabM ElimResult := +def mkMatcher (elimName : Name) (motiveType : Expr) (numDiscrs : Nat) (lhss : List AltLHS) : TermElabM MatcherResult := liftMetaM $ Meta.Match.mkMatcher elimName motiveType numDiscrs lhss -def reportElimResultErrors (result : ElimResult) : TermElabM Unit := do +def reportMatcherResultErrors (result : MatcherResult) : TermElabM Unit := do -- TODO: improve error messages unless result.counterExamples.isEmpty $ throwError ("missing cases:" ++ Format.line ++ Meta.Match.counterExamplesToMessageData result.counterExamples); @@ -577,10 +577,10 @@ let altLHSS := alts.map Prod.fst; let numDiscrs := discrs.size; motiveType ← mkMotiveType matchType numDiscrs; motive ← forallBoundedTelescope matchType numDiscrs fun xs matchType => mkLambdaFVars xs matchType; -elimName ← mkAuxName `match; -elimResult ← mkMatcher elimName motiveType numDiscrs altLHSS.toList; -reportElimResultErrors elimResult; -let r := mkApp elimResult.elim motive; +matcherName ← mkAuxName `match; +matcherResult ← mkMatcher matcherName motiveType numDiscrs altLHSS.toList; +reportMatcherResultErrors matcherResult; +let r := mkApp matcherResult.matcher motive; let r := mkAppN r discrs; let r := mkAppN r rhss; trace `Elab.match fun _ => "result: " ++ r; diff --git a/src/Lean/Meta/Match/Match.lean b/src/Lean/Meta/Match/Match.lean index 82008d4cae..559b252eed 100644 --- a/src/Lean/Meta/Match/Match.lean +++ b/src/Lean/Meta/Match/Match.lean @@ -179,8 +179,8 @@ examplesToMessageData cex def counterExamplesToMessageData (cexs : List CounterExample) : MessageData := MessageData.joinSep (cexs.map counterExampleToMessageData) Format.line -structure ElimResult := -(elim : Expr) -- The eliminator. It is not just `Expr.const elimName` because the type of the major premises may contain free variables. +structure MatcherResult := +(matcher : Expr) -- The matcher. It is not just `Expr.const matcherName` because the type of the major premises may contain free variables. (counterExamples : List CounterExample) (unusedAltIdxs : List Nat) @@ -662,7 +662,61 @@ private partial def process : Problem → StateRefT State MetaM Unit else liftM $ throwNonSupported p -def mkMatcher (elimName : Name) (motiveType : Expr) (numDiscrs : Nat) (lhss : List AltLHS) : MetaM ElimResult := +/-- +A "matcher" auxiliary declaration has the following structure: +- `numParams` parameters +- motive +- `numDiscrs` discriminators (aka major premises) +- `numAlts` alternatives (aka minor premises) +-/ +structure MatcherInfo := +(numParams : Nat) (numDiscrs : Nat) (numAlts : Nat) + +namespace Extension + +structure Entry := +(name : Name) (info : MatcherInfo) + +structure State := +(map : SMap Name MatcherInfo := {}) + +instance State.inhabited : Inhabited State := +⟨{}⟩ + +def State.addEntry (s : State) (e : Entry) : State := { s with map := s.map.insert e.name e.info } +def State.switch (s : State) : State := { s with map := s.map.switch } + +def mkExtension : IO (SimplePersistentEnvExtension Entry State) := +registerSimplePersistentEnvExtension { + name := `matcher, + addEntryFn := State.addEntry, + addImportedFn := fun es => (mkStateFromImportedEntries State.addEntry {} es).switch +} + +@[init mkExtension] +constant extension : SimplePersistentEnvExtension Entry State := +arbitrary _ + +def addMatcherInfo (env : Environment) (matcherName : Name) (info : MatcherInfo) : Environment := +extension.addEntry env { name := matcherName, info := info } + +def getMatcherInfo? (env : Environment) (declName : Name) : Option MatcherInfo := +(extension.getState env).map.find? declName + +end Extension + +def addMatcherInfo (matcherName : Name) (info : MatcherInfo) : MetaM Unit := +modifyEnv fun env => Extension.addMatcherInfo env matcherName info + +def getMatcherInfo? (declName : Name) : MetaM (Option MatcherInfo) := do +env ← getEnv; +pure $ Extension.getMatcherInfo? env declName + +def isMatcher (declName : Name) : MetaM Bool := do +info? ← getMatcherInfo? declName; +pure info?.isSome + +def mkMatcher (matcherName : Name) (motiveType : Expr) (numDiscrs : Nat) (lhss : List AltLHS) : MetaM MatcherResult := withLocalDeclD `motive motiveType fun motive => do trace! `Meta.Match.debug ("motiveType: " ++ motiveType); forallBoundedTelescope motiveType numDiscrs fun majors _ => do @@ -676,14 +730,15 @@ withAlts motive lhss fun alts minors => do let args := #[motive] ++ majors ++ minors; type ← mkForallFVars args mvarType; val ← mkLambdaFVars args mvar; - trace! `Meta.Match.debug ("eliminator value: " ++ val ++ "\ntype: " ++ type); - elim ← mkAuxDefinition elimName type val; - setInlineAttribute elimName; - trace! `Meta.Match.debug ("eliminator: " ++ elim); + trace! `Meta.Match.debug ("matcher value: " ++ val ++ "\ntype: " ++ type); + matcher ← mkAuxDefinition matcherName type val; + addMatcherInfo matcherName { numParams := matcher.getAppNumArgs, numDiscrs := majors.size, numAlts := minors.size }; + setInlineAttribute matcherName; + trace! `Meta.Match.debug ("matcher: " ++ matcher); let unusedAltIdxs : List Nat := lhss.length.fold (fun i r => if s.used.contains i then r else i::r) []; - pure { elim := elim, counterExamples := s.counterExamples, unusedAltIdxs := unusedAltIdxs.reverse } + pure { matcher := matcher, counterExamples := s.counterExamples, unusedAltIdxs := unusedAltIdxs.reverse } @[init] private def regTraceClasses : IO Unit := do registerTraceClass `Meta.Match.match;