diff --git a/fincal/core.py b/fincal/core.py index bdbd5dd..8c187fc 100644 --- a/fincal/core.py +++ b/fincal/core.py @@ -149,99 +149,62 @@ class Series(UserList): else: return self.data[i] - def __gt__(self, other): - if self.dtype == bool: - raise TypeError("> not supported for boolean series") + def _comparison_validator(self, other): + """Validates other before making comparison""" if isinstance(other, (str, datetime.datetime, datetime.date)): other = _parse_date(other) + return other - if isinstance(other, Series): + if self.dtype == bool: + raise TypeError("Comparison operation not supported for boolean series") + + elif isinstance(other, Series): if len(self) != len(other): raise ValueError("Length of Series must be same for comparison") - gt = Series([j > other[i] for i, j in enumerate(self)], "bool") - - elif self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype): - gt = Series([i > other for i in self.data], "bool") - else: + elif (self.dtype != float and isinstance(other, Number)) or not isinstance(other, self.dtype): raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}") - return gt + def __gt__(self, other): + other = self._comparison_validator(other) + + if isinstance(other, Series): + return Series([j > other[i] for i, j in enumerate(self)], "bool") + + return Series([i > other for i in self.data], "bool") def __ge__(self, other): - if self.dtype == bool: - raise TypeError(">= not supported for boolean series") - - if isinstance(other, (str, datetime.datetime, datetime.date)): - other = _parse_date(other) + other = self._comparison_validator(other) if isinstance(other, Series): - if len(self) != len(other): - raise ValueError("Length of Series must be same for comparison") + return Series([j >= other[i] for i, j in enumerate(self)], "bool") - ge = Series([j >= other[i] for i, j in enumerate(self)], "bool") - - elif self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype): - ge = Series([i >= other for i in self.data], "bool") - else: - raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}") - - return ge + return Series([i >= other for i in self.data], "bool") def __lt__(self, other): - if self.dtype == bool: - raise TypeError("< not supported for boolean series") - - if isinstance(other, (str, datetime.datetime, datetime.date)): - other = _parse_date(other) + other = self._comparison_validator(other) if isinstance(other, Series): - if len(self) != len(other): - raise ValueError("Length of Series must be same for comparison") + return Series([j < other[i] for i, j in enumerate(self)], "bool") - lt = Series([j < other[i] for i, j in enumerate(self)], "bool") - - elif self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype): - lt = Series([i < other for i in self.data], "bool") - else: - raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}") - return lt + return Series([i < other for i in self.data], "bool") def __le__(self, other): - if self.dtype == bool: - raise TypeError("<= not supported for boolean series") - - if isinstance(other, (str, datetime.datetime, datetime.date)): - other = _parse_date(other) + other = self._comparison_validator(other) if isinstance(other, Series): - if len(self) != len(other): - raise ValueError("Length of Series must be same for comparison") + return Series([j <= other[i] for i, j in enumerate(self)], "bool") - le = Series([j <= other[i] for i, j in enumerate(self)], "bool") - - elif self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype): - le = Series([i <= other for i in self.data], "bool") - else: - raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}") - return le + return Series([i <= other for i in self.data], "bool") def __eq__(self, other): - if isinstance(other, (str, datetime.datetime, datetime.date)): - other = _parse_date(other) + other = self._comparison_validator(other) if isinstance(other, Series): - if len(self) != len(other): - raise ValueError("Length of Series must be same for comparison") + return Series([j == other[i] for i, j in enumerate(self)], "bool") - eq = Series([j == other[i] for i, j in enumerate(self)], "bool") - - elif self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype): - eq = Series([i == other for i in self.data], "bool") - else: - raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}") - return eq + return Series([i == other for i in self.data], "bool") @Mapping.register @@ -401,30 +364,6 @@ class TimeSeriesCore: raise TypeError(f"Invalid type {repr(type(key).__name__)} for slicing.") - def __gt__(self, other): - if isinstance(other, Number): - data = {k: v > other for k, v in self.data.items()} - - if isinstance(other, TimeSeriesCore): - if self.dates != other.dates: - raise ValueError( - "Only objects with same set of dates can be compared.\n" - "Hint: use TimeSeries.sync() method to sync dates of two TimeSeries objects." - ) - - data = {dt: val > other[dt][1] for dt, val in self.data.items()} - - if isinstance(other, Series): - if Series.dtype != float: - raise TypeError("Only Series of type float can be used for comparison") - - if len(self) != len(other): - raise ValueError("Length of series does not match length of object") - - data = {dt: val > other[i] for i, (dt, val) in enumerate(self.data.items())} - - return self.__class__(data, frequency=self.frequency.symbol) - @date_parser(1) def __setitem__(self, key: str | datetime.datetime, value: Number) -> None: if not isinstance(value, Number): @@ -436,6 +375,102 @@ class TimeSeriesCore: self.data.update({key: value}) self.data = dict(sorted(self.data.items())) + @date_parser(1) + def __delitem__(self, key): + del self.data[key] + + def _comparison_validator(self, other): + """Validates the data before comparison is performed""" + + if not isinstance(other, (Number, Series, TimeSeriesCore)): + raise TypeError( + f"Comparison cannot be performed between '{self.__class__.__name__}' and '{other.__class__.__name__}'" + ) + + if isinstance(other, TimeSeriesCore): + if self.dates != other.dates: + raise ValueError( + "Only objects with same set of dates can be compared.\n" + "Hint: use TimeSeries.sync() method to sync dates of two TimeSeries objects." + ) + + if isinstance(other, Series): + if other.dtype != float: + raise TypeError("Only Series of type float can be used for comparison") + + if len(self) != len(other): + raise ValueError("Length of series does not match length of object") + + def __gt__(self, other): + self._comparison_validator(other) + + if isinstance(other, Number): + data = {k: v > other for k, v in self.data.items()} + + if isinstance(other, TimeSeriesCore): + data = {dt: val > other[dt][1] for dt, val in self.data.items()} + + if isinstance(other, Series): + data = {dt: val > other[i] for i, (dt, val) in enumerate(self.data.items())} + + return self.__class__(data, frequency=self.frequency.symbol) + + def __ge__(self, other): + self._comparison_validator(other) + + if isinstance(other, Number): + data = {k: v >= other for k, v in self.data.items()} + + if isinstance(other, TimeSeriesCore): + data = {dt: val >= other[dt][1] for dt, val in self.data.items()} + + if isinstance(other, Series): + data = {dt: val >= other[i] for i, (dt, val) in enumerate(self.data.items())} + + return self.__class__(data, frequency=self.frequency.symbol) + + def __lt__(self, other): + self._comparison_validator(other) + + if isinstance(other, Number): + data = {k: v < other for k, v in self.data.items()} + + if isinstance(other, TimeSeriesCore): + data = {dt: val < other[dt][1] for dt, val in self.data.items()} + + if isinstance(other, Series): + data = {dt: val < other[i] for i, (dt, val) in enumerate(self.data.items())} + + return self.__class__(data, frequency=self.frequency.symbol) + + def __le__(self, other): + self._comparison_validator(other) + + if isinstance(other, Number): + data = {k: v <= other for k, v in self.data.items()} + + if isinstance(other, TimeSeriesCore): + data = {dt: val <= other[dt][1] for dt, val in self.data.items()} + + if isinstance(other, Series): + data = {dt: val <= other[i] for i, (dt, val) in enumerate(self.data.items())} + + return self.__class__(data, frequency=self.frequency.symbol) + + def __eq__(self, other): + self._comparison_validator(other) + + if isinstance(other, Number): + data = {k: v == other for k, v in self.data.items()} + + if isinstance(other, TimeSeriesCore): + data = {dt: val == other[dt][1] for dt, val in self.data.items()} + + if isinstance(other, Series): + data = {dt: val == other[i] for i, (dt, val) in enumerate(self.data.items())} + + return self.__class__(data, frequency=self.frequency.symbol) + def __iter__(self): self.n = 0 return self