feat: scientific notation

This commit is contained in:
Leonardo de Moura 2020-12-03 07:49:20 -08:00
parent 962cffbaaa
commit d1f4d4f57e
12 changed files with 133 additions and 42 deletions

View file

@ -98,4 +98,4 @@ abbrev Nat.toFloat (n : Nat) : Float :=
instance : Pow Float := ⟨Float.pow⟩
@[extern "lean_float_of_decimal"] constant Float.ofDecimal (m : Nat) (e : Nat) : Float
@[extern "lean_float_of_decimal"] constant Float.ofDecimal (m : Nat) (s : Bool) (e : Nat) : Float

View file

@ -8,10 +8,13 @@ import Init.Data.Float
import Init.Data.Nat
/- For decimal numbers (e.g., `1.23`).
The Lean frontend uses `OfDecimal.ofDecimal 123 2` to represent `1.23` -/
Examples:
- `OfDecimal.ofDecimal 123 true 2` represents `1.23`
- `Ofdecimal.ofdecimal 121 false 100` represents `121e100`
-/
class OfDecimal (α : Type u) where
ofDecimal : Nat → Nat → α
ofDecimal : Nat → Bool → Nat → α
@[defaultInstance]
instance : OfDecimal Float where
ofDecimal m e := Float.ofDecimal m e
ofDecimal m s e := Float.ofDecimal m s e

View file

@ -288,9 +288,6 @@ def mkNumLit (val : String) (info : SourceInfo := {}) : Syntax :=
def mkDecimalLit (val : String) (info : SourceInfo := {}) : Syntax :=
mkLit decimalLitKind val info
def mkScientificLit (val : String) (info : SourceInfo := {}) : Syntax :=
mkLit scientificLitKind val info
/- Recall that we don't have special Syntax constructors for storing numeric and string atoms.
The idea is to have an extensible approach where embedded DSLs may have new kind of atoms and/or
different ways of representing them. So, our atoms contain just the parsed string.
@ -373,28 +370,64 @@ def isNatLit? (s : Syntax) : Option Nat :=
def isFieldIdx? (s : Syntax) : Option Nat :=
isNatLitAux fieldIdxKind s
partial def decodeDecimalLitVal? (s : String) : Option (Nat × Nat) :=
partial def decodeDecimalLitVal? (s : String) : Option (Nat × Bool × Nat) :=
let len := s.length
if len == 0 then none
else
let c := s.get 0
if c.isDigit then
decode 0 0 0 false
decode 0 0
else none
where
decode (i : String.Pos) (val : Nat) (e : Nat) (foundDot : Bool) : Option (Nat × Nat) :=
decodeAfterExp (i : String.Pos) (val : Nat) (e : Nat) (sign : Bool) (exp : Nat) : Option (Nat × Bool × Nat) :=
if s.atEnd i then
if foundDot then some (val, e) else none
if sign then
some (val, sign, exp + e)
else if exp >= e then
some (val, sign, exp - e)
else
some (val, true, e - exp)
else
let c := s.get i
if '0' ≤ c && c ≤ '9' then
decode (s.next i) (10*val + c.toNat - '0'.toNat) (if foundDot then e+1 else e) foundDot
else if c == '.' && !foundDot then
decode (s.next i) val e true
decodeAfterExp (s.next i) val e sign (10*exp + c.toNat - '0'.toNat)
else
none
def isDecimalLit? (stx : Syntax) : Option (Nat × Nat) :=
decodeExp (i : String.Pos) (val : Nat) (e : Nat) : Option (Nat × Bool × Nat) :=
let c := s.get i
if c == '-' then
decodeAfterExp (s.next i) val e true 0
else
decodeAfterExp i val e false 0
decodeAfterDot (i : String.Pos) (val : Nat) (e : Nat) : Option (Nat × Bool × Nat) :=
if s.atEnd i then
some (val, true, e)
else
let c := s.get i
if '0' ≤ c && c ≤ '9' then
decodeAfterDot (s.next i) (10*val + c.toNat - '0'.toNat) (e+1)
else if c == 'e' || c == 'E' then
decodeExp (s.next i) val e
else
none
decode (i : String.Pos) (val : Nat) : Option (Nat × Bool × Nat) :=
if s.atEnd i then
none
else
let c := s.get i
if '0' ≤ c && c ≤ '9' then
decode (s.next i) (10*val + c.toNat - '0'.toNat)
else if c == '.' then
decodeAfterDot (s.next i) val 0
else if c == 'e' || c == 'E' then
decodeExp (s.next i) val 0
else
none
def isDecimalLit? (stx : Syntax) : Option (Nat × Bool × Nat) :=
match isLit? decimalLitKind stx with
| some val => decodeDecimalLitVal? val
| _ => none

View file

@ -1596,7 +1596,6 @@ def strLitKind : SyntaxNodeKind := `strLit
def charLitKind : SyntaxNodeKind := `charLit
def numLitKind : SyntaxNodeKind := `numLit
def decimalLitKind : SyntaxNodeKind := `decimalLit
def scientificLitKind : SyntaxNodeKind := `scientificLit
def nameLitKind : SyntaxNodeKind := `nameLit
def fieldIdxKind : SyntaxNodeKind := `fieldIdx
def interpolatedStrLitKind : SyntaxNodeKind := `interpolatedStrLitKind

View file

@ -483,8 +483,6 @@ partial def collect : Syntax → M Syntax
pure stx
else if k == decimalLitKind then
pure stx
else if k == scientificLitKind then
pure stx
else if k == charLitKind then
pure stx
else if k == `Lean.Parser.Term.quotedName then

View file

@ -1277,11 +1277,11 @@ private def mkFreshTypeMVarFor (expectedType? : Option Expr) : TermElabM Expr :=
@[builtinTermElab decimalLit] def elabDecimalLit : TermElab := fun stx expectedType? => do
match stx.isDecimalLit? with
| none => throwIllFormedSyntax
| some (m, e) =>
| some (m, sign, e) =>
let typeMVar ← mkFreshTypeMVarFor expectedType?
let u ← getDecLevel typeMVar
let mvar ← mkInstMVar (mkApp (Lean.mkConst `OfDecimal [u]) typeMVar)
return mkApp4 (Lean.mkConst `OfDecimal.ofDecimal [u]) typeMVar mvar (mkNatLit m) (mkNatLit e)
return mkApp5 (Lean.mkConst `OfDecimal.ofDecimal [u]) typeMVar mvar (mkNatLit m) (toExpr sign) (mkNatLit e)
@[builtinTermElab charLit] def elabCharLit : TermElab := fun stx _ => do
match stx.isCharLit? with

View file

@ -68,7 +68,7 @@ namespace Lean
namespace Parser
def isLitKind (k : SyntaxNodeKind) : Bool :=
k == strLitKind || k == numLitKind || k == charLitKind || k == nameLitKind || k == decimalLitKind || k == scientificLitKind
k == strLitKind || k == numLitKind || k == charLitKind || k == nameLitKind || k == decimalLitKind
abbrev mkAtom (info : SourceInfo) (val : String) : Syntax :=
Syntax.atom info val
@ -807,22 +807,46 @@ partial def strLitFnAux (startPos : Nat) : ParserFn := fun c s =>
else if curr == '\\' then andthenFn quotedCharFn (strLitFnAux startPos) c s
else strLitFnAux startPos c s
def decimalNumberFn (startPos : Nat) : ParserFn := fun c s =>
def decimalNumberFn (startPos : Nat) (c : ParserContext) : ParserState → ParserState := fun s =>
let s := takeWhileFn (fun c => c.isDigit) c s
let input := c.input
let i := s.pos
let curr := input.get i
if curr == '.' then
let i := input.next i
let curr := input.get i
let s :=
if curr == '.' || curr == 'e' || curr == 'E' then
let s := parseOptDot s
let s := parseOptExp s
mkNodeToken decimalLitKind startPos c s
else
mkNodeToken numLitKind startPos c s
where
parseOptDot s :=
let input := c.input
let i := s.pos
let curr := input.get i
if curr == '.' then
let i := input.next i
let curr := input.get i
if curr.isDigit then
takeWhileFn (fun c => c.isDigit) c (s.setPos i)
else
s.setPos i
mkNodeToken decimalLitKind startPos c s
else
mkNodeToken numLitKind startPos c s
else
s
parseOptExp s :=
let input := c.input
let i := s.pos
let curr := input.get i
if curr == 'e' || curr == 'E' then
let i := input.next i
let i := if input.get i == '-' then input.next i else i
let curr := input.get i
if curr.isDigit then
takeWhileFn (fun c => c.isDigit) c (s.setPos i)
else
s.setPos i
else
s
def binNumberFn (startPos : Nat) : ParserFn := fun c s =>
let s := takeWhile1Fn (fun c => c == '0' || c == '1') "binary number" c s

View file

@ -27,7 +27,6 @@ builtin_initialize
registerBuiltinNodeKind strLitKind
registerBuiltinNodeKind numLitKind
registerBuiltinNodeKind decimalLitKind
registerBuiltinNodeKind scientificLitKind
registerBuiltinNodeKind charLitKind
registerBuiltinNodeKind nameLitKind

View file

@ -368,14 +368,26 @@ def delabOfNat : Delab := whenPPOption getPPCoercions do
let (Expr.app (Expr.app _ (Expr.lit (Literal.natVal n) _) _) _ _) ← getExpr | failure
return quote n
-- `@OfDecimal.ofDecimal _ _ m e` ~> `m*10^(-e)`
-- `@OfDecimal.ofDecimal _ _ m s e` ~> `m*10^(sign * e)` where `sign == 1` if `s = false` and `sign = -1` if `s = true`
@[builtinDelab app.OfDecimal.ofDecimal]
def delabOfDecimal : Delab := whenPPOption getPPCoercions do
let (Expr.app (Expr.app _ (Expr.lit (Literal.natVal m) _) _) (Expr.lit (Literal.natVal e) _) _) ← getExpr | failure
let expr ← getExpr
guard <| expr.getAppNumArgs == 5
let Expr.lit (Literal.natVal m) _ ← pure (expr.getArg! 2) | failure
let Expr.lit (Literal.natVal e) _ ← pure (expr.getArg! 4) | failure
let s ← match expr.getArg! 3 with
| Expr.const `Bool.true _ _ => pure true
| Expr.const `Bool.false _ _ => pure false
| _ => failure
let str := toString m
let mStr := str.extract 0 (str.length - e)
let eStr := str.extract (str.length - e) str.length
return Syntax.mkDecimalLit (mStr ++ "." ++ eStr)
if s && e == str.length then
return Syntax.mkDecimalLit ("0." ++ str)
else if s && e < str.length then
let mStr := str.extract 0 (str.length - e)
let eStr := str.extract (str.length - e) str.length
return Syntax.mkDecimalLit (mStr ++ "." ++ eStr)
else
return Syntax.mkDecimalLit (str ++ "e" ++ (if s then "-" else "") ++ toString e)
/--
Delaborate a projection primitive. These do not usually occur in

View file

@ -1464,17 +1464,25 @@ extern "C" lean_obj_res lean_float_to_string(double a) {
return mk_string(std::to_string(a));
}
static double of_decimal(mpz const & m, size_t e) {
return (mpq(m)/mpz(10).pow(e)).get_double();
static double of_decimal(mpz const & m, bool sign, size_t e) {
if (sign)
return (mpq(m)/mpz(10).pow(e)).get_double();
else
return (mpq(m)*mpz(10).pow(e)).get_double();
}
extern "C" double lean_float_of_decimal(b_lean_obj_arg m, b_lean_obj_arg e) {
if (!lean_is_scalar(e))
return 0.0;
extern "C" double lean_float_of_decimal(b_lean_obj_arg m, uint8 esign, b_lean_obj_arg e) {
if (!lean_is_scalar(e)) {
if (esign) {
return 0.0;
} else {
return std::numeric_limits<double>::infinity();
}
}
if (lean_is_scalar(m)) {
return of_decimal(mpz::of_size_t(lean_unbox(m)), lean_unbox(e));
return of_decimal(mpz::of_size_t(lean_unbox(m)), esign, lean_unbox(e));
} else {
return of_decimal(mpz_value(m), lean_unbox(e));
return of_decimal(mpz_value(m), esign, lean_unbox(e));
}
}

View file

@ -4,3 +4,13 @@
#eval 1.2 + 2.3
#check 1.
#check 3.1416
theorem ex : 31416e-4 = 3.1416 :=
rfl
#eval 3.4e-100 * 1e98
#eval 12.3e90 * 1e-90
#eval 3.00e-100 * 1e100
#eval 3.00e-100 * 1.e100
#eval 3.00e-100 * 1.0e100

View file

@ -4,3 +4,8 @@
3.500000
1. : Float
3.1416 : Float
0.034000
12.300000
3.000000
3.000000
3.000000