feat: add processComplete
This commit is contained in:
parent
d36ccaa620
commit
e43b5e27a1
3 changed files with 145 additions and 19 deletions
|
|
@ -635,11 +635,13 @@ constant abstract (e : @& Expr) (xs : @& Array Expr) : Expr := arbitrary _
|
|||
@[extern "lean_expr_abstract_range"]
|
||||
constant abstractRange (e : @& Expr) (n : @& Nat) (xs : @& Array Expr) : Expr := arbitrary _
|
||||
|
||||
/-- Replace occurrences of the free variable `fvar` in `e` with `v` -/
|
||||
def replaceFVar (e : Expr) (fvar : Expr) (v : Expr) : Expr :=
|
||||
(e.abstract #[fvar]).instantiate1 v
|
||||
|
||||
def replaceFVarId (e : Expr) (fvarId : FVarId) (newFVarId : FVarId) : Expr :=
|
||||
replaceFVar e (mkFVar fvarId) (mkFVar newFVarId)
|
||||
/-- Replace occurrences of the free variable `fvarId` in `e` with `v` -/
|
||||
def replaceFVarId (e : Expr) (fvarId : FVarId) (v : Expr) : Expr :=
|
||||
replaceFVar e (mkFVar fvarId) v
|
||||
|
||||
instance : HasToString Expr :=
|
||||
⟨Expr.dbgToString⟩
|
||||
|
|
|
|||
|
|
@ -35,6 +35,10 @@ def index : LocalDecl → Nat
|
|||
| cdecl idx _ _ _ _ => idx
|
||||
| ldecl idx _ _ _ _ => idx
|
||||
|
||||
def setIndex : LocalDecl → Nat → LocalDecl
|
||||
| cdecl _ id n t bi, idx => cdecl idx id n t bi
|
||||
| ldecl _ id n t v, idx => ldecl idx id n t v
|
||||
|
||||
def fvarId : LocalDecl → FVarId
|
||||
| cdecl _ id _ _ _ => id
|
||||
| ldecl _ id _ _ _ => id
|
||||
|
|
@ -112,7 +116,10 @@ match lctx with
|
|||
/- Low level API -/
|
||||
def addDecl (lctx : LocalContext) (newDecl : LocalDecl) : LocalContext :=
|
||||
match lctx with
|
||||
| { fvarIdToDecl := map, decls := decls } => { fvarIdToDecl := map.insert newDecl.fvarId newDecl, decls := decls.set newDecl.index newDecl }
|
||||
| { fvarIdToDecl := map, decls := decls } =>
|
||||
let idx := decls.size;
|
||||
let newDecl := newDecl.setIndex idx;
|
||||
{ fvarIdToDecl := map.insert newDecl.fvarId newDecl, decls := decls.push newDecl }
|
||||
|
||||
@[export lean_local_ctx_find]
|
||||
def find? (lctx : LocalContext) (fvarId : FVarId) : Option LocalDecl :=
|
||||
|
|
|
|||
|
|
@ -18,6 +18,13 @@ namespace Pattern
|
|||
|
||||
instance : Inhabited Pattern := ⟨Pattern.inaccessible Syntax.missing (arbitrary _)⟩
|
||||
|
||||
def ref : Pattern → Syntax
|
||||
| inaccessible r _ => r
|
||||
| var r _ => r
|
||||
| ctor r _ _ _ _ => r
|
||||
| val r _ => r
|
||||
| arrayLit r _ _ => r
|
||||
|
||||
partial def toMessageData : Pattern → MessageData
|
||||
| inaccessible _ e => ".(" ++ e ++ ")"
|
||||
| var _ fvarId => mkFVar fvarId
|
||||
|
|
@ -37,13 +44,7 @@ partial def toExpr : Pattern → MetaM Expr
|
|||
fields ← fields.mapM toExpr;
|
||||
pure $ mkAppN (mkConst ctorName us) (params ++ fields).toArray
|
||||
|
||||
partial def replaceFVarId (fvarId : FVarId) (newFVarId : FVarId) : Pattern → Pattern
|
||||
| inaccessible r e => inaccessible r $ e.replaceFVarId fvarId newFVarId
|
||||
| val r e => val r $ e.replaceFVarId fvarId newFVarId
|
||||
| ctor r n us ps fs => ctor r n us (ps.map fun p => p.replaceFVarId fvarId newFVarId) (fs.map replaceFVarId)
|
||||
| arrayLit r t xs => arrayLit r (t.replaceFVarId fvarId newFVarId) (xs.map replaceFVarId)
|
||||
| p@(var r id) => if fvarId == id then var r newFVarId else p
|
||||
|
||||
/- Apply the free variable substitution `s` to the given pattern -/
|
||||
partial def applyFVarSubst (s : FVarSubst) : Pattern → Pattern
|
||||
| inaccessible r e => inaccessible r $ e.applyFVarSubst s
|
||||
| ctor r n us ps fs => ctor r n us (ps.map fun p => p.applyFVarSubst s) (fs.map applyFVarSubst)
|
||||
|
|
@ -51,6 +52,22 @@ partial def applyFVarSubst (s : FVarSubst) : Pattern → Pattern
|
|||
| arrayLit r t xs => arrayLit r (t.applyFVarSubst s) (xs.map applyFVarSubst)
|
||||
| var r fvarId => var r $ s.get fvarId
|
||||
|
||||
/- Shorthand for applying an unary variable substitution -/
|
||||
def substFVarId (fvarId : FVarId) (newFVarId : FVarId) (p : Pattern) : Pattern :=
|
||||
let s : FVarSubst := {};
|
||||
p.applyFVarSubst (s.insert fvarId newFVarId)
|
||||
|
||||
/-
|
||||
Replace occurrences of the free variable `fvarId` with `e` in the pattern `p`.
|
||||
Remark: the nested pattern `var _ fvarId` are replaced with `inaccessible _ e`.
|
||||
This function is used by `processComplete`. -/
|
||||
partial def replaceFVarId (fvarId : FVarId) (e : Expr) : Pattern → Pattern
|
||||
| inaccessible r e => inaccessible r $ e.replaceFVarId fvarId e
|
||||
| val r e => val r $ e.replaceFVarId fvarId e
|
||||
| ctor r n us ps fs => ctor r n us (ps.map fun p => p.replaceFVarId fvarId e) (fs.map replaceFVarId)
|
||||
| arrayLit r t xs => arrayLit r (t.replaceFVarId fvarId e) (xs.map replaceFVarId)
|
||||
| p@(var r id) => if fvarId == id then inaccessible r e else p
|
||||
|
||||
end Pattern
|
||||
|
||||
structure AltLHS :=
|
||||
|
|
@ -107,7 +124,6 @@ end Problem
|
|||
structure ElimResult :=
|
||||
(elim : Expr) -- The eliminator. It is not just `Expr.const elimName` because the type of the major premises may contain free variables.
|
||||
|
||||
|
||||
/- The number of patterns in each AltLHS must be equal to majors.length -/
|
||||
private def checkNumPatterns (majors : List Expr) (lhss : List AltLHS) : MetaM Unit :=
|
||||
let num := majors.length;
|
||||
|
|
@ -167,6 +183,18 @@ p.alts.all fun alt => match alt.patterns with
|
|||
| Pattern.ctor _ _ _ _ _ :: _ => true
|
||||
| _ => false
|
||||
|
||||
/- Return true if the next pattern of the remaining alternatives contain variables AND constructors. -/
|
||||
private def isCompleteTransition (p : Problem) : Bool :=
|
||||
let (ok, hasVar, hasCtor) := p.alts.foldl
|
||||
(fun (acc : Bool × Bool × Bool) (alt : Alt) =>
|
||||
let (ok, hasVar, hasCtor) := acc;
|
||||
match alt.patterns with
|
||||
| Pattern.ctor _ _ _ _ _ :: _ => (ok, hasVar, true)
|
||||
| Pattern.var _ _ :: _ => (ok, true, hasCtor)
|
||||
| _ => (false, hasVar, hasCtor))
|
||||
(true, false, false);
|
||||
ok && hasVar && hasCtor
|
||||
|
||||
private def processLeaf (p : Problem) (s : State) : MetaM State := do
|
||||
let alt := p.alts.head!;
|
||||
assignGoalOf p alt.rhs;
|
||||
|
|
@ -177,7 +205,7 @@ match p.vars with
|
|||
| x :: xs =>
|
||||
let alts := p.alts.map fun alt => match alt.patterns with
|
||||
| Pattern.inaccessible _ _ :: ps => { alt with patterns := ps }
|
||||
| Pattern.var _ fvarId :: ps => { alt with patterns := ps.map (fun p => p.replaceFVarId fvarId x.fvarId!), rhs := alt.rhs.replaceFVarId fvarId x.fvarId! }
|
||||
| Pattern.var _ fvarId :: ps => { alt with patterns := ps.map (fun p => p.substFVarId fvarId x.fvarId!), rhs := alt.rhs.replaceFVarId fvarId x }
|
||||
| _ => unreachable!;
|
||||
process { p with alts := alts, vars := xs } s
|
||||
| _ => unreachable!
|
||||
|
|
@ -209,6 +237,83 @@ match p.vars with
|
|||
s
|
||||
| _ => unreachable!
|
||||
|
||||
private def throwInductiveTypeExpected {α} (e : Expr) : MetaM α := do
|
||||
t ← inferType e;
|
||||
throwOther ("failed to compile pattern matching, inductive type expected" ++ indentExpr e ++ Format.line ++ "has type" ++ indentExpr t)
|
||||
|
||||
private partial def mkCompatibleCtorPattern (ref : Syntax) (ctorName : Name) (us : List Level) (params : Array Expr) (mvars : Array Expr) (varNamePrefix : Name)
|
||||
: Nat → Array LocalDecl → Array Pattern → MetaM (List LocalDecl × Pattern)
|
||||
| i, newDecls, fields =>
|
||||
if h : i < mvars.size then do
|
||||
let mvar := mvars.get ⟨i, h⟩;
|
||||
e ← instantiateMVars mvar;
|
||||
match e with
|
||||
| Expr.mvar _ _ => do
|
||||
type ← inferType e;
|
||||
withLocalDecl (varNamePrefix.appendIndexAfter i) type BinderInfo.default fun x => do
|
||||
decl ← getLocalDecl x.fvarId!;
|
||||
mkCompatibleCtorPattern (i+1) (newDecls.push decl) (fields.push (Pattern.var ref decl.fvarId))
|
||||
| _ => mkCompatibleCtorPattern (i+1) newDecls (fields.push (Pattern.inaccessible ref e))
|
||||
else
|
||||
pure (newDecls.toList, Pattern.ctor ref ctorName us params.toList fields.toList)
|
||||
|
||||
private partial def compatibleConstructor? (ref : Syntax) (ctorName : Name) (us : List Level) (params : Array Expr) (expectedType : Expr)
|
||||
(varNamePrefix : Name) : MetaM (Option (List LocalDecl × Pattern)) := do
|
||||
let ctor := mkAppN (mkConst ctorName us) params;
|
||||
ctorType ← inferType ctor;
|
||||
(mvars, _, resultType) ← forallMetaTelescopeReducing ctorType;
|
||||
condM (isDefEq resultType expectedType)
|
||||
(Option.some <$> mkCompatibleCtorPattern ref ctorName us params mvars varNamePrefix 0 #[] #[])
|
||||
(pure none)
|
||||
|
||||
private def getCompatibleConstructors (ref : Syntax) (e : Expr) (varNamePrefix : Name) : MetaM (List (List LocalDecl × Pattern)) := do
|
||||
env ← getEnv;
|
||||
expectedType ← inferType e;
|
||||
expectedType ← whnfD expectedType;
|
||||
let I := expectedType.getAppFn;
|
||||
let Iargs := expectedType.getAppArgs;
|
||||
matchConst env I (fun _ => throwInductiveTypeExpected e) fun info us =>
|
||||
match info with
|
||||
| ConstantInfo.inductInfo val =>
|
||||
let params := Iargs.extract 0 val.nparams;
|
||||
val.ctors.foldlM
|
||||
(fun (result : List (List LocalDecl × Pattern)) ctor => do
|
||||
entry? ← withNewMCtxDepth $ compatibleConstructor? ref ctor us params expectedType varNamePrefix;
|
||||
match entry? with
|
||||
| none => pure result
|
||||
| some entry => pure (entry :: result))
|
||||
[]
|
||||
| _ => throwInductiveTypeExpected e
|
||||
|
||||
private def replaceFVarIdAtLocalDecl (fvarId : FVarId) (e : Expr) (d : LocalDecl) : LocalDecl :=
|
||||
if d.fvarId == fvarId then d
|
||||
else match d with
|
||||
| LocalDecl.cdecl idx id n type bi => LocalDecl.cdecl idx id n (type.replaceFVarId fvarId e) bi
|
||||
| LocalDecl.ldecl idx id n type val => LocalDecl.ldecl idx id n (type.replaceFVarId fvarId e) (val.replaceFVarId fvarId e)
|
||||
|
||||
private def processComplete (process : Problem → State → MetaM State) (p : Problem) (s : State) : MetaM State :=
|
||||
withGoalOf p do
|
||||
env ← getEnv;
|
||||
newAlts ← p.alts.foldlM
|
||||
(fun (newAlts : List Alt) alt =>
|
||||
match alt.patterns with
|
||||
| Pattern.ctor _ _ _ _ _ :: ps => pure (alt :: newAlts)
|
||||
| p@(Pattern.var ref fvarId) :: ps =>
|
||||
withExistingLocalDecls alt.fvarDecls do
|
||||
ldecl ← getLocalDecl fvarId;
|
||||
dps ← getCompatibleConstructors p.ref (mkFVar fvarId) ldecl.userName;
|
||||
expandedAlts ← dps.mapM fun ⟨newLocalDecls, p⟩ => do {
|
||||
e ← p.toExpr;
|
||||
let ps := ps.map fun p => p.replaceFVarId fvarId e;
|
||||
let rhs := alt.rhs.replaceFVarId fvarId e;
|
||||
let fvarDecls := alt.fvarDecls.map (replaceFVarIdAtLocalDecl fvarId e);
|
||||
pure { alt with patterns := p :: ps, fvarDecls := fvarDecls ++ newLocalDecls, rhs := rhs }
|
||||
};
|
||||
pure (expandedAlts ++ newAlts)
|
||||
| _ => unreachable!)
|
||||
[];
|
||||
process { p with alts := newAlts.reverse } s
|
||||
|
||||
private partial def process : Problem → State → MetaM State
|
||||
| p, s => withIncRecDepth do
|
||||
withGoalOf p (traceM `Meta.debug p.toMessageData);
|
||||
|
|
@ -218,6 +323,8 @@ private partial def process : Problem → State → MetaM State
|
|||
processVariable process p s
|
||||
else if isConstructorTransition p then
|
||||
processConstructor process p s
|
||||
else if isCompleteTransition p then
|
||||
processComplete process p s
|
||||
else do
|
||||
msg ← p.toMessageData;
|
||||
-- TODO: remaining cases
|
||||
|
|
@ -392,17 +499,27 @@ def ex2 (α : Type u) (n : Nat) (xs : Vec α n) (ys : Vec α n) :
|
|||
arbitrary _
|
||||
|
||||
#eval test `ex2 3 `elimTest2
|
||||
#print elimTest2
|
||||
#check elimTest2
|
||||
|
||||
#exit
|
||||
|
||||
def ex2 (α : Type u) (β : Type v) (n : Nat) (x : List α) (y : List β) :
|
||||
def ex3 (α : Type u) (β : Type v) (n : Nat) (x : List α) (y : List β) :
|
||||
LHS (Pat ([] : List α) × Pat ([] : List β))
|
||||
× LHS (forall (a : α) (b : α), Pat [a] × Pat [b])
|
||||
× LHS (forall (a : α) (b : β), Pat [a] × Pat [b])
|
||||
× LHS (forall (a₁ a₂ : α) (as : List α) (b₁ b₂ : β) (bs : List β), Pat (a₁::a₂::as) × Pat (b₁::b₂::bs))
|
||||
× LHS (forall (as : List α) (bs : List β), Pat as × Pat bs)
|
||||
:= arbitrary _
|
||||
|
||||
#eval test `ex2 2 `elimTest2
|
||||
#eval test `ex3 2 `elimTest3
|
||||
#print elimTest3
|
||||
|
||||
#print elimTest2
|
||||
#exit
|
||||
|
||||
|
||||
def ex4 (α : Type u) (n : Nat) (xs : Vec α n) (ys : Vec α n) :
|
||||
LHS (Pat (inaccessible 0) × Pat (Vec.nil : Vec α 0) × Pat (Vec.nil : Vec α 0))
|
||||
× LHS (forall (n : Nat) (xs : Vec α n) (ys : Vec α n), Pat (inaccessible n) × Pat xs × Pat ys) :=
|
||||
arbitrary _
|
||||
|
||||
set_option trace.Meta.debug true
|
||||
|
||||
#eval test `ex4 3 `elimTest4
|
||||
#check elimTest4
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue