diff --git a/patroni/dcs/etcd.py b/patroni/dcs/etcd.py index a5b55e7e..1d8911d9 100644 --- a/patroni/dcs/etcd.py +++ b/patroni/dcs/etcd.py @@ -41,6 +41,10 @@ class EtcdRaftInternal(etcd.EtcdException): """Raft Internal Error""" +class StaleEtcdNode(Exception): + """Node is stale (raft term is older than previous known).""" + + class EtcdError(DCSError): pass @@ -102,6 +106,8 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client): ERROR_CLS: Type[Exception] def __init__(self, config: Dict[str, Any], dns_resolver: DnsCachingResolver, cache_ttl: int = 300) -> None: + self._cluster_id = None + self._raft_term = 0 self._dns_resolver = dns_resolver self.set_machines_cache_ttl(cache_ttl) self._machines_cache_updated = 0 @@ -119,6 +125,32 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client): self._read_options.add('retry') self._del_conditions.add('retry') + def _check_cluster_raft_term(self, cluster_id: Optional[str], value: Union[None, str, int]) -> None: + """Check that observed Raft Term in Etcd cluster is increasing. + + If we observe that the new value is smaller than the previously known one, it could be an + indicator that we connected to a stale node and should switch to some other node. + However, we need to reset the memorized value when we notice that Cluster ID changed. + """ + if not (cluster_id and value): + return + + if self._cluster_id and self._cluster_id != cluster_id: + logger.warning('Etcd Cluster ID changed from %s to %s', self._cluster_id, cluster_id) + self._raft_term = 0 + self._cluster_id = cluster_id + + try: + raft_term = int(value) + except Exception: + return + + if raft_term < self._raft_term: + logger.warning('Connected to Etcd node with term %d. Old known term %d. Switching to another node.', + raft_term, self._raft_term) + raise StaleEtcdNode + self._raft_term = raft_term + def _calculate_timeouts(self, etcd_nodes: int, timeout: Optional[float] = None) -> Tuple[int, float, int]: """Calculate a request timeout and number of retries per single etcd node. In case if the timeout per node is too small (less than one second) we will reduce the number of nodes. @@ -227,7 +259,7 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client): def _do_http_request(self, retry: Optional[Retry], machines_cache: List[str], request_executor: Callable[..., urllib3.response.HTTPResponse], method: str, path: str, fields: Optional[Dict[str, Any]] = None, - **kwargs: Any) -> urllib3.response.HTTPResponse: + **kwargs: Any) -> Any: is_watch_request = isinstance(fields, dict) and fields.get('wait') == 'true' if fields is not None: kwargs['fields'] = fields @@ -241,8 +273,8 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client): if some_request_failed: self.set_base_uri(base_uri) self._refresh_machines_cache() - return response - except (HTTPError, HTTPException, socket.error, socket.timeout) as e: + return self._handle_server_response(response) + except (HTTPError, HTTPException, socket.error, socket.timeout, StaleEtcdNode) as e: self.http.clear() if not retry: if len(machines_cache) == 1: @@ -285,8 +317,7 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client): while True: try: - response = self._do_http_request(retry, machines_cache, request_executor, method, path, **kwargs) - return self._handle_server_response(response) + return self._do_http_request(retry, machines_cache, request_executor, method, path, **kwargs) except etcd.EtcdWatchTimedOut: raise except etcd.EtcdConnectionFailed as ex: @@ -463,6 +494,10 @@ class EtcdClient(AbstractEtcdClientWithFailover): def _prepare_get_members(self, etcd_nodes: int) -> Dict[str, Any]: return self._prepare_common_parameters(etcd_nodes) + def _handle_server_response(self, response: urllib3.response.HTTPResponse) -> Any: + self._check_cluster_raft_term(response.headers.get('x-etcd-cluster-id'), response.headers.get('x-raft-term')) + return super(EtcdClient, self)._handle_server_response(response) + def _get_members(self, base_uri: str, **kwargs: Any) -> List[str]: response = self.http.request(self._MGET, base_uri + self.version_prefix + '/machines', **kwargs) data = self._handle_server_response(response).data.decode('utf-8') diff --git a/patroni/dcs/etcd3.py b/patroni/dcs/etcd3.py index 4688cc21..9f790fa9 100644 --- a/patroni/dcs/etcd3.py +++ b/patroni/dcs/etcd3.py @@ -18,6 +18,7 @@ import urllib3 from urllib3.exceptions import ProtocolError, ReadTimeoutError +from ..collections import EMPTY_DICT from ..exceptions import DCSError, PatroniException from ..postgresql.mpp import AbstractMPP from ..utils import deep_compare, enable_keepalive, iter_response_objects, RetryFailedError, USER_AGENT @@ -240,6 +241,10 @@ class Etcd3Client(AbstractEtcdClientWithFailover): try: data = data.decode('utf-8') ret: Dict[str, Any] = json.loads(data) + + header = ret.get('header', EMPTY_DICT) + self._check_cluster_raft_term(header.get('cluster_id'), header.get('raft_term')) + if response.status < 400: return ret except (TypeError, ValueError, UnicodeError) as e: diff --git a/tests/test_etcd.py b/tests/test_etcd.py index eb88b686..d897ec8e 100644 --- a/tests/test_etcd.py +++ b/tests/test_etcd.py @@ -117,6 +117,12 @@ def http_request(method, url, **kwargs): ret.content = 'http://localhost:2379,http://localhost:4001' elif url == 'http://localhost:4001/v2/machines': ret.content = '' + elif url == 'http://localhost:4001/term/': + ret.headers['x-etcd-cluster-id'] = 'a' + ret.headers['x-raft-term'] = '1' + elif url == 'http://localhost:2379/term/': + ret.headers['x-etcd-cluster-id'] = 'b' + ret.headers['x-raft-term'] = 'x' elif url != 'http://localhost:2379/': raise socket.error return ret @@ -165,6 +171,23 @@ class TestClient(unittest.TestCase): except Exception: self.assertIsNone(machines) + def test__check_cluster_raft_term(self): + self.client._raft_term = 2 + self.client._base_uri = 'http://localhost:4001/term' + self.client._machines_cache = [self.client._base_uri, 'http://localhost:2379/term'] + rtry = Retry(deadline=10, max_delay=1, max_tries=-1, retry_exceptions=(etcd.EtcdLeaderElectionInProgress,)) + with patch('patroni.dcs.etcd.logger.warning') as mock_logger: + rtry(self.client.api_execute, '/', 'POST', timeout=0, params={'retry': rtry}) + self.assertEqual(mock_logger.call_args_list[0][0], + ('Connected to Etcd node with term %d. Old known term %d. Switching to another node.', + 1, 2)) + self.assertEqual(mock_logger.call_args_list[1][0], ('Etcd Cluster ID changed from %s to %s', 'a', 'b')) + self.client._base_uri = self.client._machines_cache[0] + with patch('patroni.dcs.etcd.logger.warning') as mock_logger: + rtry(self.client.api_execute, '/', 'POST', timeout=0, params={'retry': rtry}) + self.assertEqual(mock_logger.call_args[0], ('Etcd Cluster ID changed from %s to %s', 'b', 'a')) + + @patch('time.sleep', Mock()) @patch.object(EtcdClient, '_get_machines_list', Mock(return_value=['http://localhost:4001', 'http://localhost:2379'])) def test_api_execute(self):