test: deriving experiment
This commit is contained in:
parent
5249fdc24d
commit
e9a1c3ac44
1 changed files with 62 additions and 16 deletions
|
|
@ -108,14 +108,25 @@ structure AuxFun where
|
|||
structure State where
|
||||
auxFns : NameMap AuxFun -- type name to function
|
||||
|
||||
structure Context where
|
||||
structure ContextCore where
|
||||
classInfo : ConstantInfo
|
||||
typeInfos : Array InductiveVal
|
||||
auxFunNames : Array Name
|
||||
usePartial : Bool
|
||||
resultType : Syntax
|
||||
|
||||
def mkContext (className : Name) (typeName : Name) (resultType : Syntax) : TermElabM Context := do
|
||||
structure Header where
|
||||
binders : Array Syntax
|
||||
argNames : Array Name
|
||||
targetName : Name
|
||||
|
||||
abbrev MkAltRhs :=
|
||||
ContextCore → (ctorName : Name) → (ctorArgs : Array (Syntax × Expr)) → TermElabM Syntax
|
||||
|
||||
structure Context extends ContextCore where
|
||||
mkAltRhs : MkAltRhs
|
||||
|
||||
def mkContext (className : Name) (typeName : Name) (resultType : Syntax) (mkAltRhs : MkAltRhs) : TermElabM Context := do
|
||||
let indVal ← getConstInfoInduct typeName
|
||||
let mut typeInfos := #[]
|
||||
for typeName in indVal.all do
|
||||
|
|
@ -134,6 +145,7 @@ def mkContext (className : Name) (typeName : Name) (resultType : Syntax) : TermE
|
|||
auxFunNames := auxFunNames
|
||||
usePartial := usePartial
|
||||
resultType := resultType
|
||||
mkAltRhs := mkAltRhs
|
||||
}
|
||||
|
||||
def mkInductArgNames (indVal : InductiveVal) : TermElabM (Array Name) := do
|
||||
|
|
@ -174,11 +186,6 @@ def mkInstImplicitBinders (ctx : Context) (indVal : InductiveVal) (argNames : Ar
|
|||
pure ()
|
||||
return binders
|
||||
|
||||
structure Header where
|
||||
binders : Array Syntax
|
||||
argNames : Array Name
|
||||
targetName : Name
|
||||
|
||||
def mkHeader (ctx : Context) (indVal : InductiveVal) : TermElabM Header := do
|
||||
let argNames ← mkInductArgNames indVal
|
||||
let binders ← mkImplicitBinders argNames
|
||||
|
|
@ -215,11 +222,49 @@ def mkLet (letDecls : Array Syntax) (body : Syntax) : TermElabM Syntax :=
|
|||
letDecls.foldrM (init := body) fun letDecl body =>
|
||||
`(let $letDecl:letDecl; $body)
|
||||
|
||||
def matchAltExpr := Parser.Term.matchAlt
|
||||
|
||||
def mkMatch (ctx : Context) (header : Header) (indVal : InductiveVal) (argNames : Array Name) : TermElabM Syntax := do
|
||||
let discrs ← mkDiscrs
|
||||
let alts ← mkAlts
|
||||
`(match $[$discrs],* with | $[$alts:matchAlt]|*)
|
||||
where
|
||||
mkDiscr (varName : Name) : TermElabM Syntax :=
|
||||
`(Parser.Term.matchDiscr| $(mkIdent varName):term)
|
||||
|
||||
mkDiscrs : TermElabM (Array Syntax) := do
|
||||
let mut discrs := #[]
|
||||
-- add indices
|
||||
for argName in argNames[indVal.nparams:] do
|
||||
discrs := discrs.push (← mkDiscr argName)
|
||||
return discrs.push (← mkDiscr header.targetName)
|
||||
|
||||
mkAlts : TermElabM (Array Syntax) := do
|
||||
let mut alts := #[]
|
||||
for ctorName in indVal.ctors do
|
||||
let mut patterns := #[]
|
||||
-- add `_` pattern for indices
|
||||
for i in [:indVal.nindices] do
|
||||
patterns := patterns.push (← `(_))
|
||||
let ctorInfo ← getConstInfoCtor ctorName
|
||||
let mut ctorArgs := #[]
|
||||
-- add `_` for inductive parameters, they are inaccessible
|
||||
for i in [:indVal.nparams] do
|
||||
ctorArgs := ctorArgs.push (← `(_))
|
||||
for i in [:ctorInfo.nfields] do
|
||||
ctorArgs := ctorArgs.push (mkIdent (← mkFreshUserName `y))
|
||||
patterns := patterns.push (← `(@$(mkIdent ctorName):ident $ctorArgs:term*))
|
||||
let altRhs ← forallTelescopeReducing ctorInfo.type fun xs _ =>
|
||||
ctx.mkAltRhs ctx.toContextCore ctorName (Array.zip ctorArgs xs)
|
||||
let alt ← `(matchAltExpr| $[$patterns:term],* => $altRhs:term)
|
||||
alts := alts.push alt
|
||||
return alts
|
||||
|
||||
def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Syntax := do
|
||||
let auxFunName ← ctx.auxFunNames[i]
|
||||
let indVal ← ctx.typeInfos[i]
|
||||
let header ← mkHeader ctx indVal
|
||||
let mut body ← `("testing") -- TODO
|
||||
let mut body ← mkMatch ctx header indVal header.argNames
|
||||
if ctx.usePartial then
|
||||
let letDecls ← mkLocalInstanceLetDecls ctx header.argNames
|
||||
body ← mkLet letDecls body
|
||||
|
|
@ -251,23 +296,24 @@ def mkInstanceCmds (ctx : Context) : TermElabM (Array Syntax) := do
|
|||
instances := instances.push instCmd
|
||||
return instances
|
||||
|
||||
open Command in
|
||||
def tst (className : Name) (typeName : Name) (resultTypeName : Name) : CommandElabM Unit := do
|
||||
open Command
|
||||
|
||||
def mkDeriving (className : Name) (typeName : Name) (resultType : Syntax) (mkAltRhs : MkAltRhs) : CommandElabM Unit := do
|
||||
let cmds ← liftTermElabM none do
|
||||
let resultType := mkIdent resultTypeName
|
||||
let ctx ← mkContext className typeName resultType
|
||||
let ctx ← mkContext className typeName resultType mkAltRhs
|
||||
let block ← mkMutualBlock ctx
|
||||
trace[Meta.debug]! "\n{block}"
|
||||
return #[block] ++ (← mkInstanceCmds ctx)
|
||||
cmds.forM elabCommand
|
||||
|
||||
def mkDerivingToString (typeName : Name) : CommandElabM Unit := do
|
||||
mkDeriving `ToString typeName (← `(String)) fun ctx ctorName ctorArgs =>
|
||||
quote (toString ctorName) -- TODO
|
||||
|
||||
syntax[runTstKind] "runTst" : command
|
||||
|
||||
open Command in
|
||||
@[commandElab runTstKind] def elabTst : CommandElab := fun stx =>
|
||||
tst `ToString `Vect `String
|
||||
mkDerivingToString `Test.Foo
|
||||
|
||||
set_option trace.Meta.debug true
|
||||
runTst
|
||||
|
||||
#eval (Vect.nil : Vect Nat Nat _ _)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue