diff --git a/patroni/scripts/aws.py b/patroni/scripts/aws.py index 31e7e398..7d18e828 100755 --- a/patroni/scripts/aws.py +++ b/patroni/scripts/aws.py @@ -6,9 +6,9 @@ import sys import boto3 from ..utils import Retry, RetryFailedError -from ..request import get as requests_get from botocore.exceptions import ClientError +from botocore.utils import IMDSFetcher logger = logging.getLogger(__name__) @@ -21,14 +21,16 @@ class AWSConnection(object): self._retry = Retry(deadline=300, max_delay=30, max_tries=-1, retry_exceptions=(ClientError,)) try: # get the instance id - r = requests_get('http://169.254.169.254/latest/dynamic/instance-identity/document', timeout=2.1) + fetcher = IMDSFetcher(timeout=2.1) + token = fetcher._fetch_metadata_token() + r = fetcher._get_request("/latest/dynamic/instance-identity/document", None, token) except Exception: logger.error('cannot query AWS meta-data') return - if r.status < 400: + if r.status_code < 400: try: - content = json.loads(r.data.decode('utf-8')) + content = json.loads(r.text) self.instance_id = content['instanceId'] self.region = content['region'] except Exception: diff --git a/tests/test_aws.py b/tests/test_aws.py index 91345f33..c291dbd2 100644 --- a/tests/test_aws.py +++ b/tests/test_aws.py @@ -1,9 +1,9 @@ import botocore +import botocore.awsrequest import sys import unittest -import urllib3 -from mock import Mock, patch +from mock import Mock, PropertyMock, patch from collections import namedtuple from patroni.scripts.aws import AWSConnection, main as _main @@ -28,31 +28,45 @@ class MockEc2Connection(object): return True +class MockIMDSFetcher(object): + + def __init__(self, timeout): + pass + + @staticmethod + def _fetch_metadata_token(): + return '' + + @staticmethod + def _get_request(*args): + return botocore.awsrequest.AWSResponse(url='', status_code=200, headers={}, raw=None) + + @patch('boto3.resource', Mock(return_value=MockEc2Connection())) +@patch('patroni.scripts.aws.IMDSFetcher', MockIMDSFetcher) class TestAWSConnection(unittest.TestCase): - @patch('patroni.scripts.aws.requests_get', Mock(return_value=urllib3.HTTPResponse( - status=200, body=b'{"instanceId": "012345", "region": "eu-west-1"}'))) - def setUp(self): - self.conn = AWSConnection('test') - + @patch.object(botocore.awsrequest.AWSResponse, 'text', + PropertyMock(return_value='{"instanceId": "012345", "region": "eu-west-1"}')) def test_on_role_change(self): - self.assertTrue(self.conn.on_role_change('primary')) + conn = AWSConnection('test') + self.assertTrue(conn.on_role_change('primary')) with patch.object(MockVolumes, 'filter', Mock(return_value=[])): - self.conn._retry.max_tries = 1 - self.assertFalse(self.conn.on_role_change('primary')) + conn._retry.max_tries = 1 + self.assertFalse(conn.on_role_change('primary')) - @patch('patroni.scripts.aws.requests_get', Mock(side_effect=Exception('foo'))) + @patch.object(MockIMDSFetcher, '_get_request', Mock(side_effect=Exception('foo'))) def test_non_aws(self): conn = AWSConnection('test') self.assertFalse(conn.on_role_change("primary")) - @patch('patroni.scripts.aws.requests_get', Mock(return_value=urllib3.HTTPResponse(status=200, body=b'foo'))) + @patch.object(botocore.awsrequest.AWSResponse, 'text', PropertyMock(return_value='boo')) def test_aws_bizare_response(self): conn = AWSConnection('test') self.assertFalse(conn.aws_available()) - @patch('patroni.scripts.aws.requests_get', Mock(return_value=urllib3.HTTPResponse(status=503, body=b'Error'))) + @patch.object(MockIMDSFetcher, '_get_request', Mock(return_value=botocore.awsrequest.AWSResponse( + url='', status_code=503, headers={}, raw=None))) @patch('sys.exit', Mock()) def test_main(self): self.assertIsNone(_main())