expanded comparison for series, implemented gt in TSC

This commit is contained in:
Gourav Kumar 2022-04-10 23:52:53 +05:30
parent 09365c7957
commit b246709603

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import datetime import datetime
import inspect import inspect
import warnings
from collections import UserList from collections import UserList
from dataclasses import dataclass from dataclasses import dataclass
from numbers import Number from numbers import Number
@ -99,6 +100,11 @@ class _IndexSlicer:
return self.parent.__class__(item, self.parent.frequency.symbol) return self.parent.__class__(item, self.parent.frequency.symbol)
def __setitem__(self, key, value):
raise NotImplementedError(
"iloc cannot be used for setting a value as value will always be inserted in order of date"
)
class Series(UserList): class Series(UserList):
"""Container for a series of objects, all objects must be of the same type""" """Container for a series of objects, all objects must be of the same type"""
@ -150,7 +156,13 @@ class Series(UserList):
if isinstance(other, (str, datetime.datetime, datetime.date)): if isinstance(other, (str, datetime.datetime, datetime.date)):
other = _parse_date(other) other = _parse_date(other)
if self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype): if isinstance(other, Series):
if len(self) != len(other):
raise ValueError("Length of Series must be same for comparison")
gt = Series([j > other[i] for i, j in enumerate(self)], "bool")
elif self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype):
gt = Series([i > other for i in self.data], "bool") gt = Series([i > other for i in self.data], "bool")
else: else:
raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}") raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}")
@ -164,7 +176,13 @@ class Series(UserList):
if isinstance(other, (str, datetime.datetime, datetime.date)): if isinstance(other, (str, datetime.datetime, datetime.date)):
other = _parse_date(other) other = _parse_date(other)
if self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype): if isinstance(other, Series):
if len(self) != len(other):
raise ValueError("Length of Series must be same for comparison")
ge = Series([j >= other[i] for i, j in enumerate(self)], "bool")
elif self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype):
ge = Series([i >= other for i in self.data], "bool") ge = Series([i >= other for i in self.data], "bool")
else: else:
raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}") raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}")
@ -178,7 +196,13 @@ class Series(UserList):
if isinstance(other, (str, datetime.datetime, datetime.date)): if isinstance(other, (str, datetime.datetime, datetime.date)):
other = _parse_date(other) other = _parse_date(other)
if self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype): if isinstance(other, Series):
if len(self) != len(other):
raise ValueError("Length of Series must be same for comparison")
lt = Series([j < other[i] for i, j in enumerate(self)], "bool")
elif self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype):
lt = Series([i < other for i in self.data], "bool") lt = Series([i < other for i in self.data], "bool")
else: else:
raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}") raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}")
@ -191,7 +215,13 @@ class Series(UserList):
if isinstance(other, (str, datetime.datetime, datetime.date)): if isinstance(other, (str, datetime.datetime, datetime.date)):
other = _parse_date(other) other = _parse_date(other)
if self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype): if isinstance(other, Series):
if len(self) != len(other):
raise ValueError("Length of Series must be same for comparison")
le = Series([j <= other[i] for i, j in enumerate(self)], "bool")
elif self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype):
le = Series([i <= other for i in self.data], "bool") le = Series([i <= other for i in self.data], "bool")
else: else:
raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}") raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}")
@ -201,13 +231,20 @@ class Series(UserList):
if isinstance(other, (str, datetime.datetime, datetime.date)): if isinstance(other, (str, datetime.datetime, datetime.date)):
other = _parse_date(other) other = _parse_date(other)
if self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype): if isinstance(other, Series):
if len(self) != len(other):
raise ValueError("Length of Series must be same for comparison")
eq = Series([j == other[i] for i, j in enumerate(self)], "bool")
elif self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype):
eq = Series([i == other for i in self.data], "bool") eq = Series([i == other for i in self.data], "bool")
else: else:
raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}") raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}")
return eq return eq
@Mapping.register
class TimeSeriesCore: class TimeSeriesCore:
"""Defines the core building blocks of a TimeSeries object""" """Defines the core building blocks of a TimeSeries object"""
@ -239,7 +276,7 @@ class TimeSeriesCore:
self.data = dict(ts_data) self.data = dict(ts_data)
if len(self.data) != len(ts_data): if len(self.data) != len(ts_data):
print("Warning: The input data contains duplicate dates which have been ignored.") warnings.warn("The input data contains duplicate dates which have been ignored.")
self.frequency: Frequency = getattr(AllFrequencies, frequency) self.frequency: Frequency = getattr(AllFrequencies, frequency)
self.iter_num: int = -1 self.iter_num: int = -1
self._dates: list = None self._dates: list = None
@ -364,8 +401,38 @@ class TimeSeriesCore:
raise TypeError(f"Invalid type {repr(type(key).__name__)} for slicing.") raise TypeError(f"Invalid type {repr(type(key).__name__)} for slicing.")
def __gt__(self, other):
if isinstance(other, Number):
data = {k: v > other for k, v in self.data.items()}
if isinstance(other, TimeSeriesCore):
if self.dates != other.dates:
raise ValueError(
"Only objects with same set of dates can be compared.\n"
"Hint: use TimeSeries.sync() method to sync dates of two TimeSeries objects."
)
data = {dt: val > other[dt][1] for dt, val in self.data.items()}
if isinstance(other, Series):
if Series.dtype != float:
raise TypeError("Only Series of type float can be used for comparison")
if len(self) != len(other):
raise ValueError("Length of series does not match length of object")
data = {dt: val > other[i] for i, (dt, val) in enumerate(self.data.items())}
return self.__class__(data, frequency=self.frequency.symbol)
@date_parser(1)
def __setitem__(self, key: str | datetime.datetime, value: Number) -> None: def __setitem__(self, key: str | datetime.datetime, value: Number) -> None:
key = _parse_date(key) if not isinstance(value, Number):
raise TypeError("Only numerical values can be stored in TimeSeries")
if key in self.data:
self.data[key] = value
else:
self.data.update({key: value}) self.data.update({key: value})
self.data = dict(sorted(self.data.items())) self.data = dict(sorted(self.data.items()))