feat: expand for x in xs notation

This commit is contained in:
Leonardo de Moura 2020-10-05 15:37:34 -07:00
parent ac16393ae9
commit eafd9bc0ad
2 changed files with 106 additions and 25 deletions

View file

@ -588,7 +588,10 @@ else
namespace ToTerm
inductive Kind
| regular | forInNestedTerm | forIn | forInMap
| regular
| forInNestedTerm
| forIn
| forInMap (x : Name)
structure Context :=
(m : Syntax) -- Syntax to reference the monad associated with the do notation.
@ -614,12 +617,22 @@ else do
uvars ← mkUVarTuple ref;
liftM $ mkTuple ref #[val, uvars]
private def mkForInYield (ref : Syntax) : M Syntax := do
u ← mkUVarTuple ref;
`(HasPure.pure (ForInStep.yield $u))
private def mkForInMapYield (ref : Syntax) (x : Name) : M Syntax := do
u ← mkUVarTuple ref;
r ← liftM $ mkTuple ref #[mkIdentFrom ref x, u];
`(HasPure.pure (ForInStep.yield $r))
def returnToTermCore (ref : Syntax) (val : Syntax) : M Syntax := do
ctx ← read;
match ctx.kind with
| Kind.forInNestedTerm => do u ← mkUVarTuple ref; `(HasPure.pure (DoResult.«return» $u))
| Kind.regular => do r ← mkResultUVarTuple ref val; `(HasPure.pure $r)
| _ => do u ← mkUVarTuple ref; `(HasPure.pure (ForInStep.yield $u))
| Kind.forIn => mkForInYield ref
| Kind.forInMap x => mkForInMapYield ref x
def returnToTerm (ref : Syntax) (val : Syntax) : M Syntax := do
r ← returnToTermCore ref val;
@ -630,7 +643,8 @@ ctx ← read;
match ctx.kind with
| Kind.regular => unreachable!
| Kind.forInNestedTerm => do u ← mkUVarTuple ref; `(HasPure.pure (DoResult.«continue» $u))
| _ => do u ← mkUVarTuple ref; `(HasPure.pure (ForInStep.yield $u))
| Kind.forIn => mkForInYield ref
| Kind.forInMap x => mkForInMapYield ref x
def continueToTerm (ref : Syntax) : M Syntax := do
r ← continueToTermCore ref;
@ -641,7 +655,8 @@ ctx ← read;
match ctx.kind with
| Kind.regular => unreachable!
| Kind.forInNestedTerm => do u ← mkUVarTuple ref; `(HasPure.pure (DoResult.«break» $u))
| _ => do u ← mkUVarTuple ref; `(HasPure.pure (ForInStep.done $u))
| Kind.forIn => do u ← mkUVarTuple ref; `(HasPure.pure (ForInStep.done $u))
| Kind.forInMap x => do u ← mkUVarTuple ref; r ← liftM $ mkTuple ref #[mkIdentFrom ref x, u]; `(HasPure.pure (ForInStep.done $r))
def breakToTerm (ref : Syntax) : M Syntax := do
r ← breakToTermCore ref;
@ -748,25 +763,9 @@ partial def toTerm : Code → M Syntax
| Code.ite ref _ o c t e => do t ← toTerm t; e ← toTerm e; pure $ mkIte ref o c t e
| _ => liftM $ Macro.throwError Syntax.missing "WIP"
private def getKindUVars (c : CodeBlock) (forInVar? : Option Name) : Kind × Array Name :=
match forInVar? with
| none =>
if hasContinueBreak c.code then
(Kind.forInNestedTerm, nameSetToArray c.uvars)
else
(Kind.regular, nameSetToArray c.uvars)
| some forInVar =>
if c.uvars.contains forInVar then
let uvars := #[forInVar] ++ nameSetToArray (c.uvars.erase forInVar);
(Kind.forInMap, uvars)
else
(Kind.forIn, nameSetToArray c.uvars)
def run (c : CodeBlock) (m : Syntax) (forInVar? : Option Name := none) : MacroM (Array Name × Syntax) := do
let code := c.code;
let (kind, uvars) := getKindUVars c forInVar?;
def run (code : Code) (m : Syntax) (uvars : Array Name := #[]) (kind := Kind.regular) : MacroM Syntax := do
term ← toTerm code { m := m, kind := kind, uvars := uvars };
pure (uvars, term)
pure term
end ToTerm
@ -779,6 +778,27 @@ structure Context :=
abbrev M := ReaderT Context TermElabM
@[inline] def withFor {α} (x : M α) : M α :=
adaptReader (fun (ctx : Context) => { ctx with insideFor := true }) x
structure ToForInTermResult :=
(isForInMap : Bool)
(uvars : Array Name)
(term : Syntax)
def toForInTerm (x : Syntax) (forCodeBlock : CodeBlock) : M ToForInTermResult := do
ctx ← read;
let uvars := forCodeBlock.uvars;
if x.isIdent && uvars.contains x.getId then do
-- It is a forInMap
let uvars := nameSetToArray (uvars.erase x.getId);
term ← liftMacroM $ ToTerm.run forCodeBlock.code ctx.m uvars (ToTerm.Kind.forInMap x.getId);
pure ⟨true, uvars, term⟩
else do
let uvars := nameSetToArray uvars;
term ← liftMacroM $ ToTerm.run forCodeBlock.code ctx.m uvars ToTerm.Kind.forIn;
pure ⟨false, uvars, term⟩
def ensureInsideFor : M Unit := do
ctx ← read;
unless ctx.insideFor $
@ -869,8 +889,23 @@ partial def doSeqToCode : List Syntax → M CodeBlock
body ← doSeqToCode (getDoSeqElems doSeq);
unless ← liftM $ mkUnless ref cond body;
concatWithRest unless
else if k == `Lean.Parser.Term.doFor then
throwError "WIP"
else if k == `Lean.Parser.Term.doFor then withFreshMacroScope do
let ref := doElem;
let x := doElem.getArg 1;
let xs := doElem.getArg 3;
let forElems := getDoSeqElems (doElem.getArg 5);
forCodeBlock ← withFor (doSeqToCode forElems);
⟨isForInMap, uvars, forInBody⟩ ← toForInTerm x forCodeBlock;
uvarsTuple ← liftMacroM $ mkTuple ref (uvars.map (mkIdentFrom ref));
auxDo ← if isForInMap then do
forInTerm ← `($(xs).forInMap $uvarsTuple fun $x $uvarsTuple => $forInBody);
`(do let r ← $forInTerm; $uvarsTuple:term := r.2; return r.1)
else do {
forInTerm ← `($(xs).forIn $uvarsTuple fun $x $uvarsTuple => $forInBody);
`(do let r ← $forInTerm; $uvarsTuple:term := r)
};
let doElemsNew := getDoSeqElems (getDoSeq auxDo);
doSeqToCode (doElemsNew ++ doElems)
else if k == `Lean.Parser.Term.doMatch then
throwError "WIP"
else if k == `Lean.Parser.Term.doTry then
@ -927,7 +962,7 @@ fun stx expectedType? => do
m ← mkMonadAlias bindInfo.m;
codeBlock ← ToCodeBlock.run stx m;
-- trace! `Elab.do ("codeBlock: " ++ Format.line ++ codeBlock.toMessageData);
(_, stxNew) ← liftMacroM $ ToTerm.run { codeBlock with uvars := {} } m;
stxNew ← liftMacroM $ ToTerm.run codeBlock.code m;
trace! `Elab.do stxNew;
let expectedType := mkApp bindInfo.m bindInfo.α;
withMacroExpansion stx stxNew $ elabTermEnsuringType stxNew expectedType

View file

@ -44,3 +44,49 @@ rfl
theorem ex4 (y : Nat) : h 1 y = (1 + 1) + y :=
rfl
def sumOdd (xs : List Nat) (threshold : Nat) : Nat := do
let sum := 0
for x in xs do
if x % 2 == 1 then
sum := sum + x
if sum > threshold then
break
unless x % 2 == 1 do
continue
dbgTrace! ">> x: " ++ toString x
return sum
#eval sumOdd [1, 2, 3, 4, 5, 6, 7, 9, 11, 101] 10
theorem ex5 : sumOdd [1, 2, 3, 4, 5, 6, 7, 9, 11, 101] 10 = 16 :=
rfl
def mapOdd (f : Nat → Nat) (xs : List Nat) : List Nat := do
for x in xs do
if x % 2 == 1 then
x := f x
dbgTrace! ">> mapOdd x: " ++ toString x
#eval mapOdd (·+10) [1, 2, 3, 4, 5, 6, 7, 9]
theorem ex6 : mapOdd (·+10) [1, 2, 3, 4, 5, 6, 7, 9] = [11, 2, 13, 4, 15, 6, 17, 19] :=
rfl
-- We need `Id.run` because we still have `Monad Option`
def find? (xs : List Nat) (p : Nat → Bool) : Option Nat := Id.run do
let result := none
for x in xs do
if p x then
result := x
break
return result
def sumDiff (ps : List (Nat × Nat)) : Nat := do
let sum := 0
for (x, y) in ps do
sum := sum + x - y
return sum
theorem ex7 : sumDiff [(2, 1), (10, 5)] = 6 :=
rfl