lean4-htt/src/Lean/Meta/LetToHave.lean
Jason Yuen facc356a0a
chore: fix spelling errors (#10042)
Typos were found with
```
pip install codespell --upgrade
codespell --summary --ignore-words-list enew,forin,fro,happend,hge,ihs,iterm,spred --skip stage0 --check-filenames
codespell --summary --ignore-words-list enew,forin,fro,happend,hge,ihs,iterm,spred --skip stage0 --check-filenames --regex '[A-Z][a-z]*'
codespell --summary --ignore-words-list enew,forin,fro,happend,hge,ihs,iterm,spred --skip stage0 --check-filenames --regex "\b[a-z']*"
```
2025-08-22 07:23:12 +00:00

447 lines
18 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Kyle Miller
-/
module
prelude
public import Lean.Meta.Check
public import Lean.ReservedNameAction
public import Lean.AddDecl
public import Lean.Meta.Transform
public import Lean.Util.CollectFVars
public import Lean.Util.CollectMVars
public section
/-!
# Transforming nondependent `let`s into `have`s
A `let` expression `let x : t := v; b` is *nondependent* if `fun x : t => b` is type correct.
Nondependent `let`s are those that can be transformed into `have x := v; b`.
This module has a procedure that detects which `let`s are nondependent and does the transformation,
attempting to do so efficiently.
Dependence checking is approximated using the `withTrackingZetaDelta` technique:
given `let x := v; b`, we add a `x := v` declaration to the local context,
and then type check `b` with `withTrackingZetaDelta` enabled to record whether `x` is unfolded.
If `x` is not unfolded, then we know that `b` does not depend on `v`.
This is a conservative check, since `isDefEq` may unfold local definitions unnecessarily.
We do not use `Lean.Meta.check` directly. A naive algorithm would be to do `Meta.check (b.instantiate1 x)`
for each `let` body, which would involve rechecking subexpressions multiple times when there are nested `let`s,
and furthermore we do not need to fully typecheck the body when evaluating dependence.
Instead, we re-implement a type checking algorithm here to be able to interleave checking and transformation.
The trace class `trace.Meta.letToHave` reports statistics.
The transformation has very limited support for metavariables.
Any `let` that contains a metavariable remains a `let` for now.
Optimizations, present and future:
- We can avoid doing the transformation if the expression has no `let`s.
- We can also avoid doing the transformation to `let`-free subexpressions that are not inside a `let`,
however checking for `let`s is O(n), so we only try this for expressions with a small `approxDepth`.
(We can consider precomputing this somehow.)
- The cache is currently responsible for the check.
- We also do it before entering telescopes, to avoid unnecessary fvar overhead.
- If we are not currently inside a `let`, then we do not need to do full typechecking.
- We try to reuse Exprs to promote subexpression sharing.
- We might consider not transforming lets to haves if we are in a proof that is not inside a `let`.
For now we assume `abstractNestedProofs` has already been applied.
-/
namespace Lean.Meta
namespace LetToHave
/--
Returns `true` if there are any `letE (nondep := false)` subexpressions.
If true, then we must be sure to visit the subexpression.
If false, then we might still need to visit the subexpression,
but if we are not currently under any nondependent lets it is safe to skip it.
-/
private def hasDepLet (e : Expr) : Bool :=
Option.isSome <| e.find? (· matches .letE (nondep := false) ..)
/--
Heuristic for skipping subexpressions. If true, we definitely can skip.
The default max depth of `5` was not experimentally optimized, except to see that it was faster than `0`.
-/
private def canSkip (e : Expr) (maxDepth : UInt32 := 5) : Bool :=
!e.hasFVar && !e.hasExprMVar && (e.approxDepth ≤ maxDepth && !hasDepLet e)
private structure Result where
/-- The transformed expression. -/
expr : Expr
/-- The type of `expr`, if it has been computed. -/
type? : Option Expr
deriving Inhabited
private local instance : Coe Result Expr where coe := Result.expr
private structure Context where
/-- The dependent lets we are currently under.
If this list is nonempty, then full typechecking is necessary. -/
letFVars : List FVarId := []
private structure State where
/-- The number of transformed `let` expressions. See `incCount`. -/
count : Nat := 0
/-- Cached results for `visit`. -/
results : Std.HashMap ExprStructEq Result := {}
private abbrev M := ReaderT Context (StateRefT State MetaM)
/-- Gives the type of `r.expr`. If it has not been computed yet, updates the cache. -/
private def Result.type (r : Result) : M Expr := do
if let some type := r.type? then
return type
else
let type ← inferType r.expr
let r := { r with type? := type }
modify fun s => { s with results := s.results.insert r.expr r }
return type
/-- Returns `true` if we need to do full typechecking due to `let` variables being in scope. -/
private def Context.check (ctx : Context) : Bool := !ctx.letFVars.isEmpty
/-- If we don't need full typechecking, returns `e`, otherwise evaluates `m`. -/
private def whenCheck (e : Expr) (m : M Result) : M Result := do
if (← read).check then m else return { expr := e, type? := none }
/-- Executes `m` in a context where `letFVars := fvars`. -/
private def withLetFVars (fvars : List FVarId) (m : M α) : M α := do
withReader (fun ctx => { ctx with letFVars := fvars }) m
/-- Increments the count of the number of `let`s transformed into `have`s. -/
private def incCount : M Unit :=
modify fun s => { s with count := s.count + 1 }
/--
Finds a pre-existing result in the cache.
Note that the result might have no type, which happens for example if it was visited when `check` is false.
-/
private def findCache? (e : Expr) : M (Option Result) := do
return (← get).results[(e : ExprStructEq)]?
/--
Finds `e` in the cache, or otherwise computes `m`.
If not in the cache, applies a cheap check to see if we can skip descending into the expression.
-/
private def checkCache (e : Expr) (m : M Result) : M Result := do
if let some r ← findCache? e then
return r
else
-- `2` was not experimentally optimized
let r ← if canSkip (maxDepth := 2) e then
pure { expr := e, type? := none }
else
m
modify fun st => { st with results := st.results.insert e r }
return r
/-- Like `findMCache?` but checks that `e` doesn't have any loose bvars. -/
private def findCacheNoBVars? (e : Expr) : M (Option Result) :=
if e.hasLooseBVars then pure none else findCache? e
private def visitFVar (e : Expr) : MetaM Result := do
let some decl ← e.fvarId!.findDecl? | e.fvarId!.throwUnknown
return { expr := e, type? := decl.type }
/--
Give an expression `e` whose definition may be used in an unknown manner (for example, through a metavariable),
marks all fvars in `e` (or accessible through `e`) that can potentially be unfolded.
Assumption: while there may be metavariables in `e` (or in types of fvars present in `e`),
they have already been processed by `checkMVar` or will be processed by it.
-/
private def visitDepExpr (e : Expr) : M Unit := do
let mut visited : FVarIdSet := {}
let mut worklist := #[e]
while !worklist.isEmpty do
let e ← instantiateMVars worklist.back!
worklist := worklist.pop
for fvarId in (collectFVars {} e).fvarIds do
unless visited.contains fvarId do
visited := visited.insert fvarId
if ← fvarId.isLetVar then
addZetaDeltaFVarId fvarId
worklist := worklist.push (← fvarId.getType)
/--
Checks whether the mvar creates a dependency on any let fvars.
Note: the local context of `mvarId` cannot depend on `letFVars`, since it was created outside these `let`s.
The only consideration is delayed assignments and which variables they depend on;
if the fvar is not passed among the `args`, the mvar cannot depend on it.
-/
private def checkMVar (mvarId : MVarId) (args : Array Expr) : M Unit := do
if let some { fvars, mvarIdPending } ← getDelayedMVarAssignment? mvarId then
let letFVars := (← read).letFVars
unless fvars.size ≤ args.size do
-- This is an invalid delayed assignment. Mark all `letFVars` to inhibit transformation.
letFVars.forM (addZetaDeltaFVarId ·)
return
let pendingDecl ← mvarIdPending.getDecl
for fvar in fvars, arg in args do
let fvarDecl := pendingDecl.lctx.getFVar! fvar
if fvarDecl.isLet then
visitDepExpr arg
private def visitMVar (e : Expr) : M Result := do
let some decl ← e.mvarId!.findDecl? | throwUnknownMVar e.mvarId!
if (← read).check then checkMVar e.mvarId! #[]
return { expr := e, type? := decl.type }
private def visitConst (e : Expr) : M Result := do
whenCheck e do
let .const constName us := e | unreachable!
let cinfo ← getConstVal constName
if cinfo.levelParams.length == us.length then
let type ← instantiateTypeLevelParams cinfo us
return { expr := e, type? := type }
else
throwIncorrectNumberOfLevels constName us
/--
When checking, makes sure that `r.type?` is of the form `Expr.sort _`.
-/
private def ensureType (r : Result) : M Result := do
if (← read).check then
let type ← r.type
let r := { r with type? := type }
if type.isSort then
return r
else
let .sort u ← whnf type | throwTypeExpected r
let r := { r with type? := Expr.sort u }
modify fun s => { s with results := s.results.insert r.expr r }
return r
else
return r
/--
Note: We want to cache all prefixes of each application, hence no `instantiateRevRange`-type logic here.
-/
private def visitApp (e : Expr) (f a : Result) : M Result := do
if (← read).check then
let mut fType ← f.type
unless fType.isForall do
fType ← whnf fType
match fType with
| Expr.forallE _ d b _ =>
unless (← isDefEq d (← a.type)) do
throwAppTypeMismatch f a
return { expr := e.updateApp! f a, type? := b.instantiate1 a }
| _ => throwFunctionExpected (mkApp f a)
else
return { expr := e.updateApp! f a, type? := none }
mutual
private partial def visitType (e : Expr) : M Result := do
let r ← visit e
ensureType r
private partial def visitAppArgs (e : Expr) : M Result := do
if (← read).check then
if let .mvar mvarId := e.getAppFn then
checkMVar mvarId e.getAppArgs
let rec go (e : Expr) : M Result := do
let Expr.app f a := e | visit e
visitApp e (← checkCache f <| go f) (← visit a)
go e
else
-- If not checking, skip caching each prefix.
let rec go' (e : Expr) : M Expr := do
let Expr.app f a := e | visit e
return e.updateApp! (← go' f) (← visit a)
return { expr := ← go' e, type? := none }
private partial def visitForall (e : Expr) : M Result := do
if canSkip e then
return { expr := e, type? := none }
else
go (← getLCtx) #[] #[] e
where
go (lctx : LocalContext) (fvars : Array Expr) (doms : Array Result) (e : Expr) : M Result := do
if let some e' ← findCacheNoBVars? e then
return ← withLCtx lctx {} do finalize fvars doms e'
else
match e with
| .forallE n t b bi =>
let t ← withLCtx lctx {} do visitType (t.instantiateRev fvars)
let fvarId ← mkFreshFVarId
let lctx := lctx.mkLocalDecl fvarId n t.expr bi
go lctx (fvars.push (.fvar fvarId)) (doms.push t) b
| _ =>
withLCtx lctx {} do
let e' ← visit (e.instantiateRev fvars)
finalize fvars doms e'
finalize (fvars : Array Expr) (doms : Array Result) (body : Result) : M Result := do
let e' := (← getLCtx).mkForall fvars body
if (← read).check then
let bodyLevel := (← ensureType body).type?.get!.sortLevel!
let u ← doms.foldrM (init := bodyLevel) fun dom u =>
return mkLevelIMax' (← dom.type).sortLevel! u
return { expr := e', type? := Expr.sort u }
else
return { expr := e', type? := none }
/--
Visits lambdas, lets, and haves.
Enters the entire telescope at once.
We do not check the cache at each step of the telescope since we assume that there are no unused variables.
-/
private partial def visitLambdaLet (e : Expr) : M Result := do
if canSkip e then
return { expr := e, type? := none }
else
go (← getLCtx) #[] e (← read).letFVars
where
/--
Enters a lambda/let/have telescope, checking that each domain type is a type.
For let/have, checks that each value has a type defeq to the domain type.
Calls `finalize` once the telescope is constructed.
-/
go (lctx : LocalContext) (fvars : Array Expr) (e : Expr) (letFVars : List FVarId) : M Result := do
let inCtx (v : Expr → M Result) (e : Expr) : M Result :=
withLCtx lctx {} <| withLetFVars letFVars <| v (e.instantiateRev fvars)
match e with
| .lam n t b bi =>
let t ← inCtx visitType t
let fvarId ← mkFreshFVarId
let lctx := lctx.mkLocalDecl fvarId n t.expr bi
go lctx (fvars.push (.fvar fvarId)) b letFVars
| .letE n t v b nondep =>
let t ← inCtx visitType t
let v ← inCtx visit v
unless letFVars.isEmpty do withLCtx' lctx do
let vType ← v.type
unless (← isDefEq t vType) do
throwError "invalid let declaration, term{indentExpr v}{← mkHasTypeButIsExpectedMsg vType t}"
let fvarId ← mkFreshFVarId
let lctx := lctx.mkLetDecl fvarId n t v nondep
let letFVars := if nondep then letFVars else fvarId :: letFVars
go lctx (fvars.push (.fvar fvarId)) b letFVars
| _ =>
inCtx (finalize fvars <=< visit) e
/--
This function rebuilds the expression and converts `let`s to `have`s when possible.
-/
finalize (fvars : Array Expr) (body : Result) : M Result := do
trace[Meta.letToHave.debug] "finalize {fvars},{indentD m!"{body.expr} : {body.type?}"}"
let body' := {
expr := body.expr.abstract fvars
type? := body.type?.map (·.abstract fvars)
}
Nat.foldRevM fvars.size (init := body') fun i _ res => do
let .fvar fvarId := fvars[i] | unreachable!
let some decl ← fvarId.findDecl? | fvarId.throwUnknown
match decl with
| .cdecl _ _ n t bi _ => do
let t := t.abstractRange i fvars
pure {
expr := Expr.lam n t res.expr bi
type? := res.type?.map fun type => Expr.forallE n t type bi
}
| .ldecl _ _ n t v nondep _ => do
let nondep ←
if nondep then pure true
else if !(← getZetaDeltaFVarIds).contains fvarId then incCount; pure true
else pure false
let t := t.abstractRange i fvars
let v := v.abstractRange i fvars
pure {
expr := Expr.letE n t v res.expr nondep
type? := res.type?.map fun type =>
if type.hasLooseBVar 0 then
Expr.letE n t v type nondep
else
type.lowerLooseBVars 1 1
}
private partial def visitProj (e : Expr) (structName : Name) (idx : Nat) (struct : Result) : M Result := do
unless (← read).check do
return { expr := e.updateProj! struct, type? := none }
let structType ← whnf (← struct.type)
let prop ← isProp structType
let failed {α} (_ : Unit) : M α := throwError "invalid projection{indentExpr (mkProj structName idx struct)}\nfrom type{indentExpr structType}"
matchConstStructure structType.getAppFn failed fun structVal structLvls ctorVal => do
unless structVal.name == structName do failed ()
let structTypeArgs := structType.getAppArgs
if structVal.numParams + structVal.numIndices != structTypeArgs.size then
failed ()
let mut ctorType ← inferType <| mkAppN (mkConst ctorVal.name structLvls) structTypeArgs[*...structVal.numParams]
let mut args := #[]
let mut j := 0
let mut lastFieldTy : Expr := default
for i in *...=idx do
unless ctorType.isForall do
ctorType ← whnf <| ctorType.instantiateRevRange j i args
j := i
let .forallE _ dom body _ := ctorType | failed ()
let dom := dom.instantiateRevRange j i args
if prop then unless ← isProp dom do failed ()
args := args.push <| Expr.proj structName i struct
ctorType := body
lastFieldTy := dom
lastFieldTy := lastFieldTy.cleanupAnnotations
return { expr := e.updateProj! struct, type? := lastFieldTy }
private partial def visit (e : Expr) : M Result := do
withTraceNode `Meta.letToHave.debug (fun res =>
return m!"{if res.isOk then checkEmoji else crossEmoji} visit (check := {(← read).check}){indentExpr e}") do
match e with
| .bvar .. => throwError "unexpected bound variable {e}"
| .fvar .. => visitFVar e
| .mvar .. => visitMVar e
| .sort u => return { expr := e, type? := Expr.sort u.succ }
| .const .. => visitConst e
| .app .. => checkCache e do visitAppArgs e
| .forallE .. => checkCache e do visitForall e
| .lam .. | .letE .. => checkCache e do visitLambdaLet e
| .lit v => return { expr := e, type? := v.type }
| .mdata _ e' => let e' ← visit e'; return { e' with expr := e.updateMData! e' }
| .proj structName idx struct => checkCache e do visitProj e structName idx (← visit struct)
end
private def main (e : Expr) : MetaM Expr := do
Prod.fst <$> withTraceNode `Meta.letToHave (fun
| .ok (_, msg) => pure m!"{checkEmoji} {msg}"
| .error ex => pure m!"{crossEmoji} {ex.toMessageData}") do
if hasDepLet e then
withTrackingZetaDelta <|
withTransparency TransparencyMode.all <|
withInferTypeConfig do
let (r, s) ← visit e |>.run {} |>.run {}
if s.count == 0 then
trace[Meta.letToHave] "result: (no change)"
else
trace[Meta.letToHave] "result:{indentExpr r.expr}"
return (r.expr, m!"transformed {s.count} `let` expressions into `have` expressions")
else
return (e, "no `let` expressions")
end LetToHave
/--
Transforms nondependent `let` expressions into `have` expressions.
If `e` is not type correct, returns `e`.
The `Meta.letToHave` trace class logs errors and messages.
-/
def letToHave (e : Expr) : MetaM Expr := do
profileitM Exception "let-to-have transformation" (← getOptions) do
let e ← instantiateMVars e
withoutExporting <| LetToHave.main e
builtin_initialize
registerTraceClass `Meta.letToHave
registerTraceClass `Meta.letToHave.debug
end Lean.Meta