From 24025b96c5a67420c46668f80859b0456a24d41a Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 13 Aug 2020 15:19:40 -0700 Subject: [PATCH] feat: elaborate equation RHS --- src/Lean/Elab/Match.lean | 19 +++++++++++-------- src/Lean/Expr.lean | 8 ++++++++ src/Lean/Meta/EqnCompiler/DepElim.lean | 5 +---- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/Lean/Elab/Match.lean b/src/Lean/Elab/Match.lean index 0c29208770..f82295995f 100644 --- a/src/Lean/Elab/Match.lean +++ b/src/Lean/Elab/Match.lean @@ -369,7 +369,7 @@ private partial def withPatternVarsAux {α} (pVars : Array PatternVar) (k : Arra private def withPatternVars {α} (pVars : Array PatternVar) (k : Array PatternVarDecl → TermElabM α) : TermElabM α := withPatternVarsAux pVars k 0 #[] -private partial def elabPatternsAux (patternStxs : Array Syntax) : Nat → Expr → Array Expr → TermElabM (Array Expr) +private partial def elabPatternsAux (patternStxs : Array Syntax) : Nat → Expr → Array Expr → TermElabM (Array Expr × Expr) | i, matchType, patterns => if h : i < patternStxs.size then do matchType ← whnf matchType; @@ -381,7 +381,7 @@ private partial def elabPatternsAux (patternStxs : Array Syntax) : Nat → Expr elabPatternsAux (i+1) (b.instantiate1 pattern) (patterns.push pattern) | _ => throwError "unexpected match type" else - pure patterns + pure (patterns, matchType) def finalizePatternDecls (patternVarDecls : Array PatternVarDecl) : TermElabM (Array LocalDecl) := patternVarDecls.foldlM @@ -496,24 +496,27 @@ end ToDepElimPattern def toDepElimPattern (localDecls : Array LocalDecl) (e : Expr) : TermElabM Meta.DepElim.Pattern := (ToDepElimPattern.main localDecls e).run' {} -private def elabPatterns (patternVarDecls : Array PatternVarDecl) (patternStxs : Array Syntax) (matchType : Expr) : TermElabM Meta.DepElim.AltLHS := do -patterns ← withSynthesize $ elabPatternsAux patternStxs 0 matchType #[]; +private def elabPatterns (patternVarDecls : Array PatternVarDecl) (patternStxs : Array Syntax) (matchType : Expr) : TermElabM (Meta.DepElim.AltLHS × Expr) := do +(patterns, matchType) ← withSynthesize $ elabPatternsAux patternStxs 0 matchType #[]; localDecls ← finalizePatternDecls patternVarDecls; patterns ← patterns.mapM instantiateMVars; trace `Elab.match fun _ => MessageData.ofArray $ localDecls.map fun (d : LocalDecl) => (d.userName ++ " : " ++ d.type : MessageData); patterns.forM $ fun pattern => when pattern.hasExprMVar $ throwError ("pattern contains metavariables " ++ indentExpr pattern); patterns ← patterns.mapM $ toDepElimPattern localDecls; trace `Elab.match fun _ => "patterns: " ++ MessageData.ofArray (patterns.map fun (p : Meta.DepElim.Pattern) => p.toMessageData); -pure { localDecls := localDecls.toList, patterns := patterns.toList } +pure ({ localDecls := localDecls.toList, patterns := patterns.toList }, matchType) def elabMatchAltView (alt : MatchAltView) (matchType : Expr) : TermElabM (Meta.DepElim.AltLHS × Expr) := withRef alt.ref do (patternVars, alt) ← collectPatternVars alt; trace `Elab.match fun _ => "patternVars: " ++ toString patternVars; withPatternVars patternVars fun patternVarDecls => do - ps ← elabPatterns patternVarDecls alt.patterns matchType; - -- TODO - pure (⟨[], []⟩, arbitrary _) + (altLHS, matchType) ← elabPatterns patternVarDecls alt.patterns matchType; + rhs ← elabTerm alt.rhs matchType; + let xs := altLHS.localDecls.toArray.map LocalDecl.toExpr; + rhs ← if xs.isEmpty then pure $ mkThunk rhs else mkLambda xs rhs; + trace `Elab.match fun _ => "rhs: " ++ rhs; + pure (altLHS, rhs) /- ``` diff --git a/src/Lean/Expr.lean b/src/Lean/Expr.lean index 4ba294cb21..f8a7f6ce0b 100644 --- a/src/Lean/Expr.lean +++ b/src/Lean/Expr.lean @@ -321,6 +321,14 @@ Expr.forallE x t b $ mkDataForBinder (mixHash 37 $ mixHash (hash t) (hash b)) (t.hasLevelParam || b.hasLevelParam) bi +/- Return `Unit -> type` -/ +def mkThunkType (type : Expr) : Expr := +mkForall Name.anonymous BinderInfo.default (Lean.mkConst `Unit) type + +/- Return `fun (_ : Unit), e` -/ +def mkThunk (type : Expr) : Expr := +mkLambda `_ BinderInfo.default (Lean.mkConst `Unit) type + def mkLet (x : Name) (t : Expr) (v : Expr) (b : Expr) (nonDep : Bool := false) : Expr := let x := x.eraseMacroScopes; Expr.letE x t v b $ mkDataForLet (mixHash 41 $ mixHash (hash t) $ mixHash (hash v) (hash b)) diff --git a/src/Lean/Meta/EqnCompiler/DepElim.lean b/src/Lean/Meta/EqnCompiler/DepElim.lean index e4050c8c03..04a452e307 100644 --- a/src/Lean/Meta/EqnCompiler/DepElim.lean +++ b/src/Lean/Meta/EqnCompiler/DepElim.lean @@ -284,9 +284,6 @@ private def localDeclsToMVarsAux : List LocalDecl → List MVarId → FVarSubst private def localDeclsToMVars (localDecls : List LocalDecl) : MetaM (List MVarId × FVarSubst) := localDeclsToMVarsAux localDecls [] {} -private def mkThunk (type : Expr) : Expr := -Lean.mkForall `u BinderInfo.default (Lean.mkConst `Unit) type - private partial def withAltsAux {α} (motive : Expr) : List AltLHS → List Alt → Array Expr → (List Alt → Array Expr → MetaM α) → MetaM α | [], alts, minors, k => k alts.reverse minors | lhs::lhss, alts, minors, k => do @@ -296,7 +293,7 @@ private partial def withAltsAux {α} (motive : Expr) : List AltLHS → List Alt let minorType := mkAppN motive args; mkForall xs minorType }; - let minorType := if minorType.isForall then minorType else mkThunk minorType; + let minorType := if minorType.isForall then minorType else mkThunkType minorType; let idx := alts.length; let minorName := (`h).appendIndexAfter (idx+1); trace! `Meta.EqnCompiler.matchDebug ("minor premise " ++ minorName ++ " : " ++ minorType);