diff --git a/fincal/fincal.py b/fincal/fincal.py index fb227c9..5b38d26 100644 --- a/fincal/fincal.py +++ b/fincal/fincal.py @@ -5,7 +5,7 @@ from typing import List, Union from dateutil.relativedelta import relativedelta -from .core import AllFrequencies, TimeSeriesCore, _preprocess_match_options +from .core import AllFrequencies, TimeSeriesCore, _parse_date, _preprocess_match_options def create_date_series( @@ -113,12 +113,13 @@ class TimeSeries(TimeSeriesCore): def calculate_returns( self, - as_on: datetime.datetime, + as_on: Union[str, datetime.datetime], as_on_match: str = "closest", prior_match: str = "closest", closest: str = "previous", compounding: bool = True, years: int = 1, + date_format: str = None ) -> float: """Method to calculate returns for a certain time-period as on a particular date @@ -158,6 +159,7 @@ class TimeSeries(TimeSeriesCore): >>> calculate_returns(datetime.date(2020, 1, 1), years=1) """ + as_on = _parse_date(as_on, date_format) as_on_delta, prior_delta = _preprocess_match_options(as_on_match, prior_match, closest) while True: @@ -184,21 +186,28 @@ class TimeSeries(TimeSeriesCore): def calculate_rolling_returns( self, - from_date: datetime.date, - to_date: datetime.date, - frequency: str = "D", + from_date: Union[datetime.date, str], + to_date: Union[datetime.date, str], + frequency: str = None, as_on_match: str = "closest", prior_match: str = "closest", closest: str = "previous", compounding: bool = True, years: int = 1, + date_format: str = None ) -> List[tuple]: """Calculates the rolling return""" - try: - frequency = getattr(AllFrequencies, frequency) - except AttributeError: - raise ValueError(f"Invalid argument for frequency {frequency}") + from_date = _parse_date(from_date, date_format) + to_date = _parse_date(to_date, date_format) + + if frequency is None: + frequency = self.frequency + else: + try: + frequency = getattr(AllFrequencies, frequency) + except AttributeError: + raise ValueError(f"Invalid argument for frequency {frequency}") dates = create_date_series(from_date, to_date, frequency.symbol) if frequency == AllFrequencies.D: