From 9a24db4e86af9b39db7df3094c7aab05bd5c9a64 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 15 Dec 2021 16:16:42 -0800 Subject: [PATCH] =?UTF-8?q?fix:=20check=20generated=20motives=20at=20`?= =?UTF-8?q?=E2=96=B8`=20notation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit also improves the `▸` notation a bit. It now tries `subst` (if applicable) before failing. --- src/Lean/Elab/BuiltinNotation.lean | 48 ++++++++++++++++++--- tests/lean/substBadMotive.lean | 39 +++++++++++++++++ tests/lean/substBadMotive.lean.expected.out | 8 ++++ 3 files changed, 88 insertions(+), 7 deletions(-) create mode 100644 tests/lean/substBadMotive.lean create mode 100644 tests/lean/substBadMotive.lean.expected.out diff --git a/src/Lean/Elab/BuiltinNotation.lean b/src/Lean/Elab/BuiltinNotation.lean index 429c9975ac..21e63e0aba 100644 --- a/src/Lean/Elab/BuiltinNotation.lean +++ b/src/Lean/Elab/BuiltinNotation.lean @@ -267,11 +267,30 @@ where ensureHasType type e | _ => throwUnsupportedSyntax +/-- Return `true` if `lhs` is a free variable and `rhs` does not depend on it. -/ +private def isSubstCandidate (lhs rhs : Expr) : MetaM Bool := + if lhs.isFVar then + return !(← dependsOn rhs lhs.fvarId!) + else + return false + +/-- + Given an expression `e` that is the elaboration of `stx`, if `e` is a free variable, then return `k stx`. + Otherwise, return `(fun x => k x) e` +-/ +private def withLocalIdentFor (stx : Syntax) (e : Expr) (k : Syntax → TermElabM Expr) : TermElabM Expr := do + if e.isFVar then + k stx + else + let id ← mkFreshUserName `h + let aux ← withLocalDeclD id (← inferType e) fun x => do mkLambdaFVars #[x] (← k (mkIdentFrom stx id)) + return mkApp aux e + @[builtinTermElab subst] def elabSubst : TermElab := fun stx expectedType? => do let expectedType ← tryPostponeIfHasMVars expectedType? "invalid `▸` notation" match stx with - | `($heq ▸ $h) => do - let mut heq ← elabTerm heq none + | `($heqStx ▸ $hStx) => do + let mut heq ← elabTerm heqStx none let heqType ← inferType heq let heqType ← instantiateMVars heqType match (← Meta.matchEq? heqType) with @@ -290,10 +309,10 @@ where heq ← mkEqSymm heq (lhs, rhs) := (rhs, lhs) let hExpectedType := expectedAbst.instantiate1 lhs - let h ← withRef h do - let h ← elabTerm h hExpectedType + let (h, badMotive?) ← withRef hStx do + let h ← elabTerm hStx hExpectedType try - ensureHasType hExpectedType h + return (← ensureHasType hExpectedType h, none) catch ex => -- if `rhs` occurs in `hType`, we try to apply `heq` to `h` too let hType ← inferType h @@ -303,8 +322,23 @@ where let hTypeNew := hTypeAbst.instantiate1 lhs unless (← isDefEq hExpectedType hTypeNew) do throw ex - mkEqNDRec (← mkMotive hTypeAbst) h (← mkEqSymm heq) - mkEqNDRec (← mkMotive expectedAbst) h heq + let motive ← mkMotive hTypeAbst + if !(← isTypeCorrect motive) then + return (h, some motive) + else + return (← mkEqNDRec motive h (← mkEqSymm heq), none) + let motive ← mkMotive expectedAbst + if badMotive?.isSome || !(← isTypeCorrect motive) then + -- Before failing try tos use `subst` + if ← (isSubstCandidate lhs rhs <||> isSubstCandidate rhs lhs) then + withLocalIdentFor heqStx heq fun heqStx => + withLocalIdentFor hStx h fun hStx => do + let stxNew ← `(by subst $heqStx; exact $hStx) + withMacroExpansion stx stxNew (elabTerm stxNew expectedType) + else + throwError "invalid `▸` notation, failed to compute motive for the substitution" + else + mkEqNDRec motive h heq | _ => throwUnsupportedSyntax @[builtinTermElab stateRefT] def elabStateRefT : TermElab := fun stx _ => do diff --git a/tests/lean/substBadMotive.lean b/tests/lean/substBadMotive.lean new file mode 100644 index 0000000000..30ccdfae5b --- /dev/null +++ b/tests/lean/substBadMotive.lean @@ -0,0 +1,39 @@ +namespace Ex1 + variable (a : Nat) (i : Fin a) (h : 1 = a) + example : i < a := h ▸ i.2 -- `▸` uses `subst` here +end Ex1 + +namespace Ex2 +def heapifyDown' (a : Array α) (i : Fin a.size) : Array α := sorry +def heapifyDown (a : Array α) (i : Fin a.size) : Array α := + heapifyDown' a ⟨i.1, a.size_swap i i ▸ i.2⟩ -- Error, failed to compute motive, `subst` is not applicable here +end Ex2 + +namespace Ex3 +def heapifyDown (a : Array α) (i : Fin a.size) : Array α := + have : i < i := sorry + heapifyDown a ⟨i.1, a.size_swap i i ▸ i.2⟩ -- Error, failed to compute motive, `subst` is not applicable here +termination_by measure fun ⟨_, a, i⟩ => i.1 +decreasing_by assumption +end Ex3 + +namespace Ex4 +def heapifyDown (lt : α → α → Bool) (a : Array α) (i : Fin a.size) : Array α := + let left := 2 * i.1 + 1 + let right := left + 1 + have left_le : i ≤ left := sorry + have right_le : i ≤ right := sorry + have i_le : i ≤ i := Nat.le_refl _ + have j : {j : Fin a.size // i ≤ j} := if h : left < a.size then + if lt (a.get i) (a.get ⟨left, h⟩) then ⟨⟨left, h⟩, left_le⟩ else ⟨i, i_le⟩ else ⟨i, i_le⟩ + have j := if h : right < a.size then + if lt (a.get j) (a.get ⟨right, h⟩) then ⟨⟨right, h⟩, right_le⟩ else j else j + if h : i ≠ j then + let a' := a.swap i j + have : a'.size - j < a.size - i := sorry + heapifyDown lt a' ⟨j.1.1, a.size_swap i j ▸ j.1.2⟩ -- Error, failed to compute motive, `subst` is not applicable here + else + a +termination_by measure fun ⟨_, _, a, i⟩ => a.size - i.1 +decreasing_by assumption +end Ex4 diff --git a/tests/lean/substBadMotive.lean.expected.out b/tests/lean/substBadMotive.lean.expected.out new file mode 100644 index 0000000000..9737a97463 --- /dev/null +++ b/tests/lean/substBadMotive.lean.expected.out @@ -0,0 +1,8 @@ +substBadMotive.lean:7:61-7:66: warning: declaration uses 'sorry' +substBadMotive.lean:9:23-9:44: error: invalid `▸` notation, failed to compute motive for the substitution +substBadMotive.lean:14:18-14:23: warning: declaration uses 'sorry' +substBadMotive.lean:15:22-15:43: error: invalid `▸` notation, failed to compute motive for the substitution +substBadMotive.lean:24:29-24:34: warning: declaration uses 'sorry' +substBadMotive.lean:25:31-25:36: warning: declaration uses 'sorry' +substBadMotive.lean:33:39-33:44: warning: declaration uses 'sorry' +substBadMotive.lean:34:30-34:53: error: invalid `▸` notation, failed to compute motive for the substitution