- Add Twitter integration
- Refactor alot of existing code in collection app
- Update webpack font configuration
This commit is contained in:
Sonny Bakker 2020-09-27 16:19:32 +02:00
parent 805321f66d
commit d4a41a62da
118 changed files with 11060 additions and 5515 deletions

View file

@ -1,4 +1,4 @@
version: '3' version: "3"
volumes: volumes:
postgres-data: postgres-data:
static-files: static-files:
@ -16,7 +16,7 @@ services:
rabbitmq: rabbitmq:
image: rabbitmq:3.7 image: rabbitmq:3.7
memcached: memcached:
image: memcached:1.5.22 image: memcached:1.6
ports: ports:
- "11211:11211" - "11211:11211"
entrypoint: entrypoint:
@ -31,6 +31,7 @@ services:
- DJANGO_SETTINGS_MODULE=newsreader.conf.docker - DJANGO_SETTINGS_MODULE=newsreader.conf.docker
depends_on: depends_on:
- rabbitmq - rabbitmq
- memcached
volumes: volumes:
- .:/app - .:/app
django: django:
@ -41,9 +42,10 @@ services:
environment: environment:
- DJANGO_SETTINGS_MODULE=newsreader.conf.docker - DJANGO_SETTINGS_MODULE=newsreader.conf.docker
ports: ports:
- '8000:8000' - "8000:8000"
depends_on: depends_on:
- db - db
- memcached
volumes: volumes:
- .:/app - .:/app
- static-files:/app/src/newsreader/static - static-files:/app/src/newsreader/static

609
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -25,6 +25,8 @@ gunicorn = "^20.0.4"
python-dotenv = "^0.12.0" python-dotenv = "^0.12.0"
django = ">=3.0.7" django = ">=3.0.7"
sentry-sdk = "^0.15.1" sentry-sdk = "^0.15.1"
ftfy = "^5.8"
requests_oauthlib = "^1.3.0"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
factory-boy = "^2.12.0" factory-boy = "^2.12.0"

View file

@ -11,8 +11,18 @@ class UserAdminForm(UserChangeForm):
class Meta: class Meta:
widgets = { widgets = {
"email": forms.EmailInput(attrs={"size": "50"}), "email": forms.EmailInput(attrs={"size": "50"}),
"reddit_access_token": forms.TextInput(attrs={"size": "90"}), "reddit_access_token": forms.PasswordInput(
"reddit_refresh_token": forms.TextInput(attrs={"size": "90"}), attrs={"size": "90"}, render_value=True
),
"reddit_refresh_token": forms.PasswordInput(
attrs={"size": "90"}, render_value=True
),
"twitter_oauth_token": forms.PasswordInput(
attrs={"size": "90"}, render_value=True
),
"twitter_oauth_token_secret": forms.PasswordInput(
attrs={"size": "90"}, render_value=True
),
} }
@ -34,6 +44,10 @@ class UserAdmin(DjangoUserAdmin):
_("Reddit settings"), _("Reddit settings"),
{"fields": ("reddit_access_token", "reddit_refresh_token")}, {"fields": ("reddit_access_token", "reddit_refresh_token")},
), ),
(
_("Twitter settings"),
{"fields": ("twitter_oauth_token", "twitter_oauth_token_secret")},
),
( (
_("Permission settings"), _("Permission settings"),
{"classes": ("collapse",), "fields": ("is_staff", "is_superuser")}, {"classes": ("collapse",), "fields": ("is_staff", "is_superuser")},

View file

@ -0,0 +1,21 @@
# Generated by Django 3.0.7 on 2020-09-13 19:01
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [("accounts", "0010_auto_20200603_2230")]
operations = [
migrations.AddField(
model_name="user",
name="twitter_oauth_token",
field=models.CharField(blank=True, max_length=255, null=True),
),
migrations.AddField(
model_name="user",
name="twitter_oauth_token_secret",
field=models.CharField(blank=True, max_length=255, null=True),
),
]

View file

@ -0,0 +1,10 @@
# Generated by Django 3.0.7 on 2020-09-26 15:34
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [("accounts", "0011_auto_20200913_2101")]
operations = [migrations.RemoveField(model_name="user", name="task")]

View file

@ -1,11 +1,9 @@
import json
from django.contrib.auth.models import AbstractUser from django.contrib.auth.models import AbstractUser
from django.contrib.auth.models import UserManager as DjangoUserManager from django.contrib.auth.models import UserManager as DjangoUserManager
from django.db import models from django.db import models
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from django_celery_beat.models import IntervalSchedule, PeriodicTask from django_celery_beat.models import PeriodicTask
class UserManager(DjangoUserManager): class UserManager(DjangoUserManager):
@ -41,18 +39,12 @@ class UserManager(DjangoUserManager):
class User(AbstractUser): class User(AbstractUser):
email = models.EmailField(_("email address"), unique=True) email = models.EmailField(_("email address"), unique=True)
task = models.OneToOneField(
PeriodicTask,
on_delete=models.CASCADE,
null=True,
blank=True,
editable=False,
verbose_name="collection task",
)
reddit_refresh_token = models.CharField(max_length=255, blank=True, null=True) reddit_refresh_token = models.CharField(max_length=255, blank=True, null=True)
reddit_access_token = models.CharField(max_length=255, blank=True, null=True) reddit_access_token = models.CharField(max_length=255, blank=True, null=True)
twitter_oauth_token = models.CharField(max_length=255, blank=True, null=True)
twitter_oauth_token_secret = models.CharField(max_length=255, blank=True, null=True)
username = None username = None
objects = UserManager() objects = UserManager()
@ -60,24 +52,12 @@ class User(AbstractUser):
USERNAME_FIELD = "email" USERNAME_FIELD = "email"
REQUIRED_FIELDS = [] REQUIRED_FIELDS = []
def save(self, *args, **kwargs):
super().save(*args, **kwargs)
if not self.task:
task_interval, _ = IntervalSchedule.objects.get_or_create(
every=1, period=IntervalSchedule.HOURS
)
self.task, _ = PeriodicTask.objects.get_or_create(
enabled=True,
interval=task_interval,
name=f"{self.email}-collection-task",
task="FeedTask",
args=json.dumps([self.pk]),
)
self.save()
def delete(self, *args, **kwargs): def delete(self, *args, **kwargs):
self.task.delete() tasks = PeriodicTask.objects.filter(name__contains=self.email)
tasks.delete()
return super().delete(*args, **kwargs) return super().delete(*args, **kwargs)
@property
def has_twitter_auth(self):
return self.twitter_oauth_token and self.twitter_oauth_token_secret

View file

@ -3,28 +3,15 @@
{% block actions %} {% block actions %}
<section class="section form__section--last"> <section class="section form__section--last">
<fieldset class="fieldset form__fieldset">
{% include "components/form/cancel-button.html" %}
</fieldset>
<fieldset class="fieldset form__fieldset"> <fieldset class="fieldset form__fieldset">
<a class="link button button--primary" href="{% url 'accounts:password-change' %}"> <a class="link button button--primary" href="{% url 'accounts:password-change' %}">
{% trans "Change password" %} {% trans "Change password" %}
</a> </a>
<a class="link button button--primary" href="{% url 'accounts:integrations' %}">
{% trans "Third party integrations" %}
</a>
{% include "components/form/confirm-button.html" %} {% include "components/form/confirm-button.html" %}
{% if reddit_authorization_url %}
<a class="link button button--reddit" href="{{ reddit_authorization_url }}">
{% trans "Authorize Reddit account" %}
</a>
{% endif %}
{% if reddit_refresh_url %}
<a class="link button button--reddit" href="{{ reddit_refresh_url }}">
{% trans "Refresh Reddit access token" %}
</a>
{% endif %}
</fieldset> </fieldset>
</section> </section>
{% endblock actions %} {% endblock actions %}

View file

@ -0,0 +1,70 @@
{% extends "base.html" %}
{% load i18n %}
{% block content %}
<main id="integrations--page" class="main">
<section class="section">
{% include "components/header/header.html" with title="Integrations" only %}
<div class="integrations">
<h3 class="integrations__title">Reddit</h3>
<div class="integrations__controls">
{% if reddit_authorization_url %}
<a class="link button button--reddit" href="{{ reddit_authorization_url }}">
{% trans "Authorize account" %}
</a>
{% else %}
<button class="button button--reddit button--disabled" disabled>
{% trans "Authorize account" %}
</button>
{% endif %}
{% if reddit_refresh_url %}
<a class="link button button--reddit" href="{{ reddit_refresh_url }}">
{% trans "Refresh token" %}
</a>
{% else %}
<button class="button button--reddit button--disabled" disabled>
{% trans "Refresh token" %}
</button>
{% endif %}
{% if reddit_revoke_url %}
<a class="link button button--reddit" href="{{ reddit_revoke_url }}">
{% trans "Deauthorize account" %}
</a>
{% else %}
<button class="button button--reddit button--disabled" disabled>
{% trans "Deauthorize account" %}
</button>
{% endif %}
</div>
</div>
<div class="integrations">
<h3 class="integrations__title">Twitter</h3>
<div class="integrations__controls">
{% if twitter_auth_url %}
<a class="link button button--twitter" href="{{ twitter_auth_url }}">
{% trans "Authorize account" %}
</a>
{% else %}
<button class="button button--twitter button--disabled" disabled>
{% trans "Authorize account" %}
</button>
{% endif %}
{% if twitter_revoke_url %}
<a class="link button button--twitter" href="{{ twitter_revoke_url }}">
{% trans "Deauthorize account" %}
</a>
{% else %}
<button class="button button--twitter button--disabled" disabled>
{% trans "Deauthorize account" %}
</button>
{% endif %}
</div>
</div>
</section>
</main>
{% endblock %}

View file

@ -1,17 +1,20 @@
{% extends "base.html" %} {% extends "base.html" %}
{% load i18n %}
{% block content %} {% block content %}
<main id="settings--page" class="main"> <main id="reddit--page" class="main">
<section class="section text-section"> <section class="section text-section">
{% if error %} {% if error %}
<h1 class="h1">Reddit authorization failed</h1> <h1 class="h1">{% trans "Reddit authorization failed" %}</h1>
<p>{{ error }}</p> <p>{{ error }}</p>
{% elif access_token and refresh_token %} {% elif access_token and refresh_token %}
<h1 class="h1">Reddit account is linked</h1> <h1 class="h1">{% trans "Reddit account is linked" %}</h1>
<p>Your reddit account was successfully linked.</p> <p>{% trans "Your reddit account was successfully linked." %}</p>
{% endif %} {% endif %}
<p><a href="{% url 'accounts:settings' %}">Return to settings page</a></p> <p>
<a class="link" href="{% url 'accounts:integrations' %}">{% trans "Return to integrations page" %}</a>
</p>
</section> </section>
</main> </main>
{% endblock %} {% endblock %}

View file

@ -0,0 +1,20 @@
{% extends "base.html" %}
{% load i18n %}
{% block content %}
<main id="twitter--page" class="main">
<section class="section text-section">
{% if error %}
<h1 class="h1">{% trans "Twitter authorization failed" %}</h1>
<p>{{ error }}</p>
{% elif authorized %}
<h1 class="h1">{% trans "Twitter account is linked" %}</h1>
<p>{% trans "Your Twitter account was successfully linked." %}</p>
{% endif %}
<p>
<a class="link" href="{% url 'accounts:integrations' %}">{% trans "Return to integrations page" %}</a>
</p>
</section>
</main>
{% endblock %}

View file

@ -0,0 +1,537 @@
from unittest.mock import Mock, patch
from urllib.parse import urlencode
from uuid import uuid4
from django.core.cache import cache
from django.test import TestCase
from django.urls import reverse
from django.utils.translation import gettext as _
from bs4 import BeautifulSoup
from newsreader.accounts.tests.factories import UserFactory
from newsreader.news.collection.exceptions import (
StreamException,
StreamTooManyException,
)
from newsreader.news.collection.twitter import TWITTER_AUTH_URL
class IntegrationsViewTestCase(TestCase):
def setUp(self):
self.user = UserFactory(email="test@test.nl", password="test")
self.client.force_login(self.user)
self.url = reverse("accounts:integrations")
class RedditIntegrationsTestCase(IntegrationsViewTestCase):
def test_reddit_authorization(self):
self.user.reddit_refresh_token = None
self.user.save()
response = self.client.get(self.url)
soup = BeautifulSoup(response.content, features="lxml")
button = soup.find("a", class_="link button button--reddit")
self.assertEquals(button.text.strip(), "Authorize account")
def test_reddit_refresh_token(self):
self.user.reddit_refresh_token = "jadajadajada"
self.user.reddit_access_token = None
self.user.save()
response = self.client.get(self.url)
soup = BeautifulSoup(response.content, features="lxml")
button = soup.find("a", class_="link button button--reddit")
self.assertEquals(button.text.strip(), "Refresh token")
def test_reddit_revoke(self):
self.user.reddit_refresh_token = "jadajadajada"
self.user.reddit_access_token = None
self.user.save()
response = self.client.get(self.url)
soup = BeautifulSoup(response.content, features="lxml")
buttons = soup.find_all("a", class_="link button button--reddit")
self.assertIn(
"Deauthorize account", [button.text.strip() for button in buttons]
)
class RedditTemplateViewTestCase(TestCase):
def setUp(self):
self.user = UserFactory(email="test@test.nl", password="test")
self.client.force_login(self.user)
self.base_url = reverse("accounts:reddit-template")
self.state = str(uuid4())
self.patch = patch("newsreader.news.collection.reddit.post")
self.mocked_post = self.patch.start()
def tearDown(self):
patch.stopall()
def test_simple(self):
response = self.client.get(self.base_url)
self.assertEquals(response.status_code, 200)
self.assertContains(response, "Return to integrations page")
def test_successful_authorization(self):
self.mocked_post.return_value.json.return_value = {
"access_token": "1001010412",
"refresh_token": "134510143",
}
cache.set(f"{self.user.email}-reddit-auth", self.state)
params = {"state": self.state, "code": "Valid code"}
url = f"{self.base_url}?{urlencode(params)}"
response = self.client.get(url)
self.mocked_post.assert_called_once()
self.assertEquals(response.status_code, 200)
self.assertContains(response, "Your reddit account was successfully linked.")
self.user.refresh_from_db()
self.assertEquals(self.user.reddit_access_token, "1001010412")
self.assertEquals(self.user.reddit_refresh_token, "134510143")
self.assertEquals(cache.get(f"{self.user.email}-reddit-auth"), None)
def test_error(self):
params = {"error": "Denied authorization"}
url = f"{self.base_url}?{urlencode(params)}"
response = self.client.get(url)
self.assertEquals(response.status_code, 200)
self.assertContains(response, "Denied authorization")
def test_invalid_state(self):
cache.set(f"{self.user.email}-reddit-auth", str(uuid4()))
params = {"code": "Valid code", "state": "Invalid state"}
url = f"{self.base_url}?{urlencode(params)}"
response = self.client.get(url)
self.assertEquals(response.status_code, 200)
self.assertContains(
response, "The saved state for Reddit authorization did not match"
)
def test_stream_error(self):
self.mocked_post.side_effect = StreamTooManyException
cache.set(f"{self.user.email}-reddit-auth", self.state)
params = {"state": self.state, "code": "Valid code"}
url = f"{self.base_url}?{urlencode(params)}"
response = self.client.get(url)
self.mocked_post.assert_called_once()
self.assertEquals(response.status_code, 200)
self.assertContains(response, "Too many requests")
self.user.refresh_from_db()
self.assertEquals(self.user.reddit_access_token, None)
self.assertEquals(self.user.reddit_refresh_token, None)
self.assertEquals(cache.get(f"{self.user.email}-reddit-auth"), self.state)
def test_unexpected_json(self):
self.mocked_post.return_value.json.return_value = {"message": "Happy eastern"}
cache.set(f"{self.user.email}-reddit-auth", self.state)
params = {"state": self.state, "code": "Valid code"}
url = f"{self.base_url}?{urlencode(params)}"
response = self.client.get(url)
self.mocked_post.assert_called_once()
self.assertEquals(response.status_code, 200)
self.assertContains(response, "Access and refresh token not found in response")
self.user.refresh_from_db()
self.assertEquals(self.user.reddit_access_token, None)
self.assertEquals(self.user.reddit_refresh_token, None)
self.assertEquals(cache.get(f"{self.user.email}-reddit-auth"), self.state)
class RedditTokenRedirectViewTestCase(TestCase):
def setUp(self):
self.user = UserFactory(email="test@test.nl", password="test")
self.client.force_login(self.user)
self.patch = patch("newsreader.accounts.views.integrations.RedditTokenTask")
self.mocked_task = self.patch.start()
def tearDown(self):
cache.clear()
def test_simple(self):
response = self.client.get(reverse("accounts:reddit-refresh"))
self.assertRedirects(response, reverse("accounts:integrations"))
self.mocked_task.delay.assert_called_once_with(self.user.pk)
self.assertEquals(1, cache.get(f"{self.user.email}-reddit-refresh"))
def test_not_active(self):
cache.set(f"{self.user.email}-reddit-refresh", 1)
response = self.client.get(reverse("accounts:reddit-refresh"))
self.assertRedirects(response, reverse("accounts:integrations"))
self.mocked_task.delay.assert_not_called()
class RedditRevokeRedirectViewTestCase(TestCase):
def setUp(self):
self.user = UserFactory(email="test@test.nl", password="test")
self.client.force_login(self.user)
self.patch = patch("newsreader.accounts.views.integrations.revoke_reddit_token")
self.mocked_revoke = self.patch.start()
def test_simple(self):
self.user.reddit_access_token = "jadajadajada"
self.user.reddit_refresh_token = "jadajadajada"
self.user.save()
self.mocked_revoke.return_value = True
response = self.client.get(reverse("accounts:reddit-revoke"))
self.assertRedirects(response, reverse("accounts:integrations"))
self.mocked_revoke.assert_called_once_with(self.user)
self.user.refresh_from_db()
self.assertEquals(self.user.reddit_access_token, None)
self.assertEquals(self.user.reddit_refresh_token, None)
def test_no_refresh_token(self):
self.user.reddit_refresh_token = None
self.user.save()
response = self.client.get(reverse("accounts:reddit-revoke"))
self.assertRedirects(response, reverse("accounts:integrations"))
self.mocked_revoke.assert_not_called()
def test_unsuccessful_response(self):
self.user.reddit_access_token = "jadajadajada"
self.user.reddit_refresh_token = "jadajadajada"
self.user.save()
self.mocked_revoke.return_value = False
response = self.client.get(reverse("accounts:reddit-revoke"))
self.assertRedirects(response, reverse("accounts:integrations"))
self.user.refresh_from_db()
self.assertEquals(self.user.reddit_access_token, "jadajadajada")
self.assertEquals(self.user.reddit_refresh_token, "jadajadajada")
def test_stream_exception(self):
self.user.reddit_access_token = "jadajadajada"
self.user.reddit_refresh_token = "jadajadajada"
self.user.save()
self.mocked_revoke.side_effect = StreamException
response = self.client.get(reverse("accounts:reddit-revoke"))
self.assertRedirects(response, reverse("accounts:integrations"))
self.user.refresh_from_db()
self.assertEquals(self.user.reddit_access_token, "jadajadajada")
self.assertEquals(self.user.reddit_refresh_token, "jadajadajada")
class TwitterRevokeRedirectView(TestCase):
def setUp(self):
self.user = UserFactory(email="test@test.nl", password="test")
self.client.force_login(self.user)
self.patch = patch("newsreader.accounts.views.integrations.post")
self.mocked_post = self.patch.start()
def tearDown(self):
patch.stopall()
def test_simple(self):
self.user.twitter_oauth_token = "jadajadajada"
self.user.twitter_oauth_token_secret = "jadajadajada"
self.user.save()
response = self.client.get(reverse("accounts:twitter-revoke"))
self.assertRedirects(response, reverse("accounts:integrations"))
self.user.refresh_from_db()
self.assertIsNone(self.user.twitter_oauth_token)
self.assertIsNone(self.user.twitter_oauth_token_secret)
def test_no_authorized_account(self):
self.user.twitter_oauth_token = None
self.user.twitter_oauth_token_secret = None
self.user.save()
response = self.client.get(reverse("accounts:twitter-revoke"))
self.assertRedirects(response, reverse("accounts:integrations"))
self.mocked_post.assert_not_called()
def test_stream_exception(self):
self.user.twitter_oauth_token = "jadajadajada"
self.user.twitter_oauth_token_secret = "jadajadajada"
self.user.save()
self.mocked_post.side_effect = StreamException
response = self.client.get(reverse("accounts:twitter-revoke"))
self.assertRedirects(response, reverse("accounts:integrations"))
self.user.refresh_from_db()
self.assertEquals(self.user.twitter_oauth_token, "jadajadajada")
self.assertEquals(self.user.twitter_oauth_token_secret, "jadajadajada")
class TwitterAuthRedirectViewTestCase(TestCase):
def setUp(self):
self.user = UserFactory(email="test@test.nl", password="test")
self.client.force_login(self.user)
self.patch = patch("newsreader.accounts.views.integrations.post")
self.mocked_post = self.patch.start()
def tearDown(self):
cache.clear()
def test_simple(self):
self.mocked_post.return_value = Mock(
text="oauth_token=foo&oauth_token_secret=bar"
)
response = self.client.get(reverse("accounts:twitter-auth"))
self.assertRedirects(
response,
f"{TWITTER_AUTH_URL}/?oauth_token=foo",
fetch_redirect_response=False,
)
cached_token = cache.get(f"twitter-{self.user.email}-token")
cached_secret = cache.get(f"twitter-{self.user.email}-secret")
self.assertEquals(cached_token, "foo")
self.assertEquals(cached_secret, "bar")
def test_stream_exception(self):
self.mocked_post.side_effect = StreamException
response = self.client.get(reverse("accounts:twitter-auth"))
self.assertRedirects(response, reverse("accounts:integrations"))
cached_token = cache.get(f"twitter-{self.user.email}-token")
cached_secret = cache.get(f"twitter-{self.user.email}-secret")
self.assertIsNone(cached_token)
self.assertIsNone(cached_secret)
def test_unexpected_contents(self):
self.mocked_post.return_value = Mock(text="foo=bar&oauth_token_secret=bar")
response = self.client.get(reverse("accounts:twitter-auth"))
self.assertRedirects(response, reverse("accounts:integrations"))
cached_token = cache.get(f"twitter-{self.user.email}-token")
cached_secret = cache.get(f"twitter-{self.user.email}-secret")
self.assertIsNone(cached_token)
self.assertIsNone(cached_secret)
class TwitterTemplateViewTestCase(TestCase):
def setUp(self):
self.user = UserFactory(email="test@test.nl", password="test")
self.client.force_login(self.user)
self.patch = patch("newsreader.accounts.views.integrations.post")
self.mocked_post = self.patch.start()
def tearDown(self):
cache.clear()
def test_simple(self):
cache.set_many(
{
f"twitter-{self.user.email}-token": "foo",
f"twitter-{self.user.email}-secret": "bar",
}
)
params = {"denied": "", "oauth_token": "foo", "oauth_verifier": "barfoo"}
self.mocked_post.return_value = Mock(
text="oauth_token=realtoken&oauth_token_secret=realsecret"
)
response = self.client.get(
f"{reverse('accounts:twitter-template')}?{urlencode(params)}"
)
self.assertContains(response, _("Twitter account is linked"))
self.user.refresh_from_db()
self.assertEquals(self.user.twitter_oauth_token, "realtoken")
self.assertEquals(self.user.twitter_oauth_token_secret, "realsecret")
self.assertIsNone(cache.get(f"twitter-{self.user.email}-token"))
self.assertIsNone(cache.get(f"twitter-{self.user.email}-secret"))
def test_denied(self):
params = {"denied": "true", "oauth_token": "foo", "oauth_verifier": "barfoo"}
response = self.client.get(
f"{reverse('accounts:twitter-template')}?{urlencode(params)}"
)
self.assertContains(response, _("Twitter authorization failed"))
self.user.refresh_from_db()
self.assertIsNone(self.user.twitter_oauth_token)
self.assertIsNone(self.user.twitter_oauth_token_secret)
self.mocked_post.assert_not_called()
def test_mismatched_token(self):
cache.set_many(
{
f"twitter-{self.user.email}-token": "foo",
f"twitter-{self.user.email}-secret": "bar",
}
)
params = {"denied": "", "oauth_token": "boo", "oauth_verifier": "barfoo"}
response = self.client.get(
f"{reverse('accounts:twitter-template')}?{urlencode(params)}"
)
self.assertContains(response, _("OAuth tokens failed to match"))
self.user.refresh_from_db()
self.assertIsNone(self.user.twitter_oauth_token)
self.assertIsNone(self.user.twitter_oauth_token_secret)
self.mocked_post.assert_not_called()
def test_missing_secret(self):
cache.set_many({f"twitter-{self.user.email}-token": "foo"})
params = {"denied": "", "oauth_token": "foo", "oauth_verifier": "barfoo"}
response = self.client.get(
f"{reverse('accounts:twitter-template')}?{urlencode(params)}"
)
self.assertContains(response, _("No matching tokens found for this user"))
self.user.refresh_from_db()
self.assertIsNone(self.user.twitter_oauth_token_secret)
self.mocked_post.assert_not_called()
def test_stream_exception(self):
cache.set_many(
{
f"twitter-{self.user.email}-token": "foo",
f"twitter-{self.user.email}-secret": "bar",
}
)
params = {"denied": "", "oauth_token": "foo", "oauth_verifier": "barfoo"}
self.mocked_post.side_effect = StreamException
response = self.client.get(
f"{reverse('accounts:twitter-template')}?{urlencode(params)}"
)
self.assertContains(response, _("Failed requesting access token"))
self.user.refresh_from_db()
self.assertIsNone(self.user.twitter_oauth_token)
self.assertIsNone(self.user.twitter_oauth_token_secret)
self.assertIsNotNone(cache.get(f"twitter-{self.user.email}-token"))
self.assertIsNotNone(cache.get(f"twitter-{self.user.email}-secret"))
def test_unexpected_contents(self):
cache.set_many(
{
f"twitter-{self.user.email}-token": "foo",
f"twitter-{self.user.email}-secret": "bar",
}
)
params = {"denied": "", "oauth_token": "foo", "oauth_verifier": "barfoo"}
self.mocked_post.return_value = Mock(
text="foobar=boo&oauth_token_secret=realsecret"
)
response = self.client.get(
f"{reverse('accounts:twitter-template')}?{urlencode(params)}"
)
self.assertContains(response, _("No credentials found in Twitter response"))
self.user.refresh_from_db()
self.assertIsNone(self.user.twitter_oauth_token)
self.assertIsNone(self.user.twitter_oauth_token_secret)
self.assertIsNotNone(cache.get(f"twitter-{self.user.email}-token"))
self.assertIsNotNone(cache.get(f"twitter-{self.user.email}-secret"))

View file

@ -1,14 +1,8 @@
from unittest.mock import patch
from urllib.parse import urlencode
from uuid import uuid4
from django.core.cache import cache
from django.test import TestCase from django.test import TestCase
from django.urls import reverse from django.urls import reverse
from newsreader.accounts.models import User from newsreader.accounts.models import User
from newsreader.accounts.tests.factories import UserFactory from newsreader.accounts.tests.factories import UserFactory
from newsreader.news.collection.exceptions import StreamTooManyException
class SettingsViewTestCase(TestCase): class SettingsViewTestCase(TestCase):
@ -22,7 +16,6 @@ class SettingsViewTestCase(TestCase):
response = self.client.get(self.url) response = self.client.get(self.url)
self.assertEquals(response.status_code, 200) self.assertEquals(response.status_code, 200)
self.assertContains(response, "Authorize Reddit account")
def test_user_credential_change(self): def test_user_credential_change(self):
response = self.client.post( response = self.client.post(
@ -36,126 +29,3 @@ class SettingsViewTestCase(TestCase):
self.assertEquals(user.first_name, "First name") self.assertEquals(user.first_name, "First name")
self.assertEquals(user.last_name, "Last name") self.assertEquals(user.last_name, "Last name")
def test_linked_reddit_account(self):
self.user.reddit_refresh_token = "test"
self.user.save()
response = self.client.get(self.url)
self.assertEquals(response.status_code, 200)
self.assertNotContains(response, "Authorize Reddit account")
class RedditTemplateViewTestCase(TestCase):
def setUp(self):
self.user = UserFactory(email="test@test.nl", password="test")
self.client.force_login(self.user)
self.base_url = reverse("accounts:reddit-template")
self.state = str(uuid4())
self.patch = patch("newsreader.news.collection.reddit.post")
self.mocked_post = self.patch.start()
def tearDown(self):
patch.stopall()
def test_simple(self):
response = self.client.get(self.base_url)
self.assertEquals(response.status_code, 200)
self.assertContains(response, "Return to settings page")
def test_successful_authorization(self):
self.mocked_post.return_value.json.return_value = {
"access_token": "1001010412",
"refresh_token": "134510143",
}
cache.set(f"{self.user.email}-reddit-auth", self.state)
params = {"state": self.state, "code": "Valid code"}
url = f"{self.base_url}?{urlencode(params)}"
response = self.client.get(url)
self.mocked_post.assert_called_once()
self.assertEquals(response.status_code, 200)
self.assertContains(response, "Your reddit account was successfully linked.")
self.user.refresh_from_db()
self.assertEquals(self.user.reddit_access_token, "1001010412")
self.assertEquals(self.user.reddit_refresh_token, "134510143")
self.assertEquals(cache.get(f"{self.user.email}-reddit-auth"), None)
def test_error(self):
params = {"error": "Denied authorization"}
url = f"{self.base_url}?{urlencode(params)}"
response = self.client.get(url)
self.assertEquals(response.status_code, 200)
self.assertContains(response, "Denied authorization")
def test_invalid_state(self):
cache.set(f"{self.user.email}-reddit-auth", str(uuid4()))
params = {"code": "Valid code", "state": "Invalid state"}
url = f"{self.base_url}?{urlencode(params)}"
response = self.client.get(url)
self.assertEquals(response.status_code, 200)
self.assertContains(
response, "The saved state for Reddit authorization did not match"
)
def test_stream_error(self):
self.mocked_post.side_effect = StreamTooManyException
cache.set(f"{self.user.email}-reddit-auth", self.state)
params = {"state": self.state, "code": "Valid code"}
url = f"{self.base_url}?{urlencode(params)}"
response = self.client.get(url)
self.mocked_post.assert_called_once()
self.assertEquals(response.status_code, 200)
self.assertContains(response, "Too many requests")
self.user.refresh_from_db()
self.assertEquals(self.user.reddit_access_token, None)
self.assertEquals(self.user.reddit_refresh_token, None)
self.assertEquals(cache.get(f"{self.user.email}-reddit-auth"), self.state)
def test_unexpected_json(self):
self.mocked_post.return_value.json.return_value = {"message": "Happy eastern"}
cache.set(f"{self.user.email}-reddit-auth", self.state)
params = {"state": self.state, "code": "Valid code"}
url = f"{self.base_url}?{urlencode(params)}"
response = self.client.get(url)
self.mocked_post.assert_called_once()
self.assertEquals(response.status_code, 200)
self.assertContains(response, "Access and refresh token not found in response")
self.user.refresh_from_db()
self.assertEquals(self.user.reddit_access_token, None)
self.assertEquals(self.user.reddit_refresh_token, None)
self.assertEquals(cache.get(f"{self.user.email}-reddit-auth"), self.state)

View file

@ -1,22 +1,24 @@
from django.test import TestCase from django.test import TestCase
from django_celery_beat.models import PeriodicTask from django_celery_beat.models import IntervalSchedule, PeriodicTask
from newsreader.accounts.models import User from newsreader.accounts.tests.factories import UserFactory
class UserTestCase(TestCase): class UserTestCase(TestCase):
def test_task_is_created(self):
user = User.objects.create(email="durp@burp.nl", task=None)
task = PeriodicTask.objects.get(name=f"{user.email}-collection-task")
user.refresh_from_db()
self.assertEquals(task, user.task)
self.assertEquals(PeriodicTask.objects.count(), 1)
def test_task_is_deleted(self): def test_task_is_deleted(self):
user = User.objects.create(email="durp@burp.nl", task=None) user = UserFactory(email="durp@burp.nl")
interval = IntervalSchedule.objects.create(
every=1, period=IntervalSchedule.HOURS
)
PeriodicTask.objects.create(
name=f"{user.email}-feed", task="FeedTask", interval=interval
)
PeriodicTask.objects.create(
name=f"{user.email}-timeline", task="TwitterTimelineTask", interval=interval
)
user.delete() user.delete()
self.assertEquals(PeriodicTask.objects.count(), 0) self.assertEquals(PeriodicTask.objects.count(), 0)

View file

@ -5,6 +5,7 @@ from newsreader.accounts.views import (
ActivationCompleteView, ActivationCompleteView,
ActivationResendView, ActivationResendView,
ActivationView, ActivationView,
IntegrationsView,
LoginView, LoginView,
LogoutView, LogoutView,
PasswordChangeView, PasswordChangeView,
@ -12,18 +13,24 @@ from newsreader.accounts.views import (
PasswordResetConfirmView, PasswordResetConfirmView,
PasswordResetDoneView, PasswordResetDoneView,
PasswordResetView, PasswordResetView,
RedditRevokeRedirectView,
RedditTemplateView, RedditTemplateView,
RedditTokenRedirectView, RedditTokenRedirectView,
RegistrationClosedView, RegistrationClosedView,
RegistrationCompleteView, RegistrationCompleteView,
RegistrationView, RegistrationView,
SettingsView, SettingsView,
TwitterAuthRedirectView,
TwitterRevokeRedirectView,
TwitterTemplateView,
) )
urlpatterns = [ urlpatterns = [
# Auth
path("login/", LoginView.as_view(), name="login"), path("login/", LoginView.as_view(), name="login"),
path("logout/", LogoutView.as_view(), name="logout"), path("logout/", LogoutView.as_view(), name="logout"),
# Register
path("register/", RegistrationView.as_view(), name="register"), path("register/", RegistrationView.as_view(), name="register"),
path( path(
"register/complete/", "register/complete/",
@ -41,6 +48,7 @@ urlpatterns = [
ActivationView.as_view(), ActivationView.as_view(),
name="activate", name="activate",
), ),
# Password
path("password-reset/", PasswordResetView.as_view(), name="password-reset"), path("password-reset/", PasswordResetView.as_view(), name="password-reset"),
path( path(
"password-reset/done/", "password-reset/done/",
@ -62,15 +70,42 @@ urlpatterns = [
login_required(PasswordChangeView.as_view()), login_required(PasswordChangeView.as_view()),
name="password-change", name="password-change",
), ),
path("settings/", login_required(SettingsView.as_view()), name="settings"), # Integrations
path( path(
"settings/reddit/callback/", "settings/integrations/reddit/callback/",
login_required(RedditTemplateView.as_view()), login_required(RedditTemplateView.as_view()),
name="reddit-template", name="reddit-template",
), ),
path( path(
"settings/reddit/refresh/", "settings/integrations/reddit/refresh/",
login_required(RedditTokenRedirectView.as_view()), login_required(RedditTokenRedirectView.as_view()),
name="reddit-refresh", name="reddit-refresh",
), ),
path(
"settings/integrations/reddit/revoke/",
login_required(RedditRevokeRedirectView.as_view()),
name="reddit-revoke",
),
path(
"settings/integrations/twitter/auth/",
login_required(TwitterAuthRedirectView.as_view()),
name="twitter-auth",
),
path(
"settings/integrations/twitter/callback/",
login_required(TwitterTemplateView.as_view()),
name="twitter-template",
),
path(
"settings/integrations/twitter/revoke/",
login_required(TwitterRevokeRedirectView.as_view()),
name="twitter-revoke",
),
path(
"settings/integrations",
login_required(IntegrationsView.as_view()),
name="integrations",
),
# Settings
path("settings/", login_required(SettingsView.as_view()), name="settings"),
] ]

View file

@ -1,210 +0,0 @@
from django.contrib import messages
from django.contrib.auth import views as django_views
from django.core.cache import cache
from django.shortcuts import render
from django.urls import reverse_lazy
from django.utils.translation import gettext as _
from django.views.generic import RedirectView, TemplateView
from django.views.generic.edit import FormView, ModelFormMixin
from registration.backends.default import views as registration_views
from newsreader.accounts.forms import UserSettingsForm
from newsreader.accounts.models import User
from newsreader.news.collection.exceptions import StreamException
from newsreader.news.collection.reddit import (
get_reddit_access_token,
get_reddit_authorization_url,
)
from newsreader.news.collection.tasks import RedditTokenTask
class LoginView(django_views.LoginView):
template_name = "accounts/views/login.html"
success_url = reverse_lazy("index")
class LogoutView(django_views.LogoutView):
next_page = reverse_lazy("accounts:login")
# RegistrationView shows a registration form and sends the email
# RegistrationCompleteView shows after filling in the registration form
# ActivationView is send within the activation email and activates the account
# ActivationCompleteView shows the success screen when activation was succesful
# ActivationResendView can be used when activation links are expired
# RegistrationClosedView shows when registration is disabled
class RegistrationView(registration_views.RegistrationView):
disallowed_url = reverse_lazy("accounts:register-closed")
template_name = "registration/registration_form.html"
success_url = reverse_lazy("accounts:register-complete")
class RegistrationCompleteView(TemplateView):
template_name = "registration/registration_complete.html"
class RegistrationClosedView(TemplateView):
template_name = "registration/registration_closed.html"
# Redirects or renders failed activation template
class ActivationView(registration_views.ActivationView):
template_name = "registration/activation_failure.html"
def get_success_url(self, user):
return ("accounts:activate-complete", (), {})
class ActivationCompleteView(TemplateView):
template_name = "registration/activation_complete.html"
# Renders activation form resend or resend_activation_complete
class ActivationResendView(registration_views.ResendActivationView):
template_name = "registration/activation_resend_form.html"
def render_form_submitted_template(self, form):
"""
Renders resend activation complete template with the submitted email.
"""
email = form.cleaned_data["email"]
context = {"email": email}
return render(
self.request, "registration/activation_resend_complete.html", context
)
# PasswordResetView sends the mail
# PasswordResetDoneView shows a success message for the above
# PasswordResetConfirmView checks the link the user clicked and
# prompts for a new password
# PasswordResetCompleteView shows a success message for the above
class PasswordResetView(django_views.PasswordResetView):
template_name = "password-reset/password-reset.html"
subject_template_name = "password-reset/password-reset-subject.txt"
email_template_name = "password-reset/password-reset-email.html"
success_url = reverse_lazy("accounts:password-reset-done")
class PasswordResetDoneView(django_views.PasswordResetDoneView):
template_name = "password-reset/password-reset-done.html"
class PasswordResetConfirmView(django_views.PasswordResetConfirmView):
template_name = "password-reset/password-reset-confirm.html"
success_url = reverse_lazy("accounts:password-reset-complete")
class PasswordResetCompleteView(django_views.PasswordResetCompleteView):
template_name = "password-reset/password-reset-complete.html"
class PasswordChangeView(django_views.PasswordChangeView):
template_name = "accounts/views/password-change.html"
success_url = reverse_lazy("accounts:settings")
class SettingsView(ModelFormMixin, FormView):
template_name = "accounts/views/settings.html"
success_url = reverse_lazy("accounts:settings")
form_class = UserSettingsForm
model = User
def get(self, request, *args, **kwargs):
self.object = self.get_object()
return super().get(request, *args, **kwargs)
def get_object(self, **kwargs):
return self.request.user
def get_context_data(self, **kwargs):
user = self.request.user
reddit_authorization_url = None
reddit_refresh_url = None
reddit_task_active = cache.get(f"{user.email}-reddit-refresh")
if (
user.reddit_refresh_token
and not user.reddit_access_token
and not reddit_task_active
):
reddit_refresh_url = reverse_lazy("accounts:reddit-refresh")
if not user.reddit_refresh_token:
reddit_authorization_url = get_reddit_authorization_url(user)
return {
**super().get_context_data(**kwargs),
"reddit_authorization_url": reddit_authorization_url,
"reddit_refresh_url": reddit_refresh_url,
}
def get_form_kwargs(self):
return {**super().get_form_kwargs(), "instance": self.request.user}
class RedditTemplateView(TemplateView):
template_name = "accounts/views/reddit.html"
def get(self, request, *args, **kwargs):
context = self.get_context_data(**kwargs)
error = request.GET.get("error", None)
state = request.GET.get("state", None)
code = request.GET.get("code", None)
if error:
return self.render_to_response({**context, "error": error})
if not code or not state:
return self.render_to_response(context)
cached_state = cache.get(f"{request.user.email}-reddit-auth")
if state != cached_state:
return self.render_to_response(
{
**context,
"error": "The saved state for Reddit authorization did not match",
}
)
try:
access_token, refresh_token = get_reddit_access_token(code, request.user)
return self.render_to_response(
{
**context,
"access_token": access_token,
"refresh_token": refresh_token,
}
)
except StreamException as e:
return self.render_to_response({**context, "error": str(e)})
except KeyError:
return self.render_to_response(
{**context, "error": "Access and refresh token not found in response"}
)
class RedditTokenRedirectView(RedirectView):
url = reverse_lazy("accounts:settings")
def get(self, request, *args, **kwargs):
response = super().get(request, *args, **kwargs)
user = request.user
task_active = cache.get(f"{user.email}-reddit-refresh")
if not task_active:
RedditTokenTask.delay(user.pk)
messages.success(request, _("Access token is being retrieved"))
cache.set(f"{user.email}-reddit-refresh", 1, 300)
return response
messages.error(request, _("Unable to retrieve token"))
return response

View file

@ -0,0 +1,26 @@
from newsreader.accounts.views.auth import LoginView, LogoutView
from newsreader.accounts.views.integrations import (
IntegrationsView,
RedditRevokeRedirectView,
RedditTemplateView,
RedditTokenRedirectView,
TwitterAuthRedirectView,
TwitterRevokeRedirectView,
TwitterTemplateView,
)
from newsreader.accounts.views.password import (
PasswordChangeView,
PasswordResetCompleteView,
PasswordResetConfirmView,
PasswordResetDoneView,
PasswordResetView,
)
from newsreader.accounts.views.registration import (
ActivationCompleteView,
ActivationResendView,
ActivationView,
RegistrationClosedView,
RegistrationCompleteView,
RegistrationView,
)
from newsreader.accounts.views.settings import SettingsView

View file

@ -0,0 +1,11 @@
from django.contrib.auth import views as django_views
from django.urls import reverse_lazy
class LoginView(django_views.LoginView):
template_name = "accounts/views/login.html"
success_url = reverse_lazy("index")
class LogoutView(django_views.LogoutView):
next_page = reverse_lazy("accounts:login")

View file

@ -0,0 +1,343 @@
import logging
from urllib.parse import parse_qs, urlencode
from django.conf import settings
from django.contrib import messages
from django.core.cache import cache
from django.shortcuts import redirect
from django.urls import reverse_lazy
from django.utils.translation import gettext as _
from django.views.generic import RedirectView, TemplateView
from requests_oauthlib import OAuth1 as OAuth
from newsreader.news.collection.exceptions import StreamException
from newsreader.news.collection.reddit import (
get_reddit_access_token,
get_reddit_authorization_url,
revoke_reddit_token,
)
from newsreader.news.collection.tasks import RedditTokenTask
from newsreader.news.collection.twitter import (
TWITTER_ACCESS_TOKEN_URL,
TWITTER_AUTH_URL,
TWITTER_REQUEST_TOKEN_URL,
TWITTER_REVOKE_URL,
)
from newsreader.news.collection.utils import post
logger = logging.getLogger(__name__)
class IntegrationsView(TemplateView):
template_name = "accounts/views/integrations.html"
def get_context_data(self, **kwargs):
return {
**super().get_context_data(**kwargs),
**self.get_reddit_context(**kwargs),
**self.get_twitter_context(**kwargs),
}
def get_reddit_context(self, **kwargs):
user = self.request.user
reddit_authorization_url = None
reddit_refresh_url = None
reddit_task_active = cache.get(f"{user.email}-reddit-refresh")
if (
user.reddit_refresh_token
and not user.reddit_access_token
and not reddit_task_active
):
reddit_refresh_url = reverse_lazy("accounts:reddit-refresh")
if not user.reddit_refresh_token:
reddit_authorization_url = get_reddit_authorization_url(user)
return {
"reddit_authorization_url": reddit_authorization_url,
"reddit_refresh_url": reddit_refresh_url,
"reddit_revoke_url": (
reverse_lazy("accounts:reddit-revoke")
if not reddit_authorization_url
else None
),
}
def get_twitter_context(self, **kwargs):
twitter_revoke_url = None
if self.request.user.has_twitter_auth:
twitter_revoke_url = reverse_lazy("accounts:twitter-revoke")
return {
"twitter_auth_url": reverse_lazy("accounts:twitter-auth"),
"twitter_revoke_url": twitter_revoke_url,
}
class RedditTemplateView(TemplateView):
template_name = "accounts/views/reddit.html"
def get(self, request, *args, **kwargs):
context = self.get_context_data(**kwargs)
error = request.GET.get("error", None)
state = request.GET.get("state", None)
code = request.GET.get("code", None)
if error:
return self.render_to_response({**context, "error": error})
if not code or not state:
return self.render_to_response(context)
cached_state = cache.get(f"{request.user.email}-reddit-auth")
if state != cached_state:
return self.render_to_response(
{
**context,
"error": _(
"The saved state for Reddit authorization did not match"
),
}
)
try:
access_token, refresh_token = get_reddit_access_token(code, request.user)
return self.render_to_response(
{
**context,
"access_token": access_token,
"refresh_token": refresh_token,
}
)
except StreamException as e:
return self.render_to_response({**context, "error": str(e)})
except KeyError:
return self.render_to_response(
{
**context,
"error": _("Access and refresh token not found in response"),
}
)
class RedditTokenRedirectView(RedirectView):
url = reverse_lazy("accounts:integrations")
def get(self, request, *args, **kwargs):
response = super().get(request, *args, **kwargs)
user = request.user
task_active = cache.get(f"{user.email}-reddit-refresh")
if not task_active:
RedditTokenTask.delay(user.pk)
messages.success(request, _("Access token is being retrieved"))
cache.set(f"{user.email}-reddit-refresh", 1, 300)
return response
messages.error(request, _("Unable to retrieve token"))
return response
class RedditRevokeRedirectView(RedirectView):
url = reverse_lazy("accounts:integrations")
def get(self, request, *args, **kwargs):
response = super().get(request, *args, **kwargs)
user = request.user
if not user.reddit_refresh_token:
messages.error(request, _("No reddit account is linked to this account"))
return response
try:
is_revoked = revoke_reddit_token(user)
except StreamException:
logger.exception(f"Unable to revoke reddit token for {user.pk}")
messages.error(request, _("Unable to revoke reddit token"))
return response
if not is_revoked:
messages.error(request, _("Unable to revoke reddit token"))
return response
user.reddit_access_token = None
user.reddit_refresh_token = None
user.save()
messages.success(request, _("Reddit account deathorized"))
return response
class TwitterRevokeRedirectView(RedirectView):
url = reverse_lazy("accounts:integrations")
def get(self, request, *args, **kwargs):
if not request.user.has_twitter_auth:
messages.error(request, _("No twitter credentials found"))
return super().get(request, *args, **kwargs)
oauth = OAuth(
settings.TWITTER_CONSUMER_ID,
client_secret=settings.TWITTER_CONSUMER_SECRET,
resource_owner_key=request.user.twitter_oauth_token,
resource_owner_secret=request.user.twitter_oauth_token_secret,
)
try:
post(TWITTER_REVOKE_URL, auth=oauth)
except StreamException:
logger.exception("Failed revoking Twitter account")
messages.error(request, _("Unable revoke Twitter account"))
return super().get(request, *args, **kwargs)
request.user.twitter_oauth_token = None
request.user.twitter_oauth_token_secret = None
request.user.save()
messages.success(request, _("Twitter account revoked"))
return super().get(request, *args, **kwargs)
class TwitterAuthRedirectView(RedirectView):
url = reverse_lazy("accounts:integrations")
def get(self, request, *args, **kwargs):
oauth = OAuth(
settings.TWITTER_CONSUMER_ID,
client_secret=settings.TWITTER_CONSUMER_SECRET,
callback_uri=settings.TWITTER_REDIRECT_URL,
)
try:
response = post(TWITTER_REQUEST_TOKEN_URL, auth=oauth)
except StreamException:
logger.exception("Failed requesting Twitter authentication token")
messages.error(request, _("Unable to retrieve initial Twitter token"))
return super().get(request, *args, **kwargs)
params = parse_qs(response.text)
try:
request_oauth_token = params["oauth_token"][0]
request_oauth_secret = params["oauth_token_secret"][0]
except KeyError:
logger.exception("No credentials found in response")
messages.error(request, _("Unable to retrieve initial Twitter token"))
return super().get(request, *args, **kwargs)
cache.set_many(
{
f"twitter-{request.user.email}-token": request_oauth_token,
f"twitter-{request.user.email}-secret": request_oauth_secret,
}
)
request_params = urlencode({"oauth_token": request_oauth_token})
return redirect(f"{TWITTER_AUTH_URL}/?{request_params}")
class TwitterTemplateView(TemplateView):
template_name = "accounts/views/twitter.html"
def get(self, request, *args, **kwargs):
context = self.get_context_data(**kwargs)
denied = request.GET.get("denied", False)
oauth_token = request.GET.get("oauth_token")
oauth_verifier = request.GET.get("oauth_verifier")
if denied:
return self.render_to_response(
{
**context,
"error": _("Twitter authorization failed"),
"authorized": False,
}
)
cached_token = cache.get(f"twitter-{request.user.email}-token")
if oauth_token != cached_token:
return self.render_to_response(
{
**context,
"error": _("OAuth tokens failed to match"),
"authorized": False,
}
)
cached_secret = cache.get(f"twitter-{request.user.email}-secret")
if not cached_token or not cached_secret:
return self.render_to_response(
{
**context,
"error": _("No matching tokens found for this user"),
"authorized": False,
}
)
oauth = OAuth(
settings.TWITTER_CONSUMER_ID,
client_secret=settings.TWITTER_CONSUMER_SECRET,
resource_owner_key=cached_token,
resource_owner_secret=cached_secret,
verifier=oauth_verifier,
)
try:
response = post(TWITTER_ACCESS_TOKEN_URL, auth=oauth)
except StreamException:
logger.exception("Failed requesting Twitter access token")
return self.render_to_response(
{
**context,
"error": _("Failed requesting access token"),
"authorized": False,
}
)
params = parse_qs(response.text)
try:
oauth_token = params["oauth_token"][0]
oauth_secret = params["oauth_token_secret"][0]
except KeyError:
logger.exception("No credentials in Twitter response")
return self.render_to_response(
{
**context,
"error": _("No credentials found in Twitter response"),
"authorized": False,
}
)
request.user.twitter_oauth_token = oauth_token
request.user.twitter_oauth_token_secret = oauth_secret
request.user.save()
cache.delete_many(
[
f"twitter-{request.user.email}-token",
f"twitter-{request.user.email}-secret",
]
)
return self.render_to_response({**context, "error": None, "authorized": True})

View file

@ -0,0 +1,37 @@
from django.contrib.auth import views as django_views
from django.urls import reverse_lazy
from newsreader.news.collection.reddit import (
get_reddit_access_token,
get_reddit_authorization_url,
)
# PasswordResetView sends the mail
# PasswordResetDoneView shows a success message for the above
# PasswordResetConfirmView checks the link the user clicked and
# prompts for a new password
# PasswordResetCompleteView shows a success message for the above
class PasswordResetView(django_views.PasswordResetView):
template_name = "password-reset/password-reset.html"
subject_template_name = "password-reset/password-reset-subject.txt"
email_template_name = "password-reset/password-reset-email.html"
success_url = reverse_lazy("accounts:password-reset-done")
class PasswordResetDoneView(django_views.PasswordResetDoneView):
template_name = "password-reset/password-reset-done.html"
class PasswordResetConfirmView(django_views.PasswordResetConfirmView):
template_name = "password-reset/password-reset-confirm.html"
success_url = reverse_lazy("accounts:password-reset-complete")
class PasswordResetCompleteView(django_views.PasswordResetCompleteView):
template_name = "password-reset/password-reset-complete.html"
class PasswordChangeView(django_views.PasswordChangeView):
template_name = "accounts/views/password-change.html"
success_url = reverse_lazy("accounts:settings")

View file

@ -0,0 +1,59 @@
from django.shortcuts import render
from django.urls import reverse_lazy
from django.views.generic import TemplateView
from registration.backends.default import views as registration_views
from newsreader.news.collection.reddit import (
get_reddit_access_token,
get_reddit_authorization_url,
)
# RegistrationView shows a registration form and sends the email
# RegistrationCompleteView shows after filling in the registration form
# ActivationView is send within the activation email and activates the account
# ActivationCompleteView shows the success screen when activation was succesful
# ActivationResendView can be used when activation links are expired
# RegistrationClosedView shows when registration is disabled
class RegistrationView(registration_views.RegistrationView):
disallowed_url = reverse_lazy("accounts:register-closed")
template_name = "registration/registration_form.html"
success_url = reverse_lazy("accounts:register-complete")
class RegistrationCompleteView(TemplateView):
template_name = "registration/registration_complete.html"
class RegistrationClosedView(TemplateView):
template_name = "registration/registration_closed.html"
# Redirects or renders failed activation template
class ActivationView(registration_views.ActivationView):
template_name = "registration/activation_failure.html"
def get_success_url(self, user):
return ("accounts:activate-complete", (), {})
class ActivationCompleteView(TemplateView):
template_name = "registration/activation_complete.html"
# Renders activation form resend or resend_activation_complete
class ActivationResendView(registration_views.ResendActivationView):
template_name = "registration/activation_resend_form.html"
def render_form_submitted_template(self, form):
"""
Renders resend activation complete template with the submitted email.
"""
email = form.cleaned_data["email"]
context = {"email": email}
return render(
self.request, "registration/activation_resend_complete.html", context
)

View file

@ -0,0 +1,26 @@
from django.urls import reverse_lazy
from django.views.generic.edit import FormView, ModelFormMixin
from newsreader.accounts.forms import UserSettingsForm
from newsreader.accounts.models import User
from newsreader.news.collection.reddit import (
get_reddit_access_token,
get_reddit_authorization_url,
)
class SettingsView(ModelFormMixin, FormView):
template_name = "accounts/views/settings.html"
success_url = reverse_lazy("accounts:settings")
form_class = UserSettingsForm
model = User
def get(self, request, *args, **kwargs):
self.object = self.get_object()
return super().get(request, *args, **kwargs)
def get_object(self, **kwargs):
return self.request.user
def get_form_kwargs(self):
return {**super().get_form_kwargs(), "instance": self.request.user}

View file

@ -129,19 +129,14 @@ LOGGING = {
"class": "logging.StreamHandler", "class": "logging.StreamHandler",
"formatter": "timestamped", "formatter": "timestamped",
}, },
"mail_admins": { "celery": {
"level": "ERROR",
"filters": ["require_debug_false"],
"class": "django.utils.log.AdminEmailHandler",
},
"syslog": {
"level": "INFO", "level": "INFO",
"filters": ["require_debug_false"], "filters": ["require_debug_false"],
"class": "logging.handlers.SysLogHandler", "class": "logging.handlers.SysLogHandler",
"formatter": "syslog", "formatter": "syslog",
"address": "/dev/log", "address": "/dev/log",
}, },
"syslog_errors": { "syslog": {
"level": "ERROR", "level": "ERROR",
"filters": ["require_debug_false"], "filters": ["require_debug_false"],
"class": "logging.handlers.SysLogHandler", "class": "logging.handlers.SysLogHandler",
@ -150,26 +145,13 @@ LOGGING = {
}, },
}, },
"loggers": { "loggers": {
"django": { "django": {"handlers": ["console", "syslog"], "level": "INFO"},
"handlers": ["console", "mail_admins", "syslog_errors"],
"level": "WARNING",
},
"django.server": { "django.server": {
"handlers": ["console", "syslog_errors"], "handlers": ["console", "syslog"],
"level": "INFO",
"propagate": False,
},
"django.request": {
"handlers": ["console", "syslog_errors"],
"level": "INFO",
"propagate": False,
},
"celery": {"handlers": ["syslog", "console"], "level": "INFO"},
"celery.task": {
"handlers": ["syslog", "console"],
"level": "INFO", "level": "INFO",
"propagate": False, "propagate": False,
}, },
"celery": {"handlers": ["celery", "console"], "level": "INFO"},
"newsreader": {"handlers": ["syslog", "console"], "level": "INFO"}, "newsreader": {"handlers": ["syslog", "console"], "level": "INFO"},
}, },
} }
@ -219,7 +201,16 @@ VERSION = get_current_version()
# Reddit integration # Reddit integration
REDDIT_CLIENT_ID = "CLIENT_ID" REDDIT_CLIENT_ID = "CLIENT_ID"
REDDIT_CLIENT_SECRET = "CLIENT_SECRET" REDDIT_CLIENT_SECRET = "CLIENT_SECRET"
REDDIT_REDIRECT_URL = "http://127.0.0.1:8000/accounts/settings/reddit/callback/" REDDIT_REDIRECT_URL = (
"http://127.0.0.1:8000/accounts/settings/integrations/reddit/callback/"
)
# Twitter integration
TWITTER_CONSUMER_ID = "CONSUMER_ID"
TWITTER_CONSUMER_SECRET = "CONSUMER_SECRET"
TWITTER_REDIRECT_URL = (
"http://127.0.0.1:8000/accounts/settings/integrations/twitter/callback/"
)
# Third party settings # Third party settings
AXES_HANDLER = "axes.handlers.cache.AxesCacheHandler" AXES_HANDLER = "axes.handlers.cache.AxesCacheHandler"

View file

@ -46,9 +46,14 @@ TEMPLATES = [
] ]
# Reddit integration # Reddit integration
REDDIT_CLIENT_ID = os.environ["REDDIT_CLIENT_ID"] REDDIT_CLIENT_ID = os.environ.get("REDDIT_CLIENT_ID", "")
REDDIT_CLIENT_SECRET = os.environ["REDDIT_CLIENT_SECRET"] REDDIT_CLIENT_SECRET = os.environ.get("REDDIT_CLIENT_SECRET", "")
REDDIT_REDIRECT_URL = os.environ["REDDIT_CALLBACK_URL"] REDDIT_REDIRECT_URL = os.environ.get("REDDIT_CALLBACK_URL", "")
# Twitter integration
TWITTER_CONSUMER_ID = os.environ.get("TWITTER_CONSUMER_ID", "")
TWITTER_CONSUMER_SECRET = os.environ.get("TWITTER_CONSUMER_SECRET", "")
TWITTER_REDIRECT_URL = os.environ.get("TWITTER_REDIRECT_URL", "")
# Third party settings # Third party settings
AXES_HANDLER = "axes.handlers.database.AxesDatabaseHandler" AXES_HANDLER = "axes.handlers.database.AxesDatabaseHandler"

File diff suppressed because it is too large Load diff

View file

@ -47,7 +47,7 @@
"user" : 2, "user" : 2,
"succeeded" : true, "succeeded" : true,
"modified" : "2019-07-20T11:28:16.473Z", "modified" : "2019-07-20T11:28:16.473Z",
"last_suceeded" : "2019-07-20T11:28:16.316Z", "last_run" : "2019-07-20T11:28:16.316Z",
"name" : "Hackers News", "name" : "Hackers News",
"website_url" : null, "website_url" : null,
"created" : "2019-07-14T13:08:10.374Z", "created" : "2019-07-14T13:08:10.374Z",
@ -65,7 +65,7 @@
"error" : null, "error" : null,
"user" : 2, "user" : 2,
"succeeded" : true, "succeeded" : true,
"last_suceeded" : "2019-07-20T11:28:15.691Z", "last_run" : "2019-07-20T11:28:15.691Z",
"name" : "BBC", "name" : "BBC",
"modified" : "2019-07-20T12:07:49.164Z", "modified" : "2019-07-20T12:07:49.164Z",
"timezone" : "UTC", "timezone" : "UTC",
@ -85,7 +85,7 @@
"website_url" : null, "website_url" : null,
"name" : "Ars Technica", "name" : "Ars Technica",
"succeeded" : true, "succeeded" : true,
"last_suceeded" : "2019-07-20T11:28:15.986Z", "last_run" : "2019-07-20T11:28:15.986Z",
"modified" : "2019-07-20T11:28:16.033Z", "modified" : "2019-07-20T11:28:16.033Z",
"user" : 2 "user" : 2
}, },
@ -102,7 +102,7 @@
"user" : 2, "user" : 2,
"name" : "The Guardian", "name" : "The Guardian",
"succeeded" : true, "succeeded" : true,
"last_suceeded" : "2019-07-20T11:28:16.078Z", "last_run" : "2019-07-20T11:28:16.078Z",
"modified" : "2019-07-20T12:07:44.292Z", "modified" : "2019-07-20T12:07:44.292Z",
"created" : "2019-07-20T11:25:02.089Z", "created" : "2019-07-20T11:25:02.089Z",
"website_url" : null, "website_url" : null,
@ -119,7 +119,7 @@
"website_url" : null, "website_url" : null,
"created" : "2019-07-20T11:25:30.121Z", "created" : "2019-07-20T11:25:30.121Z",
"user" : 2, "user" : 2,
"last_suceeded" : "2019-07-20T11:28:15.860Z", "last_run" : "2019-07-20T11:28:15.860Z",
"succeeded" : true, "succeeded" : true,
"modified" : "2019-07-20T12:07:28.473Z", "modified" : "2019-07-20T12:07:28.473Z",
"name" : "Tweakers" "name" : "Tweakers"
@ -139,7 +139,7 @@
"website_url" : null, "website_url" : null,
"timezone" : "UTC", "timezone" : "UTC",
"user" : 2, "user" : 2,
"last_suceeded" : "2019-07-20T11:28:16.034Z", "last_run" : "2019-07-20T11:28:16.034Z",
"succeeded" : true, "succeeded" : true,
"modified" : "2019-07-20T12:07:21.704Z", "modified" : "2019-07-20T12:07:21.704Z",
"name" : "The Verge" "name" : "The Verge"

View file

@ -69,6 +69,7 @@ class App extends React.Component {
key={category.pk} key={category.pk}
category={category} category={category}
showDialog={this.selectCategory} showDialog={this.selectCategory}
updateUrl={this.props.updateUrl}
/> />
); );
}); });
@ -80,7 +81,7 @@ class App extends React.Component {
const pageHeader = ( const pageHeader = (
<> <>
<h1 className="h1">Categories</h1> <h1 className="h1">Categories</h1>
<a className="link button button--confirm" href="/core/categories/create/"> <a className="link button button--confirm" href={`${this.props.createUrl}/`}>
Create category Create category
</a> </a>
</> </>

View file

@ -33,7 +33,7 @@ const CategoryCard = props => {
<> <>
<a <a
className="link button button--primary" className="link button button--primary"
href={`/core/categories/${category.pk}/`} href={`${props.updateUrl}/${category.pk}/`}
> >
Edit Edit
</a> </a>

View file

@ -9,5 +9,15 @@ if (page) {
const dataScript = document.getElementById('categories-data'); const dataScript = document.getElementById('categories-data');
const categories = JSON.parse(dataScript.textContent); const categories = JSON.parse(dataScript.textContent);
ReactDOM.render(<App categories={categories} />, page); let createUrl = document.getElementById('createUrl').textContent;
let updateUrl = document.getElementById('updateUrl').textContent;
ReactDOM.render(
<App
categories={categories}
createUrl={createUrl.substring(1, createUrl.length - 2)}
updateUrl={updateUrl.substring(1, updateUrl.length - 4)}
/>,
page
);
} }

View file

@ -19,7 +19,11 @@ class App extends React.Component {
return ( return (
<> <>
<Sidebar /> <Sidebar />
<PostList /> <PostList
feedUrl={this.props.feedUrl}
subredditUrl={this.props.subredditUrl}
timelineUrl={this.props.timelineUrl}
/>
{this.props.error && ( {this.props.error && (
<Messages messages={[{ type: 'error', text: this.props.error.message }]} /> <Messages messages={[{ type: 'error', text: this.props.error.message }]} />
@ -30,6 +34,10 @@ class App extends React.Component {
post={this.props.post} post={this.props.post}
rule={this.props.rule} rule={this.props.rule}
category={this.props.category} category={this.props.category}
feedUrl={this.props.feedUrl}
subredditUrl={this.props.subredditUrl}
timelineUrl={this.props.timelineUrl}
categoriesUrl={this.props.categoriesUrl}
/> />
)} )}
</> </>

View file

@ -3,7 +3,13 @@ import { connect } from 'react-redux';
import Cookies from 'js-cookie'; import Cookies from 'js-cookie';
import { unSelectPost, markPostRead } from '../actions/posts.js'; import { unSelectPost, markPostRead } from '../actions/posts.js';
import { CATEGORY_TYPE, RULE_TYPE, FEED, SUBREDDIT } from '../constants.js'; import {
CATEGORY_TYPE,
RULE_TYPE,
FEED,
SUBREDDIT,
TWITTER_TIMELINE,
} from '../constants.js';
import { formatDatetime } from '../../../utils.js'; import { formatDatetime } from '../../../utils.js';
class PostModal extends React.Component { class PostModal extends React.Component {
@ -44,10 +50,15 @@ class PostModal extends React.Component {
const post = this.props.post; const post = this.props.post;
const publicationDate = formatDatetime(post.publicationDate); const publicationDate = formatDatetime(post.publicationDate);
const titleClassName = post.read ? 'post__title post__title--read' : 'post__title'; const titleClassName = post.read ? 'post__title post__title--read' : 'post__title';
const ruleUrl = let ruleUrl = '';
this.props.rule.type === FEED
? `/collection/rules/${this.props.rule.id}/` if (this.props.rule.type === SUBREDDIT) {
: `/collection/rules/subreddits/${this.props.rule.id}/`; ruleUrl = `${this.props.subredditUrl}/${this.props.rule.id}/`;
} else if (this.props.rule.type === TWITTER_TIMELINE) {
ruleUrl = `${this.props.timelineUrl}/${this.props.rule.id}/`;
} else {
ruleUrl = `${this.props.feedUrl}/${this.props.rule.id}/`;
}
return ( return (
<div className="modal post-modal"> <div className="modal post-modal">
@ -66,7 +77,7 @@ class PostModal extends React.Component {
{this.props.category && ( {this.props.category && (
<span className="badge post__category" title={this.props.category.name}> <span className="badge post__category" title={this.props.category.name}>
<a <a
href={`/core/categories/${this.props.category.id}/`} href={`${this.props.categoriesUrl}/${this.props.category.id}/`}
target="_blank" target="_blank"
rel="noopener noreferrer" rel="noopener noreferrer"
> >

View file

@ -1,7 +1,13 @@
import React from 'react'; import React from 'react';
import { connect } from 'react-redux'; import { connect } from 'react-redux';
import { CATEGORY_TYPE, RULE_TYPE, FEED, SUBREDDIT } from '../../constants.js'; import {
CATEGORY_TYPE,
RULE_TYPE,
FEED,
SUBREDDIT,
TWITTER_TIMELINE,
} from '../../constants.js';
import { selectPost } from '../../actions/posts.js'; import { selectPost } from '../../actions/posts.js';
import { formatDatetime } from '../../../../utils.js'; import { formatDatetime } from '../../../../utils.js';
@ -13,11 +19,15 @@ class PostItem extends React.Component {
const titleClassName = post.read const titleClassName = post.read
? 'posts__header posts__header--read' ? 'posts__header posts__header--read'
: 'posts__header'; : 'posts__header';
let ruleUrl = '';
const ruleUrl = if (rule.type === SUBREDDIT) {
rule.type === FEED ruleUrl = `${this.props.subredditUrl}/${rule.id}/`;
? `/collection/rules/${rule.id}/` } else if (rule.type === TWITTER_TIMELINE) {
: `/collection/rules/subreddits/${rule.id}/`; ruleUrl = `${this.props.timelineUrl}/${rule.id}/`;
} else {
ruleUrl = `${this.props.feedUrl}/${rule.id}/`;
}
return ( return (
<li className="posts__item"> <li className="posts__item">

View file

@ -38,7 +38,16 @@ class PostList extends React.Component {
render() { render() {
const postItems = this.props.postsBySection.map((item, index) => { const postItems = this.props.postsBySection.map((item, index) => {
return <PostItem key={index} post={item} selected={this.props.selected} />; return (
<PostItem
key={index}
post={item}
selected={this.props.selected}
feedUrl={this.props.feedUrl}
subredditUrl={this.props.subredditUrl}
timelineUrl={this.props.timelineUrl}
/>
);
}); });
if (isEqual(this.props.selected, {})) { if (isEqual(this.props.selected, {})) {

View file

@ -3,3 +3,4 @@ export const CATEGORY_TYPE = 'CATEGORY';
export const SUBREDDIT = 'subreddit'; export const SUBREDDIT = 'subreddit';
export const FEED = 'feed'; export const FEED = 'feed';
export const TWITTER_TIMELINE = 'twitter_timeline';

View file

@ -11,9 +11,19 @@ const page = document.getElementById('homepage--page');
if (page) { if (page) {
const store = configureStore(); const store = configureStore();
let feedUrl = document.getElementById('feedUrl').textContent;
let subredditUrl = document.getElementById('subredditUrl').textContent;
let timelineUrl = document.getElementById('timelineUrl').textContent;
let categoriesUrl = document.getElementById('categoriesUrl').textContent;
ReactDOM.render( ReactDOM.render(
<Provider store={store}> <Provider store={store}>
<App /> <App
feedUrl={feedUrl.substring(1, feedUrl.length - 4)}
subredditUrl={subredditUrl.substring(1, subredditUrl.length - 4)}
timelineUrl={timelineUrl.substring(1, timelineUrl.length - 4)}
categoriesUrl={categoriesUrl.substring(1, categoriesUrl.length - 4)}
/>
</Provider>, </Provider>,
page page
); );

View file

@ -6,14 +6,7 @@ from newsreader.news.collection.models import CollectionRule
class CollectionRuleAdmin(admin.ModelAdmin): class CollectionRuleAdmin(admin.ModelAdmin):
fields = ("url", "name", "timezone", "category", "favicon", "user") fields = ("url", "name", "timezone", "category", "favicon", "user")
list_display = ( list_display = ("name", "type_display", "category", "url", "last_run", "succeeded")
"name",
"type_display",
"category",
"url",
"last_suceeded",
"succeeded",
)
list_filter = ("user",) list_filter = ("user",)
def save_model(self, request, obj, form, change): def save_model(self, request, obj, form, change):

View file

@ -1,7 +1,10 @@
from bs4 import BeautifulSoup import bleach
from newsreader.news.collection.exceptions import StreamParseException from newsreader.news.collection.constants import (
from newsreader.news.collection.utils import fetch WHITELISTED_ATTRIBUTES,
WHITELISTED_TAGS,
)
from newsreader.news.core.models import Post
class Stream: class Stream:
@ -20,19 +23,16 @@ class Stream:
def parse(self, response): def parse(self, response):
raise NotImplementedError raise NotImplementedError
class Meta:
abstract = True
class Client: class Client:
""" """
Retrieves the data with streams Retrieves the data through streams
""" """
stream = Stream stream = Stream
def __init__(self, rules=[]): def __init__(self, rules=[]):
self.rules = rules if rules else CollectionRule.objects.enabled() self.rules = rules
def __enter__(self): def __enter__(self):
for rule in self.rules: for rule in self.rules:
@ -43,36 +43,40 @@ class Client:
def __exit__(self, *args, **kwargs): def __exit__(self, *args, **kwargs):
pass pass
class Meta:
abstract = True
class Builder: class Builder:
""" """
Creates the collected posts Builds instances of various types
""" """
instances = [] instances = []
stream = None stream = None
payload = None
def __init__(self, stream): def __init__(self, payload, stream):
self.payload = payload
self.stream = stream self.stream = stream
def __enter__(self): def __enter__(self):
self.create_posts(self.stream)
return self return self
def __exit__(self, *args, **kwargs): def __exit__(self, *args, **kwargs):
pass pass
def create_posts(self, stream): def build(self):
pass raise NotImplementedError
def save(self): def sanitize_fragment(self, fragment):
pass if not fragment:
return ""
class Meta: return bleach.clean(
abstract = True fragment,
tags=WHITELISTED_TAGS,
attributes=WHITELISTED_ATTRIBUTES,
strip=True,
strip_comments=True,
)
class Collector: class Collector:
@ -88,46 +92,54 @@ class Collector:
self.builder = builder if builder else self.builder self.builder = builder if builder else self.builder
def collect(self, rules=None): def collect(self, rules=None):
with self.client(rules=rules) as client: raise NotImplementedError
for data, stream in client:
with self.builder((data, stream)) as builder:
builder.save()
class Meta:
abstract = True
class WebsiteStream(Stream): class Scheduler:
def __init__(self, url): """
self.url = url Schedules rules according to certain ratelimitting
"""
def read(self): def get_scheduled_rules(self):
response = fetch(self.url) raise NotImplementedError
return (self.parse(response.content), self)
def parse(self, payload):
try:
return BeautifulSoup(payload, "lxml")
except TypeError:
raise StreamParseException("Could not parse given HTML")
class URLBuilder(Builder): class PostBuilder(Builder):
rule_type = None
def __enter__(self): def __enter__(self):
return self self.existing_posts = {
post.remote_identifier: post
for post in Post.objects.filter(
rule=self.stream.rule, rule__type=self.rule_type
)
}
def build(self): return super().__enter__()
data, stream = self.stream
rule = stream.rule
try: def save(self):
url = data["feed"]["link"] for post in self.instances:
except (KeyError, TypeError): post.save()
url = None
if url:
rule.website_url = url
rule.save()
return rule, url class PostStream(Stream):
rule_type = None
class PostClient(Client):
stream = PostStream
def set_rule_error(self, rule, exception):
length = rule._meta.get_field("error").max_length
rule.error = exception.message[-length:]
rule.succeeded = False
class PostCollector(Collector):
def collect(self, rules=[]):
with self.client(rules=rules) as client:
for payload, stream in client:
with self.builder(payload, stream) as builder:
builder.build()
builder.save()

View file

@ -5,3 +5,10 @@ from django.utils.translation import gettext as _
class RuleTypeChoices(TextChoices): class RuleTypeChoices(TextChoices):
feed = "feed", _("Feed") feed = "feed", _("Feed")
subreddit = "subreddit", _("Subreddit") subreddit = "subreddit", _("Subreddit")
twitter_timeline = "twitter_timeline", _("Twitter timeline")
class TwitterPostTypeChoices(TextChoices):
photo = "photo", _("Photo")
video = "video", _("Video")
animated_gif = "animated_gif", _("GIF")

View file

@ -23,6 +23,7 @@ WHITELISTED_TAGS = (
WHITELISTED_ATTRIBUTES = { WHITELISTED_ATTRIBUTES = {
**BLEACH_ATTRIBUTES, **BLEACH_ATTRIBUTES,
"a": ["href", "rel"], "a": ["href", "rel"],
"img": ["alt", "src"], "img": ["alt", "src", "loading"],
"source": ["srcset", "media", "src", "type"], "video": ["controls", "muted"],
"source": ["srcset", "src", "media", "type"],
} }

View file

@ -1,16 +1,12 @@
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from urllib.parse import urljoin, urlparse from urllib.parse import urljoin, urlparse
from newsreader.news.collection.base import ( from bs4 import BeautifulSoup
Builder,
Client, from newsreader.news.collection.base import Builder, Client, Collector, Stream
Collector, from newsreader.news.collection.exceptions import StreamException, StreamParseException
Stream,
URLBuilder,
WebsiteStream,
)
from newsreader.news.collection.exceptions import StreamException
from newsreader.news.collection.feed import FeedClient from newsreader.news.collection.feed import FeedClient
from newsreader.news.collection.utils import fetch
LINK_RELS = [ LINK_RELS = [
@ -21,17 +17,45 @@ LINK_RELS = [
] ]
class WebsiteStream(Stream):
def read(self):
response = fetch(self.rule.website_url)
return self.parse(response.content), self
def parse(self, payload):
try:
return BeautifulSoup(payload, features="lxml")
except TypeError:
raise StreamParseException("Could not parse given HTML")
class WebsiteURLBuilder(Builder):
def build(self):
try:
url = self.payload["feed"]["link"]
except (KeyError, TypeError):
url = None
self.instances = [(self.stream, url)] if url else []
def save(self):
for stream, url in self.instances:
stream.rule.website_url = url
stream.rule.save()
class FaviconBuilder(Builder): class FaviconBuilder(Builder):
def build(self): def build(self):
rule, soup = self.stream rule = self.stream.rule
url = self.parse(soup, rule.website_url) url = self.parse()
if url: self.instances = [(rule, url)] if url else []
rule.favicon = url
rule.save() def parse(self):
soup = self.payload
def parse(self, soup, website_url):
if not soup.head: if not soup.head:
return return
@ -44,9 +68,9 @@ class FaviconBuilder(Builder):
parsed_url = urlparse(url) parsed_url = urlparse(url)
if not parsed_url.scheme and not parsed_url.netloc: if not parsed_url.scheme and not parsed_url.netloc:
if not website_url: if not self.stream.rule.website_url:
return return
return urljoin(website_url, url) return urljoin(self.stream.rule.website_url, url)
elif not parsed_url.scheme: elif not parsed_url.scheme:
return urljoin(f"https://{parsed_url.netloc}", parsed_url.path) return urljoin(f"https://{parsed_url.netloc}", parsed_url.path)
@ -73,6 +97,11 @@ class FaviconBuilder(Builder):
elif icons: elif icons:
return icons.pop() return icons.pop()
def save(self):
for rule, favicon_url in self.instances:
rule.favicon = favicon_url
rule.save()
class FaviconClient(Client): class FaviconClient(Client):
stream = WebsiteStream stream = WebsiteStream
@ -82,39 +111,35 @@ class FaviconClient(Client):
def __enter__(self): def __enter__(self):
with ThreadPoolExecutor(max_workers=10) as executor: with ThreadPoolExecutor(max_workers=10) as executor:
futures = { futures = [executor.submit(stream.read) for stream in self.streams]
executor.submit(stream.read): rule for rule, stream in self.streams
}
for future in as_completed(futures): for future in as_completed(futures):
rule = futures[future]
try: try:
response_data, stream = future.result() payload, stream = future.result()
except StreamException: except StreamException:
continue continue
yield (rule, response_data) yield payload, stream
class FaviconCollector(Collector): class FaviconCollector(Collector):
feed_client, favicon_client = (FeedClient, FaviconClient) feed_client, favicon_client = (FeedClient, FaviconClient)
url_builder, favicon_builder = (URLBuilder, FaviconBuilder) url_builder, favicon_builder = (WebsiteURLBuilder, FaviconBuilder)
def collect(self, rules=None): def collect(self, rules=None):
streams = [] streams = []
with self.feed_client(rules=rules) as client: with self.feed_client(rules=rules) as client:
for data, stream in client: for payload, stream in client:
with self.url_builder((data, stream)) as builder: with self.url_builder(payload, stream) as builder:
rule, url = builder.build() builder.build()
builder.save()
if not url: if builder.instances:
continue streams.append(WebsiteStream(stream.rule))
streams.append((rule, WebsiteStream(url)))
with self.favicon_client(streams) as client: with self.favicon_client(streams) as client:
for rule, data in client: for payload, stream in client:
with self.favicon_builder((rule, data)) as builder: with self.favicon_builder(payload, stream) as builder:
builder.build() builder.build()
builder.save()

View file

@ -6,17 +6,17 @@ from datetime import timedelta
from django.core.exceptions import MultipleObjectsReturned, ObjectDoesNotExist from django.core.exceptions import MultipleObjectsReturned, ObjectDoesNotExist
from django.utils import timezone from django.utils import timezone
import bleach
import pytz import pytz
from feedparser import parse from feedparser import parse
from newsreader.news.collection.base import Builder, Client, Collector, Stream from newsreader.news.collection.base import (
from newsreader.news.collection.choices import RuleTypeChoices PostBuilder,
from newsreader.news.collection.constants import ( PostClient,
WHITELISTED_ATTRIBUTES, PostCollector,
WHITELISTED_TAGS, PostStream,
) )
from newsreader.news.collection.choices import RuleTypeChoices
from newsreader.news.collection.exceptions import ( from newsreader.news.collection.exceptions import (
StreamDeniedException, StreamDeniedException,
StreamException, StreamException,
@ -24,7 +24,6 @@ from newsreader.news.collection.exceptions import (
StreamParseException, StreamParseException,
StreamTimeOutException, StreamTimeOutException,
) )
from newsreader.news.collection.models import CollectionRule
from newsreader.news.collection.utils import ( from newsreader.news.collection.utils import (
build_publication_date, build_publication_date,
fetch, fetch,
@ -36,32 +35,10 @@ from newsreader.news.core.models import Post
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class FeedBuilder(Builder): class FeedBuilder(PostBuilder):
instances = [] rule__type = RuleTypeChoices.feed
def __enter__(self): def build(self):
_, stream = self.stream
self.instances = []
self.existing_posts = {
post.remote_identifier: post
for post in Post.objects.filter(
rule=stream.rule, rule__type=RuleTypeChoices.feed
)
}
return super().__enter__()
def create_posts(self, stream):
data, stream = stream
with FeedDuplicateHandler(stream.rule) as duplicate_handler:
entries = data.get("entries", [])
instances = self.build(entries, stream.rule)
self.instances = duplicate_handler.check(instances)
def build(self, entries, rule):
field_mapping = { field_mapping = {
"id": "remote_identifier", "id": "remote_identifier",
"title": "title", "title": "title",
@ -70,56 +47,47 @@ class FeedBuilder(Builder):
"published_parsed": "publication_date", "published_parsed": "publication_date",
"author": "author", "author": "author",
} }
tz = pytz.timezone(self.stream.rule.timezone)
instances = []
tz = pytz.timezone(rule.timezone) with FeedDuplicateHandler(self.stream.rule) as duplicate_handler:
entries = self.payload.get("entries", [])
for entry in entries: for entry in entries:
data = {"rule_id": rule.pk} data = {"rule_id": self.stream.rule.pk}
for field, model_field in field_mapping.items(): for field, model_field in field_mapping.items():
if not field in entry: if not field in entry:
continue continue
value = truncate_text(Post, model_field, entry[field]) value = truncate_text(Post, model_field, entry[field])
if field == "published_parsed": if field == "published_parsed":
data[model_field] = build_publication_date(value, tz) data[model_field] = build_publication_date(value, tz)
elif field == "summary": elif field == "summary":
data[model_field] = self.sanitize_fragment(value) data[model_field] = self.sanitize_fragment(value)
else: else:
data[model_field] = value data[model_field] = value
if "content" in entry: if "content" in entry:
content = self.get_content(entry["content"]) content = self.get_content(entry["content"])
body = data.get("body", "") body = data.get("body", "")
if not body or len(body) < len(content): if not body or len(body) < len(content):
data["body"] = content data["body"] = content
yield Post(**data) instances.append(Post(**data))
def sanitize_fragment(self, fragment): self.instances = duplicate_handler.check(instances)
if not fragment:
return ""
return bleach.clean(
fragment,
tags=WHITELISTED_TAGS,
attributes=WHITELISTED_ATTRIBUTES,
strip=True,
strip_comments=True,
)
def get_content(self, items): def get_content(self, items):
content = "\n ".join([item.get("value") for item in items]) content = "\n ".join([item.get("value") for item in items])
return self.sanitize_fragment(content) return self.sanitize_fragment(content)
def save(self):
for post in self.instances:
post.save()
class FeedStream(PostStream):
rule_type = RuleTypeChoices.feed
class FeedStream(Stream):
def read(self): def read(self):
response = fetch(self.rule.url) response = fetch(self.rule.url)
@ -133,17 +101,9 @@ class FeedStream(Stream):
raise StreamParseException(response=response, message=message) from e raise StreamParseException(response=response, message=message) from e
class FeedClient(Client): class FeedClient(PostClient):
stream = FeedStream stream = FeedStream
def __init__(self, rules=[]):
if rules:
self.rules = rules
else:
self.rules = CollectionRule.objects.filter(
enabled=True, type=RuleTypeChoices.feed
)
def __enter__(self): def __enter__(self):
streams = [self.stream(rule) for rule in self.rules] streams = [self.stream(rule) for rule in self.rules]
@ -154,13 +114,12 @@ class FeedClient(Client):
stream = futures[future] stream = futures[future]
try: try:
response_data = future.result() payload = future.result()
stream.rule.error = None stream.rule.error = None
stream.rule.succeeded = True stream.rule.succeeded = True
stream.rule.last_suceeded = timezone.now()
yield response_data yield payload
except (StreamNotFoundException, StreamTimeOutException) as e: except (StreamNotFoundException, StreamTimeOutException) as e:
logger.warning(f"Request failed for {stream.rule.url}") logger.warning(f"Request failed for {stream.rule.url}")
@ -174,16 +133,11 @@ class FeedClient(Client):
continue continue
finally: finally:
stream.rule.last_run = timezone.now()
stream.rule.save() stream.rule.save()
def set_rule_error(self, rule, exception):
length = rule._meta.get_field("error").max_length
rule.error = exception.message[-length:] class FeedCollector(PostCollector):
rule.succeeded = False
class FeedCollector(Collector):
builder = FeedBuilder builder = FeedBuilder
client = FeedClient client = FeedClient

View file

@ -1,101 +0,0 @@
from django import forms
from django.core.exceptions import ValidationError
from django.utils.safestring import mark_safe
from django.utils.translation import gettext_lazy as _
import pytz
from newsreader.core.forms import CheckboxInput
from newsreader.news.collection.choices import RuleTypeChoices
from newsreader.news.collection.models import CollectionRule
from newsreader.news.collection.reddit import REDDIT_API_URL
from newsreader.news.core.models import Category
def get_reddit_help_text():
return mark_safe(
"Only subreddits are supported"
" see the 'listings' section in <a className='link' target='_blank' rel='noopener noreferrer'"
" href='https://www.reddit.com/dev/api#section_listings'>the reddit API docs</a>."
" For example: <a className='link' target='_blank' rel='noopener noreferrer'"
" href='https://oauth.reddit.com/r/aww'>https://oauth.reddit.com/r/aww</a>"
)
class CollectionRuleForm(forms.ModelForm):
category = forms.ModelChoiceField(required=False, queryset=Category.objects.all())
timezone = forms.ChoiceField(
widget=forms.Select(attrs={"size": len(pytz.all_timezones)}),
choices=((timezone, timezone) for timezone in pytz.all_timezones),
help_text=_("The timezone which the feed uses"),
initial=pytz.utc,
)
def __init__(self, *args, **kwargs):
self.user = kwargs.pop("user")
super().__init__(*args, **kwargs)
self.fields["category"].queryset = Category.objects.filter(user=self.user)
def save(self, commit=True):
instance = super().save(commit=False)
instance.user = self.user
if commit:
instance.save()
self.save_m2m()
return instance
class Meta:
model = CollectionRule
fields = ("name", "url", "timezone", "favicon", "category")
class CollectionRuleBulkForm(forms.Form):
rules = forms.ModelMultipleChoiceField(queryset=CollectionRule.objects.none())
def __init__(self, user, *args, **kwargs):
self.user = user
super().__init__(*args, **kwargs)
self.fields["rules"].queryset = CollectionRule.objects.filter(user=user)
class SubRedditRuleForm(CollectionRuleForm):
url = forms.URLField(max_length=1024, help_text=get_reddit_help_text)
timezone = None
def clean_url(self):
url = self.cleaned_data["url"]
if not url.startswith(REDDIT_API_URL):
raise ValidationError(_("This does not look like an Reddit API URL"))
return url
def save(self, commit=True):
instance = super().save(commit=False)
instance.type = RuleTypeChoices.subreddit
instance.timezone = str(pytz.utc)
if commit:
instance.save()
self.save_m2m()
return instance
class Meta:
model = CollectionRule
fields = ("name", "url", "favicon", "category")
class OPMLImportForm(forms.Form):
file = forms.FileField(allow_empty_file=False)
skip_existing = forms.BooleanField(
initial=False, required=False, widget=CheckboxInput
)

View file

@ -0,0 +1,4 @@
from newsreader.news.collection.forms.feed import FeedForm, OPMLImportForm
from newsreader.news.collection.forms.reddit import SubRedditForm
from newsreader.news.collection.forms.rules import CollectionRuleBulkForm
from newsreader.news.collection.forms.twitter import TwitterTimelineForm

View file

@ -0,0 +1,29 @@
from django import forms
from newsreader.news.collection.models import CollectionRule
from newsreader.news.core.models import Category
class CollectionRuleForm(forms.ModelForm):
category = forms.ModelChoiceField(required=False, queryset=Category.objects.all())
def __init__(self, *args, **kwargs):
self.user = kwargs.pop("user")
super().__init__(*args, **kwargs)
self.fields["category"].queryset = Category.objects.filter(user=self.user)
def save(self, commit=True):
instance = super().save(commit=False)
instance.user = self.user
if commit:
instance.save()
self.save_m2m()
return instance
class Meta:
model = CollectionRule
fields = "__all__"

View file

@ -0,0 +1,28 @@
from django import forms
from django.utils.translation import gettext_lazy as _
import pytz
from newsreader.core.forms import CheckboxInput
from newsreader.news.collection.forms.base import CollectionRuleForm
from newsreader.news.collection.models import CollectionRule
class FeedForm(CollectionRuleForm):
timezone = forms.ChoiceField(
widget=forms.Select(attrs={"size": len(pytz.all_timezones)}),
choices=((timezone, timezone) for timezone in pytz.all_timezones),
help_text=_("The timezone which the feed uses"),
initial=pytz.utc,
)
class Meta:
model = CollectionRule
fields = ("name", "url", "timezone", "favicon", "category")
class OPMLImportForm(forms.Form):
file = forms.FileField(allow_empty_file=False)
skip_existing = forms.BooleanField(
initial=False, required=False, widget=CheckboxInput
)

View file

@ -0,0 +1,49 @@
from django import forms
from django.core.exceptions import ValidationError
from django.utils.safestring import mark_safe
from django.utils.translation import gettext_lazy as _
import pytz
from newsreader.news.collection.choices import RuleTypeChoices
from newsreader.news.collection.forms.base import CollectionRuleForm
from newsreader.news.collection.models import CollectionRule
from newsreader.news.collection.reddit import REDDIT_API_URL
def get_reddit_help_text():
return mark_safe(
"Only subreddits are supported"
" see the 'listings' section in <a className='link' target='_blank' rel='noopener noreferrer'"
" href='https://www.reddit.com/dev/api#section_listings'>the reddit API docs</a>."
" For example: <a className='link' target='_blank' rel='noopener noreferrer'"
" href='https://oauth.reddit.com/r/aww'>https://oauth.reddit.com/r/aww</a>"
)
class SubRedditForm(CollectionRuleForm):
url = forms.URLField(max_length=1024, help_text=get_reddit_help_text)
def clean_url(self):
url = self.cleaned_data["url"]
if not url.startswith(REDDIT_API_URL):
raise ValidationError(_("This does not look like an Reddit API URL"))
return url
def save(self, commit=True):
instance = super().save(commit=False)
instance.type = RuleTypeChoices.subreddit
instance.timezone = str(pytz.utc)
if commit:
instance.save()
self.save_m2m()
return instance
class Meta:
model = CollectionRule
fields = ("name", "url", "favicon", "category")

View file

@ -0,0 +1,14 @@
from django import forms
from newsreader.news.collection.models import CollectionRule
class CollectionRuleBulkForm(forms.Form):
rules = forms.ModelMultipleChoiceField(queryset=CollectionRule.objects.none())
def __init__(self, user, *args, **kwargs):
self.user = user
super().__init__(*args, **kwargs)
self.fields["rules"].queryset = CollectionRule.objects.filter(user=user)

View file

@ -0,0 +1,35 @@
from django import forms
from django.utils.translation import gettext_lazy as _
import pytz
from newsreader.news.collection.choices import RuleTypeChoices
from newsreader.news.collection.forms.base import CollectionRuleForm
from newsreader.news.collection.models import CollectionRule
from newsreader.news.collection.twitter import TWITTER_API_URL
class TwitterTimelineForm(CollectionRuleForm):
screen_name = forms.CharField(
max_length=255,
label=_("Twitter profile name"),
help_text=_("Profile name without hashtags"),
required=True,
)
def save(self, commit=True):
instance = super().save(commit=False)
instance.type = RuleTypeChoices.twitter_timeline
instance.timezone = str(pytz.utc)
instance.url = f"{TWITTER_API_URL}/statuses/user_timeline.json?screen_name={instance.screen_name}&tweet_mode=extended"
if commit:
instance.save()
self.save_m2m()
return instance
class Meta:
model = CollectionRule
fields = ("name", "screen_name", "favicon", "category")

View file

@ -0,0 +1,29 @@
# Generated by Django 3.0.7 on 2020-08-07 18:30
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [("collection", "0008_collectionrule_type")]
operations = [
migrations.AddField(
model_name="collectionrule",
name="screen_name",
field=models.CharField(blank=True, max_length=255, null=True),
),
migrations.AlterField(
model_name="collectionrule",
name="type",
field=models.CharField(
choices=[
("feed", "Feed"),
("subreddit", "Subreddit"),
("twitter", "Twitter"),
],
default="feed",
max_length=20,
),
),
]

View file

@ -0,0 +1,24 @@
# Generated by Django 3.0.7 on 2020-09-13 19:01
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [("collection", "0009_auto_20200807_2030")]
operations = [
migrations.AlterField(
model_name="collectionrule",
name="type",
field=models.CharField(
choices=[
("feed", "Feed"),
("subreddit", "Subreddit"),
("twitter_timeline", "Twitter timeline"),
],
default="feed",
max_length=20,
),
)
]

View file

@ -0,0 +1,14 @@
# Generated by Django 3.0.7 on 2020-09-13 19:57
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [("collection", "0010_auto_20200913_2101")]
operations = [
migrations.RenameField(
model_name="collectionrule", old_name="last_suceeded", new_name="last_run"
)
]

View file

@ -41,9 +41,8 @@ class CollectionRule(TimeStampedModel):
on_delete=models.SET_NULL, on_delete=models.SET_NULL,
) )
last_suceeded = models.DateTimeField(blank=True, null=True) last_run = models.DateTimeField(blank=True, null=True)
succeeded = models.BooleanField(default=False) succeeded = models.BooleanField(default=False)
error = models.CharField(max_length=1024, blank=True, null=True) error = models.CharField(max_length=1024, blank=True, null=True)
enabled = models.BooleanField( enabled = models.BooleanField(
@ -57,6 +56,9 @@ class CollectionRule(TimeStampedModel):
on_delete=models.CASCADE, on_delete=models.CASCADE,
) )
# Twitter
screen_name = models.CharField(max_length=255, blank=True, null=True)
objects = CollectionRuleQuerySet.as_manager() objects = CollectionRuleQuerySet.as_manager()
def __str__(self): def __str__(self):
@ -66,5 +68,13 @@ class CollectionRule(TimeStampedModel):
def update_url(self): def update_url(self):
if self.type == RuleTypeChoices.subreddit: if self.type == RuleTypeChoices.subreddit:
return reverse("news:collection:subreddit-update", kwargs={"pk": self.pk}) return reverse("news:collection:subreddit-update", kwargs={"pk": self.pk})
elif self.type == RuleTypeChoices.twitter_timeline:
return reverse(
"news:collection:twitter-timeline-update", kwargs={"pk": self.pk}
)
return reverse("news:collection:rule-update", kwargs={"pk": self.pk}) return reverse("news:collection:feed-update", kwargs={"pk": self.pk})
@property
def failed(self):
return not self.succeeded and self.last_run

View file

@ -12,11 +12,16 @@ from django.core.cache import cache
from django.utils import timezone from django.utils import timezone
from django.utils.html import format_html from django.utils.html import format_html
import bleach
import pytz import pytz
import requests import requests
from newsreader.news.collection.base import Builder, Client, Collector, Stream from newsreader.news.collection.base import (
PostBuilder,
PostClient,
PostCollector,
PostStream,
Scheduler,
)
from newsreader.news.collection.choices import RuleTypeChoices from newsreader.news.collection.choices import RuleTypeChoices
from newsreader.news.collection.constants import ( from newsreader.news.collection.constants import (
WHITELISTED_ATTRIBUTES, WHITELISTED_ATTRIBUTES,
@ -93,32 +98,32 @@ def get_reddit_access_token(code, user):
return response_data["access_token"], response_data["refresh_token"] return response_data["access_token"], response_data["refresh_token"]
class RedditBuilder(Builder): # Note that the API always returns 204's with correct basic auth headers
def __enter__(self): def revoke_reddit_token(user):
_, stream = self.stream client_auth = requests.auth.HTTPBasicAuth(
settings.REDDIT_CLIENT_ID, settings.REDDIT_CLIENT_SECRET
)
self.instances = [] response = post(
self.existing_posts = { f"{REDDIT_URL}/api/v1/revoke_token",
post.remote_identifier: post data={"token": user.reddit_refresh_token, "token_type_hint": "refresh_token"},
for post in Post.objects.filter( auth=client_auth,
rule=stream.rule, rule__type=RuleTypeChoices.subreddit )
)
}
return super().__enter__() return response.status_code == 204
def create_posts(self, stream):
data, stream = stream
posts = []
if not "data" in data or not "children" in data["data"]: class RedditBuilder(PostBuilder):
rule_type = RuleTypeChoices.subreddit
def build(self):
results = {}
if not "data" in self.payload or not "children" in self.payload["data"]:
return return
posts = data["data"]["children"] posts = self.payload["data"]["children"]
self.instances = self.build(posts, stream.rule) rule = self.stream.rule
def build(self, posts, rule):
results = {}
for post in posts: for post in posts:
if not "data" in post or post["kind"] != REDDIT_POST: if not "data" in post or post["kind"] != REDDIT_POST:
@ -139,17 +144,7 @@ class RedditBuilder(Builder):
if is_text_post: if is_text_post:
uncleaned_body = data["selftext_html"] uncleaned_body = data["selftext_html"]
unescaped_body = unescape(uncleaned_body) if uncleaned_body else "" unescaped_body = unescape(uncleaned_body) if uncleaned_body else ""
body = ( body = self.sanitize_fragment(unescaped_body) if unescaped_body else ""
bleach.clean(
unescaped_body,
tags=WHITELISTED_TAGS,
attributes=WHITELISTED_ATTRIBUTES,
strip=True,
strip_comments=True,
)
if unescaped_body
else ""
)
elif direct_url.endswith(REDDIT_IMAGE_EXTENSIONS): elif direct_url.endswith(REDDIT_IMAGE_EXTENSIONS):
body = format_html( body = format_html(
"<div><img alt='{title}' src='{url}' loading='lazy' /></div>", "<div><img alt='{title}' src='{url}' loading='lazy' /></div>",
@ -192,7 +187,9 @@ class RedditBuilder(Builder):
parsed_date = datetime.fromtimestamp(post["data"]["created_utc"]) parsed_date = datetime.fromtimestamp(post["data"]["created_utc"])
created_date = pytz.utc.localize(parsed_date) created_date = pytz.utc.localize(parsed_date)
except (OverflowError, OSError): except (OverflowError, OSError):
logging.warning(f"Failed parsing timestamp from {url_fragment}") logging.warning(
f"Failed parsing timestamp from {REDDIT_URL}{post_url_fragment}"
)
created_date = timezone.now() created_date = timezone.now()
post_data = { post_data = {
@ -216,14 +213,98 @@ class RedditBuilder(Builder):
results[remote_identifier] = Post(**post_data) results[remote_identifier] = Post(**post_data)
return results.values() self.instances = results.values()
def save(self):
for post in self.instances:
post.save()
class RedditScheduler: class RedditStream(PostStream):
rule_type = RuleTypeChoices.subreddit
headers = {}
def __init__(self, rule):
super().__init__(rule)
self.headers = {
f"Authorization": f"bearer {self.rule.user.reddit_access_token}"
}
def read(self):
response = fetch(self.rule.url, headers=self.headers)
return self.parse(response), self
def parse(self, response):
try:
return response.json()
except JSONDecodeError as e:
raise StreamParseException(
response=response, message="Failed parsing json"
) from e
class RedditClient(PostClient):
stream = RedditStream
def __enter__(self):
streams = [[self.stream(rule) for rule in batch] for batch in self.rules]
rate_limitted = False
with ThreadPoolExecutor(max_workers=10) as executor:
for batch in streams:
futures = {executor.submit(stream.read): stream for stream in batch}
if rate_limitted:
logger.warning("Aborting requests, ratelimit hit")
break
for future in as_completed(futures):
stream = futures[future]
try:
response_data = future.result()
stream.rule.error = None
stream.rule.succeeded = True
yield response_data
except StreamDeniedException as e:
logger.warning(
f"Access token expired for user {stream.rule.user.pk}"
)
stream.rule.user.reddit_access_token = None
stream.rule.user.save()
self.set_rule_error(stream.rule, e)
RedditTokenTask.delay(stream.rule.user.pk)
break
except StreamTooManyException as e:
logger.exception("Ratelimit hit, aborting batched subreddits")
self.set_rule_error(stream.rule, e)
rate_limitted = True
break
except StreamException as e:
logger.exception(
f"Stream failed reading content from {stream.rule.url}"
)
self.set_rule_error(stream.rule, e)
continue
finally:
stream.rule.last_run = timezone.now()
stream.rule.save()
class RedditCollector(PostCollector):
builder = RedditBuilder
client = RedditClient
class RedditScheduler(Scheduler):
max_amount = RATE_LIMIT max_amount = RATE_LIMIT
max_user_amount = RATE_LIMIT / 4 max_user_amount = RATE_LIMIT / 4
@ -234,7 +315,7 @@ class RedditScheduler:
user__reddit_access_token__isnull=False, user__reddit_access_token__isnull=False,
user__reddit_refresh_token__isnull=False, user__reddit_refresh_token__isnull=False,
enabled=True, enabled=True,
).order_by("last_suceeded")[:200] ).order_by("last_run")[:200]
else: else:
self.subreddits = subreddits self.subreddits = subreddits
@ -263,100 +344,3 @@ class RedditScheduler:
current_amount += 1 current_amount += 1
return list(rule_mapping.values()) return list(rule_mapping.values())
class RedditStream(Stream):
headers = {}
user = None
def __init__(self, rule):
super().__init__(rule)
self.user = self.rule.user
self.headers = {
f"Authorization": f"bearer {self.rule.user.reddit_access_token}"
}
def read(self):
response = fetch(self.rule.url, headers=self.headers)
return self.parse(response), self
def parse(self, response):
try:
return response.json()
except JSONDecodeError as e:
raise StreamParseException(
response=response, message=f"Failed parsing json"
) from e
class RedditClient(Client):
stream = RedditStream
def __init__(self, rules=[]):
self.rules = rules
def __enter__(self):
streams = [[self.stream(rule) for rule in batch] for batch in self.rules]
rate_limitted = False
with ThreadPoolExecutor(max_workers=10) as executor:
for batch in streams:
futures = {executor.submit(stream.read): stream for stream in batch}
if rate_limitted:
break
for future in as_completed(futures):
stream = futures[future]
try:
response_data = future.result()
stream.rule.error = None
stream.rule.succeeded = True
stream.rule.last_suceeded = timezone.now()
yield response_data
except StreamDeniedException as e:
logger.warning(
f"Access token expired for user {stream.user.pk}"
)
stream.rule.user.reddit_access_token = None
stream.rule.user.save()
self.set_rule_error(stream.rule, e)
RedditTokenTask.delay(stream.rule.user.pk)
break
except StreamTooManyException as e:
logger.exception("Ratelimit hit, aborting batched subreddits")
self.set_rule_error(stream.rule, e)
rate_limitted = True
break
except StreamException as e:
logger.exception(
"Stream failed reading content from " f"{stream.rule.url}"
)
self.set_rule_error(stream.rule, e)
continue
finally:
stream.rule.save()
def set_rule_error(self, rule, exception):
length = rule._meta.get_field("error").max_length
rule.error = exception.message[-length:]
rule.succeeded = False
class RedditCollector(Collector):
builder = RedditBuilder
client = RedditClient

View file

@ -114,6 +114,40 @@ class RedditTokenTask(app.Task):
user.save() user.save()
class TwitterTimelineTask(app.Task):
name = "TwitterTimelineTask"
ignore_result = True
def run(self, user_pk):
from newsreader.news.collection.twitter import (
TwitterCollector,
TwitterTimeLineScheduler,
)
try:
user = User.objects.get(pk=user_pk)
except ObjectDoesNotExist:
message = f"User {user_pk} does not exist"
logger.exception(message)
raise Reject(reason=message, requeue=False)
with MemCacheLock("f{user.email}-timeline-task", self.app.oid) as acquired:
if acquired:
logger.info(f"Running twitter timeline task for user {user_pk}")
scheduler = TwitterTimeLineScheduler(user)
timelines = scheduler.get_scheduled_rules()
collector = TwitterCollector()
collector.collect(rules=timelines)
else:
logger.warning(f"Cancelling task due to existing lock")
raise Reject(reason="Task already running", requeue=False)
FeedTask = app.register_task(FeedTask()) FeedTask = app.register_task(FeedTask())
RedditTask = app.register_task(RedditTask()) RedditTask = app.register_task(RedditTask())
RedditTokenTask = app.register_task(RedditTokenTask()) RedditTokenTask = app.register_task(RedditTokenTask())
TwitterTimelineTask = app.register_task(TwitterTimelineTask())

View file

@ -4,6 +4,6 @@
{% block content %} {% block content %}
<main id="rule--page" class="main"> <main id="rule--page" class="main">
{% url "news:collection:rules" as cancel_url %} {% url "news:collection:rules" as cancel_url %}
{% include "components/form/form.html" with form=form title="Create rule" cancel_url=cancel_url confirm_text="Create rule" %} {% include "components/form/form.html" with form=form title="Add a feed" cancel_url=cancel_url confirm_text="Add feed" %}
</main> </main>
{% endblock %} {% endblock %}

View file

@ -3,12 +3,12 @@
{% block content %} {% block content %}
<main id="rule--page" class="main"> <main id="rule--page" class="main">
{% if rule.error %} {% if feed.error %}
{% trans "Failed to retrieve posts" as title %} {% trans "Failed to retrieve posts" as title %}
{% include "components/textbox/textbox.html" with title=title body=rule.error class="text-section--error" only %} {% include "components/textbox/textbox.html" with title=title body=feed.error class="text-section--error" only %}
{% endif %} {% endif %}
{% url "news:collection:rules" as cancel_url %} {% url "news:collection:rules" as cancel_url %}
{% include "components/form/form.html" with form=form title="Update rule" cancel_url=cancel_url confirm_text="Save rule" only %} {% include "components/form/form.html" with form=form title="Update feed" cancel_url=cancel_url confirm_text="Save feed" only %}
</main> </main>
{% endblock %} {% endblock %}

View file

@ -4,6 +4,6 @@
{% block content %} {% block content %}
<main id="import--page" class="main"> <main id="import--page" class="main">
{% url "news:collection:rules" as cancel_url %} {% url "news:collection:rules" as cancel_url %}
{% include "components/form/form.html" with form=form title="Import an OPML file" cancel_url=cancel_url confirm_text="Import rules" %} {% include "components/form/form.html" with form=form title="Import an OPML file" cancel_url=cancel_url confirm_text="Import feeds" %}
</main> </main>
{% endblock %} {% endblock %}

View file

@ -14,8 +14,9 @@
</fieldset> </fieldset>
<div class="form__actions"> <div class="form__actions">
<a class="link button button--confirm" href="{% url "news:collection:rule-create" %}">{% trans "Add a rule" %}</a> <a class="link button button--confirm" href="{% url "news:collection:feed-create" %}">{% trans "Add a feed" %}</a>
<a class="link button button--confirm" href="{% url "news:collection:subreddit-create" %}">{% trans "Add a subreddit" %}</a> <a class="link button button--reddit" href="{% url "news:collection:subreddit-create" %}">{% trans "Add a subreddit" %}</a>
<a class="link button button--twitter" href="{% url "news:collection:twitter-timeline-create" %}">{% trans "Add a Twitter profile" %}</a>
<a class="link button button--confirm" href="{% url "news:collection:import" %}">{% trans "Import rules" %}</a> <a class="link button button--confirm" href="{% url "news:collection:import" %}">{% trans "Import rules" %}</a>
</div> </div>
</section> </section>
@ -36,7 +37,7 @@
</thead> </thead>
<tbody class="table__body"> <tbody class="table__body">
{% for rule in rules %} {% for rule in rules %}
<tr class="table__row {% if not rule.succeeded %}table__row--error {% endif %}rules-table__row"> <tr class="table__row {% if rule.failed %}table__row--error {% endif %}rules-table__row">
<td class="table__item rules-table__item"> <td class="table__item rules-table__item">
{% with rule|id_for_label:"rules" as id_for_label %} {% with rule|id_for_label:"rules" as id_for_label %}
{% include "components/form/checkbox.html" with name="rules" value=rule.pk id=id_for_label id_for_label=id_for_label %} {% include "components/form/checkbox.html" with name="rules" value=rule.pk id=id_for_label id_for_label=id_for_label %}
@ -54,10 +55,10 @@
<a class="link" href="{{ rule.url }}" target="_blank" rel="noopener noreferrer">{{ rule.url }}</a> <a class="link" href="{{ rule.url }}" target="_blank" rel="noopener noreferrer">{{ rule.url }}</a>
</td> </td>
<td class="table__item rules-table__item"> <td class="table__item rules-table__item">
{% if rule.succeeded %} {% if rule.failed %}
<i class="gg-check"></i>
{% else %}
<i class="gg-danger"></i> <i class="gg-danger"></i>
{% else %}
<i class="gg-check"></i>
{% endif %} {% endif %}
</td> </td>
<td class="table__item rules-table__item"> <td class="table__item rules-table__item">

View file

@ -0,0 +1,9 @@
{% extends "base.html" %}
{% load static %}
{% block content %}
<main id="twitter--page" class="main">
{% url "news:collection:rules" as cancel_url %}
{% include "components/form/form.html" with form=form title="Add a Twitter profile" cancel_url=cancel_url confirm_text="Add profile" %}
</main>
{% endblock %}

View file

@ -0,0 +1,14 @@
{% extends "base.html" %}
{% load static i18n %}
{% block content %}
<main id="twitter--page" class="main">
{% if timeline.error %}
{% trans "Failed to retrieve posts" as title %}
{% include "components/textbox/textbox.html" with title=title body=timeline.error class="text-section--error" only %}
{% endif %}
{% url "news:collection:rules" as cancel_url %}
{% include "components/form/form.html" with form=form title="Update profile" cancel_url=cancel_url confirm_text="Save profile" %}
</main>
{% endblock %}

View file

@ -28,3 +28,8 @@ class FeedFactory(CollectionRuleFactory):
class SubredditFactory(CollectionRuleFactory): class SubredditFactory(CollectionRuleFactory):
type = RuleTypeChoices.subreddit type = RuleTypeChoices.subreddit
website_url = REDDIT_URL website_url = REDDIT_URL
class TwitterTimelineFactory(CollectionRuleFactory):
type = RuleTypeChoices.twitter_timeline
screen_name = factory.Faker("user_name")

View file

@ -1,3 +1,5 @@
from unittest.mock import Mock
from django.test import TestCase from django.test import TestCase
from newsreader.news.collection.favicon import FaviconBuilder from newsreader.news.collection.favicon import FaviconBuilder
@ -12,8 +14,11 @@ class FaviconBuilderTestCase(TestCase):
def test_simple(self): def test_simple(self):
rule = CollectionRuleFactory(favicon=None) rule = CollectionRuleFactory(favicon=None)
with FaviconBuilder((rule, simple_mock)) as builder: with FaviconBuilder(simple_mock, Mock(rule=rule)) as builder:
builder.build() builder.build()
builder.save()
rule.refresh_from_db()
self.assertEquals(rule.favicon, "https://www.bbc.com/favicon.ico") self.assertEquals(rule.favicon, "https://www.bbc.com/favicon.ico")
@ -22,24 +27,33 @@ class FaviconBuilderTestCase(TestCase):
website_url="https://www.theguardian.com/", favicon=None website_url="https://www.theguardian.com/", favicon=None
) )
with FaviconBuilder((rule, mock_without_url)) as builder: with FaviconBuilder(mock_without_url, Mock(rule=rule)) as builder:
builder.build() builder.build()
builder.save()
rule.refresh_from_db()
self.assertEquals(rule.favicon, "https://www.theguardian.com/favicon.ico") self.assertEquals(rule.favicon, "https://www.theguardian.com/favicon.ico")
def test_without_header(self): def test_without_header(self):
rule = CollectionRuleFactory(favicon=None) rule = CollectionRuleFactory(favicon=None)
with FaviconBuilder((rule, mock_without_header)) as builder: with FaviconBuilder(mock_without_header, Mock(rule=rule)) as builder:
builder.build() builder.build()
builder.save()
rule.refresh_from_db()
self.assertEquals(rule.favicon, None) self.assertEquals(rule.favicon, None)
def test_weird_path(self): def test_weird_path(self):
rule = CollectionRuleFactory(favicon=None) rule = CollectionRuleFactory(favicon=None)
with FaviconBuilder((rule, mock_with_weird_path)) as builder: with FaviconBuilder(mock_with_weird_path, Mock(rule=rule)) as builder:
builder.build() builder.build()
builder.save()
rule.refresh_from_db()
self.assertEquals( self.assertEquals(
rule.favicon, "https://www.theguardian.com/jabadaba/doe/favicon.ico" rule.favicon, "https://www.theguardian.com/jabadaba/doe/favicon.ico"
@ -48,15 +62,21 @@ class FaviconBuilderTestCase(TestCase):
def test_other_url(self): def test_other_url(self):
rule = CollectionRuleFactory(favicon=None) rule = CollectionRuleFactory(favicon=None)
with FaviconBuilder((rule, mock_with_other_url)) as builder: with FaviconBuilder(mock_with_other_url, Mock(rule=rule)) as builder:
builder.build() builder.build()
builder.save()
rule.refresh_from_db()
self.assertEquals(rule.favicon, "https://www.theguardian.com/icon.png") self.assertEquals(rule.favicon, "https://www.theguardian.com/icon.png")
def test_url_with_favicon_takes_precedence(self): def test_url_with_favicon_takes_precedence(self):
rule = CollectionRuleFactory(favicon=None) rule = CollectionRuleFactory(favicon=None)
with FaviconBuilder((rule, mock_with_multiple_icons)) as builder: with FaviconBuilder(mock_with_multiple_icons, Mock(rule=rule)) as builder:
builder.build() builder.build()
builder.save()
rule.refresh_from_db()
self.assertEquals(rule.favicon, "https://www.bbc.com/favicon.ico") self.assertEquals(rule.favicon, "https://www.bbc.com/favicon.ico")

View file

@ -1,4 +1,4 @@
from unittest.mock import MagicMock from unittest.mock import Mock
from django.test import TestCase from django.test import TestCase
@ -19,22 +19,22 @@ class FaviconClientTestCase(TestCase):
def test_simple(self): def test_simple(self):
rule = CollectionRuleFactory() rule = CollectionRuleFactory()
stream = MagicMock(url="https://www.bbc.com") stream = Mock(url="https://www.bbc.com", rule=rule)
stream.read.return_value = (simple_mock, stream) stream.read.return_value = (simple_mock, stream)
with FaviconClient([(rule, stream)]) as client: with FaviconClient([stream]) as client:
for rule, data in client: for payload, stream in client:
self.assertEquals(rule.pk, rule.pk) self.assertEquals(stream.rule.pk, rule.pk)
self.assertEquals(data, simple_mock) self.assertEquals(payload, simple_mock)
stream.read.assert_called_once_with() stream.read.assert_called_once_with()
def test_client_catches_stream_exception(self): def test_client_catches_stream_exception(self):
rule = CollectionRuleFactory(error=None, succeeded=True) rule = CollectionRuleFactory(error=None, succeeded=True)
stream = MagicMock(url="https://www.bbc.com") stream = Mock(url="https://www.bbc.com", rule=rule)
stream.read.side_effect = StreamException stream.read.side_effect = StreamException
with FaviconClient([(rule, stream)]) as client: with FaviconClient([stream]) as client:
for rule, data in client: for rule, data in client:
pass pass
@ -46,10 +46,10 @@ class FaviconClientTestCase(TestCase):
def test_client_catches_stream_not_found_exception(self): def test_client_catches_stream_not_found_exception(self):
rule = CollectionRuleFactory(error=None, succeeded=True) rule = CollectionRuleFactory(error=None, succeeded=True)
stream = MagicMock(url="https://www.bbc.com") stream = Mock(url="https://www.bbc.com", rule=rule)
stream.read.side_effect = StreamNotFoundException stream.read.side_effect = StreamNotFoundException
with FaviconClient([(rule, stream)]) as client: with FaviconClient([stream]) as client:
for rule, data in client: for rule, data in client:
pass pass
@ -61,10 +61,10 @@ class FaviconClientTestCase(TestCase):
def test_client_catches_stream_denied_exception(self): def test_client_catches_stream_denied_exception(self):
rule = CollectionRuleFactory(error=None, succeeded=True) rule = CollectionRuleFactory(error=None, succeeded=True)
stream = MagicMock(url="https://www.bbc.com") stream = Mock(url="https://www.bbc.com", rule=rule)
stream.read.side_effect = StreamDeniedException stream.read.side_effect = StreamDeniedException
with FaviconClient([(rule, stream)]) as client: with FaviconClient([stream]) as client:
for rule, data in client: for rule, data in client:
pass pass
@ -76,10 +76,10 @@ class FaviconClientTestCase(TestCase):
def test_client_catches_stream_timed_out(self): def test_client_catches_stream_timed_out(self):
rule = CollectionRuleFactory(error=None, succeeded=True) rule = CollectionRuleFactory(error=None, succeeded=True)
stream = MagicMock(url="https://www.bbc.com") stream = Mock(url="https://www.bbc.com", rule=rule)
stream.read.side_effect = StreamTimeOutException stream.read.side_effect = StreamTimeOutException
with FaviconClient([(rule, stream)]) as client: with FaviconClient([stream]) as client:
for rule, data in client: for rule, data in client:
pass pass

View file

@ -1,4 +1,4 @@
from unittest.mock import MagicMock, patch from unittest.mock import Mock, patch
from django.test import TestCase from django.test import TestCase
@ -38,8 +38,8 @@ class FaviconCollectorTestCase(TestCase):
def test_simple(self): def test_simple(self):
rule = CollectionRuleFactory(succeeded=True, error=None) rule = CollectionRuleFactory(succeeded=True, error=None)
self.mocked_feed_client.return_value = [(feed_mock, MagicMock(rule=rule))] self.mocked_feed_client.return_value = [(feed_mock, Mock(rule=rule))]
self.mocked_website_read.return_value = (website_mock, MagicMock()) self.mocked_website_read.return_value = (website_mock, Mock(rule=rule))
collector = FaviconCollector() collector = FaviconCollector()
collector.collect() collector.collect()
@ -54,8 +54,11 @@ class FaviconCollectorTestCase(TestCase):
def test_empty_stream(self): def test_empty_stream(self):
rule = CollectionRuleFactory(succeeded=True, error=None) rule = CollectionRuleFactory(succeeded=True, error=None)
self.mocked_feed_client.return_value = [(feed_mock, MagicMock(rule=rule))] self.mocked_feed_client.return_value = [(feed_mock, Mock(rule=rule))]
self.mocked_website_read.return_value = (BeautifulSoup("", "lxml"), MagicMock()) self.mocked_website_read.return_value = (
BeautifulSoup("", "lxml"),
Mock(rule=rule),
)
collector = FaviconCollector() collector = FaviconCollector()
collector.collect() collector.collect()
@ -70,7 +73,7 @@ class FaviconCollectorTestCase(TestCase):
def test_not_found(self): def test_not_found(self):
rule = CollectionRuleFactory(succeeded=True, error=None) rule = CollectionRuleFactory(succeeded=True, error=None)
self.mocked_feed_client.return_value = [(feed_mock, MagicMock(rule=rule))] self.mocked_feed_client.return_value = [(feed_mock, Mock(rule=rule))]
self.mocked_website_read.side_effect = StreamNotFoundException self.mocked_website_read.side_effect = StreamNotFoundException
collector = FaviconCollector() collector = FaviconCollector()
@ -86,7 +89,7 @@ class FaviconCollectorTestCase(TestCase):
def test_denied(self): def test_denied(self):
rule = CollectionRuleFactory(succeeded=True, error=None) rule = CollectionRuleFactory(succeeded=True, error=None)
self.mocked_feed_client.return_value = [(feed_mock, MagicMock(rule=rule))] self.mocked_feed_client.return_value = [(feed_mock, Mock(rule=rule))]
self.mocked_website_read.side_effect = StreamDeniedException self.mocked_website_read.side_effect = StreamDeniedException
collector = FaviconCollector() collector = FaviconCollector()
@ -102,7 +105,7 @@ class FaviconCollectorTestCase(TestCase):
def test_forbidden(self): def test_forbidden(self):
rule = CollectionRuleFactory(succeeded=True, error=None) rule = CollectionRuleFactory(succeeded=True, error=None)
self.mocked_feed_client.return_value = [(feed_mock, MagicMock(rule=rule))] self.mocked_feed_client.return_value = [(feed_mock, Mock(rule=rule))]
self.mocked_website_read.side_effect = StreamForbiddenException self.mocked_website_read.side_effect = StreamForbiddenException
collector = FaviconCollector() collector = FaviconCollector()
@ -118,7 +121,7 @@ class FaviconCollectorTestCase(TestCase):
def test_timed_out(self): def test_timed_out(self):
rule = CollectionRuleFactory(succeeded=True, error=None) rule = CollectionRuleFactory(succeeded=True, error=None)
self.mocked_feed_client.return_value = [(feed_mock, MagicMock(rule=rule))] self.mocked_feed_client.return_value = [(feed_mock, Mock(rule=rule))]
self.mocked_website_read.side_effect = StreamTimeOutException self.mocked_website_read.side_effect = StreamTimeOutException
collector = FaviconCollector() collector = FaviconCollector()
@ -134,7 +137,7 @@ class FaviconCollectorTestCase(TestCase):
def test_wrong_stream_content_type(self): def test_wrong_stream_content_type(self):
rule = CollectionRuleFactory(succeeded=True, error=None) rule = CollectionRuleFactory(succeeded=True, error=None)
self.mocked_feed_client.return_value = [(feed_mock, MagicMock(rule=rule))] self.mocked_feed_client.return_value = [(feed_mock, Mock(rule=rule))]
self.mocked_website_read.side_effect = StreamParseException self.mocked_website_read.side_effect = StreamParseException
collector = FaviconCollector() collector = FaviconCollector()

View file

@ -1,5 +1,5 @@
from datetime import date, datetime, time from datetime import date, datetime, time
from unittest.mock import MagicMock from unittest.mock import Mock
from django.test import TestCase from django.test import TestCase
from django.utils import timezone from django.utils import timezone
@ -24,9 +24,10 @@ class FeedBuilderTestCase(TestCase):
def test_basic_entry(self): def test_basic_entry(self):
builder = FeedBuilder builder = FeedBuilder
rule = FeedFactory() rule = FeedFactory()
mock_stream = MagicMock(rule=rule) mock_stream = Mock(rule=rule)
with builder((simple_mock, mock_stream)) as builder: with builder(simple_mock, mock_stream) as builder:
builder.build()
builder.save() builder.save()
post = Post.objects.get() post = Post.objects.get()
@ -55,9 +56,10 @@ class FeedBuilderTestCase(TestCase):
def test_multiple_entries(self): def test_multiple_entries(self):
builder = FeedBuilder builder = FeedBuilder
rule = FeedFactory() rule = FeedFactory()
mock_stream = MagicMock(rule=rule) mock_stream = Mock(rule=rule)
with builder((multiple_mock, mock_stream)) as builder: with builder(multiple_mock, mock_stream) as builder:
builder.build()
builder.save() builder.save()
posts = Post.objects.order_by("-publication_date") posts = Post.objects.order_by("-publication_date")
@ -116,9 +118,10 @@ class FeedBuilderTestCase(TestCase):
def test_entries_without_remote_identifier(self): def test_entries_without_remote_identifier(self):
builder = FeedBuilder builder = FeedBuilder
rule = FeedFactory() rule = FeedFactory()
mock_stream = MagicMock(rule=rule) mock_stream = Mock(rule=rule)
with builder((mock_without_identifier, mock_stream)) as builder: with builder(mock_without_identifier, mock_stream) as builder:
builder.build()
builder.save() builder.save()
posts = Post.objects.order_by("-publication_date") posts = Post.objects.order_by("-publication_date")
@ -155,9 +158,10 @@ class FeedBuilderTestCase(TestCase):
def test_entry_without_publication_date(self): def test_entry_without_publication_date(self):
builder = FeedBuilder builder = FeedBuilder
rule = FeedFactory() rule = FeedFactory()
mock_stream = MagicMock(rule=rule) mock_stream = Mock(rule=rule)
with builder((mock_without_publish_date, mock_stream)) as builder: with builder(mock_without_publish_date, mock_stream) as builder:
builder.build()
builder.save() builder.save()
posts = Post.objects.order_by("-publication_date") posts = Post.objects.order_by("-publication_date")
@ -187,9 +191,10 @@ class FeedBuilderTestCase(TestCase):
def test_entry_without_url(self): def test_entry_without_url(self):
builder = FeedBuilder builder = FeedBuilder
rule = FeedFactory() rule = FeedFactory()
mock_stream = MagicMock(rule=rule) mock_stream = Mock(rule=rule)
with builder((mock_without_url, mock_stream)) as builder: with builder(mock_without_url, mock_stream) as builder:
builder.build()
builder.save() builder.save()
posts = Post.objects.order_by("-publication_date") posts = Post.objects.order_by("-publication_date")
@ -213,9 +218,10 @@ class FeedBuilderTestCase(TestCase):
def test_entry_without_body(self): def test_entry_without_body(self):
builder = FeedBuilder builder = FeedBuilder
rule = FeedFactory() rule = FeedFactory()
mock_stream = MagicMock(rule=rule) mock_stream = Mock(rule=rule)
with builder((mock_without_body, mock_stream)) as builder: with builder(mock_without_body, mock_stream) as builder:
builder.build()
builder.save() builder.save()
posts = Post.objects.order_by("-publication_date") posts = Post.objects.order_by("-publication_date")
@ -247,9 +253,10 @@ class FeedBuilderTestCase(TestCase):
def test_entry_without_author(self): def test_entry_without_author(self):
builder = FeedBuilder builder = FeedBuilder
rule = FeedFactory() rule = FeedFactory()
mock_stream = MagicMock(rule=rule) mock_stream = Mock(rule=rule)
with builder((mock_without_author, mock_stream)) as builder: with builder(mock_without_author, mock_stream) as builder:
builder.build()
builder.save() builder.save()
posts = Post.objects.order_by("-publication_date") posts = Post.objects.order_by("-publication_date")
@ -275,9 +282,10 @@ class FeedBuilderTestCase(TestCase):
def test_empty_entries(self): def test_empty_entries(self):
builder = FeedBuilder builder = FeedBuilder
rule = FeedFactory() rule = FeedFactory()
mock_stream = MagicMock(rule=rule) mock_stream = Mock(rule=rule)
with builder((mock_without_entries, mock_stream)) as builder: with builder(mock_without_entries, mock_stream) as builder:
builder.build()
builder.save() builder.save()
self.assertEquals(Post.objects.count(), 0) self.assertEquals(Post.objects.count(), 0)
@ -285,7 +293,7 @@ class FeedBuilderTestCase(TestCase):
def test_update_entries(self): def test_update_entries(self):
builder = FeedBuilder builder = FeedBuilder
rule = FeedFactory() rule = FeedFactory()
mock_stream = MagicMock(rule=rule) mock_stream = Mock(rule=rule)
existing_first_post = FeedPostFactory.create( existing_first_post = FeedPostFactory.create(
remote_identifier="28f79ae4-8f9a-11e9-b143-00163ef6bee7", rule=rule remote_identifier="28f79ae4-8f9a-11e9-b143-00163ef6bee7", rule=rule
@ -295,7 +303,8 @@ class FeedBuilderTestCase(TestCase):
remote_identifier="a5479c66-8fae-11e9-8422-00163ef6bee7", rule=rule remote_identifier="a5479c66-8fae-11e9-8422-00163ef6bee7", rule=rule
) )
with builder((mock_with_update_entries, mock_stream)) as builder: with builder(mock_with_update_entries, mock_stream) as builder:
builder.build()
builder.save() builder.save()
self.assertEquals(Post.objects.count(), 3) self.assertEquals(Post.objects.count(), 3)
@ -315,9 +324,10 @@ class FeedBuilderTestCase(TestCase):
def test_html_sanitizing(self): def test_html_sanitizing(self):
builder = FeedBuilder builder = FeedBuilder
rule = FeedFactory() rule = FeedFactory()
mock_stream = MagicMock(rule=rule) mock_stream = Mock(rule=rule)
with builder((mock_with_html, mock_stream)) as builder: with builder(mock_with_html, mock_stream) as builder:
builder.build()
builder.save() builder.save()
post = Post.objects.get() post = Post.objects.get()
@ -337,9 +347,10 @@ class FeedBuilderTestCase(TestCase):
def test_long_author_text_is_truncated(self): def test_long_author_text_is_truncated(self):
builder = FeedBuilder builder = FeedBuilder
rule = FeedFactory() rule = FeedFactory()
mock_stream = MagicMock(rule=rule) mock_stream = Mock(rule=rule)
with builder((mock_with_long_author, mock_stream)) as builder: with builder(mock_with_long_author, mock_stream) as builder:
builder.build()
builder.save() builder.save()
post = Post.objects.get() post = Post.objects.get()
@ -351,9 +362,10 @@ class FeedBuilderTestCase(TestCase):
def test_long_title_text_is_truncated(self): def test_long_title_text_is_truncated(self):
builder = FeedBuilder builder = FeedBuilder
rule = FeedFactory() rule = FeedFactory()
mock_stream = MagicMock(rule=rule) mock_stream = Mock(rule=rule)
with builder((mock_with_long_title, mock_stream)) as builder: with builder(mock_with_long_title, mock_stream) as builder:
builder.build()
builder.save() builder.save()
post = Post.objects.get() post = Post.objects.get()
@ -366,9 +378,10 @@ class FeedBuilderTestCase(TestCase):
def test_long_title_exotic_title(self): def test_long_title_exotic_title(self):
builder = FeedBuilder builder = FeedBuilder
rule = FeedFactory() rule = FeedFactory()
mock_stream = MagicMock(rule=rule) mock_stream = Mock(rule=rule)
with builder((mock_with_long_exotic_title, mock_stream)) as builder: with builder(mock_with_long_exotic_title, mock_stream) as builder:
builder.build()
builder.save() builder.save()
post = Post.objects.get() post = Post.objects.get()
@ -381,9 +394,10 @@ class FeedBuilderTestCase(TestCase):
def test_content_detail_is_prioritized_if_longer(self): def test_content_detail_is_prioritized_if_longer(self):
builder = FeedBuilder builder = FeedBuilder
rule = FeedFactory() rule = FeedFactory()
mock_stream = MagicMock(rule=rule) mock_stream = Mock(rule=rule)
with builder((mock_with_longer_content_detail, mock_stream)) as builder: with builder(mock_with_longer_content_detail, mock_stream) as builder:
builder.build()
builder.save() builder.save()
post = Post.objects.get() post = Post.objects.get()
@ -398,9 +412,10 @@ class FeedBuilderTestCase(TestCase):
def test_content_detail_is_not_prioritized_if_shorter(self): def test_content_detail_is_not_prioritized_if_shorter(self):
builder = FeedBuilder builder = FeedBuilder
rule = FeedFactory() rule = FeedFactory()
mock_stream = MagicMock(rule=rule) mock_stream = Mock(rule=rule)
with builder((mock_with_shorter_content_detail, mock_stream)) as builder: with builder(mock_with_shorter_content_detail, mock_stream) as builder:
builder.build()
builder.save() builder.save()
post = Post.objects.get() post = Post.objects.get()
@ -414,9 +429,10 @@ class FeedBuilderTestCase(TestCase):
def test_content_detail_is_concatinated(self): def test_content_detail_is_concatinated(self):
builder = FeedBuilder builder = FeedBuilder
rule = FeedFactory() rule = FeedFactory()
mock_stream = MagicMock(rule=rule) mock_stream = Mock(rule=rule)
with builder((mock_with_multiple_content_detail, mock_stream)) as builder: with builder(mock_with_multiple_content_detail, mock_stream) as builder:
builder.build()
builder.save() builder.save()
post = Post.objects.get() post = Post.objects.get()

View file

@ -1,4 +1,4 @@
from unittest.mock import MagicMock, patch from unittest.mock import Mock, patch
from django.test import TestCase from django.test import TestCase
from django.utils.lorem_ipsum import words from django.utils.lorem_ipsum import words
@ -28,7 +28,7 @@ class FeedClientTestCase(TestCase):
def test_client_retrieves_single_rules(self): def test_client_retrieves_single_rules(self):
rule = FeedFactory.create() rule = FeedFactory.create()
mock_stream = MagicMock(rule=rule) mock_stream = Mock(rule=rule)
self.mocked_read.return_value = (simple_mock, mock_stream) self.mocked_read.return_value = (simple_mock, mock_stream)

View file

@ -1,6 +1,6 @@
from datetime import date, datetime, time from datetime import date, datetime, time
from time import struct_time from time import struct_time
from unittest.mock import MagicMock, patch from unittest.mock import Mock, patch
from django.test import TestCase from django.test import TestCase
from django.utils import timezone from django.utils import timezone
@ -26,6 +26,7 @@ from newsreader.news.core.tests.factories import FeedPostFactory
from .mocks import duplicate_mock, empty_mock, multiple_mock, multiple_update_mock from .mocks import duplicate_mock, empty_mock, multiple_mock, multiple_update_mock
@freeze_time("2019-10-30 12:30:00")
class FeedCollectorTestCase(TestCase): class FeedCollectorTestCase(TestCase):
def setUp(self): def setUp(self):
self.maxDiff = None self.maxDiff = None
@ -39,43 +40,42 @@ class FeedCollectorTestCase(TestCase):
def tearDown(self): def tearDown(self):
patch.stopall() patch.stopall()
@freeze_time("2019-10-30 12:30:00")
def test_simple_batch(self): def test_simple_batch(self):
self.mocked_parse.return_value = multiple_mock self.mocked_parse.return_value = multiple_mock
rule = FeedFactory()
rule = FeedFactory()
collector = FeedCollector() collector = FeedCollector()
collector.collect() collector.collect(rules=[rule])
rule.refresh_from_db() rule.refresh_from_db()
self.assertEquals(Post.objects.count(), 3) self.assertEquals(Post.objects.count(), 3)
self.assertEquals(rule.succeeded, True) self.assertEquals(rule.succeeded, True)
self.assertEquals(rule.last_suceeded, timezone.now()) self.assertEquals(rule.last_run, timezone.now())
self.assertEquals(rule.error, None) self.assertEquals(rule.error, None)
@freeze_time("2019-10-30 12:30:00")
def test_emtpy_batch(self): def test_emtpy_batch(self):
self.mocked_fetch.return_value = MagicMock() self.mocked_fetch.return_value = Mock()
self.mocked_parse.return_value = empty_mock self.mocked_parse.return_value = empty_mock
rule = FeedFactory() rule = FeedFactory()
collector = FeedCollector() collector = FeedCollector()
collector.collect() collector.collect(rules=[rule])
rule.refresh_from_db() rule.refresh_from_db()
self.assertEquals(Post.objects.count(), 0) self.assertEquals(Post.objects.count(), 0)
self.assertEquals(rule.succeeded, True) self.assertEquals(rule.succeeded, True)
self.assertEquals(rule.error, None) self.assertEquals(rule.error, None)
self.assertEquals(rule.last_suceeded, timezone.now()) self.assertEquals(rule.last_run, timezone.now())
def test_not_found(self): def test_not_found(self):
self.mocked_fetch.side_effect = StreamNotFoundException self.mocked_fetch.side_effect = StreamNotFoundException
rule = FeedFactory()
rule = FeedFactory()
collector = FeedCollector() collector = FeedCollector()
collector.collect() collector.collect(rules=[rule])
rule.refresh_from_db() rule.refresh_from_db()
@ -85,58 +85,59 @@ class FeedCollectorTestCase(TestCase):
def test_denied(self): def test_denied(self):
self.mocked_fetch.side_effect = StreamDeniedException self.mocked_fetch.side_effect = StreamDeniedException
last_suceeded = timezone.make_aware(
datetime.combine(date=date(2019, 10, 30), time=time(12, 30)) old_run = timezone.make_aware(datetime(2019, 10, 30, 12, 30))
) rule = FeedFactory(last_run=old_run)
rule = FeedFactory(last_suceeded=last_suceeded)
collector = FeedCollector() collector = FeedCollector()
collector.collect() collector.collect(rules=[rule])
rule.refresh_from_db() rule.refresh_from_db()
self.assertEquals(Post.objects.count(), 0) self.assertEquals(Post.objects.count(), 0)
self.assertEquals(rule.succeeded, False) self.assertEquals(rule.succeeded, False)
self.assertEquals(rule.error, "Stream does not have sufficient permissions") self.assertEquals(rule.error, "Stream does not have sufficient permissions")
self.assertEquals(rule.last_suceeded, last_suceeded) self.assertEquals(rule.last_run, timezone.now())
def test_forbidden(self): def test_forbidden(self):
self.mocked_fetch.side_effect = StreamForbiddenException self.mocked_fetch.side_effect = StreamForbiddenException
last_suceeded = timezone.make_aware(
datetime.combine(date=date(2019, 10, 30), time=time(12, 30)) old_run = pytz.utc.localize(datetime(2019, 10, 30, 12, 30))
) rule = FeedFactory(last_run=old_run)
rule = FeedFactory(last_suceeded=last_suceeded)
collector = FeedCollector() collector = FeedCollector()
collector.collect() collector.collect(rules=[rule])
rule.refresh_from_db() rule.refresh_from_db()
self.assertEquals(Post.objects.count(), 0) self.assertEquals(Post.objects.count(), 0)
self.assertEquals(rule.succeeded, False) self.assertEquals(rule.succeeded, False)
self.assertEquals(rule.error, "Stream forbidden") self.assertEquals(rule.error, "Stream forbidden")
self.assertEquals(rule.last_suceeded, last_suceeded) self.assertEquals(rule.last_run, timezone.now())
def test_timed_out(self): def test_timed_out(self):
self.mocked_fetch.side_effect = StreamTimeOutException self.mocked_fetch.side_effect = StreamTimeOutException
last_suceeded = timezone.make_aware(
last_run = timezone.make_aware(
datetime.combine(date=date(2019, 10, 30), time=time(12, 30)) datetime.combine(date=date(2019, 10, 30), time=time(12, 30))
) )
rule = FeedFactory(last_suceeded=last_suceeded) rule = FeedFactory(last_run=last_run)
collector = FeedCollector() collector = FeedCollector()
collector.collect() collector.collect(rules=[rule])
rule.refresh_from_db() rule.refresh_from_db()
self.assertEquals(Post.objects.count(), 0) self.assertEquals(Post.objects.count(), 0)
self.assertEquals(rule.succeeded, False) self.assertEquals(rule.succeeded, False)
self.assertEquals(rule.error, "Stream timed out") self.assertEquals(rule.error, "Stream timed out")
self.assertEquals(rule.last_suceeded, last_suceeded) self.assertEquals(
rule.last_run, pytz.utc.localize(datetime(2019, 10, 30, 12, 30))
)
@freeze_time("2019-10-30 12:30:00")
def test_duplicates(self): def test_duplicates(self):
self.mocked_parse.return_value = duplicate_mock self.mocked_parse.return_value = duplicate_mock
rule = FeedFactory() rule = FeedFactory()
aware_datetime = build_publication_date( aware_datetime = build_publication_date(
@ -186,10 +187,9 @@ class FeedCollectorTestCase(TestCase):
self.assertEquals(Post.objects.count(), 3) self.assertEquals(Post.objects.count(), 3)
self.assertEquals(rule.succeeded, True) self.assertEquals(rule.succeeded, True)
self.assertEquals(rule.last_suceeded, timezone.now()) self.assertEquals(rule.last_run, timezone.now())
self.assertEquals(rule.error, None) self.assertEquals(rule.error, None)
@freeze_time("2019-02-22 12:30:00")
def test_items_with_identifiers_get_updated(self): def test_items_with_identifiers_get_updated(self):
self.mocked_parse.return_value = multiple_update_mock self.mocked_parse.return_value = multiple_update_mock
rule = FeedFactory() rule = FeedFactory()
@ -231,7 +231,7 @@ class FeedCollectorTestCase(TestCase):
self.assertEquals(Post.objects.count(), 3) self.assertEquals(Post.objects.count(), 3)
self.assertEquals(rule.succeeded, True) self.assertEquals(rule.succeeded, True)
self.assertEquals(rule.last_suceeded, timezone.now()) self.assertEquals(rule.last_run, timezone.now())
self.assertEquals(rule.error, None) self.assertEquals(rule.error, None)
self.assertEquals( self.assertEquals(
@ -245,23 +245,3 @@ class FeedCollectorTestCase(TestCase):
self.assertEquals( self.assertEquals(
third_post.title, "Birmingham head teacher threatened over LGBT lessons" third_post.title, "Birmingham head teacher threatened over LGBT lessons"
) )
@freeze_time("2019-02-22 12:30:00")
def test_disabled_rules(self):
rules = (FeedFactory(enabled=False), FeedFactory(enabled=True))
self.mocked_parse.return_value = multiple_mock
collector = FeedCollector()
collector.collect()
for rule in rules:
rule.refresh_from_db()
self.assertEquals(Post.objects.count(), 3)
self.assertEquals(rules[1].succeeded, True)
self.assertEquals(rules[1].last_suceeded, timezone.now())
self.assertEquals(rules[1].error, None)
self.assertEquals(rules[0].last_suceeded, None)
self.assertEquals(rules[0].succeeded, False)

View file

@ -1,4 +1,4 @@
from unittest.mock import MagicMock, patch from unittest.mock import Mock, patch
from django.test import TestCase from django.test import TestCase
@ -27,7 +27,7 @@ class FeedStreamTestCase(TestCase):
patch.stopall() patch.stopall()
def test_simple_stream(self): def test_simple_stream(self):
self.mocked_fetch.return_value = MagicMock(content=simple_mock) self.mocked_fetch.return_value = Mock(content=simple_mock)
rule = FeedFactory() rule = FeedFactory()
stream = FeedStream(rule) stream = FeedStream(rule)
@ -95,7 +95,7 @@ class FeedStreamTestCase(TestCase):
@patch("newsreader.news.collection.feed.parse") @patch("newsreader.news.collection.feed.parse")
def test_stream_raises_parse_exception(self, mocked_parse): def test_stream_raises_parse_exception(self, mocked_parse):
self.mocked_fetch.return_value = MagicMock() self.mocked_fetch.return_value = Mock()
mocked_parse.side_effect = TypeError mocked_parse.side_effect = TypeError
rule = FeedFactory() rule = FeedFactory()

View file

@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from unittest.mock import MagicMock from unittest.mock import Mock
from django.test import TestCase from django.test import TestCase
@ -20,9 +20,10 @@ class RedditBuilderTestCase(TestCase):
builder = RedditBuilder builder = RedditBuilder
subreddit = SubredditFactory() subreddit = SubredditFactory()
mock_stream = MagicMock(rule=subreddit) mock_stream = Mock(rule=subreddit)
with builder((simple_mock, mock_stream)) as builder: with builder(simple_mock, mock_stream) as builder:
builder.build()
builder.save() builder.save()
posts = {post.remote_identifier: post for post in Post.objects.all()} posts = {post.remote_identifier: post for post in Post.objects.all()}
@ -65,9 +66,10 @@ class RedditBuilderTestCase(TestCase):
builder = RedditBuilder builder = RedditBuilder
subreddit = SubredditFactory() subreddit = SubredditFactory()
mock_stream = MagicMock(rule=subreddit) mock_stream = Mock(rule=subreddit)
with builder((empty_mock, mock_stream)) as builder: with builder(empty_mock, mock_stream) as builder:
builder.build()
builder.save() builder.save()
self.assertEquals(Post.objects.count(), 0) self.assertEquals(Post.objects.count(), 0)
@ -76,9 +78,10 @@ class RedditBuilderTestCase(TestCase):
builder = RedditBuilder builder = RedditBuilder
subreddit = SubredditFactory() subreddit = SubredditFactory()
mock_stream = MagicMock(rule=subreddit) mock_stream = Mock(rule=subreddit)
with builder((unknown_mock, mock_stream)) as builder: with builder(unknown_mock, mock_stream) as builder:
builder.build()
builder.save() builder.save()
self.assertEquals(Post.objects.count(), 0) self.assertEquals(Post.objects.count(), 0)
@ -95,9 +98,10 @@ class RedditBuilderTestCase(TestCase):
) )
builder = RedditBuilder builder = RedditBuilder
mock_stream = MagicMock(rule=subreddit) mock_stream = Mock(rule=subreddit)
with builder((simple_mock, mock_stream)) as builder: with builder(simple_mock, mock_stream) as builder:
builder.build()
builder.save() builder.save()
posts = {post.remote_identifier: post for post in Post.objects.all()} posts = {post.remote_identifier: post for post in Post.objects.all()}
@ -132,9 +136,10 @@ class RedditBuilderTestCase(TestCase):
builder = RedditBuilder builder = RedditBuilder
subreddit = SubredditFactory() subreddit = SubredditFactory()
mock_stream = MagicMock(rule=subreddit) mock_stream = Mock(rule=subreddit)
with builder((unsanitized_mock, mock_stream)) as builder: with builder(unsanitized_mock, mock_stream) as builder:
builder.build()
builder.save() builder.save()
posts = {post.remote_identifier: post for post in Post.objects.all()} posts = {post.remote_identifier: post for post in Post.objects.all()}
@ -149,9 +154,10 @@ class RedditBuilderTestCase(TestCase):
builder = RedditBuilder builder = RedditBuilder
subreddit = SubredditFactory() subreddit = SubredditFactory()
mock_stream = MagicMock(rule=subreddit) mock_stream = Mock(rule=subreddit)
with builder((author_mock, mock_stream)) as builder: with builder(author_mock, mock_stream) as builder:
builder.build()
builder.save() builder.save()
posts = {post.remote_identifier: post for post in Post.objects.all()} posts = {post.remote_identifier: post for post in Post.objects.all()}
@ -166,9 +172,10 @@ class RedditBuilderTestCase(TestCase):
builder = RedditBuilder builder = RedditBuilder
subreddit = SubredditFactory() subreddit = SubredditFactory()
mock_stream = MagicMock(rule=subreddit) mock_stream = Mock(rule=subreddit)
with builder((title_mock, mock_stream)) as builder: with builder(title_mock, mock_stream) as builder:
builder.build()
builder.save() builder.save()
posts = {post.remote_identifier: post for post in Post.objects.all()} posts = {post.remote_identifier: post for post in Post.objects.all()}
@ -186,9 +193,10 @@ class RedditBuilderTestCase(TestCase):
builder = RedditBuilder builder = RedditBuilder
subreddit = SubredditFactory() subreddit = SubredditFactory()
mock_stream = MagicMock(rule=subreddit) mock_stream = Mock(rule=subreddit)
with builder((duplicate_mock, mock_stream)) as builder: with builder(duplicate_mock, mock_stream) as builder:
builder.build()
builder.save() builder.save()
posts = {post.remote_identifier: post for post in Post.objects.all()} posts = {post.remote_identifier: post for post in Post.objects.all()}
@ -200,13 +208,14 @@ class RedditBuilderTestCase(TestCase):
builder = RedditBuilder builder = RedditBuilder
subreddit = SubredditFactory() subreddit = SubredditFactory()
mock_stream = MagicMock(rule=subreddit) mock_stream = Mock(rule=subreddit)
duplicate_post = RedditPostFactory( duplicate_post = RedditPostFactory(
remote_identifier="hm0qct", rule=subreddit, title="foo" remote_identifier="hm0qct", rule=subreddit, title="foo"
) )
with builder((simple_mock, mock_stream)) as builder: with builder(simple_mock, mock_stream) as builder:
builder.build()
builder.save() builder.save()
posts = {post.remote_identifier: post for post in Post.objects.all()} posts = {post.remote_identifier: post for post in Post.objects.all()}
@ -231,9 +240,10 @@ class RedditBuilderTestCase(TestCase):
builder = RedditBuilder builder = RedditBuilder
subreddit = SubredditFactory() subreddit = SubredditFactory()
mock_stream = MagicMock(rule=subreddit) mock_stream = Mock(rule=subreddit)
with builder((image_mock, mock_stream)) as builder: with builder(image_mock, mock_stream) as builder:
builder.build()
builder.save() builder.save()
posts = {post.remote_identifier: post for post in Post.objects.all()} posts = {post.remote_identifier: post for post in Post.objects.all()}
@ -262,9 +272,10 @@ class RedditBuilderTestCase(TestCase):
builder = RedditBuilder builder = RedditBuilder
subreddit = SubredditFactory() subreddit = SubredditFactory()
mock_stream = MagicMock(rule=subreddit) mock_stream = Mock(rule=subreddit)
with builder((external_image_mock, mock_stream)) as builder: with builder(external_image_mock, mock_stream) as builder:
builder.build()
builder.save() builder.save()
posts = {post.remote_identifier: post for post in Post.objects.all()} posts = {post.remote_identifier: post for post in Post.objects.all()}
@ -302,9 +313,10 @@ class RedditBuilderTestCase(TestCase):
builder = RedditBuilder builder = RedditBuilder
subreddit = SubredditFactory() subreddit = SubredditFactory()
mock_stream = MagicMock(rule=subreddit) mock_stream = Mock(rule=subreddit)
with builder((video_mock, mock_stream)) as builder: with builder(video_mock, mock_stream) as builder:
builder.build()
builder.save() builder.save()
posts = {post.remote_identifier: post for post in Post.objects.all()} posts = {post.remote_identifier: post for post in Post.objects.all()}
@ -328,9 +340,10 @@ class RedditBuilderTestCase(TestCase):
builder = RedditBuilder builder = RedditBuilder
subreddit = SubredditFactory() subreddit = SubredditFactory()
mock_stream = MagicMock(rule=subreddit) mock_stream = Mock(rule=subreddit)
with builder((external_video_mock, mock_stream)) as builder: with builder(external_video_mock, mock_stream) as builder:
builder.build()
builder.save() builder.save()
post = Post.objects.get() post = Post.objects.get()
@ -354,9 +367,10 @@ class RedditBuilderTestCase(TestCase):
builder = RedditBuilder builder = RedditBuilder
subreddit = SubredditFactory() subreddit = SubredditFactory()
mock_stream = MagicMock(rule=subreddit) mock_stream = Mock(rule=subreddit)
with builder((external_gifv_mock, mock_stream)) as builder: with builder(external_gifv_mock, mock_stream) as builder:
builder.build()
builder.save() builder.save()
post = Post.objects.get() post = Post.objects.get()
@ -376,9 +390,10 @@ class RedditBuilderTestCase(TestCase):
builder = RedditBuilder builder = RedditBuilder
subreddit = SubredditFactory() subreddit = SubredditFactory()
mock_stream = MagicMock(rule=subreddit) mock_stream = Mock(rule=subreddit)
with builder((simple_mock, mock_stream)) as builder: with builder(simple_mock, mock_stream) as builder:
builder.build()
builder.save() builder.save()
post = Post.objects.get(remote_identifier="hngsj8") post = Post.objects.get(remote_identifier="hngsj8")
@ -400,9 +415,10 @@ class RedditBuilderTestCase(TestCase):
builder = RedditBuilder builder = RedditBuilder
subreddit = SubredditFactory() subreddit = SubredditFactory()
mock_stream = MagicMock(rule=subreddit) mock_stream = Mock(rule=subreddit)
with builder((unknown_mock, mock_stream)) as builder: with builder(unknown_mock, mock_stream) as builder:
builder.build()
builder.save() builder.save()
self.assertEquals(Post.objects.count(), 0) self.assertEquals(Post.objects.count(), 0)

View file

@ -1,4 +1,4 @@
from unittest.mock import MagicMock, patch from unittest.mock import Mock, patch
from uuid import uuid4 from uuid import uuid4
from django.test import TestCase from django.test import TestCase
@ -31,7 +31,7 @@ class RedditClientTestCase(TestCase):
def test_client_retrieves_single_rules(self): def test_client_retrieves_single_rules(self):
subreddit = SubredditFactory() subreddit = SubredditFactory()
mock_stream = MagicMock(rule=subreddit) mock_stream = Mock(rule=subreddit)
self.mocked_read.return_value = (simple_mock, mock_stream) self.mocked_read.return_value = (simple_mock, mock_stream)
@ -150,7 +150,7 @@ class RedditClientTestCase(TestCase):
def test_client_catches_long_exception_text(self): def test_client_catches_long_exception_text(self):
subreddit = SubredditFactory() subreddit = SubredditFactory()
mock_stream = MagicMock(rule=subreddit) mock_stream = Mock(rule=subreddit)
self.mocked_read.side_effect = StreamParseException(message=words(1000)) self.mocked_read.side_effect = StreamParseException(message=words(1000))

View file

@ -74,7 +74,7 @@ class RedditCollectorTestCase(TestCase):
for subreddit in rules: for subreddit in rules:
with self.subTest(subreddit=subreddit): with self.subTest(subreddit=subreddit):
self.assertEquals(subreddit.succeeded, True) self.assertEquals(subreddit.succeeded, True)
self.assertEquals(subreddit.last_suceeded, timezone.now()) self.assertEquals(subreddit.last_run, timezone.now())
self.assertEquals(subreddit.error, None) self.assertEquals(subreddit.error, None)
post = Post.objects.get( post = Post.objects.get(
@ -133,7 +133,7 @@ class RedditCollectorTestCase(TestCase):
for subreddit in rules: for subreddit in rules:
with self.subTest(subreddit=subreddit): with self.subTest(subreddit=subreddit):
self.assertEquals(subreddit.succeeded, True) self.assertEquals(subreddit.succeeded, True)
self.assertEquals(subreddit.last_suceeded, timezone.now()) self.assertEquals(subreddit.last_run, timezone.now())
self.assertEquals(subreddit.error, None) self.assertEquals(subreddit.error, None)
def test_not_found(self): def test_not_found(self):

View file

@ -25,19 +25,19 @@ class RedditSchedulerTestCase(TestCase):
CollectionRuleFactory( CollectionRuleFactory(
user=user_1, user=user_1,
type=RuleTypeChoices.subreddit, type=RuleTypeChoices.subreddit,
last_suceeded=timezone.now() - timedelta(days=4), last_run=timezone.now() - timedelta(days=4),
enabled=True, enabled=True,
), ),
CollectionRuleFactory( CollectionRuleFactory(
user=user_1, user=user_1,
type=RuleTypeChoices.subreddit, type=RuleTypeChoices.subreddit,
last_suceeded=timezone.now() - timedelta(days=3), last_run=timezone.now() - timedelta(days=3),
enabled=True, enabled=True,
), ),
CollectionRuleFactory( CollectionRuleFactory(
user=user_1, user=user_1,
type=RuleTypeChoices.subreddit, type=RuleTypeChoices.subreddit,
last_suceeded=timezone.now() - timedelta(days=2), last_run=timezone.now() - timedelta(days=2),
enabled=True, enabled=True,
), ),
] ]
@ -46,19 +46,19 @@ class RedditSchedulerTestCase(TestCase):
CollectionRuleFactory( CollectionRuleFactory(
user=user_2, user=user_2,
type=RuleTypeChoices.subreddit, type=RuleTypeChoices.subreddit,
last_suceeded=timezone.now() - timedelta(days=4), last_run=timezone.now() - timedelta(days=4),
enabled=True, enabled=True,
), ),
CollectionRuleFactory( CollectionRuleFactory(
user=user_2, user=user_2,
type=RuleTypeChoices.subreddit, type=RuleTypeChoices.subreddit,
last_suceeded=timezone.now() - timedelta(days=3), last_run=timezone.now() - timedelta(days=3),
enabled=True, enabled=True,
), ),
CollectionRuleFactory( CollectionRuleFactory(
user=user_2, user=user_2,
type=RuleTypeChoices.subreddit, type=RuleTypeChoices.subreddit,
last_suceeded=timezone.now() - timedelta(days=2), last_run=timezone.now() - timedelta(days=2),
enabled=True, enabled=True,
), ),
] ]
@ -87,7 +87,7 @@ class RedditSchedulerTestCase(TestCase):
CollectionRuleFactory.create_batch( CollectionRuleFactory.create_batch(
name=f"rule-{index}", name=f"rule-{index}",
type=RuleTypeChoices.subreddit, type=RuleTypeChoices.subreddit,
last_suceeded=timezone.now() - timedelta(seconds=index), last_run=timezone.now() - timedelta(seconds=index),
enabled=True, enabled=True,
user=user, user=user,
size=15, size=15,
@ -121,7 +121,7 @@ class RedditSchedulerTestCase(TestCase):
CollectionRuleFactory( CollectionRuleFactory(
name=f"rule-{index}", name=f"rule-{index}",
type=RuleTypeChoices.subreddit, type=RuleTypeChoices.subreddit,
last_suceeded=timezone.now() - timedelta(seconds=index), last_run=timezone.now() - timedelta(seconds=index),
enabled=True, enabled=True,
user=user, user=user,
) )

View file

@ -1,10 +1,9 @@
from unittest.mock import MagicMock, patch from unittest.mock import Mock, patch
from django.test import TestCase from django.test import TestCase
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from newsreader.news.collection.base import URLBuilder, WebsiteStream
from newsreader.news.collection.exceptions import ( from newsreader.news.collection.exceptions import (
StreamDeniedException, StreamDeniedException,
StreamException, StreamException,
@ -13,6 +12,7 @@ from newsreader.news.collection.exceptions import (
StreamParseException, StreamParseException,
StreamTimeOutException, StreamTimeOutException,
) )
from newsreader.news.collection.favicon import WebsiteStream, WebsiteURLBuilder
from newsreader.news.collection.tests.factories import CollectionRuleFactory from newsreader.news.collection.tests.factories import CollectionRuleFactory
from .mocks import feed_mock_without_link, simple_feed_mock, simple_mock from .mocks import feed_mock_without_link, simple_feed_mock, simple_mock
@ -20,117 +20,125 @@ from .mocks import feed_mock_without_link, simple_feed_mock, simple_mock
class WebsiteStreamTestCase(TestCase): class WebsiteStreamTestCase(TestCase):
def setUp(self): def setUp(self):
self.patched_fetch = patch("newsreader.news.collection.base.fetch") self.patched_fetch = patch("newsreader.news.collection.favicon.fetch")
self.mocked_fetch = self.patched_fetch.start() self.mocked_fetch = self.patched_fetch.start()
def tearDown(self): def tearDown(self):
patch.stopall() patch.stopall()
def test_simple(self): def test_simple(self):
self.mocked_fetch.return_value = MagicMock(content=simple_mock) self.mocked_fetch.return_value = Mock(content=simple_mock)
rule = CollectionRuleFactory() rule = CollectionRuleFactory(website_url="https://www.bbc.co.uk/news/")
stream = WebsiteStream(rule.url) stream = WebsiteStream(rule)
return_value = stream.read() return_value = stream.read()
self.mocked_fetch.assert_called_once_with(rule.url) self.mocked_fetch.assert_called_once_with("https://www.bbc.co.uk/news/")
self.assertEquals(return_value, (BeautifulSoup(simple_mock, "lxml"), stream)) self.assertEquals(
return_value, (BeautifulSoup(simple_mock, features="lxml"), stream)
)
def test_raises_exception(self): def test_raises_exception(self):
self.mocked_fetch.side_effect = StreamException self.mocked_fetch.side_effect = StreamException
rule = CollectionRuleFactory() rule = CollectionRuleFactory(website_url="https://www.bbc.co.uk/news/")
stream = WebsiteStream(rule.url) stream = WebsiteStream(rule)
with self.assertRaises(StreamException): with self.assertRaises(StreamException):
stream.read() stream.read()
self.mocked_fetch.assert_called_once_with(rule.url) self.mocked_fetch.assert_called_once_with("https://www.bbc.co.uk/news/")
def test_raises_denied_exception(self): def test_raises_denied_exception(self):
self.mocked_fetch.side_effect = StreamDeniedException self.mocked_fetch.side_effect = StreamDeniedException
rule = CollectionRuleFactory() rule = CollectionRuleFactory(website_url="https://www.bbc.co.uk/news/")
stream = WebsiteStream(rule.url) stream = WebsiteStream(rule)
with self.assertRaises(StreamDeniedException): with self.assertRaises(StreamDeniedException):
stream.read() stream.read()
self.mocked_fetch.assert_called_once_with(rule.url) self.mocked_fetch.assert_called_once_with("https://www.bbc.co.uk/news/")
def test_raises_stream_not_found_exception(self): def test_raises_stream_not_found_exception(self):
self.mocked_fetch.side_effect = StreamNotFoundException self.mocked_fetch.side_effect = StreamNotFoundException
rule = CollectionRuleFactory() rule = CollectionRuleFactory(website_url="https://www.bbc.co.uk/news/")
stream = WebsiteStream(rule.url) stream = WebsiteStream(rule)
with self.assertRaises(StreamNotFoundException): with self.assertRaises(StreamNotFoundException):
stream.read() stream.read()
self.mocked_fetch.assert_called_once_with(rule.url) self.mocked_fetch.assert_called_once_with("https://www.bbc.co.uk/news/")
def test_stream_raises_time_out_exception(self): def test_stream_raises_time_out_exception(self):
self.mocked_fetch.side_effect = StreamTimeOutException self.mocked_fetch.side_effect = StreamTimeOutException
rule = CollectionRuleFactory() rule = CollectionRuleFactory(website_url="https://www.bbc.co.uk/news/")
stream = WebsiteStream(rule.url) stream = WebsiteStream(rule)
with self.assertRaises(StreamTimeOutException): with self.assertRaises(StreamTimeOutException):
stream.read() stream.read()
self.mocked_fetch.assert_called_once_with(rule.url) self.mocked_fetch.assert_called_once_with("https://www.bbc.co.uk/news/")
def test_stream_raises_forbidden_exception(self): def test_stream_raises_forbidden_exception(self):
self.mocked_fetch.side_effect = StreamForbiddenException self.mocked_fetch.side_effect = StreamForbiddenException
rule = CollectionRuleFactory() rule = CollectionRuleFactory(website_url="https://www.bbc.co.uk/news/")
stream = WebsiteStream(rule.url) stream = WebsiteStream(rule)
with self.assertRaises(StreamForbiddenException): with self.assertRaises(StreamForbiddenException):
stream.read() stream.read()
self.mocked_fetch.assert_called_once_with(rule.url) self.mocked_fetch.assert_called_once_with("https://www.bbc.co.uk/news/")
@patch("newsreader.news.collection.base.WebsiteStream.parse") @patch("newsreader.news.collection.favicon.WebsiteStream.parse")
def test_stream_raises_parse_exception(self, mocked_parse): def test_stream_raises_parse_exception(self, mocked_parse):
self.mocked_fetch.return_value = MagicMock() self.mocked_fetch.return_value = Mock()
mocked_parse.side_effect = StreamParseException mocked_parse.side_effect = StreamParseException
rule = CollectionRuleFactory() rule = CollectionRuleFactory(website_url="https://www.bbc.co.uk/news/")
stream = WebsiteStream(rule.url) stream = WebsiteStream(rule)
with self.assertRaises(StreamParseException): with self.assertRaises(StreamParseException):
stream.read() stream.read()
self.mocked_fetch.assert_called_once_with(rule.url) self.mocked_fetch.assert_called_once_with("https://www.bbc.co.uk/news/")
class URLBuilderTestCase(TestCase): class WebsiteURLBuilderTestCase(TestCase):
def test_simple(self): def test_simple(self):
initial_rule = CollectionRuleFactory() initial_rule = CollectionRuleFactory()
with URLBuilder((simple_feed_mock, MagicMock(rule=initial_rule))) as builder: with WebsiteURLBuilder(simple_feed_mock, Mock(rule=initial_rule)) as builder:
rule, url = builder.build() builder.build()
builder.save()
self.assertEquals(rule.pk, initial_rule.pk) initial_rule.refresh_from_db()
self.assertEquals(url, "https://www.bbc.co.uk/news/")
self.assertEquals(initial_rule.website_url, "https://www.bbc.co.uk/news/")
def test_no_link(self): def test_no_link(self):
initial_rule = CollectionRuleFactory() initial_rule = CollectionRuleFactory(website_url=None)
with URLBuilder( with WebsiteURLBuilder(
(feed_mock_without_link, MagicMock(rule=initial_rule)) feed_mock_without_link, Mock(rule=initial_rule)
) as builder: ) as builder:
rule, url = builder.build() builder.build()
builder.save()
self.assertEquals(rule.pk, initial_rule.pk) initial_rule.refresh_from_db()
self.assertEquals(url, None)
self.assertEquals(initial_rule.website_url, None)
def test_no_data(self): def test_no_data(self):
initial_rule = CollectionRuleFactory() initial_rule = CollectionRuleFactory(website_url=None)
with URLBuilder((None, MagicMock(rule=initial_rule))) as builder: with WebsiteURLBuilder(None, Mock(rule=initial_rule)) as builder:
rule, url = builder.build() builder.build()
builder.save()
self.assertEquals(rule.pk, initial_rule.pk) initial_rule.refresh_from_db()
self.assertEquals(url, None)
self.assertEquals(initial_rule.website_url, None)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,412 @@
from datetime import datetime
from unittest.mock import Mock
from django.test import TestCase
from django.utils.safestring import mark_safe
import pytz
from ftfy import fix_text
from newsreader.news.collection.tests.factories import TwitterTimelineFactory
from newsreader.news.collection.tests.twitter.builder.mocks import (
gif_mock,
image_mock,
quoted_mock,
retweet_mock,
simple_mock,
unsanitized_mock,
video_mock,
video_without_bitrate_mock,
)
from newsreader.news.collection.twitter import TWITTER_URL, TwitterBuilder
from newsreader.news.collection.utils import truncate_text
from newsreader.news.core.models import Post
from newsreader.news.core.tests.factories import PostFactory
class TwitterBuilderTestCase(TestCase):
def setUp(self):
self.maxDiff = None
def test_simple_post(self):
builder = TwitterBuilder
profile = TwitterTimelineFactory(screen_name="RobertsSpaceInd")
mock_stream = Mock(rule=profile)
with builder(simple_mock, mock_stream) as builder:
builder.build()
builder.save()
posts = {post.remote_identifier: post for post in Post.objects.all()}
self.assertCountEqual(
("1291528756373286914", "1288550304095416320"), posts.keys()
)
post = posts["1291528756373286914"]
full_text = (
"@ArieNeoSC Here you go, goodnight!\n\n"
"""<a href="https://t.co/trAcIxBMlX" rel="nofollow">https://t.co/trAcIxBMlX</a>"""
)
self.assertEquals(post.rule, profile)
self.assertEquals(
post.title,
truncate_text(
Post,
"title",
"@ArieNeoSC Here you go, goodnight!\n\nhttps://t.co/trAcIxBMlX",
),
)
self.assertEquals(post.body, mark_safe(full_text))
self.assertEquals(post.author, "RobertsSpaceInd")
self.assertEquals(
post.url, f"{TWITTER_URL}/RobertsSpaceInd/status/1291528756373286914"
)
self.assertEquals(
post.publication_date, pytz.utc.localize(datetime(2020, 8, 7, 0, 17, 5))
)
post = posts["1288550304095416320"]
full_text = "@RelicCcb Hi Christoper, we have checked the status of your investigation and it is still ongoing."
self.assertEquals(post.rule, profile)
self.assertEquals(post.title, truncate_text(Post, "title", full_text))
self.assertEquals(post.body, mark_safe(full_text))
self.assertEquals(post.author, "RobertsSpaceInd")
self.assertEquals(
post.url, f"{TWITTER_URL}/RobertsSpaceInd/status/1288550304095416320"
)
self.assertEquals(
post.publication_date, pytz.utc.localize(datetime(2020, 7, 29, 19, 1, 47))
)
# note that only one media type can be uploaded to an Tweet
# see https://developer.twitter.com/en/docs/tweets/data-dictionary/overview/extended-entities-object
def test_images_in_post(self):
builder = TwitterBuilder
profile = TwitterTimelineFactory(screen_name="RobertsSpaceInd")
mock_stream = Mock(rule=profile)
with builder(image_mock, mock_stream) as builder:
builder.build()
builder.save()
posts = {post.remote_identifier: post for post in Post.objects.all()}
self.assertCountEqual(("1269039237166321664",), posts.keys())
post = posts["1269039237166321664"]
self.assertEquals(post.rule, profile)
self.assertEquals(post.title, "_ https://t.co/VjEeDrL1iA")
self.assertEquals(post.author, "RobertsSpaceInd")
self.assertEquals(
post.url, f"{TWITTER_URL}/RobertsSpaceInd/status/1269039237166321664"
)
self.assertEquals(
post.publication_date, pytz.utc.localize(datetime(2020, 6, 5, 22, 51, 46))
)
self.assertInHTML(
"""<a href="https://t.co/VjEeDrL1iA" rel="nofollow">https://t.co/VjEeDrL1iA</a>""",
post.body,
count=1,
)
self.assertInHTML(
"""<div><img alt="1269039233072689152" src="https://pbs.twimg.com/media/EZyIdXVU8AACPCz.jpg" loading="lazy"></div>""",
post.body,
count=1,
)
self.assertInHTML(
"""<div><img alt="1269039233068527618" src="https://pbs.twimg.com/media/EZyIdXUVcAI3Cju.jpg" loading="lazy"></div>""",
post.body,
count=1,
)
def test_videos_in_post(self):
builder = TwitterBuilder
profile = TwitterTimelineFactory(screen_name="RobertsSpaceInd")
mock_stream = Mock(rule=profile)
with builder(video_mock, mock_stream) as builder:
builder.build()
builder.save()
posts = {post.remote_identifier: post for post in Post.objects.all()}
self.assertCountEqual(
("1291080532361527296", "1291079386821582849"), posts.keys()
)
post = posts["1291080532361527296"]
full_text = fix_text(
"Small enough to access hard-to-reach ore deposits, but with enough"
" power to get through the tough jobs, Greycat\u2019s ROC perfectly"
" complements any mining operation. \n\nDetails:"
""" <a href="https://t.co/2aH7qdOfSk" rel="nofollow">https://t.co/2aH7qdOfSk</a>"""
""" <a href="https://t.co/mZ8CAuq3SH" rel="nofollow">https://t.co/mZ8CAuq3SH</a>"""
)
self.assertEquals(post.rule, profile)
self.assertEquals(
post.title,
truncate_text(
Post,
"title",
fix_text(
"Small enough to access hard-to-reach ore deposits, but with enough"
" power to get through the tough jobs, Greycat\u2019s ROC perfectly"
" complements any mining operation. \n\nDetails:"
" https://t.co/2aH7qdOfSk https://t.co/mZ8CAuq3SH"
),
),
)
self.assertEquals(post.author, "RobertsSpaceInd")
self.assertEquals(
post.url, f"{TWITTER_URL}/RobertsSpaceInd/status/1291080532361527296"
)
self.assertEquals(
post.publication_date, pytz.utc.localize(datetime(2020, 8, 5, 18, 36, 0))
)
self.assertIn(full_text, post.body)
self.assertInHTML(
"""<div><video controls muted=""><source src="https://video.twimg.com/amplify_video/1291074294747770880/vid/1280x720/J05_p6q74ZUN4csg.mp4?tag=13" type="video/mp4" /></video></div>""",
post.body,
count=1,
)
def test_video_without_bitrate(self):
builder = TwitterBuilder
profile = TwitterTimelineFactory(screen_name="RobertsSpaceInd")
mock_stream = Mock(rule=profile)
with builder(video_without_bitrate_mock, mock_stream) as builder:
builder.build()
builder.save()
posts = {post.remote_identifier: post for post in Post.objects.all()}
self.assertCountEqual(("1291080532361527296",), posts.keys())
post = posts["1291080532361527296"]
self.assertInHTML(
"""<div><video controls muted=""><source src="https://video.twimg.com/amplify_video/1291074294747770880/pl/kMYgFEoRyoW99o-i.m3u8?tag=13" type="application/x-mpegURL"></video></div>""",
post.body,
count=1,
)
def test_GIFs_in_post(self):
builder = TwitterBuilder
profile = TwitterTimelineFactory(screen_name="RobertsSpaceInd")
mock_stream = Mock(rule=profile)
with builder(gif_mock, mock_stream) as builder:
builder.build()
builder.save()
posts = {post.remote_identifier: post for post in Post.objects.all()}
self.assertCountEqual(
("1289337776140296193", "1288965215648849920"), posts.keys()
)
post = posts["1289337776140296193"]
self.assertInHTML(
"""<div><video controls muted=""><source src="https://video.twimg.com/tweet_video/EeSl3sPUcAAyE4J.mp4" type="video/mp4"></video></div>""",
post.body,
count=1,
)
self.assertIn(
"""@Xenosystems <a href="https://t.co/wxvioLCJ6h" rel="nofollow">https://t.co/wxvioLCJ6h</a>""",
post.body,
)
def test_retweet_post(self):
builder = TwitterBuilder
profile = TwitterTimelineFactory(screen_name="RobertsSpaceInd")
mock_stream = Mock(rule=profile)
with builder(retweet_mock, mock_stream) as builder:
builder.build()
builder.save()
posts = {post.remote_identifier: post for post in Post.objects.all()}
self.assertCountEqual(
("1291117030486106112", "1288825524878336000"), posts.keys()
)
post = posts["1291117030486106112"]
self.assertIn(
fix_text(
"RT @Narayan_N7: New video! #StarCitizen 3.9 vs. 3.10 comparison!\nSo,"
" the patch 3.10 came out, which brought us quite a lot of changes!\ud83d\ude42\nPle\u2026"
),
post.body,
)
self.assertIn(
fix_text(
"Original tweet: New video! #StarCitizen 3.9 vs. 3.10 comparison!\nSo, the patch"
" 3.10 came out, which brought us quite a lot of changes!\ud83d\ude42\nPlease,"
" share it with your friends!\ud83d\ude4f\n\nEnjoy watching and stay safe!"
" \u2764\ufe0f\u263a\ufe0f\n@RobertsSpaceInd\n\n@CloudImperium\n\n"
"""<a href="https://t.co/j4QahHzbw4" rel="nofollow">https://t.co/j4QahHzbw4</a>"""
),
post.body,
)
def test_quoted_post(self):
builder = TwitterBuilder
profile = TwitterTimelineFactory(screen_name="RobertsSpaceInd")
mock_stream = Mock(rule=profile)
with builder(quoted_mock, mock_stream) as builder:
builder.build()
builder.save()
posts = {post.remote_identifier: post for post in Post.objects.all()}
self.assertCountEqual(
("1290801039075979264", "1289320160021495809"), posts.keys()
)
post = posts["1290801039075979264"]
self.assertIn(
fix_text(
"Bonne nuit \ud83c\udf3a\ud83d\udeeb"
""" <a href="https://t.co/WyznJwCJLp" rel="nofollow">https://t.co/WyznJwCJLp</a>"""
),
post.body,
)
self.assertIn(
fix_text(
"Quoted tweet: #Starcitizen Le jeu est beau. Bonne nuit"
""" @RobertsSpaceInd <a href="https://t.co/xCXun68V3r" rel="nofollow">https://t.co/xCXun68V3r</a>"""
),
post.body,
)
def test_empty_data(self):
builder = TwitterBuilder
profile = TwitterTimelineFactory(screen_name="RobertsSpaceInd")
mock_stream = Mock(rule=profile)
with builder([], mock_stream) as builder:
builder.build()
builder.save()
self.assertEquals(Post.objects.count(), 0)
def test_html_sanitizing(self):
builder = TwitterBuilder
profile = TwitterTimelineFactory(screen_name="RobertsSpaceInd")
mock_stream = Mock(rule=profile)
with builder(unsanitized_mock, mock_stream) as builder:
builder.build()
builder.save()
posts = {post.remote_identifier: post for post in Post.objects.all()}
self.assertCountEqual(("1291528756373286914",), posts.keys())
post = posts["1291528756373286914"]
full_text = (
"@ArieNeoSC Here you go, goodnight!\n\n"
"""<a href="https://t.co/trAcIxBMlX" rel="nofollow">https://t.co/trAcIxBMlX</a>"""
" <article></article>"
)
self.assertEquals(post.rule, profile)
self.assertEquals(
post.title,
truncate_text(
Post,
"title",
"@ArieNeoSC Here you go, goodnight!\n\nhttps://t.co/trAcIxBMlX"
" <article></article>",
),
)
self.assertEquals(post.body, mark_safe(full_text))
self.assertInHTML("<script></script>", post.body, count=0)
self.assertInHTML("<article></article>", post.body, count=1)
self.assertInHTML("<script></script>", post.title, count=0)
self.assertInHTML("<article></article>", post.title, count=1)
def test_urlize_on_urls(self):
builder = TwitterBuilder
profile = TwitterTimelineFactory(screen_name="RobertsSpaceInd")
mock_stream = Mock(rule=profile)
with builder(simple_mock, mock_stream) as builder:
builder.build()
builder.save()
posts = {post.remote_identifier: post for post in Post.objects.all()}
self.assertCountEqual(
("1291528756373286914", "1288550304095416320"), posts.keys()
)
post = posts["1291528756373286914"]
full_text = (
"@ArieNeoSC Here you go, goodnight!\n\n"
"""<a href="https://t.co/trAcIxBMlX" rel="nofollow">https://t.co/trAcIxBMlX</a>"""
)
self.assertEquals(post.rule, profile)
self.assertEquals(
post.title,
truncate_text(
Post,
"title",
"@ArieNeoSC Here you go, goodnight!\n\nhttps://t.co/trAcIxBMlX",
),
)
self.assertEquals(post.body, mark_safe(full_text))
def test_existing_posts(self):
builder = TwitterBuilder
profile = TwitterTimelineFactory(screen_name="RobertsSpaceInd")
mock_stream = Mock(rule=profile)
PostFactory(rule=profile, remote_identifier="1291528756373286914")
PostFactory(rule=profile, remote_identifier="1288550304095416320")
with builder(simple_mock, mock_stream) as builder:
builder.build()
builder.save()
self.assertEquals(Post.objects.count(), 2)

View file

@ -0,0 +1,225 @@
# retrieved with:
# curl -X GET -H "Authorization: Bearer <TOKEN>" "https://api.twitter.com/1.1/statuses/user_timeline.json?screen_name=RobertsSpaceInd&tweet_mode=extended" | python3 -m json.tool --sort-keys
simple_mock = [
{
"contributors": None,
"coordinates": None,
"created_at": "Fri Sep 18 20:32:22 +0000 2020",
"display_text_range": [0, 111],
"entities": {
"hashtags": [{"indices": [26, 41], "text": "SCShipShowdown"}],
"symbols": [],
"urls": [],
"user_mentions": [],
},
"favorite_count": 54,
"favorited": False,
"full_text": "It's a close match-up for #SCShipShowdown today! Which Aegis ship do you think will make it to the Semi-Finals?",
"geo": None,
"id": 1307054882210435074,
"id_str": "1307054882210435074",
"in_reply_to_screen_name": None,
"in_reply_to_status_id": None,
"in_reply_to_status_id_str": None,
"in_reply_to_user_id": None,
"in_reply_to_user_id_str": None,
"is_quote_status": False,
"lang": "en",
"place": None,
"retweet_count": 9,
"retweeted": False,
"source": '<a href="https://mobile.twitter.com" rel="nofollow">Twitter Web App</a>',
"truncated": False,
"user": {
"contributors_enabled": False,
"created_at": "Wed Sep 05 00:58:11 +0000 2012",
"default_profile": False,
"default_profile_image": False,
"description": "The official Twitter profile for #StarCitizen and Roberts Space Industries.",
"entities": {
"description": {"urls": []},
"url": {
"urls": [
{
"display_url": "robertsspaceindustries.com",
"expanded_url": "http://www.robertsspaceindustries.com",
"indices": [0, 23],
"url": "https://t.co/iqO6apof3y",
}
]
},
},
"favourites_count": 4831,
"follow_request_sent": None,
"followers_count": 106971,
"following": None,
"friends_count": 204,
"geo_enabled": False,
"has_extended_profile": False,
"id": 803542770,
"id_str": "803542770",
"is_translation_enabled": False,
"is_translator": False,
"lang": None,
"listed_count": 893,
"location": "Roberts Space Industries",
"name": "Star Citizen",
"notifications": None,
"profile_background_color": "131516",
"profile_background_image_url": "http://abs.twimg.com/images/themes/theme14/bg.gif",
"profile_background_image_url_https": "https://abs.twimg.com/images/themes/theme14/bg.gif",
"profile_background_tile": False,
"profile_banner_url": "https://pbs.twimg.com/profile_banners/803542770/1596651186",
"profile_image_url": "http://pbs.twimg.com/profile_images/963109950103814144/ysnj_Asy_normal.jpg",
"profile_image_url_https": "https://pbs.twimg.com/profile_images/963109950103814144/ysnj_Asy_normal.jpg",
"profile_link_color": "0A5485",
"profile_sidebar_border_color": "FFFFFF",
"profile_sidebar_fill_color": "EFEFEF",
"profile_text_color": "333333",
"profile_use_background_image": True,
"protected": False,
"screen_name": "RobertsSpaceInd",
"statuses_count": 6368,
"time_zone": None,
"translator_type": "none",
"url": "https://t.co/iqO6apof3y",
"utc_offset": None,
"verified": True,
},
},
{
"contributors": None,
"coordinates": None,
"created_at": "Fri Sep 18 18:50:11 +0000 2020",
"display_text_range": [0, 271],
"entities": {
"hashtags": [{"indices": [211, 218], "text": "Twitch"}],
"media": [
{
"display_url": "pic.twitter.com/Cey5JpR1i9",
"expanded_url": "https://twitter.com/RobertsSpaceInd/status/1307029168941461504/photo/1",
"id": 1307028141697765376,
"id_str": "1307028141697765376",
"indices": [272, 295],
"media_url": "http://pbs.twimg.com/media/EiN_K4FVkAAGBcr.jpg",
"media_url_https": "https://pbs.twimg.com/media/EiN_K4FVkAAGBcr.jpg",
"sizes": {
"large": {"h": 1090, "resize": "fit", "w": 1920},
"medium": {"h": 681, "resize": "fit", "w": 1200},
"small": {"h": 386, "resize": "fit", "w": 680},
"thumb": {"h": 150, "resize": "crop", "w": 150},
},
"type": "photo",
"url": "https://t.co/Cey5JpR1i9",
}
],
"symbols": [],
"urls": [
{
"display_url": "twitch.tv/starcitizen",
"expanded_url": "http://twitch.tv/starcitizen",
"indices": [248, 271],
"url": "https://t.co/2AdNovhpFW",
}
],
"user_mentions": [],
},
"extended_entities": {
"media": [
{
"display_url": "pic.twitter.com/Cey5JpR1i9",
"expanded_url": "https://twitter.com/RobertsSpaceInd/status/1307029168941461504/photo/1",
"id": 1307028141697765376,
"id_str": "1307028141697765376",
"indices": [272, 295],
"media_url": "http://pbs.twimg.com/media/EiN_K4FVkAAGBcr.jpg",
"media_url_https": "https://pbs.twimg.com/media/EiN_K4FVkAAGBcr.jpg",
"sizes": {
"large": {"h": 1090, "resize": "fit", "w": 1920},
"medium": {"h": 681, "resize": "fit", "w": 1200},
"small": {"h": 386, "resize": "fit", "w": 680},
"thumb": {"h": 150, "resize": "crop", "w": 150},
},
"type": "photo",
"url": "https://t.co/Cey5JpR1i9",
}
]
},
"favorite_count": 90,
"favorited": False,
"full_text": "We\u2019re welcoming members of our Builds, Publishes and Platform teams on Star Citizen Live to talk about the process involved in bringing everyone\u2019s work together and getting it out into your hands. Going live on #Twitch in 10 minutes. \ud83c\udfa5\ud83d\udd34 \n\nTune in: https://t.co/2AdNovhpFW https://t.co/Cey5JpR1i9",
"geo": None,
"id": 1307029168941461504,
"id_str": "1307029168941461504",
"in_reply_to_screen_name": None,
"in_reply_to_status_id": None,
"in_reply_to_status_id_str": None,
"in_reply_to_user_id": None,
"in_reply_to_user_id_str": None,
"is_quote_status": False,
"lang": "en",
"place": None,
"possibly_sensitive": False,
"retweet_count": 13,
"retweeted": False,
"source": '<a href="https://mobile.twitter.com" rel="nofollow">Twitter Web App</a>',
"truncated": False,
"user": {
"contributors_enabled": False,
"created_at": "Wed Sep 05 00:58:11 +0000 2012",
"default_profile": False,
"default_profile_image": False,
"description": "The official Twitter profile for #StarCitizen and Roberts Space Industries.",
"entities": {
"description": {"urls": []},
"url": {
"urls": [
{
"display_url": "robertsspaceindustries.com",
"expanded_url": "http://www.robertsspaceindustries.com",
"indices": [0, 23],
"url": "https://t.co/iqO6apof3y",
}
]
},
},
"favourites_count": 4831,
"follow_request_sent": None,
"followers_count": 106971,
"following": None,
"friends_count": 204,
"geo_enabled": False,
"has_extended_profile": False,
"id": 803542770,
"id_str": "803542770",
"is_translation_enabled": False,
"is_translator": False,
"lang": None,
"listed_count": 893,
"location": "Roberts Space Industries",
"name": "Star Citizen",
"notifications": None,
"profile_background_color": "131516",
"profile_background_image_url": "http://abs.twimg.com/images/themes/theme14/bg.gif",
"profile_background_image_url_https": "https://abs.twimg.com/images/themes/theme14/bg.gif",
"profile_background_tile": False,
"profile_banner_url": "https://pbs.twimg.com/profile_banners/803542770/1596651186",
"profile_image_url": "http://pbs.twimg.com/profile_images/963109950103814144/ysnj_Asy_normal.jpg",
"profile_image_url_https": "https://pbs.twimg.com/profile_images/963109950103814144/ysnj_Asy_normal.jpg",
"profile_link_color": "0A5485",
"profile_sidebar_border_color": "FFFFFF",
"profile_sidebar_fill_color": "EFEFEF",
"profile_text_color": "333333",
"profile_use_background_image": True,
"protected": False,
"screen_name": "RobertsSpaceInd",
"statuses_count": 6368,
"time_zone": None,
"translator_type": "none",
"url": "https://t.co/iqO6apof3y",
"utc_offset": None,
"verified": True,
},
},
]

View file

@ -0,0 +1,162 @@
from unittest.mock import Mock, patch
from uuid import uuid4
from django.test import TestCase
from django.utils.lorem_ipsum import words
from newsreader.accounts.tests.factories import UserFactory
from newsreader.news.collection.exceptions import (
StreamDeniedException,
StreamException,
StreamNotFoundException,
StreamParseException,
StreamTimeOutException,
StreamTooManyException,
)
from newsreader.news.collection.tests.factories import TwitterTimelineFactory
from newsreader.news.collection.twitter import TwitterClient
from .mocks import simple_mock
class TwitterClientTestCase(TestCase):
def setUp(self):
patched_read = patch("newsreader.news.collection.twitter.TwitterStream.read")
self.mocked_read = patched_read.start()
def tearDown(self):
patch.stopall()
def test_simple(self):
timeline = TwitterTimelineFactory()
mock_stream = Mock(rule=timeline)
self.mocked_read.return_value = (simple_mock, mock_stream)
with TwitterClient([timeline]) as client:
for data, stream in client:
with self.subTest(data=data, stream=stream):
self.assertEquals(data, simple_mock)
self.assertEquals(stream, mock_stream)
self.mocked_read.assert_called()
def test_client_catches_stream_exception(self):
timeline = TwitterTimelineFactory()
self.mocked_read.side_effect = StreamException(message="Stream exception")
with TwitterClient([timeline]) as client:
for data, stream in client:
with self.subTest(data=data, stream=stream):
self.assertIsNone(data)
self.assertIsNone(stream)
self.assertEquals(stream.rule.error, "Stream exception")
self.assertEquals(stream.rule.succeeded, False)
self.mocked_read.assert_called()
def test_client_catches_stream_not_found_exception(self):
timeline = TwitterTimelineFactory.create()
self.mocked_read.side_effect = StreamNotFoundException(
message="Stream not found"
)
with TwitterClient([timeline]) as client:
for data, stream in client:
with self.subTest(data=data, stream=stream):
self.assertIsNone(data)
self.assertIsNone(stream)
self.assertEquals(stream.rule.error, "Stream not found")
self.assertEquals(stream.rule.succeeded, False)
self.mocked_read.assert_called()
def test_client_catches_stream_denied_exception(self):
user = UserFactory(
twitter_oauth_token=str(uuid4()), twitter_oauth_token_secret=str(uuid4())
)
timeline = TwitterTimelineFactory(user=user)
self.mocked_read.side_effect = StreamDeniedException(message="Token expired")
with TwitterClient([timeline]) as client:
for data, stream in client:
with self.subTest(data=data, stream=stream):
self.assertIsNone(data)
self.assertIsNone(stream)
self.assertEquals(stream.rule.error, "Token expired")
self.assertEquals(stream.rule.succeeded, False)
self.mocked_read.assert_called()
user.refresh_from_db()
timeline.refresh_from_db()
self.assertIsNone(user.twitter_oauth_token)
self.assertIsNone(user.twitter_oauth_token_secret)
def test_client_catches_stream_timed_out_exception(self):
timeline = TwitterTimelineFactory()
self.mocked_read.side_effect = StreamTimeOutException(
message="Stream timed out"
)
with TwitterClient([timeline]) as client:
for data, stream in client:
with self.subTest(data=data, stream=stream):
self.assertIsNone(data)
self.assertIsNone(stream)
self.assertEquals(stream.rule.error, "Stream timed out")
self.assertEquals(stream.rule.succeeded, False)
self.mocked_read.assert_called()
def test_client_catches_stream_too_many_exception(self):
timeline = TwitterTimelineFactory()
self.mocked_read.side_effect = StreamTooManyException
with TwitterClient([timeline]) as client:
for data, stream in client:
with self.subTest(data=data, stream=stream):
self.assertIsNone(data)
self.assertIsNone(stream)
self.assertEquals(stream.rule.error, "Too many requests")
self.assertEquals(stream.rule.succeeded, False)
self.mocked_read.assert_called()
def test_client_catches_stream_parse_exception(self):
timeline = TwitterTimelineFactory()
self.mocked_read.side_effect = StreamParseException(
message="Stream could not be parsed"
)
with TwitterClient([timeline]) as client:
for data, stream in client:
with self.subTest(data=data, stream=stream):
self.assertIsNone(data)
self.assertIsNone(stream)
self.assertEquals(stream.rule.error, "Stream could not be parsed")
self.assertEquals(stream.rule.succeeded, False)
self.mocked_read.assert_called()
def test_client_catches_long_exception_text(self):
timeline = TwitterTimelineFactory()
mock_stream = Mock(rule=timeline)
self.mocked_read.side_effect = StreamParseException(message=words(1000))
with TwitterClient([timeline]) as client:
for data, stream in client:
self.assertIsNone(data)
self.assertIsNone(stream)
self.assertEquals(len(stream.rule.error), 1024)
self.assertEquals(stream.rule.succeeded, False)
self.mocked_read.assert_called()

View file

@ -0,0 +1,227 @@
# retrieved with:
# curl -X GET -H "Authorization: Bearer <TOKEN>" "https://api.twitter.com/1.1/statuses/user_timeline.json?screen_name=RobertsSpaceInd&tweet_mode=extended" | python3 -m json.tool --sort-keys
simple_mock = [
{
"contributors": None,
"coordinates": None,
"created_at": "Fri Sep 18 20:32:22 +0000 2020",
"display_text_range": [0, 111],
"entities": {
"hashtags": [{"indices": [26, 41], "text": "SCShipShowdown"}],
"symbols": [],
"urls": [],
"user_mentions": [],
},
"favorite_count": 54,
"favorited": False,
"full_text": "It's a close match-up for #SCShipShowdown today! Which Aegis ship do you think will make it to the Semi-Finals?",
"geo": None,
"id": 1307054882210435074,
"id_str": "1307054882210435074",
"in_reply_to_screen_name": None,
"in_reply_to_status_id": None,
"in_reply_to_status_id_str": None,
"in_reply_to_user_id": None,
"in_reply_to_user_id_str": None,
"is_quote_status": False,
"lang": "en",
"place": None,
"retweet_count": 9,
"retweeted": False,
"source": '<a href="https://mobile.twitter.com" rel="nofollow">Twitter Web App</a>',
"truncated": False,
"user": {
"contributors_enabled": False,
"created_at": "Wed Sep 05 00:58:11 +0000 2012",
"default_profile": False,
"default_profile_image": False,
"description": "The official Twitter profile for #StarCitizen and Roberts Space Industries.",
"entities": {
"description": {"urls": []},
"url": {
"urls": [
{
"display_url": "robertsspaceindustries.com",
"expanded_url": "http://www.robertsspaceindustries.com",
"indices": [0, 23],
"url": "https://t.co/iqO6apof3y",
}
]
},
},
"favourites_count": 4831,
"follow_request_sent": None,
"followers_count": 106971,
"following": None,
"friends_count": 204,
"geo_enabled": False,
"has_extended_profile": False,
"id": 803542770,
"id_str": "803542770",
"is_translation_enabled": False,
"is_translator": False,
"lang": None,
"listed_count": 893,
"location": "Roberts Space Industries",
"name": "Star Citizen",
"notifications": None,
"profile_background_color": "131516",
"profile_background_image_url": "http://abs.twimg.com/images/themes/theme14/bg.gif",
"profile_background_image_url_https": "https://abs.twimg.com/images/themes/theme14/bg.gif",
"profile_background_tile": False,
"profile_banner_url": "https://pbs.twimg.com/profile_banners/803542770/1596651186",
"profile_image_url": "http://pbs.twimg.com/profile_images/963109950103814144/ysnj_Asy_normal.jpg",
"profile_image_url_https": "https://pbs.twimg.com/profile_images/963109950103814144/ysnj_Asy_normal.jpg",
"profile_link_color": "0A5485",
"profile_sidebar_border_color": "FFFFFF",
"profile_sidebar_fill_color": "EFEFEF",
"profile_text_color": "333333",
"profile_use_background_image": True,
"protected": False,
"screen_name": "RobertsSpaceInd",
"statuses_count": 6368,
"time_zone": None,
"translator_type": "none",
"url": "https://t.co/iqO6apof3y",
"utc_offset": None,
"verified": True,
},
},
{
"contributors": None,
"coordinates": None,
"created_at": "Fri Sep 18 18:50:11 +0000 2020",
"display_text_range": [0, 271],
"entities": {
"hashtags": [{"indices": [211, 218], "text": "Twitch"}],
"media": [
{
"display_url": "pic.twitter.com/Cey5JpR1i9",
"expanded_url": "https://twitter.com/RobertsSpaceInd/status/1307029168941461504/photo/1",
"id": 1307028141697765376,
"id_str": "1307028141697765376",
"indices": [272, 295],
"media_url": "http://pbs.twimg.com/media/EiN_K4FVkAAGBcr.jpg",
"media_url_https": "https://pbs.twimg.com/media/EiN_K4FVkAAGBcr.jpg",
"sizes": {
"large": {"h": 1090, "resize": "fit", "w": 1920},
"medium": {"h": 681, "resize": "fit", "w": 1200},
"small": {"h": 386, "resize": "fit", "w": 680},
"thumb": {"h": 150, "resize": "crop", "w": 150},
},
"type": "photo",
"url": "https://t.co/Cey5JpR1i9",
}
],
"symbols": [],
"urls": [
{
"display_url": "twitch.tv/starcitizen",
"expanded_url": "http://twitch.tv/starcitizen",
"indices": [248, 271],
"url": "https://t.co/2AdNovhpFW",
}
],
"user_mentions": [],
},
"extended_entities": {
"media": [
{
"display_url": "pic.twitter.com/Cey5JpR1i9",
"expanded_url": "https://twitter.com/RobertsSpaceInd/status/1307029168941461504/photo/1",
"id": 1307028141697765376,
"id_str": "1307028141697765376",
"indices": [272, 295],
"media_url": "http://pbs.twimg.com/media/EiN_K4FVkAAGBcr.jpg",
"media_url_https": "https://pbs.twimg.com/media/EiN_K4FVkAAGBcr.jpg",
"sizes": {
"large": {"h": 1090, "resize": "fit", "w": 1920},
"medium": {"h": 681, "resize": "fit", "w": 1200},
"small": {"h": 386, "resize": "fit", "w": 680},
"thumb": {"h": 150, "resize": "crop", "w": 150},
},
"type": "photo",
"url": "https://t.co/Cey5JpR1i9",
}
]
},
"favorite_count": 90,
"favorited": False,
"full_text": "We\u2019re welcoming members of our Builds, Publishes and Platform teams on Star Citizen Live to talk about the process involved in bringing everyone\u2019s work together and getting it out into your hands. Going live on #Twitch in 10 minutes. \ud83c\udfa5\ud83d\udd34 \n\nTune in: https://t.co/2AdNovhpFW https://t.co/Cey5JpR1i9",
"geo": None,
"id": 1307029168941461504,
"id_str": "1307029168941461504",
"in_reply_to_screen_name": None,
"in_reply_to_status_id": None,
"in_reply_to_status_id_str": None,
"in_reply_to_user_id": None,
"in_reply_to_user_id_str": None,
"is_quote_status": False,
"lang": "en",
"place": None,
"possibly_sensitive": False,
"retweet_count": 13,
"retweeted": False,
"source": '<a href="https://mobile.twitter.com" rel="nofollow">Twitter Web App</a>',
"truncated": False,
"user": {
"contributors_enabled": False,
"created_at": "Wed Sep 05 00:58:11 +0000 2012",
"default_profile": False,
"default_profile_image": False,
"description": "The official Twitter profile for #StarCitizen and Roberts Space Industries.",
"entities": {
"description": {"urls": []},
"url": {
"urls": [
{
"display_url": "robertsspaceindustries.com",
"expanded_url": "http://www.robertsspaceindustries.com",
"indices": [0, 23],
"url": "https://t.co/iqO6apof3y",
}
]
},
},
"favourites_count": 4831,
"follow_request_sent": None,
"followers_count": 106971,
"following": None,
"friends_count": 204,
"geo_enabled": False,
"has_extended_profile": False,
"id": 803542770,
"id_str": "803542770",
"is_translation_enabled": False,
"is_translator": False,
"lang": None,
"listed_count": 893,
"location": "Roberts Space Industries",
"name": "Star Citizen",
"notifications": None,
"profile_background_color": "131516",
"profile_background_image_url": "http://abs.twimg.com/images/themes/theme14/bg.gif",
"profile_background_image_url_https": "https://abs.twimg.com/images/themes/theme14/bg.gif",
"profile_background_tile": False,
"profile_banner_url": "https://pbs.twimg.com/profile_banners/803542770/1596651186",
"profile_image_url": "http://pbs.twimg.com/profile_images/963109950103814144/ysnj_Asy_normal.jpg",
"profile_image_url_https": "https://pbs.twimg.com/profile_images/963109950103814144/ysnj_Asy_normal.jpg",
"profile_link_color": "0A5485",
"profile_sidebar_border_color": "FFFFFF",
"profile_sidebar_fill_color": "EFEFEF",
"profile_text_color": "333333",
"profile_use_background_image": True,
"protected": False,
"screen_name": "RobertsSpaceInd",
"statuses_count": 6368,
"time_zone": None,
"translator_type": "none",
"url": "https://t.co/iqO6apof3y",
"utc_offset": None,
"verified": True,
},
},
]
empty_mock = []

View file

@ -0,0 +1,180 @@
from datetime import datetime
from unittest.mock import patch
from uuid import uuid4
from django.test import TestCase
from django.utils import timezone
import pytz
from freezegun import freeze_time
from ftfy import fix_text
from newsreader.news.collection.choices import RuleTypeChoices
from newsreader.news.collection.exceptions import (
StreamDeniedException,
StreamForbiddenException,
StreamNotFoundException,
StreamTimeOutException,
)
from newsreader.news.collection.tests.factories import TwitterTimelineFactory
from newsreader.news.collection.tests.twitter.collector.mocks import (
empty_mock,
simple_mock,
)
from newsreader.news.collection.twitter import TWITTER_URL, TwitterCollector
from newsreader.news.collection.utils import truncate_text
from newsreader.news.core.models import Post
@freeze_time("2020-09-26 14:40:00")
class TwitterCollectorTestCase(TestCase):
def setUp(self):
patched_get = patch("newsreader.news.collection.twitter.fetch")
self.mocked_fetch = patched_get.start()
patched_parse = patch("newsreader.news.collection.twitter.TwitterStream.parse")
self.mocked_parse = patched_parse.start()
def tearDown(self):
patch.stopall()
def test_simple_batch(self):
self.mocked_parse.return_value = simple_mock
timeline = TwitterTimelineFactory(
user__twitter_oauth_token=str(uuid4()),
user__twitter_oauth_token_secret=str(uuid4()),
screen_name="RobertsSpaceInd",
enabled=True,
)
collector = TwitterCollector()
collector.collect(rules=[timeline])
self.assertCountEqual(
Post.objects.values_list("remote_identifier", flat=True),
("1307054882210435074", "1307029168941461504"),
)
self.assertEquals(timeline.succeeded, True)
self.assertEquals(timeline.last_run, timezone.now())
self.assertIsNone(timeline.error)
post = Post.objects.get(
remote_identifier="1307054882210435074",
rule__type=RuleTypeChoices.twitter_timeline,
)
self.assertEquals(
post.publication_date, pytz.utc.localize(datetime(2020, 9, 18, 20, 32, 22))
)
title = truncate_text(
Post,
"title",
"It's a close match-up for #SCShipShowdown today! Which Aegis ship"
" do you think will make it to the Semi-Finals?",
)
self.assertEquals(post.author, "RobertsSpaceInd")
self.assertEquals(post.title, title)
self.assertEquals(
post.url, f"{TWITTER_URL}/RobertsSpaceInd/status/1307054882210435074"
)
post = Post.objects.get(
remote_identifier="1307029168941461504",
rule__type=RuleTypeChoices.twitter_timeline,
)
self.assertEquals(
post.publication_date, pytz.utc.localize(datetime(2020, 9, 18, 18, 50, 11))
)
body = fix_text(
"We\u2019re welcoming members of our Builds, Publishes and Platform"
" teams on Star Citizen Live to talk about the process involved in"
" bringing everyone\u2019s work together and getting it out into your"
" hands. Going live on #Twitch in 10 minutes."
" \ud83c\udfa5\ud83d\udd34 \n\nTune in:"
" https://t.co/2AdNovhpFW https://t.co/Cey5JpR1i9"
)
title = truncate_text(Post, "title", body)
self.assertEquals(post.author, "RobertsSpaceInd")
self.assertEquals(post.title, title)
self.assertEquals(
post.url, f"{TWITTER_URL}/RobertsSpaceInd/status/1307029168941461504"
)
def test_empty_batch(self):
self.mocked_parse.return_value = empty_mock
timeline = TwitterTimelineFactory()
collector = TwitterCollector()
collector.collect(rules=[timeline])
self.assertEquals(Post.objects.count(), 0)
self.assertEquals(timeline.succeeded, True)
self.assertEquals(timeline.last_run, timezone.now())
self.assertIsNone(timeline.error)
def test_not_found(self):
self.mocked_fetch.side_effect = StreamNotFoundException
timeline = TwitterTimelineFactory()
collector = TwitterCollector()
collector.collect(rules=[timeline])
self.assertEquals(Post.objects.count(), 0)
self.assertEquals(timeline.succeeded, False)
self.assertEquals(timeline.error, "Stream not found")
def test_denied(self):
self.mocked_fetch.side_effect = StreamDeniedException
timeline = TwitterTimelineFactory(
user__twitter_oauth_token=str(uuid4()),
user__twitter_oauth_token_secret=str(uuid4()),
)
collector = TwitterCollector()
collector.collect(rules=[timeline])
self.assertEquals(Post.objects.count(), 0)
self.assertEquals(timeline.succeeded, False)
self.assertEquals(timeline.error, "Stream does not have sufficient permissions")
user = timeline.user
self.assertIsNone(user.twitter_oauth_token)
self.assertIsNone(user.twitter_oauth_token_secret)
def test_forbidden(self):
self.mocked_fetch.side_effect = StreamForbiddenException
timeline = TwitterTimelineFactory()
collector = TwitterCollector()
collector.collect(rules=[timeline])
self.assertEquals(Post.objects.count(), 0)
self.assertEquals(timeline.succeeded, False)
self.assertEquals(timeline.error, "Stream forbidden")
def test_timed_out(self):
self.mocked_fetch.side_effect = StreamTimeOutException
timeline = TwitterTimelineFactory()
collector = TwitterCollector()
collector.collect(rules=[timeline])
self.assertEquals(Post.objects.count(), 0)
self.assertEquals(timeline.succeeded, False)
self.assertEquals(timeline.error, "Stream timed out")

View file

@ -0,0 +1,225 @@
# retrieved with:
# curl -X GET -H "Authorization: Bearer <TOKEN>" "https://api.twitter.com/1.1/statuses/user_timeline.json?screen_name=RobertsSpaceInd&tweet_mode=extended" | python3 -m json.tool --sort-keys
simple_mock = [
{
"contributors": None,
"coordinates": None,
"created_at": "Fri Sep 18 20:32:22 +0000 2020",
"display_text_range": [0, 111],
"entities": {
"hashtags": [{"indices": [26, 41], "text": "SCShipShowdown"}],
"symbols": [],
"urls": [],
"user_mentions": [],
},
"favorite_count": 54,
"favorited": False,
"full_text": "It's a close match-up for #SCShipShowdown today! Which Aegis ship do you think will make it to the Semi-Finals?",
"geo": None,
"id": 1307054882210435074,
"id_str": "1307054882210435074",
"in_reply_to_screen_name": None,
"in_reply_to_status_id": None,
"in_reply_to_status_id_str": None,
"in_reply_to_user_id": None,
"in_reply_to_user_id_str": None,
"is_quote_status": False,
"lang": "en",
"place": None,
"retweet_count": 9,
"retweeted": False,
"source": '<a href="https://mobile.twitter.com" rel="nofollow">Twitter Web App</a>',
"truncated": False,
"user": {
"contributors_enabled": False,
"created_at": "Wed Sep 05 00:58:11 +0000 2012",
"default_profile": False,
"default_profile_image": False,
"description": "The official Twitter profile for #StarCitizen and Roberts Space Industries.",
"entities": {
"description": {"urls": []},
"url": {
"urls": [
{
"display_url": "robertsspaceindustries.com",
"expanded_url": "http://www.robertsspaceindustries.com",
"indices": [0, 23],
"url": "https://t.co/iqO6apof3y",
}
]
},
},
"favourites_count": 4831,
"follow_request_sent": None,
"followers_count": 106971,
"following": None,
"friends_count": 204,
"geo_enabled": False,
"has_extended_profile": False,
"id": 803542770,
"id_str": "803542770",
"is_translation_enabled": False,
"is_translator": False,
"lang": None,
"listed_count": 893,
"location": "Roberts Space Industries",
"name": "Star Citizen",
"notifications": None,
"profile_background_color": "131516",
"profile_background_image_url": "http://abs.twimg.com/images/themes/theme14/bg.gif",
"profile_background_image_url_https": "https://abs.twimg.com/images/themes/theme14/bg.gif",
"profile_background_tile": False,
"profile_banner_url": "https://pbs.twimg.com/profile_banners/803542770/1596651186",
"profile_image_url": "http://pbs.twimg.com/profile_images/963109950103814144/ysnj_Asy_normal.jpg",
"profile_image_url_https": "https://pbs.twimg.com/profile_images/963109950103814144/ysnj_Asy_normal.jpg",
"profile_link_color": "0A5485",
"profile_sidebar_border_color": "FFFFFF",
"profile_sidebar_fill_color": "EFEFEF",
"profile_text_color": "333333",
"profile_use_background_image": True,
"protected": False,
"screen_name": "RobertsSpaceInd",
"statuses_count": 6368,
"time_zone": None,
"translator_type": "none",
"url": "https://t.co/iqO6apof3y",
"utc_offset": None,
"verified": True,
},
},
{
"contributors": None,
"coordinates": None,
"created_at": "Fri Sep 18 18:50:11 +0000 2020",
"display_text_range": [0, 271],
"entities": {
"hashtags": [{"indices": [211, 218], "text": "Twitch"}],
"media": [
{
"display_url": "pic.twitter.com/Cey5JpR1i9",
"expanded_url": "https://twitter.com/RobertsSpaceInd/status/1307029168941461504/photo/1",
"id": 1307028141697765376,
"id_str": "1307028141697765376",
"indices": [272, 295],
"media_url": "http://pbs.twimg.com/media/EiN_K4FVkAAGBcr.jpg",
"media_url_https": "https://pbs.twimg.com/media/EiN_K4FVkAAGBcr.jpg",
"sizes": {
"large": {"h": 1090, "resize": "fit", "w": 1920},
"medium": {"h": 681, "resize": "fit", "w": 1200},
"small": {"h": 386, "resize": "fit", "w": 680},
"thumb": {"h": 150, "resize": "crop", "w": 150},
},
"type": "photo",
"url": "https://t.co/Cey5JpR1i9",
}
],
"symbols": [],
"urls": [
{
"display_url": "twitch.tv/starcitizen",
"expanded_url": "http://twitch.tv/starcitizen",
"indices": [248, 271],
"url": "https://t.co/2AdNovhpFW",
}
],
"user_mentions": [],
},
"extended_entities": {
"media": [
{
"display_url": "pic.twitter.com/Cey5JpR1i9",
"expanded_url": "https://twitter.com/RobertsSpaceInd/status/1307029168941461504/photo/1",
"id": 1307028141697765376,
"id_str": "1307028141697765376",
"indices": [272, 295],
"media_url": "http://pbs.twimg.com/media/EiN_K4FVkAAGBcr.jpg",
"media_url_https": "https://pbs.twimg.com/media/EiN_K4FVkAAGBcr.jpg",
"sizes": {
"large": {"h": 1090, "resize": "fit", "w": 1920},
"medium": {"h": 681, "resize": "fit", "w": 1200},
"small": {"h": 386, "resize": "fit", "w": 680},
"thumb": {"h": 150, "resize": "crop", "w": 150},
},
"type": "photo",
"url": "https://t.co/Cey5JpR1i9",
}
]
},
"favorite_count": 90,
"favorited": False,
"full_text": "We\u2019re welcoming members of our Builds, Publishes and Platform teams on Star Citizen Live to talk about the process involved in bringing everyone\u2019s work together and getting it out into your hands. Going live on #Twitch in 10 minutes. \ud83c\udfa5\ud83d\udd34 \n\nTune in: https://t.co/2AdNovhpFW https://t.co/Cey5JpR1i9",
"geo": None,
"id": 1307029168941461504,
"id_str": "1307029168941461504",
"in_reply_to_screen_name": None,
"in_reply_to_status_id": None,
"in_reply_to_status_id_str": None,
"in_reply_to_user_id": None,
"in_reply_to_user_id_str": None,
"is_quote_status": False,
"lang": "en",
"place": None,
"possibly_sensitive": False,
"retweet_count": 13,
"retweeted": False,
"source": '<a href="https://mobile.twitter.com" rel="nofollow">Twitter Web App</a>',
"truncated": False,
"user": {
"contributors_enabled": False,
"created_at": "Wed Sep 05 00:58:11 +0000 2012",
"default_profile": False,
"default_profile_image": False,
"description": "The official Twitter profile for #StarCitizen and Roberts Space Industries.",
"entities": {
"description": {"urls": []},
"url": {
"urls": [
{
"display_url": "robertsspaceindustries.com",
"expanded_url": "http://www.robertsspaceindustries.com",
"indices": [0, 23],
"url": "https://t.co/iqO6apof3y",
}
]
},
},
"favourites_count": 4831,
"follow_request_sent": None,
"followers_count": 106971,
"following": None,
"friends_count": 204,
"geo_enabled": False,
"has_extended_profile": False,
"id": 803542770,
"id_str": "803542770",
"is_translation_enabled": False,
"is_translator": False,
"lang": None,
"listed_count": 893,
"location": "Roberts Space Industries",
"name": "Star Citizen",
"notifications": None,
"profile_background_color": "131516",
"profile_background_image_url": "http://abs.twimg.com/images/themes/theme14/bg.gif",
"profile_background_image_url_https": "https://abs.twimg.com/images/themes/theme14/bg.gif",
"profile_background_tile": False,
"profile_banner_url": "https://pbs.twimg.com/profile_banners/803542770/1596651186",
"profile_image_url": "http://pbs.twimg.com/profile_images/963109950103814144/ysnj_Asy_normal.jpg",
"profile_image_url_https": "https://pbs.twimg.com/profile_images/963109950103814144/ysnj_Asy_normal.jpg",
"profile_link_color": "0A5485",
"profile_sidebar_border_color": "FFFFFF",
"profile_sidebar_fill_color": "EFEFEF",
"profile_text_color": "333333",
"profile_use_background_image": True,
"protected": False,
"screen_name": "RobertsSpaceInd",
"statuses_count": 6368,
"time_zone": None,
"translator_type": "none",
"url": "https://t.co/iqO6apof3y",
"utc_offset": None,
"verified": True,
},
},
]

View file

@ -0,0 +1,107 @@
from json import JSONDecodeError
from unittest.mock import patch
from django.test import TestCase
from newsreader.news.collection.exceptions import (
StreamDeniedException,
StreamException,
StreamForbiddenException,
StreamNotFoundException,
StreamParseException,
StreamTimeOutException,
StreamTooManyException,
)
from newsreader.news.collection.tests.factories import TwitterTimelineFactory
from newsreader.news.collection.tests.twitter.stream.mocks import simple_mock
from newsreader.news.collection.twitter import TwitterStream
class TwitterStreamTestCase(TestCase):
def setUp(self):
self.patched_fetch = patch("newsreader.news.collection.twitter.fetch")
self.mocked_fetch = self.patched_fetch.start()
def tearDown(self):
patch.stopall()
def test_simple_stream(self):
self.mocked_fetch.return_value.json.return_value = simple_mock
timeline = TwitterTimelineFactory()
stream = TwitterStream(timeline)
data, stream = stream.read()
self.assertEquals(data, simple_mock)
self.assertEquals(stream, stream)
self.mocked_fetch.assert_called()
def test_stream_raises_exception(self):
self.mocked_fetch.side_effect = StreamException
timeline = TwitterTimelineFactory()
stream = TwitterStream(timeline)
with self.assertRaises(StreamException):
stream.read()
self.mocked_fetch.assert_called()
def test_stream_raises_denied_exception(self):
self.mocked_fetch.side_effect = StreamDeniedException
timeline = TwitterTimelineFactory()
stream = TwitterStream(timeline)
with self.assertRaises(StreamDeniedException):
stream.read()
self.mocked_fetch.assert_called()
def test_stream_raises_not_found_exception(self):
self.mocked_fetch.side_effect = StreamNotFoundException
timeline = TwitterTimelineFactory()
stream = TwitterStream(timeline)
with self.assertRaises(StreamNotFoundException):
stream.read()
self.mocked_fetch.assert_called()
def test_stream_raises_time_out_exception(self):
self.mocked_fetch.side_effect = StreamTimeOutException
timeline = TwitterTimelineFactory()
stream = TwitterStream(timeline)
with self.assertRaises(StreamTimeOutException):
stream.read()
self.mocked_fetch.assert_called()
def test_stream_raises_forbidden_exception(self):
self.mocked_fetch.side_effect = StreamForbiddenException
timeline = TwitterTimelineFactory()
stream = TwitterStream(timeline)
with self.assertRaises(StreamForbiddenException):
stream.read()
self.mocked_fetch.assert_called()
def test_stream_raises_parse_exception(self):
self.mocked_fetch.return_value.json.side_effect = JSONDecodeError(
"No json found", "{}", 5
)
timeline = TwitterTimelineFactory()
stream = TwitterStream(timeline)
with self.assertRaises(StreamParseException):
stream.read()
self.mocked_fetch.assert_called()

View file

@ -0,0 +1,63 @@
from json import JSONDecodeError
from unittest.mock import patch
from django.test import TestCase
from newsreader.accounts.tests.factories import UserFactory
from newsreader.news.collection.exceptions import StreamException
from newsreader.news.collection.twitter import TwitterTimeLineScheduler
class TwitterTimeLineSchedulerTestCase(TestCase):
def setUp(self):
patched_fetch = patch("newsreader.news.collection.twitter.fetch")
self.mocked_fetch = patched_fetch.start()
def test_simple(self):
user = UserFactory(twitter_oauth_token="foo", twitter_oauth_token_secret="bar")
self.mocked_fetch.return_value.json.return_value = {
"rate_limit_context": {"application": "dummykey"},
"resources": {
"statuses": {
"/statuses/user_timeline": {
"limit": 1500,
"remaining": 1500,
"reset": 1601141386,
}
}
},
}
scheduler = TwitterTimeLineScheduler(user)
self.assertEquals(scheduler.get_current_ratelimit(), 1500)
def test_stream_exception(self):
user = UserFactory(twitter_oauth_token=None, twitter_oauth_token_secret=None)
self.mocked_fetch.side_effect = StreamException
scheduler = TwitterTimeLineScheduler(user)
self.assertEquals(scheduler.get_current_ratelimit(), None)
def test_json_decode_error(self):
user = UserFactory(twitter_oauth_token="foo", twitter_oauth_token_secret="bar")
self.mocked_fetch.return_value.json.side_effect = JSONDecodeError(
"foo", "bar", 10
)
scheduler = TwitterTimeLineScheduler(user)
self.assertEquals(scheduler.get_current_ratelimit(), None)
def test_unexpected_contents(self):
user = UserFactory(twitter_oauth_token="foo", twitter_oauth_token_secret="bar")
self.mocked_fetch.return_value.json.return_value = {"foo": "bar"}
scheduler = TwitterTimeLineScheduler(user)
self.assertEquals(scheduler.get_current_ratelimit(), None)

View file

@ -1,4 +1,4 @@
from unittest.mock import MagicMock, patch from unittest.mock import Mock, patch
from django.test import TestCase from django.test import TestCase
@ -19,7 +19,7 @@ from newsreader.news.collection.utils import fetch, post
class HelperFunctionTestCase: class HelperFunctionTestCase:
def test_simple(self): def test_simple(self):
self.mocked_method.return_value = MagicMock(status_code=200, content="content") self.mocked_method.return_value = Mock(status_code=200, content="content")
url = "https://www.bbc.co.uk/news" url = "https://www.bbc.co.uk/news"
response = self.method(url) response = self.method(url)
@ -27,7 +27,7 @@ class HelperFunctionTestCase:
self.assertEquals(response.content, "content") self.assertEquals(response.content, "content")
def test_raises_not_found(self): def test_raises_not_found(self):
self.mocked_method.return_value = MagicMock(status_code=404) self.mocked_method.return_value = Mock(status_code=404)
url = "https://www.bbc.co.uk/news" url = "https://www.bbc.co.uk/news"
@ -35,7 +35,7 @@ class HelperFunctionTestCase:
self.method(url) self.method(url)
def test_raises_denied(self): def test_raises_denied(self):
self.mocked_method.return_value = MagicMock(status_code=401) self.mocked_method.return_value = Mock(status_code=401)
url = "https://www.bbc.co.uk/news" url = "https://www.bbc.co.uk/news"
@ -43,7 +43,7 @@ class HelperFunctionTestCase:
self.method(url) self.method(url)
def test_raises_forbidden(self): def test_raises_forbidden(self):
self.mocked_method.return_value = MagicMock(status_code=403) self.mocked_method.return_value = Mock(status_code=403)
url = "https://www.bbc.co.uk/news" url = "https://www.bbc.co.uk/news"
@ -51,7 +51,7 @@ class HelperFunctionTestCase:
self.method(url) self.method(url)
def test_raises_timed_out(self): def test_raises_timed_out(self):
self.mocked_method.return_value = MagicMock(status_code=408) self.mocked_method.return_value = Mock(status_code=408)
url = "https://www.bbc.co.uk/news" url = "https://www.bbc.co.uk/news"
@ -99,7 +99,7 @@ class HelperFunctionTestCase:
self.method(url) self.method(url)
def test_raises_stream_error_on_too_many_requests(self): def test_raises_stream_error_on_too_many_requests(self):
self.mocked_method.return_value = MagicMock(status_code=429) self.mocked_method.return_value = Mock(status_code=429)
url = "https://www.bbc.co.uk/news" url = "https://www.bbc.co.uk/news"

View file

@ -49,7 +49,7 @@ class CollectionRuleViewTestCase:
timezone=other_rule.timezone, timezone=other_rule.timezone,
) )
other_url = reverse("news:collection:rule-update", args=[other_rule.pk]) other_url = reverse("news:collection:feed-update", args=[other_rule.pk])
response = self.client.post(other_url, self.form_data) response = self.client.post(other_url, self.form_data)
self.assertEquals(response.status_code, 404) self.assertEquals(response.status_code, 404)

View file

@ -3,6 +3,8 @@ from django.urls import reverse
import pytz import pytz
from django_celery_beat.models import PeriodicTask
from newsreader.news.collection.choices import RuleTypeChoices from newsreader.news.collection.choices import RuleTypeChoices
from newsreader.news.collection.models import CollectionRule from newsreader.news.collection.models import CollectionRule
from newsreader.news.collection.tests.factories import FeedFactory from newsreader.news.collection.tests.factories import FeedFactory
@ -10,11 +12,11 @@ from newsreader.news.collection.tests.views.base import CollectionRuleViewTestCa
from newsreader.news.core.tests.factories import CategoryFactory from newsreader.news.core.tests.factories import CategoryFactory
class CollectionRuleCreateViewTestCase(CollectionRuleViewTestCase, TestCase): class FeedCreateViewTestCase(CollectionRuleViewTestCase, TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.url = reverse("news:collection:rule-create") self.url = reverse("news:collection:feed-create")
self.form_data.update( self.form_data.update(
name="new rule", name="new rule",
@ -37,15 +39,21 @@ class CollectionRuleCreateViewTestCase(CollectionRuleViewTestCase, TestCase):
self.assertEquals(rule.category.pk, self.category.pk) self.assertEquals(rule.category.pk, self.category.pk)
self.assertEquals(rule.user.pk, self.user.pk) self.assertEquals(rule.user.pk, self.user.pk)
self.assertTrue(
PeriodicTask.objects.get(
name=f"{self.user.email}-feed", task="FeedTask", enabled=True
)
)
class CollectionRuleUpdateViewTestCase(CollectionRuleViewTestCase, TestCase):
class FeedUpdateViewTestCase(CollectionRuleViewTestCase, TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.rule = FeedFactory( self.rule = FeedFactory(
name="collection rule", user=self.user, category=self.category name="collection rule", user=self.user, category=self.category
) )
self.url = reverse("news:collection:rule-update", kwargs={"pk": self.rule.pk}) self.url = reverse("news:collection:feed-update", kwargs={"pk": self.rule.pk})
self.form_data.update( self.form_data.update(
name=self.rule.name, name=self.rule.name,
@ -94,7 +102,7 @@ class CollectionRuleUpdateViewTestCase(CollectionRuleViewTestCase, TestCase):
category=self.category, category=self.category,
type=RuleTypeChoices.subreddit, type=RuleTypeChoices.subreddit,
) )
url = reverse("news:collection:rule-update", kwargs={"pk": rule.pk}) url = reverse("news:collection:feed-update", kwargs={"pk": rule.pk})
response = self.client.get(url) response = self.client.get(url)

View file

@ -84,7 +84,7 @@ class OPMLImportTestCase(TestCase):
rules = CollectionRule.objects.all() rules = CollectionRule.objects.all()
self.assertEquals(len(rules), 0) self.assertEquals(len(rules), 0)
self.assertFormError(response, "form", "file", _("No (new) rules found")) self.assertFormError(response, "form", "file", _("No (new) feeds found"))
def test_invalid_feeds(self): def test_invalid_feeds(self):
file_path = self._get_file_path("invalid-url-feeds.opml") file_path = self._get_file_path("invalid-url-feeds.opml")
@ -99,7 +99,7 @@ class OPMLImportTestCase(TestCase):
rules = CollectionRule.objects.all() rules = CollectionRule.objects.all()
self.assertEquals(len(rules), 0) self.assertEquals(len(rules), 0)
self.assertFormError(response, "form", "file", _("No (new) rules found")) self.assertFormError(response, "form", "file", _("No (new) feeds found"))
def test_invalid_file(self): def test_invalid_file(self):
file_path = self._get_file_path("test.png") file_path = self._get_file_path("test.png")

View file

@ -0,0 +1,129 @@
from django.test import TestCase
from django.urls import reverse
import pytz
from django_celery_beat.models import PeriodicTask
from newsreader.news.collection.choices import RuleTypeChoices
from newsreader.news.collection.models import CollectionRule
from newsreader.news.collection.tests.factories import TwitterTimelineFactory
from newsreader.news.collection.tests.views.base import CollectionRuleViewTestCase
from newsreader.news.collection.twitter import TWITTER_API_URL
from newsreader.news.core.tests.factories import CategoryFactory
class TwitterTimelineCreateViewTestCase(CollectionRuleViewTestCase, TestCase):
def setUp(self):
super().setUp()
self.form_data = {
"name": "new rule",
"screen_name": "RobertsSpaceInd",
"category": str(self.category.pk),
}
self.url = reverse("news:collection:twitter-timeline-create")
def test_creation(self):
response = self.client.post(self.url, self.form_data)
self.assertEquals(response.status_code, 302)
rule = CollectionRule.objects.get(name="new rule")
self.assertEquals(rule.type, RuleTypeChoices.twitter_timeline)
self.assertEquals(
rule.url,
f"{TWITTER_API_URL}/statuses/user_timeline.json?screen_name=RobertsSpaceInd&tweet_mode=extended",
)
self.assertEquals(rule.timezone, str(pytz.utc))
self.assertEquals(rule.favicon, None)
self.assertEquals(rule.category.pk, self.category.pk)
self.assertEquals(rule.user.pk, self.user.pk)
self.assertTrue(
PeriodicTask.objects.get(
name=f"{self.user.email}-timeline",
task="TwitterTimelineTask",
enabled=True,
)
)
class TwitterTimelineUpdateViewTestCase(CollectionRuleViewTestCase, TestCase):
def setUp(self):
super().setUp()
self.rule = TwitterTimelineFactory(
name="Star citizen",
screen_name="RobertsSpaceInd",
user=self.user,
category=self.category,
type=RuleTypeChoices.twitter_timeline,
)
self.url = reverse(
"news:collection:twitter-timeline-update", kwargs={"pk": self.rule.pk}
)
self.form_data = {
"name": self.rule.name,
"screen_name": self.rule.screen_name,
"category": str(self.category.pk),
"timezone": pytz.utc,
}
def test_name_change(self):
self.form_data.update(name="Star citizen Twitter")
response = self.client.post(self.url, self.form_data)
self.assertEquals(response.status_code, 302)
self.rule.refresh_from_db()
self.assertEquals(self.rule.name, "Star citizen Twitter")
def test_category_change(self):
new_category = CategoryFactory(user=self.user)
self.form_data.update(category=new_category.pk)
response = self.client.post(self.url, self.form_data)
self.assertEquals(response.status_code, 302)
self.rule.refresh_from_db()
self.assertEquals(self.rule.category.pk, new_category.pk)
def test_twitter_timelines_only(self):
rule = TwitterTimelineFactory(
name="Fake twitter",
user=self.user,
category=self.category,
type=RuleTypeChoices.feed,
url="https://twitter.com/RobertsSpaceInd",
)
url = reverse("news:collection:twitter-timeline-update", kwargs={"pk": rule.pk})
response = self.client.get(url)
self.assertEquals(response.status_code, 404)
def test_screen_name_change(self):
self.form_data.update(screen_name="CyberpunkGame")
response = self.client.post(self.url, self.form_data)
self.assertEquals(response.status_code, 302)
self.rule.refresh_from_db()
self.assertEquals(self.rule.type, RuleTypeChoices.twitter_timeline)
self.assertEquals(
self.rule.url,
f"{TWITTER_API_URL}/statuses/user_timeline.json?screen_name=CyberpunkGame&tweet_mode=extended",
)
self.assertEquals(self.rule.timezone, str(pytz.utc))
self.assertEquals(self.rule.favicon, None)
self.assertEquals(self.rule.category.pk, self.category.pk)
self.assertEquals(self.rule.user.pk, self.user.pk)

View file

@ -0,0 +1,281 @@
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from json import JSONDecodeError
from django.conf import settings
from django.utils import timezone
from django.utils.html import format_html, urlize
import pytz
from ftfy import fix_text
from requests_oauthlib import OAuth1 as OAuth
from newsreader.news.collection.base import (
PostBuilder,
PostClient,
PostCollector,
PostStream,
Scheduler,
)
from newsreader.news.collection.choices import RuleTypeChoices, TwitterPostTypeChoices
from newsreader.news.collection.exceptions import (
StreamDeniedException,
StreamException,
StreamNotFoundException,
StreamParseException,
StreamTimeOutException,
StreamTooManyException,
)
from newsreader.news.collection.utils import fetch, truncate_text
from newsreader.news.core.models import Post
logger = logging.getLogger(__name__)
TWITTER_URL = "https://twitter.com"
TWITTER_API_URL = "https://api.twitter.com/1.1"
TWITTER_REQUEST_TOKEN_URL = "https://api.twitter.com/oauth/request_token"
TWITTER_AUTH_URL = "https://api.twitter.com/oauth/authorize"
TWITTER_ACCESS_TOKEN_URL = "https://api.twitter.com/oauth/access_token"
TWITTER_REVOKE_URL = f"{TWITTER_API_URL}/oauth/invalidate_token"
class TwitterBuilder(PostBuilder):
rule_type = RuleTypeChoices.twitter_timeline
def build(self):
results = {}
rule = self.stream.rule
for post in self.payload:
remote_identifier = post["id_str"]
if remote_identifier in self.existing_posts:
continue
url = f"{TWITTER_URL}/{rule.screen_name}/status/{remote_identifier}"
body = urlize(post["full_text"], nofollow=True)
title = truncate_text(
Post, "title", self.sanitize_fragment(post["full_text"])
)
publication_date = pytz.utc.localize(
datetime.strptime(post["created_at"], "%a %b %d %H:%M:%S +0000 %Y")
)
if "extended_entities" in post:
try:
media_entities = self.get_media_entities(post)
body += media_entities
except KeyError:
logger.exception(f"Failed parsing media_entities for {url}")
if "retweeted_status" in post:
original_post = post["retweeted_status"]
original_tweet = urlize(original_post["full_text"], nofollow=True)
body = f"{body} <br><div>Original tweet: {original_tweet}</div>"
if "quoted_status" in post:
original_post = post["quoted_status"]
original_tweet = urlize(original_post["full_text"], nofollow=True)
body = f"{body} <br><div>Quoted tweet: {original_tweet}</div>"
body = self.sanitize_fragment(body)
data = {
"remote_identifier": remote_identifier,
"title": fix_text(title),
"body": fix_text(body),
"author": rule.screen_name,
"publication_date": publication_date,
"url": url,
"rule": rule,
}
results[remote_identifier] = Post(**data)
self.instances = results.values()
def get_media_entities(self, post):
media_entities = post["extended_entities"]["media"]
formatted_entities = ""
for media_entity in media_entities:
media_type = media_entity["type"]
media_url = media_entity["media_url_https"]
title = media_entity["id_str"]
if media_type == TwitterPostTypeChoices.photo:
html_fragment = format_html(
"""<br /><div><img alt="{title}" src="{media_url}" loading="lazy" /></div>""",
title=title,
media_url=media_url,
)
formatted_entities += html_fragment
elif media_type in (
TwitterPostTypeChoices.video,
TwitterPostTypeChoices.animated_gif,
):
meta_data = media_entity["video_info"]
videos = sorted(
[video for video in meta_data["variants"]],
reverse=True,
key=lambda video: video.get("bitrate", 0),
)
if not videos:
continue
video = videos[0]
content_type = video["content_type"]
url = video["url"]
html_fragment = format_html(
"""<br /><div><video controls muted><source src="{url}" type="{content_type}" /></video></div> """,
url=url,
content_type=content_type,
)
formatted_entities += html_fragment
return formatted_entities
class TwitterStream(PostStream):
rule_type = RuleTypeChoices.twitter_timeline
def read(self):
oauth = OAuth(
settings.TWITTER_CONSUMER_ID,
client_secret=settings.TWITTER_CONSUMER_SECRET,
resource_owner_key=self.rule.user.twitter_oauth_token,
resource_owner_secret=self.rule.user.twitter_oauth_token_secret,
)
response = fetch(self.rule.url, auth=oauth)
return self.parse(response), self
def parse(self, response):
try:
return response.json()
except JSONDecodeError as e:
raise StreamParseException(
response=response, message="Failed parsing json"
) from e
class TwitterClient(PostClient):
stream = TwitterStream
def __enter__(self):
streams = [self.stream(timeline) for timeline in self.rules]
with ThreadPoolExecutor(max_workers=10) as executor:
futures = {executor.submit(stream.read): stream for stream in streams}
for future in as_completed(futures):
stream = futures[future]
try:
payload = future.result()
stream.rule.error = None
stream.rule.succeeded = True
yield payload
except StreamTooManyException as e:
logger.exception("Ratelimit hit, aborting twitter calls")
self.set_rule_error(stream.rule, e)
break
except StreamDeniedException as e:
logger.warning(
f"Access token expired for user {stream.rule.user.pk}"
)
stream.rule.user.twitter_oauth_token = None
stream.rule.user.twitter_oauth_token_secret = None
stream.rule.user.save()
self.set_rule_error(stream.rule, e)
break
except (StreamNotFoundException, StreamTimeOutException) as e:
logger.warning(f"Request failed for {stream.rule.screen_name}")
self.set_rule_error(stream.rule, e)
continue
except StreamException as e:
logger.exception(f"Request failed for {stream.rule.screen_name}")
self.set_rule_error(stream.rule, e)
continue
finally:
stream.rule.last_run = timezone.now()
stream.rule.save()
class TwitterCollector(PostCollector):
builder = TwitterBuilder
client = TwitterClient
# see https://developer.twitter.com/en/docs/twitter-api/v1/rate-limits
class TwitterTimeLineScheduler(Scheduler):
def __init__(self, user, timelines=[]):
self.user = user
if not timelines:
self.timelines = (
user.rules.enabled()
.filter(type=RuleTypeChoices.twitter_timeline)
.order_by("last_run")[:200]
)
else:
self.timelines = timelines
def get_scheduled_rules(self):
max_amount = self.get_current_ratelimit()
return self.timelines[:max_amount] if max_amount else []
def get_current_ratelimit(self):
endpoint = "application/rate_limit_status.json?resources=statuses"
if (
not self.user.twitter_oauth_token
or not self.user.twitter_oauth_token_secret
):
return
oauth = OAuth(
settings.TWITTER_CONSUMER_ID,
client_secret=settings.TWITTER_CONSUMER_SECRET,
resource_owner_key=self.user.twitter_oauth_token,
resource_owner_secret=self.user.twitter_oauth_token_secret,
)
try:
response = fetch(f"{TWITTER_API_URL}/{endpoint}", auth=oauth)
except StreamException:
logger.exception(f"Unable to retrieve current ratelimit for {self.user.pk}")
return
try:
payload = response.json()
except JSONDecodeError:
logger.exception(f"Unable to parse ratelimit request for {self.user.pk}")
return
try:
return payload["resources"]["statuses"]["/statuses/user_timeline"]["limit"]
except KeyError:
return

View file

@ -11,12 +11,14 @@ from newsreader.news.collection.views import (
CollectionRuleBulkDeleteView, CollectionRuleBulkDeleteView,
CollectionRuleBulkDisableView, CollectionRuleBulkDisableView,
CollectionRuleBulkEnableView, CollectionRuleBulkEnableView,
CollectionRuleCreateView,
CollectionRuleListView, CollectionRuleListView,
CollectionRuleUpdateView, FeedCreateView,
FeedUpdateView,
OPMLImportView, OPMLImportView,
SubRedditCreateView, SubRedditCreateView,
SubRedditUpdateView, SubRedditUpdateView,
TwitterTimelineCreateView,
TwitterTimelineUpdateView,
) )
@ -28,17 +30,13 @@ endpoints = [
] ]
urlpatterns = [ urlpatterns = [
# Feeds
path(
"feeds/<int:pk>/", login_required(FeedUpdateView.as_view()), name="feed-update"
),
path("feeds/create/", login_required(FeedCreateView.as_view()), name="feed-create"),
# Generic rules
path("rules/", login_required(CollectionRuleListView.as_view()), name="rules"), path("rules/", login_required(CollectionRuleListView.as_view()), name="rules"),
path(
"rules/<int:pk>/",
login_required(CollectionRuleUpdateView.as_view()),
name="rule-update",
),
path(
"rules/create/",
login_required(CollectionRuleCreateView.as_view()),
name="rule-create",
),
path( path(
"rules/delete/", "rules/delete/",
login_required(CollectionRuleBulkDeleteView.as_view()), login_required(CollectionRuleBulkDeleteView.as_view()),
@ -54,15 +52,27 @@ urlpatterns = [
login_required(CollectionRuleBulkDisableView.as_view()), login_required(CollectionRuleBulkDisableView.as_view()),
name="rules-disable", name="rules-disable",
), ),
path("rules/import/", login_required(OPMLImportView.as_view()), name="import"),
# Reddit
path( path(
"rules/subreddits/create/", "subreddits/create/",
login_required(SubRedditCreateView.as_view()), login_required(SubRedditCreateView.as_view()),
name="subreddit-create", name="subreddit-create",
), ),
path( path(
"rules/subreddits/<int:pk>/", "subreddits/<int:pk>/",
login_required(SubRedditUpdateView.as_view()), login_required(SubRedditUpdateView.as_view()),
name="subreddit-update", name="subreddit-update",
), ),
path("rules/import/", login_required(OPMLImportView.as_view()), name="import"), # Twitter
path(
"twitter/timelines/create/",
login_required(TwitterTimelineCreateView.as_view()),
name="twitter-timeline-create",
),
path(
"twitter/timelines/<int:pk>/",
login_required(TwitterTimelineUpdateView.as_view()),
name="twitter-timeline-update",
),
] ]

View file

@ -25,12 +25,12 @@ def build_publication_date(dt, tz):
return published_parsed.astimezone(pytz.utc) return published_parsed.astimezone(pytz.utc)
def fetch(url, headers={}): def fetch(url, auth=None, headers={}):
headers = {**DEFAULT_HEADERS, **headers} headers = {**DEFAULT_HEADERS, **headers}
with ResponseHandler() as response_handler: with ResponseHandler() as response_handler:
try: try:
response = requests.get(url, headers=headers) response = requests.get(url, auth=auth, headers=headers)
response_handler.handle_response(response) response_handler.handle_response(response)
except RequestException as exception: except RequestException as exception:
response_handler.map_exception(exception) response_handler.map_exception(exception)

View file

@ -1,3 +1,8 @@
from newsreader.news.collection.views.feed import (
FeedCreateView,
FeedUpdateView,
OPMLImportView,
)
from newsreader.news.collection.views.reddit import ( from newsreader.news.collection.views.reddit import (
SubRedditCreateView, SubRedditCreateView,
SubRedditUpdateView, SubRedditUpdateView,
@ -6,8 +11,9 @@ from newsreader.news.collection.views.rules import (
CollectionRuleBulkDeleteView, CollectionRuleBulkDeleteView,
CollectionRuleBulkDisableView, CollectionRuleBulkDisableView,
CollectionRuleBulkEnableView, CollectionRuleBulkEnableView,
CollectionRuleCreateView,
CollectionRuleListView, CollectionRuleListView,
CollectionRuleUpdateView, )
OPMLImportView, from newsreader.news.collection.views.twitter import (
TwitterTimelineCreateView,
TwitterTimelineUpdateView,
) )

View file

@ -1,8 +1,11 @@
import json
from django.urls import reverse_lazy from django.urls import reverse_lazy
import pytz import pytz
from newsreader.news.collection.forms import CollectionRuleForm from django_celery_beat.models import IntervalSchedule, PeriodicTask
from newsreader.news.collection.models import CollectionRule from newsreader.news.collection.models import CollectionRule
from newsreader.news.core.models import Category from newsreader.news.core.models import Category
@ -17,7 +20,6 @@ class CollectionRuleViewMixin:
class CollectionRuleDetailMixin: class CollectionRuleDetailMixin:
success_url = reverse_lazy("news:collection:rules") success_url = reverse_lazy("news:collection:rules")
form_class = CollectionRuleForm
def get_context_data(self, **kwargs): def get_context_data(self, **kwargs):
context_data = super().get_context_data(**kwargs) context_data = super().get_context_data(**kwargs)
@ -34,3 +36,25 @@ class CollectionRuleDetailMixin:
kwargs = super().get_form_kwargs() kwargs = super().get_form_kwargs()
kwargs["user"] = self.request.user kwargs["user"] = self.request.user
return kwargs return kwargs
class TaskCreationMixin:
def form_valid(self, form):
response = super().form_valid(form)
interval, period = self.task_interval
task_interval, _ = IntervalSchedule.objects.get_or_create(
every=interval, period=period
)
PeriodicTask.objects.get_or_create(
name=f"{self.request.user.email}-{self.task_name}",
task=self.task_type,
defaults={
"args": json.dumps([self.request.user.pk]),
"interval": task_interval,
"enabled": True,
},
)
return response

View file

@ -0,0 +1,70 @@
from django.contrib import messages
from django.urls import reverse
from django.utils.translation import gettext as _
from django.views.generic.edit import CreateView, FormView, UpdateView
from django_celery_beat.models import IntervalSchedule
from newsreader.news.collection.choices import RuleTypeChoices
from newsreader.news.collection.forms import (
CollectionRuleBulkForm,
FeedForm,
OPMLImportForm,
)
from newsreader.news.collection.models import CollectionRule
from newsreader.news.collection.views.base import (
CollectionRuleDetailMixin,
CollectionRuleViewMixin,
TaskCreationMixin,
)
from newsreader.utils.opml import parse_opml
class FeedUpdateView(CollectionRuleViewMixin, CollectionRuleDetailMixin, UpdateView):
template_name = "news/collection/views/feed-update.html"
context_object_name = "feed"
form_class = FeedForm
def get_queryset(self):
queryset = super().get_queryset()
return queryset.filter(type=RuleTypeChoices.feed)
class FeedCreateView(
CollectionRuleViewMixin, CollectionRuleDetailMixin, TaskCreationMixin, CreateView
):
template_name = "news/collection/views/feed-create.html"
task_interval = (1, IntervalSchedule.HOURS)
task_name = "feed"
task_type = "FeedTask"
form_class = FeedForm
class OPMLImportView(FormView):
form_class = OPMLImportForm
template_name = "news/collection/views/import.html"
def form_valid(self, form):
user = self.request.user
file = form.cleaned_data["file"]
skip_existing = form.cleaned_data["skip_existing"]
instances = parse_opml(file, user, skip_existing=skip_existing)
try:
feeds = CollectionRule.objects.bulk_create(instances)
except IOError:
form.add_error("file", _("Invalid OPML file"))
return self.form_invalid(form)
if not feeds:
form.add_error("file", _("No (new) feeds found"))
return self.form_invalid(form)
message = _(f"{len(feeds)} new feeds created")
messages.success(self.request, message)
return super().form_valid(form)
def get_success_url(self):
return reverse("news:collection:rules")

View file

@ -1,7 +1,7 @@
from django.views.generic.edit import CreateView, UpdateView from django.views.generic.edit import CreateView, UpdateView
from newsreader.news.collection.choices import RuleTypeChoices from newsreader.news.collection.choices import RuleTypeChoices
from newsreader.news.collection.forms import SubRedditRuleForm from newsreader.news.collection.forms import SubRedditForm
from newsreader.news.collection.views.base import ( from newsreader.news.collection.views.base import (
CollectionRuleDetailMixin, CollectionRuleDetailMixin,
CollectionRuleViewMixin, CollectionRuleViewMixin,
@ -11,14 +11,14 @@ from newsreader.news.collection.views.base import (
class SubRedditCreateView( class SubRedditCreateView(
CollectionRuleViewMixin, CollectionRuleDetailMixin, CreateView CollectionRuleViewMixin, CollectionRuleDetailMixin, CreateView
): ):
form_class = SubRedditRuleForm form_class = SubRedditForm
template_name = "news/collection/views/subreddit-create.html" template_name = "news/collection/views/subreddit-create.html"
class SubRedditUpdateView( class SubRedditUpdateView(
CollectionRuleViewMixin, CollectionRuleDetailMixin, UpdateView CollectionRuleViewMixin, CollectionRuleDetailMixin, UpdateView
): ):
form_class = SubRedditRuleForm form_class = SubRedditForm
template_name = "news/collection/views/subreddit-update.html" template_name = "news/collection/views/subreddit-update.html"
context_object_name = "subreddit" context_object_name = "subreddit"

View file

@ -2,17 +2,14 @@ from django.contrib import messages
from django.shortcuts import redirect from django.shortcuts import redirect
from django.urls import reverse from django.urls import reverse
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from django.views.generic.edit import CreateView, FormView, UpdateView from django.views.generic.edit import FormView
from django.views.generic.list import ListView from django.views.generic.list import ListView
from newsreader.news.collection.choices import RuleTypeChoices from newsreader.news.collection.forms import CollectionRuleBulkForm
from newsreader.news.collection.forms import CollectionRuleBulkForm, OPMLImportForm
from newsreader.news.collection.models import CollectionRule
from newsreader.news.collection.views.base import ( from newsreader.news.collection.views.base import (
CollectionRuleDetailMixin, CollectionRuleDetailMixin,
CollectionRuleViewMixin, CollectionRuleViewMixin,
) )
from newsreader.utils.opml import parse_opml
class CollectionRuleListView(CollectionRuleViewMixin, ListView): class CollectionRuleListView(CollectionRuleViewMixin, ListView):
@ -21,23 +18,6 @@ class CollectionRuleListView(CollectionRuleViewMixin, ListView):
context_object_name = "rules" context_object_name = "rules"
class CollectionRuleUpdateView(
CollectionRuleViewMixin, CollectionRuleDetailMixin, UpdateView
):
template_name = "news/collection/views/rule-update.html"
context_object_name = "rule"
def get_queryset(self):
queryset = super().get_queryset()
return queryset.filter(type=RuleTypeChoices.feed)
class CollectionRuleCreateView(
CollectionRuleViewMixin, CollectionRuleDetailMixin, CreateView
):
template_name = "news/collection/views/rule-create.html"
class CollectionRuleBulkView(FormView): class CollectionRuleBulkView(FormView):
form_class = CollectionRuleBulkForm form_class = CollectionRuleBulkForm
@ -90,33 +70,3 @@ class CollectionRuleBulkDeleteView(CollectionRuleBulkView):
rule.delete() rule.delete()
return response return response
class OPMLImportView(FormView):
form_class = OPMLImportForm
template_name = "news/collection/views/import.html"
def form_valid(self, form):
user = self.request.user
file = form.cleaned_data["file"]
skip_existing = form.cleaned_data["skip_existing"]
instances = parse_opml(file, user, skip_existing=skip_existing)
try:
rules = CollectionRule.objects.bulk_create(instances)
except IOError:
form.add_error("file", _("Invalid OPML file"))
return self.form_invalid(form)
if not rules:
form.add_error("file", _("No (new) rules found"))
return self.form_invalid(form)
message = _(f"{len(rules)} new rules created")
messages.success(self.request, message)
return super().form_valid(form)
def get_success_url(self):
return reverse("news:collection:rules")

Some files were not shown because too many files have changed in this diff Show more