feat: only allow variables declared with mut to be reassigned

This commit is contained in:
Leonardo de Moura 2020-11-07 16:57:29 -08:00
parent f484be1409
commit 6c6595cd9b
13 changed files with 85 additions and 92 deletions

View file

@ -17,7 +17,7 @@
"end" "this" "using" "using_well_founded" "namespace" "section"
"attribute" "local" "set_option" "extends" "include" "class"
"attributes" "raw" "have" "show" "suffices" "by" "in" "at" "do" "let" "for" "unless" "break" "continue"
"try" "catch" "finally" "where" "rec" "forall" "fun"
"try" "catch" "finally" "where" "rec" "mut" "forall" "fun"
"exists" "if" "then" "else" "from" "init_quot" "return"
"mutual" "def" "run_cmd" "declare_syntax_cat" "syntax" "macro_rules" "macro"
"initialize" "builtin_initialize")

View file

@ -1097,37 +1097,32 @@ def matchNestedTermResult (ref : Syntax) (term : Syntax) (uvars : Array Name) (a
end ToTerm
def isMutableLet (doElem : Syntax) : Bool :=
let kind := doElem.getKind
(kind == `Lean.Parser.Term.doLetArrow || kind == `Lean.Parser.Term.doLet)
&&
!doElem[1].isNone
namespace ToCodeBlock
structure Context :=
(ref : Syntax)
(m : Syntax) -- Syntax representing the monad associated with the do notation.
(varSet : NameSet := {})
(insideFor : Bool := false)
(ref : Syntax)
(m : Syntax) -- Syntax representing the monad associated with the do notation.
(mutableVars : NameSet := {})
(insideFor : Bool := false)
abbrev M := ReaderT Context TermElabM
@[inline] def withNewVars {α} (newVars : Array Name) (x : M α) : M α :=
withReader (fun ctx => { ctx with varSet := insertVars ctx.varSet newVars }) x
builtin_initialize
registerOption `relaxedReassignments { defValue := false, group := "do", descr := "if set to true, then any variable in the local context may be reassigned" }
def getRelaxedReassigments : M Bool := do
return (← getOptions).get `relaxedReassignments false
@[inline] def withNewMutableVars {α} (newVars : Array Name) (mutable : Bool) (x : M α) : M α :=
withReader (fun ctx => if mutable then { ctx with mutableVars := insertVars ctx.mutableVars newVars } else ctx) x
def checkReassignable (xs : Array Name) : M Unit := do
let throwInvalidReassignment (x : Name) : M Unit :=
throwError! "'{x.simpMacroScopes}' cannot be reassigned"
let ctx ← read
for x in xs do
unless ctx.varSet.contains x do
if (← getRelaxedReassigments) then
match (← resolveLocalName x) with
| some (_, []) => pure ()
| _ => throwInvalidReassignment x
else
throwInvalidReassignment x
unless ctx.mutableVars.contains x do
throwInvalidReassignment x
@[inline] def withFor {α} (x : M α) : M α :=
withReader (fun ctx => { ctx with insideFor := true }) x
@ -1203,12 +1198,12 @@ def checkLetArrowRHS (doElem : Syntax) : M Unit := do
def doPatDecl := parser! termParser >> leftArrow >> doElemParser >> optional (" | " >> doElemParser)
``` -/
def doLetArrowToCode (doSeqToCode : List Syntax → M CodeBlock) (doLetArrow : Syntax) (doElems : List Syntax) : M CodeBlock := do
let ref := doLetArrow
let decl := doLetArrow[2]
let ref := doLetArrow
let decl := doLetArrow[2]
if decl.getKind == `Lean.Parser.Term.doIdDecl then
let y := decl[0].getId
let doElem := decl[3]
let k ← withNewVars #[y] (doSeqToCode doElems)
let k ← withNewMutableVars #[y] (isMutableLet doLetArrow) (doSeqToCode doElems)
match isDoExpr? doElem with
| some action => pure $ mkVarDeclCore #[y] doLetArrow k
| none =>
@ -1222,9 +1217,15 @@ def doLetArrowToCode (doSeqToCode : List Syntax → M CodeBlock) (doLetArrow : S
let doElem := decl[2]
let optElse := decl[3]
if optElse.isNone then withFreshMacroScope do
let auxDo ← `(do let discr ← $doElem; let $pattern:term := discr)
let auxDo ←
if isMutableLet doLetArrow then
`(do let discr ← $doElem; let mut $pattern:term := discr)
else
`(do let discr ← $doElem; let $pattern:term := discr)
doSeqToCode $ getDoSeqElems (getDoSeq auxDo) ++ doElems
else
if isMutableLet doLetArrow then
throwError! "'mut' is currently not supported in let-decls with 'else' case"
let contSeq := mkDoSeq doElems.toArray
let elseSeq := mkSingletonDoSeq optElse[1]
let auxDo ← `(do let discr ← $doElem; match discr with | $pattern:term => $contSeq | _ => $elseSeq)
@ -1295,8 +1296,7 @@ def doForToCode (doSeqToCode : List Syntax → M CodeBlock) (doFor : Syntax) (do
let x := doFor[1]
let xs := doFor[3]
let forElems := getDoSeqElems doFor[5]
let newVars := if x.isIdent then #[x.getId] else #[]
let forInBodyCodeBlock ← withNewVars newVars $ withFor (doSeqToCode forElems)
let forInBodyCodeBlock ← withFor (doSeqToCode forElems)
let ⟨uvars, forInBody⟩ ← mkForInBody x forInBodyCodeBlock
let uvarsTuple ← liftMacroM $ mkTuple ref (uvars.map (mkIdentFrom ref))
if hasReturn forInBodyCodeBlock.code then
@ -1335,7 +1335,7 @@ def doMatchToCode (doSeqToCode : List Syntax → M CodeBlock) (doMatch : Syntax)
let pvars ← getPatternsVars patterns.getSepArgs
let vars := getPatternVarNames pvars
let rhs := matchAlt[2]
let rhs ← withNewVars vars $ doSeqToCode (getDoSeqElems rhs)
let rhs ← doSeqToCode (getDoSeqElems rhs)
pure { ref := matchAlt, vars := vars, patterns := patterns, rhs := rhs : Alt CodeBlock }
let matchCode ← mkMatch ref discrs optType alts
concatWith doSeqToCode matchCode doElems
@ -1450,13 +1450,13 @@ partial def doSeqToCode : List Syntax → M CodeBlock
let k := doElem.getKind
if k == `Lean.Parser.Term.doLet then
let vars ← getDoLetVars doElem
mkVarDeclCore vars doElem <$> withNewVars vars (doSeqToCode doElems)
mkVarDeclCore vars doElem <$> withNewMutableVars vars (isMutableLet doElem) (doSeqToCode doElems)
else if k == `Lean.Parser.Term.doHave then
let var := getDoHaveVar doElem
mkVarDeclCore #[var] doElem <$> withNewVars #[var] (doSeqToCode doElems)
mkVarDeclCore #[var] doElem <$> (doSeqToCode doElems)
else if k == `Lean.Parser.Term.doLetRec then
let vars ← getDoLetRecVars doElem
mkVarDeclCore vars doElem <$> withNewVars vars (doSeqToCode doElems)
mkVarDeclCore vars doElem <$> (doSeqToCode doElems)
else if k == `Lean.Parser.Term.doReassign then
let vars ← liftM $ getDoReassignVars doElem
checkReassignable vars

View file

@ -53,11 +53,11 @@ def mkRandMap (max : Nat) : Nat → Map → Array (Nat × Nat) → IO (Map × Ar
def tst3 (seed : Nat) (n : Nat) (max : Nat) : IO Unit :=
do IO.setRandSeed seed
let (m, a) ← mkRandMap max n {} Array.empty
let mut (m, a) ← mkRandMap max n {} Array.empty
check (sz m == a.size)
check (a.all (fun ⟨k, v⟩ => m.find? k == some v))
IO.println ("tst3 size: " ++ toString a.size)
let i := 0
let mut i := 0
for (k, b) in a do
if i % 2 == 0 then
m := m.erase k

View file

@ -14,13 +14,13 @@ for p in xs do
inductive Vector (α : Type) : Nat → Type
| nil : Vector α 0
| cons : α → {n : Nat} → Vector α n → Vector α (n+1)
set_option relaxedReassignments true in
def f4 (b : Bool) (n : Nat) (v : Vector Nat n) : Vector Nat (n+1) := do
let mut v := v
if b then
v := Vector.cons 1 v
Vector.cons 1 v
set_option relaxedReassignments true in
def f5 (y : Nat) (xs : List Nat) : List Bool := do
let mut y := y
for x in xs do
y := true -- invalid reassigned
@ -49,8 +49,8 @@ def f11 (x : Nat) : IO Unit := do
if x > 0 then
IO.println "x is not zero"
IO.mkRef true -- error here as expected
set_option relaxedReassignments true in
def f12 (x : Nat) : IO Unit := do
let mut x := x
if x > 0 then
pure true
else

View file

@ -1,6 +1,6 @@
doNotation1.lean:4:0: error: 'y' cannot be reassigned
doNotation1.lean:8:2: error: 'y' cannot be reassigned
doNotation1.lean:11:0: error: 'p' cannot be reassigned
doNotation1.lean:12:2: error: 'p' cannot be reassigned
doNotation1.lean:20:7: error: invalid reassignment, value has type
Vector Nat (n + 1)
but is expected to have type

View file

@ -139,7 +139,7 @@ registerTraceClass `Meta.mkElim
/- Helper methods for testins mkElim -/
private def getUnusedLevelParam (majors : List Expr) (lhss : List AltLHS) : MetaM Level := do
let s := {}
let mut s := {}
for major in majors do
let major ← instantiateMVars major
let majorType ← inferType major

View file

@ -1,4 +1,3 @@
open Lean
def f : IO Nat :=
@ -122,8 +121,8 @@ else
def f1 (x : Nat) : StateT Nat IO Nat := do
IO.println "hello"
let z := x
let y := x
let mut z := x
let mut y := x
modify (· + 10)
if x > 0 then
y := 3*y

View file

@ -15,7 +15,7 @@ aux x x;
#eval f 10
def g (xs : List Nat) : StateT Nat Id Nat := do
let xs := xs
let mut xs := xs
if xs.isEmpty then
xs := [← get]
dbgTrace! ">>> xs: {xs}"
@ -31,8 +31,8 @@ theorem ex2 : (g [] $.run' 0) = 1 :=
rfl
def h (x : Nat) (y : Nat) : Nat := do
let x := x
let y := y
let mut x := x
let mut y := y
if x > 0 then
let y := x + 1 -- this is a new `y` that shadows the one above
x := y
@ -47,7 +47,7 @@ theorem ex4 (y : Nat) : h 1 y = (1 + 1) + y :=
rfl
def sumOdd (xs : List Nat) (threshold : Nat) : Nat := do
let sum := 0
let mut sum := 0
for x in xs do
if x % 2 == 1 then
sum := sum + x
@ -65,7 +65,7 @@ rfl
-- We need `Id.run` because we still have `Monad Option`
def find? (xs : List Nat) (p : Nat → Bool) : Option Nat := Id.run do
let result := none
let mut result := none
for x in xs do
if p x then
result := x
@ -73,7 +73,7 @@ for x in xs do
return result
def sumDiff (ps : List (Nat × Nat)) : Nat := do
let sum := 0
let mut sum := 0
for (x, y) in ps do
sum := sum + x - y
return sum
@ -103,8 +103,8 @@ IO.println ("isOdd(" ++ toString x ++ "): " ++ toString (isOdd x))
#eval f2 10
def split (xs : List Nat) : List Nat × List Nat := do
let evens := []
let odds := []
let mut evens := []
let mut odds := []
for x in xs.reverse do
if x % 2 == 0 then
evens := x :: evens
@ -119,12 +119,12 @@ def f3 (x : Nat) : IO Bool := do
let y ← cond (x == 0) (do IO.println "hello"; true) false;
!y
set_option relaxedReassignments true in
def f4 (x y : Nat) : Nat × Nat := do
match x with
| 0 => y := y + 1
| _ => x := x + y
return (x, y)
let mut (x, y) := (x, y)
match x with
| 0 => y := y + 1
| _ => x := x + y
return (x, y)
#eval f4 0 10
#eval f4 5 10
@ -135,12 +135,12 @@ rfl
theorem ex10 (x y : Nat) : f4 (x+1) y = ((x+1)+y, y) :=
rfl
set_option relaxedReassignments true in
def f5 (x y : Nat) : Nat × Nat := do
match x with
| 0 => y := y + 1
| z+1 => dbgTrace! "z: {z}"; x := x + y
return (x, y)
let mut (x, y) := (x, y)
match x with
| 0 => y := y + 1
| z+1 => dbgTrace! "z: {z}"; x := x + y
return (x, y)
#eval f5 5 6
@ -148,11 +148,11 @@ theorem ex11 (x y : Nat) : f5 (x+1) y = ((x+1)+y, y) :=
rfl
def f6 (x : Nat) : Nat := do
let x := x
if x > 10 then
return 0
x := x + 1
return x
let mut x := x
if x > 10 then
return 0
x := x + 1
return x
theorem ex12 : f6 11 = 0 :=
rfl

View file

@ -33,7 +33,7 @@ let rec loop (i : Nat) (h : i ≤ as.size) (b : β) : m β := do
loop as.size (Nat.leRefl _) b
def f (x : Nat) (ref : IO.Ref Nat) : IO Nat := do
let x := x
let mut x := x
if x == 0 then
x ← ref.get
IO.println x
@ -43,12 +43,12 @@ def fTest : IO Unit := do
unless (← f 0 (← IO.mkRef 10)) == 11 do throw $ IO.userError "unexpected"
unless (← f 1 (← IO.mkRef 10)) == 2 do throw $ IO.userError "unexpected"
set_option relaxedReassignments true in
def g (x y : Nat) (ref : IO.Ref (Nat × Nat)) : IO (Nat × Nat) := do
if x == 0 then
(x, y) ← ref.get
IO.println ("x: " ++ toString x ++ ", y: " ++ toString y)
return (x, y)
let mut (x, y) := (x, y)
if x == 0 then
(x, y) ← ref.get
IO.println ("x: " ++ toString x ++ ", y: " ++ toString y)
return (x, y)
def gTest : IO Unit := do
unless (← g 2 1 (← IO.mkRef (10, 20))) == (2, 1) do throw $ IO.userError "unexpected"
@ -59,12 +59,12 @@ return ()
macro "ret!" x:term : doElem => `(return $x)
set_option relaxedReassignments true in
def f1 (x : Nat) : Nat := do
if x == 0 then
ret! 100
x := x + 1
ret! x
let mut x := x
if x == 0 then
ret! 100
x := x + 1
ret! x
theorem ex1 : f1 0 = 100 := rfl
theorem ex2 : f1 1 = 2 := rfl
@ -75,10 +75,10 @@ syntax "inc!" ident : doElem
macro_rules
| `(doElem| inc! $x) => `(doElem| $x:ident := $x + 1)
set_option relaxedReassignments true in
def f2 (x : Nat) : Nat := do
inc! x
ret! x
let mut x := x
inc! x
ret! x
theorem ex4 : f2 0 = 1 := rfl
theorem ex5 : f2 3 = 4 := rfl

View file

@ -1,5 +1,3 @@
abbrev M := StateRefT Nat IO
def testM {α} [ToString α] [BEq α] (init : Nat) (expected : α) (x : M α): IO Unit := do
@ -52,7 +50,7 @@ catch
#eval testM 0 2000 $ f3 10
def f4 (xs : List Nat) : M Nat := do
let y := 0
let mut y := 0
for x in xs do
IO.println s!"x: {x}"
try
@ -67,7 +65,7 @@ get
#eval testM 40 19 $ f4 [1, 2, 3, 4, 5, 6]
def f5 (xs : List Nat) : M Nat := do
let y := 0
let mut y := 0
for x in xs do
IO.println s!"x: {x}"
try

View file

@ -1,5 +1,3 @@
abbrev M := StateRefT Nat IO
def testM {α} [ToString α] [BEq α] (init : Nat) (expected : α) (x : M α): IO Unit := do
@ -22,7 +20,7 @@ let v ←
return 1
def f2 (xs : List Nat) : M Nat := do
let sum := 0
let mut sum := 0
for x in xs do
try
dec x
@ -39,7 +37,7 @@ return sum
#eval testM 1 1 $ f2 [1, 100, 200, 300]
def f3 (xs : List Nat) : M Nat := do
let sum := 0
let mut sum := 0
for x in xs do
try
dec x
@ -56,7 +54,7 @@ return sum
#eval testM 1 1 $ f3 [1, 100, 200, 300]
def f4 (xs : Array Nat) : IO Nat := do
let sum := 0
let mut sum := 0
for x in xs do
sum := sum + x
IO.println x
@ -65,7 +63,7 @@ return sum
#eval f4 #[1, 2, 3]
def f5 (xs : Array Nat) : IO Nat := do
let sum := 0
let mut sum := 0
for x in xs[1 : xs.size - 1] do
sum := sum + x
IO.println x

View file

@ -6,7 +6,7 @@ unless (← x) == (← expected) do
throw $ IO.userError "unexpected result"
def f1 (xs : Std.PArray Nat) (top : Nat) : IO Nat := do
let sum := 0
let mut sum := 0
for x in xs do
if x % 2 == 0 then
IO.println s!"x: {x}"
@ -21,7 +21,7 @@ return sum
#eval check (f1 [1, 2, 3, 4, 5, 10, 20].toPersistentArray 10) (pure 16)
def f2 (xs : Std.PArray Nat) (top : Nat) : IO Nat := do
let sum := 0
let mut sum := 0
for x in xs do
if x % 2 == 0 then
IO.println s!"x: {x}"

View file

@ -1,5 +1,3 @@
-- Macro for the `syntax` category
macro "many " x:stx : stx => `(stx| ($x)*)
@ -7,7 +5,7 @@ syntax "sum! " (many term:max) : term
macro_rules
| `(sum! $xs*) => do
let r ← `(0)
let mut r ← `(0)
for x in xs do
r ← `($r + $x)
return r