feat: add Tomas Skrivan's TC resolution improvement

This commit implements the TC resolution improvement suggested by
Tomas at #815.

Closes #815.
This commit is contained in:
Leonardo de Moura 2021-12-06 17:41:51 -08:00
parent c97487fd65
commit b0fe1e5d10
2 changed files with 120 additions and 1 deletions

View file

@ -422,6 +422,58 @@ def addAnswer (cNode : ConsumerNode) : SynthM Unit := do
modify fun s => { s with tableEntries := s.tableEntries.insert key newEntry }
entry.waiters.forM (wakeUp answer)
/--
Return `true` if a type of the form `(a_1 : A_1) → ... → (a_n : A_n) → B` has an unused argument `a_i`.
Remark: This is syntactic check and no reduction is performed.
-/
private def hasUnusedArguments : Expr → Bool
| Expr.forallE _ d b _ => b.hasLooseBVar 0 || hasUnusedArguments b
| _ => false
/--
If the type of the metavariable `mvar` has unused argument, return a pair `(α, transformer)`
where `α` is a new type without the unused arguments and the `transformer` is a function for coverting a
solution with type `α` into a value that can be assigned to `mvar`.
Example: suppose `mvar` has type `(a : A) → (b : B a) → (c : C a) → D a c`, the result is the pair
```
((a : A) → (c : C a) → D a c,
fun (f : (a : A) → (c : C a) → D a c) (a : A) (b : B a) (c : C a) => f a c
)
```
This method is used to improve the effectiveness of the TC resolution procedure. It was suggested and prototyped by
Tomas Skrivan. It improves the support for instances of type `a : A → C` where `a` does not appear in class `C`.
When we look for such an instance it is enough to look for an instance `c : C` and then return `fun _ => c`.
Tomas' approach makes sure that instance of a type like `a : A → C` never gets tabled/cached. More on that later.
At the core is the this methos. it takes an expression E and does two things:
The modification to TC resolution works this way: We are looking for an instance of `E`, if it is tabled
just get it as normal, but if not first remove all unused arguments producing `E'`. Now we look up the table again but
for `E'`. If it exists, use the transforme to create E. If it does not exists, create a new goal `E'`.
-/
private def removeUnusedArguments? (mctx : MetavarContext) (mvar : Expr) : MetaM (Option (Expr × Expr)) :=
withMCtx mctx do
let mvarType ← instantiateMVars (← inferType mvar)
if !hasUnusedArguments mvarType then
return none
else
forallTelescope mvarType fun xs body => do
let ys ← xs.foldrM (init := []) fun x ys => do
if body.containsFVar x.fvarId! then
return x :: ys
else if (← ys.anyM fun y => return (← inferType y).containsFVar x.fvarId!) then
return x :: ys
else
return ys
let ys := ys.toArray
let mvarType' ← mkForallFVars ys body
withLocalDeclD `redf mvarType' fun f => do
let transformer ← mkLambdaFVars #[f] (← mkLambdaFVars xs (mkAppN f ys))
trace[Meta.synthInstance.unusedArgs] "{mvarType}\nhas unused arguments, reduced type{indentExpr mvarType'}\nTransformer{indentExpr transformer}"
return some (mvarType', transformer)
/-- Process the next subgoal in the given consumer node. -/
def consume (cNode : ConsumerNode) : SynthM Unit :=
match cNode.subgoals with
@ -431,7 +483,27 @@ def consume (cNode : ConsumerNode) : SynthM Unit :=
let key ← mkTableKeyFor cNode.mctx mvar
let entry? ← findEntry? key
match entry? with
| none => newSubgoal cNode.mctx key mvar waiter
| none =>
-- Remove unused arguments and try again, see comment at `removeUnusedArguments?`
match (← removeUnusedArguments? cNode.mctx mvar) with
| none => newSubgoal cNode.mctx key mvar waiter
| some (mvarType', transformer) =>
let key' ← mkTableKey cNode.mctx mvarType'
match (← findEntry? key') with
| none => do
let (mctx', mvar') ← withMCtx cNode.mctx do
let mvar' ← mkFreshExprMVar mvarType'
return (← getMCtx, mvar')
newSubgoal mctx' key' mvar' (Waiter.consumerNode { cNode with mctx := mctx', subgoals := mvar'::cNode.subgoals })
| some entry' => do
let answers' ← entry'.answers.mapM fun a => withMCtx cNode.mctx do
let trAnswr := Expr.betaRev transformer #[← instantiateMVars a.result.expr]
let trAnswrType ← inferType trAnswr
{ a with result.expr := trAnswr, resultType := trAnswrType }
modify fun s =>
{ s with
resumeStack := answers'.foldl (fun s answer => s.push (cNode, answer)) s.resumeStack,
tableEntries := s.tableEntries.insert key' { entry' with waiters := entry'.waiters.push waiter } }
| some entry => modify fun s =>
{ s with
resumeStack := entry.answers.foldl (fun s answer => s.push (cNode, answer)) s.resumeStack,

47
tests/lean/run/815.lean Normal file
View file

@ -0,0 +1,47 @@
def is_smooth {α β} (f : α → β) : Prop := sorry
class IsSmooth {α β} (f : α → β) : Prop where
(proof : is_smooth f)
instance identity : IsSmooth fun a : α => a := sorry
instance const (b : β) : IsSmooth fun a : α => b := sorry
instance swap (f : α → β → γ) [∀ a, IsSmooth (f a)] : IsSmooth (λ b a => f a b) := sorry
instance parm (f : α → β → γ) [IsSmooth f] (b : β) : IsSmooth (λ a => f a b) := sorry
instance comp (f : β → γ) (g : α → β) [IsSmooth f] [IsSmooth g] : IsSmooth (fun a => f (g a)) := sorry
instance diag (f : β → δ → γ) (g : α → β) (h : α → δ) [IsSmooth f] [∀ b, IsSmooth (f b)] [IsSmooth g] [IsSmooth h] : IsSmooth (λ a => f (g a) (h a)) := sorry
example (f : β → δ → γ) [IsSmooth f] (g : α → β) [IsSmooth g] (d : δ) : IsSmooth (λ a => f (g a) d) := by infer_instance
example (f : β → δ → γ) [IsSmooth f] (g : α → β) [IsSmooth g] : IsSmooth (λ a d => f (g a) d) := by infer_instance
example (f : β → δ → γ) [IsSmooth f] (g : α → β) [IsSmooth g] (h : αα) [IsSmooth h] (d : δ) : IsSmooth (λ a => f (g (h a)) d) := by infer_instance
example (f : α → β → γ) [∀ a, IsSmooth (f a)] : IsSmooth (λ b a => f a b) := by infer_instance
example (f : α → β → γ → δ) [∀ a b, IsSmooth (f a b)] : IsSmooth (λ c b a => f a b c) := by infer_instance
example (f : α → β → γ → δ) [∀ a b, IsSmooth (f a b)] : IsSmooth (λ c a b => f a b c) := by infer_instance
example (f : α → β → γ → δ → ε) [∀ a b c, IsSmooth (f a b c)] : IsSmooth (λ d a b c => f a b c d) := by infer_instance
example (f : α → β → γ) [IsSmooth f] (b : β) : IsSmooth (λ a => f a b) := by infer_instance
example (f : α → β → γ → δ) [IsSmooth f] (b : β) (c : γ) : IsSmooth (λ a => f a b c) := by infer_instance
example (f : α → β → γ → δ) [IsSmooth f] (b : β) : IsSmooth (λ a c => f a b c) := by infer_instance
example (f : α → β → γ → δ) [IsSmooth f] (c : γ) : IsSmooth (λ a b => f a b c) := by infer_instance
example (f : α → β → γ → δ) (b : β) [IsSmooth (λ a => f a b)] : IsSmooth (λ a c => f a b c) := by infer_instance
example (f : α → β → γ) (g : δ → ε → α) (h : δ → ε → β) [IsSmooth f] [∀ a, IsSmooth (f a)] [IsSmooth g] [IsSmooth h] : IsSmooth (λ x y => f (g x y) (h x y)) := by infer_instance
example (f : β → δ → γ) (g : α → β) [IsSmooth f] [∀ b, IsSmooth (f b)] [IsSmooth g] (a : α): IsSmooth (λ (h : α → δ) => f (g a) (h a)) := by infer_instance
example (f : β → δ → γ) (h : α → δ) [IsSmooth f] : IsSmooth (λ (g : α → β) a => f (g a) (h a)) := by infer_instance
example (f : β → δ → γ) [IsSmooth f] (d : δ) : IsSmooth (λ (g : α → β) a => f (g a) d) := by infer_instance
example (f : β → γ) (g : β → β) [IsSmooth f] [IsSmooth g] : IsSmooth (fun x => f (g (g x))) := by infer_instance
example (f : α → β → γ) [∀ a, IsSmooth (f a)] : IsSmooth (λ b a => f a b) := by infer_instance
example (f : α → β → γ → δ) [∀ a b, IsSmooth (f a b)] : IsSmooth (λ c a b => f a b c) := by infer_instance
example (f : α → β → γ → δ → ε) [∀ a b c, IsSmooth (f a b c)] : IsSmooth (λ d a b c => f a b c d) := by infer_instance
example (f : β → δ → γ) [IsSmooth f] (g : α → β) [IsSmooth g] (d : δ) : IsSmooth (λ a => f (g a) d) := by infer_instance
example (f : β → δ → γ) [IsSmooth f] (g : α → β) [IsSmooth g] : IsSmooth (λ a d => f (g a) d) := by infer_instance
example (f : δ → β → γ) [∀ d, IsSmooth (f d)] (g : α → β) [IsSmooth g] : IsSmooth (λ a d => (f d (g a))) := by infer_instance
-- Recall Function.comp is not reducible anymore
instance (f : β → γ) (g : α → β) [IsSmooth f] [IsSmooth g] : IsSmooth (f ∘ g) := by
delta Function.comp
infer_instance
example (f : β → γ) (g : α → β) [IsSmooth f] [IsSmooth g] : IsSmooth (f ∘ g) := by infer_instance
example (f : β → γ) [IsSmooth f] : IsSmooth λ (g : α → β) => (f ∘ g) := by
delta Function.comp
infer_instance