feat: basic support for decimal numbers
This commit is contained in:
parent
133ecb111b
commit
facb28d080
11 changed files with 99 additions and 13 deletions
|
|
@ -21,3 +21,4 @@ import Init.Data.Random
|
|||
import Init.Data.ToString
|
||||
import Init.Data.Range
|
||||
import Init.Data.Hashable
|
||||
import Init.Data.OfDecimal
|
||||
|
|
|
|||
25
src/Init/Data/OfDecimal.lean
Normal file
25
src/Init/Data/OfDecimal.lean
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
/-
|
||||
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Init.Data.Float
|
||||
import Init.Data.Nat
|
||||
|
||||
/- For decimal numbers (e.g., `1.23`).
|
||||
The Lean frontend uses `OfDecimal.ofDecimal 123 2` to represent `1.23` -/
|
||||
class OfDecimal (α : Type u) where
|
||||
ofDecimal : Nat → Nat → α
|
||||
|
||||
def Float.fromDecimal (m : Nat) (e : Nat) : Float :=
|
||||
fromDec (Float.ofNat m) e
|
||||
where
|
||||
fromDec (m : Float) (e : Nat) : Float :=
|
||||
match e with
|
||||
| 0 => m
|
||||
| e+1 => fromDec (m/10) e
|
||||
|
||||
@[defaultInstance]
|
||||
instance : OfDecimal Float where
|
||||
ofDecimal m e := Float.fromDecimal m e
|
||||
|
|
@ -285,6 +285,12 @@ def mkStrLit (val : String) (info : SourceInfo := {}) : Syntax :=
|
|||
def mkNumLit (val : String) (info : SourceInfo := {}) : Syntax :=
|
||||
mkLit numLitKind val info
|
||||
|
||||
def mkDecimalLit (val : String) (info : SourceInfo := {}) : Syntax :=
|
||||
mkLit decimalLitKind val info
|
||||
|
||||
def mkScientificLit (val : String) (info : SourceInfo := {}) : Syntax :=
|
||||
mkLit scientificLitKind val info
|
||||
|
||||
/- Recall that we don't have special Syntax constructors for storing numeric and string atoms.
|
||||
The idea is to have an extensible approach where embedded DSLs may have new kind of atoms and/or
|
||||
different ways of representing them. So, our atoms contain just the parsed string.
|
||||
|
|
@ -328,7 +334,7 @@ private partial def decodeDecimalLitAux (s : String) (i : String.Pos) (val : Nat
|
|||
if '0' ≤ c && c ≤ '9' then decodeDecimalLitAux s (s.next i) (10*val + c.toNat - '0'.toNat)
|
||||
else none
|
||||
|
||||
def decodeNatLitVal (s : String) : Option Nat :=
|
||||
def decodeNatLitVal? (s : String) : Option Nat :=
|
||||
let len := s.length
|
||||
if len == 0 then none
|
||||
else
|
||||
|
|
@ -356,9 +362,9 @@ def isLit? (litKind : SyntaxNodeKind) (stx : Syntax) : Option String :=
|
|||
none
|
||||
| _ => none
|
||||
|
||||
def isNatLitAux (litKind : SyntaxNodeKind) (stx : Syntax) : Option Nat :=
|
||||
private def isNatLitAux (litKind : SyntaxNodeKind) (stx : Syntax) : Option Nat :=
|
||||
match isLit? litKind stx with
|
||||
| some val => decodeNatLitVal val
|
||||
| some val => decodeNatLitVal? val
|
||||
| _ => none
|
||||
|
||||
def isNatLit? (s : Syntax) : Option Nat :=
|
||||
|
|
@ -367,6 +373,32 @@ def isNatLit? (s : Syntax) : Option Nat :=
|
|||
def isFieldIdx? (s : Syntax) : Option Nat :=
|
||||
isNatLitAux fieldIdxKind s
|
||||
|
||||
partial def decodeDecimalLitVal? (s : String) : Option (Nat × Nat) :=
|
||||
let len := s.length
|
||||
if len == 0 then none
|
||||
else
|
||||
let c := s.get 0
|
||||
if c.isDigit then
|
||||
decode 0 0 0 false
|
||||
else none
|
||||
where
|
||||
decode (i : String.Pos) (val : Nat) (e : Nat) (foundDot : Bool) : Option (Nat × Nat) :=
|
||||
if s.atEnd i then
|
||||
if foundDot then some (val, e) else none
|
||||
else
|
||||
let c := s.get i
|
||||
if '0' ≤ c && c ≤ '9' then
|
||||
decode (s.next i) (10*val + c.toNat - '0'.toNat) (if foundDot then e+1 else e) foundDot
|
||||
else if c == '.' && !foundDot then
|
||||
decode (s.next i) val e true
|
||||
else
|
||||
none
|
||||
|
||||
def isDecimalLit? (stx : Syntax) : Option (Nat × Nat) :=
|
||||
match isLit? decimalLitKind stx with
|
||||
| some val => decodeDecimalLitVal? val
|
||||
| _ => none
|
||||
|
||||
def isIdOrAtom? : Syntax → Option String
|
||||
| Syntax.atom _ val => some val
|
||||
| Syntax.ident _ rawVal _ _ => some rawVal.toString
|
||||
|
|
|
|||
|
|
@ -1595,6 +1595,8 @@ def identKind : SyntaxNodeKind := `ident
|
|||
def strLitKind : SyntaxNodeKind := `strLit
|
||||
def charLitKind : SyntaxNodeKind := `charLit
|
||||
def numLitKind : SyntaxNodeKind := `numLit
|
||||
def decimalLitKind : SyntaxNodeKind := `decimalLit
|
||||
def scientificLitKind : SyntaxNodeKind := `scientificLit
|
||||
def nameLitKind : SyntaxNodeKind := `nameLit
|
||||
def fieldIdxKind : SyntaxNodeKind := `fieldIdx
|
||||
def interpolatedStrLitKind : SyntaxNodeKind := `interpolatedStrLitKind
|
||||
|
|
|
|||
|
|
@ -481,6 +481,10 @@ partial def collect : Syntax → M Syntax
|
|||
pure stx
|
||||
else if k == numLitKind then
|
||||
pure stx
|
||||
else if k == decimalLitKind then
|
||||
pure stx
|
||||
else if k == scientificLitKind then
|
||||
pure stx
|
||||
else if k == charLitKind then
|
||||
pure stx
|
||||
else if k == `Lean.Parser.Term.quotedName then
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ namespace Lean
|
|||
namespace Parser
|
||||
|
||||
def isLitKind (k : SyntaxNodeKind) : Bool :=
|
||||
k == strLitKind || k == numLitKind || k == charLitKind || k == nameLitKind
|
||||
k == strLitKind || k == numLitKind || k == charLitKind || k == nameLitKind || k == decimalLitKind || k == scientificLitKind
|
||||
|
||||
abbrev mkAtom (info : SourceInfo) (val : String) : Syntax :=
|
||||
Syntax.atom info val
|
||||
|
|
@ -812,16 +812,17 @@ def decimalNumberFn (startPos : Nat) : ParserFn := fun c s =>
|
|||
let input := c.input
|
||||
let i := s.pos
|
||||
let curr := input.get i
|
||||
let s :=
|
||||
/- TODO(Leo): should we use a different kind for numerals containing decimal points? -/
|
||||
if curr == '.' then
|
||||
let i := input.next i
|
||||
let curr := input.get i
|
||||
if curr == '.' then
|
||||
let i := input.next i
|
||||
let curr := input.get i
|
||||
let s :=
|
||||
if curr.isDigit then
|
||||
takeWhileFn (fun c => c.isDigit) c (s.setPos i)
|
||||
else s
|
||||
else s
|
||||
mkNodeToken numLitKind startPos c s
|
||||
else
|
||||
s
|
||||
mkNodeToken decimalLitKind startPos c s
|
||||
else
|
||||
mkNodeToken numLitKind startPos c s
|
||||
|
||||
def binNumberFn (startPos : Nat) : ParserFn := fun c s =>
|
||||
let s := takeWhile1Fn (fun c => c == '0' || c == '1') "binary number" c s
|
||||
|
|
@ -1161,6 +1162,17 @@ def numLitFn : ParserFn :=
|
|||
info := mkAtomicInfo "numLit"
|
||||
}
|
||||
|
||||
def decimalLitFn : ParserFn :=
|
||||
fun c s =>
|
||||
let iniPos := s.pos
|
||||
let s := tokenFn c s
|
||||
if s.hasError || !(s.stxStack.back.isOfKind decimalLitKind) then s.mkErrorAt "decimal number" iniPos else s
|
||||
|
||||
@[inline] def decimalLitNoAntiquot : Parser := {
|
||||
fn := decimalLitFn,
|
||||
info := mkAtomicInfo "decimalLit"
|
||||
}
|
||||
|
||||
def strLitFn : ParserFn := fun c s =>
|
||||
let iniPos := s.pos
|
||||
let s := tokenFn c s
|
||||
|
|
@ -1594,6 +1606,9 @@ def rawIdent : Parser :=
|
|||
def numLit : Parser :=
|
||||
withAntiquot (mkAntiquot "numLit" numLitKind) numLitNoAntiquot
|
||||
|
||||
def decimalLit : Parser :=
|
||||
withAntiquot (mkAntiquot "decimalLit" decimalLitKind) decimalLitNoAntiquot
|
||||
|
||||
def strLit : Parser :=
|
||||
withAntiquot (mkAntiquot "strLit" strLitKind) strLitNoAntiquot
|
||||
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ builtin_initialize
|
|||
registerBuiltinNodeKind identKind
|
||||
registerBuiltinNodeKind strLitKind
|
||||
registerBuiltinNodeKind numLitKind
|
||||
registerBuiltinNodeKind decimalLitKind
|
||||
registerBuiltinNodeKind scientificLitKind
|
||||
registerBuiltinNodeKind charLitKind
|
||||
registerBuiltinNodeKind nameLitKind
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ namespace Parser
|
|||
-- (because `Parser.Extension` depends on them)
|
||||
attribute [runBuiltinParserAttributeHooks]
|
||||
leadingNode termParser commandParser antiquotNestedExpr antiquotExpr mkAntiquot nodeWithAntiquot
|
||||
ident numLit charLit strLit nameLit
|
||||
ident numLit decimalLit charLit strLit nameLit
|
||||
|
||||
@[runBuiltinParserAttributeHooks, inline] def group (p : Parser) : Parser :=
|
||||
node nullKind p
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ def optSemicolon (p : Parser) : Parser := ppDedent $ optional ";" >> ppLine >> p
|
|||
-- `checkPrec` necessary for the pretty printer
|
||||
@[builtinTermParser] def ident := checkPrec maxPrec >> Parser.ident
|
||||
@[builtinTermParser] def num : Parser := checkPrec maxPrec >> numLit
|
||||
@[builtinTermParser] def decimal : Parser := checkPrec maxPrec >> decimalLit
|
||||
@[builtinTermParser] def str : Parser := checkPrec maxPrec >> strLit
|
||||
@[builtinTermParser] def char : Parser := checkPrec maxPrec >> charLit
|
||||
@[builtinTermParser] def type := parser! "Type" >> optional (checkWsBefore "" >> checkPrec leadPrec >> checkColGt >> levelParser maxPrec)
|
||||
|
|
|
|||
|
|
@ -353,6 +353,7 @@ def visitAtom (k : SyntaxNodeKind) : Formatter := do
|
|||
@[combinatorFormatter Lean.Parser.strLitNoAntiquot] def strLitNoAntiquot.formatter := visitAtom strLitKind
|
||||
@[combinatorFormatter Lean.Parser.nameLitNoAntiquot] def nameLitNoAntiquot.formatter := visitAtom nameLitKind
|
||||
@[combinatorFormatter Lean.Parser.numLitNoAntiquot] def numLitNoAntiquot.formatter := visitAtom numLitKind
|
||||
@[combinatorFormatter Lean.Parser.decimalLitNoAntiquot] def decimalLitNoAntiquot.formatter := visitAtom decimalLitKind
|
||||
@[combinatorFormatter Lean.Parser.fieldIdx] def fieldIdx.formatter := visitAtom fieldIdxKind
|
||||
|
||||
@[combinatorFormatter Lean.Parser.many]
|
||||
|
|
@ -439,6 +440,7 @@ builtin_initialize
|
|||
registerAlias "ws" checkWsBefore.formatter
|
||||
registerAlias "noWs" checkNoWsBefore.formatter
|
||||
registerAlias "num" (withAntiquot.formatter (mkAntiquot.formatter' "numLit" `numLit) numLitNoAntiquot.formatter)
|
||||
registerAlias "decimal" (withAntiquot.formatter (mkAntiquot.formatter' "decimalLit" `decimalLit) decimalLitNoAntiquot.formatter)
|
||||
registerAlias "str" (withAntiquot.formatter (mkAntiquot.formatter' "strLit" `strLit) strLitNoAntiquot.formatter)
|
||||
registerAlias "char" (withAntiquot.formatter (mkAntiquot.formatter' "charLit" `charLit) charLitNoAntiquot.formatter)
|
||||
registerAlias "name" (withAntiquot.formatter (mkAntiquot.formatter' "nameLit" `nameLit) nameLitNoAntiquot.formatter)
|
||||
|
|
|
|||
|
|
@ -407,6 +407,7 @@ def trailingNode.parenthesizer (k : SyntaxNodeKind) (prec : Nat) (p : Parenthesi
|
|||
@[combinatorParenthesizer Lean.Parser.strLitNoAntiquot] def strLitNoAntiquot.parenthesizer := visitToken
|
||||
@[combinatorParenthesizer Lean.Parser.nameLitNoAntiquot] def nameLitNoAntiquot.parenthesizer := visitToken
|
||||
@[combinatorParenthesizer Lean.Parser.numLitNoAntiquot] def numLitNoAntiquot.parenthesizer := visitToken
|
||||
@[combinatorParenthesizer Lean.Parser.decimalLitNoAntiquot] def decimalLitNoAntiquot.parenthesizer := visitToken
|
||||
@[combinatorParenthesizer Lean.Parser.fieldIdx] def fieldIdx.parenthesizer := visitToken
|
||||
|
||||
@[combinatorParenthesizer Lean.Parser.many]
|
||||
|
|
@ -494,6 +495,7 @@ builtin_initialize
|
|||
registerAlias "ws" checkWsBefore.parenthesizer
|
||||
registerAlias "noWs" checkNoWsBefore.parenthesizer
|
||||
registerAlias "num" (withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "numLit" `numLit) numLitNoAntiquot.parenthesizer)
|
||||
registerAlias "decimal" (withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "decimalLit" `decimalLit) decimalLitNoAntiquot.parenthesizer)
|
||||
registerAlias "str" (withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "strLit" `strLit) strLitNoAntiquot.parenthesizer)
|
||||
registerAlias "char" (withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "charLit" `charLit) charLitNoAntiquot.parenthesizer)
|
||||
registerAlias "name" (withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "nameLit" `nameLit) nameLitNoAntiquot.parenthesizer)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue