From 76b3b99de2f2bfaa8ab2df9e47dbfc3749d14e84 Mon Sep 17 00:00:00 2001 From: Alexander Kukushkin Date: Tue, 9 May 2023 09:38:00 +0200 Subject: [PATCH] Enable pyright strict mode (#2652) - added pyrightconfig.json with typeCheckingMode=strict - added type hints to all files except api.py - added type stubs for dns, etcd, consul, kazoo, pysyncobj and other modules - added type stubs for psycopg2 and urllib3 with some little fixes - fixes most of the issues reported by pyright - remaining issues will be addressed later, along with enabling CI linting task --- patroni/__init__.py | 15 +- patroni/__main__.py | 43 ++- patroni/async_executor.py | 44 ++- patroni/collections.py | 15 +- patroni/config.py | 70 ++-- patroni/ctl.py | 305 ++++++++------- patroni/daemon.py | 13 +- patroni/dcs/__init__.py | 423 +++++++++++--------- patroni/dcs/consul.py | 174 +++++---- patroni/dcs/etcd.py | 263 +++++++------ patroni/dcs/etcd3.py | 406 +++++++++++--------- patroni/dcs/exhibitor.py | 29 +- patroni/dcs/kubernetes.py | 432 +++++++++++---------- patroni/dcs/raft.py | 155 ++++---- patroni/dcs/zookeeper.py | 182 ++++----- patroni/exceptions.py | 7 +- patroni/ha.py | 361 +++++++++--------- patroni/log.py | 8 +- patroni/postgresql/__init__.py | 305 ++++++++------- patroni/postgresql/bootstrap.py | 47 ++- patroni/postgresql/callback_executor.py | 16 +- patroni/postgresql/cancellable.py | 28 +- patroni/postgresql/citus.py | 114 +++--- patroni/postgresql/config.py | 289 +++++++------- patroni/postgresql/connection.py | 19 +- patroni/postgresql/misc.py | 16 +- patroni/postgresql/postmaster.py | 38 +- patroni/postgresql/rewind.py | 153 ++++---- patroni/postgresql/slots.py | 134 ++++--- patroni/postgresql/sync.py | 15 +- patroni/postgresql/validator.py | 94 +++-- patroni/psycopg.py | 22 +- patroni/raft_controller.py | 15 +- patroni/request.py | 4 +- patroni/scripts/aws.py | 22 +- patroni/scripts/wale_restore.py | 66 ++-- patroni/utils.py | 117 +++--- patroni/validator.py | 119 ++++-- patroni/watchdog/base.py | 82 ++-- patroni/watchdog/linux.py | 64 ++-- pyrightconfig.json | 27 ++ tests/__init__.py | 9 +- tests/test_citus.py | 2 +- tests/test_config.py | 5 +- tests/test_consul.py | 2 + tests/test_ctl.py | 6 +- tests/test_etcd.py | 4 +- tests/test_etcd3.py | 21 +- tests/test_ha.py | 49 +-- tests/test_kubernetes.py | 32 +- tests/test_postgresql.py | 8 +- tests/test_slots.py | 2 + tests/test_utils.py | 2 +- tests/test_wale_restore.py | 2 +- tests/test_watchdog.py | 5 + tests/test_zookeeper.py | 8 +- typings/botocore/__init__.pyi | 0 typings/botocore/exceptions.pyi | 1 + typings/botocore/utils.pyi | 11 + typings/cdiff/__init__.pyi | 5 + typings/consul/__init__.pyi | 2 + typings/consul/base.pyi | 24 ++ typings/dns/resolver.pyi | 17 + typings/etcd/__init__.pyi | 24 ++ typings/etcd/client.pyi | 29 ++ typings/kazoo/client.pyi | 34 ++ typings/kazoo/exceptions.pyi | 12 + typings/kazoo/handlers/threading.pyi | 13 + typings/kazoo/handlers/utils.pyi | 6 + typings/kazoo/protocol/connection.pyi | 6 + typings/kazoo/protocol/states.pyi | 29 ++ typings/kazoo/retry.pyi | 7 + typings/kazoo/security.pyi | 5 + typings/prettytable/__init__.pyi | 13 + typings/psycopg2/__init__.pyi | 52 +++ typings/psycopg2/_ipaddress.pyi | 9 + typings/psycopg2/_json.pyi | 26 ++ typings/psycopg2/_psycopg.pyi | 488 ++++++++++++++++++++++++ typings/psycopg2/_range.pyi | 62 +++ typings/psycopg2/errorcodes.pyi | 304 +++++++++++++++ typings/psycopg2/errors.pyi | 263 +++++++++++++ typings/psycopg2/extensions.pyi | 117 ++++++ typings/psycopg2/extras.pyi | 240 ++++++++++++ typings/psycopg2/pool.pyi | 24 ++ typings/psycopg2/sql.pyi | 50 +++ typings/psycopg2/tz.pyi | 26 ++ typings/pysyncobj/__init__.pyi | 2 + typings/pysyncobj/config.pyi | 13 + typings/pysyncobj/dns_resolver.pyi | 5 + typings/pysyncobj/node.pyi | 6 + typings/pysyncobj/syncobj.pyi | 24 ++ typings/pysyncobj/tcp_connection.pyi | 4 + typings/pysyncobj/transport.pyi | 10 + typings/pysyncobj/utility.pyi | 5 + typings/urllib3/__init__.pyi | 6 + typings/urllib3/_collections.pyi | 28 ++ typings/urllib3/connection.pyi | 2 + typings/urllib3/poolmanager.pyi | 9 + typings/urllib3/response.pyi | 14 + typings/urllib3/util/request.pyi | 9 + typings/urllib3/util/timeout.pyi | 4 + typings/ydiff/__init__.pyi | 5 + 102 files changed, 4803 insertions(+), 2150 deletions(-) create mode 100644 pyrightconfig.json create mode 100644 typings/botocore/__init__.pyi create mode 100644 typings/botocore/exceptions.pyi create mode 100644 typings/botocore/utils.pyi create mode 100644 typings/cdiff/__init__.pyi create mode 100644 typings/consul/__init__.pyi create mode 100644 typings/consul/base.pyi create mode 100644 typings/dns/resolver.pyi create mode 100644 typings/etcd/__init__.pyi create mode 100644 typings/etcd/client.pyi create mode 100644 typings/kazoo/client.pyi create mode 100644 typings/kazoo/exceptions.pyi create mode 100644 typings/kazoo/handlers/threading.pyi create mode 100644 typings/kazoo/handlers/utils.pyi create mode 100644 typings/kazoo/protocol/connection.pyi create mode 100644 typings/kazoo/protocol/states.pyi create mode 100644 typings/kazoo/retry.pyi create mode 100644 typings/kazoo/security.pyi create mode 100644 typings/prettytable/__init__.pyi create mode 100644 typings/psycopg2/__init__.pyi create mode 100644 typings/psycopg2/_ipaddress.pyi create mode 100644 typings/psycopg2/_json.pyi create mode 100644 typings/psycopg2/_psycopg.pyi create mode 100644 typings/psycopg2/_range.pyi create mode 100644 typings/psycopg2/errorcodes.pyi create mode 100644 typings/psycopg2/errors.pyi create mode 100644 typings/psycopg2/extensions.pyi create mode 100644 typings/psycopg2/extras.pyi create mode 100644 typings/psycopg2/pool.pyi create mode 100644 typings/psycopg2/sql.pyi create mode 100644 typings/psycopg2/tz.pyi create mode 100644 typings/pysyncobj/__init__.pyi create mode 100644 typings/pysyncobj/config.pyi create mode 100644 typings/pysyncobj/dns_resolver.pyi create mode 100644 typings/pysyncobj/node.pyi create mode 100644 typings/pysyncobj/syncobj.pyi create mode 100644 typings/pysyncobj/tcp_connection.pyi create mode 100644 typings/pysyncobj/transport.pyi create mode 100644 typings/pysyncobj/utility.pyi create mode 100644 typings/urllib3/__init__.pyi create mode 100644 typings/urllib3/_collections.pyi create mode 100644 typings/urllib3/connection.pyi create mode 100644 typings/urllib3/poolmanager.pyi create mode 100644 typings/urllib3/response.pyi create mode 100644 typings/urllib3/util/request.pyi create mode 100644 typings/urllib3/util/timeout.pyi create mode 100644 typings/ydiff/__init__.pyi diff --git a/patroni/__init__.py b/patroni/__init__.py index f6a101cc..8b91c26c 100644 --- a/patroni/__init__.py +++ b/patroni/__init__.py @@ -1,17 +1,19 @@ import sys +from typing import Any, Callable, Iterator, Tuple + PATRONI_ENV_PREFIX = 'PATRONI_' KUBERNETES_ENV_PREFIX = 'KUBERNETES_' MIN_PSYCOPG2 = (2, 5, 4) -def fatal(string, *args): +def fatal(string: str, *args: Any) -> None: sys.stderr.write('FATAL: ' + string.format(*args) + '\n') sys.exit(1) -def parse_version(version): - def _parse_version(version): +def parse_version(version: str) -> Tuple[int, ...]: + def _parse_version(version: str) -> Iterator[int]: for e in version.split('.'): try: yield int(e) @@ -21,9 +23,11 @@ def parse_version(version): # We pass MIN_PSYCOPG2 and parse_version as arguments to simplify usage of check_psycopg from the setup.py -def check_psycopg(_min_psycopg2=MIN_PSYCOPG2, _parse_version=parse_version): +def check_psycopg(_min_psycopg2: Tuple[int, ...] = MIN_PSYCOPG2, + _parse_version: Callable[[str], Tuple[int, ...]] = parse_version) -> None: min_psycopg2_str = '.'.join(map(str, _min_psycopg2)) + # try psycopg2 try: from psycopg2 import __version__ if _parse_version(__version__) >= _min_psycopg2: @@ -32,10 +36,11 @@ def check_psycopg(_min_psycopg2=MIN_PSYCOPG2, _parse_version=parse_version): except ImportError: version_str = None + # try psycopg3 try: from psycopg import __version__ except ImportError: error = 'Patroni requires psycopg2>={0}, psycopg2-binary, or psycopg>=3.0'.format(min_psycopg2_str) - if version_str: + if version_str is not None: error += ', but only psycopg2=={0} is available'.format(version_str) fatal(error) diff --git a/patroni/__main__.py b/patroni/__main__.py index 4aa21424..9c7f5cea 100644 --- a/patroni/__main__.py +++ b/patroni/__main__.py @@ -3,14 +3,19 @@ import os import signal import time +from typing import Any, Dict, Optional, TYPE_CHECKING + from patroni.daemon import AbstractPatroniDaemon, abstract_main +if TYPE_CHECKING: # pragma: no cover + from .config import Config + logger = logging.getLogger(__name__) class Patroni(AbstractPatroniDaemon): - def __init__(self, config): + def __init__(self, config: 'Config') -> None: from patroni.api import RestApiServer from patroni.dcs import get_dcs from patroni.ha import Ha @@ -33,9 +38,9 @@ class Patroni(AbstractPatroniDaemon): self.tags = self.get_tags() self.next_run = time.time() - self.scheduled_restart = {} + self.scheduled_restart: Dict[str, Any] = {} - def load_dynamic_configuration(self): + def load_dynamic_configuration(self) -> None: from patroni.exceptions import DCSError while True: try: @@ -53,19 +58,19 @@ class Patroni(AbstractPatroniDaemon): logger.warning('Can not get cluster from dcs') time.sleep(5) - def get_tags(self): + def get_tags(self) -> Dict[str, Any]: return {tag: value for tag, value in self.config.get('tags', {}).items() if tag not in ('clonefrom', 'nofailover', 'noloadbalance', 'nosync') or value} @property - def nofailover(self): + def nofailover(self) -> bool: return bool(self.tags.get('nofailover', False)) @property - def nosync(self): + def nosync(self) -> bool: return bool(self.tags.get('nosync', False)) - def reload_config(self, sighup=False, local=False): + def reload_config(self, sighup: bool = False, local: Optional[bool] = False) -> None: try: super(Patroni, self).reload_config(sighup, local) if local: @@ -87,7 +92,7 @@ class Patroni(AbstractPatroniDaemon): def noloadbalance(self): return bool(self.tags.get('noloadbalance', False)) - def schedule_next_run(self): + def schedule_next_run(self) -> None: self.next_run += self.dcs.loop_wait current_time = time.time() nap_time = self.next_run - current_time @@ -100,12 +105,12 @@ class Patroni(AbstractPatroniDaemon): elif self.ha.watch(nap_time): self.next_run = time.time() - def run(self): + def run(self) -> None: self.api.start() self.next_run = time.time() super(Patroni, self).run() - def _run_cycle(self): + def _run_cycle(self) -> None: logger.info(self.ha.run_cycle()) if self.dcs.cluster and self.dcs.cluster.config and self.dcs.cluster.config.data \ @@ -117,7 +122,7 @@ class Patroni(AbstractPatroniDaemon): self.schedule_next_run() - def _shutdown(self): + def _shutdown(self) -> None: try: self.api.shutdown() except Exception: @@ -128,7 +133,7 @@ class Patroni(AbstractPatroniDaemon): logger.exception('Exception during Ha.shutdown') -def patroni_main(): +def patroni_main() -> None: from multiprocessing import freeze_support from patroni.validator import schema @@ -136,18 +141,20 @@ def patroni_main(): abstract_main(Patroni, schema) -def main(): - if os.getpid() != 1: - from patroni import check_psycopg +def main() -> None: + from patroni import check_psycopg - check_psycopg() + check_psycopg() + + if os.getpid() != 1: return patroni_main() # Patroni started with PID=1, it looks like we are in the container + from types import FrameType pid = 0 # Looks like we are in a docker, so we will act like init - def sigchld_handler(signo, stack_frame): + def sigchld_handler(signo: int, stack_frame: Optional[FrameType]) -> None: try: while True: ret = os.waitpid(-1, os.WNOHANG) @@ -158,7 +165,7 @@ def main(): except OSError: pass - def passtochild(signo, stack_frame): + def passtochild(signo: int, stack_frame: Optional[FrameType]): if pid: os.kill(pid, signo) diff --git a/patroni/async_executor.py b/patroni/async_executor.py index 46194fed..39f11013 100644 --- a/patroni/async_executor.py +++ b/patroni/async_executor.py @@ -1,5 +1,10 @@ import logging + from threading import Event, Lock, RLock, Thread +from types import TracebackType +from typing import Any, Callable, Optional, Tuple, Type + +from .postgresql.cancellable import CancellableSubprocess logger = logging.getLogger(__name__) @@ -15,19 +20,19 @@ class CriticalTask(object): call cancel. If the task has completed `cancel()` will return False and `result` field will contain the result of the task. When cancel returns True it is guaranteed that the background task will notice the `is_cancelled` flag. """ - def __init__(self): + def __init__(self) -> None: self._lock = Lock() self.is_cancelled = False self.result = None - def reset(self): + def reset(self) -> None: """Must be called every time the background task is finished. Must be called from async thread. Caller must hold lock on async executor when calling.""" self.is_cancelled = False self.result = None - def cancel(self): + def cancel(self) -> bool: """Tries to cancel the task, returns True if the task has already run. Caller must hold lock on async executor and the task when calling.""" @@ -36,37 +41,38 @@ class CriticalTask(object): self.is_cancelled = True return True - def complete(self, result): + def complete(self, result: Any) -> None: """Mark task as completed along with a result. Must be called from async thread. Caller must hold lock on task when calling.""" self.result = result - def __enter__(self): + def __enter__(self) -> 'CriticalTask': self._lock.acquire() return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]) -> None: self._lock.release() class AsyncExecutor(object): - def __init__(self, cancellable, ha_wakeup): + def __init__(self, cancellable: CancellableSubprocess, ha_wakeup: Callable[..., None]) -> None: self._cancellable = cancellable self._ha_wakeup = ha_wakeup self._thread_lock = RLock() - self._scheduled_action = None + self._scheduled_action: Optional[str] = None self._scheduled_action_lock = RLock() self._is_cancelled = False self._finish_event = Event() self.critical_task = CriticalTask() @property - def busy(self): + def busy(self) -> bool: return self.scheduled_action is not None - def schedule(self, action): + def schedule(self, action: str) -> Optional[str]: with self._scheduled_action_lock: if self._scheduled_action is not None: return self._scheduled_action @@ -76,15 +82,15 @@ class AsyncExecutor(object): return None @property - def scheduled_action(self): + def scheduled_action(self) -> Optional[str]: with self._scheduled_action_lock: return self._scheduled_action - def reset_scheduled_action(self): + def reset_scheduled_action(self) -> None: with self._scheduled_action_lock: self._scheduled_action = None - def run(self, func, args=()): + def run(self, func: Callable[..., Any], args: Tuple[Any, ...] = ()) -> Optional[bool]: wakeup = False try: with self: @@ -107,16 +113,16 @@ class AsyncExecutor(object): if wakeup is not None: self._ha_wakeup() - def run_async(self, func, args=()): + def run_async(self, func: Callable[..., Any], args: Tuple[Any, ...] = ()) -> None: Thread(target=self.run, args=(func, args)).start() - def try_run_async(self, action, func, args=()): + def try_run_async(self, action: str, func: Callable[..., Any], args: Tuple[Any, ...] = ()) -> Optional[str]: prev = self.schedule(action) if prev is None: return self.run_async(func, args) return 'Failed to run {0}, {1} is already in progress'.format(action, prev) - def cancel(self): + def cancel(self) -> None: with self: with self._scheduled_action_lock: if self._scheduled_action is None: @@ -130,8 +136,10 @@ class AsyncExecutor(object): with self: self.reset_scheduled_action() - def __enter__(self): + def __enter__(self) -> 'AsyncExecutor': self._thread_lock.acquire() + return self - def __exit__(self, *args): + def __exit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]) -> None: self._thread_lock.release() diff --git a/patroni/collections.py b/patroni/collections.py index f16a6f86..6560c35d 100644 --- a/patroni/collections.py +++ b/patroni/collections.py @@ -1,16 +1,15 @@ from collections import OrderedDict -from collections.abc import MutableMapping, MutableSet -from typing import Any, Collection, Dict, Iterable, Iterator, Optional, Tuple, Union +from typing import Any, Collection, Dict, Iterator, MutableMapping, MutableSet, Optional -class CaseInsensitiveSet(MutableSet): +class CaseInsensitiveSet(MutableSet[str]): """A case-insensitive ``set``-like object. Implements all methods and operations of :class:``MutableSet``. All values are expected to be strings. The structure remembers the case of the last value set, however, contains testing is case insensitive. """ def __init__(self, values: Optional[Collection[str]] = None) -> None: - self._values = {} + self._values: Dict[str, str] = {} for v in values or (): self.add(v) @@ -39,7 +38,7 @@ class CaseInsensitiveSet(MutableSet): return self <= other -class CaseInsensitiveDict(MutableMapping): +class CaseInsensitiveDict(MutableMapping[str, Any]): """A case-insensitive ``dict``-like object. Implements all methods and operations of :class:``MutableMapping`` as well as dict's :func:``copy``. @@ -47,8 +46,8 @@ class CaseInsensitiveDict(MutableMapping): and ``iter(instance)``, ``keys()``, ``items()``, ``iterkeys()``, and ``iteritems()`` will contain case-sensitive keys. However, querying and contains testing is case insensitive. """ - def __init__(self, data: Optional[Union[Dict[str, Any], Iterable[Tuple[str, Any]]]] = None) -> None: - self._values = OrderedDict() + def __init__(self, data: Optional[Dict[str, Any]] = None) -> None: + self._values: OrderedDict[str, Any] = OrderedDict() self.update(data or {}) def __setitem__(self, key: str, value: Any) -> None: @@ -68,7 +67,7 @@ class CaseInsensitiveDict(MutableMapping): return len(self._values) def copy(self) -> 'CaseInsensitiveDict': - return CaseInsensitiveDict(self._values.values()) + return CaseInsensitiveDict({v[0]: v[1] for v in self._values.values()}) def __repr__(self) -> str: return '<{0}{1} at {2:x}>'.format(type(self).__name__, dict(self.items()), id(self)) diff --git a/patroni/config.py b/patroni/config.py index 44737772..4d3eea8e 100644 --- a/patroni/config.py +++ b/patroni/config.py @@ -7,7 +7,7 @@ import yaml from collections import defaultdict from copy import deepcopy -from typing import Any, Dict, Optional, Union +from typing import Any, Callable, Collection, Dict, List, Optional, Union from . import PATRONI_ENV_PREFIX from .collections import CaseInsensitiveDict @@ -33,9 +33,10 @@ _AUTH_ALLOWED_PARAMETERS = ( ) -def default_validator(conf): +def default_validator(conf: Dict[str, Any]) -> List[str]: if not conf: raise ConfigParseError("Config is empty.") + return [] class GlobalConfig(object): @@ -86,7 +87,7 @@ class GlobalConfig(object): """:returns: `True` if at least one synchronous node is required.""" return self.check_mode('synchronous_mode_strict') - def get_standby_cluster_config(self) -> Any: + def get_standby_cluster_config(self) -> Union[Dict[str, Any], Any]: """:returns: "standby_cluster" configuration.""" return deepcopy(self.get('standby_cluster')) @@ -142,7 +143,7 @@ class GlobalConfig(object): if 'primary_stop_timeout' in self.__config else self.get_int('master_stop_timeout', default) -def get_global_config(cluster: Union[Cluster, None], default: Optional[Dict] = None) -> GlobalConfig: +def get_global_config(cluster: Union[Cluster, None], default: Optional[Dict[str, Any]] = None) -> GlobalConfig: """Instantiates :class:`GlobalConfig` based on the input. :param cluster: the currently known cluster state from DCS @@ -180,7 +181,7 @@ class Config(object): PATRONI_CONFIG_VARIABLE = PATRONI_ENV_PREFIX + 'CONFIGURATION' __CACHE_FILENAME = 'patroni.dynamic.json' - __DEFAULT_CONFIG = { + __DEFAULT_CONFIG: Dict[str, Any] = { 'ttl': 30, 'loop_wait': 10, 'retry_timeout': 10, 'standby_cluster': { 'create_replica_methods': '', @@ -199,19 +200,21 @@ class Config(object): } } - def __init__(self, configfile, validator=default_validator): + def __init__(self, configfile: str, + validator: Optional[Callable[[Dict[str, Any]], List[str]]] = default_validator) -> None: self._modify_index = -1 self._dynamic_configuration = {} self.__environment_configuration = self._build_environment_configuration() # Patroni reads the configuration from the command-line argument if it exists, otherwise from the environment - self._config_file = configfile and os.path.exists(configfile) and configfile + self._config_file = configfile if configfile and os.path.exists(configfile) else None if self._config_file: self._local_configuration = self._load_config_file() else: config_env = os.environ.pop(self.PATRONI_CONFIG_VARIABLE, None) self._local_configuration = config_env and yaml.safe_load(config_env) or self.__environment_configuration + if validator: errors = validator(self._local_configuration) if errors: @@ -224,14 +227,14 @@ class Config(object): self._cache_needs_saving = False @property - def config_file(self): + def config_file(self) -> Union[str, None]: return self._config_file @property - def dynamic_configuration(self): + def dynamic_configuration(self) -> Dict[str, Any]: return deepcopy(self._dynamic_configuration) - def _load_config_path(self, path): + def _load_config_path(self, path: str) -> Dict[str, Any]: """ If path is a file, loads the yml file pointed to by path. If path is a directory, loads all yml files in that directory in alphabetical order @@ -245,20 +248,21 @@ class Config(object): logger.error('config path %s is neither directory nor file', path) raise ConfigParseError('invalid config path') - overall_config = {} + overall_config: Dict[str, Any] = {} for fname in files: with open(fname) as f: config = yaml.safe_load(f) patch_config(overall_config, config) return overall_config - def _load_config_file(self): + def _load_config_file(self) -> Dict[str, Any]: """Loads config.yaml from filesystem and applies some values which were set via ENV""" + assert self._config_file is not None config = self._load_config_path(self._config_file) patch_config(config, self.__environment_configuration) return config - def _load_cache(self): + def _load_cache(self) -> None: if os.path.isfile(self._cache_file): try: with open(self._cache_file) as f: @@ -266,7 +270,7 @@ class Config(object): except Exception: logger.exception('Exception when loading file: %s', self._cache_file) - def save_cache(self): + def save_cache(self) -> None: if self._cache_needs_saving: tmpfile = fd = None try: @@ -290,7 +294,7 @@ class Config(object): logger.error('Can not remove temporary file %s', tmpfile) # configuration could be either ClusterConfig or dict - def set_dynamic_configuration(self, configuration): + def set_dynamic_configuration(self, configuration: Union[ClusterConfig, Dict[str, Any]]) -> bool: if isinstance(configuration, ClusterConfig): if self._modify_index == configuration.modify_index: return False # If the index didn't changed there is nothing to do @@ -306,8 +310,9 @@ class Config(object): return True except Exception: logger.exception('Exception when setting dynamic_configuration') + return False - def reload_local_configuration(self): + def reload_local_configuration(self) -> Optional[bool]: if self.config_file: try: configuration = self._load_config_file() @@ -322,12 +327,12 @@ class Config(object): logger.exception('Exception when reloading local configuration from %s', self.config_file) @staticmethod - def _process_postgresql_parameters(parameters, is_local=False): + def _process_postgresql_parameters(parameters: Dict[str, Any], is_local: bool = False) -> Dict[str, Any]: return {name: value for name, value in (parameters or {}).items() if name not in ConfigHandler.CMDLINE_OPTIONS or not is_local and ConfigHandler.CMDLINE_OPTIONS[name][1](value)} - def _safe_copy_dynamic_configuration(self, dynamic_configuration): + def _safe_copy_dynamic_configuration(self, dynamic_configuration: Dict[str, Any]) -> Dict[str, Any]: config = deepcopy(self.__DEFAULT_CONFIG) for name, value in dynamic_configuration.items(): @@ -347,10 +352,10 @@ class Config(object): return config @staticmethod - def _build_environment_configuration(): - ret = defaultdict(dict) + def _build_environment_configuration() -> Dict[str, Any]: + ret: Dict[str, Any] = defaultdict(dict) - def _popenv(name): + def _popenv(name: str) -> Union[str, None]: return os.environ.pop(PATRONI_ENV_PREFIX + name.upper(), None) for param in ('name', 'namespace', 'scope'): @@ -358,7 +363,7 @@ class Config(object): if value: ret[param] = value - def _fix_log_env(name, oldname): + def _fix_log_env(name: str, oldname: str) -> None: value = _popenv(oldname) name = PATRONI_ENV_PREFIX + 'LOG_' + name.upper() if value and name not in os.environ: @@ -367,7 +372,7 @@ class Config(object): for name, oldname in (('level', 'loglevel'), ('format', 'logformat'), ('dateformat', 'log_datefmt')): _fix_log_env(name, oldname) - def _set_section_values(section, params): + def _set_section_values(section: str, params: List[str]) -> None: for param in params: value = _popenv(section + '_' + param) if value: @@ -400,7 +405,7 @@ class Config(object): if value is not None: ret[first][second] = value - def _parse_list(value): + def _parse_list(value: str) -> Union[List[str], None]: if not (value.strip().startswith('-') or '[' in value): value = '[{0}]'.format(value) try: @@ -416,7 +421,7 @@ class Config(object): if value: ret[first][second] = value - def _parse_dict(value): + def _parse_dict(value: str) -> Union[Dict[str, Any], None]: if not value.strip().startswith('{'): value = '{{{0}}}'.format(value) try: @@ -433,8 +438,8 @@ class Config(object): if value: ret[first][second] = value - def _get_auth(name, params=None): - ret = {} + def _get_auth(name: str, params: Optional[Collection[str]] = None) -> Dict[str, str]: + ret: Dict[str, str] = {} for param in params or _AUTH_ALLOWED_PARAMETERS[:2]: value = _popenv(name + '_' + param) if value: @@ -503,7 +508,8 @@ class Config(object): return ret - def _build_effective_configuration(self, dynamic_configuration, local_configuration): + def _build_effective_configuration(self, dynamic_configuration: Dict[str, Any], + local_configuration: Dict[str, Union[Dict[str, Any], Any]]) -> Dict[str, Any]: config = self._safe_copy_dynamic_configuration(dynamic_configuration) for name, value in local_configuration.items(): if name == 'citus': # remove invalid citus configuration @@ -565,16 +571,16 @@ class Config(object): return config - def get(self, key, default=None): + def get(self, key: str, default: Optional[Any] = None) -> Any: return self.__effective_configuration.get(key, default) - def __contains__(self, key): + def __contains__(self, key: str) -> bool: return key in self.__effective_configuration - def __getitem__(self, key): + def __getitem__(self, key: str) -> Any: return self.__effective_configuration[key] - def copy(self): + def copy(self) -> Dict[str, Any]: return deepcopy(self.__effective_configuration) def get_global_config(self, cluster: Union[Cluster, None]) -> GlobalConfig: diff --git a/patroni/ctl.py b/patroni/ctl.py index e78c0f27..786ea4ec 100644 --- a/patroni/ctl.py +++ b/patroni/ctl.py @@ -18,23 +18,25 @@ import shutil import subprocess import sys import tempfile +import urllib3 import time import yaml -from typing import Any, Dict, Union - -from click import ClickException from collections import defaultdict from contextlib import contextmanager from prettytable import ALL, FRAME, PrettyTable from urllib.parse import urlparse +from typing import Any, Dict, Generator, Iterator, List, Optional, Union, Tuple, TYPE_CHECKING +if TYPE_CHECKING: # pragma: no cover + from psycopg import Cursor + from psycopg2 import cursor try: from ydiff import markup_to_pager, PatchStream except ImportError: # pragma: no cover from cdiff import markup_to_pager, PatchStream -from .dcs import get_dcs as _get_dcs +from .dcs import get_dcs as _get_dcs, AbstractDCS, Cluster, Member from .exceptions import PatroniException from .postgresql.misc import postgres_version_to_int from .utils import cluster_as_json, patch_config, polling_loop @@ -43,30 +45,31 @@ from .version import __version__ CONFIG_DIR_PATH = click.get_app_dir('patroni') CONFIG_FILE_PATH = os.path.join(CONFIG_DIR_PATH, 'patronictl.yaml') -DCS_DEFAULTS = {'zookeeper': {'port': 2181, 'template': "zookeeper:\n hosts: ['{host}:{port}']"}, - 'exhibitor': {'port': 8181, 'template': "exhibitor:\n hosts: [{host}]\n port: {port}"}, - 'consul': {'port': 8500, 'template': "consul:\n host: '{host}:{port}'"}, - 'etcd': {'port': 2379, 'template': "etcd:\n host: '{host}:{port}'"}, - 'etcd3': {'port': 2379, 'template': "etcd3:\n host: '{host}:{port}'"}} +DCS_DEFAULTS: Dict[str, Dict[str, Any]] = { + 'zookeeper': {'port': 2181, 'template': "zookeeper:\n hosts: ['{host}:{port}']"}, + 'exhibitor': {'port': 8181, 'template': "exhibitor:\n hosts: [{host}]\n port: {port}"}, + 'consul': {'port': 8500, 'template': "consul:\n host: '{host}:{port}'"}, + 'etcd': {'port': 2379, 'template': "etcd:\n host: '{host}:{port}'"}, + 'etcd3': {'port': 2379, 'template': "etcd3:\n host: '{host}:{port}'"}} -class PatroniCtlException(ClickException): +class PatroniCtlException(click.ClickException): pass class PatronictlPrettyTable(PrettyTable): - def __init__(self, header, *args, **kwargs): - PrettyTable.__init__(self, *args, **kwargs) + def __init__(self, header: str, *args: Any, **kwargs: Any) -> None: + super(PatronictlPrettyTable, self).__init__(*args, **kwargs) self.__table_header = header self.__hline_num = 0 - self.__hline = None + self.__hline: str - def __build_header(self, line): + def __build_header(self, line: str) -> str: header = self.__table_header[:len(line) - 2] return "".join([line[0], header, line[1 + len(header):]]) - def _stringify_hrule(self, *args, **kwargs): + def _stringify_hrule(self, *args: Any, **kwargs: Any) -> str: ret = super(PatronictlPrettyTable, self)._stringify_hrule(*args, **kwargs) where = args[1] if len(args) > 1 else kwargs.get('where') if where == 'top_' and self.__table_header: @@ -74,13 +77,13 @@ class PatronictlPrettyTable(PrettyTable): self.__hline_num += 1 return ret - def _is_first_hline(self): + def _is_first_hline(self) -> bool: return self.__hline_num == 0 - def _set_hline(self, value): + def _set_hline(self, value: str) -> None: self.__hline = value - def _get_hline(self): + def _get_hline(self) -> str: ret = self.__hline # Inject nice table header @@ -93,7 +96,7 @@ class PatronictlPrettyTable(PrettyTable): _hrule = property(_get_hline, _set_hline) -def parse_dcs(dcs: str) -> Union[Dict[str, Any], None]: +def parse_dcs(dcs: Optional[str]) -> Optional[Dict[str, Any]]: """Parse a DCS URL. :param dcs: the DCS URL in the format ``DCS://HOST:PORT``. ``DCS`` can be one among @@ -143,7 +146,7 @@ def parse_dcs(dcs: str) -> Union[Dict[str, Any], None]: return yaml.safe_load(default['template'].format(host=parsed.hostname or 'localhost', port=port or default['port'])) -def load_config(path, dcs_url): +def load_config(path: str, dcs_url: Optional[str]) -> Dict[str, Any]: from patroni.config import Config if not (os.path.exists(path) and os.access(path, os.R_OK)): @@ -156,11 +159,11 @@ def load_config(path, dcs_url): logging.debug('Loading configuration from file %s', path) config = Config(path, validator=None).copy() - dcs_url = parse_dcs(dcs_url) or {} - if dcs_url: + dcs_kwargs = parse_dcs(dcs_url) or {} + if dcs_kwargs: for d in DCS_DEFAULTS: config.pop(d, None) - config.update(dcs_url) + config.update(dcs_kwargs) return config @@ -183,7 +186,7 @@ role_choice = click.Choice(['leader', 'primary', 'standby-leader', 'replica', 's @click.option('--dcs-url', '--dcs', '-d', 'dcs_url', help='The DCS connect url', envvar='DCS_URL') @option_insecure @click.pass_context -def ctl(ctx, config_file, dcs_url, insecure): +def ctl(ctx: click.Context, config_file: str, dcs_url: Optional[str], insecure: bool) -> None: level = 'WARNING' for name in ('LOGLEVEL', 'PATRONI_LOGLEVEL', 'PATRONI_LOG_LEVEL'): level = os.environ.get(name, level) @@ -194,7 +197,7 @@ def ctl(ctx, config_file, dcs_url, insecure): ctx.obj.setdefault('ctl', {})['insecure'] = ctx.obj.get('ctl', {}).get('insecure') or insecure -def get_dcs(config, scope, group): +def get_dcs(config: Dict[str, Any], scope: str, group: Optional[int]) -> AbstractDCS: config.update({'scope': scope, 'patronictl': True}) if group is not None: config['citus'] = {'group': group} @@ -202,13 +205,14 @@ def get_dcs(config, scope, group): try: dcs = _get_dcs(config) if config.get('citus') and group is None: - dcs.get_cluster = dcs._get_citus_cluster + dcs.is_citus_coordinator = lambda: True return dcs except PatroniException as e: raise PatroniCtlException(str(e)) -def request_patroni(member, method='GET', endpoint=None, data=None): +def request_patroni(member: Member, method: str = 'GET', + endpoint: Optional[str] = None, data: Optional[Any] = None) -> urllib3.response.HTTPResponse: ctx = click.get_current_context() # the current click context request_executor = ctx.obj.get('__request_patroni') if not request_executor: @@ -216,19 +220,20 @@ def request_patroni(member, method='GET', endpoint=None, data=None): return request_executor(member, method, endpoint, data) -def print_output(columns, rows, alignment=None, fmt='pretty', header=None, delimiter='\t'): +def print_output(columns: Optional[List[str]], rows: List[List[Any]], alignment: Optional[Dict[str, str]] = None, + fmt: str = 'pretty', header: str = '', delimiter: str = '\t') -> None: if fmt in {'json', 'yaml', 'yml'}: - elements = [{k: v for k, v in zip(columns, r) if not header or str(v)} for r in rows] + elements = [{k: v for k, v in zip(columns or [], r) if not header or str(v)} for r in rows] func = json.dumps if fmt == 'json' else format_config_for_editing click.echo(func(elements)) elif fmt in {'pretty', 'tsv', 'topology'}: list_cluster = bool(header and columns and columns[0] == 'Cluster') - if list_cluster and 'Tags' in columns: # we want to format member tags as YAML + if list_cluster and columns and 'Tags' in columns: # we want to format member tags as YAML i = columns.index('Tags') for row in rows: if row[i]: row[i] = format_config_for_editing(row[i], fmt != 'pretty').strip() - if list_cluster and fmt != 'tsv': # skip cluster name and maybe Citus group if pretty-printing + if list_cluster and header and fmt != 'tsv': # skip cluster name and maybe Citus group if pretty-printing skip_cols = 2 if ' (group: ' in header else 1 columns = columns[skip_cols:] if columns else [] rows = [row[skip_cols:] for row in rows] @@ -247,7 +252,7 @@ def print_output(columns, rows, alignment=None, fmt='pretty', header=None, delim click.echo(table) -def watching(w, watch, max_count=None, clear=True): +def watching(w: bool, watch: Optional[int], max_count: Optional[int] = None, clear: bool = True) -> Iterator[int]: """ >>> len(list(watching(True, 1, 0))) 1 @@ -275,7 +280,8 @@ def watching(w, watch, max_count=None, clear=True): yield 0 -def get_all_members(obj, cluster, group, role='leader'): +def get_all_members(obj: Dict[str, Any], cluster: Cluster, + group: Optional[int], role: str = 'leader') -> Iterator[Member]: clusters = {0: cluster} if obj.get('citus') and group is None: clusters.update(cluster.workers) @@ -296,23 +302,25 @@ def get_all_members(obj, cluster, group, role='leader'): yield m -def get_any_member(obj, cluster, group, role='leader', member=None): +def get_any_member(obj: Dict[str, Any], cluster: Cluster, group: Optional[int], + role: str = 'leader', member: Optional[str] = None) -> Optional[Member]: for m in get_all_members(obj, cluster, group, role): if member is None or m.name == member: return m -def get_all_members_leader_first(cluster): +def get_all_members_leader_first(cluster: Cluster) -> Iterator[Member]: leader_name = cluster.leader.member.name if cluster.leader and cluster.leader.member.api_url else None - if leader_name: + if leader_name and cluster.leader: yield cluster.leader.member for member in cluster.members: if member.api_url and member.name != leader_name: yield member -def get_cursor(obj, cluster, group, connect_parameters, role='leader', member=None): - member = get_any_member(obj, cluster, group, role=role, member=member) +def get_cursor(obj: Dict[str, Any], cluster: Cluster, group: Optional[int], connect_parameters: Dict[str, Any], + role: str = 'leader', member_name: Optional[str] = None) -> Union['cursor', 'Cursor[Any]', None]: + member = get_any_member(obj, cluster, group, role=role, member=member_name) if member is None: return None @@ -330,7 +338,8 @@ def get_cursor(obj, cluster, group, connect_parameters, role='leader', member=No return cursor cursor.execute('SELECT pg_catalog.pg_is_in_recovery()') - in_recovery = cursor.fetchone()[0] + row = cursor.fetchone() + in_recovery = not row or row[0] if in_recovery and role in ('replica', 'standby', 'standby-leader')\ or not in_recovery and role in ('master', 'primary'): @@ -341,7 +350,8 @@ def get_cursor(obj, cluster, group, connect_parameters, role='leader', member=No return None -def get_members(obj, cluster, cluster_name, member_names, role, force, action, ask_confirmation=True, group=None): +def get_members(obj: Dict[str, Any], cluster: Cluster, cluster_name: str, member_names: List[str], role: str, + force: bool, action: str, ask_confirmation: bool = True, group: Optional[int] = None) -> List[Member]: members = list(get_all_members(obj, cluster, group, role)) candidates = {m.name for m in members} @@ -371,7 +381,8 @@ def get_members(obj, cluster, cluster_name, member_names, role, force, action, a return members -def confirm_members_action(members, force, action, scheduled_at=None): +def confirm_members_action(members: List[Member], force: bool, action: str, + scheduled_at: Optional[datetime.datetime] = None) -> None: if scheduled_at: if not force: confirm = click.confirm('Are you sure you want to schedule {0} of members {1} at {2}?' @@ -392,12 +403,13 @@ def confirm_members_action(members, force, action, scheduled_at=None): @arg_cluster_name @option_citus_group @click.pass_obj -def dsn(obj, cluster_name, group, role, member): +def dsn(obj: Dict[str, Any], cluster_name: str, group: Optional[int], + role: Optional[str], member: Optional[str]) -> None: if member is not None: if role is not None: raise PatroniCtlException('--role and --member are mutually exclusive options') role = 'any' - if member is None and role is None: + elif role is None: role = 'leader' cluster = get_dcs(obj, cluster_name, group).get_cluster() @@ -425,33 +437,36 @@ def dsn(obj, cluster_name, group, role, member): @click.option('-d', '--dbname', help='database name to connect to', type=str) @click.pass_obj def query( - obj, - cluster_name, - group, - role, - member, - w, - watch, - delimiter, - command, - p_file, - password, - username, - dbname, - fmt='tsv', -): + obj: Dict[str, Any], + cluster_name: str, + group: Optional[int], + role: Optional[str], + member: Optional[str], + w: bool, + watch: Optional[int], + delimiter: str, + command: Optional[str], + p_file: Optional[io.BufferedReader], + password: Optional[bool], + username: Optional[str], + dbname: Optional[str], + fmt: str = 'tsv' +) -> None: if member is not None: if role is not None: raise PatroniCtlException('--role and --member are mutually exclusive options') role = 'any' - if member is None and role is None: + elif role is None: role = 'leader' - if p_file is not None and command is not None: - raise PatroniCtlException('--file and --command are mutually exclusive options') - - if p_file is None and command is None: - raise PatroniCtlException('You need to specify either --command or --file') + if p_file is not None: + if command is not None: + raise PatroniCtlException('--file and --command are mutually exclusive options') + sql = p_file.read().decode('utf-8') + else: + if command is None: + raise PatroniCtlException('You need to specify either --command or --file') + sql = command connect_parameters = {} if username: @@ -461,25 +476,25 @@ def query( if dbname: connect_parameters['dbname'] = dbname - if p_file is not None: - command = p_file.read() - dcs = get_dcs(obj, cluster_name, group) - cursor = None + cluster = cursor = None for _ in watching(w, watch, clear=False): - if cursor is None: + if cluster is None: cluster = dcs.get_cluster() +# cursor = get_cursor(obj, cluster, group, connect_parameters, role=role, member=member) - output, header = query_member(obj, cluster, group, cursor, member, role, command, connect_parameters) + output, header = query_member(obj, cluster, group, cursor, member, role, sql, connect_parameters) print_output(header, output, fmt=fmt, delimiter=delimiter) -def query_member(obj, cluster, group, cursor, member, role, command, connect_parameters): +def query_member(obj: Dict[str, Any], cluster: Cluster, group: Optional[int], + cursor: Union['cursor', 'Cursor[Any]', None], member: Optional[str], role: str, + command: str, connect_parameters: Dict[str, Any]) -> Tuple[List[List[Any]], Optional[List[Any]]]: from . import psycopg try: if cursor is None: - cursor = get_cursor(obj, cluster, group, connect_parameters, role=role, member=member) + cursor = get_cursor(obj, cluster, group, connect_parameters, role=role, member_name=member) if cursor is None: if member is not None: @@ -489,8 +504,8 @@ def query_member(obj, cluster, group, cursor, member, role, command, connect_par logging.debug(message) return [[timestamp(0), message]], None - cursor.execute(command) - return cursor.fetchall(), [d.name for d in cursor.description] + cursor.execute(command.encode('utf-8')) + return [list(row) for row in cursor], cursor.description and [d.name for d in cursor.description] except psycopg.DatabaseError as de: logging.debug(de) if cursor is not None and not cursor.connection.closed: @@ -505,7 +520,7 @@ def query_member(obj, cluster, group, cursor, member, role, command, connect_par @option_citus_group @option_format @click.pass_obj -def remove(obj, cluster_name, group, fmt): +def remove(obj: Dict[str, Any], cluster_name: str, group: Optional[int], fmt: str) -> None: dcs = get_dcs(obj, cluster_name, group) cluster = dcs.get_cluster() @@ -532,7 +547,8 @@ def remove(obj, cluster_name, group, fmt): dcs.delete_cluster() -def check_response(response, member_name, action_name, silent_success=False): +def check_response(response: urllib3.response.HTTPResponse, member_name: str, + action_name: str, silent_success: bool = False) -> bool: if response.status >= 400: click.echo('Failed: {0} for member {1}, status code={2}, ({3})'.format( action_name, member_name, response.status, response.data.decode('utf-8') @@ -543,8 +559,8 @@ def check_response(response, member_name, action_name, silent_success=False): return True -def parse_scheduled(scheduled): - if (scheduled or 'now') != 'now': +def parse_scheduled(scheduled: Optional[str]) -> Optional[datetime.datetime]: + if scheduled is not None and (scheduled or 'now') != 'now': try: scheduled_at = dateutil.parser.parse(scheduled) if scheduled_at.tzinfo is None: @@ -564,7 +580,8 @@ def parse_scheduled(scheduled): @click.option('--role', '-r', help='Reload only members with this role', type=role_choice, default='any') @option_force @click.pass_obj -def reload(obj, cluster_name, member_names, group, force, role): +def reload(obj: Dict[str, Any], cluster_name: str, member_names: List[str], + group: Optional[int], force: bool, role: str) -> None: dcs = get_dcs(obj, cluster_name, group) cluster = dcs.get_cluster() @@ -575,8 +592,10 @@ def reload(obj, cluster_name, member_names, group, force, role): if r.status == 200: click.echo('No changes to apply on member {0}'.format(member.name)) elif r.status == 202: + from patroni.config import get_global_config + config = get_global_config(cluster) click.echo('Reload request received for member {0} and will be processed within {1} seconds'.format( - member.name, cluster.config.data.get('loop_wait', dcs.loop_wait)) + member.name, config.get('loop_wait') or dcs.loop_wait) ) else: click.echo('Failed: reload for member {0}, status code={1}, ({2})'.format( @@ -595,11 +614,12 @@ def reload(obj, cluster_name, member_names, group, force, role): @click.option('--pg-version', 'version', help='Restart if the PostgreSQL version is less than provided (e.g. 9.5.2)', default=None) @click.option('--pending', help='Restart if pending', is_flag=True) -@click.option('--timeout', - help='Return error and fail over if necessary when restarting takes longer than this.') +@click.option('--timeout', help='Return error and fail over if necessary when restarting takes longer than this.') @option_force @click.pass_obj -def restart(obj, cluster_name, group, member_names, force, role, p_any, scheduled, version, pending, timeout): +def restart(obj: Dict[str, Any], cluster_name: str, group: Optional[int], member_names: List[str], + force: bool, role: str, p_any: bool, scheduled: Optional[str], version: Optional[str], + pending: bool, timeout: Optional[str]) -> None: cluster = get_dcs(obj, cluster_name, group).get_cluster() members = get_members(obj, cluster, cluster_name, member_names, role, force, 'restart', False, group=group) @@ -666,13 +686,14 @@ def restart(obj, cluster_name, group, member_names, force, role, p_any, schedule @option_force @click.option('--wait', help='Wait until reinitialization completes', is_flag=True) @click.pass_obj -def reinit(obj, cluster_name, group, member_names, force, wait): +def reinit(obj: Dict[str, Any], cluster_name: str, group: Optional[int], + member_names: List[str], force: bool, wait: bool) -> None: cluster = get_dcs(obj, cluster_name, group).get_cluster() members = get_members(obj, cluster, cluster_name, member_names, 'replica', force, 'reinitialize', group=group) - wait_on_members = [] + wait_on_members: List[Member] = [] for member in members: - body = {'force': force} + body: Dict[str, bool] = {'force': force} while True: r = request_patroni(member, 'post', 'reinitialize', body) started = check_response(r, member.name, 'reinitialize') @@ -699,7 +720,9 @@ def reinit(obj, cluster_name, group, member_names, force, wait): wait_on_members.remove(member) -def _do_failover_or_switchover(obj, action, cluster_name, group, leader, candidate, force, scheduled=None): +def _do_failover_or_switchover(obj: Dict[str, Any], action: str, cluster_name: str, + group: Optional[int], leader: Optional[str], candidate: Optional[str], + force: bool, scheduled: Optional[str] = None) -> None: """ We want to trigger a failover or switchover for the specified cluster name. @@ -729,7 +752,7 @@ def _do_failover_or_switchover(obj, action, cluster_name, group, leader, candida else: from patroni.config import get_global_config prompt = 'Standby Leader' if get_global_config(cluster).is_standby_cluster else 'Primary' - leader = click.prompt(prompt, type=str, default=cluster.leader.member.name) + leader = click.prompt(prompt, type=str, default=(cluster.leader and cluster.leader.member.name)) if leader is not None and cluster.leader and cluster.leader.member.name != leader: raise PatroniCtlException('Member {0} is not the leader of cluster {1}'.format(leader, cluster_name)) @@ -788,8 +811,8 @@ def _do_failover_or_switchover(obj, action, cluster_name, group, leader, candida r = None try: - member = cluster.leader.member if cluster.leader else cluster.get_member(candidate, False) - + member = cluster.leader.member if cluster.leader else candidate and cluster.get_member(candidate, False) + assert isinstance(member, Member) r = request_patroni(member, 'post', action, failover_value) # probably old patroni, which doesn't support switchover yet @@ -820,7 +843,8 @@ def _do_failover_or_switchover(obj, action, cluster_name, group, leader, candida @click.option('--candidate', help='The name of the candidate', default=None) @option_force @click.pass_obj -def failover(obj, cluster_name, group, leader, candidate, force): +def failover(obj: Dict[str, Any], cluster_name: str, group: Optional[int], + leader: Optional[str], candidate: Optional[str], force: bool) -> None: action = 'switchover' if leader else 'failover' _do_failover_or_switchover(obj, action, cluster_name, group, leader, candidate, force) @@ -834,11 +858,13 @@ def failover(obj, cluster_name, group, leader, candidate, force): default=None) @option_force @click.pass_obj -def switchover(obj, cluster_name, group, leader, candidate, force, scheduled): +def switchover(obj: Dict[str, Any], cluster_name: str, group: Optional[int], + leader: Optional[str], candidate: Optional[str], force: bool, scheduled: Optional[str]) -> None: _do_failover_or_switchover(obj, 'switchover', cluster_name, group, leader, candidate, force, scheduled) -def generate_topology(level, member, topology): +def generate_topology(level: int, member: Dict[str, Any], + topology: Dict[str, List[Dict[str, Any]]]) -> Iterator[Dict[str, Any]]: members = topology.get(member['name'], []) if level > 0: @@ -852,8 +878,8 @@ def generate_topology(level, member, topology): yield member -def topology_sort(members): - topology = defaultdict(list) +def topology_sort(members: List[Dict[str, Any]]) -> Iterator[Dict[str, Any]]: + topology: Dict[str, List[Dict[str, Any]]] = defaultdict(list) leader = next((m for m in members if m['role'].endswith('leader')), {'name': None}) replicas = set(member['name'] for member in members if not member['role'].endswith('leader')) for member in members: @@ -865,8 +891,8 @@ def topology_sort(members): yield member -def get_cluster_service_info(cluster): - service_info = [] +def get_cluster_service_info(cluster: Dict[str, Any]) -> List[str]: + service_info: List[str] = [] if cluster.get('pause'): service_info.append('Maintenance mode: on') @@ -879,8 +905,9 @@ def get_cluster_service_info(cluster): return service_info -def output_members(obj, cluster, name, extended=False, fmt='pretty', group=None): - rows = [] +def output_members(obj: Dict[str, Any], cluster: Cluster, name: str, + extended: bool = False, fmt: str = 'pretty', group: Optional[int] = None) -> None: + rows: List[List[Any]] = [] logging.debug(cluster) initialize = {None: 'uninitialized', '': 'initializing'}.get(cluster.initialize, cluster.initialize) @@ -905,12 +932,12 @@ def output_members(obj, cluster, name, extended=False, fmt='pretty', group=None) len(set(m['host'] for m in all_members)) < len(all_members) sort = topology_sort if fmt == 'topology' else iter - for g, cluster in sorted(clusters.items()): - for member in sort(cluster['members']): + for g, c in sorted(clusters.items()): + for member in sort(c['members']): logging.debug(member) lag = member.get('lag', '') - member.update(cluster=name, member=member['name'], group=g, + member.update(c=name, member=member['name'], group=g, host=member.get('host', ''), tl=member.get('timeline', ''), role=member['role'].replace('_', ' ').title(), lag_in_mb=round(lag / 1024 / 1024) if isinstance(lag, int) else lag, @@ -936,8 +963,8 @@ def output_members(obj, cluster, name, extended=False, fmt='pretty', group=None) if fmt not in ('pretty', 'topology'): # Omit service info when using machine-readable formats return - for g, cluster in sorted(clusters.items()): - service_info = get_cluster_service_info(cluster) + for g, c in sorted(clusters.items()): + service_info = get_cluster_service_info(c) if service_info: if is_citus_cluster and group is None: click.echo('Citus group: {0}'.format(g)) @@ -953,7 +980,8 @@ def output_members(obj, cluster, name, extended=False, fmt='pretty', group=None) @option_watch @option_watchrefresh @click.pass_obj -def members(obj, cluster_names, group, fmt, watch, w, extended, ts): +def members(obj: Dict[str, Any], cluster_names: List[str], group: Optional[int], + fmt: str, watch: Optional[int], w: bool, extended: bool, ts: bool) -> None: if not cluster_names: if 'scope' in obj: cluster_names = [obj['scope']] @@ -976,13 +1004,12 @@ def members(obj, cluster_names, group, fmt, watch, w, extended, ts): @option_citus_group @option_watch @option_watchrefresh -@click.pass_obj @click.pass_context -def topology(ctx, obj, cluster_names, group, watch, w): +def topology(ctx: click.Context, cluster_names: List[str], group: Optional[int], watch: Optional[int], w: bool) -> None: ctx.forward(members, fmt='topology') -def timestamp(precision=6): +def timestamp(precision: int = 6) -> str: return datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:precision - 7] @@ -994,7 +1021,8 @@ def timestamp(precision=6): @click.option('--role', '-r', help='Flush only members with this role', type=role_choice, default='any') @option_force @click.pass_obj -def flush(obj, cluster_name, group, member_names, force, role, target): +def flush(obj: Dict[str, Any], cluster_name: str, group: Optional[int], + member_names: List[str], force: bool, role: str, target: str) -> None: dcs = get_dcs(obj, cluster_name, group) cluster = dcs.get_cluster() @@ -1015,28 +1043,34 @@ def flush(obj, cluster_name, group, member_names, force, role, target): if r.status in (200, 404): prefix = 'Success' if r.status == 200 else 'Failed' return click.echo('{0}: {1}'.format(prefix, r.data.decode('utf-8'))) + + click.echo('Failed: member={0}, status_code={1}, ({2})'.format( + member.name, r.status, r.data.decode('utf-8'))) except Exception as err: logging.warning(str(err)) logging.warning('Member %s is not accessible', member.name) - click.echo('Failed: member={0}, status_code={1}, ({2})'.format( - member.name, r.status, r.data.decode('utf-8'))) - logging.warning('Failing over to DCS') click.echo('{0} Could not find any accessible member of cluster {1}'.format(timestamp(), cluster_name)) dcs.manual_failover('', '', index=failover.index) -def wait_until_pause_is_applied(dcs, paused, old_cluster): +def wait_until_pause_is_applied(dcs: AbstractDCS, paused: bool, old_cluster: Cluster) -> None: + from patroni.config import get_global_config + config = get_global_config(old_cluster) + click.echo("'{0}' request sent, waiting until it is recognized by all nodes".format(paused and 'pause' or 'resume')) old = {m.name: m.index for m in old_cluster.members if m.api_url} - loop_wait = old_cluster.config.data.get('loop_wait', dcs.loop_wait) + loop_wait = config.get('loop_wait') or dcs.loop_wait + cluster = None for _ in polling_loop(loop_wait + 1): cluster = dcs.get_cluster() if all(m.data.get('pause', False) == paused for m in cluster.members if m.name in old): break else: + if TYPE_CHECKING: # pragma: no cover + assert cluster is not None remaining = [m.name for m in cluster.members if m.data.get('pause', False) != paused and m.name in old and old[m.name] != m.index] if remaining: @@ -1045,7 +1079,7 @@ def wait_until_pause_is_applied(dcs, paused, old_cluster): return click.echo('Success: cluster management is {0}'.format(paused and 'paused' or 'resumed')) -def toggle_pause(config, cluster_name, group, paused, wait): +def toggle_pause(config: Dict[str, Any], cluster_name: str, group: Optional[int], paused: bool, wait: bool) -> None: from patroni.config import get_global_config dcs = get_dcs(config, cluster_name, group) cluster = dcs.get_cluster() @@ -1078,7 +1112,7 @@ def toggle_pause(config, cluster_name, group, paused, wait): @option_default_citus_group @click.pass_obj @click.option('--wait', help='Wait until pause is applied on all nodes', is_flag=True) -def pause(obj, cluster_name, group, wait): +def pause(obj: Dict[str, Any], cluster_name: str, group: Optional[int], wait: bool) -> None: return toggle_pause(obj, cluster_name, group, True, wait) @@ -1087,12 +1121,12 @@ def pause(obj, cluster_name, group, wait): @option_default_citus_group @click.option('--wait', help='Wait until pause is cleared on all nodes', is_flag=True) @click.pass_obj -def resume(obj, cluster_name, group, wait): +def resume(obj: Dict[str, Any], cluster_name: str, group: Optional[int], wait: bool) -> None: return toggle_pause(obj, cluster_name, group, False, wait) @contextmanager -def temporary_file(contents, suffix='', prefix='tmp'): +def temporary_file(contents: bytes, suffix: str = '', prefix: str = 'tmp') -> Generator[str, None, None]: """Creates a temporary file with specified contents that persists for the context. :param contents: binary string that will be written to the file. @@ -1110,12 +1144,12 @@ def temporary_file(contents, suffix='', prefix='tmp'): os.unlink(tmp.name) -def show_diff(before_editing, after_editing): +def show_diff(before_editing: str, after_editing: str) -> None: """Shows a diff between two strings. If the output is to a tty the diff will be colored. Inputs are expected to be unicode strings. """ - def listify(string): + def listify(string: str) -> List[str]: return [line + '\n' for line in string.rstrip('\n').split('\n')] unified_diff = difflib.unified_diff(listify(before_editing), listify(after_editing)) @@ -1159,7 +1193,7 @@ def show_diff(before_editing, after_editing): click.echo(line.rstrip('\n')) -def format_config_for_editing(data, default_flow_style=False): +def format_config_for_editing(data: Any, default_flow_style: bool = False) -> str: """Formats configuration as YAML for human consumption. :param data: configuration as nested dictionaries @@ -1167,7 +1201,7 @@ def format_config_for_editing(data, default_flow_style=False): return yaml.safe_dump(data, default_flow_style=default_flow_style, encoding=None, allow_unicode=True, width=200) -def apply_config_changes(before_editing, data, kvpairs): +def apply_config_changes(before_editing: str, data: Dict[str, Any], kvpairs: List[str]) -> Tuple[str, Dict[str, Any]]: """Applies config changes specified as a list of key-value pairs. Keys are interpreted as dotted paths into the configuration data structure. Except for paths beginning with @@ -1181,7 +1215,7 @@ def apply_config_changes(before_editing, data, kvpairs): """ changed_data = copy.deepcopy(data) - def set_path_value(config, path, value, prefix=()): + def set_path_value(config: Dict[str, Any], path: List[str], value: Any, prefix: Tuple[str, ...] = ()): # Postgresql GUCs can't be nested, but can contain dots so we re-flatten the structure for this case if prefix == ('postgresql', 'parameters'): path = ['.'.join(path)] @@ -1208,7 +1242,7 @@ def apply_config_changes(before_editing, data, kvpairs): return format_config_for_editing(changed_data), changed_data -def apply_yaml_file(data, filename): +def apply_yaml_file(data: Dict[str, Any], filename: str) -> Tuple[str, Dict[str, Any]]: """Applies changes from a YAML file to configuration :param data: configuration datastructure @@ -1228,7 +1262,7 @@ def apply_yaml_file(data, filename): return format_config_for_editing(changed_data), changed_data -def invoke_editor(before_editing, cluster_name): +def invoke_editor(before_editing: str, cluster_name: str) -> Tuple[str, Dict[str, Any]]: """Starts editor command to edit configuration in human readable format :param before_editing: human representation before editing @@ -1272,10 +1306,15 @@ def invoke_editor(before_editing, cluster_name): ' Use - for stdin.') @option_force @click.pass_obj -def edit_config(obj, cluster_name, group, force, quiet, kvpairs, pgkvpairs, apply_filename, replace_filename): +def edit_config(obj: Dict[str, Any], cluster_name: str, group: Optional[int], + force: bool, quiet: bool, kvpairs: List[str], pgkvpairs: List[str], + apply_filename: Optional[str], replace_filename: Optional[str]) -> None: dcs = get_dcs(obj, cluster_name, group) cluster = dcs.get_cluster() + if not cluster.config: + raise PatroniCtlException('The config key does not exist in the cluster {0}'.format(cluster_name)) + before_editing = format_config_for_editing(cluster.config.data) after_editing = None # Serves as a flag if any changes were requested @@ -1317,10 +1356,10 @@ def edit_config(obj, cluster_name, group, force, quiet, kvpairs, pgkvpairs, appl @arg_cluster_name @option_default_citus_group @click.pass_obj -def show_config(obj, cluster_name, group): +def show_config(obj: Dict[str, Any], cluster_name: str, group: Optional[int]) -> None: cluster = get_dcs(obj, cluster_name, group).get_cluster() - - click.echo(format_config_for_editing(cluster.config.data)) + if cluster.config: + click.echo(format_config_for_editing(cluster.config.data)) @ctl.command('version', help='Output version of patronictl command or a running Patroni instance') @@ -1328,7 +1367,7 @@ def show_config(obj, cluster_name, group): @click.argument('member_names', nargs=-1) @option_citus_group @click.pass_obj -def version(obj, cluster_name, group, member_names): +def version(obj: Dict[str, Any], cluster_name: str, group: Optional[int], member_names: List[str]) -> None: click.echo("patronictl version {0}".format(__version__)) if not cluster_name: @@ -1355,9 +1394,9 @@ def version(obj, cluster_name, group, member_names): @option_default_citus_group @option_format @click.pass_obj -def history(obj, cluster_name, group, fmt): +def history(obj: Dict[str, Any], cluster_name: str, group: Optional[int], fmt: str) -> None: cluster = get_dcs(obj, cluster_name, group).get_cluster() - history = cluster.history and cluster.history.lines or [] + history: List[List[Any]] = list(map(list, cluster.history and cluster.history.lines or [])) table_header_row = ['TL', 'LSN', 'Reason', 'Timestamp', 'New Leader'] for line in history: if len(line) < len(table_header_row): @@ -1367,7 +1406,7 @@ def history(obj, cluster_name, group, fmt): print_output(table_header_row, history, {'TL': 'r', 'LSN': 'r'}, fmt) -def format_pg_version(version): +def format_pg_version(version: int) -> str: if version < 100000: return "{0}.{1}.{2}".format(version // 10000, version // 100 % 100, version % 100) else: diff --git a/patroni/daemon.py b/patroni/daemon.py index 404e4061..88f6b133 100644 --- a/patroni/daemon.py +++ b/patroni/daemon.py @@ -11,10 +11,11 @@ import signal import sys from threading import Lock -from typing import Any, Optional, Type +from typing import Any, Optional, Type, TYPE_CHECKING -from .config import Config -from .validator import Schema +if TYPE_CHECKING: # pragma: no cover + from .config import Config + from .validator import Schema class AbstractPatroniDaemon(abc.ABC): @@ -30,7 +31,7 @@ class AbstractPatroniDaemon(abc.ABC): :ivar config: configuration options for this daemon. """ - def __init__(self, config: Config) -> None: + def __init__(self, config: 'Config') -> None: """Set up signal handlers, logging handler and configuration. :param config: configuration options for this daemon. @@ -94,7 +95,7 @@ class AbstractPatroniDaemon(abc.ABC): with self._sigterm_lock: return self._received_sigterm - def reload_config(self, sighup: Optional[bool] = False, local: Optional[bool] = False) -> None: + def reload_config(self, sighup: bool = False, local: Optional[bool] = False) -> None: """Reload configuration. :param sighup: if it is related to a SIGHUP signal. @@ -140,7 +141,7 @@ class AbstractPatroniDaemon(abc.ABC): self.logger.shutdown() -def abstract_main(cls: Type[AbstractPatroniDaemon], validator: Optional[Schema] = None) -> None: +def abstract_main(cls: Type[AbstractPatroniDaemon], validator: Optional['Schema'] = None) -> None: """Create the main entry point of a given daemon process. Expose a basic argument parser, parse the command-line arguments, and run the given daemon process. diff --git a/patroni/dcs/__init__.py b/patroni/dcs/__init__.py index ae52e8ad..2c2d0187 100644 --- a/patroni/dcs/__init__.py +++ b/patroni/dcs/__init__.py @@ -1,5 +1,6 @@ import abc import dateutil.parser +import datetime import importlib import inspect import json @@ -10,15 +11,17 @@ import re import sys import time -from collections import defaultdict, namedtuple +from collections import defaultdict from copy import deepcopy from random import randint from threading import Event, Lock -from typing import Any, Collection, Dict, List, Optional, Union +from typing import Any, Callable, Collection, Dict, List, NamedTuple, Optional, Set, Tuple, Union, TYPE_CHECKING from urllib.parse import urlparse, urlunparse, parse_qsl from ..exceptions import PatroniFatalException from ..utils import deep_compare, uri +if TYPE_CHECKING: # pragma: no cover + from ..config import Config CITUS_COORDINATOR_GROUP_ID = 0 citus_group_re = re.compile('^(0|[1-9][0-9]*)$') @@ -26,7 +29,7 @@ slot_name_re = re.compile('^[a-z0-9_]{1,63}$') logger = logging.getLogger(__name__) -def slot_name_from_member_name(member_name): +def slot_name_from_member_name(member_name: str) -> str: """Translate member name to valid PostgreSQL slot name. PostgreSQL replication slot names must be valid PostgreSQL names. This function maps the wider space of @@ -34,7 +37,7 @@ def slot_name_from_member_name(member_name): are replaced with underscores, other characters are encoded as their unicode codepoint. Name is truncated to 64 characters. Multiple different member names may map to a single slot name.""" - def replace_char(match): + def replace_char(match: Any) -> str: c = match.group(0) return '_' if c in '-.' else "u{:04d}".format(ord(c)) @@ -42,7 +45,7 @@ def slot_name_from_member_name(member_name): return slot_name[0:63] -def parse_connection_string(value): +def parse_connection_string(value: str) -> Tuple[str, Union[str, None]]: """Original Governor stores connection strings for each cluster members if a following format: postgres://{username}:{password}@{connect_address}/postgres Since each of our patroni instances provides own REST API endpoint it's good to store this information @@ -59,7 +62,7 @@ def parse_connection_string(value): return conn_url, api_url -def dcs_modules(): +def dcs_modules() -> List[str]: """Get names of DCS modules, depending on execution environment. If being packaged with PyInstaller, modules aren't discoverable dynamically by scanning source directory because `FrozenImporter` doesn't implement `iter_modules` method. But it is still possible to find all potential DCS modules by @@ -69,20 +72,20 @@ def dcs_modules(): module_prefix = __package__ + '.' if getattr(sys, 'frozen', False): - toc = set() + toc: Set[str] = set() # dcs_dirname may contain a dot, which causes pkgutil.iter_importers() # to misinterpret the path as a package name. This can be avoided # altogether by not passing a path at all, because PyInstaller's # FrozenImporter is a singleton and registered as top-level finder. for importer in pkgutil.iter_importers(): if hasattr(importer, 'toc'): - toc |= importer.toc + toc |= getattr(importer, 'toc') return [module for module in toc if module.startswith(module_prefix) and module.count('.') == 2] else: return [module_prefix + name for _, name, is_pkg in pkgutil.iter_modules([dcs_dirname]) if not is_pkg] -def get_dcs(config): +def get_dcs(config: Union['Config', Dict[str, Any]]) -> 'AbstractDCS': modules = dcs_modules() for module_name in modules: @@ -103,7 +106,7 @@ def get_dcs(config): except ImportError: logger.debug('Failed to import %s', module_name) - available_implementations = [] + available_implementations: List[str] = [] for module_name in modules: name = module_name.split('.')[-1] try: @@ -116,8 +119,11 @@ def get_dcs(config): Available implementations: """ + ', '.join(sorted(set(available_implementations)))) -class Member(namedtuple('Member', 'index,name,session,data')): +_Version = Union[int, str] +_Session = Union[int, float, str, None] + +class Member(NamedTuple): """Immutable object (namedtuple) which represents single member of PostgreSQL cluster. Consists of the following fields: :param index: modification index of a given member key in a Configuration Store @@ -127,30 +133,34 @@ class Member(namedtuple('Member', 'index,name,session,data')): There are two mandatory keys in a data: conn_url: connection string containing host, user and password which could be used to access this member. - api_url: REST API url of patroni instance""" + api_url: REST API url of patroni instance + """ + index: _Version + name: str + session: _Session + data: Dict[str, Any] @staticmethod - def from_node(index, name, session, data): + def from_node(index: _Version, name: str, session: _Session, value: str) -> 'Member': """ >>> Member.from_node(-1, '', '', '{"conn_url": "postgres://foo@bar/postgres"}') is not None True >>> Member.from_node(-1, '', '', '{') Member(index=-1, name='', session='', data={}) """ - if data.startswith('postgres'): - conn_url, api_url = parse_connection_string(data) + if value.startswith('postgres'): + conn_url, api_url = parse_connection_string(value) data = {'conn_url': conn_url, 'api_url': api_url} else: try: - data = json.loads(data) - if not isinstance(data, dict): - data = {} - except (TypeError, ValueError): - data = {} + data = json.loads(value) + assert isinstance(data, dict) + except (AssertionError, TypeError, ValueError): + data: Dict[str, Any] = {} return Member(index, name, session, data) @property - def conn_url(self): + def conn_url(self) -> Optional[str]: conn_url = self.data.get('conn_url') if conn_url: return conn_url @@ -161,13 +171,13 @@ class Member(namedtuple('Member', 'index,name,session,data')): self.data['conn_url'] = conn_url return conn_url - def conn_kwargs(self, auth=None): + def conn_kwargs(self, auth: Union[Any, Dict[str, Any], None] = None) -> Dict[str, Any]: defaults = { "host": None, "port": None, "dbname": None } - ret = self.data.get('conn_kwargs') + ret: Optional[Dict[str, Any]] = self.data.get('conn_kwargs') if ret: defaults.update(ret) ret = defaults @@ -191,35 +201,35 @@ class Member(namedtuple('Member', 'index,name,session,data')): return ret @property - def api_url(self): + def api_url(self) -> Optional[str]: return self.data.get('api_url') @property - def tags(self): + def tags(self) -> Dict[str, Any]: return self.data.get('tags', {}) @property - def nofailover(self): + def nofailover(self) -> bool: return self.tags.get('nofailover', False) @property - def replicatefrom(self): + def replicatefrom(self) -> Optional[str]: return self.tags.get('replicatefrom') @property - def clonefrom(self): + def clonefrom(self) -> bool: return self.tags.get('clonefrom', False) and bool(self.conn_url) @property - def state(self): + def state(self) -> str: return self.data.get('state', 'unknown') @property - def is_running(self): + def is_running(self) -> bool: return self.state == 'running' @property - def version(self): + def version(self) -> Optional[Tuple[int, ...]]: version = self.data.get('version') if version: try: @@ -230,11 +240,11 @@ class Member(namedtuple('Member', 'index,name,session,data')): class RemoteMember(Member): """Represents a remote member (typically a primary) for a standby cluster""" - def __new__(cls, name, data): - return super(RemoteMember, cls).__new__(cls, None, name, None, data) + def __new__(cls, name: str, data: Dict[str, Any]) -> 'RemoteMember': + return super(RemoteMember, cls).__new__(cls, -1, name, None, data) @staticmethod - def allowed_keys(): + def allowed_keys() -> Tuple[str, ...]: return ('primary_slot_name', 'create_replica_methods', 'restore_command', @@ -242,42 +252,47 @@ class RemoteMember(Member): 'recovery_min_apply_delay', 'no_replication_slot') - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: if name in RemoteMember.allowed_keys(): return self.data.get(name) -class Leader(namedtuple('Leader', 'index,session,member')): - +class Leader(NamedTuple): """Immutable object (namedtuple) which represents leader key. + Consists of the following fields: :param index: modification index of a leader key in a Configuration Store :param session: either session id or just ttl in seconds - :param member: reference to a `Member` object which represents current leader (see `Cluster.members`)""" + :param member: reference to a `Member` object which represents current leader (see `Cluster.members`) + """ + index: _Version + session: _Session + member: Member @property - def name(self): + def name(self) -> str: return self.member.name - def conn_kwargs(self, auth=None): + def conn_kwargs(self, auth: Optional[Dict[str, str]] = None) -> Dict[str, str]: return self.member.conn_kwargs(auth) @property - def conn_url(self): + def conn_url(self) -> Optional[str]: return self.member.conn_url @property - def data(self): + def data(self) -> Dict[str, Any]: return self.member.data @property - def timeline(self): + def timeline(self) -> Optional[int]: return self.data.get('timeline') @property - def checkpoint_after_promote(self): + def checkpoint_after_promote(self) -> Optional[bool]: """ >>> Leader(1, '', Member.from_node(1, '', '', '{"version":"z"}')).checkpoint_after_promote + """ version = self.member.version # 1.5.6 is the last version which doesn't expose checkpoint_after_promote: false @@ -285,7 +300,7 @@ class Leader(namedtuple('Leader', 'index,session,member')): return self.data.get('role') in ('master', 'primary') and 'checkpoint_after_promote' not in self.data -class Failover(namedtuple('Failover', 'index,leader,candidate,scheduled_at')): +class Failover(NamedTuple): """ >>> 'Failover' in str(Failover.from_node(1, '{"leader": "cluster_leader"}')) @@ -306,20 +321,26 @@ class Failover(namedtuple('Failover', 'index,leader,candidate,scheduled_at')): >>> 'abc' in Failover.from_node(1, 'abc:def') True """ + index: _Version + leader: Optional[str] + candidate: Optional[str] + scheduled_at: Optional[datetime.datetime] + @staticmethod - def from_node(index, value): + def from_node(index: _Version, value: Union[str, Dict[str, str]]) -> 'Failover': if isinstance(value, dict): - data = value + data: Dict[str, Any] = value elif value: try: data = json.loads(value) - if not isinstance(data, dict): - data = {} + assert isinstance(data, dict) + except AssertionError: + data = {} except ValueError: t = [a.strip() for a in value.split(':')] leader = t[0] candidate = t[1] if len(t) > 1 else None - return Failover(index, leader, candidate, None) if leader or candidate else None + return Failover(index, leader, candidate, None) else: data = {} @@ -328,54 +349,57 @@ class Failover(namedtuple('Failover', 'index,leader,candidate,scheduled_at')): return Failover(index, data.get('leader'), data.get('member'), data.get('scheduled_at')) - def __len__(self): + def __len__(self) -> int: return int(bool(self.leader)) + int(bool(self.candidate)) -class ClusterConfig(namedtuple('ClusterConfig', 'index,data,modify_index')): +class ClusterConfig(NamedTuple): + index: _Version + data: Dict[str, Any] + modify_index: _Version @staticmethod - def from_node(index, data, modify_index=None): + def from_node(index: _Version, value: str, modify_index: Optional[_Version] = None) -> 'ClusterConfig': """ >>> ClusterConfig.from_node(1, '{') is None False """ try: - data = json.loads(data) - except (TypeError, ValueError): - data = None + data = json.loads(value) + assert isinstance(data, dict) + except (AssertionError, TypeError, ValueError): + data: Dict[str, Any] = {} modify_index = 0 - if not isinstance(data, dict): - data = {} return ClusterConfig(index, data, index if modify_index is None else modify_index) @property - def permanent_slots(self): - return isinstance(self.data, dict) and ( - self.data.get('permanent_replication_slots') - or self.data.get('permanent_slots') or self.data.get('slots') - ) or {} + def permanent_slots(self) -> Dict[str, Any]: + return self.data.get('permanent_replication_slots')\ + or self.data.get('permanent_slots') or self.data.get('slots') or {} @property - def ignore_slots_matchers(self): - return isinstance(self.data, dict) and self.data.get('ignore_slots') or [] + def ignore_slots_matchers(self) -> List[Dict[str, Any]]: + return self.data.get('ignore_slots') or [] @property - def max_timelines_history(self): + def max_timelines_history(self) -> int: return self.data.get('max_timelines_history', 0) -class SyncState(namedtuple('SyncState', 'index,leader,sync_standby')): +class SyncState(NamedTuple): """Immutable object (namedtuple) which represents last observed synhcronous replication state :param index: modification index of a synchronization key in a Configuration Store :param leader: reference to member that was leader :param sync_standby: synchronous standby list (comma delimited) which are last synchronized to leader """ + index: Optional[_Version] + leader: Optional[str] + sync_standby: Optional[str] @staticmethod - def from_node(index: Union[str, int], value: Union[str, Dict[str, Any]]) -> 'SyncState': + def from_node(index: Optional[_Version], value: Union[str, Dict[str, Any], None]) -> 'SyncState': """ >>> SyncState.from_node(1, None).leader is None True @@ -393,15 +417,14 @@ class SyncState(namedtuple('SyncState', 'index,leader,sync_standby')): try: if value and isinstance(value, str): value = json.loads(value) - if not isinstance(value, dict): - return SyncState.empty(index) + assert isinstance(value, dict) return SyncState(index, value.get('leader'), value.get('sync_standby')) - except (TypeError, ValueError): + except (AssertionError, TypeError, ValueError): return SyncState.empty(index) @staticmethod - def empty(index: Optional[Union[str, int]] = '') -> 'SyncState': - return SyncState(index, None, '') + def empty(index: Optional[_Version] = None) -> 'SyncState': + return SyncState(index, None, None) @property def is_empty(self) -> bool: @@ -422,7 +445,7 @@ class SyncState(namedtuple('SyncState', 'index,leader,sync_standby')): """:returns: sync_standby as list.""" return self._str_to_list(self.sync_standby) if not self.is_empty and self.sync_standby else [] - def matches(self, name: Union[str, None], check_leader: Optional[bool] = False) -> bool: + def matches(self, name: Optional[str], check_leader: bool = False) -> bool: """Checks if node is presented in the /sync state. Since PostgreSQL does case-insensitive checks for synchronous_standby_name we do it also. @@ -447,20 +470,26 @@ class SyncState(namedtuple('SyncState', 'index,leader,sync_standby')): """ ret = False if name and not self.is_empty: - search_str = (self.sync_standby or '') + (',' + self.leader if check_leader else '') + search_str = (self.sync_standby or '') + (',' + (self.leader or '') if check_leader else '') ret = name.lower() in self._str_to_list(search_str.lower()) return ret - def leader_matches(self, name: Union[str, None]) -> bool: + def leader_matches(self, name: Optional[str]) -> bool: """:returns: `True` if name is matching the `SyncState.leader` value.""" - return name and not self.is_empty and name.lower() == self.leader.lower() + return bool(name and not self.is_empty and name.lower() == (self.leader or '').lower()) -class TimelineHistory(namedtuple('TimelineHistory', 'index,value,lines')): +_HistoryTuple = Union[Tuple[int, int, str], Tuple[int, int, str, str], Tuple[int, int, str, str, str]] + + +class TimelineHistory(NamedTuple): """Object representing timeline history file""" + index: _Version + value: Any + lines: List[_HistoryTuple] @staticmethod - def from_node(index, value): + def from_node(index: _Version, value: str) -> 'TimelineHistory': """ >>> h = TimelineHistory.from_node(1, 2) >>> h.lines @@ -468,16 +497,13 @@ class TimelineHistory(namedtuple('TimelineHistory', 'index,value,lines')): """ try: lines = json.loads(value) - except (TypeError, ValueError): - lines = None - if not isinstance(lines, list): - lines = [] + assert isinstance(lines, list) + except (AssertionError, TypeError, ValueError): + lines: List[_HistoryTuple] = [] return TimelineHistory(index, value, lines) -class Cluster(namedtuple('Cluster', 'initialize,config,leader,last_lsn,members,' - 'failover,sync,history,slots,failsafe,workers')): - +class Cluster(NamedTuple): """Immutable object (namedtuple) which represents PostgreSQL cluster. Consists of the following fields: :param initialize: shows whether this cluster has initialization key stored in DC or not. @@ -491,55 +517,70 @@ class Cluster(namedtuple('Cluster', 'initialize,config,leader,last_lsn,members,' :param history: reference to `TimelineHistory` object :param slots: state of permanent logical replication slots on the primary in the format: {"slot_name": int} :param failsafe: failsafe topology. Node is allowed to become the leader only if its name is found in this list. - :param workers: workers of the Citus cluster, optional. Format: {int(group): Cluster()}""" - - def __new__(cls, *args): - # Make workers argument optional - if len(cls._fields) == len(args) + 1: - args = args + ({},) - return super(Cluster, cls).__new__(cls, *args) + :param workers: workers of the Citus cluster, optional. Format: {int(group): Cluster()} + """ + initialize: Optional[str] + config: Optional[ClusterConfig] + leader: Optional[Leader] + last_lsn: int + members: List[Member] + failover: Optional[Failover] + sync: SyncState + history: Optional[TimelineHistory] + slots: Optional[Dict[str, int]] + failsafe: Optional[Dict[str, str]] + workers: Dict[int, 'Cluster'] = {} @staticmethod - def empty(): + def empty() -> 'Cluster': return Cluster(None, None, None, 0, [], None, SyncState.empty(), None, None, None) + def is_empty(self): + return self.initialize is None and self.config is None and self.leader is None and self.last_lsn == 0\ + and self.members == [] and self.failover is None and self.sync.index is None\ + and self.history is None and self.slots is None and self.failsafe is None and self.workers == {} + + def __len__(self) -> int: + return int(not self.is_empty()) + @property - def leader_name(self): + def leader_name(self) -> Optional[str]: return self.leader and self.leader.name - def is_unlocked(self): + def is_unlocked(self) -> bool: return not self.leader_name - def has_member(self, member_name): + def has_member(self, member_name: str) -> bool: return any(m for m in self.members if m.name == member_name) - def get_member(self, member_name, fallback_to_leader=True): + def get_member(self, member_name: str, fallback_to_leader: bool = True) -> Union[Member, Leader, None]: return ([m for m in self.members if m.name == member_name] or [self.leader if fallback_to_leader else None])[0] - def get_clone_member(self, exclude): - exclude = [exclude] + [self.leader.name] if self.leader else [] + def get_clone_member(self, exclude_name: str) -> Union[Member, Leader, None]: + exclude = [exclude_name] + ([self.leader.name] if self.leader else []) candidates = [m for m in self.members if m.clonefrom and m.is_running and m.name not in exclude] return candidates[randint(0, len(candidates) - 1)] if candidates else self.leader @property - def __permanent_slots(self): + def __permanent_slots(self) -> Dict[str, Union[Dict[str, Any], Any]]: return self.config and self.config.permanent_slots or {} @property - def __permanent_physical_slots(self): + def __permanent_physical_slots(self) -> Dict[str, Any]: return {name: value for name, value in self.__permanent_slots.items() if not value or isinstance(value, dict) and value.get('type', 'physical') == 'physical'} @property - def __permanent_logical_slots(self): + def __permanent_logical_slots(self) -> Dict[str, Any]: return {name: value for name, value in self.__permanent_slots.items() if isinstance(value, dict) and value.get('type', 'logical') == 'logical' and value.get('database') and value.get('plugin')} @property - def use_slots(self): - return self.config and (self.config.data.get('postgresql') or {}).get('use_slots', True) + def use_slots(self) -> bool: + return bool(self.config and (self.config.data.get('postgresql') or {}).get('use_slots', True)) - def get_replication_slots(self, my_name, role, nofailover, major_version, show_error=False): + def get_replication_slots(self, my_name: str, role: str, nofailover: bool, + major_version: int, show_error: bool = False) -> Dict[str, Dict[str, Any]]: # if the replicatefrom tag is set on the member - we should not create the replication slot for it on # the current primary, because that member would replicate from elsewhere. We still create the slot if # the replicatefrom destination member is currently not a member of the cluster (fallback to the @@ -561,7 +602,7 @@ class Cluster(namedtuple('Cluster', 'initialize,config,leader,last_lsn,members,' if len(slots) < len(slot_members): # Find which names are conflicting for a nicer error message - slot_conflicts = defaultdict(list) + slot_conflicts: Dict[str, List[str]] = defaultdict(list) for name in slot_members: slot_conflicts[slot_name_from_member_name(name)].append(name) logger.error("Following cluster members share a replication slot name: %s", @@ -569,7 +610,7 @@ class Cluster(namedtuple('Cluster', 'initialize,config,leader,last_lsn,members,' for k, v in slot_conflicts.items() if len(v) > 1)) # "merge" replication slots for members with permanent_replication_slots - disabled_permanent_logical_slots = [] + disabled_permanent_logical_slots: List[str] = [] for name, value in permanent_slots.items(): if not slot_name_re.match(name): logger.error("Invalid permanent replication slot name '%s'", name) @@ -604,13 +645,13 @@ class Cluster(namedtuple('Cluster', 'initialize,config,leader,last_lsn,members,' return slots - def has_permanent_logical_slots(self, my_name, nofailover, major_version=110000): + def has_permanent_logical_slots(self, my_name: str, nofailover: bool, major_version: int = 110000) -> bool: if major_version < 110000: return False slots = self.get_replication_slots(my_name, 'replica', nofailover, major_version).values() return any(v for v in slots if v.get("type") == "logical") - def should_enforce_hot_standby_feedback(self, my_name, nofailover, major_version): + def should_enforce_hot_standby_feedback(self, my_name: str, nofailover: bool, major_version: int) -> bool: """ The hot_standby_feedback must be enabled if the current replica has logical slots or it is working as a cascading replica for the other node that has logical slots. @@ -627,7 +668,7 @@ class Cluster(namedtuple('Cluster', 'initialize,config,leader,last_lsn,members,' return any(self.should_enforce_hot_standby_feedback(m.name, m.nofailover, major_version) for m in members) return False - def get_my_slot_name_on_primary(self, my_name, replicatefrom): + def get_my_slot_name_on_primary(self, my_name: str, replicatefrom: Optional[str]) -> str: """ P <-- I <-- L In case of cascading replication we have to check not our physical slot, @@ -635,10 +676,11 @@ class Cluster(namedtuple('Cluster', 'initialize,config,leader,last_lsn,members,' """ m = self.get_member(replicatefrom, False) if replicatefrom else None - return self.get_my_slot_name_on_primary(m.name, m.replicatefrom) if m else slot_name_from_member_name(my_name) + return self.get_my_slot_name_on_primary(m.name, m.replicatefrom)\ + if isinstance(m, Member) else slot_name_from_member_name(my_name) @property - def timeline(self): + def timeline(self) -> int: """ >>> Cluster(0, 0, 0, 0, 0, 0, 0, 0, 0, None).timeline 0 @@ -658,16 +700,16 @@ class Cluster(namedtuple('Cluster', 'initialize,config,leader,last_lsn,members,' return 0 @property - def min_version(self): - return next(iter(sorted(filter(lambda v: v, [m.version for m in self.members])) + [None])) + def min_version(self) -> Optional[Tuple[int, ...]]: + return next(iter(sorted(m.version for m in self.members if m.version)), None) class ReturnFalseException(Exception): pass -def catch_return_false_exception(func): - def wrapper(*args, **kwargs): +def catch_return_false_exception(func: Callable[..., Any]) -> Any: + def wrapper(*args: Any, **kwargs: Any): try: return func(*args, **kwargs) except ReturnFalseException: @@ -690,7 +732,7 @@ class AbstractDCS(abc.ABC): _SYNC = 'sync' _FAILSAFE = 'failsafe' - def __init__(self, config): + def __init__(self, config: Dict[str, Any]) -> None: """ :param config: dict, reference to config section of selected DCS. i.e.: `zookeeper` for zookeeper, `etcd` for etcd, etc... @@ -701,16 +743,16 @@ class AbstractDCS(abc.ABC): self._set_loop_wait(config.get('loop_wait', 10)) self._ctl = bool(config.get('patronictl', False)) - self._cluster = None - self._cluster_valid_till = 0 + self._cluster: Optional[Cluster] = None + self._cluster_valid_till: float = 0 self._cluster_thread_lock = Lock() - self._last_lsn = '' - self._last_seen = 0 - self._last_status = {} - self._last_failsafe = {} + self._last_lsn: int = 0 + self._last_seen: int = 0 + self._last_status: Dict[str, Any] = {} + self._last_failsafe: Optional[Dict[str, str]] = {} self.event = Event() - def client_path(self, path): + def client_path(self, path: str) -> str: components = [self._base_path] if self._citus_group: components.append(self._citus_group) @@ -718,86 +760,88 @@ class AbstractDCS(abc.ABC): return '/'.join(components) @property - def initialize_path(self): + def initialize_path(self) -> str: return self.client_path(self._INITIALIZE) @property - def config_path(self): + def config_path(self) -> str: return self.client_path(self._CONFIG) @property - def members_path(self): + def members_path(self) -> str: return self.client_path(self._MEMBERS) @property - def member_path(self): + def member_path(self) -> str: return self.client_path(self._MEMBERS + self._name) @property - def leader_path(self): + def leader_path(self) -> str: return self.client_path(self._LEADER) @property - def failover_path(self): + def failover_path(self) -> str: return self.client_path(self._FAILOVER) @property - def history_path(self): + def history_path(self) -> str: return self.client_path(self._HISTORY) @property - def status_path(self): + def status_path(self) -> str: return self.client_path(self._STATUS) @property - def leader_optime_path(self): + def leader_optime_path(self) -> str: return self.client_path(self._LEADER_OPTIME) @property - def sync_path(self): + def sync_path(self) -> str: return self.client_path(self._SYNC) @property - def failsafe_path(self): + def failsafe_path(self) -> str: return self.client_path(self._FAILSAFE) @abc.abstractmethod - def set_ttl(self, ttl): + def set_ttl(self, ttl: int) -> Optional[bool]: """Set the new ttl value for leader key""" + @property @abc.abstractmethod - def ttl(self): + def ttl(self) -> int: """Get new ttl value""" @abc.abstractmethod - def set_retry_timeout(self, retry_timeout): + def set_retry_timeout(self, retry_timeout: int) -> None: """Set the new value for retry_timeout""" - def _set_loop_wait(self, loop_wait): + def _set_loop_wait(self, loop_wait: int) -> None: self._loop_wait = loop_wait - def reload_config(self, config): + def reload_config(self, config: Union['Config', Dict[str, Any]]) -> None: self._set_loop_wait(config['loop_wait']) self.set_ttl(config['ttl']) self.set_retry_timeout(config['retry_timeout']) @property - def loop_wait(self): + def loop_wait(self) -> int: return self._loop_wait @property - def last_seen(self): + def last_seen(self) -> int: return self._last_seen @abc.abstractmethod - def _cluster_loader(self, path): + def _cluster_loader(self, path: Any) -> Cluster: """Load and build the `Cluster` object from DCS, which represents a single Patroni cluster. :param path: the path in DCS where to load Cluster(s) from. :returns: `Cluster`""" - def _citus_cluster_loader(self, path): + @abc.abstractmethod + def _citus_cluster_loader(self, path: Any) -> Union[Cluster, Dict[int, Cluster]]: """Load and build `Cluster` onjects from DCS that represent all Patroni clusters from a single Citus cluster. @@ -805,7 +849,9 @@ class AbstractDCS(abc.ABC): :returns: all Citus groups as `dict`, with group ids as keys""" @abc.abstractmethod - def _load_cluster(self, path, loader): + def _load_cluster( + self, path: str, loader: Callable[[Any], Union[Cluster, Dict[int, Cluster]]] + ) -> Union[Cluster, Dict[int, Cluster]]: """Internally this method should call the `loader` method that will build `Cluster` object which represents current state and topology of the cluster in DCS. This method supposed to be @@ -817,20 +863,26 @@ class AbstractDCS(abc.ABC): If the current node was running as a primary and exception raised, instance would be demoted.""" - def _bypass_caches(self): + def _bypass_caches(self) -> None: """Used only in zookeeper""" - def is_citus_coordinator(self): + def __get_patroni_cluster(self, path: Optional[str] = None) -> Cluster: + if path is None: + path = self.client_path('') + cluster = self._load_cluster(path, self._cluster_loader) + assert isinstance(cluster, Cluster) + return cluster + + def is_citus_coordinator(self) -> bool: return self._citus_group == str(CITUS_COORDINATOR_GROUP_ID) - def get_citus_coordinator(self): + def get_citus_coordinator(self) -> Optional[Cluster]: try: - path = '{0}/{1}/'.format(self._base_path, CITUS_COORDINATOR_GROUP_ID) - return self._load_cluster(path, self._cluster_loader) + return self.__get_patroni_cluster('{0}/{1}/'.format(self._base_path, CITUS_COORDINATOR_GROUP_ID)) except Exception as e: logger.error('Failed to load Citus coordinator cluster from %s: %r', self.__class__.__name__, e) - def _get_citus_cluster(self): + def _get_citus_cluster(self) -> Cluster: groups = self._load_cluster(self._base_path + '/', self._citus_cluster_loader) if isinstance(groups, Cluster): # Zookeeper could return a cached version cluster = groups @@ -840,12 +892,11 @@ class AbstractDCS(abc.ABC): cluster.workers.update(groups) return cluster - def get_cluster(self, force=False): + def get_cluster(self, force: bool = False) -> Cluster: if force: self._bypass_caches() try: - cluster = self._get_citus_cluster() if self.is_citus_coordinator()\ - else self._load_cluster(self.client_path(''), self._cluster_loader) + cluster = self._get_citus_cluster() if self.is_citus_coordinator() else self.__get_patroni_cluster() except Exception: self.reset_cluster() raise @@ -860,33 +911,33 @@ class AbstractDCS(abc.ABC): return cluster @property - def cluster(self): + def cluster(self) -> Optional[Cluster]: with self._cluster_thread_lock: return self._cluster if self._cluster_valid_till > time.time() else None - def reset_cluster(self): + def reset_cluster(self) -> None: with self._cluster_thread_lock: self._cluster = None self._cluster_valid_till = 0 @abc.abstractmethod - def _write_leader_optime(self, last_lsn): + def _write_leader_optime(self, last_lsn: str) -> bool: """write current WAL LSN into `/optime/leader` key in DCS :param last_lsn: absolute WAL LSN in bytes :returns: `!True` on success.""" - def write_leader_optime(self, last_lsn): + def write_leader_optime(self, last_lsn: int) -> None: self.write_status({self._OPTIME: last_lsn}) @abc.abstractmethod - def _write_status(self, value): + def _write_status(self, value: str) -> bool: """write current WAL LSN and confirmed_flush_lsn of permanent slots into the `/status` key in DCS :param value: status serialized in JSON forman :returns: `!True` on success.""" - def write_status(self, value): + def write_status(self, value: Dict[str, Any]) -> None: if not deep_compare(self._last_status, value) and self._write_status(json.dumps(value, separators=(',', ':'))): self._last_status = value cluster = self.cluster @@ -896,23 +947,23 @@ class AbstractDCS(abc.ABC): self._write_leader_optime(str(value[self._OPTIME])) @abc.abstractmethod - def _write_failsafe(self, value): + def _write_failsafe(self, value: str) -> bool: """Write current cluster topology to DCS that will be used by failsafe mechanism (if enabled). :param value: failsafe topology serialized in JSON format :returns: `!True` on success.""" - def write_failsafe(self, value): + def write_failsafe(self, value: Dict[str, str]) -> None: if not (isinstance(self._last_failsafe, dict) and deep_compare(self._last_failsafe, value))\ and self._write_failsafe(json.dumps(value, separators=(',', ':'))): self._last_failsafe = value @property - def failsafe(self): + def failsafe(self) -> Optional[Dict[str, str]]: return self._last_failsafe @abc.abstractmethod - def _update_leader(self): + def _update_leader(self) -> bool: """Update leader key (or session) ttl :returns: `!True` if leader key (or session) has been updated successfully. @@ -922,7 +973,8 @@ class AbstractDCS(abc.ABC): If update fails due to DCS not being accessible or because it is not able to process requests (hopefuly temporary), the ~DCSError exception should be raised.""" - def update_leader(self, last_lsn, slots=None, failsafe=None): + def update_leader(self, last_lsn: Optional[int], slots: Optional[Dict[str, int]] = None, + failsafe: Optional[Dict[str, str]] = None) -> bool: """Update leader key (or session) ttl and optime/leader :param last_lsn: absolute WAL LSN in bytes @@ -931,7 +983,7 @@ class AbstractDCS(abc.ABC): ret = self._update_leader() if ret and last_lsn: - status = {self._OPTIME: last_lsn} + status: Dict[str, Any] = {self._OPTIME: last_lsn} if slots: status['slots'] = slots self.write_status(status) @@ -942,7 +994,7 @@ class AbstractDCS(abc.ABC): return ret @abc.abstractmethod - def attempt_to_acquire_leader(self): + def attempt_to_acquire_leader(self) -> bool: """Attempt to acquire leader lock This method should create `/leader` key with value=`~self._name` :returns: `!True` if key has been created successfully. @@ -954,10 +1006,11 @@ class AbstractDCS(abc.ABC): process requests (hopefuly temporary), the ~DCSError exception should be raised""" @abc.abstractmethod - def set_failover_value(self, value, index=None): + def set_failover_value(self, value: str, index: Optional[Any] = None) -> bool: """Create or update `/failover` key""" - def manual_failover(self, leader, candidate, scheduled_at=None, index=None): + def manual_failover(self, leader: Optional[str], candidate: Optional[str], + scheduled_at: Optional[datetime.datetime] = None, index: Optional[Any] = None) -> bool: failover_value = {} if leader: failover_value['leader'] = leader @@ -970,11 +1023,11 @@ class AbstractDCS(abc.ABC): return self.set_failover_value(json.dumps(failover_value, separators=(',', ':')), index) @abc.abstractmethod - def set_config_value(self, value, index=None): + def set_config_value(self, value: str, index: Optional[Any] = None) -> bool: """Create or update `/config` key""" @abc.abstractmethod - def touch_member(self, data): + def touch_member(self, data: Dict[str, Any]) -> bool: """Update member key in DCS. This method should create or update key with the name = '/members/' + `~self._name` and value = data in a given DCS. @@ -985,13 +1038,13 @@ class AbstractDCS(abc.ABC): """ @abc.abstractmethod - def take_leader(self): + def take_leader(self) -> bool: """This method should create leader key with value = `~self._name` and ttl=`~self.ttl` Since it could be called only on initial cluster bootstrap it could create this key regardless, overwriting the key if necessary.""" @abc.abstractmethod - def initialize(self, create_new=True, sysid=""): + def initialize(self, create_new: bool = True, sysid: str = "") -> bool: """Race for cluster initialization. :param create_new: False if the key should already exist (in the case we are setting the system_id) @@ -1002,11 +1055,11 @@ class AbstractDCS(abc.ABC): otherwise it should return `!False`""" @abc.abstractmethod - def _delete_leader(self): + def _delete_leader(self) -> bool: """Remove leader key from DCS. This method should remove leader key if current instance is the leader""" - def delete_leader(self, last_lsn=None): + def delete_leader(self, last_lsn: Optional[int] = None) -> bool: """Update optime/leader and voluntarily remove leader key from DCS. This method should remove leader key if current instance is the leader. :param last_lsn: latest checkpoint location in bytes""" @@ -1016,15 +1069,15 @@ class AbstractDCS(abc.ABC): return self._delete_leader() @abc.abstractmethod - def cancel_initialization(self): + def cancel_initialization(self) -> bool: """ Removes the initialize key for a cluster """ @abc.abstractmethod - def delete_cluster(self): + def delete_cluster(self) -> bool: """Delete cluster from DCS""" @staticmethod - def sync_state(leader: Union[str, None], sync_standby: Union[Collection[str], None]) -> Dict[str, Any]: + def sync_state(leader: Optional[str], sync_standby: Optional[Collection[str]]) -> Dict[str, Any]: """Build sync_state dict. The sync_standby key being kept for backward compatibility. :param leader: name of the leader node that manages /sync key @@ -1033,8 +1086,8 @@ class AbstractDCS(abc.ABC): """ return {'leader': leader, 'sync_standby': ','.join(sorted(sync_standby)) if sync_standby else None} - def write_sync_state(self, leader: Union[str, None], sync_standby: Union[Collection[str], None], - index: Optional[Union[int, str]] = None) -> bool: + def write_sync_state(self, leader: Optional[str], sync_standby: Optional[Collection[str]], + index: Optional[Any] = None) -> bool: """Write the new synchronous state to DCS. Calls :func:`sync_state` method to build a dict and than calls DCS specific :func:`set_sync_state_value` method. :param leader: name of the leader node that manages /sync key @@ -1046,11 +1099,11 @@ class AbstractDCS(abc.ABC): return self.set_sync_state_value(json.dumps(sync_value, separators=(',', ':')), index) @abc.abstractmethod - def set_history_value(self, value): + def set_history_value(self, value: str) -> bool: """""" @abc.abstractmethod - def set_sync_state_value(self, value: str, index: Optional[Union[int, str]] = None) -> bool: + def set_sync_state_value(self, value: str, index: Optional[Any] = None) -> bool: """Set synchronous state in DCS, should be implemented in the child class. :param value: the new value of /sync key @@ -1059,10 +1112,10 @@ class AbstractDCS(abc.ABC): """ @abc.abstractmethod - def delete_sync_state(self, index=None): + def delete_sync_state(self, index: Optional[Any] = None) -> bool: """""" - def watch(self, leader_index, timeout): + def watch(self, leader_index: Optional[Any], timeout: float) -> bool: """If the current node is a leader it should just sleep. Any other node should watch for changes of leader key with a given timeout diff --git a/patroni/dcs/consul.py b/patroni/dcs/consul.py index fbf4c63f..5a2e5700 100644 --- a/patroni/dcs/consul.py +++ b/patroni/dcs/consul.py @@ -8,16 +8,19 @@ import ssl import time import urllib3 -from collections import defaultdict, namedtuple +from collections import defaultdict from consul import ConsulException, NotFound, base from http.client import HTTPException from urllib3.exceptions import HTTPError from urllib.parse import urlencode, urlparse, quote +from typing import Any, Callable, Dict, List, Mapping, NamedTuple, Optional, Union, Tuple, TYPE_CHECKING from . import AbstractDCS, Cluster, ClusterConfig, Failover, Leader, Member, SyncState,\ TimelineHistory, ReturnFalseException, catch_return_false_exception, citus_group_re from ..exceptions import DCSError from ..utils import deep_compare, parse_bool, Retry, RetryFailedError, split_host_port, uri, USER_AGENT +if TYPE_CHECKING: # pragma: no cover + from ..config import Config logger = logging.getLogger(__name__) @@ -38,12 +41,17 @@ class InvalidSession(ConsulException): """invalid session""" -Response = namedtuple('Response', 'code,headers,body,content') +class Response(NamedTuple): + code: int + headers: Union[Mapping[str, str], Mapping[bytes, bytes], None] + body: str + content: bytes class HTTPClient(object): - def __init__(self, host='127.0.0.1', port=8500, token=None, scheme='http', verify=True, cert=None, ca_cert=None): + def __init__(self, host: str = '127.0.0.1', port: int = 8500, token: Optional[str] = None, scheme: str = 'http', + verify: bool = True, cert: Optional[str] = None, ca_cert: Optional[str] = None) -> None: self.token = token self._read_timeout = 10 self.base_uri = uri(scheme, (host, port)) @@ -60,22 +68,22 @@ class HTTPClient(object): kwargs['ca_certs'] = ca_cert kwargs['cert_reqs'] = ssl.CERT_REQUIRED if verify or ca_cert else ssl.CERT_NONE self.http = urllib3.PoolManager(num_pools=10, maxsize=10, **kwargs) - self._ttl = None + self._ttl = 30 - def set_read_timeout(self, timeout): + def set_read_timeout(self, timeout: float) -> None: self._read_timeout = timeout / 3.0 @property - def ttl(self): + def ttl(self) -> int: return self._ttl - def set_ttl(self, ttl): + def set_ttl(self, ttl: int) -> bool: ret = self._ttl != ttl self._ttl = ttl return ret @staticmethod - def response(response): + def response(response: urllib3.response.HTTPResponse) -> Response: content = response.data body = content.decode('utf-8') if response.status == 500: @@ -88,14 +96,19 @@ class HTTPClient(object): raise ConsulInternalError(msg) return Response(response.status, response.headers, body, content) - def uri(self, path, params=None): + def uri(self, path: str, + params: Union[None, Dict[str, Any], List[Tuple[str, Any]], Tuple[Tuple[str, Any], ...]] = None) -> str: return '{0}{1}{2}'.format(self.base_uri, path, params and '?' + urlencode(params) or '') - def __getattr__(self, method): + def __getattr__(self, method: str) -> Callable[[Callable[[Response], Union[bool, Any, Tuple[str, Any]]], + str, Union[None, Dict[str, Any], List[Tuple[str, Any]]], + str, Optional[Dict[str, str]]], Union[bool, Any, Tuple[str, Any]]]: if method not in ('get', 'post', 'put', 'delete'): raise AttributeError("HTTPClient instance has no attribute '{0}'".format(method)) - def wrapper(callback, path, params=None, data='', headers=None): + def wrapper(callback: Callable[[Response], Union[bool, Any, Tuple[str, Any]]], path: str, + params: Union[None, Dict[str, Any], List[Tuple[str, Any]]] = None, data: str = '', + headers: Optional[Dict[str, str]] = None) -> Union[bool, Any, Tuple[str, Any]]: # python-consul doesn't allow to specify ttl smaller then 10 seconds # because session_ttl_min defaults to 10s, so we have to do this ugly dirty hack... if method == 'put' and path == '/v1/session/create': @@ -106,7 +119,7 @@ class HTTPClient(object): data = data[:-1] + ', ' + ttl + '}' if isinstance(params, list): # starting from v1.1.0 python-consul switched from `dict` to `list` for params params = {k: v for k, v in params} - kwargs = {'retries': 0, 'preload_content': False, 'body': data} + kwargs: Dict[str, Any] = {'retries': 0, 'preload_content': False, 'body': data} if method == 'get' and isinstance(params, dict) and 'index' in params: timeout = float(params['wait'][:-1]) if 'wait' in params else 300 # According to the documentation a small random amount of additional wait time is added to the @@ -127,13 +140,13 @@ class HTTPClient(object): class ConsulClient(base.Consul): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: self._cert = kwargs.pop('cert', None) self._ca_cert = kwargs.pop('ca_cert', None) self.token = kwargs.get('token') super(ConsulClient, self).__init__(*args, **kwargs) - def http_connect(self, *args, **kwargs): + def http_connect(self, *args: Any, **kwargs: Any) -> HTTPClient: kwargs.update(dict(zip(['host', 'port', 'scheme', 'verify'], args))) if self._cert: kwargs['cert'] = self._cert @@ -143,17 +156,17 @@ class ConsulClient(base.Consul): kwargs['token'] = self.token return HTTPClient(**kwargs) - def connect(self, *args, **kwargs): + def connect(self, *args: Any, **kwargs: Any) -> HTTPClient: return self.http_connect(*args, **kwargs) - def reload_config(self, config): + def reload_config(self, config: Dict[str, Any]) -> None: self.http.token = self.token = config.get('token') self.consistency = config.get('consistency', 'default') self.dc = config.get('dc') -def catch_consul_errors(func): - def wrapper(*args, **kwargs): +def catch_consul_errors(func: Callable[..., Any]) -> Callable[..., Any]: + def wrapper(*args: Any, **kwargs: Any) -> Any: try: return func(*args, **kwargs) except (RetryFailedError, ConsulException, HTTPException, HTTPError, socket.error, socket.timeout): @@ -161,24 +174,25 @@ def catch_consul_errors(func): return wrapper -def force_if_last_failed(func): - def wrapper(*args, **kwargs): - if wrapper.last_result is False: +def force_if_last_failed(func: Callable[..., Any]) -> Callable[..., Any]: + def wrapper(*args: Any, **kwargs: Any) -> Any: + if getattr(wrapper, 'last_result', None) is False: kwargs['force'] = True - wrapper.last_result = func(*args, **kwargs) - return wrapper.last_result + last_result = func(*args, **kwargs) + setattr(wrapper, 'last_result', last_result) + return last_result - wrapper.last_result = None + setattr(wrapper, 'last_result', None) return wrapper -def service_name_from_scope_name(scope_name): +def service_name_from_scope_name(scope_name: str) -> str: """Translate scope name to service name which can be used in dns. 230 = 253 - len('replica.') - len('.service.consul') """ - def replace_char(match): + def replace_char(match: Any) -> str: c = match.group(0) return '-' if c in '. _' else "u{:04d}".format(ord(c)) @@ -188,7 +202,7 @@ def service_name_from_scope_name(scope_name): class Consul(AbstractDCS): - def __init__(self, config): + def __init__(self, config: Dict[str, Any]) -> None: super(Consul, self).__init__(config) self._base_path = self._base_path[1:] self._scope = config['scope'] @@ -198,9 +212,9 @@ class Consul(AbstractDCS): retry_exceptions=(ConsulInternalError, HTTPException, HTTPError, socket.error, socket.timeout)) - kwargs = {} if 'url' in config: - r = urlparse(config['url']) + url: str = config['url'] + r = urlparse(url) config.update({'scheme': r.scheme, 'host': r.hostname, 'port': r.port or 8500}) elif 'host' in config: host, port = split_host_port(config.get('host', '127.0.0.1:8500'), 8500) @@ -215,7 +229,7 @@ class Consul(AbstractDCS): config['cert'] = (config['cert'], config['key']) config_keys = ('host', 'port', 'token', 'scheme', 'cert', 'ca_cert', 'dc', 'consistency') - kwargs = {p: config.get(p) for p in config_keys if config.get(p)} + kwargs: Dict[str, Any] = {p: config.get(p) for p in config_keys if config.get(p)} verify = config.get('verify') if not isinstance(verify, bool): @@ -240,10 +254,10 @@ class Consul(AbstractDCS): self.create_session() self._previous_loop_token = self._client.token - def retry(self, *args, **kwargs): - return self._retry.copy()(*args, **kwargs) + def retry(self, method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: + return self._retry.copy()(method, *args, **kwargs) - def create_session(self): + def create_session(self) -> None: while not self._session: try: self.refresh_session() @@ -251,13 +265,14 @@ class Consul(AbstractDCS): logger.info('waiting on consul') time.sleep(5) - def reload_config(self, config): + def reload_config(self, config: Union['Config', Dict[str, Any]]) -> None: super(Consul, self).reload_config(config) consul_config = config.get('consul', {}) self._client.reload_config(consul_config) self._previous_loop_service_tags = self._service_tags - self._service_tags = sorted(consul_config.get('service_tags', [])) + self._service_tags: List[str] = consul_config.get('service_tags', []) + self._service_tags.sort() should_register_service = consul_config.get('register_service', False) if should_register_service and not self._register_service: @@ -266,20 +281,21 @@ class Consul(AbstractDCS): self._previous_loop_register_service = self._register_service self._register_service = should_register_service - def set_ttl(self, ttl): + def set_ttl(self, ttl: int) -> Optional[bool]: if self._client.http.set_ttl(ttl / 2.0): # Consul multiplies the TTL by 2x self._session = None self.__do_not_watch = True + return None @property - def ttl(self): + def ttl(self) -> int: return self._client.http.ttl * 2 # we multiply the value by 2 because it was divided in the `set_ttl()` method - def set_retry_timeout(self, retry_timeout): + def set_retry_timeout(self, retry_timeout: int) -> None: self._retry.deadline = retry_timeout self._client.http.set_read_timeout(retry_timeout) - def adjust_ttl(self): + def adjust_ttl(self) -> None: try: settings = self._client.agent.self() min_ttl = (settings['Config']['SessionTTLMin'] or 10000000000) / 1000000000.0 @@ -288,7 +304,7 @@ class Consul(AbstractDCS): except Exception: logger.exception('adjust_ttl') - def _do_refresh_session(self, force=False): + def _do_refresh_session(self, force: bool = False) -> bool: """:returns: `!True` if it had to create new session""" if not force and self._session and self._last_session_refresh + self._loop_wait > time.time(): return False @@ -312,7 +328,7 @@ class Consul(AbstractDCS): self._last_session_refresh = time.time() return ret - def refresh_session(self): + def refresh_session(self) -> bool: try: return self.retry(self._do_refresh_session) except (ConsulException, RetryFailedError): @@ -320,10 +336,10 @@ class Consul(AbstractDCS): raise ConsulError('Failed to renew/create session') @staticmethod - def member(node): + def member(node: Dict[str, str]) -> Member: return Member.from_node(node['ModifyIndex'], os.path.basename(node['Key']), node.get('Session'), node['Value']) - def _cluster_from_nodes(self, nodes): + def _cluster_from_nodes(self, nodes: Dict[str, Any]) -> Cluster: # get initialize flag initialize = nodes.get(self._INITIALIZE) initialize = initialize and initialize['Value'] @@ -351,7 +367,7 @@ class Consul(AbstractDCS): slots = None try: - last_lsn = int(last_lsn) + last_lsn = int(last_lsn or '') except Exception: last_lsn = 0 @@ -384,7 +400,7 @@ class Consul(AbstractDCS): return Cluster(initialize, config, leader, last_lsn, members, failover, sync, history, slots, failsafe) - def _cluster_loader(self, path): + def _cluster_loader(self, path: str) -> Cluster: _, results = self.retry(self._client.kv.get, path, recurse=True) if results is None: raise NotFound @@ -395,9 +411,9 @@ class Consul(AbstractDCS): return self._cluster_from_nodes(nodes) - def _citus_cluster_loader(self, path): + def _citus_cluster_loader(self, path: str) -> Dict[int, Cluster]: _, results = self.retry(self._client.kv.get, path, recurse=True) - clusters = defaultdict(dict) + clusters: Dict[int, Dict[str, Cluster]] = defaultdict(dict) for node in results or []: key = node['Key'][len(path):].split('/', 1) if len(key) == 2 and citus_group_re.match(key[0]): @@ -405,7 +421,9 @@ class Consul(AbstractDCS): clusters[int(key[0])][key[1]] = node return {group: self._cluster_from_nodes(nodes) for group, nodes in clusters.items()} - def _load_cluster(self, path, loader): + def _load_cluster( + self, path: str, loader: Callable[[str], Union[Cluster, Dict[int, Cluster]]] + ) -> Union[Cluster, Dict[int, Cluster]]: try: return loader(path) except NotFound: @@ -415,7 +433,7 @@ class Consul(AbstractDCS): raise ConsulError('Consul is not responding properly') @catch_consul_errors - def touch_member(self, data): + def touch_member(self, data: Dict[str, Any]) -> bool: cluster = self.cluster member = cluster and cluster.get_member(self._name, fallback_to_leader=False) @@ -447,30 +465,32 @@ class Consul(AbstractDCS): logger.exception('touch_member') return False - def _set_service_name(self): + def _set_service_name(self) -> None: self._service_name = service_name_from_scope_name(self._scope) if self._scope != self._service_name: logger.warning('Using %s as consul service name instead of scope name %s', self._service_name, self._scope) @catch_consul_errors - def register_service(self, service_name, **kwargs): + def register_service(self, service_name: str, **kwargs: Any) -> bool: logger.info('Register service %s, params %s', service_name, kwargs) return self._client.agent.service.register(service_name, **kwargs) @catch_consul_errors - def deregister_service(self, service_id): + def deregister_service(self, service_id: str) -> bool: logger.info('Deregister service %s', service_id) # service_id can contain special characters, but is used as part of uri in deregister request service_id = quote(service_id) return self._client.agent.service.deregister(service_id) - def _update_service(self, data): + def _update_service(self, data: Dict[str, Any]) -> Optional[bool]: service_name = self._service_name role = data['role'].replace('_', '-') state = data['state'] - api_parts = urlparse(data['api_url']) + api_url: str = data['api_url'] + api_parts = urlparse(api_url) api_parts = api_parts._replace(path='/{0}'.format(role)) - conn_parts = urlparse(data['conn_url']) + conn_url: str = data['conn_url'] + conn_parts = urlparse(conn_url) check = base.Check.http(api_parts.geturl(), self._service_check_interval, deregister='{0}s'.format(self._client.http.ttl * 10)) if self._service_check_tls_server_name is not None: @@ -506,7 +526,7 @@ class Consul(AbstractDCS): logger.warning('Could not register service: unknown role type %s', role) @force_if_last_failed - def update_service(self, old_data, new_data, force=False): + def update_service(self, old_data: Dict[str, Any], new_data: Dict[str, Any], force: bool = False) -> Optional[bool]: update = False for key in ['role', 'api_url', 'conn_url', 'state']: @@ -523,7 +543,7 @@ class Consul(AbstractDCS): ): return self._update_service(new_data) - def _do_attempt_to_acquire_leader(self, retry): + def _do_attempt_to_acquire_leader(self, retry: Retry) -> bool: try: return retry(self._client.kv.put, self.leader_path, self._name, acquire=self._session) except InvalidSession: @@ -540,7 +560,7 @@ class Consul(AbstractDCS): return retry(self._client.kv.put, self.leader_path, self._name, acquire=self._session) @catch_return_false_exception - def attempt_to_acquire_leader(self): + def attempt_to_acquire_leader(self) -> bool: retry = self._retry.copy() self._run_and_handle_exceptions(self._do_refresh_session, retry=retry) @@ -554,31 +574,31 @@ class Consul(AbstractDCS): return ret - def take_leader(self): + def take_leader(self) -> bool: return self.attempt_to_acquire_leader() @catch_consul_errors - def set_failover_value(self, value, index=None): + def set_failover_value(self, value: str, index: Optional[int] = None) -> bool: return self._client.kv.put(self.failover_path, value, cas=index) @catch_consul_errors - def set_config_value(self, value, index=None): + def set_config_value(self, value: str, index: Optional[int] = None) -> bool: return self._client.kv.put(self.config_path, value, cas=index) @catch_consul_errors - def _write_leader_optime(self, last_lsn): + def _write_leader_optime(self, last_lsn: str) -> bool: return self._client.kv.put(self.leader_optime_path, last_lsn) @catch_consul_errors - def _write_status(self, value): + def _write_status(self, value: str) -> bool: return self._client.kv.put(self.status_path, value) @catch_consul_errors - def _write_failsafe(self, value): + def _write_failsafe(self, value: str) -> bool: return self._client.kv.put(self.failsafe_path, value) @staticmethod - def _run_and_handle_exceptions(method, *args, **kwargs): + def _run_and_handle_exceptions(method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: retry = kwargs.pop('retry', None) try: return retry(method, *args, **kwargs) if retry else method(*args, **kwargs) @@ -588,7 +608,7 @@ class Consul(AbstractDCS): raise ReturnFalseException @catch_return_false_exception - def _update_leader(self): + def _update_leader(self) -> bool: retry = self._retry.copy() self._run_and_handle_exceptions(self._do_refresh_session, True, retry=retry) @@ -601,7 +621,7 @@ class Consul(AbstractDCS): if retry.deadline < 1: raise ConsulError('update_leader timeout') logger.warning('Recreating the leader key due to session mismatch') - if cluster.leader: + if cluster and cluster.leader: self._run_and_handle_exceptions(self._client.kv.delete, self.leader_path, cas=cluster.leader.index) retry.deadline = retry.stoptime - time.time() @@ -613,37 +633,39 @@ class Consul(AbstractDCS): return bool(self._session) @catch_consul_errors - def initialize(self, create_new=True, sysid=''): + def initialize(self, create_new: bool = True, sysid: str = '') -> bool: kwargs = {'cas': 0} if create_new else {} return self.retry(self._client.kv.put, self.initialize_path, sysid, **kwargs) @catch_consul_errors - def cancel_initialization(self): + def cancel_initialization(self) -> bool: return self.retry(self._client.kv.delete, self.initialize_path) @catch_consul_errors - def delete_cluster(self): + def delete_cluster(self) -> bool: return self.retry(self._client.kv.delete, self.client_path(''), recurse=True) @catch_consul_errors - def set_history_value(self, value): + def set_history_value(self, value: str) -> bool: return self._client.kv.put(self.history_path, value) @catch_consul_errors - def _delete_leader(self): + def _delete_leader(self) -> bool: cluster = self.cluster - if cluster and isinstance(cluster.leader, Leader) and cluster.leader.name == self._name: + if cluster and isinstance(cluster.leader, Leader) and\ + cluster.leader.name == self._name and isinstance(cluster.leader.index, int): return self._client.kv.delete(self.leader_path, cas=cluster.leader.index) + return True @catch_consul_errors - def set_sync_state_value(self, value, index=None): + def set_sync_state_value(self, value: str, index: Optional[int] = None) -> bool: return self.retry(self._client.kv.put, self.sync_path, value, cas=index) @catch_consul_errors - def delete_sync_state(self, index=None): + def delete_sync_state(self, index: Optional[int] = None) -> bool: return self.retry(self._client.kv.delete, self.sync_path, cas=index) - def watch(self, leader_index, timeout): + def watch(self, leader_index: Optional[int], timeout: float) -> bool: self._last_session_refresh = 0 if self.__do_not_watch: self.__do_not_watch = False diff --git a/patroni/dcs/etcd.py b/patroni/dcs/etcd.py index c736c1c6..c2a83dab 100644 --- a/patroni/dcs/etcd.py +++ b/patroni/dcs/etcd.py @@ -16,7 +16,7 @@ from dns import resolver from http.client import HTTPException from queue import Queue from threading import Thread -from typing import List, Optional +from typing import Any, Callable, Collection, Dict, List, Optional, Union, Tuple, Type, TYPE_CHECKING from urllib.parse import urlparse from urllib3 import Timeout from urllib3.exceptions import HTTPError, ReadTimeoutError, ProtocolError @@ -26,6 +26,8 @@ from . import AbstractDCS, Cluster, ClusterConfig, Failover, Leader, Member, Syn from ..exceptions import DCSError from ..request import get as requests_get from ..utils import Retry, RetryFailedError, split_host_port, uri, USER_AGENT +if TYPE_CHECKING: # pragma: no cover + from ..config import Config logger = logging.getLogger(__name__) @@ -38,18 +40,21 @@ class EtcdError(DCSError): pass +_AddrInfo = Tuple[socket.AddressFamily, socket.SocketKind, int, str, Union[Tuple[str, int], Tuple[str, int, int, int]]] + + class DnsCachingResolver(Thread): - def __init__(self, cache_time=600.0, cache_fail_time=30.0): + def __init__(self, cache_time: float = 600.0, cache_fail_time: float = 30.0) -> None: super(DnsCachingResolver, self).__init__() - self._cache = {} + self._cache: Dict[Tuple[str, int], Tuple[float, List[_AddrInfo]]] = {} self._cache_time = cache_time self._cache_fail_time = cache_fail_time - self._resolve_queue = Queue() + self._resolve_queue: Queue[Tuple[Tuple[str, int], int]] = Queue() self.daemon = True self.start() - def run(self): + def run(self) -> None: while True: (host, port), attempt = self._resolve_queue.get() response = self._do_resolve(host, port) @@ -60,7 +65,7 @@ class DnsCachingResolver(Thread): self.resolve_async(host, port, attempt + 1) time.sleep(1) - def resolve(self, host, port): + def resolve(self, host: str, port: int) -> List[_AddrInfo]: current_time = time.time() cached_time, response = self._cache.get((host, port), (0, [])) time_passed = current_time - cached_time @@ -71,14 +76,14 @@ class DnsCachingResolver(Thread): response = new_response return response - def resolve_async(self, host, port, attempt=0): + def resolve_async(self, host: str, port: int, attempt: int = 0) -> None: self._resolve_queue.put(((host, port), attempt)) - def remove(self, host, port): + def remove(self, host: str, port: int) -> None: self._cache.pop((host, port), None) @staticmethod - def _do_resolve(host, port): + def _do_resolve(host: str, port: int) -> List[_AddrInfo]: try: return socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM, socket.IPPROTO_TCP) except Exception as e: @@ -88,13 +93,15 @@ class DnsCachingResolver(Thread): class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client): - def __init__(self, config, dns_resolver, cache_ttl=300): + ERROR_CLS: Type[Exception] + + def __init__(self, config: Dict[str, Any], dns_resolver: DnsCachingResolver, cache_ttl: int = 300) -> None: self._dns_resolver = dns_resolver self.set_machines_cache_ttl(cache_ttl) self._machines_cache_updated = 0 - args = {p: config.get(p) for p in ('host', 'port', 'protocol', 'use_proxies', 'username', 'password', - 'cert', 'ca_cert') if config.get(p)} - super(AbstractEtcdClientWithFailover, self).__init__(read_timeout=config['retry_timeout'], **args) + kwargs = {p: config.get(p) for p in ('host', 'port', 'protocol', 'use_proxies', + 'username', 'password', 'cert', 'ca_cert') if config.get(p)} + super(AbstractEtcdClientWithFailover, self).__init__(read_timeout=config['retry_timeout'], **kwargs) # For some reason python3-etcd on debian and ubuntu are not based on the latest version # Workaround for the case when https://github.com/jplana/python-etcd/pull/196 is not applied self.http.connection_pool_kw.pop('ssl_version', None) @@ -106,7 +113,7 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client): self._read_options.add('retry') self._del_conditions.add('retry') - def _calculate_timeouts(self, etcd_nodes, timeout=None): + def _calculate_timeouts(self, etcd_nodes: int, timeout: Optional[float] = None) -> Tuple[int, float, int]: """Calculate a request timeout and number of retries per single etcd node. In case if the timeout per node is too small (less than one second) we will reduce the number of nodes. For the cluster with only one node we will try to do 2 retries. @@ -133,38 +140,39 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client): return etcd_nodes, per_node_timeout, per_node_retries - 1 - def reload_config(self, config): + def reload_config(self, config: Dict[str, Any]) -> None: self.username = config.get('username') self.password = config.get('password') - def _get_headers(self): + def _get_headers(self) -> Dict[str, str]: basic_auth = ':'.join((self.username, self.password)) if self.username and self.password else None return urllib3.make_headers(basic_auth=basic_auth, user_agent=USER_AGENT) - def _prepare_common_parameters(self, etcd_nodes, timeout=None): - kwargs = {'headers': self._get_headers(), 'redirect': self.allow_redirect, 'preload_content': False} + def _prepare_common_parameters(self, etcd_nodes: int, timeout: Optional[float] = None) -> Dict[str, Any]: + kwargs: Dict[str, Any] = {'headers': self._get_headers(), + 'redirect': self.allow_redirect, 'preload_content': False} if timeout is not None: kwargs.update(retries=0, timeout=timeout) else: _, per_node_timeout, per_node_retries = self._calculate_timeouts(etcd_nodes) - connect_timeout = max(1, per_node_timeout / 2) + connect_timeout = max(1.0, per_node_timeout / 2.0) kwargs.update(timeout=Timeout(connect=connect_timeout, total=per_node_timeout), retries=per_node_retries) return kwargs - def set_machines_cache_ttl(self, cache_ttl): + def set_machines_cache_ttl(self, cache_ttl: int) -> None: self._machines_cache_ttl = cache_ttl @abc.abstractmethod - def _prepare_get_members(self, etcd_nodes): + def _prepare_get_members(self, etcd_nodes: int) -> Dict[str, Any]: """returns: request parameters""" @abc.abstractmethod - def _get_members(self, base_uri, **kwargs): + def _get_members(self, base_uri: str, **kwargs: Any) -> List[str]: """returns: list of clientURLs""" @property - def machines_cache(self): + def machines_cache(self) -> List[str]: base_uri, cache = self._base_uri, self._machines_cache return ([base_uri] if base_uri in cache else []) + [machine for machine in cache if machine != base_uri] @@ -207,10 +215,14 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client): return self._get_machines_list(self.machines_cache) - def set_read_timeout(self, timeout): + def set_read_timeout(self, timeout: float) -> None: self._read_timeout = timeout - def _do_http_request(self, retry, machines_cache, request_executor, method, path, fields=None, **kwargs): + def _do_http_request(self, retry: Optional[Retry], machines_cache: List[str], + request_executor: Callable[..., urllib3.response.HTTPResponse], + method: str, path: str, fields: Optional[Dict[str, Any]] = None, + **kwargs: Any) -> urllib3.response.HTTPResponse: + is_watch_request = isinstance(fields, dict) and fields.get('wait') == 'true' if fields is not None: kwargs['fields'] = fields some_request_failed = False @@ -233,8 +245,7 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client): # whether the key didn't received an update or there is a network problem. elif i + 1 < len(machines_cache): self.set_base_uri(machines_cache[i + 1]) - if (isinstance(fields, dict) and fields.get("wait") == "true" - and isinstance(e, (ReadTimeoutError, ProtocolError))): + if is_watch_request and isinstance(e, (ReadTimeoutError, ProtocolError)): logger.debug("Watch timed out.") raise etcd.EtcdWatchTimedOut("Watch timed out: {0}".format(e), cause=e) logger.error("Request to server %s failed: %r", base_uri, e) @@ -246,10 +257,12 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client): raise etcd.EtcdConnectionFailed('No more machines in the cluster') @abc.abstractmethod - def _prepare_request(self, kwargs, params=None, method=None): + def _prepare_request(self, kwargs: Dict[str, Any], params: Optional[Dict[str, Any]] = None, + method: Optional[str] = None) -> Callable[..., urllib3.response.HTTPResponse]: """returns: request_executor""" - def api_execute(self, path, method, params=None, timeout=None): + def api_execute(self, path: str, method: str, params: Optional[Dict[str, Any]] = None, + timeout: Optional[float] = None) -> Any: retry = params.pop('retry', None) if isinstance(params, dict) else None # Update machines_cache if previous attempt of update has failed @@ -277,6 +290,7 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client): etcd_nodes = len(machines_cache) except Exception as e: logger.debug('Failed to update list of etcd nodes: %r', e) + assert isinstance(retry, Retry) # etcd.EtcdConnectionFailed is raised only if retry is not None! sleeptime = retry.sleeptime remaining_time = retry.stoptime - sleeptime - time.time() nodes, timeout, retries = self._calculate_timeouts(etcd_nodes, remaining_time) @@ -287,22 +301,22 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client): retry.sleep_func(sleeptime) retry.update_delay() # We still have some time left. Partially reduce `machines_cache` and retry request - kwargs.update(timeout=Timeout(connect=max(1, timeout / 2), total=timeout), retries=retries) + kwargs.update(timeout=Timeout(connect=max(1.0, timeout / 2.0), total=timeout), retries=retries) machines_cache = machines_cache[:nodes] @staticmethod - def get_srv_record(host): + def get_srv_record(host: str) -> List[Tuple[str, int]]: try: return [(r.target.to_text(True), r.port) for r in resolver.query(host, 'SRV')] except DNSException: return [] - def _get_machines_cache_from_srv(self, srv, srv_suffix=None): + def _get_machines_cache_from_srv(self, srv: str, srv_suffix: Optional[str] = None) -> List[str]: """Fetch list of etcd-cluster member by resolving _etcd-server._tcp. SRV record. This record should contain list of host and peer ports which could be used to run 'GET http://{host}:{port}/members' request (peer protocol)""" - ret = [] + ret: List[str] = [] for r in ['-client-ssl', '-client', '-ssl', '', '-server-ssl', '-server']: r = '{0}-{1}'.format(r, srv_suffix) if srv_suffix else r protocol = 'https' if '-ssl' in r else 'http' @@ -327,15 +341,15 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client): logger.warning('Can not resolve SRV for %s', srv) return list(set(ret)) - def _get_machines_cache_from_dns(self, host, port): + def _get_machines_cache_from_dns(self, host: str, port: int) -> List[str]: """One host might be resolved into multiple ip addresses. We will make list out of it""" if self.protocol == 'http': - ret = map(lambda res: uri(self.protocol, res[-1][:2]), self._dns_resolver.resolve(host, port)) + ret = [uri(self.protocol, res[-1][:2]) for res in self._dns_resolver.resolve(host, port)] if ret: return list(set(ret)) return [uri(self.protocol, (host, port))] - def _get_machines_cache_from_config(self): + def _get_machines_cache_from_config(self) -> List[str]: if 'proxy' in self._config: return [uri(self.protocol, (self._config['host'], self._config['port']))] @@ -351,13 +365,14 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client): return machines_cache @staticmethod - def _update_dns_cache(func, machines): + def _update_dns_cache(func: Callable[[str, int], None], machines: List[str]) -> None: for url in machines: r = urlparse(url) - port = r.port or (443 if r.scheme == 'https' else 80) - func(r.hostname, port) + if r.hostname: + port = r.port or (443 if r.scheme == 'https' else 80) + func(r.hostname, port) - def _load_machines_cache(self): + def _load_machines_cache(self) -> bool: """This method should fill up `_machines_cache` from scratch. It could happen only in two cases: 1. During class initialization @@ -417,7 +432,7 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client): self._machines_cache_updated = time.time() return ret - def set_base_uri(self, value): + def set_base_uri(self, value: str) -> None: if self._base_uri != value: logger.info('Selected new etcd server %s', value) self._base_uri = value @@ -427,22 +442,22 @@ class EtcdClient(AbstractEtcdClientWithFailover): ERROR_CLS = EtcdError - def __del__(self): - if self.http is not None: - try: - self.http.clear() - except (ReferenceError, TypeError, AttributeError): - pass + def __del__(self) -> None: + try: + self.http.clear() + except (ReferenceError, TypeError, AttributeError): + pass - def _prepare_get_members(self, etcd_nodes): + def _prepare_get_members(self, etcd_nodes: int) -> Dict[str, Any]: return self._prepare_common_parameters(etcd_nodes) - def _get_members(self, base_uri, **kwargs): + def _get_members(self, base_uri: str, **kwargs: Any) -> List[str]: response = self.http.request(self._MGET, base_uri + self.version_prefix + '/machines', **kwargs) data = self._handle_server_response(response).data.decode('utf-8') return [m.strip() for m in data.split(',') if m.strip()] - def _prepare_request(self, kwargs, params=None, method=None): + def _prepare_request(self, kwargs: Dict[str, Any], params: Optional[Dict[str, Any]] = None, + method: Optional[str] = None) -> Callable[..., urllib3.response.HTTPResponse]: kwargs['fields'] = params if method in (self._MPOST, self._MPUT): kwargs['encode_multipart'] = False @@ -451,25 +466,32 @@ class EtcdClient(AbstractEtcdClientWithFailover): class AbstractEtcd(AbstractDCS): - def __init__(self, config, client_cls, retry_errors_cls): + def __init__(self, config: Dict[str, Any], client_cls: Type[AbstractEtcdClientWithFailover], + retry_errors_cls: Union[Type[Exception], Tuple[Type[Exception], ...]]) -> None: super(AbstractEtcd, self).__init__(config) self._retry = Retry(deadline=config['retry_timeout'], max_delay=1, max_tries=-1, retry_exceptions=retry_errors_cls) self._ttl = int(config.get('ttl') or 30) - self._client = self.get_etcd_client(config, client_cls) + self._abstract_client = self.get_etcd_client(config, client_cls) self.__do_not_watch = False self._has_failed = False - def reload_config(self, config): + @property + @abc.abstractmethod + def _client(self) -> AbstractEtcdClientWithFailover: + """return correct type of etcd client""" + + def reload_config(self, config: Union['Config', Dict[str, Any]]) -> None: super(AbstractEtcd, self).reload_config(config) self._client.reload_config(config.get(self.__class__.__name__.lower(), {})) - def retry(self, *args, **kwargs): + def retry(self, method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: retry = self._retry.copy() kwargs['retry'] = retry - return retry(*args, **kwargs) + return retry(method, *args, **kwargs) - def _handle_exception(self, e, name='', do_sleep=False, raise_ex=None): + def _handle_exception(self, e: Exception, name: str = '', do_sleep: bool = False, + raise_ex: Optional[Exception] = None) -> None: if not self._has_failed: logger.exception(name) else: @@ -480,7 +502,7 @@ class AbstractEtcd(AbstractDCS): if isinstance(raise_ex, Exception): raise raise_ex - def _run_and_handle_exceptions(self, method, *args, **kwargs): + def _run_and_handle_exceptions(self, method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: retry = kwargs.pop('retry', self.retry) try: return retry(method, *args, **kwargs) if retry else method(*args, **kwargs) @@ -492,19 +514,20 @@ class AbstractEtcd(AbstractDCS): except Exception as e: self._handle_exception(e, raise_ex=self._client.ERROR_CLS('unexpected error')) - @staticmethod - def set_socket_options(sock, socket_options): + def set_socket_options(self, sock: socket.socket, + socket_options: Optional[Collection[Tuple[int, int, int]]]) -> None: if socket_options: for opt in socket_options: sock.setsockopt(*opt) - def get_etcd_client(self, config, client_cls): + def get_etcd_client(self, config: Dict[str, Any], + client_cls: Type[AbstractEtcdClientWithFailover]) -> AbstractEtcdClientWithFailover: config = deepcopy(config) if 'proxy' in config: config['use_proxies'] = True config['url'] = config['proxy'] - if 'url' in config: + if 'url' in config and isinstance(config['url'], str): r = urlparse(config['url']) config.update({'protocol': r.scheme, 'host': r.hostname, 'port': r.port or 2379, 'username': r.username, 'password': r.password}) @@ -516,10 +539,11 @@ class AbstractEtcd(AbstractDCS): if isinstance(hosts, str): hosts = hosts.split(',') - config['hosts'] = [] + config_hosts: List[str] = [] for value in hosts: if isinstance(value, str): - config['hosts'].append(uri(protocol, split_host_port(value.strip(), default_port))) + config_hosts.append(uri(protocol, split_host_port(value.strip(), default_port))) + config['hosts'] = config_hosts elif 'host' in config: host, port = split_host_port(config['host'], 2379) config['host'] = host @@ -538,8 +562,10 @@ class AbstractEtcd(AbstractDCS): dns_resolver = DnsCachingResolver() - def create_connection_patched(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, - source_address=None, socket_options=None): + def create_connection_patched( + address: Tuple[str, int], timeout: Any = object(), + source_address: Optional[Any] = None, socket_options: Optional[Collection[Tuple[int, int, int]]] = None + ) -> socket.socket: host, port = address if host.startswith('['): host = host.strip('[]') @@ -549,7 +575,7 @@ class AbstractEtcd(AbstractDCS): try: sock = socket.socket(af, socktype, proto) self.set_socket_options(sock, socket_options) - if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: + if timeout is None or isinstance(timeout, (float, int)): sock.settimeout(timeout) if source_address: sock.bind(source_address) @@ -580,7 +606,7 @@ class AbstractEtcd(AbstractDCS): time.sleep(5) return client - def set_ttl(self, ttl): + def set_ttl(self, ttl: int) -> Optional[bool]: ttl = int(ttl) ret = self._ttl != ttl self._ttl = ttl @@ -588,16 +614,16 @@ class AbstractEtcd(AbstractDCS): return ret @property - def ttl(self): + def ttl(self) -> int: return self._ttl - def set_retry_timeout(self, retry_timeout): + def set_retry_timeout(self, retry_timeout: int) -> None: self._retry.deadline = retry_timeout self._client.set_read_timeout(retry_timeout) -def catch_etcd_errors(func): - def wrapper(self, *args, **kwargs): +def catch_etcd_errors(func: Callable[..., Any]) -> Any: + def wrapper(self: AbstractEtcd, *args: Any, **kwargs: Any) -> Any: try: retval = func(self, *args, **kwargs) is not None self._has_failed = False @@ -613,18 +639,24 @@ def catch_etcd_errors(func): class Etcd(AbstractEtcd): - def __init__(self, config): + def __init__(self, config: Dict[str, Any]) -> None: super(Etcd, self).__init__(config, EtcdClient, (etcd.EtcdLeaderElectionInProgress, EtcdRaftInternal)) self.__do_not_watch = False - def set_ttl(self, ttl): + @property + def _client(self) -> EtcdClient: + assert isinstance(self._abstract_client, EtcdClient) + return self._abstract_client + + def set_ttl(self, ttl: int) -> Optional[bool]: self.__do_not_watch = super(Etcd, self).set_ttl(ttl) + return None @staticmethod - def member(node): + def member(node: etcd.EtcdResult) -> Member: return Member.from_node(node.modifiedIndex, os.path.basename(node.key), node.ttl, node.value) - def _cluster_from_nodes(self, etcd_index, nodes): + def _cluster_from_nodes(self, etcd_index: int, nodes: Dict[str, etcd.EtcdResult]) -> Cluster: # get initialize flag initialize = nodes.get(self._INITIALIZE) initialize = initialize and initialize.value @@ -652,7 +684,7 @@ class Etcd(AbstractEtcd): slots = None try: - last_lsn = int(last_lsn) + last_lsn = int(last_lsn or '') except Exception: last_lsn = 0 @@ -685,13 +717,13 @@ class Etcd(AbstractEtcd): return Cluster(initialize, config, leader, last_lsn, members, failover, sync, history, slots, failsafe) - def _cluster_loader(self, path): + def _cluster_loader(self, path: str) -> Cluster: result = self.retry(self._client.read, path, recursive=True) nodes = {node.key[len(result.key):].lstrip('/'): node for node in result.leaves} return self._cluster_from_nodes(result.etcd_index, nodes) - def _citus_cluster_loader(self, path): - clusters = defaultdict(dict) + def _citus_cluster_loader(self, path: str) -> Dict[int, Cluster]: + clusters: Dict[int, Dict[str, etcd.EtcdResult]] = defaultdict(dict) result = self.retry(self._client.read, path, recursive=True) for node in result.leaves: key = node.key[len(result.key):].lstrip('/').split('/', 1) @@ -699,7 +731,9 @@ class Etcd(AbstractEtcd): clusters[int(key[0])][key[1]] = node return {group: self._cluster_from_nodes(result.etcd_index, nodes) for group, nodes in clusters.items()} - def _load_cluster(self, path, loader): + def _load_cluster( + self, path: str, loader: Callable[[str], Union[Cluster, Dict[int, Cluster]]] + ) -> Union[Cluster, Dict[int, Cluster]]: cluster = None try: cluster = loader(path) @@ -708,18 +742,19 @@ class Etcd(AbstractEtcd): except Exception as e: self._handle_exception(e, 'get_cluster', raise_ex=EtcdError('Etcd is not responding properly')) self._has_failed = False + assert cluster is not None return cluster @catch_etcd_errors - def touch_member(self, data): - data = json.dumps(data, separators=(',', ':')) - return self._client.set(self.member_path, data, self._ttl) + def touch_member(self, data: Dict[str, Any]) -> bool: + value = json.dumps(data, separators=(',', ':')) + return bool(self._client.set(self.member_path, value, self._ttl)) @catch_etcd_errors - def take_leader(self): + def take_leader(self) -> bool: return self.retry(self._client.write, self.leader_path, self._name, ttl=self._ttl) - def _do_attempt_to_acquire_leader(self): + def _do_attempt_to_acquire_leader(self) -> bool: try: return bool(self.retry(self._client.write, self.leader_path, self._name, ttl=self._ttl, prevExist=False)) except etcd.EtcdAlreadyExist: @@ -727,26 +762,26 @@ class Etcd(AbstractEtcd): return False @catch_return_false_exception - def attempt_to_acquire_leader(self): + def attempt_to_acquire_leader(self) -> bool: return self._run_and_handle_exceptions(self._do_attempt_to_acquire_leader, retry=None) @catch_etcd_errors - def set_failover_value(self, value, index=None): - return self._client.write(self.failover_path, value, prevIndex=index or 0) + def set_failover_value(self, value: str, index: Optional[int] = None) -> bool: + return bool(self._client.write(self.failover_path, value, prevIndex=index or 0)) @catch_etcd_errors - def set_config_value(self, value, index=None): - return self._client.write(self.config_path, value, prevIndex=index or 0) + def set_config_value(self, value: str, index: Optional[int] = None) -> bool: + return bool(self._client.write(self.config_path, value, prevIndex=index or 0)) @catch_etcd_errors - def _write_leader_optime(self, last_lsn): - return self._client.set(self.leader_optime_path, last_lsn) + def _write_leader_optime(self, last_lsn: str) -> bool: + return bool(self._client.set(self.leader_optime_path, last_lsn)) @catch_etcd_errors - def _write_status(self, value): - return self._client.set(self.status_path, value) + def _write_status(self, value: str) -> bool: + return bool(self._client.set(self.status_path, value)) - def _do_update_leader(self): + def _do_update_leader(self) -> bool: try: return self.retry(self._client.write, self.leader_path, self._name, prevValue=self._name, ttl=self._ttl) is not None @@ -754,42 +789,42 @@ class Etcd(AbstractEtcd): return self._do_attempt_to_acquire_leader() @catch_etcd_errors - def _write_failsafe(self, value): - return self._client.set(self.failsafe_path, value) + def _write_failsafe(self, value: str) -> bool: + return bool(self._client.set(self.failsafe_path, value)) @catch_return_false_exception - def _update_leader(self): - return self._run_and_handle_exceptions(self._do_update_leader, retry=None) + def _update_leader(self) -> bool: + return bool(self._run_and_handle_exceptions(self._do_update_leader, retry=None)) @catch_etcd_errors - def initialize(self, create_new=True, sysid=""): - return self.retry(self._client.write, self.initialize_path, sysid, prevExist=(not create_new)) + def initialize(self, create_new: bool = True, sysid: str = "") -> bool: + return bool(self.retry(self._client.write, self.initialize_path, sysid, prevExist=(not create_new))) @catch_etcd_errors - def _delete_leader(self): - return self._client.delete(self.leader_path, prevValue=self._name) + def _delete_leader(self) -> bool: + return bool(self._client.delete(self.leader_path, prevValue=self._name)) @catch_etcd_errors - def cancel_initialization(self): - return self.retry(self._client.delete, self.initialize_path) + def cancel_initialization(self) -> bool: + return bool(self.retry(self._client.delete, self.initialize_path)) @catch_etcd_errors - def delete_cluster(self): - return self.retry(self._client.delete, self.client_path(''), recursive=True) + def delete_cluster(self) -> bool: + return bool(self.retry(self._client.delete, self.client_path(''), recursive=True)) @catch_etcd_errors - def set_history_value(self, value): - return self._client.write(self.history_path, value) + def set_history_value(self, value: str) -> bool: + return bool(self._client.write(self.history_path, value)) @catch_etcd_errors - def set_sync_state_value(self, value, index=None): - return self.retry(self._client.write, self.sync_path, value, prevIndex=index or 0) + def set_sync_state_value(self, value: str, index: Optional[int] = None) -> bool: + return bool(self.retry(self._client.write, self.sync_path, value, prevIndex=index or 0)) @catch_etcd_errors - def delete_sync_state(self, index=None): - return self.retry(self._client.delete, self.sync_path, prevIndex=index or 0) + def delete_sync_state(self, index: Optional[int] = None) -> bool: + return bool(self.retry(self._client.delete, self.sync_path, prevIndex=index or 0)) - def watch(self, leader_index, timeout): + def watch(self, leader_index: Optional[int], timeout: float) -> bool: if self.__do_not_watch: self.__do_not_watch = False return True diff --git a/patroni/dcs/etcd3.py b/patroni/dcs/etcd3.py index b17e16fc..b526e1e2 100644 --- a/patroni/dcs/etcd3.py +++ b/patroni/dcs/etcd3.py @@ -10,12 +10,14 @@ import time import urllib3 from collections import defaultdict -from threading import Condition, Lock, Thread +from enum import IntEnum from urllib3.exceptions import ReadTimeoutError, ProtocolError +from threading import Condition, Lock, Thread +from typing import Any, Callable, Collection, Dict, Iterator, List, Optional, Tuple, Type, Union from . import ClusterConfig, Cluster, Failover, Leader, Member, SyncState,\ TimelineHistory, ReturnFalseException, catch_return_false_exception, citus_group_re -from .etcd import AbstractEtcdClientWithFailover, AbstractEtcd, catch_etcd_errors +from .etcd import AbstractEtcdClientWithFailover, AbstractEtcd, catch_etcd_errors, DnsCachingResolver, Retry from ..exceptions import DCSError, PatroniException from ..utils import deep_compare, enable_keepalive, iter_response_objects, RetryFailedError, USER_AGENT @@ -31,11 +33,27 @@ class UnsupportedEtcdVersion(PatroniException): # google.golang.org/grpc/codes -GRPCCode = type('Enum', (), {'OK': 0, 'Canceled': 1, 'Unknown': 2, 'InvalidArgument': 3, 'DeadlineExceeded': 4, - 'NotFound': 5, 'AlreadyExists': 6, 'PermissionDenied': 7, 'ResourceExhausted': 8, - 'FailedPrecondition': 9, 'Aborted': 10, 'OutOfRange': 11, 'Unimplemented': 12, - 'Internal': 13, 'Unavailable': 14, 'DataLoss': 15, 'Unauthenticated': 16}) -GRPCcodeToText = {v: k for k, v in GRPCCode.__dict__.items() if not k.startswith('__') and isinstance(v, int)} +class GRPCCode(IntEnum): + OK = 0 + Canceled = 1 + Unknown = 2 + InvalidArgument = 3 + DeadlineExceeded = 4 + NotFound = 5 + AlreadyExists = 6 + PermissionDenied = 7 + ResourceExhausted = 8 + FailedPrecondition = 9 + Aborted = 10 + OutOfRange = 11 + Unimplemented = 12 + Internal = 13 + Unavailable = 14 + DataLoss = 15 + Unauthenticated = 16 + + +GRPCcodeToText: Dict[int, str] = {v: k for k, v in GRPCCode.__dict__['_member_map_'].items()} class Etcd3Exception(etcd.EtcdException): @@ -44,22 +62,24 @@ class Etcd3Exception(etcd.EtcdException): class Etcd3ClientError(Etcd3Exception): - def __init__(self, code=None, error=None, status=None): + def __init__(self, code: Optional[int] = None, error: Optional[str] = None, status: Optional[int] = None) -> None: if not hasattr(self, 'error'): self.error = error and error.strip() - self.codeText = GRPCcodeToText.get(code) + self.codeText = GRPCcodeToText.get(code) if code is not None else None self.status = status - def __repr__(self): - return "<{0} error: '{1}', code: {2}>".format(self.__class__.__name__, self.error, self.code) + def __repr__(self) -> str: + return "<{0} error: '{1}', code: {2}>"\ + .format(self.__class__.__name__, getattr(self, 'error', None), getattr(self, 'code', None)) __str__ = __repr__ - def as_dict(self): - return {'error': self.error, 'code': self.code, 'codeText': self.codeText, 'status': self.status} + def as_dict(self) -> Dict[str, Any]: + return {'error': getattr(self, 'error', None), 'code': getattr(self, 'code', None), + 'codeText': self.codeText, 'status': self.status} @classmethod - def get_subclasses(cls): + def get_subclasses(cls) -> Iterator[Type['Etcd3ClientError']]: for subclass in cls.__subclasses__(): for subsubclass in subclass.get_subclasses(): yield subsubclass @@ -118,61 +138,93 @@ class InvalidAuthToken(Etcd3ClientError): error = "etcdserver: invalid auth token" -errStringToClientError = {s.error: s for s in Etcd3ClientError.get_subclasses() if hasattr(s, 'error')} -errCodeToClientError = {s.code: s for s in Etcd3ClientError.__subclasses__()} +errStringToClientError = {getattr(s, 'error'): s for s in Etcd3ClientError.get_subclasses() if hasattr(s, 'error')} +errCodeToClientError = {getattr(s, 'code'): s for s in Etcd3ClientError.__subclasses__()} -def _raise_for_data(data, status_code=None): +def _raise_for_data(data: Union[bytes, str, Dict[str, Union[Any, Dict[str, Any]]]], + status_code: Optional[int] = None) -> Etcd3ClientError: try: - error = data.get('error') or data.get('Error') - if isinstance(error, dict): # streaming response - status_code = error.get('http_code') - code = error['grpc_code'] - error = error['message'] + assert isinstance(data, dict) + data_error: Optional[Dict[str, Any]] = data.get('error') or data.get('Error') + if isinstance(data_error, dict): # streaming response + status_code = data_error.get('http_code') + code: Optional[int] = data_error['grpc_code'] + error: str = data_error['message'] else: - code = data.get('code') or data.get('Code') + data_code = data.get('code') or data.get('Code') + assert not isinstance(data_code, dict) + code = data_code + error = str(data_error) except Exception: error = str(data) code = GRPCCode.Unknown err = errStringToClientError.get(error) or errCodeToClientError.get(code) or Unknown - raise err(code, error, status_code) + return err(code, error, status_code) -def to_bytes(v): +def to_bytes(v: Union[str, bytes]) -> bytes: return v if isinstance(v, bytes) else v.encode('utf-8') -def prefix_range_end(v): - v = bytearray(to_bytes(v)) - for i in range(len(v) - 1, -1, -1): - if v[i] < 0xff: - v[i] += 1 +def prefix_range_end(v: str) -> bytes: + ret = bytearray(to_bytes(v)) + for i in range(len(ret) - 1, -1, -1): + if ret[i] < 0xff: + ret[i] += 1 break - return bytes(v) + return bytes(ret) -def base64_encode(v): +def base64_encode(v: Union[str, bytes]) -> str: return base64.b64encode(to_bytes(v)).decode('utf-8') -def base64_decode(v): +def base64_decode(v: str) -> str: return base64.b64decode(v).decode('utf-8') -def build_range_request(key, range_end=None): +def build_range_request(key: str, range_end: Union[bytes, str, None] = None) -> Dict[str, Any]: fields = {'key': base64_encode(key)} if range_end: fields['range_end'] = base64_encode(range_end) return fields +def _handle_auth_errors(func: Callable[..., Any]) -> Any: + def wrapper(self: 'Etcd3Client', *args: Any, **kwargs: Any) -> Any: + def retry(ex: Exception) -> Any: + if self.username and self.password: + self.authenticate() + return func(self, *args, **kwargs) + else: + logger.fatal('Username or password not set, authentication is not possible') + raise ex + + try: + return func(self, *args, **kwargs) + except (UserEmpty, PermissionDenied) as e: # no token provided + # PermissionDenied is raised on 3.0 and 3.1 + if self._cluster_version < (3, 3) and (not isinstance(e, PermissionDenied) + or self._cluster_version < (3, 2)): + raise UnsupportedEtcdVersion('Authentication is required by Etcd cluster but not ' + 'supported on version lower than 3.3.0. Cluster version: ' + '{0}'.format('.'.join(map(str, self._cluster_version)))) + return retry(e) + except InvalidAuthToken as e: + logger.error('Invalid auth token: %s', self._token) + return retry(e) + + return wrapper + + class Etcd3Client(AbstractEtcdClientWithFailover): ERROR_CLS = Etcd3Error - def __init__(self, config, dns_resolver, cache_ttl=300): + def __init__(self, config: Dict[str, Any], dns_resolver: DnsCachingResolver, cache_ttl: int = 300) -> None: self._token = None - self._cluster_version = None + self._cluster_version: Tuple[int] = tuple() self.version_prefix = '/v3beta' super(Etcd3Client, self).__init__(config, dns_resolver, cache_ttl) @@ -182,32 +234,33 @@ class Etcd3Client(AbstractEtcdClientWithFailover): logger.fatal('Etcd3 authentication failed: %r', e) sys.exit(1) - def _get_headers(self): + def _get_headers(self) -> Dict[str, str]: headers = urllib3.make_headers(user_agent=USER_AGENT) if self._token and self._cluster_version >= (3, 3, 0): headers['authorization'] = self._token return headers - def _prepare_request(self, kwargs, params=None, method=None): + def _prepare_request(self, kwargs: Dict[str, Any], params: Optional[Dict[str, Any]] = None, + method: Optional[str] = None) -> Callable[..., urllib3.response.HTTPResponse]: if params is not None: kwargs['body'] = json.dumps(params) kwargs['headers']['Content-Type'] = 'application/json' return self.http.urlopen - @staticmethod - def _handle_server_response(response): - data = response.data + def _handle_server_response(self, response: urllib3.response.HTTPResponse) -> Dict[str, Any]: + data: Union[bytes, str] = response.data try: data = data.decode('utf-8') - data = json.loads(data) + ret: Dict[str, Any] = json.loads(data) + if response.status < 400: + return ret except (TypeError, ValueError, UnicodeError) as e: if response.status < 400: raise etcd.EtcdException('Server response was not valid JSON: %r' % e) - if response.status < 400: - return data - _raise_for_data(data, response.status) + ret = {} + raise _raise_for_data(ret or data, response.status) - def _ensure_version_prefix(self, base_uri, **kwargs): + def _ensure_version_prefix(self, base_uri: str, **kwargs: Any) -> None: if self.version_prefix != '/v3': response = self.http.urlopen(self._MGET, base_uri + '/version', **kwargs) response = self._handle_server_response(response) @@ -234,87 +287,64 @@ class Etcd3Client(AbstractEtcdClientWithFailover): else: self.version_prefix = '/v3' - def _prepare_get_members(self, etcd_nodes): + def _prepare_get_members(self, etcd_nodes: int) -> Dict[str, Any]: kwargs = self._prepare_common_parameters(etcd_nodes) self._prepare_request(kwargs, {}) return kwargs - def _get_members(self, base_uri, **kwargs): + def _get_members(self, base_uri: str, **kwargs: Any) -> List[str]: self._ensure_version_prefix(base_uri, **kwargs) resp = self.http.urlopen(self._MPOST, base_uri + self.version_prefix + '/cluster/member/list', **kwargs) members = self._handle_server_response(resp)['members'] - return set(url for member in members for url in member.get('clientURLs', [])) + return [url for member in members for url in member.get('clientURLs', [])] - def call_rpc(self, method, fields, retry=None): + def call_rpc(self, method: str, fields: Dict[str, Any], retry: Optional[Retry] = None) -> Dict[str, Any]: fields['retry'] = retry return self.api_execute(self.version_prefix + method, self._MPOST, fields) - def authenticate(self): - if self._use_proxies and self._cluster_version is None: + def authenticate(self) -> bool: + if self._use_proxies and not self._cluster_version: kwargs = self._prepare_common_parameters(1) self._ensure_version_prefix(self._base_uri, **kwargs) - if self._cluster_version >= (3, 3) and self.username and self.password: - logger.info('Trying to authenticate on Etcd...') - old_token, self._token = self._token, None - try: - response = self.call_rpc('/auth/authenticate', {'name': self.username, 'password': self.password}) - except AuthNotEnabled: - logger.info('Etcd authentication is not enabled') - self._token = None - except Exception: - self._token = old_token - raise - else: - self._token = response.get('token') - return old_token != self._token - - def _handle_auth_errors(func): - def wrapper(self, *args, **kwargs): - def retry(ex): - if self.username and self.password: - self.authenticate() - return func(self, *args, **kwargs) - else: - logger.fatal('Username or password not set, authentication is not possible') - raise ex - - try: - return func(self, *args, **kwargs) - except (UserEmpty, PermissionDenied) as e: # no token provided - # PermissionDenied is raised on 3.0 and 3.1 - if self._cluster_version < (3, 3) and (not isinstance(e, PermissionDenied) - or self._cluster_version < (3, 2)): - raise UnsupportedEtcdVersion('Authentication is required by Etcd cluster but not ' - 'supported on version lower than 3.3.0. Cluster version: ' - '{0}'.format('.'.join(map(str, self._cluster_version)))) - return retry(e) - except InvalidAuthToken as e: - logger.error('Invalid auth token: %s', self._token) - return retry(e) - - return wrapper + if not (self._cluster_version >= (3, 3) and self.username and self.password): + return False + logger.info('Trying to authenticate on Etcd...') + old_token, self._token = self._token, None + try: + response = self.call_rpc('/auth/authenticate', {'name': self.username, 'password': self.password}) + except AuthNotEnabled: + logger.info('Etcd authentication is not enabled') + self._token = None + except Exception: + self._token = old_token + raise + else: + self._token = response.get('token') + return old_token != self._token @_handle_auth_errors - def range(self, key, range_end=None, retry=None): + def range(self, key: str, range_end: Union[bytes, str, None] = None, + retry: Optional[Retry] = None) -> Dict[str, Any]: params = build_range_request(key, range_end) params['serializable'] = True # For better performance. We can tolerate stale reads. return self.call_rpc('/kv/range', params, retry) - def prefix(self, key, retry=None): + def prefix(self, key: str, retry: Optional[Retry] = None) -> Dict[str, Any]: return self.range(key, prefix_range_end(key), retry) @_handle_auth_errors - def lease_grant(self, ttl, retry=None): + def lease_grant(self, ttl: int, retry: Optional[Retry] = None) -> str: return self.call_rpc('/lease/grant', {'TTL': ttl}, retry)['ID'] - def lease_keepalive(self, ID, retry=None): + def lease_keepalive(self, ID: str, retry: Optional[Retry] = None) -> Optional[str]: return self.call_rpc('/lease/keepalive', {'ID': ID}, retry).get('result', {}).get('TTL') - def txn(self, compare, success, retry=None): - return self.call_rpc('/kv/txn', {'compare': [compare], 'success': [success]}, retry).get('succeeded') + def txn(self, compare: Dict[str, Any], success: Dict[str, Any], retry: Optional[Retry] = None) -> Dict[str, Any]: + return self.call_rpc('/kv/txn', {'compare': [compare], 'success': [success]}, retry).get('succeeded', {}) @_handle_auth_errors - def put(self, key, value, lease=None, create_revision=None, mod_revision=None, retry=None): + def put(self, key: str, value: str, lease: Optional[str] = None, create_revision: Optional[str] = None, + mod_revision: Optional[str] = None, retry: Optional[Retry] = None) -> Dict[str, Any]: fields = {'key': base64_encode(key), 'value': base64_encode(value)} if lease: fields['lease'] = lease @@ -328,17 +358,20 @@ class Etcd3Client(AbstractEtcdClientWithFailover): return self.txn(compare, {'request_put': fields}, retry) @_handle_auth_errors - def deleterange(self, key, range_end=None, mod_revision=None, retry=None): + def deleterange(self, key: str, range_end: Union[bytes, str, None] = None, + mod_revision: Optional[str] = None, retry: Optional[Retry] = None) -> Dict[str, Any]: fields = build_range_request(key, range_end) if mod_revision is None: return self.call_rpc('/kv/deleterange', fields, retry) compare = {'target': 'MOD', 'mod_revision': mod_revision, 'key': fields['key']} return self.txn(compare, {'request_delete_range': fields}, retry) - def deleteprefix(self, key, retry=None): + def deleteprefix(self, key: str, retry: Optional[Retry] = None) -> Dict[str, Any]: return self.deleterange(key, prefix_range_end(key), retry=retry) - def watchrange(self, key, range_end=None, start_revision=None, filters=None, read_timeout=None): + def watchrange(self, key: str, range_end: Union[bytes, str, None] = None, + start_revision: Optional[str] = None, filters: Optional[List[Dict[str, Any]]] = None, + read_timeout: Optional[float] = None) -> urllib3.response.HTTPResponse: """returns: response object""" params = build_range_request(key, range_end) if start_revision is not None: @@ -349,14 +382,16 @@ class Etcd3Client(AbstractEtcdClientWithFailover): kwargs.update(timeout=urllib3.Timeout(connect=kwargs['timeout'], read=read_timeout), retries=0) return request_executor(self._MPOST, self._base_uri + self.version_prefix + '/watch', **kwargs) - def watchprefix(self, key, start_revision=None, filters=None, read_timeout=None): + def watchprefix(self, key: str, start_revision: Optional[str] = None, + filters: Optional[List[Dict[str, Any]]] = None, + read_timeout: Optional[float] = None) -> urllib3.response.HTTPResponse: return self.watchrange(key, prefix_range_end(key), start_revision, filters, read_timeout) class KVCache(Thread): - def __init__(self, dcs, client): - Thread.__init__(self) + def __init__(self, dcs: 'Etcd3', client: 'PatroniEtcd3Client') -> None: + super(KVCache, self).__init__() self.daemon = True self._dcs = dcs self._client = client @@ -373,32 +408,32 @@ class KVCache(Thread): self._object_cache_lock = Lock() self.start() - def set(self, value, overwrite=False): + def set(self, value: Dict[str, Any], overwrite: bool = False) -> Tuple[bool, Optional[Dict[str, Any]]]: with self._object_cache_lock: name = value['key'] old_value = self._object_cache.get(name) ret = not old_value or int(old_value['mod_revision']) < int(value['mod_revision']) - if ret or overwrite and old_value['mod_revision'] == value['mod_revision']: + if ret or overwrite and old_value and old_value['mod_revision'] == value['mod_revision']: self._object_cache[name] = value return ret, old_value - def delete(self, name, mod_revision): + def delete(self, name: str, mod_revision: str) -> Tuple[bool, Optional[Dict[str, Any]]]: with self._object_cache_lock: old_value = self._object_cache.get(name) ret = old_value and int(old_value['mod_revision']) < int(mod_revision) if ret: del self._object_cache[name] - return not old_value or ret, old_value + return bool(not old_value or ret), old_value - def copy(self): + def copy(self) -> List[Dict[str, Any]]: with self._object_cache_lock: return [v.copy() for v in self._object_cache.values()] - def get(self, name): + def get(self, name: str) -> Optional[Dict[str, Any]]: with self._object_cache_lock: return self._object_cache.get(name) - def _process_event(self, event): + def _process_event(self, event: Dict[str, Any]) -> None: kv = event['kv'] key = kv['key'] if event.get('type') == 'DELETE': @@ -422,21 +457,22 @@ class KVCache(Thread): or (self.get(self._leader_key) or {}).get('value') != self._name): self._dcs.event.set() - def _process_message(self, message): + def _process_message(self, message: Dict[str, Any]) -> None: logger.debug('Received message: %s', message) if 'error' in message: - _raise_for_data(message) - for event in message.get('result', {}).get('events', []): + raise _raise_for_data(message) + events: List[Dict[str, Any]] = message.get('result', {}).get('events', []) + for event in events: self._process_event(event) @staticmethod - def _finish_response(response): + def _finish_response(response: urllib3.response.HTTPResponse) -> None: try: response.close() finally: response.release_conn() - def _do_watch(self, revision): + def _do_watch(self, revision: str) -> None: with self._response_lock: self._response = None # We do most of requests with timeouts. The only exception /watch requests to Etcd v3. @@ -457,7 +493,7 @@ class KVCache(Thread): for message in iter_response_objects(response): self._process_message(message) - def _build_cache(self): + def _build_cache(self) -> None: result = self._dcs.retry(self._client.prefix, self._dcs.cluster_prefix) with self._object_cache_lock: self._object_cache = {node['key']: node for node in result.get('kvs', [])} @@ -476,10 +512,10 @@ class KVCache(Thread): self._is_ready = False with self._response_lock: response, self._response = self._response, None - if response: + if isinstance(response, urllib3.response.HTTPResponse): self._finish_response(response) - def run(self): + def run(self) -> None: while True: try: self._build_cache() @@ -487,12 +523,12 @@ class KVCache(Thread): logger.error('KVCache.run %r', e) time.sleep(1) - def kill_stream(self): + def kill_stream(self) -> None: sock = None with self._response_lock: - if self._response: + if isinstance(self._response, urllib3.response.HTTPResponse): try: - sock = self._response.connection.sock + sock = self._response.connection.sock if self._response.connection else None except Exception: sock = None else: @@ -504,48 +540,48 @@ class KVCache(Thread): except Exception as e: logger.debug('Error on socket.shutdown: %r', e) - def is_ready(self): + def is_ready(self) -> bool: """Must be called only when holding the lock on `condition`""" return self._is_ready class PatroniEtcd3Client(Etcd3Client): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: self._kv_cache = None super(PatroniEtcd3Client, self).__init__(*args, **kwargs) - def configure(self, etcd3): + def configure(self, etcd3: 'Etcd3') -> None: self._etcd3 = etcd3 - def start_watcher(self): + def start_watcher(self) -> None: if self._cluster_version >= (3, 1): self._kv_cache = KVCache(self._etcd3, self) - def _restart_watcher(self): + def _restart_watcher(self) -> None: if self._kv_cache: self._kv_cache.kill_stream() - def set_base_uri(self, value): + def set_base_uri(self, value: str) -> None: super(PatroniEtcd3Client, self).set_base_uri(value) self._restart_watcher() - def authenticate(self): + def authenticate(self) -> bool: ret = super(PatroniEtcd3Client, self).authenticate() if ret: self._restart_watcher() return ret - def _wait_cache(self, timeout): + def _wait_cache(self, timeout: float) -> None: stop_time = time.time() + timeout - while not self._kv_cache.is_ready(): + while self._kv_cache and not self._kv_cache.is_ready(): timeout = stop_time - time.time() if timeout <= 0: raise RetryFailedError('Exceeded retry deadline') self._kv_cache.condition.wait(timeout) - def get_cluster(self, path): - if self._kv_cache and path.startswith(self._etcd3.cluster_prefix): + def get_cluster(self, path: str) -> List[Dict[str, Any]]: + if self._kv_cache and self._etcd3._retry.deadline is not None and path.startswith(self._etcd3.cluster_prefix): with self._kv_cache.condition: self._wait_cache(self._etcd3._retry.deadline) ret = self._kv_cache.copy() @@ -557,7 +593,7 @@ class PatroniEtcd3Client(Etcd3Client): 'lease': node.get('lease')}) return ret - def call_rpc(self, method, fields, retry=None): + def call_rpc(self, method: str, fields: Dict[str, Any], retry: Optional[Retry] = None) -> Dict[str, Any]: ret = super(PatroniEtcd3Client, self).call_rpc(method, fields, retry) if self._kv_cache: @@ -582,8 +618,9 @@ class PatroniEtcd3Client(Etcd3Client): class Etcd3(AbstractEtcd): - def __init__(self, config): + def __init__(self, config: Dict[str, Any]) -> None: super(Etcd3, self).__init__(config, PatroniEtcd3Client, (DeadlineExceeded, Unavailable, FailedPrecondition)) + assert isinstance(self._client, PatroniEtcd3Client) self.__do_not_watch = False self._lease = None self._last_lease_refresh = 0 @@ -593,15 +630,23 @@ class Etcd3(AbstractEtcd): self._client.start_watcher() self.create_lease() - def set_socket_options(self, sock, socket_options): + @property + def _client(self) -> PatroniEtcd3Client: + assert isinstance(self._abstract_client, PatroniEtcd3Client) + return self._abstract_client + + def set_socket_options(self, sock: socket.socket, + socket_options: Optional[Collection[Tuple[int, int, int]]]) -> None: + assert self._retry.deadline is not None enable_keepalive(sock, self.ttl, int(self.loop_wait + self._retry.deadline)) - def set_ttl(self, ttl): + def set_ttl(self, ttl: int) -> Optional[bool]: self.__do_not_watch = super(Etcd3, self).set_ttl(ttl) if self.__do_not_watch: self._lease = None + return None - def _do_refresh_lease(self, force=False, retry=None): + def _do_refresh_lease(self, force: bool = False, retry: Optional[Retry] = None) -> bool: if not force and self._lease and self._last_lease_refresh + self._loop_wait > time.time(): return False @@ -615,14 +660,14 @@ class Etcd3(AbstractEtcd): self._last_lease_refresh = time.time() return ret - def refresh_lease(self): + def refresh_lease(self) -> bool: try: return self.retry(self._do_refresh_lease) except (Etcd3ClientError, RetryFailedError): logger.exception('refresh_lease') raise Etcd3Error('Failed to keepalive/grant lease') - def create_lease(self): + def create_lease(self) -> None: while not self._lease: try: self.refresh_lease() @@ -631,14 +676,14 @@ class Etcd3(AbstractEtcd): time.sleep(5) @property - def cluster_prefix(self): + def cluster_prefix(self) -> str: return self._base_path + '/' if self.is_citus_coordinator() else self.client_path('') @staticmethod - def member(node): + def member(node: Dict[str, str]) -> Member: return Member.from_node(node['mod_revision'], os.path.basename(node['key']), node['lease'], node['value']) - def _cluster_from_nodes(self, nodes): + def _cluster_from_nodes(self, nodes: Dict[str, Any]) -> Cluster: # get initialize flag initialize = nodes.get(self._INITIALIZE) initialize = initialize and initialize['value'] @@ -666,7 +711,7 @@ class Etcd3(AbstractEtcd): slots = None try: - last_lsn = int(last_lsn) + last_lsn = int(last_lsn or '') except Exception: last_lsn = 0 @@ -701,14 +746,14 @@ class Etcd3(AbstractEtcd): return Cluster(initialize, config, leader, last_lsn, members, failover, sync, history, slots, failsafe) - def _cluster_loader(self, path): + def _cluster_loader(self, path: str) -> Cluster: nodes = {node['key'][len(path):]: node for node in self._client.get_cluster(path) if node['key'].startswith(path)} return self._cluster_from_nodes(nodes) - def _citus_cluster_loader(self, path): - clusters = defaultdict(dict) + def _citus_cluster_loader(self, path: str) -> Dict[int, Cluster]: + clusters: Dict[int, Dict[str, Dict[str, Any]]] = defaultdict(dict) path = self._base_path + '/' for node in self._client.get_cluster(path): key = node['key'][len(path):].split('/', 1) @@ -716,7 +761,9 @@ class Etcd3(AbstractEtcd): clusters[int(key[0])][key[1]] = node return {group: self._cluster_from_nodes(nodes) for group, nodes in clusters.items()} - def _load_cluster(self, path, loader): + def _load_cluster( + self, path: str, loader: Callable[[str], Union[Cluster, Dict[int, Cluster]]] + ) -> Union[Cluster, Dict[int, Cluster]]: cluster = None try: cluster = loader(path) @@ -725,10 +772,11 @@ class Etcd3(AbstractEtcd): except Exception as e: self._handle_exception(e, 'get_cluster', raise_ex=Etcd3Error('Etcd is not responding properly')) self._has_failed = False + assert cluster is not None return cluster @catch_etcd_errors - def touch_member(self, data): + def touch_member(self, data: Dict[str, Any]) -> bool: try: self.refresh_lease() except Etcd3Error: @@ -740,19 +788,20 @@ class Etcd3(AbstractEtcd): if member and member.session == self._lease and deep_compare(data, member.data): return True - data = json.dumps(data, separators=(',', ':')) + value = json.dumps(data, separators=(',', ':')) try: - return self._client.put(self.member_path, data, self._lease) + return bool(self._client.put(self.member_path, value, self._lease)) except LeaseNotFound: self._lease = None logger.error('Our lease disappeared from Etcd, can not "touch_member"') + return False @catch_etcd_errors - def take_leader(self): + def take_leader(self) -> bool: return self.retry(self._client.put, self.leader_path, self._name, self._lease) - def _do_attempt_to_acquire_leader(self, retry): - def _retry(*args, **kwargs): + def _do_attempt_to_acquire_leader(self, retry: Retry) -> bool: + def _retry(*args: Any, **kwargs: Any) -> Any: kwargs['retry'] = retry return retry(*args, **kwargs) @@ -772,10 +821,10 @@ class Etcd3(AbstractEtcd): return _retry(self._client.put, self.leader_path, self._name, self._lease, 0) @catch_return_false_exception - def attempt_to_acquire_leader(self): + def attempt_to_acquire_leader(self) -> bool: retry = self._retry.copy() - def _retry(*args, **kwargs): + def _retry(*args: Any, **kwargs: Any) -> Any: kwargs['retry'] = retry return retry(*args, **kwargs) @@ -791,30 +840,30 @@ class Etcd3(AbstractEtcd): return ret @catch_etcd_errors - def set_failover_value(self, value, index=None): - return self._client.put(self.failover_path, value, mod_revision=index) + def set_failover_value(self, value: str, index: Optional[str] = None) -> bool: + return bool(self._client.put(self.failover_path, value, mod_revision=index)) @catch_etcd_errors - def set_config_value(self, value, index=None): - return self._client.put(self.config_path, value, mod_revision=index) + def set_config_value(self, value: str, index: Optional[str] = None) -> bool: + return bool(self._client.put(self.config_path, value, mod_revision=index)) @catch_etcd_errors - def _write_leader_optime(self, last_lsn): - return self._client.put(self.leader_optime_path, last_lsn) + def _write_leader_optime(self, last_lsn: str) -> bool: + return bool(self._client.put(self.leader_optime_path, last_lsn)) @catch_etcd_errors - def _write_status(self, value): - return self._client.put(self.status_path, value) + def _write_status(self, value: str) -> bool: + return bool(self._client.put(self.status_path, value)) @catch_etcd_errors - def _write_failsafe(self, value): - return self._client.put(self.failsafe_path, value) + def _write_failsafe(self, value: str) -> bool: + return bool(self._client.put(self.failsafe_path, value)) @catch_return_false_exception - def _update_leader(self): + def _update_leader(self) -> bool: retry = self._retry.copy() - def _retry(*args, **kwargs): + def _retry(*args: Any, **kwargs: Any) -> Any: kwargs['retry'] = retry return retry(*args, **kwargs) @@ -836,36 +885,37 @@ class Etcd3(AbstractEtcd): return bool(self._lease) @catch_etcd_errors - def initialize(self, create_new=True, sysid=""): + def initialize(self, create_new: bool = True, sysid: str = ""): return self.retry(self._client.put, self.initialize_path, sysid, None, 0 if create_new else None) @catch_etcd_errors - def _delete_leader(self): + def _delete_leader(self) -> bool: cluster = self.cluster if cluster and isinstance(cluster.leader, Leader) and cluster.leader.name == self._name: return self._client.deleterange(self.leader_path, mod_revision=cluster.leader.index) + return True @catch_etcd_errors - def cancel_initialization(self): + def cancel_initialization(self) -> bool: return self.retry(self._client.deleterange, self.initialize_path) @catch_etcd_errors - def delete_cluster(self): + def delete_cluster(self) -> bool: return self.retry(self._client.deleteprefix, self.client_path('')) @catch_etcd_errors - def set_history_value(self, value): - return self._client.put(self.history_path, value) + def set_history_value(self, value: str) -> bool: + return bool(self._client.put(self.history_path, value)) @catch_etcd_errors - def set_sync_state_value(self, value, index=None): + def set_sync_state_value(self, value: str, index: Optional[str] = None) -> bool: return self.retry(self._client.put, self.sync_path, value, mod_revision=index) @catch_etcd_errors - def delete_sync_state(self, index=None): + def delete_sync_state(self, index: Optional[str] = None) -> bool: return self.retry(self._client.deleterange, self.sync_path, mod_revision=index) - def watch(self, leader_index, timeout): + def watch(self, leader_index: Optional[str], timeout: float) -> bool: if self.__do_not_watch: self.__do_not_watch = False return True diff --git a/patroni/dcs/exhibitor.py b/patroni/dcs/exhibitor.py index cca14a58..2b06073b 100644 --- a/patroni/dcs/exhibitor.py +++ b/patroni/dcs/exhibitor.py @@ -3,9 +3,12 @@ import logging import random import time -from patroni.dcs.zookeeper import ZooKeeper -from patroni.request import get as requests_get -from patroni.utils import uri +from typing import Any, Callable, Dict, List, Union + +from . import Cluster +from .zookeeper import ZooKeeper +from ..request import get as requests_get +from ..utils import uri logger = logging.getLogger(__name__) @@ -14,11 +17,12 @@ class ExhibitorEnsembleProvider(object): TIMEOUT = 3.1 - def __init__(self, hosts, port, uri_path='/exhibitor/v1/cluster/list', poll_interval=300): + def __init__(self, hosts: List[str], port: int, + uri_path: str = '/exhibitor/v1/cluster/list', poll_interval: int = 300) -> None: self._exhibitor_port = port self._uri_path = uri_path self._poll_interval = poll_interval - self._exhibitors = hosts + self._exhibitors: List[str] = hosts self._boot_exhibitors = hosts self._zookeeper_hosts = '' self._next_poll = None @@ -26,7 +30,7 @@ class ExhibitorEnsembleProvider(object): logger.info('waiting on exhibitor') time.sleep(5) - def poll(self): + def poll(self) -> bool: if self._next_poll and self._next_poll > time.time(): return False @@ -36,7 +40,8 @@ class ExhibitorEnsembleProvider(object): if isinstance(json, dict) and 'servers' in json and 'port' in json: self._next_poll = time.time() + self._poll_interval - zookeeper_hosts = ','.join([h + ':' + str(json['port']) for h in sorted(json['servers'])]) + servers: List[str] = json['servers'] + zookeeper_hosts = ','.join([h + ':' + str(json['port']) for h in sorted(servers)]) if self._zookeeper_hosts != zookeeper_hosts: logger.info('ZooKeeper connection string has changed: %s => %s', self._zookeeper_hosts, zookeeper_hosts) self._zookeeper_hosts = zookeeper_hosts @@ -44,7 +49,7 @@ class ExhibitorEnsembleProvider(object): return True return False - def _query_exhibitors(self, exhibitors): + def _query_exhibitors(self, exhibitors: List[str]) -> Union[Dict[str, Any], Any]: random.shuffle(exhibitors) for host in exhibitors: try: @@ -55,18 +60,20 @@ class ExhibitorEnsembleProvider(object): return None @property - def zookeeper_hosts(self): + def zookeeper_hosts(self) -> str: return self._zookeeper_hosts class Exhibitor(ZooKeeper): - def __init__(self, config): + def __init__(self, config: Dict[str, Any]) -> None: interval = config.get('poll_interval', 300) self._ensemble_provider = ExhibitorEnsembleProvider(config['hosts'], config['port'], poll_interval=interval) super(Exhibitor, self).__init__({**config, 'hosts': self._ensemble_provider.zookeeper_hosts}) - def _load_cluster(self, path, loader): + def _load_cluster( + self, path: str, loader: Callable[[str], Union[Cluster, Dict[int, Cluster]]] + ) -> Union[Cluster, Dict[int, Cluster]]: if self._ensemble_provider.poll(): self._client.set_hosts(self._ensemble_provider.zookeeper_hosts) return super(Exhibitor, self)._load_cluster(path, loader) diff --git a/patroni/dcs/kubernetes.py b/patroni/dcs/kubernetes.py index 981980f7..61cebc10 100644 --- a/patroni/dcs/kubernetes.py +++ b/patroni/dcs/kubernetes.py @@ -15,15 +15,17 @@ import yaml from collections import defaultdict from copy import deepcopy from http.client import HTTPException -from threading import Condition, Lock, Thread -from typing import Any, Collection, Dict, List, Optional, Union from urllib3.exceptions import HTTPError +from threading import Condition, Lock, Thread +from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type, Union, TYPE_CHECKING from . import AbstractDCS, Cluster, ClusterConfig, Failover, Leader, Member, SyncState,\ TimelineHistory, CITUS_COORDINATOR_GROUP_ID, citus_group_re from ..exceptions import DCSError from ..utils import deep_compare, iter_response_objects, keepalive_socket_options,\ Retry, RetryFailedError, tzutc, uri, USER_AGENT +if TYPE_CHECKING: # pragma: no cover + from ..config import Config logger = logging.getLogger(__name__) @@ -32,14 +34,14 @@ SERVICE_HOST_ENV_NAME = 'KUBERNETES_SERVICE_HOST' SERVICE_PORT_ENV_NAME = 'KUBERNETES_SERVICE_PORT' SERVICE_TOKEN_FILENAME = '/var/run/secrets/kubernetes.io/serviceaccount/token' SERVICE_CERT_FILENAME = '/var/run/secrets/kubernetes.io/serviceaccount/ca.crt' -__temp_files = [] +__temp_files: List[str] = [] class KubernetesError(DCSError): pass -def _cleanup_temp_files(): +def _cleanup_temp_files() -> None: global __temp_files for temp_file in __temp_files: try: @@ -49,7 +51,7 @@ def _cleanup_temp_files(): __temp_files = [] -def _create_temp_file(content): +def _create_temp_file(content: bytes) -> str: if len(__temp_files) == 0: atexit.register(_cleanup_temp_files) @@ -61,7 +63,7 @@ def _create_temp_file(content): # this function does the same mapping of snake_case => camelCase for > 97% of cases as autogenerated swagger code -def to_camel_case(value): +def to_camel_case(value: str) -> str: reserved = {'api', 'apiv3', 'cidr', 'cpu', 'csi', 'id', 'io', 'ip', 'ipc', 'pid', 'tls', 'uri', 'url', 'uuid'} words = value.split('_') return words[0] + ''.join(w.upper() if w in reserved else w.title() for w in words[1:]) @@ -72,20 +74,21 @@ class K8sConfig(object): class ConfigException(Exception): pass - def __init__(self): - self.pool_config = {'maxsize': 10, 'num_pools': 10} # configuration for urllib3.PoolManager + def __init__(self) -> None: + self.pool_config: Dict[str, Union[str, int]] = {'maxsize': 10, 'num_pools': 10} # urllib3.PoolManager config self._token_expires_at = datetime.datetime.max + self._headers: Dict[str, str] = {} self._make_headers() - def _set_token(self, token): + def _set_token(self, token: str) -> None: self._headers['authorization'] = 'Bearer ' + token - def _make_headers(self, token=None, **kwargs): + def _make_headers(self, token: Optional[str] = None, **kwargs: Any) -> None: self._headers = urllib3.make_headers(user_agent=USER_AGENT, **kwargs) if token: self._set_token(token) - def _read_token_file(self): + def _read_token_file(self) -> str: if not os.path.isfile(SERVICE_TOKEN_FILENAME): raise self.ConfigException('Service token file does not exists.') with open(SERVICE_TOKEN_FILENAME) as f: @@ -95,8 +98,8 @@ class K8sConfig(object): self._token_expires_at = datetime.datetime.now() + self._token_refresh_interval return token - def load_incluster_config(self, ca_certs=SERVICE_CERT_FILENAME, - token_refresh_interval=datetime.timedelta(minutes=1)): + def load_incluster_config(self, ca_certs: str = SERVICE_CERT_FILENAME, + token_refresh_interval: datetime.timedelta = datetime.timedelta(minutes=1)) -> None: if SERVICE_HOST_ENV_NAME not in os.environ or SERVICE_PORT_ENV_NAME not in os.environ: raise self.ConfigException('Service host/port is not set.') if not os.environ[SERVICE_HOST_ENV_NAME] or not os.environ[SERVICE_PORT_ENV_NAME]: @@ -114,25 +117,29 @@ class K8sConfig(object): self._server = uri('https', (os.environ[SERVICE_HOST_ENV_NAME], os.environ[SERVICE_PORT_ENV_NAME])) @staticmethod - def _get_by_name(config, section, name): + def _get_by_name(config: Dict[str, List[Dict[str, Any]]], section: str, name: str) -> Optional[Dict[str, Any]]: for c in config[section + 's']: if c['name'] == name: return c[section] - def _pool_config_from_file_or_data(self, config, file_key_name, pool_key_name): + def _pool_config_from_file_or_data(self, config: Dict[str, str], file_key_name: str, pool_key_name: str) -> None: data_key_name = file_key_name + '-data' if data_key_name in config: self.pool_config[pool_key_name] = _create_temp_file(base64.b64decode(config[data_key_name])) elif file_key_name in config: self.pool_config[pool_key_name] = config[file_key_name] - def load_kube_config(self, context=None): + def load_kube_config(self, context: Optional[str] = None) -> None: with open(os.path.expanduser(KUBE_CONFIG_DEFAULT_LOCATION)) as f: - config = yaml.safe_load(f) + config: Dict[str, Any] = yaml.safe_load(f) - context = self._get_by_name(config, 'context', context or config['current-context']) - cluster = self._get_by_name(config, 'cluster', context['cluster']) - user = self._get_by_name(config, 'user', context['user']) + context = context or config['current-context'] + context_value = self._get_by_name(config, 'context', context) + assert isinstance(context_value, dict) + cluster = self._get_by_name(config, 'cluster', context_value['cluster']) + assert isinstance(cluster, dict) + user = self._get_by_name(config, 'user', context_value['user']) + assert isinstance(user, dict) self._server = cluster['server'].rstrip('/') if self._server.startswith('https'): @@ -143,14 +150,14 @@ class K8sConfig(object): if user.get('token'): self._make_headers(token=user['token']) elif 'username' in user and 'password' in user: - self._headers = self._make_headers(basic_auth=':'.join((user['username'], user['password']))) + self._make_headers(basic_auth=':'.join((user['username'], user['password']))) @property - def server(self): + def server(self) -> str: return self._server @property - def headers(self): + def headers(self) -> Dict[str, str]: if self._token_expires_at <= datetime.datetime.now(): try: self._set_token(self._read_token_file()) @@ -161,30 +168,32 @@ class K8sConfig(object): class K8sObject(object): - def __init__(self, kwargs): + def __init__(self, kwargs: Dict[str, Any]) -> None: self._dict = {k: self._wrap(k, v) for k, v in kwargs.items()} - def get(self, name, default=None): + def get(self, name: str, default: Optional[Any] = None) -> Optional[Any]: return self._dict.get(name, default) - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: return self.get(to_camel_case(name)) @classmethod - def _wrap(cls, parent, value): + def _wrap(cls, parent: Optional[str], value: Any) -> Any: if isinstance(value, dict): + data_dict: Dict[str, Any] = value # we know that `annotations` and `labels` are dicts and therefore don't want to convert them into K8sObject - return value if parent in {'annotations', 'labels'} and \ - all(isinstance(v, str) for v in value.values()) else cls(value) + return data_dict if parent in {'annotations', 'labels'} and \ + all(isinstance(v, str) for v in data_dict.values()) else cls(data_dict) elif isinstance(value, list): - return [cls._wrap(None, v) for v in value] + data_list: List[Any] = value + return [cls._wrap(None, v) for v in data_list] else: return value - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return self._dict - def __repr__(self): + def __repr__(self) -> str: return json.dumps(self, indent=4, default=lambda o: o.to_dict()) @@ -201,13 +210,14 @@ class K8sClient(object): class rest(object): class ApiException(Exception): - def __init__(self, status=None, reason=None, http_resp=None): + def __init__(self, status: Optional[int] = None, reason: Optional[str] = None, + http_resp: Optional[urllib3.HTTPResponse] = None) -> None: self.status = http_resp.status if http_resp else status self.reason = http_resp.reason if http_resp else reason self.body = http_resp.data if http_resp else None - self.headers = http_resp.getheaders() if http_resp else None + self.headers = http_resp.headers if http_resp else None - def __str__(self): + def __str__(self) -> str: error_message = "({0})\nReason: {1}\n".format(self.status, self.reason) if self.headers: error_message += "HTTP response headers: {0}\n".format(self.headers) @@ -232,44 +242,46 @@ class K8sClient(object): except K8sException: pass - def set_read_timeout(self, timeout): + def set_read_timeout(self, timeout: Union[int, float]) -> None: self._read_timeout = timeout - def set_api_servers_cache_ttl(self, ttl): + def set_api_servers_cache_ttl(self, ttl: int) -> None: self._api_servers_cache_ttl = ttl - 0.5 - def set_base_uri(self, value): + def set_base_uri(self, value: str) -> None: logger.info('Selected new K8s API server endpoint %s', value) # We will connect by IP of the K8s master node which is not listed as alternative name self.pool_manager.connection_pool_kw['assert_hostname'] = False self._base_uri = value @staticmethod - def _handle_server_response(response, _preload_content): + def _handle_server_response(response: urllib3.HTTPResponse, + _preload_content: bool) -> Union[urllib3.HTTPResponse, K8sObject]: if response.status not in range(200, 206): raise k8s_client.rest.ApiException(http_resp=response) return K8sObject(json.loads(response.data.decode('utf-8'))) if _preload_content else response @staticmethod - def _make_headers(headers): + def _make_headers(headers: Optional[Dict[str, str]]) -> Dict[str, str]: ret = k8s_config.headers ret.update(headers or {}) return ret @property - def api_servers_cache(self): + def api_servers_cache(self) -> List[str]: base_uri, cache = self._base_uri, self._api_servers_cache return ([base_uri] if base_uri in cache else []) + [machine for machine in cache if machine != base_uri] - def _get_api_servers(self, api_servers_cache): + def _get_api_servers(self, api_servers_cache: List[str]) -> List[str]: _, per_node_timeout, per_node_retries = self._calculate_timeouts(len(api_servers_cache)) kwargs = {'headers': self._make_headers({}), 'preload_content': True, 'retries': per_node_retries, - 'timeout': urllib3.Timeout(connect=max(1, per_node_timeout / 2.0), total=per_node_timeout)} + 'timeout': urllib3.Timeout(connect=max(1.0, per_node_timeout / 2.0), total=per_node_timeout)} path = self._API_URL_PREFIX + 'default/endpoints/kubernetes' for base_uri in api_servers_cache: try: response = self.pool_manager.request('GET', base_uri + path, **kwargs) endpoint = self._handle_server_response(response, True) + assert isinstance(endpoint, K8sObject) for subset in endpoint.subsets: for port in subset.ports: if port.name == 'https' and port.protocol == 'TCP': @@ -284,7 +296,7 @@ class K8sClient(object): logger.error('Failed to get "kubernetes" endpoint from %s: %r', base_uri, e) raise K8sConnectionFailed('No more K8s API server nodes in the cluster') - def _refresh_api_servers_cache(self, updating_cache=False): + def _refresh_api_servers_cache(self, updating_cache: Optional[bool] = False) -> None: if self._bypass_api_service: try: api_servers_cache = [k8s_config.server] if updating_cache else self.api_servers_cache @@ -309,16 +321,16 @@ class K8sClient(object): self.set_base_uri(self._api_servers_cache[0]) self._api_servers_cache_updated = time.time() - def refresh_api_servers_cache(self): + def refresh_api_servers_cache(self) -> None: if self._bypass_api_service and time.time() - self._api_servers_cache_updated > self._api_servers_cache_ttl: self._refresh_api_servers_cache() - def _load_api_servers_cache(self): + def _load_api_servers_cache(self) -> None: self._update_api_servers_cache = True self._refresh_api_servers_cache(True) self._update_api_servers_cache = False - def _calculate_timeouts(self, api_servers, timeout=None): + def _calculate_timeouts(self, api_servers: int, timeout: Optional[float] = None) -> Tuple[int, float, int]: """Calculate a request timeout and number of retries per single K8s API server node. In case if the timeout per node is too small (less than one second) we will reduce the number of nodes. For the cluster with only one API server node we will try to do 1 retry. @@ -344,7 +356,8 @@ class K8sClient(object): return api_servers, per_node_timeout, per_node_retries - 1 - def _do_http_request(self, retry, api_servers_cache, method, path, **kwargs): + def _do_http_request(self, retry: Optional[Retry], api_servers_cache: List[str], + method: str, path: str, **kwargs: Any) -> urllib3.HTTPResponse: some_request_failed = False for i, base_uri in enumerate(api_servers_cache): if i > 0: @@ -367,7 +380,10 @@ class K8sClient(object): raise K8sConnectionFailed('No more API server nodes in the cluster') - def request(self, retry, method, path, timeout=None, **kwargs): + def request( + self, retry: Optional[Retry], method: str, path: str, + timeout: Union[int, float, Tuple[Union[int, float], Union[int, float]], urllib3.Timeout, None] = None, + **kwargs: Any) -> urllib3.HTTPResponse: if self._update_api_servers_cache: self._load_api_servers_cache() @@ -382,7 +398,7 @@ class K8sClient(object): retries = 0 else: _, timeout, retries = self._calculate_timeouts(api_servers) - timeout = urllib3.Timeout(connect=max(1, timeout / 2.0), total=timeout) + timeout = urllib3.Timeout(connect=max(1.0, timeout / 2.0), total=timeout) kwargs.update(retries=retries, timeout=timeout) while True: @@ -396,6 +412,7 @@ class K8sClient(object): except Exception as e: logger.debug('Failed to update list of K8s master nodes: %r', e) + assert isinstance(retry, Retry) # K8sConnectionFailed is raised only if retry is not None! sleeptime = retry.sleeptime remaining_time = (retry.stoptime or time.time()) - sleeptime - time.time() nodes, timeout, retries = self._calculate_timeouts(api_servers, remaining_time) @@ -405,12 +422,13 @@ class K8sClient(object): retry.sleep_func(sleeptime) retry.update_delay() # We still have some time left. Partially reduce `api_servers_cache` and retry request - kwargs.update(timeout=urllib3.Timeout(connect=max(1, timeout / 2.0), total=timeout), + kwargs.update(timeout=urllib3.Timeout(connect=max(1.0, timeout / 2.0), total=timeout), retries=retries) api_servers_cache = api_servers_cache[:nodes] - def call_api(self, method, path, headers=None, body=None, _retry=None, - _preload_content=True, _request_timeout=None, **kwargs): + def call_api(self, method: str, path: str, headers: Optional[Dict[str, str]] = None, + body: Optional[Any] = None, _retry: Optional[Retry] = None, _preload_content: bool = True, + _request_timeout: Optional[float] = None, **kwargs: Any) -> Union[urllib3.HTTPResponse, K8sObject]: headers = self._make_headers(headers) fields = {to_camel_case(k): v for k, v in kwargs.items()} # resource_version => resourceVersion body = json.dumps(body, default=lambda o: o.to_dict()) if body is not None else None @@ -422,14 +440,15 @@ class K8sClient(object): class CoreV1Api(object): - def __init__(self, api_client=None): + def __init__(self, api_client: Optional['K8sClient.ApiClient'] = None) -> None: self._api_client = api_client or k8s_client.ApiClient() - def __getattr__(self, func): # `func` name pattern: (action)_namespaced_(kind) + def __getattr__(self, func: str) -> Callable[..., Any]: + # `func` name pattern: (action)_namespaced_(kind) action, kind = func.split('_namespaced_') # (read|list|create|patch|replace|delete|delete_collection) kind = kind.replace('_', '') + ('s' * int(kind[-1] != 's')) # plural, single word - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Union[urllib3.HTTPResponse, K8sObject]: method = {'read': 'GET', 'list': 'GET', 'create': 'POST', 'replace': 'PUT'}.get(action, action.split('_')[0]).upper() @@ -454,14 +473,14 @@ class K8sClient(object): class _K8sObjectTemplate(K8sObject): """The template for objects which we create locally, e.g. k8s_client.V1ObjectMeta & co""" - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: self._dict = {to_camel_case(k): v for k, v in kwargs.items()} - def __init__(self): - self.__cls_cache = {} + def __init__(self) -> None: + self.__cls_cache: Dict[str, Type['K8sClient._K8sObjectTemplate']] = {} self.__cls_lock = Lock() - def __getattr__(self, name): + def __getattr__(self, name: str) -> Type['K8sClient._K8sObjectTemplate']: with self.__cls_lock: if name not in self.__cls_cache: self.__cls_cache[name] = type(name, (self._K8sObjectTemplate,), {}) @@ -474,15 +493,15 @@ k8s_config = K8sConfig() class KubernetesRetriableException(k8s_client.rest.ApiException): - def __init__(self, orig): + def __init__(self, orig: K8sClient.rest.ApiException) -> None: super(KubernetesRetriableException, self).__init__(orig.status, orig.reason) self.body = orig.body self.headers = orig.headers @property - def sleeptime(self): + def sleeptime(self) -> Optional[int]: try: - return int(self.headers['retry-after']) + return int((self.headers or {}).get('retry-after', '')) except Exception: return None @@ -498,7 +517,7 @@ class CoreV1ApiProxy(object): self._use_endpoints = bool(use_endpoints) self._retriable_http_codes = set(self._DEFAULT_RETRIABLE_HTTP_CODES) - def configure_timeouts(self, loop_wait, retry_timeout, ttl): + def configure_timeouts(self, loop_wait: int, retry_timeout: Union[int, float], ttl: int) -> None: # Normally every loop_wait seconds we should have receive something from the socket. # If we didn't received anything after the loop_wait + retry_timeout it is a time # to start worrying (send keepalive messages). Finally, the connection should be @@ -511,10 +530,10 @@ class CoreV1ApiProxy(object): def configure_retriable_http_codes(self, retriable_http_codes: List[int]) -> None: self._retriable_http_codes = self._DEFAULT_RETRIABLE_HTTP_CODES | set(retriable_http_codes) - def refresh_api_servers_cache(self): + def refresh_api_servers_cache(self) -> None: self._api_client.refresh_api_servers_cache() - def __getattr__(self, func: str): + def __getattr__(self, func: str) -> Callable[..., Any]: """Intercepts calls to `CoreV1Api` methods. Handles two important cases: @@ -526,7 +545,7 @@ class CoreV1ApiProxy(object): if func.endswith('_kind'): func = func[:-4] + ('endpoints' if self._use_endpoints else 'config_map') - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Any: try: return getattr(self._core_v1_api, func)(*args, **kwargs) except k8s_client.rest.ApiException as e: @@ -536,12 +555,12 @@ class CoreV1ApiProxy(object): return wrapper @property - def use_endpoints(self): + def use_endpoints(self) -> bool: return self._use_endpoints -def catch_kubernetes_errors(func): - def wrapper(self, *args, **kwargs): +def catch_kubernetes_errors(func: Callable[..., Any]) -> Callable[..., Any]: + def wrapper(self: 'Kubernetes', *args: Any, **kwargs: Any) -> Any: try: return self._run_and_handle_exceptions(func, self, *args, **kwargs) except KubernetesError: @@ -551,8 +570,9 @@ def catch_kubernetes_errors(func): class ObjectCache(Thread): - def __init__(self, dcs, func, retry, condition, name=None): - Thread.__init__(self) + def __init__(self, dcs: 'Kubernetes', func: Callable[..., Any], retry: Retry, + condition: Condition, name: Optional[str] = None) -> None: + super(ObjectCache, self).__init__() self.daemon = True self._dcs = dcs self._func = func @@ -560,25 +580,25 @@ class ObjectCache(Thread): self._condition = condition self._name = name # name of this pod self._is_ready = False - self._response = None # needs to be accessible from the `kill_stream()` method + self._response: Union[urllib3.HTTPResponse, bool, None] = None # needs to be accessible from the `kill_stream` self._response_lock = Lock() # protect the `self._response` from concurrent access - self._object_cache = {} + self._object_cache: Dict[str, K8sObject] = {} self._object_cache_lock = Lock() self._annotations_map = {self._dcs.leader_path: self._dcs._LEADER, self._dcs.config_path: self._dcs._CONFIG} self.start() - def _list(self): + def _list(self) -> K8sObject: try: return self._func(_retry=self._retry.copy()) except Exception: time.sleep(1) raise - def _watch(self, resource_version): + def _watch(self, resource_version: str) -> urllib3.HTTPResponse: return self._func(_request_timeout=(self._retry.deadline, urllib3.Timeout.DEFAULT_TIMEOUT), _preload_content=False, watch=True, resource_version=resource_version) - def set(self, name, value): + def set(self, name: str, value: K8sObject) -> Tuple[bool, Optional[K8sObject]]: with self._object_cache_lock: old_value = self._object_cache.get(name) ret = not old_value or int(old_value.metadata.resource_version) < int(value.metadata.resource_version) @@ -586,27 +606,28 @@ class ObjectCache(Thread): self._object_cache[name] = value return ret, old_value - def delete(self, name, resource_version): + def delete(self, name: str, resource_version: str) -> Tuple[bool, Optional[K8sObject]]: with self._object_cache_lock: old_value = self._object_cache.get(name) ret = old_value and int(old_value.metadata.resource_version) < int(resource_version) if ret: del self._object_cache[name] - return not old_value or ret, old_value + return bool(not old_value or ret), old_value - def copy(self): + def copy(self) -> Dict[str, K8sObject]: with self._object_cache_lock: return self._object_cache.copy() - def get(self, name): + def get(self, name: str) -> Optional[K8sObject]: with self._object_cache_lock: return self._object_cache.get(name) - def _process_event(self, event): + def _process_event(self, event: Dict[str, Union[Any, Dict[str, Union[Any, Dict[str, Any]]]]]) -> None: ev_type = event['type'] obj = event['object'] name = obj['metadata']['name'] + new_value = None if ev_type in ('ADDED', 'MODIFIED'): obj = K8sObject(obj) success, old_value = self.set(name, obj) @@ -614,7 +635,6 @@ class ObjectCache(Thread): new_value = (obj.metadata.annotations or {}).get(self._annotations_map.get(name)) elif ev_type == 'DELETED': success, old_value = self.delete(name, obj['metadata']['resourceVersion']) - new_value = None else: return logger.warning('Unexpected event type: %s', ev_type) @@ -633,13 +653,13 @@ class ObjectCache(Thread): self._dcs.event.set() @staticmethod - def _finish_response(response): + def _finish_response(response: urllib3.HTTPResponse) -> None: try: response.close() finally: response.release_conn() - def _do_watch(self, resource_version): + def _do_watch(self, resource_version: str) -> None: with self._response_lock: self._response = None response = self._watch(resource_version) @@ -655,7 +675,7 @@ class ObjectCache(Thread): break self._process_event(event) - def _build_cache(self): + def _build_cache(self) -> None: objects = self._list() with self._object_cache_lock: self._object_cache = {item.metadata.name: item for item in objects.items} @@ -670,15 +690,15 @@ class ObjectCache(Thread): self._is_ready = False with self._response_lock: response, self._response = self._response, None - if response: + if isinstance(response, urllib3.HTTPResponse): self._finish_response(response) - def kill_stream(self): + def kill_stream(self) -> None: sock = None with self._response_lock: - if self._response: + if isinstance(self._response, urllib3.HTTPResponse): try: - sock = self._response.connection.sock + sock = self._response.connection.sock if self._response.connection else None except Exception: sock = None else: @@ -690,14 +710,14 @@ class ObjectCache(Thread): except Exception as e: logger.debug('Error on socket.shutdown: %r', e) - def run(self): + def run(self) -> None: while True: try: self._build_cache() except Exception as e: logger.error('ObjectCache.run %r', e) - def is_ready(self): + def is_ready(self) -> bool: """Must be called only when holding the lock on `_condition`""" return self._is_ready @@ -706,7 +726,7 @@ class Kubernetes(AbstractDCS): _CITUS_LABEL = 'citus-group' - def __init__(self, config): + def __init__(self, config: Dict[str, Any]) -> None: self._labels = deepcopy(config['labels']) self._labels[config.get('scope_label', 'cluster-name')] = config['scope'] self._label_selector = ','.join('{0}={1}'.format(k, v) for k, v in self._labels.items()) @@ -719,17 +739,18 @@ class Kubernetes(AbstractDCS): self._retry = Retry(deadline=config['retry_timeout'], max_delay=1, max_tries=-1, retry_exceptions=KubernetesRetriableException) - self._ttl = None + self._ttl = int(config.get('ttl') or 30) try: k8s_config.load_incluster_config(ca_certs=self._ca_certs) except k8s_config.ConfigException: k8s_config.load_kube_config(context=config.get('context', 'kind-kind')) - self.__my_pod = None - self.__ips = [] if config.get('patronictl') else [config.get('pod_ip')] - self.__ports = [] - for p in config.get('ports', [{}]): - port = {'port': int(p.get('port', '5432'))} + pod_ip = config.get('pod_ip') + self.__ips: List[str] = [] if config.get('patronictl') or not isinstance(pod_ip, str) else [pod_ip] + self.__ports: List[K8sObject] = [] + ports: List[Dict[str, Any]] = config.get('ports', [{}]) + for p in ports: + port: Dict[str, Any] = {'port': int(p.get('port', '5432'))} port.update({n: p[n] for n in ('name', 'protocol') if p.get(n)}) self.__ports.append(k8s_client.V1EndpointPort(**port)) @@ -738,7 +759,7 @@ class Kubernetes(AbstractDCS): self._should_create_config_service = self._api.use_endpoints self.reload_config(config) # leader_observed_record, leader_resource_version, and leader_observed_time are used only for leader race! - self._leader_observed_record = {} + self._leader_observed_record: Dict[str, str] = {} self._leader_observed_time = None self._leader_resource_version = None self.__do_not_watch = False @@ -753,13 +774,13 @@ class Kubernetes(AbstractDCS): label_selector=self._label_selector) self._kinds = ObjectCache(self, kinds_func, self._retry, self._condition, self._name) - def retry(self, *args, **kwargs): + def retry(self, method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: retry = self._retry.copy() kwargs['_retry'] = retry - return retry(*args, **kwargs) + return retry(method, *args, **kwargs) @staticmethod - def _run_and_handle_exceptions(method, *args, **kwargs): + def _run_and_handle_exceptions(method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: try: return method(*args, **kwargs) except k8s_client.rest.ApiException as e: @@ -771,31 +792,33 @@ class Kubernetes(AbstractDCS): except (RetryFailedError, K8sException) as e: raise KubernetesError(e) - def client_path(self, path): + def client_path(self, path: str) -> str: return super(Kubernetes, self).client_path(path)[1:].replace('/', '-') @property - def leader_path(self): + def leader_path(self) -> str: return super(Kubernetes, self).leader_path[:-7 if self._api.use_endpoints else None] - def set_ttl(self, ttl): + def set_ttl(self, ttl: int) -> Optional[bool]: ttl = int(ttl) self.__do_not_watch = self._ttl != ttl self._ttl = ttl + return None @property - def ttl(self): + def ttl(self) -> int: return self._ttl - def set_retry_timeout(self, retry_timeout): + def set_retry_timeout(self, retry_timeout: int) -> None: self._retry.deadline = retry_timeout - def reload_config(self, config: Dict[str, Any]) -> None: + def reload_config(self, config: Union['Config', Dict[str, Any]]) -> None: """Handles dynamic config changes. Either cause by changes in the local configuration file + SIGHUP or by changes of dynamic configuration""" super(Kubernetes, self).reload_config(config) + assert self._retry.deadline is not None self._api.configure_timeouts(self.loop_wait, self._retry.deadline, self.ttl) # retriable_http_codes supposed to be either int, list of integers or comma-separated string with integers. @@ -809,20 +832,20 @@ class Kubernetes(AbstractDCS): logger.warning('Invalid value of retriable_http_codes = %s: %r', config['retriable_http_codes'], e) @staticmethod - def member(pod): + def member(pod: K8sObject) -> Member: annotations = pod.metadata.annotations or {} member = Member.from_node(pod.metadata.resource_version, pod.metadata.name, None, annotations.get('status', '')) member.data['pod_labels'] = pod.metadata.labels return member - def _wait_caches(self, stop_time): + def _wait_caches(self, stop_time: float) -> None: while not (self._pods.is_ready() and self._kinds.is_ready()): timeout = stop_time - time.time() if timeout <= 0: raise RetryFailedError('Exceeded retry deadline') self._condition.wait(timeout) - def _cluster_from_nodes(self, group, nodes, pods): + def _cluster_from_nodes(self, group: str, nodes: Dict[str, K8sObject], pods: Collection[K8sObject]) -> Cluster: members = [self.member(pod) for pod in pods] path = self._base_path[1:] + '-' if group: @@ -836,45 +859,43 @@ class Kubernetes(AbstractDCS): initialize = annotations.get(self._INITIALIZE) # get global dynamic configuration - config = ClusterConfig.from_node(metadata and metadata.resource_version, - annotations.get(self._CONFIG) or '{}', - metadata.resource_version if self._CONFIG in annotations else 0) + config = metadata and ClusterConfig.from_node(metadata.resource_version, + annotations.get(self._CONFIG) or '{}', + metadata.resource_version if self._CONFIG in annotations else 0) # get timeline history - history = TimelineHistory.from_node(metadata and metadata.resource_version, - annotations.get(self._HISTORY) or '[]') + history = metadata and TimelineHistory.from_node(metadata.resource_version, + annotations.get(self._HISTORY) or '[]') leader_path = path[:-1] if self._api.use_endpoints else path + self._LEADER leader = nodes.get(leader_path) metadata = leader and leader.metadata if leader_path == self.leader_path: # We want to memorize leader_resource_version only for our cluster self._leader_resource_version = metadata.resource_version if metadata else None - annotations = metadata and metadata.annotations or {} + annotations: Dict[str, str] = metadata and metadata.annotations or {} # get last known leader lsn - last_lsn = annotations.get(self._OPTIME) try: - last_lsn = 0 if last_lsn is None else int(last_lsn) + last_lsn = int(annotations.get(self._OPTIME, '')) except Exception: last_lsn = 0 # get permanent slots state (confirmed_flush_lsn) slots = annotations.get('slots') try: - slots = slots and json.loads(slots) + slots = json.loads(annotations.get('slots', '')) except Exception: slots = None # get failsafe topology - failsafe = annotations.get(self._FAILSAFE) try: - failsafe = json.loads(failsafe) if failsafe else None + failsafe = json.loads(annotations.get(self._FAILSAFE, '')) except Exception: failsafe = None # get leader - leader_record = {n: annotations.get(n) for n in (self._LEADER, 'acquireTime', - 'ttl', 'renewTime', 'transitions') if n in annotations} + leader_record: Dict[str, str] = {n: annotations[n] for n in (self._LEADER, 'acquireTime', + 'ttl', 'renewTime', 'transitions') if n in annotations} # We want to memorize leader_observed_record and update leader_observed_time only for our cluster if leader_path == self.leader_path and (leader_record or self._leader_observed_record)\ and leader_record != self._leader_observed_record: @@ -883,7 +904,7 @@ class Kubernetes(AbstractDCS): leader = leader_record.get(self._LEADER) try: - ttl = int(leader_record.get('ttl')) or self._ttl + ttl = int(leader_record.get('ttl', self._ttl)) or self._ttl except (TypeError, ValueError): ttl = self._ttl @@ -893,15 +914,17 @@ class Kubernetes(AbstractDCS): leader = None if metadata: - member = Member(-1, leader, None, {}) + member = Member(-1, leader or '', None, {}) member = ([m for m in members if m.name == leader] or [member])[0] leader = Leader(metadata.resource_version, None, member) + else: + leader = None # failover key failover = nodes.get(path + self._FAILOVER) metadata = failover and failover.metadata - failover = Failover.from_node(metadata and metadata.resource_version, - metadata and (metadata.annotations or {}).copy()) + failover = metadata and Failover.from_node(metadata.resource_version, + (metadata.annotations or {}).copy()) # get synchronization state sync = nodes.get(path + self._SYNC) @@ -910,32 +933,35 @@ class Kubernetes(AbstractDCS): return Cluster(initialize, config, leader, last_lsn, members, failover, sync, history, slots, failsafe) - def _cluster_loader(self, path): - return self._cluster_from_nodes(path['group'], path['nodes'], path['pods']) + def _cluster_loader(self, path: Dict[str, Any]) -> Cluster: + return self._cluster_from_nodes(path['group'], path['nodes'], path['pods'].values()) - def _citus_cluster_loader(self, path): - clusters = defaultdict(lambda: {'pods': [], 'nodes': {}}) + def _citus_cluster_loader(self, path: Dict[str, Any]) -> Dict[int, Cluster]: + clusters: Dict[str, Dict[str, Dict[str, K8sObject]]] = defaultdict(lambda: defaultdict(dict)) - for pod in path['pods']: + for name, pod in path['pods'].items(): group = pod.metadata.labels.get(self._CITUS_LABEL) if group and citus_group_re.match(group): - clusters[group]['pods'].append(pod) + clusters[group]['pods'][name] = pod for name, kind in path['nodes'].items(): group = kind.metadata.labels.get(self._CITUS_LABEL) if group and citus_group_re.match(group): clusters[group]['nodes'][name] = kind - return {int(group): self._cluster_from_nodes(group, value['nodes'], value['pods']) + return {int(group): self._cluster_from_nodes(group, value['nodes'], value['pods'].values()) for group, value in clusters.items()} - def __load_cluster(self, group, loader): + def __load_cluster( + self, group: Optional[str], loader: Callable[[Dict[str, Any]], Union[Cluster, Dict[int, Cluster]]] + ) -> Union[Cluster, Dict[int, Cluster]]: + assert self._retry.deadline is not None stop_time = time.time() + self._retry.deadline self._api.refresh_api_servers_cache() try: with self._condition: self._wait_caches(stop_time) - pods = [pod for pod in self._pods.copy().values() - if not group or pod.metadata.labels.get(self._CITUS_LABEL) == group] + pods = {name: pod for name, pod in self._pods.copy().items() + if not group or pod.metadata.labels.get(self._CITUS_LABEL) == group} nodes = {name: kind for name, kind in self._kinds.copy().items() if not group or kind.metadata.labels.get(self._CITUS_LABEL) == group} return loader({'group': group, 'pods': pods, 'nodes': nodes}) @@ -943,25 +969,27 @@ class Kubernetes(AbstractDCS): logger.exception('get_cluster') raise KubernetesError('Kubernetes API is not responding properly') - def _load_cluster(self, path, loader): + def _load_cluster( + self, path: str, loader: Callable[[Any], Union[Cluster, Dict[int, Cluster]]] + ) -> Union[Cluster, Dict[int, Cluster]]: group = self._citus_group if path == self.client_path('') else None return self.__load_cluster(group, loader) - def get_citus_coordinator(self): + def get_citus_coordinator(self) -> Optional[Cluster]: try: - return self.__load_cluster(str(CITUS_COORDINATOR_GROUP_ID), self._cluster_loader) + ret = self.__load_cluster(str(CITUS_COORDINATOR_GROUP_ID), self._cluster_loader) + assert isinstance(ret, Cluster) + return ret except Exception as e: logger.error('Failed to load Citus coordinator cluster from Kubernetes: %r', e) @staticmethod - def compare_ports(p1, p2): + def compare_ports(p1: K8sObject, p2: K8sObject) -> bool: return p1.name == p2.name and p1.port == p2.port and (p1.protocol or 'TCP') == (p2.protocol or 'TCP') @staticmethod - def subsets_changed(last_observed_subsets, ip, ports): + def subsets_changed(last_observed_subsets: List[K8sObject], ip: str, ports: List[K8sObject]) -> bool: """ - >>> Kubernetes.subsets_changed([], None, []) - True >>> ip = '1.2.3.4' >>> a = [k8s_client.V1EndpointAddress(ip=ip)] >>> s = [k8s_client.V1EndpointSubset(addresses=a)] @@ -995,7 +1023,7 @@ class Kubernetes(AbstractDCS): return True return False - def __target_ref(self, leader_ip, latest_subsets, pod): + def __target_ref(self, leader_ip: str, latest_subsets: List[K8sObject], pod: K8sObject) -> K8sObject: # we want to re-use existing target_ref if possible for subset in latest_subsets: for address in subset.addresses or []: @@ -1004,7 +1032,7 @@ class Kubernetes(AbstractDCS): return k8s_client.V1ObjectReference(kind='Pod', uid=pod.metadata.uid, namespace=self._namespace, name=self._name, resource_version=pod.metadata.resource_version) - def _map_subsets(self, endpoints, ips): + def _map_subsets(self, endpoints: Dict[str, Any], ips: List[str]) -> None: leader = self._kinds.get(self.leader_path) latest_subsets = leader and leader.subsets or [] if not ips: @@ -1022,16 +1050,19 @@ class Kubernetes(AbstractDCS): address = k8s_client.V1EndpointAddress(ip=leader_ip, **kwargs) endpoints['subsets'] = [k8s_client.V1EndpointSubset(addresses=[address], ports=self.__ports)] - def _patch_or_create(self, name, annotations, resource_version=None, patch=False, retry=None, ips=None): + def _patch_or_create(self, name: str, annotations: Dict[str, Any], + resource_version: Optional[str] = None, patch: bool = False, + retry: Optional[Callable[..., Any]] = None, ips: Optional[List[str]] = None) -> K8sObject: metadata = {'namespace': self._namespace, 'name': name, 'labels': self._labels, 'annotations': annotations} if patch or resource_version: if resource_version is not None: metadata['resource_version'] = resource_version func = functools.partial(self._api.patch_namespaced_kind, name) + metadata['annotations'] = annotations else: func = functools.partial(self._api.create_namespaced_kind) # skip annotations with null values - metadata['annotations'] = {k: v for k, v in metadata['annotations'].items() if v is not None} + metadata['annotations'] = {k: v for k, v in annotations.items() if v is not None} metadata = k8s_client.V1ObjectMeta(**metadata) if ips is not None and self._api.use_endpoints: @@ -1046,11 +1077,11 @@ class Kubernetes(AbstractDCS): return ret @catch_kubernetes_errors - def patch_or_create(self, name, annotations, resource_version=None, patch=False, retry=True, ips=None): - if retry is True: - retry = self.retry + def patch_or_create(self, name: str, annotations: Dict[str, Any], resource_version: Optional[str] = None, + patch: bool = False, retry: bool = True, ips: Optional[List[str]] = None) -> bool: try: - return self._patch_or_create(name, annotations, resource_version, patch, retry, ips) + return bool(self._patch_or_create(name, annotations, resource_version, + patch, self.retry if retry else None, ips)) except k8s_client.rest.ApiException as e: if e.status == 409 and resource_version: # Conflict in resource_version # Terminate watchers, it could be a sign that K8s API is in a failed state @@ -1058,14 +1089,15 @@ class Kubernetes(AbstractDCS): self._pods.kill_stream() raise e - def patch_or_create_config(self, annotations, resource_version=None, patch=False, retry=True): + def patch_or_create_config(self, annotations: Dict[str, Any], + resource_version: Optional[str] = None, patch: bool = False, retry: bool = True) -> bool: # SCOPE-config endpoint requires corresponding service otherwise it might be "cleaned" by k8s master if self._api.use_endpoints and not patch and not resource_version: self._should_create_config_service = True self._create_config_service() return self.patch_or_create(self.config_path, annotations, resource_version, patch, retry) - def _create_config_service(self): + def _create_config_service(self) -> None: metadata = k8s_client.V1ObjectMeta(namespace=self._namespace, name=self.config_path, labels=self._labels) body = k8s_client.V1Service(metadata=metadata, spec=k8s_client.V1ServiceSpec(cluster_ip='None')) try: @@ -1077,27 +1109,32 @@ class Kubernetes(AbstractDCS): return logger.exception('create_config_service failed') self._should_create_config_service = False - def _write_leader_optime(self, last_lsn): + def _write_leader_optime(self, last_lsn: str) -> bool: """Unused""" + raise NotImplementedError # pragma: no cover - def _write_status(self, value): + def _write_status(self, value: str) -> bool: """Unused""" + raise NotImplementedError # pragma: no cover - def _write_failsafe(self, value): + def _write_failsafe(self, value: str) -> bool: """Unused""" + raise NotImplementedError # pragma: no cover - def _update_leader(self): + def _update_leader(self) -> bool: """Unused""" + raise NotImplementedError # pragma: no cover - def _update_leader_with_retry(self, annotations, resource_version, ips): + def _update_leader_with_retry(self, annotations: Dict[str, Any], + resource_version: Optional[str], ips: List[str]) -> bool: retry = self._retry.copy() - def _retry(*args, **kwargs): + def _retry(*args: Any, **kwargs: Any) -> Any: kwargs['_retry'] = retry return retry(*args, **kwargs) try: - return self._patch_or_create(self.leader_path, annotations, resource_version, ips=ips, retry=_retry) + return bool(self._patch_or_create(self.leader_path, annotations, resource_version, ips=ips, retry=_retry)) except k8s_client.rest.ApiException as e: if e.status == 409: logger.warning('Concurrent update of %s', self.leader_path) @@ -1134,10 +1171,11 @@ class Kubernetes(AbstractDCS): if kind and (kind_annotations.get(self._LEADER) != self._name or kind_resource_version == resource_version): return False - return self._run_and_handle_exceptions(self._patch_or_create, self.leader_path, annotations, - kind_resource_version, ips=ips, retry=_retry) + return bool(self._run_and_handle_exceptions(self._patch_or_create, self.leader_path, annotations, + kind_resource_version, ips=ips, retry=_retry)) - def update_leader(self, last_lsn, slots=None, failsafe=None): + def update_leader(self, last_lsn: Optional[int], slots: Optional[Dict[str, int]] = None, + failsafe: Optional[Dict[str, str]] = None) -> bool: kind = self._kinds.get(self.leader_path) kind_annotations = kind and kind.metadata.annotations or {} @@ -1159,13 +1197,13 @@ class Kubernetes(AbstractDCS): resource_version = kind and kind.metadata.resource_version return self._update_leader_with_retry(annotations, resource_version, self.__ips) - def attempt_to_acquire_leader(self): + def attempt_to_acquire_leader(self) -> bool: now = datetime.datetime.now(tzutc).isoformat() annotations = {self._LEADER: self._name, 'ttl': str(self._ttl), 'renewTime': now, 'acquireTime': now, 'transitions': '0'} if self._leader_observed_record: try: - transitions = int(self._leader_observed_record.get('transitions')) + transitions = int(self._leader_observed_record.get('transitions', '')) except (TypeError, ValueError): transitions = 0 @@ -1174,11 +1212,11 @@ class Kubernetes(AbstractDCS): else: annotations['acquireTime'] = self._leader_observed_record.get('acquireTime') or now annotations['transitions'] = str(transitions) - ips = [] if self._api.use_endpoints else None + ips: Optional[List[str]] = [] if self._api.use_endpoints else None try: - ret = self._patch_or_create(self.leader_path, annotations, - self._leader_resource_version, retry=self.retry, ips=ips) + ret = bool(self._patch_or_create(self.leader_path, annotations, + self._leader_resource_version, retry=self.retry, ips=ips)) except k8s_client.rest.ApiException as e: if e.status == 409 and self._leader_resource_version: # Conflict in resource_version # Terminate watchers, it could be a sign that K8s API is in a failed state @@ -1192,28 +1230,30 @@ class Kubernetes(AbstractDCS): logger.info('Could not take out TTL lock') return ret - def take_leader(self): + def take_leader(self) -> bool: return self.attempt_to_acquire_leader() - def set_failover_value(self, value, index=None): + def set_failover_value(self, value: str, index: Optional[str] = None) -> bool: """Unused""" + raise NotImplementedError # pragma: no cover - def manual_failover(self, leader, candidate, scheduled_at=None, index=None): + def manual_failover(self, leader: Optional[str], candidate: Optional[str], + scheduled_at: Optional[datetime.datetime] = None, index: Optional[str] = None) -> bool: annotations = {'leader': leader or None, 'member': candidate or None, 'scheduled_at': scheduled_at and scheduled_at.isoformat()} patch = bool(self.cluster and isinstance(self.cluster.failover, Failover) and self.cluster.failover.index) return self.patch_or_create(self.failover_path, annotations, index, bool(index or patch), False) @property - def _config_resource_version(self): + def _config_resource_version(self) -> Optional[str]: config = self._kinds.get(self.config_path) return config and config.metadata.resource_version - def set_config_value(self, value, index=None): + def set_config_value(self, value: str, index: Optional[str] = None) -> bool: return self.patch_or_create_config({self._CONFIG: value}, index, bool(self._config_resource_version), False) @catch_kubernetes_errors - def touch_member(self, data): + def touch_member(self, data: Dict[str, Any]) -> bool: cluster = self.cluster if cluster and cluster.leader and cluster.leader.name == self._name: role = 'master' @@ -1224,7 +1264,8 @@ class Kubernetes(AbstractDCS): member = cluster and cluster.get_member(self._name, fallback_to_leader=False) pod_labels = member and member.data.pop('pod_labels', None) - ret = pod_labels is not None and pod_labels.get(self._role_label) == role and deep_compare(data, member.data) + ret = member and pod_labels is not None\ + and pod_labels.get(self._role_label) == role and deep_compare(data, member.data) if not ret: metadata = {'namespace': self._namespace, 'name': self._name, 'labels': {self._role_label: role}, @@ -1235,40 +1276,45 @@ class Kubernetes(AbstractDCS): self._pods.set(self._name, ret) if self._should_create_config_service: self._create_config_service() - return ret + return bool(ret) - def initialize(self, create_new=True, sysid=""): + def initialize(self, create_new: bool = True, sysid: str = "") -> bool: cluster = self.cluster - resource_version = cluster.config.index if cluster and cluster.config and cluster.config.index else None + resource_version = str(cluster.config.index) if cluster and cluster.config and cluster.config.index else None return self.patch_or_create_config({self._INITIALIZE: sysid}, resource_version) - def _delete_leader(self): + def _delete_leader(self) -> bool: """Unused""" + raise NotImplementedError # pragma: no cover - def delete_leader(self, last_lsn=None): + def delete_leader(self, last_lsn: Optional[int] = None) -> bool: + ret = False kind = self._kinds.get(self.leader_path) if kind and (kind.metadata.annotations or {}).get(self._LEADER) == self._name: - annotations = {self._LEADER: None} + annotations: Dict[str, Optional[str]] = {self._LEADER: None} if last_lsn: annotations[self._OPTIME] = str(last_lsn) - self.patch_or_create(self.leader_path, annotations, kind.metadata.resource_version, True, False, []) + ret = self.patch_or_create(self.leader_path, annotations, kind.metadata.resource_version, True, False, []) self.reset_cluster() + return ret - def cancel_initialization(self): + def cancel_initialization(self) -> bool: return self.patch_or_create_config({self._INITIALIZE: None}, None, True) @catch_kubernetes_errors - def delete_cluster(self): - self.retry(self._api.delete_collection_namespaced_kind, self._namespace, label_selector=self._label_selector) + def delete_cluster(self) -> bool: + return bool(self.retry(self._api.delete_collection_namespaced_kind, + self._namespace, label_selector=self._label_selector)) - def set_history_value(self, value): + def set_history_value(self, value: str) -> bool: return self.patch_or_create_config({self._HISTORY: value}, None, bool(self._config_resource_version), False) - def set_sync_state_value(self, value, index=None): + def set_sync_state_value(self, value: str, index: Optional[str] = None) -> bool: """Unused""" + raise NotImplementedError # pragma: no cover - def write_sync_state(self, leader: Union[str, None], sync_standby: Union[Collection[str], None], - index: Optional[Union[int, str]] = None) -> bool: + def write_sync_state(self, leader: Optional[str], sync_standby: Optional[Collection[str]], + index: Optional[str] = None) -> bool: """Prepare and write annotations to $SCOPE-sync Endpoint or ConfigMap. :param leader: name of the leader node that manages /sync key @@ -1288,7 +1334,7 @@ class Kubernetes(AbstractDCS): """ return self.write_sync_state(None, None, index=index) - def watch(self, leader_index, timeout): + def watch(self, leader_index: Optional[str], timeout: float) -> bool: if self.__do_not_watch: self.__do_not_watch = False return True diff --git a/patroni/dcs/raft.py b/patroni/dcs/raft.py index 5d0dbc50..1c563788 100644 --- a/patroni/dcs/raft.py +++ b/patroni/dcs/raft.py @@ -10,10 +10,13 @@ from pysyncobj.dns_resolver import globalDnsResolver from pysyncobj.node import TCPNode from pysyncobj.transport import TCPTransport, CONNECTION_STATE from pysyncobj.utility import TcpUtility +from typing import Any, Callable, Collection, Dict, List, Optional, Union, TYPE_CHECKING from . import AbstractDCS, ClusterConfig, Cluster, Failover, Leader, Member, SyncState, TimelineHistory, citus_group_re from ..exceptions import DCSError from ..utils import validate_directory +if TYPE_CHECKING: # pragma: no cover + from ..config import Config logger = logging.getLogger(__name__) @@ -24,11 +27,12 @@ class RaftError(DCSError): class _TCPTransport(TCPTransport): - def __init__(self, syncObj, selfNode, otherNodes): + def __init__(self, syncObj: 'DynMemberSyncObj', selfNode: Optional[TCPNode], + otherNodes: Collection[TCPNode]) -> None: super(_TCPTransport, self).__init__(syncObj, selfNode, otherNodes) self.setOnUtilityMessageCallback('members', syncObj.getMembers) - def _connectIfNecessarySingle(self, node): + def _connectIfNecessarySingle(self, node: TCPNode) -> bool: try: return super(_TCPTransport, self)._connectIfNecessarySingle(node) except Exception as e: @@ -36,7 +40,7 @@ class _TCPTransport(TCPTransport): return False -def resolve_host(self): +def resolve_host(self: TCPNode) -> Optional[str]: return globalDnsResolver().resolve(self.host) @@ -45,17 +49,19 @@ setattr(TCPNode, 'ip', property(resolve_host)) class SyncObjUtility(object): - def __init__(self, otherNodes, conf, retry_timeout=10): + def __init__(self, otherNodes: Collection[Union[str, TCPNode]], conf: SyncObjConf, retry_timeout: int = 10) -> None: self._nodes = otherNodes self._utility = TcpUtility(conf.password, retry_timeout / max(1, len(otherNodes))) + self.__node = next(iter(otherNodes), None) - def executeCommand(self, command): + def executeCommand(self, command: List[Any]) -> Any: try: - return self._utility.executeCommand(self.__node, command) + if self.__node: + return self._utility.executeCommand(self.__node, command) except Exception: return None - def getMembers(self): + def getMembers(self) -> Optional[List[str]]: for self.__node in self._nodes: response = self.executeCommand(['members']) if response: @@ -64,7 +70,8 @@ class SyncObjUtility(object): class DynMemberSyncObj(SyncObj): - def __init__(self, selfAddress, partnerAddrs, conf, retry_timeout=10): + def __init__(self, selfAddress: Optional[str], partnerAddrs: Collection[str], + conf: SyncObjConf, retry_timeout: int = 10) -> None: self.__early_apply_local_log = selfAddress is not None self.applied_local_log = False @@ -81,12 +88,12 @@ class DynMemberSyncObj(SyncObj): thread.daemon = True thread.start() - def getMembers(self, args, callback): + def getMembers(self, args: Any, callback: Callable[[Any, Any], Any]) -> None: callback([{'addr': node.id, 'leader': node == self._getLeader(), 'status': CONNECTION_STATE.CONNECTED if self.isNodeConnected(node) else CONNECTION_STATE.DISCONNECTED} for node in self.otherNodes] + [{'addr': self.selfNode.id, 'leader': self._isLeader(), 'status': CONNECTION_STATE.CONNECTED}], None) - def _onTick(self, timeToWait=0.0): + def _onTick(self, timeToWait: float = 0.0): super(DynMemberSyncObj, self)._onTick(timeToWait) # The SyncObj calls onReady callback only when cluster got the leader and is ready for writes. @@ -98,11 +105,12 @@ class DynMemberSyncObj(SyncObj): class KVStoreTTL(DynMemberSyncObj): - def __init__(self, on_ready, on_set, on_delete, **config): + def __init__(self, on_ready: Optional[Callable[..., Any]], on_set: Optional[Callable[[str, Dict[str, Any]], None]], + on_delete: Optional[Callable[[str], None]], **config: Any) -> None: self.__thread = None self.__on_set = on_set self.__on_delete = on_delete - self.__limb = {} + self.__limb: Dict[str, Dict[str, Any]] = {} self.set_retry_timeout(int(config.get('retry_timeout') or 10)) self_addr = config.get('self_addr') @@ -128,22 +136,22 @@ class KVStoreTTL(DynMemberSyncObj): onReady=on_ready, dynamicMembershipChange=True) super(KVStoreTTL, self).__init__(self_addr, partner_addrs, conf, self.__retry_timeout) - self.__data = {} + self.__data: Dict[str, Dict[str, Any]] = {} @staticmethod - def __check_requirements(old_value, **kwargs): - return ('prevExist' not in kwargs or bool(kwargs['prevExist']) == bool(old_value)) and \ - ('prevValue' not in kwargs or old_value and old_value['value'] == kwargs['prevValue']) and \ - (not kwargs.get('prevIndex') or old_value and old_value['index'] == kwargs['prevIndex']) + def __check_requirements(old_value: Dict[str, Any], **kwargs: Any) -> bool: + return bool(('prevExist' not in kwargs or bool(kwargs['prevExist']) == bool(old_value)) + and ('prevValue' not in kwargs or old_value and old_value['value'] == kwargs['prevValue']) + and (not kwargs.get('prevIndex') or old_value and old_value['index'] == kwargs['prevIndex'])) - def set_retry_timeout(self, retry_timeout): + def set_retry_timeout(self, retry_timeout: int) -> None: self.__retry_timeout = retry_timeout - def retry(self, func, *args, **kwargs): + def retry(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: event = threading.Event() ret = {'result': None, 'error': -1} - def callback(result, error): + def callback(result: Any, error: Any) -> None: ret.update(result=result, error=error) event.set() @@ -167,7 +175,7 @@ class KVStoreTTL(DynMemberSyncObj): return False @replicated - def _set(self, key, value, **kwargs): + def _set(self, key: str, value: Dict[str, Any], **kwargs: Any) -> bool: old_value = self.__data.get(key, {}) if not self.__check_requirements(old_value, **kwargs): return False @@ -181,29 +189,30 @@ class KVStoreTTL(DynMemberSyncObj): self.__on_set(key, value) return True - def set(self, key, value, ttl=None, handle_raft_error=True, **kwargs): + def set(self, key: str, value: str, ttl: Optional[int] = None, + handle_raft_error: bool = True, **kwargs: Any) -> bool: old_value = self.__data.get(key, {}) if not self.__check_requirements(old_value, **kwargs): return False - value = {'value': value, 'updated': time.time()} - value['created'] = old_value.get('created', value['updated']) + data: Dict[str, Any] = {'value': value, 'updated': time.time()} + data['created'] = old_value.get('created', data['updated']) if ttl: - value['expire'] = value['updated'] + ttl + data['expire'] = data['updated'] + ttl try: - return self.retry(self._set, key, value, **kwargs) + return self.retry(self._set, key, data, **kwargs) except RaftError: if not handle_raft_error: raise return False - def __pop(self, key): + def __pop(self, key: str) -> None: self.__data.pop(key) if self.__on_delete: self.__on_delete(key) @replicated - def _delete(self, key, recursive=False, **kwargs): + def _delete(self, key: str, recursive: bool = False, **kwargs: Any) -> bool: if recursive: for k in list(self.__data.keys()): if k.startswith(key): @@ -214,7 +223,7 @@ class KVStoreTTL(DynMemberSyncObj): self.__pop(key) return True - def delete(self, key, recursive=False, **kwargs): + def delete(self, key: str, recursive: bool = False, **kwargs: Any) -> bool: if not recursive and not self.__check_requirements(self.__data.get(key, {}), **kwargs): return False try: @@ -223,32 +232,32 @@ class KVStoreTTL(DynMemberSyncObj): return False @staticmethod - def __values_match(old, new): + def __values_match(old: Dict[str, Any], new: Dict[str, Any]) -> bool: return all(old.get(n) == new.get(n) for n in ('created', 'updated', 'expire', 'value')) @replicated - def _expire(self, key, value, callback=None): + def _expire(self, key: str, value: Dict[str, Any], callback: Optional[Callable[..., Any]] = None) -> None: current = self.__data.get(key) if current and self.__values_match(current, value): self.__pop(key) - def __expire_keys(self): + def __expire_keys(self) -> None: for key, value in self.__data.items(): if value and 'expire' in value and value['expire'] <= time.time() and \ not (key in self.__limb and self.__values_match(self.__limb[key], value)): self.__limb[key] = value - def callback(*args): + def callback(*args: Any) -> None: if key in self.__limb and self.__values_match(self.__limb[key], value): self.__limb.pop(key) self._expire(key, value, callback=callback) - def get(self, key, recursive=False): + def get(self, key: str, recursive: bool = False) -> Union[None, Dict[str, Any], Dict[str, Dict[str, Any]]]: if not recursive: return self.__data.get(key) return {k: v for k, v in self.__data.items() if k.startswith(key)} - def _onTick(self, timeToWait=0.0): + def _onTick(self, timeToWait: float = 0.0) -> None: super(KVStoreTTL, self)._onTick(timeToWait) if self._isLeader(): @@ -256,17 +265,17 @@ class KVStoreTTL(DynMemberSyncObj): else: self.__limb.clear() - def _autoTickThread(self): + def _autoTickThread(self) -> None: self.__destroying = False while not self.__destroying: self.doTick(self.conf.autoTickPeriod) - def startAutoTick(self): + def startAutoTick(self) -> None: self.__thread = threading.Thread(target=self._autoTickThread) self.__thread.daemon = True self.__thread.start() - def destroy(self): + def destroy(self) -> None: if self.__thread: self.__destroying = True self.__thread.join() @@ -275,7 +284,7 @@ class KVStoreTTL(DynMemberSyncObj): class Raft(AbstractDCS): - def __init__(self, config): + def __init__(self, config: Dict[str, Any]) -> None: super(Raft, self).__init__(config) self._ttl = int(config.get('ttl') or 30) @@ -290,7 +299,7 @@ class Raft(AbstractDCS): else: logger.info('waiting on raft') - def _on_set(self, key, value): + def _on_set(self, key: str, value: Dict[str, Any]) -> None: leader = (self._sync_obj.get(self.leader_path) or {}).get('value') if key == value['created'] == value['updated'] and \ (key.startswith(self.members_path) or key == self.leader_path and leader != self._name) or \ @@ -298,29 +307,29 @@ class Raft(AbstractDCS): key in (self.config_path, self.sync_path): self.event.set() - def _on_delete(self, key): + def _on_delete(self, key: str) -> None: if key == self.leader_path: self.event.set() - def set_ttl(self, ttl): + def set_ttl(self, ttl: int) -> Optional[bool]: self._ttl = ttl @property - def ttl(self): + def ttl(self) -> int: return self._ttl - def set_retry_timeout(self, retry_timeout): + def set_retry_timeout(self, retry_timeout: int) -> None: self._sync_obj.set_retry_timeout(retry_timeout) - def reload_config(self, config): + def reload_config(self, config: Union['Config', Dict[str, Any]]) -> None: super(Raft, self).reload_config(config) globalDnsResolver().setTimeouts(self.ttl, self.loop_wait) @staticmethod - def member(key, value): + def member(key: str, value: Dict[str, Any]) -> Member: return Member.from_node(value['index'], os.path.basename(key), None, value['value']) - def _cluster_from_nodes(self, nodes): + def _cluster_from_nodes(self, nodes: Dict[str, Any]) -> Cluster: # get initialize flag initialize = nodes.get(self._INITIALIZE) initialize = initialize and initialize['value'] @@ -348,7 +357,7 @@ class Raft(AbstractDCS): slots = None try: - last_lsn = int(last_lsn) + last_lsn = int(last_lsn or '') except Exception: last_lsn = 0 @@ -380,79 +389,81 @@ class Raft(AbstractDCS): return Cluster(initialize, config, leader, last_lsn, members, failover, sync, history, slots, failsafe) - def _cluster_loader(self, path): + def _cluster_loader(self, path: str) -> Cluster: response = self._sync_obj.get(path, recursive=True) if not response: return Cluster.empty() nodes = {key[len(path):]: value for key, value in response.items()} return self._cluster_from_nodes(nodes) - def _citus_cluster_loader(self, path): - clusters = defaultdict(dict) + def _citus_cluster_loader(self, path: str) -> Dict[int, Cluster]: + clusters: Dict[int, Dict[str, Any]] = defaultdict(dict) response = self._sync_obj.get(path, recursive=True) - for key, value in response.items(): + for key, value in (response or {}).items(): key = key[len(path):].split('/', 1) if len(key) == 2 and citus_group_re.match(key[0]): clusters[int(key[0])][key[1]] = value return {group: self._cluster_from_nodes(nodes) for group, nodes in clusters.items()} - def _load_cluster(self, path, loader): + def _load_cluster( + self, path: str, loader: Callable[[str], Union[Cluster, Dict[int, Cluster]]] + ) -> Union[Cluster, Dict[int, Cluster]]: return loader(path) - def _write_leader_optime(self, last_lsn): + def _write_leader_optime(self, last_lsn: str) -> bool: return self._sync_obj.set(self.leader_optime_path, last_lsn, timeout=1) - def _write_status(self, value): + def _write_status(self, value: str) -> bool: return self._sync_obj.set(self.status_path, value, timeout=1) - def _write_failsafe(self, value): + def _write_failsafe(self, value: str) -> bool: return self._sync_obj.set(self.failsafe_path, value, timeout=1) - def _update_leader(self): + def _update_leader(self) -> bool: ret = self._sync_obj.set(self.leader_path, self._name, ttl=self._ttl, handle_raft_error=False, prevValue=self._name) if not ret and self._sync_obj.get(self.leader_path) is None: ret = self.attempt_to_acquire_leader() return ret - def attempt_to_acquire_leader(self): + def attempt_to_acquire_leader(self) -> bool: return self._sync_obj.set(self.leader_path, self._name, ttl=self._ttl, handle_raft_error=False, prevExist=False) - def set_failover_value(self, value, index=None): + def set_failover_value(self, value: str, index: Optional[int] = None) -> bool: return self._sync_obj.set(self.failover_path, value, prevIndex=index) - def set_config_value(self, value, index=None): + def set_config_value(self, value: str, index: Optional[int] = None) -> bool: return self._sync_obj.set(self.config_path, value, prevIndex=index) - def touch_member(self, data): - data = json.dumps(data, separators=(',', ':')) - return self._sync_obj.set(self.member_path, data, self._ttl, timeout=2) + def touch_member(self, data: Dict[str, Any]) -> bool: + value = json.dumps(data, separators=(',', ':')) + return self._sync_obj.set(self.member_path, value, self._ttl, timeout=2) - def take_leader(self): + def take_leader(self) -> bool: return self._sync_obj.set(self.leader_path, self._name, ttl=self._ttl) - def initialize(self, create_new=True, sysid=''): + def initialize(self, create_new: bool = True, sysid: str = '') -> bool: return self._sync_obj.set(self.initialize_path, sysid, prevExist=(not create_new)) - def _delete_leader(self): + def _delete_leader(self) -> bool: return self._sync_obj.delete(self.leader_path, prevValue=self._name, timeout=1) - def cancel_initialization(self): + def cancel_initialization(self) -> bool: return self._sync_obj.delete(self.initialize_path) - def delete_cluster(self): + def delete_cluster(self) -> bool: return self._sync_obj.delete(self.client_path(''), recursive=True) - def set_history_value(self, value): + def set_history_value(self, value: str) -> bool: return self._sync_obj.set(self.history_path, value) - def set_sync_state_value(self, value, index=None): + def set_sync_state_value(self, value: str, index: Optional[int] = None) -> bool: return self._sync_obj.set(self.sync_path, value, prevIndex=index) - def delete_sync_state(self, index=None): + def delete_sync_state(self, index: Optional[int] = None) -> bool: return self._sync_obj.delete(self.sync_path, prevIndex=index) - def watch(self, leader_index, timeout): + def watch(self, leader_index: Optional[int], timeout: float) -> bool: try: return super(Raft, self).watch(leader_index, timeout) finally: diff --git a/patroni/dcs/zookeeper.py b/patroni/dcs/zookeeper.py index 3f102f8d..eff6aceb 100644 --- a/patroni/dcs/zookeeper.py +++ b/patroni/dcs/zookeeper.py @@ -1,18 +1,22 @@ import json import logging import select +import socket import time from kazoo.client import KazooClient, KazooState, KazooRetry from kazoo.exceptions import ConnectionClosedError, NoNodeError, NodeExistsError, SessionExpiredError -from kazoo.handlers.threading import SequentialThreadingHandler -from kazoo.protocol.states import KeeperState +from kazoo.handlers.threading import AsyncResult, SequentialThreadingHandler +from kazoo.protocol.states import KeeperState, WatchedEvent, ZnodeStat from kazoo.retry import RetryFailedError -from kazoo.security import make_acl +from kazoo.security import ACL, make_acl +from typing import Any, Callable, Dict, List, Optional, Union, Tuple, TYPE_CHECKING from . import AbstractDCS, ClusterConfig, Cluster, Failover, Leader, Member, SyncState, TimelineHistory, citus_group_re from ..exceptions import DCSError from ..utils import deep_compare +if TYPE_CHECKING: # pragma: no cover + from ..config import Config logger = logging.getLogger(__name__) @@ -23,14 +27,14 @@ class ZooKeeperError(DCSError): class PatroniSequentialThreadingHandler(SequentialThreadingHandler): - def __init__(self, connect_timeout): + def __init__(self, connect_timeout: Union[int, float]) -> None: super(PatroniSequentialThreadingHandler, self).__init__() self.set_connect_timeout(connect_timeout) - def set_connect_timeout(self, connect_timeout): + def set_connect_timeout(self, connect_timeout: Union[int, float]) -> None: self._connect_timeout = max(1.0, connect_timeout / 2.0) # try to connect to zookeeper node during loop_wait/2 - def create_connection(self, *args, **kwargs): + def create_connection(self, *args: Any, **kwargs: Any) -> socket.socket: """This method is trying to establish connection with one of the zookeeper nodes. Somehow strategy "fail earlier and retry more often" works way better comparing to the original strategy "try to connect with specified timeout". @@ -41,16 +45,16 @@ class PatroniSequentialThreadingHandler(SequentialThreadingHandler): :param args: always contains `tuple(host, port)` as the first element and could contain `connect_timeout` (negotiated session timeout) as the second element.""" - args = list(args) - if len(args) == 0: # kazoo 2.6.0 slightly changed the way how it calls create_connection method + args_list: List[Any] = list(args) + if len(args_list) == 0: # kazoo 2.6.0 slightly changed the way how it calls create_connection method kwargs['timeout'] = max(self._connect_timeout, kwargs.get('timeout', self._connect_timeout * 10) / 10.0) - elif len(args) == 1: - args.append(self._connect_timeout) + elif len(args_list) == 1: + args_list.append(self._connect_timeout) else: - args[1] = max(self._connect_timeout, args[1] / 10.0) - return super(PatroniSequentialThreadingHandler, self).create_connection(*args, **kwargs) + args_list[1] = max(self._connect_timeout, args_list[1] / 10.0) + return super(PatroniSequentialThreadingHandler, self).create_connection(*args_list, **kwargs) - def select(self, *args, **kwargs): + def select(self, *args: Any, **kwargs: Any) -> Any: """ Python 3.XY may raise following exceptions if select/poll are called with an invalid socket: - `ValueError`: because fd == -1 @@ -68,7 +72,7 @@ class PatroniSequentialThreadingHandler(SequentialThreadingHandler): class PatroniKazooClient(KazooClient): - def _call(self, request, async_object): + def _call(self, request: Tuple[Any], async_object: AsyncResult) -> Optional[bool]: # Before kazoo==2.7.0 it wasn't possible to send requests to zookeeper if # the connection is in the SUSPENDED state and Patroni was strongly relying on it. # The https://github.com/python-zk/kazoo/pull/588 changed it, and now such requests are queued. @@ -82,7 +86,7 @@ class PatroniKazooClient(KazooClient): class ZooKeeper(AbstractDCS): - def __init__(self, config): + def __init__(self, config: Dict[str, Any]) -> None: super(ZooKeeper, self).__init__(config) hosts = config.get('hosts', []) @@ -94,17 +98,18 @@ class ZooKeeper(AbstractDCS): kwargs = {v: config[k] for k, v in mapping.items() if k in config} if 'set_acls' in config: - kwargs['default_acl'] = [] + default_acl: List[ACL] = [] for principal, permissions in config['set_acls'].items(): normalizedPermissions = [p.upper() for p in permissions] - kwargs['default_acl'].append(make_acl(scheme='x509', - credential=principal, - read='READ' in normalizedPermissions, - write='WRITE' in normalizedPermissions, - create='CREATE' in normalizedPermissions, - delete='DELETE' in normalizedPermissions, - admin='ADMIN' in normalizedPermissions, - all='ALL' in normalizedPermissions)) + default_acl.append(make_acl(scheme='x509', + credential=principal, + read='READ' in normalizedPermissions, + write='WRITE' in normalizedPermissions, + create='CREATE' in normalizedPermissions, + delete='DELETE' in normalizedPermissions, + admin='ADMIN' in normalizedPermissions, + all='ALL' in normalizedPermissions)) + kwargs['default_acl'] = default_acl self._client = PatroniKazooClient(hosts, handler=PatroniSequentialThreadingHandler(config['retry_timeout']), timeout=config['ttl'], connection_retry=KazooRetry(max_delay=1, max_tries=-1, @@ -112,16 +117,16 @@ class ZooKeeper(AbstractDCS): deadline=config['retry_timeout'], sleep_func=time.sleep), **kwargs) self._client.add_listener(self.session_listener) - self._fetch_cluster = True - self._fetch_status = True - self.__last_member_data = None + self._fetch_cluster: bool = True + self._fetch_status: bool = True + self.__last_member_data: Optional[Dict[str, Any]] = None self._orig_kazoo_connect = self._client._connection._connect self._client._connection._connect = self._kazoo_connect self._client.start() - def _kazoo_connect(self, *args): + def _kazoo_connect(self, *args: Any) -> Tuple[Union[int, float], Union[int, float]]: """Kazoo is using Ping's to determine health of connection to zookeeper. If there is no response on Ping after Ping interval (1/2 from read_timeout) it will consider current connection dead and try to connect to another node. Without this "magic" it was taking @@ -136,30 +141,28 @@ class ZooKeeper(AbstractDCS): ret = self._orig_kazoo_connect(*args) return max(self.loop_wait - 2, 2) * 1000, ret[1] - def session_listener(self, state): + def session_listener(self, state: str) -> None: if state in [KazooState.SUSPENDED, KazooState.LOST]: self.cluster_watcher(None) - def status_watcher(self, event): + def status_watcher(self, event: Optional[WatchedEvent]) -> None: self._fetch_status = True self.event.set() - def cluster_watcher(self, event): + def cluster_watcher(self, event: Optional[WatchedEvent]) -> None: self._fetch_cluster = True if not event or event.state != KazooState.CONNECTED or event.path.startswith(self.client_path('')): self.status_watcher(event) - def members_watcher(self, event): - self._fetch_cluster = True - - def reload_config(self, config): + def reload_config(self, config: Union['Config', Dict[str, Any]]) -> None: self.set_retry_timeout(config['retry_timeout']) loop_wait = config['loop_wait'] loop_wait_changed = self._loop_wait != loop_wait self._loop_wait = loop_wait - self._client.handler.set_connect_timeout(loop_wait) + if isinstance(self._client.handler, PatroniSequentialThreadingHandler): + self._client.handler.set_connect_timeout(loop_wait) # We need to reestablish connection to zookeeper if we want to change # read_timeout (and Ping interval respectively), because read_timeout @@ -170,7 +173,7 @@ class ZooKeeper(AbstractDCS): if not self.set_ttl(config['ttl']) and loop_wait_changed: self._client._connection._socket.close() - def set_ttl(self, ttl): + def set_ttl(self, ttl: int) -> Optional[bool]: """It is not possible to change ttl (session_timeout) in zookeeper without destroying old session and creating the new one. This method returns `!True` if session_timeout has been changed (`restart()` has been called).""" @@ -181,21 +184,23 @@ class ZooKeeper(AbstractDCS): return True @property - def ttl(self): - return self._client._session_timeout / 1000.0 + def ttl(self) -> int: + return int(self._client._session_timeout / 1000.0) - def set_retry_timeout(self, retry_timeout): + def set_retry_timeout(self, retry_timeout: int) -> None: retry = self._client.retry if isinstance(self._client.retry, KazooRetry) else self._client._retry retry.deadline = retry_timeout - def get_node(self, key, watch=None): + def get_node( + self, key: str, watch: Optional[Callable[[WatchedEvent], None]] = None + ) -> Optional[Tuple[str, ZnodeStat]]: try: ret = self._client.get(key, watch) return (ret[0].decode('utf-8'), ret[1]) except NoNodeError: return None - def get_status(self, path, leader): + def get_status(self, path: str, leader: Optional[Leader]) -> Tuple[int, Optional[Dict[str, int]]]: watch = self.status_watcher if not leader or leader.name != self._name else None status = self.get_node(path + self._STATUS, watch) @@ -212,7 +217,7 @@ class ZooKeeper(AbstractDCS): slots = None try: - last_lsn = int(last_lsn) + last_lsn = int(last_lsn or '') except Exception: last_lsn = 0 @@ -220,24 +225,24 @@ class ZooKeeper(AbstractDCS): return last_lsn, slots @staticmethod - def member(name, value, znode): + def member(name: str, value: str, znode: ZnodeStat) -> Member: return Member.from_node(znode.version, name, znode.ephemeralOwner, value) - def get_children(self, key, watch=None): + def get_children(self, key: str, watch: Optional[Callable[[WatchedEvent], None]] = None) -> List[str]: try: return self._client.get_children(key, watch) except NoNodeError: return [] - def load_members(self, path): - members = [] + def load_members(self, path: str) -> List[Member]: + members: List[Member] = [] for member in self.get_children(path + self._MEMBERS, self.cluster_watcher): data = self.get_node(path + self._MEMBERS + member) if data is not None: members.append(self.member(member, *data)) return members - def _cluster_loader(self, path): + def _cluster_loader(self, path: str) -> Cluster: self._fetch_cluster = False self.event.clear() nodes = set(self.get_children(path, self.cluster_watcher)) @@ -286,9 +291,9 @@ class ZooKeeper(AbstractDCS): return Cluster(initialize, config, leader, last_lsn, members, failover, sync, history, slots, failsafe) - def _citus_cluster_loader(self, path): + def _citus_cluster_loader(self, path: str) -> Dict[int, Cluster]: fetch_cluster = False - ret = {} + ret: Dict[int, Cluster] = {} for node in self.get_children(path, self.cluster_watcher): if citus_group_re.match(node): ret[int(node)] = self._cluster_loader(path + node + '/') @@ -296,7 +301,9 @@ class ZooKeeper(AbstractDCS): self._fetch_cluster = fetch_cluster return ret - def _load_cluster(self, path, loader): + def _load_cluster( + self, path: str, loader: Callable[[str], Union[Cluster, Dict[int, Cluster]]] + ) -> Union[Cluster, Dict[int, Cluster]]: cluster = self.cluster if path == self._base_path + '/' else None if self._fetch_cluster or cluster is None: try: @@ -315,18 +322,18 @@ class ZooKeeper(AbstractDCS): try: last_lsn, slots = self.get_status(self.client_path(''), cluster.leader) self.event.clear() - cluster = list(cluster) - cluster[3] = last_lsn - cluster[8] = slots - cluster = Cluster(*cluster) + new_cluster: List[Any] = list(cluster) + new_cluster[3] = last_lsn + new_cluster[8] = slots + cluster = Cluster(*new_cluster) except Exception: pass return cluster - def _bypass_caches(self): + def _bypass_caches(self) -> None: self._fetch_cluster = True - def _create(self, path, value, retry=False, ephemeral=False): + def _create(self, path: str, value: bytes, retry: bool = False, ephemeral: bool = False) -> bool: try: if retry: self._client.retry(self._client.create, path, value, makepath=True, ephemeral=ephemeral) @@ -337,7 +344,7 @@ class ZooKeeper(AbstractDCS): logger.exception('Failed to create %s', path) return False - def attempt_to_acquire_leader(self): + def attempt_to_acquire_leader(self) -> bool: try: self._client.retry(self._client.create, self.leader_path, self._name.encode('utf-8'), makepath=True, ephemeral=True) @@ -350,43 +357,44 @@ class ZooKeeper(AbstractDCS): logger.info('Could not take out TTL lock') return False - def _set_or_create(self, key, value, index=None, retry=False, do_not_create_empty=False): - value = value.encode('utf-8') + def _set_or_create(self, key: str, value: str, index: Optional[int] = None, + retry: bool = False, do_not_create_empty: bool = False) -> bool: + value_bytes = value.encode('utf-8') try: if retry: - self._client.retry(self._client.set, key, value, version=index or -1) + self._client.retry(self._client.set, key, value_bytes, version=index or -1) else: - self._client.set_async(key, value, version=index or -1).get(timeout=1) + self._client.set_async(key, value_bytes, version=index or -1).get(timeout=1) return True except NoNodeError: - if do_not_create_empty and not value: + if do_not_create_empty and not value_bytes: return True elif index is None: - return self._create(key, value, retry) + return self._create(key, value_bytes, retry) else: return False except Exception: logger.exception('Failed to update %s', key) return False - def set_failover_value(self, value, index=None): + def set_failover_value(self, value: str, index: Optional[int] = None) -> bool: return self._set_or_create(self.failover_path, value, index) - def set_config_value(self, value, index=None): + def set_config_value(self, value: str, index: Optional[int] = None) -> bool: return self._set_or_create(self.config_path, value, index, retry=True) - def initialize(self, create_new=True, sysid=""): - sysid = sysid.encode('utf-8') - return self._create(self.initialize_path, sysid, retry=True) if create_new \ - else self._client.retry(self._client.set, self.initialize_path, sysid) + def initialize(self, create_new: bool = True, sysid: str = "") -> bool: + sysid_bytes = sysid.encode('utf-8') + return self._create(self.initialize_path, sysid_bytes, retry=True) if create_new \ + else self._client.retry(self._client.set, self.initialize_path, sysid_bytes) - def touch_member(self, data): + def touch_member(self, data: Dict[str, Any]) -> bool: cluster = self.cluster member = cluster and cluster.get_member(self._name, fallback_to_leader=False) member_data = self.__last_member_data or member and member.data # We want to notify leader if some important fields in the member key changed by removing ZNode if member and (self._client.client_id is not None and member.session != self._client.client_id[0] - or not (deep_compare(member_data.get('tags', {}), data.get('tags', {})) + or not (member_data and deep_compare(member_data.get('tags', {}), data.get('tags', {})) and (member_data.get('state') == data.get('state') or 'running' not in (member_data.get('state'), data.get('state'))) and member_data.get('version') == data.get('version') @@ -401,7 +409,7 @@ class ZooKeeper(AbstractDCS): member = None encoded_data = json.dumps(data, separators=(',', ':')).encode('utf-8') - if member: + if member and member_data: if deep_compare(data, member_data): return True else: @@ -422,19 +430,19 @@ class ZooKeeper(AbstractDCS): return False - def take_leader(self): + def take_leader(self) -> bool: return self.attempt_to_acquire_leader() - def _write_leader_optime(self, last_lsn): + def _write_leader_optime(self, last_lsn: str) -> bool: return self._set_or_create(self.leader_optime_path, last_lsn) - def _write_status(self, value): + def _write_status(self, value: str) -> bool: return self._set_or_create(self.status_path, value) - def _write_failsafe(self, value): + def _write_failsafe(self, value: str) -> bool: return self._set_or_create(self.failsafe_path, value) - def _update_leader(self): + def _update_leader(self) -> bool: cluster = self.cluster session = cluster and isinstance(cluster.leader, Leader) and cluster.leader.session if self._client.client_id and self._client.client_id[0] != session: @@ -459,37 +467,39 @@ class ZooKeeper(AbstractDCS): return False return True - def _delete_leader(self): + def _delete_leader(self) -> bool: self._client.restart() return True - def _cancel_initialization(self): + def _cancel_initialization(self) -> None: node = self.get_node(self.initialize_path) if node: self._client.delete(self.initialize_path, version=node[1].version) - def cancel_initialization(self): + def cancel_initialization(self) -> bool: try: self._client.retry(self._cancel_initialization) + return True except Exception: logger.exception("Unable to delete initialize key") + return False - def delete_cluster(self): + def delete_cluster(self) -> bool: try: return self._client.retry(self._client.delete, self.client_path(''), recursive=True) except NoNodeError: return True - def set_history_value(self, value): + def set_history_value(self, value: str) -> bool: return self._set_or_create(self.history_path, value) - def set_sync_state_value(self, value, index=None): + def set_sync_state_value(self, value: str, index: Optional[int] = None) -> bool: return self._set_or_create(self.sync_path, value, index, retry=True, do_not_create_empty=True) - def delete_sync_state(self, index=None): + def delete_sync_state(self, index: Optional[int] = None) -> bool: return self.set_sync_state_value("{}", index) - def watch(self, leader_index, timeout): + def watch(self, leader_index: Optional[int], timeout: float) -> bool: ret = super(ZooKeeper, self).watch(leader_index, timeout + 0.5) if ret and not self._fetch_status: self._fetch_cluster = True diff --git a/patroni/exceptions.py b/patroni/exceptions.py index 21df1f7b..88edfe07 100644 --- a/patroni/exceptions.py +++ b/patroni/exceptions.py @@ -1,11 +1,14 @@ +from typing import Any + + class PatroniException(Exception): """Parent class for all kind of exceptions related to selected distributed configuration store""" - def __init__(self, value): + def __init__(self, value: Any) -> None: self.value = value - def __str__(self): + def __str__(self) -> str: """ >>> str(PatroniException('foo')) "'foo'" diff --git a/patroni/ha.py b/patroni/ha.py index 444e8141..b994ad9a 100644 --- a/patroni/ha.py +++ b/patroni/ha.py @@ -6,27 +6,26 @@ import sys import time import uuid -from collections import namedtuple from multiprocessing.pool import ThreadPool from threading import RLock -from typing import List, Optional, Union +from typing import Any, Callable, Collection, Dict, List, NamedTuple, Optional, Union, Tuple, TYPE_CHECKING from . import psycopg +from .__main__ import Patroni from .async_executor import AsyncExecutor, CriticalTask from .collections import CaseInsensitiveSet +from .dcs import AbstractDCS, Cluster, Leader, Member, RemoteMember from .exceptions import DCSError, PostgresConnectionException, PatroniFatalException from .postgresql.callback_executor import CallbackAction from .postgresql.misc import postgres_version_to_int +from .postgresql.postmaster import PostmasterProcess from .postgresql.rewind import Rewind from .utils import polling_loop, tzutc -from .dcs import Cluster, Leader, Member, RemoteMember logger = logging.getLogger(__name__) -class _MemberStatus(namedtuple('_MemberStatus', ['member', 'reachable', 'in_recovery', - 'dcs_last_seen', 'timeline', 'wal_position', - 'tags', 'watchdog_failed'])): +class _MemberStatus(NamedTuple): """Node status distilled from API response: member - dcs.Member object of the node @@ -38,29 +37,37 @@ class _MemberStatus(namedtuple('_MemberStatus', ['member', 'reachable', 'in_reco tags - dictionary with values of different tags (i.e. nofailover) watchdog_failed - indicates that watchdog is required by configuration but not available or failed """ + member: Member + reachable: bool + in_recovery: Optional[bool] + dcs_last_seen: int + timeline: int + wal_position: int + tags: Dict[str, Any] + watchdog_failed: bool + @classmethod - def from_api_response(cls, member, json): + def from_api_response(cls, member: Member, json: Dict[str, Any]) -> '_MemberStatus': """ :param member: dcs.Member object :param json: RestApiHandler.get_postgresql_status() result :returns: _MemberStatus object """ # If one of those is not in a response we want to count the node as not healthy/reachable - assert 'wal' in json or 'xlog' in json - - wal = json.get('wal', json.get('xlog')) - in_recovery = not bool(wal.get('location')) # abuse difference in primary/replica response format + wal: Dict[str, Any] = json.get('wal') or json['xlog'] + # abuse difference in primary/replica response format + in_recovery = not bool(wal.get('location')) or json.get('role') in ('master', 'primary') timeline = json.get('timeline', 0) dcs_last_seen = json.get('dcs_last_seen', 0) - wal = in_recovery and max(wal.get('received_location', 0), wal.get('replayed_location', 0)) - return cls(member, True, in_recovery, dcs_last_seen, timeline, wal, + lsn = int(in_recovery and max(wal.get('received_location', 0), wal.get('replayed_location', 0))) + return cls(member, True, in_recovery, dcs_last_seen, timeline, lsn, json.get('tags', {}), json.get('watchdog_failed', False)) @classmethod - def unknown(cls, member): + def unknown(cls, member: Member) -> '_MemberStatus': return cls(member, False, None, 0, 0, 0, {}, False) - def failover_limitation(self): + def failover_limitation(self) -> Optional[str]: """Returns reason why this node can't promote or None if everything is ok.""" if not self.reachable: return 'not reachable' @@ -73,7 +80,7 @@ class _MemberStatus(namedtuple('_MemberStatus', ['member', 'reachable', 'in_reco class Failsafe(object): - def __init__(self, dcs): + def __init__(self, dcs: AbstractDCS) -> None: self._lock = RLock() self._dcs = dcs self._last_update = 0 @@ -82,7 +89,7 @@ class Failsafe(object): self._api_url = None self._slots = None - def update(self, data): + def update(self, data: Dict[str, Any]) -> None: with self._lock: self._last_update = time.time() self._name = data['name'] @@ -91,26 +98,22 @@ class Failsafe(object): self._slots = data.get('slots') @property - def leader(self): + def leader(self) -> Optional[Leader]: with self._lock: if self._last_update + self._dcs.ttl > time.time(): - return Leader(None, None, - RemoteMember(self._name, {'api_url': self._api_url, - 'conn_url': self._conn_url, - 'slots': self._slots})) + return Leader('', '', RemoteMember(self._name, {'api_url': self._api_url, + 'conn_url': self._conn_url, + 'slots': self._slots})) - def update_cluster(self, cluster): + def update_cluster(self, cluster: Cluster) -> Cluster: # Enreach cluster with the real leader if there was a ping from it leader = self.leader if leader: - cluster = list(cluster) # We rely on the strict order of fields in the namedtuple - cluster[2] = leader - cluster[8] = leader.member.data['slots'] - cluster = Cluster(*cluster) + cluster = Cluster(*cluster[0:2], leader, *cluster[3:8], leader.member.data['slots'], *cluster[9:]) return cluster - def is_active(self): + def is_active(self) -> bool: """Is used to report in REST API whether the failsafe mode was activated. On primary the self._last_update is set from the @@ -124,21 +127,21 @@ class Failsafe(object): with self._lock: return self._last_update + self._dcs.ttl > time.time() - def set_is_active(self, value): + def set_is_active(self, value: float) -> None: with self._lock: self._last_update = value class Ha(object): - def __init__(self, patroni): + def __init__(self, patroni: Patroni): self.patroni = patroni self.state_handler = patroni.postgresql self._rewind = Rewind(self.state_handler) self.dcs = patroni.dcs - self.cluster = None + self.cluster = Cluster.empty() self.global_config = self.patroni.config.get_global_config(None) - self.old_cluster = None + self.old_cluster = Cluster.empty() self._is_leader = False self._is_leader_lock = RLock() self._failsafe = Failsafe(patroni.dcs) @@ -147,7 +150,7 @@ class Ha(object): self.recovering = False self._async_response = CriticalTask() self._crash_recovery_executed = False - self._crash_recovery_started = None + self._crash_recovery_started = 0 self._start_timeout = None self._async_executor = AsyncExecutor(self.state_handler.cancellable, self.wakeup) self.watchdog = patroni.watchdog @@ -173,27 +176,27 @@ class Ha(object): ret = self.global_config.primary_stop_timeout return ret if ret > 0 and self.is_synchronous_mode() else None - def is_paused(self): + def is_paused(self) -> bool: """:returns: `True` if in maintenance mode.""" return self.global_config.is_paused - def check_timeline(self): + def check_timeline(self) -> bool: """:returns: `True` if should check whether the timeline is latest during the leader race.""" return self.global_config.check_mode('check_timeline') - def is_standby_cluster(self): + def is_standby_cluster(self) -> bool: """:returns: `True` if global configuration has a valid "standby_cluster" section.""" return self.global_config.is_standby_cluster - def is_leader(self): + def is_leader(self) -> bool: with self._is_leader_lock: return self._is_leader > time.time() - def set_is_leader(self, value): + def set_is_leader(self, value: bool) -> None: with self._is_leader_lock: self._is_leader = time.time() + self.dcs.ttl if value else 0 - def load_cluster_from_dcs(self): + def load_cluster_from_dcs(self) -> None: cluster = self.dcs.get_cluster() # We want to keep the state of cluster when it was healthy @@ -208,9 +211,9 @@ class Ha(object): if not self.has_lock(False): self.set_is_leader(False) - self._leader_timeline = None if cluster.is_unlocked() else cluster.leader.timeline + self._leader_timeline = cluster.leader.timeline if cluster.leader else None - def acquire_lock(self): + def acquire_lock(self) -> bool: try: ret = self.dcs.attempt_to_acquire_leader() except DCSError: @@ -221,14 +224,14 @@ class Ha(object): self.set_is_leader(ret) return ret - def _failsafe_config(self): + def _failsafe_config(self) -> Optional[Dict[str, str]]: if self.is_failsafe_mode(): - ret = {m.name: m.api_url for m in self.cluster.members} + ret = {m.name: m.api_url for m in self.cluster.members if m.api_url} if self.state_handler.name not in ret: ret[self.state_handler.name] = self.patroni.api.connection_string return ret - def update_lock(self, write_leader_optime=False): + def update_lock(self, write_leader_optime: bool = False) -> bool: last_lsn = slots = None if write_leader_optime: try: @@ -248,13 +251,13 @@ class Ha(object): self.watchdog.keepalive() return ret - def has_lock(self, info=True): + def has_lock(self, info: bool = True) -> bool: lock_owner = self.cluster.leader and self.cluster.leader.name if info: logger.info('Lock owner: %s; I am %s', lock_owner, self.state_handler.name) return lock_owner == self.state_handler.name - def get_effective_tags(self): + def get_effective_tags(self) -> Dict[str, Any]: """Return configuration tags merged with dynamically applied tags.""" tags = self.patroni.tags.copy() # _disable_sync could be modified concurrently, but we don't care as attribute get and set are atomic. @@ -262,10 +265,10 @@ class Ha(object): tags['nosync'] = True return tags - def notify_citus_coordinator(self, event): + def notify_citus_coordinator(self, event: str) -> None: if self.state_handler.citus_handler.is_worker(): coordinator = self.dcs.get_citus_coordinator() - if coordinator and coordinator.leader and coordinator.leader.conn_kwargs: + if coordinator and coordinator.leader and coordinator.leader.conn_url: try: data = {'type': event, 'group': self.state_handler.citus_handler.group(), @@ -278,9 +281,9 @@ class Ha(object): logger.warning('Request to Citus coordinator leader %s %s failed: %r', coordinator.leader.name, coordinator.leader.member.api_url, e) - def touch_member(self): + def touch_member(self) -> bool: with self._member_state_lock: - data = { + data: Dict[str, Any] = { 'conn_url': self.state_handler.connection_string, 'api_url': self.patroni.api.connection_string, 'state': self.state_handler.state, @@ -302,6 +305,7 @@ class Ha(object): if self._async_executor.scheduled_action in (None, 'promote') \ and data['state'] in ['running', 'restarting', 'starting']: try: + timeline: Optional[int] timeline, wal_position, pg_control_timeline = self.state_handler.timeline_wal_position() data['xlog_location'] = wal_position if not timeline: # try pg_stat_wal_receiver to get the timeline @@ -317,7 +321,7 @@ class Ha(object): if self.state_handler.role == 'standby_leader': timeline = pg_control_timeline or self.state_handler.pg_control_timeline() else: - timeline = self.state_handler.replica_cached_timeline(self._leader_timeline) + timeline = self.state_handler.replica_cached_timeline(self._leader_timeline) or 0 if timeline: data['timeline'] = timeline except Exception: @@ -338,7 +342,7 @@ class Ha(object): self._last_state = new_state return ret - def clone(self, clone_member=None, msg='(without leader)'): + def clone(self, clone_member: Union[Leader, Member, None] = None, msg: str = '(without leader)') -> Optional[bool]: if self.is_standby_cluster() and not isinstance(clone_member, RemoteMember): clone_member = self.get_remote_member(clone_member) @@ -352,16 +356,10 @@ class Ha(object): logger.error('failed to bootstrap %s', msg) self.state_handler.remove_data_directory() - def bootstrap(self): - if not self.cluster.is_unlocked(): # cluster already has leader - clone_member = self.cluster.get_clone_member(self.state_handler.name) - member_role = 'leader' if clone_member == self.cluster.leader else 'replica' - msg = "from {0} '{1}'".format(member_role, clone_member.name) - ret = self._async_executor.try_run_async('bootstrap {0}'.format(msg), self.clone, args=(clone_member, msg)) - return ret or 'trying to bootstrap {0}'.format(msg) - + def bootstrap(self) -> str: # no initialize key and node is allowed to be primary and has 'bootstrap' section in a configuration file - elif self.cluster.initialize is None and not self.patroni.nofailover and 'bootstrap' in self.patroni.config: + if self.cluster.is_unlocked() and self.cluster.initialize is None\ + and not self.patroni.nofailover and 'bootstrap' in self.patroni.config: if self.dcs.initialize(create_new=True): # race for initialization self.state_handler.bootstrapping = True with self._async_response: @@ -376,17 +374,26 @@ class Ha(object): return ret or 'trying to bootstrap a new cluster' else: return 'failed to acquire initialize lock' - else: - create_replica_methods = self.global_config.get_standby_cluster_config().get('create_replica_methods', []) \ - if self.is_standby_cluster() else None - can_bootstrap = self.state_handler.can_create_replica_without_replication_connection(create_replica_methods) - concurrent_bootstrap = self.cluster.initialize == "" - if can_bootstrap and not concurrent_bootstrap: - msg = 'bootstrap (without leader)' - return self._async_executor.try_run_async(msg, self.clone) or 'trying to ' + msg - return 'waiting for {0}leader to bootstrap'.format('standby_' if self.is_standby_cluster() else '') - def bootstrap_standby_leader(self): + clone_member = self.cluster.get_clone_member(self.state_handler.name) + # cluster already has a leader, we can bootstrap from it or from one of replicas (if they allow) + if not self.cluster.is_unlocked() and clone_member: + member_role = 'leader' if clone_member == self.cluster.leader else 'replica' + msg = "from {0} '{1}'".format(member_role, clone_member.name) + ret = self._async_executor.try_run_async('bootstrap {0}'.format(msg), self.clone, args=(clone_member, msg)) + return ret or 'trying to bootstrap {0}'.format(msg) + + # no leader, but configuration may allowed replica creation using backup tools + create_replica_methods = self.global_config.get_standby_cluster_config().get('create_replica_methods', []) \ + if self.is_standby_cluster() else None + can_bootstrap = self.state_handler.can_create_replica_without_replication_connection(create_replica_methods) + concurrent_bootstrap = self.cluster.initialize == "" + if can_bootstrap and not concurrent_bootstrap: + msg = 'bootstrap (without leader)' + return self._async_executor.try_run_async(msg, self.clone) or 'trying to ' + msg + return 'waiting for {0}leader to bootstrap'.format('standby_' if self.is_standby_cluster() else '') + + def bootstrap_standby_leader(self) -> Optional[bool]: """ If we found 'standby' key in the configuration, we need to bootstrap not a real primary, but a 'standby leader', that will take base backup from a remote member and start follow it. @@ -401,16 +408,16 @@ class Ha(object): return result - def _handle_crash_recovery(self): + def _handle_crash_recovery(self) -> Optional[str]: if not self._crash_recovery_executed and (self.cluster.is_unlocked() or self._rewind.can_rewind): self._crash_recovery_executed = True self._crash_recovery_started = time.time() msg = 'doing crash recovery in a single user mode' return self._async_executor.try_run_async(msg, self._rewind.ensure_clean_shutdown) or msg - def _handle_rewind_or_reinitialize(self): + def _handle_rewind_or_reinitialize(self) -> Optional[str]: leader = self.get_remote_member() if self.is_standby_cluster() else self.cluster.leader - if not self._rewind.rewind_or_reinitialize_needed_and_possible(leader): + if not self._rewind.rewind_or_reinitialize_needed_and_possible(leader) or not leader: return None if self._rewind.can_rewind: @@ -428,7 +435,7 @@ class Ha(object): msg = 'reinitializing due to diverged timelines' return self._async_executor.try_run_async(msg, self._do_reinitialize, args=(self.cluster,)) or msg - def recover(self): + def recover(self) -> str: # Postgres is not running and we will restart in standby mode. Watchdog is not needed until we promote. self.watchdog.disable() @@ -454,7 +461,10 @@ class Ha(object): self.load_cluster_from_dcs() role = 'replica' - if self.is_standby_cluster() or not self.has_lock(): + if self.has_lock() and not self.is_standby_cluster(): + msg = "starting as readonly because i had the session lock" + node_to_follow = None + else: if not self._rewind.executed: self._rewind.trigger_check_diverged_lsn() msg = self._handle_rewind_or_reinitialize() @@ -474,16 +484,13 @@ class Ha(object): if self.is_synchronous_mode(): self.state_handler.sync_handler.set_synchronous_standby_names(CaseInsensitiveSet()) - elif self.has_lock(): - msg = "starting as readonly because i had the session lock" - node_to_follow = None if self._async_executor.try_run_async('restarting after failure', self.state_handler.follow, args=(node_to_follow, role, timeout)) is None: self.recovering = True return msg - def _get_node_to_follow(self, cluster): + def _get_node_to_follow(self, cluster: Cluster) -> Union[Leader, Member, None]: # determine the node to follow. If replicatefrom tag is set, # try to follow the node mentioned there, otherwise, follow the leader. if self.is_standby_cluster() and (self.cluster.is_unlocked() or self.has_lock(False)): @@ -506,7 +513,7 @@ class Ha(object): return node_to_follow - def follow(self, demote_reason, follow_reason, refresh=True): + def follow(self, demote_reason: str, follow_reason: str, refresh: bool = True) -> str: if refresh: self.load_cluster_from_dcs() @@ -565,7 +572,7 @@ class Ha(object): """:returns: `True` if synchronous replication is requested.""" return self.global_config.is_synchronous_mode - def is_failsafe_mode(self): + def is_failsafe_mode(self) -> bool: """:returns: `True` if failsafe_mode is enabled in global configuration.""" return self.global_config.check_mode('failsafe_mode') @@ -625,10 +632,10 @@ class Ha(object): def is_sync_standby(self, cluster: Cluster) -> bool: """:returns: `True` if the current node is a synchronous standby.""" - return cluster.leader and cluster.sync.leader_matches(cluster.leader.name) \ + return bool(cluster.leader) and cluster.sync.leader_matches(cluster.leader.name) \ and cluster.sync.matches(self.state_handler.name) - def while_not_sync_standby(self, func): + def while_not_sync_standby(self, func: Callable[..., Any]) -> Any: """Runs specified action while trying to make sure that the node is not assigned synchronous standby status. Tags us as not allowed to be a sync standby as we are going to go away, if we currently are wait for @@ -665,32 +672,32 @@ class Ha(object): with self._member_state_lock: self._disable_sync -= 1 - def update_cluster_history(self): + def update_cluster_history(self) -> None: primary_timeline = self.state_handler.get_primary_timeline() - cluster_history = self.cluster.history and self.cluster.history.lines + cluster_history = self.cluster.history.lines if self.cluster.history else [] if primary_timeline == 1: if cluster_history: self.dcs.set_history_value('[]') elif not cluster_history or cluster_history[-1][0] != primary_timeline - 1 or len(cluster_history[-1]) != 5: - cluster_history = {line[0]: line for line in cluster_history or []} - history = self.state_handler.get_history(primary_timeline) - if history and self.cluster.config: + cluster_history = {line[0]: line for line in cluster_history} + history: List[List[Any]] = list(map(list, self.state_handler.get_history(primary_timeline))) + if self.cluster.config: history = history[-self.cluster.config.max_timelines_history:] - for line in history: - # enrich current history with promotion timestamps stored in DCS - if len(line) == 3 and line[0] in cluster_history \ - and len(cluster_history[line[0]]) >= 4 \ - and cluster_history[line[0]][1] == line[1]: - line.append(cluster_history[line[0]][3]) - if len(cluster_history[line[0]]) == 5: - line.append(cluster_history[line[0]][4]) + for line in history: + # enrich current history with promotion timestamps stored in DCS + cluster_history_line = list(cluster_history.get(line[0], [])) + if len(line) == 3 and len(cluster_history_line) >= 4 and cluster_history_line[1] == line[1]: + line.append(cluster_history_line[3]) + if len(cluster_history_line) == 5: + line.append(cluster_history_line[4]) + if history: self.dcs.set_history_value(json.dumps(history, separators=(',', ':'))) - def enforce_follow_remote_member(self, message): + def enforce_follow_remote_member(self, message: str) -> str: demote_reason = 'cannot be a real primary in standby cluster' return self.follow(demote_reason, message) - def enforce_primary_role(self, message, promote_message): + def enforce_primary_role(self, message: str, promote_message: str) -> str: """ Ensure the node that has won the race for the leader key meets criteria for promoting its PG server to the 'primary' role. @@ -749,7 +756,7 @@ class Ha(object): before_promote, on_success)) return promote_message - def fetch_node_status(self, member): + def fetch_node_status(self, member: Member) -> _MemberStatus: """This function perform http get request on member.api_url and fetches its status :returns: `_MemberStatus` object """ @@ -763,36 +770,36 @@ class Ha(object): logger.warning("Request failed to %s: GET %s (%s)", member.name, member.api_url, e) return _MemberStatus.unknown(member) - def fetch_nodes_statuses(self, members): + def fetch_nodes_statuses(self, members: List[Member]) -> List[_MemberStatus]: pool = ThreadPool(len(members)) results = pool.map(self.fetch_node_status, members) # Run API calls on members in parallel pool.close() pool.join() return results - def update_failsafe(self, data): + def update_failsafe(self, data: Dict[str, Any]) -> Optional[str]: if self.state_handler.state == 'running' and self.state_handler.role in ('master', 'primary'): return 'Running as a leader' self._failsafe.update(data) - def failsafe_is_active(self): + def failsafe_is_active(self) -> bool: return self._failsafe.is_active() - def call_failsafe_member(self, data, member): + def call_failsafe_member(self, data: Dict[str, Any], member: Member) -> bool: try: response = self.patroni.request(member, 'post', 'failsafe', data, timeout=2, retries=1) - data = response.data.decode('utf-8') - logger.info('Got response from %s %s: %s', member.name, member.api_url, data) - return response.status == 200 and data == 'Accepted' + response_data = response.data.decode('utf-8') + logger.info('Got response from %s %s: %s', member.name, member.api_url, response_data) + return response.status == 200 and response_data == 'Accepted' except Exception as e: logger.warning("Request failed to %s: POST %s (%s)", member.name, member.api_url, e) return False - def check_failsafe_topology(self): + def check_failsafe_topology(self) -> bool: failsafe = self.dcs.failsafe if not isinstance(failsafe, dict) or self.state_handler.name not in failsafe: return False - data = { + data: Dict[str, Any] = { 'name': self.state_handler.name, 'conn_url': self.state_handler.connection_string, 'api_url': self.patroni.api.connection_string, @@ -813,7 +820,7 @@ class Ha(object): pool.join() return all(results) - def is_lagging(self, wal_position): + def is_lagging(self, wal_position: int) -> bool: """Returns if instance with an wal should consider itself unhealthy to be promoted due to replication lag. :param wal_position: Current wal position. @@ -822,7 +829,7 @@ class Ha(object): lag = (self.cluster.last_lsn or 0) - wal_position return lag > self.global_config.maximum_lag_on_failover - def _is_healthiest_node(self, members, check_replication_lag=True): + def _is_healthiest_node(self, members: Collection[Member], check_replication_lag: bool = True) -> bool: """This method tries to determine whether I am healthy enough to became a new leader candidate or not.""" my_wal_position = self.state_handler.last_operation() @@ -833,6 +840,9 @@ class Ha(object): if not self.is_standby_cluster() and self.check_timeline(): cluster_timeline = self.cluster.timeline my_timeline = self.state_handler.replica_cached_timeline(cluster_timeline) + if my_timeline is None: + logger.info('Can not figure out my timeline') + return False if my_timeline < cluster_timeline: logger.info('My timeline %s is behind last known cluster timeline %s', my_timeline, cluster_timeline) return False @@ -843,7 +853,7 @@ class Ha(object): if members: for st in self.fetch_nodes_statuses(members): if st.failover_limitation() is None: - if not st.in_recovery: + if st.in_recovery is False: logger.warning('Primary (%s) is still alive', st.member.name) return False if my_wal_position < st.wal_position: @@ -875,8 +885,6 @@ class Ha(object): not_allowed_reason = st.failover_limitation() if not_allowed_reason: logger.info('Member %s is %s', st.member.name, not_allowed_reason) - elif not isinstance(st.wal_position, int): - logger.info('Member %s does not report wal_position', st.member.name) elif cluster_lsn and st.wal_position < cluster_lsn or\ not cluster_lsn and self.is_lagging(st.wal_position): logger.info('Member %s exceeds maximum replication lag', st.member.name) @@ -889,13 +897,15 @@ class Ha(object): logger.warning('manual failover: members list is empty') return ret - def manual_failover_process_no_leader(self) -> Union[bool, None]: + def manual_failover_process_no_leader(self) -> Optional[bool]: """Handles manual failover/switchover when the old leader already stepped down. :returns: - `True` if the current node is the best candidate to become the new leader - `None` if the current node is running as a primary and requested candidate doesn't exist """ failover = self.cluster.failover + if TYPE_CHECKING: # pragma: no cover + assert failover is not None if failover.candidate: # manual failover to specific member if failover.candidate == self.state_handler.name: # manual failover to me return True @@ -905,7 +915,7 @@ class Ha(object): if not self.cluster.get_member(failover.candidate, fallback_to_leader=False)\ and self.state_handler.is_leader(): logger.warning("manual failover: removing failover key because failover candidate is not running") - self.dcs.manual_failover('', '', index=self.cluster.failover.index) + self.dcs.manual_failover('', '', index=failover.index) return None return False @@ -916,7 +926,7 @@ class Ha(object): # find specific node and check that it is healthy member = self.cluster.get_member(failover.candidate, fallback_to_leader=False) - if member: + if isinstance(member, Member): st = self.fetch_node_status(member) not_allowed_reason = st.failover_limitation() if not_allowed_reason is None: # node is healthy @@ -981,7 +991,7 @@ class Ha(object): if self.is_synchronous_mode() and self.cluster.failover.leader and \ not self.cluster.sync.is_empty and not self.cluster.sync.matches(self.state_handler.name, True): return False - return self.manual_failover_process_no_leader() + return self.manual_failover_process_no_leader() or False if not self.watchdog.is_healthy: logger.warning('Watchdog device is not usable') @@ -1011,17 +1021,17 @@ class Ha(object): return self._is_healthiest_node(members.values()) - def _delete_leader(self, last_lsn=None): + def _delete_leader(self, last_lsn: Optional[int] = None) -> None: self.set_is_leader(False) self.dcs.delete_leader(last_lsn) self.dcs.reset_cluster() - def release_leader_key_voluntarily(self, last_lsn=None): + def release_leader_key_voluntarily(self, last_lsn: Optional[int] = None) -> None: self._delete_leader(last_lsn) self.touch_member() logger.info("Leader key released") - def demote(self, mode): + def demote(self, mode: str) -> Optional[bool]: """Demote PostgreSQL running as primary. :param mode: One of offline, graceful or immediate. @@ -1046,7 +1056,7 @@ class Ha(object): status = {'released': False} - def on_shutdown(checkpoint_location): + def on_shutdown(checkpoint_location: int) -> None: # Postmaster is still running, but pg_control already reports clean "shut down". # It could happen if Postgres is still archiving the backlog of WAL files. # If we know that there are replicas that received the shutdown checkpoint @@ -1057,13 +1067,13 @@ class Ha(object): self.release_leader_key_voluntarily(checkpoint_location) status['released'] = True - def before_shutdown(): + def before_shutdown() -> None: if self.state_handler.citus_handler.is_coordinator(): self.state_handler.citus_handler.on_demote() else: self.notify_citus_coordinator('before_demote') - self.state_handler.stop(mode_control['stop'], checkpoint=mode_control['checkpoint'], + self.state_handler.stop(str(mode_control['stop']), checkpoint=bool(mode_control['checkpoint']), on_safepoint=self.watchdog.disable if self.watchdog.is_running else None, on_shutdown=on_shutdown if mode_control['release'] else None, before_shutdown=before_shutdown if mode == 'graceful' else None, @@ -1099,7 +1109,8 @@ class Ha(object): return False # do not start postgres, but run pg_rewind on the next iteration self.state_handler.follow(node_to_follow) - def should_run_scheduled_action(self, action_name, scheduled_at, cleanup_fn): + def should_run_scheduled_action(self, action_name: str, scheduled_at: Optional[datetime.datetime], + cleanup_fn: Callable[..., Any]) -> bool: if scheduled_at and not self.is_paused(): # If the scheduled action is in the far future, we shouldn't do anything and just return. # If the scheduled action is in the past, we consider the value to be stale and we remove @@ -1133,7 +1144,7 @@ class Ha(object): cleanup_fn() return False - def process_manual_failover_from_leader(self): + def process_manual_failover_from_leader(self) -> Optional[str]: """Checks if manual failover is requested and takes action if appropriate. Cleans up failover key if failover conditions are not matched. @@ -1177,7 +1188,7 @@ class Ha(object): logger.info('Cleaning up failover key') self.dcs.manual_failover('', '', index=failover.index) - def process_unhealthy_cluster(self): + def process_unhealthy_cluster(self) -> str: """Cluster has no leader key""" if self.is_healthiest_node(): @@ -1219,7 +1230,7 @@ class Ha(object): return self.follow('demoting self because i am not the healthiest node', 'following a different leader because i am not the healthiest node') - def process_healthy_cluster(self): + def process_healthy_cluster(self) -> str: if self.has_lock(): if self.is_paused() and not self.state_handler.is_leader(): if self.cluster.failover and self.cluster.failover.candidate == self.state_handler.name: @@ -1271,7 +1282,7 @@ class Ha(object): 'no action. I am ({0}), a secondary, and following a leader ({1})'.format( self.state_handler.name, lock_owner), refresh=False) - def evaluate_scheduled_restart(self): + def evaluate_scheduled_restart(self) -> Optional[str]: if self._async_executor.busy: # Restart already in progress return None @@ -1297,7 +1308,7 @@ class Ha(object): finally: self.delete_future_restart() - def restart_matches(self, role, postgres_version, pending_restart): + def restart_matches(self, role: Optional[str], postgres_version: Optional[str], pending_restart: bool) -> bool: reason_to_cancel = "" # checking the restart filters here seem to be less ugly than moving them into the # run_scheduled_action. @@ -1316,7 +1327,7 @@ class Ha(object): logger.info("not proceeding with the restart: %s", reason_to_cancel) return False - def schedule_future_restart(self, restart_data): + def schedule_future_restart(self, restart_data: Dict[str, Any]) -> bool: with self._async_executor: restart_data['postmaster_start_time'] = self.state_handler.postmaster_start_time() if not self.patroni.scheduled_restart: @@ -1325,7 +1336,7 @@ class Ha(object): return True return False - def delete_future_restart(self): + def delete_future_restart(self) -> bool: ret = False with self._async_executor: if self.patroni.scheduled_restart: @@ -1334,14 +1345,13 @@ class Ha(object): ret = True return ret - def future_restart_scheduled(self): - return self.patroni.scheduled_restart.copy()\ - if (self.patroni.scheduled_restart and isinstance(self.patroni.scheduled_restart, dict)) else None + def future_restart_scheduled(self) -> Dict[str, Any]: + return self.patroni.scheduled_restart.copy() - def restart_scheduled(self): + def restart_scheduled(self) -> bool: return self._async_executor.scheduled_action == 'restart' - def restart(self, restart_data, run_async=False): + def restart(self, restart_data: Dict[str, Any], run_async: bool = False) -> Tuple[bool, str]: """ conditional and unconditional restart """ assert isinstance(restart_data, dict) @@ -1365,10 +1375,10 @@ class Ha(object): timeout = restart_data.get('timeout', self.global_config.primary_start_timeout) self.set_start_timeout(timeout) - def before_shutdown(): + def before_shutdown() -> None: self.notify_citus_coordinator('before_demote') - def after_start(): + def after_start() -> None: self.notify_citus_coordinator('after_promote') # For non async cases we want to wait for restart to complete or timeout before returning. @@ -1390,16 +1400,17 @@ class Ha(object): else: return (False, 'restart failed') - def _do_reinitialize(self, cluster): + def _do_reinitialize(self, cluster: Cluster) -> Optional[bool]: self.state_handler.stop('immediate', stop_timeout=self.patroni.config['retry_timeout']) # Commented redundant data directory cleanup here # self.state_handler.remove_data_directory() - clone_member = self.cluster.get_clone_member(self.state_handler.name) - member_role = 'leader' if clone_member == self.cluster.leader else 'replica' - return self.clone(clone_member, "from {0} '{1}'".format(member_role, clone_member.name)) + clone_member = cluster.get_clone_member(self.state_handler.name) + if clone_member: + member_role = 'leader' if clone_member == cluster.leader else 'replica' + return self.clone(clone_member, "from {0} '{1}'".format(member_role, clone_member.name)) - def reinitialize(self, force=False): + def reinitialize(self, force: bool = False) -> Optional[str]: with self._async_executor: self.load_cluster_from_dcs() @@ -1409,6 +1420,8 @@ class Ha(object): if self.has_lock(False): return 'I am the leader, can not reinitialize' + cluster = self.cluster + if force: self._async_executor.cancel() @@ -1417,9 +1430,9 @@ class Ha(object): if action is not None: return '{0} already in progress'.format(action) - self._async_executor.run_async(self._do_reinitialize, args=(self.cluster, )) + self._async_executor.run_async(self._do_reinitialize, args=(cluster, )) - def handle_long_action_in_progress(self): + def handle_long_action_in_progress(self) -> str: """ Figure out what to do with the task AsyncExecutor is performing. """ @@ -1432,7 +1445,7 @@ class Ha(object): self.demote('immediate') return 'terminated crash recovery because of startup timeout' - return 'updated leader lock during ' + self._async_executor.scheduled_action + return 'updated leader lock during {0}'.format(self._async_executor.scheduled_action) elif not self.state_handler.bootstrapping and not self.is_paused(): # Don't have lock, make sure we are not promoting or starting up a primary in the background if self._async_executor.scheduled_action == 'promote': @@ -1443,28 +1456,28 @@ class Ha(object): return 'lost leader before promote' if self.state_handler.role in ('master', 'primary'): - logger.info("Demoting primary during " + self._async_executor.scheduled_action) + logger.info('Demoting primary during %s', self._async_executor.scheduled_action) if self._async_executor.scheduled_action == 'restart': # Restart needs a special interlocking cancel because postmaster may be just started in a # background thread and has not even written a pid file yet. with self._async_executor.critical_task as task: - if not task.cancel(): + if not task.cancel() and isinstance(task.result, PostmasterProcess): self.state_handler.terminate_starting_postmaster(postmaster=task.result) self.demote('immediate-nolock') - return 'lost leader lock during ' + self._async_executor.scheduled_action + return 'lost leader lock during {0}'.format(self._async_executor.scheduled_action) if self.cluster.is_unlocked(): logger.info('not healthy enough for leader race') - return self._async_executor.scheduled_action + ' in progress' + return '{0} in progress'.format(self._async_executor.scheduled_action) @staticmethod - def sysid_valid(sysid): + def sysid_valid(sysid: Optional[str]) -> bool: # sysid does tv_sec << 32, where tv_sec is the number of seconds sine 1970, # so even 1 << 32 would have 10 digits. sysid = str(sysid) return len(sysid) >= 10 and sysid.isdigit() - def post_recover(self): + def post_recover(self) -> Optional[str]: if not self.state_handler.is_running(): self.watchdog.disable() if self.has_lock(): @@ -1475,14 +1488,14 @@ class Ha(object): return 'failed to start postgres' return None - def cancel_initialization(self): + def cancel_initialization(self) -> None: logger.info('removing initialize key after failed attempt to bootstrap the cluster') self.dcs.cancel_initialization() self.state_handler.stop('immediate', stop_timeout=self.patroni.config['retry_timeout']) self.state_handler.move_data_directory() raise PatroniFatalException('Failed to bootstrap cluster') - def post_bootstrap(self): + def post_bootstrap(self) -> str: with self._async_response: result = self._async_response.result # bootstrap has failed if postgres is not running @@ -1513,7 +1526,7 @@ class Ha(object): return 'initialized a new cluster' - def handle_starting_instance(self): + def handle_starting_instance(self) -> Optional[str]: """Starting up PostgreSQL may take a long time. In case we are the leader we may want to fail over to.""" @@ -1552,13 +1565,13 @@ class Ha(object): logger.info("Still starting up as a standby.") return None - def set_start_timeout(self, value): + def set_start_timeout(self, value: Optional[int]) -> None: """Sets timeout for starting as primary before eligible for failover. Must be called when async_executor is busy or in the main thread.""" self._start_timeout = value - def _run_cycle(self): + def _run_cycle(self) -> str: dcs_failed = False try: try: @@ -1619,6 +1632,8 @@ class Ha(object): return 'started as a secondary' # is data directory empty? + data_directory_error = '' + data_directory_is_empty = None try: data_directory_is_empty = self.state_handler.data_directory_empty() data_directory_is_accessible = True @@ -1727,7 +1742,7 @@ class Ha(object): self._failsafe.set_is_active(0) self.touch_member() - def _handle_dcs_error(self): + def _handle_dcs_error(self) -> str: if not self.is_paused() and self.state_handler.is_running(): if self.state_handler.is_leader(): if self.is_failsafe_mode() and self.check_failsafe_topology(): @@ -1746,13 +1761,13 @@ class Ha(object): self._sync_replication_slots(True) return 'DCS is not accessible' - def _sync_replication_slots(self, dcs_failed): + def _sync_replication_slots(self, dcs_failed: bool) -> List[str]: """Handles replication slots. :param dcs_failed: bool, indicates that communication with DCS failed (get_cluster() or update_leader()) :returns: list[str], replication slots names that should be copied from the primary""" - slots = [] + slots: List[str] = [] # If dcs_failed we don't want to touch replication slots on a leader or replicas if failsafe_mode isn't enabled. if not self.cluster or dcs_failed and (self.is_leader() or not self.is_failsafe_mode()): @@ -1772,7 +1787,7 @@ class Ha(object): # Don't copy replication slots if failsafe_mode is active return [] if self.failsafe_is_active() else slots - def run_cycle(self): + def run_cycle(self) -> str: with self._async_executor: try: info = self._run_cycle() @@ -1783,7 +1798,7 @@ class Ha(object): logger.exception('Unexpected exception') return 'Unexpected exception raised, please report it as a BUG' - def shutdown(self): + def shutdown(self) -> None: if self.is_paused(): logger.info('Leader key is not deleted and Postgresql is not stopped due paused state') self.watchdog.disable() @@ -1796,7 +1811,7 @@ class Ha(object): status = {'deleted': False} - def _on_shutdown(checkpoint_location): + def _on_shutdown(checkpoint_location: int) -> None: if self.is_leader(): # Postmaster is still running, but pg_control already reports clean "shut down". # It could happen if Postgres is still archiving the backlog of WAL files. @@ -1808,7 +1823,7 @@ class Ha(object): else: self.dcs.write_leader_optime(checkpoint_location) - def _before_shutdown(): + def _before_shutdown() -> None: self.notify_citus_coordinator('before_demote') on_shutdown = _on_shutdown if self.is_leader() else None @@ -1829,36 +1844,36 @@ class Ha(object): logger.error("PostgreSQL shutdown failed, leader key not removed.%s", (" Leaving watchdog running." if self.watchdog.is_running else "")) - def watch(self, timeout): + def watch(self, timeout: float) -> bool: # watch on leader key changes if the postgres is running and leader is known and current node is not lock owner if self._async_executor.busy or not self.cluster or self.cluster.is_unlocked() or self.has_lock(False): leader_index = None else: - leader_index = self.cluster.leader.index + leader_index = self.cluster.leader.index if self.cluster.leader else None return self.dcs.watch(leader_index, timeout) - def wakeup(self): + def wakeup(self) -> None: """Call of this method will trigger the next run of HA loop if there is no "active" leader watch request in progress. This usually happens on the leader or if the node is running async action""" self.dcs.event.set() - def get_remote_member(self, member=None): + def get_remote_member(self, member: Union[Leader, Member, None] = None) -> RemoteMember: """ In case of standby cluster this will tel us from which remote member to stream. Config can be both patroni config or cluster.config.data """ + data: Dict[str, Any] = {} cluster_params = self.global_config.get_standby_cluster_config() if cluster_params: - name = member.name if member else 'remote_member:{}'.format(uuid.uuid1()) - - data = {k: v for k, v in cluster_params.items() if k in RemoteMember.allowed_keys()} + data.update({k: v for k, v in cluster_params.items() if k in RemoteMember.allowed_keys()}) data['no_replication_slot'] = 'primary_slot_name' not in cluster_params conn_kwargs = member.conn_kwargs() if member else \ {k: cluster_params[k] for k in ('host', 'port') if k in cluster_params} if conn_kwargs: data['conn_kwargs'] = conn_kwargs - return RemoteMember(name, data) + name = member.name if member else 'remote_member:{}'.format(uuid.uuid1()) + return RemoteMember(name, data) diff --git a/patroni/log.py b/patroni/log.py index 2ccf053d..955594ff 100644 --- a/patroni/log.py +++ b/patroni/log.py @@ -13,7 +13,7 @@ from patroni.utils import deep_compare from queue import Queue, Full from threading import Lock, Thread -from typing import Any, Dict +from typing import Any, Dict, List, Optional, Union _LOGGER = logging.getLogger(__name__) @@ -68,7 +68,7 @@ class QueueHandler(logging.Handler): def __init__(self) -> None: """Queue initialised and initial records_lost established.""" super().__init__() - self.queue = Queue() + self.queue: Queue[Union[logging.LogRecord, None]] = Queue() self._records_lost = 0 def _put_record(self, record: logging.LogRecord) -> None: @@ -189,10 +189,10 @@ class PatroniLogger(Thread): super(PatroniLogger, self).__init__() self._queue_handler = QueueHandler() self._root_logger = logging.getLogger() - self._config = None + self._config: Optional[Dict[str, Any]] = None self.log_handler = None self.log_handler_lock = Lock() - self._old_handlers = [] + self._old_handlers: List[logging.Handler] = [] # initially set log level to ``DEBUG`` while the logger thread has not started running yet. The daemon process # will later adjust all log related settings with what was provided through the user configuration file. self.reload_config({'level': 'DEBUG'}) diff --git a/patroni/postgresql/__init__.py b/patroni/postgresql/__init__.py index 1ecf111a..f60d740a 100644 --- a/patroni/postgresql/__init__.py +++ b/patroni/postgresql/__init__.py @@ -12,7 +12,7 @@ from datetime import datetime from dateutil import tz from psutil import TimeoutExpired from threading import current_thread, Lock -from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING +from typing import Any, Callable, Dict, Generator, List, Optional, Union, Tuple, TYPE_CHECKING from .bootstrap import Bootstrap from .callback_executor import CallbackAction, CallbackExecutor @@ -25,11 +25,14 @@ from .postmaster import PostmasterProcess from .slots import SlotsHandler from .sync import SyncHandler from .. import psycopg -from ..dcs import Cluster, Member +from ..async_executor import CriticalTask +from ..dcs import Cluster, Leader, Member from ..exceptions import PostgresConnectionException from ..utils import Retry, RetryFailedError, polling_loop, data_directory_is_empty, parse_int if TYPE_CHECKING: # pragma: no cover + from psycopg import Connection as Connection3, Cursor + from psycopg2 import connection as connection3, cursor from ..config import GlobalConfig logger = logging.getLogger(__name__) @@ -59,13 +62,15 @@ class Postgresql(object): "pg_catalog.pg_{0}_{1}_diff(COALESCE(pg_catalog.pg_last_{0}_receive_{1}(), '0/0'), '0/0')::bigint, " "pg_catalog.pg_is_in_recovery() AND pg_catalog.pg_is_{0}_replay_paused()") - def __init__(self, config): - self.name = config['name'] - self.scope = config['scope'] - self._data_dir = config['data_dir'] + def __init__(self, config: Dict[str, Any]) -> None: + self.name: str = config['name'] + self.scope: str = config['scope'] + self._data_dir: str = config['data_dir'] self._database = config.get('database', 'postgres') self._version_file = os.path.join(self._data_dir, 'PG_VERSION') self._pg_control = os.path.join(self._data_dir, 'global', 'pg_control') + self.connection_string: str + self.proxy_url: Optional[str] self._major_version = self.get_major_version() self._global_config = None @@ -92,7 +97,7 @@ class Postgresql(object): self.cancellable = CancellableSubprocess() - self._sysid = None + self._sysid = '' self.retry = Retry(max_tries=-1, deadline=config['retry_timeout'] / 2.0, max_delay=1, retry_exceptions=PostgresConnectionException) @@ -102,7 +107,7 @@ class Postgresql(object): self._role_lock = Lock() self.set_role(self.get_postgres_role_from_data_directory()) - self._state_entry_timestamp = None + self._state_entry_timestamp = 0 self._cluster_info_state = {} self._has_permanent_logical_slots = True @@ -126,35 +131,35 @@ class Postgresql(object): self.set_role('demoted') @property - def create_replica_methods(self): - return self.config.get('create_replica_methods', []) or self.config.get('create_replica_method', []) + def create_replica_methods(self) -> List[str]: + return self.config.get('create_replica_methods', []) or self.config.get('create_replica_method', []) or [] @property - def major_version(self): + def major_version(self) -> int: return self._major_version @property - def database(self): + def database(self) -> str: return self._database @property - def data_dir(self): + def data_dir(self) -> str: return self._data_dir @property - def callback(self): - return self.config.get('callbacks') or {} + def callback(self) -> Dict[str, str]: + return self.config.get('callbacks', {}) or {} @property - def wal_dir(self): + def wal_dir(self) -> str: return os.path.join(self._data_dir, 'pg_' + self.wal_name) @property - def wal_name(self): + def wal_name(self) -> str: return 'wal' if self._major_version >= 100000 else 'xlog' @property - def lsn_name(self): + def lsn_name(self) -> str: return 'lsn' if self._major_version >= 100000 else 'location' @property @@ -163,7 +168,7 @@ class Postgresql(object): return self._major_version >= 90600 @property - def cluster_info_query(self): + def cluster_info_query(self) -> str: """Returns the monitoring query with a fixed number of fields. The query text is constructed based on current state in DCS and PostgreSQL version: @@ -206,7 +211,7 @@ class Postgresql(object): return ("SELECT " + self.TL_LSN + ", {2}").format(self.wal_name, self.lsn_name, extra) - def _version_file_exists(self): + def _version_file_exists(self) -> bool: return not self.data_directory_empty() and os.path.isfile(self._version_file) def get_major_version(self) -> int: @@ -221,11 +226,11 @@ class Postgresql(object): logger.exception('Failed to read PG_VERSION from %s', self._data_dir) return 0 - def pgcommand(self, cmd): + def pgcommand(self, cmd: str) -> str: """Returns path to the specified PostgreSQL command""" return os.path.join(self._bin_dir, cmd) - def pg_ctl(self, cmd, *args, **kwargs): + def pg_ctl(self, cmd: str, *args: str, **kwargs: Any) -> bool: """Builds and executes pg_ctl command :returns: `!True` when return_code == 0, otherwise `!False`""" @@ -244,7 +249,7 @@ class Postgresql(object): initdb = [self.pgcommand('initdb')] + list(args) + [self.data_dir] return subprocess.call(initdb, **kwargs) == 0 - def pg_isready(self): + def pg_isready(self) -> str: """Runs pg_isready to see if PostgreSQL is accepting connections. :returns: 'ok' if PostgreSQL is up, 'reject' if starting up, 'no_resopnse' if not up.""" @@ -267,25 +272,25 @@ class Postgresql(object): 3: STATE_UNKNOWN} return return_codes.get(ret, STATE_UNKNOWN) - def reload_config(self, config, sighup=False): + def reload_config(self, config: Dict[str, Any], sighup: bool = False) -> None: self.config.reload_config(config, sighup) self._is_leader_retry.deadline = self.retry.deadline = config['retry_timeout'] / 2.0 @property - def pending_restart(self): + def pending_restart(self) -> bool: return self._pending_restart - def set_pending_restart(self, value): + def set_pending_restart(self, value: bool) -> None: self._pending_restart = value @property - def sysid(self): + def sysid(self) -> str: if not self._sysid and not self.bootstrapping: data = self.controldata() - self._sysid = data.get('Database system identifier', "") + self._sysid = data.get('Database system identifier', '') return self._sysid - def get_postgres_role_from_data_directory(self): + def get_postgres_role_from_data_directory(self) -> str: if self.data_directory_empty() or not self.controldata(): return 'uninitialized' elif self.config.recovery_conf_exists(): @@ -294,24 +299,24 @@ class Postgresql(object): return 'master' @property - def server_version(self): + def server_version(self) -> int: return self._connection.server_version - def connection(self): + def connection(self) -> Union['connection3', 'Connection3[Any]']: return self._connection.get() - def set_connection_kwargs(self, kwargs): + def set_connection_kwargs(self, kwargs: Dict[str, Any]) -> None: self._connection.set_conn_kwargs(kwargs.copy()) self.citus_handler.set_conn_kwargs(kwargs.copy()) - def _query(self, sql, *params): + def _query(self, sql: str, *params: Any) -> Union['Cursor[Any]', 'cursor']: """We are always using the same cursor, therefore this method is not thread-safe!!! You can call it from different threads only if you are holding explicit `AsyncExecutor` lock, because the main thread is always holding this lock when running HA cycle.""" cursor = None try: cursor = self._connection.cursor() - cursor.execute(sql, params or None) + cursor.execute(sql.encode('utf-8'), params or None) return cursor except psycopg.Error as e: if cursor and cursor.connection.closed == 0: @@ -327,7 +332,7 @@ class Postgresql(object): raise RetryFailedError('cluster is being restarted') raise PostgresConnectionException('connection problems') - def query(self, sql, *args, **kwargs): + def query(self, sql: str, *args: Any, **kwargs: Any) -> Union['Cursor[Any]', 'cursor']: if not kwargs.get('retry', True): return self._query(sql, *args) try: @@ -335,22 +340,22 @@ class Postgresql(object): except RetryFailedError as e: raise PostgresConnectionException(str(e)) - def pg_control_exists(self): + def pg_control_exists(self) -> bool: return os.path.isfile(self._pg_control) - def data_directory_empty(self): + def data_directory_empty(self) -> bool: if self.pg_control_exists(): return False return data_directory_is_empty(self._data_dir) - def replica_method_options(self, method): - return deepcopy(self.config.get(method, {})) + def replica_method_options(self, method: str) -> Dict[str, Any]: + return deepcopy(self.config.get(method, {}) or {}) - def replica_method_can_work_without_replication_connection(self, method): - return method != 'basebackup' and (self.replica_method_options(method).get('no_master') - or self.replica_method_options(method).get('no_leader')) + def replica_method_can_work_without_replication_connection(self, method: str) -> bool: + return method != 'basebackup' and bool(self.replica_method_options(method).get('no_master') + or self.replica_method_options(method).get('no_leader')) - def can_create_replica_without_replication_connection(self, replica_methods=None): + def can_create_replica_without_replication_connection(self, replica_methods: Optional[List[str]]) -> bool: """ go through the replication methods to see if there are ones that does not require a working replication connection. """ @@ -359,10 +364,10 @@ class Postgresql(object): return any(self.replica_method_can_work_without_replication_connection(m) for m in replica_methods) @property - def enforce_hot_standby_feedback(self): + def enforce_hot_standby_feedback(self) -> bool: return self._enforce_hot_standby_feedback - def set_enforce_hot_standby_feedback(self, value): + def set_enforce_hot_standby_feedback(self, value: bool) -> None: # If we enable or disable the hot_standby_feedback we need to update postgresql.conf and reload if self._enforce_hot_standby_feedback != value: self._enforce_hot_standby_feedback = value @@ -370,7 +375,7 @@ class Postgresql(object): self.config.write_postgresql_conf() self.reload() - def reset_cluster_info_state(self, cluster: Union[Cluster, None], nofailover: Optional[bool] = None, + def reset_cluster_info_state(self, cluster: Union[Cluster, None], nofailover: bool = False, global_config: Optional['GlobalConfig'] = None) -> None: """Reset monitoring query cache. @@ -395,7 +400,7 @@ class Postgresql(object): self._global_config = global_config - def _cluster_info_state_get(self, name): + def _cluster_info_state_get(self, name: str) -> Optional[Any]: if not self._cluster_info_state: try: result = self._is_leader_retry(self._query, self.cluster_info_query).fetchone() @@ -417,62 +422,63 @@ class Postgresql(object): return self._cluster_info_state.get(name) - def replayed_location(self): + def replayed_location(self) -> Optional[int]: return self._cluster_info_state_get('replayed_location') - def received_location(self): + def received_location(self) -> Optional[int]: return self._cluster_info_state_get('received_location') - def slots(self): - return self._cluster_info_state_get('slots') + def slots(self) -> Dict[str, int]: + return self._cluster_info_state_get('slots') or {} - def primary_slot_name(self): + def primary_slot_name(self) -> Optional[str]: return self._cluster_info_state_get('slot_name') - def primary_conninfo(self): + def primary_conninfo(self) -> Optional[str]: return self._cluster_info_state_get('conninfo') - def received_timeline(self): + def received_timeline(self) -> Optional[int]: return self._cluster_info_state_get('received_tli') def synchronous_commit(self) -> str: """:returns: "synchronous_commit" GUC value.""" - return self._cluster_info_state_get('synchronous_commit') + return self._cluster_info_state_get('synchronous_commit') or 'on' def synchronous_standby_names(self) -> str: """:returns: "synchronous_standby_names" GUC value.""" - return self._cluster_info_state_get('synchronous_standby_names') + return self._cluster_info_state_get('synchronous_standby_names') or '' def pg_stat_replication(self) -> List[Dict[str, Any]]: """:returns: a result set of 'SELECT * FROM pg_stat_replication'.""" return self._cluster_info_state_get('pg_stat_replication') or [] - def is_leader(self): + def is_leader(self) -> bool: try: return bool(self._cluster_info_state_get('timeline')) except PostgresConnectionException: logger.warning('Failed to determine PostgreSQL state from the connection, falling back to cached role') return bool(self.is_running() and self.role in ('master', 'primary')) - def replay_paused(self): - return self._cluster_info_state_get('replay_paused') + def replay_paused(self) -> bool: + return self._cluster_info_state_get('replay_paused') or False - def resume_wal_replay(self): + def resume_wal_replay(self) -> None: self._query('SELECT pg_catalog.pg_{0}_replay_resume()'.format(self.wal_name)) - def handle_parameter_change(self): + def handle_parameter_change(self) -> None: if self.major_version >= 140000 and not self.is_starting() and self.replay_paused(): logger.info('Resuming paused WAL replay for PostgreSQL 14+') self.resume_wal_replay() - def pg_control_timeline(self): + def pg_control_timeline(self) -> Optional[int]: try: - return int(self.controldata().get("Latest checkpoint's TimeLineID")) + return int(self.controldata().get("Latest checkpoint's TimeLineID", "")) except (TypeError, ValueError): logger.exception('Failed to parse timeline from pg_controldata output') - def parse_wal_record(self, timeline, lsn): + def parse_wal_record(self, timeline: str, + lsn: str) -> Union[Tuple[str, str, str, str], Tuple[None, None, None, None]]: out, err = self.waldump(timeline, lsn, 1) if out and not err: match = re.match(r'^rmgr:\s+(.+?)\s+len \(rec/tot\):\s+\d+/\s+\d+, tx:\s+\d+, ' @@ -482,7 +488,7 @@ class Postgresql(object): return match.groups() return None, None, None, None - def latest_checkpoint_location(self): + def latest_checkpoint_location(self) -> Optional[int]: """Returns checkpoint location for the cleanly shut down primary. But, if we know that the checkpoint was written to the new WAL due to the archive_mode=on, we will return the LSN of prev wal record (SWITCH).""" @@ -490,25 +496,25 @@ class Postgresql(object): data = self.controldata() timeline = data.get("Latest checkpoint's TimeLineID") lsn = checkpoint_lsn = data.get('Latest checkpoint location') - if data.get('Database cluster state') == 'shut down' and lsn and timeline: + if data.get('Database cluster state') == 'shut down' and lsn and timeline and checkpoint_lsn: try: checkpoint_lsn = parse_lsn(checkpoint_lsn) rm_name, lsn, prev, desc = self.parse_wal_record(timeline, lsn) - desc = desc.strip().lower() - if rm_name == 'XLOG' and parse_lsn(lsn) == checkpoint_lsn and prev and\ + desc = str(desc).strip().lower() + if rm_name == 'XLOG' and lsn and parse_lsn(lsn) == checkpoint_lsn and prev and\ desc.startswith('checkpoint') and desc.endswith('shutdown'): _, lsn, _, desc = self.parse_wal_record(timeline, prev) prev = parse_lsn(prev) # If the cluster is shutdown with archive_mode=on, WAL is switched before writing the checkpoint. # In this case we want to take the LSN of previous record (switch) as the last known WAL location. - if parse_lsn(lsn) == prev and desc.strip() in ('xlog switch', 'SWITCH'): + if lsn and parse_lsn(lsn) == prev and str(desc).strip() in ('xlog switch', 'SWITCH'): return prev except Exception as e: logger.error('Exception when parsing WAL pg_%sdump output: %r', self.wal_name, e) if isinstance(checkpoint_lsn, int): return checkpoint_lsn - def is_running(self): + def is_running(self) -> Optional[PostmasterProcess]: """Returns PostmasterProcess if one is running on the data directory or None. If most recently seen process is running updates the cached process based on pid file.""" if self._postmaster_proc: @@ -523,7 +529,7 @@ class Postgresql(object): return self._postmaster_proc @property - def cb_called(self): + def cb_called(self) -> bool: return self.__cb_called def call_nowait(self, cb_type: CallbackAction) -> None: @@ -544,31 +550,31 @@ class Postgresql(object): logger.exception('callback %s %r %s %s failed', cmd, cb_type, role, self.scope) @property - def role(self): + def role(self) -> str: with self._role_lock: return self._role - def set_role(self, value): + def set_role(self, value: str) -> None: with self._role_lock: self._role = value @property - def state(self): + def state(self) -> str: with self._state_lock: return self._state - def set_state(self, value): + def set_state(self, value: str) -> None: with self._state_lock: self._state = value self._state_entry_timestamp = time.time() - def time_in_state(self): + def time_in_state(self) -> float: return time.time() - self._state_entry_timestamp - def is_starting(self): + def is_starting(self) -> bool: return self.state == 'starting' - def wait_for_port_open(self, postmaster, timeout): + def wait_for_port_open(self, postmaster: PostmasterProcess, timeout: float) -> bool: """Waits until PostgreSQL opens ports.""" for _ in polling_loop(timeout): if self.cancellable.is_cancelled: @@ -588,7 +594,9 @@ class Postgresql(object): logger.warning("Timed out waiting for PostgreSQL to start") return False - def start(self, timeout=None, task=None, block_callbacks=False, role=None, after_start=None): + def start(self, timeout: Optional[float] = None, task: Optional[CriticalTask] = None, + block_callbacks: bool = False, role: Optional[str] = None, + after_start: Optional[Callable[..., Any]] = None) -> Optional[bool]: """Start PostgreSQL Waits for postmaster to open ports or terminate so pg_isready can be used to check startup completion @@ -650,7 +658,7 @@ class Postgresql(object): start_timeout = timeout if not start_timeout: try: - start_timeout = float(self.config.get('pg_ctl_timeout', 60)) + start_timeout = float(self.config.get('pg_ctl_timeout', 60) or 0) except ValueError: start_timeout = 60 @@ -668,7 +676,8 @@ class Postgresql(object): else: return None - def checkpoint(self, connect_kwargs=None, timeout=None): + def checkpoint(self, connect_kwargs: Optional[Dict[str, Any]] = None, + timeout: Optional[float] = None) -> Optional[str]: check_not_is_in_recovery = connect_kwargs is not None connect_kwargs = connect_kwargs or self.config.local_connect_kwargs for p in ['connect_timeout', 'options']: @@ -680,15 +689,17 @@ class Postgresql(object): cur.execute("SET statement_timeout = 0") if check_not_is_in_recovery: cur.execute('SELECT pg_catalog.pg_is_in_recovery()') - if cur.fetchone()[0]: + row = cur.fetchone() + if not row or row[0]: return 'is_in_recovery=true' cur.execute('CHECKPOINT') except psycopg.Error: logger.exception('Exception during CHECKPOINT') return 'not accessible or not healty' - def stop(self, mode='fast', block_callbacks=False, checkpoint=None, - on_safepoint=None, on_shutdown=None, before_shutdown=None, stop_timeout=None): + def stop(self, mode: str = 'fast', block_callbacks: bool = False, checkpoint: Optional[bool] = None, + on_safepoint: Optional[Callable[..., Any]] = None, on_shutdown: Optional[Callable[[int], Any]] = None, + before_shutdown: Optional[Callable[..., Any]] = None, stop_timeout: Optional[int] = None) -> bool: """Stop PostgreSQL Supports a callback when a safepoint is reached. A safepoint is when no user backend can return a successful @@ -716,7 +727,9 @@ class Postgresql(object): self.set_state('stop failed') return success - def _do_stop(self, mode, block_callbacks, checkpoint, on_safepoint, on_shutdown, before_shutdown, stop_timeout): + def _do_stop(self, mode: str, block_callbacks: bool, checkpoint: bool, + on_safepoint: Optional[Callable[..., Any]], on_shutdown: Optional[Callable[..., Any]], + before_shutdown: Optional[Callable[..., Any]], stop_timeout: Optional[int]) -> Tuple[bool, bool]: postmaster = self.is_running() if not postmaster: if on_safepoint: @@ -774,7 +787,8 @@ class Postgresql(object): return True, True - def terminate_postmaster(self, postmaster, mode, stop_timeout): + def terminate_postmaster(self, postmaster: PostmasterProcess, mode: str, + stop_timeout: Optional[int]) -> Optional[bool]: if mode in ['fast', 'smart']: try: success = postmaster.signal_stop('immediate', self.pgcommand('pg_ctl')) @@ -787,13 +801,13 @@ class Postgresql(object): logger.warning("Sending SIGKILL to Postmaster and its children") return postmaster.signal_kill() - def terminate_starting_postmaster(self, postmaster): + def terminate_starting_postmaster(self, postmaster: PostmasterProcess) -> None: """Terminates a postmaster that has not yet opened ports or possibly even written a pid file. Blocks until the process goes away.""" postmaster.signal_stop('immediate', self.pgcommand('pg_ctl')) postmaster.wait() - def _wait_for_connection_close(self, postmaster): + def _wait_for_connection_close(self, postmaster: PostmasterProcess) -> None: try: with self.connection().cursor() as cur: while postmaster.is_running(): # Need a timeout here? @@ -802,17 +816,17 @@ class Postgresql(object): except psycopg.Error: pass - def reload(self, block_callbacks=False): + def reload(self, block_callbacks: bool = False) -> bool: ret = self.pg_ctl('reload') if ret and not block_callbacks: self.call_nowait(CallbackAction.ON_RELOAD) return ret - def check_for_startup(self): + def check_for_startup(self) -> bool: """Checks PostgreSQL status and returns if PostgreSQL is in the middle of startup.""" return self.is_starting() and not self.check_startup_state_changed() - def check_startup_state_changed(self): + def check_startup_state_changed(self) -> bool: """Checks if PostgreSQL has completed starting up or failed or still starting. Should only be called when state == 'starting' @@ -847,7 +861,7 @@ class Postgresql(object): return True - def wait_for_startup(self, timeout=None): + def wait_for_startup(self, timeout: float = 0) -> Optional[bool]: """Waits for PostgreSQL startup to complete or fail. :returns: True if start was successful, False otherwise""" @@ -862,8 +876,10 @@ class Postgresql(object): return self.state == 'running' - def restart(self, timeout=None, task=None, block_callbacks=False, - role=None, before_shutdown=None, after_start=None): + def restart(self, timeout: Optional[float] = None, task: Optional[CriticalTask] = None, + block_callbacks: bool = False, role: Optional[str] = None, + before_shutdown: Optional[Callable[..., Any]] = None, + after_start: Optional[Callable[..., Any]] = None) -> Optional[bool]: """Restarts PostgreSQL. When timeout parameter is set the call will block either until PostgreSQL has started, failed to start or @@ -880,13 +896,13 @@ class Postgresql(object): self.set_state('restart failed ({0})'.format(self.state)) return ret - def is_healthy(self): + def is_healthy(self) -> bool: if not self.is_running(): logger.warning('Postgresql is not running.') return False return True - def get_guc_value(self, name): + def get_guc_value(self, name: str) -> Optional[str]: cmd = [self.pgcommand('postgres'), '-D', self._data_dir, '-C', name, '--config-file={}'.format(self.config.postgresql_conf)] try: @@ -896,7 +912,7 @@ class Postgresql(object): except Exception as e: logger.error('Failed to execute %s: %r', cmd, e) - def controldata(self): + def controldata(self) -> Dict[str, str]: """ return the contents of pg_controldata, or non-True value if pg_controldata call failed """ # Don't try to call pg_controldata during backup restore if self._version_file_exists() and self.state != 'creating replica': @@ -911,11 +927,11 @@ class Postgresql(object): logger.exception("Error when calling pg_controldata") return {} - def waldump(self, timeline, lsn, limit): + def waldump(self, timeline: Union[int, str], lsn: str, limit: int) -> Tuple[Optional[bytes], Optional[bytes]]: cmd = self.pgcommand('pg_{0}dump'.format(self.wal_name)) env = {**os.environ, 'LANG': 'C', 'LC_ALL': 'C', 'PGDATA': self._data_dir} try: - waldump = subprocess.Popen([cmd, '-t', str(timeline), '-s', str(lsn), '-n', str(limit)], + waldump = subprocess.Popen([cmd, '-t', str(timeline), '-s', lsn, '-n', str(limit)], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env) out, err = waldump.communicate() waldump.wait() @@ -925,22 +941,24 @@ class Postgresql(object): return None, None @contextmanager - def get_replication_connection_cursor(self, host=None, port=5432, **kwargs): + def get_replication_connection_cursor(self, host: Optional[str] = None, port: int = 5432, + **kwargs: Any) -> Generator[Union['cursor', 'Cursor[Any]'], None, None]: conn_kwargs = self.config.replication.copy() conn_kwargs.update(host=host, port=int(port) if port else None, user=conn_kwargs.pop('username'), connect_timeout=3, replication=1, options='-c statement_timeout=2000') with get_connection_cursor(**conn_kwargs) as cur: yield cur - def get_replica_timeline(self): + def get_replica_timeline(self) -> Optional[int]: try: with self.get_replication_connection_cursor(**self.config.local_replication_address) as cur: cur.execute('IDENTIFY_SYSTEM') - return cur.fetchone()[1] + row = cur.fetchone() + return row[1] if row else None except Exception: logger.exception('Can not fetch local timeline and lsn from replication connection') - def replica_cached_timeline(self, primary_timeline): + def replica_cached_timeline(self, primary_timeline: Optional[int]) -> Optional[int]: if not self._cached_replica_timeline or not primary_timeline\ or self._cached_replica_timeline != primary_timeline: self._cached_replica_timeline = self.get_replica_timeline() @@ -948,29 +966,26 @@ class Postgresql(object): def get_primary_timeline(self) -> int: """:returns: current timeline if postgres is running as a primary or 0.""" - return self._cluster_info_state_get('timeline') + return self._cluster_info_state_get('timeline') or 0 - def get_history(self, timeline): + def get_history(self, timeline: int) -> List[Union[Tuple[int, int, str], Tuple[int, int, str, str, str]]]: history_path = os.path.join(self.wal_dir, '{0:08X}.history'.format(timeline)) history_mtime = mtime(history_path) + history: List[Union[Tuple[int, int, str], Tuple[int, int, str, str, str]]] = [] if history_mtime: try: with open(history_path, 'r') as f: - history = f.read() - history = list(parse_history(history)) + history_content = f.read() + history = list(parse_history(history_content)) if history[-1][0] == timeline - 1: history_mtime = datetime.fromtimestamp(history_mtime).replace(tzinfo=tz.tzlocal()) - history[-1].append(history_mtime.isoformat()) - history[-1].append(self.name) - return history + history[-1] = history[-1][:3] + (history_mtime.isoformat(), self.name) except Exception: logger.exception('Failed to read and parse %s', (history_path,)) + return history - def follow(self, - member: Member, - role: Optional[str] = 'replica', - timeout: Optional[float] = None, - do_reload: Optional[bool] = False) -> Optional[bool]: + def follow(self, member: Union[Leader, Member, None], role: str = 'replica', + timeout: Optional[float] = None, do_reload: bool = False) -> Optional[bool]: """Reconfigure postgres to follow a new member or use different recovery parameters. Method may call `on_role_change` callback if role is changing. @@ -1016,14 +1031,14 @@ class Postgresql(object): self.call_nowait(CallbackAction.ON_ROLE_CHANGE) return ret - def _wait_promote(self, wait_seconds): + def _wait_promote(self, wait_seconds: int) -> Optional[bool]: for _ in polling_loop(wait_seconds): data = self.controldata() if data.get('Database cluster state') == 'in production': self.set_role('master') return True - def _pre_promote(self): + def _pre_promote(self) -> bool: """ Runs a fencing script after the leader lock is acquired but before the replica is promoted. If the script exits with a non-zero code, promotion does not happen and the leader key is removed from DCS. @@ -1053,7 +1068,8 @@ class Postgresql(object): except Exception as e: logger.error('Exception when calling `%s`: %r', cmd, e) - def promote(self, wait_seconds, task, before_promote=None, on_success=None): + def promote(self, wait_seconds: int, task: CriticalTask, before_promote: Optional[Callable[..., Any]] = None, + on_success: Optional[Callable[..., Any]] = None) -> Optional[bool]: if self.role in ('promoted', 'master', 'primary'): return True @@ -1086,43 +1102,48 @@ class Postgresql(object): return ret @staticmethod - def _wal_position(is_leader, wal_position, received_location, replayed_location): + def _wal_position(is_leader: bool, wal_position: int, + received_location: Optional[int], replayed_location: Optional[int]) -> int: return wal_position if is_leader else max(received_location or 0, replayed_location or 0) - def timeline_wal_position(self): + def timeline_wal_position(self) -> Tuple[int, int, Optional[int]]: # This method could be called from different threads (simultaneously with some other `_query` calls). # If it is called not from main thread we will create a new cursor to execute statement. if current_thread().ident == self.__thread_ident: - timeline = self._cluster_info_state_get('timeline') - wal_position = self._cluster_info_state_get('wal_position') + timeline = self._cluster_info_state_get('timeline') or 0 + wal_position = self._cluster_info_state_get('wal_position') or 0 replayed_location = self.replayed_location() received_location = self.received_location() pg_control_timeline = self._cluster_info_state_get('pg_control_timeline') else: with self.connection().cursor() as cursor: - cursor.execute(self.cluster_info_query) - (timeline, wal_position, replayed_location, - received_location, _, pg_control_timeline) = cursor.fetchone()[:6] + cursor.execute(self.cluster_info_query.encode('utf-8')) + row = cursor.fetchone() + if TYPE_CHECKING: # pragma: no cover + assert row is not None + (timeline, wal_position, replayed_location, received_location, _, pg_control_timeline) = row[:6] - wal_position = self._wal_position(timeline, wal_position, received_location, replayed_location) + wal_position = self._wal_position(bool(timeline), wal_position, received_location, replayed_location) return (timeline, wal_position, pg_control_timeline) - def postmaster_start_time(self): + def postmaster_start_time(self) -> Optional[str]: try: query = "SELECT " + self.POSTMASTER_START_TIME if current_thread().ident == self.__thread_ident: - return self.query(query).fetchone()[0].isoformat(sep=' ') - with self.connection().cursor() as cursor: - cursor.execute(query) - return cursor.fetchone()[0].isoformat(sep=' ') + row = self.query(query).fetchone() + else: + with self.connection().cursor() as cursor: + cursor.execute(query) + row = cursor.fetchone() + return row[0].isoformat(sep=' ') if row else None except psycopg.Error: return None - def last_operation(self): - return self._wal_position(self.is_leader(), self._cluster_info_state_get('wal_position'), + def last_operation(self) -> int: + return self._wal_position(self.is_leader(), self._cluster_info_state_get('wal_position') or 0, self.received_location(), self.replayed_location()) - def configure_server_parameters(self): + def configure_server_parameters(self) -> None: self._major_version = self.get_major_version() self.config.setup_server_parameters() @@ -1135,9 +1156,9 @@ class Postgresql(object): self.configure_server_parameters() return self._major_version > 0 - def pg_wal_realpath(self): + def pg_wal_realpath(self) -> Dict[str, str]: """Returns a dict containing the symlink (key) and target (value) for the wal directory""" - links = {} + links: Dict[str, str] = {} for pg_wal_dir in ('pg_xlog', 'pg_wal'): pg_wal_path = os.path.join(self._data_dir, pg_wal_dir) if os.path.exists(pg_wal_path) and os.path.islink(pg_wal_path): @@ -1145,9 +1166,9 @@ class Postgresql(object): links[pg_wal_path] = pg_wal_realpath return links - def pg_tblspc_realpaths(self): + def pg_tblspc_realpaths(self) -> Dict[str, str]: """Returns a dict containing the symlink (key) and target (values) for the tablespaces""" - links = {} + links: Dict[str, str] = {} pg_tblsp_dir = os.path.join(self._data_dir, 'pg_tblspc') if os.path.exists(pg_tblsp_dir): for tsdn in os.listdir(pg_tblsp_dir): @@ -1157,7 +1178,7 @@ class Postgresql(object): links[pg_tsp_path] = pg_tsp_rpath return links - def move_data_directory(self): + def move_data_directory(self) -> None: if os.path.isdir(self._data_dir) and not self.is_running(): try: postfix = 'failed' @@ -1191,7 +1212,7 @@ class Postgresql(object): except OSError: logger.exception("Could not rename data directory %s", self._data_dir) - def remove_data_directory(self): + def remove_data_directory(self) -> None: self.set_role('uninitialized') logger.info('Removing data directory: %s', self._data_dir) try: @@ -1219,7 +1240,7 @@ class Postgresql(object): logger.exception('Could not remove data directory %s', self._data_dir) self.move_data_directory() - def schedule_sanity_checks_after_pause(self): + def schedule_sanity_checks_after_pause(self) -> None: """ After coming out of pause we have to: 1. configure server parameters if necessary @@ -1229,4 +1250,4 @@ class Postgresql(object): self.ensure_major_version_is_known() self.slots_handler.schedule() self.citus_handler.schedule_cache_rebuild() - self._sysid = None + self._sysid = '' diff --git a/patroni/postgresql/bootstrap.py b/patroni/postgresql/bootstrap.py index e21e1202..a76f9f5f 100644 --- a/patroni/postgresql/bootstrap.py +++ b/patroni/postgresql/bootstrap.py @@ -3,34 +3,39 @@ import os import shlex import tempfile import time -from typing import List, Dict, Union, Callable, Tuple -from ..dcs import RemoteMember +from typing import Any, Callable, Dict, List, Optional, Union, Tuple, TYPE_CHECKING + +from ..async_executor import CriticalTask +from ..dcs import Leader, Member, RemoteMember from ..psycopg import quote_ident, quote_literal from ..utils import deep_compare, unquote +if TYPE_CHECKING: # pragma: no cover + from . import Postgresql + logger = logging.getLogger(__name__) class Bootstrap(object): - def __init__(self, postgresql): + def __init__(self, postgresql: 'Postgresql') -> None: self._postgresql = postgresql self._running_custom_bootstrap = False @property - def running_custom_bootstrap(self): + def running_custom_bootstrap(self) -> bool: return self._running_custom_bootstrap @property - def keep_existing_recovery_conf(self): + def keep_existing_recovery_conf(self) -> bool: return self._running_custom_bootstrap and self._keep_existing_recovery_conf @staticmethod def process_user_options(tool: str, - options: Union[Dict[str, str], List[Union[str, Dict[str, str]]]], + options: Union[Any, Dict[str, str], List[Union[str, Dict[str, Any]]]], not_allowed_options: Tuple[str, ...], - error_handler: Callable[[str], None]) -> List: + error_handler: Callable[[str], None]) -> List[str]: """Format *options* in a list or dictionary format into command line long form arguments. .. note:: @@ -77,9 +82,9 @@ class Bootstrap(object): :param error_handler: A function which will be called when an error condition is encountered :returns: List of long form arguments to pass to the named tool """ - user_options = [] + user_options: List[str] = [] - def option_is_allowed(name): + def option_is_allowed(name: str) -> bool: ret = name not in not_allowed_options if not ret: error_handler('{0} option for {1} is not allowed'.format(name, tool)) @@ -106,11 +111,11 @@ class Bootstrap(object): error_handler('{0} options must be list or dict'.format(tool)) return user_options - def _initdb(self, config): + def _initdb(self, config: Any) -> bool: self._postgresql.set_state('initializing new cluster') not_allowed_options = ('pgdata', 'nosync', 'pwfile', 'sync-only', 'version') - def error_handler(e): + def error_handler(e: str) -> None: raise Exception(e) options = self.process_user_options('initdb', config or [], not_allowed_options, error_handler) @@ -134,7 +139,7 @@ class Bootstrap(object): self._postgresql.set_state('initdb failed') return ret - def _post_restore(self): + def _post_restore(self) -> None: self._postgresql.config.restore_configuration_files() self._postgresql.configure_server_parameters() @@ -145,7 +150,7 @@ class Bootstrap(object): if os.path.exists(trigger_file): os.unlink(trigger_file) - def _custom_bootstrap(self, config): + def _custom_bootstrap(self, config: Any) -> bool: self._postgresql.set_state('running custom bootstrap script') params = [] if config.get('no_params') else ['--scope=' + self._postgresql.scope, '--datadir=' + self._postgresql.data_dir] @@ -165,7 +170,7 @@ class Bootstrap(object): self._postgresql.config.remove_recovery_conf() return True - def call_post_bootstrap(self, config): + def call_post_bootstrap(self, config: Dict[str, Any]) -> bool: """ runs a script after initdb or custom bootstrap script is called and waits until completion. """ @@ -192,7 +197,7 @@ class Bootstrap(object): return False return True - def create_replica(self, clone_member): + def create_replica(self, clone_member: Union[Leader, Member, None]) -> Optional[int]: """ create the replica according to the replica_method defined by the user. this is a list, so we need to @@ -279,7 +284,7 @@ class Bootstrap(object): self._postgresql.set_state('stopped') return ret - def basebackup(self, conn_url, env, options): + def basebackup(self, conn_url: str, env: Dict[str, str], options: Dict[str, Any]) -> Optional[int]: # creates a replica data dir using pg_basebackup. # this is the default, built-in create_replica_methods # tries twice, then returns failure (as 1) @@ -314,7 +319,7 @@ class Bootstrap(object): return ret - def clone(self, clone_member): + def clone(self, clone_member: Union[Leader, Member, None]) -> bool: """ - initialize the replica from an existing member (primary or replica) - initialize the replica using the replica creation method that @@ -327,7 +332,7 @@ class Bootstrap(object): self._post_restore() return ret - def bootstrap(self, config): + def bootstrap(self, config: Dict[str, Any]) -> bool: """ Initialize a new node from scratch and start it. """ pg_hba = config.get('pg_hba', []) method = config.get('method') or 'initdb' @@ -339,9 +344,9 @@ class Bootstrap(object): method = 'initdb' do_initialize = self._initdb return do_initialize(config.get(method)) and self._postgresql.config.append_pg_hba(pg_hba) \ - and self._postgresql.config.save_configuration_files() and self._postgresql.start() + and self._postgresql.config.save_configuration_files() and bool(self._postgresql.start()) - def create_or_update_role(self, name, password, options): + def create_or_update_role(self, name: str, password: Optional[str], options: List[str]) -> None: options = list(map(str.upper, options)) if 'NOLOGIN' not in options and 'LOGIN' not in options: options.append('LOGIN') @@ -371,7 +376,7 @@ END;$$""".format(quote_literal(name), quote_ident(name, self._postgresql.connect self._postgresql.query('RESET log_statement') self._postgresql.query('RESET pg_stat_statements.track_utility') - def post_bootstrap(self, config, task): + def post_bootstrap(self, config: Dict[str, Any], task: CriticalTask) -> Optional[bool]: try: postgresql = self._postgresql superuser = postgresql.config.superuser diff --git a/patroni/postgresql/callback_executor.py b/patroni/postgresql/callback_executor.py index 99711b08..3ae073fd 100644 --- a/patroni/postgresql/callback_executor.py +++ b/patroni/postgresql/callback_executor.py @@ -17,7 +17,7 @@ class CallbackAction(str, Enum): ON_RELOAD = "on_reload" ON_ROLE_CHANGE = "on_role_change" - def __repr__(self): + def __repr__(self) -> str: return self.value @@ -60,15 +60,17 @@ class CallbackExecutor(CancellableExecutor, Thread): self._cmd = cmd self._condition.notify() - def run(self): + def run(self) -> None: while True: with self._condition: if self._cmd is None: self._condition.wait() cmd, self._cmd = self._cmd, None - with self._lock: - if not self._start_process(cmd, close_fds=True): - continue - self._process.wait() - self._kill_children() + if cmd is not None: + with self._lock: + if not self._start_process(cmd, close_fds=True): + continue + if self._process: + self._process.wait() + self._kill_children() diff --git a/patroni/postgresql/cancellable.py b/patroni/postgresql/cancellable.py index 64a9e6d7..83423808 100644 --- a/patroni/postgresql/cancellable.py +++ b/patroni/postgresql/cancellable.py @@ -5,6 +5,7 @@ import subprocess from patroni.exceptions import PostgresException from patroni.utils import polling_loop from threading import Lock +from typing import Any, Dict, List, Optional, Union logger = logging.getLogger(__name__) @@ -15,13 +16,13 @@ class CancellableExecutor(object): There must be only one such process so that AsyncExecutor can easily cancel it. """ - def __init__(self): + def __init__(self) -> None: self._process = None self._process_cmd = None - self._process_children = [] + self._process_children: List[psutil.Process] = [] self._lock = Lock() - def _start_process(self, cmd, *args, **kwargs): + def _start_process(self, cmd: List[str], *args: Any, **kwargs: Any) -> Optional[bool]: """This method must be executed only when the `_lock` is acquired""" try: @@ -32,7 +33,7 @@ class CancellableExecutor(object): return logger.exception('Failed to execute %s', cmd) return True - def _kill_process(self): + def _kill_process(self) -> None: with self._lock: if self._process is not None and self._process.is_running() and not self._process_children: try: @@ -53,8 +54,8 @@ class CancellableExecutor(object): except psutil.AccessDenied as e: logger.warning('Failed to kill the process: %s', e.msg) - def _kill_children(self): - waitlist = [] + def _kill_children(self) -> None: + waitlist: List[psutil.Process] = [] with self._lock: for child in self._process_children: try: @@ -69,15 +70,16 @@ class CancellableExecutor(object): class CancellableSubprocess(CancellableExecutor): - def __init__(self): + def __init__(self) -> None: super(CancellableSubprocess, self).__init__() self._is_cancelled = False - def call(self, *args, **kwargs): + def call(self, *args: Any, **kwargs: Union[Any, Dict[str, str]]) -> Optional[int]: for s in ('stdin', 'stdout', 'stderr'): kwargs.pop(s, None) - communicate = kwargs.pop('communicate', None) + communicate: Optional[Dict[str, str]] = kwargs.pop('communicate', None) + input_data = None if isinstance(communicate, dict): input_data = communicate.get('input') if input_data: @@ -96,7 +98,7 @@ class CancellableSubprocess(CancellableExecutor): self._is_cancelled = False started = self._start_process(*args, **kwargs) - if started: + if started and self._process is not None: if isinstance(communicate, dict): communicate['stdout'], communicate['stderr'] = self._process.communicate(input_data) return self._process.wait() @@ -105,16 +107,16 @@ class CancellableSubprocess(CancellableExecutor): self._process = None self._kill_children() - def reset_is_cancelled(self): + def reset_is_cancelled(self) -> None: with self._lock: self._is_cancelled = False @property - def is_cancelled(self): + def is_cancelled(self) -> bool: with self._lock: return self._is_cancelled - def cancel(self, kill=False): + def cancel(self, kill: bool = False) -> None: with self._lock: self._is_cancelled = True if self._process is None or not self._process.is_running(): diff --git a/patroni/postgresql/citus.py b/patroni/postgresql/citus.py index c4bfac15..740369ef 100644 --- a/patroni/postgresql/citus.py +++ b/patroni/postgresql/citus.py @@ -4,11 +4,17 @@ import time from threading import Condition, Event, Thread from urllib.parse import urlparse +from typing import Any, Dict, List, Optional, Union, Tuple, TYPE_CHECKING from .connection import Connection -from ..dcs import CITUS_COORDINATOR_GROUP_ID +from ..dcs import CITUS_COORDINATOR_GROUP_ID, Cluster from ..psycopg import connect, quote_ident +if TYPE_CHECKING: # pragma: no cover + from psycopg import Cursor + from psycopg2 import cursor + from . import Postgresql + CITUS_SLOT_NAME_RE = re.compile(r'^citus_shard_(move|split)_slot(_[1-9][0-9]*){2,3}$') logger = logging.getLogger(__name__) @@ -16,7 +22,8 @@ logger = logging.getLogger(__name__) class PgDistNode(object): """Represents a single row in the `pg_dist_node` table""" - def __init__(self, group, host, port, event, nodeid=None, timeout=None, cooldown=None): + def __init__(self, group: int, host: str, port: int, event: str, nodeid: Optional[int] = None, + timeout: Optional[float] = None, cooldown: Optional[float] = None) -> None: self.group = group # A weird way of pausing client connections by adding the `-demoted` suffix to the hostname self.host = host + ('-demoted' if event == 'before_demote' else '') @@ -29,7 +36,7 @@ class PgDistNode(object): # If transaction was started, we need to COMMIT/ROLLBACK before the deadline self.timeout = timeout self.cooldown = cooldown or 10000 # 10s by default - self.deadline = 0 + self.deadline: float = 0 # All changes in the pg_dist_node are serialized on the Patroni # side by performing them from a thread. The thread, that is @@ -38,74 +45,74 @@ class PgDistNode(object): # the worker, and once it is done notify the calling thread. self._event = Event() - def wait(self): + def wait(self) -> None: self._event.wait() - def wakeup(self): + def wakeup(self) -> None: self._event.set() - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return isinstance(other, PgDistNode) and self.event == other.event\ and self.host == other.host and self.port == other.port - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other - def __str__(self): + def __str__(self) -> str: return ('PgDistNode(nodeid={0},group={1},host={2},port={3},event={4})' .format(self.nodeid, self.group, self.host, self.port, self.event)) - def __repr__(self): + def __repr__(self) -> str: return str(self) class CitusHandler(Thread): - def __init__(self, postgresql, config): + def __init__(self, postgresql: 'Postgresql', config: Optional[Dict[str, Union[str, int]]]) -> None: super(CitusHandler, self).__init__() self.daemon = True self._postgresql = postgresql self._config = config self._connection = Connection() - self._pg_dist_node = {} # Cache of pg_dist_node: {groupid: PgDistNode()} - self._tasks = [] # Requests to change pg_dist_node, every task is a `PgDistNode` + self._pg_dist_node: Dict[int, PgDistNode] = {} # Cache of pg_dist_node: {groupid: PgDistNode()} + self._tasks: List[PgDistNode] = [] # Requests to change pg_dist_node, every task is a `PgDistNode` self._condition = Condition() # protects _pg_dist_node, _tasks, and _schedule_load_pg_dist_node - self._in_flight = None # Reference to the `PgDistNode` if there is a transaction in progress changing it + self._in_flight: Optional[PgDistNode] = None # Reference to the `PgDistNode` being changed in a transaction self.schedule_cache_rebuild() - def is_enabled(self): + def is_enabled(self) -> bool: return isinstance(self._config, dict) - def group(self): - return self._config['group'] + def group(self) -> Optional[int]: + return int(self._config['group']) if isinstance(self._config, dict) else None - def is_coordinator(self): + def is_coordinator(self) -> bool: return self.is_enabled() and self.group() == CITUS_COORDINATOR_GROUP_ID - def is_worker(self): + def is_worker(self) -> bool: return self.is_enabled() and not self.is_coordinator() - def set_conn_kwargs(self, kwargs): - if self.is_enabled(): + def set_conn_kwargs(self, kwargs: Dict[str, Any]) -> None: + if isinstance(self._config, dict): # self.is_enabled(): kwargs.update({'dbname': self._config['database'], 'options': '-c statement_timeout=0 -c idle_in_transaction_session_timeout=0'}) self._connection.set_conn_kwargs(kwargs) - def schedule_cache_rebuild(self): + def schedule_cache_rebuild(self) -> None: with self._condition: self._schedule_load_pg_dist_node = True - def on_demote(self): + def on_demote(self) -> None: with self._condition: self._pg_dist_node.clear() self._tasks[:] = [] self._in_flight = None - def query(self, sql, *params): + def query(self, sql: str, *params: Any) -> Union['Cursor[Any]', 'cursor']: try: logger.debug('query(%s, %s)', sql, params) cursor = self._connection.cursor() - cursor.execute(sql, params or None) + cursor.execute(sql.encode('utf-8'), params or None) return cursor except Exception as e: logger.error('Exception when executing query "%s", (%s): %r', sql, params, e) @@ -114,7 +121,7 @@ class CitusHandler(Thread): self.schedule_cache_rebuild() raise e - def load_pg_dist_node(self): + def load_pg_dist_node(self) -> bool: """Read from the `pg_dist_node` table and put it into the local cache""" with self._condition: @@ -132,7 +139,7 @@ class CitusHandler(Thread): self._pg_dist_node = {r[1]: PgDistNode(r[1], r[2], r[3], 'after_promote', r[0]) for r in cursor} return True - def sync_pg_dist_node(self, cluster): + def sync_pg_dist_node(self, cluster: Cluster) -> None: """Maintain the `pg_dist_node` from the coordinator leader every heartbeat loop. We can't always rely on REST API calls from worker nodes in order @@ -156,12 +163,12 @@ class CitusHandler(Thread): and leader.data.get('role') in ('master', 'primary') and leader.data.get('state') == 'running': self.add_task('after_promote', group, leader.conn_url) - def find_task_by_group(self, group): + def find_task_by_group(self, group: int) -> Optional[int]: for i, task in enumerate(self._tasks): if task.group == group: return i - def pick_task(self): + def pick_task(self) -> Tuple[Optional[int], Optional[PgDistNode]]: """Returns the tuple(i, task), where `i` - is the task index in the self._tasks list Tasks are picked by following priorities: @@ -195,15 +202,17 @@ class CitusHandler(Thread): task.nodeid = self._pg_dist_node[task.group].nodeid return i, task - def update_node(self, task): + def update_node(self, task: PgDistNode) -> None: if task.nodeid is not None: self.query('SELECT pg_catalog.citus_update_node(%s, %s, %s, true, %s)', task.nodeid, task.host, task.port, task.cooldown) elif task.event != 'before_demote': - task.nodeid = self.query("SELECT pg_catalog.citus_add_node(%s, %s, %s, 'primary', 'default')", - task.host, task.port, task.group).fetchone()[0] + row = self.query("SELECT pg_catalog.citus_add_node(%s, %s, %s, 'primary', 'default')", + task.host, task.port, task.group).fetchone() + if row is not None: + task.nodeid = row[0] - def process_task(self, task): + def process_task(self, task: PgDistNode) -> bool: """Updates a single row in `pg_dist_node` table, optionally in a transaction. The transaction is started if we do a demote of the worker node @@ -238,13 +247,13 @@ class CitusHandler(Thread): self._in_flight = task return False - def process_tasks(self): + def process_tasks(self) -> None: while True: if not self._in_flight and not self.load_pg_dist_node(): break i, task = self.pick_task() - if not task: + if not task or i is None: break try: update_cache = self.process_task(task) @@ -259,7 +268,7 @@ class CitusHandler(Thread): self._tasks.pop(i) task.wakeup() - def run(self): + def run(self) -> None: while True: try: with self._condition: @@ -280,7 +289,7 @@ class CitusHandler(Thread): except Exception: logger.exception('run') - def _add_task(self, task): + def _add_task(self, task: PgDistNode) -> bool: with self._condition: i = self.find_task_by_group(task.group) @@ -306,32 +315,34 @@ class CitusHandler(Thread): return True return False - def add_task(self, event, group, conn_url, timeout=None, cooldown=None): + def add_task(self, event: str, group: int, conn_url: str, + timeout: Optional[float] = None, cooldown: Optional[float] = None) -> Optional[PgDistNode]: try: r = urlparse(conn_url) except Exception as e: return logger.error('Failed to parse connection url %s: %r', conn_url, e) host = r.hostname - port = r.port or 5432 - task = PgDistNode(group, host, port, event, timeout=timeout, cooldown=cooldown) - return task if self._add_task(task) else None + if host: + port = r.port or 5432 + task = PgDistNode(group, host, port, event, timeout=timeout, cooldown=cooldown) + return task if self._add_task(task) else None - def handle_event(self, cluster, event): + def handle_event(self, cluster: Cluster, event: Dict[str, Any]) -> None: if not self.is_alive(): return - cluster = cluster.workers.get(event['group']) - if not (cluster and cluster.leader and cluster.leader.name == event['leader'] and cluster.leader.conn_url): + worker = cluster.workers.get(event['group']) + if not (worker and worker.leader and worker.leader.name == event['leader'] and worker.leader.conn_url): return task = self.add_task(event['type'], event['group'], - cluster.leader.conn_url, + worker.leader.conn_url, event['timeout'], event['cooldown'] * 1000) if task and event['type'] == 'before_demote': task.wait() - def bootstrap(self): - if not self.is_enabled(): + def bootstrap(self) -> None: + if not isinstance(self._config, dict): # self.is_enabled() return conn_kwargs = self._postgresql.config.local_connect_kwargs @@ -340,7 +351,8 @@ class CitusHandler(Thread): conn = connect(**conn_kwargs) try: with conn.cursor() as cur: - cur.execute('CREATE DATABASE {0}'.format(quote_ident(self._config['database'], conn))) + cur.execute('CREATE DATABASE {0}'.format( + quote_ident(self._config['database'], conn)).encode('utf-8')) finally: conn.close() @@ -364,7 +376,7 @@ class CitusHandler(Thread): finally: conn.close() - def adjust_postgres_gucs(self, parameters): + def adjust_postgres_gucs(self, parameters: Dict[str, Any]) -> None: if not self.is_enabled(): return @@ -381,9 +393,9 @@ class CitusHandler(Thread): # Resharding in Citus implemented using logical replication parameters['wal_level'] = 'logical' - def ignore_replication_slot(self, slot): - if self.is_enabled() and self._postgresql.is_leader() and\ + def ignore_replication_slot(self, slot: Dict[str, str]) -> bool: + if isinstance(self._config, dict) and self._postgresql.is_leader() and\ slot['type'] == 'logical' and slot['database'] == self._config['database']: m = CITUS_SLOT_NAME_RE.match(slot['name']) - return m and {'move': 'pgoutput', 'split': 'citus'}.get(m.group(1)) == slot['plugin'] + return bool(m and {'move': 'pgoutput', 'split': 'citus'}.get(m.group(1)) == slot['plugin']) return False diff --git a/patroni/postgresql/config.py b/patroni/postgresql/config.py index 3f3d8060..4d6bacfb 100644 --- a/patroni/postgresql/config.py +++ b/patroni/postgresql/config.py @@ -7,21 +7,26 @@ import stat import time from urllib.parse import urlparse, parse_qsl, unquote +from types import TracebackType +from typing import Any, Collection, Dict, List, Optional, Union, Tuple, Type, TYPE_CHECKING from .validator import recovery_parameters, transform_postgresql_parameter_value, transform_recovery_parameter_value -from ..collections import CaseInsensitiveDict -from ..dcs import RemoteMember, slot_name_from_member_name +from ..collections import CaseInsensitiveDict, CaseInsensitiveSet +from ..dcs import Leader, Member, RemoteMember, slot_name_from_member_name from ..exceptions import PatroniFatalException -from ..utils import compare_values, parse_bool, parse_int, split_host_port, uri, \ - validate_directory, is_subpath +from ..utils import compare_values, parse_bool, parse_int, split_host_port, uri, validate_directory, is_subpath +from ..validator import IntValidator + +if TYPE_CHECKING: # pragma: no cover + from . import Postgresql logger = logging.getLogger(__name__) PARAMETER_RE = re.compile(r'([a-z_]+)\s*=\s*') -def conninfo_uri_parse(dsn): - ret = {} +def conninfo_uri_parse(dsn: str) -> Dict[str, str]: + ret: Dict[str, str] = {} r = urlparse(dsn) if r.username: ret['user'] = r.username @@ -29,21 +34,18 @@ def conninfo_uri_parse(dsn): ret['password'] = r.password if r.path[1:]: ret['dbname'] = r.path[1:] - hosts = [] - ports = [] + hosts: List[str] = [] + ports: List[str] = [] for netloc in r.netloc.split('@')[-1].split(','): - host = port = None + host = None if '[' in netloc and ']' in netloc: host = netloc.split(']')[0][1:] tmp = netloc.split(':', 1) if host is None: host = tmp[0] + hosts.append(host) if len(tmp) == 2: - host, port = tmp - if host is not None: - hosts.append(host) - if port is not None: - ports.append(port) + ports.append(tmp[1]) if hosts: ret['host'] = ','.join(hosts) if ports: @@ -56,7 +58,7 @@ def conninfo_uri_parse(dsn): return ret -def read_param_value(value): +def read_param_value(value: str) -> Union[Tuple[None, None], Tuple[str, int]]: length = len(value) ret = '' is_quoted = value[0] == "'" @@ -76,8 +78,8 @@ def read_param_value(value): return (None, None) if is_quoted else (ret, i) -def conninfo_parse(dsn): - ret = {} +def conninfo_parse(dsn: str) -> Optional[Dict[str, str]]: + ret: Dict[str, str] = {} length = len(dsn) i = 0 while i < length: @@ -96,14 +98,14 @@ def conninfo_parse(dsn): return value, end = read_param_value(dsn[i:]) - if value is None: + if value is None or end is None: return i += end ret[param] = value return ret -def parse_dsn(value): +def parse_dsn(value: str) -> Optional[Dict[str, str]]: """ Very simple equivalent of `psycopg2.extensions.parse_dsn` introduced in 2.7.0. We are not using psycopg2 function in order to remain compatible with 2.5.4+. @@ -147,14 +149,14 @@ def parse_dsn(value): return ret -def strip_comment(value): +def strip_comment(value: str) -> str: i = value.find('#') if i > -1: value = value[:i].strip() return value -def read_recovery_param_value(value): +def read_recovery_param_value(value: str) -> Optional[str]: """ >>> read_recovery_param_value('') is None True @@ -211,7 +213,7 @@ def read_recovery_param_value(value): return value -def mtime(filename): +def mtime(filename: str) -> Optional[float]: try: return os.stat(filename).st_mtime except OSError: @@ -220,35 +222,49 @@ def mtime(filename): class ConfigWriter(object): - def __init__(self, filename): + def __init__(self, filename: str) -> None: self._filename = filename self._fd = None - def __enter__(self): + def __enter__(self) -> 'ConfigWriter': self._fd = open(self._filename, 'w') self.writeline('# Do not edit this file manually!\n# It will be overwritten by Patroni!') return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]) -> None: if self._fd: self._fd.close() - def writeline(self, line): - self._fd.write(line) - self._fd.write('\n') + def writeline(self, line: str) -> None: + if self._fd: + self._fd.write(line) + self._fd.write('\n') - def writelines(self, lines): + def writelines(self, lines: List[str]) -> None: for line in lines: self.writeline(line) @staticmethod - def escape(value): # Escape (by doubling) any single quotes or backslashes in given string + def escape(value: Any) -> str: # Escape (by doubling) any single quotes or backslashes in given string return re.sub(r'([\'\\])', r'\1\1', str(value)) - def write_param(self, param, value): + def write_param(self, param: str, value: Any) -> None: self.writeline("{0} = '{1}'".format(param, self.escape(value))) +def _false_validator(value: Any) -> bool: + return False + + +def _wal_level_validator(value: Any) -> bool: + return str(value).lower() in ('hot_standby', 'replica', 'logical') + + +def _bool_validator(value: Any) -> bool: + return parse_bool(value) is not None + + class ConfigHandler(object): # List of parameters which must be always passed to postmaster as command line options @@ -266,28 +282,28 @@ class ConfigHandler(object): # check_function -- if the new value is not correct must return `!False` # min_version -- major version of PostgreSQL when parameter was introduced CMDLINE_OPTIONS = CaseInsensitiveDict({ - 'listen_addresses': (None, lambda _: False, 90100), - 'port': (None, lambda _: False, 90100), - 'cluster_name': (None, lambda _: False, 90500), - 'wal_level': ('hot_standby', lambda v: v.lower() in ('hot_standby', 'replica', 'logical'), 90100), - 'hot_standby': ('on', lambda _: False, 90100), - 'max_connections': (100, lambda v: int(v) >= 25, 90100), - 'max_wal_senders': (10, lambda v: int(v) >= 3, 90100), - 'wal_keep_segments': (8, lambda v: int(v) >= 1, 90100), - 'wal_keep_size': ('128MB', lambda v: parse_int(v, 'MB') >= 16, 130000), - 'max_prepared_transactions': (0, lambda v: int(v) >= 0, 90100), - 'max_locks_per_transaction': (64, lambda v: int(v) >= 32, 90100), - 'track_commit_timestamp': ('off', lambda v: parse_bool(v) is not None, 90500), - 'max_replication_slots': (10, lambda v: int(v) >= 4, 90400), - 'max_worker_processes': (8, lambda v: int(v) >= 2, 90400), - 'wal_log_hints': ('on', lambda _: False, 90400) + 'listen_addresses': (None, _false_validator, 90100), + 'port': (None, _false_validator, 90100), + 'cluster_name': (None, _false_validator, 90500), + 'wal_level': ('hot_standby', _wal_level_validator, 90100), + 'hot_standby': ('on', _false_validator, 90100), + 'max_connections': (100, IntValidator(min=25), 90100), + 'max_wal_senders': (10, IntValidator(min=3), 90100), + 'wal_keep_segments': (8, IntValidator(min=1), 90100), + 'wal_keep_size': ('128MB', IntValidator(min=16, base_unit='MB'), 130000), + 'max_prepared_transactions': (0, IntValidator(min=0), 90100), + 'max_locks_per_transaction': (64, IntValidator(min=32), 90100), + 'track_commit_timestamp': ('off', _bool_validator, 90500), + 'max_replication_slots': (10, IntValidator(min=4), 90400), + 'max_worker_processes': (8, IntValidator(min=2), 90400), + 'wal_log_hints': ('on', _false_validator, 90400) }) - _RECOVERY_PARAMETERS = set(recovery_parameters.keys()) + _RECOVERY_PARAMETERS = CaseInsensitiveSet(recovery_parameters.keys()) - def __init__(self, postgresql, config): + def __init__(self, postgresql: 'Postgresql', config: Dict[str, Any]) -> None: self._postgresql = postgresql - self._config_dir = os.path.abspath(config.get('config_dir') or postgresql.data_dir) + self._config_dir = os.path.abspath(config.get('config_dir', '') or postgresql.data_dir) config_base_name = config.get('config_base_name', 'postgresql') self._postgresql_conf = os.path.join(self._config_dir, config_base_name + '.conf') self._postgresql_conf_mtime = None @@ -309,21 +325,22 @@ class ConfigHandler(object): self._passfile_mtime = None self._synchronous_standby_names = None self._postmaster_ctime = None - self._current_recovery_params = None + self._current_recovery_params: Optional[CaseInsensitiveDict] = None self._config = {} - self._recovery_params = {} + self._recovery_params = CaseInsensitiveDict() + self._server_parameters: CaseInsensitiveDict self.reload_config(config) - def setup_server_parameters(self): + def setup_server_parameters(self) -> None: self._server_parameters = self.get_server_parameters(self._config) self._adjust_recovery_parameters() - def try_to_create_dir(self, d, msg): - d = os.path.join(self._postgresql._data_dir, d) - if (not is_subpath(self._postgresql._data_dir, d) or not self._postgresql.data_directory_empty()): + def try_to_create_dir(self, d: str, msg: str) -> None: + d = os.path.join(self._postgresql.data_dir, d) + if (not is_subpath(self._postgresql.data_dir, d) or not self._postgresql.data_directory_empty()): validate_directory(d, msg) - def check_directories(self): + def check_directories(self) -> None: if "unix_socket_directories" in self._server_parameters: for d in self._server_parameters["unix_socket_directories"].split(","): self.try_to_create_dir(d.strip(), "'{}' is defined in unix_socket_directories, {}") @@ -335,7 +352,11 @@ class ConfigHandler(object): "'{}' is defined in `postgresql.pgpass`, {}") @property - def _configuration_to_save(self): + def config_dir(self) -> str: + return self._config_dir + + @property + def _configuration_to_save(self) -> List[str]: configuration = [os.path.basename(self._postgresql_conf)] if 'custom_conf' not in self._config: configuration.append(os.path.basename(self._postgresql_base_conf_name)) @@ -345,7 +366,7 @@ class ConfigHandler(object): configuration.append('pg_ident.conf') return configuration - def save_configuration_files(self, check_custom_bootstrap=False): + def save_configuration_files(self, check_custom_bootstrap: bool = False) -> bool: """ copy postgresql.conf to postgresql.conf.backup to be able to retrieve configuration files - originally stored as symlinks, those are normally skipped by pg_basebackup @@ -362,7 +383,7 @@ class ConfigHandler(object): logger.exception('unable to create backup copies of configuration files') return True - def restore_configuration_files(self): + def restore_configuration_files(self) -> None: """ restore a previously saved postgresql.conf """ try: for f in self._configuration_to_save: @@ -377,7 +398,7 @@ class ConfigHandler(object): except IOError: logger.exception('unable to restore configuration files from backup') - def write_postgresql_conf(self, configuration=None): + def write_postgresql_conf(self, configuration: Optional[CaseInsensitiveDict] = None) -> None: # rename the original configuration if it is necessary if 'custom_conf' not in self._config and not os.path.exists(self._postgresql_base_conf): os.rename(self._postgresql_conf, self._postgresql_base_conf) @@ -412,13 +433,13 @@ class ConfigHandler(object): if not self._postgresql.bootstrap.keep_existing_recovery_conf: self._sanitize_auto_conf() - def append_pg_hba(self, config): + def append_pg_hba(self, config: List[str]) -> bool: if not self.hba_file and not self._config.get('pg_hba'): with open(self._pg_hba_conf, 'a') as f: f.write('\n{}\n'.format('\n'.join(config))) return True - def replace_pg_hba(self): + def replace_pg_hba(self) -> Optional[bool]: """ Replace pg_hba.conf content in the PGDATA if hba_file is not defined in the `postgresql.parameters` and pg_hba is defined in `postgresql` configuration section. @@ -446,7 +467,7 @@ class ConfigHandler(object): f.writelines(self._config['pg_hba']) return True - def replace_pg_ident(self): + def replace_pg_ident(self) -> Optional[bool]: """ Replace pg_ident.conf content in the PGDATA if ident_file is not defined in the `postgresql.parameters` and pg_ident is defined in the `postgresql` section. @@ -459,8 +480,8 @@ class ConfigHandler(object): f.writelines(self._config['pg_ident']) return True - def primary_conninfo_params(self, member): - if not (member and member.conn_url) or member.name == self._postgresql.name: + def primary_conninfo_params(self, member: Union[Leader, Member, None]) -> Optional[Dict[str, Any]]: + if not member or not member.conn_url or member.name == self._postgresql.name: return None ret = member.conn_kwargs(self.replication) ret['application_name'] = self._postgresql.name @@ -475,7 +496,7 @@ class ConfigHandler(object): del ret['dbname'] return ret - def format_dsn(self, params, include_dbname=False): + def format_dsn(self, params: Dict[str, Any], include_dbname: bool = False) -> str: # A list of keywords that can be found in a conninfo string. Follows what is acceptable by libpq keywords = ('dbname', 'user', 'passfile' if params.get('passfile') else 'password', 'host', 'port', 'sslmode', 'sslcompression', 'sslcert', 'sslkey', 'sslpassword', 'sslrootcert', 'sslcrl', @@ -491,13 +512,13 @@ class ConfigHandler(object): else: skip = {'dbname'} - def escape(value): + def escape(value: Any) -> str: return re.sub(r'([\'\\ ])', r'\\\1', str(value)) return ' '.join('{0}={1}'.format(kw, escape(params[kw])) for kw in keywords if kw not in skip and params.get(kw) is not None) - def _write_recovery_params(self, fd, recovery_params): + def _write_recovery_params(self, fd: ConfigWriter, recovery_params: CaseInsensitiveDict) -> None: if self._postgresql.major_version >= 90500: pause_at_recovery_target = parse_bool(recovery_params.pop('pause_at_recovery_target', None)) if pause_at_recovery_target is not None: @@ -518,8 +539,8 @@ class ConfigHandler(object): continue fd.write_param(name, value) - def build_recovery_params(self, member): - recovery_params = CaseInsensitiveDict({p: v for p, v in self.get('recovery_conf', {}).items() + def build_recovery_params(self, member: Union[Leader, Member, None]) -> CaseInsensitiveDict: + recovery_params = CaseInsensitiveDict({p: v for p, v in (self.get('recovery_conf') or {}).items() if not p.lower().startswith('recovery_target') and p.lower() not in ('primary_conninfo', 'primary_slot_name')}) recovery_params.update({'standby_mode': 'on', 'recovery_target_timeline': 'latest'}) @@ -550,26 +571,26 @@ class ConfigHandler(object): recovery_params.update({p: member.data.get(p) for p in standby_cluster_params if member and member.data.get(p)}) return recovery_params - def recovery_conf_exists(self): + def recovery_conf_exists(self) -> bool: if self._postgresql.major_version >= 120000: return os.path.exists(self._standby_signal) or os.path.exists(self._recovery_signal) return os.path.exists(self._recovery_conf) @property - def triggerfile_good_name(self): + def triggerfile_good_name(self) -> str: return 'trigger_file' if self._postgresql.major_version < 120000 else 'promote_trigger_file' @property - def _triggerfile_wrong_name(self): + def _triggerfile_wrong_name(self) -> str: return 'trigger_file' if self._postgresql.major_version >= 120000 else 'promote_trigger_file' @property - def _recovery_parameters_to_compare(self): - skip_params = {'pause_at_recovery_target', 'recovery_target_inclusive', - 'recovery_target_action', 'standby_mode', self._triggerfile_wrong_name} - return self._RECOVERY_PARAMETERS - skip_params + def _recovery_parameters_to_compare(self) -> CaseInsensitiveSet: + skip_params = CaseInsensitiveSet({'pause_at_recovery_target', 'recovery_target_inclusive', + 'recovery_target_action', 'standby_mode', self._triggerfile_wrong_name}) + return CaseInsensitiveSet(self._RECOVERY_PARAMETERS - skip_params) - def _read_recovery_params(self): + def _read_recovery_params(self) -> Tuple[Optional[CaseInsensitiveDict], Optional[bool]]: if self._postgresql.is_starting(): return None, False @@ -586,7 +607,7 @@ class ConfigHandler(object): try: values = self._get_pg_settings(self._recovery_parameters_to_compare).values() - values = {p[0]: [p[1], p[4] == 'postmaster', p[5]] for p in values} + values = CaseInsensitiveDict({p[0]: [p[1], p[4] == 'postmaster', p[5]] for p in values}) self._postgresql_conf_mtime = pg_conf_mtime self._auto_conf_mtime = auto_conf_mtime self._postmaster_ctime = postmaster_ctime @@ -594,13 +615,13 @@ class ConfigHandler(object): values = None return values, True - def _read_recovery_params_pre_v12(self): + def _read_recovery_params_pre_v12(self) -> Tuple[Optional[CaseInsensitiveDict], Optional[bool]]: recovery_conf_mtime = mtime(self._recovery_conf) passfile_mtime = mtime(self._passfile) if self._passfile else False if recovery_conf_mtime == self._recovery_conf_mtime and passfile_mtime == self._passfile_mtime: return None, False - values = {} + values = CaseInsensitiveDict() with open(self._recovery_conf, 'r') as f: for line in f: line = line.strip() @@ -610,7 +631,7 @@ class ConfigHandler(object): match = PARAMETER_RE.match(line) if match: value = read_recovery_param_value(line[match.end():]) - if value is None: + if match is None or value is None: return None, True values[match.group(1)] = [value, True] self._recovery_conf_mtime = recovery_conf_mtime @@ -619,7 +640,7 @@ class ConfigHandler(object): values.update({param: ['', True] for param in self._recovery_parameters_to_compare if param not in values}) return values, True - def _check_passfile(self, passfile, wanted_primary_conninfo): + def _check_passfile(self, passfile: str, wanted_primary_conninfo: Dict[str, Any]) -> bool: # If there is a passfile in the primary_conninfo try to figure out that # the passfile contains the line(s) allowing connection to the given node. # We assume that the passfile was created by Patroni and therefore doing @@ -628,7 +649,7 @@ class ConfigHandler(object): if passfile_mtime: try: with open(passfile) as f: - wanted_lines = self._pgpass_line(wanted_primary_conninfo).splitlines() + wanted_lines = (self._pgpass_line(wanted_primary_conninfo) or '').splitlines() file_lines = f.read().splitlines() if set(wanted_lines) == set(file_lines): self._passfile = passfile @@ -638,7 +659,8 @@ class ConfigHandler(object): logger.info('Failed to read %s', passfile) return False - def _check_primary_conninfo(self, primary_conninfo, wanted_primary_conninfo): + def _check_primary_conninfo(self, primary_conninfo: Dict[str, Any], + wanted_primary_conninfo: Dict[str, Any]) -> bool: # first we will cover corner cases, when we are replicating from somewhere while shouldn't # or there is no primary_conninfo but we should replicate from some specific node. if not wanted_primary_conninfo: @@ -667,7 +689,7 @@ class ConfigHandler(object): return all(str(primary_conninfo.get(p)) == str(v) for p, v in wanted_primary_conninfo.items() if v is not None) - def check_recovery_conf(self, member): + def check_recovery_conf(self, member: Union[Leader, Member, None]) -> Tuple[bool, bool]: """Returns a tuple. The first boolean element indicates that recovery params don't match and the second is set to `True` if the restart is required in order to apply new values""" @@ -701,7 +723,7 @@ class ConfigHandler(object): else: # empty string, primary_conninfo is not in the config primary_conninfo[0] = {} - if not self._postgresql.is_starting(): + if not self._postgresql.is_starting() and self._current_recovery_params: # when wal receiver is alive take primary_slot_name from pg_stat_wal_receiver wal_receiver_primary_slot_name = self._postgresql.primary_slot_name() if not wal_receiver_primary_slot_name and self._postgresql.primary_conninfo(): @@ -715,11 +737,11 @@ class ConfigHandler(object): and not self._postgresql.cb_called and not self._postgresql.is_starting())} - def record_missmatch(mtype): + def record_missmatch(mtype: bool) -> None: required['restart' if mtype else 'reload'] += 1 wanted_recovery_params = self.build_recovery_params(member) - for param, value in self._current_recovery_params.items(): + for param, value in (self._current_recovery_params or {}).items(): # Skip certain parameters defined in the included postgres config files # if we know that they are not specified in the patroni configuration. if len(value) > 2 and value[2] not in (self._postgresql_conf, self._auto_conf) and \ @@ -741,25 +763,25 @@ class ConfigHandler(object): return required['restart'] + required['reload'] > 0, required['restart'] > 0 @staticmethod - def _remove_file_if_exists(name): + def _remove_file_if_exists(name: str) -> None: if os.path.isfile(name) or os.path.islink(name): os.unlink(name) @staticmethod - def _pgpass_line(record): + def _pgpass_line(record: Dict[str, Any]) -> Optional[str]: if 'password' in record: - def escape(value): + def escape(value: Any) -> str: return re.sub(r'([:\\])', r'\\\1', str(value)) record = {n: escape(record.get(n) or '*') for n in ('host', 'port', 'user', 'password')} # 'host' could be several comma-separated hostnames, in this case # we need to write on pgpass line per host line = '' - for hostname in record.get('host').split(','): + for hostname in record['host'].split(','): line += hostname + ':{port}:*:{user}:{password}'.format(**record) + '\n' return line.rstrip() - def write_pgpass(self, record): + def write_pgpass(self, record: Dict[str, Any]) -> Dict[str, str]: line = self._pgpass_line(record) if not line: return os.environ.copy() @@ -770,7 +792,7 @@ class ConfigHandler(object): return {**os.environ, 'PGPASSFILE': self._pgpass} - def write_recovery_conf(self, recovery_params): + def write_recovery_conf(self, recovery_params: CaseInsensitiveDict) -> None: self._recovery_params = recovery_params if self._postgresql.major_version >= 120000: if parse_bool(recovery_params.pop('standby_mode', None)): @@ -779,28 +801,28 @@ class ConfigHandler(object): self._remove_file_if_exists(self._standby_signal) open(self._recovery_signal, 'w').close() - def restart_required(name): + def restart_required(name: str) -> bool: if self._postgresql.major_version >= 140000: return False return name == 'restore_command' or (self._postgresql.major_version < 130000 and name in ('primary_conninfo', 'primary_slot_name')) - self._current_recovery_params = {n: [v, restart_required(n), self._postgresql_conf] - for n, v in recovery_params.items()} + self._current_recovery_params = CaseInsensitiveDict({n: [v, restart_required(n), self._postgresql_conf] + for n, v in recovery_params.items()}) else: with ConfigWriter(self._recovery_conf) as f: os.chmod(self._recovery_conf, stat.S_IWRITE | stat.S_IREAD) self._write_recovery_params(f, recovery_params) - def remove_recovery_conf(self): + def remove_recovery_conf(self) -> None: for name in (self._recovery_conf, self._standby_signal, self._recovery_signal): self._remove_file_if_exists(name) - self._recovery_params = {} + self._recovery_params = CaseInsensitiveDict() self._current_recovery_params = None - def _sanitize_auto_conf(self): + def _sanitize_auto_conf(self) -> None: overwrite = False - lines = [] + lines: List[str] = [] if os.path.exists(self._auto_conf): try: @@ -823,7 +845,7 @@ class ConfigHandler(object): except Exception: logger.exception('Failed to remove some unwanted parameters from %s', self._auto_conf) - def _adjust_recovery_parameters(self): + def _adjust_recovery_parameters(self) -> None: # It is not strictly necessary, but we can make patroni configs crossi-compatible with all postgres versions. recovery_conf = {n: v for n, v in self._server_parameters.items() if n.lower() in self._RECOVERY_PARAMETERS} if recovery_conf: @@ -834,7 +856,7 @@ class ConfigHandler(object): if self.triggerfile_good_name not in self._config['recovery_conf'] and value: self._config['recovery_conf'][self.triggerfile_good_name] = value - def get_server_parameters(self, config): + def get_server_parameters(self, config: Dict[str, Any]) -> CaseInsensitiveDict: parameters = config['parameters'].copy() listen_addresses, port = split_host_port(config['listen'], 5432) parameters.update(cluster_name=self._postgresql.scope, listen_addresses=listen_addresses, port=str(port)) @@ -860,7 +882,7 @@ class ConfigHandler(object): parameters.setdefault('wal_keep_size', str(int(wal_keep_segments) * 16) + 'MB') elif self._postgresql.major_version: wal_keep_size = parse_int(parameters.pop('wal_keep_size', self.CMDLINE_OPTIONS['wal_keep_size'][0]), 'MB') - parameters.setdefault('wal_keep_segments', int((wal_keep_size + 8) / 16)) + parameters.setdefault('wal_keep_segments', int(((wal_keep_size or 0) + 8) / 16)) self._postgresql.citus_handler.adjust_postgres_gucs(parameters) @@ -870,14 +892,14 @@ class ConfigHandler(object): return ret @staticmethod - def _get_unix_local_address(unix_socket_directories): + def _get_unix_local_address(unix_socket_directories: str) -> str: for d in unix_socket_directories.split(','): d = d.strip() if d.startswith('/'): # Only absolute path can be used to connect via unix-socket return d return '' - def _get_tcp_local_address(self): + def _get_tcp_local_address(self) -> str: listen_addresses = self._server_parameters['listen_addresses'].split(',') for la in listen_addresses: @@ -886,7 +908,7 @@ class ConfigHandler(object): return listen_addresses[0].strip() # can't use localhost, take first address from listen_addresses @property - def local_connect_kwargs(self): + def local_connect_kwargs(self) -> Dict[str, Any]: ret = self._local_address.copy() # add all of the other connection settings that are available ret.update(self._superuser) @@ -902,7 +924,7 @@ class ConfigHandler(object): 'options': '-c statement_timeout=2000'}) return ret - def resolve_connection_addresses(self): + def resolve_connection_addresses(self) -> None: port = self._server_parameters['port'] tcp_local_address = self._get_tcp_local_address() netloc = self._config.get('connect_address') or tcp_local_address + ':' + port @@ -922,19 +944,24 @@ class ConfigHandler(object): self._postgresql.connection_string = uri('postgres', netloc, self._postgresql.database) self._postgresql.set_connection_kwargs(self.local_connect_kwargs) - def _get_pg_settings(self, names): + def _get_pg_settings( + self, names: Collection[str] + ) -> Dict[str, Tuple[str, str, Optional[str], str, str, Optional[str]]]: return {r[0]: r for r in self._postgresql.query(('SELECT name, setting, unit, vartype, context, sourcefile' + ' FROM pg_catalog.pg_settings ' + ' WHERE pg_catalog.lower(name) = ANY(%s)'), [n.lower() for n in names])} @staticmethod - def _handle_wal_buffers(old_values, changes): - wal_block_size = parse_int(old_values['wal_block_size'][1]) + def _handle_wal_buffers(old_values: Dict[str, Tuple[str, str, Optional[str], str, str, Optional[str]]], + changes: CaseInsensitiveDict) -> None: + wal_block_size = parse_int(old_values['wal_block_size'][1]) or 8192 wal_segment_size = old_values['wal_segment_size'] - wal_segment_unit = parse_int(wal_segment_size[2], 'B') if wal_segment_size[2][0].isdigit() else 1 - wal_segment_size = parse_int(wal_segment_size[1]) * wal_segment_unit / wal_block_size - default_wal_buffers = min(max(parse_int(old_values['shared_buffers'][1]) / 32, 8), wal_segment_size) + wal_segment_unit = parse_int(wal_segment_size[2], 'B') or 8192 \ + if wal_segment_size[2] is not None and wal_segment_size[2][0].isdigit() else 1 + wal_segment_size = parse_int(wal_segment_size[1]) or (16777216 if wal_segment_size[2] is None else 2048) + wal_segment_size *= wal_segment_unit / wal_block_size + default_wal_buffers = min(max((parse_int(old_values['shared_buffers'][1]) or 16384) / 32, 8), wal_segment_size) wal_buffers = old_values['wal_buffers'] new_value = str(changes['wal_buffers'] or -1) @@ -945,7 +972,7 @@ class ConfigHandler(object): if new_value == old_value: del changes['wal_buffers'] - def reload_config(self, config, sighup=False): + def reload_config(self, config: Dict[str, Any], sighup: bool = False) -> None: self._superuser = config['authentication'].get('superuser', {}) server_parameters = self.get_server_parameters(config) @@ -956,6 +983,7 @@ class ConfigHandler(object): changes.update({p: None for p in self._server_parameters.keys() if not (p in changes or p.lower() in self._RECOVERY_PARAMETERS)}) if changes: + undef = [] if 'wal_buffers' in changes: # we need to calculate the default value of wal_buffers undef = [p for p in ('shared_buffers', 'wal_segment_size', 'wal_block_size') if p not in changes] changes.update({p: None for p in undef}) @@ -1030,16 +1058,17 @@ class ConfigHandler(object): if self._postgresql.major_version >= 90500: time.sleep(1) try: - pending_restart = self._postgresql.query( - 'SELECT COUNT(*) FROM pg_catalog.pg_settings WHERE pg_catalog.lower(name) != ALL(%s)' - ' AND pending_restart', [n.lower() for n in self._RECOVERY_PARAMETERS]).fetchone()[0] > 0 + pending_restart = (self._postgresql.query( + 'SELECT COUNT(*) FROM pg_catalog.pg_settings' + ' WHERE pg_catalog.lower(name) != ALL(%s) AND pending_restart', + [n.lower() for n in self._RECOVERY_PARAMETERS]).fetchone() or (0,))[0] > 0 self._postgresql.set_pending_restart(pending_restart) except Exception as e: logger.warning('Exception %r when running query', e) else: logger.info('No PostgreSQL configuration items changed, nothing to reload.') - def set_synchronous_standby_names(self, value): + def set_synchronous_standby_names(self, value: Optional[str]) -> Optional[bool]: """Updates synchronous_standby_names and reloads if necessary. :returns: True if value was updated.""" if value != self._synchronous_standby_names: @@ -1054,7 +1083,7 @@ class ConfigHandler(object): return True @property - def effective_configuration(self): + def effective_configuration(self) -> CaseInsensitiveDict: """It might happen that the current value of one (or more) below parameters stored in the controldata is higher than the value stored in the global cluster configuration. @@ -1089,7 +1118,7 @@ class ConfigHandler(object): continue cvalue = parse_int(data[cname]) - if cvalue > value: + if cvalue is not None and value is not None and cvalue > value: effective_configuration[name] = cvalue self._postgresql.set_pending_restart(True) @@ -1113,35 +1142,35 @@ class ConfigHandler(object): return effective_configuration @property - def replication(self): + def replication(self) -> Dict[str, Any]: return self._config['authentication']['replication'] @property - def superuser(self): + def superuser(self) -> Dict[str, Any]: return self._superuser @property - def rewind_credentials(self): + def rewind_credentials(self) -> Dict[str, Any]: return self._config['authentication'].get('rewind', self._superuser) \ if self._postgresql.major_version >= 110000 else self._superuser @property - def ident_file(self): + def ident_file(self) -> Optional[str]: ident_file = self._server_parameters.get('ident_file') return None if ident_file == self._pg_ident_conf else ident_file @property - def hba_file(self): + def hba_file(self) -> Optional[str]: hba_file = self._server_parameters.get('hba_file') return None if hba_file == self._pg_hba_conf else hba_file @property - def pg_hba_conf(self): + def pg_hba_conf(self) -> str: return self._pg_hba_conf @property - def postgresql_conf(self): + def postgresql_conf(self) -> str: return self._postgresql_conf - def get(self, key, default=None): + def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]: return self._config.get(key, default) diff --git a/patroni/postgresql/connection.py b/patroni/postgresql/connection.py index c5729284..6a556983 100644 --- a/patroni/postgresql/connection.py +++ b/patroni/postgresql/connection.py @@ -2,6 +2,10 @@ import logging from contextlib import contextmanager from threading import Lock +from typing import Any, Dict, Generator, Union, TYPE_CHECKING +if TYPE_CHECKING: # pragma: no cover + from psycopg import Connection as Connection3, Cursor + from psycopg2 import connection, cursor from .. import psycopg @@ -9,29 +13,30 @@ logger = logging.getLogger(__name__) class Connection(object): + server_version: int - def __init__(self): + def __init__(self) -> None: self._lock = Lock() self._connection = None self._cursor_holder = None - def set_conn_kwargs(self, conn_kwargs): + def set_conn_kwargs(self, conn_kwargs: Dict[str, Any]) -> None: self._conn_kwargs = conn_kwargs - def get(self): + def get(self) -> Union['connection', 'Connection3[Any]']: with self._lock: if not self._connection or self._connection.closed != 0: self._connection = psycopg.connect(**self._conn_kwargs) - self.server_version = self._connection.server_version + self.server_version = getattr(self._connection, 'server_version', 0) return self._connection - def cursor(self): + def cursor(self) -> Union['cursor', 'Cursor[Any]']: if not self._cursor_holder or self._cursor_holder.closed or self._cursor_holder.connection.closed != 0: logger.info("establishing a new patroni connection to the postgres cluster") self._cursor_holder = self.get().cursor() return self._cursor_holder - def close(self): + def close(self) -> None: if self._connection and self._connection.closed == 0: self._connection.close() logger.info("closed patroni connection to the postgresql cluster") @@ -39,7 +44,7 @@ class Connection(object): @contextmanager -def get_connection_cursor(**kwargs): +def get_connection_cursor(**kwargs: Any) -> Generator[Union['cursor', 'Cursor[Any]'], None, None]: conn = psycopg.connect(**kwargs) with conn.cursor() as cur: yield cur diff --git a/patroni/postgresql/misc.py b/patroni/postgresql/misc.py index 462e5d43..a8a2c296 100644 --- a/patroni/postgresql/misc.py +++ b/patroni/postgresql/misc.py @@ -2,7 +2,9 @@ import errno import logging import os -from patroni.exceptions import PostgresException +from typing import Iterable, Tuple + +from ..exceptions import PostgresException logger = logging.getLogger(__name__) @@ -55,29 +57,27 @@ def postgres_major_version_to_int(pg_version: str) -> int: return postgres_version_to_int(pg_version + '.0') -def parse_lsn(lsn): +def parse_lsn(lsn: str) -> int: t = lsn.split('/') return int(t[0], 16) * 0x100000000 + int(t[1], 16) -def parse_history(data): +def parse_history(data: str) -> Iterable[Tuple[int, int, str]]: for line in data.split('\n'): values = line.strip().split('\t') if len(values) == 3: try: - values[0] = int(values[0]) - values[1] = parse_lsn(values[1]) - yield values + yield int(values[0]), parse_lsn(values[1]), values[2] except (IndexError, ValueError): logger.exception('Exception when parsing timeline history line "%s"', values) -def format_lsn(lsn, full=False): +def format_lsn(lsn: int, full: bool = False) -> str: template = '{0:X}/{1:08X}' if full else '{0:X}/{1:X}' return template.format(lsn >> 32, lsn & 0xFFFFFFFF) -def fsync_dir(path): +def fsync_dir(path: str) -> None: if os.name != 'nt': fd = os.open(path, os.O_DIRECTORY) try: diff --git a/patroni/postgresql/postmaster.py b/patroni/postgresql/postmaster.py index e2a88856..4505e7f7 100644 --- a/patroni/postgresql/postmaster.py +++ b/patroni/postgresql/postmaster.py @@ -7,6 +7,9 @@ import signal import subprocess import sys +from multiprocessing.connection import Connection +from typing import Dict, Optional, List + from patroni import PATRONI_ENV_PREFIX, KUBERNETES_ENV_PREFIX # avoid spawning the resource tracker process @@ -26,7 +29,7 @@ STOP_SIGNALS = { } -def pg_ctl_start(conn, cmdline, env): +def pg_ctl_start(conn: Connection, cmdline: List[str], env: Dict[str, str]) -> None: if os.name != 'nt': os.setsid() try: @@ -40,7 +43,8 @@ def pg_ctl_start(conn, cmdline, env): class PostmasterProcess(psutil.Process): - def __init__(self, pid): + def __init__(self, pid: int) -> None: + self._postmaster_pid: Dict[str, str] self.is_single_user = False if pid < 0: pid = -pid @@ -48,7 +52,7 @@ class PostmasterProcess(psutil.Process): super(PostmasterProcess, self).__init__(pid) @staticmethod - def _read_postmaster_pidfile(data_dir): + def _read_postmaster_pidfile(data_dir: str) -> Dict[str, str]: """Reads and parses postmaster.pid from the data directory :returns dictionary of values if successful, empty dictionary otherwise @@ -60,7 +64,7 @@ class PostmasterProcess(psutil.Process): except IOError: return {} - def _is_postmaster_process(self): + def _is_postmaster_process(self) -> bool: try: start_time = int(self._postmaster_pid.get('start_time', 0)) if start_time and abs(self.create_time() - start_time) > 3: @@ -79,7 +83,7 @@ class PostmasterProcess(psutil.Process): return True @classmethod - def _from_pidfile(cls, data_dir): + def _from_pidfile(cls, data_dir: str) -> Optional['PostmasterProcess']: postmaster_pid = PostmasterProcess._read_postmaster_pidfile(data_dir) try: pid = int(postmaster_pid.get('pid', 0)) @@ -88,10 +92,10 @@ class PostmasterProcess(psutil.Process): proc._postmaster_pid = postmaster_pid return proc except ValueError: - pass + return None @staticmethod - def from_pidfile(data_dir): + def from_pidfile(data_dir: str) -> Optional['PostmasterProcess']: try: proc = PostmasterProcess._from_pidfile(data_dir) return proc if proc and proc._is_postmaster_process() else None @@ -99,13 +103,13 @@ class PostmasterProcess(psutil.Process): return None @classmethod - def from_pid(cls, pid): + def from_pid(cls, pid: int) -> Optional['PostmasterProcess']: try: return cls(pid) except psutil.NoSuchProcess: return None - def signal_kill(self): + def signal_kill(self) -> bool: """to suspend and kill postmaster and all children :returns True if postmaster and children are killed, False if error @@ -141,7 +145,7 @@ class PostmasterProcess(psutil.Process): psutil.wait_procs(children + [self]) return True - def signal_stop(self, mode, pg_ctl='pg_ctl'): + def signal_stop(self, mode: str, pg_ctl: str = 'pg_ctl') -> Optional[bool]: """Signal postmaster process to stop :returns None if signaled, True if process is already gone, False if error @@ -161,7 +165,7 @@ class PostmasterProcess(psutil.Process): return None - def pg_ctl_kill(self, mode, pg_ctl): + def pg_ctl_kill(self, mode: str, pg_ctl: str) -> Optional[bool]: try: status = subprocess.call([pg_ctl, "kill", STOP_SIGNALS[mode], str(self.pid)]) except OSError: @@ -171,7 +175,7 @@ class PostmasterProcess(psutil.Process): else: return not self.is_running() - def wait_for_user_backends_to_close(self, stop_timeout): + def wait_for_user_backends_to_close(self, stop_timeout: Optional[float]) -> None: # These regexps are cross checked against versions PostgreSQL 9.1 .. 15 aux_proc_re = re.compile("(?:postgres:)( .*:)? (?:(?:archiver|startup|autovacuum launcher|autovacuum worker|" "checkpointer|logger|stats collector|wal receiver|wal writer|writer)(?: process )?|" @@ -183,8 +187,8 @@ class PostmasterProcess(psutil.Process): except psutil.Error: return logger.debug('Failed to get list of postmaster children') - user_backends = [] - user_backends_cmdlines = {} + user_backends: List[psutil.Process] = [] + user_backends_cmdlines: Dict[int, str] = {} for child in children: try: cmdline = child.cmdline() @@ -195,7 +199,7 @@ class PostmasterProcess(psutil.Process): pass if user_backends: logger.debug('Waiting for user backends %s to close', ', '.join(user_backends_cmdlines.values())) - gone, live = psutil.wait_procs(user_backends, stop_timeout) + _, live = psutil.wait_procs(user_backends, stop_timeout) if stop_timeout and live: live = [user_backends_cmdlines[b.pid] for b in live] logger.warning('Backends still alive after %s: %s', stop_timeout, ', '.join(live)) @@ -203,7 +207,7 @@ class PostmasterProcess(psutil.Process): logger.debug("Backends closed") @staticmethod - def start(pgcommand, data_dir, conf, options): + def start(pgcommand: str, data_dir: str, conf: str, options: List[str]) -> Optional['PostmasterProcess']: # Unfortunately `pg_ctl start` does not return postmaster pid to us. Without this information # it is hard to know the current state of postgres startup, so we had to reimplement pg_ctl start # in python. It will start postgres, wait for port to be open and wait until postgres will start @@ -234,7 +238,7 @@ class PostmasterProcess(psutil.Process): pass cmdline = [pgcommand, '-D', data_dir, '--config-file={}'.format(conf)] + options logger.debug("Starting postgres: %s", " ".join(cmdline)) - ctx = multiprocessing.get_context('spawn') if sys.version_info >= (3, 4) else multiprocessing + ctx = multiprocessing.get_context('spawn') parent_conn, child_conn = ctx.Pipe(False) proc = ctx.Process(target=pg_ctl_start, args=(child_conn, cmdline, env)) proc.start() diff --git a/patroni/postgresql/rewind.py b/patroni/postgresql/rewind.py index ba323eb5..f70ad065 100644 --- a/patroni/postgresql/rewind.py +++ b/patroni/postgresql/rewind.py @@ -5,36 +5,46 @@ import shlex import shutil import subprocess +from enum import IntEnum from threading import Lock, Thread +from typing import Any, Callable, Dict, List, Optional, Union, Tuple +from . import Postgresql from .connection import get_connection_cursor from .misc import format_lsn, fsync_dir, parse_history, parse_lsn from ..async_executor import CriticalTask -from ..dcs import Leader +from ..dcs import Leader, RemoteMember logger = logging.getLogger(__name__) -REWIND_STATUS = type('Enum', (), {'INITIAL': 0, 'CHECKPOINT': 1, 'CHECK': 2, 'NEED': 3, - 'NOT_NEED': 4, 'SUCCESS': 5, 'FAILED': 6}) + +class REWIND_STATUS(IntEnum): + INITIAL = 0 + CHECKPOINT = 1 + CHECK = 2 + NEED = 3 + NOT_NEED = 4 + SUCCESS = 5 + FAILED = 6 class Rewind(object): - def __init__(self, postgresql): + def __init__(self, postgresql: Postgresql) -> None: self._postgresql = postgresql self._checkpoint_task_lock = Lock() self.reset_state() @staticmethod - def configuration_allows_rewind(data): + def configuration_allows_rewind(data: Dict[str, str]) -> bool: return data.get('wal_log_hints setting', 'off') == 'on' or data.get('Data page checksum version', '0') != '0' @property - def enabled(self): - return self._postgresql.config.get('use_pg_rewind') + def enabled(self) -> bool: + return bool(self._postgresql.config.get('use_pg_rewind')) @property - def can_rewind(self): + def can_rewind(self) -> bool: """ check if pg_rewind executable is there and that pg_controldata indicates we have either wal_log_hints or checksums turned on """ @@ -52,43 +62,45 @@ class Rewind(object): return self.configuration_allows_rewind(self._postgresql.controldata()) @property - def should_remove_data_directory_on_diverged_timelines(self): - return self._postgresql.config.get('remove_data_directory_on_diverged_timelines') + def should_remove_data_directory_on_diverged_timelines(self) -> bool: + return bool(self._postgresql.config.get('remove_data_directory_on_diverged_timelines')) @property - def can_rewind_or_reinitialize_allowed(self): + def can_rewind_or_reinitialize_allowed(self) -> bool: return self.should_remove_data_directory_on_diverged_timelines or self.can_rewind - def trigger_check_diverged_lsn(self): + def trigger_check_diverged_lsn(self) -> None: if self.can_rewind_or_reinitialize_allowed and self._state != REWIND_STATUS.NEED: self._state = REWIND_STATUS.CHECK @staticmethod - def check_leader_is_not_in_recovery(conn_kwargs): + def check_leader_is_not_in_recovery(conn_kwargs: Dict[str, Any]) -> Optional[bool]: try: with get_connection_cursor(connect_timeout=3, options='-c statement_timeout=2000', **conn_kwargs) as cur: cur.execute('SELECT pg_catalog.pg_is_in_recovery()') - if not cur.fetchone()[0]: + row = cur.fetchone() + if not row or not row[0]: return True logger.info('Leader is still in_recovery and therefore can\'t be used for rewind') except Exception: return logger.exception('Exception when working with leader') @staticmethod - def check_leader_has_run_checkpoint(conn_kwargs): + def check_leader_has_run_checkpoint(conn_kwargs: Dict[str, Any]) -> Optional[str]: try: with get_connection_cursor(connect_timeout=3, options='-c statement_timeout=2000', **conn_kwargs) as cur: cur.execute("SELECT NOT pg_catalog.pg_is_in_recovery()" " AND ('x' || pg_catalog.substr(pg_catalog.pg_walfile_name(" " pg_catalog.pg_current_wal_lsn()), 1, 8))::bit(32)::int = timeline_id" " FROM pg_catalog.pg_control_checkpoint()") - if not cur.fetchone()[0]: + row = cur.fetchone() + if not row or not row[0]: return 'leader has not run a checkpoint yet' except Exception: logger.exception('Exception when working with leader') return 'not accessible or not healty' - def _get_checkpoint_end(self, timeline, lsn): + def _get_checkpoint_end(self, timeline: int, lsn: int) -> int: """The checkpoint record size in WAL depends on postgres major version and platform (memory alignment). Hence, the only reliable way to figure out where it ends, read the record from file with the help of pg_waldump and parse the output. We are trying to read two records, and expect that it will fail to read the second one: @@ -96,12 +108,12 @@ class Rewind(object): The error message contains information about LSN of the next record, which is exactly where checkpoint ends.""" lsn8 = format_lsn(lsn, True) - lsn = format_lsn(lsn) - out, err = self._postgresql.waldump(timeline, lsn, 2) + lsn_str = format_lsn(lsn) + out, err = self._postgresql.waldump(timeline, lsn_str, 2) if out is not None and err is not None: out = out.decode('utf-8').rstrip().split('\n') err = err.decode('utf-8').rstrip().split('\n') - pattern = 'error in WAL record at {0}: invalid record length at '.format(lsn) + pattern = 'error in WAL record at {0}: invalid record length at '.format(lsn_str) if len(out) == 1 and len(err) == 1 and ', lsn: {0}, prev '.format(lsn8) in out[0] and pattern in err[0]: i = err[0].find(pattern) + len(pattern) @@ -117,20 +129,20 @@ class Rewind(object): return 0 - def _get_local_timeline_lsn_from_controldata(self): + def _get_local_timeline_lsn_from_controldata(self) -> Tuple[Optional[bool], Optional[int], Optional[int]]: in_recovery = timeline = lsn = None data = self._postgresql.controldata() try: if data.get('Database cluster state') in ('shut down in recovery', 'in archive recovery'): in_recovery = True lsn = data.get('Minimum recovery ending location') - timeline = int(data.get("Min recovery ending loc's timeline")) + timeline = int(data.get("Min recovery ending loc's timeline", "")) if lsn == '0/0' or timeline == 0: # it was a primary when it crashed data['Database cluster state'] = 'shut down' if data.get('Database cluster state') == 'shut down': in_recovery = False lsn = data.get('Latest checkpoint location') - timeline = int(data.get("Latest checkpoint's TimeLineID")) + timeline = int(data.get("Latest checkpoint's TimeLineID", "")) except (TypeError, ValueError): logger.exception('Failed to get local timeline and lsn from pg_controldata output') @@ -143,7 +155,7 @@ class Rewind(object): return in_recovery, timeline, lsn - def _get_local_timeline_lsn(self): + def _get_local_timeline_lsn(self) -> Tuple[Optional[bool], Optional[int], Optional[int]]: if self._postgresql.is_running(): # if postgres is running - get timeline from replication connection in_recovery = True timeline = self._postgresql.received_timeline() or self._postgresql.get_replica_timeline() @@ -156,14 +168,15 @@ class Rewind(object): return in_recovery, timeline, lsn @staticmethod - def _log_primary_history(history, i): + def _log_primary_history(history: List[Tuple[int, int, str]], i: int) -> None: start = max(0, i - 3) end = None if i + 4 >= len(history) else i + 2 - history_show = [] + history_show: List[str] = [] - def format_history_line(line): + def format_history_line(line: Tuple[int, int, str]) -> str: return '{0}\t{1}\t{2}'.format(line[0], format_lsn(line[1]), line[2]) + line = None for line in history[start:end]: history_show.append(format_history_line(line)) @@ -173,7 +186,7 @@ class Rewind(object): logger.info('primary: history=%s', '\n'.join(history_show)) - def _conn_kwargs(self, member, auth): + def _conn_kwargs(self, member: Union[Leader, RemoteMember], auth: Dict[str, Any]) -> Dict[str, Any]: ret = member.conn_kwargs(auth) if not ret.get('dbname'): ret['dbname'] = self._postgresql.database @@ -183,7 +196,7 @@ class Rewind(object): ret['target_session_attrs'] = 'read-write' return ret - def _check_timeline_and_lsn(self, leader): + def _check_timeline_and_lsn(self, leader: Union[Leader, RemoteMember]) -> None: in_recovery, local_timeline, local_lsn = self._get_local_timeline_lsn() if local_timeline is None or local_lsn is None: return @@ -205,23 +218,28 @@ class Rewind(object): try: with self._postgresql.get_replication_connection_cursor(**leader.conn_kwargs()) as cur: cur.execute('IDENTIFY_SYSTEM') - primary_timeline = cur.fetchone()[1] - logger.info('primary_timeline=%s', primary_timeline) - if local_timeline > primary_timeline: # Not always supported by pg_rewind - need_rewind = True - elif local_timeline == primary_timeline: - need_rewind = False - elif primary_timeline > 1: - cur.execute('TIMELINE_HISTORY {0}'.format(primary_timeline)) - history = cur.fetchone()[1] - if not isinstance(history, str): - history = bytes(history).decode('utf-8') - logger.debug('primary: history=%s', history) + row = cur.fetchone() + if row: + primary_timeline = row[1] + logger.info('primary_timeline=%s', primary_timeline) + if local_timeline > primary_timeline: # Not always supported by pg_rewind + need_rewind = True + elif local_timeline == primary_timeline: + need_rewind = False + elif primary_timeline > 1: + cur.execute('TIMELINE_HISTORY {0}'.format(primary_timeline).encode('utf-8')) + row = cur.fetchone() + if row: + history = row[1] + if not isinstance(history, str): + history = bytes(history).decode('utf-8') + logger.debug('primary: history=%s', history) except Exception: return logger.exception('Exception when working with primary via replication connection') if history is not None: history = list(parse_history(history)) + i = len(history) for i, (parent_timeline, switchpoint, _) in enumerate(history): if parent_timeline == local_timeline: # We don't need to rewind when: @@ -243,12 +261,12 @@ class Rewind(object): self._state = need_rewind and REWIND_STATUS.NEED or REWIND_STATUS.NOT_NEED - def rewind_or_reinitialize_needed_and_possible(self, leader): + def rewind_or_reinitialize_needed_and_possible(self, leader: Union[Leader, RemoteMember, None]) -> bool: if leader and leader.name != self._postgresql.name and leader.conn_url and self._state == REWIND_STATUS.CHECK: self._check_timeline_and_lsn(leader) - return leader and leader.conn_url and self._state == REWIND_STATUS.NEED + return bool(leader and leader.conn_url) and self._state == REWIND_STATUS.NEED - def __checkpoint(self, task, wakeup): + def __checkpoint(self, task: CriticalTask, wakeup: Callable[..., Any]) -> None: try: result = self._postgresql.checkpoint() except Exception as e: @@ -258,7 +276,7 @@ class Rewind(object): if task.result: wakeup() - def ensure_checkpoint_after_promote(self, wakeup): + def ensure_checkpoint_after_promote(self, wakeup: Callable[..., Any]) -> None: """After promote issue a CHECKPOINT from a new thread and asynchronously check the result. In case if CHECKPOINT failed, just check that timeline in pg_control was updated.""" @@ -275,10 +293,10 @@ class Rewind(object): self._checkpoint_task = CriticalTask() Thread(target=self.__checkpoint, args=(self._checkpoint_task, wakeup)).start() - def checkpoint_after_promote(self): + def checkpoint_after_promote(self) -> bool: return self._state == REWIND_STATUS.CHECKPOINT - def _buid_archiver_command(self, command, wal_filename): + def _buid_archiver_command(self, command: str, wal_filename: str) -> str: """Replace placeholders in the given archiver command's template. Applicable for archive_command and restore_command. Can also be used for archive_cleanup_command and recovery_end_command, @@ -306,13 +324,13 @@ class Rewind(object): return cmd - def _fetch_missing_wal(self, restore_command, wal_filename): + def _fetch_missing_wal(self, restore_command: str, wal_filename: str) -> bool: cmd = self._buid_archiver_command(restore_command, wal_filename) logger.info('Trying to fetch the missing wal: %s', cmd) return self._postgresql.cancellable.call(shlex.split(cmd)) == 0 - def _find_missing_wal(self, data): + def _find_missing_wal(self, data: bytes) -> Optional[str]: # could not open file "$PGDATA/pg_wal/0000000A00006AA100000068": No such file or directory pattern = 'could not open file "' for line in data.decode('utf-8').split('\n'): @@ -325,7 +343,7 @@ class Rewind(object): if waldir.endswith('/pg_' + self._postgresql.wal_name) and len(wal_filename) == 24: return wal_filename - def _archive_ready_wals(self): + def _archive_ready_wals(self) -> None: """Try to archive WALs that have .ready files just in case archive_mode was not set to 'always' before promote, while after it the WALs were recycled on the promoted replica. @@ -361,7 +379,7 @@ class Rewind(object): else: logger.info('Failed to archive WAL segment %s', wal) - def _maybe_clean_pg_replslot(self): + def _maybe_clean_pg_replslot(self) -> None: """Clean pg_replslot directory if pg version is less then 11 (pg_rewind deletes $PGDATA/pg_replslot content only since pg11).""" if self._postgresql.major_version < 110000: @@ -373,33 +391,33 @@ class Rewind(object): except Exception as e: logger.warning('Unable to clean %s: %r', replslot_dir, e) - def pg_rewind(self, r): + def pg_rewind(self, r: Dict[str, Any]) -> bool: # prepare pg_rewind connection env = self._postgresql.config.write_pgpass(r) env.update(LANG='C', LC_ALL='C', PGOPTIONS='-c statement_timeout=0') dsn = self._postgresql.config.format_dsn(r, True) logger.info('running pg_rewind from %s', dsn) - restore_command = self._postgresql.config.get('recovery_conf', {}).get('restore_command') \ + restore_command = (self._postgresql.config.get('recovery_conf') or {}).get('restore_command') \ if self._postgresql.major_version < 120000 else self._postgresql.get_guc_value('restore_command') # Until v15 pg_rewind expected postgresql.conf to be inside $PGDATA, which is not the case on e.g. Debian pg_rewind_can_restore = restore_command and (self._postgresql.major_version >= 150000 or (self._postgresql.major_version >= 130000 - and self._postgresql.config._config_dir + and self._postgresql.config.config_dir == self._postgresql.data_dir)) cmd = [self._postgresql.pgcommand('pg_rewind')] if pg_rewind_can_restore: cmd.append('--restore-target-wal') if self._postgresql.major_version >= 150000 and\ - self._postgresql.config._config_dir != self._postgresql.data_dir: + self._postgresql.config.config_dir != self._postgresql.data_dir: cmd.append('--config-file={0}'.format(self._postgresql.config.postgresql_conf)) cmd.extend(['-D', self._postgresql.data_dir, '--source-server', dsn]) while True: - results = {} + results: Dict[str, bytes] = {} ret = self._postgresql.cancellable.call(cmd, env=env, communicate=results) logger.info('pg_rewind exit code=%s', ret) @@ -422,7 +440,7 @@ class Rewind(object): logger.info('Failed to fetch WAL segment %s required for pg_rewind', missing_wal) return False - def execute(self, leader): + def execute(self, leader: Union[Leader, RemoteMember]) -> Optional[bool]: if self._postgresql.is_running() and not self._postgresql.stop(checkpoint=False): return logger.warning('Can not run pg_rewind because postgres is still running') @@ -471,26 +489,26 @@ class Rewind(object): break return False - def reset_state(self): + def reset_state(self) -> None: self._state = REWIND_STATUS.INITIAL with self._checkpoint_task_lock: self._checkpoint_task = None @property - def is_needed(self): + def is_needed(self) -> bool: return self._state in (REWIND_STATUS.CHECK, REWIND_STATUS.NEED) @property - def executed(self): + def executed(self) -> bool: return self._state > REWIND_STATUS.NOT_NEED @property - def failed(self): + def failed(self) -> bool: return self._state == REWIND_STATUS.FAILED - def read_postmaster_opts(self): + def read_postmaster_opts(self) -> Dict[str, str]: """returns the list of option names/values from postgres.opts, Empty dict if read failed or no file""" - result = {} + result: Dict[str, str] = {} try: with open(os.path.join(self._postgresql.data_dir, 'postmaster.opts')) as f: data = f.read() @@ -502,7 +520,8 @@ class Rewind(object): logger.exception('Error when reading postmaster.opts') return result - def single_user_mode(self, communicate=None, options=None): + def single_user_mode(self, communicate: Optional[Dict[str, Any]] = None, + options: Optional[Dict[str, str]] = None) -> Optional[int]: """run a given command in a single-user mode. If the command is empty - then just start and stop""" cmd = [self._postgresql.pgcommand('postgres'), '--single', '-D', self._postgresql.data_dir] for opt, val in sorted((options or {}).items()): @@ -511,7 +530,7 @@ class Rewind(object): cmd.append('template1') return self._postgresql.cancellable.call(cmd, communicate=communicate) - def cleanup_archive_status(self): + def cleanup_archive_status(self) -> None: status_dir = os.path.join(self._postgresql.wal_dir, 'archive_status') try: for f in os.listdir(status_dir): @@ -526,7 +545,7 @@ class Rewind(object): except OSError: logger.exception('Unable to list %s', status_dir) - def ensure_clean_shutdown(self): + def ensure_clean_shutdown(self) -> Optional[bool]: self._archive_ready_wals() self.cleanup_archive_status() @@ -534,7 +553,7 @@ class Rewind(object): opts = self.read_postmaster_opts() opts.update({'archive_mode': 'on', 'archive_command': 'false'}) self._postgresql.config.remove_recovery_conf() - output = {} + output: Dict[str, bytes] = {} ret = self.single_user_mode(communicate=output, options=opts) if ret != 0: logger.error('Crash recovery finished with code=%s', ret) diff --git a/patroni/postgresql/slots.py b/patroni/postgresql/slots.py index aa5efd8e..b0a53561 100644 --- a/patroni/postgresql/slots.py +++ b/patroni/postgresql/slots.py @@ -5,36 +5,43 @@ import shutil from collections import defaultdict from contextlib import contextmanager from threading import Condition, Thread +from typing import Any, Dict, Generator, List, Optional, Union, Tuple, TYPE_CHECKING from .connection import get_connection_cursor from .misc import format_lsn, fsync_dir +from ..dcs import Cluster, Leader from ..psycopg import OperationalError +if TYPE_CHECKING: # pragma: no cover + from psycopg import Cursor + from psycopg2 import cursor + from . import Postgresql + logger = logging.getLogger(__name__) -def compare_slots(s1, s2, dbid='database'): +def compare_slots(s1: Dict[str, Any], s2: Dict[str, Any], dbid: str = 'database') -> bool: return s1['type'] == s2['type'] and (s1['type'] == 'physical' or s1.get(dbid) == s2.get(dbid) and s1['plugin'] == s2['plugin']) class SlotsAdvanceThread(Thread): - def __init__(self, slots_handler): + def __init__(self, slots_handler: 'SlotsHandler') -> None: super(SlotsAdvanceThread, self).__init__() self.daemon = True self._slots_handler = slots_handler # _copy_slots and _failed are used to asynchronously give some feedback to the main thread - self._copy_slots = [] + self._copy_slots: List[str] = [] self._failed = False - - self._scheduled = defaultdict(dict) # {'dbname1': {'slot1': 100, 'slot2': 100}, 'dbname2': {'slot3': 100}} + # {'dbname1': {'slot1': 100, 'slot2': 100}, 'dbname2': {'slot3': 100}} + self._scheduled: Dict[str, Dict[str, int]] = defaultdict(dict) self._condition = Condition() # protect self._scheduled from concurrent access and to wakeup the run() method self.start() - def sync_slot(self, cur, database, slot, lsn): + def sync_slot(self, cur: Union['cursor', 'Cursor[Any]'], database: str, slot: str, lsn: int) -> None: failed = copy = False try: cur.execute("SELECT pg_catalog.pg_replication_slot_advance(%s, %s)", (slot, format_lsn(lsn))) @@ -55,7 +62,7 @@ class SlotsAdvanceThread(Thread): if not self._scheduled[database]: self._scheduled.pop(database) - def sync_slots_in_database(self, database, slots): + def sync_slots_in_database(self, database: str, slots: List[str]) -> None: with self._slots_handler.get_local_connection_cursor(dbname=database, options='-c statement_timeout=0') as cur: for slot in slots: with self._condition: @@ -63,7 +70,7 @@ class SlotsAdvanceThread(Thread): if lsn: self.sync_slot(cur, database, slot, lsn) - def sync_slots(self): + def sync_slots(self) -> None: with self._condition: databases = list(self._scheduled.keys()) for database in databases: @@ -75,7 +82,7 @@ class SlotsAdvanceThread(Thread): except Exception as e: logger.error('Failed to advance replication slots in database %s: %r', database, e) - def run(self): + def run(self) -> None: while True: with self._condition: if not self._scheduled: @@ -83,7 +90,7 @@ class SlotsAdvanceThread(Thread): self.sync_slots() - def schedule(self, advance_slots): + def schedule(self, advance_slots: Dict[str, Dict[str, int]]) -> Tuple[bool, List[str]]: with self._condition: for database, values in advance_slots.items(): self._scheduled[database].update(values) @@ -94,7 +101,7 @@ class SlotsAdvanceThread(Thread): return ret - def on_promote(self): + def on_promote(self) -> None: with self._condition: self._scheduled.clear() self._failed = False @@ -103,22 +110,22 @@ class SlotsAdvanceThread(Thread): class SlotsHandler(object): - def __init__(self, postgresql): + def __init__(self, postgresql: 'Postgresql') -> None: self._postgresql = postgresql self._advance = None - self._replication_slots = {} # already existing replication slots - self._unready_logical_slots = {} + self._replication_slots: Dict[str, Dict[str, Any]] = {} # already existing replication slots + self._unready_logical_slots: Dict[str, Optional[int]] = {} self.pg_replslot_dir = os.path.join(self._postgresql.data_dir, 'pg_replslot') self.schedule() - def _query(self, sql, *params): + def _query(self, sql: str, *params: Any) -> Union['cursor', 'Cursor[Any]']: return self._postgresql.query(sql, *params, retry=False) @staticmethod - def _copy_items(src, dst, keys=None): + def _copy_items(src: Dict[str, Any], dst: Dict[str, Any], keys: Optional[List[str]] = None) -> None: dst.update({key: src[key] for key in keys or ('datoid', 'catalog_xmin', 'confirmed_flush_lsn')}) - def process_permanent_slots(self, slots): + def process_permanent_slots(self, slots: List[Dict[str, Any]]) -> Dict[str, int]: """This methods solves three problems at once (I know, it is weird). The cluster_info_query from `Postgresql` is executed every HA loop and returns @@ -130,11 +137,11 @@ class SlotsHandler(object): 3. Updates the local cache with the fresh catalog_xmin and confirmed_flush_lsn for every known slot. This info is used when performing the check of logical slot readiness on standbys. """ - ret = {} + ret: Dict[str, int] = {} - slots = {slot['slot_name']: slot for slot in slots or []} - if slots: - for name, value in slots.items(): + slots_dict: Dict[str, Dict[str, Any]] = {slot['slot_name']: slot for slot in slots or []} + if slots_dict: + for name, value in slots_dict.items(): if name in self._replication_slots: if compare_slots(value, self._replication_slots[name], 'datoid'): if value['type'] == 'logical': @@ -143,15 +150,15 @@ class SlotsHandler(object): else: self._schedule_load_slots = True - # It could happen that the slots was deleted in the background, we want to detect this case - if any(name not in slots for name in self._replication_slots.keys()): + # It could happen that the slot was deleted in the background, we want to detect this case + if any(name not in slots_dict for name in self._replication_slots.keys()): self._schedule_load_slots = True return ret - def load_replication_slots(self): + def load_replication_slots(self) -> None: if self._postgresql.major_version >= 90400 and self._schedule_load_slots: - replication_slots = {} + replication_slots: Dict[str, Dict[str, Any]] = {} extra = ", catalog_xmin, pg_catalog.pg_wal_lsn_diff(confirmed_flush_lsn, '0/0')::bigint"\ if self._postgresql.major_version >= 100000 else "" skip_temp_slots = ' WHERE NOT temporary' if self._postgresql.major_version >= 100000 else '' @@ -170,15 +177,16 @@ class SlotsHandler(object): self._unready_logical_slots = {n: None for n, v in replication_slots.items() if v['type'] == 'logical'} self._force_readiness_check = False - def ignore_replication_slot(self, cluster, name): + def ignore_replication_slot(self, cluster: Cluster, name: str) -> bool: slot = self._replication_slots[name] - for matcher in cluster.config.ignore_slots_matchers: - if ((matcher.get("name") is None or matcher["name"] == name) - and all(not matcher.get(a) or matcher[a] == slot.get(a) for a in ('database', 'plugin', 'type'))): - return True + if cluster.config: + for matcher in cluster.config.ignore_slots_matchers: + if ((matcher.get("name") is None or matcher["name"] == name) + and all(not matcher.get(a) or matcher[a] == slot.get(a) for a in ('database', 'plugin', 'type'))): + return True return self._postgresql.citus_handler.ignore_replication_slot(slot) - def drop_replication_slot(self, name): + def drop_replication_slot(self, name: str) -> Tuple[bool, bool]: """Returns a tuple(active, dropped)""" cursor = self._query(('WITH slots AS (SELECT slot_name, active' ' FROM pg_catalog.pg_replication_slots WHERE slot_name = %s),' @@ -186,9 +194,12 @@ class SlotsHandler(object): ' true AS dropped FROM slots WHERE not active) ' 'SELECT active, COALESCE(dropped, false) FROM slots' ' FULL OUTER JOIN dropped ON true'), name) - return cursor.fetchone() if cursor.rowcount == 1 else (False, False) + row = cursor.fetchone() + if not row: + row = (False, False) + return row - def _drop_incorrect_slots(self, cluster, slots, paused): + def _drop_incorrect_slots(self, cluster: Cluster, slots: Dict[str, Any], paused: bool) -> None: # drop old replication slots which are not presented in desired slots for name in set(self._replication_slots) - set(slots): if not paused and not self.ignore_replication_slot(cluster, name): @@ -211,7 +222,7 @@ class SlotsHandler(object): logger.error("Failed to drop replication slot '%s'", name) self._schedule_load_slots = True - def _ensure_physical_slots(self, slots): + def _ensure_physical_slots(self, slots: Dict[str, Any]) -> None: immediately_reserve = ', true' if self._postgresql.major_version >= 90600 else '' for name, value in slots.items(): if name not in self._replication_slots and value['type'] == 'physical': @@ -225,15 +236,15 @@ class SlotsHandler(object): self._schedule_load_slots = True @contextmanager - def get_local_connection_cursor(self, **kwargs): + def get_local_connection_cursor(self, **kwargs: Any) -> Generator[Union['cursor', 'Cursor[Any]'], None, None]: conn_kwargs = self._postgresql.config.local_connect_kwargs conn_kwargs.update(kwargs) with get_connection_cursor(**conn_kwargs) as cur: yield cur - def _ensure_logical_slots_primary(self, slots): + def _ensure_logical_slots_primary(self, slots: Dict[str, Any]) -> None: # Group logical slots to be created by database name - logical_slots = defaultdict(dict) + logical_slots: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict(dict) for name, value in slots.items(): if value['type'] == 'logical': # If the logical already exists, copy some information about it into the original structure @@ -257,26 +268,27 @@ class SlotsHandler(object): slots.pop(name) self._schedule_load_slots = True - def schedule_advance_slots(self, slots): + def schedule_advance_slots(self, slots: Dict[str, Dict[str, int]]) -> Tuple[bool, List[str]]: if not self._advance: self._advance = SlotsAdvanceThread(self) return self._advance.schedule(slots) - def _ensure_logical_slots_replica(self, cluster, slots): - advance_slots = defaultdict(dict) # Group logical slots to be advanced by database name - create_slots = [] # And collect logical slots to be created on the replica + def _ensure_logical_slots_replica(self, cluster: Cluster, slots: Dict[str, Any]) -> List[str]: + # Group logical slots to be advanced by database name + advance_slots: Dict[str, Dict[str, int]] = defaultdict(dict) + create_slots: List[str] = [] # And collect logical slots to be created on the replica for name, value in slots.items(): if value['type'] == 'logical': # If the logical already exists, copy some information about it into the original structure if self._replication_slots.get(name, {}).get('datoid'): self._copy_items(self._replication_slots[name], value) - if name in cluster.slots: + if cluster.slots and name in cluster.slots: try: # Skip slots that doesn't need to be advanced if value['confirmed_flush_lsn'] < int(cluster.slots[name]): advance_slots[value['database']][name] = int(cluster.slots[name]) except Exception as e: logger.error('Failed to parse "%s": %r', cluster.slots[name], e) - elif name in cluster.slots: # We want to copy only slots with feedback in a DCS + elif cluster.slots and name in cluster.slots: # We want to copy only slots with feedback in a DCS create_slots.append(name) error, copy_slots = self.schedule_advance_slots(advance_slots) @@ -284,8 +296,9 @@ class SlotsHandler(object): self._schedule_load_slots = True return create_slots + copy_slots - def sync_replication_slots(self, cluster, nofailover, replicatefrom=None, paused=False): - ret = None + def sync_replication_slots(self, cluster: Cluster, nofailover: bool, + replicatefrom: Optional[str] = None, paused: bool = False) -> List[str]: + ret = [] if self._postgresql.major_version >= 90400 and cluster.config: try: self.load_replication_slots() @@ -312,14 +325,15 @@ class SlotsHandler(object): return ret @contextmanager - def _get_leader_connection_cursor(self, leader): + def _get_leader_connection_cursor(self, leader: Leader) -> Generator[Union['cursor', 'Cursor[Any]'], None, None]: conn_kwargs = leader.conn_kwargs(self._postgresql.config.rewind_credentials) conn_kwargs['dbname'] = self._postgresql.database with get_connection_cursor(connect_timeout=3, options="-c statement_timeout=2000", **conn_kwargs) as cur: yield cur - def check_logical_slots_readiness(self, cluster, nofailover, replicatefrom): - if self._unready_logical_slots: + def check_logical_slots_readiness(self, cluster: Cluster, nofailover: bool, replicatefrom: Optional[str]) -> None: + catalog_xmin = None + if self._unready_logical_slots and cluster.leader: slot_name = cluster.get_my_slot_name_on_primary(self._postgresql.name, replicatefrom) try: with self._get_leader_connection_cursor(cluster.leader) as cur: @@ -335,13 +349,14 @@ class SlotsHandler(object): # Remember catalog_xmin of logical slots on the primary when catalog_xmin of # the physical slot became valid. Logical slots on replica will be safe to use after # promote when catalog_xmin of the physical slot overtakes these values. - if catalog_xmin: + if catalog_xmin is not None: for name, value in slots.items(): self._unready_logical_slots[name] = value else: # Replica isn't streaming or the hot_standby_feedback isn't enabled try: cur = self._query("SELECT pg_catalog.current_setting('hot_standby_feedback')::boolean") - if not cur.fetchone()[0]: + row = cur.fetchone() + if row and not row[0]: logger.error('Logical slot failover requires "hot_standby_feedback".' ' Please check postgresql.auto.conf') except Exception as e: @@ -354,14 +369,18 @@ class SlotsHandler(object): # 1. has a nonzero/non-null catalog_xmin # 2. has a catalog_xmin that is not newer (greater) than the catalog_xmin of any slot on the standby # 3. overtook the catalog_xmin of remembered values of logical slots on the primary. - if not value or self._unready_logical_slots[name] <= catalog_xmin <= value['catalog_xmin']: + if not value or catalog_xmin is not None and\ + self._unready_logical_slots[name] <= catalog_xmin <= value['catalog_xmin']: del self._unready_logical_slots[name] if value: logger.info('Logical slot %s is safe to be used after a failover', name) - def copy_logical_slots(self, cluster, create_slots): + def copy_logical_slots(self, cluster: Cluster, create_slots: List[str]) -> None: leader = cluster.leader + if not leader: + return slots = cluster.get_replication_slots(self._postgresql.name, 'replica', False, self._postgresql.major_version) + copy_slots: Dict[str, Dict[str, Any]] = {} with self._get_leader_connection_cursor(leader) as cur: try: cur.execute("SELECT slot_name, slot_type, datname, plugin, catalog_xmin, " @@ -370,21 +389,20 @@ class SlotsHandler(object): " FROM pg_catalog.pg_get_replication_slots() JOIN pg_catalog.pg_database ON datoid = oid" " WHERE NOT pg_catalog.pg_is_in_recovery() AND slot_name = ANY(%s)", (create_slots,)) - create_slots = {} for r in cur: if r[0] in slots: # slot_name is defined in the global configuration slot = {'type': r[1], 'database': r[2], 'plugin': r[3], 'catalog_xmin': r[4], 'confirmed_flush_lsn': r[5], 'data': r[6]} if compare_slots(slot, slots[r[0]]): - create_slots[r[0]] = slot + copy_slots[r[0]] = slot else: logger.warning('Will not copy the logical slot "%s" due to the configuration mismatch: ' 'configuration=%s, slot on the primary=%s', r[0], slots[r[0]], slot) except Exception as e: logger.error("Failed to copy logical slots from the %s via postgresql connection: %r", leader.name, e) - if isinstance(create_slots, dict) and create_slots and self._postgresql.stop(): - for name, value in create_slots.items(): + if copy_slots and self._postgresql.stop(): + for name, value in copy_slots.items(): slot_dir = os.path.join(self._postgresql.slots_handler.pg_replslot_dir, name) slot_tmp_dir = slot_dir + '.tmp' if os.path.exists(slot_tmp_dir): @@ -403,12 +421,12 @@ class SlotsHandler(object): fsync_dir(self._postgresql.slots_handler.pg_replslot_dir) self._postgresql.start() - def schedule(self, value=None): + def schedule(self, value: Optional[bool] = None) -> None: if value is None: value = self._postgresql.major_version >= 90400 self._schedule_load_slots = self._force_readiness_check = value - def on_promote(self): + def on_promote(self) -> None: if self._advance: self._advance.on_promote() diff --git a/patroni/postgresql/sync.py b/patroni/postgresql/sync.py index 7d252b96..7ddd7722 100644 --- a/patroni/postgresql/sync.py +++ b/patroni/postgresql/sync.py @@ -3,7 +3,7 @@ import re import time from copy import deepcopy -from typing import Any, Collection, Dict, Tuple, TYPE_CHECKING +from typing import Any, Collection, Dict, List, Tuple, TYPE_CHECKING, Union from ..collections import CaseInsensitiveDict, CaseInsensitiveSet from ..dcs import Cluster @@ -109,7 +109,7 @@ def parse_sync_standby_names(value: str) -> Dict[str, Any]: result = {'type': 'priority', 'num': int(tokens[0][1])} synclist = tokens[2:-1] else: - result = {'type': 'priority', 'num': 1} + result: Dict[str, Union[int, str, CaseInsensitiveSet]] = {'type': 'priority', 'num': 1} synclist = tokens result['members'] = CaseInsensitiveSet() for i, (a_type, a_value, a_pos) in enumerate(synclist): @@ -149,12 +149,12 @@ class SyncHandler(object): # "sync" replication connections, that were verified to reach self._primary_flush_lsn at some point self._ready_replicas = CaseInsensitiveDict({}) # keys: member names, values: connection pids - def _handle_synchronous_standby_names_change(self): + def _handle_synchronous_standby_names_change(self) -> None: """If synchronous_standby_names has changed we need to check that newly added replicas have reached self._primary_flush_lsn. Only after that they could be counted as sync.""" synchronous_standby_names = self._postgresql.synchronous_standby_names() if synchronous_standby_names == self._synchronous_standby_names: - return False + return self._synchronous_standby_names = synchronous_standby_names try: @@ -165,7 +165,7 @@ class SyncHandler(object): # Invalidate cache of "sync" connections for app_name in list(self._ready_replicas.keys()): - if app_name not in self._ssn_data['members']: + if isinstance(self._ssn_data['members'], CaseInsensitiveSet) and app_name not in self._ssn_data['members']: del self._ready_replicas[app_name] # Newly connected replicas will be counted as sync only when reached self._primary_flush_lsn @@ -201,7 +201,7 @@ class SyncHandler(object): if r[sort_col] is not None] members = CaseInsensitiveDict({m.name: m for m in cluster.members}) - replica_list = [] + replica_list: List[Tuple[int, str, str, int, bool]] = [] # pg_stat_replication.sync_state has 4 possible states - async, potential, quorum, sync. # That is, alphabetically they are in the reversed order of priority. # Since we are doing reversed sort on (sync_state, lsn) tuples, it helps to keep the result @@ -227,7 +227,8 @@ class SyncHandler(object): for pid, app_name, sync_state, replica_lsn, _ in sorted(replica_list, key=lambda x: x[4]): # if standby name is listed in the /sync key we can count it as synchronous, otherwice # it becomes really synchronous when sync_state = 'sync' and it is known that it managed to catch up - if app_name not in self._ready_replicas and app_name in self._ssn_data['members'] and\ + if app_name not in self._ready_replicas and isinstance(self._ssn_data['members'], CaseInsensitiveSet)\ + and app_name in self._ssn_data['members'] and\ (cluster.sync.matches(app_name) or sync_state == 'sync' and replica_lsn >= self._primary_flush_lsn): self._ready_replicas[app_name] = pid diff --git a/patroni/postgresql/validator.py b/patroni/postgresql/validator.py index 95162d3e..30546444 100644 --- a/patroni/postgresql/validator.py +++ b/patroni/postgresql/validator.py @@ -1,7 +1,7 @@ import abc import logging -from collections import namedtuple +from typing import Any, MutableMapping, Optional, Tuple, Union from ..collections import CaseInsensitiveDict from ..utils import parse_bool, parse_int, parse_real @@ -9,23 +9,65 @@ from ..utils import parse_bool, parse_int, parse_real logger = logging.getLogger(__name__) -class Bool(namedtuple('Bool', 'version_from,version_till')): +class _Transformable(abc.ABC): - @staticmethod - def transform(name, value): + def __init__(self, version_from: int, version_till: Optional[int]) -> None: + self.__version_from = version_from + self.__version_till = version_till + + @property + def version_from(self) -> int: + return self.__version_from + + @property + def version_till(self) -> Optional[int]: + return self.__version_till + + @abc.abstractmethod + def transform(self, name: str, value: Any) -> Optional[Any]: + """Verify that provided value is valid. + + :param name: GUC's name + :param value: GUC's value + :returns: the value (sometimes clamped) or ``None`` if the value isn't valid + """ + + +class Bool(_Transformable): + + def transform(self, name: str, value: Any) -> Optional[Any]: if parse_bool(value) is not None: return value logger.warning('Removing bool parameter=%s from the config due to the invalid value=%s', name, value) -class Number(abc.ABC, namedtuple('Number', 'version_from,version_till,min_val,max_val,unit')): +class Number(_Transformable): + + def __init__(self, version_from: int, version_till: Optional[int], + min_val: Union[int, float], max_val: Union[int, float], unit: Optional[str]) -> None: + super(Number, self).__init__(version_from, version_till) + self.__min_val = min_val + self.__max_val = max_val + self.__unit = unit + + @property + def min_val(self) -> Union[int, float]: + return self.__min_val + + @property + def max_val(self) -> Union[int, float]: + return self.__max_val + + @property + def unit(self) -> Optional[str]: + return self.__unit @staticmethod @abc.abstractmethod - def parse(value, unit): - """parse value""" + def parse(value: Any, unit: Optional[str]) -> Optional[Any]: + """Convert provided value to unit.""" - def transform(self, name, value): + def transform(self, name: str, value: Any) -> Union[int, float, None]: num_value = self.parse(value, self.unit) if num_value is not None: if num_value < self.min_val: @@ -44,20 +86,28 @@ class Number(abc.ABC, namedtuple('Number', 'version_from,version_till,min_val,ma class Integer(Number): @staticmethod - def parse(value, unit): + def parse(value: Any, unit: Optional[str]) -> Optional[int]: return parse_int(value, unit) class Real(Number): @staticmethod - def parse(value, unit): + def parse(value: Any, unit: Optional[str]) -> Optional[float]: return parse_real(value, unit) -class Enum(namedtuple('Enum', 'version_from,version_till,possible_values')): +class Enum(_Transformable): - def transform(self, name, value): + def __init__(self, version_from: int, version_till: Optional[int], possible_values: Tuple[str, ...]) -> None: + super(Enum, self).__init__(version_from, version_till) + self.__possible_values = possible_values + + @property + def possible_values(self) -> Tuple[str, ...]: + return self.__possible_values + + def transform(self, name: str, value: Optional[Any]) -> Optional[Any]: if str(value).lower() in self.possible_values: return value logger.warning('Removing enum parameter=%s from the config due to the invalid value=%s', name, value) @@ -65,16 +115,15 @@ class Enum(namedtuple('Enum', 'version_from,version_till,possible_values')): class EnumBool(Enum): - def transform(self, name, value): + def transform(self, name: str, value: Optional[Any]) -> Optional[Any]: if parse_bool(value) is not None: return value return super(EnumBool, self).transform(name, value) -class String(namedtuple('String', 'version_from,version_till')): +class String(_Transformable): - @staticmethod - def transform(name, value): + def transform(self, name: str, value: Optional[Any]) -> Optional[Any]: return value @@ -511,17 +560,18 @@ recovery_parameters = CaseInsensitiveDict({ }) -def _transform_parameter_value(validators, version, name, value): - validators = validators.get(name) - if validators: - for validator in (validators if isinstance(validators[0], tuple) else [validators]): +def _transform_parameter_value(validators: MutableMapping[str, Union[_Transformable, Tuple[_Transformable, ...]]], + version: int, name: str, value: Any) -> Optional[Any]: + name_validators = validators.get(name) + if name_validators: + for validator in (name_validators if isinstance(name_validators, tuple) else (name_validators,)): if version >= validator.version_from and\ (validator.version_till is None or version < validator.version_till): return validator.transform(name, value) logger.warning('Removing unexpected parameter=%s value=%s from the config', name, value) -def transform_postgresql_parameter_value(version, name, value): +def transform_postgresql_parameter_value(version: int, name: str, value: Any) -> Optional[Any]: if '.' in name: return value if name in recovery_parameters: @@ -529,5 +579,5 @@ def transform_postgresql_parameter_value(version, name, value): return _transform_parameter_value(parameters, version, name, value) -def transform_recovery_parameter_value(version, name, value): +def transform_recovery_parameter_value(version: int, name: str, value: Any) -> Optional[Any]: return _transform_parameter_value(recovery_parameters, version, name, value) diff --git a/patroni/psycopg.py b/patroni/psycopg.py index 6b9b5629..d337dd3e 100644 --- a/patroni/psycopg.py +++ b/patroni/psycopg.py @@ -3,8 +3,10 @@ This module is able to handle both ``pyscopg2`` and ``psycopg3``, and it exposes a common interface for both. ``psycopg2`` takes precedence. ``psycopg3`` will only be used if ``psycopg2`` is either absent or older than ``2.5.4``. """ -from typing import Any, Optional - +from typing import Any, Optional, TYPE_CHECKING, Union +if TYPE_CHECKING: # pragma: no cover + from psycopg import Connection + from psycopg2 import connection, cursor __all__ = ['connect', 'quote_ident', 'quote_literal', 'DatabaseError', 'Error', 'OperationalError', 'ProgrammingError'] @@ -39,9 +41,9 @@ try: value.prepare(conn) return value.getquoted().decode('utf-8') except ImportError: - from psycopg import connect as __connect, sql, Error, DatabaseError, OperationalError, ProgrammingError, Connection + from psycopg import connect as __connect, sql, Error, DatabaseError, OperationalError, ProgrammingError - def _connect(dsn: str = "", **kwargs: Any) -> Any: + def _connect(dsn: Optional[str] = None, **kwargs: Any) -> 'Connection[Any]': """Call ``psycopg.connect`` with ``dsn`` and ``**kwargs``. .. note:: @@ -53,19 +55,19 @@ except ImportError: :returns: a connection to the database. """ - ret = __connect(dsn, **kwargs) + ret = __connect(dsn or "", **kwargs) setattr(ret, 'server_version', ret.pgconn.server_version) # compatibility with psycopg2 return ret - def _quote_ident(value: Any, conn: Connection) -> str: + def _quote_ident(value: Any, scope: Any) -> str: """Quote *value* as a SQL identifier. :param value: value to be quoted. - :param conn: connection to evaluate the returning string into. + :param scope: connection to evaluate the returning string into. :returns: *value* quoted as a SQL identifier. """ - return sql.Identifier(value).as_string(conn) + return sql.Identifier(value).as_string(scope) def quote_literal(value: Any, conn: Optional[Any] = None) -> str: """Quote *value* as a SQL literal. @@ -78,7 +80,7 @@ except ImportError: return sql.Literal(value).as_string(conn) -def connect(*args: Any, **kwargs: Any) -> Any: +def connect(*args: Any, **kwargs: Any) -> Union['connection', 'Connection[Any]']: """Get a connection to the database. .. note:: @@ -102,7 +104,7 @@ def connect(*args: Any, **kwargs: Any) -> Any: return ret -def quote_ident(value: Any, conn: Optional[Any] = None) -> str: +def quote_ident(value: Any, conn: Optional[Union['cursor', 'connection', 'Connection[Any]']] = None) -> str: """Quote *value* as a SQL identifier. :param value: value to be quoted. diff --git a/patroni/raft_controller.py b/patroni/raft_controller.py index 1baa9786..2f9f7858 100644 --- a/patroni/raft_controller.py +++ b/patroni/raft_controller.py @@ -1,5 +1,6 @@ import logging +from .config import Config from .daemon import AbstractPatroniDaemon, abstract_main from .dcs.raft import KVStoreTTL @@ -8,22 +9,22 @@ logger = logging.getLogger(__name__) class RaftController(AbstractPatroniDaemon): - def __init__(self, config): + def __init__(self, config: Config) -> None: super(RaftController, self).__init__(config) - config = self.config.get('raft') - assert 'self_addr' in config - self._raft = KVStoreTTL(None, None, None, **config) + kvstore_config = self.config.get('raft') + assert 'self_addr' in kvstore_config + self._raft = KVStoreTTL(None, None, None, **kvstore_config) - def _run_cycle(self): + def _run_cycle(self) -> None: try: self._raft.doTick(self._raft.conf.autoTickPeriod) except Exception: logger.exception('doTick') - def _shutdown(self): + def _shutdown(self) -> None: self._raft.destroy() -def main(): +def main() -> None: abstract_main(RaftController) diff --git a/patroni/request.py b/patroni/request.py index ae1d9c8a..d063310d 100644 --- a/patroni/request.py +++ b/patroni/request.py @@ -140,14 +140,14 @@ class PatroniRequest(object): :returns: the response returned upon request. """ - url = member.api_url + url = member.api_url or '' if endpoint: scheme, netloc, _, _, _, _ = urlparse(url) url = urlunparse((scheme, netloc, endpoint, '', '', '')) return self.request(method, url, data, **kwargs) -def get(url: str, verify: Optional[bool] = True, **kwargs: Any) -> urllib3.response.HTTPResponse: +def get(url: str, verify: bool = True, **kwargs: Any) -> urllib3.response.HTTPResponse: """Perform an HTTP GET request. .. note:: diff --git a/patroni/scripts/aws.py b/patroni/scripts/aws.py index 7d18e828..9fcc8160 100755 --- a/patroni/scripts/aws.py +++ b/patroni/scripts/aws.py @@ -5,20 +5,22 @@ import logging import sys import boto3 -from ..utils import Retry, RetryFailedError - from botocore.exceptions import ClientError from botocore.utils import IMDSFetcher +from typing import Any, Optional + +from ..utils import Retry, RetryFailedError + logger = logging.getLogger(__name__) class AWSConnection(object): - def __init__(self, cluster_name): + def __init__(self, cluster_name: Optional[str]) -> None: self.available = False self.cluster_name = cluster_name if cluster_name is not None else 'unknown' - self._retry = Retry(deadline=300, max_delay=30, max_tries=-1, retry_exceptions=(ClientError,)) + self._retry = Retry(deadline=300, max_delay=30, max_tries=-1, retry_exceptions=ClientError) try: # get the instance id fetcher = IMDSFetcher(timeout=2.1) @@ -38,13 +40,13 @@ class AWSConnection(object): return self.available = True - def retry(self, *args, **kwargs): + def retry(self, *args: Any, **kwargs: Any) -> Any: return self._retry.copy()(*args, **kwargs) - def aws_available(self): + def aws_available(self) -> bool: return self.available - def _tag_ebs(self, conn, role): + def _tag_ebs(self, conn: Any, role: str) -> None: """ set tags, carrying the cluster name, instance role and instance id for the EBS storage """ tags = [{'Key': 'Name', 'Value': 'spilo_' + self.cluster_name}, {'Key': 'Role', 'Value': role}, @@ -52,16 +54,16 @@ class AWSConnection(object): volumes = conn.volumes.filter(Filters=[{'Name': 'attachment.instance-id', 'Values': [self.instance_id]}]) conn.create_tags(Resources=[v.id for v in volumes], Tags=tags) - def _tag_ec2(self, conn, role): + def _tag_ec2(self, conn: Any, role: str) -> None: """ tag the current EC2 instance with a cluster role """ tags = [{'Key': 'Role', 'Value': role}] conn.create_tags(Resources=[self.instance_id], Tags=tags) - def on_role_change(self, new_role): + def on_role_change(self, new_role: str) -> bool: if not self.available: return False try: - conn = boto3.resource('ec2', region_name=self.region) + conn = boto3.resource('ec2', region_name=self.region) # type: ignore self.retry(self._tag_ec2, conn, new_role) self.retry(self._tag_ebs, conn, new_role) except RetryFailedError: diff --git a/patroni/scripts/wale_restore.py b/patroni/scripts/wale_restore.py index 7e6de976..c2d3545c 100755 --- a/patroni/scripts/wale_restore.py +++ b/patroni/scripts/wale_restore.py @@ -31,7 +31,8 @@ import subprocess import sys import time -from collections import namedtuple +from enum import IntEnum +from typing import Any, List, NamedTuple, Optional, Tuple from .. import psycopg @@ -42,15 +43,14 @@ si_prefixes = ['K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y'] # Meaningful names to the exit codes used by WALERestore -ExitCode = type('Enum', (), { - 'SUCCESS': 0, #: Succeeded - 'RETRY_LATER': 1, #: External issue, retry later - 'FAIL': 2 #: Don't try again unless configuration changes -}) +class ExitCode(IntEnum): + SUCCESS = 0 #: Succeeded + RETRY_LATER = 1 #: External issue, retry later + FAIL = 2 #: Don't try again unless configuration changes # We need to know the current PG version in order to figure out the correct WAL directory name -def get_major_version(data_dir): +def get_major_version(data_dir: str) -> float: version_file = os.path.join(data_dir, 'PG_VERSION') if os.path.isfile(version_file): # version file exists try: @@ -61,7 +61,7 @@ def get_major_version(data_dir): return 0.0 -def repr_size(n_bytes): +def repr_size(n_bytes: float) -> str: """ >>> repr_size(1000) '1000 Bytes' @@ -77,7 +77,7 @@ def repr_size(n_bytes): return '{0} {1}iB'.format(round(n_bytes, 1), si_prefixes[i]) -def size_as_bytes(size_, prefix): +def size_as_bytes(size: float, prefix: str) -> int: """ >>> size_as_bytes(7.5, 'T') 8246337208320 @@ -88,23 +88,19 @@ def size_as_bytes(size_, prefix): exponent = si_prefixes.index(prefix) + 1 - return int(size_ * (1024.0 ** exponent)) + return int(size * (1024.0 ** exponent)) -WALEConfig = namedtuple( - 'WALEConfig', - [ - 'env_dir', - 'threshold_mb', - 'threshold_pct', - 'cmd', - ] -) +class WALEConfig(NamedTuple): + env_dir: str + threshold_mb: int + threshold_pct: int + cmd: List[str] class WALERestore(object): - def __init__(self, scope, datadir, connstring, env_dir, threshold_mb, - threshold_pct, use_iam, no_leader, retries): + def __init__(self, scope: str, datadir: str, connstring: str, env_dir: str, threshold_mb: int, + threshold_pct: int, use_iam: int, no_leader: bool, retries: int) -> None: self.scope = scope self.leader_connection = connstring self.data_dir = datadir @@ -129,7 +125,7 @@ class WALERestore(object): self.init_error = (not os.path.exists(self.wal_e.env_dir)) self.retries = retries - def run(self): + def run(self) -> int: """ Creates a new replica using WAL-E @@ -158,7 +154,7 @@ class WALERestore(object): logger.exception("Unhandled exception when running WAL-E restore") return ExitCode.FAIL - def should_use_s3_to_create_replica(self): + def should_use_s3_to_create_replica(self) -> Optional[bool]: """ determine whether it makes sense to use S3 and not pg_basebackup """ threshold_megabytes = self.wal_e.threshold_mb @@ -218,7 +214,7 @@ class WALERestore(object): try: # get the difference in bytes between the current WAL location and the backup start offset con = psycopg.connect(self.leader_connection) - if con.server_version >= 100000: + if getattr(con, 'server_version', 0) >= 100000: wal_name = 'wal' lsn_name = 'lsn' else: @@ -232,8 +228,9 @@ class WALERestore(object): " ELSE pg_catalog.pg_{0}_{1}_diff(pg_catalog.pg_current_{0}_{1}(), %s)::bigint" " END").format(wal_name, lsn_name), (backup_start_lsn, backup_start_lsn, backup_start_lsn)) - - diff_in_bytes = int(cur.fetchone()[0]) + for row in cur: + diff_in_bytes = int(row[0]) + break except psycopg.Error: logger.exception('could not determine difference with the leader location') if attempts_no < self.retries: # retry in case of a temporarily connection issue @@ -262,11 +259,11 @@ class WALERestore(object): are_thresholds_ok = is_size_thresh_ok and is_percentage_thresh_ok class Size(object): - def __init__(self, n_bytes, prefix=None): + def __init__(self, n_bytes: float, prefix: Optional[str] = None) -> None: self.n_bytes = n_bytes self.prefix = prefix - def __repr__(self): + def __repr__(self) -> str: if self.prefix is not None: n_bytes = size_as_bytes(self.n_bytes, self.prefix) else: @@ -274,10 +271,10 @@ class WALERestore(object): return repr_size(n_bytes) class HumanContext(object): - def __init__(self, items): + def __init__(self, items: List[Tuple[str, Any]]) -> None: self.items = items - def __repr__(self): + def __repr__(self) -> str: return ', '.join('{}={!r}'.format(key, value) for key, value in self.items) @@ -298,7 +295,7 @@ class WALERestore(object): logger.info('Thresholds are OK, using wal-e basebackup: %s', human_context) return are_thresholds_ok - def fix_subdirectory_path_if_broken(self, dirname): + def fix_subdirectory_path_if_broken(self, dirname: str) -> bool: # in case it is a symlink pointing to a non-existing location, remove it and create the actual directory path = os.path.join(self.data_dir, dirname) if not os.path.exists(path): @@ -316,7 +313,7 @@ class WALERestore(object): return False return True - def create_replica_with_s3(self): + def create_replica_with_s3(self) -> int: # if we're set up, restore the replica using fetch latest try: cmd = self.wal_e.cmd + ['backup-fetch', @@ -334,7 +331,7 @@ class WALERestore(object): return exit_code -def main(): +def main() -> int: logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', level=logging.INFO) parser = argparse.ArgumentParser(description='Script to image replicas using WAL-E') parser.add_argument('--scope', required=True) @@ -363,11 +360,12 @@ def main(): threshold_pct=args.threshold_backup_size_percentage, use_iam=args.use_iam, no_leader=args.no_leader, retries=args.retries) exit_code = restore.run() - if not exit_code == ExitCode.RETRY_LATER: # only WAL-E failures lead to the retry + if exit_code != ExitCode.RETRY_LATER: # only WAL-E failures lead to the retry logger.debug('exit_code is %r, not retrying', exit_code) break time.sleep(RETRY_SLEEP_INTERVAL) + assert exit_code is not None return exit_code diff --git a/patroni/utils.py b/patroni/utils.py index f8f9e2fd..f6cfa919 100644 --- a/patroni/utils.py +++ b/patroni/utils.py @@ -7,9 +7,9 @@ :var DEC_RE: regular expression to match decimal numbers, signed or unsigned. :var HEX_RE: regular expression to match hex strings, signed or unsigned. :var DBL_RE: regular expression to match double precision numbers, signed or unsigned. Matches scientific notation too. +:var WHITESPACE_RE: regular expression to match whitespace characters """ import errno -import json.decoder as json_decoder import logging import os import platform @@ -21,9 +21,10 @@ import tempfile import time from shlex import split -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union, TYPE_CHECKING +from typing import Any, Callable, Dict, Iterator, List, Optional, Union, Tuple, Type, TYPE_CHECKING from dateutil import tz +from json import JSONDecoder from urllib3.response import HTTPResponse from .exceptions import PatroniException @@ -42,9 +43,10 @@ OCT_RE = re.compile(r'^[-+]?0[0-7]*') DEC_RE = re.compile(r'^[-+]?(0|[1-9][0-9]*)') HEX_RE = re.compile(r'^[-+]?0x[0-9a-fA-F]+') DBL_RE = re.compile(r'^[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?') +WHITESPACE_RE = re.compile(r'[ \t\n\r]*', re.VERBOSE | re.MULTILINE | re.DOTALL) -def deep_compare(obj1: Dict, obj2: Dict) -> bool: +def deep_compare(obj1: Dict[Any, Union[Any, Dict[Any, Any]]], obj2: Dict[Any, Union[Any, Dict[Any, Any]]]) -> bool: """Recursively compare two dictionaries to check if they are equal in terms of keys and values. .. note:: @@ -84,7 +86,7 @@ def deep_compare(obj1: Dict, obj2: Dict) -> bool: return True -def patch_config(config: Dict, data: Dict) -> bool: +def patch_config(config: Dict[Any, Union[Any, Dict[Any, Any]]], data: Dict[Any, Union[Any, Dict[Any, Any]]]) -> bool: """Update and append to dictionary *config* from overrides in *data*. .. note:: @@ -239,7 +241,7 @@ def strtod(value: Any) -> Tuple[Union[float, None], str]: return None, value -def convert_to_base_unit(value: Union[int, float], unit: str, base_unit: str) -> Union[int, float, None]: +def convert_to_base_unit(value: Union[int, float], unit: str, base_unit: Optional[str]) -> Union[int, float, None]: """Convert *value* as a *unit* of compute information or time to *base_unit*. :param value: value to be converted to the base unit. @@ -268,7 +270,7 @@ def convert_to_base_unit(value: Union[int, float], unit: str, base_unit: str) -> >>> convert_to_base_unit(1, 'GB', '512 MB') is None True """ - convert = { + convert: Dict[str, Dict[str, Union[int, float]]] = { 'B': {'B': 1, 'kB': 1024, 'MB': 1024 * 1024, 'GB': 1024 * 1024 * 1024, 'TB': 1024 * 1024 * 1024 * 1024}, 'kB': {'B': 1.0 / 1024, 'kB': 1, 'MB': 1024, 'GB': 1024 * 1024, 'TB': 1024 * 1024 * 1024}, 'MB': {'B': 1.0 / (1024 * 1024), 'kB': 1.0 / 1024, 'MB': 1, 'GB': 1024, 'TB': 1024 * 1024}, @@ -287,7 +289,7 @@ def convert_to_base_unit(value: Union[int, float], unit: str, base_unit: str) -> else: base_value = 1 - if base_unit in convert and unit in convert[base_unit]: + if base_value is not None and base_unit in convert and unit in convert[base_unit]: value *= convert[base_unit][unit] / float(base_value) if unit in round_order: @@ -297,7 +299,7 @@ def convert_to_base_unit(value: Union[int, float], unit: str, base_unit: str) -> return value -def parse_int(value: Any, base_unit: Optional[str] = None) -> Union[int, None]: +def parse_int(value: Any, base_unit: Optional[str] = None) -> Optional[int]: """Parse *value* as an :class:`int`. :param value: any value that can be handled either by :func:`strtol` or :func:`strtod`. If *value* contains a @@ -350,7 +352,7 @@ def parse_int(value: Any, base_unit: Optional[str] = None) -> Union[int, None]: return round(val) -def parse_real(value: Any, base_unit: Optional[str] = None) -> Union[float, None]: +def parse_real(value: Any, base_unit: Optional[str] = None) -> Optional[float]: """Parse *value* as a :class:`float`. :param value: any value that can be handled by :func:`strtod`. If *value* contains a unit, then *base_unit* must @@ -381,7 +383,7 @@ def parse_real(value: Any, base_unit: Optional[str] = None) -> Union[float, None return convert_to_base_unit(val, unit, base_unit) -def compare_values(vartype: str, unit: str, old_value: Any, new_value: Any) -> bool: +def compare_values(vartype: str, unit: Optional[str], old_value: Any, new_value: Any) -> bool: """Check if *old_value* and *new_value* are equivalent after parsing them as *vartype*. :param vartpe: the target type to parse *old_value* and *new_value* before comparing them. Accepts any among of the @@ -426,7 +428,7 @@ def compare_values(vartype: str, unit: str, old_value: Any, new_value: Any) -> b >>> compare_values('integer', 'kB', 4098, '4097.5kB') True """ - converters = { + converters: Dict[str, Callable[[str, Optional[str]], Union[None, bool, int, float, str]]] = { 'bool': lambda v1, v2: parse_bool(v1), 'integer': parse_int, 'real': parse_real, @@ -434,11 +436,11 @@ def compare_values(vartype: str, unit: str, old_value: Any, new_value: Any) -> b 'string': lambda v1, v2: str(v1) } - convert = converters.get(vartype) or converters['string'] - old_value = convert(old_value, None) - new_value = convert(new_value, unit) + converter = converters.get(vartype) or converters['string'] + old_converted = converter(old_value, None) + new_converted = converter(new_value, unit) - return old_value is not None and new_value is not None and old_value == new_value + return old_converted is not None and new_converted is not None and old_converted == new_converted def _sleep(interval: Union[int, float]) -> None: @@ -467,11 +469,11 @@ class Retry(object): :ivar retry_exceptions: single exception or tuple """ - def __init__(self, max_tries: Optional[int] = 1, delay: Optional[float] = 0.1, backoff: Optional[int] = 2, - max_jitter: Optional[float] = 0.8, max_delay: Optional[int] = 3600, - sleep_func: Optional[Callable[[Union[int, float]], None]] = _sleep, + def __init__(self, max_tries: Optional[int] = 1, delay: float = 0.1, backoff: int = 2, + max_jitter: float = 0.8, max_delay: int = 3600, + sleep_func: Callable[[Union[int, float]], None] = _sleep, deadline: Optional[Union[int, float]] = None, - retry_exceptions: Optional[Union[Exception, Tuple[Exception]]] = PatroniException) -> None: + retry_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = PatroniException) -> None: """Create a :class:`Retry` instance for retrying function calls. :param max_tries: how many times to retry the command. ``-1`` means infinite tries. @@ -504,7 +506,7 @@ class Retry(object): def copy(self) -> 'Retry': """Return a clone of this retry manager.""" return Retry(max_tries=self.max_tries, delay=self.delay, backoff=self.backoff, - max_jitter=self.max_jitter / 100.0, max_delay=self.max_delay, sleep_func=self.sleep_func, + max_jitter=self.max_jitter / 100.0, max_delay=int(self.max_delay), sleep_func=self.sleep_func, deadline=self.deadline, retry_exceptions=self.retry_exceptions) @property @@ -525,11 +527,11 @@ class Retry(object): self._cur_delay = min(self._cur_delay * self.backoff, self.max_delay) @property - def stoptime(self) -> Union[float, None]: + def stoptime(self) -> float: """Get the current stop time.""" - return self._cur_stoptime + return self._cur_stoptime or 0 - def __call__(self, func: Callable, *args: Any, **kwargs: Any) -> Any: + def __call__(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: """Call a function *func* with arguments ``*args`` and ``*kwargs`` in a loop. *func* will be called until one of the following conditions is met: @@ -561,7 +563,9 @@ class Retry(object): logger.warning('Retry got exception: %s', e) raise RetryFailedError("Too many retry attempts") self._attempts += 1 - sleeptime = hasattr(e, 'sleeptime') and e.sleeptime or self.sleeptime + sleeptime = getattr(e, 'sleeptime', None) + if not isinstance(sleeptime, (int, float)): + sleeptime = self.sleeptime if self._cur_stoptime is not None and time.time() + sleeptime >= self._cur_stoptime: logger.warning('Retry got exception: %s', e) @@ -571,7 +575,7 @@ class Retry(object): self.update_delay() -def polling_loop(timeout: Union[int, float], interval: Optional[Union[int, float]] = 1) -> Iterator[int]: +def polling_loop(timeout: Union[int, float], interval: Union[int, float] = 1) -> Iterator[int]: """Return an iterator that returns values every *interval* seconds until *timeout* has passed. .. note:: @@ -587,10 +591,10 @@ def polling_loop(timeout: Union[int, float], interval: Optional[Union[int, float while time.time() < end_time: yield iteration iteration += 1 - time.sleep(interval) + time.sleep(float(interval)) -def split_host_port(value: str, default_port: int) -> Tuple[str, int]: +def split_host_port(value: str, default_port: Optional[int]) -> Tuple[str, int]: """Extract host(s) and port from *value*. :param value: string from where host(s) and port will be extracted. Accepts either of these formats @@ -625,11 +629,11 @@ def split_host_port(value: str, default_port: int) -> Tuple[str, int]: # If *value* contains ``:`` we consider it to be an IPv6 address, so we attempt to remove possible square brackets if ':' in t[0]: t[0] = ','.join([h.strip().strip('[]') for h in t[0].split(',')]) - t.append(default_port) + t.append(str(default_port)) return t[0], int(t[1]) -def uri(proto: str, netloc: Union[List, Tuple[str, int], str], path: Optional[str] = '', +def uri(proto: str, netloc: Union[List[str], Tuple[str, Union[int, str]], str], path: Optional[str] = '', user: Optional[str] = None) -> str: """Construct URI from given arguments. @@ -670,17 +674,15 @@ def iter_response_objects(response: HTTPResponse) -> Iterator[Dict[str, Any]]: :rtype: Iterator[:class:`dict`] with current JSON document. """ prev = '' - decoder = json_decoder.JSONDecoder() + decoder = JSONDecoder() for chunk in response.read_chunked(decode_content=False): - if isinstance(chunk, bytes): - chunk = chunk.decode('utf-8') - chunk = prev + chunk + chunk = prev + chunk.decode('utf-8') length = len(chunk) # ``chunk`` is analyzed in parts. ``idx`` holds the position of the first character in the current part that is # neither space nor tab nor line-break, or in other words, the position in the ``chunk`` where it is likely # that a JSON document begins - idx = json_decoder.WHITESPACE.match(chunk, 0).end() + idx = WHITESPACE_RE.match(chunk, 0).end() # pyright: ignore [reportOptionalMemberAccess] while idx < length: try: # Get a JSON document from the chunk. ``message`` is a dictionary representing the JSON document, and @@ -690,7 +692,7 @@ def iter_response_objects(response: HTTPResponse) -> Iterator[Dict[str, Any]]: break else: yield message - idx = json_decoder.WHITESPACE.match(chunk, idx).end() + idx = WHITESPACE_RE.match(chunk, idx).end() # pyright: ignore [reportOptionalMemberAccess] # It is not usual that a ``chunk`` would contain more than one JSON document, but we handle that just in case prev = chunk[idx:] @@ -732,7 +734,7 @@ def cluster_as_json(cluster: 'Cluster', global_config: Optional['GlobalConfig'] leader_name = cluster.leader.name if cluster.leader else None cluster_lsn = cluster.last_lsn or 0 - ret = {'members': []} + ret: Dict[str, Any] = {'members': []} for m in cluster.members: if m.name == leader_name: role = 'standby_leader' if global_config.is_standby_cluster else 'leader' @@ -762,7 +764,8 @@ def cluster_as_json(cluster: 'Cluster', global_config: Optional['GlobalConfig'] ret['members'].append(member) # sort members by name for consistency - ret['members'].sort(key=lambda m: m['name']) + cmp: Callable[[Dict[str, Any]], bool] = lambda m: m['name'] + ret['members'].sort(key=cmp) if global_config.is_paused: ret['pause'] = True if cluster.failover and cluster.failover.scheduled_at: @@ -791,7 +794,7 @@ def is_subpath(d1: str, d2: str) -> bool: return os.path.commonprefix([real_d1, real_d2 + os.path.sep]) == real_d1 -def validate_directory(d: str, msg: Optional[str] = "{} {}") -> None: +def validate_directory(d: str, msg: str = "{} {}") -> None: """Ensure directory exists and is writable. .. note:: @@ -842,7 +845,7 @@ def data_directory_is_empty(data_dir: str) -> bool: return all(os.name != 'nt' and (n.startswith('.') or n == 'lost+found') for n in os.listdir(data_dir)) -def keepalive_intvl(timeout: int, idle: int, cnt: Optional[int] = 3) -> int: +def keepalive_intvl(timeout: int, idle: int, cnt: int = 3) -> int: """Calculate the value to be used as ``TCP_KEEPINTVL`` based on *timeout*, *idle*, and *cnt*. :param timeout: value for ``TCP_USER_TIMEOUT``. @@ -854,7 +857,7 @@ def keepalive_intvl(timeout: int, idle: int, cnt: Optional[int] = 3) -> int: return max(1, int(float(timeout - idle) / cnt)) -def keepalive_socket_options(timeout: int, idle: int, cnt: Optional[int] = 3) -> Iterator[Tuple[int, int, int]]: +def keepalive_socket_options(timeout: int, idle: int, cnt: int = 3) -> Iterator[Tuple[int, int, int]]: """Get all keepalive related options to be set in a socket. :param timeout: value for ``TCP_USER_TIMEOUT``. @@ -871,25 +874,27 @@ def keepalive_socket_options(timeout: int, idle: int, cnt: Optional[int] = 3) -> """ yield (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - if sys.platform.startswith('linux'): - yield (socket.SOL_TCP, 18, int(timeout * 1000)) # TCP_USER_TIMEOUT - TCP_KEEPIDLE = getattr(socket, 'TCP_KEEPIDLE', None) - TCP_KEEPINTVL = getattr(socket, 'TCP_KEEPINTVL', None) - TCP_KEEPCNT = getattr(socket, 'TCP_KEEPCNT', None) - elif sys.platform.startswith('darwin'): - TCP_KEEPIDLE = 0x10 # (named "TCP_KEEPALIVE" in C) - TCP_KEEPINTVL = 0x101 - TCP_KEEPCNT = 0x102 - else: + if not (sys.platform.startswith('linux') or sys.platform.startswith('darwin')): return - intvl = keepalive_intvl(timeout, idle, cnt) - yield (socket.IPPROTO_TCP, TCP_KEEPIDLE, idle) - yield (socket.IPPROTO_TCP, TCP_KEEPINTVL, intvl) - yield (socket.IPPROTO_TCP, TCP_KEEPCNT, cnt) + if sys.platform.startswith('linux'): + yield (socket.SOL_TCP, 18, int(timeout * 1000)) # TCP_USER_TIMEOUT + + # The socket constants from MacOS netinet/tcp.h are not exported by python's + # socket module, therefore we are using 0x10, 0x101, 0x102 constants. + TCP_KEEPIDLE = getattr(socket, 'TCP_KEEPIDLE', 0x10 if sys.platform.startswith('darwin') else None) + if TCP_KEEPIDLE is not None: + yield (socket.IPPROTO_TCP, TCP_KEEPIDLE, idle) + TCP_KEEPINTVL = getattr(socket, 'TCP_KEEPINTVL', 0x101 if sys.platform.startswith('darwin') else None) + if TCP_KEEPINTVL is not None: + intvl = keepalive_intvl(timeout, idle, cnt) + yield (socket.IPPROTO_TCP, TCP_KEEPINTVL, intvl) + TCP_KEEPCNT = getattr(socket, 'TCP_KEEPCNT', 0x102 if sys.platform.startswith('darwin') else None) + if TCP_KEEPCNT is not None: + yield (socket.IPPROTO_TCP, TCP_KEEPCNT, cnt) -def enable_keepalive(sock: socket.socket, timeout: int, idle: int, cnt: Optional[int] = 3) -> Union[int, None]: +def enable_keepalive(sock: socket.socket, timeout: int, idle: int, cnt: int = 3) -> None: """Enable keepalive for *sock*. Will set socket options depending on the platform, as per return of :func:`keepalive_socket_options`. @@ -908,7 +913,7 @@ def enable_keepalive(sock: socket.socket, timeout: int, idle: int, cnt: Optional SIO_KEEPALIVE_VALS = getattr(socket, 'SIO_KEEPALIVE_VALS', None) if SIO_KEEPALIVE_VALS is not None: # Windows intvl = keepalive_intvl(timeout, idle, cnt) - return sock.ioctl(SIO_KEEPALIVE_VALS, (1, idle * 1000, intvl * 1000)) + sock.ioctl(SIO_KEEPALIVE_VALS, (1, idle * 1000, intvl * 1000)) for opt in keepalive_socket_options(timeout, idle, cnt): sock.setsockopt(*opt) diff --git a/patroni/validator.py b/patroni/validator.py index ecd3bc58..569894d1 100644 --- a/patroni/validator.py +++ b/patroni/validator.py @@ -11,9 +11,9 @@ import shutil import socket import subprocess -from typing import Any, Union, Iterator, List, Optional as OptionalType +from typing import Any, Dict, Union, Iterator, List, Optional as OptionalType -from .utils import split_host_port, data_directory_is_empty +from .utils import parse_int, split_host_port, data_directory_is_empty from .dcs import dcs_modules from .exceptions import ConfigParseError @@ -48,8 +48,7 @@ def validate_connect_address(address: str) -> bool: return True -def validate_host_port(host_port: str, listen: OptionalType[bool] = False, - multiple_hosts: OptionalType[bool] = False) -> bool: +def validate_host_port(host_port: str, listen: bool = False, multiple_hosts: bool = False) -> bool: """Check if host(s) and port are valid and available for usage. :param host_port: the host(s) and port to be validated. It can be in either of these formats @@ -68,7 +67,7 @@ def validate_host_port(host_port: str, listen: OptionalType[bool] = False, * If :class:`socket.gaierror` is thrown by socket module when attempting to connect to the given address(es). """ try: - hosts, port = split_host_port(host_port, None) + hosts, port = split_host_port(host_port, 1) except (ValueError, TypeError): raise ConfigParseError("contains a wrong value") else: @@ -197,6 +196,7 @@ def get_major_version(bin_dir: OptionalType[str] = None) -> str: binary = os.path.join(bin_dir, 'postgres') version = subprocess.check_output([binary, '--version']).decode() version = re.match(r'^[^\s]+ [^\s]+ (\d+)(\.(\d+))?', version) + assert version is not None return '.'.join([version.group(1), version.group(3)]) if int(version.group(1)) < 10 else version.group(1) @@ -251,8 +251,8 @@ class Result(object): :ivar error: error message if the validation failed, otherwise ``None``. """ - def __init__(self, status: bool, error: OptionalType[str] = "didn't pass validation", level: OptionalType[int] = 0, - path: OptionalType[str] = "", data: OptionalType[Any] = "") -> None: + def __init__(self, status: bool, error: OptionalType[str] = "didn't pass validation", level: int = 0, + path: str = "", data: Any = "") -> None: """Create a :class:`Result` object based on the given arguments. .. note:: @@ -290,7 +290,7 @@ class Case(object): them, if they are set. """ - def __init__(self, schema: dict) -> None: + def __init__(self, schema: Dict[str, Any]) -> None: """Create a :class:`Case` object. :param schema: the schema for validating a set of attributes that may be available in the configuration. @@ -317,7 +317,7 @@ class Or(object): validation functions and/or expected types for a given configuration option. """ - def __init__(self, *args) -> None: + def __init__(self, *args: Any) -> None: """Create an :class:`Or` object. :param `*args`: any arguments that the caller wants to be stored in this :class:`Or` object. @@ -486,7 +486,7 @@ class Schema(object): :param data: configuration to be validated against ``validator``. :returns: list of errors identified while validating the *data*, if any. """ - errors = [] + errors: List[str] = [] for i in self.validate(data): if not i.status: errors.append(str(i)) @@ -533,14 +533,13 @@ class Schema(object): except Exception as e: yield Result(False, "didn't pass validation: {}".format(e), data=self.data) elif isinstance(self.validator, dict): - if not len(self.validator): + if not isinstance(self.data, dict): yield Result(isinstance(self.data, dict), "is not a dictionary", level=1, data=self.data) elif isinstance(self.validator, list): if not isinstance(self.data, list): yield Result(isinstance(self.data, list), "is not a list", level=1, data=self.data) return - for i in self.iter(): - yield i + yield from self.iter() def iter(self) -> Iterator[Result]: """Iterate over ``validator``, if it is an iterable object, to validate the corresponding entries in ``data``. @@ -553,12 +552,11 @@ class Schema(object): if not isinstance(self.data, dict): yield Result(False, "is not a dictionary.", level=1) else: - for i in self.iter_dict(): - yield i + yield from self.iter_dict() elif isinstance(self.validator, list): if len(self.data) == 0: yield Result(False, "is an empty list", data=self.data) - if len(self.validator) > 0: + if self.validator: for key, value in enumerate(self.data): # Although the value in the configuration (`data`) is expected to contain 1 or more entries, only # the first validator defined in `validator` property list will be used. It is only defined as a @@ -569,11 +567,9 @@ class Schema(object): yield Result(v.status, v.error, path=(str(key) + ("." + v.path if v.path else "")), level=v.level, data=value) elif isinstance(self.validator, Directory): - for v in self.validator.validate(self.data): - yield v + yield from self.validator.validate(self.data) elif isinstance(self.validator, Or): - for i in self.iter_or(): - yield i + yield from self.iter_or() def iter_dict(self) -> Iterator[Result]: """Iterate over a :class:`dict` based ``validator`` to validate the corresponding entries in ``data``. @@ -606,14 +602,14 @@ class Schema(object): :rtype: Iterator[:class:`Result`] objects with the error message related to the failure, if any check fails. """ - results = [] + results: List[Result] = [] for a in self.validator.args: - r = [] + r: List[Result] = [] # Each of the `Or` validators can throw 0 to many `Result` instances. for v in Schema(a).validate(self.data): r.append(v) if any([x.status for x in r]) and not all([x.status for x in r]): - results += filter(lambda x: not x.status, r) + results += [x for x in r if not x.status] else: results += r # None of the `Or` validators succeeded to validate `data`, so we report the issues back. @@ -664,12 +660,12 @@ def _get_type_name(python_type: Any) -> str: Returns: User friendly name of the given Python type. """ - return {str: 'a string', int: 'an integer', float: 'a number', - bool: 'a boolean', list: 'an array', dict: 'a dictionary'}.get( - python_type, getattr(python_type, __name__, "unknown type")) + types: Dict[Any, str] = {str: 'a string', int: 'an integer', float: 'a number', + bool: 'a boolean', list: 'an array', dict: 'a dictionary'} + return types.get(python_type, getattr(python_type, __name__, "unknown type")) -def assert_(condition: bool, message: OptionalType[str] = "Wrong value") -> None: +def assert_(condition: bool, message: str = "Wrong value") -> None: """Assert that a given condition is ``True``. If the assertion fails, then throw a message. @@ -680,14 +676,63 @@ def assert_(condition: bool, message: OptionalType[str] = "Wrong value") -> None assert condition, message +class IntValidator(object): + """Validate an integer setting. + + :cvar expected_type: the expect Python type for an integer setting (:class:`int`). + :ivar min: minimum allowed value for the setting, if any. + :ivar max: maximum allowed value for the setting, if any. + :ivar base_unit: the base unit to convert the value to before checking if it's within `min` and `max` range. + :ivar raise_assert: if an ``assert`` call should be performed regarding expected type and valid range. + """ + expected_type = int + + def __init__(self, min: OptionalType[int] = None, max: OptionalType[int] = None, + base_unit: OptionalType[str] = None, raise_assert: bool = False) -> None: + """Create an :class:`IntValidator` object with the given rules. + + :param min: minimum allowed value for the setting, if any. + :param max: maximum allowed value for the setting, if any. + :param base_unit: the base unit to convert the value to before checking if it's within *min* and *max* range. + :param raise_assert: if an ``assert`` call should be performed regarding expected type and valid range. + """ + self.min = min + self.max = max + self.base_unit = base_unit + self.raise_assert = raise_assert + + def __call__(self, value: Union[int, str]) -> bool: + """Check if *value* is a valid integer and within the expected range. + + .. note:: + If ``raise_assert`` is ``True`` and *value* is not valid, then an ``AssertionError`` will be triggered. + :param value: value to be checked against the rules defined for this :class:`IntValidator` instance. + :returns: ``True`` if *value* is valid and within the expected range. + """ + if self.base_unit: + value = parse_int(value, self.base_unit) or "" + ret = isinstance(value, int)\ + and (self.min is None or value >= self.min)\ + and (self.max is None or value <= self.max) + + if self.raise_assert: + assert_(ret) + return ret + + +def validate_watchdog_mode(value: Any) -> None: + assert_(isinstance(value, (str, bool)), "expected type is not a string") + assert_(value in (False, "off", "automatic", "required")) + + userattributes = {"username": "", Optional("password"): ""} available_dcs = [m.split(".")[-1] for m in dcs_modules()] -validate_host_port_list.expected_type = list -comma_separated_host_port.expected_type = str -validate_connect_address.expected_type = str -validate_host_port_listen.expected_type = str -validate_host_port_listen_multiple_hosts.expected_type = str -validate_data_dir.expected_type = str +setattr(validate_host_port_list, 'expected_type', list) +setattr(comma_separated_host_port, 'expected_type', str) +setattr(validate_connect_address, 'expected_type', str) +setattr(validate_host_port_listen, 'expected_type', str) +setattr(validate_host_port_listen_multiple_hosts, 'expected_type', str) +setattr(validate_data_dir, 'expected_type', str) validate_etcd = { Or("host", "hosts", "srv", "srv_suffix", "url", "proxy"): Case({ "host": validate_host_port, @@ -704,7 +749,7 @@ schema = Schema({ "restapi": { "listen": validate_host_port_listen, "connect_address": validate_connect_address, - Optional("request_queue_size"): lambda i: assert_(0 <= int(i) <= 4096) + Optional("request_queue_size"): IntValidator(min=0, max=4096, raise_assert=True) }, Optional("bootstrap"): { "dcs": { @@ -726,7 +771,7 @@ schema = Schema({ "etcd3": validate_etcd, "exhibitor": { "hosts": [str], - "port": lambda i: assert_(int(i) <= 65535), + "port": IntValidator(max=65535, raise_assert=True), Optional("pool_interval"): int }, "raft": { @@ -767,7 +812,7 @@ schema = Schema({ Optional("bin_dir"): Directory(contains_executable=["pg_ctl", "initdb", "pg_controldata", "pg_basebackup", "postgres", "pg_isready"]), Optional("parameters"): { - Optional("unix_socket_directories"): lambda s: assert_(all([isinstance(s, str), len(s)])) + Optional("unix_socket_directories"): str }, Optional("pg_hba"): [str], Optional("pg_ident"): [str], @@ -775,7 +820,7 @@ schema = Schema({ Optional("use_pg_rewind"): bool }, Optional("watchdog"): { - Optional("mode"): lambda m: assert_(m in ["off", "automatic", "required"]), + Optional("mode"): validate_watchdog_mode, Optional("device"): str }, Optional("tags"): { diff --git a/patroni/watchdog/base.py b/patroni/watchdog/base.py index 4f771a56..bebdabe2 100644 --- a/patroni/watchdog/base.py +++ b/patroni/watchdog/base.py @@ -4,7 +4,10 @@ import platform import sys from threading import RLock -from patroni.exceptions import WatchdogError +from typing import Any, Callable, Dict, Optional, Union + +from ..config import Config +from ..exceptions import WatchdogError __all__ = ['WatchdogError', 'Watchdog'] @@ -15,10 +18,10 @@ MODE_AUTOMATIC = 'automatic' # Will use a watchdog if one is available MODE_OFF = 'off' # Will not try to use a watchdog -def parse_mode(mode): +def parse_mode(mode: Union[bool, str]) -> str: if mode is False: return MODE_OFF - mode = mode.lower() + mode = str(mode).lower() if mode in ['require', 'required']: return MODE_REQUIRED elif mode in ['auto', 'automatic']: @@ -29,8 +32,8 @@ def parse_mode(mode): return MODE_OFF -def synchronized(func): - def wrapped(self, *args, **kwargs): +def synchronized(func: Callable[..., Any]) -> Callable[..., Any]: + def wrapped(self: 'Watchdog', *args: Any, **kwargs: Any) -> Any: with self._lock: return func(self, *args, **kwargs) return wrapped @@ -38,7 +41,7 @@ def synchronized(func): class WatchdogConfig(object): """Helper to contain a snapshot of configuration""" - def __init__(self, config): + def __init__(self, config: Config) -> None: watchdog_config = config.get("watchdog") or {'mode': 'automatic'} self.mode = parse_mode(watchdog_config.get('mode', 'automatic')) @@ -49,15 +52,15 @@ class WatchdogConfig(object): self.driver_config = dict((k, v) for k, v in watchdog_config.items() if k not in ['mode', 'safety_margin', 'driver']) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return isinstance(other, WatchdogConfig) and \ all(getattr(self, attr) == getattr(other, attr) for attr in ['mode', 'ttl', 'loop_wait', 'safety_margin', 'driver', 'driver_config']) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other - def get_impl(self): + def get_impl(self) -> 'WatchdogBase': if self.driver == 'testing': # pragma: no cover from patroni.watchdog.linux import TestingWatchdogDevice return TestingWatchdogDevice.from_config(self.driver_config) @@ -68,14 +71,14 @@ class WatchdogConfig(object): return NullWatchdog() @property - def timeout(self): + def timeout(self) -> int: if self.safety_margin == -1: return int(self.ttl // 2) else: return self.ttl - self.safety_margin @property - def timing_slack(self): + def timing_slack(self) -> int: return self.timeout - self.loop_wait @@ -84,8 +87,9 @@ class Watchdog(object): When activation fails underlying implementation will be switched to a Null implementation. To avoid log spam activation will only be retried when watchdog configuration is changed.""" - def __init__(self, config): - self.active_config = self.config = WatchdogConfig(config) + def __init__(self, config: Config) -> None: + self.config = WatchdogConfig(config) + self.active_config: WatchdogConfig = self.config self._lock = RLock() self.active = False @@ -98,7 +102,7 @@ class Watchdog(object): sys.exit(1) @synchronized - def reload_config(self, config): + def reload_config(self, config: Config) -> None: self.config = WatchdogConfig(config) # Turning a watchdog off can always be done immediately if self.config.mode == MODE_OFF: @@ -115,7 +119,7 @@ class Watchdog(object): self.active_config = self.config @synchronized - def activate(self): + def activate(self) -> bool: """Activates the watchdog device with suitable timeouts. While watchdog is active keepalive needs to be called every time loop_wait expires. @@ -124,7 +128,7 @@ class Watchdog(object): self.active = True return self._activate() - def _activate(self): + def _activate(self) -> bool: self.active_config = self.config if self.config.timing_slack < 0: @@ -138,12 +142,13 @@ class Watchdog(object): except WatchdogError as e: logger.warning("Could not activate %s: %s", self.impl.describe(), e) self.impl = NullWatchdog() + actual_timeout = self.impl.get_timeout() if self.impl.is_running and not self.impl.can_be_disabled: logger.warning("Watchdog implementation can't be disabled." " Watchdog will trigger after Patroni loses leader key.") - if not self.impl.is_running or actual_timeout > self.config.timeout: + if not self.impl.is_running or actual_timeout and actual_timeout > self.config.timeout: if self.config.mode == MODE_REQUIRED: if self.impl.is_null: logger.error("Configuration requires watchdog, but watchdog could not be configured.") @@ -167,7 +172,7 @@ class Watchdog(object): return True - def _set_timeout(self): + def _set_timeout(self) -> Optional[int]: if self.impl.has_set_timeout(): self.impl.set_timeout(self.config.timeout) @@ -184,11 +189,11 @@ class Watchdog(object): return actual_timeout @synchronized - def disable(self): + def disable(self) -> None: self._disable() self.active = False - def _disable(self): + def _disable(self) -> None: try: if self.impl.is_running and not self.impl.can_be_disabled: # Give sysadmin some extra time to clean stuff up. @@ -200,7 +205,7 @@ class Watchdog(object): logger.error("Error while disabling watchdog: %s", e) @synchronized - def keepalive(self): + def keepalive(self) -> None: try: if self.active: self.impl.keepalive() @@ -225,12 +230,12 @@ class Watchdog(object): @property @synchronized - def is_running(self): + def is_running(self) -> bool: return self.impl.is_running @property @synchronized - def is_healthy(self): + def is_healthy(self) -> bool: if self.config.mode != MODE_REQUIRED: return True return self.config.timing_slack >= 0 and self.impl.is_healthy @@ -242,60 +247,59 @@ class WatchdogBase(abc.ABC): is_null = False @property - def is_running(self): + def is_running(self) -> bool: """Returns True when watchdog is activated and capable of performing it's task.""" return False @property - def is_healthy(self): + def is_healthy(self) -> bool: """Returns False when calling open() is known to fail.""" return False @property - def can_be_disabled(self): + def can_be_disabled(self) -> bool: """Returns True when watchdog will be disabled by calling close(). Some watchdog devices will keep running no matter what once activated. May raise WatchdogError if called without calling open() first.""" return True @abc.abstractmethod - def open(self): + def open(self) -> None: """Open watchdog device. When watchdog is opened keepalive must be called. Returns nothing on success or raises WatchdogError if the device could not be opened.""" @abc.abstractmethod - def close(self): + def close(self) -> None: """Gracefully close watchdog device.""" @abc.abstractmethod - def keepalive(self): + def keepalive(self) -> None: """Resets the watchdog timer. Watchdog must be open when keepalive is called.""" @abc.abstractmethod - def get_timeout(self): + def get_timeout(self) -> int: """Returns the current keepalive timeout in effect.""" - @staticmethod - def has_set_timeout(): + def has_set_timeout(self) -> bool: """Returns True if setting a timeout is supported.""" return False - def set_timeout(self, timeout): + def set_timeout(self, timeout: int) -> None: """Set the watchdog timer timeout. :param timeout: watchdog timeout in seconds""" raise WatchdogError("Setting timeout is not supported on {0}".format(self.describe())) - def describe(self): + def describe(self) -> str: """Human readable name for this device""" return self.__class__.__name__ @classmethod - def from_config(cls, config): + def from_config(cls, config: Dict[str, Any]) -> 'WatchdogBase': return cls() @@ -303,15 +307,15 @@ class NullWatchdog(WatchdogBase): """Null implementation when watchdog is not supported.""" is_null = True - def open(self): + def open(self) -> None: return - def close(self): + def close(self) -> None: return - def keepalive(self): + def keepalive(self) -> None: return - def get_timeout(self): + def get_timeout(self) -> int: # A big enough number to not matter return 1000000000 diff --git a/patroni/watchdog/linux.py b/patroni/watchdog/linux.py index d64846b0..3c96a268 100644 --- a/patroni/watchdog/linux.py +++ b/patroni/watchdog/linux.py @@ -1,8 +1,10 @@ -import collections import ctypes import os import platform -from patroni.watchdog.base import WatchdogBase, WatchdogError + +from typing import Any, Dict, NamedTuple + +from .base import WatchdogBase, WatchdogError # Pythonification of linux/ioctl.h IOC_NONE = 0 @@ -19,7 +21,7 @@ machine = platform.machine() if machine in ['mips', 'sparc', 'powerpc', 'ppc64', 'ppc64le']: # pragma: no cover IOC_SIZEBITS = 13 IOC_DIRBITS = 3 - IOC_NONE, IOC_WRITE, IOC_READ = 1, 4, 2 + IOC_NONE, IOC_WRITE = 1, 4 elif machine == 'parisc': # pragma: no cover IOC_WRITE, IOC_READ = 2, 1 @@ -29,19 +31,19 @@ IOC_SIZESHIFT = IOC_TYPESHIFT + IOC_TYPEBITS IOC_DIRSHIFT = IOC_SIZESHIFT + IOC_SIZEBITS -def IOW(type_, nr, size): +def IOW(type_: str, nr: int, size: int) -> int: return IOC(IOC_WRITE, type_, nr, size) -def IOR(type_, nr, size): +def IOR(type_: str, nr: int, size: int) -> int: return IOC(IOC_READ, type_, nr, size) -def IOWR(type_, nr, size): +def IOWR(type_: str, nr: int, size: int) -> int: return IOC(IOC_READ | IOC_WRITE, type_, nr, size) -def IOC(dir_, type_, nr, size): +def IOC(dir_: int, type_: str, nr: int, size: int) -> int: return (dir_ << IOC_DIRSHIFT) \ | (ord(type_) << IOC_TYPESHIFT) \ | (nr << IOC_NRSHIFT) \ @@ -104,9 +106,13 @@ WDIOS = { # Implementation -class WatchdogInfo(collections.namedtuple('WatchdogInfo', 'options,version,identity')): +class WatchdogInfo(NamedTuple): """Watchdog descriptor from the kernel""" - def __getattr__(self, name): + options: int + version: int + identity: str + + def __getattr__(self, name: str) -> bool: """Convenience has_XYZ attributes for checking WDIOF bits in options""" if name.startswith('has_') and name[4:] in WDIOF: return bool(self.options & WDIOF[name[4:]]) @@ -117,32 +123,32 @@ class WatchdogInfo(collections.namedtuple('WatchdogInfo', 'options,version,ident class LinuxWatchdogDevice(WatchdogBase): DEFAULT_DEVICE = '/dev/watchdog' - def __init__(self, device): + def __init__(self, device: str) -> None: self.device = device self._support_cache = None self._fd = None @classmethod - def from_config(cls, config): + def from_config(cls, config: Dict[str, Any]) -> 'LinuxWatchdogDevice': device = config.get('device', cls.DEFAULT_DEVICE) return cls(device) @property - def is_running(self): + def is_running(self) -> bool: return self._fd is not None @property - def is_healthy(self): + def is_healthy(self) -> bool: return os.path.exists(self.device) and os.access(self.device, os.W_OK) - def open(self): + def open(self) -> None: try: self._fd = os.open(self.device, os.O_WRONLY) except OSError as e: raise WatchdogError("Can't open watchdog device: {0}".format(e)) - def close(self): - if self.is_running: + def close(self) -> None: + if self._fd is not None: # self.is_running try: os.write(self._fd, b'V') os.close(self._fd) @@ -151,10 +157,10 @@ class LinuxWatchdogDevice(WatchdogBase): raise WatchdogError("Error while closing {0}: {1}".format(self.describe(), e)) @property - def can_be_disabled(self): + def can_be_disabled(self) -> bool: return self.get_support().has_MAGICCLOSE - def _ioctl(self, func, arg): + def _ioctl(self, func: int, arg: Any) -> None: """Runs the specified ioctl on the underlying fd. Raises WatchdogError if the device is closed. @@ -165,7 +171,7 @@ class LinuxWatchdogDevice(WatchdogBase): import fcntl fcntl.ioctl(self._fd, func, arg, True) - def get_support(self): + def get_support(self) -> WatchdogInfo: if self._support_cache is None: info = watchdog_info() try: @@ -177,7 +183,7 @@ class LinuxWatchdogDevice(WatchdogBase): bytearray(info.identity).decode(errors='ignore').rstrip('\x00')) return self._support_cache - def describe(self): + def describe(self) -> str: dev_str = " at {0}".format(self.device) if self.device != self.DEFAULT_DEVICE else "" ver_str = "" identity = "Linux watchdog device" @@ -190,17 +196,19 @@ class LinuxWatchdogDevice(WatchdogBase): return identity + ver_str + dev_str - def keepalive(self): + def keepalive(self) -> None: + if self._fd is None: + raise WatchdogError("Watchdog device is closed") try: os.write(self._fd, b'1') except OSError as e: raise WatchdogError("Could not send watchdog keepalive: {0}".format(e)) - def has_set_timeout(self): + def has_set_timeout(self) -> bool: """Returns True if setting a timeout is supported.""" return self.get_support().has_SETTIMEOUT - def set_timeout(self, timeout): + def set_timeout(self, timeout: int) -> None: timeout = int(timeout) if not 0 < timeout < 0xFFFF: raise WatchdogError("Invalid timeout {0}. Supported values are between 1 and 65535".format(timeout)) @@ -209,7 +217,7 @@ class LinuxWatchdogDevice(WatchdogBase): except (WatchdogError, OSError, IOError) as e: raise WatchdogError("Could not set timeout on watchdog device: {}".format(e)) - def get_timeout(self): + def get_timeout(self) -> int: timeout = ctypes.c_int() try: self._ioctl(WDIOC_GETTIMEOUT, timeout) @@ -222,14 +230,16 @@ class TestingWatchdogDevice(LinuxWatchdogDevice): # pragma: no cover """Converts timeout ioctls to regular writes that can be intercepted from a named pipe.""" timeout = 60 - def get_support(self): + def get_support(self) -> WatchdogInfo: return WatchdogInfo(WDIOF['MAGICCLOSE'] | WDIOF['SETTIMEOUT'], 0, "Watchdog test harness") - def set_timeout(self, timeout): + def set_timeout(self, timeout: int) -> None: + if self._fd is None: + raise WatchdogError("Watchdog device is closed") buf = "Ctimeout={0}\n".format(timeout).encode('utf8') while len(buf): buf = buf[os.write(self._fd, buf):] self.timeout = timeout - def get_timeout(self): + def get_timeout(self) -> int: return self.timeout diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 00000000..20980afb --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,27 @@ +{ + "include": [ + "patroni" + ], + + "exclude": [ + "**/__pycache__" + ], + + "ignore": [ + ], + + "defineConstant": { + "DEBUG": true + }, + + "stubPath": "typings/", + + "reportMissingImports": true, + "reportMissingTypeStubs": false, + + "pythonVersion": "3.6", + "pythonPlatform": "All", + + "typeCheckingMode": "strict" + +} diff --git a/tests/__init__.py b/tests/__init__.py index 2778a846..838c2636 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -23,6 +23,7 @@ class MockResponse(object): def __init__(self, status_code=200): self.status_code = status_code + self.headers = {'content-type': 'json'} self.content = '{}' self.reason = 'Not Found' @@ -38,10 +39,6 @@ class MockResponse(object): def getheader(*args): return '' - @staticmethod - def getheaders(): - return {'content-type': 'json'} - def requests_get(url, method='GET', endpoint=None, data='', **kwargs): members = '[{"id":14855829450254237642,"peerURLs":["http://localhost:2380","http://localhost:7001"],' +\ @@ -85,6 +82,8 @@ class MockCursor(object): self.description = [Mock()] def execute(self, sql, *params): + if isinstance(sql, bytes): + sql = sql.decode('utf-8') if sql.startswith('blabla'): raise psycopg.ProgrammingError() elif sql == 'CHECKPOINT' or sql.startswith('SELECT pg_catalog.pg_create_'): @@ -98,7 +97,7 @@ class MockCursor(object): elif sql.startswith('SELECT slot_name'): self.results = [('blabla', 'physical'), ('foobar', 'physical'), ('ls', 'logical', 'a', 'b', 5, 100, 500)] elif sql.startswith('WITH slots AS (SELECT slot_name, active'): - self.results = [(False, True)] + self.results = [(False, True)] if self.rowcount == 1 else [None] elif sql.startswith('SELECT CASE WHEN pg_catalog.pg_is_in_recovery()'): self.results = [(1, 2, 1, 0, False, 1, 1, None, None, [{"slot_name": "ls", "confirmed_flush_lsn": 12345}], diff --git a/tests/test_citus.py b/tests/test_citus.py index cd1d99f0..a2d096ba 100644 --- a/tests/test_citus.py +++ b/tests/test_citus.py @@ -136,8 +136,8 @@ class TestCitus(BaseTestPostgresql): self.assertEqual(parameters['shared_preload_libraries'], 'citus,foo,bar') self.assertEqual(parameters['wal_level'], 'logical') - @patch.object(CitusHandler, 'is_enabled', Mock(return_value=False)) def test_bootstrap(self): + self.c._config = None self.c.bootstrap() def test_ignore_replication_slot(self): diff --git a/tests/test_config.py b/tests/test_config.py index 3cd94a9e..8223a62f 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -19,8 +19,9 @@ class TestConfig(unittest.TestCase): def test_set_dynamic_configuration(self): with patch.object(Config, '_build_effective_configuration', Mock(side_effect=Exception)): - self.assertIsNone(self.config.set_dynamic_configuration({'foo': 'bar'})) - self.assertTrue(self.config.set_dynamic_configuration({'standby_cluster': {}})) + self.assertFalse(self.config.set_dynamic_configuration({'foo': 'bar'})) + self.assertTrue(self.config.set_dynamic_configuration({'standby_cluster': {}, 'postgresql': { + 'parameters': {'cluster_name': 1, 'wal_keep_size': 1, 'track_commit_timestamp': 1, 'wal_level': 1}}})) def test_reload_local_configuration(self): os.environ.update({ diff --git a/tests/test_consul.py b/tests/test_consul.py index 7bbdbc02..f2bb6bc5 100644 --- a/tests/test_consul.py +++ b/tests/test_consul.py @@ -195,6 +195,8 @@ class TestConsul(unittest.TestCase): @patch.object(consul.Consul.KV, 'delete', Mock(return_value=True)) def test_delete_leader(self): self.c.delete_leader() + self.c._name = 'other' + self.c.delete_leader() @patch.object(consul.Consul.KV, 'put', Mock(return_value=True)) def test_initialize(self): diff --git a/tests/test_ctl.py b/tests/test_ctl.py index eba4d1c1..f7d3eec5 100644 --- a/tests/test_ctl.py +++ b/tests/test_ctl.py @@ -9,7 +9,7 @@ from mock import patch, Mock, PropertyMock from patroni.ctl import ctl, load_config, output_members, get_dcs, parse_dcs, \ get_all_members, get_any_member, get_cursor, query_member, PatroniCtlException, apply_config_changes, \ format_config_for_editing, show_diff, invoke_editor, format_pg_version, CONFIG_FILE_PATH, PatronictlPrettyTable -from patroni.dcs.etcd import AbstractEtcdClientWithFailover, Failover +from patroni.dcs.etcd import AbstractEtcdClientWithFailover, Cluster, Failover from patroni.psycopg import OperationalError from patroni.utils import tzutc from prettytable import PrettyTable, ALL @@ -639,6 +639,10 @@ class TestCtl(unittest.TestCase): self.runner.invoke(ctl, ['edit-config', 'dummy', '--force', '--apply', '-'], input='foo: bar') mock_get_dcs.return_value.set_config_value.return_value = True self.runner.invoke(ctl, ['edit-config', 'dummy', '--force', '--apply', '-'], input='foo: bar') + mock_get_dcs.return_value.get_cluster = Mock(return_value=Cluster.empty()) + result = self.runner.invoke(ctl, ['edit-config', 'dummy']) + assert result.exit_code == 1 + assert 'The config key does not exist in the cluster dummy' in result.output @patch('patroni.ctl.get_dcs') def test_version(self, mock_get_dcs): diff --git a/tests/test_etcd.py b/tests/test_etcd.py index 8f780342..914ddeb8 100644 --- a/tests/test_etcd.py +++ b/tests/test_etcd.py @@ -138,7 +138,9 @@ class TestClient(unittest.TestCase): @patch.object(EtcdClient, '_get_machines_list', Mock(return_value=['http://localhost:2379', 'http://localhost:4001'])) def setUp(self): - self.client = EtcdClient({'srv': 'test', 'retry_timeout': 3}, DnsCachingResolver()) + self.etcd = Etcd({'namespace': '/patroni/', 'ttl': 30, 'retry_timeout': 3, + 'srv': 'test', 'scope': 'test', 'name': 'foo'}) + self.client = self.etcd._client self.client.http.request = http_request self.client.http.request_encode_body = http_request diff --git a/tests/test_etcd3.py b/tests/test_etcd3.py index 58433bbb..c0acb67f 100644 --- a/tests/test_etcd3.py +++ b/tests/test_etcd3.py @@ -3,7 +3,7 @@ import json import unittest import urllib3 -from mock import Mock, patch +from mock import Mock, PropertyMock, patch from patroni.dcs.etcd import DnsCachingResolver from patroni.dcs.etcd3 import PatroniEtcd3Client, Cluster, Etcd3Client, Etcd3Error, Etcd3ClientError, RetryFailedError,\ InvalidAuthToken, Unavailable, Unknown, UnsupportedEtcdVersion, UserEmpty, AuthFailed, base64_encode, Etcd3 @@ -85,6 +85,11 @@ class BaseTestEtcd3(unittest.TestCase): class TestKVCache(BaseTestEtcd3): + @patch.object(urllib3.PoolManager, 'urlopen', mock_urlopen) + @patch.object(Etcd3Client, 'watchprefix', Mock(return_value=urllib3.response.HTTPResponse())) + def test__build_cache(self): + self.kv_cache._build_cache() + def test__do_watch(self): self.client.watchprefix = Mock(return_value=False) self.assertRaises(AttributeError, self.kv_cache._do_watch, '1') @@ -94,14 +99,17 @@ class TestKVCache(BaseTestEtcd3): def test_run(self): self.assertRaises(SleepException, self.kv_cache.run) - @patch.object(urllib3.PoolManager, 'urlopen', mock_urlopen) + @patch.object(urllib3.response.HTTPResponse, 'read_chunked', + Mock(return_value=[b'{"error":{"grpc_code":14,"message":"","http_code":503}}'])) + @patch.object(Etcd3Client, 'watchprefix', Mock(return_value=urllib3.response.HTTPResponse())) def test_kill_stream(self): self.assertRaises(Unavailable, self.kv_cache._do_watch, '1') - self.kv_cache.kill_stream() - with patch.object(MockResponse, 'connection', create=True) as mock_conn: + with patch.object(urllib3.response.HTTPResponse, 'connection') as mock_conn: self.kv_cache.kill_stream() mock_conn.sock.close.side_effect = Exception self.kv_cache.kill_stream() + type(mock_conn).sock = PropertyMock(side_effect=Exception) + self.kv_cache.kill_stream() class TestPatroniEtcd3Client(BaseTestEtcd3): @@ -180,7 +188,8 @@ class TestEtcd3(BaseTestEtcd3): @patch.object(urllib3.PoolManager, 'urlopen', mock_urlopen) def setUp(self): super(TestEtcd3, self).setUp() - self.assertRaises(AttributeError, self.kv_cache._build_cache) +# self.assertRaises(AttributeError, self.kv_cache._build_cache) + self.kv_cache._build_cache() self.kv_cache._is_ready = True self.etcd3.get_cluster() @@ -276,6 +285,8 @@ class TestEtcd3(BaseTestEtcd3): def test_delete_leader(self): self.etcd3.delete_leader() + self.etcd3._name = 'other' + self.etcd3.delete_leader() def test_delete_cluster(self): self.etcd3.delete_cluster() diff --git a/tests/test_ha.py b/tests/test_ha.py index 067fd891..98e9966e 100644 --- a/tests/test_ha.py +++ b/tests/test_ha.py @@ -244,7 +244,6 @@ class TestHa(PostgresInit): def test_bootstrap_as_standby_leader(self, initialize): self.p.data_directory_empty = true self.ha.cluster = get_cluster_not_initialized_without_leader(cluster_config=ClusterConfig(0, {}, 0)) - self.ha.cluster.is_unlocked = true self.ha.patroni.config._dynamic_configuration = {"standby_cluster": {"port": 5432}} self.assertEqual(self.ha.run_cycle(), 'trying to bootstrap a new standby leader') @@ -261,7 +260,6 @@ class TestHa(PostgresInit): def test_start_as_cascade_replica_in_standby_cluster(self): self.p.data_directory_empty = true self.ha.cluster = get_standby_cluster_initialized_with_only_leader() - self.ha.cluster.is_unlocked = false self.assertEqual(self.ha.run_cycle(), "trying to bootstrap from replica 'test'") def test_recover_replica_failed(self): @@ -381,8 +379,8 @@ class TestHa(PostgresInit): self.ha._async_executor.schedule('promote') self.assertEqual(self.ha.run_cycle(), 'lost leader before promote') + @patch.object(Cluster, 'is_unlocked', Mock(return_value=False)) def test_long_promote(self): - self.ha.cluster.is_unlocked = false self.ha.has_lock = true self.p.is_leader = false self.p.set_role('primary') @@ -407,14 +405,13 @@ class TestHa(PostgresInit): self.p.is_leader = false self.assertEqual(self.ha.run_cycle(), 'following a different leader because i am not the healthiest node') + @patch.object(Cluster, 'is_unlocked', Mock(return_value=False)) def test_promote_because_have_lock(self): - self.ha.cluster.is_unlocked = false self.ha.has_lock = true self.p.is_leader = false self.assertEqual(self.ha.run_cycle(), 'promoted self to leader because I had the session lock') def test_promote_without_watchdog(self): - self.ha.cluster.is_unlocked = false self.ha.has_lock = true self.p.is_leader = true with patch.object(Watchdog, 'activate', Mock(return_value=False)): @@ -424,24 +421,22 @@ class TestHa(PostgresInit): def test_leader_with_lock(self): self.ha.cluster = get_cluster_initialized_with_leader() - self.ha.cluster.is_unlocked = false self.ha.has_lock = true self.assertEqual(self.ha.run_cycle(), 'no action. I am (postgresql0), the leader with the lock') def test_coordinator_leader_with_lock(self): self.ha.cluster = get_cluster_initialized_with_leader() - self.ha.cluster.is_unlocked = false self.ha.has_lock = true self.assertEqual(self.ha.run_cycle(), 'no action. I am (postgresql0), the leader with the lock') @patch.object(Postgresql, '_wait_for_connection_close', Mock()) + @patch.object(Cluster, 'is_unlocked', Mock(return_value=False)) def test_demote_because_not_having_lock(self): - self.ha.cluster.is_unlocked = false with patch.object(Watchdog, 'is_running', PropertyMock(return_value=True)): self.assertEqual(self.ha.run_cycle(), 'demoting self because I do not have the lock and I was a leader') + @patch.object(Cluster, 'is_unlocked', Mock(return_value=False)) def test_demote_because_update_lock_failed(self): - self.ha.cluster.is_unlocked = false self.ha.has_lock = true self.ha.update_lock = false self.assertEqual(self.ha.run_cycle(), 'demoted self because failed to update leader lock in DCS') @@ -450,8 +445,8 @@ class TestHa(PostgresInit): self.p.is_leader = false self.assertEqual(self.ha.run_cycle(), 'not promoting because failed to update leader lock in DCS') + @patch.object(Cluster, 'is_unlocked', Mock(return_value=False)) def test_follow(self): - self.ha.cluster.is_unlocked = false self.p.is_leader = false self.assertEqual(self.ha.run_cycle(), 'no action. I am (postgresql0), a secondary, and following a leader ()') self.ha.patroni.replicatefrom = "foo" @@ -465,8 +460,8 @@ class TestHa(PostgresInit): self.assertEqual(self.ha.run_cycle(), 'no action. I am (postgresql0), a secondary, and following a leader ()') del self.ha.cluster.config.data['postgresql']['use_slots'] + @patch.object(Cluster, 'is_unlocked', Mock(return_value=False)) def test_follow_in_pause(self): - self.ha.cluster.is_unlocked = false self.ha.is_paused = true self.assertEqual(self.ha.run_cycle(), 'PAUSE: continue to run as primary without lock') self.p.is_leader = false @@ -659,10 +654,12 @@ class TestHa(PostgresInit): self.ha.update_lock = false self.p.set_role('primary') - with patch('patroni.async_executor.CriticalTask.cancel', Mock(return_value=False)): - with patch('patroni.postgresql.Postgresql.terminate_starting_postmaster') as mock_terminate: - self.assertEqual(self.ha.run_cycle(), 'lost leader lock during restart') - mock_terminate.assert_called() + with patch('patroni.async_executor.CriticalTask.cancel', Mock(return_value=False)),\ + patch('patroni.async_executor.CriticalTask.result', + PropertyMock(return_value=PostmasterProcess(os.getpid())), create=True),\ + patch('patroni.postgresql.Postgresql.terminate_starting_postmaster') as mock_terminate: + self.assertEqual(self.ha.run_cycle(), 'lost leader lock during restart') + mock_terminate.assert_called() self.ha.is_paused = true self.assertEqual(self.ha.run_cycle(), 'PAUSE: restart in progress') @@ -741,7 +738,6 @@ class TestHa(PostgresInit): def test_manual_failover_process_no_leader(self): self.p.is_leader = false self.ha.cluster = get_cluster_initialized_without_leader(failover=Failover(0, '', self.p.name, None)) - self.assertEqual(self.ha.run_cycle(), 'promoted self to leader by acquiring session lock') self.ha.cluster = get_cluster_initialized_without_leader(failover=Failover(0, '', 'leader', None)) self.p.set_role('replica') self.assertEqual(self.ha.run_cycle(), 'promoted self to leader by acquiring session lock') @@ -847,10 +843,10 @@ class TestHa(PostgresInit): self.assertFalse(self.ha.is_healthiest_node()) def test__is_healthiest_node(self): + self.p.is_leader = false self.ha.cluster = get_cluster_initialized_without_leader(sync=('postgresql1', self.p.name)) self.ha.global_config = self.ha.patroni.config.get_global_config(self.ha.cluster) self.assertTrue(self.ha._is_healthiest_node(self.ha.old_cluster.members)) - self.p.is_leader = false self.ha.fetch_node_status = get_node_status() # accessible, in_recovery self.assertTrue(self.ha._is_healthiest_node(self.ha.old_cluster.members)) self.ha.fetch_node_status = get_node_status(in_recovery=False) # accessible, not in_recovery @@ -864,6 +860,8 @@ class TestHa(PostgresInit): self.ha.global_config = self.ha.patroni.config.get_global_config(self.ha.cluster) with patch('patroni.postgresql.Postgresql.last_operation', return_value=1): self.assertFalse(self.ha._is_healthiest_node(self.ha.old_cluster.members)) + with patch('patroni.postgresql.Postgresql.replica_cached_timeline', return_value=None): + self.assertFalse(self.ha._is_healthiest_node(self.ha.old_cluster.members)) with patch('patroni.postgresql.Postgresql.replica_cached_timeline', return_value=1): self.assertFalse(self.ha._is_healthiest_node(self.ha.old_cluster.members)) self.ha.patroni.nofailover = True @@ -972,11 +970,11 @@ class TestHa(PostgresInit): with patch.object(Leader, 'conn_url', PropertyMock(return_value='')): self.assertEqual(self.ha.run_cycle(), 'continue following the old known standby leader') + @patch.object(Cluster, 'is_unlocked', Mock(return_value=True)) def test_process_unhealthy_standby_cluster_as_standby_leader(self): self.p.is_leader = false self.p.name = 'leader' self.ha.cluster = get_standby_cluster_initialized_with_only_leader() - self.ha.cluster.is_unlocked = true self.ha.sysid_valid = true self.p._sysid = True self.assertEqual(self.ha.run_cycle(), 'promoted self to a standby leader by acquiring session lock') @@ -987,7 +985,6 @@ class TestHa(PostgresInit): self.p.is_leader = false self.p.name = 'replica' self.ha.cluster = get_standby_cluster_initialized_with_only_leader() - self.ha.is_unlocked = true self.assertTrue(self.ha.run_cycle().startswith('running pg_rewind from remote_member:')) def test_recover_unhealthy_leader_in_standby_cluster(self): @@ -998,13 +995,13 @@ class TestHa(PostgresInit): self.ha.cluster = get_standby_cluster_initialized_with_only_leader() self.assertEqual(self.ha.run_cycle(), 'starting as a standby leader because i had the session lock') + @patch.object(Cluster, 'is_unlocked', Mock(return_value=True)) def test_recover_unhealthy_unlocked_standby_cluster(self): self.p.is_leader = false self.p.name = 'leader' self.p.is_running = false self.p.follow = false self.ha.cluster = get_standby_cluster_initialized_with_only_leader() - self.ha.cluster.is_unlocked = true self.ha.has_lock = false self.assertEqual(self.ha.run_cycle(), 'trying to follow a remote member because standby cluster is unhealthy') @@ -1287,10 +1284,10 @@ class TestHa(PostgresInit): @patch('patroni.postgresql.mtime', Mock(return_value=1588316884)) @patch('builtins.open', Mock(side_effect=Exception)) + @patch.object(Cluster, 'is_unlocked', Mock(return_value=False)) def test_restore_cluster_config(self): self.ha.cluster.config.data.clear() self.ha.has_lock = true - self.ha.cluster.is_unlocked = false self.assertEqual(self.ha.run_cycle(), 'no action. I am (postgresql0), the leader with the lock') def test_watch(self): @@ -1341,9 +1338,9 @@ class TestHa(PostgresInit): @patch('patroni.postgresql.mtime', Mock(return_value=1588316884)) @patch('builtins.open', mock_open(read_data=('1\t0/40159C0\tno recovery target specified\n\n' '2\t1/40159C0\tno recovery target specified\n'))) + @patch.object(Cluster, 'is_unlocked', Mock(return_value=False)) def test_update_cluster_history(self): self.ha.has_lock = true - self.ha.cluster.is_unlocked = false for tl in (1, 3): self.p.get_primary_timeline = Mock(return_value=tl) self.assertEqual(self.ha.run_cycle(), 'no action. I am (postgresql0), the leader with the lock') @@ -1355,9 +1352,9 @@ class TestHa(PostgresInit): self.ha.run_cycle() exit_mock.assert_called_once_with(1) + @patch.object(Cluster, 'is_unlocked', Mock(return_value=False)) def test_after_pause(self): self.ha.has_lock = true - self.ha.cluster.is_unlocked = false self.ha.is_paused = true self.assertEqual(self.ha.run_cycle(), 'PAUSE: no action. I am (postgresql0), the leader with the lock') self.ha.is_paused = false @@ -1409,16 +1406,10 @@ class TestHa(PostgresInit): @patch.object(Postgresql, 'major_version', PropertyMock(return_value=130000)) @patch.object(SlotsHandler, 'sync_replication_slots', Mock(return_value=['ls'])) def test_follow_copy(self): - self.ha.cluster.is_unlocked = false self.ha.cluster.config.data['slots'] = {'ls': {'database': 'a', 'plugin': 'b'}} self.p.is_leader = false self.assertTrue(self.ha.run_cycle().startswith('Copying logical slots')) - def test_is_failover_possible(self): - self.ha.fetch_node_status = Mock(return_value=_MemberStatus(self.ha.cluster.members[0], - True, True, 0, 2, None, {}, False)) - self.assertFalse(self.ha.is_failover_possible(self.ha.cluster.members)) - def test_acquire_lock(self): self.ha.dcs.attempt_to_acquire_leader = Mock(side_effect=[DCSError('foo'), Exception]) self.assertRaises(DCSError, self.ha.acquire_lock) diff --git a/tests/test_kubernetes.py b/tests/test_kubernetes.py index 86669a08..fd2185e0 100644 --- a/tests/test_kubernetes.py +++ b/tests/test_kubernetes.py @@ -5,6 +5,7 @@ import mock import socket import time import unittest +import urllib3 from mock import Mock, PropertyMock, mock_open, patch from patroni.dcs.kubernetes import Cluster, k8s_client, k8s_config, K8sConfig, K8sConnectionFailed,\ @@ -34,6 +35,9 @@ def mock_list_namespaced_config_map(*args, **kwargs): metadata.update({'name': 'test-1-leader', 'labels': {Kubernetes._CITUS_LABEL: '1'}, 'annotations': {'leader': 'p-3', 'ttl': '30s'}}) items.append(k8s_client.V1ConfigMap(metadata=k8s_client.V1ObjectMeta(**metadata))) + metadata.update({'name': 'test-2-config', 'labels': {Kubernetes._CITUS_LABEL: '2'}, 'annotations': {}}) + items.append(k8s_client.V1ConfigMap(metadata=k8s_client.V1ObjectMeta(**metadata))) + metadata = k8s_client.V1ObjectMeta(resource_version='1') return k8s_client.V1ConfigMapList(metadata=metadata, items=items, kind='ConfigMapList') @@ -403,13 +407,18 @@ class TestKubernetesEndpoints(BaseTestKubernetes): self.assertEqual(('create_config_service failed',), mock_logger_exception.call_args[0]) +def mock_watch(*args): + return urllib3.HTTPResponse() + + class TestCacheBuilder(BaseTestKubernetes): @patch.object(k8s_client.CoreV1Api, 'list_namespaced_config_map', mock_list_namespaced_config_map, create=True) - @patch('patroni.dcs.kubernetes.ObjectCache._watch') - def test__build_cache(self, mock_response): + @patch('patroni.dcs.kubernetes.ObjectCache._watch', mock_watch) + @patch.object(urllib3.HTTPResponse, 'read_chunked') + def test__build_cache(self, mock_read_chunked): self.k._citus_group = '0' - mock_response.return_value.read_chunked.return_value = [json.dumps( + mock_read_chunked.return_value = [json.dumps( {'type': 'MODIFIED', 'object': {'metadata': { 'name': self.k.config_path, 'resourceVersion': '2', 'annotations': {self.k._CONFIG: 'foo'}}}} ).encode('utf-8'), ('\n' + json.dumps( @@ -435,12 +444,13 @@ class TestCacheBuilder(BaseTestKubernetes): self.assertRaises(AttributeError, self.k._kinds._do_watch, '1') @patch.object(k8s_client.CoreV1Api, 'list_namespaced_config_map', mock_list_namespaced_config_map, create=True) - @patch('patroni.dcs.kubernetes.ObjectCache._watch') - def test_kill_stream(self, mock_watch): - self.k._kinds.kill_stream() - mock_watch.return_value.read_chunked.return_value = [] - mock_watch.return_value.connection.sock.close.side_effect = Exception - self.k._kinds._do_watch('1') - self.k._kinds.kill_stream() - type(mock_watch.return_value).connection = PropertyMock(side_effect=Exception) + @patch('patroni.dcs.kubernetes.ObjectCache._watch', mock_watch) + @patch.object(urllib3.HTTPResponse, 'read_chunked', Mock(return_value=[])) + def test_kill_stream(self): self.k._kinds.kill_stream() + with patch.object(urllib3.HTTPResponse, 'connection') as mock_connection: + mock_connection.sock.close.side_effect = Exception + self.k._kinds._do_watch('1') + self.k._kinds.kill_stream() + with patch.object(urllib3.HTTPResponse, 'connection', PropertyMock(side_effect=Exception)): + self.k._kinds.kill_stream() diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 49e58c7a..1a44b671 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -369,6 +369,10 @@ class TestPostgresql(BaseTestPostgresql): def test_latest_checkpoint_location(self, mock_popen): mock_popen.return_value.communicate.return_value = (None, None) self.assertEqual(self.p.latest_checkpoint_location(), 28163096) + with patch.object(Postgresql, 'controldata', Mock(return_value={'Database cluster state': 'shut down', + 'Latest checkpoint location': 'k/1ADBC18', + "Latest checkpoint's TimeLineID": '1'})): + self.assertIsNone(self.p.latest_checkpoint_location()) # 9.3 and 9.4 format mock_popen.return_value.communicate.side_effect = [ (b'rmgr: XLOG len (rec/tot): 72/ 104, tx: 0, lsn: 0/01ADBC18, prev 0/01ADBBB8, ' @@ -518,10 +522,10 @@ class TestPostgresql(BaseTestPostgresql): def test_can_create_replica_without_replication_connection(self): self.p.config._config['create_replica_method'] = [] - self.assertFalse(self.p.can_create_replica_without_replication_connection()) + self.assertFalse(self.p.can_create_replica_without_replication_connection(None)) self.p.config._config['create_replica_method'] = ['wale', 'basebackup'] self.p.config._config['wale'] = {'command': 'foo', 'no_leader': 1} - self.assertTrue(self.p.can_create_replica_without_replication_connection()) + self.assertTrue(self.p.can_create_replica_without_replication_connection(None)) def test_replica_method_can_work_without_replication_connection(self): self.assertFalse(self.p.replica_method_can_work_without_replication_connection('basebackup')) diff --git a/tests/test_slots.py b/tests/test_slots.py index f88cd881..bdb4bc5d 100644 --- a/tests/test_slots.py +++ b/tests/test_slots.py @@ -111,6 +111,8 @@ class TestSlotsHandler(BaseTestPostgresql): self.s.copy_logical_slots(self.cluster, ['ls']) with patch.object(MockCursor, 'execute', Mock(side_effect=psycopg.OperationalError)): self.s.copy_logical_slots(self.cluster, ['foo']) + with patch.object(Cluster, 'leader', PropertyMock(return_value=None)): + self.s.copy_logical_slots(self.cluster, ['foo']) @patch.object(Postgresql, 'stop', Mock(return_value=True)) @patch.object(Postgresql, 'start', Mock(return_value=True)) diff --git a/tests/test_utils.py b/tests/test_utils.py index 36925041..023c8823 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -36,7 +36,7 @@ class TestUtils(unittest.TestCase): def test_enable_keepalive(self): with patch('socket.SIO_KEEPALIVE_VALS', 1, create=True): - self.assertIsNotNone(enable_keepalive(Mock(), 10, 5)) + self.assertIsNone(enable_keepalive(Mock(), 10, 5)) with patch('socket.SIO_KEEPALIVE_VALS', None, create=True): for platform in ('linux2', 'darwin', 'other'): with patch('sys.platform', platform): diff --git a/tests/test_wale_restore.py b/tests/test_wale_restore.py index 568073c9..0a1150ab 100644 --- a/tests/test_wale_restore.py +++ b/tests/test_wale_restore.py @@ -123,7 +123,7 @@ class TestWALERestore(unittest.TestCase): with patch.object(WALERestore, 'run', Mock(return_value=1)), \ patch('time.sleep', mock_sleep): self.assertEqual(_main(), 1) - self.assertTrue(sleeps[0], WALE_TEST_RETRIES) + self.assertEqual(sleeps[0], WALE_TEST_RETRIES) @patch('os.path.isfile', Mock(return_value=True)) def test_get_major_version(self): diff --git a/tests/test_watchdog.py b/tests/test_watchdog.py index 78752719..af37a5ff 100644 --- a/tests/test_watchdog.py +++ b/tests/test_watchdog.py @@ -110,6 +110,11 @@ class TestWatchdog(unittest.TestCase): watchdog.keepalive() self.assertEqual(len(device.writes), 1) + watchdog.impl._fd, fd = None, watchdog.impl._fd + watchdog.keepalive() + self.assertEqual(len(device.writes), 1) + watchdog.impl._fd = fd + watchdog.disable() self.assertFalse(device.open) self.assertEqual(device.writes[-1], b'V') diff --git a/tests/test_zookeeper.py b/tests/test_zookeeper.py index 067877b5..c2f3f583 100644 --- a/tests/test_zookeeper.py +++ b/tests/test_zookeeper.py @@ -13,6 +13,7 @@ from patroni.dcs.zookeeper import Cluster, Leader, PatroniKazooClient,\ class MockKazooClient(Mock): + handler = PatroniSequentialThreadingHandler(10) leader = False exists = True @@ -154,11 +155,6 @@ class TestZooKeeper(unittest.TestCase): def test_session_listener(self): self.zk.session_listener(KazooState.SUSPENDED) - def test_members_watcher(self): - self.zk._fetch_cluster = False - self.zk.members_watcher(None) - self.assertTrue(self.zk._fetch_cluster) - def test_reload_config(self): self.zk.reload_config({'ttl': 20, 'retry_timeout': 10, 'loop_wait': 10}) self.zk.reload_config({'ttl': 20, 'retry_timeout': 10, 'loop_wait': 5}) @@ -223,6 +219,8 @@ class TestZooKeeper(unittest.TestCase): def test_cancel_initialization(self): self.zk.cancel_initialization() + with patch.object(MockKazooClient, 'delete', Mock()): + self.zk.cancel_initialization() def test_touch_member(self): self.zk._name = 'buzz' diff --git a/typings/botocore/__init__.pyi b/typings/botocore/__init__.pyi new file mode 100644 index 00000000..e69de29b diff --git a/typings/botocore/exceptions.pyi b/typings/botocore/exceptions.pyi new file mode 100644 index 00000000..481fee4e --- /dev/null +++ b/typings/botocore/exceptions.pyi @@ -0,0 +1 @@ +class ClientError(Exception): ... diff --git a/typings/botocore/utils.pyi b/typings/botocore/utils.pyi new file mode 100644 index 00000000..1efc61c5 --- /dev/null +++ b/typings/botocore/utils.pyi @@ -0,0 +1,11 @@ +from typing import Any, Callable, Dict, Optional +DEFAULT_METADATA_SERVICE_TIMEOUT = 1 +METADATA_BASE_URL = 'http://169.254.169.254/' +class AWSResponse: + status_code: int + @property + def text(self) -> str: ... +class IMDSFetcher: + def __init__(self, timeout: float = DEFAULT_METADATA_SERVICE_TIMEOUT, num_attempts: int = 1, base_url: str = METADATA_BASE_URL, env: Optional[Dict[str, str]] = None, user_agent: Optional[str] = None, config: Optional[Dict[str, Any]] = None) -> None: ... + def _fetch_metadata_token(self) -> Optional[str]: ...: + def _get_request(self, url_path: str, retry_func: Optional[Callable[[AWSResponse], bool]] = None, token: Optional[str] = None) -> AWSResponse: ... diff --git a/typings/cdiff/__init__.pyi b/typings/cdiff/__init__.pyi new file mode 100644 index 00000000..4578d468 --- /dev/null +++ b/typings/cdiff/__init__.pyi @@ -0,0 +1,5 @@ +import io +from typing import Any +class PatchStream: + def __init__(self, diff_hdl: io.TextIOBase) -> None: ... +def markup_to_pager(stream: Any, opts: Any) -> None: ... diff --git a/typings/consul/__init__.pyi b/typings/consul/__init__.pyi new file mode 100644 index 00000000..c0ca380e --- /dev/null +++ b/typings/consul/__init__.pyi @@ -0,0 +1,2 @@ +from consul.base import ConsulException, NotFound +__all__ = ['ConsulException', 'Consul', 'NotFound'] diff --git a/typings/consul/base.pyi b/typings/consul/base.pyi new file mode 100644 index 00000000..d512d108 --- /dev/null +++ b/typings/consul/base.pyi @@ -0,0 +1,24 @@ +from typing import Any, Dict, List, Optional, Tuple +class ConsulException(Exception): ... +class NotFound(ConsulException): ... +class Check: + @classmethod + def http(klass, url: str, interval: str, timeout: Optional[str] = None, deregister: Optional[str] = None) -> Dict[str, str]: ... +class Consul: + http: Any + agent: 'Consul.Agent' + session: 'Consul.Session' + kv: 'Consul.KV' + class KV: + def get(self, key: str, index: Optional[int]=None, recurse: bool = False, wait: Optional[str] = None, token: Optional[str] = None, consistency: Optional[str] = None, keys: bool = False, separator: Optional[str] = '', dc: Optional[str] = None) -> Tuple[int, Dict[str, Any]]: ... + def put(self, key: str, value: str, cas: Optional[int] = None, flags: Optional[int] = None, acquire: Optional[str] = None, release: Optional[str] = None, token: Optional[str] = None, dc: Optional[str] = None) -> bool: ... + def delete(self, key: str, recurse: Optional[bool] = None, cas: Optional[int] = None, token: Optional[str] = None, dc: Optional[str] = None) -> bool: ... + class Agent: + service: 'Consul.Agent.Service' + def self(self) -> Dict[str, Dict[str, Any]]: ... + class Service: + def register(self, name: str, service_id=..., address=..., port=..., tags=..., check=..., token=..., script=..., interval=..., ttl=..., http=..., timeout=..., enable_tag_override=...) -> bool: ... + def deregister(self, service_id: str) -> bool: ... + class Session: + def create(self, name: Optional[str] = None, node: Optional[str] = [], checks: Optional[List[str]]=None, lock_delay: float = 15, behavior: str = 'release', ttl: Optional[int] = None, dc: Optional[str] = None) -> str: ... + def renew(self, session_id: str, dc: Optional[str] = None) -> Optional[str]: ... diff --git a/typings/dns/resolver.pyi b/typings/dns/resolver.pyi new file mode 100644 index 00000000..c68b1fb4 --- /dev/null +++ b/typings/dns/resolver.pyi @@ -0,0 +1,17 @@ +from typing import Union, Optional, Iterator +class Name: + def to_text(self, omit_final_dot: bool = ...) -> str: ... +class Rdata: + target: Name = ... + port: int = ... +class Answer: + def __iter__(self) -> Iterator[Rdata]: ... +def resolve(qname : str, rdtype : Union[int,str] = 0, + rdclass : Union[int,str] = 0, + tcp=False, source=None, raise_on_no_answer=True, + source_port=0, lifetime : Optional[float]=None, + search : Optional[bool]=None) -> Answer: ... +def query(qname : str, rdtype : Union[int,str] = 0, + rdclass : Union[int,str] = 0, + tcp=False, source: Optional[str] = None, raise_on_no_answer=True, + source_port=0, lifetime : Optional[float]=None) -> Answer: ... diff --git a/typings/etcd/__init__.pyi b/typings/etcd/__init__.pyi new file mode 100644 index 00000000..58e1fb4d --- /dev/null +++ b/typings/etcd/__init__.pyi @@ -0,0 +1,24 @@ +from typing import Dict, Optional, Type, List +from .client import Client +__all__ = ['Client', 'EtcdError', 'EtcdException', 'EtcdEventIndexCleared', 'EtcdWatcherCleared', 'EtcdKeyNotFound', 'EtcdAlreadyExist', 'EtcdResult', 'EtcdConnectionFailed', 'EtcdWatchTimedOut'] +class EtcdResult: + action: str = ... + modifiedIndex: int = ... + key: str = ... + value: str = ... + ttl: Optional[float] = ... + @property + def leaves(self) -> List['EtcdResult']: ... +class EtcdException(Exception): + def __init__(self, message=..., payload=...) -> None: ... +class EtcdConnectionFailed(EtcdException): + def __init__(self, message=..., payload=..., cause=...) -> None: ... +class EtcdKeyError(EtcdException): ... +class EtcdKeyNotFound(EtcdKeyError): ... +class EtcdAlreadyExist(EtcdKeyError): ... +class EtcdEventIndexCleared(EtcdException): ... +class EtcdWatchTimedOut(EtcdConnectionFailed): ... +class EtcdWatcherCleared(EtcdException): ... +class EtcdLeaderElectionInProgress(EtcdException): ... +class EtcdError: + error_exceptions: Dict[int, Type[EtcdException]] = ... diff --git a/typings/etcd/client.pyi b/typings/etcd/client.pyi new file mode 100644 index 00000000..b7760c82 --- /dev/null +++ b/typings/etcd/client.pyi @@ -0,0 +1,29 @@ +import urllib3 +from typing import Any, Optional, Set +from . import EtcdResult +class Client: + _MGET: str + _MPUT: str + _MPOST: str + _MDELETE: str + _comparison_conditions: Set[str] + _read_options: Set[str] + _del_conditions: Set[str] + http: urllib3.poolmanager.PoolManager + _use_proxies: bool + version_prefix: str + username: Optional[str] + password: Optional[str] + def __init__(self, host=..., port=..., srv_domain=..., version_prefix=..., read_timeout=..., allow_redirect=..., protocol=..., cert=..., ca_cert=..., username=..., password=..., allow_reconnect=..., use_proxies=..., expected_cluster_id=..., per_host_pool_size=..., lock_prefix=...): ... + @property + def protocol(self) -> str: ... + @property + def read_timeout(self) -> int: ... + @property + def allow_redirect(self) -> bool: ... + def write(self, key: str, value: str, ttl: int = ..., dir: bool = ..., append: bool = ..., **kwdargs: Any) -> EtcdResult: ... + def read(self, key: str, **kwdargs: Any) -> EtcdResult: ... + def delete(self, key: str, recursive: bool = ..., dir: bool = ..., **kwdargs: Any) -> EtcdResult: ... + def set(self, key: str, value: str, ttl: int = ...) -> EtcdResult: ... + def watch(self, key: str, index: int = ..., timeout: float = ..., recursive: bool = ...) -> EtcdResult: ... + def _handle_server_response(self, response: urllib3.response.HTTPResponse) -> Any: ... diff --git a/typings/kazoo/client.pyi b/typings/kazoo/client.pyi new file mode 100644 index 00000000..654ea739 --- /dev/null +++ b/typings/kazoo/client.pyi @@ -0,0 +1,34 @@ +__all__ = ['KazooState', 'KazooClient', 'KazooRetry'] + +from kazoo.protocol.connection import ConnectionHandler +from kazoo.protocol.states import KazooState, WatchedEvent, ZnodeStat +from kazoo.handlers.threading import AsyncResult, SequentialThreadingHandler +from kazoo.retry import KazooRetry +from kazoo.security import ACL + +from typing import Any, Callable, Optional, Tuple, List + + +class KazooClient: + handler: SequentialThreadingHandler + _state: str + _connection: ConnectionHandler + _session_timeout: int + retry: Callable[..., Any] + _retry: KazooRetry + def __init__(self, hosts=..., timeout=..., client_id=..., handler=..., default_acl=..., auth_data=..., sasl_options=..., read_only=..., randomize_hosts=..., connection_retry=..., command_retry=..., logger=..., keyfile=..., keyfile_password=..., certfile=..., ca=..., use_ssl=..., verify_certs=..., **kwargs) -> None: ... + @property + def client_id(self) -> Optional[Tuple[Any]]: ... + def add_listener(self, listener: Callable[[str], None]) -> None: ... + def start(self, timeout: int = ...) -> None: ... + def restart(self) -> None: ... + def set_hosts(self, hosts: str, randomize_hosts: Optional[bool] = None) -> None: ... + def create(self, path: str, value: bytes = b'', acl: Optional[ACL]=None, ephemeral: bool = False, sequence: bool = False, makepath: bool = False, include_data: bool = False) -> None: ... + def create_async(self, path: str, value: bytes = b'', acl: Optional[ACL]=None, ephemeral: bool = False, sequence: bool = False, makepath: bool = False, include_data: bool = False) -> AsyncResult: ... + def get(self, path: str, watch: Optional[Callable[[WatchedEvent], None]] = None) -> Tuple[bytes, ZnodeStat]: ... + def get_children(self, path: str, watch: Optional[Callable[[WatchedEvent], None]] = None, include_data: bool = False) -> List[str]: ... + def set(self, path: str, value: bytes, version: int = -1) -> ZnodeStat: ... + def set_async(self, path: str, value: bytes, version: int = -1) -> AsyncResult: ... + def delete(self, path: str, version: int = -1, recursive: bool = False) -> None: ... + def delete_async(self, path: str, version: int = -1) -> AsyncResult: ... + def _call(self, request: Tuple[Any], async_object: AsyncResult) -> Optional[bool]: ... diff --git a/typings/kazoo/exceptions.pyi b/typings/kazoo/exceptions.pyi new file mode 100644 index 00000000..db6f4ed1 --- /dev/null +++ b/typings/kazoo/exceptions.pyi @@ -0,0 +1,12 @@ +class KazooException(Exception): + ... +class ZookeeperError(KazooException): + ... +class SessionExpiredError(ZookeeperError): + ... +class ConnectionClosedError(SessionExpiredError): + ... +class NoNodeError(ZookeeperError): + ... +class NodeExistsError(ZookeeperError): + ... diff --git a/typings/kazoo/handlers/threading.pyi b/typings/kazoo/handlers/threading.pyi new file mode 100644 index 00000000..b8d70068 --- /dev/null +++ b/typings/kazoo/handlers/threading.pyi @@ -0,0 +1,13 @@ +import socket +from kazoo.handlers import utils +from typing import Any + +class AsyncResult(utils.AsyncResult): + ... + +class SequentialThreadingHandler: + def select(self, *args: Any, **kwargs: Any) -> Any: + ... + + def create_connection(self, *args: Any, **kwargs: Any) -> socket.socket: + ... diff --git a/typings/kazoo/handlers/utils.pyi b/typings/kazoo/handlers/utils.pyi new file mode 100644 index 00000000..f26c910d --- /dev/null +++ b/typings/kazoo/handlers/utils.pyi @@ -0,0 +1,6 @@ +from typing import Any, Optional +class AsyncResult: + def set_exception(self, exception: Exception) -> None: + ... + def get(self, block: bool = False, timeout: Optional[float] = None) -> Any: + ... diff --git a/typings/kazoo/protocol/connection.pyi b/typings/kazoo/protocol/connection.pyi new file mode 100644 index 00000000..2ce4ed05 --- /dev/null +++ b/typings/kazoo/protocol/connection.pyi @@ -0,0 +1,6 @@ +import socket +from typing import Any, Union, Tuple +class ConnectionHandler: + _socket: socket.socket + def _connect(self, *args: Any) -> Tuple[Union[int, float], Union[int, float]]: + ... diff --git a/typings/kazoo/protocol/states.pyi b/typings/kazoo/protocol/states.pyi new file mode 100644 index 00000000..3642d969 --- /dev/null +++ b/typings/kazoo/protocol/states.pyi @@ -0,0 +1,29 @@ +from typing import Any, NamedTuple +class KazooState: + SUSPENDED: str + CONNECTED: str + LOST: str +class KeeperState: + AUTH_FAILED: str + CONNECTED: str + CONNECTED_RO: str + CONNECTING: str + CLOSED: str + EXPIRED_SESSION: str +class WatchedEvent(NamedTuple): + type: str + state: str + path: str +class ZnodeStat(NamedTuple): + + czxid: int + mzxid: int + ctime: float + mtime: float + version: int + cversion: int + aversion: int + ephemeralOwner: Any + dataLength: int + numChildren: int + pzxid: int diff --git a/typings/kazoo/retry.pyi b/typings/kazoo/retry.pyi new file mode 100644 index 00000000..5231a6a8 --- /dev/null +++ b/typings/kazoo/retry.pyi @@ -0,0 +1,7 @@ +from kazoo.exceptions import KazooException +class RetryFailedError(KazooException): + ... +class KazooRetry: + deadline: float + def __init__(self, max_tries=..., delay=..., backoff=..., max_jitter=..., max_delay=..., ignore_expire=..., sleep_func=..., deadline=..., interrupt=...) -> None: + ... diff --git a/typings/kazoo/security.pyi b/typings/kazoo/security.pyi new file mode 100644 index 00000000..9e3840c2 --- /dev/null +++ b/typings/kazoo/security.pyi @@ -0,0 +1,5 @@ +from collections import namedtuple +class ACL(namedtuple('ACL', 'perms id')): + ... +def make_acl(scheme: str, credential: str, read: bool = ..., write: bool = ..., create: bool = ..., delete: bool = ..., admin: bool = ..., all: bool = ...) -> ACL: + ... diff --git a/typings/prettytable/__init__.pyi b/typings/prettytable/__init__.pyi new file mode 100644 index 00000000..fd7d5398 --- /dev/null +++ b/typings/prettytable/__init__.pyi @@ -0,0 +1,13 @@ +from typing import Any, Dict, List +FRAME = 1 +ALL = 1 +class PrettyTable: + def __init__(self, *args: str, **kwargs: Any) -> None: ... + def _stringify_hrule(self, options: Dict[str, Any], where: str = '') -> str: ... + @property + def align(self) -> Dict[str, str]: ... + @align.setter + def align(self, val: str) -> None: ... + def add_row(self, row: List[Any]) -> None: ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ... diff --git a/typings/psycopg2/__init__.pyi b/typings/psycopg2/__init__.pyi new file mode 100644 index 00000000..60aa696a --- /dev/null +++ b/typings/psycopg2/__init__.pyi @@ -0,0 +1,52 @@ +from collections.abc import Callable +from typing import Any, TypeVar, overload + +# connection and cursor not available at runtime +from psycopg2._psycopg import ( + BINARY as BINARY, + DATETIME as DATETIME, + NUMBER as NUMBER, + ROWID as ROWID, + STRING as STRING, + Binary as Binary, + DatabaseError as DatabaseError, + DataError as DataError, + Date as Date, + DateFromTicks as DateFromTicks, + Error as Error, + IntegrityError as IntegrityError, + InterfaceError as InterfaceError, + InternalError as InternalError, + NotSupportedError as NotSupportedError, + OperationalError as OperationalError, + ProgrammingError as ProgrammingError, + Time as Time, + TimeFromTicks as TimeFromTicks, + Timestamp as Timestamp, + TimestampFromTicks as TimestampFromTicks, + Warning as Warning, + __libpq_version__ as __libpq_version__, + apilevel as apilevel, + connection as connection, + cursor as cursor, + paramstyle as paramstyle, + threadsafety as threadsafety, +) + +__version__: str + +_T_conn = TypeVar("_T_conn", bound=connection) + +@overload +def connect(dsn: str, connection_factory: Callable[..., _T_conn], cursor_factory: None = None, **kwargs: Any) -> _T_conn: ... +@overload +def connect( + dsn: str | None = None, *, connection_factory: Callable[..., _T_conn], cursor_factory: None = None, **kwargs: Any +) -> _T_conn: ... +@overload +def connect( + dsn: str | None = None, + connection_factory: Callable[..., connection] | None = None, + cursor_factory: Callable[..., cursor] | None = None, + **kwargs: Any, +) -> connection: ... diff --git a/typings/psycopg2/_ipaddress.pyi b/typings/psycopg2/_ipaddress.pyi new file mode 100644 index 00000000..7c9cb8b0 --- /dev/null +++ b/typings/psycopg2/_ipaddress.pyi @@ -0,0 +1,9 @@ +from _typeshed import Incomplete +from typing import Any + +ipaddress: Any + +def register_ipaddress(conn_or_curs: Incomplete | None = None) -> None: ... +def cast_interface(s, cur: Incomplete | None = None): ... +def cast_network(s, cur: Incomplete | None = None): ... +def adapt_ipaddress(obj): ... diff --git a/typings/psycopg2/_json.pyi b/typings/psycopg2/_json.pyi new file mode 100644 index 00000000..c1dd3ee3 --- /dev/null +++ b/typings/psycopg2/_json.pyi @@ -0,0 +1,26 @@ +from _typeshed import Incomplete +from typing import Any + +JSON_OID: int +JSONARRAY_OID: int +JSONB_OID: int +JSONBARRAY_OID: int + +class Json: + adapted: Any + def __init__(self, adapted, dumps: Incomplete | None = None) -> None: ... + def __conform__(self, proto): ... + def dumps(self, obj): ... + def prepare(self, conn) -> None: ... + def getquoted(self): ... + +def register_json( + conn_or_curs: Incomplete | None = None, + globally: bool = False, + loads: Incomplete | None = None, + oid: Incomplete | None = None, + array_oid: Incomplete | None = None, + name: str = "json", +): ... +def register_default_json(conn_or_curs: Incomplete | None = None, globally: bool = False, loads: Incomplete | None = None): ... +def register_default_jsonb(conn_or_curs: Incomplete | None = None, globally: bool = False, loads: Incomplete | None = None): ... diff --git a/typings/psycopg2/_psycopg.pyi b/typings/psycopg2/_psycopg.pyi new file mode 100644 index 00000000..b6854aa1 --- /dev/null +++ b/typings/psycopg2/_psycopg.pyi @@ -0,0 +1,488 @@ +from collections.abc import Callable, Iterable, Mapping, Sequence +from types import TracebackType +from typing import Any, TypeVar, overload +from typing_extensions import Literal, Self, TypeAlias + +import psycopg2 +import psycopg2.extensions +from psycopg2.sql import Composable + +_Vars: TypeAlias = Sequence[Any] | Mapping[str, Any] | None + +BINARY: Any +BINARYARRAY: Any +BOOLEAN: Any +BOOLEANARRAY: Any +BYTES: Any +BYTESARRAY: Any +CIDRARRAY: Any +DATE: Any +DATEARRAY: Any +DATETIME: Any +DATETIMEARRAY: Any +DATETIMETZ: Any +DATETIMETZARRAY: Any +DECIMAL: Any +DECIMALARRAY: Any +FLOAT: Any +FLOATARRAY: Any +INETARRAY: Any +INTEGER: Any +INTEGERARRAY: Any +INTERVAL: Any +INTERVALARRAY: Any +LONGINTEGER: Any +LONGINTEGERARRAY: Any +MACADDRARRAY: Any +NUMBER: Any +PYDATE: Any +PYDATEARRAY: Any +PYDATETIME: Any +PYDATETIMEARRAY: Any +PYDATETIMETZ: Any +PYDATETIMETZARRAY: Any +PYINTERVAL: Any +PYINTERVALARRAY: Any +PYTIME: Any +PYTIMEARRAY: Any +REPLICATION_LOGICAL: int +REPLICATION_PHYSICAL: int +ROWID: Any +ROWIDARRAY: Any +STRING: Any +STRINGARRAY: Any +TIME: Any +TIMEARRAY: Any +UNICODE: Any +UNICODEARRAY: Any +UNKNOWN: Any +adapters: dict[Any, Any] +apilevel: str +binary_types: dict[Any, Any] +encodings: dict[Any, Any] +paramstyle: str +sqlstate_errors: dict[Any, Any] +string_types: dict[Any, Any] +threadsafety: int + +__libpq_version__: int + +class cursor: + arraysize: int + binary_types: Any + closed: Any + connection: Any + description: Any + itersize: Any + lastrowid: Any + name: Any + pgresult_ptr: Any + query: Any + row_factory: Any + rowcount: int + rownumber: int + scrollable: bool | None + statusmessage: Any + string_types: Any + typecaster: Any + tzinfo_factory: Any + withhold: bool + def __init__(self, conn: connection, name: str | bytes | None = ...) -> None: ... + def callproc(self, procname, parameters=...): ... + def cast(self, oid, s): ... + def close(self): ... + def copy_expert(self, sql: str | bytes | Composable, file, size=...): ... + def copy_from(self, file, table, sep=..., null=..., size=..., columns=...): ... + def copy_to(self, file, table, sep=..., null=..., columns=...): ... + def execute(self, query: str | bytes | Composable, vars: _Vars = ...) -> None: ... + def executemany(self, query: str | bytes | Composable, vars_list: Iterable[_Vars]) -> None: ... + def fetchall(self) -> list[tuple[Any, ...]]: ... + def fetchmany(self, size: int | None = ...) -> list[tuple[Any, ...]]: ... + def fetchone(self) -> tuple[Any, ...] | None: ... + def mogrify(self, *args, **kwargs): ... + def nextset(self): ... + def scroll(self, value, mode=...): ... + def setinputsizes(self, sizes): ... + def setoutputsize(self, size, column=...): ... + def __enter__(self) -> Self: ... + def __exit__( + self, type: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None + ) -> None: ... + def __iter__(self) -> Self: ... + def __next__(self) -> tuple[Any, ...]: ... + +_Cursor: TypeAlias = cursor + +class AsIs: + adapted: Any + def __init__(self, *args, **kwargs) -> None: ... + def getquoted(self, *args, **kwargs): ... + def __conform__(self, *args, **kwargs): ... + +class Binary: + adapted: Any + buffer: Any + def __init__(self, *args, **kwargs) -> None: ... + def getquoted(self, *args, **kwargs): ... + def prepare(self, conn): ... + def __conform__(self, *args, **kwargs): ... + +class Boolean: + adapted: Any + def __init__(self, *args, **kwargs) -> None: ... + def getquoted(self, *args, **kwargs): ... + def __conform__(self, *args, **kwargs): ... + +class Column: + display_size: Any + internal_size: Any + name: Any + null_ok: Any + precision: Any + scale: Any + table_column: Any + table_oid: Any + type_code: Any + def __init__(self, *args, **kwargs) -> None: ... + def __eq__(self, __other): ... + def __ge__(self, __other): ... + def __getitem__(self, __index): ... + def __getstate__(self): ... + def __gt__(self, __other): ... + def __le__(self, __other): ... + def __len__(self) -> int: ... + def __lt__(self, __other): ... + def __ne__(self, __other): ... + def __setstate__(self, state): ... + +class ConnectionInfo: + # Note: the following properties can be None if their corresponding libpq function + # returns NULL. They're not annotated as such, because this is very unlikely in + # practice---the psycopg2 docs [1] don't even mention this as a possibility! + # + # - db_name + # - user + # - password + # - host + # - port + # - options + # + # (To prove this, one needs to inspect the psycopg2 source code [2], plus the + # documentation [3] and source code [4] of the corresponding libpq calls.) + # + # [1]: https://www.psycopg.org/docs/extensions.html#psycopg2.extensions.ConnectionInfo + # [2]: https://github.com/psycopg/psycopg2/blob/1d3a89a0bba621dc1cc9b32db6d241bd2da85ad1/psycopg/conninfo_type.c#L52 and below + # [3]: https://www.postgresql.org/docs/current/libpq-status.html + # [4]: https://github.com/postgres/postgres/blob/b39838889e76274b107935fa8e8951baf0e8b31b/src/interfaces/libpq/fe-connect.c#L6754 and below + @property + def backend_pid(self) -> int: ... + @property + def dbname(self) -> str: ... + @property + def dsn_parameters(self) -> dict[str, str]: ... + @property + def error_message(self) -> str | None: ... + @property + def host(self) -> str: ... + @property + def needs_password(self) -> bool: ... + @property + def options(self) -> str: ... + @property + def password(self) -> str: ... + @property + def port(self) -> int: ... + @property + def protocol_version(self) -> int: ... + @property + def server_version(self) -> int: ... + @property + def socket(self) -> int: ... + @property + def ssl_attribute_names(self) -> list[str]: ... + @property + def ssl_in_use(self) -> bool: ... + @property + def status(self) -> int: ... + @property + def transaction_status(self) -> int: ... + @property + def used_password(self) -> bool: ... + @property + def user(self) -> str: ... + def __init__(self, *args, **kwargs) -> None: ... + def parameter_status(self, name: str) -> str | None: ... + def ssl_attribute(self, name: str) -> str | None: ... + +class DataError(psycopg2.DatabaseError): ... +class DatabaseError(psycopg2.Error): ... + +class Decimal: + adapted: Any + def __init__(self, *args, **kwargs) -> None: ... + def getquoted(self, *args, **kwargs): ... + def __conform__(self, *args, **kwargs): ... + +class Diagnostics: + column_name: str | None + constraint_name: str | None + context: str | None + datatype_name: str | None + internal_position: str | None + internal_query: str | None + message_detail: str | None + message_hint: str | None + message_primary: str | None + schema_name: str | None + severity: str | None + severity_nonlocalized: str | None + source_file: str | None + source_function: str | None + source_line: str | None + sqlstate: str | None + statement_position: str | None + table_name: str | None + def __init__(self, __err: Error) -> None: ... + +class Error(Exception): + cursor: _Cursor | None + diag: Diagnostics + pgcode: str | None + pgerror: str | None + def __init__(self, *args, **kwargs) -> None: ... + def __reduce__(self): ... + def __setstate__(self, state): ... + +class Float: + adapted: Any + def __init__(self, *args, **kwargs) -> None: ... + def getquoted(self, *args, **kwargs): ... + def __conform__(self, *args, **kwargs): ... + +class ISQLQuote: + _wrapped: Any + def __init__(self, *args, **kwargs) -> None: ... + def getbinary(self, *args, **kwargs): ... + def getbuffer(self, *args, **kwargs): ... + def getquoted(self, *args, **kwargs) -> bytes: ... + +class Int: + adapted: Any + def __init__(self, *args, **kwargs) -> None: ... + def getquoted(self, *args, **kwargs): ... + def __conform__(self, *args, **kwargs): ... + +class IntegrityError(psycopg2.DatabaseError): ... +class InterfaceError(psycopg2.Error): ... +class InternalError(psycopg2.DatabaseError): ... + +class List: + adapted: Any + def __init__(self, *args, **kwargs) -> None: ... + def getquoted(self, *args, **kwargs): ... + def prepare(self, *args, **kwargs): ... + def __conform__(self, *args, **kwargs): ... + +class NotSupportedError(psycopg2.DatabaseError): ... + +class Notify: + channel: Any + payload: Any + pid: Any + def __init__(self, *args, **kwargs) -> None: ... + def __eq__(self, __other): ... + def __ge__(self, __other): ... + def __getitem__(self, __index): ... + def __gt__(self, __other): ... + def __hash__(self) -> int: ... + def __le__(self, __other): ... + def __len__(self) -> int: ... + def __lt__(self, __other): ... + def __ne__(self, __other): ... + +class OperationalError(psycopg2.DatabaseError): ... +class ProgrammingError(psycopg2.DatabaseError): ... +class QueryCanceledError(psycopg2.OperationalError): ... + +class QuotedString: + adapted: Any + buffer: Any + encoding: Any + def __init__(self, *args, **kwargs) -> None: ... + def getquoted(self, *args, **kwargs): ... + def prepare(self, *args, **kwargs): ... + def __conform__(self, *args, **kwargs): ... + +class ReplicationConnection(psycopg2.extensions.connection): + autocommit: Any + isolation_level: Any + replication_type: Any + reset: Any + set_isolation_level: Any + set_session: Any + def __init__(self, *args, **kwargs) -> None: ... + +class ReplicationCursor(cursor): + feedback_timestamp: Any + io_timestamp: Any + wal_end: Any + def __init__(self, *args, **kwargs) -> None: ... + def consume_stream(self, consumer, keepalive_interval=...): ... + def read_message(self, *args, **kwargs): ... + def send_feedback(self, write_lsn=..., flush_lsn=..., apply_lsn=..., reply=..., force=...): ... + def start_replication_expert(self, command, decode=..., status_interval=...): ... + +class ReplicationMessage: + cursor: Any + data_size: Any + data_start: Any + payload: Any + send_time: Any + wal_end: Any + def __init__(self, *args, **kwargs) -> None: ... + +class TransactionRollbackError(psycopg2.OperationalError): ... +class Warning(Exception): ... + +class Xid: + bqual: Any + database: Any + format_id: Any + gtrid: Any + owner: Any + prepared: Any + def __init__(self, *args, **kwargs) -> None: ... + def from_string(self, *args, **kwargs): ... + def __getitem__(self, __index): ... + def __len__(self) -> int: ... + +_T_cur = TypeVar("_T_cur", bound=cursor) + +class connection: + DataError: Any + DatabaseError: Any + Error: Any + IntegrityError: Any + InterfaceError: Any + InternalError: Any + NotSupportedError: Any + OperationalError: Any + ProgrammingError: Any + Warning: Any + @property + def async_(self) -> int: ... + autocommit: bool + @property + def binary_types(self) -> Any: ... + @property + def closed(self) -> int: ... + cursor_factory: Callable[..., _Cursor] + @property + def dsn(self) -> str: ... + @property + def encoding(self) -> str: ... + @property + def info(self) -> ConnectionInfo: ... + @property + def isolation_level(self) -> int | None: ... + @isolation_level.setter + def isolation_level(self, __value: str | bytes | int | None) -> None: ... + notices: list[Any] + notifies: list[Any] + @property + def pgconn_ptr(self) -> int | None: ... + @property + def protocol_version(self) -> int: ... + @property + def deferrable(self) -> bool | None: ... + @deferrable.setter + def deferrable(self, __value: Literal["default"] | bool | None) -> None: ... + @property + def readonly(self) -> bool | None: ... + @readonly.setter + def readonly(self, __value: Literal["default"] | bool | None) -> None: ... + @property + def server_version(self) -> int: ... + @property + def status(self) -> int: ... + @property + def string_types(self) -> Any: ... + # Really it's dsn: str, async: int = ..., async_: int = ..., but + # that would be a syntax error. + def __init__(self, dsn: str, *, async_: int = ...) -> None: ... + def cancel(self) -> None: ... + def close(self) -> None: ... + def commit(self) -> None: ... + @overload + def cursor(self, name: str | bytes | None = ..., *, withhold: bool = ..., scrollable: bool | None = ...) -> _Cursor: ... + def fileno(self) -> int: ... + def get_backend_pid(self) -> int: ... + def get_dsn_parameters(self) -> dict[str, str]: ... + def get_native_connection(self): ... + def get_parameter_status(self, parameter: str) -> str | None: ... + def get_transaction_status(self) -> int: ... + def isexecuting(self) -> bool: ... + def lobject( + self, + oid: int = ..., + mode: str | None = ..., + new_oid: int = ..., + new_file: str | None = ..., + lobject_factory: type[lobject] = ..., + ) -> lobject: ... + def poll(self) -> int: ... + def reset(self) -> None: ... + def rollback(self) -> None: ... + def set_client_encoding(self, encoding: str) -> None: ... + def set_isolation_level(self, level: int | None) -> None: ... + def set_session( + self, + isolation_level: str | bytes | int | None = ..., + readonly: bool | Literal["default", b"default"] | None = ..., + deferrable: bool | Literal["default", b"default"] | None = ..., + autocommit: bool = ..., + ) -> None: ... + def tpc_begin(self, xid: str | bytes | Xid) -> None: ... + def tpc_commit(self, __xid: str | bytes | Xid = ...) -> None: ... + def tpc_prepare(self) -> None: ... + def tpc_recover(self) -> list[Xid]: ... + def tpc_rollback(self, __xid: str | bytes | Xid = ...) -> None: ... + def xid(self, format_id, gtrid, bqual) -> Xid: ... + def __enter__(self) -> Self: ... + def __exit__(self, __type: type[BaseException] | None, __name: BaseException | None, __tb: TracebackType | None) -> None: ... + +class lobject: + closed: Any + mode: Any + oid: Any + def __init__(self, *args, **kwargs) -> None: ... + def close(self): ... + def export(self, filename): ... + def read(self, size=...): ... + def seek(self, offset, whence=...): ... + def tell(self): ... + def truncate(self, len=...): ... + def unlink(self): ... + def write(self, str): ... + +def Date(year, month, day): ... +def DateFromPy(*args, **kwargs): ... +def DateFromTicks(ticks): ... +def IntervalFromPy(*args, **kwargs): ... +def Time(hour, minutes, seconds, tzinfo=...): ... +def TimeFromPy(*args, **kwargs): ... +def TimeFromTicks(ticks): ... +def Timestamp(year, month, day, hour, minutes, seconds, tzinfo=...): ... +def TimestampFromPy(*args, **kwargs): ... +def TimestampFromTicks(ticks): ... +def _connect(*args, **kwargs): ... +def adapt(*args, **kwargs): ... +def encrypt_password(*args, **kwargs): ... +def get_wait_callback(*args, **kwargs): ... +def libpq_version(*args, **kwargs): ... +def new_array_type(oids, name, baseobj): ... +def new_type(oids, name, castobj): ... +def parse_dsn(dsn: str | bytes) -> dict[str, Any]: ... +def quote_ident(value: Any, scope: connection | cursor | None) -> str: ... +def register_type(*args, **kwargs): ... +def set_wait_callback(_none): ... diff --git a/typings/psycopg2/_range.pyi b/typings/psycopg2/_range.pyi new file mode 100644 index 00000000..a3b3a118 --- /dev/null +++ b/typings/psycopg2/_range.pyi @@ -0,0 +1,62 @@ +from _typeshed import Incomplete +from typing import Any + +class Range: + def __init__( + self, lower: Incomplete | None = None, upper: Incomplete | None = None, bounds: str = "[)", empty: bool = False + ) -> None: ... + @property + def lower(self): ... + @property + def upper(self): ... + @property + def isempty(self): ... + @property + def lower_inf(self): ... + @property + def upper_inf(self): ... + @property + def lower_inc(self): ... + @property + def upper_inc(self): ... + def __contains__(self, x): ... + def __bool__(self) -> bool: ... + def __eq__(self, other): ... + def __ne__(self, other): ... + def __hash__(self) -> int: ... + def __lt__(self, other): ... + def __le__(self, other): ... + def __gt__(self, other): ... + def __ge__(self, other): ... + +def register_range(pgrange, pyrange, conn_or_curs, globally: bool = False): ... + +class RangeAdapter: + name: Any + adapted: Any + def __init__(self, adapted) -> None: ... + def __conform__(self, proto): ... + def prepare(self, conn) -> None: ... + def getquoted(self): ... + +class RangeCaster: + subtype_oid: Any + typecaster: Any + array_typecaster: Any + def __init__(self, pgrange, pyrange, oid, subtype_oid, array_oid: Incomplete | None = None) -> None: ... + def parse(self, s, cur: Incomplete | None = None): ... + +class NumericRange(Range): ... +class DateRange(Range): ... +class DateTimeRange(Range): ... +class DateTimeTZRange(Range): ... + +class NumberRangeAdapter(RangeAdapter): + def getquoted(self): ... + +int4range_caster: Any +int8range_caster: Any +numrange_caster: Any +daterange_caster: Any +tsrange_caster: Any +tstzrange_caster: Any diff --git a/typings/psycopg2/errorcodes.pyi b/typings/psycopg2/errorcodes.pyi new file mode 100644 index 00000000..66e6cef1 --- /dev/null +++ b/typings/psycopg2/errorcodes.pyi @@ -0,0 +1,304 @@ +def lookup(code, _cache={}): ... + +CLASS_SUCCESSFUL_COMPLETION: str +CLASS_WARNING: str +CLASS_NO_DATA: str +CLASS_SQL_STATEMENT_NOT_YET_COMPLETE: str +CLASS_CONNECTION_EXCEPTION: str +CLASS_TRIGGERED_ACTION_EXCEPTION: str +CLASS_FEATURE_NOT_SUPPORTED: str +CLASS_INVALID_TRANSACTION_INITIATION: str +CLASS_LOCATOR_EXCEPTION: str +CLASS_INVALID_GRANTOR: str +CLASS_INVALID_ROLE_SPECIFICATION: str +CLASS_DIAGNOSTICS_EXCEPTION: str +CLASS_CASE_NOT_FOUND: str +CLASS_CARDINALITY_VIOLATION: str +CLASS_DATA_EXCEPTION: str +CLASS_INTEGRITY_CONSTRAINT_VIOLATION: str +CLASS_INVALID_CURSOR_STATE: str +CLASS_INVALID_TRANSACTION_STATE: str +CLASS_INVALID_SQL_STATEMENT_NAME: str +CLASS_TRIGGERED_DATA_CHANGE_VIOLATION: str +CLASS_INVALID_AUTHORIZATION_SPECIFICATION: str +CLASS_DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST: str +CLASS_INVALID_TRANSACTION_TERMINATION: str +CLASS_SQL_ROUTINE_EXCEPTION: str +CLASS_INVALID_CURSOR_NAME: str +CLASS_EXTERNAL_ROUTINE_EXCEPTION: str +CLASS_EXTERNAL_ROUTINE_INVOCATION_EXCEPTION: str +CLASS_SAVEPOINT_EXCEPTION: str +CLASS_INVALID_CATALOG_NAME: str +CLASS_INVALID_SCHEMA_NAME: str +CLASS_TRANSACTION_ROLLBACK: str +CLASS_SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION: str +CLASS_WITH_CHECK_OPTION_VIOLATION: str +CLASS_INSUFFICIENT_RESOURCES: str +CLASS_PROGRAM_LIMIT_EXCEEDED: str +CLASS_OBJECT_NOT_IN_PREREQUISITE_STATE: str +CLASS_OPERATOR_INTERVENTION: str +CLASS_SYSTEM_ERROR: str +CLASS_SNAPSHOT_FAILURE: str +CLASS_CONFIGURATION_FILE_ERROR: str +CLASS_FOREIGN_DATA_WRAPPER_ERROR: str +CLASS_PL_PGSQL_ERROR: str +CLASS_INTERNAL_ERROR: str +SUCCESSFUL_COMPLETION: str +WARNING: str +NULL_VALUE_ELIMINATED_IN_SET_FUNCTION: str +STRING_DATA_RIGHT_TRUNCATION_: str +PRIVILEGE_NOT_REVOKED: str +PRIVILEGE_NOT_GRANTED: str +IMPLICIT_ZERO_BIT_PADDING: str +DYNAMIC_RESULT_SETS_RETURNED: str +DEPRECATED_FEATURE: str +NO_DATA: str +NO_ADDITIONAL_DYNAMIC_RESULT_SETS_RETURNED: str +SQL_STATEMENT_NOT_YET_COMPLETE: str +CONNECTION_EXCEPTION: str +SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION: str +CONNECTION_DOES_NOT_EXIST: str +SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION: str +CONNECTION_FAILURE: str +TRANSACTION_RESOLUTION_UNKNOWN: str +PROTOCOL_VIOLATION: str +TRIGGERED_ACTION_EXCEPTION: str +FEATURE_NOT_SUPPORTED: str +INVALID_TRANSACTION_INITIATION: str +LOCATOR_EXCEPTION: str +INVALID_LOCATOR_SPECIFICATION: str +INVALID_GRANTOR: str +INVALID_GRANT_OPERATION: str +INVALID_ROLE_SPECIFICATION: str +DIAGNOSTICS_EXCEPTION: str +STACKED_DIAGNOSTICS_ACCESSED_WITHOUT_ACTIVE_HANDLER: str +CASE_NOT_FOUND: str +CARDINALITY_VIOLATION: str +DATA_EXCEPTION: str +STRING_DATA_RIGHT_TRUNCATION: str +NULL_VALUE_NO_INDICATOR_PARAMETER: str +NUMERIC_VALUE_OUT_OF_RANGE: str +NULL_VALUE_NOT_ALLOWED_: str +ERROR_IN_ASSIGNMENT: str +INVALID_DATETIME_FORMAT: str +DATETIME_FIELD_OVERFLOW: str +INVALID_TIME_ZONE_DISPLACEMENT_VALUE: str +ESCAPE_CHARACTER_CONFLICT: str +INVALID_USE_OF_ESCAPE_CHARACTER: str +INVALID_ESCAPE_OCTET: str +ZERO_LENGTH_CHARACTER_STRING: str +MOST_SPECIFIC_TYPE_MISMATCH: str +SEQUENCE_GENERATOR_LIMIT_EXCEEDED: str +NOT_AN_XML_DOCUMENT: str +INVALID_XML_DOCUMENT: str +INVALID_XML_CONTENT: str +INVALID_XML_COMMENT: str +INVALID_XML_PROCESSING_INSTRUCTION: str +INVALID_INDICATOR_PARAMETER_VALUE: str +SUBSTRING_ERROR: str +DIVISION_BY_ZERO: str +INVALID_PRECEDING_OR_FOLLOWING_SIZE: str +INVALID_ARGUMENT_FOR_NTILE_FUNCTION: str +INTERVAL_FIELD_OVERFLOW: str +INVALID_ARGUMENT_FOR_NTH_VALUE_FUNCTION: str +INVALID_CHARACTER_VALUE_FOR_CAST: str +INVALID_ESCAPE_CHARACTER: str +INVALID_REGULAR_EXPRESSION: str +INVALID_ARGUMENT_FOR_LOGARITHM: str +INVALID_ARGUMENT_FOR_POWER_FUNCTION: str +INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION: str +INVALID_ROW_COUNT_IN_LIMIT_CLAUSE: str +INVALID_ROW_COUNT_IN_RESULT_OFFSET_CLAUSE: str +INVALID_LIMIT_VALUE: str +CHARACTER_NOT_IN_REPERTOIRE: str +INDICATOR_OVERFLOW: str +INVALID_PARAMETER_VALUE: str +UNTERMINATED_C_STRING: str +INVALID_ESCAPE_SEQUENCE: str +STRING_DATA_LENGTH_MISMATCH: str +TRIM_ERROR: str +ARRAY_SUBSCRIPT_ERROR: str +INVALID_TABLESAMPLE_REPEAT: str +INVALID_TABLESAMPLE_ARGUMENT: str +DUPLICATE_JSON_OBJECT_KEY_VALUE: str +INVALID_ARGUMENT_FOR_SQL_JSON_DATETIME_FUNCTION: str +INVALID_JSON_TEXT: str +INVALID_SQL_JSON_SUBSCRIPT: str +MORE_THAN_ONE_SQL_JSON_ITEM: str +NO_SQL_JSON_ITEM: str +NON_NUMERIC_SQL_JSON_ITEM: str +NON_UNIQUE_KEYS_IN_A_JSON_OBJECT: str +SINGLETON_SQL_JSON_ITEM_REQUIRED: str +SQL_JSON_ARRAY_NOT_FOUND: str +SQL_JSON_MEMBER_NOT_FOUND: str +SQL_JSON_NUMBER_NOT_FOUND: str +SQL_JSON_OBJECT_NOT_FOUND: str +TOO_MANY_JSON_ARRAY_ELEMENTS: str +TOO_MANY_JSON_OBJECT_MEMBERS: str +SQL_JSON_SCALAR_REQUIRED: str +FLOATING_POINT_EXCEPTION: str +INVALID_TEXT_REPRESENTATION: str +INVALID_BINARY_REPRESENTATION: str +BAD_COPY_FILE_FORMAT: str +UNTRANSLATABLE_CHARACTER: str +NONSTANDARD_USE_OF_ESCAPE_CHARACTER: str +INTEGRITY_CONSTRAINT_VIOLATION: str +RESTRICT_VIOLATION: str +NOT_NULL_VIOLATION: str +FOREIGN_KEY_VIOLATION: str +UNIQUE_VIOLATION: str +CHECK_VIOLATION: str +EXCLUSION_VIOLATION: str +INVALID_CURSOR_STATE: str +INVALID_TRANSACTION_STATE: str +ACTIVE_SQL_TRANSACTION: str +BRANCH_TRANSACTION_ALREADY_ACTIVE: str +INAPPROPRIATE_ACCESS_MODE_FOR_BRANCH_TRANSACTION: str +INAPPROPRIATE_ISOLATION_LEVEL_FOR_BRANCH_TRANSACTION: str +NO_ACTIVE_SQL_TRANSACTION_FOR_BRANCH_TRANSACTION: str +READ_ONLY_SQL_TRANSACTION: str +SCHEMA_AND_DATA_STATEMENT_MIXING_NOT_SUPPORTED: str +HELD_CURSOR_REQUIRES_SAME_ISOLATION_LEVEL: str +NO_ACTIVE_SQL_TRANSACTION: str +IN_FAILED_SQL_TRANSACTION: str +IDLE_IN_TRANSACTION_SESSION_TIMEOUT: str +INVALID_SQL_STATEMENT_NAME: str +TRIGGERED_DATA_CHANGE_VIOLATION: str +INVALID_AUTHORIZATION_SPECIFICATION: str +INVALID_PASSWORD: str +DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST: str +DEPENDENT_OBJECTS_STILL_EXIST: str +INVALID_TRANSACTION_TERMINATION: str +SQL_ROUTINE_EXCEPTION: str +MODIFYING_SQL_DATA_NOT_PERMITTED_: str +PROHIBITED_SQL_STATEMENT_ATTEMPTED_: str +READING_SQL_DATA_NOT_PERMITTED_: str +FUNCTION_EXECUTED_NO_RETURN_STATEMENT: str +INVALID_CURSOR_NAME: str +EXTERNAL_ROUTINE_EXCEPTION: str +CONTAINING_SQL_NOT_PERMITTED: str +MODIFYING_SQL_DATA_NOT_PERMITTED: str +PROHIBITED_SQL_STATEMENT_ATTEMPTED: str +READING_SQL_DATA_NOT_PERMITTED: str +EXTERNAL_ROUTINE_INVOCATION_EXCEPTION: str +INVALID_SQLSTATE_RETURNED: str +NULL_VALUE_NOT_ALLOWED: str +TRIGGER_PROTOCOL_VIOLATED: str +SRF_PROTOCOL_VIOLATED: str +EVENT_TRIGGER_PROTOCOL_VIOLATED: str +SAVEPOINT_EXCEPTION: str +INVALID_SAVEPOINT_SPECIFICATION: str +INVALID_CATALOG_NAME: str +INVALID_SCHEMA_NAME: str +TRANSACTION_ROLLBACK: str +SERIALIZATION_FAILURE: str +TRANSACTION_INTEGRITY_CONSTRAINT_VIOLATION: str +STATEMENT_COMPLETION_UNKNOWN: str +DEADLOCK_DETECTED: str +SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION: str +INSUFFICIENT_PRIVILEGE: str +SYNTAX_ERROR: str +INVALID_NAME: str +INVALID_COLUMN_DEFINITION: str +NAME_TOO_LONG: str +DUPLICATE_COLUMN: str +AMBIGUOUS_COLUMN: str +UNDEFINED_COLUMN: str +UNDEFINED_OBJECT: str +DUPLICATE_OBJECT: str +DUPLICATE_ALIAS: str +DUPLICATE_FUNCTION: str +AMBIGUOUS_FUNCTION: str +GROUPING_ERROR: str +DATATYPE_MISMATCH: str +WRONG_OBJECT_TYPE: str +INVALID_FOREIGN_KEY: str +CANNOT_COERCE: str +UNDEFINED_FUNCTION: str +GENERATED_ALWAYS: str +RESERVED_NAME: str +UNDEFINED_TABLE: str +UNDEFINED_PARAMETER: str +DUPLICATE_CURSOR: str +DUPLICATE_DATABASE: str +DUPLICATE_PREPARED_STATEMENT: str +DUPLICATE_SCHEMA: str +DUPLICATE_TABLE: str +AMBIGUOUS_PARAMETER: str +AMBIGUOUS_ALIAS: str +INVALID_COLUMN_REFERENCE: str +INVALID_CURSOR_DEFINITION: str +INVALID_DATABASE_DEFINITION: str +INVALID_FUNCTION_DEFINITION: str +INVALID_PREPARED_STATEMENT_DEFINITION: str +INVALID_SCHEMA_DEFINITION: str +INVALID_TABLE_DEFINITION: str +INVALID_OBJECT_DEFINITION: str +INDETERMINATE_DATATYPE: str +INVALID_RECURSION: str +WINDOWING_ERROR: str +COLLATION_MISMATCH: str +INDETERMINATE_COLLATION: str +WITH_CHECK_OPTION_VIOLATION: str +INSUFFICIENT_RESOURCES: str +DISK_FULL: str +OUT_OF_MEMORY: str +TOO_MANY_CONNECTIONS: str +CONFIGURATION_LIMIT_EXCEEDED: str +PROGRAM_LIMIT_EXCEEDED: str +STATEMENT_TOO_COMPLEX: str +TOO_MANY_COLUMNS: str +TOO_MANY_ARGUMENTS: str +OBJECT_NOT_IN_PREREQUISITE_STATE: str +OBJECT_IN_USE: str +CANT_CHANGE_RUNTIME_PARAM: str +LOCK_NOT_AVAILABLE: str +UNSAFE_NEW_ENUM_VALUE_USAGE: str +OPERATOR_INTERVENTION: str +QUERY_CANCELED: str +ADMIN_SHUTDOWN: str +CRASH_SHUTDOWN: str +CANNOT_CONNECT_NOW: str +DATABASE_DROPPED: str +SYSTEM_ERROR: str +IO_ERROR: str +UNDEFINED_FILE: str +DUPLICATE_FILE: str +SNAPSHOT_TOO_OLD: str +CONFIG_FILE_ERROR: str +LOCK_FILE_EXISTS: str +FDW_ERROR: str +FDW_OUT_OF_MEMORY: str +FDW_DYNAMIC_PARAMETER_VALUE_NEEDED: str +FDW_INVALID_DATA_TYPE: str +FDW_COLUMN_NAME_NOT_FOUND: str +FDW_INVALID_DATA_TYPE_DESCRIPTORS: str +FDW_INVALID_COLUMN_NAME: str +FDW_INVALID_COLUMN_NUMBER: str +FDW_INVALID_USE_OF_NULL_POINTER: str +FDW_INVALID_STRING_FORMAT: str +FDW_INVALID_HANDLE: str +FDW_INVALID_OPTION_INDEX: str +FDW_INVALID_OPTION_NAME: str +FDW_OPTION_NAME_NOT_FOUND: str +FDW_REPLY_HANDLE: str +FDW_UNABLE_TO_CREATE_EXECUTION: str +FDW_UNABLE_TO_CREATE_REPLY: str +FDW_UNABLE_TO_ESTABLISH_CONNECTION: str +FDW_NO_SCHEMAS: str +FDW_SCHEMA_NOT_FOUND: str +FDW_TABLE_NOT_FOUND: str +FDW_FUNCTION_SEQUENCE_ERROR: str +FDW_TOO_MANY_HANDLES: str +FDW_INCONSISTENT_DESCRIPTOR_INFORMATION: str +FDW_INVALID_ATTRIBUTE_VALUE: str +FDW_INVALID_STRING_LENGTH_OR_BUFFER_LENGTH: str +FDW_INVALID_DESCRIPTOR_FIELD_IDENTIFIER: str +PLPGSQL_ERROR: str +RAISE_EXCEPTION: str +NO_DATA_FOUND: str +TOO_MANY_ROWS: str +ASSERT_FAILURE: str +INTERNAL_ERROR: str +DATA_CORRUPTED: str +INDEX_CORRUPTED: str diff --git a/typings/psycopg2/errors.pyi b/typings/psycopg2/errors.pyi new file mode 100644 index 00000000..bf7d3579 --- /dev/null +++ b/typings/psycopg2/errors.pyi @@ -0,0 +1,263 @@ +from psycopg2._psycopg import Error as Error, Warning as Warning + +class DatabaseError(Error): ... +class InterfaceError(Error): ... +class DataError(DatabaseError): ... +class DiagnosticsException(DatabaseError): ... +class IntegrityError(DatabaseError): ... +class InternalError(DatabaseError): ... +class InvalidGrantOperation(DatabaseError): ... +class InvalidGrantor(DatabaseError): ... +class InvalidLocatorSpecification(DatabaseError): ... +class InvalidRoleSpecification(DatabaseError): ... +class InvalidTransactionInitiation(DatabaseError): ... +class LocatorException(DatabaseError): ... +class NoAdditionalDynamicResultSetsReturned(DatabaseError): ... +class NoData(DatabaseError): ... +class NotSupportedError(DatabaseError): ... +class OperationalError(DatabaseError): ... +class ProgrammingError(DatabaseError): ... +class SnapshotTooOld(DatabaseError): ... +class SqlStatementNotYetComplete(DatabaseError): ... +class StackedDiagnosticsAccessedWithoutActiveHandler(DatabaseError): ... +class TriggeredActionException(DatabaseError): ... +class ActiveSqlTransaction(InternalError): ... +class AdminShutdown(OperationalError): ... +class AmbiguousAlias(ProgrammingError): ... +class AmbiguousColumn(ProgrammingError): ... +class AmbiguousFunction(ProgrammingError): ... +class AmbiguousParameter(ProgrammingError): ... +class ArraySubscriptError(DataError): ... +class AssertFailure(InternalError): ... +class BadCopyFileFormat(DataError): ... +class BranchTransactionAlreadyActive(InternalError): ... +class CannotCoerce(ProgrammingError): ... +class CannotConnectNow(OperationalError): ... +class CantChangeRuntimeParam(OperationalError): ... +class CardinalityViolation(ProgrammingError): ... +class CaseNotFound(ProgrammingError): ... +class CharacterNotInRepertoire(DataError): ... +class CheckViolation(IntegrityError): ... +class CollationMismatch(ProgrammingError): ... +class ConfigFileError(InternalError): ... +class ConfigurationLimitExceeded(OperationalError): ... +class ConnectionDoesNotExist(OperationalError): ... +class ConnectionException(OperationalError): ... +class ConnectionFailure(OperationalError): ... +class ContainingSqlNotPermitted(InternalError): ... +class CrashShutdown(OperationalError): ... +class DataCorrupted(InternalError): ... +class DataException(DataError): ... +class DatabaseDropped(OperationalError): ... +class DatatypeMismatch(ProgrammingError): ... +class DatetimeFieldOverflow(DataError): ... +class DependentObjectsStillExist(InternalError): ... +class DependentPrivilegeDescriptorsStillExist(InternalError): ... +class DiskFull(OperationalError): ... +class DivisionByZero(DataError): ... +class DuplicateAlias(ProgrammingError): ... +class DuplicateColumn(ProgrammingError): ... +class DuplicateCursor(ProgrammingError): ... +class DuplicateDatabase(ProgrammingError): ... +class DuplicateFile(OperationalError): ... +class DuplicateFunction(ProgrammingError): ... +class DuplicateJsonObjectKeyValue(DataError): ... +class DuplicateObject(ProgrammingError): ... +class DuplicatePreparedStatement(ProgrammingError): ... +class DuplicateSchema(ProgrammingError): ... +class DuplicateTable(ProgrammingError): ... +class ErrorInAssignment(DataError): ... +class EscapeCharacterConflict(DataError): ... +class EventTriggerProtocolViolated(InternalError): ... +class ExclusionViolation(IntegrityError): ... +class ExternalRoutineException(InternalError): ... +class ExternalRoutineInvocationException(InternalError): ... +class FdwColumnNameNotFound(OperationalError): ... +class FdwDynamicParameterValueNeeded(OperationalError): ... +class FdwError(OperationalError): ... +class FdwFunctionSequenceError(OperationalError): ... +class FdwInconsistentDescriptorInformation(OperationalError): ... +class FdwInvalidAttributeValue(OperationalError): ... +class FdwInvalidColumnName(OperationalError): ... +class FdwInvalidColumnNumber(OperationalError): ... +class FdwInvalidDataType(OperationalError): ... +class FdwInvalidDataTypeDescriptors(OperationalError): ... +class FdwInvalidDescriptorFieldIdentifier(OperationalError): ... +class FdwInvalidHandle(OperationalError): ... +class FdwInvalidOptionIndex(OperationalError): ... +class FdwInvalidOptionName(OperationalError): ... +class FdwInvalidStringFormat(OperationalError): ... +class FdwInvalidStringLengthOrBufferLength(OperationalError): ... +class FdwInvalidUseOfNullPointer(OperationalError): ... +class FdwNoSchemas(OperationalError): ... +class FdwOptionNameNotFound(OperationalError): ... +class FdwOutOfMemory(OperationalError): ... +class FdwReplyHandle(OperationalError): ... +class FdwSchemaNotFound(OperationalError): ... +class FdwTableNotFound(OperationalError): ... +class FdwTooManyHandles(OperationalError): ... +class FdwUnableToCreateExecution(OperationalError): ... +class FdwUnableToCreateReply(OperationalError): ... +class FdwUnableToEstablishConnection(OperationalError): ... +class FeatureNotSupported(NotSupportedError): ... +class FloatingPointException(DataError): ... +class ForeignKeyViolation(IntegrityError): ... +class FunctionExecutedNoReturnStatement(InternalError): ... +class GeneratedAlways(ProgrammingError): ... +class GroupingError(ProgrammingError): ... +class HeldCursorRequiresSameIsolationLevel(InternalError): ... +class IdleInTransactionSessionTimeout(InternalError): ... +class InFailedSqlTransaction(InternalError): ... +class InappropriateAccessModeForBranchTransaction(InternalError): ... +class InappropriateIsolationLevelForBranchTransaction(InternalError): ... +class IndeterminateCollation(ProgrammingError): ... +class IndeterminateDatatype(ProgrammingError): ... +class IndexCorrupted(InternalError): ... +class IndicatorOverflow(DataError): ... +class InsufficientPrivilege(ProgrammingError): ... +class InsufficientResources(OperationalError): ... +class IntegrityConstraintViolation(IntegrityError): ... +class InternalError_(InternalError): ... +class IntervalFieldOverflow(DataError): ... +class InvalidArgumentForLogarithm(DataError): ... +class InvalidArgumentForNthValueFunction(DataError): ... +class InvalidArgumentForNtileFunction(DataError): ... +class InvalidArgumentForPowerFunction(DataError): ... +class InvalidArgumentForSqlJsonDatetimeFunction(DataError): ... +class InvalidArgumentForWidthBucketFunction(DataError): ... +class InvalidAuthorizationSpecification(OperationalError): ... +class InvalidBinaryRepresentation(DataError): ... +class InvalidCatalogName(ProgrammingError): ... +class InvalidCharacterValueForCast(DataError): ... +class InvalidColumnDefinition(ProgrammingError): ... +class InvalidColumnReference(ProgrammingError): ... +class InvalidCursorDefinition(ProgrammingError): ... +class InvalidCursorName(OperationalError): ... +class InvalidCursorState(InternalError): ... +class InvalidDatabaseDefinition(ProgrammingError): ... +class InvalidDatetimeFormat(DataError): ... +class InvalidEscapeCharacter(DataError): ... +class InvalidEscapeOctet(DataError): ... +class InvalidEscapeSequence(DataError): ... +class InvalidForeignKey(ProgrammingError): ... +class InvalidFunctionDefinition(ProgrammingError): ... +class InvalidIndicatorParameterValue(DataError): ... +class InvalidJsonText(DataError): ... +class InvalidName(ProgrammingError): ... +class InvalidObjectDefinition(ProgrammingError): ... +class InvalidParameterValue(DataError): ... +class InvalidPassword(OperationalError): ... +class InvalidPrecedingOrFollowingSize(DataError): ... +class InvalidPreparedStatementDefinition(ProgrammingError): ... +class InvalidRecursion(ProgrammingError): ... +class InvalidRegularExpression(DataError): ... +class InvalidRowCountInLimitClause(DataError): ... +class InvalidRowCountInResultOffsetClause(DataError): ... +class InvalidSavepointSpecification(InternalError): ... +class InvalidSchemaDefinition(ProgrammingError): ... +class InvalidSchemaName(ProgrammingError): ... +class InvalidSqlJsonSubscript(DataError): ... +class InvalidSqlStatementName(OperationalError): ... +class InvalidSqlstateReturned(InternalError): ... +class InvalidTableDefinition(ProgrammingError): ... +class InvalidTablesampleArgument(DataError): ... +class InvalidTablesampleRepeat(DataError): ... +class InvalidTextRepresentation(DataError): ... +class InvalidTimeZoneDisplacementValue(DataError): ... +class InvalidTransactionState(InternalError): ... +class InvalidTransactionTermination(InternalError): ... +class InvalidUseOfEscapeCharacter(DataError): ... +class InvalidXmlComment(DataError): ... +class InvalidXmlContent(DataError): ... +class InvalidXmlDocument(DataError): ... +class InvalidXmlProcessingInstruction(DataError): ... +class IoError(OperationalError): ... +class LockFileExists(InternalError): ... +class LockNotAvailable(OperationalError): ... +class ModifyingSqlDataNotPermitted(InternalError): ... +class ModifyingSqlDataNotPermittedExt(InternalError): ... +class MoreThanOneSqlJsonItem(DataError): ... +class MostSpecificTypeMismatch(DataError): ... +class NameTooLong(ProgrammingError): ... +class NoActiveSqlTransaction(InternalError): ... +class NoActiveSqlTransactionForBranchTransaction(InternalError): ... +class NoDataFound(InternalError): ... +class NoSqlJsonItem(DataError): ... +class NonNumericSqlJsonItem(DataError): ... +class NonUniqueKeysInAJsonObject(DataError): ... +class NonstandardUseOfEscapeCharacter(DataError): ... +class NotAnXmlDocument(DataError): ... +class NotNullViolation(IntegrityError): ... +class NullValueNoIndicatorParameter(DataError): ... +class NullValueNotAllowed(DataError): ... +class NullValueNotAllowedExt(InternalError): ... +class NumericValueOutOfRange(DataError): ... +class ObjectInUse(OperationalError): ... +class ObjectNotInPrerequisiteState(OperationalError): ... +class OperatorIntervention(OperationalError): ... +class OutOfMemory(OperationalError): ... +class PlpgsqlError(InternalError): ... +class ProgramLimitExceeded(OperationalError): ... +class ProhibitedSqlStatementAttempted(InternalError): ... +class ProhibitedSqlStatementAttemptedExt(InternalError): ... +class ProtocolViolation(OperationalError): ... +class QueryCanceledError(OperationalError): ... +class RaiseException(InternalError): ... +class ReadOnlySqlTransaction(InternalError): ... +class ReadingSqlDataNotPermitted(InternalError): ... +class ReadingSqlDataNotPermittedExt(InternalError): ... +class ReservedName(ProgrammingError): ... +class RestrictViolation(IntegrityError): ... +class SavepointException(InternalError): ... +class SchemaAndDataStatementMixingNotSupported(InternalError): ... +class SequenceGeneratorLimitExceeded(DataError): ... +class SingletonSqlJsonItemRequired(DataError): ... +class SqlJsonArrayNotFound(DataError): ... +class SqlJsonMemberNotFound(DataError): ... +class SqlJsonNumberNotFound(DataError): ... +class SqlJsonObjectNotFound(DataError): ... +class SqlJsonScalarRequired(DataError): ... +class SqlRoutineException(InternalError): ... +class SqlclientUnableToEstablishSqlconnection(OperationalError): ... +class SqlserverRejectedEstablishmentOfSqlconnection(OperationalError): ... +class SrfProtocolViolated(InternalError): ... +class StatementTooComplex(OperationalError): ... +class StringDataLengthMismatch(DataError): ... +class StringDataRightTruncation(DataError): ... +class SubstringError(DataError): ... +class SyntaxError(ProgrammingError): ... +class SyntaxErrorOrAccessRuleViolation(ProgrammingError): ... +class SystemError(OperationalError): ... +class TooManyArguments(OperationalError): ... +class TooManyColumns(OperationalError): ... +class TooManyConnections(OperationalError): ... +class TooManyJsonArrayElements(DataError): ... +class TooManyJsonObjectMembers(DataError): ... +class TooManyRows(InternalError): ... +class TransactionResolutionUnknown(OperationalError): ... +class TransactionRollbackError(OperationalError): ... +class TriggerProtocolViolated(InternalError): ... +class TriggeredDataChangeViolation(OperationalError): ... +class TrimError(DataError): ... +class UndefinedColumn(ProgrammingError): ... +class UndefinedFile(OperationalError): ... +class UndefinedFunction(ProgrammingError): ... +class UndefinedObject(ProgrammingError): ... +class UndefinedParameter(ProgrammingError): ... +class UndefinedTable(ProgrammingError): ... +class UniqueViolation(IntegrityError): ... +class UnsafeNewEnumValueUsage(OperationalError): ... +class UnterminatedCString(DataError): ... +class UntranslatableCharacter(DataError): ... +class WindowingError(ProgrammingError): ... +class WithCheckOptionViolation(ProgrammingError): ... +class WrongObjectType(ProgrammingError): ... +class ZeroLengthCharacterString(DataError): ... +class DeadlockDetected(TransactionRollbackError): ... +class QueryCanceled(QueryCanceledError): ... +class SerializationFailure(TransactionRollbackError): ... +class StatementCompletionUnknown(TransactionRollbackError): ... +class TransactionIntegrityConstraintViolation(TransactionRollbackError): ... +class TransactionRollback(TransactionRollbackError): ... + +def lookup(code): ... diff --git a/typings/psycopg2/extensions.pyi b/typings/psycopg2/extensions.pyi new file mode 100644 index 00000000..b0789a23 --- /dev/null +++ b/typings/psycopg2/extensions.pyi @@ -0,0 +1,117 @@ +from _typeshed import Incomplete +from typing import Any + +from psycopg2._psycopg import ( + BINARYARRAY as BINARYARRAY, + BOOLEAN as BOOLEAN, + BOOLEANARRAY as BOOLEANARRAY, + BYTES as BYTES, + BYTESARRAY as BYTESARRAY, + DATE as DATE, + DATEARRAY as DATEARRAY, + DATETIMEARRAY as DATETIMEARRAY, + DECIMAL as DECIMAL, + DECIMALARRAY as DECIMALARRAY, + FLOAT as FLOAT, + FLOATARRAY as FLOATARRAY, + INTEGER as INTEGER, + INTEGERARRAY as INTEGERARRAY, + INTERVAL as INTERVAL, + INTERVALARRAY as INTERVALARRAY, + LONGINTEGER as LONGINTEGER, + LONGINTEGERARRAY as LONGINTEGERARRAY, + PYDATE as PYDATE, + PYDATEARRAY as PYDATEARRAY, + PYDATETIME as PYDATETIME, + PYDATETIMEARRAY as PYDATETIMEARRAY, + PYDATETIMETZ as PYDATETIMETZ, + PYDATETIMETZARRAY as PYDATETIMETZARRAY, + PYINTERVAL as PYINTERVAL, + PYINTERVALARRAY as PYINTERVALARRAY, + PYTIME as PYTIME, + PYTIMEARRAY as PYTIMEARRAY, + ROWIDARRAY as ROWIDARRAY, + STRINGARRAY as STRINGARRAY, + TIME as TIME, + TIMEARRAY as TIMEARRAY, + UNICODE as UNICODE, + UNICODEARRAY as UNICODEARRAY, + AsIs as AsIs, + Binary as Binary, + Boolean as Boolean, + Column as Column, + ConnectionInfo as ConnectionInfo, + DateFromPy as DateFromPy, + Diagnostics as Diagnostics, + Float as Float, + Int as Int, + IntervalFromPy as IntervalFromPy, + ISQLQuote as ISQLQuote, + Notify as Notify, + QueryCanceledError as QueryCanceledError, + QuotedString as QuotedString, + TimeFromPy as TimeFromPy, + TimestampFromPy as TimestampFromPy, + TransactionRollbackError as TransactionRollbackError, + Xid as Xid, + adapt as adapt, + adapters as adapters, + binary_types as binary_types, + connection as connection, + cursor as cursor, + encodings as encodings, + encrypt_password as encrypt_password, + get_wait_callback as get_wait_callback, + libpq_version as libpq_version, + lobject as lobject, + new_array_type as new_array_type, + new_type as new_type, + parse_dsn as parse_dsn, + quote_ident as quote_ident, + register_type as register_type, + set_wait_callback as set_wait_callback, + string_types as string_types, +) + +ISOLATION_LEVEL_AUTOCOMMIT: int +ISOLATION_LEVEL_READ_UNCOMMITTED: int +ISOLATION_LEVEL_READ_COMMITTED: int +ISOLATION_LEVEL_REPEATABLE_READ: int +ISOLATION_LEVEL_SERIALIZABLE: int +ISOLATION_LEVEL_DEFAULT: Any +STATUS_SETUP: int +STATUS_READY: int +STATUS_BEGIN: int +STATUS_SYNC: int +STATUS_ASYNC: int +STATUS_PREPARED: int +STATUS_IN_TRANSACTION: int +POLL_OK: int +POLL_READ: int +POLL_WRITE: int +POLL_ERROR: int +TRANSACTION_STATUS_IDLE: int +TRANSACTION_STATUS_ACTIVE: int +TRANSACTION_STATUS_INTRANS: int +TRANSACTION_STATUS_INERROR: int +TRANSACTION_STATUS_UNKNOWN: int + +def register_adapter(typ, callable) -> None: ... + +class SQL_IN: + def __init__(self, seq) -> None: ... + def prepare(self, conn) -> None: ... + def getquoted(self): ... + +class NoneAdapter: + def __init__(self, obj) -> None: ... + def getquoted(self, _null: bytes = b"NULL"): ... + +def make_dsn(dsn: Incomplete | None = None, **kwargs): ... + +JSON: Any +JSONARRAY: Any +JSONB: Any +JSONBARRAY: Any + +def adapt(obj: Any) -> ISQLQuote: ... diff --git a/typings/psycopg2/extras.pyi b/typings/psycopg2/extras.pyi new file mode 100644 index 00000000..5c809e51 --- /dev/null +++ b/typings/psycopg2/extras.pyi @@ -0,0 +1,240 @@ +from _typeshed import Incomplete +from collections import OrderedDict +from collections.abc import Callable +from typing import Any, NamedTuple, TypeVar, overload + +from psycopg2._ipaddress import register_ipaddress as register_ipaddress +from psycopg2._json import ( + Json as Json, + register_default_json as register_default_json, + register_default_jsonb as register_default_jsonb, + register_json as register_json, +) +from psycopg2._psycopg import ( + REPLICATION_LOGICAL as REPLICATION_LOGICAL, + REPLICATION_PHYSICAL as REPLICATION_PHYSICAL, + ReplicationConnection as _replicationConnection, + ReplicationCursor as _replicationCursor, + ReplicationMessage as ReplicationMessage, +) +from psycopg2._range import ( + DateRange as DateRange, + DateTimeRange as DateTimeRange, + DateTimeTZRange as DateTimeTZRange, + NumericRange as NumericRange, + Range as Range, + RangeAdapter as RangeAdapter, + RangeCaster as RangeCaster, + register_range as register_range, +) + +from .extensions import connection as _connection, cursor as _cursor, quote_ident as quote_ident + +_T_cur = TypeVar("_T_cur", bound=_cursor) + +class DictCursorBase(_cursor): + def __init__(self, *args, **kwargs) -> None: ... + +class DictConnection(_connection): + @overload + def cursor(self, name: str | bytes | None = ..., *, withhold: bool = ..., scrollable: bool | None = ...) -> DictCursor: ... + @overload + def cursor( + self, + name: str | bytes | None = ..., + *, + cursor_factory: Callable[..., _T_cur], + withhold: bool = ..., + scrollable: bool | None = ..., + ) -> _T_cur: ... + @overload + def cursor( + self, name: str | bytes | None, cursor_factory: Callable[..., _T_cur], withhold: bool = ..., scrollable: bool | None = ... + ) -> _T_cur: ... + +class DictCursor(DictCursorBase): + def __init__(self, *args, **kwargs) -> None: ... + index: Any + def execute(self, query, vars: Incomplete | None = None): ... + def callproc(self, procname, vars: Incomplete | None = None): ... + def fetchone(self) -> DictRow | None: ... # type: ignore[override] + def fetchmany(self, size: int | None = None) -> list[DictRow]: ... # type: ignore[override] + def fetchall(self) -> list[DictRow]: ... # type: ignore[override] + def __next__(self) -> DictRow: ... # type: ignore[override] + +class DictRow(list[Any]): + def __init__(self, cursor) -> None: ... + def __getitem__(self, x): ... + def __setitem__(self, x, v) -> None: ... + def items(self): ... + def keys(self): ... + def values(self): ... + def get(self, x, default: Incomplete | None = None): ... + def copy(self): ... + def __contains__(self, x): ... + def __reduce__(self): ... + +class RealDictConnection(_connection): + @overload + def cursor( + self, name: str | bytes | None = ..., *, withhold: bool = ..., scrollable: bool | None = ... + ) -> RealDictCursor: ... + @overload + def cursor( + self, + name: str | bytes | None = ..., + *, + cursor_factory: Callable[..., _T_cur], + withhold: bool = ..., + scrollable: bool | None = ..., + ) -> _T_cur: ... + @overload + def cursor( + self, name: str | bytes | None, cursor_factory: Callable[..., _T_cur], withhold: bool = ..., scrollable: bool | None = ... + ) -> _T_cur: ... + +class RealDictCursor(DictCursorBase): + def __init__(self, *args, **kwargs) -> None: ... + column_mapping: Any + def execute(self, query, vars: Incomplete | None = None): ... + def callproc(self, procname, vars: Incomplete | None = None): ... + def fetchone(self) -> RealDictRow | None: ... # type: ignore[override] + def fetchmany(self, size: int | None = None) -> list[RealDictRow]: ... # type: ignore[override] + def fetchall(self) -> list[RealDictRow]: ... # type: ignore[override] + def __next__(self) -> RealDictRow: ... # type: ignore[override] + +class RealDictRow(OrderedDict[Any, Any]): + def __init__(self, *args, **kwargs) -> None: ... + def __setitem__(self, key, value) -> None: ... + +class NamedTupleConnection(_connection): + @overload + def cursor( + self, name: str | bytes | None = ..., *, withhold: bool = ..., scrollable: bool | None = ... + ) -> NamedTupleCursor: ... + @overload + def cursor( + self, + name: str | bytes | None = ..., + *, + cursor_factory: Callable[..., _T_cur], + withhold: bool = ..., + scrollable: bool | None = ..., + ) -> _T_cur: ... + @overload + def cursor( + self, name: str | bytes | None, cursor_factory: Callable[..., _T_cur], withhold: bool = ..., scrollable: bool | None = ... + ) -> _T_cur: ... + +class NamedTupleCursor(_cursor): + Record: Any + MAX_CACHE: int + def execute(self, query, vars: Incomplete | None = None): ... + def executemany(self, query, vars): ... + def callproc(self, procname, vars: Incomplete | None = None): ... + def fetchone(self) -> NamedTuple | None: ... + def fetchmany(self, size: int | None = None) -> list[NamedTuple]: ... # type: ignore[override] + def fetchall(self) -> list[NamedTuple]: ... # type: ignore[override] + def __next__(self) -> NamedTuple: ... + +class LoggingConnection(_connection): + log: Any + def initialize(self, logobj) -> None: ... + def filter(self, msg, curs): ... + def cursor(self, *args, **kwargs): ... + +class LoggingCursor(_cursor): + def execute(self, query, vars: Incomplete | None = None): ... + def callproc(self, procname, vars: Incomplete | None = None): ... + +class MinTimeLoggingConnection(LoggingConnection): + def initialize(self, logobj, mintime: int = 0) -> None: ... + def filter(self, msg, curs): ... + def cursor(self, *args, **kwargs): ... + +class MinTimeLoggingCursor(LoggingCursor): + timestamp: Any + def execute(self, query, vars: Incomplete | None = None): ... + def callproc(self, procname, vars: Incomplete | None = None): ... + +class LogicalReplicationConnection(_replicationConnection): + def __init__(self, *args, **kwargs) -> None: ... + +class PhysicalReplicationConnection(_replicationConnection): + def __init__(self, *args, **kwargs) -> None: ... + +class StopReplication(Exception): ... + +class ReplicationCursor(_replicationCursor): + def create_replication_slot( + self, slot_name, slot_type: Incomplete | None = None, output_plugin: Incomplete | None = None + ) -> None: ... + def drop_replication_slot(self, slot_name) -> None: ... + def start_replication( + self, + slot_name: Incomplete | None = None, + slot_type: Incomplete | None = None, + start_lsn: int = 0, + timeline: int = 0, + options: Incomplete | None = None, + decode: bool = False, + status_interval: int = 10, + ) -> None: ... + def fileno(self): ... + +class UUID_adapter: + def __init__(self, uuid) -> None: ... + def __conform__(self, proto): ... + def getquoted(self): ... + +def register_uuid(oids: Incomplete | None = None, conn_or_curs: Incomplete | None = None): ... + +class Inet: + addr: Any + def __init__(self, addr) -> None: ... + def prepare(self, conn) -> None: ... + def getquoted(self): ... + def __conform__(self, proto): ... + +def register_inet(oid: Incomplete | None = None, conn_or_curs: Incomplete | None = None): ... +def wait_select(conn) -> None: ... + +class HstoreAdapter: + wrapped: Any + def __init__(self, wrapped) -> None: ... + conn: Any + getquoted: Any + def prepare(self, conn) -> None: ... + @classmethod + def parse(cls, s, cur, _bsdec=...): ... + @classmethod + def parse_unicode(cls, s, cur): ... + @classmethod + def get_oids(cls, conn_or_curs): ... + +def register_hstore( + conn_or_curs, + globally: bool = False, + unicode: bool = False, + oid: Incomplete | None = None, + array_oid: Incomplete | None = None, +) -> None: ... + +class CompositeCaster: + name: Any + schema: Any + oid: Any + array_oid: Any + attnames: Any + atttypes: Any + typecaster: Any + array_typecaster: Any + def __init__(self, name, oid, attrs, array_oid: Incomplete | None = None, schema: Incomplete | None = None) -> None: ... + def parse(self, s, curs): ... + def make(self, values): ... + @classmethod + def tokenize(cls, s): ... + +def register_composite(name, conn_or_curs, globally: bool = False, factory: Incomplete | None = None): ... +def execute_batch(cur, sql, argslist, page_size: int = 100) -> None: ... +def execute_values(cur, sql, argslist, template: Incomplete | None = None, page_size: int = 100, fetch: bool = False): ... diff --git a/typings/psycopg2/pool.pyi b/typings/psycopg2/pool.pyi new file mode 100644 index 00000000..d380aa19 --- /dev/null +++ b/typings/psycopg2/pool.pyi @@ -0,0 +1,24 @@ +from _typeshed import Incomplete +from typing import Any + +import psycopg2 + +class PoolError(psycopg2.Error): ... + +class AbstractConnectionPool: + minconn: Any + maxconn: Any + closed: bool + def __init__(self, minconn, maxconn, *args, **kwargs) -> None: ... + # getconn, putconn and closeall are officially documented as methods of the + # abstract base class, but in reality, they only exist on the children classes + def getconn(self, key: Incomplete | None = ...): ... + def putconn(self, conn: Any, key: Incomplete | None = ..., close: bool = ...) -> None: ... + def closeall(self) -> None: ... + +class SimpleConnectionPool(AbstractConnectionPool): ... + +class ThreadedConnectionPool(AbstractConnectionPool): + # This subclass has a default value for conn which doesn't exist + # in the SimpleConnectionPool class, nor in the documentation + def putconn(self, conn: Incomplete | None = None, key: Incomplete | None = None, close: bool = False) -> None: ... diff --git a/typings/psycopg2/sql.pyi b/typings/psycopg2/sql.pyi new file mode 100644 index 00000000..5721e9f4 --- /dev/null +++ b/typings/psycopg2/sql.pyi @@ -0,0 +1,50 @@ +from _typeshed import Incomplete +from collections.abc import Iterator +from typing import Any + +class Composable: + def __init__(self, wrapped) -> None: ... + def as_string(self, context) -> str: ... + def __add__(self, other) -> Composed: ... + def __mul__(self, n) -> Composed: ... + def __eq__(self, other) -> bool: ... + def __ne__(self, other) -> bool: ... + +class Composed(Composable): + def __init__(self, seq) -> None: ... + @property + def seq(self) -> list[Composable]: ... + def as_string(self, context) -> str: ... + def __iter__(self) -> Iterator[Composable]: ... + def __add__(self, other) -> Composed: ... + def join(self, joiner) -> Composed: ... + +class SQL(Composable): + def __init__(self, string) -> None: ... + @property + def string(self) -> str: ... + def as_string(self, context) -> str: ... + def format(self, *args, **kwargs) -> Composed: ... + def join(self, seq) -> Composed: ... + +class Identifier(Composable): + def __init__(self, *strings) -> None: ... + @property + def strings(self) -> tuple[str, ...]: ... + @property + def string(self) -> str: ... + def as_string(self, context) -> str: ... + +class Literal(Composable): + @property + def wrapped(self): ... + def as_string(self, context) -> str: ... + +class Placeholder(Composable): + def __init__(self, name: Incomplete | None = None) -> None: ... + @property + def name(self) -> str | None: ... + def as_string(self, context) -> str: ... + +NULL: Any +DEFAULT: Any diff --git a/typings/psycopg2/tz.pyi b/typings/psycopg2/tz.pyi new file mode 100644 index 00000000..5095ae74 --- /dev/null +++ b/typings/psycopg2/tz.pyi @@ -0,0 +1,26 @@ +import datetime +from _typeshed import Incomplete +from typing import Any + +ZERO: Any + +class FixedOffsetTimezone(datetime.tzinfo): + def __init__(self, offset: Incomplete | None = None, name: Incomplete | None = None) -> None: ... + def __new__(cls, offset: Incomplete | None = None, name: Incomplete | None = None): ... + def __eq__(self, other): ... + def __ne__(self, other): ... + def __getinitargs__(self): ... + def utcoffset(self, dt): ... + def tzname(self, dt): ... + def dst(self, dt): ... + +STDOFFSET: Any +DSTOFFSET: Any +DSTDIFF: Any + +class LocalTimezone(datetime.tzinfo): + def utcoffset(self, dt): ... + def dst(self, dt): ... + def tzname(self, dt): ... + +LOCAL: Any diff --git a/typings/pysyncobj/__init__.pyi b/typings/pysyncobj/__init__.pyi new file mode 100644 index 00000000..f31daa0e --- /dev/null +++ b/typings/pysyncobj/__init__.pyi @@ -0,0 +1,2 @@ +from .syncobj import FAIL_REASON, SyncObj, SyncObjConf, replicated +__all__ = ['SyncObj', 'SyncObjConf', 'replicated', 'FAIL_REASON'] diff --git a/typings/pysyncobj/config.pyi b/typings/pysyncobj/config.pyi new file mode 100644 index 00000000..e9141f22 --- /dev/null +++ b/typings/pysyncobj/config.pyi @@ -0,0 +1,13 @@ +from typing import Optional +class FAIL_REASON: + SUCCESS = ... + QUEUE_FULL = ... + MISSING_LEADER = ... + DISCARDED = ... + NOT_LEADER = ... + LEADER_CHANGED = ... + REQUEST_DENIED = ... +class SyncObjConf: + password: Optional[str] + autoTickPeriod: int + def __init__(self, **kwargs) -> None: ... diff --git a/typings/pysyncobj/dns_resolver.pyi b/typings/pysyncobj/dns_resolver.pyi new file mode 100644 index 00000000..5fd9a13f --- /dev/null +++ b/typings/pysyncobj/dns_resolver.pyi @@ -0,0 +1,5 @@ +from typing import Optional +class DnsCachingResolver: + def setTimeouts(self, cacheTime: float, failCacheTime: float) -> None: ... + def resolve(self, hostname: str) -> Optional[str]: ... +def globalDnsResolver() -> DnsCachingResolver: ... diff --git a/typings/pysyncobj/node.pyi b/typings/pysyncobj/node.pyi new file mode 100644 index 00000000..032ce75d --- /dev/null +++ b/typings/pysyncobj/node.pyi @@ -0,0 +1,6 @@ +class Node: + @property + def id(self) -> str: ... +class TCPNode(Node): + @property + def host(self) -> str: ... diff --git a/typings/pysyncobj/syncobj.pyi b/typings/pysyncobj/syncobj.pyi new file mode 100644 index 00000000..9261f3bc --- /dev/null +++ b/typings/pysyncobj/syncobj.pyi @@ -0,0 +1,24 @@ +from typing import Any, Callable, Collection, List, Optional, Set, Type +from .config import FAIL_REASON, SyncObjConf +from .node import Node +from .transport import Transport +__all__ = ['FAIL_REASON', 'SyncObj', 'SyncObjConf', 'replicated'] +class SyncObj: + def __init__(self, selfNode: Optional[str], otherNodes: Collection[str], conf: SyncObjConf=..., consumers=..., nodeClass=..., transport=..., transportClass: Type[Transport]=...) -> None: ... + def destroy(self) -> None: ... + def doTick(self, timeToWait: float = 0.0) -> None: ... + def isNodeConnected(self, node: Node) -> bool: ... + @property + def selfNode(self) -> Node: ... + @property + def otherNodes(self) -> Set[Node]: ... + @property + def raftLastApplied(self) -> int: ... + @property + def raftCommitIndex(self) -> int: ... + @property + def conf(self) -> SyncObjConf: ... + def _getLeader(self) -> Optional[Node]: ... + def _isLeader(self) -> bool: ... + def _onTick(self, timeToWait: float = 0.0) -> None: ... +def replicated(*decArgs: Any, **decKwargs: Any) -> Callable[..., Any]: ... diff --git a/typings/pysyncobj/tcp_connection.pyi b/typings/pysyncobj/tcp_connection.pyi new file mode 100644 index 00000000..d0a1777d --- /dev/null +++ b/typings/pysyncobj/tcp_connection.pyi @@ -0,0 +1,4 @@ +class CONNECTION_STATE: + DISCONNECTED = ... + CONNECTING = ... + CONNECTED = ... diff --git a/typings/pysyncobj/transport.pyi b/typings/pysyncobj/transport.pyi new file mode 100644 index 00000000..3aafa77e --- /dev/null +++ b/typings/pysyncobj/transport.pyi @@ -0,0 +1,10 @@ +from typing import Any, Callable, Collection, Optional +from .node import TCPNode +from .syncobj import SyncObj +from .tcp_connection import CONNECTION_STATE +__all__ = ['CONNECTION_STATE', 'TCPTransport'] +class Transport: + def setOnUtilityMessageCallback(self, message: str, callback: Callable[[Any, Callable[..., Any]], Any]) -> None: ... +class TCPTransport(Transport): + def __init__(self, syncObj: SyncObj, selfNode: Optional[TCPNode], otherNodes: Collection[TCPNode]) -> None: ... + def _connectIfNecessarySingle(self, node: TCPNode) -> bool: ... diff --git a/typings/pysyncobj/utility.pyi b/typings/pysyncobj/utility.pyi new file mode 100644 index 00000000..86f331c7 --- /dev/null +++ b/typings/pysyncobj/utility.pyi @@ -0,0 +1,5 @@ +from typing import Any, List, Optional, Union +from .node import TCPNode +class TcpUtility(Utility): + def __init__(self, password: Optional[str] = None, timeout: float=900.0) -> None: ... + def executeCommand(self, node: Union[str, TCPNode], command: List[Any]) -> Any: ... diff --git a/typings/urllib3/__init__.pyi b/typings/urllib3/__init__.pyi new file mode 100644 index 00000000..ff573518 --- /dev/null +++ b/typings/urllib3/__init__.pyi @@ -0,0 +1,6 @@ +from .poolmanager import PoolManager +from .response import HTTPResponse +from .util.request import make_headers +from .util.timeout import Timeout + +__all__ = ['HTTPResponse', 'PoolManager', 'Timeout', 'make_headers'] diff --git a/typings/urllib3/_collections.pyi b/typings/urllib3/_collections.pyi new file mode 100644 index 00000000..b7a8cebe --- /dev/null +++ b/typings/urllib3/_collections.pyi @@ -0,0 +1,28 @@ +from typing import Any +class HTTPHeaderDict(MutableMapping[str, str]): + def __init__(self, headers=None, **kwargs) -> None: ... + def __setitem__(self, key, val) -> None: ... + def __getitem__(self, key): ... + def __delitem__(self, key) -> None: ... + def __contains__(self, key): ... + def __eq__(self, other): ... + def __iter__(self) -> NoReturn: ... + def __len__(self) -> int: ... + def __ne__(self, other): ... + values: Any + get: Any + update: Any + iterkeys: Any + itervalues: Any + def pop(self, key, default=...): ... + def discard(self, key): ... + def add(self, key, val): ... + def extend(self, *args, **kwargs): ... + def getlist(self, key): ... + getheaders: Any + getallmatchingheaders: Any + iget: Any + def copy(self): ... + def iteritems(self): ... + def itermerged(self): ... + def items(self): ... diff --git a/typings/urllib3/connection.pyi b/typings/urllib3/connection.pyi new file mode 100644 index 00000000..06b265db --- /dev/null +++ b/typings/urllib3/connection.pyi @@ -0,0 +1,2 @@ +from http.client import HTTPConnection as _HTTPConnection +class HTTPConnection(_HTTPConnection): ... diff --git a/typings/urllib3/poolmanager.pyi b/typings/urllib3/poolmanager.pyi new file mode 100644 index 00000000..5ea41a68 --- /dev/null +++ b/typings/urllib3/poolmanager.pyi @@ -0,0 +1,9 @@ +from typing import Any, Dict, Optional +from .response import HTTPResponse +class PoolManager: + headers: Dict[str, str] + connection_pool_kw: Dict[str, Any] + def __init__(self, num_pools: int = 10, headers: Optional[Dict[str, str]] = None, **connection_pool_kw: Any) -> None: ... + def urlopen(self, method: str, url: str, body: Optional[Any] = None, headers: Optional[Dict[str,str]] = None, encode_multipart: bool = True, multipart_boundary: Optional[str] = None, **kw: Any) -> HTTPResponse: ... + def request(self, method: str, url: str, fields: Optional[Any] = None, headers: Optional[Dict[str, str]] = None, **urlopen_kw: Any) -> HTTPResponse: ... + def clear(self) -> None: ... diff --git a/typings/urllib3/response.pyi b/typings/urllib3/response.pyi new file mode 100644 index 00000000..aef96fbc --- /dev/null +++ b/typings/urllib3/response.pyi @@ -0,0 +1,14 @@ +import io +from typing import Any, Iterator, Optional, Union +from ._collections import HTTPHeaderDict +from .connection import HTTPConnection +class HTTPResponse(io.IOBase): + headers: HTTPHeaderDict + status: int + reason: Optional[str] + def release_conn(self) -> None: ... + @property + def data(self) -> Union[bytes, Any]: ... + @property + def connection(self) -> Optional[HTTPConnection]: ... + def read_chunked(self, amt: Optional[int] = None, decode_content: Optional[bool] = None) -> Iterator[bytes]: ... diff --git a/typings/urllib3/util/request.pyi b/typings/urllib3/util/request.pyi new file mode 100644 index 00000000..7c949655 --- /dev/null +++ b/typings/urllib3/util/request.pyi @@ -0,0 +1,9 @@ +from typing import Optional, Union, Dict, List +def make_headers( + keep_alive: Optional[bool] = None, + accept_encoding: Union[bool, List[str], str, None] = None, + user_agent: Optional[str] = None, + basic_auth: Optional[str] = None, + proxy_basic_auth: Optional[str] = None, + disable_cache: Optional[bool] = None, +) -> Dict[str, str]: ... diff --git a/typings/urllib3/util/timeout.pyi b/typings/urllib3/util/timeout.pyi new file mode 100644 index 00000000..8ed327a3 --- /dev/null +++ b/typings/urllib3/util/timeout.pyi @@ -0,0 +1,4 @@ +from typing import Any, Optional +class Timeout: + DEFAULT_TIMEOUT: Any + def __init__(self, total: Optional[float] = None, connect: Optional[float] = None, read: Optional[float] = None) -> None: ... diff --git a/typings/ydiff/__init__.pyi b/typings/ydiff/__init__.pyi new file mode 100644 index 00000000..4578d468 --- /dev/null +++ b/typings/ydiff/__init__.pyi @@ -0,0 +1,5 @@ +import io +from typing import Any +class PatchStream: + def __init__(self, diff_hdl: io.TextIOBase) -> None: ... +def markup_to_pager(stream: Any, opts: Any) -> None: ...