feat: add Tomas Skrivan's TC resolution improvement
This commit implements the TC resolution improvement suggested by Tomas at #815. Closes #815.
This commit is contained in:
parent
c97487fd65
commit
b0fe1e5d10
2 changed files with 120 additions and 1 deletions
|
|
@ -422,6 +422,58 @@ def addAnswer (cNode : ConsumerNode) : SynthM Unit := do
|
|||
modify fun s => { s with tableEntries := s.tableEntries.insert key newEntry }
|
||||
entry.waiters.forM (wakeUp answer)
|
||||
|
||||
/--
|
||||
Return `true` if a type of the form `(a_1 : A_1) → ... → (a_n : A_n) → B` has an unused argument `a_i`.
|
||||
|
||||
Remark: This is syntactic check and no reduction is performed.
|
||||
-/
|
||||
private def hasUnusedArguments : Expr → Bool
|
||||
| Expr.forallE _ d b _ => b.hasLooseBVar 0 || hasUnusedArguments b
|
||||
| _ => false
|
||||
|
||||
/--
|
||||
If the type of the metavariable `mvar` has unused argument, return a pair `(α, transformer)`
|
||||
where `α` is a new type without the unused arguments and the `transformer` is a function for coverting a
|
||||
solution with type `α` into a value that can be assigned to `mvar`.
|
||||
Example: suppose `mvar` has type `(a : A) → (b : B a) → (c : C a) → D a c`, the result is the pair
|
||||
```
|
||||
((a : A) → (c : C a) → D a c,
|
||||
fun (f : (a : A) → (c : C a) → D a c) (a : A) (b : B a) (c : C a) => f a c
|
||||
)
|
||||
```
|
||||
|
||||
This method is used to improve the effectiveness of the TC resolution procedure. It was suggested and prototyped by
|
||||
Tomas Skrivan. It improves the support for instances of type `a : A → C` where `a` does not appear in class `C`.
|
||||
When we look for such an instance it is enough to look for an instance `c : C` and then return `fun _ => c`.
|
||||
|
||||
Tomas' approach makes sure that instance of a type like `a : A → C` never gets tabled/cached. More on that later.
|
||||
At the core is the this methos. it takes an expression E and does two things:
|
||||
|
||||
The modification to TC resolution works this way: We are looking for an instance of `E`, if it is tabled
|
||||
just get it as normal, but if not first remove all unused arguments producing `E'`. Now we look up the table again but
|
||||
for `E'`. If it exists, use the transforme to create E. If it does not exists, create a new goal `E'`.
|
||||
-/
|
||||
private def removeUnusedArguments? (mctx : MetavarContext) (mvar : Expr) : MetaM (Option (Expr × Expr)) :=
|
||||
withMCtx mctx do
|
||||
let mvarType ← instantiateMVars (← inferType mvar)
|
||||
if !hasUnusedArguments mvarType then
|
||||
return none
|
||||
else
|
||||
forallTelescope mvarType fun xs body => do
|
||||
let ys ← xs.foldrM (init := []) fun x ys => do
|
||||
if body.containsFVar x.fvarId! then
|
||||
return x :: ys
|
||||
else if (← ys.anyM fun y => return (← inferType y).containsFVar x.fvarId!) then
|
||||
return x :: ys
|
||||
else
|
||||
return ys
|
||||
let ys := ys.toArray
|
||||
let mvarType' ← mkForallFVars ys body
|
||||
withLocalDeclD `redf mvarType' fun f => do
|
||||
let transformer ← mkLambdaFVars #[f] (← mkLambdaFVars xs (mkAppN f ys))
|
||||
trace[Meta.synthInstance.unusedArgs] "{mvarType}\nhas unused arguments, reduced type{indentExpr mvarType'}\nTransformer{indentExpr transformer}"
|
||||
return some (mvarType', transformer)
|
||||
|
||||
/-- Process the next subgoal in the given consumer node. -/
|
||||
def consume (cNode : ConsumerNode) : SynthM Unit :=
|
||||
match cNode.subgoals with
|
||||
|
|
@ -431,7 +483,27 @@ def consume (cNode : ConsumerNode) : SynthM Unit :=
|
|||
let key ← mkTableKeyFor cNode.mctx mvar
|
||||
let entry? ← findEntry? key
|
||||
match entry? with
|
||||
| none => newSubgoal cNode.mctx key mvar waiter
|
||||
| none =>
|
||||
-- Remove unused arguments and try again, see comment at `removeUnusedArguments?`
|
||||
match (← removeUnusedArguments? cNode.mctx mvar) with
|
||||
| none => newSubgoal cNode.mctx key mvar waiter
|
||||
| some (mvarType', transformer) =>
|
||||
let key' ← mkTableKey cNode.mctx mvarType'
|
||||
match (← findEntry? key') with
|
||||
| none => do
|
||||
let (mctx', mvar') ← withMCtx cNode.mctx do
|
||||
let mvar' ← mkFreshExprMVar mvarType'
|
||||
return (← getMCtx, mvar')
|
||||
newSubgoal mctx' key' mvar' (Waiter.consumerNode { cNode with mctx := mctx', subgoals := mvar'::cNode.subgoals })
|
||||
| some entry' => do
|
||||
let answers' ← entry'.answers.mapM fun a => withMCtx cNode.mctx do
|
||||
let trAnswr := Expr.betaRev transformer #[← instantiateMVars a.result.expr]
|
||||
let trAnswrType ← inferType trAnswr
|
||||
{ a with result.expr := trAnswr, resultType := trAnswrType }
|
||||
modify fun s =>
|
||||
{ s with
|
||||
resumeStack := answers'.foldl (fun s answer => s.push (cNode, answer)) s.resumeStack,
|
||||
tableEntries := s.tableEntries.insert key' { entry' with waiters := entry'.waiters.push waiter } }
|
||||
| some entry => modify fun s =>
|
||||
{ s with
|
||||
resumeStack := entry.answers.foldl (fun s answer => s.push (cNode, answer)) s.resumeStack,
|
||||
|
|
|
|||
47
tests/lean/run/815.lean
Normal file
47
tests/lean/run/815.lean
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
def is_smooth {α β} (f : α → β) : Prop := sorry
|
||||
|
||||
class IsSmooth {α β} (f : α → β) : Prop where
|
||||
(proof : is_smooth f)
|
||||
|
||||
instance identity : IsSmooth fun a : α => a := sorry
|
||||
instance const (b : β) : IsSmooth fun a : α => b := sorry
|
||||
instance swap (f : α → β → γ) [∀ a, IsSmooth (f a)] : IsSmooth (λ b a => f a b) := sorry
|
||||
instance parm (f : α → β → γ) [IsSmooth f] (b : β) : IsSmooth (λ a => f a b) := sorry
|
||||
instance comp (f : β → γ) (g : α → β) [IsSmooth f] [IsSmooth g] : IsSmooth (fun a => f (g a)) := sorry
|
||||
instance diag (f : β → δ → γ) (g : α → β) (h : α → δ) [IsSmooth f] [∀ b, IsSmooth (f b)] [IsSmooth g] [IsSmooth h] : IsSmooth (λ a => f (g a) (h a)) := sorry
|
||||
|
||||
example (f : β → δ → γ) [IsSmooth f] (g : α → β) [IsSmooth g] (d : δ) : IsSmooth (λ a => f (g a) d) := by infer_instance
|
||||
example (f : β → δ → γ) [IsSmooth f] (g : α → β) [IsSmooth g] : IsSmooth (λ a d => f (g a) d) := by infer_instance
|
||||
example (f : β → δ → γ) [IsSmooth f] (g : α → β) [IsSmooth g] (h : α → α) [IsSmooth h] (d : δ) : IsSmooth (λ a => f (g (h a)) d) := by infer_instance
|
||||
example (f : α → β → γ) [∀ a, IsSmooth (f a)] : IsSmooth (λ b a => f a b) := by infer_instance
|
||||
example (f : α → β → γ → δ) [∀ a b, IsSmooth (f a b)] : IsSmooth (λ c b a => f a b c) := by infer_instance
|
||||
example (f : α → β → γ → δ) [∀ a b, IsSmooth (f a b)] : IsSmooth (λ c a b => f a b c) := by infer_instance
|
||||
example (f : α → β → γ → δ → ε) [∀ a b c, IsSmooth (f a b c)] : IsSmooth (λ d a b c => f a b c d) := by infer_instance
|
||||
example (f : α → β → γ) [IsSmooth f] (b : β) : IsSmooth (λ a => f a b) := by infer_instance
|
||||
example (f : α → β → γ → δ) [IsSmooth f] (b : β) (c : γ) : IsSmooth (λ a => f a b c) := by infer_instance
|
||||
example (f : α → β → γ → δ) [IsSmooth f] (b : β) : IsSmooth (λ a c => f a b c) := by infer_instance
|
||||
example (f : α → β → γ → δ) [IsSmooth f] (c : γ) : IsSmooth (λ a b => f a b c) := by infer_instance
|
||||
example (f : α → β → γ → δ) (b : β) [IsSmooth (λ a => f a b)] : IsSmooth (λ a c => f a b c) := by infer_instance
|
||||
example (f : α → β → γ) (g : δ → ε → α) (h : δ → ε → β) [IsSmooth f] [∀ a, IsSmooth (f a)] [IsSmooth g] [IsSmooth h] : IsSmooth (λ x y => f (g x y) (h x y)) := by infer_instance
|
||||
example (f : β → δ → γ) (g : α → β) [IsSmooth f] [∀ b, IsSmooth (f b)] [IsSmooth g] (a : α): IsSmooth (λ (h : α → δ) => f (g a) (h a)) := by infer_instance
|
||||
example (f : β → δ → γ) (h : α → δ) [IsSmooth f] : IsSmooth (λ (g : α → β) a => f (g a) (h a)) := by infer_instance
|
||||
example (f : β → δ → γ) [IsSmooth f] (d : δ) : IsSmooth (λ (g : α → β) a => f (g a) d) := by infer_instance
|
||||
example (f : β → γ) (g : β → β) [IsSmooth f] [IsSmooth g] : IsSmooth (fun x => f (g (g x))) := by infer_instance
|
||||
example (f : α → β → γ) [∀ a, IsSmooth (f a)] : IsSmooth (λ b a => f a b) := by infer_instance
|
||||
example (f : α → β → γ → δ) [∀ a b, IsSmooth (f a b)] : IsSmooth (λ c a b => f a b c) := by infer_instance
|
||||
example (f : α → β → γ → δ → ε) [∀ a b c, IsSmooth (f a b c)] : IsSmooth (λ d a b c => f a b c d) := by infer_instance
|
||||
example (f : β → δ → γ) [IsSmooth f] (g : α → β) [IsSmooth g] (d : δ) : IsSmooth (λ a => f (g a) d) := by infer_instance
|
||||
example (f : β → δ → γ) [IsSmooth f] (g : α → β) [IsSmooth g] : IsSmooth (λ a d => f (g a) d) := by infer_instance
|
||||
example (f : δ → β → γ) [∀ d, IsSmooth (f d)] (g : α → β) [IsSmooth g] : IsSmooth (λ a d => (f d (g a))) := by infer_instance
|
||||
|
||||
|
||||
-- Recall Function.comp is not reducible anymore
|
||||
instance (f : β → γ) (g : α → β) [IsSmooth f] [IsSmooth g] : IsSmooth (f ∘ g) := by
|
||||
delta Function.comp
|
||||
infer_instance
|
||||
|
||||
example (f : β → γ) (g : α → β) [IsSmooth f] [IsSmooth g] : IsSmooth (f ∘ g) := by infer_instance
|
||||
|
||||
example (f : β → γ) [IsSmooth f] : IsSmooth λ (g : α → β) => (f ∘ g) := by
|
||||
delta Function.comp
|
||||
infer_instance
|
||||
Loading…
Add table
Reference in a new issue