Add type hinting
This commit is contained in:
parent
69e4e7b269
commit
c75de1c469
5 changed files with 92 additions and 97 deletions
|
|
@ -1,3 +1,5 @@
|
|||
from typing import ContextManager, Dict, List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
from django.utils import timezone
|
||||
|
|
@ -5,23 +7,58 @@ from django.utils import timezone
|
|||
from newsreader.news.collection.models import CollectionRule
|
||||
|
||||
|
||||
class Stream:
|
||||
def __init__(self, rule: CollectionRule) -> None:
|
||||
self.rule = rule
|
||||
|
||||
def read(self) -> Tuple:
|
||||
url = self.rule.url
|
||||
response = requests.get(url)
|
||||
return (self.parse(response.content), self)
|
||||
|
||||
def parse(self, payload: bytes) -> Dict:
|
||||
raise NotImplementedError
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
|
||||
class Client:
|
||||
stream = Stream
|
||||
|
||||
def __init__(self, rules: Optional[CollectionRule] = None) -> None:
|
||||
self.rules = rules if rules else CollectionRule.objects.all()
|
||||
|
||||
def __enter__(self) -> ContextManager:
|
||||
for rule in self.rules:
|
||||
stream = self.stream(rule)
|
||||
|
||||
yield stream.read()
|
||||
|
||||
def __exit__(self, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
|
||||
class Builder:
|
||||
instances = []
|
||||
|
||||
def __init__(self, stream):
|
||||
def __init__(self, stream: Stream) -> None:
|
||||
self.stream = stream
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> ContextManager:
|
||||
self.create_posts(self.stream)
|
||||
return self
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
def __exit__(self, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
def create_posts(self, stream):
|
||||
def create_posts(self, stream: Tuple) -> None:
|
||||
pass
|
||||
|
||||
def save(self):
|
||||
def save(self) -> None:
|
||||
pass
|
||||
|
||||
class Meta:
|
||||
|
|
@ -32,11 +69,11 @@ class Collector:
|
|||
client = None
|
||||
builder = None
|
||||
|
||||
def __init__(self, client=None, builder=None):
|
||||
def __init__(self, client: Optional[Client] = None, builder: Optional[Builder] = None) -> None:
|
||||
self.client = client if client else self.client
|
||||
self.builder = builder if builder else self.builder
|
||||
|
||||
def collect(self, rules=None):
|
||||
def collect(self, rules: Optional[List] = None) -> None:
|
||||
with self.client(rules=rules) as client:
|
||||
for data, stream in client:
|
||||
with self.builder((data, stream)) as builder:
|
||||
|
|
@ -44,38 +81,3 @@ class Collector:
|
|||
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
|
||||
class Stream:
|
||||
def __init__(self, rule):
|
||||
self.rule = rule
|
||||
|
||||
def read(self):
|
||||
url = self.rule.url
|
||||
response = requests.get(url)
|
||||
return (self.parse(response.content), self)
|
||||
|
||||
def parse(self, payload):
|
||||
raise NotImplementedError
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
|
||||
class Client:
|
||||
stream = Stream
|
||||
|
||||
def __init__(self, rules=None):
|
||||
self.rules = rules if rules else CollectionRule.objects.all()
|
||||
|
||||
def __enter__(self):
|
||||
for rule in self.rules:
|
||||
stream = self.stream(rule)
|
||||
|
||||
yield stream.read()
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import ContextManager, Dict, Generator, List, Optional, Tuple
|
||||
|
||||
import bleach
|
||||
import pytz
|
||||
|
|
@ -16,6 +17,7 @@ from newsreader.news.collection.exceptions import (
|
|||
StreamParseException,
|
||||
StreamTimeOutException,
|
||||
)
|
||||
from newsreader.news.collection.models import CollectionRule
|
||||
from newsreader.news.collection.response_handler import ResponseHandler
|
||||
from newsreader.news.collection.utils import build_publication_date
|
||||
from newsreader.news.posts.models import Post
|
||||
|
|
@ -24,17 +26,16 @@ from newsreader.news.posts.models import Post
|
|||
class FeedBuilder(Builder):
|
||||
instances = []
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> ContextManager:
|
||||
_, stream = self.stream
|
||||
self.instances = []
|
||||
self.existing_posts = {
|
||||
post.remote_identifier: post
|
||||
for post in Post.objects.filter(rule=stream.rule)
|
||||
post.remote_identifier: post for post in Post.objects.filter(rule=stream.rule)
|
||||
}
|
||||
|
||||
return super().__enter__()
|
||||
|
||||
def create_posts(self, stream):
|
||||
def create_posts(self, stream: Tuple) -> None:
|
||||
data, stream = stream
|
||||
entries = []
|
||||
|
||||
|
|
@ -49,30 +50,25 @@ class FeedBuilder(Builder):
|
|||
|
||||
self.instances = [post for post in posts]
|
||||
|
||||
def build(self, entries, rule):
|
||||
def build(self, entries: List, rule: CollectionRule) -> Generator[Post, None, None]:
|
||||
field_mapping = {
|
||||
"id": "remote_identifier",
|
||||
"title": "title",
|
||||
"summary": "body",
|
||||
"link": "url",
|
||||
"published_parsed": "publication_date",
|
||||
"author": "author"
|
||||
"author": "author",
|
||||
}
|
||||
|
||||
tz = pytz.timezone(rule.timezone)
|
||||
|
||||
for entry in entries:
|
||||
data = {
|
||||
"rule_id": rule.pk,
|
||||
"category": rule.category
|
||||
}
|
||||
data = {"rule_id": rule.pk, "category": rule.category}
|
||||
|
||||
for field, value in field_mapping.items():
|
||||
if field in entry:
|
||||
if field == "published_parsed":
|
||||
created, aware_datetime = build_publication_date(
|
||||
entry[field], tz
|
||||
)
|
||||
created, aware_datetime = build_publication_date(entry[field], tz)
|
||||
data[value] = aware_datetime if created else None
|
||||
elif field == "summary":
|
||||
summary = self.sanitize_summary(entry[field])
|
||||
|
|
@ -82,19 +78,19 @@ class FeedBuilder(Builder):
|
|||
|
||||
yield Post(**data)
|
||||
|
||||
def sanitize_summary(self, summary):
|
||||
attrs = {"a": ["href", "rel"], "img": ["alt", "src"],}
|
||||
def sanitize_summary(self, summary: str) -> Optional[str]:
|
||||
attrs = {"a": ["href", "rel"], "img": ["alt", "src"]}
|
||||
tags = ["a", "img", "p"]
|
||||
|
||||
return bleach.clean(summary, tags=tags, attributes=attrs) if summary else None
|
||||
|
||||
def save(self):
|
||||
def save(self) -> None:
|
||||
for post in self.instances:
|
||||
post.save()
|
||||
|
||||
|
||||
class FeedStream(Stream):
|
||||
def read(self):
|
||||
def read(self) -> Tuple:
|
||||
url = self.rule.url
|
||||
response = requests.get(url)
|
||||
|
||||
|
|
@ -103,7 +99,7 @@ class FeedStream(Stream):
|
|||
|
||||
return (self.parse(response.content), self)
|
||||
|
||||
def parse(self, payload):
|
||||
def parse(self, payload: bytes) -> Dict:
|
||||
try:
|
||||
return parse(payload)
|
||||
except TypeError as e:
|
||||
|
|
@ -113,14 +109,11 @@ class FeedStream(Stream):
|
|||
class FeedClient(Client):
|
||||
stream = FeedStream
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> ContextManager:
|
||||
streams = [self.stream(rule) for rule in self.rules]
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = {
|
||||
executor.submit(stream.read): stream
|
||||
for stream in streams
|
||||
}
|
||||
futures = {executor.submit(stream.read): stream for stream in streams}
|
||||
|
||||
for future in as_completed(futures):
|
||||
stream = futures[future]
|
||||
|
|
@ -148,19 +141,19 @@ class FeedCollector(Collector):
|
|||
|
||||
|
||||
class FeedDuplicateHandler:
|
||||
def __init__(self, rule):
|
||||
def __init__(self, rule: CollectionRule) -> None:
|
||||
self.queryset = rule.post_set.all()
|
||||
|
||||
def __enter__(self):
|
||||
self.existing_identifiers = self.queryset.filter(remote_identifier__isnull=False).values_list(
|
||||
"remote_identifier", flat=True
|
||||
)
|
||||
def __enter__(self) -> ContextManager:
|
||||
self.existing_identifiers = self.queryset.filter(
|
||||
remote_identifier__isnull=False
|
||||
).values_list("remote_identifier", flat=True)
|
||||
return self
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
def __exit__(self, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
def check(self, instances):
|
||||
def check(self, instances: List) -> Generator[Post, None, None]:
|
||||
for instance in instances:
|
||||
if instance.remote_identifier in self.existing_identifiers:
|
||||
existing_post = self.handle_duplicate(instance)
|
||||
|
|
@ -173,31 +166,29 @@ class FeedDuplicateHandler:
|
|||
|
||||
yield instance
|
||||
|
||||
def in_database(self, entry):
|
||||
def in_database(self, post: Post) -> Optional[bool]:
|
||||
values = {
|
||||
"url": entry.url,
|
||||
"title": entry.title,
|
||||
"body": entry.body,
|
||||
"publication_date": entry.publication_date
|
||||
"url": post.url,
|
||||
"title": post.title,
|
||||
"body": post.body,
|
||||
"publication_date": post.publication_date,
|
||||
}
|
||||
|
||||
for existing_entry in self.queryset.order_by("-publication_date")[:50]:
|
||||
if self.is_duplicate(existing_entry, values):
|
||||
for existing_post in self.queryset.order_by("-publication_date")[:50]:
|
||||
if self.is_duplicate(existing_post, values):
|
||||
return True
|
||||
|
||||
def is_duplicate(self, existing_entry, values):
|
||||
def is_duplicate(self, existing_post: Post, values: Dict) -> bool:
|
||||
for key, value in values.items():
|
||||
existing_value = getattr(existing_entry, key, object())
|
||||
existing_value = getattr(existing_post, key, object())
|
||||
if existing_value != value:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def handle_duplicate(self, instance):
|
||||
def handle_duplicate(self, instance: Post) -> Optional[Post]:
|
||||
try:
|
||||
existing_instance = self.queryset.get(
|
||||
remote_identifier=instance.remote_identifier,
|
||||
)
|
||||
existing_instance = self.queryset.get(remote_identifier=instance.remote_identifier)
|
||||
except ObjectDoesNotExist:
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,7 @@
|
|||
from typing import ContextManager
|
||||
|
||||
from requests import Response
|
||||
|
||||
from newsreader.news.collection.exceptions import (
|
||||
StreamDeniedException,
|
||||
StreamForbiddenException,
|
||||
|
|
@ -14,17 +18,17 @@ class ResponseHandler:
|
|||
408: StreamTimeOutException,
|
||||
}
|
||||
|
||||
def __init__(self, response):
|
||||
def __init__(self, response: Response) -> None:
|
||||
self.response = response
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> ContextManager:
|
||||
return self
|
||||
|
||||
def handle_response(self):
|
||||
def handle_response(self) -> None:
|
||||
status_code = self.response.status_code
|
||||
|
||||
if status_code in self.message_mapping:
|
||||
raise self.message_mapping[status_code]
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
def __exit__(self, *args, **kwargs) -> None:
|
||||
self.response = None
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from django.utils import timezone
|
|||
from newsreader.news.collection.exceptions import (
|
||||
StreamDeniedException,
|
||||
StreamException,
|
||||
StreamFieldException,
|
||||
StreamNotFoundException,
|
||||
StreamTimeOutException,
|
||||
)
|
||||
|
|
@ -17,9 +16,7 @@ from newsreader.news.collection.tests.feed.client.mocks import simple_mock
|
|||
|
||||
class FeedClientTestCase(TestCase):
|
||||
def setUp(self):
|
||||
self.patched_read = patch(
|
||||
'newsreader.news.collection.feed.FeedStream.read'
|
||||
)
|
||||
self.patched_read = patch("newsreader.news.collection.feed.FeedStream.read")
|
||||
self.mocked_read = self.patched_read.start()
|
||||
|
||||
def tearDown(self):
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
from datetime import datetime
|
||||
from time import mktime
|
||||
from datetime import datetime, tzinfo
|
||||
from time import mktime, struct_time
|
||||
from typing import Tuple
|
||||
|
||||
from django.utils import timezone
|
||||
|
||||
|
||||
def build_publication_date(dt, tz):
|
||||
def build_publication_date(dt: struct_time, tz: tzinfo) -> Tuple:
|
||||
try:
|
||||
naive_datetime = datetime.fromtimestamp(mktime(dt))
|
||||
published_parsed = timezone.make_aware(naive_datetime, timezone=tz)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue