feat: basic support for decimal numbers

This commit is contained in:
Leonardo de Moura 2020-12-02 14:47:27 -08:00
parent 133ecb111b
commit facb28d080
11 changed files with 99 additions and 13 deletions

View file

@ -21,3 +21,4 @@ import Init.Data.Random
import Init.Data.ToString
import Init.Data.Range
import Init.Data.Hashable
import Init.Data.OfDecimal

View file

@ -0,0 +1,25 @@
/-
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Init.Data.Float
import Init.Data.Nat
/- For decimal numbers (e.g., `1.23`).
The Lean frontend uses `OfDecimal.ofDecimal 123 2` to represent `1.23` -/
class OfDecimal (α : Type u) where
ofDecimal : Nat → Nat → α
def Float.fromDecimal (m : Nat) (e : Nat) : Float :=
fromDec (Float.ofNat m) e
where
fromDec (m : Float) (e : Nat) : Float :=
match e with
| 0 => m
| e+1 => fromDec (m/10) e
@[defaultInstance]
instance : OfDecimal Float where
ofDecimal m e := Float.fromDecimal m e

View file

@ -285,6 +285,12 @@ def mkStrLit (val : String) (info : SourceInfo := {}) : Syntax :=
def mkNumLit (val : String) (info : SourceInfo := {}) : Syntax :=
mkLit numLitKind val info
def mkDecimalLit (val : String) (info : SourceInfo := {}) : Syntax :=
mkLit decimalLitKind val info
def mkScientificLit (val : String) (info : SourceInfo := {}) : Syntax :=
mkLit scientificLitKind val info
/- Recall that we don't have special Syntax constructors for storing numeric and string atoms.
The idea is to have an extensible approach where embedded DSLs may have new kind of atoms and/or
different ways of representing them. So, our atoms contain just the parsed string.
@ -328,7 +334,7 @@ private partial def decodeDecimalLitAux (s : String) (i : String.Pos) (val : Nat
if '0' ≤ c && c ≤ '9' then decodeDecimalLitAux s (s.next i) (10*val + c.toNat - '0'.toNat)
else none
def decodeNatLitVal (s : String) : Option Nat :=
def decodeNatLitVal? (s : String) : Option Nat :=
let len := s.length
if len == 0 then none
else
@ -356,9 +362,9 @@ def isLit? (litKind : SyntaxNodeKind) (stx : Syntax) : Option String :=
none
| _ => none
def isNatLitAux (litKind : SyntaxNodeKind) (stx : Syntax) : Option Nat :=
private def isNatLitAux (litKind : SyntaxNodeKind) (stx : Syntax) : Option Nat :=
match isLit? litKind stx with
| some val => decodeNatLitVal val
| some val => decodeNatLitVal? val
| _ => none
def isNatLit? (s : Syntax) : Option Nat :=
@ -367,6 +373,32 @@ def isNatLit? (s : Syntax) : Option Nat :=
def isFieldIdx? (s : Syntax) : Option Nat :=
isNatLitAux fieldIdxKind s
partial def decodeDecimalLitVal? (s : String) : Option (Nat × Nat) :=
let len := s.length
if len == 0 then none
else
let c := s.get 0
if c.isDigit then
decode 0 0 0 false
else none
where
decode (i : String.Pos) (val : Nat) (e : Nat) (foundDot : Bool) : Option (Nat × Nat) :=
if s.atEnd i then
if foundDot then some (val, e) else none
else
let c := s.get i
if '0' ≤ c && c ≤ '9' then
decode (s.next i) (10*val + c.toNat - '0'.toNat) (if foundDot then e+1 else e) foundDot
else if c == '.' && !foundDot then
decode (s.next i) val e true
else
none
def isDecimalLit? (stx : Syntax) : Option (Nat × Nat) :=
match isLit? decimalLitKind stx with
| some val => decodeDecimalLitVal? val
| _ => none
def isIdOrAtom? : Syntax → Option String
| Syntax.atom _ val => some val
| Syntax.ident _ rawVal _ _ => some rawVal.toString

View file

@ -1595,6 +1595,8 @@ def identKind : SyntaxNodeKind := `ident
def strLitKind : SyntaxNodeKind := `strLit
def charLitKind : SyntaxNodeKind := `charLit
def numLitKind : SyntaxNodeKind := `numLit
def decimalLitKind : SyntaxNodeKind := `decimalLit
def scientificLitKind : SyntaxNodeKind := `scientificLit
def nameLitKind : SyntaxNodeKind := `nameLit
def fieldIdxKind : SyntaxNodeKind := `fieldIdx
def interpolatedStrLitKind : SyntaxNodeKind := `interpolatedStrLitKind

View file

@ -481,6 +481,10 @@ partial def collect : Syntax → M Syntax
pure stx
else if k == numLitKind then
pure stx
else if k == decimalLitKind then
pure stx
else if k == scientificLitKind then
pure stx
else if k == charLitKind then
pure stx
else if k == `Lean.Parser.Term.quotedName then

View file

@ -68,7 +68,7 @@ namespace Lean
namespace Parser
def isLitKind (k : SyntaxNodeKind) : Bool :=
k == strLitKind || k == numLitKind || k == charLitKind || k == nameLitKind
k == strLitKind || k == numLitKind || k == charLitKind || k == nameLitKind || k == decimalLitKind || k == scientificLitKind
abbrev mkAtom (info : SourceInfo) (val : String) : Syntax :=
Syntax.atom info val
@ -812,16 +812,17 @@ def decimalNumberFn (startPos : Nat) : ParserFn := fun c s =>
let input := c.input
let i := s.pos
let curr := input.get i
let s :=
/- TODO(Leo): should we use a different kind for numerals containing decimal points? -/
if curr == '.' then
let i := input.next i
let curr := input.get i
if curr == '.' then
let i := input.next i
let curr := input.get i
let s :=
if curr.isDigit then
takeWhileFn (fun c => c.isDigit) c (s.setPos i)
else s
else s
mkNodeToken numLitKind startPos c s
else
s
mkNodeToken decimalLitKind startPos c s
else
mkNodeToken numLitKind startPos c s
def binNumberFn (startPos : Nat) : ParserFn := fun c s =>
let s := takeWhile1Fn (fun c => c == '0' || c == '1') "binary number" c s
@ -1161,6 +1162,17 @@ def numLitFn : ParserFn :=
info := mkAtomicInfo "numLit"
}
def decimalLitFn : ParserFn :=
fun c s =>
let iniPos := s.pos
let s := tokenFn c s
if s.hasError || !(s.stxStack.back.isOfKind decimalLitKind) then s.mkErrorAt "decimal number" iniPos else s
@[inline] def decimalLitNoAntiquot : Parser := {
fn := decimalLitFn,
info := mkAtomicInfo "decimalLit"
}
def strLitFn : ParserFn := fun c s =>
let iniPos := s.pos
let s := tokenFn c s
@ -1594,6 +1606,9 @@ def rawIdent : Parser :=
def numLit : Parser :=
withAntiquot (mkAntiquot "numLit" numLitKind) numLitNoAntiquot
def decimalLit : Parser :=
withAntiquot (mkAntiquot "decimalLit" decimalLitKind) decimalLitNoAntiquot
def strLit : Parser :=
withAntiquot (mkAntiquot "strLit" strLitKind) strLitNoAntiquot

View file

@ -26,6 +26,8 @@ builtin_initialize
registerBuiltinNodeKind identKind
registerBuiltinNodeKind strLitKind
registerBuiltinNodeKind numLitKind
registerBuiltinNodeKind decimalLitKind
registerBuiltinNodeKind scientificLitKind
registerBuiltinNodeKind charLitKind
registerBuiltinNodeKind nameLitKind

View file

@ -15,7 +15,7 @@ namespace Parser
-- (because `Parser.Extension` depends on them)
attribute [runBuiltinParserAttributeHooks]
leadingNode termParser commandParser antiquotNestedExpr antiquotExpr mkAntiquot nodeWithAntiquot
ident numLit charLit strLit nameLit
ident numLit decimalLit charLit strLit nameLit
@[runBuiltinParserAttributeHooks, inline] def group (p : Parser) : Parser :=
node nullKind p

View file

@ -44,6 +44,7 @@ def optSemicolon (p : Parser) : Parser := ppDedent $ optional ";" >> ppLine >> p
-- `checkPrec` necessary for the pretty printer
@[builtinTermParser] def ident := checkPrec maxPrec >> Parser.ident
@[builtinTermParser] def num : Parser := checkPrec maxPrec >> numLit
@[builtinTermParser] def decimal : Parser := checkPrec maxPrec >> decimalLit
@[builtinTermParser] def str : Parser := checkPrec maxPrec >> strLit
@[builtinTermParser] def char : Parser := checkPrec maxPrec >> charLit
@[builtinTermParser] def type := parser! "Type" >> optional (checkWsBefore "" >> checkPrec leadPrec >> checkColGt >> levelParser maxPrec)

View file

@ -353,6 +353,7 @@ def visitAtom (k : SyntaxNodeKind) : Formatter := do
@[combinatorFormatter Lean.Parser.strLitNoAntiquot] def strLitNoAntiquot.formatter := visitAtom strLitKind
@[combinatorFormatter Lean.Parser.nameLitNoAntiquot] def nameLitNoAntiquot.formatter := visitAtom nameLitKind
@[combinatorFormatter Lean.Parser.numLitNoAntiquot] def numLitNoAntiquot.formatter := visitAtom numLitKind
@[combinatorFormatter Lean.Parser.decimalLitNoAntiquot] def decimalLitNoAntiquot.formatter := visitAtom decimalLitKind
@[combinatorFormatter Lean.Parser.fieldIdx] def fieldIdx.formatter := visitAtom fieldIdxKind
@[combinatorFormatter Lean.Parser.many]
@ -439,6 +440,7 @@ builtin_initialize
registerAlias "ws" checkWsBefore.formatter
registerAlias "noWs" checkNoWsBefore.formatter
registerAlias "num" (withAntiquot.formatter (mkAntiquot.formatter' "numLit" `numLit) numLitNoAntiquot.formatter)
registerAlias "decimal" (withAntiquot.formatter (mkAntiquot.formatter' "decimalLit" `decimalLit) decimalLitNoAntiquot.formatter)
registerAlias "str" (withAntiquot.formatter (mkAntiquot.formatter' "strLit" `strLit) strLitNoAntiquot.formatter)
registerAlias "char" (withAntiquot.formatter (mkAntiquot.formatter' "charLit" `charLit) charLitNoAntiquot.formatter)
registerAlias "name" (withAntiquot.formatter (mkAntiquot.formatter' "nameLit" `nameLit) nameLitNoAntiquot.formatter)

View file

@ -407,6 +407,7 @@ def trailingNode.parenthesizer (k : SyntaxNodeKind) (prec : Nat) (p : Parenthesi
@[combinatorParenthesizer Lean.Parser.strLitNoAntiquot] def strLitNoAntiquot.parenthesizer := visitToken
@[combinatorParenthesizer Lean.Parser.nameLitNoAntiquot] def nameLitNoAntiquot.parenthesizer := visitToken
@[combinatorParenthesizer Lean.Parser.numLitNoAntiquot] def numLitNoAntiquot.parenthesizer := visitToken
@[combinatorParenthesizer Lean.Parser.decimalLitNoAntiquot] def decimalLitNoAntiquot.parenthesizer := visitToken
@[combinatorParenthesizer Lean.Parser.fieldIdx] def fieldIdx.parenthesizer := visitToken
@[combinatorParenthesizer Lean.Parser.many]
@ -494,6 +495,7 @@ builtin_initialize
registerAlias "ws" checkWsBefore.parenthesizer
registerAlias "noWs" checkNoWsBefore.parenthesizer
registerAlias "num" (withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "numLit" `numLit) numLitNoAntiquot.parenthesizer)
registerAlias "decimal" (withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "decimalLit" `decimalLit) decimalLitNoAntiquot.parenthesizer)
registerAlias "str" (withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "strLit" `strLit) strLitNoAntiquot.parenthesizer)
registerAlias "char" (withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "charLit" `charLit) charLitNoAntiquot.parenthesizer)
registerAlias "name" (withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "nameLit" `nameLit) nameLitNoAntiquot.parenthesizer)