From 469de09280c0ca9af4bc2586703c287b00cea813 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 2 Dec 2020 13:25:12 -0800 Subject: [PATCH] fix: bug at `isDefEq` The new test contains a minimal example that triggers the bug. --- src/Lean/Meta/ExprDefEq.lean | 131 ++++++++++++++++++++++++++++++- tests/lean/run/isDefEqIssue.lean | 9 +++ 2 files changed, 137 insertions(+), 3 deletions(-) create mode 100644 tests/lean/run/isDefEqIssue.lean diff --git a/src/Lean/Meta/ExprDefEq.lean b/src/Lean/Meta/ExprDefEq.lean index 447dd02937..851c955bbb 100644 --- a/src/Lean/Meta/ExprDefEq.lean +++ b/src/Lean/Meta/ExprDefEq.lean @@ -10,6 +10,7 @@ import Lean.Meta.FunInfo import Lean.Meta.LevelDefEq import Lean.Meta.Check import Lean.Meta.Offset +import Lean.Meta.ForEachExpr import Lean.Meta.UnificationHint namespace Lean.Meta @@ -239,6 +240,130 @@ private def checkTypesAndAssign (mvar : Expr) (v : Expr) : MetaM Bool := trace[Meta.isDefEq.assign.typeMismatch]! "{mvar} : {mvarType} := {v} : {vType}" pure false +/-- + Auxiliary method for solving constraints of the form `?m xs := v`. + It creates a lambda using `mkLambdaFVars ys v`, where `ys` is a superset of `xs`. + `ys` is often equal to `xs`. It is a bigger when there are let-declaration dependencies in `xs`. + For example, suppose we have `xs` of the form `#[a, c]` where + ``` + a : Nat + b : Nat := f a + c : b = a + ``` + In this scenario, the type of `?m` is `(x1 : Nat) -> (x2 : f x1 = x1) -> C[x1, x2]`, + and type of `v` is `C[a, c]`. Note that, `?m a c` is type correct since `f a = a` is definitionally equal + to the type of `c : b = a`, and the type of `?m a c` is equal to the type of `v`. + Note that `fun xs => v` is the term `fun (x1 : Nat) (x2 : b = x1) => v` which has type + `(x1 : Nat) -> (x2 : b = x1) -> C[x1, x2]` which is not definitionally equal to the type of `?m`, + and may not even be type correct. + The issue here is that we are not capturing the `let`-declarations. + + This method collects let-declarations `y` occurring between `xs[0]` and `xs.back` s.t. + some `x` in `xs` depends on `y`. + `ys` is the `xs` with these extra let-declarations included. + + In the example above, `ys` is `#[a, b, c]`, and `mkLambdaFVars ys v` produces + `fun a => let b := f a; fun (c : b = a) => v` which has a type definitionally equal to the type of `?m`. + + Recall that the method `checkAssignment` ensures `v` does not contain offending `let`-declarations. + + This method assumes that for any `xs[i]` and `xs[j]` where `i < j`, we have that `index of xs[i]` < `index of xs[j]`. + where the index is the position in the local context. +-/ +private partial def mkLambdaFVarsWithLetDeps (xs : Array Expr) (v : Expr) : MetaM (Option Expr) := do + if not (← hasLetDeclsInBetween) then + mkLambdaFVars xs v + else + let ys ← addLetDeps + trace[Meta.debug]! "ys: {ys}, v: {v}" + mkLambdaFVars ys v + +where + /- Return true if there are let-declarions between `xs[0]` and `xs[xs.size-1]`. + We use it a quick-check to avoid the more expensive collection procedure. -/ + hasLetDeclsInBetween : MetaM Bool := do + let check (lctx : LocalContext) : Bool := do + let start := lctx.getFVar! xs[0] |>.index + let stop := lctx.getFVar! xs.back |>.index + for i in [start+1:stop] do + match lctx.getAt! i with + | some localDecl => + if localDecl.isLet then + return true + | _ => pure () + return false + if xs.size <= 1 then + pure false + else + check (← getLCtx) + + /- Traverse `e` and stores in the state `NameHashSet` any let-declaration with index greater than `(← read)`. + The context `Nat` is the position of `xs[0]` in the local context. -/ + collectLetDeclsFrom (e : Expr) : ReaderT Nat (StateRefT NameHashSet MetaM) Unit := do + let rec visit (e : Expr) : MonadCacheT Expr Unit (ReaderT Nat (StateRefT NameHashSet MetaM)) Unit := + checkCache e fun e => do + match e with + | Expr.forallE _ d b _ => visit d; visit b + | Expr.lam _ d b _ => visit d; visit b + | Expr.letE _ t v b _ => visit t; visit v; visit b + | Expr.app f a _ => visit f; visit a + | Expr.mdata _ b _ => visit b + | Expr.proj _ _ b _ => visit b + | Expr.fvar fvarId _ => + let localDecl ← getLocalDecl fvarId + if localDecl.isLet && localDecl.index > (← read) then + modify fun s => s.insert localDecl.fvarId + | _ => pure () + visit (← instantiateMVars e) |>.run + + /- + Auxiliary definition for traversing all declarations between `xs[0]` ... `xs.back` backwards. + The `Nat` argument is the current position in the local context being visited, and it is less than + or equal to the position of `xs.back` in the local context. + The `Nat` context `(← read)` is the position of `xs[0]` in the local context. + -/ + collectLetDepsAux : Nat → ReaderT Nat (StateRefT NameHashSet MetaM) Unit + | 0 => return () + | i+1 => do + if i+1 == (← read) then + return () + else + match (← getLCtx).getAt! (i+1) with + | none => collectLetDepsAux i + | some localDecl => + if (← get).contains localDecl.fvarId then + collectLetDeclsFrom localDecl.type + match localDecl.value? with + | some val => collectLetDeclsFrom val + | _ => pure () + collectLetDepsAux i + + /- Computes the set `ys`. It is a set of `FVarId`s, -/ + collectLetDeps : MetaM NameHashSet := do + let lctx ← getLCtx + let start := lctx.getFVar! xs[0] |>.index + let stop := lctx.getFVar! xs.back |>.index + let s := xs.foldl (init := {}) fun s x => s.insert x.fvarId! + let (_, s) ← collectLetDepsAux stop |>.run start |>.run s + return s + + /- Computes the array `ys` containing let-decls between `xs[0]` and `xs.back` that + some `x` in `xs` depends on. -/ + addLetDeps : MetaM (Array Expr) := do + let lctx ← getLCtx + let s ← collectLetDeps + /- Convert `s` into the the array `ys` -/ + let start := lctx.getFVar! xs[0] |>.index + let stop := lctx.getFVar! xs.back |>.index + let mut ys := #[] + for i in [start:stop+1] do + match lctx.getAt! i with + | none => pure () + | some localDecl => + if s.contains localDecl.fvarId then + ys := ys.push localDecl.toExpr + return ys + /- Each metavariable is declared in a particular local context. We use the notation `C |- ?m : t` to denote a metavariable `?m` that @@ -524,7 +649,7 @@ def assignToConstFun (mvar : Expr) (numArgs : Nat) (newMVar : Expr) : MetaM Bool forallBoundedTelescope mvarType numArgs fun xs _ => do if xs.size != numArgs then pure false else - let v ← mkLambdaFVars xs newMVar + let some v ← mkLambdaFVarsWithLetDeps xs newMVar | return false checkTypesAndAssign mvar v partial def check (e : Expr) : CheckAssignmentM Expr := do @@ -709,7 +834,7 @@ private def assignConst (mvar : Expr) (numArgs : Nat) (v : Expr) : MetaM Bool := if xs.size != numArgs then pure false else - let v ← mkLambdaFVars xs v + let some v ← mkLambdaFVarsWithLetDeps xs v | pure false trace[Meta.isDefEq.constApprox]! "{mvar} := {v}" checkTypesAndAssign mvar v @@ -760,7 +885,7 @@ private partial def processAssignment (mvarApp : Expr) (v : Expr) : MetaM Bool : | none => useFOApprox args | some v => do trace[Meta.isDefEq.assign.beforeMkLambda]! "{mvar} {args} := {v}" - let v ← mkLambdaFVars args v + let some v ← mkLambdaFVarsWithLetDeps args v | return false if args.any (fun arg => mvarDecl.lctx.containsFVar arg) then /- We need to type check `v` because abstraction using `mkLambdaFVars` may have produced a type incorrect term. See discussion at A2 -/ diff --git a/tests/lean/run/isDefEqIssue.lean b/tests/lean/run/isDefEqIssue.lean new file mode 100644 index 0000000000..1a7c9e9726 --- /dev/null +++ b/tests/lean/run/isDefEqIssue.lean @@ -0,0 +1,9 @@ +constant getA (s : String) : Array String := #[] + +private def resolveLValAux (s : String) (i : Nat) : Nat := + let s1 := s + let as := getA s1 + if h : i < as.size then + i - 1 + else + i