feat: scientific notation
This commit is contained in:
parent
962cffbaaa
commit
d1f4d4f57e
12 changed files with 133 additions and 42 deletions
|
|
@ -98,4 +98,4 @@ abbrev Nat.toFloat (n : Nat) : Float :=
|
|||
|
||||
instance : Pow Float := ⟨Float.pow⟩
|
||||
|
||||
@[extern "lean_float_of_decimal"] constant Float.ofDecimal (m : Nat) (e : Nat) : Float
|
||||
@[extern "lean_float_of_decimal"] constant Float.ofDecimal (m : Nat) (s : Bool) (e : Nat) : Float
|
||||
|
|
|
|||
|
|
@ -8,10 +8,13 @@ import Init.Data.Float
|
|||
import Init.Data.Nat
|
||||
|
||||
/- For decimal numbers (e.g., `1.23`).
|
||||
The Lean frontend uses `OfDecimal.ofDecimal 123 2` to represent `1.23` -/
|
||||
Examples:
|
||||
- `OfDecimal.ofDecimal 123 true 2` represents `1.23`
|
||||
- `Ofdecimal.ofdecimal 121 false 100` represents `121e100`
|
||||
-/
|
||||
class OfDecimal (α : Type u) where
|
||||
ofDecimal : Nat → Nat → α
|
||||
ofDecimal : Nat → Bool → Nat → α
|
||||
|
||||
@[defaultInstance]
|
||||
instance : OfDecimal Float where
|
||||
ofDecimal m e := Float.ofDecimal m e
|
||||
ofDecimal m s e := Float.ofDecimal m s e
|
||||
|
|
|
|||
|
|
@ -288,9 +288,6 @@ def mkNumLit (val : String) (info : SourceInfo := {}) : Syntax :=
|
|||
def mkDecimalLit (val : String) (info : SourceInfo := {}) : Syntax :=
|
||||
mkLit decimalLitKind val info
|
||||
|
||||
def mkScientificLit (val : String) (info : SourceInfo := {}) : Syntax :=
|
||||
mkLit scientificLitKind val info
|
||||
|
||||
/- Recall that we don't have special Syntax constructors for storing numeric and string atoms.
|
||||
The idea is to have an extensible approach where embedded DSLs may have new kind of atoms and/or
|
||||
different ways of representing them. So, our atoms contain just the parsed string.
|
||||
|
|
@ -373,28 +370,64 @@ def isNatLit? (s : Syntax) : Option Nat :=
|
|||
def isFieldIdx? (s : Syntax) : Option Nat :=
|
||||
isNatLitAux fieldIdxKind s
|
||||
|
||||
partial def decodeDecimalLitVal? (s : String) : Option (Nat × Nat) :=
|
||||
partial def decodeDecimalLitVal? (s : String) : Option (Nat × Bool × Nat) :=
|
||||
let len := s.length
|
||||
if len == 0 then none
|
||||
else
|
||||
let c := s.get 0
|
||||
if c.isDigit then
|
||||
decode 0 0 0 false
|
||||
decode 0 0
|
||||
else none
|
||||
where
|
||||
decode (i : String.Pos) (val : Nat) (e : Nat) (foundDot : Bool) : Option (Nat × Nat) :=
|
||||
decodeAfterExp (i : String.Pos) (val : Nat) (e : Nat) (sign : Bool) (exp : Nat) : Option (Nat × Bool × Nat) :=
|
||||
if s.atEnd i then
|
||||
if foundDot then some (val, e) else none
|
||||
if sign then
|
||||
some (val, sign, exp + e)
|
||||
else if exp >= e then
|
||||
some (val, sign, exp - e)
|
||||
else
|
||||
some (val, true, e - exp)
|
||||
else
|
||||
let c := s.get i
|
||||
if '0' ≤ c && c ≤ '9' then
|
||||
decode (s.next i) (10*val + c.toNat - '0'.toNat) (if foundDot then e+1 else e) foundDot
|
||||
else if c == '.' && !foundDot then
|
||||
decode (s.next i) val e true
|
||||
decodeAfterExp (s.next i) val e sign (10*exp + c.toNat - '0'.toNat)
|
||||
else
|
||||
none
|
||||
|
||||
def isDecimalLit? (stx : Syntax) : Option (Nat × Nat) :=
|
||||
decodeExp (i : String.Pos) (val : Nat) (e : Nat) : Option (Nat × Bool × Nat) :=
|
||||
let c := s.get i
|
||||
if c == '-' then
|
||||
decodeAfterExp (s.next i) val e true 0
|
||||
else
|
||||
decodeAfterExp i val e false 0
|
||||
|
||||
decodeAfterDot (i : String.Pos) (val : Nat) (e : Nat) : Option (Nat × Bool × Nat) :=
|
||||
if s.atEnd i then
|
||||
some (val, true, e)
|
||||
else
|
||||
let c := s.get i
|
||||
if '0' ≤ c && c ≤ '9' then
|
||||
decodeAfterDot (s.next i) (10*val + c.toNat - '0'.toNat) (e+1)
|
||||
else if c == 'e' || c == 'E' then
|
||||
decodeExp (s.next i) val e
|
||||
else
|
||||
none
|
||||
|
||||
decode (i : String.Pos) (val : Nat) : Option (Nat × Bool × Nat) :=
|
||||
if s.atEnd i then
|
||||
none
|
||||
else
|
||||
let c := s.get i
|
||||
if '0' ≤ c && c ≤ '9' then
|
||||
decode (s.next i) (10*val + c.toNat - '0'.toNat)
|
||||
else if c == '.' then
|
||||
decodeAfterDot (s.next i) val 0
|
||||
else if c == 'e' || c == 'E' then
|
||||
decodeExp (s.next i) val 0
|
||||
else
|
||||
none
|
||||
|
||||
def isDecimalLit? (stx : Syntax) : Option (Nat × Bool × Nat) :=
|
||||
match isLit? decimalLitKind stx with
|
||||
| some val => decodeDecimalLitVal? val
|
||||
| _ => none
|
||||
|
|
|
|||
|
|
@ -1596,7 +1596,6 @@ def strLitKind : SyntaxNodeKind := `strLit
|
|||
def charLitKind : SyntaxNodeKind := `charLit
|
||||
def numLitKind : SyntaxNodeKind := `numLit
|
||||
def decimalLitKind : SyntaxNodeKind := `decimalLit
|
||||
def scientificLitKind : SyntaxNodeKind := `scientificLit
|
||||
def nameLitKind : SyntaxNodeKind := `nameLit
|
||||
def fieldIdxKind : SyntaxNodeKind := `fieldIdx
|
||||
def interpolatedStrLitKind : SyntaxNodeKind := `interpolatedStrLitKind
|
||||
|
|
|
|||
|
|
@ -483,8 +483,6 @@ partial def collect : Syntax → M Syntax
|
|||
pure stx
|
||||
else if k == decimalLitKind then
|
||||
pure stx
|
||||
else if k == scientificLitKind then
|
||||
pure stx
|
||||
else if k == charLitKind then
|
||||
pure stx
|
||||
else if k == `Lean.Parser.Term.quotedName then
|
||||
|
|
|
|||
|
|
@ -1277,11 +1277,11 @@ private def mkFreshTypeMVarFor (expectedType? : Option Expr) : TermElabM Expr :=
|
|||
@[builtinTermElab decimalLit] def elabDecimalLit : TermElab := fun stx expectedType? => do
|
||||
match stx.isDecimalLit? with
|
||||
| none => throwIllFormedSyntax
|
||||
| some (m, e) =>
|
||||
| some (m, sign, e) =>
|
||||
let typeMVar ← mkFreshTypeMVarFor expectedType?
|
||||
let u ← getDecLevel typeMVar
|
||||
let mvar ← mkInstMVar (mkApp (Lean.mkConst `OfDecimal [u]) typeMVar)
|
||||
return mkApp4 (Lean.mkConst `OfDecimal.ofDecimal [u]) typeMVar mvar (mkNatLit m) (mkNatLit e)
|
||||
return mkApp5 (Lean.mkConst `OfDecimal.ofDecimal [u]) typeMVar mvar (mkNatLit m) (toExpr sign) (mkNatLit e)
|
||||
|
||||
@[builtinTermElab charLit] def elabCharLit : TermElab := fun stx _ => do
|
||||
match stx.isCharLit? with
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ namespace Lean
|
|||
namespace Parser
|
||||
|
||||
def isLitKind (k : SyntaxNodeKind) : Bool :=
|
||||
k == strLitKind || k == numLitKind || k == charLitKind || k == nameLitKind || k == decimalLitKind || k == scientificLitKind
|
||||
k == strLitKind || k == numLitKind || k == charLitKind || k == nameLitKind || k == decimalLitKind
|
||||
|
||||
abbrev mkAtom (info : SourceInfo) (val : String) : Syntax :=
|
||||
Syntax.atom info val
|
||||
|
|
@ -807,22 +807,46 @@ partial def strLitFnAux (startPos : Nat) : ParserFn := fun c s =>
|
|||
else if curr == '\\' then andthenFn quotedCharFn (strLitFnAux startPos) c s
|
||||
else strLitFnAux startPos c s
|
||||
|
||||
def decimalNumberFn (startPos : Nat) : ParserFn := fun c s =>
|
||||
def decimalNumberFn (startPos : Nat) (c : ParserContext) : ParserState → ParserState := fun s =>
|
||||
let s := takeWhileFn (fun c => c.isDigit) c s
|
||||
let input := c.input
|
||||
let i := s.pos
|
||||
let curr := input.get i
|
||||
if curr == '.' then
|
||||
let i := input.next i
|
||||
let curr := input.get i
|
||||
let s :=
|
||||
if curr == '.' || curr == 'e' || curr == 'E' then
|
||||
let s := parseOptDot s
|
||||
let s := parseOptExp s
|
||||
mkNodeToken decimalLitKind startPos c s
|
||||
else
|
||||
mkNodeToken numLitKind startPos c s
|
||||
where
|
||||
parseOptDot s :=
|
||||
let input := c.input
|
||||
let i := s.pos
|
||||
let curr := input.get i
|
||||
if curr == '.' then
|
||||
let i := input.next i
|
||||
let curr := input.get i
|
||||
if curr.isDigit then
|
||||
takeWhileFn (fun c => c.isDigit) c (s.setPos i)
|
||||
else
|
||||
s.setPos i
|
||||
mkNodeToken decimalLitKind startPos c s
|
||||
else
|
||||
mkNodeToken numLitKind startPos c s
|
||||
else
|
||||
s
|
||||
|
||||
parseOptExp s :=
|
||||
let input := c.input
|
||||
let i := s.pos
|
||||
let curr := input.get i
|
||||
if curr == 'e' || curr == 'E' then
|
||||
let i := input.next i
|
||||
let i := if input.get i == '-' then input.next i else i
|
||||
let curr := input.get i
|
||||
if curr.isDigit then
|
||||
takeWhileFn (fun c => c.isDigit) c (s.setPos i)
|
||||
else
|
||||
s.setPos i
|
||||
else
|
||||
s
|
||||
|
||||
def binNumberFn (startPos : Nat) : ParserFn := fun c s =>
|
||||
let s := takeWhile1Fn (fun c => c == '0' || c == '1') "binary number" c s
|
||||
|
|
|
|||
|
|
@ -27,7 +27,6 @@ builtin_initialize
|
|||
registerBuiltinNodeKind strLitKind
|
||||
registerBuiltinNodeKind numLitKind
|
||||
registerBuiltinNodeKind decimalLitKind
|
||||
registerBuiltinNodeKind scientificLitKind
|
||||
registerBuiltinNodeKind charLitKind
|
||||
registerBuiltinNodeKind nameLitKind
|
||||
|
||||
|
|
|
|||
|
|
@ -368,14 +368,26 @@ def delabOfNat : Delab := whenPPOption getPPCoercions do
|
|||
let (Expr.app (Expr.app _ (Expr.lit (Literal.natVal n) _) _) _ _) ← getExpr | failure
|
||||
return quote n
|
||||
|
||||
-- `@OfDecimal.ofDecimal _ _ m e` ~> `m*10^(-e)`
|
||||
-- `@OfDecimal.ofDecimal _ _ m s e` ~> `m*10^(sign * e)` where `sign == 1` if `s = false` and `sign = -1` if `s = true`
|
||||
@[builtinDelab app.OfDecimal.ofDecimal]
|
||||
def delabOfDecimal : Delab := whenPPOption getPPCoercions do
|
||||
let (Expr.app (Expr.app _ (Expr.lit (Literal.natVal m) _) _) (Expr.lit (Literal.natVal e) _) _) ← getExpr | failure
|
||||
let expr ← getExpr
|
||||
guard <| expr.getAppNumArgs == 5
|
||||
let Expr.lit (Literal.natVal m) _ ← pure (expr.getArg! 2) | failure
|
||||
let Expr.lit (Literal.natVal e) _ ← pure (expr.getArg! 4) | failure
|
||||
let s ← match expr.getArg! 3 with
|
||||
| Expr.const `Bool.true _ _ => pure true
|
||||
| Expr.const `Bool.false _ _ => pure false
|
||||
| _ => failure
|
||||
let str := toString m
|
||||
let mStr := str.extract 0 (str.length - e)
|
||||
let eStr := str.extract (str.length - e) str.length
|
||||
return Syntax.mkDecimalLit (mStr ++ "." ++ eStr)
|
||||
if s && e == str.length then
|
||||
return Syntax.mkDecimalLit ("0." ++ str)
|
||||
else if s && e < str.length then
|
||||
let mStr := str.extract 0 (str.length - e)
|
||||
let eStr := str.extract (str.length - e) str.length
|
||||
return Syntax.mkDecimalLit (mStr ++ "." ++ eStr)
|
||||
else
|
||||
return Syntax.mkDecimalLit (str ++ "e" ++ (if s then "-" else "") ++ toString e)
|
||||
|
||||
/--
|
||||
Delaborate a projection primitive. These do not usually occur in
|
||||
|
|
|
|||
|
|
@ -1464,17 +1464,25 @@ extern "C" lean_obj_res lean_float_to_string(double a) {
|
|||
return mk_string(std::to_string(a));
|
||||
}
|
||||
|
||||
static double of_decimal(mpz const & m, size_t e) {
|
||||
return (mpq(m)/mpz(10).pow(e)).get_double();
|
||||
static double of_decimal(mpz const & m, bool sign, size_t e) {
|
||||
if (sign)
|
||||
return (mpq(m)/mpz(10).pow(e)).get_double();
|
||||
else
|
||||
return (mpq(m)*mpz(10).pow(e)).get_double();
|
||||
}
|
||||
|
||||
extern "C" double lean_float_of_decimal(b_lean_obj_arg m, b_lean_obj_arg e) {
|
||||
if (!lean_is_scalar(e))
|
||||
return 0.0;
|
||||
extern "C" double lean_float_of_decimal(b_lean_obj_arg m, uint8 esign, b_lean_obj_arg e) {
|
||||
if (!lean_is_scalar(e)) {
|
||||
if (esign) {
|
||||
return 0.0;
|
||||
} else {
|
||||
return std::numeric_limits<double>::infinity();
|
||||
}
|
||||
}
|
||||
if (lean_is_scalar(m)) {
|
||||
return of_decimal(mpz::of_size_t(lean_unbox(m)), lean_unbox(e));
|
||||
return of_decimal(mpz::of_size_t(lean_unbox(m)), esign, lean_unbox(e));
|
||||
} else {
|
||||
return of_decimal(mpz_value(m), lean_unbox(e));
|
||||
return of_decimal(mpz_value(m), esign, lean_unbox(e));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4,3 +4,13 @@
|
|||
#eval 1.2 + 2.3
|
||||
#check 1.
|
||||
#check 3.1416
|
||||
|
||||
theorem ex : 31416e-4 = 3.1416 :=
|
||||
rfl
|
||||
|
||||
#eval 3.4e-100 * 1e98
|
||||
|
||||
#eval 12.3e90 * 1e-90
|
||||
#eval 3.00e-100 * 1e100
|
||||
#eval 3.00e-100 * 1.e100
|
||||
#eval 3.00e-100 * 1.0e100
|
||||
|
|
|
|||
|
|
@ -4,3 +4,8 @@
|
|||
3.500000
|
||||
1. : Float
|
||||
3.1416 : Float
|
||||
0.034000
|
||||
12.300000
|
||||
3.000000
|
||||
3.000000
|
||||
3.000000
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue