diff --git a/src/Lean/LocalContext.lean b/src/Lean/LocalContext.lean index c7d0caeb77..ddf0ea9a4a 100644 --- a/src/Lean/LocalContext.lean +++ b/src/Lean/LocalContext.lean @@ -141,6 +141,15 @@ def hasExprMVar : LocalDecl → Bool | cdecl (type := t) .. => t.hasExprMVar | ldecl (type := t) (value := v) .. => t.hasExprMVar || v.hasExprMVar +/-- +Set the kind of a `LocalDecl`. +-/ +def setKind : LocalDecl → LocalDeclKind → LocalDecl + | cdecl index fvarId userName type bi _, kind => + cdecl index fvarId userName type bi kind + | ldecl index fvarId userName type value nonDep _, kind => + ldecl index fvarId userName type value nonDep kind + end LocalDecl /-- A LocalContext is an ordered set of local variable declarations. @@ -311,6 +320,13 @@ def renameUserName (lctx : LocalContext) (fromName : Name) (toName : Name) : Loc { fvarIdToDecl := map.insert decl.fvarId decl decls := decls.set decl.index decl } +/-- +Set the kind of the given fvar. +-/ +def setKind (lctx : LocalContext) (fvarId : FVarId) + (kind : LocalDeclKind) : LocalContext := + lctx.modifyLocalDecl fvarId (·.setKind kind) + def setBinderInfo (lctx : LocalContext) (fvarId : FVarId) (bi : BinderInfo) : LocalContext := modifyLocalDecl lctx fvarId fun decl => decl.setBinderInfo bi @@ -451,6 +467,27 @@ def sanitizeNames (lctx : LocalContext) : StateM NameSanitizerState LocalContext modify fun s => s.insert decl.userName pure lctx +/-- +Given an `FVarId`, this function returns the corresponding user name, +but only if the name can be used to recover the original FVarId. +-/ +def getRoundtrippingUserName? (lctx : LocalContext) (fvarId : FVarId) : Option Name := do + let ldecl₁ ← lctx.find? fvarId + let ldecl₂ ← lctx.findFromUserName? ldecl₁.userName + guard <| ldecl₁.fvarId == ldecl₂.fvarId + some ldecl₁.userName + +/-- +Sort the given `FVarId`s by the order in which they appear in `lctx`. If any of +the `FVarId`s do not appear in `lctx`, the result is unspecified. +-/ +def sortFVarsByContextOrder (lctx : LocalContext) (hyps : Array FVarId) : Array FVarId := + let hyps := hyps.map fun fvarId => + match lctx.fvarIdToDecl.find? fvarId with + | none => (0, fvarId) + | some ldecl => (ldecl.index, fvarId) + hyps.qsort (fun h i => h.fst < i.fst) |>.map (·.snd) + end LocalContext /-- Class used to denote that `m` has a local context. -/