diff --git a/src/Lean/Meta/EqnCompiler/DepElim.lean b/src/Lean/Meta/EqnCompiler/DepElim.lean index a163c5c58b..dc99af330c 100644 --- a/src/Lean/Meta/EqnCompiler/DepElim.lean +++ b/src/Lean/Meta/EqnCompiler/DepElim.lean @@ -18,99 +18,91 @@ namespace DepElim abbrev VarId := Name inductive Pattern (internal : Bool := false) : Type -| inaccessible (ref : Syntax) (e : Expr) : Pattern -| var (ref : Syntax) (varId : VarId) : Pattern -| ctor (ref : Syntax) (ctorName : Name) (us : List Level) (params : List Expr) (fields : List Pattern) : Pattern -| val (ref : Syntax) (e : Expr) : Pattern -| arrayLit (ref : Syntax) (type : Expr) (xs : List Pattern) : Pattern -| as (ref : Syntax) (varId : VarId) (p : Pattern) : Pattern +| inaccessible (e : Expr) : Pattern +| var (varId : VarId) : Pattern +| ctor (ctorName : Name) (us : List Level) (params : List Expr) (fields : List Pattern) : Pattern +| val (e : Expr) : Pattern +| arrayLit (type : Expr) (xs : List Pattern) : Pattern +| as (varId : VarId) (p : Pattern) : Pattern abbrev IPattern := Pattern true namespace Pattern -instance {b} : Inhabited (Pattern b) := ⟨Pattern.inaccessible Syntax.missing (arbitrary _)⟩ - -def ref {b : Bool} : Pattern b → Syntax -| inaccessible r _ => r -| var r _ => r -| ctor r _ _ _ _ => r -| val r _ => r -| arrayLit r _ _ => r -| as r _ _ => r +instance {b} : Inhabited (Pattern b) := ⟨Pattern.inaccessible (arbitrary _)⟩ partial def toMessageData {b : Bool} : Pattern b → MessageData -| inaccessible _ e => ".(" ++ e ++ ")" -| var _ varId => if b then mkMVar varId else mkFVar varId -| ctor _ ctorName _ _ [] => ctorName -| ctor _ ctorName _ _ pats => "(" ++ ctorName ++ pats.foldl (fun (msg : MessageData) pat => msg ++ " " ++ toMessageData pat) Format.nil ++ ")" -| val _ e => "val!(" ++ e ++ ")" -| arrayLit _ _ pats => "#[" ++ MessageData.joinSep (pats.map toMessageData) ", " ++ "]" -| as _ varId p => (if b then mkMVar varId else mkFVar varId) ++ "@" ++toMessageData p +| inaccessible e => ".(" ++ e ++ ")" +| var varId => if b then mkMVar varId else mkFVar varId +| ctor ctorName _ _ [] => ctorName +| ctor ctorName _ _ pats => "(" ++ ctorName ++ pats.foldl (fun (msg : MessageData) pat => msg ++ " " ++ toMessageData pat) Format.nil ++ ")" +| val e => "val!(" ++ e ++ ")" +| arrayLit _ pats => "#[" ++ MessageData.joinSep (pats.map toMessageData) ", " ++ "]" +| as varId p => (if b then mkMVar varId else mkFVar varId) ++ "@" ++toMessageData p partial def toExpr {b} : Pattern b → MetaM Expr -| inaccessible _ e => pure e -| var _ varId => if b then pure (mkMVar varId) else pure (mkFVar varId) -| val _ e => pure e -| as _ _ p => toExpr p -| arrayLit _ type xs => do +| inaccessible e => pure e +| var varId => if b then pure (mkMVar varId) else pure (mkFVar varId) +| val e => pure e +| as _ p => toExpr p +| arrayLit type xs => do xs ← xs.mapM toExpr; mkArrayLit type xs -| ctor _ ctorName us params fields => do +| ctor ctorName us params fields => do fields ← fields.mapM toExpr; pure $ mkAppN (mkConst ctorName us) (params ++ fields).toArray /- Apply the free variable substitution `s` to the given (internal) pattern -/ partial def applyFVarSubst (s : FVarSubst) : Pattern true → IPattern -| inaccessible r e => inaccessible r $ s.apply e -| ctor r n us ps fs => ctor r n us (ps.map s.apply) $ fs.map applyFVarSubst -| val r e => val r $ s.apply e -| arrayLit r t xs => arrayLit r (s.apply t) $ xs.map applyFVarSubst -| var r id => var r id -| as r v p => as r v $ applyFVarSubst p +| inaccessible e => inaccessible $ s.apply e +| ctor n us ps fs => ctor n us (ps.map s.apply) $ fs.map applyFVarSubst +| val e => val $ s.apply e +| arrayLit t xs => arrayLit (s.apply t) $ xs.map applyFVarSubst +| var id => var id +| as v p => as v $ applyFVarSubst p partial def instantiateMVars : IPattern → MetaM IPattern -| inaccessible r e => inaccessible r <$> Meta.instantiateMVars e -| ctor r n us ps fs => ctor r n us <$> ps.mapM Meta.instantiateMVars <*> fs.mapM instantiateMVars -| val r e => val r <$> Meta.instantiateMVars e -| arrayLit r t xs => arrayLit r <$> Meta.instantiateMVars t <*> xs.mapM instantiateMVars -| var ref mvarId => do +| inaccessible e => inaccessible <$> Meta.instantiateMVars e +| ctor n us ps fs => ctor n us <$> ps.mapM Meta.instantiateMVars <*> fs.mapM instantiateMVars +| val e => val <$> Meta.instantiateMVars e +| arrayLit t xs => arrayLit <$> Meta.instantiateMVars t <*> xs.mapM instantiateMVars +| var mvarId => do mctx ← getMCtx; match mctx.getExprAssignment? mvarId with - | some v => inaccessible ref <$> Meta.instantiateMVars v - | none => pure (var ref mvarId) -| as ref mvarId p => do + | some v => inaccessible <$> Meta.instantiateMVars v + | none => pure (var mvarId) +| as mvarId p => do mctx ← getMCtx; match mctx.getExprAssignment? mvarId with | some v => instantiateMVars p - | none => as ref mvarId <$> instantiateMVars p + | none => as mvarId <$> instantiateMVars p partial def applyMVarRenaming (m : MVarRenaming) : Pattern true → IPattern -| inaccessible r e => inaccessible r $ m.apply e -| ctor r n us ps fs => ctor r n us (ps.map m.apply) $ fs.map applyMVarRenaming -| val r e => val r $ m.apply e -| arrayLit r t xs => arrayLit r (m.apply t) $ xs.map applyMVarRenaming -| var ref mvarId => +| inaccessible e => inaccessible $ m.apply e +| ctor n us ps fs => ctor n us (ps.map m.apply) $ fs.map applyMVarRenaming +| val e => val $ m.apply e +| arrayLit t xs => arrayLit (m.apply t) $ xs.map applyMVarRenaming +| var mvarId => match m.find? mvarId with - | some newMVarId => var ref newMVarId - | none => var ref mvarId -| as ref mvarId p => + | some newMVarId => var newMVarId + | none => var mvarId +| as mvarId p => match m.find? mvarId with - | some newMVarId => as ref newMVarId $ applyMVarRenaming p - | none => as ref mvarId $ applyMVarRenaming p + | some newMVarId => as newMVarId $ applyMVarRenaming p + | none => as mvarId $ applyMVarRenaming p partial def toIPattern (s : FVarSubst) : Pattern → IPattern -| inaccessible r e => inaccessible r $ s.apply e -| ctor r n us ps fs => ctor r n us (ps.map s.apply) $ fs.map toIPattern -| val r e => val r $ s.apply e -| arrayLit r t xs => arrayLit r (s.apply t) $ xs.map toIPattern -| var ref fvarId => +| inaccessible e => inaccessible $ s.apply e +| ctor n us ps fs => ctor n us (ps.map s.apply) $ fs.map toIPattern +| val e => val $ s.apply e +| arrayLit t xs => arrayLit (s.apply t) $ xs.map toIPattern +| var fvarId => match s.get fvarId with - | Expr.mvar mvarId _ => Pattern.var ref mvarId + | Expr.mvar mvarId _ => Pattern.var mvarId | _ => unreachable! -| as ref fvarId p => +| as fvarId p => match s.get fvarId with - | Expr.mvar mvarId _ => Pattern.as ref mvarId $ toIPattern p + | Expr.mvar mvarId _ => Pattern.as mvarId $ toIPattern p | _ => unreachable! end Pattern @@ -340,21 +332,21 @@ match p.vars with private def hasAsPattern (p : Problem) : Bool := p.alts.any fun alt => match alt.patterns with - | Pattern.as _ _ _ :: _ => true - | _ => false + | Pattern.as _ _ :: _ => true + | _ => false /- Return true if the next pattern of each remaining alternative is an inaccessible term or a variable -/ private def isVariableTransition (p : Problem) : Bool := p.alts.all fun alt => match alt.patterns with - | Pattern.inaccessible _ _ :: _ => true - | Pattern.var _ _ :: _ => true - | _ => false + | Pattern.inaccessible _ :: _ => true + | Pattern.var _ :: _ => true + | _ => false /- Return true if the next pattern of each remaining alternative is a constructor application -/ private def isConstructorTransition (p : Problem) : Bool := p.alts.all fun alt => match alt.patterns with - | Pattern.ctor _ _ _ _ _ :: _ => true - | _ => false + | Pattern.ctor _ _ _ _ :: _ => true + | _ => false /- Return true if the next pattern of the remaining alternatives contain variables AND constructors. -/ private def isCompleteTransition (p : Problem) : Bool := @@ -362,9 +354,9 @@ let (ok, hasVar, hasCtor) := p.alts.foldl (fun (acc : Bool × Bool × Bool) (alt : Alt) => let (ok, hasVar, hasCtor) := acc; match alt.patterns with - | Pattern.ctor _ _ _ _ _ :: _ => (ok, hasVar, true) - | Pattern.var _ _ :: _ => (ok, true, hasCtor) - | _ => (false, hasVar, hasCtor)) + | Pattern.ctor _ _ _ _ :: _ => (ok, hasVar, true) + | Pattern.var _ :: _ => (ok, true, hasCtor) + | _ => (false, hasVar, hasCtor)) (true, false, false); ok && hasVar && hasCtor @@ -374,9 +366,9 @@ let (ok, hasVar, hasVal) := p.alts.foldl (fun (acc : Bool × Bool × Bool) (alt : Alt) => let (ok, hasVar, hasVal) := acc; match alt.patterns with - | Pattern.val _ _ :: _ => (ok, hasVar, true) - | Pattern.var _ _ :: _ => (ok, true, hasVal) - | _ => (false, hasVar, hasVal)) + | Pattern.val _ :: _ => (ok, hasVar, true) + | Pattern.var _ :: _ => (ok, true, hasVal) + | _ => (false, hasVar, hasVal)) (true, false, false); ok && hasVar && hasVal @@ -386,9 +378,9 @@ let (ok, hasVar, hasArray) := p.alts.foldl (fun (acc : Bool × Bool × Bool) (alt : Alt) => let (ok, hasVar, hasArray) := acc; match alt.patterns with - | Pattern.arrayLit _ _ _ :: _ => (ok, hasVar, true) - | Pattern.var _ _ :: _ => (ok, true, hasArray) - | _ => (false, hasVar, hasArray)) + | Pattern.arrayLit _ _ :: _ => (ok, hasVar, true) + | Pattern.var _ :: _ => (ok, true, hasArray) + | _ => (false, hasVar, hasArray)) (true, false, false); ok && hasVar && hasArray @@ -418,7 +410,7 @@ match p.vars with | [] => unreachable! | x :: xs => do alts ← p.alts.mapM fun alt => match alt.patterns with - | Pattern.as _ mvarId p :: ps => do + | Pattern.as mvarId p :: ps => do assignExprMVar mvarId x; rhs ← instantiateMVars alt.rhs; let mvars := alt.mvars.erase mvarId; @@ -434,8 +426,8 @@ match p.vars with | [] => unreachable! | x :: xs => do alts ← p.alts.mapM fun alt => match alt.patterns with - | Pattern.inaccessible _ _ :: ps => pure { alt with patterns := ps } - | Pattern.var _ mvarId :: ps => do + | Pattern.inaccessible _ :: ps => pure { alt with patterns := ps } + | Pattern.var mvarId :: ps => do -- trace! `Meta.EqnCompiler.matchDebug (">> assign " ++ mkMVar mvarId ++ " := " ++ x); assignExprMVar mvarId x; rhs ← instantiateMVars alt.rhs; @@ -449,8 +441,8 @@ match p.vars with private def isFirstPatternCtor (ctorName : Name) (alt : Alt) : Bool := match alt.patterns with -| Pattern.ctor _ n _ _ _ :: _ => n == ctorName -| _ => false +| Pattern.ctor n _ _ _ :: _ => n == ctorName +| _ => false private def processConstructor (process : Problem → State → MetaM State) (p : Problem) (s : State) : MetaM State := do trace! `Meta.EqnCompiler.match ("constructor step"); @@ -471,8 +463,8 @@ match p.vars with let examples := examples.map $ Example.applyFVarSubst subst; let newAlts := p.alts.filter $ isFirstPatternCtor subgoal.ctorName; let newAlts := newAlts.map fun alt => match alt.patterns with - | Pattern.ctor _ _ _ _ fields :: ps => { alt with patterns := fields ++ ps } - | _ => unreachable!; + | Pattern.ctor _ _ _ fields :: ps => { alt with patterns := fields ++ ps } + | _ => unreachable!; newAlts ← newAlts.mapM fun alt => alt.applyFVarSubst subst; newAlts ← newAlts.mapM fun alt => alt.copy; process { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples } s) @@ -481,17 +473,17 @@ match p.vars with private def throwInductiveTypeExpected {α} (type : Expr) : MetaM α := do throwOther ("failed to compile pattern matching, inductive type expected" ++ indentExpr type) -private partial def tryConstructorAux (alt : Alt) (ref : Syntax) (mvarId : MVarId) (ctorName : Name) (us : List Level) (params : Array Expr) (mvars : Array Expr) +private partial def tryConstructorAux (alt : Alt) (mvarId : MVarId) (ctorName : Name) (us : List Level) (params : Array Expr) (mvars : Array Expr) : Nat → Array MVarId → Array IPattern → MetaM Alt | i, newMVars, fields => do if h : i < mvars.size then do let mvar := mvars.get ⟨i, h⟩; e ← instantiateMVars mvar; match e with - | Expr.mvar mvarId _ => tryConstructorAux (i+1) (newMVars.push mvarId) (fields.push (Pattern.var ref mvarId)) - | _ => tryConstructorAux (i+1) newMVars (fields.push (Pattern.inaccessible ref e)) + | Expr.mvar mvarId _ => tryConstructorAux (i+1) (newMVars.push mvarId) (fields.push (Pattern.var mvarId)) + | _ => tryConstructorAux (i+1) newMVars (fields.push (Pattern.inaccessible e)) else do - let p := Pattern.ctor ref ctorName us params.toList fields.toList; + let p := Pattern.ctor ctorName us params.toList fields.toList; e ← p.toExpr; assignExprMVar mvarId e; ps ← alt.patterns.mapM Pattern.instantiateMVars; @@ -503,17 +495,17 @@ private partial def tryConstructorAux (alt : Alt) (ref : Syntax) (mvarId : MVarI mvars ← mvars.filterM fun mvarId => not <$> isExprMVarAssigned mvarId; pure { alt with rhs := rhs, mvars := mvars, patterns := ps } -private def tryConstructor? (alt : Alt) (ref : Syntax) (mvarId : MVarId) (ctorName : Name) (us : List Level) (params : Array Expr) (expectedType : Expr) +private def tryConstructor? (alt : Alt) (mvarId : MVarId) (ctorName : Name) (us : List Level) (params : Array Expr) (expectedType : Expr) : MetaM (Option Alt) := do let ctor := mkAppN (mkConst ctorName us) params; ctorType ← inferType ctor; (mvars, _, resultType) ← forallMetaTelescopeReducing ctorType; trace! `Meta.EqnCompiler.matchDebug ("ctorName: " ++ ctorName ++ ", resultType: " ++ resultType ++ ", expectedType: " ++ expectedType); condM (isDefEq resultType expectedType) - (Option.some <$> tryConstructorAux alt ref mvarId ctorName us params mvars 0 #[] #[]) + (Option.some <$> tryConstructorAux alt mvarId ctorName us params mvars 0 #[] #[]) (pure none) -private def expandAlt (alt : Alt) (ref : Syntax) (mvarId : MVarId) : MetaM (List Alt) := do +private def expandAlt (alt : Alt) (mvarId : MVarId) : MetaM (List Alt) := do env ← getEnv; mvarDecl ← getMVarDecl mvarId; let expectedType := mvarDecl.type; @@ -531,7 +523,7 @@ matchConst env expectedType.getAppFn (fun _ => throwInductiveTypeExpected expect let I := expectedType.getAppFn; let Iargs := expectedType.getAppArgs; let params := Iargs.extract 0 val.nparams; - alt? ← tryConstructor? alt ref mvarId ctor us params expectedType; + alt? ← tryConstructor? alt mvarId ctor us params expectedType; match alt? with | none => pure result | some alt => pure (alt :: result)) @@ -545,10 +537,10 @@ env ← getEnv; newAlts ← p.alts.foldlM (fun (newAlts : List Alt) alt => match alt.patterns with - | Pattern.ctor _ _ _ _ _ :: ps => pure (alt :: newAlts) - | p@(Pattern.var ref mvarId) :: ps => do + | Pattern.ctor _ _ _ _ :: ps => pure (alt :: newAlts) + | p@(Pattern.var mvarId) :: ps => do let alt := { alt with patterns := ps }; - alts ← expandAlt alt ref mvarId; + alts ← expandAlt alt mvarId; pure (alts ++ newAlts) | _ => unreachable!) []; @@ -558,14 +550,14 @@ private def collectValues (p : Problem) : Array Expr := p.alts.foldl (fun (values : Array Expr) alt => match alt.patterns with - | Pattern.val _ v :: _ => if values.contains v then values else values.push v - | _ => values) + | Pattern.val v :: _ => if values.contains v then values else values.push v + | _ => values) #[] private def isFirstPatternVar (alt : Alt) : Bool := match alt.patterns with -| Pattern.var _ _ :: _ => true -| _ => false +| Pattern.var _ :: _ => true +| _ => false private def processValue (process : Problem → State → MetaM State) (p : Problem) (s : State) : MetaM State := do trace! `Meta.EqnCompiler.match ("value step"); @@ -584,14 +576,14 @@ match p.vars with let examples := p.examples.map $ Example.replaceFVarId x.fvarId! (Example.val value); let examples := examples.map $ Example.applyFVarSubst subst; let newAlts := p.alts.filter fun alt => match alt.patterns with - | Pattern.val _ v :: _ => v == value - | Pattern.var _ _ :: _ => true + | Pattern.val v :: _ => v == value + | Pattern.var _ :: _ => true | _ => false; newAlts ← newAlts.mapM fun alt => alt.applyFVarSubst subst; newAlts ← newAlts.mapM fun alt => alt.copy; newAlts ← newAlts.mapM fun alt => match alt.patterns with - | Pattern.val _ _ :: ps => pure { alt with patterns := ps } - | Pattern.var _ mvarId :: ps => do + | Pattern.val _ :: ps => pure { alt with patterns := ps } + | Pattern.var mvarId :: ps => do assignExprMVar mvarId value; ps ← ps.mapM Pattern.instantiateMVars; rhs ← instantiateMVars alt.rhs; @@ -611,8 +603,8 @@ private def collectArraySizes (p : Problem) : Array Nat := p.alts.foldl (fun (sizes : Array Nat) alt => match alt.patterns with - | Pattern.arrayLit _ _ ps :: _ => let sz := ps.length; if sizes.contains sz then sizes else sizes.push sz - | _ => sizes) + | Pattern.arrayLit _ ps :: _ => let sz := ps.length; if sizes.contains sz then sizes else sizes.push sz + | _ => sizes) #[] private def processArrayLit (process : Problem → State → MetaM State) (p : Problem) (s : State) : MetaM State := do @@ -635,14 +627,14 @@ match p.vars with let examples := p.examples.map $ Example.replaceFVarId x.fvarId! subex; let examples := examples.map $ Example.applyFVarSubst subst; let newAlts := p.alts.filter fun alt => match alt.patterns with - | Pattern.arrayLit _ _ ps :: _ => ps.length == size - | Pattern.var _ _ :: _ => true - | _ => false; + | Pattern.arrayLit _ ps :: _ => ps.length == size + | Pattern.var _ :: _ => true + | _ => false; newAlts ← newAlts.mapM fun alt => alt.applyFVarSubst subst; newAlts ← newAlts.mapM fun alt => alt.copy; newAlts ← newAlts.mapM fun alt => match alt.patterns with - | Pattern.arrayLit _ _ pats :: ps => pure { alt with patterns := pats ++ ps } - | Pattern.var ref mvarId :: ps => do + | Pattern.arrayLit _ pats :: ps => pure { alt with patterns := pats ++ ps } + | Pattern.var mvarId :: ps => do α ← getArrayArgType x; newMVars ← size.foldM (fun _ (newMVars : List Expr) => do @@ -655,7 +647,7 @@ match p.vars with rhs ← instantiateMVars alt.rhs; let mvars := alt.mvars.erase mvarId; let mvars := newMVars.map Expr.mvarId! ++ mvars; - let ps := newMVars.map (fun mvar => Pattern.var ref mvar.mvarId!) ++ ps; + let ps := newMVars.map (fun mvar => Pattern.var mvar.mvarId!) ++ ps; pure { alt with rhs := rhs, mvars := mvars, patterns := ps } | _ => unreachable!; process { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples } s diff --git a/tests/lean/run/depElim1.lean b/tests/lean/run/depElim1.lean index c10727b473..dc95dd3b4a 100644 --- a/tests/lean/run/depElim1.lean +++ b/tests/lean/run/depElim1.lean @@ -45,16 +45,16 @@ match e with partial def mkPattern : Expr → MetaM Pattern | e => if e.isAppOfArity `val 2 then - pure $ Pattern.val Syntax.missing e.appArg! + pure $ Pattern.val e.appArg! else if e.isAppOfArity `inaccessible 2 then - pure $ Pattern.inaccessible Syntax.missing e.appArg! + pure $ Pattern.inaccessible e.appArg! else if e.isFVar then - pure $ Pattern.var Syntax.missing e.fvarId! + pure $ Pattern.var e.fvarId! else if e.isAppOfArity `As 3 && (e.getArg! 1).isFVar then do let v := e.getArg! 1; let p := e.getArg! 2; p ← mkPattern p; - pure $ Pattern.as Syntax.missing v.fvarId! p + pure $ Pattern.as v.fvarId! p else if e.isAppOfArity `ArrayLit0 1 || e.isAppOfArity `ArrayLit1 2 || e.isAppOfArity `ArrayLit2 3 || @@ -64,14 +64,14 @@ partial def mkPattern : Expr → MetaM Pattern let type := args.get! 0; let ps := args.extract 1 args.size; ps ← ps.toList.mapM mkPattern; - pure $ Pattern.arrayLit Syntax.missing type ps + pure $ Pattern.arrayLit type ps else match e.arrayLit? with | some es => do pats ← es.mapM mkPattern; type ← inferType e; type ← whnfD type; let elemType := type.appArg!; - pure $ Pattern.arrayLit Syntax.missing elemType pats + pure $ Pattern.arrayLit elemType pats | none => do e ← whnfD e; r? ← constructorApp? e; @@ -81,7 +81,7 @@ partial def mkPattern : Expr → MetaM Pattern let params := args.extract 0 cval.nparams; let fields := args.extract cval.nparams args.size; pats ← fields.toList.mapM mkPattern; - pure $ Pattern.ctor Syntax.missing cval.name fn.constLevels! params.toList pats + pure $ Pattern.ctor cval.name fn.constLevels! params.toList pats partial def decodePats : Expr → MetaM (List Pattern) | e =>