From 83eaa47e0aa4896ec9dd6a5c0752cbbc95323cb9 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 17 Aug 2021 21:30:35 -0700 Subject: [PATCH] chore: move `MatchEqs` --- src/Lean/Meta/Match.lean | 1 + src/Lean/Meta/Match/MatchEqs.lean | 129 ++++++++++++++++++++++++++++ tests/playground/matchEqs.lean | 138 +++++------------------------- 3 files changed, 150 insertions(+), 118 deletions(-) create mode 100644 src/Lean/Meta/Match/MatchEqs.lean diff --git a/src/Lean/Meta/Match.lean b/src/Lean/Meta/Match.lean index d6d356dc47..676da5b6b7 100644 --- a/src/Lean/Meta/Match.lean +++ b/src/Lean/Meta/Match.lean @@ -7,6 +7,7 @@ import Lean.Meta.Match.MatchPatternAttr import Lean.Meta.Match.Match import Lean.Meta.Match.CaseValues import Lean.Meta.Match.CaseArraySizes +import Lean.Meta.Match.MatchEqs namespace Lean diff --git a/src/Lean/Meta/Match/MatchEqs.lean b/src/Lean/Meta/Match/MatchEqs.lean new file mode 100644 index 0000000000..2a5f0b0e2b --- /dev/null +++ b/src/Lean/Meta/Match/MatchEqs.lean @@ -0,0 +1,129 @@ +/- +Copyright (c) 2021 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +import Lean.Meta.Match.Match +import Lean.Meta.Tactic.Apply +import Lean.Meta.Tactic.Delta +import Lean.Meta.Tactic.SplitIf + +namespace Lean.Meta.Match + +private def isMatchValue (e : Expr) : Bool := + e.isNatLit || e.isCharLit || e.isStringLit + +partial def mkEquationsFor (matchDeclName : Name) : MetaM Unit := do + let constInfo ← getConstInfo matchDeclName + let us := constInfo.levelParams.map mkLevelParam + let some matchInfo ← getMatcherInfo? matchDeclName | throwError "'{matchDeclName}' is not a matcher function" + forallTelescopeReducing constInfo.type fun xs _ => do + let params := xs[:matchInfo.numParams] + let motive := xs[matchInfo.getMotivePos] + let alts := xs[xs.size - matchInfo.numAlts:] + let firstDiscrIdx := matchInfo.numParams + 1 + let discrs := xs[firstDiscrIdx : firstDiscrIdx + matchInfo.numDiscrs] + let mut notAlts := #[] + for alt in alts do + let altType ← inferType alt + trace[Meta.debug] ">> {altType}" + notAlts ← forallTelescopeReducing altType fun ys altResultType => do + let (ys, rhsArgs) ← toFVarsRHSArgs ys + let patterns := altResultType.getAppArgs + let mut hs := #[] + for notAlt in notAlts do + hs := hs.push (← instantiateForall notAlt patterns) + hs ← simpHs hs patterns.size + trace[Meta.debug] "hs: {hs}" + -- Create a proposition for representing terms that do not match `patterns` + let mut notAlt := mkConst ``False + for discr in discrs.toArray.reverse, pattern in patterns.reverse do + notAlt ← mkArrow (← mkEq discr pattern) notAlt + notAlt ← mkForallFVars (discrs ++ ys) notAlt + trace[Meta.debug] "notAlt: {notAlt}" + let lhs := mkAppN (mkConst constInfo.name us) (params ++ #[motive] ++ patterns ++ alts) + let rhs := mkAppN alt rhsArgs + let thmType ← mkEq lhs rhs + let thmType ← hs.foldrM (init := thmType) mkArrow + let thmType ← mkForallFVars (params ++ #[motive] ++ alts ++ ys) thmType + let thmVal ← prove thmType + trace[Meta.debug] "thmVal: {thmVal}" + -- check thmVal -- TODO remove + return notAlts.push notAlt +where + toFVarsRHSArgs (ys : Array Expr) : MetaM (Array Expr × Array Expr) := do + if ys.size == 1 && (← inferType ys[0]).isConstOf ``Unit then + return (#[], #[mkConst ``Unit.unit]) + else + return (ys, ys) + + simpEq (lhs : Expr) (rhs : Expr) : OptionT (StateRefT (Array Expr) MetaM) Unit := do + if isMatchValue lhs && isMatchValue rhs then + unless (← isDefEq lhs rhs) do + failure + else if rhs.isFVar then + -- Ignore case since it matches anything + pure () + else match lhs.arrayLit?, rhs.arrayLit? with + | some (_, lhsArgs), some (_, rhsArgs) => + if lhsArgs.length != rhsArgs.length then + failure + else + for lhsArg in lhsArgs, rhsArg in rhsArgs do + simpEq lhsArg rhsArg + | _, _ => + match toCtorIfLit lhs |>.constructorApp? (← getEnv), toCtorIfLit rhs |>.constructorApp? (← getEnv) with + | some (lhsCtor, lhsArgs), some (rhsCtor, rhsArgs) => + if lhsCtor.name == rhsCtor.name then + for lhsArg in lhsArgs[lhsCtor.numParams:], rhsArg in rhsArgs[lhsCtor.numParams:] do + simpEq lhsArg rhsArg + else + failure + | _, _ => + let newEq ← mkEq lhs rhs + modify fun eqs => eqs.push newEq + + simpEqs (eqs : Array Expr) : OptionT (StateRefT (Array Expr) MetaM) Unit := do + eqs.forM fun eq => + match eq.eq? with + | some (_, lhs, rhs) => simpEq lhs rhs + | _ => throwError "failed to generate equality theorems for 'match', equality expected{indentExpr eq}" + + simpHs (hs : Array Expr) (numPatterns : Nat) : MetaM (Array Expr) := + hs.filterMapM fun h => forallTelescope h fun ys _ => do + trace[Meta.debug] "ys: {ys}" + let xs := ys[:ys.size - numPatterns].toArray + let eqs ← ys[ys.size - numPatterns : ys.size].toArray.mapM inferType + if let some eqsNew ← simpEqs eqs *> get |>.run |>.run' #[] then + let newH ← eqsNew.foldrM (init := mkConst ``False) mkArrow + let xs ← xs.filterM fun x => dependsOn newH x.fvarId! + return some (← mkForallFVars xs newH) + else + none + + failed : MetaM Unit := throwError "" -- TODO + + proveLoop (mvarId : MVarId) (depth : Nat) : MetaM Unit := withIncRecDepth do + let mvarId ← modifyTargetEqLHS mvarId whnfCore + (applyRefl mvarId) + <|> + (contradiction mvarId) + <|> + (commitIfNoEx do + let s::ss ← splitIfGoal mvarId | failed + if ss.isEmpty && s.mvarId == mvarId then failed + (s::ss).forM fun s => proveLoop s.mvarId (depth + 1)) + <|> + (do + trace[Meta.debug] "TODO\n{← ppGoal mvarId}" + -- TODO + admit mvarId) + + prove (type : Expr) : MetaM Expr := + withLCtx {} {} <| forallTelescope type fun ys target => do + let mvar0 ← mkFreshExprSyntheticOpaqueMVar target + let mvarId ← deltaTarget mvar0.mvarId! (. == matchDeclName) + proveLoop mvarId 0 + mkLambdaFVars ys (← instantiateMVars mvar0) + +end Lean.Meta.Match diff --git a/tests/playground/matchEqs.lean b/tests/playground/matchEqs.lean index a96ddc14d4..6afa8ba4a2 100644 --- a/tests/playground/matchEqs.lean +++ b/tests/playground/matchEqs.lean @@ -1,113 +1,14 @@ import Lean +syntax (name := test) "test%" ident : command -namespace Lean.Meta.Match +open Lean.Elab +open Lean.Elab.Command -private def isMatchValue (e : Expr) : Bool := - e.isNatLit || e.isCharLit || e.isStringLit - -partial def mkEquationsFor (matchDeclName : Name) : MetaM Unit := do - let constInfo ← getConstInfo matchDeclName - let us := constInfo.levelParams.map mkLevelParam - let some matchInfo ← getMatcherInfo? matchDeclName | throwError "'{matchDeclName}' is not a matcher function" - forallTelescopeReducing constInfo.type fun xs _ => do - let params := xs[:matchInfo.numParams] - let motive := xs[matchInfo.getMotivePos] - let alts := xs[xs.size - matchInfo.numAlts:] - let firstDiscrIdx := matchInfo.numParams + 1 - let discrs := xs[firstDiscrIdx : firstDiscrIdx + matchInfo.numDiscrs] - let mut notAlts := #[] - for alt in alts do - let altType ← inferType alt - trace[Meta.debug] ">> {altType}" - notAlts ← forallTelescopeReducing altType fun ys altResultType => do - let (ys, rhsArgs) ← toFVarsRHSArgs ys - let patterns := altResultType.getAppArgs - let mut hs := #[] - for notAlt in notAlts do - hs := hs.push (← instantiateForall notAlt patterns) - hs ← simpHs hs patterns.size - trace[Meta.debug] "hs: {hs}" - -- Create a proposition for representing terms that do not match `patterns` - let mut notAlt := mkConst ``False - for discr in discrs.toArray.reverse, pattern in patterns.reverse do - notAlt ← mkArrow (← mkEq discr pattern) notAlt - notAlt ← mkForallFVars (discrs ++ ys) notAlt - trace[Meta.debug] "notAlt: {notAlt}" - let lhs := mkAppN (mkConst constInfo.name us) (params ++ #[motive] ++ patterns ++ alts) - let rhs := mkAppN alt rhsArgs - let thmType ← mkEq lhs rhs - let thmType ← hs.foldrM (init := thmType) mkArrow - let thmType ← mkForallFVars (params ++ #[motive] ++ alts ++ ys) thmType - let thmVal ← prove thmType - trace[Meta.debug] ">> {thmType}" - return notAlts.push notAlt -where - toFVarsRHSArgs (ys : Array Expr) : MetaM (Array Expr × Array Expr) := do - if ys.size == 1 && (← inferType ys[0]).isConstOf ``Unit then - return (#[], #[mkConst ``Unit.unit]) - else - return (ys, ys) - - simpEq (lhs : Expr) (rhs : Expr) : OptionT (StateRefT (Array Expr) MetaM) Unit := do - if isMatchValue lhs && isMatchValue rhs then - unless (← isDefEq lhs rhs) do - failure - else if rhs.isFVar then - -- Ignore case since it matches anything - pure () - else match lhs.arrayLit?, rhs.arrayLit? with - | some (_, lhsArgs), some (_, rhsArgs) => - if lhsArgs.length != rhsArgs.length then - failure - else - for lhsArg in lhsArgs, rhsArg in rhsArgs do - simpEq lhsArg rhsArg - | _, _ => - match toCtorIfLit lhs |>.constructorApp? (← getEnv), toCtorIfLit rhs |>.constructorApp? (← getEnv) with - | some (lhsCtor, lhsArgs), some (rhsCtor, rhsArgs) => - if lhsCtor.name == rhsCtor.name then - for lhsArg in lhsArgs[lhsCtor.numParams:], rhsArg in rhsArgs[lhsCtor.numParams:] do - simpEq lhsArg rhsArg - else - failure - | _, _ => - let newEq ← mkEq lhs rhs - modify fun eqs => eqs.push newEq - - simpEqs (eqs : Array Expr) : OptionT (StateRefT (Array Expr) MetaM) Unit := do - eqs.forM fun eq => - match eq.eq? with - | some (_, lhs, rhs) => simpEq lhs rhs - | _ => throwError "failed to generate equality theorems for 'match', equality expected{indentExpr eq}" - - simpHs (hs : Array Expr) (numPatterns : Nat) : MetaM (Array Expr) := - hs.filterMapM fun h => forallTelescope h fun ys _ => do - trace[Meta.debug] "ys: {ys}" - let xs := ys[:ys.size - numPatterns].toArray - let eqs ← ys[ys.size - numPatterns : ys.size].toArray.mapM inferType - if let some eqsNew ← simpEqs eqs *> get |>.run |>.run' #[] then - let newH ← eqsNew.foldrM (init := mkConst ``False) mkArrow - let xs ← xs.filterM fun x => dependsOn newH x.fvarId! - return some (← mkForallFVars xs newH) - else - none - - proveLoop (mvarId : MVarId) : MetaM Unit := do - let mvarId ← modifyTargetEqLHS mvarId whnfCore - (applyRefl mvarId) - <|> - (do trace[Meta.debug] "TODO{indentD <| MessageData.ofGoal mvarId}" - -- TODO - admit mvarId) - - prove (type : Expr) : MetaM Expr := - withLCtx {} {} <| forallTelescope type fun ys target => do - let mvar0 ← mkFreshExprSyntheticOpaqueMVar target - let mvarId ← deltaTarget mvar0.mvarId! (. == matchDeclName) - proveLoop mvarId - mkLambdaFVars ys (← instantiateMVars mvar0) - -end Lean.Meta.Match +@[commandElab test] def elabTest : CommandElab := fun stx => do + let id ← resolveGlobalConstNoOverloadWithInfo stx[1] + liftTermElabM none do + Lean.Meta.Match.mkEquationsFor id + return () def f (xs ys : List String) : Nat := match xs, ys with @@ -116,20 +17,21 @@ def f (xs ys : List String) : Nat := | _, x::xs => xs.length | _, _ => 2 + def h (x y : Nat) : Nat := match x, y with | 10000, _ => 0 | 10001, _ => 5 | _, 20000 => 4 - | x+1, _ => 3 - | Nat.zero, y+1 => 44 +-- | x+1, _ => 3 +-- | Nat.zero, y+1 => 44 | _, _ => 1 -theorem ex1 : h 10000 1 = 0 := - rfl +-- theorem ex1 : h 10000 1 = 0 := +-- rfl -theorem ex2 : h 10002 1 = 3 := - rfl +-- theorem ex2 : h 10002 1 = 3 := +-- rfl def g (xs ys : Array Nat) : Nat := match xs, ys with @@ -138,9 +40,9 @@ def g (xs ys : Array Nat) : Nat := | _, #[x, y] => 2 | _, _ => 3 --- #print f.match_1 - set_option trace.Meta.debug true -#eval Lean.Meta.Match.mkEquationsFor ``f.match_1 -#eval Lean.Meta.Match.mkEquationsFor ``h.match_1 -#eval Lean.Meta.Match.mkEquationsFor ``g.match_1 + +-- set_option trace.Meta.debug true +test% f.match_1 +test% h.match_1 +test% g.match_1