diff --git a/src/Init/Lean/Elab/Declaration.lean b/src/Init/Lean/Elab/Declaration.lean index 25464a78ca..47189daebf 100644 --- a/src/Init/Lean/Elab/Declaration.lean +++ b/src/Init/Lean/Elab/Declaration.lean @@ -101,7 +101,7 @@ withDeclId declId $ fun name => do type ← Term.mkForall typeStx xs type; (type, _) ← Term.mkForallUsedOnly typeStx vars type; type ← Term.levelMVarToParam type; - let usedParams := collectLevelParams type; + let usedParams := (collectLevelParams {} type).params; let levelParams := sortDeclLevelParams explictLevelNames usedParams; pure $ Declaration.axiomDecl { name := declName, diff --git a/src/Init/Lean/Util/CollectLevelParams.lean b/src/Init/Lean/Util/CollectLevelParams.lean index 53c829e0b2..737d89453a 100644 --- a/src/Init/Lean/Util/CollectLevelParams.lean +++ b/src/Init/Lean/Util/CollectLevelParams.lean @@ -15,47 +15,42 @@ structure State := (visitedExpr : ExprSet := {}) (params : Array Name := #[]) -abbrev M := StateM State +instance State.inhabited : Inhabited State := ⟨{}⟩ -@[inline] def visitLevel (f : Level → M Unit) (u : Level) : M Unit := -if !u.hasParam then pure () -else do - s ← get; - if s.visitedLevel.contains u then pure () - else do - modify $ fun s => { visitedLevel := s.visitedLevel.insert u, .. s }; - f u +abbrev Visitor := State → State -partial def collect : Level → M Unit +@[inline] def visitLevel (f : Level → Visitor) (u : Level) : Visitor := +fun s => + if !u.hasParam || s.visitedLevel.contains u then s + else f u { visitedLevel := s.visitedLevel.insert u, .. s } + +partial def collect : Level → Visitor | Level.succ v _ => visitLevel collect v -| Level.max u v _ => do visitLevel collect u; visitLevel collect v -| Level.imax u v _ => do visitLevel collect u; visitLevel collect v -| Level.param n _ => modify $ fun s => { params := s.params.push n, .. s } -| _ => pure () +| Level.max u v _ => visitLevel collect v ∘ visitLevel collect u +| Level.imax u v _ => visitLevel collect v ∘ visitLevel collect u +| Level.param n _ => fun s => { params := s.params.push n, .. s } +| _ => id -@[inline] def visitExpr (f : Expr → M Unit) (e : Expr) : M Unit := -if !e.hasLevelParam then pure () -else do - s ← get; - if s.visitedExpr.contains e then pure () - else do - modify $ fun s => { visitedExpr := s.visitedExpr.insert e, .. s }; - f e +@[inline] def visitExpr (f : Expr → Visitor) (e : Expr) : Visitor := +fun s => + if !e.hasLevelParam then s + else if s.visitedExpr.contains e then s + else f e { visitedExpr := s.visitedExpr.insert e, .. s } -partial def main : Expr → M Unit +partial def main : Expr → Visitor | Expr.proj _ _ s _ => visitExpr main s -| Expr.forallE _ d b _ => do visitExpr main d; visitExpr main b -| Expr.lam _ d b _ => do visitExpr main d; visitExpr main b -| Expr.letE _ t v b _ => do visitExpr main t; visitExpr main v; visitExpr main b -| Expr.app f a _ => do visitExpr main f; visitExpr main a +| Expr.forallE _ d b _ => visitExpr main b ∘ visitExpr main d +| Expr.lam _ d b _ => visitExpr main b ∘ visitExpr main d +| Expr.letE _ t v b _ => visitExpr main b ∘ visitExpr main v ∘ visitExpr main t +| Expr.app f a _ => visitExpr main a ∘ visitExpr main f | Expr.mdata _ b _ => visitExpr main b -| Expr.const _ us _ => us.forM (visitLevel collect) +| Expr.const _ us _ => fun s => us.foldl (fun s u => visitLevel collect u s) s | Expr.sort u _ => visitLevel collect u -| _ => pure () +| _ => id end CollectLevelParams -def collectLevelParams (e : Expr) : Array Name := -(CollectLevelParams.main e {}).2.params +def collectLevelParams (s : CollectLevelParams.State) (e : Expr) : CollectLevelParams.State := +CollectLevelParams.main e s end Lean