Refactored with improved type hints

This commit is contained in:
Gourav Kumar 2022-04-05 10:43:53 +05:30
parent e06626dbca
commit ce6326f4b5
3 changed files with 61 additions and 61 deletions

View File

@ -5,7 +5,7 @@ import inspect
from collections import UserDict, UserList
from dataclasses import dataclass
from numbers import Number
from typing import Iterable, List, Literal, Mapping, Sequence, Union
from typing import Callable, Iterable, List, Literal, Mapping, Sequence, Type
from .utils import FincalOptions, _parse_date, _preprocess_timeseries
@ -42,15 +42,15 @@ def date_parser(*pos):
def parse_dates(func):
def wrapper_func(*args, **kwargs):
date_format = kwargs.get("date_format", None)
args = list(args)
sig = inspect.signature(func)
params = [i[0] for i in sig.parameters.items()]
date_format: str = kwargs.get("date_format", None)
args: list = list(args)
sig: inspect.Signature = inspect.signature(func)
params: list = [i[0] for i in sig.parameters.items()]
for j in pos:
kwarg = params[j]
kwarg: str = params[j]
date = kwargs.get(kwarg, None)
in_args = False
in_args: bool = False
if date is None:
try:
date = args[j]
@ -61,7 +61,7 @@ def date_parser(*pos):
if date is None:
continue
parsed_date = _parse_date(date, date_format)
parsed_date: datetime.datetime = _parse_date(date, date_format)
if not in_args:
kwargs[kwarg] = parsed_date
else:
@ -90,9 +90,9 @@ class _IndexSlicer:
def __getitem__(self, n):
if isinstance(n, int):
keys = [self.parent.dates[n]]
keys: list = [self.parent.dates[n]]
else:
keys = self.parent.dates[n]
keys: list = self.parent.dates[n]
item = [(key, self.parent.data[key]) for key in keys]
if len(item) == 1:
return item[0]
@ -105,11 +105,11 @@ class Series(UserList):
def __init__(
self,
data,
data: Sequence,
data_type: Literal["date", "number", "bool"],
date_format: str = None,
):
types_dict = {
types_dict: dict = {
"date": datetime.datetime,
"datetime": datetime.datetime,
"datetime.datetime": datetime.datetime,
@ -128,11 +128,11 @@ class Series(UserList):
if data_type in ["date", "datetime", "datetime.datetime"]:
data = [_parse_date(i, date_format) for i in data]
else:
func = types_dict[data_type]
data = [func(i) for i in data]
func: Callable = types_dict[data_type]
data: list = [func(i) for i in data]
self.dtype = types_dict[data_type]
self.data = data
self.dtype: Type = types_dict[data_type]
self.data: Sequence = data
def __repr__(self):
return f"{self.__class__.__name__}({self.data}, data_type='{self.dtype.__name__}')"
@ -212,25 +212,27 @@ class TimeSeriesCore(UserDict):
"""Defines the core building blocks of a TimeSeries object"""
def __init__(
self, data: List[Iterable], frequency: Literal["D", "W", "M", "Q", "H", "Y"], date_format: str = "%Y-%m-%d"
self,
data: List[Iterable] | Mapping,
frequency: Literal["D", "W", "M", "Q", "H", "Y"],
date_format: str = "%Y-%m-%d",
):
"""Instantiate a TimeSeriesCore object
Parameters
----------
data : List[tuple]
Time Series data in the form of list of tuples.
data : List[Iterable] | Mapping
Time Series data in the form of list of tuples or dictionary.
The first element of each tuple should be a date and second element should be a value.
In case of dictionary, the key should be the date.
frequency : str
The frequency of the time series.
Valid values are {D, W, M, Q, H, Y}
date_format : str, optional, default "%Y-%m-%d"
Specify the format of the date
Required only if the first argument of tuples is a string. Otherwise ignored.
frequency : str, optional, default "infer"
The frequency of the time series. Default is infer.
The class will try to infer the frequency automatically and adjust to the closest member.
Note that inferring frequencies can fail if the data is too irregular.
Valid values are {D, W, M, Q, H, Y}
"""
data = _preprocess_timeseries(data, date_format=date_format)
@ -322,10 +324,10 @@ class TimeSeriesCore(UserDict):
return printable_str
@date_parser(1)
def _get_item_from_date(self, date: Union[str, datetime.datetime]):
def _get_item_from_date(self, date: str | datetime.datetime):
return date, self.data[date]
def _get_item_from_key(self, key: Union[str, datetime.datetime]):
def _get_item_from_key(self, key: str | datetime.datetime):
if isinstance(key, int):
raise KeyError(f"{key}. \nHint: use .iloc[{key}] for index based slicing.")
@ -334,7 +336,7 @@ class TimeSeriesCore(UserDict):
return self._get_item_from_date(key)
def _get_item_from_list(self, date_list: Sequence[Union[str, datetime.datetime]]):
def _get_item_from_list(self, date_list: Sequence[str | datetime.datetime]):
data_to_return = [self._get_item_from_key(key) for key in date_list]
return self.__class__(data_to_return, frequency=self.frequency.symbol)
@ -379,7 +381,7 @@ class TimeSeriesCore(UserDict):
return super().__contains__(key)
@date_parser(1)
def get(self, date: Union[str, datetime.datetime], default=None, closest=None):
def get(self, date: str | datetime.datetime, default=None, closest=None):
if closest is None:
closest = FincalOptions.get_closest

View File

@ -5,7 +5,7 @@ import datetime
import math
import pathlib
import statistics
from typing import Iterable, List, Literal, Mapping, Tuple, TypedDict, Union
from typing import Iterable, List, Literal, Mapping, Tuple, TypedDict
from dateutil.relativedelta import relativedelta
@ -26,8 +26,8 @@ class MaxDrawdown(TypedDict):
@date_parser(0, 1)
def create_date_series(
start_date: Union[str, datetime.datetime],
end_date: Union[str, datetime.datetime],
start_date: str | datetime.datetime,
end_date: str | datetime.datetime,
frequency: Literal["D", "W", "M", "Q", "H", "Y"],
eomonth: bool = False,
skip_weekends: bool = False,
@ -130,7 +130,7 @@ class TimeSeries(TimeSeriesCore):
def __init__(
self,
data: Union[List[Iterable], Mapping],
data: List[Iterable] | Mapping,
frequency: Literal["D", "W", "M", "Q", "H", "Y"],
date_format: str = "%Y-%m-%d",
):
@ -145,7 +145,7 @@ class TimeSeries(TimeSeriesCore):
res_string: str = "First date: {}\nLast date: {}\nNumber of rows: {}"
return res_string.format(self.start_date, self.end_date, total_dates)
def ffill(self, inplace: bool = False, limit: int = None, skip_weekends: bool = False) -> Union[TimeSeries, None]:
def ffill(self, inplace: bool = False, limit: int = None, skip_weekends: bool = False) -> TimeSeries | None:
"""Forward fill missing dates in the time series
Parameters
@ -183,7 +183,7 @@ class TimeSeries(TimeSeriesCore):
return self.__class__(new_ts, frequency=self.frequency.symbol)
def bfill(self, inplace: bool = False, limit: int = None, skip_weekends: bool = False) -> Union[TimeSeries, None]:
def bfill(self, inplace: bool = False, limit: int = None, skip_weekends: bool = False) -> TimeSeries | None:
"""Backward fill missing dates in the time series
Parameters
@ -225,7 +225,7 @@ class TimeSeries(TimeSeriesCore):
@date_parser(1)
def calculate_returns(
self,
as_on: Union[str, datetime.datetime],
as_on: str | datetime.datetime,
return_actual_date: bool = True,
as_on_match: str = "closest",
prior_match: str = "closest",
@ -318,8 +318,9 @@ class TimeSeries(TimeSeriesCore):
@date_parser(1, 2)
def calculate_rolling_returns(
self,
from_date: Union[datetime.date, str],
to_date: Union[datetime.date, str],
from_date: datetime.date | str,
to_date: datetime.date,
str,
frequency: Literal["D", "W", "M", "Q", "H", "Y"] = None,
as_on_match: str = "closest",
prior_match: str = "closest",
@ -427,8 +428,8 @@ class TimeSeries(TimeSeriesCore):
@date_parser(1, 2)
def volatility(
self,
from_date: Union[datetime.date, str] = None,
to_date: Union[datetime.date, str] = None,
from_date: datetime.date | str = None,
to_date: datetime.date | str = None,
annualize_volatility: bool = True,
traded_days: int = None,
frequency: Literal["D", "W", "M", "Q", "H", "Y"] = None,
@ -600,7 +601,7 @@ class TimeSeries(TimeSeriesCore):
ensure_coverage=True,
)
closest = "previous" if method == "ffill" else "next"
closest: str = "previous" if method == "ffill" else "next"
new_ts: dict = {dt: self.get(dt, closest=closest)[1] for dt in new_dates}
output_ts: TimeSeries = TimeSeries(new_ts, frequency=to_frequency.symbol)
@ -617,8 +618,8 @@ def _preprocess_csv(file_path: str | pathlib.Path, delimiter: str = ",", encodin
raise ValueError("File not found. Check the file path")
with open(file_path, "r", encoding=encoding) as file:
reader = csv.reader(file, delimiter=delimiter)
csv_data = list(reader)
reader: csv.reader = csv.reader(file, delimiter=delimiter)
csv_data: list = list(reader)
csv_data = [i for i in csv_data if i] # remove blank rows
if not csv_data:

View File

@ -1,6 +1,6 @@
import datetime
from dataclasses import dataclass
from typing import Iterable, List, Literal, Mapping, Sequence, Tuple, Union
from typing import Iterable, List, Literal, Mapping, Sequence, Tuple
from .exceptions import DateNotFoundError, DateOutOfRangeError
@ -32,18 +32,15 @@ def _parse_date(date: str, date_format: str = None):
def _preprocess_timeseries(
data: Union[
Sequence[Iterable[Union[str, datetime.datetime, float]]],
Sequence[Mapping[str, Union[float, datetime.datetime]]],
Sequence[Mapping[Union[str, datetime.datetime], float]],
Mapping[Union[str, datetime.datetime], float],
],
data: Sequence[Iterable[str | datetime.datetime, float]]
| Sequence[Mapping[str | datetime.datetime, float]]
| Mapping[str | datetime.datetime, float],
date_format: str,
) -> List[Tuple[datetime.datetime, float]]:
"""Converts any type of list to the correct type"""
if isinstance(data, Mapping):
current_data = [(k, v) for k, v in data.items()]
current_data: List[tuple] = [(k, v) for k, v in data.items()]
return _preprocess_timeseries(current_data, date_format)
if not isinstance(data, Sequence):
@ -56,31 +53,31 @@ def _preprocess_timeseries(
raise TypeError("Could not parse the data")
if len(data[0]) == 1:
current_data = [tuple(*i.items()) for i in data]
current_data: List[tuple] = [tuple(*i.items()) for i in data]
elif len(data[0]) == 2:
current_data = [tuple(i.values()) for i in data]
current_data: List[tuple] = [tuple(i.values()) for i in data]
else:
raise TypeError("Could not parse the data")
return _preprocess_timeseries(current_data, date_format)
def _preprocess_match_options(as_on_match: str, prior_match: str, closest: str) -> datetime.timedelta:
def _preprocess_match_options(as_on_match: str, prior_match: str, closest: str) -> Tuple[datetime.timedelta]:
"""Checks the arguments and returns appropriate timedelta objects"""
deltas = {"exact": 0, "previous": -1, "next": 1}
if closest not in deltas.keys():
raise ValueError(f"Invalid argument for closest: {closest}")
as_on_match = closest if as_on_match == "closest" else as_on_match
prior_match = closest if prior_match == "closest" else prior_match
as_on_match: str = closest if as_on_match == "closest" else as_on_match
prior_match: str = closest if prior_match == "closest" else prior_match
if as_on_match in deltas.keys():
as_on_delta = datetime.timedelta(days=deltas[as_on_match])
as_on_delta: datetime.timedelta = datetime.timedelta(days=deltas[as_on_match])
else:
raise ValueError(f"Invalid as_on_match argument: {as_on_match}")
if prior_match in deltas.keys():
prior_delta = datetime.timedelta(days=deltas[prior_match])
prior_delta: datetime.timedelta = datetime.timedelta(days=deltas[prior_match])
else:
raise ValueError(f"Invalid prior_match argument: {prior_match}")
@ -101,7 +98,7 @@ def _find_closest_date(
if delta.days > 0 and date > max(data):
raise DateOutOfRangeError(date, "max")
row = data.get(date, None)
row: tuple = data.get(date, None)
if row is not None:
return date, row
@ -119,6 +116,6 @@ def _find_closest_date(
def _interval_to_years(interval_type: Literal["years", "months", "day"], interval_value: int) -> float:
"""Converts any time period to years for use with compounding functions"""
year_conversion_factor = {"years": 1, "months": 12, "days": 365}
years = interval_value / year_conversion_factor[interval_type]
year_conversion_factor: dict = {"years": 1, "months": 12, "days": 365}
years: float = interval_value / year_conversion_factor[interval_type]
return years