Added rmath dunder, create Series without specifying dtype

Added Math validator to Series, math dunders pending
This commit is contained in:
Gourav Kumar 2022-04-12 11:42:51 +05:30
parent 03a8045400
commit b38a317b82

View File

@ -8,6 +8,8 @@ from dataclasses import dataclass
from numbers import Number from numbers import Number
from typing import Callable, Iterable, List, Literal, Mapping, Sequence, Type from typing import Callable, Iterable, List, Literal, Mapping, Sequence, Type
from dateutil.relativedelta import relativedelta
from .utils import FincalOptions, _parse_date, _preprocess_timeseries from .utils import FincalOptions, _parse_date, _preprocess_timeseries
@ -112,7 +114,7 @@ class Series(UserList):
def __init__( def __init__(
self, self,
data: Sequence, data: Sequence,
data_type: Literal["date", "number", "bool"], dtype: Literal["date", "number", "bool"] = None,
date_format: str = None, date_format: str = None,
): ):
types_dict: dict = { types_dict: dict = {
@ -123,21 +125,26 @@ class Series(UserList):
"int": float, "int": float,
"number": float, "number": float,
"bool": bool, "bool": bool,
"Decimal": bool,
} }
if data_type not in types_dict.keys():
raise ValueError("Unsupported value for data type")
if not isinstance(data, Sequence): if not isinstance(data, Sequence):
raise TypeError("Series object can only be created using Sequence types") raise TypeError("Series object can only be created using Sequence types")
if data_type in ["date", "datetime", "datetime.datetime"]: if dtype is None:
if isinstance(data[0], (Number, datetime.datetime, datetime.date, bool)):
dtype = data[0].__class__.__name__.lower()
if dtype not in types_dict.keys():
raise ValueError("Unsupported value for data type")
if dtype in ["date", "datetime", "datetime.datetime"]:
data = [_parse_date(i, date_format) for i in data] data = [_parse_date(i, date_format) for i in data]
else: else:
func: Callable = types_dict[data_type] func: Callable = types_dict[dtype]
data: list = [func(i) for i in data] data: list = [func(i) for i in data]
self.dtype: Type = types_dict[data_type] self.dtype: Type = types_dict[dtype]
self.data: Sequence = data self.data: Sequence = data
def __repr__(self): def __repr__(self):
@ -216,6 +223,55 @@ class Series(UserList):
return Series([i == other for i in self.data], "bool") return Series([i == other for i in self.data], "bool")
def _math_validator(self, other):
if not isinstance(other, (Series, Number, datetime.timedelta, relativedelta, datetime.datetime, datetime.date)):
return NotImplemented
if isinstance(other, Series):
if len(self) != len(other):
raise ValueError("Arithmatic operations cannot be performed on objects of different lengths.")
if self.dtype == bool or other.dtype == bool:
raise TypeError("Arithmatic operations cannot be performed on boolean series.")
if self.dtype == float and not other.dtype == float:
raise TypeError(
"Arithmatic operation cannot be performed between "
f"'{self.dtype.__name__}' and '{other.dtype.__name__}'"
)
if self.dtype == datetime.datetime:
raise TypeError(
"Arithmatic operation cannot be performed between '"
f"'{self.dtype.__name__}' and '{other.dtype.__name__}'"
)
return
elif self.dtype == float and not isinstance(other, Number):
raise TypeError(
f"Arithmatic operation cannot be performed between '{self.dtype}' and '{other.__class__.__name__}'"
)
elif self.dtype == datetime.datetime and not isinstance(other, (datetime.timedelta, relativedelta)):
raise TypeError(
f"Arithmatic operation cannot be performed between '{self.dtype.__name__}' and "
f"'{other.__class__.__name__}'\nHint: Try using timedelta or relativedelta objects."
)
return other
def __add__(self, other):
if self._math_validator(other) == NotImplemented:
return NotImplemented
if isinstance(other, Series):
return self.__class__([j + other[i] for i, j in enumerate(self)], self.dtype.__name__)
if isinstance(other, (Number, datetime.timedelta, relativedelta)):
return self.__class__([i + other for i in self], self.dtype.__name__)
@Mapping.register @Mapping.register
class TimeSeriesCore: class TimeSeriesCore:
@ -640,6 +696,74 @@ class TimeSeriesCore:
return self.__class__(data, self.frequency.symbol) return self.__class__(data, self.frequency.symbol)
def __radd__(self, other):
self._arithmatic_validator(other)
if isinstance(other, TimeSeriesCore):
other = other.values
if isinstance(other, Series):
data = {dt: val + other[i] for i, (dt, val) in enumerate(self.data.items())}
elif isinstance(other, Number):
data = {dt: val + other for dt, val in self.data.items()}
return self.__class__(data, self.frequency.symbol)
def __rsub__(self, other):
self._arithmatic_validator(other)
if isinstance(other, TimeSeriesCore):
other = other.values
if isinstance(other, Series):
data = {dt: other[i] - val for i, (dt, val) in enumerate(self.data.items())}
elif isinstance(other, Number):
data = {dt: other - val for dt, val in self.data.items()}
return self.__class__(data, self.frequency.symbol)
def __rtruediv__(self, other):
self._arithmatic_validator(other)
if isinstance(other, TimeSeriesCore):
other = other.values
if isinstance(other, Series):
data = {dt: other[i] / val for i, (dt, val) in enumerate(self.data.items())}
elif isinstance(other, Number):
data = {dt: other / val for dt, val in self.data.items()}
return self.__class__(data, self.frequency.symbol)
def __rfloordiv__(self, other):
self._arithmatic_validator(other)
if isinstance(other, TimeSeriesCore):
other = other.values
if isinstance(other, Series):
data = {dt: other[i] // val for i, (dt, val) in enumerate(self.data.items())}
elif isinstance(other, Number):
data = {dt: other // val for dt, val in self.data.items()}
return self.__class__(data, self.frequency.symbol)
def __rmul__(self, other):
self._arithmatic_validator(other)
if isinstance(other, TimeSeriesCore):
other = other.values
if isinstance(other, Series):
data = {dt: val * other[i] for i, (dt, val) in enumerate(self.data.items())}
elif isinstance(other, Number):
data = {dt: val * other for dt, val in self.data.items()}
return self.__class__(data, self.frequency.symbol)
def __rpow__(self, _):
raise NotImplementedError("This operation is not supported.")
@date_parser(1) @date_parser(1)
def get(self, date: str | datetime.datetime, default=None, closest=None): def get(self, date: str | datetime.datetime, default=None, closest=None):
@ -699,3 +823,7 @@ class TimeSeriesCore:
def items(self): def items(self):
return self.data.items() return self.data.items()
def update(self, items: dict):
for k, v in items.items():
self[k] = v