From a6529a795bdcdd03ed76ea288cfea162693c3647 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 19 Aug 2021 10:57:12 -0700 Subject: [PATCH] feat: add `casesOnStuckLHS` --- src/Lean/Meta/Match/MatchEqs.lean | 25 +++++++++++++++++++++---- tests/playground/matchEqs.lean | 2 +- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/Lean/Meta/Match/MatchEqs.lean b/src/Lean/Meta/Match/MatchEqs.lean index 6936d84114..d79b3abca9 100644 --- a/src/Lean/Meta/Match/MatchEqs.lean +++ b/src/Lean/Meta/Match/MatchEqs.lean @@ -13,6 +13,23 @@ namespace Lean.Meta.Match private def isMatchValue (e : Expr) : Bool := e.isNatLit || e.isCharLit || e.isStringLit +private def casesOnStuckLHS (mvarId : MVarId) : MetaM (Array MVarId) := do + let target ← getMVarType mvarId + if let some (_, lhs, rhs) ← matchEq? target then + matchConstRec lhs.getAppFn (fun _ => failed) fun recVal _ => do + let args := lhs.getAppArgs + if recVal.getMajorIdx >= args.size then failed + let mut major := args[recVal.getMajorIdx] + if major.isAppOfArity ``Eq.symm 4 then + /- This is needed for supporting `CasesArraySizes.lean` used in the implementation of array literal matching. -/ + major := major.appArg! + unless major.isFVar do failed + return (← cases mvarId major.fvarId!).map fun s => s.mvarId + else + failed +where + failed {α} : MetaM α := throwError "'casesOnStuckLHS' failed" + partial def mkEquationsFor (matchDeclName : Name) : MetaM Unit := do let constInfo ← getConstInfo matchDeclName let us := constInfo.levelParams.map mkLevelParam @@ -101,8 +118,6 @@ where else none - failed : MetaM Unit := throwError "" -- TODO - proveLoop (mvarId : MVarId) (depth : Nat) : MetaM Unit := withIncRecDepth do let mvarId ← modifyTargetEqLHS mvarId whnfCore trace[Meta.debug] "proveLoop\n{MessageData.ofGoal mvarId}" @@ -110,9 +125,11 @@ where <|> (contradiction mvarId) <|> + (do (← casesOnStuckLHS mvarId).forM (proveLoop . (depth + 1))) + <|> (do let mvarId' ← simpIfTarget mvarId (useDecide := true) trace[Meta.debug] "simpIfTarget\n{MessageData.ofGoal mvarId'}" - if mvarId' == mvarId then failed + if mvarId' == mvarId then throwError "simpIf failed" proveLoop mvarId' (depth+1)) <|> (do if let some (s₁, s₂) ← splitIfTarget? mvarId then @@ -120,7 +137,7 @@ where proveLoop mvarId₁ (depth+1) proveLoop s₂.mvarId (depth+1) else - failed) + throwError "spliIf failed") <|> (do trace[Meta.debug] "TODO\n{← ppGoal mvarId}" -- TODO diff --git a/tests/playground/matchEqs.lean b/tests/playground/matchEqs.lean index c143e1e768..92d4c2b9b9 100644 --- a/tests/playground/matchEqs.lean +++ b/tests/playground/matchEqs.lean @@ -41,7 +41,7 @@ def g (xs ys : Array Nat) : Nat := | _, _ => 3 set_option trace.Meta.debug true - +set_option pp.proofs true -- set_option trace.Meta.debug true test% f.match_1 test% h.match_1