diff --git a/src/Lean/Expr.lean b/src/Lean/Expr.lean index 2744d3df7d..06c9bc3889 100644 --- a/src/Lean/Expr.lean +++ b/src/Lean/Expr.lean @@ -635,11 +635,13 @@ constant abstract (e : @& Expr) (xs : @& Array Expr) : Expr := arbitrary _ @[extern "lean_expr_abstract_range"] constant abstractRange (e : @& Expr) (n : @& Nat) (xs : @& Array Expr) : Expr := arbitrary _ +/-- Replace occurrences of the free variable `fvar` in `e` with `v` -/ def replaceFVar (e : Expr) (fvar : Expr) (v : Expr) : Expr := (e.abstract #[fvar]).instantiate1 v -def replaceFVarId (e : Expr) (fvarId : FVarId) (newFVarId : FVarId) : Expr := -replaceFVar e (mkFVar fvarId) (mkFVar newFVarId) +/-- Replace occurrences of the free variable `fvarId` in `e` with `v` -/ +def replaceFVarId (e : Expr) (fvarId : FVarId) (v : Expr) : Expr := +replaceFVar e (mkFVar fvarId) v instance : HasToString Expr := ⟨Expr.dbgToString⟩ diff --git a/src/Lean/LocalContext.lean b/src/Lean/LocalContext.lean index c634bcb296..752cc036b4 100644 --- a/src/Lean/LocalContext.lean +++ b/src/Lean/LocalContext.lean @@ -35,6 +35,10 @@ def index : LocalDecl → Nat | cdecl idx _ _ _ _ => idx | ldecl idx _ _ _ _ => idx +def setIndex : LocalDecl → Nat → LocalDecl +| cdecl _ id n t bi, idx => cdecl idx id n t bi +| ldecl _ id n t v, idx => ldecl idx id n t v + def fvarId : LocalDecl → FVarId | cdecl _ id _ _ _ => id | ldecl _ id _ _ _ => id @@ -112,7 +116,10 @@ match lctx with /- Low level API -/ def addDecl (lctx : LocalContext) (newDecl : LocalDecl) : LocalContext := match lctx with -| { fvarIdToDecl := map, decls := decls } => { fvarIdToDecl := map.insert newDecl.fvarId newDecl, decls := decls.set newDecl.index newDecl } +| { fvarIdToDecl := map, decls := decls } => + let idx := decls.size; + let newDecl := newDecl.setIndex idx; + { fvarIdToDecl := map.insert newDecl.fvarId newDecl, decls := decls.push newDecl } @[export lean_local_ctx_find] def find? (lctx : LocalContext) (fvarId : FVarId) : Option LocalDecl := diff --git a/tmp/eqns/prototype.lean b/tmp/eqns/prototype.lean index bf4f411d7a..9b07f0b1e2 100644 --- a/tmp/eqns/prototype.lean +++ b/tmp/eqns/prototype.lean @@ -18,6 +18,13 @@ namespace Pattern instance : Inhabited Pattern := ⟨Pattern.inaccessible Syntax.missing (arbitrary _)⟩ +def ref : Pattern → Syntax +| inaccessible r _ => r +| var r _ => r +| ctor r _ _ _ _ => r +| val r _ => r +| arrayLit r _ _ => r + partial def toMessageData : Pattern → MessageData | inaccessible _ e => ".(" ++ e ++ ")" | var _ fvarId => mkFVar fvarId @@ -37,13 +44,7 @@ partial def toExpr : Pattern → MetaM Expr fields ← fields.mapM toExpr; pure $ mkAppN (mkConst ctorName us) (params ++ fields).toArray -partial def replaceFVarId (fvarId : FVarId) (newFVarId : FVarId) : Pattern → Pattern -| inaccessible r e => inaccessible r $ e.replaceFVarId fvarId newFVarId -| val r e => val r $ e.replaceFVarId fvarId newFVarId -| ctor r n us ps fs => ctor r n us (ps.map fun p => p.replaceFVarId fvarId newFVarId) (fs.map replaceFVarId) -| arrayLit r t xs => arrayLit r (t.replaceFVarId fvarId newFVarId) (xs.map replaceFVarId) -| p@(var r id) => if fvarId == id then var r newFVarId else p - +/- Apply the free variable substitution `s` to the given pattern -/ partial def applyFVarSubst (s : FVarSubst) : Pattern → Pattern | inaccessible r e => inaccessible r $ e.applyFVarSubst s | ctor r n us ps fs => ctor r n us (ps.map fun p => p.applyFVarSubst s) (fs.map applyFVarSubst) @@ -51,6 +52,22 @@ partial def applyFVarSubst (s : FVarSubst) : Pattern → Pattern | arrayLit r t xs => arrayLit r (t.applyFVarSubst s) (xs.map applyFVarSubst) | var r fvarId => var r $ s.get fvarId +/- Shorthand for applying an unary variable substitution -/ +def substFVarId (fvarId : FVarId) (newFVarId : FVarId) (p : Pattern) : Pattern := +let s : FVarSubst := {}; +p.applyFVarSubst (s.insert fvarId newFVarId) + +/- + Replace occurrences of the free variable `fvarId` with `e` in the pattern `p`. + Remark: the nested pattern `var _ fvarId` are replaced with `inaccessible _ e`. + This function is used by `processComplete`. -/ +partial def replaceFVarId (fvarId : FVarId) (e : Expr) : Pattern → Pattern +| inaccessible r e => inaccessible r $ e.replaceFVarId fvarId e +| val r e => val r $ e.replaceFVarId fvarId e +| ctor r n us ps fs => ctor r n us (ps.map fun p => p.replaceFVarId fvarId e) (fs.map replaceFVarId) +| arrayLit r t xs => arrayLit r (t.replaceFVarId fvarId e) (xs.map replaceFVarId) +| p@(var r id) => if fvarId == id then inaccessible r e else p + end Pattern structure AltLHS := @@ -107,7 +124,6 @@ end Problem structure ElimResult := (elim : Expr) -- The eliminator. It is not just `Expr.const elimName` because the type of the major premises may contain free variables. - /- The number of patterns in each AltLHS must be equal to majors.length -/ private def checkNumPatterns (majors : List Expr) (lhss : List AltLHS) : MetaM Unit := let num := majors.length; @@ -167,6 +183,18 @@ p.alts.all fun alt => match alt.patterns with | Pattern.ctor _ _ _ _ _ :: _ => true | _ => false +/- Return true if the next pattern of the remaining alternatives contain variables AND constructors. -/ +private def isCompleteTransition (p : Problem) : Bool := +let (ok, hasVar, hasCtor) := p.alts.foldl + (fun (acc : Bool × Bool × Bool) (alt : Alt) => + let (ok, hasVar, hasCtor) := acc; + match alt.patterns with + | Pattern.ctor _ _ _ _ _ :: _ => (ok, hasVar, true) + | Pattern.var _ _ :: _ => (ok, true, hasCtor) + | _ => (false, hasVar, hasCtor)) + (true, false, false); +ok && hasVar && hasCtor + private def processLeaf (p : Problem) (s : State) : MetaM State := do let alt := p.alts.head!; assignGoalOf p alt.rhs; @@ -177,7 +205,7 @@ match p.vars with | x :: xs => let alts := p.alts.map fun alt => match alt.patterns with | Pattern.inaccessible _ _ :: ps => { alt with patterns := ps } - | Pattern.var _ fvarId :: ps => { alt with patterns := ps.map (fun p => p.replaceFVarId fvarId x.fvarId!), rhs := alt.rhs.replaceFVarId fvarId x.fvarId! } + | Pattern.var _ fvarId :: ps => { alt with patterns := ps.map (fun p => p.substFVarId fvarId x.fvarId!), rhs := alt.rhs.replaceFVarId fvarId x } | _ => unreachable!; process { p with alts := alts, vars := xs } s | _ => unreachable! @@ -209,6 +237,83 @@ match p.vars with s | _ => unreachable! +private def throwInductiveTypeExpected {α} (e : Expr) : MetaM α := do +t ← inferType e; +throwOther ("failed to compile pattern matching, inductive type expected" ++ indentExpr e ++ Format.line ++ "has type" ++ indentExpr t) + +private partial def mkCompatibleCtorPattern (ref : Syntax) (ctorName : Name) (us : List Level) (params : Array Expr) (mvars : Array Expr) (varNamePrefix : Name) + : Nat → Array LocalDecl → Array Pattern → MetaM (List LocalDecl × Pattern) +| i, newDecls, fields => + if h : i < mvars.size then do + let mvar := mvars.get ⟨i, h⟩; + e ← instantiateMVars mvar; + match e with + | Expr.mvar _ _ => do + type ← inferType e; + withLocalDecl (varNamePrefix.appendIndexAfter i) type BinderInfo.default fun x => do + decl ← getLocalDecl x.fvarId!; + mkCompatibleCtorPattern (i+1) (newDecls.push decl) (fields.push (Pattern.var ref decl.fvarId)) + | _ => mkCompatibleCtorPattern (i+1) newDecls (fields.push (Pattern.inaccessible ref e)) + else + pure (newDecls.toList, Pattern.ctor ref ctorName us params.toList fields.toList) + +private partial def compatibleConstructor? (ref : Syntax) (ctorName : Name) (us : List Level) (params : Array Expr) (expectedType : Expr) + (varNamePrefix : Name) : MetaM (Option (List LocalDecl × Pattern)) := do +let ctor := mkAppN (mkConst ctorName us) params; +ctorType ← inferType ctor; +(mvars, _, resultType) ← forallMetaTelescopeReducing ctorType; +condM (isDefEq resultType expectedType) + (Option.some <$> mkCompatibleCtorPattern ref ctorName us params mvars varNamePrefix 0 #[] #[]) + (pure none) + +private def getCompatibleConstructors (ref : Syntax) (e : Expr) (varNamePrefix : Name) : MetaM (List (List LocalDecl × Pattern)) := do +env ← getEnv; +expectedType ← inferType e; +expectedType ← whnfD expectedType; +let I := expectedType.getAppFn; +let Iargs := expectedType.getAppArgs; +matchConst env I (fun _ => throwInductiveTypeExpected e) fun info us => +match info with +| ConstantInfo.inductInfo val => + let params := Iargs.extract 0 val.nparams; + val.ctors.foldlM + (fun (result : List (List LocalDecl × Pattern)) ctor => do + entry? ← withNewMCtxDepth $ compatibleConstructor? ref ctor us params expectedType varNamePrefix; + match entry? with + | none => pure result + | some entry => pure (entry :: result)) + [] +| _ => throwInductiveTypeExpected e + +private def replaceFVarIdAtLocalDecl (fvarId : FVarId) (e : Expr) (d : LocalDecl) : LocalDecl := +if d.fvarId == fvarId then d +else match d with + | LocalDecl.cdecl idx id n type bi => LocalDecl.cdecl idx id n (type.replaceFVarId fvarId e) bi + | LocalDecl.ldecl idx id n type val => LocalDecl.ldecl idx id n (type.replaceFVarId fvarId e) (val.replaceFVarId fvarId e) + +private def processComplete (process : Problem → State → MetaM State) (p : Problem) (s : State) : MetaM State := +withGoalOf p do +env ← getEnv; +newAlts ← p.alts.foldlM + (fun (newAlts : List Alt) alt => + match alt.patterns with + | Pattern.ctor _ _ _ _ _ :: ps => pure (alt :: newAlts) + | p@(Pattern.var ref fvarId) :: ps => + withExistingLocalDecls alt.fvarDecls do + ldecl ← getLocalDecl fvarId; + dps ← getCompatibleConstructors p.ref (mkFVar fvarId) ldecl.userName; + expandedAlts ← dps.mapM fun ⟨newLocalDecls, p⟩ => do { + e ← p.toExpr; + let ps := ps.map fun p => p.replaceFVarId fvarId e; + let rhs := alt.rhs.replaceFVarId fvarId e; + let fvarDecls := alt.fvarDecls.map (replaceFVarIdAtLocalDecl fvarId e); + pure { alt with patterns := p :: ps, fvarDecls := fvarDecls ++ newLocalDecls, rhs := rhs } + }; + pure (expandedAlts ++ newAlts) + | _ => unreachable!) + []; +process { p with alts := newAlts.reverse } s + private partial def process : Problem → State → MetaM State | p, s => withIncRecDepth do withGoalOf p (traceM `Meta.debug p.toMessageData); @@ -218,6 +323,8 @@ private partial def process : Problem → State → MetaM State processVariable process p s else if isConstructorTransition p then processConstructor process p s + else if isCompleteTransition p then + processComplete process p s else do msg ← p.toMessageData; -- TODO: remaining cases @@ -392,17 +499,27 @@ def ex2 (α : Type u) (n : Nat) (xs : Vec α n) (ys : Vec α n) : arbitrary _ #eval test `ex2 3 `elimTest2 -#print elimTest2 +#check elimTest2 -#exit - -def ex2 (α : Type u) (β : Type v) (n : Nat) (x : List α) (y : List β) : +def ex3 (α : Type u) (β : Type v) (n : Nat) (x : List α) (y : List β) : LHS (Pat ([] : List α) × Pat ([] : List β)) -× LHS (forall (a : α) (b : α), Pat [a] × Pat [b]) +× LHS (forall (a : α) (b : β), Pat [a] × Pat [b]) × LHS (forall (a₁ a₂ : α) (as : List α) (b₁ b₂ : β) (bs : List β), Pat (a₁::a₂::as) × Pat (b₁::b₂::bs)) × LHS (forall (as : List α) (bs : List β), Pat as × Pat bs) := arbitrary _ -#eval test `ex2 2 `elimTest2 +#eval test `ex3 2 `elimTest3 +#print elimTest3 -#print elimTest2 +#exit + + +def ex4 (α : Type u) (n : Nat) (xs : Vec α n) (ys : Vec α n) : + LHS (Pat (inaccessible 0) × Pat (Vec.nil : Vec α 0) × Pat (Vec.nil : Vec α 0)) +× LHS (forall (n : Nat) (xs : Vec α n) (ys : Vec α n), Pat (inaccessible n) × Pat xs × Pat ys) := +arbitrary _ + +set_option trace.Meta.debug true + +#eval test `ex4 3 `elimTest4 +#check elimTest4