diff --git a/src/Lean/Meta/EqnCompiler/DepElim.lean b/src/Lean/Meta/EqnCompiler/DepElim.lean index f15a14f27a..a163c5c58b 100644 --- a/src/Lean/Meta/EqnCompiler/DepElim.lean +++ b/src/Lean/Meta/EqnCompiler/DepElim.lean @@ -276,19 +276,11 @@ structure ElimResult := (unusedAltIdxs : List Nat) /- The number of patterns in each AltLHS must be equal to majors.length -/ -private def checkNumPatterns (majors : List Expr) (lhss : List AltLHS) : MetaM Unit := -let num := majors.length; +private def checkNumPatterns (majors : Array Expr) (lhss : List AltLHS) : MetaM Unit := +let num := majors.size; when (lhss.any (fun lhs => lhs.patterns.length != num)) $ throwOther "incorrect number of patterns" -/- - Given major premises `(x_1 : A_1) (x_2 : A_2[x_1]) ... (x_n : A_n[x_1, x_2, ...])`, return - `forall (x_1 : A_1) (x_2 : A_2[x_1]) ... (x_n : A_n[x_1, x_2, ...]), sortv` -/ -private def withMotive {α} (majors : Array Expr) (sortv : Expr) (k : Expr → MetaM α) : MetaM α := do -type ← mkForall majors sortv; -trace! `Meta.EqnCompiler.matchDebug ("motive: " ++ type); -withLocalDecl `motive type BinderInfo.default k - private def localDeclsToMVarsAux : List LocalDecl → List MVarId → FVarSubst → MetaM (List MVarId × FVarSubst) | [], mvars, s => pure (mvars.reverse, s) | d::ds, mvars, s => do @@ -710,6 +702,30 @@ s ← majors.foldlM s; pure s.getUnusedLevelParam +def mkElim (elimName : Name) (motiveType : Expr) (lhss : List AltLHS) : MetaM ElimResult := +withLocalDecl `motive motiveType BinderInfo.default fun motive => do +forallTelescopeReducing motiveType fun majors _ => do +checkNumPatterns majors lhss; +let mvarType := mkAppN motive majors; +trace! `Meta.EqnCompiler.matchDebug ("target: " ++ mvarType); +withAlts motive lhss fun alts minors => do + mvar ← mkFreshExprMVar mvarType; + let examples := majors.toList.map fun major => Example.var major.fvarId!; + s ← process { mvarId := mvar.mvarId!, vars := majors.toList, alts := alts, examples := examples } {}; + let args := #[motive] ++ majors ++ minors; + type ← mkForall args mvarType; + val ← mkLambda args mvar; + trace! `Meta.EqnCompiler.matchDebug ("eliminator value: " ++ val ++ "\ntype: " ++ type); + elim ← mkAuxDefinition elimName type val; + setInlineAttribute elimName; + trace! `Meta.EqnCompiler.matchDebug ("eliminator: " ++ elim); + let unusedAltIdxs : List Nat := lhss.length.fold + (fun i r => if s.used.contains i then r else i::r) + []; + pure { elim := elim, counterExamples := s.counterExamples, unusedAltIdxs := unusedAltIdxs.reverse } + +/- Helper methods for testins mkElim -/ + /- Return `Prop` if `inProf == true` and `Sort u` otherwise, where `u` is a fresh universe level parameter. -/ private def mkElimSort (majors : List Expr) (lhss : List AltLHS) (inProp : Bool) : MetaM Expr := if inProp then @@ -718,32 +734,11 @@ else do v ← getUnusedLevelParam majors lhss; pure $ mkSort $ v -def mkElimCore (elimName : Name) (motive : Expr) (majors : List Expr) (lhss : List AltLHS) (inProp : Bool := false) : MetaM ElimResult := do -checkNumPatterns majors lhss; -generalizeTelescope majors.toArray `_d fun majors => do - let mvarType := mkAppN motive majors; - trace! `Meta.EqnCompiler.matchDebug ("target: " ++ mvarType); - withAlts motive lhss fun alts minors => do - mvar ← mkFreshExprMVar mvarType; - let examples := majors.toList.map fun major => Example.var major.fvarId!; - s ← process { mvarId := mvar.mvarId!, vars := majors.toList, alts := alts, examples := examples } {}; - let args := #[motive] ++ majors ++ minors; - type ← mkForall args mvarType; - val ← mkLambda args mvar; - trace! `Meta.EqnCompiler.matchDebug ("eliminator value: " ++ val ++ "\ntype: " ++ type); - elim ← mkAuxDefinition elimName type val; - setInlineAttribute elimName; - trace! `Meta.EqnCompiler.matchDebug ("eliminator: " ++ elim); - let unusedAltIdxs : List Nat := lhss.length.fold - (fun i r => if s.used.contains i then r else i::r) - []; - pure { elim := elim, counterExamples := s.counterExamples, unusedAltIdxs := unusedAltIdxs.reverse } - -def mkElim (elimName : Name) (majors : List Expr) (lhss : List AltLHS) (inProp : Bool := false) : MetaM ElimResult := do +def mkElimTester (elimName : Name) (majors : List Expr) (lhss : List AltLHS) (inProp : Bool := false) : MetaM ElimResult := do sortv ← mkElimSort majors lhss inProp; generalizeTelescope majors.toArray `_d fun majors => do - withMotive majors sortv fun motive => - mkElimCore elimName motive majors.toList lhss inProp + motiveType ← mkForall majors sortv; + mkElim elimName motiveType lhss @[init] private def regTraceClasses : IO Unit := do registerTraceClass `Meta.EqnCompiler.match; diff --git a/tests/lean/run/depElim1.lean b/tests/lean/run/depElim1.lean index a6dfbc0335..c10727b473 100644 --- a/tests/lean/run/depElim1.lean +++ b/tests/lean/run/depElim1.lean @@ -139,7 +139,7 @@ def test (ex : Name) (numPats : Nat) (elimName : Name) (inProp : Bool := false) withDepElimFrom ex numPats fun majors alts => do let majors := majors.map mkFVar; trace! `Meta.debug ("majors: " ++ majors.toArray); - r ← mkElim elimName majors alts inProp; + r ← mkElimTester elimName majors alts inProp; unless r.counterExamples.isEmpty $ throwOther ("missing cases:" ++ Format.line ++ counterExamplesToMessageData r.counterExamples); unless r.unusedAltIdxs.isEmpty $