Improved async/threading attachment support (#1086)

This commit is contained in:
Chris Caron 2024-03-29 14:42:28 -04:00 committed by GitHub
parent 195d0efe3c
commit 81804704da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 210 additions and 122 deletions

View File

@ -253,7 +253,7 @@ class AttachBase(URLBase):
return self.detected_mimetype \
if self.detected_mimetype else self.unknown_mimetype
def exists(self):
def exists(self, retrieve_if_missing=True):
"""
Simply returns true if the object has downloaded and stored the
attachment AND the attachment has not expired.
@ -282,7 +282,7 @@ class AttachBase(URLBase):
# The file is not present
pass
return self.download()
return False if not retrieve_if_missing else self.download()
def invalidate(self):
"""

View File

@ -29,6 +29,7 @@
import re
import os
import requests
import threading
from tempfile import NamedTemporaryFile
from .AttachBase import AttachBase
from ..common import ContentLocation
@ -56,6 +57,9 @@ class AttachHTTP(AttachBase):
# Web based requests are remote/external to our current location
location = ContentLocation.HOSTED
# thread safe loading
_lock = threading.Lock()
def __init__(self, headers=None, **kwargs):
"""
Initialize HTTP Object
@ -96,9 +100,6 @@ class AttachHTTP(AttachBase):
# our content is inaccessible
return False
# Ensure any existing content set has been invalidated
self.invalidate()
# prepare header
headers = {
'User-Agent': self.app_id,
@ -117,134 +118,154 @@ class AttachHTTP(AttachBase):
url += self.fullpath
self.logger.debug('HTTP POST URL: %s (cert_verify=%r)' % (
url, self.verify_certificate,
))
# Where our request object will temporarily live.
r = None
# Always call throttle before any remote server i/o is made
self.throttle()
try:
# Make our request
with requests.get(
url,
headers=headers,
auth=auth,
params=self.qsd,
verify=self.verify_certificate,
timeout=self.request_timeout,
stream=True) as r:
with self._lock:
if self.exists(retrieve_if_missing=False):
# Due to locking; it's possible a concurrent thread already
# handled the retrieval in which case we can safely move on
self.logger.trace(
'HTTP Attachment %s already retrieved',
self._temp_file.name)
return True
# Handle Errors
r.raise_for_status()
# Get our file-size (if known)
try:
file_size = int(r.headers.get('Content-Length', '0'))
except (TypeError, ValueError):
# Handle edge case where Content-Length is a bad value
file_size = 0
# Perform a little Q/A on file limitations and restrictions
if self.max_file_size > 0 and file_size > self.max_file_size:
# The content retrieved is to large
self.logger.error(
'HTTP response exceeds allowable maximum file length '
'({}KB): {}'.format(
int(self.max_file_size / 1024),
self.url(privacy=True)))
# Return False (signifying a failure)
return False
# Detect config format based on mime if the format isn't
# already enforced
self.detected_mimetype = r.headers.get('Content-Type')
d = r.headers.get('Content-Disposition', '')
result = re.search(
"filename=['\"]?(?P<name>[^'\"]+)['\"]?", d, re.I)
if result:
self.detected_name = result.group('name').strip()
# Create a temporary file to work with
self._temp_file = NamedTemporaryFile()
# Get our chunk size
chunk_size = self.chunk_size
# Track all bytes written to disk
bytes_written = 0
# If we get here, we can now safely write our content to disk
for chunk in r.iter_content(chunk_size=chunk_size):
# filter out keep-alive chunks
if chunk:
self._temp_file.write(chunk)
bytes_written = self._temp_file.tell()
# Prevent a case where Content-Length isn't provided
# we don't want to fetch beyond our limits
if self.max_file_size > 0:
if bytes_written > self.max_file_size:
# The content retrieved is to large
self.logger.error(
'HTTP response exceeds allowable maximum '
'file length ({}KB): {}'.format(
int(self.max_file_size / 1024),
self.url(privacy=True)))
# Invalidate any variables previously set
self.invalidate()
# Return False (signifying a failure)
return False
elif bytes_written + chunk_size \
> self.max_file_size:
# Adjust out next read to accomodate up to our
# limit +1. This will prevent us from readig
# to much into our memory buffer
self.max_file_size - bytes_written + 1
# Ensure our content is flushed to disk for post-processing
self._temp_file.flush()
# Set our minimum requirements for a successful download() call
self.download_path = self._temp_file.name
if not self.detected_name:
self.detected_name = os.path.basename(self.fullpath)
except requests.RequestException as e:
self.logger.error(
'A Connection error occurred retrieving HTTP '
'configuration from %s.' % self.host)
self.logger.debug('Socket Exception: %s' % str(e))
# Invalidate any variables previously set
# Ensure any existing content set has been invalidated
self.invalidate()
# Return False (signifying a failure)
return False
self.logger.debug(
'HTTP Attachment Fetch URL: %s (cert_verify=%r)' % (
url, self.verify_certificate))
except (IOError, OSError):
# IOError is present for backwards compatibility with Python
# versions older then 3.3. >= 3.3 throw OSError now.
try:
# Make our request
with requests.get(
url,
headers=headers,
auth=auth,
params=self.qsd,
verify=self.verify_certificate,
timeout=self.request_timeout,
stream=True) as r:
# Could not open and/or write the temporary file
self.logger.error(
'Could not write attachment to disk: {}'.format(
self.url(privacy=True)))
# Handle Errors
r.raise_for_status()
# Invalidate any variables previously set
self.invalidate()
# Get our file-size (if known)
try:
file_size = int(r.headers.get('Content-Length', '0'))
except (TypeError, ValueError):
# Handle edge case where Content-Length is a bad value
file_size = 0
# Return False (signifying a failure)
return False
# Perform a little Q/A on file limitations and restrictions
if self.max_file_size > 0 and \
file_size > self.max_file_size:
# The content retrieved is to large
self.logger.error(
'HTTP response exceeds allowable maximum file '
'length ({}KB): {}'.format(
int(self.max_file_size / 1024),
self.url(privacy=True)))
# Return False (signifying a failure)
return False
# Detect config format based on mime if the format isn't
# already enforced
self.detected_mimetype = r.headers.get('Content-Type')
d = r.headers.get('Content-Disposition', '')
result = re.search(
"filename=['\"]?(?P<name>[^'\"]+)['\"]?", d, re.I)
if result:
self.detected_name = result.group('name').strip()
# Create a temporary file to work with; delete must be set
# to False or it isn't compatible with Microsoft Windows
# instances. In lieu of this, __del__ will clean up the
# file for us.
self._temp_file = NamedTemporaryFile(delete=False)
# Get our chunk size
chunk_size = self.chunk_size
# Track all bytes written to disk
bytes_written = 0
# If we get here, we can now safely write our content to
# disk
for chunk in r.iter_content(chunk_size=chunk_size):
# filter out keep-alive chunks
if chunk:
self._temp_file.write(chunk)
bytes_written = self._temp_file.tell()
# Prevent a case where Content-Length isn't
# provided. In this case we don't want to fetch
# beyond our limits
if self.max_file_size > 0:
if bytes_written > self.max_file_size:
# The content retrieved is to large
self.logger.error(
'HTTP response exceeds allowable '
'maximum file length '
'({}KB): {}'.format(
int(self.max_file_size / 1024),
self.url(privacy=True)))
# Invalidate any variables previously set
self.invalidate()
# Return False (signifying a failure)
return False
elif bytes_written + chunk_size \
> self.max_file_size:
# Adjust out next read to accomodate up to
# our limit +1. This will prevent us from
# reading to much into our memory buffer
self.max_file_size - bytes_written + 1
# Ensure our content is flushed to disk for post-processing
self._temp_file.flush()
# Set our minimum requirements for a successful download()
# call
self.download_path = self._temp_file.name
if not self.detected_name:
self.detected_name = os.path.basename(self.fullpath)
except requests.RequestException as e:
self.logger.error(
'A Connection error occurred retrieving HTTP '
'configuration from %s.' % self.host)
self.logger.debug('Socket Exception: %s' % str(e))
# Invalidate any variables previously set
self.invalidate()
# Return False (signifying a failure)
return False
except (IOError, OSError):
# IOError is present for backwards compatibility with Python
# versions older then 3.3. >= 3.3 throw OSError now.
# Could not open and/or write the temporary file
self.logger.error(
'Could not write attachment to disk: {}'.format(
self.url(privacy=True)))
# Invalidate any variables previously set
self.invalidate()
# Return False (signifying a failure)
return False
# Return our success
return True
@ -254,11 +275,30 @@ class AttachHTTP(AttachBase):
Close our temporary file
"""
if self._temp_file:
self.logger.trace(
'Attachment cleanup of %s', self._temp_file.name)
self._temp_file.close()
try:
# Ensure our file is removed (if it exists)
os.unlink(self._temp_file.name)
except OSError:
pass
# Reset our temporary file to prevent from entering
# this block again
self._temp_file = None
super().invalidate()
def __del__(self):
"""
Tidy memory if open
"""
with self._lock:
self.invalidate()
def url(self, privacy=False, *args, **kwargs):
"""
Returns the URL built dynamically based on specified arguments.

View File

@ -35,7 +35,7 @@ from os.path import join
from os.path import dirname
from os.path import getsize
from apprise.attachment.AttachHTTP import AttachHTTP
from apprise import AppriseAttachment
from apprise import Apprise, AppriseAttachment
from apprise.NotificationManager import NotificationManager
from apprise.plugins.NotifyBase import NotifyBase
from apprise.common import ContentLocation
@ -113,8 +113,9 @@ def test_attach_http_query_string_dictionary():
assert re.search(r'[?&]_var=test', obj.url())
@mock.patch('requests.post')
@mock.patch('requests.get')
def test_attach_http(mock_get):
def test_attach_http(mock_get, mock_post):
"""
API: AttachHTTP() object
@ -422,3 +423,50 @@ def test_attach_http(mock_get):
# Restore value
AttachHTTP.max_file_size = max_file_size
# Multi Message Testing
mock_get.side_effect = None
mock_get.return_value = DummyResponse()
# Prepare our POST response (from notify call)
response = requests.Request()
response.status_code = requests.codes.ok
response.content = ""
mock_post.return_value = response
mock_get.reset_mock()
mock_post.reset_mock()
assert mock_get.call_count == 0
apobj = Apprise()
assert apobj.add('form://localhost')
assert apobj.add('json://localhost')
assert apobj.add('xml://localhost')
assert len(apobj) == 3
assert apobj.notify(
body='one attachment split 3 times',
attach="http://localhost/test.gif",
) is True
# We posted 3 times
assert mock_post.call_count == 3
# We only fetched once and re-used the same fetch for all posts
assert mock_get.call_count == 1
mock_get.reset_mock()
mock_post.reset_mock()
apobj = Apprise()
for n in range(10):
assert apobj.add(f'json://localhost?:entry={n}&method=post')
assert apobj.add(f'form://localhost?:entry={n}&method=post')
assert apobj.add(f'xml://localhost?:entry={n}&method=post')
assert apobj.notify(
body='one attachment split 30 times',
attach="http://localhost/test.gif",
) is True
# We posted 30 times
assert mock_post.call_count == 30
# We only fetched once and re-used the same fetch for all posts
assert mock_get.call_count == 1