lean4-htt/tmp/eqns/prototype.lean
2020-04-03 15:20:54 -07:00

226 lines
7.4 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

prelude
import Init.Lean.Meta.Check
import Init.Lean.Meta.Tactic.Cases
import Init.Lean.Meta.GeneralizeTelescope
namespace Lean
namespace Meta
namespace DepElim
inductive Pattern
| inaccessible (ref : Syntax) (e : Expr)
| var (ref : Syntax) (fvarId : FVarId)
| ctor (ref : Syntax) (ctorName : Name) (fields : List Pattern)
| val (ref : Syntax) (e : Expr)
| arrayLit (ref : Syntax) (xs : List Pattern)
namespace Pattern
instance : Inhabited Pattern := ⟨Pattern.arrayLit Syntax.missing []⟩
partial def toMessageData : Pattern → MessageData
| inaccessible _ e => ".(" ++ e ++ ")"
| var _ fvarId => mkFVar fvarId
| ctor _ ctorName [] => ctorName
| ctor _ ctorName pats => "(" ++ ctorName ++ pats.foldl (fun (msg : MessageData) pat => msg ++ " " ++ toMessageData pat) Format.nil ++ ")"
| val _ e => "val!(" ++ e ++ ")"
| arrayLit _ pats => "#[" ++ MessageData.joinSep (pats.map toMessageData) ", " ++ "]"
end Pattern
structure AltLHS :=
(fvarDecls : List LocalDecl) -- Free variables used in the patterns.
(patterns : List Pattern) -- We use `List Pattern` since we have nary match-expressions.
structure MinorsRange :=
(firstMinorPos : Nat)
(numMinors : Nat)
abbrev AltToMinorsMap := PersistentHashMap Nat MinorsRange
structure Alt :=
(idx : Nat) -- for generating error messages
(fvarDecls : List LocalDecl)
(patterns : List Pattern)
namespace Alt
instance : Inhabited Alt := ⟨⟨0, [], []⟩⟩
partial def toMessageData (alt : Alt) : MetaM MessageData := do
lctx ← getLCtx;
localInsts ← getLocalInstances;
let lctx := alt.fvarDecls.foldl (fun (lctx : LocalContext) decl => lctx.addDecl decl) lctx;
withLocalContext lctx localInsts $ do
let msg : MessageData := "⟦" ++ MessageData.joinSep (alt.patterns.map Pattern.toMessageData) ", " ++ "⟧";
addContext msg
end Alt
structure Problem :=
(goal : Expr)
(vars : List Expr)
(alts : List Alt)
namespace Problem
instance : Inhabited Problem := ⟨⟨arbitrary _, [], []⟩⟩
def toMessageData (p : Problem) : MetaM MessageData := do
alts ← p.alts.mapM Alt.toMessageData;
pure $ "vars " ++ p.vars.toArray ++ Format.line ++ MessageData.joinSep alts Format.line
end Problem
structure ElimResult :=
(numMinors : Nat) -- It is the number of alternatives (Reason: support for overlapping equations)
(numEqs : Nat) -- It is the number of minors (Reason: users may want equations that hold definitionally)
(elim : Expr) -- The eliminator. It is not just `Expr.const elimName` because the type of the major premises may contain free variables.
(altMap : AltToMinorsMap) -- each alternative may be "expanded" into multiple minor premise
private def checkNumPatterns (majors : List Expr) (lhss : List AltLHS) : MetaM Unit :=
let num := majors.length;
when (lhss.any (fun lhs => lhs.patterns.length != num)) $
throw $ Exception.other "incorrect number of patterns"
private def mkElimSort (inProp : Bool) : MetaM Expr :=
if inProp then
pure $ mkSort $ levelZero
else do
vId ← mkFreshId;
pure $ mkSort $ mkLevelParam vId
private def withMotive {α} (majors : Array Expr) (sortv : Expr) (k : Expr → MetaM α) : MetaM α := do
type ← mkForall majors sortv;
trace! `Meta.debug type;
withLocalDecl `motive type BinderInfo.default k
private def mkAlts (lhss : List AltLHS) : List Alt :=
let alts : List Alt := lhss.foldl
(fun result lhs => { idx := result.length, fvarDecls := lhs.fvarDecls, patterns := lhs.patterns } :: result)
[];
alts.reverse
private def process : Problem → MetaM Unit
| p => withIncRecDepth $ do
traceM `Meta.debug p.toMessageData;
pure ()
def mkElim (elimName : Name) (majors : List Expr) (lhss : List AltLHS) (inProp : Bool := false) : MetaM ElimResult := do
checkNumPatterns majors lhss;
generalizeTelescope majors.toArray `_d $ fun majors => do
sortv ← mkElimSort inProp;
withMotive majors sortv $ fun motive => do
let target := mkAppN motive majors;
goal ← mkFreshExprMVar target;
let alts := mkAlts lhss;
let problem := { Problem . goal := goal, vars := majors.toList, alts := alts };
process problem;
pure { numMinors := 0, numEqs := 0, elim := arbitrary _, altMap := {} } -- TODO
end DepElim
end Meta
end Lean
open Lean
open Lean.Meta
open Lean.Meta.DepElim
/- Infrastructure for testing -/
universes u v
def inaccessible {α : Sort u} (a : α) : α := a
def val {α : Sort u} (a : α) : α := a
/- Convert expression using auxiliary hints `inaccessible` and `val` into a pattern -/
partial def mkPattern : Expr → MetaM Pattern
| e =>
if e.isAppOfArity `val 2 then
pure $ Pattern.val Syntax.missing e.appArg!
else if e.isAppOfArity `inaccessible 2 then
pure $ Pattern.inaccessible Syntax.missing e.appArg!
else if e.isFVar then
pure $ Pattern.var Syntax.missing e.fvarId!
else match e.arrayLit? with
| some es => do
pats ← es.mapM mkPattern;
pure $ Pattern.arrayLit Syntax.missing pats
| none => do
cval? ← constructorApp? e;
match cval? with
| none => throw $ Exception.other "unexpected pattern"
| some cval => do
let args := e.getAppArgs;
let fields := args.extract cval.nparams args.size;
pats ← fields.toList.mapM mkPattern;
pure $ Pattern.ctor Syntax.missing cval.name pats
partial def decodePats : Expr → MetaM (List Pattern)
| e =>
match e.app2? `Pat with
| some (_, pat) => do pat ← mkPattern pat; pure [pat]
| none =>
match e.prod? with
| none => throw $ Exception.other "unexpected pattern"
| some (pat, pats) => do
pat ← decodePats pat;
pats ← decodePats pats;
pure (pat ++ pats)
partial def decodeAltLHS (e : Expr) : MetaM AltLHS :=
forallTelescopeReducing e $ fun args body => do
decls ← args.toList.mapM (fun arg => getLocalDecl arg.fvarId!);
pats ← decodePats body;
pure { fvarDecls := decls, patterns := pats }
partial def decodeAltLHSs : Expr → MetaM (List AltLHS)
| e =>
match e.app2? `LHS with
| some (_, lhs) => do lhs ← decodeAltLHS lhs; pure [lhs]
| none =>
match e.prod? with
| none => throw $ Exception.other "unexpected LHS"
| some (lhs, lhss) => do
lhs ← decodeAltLHSs lhs;
lhss ← decodeAltLHSs lhss;
pure (lhs ++ lhss)
def withDepElimFrom {α} (declName : Name) (numPats : Nat) (k : List FVarId → List AltLHS → MetaM α) : MetaM α := do
cinfo ← getConstInfo declName;
forallTelescopeReducing cinfo.type $ fun args body =>
if args.size < numPats then
throw $ Exception.other "insufficient number of parameters"
else do
let xs := (args.extract (args.size - numPats) args.size).toList.map $ Expr.fvarId!;
alts ← decodeAltLHSs body;
k xs alts
inductive Pat {α : Sort u} (a : α) : Type u
| mk {} : Pat
inductive LHS {α : Sort u} (a : α) : Type u
| mk {} : LHS
instance LHS.inhabited {α} (a : α) : Inhabited (LHS a) := ⟨LHS.mk⟩
def ex1 (α : Type u) (β : Type v) (n : Nat) (x : List α) (y : List β) :
LHS (Pat ([] : List α) × Pat ([] : List α))
× LHS (forall (a : α) (b : α), Pat [a] × Pat [b])
× LHS (forall (a₁ a₂ : α) (as : List α) (b₁ b₂ : β) (bs : List β), Pat (a₁::a₂::as) × Pat (b₁::b₂::bs))
× LHS (forall (as : List α) (bs : List β), Pat as × Pat bs)
:= arbitrary _
@[init] def register : IO Unit :=
registerTraceClass `Meta.mkElim
set_option trace.Meta.debug true
def tst1 : MetaM Unit :=
withDepElimFrom `ex1 2 $ fun majors alts => do
let majors := majors.map mkFVar;
trace! `Meta.debug majors.toArray;
mkElim `test majors alts;
pure ()
#eval tst1