From eafd9bc0ad8eeb9fa1d5b1dbf55fe060c53a475f Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 5 Oct 2020 15:37:34 -0700 Subject: [PATCH] feat: expand `for x in xs` notation --- src/Lean/Elab/Do.lean | 85 +++++++++++++++++++++++---------- tests/lean/run/doNotation2.lean | 46 ++++++++++++++++++ 2 files changed, 106 insertions(+), 25 deletions(-) diff --git a/src/Lean/Elab/Do.lean b/src/Lean/Elab/Do.lean index 5a4af2fe71..78cd4b89b7 100644 --- a/src/Lean/Elab/Do.lean +++ b/src/Lean/Elab/Do.lean @@ -588,7 +588,10 @@ else namespace ToTerm inductive Kind -| regular | forInNestedTerm | forIn | forInMap +| regular +| forInNestedTerm +| forIn +| forInMap (x : Name) structure Context := (m : Syntax) -- Syntax to reference the monad associated with the do notation. @@ -614,12 +617,22 @@ else do uvars ← mkUVarTuple ref; liftM $ mkTuple ref #[val, uvars] +private def mkForInYield (ref : Syntax) : M Syntax := do +u ← mkUVarTuple ref; +`(HasPure.pure (ForInStep.yield $u)) + +private def mkForInMapYield (ref : Syntax) (x : Name) : M Syntax := do +u ← mkUVarTuple ref; +r ← liftM $ mkTuple ref #[mkIdentFrom ref x, u]; +`(HasPure.pure (ForInStep.yield $r)) + def returnToTermCore (ref : Syntax) (val : Syntax) : M Syntax := do ctx ← read; match ctx.kind with | Kind.forInNestedTerm => do u ← mkUVarTuple ref; `(HasPure.pure (DoResult.«return» $u)) | Kind.regular => do r ← mkResultUVarTuple ref val; `(HasPure.pure $r) -| _ => do u ← mkUVarTuple ref; `(HasPure.pure (ForInStep.yield $u)) +| Kind.forIn => mkForInYield ref +| Kind.forInMap x => mkForInMapYield ref x def returnToTerm (ref : Syntax) (val : Syntax) : M Syntax := do r ← returnToTermCore ref val; @@ -630,7 +643,8 @@ ctx ← read; match ctx.kind with | Kind.regular => unreachable! | Kind.forInNestedTerm => do u ← mkUVarTuple ref; `(HasPure.pure (DoResult.«continue» $u)) -| _ => do u ← mkUVarTuple ref; `(HasPure.pure (ForInStep.yield $u)) +| Kind.forIn => mkForInYield ref +| Kind.forInMap x => mkForInMapYield ref x def continueToTerm (ref : Syntax) : M Syntax := do r ← continueToTermCore ref; @@ -641,7 +655,8 @@ ctx ← read; match ctx.kind with | Kind.regular => unreachable! | Kind.forInNestedTerm => do u ← mkUVarTuple ref; `(HasPure.pure (DoResult.«break» $u)) -| _ => do u ← mkUVarTuple ref; `(HasPure.pure (ForInStep.done $u)) +| Kind.forIn => do u ← mkUVarTuple ref; `(HasPure.pure (ForInStep.done $u)) +| Kind.forInMap x => do u ← mkUVarTuple ref; r ← liftM $ mkTuple ref #[mkIdentFrom ref x, u]; `(HasPure.pure (ForInStep.done $r)) def breakToTerm (ref : Syntax) : M Syntax := do r ← breakToTermCore ref; @@ -748,25 +763,9 @@ partial def toTerm : Code → M Syntax | Code.ite ref _ o c t e => do t ← toTerm t; e ← toTerm e; pure $ mkIte ref o c t e | _ => liftM $ Macro.throwError Syntax.missing "WIP" -private def getKindUVars (c : CodeBlock) (forInVar? : Option Name) : Kind × Array Name := -match forInVar? with -| none => - if hasContinueBreak c.code then - (Kind.forInNestedTerm, nameSetToArray c.uvars) - else - (Kind.regular, nameSetToArray c.uvars) -| some forInVar => - if c.uvars.contains forInVar then - let uvars := #[forInVar] ++ nameSetToArray (c.uvars.erase forInVar); - (Kind.forInMap, uvars) - else - (Kind.forIn, nameSetToArray c.uvars) - -def run (c : CodeBlock) (m : Syntax) (forInVar? : Option Name := none) : MacroM (Array Name × Syntax) := do -let code := c.code; -let (kind, uvars) := getKindUVars c forInVar?; +def run (code : Code) (m : Syntax) (uvars : Array Name := #[]) (kind := Kind.regular) : MacroM Syntax := do term ← toTerm code { m := m, kind := kind, uvars := uvars }; -pure (uvars, term) +pure term end ToTerm @@ -779,6 +778,27 @@ structure Context := abbrev M := ReaderT Context TermElabM +@[inline] def withFor {α} (x : M α) : M α := +adaptReader (fun (ctx : Context) => { ctx with insideFor := true }) x + +structure ToForInTermResult := +(isForInMap : Bool) +(uvars : Array Name) +(term : Syntax) + +def toForInTerm (x : Syntax) (forCodeBlock : CodeBlock) : M ToForInTermResult := do +ctx ← read; +let uvars := forCodeBlock.uvars; +if x.isIdent && uvars.contains x.getId then do + -- It is a forInMap + let uvars := nameSetToArray (uvars.erase x.getId); + term ← liftMacroM $ ToTerm.run forCodeBlock.code ctx.m uvars (ToTerm.Kind.forInMap x.getId); + pure ⟨true, uvars, term⟩ +else do + let uvars := nameSetToArray uvars; + term ← liftMacroM $ ToTerm.run forCodeBlock.code ctx.m uvars ToTerm.Kind.forIn; + pure ⟨false, uvars, term⟩ + def ensureInsideFor : M Unit := do ctx ← read; unless ctx.insideFor $ @@ -869,8 +889,23 @@ partial def doSeqToCode : List Syntax → M CodeBlock body ← doSeqToCode (getDoSeqElems doSeq); unless ← liftM $ mkUnless ref cond body; concatWithRest unless - else if k == `Lean.Parser.Term.doFor then - throwError "WIP" + else if k == `Lean.Parser.Term.doFor then withFreshMacroScope do + let ref := doElem; + let x := doElem.getArg 1; + let xs := doElem.getArg 3; + let forElems := getDoSeqElems (doElem.getArg 5); + forCodeBlock ← withFor (doSeqToCode forElems); + ⟨isForInMap, uvars, forInBody⟩ ← toForInTerm x forCodeBlock; + uvarsTuple ← liftMacroM $ mkTuple ref (uvars.map (mkIdentFrom ref)); + auxDo ← if isForInMap then do + forInTerm ← `($(xs).forInMap $uvarsTuple fun $x $uvarsTuple => $forInBody); + `(do let r ← $forInTerm; $uvarsTuple:term := r.2; return r.1) + else do { + forInTerm ← `($(xs).forIn $uvarsTuple fun $x $uvarsTuple => $forInBody); + `(do let r ← $forInTerm; $uvarsTuple:term := r) + }; + let doElemsNew := getDoSeqElems (getDoSeq auxDo); + doSeqToCode (doElemsNew ++ doElems) else if k == `Lean.Parser.Term.doMatch then throwError "WIP" else if k == `Lean.Parser.Term.doTry then @@ -927,7 +962,7 @@ fun stx expectedType? => do m ← mkMonadAlias bindInfo.m; codeBlock ← ToCodeBlock.run stx m; -- trace! `Elab.do ("codeBlock: " ++ Format.line ++ codeBlock.toMessageData); - (_, stxNew) ← liftMacroM $ ToTerm.run { codeBlock with uvars := {} } m; + stxNew ← liftMacroM $ ToTerm.run codeBlock.code m; trace! `Elab.do stxNew; let expectedType := mkApp bindInfo.m bindInfo.α; withMacroExpansion stx stxNew $ elabTermEnsuringType stxNew expectedType diff --git a/tests/lean/run/doNotation2.lean b/tests/lean/run/doNotation2.lean index 9b168186e9..f098312ae5 100644 --- a/tests/lean/run/doNotation2.lean +++ b/tests/lean/run/doNotation2.lean @@ -44,3 +44,49 @@ rfl theorem ex4 (y : Nat) : h 1 y = (1 + 1) + y := rfl + +def sumOdd (xs : List Nat) (threshold : Nat) : Nat := do +let sum := 0 +for x in xs do + if x % 2 == 1 then + sum := sum + x + if sum > threshold then + break + unless x % 2 == 1 do + continue + dbgTrace! ">> x: " ++ toString x +return sum + +#eval sumOdd [1, 2, 3, 4, 5, 6, 7, 9, 11, 101] 10 + +theorem ex5 : sumOdd [1, 2, 3, 4, 5, 6, 7, 9, 11, 101] 10 = 16 := +rfl + +def mapOdd (f : Nat → Nat) (xs : List Nat) : List Nat := do +for x in xs do + if x % 2 == 1 then + x := f x + dbgTrace! ">> mapOdd x: " ++ toString x + +#eval mapOdd (·+10) [1, 2, 3, 4, 5, 6, 7, 9] + +theorem ex6 : mapOdd (·+10) [1, 2, 3, 4, 5, 6, 7, 9] = [11, 2, 13, 4, 15, 6, 17, 19] := +rfl + +-- We need `Id.run` because we still have `Monad Option` +def find? (xs : List Nat) (p : Nat → Bool) : Option Nat := Id.run do +let result := none +for x in xs do + if p x then + result := x + break +return result + +def sumDiff (ps : List (Nat × Nat)) : Nat := do +let sum := 0 +for (x, y) in ps do + sum := sum + x - y +return sum + +theorem ex7 : sumDiff [(2, 1), (10, 5)] = 6 := +rfl