diff --git a/src/newsreader/accounts/models.py b/src/newsreader/accounts/models.py index 0c29801..58b7e66 100644 --- a/src/newsreader/accounts/models.py +++ b/src/newsreader/accounts/models.py @@ -61,24 +61,29 @@ class User(AbstractUser): def save(self, *args, **kwargs): if self._original_interval != self.task_interval: - if self.task_interval and self.task: self.task.interval = self.task_interval self.task.enabled = True self.task.save() - elif self.task_interval and not self.task: - self.task = PeriodicTask.objects.create( - enabled=True, - interval=self.task_interval, - name=f"{self.email}-collection-task", - task="newsreader.news.collection.tasks", - args=json.dumps([self.pk]), - kwargs=None, - ) - elif not self.task_interval and self.task: self.task.enabled = False self.task.save() + if not self.task: + self.task_interval, _ = IntervalSchedule.objects.get_or_create( + every=1, period=IntervalSchedule.HOURS + ) + + self.task = PeriodicTask.objects.create( + enabled=True, + interval=self.task_interval, + name=f"{self.email}-collection-task", + task="newsreader.news.collection.tasks", + args=json.dumps([self.pk]), + kwargs={}, + ) + + self._original_interval = self.task_interval + super().save(*args, **kwargs) diff --git a/src/newsreader/accounts/tests/tests.py b/src/newsreader/accounts/tests/tests.py new file mode 100644 index 0000000..a38aa43 --- /dev/null +++ b/src/newsreader/accounts/tests/tests.py @@ -0,0 +1,41 @@ +from django.test import TestCase + +from django_celery_beat.models import IntervalSchedule, PeriodicTask + +from newsreader.accounts.models import User + + +class UserTestCase(TestCase): + def test_task_is_created(self): + user = User.objects.create(email="durp@burp.nl", task=None, task_interval=None) + + task = PeriodicTask.objects.get(name=f"{user.email}-collection-task") + expected_interval = IntervalSchedule.objects.get( + every=1, period=IntervalSchedule.HOURS + ) + + self.assertEquals(task.interval, expected_interval) + self.assertEquals(PeriodicTask.objects.count(), 1) + + def test_task_is_updated(self): + user = User.objects.create(email="durp@burp.nl", task=None, task_interval=None) + + new_interval = IntervalSchedule.objects.create( + every=2, period=IntervalSchedule.HOURS + ) + user.task_interval = new_interval + user.save() + + task = PeriodicTask.objects.get(name=f"{user.email}-collection-task") + + self.assertEquals(task.interval, new_interval) + + def test_task_is_disabled(self): + user = User.objects.create(email="durp@burp.nl", task=None, task_interval=None) + + user.task_interval = None + user.save() + + task = PeriodicTask.objects.get(name=f"{user.email}-collection-task") + + self.assertEquals(task.enabled, False)