feat: abstract proofs occurring in binders

This commit is contained in:
Leonardo de Moura 2020-09-08 12:28:04 -07:00
parent 0a853b2c44
commit 9151fef49d
4 changed files with 49 additions and 7 deletions

View file

@ -51,6 +51,10 @@ def type : LocalDecl → Expr
| cdecl _ _ _ t _ => t
| ldecl _ _ _ t _ _ => t
def setType : LocalDecl → Expr → LocalDecl
| cdecl idx id n _ bi, t => cdecl idx id n t bi
| ldecl idx id n _ v nd, t => ldecl idx id n t v nd
def binderInfo : LocalDecl → BinderInfo
| cdecl _ _ _ _ bi => bi
| ldecl _ _ _ _ _ _ => BinderInfo.default
@ -63,6 +67,10 @@ def value : LocalDecl → Expr
| cdecl _ _ _ _ _ => panic! "let declaration expected"
| ldecl _ _ _ _ v _ => v
def setValue : LocalDecl → Expr → LocalDecl
| ldecl idx id n t _ nd, v => ldecl idx id n t v nd
| d, _ => d
def updateUserName : LocalDecl → Name → LocalDecl
| cdecl index id _ type bi, userName => cdecl index id userName type bi
| ldecl index id _ type val nd, userName => ldecl index id userName type val nd
@ -214,16 +222,23 @@ match lctx with
{ fvarIdToDecl := map.insert decl.fvarId decl,
decls := decls.set decl.index decl }
def updateBinderInfo (lctx : LocalContext) (fvarId : FVarId) (bi : BinderInfo) : LocalContext :=
/--
Low-level function for updating the local context.
Assumptions about `f`, the resulting nested expressions must be definitionally equal to their original values,
the `index` nor `fvarId` are modified. -/
@[inline] def modifyLocalDecl (lctx : LocalContext) (fvarId : FVarId) (f : LocalDecl → LocalDecl) : LocalContext :=
match lctx with
| { fvarIdToDecl := map, decls := decls } =>
match lctx.find? fvarId with
| none => lctx
| some decl =>
let decl := decl.updateBinderInfo bi;
let decl := f decl;
{ fvarIdToDecl := map.insert decl.fvarId decl,
decls := decls.set decl.index decl }
def updateBinderInfo (lctx : LocalContext) (fvarId : FVarId) (bi : BinderInfo) : LocalContext :=
modifyLocalDecl lctx fvarId fun decl => decl.updateBinderInfo bi
@[export lean_local_ctx_num_indices]
def numIndices (lctx : LocalContext) : Nat :=
lctx.decls.size

View file

@ -31,10 +31,27 @@ mkAuxDefinitionFor lemmaName e
partial def visit : Expr → M Expr
| e =>
if e.isAtomic then pure e
else checkCache e fun e => condM (liftM $ isNonTrivialProof e) (mkAuxLemma e) $ match e with
| Expr.lam _ _ _ _ => lambdaLetTelescope e fun xs b => do b ← visit b; mkLambdaFVars xs b
| Expr.letE _ _ _ _ _ => lambdaLetTelescope e fun xs b => do b ← visit b; mkLambdaFVars xs b
| Expr.forallE _ _ _ _ => forallTelescope e fun xs b => do b ← visit b; mkForallFVars xs b
else do
let visitBinders (xs : Array Expr) (k : M Expr) : M Expr := do {
localInstances ← getLocalInstances;
lctx ← getLCtx;
lctx ← xs.foldlM
(fun (lctx : LocalContext) x => do
let xFVarId := x.fvarId!;
localDecl ← getLocalDecl xFVarId;
type ← visit localDecl.type;
let localDecl := localDecl.setType type;
localDecl ← match localDecl.value? with
| some value => do value ← visit value; pure $ localDecl.setValue value
| none => pure localDecl;
pure $ lctx.modifyLocalDecl xFVarId fun _ => localDecl)
lctx;
withLCtx lctx localInstances k
};
checkCache e fun e => condM (liftM $ isNonTrivialProof e) (mkAuxLemma e) $ match e with
| Expr.lam _ _ _ _ => lambdaLetTelescope e fun xs b => visitBinders xs do b ← visit b; mkLambdaFVars xs b
| Expr.letE _ _ _ _ _ => lambdaLetTelescope e fun xs b => visitBinders xs do b ← visit b; mkLambdaFVars xs b
| Expr.forallE _ _ _ _ => forallTelescope e fun xs b => visitBinders xs do b ← visit b; mkForallFVars xs b
| Expr.mdata _ b _ => do b ← visit b; pure $ e.updateMData! b
| Expr.proj _ _ b _ => do b ← visit b; pure $ e.updateProj! b
| Expr.app _ _ _ => e.withApp fun f args => do args ← args.mapM visit; pure $ mkAppN f args

View file

@ -3,7 +3,7 @@ id (fun (x : ?m.4) => x) : ?m.4 → ?m.4
f 1 (fun (x : Nat) => x) : Nat
0 : Nat
f 1 (fun (x : Nat) => x) : Nat
id : ?m.90 → ?m.90
id : ?m.91 → ?m.91
precissues.lean:15:10: error: expected command
1 : Nat
id ((fun (this : True) => this) True.intro) : True

View file

@ -14,3 +14,13 @@ by {
intro;
assumption
}
def g (i j k : Nat) (a : Array Nat) (h₁ : i < k) (h₂ : k < j) (h₃ : j < a.size) : Nat :=
let vj := a.get ⟨j, h₃⟩;
let vi := a.get ⟨i, Nat.ltTrans h₁ (Nat.ltTrans h₂ h₃)⟩;
vi + vj
set_option pp.all true in
#print g
#check g.proof_1