TimeSeriesCore is now a subclass of UserDict

This commit is contained in:
Gourav Kumar 2022-02-21 15:23:20 +05:30
parent 2f0f1e0e47
commit 59c9c658ca

View File

@ -1,5 +1,5 @@
import datetime
from collections import UserList
from collections import UserDict, UserList
from dataclasses import dataclass
from numbers import Number
from typing import Iterable, List, Literal, Mapping, Sequence, Tuple, Union
@ -192,7 +192,7 @@ class Series(UserList):
return eq
class TimeSeriesCore:
class TimeSeriesCore(UserDict):
"""Defines the core building blocks of a TimeSeries object"""
def __init__(
@ -219,8 +219,8 @@ class TimeSeriesCore:
data = _preprocess_timeseries(data, date_format=date_format)
self.time_series = dict(data)
if len(self.time_series) != len(data):
self.data = dict(data)
if len(self.data) != len(data):
print("Warning: The input data contains duplicate dates which have been ignored.")
self.frequency = getattr(AllFrequencies, frequency)
self.iter_num = -1
@ -231,15 +231,15 @@ class TimeSeriesCore:
@property
def dates(self):
if self._dates is None or len(self._dates) != len(self.time_series):
self._dates = list(self.time_series.keys())
if self._dates is None or len(self._dates) != len(self.data):
self._dates = list(self.data.keys())
return Series(self._dates)
@property
def values(self):
if self._values is None or len(self._values) != len(self.time_series):
self._values = list(self.time_series.values())
if self._values is None or len(self._values) != len(self.data):
self._values = list(self.data.values())
return Series(self._values)
@ -255,19 +255,19 @@ class TimeSeriesCore:
"""Returns a slice of the dataframe from beginning and end"""
printable = {}
iter_f = iter(self.time_series)
iter_f = iter(self.data)
first_n = [next(iter_f) for i in range(n // 2)]
iter_b = reversed(self.time_series)
iter_b = reversed(self.data)
last_n = [next(iter_b) for i in range(n // 2)]
last_n.sort()
printable["start"] = [str((i, self.time_series[i])) for i in first_n]
printable["end"] = [str((i, self.time_series[i])) for i in last_n]
printable["start"] = [str((i, self.data[i])) for i in first_n]
printable["end"] = [str((i, self.data[i])) for i in last_n]
return printable
def __repr__(self):
if len(self.time_series) > 6:
if len(self.data) > 6:
printable = self._get_printable_slice(6)
printable_str = "{}([{}\n\t ...\n\t {}], frequency={})".format(
self.__class__.__name__,
@ -278,20 +278,20 @@ class TimeSeriesCore:
else:
printable_str = "{}([{}], frequency={})".format(
self.__class__.__name__,
",\n\t".join([str(i) for i in self.time_series.items()]),
",\n\t".join([str(i) for i in self.data.items()]),
repr(self.frequency.symbol),
)
return printable_str
def __str__(self):
if len(self.time_series) > 6:
if len(self.data) > 6:
printable = self._get_printable_slice(6)
printable_str = "[{}\n ...\n {}]".format(
",\n ".join(printable["start"]),
",\n ".join(printable["end"]),
)
else:
printable_str = "[{}]".format(",\n ".join([str(i) for i in self.time_series.items()]))
printable_str = "[{}]".format(",\n ".join([str(i) for i in self.data.items()]))
return printable_str
def __getitem__(self, key):
@ -302,14 +302,14 @@ class TimeSeriesCore:
raise Exception(f"Length of Series: {len(key)} did not match length of object: {len(self.dates)}")
else:
dates_to_return = [self.dates[i] for i, j in enumerate(key) if j]
data_to_return = [(key, self.time_series[key]) for key in dates_to_return]
data_to_return = [(key, self.data[key]) for key in dates_to_return]
return TimeSeriesCore(data_to_return, frequency=self.frequency.symbol)
if isinstance(key, int):
raise KeyError(f"{key}. For index based slicing, use .iloc[{key}]")
elif isinstance(key, (datetime.datetime, datetime.date)):
key = _parse_date(key)
item = (key, self.time_series[key])
item = (key, self.data[key])
elif isinstance(key, str):
if key == "dates":
return self.dates
@ -317,17 +317,17 @@ class TimeSeriesCore:
return self.values
dt_key = _parse_date(key)
item = (dt_key, self.time_series[dt_key])
item = (dt_key, self.data[dt_key])
elif isinstance(key, Sequence):
keys = [_parse_date(i) for i in key]
item = [(k, self.time_series[k]) for k in keys]
item = [(k, self.data[k]) for k in keys]
else:
raise TypeError(f"Invalid type {repr(type(key).__name__)} for slicing.")
return item
def __len__(self):
return len(self.time_series)
return len(self.data)
def __iter__(self):
self.n = 0
@ -339,22 +339,22 @@ class TimeSeriesCore:
else:
key = self.dates[self.n]
self.n += 1
return key, self.time_series[key]
return key, self.data[key]
def head(self, n: int = 6):
"""Returns the first n items of the TimeSeries object"""
keys = list(self.time_series.keys())
keys = list(self.data.keys())
keys = keys[:n]
result = [(key, self.time_series[key]) for key in keys]
result = [(key, self.data[key]) for key in keys]
return result
def tail(self, n: int = 6):
"""Returns the last n items of the TimeSeries object"""
keys = list(self.time_series.keys())
keys = list(self.data.keys())
keys = keys[-n:]
result = [(key, self.time_series[key]) for key in keys]
result = [(key, self.data[key]) for key in keys]
return result
@property