feat: add MatcherInfo extension
This commit is contained in:
parent
fded18d114
commit
aed9b16dc8
2 changed files with 70 additions and 15 deletions
|
|
@ -401,7 +401,7 @@ patternVarDecls.foldlM
|
|||
| _ => pure decls)
|
||||
#[]
|
||||
|
||||
open Meta.Match (Pattern Pattern.var Pattern.inaccessible Pattern.ctor Pattern.as Pattern.val Pattern.arrayLit AltLHS ElimResult)
|
||||
open Meta.Match (Pattern Pattern.var Pattern.inaccessible Pattern.ctor Pattern.as Pattern.val Pattern.arrayLit AltLHS MatcherResult)
|
||||
|
||||
namespace ToDepElimPattern
|
||||
|
||||
|
|
@ -555,10 +555,10 @@ forallBoundedTelescope matchType numDiscrs fun xs matchType => do
|
|||
u ← getLevel matchType;
|
||||
mkForallFVars xs (mkSort u)
|
||||
|
||||
def mkMatcher (elimName : Name) (motiveType : Expr) (numDiscrs : Nat) (lhss : List AltLHS) : TermElabM ElimResult :=
|
||||
def mkMatcher (elimName : Name) (motiveType : Expr) (numDiscrs : Nat) (lhss : List AltLHS) : TermElabM MatcherResult :=
|
||||
liftMetaM $ Meta.Match.mkMatcher elimName motiveType numDiscrs lhss
|
||||
|
||||
def reportElimResultErrors (result : ElimResult) : TermElabM Unit := do
|
||||
def reportMatcherResultErrors (result : MatcherResult) : TermElabM Unit := do
|
||||
-- TODO: improve error messages
|
||||
unless result.counterExamples.isEmpty $
|
||||
throwError ("missing cases:" ++ Format.line ++ Meta.Match.counterExamplesToMessageData result.counterExamples);
|
||||
|
|
@ -577,10 +577,10 @@ let altLHSS := alts.map Prod.fst;
|
|||
let numDiscrs := discrs.size;
|
||||
motiveType ← mkMotiveType matchType numDiscrs;
|
||||
motive ← forallBoundedTelescope matchType numDiscrs fun xs matchType => mkLambdaFVars xs matchType;
|
||||
elimName ← mkAuxName `match;
|
||||
elimResult ← mkMatcher elimName motiveType numDiscrs altLHSS.toList;
|
||||
reportElimResultErrors elimResult;
|
||||
let r := mkApp elimResult.elim motive;
|
||||
matcherName ← mkAuxName `match;
|
||||
matcherResult ← mkMatcher matcherName motiveType numDiscrs altLHSS.toList;
|
||||
reportMatcherResultErrors matcherResult;
|
||||
let r := mkApp matcherResult.matcher motive;
|
||||
let r := mkAppN r discrs;
|
||||
let r := mkAppN r rhss;
|
||||
trace `Elab.match fun _ => "result: " ++ r;
|
||||
|
|
|
|||
|
|
@ -179,8 +179,8 @@ examplesToMessageData cex
|
|||
def counterExamplesToMessageData (cexs : List CounterExample) : MessageData :=
|
||||
MessageData.joinSep (cexs.map counterExampleToMessageData) Format.line
|
||||
|
||||
structure ElimResult :=
|
||||
(elim : Expr) -- The eliminator. It is not just `Expr.const elimName` because the type of the major premises may contain free variables.
|
||||
structure MatcherResult :=
|
||||
(matcher : Expr) -- The matcher. It is not just `Expr.const matcherName` because the type of the major premises may contain free variables.
|
||||
(counterExamples : List CounterExample)
|
||||
(unusedAltIdxs : List Nat)
|
||||
|
||||
|
|
@ -662,7 +662,61 @@ private partial def process : Problem → StateRefT State MetaM Unit
|
|||
else
|
||||
liftM $ throwNonSupported p
|
||||
|
||||
def mkMatcher (elimName : Name) (motiveType : Expr) (numDiscrs : Nat) (lhss : List AltLHS) : MetaM ElimResult :=
|
||||
/--
|
||||
A "matcher" auxiliary declaration has the following structure:
|
||||
- `numParams` parameters
|
||||
- motive
|
||||
- `numDiscrs` discriminators (aka major premises)
|
||||
- `numAlts` alternatives (aka minor premises)
|
||||
-/
|
||||
structure MatcherInfo :=
|
||||
(numParams : Nat) (numDiscrs : Nat) (numAlts : Nat)
|
||||
|
||||
namespace Extension
|
||||
|
||||
structure Entry :=
|
||||
(name : Name) (info : MatcherInfo)
|
||||
|
||||
structure State :=
|
||||
(map : SMap Name MatcherInfo := {})
|
||||
|
||||
instance State.inhabited : Inhabited State :=
|
||||
⟨{}⟩
|
||||
|
||||
def State.addEntry (s : State) (e : Entry) : State := { s with map := s.map.insert e.name e.info }
|
||||
def State.switch (s : State) : State := { s with map := s.map.switch }
|
||||
|
||||
def mkExtension : IO (SimplePersistentEnvExtension Entry State) :=
|
||||
registerSimplePersistentEnvExtension {
|
||||
name := `matcher,
|
||||
addEntryFn := State.addEntry,
|
||||
addImportedFn := fun es => (mkStateFromImportedEntries State.addEntry {} es).switch
|
||||
}
|
||||
|
||||
@[init mkExtension]
|
||||
constant extension : SimplePersistentEnvExtension Entry State :=
|
||||
arbitrary _
|
||||
|
||||
def addMatcherInfo (env : Environment) (matcherName : Name) (info : MatcherInfo) : Environment :=
|
||||
extension.addEntry env { name := matcherName, info := info }
|
||||
|
||||
def getMatcherInfo? (env : Environment) (declName : Name) : Option MatcherInfo :=
|
||||
(extension.getState env).map.find? declName
|
||||
|
||||
end Extension
|
||||
|
||||
def addMatcherInfo (matcherName : Name) (info : MatcherInfo) : MetaM Unit :=
|
||||
modifyEnv fun env => Extension.addMatcherInfo env matcherName info
|
||||
|
||||
def getMatcherInfo? (declName : Name) : MetaM (Option MatcherInfo) := do
|
||||
env ← getEnv;
|
||||
pure $ Extension.getMatcherInfo? env declName
|
||||
|
||||
def isMatcher (declName : Name) : MetaM Bool := do
|
||||
info? ← getMatcherInfo? declName;
|
||||
pure info?.isSome
|
||||
|
||||
def mkMatcher (matcherName : Name) (motiveType : Expr) (numDiscrs : Nat) (lhss : List AltLHS) : MetaM MatcherResult :=
|
||||
withLocalDeclD `motive motiveType fun motive => do
|
||||
trace! `Meta.Match.debug ("motiveType: " ++ motiveType);
|
||||
forallBoundedTelescope motiveType numDiscrs fun majors _ => do
|
||||
|
|
@ -676,14 +730,15 @@ withAlts motive lhss fun alts minors => do
|
|||
let args := #[motive] ++ majors ++ minors;
|
||||
type ← mkForallFVars args mvarType;
|
||||
val ← mkLambdaFVars args mvar;
|
||||
trace! `Meta.Match.debug ("eliminator value: " ++ val ++ "\ntype: " ++ type);
|
||||
elim ← mkAuxDefinition elimName type val;
|
||||
setInlineAttribute elimName;
|
||||
trace! `Meta.Match.debug ("eliminator: " ++ elim);
|
||||
trace! `Meta.Match.debug ("matcher value: " ++ val ++ "\ntype: " ++ type);
|
||||
matcher ← mkAuxDefinition matcherName type val;
|
||||
addMatcherInfo matcherName { numParams := matcher.getAppNumArgs, numDiscrs := majors.size, numAlts := minors.size };
|
||||
setInlineAttribute matcherName;
|
||||
trace! `Meta.Match.debug ("matcher: " ++ matcher);
|
||||
let unusedAltIdxs : List Nat := lhss.length.fold
|
||||
(fun i r => if s.used.contains i then r else i::r)
|
||||
[];
|
||||
pure { elim := elim, counterExamples := s.counterExamples, unusedAltIdxs := unusedAltIdxs.reverse }
|
||||
pure { matcher := matcher, counterExamples := s.counterExamples, unusedAltIdxs := unusedAltIdxs.reverse }
|
||||
|
||||
@[init] private def regTraceClasses : IO Unit := do
|
||||
registerTraceClass `Meta.Match.match;
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue