From bdaabd4e7b4b463dc4cebd7adf44a143199bbecb Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Fri, 8 Jul 2022 17:29:30 -0700 Subject: [PATCH] feat: propagate return type to `for-in` block --- src/Lean/Elab/Do.lean | 52 +++++++++++-------- tests/lean/217.lean.expected.out | 6 +-- .../linterUnusedVariables.lean.expected.out | 2 +- tests/lean/run/forInReturnPropagation.lean | 8 +++ 4 files changed, 42 insertions(+), 26 deletions(-) create mode 100644 tests/lean/run/forInReturnPropagation.lean diff --git a/src/Lean/Elab/Do.lean b/src/Lean/Elab/Do.lean index cad0617977..e9a7f709a4 100644 --- a/src/Lean/Elab/Do.lean +++ b/src/Lean/Elab/Do.lean @@ -87,24 +87,24 @@ private partial def hasLiftMethod : Syntax → Bool structure ExtractMonadResult where m : Expr - α : Expr + returnType : Expr expectedType : Expr private def mkUnknownMonadResult : MetaM ExtractMonadResult := do let u ← mkFreshLevelMVar let v ← mkFreshLevelMVar let m ← mkFreshExprMVar (← mkArrow (mkSort (mkLevelSucc u)) (mkSort (mkLevelSucc v))) - let α ← mkFreshExprMVar (mkSort (mkLevelSucc u)) - return { m, α, expectedType := mkApp m α } + let returnType ← mkFreshExprMVar (mkSort (mkLevelSucc u)) + return { m, returnType, expectedType := mkApp m returnType } private partial def extractBind (expectedType? : Option Expr) : TermElabM ExtractMonadResult := do let some expectedType := expectedType? | mkUnknownMonadResult let extractStep? (type : Expr) : MetaM (Option ExtractMonadResult) := do - let .app m α _ := type | return none + let .app m returnType _ := type | return none try let bindInstType ← mkAppM ``Bind #[m] discard <| Meta.synthInstance bindInstType - return some { m, α, expectedType } + return some { m, returnType, expectedType } catch _ => return none let rec extract? (type : Expr) : MetaM (Option ExtractMonadResult) := do @@ -863,9 +863,12 @@ def Kind.isRegular : Kind → Bool | _ => false structure Context where - m : Syntax -- Syntax to reference the monad associated with the do notation. - uvars : Array Var - kind : Kind + /-- Syntax to reference the monad associated with the do notation. -/ + m : Syntax + /-- Syntax to reference the result of the monadic computation performed by the do notation. -/ + returnType : Syntax + uvars : Array Var + kind : Kind abbrev M := ReaderT Context MacroM @@ -1031,8 +1034,8 @@ partial def toTerm (c : Code) : M Syntax := do let termMatchAlts := mkNode `Lean.Parser.Term.matchAlts #[mkNullNode termAlts] return mkNode `Lean.Parser.Term.«match» #[mkAtomFrom ref "match", genParam, optMotive, discrs, mkAtomFrom ref "with", termMatchAlts] -def run (code : Code) (m : Syntax) (uvars : Array Var := #[]) (kind := Kind.regular) : MacroM Syntax := - toTerm code { m := m, kind := kind, uvars := uvars } +def run (code : Code) (m : Syntax) (returnType : Syntax) (uvars : Array Var := #[]) (kind := Kind.regular) : MacroM Syntax := + toTerm code { m, returnType, kind, uvars } /- Given - `a` is true if the code block has a `Code.action _` exit point @@ -1051,8 +1054,8 @@ def mkNestedKind (a r bc : Bool) : Kind := | true, true, true => .nestedPRBC | false, false, false => unreachable! -def mkNestedTerm (code : Code) (m : Syntax) (uvars : Array Var) (a r bc : Bool) : MacroM Syntax := do - ToTerm.run code m uvars (mkNestedKind a r bc) +def mkNestedTerm (code : Code) (m : Syntax) (returnType : Syntax) (uvars : Array Var) (a r bc : Bool) : MacroM Syntax := do + ToTerm.run code m returnType uvars (mkNestedKind a r bc) /- Given a term `term` produced by `ToTerm.run`, pattern match on its result. See comment at the beginning of the `ToTerm` namespace. @@ -1119,7 +1122,10 @@ namespace ToCodeBlock structure Context where ref : Syntax - m : Syntax -- Syntax representing the monad associated with the do notation. + /-- Syntax representing the monad associated with the do notation. -/ + m : Syntax + /-- Syntax to reference the result of the monadic computation performed by the do notation. -/ + returnType : Syntax mutableVars : VarSet := {} insideFor : Bool := false @@ -1155,7 +1161,7 @@ def mkForInBody (_ : Syntax) (forInBody : CodeBlock) : M ToForInTermResult := d let ctx ← read let uvars := forInBody.uvars let uvars := varSetToArray uvars - let term ← liftMacroM <| ToTerm.run forInBody.code ctx.m uvars (if hasReturn forInBody.code then ToTerm.Kind.forInWithReturn else ToTerm.Kind.forIn) + let term ← liftMacroM <| ToTerm.run forInBody.code ctx.m ctx.returnType uvars (if hasReturn forInBody.code then ToTerm.Kind.forInWithReturn else ToTerm.Kind.forIn) return ⟨uvars, term⟩ def ensureInsideFor : M Unit := @@ -1413,10 +1419,11 @@ mutual let uvarsTuple ← liftMacroM do mkTuple uvars if hasReturn forInBodyCodeBlock.code then let forInBody ← liftMacroM <| destructTuple uvars (← `(r)) forInBody + let optType ← `(Option $((← read).returnType)) let forInTerm ← if let some h := h? then - `(for_in'% $(xs) (MProd.mk none $uvarsTuple) fun $x $h r => let r := r.2; $forInBody) + `(for_in'% $(xs) (MProd.mk (none : $optType) $uvarsTuple) fun $x $h (r : MProd $optType _) => let r := r.2; $forInBody) else - `(for_in% $(xs) (MProd.mk none $uvarsTuple) fun $x r => let r := r.2; $forInBody) + `(for_in% $(xs) (MProd.mk (none : $optType) $uvarsTuple) fun $x (r : MProd $optType _) => let r := r.2; $forInBody) let auxDo ← `(do let r ← $forInTerm:term; $uvarsTuple:term := r.2; match r.1 with @@ -1495,7 +1502,7 @@ mutual let bc := tryCatchPred tryCode catches finallyCode? hasBreakContinue let toTerm (codeBlock : CodeBlock) : M Syntax := do let codeBlock ← liftM $ extendUpdatedVars codeBlock ws - liftMacroM <| ToTerm.mkNestedTerm codeBlock.code ctx.m uvars a r bc + liftMacroM <| ToTerm.mkNestedTerm codeBlock.code ctx.m ctx.returnType uvars a r bc let term ← toTerm tryCode let term ← catches.foldlM (init := term) fun term «catch» => do let catchTerm ← toTerm «catch».codeBlock @@ -1511,7 +1518,7 @@ mutual throwError "'finally' currently does not support reassignments" if hasBreakContinueReturn finallyCode.code then throwError "'finally' currently does 'return', 'break', nor 'continue'" - let finallyTerm ← liftMacroM <| ToTerm.run finallyCode.code ctx.m {} ToTerm.Kind.regular + let finallyTerm ← liftMacroM <| ToTerm.run finallyCode.code ctx.m ctx.returnType {} ToTerm.Kind.regular ``(tryFinally $term $finallyTerm) let doElemsNew ← liftMacroM <| ToTerm.matchNestedTermResult term uvars a r bc doSeqToCode (doElemsNew ++ doElems) @@ -1592,8 +1599,8 @@ mutual throwError "unexpected do-element of kind {doElem.getKind}:\n{doElem}" end -def run (doStx : Syntax) (m : Syntax) : TermElabM CodeBlock := - (doSeqToCode <| getDoSeqElems <| getDoSeq doStx).run { ref := doStx, m } +def run (doStx : Syntax) (m : Syntax) (returnType : Syntax) : TermElabM CodeBlock := + (doSeqToCode <| getDoSeqElems <| getDoSeq doStx).run { ref := doStx, m, returnType } end ToCodeBlock @@ -1601,8 +1608,9 @@ end ToCodeBlock tryPostponeIfNoneOrMVar expectedType? let bindInfo ← extractBind expectedType? let m ← Term.exprToSyntax bindInfo.m - let codeBlock ← ToCodeBlock.run stx m - let stxNew ← liftMacroM <| ToTerm.run codeBlock.code m + let returnType ← Term.exprToSyntax bindInfo.returnType + let codeBlock ← ToCodeBlock.run stx m returnType + let stxNew ← liftMacroM <| ToTerm.run codeBlock.code m returnType trace[Elab.do] stxNew withMacroExpansion stx stxNew <| elabTermEnsuringType stxNew bindInfo.expectedType diff --git a/tests/lean/217.lean.expected.out b/tests/lean/217.lean.expected.out index 700f278c24..d9a84d3f89 100644 --- a/tests/lean/217.lean.expected.out +++ b/tests/lean/217.lean.expected.out @@ -1,6 +1,6 @@ -217.lean:5:28-5:29: error: don't know how to synthesize placeholder for argument 'f' -context: -⊢ CoreM Unit → Name → ConstantInfo → CoreM Unit 217.lean:5:30-5:31: error: don't know how to synthesize placeholder for argument 'init' context: ⊢ CoreM Unit +217.lean:5:28-5:29: error: don't know how to synthesize placeholder for argument 'f' +context: +⊢ CoreM Unit → Name → ConstantInfo → CoreM Unit diff --git a/tests/lean/linterUnusedVariables.lean.expected.out b/tests/lean/linterUnusedVariables.lean.expected.out index bf904d9a63..2d1c491c37 100644 --- a/tests/lean/linterUnusedVariables.lean.expected.out +++ b/tests/lean/linterUnusedVariables.lean.expected.out @@ -12,8 +12,8 @@ linterUnusedVariables.lean:50:11-50:12: warning: unused variable `z` linterUnusedVariables.lean:55:14-55:15: warning: unused variable `y` linterUnusedVariables.lean:61:20-61:21: warning: unused variable `y` linterUnusedVariables.lean:66:34-66:38: warning: unused variable `inst` -linterUnusedVariables.lean:108:6-108:7: warning: unused variable `y` linterUnusedVariables.lean:107:25-107:26: warning: unused variable `x` +linterUnusedVariables.lean:108:6-108:7: warning: unused variable `y` linterUnusedVariables.lean:114:6-114:7: warning: unused variable `a` linterUnusedVariables.lean:124:26-124:27: warning: unused variable `z` linterUnusedVariables.lean:132:9-132:10: warning: unused variable `h` diff --git a/tests/lean/run/forInReturnPropagation.lean b/tests/lean/run/forInReturnPropagation.lean new file mode 100644 index 0000000000..3b2f9acc6b --- /dev/null +++ b/tests/lean/run/forInReturnPropagation.lean @@ -0,0 +1,8 @@ +def main (args : List String) : IO UInt32 := do + for (arg : String) in args do + match arg with + | "--print-cflags" => + return 1 + | _ => + pure () + return 0