feat: add Meta.EqnCompiler trace class
This commit is contained in:
parent
7342cab0e5
commit
d09eb82c4c
2 changed files with 47 additions and 17 deletions
|
|
@ -5,3 +5,10 @@ Authors: Leonardo de Moura
|
|||
-/
|
||||
import Lean.Meta.EqnCompiler.MatchPattern
|
||||
import Lean.Meta.EqnCompiler.DepElim
|
||||
|
||||
namespace Lean
|
||||
|
||||
@[init] private def regTraceClasses : IO Unit :=
|
||||
registerTraceClass `Meta.EqnCompiler
|
||||
|
||||
end Lean
|
||||
|
|
|
|||
|
|
@ -199,6 +199,9 @@ partial def toMessageData : Example → MessageData
|
|||
|
||||
end Example
|
||||
|
||||
def examplesToMessageData (cex : List Example) : MessageData :=
|
||||
MessageData.joinSep (cex.map (Example.toMessageData ∘ Example.varsToUnderscore)) ", "
|
||||
|
||||
structure Problem :=
|
||||
(goal : Expr)
|
||||
(vars : List Expr)
|
||||
|
|
@ -218,13 +221,13 @@ withGoalOf p do
|
|||
pure $ "vars " ++ p.vars.toArray
|
||||
-- ++ Format.line ++ "var ids " ++ toString (p.vars.map (fun x => match x with | Expr.fvar id _ => toString id | _ => "[nonvar]"))
|
||||
++ Format.line ++ MessageData.joinSep alts Format.line
|
||||
|
||||
++ Format.line ++ "examples: " ++ examplesToMessageData p.examples
|
||||
end Problem
|
||||
|
||||
abbrev CounterExample := List Example
|
||||
|
||||
def counterExampleToMessageData (cex : CounterExample) : MessageData :=
|
||||
MessageData.joinSep (cex.map (Example.toMessageData ∘ Example.varsToUnderscore)) ", "
|
||||
examplesToMessageData cex
|
||||
|
||||
def counterExamplesToMessageData (cexs : List CounterExample) : MessageData :=
|
||||
MessageData.joinSep (cexs.map counterExampleToMessageData) Format.line
|
||||
|
|
@ -245,7 +248,7 @@ when (lhss.any (fun lhs => lhs.patterns.length != num)) $
|
|||
`forall (x_1 : A_1) (x_2 : A_2[x_1]) ... (x_n : A_n[x_1, x_2, ...]), sortv` -/
|
||||
private def withMotive {α} (majors : Array Expr) (sortv : Expr) (k : Expr → MetaM α) : MetaM α := do
|
||||
type ← mkForall majors sortv;
|
||||
trace! `Meta.debug ("motive: " ++ type);
|
||||
trace! `Meta.EqnCompiler.matchDebug ("motive: " ++ type);
|
||||
withLocalDecl `motive type BinderInfo.default k
|
||||
|
||||
private def localDeclsToMVarsAux : List LocalDecl → List MVarId → FVarSubst → MetaM (List MVarId × FVarSubst)
|
||||
|
|
@ -276,7 +279,7 @@ private partial def withAltsAux {α} (motive : Expr) : List AltLHS → List Alt
|
|||
};
|
||||
let idx := alts.length;
|
||||
let minorName := (`h).appendIndexAfter (idx+1);
|
||||
trace! `Meta.debug ("minor premise " ++ minorName ++ " : " ++ minorType);
|
||||
trace! `Meta.EqnCompiler.matchDebug ("minor premise " ++ minorName ++ " : " ++ minorType);
|
||||
withLocalDecl minorName minorType BinderInfo.default fun minor => do
|
||||
let rhs := mkAppN minor xs;
|
||||
let minors := minors.push minor;
|
||||
|
|
@ -332,8 +335,20 @@ let (ok, hasVar, hasCtor) := p.alts.foldl
|
|||
(true, false, false);
|
||||
ok && hasVar && hasCtor
|
||||
|
||||
/- Return true if the next pattern of the remaining alternatives contain variables AND values. -/
|
||||
private def isValueTransition (p : Problem) : Bool :=
|
||||
let (ok, hasVar, hasVal) := p.alts.foldl
|
||||
(fun (acc : Bool × Bool × Bool) (alt : Alt) =>
|
||||
let (ok, hasVar, hasVal) := acc;
|
||||
match alt.patterns with
|
||||
| Pattern.val _ _ :: _ => (ok, hasVar, true)
|
||||
| Pattern.var _ _ :: _ => (ok, true, hasVal)
|
||||
| _ => (false, hasVar, hasVal))
|
||||
(true, false, false);
|
||||
ok && hasVar && hasVal
|
||||
|
||||
private def processNonVariable (process : Problem → State → MetaM State) (p : Problem) (s : State) : MetaM State := do
|
||||
trace! `Meta.debug ("process non variable");
|
||||
trace! `Meta.EqnCompiler.match ("process non variable");
|
||||
match p.vars with
|
||||
| x :: xs =>
|
||||
let alts := p.alts.map fun alt => match alt.patterns with
|
||||
|
|
@ -353,19 +368,19 @@ match p.alts with
|
|||
pure { s with used := s.used.insert alt.idx }
|
||||
|
||||
private def processVariable (process : Problem → State → MetaM State) (p : Problem) (s : State) : MetaM State := do
|
||||
trace! `Meta.debug ("process variable");
|
||||
trace! `Meta.EqnCompiler.match ("process variable");
|
||||
match p.vars with
|
||||
| x :: xs => do
|
||||
alts ← p.alts.mapM fun alt => match alt.patterns with
|
||||
| Pattern.inaccessible _ _ :: ps => pure { alt with patterns := ps }
|
||||
| Pattern.var _ mvarId :: ps => do
|
||||
-- trace! `Meta.debug (">> assign " ++ mkMVar mvarId ++ " := " ++ x);
|
||||
-- trace! `Meta.EqnCompiler.matchDebug (">> assign " ++ mkMVar mvarId ++ " := " ++ x);
|
||||
assignExprMVar mvarId x;
|
||||
rhs ← instantiateMVars alt.rhs;
|
||||
let mvars := alt.mvars.erase mvarId;
|
||||
-- trace! `Meta.debug (">> patterns before assignment: " ++ MessageData.ofList (ps.map Pattern.toMessageData));
|
||||
-- trace! `Meta.EqnCompiler.matchDebug (">> patterns before assignment: " ++ MessageData.ofList (ps.map Pattern.toMessageData));
|
||||
patterns ← ps.mapM fun p => p.instantiateMVars;
|
||||
-- trace! `Meta.debug (">> patterns after assignment: " ++ MessageData.ofList (patterns.map Pattern.toMessageData));
|
||||
-- trace! `Meta.EqnCompiler.matchDebug (">> patterns after assignment: " ++ MessageData.ofList (patterns.map Pattern.toMessageData));
|
||||
pure { alt with patterns := patterns, rhs := rhs, mvars := mvars }
|
||||
| _ => unreachable!;
|
||||
process { p with alts := alts, vars := xs } s
|
||||
|
|
@ -377,7 +392,7 @@ match alt.patterns with
|
|||
| _ => false
|
||||
|
||||
private def processConstructor (process : Problem → State → MetaM State) (p : Problem) (s : State) : MetaM State := do
|
||||
trace! `Meta.debug ("process constructor");
|
||||
trace! `Meta.EqnCompiler.match ("process constructor");
|
||||
match p.vars with
|
||||
| x :: xs => do
|
||||
subgoals ← cases p.goal.mvarId! x.fvarId!;
|
||||
|
|
@ -432,7 +447,7 @@ private def tryConstructor? (alt : Alt) (ref : Syntax) (mvarId : MVarId) (ctorNa
|
|||
let ctor := mkAppN (mkConst ctorName us) params;
|
||||
ctorType ← inferType ctor;
|
||||
(mvars, _, resultType) ← forallMetaTelescopeReducing ctorType;
|
||||
trace! `Meta.debug ("ctorName: " ++ ctorName ++ ", resultType: " ++ resultType ++ ", expectedType: " ++ expectedType);
|
||||
trace! `Meta.EqnCompiler.matchDebug ("ctorName: " ++ ctorName ++ ", resultType: " ++ resultType ++ ", expectedType: " ++ expectedType);
|
||||
condM (isDefEq resultType expectedType)
|
||||
(Option.some <$> tryConstructorAux alt ref mvarId ctorName us params mvars 0 #[] #[])
|
||||
(pure none)
|
||||
|
|
@ -462,9 +477,8 @@ matchConst env expectedType.getAppFn (fun _ => throwInductiveTypeExpected expect
|
|||
[]
|
||||
| _ => throwInductiveTypeExpected expectedType
|
||||
|
||||
/- Auxiliary method for `processComplete` -/
|
||||
private def processComplete (process : Problem → State → MetaM State) (p : Problem) (s : State) : MetaM State := do
|
||||
trace! `Meta.debug ("process complete");
|
||||
trace! `Meta.EqnCompiler.match ("process complete");
|
||||
withGoalOf p do
|
||||
env ← getEnv;
|
||||
newAlts ← p.alts.foldlM
|
||||
|
|
@ -479,9 +493,12 @@ newAlts ← p.alts.foldlM
|
|||
[];
|
||||
process { p with alts := newAlts.reverse } s
|
||||
|
||||
private def processValue (process : Problem → State → MetaM State) (p : Problem) (s : State) : MetaM State := do
|
||||
throwOther "WIP"
|
||||
|
||||
private partial def process : Problem → State → MetaM State
|
||||
| p, s => withIncRecDepth do
|
||||
withGoalOf p (traceM `Meta.debug p.toMessageData);
|
||||
withGoalOf p (traceM `Meta.EqnCompiler.match p.toMessageData);
|
||||
if isDone p then
|
||||
processLeaf p s
|
||||
else if !isNextVar p then
|
||||
|
|
@ -492,6 +509,8 @@ private partial def process : Problem → State → MetaM State
|
|||
processConstructor process p s
|
||||
else if isCompleteTransition p then
|
||||
processComplete process p s
|
||||
else if isValueTransition p then
|
||||
processValue process p s
|
||||
else do
|
||||
msg ← p.toMessageData;
|
||||
-- TODO: remaining cases
|
||||
|
|
@ -523,7 +542,7 @@ sortv ← mkElimSort majors lhss inProp;
|
|||
generalizeTelescope majors.toArray `_d fun majors => do
|
||||
withMotive majors sortv fun motive => do
|
||||
let target := mkAppN motive majors;
|
||||
trace! `Meta.debug ("target: " ++ target);
|
||||
trace! `Meta.EqnCompiler.matchDebug ("target: " ++ target);
|
||||
withAlts motive lhss fun alts minors => do
|
||||
goal ← mkFreshExprMVar target;
|
||||
let examples := majors.toList.map fun major => Example.var major.fvarId!;
|
||||
|
|
@ -531,14 +550,18 @@ withAlts motive lhss fun alts minors => do
|
|||
let args := #[motive] ++ majors ++ minors;
|
||||
type ← mkForall args target;
|
||||
val ← mkLambda args goal;
|
||||
trace! `Meta.debug ("eliminator value: " ++ val ++ "\ntype: " ++ type);
|
||||
trace! `Meta.EqnCompiler.matchDebug ("eliminator value: " ++ val ++ "\ntype: " ++ type);
|
||||
elim ← mkAuxDefinition elimName type val;
|
||||
trace! `Meta.debug ("eliminator: " ++ elim);
|
||||
trace! `Meta.EqnCompiler.matchDebug ("eliminator: " ++ elim);
|
||||
let unusedAltIdxs : List Nat := lhss.length.fold
|
||||
(fun i r => if s.used.contains i then r else i::r)
|
||||
[];
|
||||
pure { elim := elim, counterExamples := s.counterExamples, unusedAltIdxs := unusedAltIdxs.reverse }
|
||||
|
||||
@[init] private def regTraceClasses : IO Unit := do
|
||||
registerTraceClass `Meta.EqnCompiler.match;
|
||||
registerTraceClass `Meta.EqnCompiler.matchDebug
|
||||
|
||||
end DepElim
|
||||
end Meta
|
||||
end Lean
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue