diff --git a/pyfacts/pyfacts.py b/pyfacts/pyfacts.py index d823c9a..21257e8 100644 --- a/pyfacts/pyfacts.py +++ b/pyfacts/pyfacts.py @@ -81,6 +81,9 @@ def create_date_series( extend_by_days = 7 - end_date.weekday() end_date += relativedelta(days=extend_by_days) + else: + end_date += relativedelta(days=frequency.days) + # TODO: Add code to ensure coverage for other frequencies as well datediff = (end_date - start_date).days / frequency.days + 1 @@ -91,7 +94,7 @@ def create_date_series( date = start_date + relativedelta(**diff) if eomonth: - replacement = {"month": date.month + 1} if date.month < 12 else {"year": date.year + 1} + replacement = {"month": date.month + 1} if date.month < 12 else {"year": date.year + 1, "month": 1} date = date.replace(day=1).replace(**replacement) - relativedelta(days=1) if date <= end_date: @@ -846,8 +849,7 @@ class TimeSeries(TimeSeriesCore): dates = create_date_series( self.start_date, - self.end_date - + datetime.timedelta(to_frequency.days), # need extra date at the end for calculation of last value + self.end_date, to_frequency.symbol, ensure_coverage=True, ) diff --git a/tests/conftest.py b/tests/conftest.py index faf5b65..e31da75 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,10 +3,11 @@ import math import random from typing import List -import pyfacts as pft import pytest from dateutil.relativedelta import relativedelta +import pyfacts as pft + def conf_add(n1, n2): return n1 + n2 @@ -95,7 +96,9 @@ def sample_data_generator( ) } end_date = start_date + relativedelta(**timedelta_dict) - dates = pft.create_date_series(start_date, end_date, frequency.symbol, skip_weekends=skip_weekends, eomonth=eomonth) + dates = pft.create_date_series( + start_date, end_date, frequency.symbol, skip_weekends=skip_weekends, eomonth=eomonth, ensure_coverage=False + ) if dates_as_string: dates = [dt.strftime("%Y-%m-%d") for dt in dates] values = create_prices(1000, mu, sigma, num) diff --git a/tests/test_pyfacts.py b/tests/test_pyfacts.py index b06fd9b..aab530a 100644 --- a/tests/test_pyfacts.py +++ b/tests/test_pyfacts.py @@ -1,6 +1,7 @@ import datetime import pytest + from pyfacts import ( AllFrequencies, Frequency, @@ -340,8 +341,23 @@ class TestExpand: class TestShrink: - # TODO - pass + def test_daily_to_smaller(self, create_test_data): + ts_data = create_test_data(AllFrequencies.D, num=1000) + ts = TimeSeries(ts_data, "D") + shrunk_ts_w = ts.shrink("W", "ffill") + shrunk_ts_m = ts.shrink("M", "ffill") + assert len(shrunk_ts_w) == 143 + assert len(shrunk_ts_m) == 34 + + def test_weekly_to_smaller(self, create_test_data): + ts_data = create_test_data(AllFrequencies.W, num=300) + ts = TimeSeries(ts_data, "W") + tsm = ts.shrink("M", "ffill") + assert len(tsm) == 70 + tsmeo = ts.shrink("M", "ffill", eomonth=True) + assert len(tsmeo) == 69 + with pytest.raises(ValueError): + ts.shrink("D", "ffill") class TestMeanReturns: @@ -360,7 +376,7 @@ class TestTransform: ts = TimeSeries(ts_data, "D") tst = ts.transform("W", "mean") assert isinstance(tst, TimeSeries) - assert len(tst) == 157 + assert len(tst) == 156 assert "2017-01-30" in tst assert tst.iloc[4] == (datetime.datetime(2017, 1, 30), 1021.19)