From b0fe1e5d10aed9e62949511eecc20b0c6e9fd97f Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 6 Dec 2021 17:41:51 -0800 Subject: [PATCH] feat: add Tomas Skrivan's TC resolution improvement This commit implements the TC resolution improvement suggested by Tomas at #815. Closes #815. --- src/Lean/Meta/SynthInstance.lean | 74 +++++++++++++++++++++++++++++++- tests/lean/run/815.lean | 47 ++++++++++++++++++++ 2 files changed, 120 insertions(+), 1 deletion(-) create mode 100644 tests/lean/run/815.lean diff --git a/src/Lean/Meta/SynthInstance.lean b/src/Lean/Meta/SynthInstance.lean index 6203e203a0..46ccdeb75a 100644 --- a/src/Lean/Meta/SynthInstance.lean +++ b/src/Lean/Meta/SynthInstance.lean @@ -422,6 +422,58 @@ def addAnswer (cNode : ConsumerNode) : SynthM Unit := do modify fun s => { s with tableEntries := s.tableEntries.insert key newEntry } entry.waiters.forM (wakeUp answer) +/-- + Return `true` if a type of the form `(a_1 : A_1) → ... → (a_n : A_n) → B` has an unused argument `a_i`. + + Remark: This is syntactic check and no reduction is performed. +-/ +private def hasUnusedArguments : Expr → Bool + | Expr.forallE _ d b _ => b.hasLooseBVar 0 || hasUnusedArguments b + | _ => false + +/-- + If the type of the metavariable `mvar` has unused argument, return a pair `(α, transformer)` + where `α` is a new type without the unused arguments and the `transformer` is a function for coverting a + solution with type `α` into a value that can be assigned to `mvar`. + Example: suppose `mvar` has type `(a : A) → (b : B a) → (c : C a) → D a c`, the result is the pair + ``` + ((a : A) → (c : C a) → D a c, + fun (f : (a : A) → (c : C a) → D a c) (a : A) (b : B a) (c : C a) => f a c + ) + ``` + + This method is used to improve the effectiveness of the TC resolution procedure. It was suggested and prototyped by + Tomas Skrivan. It improves the support for instances of type `a : A → C` where `a` does not appear in class `C`. + When we look for such an instance it is enough to look for an instance `c : C` and then return `fun _ => c`. + + Tomas' approach makes sure that instance of a type like `a : A → C` never gets tabled/cached. More on that later. + At the core is the this methos. it takes an expression E and does two things: + + The modification to TC resolution works this way: We are looking for an instance of `E`, if it is tabled + just get it as normal, but if not first remove all unused arguments producing `E'`. Now we look up the table again but + for `E'`. If it exists, use the transforme to create E. If it does not exists, create a new goal `E'`. +-/ +private def removeUnusedArguments? (mctx : MetavarContext) (mvar : Expr) : MetaM (Option (Expr × Expr)) := + withMCtx mctx do + let mvarType ← instantiateMVars (← inferType mvar) + if !hasUnusedArguments mvarType then + return none + else + forallTelescope mvarType fun xs body => do + let ys ← xs.foldrM (init := []) fun x ys => do + if body.containsFVar x.fvarId! then + return x :: ys + else if (← ys.anyM fun y => return (← inferType y).containsFVar x.fvarId!) then + return x :: ys + else + return ys + let ys := ys.toArray + let mvarType' ← mkForallFVars ys body + withLocalDeclD `redf mvarType' fun f => do + let transformer ← mkLambdaFVars #[f] (← mkLambdaFVars xs (mkAppN f ys)) + trace[Meta.synthInstance.unusedArgs] "{mvarType}\nhas unused arguments, reduced type{indentExpr mvarType'}\nTransformer{indentExpr transformer}" + return some (mvarType', transformer) + /-- Process the next subgoal in the given consumer node. -/ def consume (cNode : ConsumerNode) : SynthM Unit := match cNode.subgoals with @@ -431,7 +483,27 @@ def consume (cNode : ConsumerNode) : SynthM Unit := let key ← mkTableKeyFor cNode.mctx mvar let entry? ← findEntry? key match entry? with - | none => newSubgoal cNode.mctx key mvar waiter + | none => + -- Remove unused arguments and try again, see comment at `removeUnusedArguments?` + match (← removeUnusedArguments? cNode.mctx mvar) with + | none => newSubgoal cNode.mctx key mvar waiter + | some (mvarType', transformer) => + let key' ← mkTableKey cNode.mctx mvarType' + match (← findEntry? key') with + | none => do + let (mctx', mvar') ← withMCtx cNode.mctx do + let mvar' ← mkFreshExprMVar mvarType' + return (← getMCtx, mvar') + newSubgoal mctx' key' mvar' (Waiter.consumerNode { cNode with mctx := mctx', subgoals := mvar'::cNode.subgoals }) + | some entry' => do + let answers' ← entry'.answers.mapM fun a => withMCtx cNode.mctx do + let trAnswr := Expr.betaRev transformer #[← instantiateMVars a.result.expr] + let trAnswrType ← inferType trAnswr + { a with result.expr := trAnswr, resultType := trAnswrType } + modify fun s => + { s with + resumeStack := answers'.foldl (fun s answer => s.push (cNode, answer)) s.resumeStack, + tableEntries := s.tableEntries.insert key' { entry' with waiters := entry'.waiters.push waiter } } | some entry => modify fun s => { s with resumeStack := entry.answers.foldl (fun s answer => s.push (cNode, answer)) s.resumeStack, diff --git a/tests/lean/run/815.lean b/tests/lean/run/815.lean new file mode 100644 index 0000000000..0105228f4a --- /dev/null +++ b/tests/lean/run/815.lean @@ -0,0 +1,47 @@ +def is_smooth {α β} (f : α → β) : Prop := sorry + +class IsSmooth {α β} (f : α → β) : Prop where + (proof : is_smooth f) + +instance identity : IsSmooth fun a : α => a := sorry +instance const (b : β) : IsSmooth fun a : α => b := sorry +instance swap (f : α → β → γ) [∀ a, IsSmooth (f a)] : IsSmooth (λ b a => f a b) := sorry +instance parm (f : α → β → γ) [IsSmooth f] (b : β) : IsSmooth (λ a => f a b) := sorry +instance comp (f : β → γ) (g : α → β) [IsSmooth f] [IsSmooth g] : IsSmooth (fun a => f (g a)) := sorry +instance diag (f : β → δ → γ) (g : α → β) (h : α → δ) [IsSmooth f] [∀ b, IsSmooth (f b)] [IsSmooth g] [IsSmooth h] : IsSmooth (λ a => f (g a) (h a)) := sorry + +example (f : β → δ → γ) [IsSmooth f] (g : α → β) [IsSmooth g] (d : δ) : IsSmooth (λ a => f (g a) d) := by infer_instance +example (f : β → δ → γ) [IsSmooth f] (g : α → β) [IsSmooth g] : IsSmooth (λ a d => f (g a) d) := by infer_instance +example (f : β → δ → γ) [IsSmooth f] (g : α → β) [IsSmooth g] (h : α → α) [IsSmooth h] (d : δ) : IsSmooth (λ a => f (g (h a)) d) := by infer_instance +example (f : α → β → γ) [∀ a, IsSmooth (f a)] : IsSmooth (λ b a => f a b) := by infer_instance +example (f : α → β → γ → δ) [∀ a b, IsSmooth (f a b)] : IsSmooth (λ c b a => f a b c) := by infer_instance +example (f : α → β → γ → δ) [∀ a b, IsSmooth (f a b)] : IsSmooth (λ c a b => f a b c) := by infer_instance +example (f : α → β → γ → δ → ε) [∀ a b c, IsSmooth (f a b c)] : IsSmooth (λ d a b c => f a b c d) := by infer_instance +example (f : α → β → γ) [IsSmooth f] (b : β) : IsSmooth (λ a => f a b) := by infer_instance +example (f : α → β → γ → δ) [IsSmooth f] (b : β) (c : γ) : IsSmooth (λ a => f a b c) := by infer_instance +example (f : α → β → γ → δ) [IsSmooth f] (b : β) : IsSmooth (λ a c => f a b c) := by infer_instance +example (f : α → β → γ → δ) [IsSmooth f] (c : γ) : IsSmooth (λ a b => f a b c) := by infer_instance +example (f : α → β → γ → δ) (b : β) [IsSmooth (λ a => f a b)] : IsSmooth (λ a c => f a b c) := by infer_instance +example (f : α → β → γ) (g : δ → ε → α) (h : δ → ε → β) [IsSmooth f] [∀ a, IsSmooth (f a)] [IsSmooth g] [IsSmooth h] : IsSmooth (λ x y => f (g x y) (h x y)) := by infer_instance +example (f : β → δ → γ) (g : α → β) [IsSmooth f] [∀ b, IsSmooth (f b)] [IsSmooth g] (a : α): IsSmooth (λ (h : α → δ) => f (g a) (h a)) := by infer_instance +example (f : β → δ → γ) (h : α → δ) [IsSmooth f] : IsSmooth (λ (g : α → β) a => f (g a) (h a)) := by infer_instance +example (f : β → δ → γ) [IsSmooth f] (d : δ) : IsSmooth (λ (g : α → β) a => f (g a) d) := by infer_instance +example (f : β → γ) (g : β → β) [IsSmooth f] [IsSmooth g] : IsSmooth (fun x => f (g (g x))) := by infer_instance +example (f : α → β → γ) [∀ a, IsSmooth (f a)] : IsSmooth (λ b a => f a b) := by infer_instance +example (f : α → β → γ → δ) [∀ a b, IsSmooth (f a b)] : IsSmooth (λ c a b => f a b c) := by infer_instance +example (f : α → β → γ → δ → ε) [∀ a b c, IsSmooth (f a b c)] : IsSmooth (λ d a b c => f a b c d) := by infer_instance +example (f : β → δ → γ) [IsSmooth f] (g : α → β) [IsSmooth g] (d : δ) : IsSmooth (λ a => f (g a) d) := by infer_instance +example (f : β → δ → γ) [IsSmooth f] (g : α → β) [IsSmooth g] : IsSmooth (λ a d => f (g a) d) := by infer_instance +example (f : δ → β → γ) [∀ d, IsSmooth (f d)] (g : α → β) [IsSmooth g] : IsSmooth (λ a d => (f d (g a))) := by infer_instance + + +-- Recall Function.comp is not reducible anymore +instance (f : β → γ) (g : α → β) [IsSmooth f] [IsSmooth g] : IsSmooth (f ∘ g) := by + delta Function.comp + infer_instance + +example (f : β → γ) (g : α → β) [IsSmooth f] [IsSmooth g] : IsSmooth (f ∘ g) := by infer_instance + +example (f : β → γ) [IsSmooth f] : IsSmooth λ (g : α → β) => (f ∘ g) := by + delta Function.comp + infer_instance