diff --git a/fincal/core.py b/fincal/core.py index 34433ea..698211e 100644 --- a/fincal/core.py +++ b/fincal/core.py @@ -2,7 +2,7 @@ import datetime from collections import UserDict, UserList from dataclasses import dataclass from numbers import Number -from typing import Iterable, List, Literal, Mapping, Sequence, Tuple, Type, Union +from typing import Iterable, List, Literal, Mapping, Sequence, Tuple, Union @dataclass @@ -138,20 +138,36 @@ class Series(UserList): def __init__( self, data, - data_type: Union[Type[bool], Type[float], Type[str], Type[datetime.datetime]], + data_type: Literal['date', 'number', 'bool'], date_format: str = None, ): - self.dtype = data_type + types_dict = { + 'date': datetime.datetime, + 'datetime': datetime.datetime, + 'datetime.datetime': datetime.datetime, + 'float': float, + 'int': float, + 'number': float, + 'bool': bool + } + + if data_type not in types_dict.keys(): + raise ValueError("Unsupported value for data type") + if not isinstance(data, Sequence): raise TypeError("Series object can only be created using Sequence types") - for i in data: - if not isinstance(i, data_type): - raise Exception("All arguments must be of the same type") - - if data_type == str: + if data_type in ['date', 'datetime', 'datetime.datetime']: data = [_parse_date(i, date_format) for i in data] + else: + func = types_dict[data_type] + data = [func(i) for i in data] + # elif data_type == 'number': + # data = [float(i) for i in data] + # elif data_type == 'boolean': + # data = [bool(i) for i in data] + self.dtype = types_dict[data_type] self.data = data def __repr__(self): @@ -159,7 +175,7 @@ class Series(UserList): def __getitem__(self, i): if isinstance(i, slice): - return self.__class__(self.data[i], self.dtype) + return self.__class__(self.data[i], str(self.dtype.__name__)) else: return self.data[i] @@ -171,25 +187,58 @@ class Series(UserList): other = _parse_date(other) if self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype): - gt = Series([i > other for i in self.data], bool) + gt = Series([i > other for i in self.data], 'bool') else: raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}") return gt + def __ge__(self, other): + if self.dtype == bool: + raise TypeError(">= not supported for boolean series") + + if isinstance(other, (str, datetime.datetime, datetime.date)): + other = _parse_date(other) + + if self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype): + ge = Series([i >= other for i in self.data], 'bool') + else: + raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}") + + return ge + def __lt__(self, other): if self.dtype == bool: raise TypeError("< not supported for boolean series") + if isinstance(other, (str, datetime.datetime, datetime.date)): + other = _parse_date(other) + if self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype): - lt = Series([i < other for i in self.data], bool) + lt = Series([i < other for i in self.data], 'bool') else: raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}") return lt - def __eq__(self, other): + def __le__(self, other): + if self.dtype == bool: + raise TypeError("<= not supported for boolean series") + + if isinstance(other, (str, datetime.datetime, datetime.date)): + other = _parse_date(other) + if self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype): - eq = Series([i == other for i in self.data], bool) + le = Series([i <= other for i in self.data], 'bool') + else: + raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}") + return le + + def __eq__(self, other): + if isinstance(other, (str, datetime.datetime, datetime.date)): + other = _parse_date(other) + + if self.dtype == float and isinstance(other, Number) or isinstance(other, self.dtype): + eq = Series([i == other for i in self.data], 'bool') else: raise Exception(f"Cannot compare type {self.dtype.__name__} to {type(other).__name__}") return eq @@ -237,14 +286,14 @@ class TimeSeriesCore(UserDict): if self._dates is None or len(self._dates) != len(self.data): self._dates = list(self.data.keys()) - return Series(self._dates, datetime.datetime) + return Series(self._dates, 'date') @property def values(self): if self._values is None or len(self._values) != len(self.data): self._values = list(self.data.values()) - return Series(self._values, float) + return Series(self._values, 'number') @property def start_date(self): @@ -304,7 +353,8 @@ class TimeSeriesCore(UserDict): elif len(key) != len(self.dates): raise Exception(f"Length of Series: {len(key)} did not match length of object: {len(self.dates)}") else: - dates_to_return = [self.dates[i] for i, j in enumerate(key) if j] + dates = self.dates + dates_to_return = [dates[i] for i, j in enumerate(key) if j] data_to_return = [(key, self.data[key]) for key in dates_to_return] return self.__class__(data_to_return, frequency=self.frequency.symbol)