improved getitem, head, and tail functions

This commit is contained in:
Gourav Kumar 2022-03-12 10:24:40 +05:30
parent ce5540e26b
commit d7b06fbe24

View File

@ -3,7 +3,7 @@ import inspect
from collections import UserDict, UserList
from dataclasses import dataclass
from numbers import Number
from typing import Iterable, List, Literal, Sequence, Tuple
from typing import Iterable, List, Literal, Mapping, Sequence, Union
from .utils import _parse_date, _preprocess_timeseries
@ -32,7 +32,7 @@ def date_parser(*pos):
... return diff.days
... return diff
...
>>> calculate_difference(date1='2019-01-01'm date2='2020-01-01')
>>> calculate_difference(date1='2019-01-01', date2='2020-01-01')
datetime.timedelta(365)
Each of the dates is automatically parsed into a datetime.datetime object from string.
@ -83,7 +83,7 @@ class AllFrequencies:
class _IndexSlicer:
"""Class to create a slice using iloc in TimeSeriesCore"""
def __init__(self, parent_obj):
def __init__(self, parent_obj: object):
self.parent = parent_obj
def __getitem__(self, n):
@ -95,7 +95,7 @@ class _IndexSlicer:
if len(item) == 1:
return item[0]
return item
return self.parent.__class__(item, self.parent.frequency.symbol)
class Series(UserList):
@ -319,38 +319,46 @@ class TimeSeriesCore(UserDict):
printable_str = "[{}]".format(",\n ".join([str(i) for i in self.data.items()]))
return printable_str
def __getitem__(self, key):
if isinstance(key, Series):
if not key.dtype == bool:
raise ValueError(f"Cannot slice {self.__class__.__name__} using a Series of {key.dtype.__name__}")
elif len(key) != len(self.dates):
raise Exception(f"Length of Series: {len(key)} did not match length of object: {len(self.dates)}")
else:
dates = self.dates
dates_to_return = [dates[i] for i, j in enumerate(key) if j]
data_to_return = [(key, self.data[key]) for key in dates_to_return]
return self.__class__(data_to_return, frequency=self.frequency.symbol)
@date_parser(1)
def _get_item_from_date(self, date: Union[str, datetime.datetime]):
return date, self.data[date]
def _get_item_from_key(self, key: Union[str, datetime.datetime]):
if isinstance(key, int):
raise KeyError(f"{key}. For index based slicing, use .iloc[{key}]")
elif isinstance(key, (datetime.datetime, datetime.date)):
key = _parse_date(key)
item = (key, self.data[key])
elif isinstance(key, str):
if key == "dates":
return self.dates
elif key == "values":
return self.values
raise KeyError(f"{key}. \nHint: use .iloc[{key}] for index based slicing.")
dt_key = _parse_date(key)
item = (dt_key, self.data[dt_key])
if key in ["dates", "values"]:
return getattr(self, key)
elif isinstance(key, Sequence):
keys = [_parse_date(i) for i in key]
item = [(k, self.data[k]) for k in keys]
return self._get_item_from_date(key)
def _get_item_from_list(self, date_list: Sequence[Union[str, datetime.datetime]]):
data_to_return = [self._get_item_from_key(key) for key in date_list]
return self.__class__(data_to_return, frequency=self.frequency.symbol)
def _get_item_from_series(self, series: Series):
if series.dtype == bool:
if len(series) != len(self.dates):
raise ValueError(f"Length of Series: {len(series)} did not match length of object: {len(self.dates)}")
dates_to_return = [self.dates[i] for i, j in enumerate(series) if j]
elif series.dtype == datetime.datetime:
dates_to_return = list(series)
else:
raise TypeError(f"Invalid type {repr(type(key).__name__)} for slicing.")
return item
raise TypeError(f"Cannot slice {self.__class__.__name__} using a Series of {series.dtype.__name__}")
return self._get_item_from_list(dates_to_return)
def __getitem__(self, key):
if isinstance(key, (int, str, datetime.datetime, datetime.date)):
return self._get_item_from_key(key)
if isinstance(key, Series):
return self._get_item_from_series(key)
if isinstance(key, Sequence):
return self._get_item_from_list(key)
raise TypeError(f"Invalid type {repr(type(key).__name__)} for slicing.")
def __iter__(self):
self.n = 0
@ -364,31 +372,12 @@ class TimeSeriesCore(UserDict):
self.n += 1
return key, self.data[key]
@date_parser(1)
def __contains__(self, key: object) -> bool:
key = _parse_date(key)
return super().__contains__(key)
def head(self, n: int = 6):
"""Returns the first n items of the TimeSeries object"""
keys = list(self.data.keys())
keys = keys[:n]
result = [(key, self.data[key]) for key in keys]
return result
def tail(self, n: int = 6):
"""Returns the last n items of the TimeSeries object"""
keys = list(self.data.keys())
keys = keys[-n:]
result = [(key, self.data[key]) for key in keys]
return result
def items(self):
return self.data.items()
@property
def iloc(self) -> List[Tuple[datetime.datetime, float]]:
def iloc(self) -> Mapping:
"""Returns an item or a set of items based on index
supports slicing using numerical index.
@ -406,3 +395,16 @@ class TimeSeriesCore(UserDict):
"""
return _IndexSlicer(self)
def head(self, n: int = 6):
"""Returns the first n items of the TimeSeries object"""
return self.iloc[:n]
def tail(self, n: int = 6):
"""Returns the last n items of the TimeSeries object"""
return self.iloc[-n:]
def items(self):
return self.data.items()