feat: add MatcherInfo extension

This commit is contained in:
Leonardo de Moura 2020-09-07 15:08:07 -07:00
parent fded18d114
commit aed9b16dc8
2 changed files with 70 additions and 15 deletions

View file

@ -401,7 +401,7 @@ patternVarDecls.foldlM
| _ => pure decls)
#[]
open Meta.Match (Pattern Pattern.var Pattern.inaccessible Pattern.ctor Pattern.as Pattern.val Pattern.arrayLit AltLHS ElimResult)
open Meta.Match (Pattern Pattern.var Pattern.inaccessible Pattern.ctor Pattern.as Pattern.val Pattern.arrayLit AltLHS MatcherResult)
namespace ToDepElimPattern
@ -555,10 +555,10 @@ forallBoundedTelescope matchType numDiscrs fun xs matchType => do
u ← getLevel matchType;
mkForallFVars xs (mkSort u)
def mkMatcher (elimName : Name) (motiveType : Expr) (numDiscrs : Nat) (lhss : List AltLHS) : TermElabM ElimResult :=
def mkMatcher (elimName : Name) (motiveType : Expr) (numDiscrs : Nat) (lhss : List AltLHS) : TermElabM MatcherResult :=
liftMetaM $ Meta.Match.mkMatcher elimName motiveType numDiscrs lhss
def reportElimResultErrors (result : ElimResult) : TermElabM Unit := do
def reportMatcherResultErrors (result : MatcherResult) : TermElabM Unit := do
-- TODO: improve error messages
unless result.counterExamples.isEmpty $
throwError ("missing cases:" ++ Format.line ++ Meta.Match.counterExamplesToMessageData result.counterExamples);
@ -577,10 +577,10 @@ let altLHSS := alts.map Prod.fst;
let numDiscrs := discrs.size;
motiveType ← mkMotiveType matchType numDiscrs;
motive ← forallBoundedTelescope matchType numDiscrs fun xs matchType => mkLambdaFVars xs matchType;
elimName ← mkAuxName `match;
elimResult ← mkMatcher elimName motiveType numDiscrs altLHSS.toList;
reportElimResultErrors elimResult;
let r := mkApp elimResult.elim motive;
matcherName ← mkAuxName `match;
matcherResult ← mkMatcher matcherName motiveType numDiscrs altLHSS.toList;
reportMatcherResultErrors matcherResult;
let r := mkApp matcherResult.matcher motive;
let r := mkAppN r discrs;
let r := mkAppN r rhss;
trace `Elab.match fun _ => "result: " ++ r;

View file

@ -179,8 +179,8 @@ examplesToMessageData cex
def counterExamplesToMessageData (cexs : List CounterExample) : MessageData :=
MessageData.joinSep (cexs.map counterExampleToMessageData) Format.line
structure ElimResult :=
(elim : Expr) -- The eliminator. It is not just `Expr.const elimName` because the type of the major premises may contain free variables.
structure MatcherResult :=
(matcher : Expr) -- The matcher. It is not just `Expr.const matcherName` because the type of the major premises may contain free variables.
(counterExamples : List CounterExample)
(unusedAltIdxs : List Nat)
@ -662,7 +662,61 @@ private partial def process : Problem → StateRefT State MetaM Unit
else
liftM $ throwNonSupported p
def mkMatcher (elimName : Name) (motiveType : Expr) (numDiscrs : Nat) (lhss : List AltLHS) : MetaM ElimResult :=
/--
A "matcher" auxiliary declaration has the following structure:
- `numParams` parameters
- motive
- `numDiscrs` discriminators (aka major premises)
- `numAlts` alternatives (aka minor premises)
-/
structure MatcherInfo :=
(numParams : Nat) (numDiscrs : Nat) (numAlts : Nat)
namespace Extension
structure Entry :=
(name : Name) (info : MatcherInfo)
structure State :=
(map : SMap Name MatcherInfo := {})
instance State.inhabited : Inhabited State :=
⟨{}⟩
def State.addEntry (s : State) (e : Entry) : State := { s with map := s.map.insert e.name e.info }
def State.switch (s : State) : State := { s with map := s.map.switch }
def mkExtension : IO (SimplePersistentEnvExtension Entry State) :=
registerSimplePersistentEnvExtension {
name := `matcher,
addEntryFn := State.addEntry,
addImportedFn := fun es => (mkStateFromImportedEntries State.addEntry {} es).switch
}
@[init mkExtension]
constant extension : SimplePersistentEnvExtension Entry State :=
arbitrary _
def addMatcherInfo (env : Environment) (matcherName : Name) (info : MatcherInfo) : Environment :=
extension.addEntry env { name := matcherName, info := info }
def getMatcherInfo? (env : Environment) (declName : Name) : Option MatcherInfo :=
(extension.getState env).map.find? declName
end Extension
def addMatcherInfo (matcherName : Name) (info : MatcherInfo) : MetaM Unit :=
modifyEnv fun env => Extension.addMatcherInfo env matcherName info
def getMatcherInfo? (declName : Name) : MetaM (Option MatcherInfo) := do
env ← getEnv;
pure $ Extension.getMatcherInfo? env declName
def isMatcher (declName : Name) : MetaM Bool := do
info? ← getMatcherInfo? declName;
pure info?.isSome
def mkMatcher (matcherName : Name) (motiveType : Expr) (numDiscrs : Nat) (lhss : List AltLHS) : MetaM MatcherResult :=
withLocalDeclD `motive motiveType fun motive => do
trace! `Meta.Match.debug ("motiveType: " ++ motiveType);
forallBoundedTelescope motiveType numDiscrs fun majors _ => do
@ -676,14 +730,15 @@ withAlts motive lhss fun alts minors => do
let args := #[motive] ++ majors ++ minors;
type ← mkForallFVars args mvarType;
val ← mkLambdaFVars args mvar;
trace! `Meta.Match.debug ("eliminator value: " ++ val ++ "\ntype: " ++ type);
elim ← mkAuxDefinition elimName type val;
setInlineAttribute elimName;
trace! `Meta.Match.debug ("eliminator: " ++ elim);
trace! `Meta.Match.debug ("matcher value: " ++ val ++ "\ntype: " ++ type);
matcher ← mkAuxDefinition matcherName type val;
addMatcherInfo matcherName { numParams := matcher.getAppNumArgs, numDiscrs := majors.size, numAlts := minors.size };
setInlineAttribute matcherName;
trace! `Meta.Match.debug ("matcher: " ++ matcher);
let unusedAltIdxs : List Nat := lhss.length.fold
(fun i r => if s.used.contains i then r else i::r)
[];
pure { elim := elim, counterExamples := s.counterExamples, unusedAltIdxs := unusedAltIdxs.reverse }
pure { matcher := matcher, counterExamples := s.counterExamples, unusedAltIdxs := unusedAltIdxs.reverse }
@[init] private def regTraceClasses : IO Unit := do
registerTraceClass `Meta.Match.match;