feat: elaborate optional deriving after def
This commit is contained in:
parent
d682d60025
commit
bbb74bfd9a
7 changed files with 92 additions and 53 deletions
|
|
@ -40,6 +40,7 @@ structure DefView where
|
|||
binders : Syntax
|
||||
type? : Option Syntax
|
||||
value : Syntax
|
||||
deriving? : Option (Array Syntax) := none
|
||||
deriving Inhabited
|
||||
|
||||
namespace Command
|
||||
|
|
@ -51,20 +52,21 @@ def mkDefViewOfAbbrev (modifiers : Modifiers) (stx : Syntax) : DefView :=
|
|||
let (binders, type) := expandOptDeclSig stx[2]
|
||||
let modifiers := modifiers.addAttribute { name := `inline }
|
||||
let modifiers := modifiers.addAttribute { name := `reducible }
|
||||
{ ref := stx, kind := DefKind.abbrev, modifiers := modifiers,
|
||||
declId := stx[1], binders := binders, type? := type, value := stx[3] }
|
||||
{ ref := stx, kind := DefKind.abbrev, modifiers,
|
||||
declId := stx[1], binders, type? := type, value := stx[3] }
|
||||
|
||||
def mkDefViewOfDef (modifiers : Modifiers) (stx : Syntax) : DefView :=
|
||||
-- leading_parser "def " >> declId >> optDeclSig >> declVal
|
||||
-- leading_parser "def " >> declId >> optDeclSig >> declVal >> optDefDeriving
|
||||
let (binders, type) := expandOptDeclSig stx[2]
|
||||
{ ref := stx, kind := DefKind.def, modifiers := modifiers,
|
||||
declId := stx[1], binders := binders, type? := type, value := stx[3] }
|
||||
let deriving? := if stx[4].isNone then none else some stx[4][1].getSepArgs
|
||||
{ ref := stx, kind := DefKind.def, modifiers,
|
||||
declId := stx[1], binders, type? := type, value := stx[3], deriving? }
|
||||
|
||||
def mkDefViewOfTheorem (modifiers : Modifiers) (stx : Syntax) : DefView :=
|
||||
-- leading_parser "theorem " >> declId >> declSig >> declVal
|
||||
let (binders, type) := expandDeclSig stx[2]
|
||||
{ ref := stx, kind := DefKind.theorem, modifiers := modifiers,
|
||||
declId := stx[1], binders := binders, type? := some type, value := stx[3] }
|
||||
{ ref := stx, kind := DefKind.theorem, modifiers,
|
||||
declId := stx[1], binders, type? := some type, value := stx[3] }
|
||||
|
||||
namespace MkInstanceName
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
Authors: Leonardo de Moura, Wojciech Nawrocki
|
||||
-/
|
||||
import Lean.Elab.Command
|
||||
import Lean.Elab.MutualDef
|
||||
|
||||
namespace Lean.Elab
|
||||
open Command
|
||||
|
|
@ -35,29 +36,8 @@ def applyDerivingHandlers (className : Name) (typeNames : Array Name) (args? : O
|
|||
| none => defaultHandler className typeNames
|
||||
|
||||
private def tryApplyDefHandler (className : Name) (declName : Name) : CommandElabM Bool :=
|
||||
open Meta in
|
||||
liftTermElabM none do
|
||||
let ConstantInfo.defnInfo info ← getConstInfo declName | return false
|
||||
forallTelescopeReducing info.type fun xs type => do
|
||||
try
|
||||
let instType ← mkAppM className #[mkAppN (Lean.mkConst declName (info.levelParams.map mkLevelParam)) xs]
|
||||
check instType
|
||||
let oldInstType ← mkAppM className #[(mkAppN info.value xs).headBeta]
|
||||
check oldInstType
|
||||
let instVal ← synthInstance oldInstType
|
||||
let instName ← liftMacroM <| mkUnusedBaseName (declName.appendBefore "inst" |>.appendAfter className.getString!)
|
||||
addAndCompile <| Declaration.defnDecl {
|
||||
name := instName
|
||||
levelParams := info.levelParams
|
||||
type := instType
|
||||
value := instVal
|
||||
hints := info.hints
|
||||
safety := info.safety
|
||||
}
|
||||
addInstance instName AttributeKind.global (eval_prio default)
|
||||
return true
|
||||
catch _ =>
|
||||
return false
|
||||
Term.processDefDeriving className declName
|
||||
|
||||
@[builtinCommandElab «deriving»] def elabDeriving : CommandElab
|
||||
| `(deriving instance $[$classes $[with $argss?]?],* for $[$declNames],*) => do
|
||||
|
|
|
|||
|
|
@ -638,30 +638,63 @@ private def levelMVarToParamHeaders (views : Array DefView) (headers : Array Def
|
|||
let newHeaders ← process.run' 1
|
||||
newHeaders.mapM fun header => return { header with type := (← instantiateMVars header.type) }
|
||||
|
||||
def processDefDeriving (className : Name) (declName : Name) : TermElabM Bool := do
|
||||
try
|
||||
let ConstantInfo.defnInfo info ← getConstInfo declName | return false
|
||||
let instType ← mkAppM className #[Lean.mkConst declName (info.levelParams.map mkLevelParam)]
|
||||
Meta.check instType
|
||||
let oldInstType ← mkAppM className #[info.value]
|
||||
Meta.check oldInstType
|
||||
let instVal ← synthInstance oldInstType
|
||||
let instName ← liftMacroM <| mkUnusedBaseName (declName.appendBefore "inst" |>.appendAfter className.getString!)
|
||||
addAndCompile <| Declaration.defnDecl {
|
||||
name := instName
|
||||
levelParams := info.levelParams
|
||||
type := instType
|
||||
value := instVal
|
||||
hints := info.hints
|
||||
safety := info.safety
|
||||
}
|
||||
addInstance instName AttributeKind.global (eval_prio default)
|
||||
return true
|
||||
catch _ =>
|
||||
return false
|
||||
|
||||
def elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit :=
|
||||
if isExample views then
|
||||
withoutModifyingEnv go
|
||||
else
|
||||
go
|
||||
where go := do
|
||||
let scopeLevelNames ← getLevelNames
|
||||
let headers ← elabHeaders views
|
||||
let headers ← levelMVarToParamHeaders views headers
|
||||
let allUserLevelNames := getAllUserLevelNames headers
|
||||
withFunLocalDecls headers fun funFVars => do
|
||||
let values ← elabFunValues headers
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
let values ← values.mapM (instantiateMVars ·)
|
||||
let headers ← headers.mapM instantiateMVarsAtHeader
|
||||
let letRecsToLift ← getLetRecsToLift
|
||||
let letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift
|
||||
checkLetRecsToLiftTypes funFVars letRecsToLift
|
||||
withUsed vars headers values letRecsToLift fun vars => do
|
||||
let preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift
|
||||
let preDefs ← levelMVarToParamPreDecls preDefs
|
||||
let preDefs ← instantiateMVarsAtPreDecls preDefs
|
||||
let preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames
|
||||
addPreDefinitions preDefs
|
||||
where
|
||||
go := do
|
||||
let scopeLevelNames ← getLevelNames
|
||||
let headers ← elabHeaders views
|
||||
let headers ← levelMVarToParamHeaders views headers
|
||||
let allUserLevelNames := getAllUserLevelNames headers
|
||||
withFunLocalDecls headers fun funFVars => do
|
||||
let values ← elabFunValues headers
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
let values ← values.mapM (instantiateMVars ·)
|
||||
let headers ← headers.mapM instantiateMVarsAtHeader
|
||||
let letRecsToLift ← getLetRecsToLift
|
||||
let letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift
|
||||
checkLetRecsToLiftTypes funFVars letRecsToLift
|
||||
withUsed vars headers values letRecsToLift fun vars => do
|
||||
let preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift
|
||||
let preDefs ← levelMVarToParamPreDecls preDefs
|
||||
let preDefs ← instantiateMVarsAtPreDecls preDefs
|
||||
let preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames
|
||||
addPreDefinitions preDefs
|
||||
processDeriving headers
|
||||
|
||||
processDeriving (headers : Array DefViewElabHeader) := do
|
||||
for header in headers, view in views do
|
||||
if let some classNamesStx := view.deriving? then
|
||||
for classNameStx in classNamesStx do
|
||||
let className ← resolveGlobalConstNoOverload classNameStx
|
||||
withRef classNameStx do
|
||||
unless (← processDefDeriving className header.declName) do
|
||||
throwError "failed to synthesize instance '{className}' for '{header.declName}'"
|
||||
|
||||
end Term
|
||||
namespace Command
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ def declValSimple := leading_parser " :=\n" >> termParser >> optional Term.wh
|
|||
def declValEqns := leading_parser Term.matchAltsWhereDecls
|
||||
def declVal := declValSimple <|> declValEqns <|> Term.whereDecls
|
||||
def «abbrev» := leading_parser "abbrev " >> declId >> optDeclSig >> declVal
|
||||
def optDefDeriving := leading_parser optional (atomic ("deriving " >> notSymbol "instance") >> sepBy1 ident ", ")
|
||||
def optDefDeriving := optional (atomic ("deriving " >> notSymbol "instance") >> sepBy1 ident ", ")
|
||||
def «def» := leading_parser "def " >> declId >> optDeclSig >> declVal >> optDefDeriving
|
||||
def «theorem» := leading_parser "theorem " >> declId >> declSig >> declVal
|
||||
def «constant» := leading_parser "constant " >> declId >> declSig >> optional declValSimple
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ StxQuot.lean:8:12: error: expected command, identifier or term
|
|||
"(«term_+_» (numLit \"1\") \"+\" (numLit \"1\"))"
|
||||
StxQuot.lean:18:15: error: expected term
|
||||
"(Term.fun \"fun\" (Term.basicFun [`a._@.UnhygienicMain._hyg.1] \"=>\" `a._@.UnhygienicMain._hyg.1))"
|
||||
"(Command.declaration\n (Command.declModifiers [] [] [] [] [] [])\n (Command.def\n \"def\"\n (Command.declId `foo._@.UnhygienicMain._hyg.1 [])\n (Command.optDeclSig [] [])\n (Command.declValSimple \":=\" (numLit \"1\") [])\n (Command.optDefDeriving [])))"
|
||||
"[(Command.declaration\n (Command.declModifiers [] [] [] [] [] [])\n (Command.def\n \"def\"\n (Command.declId `foo._@.UnhygienicMain._hyg.1 [])\n (Command.optDeclSig [] [])\n (Command.declValSimple \":=\" (numLit \"1\") [])\n (Command.optDefDeriving [])))\n (Command.declaration\n (Command.declModifiers [] [] [] [] [] [])\n (Command.def\n \"def\"\n (Command.declId `bar._@.UnhygienicMain._hyg.1 [])\n (Command.optDeclSig [] [])\n (Command.declValSimple \":=\" (numLit \"2\") [])\n (Command.optDefDeriving [])))]"
|
||||
"(Command.declaration\n (Command.declModifiers [] [] [] [] [] [])\n (Command.def\n \"def\"\n (Command.declId `foo._@.UnhygienicMain._hyg.1 [])\n (Command.optDeclSig [] [])\n (Command.declValSimple \":=\" (numLit \"1\") [])\n []))"
|
||||
"[(Command.declaration\n (Command.declModifiers [] [] [] [] [] [])\n (Command.def\n \"def\"\n (Command.declId `foo._@.UnhygienicMain._hyg.1 [])\n (Command.optDeclSig [] [])\n (Command.declValSimple \":=\" (numLit \"1\") [])\n []))\n (Command.declaration\n (Command.declModifiers [] [] [] [] [] [])\n (Command.def\n \"def\"\n (Command.declId `bar._@.UnhygienicMain._hyg.1 [])\n (Command.optDeclSig [] [])\n (Command.declValSimple \":=\" (numLit \"2\") [])\n []))]"
|
||||
"`Nat.one._@.UnhygienicMain._hyg.1"
|
||||
"`Nat.one._@.UnhygienicMain._hyg.1"
|
||||
"(Term.app `f._@.UnhygienicMain._hyg.1 [`Nat.one._@.UnhygienicMain._hyg.1 `Nat.one._@.UnhygienicMain._hyg.1])"
|
||||
|
|
@ -18,8 +18,8 @@ StxQuot.lean:18:15: error: expected term
|
|||
"(Term.proj `Nat.one._@.UnhygienicMain._hyg.1 \".\" `b._@.UnhygienicMain._hyg.1)"
|
||||
"(«term_+_» (numLit \"2\") \"+\" (numLit \"1\"))"
|
||||
"(«term_+_» («term_+_» (numLit \"1\") \"+\" (numLit \"2\")) \"+\" (numLit \"1\"))"
|
||||
"(Command.declaration\n (Command.declModifiers [] [] [] [] [] [])\n (Command.def\n \"def\"\n (Command.declId `foo._@.UnhygienicMain._hyg.1 [])\n (Command.optDeclSig [] [])\n (Command.declValSimple \":=\" (numLit \"1\") [])\n (Command.optDefDeriving [])))"
|
||||
"[(Command.declaration\n (Command.declModifiers [] [] [] [] [] [])\n (Command.def\n \"def\"\n (Command.declId `bar._@.UnhygienicMain._hyg.1 [])\n (Command.optDeclSig [] [])\n (Command.declValSimple \":=\" (numLit \"2\") [])\n (Command.optDefDeriving [])))\n (Command.declaration\n (Command.declModifiers [] [] [] [] [] [])\n (Command.def\n \"def\"\n (Command.declId `foo._@.UnhygienicMain._hyg.1 [])\n (Command.optDeclSig [] [])\n (Command.declValSimple \":=\" (numLit \"1\") [])\n (Command.optDefDeriving [])))]"
|
||||
"(Command.declaration\n (Command.declModifiers [] [] [] [] [] [])\n (Command.def\n \"def\"\n (Command.declId `foo._@.UnhygienicMain._hyg.1 [])\n (Command.optDeclSig [] [])\n (Command.declValSimple \":=\" (numLit \"1\") [])\n []))"
|
||||
"[(Command.declaration\n (Command.declModifiers [] [] [] [] [] [])\n (Command.def\n \"def\"\n (Command.declId `bar._@.UnhygienicMain._hyg.1 [])\n (Command.optDeclSig [] [])\n (Command.declValSimple \":=\" (numLit \"2\") [])\n []))\n (Command.declaration\n (Command.declModifiers [] [] [] [] [] [])\n (Command.def\n \"def\"\n (Command.declId `foo._@.UnhygienicMain._hyg.1 [])\n (Command.optDeclSig [] [])\n (Command.declValSimple \":=\" (numLit \"1\") [])\n []))]"
|
||||
"0"
|
||||
0
|
||||
1
|
||||
|
|
|
|||
|
|
@ -12,3 +12,22 @@ deriving instance BEq, Repr for Foo
|
|||
#eval test 4
|
||||
|
||||
#check fun (x y : Foo) => x == y
|
||||
|
||||
def Boo := List (String × String)
|
||||
deriving BEq, Repr
|
||||
|
||||
def mkBoo (s : String) : Boo :=
|
||||
[(s, s)]
|
||||
|
||||
#eval mkBoo "hello"
|
||||
|
||||
#eval mkBoo "hell" == mkBoo "hello"
|
||||
#eval mkBoo "hello" == mkBoo "hello"
|
||||
|
||||
def M := ReaderT String (StateT Nat IO)
|
||||
deriving Monad
|
||||
|
||||
#print instMMonad
|
||||
|
||||
def action : M Unit := do
|
||||
pure ()
|
||||
|
|
|
|||
|
|
@ -5,3 +5,8 @@ defInst.lean:8:26-8:32: error: failed to synthesize instance
|
|||
fun x y => sorry : (x y : Foo) → ?m x y
|
||||
[4, 5, 6]
|
||||
fun x y => x == y : Foo → Foo → Bool
|
||||
[("hello", "hello")]
|
||||
false
|
||||
true
|
||||
def instMMonad : Monad M :=
|
||||
ReaderT.instMonadReaderT
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue