diff --git a/patroni/dcs/zookeeper.py b/patroni/dcs/zookeeper.py index dae08ffb..be6fe45a 100644 --- a/patroni/dcs/zookeeper.py +++ b/patroni/dcs/zookeeper.py @@ -2,6 +2,7 @@ import logging from kazoo.client import KazooClient, KazooState from kazoo.exceptions import NoNodeError, NodeExistsError +from kazoo.handlers.threading import SequentialThreadingHandler from patroni.dcs import AbstractDCS, ClusterConfig, Cluster, Failover, Leader, Member from patroni.exceptions import DCSError @@ -12,6 +13,34 @@ class ZooKeeperError(DCSError): pass +class PatroniSequentialThreadingHandler(SequentialThreadingHandler): + + def __init__(self, connect_timeout): + super(PatroniSequentialThreadingHandler, self).__init__() + self.set_connect_timeout(connect_timeout) + + def set_connect_timeout(self, connect_timeout): + self._connect_timeout = max(1.0, connect_timeout/4.0) + + def create_connection(self, *args, **kwargs): + """This method is trying to establish connection with one of the zookeeper nodes. + Somehow strategy "fail earlier and retry more often" works way better comparing to + the original strategy "try to connect with specified timeout". + Since we want to try connect to zookeeper more often (with the smaller connect_timeout), + he have to override `create_connection` method in the `SequentialThreadingHandler` + class (which is used by `kazoo.Client`). + + :param args: always contains `tuple(host, port)` as the first element and could contain + `connect_timeout` (negotiated session timeout) as the second element.""" + + args = list(args) + if len(args) == 1: + args.append(self._connect_timeout) + else: + args[1] = max(self._connect_timeout, args[1]/10.0) + return super(PatroniSequentialThreadingHandler, self).create_connection(*args, **kwargs) + + class ZooKeeper(AbstractDCS): def __init__(self, config): @@ -21,7 +50,8 @@ class ZooKeeper(AbstractDCS): if isinstance(hosts, list): hosts = ','.join(hosts) - self._client = KazooClient(hosts, timeout=config['ttl'], connection_retry={'max_delay': 1, 'max_tries': -1}, + self._client = KazooClient(hosts, handler=PatroniSequentialThreadingHandler(config['retry_timeout']), + timeout=config['ttl'], connection_retry={'max_delay': 1, 'max_tries': -1}, command_retry={'deadline': config['retry_timeout'], 'max_delay': 1, 'max_tries': -1}) self._client.add_listener(self.session_listener) @@ -47,6 +77,7 @@ class ZooKeeper(AbstractDCS): self._client.restart() def set_retry_timeout(self, retry_timeout): + self._client.handler.set_connect_timeout(retry_timeout) self._client._retry.deadline = retry_timeout def get_node(self, key, watch=None): diff --git a/tests/test_zookeeper.py b/tests/test_zookeeper.py index 01555d8e..a406ec82 100644 --- a/tests/test_zookeeper.py +++ b/tests/test_zookeeper.py @@ -3,9 +3,10 @@ import unittest from kazoo.client import KazooState from kazoo.exceptions import NoNodeError, NodeExistsError +from kazoo.handlers.threading import SequentialThreadingHandler from kazoo.protocol.states import ZnodeStat from mock import Mock, patch -from patroni.dcs.zookeeper import Leader, ZooKeeper, ZooKeeperError +from patroni.dcs.zookeeper import Leader, PatroniSequentialThreadingHandler, ZooKeeper, ZooKeeperError class MockKazooClient(Mock): @@ -92,6 +93,17 @@ class MockKazooClient(Mock): raise NoNodeError +class TestPatroniSequentialThreadingHandler(unittest.TestCase): + + def setUp(self): + self.handler = PatroniSequentialThreadingHandler(10) + + @patch.object(SequentialThreadingHandler, 'create_connection', Mock()) + def test_create_connection(self): + self.assertIsNotNone(self.handler.create_connection(())) + self.assertIsNotNone(self.handler.create_connection((), 40)) + + class TestZooKeeper(unittest.TestCase): @patch('patroni.dcs.zookeeper.KazooClient', MockKazooClient)