prelude import Init.Lean.Meta.Check import Init.Lean.Meta.Tactic.Cases import Init.Lean.Meta.GeneralizeTelescope namespace Lean namespace Meta namespace DepElim inductive Pattern | inaccessible (ref : Syntax) (e : Expr) | var (ref : Syntax) (fvarId : FVarId) | ctor (ref : Syntax) (ctorName : Name) (fields : List Pattern) | val (ref : Syntax) (e : Expr) | arrayLit (ref : Syntax) (xs : List Pattern) namespace Pattern instance : Inhabited Pattern := ⟨Pattern.arrayLit Syntax.missing []⟩ partial def toMessageData : Pattern → MessageData | inaccessible _ e => ".(" ++ e ++ ")" | var _ fvarId => mkFVar fvarId | ctor _ ctorName [] => ctorName | ctor _ ctorName pats => "(" ++ ctorName ++ pats.foldl (fun (msg : MessageData) pat => msg ++ " " ++ toMessageData pat) Format.nil ++ ")" | val _ e => "val!(" ++ e ++ ")" | arrayLit _ pats => "#[" ++ MessageData.joinSep (pats.map toMessageData) ", " ++ "]" end Pattern structure AltLHS := (fvarDecls : List LocalDecl) -- Free variables used in the patterns. (patterns : List Pattern) -- We use `List Pattern` since we have nary match-expressions. structure MinorsRange := (firstMinorPos : Nat) (numMinors : Nat) abbrev AltToMinorsMap := PersistentHashMap Nat MinorsRange structure Alt := (idx : Nat) -- for generating error messages (fvarDecls : List LocalDecl) (patterns : List Pattern) namespace Alt instance : Inhabited Alt := ⟨⟨0, [], []⟩⟩ partial def toMessageData (alt : Alt) : MetaM MessageData := do lctx ← getLCtx; localInsts ← getLocalInstances; let lctx := alt.fvarDecls.foldl (fun (lctx : LocalContext) decl => lctx.addDecl decl) lctx; withLocalContext lctx localInsts $ do let msg : MessageData := "⟦" ++ MessageData.joinSep (alt.patterns.map Pattern.toMessageData) ", " ++ "⟧"; addContext msg end Alt structure Problem := (goal : Expr) (vars : List Expr) (alts : List Alt) namespace Problem instance : Inhabited Problem := ⟨⟨arbitrary _, [], []⟩⟩ def toMessageData (p : Problem) : MetaM MessageData := do alts ← p.alts.mapM Alt.toMessageData; pure $ "vars " ++ p.vars.toArray ++ Format.line ++ MessageData.joinSep alts Format.line end Problem structure ElimResult := (numMinors : Nat) -- It is the number of alternatives (Reason: support for overlapping equations) (numEqs : Nat) -- It is the number of minors (Reason: users may want equations that hold definitionally) (elim : Expr) -- The eliminator. It is not just `Expr.const elimName` because the type of the major premises may contain free variables. (altMap : AltToMinorsMap) -- each alternative may be "expanded" into multiple minor premise private def checkNumPatterns (majors : List Expr) (lhss : List AltLHS) : MetaM Unit := let num := majors.length; when (lhss.any (fun lhs => lhs.patterns.length != num)) $ throw $ Exception.other "incorrect number of patterns" private def mkElimSort (inProp : Bool) : MetaM Expr := if inProp then pure $ mkSort $ levelZero else do vId ← mkFreshId; pure $ mkSort $ mkLevelParam vId private def withMotive {α} (majors : Array Expr) (sortv : Expr) (k : Expr → MetaM α) : MetaM α := do type ← mkForall majors sortv; trace! `Meta.debug type; withLocalDecl `motive type BinderInfo.default k private def mkAlts (lhss : List AltLHS) : List Alt := let alts : List Alt := lhss.foldl (fun result lhs => { idx := result.length, fvarDecls := lhs.fvarDecls, patterns := lhs.patterns } :: result) []; alts.reverse private def process : Problem → MetaM Unit | p => withIncRecDepth $ do traceM `Meta.debug p.toMessageData; pure () def mkElim (elimName : Name) (majors : List Expr) (lhss : List AltLHS) (inProp : Bool := false) : MetaM ElimResult := do checkNumPatterns majors lhss; generalizeTelescope majors.toArray `_d $ fun majors => do sortv ← mkElimSort inProp; withMotive majors sortv $ fun motive => do let target := mkAppN motive majors; goal ← mkFreshExprMVar target; let alts := mkAlts lhss; let problem := { Problem . goal := goal, vars := majors.toList, alts := alts }; process problem; pure { numMinors := 0, numEqs := 0, elim := arbitrary _, altMap := {} } -- TODO end DepElim end Meta end Lean open Lean open Lean.Meta open Lean.Meta.DepElim /- Infrastructure for testing -/ universes u v def inaccessible {α : Sort u} (a : α) : α := a def val {α : Sort u} (a : α) : α := a /- Convert expression using auxiliary hints `inaccessible` and `val` into a pattern -/ partial def mkPattern : Expr → MetaM Pattern | e => if e.isAppOfArity `val 2 then pure $ Pattern.val Syntax.missing e.appArg! else if e.isAppOfArity `inaccessible 2 then pure $ Pattern.inaccessible Syntax.missing e.appArg! else if e.isFVar then pure $ Pattern.var Syntax.missing e.fvarId! else match e.arrayLit? with | some es => do pats ← es.mapM mkPattern; pure $ Pattern.arrayLit Syntax.missing pats | none => do cval? ← constructorApp? e; match cval? with | none => throw $ Exception.other "unexpected pattern" | some cval => do let args := e.getAppArgs; let fields := args.extract cval.nparams args.size; pats ← fields.toList.mapM mkPattern; pure $ Pattern.ctor Syntax.missing cval.name pats partial def decodePats : Expr → MetaM (List Pattern) | e => match e.app2? `Pat with | some (_, pat) => do pat ← mkPattern pat; pure [pat] | none => match e.prod? with | none => throw $ Exception.other "unexpected pattern" | some (pat, pats) => do pat ← decodePats pat; pats ← decodePats pats; pure (pat ++ pats) partial def decodeAltLHS (e : Expr) : MetaM AltLHS := forallTelescopeReducing e $ fun args body => do decls ← args.toList.mapM (fun arg => getLocalDecl arg.fvarId!); pats ← decodePats body; pure { fvarDecls := decls, patterns := pats } partial def decodeAltLHSs : Expr → MetaM (List AltLHS) | e => match e.app2? `LHS with | some (_, lhs) => do lhs ← decodeAltLHS lhs; pure [lhs] | none => match e.prod? with | none => throw $ Exception.other "unexpected LHS" | some (lhs, lhss) => do lhs ← decodeAltLHSs lhs; lhss ← decodeAltLHSs lhss; pure (lhs ++ lhss) def withDepElimFrom {α} (declName : Name) (numPats : Nat) (k : List FVarId → List AltLHS → MetaM α) : MetaM α := do cinfo ← getConstInfo declName; forallTelescopeReducing cinfo.type $ fun args body => if args.size < numPats then throw $ Exception.other "insufficient number of parameters" else do let xs := (args.extract (args.size - numPats) args.size).toList.map $ Expr.fvarId!; alts ← decodeAltLHSs body; k xs alts inductive Pat {α : Sort u} (a : α) : Type u | mk {} : Pat inductive LHS {α : Sort u} (a : α) : Type u | mk {} : LHS instance LHS.inhabited {α} (a : α) : Inhabited (LHS a) := ⟨LHS.mk⟩ def ex1 (α : Type u) (β : Type v) (n : Nat) (x : List α) (y : List β) : LHS (Pat ([] : List α) × Pat ([] : List α)) × LHS (forall (a : α) (b : α), Pat [a] × Pat [b]) × LHS (forall (a₁ a₂ : α) (as : List α) (b₁ b₂ : β) (bs : List β), Pat (a₁::a₂::as) × Pat (b₁::b₂::bs)) × LHS (forall (as : List α) (bs : List β), Pat as × Pat bs) := arbitrary _ @[init] def register : IO Unit := registerTraceClass `Meta.mkElim set_option trace.Meta.debug true def tst1 : MetaM Unit := withDepElimFrom `ex1 2 $ fun majors alts => do let majors := majors.map mkFVar; trace! `Meta.debug majors.toArray; mkElim `test majors alts; pure () #eval tst1