diff --git a/src/newsreader/news/collection/reddit.py b/src/newsreader/news/collection/reddit.py index 2bb7bd9..1e2837b 100644 --- a/src/newsreader/news/collection/reddit.py +++ b/src/newsreader/news/collection/reddit.py @@ -111,6 +111,8 @@ class RedditBuilder(Builder): self.instances = self.build(posts, stream.rule) def build(self, posts, rule): + results = {} + for post in posts: if not "data" in post: continue @@ -120,6 +122,9 @@ class RedditBuilder(Builder): author = truncate_text(Post, "author", post["data"]["author"]) url_fragment = f"{post['data']['permalink']}" + if remote_identifier in results: + continue + uncleaned_body = post["data"]["selftext_html"] unescaped_body = unescape(uncleaned_body) if uncleaned_body else "" body = ( @@ -154,14 +159,15 @@ class RedditBuilder(Builder): if remote_identifier in self.existing_posts: existing_post = self.existing_posts[remote_identifier] - if created_date > existing_post.publication_date: - for key, value in data.items(): - setattr(existing_post, key, value) + for key, value in data.items(): + setattr(existing_post, key, value) - yield existing_post - continue + results[existing_post.remote_identifier] = existing_post + continue - yield Post(**data) + results[remote_identifier] = Post(**data) + + return results.values() def save(self): for post in self.instances: diff --git a/src/newsreader/news/collection/tests/reddit/builder/tests.py b/src/newsreader/news/collection/tests/reddit/builder/tests.py index e1a6770..eb8182a 100644 --- a/src/newsreader/news/collection/tests/reddit/builder/tests.py +++ b/src/newsreader/news/collection/tests/reddit/builder/tests.py @@ -85,10 +85,8 @@ class RedditBuilderTestCase(TestCase): def test_update_posts(self): subreddit = SubredditFactory() - existing_publication_date = pytz.utc.localize(datetime(2020, 7, 8, 14, 0, 0)) existing_post = RedditPostFactory( remote_identifier="hngsj8", - publication_date=existing_publication_date, author="Old author", title="Old title", body="Old body", @@ -198,9 +196,7 @@ class RedditBuilderTestCase(TestCase): mock_stream = MagicMock(rule=subreddit) duplicate_post = RedditPostFactory( - publication_date=pytz.utc.localize(datetime(2020, 7, 1, 9, 20, 22)), - remote_identifier="hm0qct", - title="foo", + remote_identifier="hm0qct", rule=subreddit, title="foo" ) with builder((simple_mock, mock_stream)) as builder: @@ -217,7 +213,7 @@ class RedditBuilderTestCase(TestCase): self.assertEquals( duplicate_post.publication_date, - pytz.utc.localize(datetime(2020, 7, 6, 14, 11, 22)), + pytz.utc.localize(datetime(2020, 7, 6, 6, 11, 22)), ) self.assertEquals( duplicate_post.title,