Compare commits

...

4 Commits

Author SHA1 Message Date
65f2e8434c More arithmatic tests 2022-04-12 22:40:06 +05:30
e8be7e9efa Math tests and series dtype param name change 2022-04-12 11:43:52 +05:30
49604a5ae9 Series parameter name change 2022-04-12 11:43:11 +05:30
b38a317b82 Added rmath dunder, create Series without specifying dtype
Added Math validator to Series, math dunders pending
2022-04-12 11:42:51 +05:30
3 changed files with 224 additions and 13 deletions

View File

@ -8,6 +8,8 @@ from dataclasses import dataclass
from numbers import Number from numbers import Number
from typing import Callable, Iterable, List, Literal, Mapping, Sequence, Type from typing import Callable, Iterable, List, Literal, Mapping, Sequence, Type
from dateutil.relativedelta import relativedelta
from .utils import FincalOptions, _parse_date, _preprocess_timeseries from .utils import FincalOptions, _parse_date, _preprocess_timeseries
@ -112,7 +114,7 @@ class Series(UserList):
def __init__( def __init__(
self, self,
data: Sequence, data: Sequence,
data_type: Literal["date", "number", "bool"], dtype: Literal["date", "number", "bool"] = None,
date_format: str = None, date_format: str = None,
): ):
types_dict: dict = { types_dict: dict = {
@ -123,21 +125,26 @@ class Series(UserList):
"int": float, "int": float,
"number": float, "number": float,
"bool": bool, "bool": bool,
"Decimal": bool,
} }
if data_type not in types_dict.keys():
raise ValueError("Unsupported value for data type")
if not isinstance(data, Sequence): if not isinstance(data, Sequence):
raise TypeError("Series object can only be created using Sequence types") raise TypeError("Series object can only be created using Sequence types")
if data_type in ["date", "datetime", "datetime.datetime"]: if dtype is None:
if isinstance(data[0], (Number, datetime.datetime, datetime.date, bool)):
dtype = data[0].__class__.__name__.lower()
if dtype not in types_dict.keys():
raise ValueError("Unsupported value for data type")
if dtype in ["date", "datetime", "datetime.datetime"]:
data = [_parse_date(i, date_format) for i in data] data = [_parse_date(i, date_format) for i in data]
else: else:
func: Callable = types_dict[data_type] func: Callable = types_dict[dtype]
data: list = [func(i) for i in data] data: list = [func(i) for i in data]
self.dtype: Type = types_dict[data_type] self.dtype: Type = types_dict[dtype]
self.data: Sequence = data self.data: Sequence = data
def __repr__(self): def __repr__(self):
@ -216,6 +223,55 @@ class Series(UserList):
return Series([i == other for i in self.data], "bool") return Series([i == other for i in self.data], "bool")
def _math_validator(self, other):
if not isinstance(other, (Series, Number, datetime.timedelta, relativedelta, datetime.datetime, datetime.date)):
return NotImplemented
if isinstance(other, Series):
if len(self) != len(other):
raise ValueError("Arithmatic operations cannot be performed on objects of different lengths.")
if self.dtype == bool or other.dtype == bool:
raise TypeError("Arithmatic operations cannot be performed on boolean series.")
if self.dtype == float and not other.dtype == float:
raise TypeError(
"Arithmatic operation cannot be performed between "
f"'{self.dtype.__name__}' and '{other.dtype.__name__}'"
)
if self.dtype == datetime.datetime:
raise TypeError(
"Arithmatic operation cannot be performed between '"
f"'{self.dtype.__name__}' and '{other.dtype.__name__}'"
)
return
elif self.dtype == float and not isinstance(other, Number):
raise TypeError(
f"Arithmatic operation cannot be performed between '{self.dtype}' and '{other.__class__.__name__}'"
)
elif self.dtype == datetime.datetime and not isinstance(other, (datetime.timedelta, relativedelta)):
raise TypeError(
f"Arithmatic operation cannot be performed between '{self.dtype.__name__}' and "
f"'{other.__class__.__name__}'\nHint: Try using timedelta or relativedelta objects."
)
return other
def __add__(self, other):
if self._math_validator(other) == NotImplemented:
return NotImplemented
if isinstance(other, Series):
return self.__class__([j + other[i] for i, j in enumerate(self)], self.dtype.__name__)
if isinstance(other, (Number, datetime.timedelta, relativedelta)):
return self.__class__([i + other for i in self], self.dtype.__name__)
@Mapping.register @Mapping.register
class TimeSeriesCore: class TimeSeriesCore:
@ -640,6 +696,74 @@ class TimeSeriesCore:
return self.__class__(data, self.frequency.symbol) return self.__class__(data, self.frequency.symbol)
def __radd__(self, other):
self._arithmatic_validator(other)
if isinstance(other, TimeSeriesCore):
other = other.values
if isinstance(other, Series):
data = {dt: val + other[i] for i, (dt, val) in enumerate(self.data.items())}
elif isinstance(other, Number):
data = {dt: val + other for dt, val in self.data.items()}
return self.__class__(data, self.frequency.symbol)
def __rsub__(self, other):
self._arithmatic_validator(other)
if isinstance(other, TimeSeriesCore):
other = other.values
if isinstance(other, Series):
data = {dt: other[i] - val for i, (dt, val) in enumerate(self.data.items())}
elif isinstance(other, Number):
data = {dt: other - val for dt, val in self.data.items()}
return self.__class__(data, self.frequency.symbol)
def __rtruediv__(self, other):
self._arithmatic_validator(other)
if isinstance(other, TimeSeriesCore):
other = other.values
if isinstance(other, Series):
data = {dt: other[i] / val for i, (dt, val) in enumerate(self.data.items())}
elif isinstance(other, Number):
data = {dt: other / val for dt, val in self.data.items()}
return self.__class__(data, self.frequency.symbol)
def __rfloordiv__(self, other):
self._arithmatic_validator(other)
if isinstance(other, TimeSeriesCore):
other = other.values
if isinstance(other, Series):
data = {dt: other[i] // val for i, (dt, val) in enumerate(self.data.items())}
elif isinstance(other, Number):
data = {dt: other // val for dt, val in self.data.items()}
return self.__class__(data, self.frequency.symbol)
def __rmul__(self, other):
self._arithmatic_validator(other)
if isinstance(other, TimeSeriesCore):
other = other.values
if isinstance(other, Series):
data = {dt: val * other[i] for i, (dt, val) in enumerate(self.data.items())}
elif isinstance(other, Number):
data = {dt: val * other for dt, val in self.data.items()}
return self.__class__(data, self.frequency.symbol)
def __rpow__(self, _):
raise NotImplementedError("This operation is not supported.")
@date_parser(1) @date_parser(1)
def get(self, date: str | datetime.datetime, default=None, closest=None): def get(self, date: str | datetime.datetime, default=None, closest=None):
@ -699,3 +823,7 @@ class TimeSeriesCore:
def items(self): def items(self):
return self.data.items() return self.data.items()
def update(self, items: dict):
for k, v in items.items():
self[k] = v

View File

@ -99,7 +99,7 @@ def create_date_series(
elif date.weekday() < 5: elif date.weekday() < 5:
dates.append(date) dates.append(date)
return Series(dates, data_type="date") return Series(dates, dtype="date")
class TimeSeries(TimeSeriesCore): class TimeSeries(TimeSeriesCore):

View File

@ -86,12 +86,12 @@ class TestAllFrequencies:
class TestSeries: class TestSeries:
def test_creation(self): def test_creation(self):
series = Series([1, 2, 3, 4, 5, 6, 7], data_type="number") series = Series([1, 2, 3, 4, 5, 6, 7], dtype="number")
assert series.dtype == float assert series.dtype == float
assert series[2] == 3 assert series[2] == 3
dates = create_date_series("2021-01-01", "2021-01-31", frequency="D") dates = create_date_series("2021-01-01", "2021-01-31", frequency="D")
series = Series(dates, data_type="date") series = Series(dates, dtype="date")
assert series.dtype == datetime.datetime assert series.dtype == datetime.datetime
@ -292,7 +292,7 @@ class TestTimeSeriesComparisons:
def test_series_comparison(self): def test_series_comparison(self):
ts1 = TimeSeriesCore(self.data1, "M") ts1 = TimeSeriesCore(self.data1, "M")
ser = Series([240, 210, 240, 270], data_type="int") ser = Series([240, 210, 240, 270], dtype="int")
assert (ts1 > ser).values == Series([0.0, 1.0, 0.0, 0.0], "float") assert (ts1 > ser).values == Series([0.0, 1.0, 0.0, 0.0], "float")
assert (ts1 >= ser).values == Series([0.0, 1.0, 1.0, 0.0], "float") assert (ts1 >= ser).values == Series([0.0, 1.0, 1.0, 0.0], "float")
@ -315,8 +315,8 @@ class TestTimeSeriesComparisons:
def test_errors(self): def test_errors(self):
ts1 = TimeSeriesCore(self.data1, "M") ts1 = TimeSeriesCore(self.data1, "M")
ts2 = TimeSeriesCore(self.data2, "M") ts2 = TimeSeriesCore(self.data2, "M")
ser = Series([240, 210, 240], data_type="int") ser = Series([240, 210, 240], dtype="int")
ser2 = Series(["2021-01-01", "2021-02-01", "2021-03-01", "2021-04-01"], data_type="date") ser2 = Series(["2021-01-01", "2021-02-01", "2021-03-01", "2021-04-01"], dtype="date")
del ts2["2021-04-01"] del ts2["2021-04-01"]
@ -334,3 +334,86 @@ class TestTimeSeriesComparisons:
with pytest.raises(TypeError): with pytest.raises(TypeError):
ts2 < [23, 24, 25, 26] ts2 < [23, 24, 25, 26]
class TestTimeSeriesArithmatic:
data = [
("2021-01-01", 220),
("2021-02-01", 230),
("2021-03-01", 240),
("2021-04-01", 250),
]
def test_add(self):
ts = TimeSeriesCore(self.data, "M")
ser = ts.values
num_add_ts = ts + 40
assert num_add_ts["2021-01-01"][1] == 260
assert num_add_ts["2021-04-01"][1] == 290
num_radd_ts = 40 + ts
assert num_radd_ts["2021-01-01"][1] == 260
assert num_radd_ts["2021-04-01"][1] == 290
ser_add_ts = ts + ser
assert ser_add_ts["2021-01-01"][1] == 440
assert ser_add_ts["2021-04-01"][1] == 500
ts_add_ts = ts + num_add_ts
assert ts_add_ts["2021-01-01"][1] == 480
assert ts_add_ts["2021-04-01"][1] == 540
def test_sub(self):
ts = TimeSeriesCore(self.data, "M")
ser = Series([20, 30, 40, 50], "number")
num_sub_ts = ts - 40
assert num_sub_ts["2021-01-01"][1] == 180
assert num_sub_ts["2021-04-01"][1] == 210
num_rsub_ts = 240 - ts
assert num_rsub_ts["2021-01-01"][1] == 20
assert num_rsub_ts["2021-04-01"][1] == -10
ser_sub_ts = ts - ser
assert ser_sub_ts["2021-01-01"][1] == 200
assert ser_sub_ts["2021-04-01"][1] == 200
ts_sub_ts = ts - num_sub_ts
assert ts_sub_ts["2021-01-01"][1] == 40
assert ts_sub_ts["2021-04-01"][1] == 40
def test_truediv(self):
ts = TimeSeriesCore(self.data, "M")
ser = Series([21, 21, 23, 24], "number")
num_div_ts = ts / 10
assert num_div_ts["2021-01-01"][1] == 22
assert num_div_ts["2021-04-01"][1] == 25
num_rdiv_ts = 1000 / ts
assert num_rdiv_ts["2021-04-01"][1] == 4
ser_div_ts = ts / ser
assert ser_div_ts["2021-01-01"][1] == 10
assert ser_div_ts["2021-04-01"][1] == 10
ts_div_ts = ts / num_div_ts
assert ts_div_ts["2021-01-01"][1] == 10
assert ts_div_ts["2021-04-01"][1] == 10
def test_floordiv(self):
ts = TimeSeriesCore(self.data, "M")
ser = Series([22, 23, 24, 25], "number")
num_div_ts = ts // 11
assert num_div_ts["2021-02-01"][1] == 20
assert num_div_ts["2021-04-01"][1] == 22
num_rdiv_ts = 1000 // ts
assert num_rdiv_ts["2021-01-01"][1] == 4
ser_div_ts = ts // ser
assert ser_div_ts["2021-01-01"][1] == 10
assert ser_div_ts["2021-04-01"][1] == 10