implemented comparison in TSC, improved comparisons in Series
This commit is contained in:
parent
b246709603
commit
e8dbc16157
211
fincal/core.py
211
fincal/core.py
@ -149,99 +149,62 @@ class Series(UserList):
|
||||
else:
|
||||
return self.data[i]
|
||||
|
||||
def __gt__(self, other):
|
||||
if self.dtype == bool:
|
||||
raise TypeError("> not supported for boolean series")
|
||||
def _comparison_validator(self, other):
|
||||
"""Validates other before making comparison"""
|
||||
|
||||
if isinstance(other, (str, datetime.datetime, datetime.date)):
|
||||
other = _parse_date(other)
|
||||
return other
|
||||
|
||||
if isinstance(other, Series):
|
||||
if self.dtype == bool:
|
||||
raise TypeError("Comparison operation not supported for boolean series")
|
||||
|
||||
elif isinstance(other, Series):
|
||||
if len(self) != len(other):
|
||||
raise ValueError("Length of Series must be same for comparison")
|
||||
|
||||
gt = Series([j > other[i] for i, j in enumerate(self)], "bool")
|
||||
|
||||
elif self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype):
|
||||
gt = Series([i > other for i in self.data], "bool")
|
||||
else:
|
||||
elif (self.dtype != float and isinstance(other, Number)) or not isinstance(other, self.dtype):
|
||||
raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}")
|
||||
|
||||
return gt
|
||||
def __gt__(self, other):
|
||||
other = self._comparison_validator(other)
|
||||
|
||||
if isinstance(other, Series):
|
||||
return Series([j > other[i] for i, j in enumerate(self)], "bool")
|
||||
|
||||
return Series([i > other for i in self.data], "bool")
|
||||
|
||||
def __ge__(self, other):
|
||||
if self.dtype == bool:
|
||||
raise TypeError(">= not supported for boolean series")
|
||||
|
||||
if isinstance(other, (str, datetime.datetime, datetime.date)):
|
||||
other = _parse_date(other)
|
||||
other = self._comparison_validator(other)
|
||||
|
||||
if isinstance(other, Series):
|
||||
if len(self) != len(other):
|
||||
raise ValueError("Length of Series must be same for comparison")
|
||||
return Series([j >= other[i] for i, j in enumerate(self)], "bool")
|
||||
|
||||
ge = Series([j >= other[i] for i, j in enumerate(self)], "bool")
|
||||
|
||||
elif self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype):
|
||||
ge = Series([i >= other for i in self.data], "bool")
|
||||
else:
|
||||
raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}")
|
||||
|
||||
return ge
|
||||
return Series([i >= other for i in self.data], "bool")
|
||||
|
||||
def __lt__(self, other):
|
||||
if self.dtype == bool:
|
||||
raise TypeError("< not supported for boolean series")
|
||||
|
||||
if isinstance(other, (str, datetime.datetime, datetime.date)):
|
||||
other = _parse_date(other)
|
||||
other = self._comparison_validator(other)
|
||||
|
||||
if isinstance(other, Series):
|
||||
if len(self) != len(other):
|
||||
raise ValueError("Length of Series must be same for comparison")
|
||||
return Series([j < other[i] for i, j in enumerate(self)], "bool")
|
||||
|
||||
lt = Series([j < other[i] for i, j in enumerate(self)], "bool")
|
||||
|
||||
elif self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype):
|
||||
lt = Series([i < other for i in self.data], "bool")
|
||||
else:
|
||||
raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}")
|
||||
return lt
|
||||
return Series([i < other for i in self.data], "bool")
|
||||
|
||||
def __le__(self, other):
|
||||
if self.dtype == bool:
|
||||
raise TypeError("<= not supported for boolean series")
|
||||
|
||||
if isinstance(other, (str, datetime.datetime, datetime.date)):
|
||||
other = _parse_date(other)
|
||||
other = self._comparison_validator(other)
|
||||
|
||||
if isinstance(other, Series):
|
||||
if len(self) != len(other):
|
||||
raise ValueError("Length of Series must be same for comparison")
|
||||
return Series([j <= other[i] for i, j in enumerate(self)], "bool")
|
||||
|
||||
le = Series([j <= other[i] for i, j in enumerate(self)], "bool")
|
||||
|
||||
elif self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype):
|
||||
le = Series([i <= other for i in self.data], "bool")
|
||||
else:
|
||||
raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}")
|
||||
return le
|
||||
return Series([i <= other for i in self.data], "bool")
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, (str, datetime.datetime, datetime.date)):
|
||||
other = _parse_date(other)
|
||||
other = self._comparison_validator(other)
|
||||
|
||||
if isinstance(other, Series):
|
||||
if len(self) != len(other):
|
||||
raise ValueError("Length of Series must be same for comparison")
|
||||
return Series([j == other[i] for i, j in enumerate(self)], "bool")
|
||||
|
||||
eq = Series([j == other[i] for i, j in enumerate(self)], "bool")
|
||||
|
||||
elif self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype):
|
||||
eq = Series([i == other for i in self.data], "bool")
|
||||
else:
|
||||
raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}")
|
||||
return eq
|
||||
return Series([i == other for i in self.data], "bool")
|
||||
|
||||
|
||||
@Mapping.register
|
||||
@ -401,30 +364,6 @@ class TimeSeriesCore:
|
||||
|
||||
raise TypeError(f"Invalid type {repr(type(key).__name__)} for slicing.")
|
||||
|
||||
def __gt__(self, other):
|
||||
if isinstance(other, Number):
|
||||
data = {k: v > other for k, v in self.data.items()}
|
||||
|
||||
if isinstance(other, TimeSeriesCore):
|
||||
if self.dates != other.dates:
|
||||
raise ValueError(
|
||||
"Only objects with same set of dates can be compared.\n"
|
||||
"Hint: use TimeSeries.sync() method to sync dates of two TimeSeries objects."
|
||||
)
|
||||
|
||||
data = {dt: val > other[dt][1] for dt, val in self.data.items()}
|
||||
|
||||
if isinstance(other, Series):
|
||||
if Series.dtype != float:
|
||||
raise TypeError("Only Series of type float can be used for comparison")
|
||||
|
||||
if len(self) != len(other):
|
||||
raise ValueError("Length of series does not match length of object")
|
||||
|
||||
data = {dt: val > other[i] for i, (dt, val) in enumerate(self.data.items())}
|
||||
|
||||
return self.__class__(data, frequency=self.frequency.symbol)
|
||||
|
||||
@date_parser(1)
|
||||
def __setitem__(self, key: str | datetime.datetime, value: Number) -> None:
|
||||
if not isinstance(value, Number):
|
||||
@ -436,6 +375,102 @@ class TimeSeriesCore:
|
||||
self.data.update({key: value})
|
||||
self.data = dict(sorted(self.data.items()))
|
||||
|
||||
@date_parser(1)
|
||||
def __delitem__(self, key):
|
||||
del self.data[key]
|
||||
|
||||
def _comparison_validator(self, other):
|
||||
"""Validates the data before comparison is performed"""
|
||||
|
||||
if not isinstance(other, (Number, Series, TimeSeriesCore)):
|
||||
raise TypeError(
|
||||
f"Comparison cannot be performed between '{self.__class__.__name__}' and '{other.__class__.__name__}'"
|
||||
)
|
||||
|
||||
if isinstance(other, TimeSeriesCore):
|
||||
if self.dates != other.dates:
|
||||
raise ValueError(
|
||||
"Only objects with same set of dates can be compared.\n"
|
||||
"Hint: use TimeSeries.sync() method to sync dates of two TimeSeries objects."
|
||||
)
|
||||
|
||||
if isinstance(other, Series):
|
||||
if other.dtype != float:
|
||||
raise TypeError("Only Series of type float can be used for comparison")
|
||||
|
||||
if len(self) != len(other):
|
||||
raise ValueError("Length of series does not match length of object")
|
||||
|
||||
def __gt__(self, other):
|
||||
self._comparison_validator(other)
|
||||
|
||||
if isinstance(other, Number):
|
||||
data = {k: v > other for k, v in self.data.items()}
|
||||
|
||||
if isinstance(other, TimeSeriesCore):
|
||||
data = {dt: val > other[dt][1] for dt, val in self.data.items()}
|
||||
|
||||
if isinstance(other, Series):
|
||||
data = {dt: val > other[i] for i, (dt, val) in enumerate(self.data.items())}
|
||||
|
||||
return self.__class__(data, frequency=self.frequency.symbol)
|
||||
|
||||
def __ge__(self, other):
|
||||
self._comparison_validator(other)
|
||||
|
||||
if isinstance(other, Number):
|
||||
data = {k: v >= other for k, v in self.data.items()}
|
||||
|
||||
if isinstance(other, TimeSeriesCore):
|
||||
data = {dt: val >= other[dt][1] for dt, val in self.data.items()}
|
||||
|
||||
if isinstance(other, Series):
|
||||
data = {dt: val >= other[i] for i, (dt, val) in enumerate(self.data.items())}
|
||||
|
||||
return self.__class__(data, frequency=self.frequency.symbol)
|
||||
|
||||
def __lt__(self, other):
|
||||
self._comparison_validator(other)
|
||||
|
||||
if isinstance(other, Number):
|
||||
data = {k: v < other for k, v in self.data.items()}
|
||||
|
||||
if isinstance(other, TimeSeriesCore):
|
||||
data = {dt: val < other[dt][1] for dt, val in self.data.items()}
|
||||
|
||||
if isinstance(other, Series):
|
||||
data = {dt: val < other[i] for i, (dt, val) in enumerate(self.data.items())}
|
||||
|
||||
return self.__class__(data, frequency=self.frequency.symbol)
|
||||
|
||||
def __le__(self, other):
|
||||
self._comparison_validator(other)
|
||||
|
||||
if isinstance(other, Number):
|
||||
data = {k: v <= other for k, v in self.data.items()}
|
||||
|
||||
if isinstance(other, TimeSeriesCore):
|
||||
data = {dt: val <= other[dt][1] for dt, val in self.data.items()}
|
||||
|
||||
if isinstance(other, Series):
|
||||
data = {dt: val <= other[i] for i, (dt, val) in enumerate(self.data.items())}
|
||||
|
||||
return self.__class__(data, frequency=self.frequency.symbol)
|
||||
|
||||
def __eq__(self, other):
|
||||
self._comparison_validator(other)
|
||||
|
||||
if isinstance(other, Number):
|
||||
data = {k: v == other for k, v in self.data.items()}
|
||||
|
||||
if isinstance(other, TimeSeriesCore):
|
||||
data = {dt: val == other[dt][1] for dt, val in self.data.items()}
|
||||
|
||||
if isinstance(other, Series):
|
||||
data = {dt: val == other[i] for i, (dt, val) in enumerate(self.data.items())}
|
||||
|
||||
return self.__class__(data, frequency=self.frequency.symbol)
|
||||
|
||||
def __iter__(self):
|
||||
self.n = 0
|
||||
return self
|
||||
|
Loading…
Reference in New Issue
Block a user