diff --git a/src/Lean/Meta/Match/MatchEqs.lean b/src/Lean/Meta/Match/MatchEqs.lean index 65f5eae573..0575aad6ab 100644 --- a/src/Lean/Meta/Match/MatchEqs.lean +++ b/src/Lean/Meta/Match/MatchEqs.lean @@ -397,7 +397,8 @@ where /-- Create conditional equations and splitter for the given match auxiliary declaration. -/ -private partial def mkEquationsFor (matchDeclName : Name) : MetaM MatchEqns := +private partial def mkEquationsFor (matchDeclName : Name) : MetaM MatchEqns := do + trace[Meta.Match.matchEqs] "mkEquationsFor '{matchDeclName}'" withConfig (fun c => { c with etaStruct := false }) do let baseName := mkPrivateName (← getEnv) matchDeclName let constInfo ← getConstInfo matchDeclName diff --git a/src/Lean/Meta/Tactic/Contradiction.lean b/src/Lean/Meta/Tactic/Contradiction.lean index 092f5cbf60..b72a4ffbef 100644 --- a/src/Lean/Meta/Tactic/Contradiction.lean +++ b/src/Lean/Meta/Tactic/Contradiction.lean @@ -90,6 +90,20 @@ private def isGenDiseq (e : Expr) : Bool := | Expr.forallE _ d b _ => (d.isEq || b.hasLooseBVar 0) && isGenDiseq b | _ => e.isConstOf ``False +/-- + Given `e` s.t. `isGenDiseq e`, generate a bit-mask `mask` s.t. `mask[i] = true` iff + the `i`-th binder is an equality without forward dependencies. + + See `processGenDiseq` +-/ +private def mkGenDiseqMask (e : Expr) : Array Bool := + go e #[] +where + go (e : Expr) (acc : Array Bool) : Array Bool := + match e with + | Expr.forallE _ d b _ => go b (acc.push (!b.hasLooseBVar 0 && d.isEq)) + | _ => acc + /-- Close goal if `localDecl` is a "generalized disequality". Example: ``` @@ -101,13 +115,24 @@ private def processGenDiseq (mvarId : MVarId) (localDecl : LocalDecl) : MetaM Bo assert! isGenDiseq localDecl.type let val? ← withNewMCtxDepth do let (args, _, _) ← forallMetaTelescope localDecl.type - for arg in args do - let argType ← inferType arg - if let some (_, lhs, rhs) ← matchEq? argType then - unless (← isDefEq lhs rhs) do - return none - unless (← isDefEq arg (← mkEqRefl lhs)) do - return none + let mask := mkGenDiseqMask localDecl.type + for arg in args, useRefl in mask do + if useRefl then + /- Remark: we should not try to use `refl` for equalities that have forward dependencies because + they correspond to constructor fields. We did not use to have this extra test, and this method failed + to close the following goal. + ``` + ... + ns' : NEList String + h' : NEList.notUno ns' = true + : ∀ (ns : NEList String) (h : NEList.notUno ns = true), Value.lam (Lambda.mk ns' h') = Value.lam (Lambda.mk ns h) → False + ⊢ h_1 l a = h_2 v + + ``` + -/ + if let some (_, lhs, _) ← matchEq? (← inferType arg) then + unless (← isDefEq arg (← mkEqRefl lhs)) do + return none let falseProof ← instantiateMVars (mkAppN localDecl.toExpr args) if (← hasAssignableMVar falseProof) then return none diff --git a/tests/lean/run/arthur1.lean b/tests/lean/run/arthur1.lean new file mode 100644 index 0000000000..f0d804e978 --- /dev/null +++ b/tests/lean/run/arthur1.lean @@ -0,0 +1,398 @@ +import Std + +inductive NEList (α : Type) + | uno : α → NEList α + | cons : α → NEList α → NEList α + +def NEList.contains [BEq α] : NEList α → α → Bool + | uno a, x => a == x + | cons a as, x => a == x || as.contains x + +def NEList.noDup [BEq α] : NEList α → Bool + | uno a => true + | cons a as => ¬as.contains a && as.noDup + +@[specialize] +def NEList.foldl (f : α → β → α) : (init : α) → NEList β → α + | a, uno b => f a b + | a, cons b l => foldl f (f a b) l + +@[specialize] +def NEList.map (f : α → β) : NEList α → NEList β + | uno a => uno (f a) + | cons a as => cons (f a) (map f as) + +inductive Literal + | bool : Bool → Literal + | int : Int → Literal + | float : Float → Literal + | str : String → Literal + +inductive BinOp + | add | mul | eq | ne | lt | le | gt | ge + +inductive UnOp + | not + +mutual + + inductive Lambda + | mk : (l : NEList String) → l.noDup → Program → Lambda + + inductive Expression + | lit : Literal → Expression + | var : String → Expression + | lam : Lambda → Expression + | list : List Literal → Expression + | app : Expression → NEList Expression → Expression + | unOp : UnOp → Expression → Expression + | binOp : BinOp → Expression → Expression → Expression + + inductive Program + | skip : Program + | eval : Expression → Program + | decl : String → Program → Program + | seq : Program → Program → Program + | fork : Expression → Program → Program → Program + | loop : Expression → Program → Program + | print : Expression → Program + deriving Inhabited + +end + +inductive Value + | nil : Value + | lit : Literal → Value + | list : List Literal → Value + | lam : Lambda → Value + deriving Inhabited + +abbrev Context := Std.HashMap String Value + +inductive ErrorType + | name | type | runTime + +def Literal.typeStr : Literal → String + | bool _ => "bool" + | int _ => "int" + | float _ => "float" + | str _ => "str" + +def removeRightmostZeros (s : String) : String := + let rec aux (buff res : List Char) : List Char → List Char + | [] => res.reverse + | a :: as => + if a != '0' + then aux [] (a :: (buff ++ res)) as + else aux (a :: buff) res as + ⟨aux [] [] s.data⟩ + +protected def Literal.toString : Literal → String + | bool b => toString b + | int i => toString i + | float f => removeRightmostZeros $ toString f + | str s => s + +def Lambda.typeStr : Lambda → String + | mk l .. => (l.foldl (init := "") fun acc _ => acc ++ "_ → ") ++ "_" + +def Value.typeStr : Value → String + | nil => "nil" + | lit l => l.typeStr + | list _ => "list" + | lam l => l.typeStr + +def Literal.eq : Literal → Literal → Bool + | bool bₗ, bool bᵣ => bₗ == bᵣ + | int iₗ, int iᵣ => iₗ == iᵣ + | float fₗ, float fᵣ => fₗ == fᵣ + | int iₗ, float fᵣ => (.ofInt iₗ) == fᵣ + | float fₗ, int iᵣ => fₗ == (.ofInt iᵣ) + | str sₗ, str sᵣ => sₗ == sᵣ + | _ , _ => false + +def listLiteralEq : List Literal → List Literal → Bool + | [], [] => true + | a :: a' :: as, b :: b' :: bs => + a.eq b && listLiteralEq (a' :: as) (b' :: bs) + | _, _ => false + +def opError (app l r : String) : String := + s!"I can't perform a '{app}' operation between '{l}' and '{r}'" + +def opError1 (app v : String) : String := + s!"I can't perform a '{app}' operation on '{v}'" + +def Value.not : Value → Except String Value + | lit $ .bool b => return lit $ .bool !b + | v => throw $ opError1 "!" v.typeStr + +def Value.add : Value → Value → Except String Value + | lit $ .bool bₗ, lit $ .bool bᵣ => return lit $ .bool $ bₗ || bᵣ + | lit $ .int iₗ, lit $ .int iᵣ => return lit $ .int $ iₗ + iᵣ + | lit $ .float fₗ, lit $ .float fᵣ => return lit $ .float $ fₗ + fᵣ + | lit $ .int iₗ, lit $ .float fᵣ => return lit $ .float $ (.ofInt iₗ) + fᵣ + | lit $ .float fₗ, lit $ .int iᵣ => return lit $ .float $ fₗ + (.ofInt iᵣ) + | lit $ .str sₗ, lit $ .str sᵣ => return lit $ .str $ sₗ ++ sᵣ + | list lₗ, list lᵣ => return list $ lₗ ++ lᵣ + | list l, lit r => return list $ l.concat r + | l, r => throw $ opError "+" l.typeStr r.typeStr + +def Value.mul : Value → Value → Except String Value + | lit $ .bool bₗ, lit $ .bool bᵣ => return .lit $ .bool $ bₗ && bᵣ + | lit $ .int iₗ, lit $ .int iᵣ => return .lit $ .int $ iₗ * iᵣ + | lit $ .float fₗ, lit $ .float fᵣ => return .lit $ .float $ fₗ * fᵣ + | lit $ .int iₗ, lit $ .float fᵣ => return .lit $ .float $ (.ofInt iₗ) * fᵣ + | lit $ .float fₗ, lit $ .int iᵣ => return .lit $ .float $ fₗ * (.ofInt iᵣ) + | l, r => throw $ opError "*" l.typeStr r.typeStr + +def Bool.toNat : Bool → Nat + | false => 0 + | true => 1 + +def Value.lt : Value → Value → Except String Value + | lit $ .bool bₗ, lit $ .bool bᵣ => return lit $ .bool $ bₗ.toNat < bᵣ.toNat + | lit $ .int iₗ, lit $ .int iᵣ => return lit $ .bool $ iₗ < iᵣ + | lit $ .float fₗ, lit $ .float fᵣ => return lit $ .bool $ fₗ < fᵣ + | lit $ .int iₗ, lit $ .float fᵣ => return lit $ .bool $ (.ofInt iₗ) < fᵣ + | lit $ .float fₗ, lit $ .int iᵣ => return lit $ .bool $ fₗ < (.ofInt iᵣ) + | lit $ .str sₗ, lit $ .str sᵣ => return lit $ .bool $ sₗ < sᵣ + | list lₗ, list lᵣ => return lit $ .bool $ lₗ.length < lᵣ.length + | l, r => throw $ opError "<" l.typeStr r.typeStr + +def Value.le : Value → Value → Except String Value + | lit $ .bool bₗ, lit $ .bool bᵣ => return lit $ .bool $ bₗ.toNat ≤ bᵣ.toNat + | lit $ .int iₗ, lit $ .int iᵣ => return lit $ .bool $ iₗ ≤ iᵣ + | lit $ .float fₗ, lit $ .float fᵣ => return lit $ .bool $ fₗ ≤ fᵣ + | lit $ .int iₗ, lit $ .float fᵣ => return lit $ .bool $ (.ofInt iₗ) ≤ fᵣ + | lit $ .float fₗ, lit $ .int iᵣ => return lit $ .bool $ fₗ ≤ (.ofInt iᵣ) + | lit $ .str sₗ, lit $ .str sᵣ => return lit $ .bool $ sₗ < sᵣ || sₗ == sᵣ + | list lₗ, list lᵣ => return lit $ .bool $ lₗ.length ≤ lᵣ.length + | l, r => throw $ opError "<=" l.typeStr r.typeStr + +def Value.gt : Value → Value → Except String Value + | lit $ .bool bₗ, lit $ .bool bᵣ => return lit $ .bool $ bₗ.toNat > bᵣ.toNat + | lit $ .int iₗ, lit $ .int iᵣ => return lit $ .bool $ iₗ > iᵣ + | lit $ .float fₗ, lit $ .float fᵣ => return lit $ .bool $ fₗ > fᵣ + | lit $ .int iₗ, lit $ .float fᵣ => return lit $ .bool $ (.ofInt iₗ) > fᵣ + | lit $ .float fₗ, lit $ .int iᵣ => return lit $ .bool $ fₗ > (.ofInt iᵣ) + | lit $ .str sₗ, lit $ .str sᵣ => return lit $ .bool $ sₗ > sᵣ + | list lₗ, list lᵣ => return lit $ .bool $ lₗ.length > lᵣ.length + | l, r => throw $ opError ">" l.typeStr r.typeStr + +def Value.ge : Value → Value → Except String Value + | lit $ .bool bₗ, lit $ .bool bᵣ => return lit $ .bool $ bₗ.toNat ≥ bᵣ.toNat + | lit $ .int iₗ, lit $ .int iᵣ => return lit $ .bool $ iₗ ≥ iᵣ + | lit $ .float fₗ, lit $ .float fᵣ => return lit $ .bool $ fₗ ≥ fᵣ + | lit $ .int iₗ, lit $ .float fᵣ => return lit $ .bool $ (.ofInt iₗ) ≥ fᵣ + | lit $ .float fₗ, lit $ .int iᵣ => return lit $ .bool $ fₗ ≥ (.ofInt iᵣ) + | lit $ .str sₗ, lit $ .str sᵣ => return lit $ .bool $ sₗ > sᵣ || sₗ == sᵣ + | list lₗ, list lᵣ => return lit $ .bool $ lₗ.length ≥ lᵣ.length + | l, r => throw $ opError ">=" l.typeStr r.typeStr + +def Value.eq : Value → Value → Except String Value + | nil, nil => return lit $ .bool true + | lit lₗ, lit lᵣ => return lit $ .bool $ lₗ.eq lᵣ + | list lₗ, list lᵣ => return lit $ .bool (listLiteralEq lₗ lᵣ) + | lam .. , lam .. => throw "I can't compare functions" + | _, _ => return lit $ .bool false + +def Value.ne : Value → Value → Except String Value + | nil, nil => return lit $ .bool false + | lit lₗ, lit lᵣ => return lit $ .bool $ !(lₗ.eq lᵣ) + | list lₗ, list lᵣ => return lit $ .bool !(listLiteralEq lₗ lᵣ) + | lam .., lam .. => throw "I can't compare functions" + | _, _ => return lit $ .bool true + +def Value.unOp : Value → UnOp → Except String Value + | v, .not => v.not + +def Value.binOp : Value → Value → BinOp → Except String Value + | l, r, .add => l.add r + | l, r, .mul => l.mul r + | l, r, .lt => l.lt r + | l, r, .le => l.le r + | l, r, .gt => l.gt r + | l, r, .ge => l.ge r + | l, r, .eq => l.eq r + | l, r, .ne => l.ne r + +def NEList.unfoldStrings (l : NEList String) : String := + l.foldl (init := "") $ fun acc a => acc ++ s!" {a}" |>.trimLeft + +mutual + + partial def unfoldExpressions (es : NEList Expression) : String := + (es.map exprToString).unfoldStrings + + partial def exprToString : Expression → String + | .var n => n + | .lit l => l.toString + | .list l => toString $ l.map Literal.toString + | .lam _ => "«function»" + | .app e es => s!"({exprToString e} {unfoldExpressions es})" + | .unOp .not e => s!"!{exprToString e}" + | .binOp .add l r => s!"({exprToString l} + {exprToString r})" + | .binOp .mul l r => s!"({exprToString l} * {exprToString r})" + | .binOp .eq l r => s!"({exprToString l} = {exprToString r})" + | .binOp .ne l r => s!"({exprToString l} != {exprToString r})" + | .binOp .lt l r => s!"({exprToString l} < {exprToString r})" + | .binOp .le l r => s!"({exprToString l} <= {exprToString r})" + | .binOp .gt l r => s!"({exprToString l} > {exprToString r})" + | .binOp .ge l r => s!"({exprToString l} >= {exprToString r})" + +end + +instance : ToString Expression := ⟨exprToString⟩ + +def valToString : Value → String + | .nil => "«nil»" + | .lit l => l.toString + | .list l => toString $ l.map Literal.toString + | .lam _ => "«function»" + +instance : ToString Value := ⟨valToString⟩ + +def consume (p : Program) : + NEList String → NEList Expression → + Option ((Option (NEList String)) × Program) + | .cons n ns, .cons e es => consume (.seq (.decl n (.eval e)) p) ns es + | .cons n ns, .uno e => some (some ns, .seq (.decl n (.eval e)) p) + | .uno n, .uno e => some (none, .seq (.decl n (.eval e)) p) + | .uno _, .cons .. => none + +theorem noDupOfConsumeNoDup + (h : ns.noDup) (h' : consume p' ns es = some (some l, p)) : + l.noDup = true := by + induction ns generalizing p' es with + | uno _ => cases es <;> cases h' + | cons _ _ hi => + simp [NEList.noDup] at h + cases es with + | uno _ => simp [consume] at h'; simp only [h.2, ← h'.1] + | cons _ _ => exact hi h.2 h' + +inductive Continuation + | exit : Continuation + | seq : Program → Continuation → Continuation + | decl : String → Continuation → Continuation + | fork : Expression → Program → Program → Continuation → Continuation + | loop : Expression → Program → Continuation → Continuation + | unOp : UnOp → Expression → Continuation → Continuation + | binOp₁ : BinOp → Expression → Continuation → Continuation + | binOp₂ : BinOp → Value → Continuation → Continuation + | app : Expression → NEList Expression → Continuation → Continuation + | block : Context → Continuation → Continuation + | print : Continuation → Continuation + +inductive State + | ret : Value → Context → Continuation → State + | prog : Program → Context → Continuation → State + | expr : Expression → Context → Continuation → State + | error : ErrorType → Context → String → State + | done : Value → Context → State + +def cantEvalAsBool (e : Expression) (v : Value) : String := + s!"I can't evaluate '{e}' as a 'bool' because it reduces to '{v}', of " ++ + s!"type '{v.typeStr}'" + +def notFound (n : String) : String := + s!"I can't find the definition of '{n}'" + +def notAFunction (e : Expression) (v : Value) : String := + s!"I can't apply arguments to '{e}' because it evaluates to '{v}', of " ++ + s!"type '{v.typeStr}'" + +def wrongNParameters (e : Expression) (allowed provided : Nat) : String := + s!"I can't apply {provided} arguments to '{e}' because the maximum " ++ + s!"allowed is {allowed}" + +def NEList.length : NEList α → Nat + | uno _ => 1 + | cons _ l => 1 + l.length + +def State.step : State → State + | prog .skip ctx k => ret .nil ctx k + | prog (.eval e) ctx k => expr e ctx k + | prog (.seq p₁ p₂) ctx k => prog p₁ ctx (.seq p₂ k) + | prog (.decl n p) ctx k => prog p ctx $ .block ctx (.decl n k) + | prog (.fork e pT pF) ctx k => expr e ctx (.fork e pT pF k) + | prog (.loop e p) ctx k => expr e ctx (.loop e p k) + | prog (.print e) ctx k => expr e ctx (.print k) + + | expr (.lit l) ctx k => ret (.lit l) ctx k + | expr (.list l) ctx k => ret (.list l) ctx k + | expr (.var n) ctx k => match ctx[n] with + | none => error .name ctx $ notFound n + | some v => ret v ctx k + | expr (.lam l) ctx k => ret (.lam l) ctx k + | expr (.app e es) ctx k => expr e ctx (.app e es k) + | expr (.unOp o e) ctx k => expr e ctx (.unOp o e k) + | expr (.binOp o e₁ e₂) ctx k => expr e₁ ctx (.binOp₁ o e₂ k) + + | ret v ctx .exit => done v ctx + | ret v ctx (.print k) => dbg_trace v; ret .nil ctx k + | ret _ ctx (.seq p k) => prog p ctx k + + | ret v _ (.block ctx k) => ret v ctx k + + | ret v ctx (.app e es k) => match v with + | .lam $ .mk ns h p => match h' : consume p ns es with + | some (some l, p) => + ret (.lam $ .mk l (noDupOfConsumeNoDup h h') p) ctx k + | some (none, p) => prog p ctx (.block ctx k) + | none => error .runTime ctx $ wrongNParameters e ns.length es.length + | v => error .type ctx $ notAFunction e v + + | ret (.lit $ .bool true) ctx (.fork _ pT _ k) => prog pT ctx k + | ret (.lit $ .bool false) ctx (.fork _ _ pF k) => prog pF ctx k + | ret v ctx (.fork e ..) => error .type ctx $ cantEvalAsBool e v + + | ret (.lit $ .bool true) ctx (.loop e p k) => prog (.seq p (.loop e p)) ctx k + | ret (.lit $ .bool false) ctx (.loop _ _ k) => ret .nil ctx k + | ret v ctx (.loop e ..) => error .type ctx $ cantEvalAsBool e v + + | ret v ctx (.decl n k) => ret .nil (ctx.insert n v) k + + | ret v ctx (.unOp o e k) => match v.unOp o with + | .error m => error .type ctx m + | .ok v => ret v ctx k + | ret v1 ctx (.binOp₁ o e2 k) => expr e2 ctx (.binOp₂ o v1 k) + | ret v2 ctx (.binOp₂ o v1 k) => match v1.binOp v2 o with + | .error m => error .type ctx m + | .ok v => ret v ctx k + + | s@(error ..) => s + | s@(done ..) => s + +def Context.equiv (cₗ cᵣ : Context) : Prop := + ∀ n, cₗ[n] = cᵣ[n] + +def State.stepN : State → Nat → State + | s, 0 => s + | s, n + 1 => s.step.stepN n + +def State.reaches (s₁ s₂ : State) : Prop := + ∃ n, s₁.stepN n = s₂ + +notation cₗ " ≃ " cᵣ:21 => Context.equiv cₗ cᵣ +notation s₁ " ↠ " s₂ => State.reaches s₁ s₂ + +def State.ctx : State → Context + | ret _ c _ => c + | prog _ c _ => c + | expr _ c _ => c + | error _ c _ => c + | done _ c => c + +theorem Context.equivSelf {c : Context} : c ≃ c := + fun _ => rfl + +/- +theorem State.skipStep (h : s = (prog .skip c k).step) : s.ctx ≃ c := by + have : s.ctx = c := by rw [h, step, ctx] + simp only [this, Context.equivSelf] +-/ + +theorem State.skipClean : (prog .skip c .exit) ↠ (done .nil c) := + ⟨2 , by simp only [stepN, step]⟩ diff --git a/tests/lean/run/processGenDiseqBug.lean b/tests/lean/run/processGenDiseqBug.lean new file mode 100644 index 0000000000..b5da7000a8 --- /dev/null +++ b/tests/lean/run/processGenDiseqBug.lean @@ -0,0 +1,26 @@ +import Lean + +inductive NEList (α : Type) + | uno : α → NEList α + | cons : α → NEList α → NEList α + +def NEList.notUno : NEList α → Bool + | uno a => true + | cons a as => false + +inductive Lambda + | mk : (l : NEList String) → l.notUno = true → Lambda + +inductive Value + | lam : Lambda → Value + | nil : Value + +def State.aux (v : Value) : Bool := + match v with + | .lam (.mk ns h) => true + | v => false + +def gen : Lean.MetaM Unit := do + discard <| Lean.Meta.Match.getEquationsForImpl ``State.aux.match_1 + +#eval gen