Refactored with improved type hints

This commit is contained in:
Gourav Kumar 2022-04-05 10:43:53 +05:30
parent e06626dbca
commit ce6326f4b5
3 changed files with 61 additions and 61 deletions

View File

@ -5,7 +5,7 @@ import inspect
from collections import UserDict, UserList from collections import UserDict, UserList
from dataclasses import dataclass from dataclasses import dataclass
from numbers import Number from numbers import Number
from typing import Iterable, List, Literal, Mapping, Sequence, Union from typing import Callable, Iterable, List, Literal, Mapping, Sequence, Type
from .utils import FincalOptions, _parse_date, _preprocess_timeseries from .utils import FincalOptions, _parse_date, _preprocess_timeseries
@ -42,15 +42,15 @@ def date_parser(*pos):
def parse_dates(func): def parse_dates(func):
def wrapper_func(*args, **kwargs): def wrapper_func(*args, **kwargs):
date_format = kwargs.get("date_format", None) date_format: str = kwargs.get("date_format", None)
args = list(args) args: list = list(args)
sig = inspect.signature(func) sig: inspect.Signature = inspect.signature(func)
params = [i[0] for i in sig.parameters.items()] params: list = [i[0] for i in sig.parameters.items()]
for j in pos: for j in pos:
kwarg = params[j] kwarg: str = params[j]
date = kwargs.get(kwarg, None) date = kwargs.get(kwarg, None)
in_args = False in_args: bool = False
if date is None: if date is None:
try: try:
date = args[j] date = args[j]
@ -61,7 +61,7 @@ def date_parser(*pos):
if date is None: if date is None:
continue continue
parsed_date = _parse_date(date, date_format) parsed_date: datetime.datetime = _parse_date(date, date_format)
if not in_args: if not in_args:
kwargs[kwarg] = parsed_date kwargs[kwarg] = parsed_date
else: else:
@ -90,9 +90,9 @@ class _IndexSlicer:
def __getitem__(self, n): def __getitem__(self, n):
if isinstance(n, int): if isinstance(n, int):
keys = [self.parent.dates[n]] keys: list = [self.parent.dates[n]]
else: else:
keys = self.parent.dates[n] keys: list = self.parent.dates[n]
item = [(key, self.parent.data[key]) for key in keys] item = [(key, self.parent.data[key]) for key in keys]
if len(item) == 1: if len(item) == 1:
return item[0] return item[0]
@ -105,11 +105,11 @@ class Series(UserList):
def __init__( def __init__(
self, self,
data, data: Sequence,
data_type: Literal["date", "number", "bool"], data_type: Literal["date", "number", "bool"],
date_format: str = None, date_format: str = None,
): ):
types_dict = { types_dict: dict = {
"date": datetime.datetime, "date": datetime.datetime,
"datetime": datetime.datetime, "datetime": datetime.datetime,
"datetime.datetime": datetime.datetime, "datetime.datetime": datetime.datetime,
@ -128,11 +128,11 @@ class Series(UserList):
if data_type in ["date", "datetime", "datetime.datetime"]: if data_type in ["date", "datetime", "datetime.datetime"]:
data = [_parse_date(i, date_format) for i in data] data = [_parse_date(i, date_format) for i in data]
else: else:
func = types_dict[data_type] func: Callable = types_dict[data_type]
data = [func(i) for i in data] data: list = [func(i) for i in data]
self.dtype = types_dict[data_type] self.dtype: Type = types_dict[data_type]
self.data = data self.data: Sequence = data
def __repr__(self): def __repr__(self):
return f"{self.__class__.__name__}({self.data}, data_type='{self.dtype.__name__}')" return f"{self.__class__.__name__}({self.data}, data_type='{self.dtype.__name__}')"
@ -212,25 +212,27 @@ class TimeSeriesCore(UserDict):
"""Defines the core building blocks of a TimeSeries object""" """Defines the core building blocks of a TimeSeries object"""
def __init__( def __init__(
self, data: List[Iterable], frequency: Literal["D", "W", "M", "Q", "H", "Y"], date_format: str = "%Y-%m-%d" self,
data: List[Iterable] | Mapping,
frequency: Literal["D", "W", "M", "Q", "H", "Y"],
date_format: str = "%Y-%m-%d",
): ):
"""Instantiate a TimeSeriesCore object """Instantiate a TimeSeriesCore object
Parameters Parameters
---------- ----------
data : List[tuple] data : List[Iterable] | Mapping
Time Series data in the form of list of tuples. Time Series data in the form of list of tuples or dictionary.
The first element of each tuple should be a date and second element should be a value. The first element of each tuple should be a date and second element should be a value.
In case of dictionary, the key should be the date.
frequency : str
The frequency of the time series.
Valid values are {D, W, M, Q, H, Y}
date_format : str, optional, default "%Y-%m-%d" date_format : str, optional, default "%Y-%m-%d"
Specify the format of the date Specify the format of the date
Required only if the first argument of tuples is a string. Otherwise ignored. Required only if the first argument of tuples is a string. Otherwise ignored.
frequency : str, optional, default "infer"
The frequency of the time series. Default is infer.
The class will try to infer the frequency automatically and adjust to the closest member.
Note that inferring frequencies can fail if the data is too irregular.
Valid values are {D, W, M, Q, H, Y}
""" """
data = _preprocess_timeseries(data, date_format=date_format) data = _preprocess_timeseries(data, date_format=date_format)
@ -322,10 +324,10 @@ class TimeSeriesCore(UserDict):
return printable_str return printable_str
@date_parser(1) @date_parser(1)
def _get_item_from_date(self, date: Union[str, datetime.datetime]): def _get_item_from_date(self, date: str | datetime.datetime):
return date, self.data[date] return date, self.data[date]
def _get_item_from_key(self, key: Union[str, datetime.datetime]): def _get_item_from_key(self, key: str | datetime.datetime):
if isinstance(key, int): if isinstance(key, int):
raise KeyError(f"{key}. \nHint: use .iloc[{key}] for index based slicing.") raise KeyError(f"{key}. \nHint: use .iloc[{key}] for index based slicing.")
@ -334,7 +336,7 @@ class TimeSeriesCore(UserDict):
return self._get_item_from_date(key) return self._get_item_from_date(key)
def _get_item_from_list(self, date_list: Sequence[Union[str, datetime.datetime]]): def _get_item_from_list(self, date_list: Sequence[str | datetime.datetime]):
data_to_return = [self._get_item_from_key(key) for key in date_list] data_to_return = [self._get_item_from_key(key) for key in date_list]
return self.__class__(data_to_return, frequency=self.frequency.symbol) return self.__class__(data_to_return, frequency=self.frequency.symbol)
@ -379,7 +381,7 @@ class TimeSeriesCore(UserDict):
return super().__contains__(key) return super().__contains__(key)
@date_parser(1) @date_parser(1)
def get(self, date: Union[str, datetime.datetime], default=None, closest=None): def get(self, date: str | datetime.datetime, default=None, closest=None):
if closest is None: if closest is None:
closest = FincalOptions.get_closest closest = FincalOptions.get_closest

View File

@ -5,7 +5,7 @@ import datetime
import math import math
import pathlib import pathlib
import statistics import statistics
from typing import Iterable, List, Literal, Mapping, Tuple, TypedDict, Union from typing import Iterable, List, Literal, Mapping, Tuple, TypedDict
from dateutil.relativedelta import relativedelta from dateutil.relativedelta import relativedelta
@ -26,8 +26,8 @@ class MaxDrawdown(TypedDict):
@date_parser(0, 1) @date_parser(0, 1)
def create_date_series( def create_date_series(
start_date: Union[str, datetime.datetime], start_date: str | datetime.datetime,
end_date: Union[str, datetime.datetime], end_date: str | datetime.datetime,
frequency: Literal["D", "W", "M", "Q", "H", "Y"], frequency: Literal["D", "W", "M", "Q", "H", "Y"],
eomonth: bool = False, eomonth: bool = False,
skip_weekends: bool = False, skip_weekends: bool = False,
@ -130,7 +130,7 @@ class TimeSeries(TimeSeriesCore):
def __init__( def __init__(
self, self,
data: Union[List[Iterable], Mapping], data: List[Iterable] | Mapping,
frequency: Literal["D", "W", "M", "Q", "H", "Y"], frequency: Literal["D", "W", "M", "Q", "H", "Y"],
date_format: str = "%Y-%m-%d", date_format: str = "%Y-%m-%d",
): ):
@ -145,7 +145,7 @@ class TimeSeries(TimeSeriesCore):
res_string: str = "First date: {}\nLast date: {}\nNumber of rows: {}" res_string: str = "First date: {}\nLast date: {}\nNumber of rows: {}"
return res_string.format(self.start_date, self.end_date, total_dates) return res_string.format(self.start_date, self.end_date, total_dates)
def ffill(self, inplace: bool = False, limit: int = None, skip_weekends: bool = False) -> Union[TimeSeries, None]: def ffill(self, inplace: bool = False, limit: int = None, skip_weekends: bool = False) -> TimeSeries | None:
"""Forward fill missing dates in the time series """Forward fill missing dates in the time series
Parameters Parameters
@ -183,7 +183,7 @@ class TimeSeries(TimeSeriesCore):
return self.__class__(new_ts, frequency=self.frequency.symbol) return self.__class__(new_ts, frequency=self.frequency.symbol)
def bfill(self, inplace: bool = False, limit: int = None, skip_weekends: bool = False) -> Union[TimeSeries, None]: def bfill(self, inplace: bool = False, limit: int = None, skip_weekends: bool = False) -> TimeSeries | None:
"""Backward fill missing dates in the time series """Backward fill missing dates in the time series
Parameters Parameters
@ -225,7 +225,7 @@ class TimeSeries(TimeSeriesCore):
@date_parser(1) @date_parser(1)
def calculate_returns( def calculate_returns(
self, self,
as_on: Union[str, datetime.datetime], as_on: str | datetime.datetime,
return_actual_date: bool = True, return_actual_date: bool = True,
as_on_match: str = "closest", as_on_match: str = "closest",
prior_match: str = "closest", prior_match: str = "closest",
@ -318,8 +318,9 @@ class TimeSeries(TimeSeriesCore):
@date_parser(1, 2) @date_parser(1, 2)
def calculate_rolling_returns( def calculate_rolling_returns(
self, self,
from_date: Union[datetime.date, str], from_date: datetime.date | str,
to_date: Union[datetime.date, str], to_date: datetime.date,
str,
frequency: Literal["D", "W", "M", "Q", "H", "Y"] = None, frequency: Literal["D", "W", "M", "Q", "H", "Y"] = None,
as_on_match: str = "closest", as_on_match: str = "closest",
prior_match: str = "closest", prior_match: str = "closest",
@ -427,8 +428,8 @@ class TimeSeries(TimeSeriesCore):
@date_parser(1, 2) @date_parser(1, 2)
def volatility( def volatility(
self, self,
from_date: Union[datetime.date, str] = None, from_date: datetime.date | str = None,
to_date: Union[datetime.date, str] = None, to_date: datetime.date | str = None,
annualize_volatility: bool = True, annualize_volatility: bool = True,
traded_days: int = None, traded_days: int = None,
frequency: Literal["D", "W", "M", "Q", "H", "Y"] = None, frequency: Literal["D", "W", "M", "Q", "H", "Y"] = None,
@ -600,7 +601,7 @@ class TimeSeries(TimeSeriesCore):
ensure_coverage=True, ensure_coverage=True,
) )
closest = "previous" if method == "ffill" else "next" closest: str = "previous" if method == "ffill" else "next"
new_ts: dict = {dt: self.get(dt, closest=closest)[1] for dt in new_dates} new_ts: dict = {dt: self.get(dt, closest=closest)[1] for dt in new_dates}
output_ts: TimeSeries = TimeSeries(new_ts, frequency=to_frequency.symbol) output_ts: TimeSeries = TimeSeries(new_ts, frequency=to_frequency.symbol)
@ -617,8 +618,8 @@ def _preprocess_csv(file_path: str | pathlib.Path, delimiter: str = ",", encodin
raise ValueError("File not found. Check the file path") raise ValueError("File not found. Check the file path")
with open(file_path, "r", encoding=encoding) as file: with open(file_path, "r", encoding=encoding) as file:
reader = csv.reader(file, delimiter=delimiter) reader: csv.reader = csv.reader(file, delimiter=delimiter)
csv_data = list(reader) csv_data: list = list(reader)
csv_data = [i for i in csv_data if i] # remove blank rows csv_data = [i for i in csv_data if i] # remove blank rows
if not csv_data: if not csv_data:

View File

@ -1,6 +1,6 @@
import datetime import datetime
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable, List, Literal, Mapping, Sequence, Tuple, Union from typing import Iterable, List, Literal, Mapping, Sequence, Tuple
from .exceptions import DateNotFoundError, DateOutOfRangeError from .exceptions import DateNotFoundError, DateOutOfRangeError
@ -32,18 +32,15 @@ def _parse_date(date: str, date_format: str = None):
def _preprocess_timeseries( def _preprocess_timeseries(
data: Union[ data: Sequence[Iterable[str | datetime.datetime, float]]
Sequence[Iterable[Union[str, datetime.datetime, float]]], | Sequence[Mapping[str | datetime.datetime, float]]
Sequence[Mapping[str, Union[float, datetime.datetime]]], | Mapping[str | datetime.datetime, float],
Sequence[Mapping[Union[str, datetime.datetime], float]],
Mapping[Union[str, datetime.datetime], float],
],
date_format: str, date_format: str,
) -> List[Tuple[datetime.datetime, float]]: ) -> List[Tuple[datetime.datetime, float]]:
"""Converts any type of list to the correct type""" """Converts any type of list to the correct type"""
if isinstance(data, Mapping): if isinstance(data, Mapping):
current_data = [(k, v) for k, v in data.items()] current_data: List[tuple] = [(k, v) for k, v in data.items()]
return _preprocess_timeseries(current_data, date_format) return _preprocess_timeseries(current_data, date_format)
if not isinstance(data, Sequence): if not isinstance(data, Sequence):
@ -56,31 +53,31 @@ def _preprocess_timeseries(
raise TypeError("Could not parse the data") raise TypeError("Could not parse the data")
if len(data[0]) == 1: if len(data[0]) == 1:
current_data = [tuple(*i.items()) for i in data] current_data: List[tuple] = [tuple(*i.items()) for i in data]
elif len(data[0]) == 2: elif len(data[0]) == 2:
current_data = [tuple(i.values()) for i in data] current_data: List[tuple] = [tuple(i.values()) for i in data]
else: else:
raise TypeError("Could not parse the data") raise TypeError("Could not parse the data")
return _preprocess_timeseries(current_data, date_format) return _preprocess_timeseries(current_data, date_format)
def _preprocess_match_options(as_on_match: str, prior_match: str, closest: str) -> datetime.timedelta: def _preprocess_match_options(as_on_match: str, prior_match: str, closest: str) -> Tuple[datetime.timedelta]:
"""Checks the arguments and returns appropriate timedelta objects""" """Checks the arguments and returns appropriate timedelta objects"""
deltas = {"exact": 0, "previous": -1, "next": 1} deltas = {"exact": 0, "previous": -1, "next": 1}
if closest not in deltas.keys(): if closest not in deltas.keys():
raise ValueError(f"Invalid argument for closest: {closest}") raise ValueError(f"Invalid argument for closest: {closest}")
as_on_match = closest if as_on_match == "closest" else as_on_match as_on_match: str = closest if as_on_match == "closest" else as_on_match
prior_match = closest if prior_match == "closest" else prior_match prior_match: str = closest if prior_match == "closest" else prior_match
if as_on_match in deltas.keys(): if as_on_match in deltas.keys():
as_on_delta = datetime.timedelta(days=deltas[as_on_match]) as_on_delta: datetime.timedelta = datetime.timedelta(days=deltas[as_on_match])
else: else:
raise ValueError(f"Invalid as_on_match argument: {as_on_match}") raise ValueError(f"Invalid as_on_match argument: {as_on_match}")
if prior_match in deltas.keys(): if prior_match in deltas.keys():
prior_delta = datetime.timedelta(days=deltas[prior_match]) prior_delta: datetime.timedelta = datetime.timedelta(days=deltas[prior_match])
else: else:
raise ValueError(f"Invalid prior_match argument: {prior_match}") raise ValueError(f"Invalid prior_match argument: {prior_match}")
@ -101,7 +98,7 @@ def _find_closest_date(
if delta.days > 0 and date > max(data): if delta.days > 0 and date > max(data):
raise DateOutOfRangeError(date, "max") raise DateOutOfRangeError(date, "max")
row = data.get(date, None) row: tuple = data.get(date, None)
if row is not None: if row is not None:
return date, row return date, row
@ -119,6 +116,6 @@ def _find_closest_date(
def _interval_to_years(interval_type: Literal["years", "months", "day"], interval_value: int) -> float: def _interval_to_years(interval_type: Literal["years", "months", "day"], interval_value: int) -> float:
"""Converts any time period to years for use with compounding functions""" """Converts any time period to years for use with compounding functions"""
year_conversion_factor = {"years": 1, "months": 12, "days": 365} year_conversion_factor: dict = {"years": 1, "months": 12, "days": 365}
years = interval_value / year_conversion_factor[interval_type] years: float = interval_value / year_conversion_factor[interval_type]
return years return years