diff --git a/fincal/core.py b/fincal/core.py index 5ea8893..67c7416 100644 --- a/fincal/core.py +++ b/fincal/core.py @@ -3,7 +3,7 @@ import inspect from collections import UserDict, UserList from dataclasses import dataclass from numbers import Number -from typing import Iterable, List, Literal, Sequence, Tuple +from typing import Iterable, List, Literal, Mapping, Sequence, Union from .utils import _parse_date, _preprocess_timeseries @@ -32,7 +32,7 @@ def date_parser(*pos): ... return diff.days ... return diff ... - >>> calculate_difference(date1='2019-01-01'm date2='2020-01-01') + >>> calculate_difference(date1='2019-01-01', date2='2020-01-01') datetime.timedelta(365) Each of the dates is automatically parsed into a datetime.datetime object from string. @@ -83,7 +83,7 @@ class AllFrequencies: class _IndexSlicer: """Class to create a slice using iloc in TimeSeriesCore""" - def __init__(self, parent_obj): + def __init__(self, parent_obj: object): self.parent = parent_obj def __getitem__(self, n): @@ -95,7 +95,7 @@ class _IndexSlicer: if len(item) == 1: return item[0] - return item + return self.parent.__class__(item, self.parent.frequency.symbol) class Series(UserList): @@ -319,38 +319,46 @@ class TimeSeriesCore(UserDict): printable_str = "[{}]".format(",\n ".join([str(i) for i in self.data.items()])) return printable_str - def __getitem__(self, key): - if isinstance(key, Series): - if not key.dtype == bool: - raise ValueError(f"Cannot slice {self.__class__.__name__} using a Series of {key.dtype.__name__}") - elif len(key) != len(self.dates): - raise Exception(f"Length of Series: {len(key)} did not match length of object: {len(self.dates)}") - else: - dates = self.dates - dates_to_return = [dates[i] for i, j in enumerate(key) if j] - data_to_return = [(key, self.data[key]) for key in dates_to_return] - return self.__class__(data_to_return, frequency=self.frequency.symbol) + @date_parser(1) + def _get_item_from_date(self, date: Union[str, datetime.datetime]): + return date, self.data[date] + def _get_item_from_key(self, key: Union[str, datetime.datetime]): if isinstance(key, int): - raise KeyError(f"{key}. For index based slicing, use .iloc[{key}]") - elif isinstance(key, (datetime.datetime, datetime.date)): - key = _parse_date(key) - item = (key, self.data[key]) - elif isinstance(key, str): - if key == "dates": - return self.dates - elif key == "values": - return self.values + raise KeyError(f"{key}. \nHint: use .iloc[{key}] for index based slicing.") - dt_key = _parse_date(key) - item = (dt_key, self.data[dt_key]) + if key in ["dates", "values"]: + return getattr(self, key) - elif isinstance(key, Sequence): - keys = [_parse_date(i) for i in key] - item = [(k, self.data[k]) for k in keys] + return self._get_item_from_date(key) + + def _get_item_from_list(self, date_list: Sequence[Union[str, datetime.datetime]]): + data_to_return = [self._get_item_from_key(key) for key in date_list] + return self.__class__(data_to_return, frequency=self.frequency.symbol) + + def _get_item_from_series(self, series: Series): + if series.dtype == bool: + if len(series) != len(self.dates): + raise ValueError(f"Length of Series: {len(series)} did not match length of object: {len(self.dates)}") + dates_to_return = [self.dates[i] for i, j in enumerate(series) if j] + elif series.dtype == datetime.datetime: + dates_to_return = list(series) else: - raise TypeError(f"Invalid type {repr(type(key).__name__)} for slicing.") - return item + raise TypeError(f"Cannot slice {self.__class__.__name__} using a Series of {series.dtype.__name__}") + + return self._get_item_from_list(dates_to_return) + + def __getitem__(self, key): + if isinstance(key, (int, str, datetime.datetime, datetime.date)): + return self._get_item_from_key(key) + + if isinstance(key, Series): + return self._get_item_from_series(key) + + if isinstance(key, Sequence): + return self._get_item_from_list(key) + + raise TypeError(f"Invalid type {repr(type(key).__name__)} for slicing.") def __iter__(self): self.n = 0 @@ -364,31 +372,12 @@ class TimeSeriesCore(UserDict): self.n += 1 return key, self.data[key] + @date_parser(1) def __contains__(self, key: object) -> bool: - key = _parse_date(key) return super().__contains__(key) - def head(self, n: int = 6): - """Returns the first n items of the TimeSeries object""" - - keys = list(self.data.keys()) - keys = keys[:n] - result = [(key, self.data[key]) for key in keys] - return result - - def tail(self, n: int = 6): - """Returns the last n items of the TimeSeries object""" - - keys = list(self.data.keys()) - keys = keys[-n:] - result = [(key, self.data[key]) for key in keys] - return result - - def items(self): - return self.data.items() - @property - def iloc(self) -> List[Tuple[datetime.datetime, float]]: + def iloc(self) -> Mapping: """Returns an item or a set of items based on index supports slicing using numerical index. @@ -406,3 +395,16 @@ class TimeSeriesCore(UserDict): """ return _IndexSlicer(self) + + def head(self, n: int = 6): + """Returns the first n items of the TimeSeries object""" + + return self.iloc[:n] + + def tail(self, n: int = 6): + """Returns the last n items of the TimeSeries object""" + + return self.iloc[-n:] + + def items(self): + return self.data.items()