feat: propagate return type to for-in block

This commit is contained in:
Leonardo de Moura 2022-07-08 17:29:30 -07:00
parent 71edf731f9
commit bdaabd4e7b
4 changed files with 42 additions and 26 deletions

View file

@ -87,24 +87,24 @@ private partial def hasLiftMethod : Syntax → Bool
structure ExtractMonadResult where
m : Expr
α : Expr
returnType : Expr
expectedType : Expr
private def mkUnknownMonadResult : MetaM ExtractMonadResult := do
let u ← mkFreshLevelMVar
let v ← mkFreshLevelMVar
let m ← mkFreshExprMVar (← mkArrow (mkSort (mkLevelSucc u)) (mkSort (mkLevelSucc v)))
let α ← mkFreshExprMVar (mkSort (mkLevelSucc u))
return { m, α, expectedType := mkApp m α }
let returnType ← mkFreshExprMVar (mkSort (mkLevelSucc u))
return { m, returnType, expectedType := mkApp m returnType }
private partial def extractBind (expectedType? : Option Expr) : TermElabM ExtractMonadResult := do
let some expectedType := expectedType? | mkUnknownMonadResult
let extractStep? (type : Expr) : MetaM (Option ExtractMonadResult) := do
let .app m α _ := type | return none
let .app m returnType _ := type | return none
try
let bindInstType ← mkAppM ``Bind #[m]
discard <| Meta.synthInstance bindInstType
return some { m, α, expectedType }
return some { m, returnType, expectedType }
catch _ =>
return none
let rec extract? (type : Expr) : MetaM (Option ExtractMonadResult) := do
@ -863,9 +863,12 @@ def Kind.isRegular : Kind → Bool
| _ => false
structure Context where
m : Syntax -- Syntax to reference the monad associated with the do notation.
uvars : Array Var
kind : Kind
/-- Syntax to reference the monad associated with the do notation. -/
m : Syntax
/-- Syntax to reference the result of the monadic computation performed by the do notation. -/
returnType : Syntax
uvars : Array Var
kind : Kind
abbrev M := ReaderT Context MacroM
@ -1031,8 +1034,8 @@ partial def toTerm (c : Code) : M Syntax := do
let termMatchAlts := mkNode `Lean.Parser.Term.matchAlts #[mkNullNode termAlts]
return mkNode `Lean.Parser.Term.«match» #[mkAtomFrom ref "match", genParam, optMotive, discrs, mkAtomFrom ref "with", termMatchAlts]
def run (code : Code) (m : Syntax) (uvars : Array Var := #[]) (kind := Kind.regular) : MacroM Syntax :=
toTerm code { m := m, kind := kind, uvars := uvars }
def run (code : Code) (m : Syntax) (returnType : Syntax) (uvars : Array Var := #[]) (kind := Kind.regular) : MacroM Syntax :=
toTerm code { m, returnType, kind, uvars }
/- Given
- `a` is true if the code block has a `Code.action _` exit point
@ -1051,8 +1054,8 @@ def mkNestedKind (a r bc : Bool) : Kind :=
| true, true, true => .nestedPRBC
| false, false, false => unreachable!
def mkNestedTerm (code : Code) (m : Syntax) (uvars : Array Var) (a r bc : Bool) : MacroM Syntax := do
ToTerm.run code m uvars (mkNestedKind a r bc)
def mkNestedTerm (code : Code) (m : Syntax) (returnType : Syntax) (uvars : Array Var) (a r bc : Bool) : MacroM Syntax := do
ToTerm.run code m returnType uvars (mkNestedKind a r bc)
/- Given a term `term` produced by `ToTerm.run`, pattern match on its result.
See comment at the beginning of the `ToTerm` namespace.
@ -1119,7 +1122,10 @@ namespace ToCodeBlock
structure Context where
ref : Syntax
m : Syntax -- Syntax representing the monad associated with the do notation.
/-- Syntax representing the monad associated with the do notation. -/
m : Syntax
/-- Syntax to reference the result of the monadic computation performed by the do notation. -/
returnType : Syntax
mutableVars : VarSet := {}
insideFor : Bool := false
@ -1155,7 +1161,7 @@ def mkForInBody (_ : Syntax) (forInBody : CodeBlock) : M ToForInTermResult := d
let ctx ← read
let uvars := forInBody.uvars
let uvars := varSetToArray uvars
let term ← liftMacroM <| ToTerm.run forInBody.code ctx.m uvars (if hasReturn forInBody.code then ToTerm.Kind.forInWithReturn else ToTerm.Kind.forIn)
let term ← liftMacroM <| ToTerm.run forInBody.code ctx.m ctx.returnType uvars (if hasReturn forInBody.code then ToTerm.Kind.forInWithReturn else ToTerm.Kind.forIn)
return ⟨uvars, term⟩
def ensureInsideFor : M Unit :=
@ -1413,10 +1419,11 @@ mutual
let uvarsTuple ← liftMacroM do mkTuple uvars
if hasReturn forInBodyCodeBlock.code then
let forInBody ← liftMacroM <| destructTuple uvars (← `(r)) forInBody
let optType ← `(Option $((← read).returnType))
let forInTerm ← if let some h := h? then
`(for_in'% $(xs) (MProd.mk none $uvarsTuple) fun $x $h r => let r := r.2; $forInBody)
`(for_in'% $(xs) (MProd.mk (none : $optType) $uvarsTuple) fun $x $h (r : MProd $optType _) => let r := r.2; $forInBody)
else
`(for_in% $(xs) (MProd.mk none $uvarsTuple) fun $x r => let r := r.2; $forInBody)
`(for_in% $(xs) (MProd.mk (none : $optType) $uvarsTuple) fun $x (r : MProd $optType _) => let r := r.2; $forInBody)
let auxDo ← `(do let r ← $forInTerm:term;
$uvarsTuple:term := r.2;
match r.1 with
@ -1495,7 +1502,7 @@ mutual
let bc := tryCatchPred tryCode catches finallyCode? hasBreakContinue
let toTerm (codeBlock : CodeBlock) : M Syntax := do
let codeBlock ← liftM $ extendUpdatedVars codeBlock ws
liftMacroM <| ToTerm.mkNestedTerm codeBlock.code ctx.m uvars a r bc
liftMacroM <| ToTerm.mkNestedTerm codeBlock.code ctx.m ctx.returnType uvars a r bc
let term ← toTerm tryCode
let term ← catches.foldlM (init := term) fun term «catch» => do
let catchTerm ← toTerm «catch».codeBlock
@ -1511,7 +1518,7 @@ mutual
throwError "'finally' currently does not support reassignments"
if hasBreakContinueReturn finallyCode.code then
throwError "'finally' currently does 'return', 'break', nor 'continue'"
let finallyTerm ← liftMacroM <| ToTerm.run finallyCode.code ctx.m {} ToTerm.Kind.regular
let finallyTerm ← liftMacroM <| ToTerm.run finallyCode.code ctx.m ctx.returnType {} ToTerm.Kind.regular
``(tryFinally $term $finallyTerm)
let doElemsNew ← liftMacroM <| ToTerm.matchNestedTermResult term uvars a r bc
doSeqToCode (doElemsNew ++ doElems)
@ -1592,8 +1599,8 @@ mutual
throwError "unexpected do-element of kind {doElem.getKind}:\n{doElem}"
end
def run (doStx : Syntax) (m : Syntax) : TermElabM CodeBlock :=
(doSeqToCode <| getDoSeqElems <| getDoSeq doStx).run { ref := doStx, m }
def run (doStx : Syntax) (m : Syntax) (returnType : Syntax) : TermElabM CodeBlock :=
(doSeqToCode <| getDoSeqElems <| getDoSeq doStx).run { ref := doStx, m, returnType }
end ToCodeBlock
@ -1601,8 +1608,9 @@ end ToCodeBlock
tryPostponeIfNoneOrMVar expectedType?
let bindInfo ← extractBind expectedType?
let m ← Term.exprToSyntax bindInfo.m
let codeBlock ← ToCodeBlock.run stx m
let stxNew ← liftMacroM <| ToTerm.run codeBlock.code m
let returnType ← Term.exprToSyntax bindInfo.returnType
let codeBlock ← ToCodeBlock.run stx m returnType
let stxNew ← liftMacroM <| ToTerm.run codeBlock.code m returnType
trace[Elab.do] stxNew
withMacroExpansion stx stxNew <| elabTermEnsuringType stxNew bindInfo.expectedType

View file

@ -1,6 +1,6 @@
217.lean:5:28-5:29: error: don't know how to synthesize placeholder for argument 'f'
context:
⊢ CoreM Unit → Name → ConstantInfo → CoreM Unit
217.lean:5:30-5:31: error: don't know how to synthesize placeholder for argument 'init'
context:
⊢ CoreM Unit
217.lean:5:28-5:29: error: don't know how to synthesize placeholder for argument 'f'
context:
⊢ CoreM Unit → Name → ConstantInfo → CoreM Unit

View file

@ -12,8 +12,8 @@ linterUnusedVariables.lean:50:11-50:12: warning: unused variable `z`
linterUnusedVariables.lean:55:14-55:15: warning: unused variable `y`
linterUnusedVariables.lean:61:20-61:21: warning: unused variable `y`
linterUnusedVariables.lean:66:34-66:38: warning: unused variable `inst`
linterUnusedVariables.lean:108:6-108:7: warning: unused variable `y`
linterUnusedVariables.lean:107:25-107:26: warning: unused variable `x`
linterUnusedVariables.lean:108:6-108:7: warning: unused variable `y`
linterUnusedVariables.lean:114:6-114:7: warning: unused variable `a`
linterUnusedVariables.lean:124:26-124:27: warning: unused variable `z`
linterUnusedVariables.lean:132:9-132:10: warning: unused variable `h`

View file

@ -0,0 +1,8 @@
def main (args : List String) : IO UInt32 := do
for (arg : String) in args do
match arg with
| "--print-cflags" =>
return 1
| _ =>
pure ()
return 0