diff --git a/src/Lean/Meta/EqnCompiler.lean b/src/Lean/Meta/EqnCompiler.lean index d1e1e2171c..ae15f4417e 100644 --- a/src/Lean/Meta/EqnCompiler.lean +++ b/src/Lean/Meta/EqnCompiler.lean @@ -5,3 +5,10 @@ Authors: Leonardo de Moura -/ import Lean.Meta.EqnCompiler.MatchPattern import Lean.Meta.EqnCompiler.DepElim + +namespace Lean + +@[init] private def regTraceClasses : IO Unit := +registerTraceClass `Meta.EqnCompiler + +end Lean diff --git a/src/Lean/Meta/EqnCompiler/DepElim.lean b/src/Lean/Meta/EqnCompiler/DepElim.lean index 05119c42cb..d5857cac39 100644 --- a/src/Lean/Meta/EqnCompiler/DepElim.lean +++ b/src/Lean/Meta/EqnCompiler/DepElim.lean @@ -199,6 +199,9 @@ partial def toMessageData : Example → MessageData end Example +def examplesToMessageData (cex : List Example) : MessageData := +MessageData.joinSep (cex.map (Example.toMessageData ∘ Example.varsToUnderscore)) ", " + structure Problem := (goal : Expr) (vars : List Expr) @@ -218,13 +221,13 @@ withGoalOf p do pure $ "vars " ++ p.vars.toArray -- ++ Format.line ++ "var ids " ++ toString (p.vars.map (fun x => match x with | Expr.fvar id _ => toString id | _ => "[nonvar]")) ++ Format.line ++ MessageData.joinSep alts Format.line - + ++ Format.line ++ "examples: " ++ examplesToMessageData p.examples end Problem abbrev CounterExample := List Example def counterExampleToMessageData (cex : CounterExample) : MessageData := -MessageData.joinSep (cex.map (Example.toMessageData ∘ Example.varsToUnderscore)) ", " +examplesToMessageData cex def counterExamplesToMessageData (cexs : List CounterExample) : MessageData := MessageData.joinSep (cexs.map counterExampleToMessageData) Format.line @@ -245,7 +248,7 @@ when (lhss.any (fun lhs => lhs.patterns.length != num)) $ `forall (x_1 : A_1) (x_2 : A_2[x_1]) ... (x_n : A_n[x_1, x_2, ...]), sortv` -/ private def withMotive {α} (majors : Array Expr) (sortv : Expr) (k : Expr → MetaM α) : MetaM α := do type ← mkForall majors sortv; -trace! `Meta.debug ("motive: " ++ type); +trace! `Meta.EqnCompiler.matchDebug ("motive: " ++ type); withLocalDecl `motive type BinderInfo.default k private def localDeclsToMVarsAux : List LocalDecl → List MVarId → FVarSubst → MetaM (List MVarId × FVarSubst) @@ -276,7 +279,7 @@ private partial def withAltsAux {α} (motive : Expr) : List AltLHS → List Alt }; let idx := alts.length; let minorName := (`h).appendIndexAfter (idx+1); - trace! `Meta.debug ("minor premise " ++ minorName ++ " : " ++ minorType); + trace! `Meta.EqnCompiler.matchDebug ("minor premise " ++ minorName ++ " : " ++ minorType); withLocalDecl minorName minorType BinderInfo.default fun minor => do let rhs := mkAppN minor xs; let minors := minors.push minor; @@ -332,8 +335,20 @@ let (ok, hasVar, hasCtor) := p.alts.foldl (true, false, false); ok && hasVar && hasCtor +/- Return true if the next pattern of the remaining alternatives contain variables AND values. -/ +private def isValueTransition (p : Problem) : Bool := +let (ok, hasVar, hasVal) := p.alts.foldl + (fun (acc : Bool × Bool × Bool) (alt : Alt) => + let (ok, hasVar, hasVal) := acc; + match alt.patterns with + | Pattern.val _ _ :: _ => (ok, hasVar, true) + | Pattern.var _ _ :: _ => (ok, true, hasVal) + | _ => (false, hasVar, hasVal)) + (true, false, false); +ok && hasVar && hasVal + private def processNonVariable (process : Problem → State → MetaM State) (p : Problem) (s : State) : MetaM State := do -trace! `Meta.debug ("process non variable"); +trace! `Meta.EqnCompiler.match ("process non variable"); match p.vars with | x :: xs => let alts := p.alts.map fun alt => match alt.patterns with @@ -353,19 +368,19 @@ match p.alts with pure { s with used := s.used.insert alt.idx } private def processVariable (process : Problem → State → MetaM State) (p : Problem) (s : State) : MetaM State := do -trace! `Meta.debug ("process variable"); +trace! `Meta.EqnCompiler.match ("process variable"); match p.vars with | x :: xs => do alts ← p.alts.mapM fun alt => match alt.patterns with | Pattern.inaccessible _ _ :: ps => pure { alt with patterns := ps } | Pattern.var _ mvarId :: ps => do - -- trace! `Meta.debug (">> assign " ++ mkMVar mvarId ++ " := " ++ x); + -- trace! `Meta.EqnCompiler.matchDebug (">> assign " ++ mkMVar mvarId ++ " := " ++ x); assignExprMVar mvarId x; rhs ← instantiateMVars alt.rhs; let mvars := alt.mvars.erase mvarId; - -- trace! `Meta.debug (">> patterns before assignment: " ++ MessageData.ofList (ps.map Pattern.toMessageData)); + -- trace! `Meta.EqnCompiler.matchDebug (">> patterns before assignment: " ++ MessageData.ofList (ps.map Pattern.toMessageData)); patterns ← ps.mapM fun p => p.instantiateMVars; - -- trace! `Meta.debug (">> patterns after assignment: " ++ MessageData.ofList (patterns.map Pattern.toMessageData)); + -- trace! `Meta.EqnCompiler.matchDebug (">> patterns after assignment: " ++ MessageData.ofList (patterns.map Pattern.toMessageData)); pure { alt with patterns := patterns, rhs := rhs, mvars := mvars } | _ => unreachable!; process { p with alts := alts, vars := xs } s @@ -377,7 +392,7 @@ match alt.patterns with | _ => false private def processConstructor (process : Problem → State → MetaM State) (p : Problem) (s : State) : MetaM State := do -trace! `Meta.debug ("process constructor"); +trace! `Meta.EqnCompiler.match ("process constructor"); match p.vars with | x :: xs => do subgoals ← cases p.goal.mvarId! x.fvarId!; @@ -432,7 +447,7 @@ private def tryConstructor? (alt : Alt) (ref : Syntax) (mvarId : MVarId) (ctorNa let ctor := mkAppN (mkConst ctorName us) params; ctorType ← inferType ctor; (mvars, _, resultType) ← forallMetaTelescopeReducing ctorType; -trace! `Meta.debug ("ctorName: " ++ ctorName ++ ", resultType: " ++ resultType ++ ", expectedType: " ++ expectedType); +trace! `Meta.EqnCompiler.matchDebug ("ctorName: " ++ ctorName ++ ", resultType: " ++ resultType ++ ", expectedType: " ++ expectedType); condM (isDefEq resultType expectedType) (Option.some <$> tryConstructorAux alt ref mvarId ctorName us params mvars 0 #[] #[]) (pure none) @@ -462,9 +477,8 @@ matchConst env expectedType.getAppFn (fun _ => throwInductiveTypeExpected expect [] | _ => throwInductiveTypeExpected expectedType -/- Auxiliary method for `processComplete` -/ private def processComplete (process : Problem → State → MetaM State) (p : Problem) (s : State) : MetaM State := do -trace! `Meta.debug ("process complete"); +trace! `Meta.EqnCompiler.match ("process complete"); withGoalOf p do env ← getEnv; newAlts ← p.alts.foldlM @@ -479,9 +493,12 @@ newAlts ← p.alts.foldlM []; process { p with alts := newAlts.reverse } s +private def processValue (process : Problem → State → MetaM State) (p : Problem) (s : State) : MetaM State := do +throwOther "WIP" + private partial def process : Problem → State → MetaM State | p, s => withIncRecDepth do - withGoalOf p (traceM `Meta.debug p.toMessageData); + withGoalOf p (traceM `Meta.EqnCompiler.match p.toMessageData); if isDone p then processLeaf p s else if !isNextVar p then @@ -492,6 +509,8 @@ private partial def process : Problem → State → MetaM State processConstructor process p s else if isCompleteTransition p then processComplete process p s + else if isValueTransition p then + processValue process p s else do msg ← p.toMessageData; -- TODO: remaining cases @@ -523,7 +542,7 @@ sortv ← mkElimSort majors lhss inProp; generalizeTelescope majors.toArray `_d fun majors => do withMotive majors sortv fun motive => do let target := mkAppN motive majors; -trace! `Meta.debug ("target: " ++ target); +trace! `Meta.EqnCompiler.matchDebug ("target: " ++ target); withAlts motive lhss fun alts minors => do goal ← mkFreshExprMVar target; let examples := majors.toList.map fun major => Example.var major.fvarId!; @@ -531,14 +550,18 @@ withAlts motive lhss fun alts minors => do let args := #[motive] ++ majors ++ minors; type ← mkForall args target; val ← mkLambda args goal; - trace! `Meta.debug ("eliminator value: " ++ val ++ "\ntype: " ++ type); + trace! `Meta.EqnCompiler.matchDebug ("eliminator value: " ++ val ++ "\ntype: " ++ type); elim ← mkAuxDefinition elimName type val; - trace! `Meta.debug ("eliminator: " ++ elim); + trace! `Meta.EqnCompiler.matchDebug ("eliminator: " ++ elim); let unusedAltIdxs : List Nat := lhss.length.fold (fun i r => if s.used.contains i then r else i::r) []; pure { elim := elim, counterExamples := s.counterExamples, unusedAltIdxs := unusedAltIdxs.reverse } +@[init] private def regTraceClasses : IO Unit := do +registerTraceClass `Meta.EqnCompiler.match; +registerTraceClass `Meta.EqnCompiler.matchDebug + end DepElim end Meta end Lean