expanded comparison for series, implemented gt in TSC

This commit is contained in:
Gourav Kumar 2022-04-10 23:52:53 +05:30
parent 09365c7957
commit b246709603

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import datetime
import inspect
import warnings
from collections import UserList
from dataclasses import dataclass
from numbers import Number
@ -99,6 +100,11 @@ class _IndexSlicer:
return self.parent.__class__(item, self.parent.frequency.symbol)
def __setitem__(self, key, value):
raise NotImplementedError(
"iloc cannot be used for setting a value as value will always be inserted in order of date"
)
class Series(UserList):
"""Container for a series of objects, all objects must be of the same type"""
@ -150,7 +156,13 @@ class Series(UserList):
if isinstance(other, (str, datetime.datetime, datetime.date)):
other = _parse_date(other)
if self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype):
if isinstance(other, Series):
if len(self) != len(other):
raise ValueError("Length of Series must be same for comparison")
gt = Series([j > other[i] for i, j in enumerate(self)], "bool")
elif self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype):
gt = Series([i > other for i in self.data], "bool")
else:
raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}")
@ -164,7 +176,13 @@ class Series(UserList):
if isinstance(other, (str, datetime.datetime, datetime.date)):
other = _parse_date(other)
if self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype):
if isinstance(other, Series):
if len(self) != len(other):
raise ValueError("Length of Series must be same for comparison")
ge = Series([j >= other[i] for i, j in enumerate(self)], "bool")
elif self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype):
ge = Series([i >= other for i in self.data], "bool")
else:
raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}")
@ -178,7 +196,13 @@ class Series(UserList):
if isinstance(other, (str, datetime.datetime, datetime.date)):
other = _parse_date(other)
if self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype):
if isinstance(other, Series):
if len(self) != len(other):
raise ValueError("Length of Series must be same for comparison")
lt = Series([j < other[i] for i, j in enumerate(self)], "bool")
elif self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype):
lt = Series([i < other for i in self.data], "bool")
else:
raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}")
@ -191,7 +215,13 @@ class Series(UserList):
if isinstance(other, (str, datetime.datetime, datetime.date)):
other = _parse_date(other)
if self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype):
if isinstance(other, Series):
if len(self) != len(other):
raise ValueError("Length of Series must be same for comparison")
le = Series([j <= other[i] for i, j in enumerate(self)], "bool")
elif self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype):
le = Series([i <= other for i in self.data], "bool")
else:
raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}")
@ -201,13 +231,20 @@ class Series(UserList):
if isinstance(other, (str, datetime.datetime, datetime.date)):
other = _parse_date(other)
if self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype):
if isinstance(other, Series):
if len(self) != len(other):
raise ValueError("Length of Series must be same for comparison")
eq = Series([j == other[i] for i, j in enumerate(self)], "bool")
elif self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype):
eq = Series([i == other for i in self.data], "bool")
else:
raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}")
return eq
@Mapping.register
class TimeSeriesCore:
"""Defines the core building blocks of a TimeSeries object"""
@ -239,7 +276,7 @@ class TimeSeriesCore:
self.data = dict(ts_data)
if len(self.data) != len(ts_data):
print("Warning: The input data contains duplicate dates which have been ignored.")
warnings.warn("The input data contains duplicate dates which have been ignored.")
self.frequency: Frequency = getattr(AllFrequencies, frequency)
self.iter_num: int = -1
self._dates: list = None
@ -364,8 +401,38 @@ class TimeSeriesCore:
raise TypeError(f"Invalid type {repr(type(key).__name__)} for slicing.")
def __gt__(self, other):
if isinstance(other, Number):
data = {k: v > other for k, v in self.data.items()}
if isinstance(other, TimeSeriesCore):
if self.dates != other.dates:
raise ValueError(
"Only objects with same set of dates can be compared.\n"
"Hint: use TimeSeries.sync() method to sync dates of two TimeSeries objects."
)
data = {dt: val > other[dt][1] for dt, val in self.data.items()}
if isinstance(other, Series):
if Series.dtype != float:
raise TypeError("Only Series of type float can be used for comparison")
if len(self) != len(other):
raise ValueError("Length of series does not match length of object")
data = {dt: val > other[i] for i, (dt, val) in enumerate(self.data.items())}
return self.__class__(data, frequency=self.frequency.symbol)
@date_parser(1)
def __setitem__(self, key: str | datetime.datetime, value: Number) -> None:
key = _parse_date(key)
if not isinstance(value, Number):
raise TypeError("Only numerical values can be stored in TimeSeries")
if key in self.data:
self.data[key] = value
else:
self.data.update({key: value})
self.data = dict(sorted(self.data.items()))