From 698d7e6dd1d4701d80796f4af77d56cafcccdbcd Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 3 Aug 2020 17:49:30 -0700 Subject: [PATCH] fix: bug at `processComplete` --- tmp/eqns/prototype.lean | 55 +++++++++++++++++++++++++++-------------- 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/tmp/eqns/prototype.lean b/tmp/eqns/prototype.lean index 9b07f0b1e2..1f3edf1002 100644 --- a/tmp/eqns/prototype.lean +++ b/tmp/eqns/prototype.lean @@ -205,7 +205,12 @@ match p.vars with | x :: xs => let alts := p.alts.map fun alt => match alt.patterns with | Pattern.inaccessible _ _ :: ps => { alt with patterns := ps } - | Pattern.var _ fvarId :: ps => { alt with patterns := ps.map (fun p => p.substFVarId fvarId x.fvarId!), rhs := alt.rhs.replaceFVarId fvarId x } + | Pattern.var _ fvarId :: ps => + let patterns := ps.map (fun p => p.substFVarId fvarId x.fvarId!); + let rhs := alt.rhs.replaceFVarId fvarId x; + /- We eliminate the LocalDecl for fvarId since it was substituted. -/ + let fvarDecls := alt.fvarDecls.filter fun d => d.fvarId != fvarId; + { alt with patterns := patterns, rhs := rhs, fvarDecls := fvarDecls } | _ => unreachable!; process { p with alts := alts, vars := xs } s | _ => unreachable! @@ -241,8 +246,9 @@ private def throwInductiveTypeExpected {α} (e : Expr) : MetaM α := do t ← inferType e; throwOther ("failed to compile pattern matching, inductive type expected" ++ indentExpr e ++ Format.line ++ "has type" ++ indentExpr t) -private partial def mkCompatibleCtorPattern (ref : Syntax) (ctorName : Name) (us : List Level) (params : Array Expr) (mvars : Array Expr) (varNamePrefix : Name) - : Nat → Array LocalDecl → Array Pattern → MetaM (List LocalDecl × Pattern) +/- Auxiliary method for `processComplete` -/ +private partial def mkCompatibleCtorPattern (ref : Syntax) (fvarDecls : List LocalDecl) (ctorName : Name) (us : List Level) (params : Array Expr) + (mvars : Array Expr) (varNamePrefix : Name) : Nat → Array LocalDecl → Array Pattern → MetaM (List LocalDecl × Pattern) | i, newDecls, fields => if h : i < mvars.size then do let mvar := mvars.get ⟨i, h⟩; @@ -253,20 +259,28 @@ private partial def mkCompatibleCtorPattern (ref : Syntax) (ctorName : Name) (us withLocalDecl (varNamePrefix.appendIndexAfter i) type BinderInfo.default fun x => do decl ← getLocalDecl x.fvarId!; mkCompatibleCtorPattern (i+1) (newDecls.push decl) (fields.push (Pattern.var ref decl.fvarId)) + | Expr.fvar fvarId _ => + if fvarDecls.any fun d => d.fvarId == fvarId then + mkCompatibleCtorPattern (i+1) newDecls (fields.push (Pattern.var ref fvarId)) + else + mkCompatibleCtorPattern (i+1) newDecls (fields.push (Pattern.inaccessible ref e)) | _ => mkCompatibleCtorPattern (i+1) newDecls (fields.push (Pattern.inaccessible ref e)) else pure (newDecls.toList, Pattern.ctor ref ctorName us params.toList fields.toList) -private partial def compatibleConstructor? (ref : Syntax) (ctorName : Name) (us : List Level) (params : Array Expr) (expectedType : Expr) +/- Auxiliary method for `processComplete` -/ +private partial def compatibleConstructor? (ref : Syntax) (fvarDecls : List LocalDecl) (ctorName : Name) (us : List Level) (params : Array Expr) (expectedType : Expr) (varNamePrefix : Name) : MetaM (Option (List LocalDecl × Pattern)) := do let ctor := mkAppN (mkConst ctorName us) params; ctorType ← inferType ctor; (mvars, _, resultType) ← forallMetaTelescopeReducing ctorType; +trace! `Meta.debug ("ctorName: " ++ ctorName ++ ", resultType: " ++ resultType ++ ", expectedType: " ++ expectedType); condM (isDefEq resultType expectedType) - (Option.some <$> mkCompatibleCtorPattern ref ctorName us params mvars varNamePrefix 0 #[] #[]) + (Option.some <$> mkCompatibleCtorPattern ref fvarDecls ctorName us params mvars varNamePrefix 0 #[] #[]) (pure none) -private def getCompatibleConstructors (ref : Syntax) (e : Expr) (varNamePrefix : Name) : MetaM (List (List LocalDecl × Pattern)) := do +/- Auxiliary method for `processComplete` -/ +private def getCompatibleConstructors (ref : Syntax) (fvarDecls : List LocalDecl) (e : Expr) (varNamePrefix : Name) : MetaM (List (List LocalDecl × Pattern)) := do env ← getEnv; expectedType ← inferType e; expectedType ← whnfD expectedType; @@ -278,7 +292,7 @@ match info with let params := Iargs.extract 0 val.nparams; val.ctors.foldlM (fun (result : List (List LocalDecl × Pattern)) ctor => do - entry? ← withNewMCtxDepth $ compatibleConstructor? ref ctor us params expectedType varNamePrefix; + entry? ← withNewMCtxDepth $ compatibleConstructor? ref fvarDecls ctor us params expectedType varNamePrefix; match entry? with | none => pure result | some entry => pure (entry :: result)) @@ -291,6 +305,7 @@ else match d with | LocalDecl.cdecl idx id n type bi => LocalDecl.cdecl idx id n (type.replaceFVarId fvarId e) bi | LocalDecl.ldecl idx id n type val => LocalDecl.ldecl idx id n (type.replaceFVarId fvarId e) (val.replaceFVarId fvarId e) +/- Auxiliary method for `processComplete` -/ private def processComplete (process : Problem → State → MetaM State) (p : Problem) (s : State) : MetaM State := withGoalOf p do env ← getEnv; @@ -301,7 +316,7 @@ newAlts ← p.alts.foldlM | p@(Pattern.var ref fvarId) :: ps => withExistingLocalDecls alt.fvarDecls do ldecl ← getLocalDecl fvarId; - dps ← getCompatibleConstructors p.ref (mkFVar fvarId) ldecl.userName; + dps ← getCompatibleConstructors p.ref alt.fvarDecls (mkFVar fvarId) ldecl.userName; expandedAlts ← dps.mapM fun ⟨newLocalDecls, p⟩ => do { e ← p.toExpr; let ps := ps.map fun p => p.replaceFVarId fvarId e; @@ -499,7 +514,7 @@ def ex2 (α : Type u) (n : Nat) (xs : Vec α n) (ys : Vec α n) : arbitrary _ #eval test `ex2 3 `elimTest2 -#check elimTest2 +#print elimTest2 def ex3 (α : Type u) (β : Type v) (n : Nat) (x : List α) (y : List β) : LHS (Pat ([] : List α) × Pat ([] : List β)) @@ -511,15 +526,19 @@ def ex3 (α : Type u) (β : Type v) (n : Nat) (x : List α) (y : List β) : #eval test `ex3 2 `elimTest3 #print elimTest3 -#exit - - -def ex4 (α : Type u) (n : Nat) (xs : Vec α n) (ys : Vec α n) : - LHS (Pat (inaccessible 0) × Pat (Vec.nil : Vec α 0) × Pat (Vec.nil : Vec α 0)) -× LHS (forall (n : Nat) (xs : Vec α n) (ys : Vec α n), Pat (inaccessible n) × Pat xs × Pat ys) := +def ex4 (α : Type u) (n : Nat) (xs : Vec α n) : + LHS (Pat (inaccessible 0) × Pat (Vec.nil : Vec α 0)) +× LHS (forall (n : Nat) (xs : Vec α (n+1)), Pat (inaccessible (n+1)) × Pat xs) := arbitrary _ -set_option trace.Meta.debug true - -#eval test `ex4 3 `elimTest4 +#eval test `ex4 2 `elimTest4 #check elimTest4 +#print elimTest4 + +def ex5 (α : Type u) (n : Nat) (xs : Vec α n) : + LHS (Pat Nat.zero × Pat (Vec.nil : Vec α 0)) +× LHS (forall (n : Nat) (xs : Vec α (n+1)), Pat (Nat.succ n) × Pat xs) := +arbitrary _ + +#eval test `ex5 2 `elimTest5 +#print elimTest5