Enable pyright strict mode (#2652)

- added pyrightconfig.json with typeCheckingMode=strict
- added type hints to all files except api.py
- added type stubs for dns, etcd, consul, kazoo, pysyncobj and other modules
- added type stubs for psycopg2 and urllib3 with some little fixes
- fixes most of the issues reported by pyright
- remaining issues will be addressed later, along with enabling CI linting task
This commit is contained in:
Alexander Kukushkin
2023-05-09 09:38:00 +02:00
committed by GitHub
parent 1ac9b11f33
commit 76b3b99de2
102 changed files with 4803 additions and 2150 deletions

View File

@@ -1,17 +1,19 @@
import sys
from typing import Any, Callable, Iterator, Tuple
PATRONI_ENV_PREFIX = 'PATRONI_'
KUBERNETES_ENV_PREFIX = 'KUBERNETES_'
MIN_PSYCOPG2 = (2, 5, 4)
def fatal(string, *args):
def fatal(string: str, *args: Any) -> None:
sys.stderr.write('FATAL: ' + string.format(*args) + '\n')
sys.exit(1)
def parse_version(version):
def _parse_version(version):
def parse_version(version: str) -> Tuple[int, ...]:
def _parse_version(version: str) -> Iterator[int]:
for e in version.split('.'):
try:
yield int(e)
@@ -21,9 +23,11 @@ def parse_version(version):
# We pass MIN_PSYCOPG2 and parse_version as arguments to simplify usage of check_psycopg from the setup.py
def check_psycopg(_min_psycopg2=MIN_PSYCOPG2, _parse_version=parse_version):
def check_psycopg(_min_psycopg2: Tuple[int, ...] = MIN_PSYCOPG2,
_parse_version: Callable[[str], Tuple[int, ...]] = parse_version) -> None:
min_psycopg2_str = '.'.join(map(str, _min_psycopg2))
# try psycopg2
try:
from psycopg2 import __version__
if _parse_version(__version__) >= _min_psycopg2:
@@ -32,10 +36,11 @@ def check_psycopg(_min_psycopg2=MIN_PSYCOPG2, _parse_version=parse_version):
except ImportError:
version_str = None
# try psycopg3
try:
from psycopg import __version__
except ImportError:
error = 'Patroni requires psycopg2>={0}, psycopg2-binary, or psycopg>=3.0'.format(min_psycopg2_str)
if version_str:
if version_str is not None:
error += ', but only psycopg2=={0} is available'.format(version_str)
fatal(error)

View File

@@ -3,14 +3,19 @@ import os
import signal
import time
from typing import Any, Dict, Optional, TYPE_CHECKING
from patroni.daemon import AbstractPatroniDaemon, abstract_main
if TYPE_CHECKING: # pragma: no cover
from .config import Config
logger = logging.getLogger(__name__)
class Patroni(AbstractPatroniDaemon):
def __init__(self, config):
def __init__(self, config: 'Config') -> None:
from patroni.api import RestApiServer
from patroni.dcs import get_dcs
from patroni.ha import Ha
@@ -33,9 +38,9 @@ class Patroni(AbstractPatroniDaemon):
self.tags = self.get_tags()
self.next_run = time.time()
self.scheduled_restart = {}
self.scheduled_restart: Dict[str, Any] = {}
def load_dynamic_configuration(self):
def load_dynamic_configuration(self) -> None:
from patroni.exceptions import DCSError
while True:
try:
@@ -53,19 +58,19 @@ class Patroni(AbstractPatroniDaemon):
logger.warning('Can not get cluster from dcs')
time.sleep(5)
def get_tags(self):
def get_tags(self) -> Dict[str, Any]:
return {tag: value for tag, value in self.config.get('tags', {}).items()
if tag not in ('clonefrom', 'nofailover', 'noloadbalance', 'nosync') or value}
@property
def nofailover(self):
def nofailover(self) -> bool:
return bool(self.tags.get('nofailover', False))
@property
def nosync(self):
def nosync(self) -> bool:
return bool(self.tags.get('nosync', False))
def reload_config(self, sighup=False, local=False):
def reload_config(self, sighup: bool = False, local: Optional[bool] = False) -> None:
try:
super(Patroni, self).reload_config(sighup, local)
if local:
@@ -87,7 +92,7 @@ class Patroni(AbstractPatroniDaemon):
def noloadbalance(self):
return bool(self.tags.get('noloadbalance', False))
def schedule_next_run(self):
def schedule_next_run(self) -> None:
self.next_run += self.dcs.loop_wait
current_time = time.time()
nap_time = self.next_run - current_time
@@ -100,12 +105,12 @@ class Patroni(AbstractPatroniDaemon):
elif self.ha.watch(nap_time):
self.next_run = time.time()
def run(self):
def run(self) -> None:
self.api.start()
self.next_run = time.time()
super(Patroni, self).run()
def _run_cycle(self):
def _run_cycle(self) -> None:
logger.info(self.ha.run_cycle())
if self.dcs.cluster and self.dcs.cluster.config and self.dcs.cluster.config.data \
@@ -117,7 +122,7 @@ class Patroni(AbstractPatroniDaemon):
self.schedule_next_run()
def _shutdown(self):
def _shutdown(self) -> None:
try:
self.api.shutdown()
except Exception:
@@ -128,7 +133,7 @@ class Patroni(AbstractPatroniDaemon):
logger.exception('Exception during Ha.shutdown')
def patroni_main():
def patroni_main() -> None:
from multiprocessing import freeze_support
from patroni.validator import schema
@@ -136,18 +141,20 @@ def patroni_main():
abstract_main(Patroni, schema)
def main():
if os.getpid() != 1:
from patroni import check_psycopg
def main() -> None:
from patroni import check_psycopg
check_psycopg()
check_psycopg()
if os.getpid() != 1:
return patroni_main()
# Patroni started with PID=1, it looks like we are in the container
from types import FrameType
pid = 0
# Looks like we are in a docker, so we will act like init
def sigchld_handler(signo, stack_frame):
def sigchld_handler(signo: int, stack_frame: Optional[FrameType]) -> None:
try:
while True:
ret = os.waitpid(-1, os.WNOHANG)
@@ -158,7 +165,7 @@ def main():
except OSError:
pass
def passtochild(signo, stack_frame):
def passtochild(signo: int, stack_frame: Optional[FrameType]):
if pid:
os.kill(pid, signo)

View File

@@ -1,5 +1,10 @@
import logging
from threading import Event, Lock, RLock, Thread
from types import TracebackType
from typing import Any, Callable, Optional, Tuple, Type
from .postgresql.cancellable import CancellableSubprocess
logger = logging.getLogger(__name__)
@@ -15,19 +20,19 @@ class CriticalTask(object):
call cancel. If the task has completed `cancel()` will return False and `result` field will contain the result of
the task. When cancel returns True it is guaranteed that the background task will notice the `is_cancelled` flag.
"""
def __init__(self):
def __init__(self) -> None:
self._lock = Lock()
self.is_cancelled = False
self.result = None
def reset(self):
def reset(self) -> None:
"""Must be called every time the background task is finished.
Must be called from async thread. Caller must hold lock on async executor when calling."""
self.is_cancelled = False
self.result = None
def cancel(self):
def cancel(self) -> bool:
"""Tries to cancel the task, returns True if the task has already run.
Caller must hold lock on async executor and the task when calling."""
@@ -36,37 +41,38 @@ class CriticalTask(object):
self.is_cancelled = True
return True
def complete(self, result):
def complete(self, result: Any) -> None:
"""Mark task as completed along with a result.
Must be called from async thread. Caller must hold lock on task when calling."""
self.result = result
def __enter__(self):
def __enter__(self) -> 'CriticalTask':
self._lock.acquire()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]) -> None:
self._lock.release()
class AsyncExecutor(object):
def __init__(self, cancellable, ha_wakeup):
def __init__(self, cancellable: CancellableSubprocess, ha_wakeup: Callable[..., None]) -> None:
self._cancellable = cancellable
self._ha_wakeup = ha_wakeup
self._thread_lock = RLock()
self._scheduled_action = None
self._scheduled_action: Optional[str] = None
self._scheduled_action_lock = RLock()
self._is_cancelled = False
self._finish_event = Event()
self.critical_task = CriticalTask()
@property
def busy(self):
def busy(self) -> bool:
return self.scheduled_action is not None
def schedule(self, action):
def schedule(self, action: str) -> Optional[str]:
with self._scheduled_action_lock:
if self._scheduled_action is not None:
return self._scheduled_action
@@ -76,15 +82,15 @@ class AsyncExecutor(object):
return None
@property
def scheduled_action(self):
def scheduled_action(self) -> Optional[str]:
with self._scheduled_action_lock:
return self._scheduled_action
def reset_scheduled_action(self):
def reset_scheduled_action(self) -> None:
with self._scheduled_action_lock:
self._scheduled_action = None
def run(self, func, args=()):
def run(self, func: Callable[..., Any], args: Tuple[Any, ...] = ()) -> Optional[bool]:
wakeup = False
try:
with self:
@@ -107,16 +113,16 @@ class AsyncExecutor(object):
if wakeup is not None:
self._ha_wakeup()
def run_async(self, func, args=()):
def run_async(self, func: Callable[..., Any], args: Tuple[Any, ...] = ()) -> None:
Thread(target=self.run, args=(func, args)).start()
def try_run_async(self, action, func, args=()):
def try_run_async(self, action: str, func: Callable[..., Any], args: Tuple[Any, ...] = ()) -> Optional[str]:
prev = self.schedule(action)
if prev is None:
return self.run_async(func, args)
return 'Failed to run {0}, {1} is already in progress'.format(action, prev)
def cancel(self):
def cancel(self) -> None:
with self:
with self._scheduled_action_lock:
if self._scheduled_action is None:
@@ -130,8 +136,10 @@ class AsyncExecutor(object):
with self:
self.reset_scheduled_action()
def __enter__(self):
def __enter__(self) -> 'AsyncExecutor':
self._thread_lock.acquire()
return self
def __exit__(self, *args):
def __exit__(self, exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]) -> None:
self._thread_lock.release()

View File

@@ -1,16 +1,15 @@
from collections import OrderedDict
from collections.abc import MutableMapping, MutableSet
from typing import Any, Collection, Dict, Iterable, Iterator, Optional, Tuple, Union
from typing import Any, Collection, Dict, Iterator, MutableMapping, MutableSet, Optional
class CaseInsensitiveSet(MutableSet):
class CaseInsensitiveSet(MutableSet[str]):
"""A case-insensitive ``set``-like object.
Implements all methods and operations of :class:``MutableSet``. All values are expected to be strings.
The structure remembers the case of the last value set, however, contains testing is case insensitive.
"""
def __init__(self, values: Optional[Collection[str]] = None) -> None:
self._values = {}
self._values: Dict[str, str] = {}
for v in values or ():
self.add(v)
@@ -39,7 +38,7 @@ class CaseInsensitiveSet(MutableSet):
return self <= other
class CaseInsensitiveDict(MutableMapping):
class CaseInsensitiveDict(MutableMapping[str, Any]):
"""A case-insensitive ``dict``-like object.
Implements all methods and operations of :class:``MutableMapping`` as well as dict's :func:``copy``.
@@ -47,8 +46,8 @@ class CaseInsensitiveDict(MutableMapping):
and ``iter(instance)``, ``keys()``, ``items()``, ``iterkeys()``, and ``iteritems()`` will contain
case-sensitive keys. However, querying and contains testing is case insensitive.
"""
def __init__(self, data: Optional[Union[Dict[str, Any], Iterable[Tuple[str, Any]]]] = None) -> None:
self._values = OrderedDict()
def __init__(self, data: Optional[Dict[str, Any]] = None) -> None:
self._values: OrderedDict[str, Any] = OrderedDict()
self.update(data or {})
def __setitem__(self, key: str, value: Any) -> None:
@@ -68,7 +67,7 @@ class CaseInsensitiveDict(MutableMapping):
return len(self._values)
def copy(self) -> 'CaseInsensitiveDict':
return CaseInsensitiveDict(self._values.values())
return CaseInsensitiveDict({v[0]: v[1] for v in self._values.values()})
def __repr__(self) -> str:
return '<{0}{1} at {2:x}>'.format(type(self).__name__, dict(self.items()), id(self))

View File

@@ -7,7 +7,7 @@ import yaml
from collections import defaultdict
from copy import deepcopy
from typing import Any, Dict, Optional, Union
from typing import Any, Callable, Collection, Dict, List, Optional, Union
from . import PATRONI_ENV_PREFIX
from .collections import CaseInsensitiveDict
@@ -33,9 +33,10 @@ _AUTH_ALLOWED_PARAMETERS = (
)
def default_validator(conf):
def default_validator(conf: Dict[str, Any]) -> List[str]:
if not conf:
raise ConfigParseError("Config is empty.")
return []
class GlobalConfig(object):
@@ -86,7 +87,7 @@ class GlobalConfig(object):
""":returns: `True` if at least one synchronous node is required."""
return self.check_mode('synchronous_mode_strict')
def get_standby_cluster_config(self) -> Any:
def get_standby_cluster_config(self) -> Union[Dict[str, Any], Any]:
""":returns: "standby_cluster" configuration."""
return deepcopy(self.get('standby_cluster'))
@@ -142,7 +143,7 @@ class GlobalConfig(object):
if 'primary_stop_timeout' in self.__config else self.get_int('master_stop_timeout', default)
def get_global_config(cluster: Union[Cluster, None], default: Optional[Dict] = None) -> GlobalConfig:
def get_global_config(cluster: Union[Cluster, None], default: Optional[Dict[str, Any]] = None) -> GlobalConfig:
"""Instantiates :class:`GlobalConfig` based on the input.
:param cluster: the currently known cluster state from DCS
@@ -180,7 +181,7 @@ class Config(object):
PATRONI_CONFIG_VARIABLE = PATRONI_ENV_PREFIX + 'CONFIGURATION'
__CACHE_FILENAME = 'patroni.dynamic.json'
__DEFAULT_CONFIG = {
__DEFAULT_CONFIG: Dict[str, Any] = {
'ttl': 30, 'loop_wait': 10, 'retry_timeout': 10,
'standby_cluster': {
'create_replica_methods': '',
@@ -199,19 +200,21 @@ class Config(object):
}
}
def __init__(self, configfile, validator=default_validator):
def __init__(self, configfile: str,
validator: Optional[Callable[[Dict[str, Any]], List[str]]] = default_validator) -> None:
self._modify_index = -1
self._dynamic_configuration = {}
self.__environment_configuration = self._build_environment_configuration()
# Patroni reads the configuration from the command-line argument if it exists, otherwise from the environment
self._config_file = configfile and os.path.exists(configfile) and configfile
self._config_file = configfile if configfile and os.path.exists(configfile) else None
if self._config_file:
self._local_configuration = self._load_config_file()
else:
config_env = os.environ.pop(self.PATRONI_CONFIG_VARIABLE, None)
self._local_configuration = config_env and yaml.safe_load(config_env) or self.__environment_configuration
if validator:
errors = validator(self._local_configuration)
if errors:
@@ -224,14 +227,14 @@ class Config(object):
self._cache_needs_saving = False
@property
def config_file(self):
def config_file(self) -> Union[str, None]:
return self._config_file
@property
def dynamic_configuration(self):
def dynamic_configuration(self) -> Dict[str, Any]:
return deepcopy(self._dynamic_configuration)
def _load_config_path(self, path):
def _load_config_path(self, path: str) -> Dict[str, Any]:
"""
If path is a file, loads the yml file pointed to by path.
If path is a directory, loads all yml files in that directory in alphabetical order
@@ -245,20 +248,21 @@ class Config(object):
logger.error('config path %s is neither directory nor file', path)
raise ConfigParseError('invalid config path')
overall_config = {}
overall_config: Dict[str, Any] = {}
for fname in files:
with open(fname) as f:
config = yaml.safe_load(f)
patch_config(overall_config, config)
return overall_config
def _load_config_file(self):
def _load_config_file(self) -> Dict[str, Any]:
"""Loads config.yaml from filesystem and applies some values which were set via ENV"""
assert self._config_file is not None
config = self._load_config_path(self._config_file)
patch_config(config, self.__environment_configuration)
return config
def _load_cache(self):
def _load_cache(self) -> None:
if os.path.isfile(self._cache_file):
try:
with open(self._cache_file) as f:
@@ -266,7 +270,7 @@ class Config(object):
except Exception:
logger.exception('Exception when loading file: %s', self._cache_file)
def save_cache(self):
def save_cache(self) -> None:
if self._cache_needs_saving:
tmpfile = fd = None
try:
@@ -290,7 +294,7 @@ class Config(object):
logger.error('Can not remove temporary file %s', tmpfile)
# configuration could be either ClusterConfig or dict
def set_dynamic_configuration(self, configuration):
def set_dynamic_configuration(self, configuration: Union[ClusterConfig, Dict[str, Any]]) -> bool:
if isinstance(configuration, ClusterConfig):
if self._modify_index == configuration.modify_index:
return False # If the index didn't changed there is nothing to do
@@ -306,8 +310,9 @@ class Config(object):
return True
except Exception:
logger.exception('Exception when setting dynamic_configuration')
return False
def reload_local_configuration(self):
def reload_local_configuration(self) -> Optional[bool]:
if self.config_file:
try:
configuration = self._load_config_file()
@@ -322,12 +327,12 @@ class Config(object):
logger.exception('Exception when reloading local configuration from %s', self.config_file)
@staticmethod
def _process_postgresql_parameters(parameters, is_local=False):
def _process_postgresql_parameters(parameters: Dict[str, Any], is_local: bool = False) -> Dict[str, Any]:
return {name: value for name, value in (parameters or {}).items()
if name not in ConfigHandler.CMDLINE_OPTIONS
or not is_local and ConfigHandler.CMDLINE_OPTIONS[name][1](value)}
def _safe_copy_dynamic_configuration(self, dynamic_configuration):
def _safe_copy_dynamic_configuration(self, dynamic_configuration: Dict[str, Any]) -> Dict[str, Any]:
config = deepcopy(self.__DEFAULT_CONFIG)
for name, value in dynamic_configuration.items():
@@ -347,10 +352,10 @@ class Config(object):
return config
@staticmethod
def _build_environment_configuration():
ret = defaultdict(dict)
def _build_environment_configuration() -> Dict[str, Any]:
ret: Dict[str, Any] = defaultdict(dict)
def _popenv(name):
def _popenv(name: str) -> Union[str, None]:
return os.environ.pop(PATRONI_ENV_PREFIX + name.upper(), None)
for param in ('name', 'namespace', 'scope'):
@@ -358,7 +363,7 @@ class Config(object):
if value:
ret[param] = value
def _fix_log_env(name, oldname):
def _fix_log_env(name: str, oldname: str) -> None:
value = _popenv(oldname)
name = PATRONI_ENV_PREFIX + 'LOG_' + name.upper()
if value and name not in os.environ:
@@ -367,7 +372,7 @@ class Config(object):
for name, oldname in (('level', 'loglevel'), ('format', 'logformat'), ('dateformat', 'log_datefmt')):
_fix_log_env(name, oldname)
def _set_section_values(section, params):
def _set_section_values(section: str, params: List[str]) -> None:
for param in params:
value = _popenv(section + '_' + param)
if value:
@@ -400,7 +405,7 @@ class Config(object):
if value is not None:
ret[first][second] = value
def _parse_list(value):
def _parse_list(value: str) -> Union[List[str], None]:
if not (value.strip().startswith('-') or '[' in value):
value = '[{0}]'.format(value)
try:
@@ -416,7 +421,7 @@ class Config(object):
if value:
ret[first][second] = value
def _parse_dict(value):
def _parse_dict(value: str) -> Union[Dict[str, Any], None]:
if not value.strip().startswith('{'):
value = '{{{0}}}'.format(value)
try:
@@ -433,8 +438,8 @@ class Config(object):
if value:
ret[first][second] = value
def _get_auth(name, params=None):
ret = {}
def _get_auth(name: str, params: Optional[Collection[str]] = None) -> Dict[str, str]:
ret: Dict[str, str] = {}
for param in params or _AUTH_ALLOWED_PARAMETERS[:2]:
value = _popenv(name + '_' + param)
if value:
@@ -503,7 +508,8 @@ class Config(object):
return ret
def _build_effective_configuration(self, dynamic_configuration, local_configuration):
def _build_effective_configuration(self, dynamic_configuration: Dict[str, Any],
local_configuration: Dict[str, Union[Dict[str, Any], Any]]) -> Dict[str, Any]:
config = self._safe_copy_dynamic_configuration(dynamic_configuration)
for name, value in local_configuration.items():
if name == 'citus': # remove invalid citus configuration
@@ -565,16 +571,16 @@ class Config(object):
return config
def get(self, key, default=None):
def get(self, key: str, default: Optional[Any] = None) -> Any:
return self.__effective_configuration.get(key, default)
def __contains__(self, key):
def __contains__(self, key: str) -> bool:
return key in self.__effective_configuration
def __getitem__(self, key):
def __getitem__(self, key: str) -> Any:
return self.__effective_configuration[key]
def copy(self):
def copy(self) -> Dict[str, Any]:
return deepcopy(self.__effective_configuration)
def get_global_config(self, cluster: Union[Cluster, None]) -> GlobalConfig:

View File

@@ -18,23 +18,25 @@ import shutil
import subprocess
import sys
import tempfile
import urllib3
import time
import yaml
from typing import Any, Dict, Union
from click import ClickException
from collections import defaultdict
from contextlib import contextmanager
from prettytable import ALL, FRAME, PrettyTable
from urllib.parse import urlparse
from typing import Any, Dict, Generator, Iterator, List, Optional, Union, Tuple, TYPE_CHECKING
if TYPE_CHECKING: # pragma: no cover
from psycopg import Cursor
from psycopg2 import cursor
try:
from ydiff import markup_to_pager, PatchStream
except ImportError: # pragma: no cover
from cdiff import markup_to_pager, PatchStream
from .dcs import get_dcs as _get_dcs
from .dcs import get_dcs as _get_dcs, AbstractDCS, Cluster, Member
from .exceptions import PatroniException
from .postgresql.misc import postgres_version_to_int
from .utils import cluster_as_json, patch_config, polling_loop
@@ -43,30 +45,31 @@ from .version import __version__
CONFIG_DIR_PATH = click.get_app_dir('patroni')
CONFIG_FILE_PATH = os.path.join(CONFIG_DIR_PATH, 'patronictl.yaml')
DCS_DEFAULTS = {'zookeeper': {'port': 2181, 'template': "zookeeper:\n hosts: ['{host}:{port}']"},
'exhibitor': {'port': 8181, 'template': "exhibitor:\n hosts: [{host}]\n port: {port}"},
'consul': {'port': 8500, 'template': "consul:\n host: '{host}:{port}'"},
'etcd': {'port': 2379, 'template': "etcd:\n host: '{host}:{port}'"},
'etcd3': {'port': 2379, 'template': "etcd3:\n host: '{host}:{port}'"}}
DCS_DEFAULTS: Dict[str, Dict[str, Any]] = {
'zookeeper': {'port': 2181, 'template': "zookeeper:\n hosts: ['{host}:{port}']"},
'exhibitor': {'port': 8181, 'template': "exhibitor:\n hosts: [{host}]\n port: {port}"},
'consul': {'port': 8500, 'template': "consul:\n host: '{host}:{port}'"},
'etcd': {'port': 2379, 'template': "etcd:\n host: '{host}:{port}'"},
'etcd3': {'port': 2379, 'template': "etcd3:\n host: '{host}:{port}'"}}
class PatroniCtlException(ClickException):
class PatroniCtlException(click.ClickException):
pass
class PatronictlPrettyTable(PrettyTable):
def __init__(self, header, *args, **kwargs):
PrettyTable.__init__(self, *args, **kwargs)
def __init__(self, header: str, *args: Any, **kwargs: Any) -> None:
super(PatronictlPrettyTable, self).__init__(*args, **kwargs)
self.__table_header = header
self.__hline_num = 0
self.__hline = None
self.__hline: str
def __build_header(self, line):
def __build_header(self, line: str) -> str:
header = self.__table_header[:len(line) - 2]
return "".join([line[0], header, line[1 + len(header):]])
def _stringify_hrule(self, *args, **kwargs):
def _stringify_hrule(self, *args: Any, **kwargs: Any) -> str:
ret = super(PatronictlPrettyTable, self)._stringify_hrule(*args, **kwargs)
where = args[1] if len(args) > 1 else kwargs.get('where')
if where == 'top_' and self.__table_header:
@@ -74,13 +77,13 @@ class PatronictlPrettyTable(PrettyTable):
self.__hline_num += 1
return ret
def _is_first_hline(self):
def _is_first_hline(self) -> bool:
return self.__hline_num == 0
def _set_hline(self, value):
def _set_hline(self, value: str) -> None:
self.__hline = value
def _get_hline(self):
def _get_hline(self) -> str:
ret = self.__hline
# Inject nice table header
@@ -93,7 +96,7 @@ class PatronictlPrettyTable(PrettyTable):
_hrule = property(_get_hline, _set_hline)
def parse_dcs(dcs: str) -> Union[Dict[str, Any], None]:
def parse_dcs(dcs: Optional[str]) -> Optional[Dict[str, Any]]:
"""Parse a DCS URL.
:param dcs: the DCS URL in the format ``DCS://HOST:PORT``. ``DCS`` can be one among
@@ -143,7 +146,7 @@ def parse_dcs(dcs: str) -> Union[Dict[str, Any], None]:
return yaml.safe_load(default['template'].format(host=parsed.hostname or 'localhost', port=port or default['port']))
def load_config(path, dcs_url):
def load_config(path: str, dcs_url: Optional[str]) -> Dict[str, Any]:
from patroni.config import Config
if not (os.path.exists(path) and os.access(path, os.R_OK)):
@@ -156,11 +159,11 @@ def load_config(path, dcs_url):
logging.debug('Loading configuration from file %s', path)
config = Config(path, validator=None).copy()
dcs_url = parse_dcs(dcs_url) or {}
if dcs_url:
dcs_kwargs = parse_dcs(dcs_url) or {}
if dcs_kwargs:
for d in DCS_DEFAULTS:
config.pop(d, None)
config.update(dcs_url)
config.update(dcs_kwargs)
return config
@@ -183,7 +186,7 @@ role_choice = click.Choice(['leader', 'primary', 'standby-leader', 'replica', 's
@click.option('--dcs-url', '--dcs', '-d', 'dcs_url', help='The DCS connect url', envvar='DCS_URL')
@option_insecure
@click.pass_context
def ctl(ctx, config_file, dcs_url, insecure):
def ctl(ctx: click.Context, config_file: str, dcs_url: Optional[str], insecure: bool) -> None:
level = 'WARNING'
for name in ('LOGLEVEL', 'PATRONI_LOGLEVEL', 'PATRONI_LOG_LEVEL'):
level = os.environ.get(name, level)
@@ -194,7 +197,7 @@ def ctl(ctx, config_file, dcs_url, insecure):
ctx.obj.setdefault('ctl', {})['insecure'] = ctx.obj.get('ctl', {}).get('insecure') or insecure
def get_dcs(config, scope, group):
def get_dcs(config: Dict[str, Any], scope: str, group: Optional[int]) -> AbstractDCS:
config.update({'scope': scope, 'patronictl': True})
if group is not None:
config['citus'] = {'group': group}
@@ -202,13 +205,14 @@ def get_dcs(config, scope, group):
try:
dcs = _get_dcs(config)
if config.get('citus') and group is None:
dcs.get_cluster = dcs._get_citus_cluster
dcs.is_citus_coordinator = lambda: True
return dcs
except PatroniException as e:
raise PatroniCtlException(str(e))
def request_patroni(member, method='GET', endpoint=None, data=None):
def request_patroni(member: Member, method: str = 'GET',
endpoint: Optional[str] = None, data: Optional[Any] = None) -> urllib3.response.HTTPResponse:
ctx = click.get_current_context() # the current click context
request_executor = ctx.obj.get('__request_patroni')
if not request_executor:
@@ -216,19 +220,20 @@ def request_patroni(member, method='GET', endpoint=None, data=None):
return request_executor(member, method, endpoint, data)
def print_output(columns, rows, alignment=None, fmt='pretty', header=None, delimiter='\t'):
def print_output(columns: Optional[List[str]], rows: List[List[Any]], alignment: Optional[Dict[str, str]] = None,
fmt: str = 'pretty', header: str = '', delimiter: str = '\t') -> None:
if fmt in {'json', 'yaml', 'yml'}:
elements = [{k: v for k, v in zip(columns, r) if not header or str(v)} for r in rows]
elements = [{k: v for k, v in zip(columns or [], r) if not header or str(v)} for r in rows]
func = json.dumps if fmt == 'json' else format_config_for_editing
click.echo(func(elements))
elif fmt in {'pretty', 'tsv', 'topology'}:
list_cluster = bool(header and columns and columns[0] == 'Cluster')
if list_cluster and 'Tags' in columns: # we want to format member tags as YAML
if list_cluster and columns and 'Tags' in columns: # we want to format member tags as YAML
i = columns.index('Tags')
for row in rows:
if row[i]:
row[i] = format_config_for_editing(row[i], fmt != 'pretty').strip()
if list_cluster and fmt != 'tsv': # skip cluster name and maybe Citus group if pretty-printing
if list_cluster and header and fmt != 'tsv': # skip cluster name and maybe Citus group if pretty-printing
skip_cols = 2 if ' (group: ' in header else 1
columns = columns[skip_cols:] if columns else []
rows = [row[skip_cols:] for row in rows]
@@ -247,7 +252,7 @@ def print_output(columns, rows, alignment=None, fmt='pretty', header=None, delim
click.echo(table)
def watching(w, watch, max_count=None, clear=True):
def watching(w: bool, watch: Optional[int], max_count: Optional[int] = None, clear: bool = True) -> Iterator[int]:
"""
>>> len(list(watching(True, 1, 0)))
1
@@ -275,7 +280,8 @@ def watching(w, watch, max_count=None, clear=True):
yield 0
def get_all_members(obj, cluster, group, role='leader'):
def get_all_members(obj: Dict[str, Any], cluster: Cluster,
group: Optional[int], role: str = 'leader') -> Iterator[Member]:
clusters = {0: cluster}
if obj.get('citus') and group is None:
clusters.update(cluster.workers)
@@ -296,23 +302,25 @@ def get_all_members(obj, cluster, group, role='leader'):
yield m
def get_any_member(obj, cluster, group, role='leader', member=None):
def get_any_member(obj: Dict[str, Any], cluster: Cluster, group: Optional[int],
role: str = 'leader', member: Optional[str] = None) -> Optional[Member]:
for m in get_all_members(obj, cluster, group, role):
if member is None or m.name == member:
return m
def get_all_members_leader_first(cluster):
def get_all_members_leader_first(cluster: Cluster) -> Iterator[Member]:
leader_name = cluster.leader.member.name if cluster.leader and cluster.leader.member.api_url else None
if leader_name:
if leader_name and cluster.leader:
yield cluster.leader.member
for member in cluster.members:
if member.api_url and member.name != leader_name:
yield member
def get_cursor(obj, cluster, group, connect_parameters, role='leader', member=None):
member = get_any_member(obj, cluster, group, role=role, member=member)
def get_cursor(obj: Dict[str, Any], cluster: Cluster, group: Optional[int], connect_parameters: Dict[str, Any],
role: str = 'leader', member_name: Optional[str] = None) -> Union['cursor', 'Cursor[Any]', None]:
member = get_any_member(obj, cluster, group, role=role, member=member_name)
if member is None:
return None
@@ -330,7 +338,8 @@ def get_cursor(obj, cluster, group, connect_parameters, role='leader', member=No
return cursor
cursor.execute('SELECT pg_catalog.pg_is_in_recovery()')
in_recovery = cursor.fetchone()[0]
row = cursor.fetchone()
in_recovery = not row or row[0]
if in_recovery and role in ('replica', 'standby', 'standby-leader')\
or not in_recovery and role in ('master', 'primary'):
@@ -341,7 +350,8 @@ def get_cursor(obj, cluster, group, connect_parameters, role='leader', member=No
return None
def get_members(obj, cluster, cluster_name, member_names, role, force, action, ask_confirmation=True, group=None):
def get_members(obj: Dict[str, Any], cluster: Cluster, cluster_name: str, member_names: List[str], role: str,
force: bool, action: str, ask_confirmation: bool = True, group: Optional[int] = None) -> List[Member]:
members = list(get_all_members(obj, cluster, group, role))
candidates = {m.name for m in members}
@@ -371,7 +381,8 @@ def get_members(obj, cluster, cluster_name, member_names, role, force, action, a
return members
def confirm_members_action(members, force, action, scheduled_at=None):
def confirm_members_action(members: List[Member], force: bool, action: str,
scheduled_at: Optional[datetime.datetime] = None) -> None:
if scheduled_at:
if not force:
confirm = click.confirm('Are you sure you want to schedule {0} of members {1} at {2}?'
@@ -392,12 +403,13 @@ def confirm_members_action(members, force, action, scheduled_at=None):
@arg_cluster_name
@option_citus_group
@click.pass_obj
def dsn(obj, cluster_name, group, role, member):
def dsn(obj: Dict[str, Any], cluster_name: str, group: Optional[int],
role: Optional[str], member: Optional[str]) -> None:
if member is not None:
if role is not None:
raise PatroniCtlException('--role and --member are mutually exclusive options')
role = 'any'
if member is None and role is None:
elif role is None:
role = 'leader'
cluster = get_dcs(obj, cluster_name, group).get_cluster()
@@ -425,33 +437,36 @@ def dsn(obj, cluster_name, group, role, member):
@click.option('-d', '--dbname', help='database name to connect to', type=str)
@click.pass_obj
def query(
obj,
cluster_name,
group,
role,
member,
w,
watch,
delimiter,
command,
p_file,
password,
username,
dbname,
fmt='tsv',
):
obj: Dict[str, Any],
cluster_name: str,
group: Optional[int],
role: Optional[str],
member: Optional[str],
w: bool,
watch: Optional[int],
delimiter: str,
command: Optional[str],
p_file: Optional[io.BufferedReader],
password: Optional[bool],
username: Optional[str],
dbname: Optional[str],
fmt: str = 'tsv'
) -> None:
if member is not None:
if role is not None:
raise PatroniCtlException('--role and --member are mutually exclusive options')
role = 'any'
if member is None and role is None:
elif role is None:
role = 'leader'
if p_file is not None and command is not None:
raise PatroniCtlException('--file and --command are mutually exclusive options')
if p_file is None and command is None:
raise PatroniCtlException('You need to specify either --command or --file')
if p_file is not None:
if command is not None:
raise PatroniCtlException('--file and --command are mutually exclusive options')
sql = p_file.read().decode('utf-8')
else:
if command is None:
raise PatroniCtlException('You need to specify either --command or --file')
sql = command
connect_parameters = {}
if username:
@@ -461,25 +476,25 @@ def query(
if dbname:
connect_parameters['dbname'] = dbname
if p_file is not None:
command = p_file.read()
dcs = get_dcs(obj, cluster_name, group)
cursor = None
cluster = cursor = None
for _ in watching(w, watch, clear=False):
if cursor is None:
if cluster is None:
cluster = dcs.get_cluster()
# cursor = get_cursor(obj, cluster, group, connect_parameters, role=role, member=member)
output, header = query_member(obj, cluster, group, cursor, member, role, command, connect_parameters)
output, header = query_member(obj, cluster, group, cursor, member, role, sql, connect_parameters)
print_output(header, output, fmt=fmt, delimiter=delimiter)
def query_member(obj, cluster, group, cursor, member, role, command, connect_parameters):
def query_member(obj: Dict[str, Any], cluster: Cluster, group: Optional[int],
cursor: Union['cursor', 'Cursor[Any]', None], member: Optional[str], role: str,
command: str, connect_parameters: Dict[str, Any]) -> Tuple[List[List[Any]], Optional[List[Any]]]:
from . import psycopg
try:
if cursor is None:
cursor = get_cursor(obj, cluster, group, connect_parameters, role=role, member=member)
cursor = get_cursor(obj, cluster, group, connect_parameters, role=role, member_name=member)
if cursor is None:
if member is not None:
@@ -489,8 +504,8 @@ def query_member(obj, cluster, group, cursor, member, role, command, connect_par
logging.debug(message)
return [[timestamp(0), message]], None
cursor.execute(command)
return cursor.fetchall(), [d.name for d in cursor.description]
cursor.execute(command.encode('utf-8'))
return [list(row) for row in cursor], cursor.description and [d.name for d in cursor.description]
except psycopg.DatabaseError as de:
logging.debug(de)
if cursor is not None and not cursor.connection.closed:
@@ -505,7 +520,7 @@ def query_member(obj, cluster, group, cursor, member, role, command, connect_par
@option_citus_group
@option_format
@click.pass_obj
def remove(obj, cluster_name, group, fmt):
def remove(obj: Dict[str, Any], cluster_name: str, group: Optional[int], fmt: str) -> None:
dcs = get_dcs(obj, cluster_name, group)
cluster = dcs.get_cluster()
@@ -532,7 +547,8 @@ def remove(obj, cluster_name, group, fmt):
dcs.delete_cluster()
def check_response(response, member_name, action_name, silent_success=False):
def check_response(response: urllib3.response.HTTPResponse, member_name: str,
action_name: str, silent_success: bool = False) -> bool:
if response.status >= 400:
click.echo('Failed: {0} for member {1}, status code={2}, ({3})'.format(
action_name, member_name, response.status, response.data.decode('utf-8')
@@ -543,8 +559,8 @@ def check_response(response, member_name, action_name, silent_success=False):
return True
def parse_scheduled(scheduled):
if (scheduled or 'now') != 'now':
def parse_scheduled(scheduled: Optional[str]) -> Optional[datetime.datetime]:
if scheduled is not None and (scheduled or 'now') != 'now':
try:
scheduled_at = dateutil.parser.parse(scheduled)
if scheduled_at.tzinfo is None:
@@ -564,7 +580,8 @@ def parse_scheduled(scheduled):
@click.option('--role', '-r', help='Reload only members with this role', type=role_choice, default='any')
@option_force
@click.pass_obj
def reload(obj, cluster_name, member_names, group, force, role):
def reload(obj: Dict[str, Any], cluster_name: str, member_names: List[str],
group: Optional[int], force: bool, role: str) -> None:
dcs = get_dcs(obj, cluster_name, group)
cluster = dcs.get_cluster()
@@ -575,8 +592,10 @@ def reload(obj, cluster_name, member_names, group, force, role):
if r.status == 200:
click.echo('No changes to apply on member {0}'.format(member.name))
elif r.status == 202:
from patroni.config import get_global_config
config = get_global_config(cluster)
click.echo('Reload request received for member {0} and will be processed within {1} seconds'.format(
member.name, cluster.config.data.get('loop_wait', dcs.loop_wait))
member.name, config.get('loop_wait') or dcs.loop_wait)
)
else:
click.echo('Failed: reload for member {0}, status code={1}, ({2})'.format(
@@ -595,11 +614,12 @@ def reload(obj, cluster_name, member_names, group, force, role):
@click.option('--pg-version', 'version', help='Restart if the PostgreSQL version is less than provided (e.g. 9.5.2)',
default=None)
@click.option('--pending', help='Restart if pending', is_flag=True)
@click.option('--timeout',
help='Return error and fail over if necessary when restarting takes longer than this.')
@click.option('--timeout', help='Return error and fail over if necessary when restarting takes longer than this.')
@option_force
@click.pass_obj
def restart(obj, cluster_name, group, member_names, force, role, p_any, scheduled, version, pending, timeout):
def restart(obj: Dict[str, Any], cluster_name: str, group: Optional[int], member_names: List[str],
force: bool, role: str, p_any: bool, scheduled: Optional[str], version: Optional[str],
pending: bool, timeout: Optional[str]) -> None:
cluster = get_dcs(obj, cluster_name, group).get_cluster()
members = get_members(obj, cluster, cluster_name, member_names, role, force, 'restart', False, group=group)
@@ -666,13 +686,14 @@ def restart(obj, cluster_name, group, member_names, force, role, p_any, schedule
@option_force
@click.option('--wait', help='Wait until reinitialization completes', is_flag=True)
@click.pass_obj
def reinit(obj, cluster_name, group, member_names, force, wait):
def reinit(obj: Dict[str, Any], cluster_name: str, group: Optional[int],
member_names: List[str], force: bool, wait: bool) -> None:
cluster = get_dcs(obj, cluster_name, group).get_cluster()
members = get_members(obj, cluster, cluster_name, member_names, 'replica', force, 'reinitialize', group=group)
wait_on_members = []
wait_on_members: List[Member] = []
for member in members:
body = {'force': force}
body: Dict[str, bool] = {'force': force}
while True:
r = request_patroni(member, 'post', 'reinitialize', body)
started = check_response(r, member.name, 'reinitialize')
@@ -699,7 +720,9 @@ def reinit(obj, cluster_name, group, member_names, force, wait):
wait_on_members.remove(member)
def _do_failover_or_switchover(obj, action, cluster_name, group, leader, candidate, force, scheduled=None):
def _do_failover_or_switchover(obj: Dict[str, Any], action: str, cluster_name: str,
group: Optional[int], leader: Optional[str], candidate: Optional[str],
force: bool, scheduled: Optional[str] = None) -> None:
"""
We want to trigger a failover or switchover for the specified cluster name.
@@ -729,7 +752,7 @@ def _do_failover_or_switchover(obj, action, cluster_name, group, leader, candida
else:
from patroni.config import get_global_config
prompt = 'Standby Leader' if get_global_config(cluster).is_standby_cluster else 'Primary'
leader = click.prompt(prompt, type=str, default=cluster.leader.member.name)
leader = click.prompt(prompt, type=str, default=(cluster.leader and cluster.leader.member.name))
if leader is not None and cluster.leader and cluster.leader.member.name != leader:
raise PatroniCtlException('Member {0} is not the leader of cluster {1}'.format(leader, cluster_name))
@@ -788,8 +811,8 @@ def _do_failover_or_switchover(obj, action, cluster_name, group, leader, candida
r = None
try:
member = cluster.leader.member if cluster.leader else cluster.get_member(candidate, False)
member = cluster.leader.member if cluster.leader else candidate and cluster.get_member(candidate, False)
assert isinstance(member, Member)
r = request_patroni(member, 'post', action, failover_value)
# probably old patroni, which doesn't support switchover yet
@@ -820,7 +843,8 @@ def _do_failover_or_switchover(obj, action, cluster_name, group, leader, candida
@click.option('--candidate', help='The name of the candidate', default=None)
@option_force
@click.pass_obj
def failover(obj, cluster_name, group, leader, candidate, force):
def failover(obj: Dict[str, Any], cluster_name: str, group: Optional[int],
leader: Optional[str], candidate: Optional[str], force: bool) -> None:
action = 'switchover' if leader else 'failover'
_do_failover_or_switchover(obj, action, cluster_name, group, leader, candidate, force)
@@ -834,11 +858,13 @@ def failover(obj, cluster_name, group, leader, candidate, force):
default=None)
@option_force
@click.pass_obj
def switchover(obj, cluster_name, group, leader, candidate, force, scheduled):
def switchover(obj: Dict[str, Any], cluster_name: str, group: Optional[int],
leader: Optional[str], candidate: Optional[str], force: bool, scheduled: Optional[str]) -> None:
_do_failover_or_switchover(obj, 'switchover', cluster_name, group, leader, candidate, force, scheduled)
def generate_topology(level, member, topology):
def generate_topology(level: int, member: Dict[str, Any],
topology: Dict[str, List[Dict[str, Any]]]) -> Iterator[Dict[str, Any]]:
members = topology.get(member['name'], [])
if level > 0:
@@ -852,8 +878,8 @@ def generate_topology(level, member, topology):
yield member
def topology_sort(members):
topology = defaultdict(list)
def topology_sort(members: List[Dict[str, Any]]) -> Iterator[Dict[str, Any]]:
topology: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
leader = next((m for m in members if m['role'].endswith('leader')), {'name': None})
replicas = set(member['name'] for member in members if not member['role'].endswith('leader'))
for member in members:
@@ -865,8 +891,8 @@ def topology_sort(members):
yield member
def get_cluster_service_info(cluster):
service_info = []
def get_cluster_service_info(cluster: Dict[str, Any]) -> List[str]:
service_info: List[str] = []
if cluster.get('pause'):
service_info.append('Maintenance mode: on')
@@ -879,8 +905,9 @@ def get_cluster_service_info(cluster):
return service_info
def output_members(obj, cluster, name, extended=False, fmt='pretty', group=None):
rows = []
def output_members(obj: Dict[str, Any], cluster: Cluster, name: str,
extended: bool = False, fmt: str = 'pretty', group: Optional[int] = None) -> None:
rows: List[List[Any]] = []
logging.debug(cluster)
initialize = {None: 'uninitialized', '': 'initializing'}.get(cluster.initialize, cluster.initialize)
@@ -905,12 +932,12 @@ def output_members(obj, cluster, name, extended=False, fmt='pretty', group=None)
len(set(m['host'] for m in all_members)) < len(all_members)
sort = topology_sort if fmt == 'topology' else iter
for g, cluster in sorted(clusters.items()):
for member in sort(cluster['members']):
for g, c in sorted(clusters.items()):
for member in sort(c['members']):
logging.debug(member)
lag = member.get('lag', '')
member.update(cluster=name, member=member['name'], group=g,
member.update(c=name, member=member['name'], group=g,
host=member.get('host', ''), tl=member.get('timeline', ''),
role=member['role'].replace('_', ' ').title(),
lag_in_mb=round(lag / 1024 / 1024) if isinstance(lag, int) else lag,
@@ -936,8 +963,8 @@ def output_members(obj, cluster, name, extended=False, fmt='pretty', group=None)
if fmt not in ('pretty', 'topology'): # Omit service info when using machine-readable formats
return
for g, cluster in sorted(clusters.items()):
service_info = get_cluster_service_info(cluster)
for g, c in sorted(clusters.items()):
service_info = get_cluster_service_info(c)
if service_info:
if is_citus_cluster and group is None:
click.echo('Citus group: {0}'.format(g))
@@ -953,7 +980,8 @@ def output_members(obj, cluster, name, extended=False, fmt='pretty', group=None)
@option_watch
@option_watchrefresh
@click.pass_obj
def members(obj, cluster_names, group, fmt, watch, w, extended, ts):
def members(obj: Dict[str, Any], cluster_names: List[str], group: Optional[int],
fmt: str, watch: Optional[int], w: bool, extended: bool, ts: bool) -> None:
if not cluster_names:
if 'scope' in obj:
cluster_names = [obj['scope']]
@@ -976,13 +1004,12 @@ def members(obj, cluster_names, group, fmt, watch, w, extended, ts):
@option_citus_group
@option_watch
@option_watchrefresh
@click.pass_obj
@click.pass_context
def topology(ctx, obj, cluster_names, group, watch, w):
def topology(ctx: click.Context, cluster_names: List[str], group: Optional[int], watch: Optional[int], w: bool) -> None:
ctx.forward(members, fmt='topology')
def timestamp(precision=6):
def timestamp(precision: int = 6) -> str:
return datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:precision - 7]
@@ -994,7 +1021,8 @@ def timestamp(precision=6):
@click.option('--role', '-r', help='Flush only members with this role', type=role_choice, default='any')
@option_force
@click.pass_obj
def flush(obj, cluster_name, group, member_names, force, role, target):
def flush(obj: Dict[str, Any], cluster_name: str, group: Optional[int],
member_names: List[str], force: bool, role: str, target: str) -> None:
dcs = get_dcs(obj, cluster_name, group)
cluster = dcs.get_cluster()
@@ -1015,28 +1043,34 @@ def flush(obj, cluster_name, group, member_names, force, role, target):
if r.status in (200, 404):
prefix = 'Success' if r.status == 200 else 'Failed'
return click.echo('{0}: {1}'.format(prefix, r.data.decode('utf-8')))
click.echo('Failed: member={0}, status_code={1}, ({2})'.format(
member.name, r.status, r.data.decode('utf-8')))
except Exception as err:
logging.warning(str(err))
logging.warning('Member %s is not accessible', member.name)
click.echo('Failed: member={0}, status_code={1}, ({2})'.format(
member.name, r.status, r.data.decode('utf-8')))
logging.warning('Failing over to DCS')
click.echo('{0} Could not find any accessible member of cluster {1}'.format(timestamp(), cluster_name))
dcs.manual_failover('', '', index=failover.index)
def wait_until_pause_is_applied(dcs, paused, old_cluster):
def wait_until_pause_is_applied(dcs: AbstractDCS, paused: bool, old_cluster: Cluster) -> None:
from patroni.config import get_global_config
config = get_global_config(old_cluster)
click.echo("'{0}' request sent, waiting until it is recognized by all nodes".format(paused and 'pause' or 'resume'))
old = {m.name: m.index for m in old_cluster.members if m.api_url}
loop_wait = old_cluster.config.data.get('loop_wait', dcs.loop_wait)
loop_wait = config.get('loop_wait') or dcs.loop_wait
cluster = None
for _ in polling_loop(loop_wait + 1):
cluster = dcs.get_cluster()
if all(m.data.get('pause', False) == paused for m in cluster.members if m.name in old):
break
else:
if TYPE_CHECKING: # pragma: no cover
assert cluster is not None
remaining = [m.name for m in cluster.members if m.data.get('pause', False) != paused
and m.name in old and old[m.name] != m.index]
if remaining:
@@ -1045,7 +1079,7 @@ def wait_until_pause_is_applied(dcs, paused, old_cluster):
return click.echo('Success: cluster management is {0}'.format(paused and 'paused' or 'resumed'))
def toggle_pause(config, cluster_name, group, paused, wait):
def toggle_pause(config: Dict[str, Any], cluster_name: str, group: Optional[int], paused: bool, wait: bool) -> None:
from patroni.config import get_global_config
dcs = get_dcs(config, cluster_name, group)
cluster = dcs.get_cluster()
@@ -1078,7 +1112,7 @@ def toggle_pause(config, cluster_name, group, paused, wait):
@option_default_citus_group
@click.pass_obj
@click.option('--wait', help='Wait until pause is applied on all nodes', is_flag=True)
def pause(obj, cluster_name, group, wait):
def pause(obj: Dict[str, Any], cluster_name: str, group: Optional[int], wait: bool) -> None:
return toggle_pause(obj, cluster_name, group, True, wait)
@@ -1087,12 +1121,12 @@ def pause(obj, cluster_name, group, wait):
@option_default_citus_group
@click.option('--wait', help='Wait until pause is cleared on all nodes', is_flag=True)
@click.pass_obj
def resume(obj, cluster_name, group, wait):
def resume(obj: Dict[str, Any], cluster_name: str, group: Optional[int], wait: bool) -> None:
return toggle_pause(obj, cluster_name, group, False, wait)
@contextmanager
def temporary_file(contents, suffix='', prefix='tmp'):
def temporary_file(contents: bytes, suffix: str = '', prefix: str = 'tmp') -> Generator[str, None, None]:
"""Creates a temporary file with specified contents that persists for the context.
:param contents: binary string that will be written to the file.
@@ -1110,12 +1144,12 @@ def temporary_file(contents, suffix='', prefix='tmp'):
os.unlink(tmp.name)
def show_diff(before_editing, after_editing):
def show_diff(before_editing: str, after_editing: str) -> None:
"""Shows a diff between two strings.
If the output is to a tty the diff will be colored. Inputs are expected to be unicode strings.
"""
def listify(string):
def listify(string: str) -> List[str]:
return [line + '\n' for line in string.rstrip('\n').split('\n')]
unified_diff = difflib.unified_diff(listify(before_editing), listify(after_editing))
@@ -1159,7 +1193,7 @@ def show_diff(before_editing, after_editing):
click.echo(line.rstrip('\n'))
def format_config_for_editing(data, default_flow_style=False):
def format_config_for_editing(data: Any, default_flow_style: bool = False) -> str:
"""Formats configuration as YAML for human consumption.
:param data: configuration as nested dictionaries
@@ -1167,7 +1201,7 @@ def format_config_for_editing(data, default_flow_style=False):
return yaml.safe_dump(data, default_flow_style=default_flow_style, encoding=None, allow_unicode=True, width=200)
def apply_config_changes(before_editing, data, kvpairs):
def apply_config_changes(before_editing: str, data: Dict[str, Any], kvpairs: List[str]) -> Tuple[str, Dict[str, Any]]:
"""Applies config changes specified as a list of key-value pairs.
Keys are interpreted as dotted paths into the configuration data structure. Except for paths beginning with
@@ -1181,7 +1215,7 @@ def apply_config_changes(before_editing, data, kvpairs):
"""
changed_data = copy.deepcopy(data)
def set_path_value(config, path, value, prefix=()):
def set_path_value(config: Dict[str, Any], path: List[str], value: Any, prefix: Tuple[str, ...] = ()):
# Postgresql GUCs can't be nested, but can contain dots so we re-flatten the structure for this case
if prefix == ('postgresql', 'parameters'):
path = ['.'.join(path)]
@@ -1208,7 +1242,7 @@ def apply_config_changes(before_editing, data, kvpairs):
return format_config_for_editing(changed_data), changed_data
def apply_yaml_file(data, filename):
def apply_yaml_file(data: Dict[str, Any], filename: str) -> Tuple[str, Dict[str, Any]]:
"""Applies changes from a YAML file to configuration
:param data: configuration datastructure
@@ -1228,7 +1262,7 @@ def apply_yaml_file(data, filename):
return format_config_for_editing(changed_data), changed_data
def invoke_editor(before_editing, cluster_name):
def invoke_editor(before_editing: str, cluster_name: str) -> Tuple[str, Dict[str, Any]]:
"""Starts editor command to edit configuration in human readable format
:param before_editing: human representation before editing
@@ -1272,10 +1306,15 @@ def invoke_editor(before_editing, cluster_name):
' Use - for stdin.')
@option_force
@click.pass_obj
def edit_config(obj, cluster_name, group, force, quiet, kvpairs, pgkvpairs, apply_filename, replace_filename):
def edit_config(obj: Dict[str, Any], cluster_name: str, group: Optional[int],
force: bool, quiet: bool, kvpairs: List[str], pgkvpairs: List[str],
apply_filename: Optional[str], replace_filename: Optional[str]) -> None:
dcs = get_dcs(obj, cluster_name, group)
cluster = dcs.get_cluster()
if not cluster.config:
raise PatroniCtlException('The config key does not exist in the cluster {0}'.format(cluster_name))
before_editing = format_config_for_editing(cluster.config.data)
after_editing = None # Serves as a flag if any changes were requested
@@ -1317,10 +1356,10 @@ def edit_config(obj, cluster_name, group, force, quiet, kvpairs, pgkvpairs, appl
@arg_cluster_name
@option_default_citus_group
@click.pass_obj
def show_config(obj, cluster_name, group):
def show_config(obj: Dict[str, Any], cluster_name: str, group: Optional[int]) -> None:
cluster = get_dcs(obj, cluster_name, group).get_cluster()
click.echo(format_config_for_editing(cluster.config.data))
if cluster.config:
click.echo(format_config_for_editing(cluster.config.data))
@ctl.command('version', help='Output version of patronictl command or a running Patroni instance')
@@ -1328,7 +1367,7 @@ def show_config(obj, cluster_name, group):
@click.argument('member_names', nargs=-1)
@option_citus_group
@click.pass_obj
def version(obj, cluster_name, group, member_names):
def version(obj: Dict[str, Any], cluster_name: str, group: Optional[int], member_names: List[str]) -> None:
click.echo("patronictl version {0}".format(__version__))
if not cluster_name:
@@ -1355,9 +1394,9 @@ def version(obj, cluster_name, group, member_names):
@option_default_citus_group
@option_format
@click.pass_obj
def history(obj, cluster_name, group, fmt):
def history(obj: Dict[str, Any], cluster_name: str, group: Optional[int], fmt: str) -> None:
cluster = get_dcs(obj, cluster_name, group).get_cluster()
history = cluster.history and cluster.history.lines or []
history: List[List[Any]] = list(map(list, cluster.history and cluster.history.lines or []))
table_header_row = ['TL', 'LSN', 'Reason', 'Timestamp', 'New Leader']
for line in history:
if len(line) < len(table_header_row):
@@ -1367,7 +1406,7 @@ def history(obj, cluster_name, group, fmt):
print_output(table_header_row, history, {'TL': 'r', 'LSN': 'r'}, fmt)
def format_pg_version(version):
def format_pg_version(version: int) -> str:
if version < 100000:
return "{0}.{1}.{2}".format(version // 10000, version // 100 % 100, version % 100)
else:

View File

@@ -11,10 +11,11 @@ import signal
import sys
from threading import Lock
from typing import Any, Optional, Type
from typing import Any, Optional, Type, TYPE_CHECKING
from .config import Config
from .validator import Schema
if TYPE_CHECKING: # pragma: no cover
from .config import Config
from .validator import Schema
class AbstractPatroniDaemon(abc.ABC):
@@ -30,7 +31,7 @@ class AbstractPatroniDaemon(abc.ABC):
:ivar config: configuration options for this daemon.
"""
def __init__(self, config: Config) -> None:
def __init__(self, config: 'Config') -> None:
"""Set up signal handlers, logging handler and configuration.
:param config: configuration options for this daemon.
@@ -94,7 +95,7 @@ class AbstractPatroniDaemon(abc.ABC):
with self._sigterm_lock:
return self._received_sigterm
def reload_config(self, sighup: Optional[bool] = False, local: Optional[bool] = False) -> None:
def reload_config(self, sighup: bool = False, local: Optional[bool] = False) -> None:
"""Reload configuration.
:param sighup: if it is related to a SIGHUP signal.
@@ -140,7 +141,7 @@ class AbstractPatroniDaemon(abc.ABC):
self.logger.shutdown()
def abstract_main(cls: Type[AbstractPatroniDaemon], validator: Optional[Schema] = None) -> None:
def abstract_main(cls: Type[AbstractPatroniDaemon], validator: Optional['Schema'] = None) -> None:
"""Create the main entry point of a given daemon process.
Expose a basic argument parser, parse the command-line arguments, and run the given daemon process.

File diff suppressed because it is too large Load Diff

View File

@@ -8,16 +8,19 @@ import ssl
import time
import urllib3
from collections import defaultdict, namedtuple
from collections import defaultdict
from consul import ConsulException, NotFound, base
from http.client import HTTPException
from urllib3.exceptions import HTTPError
from urllib.parse import urlencode, urlparse, quote
from typing import Any, Callable, Dict, List, Mapping, NamedTuple, Optional, Union, Tuple, TYPE_CHECKING
from . import AbstractDCS, Cluster, ClusterConfig, Failover, Leader, Member, SyncState,\
TimelineHistory, ReturnFalseException, catch_return_false_exception, citus_group_re
from ..exceptions import DCSError
from ..utils import deep_compare, parse_bool, Retry, RetryFailedError, split_host_port, uri, USER_AGENT
if TYPE_CHECKING: # pragma: no cover
from ..config import Config
logger = logging.getLogger(__name__)
@@ -38,12 +41,17 @@ class InvalidSession(ConsulException):
"""invalid session"""
Response = namedtuple('Response', 'code,headers,body,content')
class Response(NamedTuple):
code: int
headers: Union[Mapping[str, str], Mapping[bytes, bytes], None]
body: str
content: bytes
class HTTPClient(object):
def __init__(self, host='127.0.0.1', port=8500, token=None, scheme='http', verify=True, cert=None, ca_cert=None):
def __init__(self, host: str = '127.0.0.1', port: int = 8500, token: Optional[str] = None, scheme: str = 'http',
verify: bool = True, cert: Optional[str] = None, ca_cert: Optional[str] = None) -> None:
self.token = token
self._read_timeout = 10
self.base_uri = uri(scheme, (host, port))
@@ -60,22 +68,22 @@ class HTTPClient(object):
kwargs['ca_certs'] = ca_cert
kwargs['cert_reqs'] = ssl.CERT_REQUIRED if verify or ca_cert else ssl.CERT_NONE
self.http = urllib3.PoolManager(num_pools=10, maxsize=10, **kwargs)
self._ttl = None
self._ttl = 30
def set_read_timeout(self, timeout):
def set_read_timeout(self, timeout: float) -> None:
self._read_timeout = timeout / 3.0
@property
def ttl(self):
def ttl(self) -> int:
return self._ttl
def set_ttl(self, ttl):
def set_ttl(self, ttl: int) -> bool:
ret = self._ttl != ttl
self._ttl = ttl
return ret
@staticmethod
def response(response):
def response(response: urllib3.response.HTTPResponse) -> Response:
content = response.data
body = content.decode('utf-8')
if response.status == 500:
@@ -88,14 +96,19 @@ class HTTPClient(object):
raise ConsulInternalError(msg)
return Response(response.status, response.headers, body, content)
def uri(self, path, params=None):
def uri(self, path: str,
params: Union[None, Dict[str, Any], List[Tuple[str, Any]], Tuple[Tuple[str, Any], ...]] = None) -> str:
return '{0}{1}{2}'.format(self.base_uri, path, params and '?' + urlencode(params) or '')
def __getattr__(self, method):
def __getattr__(self, method: str) -> Callable[[Callable[[Response], Union[bool, Any, Tuple[str, Any]]],
str, Union[None, Dict[str, Any], List[Tuple[str, Any]]],
str, Optional[Dict[str, str]]], Union[bool, Any, Tuple[str, Any]]]:
if method not in ('get', 'post', 'put', 'delete'):
raise AttributeError("HTTPClient instance has no attribute '{0}'".format(method))
def wrapper(callback, path, params=None, data='', headers=None):
def wrapper(callback: Callable[[Response], Union[bool, Any, Tuple[str, Any]]], path: str,
params: Union[None, Dict[str, Any], List[Tuple[str, Any]]] = None, data: str = '',
headers: Optional[Dict[str, str]] = None) -> Union[bool, Any, Tuple[str, Any]]:
# python-consul doesn't allow to specify ttl smaller then 10 seconds
# because session_ttl_min defaults to 10s, so we have to do this ugly dirty hack...
if method == 'put' and path == '/v1/session/create':
@@ -106,7 +119,7 @@ class HTTPClient(object):
data = data[:-1] + ', ' + ttl + '}'
if isinstance(params, list): # starting from v1.1.0 python-consul switched from `dict` to `list` for params
params = {k: v for k, v in params}
kwargs = {'retries': 0, 'preload_content': False, 'body': data}
kwargs: Dict[str, Any] = {'retries': 0, 'preload_content': False, 'body': data}
if method == 'get' and isinstance(params, dict) and 'index' in params:
timeout = float(params['wait'][:-1]) if 'wait' in params else 300
# According to the documentation a small random amount of additional wait time is added to the
@@ -127,13 +140,13 @@ class HTTPClient(object):
class ConsulClient(base.Consul):
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
self._cert = kwargs.pop('cert', None)
self._ca_cert = kwargs.pop('ca_cert', None)
self.token = kwargs.get('token')
super(ConsulClient, self).__init__(*args, **kwargs)
def http_connect(self, *args, **kwargs):
def http_connect(self, *args: Any, **kwargs: Any) -> HTTPClient:
kwargs.update(dict(zip(['host', 'port', 'scheme', 'verify'], args)))
if self._cert:
kwargs['cert'] = self._cert
@@ -143,17 +156,17 @@ class ConsulClient(base.Consul):
kwargs['token'] = self.token
return HTTPClient(**kwargs)
def connect(self, *args, **kwargs):
def connect(self, *args: Any, **kwargs: Any) -> HTTPClient:
return self.http_connect(*args, **kwargs)
def reload_config(self, config):
def reload_config(self, config: Dict[str, Any]) -> None:
self.http.token = self.token = config.get('token')
self.consistency = config.get('consistency', 'default')
self.dc = config.get('dc')
def catch_consul_errors(func):
def wrapper(*args, **kwargs):
def catch_consul_errors(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*args: Any, **kwargs: Any) -> Any:
try:
return func(*args, **kwargs)
except (RetryFailedError, ConsulException, HTTPException, HTTPError, socket.error, socket.timeout):
@@ -161,24 +174,25 @@ def catch_consul_errors(func):
return wrapper
def force_if_last_failed(func):
def wrapper(*args, **kwargs):
if wrapper.last_result is False:
def force_if_last_failed(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*args: Any, **kwargs: Any) -> Any:
if getattr(wrapper, 'last_result', None) is False:
kwargs['force'] = True
wrapper.last_result = func(*args, **kwargs)
return wrapper.last_result
last_result = func(*args, **kwargs)
setattr(wrapper, 'last_result', last_result)
return last_result
wrapper.last_result = None
setattr(wrapper, 'last_result', None)
return wrapper
def service_name_from_scope_name(scope_name):
def service_name_from_scope_name(scope_name: str) -> str:
"""Translate scope name to service name which can be used in dns.
230 = 253 - len('replica.') - len('.service.consul')
"""
def replace_char(match):
def replace_char(match: Any) -> str:
c = match.group(0)
return '-' if c in '. _' else "u{:04d}".format(ord(c))
@@ -188,7 +202,7 @@ def service_name_from_scope_name(scope_name):
class Consul(AbstractDCS):
def __init__(self, config):
def __init__(self, config: Dict[str, Any]) -> None:
super(Consul, self).__init__(config)
self._base_path = self._base_path[1:]
self._scope = config['scope']
@@ -198,9 +212,9 @@ class Consul(AbstractDCS):
retry_exceptions=(ConsulInternalError, HTTPException,
HTTPError, socket.error, socket.timeout))
kwargs = {}
if 'url' in config:
r = urlparse(config['url'])
url: str = config['url']
r = urlparse(url)
config.update({'scheme': r.scheme, 'host': r.hostname, 'port': r.port or 8500})
elif 'host' in config:
host, port = split_host_port(config.get('host', '127.0.0.1:8500'), 8500)
@@ -215,7 +229,7 @@ class Consul(AbstractDCS):
config['cert'] = (config['cert'], config['key'])
config_keys = ('host', 'port', 'token', 'scheme', 'cert', 'ca_cert', 'dc', 'consistency')
kwargs = {p: config.get(p) for p in config_keys if config.get(p)}
kwargs: Dict[str, Any] = {p: config.get(p) for p in config_keys if config.get(p)}
verify = config.get('verify')
if not isinstance(verify, bool):
@@ -240,10 +254,10 @@ class Consul(AbstractDCS):
self.create_session()
self._previous_loop_token = self._client.token
def retry(self, *args, **kwargs):
return self._retry.copy()(*args, **kwargs)
def retry(self, method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
return self._retry.copy()(method, *args, **kwargs)
def create_session(self):
def create_session(self) -> None:
while not self._session:
try:
self.refresh_session()
@@ -251,13 +265,14 @@ class Consul(AbstractDCS):
logger.info('waiting on consul')
time.sleep(5)
def reload_config(self, config):
def reload_config(self, config: Union['Config', Dict[str, Any]]) -> None:
super(Consul, self).reload_config(config)
consul_config = config.get('consul', {})
self._client.reload_config(consul_config)
self._previous_loop_service_tags = self._service_tags
self._service_tags = sorted(consul_config.get('service_tags', []))
self._service_tags: List[str] = consul_config.get('service_tags', [])
self._service_tags.sort()
should_register_service = consul_config.get('register_service', False)
if should_register_service and not self._register_service:
@@ -266,20 +281,21 @@ class Consul(AbstractDCS):
self._previous_loop_register_service = self._register_service
self._register_service = should_register_service
def set_ttl(self, ttl):
def set_ttl(self, ttl: int) -> Optional[bool]:
if self._client.http.set_ttl(ttl / 2.0): # Consul multiplies the TTL by 2x
self._session = None
self.__do_not_watch = True
return None
@property
def ttl(self):
def ttl(self) -> int:
return self._client.http.ttl * 2 # we multiply the value by 2 because it was divided in the `set_ttl()` method
def set_retry_timeout(self, retry_timeout):
def set_retry_timeout(self, retry_timeout: int) -> None:
self._retry.deadline = retry_timeout
self._client.http.set_read_timeout(retry_timeout)
def adjust_ttl(self):
def adjust_ttl(self) -> None:
try:
settings = self._client.agent.self()
min_ttl = (settings['Config']['SessionTTLMin'] or 10000000000) / 1000000000.0
@@ -288,7 +304,7 @@ class Consul(AbstractDCS):
except Exception:
logger.exception('adjust_ttl')
def _do_refresh_session(self, force=False):
def _do_refresh_session(self, force: bool = False) -> bool:
""":returns: `!True` if it had to create new session"""
if not force and self._session and self._last_session_refresh + self._loop_wait > time.time():
return False
@@ -312,7 +328,7 @@ class Consul(AbstractDCS):
self._last_session_refresh = time.time()
return ret
def refresh_session(self):
def refresh_session(self) -> bool:
try:
return self.retry(self._do_refresh_session)
except (ConsulException, RetryFailedError):
@@ -320,10 +336,10 @@ class Consul(AbstractDCS):
raise ConsulError('Failed to renew/create session')
@staticmethod
def member(node):
def member(node: Dict[str, str]) -> Member:
return Member.from_node(node['ModifyIndex'], os.path.basename(node['Key']), node.get('Session'), node['Value'])
def _cluster_from_nodes(self, nodes):
def _cluster_from_nodes(self, nodes: Dict[str, Any]) -> Cluster:
# get initialize flag
initialize = nodes.get(self._INITIALIZE)
initialize = initialize and initialize['Value']
@@ -351,7 +367,7 @@ class Consul(AbstractDCS):
slots = None
try:
last_lsn = int(last_lsn)
last_lsn = int(last_lsn or '')
except Exception:
last_lsn = 0
@@ -384,7 +400,7 @@ class Consul(AbstractDCS):
return Cluster(initialize, config, leader, last_lsn, members, failover, sync, history, slots, failsafe)
def _cluster_loader(self, path):
def _cluster_loader(self, path: str) -> Cluster:
_, results = self.retry(self._client.kv.get, path, recurse=True)
if results is None:
raise NotFound
@@ -395,9 +411,9 @@ class Consul(AbstractDCS):
return self._cluster_from_nodes(nodes)
def _citus_cluster_loader(self, path):
def _citus_cluster_loader(self, path: str) -> Dict[int, Cluster]:
_, results = self.retry(self._client.kv.get, path, recurse=True)
clusters = defaultdict(dict)
clusters: Dict[int, Dict[str, Cluster]] = defaultdict(dict)
for node in results or []:
key = node['Key'][len(path):].split('/', 1)
if len(key) == 2 and citus_group_re.match(key[0]):
@@ -405,7 +421,9 @@ class Consul(AbstractDCS):
clusters[int(key[0])][key[1]] = node
return {group: self._cluster_from_nodes(nodes) for group, nodes in clusters.items()}
def _load_cluster(self, path, loader):
def _load_cluster(
self, path: str, loader: Callable[[str], Union[Cluster, Dict[int, Cluster]]]
) -> Union[Cluster, Dict[int, Cluster]]:
try:
return loader(path)
except NotFound:
@@ -415,7 +433,7 @@ class Consul(AbstractDCS):
raise ConsulError('Consul is not responding properly')
@catch_consul_errors
def touch_member(self, data):
def touch_member(self, data: Dict[str, Any]) -> bool:
cluster = self.cluster
member = cluster and cluster.get_member(self._name, fallback_to_leader=False)
@@ -447,30 +465,32 @@ class Consul(AbstractDCS):
logger.exception('touch_member')
return False
def _set_service_name(self):
def _set_service_name(self) -> None:
self._service_name = service_name_from_scope_name(self._scope)
if self._scope != self._service_name:
logger.warning('Using %s as consul service name instead of scope name %s', self._service_name, self._scope)
@catch_consul_errors
def register_service(self, service_name, **kwargs):
def register_service(self, service_name: str, **kwargs: Any) -> bool:
logger.info('Register service %s, params %s', service_name, kwargs)
return self._client.agent.service.register(service_name, **kwargs)
@catch_consul_errors
def deregister_service(self, service_id):
def deregister_service(self, service_id: str) -> bool:
logger.info('Deregister service %s', service_id)
# service_id can contain special characters, but is used as part of uri in deregister request
service_id = quote(service_id)
return self._client.agent.service.deregister(service_id)
def _update_service(self, data):
def _update_service(self, data: Dict[str, Any]) -> Optional[bool]:
service_name = self._service_name
role = data['role'].replace('_', '-')
state = data['state']
api_parts = urlparse(data['api_url'])
api_url: str = data['api_url']
api_parts = urlparse(api_url)
api_parts = api_parts._replace(path='/{0}'.format(role))
conn_parts = urlparse(data['conn_url'])
conn_url: str = data['conn_url']
conn_parts = urlparse(conn_url)
check = base.Check.http(api_parts.geturl(), self._service_check_interval,
deregister='{0}s'.format(self._client.http.ttl * 10))
if self._service_check_tls_server_name is not None:
@@ -506,7 +526,7 @@ class Consul(AbstractDCS):
logger.warning('Could not register service: unknown role type %s', role)
@force_if_last_failed
def update_service(self, old_data, new_data, force=False):
def update_service(self, old_data: Dict[str, Any], new_data: Dict[str, Any], force: bool = False) -> Optional[bool]:
update = False
for key in ['role', 'api_url', 'conn_url', 'state']:
@@ -523,7 +543,7 @@ class Consul(AbstractDCS):
):
return self._update_service(new_data)
def _do_attempt_to_acquire_leader(self, retry):
def _do_attempt_to_acquire_leader(self, retry: Retry) -> bool:
try:
return retry(self._client.kv.put, self.leader_path, self._name, acquire=self._session)
except InvalidSession:
@@ -540,7 +560,7 @@ class Consul(AbstractDCS):
return retry(self._client.kv.put, self.leader_path, self._name, acquire=self._session)
@catch_return_false_exception
def attempt_to_acquire_leader(self):
def attempt_to_acquire_leader(self) -> bool:
retry = self._retry.copy()
self._run_and_handle_exceptions(self._do_refresh_session, retry=retry)
@@ -554,31 +574,31 @@ class Consul(AbstractDCS):
return ret
def take_leader(self):
def take_leader(self) -> bool:
return self.attempt_to_acquire_leader()
@catch_consul_errors
def set_failover_value(self, value, index=None):
def set_failover_value(self, value: str, index: Optional[int] = None) -> bool:
return self._client.kv.put(self.failover_path, value, cas=index)
@catch_consul_errors
def set_config_value(self, value, index=None):
def set_config_value(self, value: str, index: Optional[int] = None) -> bool:
return self._client.kv.put(self.config_path, value, cas=index)
@catch_consul_errors
def _write_leader_optime(self, last_lsn):
def _write_leader_optime(self, last_lsn: str) -> bool:
return self._client.kv.put(self.leader_optime_path, last_lsn)
@catch_consul_errors
def _write_status(self, value):
def _write_status(self, value: str) -> bool:
return self._client.kv.put(self.status_path, value)
@catch_consul_errors
def _write_failsafe(self, value):
def _write_failsafe(self, value: str) -> bool:
return self._client.kv.put(self.failsafe_path, value)
@staticmethod
def _run_and_handle_exceptions(method, *args, **kwargs):
def _run_and_handle_exceptions(method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
retry = kwargs.pop('retry', None)
try:
return retry(method, *args, **kwargs) if retry else method(*args, **kwargs)
@@ -588,7 +608,7 @@ class Consul(AbstractDCS):
raise ReturnFalseException
@catch_return_false_exception
def _update_leader(self):
def _update_leader(self) -> bool:
retry = self._retry.copy()
self._run_and_handle_exceptions(self._do_refresh_session, True, retry=retry)
@@ -601,7 +621,7 @@ class Consul(AbstractDCS):
if retry.deadline < 1:
raise ConsulError('update_leader timeout')
logger.warning('Recreating the leader key due to session mismatch')
if cluster.leader:
if cluster and cluster.leader:
self._run_and_handle_exceptions(self._client.kv.delete, self.leader_path, cas=cluster.leader.index)
retry.deadline = retry.stoptime - time.time()
@@ -613,37 +633,39 @@ class Consul(AbstractDCS):
return bool(self._session)
@catch_consul_errors
def initialize(self, create_new=True, sysid=''):
def initialize(self, create_new: bool = True, sysid: str = '') -> bool:
kwargs = {'cas': 0} if create_new else {}
return self.retry(self._client.kv.put, self.initialize_path, sysid, **kwargs)
@catch_consul_errors
def cancel_initialization(self):
def cancel_initialization(self) -> bool:
return self.retry(self._client.kv.delete, self.initialize_path)
@catch_consul_errors
def delete_cluster(self):
def delete_cluster(self) -> bool:
return self.retry(self._client.kv.delete, self.client_path(''), recurse=True)
@catch_consul_errors
def set_history_value(self, value):
def set_history_value(self, value: str) -> bool:
return self._client.kv.put(self.history_path, value)
@catch_consul_errors
def _delete_leader(self):
def _delete_leader(self) -> bool:
cluster = self.cluster
if cluster and isinstance(cluster.leader, Leader) and cluster.leader.name == self._name:
if cluster and isinstance(cluster.leader, Leader) and\
cluster.leader.name == self._name and isinstance(cluster.leader.index, int):
return self._client.kv.delete(self.leader_path, cas=cluster.leader.index)
return True
@catch_consul_errors
def set_sync_state_value(self, value, index=None):
def set_sync_state_value(self, value: str, index: Optional[int] = None) -> bool:
return self.retry(self._client.kv.put, self.sync_path, value, cas=index)
@catch_consul_errors
def delete_sync_state(self, index=None):
def delete_sync_state(self, index: Optional[int] = None) -> bool:
return self.retry(self._client.kv.delete, self.sync_path, cas=index)
def watch(self, leader_index, timeout):
def watch(self, leader_index: Optional[int], timeout: float) -> bool:
self._last_session_refresh = 0
if self.__do_not_watch:
self.__do_not_watch = False

View File

@@ -16,7 +16,7 @@ from dns import resolver
from http.client import HTTPException
from queue import Queue
from threading import Thread
from typing import List, Optional
from typing import Any, Callable, Collection, Dict, List, Optional, Union, Tuple, Type, TYPE_CHECKING
from urllib.parse import urlparse
from urllib3 import Timeout
from urllib3.exceptions import HTTPError, ReadTimeoutError, ProtocolError
@@ -26,6 +26,8 @@ from . import AbstractDCS, Cluster, ClusterConfig, Failover, Leader, Member, Syn
from ..exceptions import DCSError
from ..request import get as requests_get
from ..utils import Retry, RetryFailedError, split_host_port, uri, USER_AGENT
if TYPE_CHECKING: # pragma: no cover
from ..config import Config
logger = logging.getLogger(__name__)
@@ -38,18 +40,21 @@ class EtcdError(DCSError):
pass
_AddrInfo = Tuple[socket.AddressFamily, socket.SocketKind, int, str, Union[Tuple[str, int], Tuple[str, int, int, int]]]
class DnsCachingResolver(Thread):
def __init__(self, cache_time=600.0, cache_fail_time=30.0):
def __init__(self, cache_time: float = 600.0, cache_fail_time: float = 30.0) -> None:
super(DnsCachingResolver, self).__init__()
self._cache = {}
self._cache: Dict[Tuple[str, int], Tuple[float, List[_AddrInfo]]] = {}
self._cache_time = cache_time
self._cache_fail_time = cache_fail_time
self._resolve_queue = Queue()
self._resolve_queue: Queue[Tuple[Tuple[str, int], int]] = Queue()
self.daemon = True
self.start()
def run(self):
def run(self) -> None:
while True:
(host, port), attempt = self._resolve_queue.get()
response = self._do_resolve(host, port)
@@ -60,7 +65,7 @@ class DnsCachingResolver(Thread):
self.resolve_async(host, port, attempt + 1)
time.sleep(1)
def resolve(self, host, port):
def resolve(self, host: str, port: int) -> List[_AddrInfo]:
current_time = time.time()
cached_time, response = self._cache.get((host, port), (0, []))
time_passed = current_time - cached_time
@@ -71,14 +76,14 @@ class DnsCachingResolver(Thread):
response = new_response
return response
def resolve_async(self, host, port, attempt=0):
def resolve_async(self, host: str, port: int, attempt: int = 0) -> None:
self._resolve_queue.put(((host, port), attempt))
def remove(self, host, port):
def remove(self, host: str, port: int) -> None:
self._cache.pop((host, port), None)
@staticmethod
def _do_resolve(host, port):
def _do_resolve(host: str, port: int) -> List[_AddrInfo]:
try:
return socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM, socket.IPPROTO_TCP)
except Exception as e:
@@ -88,13 +93,15 @@ class DnsCachingResolver(Thread):
class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client):
def __init__(self, config, dns_resolver, cache_ttl=300):
ERROR_CLS: Type[Exception]
def __init__(self, config: Dict[str, Any], dns_resolver: DnsCachingResolver, cache_ttl: int = 300) -> None:
self._dns_resolver = dns_resolver
self.set_machines_cache_ttl(cache_ttl)
self._machines_cache_updated = 0
args = {p: config.get(p) for p in ('host', 'port', 'protocol', 'use_proxies', 'username', 'password',
'cert', 'ca_cert') if config.get(p)}
super(AbstractEtcdClientWithFailover, self).__init__(read_timeout=config['retry_timeout'], **args)
kwargs = {p: config.get(p) for p in ('host', 'port', 'protocol', 'use_proxies',
'username', 'password', 'cert', 'ca_cert') if config.get(p)}
super(AbstractEtcdClientWithFailover, self).__init__(read_timeout=config['retry_timeout'], **kwargs)
# For some reason python3-etcd on debian and ubuntu are not based on the latest version
# Workaround for the case when https://github.com/jplana/python-etcd/pull/196 is not applied
self.http.connection_pool_kw.pop('ssl_version', None)
@@ -106,7 +113,7 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client):
self._read_options.add('retry')
self._del_conditions.add('retry')
def _calculate_timeouts(self, etcd_nodes, timeout=None):
def _calculate_timeouts(self, etcd_nodes: int, timeout: Optional[float] = None) -> Tuple[int, float, int]:
"""Calculate a request timeout and number of retries per single etcd node.
In case if the timeout per node is too small (less than one second) we will reduce the number of nodes.
For the cluster with only one node we will try to do 2 retries.
@@ -133,38 +140,39 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client):
return etcd_nodes, per_node_timeout, per_node_retries - 1
def reload_config(self, config):
def reload_config(self, config: Dict[str, Any]) -> None:
self.username = config.get('username')
self.password = config.get('password')
def _get_headers(self):
def _get_headers(self) -> Dict[str, str]:
basic_auth = ':'.join((self.username, self.password)) if self.username and self.password else None
return urllib3.make_headers(basic_auth=basic_auth, user_agent=USER_AGENT)
def _prepare_common_parameters(self, etcd_nodes, timeout=None):
kwargs = {'headers': self._get_headers(), 'redirect': self.allow_redirect, 'preload_content': False}
def _prepare_common_parameters(self, etcd_nodes: int, timeout: Optional[float] = None) -> Dict[str, Any]:
kwargs: Dict[str, Any] = {'headers': self._get_headers(),
'redirect': self.allow_redirect, 'preload_content': False}
if timeout is not None:
kwargs.update(retries=0, timeout=timeout)
else:
_, per_node_timeout, per_node_retries = self._calculate_timeouts(etcd_nodes)
connect_timeout = max(1, per_node_timeout / 2)
connect_timeout = max(1.0, per_node_timeout / 2.0)
kwargs.update(timeout=Timeout(connect=connect_timeout, total=per_node_timeout), retries=per_node_retries)
return kwargs
def set_machines_cache_ttl(self, cache_ttl):
def set_machines_cache_ttl(self, cache_ttl: int) -> None:
self._machines_cache_ttl = cache_ttl
@abc.abstractmethod
def _prepare_get_members(self, etcd_nodes):
def _prepare_get_members(self, etcd_nodes: int) -> Dict[str, Any]:
"""returns: request parameters"""
@abc.abstractmethod
def _get_members(self, base_uri, **kwargs):
def _get_members(self, base_uri: str, **kwargs: Any) -> List[str]:
"""returns: list of clientURLs"""
@property
def machines_cache(self):
def machines_cache(self) -> List[str]:
base_uri, cache = self._base_uri, self._machines_cache
return ([base_uri] if base_uri in cache else []) + [machine for machine in cache if machine != base_uri]
@@ -207,10 +215,14 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client):
return self._get_machines_list(self.machines_cache)
def set_read_timeout(self, timeout):
def set_read_timeout(self, timeout: float) -> None:
self._read_timeout = timeout
def _do_http_request(self, retry, machines_cache, request_executor, method, path, fields=None, **kwargs):
def _do_http_request(self, retry: Optional[Retry], machines_cache: List[str],
request_executor: Callable[..., urllib3.response.HTTPResponse],
method: str, path: str, fields: Optional[Dict[str, Any]] = None,
**kwargs: Any) -> urllib3.response.HTTPResponse:
is_watch_request = isinstance(fields, dict) and fields.get('wait') == 'true'
if fields is not None:
kwargs['fields'] = fields
some_request_failed = False
@@ -233,8 +245,7 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client):
# whether the key didn't received an update or there is a network problem.
elif i + 1 < len(machines_cache):
self.set_base_uri(machines_cache[i + 1])
if (isinstance(fields, dict) and fields.get("wait") == "true"
and isinstance(e, (ReadTimeoutError, ProtocolError))):
if is_watch_request and isinstance(e, (ReadTimeoutError, ProtocolError)):
logger.debug("Watch timed out.")
raise etcd.EtcdWatchTimedOut("Watch timed out: {0}".format(e), cause=e)
logger.error("Request to server %s failed: %r", base_uri, e)
@@ -246,10 +257,12 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client):
raise etcd.EtcdConnectionFailed('No more machines in the cluster')
@abc.abstractmethod
def _prepare_request(self, kwargs, params=None, method=None):
def _prepare_request(self, kwargs: Dict[str, Any], params: Optional[Dict[str, Any]] = None,
method: Optional[str] = None) -> Callable[..., urllib3.response.HTTPResponse]:
"""returns: request_executor"""
def api_execute(self, path, method, params=None, timeout=None):
def api_execute(self, path: str, method: str, params: Optional[Dict[str, Any]] = None,
timeout: Optional[float] = None) -> Any:
retry = params.pop('retry', None) if isinstance(params, dict) else None
# Update machines_cache if previous attempt of update has failed
@@ -277,6 +290,7 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client):
etcd_nodes = len(machines_cache)
except Exception as e:
logger.debug('Failed to update list of etcd nodes: %r', e)
assert isinstance(retry, Retry) # etcd.EtcdConnectionFailed is raised only if retry is not None!
sleeptime = retry.sleeptime
remaining_time = retry.stoptime - sleeptime - time.time()
nodes, timeout, retries = self._calculate_timeouts(etcd_nodes, remaining_time)
@@ -287,22 +301,22 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client):
retry.sleep_func(sleeptime)
retry.update_delay()
# We still have some time left. Partially reduce `machines_cache` and retry request
kwargs.update(timeout=Timeout(connect=max(1, timeout / 2), total=timeout), retries=retries)
kwargs.update(timeout=Timeout(connect=max(1.0, timeout / 2.0), total=timeout), retries=retries)
machines_cache = machines_cache[:nodes]
@staticmethod
def get_srv_record(host):
def get_srv_record(host: str) -> List[Tuple[str, int]]:
try:
return [(r.target.to_text(True), r.port) for r in resolver.query(host, 'SRV')]
except DNSException:
return []
def _get_machines_cache_from_srv(self, srv, srv_suffix=None):
def _get_machines_cache_from_srv(self, srv: str, srv_suffix: Optional[str] = None) -> List[str]:
"""Fetch list of etcd-cluster member by resolving _etcd-server._tcp. SRV record.
This record should contain list of host and peer ports which could be used to run
'GET http://{host}:{port}/members' request (peer protocol)"""
ret = []
ret: List[str] = []
for r in ['-client-ssl', '-client', '-ssl', '', '-server-ssl', '-server']:
r = '{0}-{1}'.format(r, srv_suffix) if srv_suffix else r
protocol = 'https' if '-ssl' in r else 'http'
@@ -327,15 +341,15 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client):
logger.warning('Can not resolve SRV for %s', srv)
return list(set(ret))
def _get_machines_cache_from_dns(self, host, port):
def _get_machines_cache_from_dns(self, host: str, port: int) -> List[str]:
"""One host might be resolved into multiple ip addresses. We will make list out of it"""
if self.protocol == 'http':
ret = map(lambda res: uri(self.protocol, res[-1][:2]), self._dns_resolver.resolve(host, port))
ret = [uri(self.protocol, res[-1][:2]) for res in self._dns_resolver.resolve(host, port)]
if ret:
return list(set(ret))
return [uri(self.protocol, (host, port))]
def _get_machines_cache_from_config(self):
def _get_machines_cache_from_config(self) -> List[str]:
if 'proxy' in self._config:
return [uri(self.protocol, (self._config['host'], self._config['port']))]
@@ -351,13 +365,14 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client):
return machines_cache
@staticmethod
def _update_dns_cache(func, machines):
def _update_dns_cache(func: Callable[[str, int], None], machines: List[str]) -> None:
for url in machines:
r = urlparse(url)
port = r.port or (443 if r.scheme == 'https' else 80)
func(r.hostname, port)
if r.hostname:
port = r.port or (443 if r.scheme == 'https' else 80)
func(r.hostname, port)
def _load_machines_cache(self):
def _load_machines_cache(self) -> bool:
"""This method should fill up `_machines_cache` from scratch.
It could happen only in two cases:
1. During class initialization
@@ -417,7 +432,7 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client):
self._machines_cache_updated = time.time()
return ret
def set_base_uri(self, value):
def set_base_uri(self, value: str) -> None:
if self._base_uri != value:
logger.info('Selected new etcd server %s', value)
self._base_uri = value
@@ -427,22 +442,22 @@ class EtcdClient(AbstractEtcdClientWithFailover):
ERROR_CLS = EtcdError
def __del__(self):
if self.http is not None:
try:
self.http.clear()
except (ReferenceError, TypeError, AttributeError):
pass
def __del__(self) -> None:
try:
self.http.clear()
except (ReferenceError, TypeError, AttributeError):
pass
def _prepare_get_members(self, etcd_nodes):
def _prepare_get_members(self, etcd_nodes: int) -> Dict[str, Any]:
return self._prepare_common_parameters(etcd_nodes)
def _get_members(self, base_uri, **kwargs):
def _get_members(self, base_uri: str, **kwargs: Any) -> List[str]:
response = self.http.request(self._MGET, base_uri + self.version_prefix + '/machines', **kwargs)
data = self._handle_server_response(response).data.decode('utf-8')
return [m.strip() for m in data.split(',') if m.strip()]
def _prepare_request(self, kwargs, params=None, method=None):
def _prepare_request(self, kwargs: Dict[str, Any], params: Optional[Dict[str, Any]] = None,
method: Optional[str] = None) -> Callable[..., urllib3.response.HTTPResponse]:
kwargs['fields'] = params
if method in (self._MPOST, self._MPUT):
kwargs['encode_multipart'] = False
@@ -451,25 +466,32 @@ class EtcdClient(AbstractEtcdClientWithFailover):
class AbstractEtcd(AbstractDCS):
def __init__(self, config, client_cls, retry_errors_cls):
def __init__(self, config: Dict[str, Any], client_cls: Type[AbstractEtcdClientWithFailover],
retry_errors_cls: Union[Type[Exception], Tuple[Type[Exception], ...]]) -> None:
super(AbstractEtcd, self).__init__(config)
self._retry = Retry(deadline=config['retry_timeout'], max_delay=1, max_tries=-1,
retry_exceptions=retry_errors_cls)
self._ttl = int(config.get('ttl') or 30)
self._client = self.get_etcd_client(config, client_cls)
self._abstract_client = self.get_etcd_client(config, client_cls)
self.__do_not_watch = False
self._has_failed = False
def reload_config(self, config):
@property
@abc.abstractmethod
def _client(self) -> AbstractEtcdClientWithFailover:
"""return correct type of etcd client"""
def reload_config(self, config: Union['Config', Dict[str, Any]]) -> None:
super(AbstractEtcd, self).reload_config(config)
self._client.reload_config(config.get(self.__class__.__name__.lower(), {}))
def retry(self, *args, **kwargs):
def retry(self, method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
retry = self._retry.copy()
kwargs['retry'] = retry
return retry(*args, **kwargs)
return retry(method, *args, **kwargs)
def _handle_exception(self, e, name='', do_sleep=False, raise_ex=None):
def _handle_exception(self, e: Exception, name: str = '', do_sleep: bool = False,
raise_ex: Optional[Exception] = None) -> None:
if not self._has_failed:
logger.exception(name)
else:
@@ -480,7 +502,7 @@ class AbstractEtcd(AbstractDCS):
if isinstance(raise_ex, Exception):
raise raise_ex
def _run_and_handle_exceptions(self, method, *args, **kwargs):
def _run_and_handle_exceptions(self, method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
retry = kwargs.pop('retry', self.retry)
try:
return retry(method, *args, **kwargs) if retry else method(*args, **kwargs)
@@ -492,19 +514,20 @@ class AbstractEtcd(AbstractDCS):
except Exception as e:
self._handle_exception(e, raise_ex=self._client.ERROR_CLS('unexpected error'))
@staticmethod
def set_socket_options(sock, socket_options):
def set_socket_options(self, sock: socket.socket,
socket_options: Optional[Collection[Tuple[int, int, int]]]) -> None:
if socket_options:
for opt in socket_options:
sock.setsockopt(*opt)
def get_etcd_client(self, config, client_cls):
def get_etcd_client(self, config: Dict[str, Any],
client_cls: Type[AbstractEtcdClientWithFailover]) -> AbstractEtcdClientWithFailover:
config = deepcopy(config)
if 'proxy' in config:
config['use_proxies'] = True
config['url'] = config['proxy']
if 'url' in config:
if 'url' in config and isinstance(config['url'], str):
r = urlparse(config['url'])
config.update({'protocol': r.scheme, 'host': r.hostname, 'port': r.port or 2379,
'username': r.username, 'password': r.password})
@@ -516,10 +539,11 @@ class AbstractEtcd(AbstractDCS):
if isinstance(hosts, str):
hosts = hosts.split(',')
config['hosts'] = []
config_hosts: List[str] = []
for value in hosts:
if isinstance(value, str):
config['hosts'].append(uri(protocol, split_host_port(value.strip(), default_port)))
config_hosts.append(uri(protocol, split_host_port(value.strip(), default_port)))
config['hosts'] = config_hosts
elif 'host' in config:
host, port = split_host_port(config['host'], 2379)
config['host'] = host
@@ -538,8 +562,10 @@ class AbstractEtcd(AbstractDCS):
dns_resolver = DnsCachingResolver()
def create_connection_patched(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
source_address=None, socket_options=None):
def create_connection_patched(
address: Tuple[str, int], timeout: Any = object(),
source_address: Optional[Any] = None, socket_options: Optional[Collection[Tuple[int, int, int]]] = None
) -> socket.socket:
host, port = address
if host.startswith('['):
host = host.strip('[]')
@@ -549,7 +575,7 @@ class AbstractEtcd(AbstractDCS):
try:
sock = socket.socket(af, socktype, proto)
self.set_socket_options(sock, socket_options)
if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
if timeout is None or isinstance(timeout, (float, int)):
sock.settimeout(timeout)
if source_address:
sock.bind(source_address)
@@ -580,7 +606,7 @@ class AbstractEtcd(AbstractDCS):
time.sleep(5)
return client
def set_ttl(self, ttl):
def set_ttl(self, ttl: int) -> Optional[bool]:
ttl = int(ttl)
ret = self._ttl != ttl
self._ttl = ttl
@@ -588,16 +614,16 @@ class AbstractEtcd(AbstractDCS):
return ret
@property
def ttl(self):
def ttl(self) -> int:
return self._ttl
def set_retry_timeout(self, retry_timeout):
def set_retry_timeout(self, retry_timeout: int) -> None:
self._retry.deadline = retry_timeout
self._client.set_read_timeout(retry_timeout)
def catch_etcd_errors(func):
def wrapper(self, *args, **kwargs):
def catch_etcd_errors(func: Callable[..., Any]) -> Any:
def wrapper(self: AbstractEtcd, *args: Any, **kwargs: Any) -> Any:
try:
retval = func(self, *args, **kwargs) is not None
self._has_failed = False
@@ -613,18 +639,24 @@ def catch_etcd_errors(func):
class Etcd(AbstractEtcd):
def __init__(self, config):
def __init__(self, config: Dict[str, Any]) -> None:
super(Etcd, self).__init__(config, EtcdClient, (etcd.EtcdLeaderElectionInProgress, EtcdRaftInternal))
self.__do_not_watch = False
def set_ttl(self, ttl):
@property
def _client(self) -> EtcdClient:
assert isinstance(self._abstract_client, EtcdClient)
return self._abstract_client
def set_ttl(self, ttl: int) -> Optional[bool]:
self.__do_not_watch = super(Etcd, self).set_ttl(ttl)
return None
@staticmethod
def member(node):
def member(node: etcd.EtcdResult) -> Member:
return Member.from_node(node.modifiedIndex, os.path.basename(node.key), node.ttl, node.value)
def _cluster_from_nodes(self, etcd_index, nodes):
def _cluster_from_nodes(self, etcd_index: int, nodes: Dict[str, etcd.EtcdResult]) -> Cluster:
# get initialize flag
initialize = nodes.get(self._INITIALIZE)
initialize = initialize and initialize.value
@@ -652,7 +684,7 @@ class Etcd(AbstractEtcd):
slots = None
try:
last_lsn = int(last_lsn)
last_lsn = int(last_lsn or '')
except Exception:
last_lsn = 0
@@ -685,13 +717,13 @@ class Etcd(AbstractEtcd):
return Cluster(initialize, config, leader, last_lsn, members, failover, sync, history, slots, failsafe)
def _cluster_loader(self, path):
def _cluster_loader(self, path: str) -> Cluster:
result = self.retry(self._client.read, path, recursive=True)
nodes = {node.key[len(result.key):].lstrip('/'): node for node in result.leaves}
return self._cluster_from_nodes(result.etcd_index, nodes)
def _citus_cluster_loader(self, path):
clusters = defaultdict(dict)
def _citus_cluster_loader(self, path: str) -> Dict[int, Cluster]:
clusters: Dict[int, Dict[str, etcd.EtcdResult]] = defaultdict(dict)
result = self.retry(self._client.read, path, recursive=True)
for node in result.leaves:
key = node.key[len(result.key):].lstrip('/').split('/', 1)
@@ -699,7 +731,9 @@ class Etcd(AbstractEtcd):
clusters[int(key[0])][key[1]] = node
return {group: self._cluster_from_nodes(result.etcd_index, nodes) for group, nodes in clusters.items()}
def _load_cluster(self, path, loader):
def _load_cluster(
self, path: str, loader: Callable[[str], Union[Cluster, Dict[int, Cluster]]]
) -> Union[Cluster, Dict[int, Cluster]]:
cluster = None
try:
cluster = loader(path)
@@ -708,18 +742,19 @@ class Etcd(AbstractEtcd):
except Exception as e:
self._handle_exception(e, 'get_cluster', raise_ex=EtcdError('Etcd is not responding properly'))
self._has_failed = False
assert cluster is not None
return cluster
@catch_etcd_errors
def touch_member(self, data):
data = json.dumps(data, separators=(',', ':'))
return self._client.set(self.member_path, data, self._ttl)
def touch_member(self, data: Dict[str, Any]) -> bool:
value = json.dumps(data, separators=(',', ':'))
return bool(self._client.set(self.member_path, value, self._ttl))
@catch_etcd_errors
def take_leader(self):
def take_leader(self) -> bool:
return self.retry(self._client.write, self.leader_path, self._name, ttl=self._ttl)
def _do_attempt_to_acquire_leader(self):
def _do_attempt_to_acquire_leader(self) -> bool:
try:
return bool(self.retry(self._client.write, self.leader_path, self._name, ttl=self._ttl, prevExist=False))
except etcd.EtcdAlreadyExist:
@@ -727,26 +762,26 @@ class Etcd(AbstractEtcd):
return False
@catch_return_false_exception
def attempt_to_acquire_leader(self):
def attempt_to_acquire_leader(self) -> bool:
return self._run_and_handle_exceptions(self._do_attempt_to_acquire_leader, retry=None)
@catch_etcd_errors
def set_failover_value(self, value, index=None):
return self._client.write(self.failover_path, value, prevIndex=index or 0)
def set_failover_value(self, value: str, index: Optional[int] = None) -> bool:
return bool(self._client.write(self.failover_path, value, prevIndex=index or 0))
@catch_etcd_errors
def set_config_value(self, value, index=None):
return self._client.write(self.config_path, value, prevIndex=index or 0)
def set_config_value(self, value: str, index: Optional[int] = None) -> bool:
return bool(self._client.write(self.config_path, value, prevIndex=index or 0))
@catch_etcd_errors
def _write_leader_optime(self, last_lsn):
return self._client.set(self.leader_optime_path, last_lsn)
def _write_leader_optime(self, last_lsn: str) -> bool:
return bool(self._client.set(self.leader_optime_path, last_lsn))
@catch_etcd_errors
def _write_status(self, value):
return self._client.set(self.status_path, value)
def _write_status(self, value: str) -> bool:
return bool(self._client.set(self.status_path, value))
def _do_update_leader(self):
def _do_update_leader(self) -> bool:
try:
return self.retry(self._client.write, self.leader_path, self._name,
prevValue=self._name, ttl=self._ttl) is not None
@@ -754,42 +789,42 @@ class Etcd(AbstractEtcd):
return self._do_attempt_to_acquire_leader()
@catch_etcd_errors
def _write_failsafe(self, value):
return self._client.set(self.failsafe_path, value)
def _write_failsafe(self, value: str) -> bool:
return bool(self._client.set(self.failsafe_path, value))
@catch_return_false_exception
def _update_leader(self):
return self._run_and_handle_exceptions(self._do_update_leader, retry=None)
def _update_leader(self) -> bool:
return bool(self._run_and_handle_exceptions(self._do_update_leader, retry=None))
@catch_etcd_errors
def initialize(self, create_new=True, sysid=""):
return self.retry(self._client.write, self.initialize_path, sysid, prevExist=(not create_new))
def initialize(self, create_new: bool = True, sysid: str = "") -> bool:
return bool(self.retry(self._client.write, self.initialize_path, sysid, prevExist=(not create_new)))
@catch_etcd_errors
def _delete_leader(self):
return self._client.delete(self.leader_path, prevValue=self._name)
def _delete_leader(self) -> bool:
return bool(self._client.delete(self.leader_path, prevValue=self._name))
@catch_etcd_errors
def cancel_initialization(self):
return self.retry(self._client.delete, self.initialize_path)
def cancel_initialization(self) -> bool:
return bool(self.retry(self._client.delete, self.initialize_path))
@catch_etcd_errors
def delete_cluster(self):
return self.retry(self._client.delete, self.client_path(''), recursive=True)
def delete_cluster(self) -> bool:
return bool(self.retry(self._client.delete, self.client_path(''), recursive=True))
@catch_etcd_errors
def set_history_value(self, value):
return self._client.write(self.history_path, value)
def set_history_value(self, value: str) -> bool:
return bool(self._client.write(self.history_path, value))
@catch_etcd_errors
def set_sync_state_value(self, value, index=None):
return self.retry(self._client.write, self.sync_path, value, prevIndex=index or 0)
def set_sync_state_value(self, value: str, index: Optional[int] = None) -> bool:
return bool(self.retry(self._client.write, self.sync_path, value, prevIndex=index or 0))
@catch_etcd_errors
def delete_sync_state(self, index=None):
return self.retry(self._client.delete, self.sync_path, prevIndex=index or 0)
def delete_sync_state(self, index: Optional[int] = None) -> bool:
return bool(self.retry(self._client.delete, self.sync_path, prevIndex=index or 0))
def watch(self, leader_index, timeout):
def watch(self, leader_index: Optional[int], timeout: float) -> bool:
if self.__do_not_watch:
self.__do_not_watch = False
return True

View File

@@ -10,12 +10,14 @@ import time
import urllib3
from collections import defaultdict
from threading import Condition, Lock, Thread
from enum import IntEnum
from urllib3.exceptions import ReadTimeoutError, ProtocolError
from threading import Condition, Lock, Thread
from typing import Any, Callable, Collection, Dict, Iterator, List, Optional, Tuple, Type, Union
from . import ClusterConfig, Cluster, Failover, Leader, Member, SyncState,\
TimelineHistory, ReturnFalseException, catch_return_false_exception, citus_group_re
from .etcd import AbstractEtcdClientWithFailover, AbstractEtcd, catch_etcd_errors
from .etcd import AbstractEtcdClientWithFailover, AbstractEtcd, catch_etcd_errors, DnsCachingResolver, Retry
from ..exceptions import DCSError, PatroniException
from ..utils import deep_compare, enable_keepalive, iter_response_objects, RetryFailedError, USER_AGENT
@@ -31,11 +33,27 @@ class UnsupportedEtcdVersion(PatroniException):
# google.golang.org/grpc/codes
GRPCCode = type('Enum', (), {'OK': 0, 'Canceled': 1, 'Unknown': 2, 'InvalidArgument': 3, 'DeadlineExceeded': 4,
'NotFound': 5, 'AlreadyExists': 6, 'PermissionDenied': 7, 'ResourceExhausted': 8,
'FailedPrecondition': 9, 'Aborted': 10, 'OutOfRange': 11, 'Unimplemented': 12,
'Internal': 13, 'Unavailable': 14, 'DataLoss': 15, 'Unauthenticated': 16})
GRPCcodeToText = {v: k for k, v in GRPCCode.__dict__.items() if not k.startswith('__') and isinstance(v, int)}
class GRPCCode(IntEnum):
OK = 0
Canceled = 1
Unknown = 2
InvalidArgument = 3
DeadlineExceeded = 4
NotFound = 5
AlreadyExists = 6
PermissionDenied = 7
ResourceExhausted = 8
FailedPrecondition = 9
Aborted = 10
OutOfRange = 11
Unimplemented = 12
Internal = 13
Unavailable = 14
DataLoss = 15
Unauthenticated = 16
GRPCcodeToText: Dict[int, str] = {v: k for k, v in GRPCCode.__dict__['_member_map_'].items()}
class Etcd3Exception(etcd.EtcdException):
@@ -44,22 +62,24 @@ class Etcd3Exception(etcd.EtcdException):
class Etcd3ClientError(Etcd3Exception):
def __init__(self, code=None, error=None, status=None):
def __init__(self, code: Optional[int] = None, error: Optional[str] = None, status: Optional[int] = None) -> None:
if not hasattr(self, 'error'):
self.error = error and error.strip()
self.codeText = GRPCcodeToText.get(code)
self.codeText = GRPCcodeToText.get(code) if code is not None else None
self.status = status
def __repr__(self):
return "<{0} error: '{1}', code: {2}>".format(self.__class__.__name__, self.error, self.code)
def __repr__(self) -> str:
return "<{0} error: '{1}', code: {2}>"\
.format(self.__class__.__name__, getattr(self, 'error', None), getattr(self, 'code', None))
__str__ = __repr__
def as_dict(self):
return {'error': self.error, 'code': self.code, 'codeText': self.codeText, 'status': self.status}
def as_dict(self) -> Dict[str, Any]:
return {'error': getattr(self, 'error', None), 'code': getattr(self, 'code', None),
'codeText': self.codeText, 'status': self.status}
@classmethod
def get_subclasses(cls):
def get_subclasses(cls) -> Iterator[Type['Etcd3ClientError']]:
for subclass in cls.__subclasses__():
for subsubclass in subclass.get_subclasses():
yield subsubclass
@@ -118,61 +138,93 @@ class InvalidAuthToken(Etcd3ClientError):
error = "etcdserver: invalid auth token"
errStringToClientError = {s.error: s for s in Etcd3ClientError.get_subclasses() if hasattr(s, 'error')}
errCodeToClientError = {s.code: s for s in Etcd3ClientError.__subclasses__()}
errStringToClientError = {getattr(s, 'error'): s for s in Etcd3ClientError.get_subclasses() if hasattr(s, 'error')}
errCodeToClientError = {getattr(s, 'code'): s for s in Etcd3ClientError.__subclasses__()}
def _raise_for_data(data, status_code=None):
def _raise_for_data(data: Union[bytes, str, Dict[str, Union[Any, Dict[str, Any]]]],
status_code: Optional[int] = None) -> Etcd3ClientError:
try:
error = data.get('error') or data.get('Error')
if isinstance(error, dict): # streaming response
status_code = error.get('http_code')
code = error['grpc_code']
error = error['message']
assert isinstance(data, dict)
data_error: Optional[Dict[str, Any]] = data.get('error') or data.get('Error')
if isinstance(data_error, dict): # streaming response
status_code = data_error.get('http_code')
code: Optional[int] = data_error['grpc_code']
error: str = data_error['message']
else:
code = data.get('code') or data.get('Code')
data_code = data.get('code') or data.get('Code')
assert not isinstance(data_code, dict)
code = data_code
error = str(data_error)
except Exception:
error = str(data)
code = GRPCCode.Unknown
err = errStringToClientError.get(error) or errCodeToClientError.get(code) or Unknown
raise err(code, error, status_code)
return err(code, error, status_code)
def to_bytes(v):
def to_bytes(v: Union[str, bytes]) -> bytes:
return v if isinstance(v, bytes) else v.encode('utf-8')
def prefix_range_end(v):
v = bytearray(to_bytes(v))
for i in range(len(v) - 1, -1, -1):
if v[i] < 0xff:
v[i] += 1
def prefix_range_end(v: str) -> bytes:
ret = bytearray(to_bytes(v))
for i in range(len(ret) - 1, -1, -1):
if ret[i] < 0xff:
ret[i] += 1
break
return bytes(v)
return bytes(ret)
def base64_encode(v):
def base64_encode(v: Union[str, bytes]) -> str:
return base64.b64encode(to_bytes(v)).decode('utf-8')
def base64_decode(v):
def base64_decode(v: str) -> str:
return base64.b64decode(v).decode('utf-8')
def build_range_request(key, range_end=None):
def build_range_request(key: str, range_end: Union[bytes, str, None] = None) -> Dict[str, Any]:
fields = {'key': base64_encode(key)}
if range_end:
fields['range_end'] = base64_encode(range_end)
return fields
def _handle_auth_errors(func: Callable[..., Any]) -> Any:
def wrapper(self: 'Etcd3Client', *args: Any, **kwargs: Any) -> Any:
def retry(ex: Exception) -> Any:
if self.username and self.password:
self.authenticate()
return func(self, *args, **kwargs)
else:
logger.fatal('Username or password not set, authentication is not possible')
raise ex
try:
return func(self, *args, **kwargs)
except (UserEmpty, PermissionDenied) as e: # no token provided
# PermissionDenied is raised on 3.0 and 3.1
if self._cluster_version < (3, 3) and (not isinstance(e, PermissionDenied)
or self._cluster_version < (3, 2)):
raise UnsupportedEtcdVersion('Authentication is required by Etcd cluster but not '
'supported on version lower than 3.3.0. Cluster version: '
'{0}'.format('.'.join(map(str, self._cluster_version))))
return retry(e)
except InvalidAuthToken as e:
logger.error('Invalid auth token: %s', self._token)
return retry(e)
return wrapper
class Etcd3Client(AbstractEtcdClientWithFailover):
ERROR_CLS = Etcd3Error
def __init__(self, config, dns_resolver, cache_ttl=300):
def __init__(self, config: Dict[str, Any], dns_resolver: DnsCachingResolver, cache_ttl: int = 300) -> None:
self._token = None
self._cluster_version = None
self._cluster_version: Tuple[int] = tuple()
self.version_prefix = '/v3beta'
super(Etcd3Client, self).__init__(config, dns_resolver, cache_ttl)
@@ -182,32 +234,33 @@ class Etcd3Client(AbstractEtcdClientWithFailover):
logger.fatal('Etcd3 authentication failed: %r', e)
sys.exit(1)
def _get_headers(self):
def _get_headers(self) -> Dict[str, str]:
headers = urllib3.make_headers(user_agent=USER_AGENT)
if self._token and self._cluster_version >= (3, 3, 0):
headers['authorization'] = self._token
return headers
def _prepare_request(self, kwargs, params=None, method=None):
def _prepare_request(self, kwargs: Dict[str, Any], params: Optional[Dict[str, Any]] = None,
method: Optional[str] = None) -> Callable[..., urllib3.response.HTTPResponse]:
if params is not None:
kwargs['body'] = json.dumps(params)
kwargs['headers']['Content-Type'] = 'application/json'
return self.http.urlopen
@staticmethod
def _handle_server_response(response):
data = response.data
def _handle_server_response(self, response: urllib3.response.HTTPResponse) -> Dict[str, Any]:
data: Union[bytes, str] = response.data
try:
data = data.decode('utf-8')
data = json.loads(data)
ret: Dict[str, Any] = json.loads(data)
if response.status < 400:
return ret
except (TypeError, ValueError, UnicodeError) as e:
if response.status < 400:
raise etcd.EtcdException('Server response was not valid JSON: %r' % e)
if response.status < 400:
return data
_raise_for_data(data, response.status)
ret = {}
raise _raise_for_data(ret or data, response.status)
def _ensure_version_prefix(self, base_uri, **kwargs):
def _ensure_version_prefix(self, base_uri: str, **kwargs: Any) -> None:
if self.version_prefix != '/v3':
response = self.http.urlopen(self._MGET, base_uri + '/version', **kwargs)
response = self._handle_server_response(response)
@@ -234,87 +287,64 @@ class Etcd3Client(AbstractEtcdClientWithFailover):
else:
self.version_prefix = '/v3'
def _prepare_get_members(self, etcd_nodes):
def _prepare_get_members(self, etcd_nodes: int) -> Dict[str, Any]:
kwargs = self._prepare_common_parameters(etcd_nodes)
self._prepare_request(kwargs, {})
return kwargs
def _get_members(self, base_uri, **kwargs):
def _get_members(self, base_uri: str, **kwargs: Any) -> List[str]:
self._ensure_version_prefix(base_uri, **kwargs)
resp = self.http.urlopen(self._MPOST, base_uri + self.version_prefix + '/cluster/member/list', **kwargs)
members = self._handle_server_response(resp)['members']
return set(url for member in members for url in member.get('clientURLs', []))
return [url for member in members for url in member.get('clientURLs', [])]
def call_rpc(self, method, fields, retry=None):
def call_rpc(self, method: str, fields: Dict[str, Any], retry: Optional[Retry] = None) -> Dict[str, Any]:
fields['retry'] = retry
return self.api_execute(self.version_prefix + method, self._MPOST, fields)
def authenticate(self):
if self._use_proxies and self._cluster_version is None:
def authenticate(self) -> bool:
if self._use_proxies and not self._cluster_version:
kwargs = self._prepare_common_parameters(1)
self._ensure_version_prefix(self._base_uri, **kwargs)
if self._cluster_version >= (3, 3) and self.username and self.password:
logger.info('Trying to authenticate on Etcd...')
old_token, self._token = self._token, None
try:
response = self.call_rpc('/auth/authenticate', {'name': self.username, 'password': self.password})
except AuthNotEnabled:
logger.info('Etcd authentication is not enabled')
self._token = None
except Exception:
self._token = old_token
raise
else:
self._token = response.get('token')
return old_token != self._token
def _handle_auth_errors(func):
def wrapper(self, *args, **kwargs):
def retry(ex):
if self.username and self.password:
self.authenticate()
return func(self, *args, **kwargs)
else:
logger.fatal('Username or password not set, authentication is not possible')
raise ex
try:
return func(self, *args, **kwargs)
except (UserEmpty, PermissionDenied) as e: # no token provided
# PermissionDenied is raised on 3.0 and 3.1
if self._cluster_version < (3, 3) and (not isinstance(e, PermissionDenied)
or self._cluster_version < (3, 2)):
raise UnsupportedEtcdVersion('Authentication is required by Etcd cluster but not '
'supported on version lower than 3.3.0. Cluster version: '
'{0}'.format('.'.join(map(str, self._cluster_version))))
return retry(e)
except InvalidAuthToken as e:
logger.error('Invalid auth token: %s', self._token)
return retry(e)
return wrapper
if not (self._cluster_version >= (3, 3) and self.username and self.password):
return False
logger.info('Trying to authenticate on Etcd...')
old_token, self._token = self._token, None
try:
response = self.call_rpc('/auth/authenticate', {'name': self.username, 'password': self.password})
except AuthNotEnabled:
logger.info('Etcd authentication is not enabled')
self._token = None
except Exception:
self._token = old_token
raise
else:
self._token = response.get('token')
return old_token != self._token
@_handle_auth_errors
def range(self, key, range_end=None, retry=None):
def range(self, key: str, range_end: Union[bytes, str, None] = None,
retry: Optional[Retry] = None) -> Dict[str, Any]:
params = build_range_request(key, range_end)
params['serializable'] = True # For better performance. We can tolerate stale reads.
return self.call_rpc('/kv/range', params, retry)
def prefix(self, key, retry=None):
def prefix(self, key: str, retry: Optional[Retry] = None) -> Dict[str, Any]:
return self.range(key, prefix_range_end(key), retry)
@_handle_auth_errors
def lease_grant(self, ttl, retry=None):
def lease_grant(self, ttl: int, retry: Optional[Retry] = None) -> str:
return self.call_rpc('/lease/grant', {'TTL': ttl}, retry)['ID']
def lease_keepalive(self, ID, retry=None):
def lease_keepalive(self, ID: str, retry: Optional[Retry] = None) -> Optional[str]:
return self.call_rpc('/lease/keepalive', {'ID': ID}, retry).get('result', {}).get('TTL')
def txn(self, compare, success, retry=None):
return self.call_rpc('/kv/txn', {'compare': [compare], 'success': [success]}, retry).get('succeeded')
def txn(self, compare: Dict[str, Any], success: Dict[str, Any], retry: Optional[Retry] = None) -> Dict[str, Any]:
return self.call_rpc('/kv/txn', {'compare': [compare], 'success': [success]}, retry).get('succeeded', {})
@_handle_auth_errors
def put(self, key, value, lease=None, create_revision=None, mod_revision=None, retry=None):
def put(self, key: str, value: str, lease: Optional[str] = None, create_revision: Optional[str] = None,
mod_revision: Optional[str] = None, retry: Optional[Retry] = None) -> Dict[str, Any]:
fields = {'key': base64_encode(key), 'value': base64_encode(value)}
if lease:
fields['lease'] = lease
@@ -328,17 +358,20 @@ class Etcd3Client(AbstractEtcdClientWithFailover):
return self.txn(compare, {'request_put': fields}, retry)
@_handle_auth_errors
def deleterange(self, key, range_end=None, mod_revision=None, retry=None):
def deleterange(self, key: str, range_end: Union[bytes, str, None] = None,
mod_revision: Optional[str] = None, retry: Optional[Retry] = None) -> Dict[str, Any]:
fields = build_range_request(key, range_end)
if mod_revision is None:
return self.call_rpc('/kv/deleterange', fields, retry)
compare = {'target': 'MOD', 'mod_revision': mod_revision, 'key': fields['key']}
return self.txn(compare, {'request_delete_range': fields}, retry)
def deleteprefix(self, key, retry=None):
def deleteprefix(self, key: str, retry: Optional[Retry] = None) -> Dict[str, Any]:
return self.deleterange(key, prefix_range_end(key), retry=retry)
def watchrange(self, key, range_end=None, start_revision=None, filters=None, read_timeout=None):
def watchrange(self, key: str, range_end: Union[bytes, str, None] = None,
start_revision: Optional[str] = None, filters: Optional[List[Dict[str, Any]]] = None,
read_timeout: Optional[float] = None) -> urllib3.response.HTTPResponse:
"""returns: response object"""
params = build_range_request(key, range_end)
if start_revision is not None:
@@ -349,14 +382,16 @@ class Etcd3Client(AbstractEtcdClientWithFailover):
kwargs.update(timeout=urllib3.Timeout(connect=kwargs['timeout'], read=read_timeout), retries=0)
return request_executor(self._MPOST, self._base_uri + self.version_prefix + '/watch', **kwargs)
def watchprefix(self, key, start_revision=None, filters=None, read_timeout=None):
def watchprefix(self, key: str, start_revision: Optional[str] = None,
filters: Optional[List[Dict[str, Any]]] = None,
read_timeout: Optional[float] = None) -> urllib3.response.HTTPResponse:
return self.watchrange(key, prefix_range_end(key), start_revision, filters, read_timeout)
class KVCache(Thread):
def __init__(self, dcs, client):
Thread.__init__(self)
def __init__(self, dcs: 'Etcd3', client: 'PatroniEtcd3Client') -> None:
super(KVCache, self).__init__()
self.daemon = True
self._dcs = dcs
self._client = client
@@ -373,32 +408,32 @@ class KVCache(Thread):
self._object_cache_lock = Lock()
self.start()
def set(self, value, overwrite=False):
def set(self, value: Dict[str, Any], overwrite: bool = False) -> Tuple[bool, Optional[Dict[str, Any]]]:
with self._object_cache_lock:
name = value['key']
old_value = self._object_cache.get(name)
ret = not old_value or int(old_value['mod_revision']) < int(value['mod_revision'])
if ret or overwrite and old_value['mod_revision'] == value['mod_revision']:
if ret or overwrite and old_value and old_value['mod_revision'] == value['mod_revision']:
self._object_cache[name] = value
return ret, old_value
def delete(self, name, mod_revision):
def delete(self, name: str, mod_revision: str) -> Tuple[bool, Optional[Dict[str, Any]]]:
with self._object_cache_lock:
old_value = self._object_cache.get(name)
ret = old_value and int(old_value['mod_revision']) < int(mod_revision)
if ret:
del self._object_cache[name]
return not old_value or ret, old_value
return bool(not old_value or ret), old_value
def copy(self):
def copy(self) -> List[Dict[str, Any]]:
with self._object_cache_lock:
return [v.copy() for v in self._object_cache.values()]
def get(self, name):
def get(self, name: str) -> Optional[Dict[str, Any]]:
with self._object_cache_lock:
return self._object_cache.get(name)
def _process_event(self, event):
def _process_event(self, event: Dict[str, Any]) -> None:
kv = event['kv']
key = kv['key']
if event.get('type') == 'DELETE':
@@ -422,21 +457,22 @@ class KVCache(Thread):
or (self.get(self._leader_key) or {}).get('value') != self._name):
self._dcs.event.set()
def _process_message(self, message):
def _process_message(self, message: Dict[str, Any]) -> None:
logger.debug('Received message: %s', message)
if 'error' in message:
_raise_for_data(message)
for event in message.get('result', {}).get('events', []):
raise _raise_for_data(message)
events: List[Dict[str, Any]] = message.get('result', {}).get('events', [])
for event in events:
self._process_event(event)
@staticmethod
def _finish_response(response):
def _finish_response(response: urllib3.response.HTTPResponse) -> None:
try:
response.close()
finally:
response.release_conn()
def _do_watch(self, revision):
def _do_watch(self, revision: str) -> None:
with self._response_lock:
self._response = None
# We do most of requests with timeouts. The only exception /watch requests to Etcd v3.
@@ -457,7 +493,7 @@ class KVCache(Thread):
for message in iter_response_objects(response):
self._process_message(message)
def _build_cache(self):
def _build_cache(self) -> None:
result = self._dcs.retry(self._client.prefix, self._dcs.cluster_prefix)
with self._object_cache_lock:
self._object_cache = {node['key']: node for node in result.get('kvs', [])}
@@ -476,10 +512,10 @@ class KVCache(Thread):
self._is_ready = False
with self._response_lock:
response, self._response = self._response, None
if response:
if isinstance(response, urllib3.response.HTTPResponse):
self._finish_response(response)
def run(self):
def run(self) -> None:
while True:
try:
self._build_cache()
@@ -487,12 +523,12 @@ class KVCache(Thread):
logger.error('KVCache.run %r', e)
time.sleep(1)
def kill_stream(self):
def kill_stream(self) -> None:
sock = None
with self._response_lock:
if self._response:
if isinstance(self._response, urllib3.response.HTTPResponse):
try:
sock = self._response.connection.sock
sock = self._response.connection.sock if self._response.connection else None
except Exception:
sock = None
else:
@@ -504,48 +540,48 @@ class KVCache(Thread):
except Exception as e:
logger.debug('Error on socket.shutdown: %r', e)
def is_ready(self):
def is_ready(self) -> bool:
"""Must be called only when holding the lock on `condition`"""
return self._is_ready
class PatroniEtcd3Client(Etcd3Client):
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
self._kv_cache = None
super(PatroniEtcd3Client, self).__init__(*args, **kwargs)
def configure(self, etcd3):
def configure(self, etcd3: 'Etcd3') -> None:
self._etcd3 = etcd3
def start_watcher(self):
def start_watcher(self) -> None:
if self._cluster_version >= (3, 1):
self._kv_cache = KVCache(self._etcd3, self)
def _restart_watcher(self):
def _restart_watcher(self) -> None:
if self._kv_cache:
self._kv_cache.kill_stream()
def set_base_uri(self, value):
def set_base_uri(self, value: str) -> None:
super(PatroniEtcd3Client, self).set_base_uri(value)
self._restart_watcher()
def authenticate(self):
def authenticate(self) -> bool:
ret = super(PatroniEtcd3Client, self).authenticate()
if ret:
self._restart_watcher()
return ret
def _wait_cache(self, timeout):
def _wait_cache(self, timeout: float) -> None:
stop_time = time.time() + timeout
while not self._kv_cache.is_ready():
while self._kv_cache and not self._kv_cache.is_ready():
timeout = stop_time - time.time()
if timeout <= 0:
raise RetryFailedError('Exceeded retry deadline')
self._kv_cache.condition.wait(timeout)
def get_cluster(self, path):
if self._kv_cache and path.startswith(self._etcd3.cluster_prefix):
def get_cluster(self, path: str) -> List[Dict[str, Any]]:
if self._kv_cache and self._etcd3._retry.deadline is not None and path.startswith(self._etcd3.cluster_prefix):
with self._kv_cache.condition:
self._wait_cache(self._etcd3._retry.deadline)
ret = self._kv_cache.copy()
@@ -557,7 +593,7 @@ class PatroniEtcd3Client(Etcd3Client):
'lease': node.get('lease')})
return ret
def call_rpc(self, method, fields, retry=None):
def call_rpc(self, method: str, fields: Dict[str, Any], retry: Optional[Retry] = None) -> Dict[str, Any]:
ret = super(PatroniEtcd3Client, self).call_rpc(method, fields, retry)
if self._kv_cache:
@@ -582,8 +618,9 @@ class PatroniEtcd3Client(Etcd3Client):
class Etcd3(AbstractEtcd):
def __init__(self, config):
def __init__(self, config: Dict[str, Any]) -> None:
super(Etcd3, self).__init__(config, PatroniEtcd3Client, (DeadlineExceeded, Unavailable, FailedPrecondition))
assert isinstance(self._client, PatroniEtcd3Client)
self.__do_not_watch = False
self._lease = None
self._last_lease_refresh = 0
@@ -593,15 +630,23 @@ class Etcd3(AbstractEtcd):
self._client.start_watcher()
self.create_lease()
def set_socket_options(self, sock, socket_options):
@property
def _client(self) -> PatroniEtcd3Client:
assert isinstance(self._abstract_client, PatroniEtcd3Client)
return self._abstract_client
def set_socket_options(self, sock: socket.socket,
socket_options: Optional[Collection[Tuple[int, int, int]]]) -> None:
assert self._retry.deadline is not None
enable_keepalive(sock, self.ttl, int(self.loop_wait + self._retry.deadline))
def set_ttl(self, ttl):
def set_ttl(self, ttl: int) -> Optional[bool]:
self.__do_not_watch = super(Etcd3, self).set_ttl(ttl)
if self.__do_not_watch:
self._lease = None
return None
def _do_refresh_lease(self, force=False, retry=None):
def _do_refresh_lease(self, force: bool = False, retry: Optional[Retry] = None) -> bool:
if not force and self._lease and self._last_lease_refresh + self._loop_wait > time.time():
return False
@@ -615,14 +660,14 @@ class Etcd3(AbstractEtcd):
self._last_lease_refresh = time.time()
return ret
def refresh_lease(self):
def refresh_lease(self) -> bool:
try:
return self.retry(self._do_refresh_lease)
except (Etcd3ClientError, RetryFailedError):
logger.exception('refresh_lease')
raise Etcd3Error('Failed to keepalive/grant lease')
def create_lease(self):
def create_lease(self) -> None:
while not self._lease:
try:
self.refresh_lease()
@@ -631,14 +676,14 @@ class Etcd3(AbstractEtcd):
time.sleep(5)
@property
def cluster_prefix(self):
def cluster_prefix(self) -> str:
return self._base_path + '/' if self.is_citus_coordinator() else self.client_path('')
@staticmethod
def member(node):
def member(node: Dict[str, str]) -> Member:
return Member.from_node(node['mod_revision'], os.path.basename(node['key']), node['lease'], node['value'])
def _cluster_from_nodes(self, nodes):
def _cluster_from_nodes(self, nodes: Dict[str, Any]) -> Cluster:
# get initialize flag
initialize = nodes.get(self._INITIALIZE)
initialize = initialize and initialize['value']
@@ -666,7 +711,7 @@ class Etcd3(AbstractEtcd):
slots = None
try:
last_lsn = int(last_lsn)
last_lsn = int(last_lsn or '')
except Exception:
last_lsn = 0
@@ -701,14 +746,14 @@ class Etcd3(AbstractEtcd):
return Cluster(initialize, config, leader, last_lsn, members, failover, sync, history, slots, failsafe)
def _cluster_loader(self, path):
def _cluster_loader(self, path: str) -> Cluster:
nodes = {node['key'][len(path):]: node
for node in self._client.get_cluster(path)
if node['key'].startswith(path)}
return self._cluster_from_nodes(nodes)
def _citus_cluster_loader(self, path):
clusters = defaultdict(dict)
def _citus_cluster_loader(self, path: str) -> Dict[int, Cluster]:
clusters: Dict[int, Dict[str, Dict[str, Any]]] = defaultdict(dict)
path = self._base_path + '/'
for node in self._client.get_cluster(path):
key = node['key'][len(path):].split('/', 1)
@@ -716,7 +761,9 @@ class Etcd3(AbstractEtcd):
clusters[int(key[0])][key[1]] = node
return {group: self._cluster_from_nodes(nodes) for group, nodes in clusters.items()}
def _load_cluster(self, path, loader):
def _load_cluster(
self, path: str, loader: Callable[[str], Union[Cluster, Dict[int, Cluster]]]
) -> Union[Cluster, Dict[int, Cluster]]:
cluster = None
try:
cluster = loader(path)
@@ -725,10 +772,11 @@ class Etcd3(AbstractEtcd):
except Exception as e:
self._handle_exception(e, 'get_cluster', raise_ex=Etcd3Error('Etcd is not responding properly'))
self._has_failed = False
assert cluster is not None
return cluster
@catch_etcd_errors
def touch_member(self, data):
def touch_member(self, data: Dict[str, Any]) -> bool:
try:
self.refresh_lease()
except Etcd3Error:
@@ -740,19 +788,20 @@ class Etcd3(AbstractEtcd):
if member and member.session == self._lease and deep_compare(data, member.data):
return True
data = json.dumps(data, separators=(',', ':'))
value = json.dumps(data, separators=(',', ':'))
try:
return self._client.put(self.member_path, data, self._lease)
return bool(self._client.put(self.member_path, value, self._lease))
except LeaseNotFound:
self._lease = None
logger.error('Our lease disappeared from Etcd, can not "touch_member"')
return False
@catch_etcd_errors
def take_leader(self):
def take_leader(self) -> bool:
return self.retry(self._client.put, self.leader_path, self._name, self._lease)
def _do_attempt_to_acquire_leader(self, retry):
def _retry(*args, **kwargs):
def _do_attempt_to_acquire_leader(self, retry: Retry) -> bool:
def _retry(*args: Any, **kwargs: Any) -> Any:
kwargs['retry'] = retry
return retry(*args, **kwargs)
@@ -772,10 +821,10 @@ class Etcd3(AbstractEtcd):
return _retry(self._client.put, self.leader_path, self._name, self._lease, 0)
@catch_return_false_exception
def attempt_to_acquire_leader(self):
def attempt_to_acquire_leader(self) -> bool:
retry = self._retry.copy()
def _retry(*args, **kwargs):
def _retry(*args: Any, **kwargs: Any) -> Any:
kwargs['retry'] = retry
return retry(*args, **kwargs)
@@ -791,30 +840,30 @@ class Etcd3(AbstractEtcd):
return ret
@catch_etcd_errors
def set_failover_value(self, value, index=None):
return self._client.put(self.failover_path, value, mod_revision=index)
def set_failover_value(self, value: str, index: Optional[str] = None) -> bool:
return bool(self._client.put(self.failover_path, value, mod_revision=index))
@catch_etcd_errors
def set_config_value(self, value, index=None):
return self._client.put(self.config_path, value, mod_revision=index)
def set_config_value(self, value: str, index: Optional[str] = None) -> bool:
return bool(self._client.put(self.config_path, value, mod_revision=index))
@catch_etcd_errors
def _write_leader_optime(self, last_lsn):
return self._client.put(self.leader_optime_path, last_lsn)
def _write_leader_optime(self, last_lsn: str) -> bool:
return bool(self._client.put(self.leader_optime_path, last_lsn))
@catch_etcd_errors
def _write_status(self, value):
return self._client.put(self.status_path, value)
def _write_status(self, value: str) -> bool:
return bool(self._client.put(self.status_path, value))
@catch_etcd_errors
def _write_failsafe(self, value):
return self._client.put(self.failsafe_path, value)
def _write_failsafe(self, value: str) -> bool:
return bool(self._client.put(self.failsafe_path, value))
@catch_return_false_exception
def _update_leader(self):
def _update_leader(self) -> bool:
retry = self._retry.copy()
def _retry(*args, **kwargs):
def _retry(*args: Any, **kwargs: Any) -> Any:
kwargs['retry'] = retry
return retry(*args, **kwargs)
@@ -836,36 +885,37 @@ class Etcd3(AbstractEtcd):
return bool(self._lease)
@catch_etcd_errors
def initialize(self, create_new=True, sysid=""):
def initialize(self, create_new: bool = True, sysid: str = ""):
return self.retry(self._client.put, self.initialize_path, sysid, None, 0 if create_new else None)
@catch_etcd_errors
def _delete_leader(self):
def _delete_leader(self) -> bool:
cluster = self.cluster
if cluster and isinstance(cluster.leader, Leader) and cluster.leader.name == self._name:
return self._client.deleterange(self.leader_path, mod_revision=cluster.leader.index)
return True
@catch_etcd_errors
def cancel_initialization(self):
def cancel_initialization(self) -> bool:
return self.retry(self._client.deleterange, self.initialize_path)
@catch_etcd_errors
def delete_cluster(self):
def delete_cluster(self) -> bool:
return self.retry(self._client.deleteprefix, self.client_path(''))
@catch_etcd_errors
def set_history_value(self, value):
return self._client.put(self.history_path, value)
def set_history_value(self, value: str) -> bool:
return bool(self._client.put(self.history_path, value))
@catch_etcd_errors
def set_sync_state_value(self, value, index=None):
def set_sync_state_value(self, value: str, index: Optional[str] = None) -> bool:
return self.retry(self._client.put, self.sync_path, value, mod_revision=index)
@catch_etcd_errors
def delete_sync_state(self, index=None):
def delete_sync_state(self, index: Optional[str] = None) -> bool:
return self.retry(self._client.deleterange, self.sync_path, mod_revision=index)
def watch(self, leader_index, timeout):
def watch(self, leader_index: Optional[str], timeout: float) -> bool:
if self.__do_not_watch:
self.__do_not_watch = False
return True

View File

@@ -3,9 +3,12 @@ import logging
import random
import time
from patroni.dcs.zookeeper import ZooKeeper
from patroni.request import get as requests_get
from patroni.utils import uri
from typing import Any, Callable, Dict, List, Union
from . import Cluster
from .zookeeper import ZooKeeper
from ..request import get as requests_get
from ..utils import uri
logger = logging.getLogger(__name__)
@@ -14,11 +17,12 @@ class ExhibitorEnsembleProvider(object):
TIMEOUT = 3.1
def __init__(self, hosts, port, uri_path='/exhibitor/v1/cluster/list', poll_interval=300):
def __init__(self, hosts: List[str], port: int,
uri_path: str = '/exhibitor/v1/cluster/list', poll_interval: int = 300) -> None:
self._exhibitor_port = port
self._uri_path = uri_path
self._poll_interval = poll_interval
self._exhibitors = hosts
self._exhibitors: List[str] = hosts
self._boot_exhibitors = hosts
self._zookeeper_hosts = ''
self._next_poll = None
@@ -26,7 +30,7 @@ class ExhibitorEnsembleProvider(object):
logger.info('waiting on exhibitor')
time.sleep(5)
def poll(self):
def poll(self) -> bool:
if self._next_poll and self._next_poll > time.time():
return False
@@ -36,7 +40,8 @@ class ExhibitorEnsembleProvider(object):
if isinstance(json, dict) and 'servers' in json and 'port' in json:
self._next_poll = time.time() + self._poll_interval
zookeeper_hosts = ','.join([h + ':' + str(json['port']) for h in sorted(json['servers'])])
servers: List[str] = json['servers']
zookeeper_hosts = ','.join([h + ':' + str(json['port']) for h in sorted(servers)])
if self._zookeeper_hosts != zookeeper_hosts:
logger.info('ZooKeeper connection string has changed: %s => %s', self._zookeeper_hosts, zookeeper_hosts)
self._zookeeper_hosts = zookeeper_hosts
@@ -44,7 +49,7 @@ class ExhibitorEnsembleProvider(object):
return True
return False
def _query_exhibitors(self, exhibitors):
def _query_exhibitors(self, exhibitors: List[str]) -> Union[Dict[str, Any], Any]:
random.shuffle(exhibitors)
for host in exhibitors:
try:
@@ -55,18 +60,20 @@ class ExhibitorEnsembleProvider(object):
return None
@property
def zookeeper_hosts(self):
def zookeeper_hosts(self) -> str:
return self._zookeeper_hosts
class Exhibitor(ZooKeeper):
def __init__(self, config):
def __init__(self, config: Dict[str, Any]) -> None:
interval = config.get('poll_interval', 300)
self._ensemble_provider = ExhibitorEnsembleProvider(config['hosts'], config['port'], poll_interval=interval)
super(Exhibitor, self).__init__({**config, 'hosts': self._ensemble_provider.zookeeper_hosts})
def _load_cluster(self, path, loader):
def _load_cluster(
self, path: str, loader: Callable[[str], Union[Cluster, Dict[int, Cluster]]]
) -> Union[Cluster, Dict[int, Cluster]]:
if self._ensemble_provider.poll():
self._client.set_hosts(self._ensemble_provider.zookeeper_hosts)
return super(Exhibitor, self)._load_cluster(path, loader)

File diff suppressed because it is too large Load Diff

View File

@@ -10,10 +10,13 @@ from pysyncobj.dns_resolver import globalDnsResolver
from pysyncobj.node import TCPNode
from pysyncobj.transport import TCPTransport, CONNECTION_STATE
from pysyncobj.utility import TcpUtility
from typing import Any, Callable, Collection, Dict, List, Optional, Union, TYPE_CHECKING
from . import AbstractDCS, ClusterConfig, Cluster, Failover, Leader, Member, SyncState, TimelineHistory, citus_group_re
from ..exceptions import DCSError
from ..utils import validate_directory
if TYPE_CHECKING: # pragma: no cover
from ..config import Config
logger = logging.getLogger(__name__)
@@ -24,11 +27,12 @@ class RaftError(DCSError):
class _TCPTransport(TCPTransport):
def __init__(self, syncObj, selfNode, otherNodes):
def __init__(self, syncObj: 'DynMemberSyncObj', selfNode: Optional[TCPNode],
otherNodes: Collection[TCPNode]) -> None:
super(_TCPTransport, self).__init__(syncObj, selfNode, otherNodes)
self.setOnUtilityMessageCallback('members', syncObj.getMembers)
def _connectIfNecessarySingle(self, node):
def _connectIfNecessarySingle(self, node: TCPNode) -> bool:
try:
return super(_TCPTransport, self)._connectIfNecessarySingle(node)
except Exception as e:
@@ -36,7 +40,7 @@ class _TCPTransport(TCPTransport):
return False
def resolve_host(self):
def resolve_host(self: TCPNode) -> Optional[str]:
return globalDnsResolver().resolve(self.host)
@@ -45,17 +49,19 @@ setattr(TCPNode, 'ip', property(resolve_host))
class SyncObjUtility(object):
def __init__(self, otherNodes, conf, retry_timeout=10):
def __init__(self, otherNodes: Collection[Union[str, TCPNode]], conf: SyncObjConf, retry_timeout: int = 10) -> None:
self._nodes = otherNodes
self._utility = TcpUtility(conf.password, retry_timeout / max(1, len(otherNodes)))
self.__node = next(iter(otherNodes), None)
def executeCommand(self, command):
def executeCommand(self, command: List[Any]) -> Any:
try:
return self._utility.executeCommand(self.__node, command)
if self.__node:
return self._utility.executeCommand(self.__node, command)
except Exception:
return None
def getMembers(self):
def getMembers(self) -> Optional[List[str]]:
for self.__node in self._nodes:
response = self.executeCommand(['members'])
if response:
@@ -64,7 +70,8 @@ class SyncObjUtility(object):
class DynMemberSyncObj(SyncObj):
def __init__(self, selfAddress, partnerAddrs, conf, retry_timeout=10):
def __init__(self, selfAddress: Optional[str], partnerAddrs: Collection[str],
conf: SyncObjConf, retry_timeout: int = 10) -> None:
self.__early_apply_local_log = selfAddress is not None
self.applied_local_log = False
@@ -81,12 +88,12 @@ class DynMemberSyncObj(SyncObj):
thread.daemon = True
thread.start()
def getMembers(self, args, callback):
def getMembers(self, args: Any, callback: Callable[[Any, Any], Any]) -> None:
callback([{'addr': node.id, 'leader': node == self._getLeader(), 'status': CONNECTION_STATE.CONNECTED
if self.isNodeConnected(node) else CONNECTION_STATE.DISCONNECTED} for node in self.otherNodes]
+ [{'addr': self.selfNode.id, 'leader': self._isLeader(), 'status': CONNECTION_STATE.CONNECTED}], None)
def _onTick(self, timeToWait=0.0):
def _onTick(self, timeToWait: float = 0.0):
super(DynMemberSyncObj, self)._onTick(timeToWait)
# The SyncObj calls onReady callback only when cluster got the leader and is ready for writes.
@@ -98,11 +105,12 @@ class DynMemberSyncObj(SyncObj):
class KVStoreTTL(DynMemberSyncObj):
def __init__(self, on_ready, on_set, on_delete, **config):
def __init__(self, on_ready: Optional[Callable[..., Any]], on_set: Optional[Callable[[str, Dict[str, Any]], None]],
on_delete: Optional[Callable[[str], None]], **config: Any) -> None:
self.__thread = None
self.__on_set = on_set
self.__on_delete = on_delete
self.__limb = {}
self.__limb: Dict[str, Dict[str, Any]] = {}
self.set_retry_timeout(int(config.get('retry_timeout') or 10))
self_addr = config.get('self_addr')
@@ -128,22 +136,22 @@ class KVStoreTTL(DynMemberSyncObj):
onReady=on_ready, dynamicMembershipChange=True)
super(KVStoreTTL, self).__init__(self_addr, partner_addrs, conf, self.__retry_timeout)
self.__data = {}
self.__data: Dict[str, Dict[str, Any]] = {}
@staticmethod
def __check_requirements(old_value, **kwargs):
return ('prevExist' not in kwargs or bool(kwargs['prevExist']) == bool(old_value)) and \
('prevValue' not in kwargs or old_value and old_value['value'] == kwargs['prevValue']) and \
(not kwargs.get('prevIndex') or old_value and old_value['index'] == kwargs['prevIndex'])
def __check_requirements(old_value: Dict[str, Any], **kwargs: Any) -> bool:
return bool(('prevExist' not in kwargs or bool(kwargs['prevExist']) == bool(old_value))
and ('prevValue' not in kwargs or old_value and old_value['value'] == kwargs['prevValue'])
and (not kwargs.get('prevIndex') or old_value and old_value['index'] == kwargs['prevIndex']))
def set_retry_timeout(self, retry_timeout):
def set_retry_timeout(self, retry_timeout: int) -> None:
self.__retry_timeout = retry_timeout
def retry(self, func, *args, **kwargs):
def retry(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
event = threading.Event()
ret = {'result': None, 'error': -1}
def callback(result, error):
def callback(result: Any, error: Any) -> None:
ret.update(result=result, error=error)
event.set()
@@ -167,7 +175,7 @@ class KVStoreTTL(DynMemberSyncObj):
return False
@replicated
def _set(self, key, value, **kwargs):
def _set(self, key: str, value: Dict[str, Any], **kwargs: Any) -> bool:
old_value = self.__data.get(key, {})
if not self.__check_requirements(old_value, **kwargs):
return False
@@ -181,29 +189,30 @@ class KVStoreTTL(DynMemberSyncObj):
self.__on_set(key, value)
return True
def set(self, key, value, ttl=None, handle_raft_error=True, **kwargs):
def set(self, key: str, value: str, ttl: Optional[int] = None,
handle_raft_error: bool = True, **kwargs: Any) -> bool:
old_value = self.__data.get(key, {})
if not self.__check_requirements(old_value, **kwargs):
return False
value = {'value': value, 'updated': time.time()}
value['created'] = old_value.get('created', value['updated'])
data: Dict[str, Any] = {'value': value, 'updated': time.time()}
data['created'] = old_value.get('created', data['updated'])
if ttl:
value['expire'] = value['updated'] + ttl
data['expire'] = data['updated'] + ttl
try:
return self.retry(self._set, key, value, **kwargs)
return self.retry(self._set, key, data, **kwargs)
except RaftError:
if not handle_raft_error:
raise
return False
def __pop(self, key):
def __pop(self, key: str) -> None:
self.__data.pop(key)
if self.__on_delete:
self.__on_delete(key)
@replicated
def _delete(self, key, recursive=False, **kwargs):
def _delete(self, key: str, recursive: bool = False, **kwargs: Any) -> bool:
if recursive:
for k in list(self.__data.keys()):
if k.startswith(key):
@@ -214,7 +223,7 @@ class KVStoreTTL(DynMemberSyncObj):
self.__pop(key)
return True
def delete(self, key, recursive=False, **kwargs):
def delete(self, key: str, recursive: bool = False, **kwargs: Any) -> bool:
if not recursive and not self.__check_requirements(self.__data.get(key, {}), **kwargs):
return False
try:
@@ -223,32 +232,32 @@ class KVStoreTTL(DynMemberSyncObj):
return False
@staticmethod
def __values_match(old, new):
def __values_match(old: Dict[str, Any], new: Dict[str, Any]) -> bool:
return all(old.get(n) == new.get(n) for n in ('created', 'updated', 'expire', 'value'))
@replicated
def _expire(self, key, value, callback=None):
def _expire(self, key: str, value: Dict[str, Any], callback: Optional[Callable[..., Any]] = None) -> None:
current = self.__data.get(key)
if current and self.__values_match(current, value):
self.__pop(key)
def __expire_keys(self):
def __expire_keys(self) -> None:
for key, value in self.__data.items():
if value and 'expire' in value and value['expire'] <= time.time() and \
not (key in self.__limb and self.__values_match(self.__limb[key], value)):
self.__limb[key] = value
def callback(*args):
def callback(*args: Any) -> None:
if key in self.__limb and self.__values_match(self.__limb[key], value):
self.__limb.pop(key)
self._expire(key, value, callback=callback)
def get(self, key, recursive=False):
def get(self, key: str, recursive: bool = False) -> Union[None, Dict[str, Any], Dict[str, Dict[str, Any]]]:
if not recursive:
return self.__data.get(key)
return {k: v for k, v in self.__data.items() if k.startswith(key)}
def _onTick(self, timeToWait=0.0):
def _onTick(self, timeToWait: float = 0.0) -> None:
super(KVStoreTTL, self)._onTick(timeToWait)
if self._isLeader():
@@ -256,17 +265,17 @@ class KVStoreTTL(DynMemberSyncObj):
else:
self.__limb.clear()
def _autoTickThread(self):
def _autoTickThread(self) -> None:
self.__destroying = False
while not self.__destroying:
self.doTick(self.conf.autoTickPeriod)
def startAutoTick(self):
def startAutoTick(self) -> None:
self.__thread = threading.Thread(target=self._autoTickThread)
self.__thread.daemon = True
self.__thread.start()
def destroy(self):
def destroy(self) -> None:
if self.__thread:
self.__destroying = True
self.__thread.join()
@@ -275,7 +284,7 @@ class KVStoreTTL(DynMemberSyncObj):
class Raft(AbstractDCS):
def __init__(self, config):
def __init__(self, config: Dict[str, Any]) -> None:
super(Raft, self).__init__(config)
self._ttl = int(config.get('ttl') or 30)
@@ -290,7 +299,7 @@ class Raft(AbstractDCS):
else:
logger.info('waiting on raft')
def _on_set(self, key, value):
def _on_set(self, key: str, value: Dict[str, Any]) -> None:
leader = (self._sync_obj.get(self.leader_path) or {}).get('value')
if key == value['created'] == value['updated'] and \
(key.startswith(self.members_path) or key == self.leader_path and leader != self._name) or \
@@ -298,29 +307,29 @@ class Raft(AbstractDCS):
key in (self.config_path, self.sync_path):
self.event.set()
def _on_delete(self, key):
def _on_delete(self, key: str) -> None:
if key == self.leader_path:
self.event.set()
def set_ttl(self, ttl):
def set_ttl(self, ttl: int) -> Optional[bool]:
self._ttl = ttl
@property
def ttl(self):
def ttl(self) -> int:
return self._ttl
def set_retry_timeout(self, retry_timeout):
def set_retry_timeout(self, retry_timeout: int) -> None:
self._sync_obj.set_retry_timeout(retry_timeout)
def reload_config(self, config):
def reload_config(self, config: Union['Config', Dict[str, Any]]) -> None:
super(Raft, self).reload_config(config)
globalDnsResolver().setTimeouts(self.ttl, self.loop_wait)
@staticmethod
def member(key, value):
def member(key: str, value: Dict[str, Any]) -> Member:
return Member.from_node(value['index'], os.path.basename(key), None, value['value'])
def _cluster_from_nodes(self, nodes):
def _cluster_from_nodes(self, nodes: Dict[str, Any]) -> Cluster:
# get initialize flag
initialize = nodes.get(self._INITIALIZE)
initialize = initialize and initialize['value']
@@ -348,7 +357,7 @@ class Raft(AbstractDCS):
slots = None
try:
last_lsn = int(last_lsn)
last_lsn = int(last_lsn or '')
except Exception:
last_lsn = 0
@@ -380,79 +389,81 @@ class Raft(AbstractDCS):
return Cluster(initialize, config, leader, last_lsn, members, failover, sync, history, slots, failsafe)
def _cluster_loader(self, path):
def _cluster_loader(self, path: str) -> Cluster:
response = self._sync_obj.get(path, recursive=True)
if not response:
return Cluster.empty()
nodes = {key[len(path):]: value for key, value in response.items()}
return self._cluster_from_nodes(nodes)
def _citus_cluster_loader(self, path):
clusters = defaultdict(dict)
def _citus_cluster_loader(self, path: str) -> Dict[int, Cluster]:
clusters: Dict[int, Dict[str, Any]] = defaultdict(dict)
response = self._sync_obj.get(path, recursive=True)
for key, value in response.items():
for key, value in (response or {}).items():
key = key[len(path):].split('/', 1)
if len(key) == 2 and citus_group_re.match(key[0]):
clusters[int(key[0])][key[1]] = value
return {group: self._cluster_from_nodes(nodes) for group, nodes in clusters.items()}
def _load_cluster(self, path, loader):
def _load_cluster(
self, path: str, loader: Callable[[str], Union[Cluster, Dict[int, Cluster]]]
) -> Union[Cluster, Dict[int, Cluster]]:
return loader(path)
def _write_leader_optime(self, last_lsn):
def _write_leader_optime(self, last_lsn: str) -> bool:
return self._sync_obj.set(self.leader_optime_path, last_lsn, timeout=1)
def _write_status(self, value):
def _write_status(self, value: str) -> bool:
return self._sync_obj.set(self.status_path, value, timeout=1)
def _write_failsafe(self, value):
def _write_failsafe(self, value: str) -> bool:
return self._sync_obj.set(self.failsafe_path, value, timeout=1)
def _update_leader(self):
def _update_leader(self) -> bool:
ret = self._sync_obj.set(self.leader_path, self._name, ttl=self._ttl,
handle_raft_error=False, prevValue=self._name)
if not ret and self._sync_obj.get(self.leader_path) is None:
ret = self.attempt_to_acquire_leader()
return ret
def attempt_to_acquire_leader(self):
def attempt_to_acquire_leader(self) -> bool:
return self._sync_obj.set(self.leader_path, self._name, ttl=self._ttl, handle_raft_error=False, prevExist=False)
def set_failover_value(self, value, index=None):
def set_failover_value(self, value: str, index: Optional[int] = None) -> bool:
return self._sync_obj.set(self.failover_path, value, prevIndex=index)
def set_config_value(self, value, index=None):
def set_config_value(self, value: str, index: Optional[int] = None) -> bool:
return self._sync_obj.set(self.config_path, value, prevIndex=index)
def touch_member(self, data):
data = json.dumps(data, separators=(',', ':'))
return self._sync_obj.set(self.member_path, data, self._ttl, timeout=2)
def touch_member(self, data: Dict[str, Any]) -> bool:
value = json.dumps(data, separators=(',', ':'))
return self._sync_obj.set(self.member_path, value, self._ttl, timeout=2)
def take_leader(self):
def take_leader(self) -> bool:
return self._sync_obj.set(self.leader_path, self._name, ttl=self._ttl)
def initialize(self, create_new=True, sysid=''):
def initialize(self, create_new: bool = True, sysid: str = '') -> bool:
return self._sync_obj.set(self.initialize_path, sysid, prevExist=(not create_new))
def _delete_leader(self):
def _delete_leader(self) -> bool:
return self._sync_obj.delete(self.leader_path, prevValue=self._name, timeout=1)
def cancel_initialization(self):
def cancel_initialization(self) -> bool:
return self._sync_obj.delete(self.initialize_path)
def delete_cluster(self):
def delete_cluster(self) -> bool:
return self._sync_obj.delete(self.client_path(''), recursive=True)
def set_history_value(self, value):
def set_history_value(self, value: str) -> bool:
return self._sync_obj.set(self.history_path, value)
def set_sync_state_value(self, value, index=None):
def set_sync_state_value(self, value: str, index: Optional[int] = None) -> bool:
return self._sync_obj.set(self.sync_path, value, prevIndex=index)
def delete_sync_state(self, index=None):
def delete_sync_state(self, index: Optional[int] = None) -> bool:
return self._sync_obj.delete(self.sync_path, prevIndex=index)
def watch(self, leader_index, timeout):
def watch(self, leader_index: Optional[int], timeout: float) -> bool:
try:
return super(Raft, self).watch(leader_index, timeout)
finally:

View File

@@ -1,18 +1,22 @@
import json
import logging
import select
import socket
import time
from kazoo.client import KazooClient, KazooState, KazooRetry
from kazoo.exceptions import ConnectionClosedError, NoNodeError, NodeExistsError, SessionExpiredError
from kazoo.handlers.threading import SequentialThreadingHandler
from kazoo.protocol.states import KeeperState
from kazoo.handlers.threading import AsyncResult, SequentialThreadingHandler
from kazoo.protocol.states import KeeperState, WatchedEvent, ZnodeStat
from kazoo.retry import RetryFailedError
from kazoo.security import make_acl
from kazoo.security import ACL, make_acl
from typing import Any, Callable, Dict, List, Optional, Union, Tuple, TYPE_CHECKING
from . import AbstractDCS, ClusterConfig, Cluster, Failover, Leader, Member, SyncState, TimelineHistory, citus_group_re
from ..exceptions import DCSError
from ..utils import deep_compare
if TYPE_CHECKING: # pragma: no cover
from ..config import Config
logger = logging.getLogger(__name__)
@@ -23,14 +27,14 @@ class ZooKeeperError(DCSError):
class PatroniSequentialThreadingHandler(SequentialThreadingHandler):
def __init__(self, connect_timeout):
def __init__(self, connect_timeout: Union[int, float]) -> None:
super(PatroniSequentialThreadingHandler, self).__init__()
self.set_connect_timeout(connect_timeout)
def set_connect_timeout(self, connect_timeout):
def set_connect_timeout(self, connect_timeout: Union[int, float]) -> None:
self._connect_timeout = max(1.0, connect_timeout / 2.0) # try to connect to zookeeper node during loop_wait/2
def create_connection(self, *args, **kwargs):
def create_connection(self, *args: Any, **kwargs: Any) -> socket.socket:
"""This method is trying to establish connection with one of the zookeeper nodes.
Somehow strategy "fail earlier and retry more often" works way better comparing to
the original strategy "try to connect with specified timeout".
@@ -41,16 +45,16 @@ class PatroniSequentialThreadingHandler(SequentialThreadingHandler):
:param args: always contains `tuple(host, port)` as the first element and could contain
`connect_timeout` (negotiated session timeout) as the second element."""
args = list(args)
if len(args) == 0: # kazoo 2.6.0 slightly changed the way how it calls create_connection method
args_list: List[Any] = list(args)
if len(args_list) == 0: # kazoo 2.6.0 slightly changed the way how it calls create_connection method
kwargs['timeout'] = max(self._connect_timeout, kwargs.get('timeout', self._connect_timeout * 10) / 10.0)
elif len(args) == 1:
args.append(self._connect_timeout)
elif len(args_list) == 1:
args_list.append(self._connect_timeout)
else:
args[1] = max(self._connect_timeout, args[1] / 10.0)
return super(PatroniSequentialThreadingHandler, self).create_connection(*args, **kwargs)
args_list[1] = max(self._connect_timeout, args_list[1] / 10.0)
return super(PatroniSequentialThreadingHandler, self).create_connection(*args_list, **kwargs)
def select(self, *args, **kwargs):
def select(self, *args: Any, **kwargs: Any) -> Any:
"""
Python 3.XY may raise following exceptions if select/poll are called with an invalid socket:
- `ValueError`: because fd == -1
@@ -68,7 +72,7 @@ class PatroniSequentialThreadingHandler(SequentialThreadingHandler):
class PatroniKazooClient(KazooClient):
def _call(self, request, async_object):
def _call(self, request: Tuple[Any], async_object: AsyncResult) -> Optional[bool]:
# Before kazoo==2.7.0 it wasn't possible to send requests to zookeeper if
# the connection is in the SUSPENDED state and Patroni was strongly relying on it.
# The https://github.com/python-zk/kazoo/pull/588 changed it, and now such requests are queued.
@@ -82,7 +86,7 @@ class PatroniKazooClient(KazooClient):
class ZooKeeper(AbstractDCS):
def __init__(self, config):
def __init__(self, config: Dict[str, Any]) -> None:
super(ZooKeeper, self).__init__(config)
hosts = config.get('hosts', [])
@@ -94,17 +98,18 @@ class ZooKeeper(AbstractDCS):
kwargs = {v: config[k] for k, v in mapping.items() if k in config}
if 'set_acls' in config:
kwargs['default_acl'] = []
default_acl: List[ACL] = []
for principal, permissions in config['set_acls'].items():
normalizedPermissions = [p.upper() for p in permissions]
kwargs['default_acl'].append(make_acl(scheme='x509',
credential=principal,
read='READ' in normalizedPermissions,
write='WRITE' in normalizedPermissions,
create='CREATE' in normalizedPermissions,
delete='DELETE' in normalizedPermissions,
admin='ADMIN' in normalizedPermissions,
all='ALL' in normalizedPermissions))
default_acl.append(make_acl(scheme='x509',
credential=principal,
read='READ' in normalizedPermissions,
write='WRITE' in normalizedPermissions,
create='CREATE' in normalizedPermissions,
delete='DELETE' in normalizedPermissions,
admin='ADMIN' in normalizedPermissions,
all='ALL' in normalizedPermissions))
kwargs['default_acl'] = default_acl
self._client = PatroniKazooClient(hosts, handler=PatroniSequentialThreadingHandler(config['retry_timeout']),
timeout=config['ttl'], connection_retry=KazooRetry(max_delay=1, max_tries=-1,
@@ -112,16 +117,16 @@ class ZooKeeper(AbstractDCS):
deadline=config['retry_timeout'], sleep_func=time.sleep), **kwargs)
self._client.add_listener(self.session_listener)
self._fetch_cluster = True
self._fetch_status = True
self.__last_member_data = None
self._fetch_cluster: bool = True
self._fetch_status: bool = True
self.__last_member_data: Optional[Dict[str, Any]] = None
self._orig_kazoo_connect = self._client._connection._connect
self._client._connection._connect = self._kazoo_connect
self._client.start()
def _kazoo_connect(self, *args):
def _kazoo_connect(self, *args: Any) -> Tuple[Union[int, float], Union[int, float]]:
"""Kazoo is using Ping's to determine health of connection to zookeeper. If there is no
response on Ping after Ping interval (1/2 from read_timeout) it will consider current
connection dead and try to connect to another node. Without this "magic" it was taking
@@ -136,30 +141,28 @@ class ZooKeeper(AbstractDCS):
ret = self._orig_kazoo_connect(*args)
return max(self.loop_wait - 2, 2) * 1000, ret[1]
def session_listener(self, state):
def session_listener(self, state: str) -> None:
if state in [KazooState.SUSPENDED, KazooState.LOST]:
self.cluster_watcher(None)
def status_watcher(self, event):
def status_watcher(self, event: Optional[WatchedEvent]) -> None:
self._fetch_status = True
self.event.set()
def cluster_watcher(self, event):
def cluster_watcher(self, event: Optional[WatchedEvent]) -> None:
self._fetch_cluster = True
if not event or event.state != KazooState.CONNECTED or event.path.startswith(self.client_path('')):
self.status_watcher(event)
def members_watcher(self, event):
self._fetch_cluster = True
def reload_config(self, config):
def reload_config(self, config: Union['Config', Dict[str, Any]]) -> None:
self.set_retry_timeout(config['retry_timeout'])
loop_wait = config['loop_wait']
loop_wait_changed = self._loop_wait != loop_wait
self._loop_wait = loop_wait
self._client.handler.set_connect_timeout(loop_wait)
if isinstance(self._client.handler, PatroniSequentialThreadingHandler):
self._client.handler.set_connect_timeout(loop_wait)
# We need to reestablish connection to zookeeper if we want to change
# read_timeout (and Ping interval respectively), because read_timeout
@@ -170,7 +173,7 @@ class ZooKeeper(AbstractDCS):
if not self.set_ttl(config['ttl']) and loop_wait_changed:
self._client._connection._socket.close()
def set_ttl(self, ttl):
def set_ttl(self, ttl: int) -> Optional[bool]:
"""It is not possible to change ttl (session_timeout) in zookeeper without
destroying old session and creating the new one. This method returns `!True`
if session_timeout has been changed (`restart()` has been called)."""
@@ -181,21 +184,23 @@ class ZooKeeper(AbstractDCS):
return True
@property
def ttl(self):
return self._client._session_timeout / 1000.0
def ttl(self) -> int:
return int(self._client._session_timeout / 1000.0)
def set_retry_timeout(self, retry_timeout):
def set_retry_timeout(self, retry_timeout: int) -> None:
retry = self._client.retry if isinstance(self._client.retry, KazooRetry) else self._client._retry
retry.deadline = retry_timeout
def get_node(self, key, watch=None):
def get_node(
self, key: str, watch: Optional[Callable[[WatchedEvent], None]] = None
) -> Optional[Tuple[str, ZnodeStat]]:
try:
ret = self._client.get(key, watch)
return (ret[0].decode('utf-8'), ret[1])
except NoNodeError:
return None
def get_status(self, path, leader):
def get_status(self, path: str, leader: Optional[Leader]) -> Tuple[int, Optional[Dict[str, int]]]:
watch = self.status_watcher if not leader or leader.name != self._name else None
status = self.get_node(path + self._STATUS, watch)
@@ -212,7 +217,7 @@ class ZooKeeper(AbstractDCS):
slots = None
try:
last_lsn = int(last_lsn)
last_lsn = int(last_lsn or '')
except Exception:
last_lsn = 0
@@ -220,24 +225,24 @@ class ZooKeeper(AbstractDCS):
return last_lsn, slots
@staticmethod
def member(name, value, znode):
def member(name: str, value: str, znode: ZnodeStat) -> Member:
return Member.from_node(znode.version, name, znode.ephemeralOwner, value)
def get_children(self, key, watch=None):
def get_children(self, key: str, watch: Optional[Callable[[WatchedEvent], None]] = None) -> List[str]:
try:
return self._client.get_children(key, watch)
except NoNodeError:
return []
def load_members(self, path):
members = []
def load_members(self, path: str) -> List[Member]:
members: List[Member] = []
for member in self.get_children(path + self._MEMBERS, self.cluster_watcher):
data = self.get_node(path + self._MEMBERS + member)
if data is not None:
members.append(self.member(member, *data))
return members
def _cluster_loader(self, path):
def _cluster_loader(self, path: str) -> Cluster:
self._fetch_cluster = False
self.event.clear()
nodes = set(self.get_children(path, self.cluster_watcher))
@@ -286,9 +291,9 @@ class ZooKeeper(AbstractDCS):
return Cluster(initialize, config, leader, last_lsn, members, failover, sync, history, slots, failsafe)
def _citus_cluster_loader(self, path):
def _citus_cluster_loader(self, path: str) -> Dict[int, Cluster]:
fetch_cluster = False
ret = {}
ret: Dict[int, Cluster] = {}
for node in self.get_children(path, self.cluster_watcher):
if citus_group_re.match(node):
ret[int(node)] = self._cluster_loader(path + node + '/')
@@ -296,7 +301,9 @@ class ZooKeeper(AbstractDCS):
self._fetch_cluster = fetch_cluster
return ret
def _load_cluster(self, path, loader):
def _load_cluster(
self, path: str, loader: Callable[[str], Union[Cluster, Dict[int, Cluster]]]
) -> Union[Cluster, Dict[int, Cluster]]:
cluster = self.cluster if path == self._base_path + '/' else None
if self._fetch_cluster or cluster is None:
try:
@@ -315,18 +322,18 @@ class ZooKeeper(AbstractDCS):
try:
last_lsn, slots = self.get_status(self.client_path(''), cluster.leader)
self.event.clear()
cluster = list(cluster)
cluster[3] = last_lsn
cluster[8] = slots
cluster = Cluster(*cluster)
new_cluster: List[Any] = list(cluster)
new_cluster[3] = last_lsn
new_cluster[8] = slots
cluster = Cluster(*new_cluster)
except Exception:
pass
return cluster
def _bypass_caches(self):
def _bypass_caches(self) -> None:
self._fetch_cluster = True
def _create(self, path, value, retry=False, ephemeral=False):
def _create(self, path: str, value: bytes, retry: bool = False, ephemeral: bool = False) -> bool:
try:
if retry:
self._client.retry(self._client.create, path, value, makepath=True, ephemeral=ephemeral)
@@ -337,7 +344,7 @@ class ZooKeeper(AbstractDCS):
logger.exception('Failed to create %s', path)
return False
def attempt_to_acquire_leader(self):
def attempt_to_acquire_leader(self) -> bool:
try:
self._client.retry(self._client.create, self.leader_path, self._name.encode('utf-8'),
makepath=True, ephemeral=True)
@@ -350,43 +357,44 @@ class ZooKeeper(AbstractDCS):
logger.info('Could not take out TTL lock')
return False
def _set_or_create(self, key, value, index=None, retry=False, do_not_create_empty=False):
value = value.encode('utf-8')
def _set_or_create(self, key: str, value: str, index: Optional[int] = None,
retry: bool = False, do_not_create_empty: bool = False) -> bool:
value_bytes = value.encode('utf-8')
try:
if retry:
self._client.retry(self._client.set, key, value, version=index or -1)
self._client.retry(self._client.set, key, value_bytes, version=index or -1)
else:
self._client.set_async(key, value, version=index or -1).get(timeout=1)
self._client.set_async(key, value_bytes, version=index or -1).get(timeout=1)
return True
except NoNodeError:
if do_not_create_empty and not value:
if do_not_create_empty and not value_bytes:
return True
elif index is None:
return self._create(key, value, retry)
return self._create(key, value_bytes, retry)
else:
return False
except Exception:
logger.exception('Failed to update %s', key)
return False
def set_failover_value(self, value, index=None):
def set_failover_value(self, value: str, index: Optional[int] = None) -> bool:
return self._set_or_create(self.failover_path, value, index)
def set_config_value(self, value, index=None):
def set_config_value(self, value: str, index: Optional[int] = None) -> bool:
return self._set_or_create(self.config_path, value, index, retry=True)
def initialize(self, create_new=True, sysid=""):
sysid = sysid.encode('utf-8')
return self._create(self.initialize_path, sysid, retry=True) if create_new \
else self._client.retry(self._client.set, self.initialize_path, sysid)
def initialize(self, create_new: bool = True, sysid: str = "") -> bool:
sysid_bytes = sysid.encode('utf-8')
return self._create(self.initialize_path, sysid_bytes, retry=True) if create_new \
else self._client.retry(self._client.set, self.initialize_path, sysid_bytes)
def touch_member(self, data):
def touch_member(self, data: Dict[str, Any]) -> bool:
cluster = self.cluster
member = cluster and cluster.get_member(self._name, fallback_to_leader=False)
member_data = self.__last_member_data or member and member.data
# We want to notify leader if some important fields in the member key changed by removing ZNode
if member and (self._client.client_id is not None and member.session != self._client.client_id[0]
or not (deep_compare(member_data.get('tags', {}), data.get('tags', {}))
or not (member_data and deep_compare(member_data.get('tags', {}), data.get('tags', {}))
and (member_data.get('state') == data.get('state')
or 'running' not in (member_data.get('state'), data.get('state')))
and member_data.get('version') == data.get('version')
@@ -401,7 +409,7 @@ class ZooKeeper(AbstractDCS):
member = None
encoded_data = json.dumps(data, separators=(',', ':')).encode('utf-8')
if member:
if member and member_data:
if deep_compare(data, member_data):
return True
else:
@@ -422,19 +430,19 @@ class ZooKeeper(AbstractDCS):
return False
def take_leader(self):
def take_leader(self) -> bool:
return self.attempt_to_acquire_leader()
def _write_leader_optime(self, last_lsn):
def _write_leader_optime(self, last_lsn: str) -> bool:
return self._set_or_create(self.leader_optime_path, last_lsn)
def _write_status(self, value):
def _write_status(self, value: str) -> bool:
return self._set_or_create(self.status_path, value)
def _write_failsafe(self, value):
def _write_failsafe(self, value: str) -> bool:
return self._set_or_create(self.failsafe_path, value)
def _update_leader(self):
def _update_leader(self) -> bool:
cluster = self.cluster
session = cluster and isinstance(cluster.leader, Leader) and cluster.leader.session
if self._client.client_id and self._client.client_id[0] != session:
@@ -459,37 +467,39 @@ class ZooKeeper(AbstractDCS):
return False
return True
def _delete_leader(self):
def _delete_leader(self) -> bool:
self._client.restart()
return True
def _cancel_initialization(self):
def _cancel_initialization(self) -> None:
node = self.get_node(self.initialize_path)
if node:
self._client.delete(self.initialize_path, version=node[1].version)
def cancel_initialization(self):
def cancel_initialization(self) -> bool:
try:
self._client.retry(self._cancel_initialization)
return True
except Exception:
logger.exception("Unable to delete initialize key")
return False
def delete_cluster(self):
def delete_cluster(self) -> bool:
try:
return self._client.retry(self._client.delete, self.client_path(''), recursive=True)
except NoNodeError:
return True
def set_history_value(self, value):
def set_history_value(self, value: str) -> bool:
return self._set_or_create(self.history_path, value)
def set_sync_state_value(self, value, index=None):
def set_sync_state_value(self, value: str, index: Optional[int] = None) -> bool:
return self._set_or_create(self.sync_path, value, index, retry=True, do_not_create_empty=True)
def delete_sync_state(self, index=None):
def delete_sync_state(self, index: Optional[int] = None) -> bool:
return self.set_sync_state_value("{}", index)
def watch(self, leader_index, timeout):
def watch(self, leader_index: Optional[int], timeout: float) -> bool:
ret = super(ZooKeeper, self).watch(leader_index, timeout + 0.5)
if ret and not self._fetch_status:
self._fetch_cluster = True

View File

@@ -1,11 +1,14 @@
from typing import Any
class PatroniException(Exception):
"""Parent class for all kind of exceptions related to selected distributed configuration store"""
def __init__(self, value):
def __init__(self, value: Any) -> None:
self.value = value
def __str__(self):
def __str__(self) -> str:
"""
>>> str(PatroniException('foo'))
"'foo'"

View File

@@ -6,27 +6,26 @@ import sys
import time
import uuid
from collections import namedtuple
from multiprocessing.pool import ThreadPool
from threading import RLock
from typing import List, Optional, Union
from typing import Any, Callable, Collection, Dict, List, NamedTuple, Optional, Union, Tuple, TYPE_CHECKING
from . import psycopg
from .__main__ import Patroni
from .async_executor import AsyncExecutor, CriticalTask
from .collections import CaseInsensitiveSet
from .dcs import AbstractDCS, Cluster, Leader, Member, RemoteMember
from .exceptions import DCSError, PostgresConnectionException, PatroniFatalException
from .postgresql.callback_executor import CallbackAction
from .postgresql.misc import postgres_version_to_int
from .postgresql.postmaster import PostmasterProcess
from .postgresql.rewind import Rewind
from .utils import polling_loop, tzutc
from .dcs import Cluster, Leader, Member, RemoteMember
logger = logging.getLogger(__name__)
class _MemberStatus(namedtuple('_MemberStatus', ['member', 'reachable', 'in_recovery',
'dcs_last_seen', 'timeline', 'wal_position',
'tags', 'watchdog_failed'])):
class _MemberStatus(NamedTuple):
"""Node status distilled from API response:
member - dcs.Member object of the node
@@ -38,29 +37,37 @@ class _MemberStatus(namedtuple('_MemberStatus', ['member', 'reachable', 'in_reco
tags - dictionary with values of different tags (i.e. nofailover)
watchdog_failed - indicates that watchdog is required by configuration but not available or failed
"""
member: Member
reachable: bool
in_recovery: Optional[bool]
dcs_last_seen: int
timeline: int
wal_position: int
tags: Dict[str, Any]
watchdog_failed: bool
@classmethod
def from_api_response(cls, member, json):
def from_api_response(cls, member: Member, json: Dict[str, Any]) -> '_MemberStatus':
"""
:param member: dcs.Member object
:param json: RestApiHandler.get_postgresql_status() result
:returns: _MemberStatus object
"""
# If one of those is not in a response we want to count the node as not healthy/reachable
assert 'wal' in json or 'xlog' in json
wal = json.get('wal', json.get('xlog'))
in_recovery = not bool(wal.get('location')) # abuse difference in primary/replica response format
wal: Dict[str, Any] = json.get('wal') or json['xlog']
# abuse difference in primary/replica response format
in_recovery = not bool(wal.get('location')) or json.get('role') in ('master', 'primary')
timeline = json.get('timeline', 0)
dcs_last_seen = json.get('dcs_last_seen', 0)
wal = in_recovery and max(wal.get('received_location', 0), wal.get('replayed_location', 0))
return cls(member, True, in_recovery, dcs_last_seen, timeline, wal,
lsn = int(in_recovery and max(wal.get('received_location', 0), wal.get('replayed_location', 0)))
return cls(member, True, in_recovery, dcs_last_seen, timeline, lsn,
json.get('tags', {}), json.get('watchdog_failed', False))
@classmethod
def unknown(cls, member):
def unknown(cls, member: Member) -> '_MemberStatus':
return cls(member, False, None, 0, 0, 0, {}, False)
def failover_limitation(self):
def failover_limitation(self) -> Optional[str]:
"""Returns reason why this node can't promote or None if everything is ok."""
if not self.reachable:
return 'not reachable'
@@ -73,7 +80,7 @@ class _MemberStatus(namedtuple('_MemberStatus', ['member', 'reachable', 'in_reco
class Failsafe(object):
def __init__(self, dcs):
def __init__(self, dcs: AbstractDCS) -> None:
self._lock = RLock()
self._dcs = dcs
self._last_update = 0
@@ -82,7 +89,7 @@ class Failsafe(object):
self._api_url = None
self._slots = None
def update(self, data):
def update(self, data: Dict[str, Any]) -> None:
with self._lock:
self._last_update = time.time()
self._name = data['name']
@@ -91,26 +98,22 @@ class Failsafe(object):
self._slots = data.get('slots')
@property
def leader(self):
def leader(self) -> Optional[Leader]:
with self._lock:
if self._last_update + self._dcs.ttl > time.time():
return Leader(None, None,
RemoteMember(self._name, {'api_url': self._api_url,
'conn_url': self._conn_url,
'slots': self._slots}))
return Leader('', '', RemoteMember(self._name, {'api_url': self._api_url,
'conn_url': self._conn_url,
'slots': self._slots}))
def update_cluster(self, cluster):
def update_cluster(self, cluster: Cluster) -> Cluster:
# Enreach cluster with the real leader if there was a ping from it
leader = self.leader
if leader:
cluster = list(cluster)
# We rely on the strict order of fields in the namedtuple
cluster[2] = leader
cluster[8] = leader.member.data['slots']
cluster = Cluster(*cluster)
cluster = Cluster(*cluster[0:2], leader, *cluster[3:8], leader.member.data['slots'], *cluster[9:])
return cluster
def is_active(self):
def is_active(self) -> bool:
"""Is used to report in REST API whether the failsafe mode was activated.
On primary the self._last_update is set from the
@@ -124,21 +127,21 @@ class Failsafe(object):
with self._lock:
return self._last_update + self._dcs.ttl > time.time()
def set_is_active(self, value):
def set_is_active(self, value: float) -> None:
with self._lock:
self._last_update = value
class Ha(object):
def __init__(self, patroni):
def __init__(self, patroni: Patroni):
self.patroni = patroni
self.state_handler = patroni.postgresql
self._rewind = Rewind(self.state_handler)
self.dcs = patroni.dcs
self.cluster = None
self.cluster = Cluster.empty()
self.global_config = self.patroni.config.get_global_config(None)
self.old_cluster = None
self.old_cluster = Cluster.empty()
self._is_leader = False
self._is_leader_lock = RLock()
self._failsafe = Failsafe(patroni.dcs)
@@ -147,7 +150,7 @@ class Ha(object):
self.recovering = False
self._async_response = CriticalTask()
self._crash_recovery_executed = False
self._crash_recovery_started = None
self._crash_recovery_started = 0
self._start_timeout = None
self._async_executor = AsyncExecutor(self.state_handler.cancellable, self.wakeup)
self.watchdog = patroni.watchdog
@@ -173,27 +176,27 @@ class Ha(object):
ret = self.global_config.primary_stop_timeout
return ret if ret > 0 and self.is_synchronous_mode() else None
def is_paused(self):
def is_paused(self) -> bool:
""":returns: `True` if in maintenance mode."""
return self.global_config.is_paused
def check_timeline(self):
def check_timeline(self) -> bool:
""":returns: `True` if should check whether the timeline is latest during the leader race."""
return self.global_config.check_mode('check_timeline')
def is_standby_cluster(self):
def is_standby_cluster(self) -> bool:
""":returns: `True` if global configuration has a valid "standby_cluster" section."""
return self.global_config.is_standby_cluster
def is_leader(self):
def is_leader(self) -> bool:
with self._is_leader_lock:
return self._is_leader > time.time()
def set_is_leader(self, value):
def set_is_leader(self, value: bool) -> None:
with self._is_leader_lock:
self._is_leader = time.time() + self.dcs.ttl if value else 0
def load_cluster_from_dcs(self):
def load_cluster_from_dcs(self) -> None:
cluster = self.dcs.get_cluster()
# We want to keep the state of cluster when it was healthy
@@ -208,9 +211,9 @@ class Ha(object):
if not self.has_lock(False):
self.set_is_leader(False)
self._leader_timeline = None if cluster.is_unlocked() else cluster.leader.timeline
self._leader_timeline = cluster.leader.timeline if cluster.leader else None
def acquire_lock(self):
def acquire_lock(self) -> bool:
try:
ret = self.dcs.attempt_to_acquire_leader()
except DCSError:
@@ -221,14 +224,14 @@ class Ha(object):
self.set_is_leader(ret)
return ret
def _failsafe_config(self):
def _failsafe_config(self) -> Optional[Dict[str, str]]:
if self.is_failsafe_mode():
ret = {m.name: m.api_url for m in self.cluster.members}
ret = {m.name: m.api_url for m in self.cluster.members if m.api_url}
if self.state_handler.name not in ret:
ret[self.state_handler.name] = self.patroni.api.connection_string
return ret
def update_lock(self, write_leader_optime=False):
def update_lock(self, write_leader_optime: bool = False) -> bool:
last_lsn = slots = None
if write_leader_optime:
try:
@@ -248,13 +251,13 @@ class Ha(object):
self.watchdog.keepalive()
return ret
def has_lock(self, info=True):
def has_lock(self, info: bool = True) -> bool:
lock_owner = self.cluster.leader and self.cluster.leader.name
if info:
logger.info('Lock owner: %s; I am %s', lock_owner, self.state_handler.name)
return lock_owner == self.state_handler.name
def get_effective_tags(self):
def get_effective_tags(self) -> Dict[str, Any]:
"""Return configuration tags merged with dynamically applied tags."""
tags = self.patroni.tags.copy()
# _disable_sync could be modified concurrently, but we don't care as attribute get and set are atomic.
@@ -262,10 +265,10 @@ class Ha(object):
tags['nosync'] = True
return tags
def notify_citus_coordinator(self, event):
def notify_citus_coordinator(self, event: str) -> None:
if self.state_handler.citus_handler.is_worker():
coordinator = self.dcs.get_citus_coordinator()
if coordinator and coordinator.leader and coordinator.leader.conn_kwargs:
if coordinator and coordinator.leader and coordinator.leader.conn_url:
try:
data = {'type': event,
'group': self.state_handler.citus_handler.group(),
@@ -278,9 +281,9 @@ class Ha(object):
logger.warning('Request to Citus coordinator leader %s %s failed: %r',
coordinator.leader.name, coordinator.leader.member.api_url, e)
def touch_member(self):
def touch_member(self) -> bool:
with self._member_state_lock:
data = {
data: Dict[str, Any] = {
'conn_url': self.state_handler.connection_string,
'api_url': self.patroni.api.connection_string,
'state': self.state_handler.state,
@@ -302,6 +305,7 @@ class Ha(object):
if self._async_executor.scheduled_action in (None, 'promote') \
and data['state'] in ['running', 'restarting', 'starting']:
try:
timeline: Optional[int]
timeline, wal_position, pg_control_timeline = self.state_handler.timeline_wal_position()
data['xlog_location'] = wal_position
if not timeline: # try pg_stat_wal_receiver to get the timeline
@@ -317,7 +321,7 @@ class Ha(object):
if self.state_handler.role == 'standby_leader':
timeline = pg_control_timeline or self.state_handler.pg_control_timeline()
else:
timeline = self.state_handler.replica_cached_timeline(self._leader_timeline)
timeline = self.state_handler.replica_cached_timeline(self._leader_timeline) or 0
if timeline:
data['timeline'] = timeline
except Exception:
@@ -338,7 +342,7 @@ class Ha(object):
self._last_state = new_state
return ret
def clone(self, clone_member=None, msg='(without leader)'):
def clone(self, clone_member: Union[Leader, Member, None] = None, msg: str = '(without leader)') -> Optional[bool]:
if self.is_standby_cluster() and not isinstance(clone_member, RemoteMember):
clone_member = self.get_remote_member(clone_member)
@@ -352,16 +356,10 @@ class Ha(object):
logger.error('failed to bootstrap %s', msg)
self.state_handler.remove_data_directory()
def bootstrap(self):
if not self.cluster.is_unlocked(): # cluster already has leader
clone_member = self.cluster.get_clone_member(self.state_handler.name)
member_role = 'leader' if clone_member == self.cluster.leader else 'replica'
msg = "from {0} '{1}'".format(member_role, clone_member.name)
ret = self._async_executor.try_run_async('bootstrap {0}'.format(msg), self.clone, args=(clone_member, msg))
return ret or 'trying to bootstrap {0}'.format(msg)
def bootstrap(self) -> str:
# no initialize key and node is allowed to be primary and has 'bootstrap' section in a configuration file
elif self.cluster.initialize is None and not self.patroni.nofailover and 'bootstrap' in self.patroni.config:
if self.cluster.is_unlocked() and self.cluster.initialize is None\
and not self.patroni.nofailover and 'bootstrap' in self.patroni.config:
if self.dcs.initialize(create_new=True): # race for initialization
self.state_handler.bootstrapping = True
with self._async_response:
@@ -376,17 +374,26 @@ class Ha(object):
return ret or 'trying to bootstrap a new cluster'
else:
return 'failed to acquire initialize lock'
else:
create_replica_methods = self.global_config.get_standby_cluster_config().get('create_replica_methods', []) \
if self.is_standby_cluster() else None
can_bootstrap = self.state_handler.can_create_replica_without_replication_connection(create_replica_methods)
concurrent_bootstrap = self.cluster.initialize == ""
if can_bootstrap and not concurrent_bootstrap:
msg = 'bootstrap (without leader)'
return self._async_executor.try_run_async(msg, self.clone) or 'trying to ' + msg
return 'waiting for {0}leader to bootstrap'.format('standby_' if self.is_standby_cluster() else '')
def bootstrap_standby_leader(self):
clone_member = self.cluster.get_clone_member(self.state_handler.name)
# cluster already has a leader, we can bootstrap from it or from one of replicas (if they allow)
if not self.cluster.is_unlocked() and clone_member:
member_role = 'leader' if clone_member == self.cluster.leader else 'replica'
msg = "from {0} '{1}'".format(member_role, clone_member.name)
ret = self._async_executor.try_run_async('bootstrap {0}'.format(msg), self.clone, args=(clone_member, msg))
return ret or 'trying to bootstrap {0}'.format(msg)
# no leader, but configuration may allowed replica creation using backup tools
create_replica_methods = self.global_config.get_standby_cluster_config().get('create_replica_methods', []) \
if self.is_standby_cluster() else None
can_bootstrap = self.state_handler.can_create_replica_without_replication_connection(create_replica_methods)
concurrent_bootstrap = self.cluster.initialize == ""
if can_bootstrap and not concurrent_bootstrap:
msg = 'bootstrap (without leader)'
return self._async_executor.try_run_async(msg, self.clone) or 'trying to ' + msg
return 'waiting for {0}leader to bootstrap'.format('standby_' if self.is_standby_cluster() else '')
def bootstrap_standby_leader(self) -> Optional[bool]:
""" If we found 'standby' key in the configuration, we need to bootstrap
not a real primary, but a 'standby leader', that will take base backup
from a remote member and start follow it.
@@ -401,16 +408,16 @@ class Ha(object):
return result
def _handle_crash_recovery(self):
def _handle_crash_recovery(self) -> Optional[str]:
if not self._crash_recovery_executed and (self.cluster.is_unlocked() or self._rewind.can_rewind):
self._crash_recovery_executed = True
self._crash_recovery_started = time.time()
msg = 'doing crash recovery in a single user mode'
return self._async_executor.try_run_async(msg, self._rewind.ensure_clean_shutdown) or msg
def _handle_rewind_or_reinitialize(self):
def _handle_rewind_or_reinitialize(self) -> Optional[str]:
leader = self.get_remote_member() if self.is_standby_cluster() else self.cluster.leader
if not self._rewind.rewind_or_reinitialize_needed_and_possible(leader):
if not self._rewind.rewind_or_reinitialize_needed_and_possible(leader) or not leader:
return None
if self._rewind.can_rewind:
@@ -428,7 +435,7 @@ class Ha(object):
msg = 'reinitializing due to diverged timelines'
return self._async_executor.try_run_async(msg, self._do_reinitialize, args=(self.cluster,)) or msg
def recover(self):
def recover(self) -> str:
# Postgres is not running and we will restart in standby mode. Watchdog is not needed until we promote.
self.watchdog.disable()
@@ -454,7 +461,10 @@ class Ha(object):
self.load_cluster_from_dcs()
role = 'replica'
if self.is_standby_cluster() or not self.has_lock():
if self.has_lock() and not self.is_standby_cluster():
msg = "starting as readonly because i had the session lock"
node_to_follow = None
else:
if not self._rewind.executed:
self._rewind.trigger_check_diverged_lsn()
msg = self._handle_rewind_or_reinitialize()
@@ -474,16 +484,13 @@ class Ha(object):
if self.is_synchronous_mode():
self.state_handler.sync_handler.set_synchronous_standby_names(CaseInsensitiveSet())
elif self.has_lock():
msg = "starting as readonly because i had the session lock"
node_to_follow = None
if self._async_executor.try_run_async('restarting after failure', self.state_handler.follow,
args=(node_to_follow, role, timeout)) is None:
self.recovering = True
return msg
def _get_node_to_follow(self, cluster):
def _get_node_to_follow(self, cluster: Cluster) -> Union[Leader, Member, None]:
# determine the node to follow. If replicatefrom tag is set,
# try to follow the node mentioned there, otherwise, follow the leader.
if self.is_standby_cluster() and (self.cluster.is_unlocked() or self.has_lock(False)):
@@ -506,7 +513,7 @@ class Ha(object):
return node_to_follow
def follow(self, demote_reason, follow_reason, refresh=True):
def follow(self, demote_reason: str, follow_reason: str, refresh: bool = True) -> str:
if refresh:
self.load_cluster_from_dcs()
@@ -565,7 +572,7 @@ class Ha(object):
""":returns: `True` if synchronous replication is requested."""
return self.global_config.is_synchronous_mode
def is_failsafe_mode(self):
def is_failsafe_mode(self) -> bool:
""":returns: `True` if failsafe_mode is enabled in global configuration."""
return self.global_config.check_mode('failsafe_mode')
@@ -625,10 +632,10 @@ class Ha(object):
def is_sync_standby(self, cluster: Cluster) -> bool:
""":returns: `True` if the current node is a synchronous standby."""
return cluster.leader and cluster.sync.leader_matches(cluster.leader.name) \
return bool(cluster.leader) and cluster.sync.leader_matches(cluster.leader.name) \
and cluster.sync.matches(self.state_handler.name)
def while_not_sync_standby(self, func):
def while_not_sync_standby(self, func: Callable[..., Any]) -> Any:
"""Runs specified action while trying to make sure that the node is not assigned synchronous standby status.
Tags us as not allowed to be a sync standby as we are going to go away, if we currently are wait for
@@ -665,32 +672,32 @@ class Ha(object):
with self._member_state_lock:
self._disable_sync -= 1
def update_cluster_history(self):
def update_cluster_history(self) -> None:
primary_timeline = self.state_handler.get_primary_timeline()
cluster_history = self.cluster.history and self.cluster.history.lines
cluster_history = self.cluster.history.lines if self.cluster.history else []
if primary_timeline == 1:
if cluster_history:
self.dcs.set_history_value('[]')
elif not cluster_history or cluster_history[-1][0] != primary_timeline - 1 or len(cluster_history[-1]) != 5:
cluster_history = {line[0]: line for line in cluster_history or []}
history = self.state_handler.get_history(primary_timeline)
if history and self.cluster.config:
cluster_history = {line[0]: line for line in cluster_history}
history: List[List[Any]] = list(map(list, self.state_handler.get_history(primary_timeline)))
if self.cluster.config:
history = history[-self.cluster.config.max_timelines_history:]
for line in history:
# enrich current history with promotion timestamps stored in DCS
if len(line) == 3 and line[0] in cluster_history \
and len(cluster_history[line[0]]) >= 4 \
and cluster_history[line[0]][1] == line[1]:
line.append(cluster_history[line[0]][3])
if len(cluster_history[line[0]]) == 5:
line.append(cluster_history[line[0]][4])
for line in history:
# enrich current history with promotion timestamps stored in DCS
cluster_history_line = list(cluster_history.get(line[0], []))
if len(line) == 3 and len(cluster_history_line) >= 4 and cluster_history_line[1] == line[1]:
line.append(cluster_history_line[3])
if len(cluster_history_line) == 5:
line.append(cluster_history_line[4])
if history:
self.dcs.set_history_value(json.dumps(history, separators=(',', ':')))
def enforce_follow_remote_member(self, message):
def enforce_follow_remote_member(self, message: str) -> str:
demote_reason = 'cannot be a real primary in standby cluster'
return self.follow(demote_reason, message)
def enforce_primary_role(self, message, promote_message):
def enforce_primary_role(self, message: str, promote_message: str) -> str:
"""
Ensure the node that has won the race for the leader key meets criteria
for promoting its PG server to the 'primary' role.
@@ -749,7 +756,7 @@ class Ha(object):
before_promote, on_success))
return promote_message
def fetch_node_status(self, member):
def fetch_node_status(self, member: Member) -> _MemberStatus:
"""This function perform http get request on member.api_url and fetches its status
:returns: `_MemberStatus` object
"""
@@ -763,36 +770,36 @@ class Ha(object):
logger.warning("Request failed to %s: GET %s (%s)", member.name, member.api_url, e)
return _MemberStatus.unknown(member)
def fetch_nodes_statuses(self, members):
def fetch_nodes_statuses(self, members: List[Member]) -> List[_MemberStatus]:
pool = ThreadPool(len(members))
results = pool.map(self.fetch_node_status, members) # Run API calls on members in parallel
pool.close()
pool.join()
return results
def update_failsafe(self, data):
def update_failsafe(self, data: Dict[str, Any]) -> Optional[str]:
if self.state_handler.state == 'running' and self.state_handler.role in ('master', 'primary'):
return 'Running as a leader'
self._failsafe.update(data)
def failsafe_is_active(self):
def failsafe_is_active(self) -> bool:
return self._failsafe.is_active()
def call_failsafe_member(self, data, member):
def call_failsafe_member(self, data: Dict[str, Any], member: Member) -> bool:
try:
response = self.patroni.request(member, 'post', 'failsafe', data, timeout=2, retries=1)
data = response.data.decode('utf-8')
logger.info('Got response from %s %s: %s', member.name, member.api_url, data)
return response.status == 200 and data == 'Accepted'
response_data = response.data.decode('utf-8')
logger.info('Got response from %s %s: %s', member.name, member.api_url, response_data)
return response.status == 200 and response_data == 'Accepted'
except Exception as e:
logger.warning("Request failed to %s: POST %s (%s)", member.name, member.api_url, e)
return False
def check_failsafe_topology(self):
def check_failsafe_topology(self) -> bool:
failsafe = self.dcs.failsafe
if not isinstance(failsafe, dict) or self.state_handler.name not in failsafe:
return False
data = {
data: Dict[str, Any] = {
'name': self.state_handler.name,
'conn_url': self.state_handler.connection_string,
'api_url': self.patroni.api.connection_string,
@@ -813,7 +820,7 @@ class Ha(object):
pool.join()
return all(results)
def is_lagging(self, wal_position):
def is_lagging(self, wal_position: int) -> bool:
"""Returns if instance with an wal should consider itself unhealthy to be promoted due to replication lag.
:param wal_position: Current wal position.
@@ -822,7 +829,7 @@ class Ha(object):
lag = (self.cluster.last_lsn or 0) - wal_position
return lag > self.global_config.maximum_lag_on_failover
def _is_healthiest_node(self, members, check_replication_lag=True):
def _is_healthiest_node(self, members: Collection[Member], check_replication_lag: bool = True) -> bool:
"""This method tries to determine whether I am healthy enough to became a new leader candidate or not."""
my_wal_position = self.state_handler.last_operation()
@@ -833,6 +840,9 @@ class Ha(object):
if not self.is_standby_cluster() and self.check_timeline():
cluster_timeline = self.cluster.timeline
my_timeline = self.state_handler.replica_cached_timeline(cluster_timeline)
if my_timeline is None:
logger.info('Can not figure out my timeline')
return False
if my_timeline < cluster_timeline:
logger.info('My timeline %s is behind last known cluster timeline %s', my_timeline, cluster_timeline)
return False
@@ -843,7 +853,7 @@ class Ha(object):
if members:
for st in self.fetch_nodes_statuses(members):
if st.failover_limitation() is None:
if not st.in_recovery:
if st.in_recovery is False:
logger.warning('Primary (%s) is still alive', st.member.name)
return False
if my_wal_position < st.wal_position:
@@ -875,8 +885,6 @@ class Ha(object):
not_allowed_reason = st.failover_limitation()
if not_allowed_reason:
logger.info('Member %s is %s', st.member.name, not_allowed_reason)
elif not isinstance(st.wal_position, int):
logger.info('Member %s does not report wal_position', st.member.name)
elif cluster_lsn and st.wal_position < cluster_lsn or\
not cluster_lsn and self.is_lagging(st.wal_position):
logger.info('Member %s exceeds maximum replication lag', st.member.name)
@@ -889,13 +897,15 @@ class Ha(object):
logger.warning('manual failover: members list is empty')
return ret
def manual_failover_process_no_leader(self) -> Union[bool, None]:
def manual_failover_process_no_leader(self) -> Optional[bool]:
"""Handles manual failover/switchover when the old leader already stepped down.
:returns: - `True` if the current node is the best candidate to become the new leader
- `None` if the current node is running as a primary and requested candidate doesn't exist
"""
failover = self.cluster.failover
if TYPE_CHECKING: # pragma: no cover
assert failover is not None
if failover.candidate: # manual failover to specific member
if failover.candidate == self.state_handler.name: # manual failover to me
return True
@@ -905,7 +915,7 @@ class Ha(object):
if not self.cluster.get_member(failover.candidate, fallback_to_leader=False)\
and self.state_handler.is_leader():
logger.warning("manual failover: removing failover key because failover candidate is not running")
self.dcs.manual_failover('', '', index=self.cluster.failover.index)
self.dcs.manual_failover('', '', index=failover.index)
return None
return False
@@ -916,7 +926,7 @@ class Ha(object):
# find specific node and check that it is healthy
member = self.cluster.get_member(failover.candidate, fallback_to_leader=False)
if member:
if isinstance(member, Member):
st = self.fetch_node_status(member)
not_allowed_reason = st.failover_limitation()
if not_allowed_reason is None: # node is healthy
@@ -981,7 +991,7 @@ class Ha(object):
if self.is_synchronous_mode() and self.cluster.failover.leader and \
not self.cluster.sync.is_empty and not self.cluster.sync.matches(self.state_handler.name, True):
return False
return self.manual_failover_process_no_leader()
return self.manual_failover_process_no_leader() or False
if not self.watchdog.is_healthy:
logger.warning('Watchdog device is not usable')
@@ -1011,17 +1021,17 @@ class Ha(object):
return self._is_healthiest_node(members.values())
def _delete_leader(self, last_lsn=None):
def _delete_leader(self, last_lsn: Optional[int] = None) -> None:
self.set_is_leader(False)
self.dcs.delete_leader(last_lsn)
self.dcs.reset_cluster()
def release_leader_key_voluntarily(self, last_lsn=None):
def release_leader_key_voluntarily(self, last_lsn: Optional[int] = None) -> None:
self._delete_leader(last_lsn)
self.touch_member()
logger.info("Leader key released")
def demote(self, mode):
def demote(self, mode: str) -> Optional[bool]:
"""Demote PostgreSQL running as primary.
:param mode: One of offline, graceful or immediate.
@@ -1046,7 +1056,7 @@ class Ha(object):
status = {'released': False}
def on_shutdown(checkpoint_location):
def on_shutdown(checkpoint_location: int) -> None:
# Postmaster is still running, but pg_control already reports clean "shut down".
# It could happen if Postgres is still archiving the backlog of WAL files.
# If we know that there are replicas that received the shutdown checkpoint
@@ -1057,13 +1067,13 @@ class Ha(object):
self.release_leader_key_voluntarily(checkpoint_location)
status['released'] = True
def before_shutdown():
def before_shutdown() -> None:
if self.state_handler.citus_handler.is_coordinator():
self.state_handler.citus_handler.on_demote()
else:
self.notify_citus_coordinator('before_demote')
self.state_handler.stop(mode_control['stop'], checkpoint=mode_control['checkpoint'],
self.state_handler.stop(str(mode_control['stop']), checkpoint=bool(mode_control['checkpoint']),
on_safepoint=self.watchdog.disable if self.watchdog.is_running else None,
on_shutdown=on_shutdown if mode_control['release'] else None,
before_shutdown=before_shutdown if mode == 'graceful' else None,
@@ -1099,7 +1109,8 @@ class Ha(object):
return False # do not start postgres, but run pg_rewind on the next iteration
self.state_handler.follow(node_to_follow)
def should_run_scheduled_action(self, action_name, scheduled_at, cleanup_fn):
def should_run_scheduled_action(self, action_name: str, scheduled_at: Optional[datetime.datetime],
cleanup_fn: Callable[..., Any]) -> bool:
if scheduled_at and not self.is_paused():
# If the scheduled action is in the far future, we shouldn't do anything and just return.
# If the scheduled action is in the past, we consider the value to be stale and we remove
@@ -1133,7 +1144,7 @@ class Ha(object):
cleanup_fn()
return False
def process_manual_failover_from_leader(self):
def process_manual_failover_from_leader(self) -> Optional[str]:
"""Checks if manual failover is requested and takes action if appropriate.
Cleans up failover key if failover conditions are not matched.
@@ -1177,7 +1188,7 @@ class Ha(object):
logger.info('Cleaning up failover key')
self.dcs.manual_failover('', '', index=failover.index)
def process_unhealthy_cluster(self):
def process_unhealthy_cluster(self) -> str:
"""Cluster has no leader key"""
if self.is_healthiest_node():
@@ -1219,7 +1230,7 @@ class Ha(object):
return self.follow('demoting self because i am not the healthiest node',
'following a different leader because i am not the healthiest node')
def process_healthy_cluster(self):
def process_healthy_cluster(self) -> str:
if self.has_lock():
if self.is_paused() and not self.state_handler.is_leader():
if self.cluster.failover and self.cluster.failover.candidate == self.state_handler.name:
@@ -1271,7 +1282,7 @@ class Ha(object):
'no action. I am ({0}), a secondary, and following a leader ({1})'.format(
self.state_handler.name, lock_owner), refresh=False)
def evaluate_scheduled_restart(self):
def evaluate_scheduled_restart(self) -> Optional[str]:
if self._async_executor.busy: # Restart already in progress
return None
@@ -1297,7 +1308,7 @@ class Ha(object):
finally:
self.delete_future_restart()
def restart_matches(self, role, postgres_version, pending_restart):
def restart_matches(self, role: Optional[str], postgres_version: Optional[str], pending_restart: bool) -> bool:
reason_to_cancel = ""
# checking the restart filters here seem to be less ugly than moving them into the
# run_scheduled_action.
@@ -1316,7 +1327,7 @@ class Ha(object):
logger.info("not proceeding with the restart: %s", reason_to_cancel)
return False
def schedule_future_restart(self, restart_data):
def schedule_future_restart(self, restart_data: Dict[str, Any]) -> bool:
with self._async_executor:
restart_data['postmaster_start_time'] = self.state_handler.postmaster_start_time()
if not self.patroni.scheduled_restart:
@@ -1325,7 +1336,7 @@ class Ha(object):
return True
return False
def delete_future_restart(self):
def delete_future_restart(self) -> bool:
ret = False
with self._async_executor:
if self.patroni.scheduled_restart:
@@ -1334,14 +1345,13 @@ class Ha(object):
ret = True
return ret
def future_restart_scheduled(self):
return self.patroni.scheduled_restart.copy()\
if (self.patroni.scheduled_restart and isinstance(self.patroni.scheduled_restart, dict)) else None
def future_restart_scheduled(self) -> Dict[str, Any]:
return self.patroni.scheduled_restart.copy()
def restart_scheduled(self):
def restart_scheduled(self) -> bool:
return self._async_executor.scheduled_action == 'restart'
def restart(self, restart_data, run_async=False):
def restart(self, restart_data: Dict[str, Any], run_async: bool = False) -> Tuple[bool, str]:
""" conditional and unconditional restart """
assert isinstance(restart_data, dict)
@@ -1365,10 +1375,10 @@ class Ha(object):
timeout = restart_data.get('timeout', self.global_config.primary_start_timeout)
self.set_start_timeout(timeout)
def before_shutdown():
def before_shutdown() -> None:
self.notify_citus_coordinator('before_demote')
def after_start():
def after_start() -> None:
self.notify_citus_coordinator('after_promote')
# For non async cases we want to wait for restart to complete or timeout before returning.
@@ -1390,16 +1400,17 @@ class Ha(object):
else:
return (False, 'restart failed')
def _do_reinitialize(self, cluster):
def _do_reinitialize(self, cluster: Cluster) -> Optional[bool]:
self.state_handler.stop('immediate', stop_timeout=self.patroni.config['retry_timeout'])
# Commented redundant data directory cleanup here
# self.state_handler.remove_data_directory()
clone_member = self.cluster.get_clone_member(self.state_handler.name)
member_role = 'leader' if clone_member == self.cluster.leader else 'replica'
return self.clone(clone_member, "from {0} '{1}'".format(member_role, clone_member.name))
clone_member = cluster.get_clone_member(self.state_handler.name)
if clone_member:
member_role = 'leader' if clone_member == cluster.leader else 'replica'
return self.clone(clone_member, "from {0} '{1}'".format(member_role, clone_member.name))
def reinitialize(self, force=False):
def reinitialize(self, force: bool = False) -> Optional[str]:
with self._async_executor:
self.load_cluster_from_dcs()
@@ -1409,6 +1420,8 @@ class Ha(object):
if self.has_lock(False):
return 'I am the leader, can not reinitialize'
cluster = self.cluster
if force:
self._async_executor.cancel()
@@ -1417,9 +1430,9 @@ class Ha(object):
if action is not None:
return '{0} already in progress'.format(action)
self._async_executor.run_async(self._do_reinitialize, args=(self.cluster, ))
self._async_executor.run_async(self._do_reinitialize, args=(cluster, ))
def handle_long_action_in_progress(self):
def handle_long_action_in_progress(self) -> str:
"""
Figure out what to do with the task AsyncExecutor is performing.
"""
@@ -1432,7 +1445,7 @@ class Ha(object):
self.demote('immediate')
return 'terminated crash recovery because of startup timeout'
return 'updated leader lock during ' + self._async_executor.scheduled_action
return 'updated leader lock during {0}'.format(self._async_executor.scheduled_action)
elif not self.state_handler.bootstrapping and not self.is_paused():
# Don't have lock, make sure we are not promoting or starting up a primary in the background
if self._async_executor.scheduled_action == 'promote':
@@ -1443,28 +1456,28 @@ class Ha(object):
return 'lost leader before promote'
if self.state_handler.role in ('master', 'primary'):
logger.info("Demoting primary during " + self._async_executor.scheduled_action)
logger.info('Demoting primary during %s', self._async_executor.scheduled_action)
if self._async_executor.scheduled_action == 'restart':
# Restart needs a special interlocking cancel because postmaster may be just started in a
# background thread and has not even written a pid file yet.
with self._async_executor.critical_task as task:
if not task.cancel():
if not task.cancel() and isinstance(task.result, PostmasterProcess):
self.state_handler.terminate_starting_postmaster(postmaster=task.result)
self.demote('immediate-nolock')
return 'lost leader lock during ' + self._async_executor.scheduled_action
return 'lost leader lock during {0}'.format(self._async_executor.scheduled_action)
if self.cluster.is_unlocked():
logger.info('not healthy enough for leader race')
return self._async_executor.scheduled_action + ' in progress'
return '{0} in progress'.format(self._async_executor.scheduled_action)
@staticmethod
def sysid_valid(sysid):
def sysid_valid(sysid: Optional[str]) -> bool:
# sysid does tv_sec << 32, where tv_sec is the number of seconds sine 1970,
# so even 1 << 32 would have 10 digits.
sysid = str(sysid)
return len(sysid) >= 10 and sysid.isdigit()
def post_recover(self):
def post_recover(self) -> Optional[str]:
if not self.state_handler.is_running():
self.watchdog.disable()
if self.has_lock():
@@ -1475,14 +1488,14 @@ class Ha(object):
return 'failed to start postgres'
return None
def cancel_initialization(self):
def cancel_initialization(self) -> None:
logger.info('removing initialize key after failed attempt to bootstrap the cluster')
self.dcs.cancel_initialization()
self.state_handler.stop('immediate', stop_timeout=self.patroni.config['retry_timeout'])
self.state_handler.move_data_directory()
raise PatroniFatalException('Failed to bootstrap cluster')
def post_bootstrap(self):
def post_bootstrap(self) -> str:
with self._async_response:
result = self._async_response.result
# bootstrap has failed if postgres is not running
@@ -1513,7 +1526,7 @@ class Ha(object):
return 'initialized a new cluster'
def handle_starting_instance(self):
def handle_starting_instance(self) -> Optional[str]:
"""Starting up PostgreSQL may take a long time. In case we are the leader we may want to
fail over to."""
@@ -1552,13 +1565,13 @@ class Ha(object):
logger.info("Still starting up as a standby.")
return None
def set_start_timeout(self, value):
def set_start_timeout(self, value: Optional[int]) -> None:
"""Sets timeout for starting as primary before eligible for failover.
Must be called when async_executor is busy or in the main thread."""
self._start_timeout = value
def _run_cycle(self):
def _run_cycle(self) -> str:
dcs_failed = False
try:
try:
@@ -1619,6 +1632,8 @@ class Ha(object):
return 'started as a secondary'
# is data directory empty?
data_directory_error = ''
data_directory_is_empty = None
try:
data_directory_is_empty = self.state_handler.data_directory_empty()
data_directory_is_accessible = True
@@ -1727,7 +1742,7 @@ class Ha(object):
self._failsafe.set_is_active(0)
self.touch_member()
def _handle_dcs_error(self):
def _handle_dcs_error(self) -> str:
if not self.is_paused() and self.state_handler.is_running():
if self.state_handler.is_leader():
if self.is_failsafe_mode() and self.check_failsafe_topology():
@@ -1746,13 +1761,13 @@ class Ha(object):
self._sync_replication_slots(True)
return 'DCS is not accessible'
def _sync_replication_slots(self, dcs_failed):
def _sync_replication_slots(self, dcs_failed: bool) -> List[str]:
"""Handles replication slots.
:param dcs_failed: bool, indicates that communication with DCS failed (get_cluster() or update_leader())
:returns: list[str], replication slots names that should be copied from the primary"""
slots = []
slots: List[str] = []
# If dcs_failed we don't want to touch replication slots on a leader or replicas if failsafe_mode isn't enabled.
if not self.cluster or dcs_failed and (self.is_leader() or not self.is_failsafe_mode()):
@@ -1772,7 +1787,7 @@ class Ha(object):
# Don't copy replication slots if failsafe_mode is active
return [] if self.failsafe_is_active() else slots
def run_cycle(self):
def run_cycle(self) -> str:
with self._async_executor:
try:
info = self._run_cycle()
@@ -1783,7 +1798,7 @@ class Ha(object):
logger.exception('Unexpected exception')
return 'Unexpected exception raised, please report it as a BUG'
def shutdown(self):
def shutdown(self) -> None:
if self.is_paused():
logger.info('Leader key is not deleted and Postgresql is not stopped due paused state')
self.watchdog.disable()
@@ -1796,7 +1811,7 @@ class Ha(object):
status = {'deleted': False}
def _on_shutdown(checkpoint_location):
def _on_shutdown(checkpoint_location: int) -> None:
if self.is_leader():
# Postmaster is still running, but pg_control already reports clean "shut down".
# It could happen if Postgres is still archiving the backlog of WAL files.
@@ -1808,7 +1823,7 @@ class Ha(object):
else:
self.dcs.write_leader_optime(checkpoint_location)
def _before_shutdown():
def _before_shutdown() -> None:
self.notify_citus_coordinator('before_demote')
on_shutdown = _on_shutdown if self.is_leader() else None
@@ -1829,36 +1844,36 @@ class Ha(object):
logger.error("PostgreSQL shutdown failed, leader key not removed.%s",
(" Leaving watchdog running." if self.watchdog.is_running else ""))
def watch(self, timeout):
def watch(self, timeout: float) -> bool:
# watch on leader key changes if the postgres is running and leader is known and current node is not lock owner
if self._async_executor.busy or not self.cluster or self.cluster.is_unlocked() or self.has_lock(False):
leader_index = None
else:
leader_index = self.cluster.leader.index
leader_index = self.cluster.leader.index if self.cluster.leader else None
return self.dcs.watch(leader_index, timeout)
def wakeup(self):
def wakeup(self) -> None:
"""Call of this method will trigger the next run of HA loop if there is
no "active" leader watch request in progress.
This usually happens on the leader or if the node is running async action"""
self.dcs.event.set()
def get_remote_member(self, member=None):
def get_remote_member(self, member: Union[Leader, Member, None] = None) -> RemoteMember:
""" In case of standby cluster this will tel us from which remote
member to stream. Config can be both patroni config or
cluster.config.data
"""
data: Dict[str, Any] = {}
cluster_params = self.global_config.get_standby_cluster_config()
if cluster_params:
name = member.name if member else 'remote_member:{}'.format(uuid.uuid1())
data = {k: v for k, v in cluster_params.items() if k in RemoteMember.allowed_keys()}
data.update({k: v for k, v in cluster_params.items() if k in RemoteMember.allowed_keys()})
data['no_replication_slot'] = 'primary_slot_name' not in cluster_params
conn_kwargs = member.conn_kwargs() if member else \
{k: cluster_params[k] for k in ('host', 'port') if k in cluster_params}
if conn_kwargs:
data['conn_kwargs'] = conn_kwargs
return RemoteMember(name, data)
name = member.name if member else 'remote_member:{}'.format(uuid.uuid1())
return RemoteMember(name, data)

View File

@@ -13,7 +13,7 @@ from patroni.utils import deep_compare
from queue import Queue, Full
from threading import Lock, Thread
from typing import Any, Dict
from typing import Any, Dict, List, Optional, Union
_LOGGER = logging.getLogger(__name__)
@@ -68,7 +68,7 @@ class QueueHandler(logging.Handler):
def __init__(self) -> None:
"""Queue initialised and initial records_lost established."""
super().__init__()
self.queue = Queue()
self.queue: Queue[Union[logging.LogRecord, None]] = Queue()
self._records_lost = 0
def _put_record(self, record: logging.LogRecord) -> None:
@@ -189,10 +189,10 @@ class PatroniLogger(Thread):
super(PatroniLogger, self).__init__()
self._queue_handler = QueueHandler()
self._root_logger = logging.getLogger()
self._config = None
self._config: Optional[Dict[str, Any]] = None
self.log_handler = None
self.log_handler_lock = Lock()
self._old_handlers = []
self._old_handlers: List[logging.Handler] = []
# initially set log level to ``DEBUG`` while the logger thread has not started running yet. The daemon process
# will later adjust all log related settings with what was provided through the user configuration file.
self.reload_config({'level': 'DEBUG'})

View File

@@ -12,7 +12,7 @@ from datetime import datetime
from dateutil import tz
from psutil import TimeoutExpired
from threading import current_thread, Lock
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING
from typing import Any, Callable, Dict, Generator, List, Optional, Union, Tuple, TYPE_CHECKING
from .bootstrap import Bootstrap
from .callback_executor import CallbackAction, CallbackExecutor
@@ -25,11 +25,14 @@ from .postmaster import PostmasterProcess
from .slots import SlotsHandler
from .sync import SyncHandler
from .. import psycopg
from ..dcs import Cluster, Member
from ..async_executor import CriticalTask
from ..dcs import Cluster, Leader, Member
from ..exceptions import PostgresConnectionException
from ..utils import Retry, RetryFailedError, polling_loop, data_directory_is_empty, parse_int
if TYPE_CHECKING: # pragma: no cover
from psycopg import Connection as Connection3, Cursor
from psycopg2 import connection as connection3, cursor
from ..config import GlobalConfig
logger = logging.getLogger(__name__)
@@ -59,13 +62,15 @@ class Postgresql(object):
"pg_catalog.pg_{0}_{1}_diff(COALESCE(pg_catalog.pg_last_{0}_receive_{1}(), '0/0'), '0/0')::bigint, "
"pg_catalog.pg_is_in_recovery() AND pg_catalog.pg_is_{0}_replay_paused()")
def __init__(self, config):
self.name = config['name']
self.scope = config['scope']
self._data_dir = config['data_dir']
def __init__(self, config: Dict[str, Any]) -> None:
self.name: str = config['name']
self.scope: str = config['scope']
self._data_dir: str = config['data_dir']
self._database = config.get('database', 'postgres')
self._version_file = os.path.join(self._data_dir, 'PG_VERSION')
self._pg_control = os.path.join(self._data_dir, 'global', 'pg_control')
self.connection_string: str
self.proxy_url: Optional[str]
self._major_version = self.get_major_version()
self._global_config = None
@@ -92,7 +97,7 @@ class Postgresql(object):
self.cancellable = CancellableSubprocess()
self._sysid = None
self._sysid = ''
self.retry = Retry(max_tries=-1, deadline=config['retry_timeout'] / 2.0, max_delay=1,
retry_exceptions=PostgresConnectionException)
@@ -102,7 +107,7 @@ class Postgresql(object):
self._role_lock = Lock()
self.set_role(self.get_postgres_role_from_data_directory())
self._state_entry_timestamp = None
self._state_entry_timestamp = 0
self._cluster_info_state = {}
self._has_permanent_logical_slots = True
@@ -126,35 +131,35 @@ class Postgresql(object):
self.set_role('demoted')
@property
def create_replica_methods(self):
return self.config.get('create_replica_methods', []) or self.config.get('create_replica_method', [])
def create_replica_methods(self) -> List[str]:
return self.config.get('create_replica_methods', []) or self.config.get('create_replica_method', []) or []
@property
def major_version(self):
def major_version(self) -> int:
return self._major_version
@property
def database(self):
def database(self) -> str:
return self._database
@property
def data_dir(self):
def data_dir(self) -> str:
return self._data_dir
@property
def callback(self):
return self.config.get('callbacks') or {}
def callback(self) -> Dict[str, str]:
return self.config.get('callbacks', {}) or {}
@property
def wal_dir(self):
def wal_dir(self) -> str:
return os.path.join(self._data_dir, 'pg_' + self.wal_name)
@property
def wal_name(self):
def wal_name(self) -> str:
return 'wal' if self._major_version >= 100000 else 'xlog'
@property
def lsn_name(self):
def lsn_name(self) -> str:
return 'lsn' if self._major_version >= 100000 else 'location'
@property
@@ -163,7 +168,7 @@ class Postgresql(object):
return self._major_version >= 90600
@property
def cluster_info_query(self):
def cluster_info_query(self) -> str:
"""Returns the monitoring query with a fixed number of fields.
The query text is constructed based on current state in DCS and PostgreSQL version:
@@ -206,7 +211,7 @@ class Postgresql(object):
return ("SELECT " + self.TL_LSN + ", {2}").format(self.wal_name, self.lsn_name, extra)
def _version_file_exists(self):
def _version_file_exists(self) -> bool:
return not self.data_directory_empty() and os.path.isfile(self._version_file)
def get_major_version(self) -> int:
@@ -221,11 +226,11 @@ class Postgresql(object):
logger.exception('Failed to read PG_VERSION from %s', self._data_dir)
return 0
def pgcommand(self, cmd):
def pgcommand(self, cmd: str) -> str:
"""Returns path to the specified PostgreSQL command"""
return os.path.join(self._bin_dir, cmd)
def pg_ctl(self, cmd, *args, **kwargs):
def pg_ctl(self, cmd: str, *args: str, **kwargs: Any) -> bool:
"""Builds and executes pg_ctl command
:returns: `!True` when return_code == 0, otherwise `!False`"""
@@ -244,7 +249,7 @@ class Postgresql(object):
initdb = [self.pgcommand('initdb')] + list(args) + [self.data_dir]
return subprocess.call(initdb, **kwargs) == 0
def pg_isready(self):
def pg_isready(self) -> str:
"""Runs pg_isready to see if PostgreSQL is accepting connections.
:returns: 'ok' if PostgreSQL is up, 'reject' if starting up, 'no_resopnse' if not up."""
@@ -267,25 +272,25 @@ class Postgresql(object):
3: STATE_UNKNOWN}
return return_codes.get(ret, STATE_UNKNOWN)
def reload_config(self, config, sighup=False):
def reload_config(self, config: Dict[str, Any], sighup: bool = False) -> None:
self.config.reload_config(config, sighup)
self._is_leader_retry.deadline = self.retry.deadline = config['retry_timeout'] / 2.0
@property
def pending_restart(self):
def pending_restart(self) -> bool:
return self._pending_restart
def set_pending_restart(self, value):
def set_pending_restart(self, value: bool) -> None:
self._pending_restart = value
@property
def sysid(self):
def sysid(self) -> str:
if not self._sysid and not self.bootstrapping:
data = self.controldata()
self._sysid = data.get('Database system identifier', "")
self._sysid = data.get('Database system identifier', '')
return self._sysid
def get_postgres_role_from_data_directory(self):
def get_postgres_role_from_data_directory(self) -> str:
if self.data_directory_empty() or not self.controldata():
return 'uninitialized'
elif self.config.recovery_conf_exists():
@@ -294,24 +299,24 @@ class Postgresql(object):
return 'master'
@property
def server_version(self):
def server_version(self) -> int:
return self._connection.server_version
def connection(self):
def connection(self) -> Union['connection3', 'Connection3[Any]']:
return self._connection.get()
def set_connection_kwargs(self, kwargs):
def set_connection_kwargs(self, kwargs: Dict[str, Any]) -> None:
self._connection.set_conn_kwargs(kwargs.copy())
self.citus_handler.set_conn_kwargs(kwargs.copy())
def _query(self, sql, *params):
def _query(self, sql: str, *params: Any) -> Union['Cursor[Any]', 'cursor']:
"""We are always using the same cursor, therefore this method is not thread-safe!!!
You can call it from different threads only if you are holding explicit `AsyncExecutor` lock,
because the main thread is always holding this lock when running HA cycle."""
cursor = None
try:
cursor = self._connection.cursor()
cursor.execute(sql, params or None)
cursor.execute(sql.encode('utf-8'), params or None)
return cursor
except psycopg.Error as e:
if cursor and cursor.connection.closed == 0:
@@ -327,7 +332,7 @@ class Postgresql(object):
raise RetryFailedError('cluster is being restarted')
raise PostgresConnectionException('connection problems')
def query(self, sql, *args, **kwargs):
def query(self, sql: str, *args: Any, **kwargs: Any) -> Union['Cursor[Any]', 'cursor']:
if not kwargs.get('retry', True):
return self._query(sql, *args)
try:
@@ -335,22 +340,22 @@ class Postgresql(object):
except RetryFailedError as e:
raise PostgresConnectionException(str(e))
def pg_control_exists(self):
def pg_control_exists(self) -> bool:
return os.path.isfile(self._pg_control)
def data_directory_empty(self):
def data_directory_empty(self) -> bool:
if self.pg_control_exists():
return False
return data_directory_is_empty(self._data_dir)
def replica_method_options(self, method):
return deepcopy(self.config.get(method, {}))
def replica_method_options(self, method: str) -> Dict[str, Any]:
return deepcopy(self.config.get(method, {}) or {})
def replica_method_can_work_without_replication_connection(self, method):
return method != 'basebackup' and (self.replica_method_options(method).get('no_master')
or self.replica_method_options(method).get('no_leader'))
def replica_method_can_work_without_replication_connection(self, method: str) -> bool:
return method != 'basebackup' and bool(self.replica_method_options(method).get('no_master')
or self.replica_method_options(method).get('no_leader'))
def can_create_replica_without_replication_connection(self, replica_methods=None):
def can_create_replica_without_replication_connection(self, replica_methods: Optional[List[str]]) -> bool:
""" go through the replication methods to see if there are ones
that does not require a working replication connection.
"""
@@ -359,10 +364,10 @@ class Postgresql(object):
return any(self.replica_method_can_work_without_replication_connection(m) for m in replica_methods)
@property
def enforce_hot_standby_feedback(self):
def enforce_hot_standby_feedback(self) -> bool:
return self._enforce_hot_standby_feedback
def set_enforce_hot_standby_feedback(self, value):
def set_enforce_hot_standby_feedback(self, value: bool) -> None:
# If we enable or disable the hot_standby_feedback we need to update postgresql.conf and reload
if self._enforce_hot_standby_feedback != value:
self._enforce_hot_standby_feedback = value
@@ -370,7 +375,7 @@ class Postgresql(object):
self.config.write_postgresql_conf()
self.reload()
def reset_cluster_info_state(self, cluster: Union[Cluster, None], nofailover: Optional[bool] = None,
def reset_cluster_info_state(self, cluster: Union[Cluster, None], nofailover: bool = False,
global_config: Optional['GlobalConfig'] = None) -> None:
"""Reset monitoring query cache.
@@ -395,7 +400,7 @@ class Postgresql(object):
self._global_config = global_config
def _cluster_info_state_get(self, name):
def _cluster_info_state_get(self, name: str) -> Optional[Any]:
if not self._cluster_info_state:
try:
result = self._is_leader_retry(self._query, self.cluster_info_query).fetchone()
@@ -417,62 +422,63 @@ class Postgresql(object):
return self._cluster_info_state.get(name)
def replayed_location(self):
def replayed_location(self) -> Optional[int]:
return self._cluster_info_state_get('replayed_location')
def received_location(self):
def received_location(self) -> Optional[int]:
return self._cluster_info_state_get('received_location')
def slots(self):
return self._cluster_info_state_get('slots')
def slots(self) -> Dict[str, int]:
return self._cluster_info_state_get('slots') or {}
def primary_slot_name(self):
def primary_slot_name(self) -> Optional[str]:
return self._cluster_info_state_get('slot_name')
def primary_conninfo(self):
def primary_conninfo(self) -> Optional[str]:
return self._cluster_info_state_get('conninfo')
def received_timeline(self):
def received_timeline(self) -> Optional[int]:
return self._cluster_info_state_get('received_tli')
def synchronous_commit(self) -> str:
""":returns: "synchronous_commit" GUC value."""
return self._cluster_info_state_get('synchronous_commit')
return self._cluster_info_state_get('synchronous_commit') or 'on'
def synchronous_standby_names(self) -> str:
""":returns: "synchronous_standby_names" GUC value."""
return self._cluster_info_state_get('synchronous_standby_names')
return self._cluster_info_state_get('synchronous_standby_names') or ''
def pg_stat_replication(self) -> List[Dict[str, Any]]:
""":returns: a result set of 'SELECT * FROM pg_stat_replication'."""
return self._cluster_info_state_get('pg_stat_replication') or []
def is_leader(self):
def is_leader(self) -> bool:
try:
return bool(self._cluster_info_state_get('timeline'))
except PostgresConnectionException:
logger.warning('Failed to determine PostgreSQL state from the connection, falling back to cached role')
return bool(self.is_running() and self.role in ('master', 'primary'))
def replay_paused(self):
return self._cluster_info_state_get('replay_paused')
def replay_paused(self) -> bool:
return self._cluster_info_state_get('replay_paused') or False
def resume_wal_replay(self):
def resume_wal_replay(self) -> None:
self._query('SELECT pg_catalog.pg_{0}_replay_resume()'.format(self.wal_name))
def handle_parameter_change(self):
def handle_parameter_change(self) -> None:
if self.major_version >= 140000 and not self.is_starting() and self.replay_paused():
logger.info('Resuming paused WAL replay for PostgreSQL 14+')
self.resume_wal_replay()
def pg_control_timeline(self):
def pg_control_timeline(self) -> Optional[int]:
try:
return int(self.controldata().get("Latest checkpoint's TimeLineID"))
return int(self.controldata().get("Latest checkpoint's TimeLineID", ""))
except (TypeError, ValueError):
logger.exception('Failed to parse timeline from pg_controldata output')
def parse_wal_record(self, timeline, lsn):
def parse_wal_record(self, timeline: str,
lsn: str) -> Union[Tuple[str, str, str, str], Tuple[None, None, None, None]]:
out, err = self.waldump(timeline, lsn, 1)
if out and not err:
match = re.match(r'^rmgr:\s+(.+?)\s+len \(rec/tot\):\s+\d+/\s+\d+, tx:\s+\d+, '
@@ -482,7 +488,7 @@ class Postgresql(object):
return match.groups()
return None, None, None, None
def latest_checkpoint_location(self):
def latest_checkpoint_location(self) -> Optional[int]:
"""Returns checkpoint location for the cleanly shut down primary.
But, if we know that the checkpoint was written to the new WAL
due to the archive_mode=on, we will return the LSN of prev wal record (SWITCH)."""
@@ -490,25 +496,25 @@ class Postgresql(object):
data = self.controldata()
timeline = data.get("Latest checkpoint's TimeLineID")
lsn = checkpoint_lsn = data.get('Latest checkpoint location')
if data.get('Database cluster state') == 'shut down' and lsn and timeline:
if data.get('Database cluster state') == 'shut down' and lsn and timeline and checkpoint_lsn:
try:
checkpoint_lsn = parse_lsn(checkpoint_lsn)
rm_name, lsn, prev, desc = self.parse_wal_record(timeline, lsn)
desc = desc.strip().lower()
if rm_name == 'XLOG' and parse_lsn(lsn) == checkpoint_lsn and prev and\
desc = str(desc).strip().lower()
if rm_name == 'XLOG' and lsn and parse_lsn(lsn) == checkpoint_lsn and prev and\
desc.startswith('checkpoint') and desc.endswith('shutdown'):
_, lsn, _, desc = self.parse_wal_record(timeline, prev)
prev = parse_lsn(prev)
# If the cluster is shutdown with archive_mode=on, WAL is switched before writing the checkpoint.
# In this case we want to take the LSN of previous record (switch) as the last known WAL location.
if parse_lsn(lsn) == prev and desc.strip() in ('xlog switch', 'SWITCH'):
if lsn and parse_lsn(lsn) == prev and str(desc).strip() in ('xlog switch', 'SWITCH'):
return prev
except Exception as e:
logger.error('Exception when parsing WAL pg_%sdump output: %r', self.wal_name, e)
if isinstance(checkpoint_lsn, int):
return checkpoint_lsn
def is_running(self):
def is_running(self) -> Optional[PostmasterProcess]:
"""Returns PostmasterProcess if one is running on the data directory or None. If most recently seen process
is running updates the cached process based on pid file."""
if self._postmaster_proc:
@@ -523,7 +529,7 @@ class Postgresql(object):
return self._postmaster_proc
@property
def cb_called(self):
def cb_called(self) -> bool:
return self.__cb_called
def call_nowait(self, cb_type: CallbackAction) -> None:
@@ -544,31 +550,31 @@ class Postgresql(object):
logger.exception('callback %s %r %s %s failed', cmd, cb_type, role, self.scope)
@property
def role(self):
def role(self) -> str:
with self._role_lock:
return self._role
def set_role(self, value):
def set_role(self, value: str) -> None:
with self._role_lock:
self._role = value
@property
def state(self):
def state(self) -> str:
with self._state_lock:
return self._state
def set_state(self, value):
def set_state(self, value: str) -> None:
with self._state_lock:
self._state = value
self._state_entry_timestamp = time.time()
def time_in_state(self):
def time_in_state(self) -> float:
return time.time() - self._state_entry_timestamp
def is_starting(self):
def is_starting(self) -> bool:
return self.state == 'starting'
def wait_for_port_open(self, postmaster, timeout):
def wait_for_port_open(self, postmaster: PostmasterProcess, timeout: float) -> bool:
"""Waits until PostgreSQL opens ports."""
for _ in polling_loop(timeout):
if self.cancellable.is_cancelled:
@@ -588,7 +594,9 @@ class Postgresql(object):
logger.warning("Timed out waiting for PostgreSQL to start")
return False
def start(self, timeout=None, task=None, block_callbacks=False, role=None, after_start=None):
def start(self, timeout: Optional[float] = None, task: Optional[CriticalTask] = None,
block_callbacks: bool = False, role: Optional[str] = None,
after_start: Optional[Callable[..., Any]] = None) -> Optional[bool]:
"""Start PostgreSQL
Waits for postmaster to open ports or terminate so pg_isready can be used to check startup completion
@@ -650,7 +658,7 @@ class Postgresql(object):
start_timeout = timeout
if not start_timeout:
try:
start_timeout = float(self.config.get('pg_ctl_timeout', 60))
start_timeout = float(self.config.get('pg_ctl_timeout', 60) or 0)
except ValueError:
start_timeout = 60
@@ -668,7 +676,8 @@ class Postgresql(object):
else:
return None
def checkpoint(self, connect_kwargs=None, timeout=None):
def checkpoint(self, connect_kwargs: Optional[Dict[str, Any]] = None,
timeout: Optional[float] = None) -> Optional[str]:
check_not_is_in_recovery = connect_kwargs is not None
connect_kwargs = connect_kwargs or self.config.local_connect_kwargs
for p in ['connect_timeout', 'options']:
@@ -680,15 +689,17 @@ class Postgresql(object):
cur.execute("SET statement_timeout = 0")
if check_not_is_in_recovery:
cur.execute('SELECT pg_catalog.pg_is_in_recovery()')
if cur.fetchone()[0]:
row = cur.fetchone()
if not row or row[0]:
return 'is_in_recovery=true'
cur.execute('CHECKPOINT')
except psycopg.Error:
logger.exception('Exception during CHECKPOINT')
return 'not accessible or not healty'
def stop(self, mode='fast', block_callbacks=False, checkpoint=None,
on_safepoint=None, on_shutdown=None, before_shutdown=None, stop_timeout=None):
def stop(self, mode: str = 'fast', block_callbacks: bool = False, checkpoint: Optional[bool] = None,
on_safepoint: Optional[Callable[..., Any]] = None, on_shutdown: Optional[Callable[[int], Any]] = None,
before_shutdown: Optional[Callable[..., Any]] = None, stop_timeout: Optional[int] = None) -> bool:
"""Stop PostgreSQL
Supports a callback when a safepoint is reached. A safepoint is when no user backend can return a successful
@@ -716,7 +727,9 @@ class Postgresql(object):
self.set_state('stop failed')
return success
def _do_stop(self, mode, block_callbacks, checkpoint, on_safepoint, on_shutdown, before_shutdown, stop_timeout):
def _do_stop(self, mode: str, block_callbacks: bool, checkpoint: bool,
on_safepoint: Optional[Callable[..., Any]], on_shutdown: Optional[Callable[..., Any]],
before_shutdown: Optional[Callable[..., Any]], stop_timeout: Optional[int]) -> Tuple[bool, bool]:
postmaster = self.is_running()
if not postmaster:
if on_safepoint:
@@ -774,7 +787,8 @@ class Postgresql(object):
return True, True
def terminate_postmaster(self, postmaster, mode, stop_timeout):
def terminate_postmaster(self, postmaster: PostmasterProcess, mode: str,
stop_timeout: Optional[int]) -> Optional[bool]:
if mode in ['fast', 'smart']:
try:
success = postmaster.signal_stop('immediate', self.pgcommand('pg_ctl'))
@@ -787,13 +801,13 @@ class Postgresql(object):
logger.warning("Sending SIGKILL to Postmaster and its children")
return postmaster.signal_kill()
def terminate_starting_postmaster(self, postmaster):
def terminate_starting_postmaster(self, postmaster: PostmasterProcess) -> None:
"""Terminates a postmaster that has not yet opened ports or possibly even written a pid file. Blocks
until the process goes away."""
postmaster.signal_stop('immediate', self.pgcommand('pg_ctl'))
postmaster.wait()
def _wait_for_connection_close(self, postmaster):
def _wait_for_connection_close(self, postmaster: PostmasterProcess) -> None:
try:
with self.connection().cursor() as cur:
while postmaster.is_running(): # Need a timeout here?
@@ -802,17 +816,17 @@ class Postgresql(object):
except psycopg.Error:
pass
def reload(self, block_callbacks=False):
def reload(self, block_callbacks: bool = False) -> bool:
ret = self.pg_ctl('reload')
if ret and not block_callbacks:
self.call_nowait(CallbackAction.ON_RELOAD)
return ret
def check_for_startup(self):
def check_for_startup(self) -> bool:
"""Checks PostgreSQL status and returns if PostgreSQL is in the middle of startup."""
return self.is_starting() and not self.check_startup_state_changed()
def check_startup_state_changed(self):
def check_startup_state_changed(self) -> bool:
"""Checks if PostgreSQL has completed starting up or failed or still starting.
Should only be called when state == 'starting'
@@ -847,7 +861,7 @@ class Postgresql(object):
return True
def wait_for_startup(self, timeout=None):
def wait_for_startup(self, timeout: float = 0) -> Optional[bool]:
"""Waits for PostgreSQL startup to complete or fail.
:returns: True if start was successful, False otherwise"""
@@ -862,8 +876,10 @@ class Postgresql(object):
return self.state == 'running'
def restart(self, timeout=None, task=None, block_callbacks=False,
role=None, before_shutdown=None, after_start=None):
def restart(self, timeout: Optional[float] = None, task: Optional[CriticalTask] = None,
block_callbacks: bool = False, role: Optional[str] = None,
before_shutdown: Optional[Callable[..., Any]] = None,
after_start: Optional[Callable[..., Any]] = None) -> Optional[bool]:
"""Restarts PostgreSQL.
When timeout parameter is set the call will block either until PostgreSQL has started, failed to start or
@@ -880,13 +896,13 @@ class Postgresql(object):
self.set_state('restart failed ({0})'.format(self.state))
return ret
def is_healthy(self):
def is_healthy(self) -> bool:
if not self.is_running():
logger.warning('Postgresql is not running.')
return False
return True
def get_guc_value(self, name):
def get_guc_value(self, name: str) -> Optional[str]:
cmd = [self.pgcommand('postgres'), '-D', self._data_dir, '-C', name,
'--config-file={}'.format(self.config.postgresql_conf)]
try:
@@ -896,7 +912,7 @@ class Postgresql(object):
except Exception as e:
logger.error('Failed to execute %s: %r', cmd, e)
def controldata(self):
def controldata(self) -> Dict[str, str]:
""" return the contents of pg_controldata, or non-True value if pg_controldata call failed """
# Don't try to call pg_controldata during backup restore
if self._version_file_exists() and self.state != 'creating replica':
@@ -911,11 +927,11 @@ class Postgresql(object):
logger.exception("Error when calling pg_controldata")
return {}
def waldump(self, timeline, lsn, limit):
def waldump(self, timeline: Union[int, str], lsn: str, limit: int) -> Tuple[Optional[bytes], Optional[bytes]]:
cmd = self.pgcommand('pg_{0}dump'.format(self.wal_name))
env = {**os.environ, 'LANG': 'C', 'LC_ALL': 'C', 'PGDATA': self._data_dir}
try:
waldump = subprocess.Popen([cmd, '-t', str(timeline), '-s', str(lsn), '-n', str(limit)],
waldump = subprocess.Popen([cmd, '-t', str(timeline), '-s', lsn, '-n', str(limit)],
stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env)
out, err = waldump.communicate()
waldump.wait()
@@ -925,22 +941,24 @@ class Postgresql(object):
return None, None
@contextmanager
def get_replication_connection_cursor(self, host=None, port=5432, **kwargs):
def get_replication_connection_cursor(self, host: Optional[str] = None, port: int = 5432,
**kwargs: Any) -> Generator[Union['cursor', 'Cursor[Any]'], None, None]:
conn_kwargs = self.config.replication.copy()
conn_kwargs.update(host=host, port=int(port) if port else None, user=conn_kwargs.pop('username'),
connect_timeout=3, replication=1, options='-c statement_timeout=2000')
with get_connection_cursor(**conn_kwargs) as cur:
yield cur
def get_replica_timeline(self):
def get_replica_timeline(self) -> Optional[int]:
try:
with self.get_replication_connection_cursor(**self.config.local_replication_address) as cur:
cur.execute('IDENTIFY_SYSTEM')
return cur.fetchone()[1]
row = cur.fetchone()
return row[1] if row else None
except Exception:
logger.exception('Can not fetch local timeline and lsn from replication connection')
def replica_cached_timeline(self, primary_timeline):
def replica_cached_timeline(self, primary_timeline: Optional[int]) -> Optional[int]:
if not self._cached_replica_timeline or not primary_timeline\
or self._cached_replica_timeline != primary_timeline:
self._cached_replica_timeline = self.get_replica_timeline()
@@ -948,29 +966,26 @@ class Postgresql(object):
def get_primary_timeline(self) -> int:
""":returns: current timeline if postgres is running as a primary or 0."""
return self._cluster_info_state_get('timeline')
return self._cluster_info_state_get('timeline') or 0
def get_history(self, timeline):
def get_history(self, timeline: int) -> List[Union[Tuple[int, int, str], Tuple[int, int, str, str, str]]]:
history_path = os.path.join(self.wal_dir, '{0:08X}.history'.format(timeline))
history_mtime = mtime(history_path)
history: List[Union[Tuple[int, int, str], Tuple[int, int, str, str, str]]] = []
if history_mtime:
try:
with open(history_path, 'r') as f:
history = f.read()
history = list(parse_history(history))
history_content = f.read()
history = list(parse_history(history_content))
if history[-1][0] == timeline - 1:
history_mtime = datetime.fromtimestamp(history_mtime).replace(tzinfo=tz.tzlocal())
history[-1].append(history_mtime.isoformat())
history[-1].append(self.name)
return history
history[-1] = history[-1][:3] + (history_mtime.isoformat(), self.name)
except Exception:
logger.exception('Failed to read and parse %s', (history_path,))
return history
def follow(self,
member: Member,
role: Optional[str] = 'replica',
timeout: Optional[float] = None,
do_reload: Optional[bool] = False) -> Optional[bool]:
def follow(self, member: Union[Leader, Member, None], role: str = 'replica',
timeout: Optional[float] = None, do_reload: bool = False) -> Optional[bool]:
"""Reconfigure postgres to follow a new member or use different recovery parameters.
Method may call `on_role_change` callback if role is changing.
@@ -1016,14 +1031,14 @@ class Postgresql(object):
self.call_nowait(CallbackAction.ON_ROLE_CHANGE)
return ret
def _wait_promote(self, wait_seconds):
def _wait_promote(self, wait_seconds: int) -> Optional[bool]:
for _ in polling_loop(wait_seconds):
data = self.controldata()
if data.get('Database cluster state') == 'in production':
self.set_role('master')
return True
def _pre_promote(self):
def _pre_promote(self) -> bool:
"""
Runs a fencing script after the leader lock is acquired but before the replica is promoted.
If the script exits with a non-zero code, promotion does not happen and the leader key is removed from DCS.
@@ -1053,7 +1068,8 @@ class Postgresql(object):
except Exception as e:
logger.error('Exception when calling `%s`: %r', cmd, e)
def promote(self, wait_seconds, task, before_promote=None, on_success=None):
def promote(self, wait_seconds: int, task: CriticalTask, before_promote: Optional[Callable[..., Any]] = None,
on_success: Optional[Callable[..., Any]] = None) -> Optional[bool]:
if self.role in ('promoted', 'master', 'primary'):
return True
@@ -1086,43 +1102,48 @@ class Postgresql(object):
return ret
@staticmethod
def _wal_position(is_leader, wal_position, received_location, replayed_location):
def _wal_position(is_leader: bool, wal_position: int,
received_location: Optional[int], replayed_location: Optional[int]) -> int:
return wal_position if is_leader else max(received_location or 0, replayed_location or 0)
def timeline_wal_position(self):
def timeline_wal_position(self) -> Tuple[int, int, Optional[int]]:
# This method could be called from different threads (simultaneously with some other `_query` calls).
# If it is called not from main thread we will create a new cursor to execute statement.
if current_thread().ident == self.__thread_ident:
timeline = self._cluster_info_state_get('timeline')
wal_position = self._cluster_info_state_get('wal_position')
timeline = self._cluster_info_state_get('timeline') or 0
wal_position = self._cluster_info_state_get('wal_position') or 0
replayed_location = self.replayed_location()
received_location = self.received_location()
pg_control_timeline = self._cluster_info_state_get('pg_control_timeline')
else:
with self.connection().cursor() as cursor:
cursor.execute(self.cluster_info_query)
(timeline, wal_position, replayed_location,
received_location, _, pg_control_timeline) = cursor.fetchone()[:6]
cursor.execute(self.cluster_info_query.encode('utf-8'))
row = cursor.fetchone()
if TYPE_CHECKING: # pragma: no cover
assert row is not None
(timeline, wal_position, replayed_location, received_location, _, pg_control_timeline) = row[:6]
wal_position = self._wal_position(timeline, wal_position, received_location, replayed_location)
wal_position = self._wal_position(bool(timeline), wal_position, received_location, replayed_location)
return (timeline, wal_position, pg_control_timeline)
def postmaster_start_time(self):
def postmaster_start_time(self) -> Optional[str]:
try:
query = "SELECT " + self.POSTMASTER_START_TIME
if current_thread().ident == self.__thread_ident:
return self.query(query).fetchone()[0].isoformat(sep=' ')
with self.connection().cursor() as cursor:
cursor.execute(query)
return cursor.fetchone()[0].isoformat(sep=' ')
row = self.query(query).fetchone()
else:
with self.connection().cursor() as cursor:
cursor.execute(query)
row = cursor.fetchone()
return row[0].isoformat(sep=' ') if row else None
except psycopg.Error:
return None
def last_operation(self):
return self._wal_position(self.is_leader(), self._cluster_info_state_get('wal_position'),
def last_operation(self) -> int:
return self._wal_position(self.is_leader(), self._cluster_info_state_get('wal_position') or 0,
self.received_location(), self.replayed_location())
def configure_server_parameters(self):
def configure_server_parameters(self) -> None:
self._major_version = self.get_major_version()
self.config.setup_server_parameters()
@@ -1135,9 +1156,9 @@ class Postgresql(object):
self.configure_server_parameters()
return self._major_version > 0
def pg_wal_realpath(self):
def pg_wal_realpath(self) -> Dict[str, str]:
"""Returns a dict containing the symlink (key) and target (value) for the wal directory"""
links = {}
links: Dict[str, str] = {}
for pg_wal_dir in ('pg_xlog', 'pg_wal'):
pg_wal_path = os.path.join(self._data_dir, pg_wal_dir)
if os.path.exists(pg_wal_path) and os.path.islink(pg_wal_path):
@@ -1145,9 +1166,9 @@ class Postgresql(object):
links[pg_wal_path] = pg_wal_realpath
return links
def pg_tblspc_realpaths(self):
def pg_tblspc_realpaths(self) -> Dict[str, str]:
"""Returns a dict containing the symlink (key) and target (values) for the tablespaces"""
links = {}
links: Dict[str, str] = {}
pg_tblsp_dir = os.path.join(self._data_dir, 'pg_tblspc')
if os.path.exists(pg_tblsp_dir):
for tsdn in os.listdir(pg_tblsp_dir):
@@ -1157,7 +1178,7 @@ class Postgresql(object):
links[pg_tsp_path] = pg_tsp_rpath
return links
def move_data_directory(self):
def move_data_directory(self) -> None:
if os.path.isdir(self._data_dir) and not self.is_running():
try:
postfix = 'failed'
@@ -1191,7 +1212,7 @@ class Postgresql(object):
except OSError:
logger.exception("Could not rename data directory %s", self._data_dir)
def remove_data_directory(self):
def remove_data_directory(self) -> None:
self.set_role('uninitialized')
logger.info('Removing data directory: %s', self._data_dir)
try:
@@ -1219,7 +1240,7 @@ class Postgresql(object):
logger.exception('Could not remove data directory %s', self._data_dir)
self.move_data_directory()
def schedule_sanity_checks_after_pause(self):
def schedule_sanity_checks_after_pause(self) -> None:
"""
After coming out of pause we have to:
1. configure server parameters if necessary
@@ -1229,4 +1250,4 @@ class Postgresql(object):
self.ensure_major_version_is_known()
self.slots_handler.schedule()
self.citus_handler.schedule_cache_rebuild()
self._sysid = None
self._sysid = ''

View File

@@ -3,34 +3,39 @@ import os
import shlex
import tempfile
import time
from typing import List, Dict, Union, Callable, Tuple
from ..dcs import RemoteMember
from typing import Any, Callable, Dict, List, Optional, Union, Tuple, TYPE_CHECKING
from ..async_executor import CriticalTask
from ..dcs import Leader, Member, RemoteMember
from ..psycopg import quote_ident, quote_literal
from ..utils import deep_compare, unquote
if TYPE_CHECKING: # pragma: no cover
from . import Postgresql
logger = logging.getLogger(__name__)
class Bootstrap(object):
def __init__(self, postgresql):
def __init__(self, postgresql: 'Postgresql') -> None:
self._postgresql = postgresql
self._running_custom_bootstrap = False
@property
def running_custom_bootstrap(self):
def running_custom_bootstrap(self) -> bool:
return self._running_custom_bootstrap
@property
def keep_existing_recovery_conf(self):
def keep_existing_recovery_conf(self) -> bool:
return self._running_custom_bootstrap and self._keep_existing_recovery_conf
@staticmethod
def process_user_options(tool: str,
options: Union[Dict[str, str], List[Union[str, Dict[str, str]]]],
options: Union[Any, Dict[str, str], List[Union[str, Dict[str, Any]]]],
not_allowed_options: Tuple[str, ...],
error_handler: Callable[[str], None]) -> List:
error_handler: Callable[[str], None]) -> List[str]:
"""Format *options* in a list or dictionary format into command line long form arguments.
.. note::
@@ -77,9 +82,9 @@ class Bootstrap(object):
:param error_handler: A function which will be called when an error condition is encountered
:returns: List of long form arguments to pass to the named tool
"""
user_options = []
user_options: List[str] = []
def option_is_allowed(name):
def option_is_allowed(name: str) -> bool:
ret = name not in not_allowed_options
if not ret:
error_handler('{0} option for {1} is not allowed'.format(name, tool))
@@ -106,11 +111,11 @@ class Bootstrap(object):
error_handler('{0} options must be list or dict'.format(tool))
return user_options
def _initdb(self, config):
def _initdb(self, config: Any) -> bool:
self._postgresql.set_state('initializing new cluster')
not_allowed_options = ('pgdata', 'nosync', 'pwfile', 'sync-only', 'version')
def error_handler(e):
def error_handler(e: str) -> None:
raise Exception(e)
options = self.process_user_options('initdb', config or [], not_allowed_options, error_handler)
@@ -134,7 +139,7 @@ class Bootstrap(object):
self._postgresql.set_state('initdb failed')
return ret
def _post_restore(self):
def _post_restore(self) -> None:
self._postgresql.config.restore_configuration_files()
self._postgresql.configure_server_parameters()
@@ -145,7 +150,7 @@ class Bootstrap(object):
if os.path.exists(trigger_file):
os.unlink(trigger_file)
def _custom_bootstrap(self, config):
def _custom_bootstrap(self, config: Any) -> bool:
self._postgresql.set_state('running custom bootstrap script')
params = [] if config.get('no_params') else ['--scope=' + self._postgresql.scope,
'--datadir=' + self._postgresql.data_dir]
@@ -165,7 +170,7 @@ class Bootstrap(object):
self._postgresql.config.remove_recovery_conf()
return True
def call_post_bootstrap(self, config):
def call_post_bootstrap(self, config: Dict[str, Any]) -> bool:
"""
runs a script after initdb or custom bootstrap script is called and waits until completion.
"""
@@ -192,7 +197,7 @@ class Bootstrap(object):
return False
return True
def create_replica(self, clone_member):
def create_replica(self, clone_member: Union[Leader, Member, None]) -> Optional[int]:
"""
create the replica according to the replica_method
defined by the user. this is a list, so we need to
@@ -279,7 +284,7 @@ class Bootstrap(object):
self._postgresql.set_state('stopped')
return ret
def basebackup(self, conn_url, env, options):
def basebackup(self, conn_url: str, env: Dict[str, str], options: Dict[str, Any]) -> Optional[int]:
# creates a replica data dir using pg_basebackup.
# this is the default, built-in create_replica_methods
# tries twice, then returns failure (as 1)
@@ -314,7 +319,7 @@ class Bootstrap(object):
return ret
def clone(self, clone_member):
def clone(self, clone_member: Union[Leader, Member, None]) -> bool:
"""
- initialize the replica from an existing member (primary or replica)
- initialize the replica using the replica creation method that
@@ -327,7 +332,7 @@ class Bootstrap(object):
self._post_restore()
return ret
def bootstrap(self, config):
def bootstrap(self, config: Dict[str, Any]) -> bool:
""" Initialize a new node from scratch and start it. """
pg_hba = config.get('pg_hba', [])
method = config.get('method') or 'initdb'
@@ -339,9 +344,9 @@ class Bootstrap(object):
method = 'initdb'
do_initialize = self._initdb
return do_initialize(config.get(method)) and self._postgresql.config.append_pg_hba(pg_hba) \
and self._postgresql.config.save_configuration_files() and self._postgresql.start()
and self._postgresql.config.save_configuration_files() and bool(self._postgresql.start())
def create_or_update_role(self, name, password, options):
def create_or_update_role(self, name: str, password: Optional[str], options: List[str]) -> None:
options = list(map(str.upper, options))
if 'NOLOGIN' not in options and 'LOGIN' not in options:
options.append('LOGIN')
@@ -371,7 +376,7 @@ END;$$""".format(quote_literal(name), quote_ident(name, self._postgresql.connect
self._postgresql.query('RESET log_statement')
self._postgresql.query('RESET pg_stat_statements.track_utility')
def post_bootstrap(self, config, task):
def post_bootstrap(self, config: Dict[str, Any], task: CriticalTask) -> Optional[bool]:
try:
postgresql = self._postgresql
superuser = postgresql.config.superuser

View File

@@ -17,7 +17,7 @@ class CallbackAction(str, Enum):
ON_RELOAD = "on_reload"
ON_ROLE_CHANGE = "on_role_change"
def __repr__(self):
def __repr__(self) -> str:
return self.value
@@ -60,15 +60,17 @@ class CallbackExecutor(CancellableExecutor, Thread):
self._cmd = cmd
self._condition.notify()
def run(self):
def run(self) -> None:
while True:
with self._condition:
if self._cmd is None:
self._condition.wait()
cmd, self._cmd = self._cmd, None
with self._lock:
if not self._start_process(cmd, close_fds=True):
continue
self._process.wait()
self._kill_children()
if cmd is not None:
with self._lock:
if not self._start_process(cmd, close_fds=True):
continue
if self._process:
self._process.wait()
self._kill_children()

View File

@@ -5,6 +5,7 @@ import subprocess
from patroni.exceptions import PostgresException
from patroni.utils import polling_loop
from threading import Lock
from typing import Any, Dict, List, Optional, Union
logger = logging.getLogger(__name__)
@@ -15,13 +16,13 @@ class CancellableExecutor(object):
There must be only one such process so that AsyncExecutor can easily cancel it.
"""
def __init__(self):
def __init__(self) -> None:
self._process = None
self._process_cmd = None
self._process_children = []
self._process_children: List[psutil.Process] = []
self._lock = Lock()
def _start_process(self, cmd, *args, **kwargs):
def _start_process(self, cmd: List[str], *args: Any, **kwargs: Any) -> Optional[bool]:
"""This method must be executed only when the `_lock` is acquired"""
try:
@@ -32,7 +33,7 @@ class CancellableExecutor(object):
return logger.exception('Failed to execute %s', cmd)
return True
def _kill_process(self):
def _kill_process(self) -> None:
with self._lock:
if self._process is not None and self._process.is_running() and not self._process_children:
try:
@@ -53,8 +54,8 @@ class CancellableExecutor(object):
except psutil.AccessDenied as e:
logger.warning('Failed to kill the process: %s', e.msg)
def _kill_children(self):
waitlist = []
def _kill_children(self) -> None:
waitlist: List[psutil.Process] = []
with self._lock:
for child in self._process_children:
try:
@@ -69,15 +70,16 @@ class CancellableExecutor(object):
class CancellableSubprocess(CancellableExecutor):
def __init__(self):
def __init__(self) -> None:
super(CancellableSubprocess, self).__init__()
self._is_cancelled = False
def call(self, *args, **kwargs):
def call(self, *args: Any, **kwargs: Union[Any, Dict[str, str]]) -> Optional[int]:
for s in ('stdin', 'stdout', 'stderr'):
kwargs.pop(s, None)
communicate = kwargs.pop('communicate', None)
communicate: Optional[Dict[str, str]] = kwargs.pop('communicate', None)
input_data = None
if isinstance(communicate, dict):
input_data = communicate.get('input')
if input_data:
@@ -96,7 +98,7 @@ class CancellableSubprocess(CancellableExecutor):
self._is_cancelled = False
started = self._start_process(*args, **kwargs)
if started:
if started and self._process is not None:
if isinstance(communicate, dict):
communicate['stdout'], communicate['stderr'] = self._process.communicate(input_data)
return self._process.wait()
@@ -105,16 +107,16 @@ class CancellableSubprocess(CancellableExecutor):
self._process = None
self._kill_children()
def reset_is_cancelled(self):
def reset_is_cancelled(self) -> None:
with self._lock:
self._is_cancelled = False
@property
def is_cancelled(self):
def is_cancelled(self) -> bool:
with self._lock:
return self._is_cancelled
def cancel(self, kill=False):
def cancel(self, kill: bool = False) -> None:
with self._lock:
self._is_cancelled = True
if self._process is None or not self._process.is_running():

View File

@@ -4,11 +4,17 @@ import time
from threading import Condition, Event, Thread
from urllib.parse import urlparse
from typing import Any, Dict, List, Optional, Union, Tuple, TYPE_CHECKING
from .connection import Connection
from ..dcs import CITUS_COORDINATOR_GROUP_ID
from ..dcs import CITUS_COORDINATOR_GROUP_ID, Cluster
from ..psycopg import connect, quote_ident
if TYPE_CHECKING: # pragma: no cover
from psycopg import Cursor
from psycopg2 import cursor
from . import Postgresql
CITUS_SLOT_NAME_RE = re.compile(r'^citus_shard_(move|split)_slot(_[1-9][0-9]*){2,3}$')
logger = logging.getLogger(__name__)
@@ -16,7 +22,8 @@ logger = logging.getLogger(__name__)
class PgDistNode(object):
"""Represents a single row in the `pg_dist_node` table"""
def __init__(self, group, host, port, event, nodeid=None, timeout=None, cooldown=None):
def __init__(self, group: int, host: str, port: int, event: str, nodeid: Optional[int] = None,
timeout: Optional[float] = None, cooldown: Optional[float] = None) -> None:
self.group = group
# A weird way of pausing client connections by adding the `-demoted` suffix to the hostname
self.host = host + ('-demoted' if event == 'before_demote' else '')
@@ -29,7 +36,7 @@ class PgDistNode(object):
# If transaction was started, we need to COMMIT/ROLLBACK before the deadline
self.timeout = timeout
self.cooldown = cooldown or 10000 # 10s by default
self.deadline = 0
self.deadline: float = 0
# All changes in the pg_dist_node are serialized on the Patroni
# side by performing them from a thread. The thread, that is
@@ -38,74 +45,74 @@ class PgDistNode(object):
# the worker, and once it is done notify the calling thread.
self._event = Event()
def wait(self):
def wait(self) -> None:
self._event.wait()
def wakeup(self):
def wakeup(self) -> None:
self._event.set()
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
return isinstance(other, PgDistNode) and self.event == other.event\
and self.host == other.host and self.port == other.port
def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
return not self == other
def __str__(self):
def __str__(self) -> str:
return ('PgDistNode(nodeid={0},group={1},host={2},port={3},event={4})'
.format(self.nodeid, self.group, self.host, self.port, self.event))
def __repr__(self):
def __repr__(self) -> str:
return str(self)
class CitusHandler(Thread):
def __init__(self, postgresql, config):
def __init__(self, postgresql: 'Postgresql', config: Optional[Dict[str, Union[str, int]]]) -> None:
super(CitusHandler, self).__init__()
self.daemon = True
self._postgresql = postgresql
self._config = config
self._connection = Connection()
self._pg_dist_node = {} # Cache of pg_dist_node: {groupid: PgDistNode()}
self._tasks = [] # Requests to change pg_dist_node, every task is a `PgDistNode`
self._pg_dist_node: Dict[int, PgDistNode] = {} # Cache of pg_dist_node: {groupid: PgDistNode()}
self._tasks: List[PgDistNode] = [] # Requests to change pg_dist_node, every task is a `PgDistNode`
self._condition = Condition() # protects _pg_dist_node, _tasks, and _schedule_load_pg_dist_node
self._in_flight = None # Reference to the `PgDistNode` if there is a transaction in progress changing it
self._in_flight: Optional[PgDistNode] = None # Reference to the `PgDistNode` being changed in a transaction
self.schedule_cache_rebuild()
def is_enabled(self):
def is_enabled(self) -> bool:
return isinstance(self._config, dict)
def group(self):
return self._config['group']
def group(self) -> Optional[int]:
return int(self._config['group']) if isinstance(self._config, dict) else None
def is_coordinator(self):
def is_coordinator(self) -> bool:
return self.is_enabled() and self.group() == CITUS_COORDINATOR_GROUP_ID
def is_worker(self):
def is_worker(self) -> bool:
return self.is_enabled() and not self.is_coordinator()
def set_conn_kwargs(self, kwargs):
if self.is_enabled():
def set_conn_kwargs(self, kwargs: Dict[str, Any]) -> None:
if isinstance(self._config, dict): # self.is_enabled():
kwargs.update({'dbname': self._config['database'],
'options': '-c statement_timeout=0 -c idle_in_transaction_session_timeout=0'})
self._connection.set_conn_kwargs(kwargs)
def schedule_cache_rebuild(self):
def schedule_cache_rebuild(self) -> None:
with self._condition:
self._schedule_load_pg_dist_node = True
def on_demote(self):
def on_demote(self) -> None:
with self._condition:
self._pg_dist_node.clear()
self._tasks[:] = []
self._in_flight = None
def query(self, sql, *params):
def query(self, sql: str, *params: Any) -> Union['Cursor[Any]', 'cursor']:
try:
logger.debug('query(%s, %s)', sql, params)
cursor = self._connection.cursor()
cursor.execute(sql, params or None)
cursor.execute(sql.encode('utf-8'), params or None)
return cursor
except Exception as e:
logger.error('Exception when executing query "%s", (%s): %r', sql, params, e)
@@ -114,7 +121,7 @@ class CitusHandler(Thread):
self.schedule_cache_rebuild()
raise e
def load_pg_dist_node(self):
def load_pg_dist_node(self) -> bool:
"""Read from the `pg_dist_node` table and put it into the local cache"""
with self._condition:
@@ -132,7 +139,7 @@ class CitusHandler(Thread):
self._pg_dist_node = {r[1]: PgDistNode(r[1], r[2], r[3], 'after_promote', r[0]) for r in cursor}
return True
def sync_pg_dist_node(self, cluster):
def sync_pg_dist_node(self, cluster: Cluster) -> None:
"""Maintain the `pg_dist_node` from the coordinator leader every heartbeat loop.
We can't always rely on REST API calls from worker nodes in order
@@ -156,12 +163,12 @@ class CitusHandler(Thread):
and leader.data.get('role') in ('master', 'primary') and leader.data.get('state') == 'running':
self.add_task('after_promote', group, leader.conn_url)
def find_task_by_group(self, group):
def find_task_by_group(self, group: int) -> Optional[int]:
for i, task in enumerate(self._tasks):
if task.group == group:
return i
def pick_task(self):
def pick_task(self) -> Tuple[Optional[int], Optional[PgDistNode]]:
"""Returns the tuple(i, task), where `i` - is the task index in the self._tasks list
Tasks are picked by following priorities:
@@ -195,15 +202,17 @@ class CitusHandler(Thread):
task.nodeid = self._pg_dist_node[task.group].nodeid
return i, task
def update_node(self, task):
def update_node(self, task: PgDistNode) -> None:
if task.nodeid is not None:
self.query('SELECT pg_catalog.citus_update_node(%s, %s, %s, true, %s)',
task.nodeid, task.host, task.port, task.cooldown)
elif task.event != 'before_demote':
task.nodeid = self.query("SELECT pg_catalog.citus_add_node(%s, %s, %s, 'primary', 'default')",
task.host, task.port, task.group).fetchone()[0]
row = self.query("SELECT pg_catalog.citus_add_node(%s, %s, %s, 'primary', 'default')",
task.host, task.port, task.group).fetchone()
if row is not None:
task.nodeid = row[0]
def process_task(self, task):
def process_task(self, task: PgDistNode) -> bool:
"""Updates a single row in `pg_dist_node` table, optionally in a transaction.
The transaction is started if we do a demote of the worker node
@@ -238,13 +247,13 @@ class CitusHandler(Thread):
self._in_flight = task
return False
def process_tasks(self):
def process_tasks(self) -> None:
while True:
if not self._in_flight and not self.load_pg_dist_node():
break
i, task = self.pick_task()
if not task:
if not task or i is None:
break
try:
update_cache = self.process_task(task)
@@ -259,7 +268,7 @@ class CitusHandler(Thread):
self._tasks.pop(i)
task.wakeup()
def run(self):
def run(self) -> None:
while True:
try:
with self._condition:
@@ -280,7 +289,7 @@ class CitusHandler(Thread):
except Exception:
logger.exception('run')
def _add_task(self, task):
def _add_task(self, task: PgDistNode) -> bool:
with self._condition:
i = self.find_task_by_group(task.group)
@@ -306,32 +315,34 @@ class CitusHandler(Thread):
return True
return False
def add_task(self, event, group, conn_url, timeout=None, cooldown=None):
def add_task(self, event: str, group: int, conn_url: str,
timeout: Optional[float] = None, cooldown: Optional[float] = None) -> Optional[PgDistNode]:
try:
r = urlparse(conn_url)
except Exception as e:
return logger.error('Failed to parse connection url %s: %r', conn_url, e)
host = r.hostname
port = r.port or 5432
task = PgDistNode(group, host, port, event, timeout=timeout, cooldown=cooldown)
return task if self._add_task(task) else None
if host:
port = r.port or 5432
task = PgDistNode(group, host, port, event, timeout=timeout, cooldown=cooldown)
return task if self._add_task(task) else None
def handle_event(self, cluster, event):
def handle_event(self, cluster: Cluster, event: Dict[str, Any]) -> None:
if not self.is_alive():
return
cluster = cluster.workers.get(event['group'])
if not (cluster and cluster.leader and cluster.leader.name == event['leader'] and cluster.leader.conn_url):
worker = cluster.workers.get(event['group'])
if not (worker and worker.leader and worker.leader.name == event['leader'] and worker.leader.conn_url):
return
task = self.add_task(event['type'], event['group'],
cluster.leader.conn_url,
worker.leader.conn_url,
event['timeout'], event['cooldown'] * 1000)
if task and event['type'] == 'before_demote':
task.wait()
def bootstrap(self):
if not self.is_enabled():
def bootstrap(self) -> None:
if not isinstance(self._config, dict): # self.is_enabled()
return
conn_kwargs = self._postgresql.config.local_connect_kwargs
@@ -340,7 +351,8 @@ class CitusHandler(Thread):
conn = connect(**conn_kwargs)
try:
with conn.cursor() as cur:
cur.execute('CREATE DATABASE {0}'.format(quote_ident(self._config['database'], conn)))
cur.execute('CREATE DATABASE {0}'.format(
quote_ident(self._config['database'], conn)).encode('utf-8'))
finally:
conn.close()
@@ -364,7 +376,7 @@ class CitusHandler(Thread):
finally:
conn.close()
def adjust_postgres_gucs(self, parameters):
def adjust_postgres_gucs(self, parameters: Dict[str, Any]) -> None:
if not self.is_enabled():
return
@@ -381,9 +393,9 @@ class CitusHandler(Thread):
# Resharding in Citus implemented using logical replication
parameters['wal_level'] = 'logical'
def ignore_replication_slot(self, slot):
if self.is_enabled() and self._postgresql.is_leader() and\
def ignore_replication_slot(self, slot: Dict[str, str]) -> bool:
if isinstance(self._config, dict) and self._postgresql.is_leader() and\
slot['type'] == 'logical' and slot['database'] == self._config['database']:
m = CITUS_SLOT_NAME_RE.match(slot['name'])
return m and {'move': 'pgoutput', 'split': 'citus'}.get(m.group(1)) == slot['plugin']
return bool(m and {'move': 'pgoutput', 'split': 'citus'}.get(m.group(1)) == slot['plugin'])
return False

View File

@@ -7,21 +7,26 @@ import stat
import time
from urllib.parse import urlparse, parse_qsl, unquote
from types import TracebackType
from typing import Any, Collection, Dict, List, Optional, Union, Tuple, Type, TYPE_CHECKING
from .validator import recovery_parameters, transform_postgresql_parameter_value, transform_recovery_parameter_value
from ..collections import CaseInsensitiveDict
from ..dcs import RemoteMember, slot_name_from_member_name
from ..collections import CaseInsensitiveDict, CaseInsensitiveSet
from ..dcs import Leader, Member, RemoteMember, slot_name_from_member_name
from ..exceptions import PatroniFatalException
from ..utils import compare_values, parse_bool, parse_int, split_host_port, uri, \
validate_directory, is_subpath
from ..utils import compare_values, parse_bool, parse_int, split_host_port, uri, validate_directory, is_subpath
from ..validator import IntValidator
if TYPE_CHECKING: # pragma: no cover
from . import Postgresql
logger = logging.getLogger(__name__)
PARAMETER_RE = re.compile(r'([a-z_]+)\s*=\s*')
def conninfo_uri_parse(dsn):
ret = {}
def conninfo_uri_parse(dsn: str) -> Dict[str, str]:
ret: Dict[str, str] = {}
r = urlparse(dsn)
if r.username:
ret['user'] = r.username
@@ -29,21 +34,18 @@ def conninfo_uri_parse(dsn):
ret['password'] = r.password
if r.path[1:]:
ret['dbname'] = r.path[1:]
hosts = []
ports = []
hosts: List[str] = []
ports: List[str] = []
for netloc in r.netloc.split('@')[-1].split(','):
host = port = None
host = None
if '[' in netloc and ']' in netloc:
host = netloc.split(']')[0][1:]
tmp = netloc.split(':', 1)
if host is None:
host = tmp[0]
hosts.append(host)
if len(tmp) == 2:
host, port = tmp
if host is not None:
hosts.append(host)
if port is not None:
ports.append(port)
ports.append(tmp[1])
if hosts:
ret['host'] = ','.join(hosts)
if ports:
@@ -56,7 +58,7 @@ def conninfo_uri_parse(dsn):
return ret
def read_param_value(value):
def read_param_value(value: str) -> Union[Tuple[None, None], Tuple[str, int]]:
length = len(value)
ret = ''
is_quoted = value[0] == "'"
@@ -76,8 +78,8 @@ def read_param_value(value):
return (None, None) if is_quoted else (ret, i)
def conninfo_parse(dsn):
ret = {}
def conninfo_parse(dsn: str) -> Optional[Dict[str, str]]:
ret: Dict[str, str] = {}
length = len(dsn)
i = 0
while i < length:
@@ -96,14 +98,14 @@ def conninfo_parse(dsn):
return
value, end = read_param_value(dsn[i:])
if value is None:
if value is None or end is None:
return
i += end
ret[param] = value
return ret
def parse_dsn(value):
def parse_dsn(value: str) -> Optional[Dict[str, str]]:
"""
Very simple equivalent of `psycopg2.extensions.parse_dsn` introduced in 2.7.0.
We are not using psycopg2 function in order to remain compatible with 2.5.4+.
@@ -147,14 +149,14 @@ def parse_dsn(value):
return ret
def strip_comment(value):
def strip_comment(value: str) -> str:
i = value.find('#')
if i > -1:
value = value[:i].strip()
return value
def read_recovery_param_value(value):
def read_recovery_param_value(value: str) -> Optional[str]:
"""
>>> read_recovery_param_value('') is None
True
@@ -211,7 +213,7 @@ def read_recovery_param_value(value):
return value
def mtime(filename):
def mtime(filename: str) -> Optional[float]:
try:
return os.stat(filename).st_mtime
except OSError:
@@ -220,35 +222,49 @@ def mtime(filename):
class ConfigWriter(object):
def __init__(self, filename):
def __init__(self, filename: str) -> None:
self._filename = filename
self._fd = None
def __enter__(self):
def __enter__(self) -> 'ConfigWriter':
self._fd = open(self._filename, 'w')
self.writeline('# Do not edit this file manually!\n# It will be overwritten by Patroni!')
return self
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]) -> None:
if self._fd:
self._fd.close()
def writeline(self, line):
self._fd.write(line)
self._fd.write('\n')
def writeline(self, line: str) -> None:
if self._fd:
self._fd.write(line)
self._fd.write('\n')
def writelines(self, lines):
def writelines(self, lines: List[str]) -> None:
for line in lines:
self.writeline(line)
@staticmethod
def escape(value): # Escape (by doubling) any single quotes or backslashes in given string
def escape(value: Any) -> str: # Escape (by doubling) any single quotes or backslashes in given string
return re.sub(r'([\'\\])', r'\1\1', str(value))
def write_param(self, param, value):
def write_param(self, param: str, value: Any) -> None:
self.writeline("{0} = '{1}'".format(param, self.escape(value)))
def _false_validator(value: Any) -> bool:
return False
def _wal_level_validator(value: Any) -> bool:
return str(value).lower() in ('hot_standby', 'replica', 'logical')
def _bool_validator(value: Any) -> bool:
return parse_bool(value) is not None
class ConfigHandler(object):
# List of parameters which must be always passed to postmaster as command line options
@@ -266,28 +282,28 @@ class ConfigHandler(object):
# check_function -- if the new value is not correct must return `!False`
# min_version -- major version of PostgreSQL when parameter was introduced
CMDLINE_OPTIONS = CaseInsensitiveDict({
'listen_addresses': (None, lambda _: False, 90100),
'port': (None, lambda _: False, 90100),
'cluster_name': (None, lambda _: False, 90500),
'wal_level': ('hot_standby', lambda v: v.lower() in ('hot_standby', 'replica', 'logical'), 90100),
'hot_standby': ('on', lambda _: False, 90100),
'max_connections': (100, lambda v: int(v) >= 25, 90100),
'max_wal_senders': (10, lambda v: int(v) >= 3, 90100),
'wal_keep_segments': (8, lambda v: int(v) >= 1, 90100),
'wal_keep_size': ('128MB', lambda v: parse_int(v, 'MB') >= 16, 130000),
'max_prepared_transactions': (0, lambda v: int(v) >= 0, 90100),
'max_locks_per_transaction': (64, lambda v: int(v) >= 32, 90100),
'track_commit_timestamp': ('off', lambda v: parse_bool(v) is not None, 90500),
'max_replication_slots': (10, lambda v: int(v) >= 4, 90400),
'max_worker_processes': (8, lambda v: int(v) >= 2, 90400),
'wal_log_hints': ('on', lambda _: False, 90400)
'listen_addresses': (None, _false_validator, 90100),
'port': (None, _false_validator, 90100),
'cluster_name': (None, _false_validator, 90500),
'wal_level': ('hot_standby', _wal_level_validator, 90100),
'hot_standby': ('on', _false_validator, 90100),
'max_connections': (100, IntValidator(min=25), 90100),
'max_wal_senders': (10, IntValidator(min=3), 90100),
'wal_keep_segments': (8, IntValidator(min=1), 90100),
'wal_keep_size': ('128MB', IntValidator(min=16, base_unit='MB'), 130000),
'max_prepared_transactions': (0, IntValidator(min=0), 90100),
'max_locks_per_transaction': (64, IntValidator(min=32), 90100),
'track_commit_timestamp': ('off', _bool_validator, 90500),
'max_replication_slots': (10, IntValidator(min=4), 90400),
'max_worker_processes': (8, IntValidator(min=2), 90400),
'wal_log_hints': ('on', _false_validator, 90400)
})
_RECOVERY_PARAMETERS = set(recovery_parameters.keys())
_RECOVERY_PARAMETERS = CaseInsensitiveSet(recovery_parameters.keys())
def __init__(self, postgresql, config):
def __init__(self, postgresql: 'Postgresql', config: Dict[str, Any]) -> None:
self._postgresql = postgresql
self._config_dir = os.path.abspath(config.get('config_dir') or postgresql.data_dir)
self._config_dir = os.path.abspath(config.get('config_dir', '') or postgresql.data_dir)
config_base_name = config.get('config_base_name', 'postgresql')
self._postgresql_conf = os.path.join(self._config_dir, config_base_name + '.conf')
self._postgresql_conf_mtime = None
@@ -309,21 +325,22 @@ class ConfigHandler(object):
self._passfile_mtime = None
self._synchronous_standby_names = None
self._postmaster_ctime = None
self._current_recovery_params = None
self._current_recovery_params: Optional[CaseInsensitiveDict] = None
self._config = {}
self._recovery_params = {}
self._recovery_params = CaseInsensitiveDict()
self._server_parameters: CaseInsensitiveDict
self.reload_config(config)
def setup_server_parameters(self):
def setup_server_parameters(self) -> None:
self._server_parameters = self.get_server_parameters(self._config)
self._adjust_recovery_parameters()
def try_to_create_dir(self, d, msg):
d = os.path.join(self._postgresql._data_dir, d)
if (not is_subpath(self._postgresql._data_dir, d) or not self._postgresql.data_directory_empty()):
def try_to_create_dir(self, d: str, msg: str) -> None:
d = os.path.join(self._postgresql.data_dir, d)
if (not is_subpath(self._postgresql.data_dir, d) or not self._postgresql.data_directory_empty()):
validate_directory(d, msg)
def check_directories(self):
def check_directories(self) -> None:
if "unix_socket_directories" in self._server_parameters:
for d in self._server_parameters["unix_socket_directories"].split(","):
self.try_to_create_dir(d.strip(), "'{}' is defined in unix_socket_directories, {}")
@@ -335,7 +352,11 @@ class ConfigHandler(object):
"'{}' is defined in `postgresql.pgpass`, {}")
@property
def _configuration_to_save(self):
def config_dir(self) -> str:
return self._config_dir
@property
def _configuration_to_save(self) -> List[str]:
configuration = [os.path.basename(self._postgresql_conf)]
if 'custom_conf' not in self._config:
configuration.append(os.path.basename(self._postgresql_base_conf_name))
@@ -345,7 +366,7 @@ class ConfigHandler(object):
configuration.append('pg_ident.conf')
return configuration
def save_configuration_files(self, check_custom_bootstrap=False):
def save_configuration_files(self, check_custom_bootstrap: bool = False) -> bool:
"""
copy postgresql.conf to postgresql.conf.backup to be able to retrieve configuration files
- originally stored as symlinks, those are normally skipped by pg_basebackup
@@ -362,7 +383,7 @@ class ConfigHandler(object):
logger.exception('unable to create backup copies of configuration files')
return True
def restore_configuration_files(self):
def restore_configuration_files(self) -> None:
""" restore a previously saved postgresql.conf """
try:
for f in self._configuration_to_save:
@@ -377,7 +398,7 @@ class ConfigHandler(object):
except IOError:
logger.exception('unable to restore configuration files from backup')
def write_postgresql_conf(self, configuration=None):
def write_postgresql_conf(self, configuration: Optional[CaseInsensitiveDict] = None) -> None:
# rename the original configuration if it is necessary
if 'custom_conf' not in self._config and not os.path.exists(self._postgresql_base_conf):
os.rename(self._postgresql_conf, self._postgresql_base_conf)
@@ -412,13 +433,13 @@ class ConfigHandler(object):
if not self._postgresql.bootstrap.keep_existing_recovery_conf:
self._sanitize_auto_conf()
def append_pg_hba(self, config):
def append_pg_hba(self, config: List[str]) -> bool:
if not self.hba_file and not self._config.get('pg_hba'):
with open(self._pg_hba_conf, 'a') as f:
f.write('\n{}\n'.format('\n'.join(config)))
return True
def replace_pg_hba(self):
def replace_pg_hba(self) -> Optional[bool]:
"""
Replace pg_hba.conf content in the PGDATA if hba_file is not defined in the
`postgresql.parameters` and pg_hba is defined in `postgresql` configuration section.
@@ -446,7 +467,7 @@ class ConfigHandler(object):
f.writelines(self._config['pg_hba'])
return True
def replace_pg_ident(self):
def replace_pg_ident(self) -> Optional[bool]:
"""
Replace pg_ident.conf content in the PGDATA if ident_file is not defined in the
`postgresql.parameters` and pg_ident is defined in the `postgresql` section.
@@ -459,8 +480,8 @@ class ConfigHandler(object):
f.writelines(self._config['pg_ident'])
return True
def primary_conninfo_params(self, member):
if not (member and member.conn_url) or member.name == self._postgresql.name:
def primary_conninfo_params(self, member: Union[Leader, Member, None]) -> Optional[Dict[str, Any]]:
if not member or not member.conn_url or member.name == self._postgresql.name:
return None
ret = member.conn_kwargs(self.replication)
ret['application_name'] = self._postgresql.name
@@ -475,7 +496,7 @@ class ConfigHandler(object):
del ret['dbname']
return ret
def format_dsn(self, params, include_dbname=False):
def format_dsn(self, params: Dict[str, Any], include_dbname: bool = False) -> str:
# A list of keywords that can be found in a conninfo string. Follows what is acceptable by libpq
keywords = ('dbname', 'user', 'passfile' if params.get('passfile') else 'password', 'host', 'port',
'sslmode', 'sslcompression', 'sslcert', 'sslkey', 'sslpassword', 'sslrootcert', 'sslcrl',
@@ -491,13 +512,13 @@ class ConfigHandler(object):
else:
skip = {'dbname'}
def escape(value):
def escape(value: Any) -> str:
return re.sub(r'([\'\\ ])', r'\\\1', str(value))
return ' '.join('{0}={1}'.format(kw, escape(params[kw])) for kw in keywords
if kw not in skip and params.get(kw) is not None)
def _write_recovery_params(self, fd, recovery_params):
def _write_recovery_params(self, fd: ConfigWriter, recovery_params: CaseInsensitiveDict) -> None:
if self._postgresql.major_version >= 90500:
pause_at_recovery_target = parse_bool(recovery_params.pop('pause_at_recovery_target', None))
if pause_at_recovery_target is not None:
@@ -518,8 +539,8 @@ class ConfigHandler(object):
continue
fd.write_param(name, value)
def build_recovery_params(self, member):
recovery_params = CaseInsensitiveDict({p: v for p, v in self.get('recovery_conf', {}).items()
def build_recovery_params(self, member: Union[Leader, Member, None]) -> CaseInsensitiveDict:
recovery_params = CaseInsensitiveDict({p: v for p, v in (self.get('recovery_conf') or {}).items()
if not p.lower().startswith('recovery_target')
and p.lower() not in ('primary_conninfo', 'primary_slot_name')})
recovery_params.update({'standby_mode': 'on', 'recovery_target_timeline': 'latest'})
@@ -550,26 +571,26 @@ class ConfigHandler(object):
recovery_params.update({p: member.data.get(p) for p in standby_cluster_params if member and member.data.get(p)})
return recovery_params
def recovery_conf_exists(self):
def recovery_conf_exists(self) -> bool:
if self._postgresql.major_version >= 120000:
return os.path.exists(self._standby_signal) or os.path.exists(self._recovery_signal)
return os.path.exists(self._recovery_conf)
@property
def triggerfile_good_name(self):
def triggerfile_good_name(self) -> str:
return 'trigger_file' if self._postgresql.major_version < 120000 else 'promote_trigger_file'
@property
def _triggerfile_wrong_name(self):
def _triggerfile_wrong_name(self) -> str:
return 'trigger_file' if self._postgresql.major_version >= 120000 else 'promote_trigger_file'
@property
def _recovery_parameters_to_compare(self):
skip_params = {'pause_at_recovery_target', 'recovery_target_inclusive',
'recovery_target_action', 'standby_mode', self._triggerfile_wrong_name}
return self._RECOVERY_PARAMETERS - skip_params
def _recovery_parameters_to_compare(self) -> CaseInsensitiveSet:
skip_params = CaseInsensitiveSet({'pause_at_recovery_target', 'recovery_target_inclusive',
'recovery_target_action', 'standby_mode', self._triggerfile_wrong_name})
return CaseInsensitiveSet(self._RECOVERY_PARAMETERS - skip_params)
def _read_recovery_params(self):
def _read_recovery_params(self) -> Tuple[Optional[CaseInsensitiveDict], Optional[bool]]:
if self._postgresql.is_starting():
return None, False
@@ -586,7 +607,7 @@ class ConfigHandler(object):
try:
values = self._get_pg_settings(self._recovery_parameters_to_compare).values()
values = {p[0]: [p[1], p[4] == 'postmaster', p[5]] for p in values}
values = CaseInsensitiveDict({p[0]: [p[1], p[4] == 'postmaster', p[5]] for p in values})
self._postgresql_conf_mtime = pg_conf_mtime
self._auto_conf_mtime = auto_conf_mtime
self._postmaster_ctime = postmaster_ctime
@@ -594,13 +615,13 @@ class ConfigHandler(object):
values = None
return values, True
def _read_recovery_params_pre_v12(self):
def _read_recovery_params_pre_v12(self) -> Tuple[Optional[CaseInsensitiveDict], Optional[bool]]:
recovery_conf_mtime = mtime(self._recovery_conf)
passfile_mtime = mtime(self._passfile) if self._passfile else False
if recovery_conf_mtime == self._recovery_conf_mtime and passfile_mtime == self._passfile_mtime:
return None, False
values = {}
values = CaseInsensitiveDict()
with open(self._recovery_conf, 'r') as f:
for line in f:
line = line.strip()
@@ -610,7 +631,7 @@ class ConfigHandler(object):
match = PARAMETER_RE.match(line)
if match:
value = read_recovery_param_value(line[match.end():])
if value is None:
if match is None or value is None:
return None, True
values[match.group(1)] = [value, True]
self._recovery_conf_mtime = recovery_conf_mtime
@@ -619,7 +640,7 @@ class ConfigHandler(object):
values.update({param: ['', True] for param in self._recovery_parameters_to_compare if param not in values})
return values, True
def _check_passfile(self, passfile, wanted_primary_conninfo):
def _check_passfile(self, passfile: str, wanted_primary_conninfo: Dict[str, Any]) -> bool:
# If there is a passfile in the primary_conninfo try to figure out that
# the passfile contains the line(s) allowing connection to the given node.
# We assume that the passfile was created by Patroni and therefore doing
@@ -628,7 +649,7 @@ class ConfigHandler(object):
if passfile_mtime:
try:
with open(passfile) as f:
wanted_lines = self._pgpass_line(wanted_primary_conninfo).splitlines()
wanted_lines = (self._pgpass_line(wanted_primary_conninfo) or '').splitlines()
file_lines = f.read().splitlines()
if set(wanted_lines) == set(file_lines):
self._passfile = passfile
@@ -638,7 +659,8 @@ class ConfigHandler(object):
logger.info('Failed to read %s', passfile)
return False
def _check_primary_conninfo(self, primary_conninfo, wanted_primary_conninfo):
def _check_primary_conninfo(self, primary_conninfo: Dict[str, Any],
wanted_primary_conninfo: Dict[str, Any]) -> bool:
# first we will cover corner cases, when we are replicating from somewhere while shouldn't
# or there is no primary_conninfo but we should replicate from some specific node.
if not wanted_primary_conninfo:
@@ -667,7 +689,7 @@ class ConfigHandler(object):
return all(str(primary_conninfo.get(p)) == str(v) for p, v in wanted_primary_conninfo.items() if v is not None)
def check_recovery_conf(self, member):
def check_recovery_conf(self, member: Union[Leader, Member, None]) -> Tuple[bool, bool]:
"""Returns a tuple. The first boolean element indicates that recovery params don't match
and the second is set to `True` if the restart is required in order to apply new values"""
@@ -701,7 +723,7 @@ class ConfigHandler(object):
else: # empty string, primary_conninfo is not in the config
primary_conninfo[0] = {}
if not self._postgresql.is_starting():
if not self._postgresql.is_starting() and self._current_recovery_params:
# when wal receiver is alive take primary_slot_name from pg_stat_wal_receiver
wal_receiver_primary_slot_name = self._postgresql.primary_slot_name()
if not wal_receiver_primary_slot_name and self._postgresql.primary_conninfo():
@@ -715,11 +737,11 @@ class ConfigHandler(object):
and not self._postgresql.cb_called
and not self._postgresql.is_starting())}
def record_missmatch(mtype):
def record_missmatch(mtype: bool) -> None:
required['restart' if mtype else 'reload'] += 1
wanted_recovery_params = self.build_recovery_params(member)
for param, value in self._current_recovery_params.items():
for param, value in (self._current_recovery_params or {}).items():
# Skip certain parameters defined in the included postgres config files
# if we know that they are not specified in the patroni configuration.
if len(value) > 2 and value[2] not in (self._postgresql_conf, self._auto_conf) and \
@@ -741,25 +763,25 @@ class ConfigHandler(object):
return required['restart'] + required['reload'] > 0, required['restart'] > 0
@staticmethod
def _remove_file_if_exists(name):
def _remove_file_if_exists(name: str) -> None:
if os.path.isfile(name) or os.path.islink(name):
os.unlink(name)
@staticmethod
def _pgpass_line(record):
def _pgpass_line(record: Dict[str, Any]) -> Optional[str]:
if 'password' in record:
def escape(value):
def escape(value: Any) -> str:
return re.sub(r'([:\\])', r'\\\1', str(value))
record = {n: escape(record.get(n) or '*') for n in ('host', 'port', 'user', 'password')}
# 'host' could be several comma-separated hostnames, in this case
# we need to write on pgpass line per host
line = ''
for hostname in record.get('host').split(','):
for hostname in record['host'].split(','):
line += hostname + ':{port}:*:{user}:{password}'.format(**record) + '\n'
return line.rstrip()
def write_pgpass(self, record):
def write_pgpass(self, record: Dict[str, Any]) -> Dict[str, str]:
line = self._pgpass_line(record)
if not line:
return os.environ.copy()
@@ -770,7 +792,7 @@ class ConfigHandler(object):
return {**os.environ, 'PGPASSFILE': self._pgpass}
def write_recovery_conf(self, recovery_params):
def write_recovery_conf(self, recovery_params: CaseInsensitiveDict) -> None:
self._recovery_params = recovery_params
if self._postgresql.major_version >= 120000:
if parse_bool(recovery_params.pop('standby_mode', None)):
@@ -779,28 +801,28 @@ class ConfigHandler(object):
self._remove_file_if_exists(self._standby_signal)
open(self._recovery_signal, 'w').close()
def restart_required(name):
def restart_required(name: str) -> bool:
if self._postgresql.major_version >= 140000:
return False
return name == 'restore_command' or (self._postgresql.major_version < 130000
and name in ('primary_conninfo', 'primary_slot_name'))
self._current_recovery_params = {n: [v, restart_required(n), self._postgresql_conf]
for n, v in recovery_params.items()}
self._current_recovery_params = CaseInsensitiveDict({n: [v, restart_required(n), self._postgresql_conf]
for n, v in recovery_params.items()})
else:
with ConfigWriter(self._recovery_conf) as f:
os.chmod(self._recovery_conf, stat.S_IWRITE | stat.S_IREAD)
self._write_recovery_params(f, recovery_params)
def remove_recovery_conf(self):
def remove_recovery_conf(self) -> None:
for name in (self._recovery_conf, self._standby_signal, self._recovery_signal):
self._remove_file_if_exists(name)
self._recovery_params = {}
self._recovery_params = CaseInsensitiveDict()
self._current_recovery_params = None
def _sanitize_auto_conf(self):
def _sanitize_auto_conf(self) -> None:
overwrite = False
lines = []
lines: List[str] = []
if os.path.exists(self._auto_conf):
try:
@@ -823,7 +845,7 @@ class ConfigHandler(object):
except Exception:
logger.exception('Failed to remove some unwanted parameters from %s', self._auto_conf)
def _adjust_recovery_parameters(self):
def _adjust_recovery_parameters(self) -> None:
# It is not strictly necessary, but we can make patroni configs crossi-compatible with all postgres versions.
recovery_conf = {n: v for n, v in self._server_parameters.items() if n.lower() in self._RECOVERY_PARAMETERS}
if recovery_conf:
@@ -834,7 +856,7 @@ class ConfigHandler(object):
if self.triggerfile_good_name not in self._config['recovery_conf'] and value:
self._config['recovery_conf'][self.triggerfile_good_name] = value
def get_server_parameters(self, config):
def get_server_parameters(self, config: Dict[str, Any]) -> CaseInsensitiveDict:
parameters = config['parameters'].copy()
listen_addresses, port = split_host_port(config['listen'], 5432)
parameters.update(cluster_name=self._postgresql.scope, listen_addresses=listen_addresses, port=str(port))
@@ -860,7 +882,7 @@ class ConfigHandler(object):
parameters.setdefault('wal_keep_size', str(int(wal_keep_segments) * 16) + 'MB')
elif self._postgresql.major_version:
wal_keep_size = parse_int(parameters.pop('wal_keep_size', self.CMDLINE_OPTIONS['wal_keep_size'][0]), 'MB')
parameters.setdefault('wal_keep_segments', int((wal_keep_size + 8) / 16))
parameters.setdefault('wal_keep_segments', int(((wal_keep_size or 0) + 8) / 16))
self._postgresql.citus_handler.adjust_postgres_gucs(parameters)
@@ -870,14 +892,14 @@ class ConfigHandler(object):
return ret
@staticmethod
def _get_unix_local_address(unix_socket_directories):
def _get_unix_local_address(unix_socket_directories: str) -> str:
for d in unix_socket_directories.split(','):
d = d.strip()
if d.startswith('/'): # Only absolute path can be used to connect via unix-socket
return d
return ''
def _get_tcp_local_address(self):
def _get_tcp_local_address(self) -> str:
listen_addresses = self._server_parameters['listen_addresses'].split(',')
for la in listen_addresses:
@@ -886,7 +908,7 @@ class ConfigHandler(object):
return listen_addresses[0].strip() # can't use localhost, take first address from listen_addresses
@property
def local_connect_kwargs(self):
def local_connect_kwargs(self) -> Dict[str, Any]:
ret = self._local_address.copy()
# add all of the other connection settings that are available
ret.update(self._superuser)
@@ -902,7 +924,7 @@ class ConfigHandler(object):
'options': '-c statement_timeout=2000'})
return ret
def resolve_connection_addresses(self):
def resolve_connection_addresses(self) -> None:
port = self._server_parameters['port']
tcp_local_address = self._get_tcp_local_address()
netloc = self._config.get('connect_address') or tcp_local_address + ':' + port
@@ -922,19 +944,24 @@ class ConfigHandler(object):
self._postgresql.connection_string = uri('postgres', netloc, self._postgresql.database)
self._postgresql.set_connection_kwargs(self.local_connect_kwargs)
def _get_pg_settings(self, names):
def _get_pg_settings(
self, names: Collection[str]
) -> Dict[str, Tuple[str, str, Optional[str], str, str, Optional[str]]]:
return {r[0]: r for r in self._postgresql.query(('SELECT name, setting, unit, vartype, context, sourcefile'
+ ' FROM pg_catalog.pg_settings '
+ ' WHERE pg_catalog.lower(name) = ANY(%s)'),
[n.lower() for n in names])}
@staticmethod
def _handle_wal_buffers(old_values, changes):
wal_block_size = parse_int(old_values['wal_block_size'][1])
def _handle_wal_buffers(old_values: Dict[str, Tuple[str, str, Optional[str], str, str, Optional[str]]],
changes: CaseInsensitiveDict) -> None:
wal_block_size = parse_int(old_values['wal_block_size'][1]) or 8192
wal_segment_size = old_values['wal_segment_size']
wal_segment_unit = parse_int(wal_segment_size[2], 'B') if wal_segment_size[2][0].isdigit() else 1
wal_segment_size = parse_int(wal_segment_size[1]) * wal_segment_unit / wal_block_size
default_wal_buffers = min(max(parse_int(old_values['shared_buffers'][1]) / 32, 8), wal_segment_size)
wal_segment_unit = parse_int(wal_segment_size[2], 'B') or 8192 \
if wal_segment_size[2] is not None and wal_segment_size[2][0].isdigit() else 1
wal_segment_size = parse_int(wal_segment_size[1]) or (16777216 if wal_segment_size[2] is None else 2048)
wal_segment_size *= wal_segment_unit / wal_block_size
default_wal_buffers = min(max((parse_int(old_values['shared_buffers'][1]) or 16384) / 32, 8), wal_segment_size)
wal_buffers = old_values['wal_buffers']
new_value = str(changes['wal_buffers'] or -1)
@@ -945,7 +972,7 @@ class ConfigHandler(object):
if new_value == old_value:
del changes['wal_buffers']
def reload_config(self, config, sighup=False):
def reload_config(self, config: Dict[str, Any], sighup: bool = False) -> None:
self._superuser = config['authentication'].get('superuser', {})
server_parameters = self.get_server_parameters(config)
@@ -956,6 +983,7 @@ class ConfigHandler(object):
changes.update({p: None for p in self._server_parameters.keys()
if not (p in changes or p.lower() in self._RECOVERY_PARAMETERS)})
if changes:
undef = []
if 'wal_buffers' in changes: # we need to calculate the default value of wal_buffers
undef = [p for p in ('shared_buffers', 'wal_segment_size', 'wal_block_size') if p not in changes]
changes.update({p: None for p in undef})
@@ -1030,16 +1058,17 @@ class ConfigHandler(object):
if self._postgresql.major_version >= 90500:
time.sleep(1)
try:
pending_restart = self._postgresql.query(
'SELECT COUNT(*) FROM pg_catalog.pg_settings WHERE pg_catalog.lower(name) != ALL(%s)'
' AND pending_restart', [n.lower() for n in self._RECOVERY_PARAMETERS]).fetchone()[0] > 0
pending_restart = (self._postgresql.query(
'SELECT COUNT(*) FROM pg_catalog.pg_settings'
' WHERE pg_catalog.lower(name) != ALL(%s) AND pending_restart',
[n.lower() for n in self._RECOVERY_PARAMETERS]).fetchone() or (0,))[0] > 0
self._postgresql.set_pending_restart(pending_restart)
except Exception as e:
logger.warning('Exception %r when running query', e)
else:
logger.info('No PostgreSQL configuration items changed, nothing to reload.')
def set_synchronous_standby_names(self, value):
def set_synchronous_standby_names(self, value: Optional[str]) -> Optional[bool]:
"""Updates synchronous_standby_names and reloads if necessary.
:returns: True if value was updated."""
if value != self._synchronous_standby_names:
@@ -1054,7 +1083,7 @@ class ConfigHandler(object):
return True
@property
def effective_configuration(self):
def effective_configuration(self) -> CaseInsensitiveDict:
"""It might happen that the current value of one (or more) below parameters stored in
the controldata is higher than the value stored in the global cluster configuration.
@@ -1089,7 +1118,7 @@ class ConfigHandler(object):
continue
cvalue = parse_int(data[cname])
if cvalue > value:
if cvalue is not None and value is not None and cvalue > value:
effective_configuration[name] = cvalue
self._postgresql.set_pending_restart(True)
@@ -1113,35 +1142,35 @@ class ConfigHandler(object):
return effective_configuration
@property
def replication(self):
def replication(self) -> Dict[str, Any]:
return self._config['authentication']['replication']
@property
def superuser(self):
def superuser(self) -> Dict[str, Any]:
return self._superuser
@property
def rewind_credentials(self):
def rewind_credentials(self) -> Dict[str, Any]:
return self._config['authentication'].get('rewind', self._superuser) \
if self._postgresql.major_version >= 110000 else self._superuser
@property
def ident_file(self):
def ident_file(self) -> Optional[str]:
ident_file = self._server_parameters.get('ident_file')
return None if ident_file == self._pg_ident_conf else ident_file
@property
def hba_file(self):
def hba_file(self) -> Optional[str]:
hba_file = self._server_parameters.get('hba_file')
return None if hba_file == self._pg_hba_conf else hba_file
@property
def pg_hba_conf(self):
def pg_hba_conf(self) -> str:
return self._pg_hba_conf
@property
def postgresql_conf(self):
def postgresql_conf(self) -> str:
return self._postgresql_conf
def get(self, key, default=None):
def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
return self._config.get(key, default)

View File

@@ -2,6 +2,10 @@ import logging
from contextlib import contextmanager
from threading import Lock
from typing import Any, Dict, Generator, Union, TYPE_CHECKING
if TYPE_CHECKING: # pragma: no cover
from psycopg import Connection as Connection3, Cursor
from psycopg2 import connection, cursor
from .. import psycopg
@@ -9,29 +13,30 @@ logger = logging.getLogger(__name__)
class Connection(object):
server_version: int
def __init__(self):
def __init__(self) -> None:
self._lock = Lock()
self._connection = None
self._cursor_holder = None
def set_conn_kwargs(self, conn_kwargs):
def set_conn_kwargs(self, conn_kwargs: Dict[str, Any]) -> None:
self._conn_kwargs = conn_kwargs
def get(self):
def get(self) -> Union['connection', 'Connection3[Any]']:
with self._lock:
if not self._connection or self._connection.closed != 0:
self._connection = psycopg.connect(**self._conn_kwargs)
self.server_version = self._connection.server_version
self.server_version = getattr(self._connection, 'server_version', 0)
return self._connection
def cursor(self):
def cursor(self) -> Union['cursor', 'Cursor[Any]']:
if not self._cursor_holder or self._cursor_holder.closed or self._cursor_holder.connection.closed != 0:
logger.info("establishing a new patroni connection to the postgres cluster")
self._cursor_holder = self.get().cursor()
return self._cursor_holder
def close(self):
def close(self) -> None:
if self._connection and self._connection.closed == 0:
self._connection.close()
logger.info("closed patroni connection to the postgresql cluster")
@@ -39,7 +44,7 @@ class Connection(object):
@contextmanager
def get_connection_cursor(**kwargs):
def get_connection_cursor(**kwargs: Any) -> Generator[Union['cursor', 'Cursor[Any]'], None, None]:
conn = psycopg.connect(**kwargs)
with conn.cursor() as cur:
yield cur

View File

@@ -2,7 +2,9 @@ import errno
import logging
import os
from patroni.exceptions import PostgresException
from typing import Iterable, Tuple
from ..exceptions import PostgresException
logger = logging.getLogger(__name__)
@@ -55,29 +57,27 @@ def postgres_major_version_to_int(pg_version: str) -> int:
return postgres_version_to_int(pg_version + '.0')
def parse_lsn(lsn):
def parse_lsn(lsn: str) -> int:
t = lsn.split('/')
return int(t[0], 16) * 0x100000000 + int(t[1], 16)
def parse_history(data):
def parse_history(data: str) -> Iterable[Tuple[int, int, str]]:
for line in data.split('\n'):
values = line.strip().split('\t')
if len(values) == 3:
try:
values[0] = int(values[0])
values[1] = parse_lsn(values[1])
yield values
yield int(values[0]), parse_lsn(values[1]), values[2]
except (IndexError, ValueError):
logger.exception('Exception when parsing timeline history line "%s"', values)
def format_lsn(lsn, full=False):
def format_lsn(lsn: int, full: bool = False) -> str:
template = '{0:X}/{1:08X}' if full else '{0:X}/{1:X}'
return template.format(lsn >> 32, lsn & 0xFFFFFFFF)
def fsync_dir(path):
def fsync_dir(path: str) -> None:
if os.name != 'nt':
fd = os.open(path, os.O_DIRECTORY)
try:

View File

@@ -7,6 +7,9 @@ import signal
import subprocess
import sys
from multiprocessing.connection import Connection
from typing import Dict, Optional, List
from patroni import PATRONI_ENV_PREFIX, KUBERNETES_ENV_PREFIX
# avoid spawning the resource tracker process
@@ -26,7 +29,7 @@ STOP_SIGNALS = {
}
def pg_ctl_start(conn, cmdline, env):
def pg_ctl_start(conn: Connection, cmdline: List[str], env: Dict[str, str]) -> None:
if os.name != 'nt':
os.setsid()
try:
@@ -40,7 +43,8 @@ def pg_ctl_start(conn, cmdline, env):
class PostmasterProcess(psutil.Process):
def __init__(self, pid):
def __init__(self, pid: int) -> None:
self._postmaster_pid: Dict[str, str]
self.is_single_user = False
if pid < 0:
pid = -pid
@@ -48,7 +52,7 @@ class PostmasterProcess(psutil.Process):
super(PostmasterProcess, self).__init__(pid)
@staticmethod
def _read_postmaster_pidfile(data_dir):
def _read_postmaster_pidfile(data_dir: str) -> Dict[str, str]:
"""Reads and parses postmaster.pid from the data directory
:returns dictionary of values if successful, empty dictionary otherwise
@@ -60,7 +64,7 @@ class PostmasterProcess(psutil.Process):
except IOError:
return {}
def _is_postmaster_process(self):
def _is_postmaster_process(self) -> bool:
try:
start_time = int(self._postmaster_pid.get('start_time', 0))
if start_time and abs(self.create_time() - start_time) > 3:
@@ -79,7 +83,7 @@ class PostmasterProcess(psutil.Process):
return True
@classmethod
def _from_pidfile(cls, data_dir):
def _from_pidfile(cls, data_dir: str) -> Optional['PostmasterProcess']:
postmaster_pid = PostmasterProcess._read_postmaster_pidfile(data_dir)
try:
pid = int(postmaster_pid.get('pid', 0))
@@ -88,10 +92,10 @@ class PostmasterProcess(psutil.Process):
proc._postmaster_pid = postmaster_pid
return proc
except ValueError:
pass
return None
@staticmethod
def from_pidfile(data_dir):
def from_pidfile(data_dir: str) -> Optional['PostmasterProcess']:
try:
proc = PostmasterProcess._from_pidfile(data_dir)
return proc if proc and proc._is_postmaster_process() else None
@@ -99,13 +103,13 @@ class PostmasterProcess(psutil.Process):
return None
@classmethod
def from_pid(cls, pid):
def from_pid(cls, pid: int) -> Optional['PostmasterProcess']:
try:
return cls(pid)
except psutil.NoSuchProcess:
return None
def signal_kill(self):
def signal_kill(self) -> bool:
"""to suspend and kill postmaster and all children
:returns True if postmaster and children are killed, False if error
@@ -141,7 +145,7 @@ class PostmasterProcess(psutil.Process):
psutil.wait_procs(children + [self])
return True
def signal_stop(self, mode, pg_ctl='pg_ctl'):
def signal_stop(self, mode: str, pg_ctl: str = 'pg_ctl') -> Optional[bool]:
"""Signal postmaster process to stop
:returns None if signaled, True if process is already gone, False if error
@@ -161,7 +165,7 @@ class PostmasterProcess(psutil.Process):
return None
def pg_ctl_kill(self, mode, pg_ctl):
def pg_ctl_kill(self, mode: str, pg_ctl: str) -> Optional[bool]:
try:
status = subprocess.call([pg_ctl, "kill", STOP_SIGNALS[mode], str(self.pid)])
except OSError:
@@ -171,7 +175,7 @@ class PostmasterProcess(psutil.Process):
else:
return not self.is_running()
def wait_for_user_backends_to_close(self, stop_timeout):
def wait_for_user_backends_to_close(self, stop_timeout: Optional[float]) -> None:
# These regexps are cross checked against versions PostgreSQL 9.1 .. 15
aux_proc_re = re.compile("(?:postgres:)( .*:)? (?:(?:archiver|startup|autovacuum launcher|autovacuum worker|"
"checkpointer|logger|stats collector|wal receiver|wal writer|writer)(?: process )?|"
@@ -183,8 +187,8 @@ class PostmasterProcess(psutil.Process):
except psutil.Error:
return logger.debug('Failed to get list of postmaster children')
user_backends = []
user_backends_cmdlines = {}
user_backends: List[psutil.Process] = []
user_backends_cmdlines: Dict[int, str] = {}
for child in children:
try:
cmdline = child.cmdline()
@@ -195,7 +199,7 @@ class PostmasterProcess(psutil.Process):
pass
if user_backends:
logger.debug('Waiting for user backends %s to close', ', '.join(user_backends_cmdlines.values()))
gone, live = psutil.wait_procs(user_backends, stop_timeout)
_, live = psutil.wait_procs(user_backends, stop_timeout)
if stop_timeout and live:
live = [user_backends_cmdlines[b.pid] for b in live]
logger.warning('Backends still alive after %s: %s', stop_timeout, ', '.join(live))
@@ -203,7 +207,7 @@ class PostmasterProcess(psutil.Process):
logger.debug("Backends closed")
@staticmethod
def start(pgcommand, data_dir, conf, options):
def start(pgcommand: str, data_dir: str, conf: str, options: List[str]) -> Optional['PostmasterProcess']:
# Unfortunately `pg_ctl start` does not return postmaster pid to us. Without this information
# it is hard to know the current state of postgres startup, so we had to reimplement pg_ctl start
# in python. It will start postgres, wait for port to be open and wait until postgres will start
@@ -234,7 +238,7 @@ class PostmasterProcess(psutil.Process):
pass
cmdline = [pgcommand, '-D', data_dir, '--config-file={}'.format(conf)] + options
logger.debug("Starting postgres: %s", " ".join(cmdline))
ctx = multiprocessing.get_context('spawn') if sys.version_info >= (3, 4) else multiprocessing
ctx = multiprocessing.get_context('spawn')
parent_conn, child_conn = ctx.Pipe(False)
proc = ctx.Process(target=pg_ctl_start, args=(child_conn, cmdline, env))
proc.start()

View File

@@ -5,36 +5,46 @@ import shlex
import shutil
import subprocess
from enum import IntEnum
from threading import Lock, Thread
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
from . import Postgresql
from .connection import get_connection_cursor
from .misc import format_lsn, fsync_dir, parse_history, parse_lsn
from ..async_executor import CriticalTask
from ..dcs import Leader
from ..dcs import Leader, RemoteMember
logger = logging.getLogger(__name__)
REWIND_STATUS = type('Enum', (), {'INITIAL': 0, 'CHECKPOINT': 1, 'CHECK': 2, 'NEED': 3,
'NOT_NEED': 4, 'SUCCESS': 5, 'FAILED': 6})
class REWIND_STATUS(IntEnum):
INITIAL = 0
CHECKPOINT = 1
CHECK = 2
NEED = 3
NOT_NEED = 4
SUCCESS = 5
FAILED = 6
class Rewind(object):
def __init__(self, postgresql):
def __init__(self, postgresql: Postgresql) -> None:
self._postgresql = postgresql
self._checkpoint_task_lock = Lock()
self.reset_state()
@staticmethod
def configuration_allows_rewind(data):
def configuration_allows_rewind(data: Dict[str, str]) -> bool:
return data.get('wal_log_hints setting', 'off') == 'on' or data.get('Data page checksum version', '0') != '0'
@property
def enabled(self):
return self._postgresql.config.get('use_pg_rewind')
def enabled(self) -> bool:
return bool(self._postgresql.config.get('use_pg_rewind'))
@property
def can_rewind(self):
def can_rewind(self) -> bool:
""" check if pg_rewind executable is there and that pg_controldata indicates
we have either wal_log_hints or checksums turned on
"""
@@ -52,43 +62,45 @@ class Rewind(object):
return self.configuration_allows_rewind(self._postgresql.controldata())
@property
def should_remove_data_directory_on_diverged_timelines(self):
return self._postgresql.config.get('remove_data_directory_on_diverged_timelines')
def should_remove_data_directory_on_diverged_timelines(self) -> bool:
return bool(self._postgresql.config.get('remove_data_directory_on_diverged_timelines'))
@property
def can_rewind_or_reinitialize_allowed(self):
def can_rewind_or_reinitialize_allowed(self) -> bool:
return self.should_remove_data_directory_on_diverged_timelines or self.can_rewind
def trigger_check_diverged_lsn(self):
def trigger_check_diverged_lsn(self) -> None:
if self.can_rewind_or_reinitialize_allowed and self._state != REWIND_STATUS.NEED:
self._state = REWIND_STATUS.CHECK
@staticmethod
def check_leader_is_not_in_recovery(conn_kwargs):
def check_leader_is_not_in_recovery(conn_kwargs: Dict[str, Any]) -> Optional[bool]:
try:
with get_connection_cursor(connect_timeout=3, options='-c statement_timeout=2000', **conn_kwargs) as cur:
cur.execute('SELECT pg_catalog.pg_is_in_recovery()')
if not cur.fetchone()[0]:
row = cur.fetchone()
if not row or not row[0]:
return True
logger.info('Leader is still in_recovery and therefore can\'t be used for rewind')
except Exception:
return logger.exception('Exception when working with leader')
@staticmethod
def check_leader_has_run_checkpoint(conn_kwargs):
def check_leader_has_run_checkpoint(conn_kwargs: Dict[str, Any]) -> Optional[str]:
try:
with get_connection_cursor(connect_timeout=3, options='-c statement_timeout=2000', **conn_kwargs) as cur:
cur.execute("SELECT NOT pg_catalog.pg_is_in_recovery()"
" AND ('x' || pg_catalog.substr(pg_catalog.pg_walfile_name("
" pg_catalog.pg_current_wal_lsn()), 1, 8))::bit(32)::int = timeline_id"
" FROM pg_catalog.pg_control_checkpoint()")
if not cur.fetchone()[0]:
row = cur.fetchone()
if not row or not row[0]:
return 'leader has not run a checkpoint yet'
except Exception:
logger.exception('Exception when working with leader')
return 'not accessible or not healty'
def _get_checkpoint_end(self, timeline, lsn):
def _get_checkpoint_end(self, timeline: int, lsn: int) -> int:
"""The checkpoint record size in WAL depends on postgres major version and platform (memory alignment).
Hence, the only reliable way to figure out where it ends, read the record from file with the help of pg_waldump
and parse the output. We are trying to read two records, and expect that it will fail to read the second one:
@@ -96,12 +108,12 @@ class Rewind(object):
The error message contains information about LSN of the next record, which is exactly where checkpoint ends."""
lsn8 = format_lsn(lsn, True)
lsn = format_lsn(lsn)
out, err = self._postgresql.waldump(timeline, lsn, 2)
lsn_str = format_lsn(lsn)
out, err = self._postgresql.waldump(timeline, lsn_str, 2)
if out is not None and err is not None:
out = out.decode('utf-8').rstrip().split('\n')
err = err.decode('utf-8').rstrip().split('\n')
pattern = 'error in WAL record at {0}: invalid record length at '.format(lsn)
pattern = 'error in WAL record at {0}: invalid record length at '.format(lsn_str)
if len(out) == 1 and len(err) == 1 and ', lsn: {0}, prev '.format(lsn8) in out[0] and pattern in err[0]:
i = err[0].find(pattern) + len(pattern)
@@ -117,20 +129,20 @@ class Rewind(object):
return 0
def _get_local_timeline_lsn_from_controldata(self):
def _get_local_timeline_lsn_from_controldata(self) -> Tuple[Optional[bool], Optional[int], Optional[int]]:
in_recovery = timeline = lsn = None
data = self._postgresql.controldata()
try:
if data.get('Database cluster state') in ('shut down in recovery', 'in archive recovery'):
in_recovery = True
lsn = data.get('Minimum recovery ending location')
timeline = int(data.get("Min recovery ending loc's timeline"))
timeline = int(data.get("Min recovery ending loc's timeline", ""))
if lsn == '0/0' or timeline == 0: # it was a primary when it crashed
data['Database cluster state'] = 'shut down'
if data.get('Database cluster state') == 'shut down':
in_recovery = False
lsn = data.get('Latest checkpoint location')
timeline = int(data.get("Latest checkpoint's TimeLineID"))
timeline = int(data.get("Latest checkpoint's TimeLineID", ""))
except (TypeError, ValueError):
logger.exception('Failed to get local timeline and lsn from pg_controldata output')
@@ -143,7 +155,7 @@ class Rewind(object):
return in_recovery, timeline, lsn
def _get_local_timeline_lsn(self):
def _get_local_timeline_lsn(self) -> Tuple[Optional[bool], Optional[int], Optional[int]]:
if self._postgresql.is_running(): # if postgres is running - get timeline from replication connection
in_recovery = True
timeline = self._postgresql.received_timeline() or self._postgresql.get_replica_timeline()
@@ -156,14 +168,15 @@ class Rewind(object):
return in_recovery, timeline, lsn
@staticmethod
def _log_primary_history(history, i):
def _log_primary_history(history: List[Tuple[int, int, str]], i: int) -> None:
start = max(0, i - 3)
end = None if i + 4 >= len(history) else i + 2
history_show = []
history_show: List[str] = []
def format_history_line(line):
def format_history_line(line: Tuple[int, int, str]) -> str:
return '{0}\t{1}\t{2}'.format(line[0], format_lsn(line[1]), line[2])
line = None
for line in history[start:end]:
history_show.append(format_history_line(line))
@@ -173,7 +186,7 @@ class Rewind(object):
logger.info('primary: history=%s', '\n'.join(history_show))
def _conn_kwargs(self, member, auth):
def _conn_kwargs(self, member: Union[Leader, RemoteMember], auth: Dict[str, Any]) -> Dict[str, Any]:
ret = member.conn_kwargs(auth)
if not ret.get('dbname'):
ret['dbname'] = self._postgresql.database
@@ -183,7 +196,7 @@ class Rewind(object):
ret['target_session_attrs'] = 'read-write'
return ret
def _check_timeline_and_lsn(self, leader):
def _check_timeline_and_lsn(self, leader: Union[Leader, RemoteMember]) -> None:
in_recovery, local_timeline, local_lsn = self._get_local_timeline_lsn()
if local_timeline is None or local_lsn is None:
return
@@ -205,23 +218,28 @@ class Rewind(object):
try:
with self._postgresql.get_replication_connection_cursor(**leader.conn_kwargs()) as cur:
cur.execute('IDENTIFY_SYSTEM')
primary_timeline = cur.fetchone()[1]
logger.info('primary_timeline=%s', primary_timeline)
if local_timeline > primary_timeline: # Not always supported by pg_rewind
need_rewind = True
elif local_timeline == primary_timeline:
need_rewind = False
elif primary_timeline > 1:
cur.execute('TIMELINE_HISTORY {0}'.format(primary_timeline))
history = cur.fetchone()[1]
if not isinstance(history, str):
history = bytes(history).decode('utf-8')
logger.debug('primary: history=%s', history)
row = cur.fetchone()
if row:
primary_timeline = row[1]
logger.info('primary_timeline=%s', primary_timeline)
if local_timeline > primary_timeline: # Not always supported by pg_rewind
need_rewind = True
elif local_timeline == primary_timeline:
need_rewind = False
elif primary_timeline > 1:
cur.execute('TIMELINE_HISTORY {0}'.format(primary_timeline).encode('utf-8'))
row = cur.fetchone()
if row:
history = row[1]
if not isinstance(history, str):
history = bytes(history).decode('utf-8')
logger.debug('primary: history=%s', history)
except Exception:
return logger.exception('Exception when working with primary via replication connection')
if history is not None:
history = list(parse_history(history))
i = len(history)
for i, (parent_timeline, switchpoint, _) in enumerate(history):
if parent_timeline == local_timeline:
# We don't need to rewind when:
@@ -243,12 +261,12 @@ class Rewind(object):
self._state = need_rewind and REWIND_STATUS.NEED or REWIND_STATUS.NOT_NEED
def rewind_or_reinitialize_needed_and_possible(self, leader):
def rewind_or_reinitialize_needed_and_possible(self, leader: Union[Leader, RemoteMember, None]) -> bool:
if leader and leader.name != self._postgresql.name and leader.conn_url and self._state == REWIND_STATUS.CHECK:
self._check_timeline_and_lsn(leader)
return leader and leader.conn_url and self._state == REWIND_STATUS.NEED
return bool(leader and leader.conn_url) and self._state == REWIND_STATUS.NEED
def __checkpoint(self, task, wakeup):
def __checkpoint(self, task: CriticalTask, wakeup: Callable[..., Any]) -> None:
try:
result = self._postgresql.checkpoint()
except Exception as e:
@@ -258,7 +276,7 @@ class Rewind(object):
if task.result:
wakeup()
def ensure_checkpoint_after_promote(self, wakeup):
def ensure_checkpoint_after_promote(self, wakeup: Callable[..., Any]) -> None:
"""After promote issue a CHECKPOINT from a new thread and asynchronously check the result.
In case if CHECKPOINT failed, just check that timeline in pg_control was updated."""
@@ -275,10 +293,10 @@ class Rewind(object):
self._checkpoint_task = CriticalTask()
Thread(target=self.__checkpoint, args=(self._checkpoint_task, wakeup)).start()
def checkpoint_after_promote(self):
def checkpoint_after_promote(self) -> bool:
return self._state == REWIND_STATUS.CHECKPOINT
def _buid_archiver_command(self, command, wal_filename):
def _buid_archiver_command(self, command: str, wal_filename: str) -> str:
"""Replace placeholders in the given archiver command's template.
Applicable for archive_command and restore_command.
Can also be used for archive_cleanup_command and recovery_end_command,
@@ -306,13 +324,13 @@ class Rewind(object):
return cmd
def _fetch_missing_wal(self, restore_command, wal_filename):
def _fetch_missing_wal(self, restore_command: str, wal_filename: str) -> bool:
cmd = self._buid_archiver_command(restore_command, wal_filename)
logger.info('Trying to fetch the missing wal: %s', cmd)
return self._postgresql.cancellable.call(shlex.split(cmd)) == 0
def _find_missing_wal(self, data):
def _find_missing_wal(self, data: bytes) -> Optional[str]:
# could not open file "$PGDATA/pg_wal/0000000A00006AA100000068": No such file or directory
pattern = 'could not open file "'
for line in data.decode('utf-8').split('\n'):
@@ -325,7 +343,7 @@ class Rewind(object):
if waldir.endswith('/pg_' + self._postgresql.wal_name) and len(wal_filename) == 24:
return wal_filename
def _archive_ready_wals(self):
def _archive_ready_wals(self) -> None:
"""Try to archive WALs that have .ready files just in case
archive_mode was not set to 'always' before promote, while
after it the WALs were recycled on the promoted replica.
@@ -361,7 +379,7 @@ class Rewind(object):
else:
logger.info('Failed to archive WAL segment %s', wal)
def _maybe_clean_pg_replslot(self):
def _maybe_clean_pg_replslot(self) -> None:
"""Clean pg_replslot directory if pg version is less then 11
(pg_rewind deletes $PGDATA/pg_replslot content only since pg11)."""
if self._postgresql.major_version < 110000:
@@ -373,33 +391,33 @@ class Rewind(object):
except Exception as e:
logger.warning('Unable to clean %s: %r', replslot_dir, e)
def pg_rewind(self, r):
def pg_rewind(self, r: Dict[str, Any]) -> bool:
# prepare pg_rewind connection
env = self._postgresql.config.write_pgpass(r)
env.update(LANG='C', LC_ALL='C', PGOPTIONS='-c statement_timeout=0')
dsn = self._postgresql.config.format_dsn(r, True)
logger.info('running pg_rewind from %s', dsn)
restore_command = self._postgresql.config.get('recovery_conf', {}).get('restore_command') \
restore_command = (self._postgresql.config.get('recovery_conf') or {}).get('restore_command') \
if self._postgresql.major_version < 120000 else self._postgresql.get_guc_value('restore_command')
# Until v15 pg_rewind expected postgresql.conf to be inside $PGDATA, which is not the case on e.g. Debian
pg_rewind_can_restore = restore_command and (self._postgresql.major_version >= 150000
or (self._postgresql.major_version >= 130000
and self._postgresql.config._config_dir
and self._postgresql.config.config_dir
== self._postgresql.data_dir))
cmd = [self._postgresql.pgcommand('pg_rewind')]
if pg_rewind_can_restore:
cmd.append('--restore-target-wal')
if self._postgresql.major_version >= 150000 and\
self._postgresql.config._config_dir != self._postgresql.data_dir:
self._postgresql.config.config_dir != self._postgresql.data_dir:
cmd.append('--config-file={0}'.format(self._postgresql.config.postgresql_conf))
cmd.extend(['-D', self._postgresql.data_dir, '--source-server', dsn])
while True:
results = {}
results: Dict[str, bytes] = {}
ret = self._postgresql.cancellable.call(cmd, env=env, communicate=results)
logger.info('pg_rewind exit code=%s', ret)
@@ -422,7 +440,7 @@ class Rewind(object):
logger.info('Failed to fetch WAL segment %s required for pg_rewind', missing_wal)
return False
def execute(self, leader):
def execute(self, leader: Union[Leader, RemoteMember]) -> Optional[bool]:
if self._postgresql.is_running() and not self._postgresql.stop(checkpoint=False):
return logger.warning('Can not run pg_rewind because postgres is still running')
@@ -471,26 +489,26 @@ class Rewind(object):
break
return False
def reset_state(self):
def reset_state(self) -> None:
self._state = REWIND_STATUS.INITIAL
with self._checkpoint_task_lock:
self._checkpoint_task = None
@property
def is_needed(self):
def is_needed(self) -> bool:
return self._state in (REWIND_STATUS.CHECK, REWIND_STATUS.NEED)
@property
def executed(self):
def executed(self) -> bool:
return self._state > REWIND_STATUS.NOT_NEED
@property
def failed(self):
def failed(self) -> bool:
return self._state == REWIND_STATUS.FAILED
def read_postmaster_opts(self):
def read_postmaster_opts(self) -> Dict[str, str]:
"""returns the list of option names/values from postgres.opts, Empty dict if read failed or no file"""
result = {}
result: Dict[str, str] = {}
try:
with open(os.path.join(self._postgresql.data_dir, 'postmaster.opts')) as f:
data = f.read()
@@ -502,7 +520,8 @@ class Rewind(object):
logger.exception('Error when reading postmaster.opts')
return result
def single_user_mode(self, communicate=None, options=None):
def single_user_mode(self, communicate: Optional[Dict[str, Any]] = None,
options: Optional[Dict[str, str]] = None) -> Optional[int]:
"""run a given command in a single-user mode. If the command is empty - then just start and stop"""
cmd = [self._postgresql.pgcommand('postgres'), '--single', '-D', self._postgresql.data_dir]
for opt, val in sorted((options or {}).items()):
@@ -511,7 +530,7 @@ class Rewind(object):
cmd.append('template1')
return self._postgresql.cancellable.call(cmd, communicate=communicate)
def cleanup_archive_status(self):
def cleanup_archive_status(self) -> None:
status_dir = os.path.join(self._postgresql.wal_dir, 'archive_status')
try:
for f in os.listdir(status_dir):
@@ -526,7 +545,7 @@ class Rewind(object):
except OSError:
logger.exception('Unable to list %s', status_dir)
def ensure_clean_shutdown(self):
def ensure_clean_shutdown(self) -> Optional[bool]:
self._archive_ready_wals()
self.cleanup_archive_status()
@@ -534,7 +553,7 @@ class Rewind(object):
opts = self.read_postmaster_opts()
opts.update({'archive_mode': 'on', 'archive_command': 'false'})
self._postgresql.config.remove_recovery_conf()
output = {}
output: Dict[str, bytes] = {}
ret = self.single_user_mode(communicate=output, options=opts)
if ret != 0:
logger.error('Crash recovery finished with code=%s', ret)

View File

@@ -5,36 +5,43 @@ import shutil
from collections import defaultdict
from contextlib import contextmanager
from threading import Condition, Thread
from typing import Any, Dict, Generator, List, Optional, Union, Tuple, TYPE_CHECKING
from .connection import get_connection_cursor
from .misc import format_lsn, fsync_dir
from ..dcs import Cluster, Leader
from ..psycopg import OperationalError
if TYPE_CHECKING: # pragma: no cover
from psycopg import Cursor
from psycopg2 import cursor
from . import Postgresql
logger = logging.getLogger(__name__)
def compare_slots(s1, s2, dbid='database'):
def compare_slots(s1: Dict[str, Any], s2: Dict[str, Any], dbid: str = 'database') -> bool:
return s1['type'] == s2['type'] and (s1['type'] == 'physical'
or s1.get(dbid) == s2.get(dbid) and s1['plugin'] == s2['plugin'])
class SlotsAdvanceThread(Thread):
def __init__(self, slots_handler):
def __init__(self, slots_handler: 'SlotsHandler') -> None:
super(SlotsAdvanceThread, self).__init__()
self.daemon = True
self._slots_handler = slots_handler
# _copy_slots and _failed are used to asynchronously give some feedback to the main thread
self._copy_slots = []
self._copy_slots: List[str] = []
self._failed = False
self._scheduled = defaultdict(dict) # {'dbname1': {'slot1': 100, 'slot2': 100}, 'dbname2': {'slot3': 100}}
# {'dbname1': {'slot1': 100, 'slot2': 100}, 'dbname2': {'slot3': 100}}
self._scheduled: Dict[str, Dict[str, int]] = defaultdict(dict)
self._condition = Condition() # protect self._scheduled from concurrent access and to wakeup the run() method
self.start()
def sync_slot(self, cur, database, slot, lsn):
def sync_slot(self, cur: Union['cursor', 'Cursor[Any]'], database: str, slot: str, lsn: int) -> None:
failed = copy = False
try:
cur.execute("SELECT pg_catalog.pg_replication_slot_advance(%s, %s)", (slot, format_lsn(lsn)))
@@ -55,7 +62,7 @@ class SlotsAdvanceThread(Thread):
if not self._scheduled[database]:
self._scheduled.pop(database)
def sync_slots_in_database(self, database, slots):
def sync_slots_in_database(self, database: str, slots: List[str]) -> None:
with self._slots_handler.get_local_connection_cursor(dbname=database, options='-c statement_timeout=0') as cur:
for slot in slots:
with self._condition:
@@ -63,7 +70,7 @@ class SlotsAdvanceThread(Thread):
if lsn:
self.sync_slot(cur, database, slot, lsn)
def sync_slots(self):
def sync_slots(self) -> None:
with self._condition:
databases = list(self._scheduled.keys())
for database in databases:
@@ -75,7 +82,7 @@ class SlotsAdvanceThread(Thread):
except Exception as e:
logger.error('Failed to advance replication slots in database %s: %r', database, e)
def run(self):
def run(self) -> None:
while True:
with self._condition:
if not self._scheduled:
@@ -83,7 +90,7 @@ class SlotsAdvanceThread(Thread):
self.sync_slots()
def schedule(self, advance_slots):
def schedule(self, advance_slots: Dict[str, Dict[str, int]]) -> Tuple[bool, List[str]]:
with self._condition:
for database, values in advance_slots.items():
self._scheduled[database].update(values)
@@ -94,7 +101,7 @@ class SlotsAdvanceThread(Thread):
return ret
def on_promote(self):
def on_promote(self) -> None:
with self._condition:
self._scheduled.clear()
self._failed = False
@@ -103,22 +110,22 @@ class SlotsAdvanceThread(Thread):
class SlotsHandler(object):
def __init__(self, postgresql):
def __init__(self, postgresql: 'Postgresql') -> None:
self._postgresql = postgresql
self._advance = None
self._replication_slots = {} # already existing replication slots
self._unready_logical_slots = {}
self._replication_slots: Dict[str, Dict[str, Any]] = {} # already existing replication slots
self._unready_logical_slots: Dict[str, Optional[int]] = {}
self.pg_replslot_dir = os.path.join(self._postgresql.data_dir, 'pg_replslot')
self.schedule()
def _query(self, sql, *params):
def _query(self, sql: str, *params: Any) -> Union['cursor', 'Cursor[Any]']:
return self._postgresql.query(sql, *params, retry=False)
@staticmethod
def _copy_items(src, dst, keys=None):
def _copy_items(src: Dict[str, Any], dst: Dict[str, Any], keys: Optional[List[str]] = None) -> None:
dst.update({key: src[key] for key in keys or ('datoid', 'catalog_xmin', 'confirmed_flush_lsn')})
def process_permanent_slots(self, slots):
def process_permanent_slots(self, slots: List[Dict[str, Any]]) -> Dict[str, int]:
"""This methods solves three problems at once (I know, it is weird).
The cluster_info_query from `Postgresql` is executed every HA loop and returns
@@ -130,11 +137,11 @@ class SlotsHandler(object):
3. Updates the local cache with the fresh catalog_xmin and confirmed_flush_lsn for every known slot.
This info is used when performing the check of logical slot readiness on standbys.
"""
ret = {}
ret: Dict[str, int] = {}
slots = {slot['slot_name']: slot for slot in slots or []}
if slots:
for name, value in slots.items():
slots_dict: Dict[str, Dict[str, Any]] = {slot['slot_name']: slot for slot in slots or []}
if slots_dict:
for name, value in slots_dict.items():
if name in self._replication_slots:
if compare_slots(value, self._replication_slots[name], 'datoid'):
if value['type'] == 'logical':
@@ -143,15 +150,15 @@ class SlotsHandler(object):
else:
self._schedule_load_slots = True
# It could happen that the slots was deleted in the background, we want to detect this case
if any(name not in slots for name in self._replication_slots.keys()):
# It could happen that the slot was deleted in the background, we want to detect this case
if any(name not in slots_dict for name in self._replication_slots.keys()):
self._schedule_load_slots = True
return ret
def load_replication_slots(self):
def load_replication_slots(self) -> None:
if self._postgresql.major_version >= 90400 and self._schedule_load_slots:
replication_slots = {}
replication_slots: Dict[str, Dict[str, Any]] = {}
extra = ", catalog_xmin, pg_catalog.pg_wal_lsn_diff(confirmed_flush_lsn, '0/0')::bigint"\
if self._postgresql.major_version >= 100000 else ""
skip_temp_slots = ' WHERE NOT temporary' if self._postgresql.major_version >= 100000 else ''
@@ -170,15 +177,16 @@ class SlotsHandler(object):
self._unready_logical_slots = {n: None for n, v in replication_slots.items() if v['type'] == 'logical'}
self._force_readiness_check = False
def ignore_replication_slot(self, cluster, name):
def ignore_replication_slot(self, cluster: Cluster, name: str) -> bool:
slot = self._replication_slots[name]
for matcher in cluster.config.ignore_slots_matchers:
if ((matcher.get("name") is None or matcher["name"] == name)
and all(not matcher.get(a) or matcher[a] == slot.get(a) for a in ('database', 'plugin', 'type'))):
return True
if cluster.config:
for matcher in cluster.config.ignore_slots_matchers:
if ((matcher.get("name") is None or matcher["name"] == name)
and all(not matcher.get(a) or matcher[a] == slot.get(a) for a in ('database', 'plugin', 'type'))):
return True
return self._postgresql.citus_handler.ignore_replication_slot(slot)
def drop_replication_slot(self, name):
def drop_replication_slot(self, name: str) -> Tuple[bool, bool]:
"""Returns a tuple(active, dropped)"""
cursor = self._query(('WITH slots AS (SELECT slot_name, active'
' FROM pg_catalog.pg_replication_slots WHERE slot_name = %s),'
@@ -186,9 +194,12 @@ class SlotsHandler(object):
' true AS dropped FROM slots WHERE not active) '
'SELECT active, COALESCE(dropped, false) FROM slots'
' FULL OUTER JOIN dropped ON true'), name)
return cursor.fetchone() if cursor.rowcount == 1 else (False, False)
row = cursor.fetchone()
if not row:
row = (False, False)
return row
def _drop_incorrect_slots(self, cluster, slots, paused):
def _drop_incorrect_slots(self, cluster: Cluster, slots: Dict[str, Any], paused: bool) -> None:
# drop old replication slots which are not presented in desired slots
for name in set(self._replication_slots) - set(slots):
if not paused and not self.ignore_replication_slot(cluster, name):
@@ -211,7 +222,7 @@ class SlotsHandler(object):
logger.error("Failed to drop replication slot '%s'", name)
self._schedule_load_slots = True
def _ensure_physical_slots(self, slots):
def _ensure_physical_slots(self, slots: Dict[str, Any]) -> None:
immediately_reserve = ', true' if self._postgresql.major_version >= 90600 else ''
for name, value in slots.items():
if name not in self._replication_slots and value['type'] == 'physical':
@@ -225,15 +236,15 @@ class SlotsHandler(object):
self._schedule_load_slots = True
@contextmanager
def get_local_connection_cursor(self, **kwargs):
def get_local_connection_cursor(self, **kwargs: Any) -> Generator[Union['cursor', 'Cursor[Any]'], None, None]:
conn_kwargs = self._postgresql.config.local_connect_kwargs
conn_kwargs.update(kwargs)
with get_connection_cursor(**conn_kwargs) as cur:
yield cur
def _ensure_logical_slots_primary(self, slots):
def _ensure_logical_slots_primary(self, slots: Dict[str, Any]) -> None:
# Group logical slots to be created by database name
logical_slots = defaultdict(dict)
logical_slots: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict(dict)
for name, value in slots.items():
if value['type'] == 'logical':
# If the logical already exists, copy some information about it into the original structure
@@ -257,26 +268,27 @@ class SlotsHandler(object):
slots.pop(name)
self._schedule_load_slots = True
def schedule_advance_slots(self, slots):
def schedule_advance_slots(self, slots: Dict[str, Dict[str, int]]) -> Tuple[bool, List[str]]:
if not self._advance:
self._advance = SlotsAdvanceThread(self)
return self._advance.schedule(slots)
def _ensure_logical_slots_replica(self, cluster, slots):
advance_slots = defaultdict(dict) # Group logical slots to be advanced by database name
create_slots = [] # And collect logical slots to be created on the replica
def _ensure_logical_slots_replica(self, cluster: Cluster, slots: Dict[str, Any]) -> List[str]:
# Group logical slots to be advanced by database name
advance_slots: Dict[str, Dict[str, int]] = defaultdict(dict)
create_slots: List[str] = [] # And collect logical slots to be created on the replica
for name, value in slots.items():
if value['type'] == 'logical':
# If the logical already exists, copy some information about it into the original structure
if self._replication_slots.get(name, {}).get('datoid'):
self._copy_items(self._replication_slots[name], value)
if name in cluster.slots:
if cluster.slots and name in cluster.slots:
try: # Skip slots that doesn't need to be advanced
if value['confirmed_flush_lsn'] < int(cluster.slots[name]):
advance_slots[value['database']][name] = int(cluster.slots[name])
except Exception as e:
logger.error('Failed to parse "%s": %r', cluster.slots[name], e)
elif name in cluster.slots: # We want to copy only slots with feedback in a DCS
elif cluster.slots and name in cluster.slots: # We want to copy only slots with feedback in a DCS
create_slots.append(name)
error, copy_slots = self.schedule_advance_slots(advance_slots)
@@ -284,8 +296,9 @@ class SlotsHandler(object):
self._schedule_load_slots = True
return create_slots + copy_slots
def sync_replication_slots(self, cluster, nofailover, replicatefrom=None, paused=False):
ret = None
def sync_replication_slots(self, cluster: Cluster, nofailover: bool,
replicatefrom: Optional[str] = None, paused: bool = False) -> List[str]:
ret = []
if self._postgresql.major_version >= 90400 and cluster.config:
try:
self.load_replication_slots()
@@ -312,14 +325,15 @@ class SlotsHandler(object):
return ret
@contextmanager
def _get_leader_connection_cursor(self, leader):
def _get_leader_connection_cursor(self, leader: Leader) -> Generator[Union['cursor', 'Cursor[Any]'], None, None]:
conn_kwargs = leader.conn_kwargs(self._postgresql.config.rewind_credentials)
conn_kwargs['dbname'] = self._postgresql.database
with get_connection_cursor(connect_timeout=3, options="-c statement_timeout=2000", **conn_kwargs) as cur:
yield cur
def check_logical_slots_readiness(self, cluster, nofailover, replicatefrom):
if self._unready_logical_slots:
def check_logical_slots_readiness(self, cluster: Cluster, nofailover: bool, replicatefrom: Optional[str]) -> None:
catalog_xmin = None
if self._unready_logical_slots and cluster.leader:
slot_name = cluster.get_my_slot_name_on_primary(self._postgresql.name, replicatefrom)
try:
with self._get_leader_connection_cursor(cluster.leader) as cur:
@@ -335,13 +349,14 @@ class SlotsHandler(object):
# Remember catalog_xmin of logical slots on the primary when catalog_xmin of
# the physical slot became valid. Logical slots on replica will be safe to use after
# promote when catalog_xmin of the physical slot overtakes these values.
if catalog_xmin:
if catalog_xmin is not None:
for name, value in slots.items():
self._unready_logical_slots[name] = value
else: # Replica isn't streaming or the hot_standby_feedback isn't enabled
try:
cur = self._query("SELECT pg_catalog.current_setting('hot_standby_feedback')::boolean")
if not cur.fetchone()[0]:
row = cur.fetchone()
if row and not row[0]:
logger.error('Logical slot failover requires "hot_standby_feedback".'
' Please check postgresql.auto.conf')
except Exception as e:
@@ -354,14 +369,18 @@ class SlotsHandler(object):
# 1. has a nonzero/non-null catalog_xmin
# 2. has a catalog_xmin that is not newer (greater) than the catalog_xmin of any slot on the standby
# 3. overtook the catalog_xmin of remembered values of logical slots on the primary.
if not value or self._unready_logical_slots[name] <= catalog_xmin <= value['catalog_xmin']:
if not value or catalog_xmin is not None and\
self._unready_logical_slots[name] <= catalog_xmin <= value['catalog_xmin']:
del self._unready_logical_slots[name]
if value:
logger.info('Logical slot %s is safe to be used after a failover', name)
def copy_logical_slots(self, cluster, create_slots):
def copy_logical_slots(self, cluster: Cluster, create_slots: List[str]) -> None:
leader = cluster.leader
if not leader:
return
slots = cluster.get_replication_slots(self._postgresql.name, 'replica', False, self._postgresql.major_version)
copy_slots: Dict[str, Dict[str, Any]] = {}
with self._get_leader_connection_cursor(leader) as cur:
try:
cur.execute("SELECT slot_name, slot_type, datname, plugin, catalog_xmin, "
@@ -370,21 +389,20 @@ class SlotsHandler(object):
" FROM pg_catalog.pg_get_replication_slots() JOIN pg_catalog.pg_database ON datoid = oid"
" WHERE NOT pg_catalog.pg_is_in_recovery() AND slot_name = ANY(%s)", (create_slots,))
create_slots = {}
for r in cur:
if r[0] in slots: # slot_name is defined in the global configuration
slot = {'type': r[1], 'database': r[2], 'plugin': r[3],
'catalog_xmin': r[4], 'confirmed_flush_lsn': r[5], 'data': r[6]}
if compare_slots(slot, slots[r[0]]):
create_slots[r[0]] = slot
copy_slots[r[0]] = slot
else:
logger.warning('Will not copy the logical slot "%s" due to the configuration mismatch: '
'configuration=%s, slot on the primary=%s', r[0], slots[r[0]], slot)
except Exception as e:
logger.error("Failed to copy logical slots from the %s via postgresql connection: %r", leader.name, e)
if isinstance(create_slots, dict) and create_slots and self._postgresql.stop():
for name, value in create_slots.items():
if copy_slots and self._postgresql.stop():
for name, value in copy_slots.items():
slot_dir = os.path.join(self._postgresql.slots_handler.pg_replslot_dir, name)
slot_tmp_dir = slot_dir + '.tmp'
if os.path.exists(slot_tmp_dir):
@@ -403,12 +421,12 @@ class SlotsHandler(object):
fsync_dir(self._postgresql.slots_handler.pg_replslot_dir)
self._postgresql.start()
def schedule(self, value=None):
def schedule(self, value: Optional[bool] = None) -> None:
if value is None:
value = self._postgresql.major_version >= 90400
self._schedule_load_slots = self._force_readiness_check = value
def on_promote(self):
def on_promote(self) -> None:
if self._advance:
self._advance.on_promote()

View File

@@ -3,7 +3,7 @@ import re
import time
from copy import deepcopy
from typing import Any, Collection, Dict, Tuple, TYPE_CHECKING
from typing import Any, Collection, Dict, List, Tuple, TYPE_CHECKING, Union
from ..collections import CaseInsensitiveDict, CaseInsensitiveSet
from ..dcs import Cluster
@@ -109,7 +109,7 @@ def parse_sync_standby_names(value: str) -> Dict[str, Any]:
result = {'type': 'priority', 'num': int(tokens[0][1])}
synclist = tokens[2:-1]
else:
result = {'type': 'priority', 'num': 1}
result: Dict[str, Union[int, str, CaseInsensitiveSet]] = {'type': 'priority', 'num': 1}
synclist = tokens
result['members'] = CaseInsensitiveSet()
for i, (a_type, a_value, a_pos) in enumerate(synclist):
@@ -149,12 +149,12 @@ class SyncHandler(object):
# "sync" replication connections, that were verified to reach self._primary_flush_lsn at some point
self._ready_replicas = CaseInsensitiveDict({}) # keys: member names, values: connection pids
def _handle_synchronous_standby_names_change(self):
def _handle_synchronous_standby_names_change(self) -> None:
"""If synchronous_standby_names has changed we need to check that newly added replicas
have reached self._primary_flush_lsn. Only after that they could be counted as sync."""
synchronous_standby_names = self._postgresql.synchronous_standby_names()
if synchronous_standby_names == self._synchronous_standby_names:
return False
return
self._synchronous_standby_names = synchronous_standby_names
try:
@@ -165,7 +165,7 @@ class SyncHandler(object):
# Invalidate cache of "sync" connections
for app_name in list(self._ready_replicas.keys()):
if app_name not in self._ssn_data['members']:
if isinstance(self._ssn_data['members'], CaseInsensitiveSet) and app_name not in self._ssn_data['members']:
del self._ready_replicas[app_name]
# Newly connected replicas will be counted as sync only when reached self._primary_flush_lsn
@@ -201,7 +201,7 @@ class SyncHandler(object):
if r[sort_col] is not None]
members = CaseInsensitiveDict({m.name: m for m in cluster.members})
replica_list = []
replica_list: List[Tuple[int, str, str, int, bool]] = []
# pg_stat_replication.sync_state has 4 possible states - async, potential, quorum, sync.
# That is, alphabetically they are in the reversed order of priority.
# Since we are doing reversed sort on (sync_state, lsn) tuples, it helps to keep the result
@@ -227,7 +227,8 @@ class SyncHandler(object):
for pid, app_name, sync_state, replica_lsn, _ in sorted(replica_list, key=lambda x: x[4]):
# if standby name is listed in the /sync key we can count it as synchronous, otherwice
# it becomes really synchronous when sync_state = 'sync' and it is known that it managed to catch up
if app_name not in self._ready_replicas and app_name in self._ssn_data['members'] and\
if app_name not in self._ready_replicas and isinstance(self._ssn_data['members'], CaseInsensitiveSet)\
and app_name in self._ssn_data['members'] and\
(cluster.sync.matches(app_name) or sync_state == 'sync' and replica_lsn >= self._primary_flush_lsn):
self._ready_replicas[app_name] = pid

View File

@@ -1,7 +1,7 @@
import abc
import logging
from collections import namedtuple
from typing import Any, MutableMapping, Optional, Tuple, Union
from ..collections import CaseInsensitiveDict
from ..utils import parse_bool, parse_int, parse_real
@@ -9,23 +9,65 @@ from ..utils import parse_bool, parse_int, parse_real
logger = logging.getLogger(__name__)
class Bool(namedtuple('Bool', 'version_from,version_till')):
class _Transformable(abc.ABC):
@staticmethod
def transform(name, value):
def __init__(self, version_from: int, version_till: Optional[int]) -> None:
self.__version_from = version_from
self.__version_till = version_till
@property
def version_from(self) -> int:
return self.__version_from
@property
def version_till(self) -> Optional[int]:
return self.__version_till
@abc.abstractmethod
def transform(self, name: str, value: Any) -> Optional[Any]:
"""Verify that provided value is valid.
:param name: GUC's name
:param value: GUC's value
:returns: the value (sometimes clamped) or ``None`` if the value isn't valid
"""
class Bool(_Transformable):
def transform(self, name: str, value: Any) -> Optional[Any]:
if parse_bool(value) is not None:
return value
logger.warning('Removing bool parameter=%s from the config due to the invalid value=%s', name, value)
class Number(abc.ABC, namedtuple('Number', 'version_from,version_till,min_val,max_val,unit')):
class Number(_Transformable):
def __init__(self, version_from: int, version_till: Optional[int],
min_val: Union[int, float], max_val: Union[int, float], unit: Optional[str]) -> None:
super(Number, self).__init__(version_from, version_till)
self.__min_val = min_val
self.__max_val = max_val
self.__unit = unit
@property
def min_val(self) -> Union[int, float]:
return self.__min_val
@property
def max_val(self) -> Union[int, float]:
return self.__max_val
@property
def unit(self) -> Optional[str]:
return self.__unit
@staticmethod
@abc.abstractmethod
def parse(value, unit):
"""parse value"""
def parse(value: Any, unit: Optional[str]) -> Optional[Any]:
"""Convert provided value to unit."""
def transform(self, name, value):
def transform(self, name: str, value: Any) -> Union[int, float, None]:
num_value = self.parse(value, self.unit)
if num_value is not None:
if num_value < self.min_val:
@@ -44,20 +86,28 @@ class Number(abc.ABC, namedtuple('Number', 'version_from,version_till,min_val,ma
class Integer(Number):
@staticmethod
def parse(value, unit):
def parse(value: Any, unit: Optional[str]) -> Optional[int]:
return parse_int(value, unit)
class Real(Number):
@staticmethod
def parse(value, unit):
def parse(value: Any, unit: Optional[str]) -> Optional[float]:
return parse_real(value, unit)
class Enum(namedtuple('Enum', 'version_from,version_till,possible_values')):
class Enum(_Transformable):
def transform(self, name, value):
def __init__(self, version_from: int, version_till: Optional[int], possible_values: Tuple[str, ...]) -> None:
super(Enum, self).__init__(version_from, version_till)
self.__possible_values = possible_values
@property
def possible_values(self) -> Tuple[str, ...]:
return self.__possible_values
def transform(self, name: str, value: Optional[Any]) -> Optional[Any]:
if str(value).lower() in self.possible_values:
return value
logger.warning('Removing enum parameter=%s from the config due to the invalid value=%s', name, value)
@@ -65,16 +115,15 @@ class Enum(namedtuple('Enum', 'version_from,version_till,possible_values')):
class EnumBool(Enum):
def transform(self, name, value):
def transform(self, name: str, value: Optional[Any]) -> Optional[Any]:
if parse_bool(value) is not None:
return value
return super(EnumBool, self).transform(name, value)
class String(namedtuple('String', 'version_from,version_till')):
class String(_Transformable):
@staticmethod
def transform(name, value):
def transform(self, name: str, value: Optional[Any]) -> Optional[Any]:
return value
@@ -511,17 +560,18 @@ recovery_parameters = CaseInsensitiveDict({
})
def _transform_parameter_value(validators, version, name, value):
validators = validators.get(name)
if validators:
for validator in (validators if isinstance(validators[0], tuple) else [validators]):
def _transform_parameter_value(validators: MutableMapping[str, Union[_Transformable, Tuple[_Transformable, ...]]],
version: int, name: str, value: Any) -> Optional[Any]:
name_validators = validators.get(name)
if name_validators:
for validator in (name_validators if isinstance(name_validators, tuple) else (name_validators,)):
if version >= validator.version_from and\
(validator.version_till is None or version < validator.version_till):
return validator.transform(name, value)
logger.warning('Removing unexpected parameter=%s value=%s from the config', name, value)
def transform_postgresql_parameter_value(version, name, value):
def transform_postgresql_parameter_value(version: int, name: str, value: Any) -> Optional[Any]:
if '.' in name:
return value
if name in recovery_parameters:
@@ -529,5 +579,5 @@ def transform_postgresql_parameter_value(version, name, value):
return _transform_parameter_value(parameters, version, name, value)
def transform_recovery_parameter_value(version, name, value):
def transform_recovery_parameter_value(version: int, name: str, value: Any) -> Optional[Any]:
return _transform_parameter_value(recovery_parameters, version, name, value)

View File

@@ -3,8 +3,10 @@
This module is able to handle both ``pyscopg2`` and ``psycopg3``, and it exposes a common interface for both.
``psycopg2`` takes precedence. ``psycopg3`` will only be used if ``psycopg2`` is either absent or older than ``2.5.4``.
"""
from typing import Any, Optional
from typing import Any, Optional, TYPE_CHECKING, Union
if TYPE_CHECKING: # pragma: no cover
from psycopg import Connection
from psycopg2 import connection, cursor
__all__ = ['connect', 'quote_ident', 'quote_literal', 'DatabaseError', 'Error', 'OperationalError', 'ProgrammingError']
@@ -39,9 +41,9 @@ try:
value.prepare(conn)
return value.getquoted().decode('utf-8')
except ImportError:
from psycopg import connect as __connect, sql, Error, DatabaseError, OperationalError, ProgrammingError, Connection
from psycopg import connect as __connect, sql, Error, DatabaseError, OperationalError, ProgrammingError
def _connect(dsn: str = "", **kwargs: Any) -> Any:
def _connect(dsn: Optional[str] = None, **kwargs: Any) -> 'Connection[Any]':
"""Call ``psycopg.connect`` with ``dsn`` and ``**kwargs``.
.. note::
@@ -53,19 +55,19 @@ except ImportError:
:returns: a connection to the database.
"""
ret = __connect(dsn, **kwargs)
ret = __connect(dsn or "", **kwargs)
setattr(ret, 'server_version', ret.pgconn.server_version) # compatibility with psycopg2
return ret
def _quote_ident(value: Any, conn: Connection) -> str:
def _quote_ident(value: Any, scope: Any) -> str:
"""Quote *value* as a SQL identifier.
:param value: value to be quoted.
:param conn: connection to evaluate the returning string into.
:param scope: connection to evaluate the returning string into.
:returns: *value* quoted as a SQL identifier.
"""
return sql.Identifier(value).as_string(conn)
return sql.Identifier(value).as_string(scope)
def quote_literal(value: Any, conn: Optional[Any] = None) -> str:
"""Quote *value* as a SQL literal.
@@ -78,7 +80,7 @@ except ImportError:
return sql.Literal(value).as_string(conn)
def connect(*args: Any, **kwargs: Any) -> Any:
def connect(*args: Any, **kwargs: Any) -> Union['connection', 'Connection[Any]']:
"""Get a connection to the database.
.. note::
@@ -102,7 +104,7 @@ def connect(*args: Any, **kwargs: Any) -> Any:
return ret
def quote_ident(value: Any, conn: Optional[Any] = None) -> str:
def quote_ident(value: Any, conn: Optional[Union['cursor', 'connection', 'Connection[Any]']] = None) -> str:
"""Quote *value* as a SQL identifier.
:param value: value to be quoted.

View File

@@ -1,5 +1,6 @@
import logging
from .config import Config
from .daemon import AbstractPatroniDaemon, abstract_main
from .dcs.raft import KVStoreTTL
@@ -8,22 +9,22 @@ logger = logging.getLogger(__name__)
class RaftController(AbstractPatroniDaemon):
def __init__(self, config):
def __init__(self, config: Config) -> None:
super(RaftController, self).__init__(config)
config = self.config.get('raft')
assert 'self_addr' in config
self._raft = KVStoreTTL(None, None, None, **config)
kvstore_config = self.config.get('raft')
assert 'self_addr' in kvstore_config
self._raft = KVStoreTTL(None, None, None, **kvstore_config)
def _run_cycle(self):
def _run_cycle(self) -> None:
try:
self._raft.doTick(self._raft.conf.autoTickPeriod)
except Exception:
logger.exception('doTick')
def _shutdown(self):
def _shutdown(self) -> None:
self._raft.destroy()
def main():
def main() -> None:
abstract_main(RaftController)

View File

@@ -140,14 +140,14 @@ class PatroniRequest(object):
:returns: the response returned upon request.
"""
url = member.api_url
url = member.api_url or ''
if endpoint:
scheme, netloc, _, _, _, _ = urlparse(url)
url = urlunparse((scheme, netloc, endpoint, '', '', ''))
return self.request(method, url, data, **kwargs)
def get(url: str, verify: Optional[bool] = True, **kwargs: Any) -> urllib3.response.HTTPResponse:
def get(url: str, verify: bool = True, **kwargs: Any) -> urllib3.response.HTTPResponse:
"""Perform an HTTP GET request.
.. note::

View File

@@ -5,20 +5,22 @@ import logging
import sys
import boto3
from ..utils import Retry, RetryFailedError
from botocore.exceptions import ClientError
from botocore.utils import IMDSFetcher
from typing import Any, Optional
from ..utils import Retry, RetryFailedError
logger = logging.getLogger(__name__)
class AWSConnection(object):
def __init__(self, cluster_name):
def __init__(self, cluster_name: Optional[str]) -> None:
self.available = False
self.cluster_name = cluster_name if cluster_name is not None else 'unknown'
self._retry = Retry(deadline=300, max_delay=30, max_tries=-1, retry_exceptions=(ClientError,))
self._retry = Retry(deadline=300, max_delay=30, max_tries=-1, retry_exceptions=ClientError)
try:
# get the instance id
fetcher = IMDSFetcher(timeout=2.1)
@@ -38,13 +40,13 @@ class AWSConnection(object):
return
self.available = True
def retry(self, *args, **kwargs):
def retry(self, *args: Any, **kwargs: Any) -> Any:
return self._retry.copy()(*args, **kwargs)
def aws_available(self):
def aws_available(self) -> bool:
return self.available
def _tag_ebs(self, conn, role):
def _tag_ebs(self, conn: Any, role: str) -> None:
""" set tags, carrying the cluster name, instance role and instance id for the EBS storage """
tags = [{'Key': 'Name', 'Value': 'spilo_' + self.cluster_name},
{'Key': 'Role', 'Value': role},
@@ -52,16 +54,16 @@ class AWSConnection(object):
volumes = conn.volumes.filter(Filters=[{'Name': 'attachment.instance-id', 'Values': [self.instance_id]}])
conn.create_tags(Resources=[v.id for v in volumes], Tags=tags)
def _tag_ec2(self, conn, role):
def _tag_ec2(self, conn: Any, role: str) -> None:
""" tag the current EC2 instance with a cluster role """
tags = [{'Key': 'Role', 'Value': role}]
conn.create_tags(Resources=[self.instance_id], Tags=tags)
def on_role_change(self, new_role):
def on_role_change(self, new_role: str) -> bool:
if not self.available:
return False
try:
conn = boto3.resource('ec2', region_name=self.region)
conn = boto3.resource('ec2', region_name=self.region) # type: ignore
self.retry(self._tag_ec2, conn, new_role)
self.retry(self._tag_ebs, conn, new_role)
except RetryFailedError:

View File

@@ -31,7 +31,8 @@ import subprocess
import sys
import time
from collections import namedtuple
from enum import IntEnum
from typing import Any, List, NamedTuple, Optional, Tuple
from .. import psycopg
@@ -42,15 +43,14 @@ si_prefixes = ['K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y']
# Meaningful names to the exit codes used by WALERestore
ExitCode = type('Enum', (), {
'SUCCESS': 0, #: Succeeded
'RETRY_LATER': 1, #: External issue, retry later
'FAIL': 2 #: Don't try again unless configuration changes
})
class ExitCode(IntEnum):
SUCCESS = 0 #: Succeeded
RETRY_LATER = 1 #: External issue, retry later
FAIL = 2 #: Don't try again unless configuration changes
# We need to know the current PG version in order to figure out the correct WAL directory name
def get_major_version(data_dir):
def get_major_version(data_dir: str) -> float:
version_file = os.path.join(data_dir, 'PG_VERSION')
if os.path.isfile(version_file): # version file exists
try:
@@ -61,7 +61,7 @@ def get_major_version(data_dir):
return 0.0
def repr_size(n_bytes):
def repr_size(n_bytes: float) -> str:
"""
>>> repr_size(1000)
'1000 Bytes'
@@ -77,7 +77,7 @@ def repr_size(n_bytes):
return '{0} {1}iB'.format(round(n_bytes, 1), si_prefixes[i])
def size_as_bytes(size_, prefix):
def size_as_bytes(size: float, prefix: str) -> int:
"""
>>> size_as_bytes(7.5, 'T')
8246337208320
@@ -88,23 +88,19 @@ def size_as_bytes(size_, prefix):
exponent = si_prefixes.index(prefix) + 1
return int(size_ * (1024.0 ** exponent))
return int(size * (1024.0 ** exponent))
WALEConfig = namedtuple(
'WALEConfig',
[
'env_dir',
'threshold_mb',
'threshold_pct',
'cmd',
]
)
class WALEConfig(NamedTuple):
env_dir: str
threshold_mb: int
threshold_pct: int
cmd: List[str]
class WALERestore(object):
def __init__(self, scope, datadir, connstring, env_dir, threshold_mb,
threshold_pct, use_iam, no_leader, retries):
def __init__(self, scope: str, datadir: str, connstring: str, env_dir: str, threshold_mb: int,
threshold_pct: int, use_iam: int, no_leader: bool, retries: int) -> None:
self.scope = scope
self.leader_connection = connstring
self.data_dir = datadir
@@ -129,7 +125,7 @@ class WALERestore(object):
self.init_error = (not os.path.exists(self.wal_e.env_dir))
self.retries = retries
def run(self):
def run(self) -> int:
"""
Creates a new replica using WAL-E
@@ -158,7 +154,7 @@ class WALERestore(object):
logger.exception("Unhandled exception when running WAL-E restore")
return ExitCode.FAIL
def should_use_s3_to_create_replica(self):
def should_use_s3_to_create_replica(self) -> Optional[bool]:
""" determine whether it makes sense to use S3 and not pg_basebackup """
threshold_megabytes = self.wal_e.threshold_mb
@@ -218,7 +214,7 @@ class WALERestore(object):
try:
# get the difference in bytes between the current WAL location and the backup start offset
con = psycopg.connect(self.leader_connection)
if con.server_version >= 100000:
if getattr(con, 'server_version', 0) >= 100000:
wal_name = 'wal'
lsn_name = 'lsn'
else:
@@ -232,8 +228,9 @@ class WALERestore(object):
" ELSE pg_catalog.pg_{0}_{1}_diff(pg_catalog.pg_current_{0}_{1}(), %s)::bigint"
" END").format(wal_name, lsn_name),
(backup_start_lsn, backup_start_lsn, backup_start_lsn))
diff_in_bytes = int(cur.fetchone()[0])
for row in cur:
diff_in_bytes = int(row[0])
break
except psycopg.Error:
logger.exception('could not determine difference with the leader location')
if attempts_no < self.retries: # retry in case of a temporarily connection issue
@@ -262,11 +259,11 @@ class WALERestore(object):
are_thresholds_ok = is_size_thresh_ok and is_percentage_thresh_ok
class Size(object):
def __init__(self, n_bytes, prefix=None):
def __init__(self, n_bytes: float, prefix: Optional[str] = None) -> None:
self.n_bytes = n_bytes
self.prefix = prefix
def __repr__(self):
def __repr__(self) -> str:
if self.prefix is not None:
n_bytes = size_as_bytes(self.n_bytes, self.prefix)
else:
@@ -274,10 +271,10 @@ class WALERestore(object):
return repr_size(n_bytes)
class HumanContext(object):
def __init__(self, items):
def __init__(self, items: List[Tuple[str, Any]]) -> None:
self.items = items
def __repr__(self):
def __repr__(self) -> str:
return ', '.join('{}={!r}'.format(key, value)
for key, value in self.items)
@@ -298,7 +295,7 @@ class WALERestore(object):
logger.info('Thresholds are OK, using wal-e basebackup: %s', human_context)
return are_thresholds_ok
def fix_subdirectory_path_if_broken(self, dirname):
def fix_subdirectory_path_if_broken(self, dirname: str) -> bool:
# in case it is a symlink pointing to a non-existing location, remove it and create the actual directory
path = os.path.join(self.data_dir, dirname)
if not os.path.exists(path):
@@ -316,7 +313,7 @@ class WALERestore(object):
return False
return True
def create_replica_with_s3(self):
def create_replica_with_s3(self) -> int:
# if we're set up, restore the replica using fetch latest
try:
cmd = self.wal_e.cmd + ['backup-fetch',
@@ -334,7 +331,7 @@ class WALERestore(object):
return exit_code
def main():
def main() -> int:
logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', level=logging.INFO)
parser = argparse.ArgumentParser(description='Script to image replicas using WAL-E')
parser.add_argument('--scope', required=True)
@@ -363,11 +360,12 @@ def main():
threshold_pct=args.threshold_backup_size_percentage, use_iam=args.use_iam,
no_leader=args.no_leader, retries=args.retries)
exit_code = restore.run()
if not exit_code == ExitCode.RETRY_LATER: # only WAL-E failures lead to the retry
if exit_code != ExitCode.RETRY_LATER: # only WAL-E failures lead to the retry
logger.debug('exit_code is %r, not retrying', exit_code)
break
time.sleep(RETRY_SLEEP_INTERVAL)
assert exit_code is not None
return exit_code

View File

@@ -7,9 +7,9 @@
:var DEC_RE: regular expression to match decimal numbers, signed or unsigned.
:var HEX_RE: regular expression to match hex strings, signed or unsigned.
:var DBL_RE: regular expression to match double precision numbers, signed or unsigned. Matches scientific notation too.
:var WHITESPACE_RE: regular expression to match whitespace characters
"""
import errno
import json.decoder as json_decoder
import logging
import os
import platform
@@ -21,9 +21,10 @@ import tempfile
import time
from shlex import split
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union, TYPE_CHECKING
from typing import Any, Callable, Dict, Iterator, List, Optional, Union, Tuple, Type, TYPE_CHECKING
from dateutil import tz
from json import JSONDecoder
from urllib3.response import HTTPResponse
from .exceptions import PatroniException
@@ -42,9 +43,10 @@ OCT_RE = re.compile(r'^[-+]?0[0-7]*')
DEC_RE = re.compile(r'^[-+]?(0|[1-9][0-9]*)')
HEX_RE = re.compile(r'^[-+]?0x[0-9a-fA-F]+')
DBL_RE = re.compile(r'^[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?')
WHITESPACE_RE = re.compile(r'[ \t\n\r]*', re.VERBOSE | re.MULTILINE | re.DOTALL)
def deep_compare(obj1: Dict, obj2: Dict) -> bool:
def deep_compare(obj1: Dict[Any, Union[Any, Dict[Any, Any]]], obj2: Dict[Any, Union[Any, Dict[Any, Any]]]) -> bool:
"""Recursively compare two dictionaries to check if they are equal in terms of keys and values.
.. note::
@@ -84,7 +86,7 @@ def deep_compare(obj1: Dict, obj2: Dict) -> bool:
return True
def patch_config(config: Dict, data: Dict) -> bool:
def patch_config(config: Dict[Any, Union[Any, Dict[Any, Any]]], data: Dict[Any, Union[Any, Dict[Any, Any]]]) -> bool:
"""Update and append to dictionary *config* from overrides in *data*.
.. note::
@@ -239,7 +241,7 @@ def strtod(value: Any) -> Tuple[Union[float, None], str]:
return None, value
def convert_to_base_unit(value: Union[int, float], unit: str, base_unit: str) -> Union[int, float, None]:
def convert_to_base_unit(value: Union[int, float], unit: str, base_unit: Optional[str]) -> Union[int, float, None]:
"""Convert *value* as a *unit* of compute information or time to *base_unit*.
:param value: value to be converted to the base unit.
@@ -268,7 +270,7 @@ def convert_to_base_unit(value: Union[int, float], unit: str, base_unit: str) ->
>>> convert_to_base_unit(1, 'GB', '512 MB') is None
True
"""
convert = {
convert: Dict[str, Dict[str, Union[int, float]]] = {
'B': {'B': 1, 'kB': 1024, 'MB': 1024 * 1024, 'GB': 1024 * 1024 * 1024, 'TB': 1024 * 1024 * 1024 * 1024},
'kB': {'B': 1.0 / 1024, 'kB': 1, 'MB': 1024, 'GB': 1024 * 1024, 'TB': 1024 * 1024 * 1024},
'MB': {'B': 1.0 / (1024 * 1024), 'kB': 1.0 / 1024, 'MB': 1, 'GB': 1024, 'TB': 1024 * 1024},
@@ -287,7 +289,7 @@ def convert_to_base_unit(value: Union[int, float], unit: str, base_unit: str) ->
else:
base_value = 1
if base_unit in convert and unit in convert[base_unit]:
if base_value is not None and base_unit in convert and unit in convert[base_unit]:
value *= convert[base_unit][unit] / float(base_value)
if unit in round_order:
@@ -297,7 +299,7 @@ def convert_to_base_unit(value: Union[int, float], unit: str, base_unit: str) ->
return value
def parse_int(value: Any, base_unit: Optional[str] = None) -> Union[int, None]:
def parse_int(value: Any, base_unit: Optional[str] = None) -> Optional[int]:
"""Parse *value* as an :class:`int`.
:param value: any value that can be handled either by :func:`strtol` or :func:`strtod`. If *value* contains a
@@ -350,7 +352,7 @@ def parse_int(value: Any, base_unit: Optional[str] = None) -> Union[int, None]:
return round(val)
def parse_real(value: Any, base_unit: Optional[str] = None) -> Union[float, None]:
def parse_real(value: Any, base_unit: Optional[str] = None) -> Optional[float]:
"""Parse *value* as a :class:`float`.
:param value: any value that can be handled by :func:`strtod`. If *value* contains a unit, then *base_unit* must
@@ -381,7 +383,7 @@ def parse_real(value: Any, base_unit: Optional[str] = None) -> Union[float, None
return convert_to_base_unit(val, unit, base_unit)
def compare_values(vartype: str, unit: str, old_value: Any, new_value: Any) -> bool:
def compare_values(vartype: str, unit: Optional[str], old_value: Any, new_value: Any) -> bool:
"""Check if *old_value* and *new_value* are equivalent after parsing them as *vartype*.
:param vartpe: the target type to parse *old_value* and *new_value* before comparing them. Accepts any among of the
@@ -426,7 +428,7 @@ def compare_values(vartype: str, unit: str, old_value: Any, new_value: Any) -> b
>>> compare_values('integer', 'kB', 4098, '4097.5kB')
True
"""
converters = {
converters: Dict[str, Callable[[str, Optional[str]], Union[None, bool, int, float, str]]] = {
'bool': lambda v1, v2: parse_bool(v1),
'integer': parse_int,
'real': parse_real,
@@ -434,11 +436,11 @@ def compare_values(vartype: str, unit: str, old_value: Any, new_value: Any) -> b
'string': lambda v1, v2: str(v1)
}
convert = converters.get(vartype) or converters['string']
old_value = convert(old_value, None)
new_value = convert(new_value, unit)
converter = converters.get(vartype) or converters['string']
old_converted = converter(old_value, None)
new_converted = converter(new_value, unit)
return old_value is not None and new_value is not None and old_value == new_value
return old_converted is not None and new_converted is not None and old_converted == new_converted
def _sleep(interval: Union[int, float]) -> None:
@@ -467,11 +469,11 @@ class Retry(object):
:ivar retry_exceptions: single exception or tuple
"""
def __init__(self, max_tries: Optional[int] = 1, delay: Optional[float] = 0.1, backoff: Optional[int] = 2,
max_jitter: Optional[float] = 0.8, max_delay: Optional[int] = 3600,
sleep_func: Optional[Callable[[Union[int, float]], None]] = _sleep,
def __init__(self, max_tries: Optional[int] = 1, delay: float = 0.1, backoff: int = 2,
max_jitter: float = 0.8, max_delay: int = 3600,
sleep_func: Callable[[Union[int, float]], None] = _sleep,
deadline: Optional[Union[int, float]] = None,
retry_exceptions: Optional[Union[Exception, Tuple[Exception]]] = PatroniException) -> None:
retry_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = PatroniException) -> None:
"""Create a :class:`Retry` instance for retrying function calls.
:param max_tries: how many times to retry the command. ``-1`` means infinite tries.
@@ -504,7 +506,7 @@ class Retry(object):
def copy(self) -> 'Retry':
"""Return a clone of this retry manager."""
return Retry(max_tries=self.max_tries, delay=self.delay, backoff=self.backoff,
max_jitter=self.max_jitter / 100.0, max_delay=self.max_delay, sleep_func=self.sleep_func,
max_jitter=self.max_jitter / 100.0, max_delay=int(self.max_delay), sleep_func=self.sleep_func,
deadline=self.deadline, retry_exceptions=self.retry_exceptions)
@property
@@ -525,11 +527,11 @@ class Retry(object):
self._cur_delay = min(self._cur_delay * self.backoff, self.max_delay)
@property
def stoptime(self) -> Union[float, None]:
def stoptime(self) -> float:
"""Get the current stop time."""
return self._cur_stoptime
return self._cur_stoptime or 0
def __call__(self, func: Callable, *args: Any, **kwargs: Any) -> Any:
def __call__(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
"""Call a function *func* with arguments ``*args`` and ``*kwargs`` in a loop.
*func* will be called until one of the following conditions is met:
@@ -561,7 +563,9 @@ class Retry(object):
logger.warning('Retry got exception: %s', e)
raise RetryFailedError("Too many retry attempts")
self._attempts += 1
sleeptime = hasattr(e, 'sleeptime') and e.sleeptime or self.sleeptime
sleeptime = getattr(e, 'sleeptime', None)
if not isinstance(sleeptime, (int, float)):
sleeptime = self.sleeptime
if self._cur_stoptime is not None and time.time() + sleeptime >= self._cur_stoptime:
logger.warning('Retry got exception: %s', e)
@@ -571,7 +575,7 @@ class Retry(object):
self.update_delay()
def polling_loop(timeout: Union[int, float], interval: Optional[Union[int, float]] = 1) -> Iterator[int]:
def polling_loop(timeout: Union[int, float], interval: Union[int, float] = 1) -> Iterator[int]:
"""Return an iterator that returns values every *interval* seconds until *timeout* has passed.
.. note::
@@ -587,10 +591,10 @@ def polling_loop(timeout: Union[int, float], interval: Optional[Union[int, float
while time.time() < end_time:
yield iteration
iteration += 1
time.sleep(interval)
time.sleep(float(interval))
def split_host_port(value: str, default_port: int) -> Tuple[str, int]:
def split_host_port(value: str, default_port: Optional[int]) -> Tuple[str, int]:
"""Extract host(s) and port from *value*.
:param value: string from where host(s) and port will be extracted. Accepts either of these formats
@@ -625,11 +629,11 @@ def split_host_port(value: str, default_port: int) -> Tuple[str, int]:
# If *value* contains ``:`` we consider it to be an IPv6 address, so we attempt to remove possible square brackets
if ':' in t[0]:
t[0] = ','.join([h.strip().strip('[]') for h in t[0].split(',')])
t.append(default_port)
t.append(str(default_port))
return t[0], int(t[1])
def uri(proto: str, netloc: Union[List, Tuple[str, int], str], path: Optional[str] = '',
def uri(proto: str, netloc: Union[List[str], Tuple[str, Union[int, str]], str], path: Optional[str] = '',
user: Optional[str] = None) -> str:
"""Construct URI from given arguments.
@@ -670,17 +674,15 @@ def iter_response_objects(response: HTTPResponse) -> Iterator[Dict[str, Any]]:
:rtype: Iterator[:class:`dict`] with current JSON document.
"""
prev = ''
decoder = json_decoder.JSONDecoder()
decoder = JSONDecoder()
for chunk in response.read_chunked(decode_content=False):
if isinstance(chunk, bytes):
chunk = chunk.decode('utf-8')
chunk = prev + chunk
chunk = prev + chunk.decode('utf-8')
length = len(chunk)
# ``chunk`` is analyzed in parts. ``idx`` holds the position of the first character in the current part that is
# neither space nor tab nor line-break, or in other words, the position in the ``chunk`` where it is likely
# that a JSON document begins
idx = json_decoder.WHITESPACE.match(chunk, 0).end()
idx = WHITESPACE_RE.match(chunk, 0).end() # pyright: ignore [reportOptionalMemberAccess]
while idx < length:
try:
# Get a JSON document from the chunk. ``message`` is a dictionary representing the JSON document, and
@@ -690,7 +692,7 @@ def iter_response_objects(response: HTTPResponse) -> Iterator[Dict[str, Any]]:
break
else:
yield message
idx = json_decoder.WHITESPACE.match(chunk, idx).end()
idx = WHITESPACE_RE.match(chunk, idx).end() # pyright: ignore [reportOptionalMemberAccess]
# It is not usual that a ``chunk`` would contain more than one JSON document, but we handle that just in case
prev = chunk[idx:]
@@ -732,7 +734,7 @@ def cluster_as_json(cluster: 'Cluster', global_config: Optional['GlobalConfig']
leader_name = cluster.leader.name if cluster.leader else None
cluster_lsn = cluster.last_lsn or 0
ret = {'members': []}
ret: Dict[str, Any] = {'members': []}
for m in cluster.members:
if m.name == leader_name:
role = 'standby_leader' if global_config.is_standby_cluster else 'leader'
@@ -762,7 +764,8 @@ def cluster_as_json(cluster: 'Cluster', global_config: Optional['GlobalConfig']
ret['members'].append(member)
# sort members by name for consistency
ret['members'].sort(key=lambda m: m['name'])
cmp: Callable[[Dict[str, Any]], bool] = lambda m: m['name']
ret['members'].sort(key=cmp)
if global_config.is_paused:
ret['pause'] = True
if cluster.failover and cluster.failover.scheduled_at:
@@ -791,7 +794,7 @@ def is_subpath(d1: str, d2: str) -> bool:
return os.path.commonprefix([real_d1, real_d2 + os.path.sep]) == real_d1
def validate_directory(d: str, msg: Optional[str] = "{} {}") -> None:
def validate_directory(d: str, msg: str = "{} {}") -> None:
"""Ensure directory exists and is writable.
.. note::
@@ -842,7 +845,7 @@ def data_directory_is_empty(data_dir: str) -> bool:
return all(os.name != 'nt' and (n.startswith('.') or n == 'lost+found') for n in os.listdir(data_dir))
def keepalive_intvl(timeout: int, idle: int, cnt: Optional[int] = 3) -> int:
def keepalive_intvl(timeout: int, idle: int, cnt: int = 3) -> int:
"""Calculate the value to be used as ``TCP_KEEPINTVL`` based on *timeout*, *idle*, and *cnt*.
:param timeout: value for ``TCP_USER_TIMEOUT``.
@@ -854,7 +857,7 @@ def keepalive_intvl(timeout: int, idle: int, cnt: Optional[int] = 3) -> int:
return max(1, int(float(timeout - idle) / cnt))
def keepalive_socket_options(timeout: int, idle: int, cnt: Optional[int] = 3) -> Iterator[Tuple[int, int, int]]:
def keepalive_socket_options(timeout: int, idle: int, cnt: int = 3) -> Iterator[Tuple[int, int, int]]:
"""Get all keepalive related options to be set in a socket.
:param timeout: value for ``TCP_USER_TIMEOUT``.
@@ -871,25 +874,27 @@ def keepalive_socket_options(timeout: int, idle: int, cnt: Optional[int] = 3) ->
"""
yield (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
if sys.platform.startswith('linux'):
yield (socket.SOL_TCP, 18, int(timeout * 1000)) # TCP_USER_TIMEOUT
TCP_KEEPIDLE = getattr(socket, 'TCP_KEEPIDLE', None)
TCP_KEEPINTVL = getattr(socket, 'TCP_KEEPINTVL', None)
TCP_KEEPCNT = getattr(socket, 'TCP_KEEPCNT', None)
elif sys.platform.startswith('darwin'):
TCP_KEEPIDLE = 0x10 # (named "TCP_KEEPALIVE" in C)
TCP_KEEPINTVL = 0x101
TCP_KEEPCNT = 0x102
else:
if not (sys.platform.startswith('linux') or sys.platform.startswith('darwin')):
return
intvl = keepalive_intvl(timeout, idle, cnt)
yield (socket.IPPROTO_TCP, TCP_KEEPIDLE, idle)
yield (socket.IPPROTO_TCP, TCP_KEEPINTVL, intvl)
yield (socket.IPPROTO_TCP, TCP_KEEPCNT, cnt)
if sys.platform.startswith('linux'):
yield (socket.SOL_TCP, 18, int(timeout * 1000)) # TCP_USER_TIMEOUT
# The socket constants from MacOS netinet/tcp.h are not exported by python's
# socket module, therefore we are using 0x10, 0x101, 0x102 constants.
TCP_KEEPIDLE = getattr(socket, 'TCP_KEEPIDLE', 0x10 if sys.platform.startswith('darwin') else None)
if TCP_KEEPIDLE is not None:
yield (socket.IPPROTO_TCP, TCP_KEEPIDLE, idle)
TCP_KEEPINTVL = getattr(socket, 'TCP_KEEPINTVL', 0x101 if sys.platform.startswith('darwin') else None)
if TCP_KEEPINTVL is not None:
intvl = keepalive_intvl(timeout, idle, cnt)
yield (socket.IPPROTO_TCP, TCP_KEEPINTVL, intvl)
TCP_KEEPCNT = getattr(socket, 'TCP_KEEPCNT', 0x102 if sys.platform.startswith('darwin') else None)
if TCP_KEEPCNT is not None:
yield (socket.IPPROTO_TCP, TCP_KEEPCNT, cnt)
def enable_keepalive(sock: socket.socket, timeout: int, idle: int, cnt: Optional[int] = 3) -> Union[int, None]:
def enable_keepalive(sock: socket.socket, timeout: int, idle: int, cnt: int = 3) -> None:
"""Enable keepalive for *sock*.
Will set socket options depending on the platform, as per return of :func:`keepalive_socket_options`.
@@ -908,7 +913,7 @@ def enable_keepalive(sock: socket.socket, timeout: int, idle: int, cnt: Optional
SIO_KEEPALIVE_VALS = getattr(socket, 'SIO_KEEPALIVE_VALS', None)
if SIO_KEEPALIVE_VALS is not None: # Windows
intvl = keepalive_intvl(timeout, idle, cnt)
return sock.ioctl(SIO_KEEPALIVE_VALS, (1, idle * 1000, intvl * 1000))
sock.ioctl(SIO_KEEPALIVE_VALS, (1, idle * 1000, intvl * 1000))
for opt in keepalive_socket_options(timeout, idle, cnt):
sock.setsockopt(*opt)

View File

@@ -11,9 +11,9 @@ import shutil
import socket
import subprocess
from typing import Any, Union, Iterator, List, Optional as OptionalType
from typing import Any, Dict, Union, Iterator, List, Optional as OptionalType
from .utils import split_host_port, data_directory_is_empty
from .utils import parse_int, split_host_port, data_directory_is_empty
from .dcs import dcs_modules
from .exceptions import ConfigParseError
@@ -48,8 +48,7 @@ def validate_connect_address(address: str) -> bool:
return True
def validate_host_port(host_port: str, listen: OptionalType[bool] = False,
multiple_hosts: OptionalType[bool] = False) -> bool:
def validate_host_port(host_port: str, listen: bool = False, multiple_hosts: bool = False) -> bool:
"""Check if host(s) and port are valid and available for usage.
:param host_port: the host(s) and port to be validated. It can be in either of these formats
@@ -68,7 +67,7 @@ def validate_host_port(host_port: str, listen: OptionalType[bool] = False,
* If :class:`socket.gaierror` is thrown by socket module when attempting to connect to the given address(es).
"""
try:
hosts, port = split_host_port(host_port, None)
hosts, port = split_host_port(host_port, 1)
except (ValueError, TypeError):
raise ConfigParseError("contains a wrong value")
else:
@@ -197,6 +196,7 @@ def get_major_version(bin_dir: OptionalType[str] = None) -> str:
binary = os.path.join(bin_dir, 'postgres')
version = subprocess.check_output([binary, '--version']).decode()
version = re.match(r'^[^\s]+ [^\s]+ (\d+)(\.(\d+))?', version)
assert version is not None
return '.'.join([version.group(1), version.group(3)]) if int(version.group(1)) < 10 else version.group(1)
@@ -251,8 +251,8 @@ class Result(object):
:ivar error: error message if the validation failed, otherwise ``None``.
"""
def __init__(self, status: bool, error: OptionalType[str] = "didn't pass validation", level: OptionalType[int] = 0,
path: OptionalType[str] = "", data: OptionalType[Any] = "") -> None:
def __init__(self, status: bool, error: OptionalType[str] = "didn't pass validation", level: int = 0,
path: str = "", data: Any = "") -> None:
"""Create a :class:`Result` object based on the given arguments.
.. note::
@@ -290,7 +290,7 @@ class Case(object):
them, if they are set.
"""
def __init__(self, schema: dict) -> None:
def __init__(self, schema: Dict[str, Any]) -> None:
"""Create a :class:`Case` object.
:param schema: the schema for validating a set of attributes that may be available in the configuration.
@@ -317,7 +317,7 @@ class Or(object):
validation functions and/or expected types for a given configuration option.
"""
def __init__(self, *args) -> None:
def __init__(self, *args: Any) -> None:
"""Create an :class:`Or` object.
:param `*args`: any arguments that the caller wants to be stored in this :class:`Or` object.
@@ -486,7 +486,7 @@ class Schema(object):
:param data: configuration to be validated against ``validator``.
:returns: list of errors identified while validating the *data*, if any.
"""
errors = []
errors: List[str] = []
for i in self.validate(data):
if not i.status:
errors.append(str(i))
@@ -533,14 +533,13 @@ class Schema(object):
except Exception as e:
yield Result(False, "didn't pass validation: {}".format(e), data=self.data)
elif isinstance(self.validator, dict):
if not len(self.validator):
if not isinstance(self.data, dict):
yield Result(isinstance(self.data, dict), "is not a dictionary", level=1, data=self.data)
elif isinstance(self.validator, list):
if not isinstance(self.data, list):
yield Result(isinstance(self.data, list), "is not a list", level=1, data=self.data)
return
for i in self.iter():
yield i
yield from self.iter()
def iter(self) -> Iterator[Result]:
"""Iterate over ``validator``, if it is an iterable object, to validate the corresponding entries in ``data``.
@@ -553,12 +552,11 @@ class Schema(object):
if not isinstance(self.data, dict):
yield Result(False, "is not a dictionary.", level=1)
else:
for i in self.iter_dict():
yield i
yield from self.iter_dict()
elif isinstance(self.validator, list):
if len(self.data) == 0:
yield Result(False, "is an empty list", data=self.data)
if len(self.validator) > 0:
if self.validator:
for key, value in enumerate(self.data):
# Although the value in the configuration (`data`) is expected to contain 1 or more entries, only
# the first validator defined in `validator` property list will be used. It is only defined as a
@@ -569,11 +567,9 @@ class Schema(object):
yield Result(v.status, v.error,
path=(str(key) + ("." + v.path if v.path else "")), level=v.level, data=value)
elif isinstance(self.validator, Directory):
for v in self.validator.validate(self.data):
yield v
yield from self.validator.validate(self.data)
elif isinstance(self.validator, Or):
for i in self.iter_or():
yield i
yield from self.iter_or()
def iter_dict(self) -> Iterator[Result]:
"""Iterate over a :class:`dict` based ``validator`` to validate the corresponding entries in ``data``.
@@ -606,14 +602,14 @@ class Schema(object):
:rtype: Iterator[:class:`Result`] objects with the error message related to the failure, if any check fails.
"""
results = []
results: List[Result] = []
for a in self.validator.args:
r = []
r: List[Result] = []
# Each of the `Or` validators can throw 0 to many `Result` instances.
for v in Schema(a).validate(self.data):
r.append(v)
if any([x.status for x in r]) and not all([x.status for x in r]):
results += filter(lambda x: not x.status, r)
results += [x for x in r if not x.status]
else:
results += r
# None of the `Or` validators succeeded to validate `data`, so we report the issues back.
@@ -664,12 +660,12 @@ def _get_type_name(python_type: Any) -> str:
Returns:
User friendly name of the given Python type.
"""
return {str: 'a string', int: 'an integer', float: 'a number',
bool: 'a boolean', list: 'an array', dict: 'a dictionary'}.get(
python_type, getattr(python_type, __name__, "unknown type"))
types: Dict[Any, str] = {str: 'a string', int: 'an integer', float: 'a number',
bool: 'a boolean', list: 'an array', dict: 'a dictionary'}
return types.get(python_type, getattr(python_type, __name__, "unknown type"))
def assert_(condition: bool, message: OptionalType[str] = "Wrong value") -> None:
def assert_(condition: bool, message: str = "Wrong value") -> None:
"""Assert that a given condition is ``True``.
If the assertion fails, then throw a message.
@@ -680,14 +676,63 @@ def assert_(condition: bool, message: OptionalType[str] = "Wrong value") -> None
assert condition, message
class IntValidator(object):
"""Validate an integer setting.
:cvar expected_type: the expect Python type for an integer setting (:class:`int`).
:ivar min: minimum allowed value for the setting, if any.
:ivar max: maximum allowed value for the setting, if any.
:ivar base_unit: the base unit to convert the value to before checking if it's within `min` and `max` range.
:ivar raise_assert: if an ``assert`` call should be performed regarding expected type and valid range.
"""
expected_type = int
def __init__(self, min: OptionalType[int] = None, max: OptionalType[int] = None,
base_unit: OptionalType[str] = None, raise_assert: bool = False) -> None:
"""Create an :class:`IntValidator` object with the given rules.
:param min: minimum allowed value for the setting, if any.
:param max: maximum allowed value for the setting, if any.
:param base_unit: the base unit to convert the value to before checking if it's within *min* and *max* range.
:param raise_assert: if an ``assert`` call should be performed regarding expected type and valid range.
"""
self.min = min
self.max = max
self.base_unit = base_unit
self.raise_assert = raise_assert
def __call__(self, value: Union[int, str]) -> bool:
"""Check if *value* is a valid integer and within the expected range.
.. note::
If ``raise_assert`` is ``True`` and *value* is not valid, then an ``AssertionError`` will be triggered.
:param value: value to be checked against the rules defined for this :class:`IntValidator` instance.
:returns: ``True`` if *value* is valid and within the expected range.
"""
if self.base_unit:
value = parse_int(value, self.base_unit) or ""
ret = isinstance(value, int)\
and (self.min is None or value >= self.min)\
and (self.max is None or value <= self.max)
if self.raise_assert:
assert_(ret)
return ret
def validate_watchdog_mode(value: Any) -> None:
assert_(isinstance(value, (str, bool)), "expected type is not a string")
assert_(value in (False, "off", "automatic", "required"))
userattributes = {"username": "", Optional("password"): ""}
available_dcs = [m.split(".")[-1] for m in dcs_modules()]
validate_host_port_list.expected_type = list
comma_separated_host_port.expected_type = str
validate_connect_address.expected_type = str
validate_host_port_listen.expected_type = str
validate_host_port_listen_multiple_hosts.expected_type = str
validate_data_dir.expected_type = str
setattr(validate_host_port_list, 'expected_type', list)
setattr(comma_separated_host_port, 'expected_type', str)
setattr(validate_connect_address, 'expected_type', str)
setattr(validate_host_port_listen, 'expected_type', str)
setattr(validate_host_port_listen_multiple_hosts, 'expected_type', str)
setattr(validate_data_dir, 'expected_type', str)
validate_etcd = {
Or("host", "hosts", "srv", "srv_suffix", "url", "proxy"): Case({
"host": validate_host_port,
@@ -704,7 +749,7 @@ schema = Schema({
"restapi": {
"listen": validate_host_port_listen,
"connect_address": validate_connect_address,
Optional("request_queue_size"): lambda i: assert_(0 <= int(i) <= 4096)
Optional("request_queue_size"): IntValidator(min=0, max=4096, raise_assert=True)
},
Optional("bootstrap"): {
"dcs": {
@@ -726,7 +771,7 @@ schema = Schema({
"etcd3": validate_etcd,
"exhibitor": {
"hosts": [str],
"port": lambda i: assert_(int(i) <= 65535),
"port": IntValidator(max=65535, raise_assert=True),
Optional("pool_interval"): int
},
"raft": {
@@ -767,7 +812,7 @@ schema = Schema({
Optional("bin_dir"): Directory(contains_executable=["pg_ctl", "initdb", "pg_controldata", "pg_basebackup",
"postgres", "pg_isready"]),
Optional("parameters"): {
Optional("unix_socket_directories"): lambda s: assert_(all([isinstance(s, str), len(s)]))
Optional("unix_socket_directories"): str
},
Optional("pg_hba"): [str],
Optional("pg_ident"): [str],
@@ -775,7 +820,7 @@ schema = Schema({
Optional("use_pg_rewind"): bool
},
Optional("watchdog"): {
Optional("mode"): lambda m: assert_(m in ["off", "automatic", "required"]),
Optional("mode"): validate_watchdog_mode,
Optional("device"): str
},
Optional("tags"): {

View File

@@ -4,7 +4,10 @@ import platform
import sys
from threading import RLock
from patroni.exceptions import WatchdogError
from typing import Any, Callable, Dict, Optional, Union
from ..config import Config
from ..exceptions import WatchdogError
__all__ = ['WatchdogError', 'Watchdog']
@@ -15,10 +18,10 @@ MODE_AUTOMATIC = 'automatic' # Will use a watchdog if one is available
MODE_OFF = 'off' # Will not try to use a watchdog
def parse_mode(mode):
def parse_mode(mode: Union[bool, str]) -> str:
if mode is False:
return MODE_OFF
mode = mode.lower()
mode = str(mode).lower()
if mode in ['require', 'required']:
return MODE_REQUIRED
elif mode in ['auto', 'automatic']:
@@ -29,8 +32,8 @@ def parse_mode(mode):
return MODE_OFF
def synchronized(func):
def wrapped(self, *args, **kwargs):
def synchronized(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapped(self: 'Watchdog', *args: Any, **kwargs: Any) -> Any:
with self._lock:
return func(self, *args, **kwargs)
return wrapped
@@ -38,7 +41,7 @@ def synchronized(func):
class WatchdogConfig(object):
"""Helper to contain a snapshot of configuration"""
def __init__(self, config):
def __init__(self, config: Config) -> None:
watchdog_config = config.get("watchdog") or {'mode': 'automatic'}
self.mode = parse_mode(watchdog_config.get('mode', 'automatic'))
@@ -49,15 +52,15 @@ class WatchdogConfig(object):
self.driver_config = dict((k, v) for k, v in watchdog_config.items()
if k not in ['mode', 'safety_margin', 'driver'])
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
return isinstance(other, WatchdogConfig) and \
all(getattr(self, attr) == getattr(other, attr) for attr in
['mode', 'ttl', 'loop_wait', 'safety_margin', 'driver', 'driver_config'])
def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
return not self == other
def get_impl(self):
def get_impl(self) -> 'WatchdogBase':
if self.driver == 'testing': # pragma: no cover
from patroni.watchdog.linux import TestingWatchdogDevice
return TestingWatchdogDevice.from_config(self.driver_config)
@@ -68,14 +71,14 @@ class WatchdogConfig(object):
return NullWatchdog()
@property
def timeout(self):
def timeout(self) -> int:
if self.safety_margin == -1:
return int(self.ttl // 2)
else:
return self.ttl - self.safety_margin
@property
def timing_slack(self):
def timing_slack(self) -> int:
return self.timeout - self.loop_wait
@@ -84,8 +87,9 @@ class Watchdog(object):
When activation fails underlying implementation will be switched to a Null implementation. To avoid log spam
activation will only be retried when watchdog configuration is changed."""
def __init__(self, config):
self.active_config = self.config = WatchdogConfig(config)
def __init__(self, config: Config) -> None:
self.config = WatchdogConfig(config)
self.active_config: WatchdogConfig = self.config
self._lock = RLock()
self.active = False
@@ -98,7 +102,7 @@ class Watchdog(object):
sys.exit(1)
@synchronized
def reload_config(self, config):
def reload_config(self, config: Config) -> None:
self.config = WatchdogConfig(config)
# Turning a watchdog off can always be done immediately
if self.config.mode == MODE_OFF:
@@ -115,7 +119,7 @@ class Watchdog(object):
self.active_config = self.config
@synchronized
def activate(self):
def activate(self) -> bool:
"""Activates the watchdog device with suitable timeouts. While watchdog is active keepalive needs
to be called every time loop_wait expires.
@@ -124,7 +128,7 @@ class Watchdog(object):
self.active = True
return self._activate()
def _activate(self):
def _activate(self) -> bool:
self.active_config = self.config
if self.config.timing_slack < 0:
@@ -138,12 +142,13 @@ class Watchdog(object):
except WatchdogError as e:
logger.warning("Could not activate %s: %s", self.impl.describe(), e)
self.impl = NullWatchdog()
actual_timeout = self.impl.get_timeout()
if self.impl.is_running and not self.impl.can_be_disabled:
logger.warning("Watchdog implementation can't be disabled."
" Watchdog will trigger after Patroni loses leader key.")
if not self.impl.is_running or actual_timeout > self.config.timeout:
if not self.impl.is_running or actual_timeout and actual_timeout > self.config.timeout:
if self.config.mode == MODE_REQUIRED:
if self.impl.is_null:
logger.error("Configuration requires watchdog, but watchdog could not be configured.")
@@ -167,7 +172,7 @@ class Watchdog(object):
return True
def _set_timeout(self):
def _set_timeout(self) -> Optional[int]:
if self.impl.has_set_timeout():
self.impl.set_timeout(self.config.timeout)
@@ -184,11 +189,11 @@ class Watchdog(object):
return actual_timeout
@synchronized
def disable(self):
def disable(self) -> None:
self._disable()
self.active = False
def _disable(self):
def _disable(self) -> None:
try:
if self.impl.is_running and not self.impl.can_be_disabled:
# Give sysadmin some extra time to clean stuff up.
@@ -200,7 +205,7 @@ class Watchdog(object):
logger.error("Error while disabling watchdog: %s", e)
@synchronized
def keepalive(self):
def keepalive(self) -> None:
try:
if self.active:
self.impl.keepalive()
@@ -225,12 +230,12 @@ class Watchdog(object):
@property
@synchronized
def is_running(self):
def is_running(self) -> bool:
return self.impl.is_running
@property
@synchronized
def is_healthy(self):
def is_healthy(self) -> bool:
if self.config.mode != MODE_REQUIRED:
return True
return self.config.timing_slack >= 0 and self.impl.is_healthy
@@ -242,60 +247,59 @@ class WatchdogBase(abc.ABC):
is_null = False
@property
def is_running(self):
def is_running(self) -> bool:
"""Returns True when watchdog is activated and capable of performing it's task."""
return False
@property
def is_healthy(self):
def is_healthy(self) -> bool:
"""Returns False when calling open() is known to fail."""
return False
@property
def can_be_disabled(self):
def can_be_disabled(self) -> bool:
"""Returns True when watchdog will be disabled by calling close(). Some watchdog devices
will keep running no matter what once activated. May raise WatchdogError if called without
calling open() first."""
return True
@abc.abstractmethod
def open(self):
def open(self) -> None:
"""Open watchdog device.
When watchdog is opened keepalive must be called. Returns nothing on success
or raises WatchdogError if the device could not be opened."""
@abc.abstractmethod
def close(self):
def close(self) -> None:
"""Gracefully close watchdog device."""
@abc.abstractmethod
def keepalive(self):
def keepalive(self) -> None:
"""Resets the watchdog timer.
Watchdog must be open when keepalive is called."""
@abc.abstractmethod
def get_timeout(self):
def get_timeout(self) -> int:
"""Returns the current keepalive timeout in effect."""
@staticmethod
def has_set_timeout():
def has_set_timeout(self) -> bool:
"""Returns True if setting a timeout is supported."""
return False
def set_timeout(self, timeout):
def set_timeout(self, timeout: int) -> None:
"""Set the watchdog timer timeout.
:param timeout: watchdog timeout in seconds"""
raise WatchdogError("Setting timeout is not supported on {0}".format(self.describe()))
def describe(self):
def describe(self) -> str:
"""Human readable name for this device"""
return self.__class__.__name__
@classmethod
def from_config(cls, config):
def from_config(cls, config: Dict[str, Any]) -> 'WatchdogBase':
return cls()
@@ -303,15 +307,15 @@ class NullWatchdog(WatchdogBase):
"""Null implementation when watchdog is not supported."""
is_null = True
def open(self):
def open(self) -> None:
return
def close(self):
def close(self) -> None:
return
def keepalive(self):
def keepalive(self) -> None:
return
def get_timeout(self):
def get_timeout(self) -> int:
# A big enough number to not matter
return 1000000000

View File

@@ -1,8 +1,10 @@
import collections
import ctypes
import os
import platform
from patroni.watchdog.base import WatchdogBase, WatchdogError
from typing import Any, Dict, NamedTuple
from .base import WatchdogBase, WatchdogError
# Pythonification of linux/ioctl.h
IOC_NONE = 0
@@ -19,7 +21,7 @@ machine = platform.machine()
if machine in ['mips', 'sparc', 'powerpc', 'ppc64', 'ppc64le']: # pragma: no cover
IOC_SIZEBITS = 13
IOC_DIRBITS = 3
IOC_NONE, IOC_WRITE, IOC_READ = 1, 4, 2
IOC_NONE, IOC_WRITE = 1, 4
elif machine == 'parisc': # pragma: no cover
IOC_WRITE, IOC_READ = 2, 1
@@ -29,19 +31,19 @@ IOC_SIZESHIFT = IOC_TYPESHIFT + IOC_TYPEBITS
IOC_DIRSHIFT = IOC_SIZESHIFT + IOC_SIZEBITS
def IOW(type_, nr, size):
def IOW(type_: str, nr: int, size: int) -> int:
return IOC(IOC_WRITE, type_, nr, size)
def IOR(type_, nr, size):
def IOR(type_: str, nr: int, size: int) -> int:
return IOC(IOC_READ, type_, nr, size)
def IOWR(type_, nr, size):
def IOWR(type_: str, nr: int, size: int) -> int:
return IOC(IOC_READ | IOC_WRITE, type_, nr, size)
def IOC(dir_, type_, nr, size):
def IOC(dir_: int, type_: str, nr: int, size: int) -> int:
return (dir_ << IOC_DIRSHIFT) \
| (ord(type_) << IOC_TYPESHIFT) \
| (nr << IOC_NRSHIFT) \
@@ -104,9 +106,13 @@ WDIOS = {
# Implementation
class WatchdogInfo(collections.namedtuple('WatchdogInfo', 'options,version,identity')):
class WatchdogInfo(NamedTuple):
"""Watchdog descriptor from the kernel"""
def __getattr__(self, name):
options: int
version: int
identity: str
def __getattr__(self, name: str) -> bool:
"""Convenience has_XYZ attributes for checking WDIOF bits in options"""
if name.startswith('has_') and name[4:] in WDIOF:
return bool(self.options & WDIOF[name[4:]])
@@ -117,32 +123,32 @@ class WatchdogInfo(collections.namedtuple('WatchdogInfo', 'options,version,ident
class LinuxWatchdogDevice(WatchdogBase):
DEFAULT_DEVICE = '/dev/watchdog'
def __init__(self, device):
def __init__(self, device: str) -> None:
self.device = device
self._support_cache = None
self._fd = None
@classmethod
def from_config(cls, config):
def from_config(cls, config: Dict[str, Any]) -> 'LinuxWatchdogDevice':
device = config.get('device', cls.DEFAULT_DEVICE)
return cls(device)
@property
def is_running(self):
def is_running(self) -> bool:
return self._fd is not None
@property
def is_healthy(self):
def is_healthy(self) -> bool:
return os.path.exists(self.device) and os.access(self.device, os.W_OK)
def open(self):
def open(self) -> None:
try:
self._fd = os.open(self.device, os.O_WRONLY)
except OSError as e:
raise WatchdogError("Can't open watchdog device: {0}".format(e))
def close(self):
if self.is_running:
def close(self) -> None:
if self._fd is not None: # self.is_running
try:
os.write(self._fd, b'V')
os.close(self._fd)
@@ -151,10 +157,10 @@ class LinuxWatchdogDevice(WatchdogBase):
raise WatchdogError("Error while closing {0}: {1}".format(self.describe(), e))
@property
def can_be_disabled(self):
def can_be_disabled(self) -> bool:
return self.get_support().has_MAGICCLOSE
def _ioctl(self, func, arg):
def _ioctl(self, func: int, arg: Any) -> None:
"""Runs the specified ioctl on the underlying fd.
Raises WatchdogError if the device is closed.
@@ -165,7 +171,7 @@ class LinuxWatchdogDevice(WatchdogBase):
import fcntl
fcntl.ioctl(self._fd, func, arg, True)
def get_support(self):
def get_support(self) -> WatchdogInfo:
if self._support_cache is None:
info = watchdog_info()
try:
@@ -177,7 +183,7 @@ class LinuxWatchdogDevice(WatchdogBase):
bytearray(info.identity).decode(errors='ignore').rstrip('\x00'))
return self._support_cache
def describe(self):
def describe(self) -> str:
dev_str = " at {0}".format(self.device) if self.device != self.DEFAULT_DEVICE else ""
ver_str = ""
identity = "Linux watchdog device"
@@ -190,17 +196,19 @@ class LinuxWatchdogDevice(WatchdogBase):
return identity + ver_str + dev_str
def keepalive(self):
def keepalive(self) -> None:
if self._fd is None:
raise WatchdogError("Watchdog device is closed")
try:
os.write(self._fd, b'1')
except OSError as e:
raise WatchdogError("Could not send watchdog keepalive: {0}".format(e))
def has_set_timeout(self):
def has_set_timeout(self) -> bool:
"""Returns True if setting a timeout is supported."""
return self.get_support().has_SETTIMEOUT
def set_timeout(self, timeout):
def set_timeout(self, timeout: int) -> None:
timeout = int(timeout)
if not 0 < timeout < 0xFFFF:
raise WatchdogError("Invalid timeout {0}. Supported values are between 1 and 65535".format(timeout))
@@ -209,7 +217,7 @@ class LinuxWatchdogDevice(WatchdogBase):
except (WatchdogError, OSError, IOError) as e:
raise WatchdogError("Could not set timeout on watchdog device: {}".format(e))
def get_timeout(self):
def get_timeout(self) -> int:
timeout = ctypes.c_int()
try:
self._ioctl(WDIOC_GETTIMEOUT, timeout)
@@ -222,14 +230,16 @@ class TestingWatchdogDevice(LinuxWatchdogDevice): # pragma: no cover
"""Converts timeout ioctls to regular writes that can be intercepted from a named pipe."""
timeout = 60
def get_support(self):
def get_support(self) -> WatchdogInfo:
return WatchdogInfo(WDIOF['MAGICCLOSE'] | WDIOF['SETTIMEOUT'], 0, "Watchdog test harness")
def set_timeout(self, timeout):
def set_timeout(self, timeout: int) -> None:
if self._fd is None:
raise WatchdogError("Watchdog device is closed")
buf = "Ctimeout={0}\n".format(timeout).encode('utf8')
while len(buf):
buf = buf[os.write(self._fd, buf):]
self.timeout = timeout
def get_timeout(self):
def get_timeout(self) -> int:
return self.timeout

27
pyrightconfig.json Normal file
View File

@@ -0,0 +1,27 @@
{
"include": [
"patroni"
],
"exclude": [
"**/__pycache__"
],
"ignore": [
],
"defineConstant": {
"DEBUG": true
},
"stubPath": "typings/",
"reportMissingImports": true,
"reportMissingTypeStubs": false,
"pythonVersion": "3.6",
"pythonPlatform": "All",
"typeCheckingMode": "strict"
}

View File

@@ -23,6 +23,7 @@ class MockResponse(object):
def __init__(self, status_code=200):
self.status_code = status_code
self.headers = {'content-type': 'json'}
self.content = '{}'
self.reason = 'Not Found'
@@ -38,10 +39,6 @@ class MockResponse(object):
def getheader(*args):
return ''
@staticmethod
def getheaders():
return {'content-type': 'json'}
def requests_get(url, method='GET', endpoint=None, data='', **kwargs):
members = '[{"id":14855829450254237642,"peerURLs":["http://localhost:2380","http://localhost:7001"],' +\
@@ -85,6 +82,8 @@ class MockCursor(object):
self.description = [Mock()]
def execute(self, sql, *params):
if isinstance(sql, bytes):
sql = sql.decode('utf-8')
if sql.startswith('blabla'):
raise psycopg.ProgrammingError()
elif sql == 'CHECKPOINT' or sql.startswith('SELECT pg_catalog.pg_create_'):
@@ -98,7 +97,7 @@ class MockCursor(object):
elif sql.startswith('SELECT slot_name'):
self.results = [('blabla', 'physical'), ('foobar', 'physical'), ('ls', 'logical', 'a', 'b', 5, 100, 500)]
elif sql.startswith('WITH slots AS (SELECT slot_name, active'):
self.results = [(False, True)]
self.results = [(False, True)] if self.rowcount == 1 else [None]
elif sql.startswith('SELECT CASE WHEN pg_catalog.pg_is_in_recovery()'):
self.results = [(1, 2, 1, 0, False, 1, 1, None, None,
[{"slot_name": "ls", "confirmed_flush_lsn": 12345}],

View File

@@ -136,8 +136,8 @@ class TestCitus(BaseTestPostgresql):
self.assertEqual(parameters['shared_preload_libraries'], 'citus,foo,bar')
self.assertEqual(parameters['wal_level'], 'logical')
@patch.object(CitusHandler, 'is_enabled', Mock(return_value=False))
def test_bootstrap(self):
self.c._config = None
self.c.bootstrap()
def test_ignore_replication_slot(self):

View File

@@ -19,8 +19,9 @@ class TestConfig(unittest.TestCase):
def test_set_dynamic_configuration(self):
with patch.object(Config, '_build_effective_configuration', Mock(side_effect=Exception)):
self.assertIsNone(self.config.set_dynamic_configuration({'foo': 'bar'}))
self.assertTrue(self.config.set_dynamic_configuration({'standby_cluster': {}}))
self.assertFalse(self.config.set_dynamic_configuration({'foo': 'bar'}))
self.assertTrue(self.config.set_dynamic_configuration({'standby_cluster': {}, 'postgresql': {
'parameters': {'cluster_name': 1, 'wal_keep_size': 1, 'track_commit_timestamp': 1, 'wal_level': 1}}}))
def test_reload_local_configuration(self):
os.environ.update({

View File

@@ -195,6 +195,8 @@ class TestConsul(unittest.TestCase):
@patch.object(consul.Consul.KV, 'delete', Mock(return_value=True))
def test_delete_leader(self):
self.c.delete_leader()
self.c._name = 'other'
self.c.delete_leader()
@patch.object(consul.Consul.KV, 'put', Mock(return_value=True))
def test_initialize(self):

View File

@@ -9,7 +9,7 @@ from mock import patch, Mock, PropertyMock
from patroni.ctl import ctl, load_config, output_members, get_dcs, parse_dcs, \
get_all_members, get_any_member, get_cursor, query_member, PatroniCtlException, apply_config_changes, \
format_config_for_editing, show_diff, invoke_editor, format_pg_version, CONFIG_FILE_PATH, PatronictlPrettyTable
from patroni.dcs.etcd import AbstractEtcdClientWithFailover, Failover
from patroni.dcs.etcd import AbstractEtcdClientWithFailover, Cluster, Failover
from patroni.psycopg import OperationalError
from patroni.utils import tzutc
from prettytable import PrettyTable, ALL
@@ -639,6 +639,10 @@ class TestCtl(unittest.TestCase):
self.runner.invoke(ctl, ['edit-config', 'dummy', '--force', '--apply', '-'], input='foo: bar')
mock_get_dcs.return_value.set_config_value.return_value = True
self.runner.invoke(ctl, ['edit-config', 'dummy', '--force', '--apply', '-'], input='foo: bar')
mock_get_dcs.return_value.get_cluster = Mock(return_value=Cluster.empty())
result = self.runner.invoke(ctl, ['edit-config', 'dummy'])
assert result.exit_code == 1
assert 'The config key does not exist in the cluster dummy' in result.output
@patch('patroni.ctl.get_dcs')
def test_version(self, mock_get_dcs):

View File

@@ -138,7 +138,9 @@ class TestClient(unittest.TestCase):
@patch.object(EtcdClient, '_get_machines_list',
Mock(return_value=['http://localhost:2379', 'http://localhost:4001']))
def setUp(self):
self.client = EtcdClient({'srv': 'test', 'retry_timeout': 3}, DnsCachingResolver())
self.etcd = Etcd({'namespace': '/patroni/', 'ttl': 30, 'retry_timeout': 3,
'srv': 'test', 'scope': 'test', 'name': 'foo'})
self.client = self.etcd._client
self.client.http.request = http_request
self.client.http.request_encode_body = http_request

View File

@@ -3,7 +3,7 @@ import json
import unittest
import urllib3
from mock import Mock, patch
from mock import Mock, PropertyMock, patch
from patroni.dcs.etcd import DnsCachingResolver
from patroni.dcs.etcd3 import PatroniEtcd3Client, Cluster, Etcd3Client, Etcd3Error, Etcd3ClientError, RetryFailedError,\
InvalidAuthToken, Unavailable, Unknown, UnsupportedEtcdVersion, UserEmpty, AuthFailed, base64_encode, Etcd3
@@ -85,6 +85,11 @@ class BaseTestEtcd3(unittest.TestCase):
class TestKVCache(BaseTestEtcd3):
@patch.object(urllib3.PoolManager, 'urlopen', mock_urlopen)
@patch.object(Etcd3Client, 'watchprefix', Mock(return_value=urllib3.response.HTTPResponse()))
def test__build_cache(self):
self.kv_cache._build_cache()
def test__do_watch(self):
self.client.watchprefix = Mock(return_value=False)
self.assertRaises(AttributeError, self.kv_cache._do_watch, '1')
@@ -94,14 +99,17 @@ class TestKVCache(BaseTestEtcd3):
def test_run(self):
self.assertRaises(SleepException, self.kv_cache.run)
@patch.object(urllib3.PoolManager, 'urlopen', mock_urlopen)
@patch.object(urllib3.response.HTTPResponse, 'read_chunked',
Mock(return_value=[b'{"error":{"grpc_code":14,"message":"","http_code":503}}']))
@patch.object(Etcd3Client, 'watchprefix', Mock(return_value=urllib3.response.HTTPResponse()))
def test_kill_stream(self):
self.assertRaises(Unavailable, self.kv_cache._do_watch, '1')
self.kv_cache.kill_stream()
with patch.object(MockResponse, 'connection', create=True) as mock_conn:
with patch.object(urllib3.response.HTTPResponse, 'connection') as mock_conn:
self.kv_cache.kill_stream()
mock_conn.sock.close.side_effect = Exception
self.kv_cache.kill_stream()
type(mock_conn).sock = PropertyMock(side_effect=Exception)
self.kv_cache.kill_stream()
class TestPatroniEtcd3Client(BaseTestEtcd3):
@@ -180,7 +188,8 @@ class TestEtcd3(BaseTestEtcd3):
@patch.object(urllib3.PoolManager, 'urlopen', mock_urlopen)
def setUp(self):
super(TestEtcd3, self).setUp()
self.assertRaises(AttributeError, self.kv_cache._build_cache)
# self.assertRaises(AttributeError, self.kv_cache._build_cache)
self.kv_cache._build_cache()
self.kv_cache._is_ready = True
self.etcd3.get_cluster()
@@ -276,6 +285,8 @@ class TestEtcd3(BaseTestEtcd3):
def test_delete_leader(self):
self.etcd3.delete_leader()
self.etcd3._name = 'other'
self.etcd3.delete_leader()
def test_delete_cluster(self):
self.etcd3.delete_cluster()

View File

@@ -244,7 +244,6 @@ class TestHa(PostgresInit):
def test_bootstrap_as_standby_leader(self, initialize):
self.p.data_directory_empty = true
self.ha.cluster = get_cluster_not_initialized_without_leader(cluster_config=ClusterConfig(0, {}, 0))
self.ha.cluster.is_unlocked = true
self.ha.patroni.config._dynamic_configuration = {"standby_cluster": {"port": 5432}}
self.assertEqual(self.ha.run_cycle(), 'trying to bootstrap a new standby leader')
@@ -261,7 +260,6 @@ class TestHa(PostgresInit):
def test_start_as_cascade_replica_in_standby_cluster(self):
self.p.data_directory_empty = true
self.ha.cluster = get_standby_cluster_initialized_with_only_leader()
self.ha.cluster.is_unlocked = false
self.assertEqual(self.ha.run_cycle(), "trying to bootstrap from replica 'test'")
def test_recover_replica_failed(self):
@@ -381,8 +379,8 @@ class TestHa(PostgresInit):
self.ha._async_executor.schedule('promote')
self.assertEqual(self.ha.run_cycle(), 'lost leader before promote')
@patch.object(Cluster, 'is_unlocked', Mock(return_value=False))
def test_long_promote(self):
self.ha.cluster.is_unlocked = false
self.ha.has_lock = true
self.p.is_leader = false
self.p.set_role('primary')
@@ -407,14 +405,13 @@ class TestHa(PostgresInit):
self.p.is_leader = false
self.assertEqual(self.ha.run_cycle(), 'following a different leader because i am not the healthiest node')
@patch.object(Cluster, 'is_unlocked', Mock(return_value=False))
def test_promote_because_have_lock(self):
self.ha.cluster.is_unlocked = false
self.ha.has_lock = true
self.p.is_leader = false
self.assertEqual(self.ha.run_cycle(), 'promoted self to leader because I had the session lock')
def test_promote_without_watchdog(self):
self.ha.cluster.is_unlocked = false
self.ha.has_lock = true
self.p.is_leader = true
with patch.object(Watchdog, 'activate', Mock(return_value=False)):
@@ -424,24 +421,22 @@ class TestHa(PostgresInit):
def test_leader_with_lock(self):
self.ha.cluster = get_cluster_initialized_with_leader()
self.ha.cluster.is_unlocked = false
self.ha.has_lock = true
self.assertEqual(self.ha.run_cycle(), 'no action. I am (postgresql0), the leader with the lock')
def test_coordinator_leader_with_lock(self):
self.ha.cluster = get_cluster_initialized_with_leader()
self.ha.cluster.is_unlocked = false
self.ha.has_lock = true
self.assertEqual(self.ha.run_cycle(), 'no action. I am (postgresql0), the leader with the lock')
@patch.object(Postgresql, '_wait_for_connection_close', Mock())
@patch.object(Cluster, 'is_unlocked', Mock(return_value=False))
def test_demote_because_not_having_lock(self):
self.ha.cluster.is_unlocked = false
with patch.object(Watchdog, 'is_running', PropertyMock(return_value=True)):
self.assertEqual(self.ha.run_cycle(), 'demoting self because I do not have the lock and I was a leader')
@patch.object(Cluster, 'is_unlocked', Mock(return_value=False))
def test_demote_because_update_lock_failed(self):
self.ha.cluster.is_unlocked = false
self.ha.has_lock = true
self.ha.update_lock = false
self.assertEqual(self.ha.run_cycle(), 'demoted self because failed to update leader lock in DCS')
@@ -450,8 +445,8 @@ class TestHa(PostgresInit):
self.p.is_leader = false
self.assertEqual(self.ha.run_cycle(), 'not promoting because failed to update leader lock in DCS')
@patch.object(Cluster, 'is_unlocked', Mock(return_value=False))
def test_follow(self):
self.ha.cluster.is_unlocked = false
self.p.is_leader = false
self.assertEqual(self.ha.run_cycle(), 'no action. I am (postgresql0), a secondary, and following a leader ()')
self.ha.patroni.replicatefrom = "foo"
@@ -465,8 +460,8 @@ class TestHa(PostgresInit):
self.assertEqual(self.ha.run_cycle(), 'no action. I am (postgresql0), a secondary, and following a leader ()')
del self.ha.cluster.config.data['postgresql']['use_slots']
@patch.object(Cluster, 'is_unlocked', Mock(return_value=False))
def test_follow_in_pause(self):
self.ha.cluster.is_unlocked = false
self.ha.is_paused = true
self.assertEqual(self.ha.run_cycle(), 'PAUSE: continue to run as primary without lock')
self.p.is_leader = false
@@ -659,10 +654,12 @@ class TestHa(PostgresInit):
self.ha.update_lock = false
self.p.set_role('primary')
with patch('patroni.async_executor.CriticalTask.cancel', Mock(return_value=False)):
with patch('patroni.postgresql.Postgresql.terminate_starting_postmaster') as mock_terminate:
self.assertEqual(self.ha.run_cycle(), 'lost leader lock during restart')
mock_terminate.assert_called()
with patch('patroni.async_executor.CriticalTask.cancel', Mock(return_value=False)),\
patch('patroni.async_executor.CriticalTask.result',
PropertyMock(return_value=PostmasterProcess(os.getpid())), create=True),\
patch('patroni.postgresql.Postgresql.terminate_starting_postmaster') as mock_terminate:
self.assertEqual(self.ha.run_cycle(), 'lost leader lock during restart')
mock_terminate.assert_called()
self.ha.is_paused = true
self.assertEqual(self.ha.run_cycle(), 'PAUSE: restart in progress')
@@ -741,7 +738,6 @@ class TestHa(PostgresInit):
def test_manual_failover_process_no_leader(self):
self.p.is_leader = false
self.ha.cluster = get_cluster_initialized_without_leader(failover=Failover(0, '', self.p.name, None))
self.assertEqual(self.ha.run_cycle(), 'promoted self to leader by acquiring session lock')
self.ha.cluster = get_cluster_initialized_without_leader(failover=Failover(0, '', 'leader', None))
self.p.set_role('replica')
self.assertEqual(self.ha.run_cycle(), 'promoted self to leader by acquiring session lock')
@@ -847,10 +843,10 @@ class TestHa(PostgresInit):
self.assertFalse(self.ha.is_healthiest_node())
def test__is_healthiest_node(self):
self.p.is_leader = false
self.ha.cluster = get_cluster_initialized_without_leader(sync=('postgresql1', self.p.name))
self.ha.global_config = self.ha.patroni.config.get_global_config(self.ha.cluster)
self.assertTrue(self.ha._is_healthiest_node(self.ha.old_cluster.members))
self.p.is_leader = false
self.ha.fetch_node_status = get_node_status() # accessible, in_recovery
self.assertTrue(self.ha._is_healthiest_node(self.ha.old_cluster.members))
self.ha.fetch_node_status = get_node_status(in_recovery=False) # accessible, not in_recovery
@@ -864,6 +860,8 @@ class TestHa(PostgresInit):
self.ha.global_config = self.ha.patroni.config.get_global_config(self.ha.cluster)
with patch('patroni.postgresql.Postgresql.last_operation', return_value=1):
self.assertFalse(self.ha._is_healthiest_node(self.ha.old_cluster.members))
with patch('patroni.postgresql.Postgresql.replica_cached_timeline', return_value=None):
self.assertFalse(self.ha._is_healthiest_node(self.ha.old_cluster.members))
with patch('patroni.postgresql.Postgresql.replica_cached_timeline', return_value=1):
self.assertFalse(self.ha._is_healthiest_node(self.ha.old_cluster.members))
self.ha.patroni.nofailover = True
@@ -972,11 +970,11 @@ class TestHa(PostgresInit):
with patch.object(Leader, 'conn_url', PropertyMock(return_value='')):
self.assertEqual(self.ha.run_cycle(), 'continue following the old known standby leader')
@patch.object(Cluster, 'is_unlocked', Mock(return_value=True))
def test_process_unhealthy_standby_cluster_as_standby_leader(self):
self.p.is_leader = false
self.p.name = 'leader'
self.ha.cluster = get_standby_cluster_initialized_with_only_leader()
self.ha.cluster.is_unlocked = true
self.ha.sysid_valid = true
self.p._sysid = True
self.assertEqual(self.ha.run_cycle(), 'promoted self to a standby leader by acquiring session lock')
@@ -987,7 +985,6 @@ class TestHa(PostgresInit):
self.p.is_leader = false
self.p.name = 'replica'
self.ha.cluster = get_standby_cluster_initialized_with_only_leader()
self.ha.is_unlocked = true
self.assertTrue(self.ha.run_cycle().startswith('running pg_rewind from remote_member:'))
def test_recover_unhealthy_leader_in_standby_cluster(self):
@@ -998,13 +995,13 @@ class TestHa(PostgresInit):
self.ha.cluster = get_standby_cluster_initialized_with_only_leader()
self.assertEqual(self.ha.run_cycle(), 'starting as a standby leader because i had the session lock')
@patch.object(Cluster, 'is_unlocked', Mock(return_value=True))
def test_recover_unhealthy_unlocked_standby_cluster(self):
self.p.is_leader = false
self.p.name = 'leader'
self.p.is_running = false
self.p.follow = false
self.ha.cluster = get_standby_cluster_initialized_with_only_leader()
self.ha.cluster.is_unlocked = true
self.ha.has_lock = false
self.assertEqual(self.ha.run_cycle(), 'trying to follow a remote member because standby cluster is unhealthy')
@@ -1287,10 +1284,10 @@ class TestHa(PostgresInit):
@patch('patroni.postgresql.mtime', Mock(return_value=1588316884))
@patch('builtins.open', Mock(side_effect=Exception))
@patch.object(Cluster, 'is_unlocked', Mock(return_value=False))
def test_restore_cluster_config(self):
self.ha.cluster.config.data.clear()
self.ha.has_lock = true
self.ha.cluster.is_unlocked = false
self.assertEqual(self.ha.run_cycle(), 'no action. I am (postgresql0), the leader with the lock')
def test_watch(self):
@@ -1341,9 +1338,9 @@ class TestHa(PostgresInit):
@patch('patroni.postgresql.mtime', Mock(return_value=1588316884))
@patch('builtins.open', mock_open(read_data=('1\t0/40159C0\tno recovery target specified\n\n'
'2\t1/40159C0\tno recovery target specified\n')))
@patch.object(Cluster, 'is_unlocked', Mock(return_value=False))
def test_update_cluster_history(self):
self.ha.has_lock = true
self.ha.cluster.is_unlocked = false
for tl in (1, 3):
self.p.get_primary_timeline = Mock(return_value=tl)
self.assertEqual(self.ha.run_cycle(), 'no action. I am (postgresql0), the leader with the lock')
@@ -1355,9 +1352,9 @@ class TestHa(PostgresInit):
self.ha.run_cycle()
exit_mock.assert_called_once_with(1)
@patch.object(Cluster, 'is_unlocked', Mock(return_value=False))
def test_after_pause(self):
self.ha.has_lock = true
self.ha.cluster.is_unlocked = false
self.ha.is_paused = true
self.assertEqual(self.ha.run_cycle(), 'PAUSE: no action. I am (postgresql0), the leader with the lock')
self.ha.is_paused = false
@@ -1409,16 +1406,10 @@ class TestHa(PostgresInit):
@patch.object(Postgresql, 'major_version', PropertyMock(return_value=130000))
@patch.object(SlotsHandler, 'sync_replication_slots', Mock(return_value=['ls']))
def test_follow_copy(self):
self.ha.cluster.is_unlocked = false
self.ha.cluster.config.data['slots'] = {'ls': {'database': 'a', 'plugin': 'b'}}
self.p.is_leader = false
self.assertTrue(self.ha.run_cycle().startswith('Copying logical slots'))
def test_is_failover_possible(self):
self.ha.fetch_node_status = Mock(return_value=_MemberStatus(self.ha.cluster.members[0],
True, True, 0, 2, None, {}, False))
self.assertFalse(self.ha.is_failover_possible(self.ha.cluster.members))
def test_acquire_lock(self):
self.ha.dcs.attempt_to_acquire_leader = Mock(side_effect=[DCSError('foo'), Exception])
self.assertRaises(DCSError, self.ha.acquire_lock)

View File

@@ -5,6 +5,7 @@ import mock
import socket
import time
import unittest
import urllib3
from mock import Mock, PropertyMock, mock_open, patch
from patroni.dcs.kubernetes import Cluster, k8s_client, k8s_config, K8sConfig, K8sConnectionFailed,\
@@ -34,6 +35,9 @@ def mock_list_namespaced_config_map(*args, **kwargs):
metadata.update({'name': 'test-1-leader', 'labels': {Kubernetes._CITUS_LABEL: '1'},
'annotations': {'leader': 'p-3', 'ttl': '30s'}})
items.append(k8s_client.V1ConfigMap(metadata=k8s_client.V1ObjectMeta(**metadata)))
metadata.update({'name': 'test-2-config', 'labels': {Kubernetes._CITUS_LABEL: '2'}, 'annotations': {}})
items.append(k8s_client.V1ConfigMap(metadata=k8s_client.V1ObjectMeta(**metadata)))
metadata = k8s_client.V1ObjectMeta(resource_version='1')
return k8s_client.V1ConfigMapList(metadata=metadata, items=items, kind='ConfigMapList')
@@ -403,13 +407,18 @@ class TestKubernetesEndpoints(BaseTestKubernetes):
self.assertEqual(('create_config_service failed',), mock_logger_exception.call_args[0])
def mock_watch(*args):
return urllib3.HTTPResponse()
class TestCacheBuilder(BaseTestKubernetes):
@patch.object(k8s_client.CoreV1Api, 'list_namespaced_config_map', mock_list_namespaced_config_map, create=True)
@patch('patroni.dcs.kubernetes.ObjectCache._watch')
def test__build_cache(self, mock_response):
@patch('patroni.dcs.kubernetes.ObjectCache._watch', mock_watch)
@patch.object(urllib3.HTTPResponse, 'read_chunked')
def test__build_cache(self, mock_read_chunked):
self.k._citus_group = '0'
mock_response.return_value.read_chunked.return_value = [json.dumps(
mock_read_chunked.return_value = [json.dumps(
{'type': 'MODIFIED', 'object': {'metadata': {
'name': self.k.config_path, 'resourceVersion': '2', 'annotations': {self.k._CONFIG: 'foo'}}}}
).encode('utf-8'), ('\n' + json.dumps(
@@ -435,12 +444,13 @@ class TestCacheBuilder(BaseTestKubernetes):
self.assertRaises(AttributeError, self.k._kinds._do_watch, '1')
@patch.object(k8s_client.CoreV1Api, 'list_namespaced_config_map', mock_list_namespaced_config_map, create=True)
@patch('patroni.dcs.kubernetes.ObjectCache._watch')
def test_kill_stream(self, mock_watch):
self.k._kinds.kill_stream()
mock_watch.return_value.read_chunked.return_value = []
mock_watch.return_value.connection.sock.close.side_effect = Exception
self.k._kinds._do_watch('1')
self.k._kinds.kill_stream()
type(mock_watch.return_value).connection = PropertyMock(side_effect=Exception)
@patch('patroni.dcs.kubernetes.ObjectCache._watch', mock_watch)
@patch.object(urllib3.HTTPResponse, 'read_chunked', Mock(return_value=[]))
def test_kill_stream(self):
self.k._kinds.kill_stream()
with patch.object(urllib3.HTTPResponse, 'connection') as mock_connection:
mock_connection.sock.close.side_effect = Exception
self.k._kinds._do_watch('1')
self.k._kinds.kill_stream()
with patch.object(urllib3.HTTPResponse, 'connection', PropertyMock(side_effect=Exception)):
self.k._kinds.kill_stream()

View File

@@ -369,6 +369,10 @@ class TestPostgresql(BaseTestPostgresql):
def test_latest_checkpoint_location(self, mock_popen):
mock_popen.return_value.communicate.return_value = (None, None)
self.assertEqual(self.p.latest_checkpoint_location(), 28163096)
with patch.object(Postgresql, 'controldata', Mock(return_value={'Database cluster state': 'shut down',
'Latest checkpoint location': 'k/1ADBC18',
"Latest checkpoint's TimeLineID": '1'})):
self.assertIsNone(self.p.latest_checkpoint_location())
# 9.3 and 9.4 format
mock_popen.return_value.communicate.side_effect = [
(b'rmgr: XLOG len (rec/tot): 72/ 104, tx: 0, lsn: 0/01ADBC18, prev 0/01ADBBB8, '
@@ -518,10 +522,10 @@ class TestPostgresql(BaseTestPostgresql):
def test_can_create_replica_without_replication_connection(self):
self.p.config._config['create_replica_method'] = []
self.assertFalse(self.p.can_create_replica_without_replication_connection())
self.assertFalse(self.p.can_create_replica_without_replication_connection(None))
self.p.config._config['create_replica_method'] = ['wale', 'basebackup']
self.p.config._config['wale'] = {'command': 'foo', 'no_leader': 1}
self.assertTrue(self.p.can_create_replica_without_replication_connection())
self.assertTrue(self.p.can_create_replica_without_replication_connection(None))
def test_replica_method_can_work_without_replication_connection(self):
self.assertFalse(self.p.replica_method_can_work_without_replication_connection('basebackup'))

View File

@@ -111,6 +111,8 @@ class TestSlotsHandler(BaseTestPostgresql):
self.s.copy_logical_slots(self.cluster, ['ls'])
with patch.object(MockCursor, 'execute', Mock(side_effect=psycopg.OperationalError)):
self.s.copy_logical_slots(self.cluster, ['foo'])
with patch.object(Cluster, 'leader', PropertyMock(return_value=None)):
self.s.copy_logical_slots(self.cluster, ['foo'])
@patch.object(Postgresql, 'stop', Mock(return_value=True))
@patch.object(Postgresql, 'start', Mock(return_value=True))

View File

@@ -36,7 +36,7 @@ class TestUtils(unittest.TestCase):
def test_enable_keepalive(self):
with patch('socket.SIO_KEEPALIVE_VALS', 1, create=True):
self.assertIsNotNone(enable_keepalive(Mock(), 10, 5))
self.assertIsNone(enable_keepalive(Mock(), 10, 5))
with patch('socket.SIO_KEEPALIVE_VALS', None, create=True):
for platform in ('linux2', 'darwin', 'other'):
with patch('sys.platform', platform):

View File

@@ -123,7 +123,7 @@ class TestWALERestore(unittest.TestCase):
with patch.object(WALERestore, 'run', Mock(return_value=1)), \
patch('time.sleep', mock_sleep):
self.assertEqual(_main(), 1)
self.assertTrue(sleeps[0], WALE_TEST_RETRIES)
self.assertEqual(sleeps[0], WALE_TEST_RETRIES)
@patch('os.path.isfile', Mock(return_value=True))
def test_get_major_version(self):

View File

@@ -110,6 +110,11 @@ class TestWatchdog(unittest.TestCase):
watchdog.keepalive()
self.assertEqual(len(device.writes), 1)
watchdog.impl._fd, fd = None, watchdog.impl._fd
watchdog.keepalive()
self.assertEqual(len(device.writes), 1)
watchdog.impl._fd = fd
watchdog.disable()
self.assertFalse(device.open)
self.assertEqual(device.writes[-1], b'V')

View File

@@ -13,6 +13,7 @@ from patroni.dcs.zookeeper import Cluster, Leader, PatroniKazooClient,\
class MockKazooClient(Mock):
handler = PatroniSequentialThreadingHandler(10)
leader = False
exists = True
@@ -154,11 +155,6 @@ class TestZooKeeper(unittest.TestCase):
def test_session_listener(self):
self.zk.session_listener(KazooState.SUSPENDED)
def test_members_watcher(self):
self.zk._fetch_cluster = False
self.zk.members_watcher(None)
self.assertTrue(self.zk._fetch_cluster)
def test_reload_config(self):
self.zk.reload_config({'ttl': 20, 'retry_timeout': 10, 'loop_wait': 10})
self.zk.reload_config({'ttl': 20, 'retry_timeout': 10, 'loop_wait': 5})
@@ -223,6 +219,8 @@ class TestZooKeeper(unittest.TestCase):
def test_cancel_initialization(self):
self.zk.cancel_initialization()
with patch.object(MockKazooClient, 'delete', Mock()):
self.zk.cancel_initialization()
def test_touch_member(self):
self.zk._name = 'buzz'

View File

View File

@@ -0,0 +1 @@
class ClientError(Exception): ...

View File

@@ -0,0 +1,11 @@
from typing import Any, Callable, Dict, Optional
DEFAULT_METADATA_SERVICE_TIMEOUT = 1
METADATA_BASE_URL = 'http://169.254.169.254/'
class AWSResponse:
status_code: int
@property
def text(self) -> str: ...
class IMDSFetcher:
def __init__(self, timeout: float = DEFAULT_METADATA_SERVICE_TIMEOUT, num_attempts: int = 1, base_url: str = METADATA_BASE_URL, env: Optional[Dict[str, str]] = None, user_agent: Optional[str] = None, config: Optional[Dict[str, Any]] = None) -> None: ...
def _fetch_metadata_token(self) -> Optional[str]: ...:
def _get_request(self, url_path: str, retry_func: Optional[Callable[[AWSResponse], bool]] = None, token: Optional[str] = None) -> AWSResponse: ...

View File

@@ -0,0 +1,5 @@
import io
from typing import Any
class PatchStream:
def __init__(self, diff_hdl: io.TextIOBase) -> None: ...
def markup_to_pager(stream: Any, opts: Any) -> None: ...

View File

@@ -0,0 +1,2 @@
from consul.base import ConsulException, NotFound
__all__ = ['ConsulException', 'Consul', 'NotFound']

24
typings/consul/base.pyi Normal file
View File

@@ -0,0 +1,24 @@
from typing import Any, Dict, List, Optional, Tuple
class ConsulException(Exception): ...
class NotFound(ConsulException): ...
class Check:
@classmethod
def http(klass, url: str, interval: str, timeout: Optional[str] = None, deregister: Optional[str] = None) -> Dict[str, str]: ...
class Consul:
http: Any
agent: 'Consul.Agent'
session: 'Consul.Session'
kv: 'Consul.KV'
class KV:
def get(self, key: str, index: Optional[int]=None, recurse: bool = False, wait: Optional[str] = None, token: Optional[str] = None, consistency: Optional[str] = None, keys: bool = False, separator: Optional[str] = '', dc: Optional[str] = None) -> Tuple[int, Dict[str, Any]]: ...
def put(self, key: str, value: str, cas: Optional[int] = None, flags: Optional[int] = None, acquire: Optional[str] = None, release: Optional[str] = None, token: Optional[str] = None, dc: Optional[str] = None) -> bool: ...
def delete(self, key: str, recurse: Optional[bool] = None, cas: Optional[int] = None, token: Optional[str] = None, dc: Optional[str] = None) -> bool: ...
class Agent:
service: 'Consul.Agent.Service'
def self(self) -> Dict[str, Dict[str, Any]]: ...
class Service:
def register(self, name: str, service_id=..., address=..., port=..., tags=..., check=..., token=..., script=..., interval=..., ttl=..., http=..., timeout=..., enable_tag_override=...) -> bool: ...
def deregister(self, service_id: str) -> bool: ...
class Session:
def create(self, name: Optional[str] = None, node: Optional[str] = [], checks: Optional[List[str]]=None, lock_delay: float = 15, behavior: str = 'release', ttl: Optional[int] = None, dc: Optional[str] = None) -> str: ...
def renew(self, session_id: str, dc: Optional[str] = None) -> Optional[str]: ...

17
typings/dns/resolver.pyi Normal file
View File

@@ -0,0 +1,17 @@
from typing import Union, Optional, Iterator
class Name:
def to_text(self, omit_final_dot: bool = ...) -> str: ...
class Rdata:
target: Name = ...
port: int = ...
class Answer:
def __iter__(self) -> Iterator[Rdata]: ...
def resolve(qname : str, rdtype : Union[int,str] = 0,
rdclass : Union[int,str] = 0,
tcp=False, source=None, raise_on_no_answer=True,
source_port=0, lifetime : Optional[float]=None,
search : Optional[bool]=None) -> Answer: ...
def query(qname : str, rdtype : Union[int,str] = 0,
rdclass : Union[int,str] = 0,
tcp=False, source: Optional[str] = None, raise_on_no_answer=True,
source_port=0, lifetime : Optional[float]=None) -> Answer: ...

24
typings/etcd/__init__.pyi Normal file
View File

@@ -0,0 +1,24 @@
from typing import Dict, Optional, Type, List
from .client import Client
__all__ = ['Client', 'EtcdError', 'EtcdException', 'EtcdEventIndexCleared', 'EtcdWatcherCleared', 'EtcdKeyNotFound', 'EtcdAlreadyExist', 'EtcdResult', 'EtcdConnectionFailed', 'EtcdWatchTimedOut']
class EtcdResult:
action: str = ...
modifiedIndex: int = ...
key: str = ...
value: str = ...
ttl: Optional[float] = ...
@property
def leaves(self) -> List['EtcdResult']: ...
class EtcdException(Exception):
def __init__(self, message=..., payload=...) -> None: ...
class EtcdConnectionFailed(EtcdException):
def __init__(self, message=..., payload=..., cause=...) -> None: ...
class EtcdKeyError(EtcdException): ...
class EtcdKeyNotFound(EtcdKeyError): ...
class EtcdAlreadyExist(EtcdKeyError): ...
class EtcdEventIndexCleared(EtcdException): ...
class EtcdWatchTimedOut(EtcdConnectionFailed): ...
class EtcdWatcherCleared(EtcdException): ...
class EtcdLeaderElectionInProgress(EtcdException): ...
class EtcdError:
error_exceptions: Dict[int, Type[EtcdException]] = ...

29
typings/etcd/client.pyi Normal file
View File

@@ -0,0 +1,29 @@
import urllib3
from typing import Any, Optional, Set
from . import EtcdResult
class Client:
_MGET: str
_MPUT: str
_MPOST: str
_MDELETE: str
_comparison_conditions: Set[str]
_read_options: Set[str]
_del_conditions: Set[str]
http: urllib3.poolmanager.PoolManager
_use_proxies: bool
version_prefix: str
username: Optional[str]
password: Optional[str]
def __init__(self, host=..., port=..., srv_domain=..., version_prefix=..., read_timeout=..., allow_redirect=..., protocol=..., cert=..., ca_cert=..., username=..., password=..., allow_reconnect=..., use_proxies=..., expected_cluster_id=..., per_host_pool_size=..., lock_prefix=...): ...
@property
def protocol(self) -> str: ...
@property
def read_timeout(self) -> int: ...
@property
def allow_redirect(self) -> bool: ...
def write(self, key: str, value: str, ttl: int = ..., dir: bool = ..., append: bool = ..., **kwdargs: Any) -> EtcdResult: ...
def read(self, key: str, **kwdargs: Any) -> EtcdResult: ...
def delete(self, key: str, recursive: bool = ..., dir: bool = ..., **kwdargs: Any) -> EtcdResult: ...
def set(self, key: str, value: str, ttl: int = ...) -> EtcdResult: ...
def watch(self, key: str, index: int = ..., timeout: float = ..., recursive: bool = ...) -> EtcdResult: ...
def _handle_server_response(self, response: urllib3.response.HTTPResponse) -> Any: ...

34
typings/kazoo/client.pyi Normal file
View File

@@ -0,0 +1,34 @@
__all__ = ['KazooState', 'KazooClient', 'KazooRetry']
from kazoo.protocol.connection import ConnectionHandler
from kazoo.protocol.states import KazooState, WatchedEvent, ZnodeStat
from kazoo.handlers.threading import AsyncResult, SequentialThreadingHandler
from kazoo.retry import KazooRetry
from kazoo.security import ACL
from typing import Any, Callable, Optional, Tuple, List
class KazooClient:
handler: SequentialThreadingHandler
_state: str
_connection: ConnectionHandler
_session_timeout: int
retry: Callable[..., Any]
_retry: KazooRetry
def __init__(self, hosts=..., timeout=..., client_id=..., handler=..., default_acl=..., auth_data=..., sasl_options=..., read_only=..., randomize_hosts=..., connection_retry=..., command_retry=..., logger=..., keyfile=..., keyfile_password=..., certfile=..., ca=..., use_ssl=..., verify_certs=..., **kwargs) -> None: ...
@property
def client_id(self) -> Optional[Tuple[Any]]: ...
def add_listener(self, listener: Callable[[str], None]) -> None: ...
def start(self, timeout: int = ...) -> None: ...
def restart(self) -> None: ...
def set_hosts(self, hosts: str, randomize_hosts: Optional[bool] = None) -> None: ...
def create(self, path: str, value: bytes = b'', acl: Optional[ACL]=None, ephemeral: bool = False, sequence: bool = False, makepath: bool = False, include_data: bool = False) -> None: ...
def create_async(self, path: str, value: bytes = b'', acl: Optional[ACL]=None, ephemeral: bool = False, sequence: bool = False, makepath: bool = False, include_data: bool = False) -> AsyncResult: ...
def get(self, path: str, watch: Optional[Callable[[WatchedEvent], None]] = None) -> Tuple[bytes, ZnodeStat]: ...
def get_children(self, path: str, watch: Optional[Callable[[WatchedEvent], None]] = None, include_data: bool = False) -> List[str]: ...
def set(self, path: str, value: bytes, version: int = -1) -> ZnodeStat: ...
def set_async(self, path: str, value: bytes, version: int = -1) -> AsyncResult: ...
def delete(self, path: str, version: int = -1, recursive: bool = False) -> None: ...
def delete_async(self, path: str, version: int = -1) -> AsyncResult: ...
def _call(self, request: Tuple[Any], async_object: AsyncResult) -> Optional[bool]: ...

View File

@@ -0,0 +1,12 @@
class KazooException(Exception):
...
class ZookeeperError(KazooException):
...
class SessionExpiredError(ZookeeperError):
...
class ConnectionClosedError(SessionExpiredError):
...
class NoNodeError(ZookeeperError):
...
class NodeExistsError(ZookeeperError):
...

View File

@@ -0,0 +1,13 @@
import socket
from kazoo.handlers import utils
from typing import Any
class AsyncResult(utils.AsyncResult):
...
class SequentialThreadingHandler:
def select(self, *args: Any, **kwargs: Any) -> Any:
...
def create_connection(self, *args: Any, **kwargs: Any) -> socket.socket:
...

View File

@@ -0,0 +1,6 @@
from typing import Any, Optional
class AsyncResult:
def set_exception(self, exception: Exception) -> None:
...
def get(self, block: bool = False, timeout: Optional[float] = None) -> Any:
...

View File

@@ -0,0 +1,6 @@
import socket
from typing import Any, Union, Tuple
class ConnectionHandler:
_socket: socket.socket
def _connect(self, *args: Any) -> Tuple[Union[int, float], Union[int, float]]:
...

View File

@@ -0,0 +1,29 @@
from typing import Any, NamedTuple
class KazooState:
SUSPENDED: str
CONNECTED: str
LOST: str
class KeeperState:
AUTH_FAILED: str
CONNECTED: str
CONNECTED_RO: str
CONNECTING: str
CLOSED: str
EXPIRED_SESSION: str
class WatchedEvent(NamedTuple):
type: str
state: str
path: str
class ZnodeStat(NamedTuple):
czxid: int
mzxid: int
ctime: float
mtime: float
version: int
cversion: int
aversion: int
ephemeralOwner: Any
dataLength: int
numChildren: int
pzxid: int

7
typings/kazoo/retry.pyi Normal file
View File

@@ -0,0 +1,7 @@
from kazoo.exceptions import KazooException
class RetryFailedError(KazooException):
...
class KazooRetry:
deadline: float
def __init__(self, max_tries=..., delay=..., backoff=..., max_jitter=..., max_delay=..., ignore_expire=..., sleep_func=..., deadline=..., interrupt=...) -> None:
...

View File

@@ -0,0 +1,5 @@
from collections import namedtuple
class ACL(namedtuple('ACL', 'perms id')):
...
def make_acl(scheme: str, credential: str, read: bool = ..., write: bool = ..., create: bool = ..., delete: bool = ..., admin: bool = ..., all: bool = ...) -> ACL:
...

View File

@@ -0,0 +1,13 @@
from typing import Any, Dict, List
FRAME = 1
ALL = 1
class PrettyTable:
def __init__(self, *args: str, **kwargs: Any) -> None: ...
def _stringify_hrule(self, options: Dict[str, Any], where: str = '') -> str: ...
@property
def align(self) -> Dict[str, str]: ...
@align.setter
def align(self, val: str) -> None: ...
def add_row(self, row: List[Any]) -> None: ...
def __str__(self) -> str: ...
def __repr__(self) -> str: ...

View File

@@ -0,0 +1,52 @@
from collections.abc import Callable
from typing import Any, TypeVar, overload
# connection and cursor not available at runtime
from psycopg2._psycopg import (
BINARY as BINARY,
DATETIME as DATETIME,
NUMBER as NUMBER,
ROWID as ROWID,
STRING as STRING,
Binary as Binary,
DatabaseError as DatabaseError,
DataError as DataError,
Date as Date,
DateFromTicks as DateFromTicks,
Error as Error,
IntegrityError as IntegrityError,
InterfaceError as InterfaceError,
InternalError as InternalError,
NotSupportedError as NotSupportedError,
OperationalError as OperationalError,
ProgrammingError as ProgrammingError,
Time as Time,
TimeFromTicks as TimeFromTicks,
Timestamp as Timestamp,
TimestampFromTicks as TimestampFromTicks,
Warning as Warning,
__libpq_version__ as __libpq_version__,
apilevel as apilevel,
connection as connection,
cursor as cursor,
paramstyle as paramstyle,
threadsafety as threadsafety,
)
__version__: str
_T_conn = TypeVar("_T_conn", bound=connection)
@overload
def connect(dsn: str, connection_factory: Callable[..., _T_conn], cursor_factory: None = None, **kwargs: Any) -> _T_conn: ...
@overload
def connect(
dsn: str | None = None, *, connection_factory: Callable[..., _T_conn], cursor_factory: None = None, **kwargs: Any
) -> _T_conn: ...
@overload
def connect(
dsn: str | None = None,
connection_factory: Callable[..., connection] | None = None,
cursor_factory: Callable[..., cursor] | None = None,
**kwargs: Any,
) -> connection: ...

View File

@@ -0,0 +1,9 @@
from _typeshed import Incomplete
from typing import Any
ipaddress: Any
def register_ipaddress(conn_or_curs: Incomplete | None = None) -> None: ...
def cast_interface(s, cur: Incomplete | None = None): ...
def cast_network(s, cur: Incomplete | None = None): ...
def adapt_ipaddress(obj): ...

View File

@@ -0,0 +1,26 @@
from _typeshed import Incomplete
from typing import Any
JSON_OID: int
JSONARRAY_OID: int
JSONB_OID: int
JSONBARRAY_OID: int
class Json:
adapted: Any
def __init__(self, adapted, dumps: Incomplete | None = None) -> None: ...
def __conform__(self, proto): ...
def dumps(self, obj): ...
def prepare(self, conn) -> None: ...
def getquoted(self): ...
def register_json(
conn_or_curs: Incomplete | None = None,
globally: bool = False,
loads: Incomplete | None = None,
oid: Incomplete | None = None,
array_oid: Incomplete | None = None,
name: str = "json",
): ...
def register_default_json(conn_or_curs: Incomplete | None = None, globally: bool = False, loads: Incomplete | None = None): ...
def register_default_jsonb(conn_or_curs: Incomplete | None = None, globally: bool = False, loads: Incomplete | None = None): ...

View File

@@ -0,0 +1,488 @@
from collections.abc import Callable, Iterable, Mapping, Sequence
from types import TracebackType
from typing import Any, TypeVar, overload
from typing_extensions import Literal, Self, TypeAlias
import psycopg2
import psycopg2.extensions
from psycopg2.sql import Composable
_Vars: TypeAlias = Sequence[Any] | Mapping[str, Any] | None
BINARY: Any
BINARYARRAY: Any
BOOLEAN: Any
BOOLEANARRAY: Any
BYTES: Any
BYTESARRAY: Any
CIDRARRAY: Any
DATE: Any
DATEARRAY: Any
DATETIME: Any
DATETIMEARRAY: Any
DATETIMETZ: Any
DATETIMETZARRAY: Any
DECIMAL: Any
DECIMALARRAY: Any
FLOAT: Any
FLOATARRAY: Any
INETARRAY: Any
INTEGER: Any
INTEGERARRAY: Any
INTERVAL: Any
INTERVALARRAY: Any
LONGINTEGER: Any
LONGINTEGERARRAY: Any
MACADDRARRAY: Any
NUMBER: Any
PYDATE: Any
PYDATEARRAY: Any
PYDATETIME: Any
PYDATETIMEARRAY: Any
PYDATETIMETZ: Any
PYDATETIMETZARRAY: Any
PYINTERVAL: Any
PYINTERVALARRAY: Any
PYTIME: Any
PYTIMEARRAY: Any
REPLICATION_LOGICAL: int
REPLICATION_PHYSICAL: int
ROWID: Any
ROWIDARRAY: Any
STRING: Any
STRINGARRAY: Any
TIME: Any
TIMEARRAY: Any
UNICODE: Any
UNICODEARRAY: Any
UNKNOWN: Any
adapters: dict[Any, Any]
apilevel: str
binary_types: dict[Any, Any]
encodings: dict[Any, Any]
paramstyle: str
sqlstate_errors: dict[Any, Any]
string_types: dict[Any, Any]
threadsafety: int
__libpq_version__: int
class cursor:
arraysize: int
binary_types: Any
closed: Any
connection: Any
description: Any
itersize: Any
lastrowid: Any
name: Any
pgresult_ptr: Any
query: Any
row_factory: Any
rowcount: int
rownumber: int
scrollable: bool | None
statusmessage: Any
string_types: Any
typecaster: Any
tzinfo_factory: Any
withhold: bool
def __init__(self, conn: connection, name: str | bytes | None = ...) -> None: ...
def callproc(self, procname, parameters=...): ...
def cast(self, oid, s): ...
def close(self): ...
def copy_expert(self, sql: str | bytes | Composable, file, size=...): ...
def copy_from(self, file, table, sep=..., null=..., size=..., columns=...): ...
def copy_to(self, file, table, sep=..., null=..., columns=...): ...
def execute(self, query: str | bytes | Composable, vars: _Vars = ...) -> None: ...
def executemany(self, query: str | bytes | Composable, vars_list: Iterable[_Vars]) -> None: ...
def fetchall(self) -> list[tuple[Any, ...]]: ...
def fetchmany(self, size: int | None = ...) -> list[tuple[Any, ...]]: ...
def fetchone(self) -> tuple[Any, ...] | None: ...
def mogrify(self, *args, **kwargs): ...
def nextset(self): ...
def scroll(self, value, mode=...): ...
def setinputsizes(self, sizes): ...
def setoutputsize(self, size, column=...): ...
def __enter__(self) -> Self: ...
def __exit__(
self, type: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None
) -> None: ...
def __iter__(self) -> Self: ...
def __next__(self) -> tuple[Any, ...]: ...
_Cursor: TypeAlias = cursor
class AsIs:
adapted: Any
def __init__(self, *args, **kwargs) -> None: ...
def getquoted(self, *args, **kwargs): ...
def __conform__(self, *args, **kwargs): ...
class Binary:
adapted: Any
buffer: Any
def __init__(self, *args, **kwargs) -> None: ...
def getquoted(self, *args, **kwargs): ...
def prepare(self, conn): ...
def __conform__(self, *args, **kwargs): ...
class Boolean:
adapted: Any
def __init__(self, *args, **kwargs) -> None: ...
def getquoted(self, *args, **kwargs): ...
def __conform__(self, *args, **kwargs): ...
class Column:
display_size: Any
internal_size: Any
name: Any
null_ok: Any
precision: Any
scale: Any
table_column: Any
table_oid: Any
type_code: Any
def __init__(self, *args, **kwargs) -> None: ...
def __eq__(self, __other): ...
def __ge__(self, __other): ...
def __getitem__(self, __index): ...
def __getstate__(self): ...
def __gt__(self, __other): ...
def __le__(self, __other): ...
def __len__(self) -> int: ...
def __lt__(self, __other): ...
def __ne__(self, __other): ...
def __setstate__(self, state): ...
class ConnectionInfo:
# Note: the following properties can be None if their corresponding libpq function
# returns NULL. They're not annotated as such, because this is very unlikely in
# practice---the psycopg2 docs [1] don't even mention this as a possibility!
#
# - db_name
# - user
# - password
# - host
# - port
# - options
#
# (To prove this, one needs to inspect the psycopg2 source code [2], plus the
# documentation [3] and source code [4] of the corresponding libpq calls.)
#
# [1]: https://www.psycopg.org/docs/extensions.html#psycopg2.extensions.ConnectionInfo
# [2]: https://github.com/psycopg/psycopg2/blob/1d3a89a0bba621dc1cc9b32db6d241bd2da85ad1/psycopg/conninfo_type.c#L52 and below
# [3]: https://www.postgresql.org/docs/current/libpq-status.html
# [4]: https://github.com/postgres/postgres/blob/b39838889e76274b107935fa8e8951baf0e8b31b/src/interfaces/libpq/fe-connect.c#L6754 and below
@property
def backend_pid(self) -> int: ...
@property
def dbname(self) -> str: ...
@property
def dsn_parameters(self) -> dict[str, str]: ...
@property
def error_message(self) -> str | None: ...
@property
def host(self) -> str: ...
@property
def needs_password(self) -> bool: ...
@property
def options(self) -> str: ...
@property
def password(self) -> str: ...
@property
def port(self) -> int: ...
@property
def protocol_version(self) -> int: ...
@property
def server_version(self) -> int: ...
@property
def socket(self) -> int: ...
@property
def ssl_attribute_names(self) -> list[str]: ...
@property
def ssl_in_use(self) -> bool: ...
@property
def status(self) -> int: ...
@property
def transaction_status(self) -> int: ...
@property
def used_password(self) -> bool: ...
@property
def user(self) -> str: ...
def __init__(self, *args, **kwargs) -> None: ...
def parameter_status(self, name: str) -> str | None: ...
def ssl_attribute(self, name: str) -> str | None: ...
class DataError(psycopg2.DatabaseError): ...
class DatabaseError(psycopg2.Error): ...
class Decimal:
adapted: Any
def __init__(self, *args, **kwargs) -> None: ...
def getquoted(self, *args, **kwargs): ...
def __conform__(self, *args, **kwargs): ...
class Diagnostics:
column_name: str | None
constraint_name: str | None
context: str | None
datatype_name: str | None
internal_position: str | None
internal_query: str | None
message_detail: str | None
message_hint: str | None
message_primary: str | None
schema_name: str | None
severity: str | None
severity_nonlocalized: str | None
source_file: str | None
source_function: str | None
source_line: str | None
sqlstate: str | None
statement_position: str | None
table_name: str | None
def __init__(self, __err: Error) -> None: ...
class Error(Exception):
cursor: _Cursor | None
diag: Diagnostics
pgcode: str | None
pgerror: str | None
def __init__(self, *args, **kwargs) -> None: ...
def __reduce__(self): ...
def __setstate__(self, state): ...
class Float:
adapted: Any
def __init__(self, *args, **kwargs) -> None: ...
def getquoted(self, *args, **kwargs): ...
def __conform__(self, *args, **kwargs): ...
class ISQLQuote:
_wrapped: Any
def __init__(self, *args, **kwargs) -> None: ...
def getbinary(self, *args, **kwargs): ...
def getbuffer(self, *args, **kwargs): ...
def getquoted(self, *args, **kwargs) -> bytes: ...
class Int:
adapted: Any
def __init__(self, *args, **kwargs) -> None: ...
def getquoted(self, *args, **kwargs): ...
def __conform__(self, *args, **kwargs): ...
class IntegrityError(psycopg2.DatabaseError): ...
class InterfaceError(psycopg2.Error): ...
class InternalError(psycopg2.DatabaseError): ...
class List:
adapted: Any
def __init__(self, *args, **kwargs) -> None: ...
def getquoted(self, *args, **kwargs): ...
def prepare(self, *args, **kwargs): ...
def __conform__(self, *args, **kwargs): ...
class NotSupportedError(psycopg2.DatabaseError): ...
class Notify:
channel: Any
payload: Any
pid: Any
def __init__(self, *args, **kwargs) -> None: ...
def __eq__(self, __other): ...
def __ge__(self, __other): ...
def __getitem__(self, __index): ...
def __gt__(self, __other): ...
def __hash__(self) -> int: ...
def __le__(self, __other): ...
def __len__(self) -> int: ...
def __lt__(self, __other): ...
def __ne__(self, __other): ...
class OperationalError(psycopg2.DatabaseError): ...
class ProgrammingError(psycopg2.DatabaseError): ...
class QueryCanceledError(psycopg2.OperationalError): ...
class QuotedString:
adapted: Any
buffer: Any
encoding: Any
def __init__(self, *args, **kwargs) -> None: ...
def getquoted(self, *args, **kwargs): ...
def prepare(self, *args, **kwargs): ...
def __conform__(self, *args, **kwargs): ...
class ReplicationConnection(psycopg2.extensions.connection):
autocommit: Any
isolation_level: Any
replication_type: Any
reset: Any
set_isolation_level: Any
set_session: Any
def __init__(self, *args, **kwargs) -> None: ...
class ReplicationCursor(cursor):
feedback_timestamp: Any
io_timestamp: Any
wal_end: Any
def __init__(self, *args, **kwargs) -> None: ...
def consume_stream(self, consumer, keepalive_interval=...): ...
def read_message(self, *args, **kwargs): ...
def send_feedback(self, write_lsn=..., flush_lsn=..., apply_lsn=..., reply=..., force=...): ...
def start_replication_expert(self, command, decode=..., status_interval=...): ...
class ReplicationMessage:
cursor: Any
data_size: Any
data_start: Any
payload: Any
send_time: Any
wal_end: Any
def __init__(self, *args, **kwargs) -> None: ...
class TransactionRollbackError(psycopg2.OperationalError): ...
class Warning(Exception): ...
class Xid:
bqual: Any
database: Any
format_id: Any
gtrid: Any
owner: Any
prepared: Any
def __init__(self, *args, **kwargs) -> None: ...
def from_string(self, *args, **kwargs): ...
def __getitem__(self, __index): ...
def __len__(self) -> int: ...
_T_cur = TypeVar("_T_cur", bound=cursor)
class connection:
DataError: Any
DatabaseError: Any
Error: Any
IntegrityError: Any
InterfaceError: Any
InternalError: Any
NotSupportedError: Any
OperationalError: Any
ProgrammingError: Any
Warning: Any
@property
def async_(self) -> int: ...
autocommit: bool
@property
def binary_types(self) -> Any: ...
@property
def closed(self) -> int: ...
cursor_factory: Callable[..., _Cursor]
@property
def dsn(self) -> str: ...
@property
def encoding(self) -> str: ...
@property
def info(self) -> ConnectionInfo: ...
@property
def isolation_level(self) -> int | None: ...
@isolation_level.setter
def isolation_level(self, __value: str | bytes | int | None) -> None: ...
notices: list[Any]
notifies: list[Any]
@property
def pgconn_ptr(self) -> int | None: ...
@property
def protocol_version(self) -> int: ...
@property
def deferrable(self) -> bool | None: ...
@deferrable.setter
def deferrable(self, __value: Literal["default"] | bool | None) -> None: ...
@property
def readonly(self) -> bool | None: ...
@readonly.setter
def readonly(self, __value: Literal["default"] | bool | None) -> None: ...
@property
def server_version(self) -> int: ...
@property
def status(self) -> int: ...
@property
def string_types(self) -> Any: ...
# Really it's dsn: str, async: int = ..., async_: int = ..., but
# that would be a syntax error.
def __init__(self, dsn: str, *, async_: int = ...) -> None: ...
def cancel(self) -> None: ...
def close(self) -> None: ...
def commit(self) -> None: ...
@overload
def cursor(self, name: str | bytes | None = ..., *, withhold: bool = ..., scrollable: bool | None = ...) -> _Cursor: ...
def fileno(self) -> int: ...
def get_backend_pid(self) -> int: ...
def get_dsn_parameters(self) -> dict[str, str]: ...
def get_native_connection(self): ...
def get_parameter_status(self, parameter: str) -> str | None: ...
def get_transaction_status(self) -> int: ...
def isexecuting(self) -> bool: ...
def lobject(
self,
oid: int = ...,
mode: str | None = ...,
new_oid: int = ...,
new_file: str | None = ...,
lobject_factory: type[lobject] = ...,
) -> lobject: ...
def poll(self) -> int: ...
def reset(self) -> None: ...
def rollback(self) -> None: ...
def set_client_encoding(self, encoding: str) -> None: ...
def set_isolation_level(self, level: int | None) -> None: ...
def set_session(
self,
isolation_level: str | bytes | int | None = ...,
readonly: bool | Literal["default", b"default"] | None = ...,
deferrable: bool | Literal["default", b"default"] | None = ...,
autocommit: bool = ...,
) -> None: ...
def tpc_begin(self, xid: str | bytes | Xid) -> None: ...
def tpc_commit(self, __xid: str | bytes | Xid = ...) -> None: ...
def tpc_prepare(self) -> None: ...
def tpc_recover(self) -> list[Xid]: ...
def tpc_rollback(self, __xid: str | bytes | Xid = ...) -> None: ...
def xid(self, format_id, gtrid, bqual) -> Xid: ...
def __enter__(self) -> Self: ...
def __exit__(self, __type: type[BaseException] | None, __name: BaseException | None, __tb: TracebackType | None) -> None: ...
class lobject:
closed: Any
mode: Any
oid: Any
def __init__(self, *args, **kwargs) -> None: ...
def close(self): ...
def export(self, filename): ...
def read(self, size=...): ...
def seek(self, offset, whence=...): ...
def tell(self): ...
def truncate(self, len=...): ...
def unlink(self): ...
def write(self, str): ...
def Date(year, month, day): ...
def DateFromPy(*args, **kwargs): ...
def DateFromTicks(ticks): ...
def IntervalFromPy(*args, **kwargs): ...
def Time(hour, minutes, seconds, tzinfo=...): ...
def TimeFromPy(*args, **kwargs): ...
def TimeFromTicks(ticks): ...
def Timestamp(year, month, day, hour, minutes, seconds, tzinfo=...): ...
def TimestampFromPy(*args, **kwargs): ...
def TimestampFromTicks(ticks): ...
def _connect(*args, **kwargs): ...
def adapt(*args, **kwargs): ...
def encrypt_password(*args, **kwargs): ...
def get_wait_callback(*args, **kwargs): ...
def libpq_version(*args, **kwargs): ...
def new_array_type(oids, name, baseobj): ...
def new_type(oids, name, castobj): ...
def parse_dsn(dsn: str | bytes) -> dict[str, Any]: ...
def quote_ident(value: Any, scope: connection | cursor | None) -> str: ...
def register_type(*args, **kwargs): ...
def set_wait_callback(_none): ...

View File

@@ -0,0 +1,62 @@
from _typeshed import Incomplete
from typing import Any
class Range:
def __init__(
self, lower: Incomplete | None = None, upper: Incomplete | None = None, bounds: str = "[)", empty: bool = False
) -> None: ...
@property
def lower(self): ...
@property
def upper(self): ...
@property
def isempty(self): ...
@property
def lower_inf(self): ...
@property
def upper_inf(self): ...
@property
def lower_inc(self): ...
@property
def upper_inc(self): ...
def __contains__(self, x): ...
def __bool__(self) -> bool: ...
def __eq__(self, other): ...
def __ne__(self, other): ...
def __hash__(self) -> int: ...
def __lt__(self, other): ...
def __le__(self, other): ...
def __gt__(self, other): ...
def __ge__(self, other): ...
def register_range(pgrange, pyrange, conn_or_curs, globally: bool = False): ...
class RangeAdapter:
name: Any
adapted: Any
def __init__(self, adapted) -> None: ...
def __conform__(self, proto): ...
def prepare(self, conn) -> None: ...
def getquoted(self): ...
class RangeCaster:
subtype_oid: Any
typecaster: Any
array_typecaster: Any
def __init__(self, pgrange, pyrange, oid, subtype_oid, array_oid: Incomplete | None = None) -> None: ...
def parse(self, s, cur: Incomplete | None = None): ...
class NumericRange(Range): ...
class DateRange(Range): ...
class DateTimeRange(Range): ...
class DateTimeTZRange(Range): ...
class NumberRangeAdapter(RangeAdapter):
def getquoted(self): ...
int4range_caster: Any
int8range_caster: Any
numrange_caster: Any
daterange_caster: Any
tsrange_caster: Any
tstzrange_caster: Any

View File

@@ -0,0 +1,304 @@
def lookup(code, _cache={}): ...
CLASS_SUCCESSFUL_COMPLETION: str
CLASS_WARNING: str
CLASS_NO_DATA: str
CLASS_SQL_STATEMENT_NOT_YET_COMPLETE: str
CLASS_CONNECTION_EXCEPTION: str
CLASS_TRIGGERED_ACTION_EXCEPTION: str
CLASS_FEATURE_NOT_SUPPORTED: str
CLASS_INVALID_TRANSACTION_INITIATION: str
CLASS_LOCATOR_EXCEPTION: str
CLASS_INVALID_GRANTOR: str
CLASS_INVALID_ROLE_SPECIFICATION: str
CLASS_DIAGNOSTICS_EXCEPTION: str
CLASS_CASE_NOT_FOUND: str
CLASS_CARDINALITY_VIOLATION: str
CLASS_DATA_EXCEPTION: str
CLASS_INTEGRITY_CONSTRAINT_VIOLATION: str
CLASS_INVALID_CURSOR_STATE: str
CLASS_INVALID_TRANSACTION_STATE: str
CLASS_INVALID_SQL_STATEMENT_NAME: str
CLASS_TRIGGERED_DATA_CHANGE_VIOLATION: str
CLASS_INVALID_AUTHORIZATION_SPECIFICATION: str
CLASS_DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST: str
CLASS_INVALID_TRANSACTION_TERMINATION: str
CLASS_SQL_ROUTINE_EXCEPTION: str
CLASS_INVALID_CURSOR_NAME: str
CLASS_EXTERNAL_ROUTINE_EXCEPTION: str
CLASS_EXTERNAL_ROUTINE_INVOCATION_EXCEPTION: str
CLASS_SAVEPOINT_EXCEPTION: str
CLASS_INVALID_CATALOG_NAME: str
CLASS_INVALID_SCHEMA_NAME: str
CLASS_TRANSACTION_ROLLBACK: str
CLASS_SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION: str
CLASS_WITH_CHECK_OPTION_VIOLATION: str
CLASS_INSUFFICIENT_RESOURCES: str
CLASS_PROGRAM_LIMIT_EXCEEDED: str
CLASS_OBJECT_NOT_IN_PREREQUISITE_STATE: str
CLASS_OPERATOR_INTERVENTION: str
CLASS_SYSTEM_ERROR: str
CLASS_SNAPSHOT_FAILURE: str
CLASS_CONFIGURATION_FILE_ERROR: str
CLASS_FOREIGN_DATA_WRAPPER_ERROR: str
CLASS_PL_PGSQL_ERROR: str
CLASS_INTERNAL_ERROR: str
SUCCESSFUL_COMPLETION: str
WARNING: str
NULL_VALUE_ELIMINATED_IN_SET_FUNCTION: str
STRING_DATA_RIGHT_TRUNCATION_: str
PRIVILEGE_NOT_REVOKED: str
PRIVILEGE_NOT_GRANTED: str
IMPLICIT_ZERO_BIT_PADDING: str
DYNAMIC_RESULT_SETS_RETURNED: str
DEPRECATED_FEATURE: str
NO_DATA: str
NO_ADDITIONAL_DYNAMIC_RESULT_SETS_RETURNED: str
SQL_STATEMENT_NOT_YET_COMPLETE: str
CONNECTION_EXCEPTION: str
SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION: str
CONNECTION_DOES_NOT_EXIST: str
SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION: str
CONNECTION_FAILURE: str
TRANSACTION_RESOLUTION_UNKNOWN: str
PROTOCOL_VIOLATION: str
TRIGGERED_ACTION_EXCEPTION: str
FEATURE_NOT_SUPPORTED: str
INVALID_TRANSACTION_INITIATION: str
LOCATOR_EXCEPTION: str
INVALID_LOCATOR_SPECIFICATION: str
INVALID_GRANTOR: str
INVALID_GRANT_OPERATION: str
INVALID_ROLE_SPECIFICATION: str
DIAGNOSTICS_EXCEPTION: str
STACKED_DIAGNOSTICS_ACCESSED_WITHOUT_ACTIVE_HANDLER: str
CASE_NOT_FOUND: str
CARDINALITY_VIOLATION: str
DATA_EXCEPTION: str
STRING_DATA_RIGHT_TRUNCATION: str
NULL_VALUE_NO_INDICATOR_PARAMETER: str
NUMERIC_VALUE_OUT_OF_RANGE: str
NULL_VALUE_NOT_ALLOWED_: str
ERROR_IN_ASSIGNMENT: str
INVALID_DATETIME_FORMAT: str
DATETIME_FIELD_OVERFLOW: str
INVALID_TIME_ZONE_DISPLACEMENT_VALUE: str
ESCAPE_CHARACTER_CONFLICT: str
INVALID_USE_OF_ESCAPE_CHARACTER: str
INVALID_ESCAPE_OCTET: str
ZERO_LENGTH_CHARACTER_STRING: str
MOST_SPECIFIC_TYPE_MISMATCH: str
SEQUENCE_GENERATOR_LIMIT_EXCEEDED: str
NOT_AN_XML_DOCUMENT: str
INVALID_XML_DOCUMENT: str
INVALID_XML_CONTENT: str
INVALID_XML_COMMENT: str
INVALID_XML_PROCESSING_INSTRUCTION: str
INVALID_INDICATOR_PARAMETER_VALUE: str
SUBSTRING_ERROR: str
DIVISION_BY_ZERO: str
INVALID_PRECEDING_OR_FOLLOWING_SIZE: str
INVALID_ARGUMENT_FOR_NTILE_FUNCTION: str
INTERVAL_FIELD_OVERFLOW: str
INVALID_ARGUMENT_FOR_NTH_VALUE_FUNCTION: str
INVALID_CHARACTER_VALUE_FOR_CAST: str
INVALID_ESCAPE_CHARACTER: str
INVALID_REGULAR_EXPRESSION: str
INVALID_ARGUMENT_FOR_LOGARITHM: str
INVALID_ARGUMENT_FOR_POWER_FUNCTION: str
INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION: str
INVALID_ROW_COUNT_IN_LIMIT_CLAUSE: str
INVALID_ROW_COUNT_IN_RESULT_OFFSET_CLAUSE: str
INVALID_LIMIT_VALUE: str
CHARACTER_NOT_IN_REPERTOIRE: str
INDICATOR_OVERFLOW: str
INVALID_PARAMETER_VALUE: str
UNTERMINATED_C_STRING: str
INVALID_ESCAPE_SEQUENCE: str
STRING_DATA_LENGTH_MISMATCH: str
TRIM_ERROR: str
ARRAY_SUBSCRIPT_ERROR: str
INVALID_TABLESAMPLE_REPEAT: str
INVALID_TABLESAMPLE_ARGUMENT: str
DUPLICATE_JSON_OBJECT_KEY_VALUE: str
INVALID_ARGUMENT_FOR_SQL_JSON_DATETIME_FUNCTION: str
INVALID_JSON_TEXT: str
INVALID_SQL_JSON_SUBSCRIPT: str
MORE_THAN_ONE_SQL_JSON_ITEM: str
NO_SQL_JSON_ITEM: str
NON_NUMERIC_SQL_JSON_ITEM: str
NON_UNIQUE_KEYS_IN_A_JSON_OBJECT: str
SINGLETON_SQL_JSON_ITEM_REQUIRED: str
SQL_JSON_ARRAY_NOT_FOUND: str
SQL_JSON_MEMBER_NOT_FOUND: str
SQL_JSON_NUMBER_NOT_FOUND: str
SQL_JSON_OBJECT_NOT_FOUND: str
TOO_MANY_JSON_ARRAY_ELEMENTS: str
TOO_MANY_JSON_OBJECT_MEMBERS: str
SQL_JSON_SCALAR_REQUIRED: str
FLOATING_POINT_EXCEPTION: str
INVALID_TEXT_REPRESENTATION: str
INVALID_BINARY_REPRESENTATION: str
BAD_COPY_FILE_FORMAT: str
UNTRANSLATABLE_CHARACTER: str
NONSTANDARD_USE_OF_ESCAPE_CHARACTER: str
INTEGRITY_CONSTRAINT_VIOLATION: str
RESTRICT_VIOLATION: str
NOT_NULL_VIOLATION: str
FOREIGN_KEY_VIOLATION: str
UNIQUE_VIOLATION: str
CHECK_VIOLATION: str
EXCLUSION_VIOLATION: str
INVALID_CURSOR_STATE: str
INVALID_TRANSACTION_STATE: str
ACTIVE_SQL_TRANSACTION: str
BRANCH_TRANSACTION_ALREADY_ACTIVE: str
INAPPROPRIATE_ACCESS_MODE_FOR_BRANCH_TRANSACTION: str
INAPPROPRIATE_ISOLATION_LEVEL_FOR_BRANCH_TRANSACTION: str
NO_ACTIVE_SQL_TRANSACTION_FOR_BRANCH_TRANSACTION: str
READ_ONLY_SQL_TRANSACTION: str
SCHEMA_AND_DATA_STATEMENT_MIXING_NOT_SUPPORTED: str
HELD_CURSOR_REQUIRES_SAME_ISOLATION_LEVEL: str
NO_ACTIVE_SQL_TRANSACTION: str
IN_FAILED_SQL_TRANSACTION: str
IDLE_IN_TRANSACTION_SESSION_TIMEOUT: str
INVALID_SQL_STATEMENT_NAME: str
TRIGGERED_DATA_CHANGE_VIOLATION: str
INVALID_AUTHORIZATION_SPECIFICATION: str
INVALID_PASSWORD: str
DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST: str
DEPENDENT_OBJECTS_STILL_EXIST: str
INVALID_TRANSACTION_TERMINATION: str
SQL_ROUTINE_EXCEPTION: str
MODIFYING_SQL_DATA_NOT_PERMITTED_: str
PROHIBITED_SQL_STATEMENT_ATTEMPTED_: str
READING_SQL_DATA_NOT_PERMITTED_: str
FUNCTION_EXECUTED_NO_RETURN_STATEMENT: str
INVALID_CURSOR_NAME: str
EXTERNAL_ROUTINE_EXCEPTION: str
CONTAINING_SQL_NOT_PERMITTED: str
MODIFYING_SQL_DATA_NOT_PERMITTED: str
PROHIBITED_SQL_STATEMENT_ATTEMPTED: str
READING_SQL_DATA_NOT_PERMITTED: str
EXTERNAL_ROUTINE_INVOCATION_EXCEPTION: str
INVALID_SQLSTATE_RETURNED: str
NULL_VALUE_NOT_ALLOWED: str
TRIGGER_PROTOCOL_VIOLATED: str
SRF_PROTOCOL_VIOLATED: str
EVENT_TRIGGER_PROTOCOL_VIOLATED: str
SAVEPOINT_EXCEPTION: str
INVALID_SAVEPOINT_SPECIFICATION: str
INVALID_CATALOG_NAME: str
INVALID_SCHEMA_NAME: str
TRANSACTION_ROLLBACK: str
SERIALIZATION_FAILURE: str
TRANSACTION_INTEGRITY_CONSTRAINT_VIOLATION: str
STATEMENT_COMPLETION_UNKNOWN: str
DEADLOCK_DETECTED: str
SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION: str
INSUFFICIENT_PRIVILEGE: str
SYNTAX_ERROR: str
INVALID_NAME: str
INVALID_COLUMN_DEFINITION: str
NAME_TOO_LONG: str
DUPLICATE_COLUMN: str
AMBIGUOUS_COLUMN: str
UNDEFINED_COLUMN: str
UNDEFINED_OBJECT: str
DUPLICATE_OBJECT: str
DUPLICATE_ALIAS: str
DUPLICATE_FUNCTION: str
AMBIGUOUS_FUNCTION: str
GROUPING_ERROR: str
DATATYPE_MISMATCH: str
WRONG_OBJECT_TYPE: str
INVALID_FOREIGN_KEY: str
CANNOT_COERCE: str
UNDEFINED_FUNCTION: str
GENERATED_ALWAYS: str
RESERVED_NAME: str
UNDEFINED_TABLE: str
UNDEFINED_PARAMETER: str
DUPLICATE_CURSOR: str
DUPLICATE_DATABASE: str
DUPLICATE_PREPARED_STATEMENT: str
DUPLICATE_SCHEMA: str
DUPLICATE_TABLE: str
AMBIGUOUS_PARAMETER: str
AMBIGUOUS_ALIAS: str
INVALID_COLUMN_REFERENCE: str
INVALID_CURSOR_DEFINITION: str
INVALID_DATABASE_DEFINITION: str
INVALID_FUNCTION_DEFINITION: str
INVALID_PREPARED_STATEMENT_DEFINITION: str
INVALID_SCHEMA_DEFINITION: str
INVALID_TABLE_DEFINITION: str
INVALID_OBJECT_DEFINITION: str
INDETERMINATE_DATATYPE: str
INVALID_RECURSION: str
WINDOWING_ERROR: str
COLLATION_MISMATCH: str
INDETERMINATE_COLLATION: str
WITH_CHECK_OPTION_VIOLATION: str
INSUFFICIENT_RESOURCES: str
DISK_FULL: str
OUT_OF_MEMORY: str
TOO_MANY_CONNECTIONS: str
CONFIGURATION_LIMIT_EXCEEDED: str
PROGRAM_LIMIT_EXCEEDED: str
STATEMENT_TOO_COMPLEX: str
TOO_MANY_COLUMNS: str
TOO_MANY_ARGUMENTS: str
OBJECT_NOT_IN_PREREQUISITE_STATE: str
OBJECT_IN_USE: str
CANT_CHANGE_RUNTIME_PARAM: str
LOCK_NOT_AVAILABLE: str
UNSAFE_NEW_ENUM_VALUE_USAGE: str
OPERATOR_INTERVENTION: str
QUERY_CANCELED: str
ADMIN_SHUTDOWN: str
CRASH_SHUTDOWN: str
CANNOT_CONNECT_NOW: str
DATABASE_DROPPED: str
SYSTEM_ERROR: str
IO_ERROR: str
UNDEFINED_FILE: str
DUPLICATE_FILE: str
SNAPSHOT_TOO_OLD: str
CONFIG_FILE_ERROR: str
LOCK_FILE_EXISTS: str
FDW_ERROR: str
FDW_OUT_OF_MEMORY: str
FDW_DYNAMIC_PARAMETER_VALUE_NEEDED: str
FDW_INVALID_DATA_TYPE: str
FDW_COLUMN_NAME_NOT_FOUND: str
FDW_INVALID_DATA_TYPE_DESCRIPTORS: str
FDW_INVALID_COLUMN_NAME: str
FDW_INVALID_COLUMN_NUMBER: str
FDW_INVALID_USE_OF_NULL_POINTER: str
FDW_INVALID_STRING_FORMAT: str
FDW_INVALID_HANDLE: str
FDW_INVALID_OPTION_INDEX: str
FDW_INVALID_OPTION_NAME: str
FDW_OPTION_NAME_NOT_FOUND: str
FDW_REPLY_HANDLE: str
FDW_UNABLE_TO_CREATE_EXECUTION: str
FDW_UNABLE_TO_CREATE_REPLY: str
FDW_UNABLE_TO_ESTABLISH_CONNECTION: str
FDW_NO_SCHEMAS: str
FDW_SCHEMA_NOT_FOUND: str
FDW_TABLE_NOT_FOUND: str
FDW_FUNCTION_SEQUENCE_ERROR: str
FDW_TOO_MANY_HANDLES: str
FDW_INCONSISTENT_DESCRIPTOR_INFORMATION: str
FDW_INVALID_ATTRIBUTE_VALUE: str
FDW_INVALID_STRING_LENGTH_OR_BUFFER_LENGTH: str
FDW_INVALID_DESCRIPTOR_FIELD_IDENTIFIER: str
PLPGSQL_ERROR: str
RAISE_EXCEPTION: str
NO_DATA_FOUND: str
TOO_MANY_ROWS: str
ASSERT_FAILURE: str
INTERNAL_ERROR: str
DATA_CORRUPTED: str
INDEX_CORRUPTED: str

263
typings/psycopg2/errors.pyi Normal file
View File

@@ -0,0 +1,263 @@
from psycopg2._psycopg import Error as Error, Warning as Warning
class DatabaseError(Error): ...
class InterfaceError(Error): ...
class DataError(DatabaseError): ...
class DiagnosticsException(DatabaseError): ...
class IntegrityError(DatabaseError): ...
class InternalError(DatabaseError): ...
class InvalidGrantOperation(DatabaseError): ...
class InvalidGrantor(DatabaseError): ...
class InvalidLocatorSpecification(DatabaseError): ...
class InvalidRoleSpecification(DatabaseError): ...
class InvalidTransactionInitiation(DatabaseError): ...
class LocatorException(DatabaseError): ...
class NoAdditionalDynamicResultSetsReturned(DatabaseError): ...
class NoData(DatabaseError): ...
class NotSupportedError(DatabaseError): ...
class OperationalError(DatabaseError): ...
class ProgrammingError(DatabaseError): ...
class SnapshotTooOld(DatabaseError): ...
class SqlStatementNotYetComplete(DatabaseError): ...
class StackedDiagnosticsAccessedWithoutActiveHandler(DatabaseError): ...
class TriggeredActionException(DatabaseError): ...
class ActiveSqlTransaction(InternalError): ...
class AdminShutdown(OperationalError): ...
class AmbiguousAlias(ProgrammingError): ...
class AmbiguousColumn(ProgrammingError): ...
class AmbiguousFunction(ProgrammingError): ...
class AmbiguousParameter(ProgrammingError): ...
class ArraySubscriptError(DataError): ...
class AssertFailure(InternalError): ...
class BadCopyFileFormat(DataError): ...
class BranchTransactionAlreadyActive(InternalError): ...
class CannotCoerce(ProgrammingError): ...
class CannotConnectNow(OperationalError): ...
class CantChangeRuntimeParam(OperationalError): ...
class CardinalityViolation(ProgrammingError): ...
class CaseNotFound(ProgrammingError): ...
class CharacterNotInRepertoire(DataError): ...
class CheckViolation(IntegrityError): ...
class CollationMismatch(ProgrammingError): ...
class ConfigFileError(InternalError): ...
class ConfigurationLimitExceeded(OperationalError): ...
class ConnectionDoesNotExist(OperationalError): ...
class ConnectionException(OperationalError): ...
class ConnectionFailure(OperationalError): ...
class ContainingSqlNotPermitted(InternalError): ...
class CrashShutdown(OperationalError): ...
class DataCorrupted(InternalError): ...
class DataException(DataError): ...
class DatabaseDropped(OperationalError): ...
class DatatypeMismatch(ProgrammingError): ...
class DatetimeFieldOverflow(DataError): ...
class DependentObjectsStillExist(InternalError): ...
class DependentPrivilegeDescriptorsStillExist(InternalError): ...
class DiskFull(OperationalError): ...
class DivisionByZero(DataError): ...
class DuplicateAlias(ProgrammingError): ...
class DuplicateColumn(ProgrammingError): ...
class DuplicateCursor(ProgrammingError): ...
class DuplicateDatabase(ProgrammingError): ...
class DuplicateFile(OperationalError): ...
class DuplicateFunction(ProgrammingError): ...
class DuplicateJsonObjectKeyValue(DataError): ...
class DuplicateObject(ProgrammingError): ...
class DuplicatePreparedStatement(ProgrammingError): ...
class DuplicateSchema(ProgrammingError): ...
class DuplicateTable(ProgrammingError): ...
class ErrorInAssignment(DataError): ...
class EscapeCharacterConflict(DataError): ...
class EventTriggerProtocolViolated(InternalError): ...
class ExclusionViolation(IntegrityError): ...
class ExternalRoutineException(InternalError): ...
class ExternalRoutineInvocationException(InternalError): ...
class FdwColumnNameNotFound(OperationalError): ...
class FdwDynamicParameterValueNeeded(OperationalError): ...
class FdwError(OperationalError): ...
class FdwFunctionSequenceError(OperationalError): ...
class FdwInconsistentDescriptorInformation(OperationalError): ...
class FdwInvalidAttributeValue(OperationalError): ...
class FdwInvalidColumnName(OperationalError): ...
class FdwInvalidColumnNumber(OperationalError): ...
class FdwInvalidDataType(OperationalError): ...
class FdwInvalidDataTypeDescriptors(OperationalError): ...
class FdwInvalidDescriptorFieldIdentifier(OperationalError): ...
class FdwInvalidHandle(OperationalError): ...
class FdwInvalidOptionIndex(OperationalError): ...
class FdwInvalidOptionName(OperationalError): ...
class FdwInvalidStringFormat(OperationalError): ...
class FdwInvalidStringLengthOrBufferLength(OperationalError): ...
class FdwInvalidUseOfNullPointer(OperationalError): ...
class FdwNoSchemas(OperationalError): ...
class FdwOptionNameNotFound(OperationalError): ...
class FdwOutOfMemory(OperationalError): ...
class FdwReplyHandle(OperationalError): ...
class FdwSchemaNotFound(OperationalError): ...
class FdwTableNotFound(OperationalError): ...
class FdwTooManyHandles(OperationalError): ...
class FdwUnableToCreateExecution(OperationalError): ...
class FdwUnableToCreateReply(OperationalError): ...
class FdwUnableToEstablishConnection(OperationalError): ...
class FeatureNotSupported(NotSupportedError): ...
class FloatingPointException(DataError): ...
class ForeignKeyViolation(IntegrityError): ...
class FunctionExecutedNoReturnStatement(InternalError): ...
class GeneratedAlways(ProgrammingError): ...
class GroupingError(ProgrammingError): ...
class HeldCursorRequiresSameIsolationLevel(InternalError): ...
class IdleInTransactionSessionTimeout(InternalError): ...
class InFailedSqlTransaction(InternalError): ...
class InappropriateAccessModeForBranchTransaction(InternalError): ...
class InappropriateIsolationLevelForBranchTransaction(InternalError): ...
class IndeterminateCollation(ProgrammingError): ...
class IndeterminateDatatype(ProgrammingError): ...
class IndexCorrupted(InternalError): ...
class IndicatorOverflow(DataError): ...
class InsufficientPrivilege(ProgrammingError): ...
class InsufficientResources(OperationalError): ...
class IntegrityConstraintViolation(IntegrityError): ...
class InternalError_(InternalError): ...
class IntervalFieldOverflow(DataError): ...
class InvalidArgumentForLogarithm(DataError): ...
class InvalidArgumentForNthValueFunction(DataError): ...
class InvalidArgumentForNtileFunction(DataError): ...
class InvalidArgumentForPowerFunction(DataError): ...
class InvalidArgumentForSqlJsonDatetimeFunction(DataError): ...
class InvalidArgumentForWidthBucketFunction(DataError): ...
class InvalidAuthorizationSpecification(OperationalError): ...
class InvalidBinaryRepresentation(DataError): ...
class InvalidCatalogName(ProgrammingError): ...
class InvalidCharacterValueForCast(DataError): ...
class InvalidColumnDefinition(ProgrammingError): ...
class InvalidColumnReference(ProgrammingError): ...
class InvalidCursorDefinition(ProgrammingError): ...
class InvalidCursorName(OperationalError): ...
class InvalidCursorState(InternalError): ...
class InvalidDatabaseDefinition(ProgrammingError): ...
class InvalidDatetimeFormat(DataError): ...
class InvalidEscapeCharacter(DataError): ...
class InvalidEscapeOctet(DataError): ...
class InvalidEscapeSequence(DataError): ...
class InvalidForeignKey(ProgrammingError): ...
class InvalidFunctionDefinition(ProgrammingError): ...
class InvalidIndicatorParameterValue(DataError): ...
class InvalidJsonText(DataError): ...
class InvalidName(ProgrammingError): ...
class InvalidObjectDefinition(ProgrammingError): ...
class InvalidParameterValue(DataError): ...
class InvalidPassword(OperationalError): ...
class InvalidPrecedingOrFollowingSize(DataError): ...
class InvalidPreparedStatementDefinition(ProgrammingError): ...
class InvalidRecursion(ProgrammingError): ...
class InvalidRegularExpression(DataError): ...
class InvalidRowCountInLimitClause(DataError): ...
class InvalidRowCountInResultOffsetClause(DataError): ...
class InvalidSavepointSpecification(InternalError): ...
class InvalidSchemaDefinition(ProgrammingError): ...
class InvalidSchemaName(ProgrammingError): ...
class InvalidSqlJsonSubscript(DataError): ...
class InvalidSqlStatementName(OperationalError): ...
class InvalidSqlstateReturned(InternalError): ...
class InvalidTableDefinition(ProgrammingError): ...
class InvalidTablesampleArgument(DataError): ...
class InvalidTablesampleRepeat(DataError): ...
class InvalidTextRepresentation(DataError): ...
class InvalidTimeZoneDisplacementValue(DataError): ...
class InvalidTransactionState(InternalError): ...
class InvalidTransactionTermination(InternalError): ...
class InvalidUseOfEscapeCharacter(DataError): ...
class InvalidXmlComment(DataError): ...
class InvalidXmlContent(DataError): ...
class InvalidXmlDocument(DataError): ...
class InvalidXmlProcessingInstruction(DataError): ...
class IoError(OperationalError): ...
class LockFileExists(InternalError): ...
class LockNotAvailable(OperationalError): ...
class ModifyingSqlDataNotPermitted(InternalError): ...
class ModifyingSqlDataNotPermittedExt(InternalError): ...
class MoreThanOneSqlJsonItem(DataError): ...
class MostSpecificTypeMismatch(DataError): ...
class NameTooLong(ProgrammingError): ...
class NoActiveSqlTransaction(InternalError): ...
class NoActiveSqlTransactionForBranchTransaction(InternalError): ...
class NoDataFound(InternalError): ...
class NoSqlJsonItem(DataError): ...
class NonNumericSqlJsonItem(DataError): ...
class NonUniqueKeysInAJsonObject(DataError): ...
class NonstandardUseOfEscapeCharacter(DataError): ...
class NotAnXmlDocument(DataError): ...
class NotNullViolation(IntegrityError): ...
class NullValueNoIndicatorParameter(DataError): ...
class NullValueNotAllowed(DataError): ...
class NullValueNotAllowedExt(InternalError): ...
class NumericValueOutOfRange(DataError): ...
class ObjectInUse(OperationalError): ...
class ObjectNotInPrerequisiteState(OperationalError): ...
class OperatorIntervention(OperationalError): ...
class OutOfMemory(OperationalError): ...
class PlpgsqlError(InternalError): ...
class ProgramLimitExceeded(OperationalError): ...
class ProhibitedSqlStatementAttempted(InternalError): ...
class ProhibitedSqlStatementAttemptedExt(InternalError): ...
class ProtocolViolation(OperationalError): ...
class QueryCanceledError(OperationalError): ...
class RaiseException(InternalError): ...
class ReadOnlySqlTransaction(InternalError): ...
class ReadingSqlDataNotPermitted(InternalError): ...
class ReadingSqlDataNotPermittedExt(InternalError): ...
class ReservedName(ProgrammingError): ...
class RestrictViolation(IntegrityError): ...
class SavepointException(InternalError): ...
class SchemaAndDataStatementMixingNotSupported(InternalError): ...
class SequenceGeneratorLimitExceeded(DataError): ...
class SingletonSqlJsonItemRequired(DataError): ...
class SqlJsonArrayNotFound(DataError): ...
class SqlJsonMemberNotFound(DataError): ...
class SqlJsonNumberNotFound(DataError): ...
class SqlJsonObjectNotFound(DataError): ...
class SqlJsonScalarRequired(DataError): ...
class SqlRoutineException(InternalError): ...
class SqlclientUnableToEstablishSqlconnection(OperationalError): ...
class SqlserverRejectedEstablishmentOfSqlconnection(OperationalError): ...
class SrfProtocolViolated(InternalError): ...
class StatementTooComplex(OperationalError): ...
class StringDataLengthMismatch(DataError): ...
class StringDataRightTruncation(DataError): ...
class SubstringError(DataError): ...
class SyntaxError(ProgrammingError): ...
class SyntaxErrorOrAccessRuleViolation(ProgrammingError): ...
class SystemError(OperationalError): ...
class TooManyArguments(OperationalError): ...
class TooManyColumns(OperationalError): ...
class TooManyConnections(OperationalError): ...
class TooManyJsonArrayElements(DataError): ...
class TooManyJsonObjectMembers(DataError): ...
class TooManyRows(InternalError): ...
class TransactionResolutionUnknown(OperationalError): ...
class TransactionRollbackError(OperationalError): ...
class TriggerProtocolViolated(InternalError): ...
class TriggeredDataChangeViolation(OperationalError): ...
class TrimError(DataError): ...
class UndefinedColumn(ProgrammingError): ...
class UndefinedFile(OperationalError): ...
class UndefinedFunction(ProgrammingError): ...
class UndefinedObject(ProgrammingError): ...
class UndefinedParameter(ProgrammingError): ...
class UndefinedTable(ProgrammingError): ...
class UniqueViolation(IntegrityError): ...
class UnsafeNewEnumValueUsage(OperationalError): ...
class UnterminatedCString(DataError): ...
class UntranslatableCharacter(DataError): ...
class WindowingError(ProgrammingError): ...
class WithCheckOptionViolation(ProgrammingError): ...
class WrongObjectType(ProgrammingError): ...
class ZeroLengthCharacterString(DataError): ...
class DeadlockDetected(TransactionRollbackError): ...
class QueryCanceled(QueryCanceledError): ...
class SerializationFailure(TransactionRollbackError): ...
class StatementCompletionUnknown(TransactionRollbackError): ...
class TransactionIntegrityConstraintViolation(TransactionRollbackError): ...
class TransactionRollback(TransactionRollbackError): ...
def lookup(code): ...

View File

@@ -0,0 +1,117 @@
from _typeshed import Incomplete
from typing import Any
from psycopg2._psycopg import (
BINARYARRAY as BINARYARRAY,
BOOLEAN as BOOLEAN,
BOOLEANARRAY as BOOLEANARRAY,
BYTES as BYTES,
BYTESARRAY as BYTESARRAY,
DATE as DATE,
DATEARRAY as DATEARRAY,
DATETIMEARRAY as DATETIMEARRAY,
DECIMAL as DECIMAL,
DECIMALARRAY as DECIMALARRAY,
FLOAT as FLOAT,
FLOATARRAY as FLOATARRAY,
INTEGER as INTEGER,
INTEGERARRAY as INTEGERARRAY,
INTERVAL as INTERVAL,
INTERVALARRAY as INTERVALARRAY,
LONGINTEGER as LONGINTEGER,
LONGINTEGERARRAY as LONGINTEGERARRAY,
PYDATE as PYDATE,
PYDATEARRAY as PYDATEARRAY,
PYDATETIME as PYDATETIME,
PYDATETIMEARRAY as PYDATETIMEARRAY,
PYDATETIMETZ as PYDATETIMETZ,
PYDATETIMETZARRAY as PYDATETIMETZARRAY,
PYINTERVAL as PYINTERVAL,
PYINTERVALARRAY as PYINTERVALARRAY,
PYTIME as PYTIME,
PYTIMEARRAY as PYTIMEARRAY,
ROWIDARRAY as ROWIDARRAY,
STRINGARRAY as STRINGARRAY,
TIME as TIME,
TIMEARRAY as TIMEARRAY,
UNICODE as UNICODE,
UNICODEARRAY as UNICODEARRAY,
AsIs as AsIs,
Binary as Binary,
Boolean as Boolean,
Column as Column,
ConnectionInfo as ConnectionInfo,
DateFromPy as DateFromPy,
Diagnostics as Diagnostics,
Float as Float,
Int as Int,
IntervalFromPy as IntervalFromPy,
ISQLQuote as ISQLQuote,
Notify as Notify,
QueryCanceledError as QueryCanceledError,
QuotedString as QuotedString,
TimeFromPy as TimeFromPy,
TimestampFromPy as TimestampFromPy,
TransactionRollbackError as TransactionRollbackError,
Xid as Xid,
adapt as adapt,
adapters as adapters,
binary_types as binary_types,
connection as connection,
cursor as cursor,
encodings as encodings,
encrypt_password as encrypt_password,
get_wait_callback as get_wait_callback,
libpq_version as libpq_version,
lobject as lobject,
new_array_type as new_array_type,
new_type as new_type,
parse_dsn as parse_dsn,
quote_ident as quote_ident,
register_type as register_type,
set_wait_callback as set_wait_callback,
string_types as string_types,
)
ISOLATION_LEVEL_AUTOCOMMIT: int
ISOLATION_LEVEL_READ_UNCOMMITTED: int
ISOLATION_LEVEL_READ_COMMITTED: int
ISOLATION_LEVEL_REPEATABLE_READ: int
ISOLATION_LEVEL_SERIALIZABLE: int
ISOLATION_LEVEL_DEFAULT: Any
STATUS_SETUP: int
STATUS_READY: int
STATUS_BEGIN: int
STATUS_SYNC: int
STATUS_ASYNC: int
STATUS_PREPARED: int
STATUS_IN_TRANSACTION: int
POLL_OK: int
POLL_READ: int
POLL_WRITE: int
POLL_ERROR: int
TRANSACTION_STATUS_IDLE: int
TRANSACTION_STATUS_ACTIVE: int
TRANSACTION_STATUS_INTRANS: int
TRANSACTION_STATUS_INERROR: int
TRANSACTION_STATUS_UNKNOWN: int
def register_adapter(typ, callable) -> None: ...
class SQL_IN:
def __init__(self, seq) -> None: ...
def prepare(self, conn) -> None: ...
def getquoted(self): ...
class NoneAdapter:
def __init__(self, obj) -> None: ...
def getquoted(self, _null: bytes = b"NULL"): ...
def make_dsn(dsn: Incomplete | None = None, **kwargs): ...
JSON: Any
JSONARRAY: Any
JSONB: Any
JSONBARRAY: Any
def adapt(obj: Any) -> ISQLQuote: ...

240
typings/psycopg2/extras.pyi Normal file
View File

@@ -0,0 +1,240 @@
from _typeshed import Incomplete
from collections import OrderedDict
from collections.abc import Callable
from typing import Any, NamedTuple, TypeVar, overload
from psycopg2._ipaddress import register_ipaddress as register_ipaddress
from psycopg2._json import (
Json as Json,
register_default_json as register_default_json,
register_default_jsonb as register_default_jsonb,
register_json as register_json,
)
from psycopg2._psycopg import (
REPLICATION_LOGICAL as REPLICATION_LOGICAL,
REPLICATION_PHYSICAL as REPLICATION_PHYSICAL,
ReplicationConnection as _replicationConnection,
ReplicationCursor as _replicationCursor,
ReplicationMessage as ReplicationMessage,
)
from psycopg2._range import (
DateRange as DateRange,
DateTimeRange as DateTimeRange,
DateTimeTZRange as DateTimeTZRange,
NumericRange as NumericRange,
Range as Range,
RangeAdapter as RangeAdapter,
RangeCaster as RangeCaster,
register_range as register_range,
)
from .extensions import connection as _connection, cursor as _cursor, quote_ident as quote_ident
_T_cur = TypeVar("_T_cur", bound=_cursor)
class DictCursorBase(_cursor):
def __init__(self, *args, **kwargs) -> None: ...
class DictConnection(_connection):
@overload
def cursor(self, name: str | bytes | None = ..., *, withhold: bool = ..., scrollable: bool | None = ...) -> DictCursor: ...
@overload
def cursor(
self,
name: str | bytes | None = ...,
*,
cursor_factory: Callable[..., _T_cur],
withhold: bool = ...,
scrollable: bool | None = ...,
) -> _T_cur: ...
@overload
def cursor(
self, name: str | bytes | None, cursor_factory: Callable[..., _T_cur], withhold: bool = ..., scrollable: bool | None = ...
) -> _T_cur: ...
class DictCursor(DictCursorBase):
def __init__(self, *args, **kwargs) -> None: ...
index: Any
def execute(self, query, vars: Incomplete | None = None): ...
def callproc(self, procname, vars: Incomplete | None = None): ...
def fetchone(self) -> DictRow | None: ... # type: ignore[override]
def fetchmany(self, size: int | None = None) -> list[DictRow]: ... # type: ignore[override]
def fetchall(self) -> list[DictRow]: ... # type: ignore[override]
def __next__(self) -> DictRow: ... # type: ignore[override]
class DictRow(list[Any]):
def __init__(self, cursor) -> None: ...
def __getitem__(self, x): ...
def __setitem__(self, x, v) -> None: ...
def items(self): ...
def keys(self): ...
def values(self): ...
def get(self, x, default: Incomplete | None = None): ...
def copy(self): ...
def __contains__(self, x): ...
def __reduce__(self): ...
class RealDictConnection(_connection):
@overload
def cursor(
self, name: str | bytes | None = ..., *, withhold: bool = ..., scrollable: bool | None = ...
) -> RealDictCursor: ...
@overload
def cursor(
self,
name: str | bytes | None = ...,
*,
cursor_factory: Callable[..., _T_cur],
withhold: bool = ...,
scrollable: bool | None = ...,
) -> _T_cur: ...
@overload
def cursor(
self, name: str | bytes | None, cursor_factory: Callable[..., _T_cur], withhold: bool = ..., scrollable: bool | None = ...
) -> _T_cur: ...
class RealDictCursor(DictCursorBase):
def __init__(self, *args, **kwargs) -> None: ...
column_mapping: Any
def execute(self, query, vars: Incomplete | None = None): ...
def callproc(self, procname, vars: Incomplete | None = None): ...
def fetchone(self) -> RealDictRow | None: ... # type: ignore[override]
def fetchmany(self, size: int | None = None) -> list[RealDictRow]: ... # type: ignore[override]
def fetchall(self) -> list[RealDictRow]: ... # type: ignore[override]
def __next__(self) -> RealDictRow: ... # type: ignore[override]
class RealDictRow(OrderedDict[Any, Any]):
def __init__(self, *args, **kwargs) -> None: ...
def __setitem__(self, key, value) -> None: ...
class NamedTupleConnection(_connection):
@overload
def cursor(
self, name: str | bytes | None = ..., *, withhold: bool = ..., scrollable: bool | None = ...
) -> NamedTupleCursor: ...
@overload
def cursor(
self,
name: str | bytes | None = ...,
*,
cursor_factory: Callable[..., _T_cur],
withhold: bool = ...,
scrollable: bool | None = ...,
) -> _T_cur: ...
@overload
def cursor(
self, name: str | bytes | None, cursor_factory: Callable[..., _T_cur], withhold: bool = ..., scrollable: bool | None = ...
) -> _T_cur: ...
class NamedTupleCursor(_cursor):
Record: Any
MAX_CACHE: int
def execute(self, query, vars: Incomplete | None = None): ...
def executemany(self, query, vars): ...
def callproc(self, procname, vars: Incomplete | None = None): ...
def fetchone(self) -> NamedTuple | None: ...
def fetchmany(self, size: int | None = None) -> list[NamedTuple]: ... # type: ignore[override]
def fetchall(self) -> list[NamedTuple]: ... # type: ignore[override]
def __next__(self) -> NamedTuple: ...
class LoggingConnection(_connection):
log: Any
def initialize(self, logobj) -> None: ...
def filter(self, msg, curs): ...
def cursor(self, *args, **kwargs): ...
class LoggingCursor(_cursor):
def execute(self, query, vars: Incomplete | None = None): ...
def callproc(self, procname, vars: Incomplete | None = None): ...
class MinTimeLoggingConnection(LoggingConnection):
def initialize(self, logobj, mintime: int = 0) -> None: ...
def filter(self, msg, curs): ...
def cursor(self, *args, **kwargs): ...
class MinTimeLoggingCursor(LoggingCursor):
timestamp: Any
def execute(self, query, vars: Incomplete | None = None): ...
def callproc(self, procname, vars: Incomplete | None = None): ...
class LogicalReplicationConnection(_replicationConnection):
def __init__(self, *args, **kwargs) -> None: ...
class PhysicalReplicationConnection(_replicationConnection):
def __init__(self, *args, **kwargs) -> None: ...
class StopReplication(Exception): ...
class ReplicationCursor(_replicationCursor):
def create_replication_slot(
self, slot_name, slot_type: Incomplete | None = None, output_plugin: Incomplete | None = None
) -> None: ...
def drop_replication_slot(self, slot_name) -> None: ...
def start_replication(
self,
slot_name: Incomplete | None = None,
slot_type: Incomplete | None = None,
start_lsn: int = 0,
timeline: int = 0,
options: Incomplete | None = None,
decode: bool = False,
status_interval: int = 10,
) -> None: ...
def fileno(self): ...
class UUID_adapter:
def __init__(self, uuid) -> None: ...
def __conform__(self, proto): ...
def getquoted(self): ...
def register_uuid(oids: Incomplete | None = None, conn_or_curs: Incomplete | None = None): ...
class Inet:
addr: Any
def __init__(self, addr) -> None: ...
def prepare(self, conn) -> None: ...
def getquoted(self): ...
def __conform__(self, proto): ...
def register_inet(oid: Incomplete | None = None, conn_or_curs: Incomplete | None = None): ...
def wait_select(conn) -> None: ...
class HstoreAdapter:
wrapped: Any
def __init__(self, wrapped) -> None: ...
conn: Any
getquoted: Any
def prepare(self, conn) -> None: ...
@classmethod
def parse(cls, s, cur, _bsdec=...): ...
@classmethod
def parse_unicode(cls, s, cur): ...
@classmethod
def get_oids(cls, conn_or_curs): ...
def register_hstore(
conn_or_curs,
globally: bool = False,
unicode: bool = False,
oid: Incomplete | None = None,
array_oid: Incomplete | None = None,
) -> None: ...
class CompositeCaster:
name: Any
schema: Any
oid: Any
array_oid: Any
attnames: Any
atttypes: Any
typecaster: Any
array_typecaster: Any
def __init__(self, name, oid, attrs, array_oid: Incomplete | None = None, schema: Incomplete | None = None) -> None: ...
def parse(self, s, curs): ...
def make(self, values): ...
@classmethod
def tokenize(cls, s): ...
def register_composite(name, conn_or_curs, globally: bool = False, factory: Incomplete | None = None): ...
def execute_batch(cur, sql, argslist, page_size: int = 100) -> None: ...
def execute_values(cur, sql, argslist, template: Incomplete | None = None, page_size: int = 100, fetch: bool = False): ...

24
typings/psycopg2/pool.pyi Normal file
View File

@@ -0,0 +1,24 @@
from _typeshed import Incomplete
from typing import Any
import psycopg2
class PoolError(psycopg2.Error): ...
class AbstractConnectionPool:
minconn: Any
maxconn: Any
closed: bool
def __init__(self, minconn, maxconn, *args, **kwargs) -> None: ...
# getconn, putconn and closeall are officially documented as methods of the
# abstract base class, but in reality, they only exist on the children classes
def getconn(self, key: Incomplete | None = ...): ...
def putconn(self, conn: Any, key: Incomplete | None = ..., close: bool = ...) -> None: ...
def closeall(self) -> None: ...
class SimpleConnectionPool(AbstractConnectionPool): ...
class ThreadedConnectionPool(AbstractConnectionPool):
# This subclass has a default value for conn which doesn't exist
# in the SimpleConnectionPool class, nor in the documentation
def putconn(self, conn: Incomplete | None = None, key: Incomplete | None = None, close: bool = False) -> None: ...

50
typings/psycopg2/sql.pyi Normal file
View File

@@ -0,0 +1,50 @@
from _typeshed import Incomplete
from collections.abc import Iterator
from typing import Any
class Composable:
def __init__(self, wrapped) -> None: ...
def as_string(self, context) -> str: ...
def __add__(self, other) -> Composed: ...
def __mul__(self, n) -> Composed: ...
def __eq__(self, other) -> bool: ...
def __ne__(self, other) -> bool: ...
class Composed(Composable):
def __init__(self, seq) -> None: ...
@property
def seq(self) -> list[Composable]: ...
def as_string(self, context) -> str: ...
def __iter__(self) -> Iterator[Composable]: ...
def __add__(self, other) -> Composed: ...
def join(self, joiner) -> Composed: ...
class SQL(Composable):
def __init__(self, string) -> None: ...
@property
def string(self) -> str: ...
def as_string(self, context) -> str: ...
def format(self, *args, **kwargs) -> Composed: ...
def join(self, seq) -> Composed: ...
class Identifier(Composable):
def __init__(self, *strings) -> None: ...
@property
def strings(self) -> tuple[str, ...]: ...
@property
def string(self) -> str: ...
def as_string(self, context) -> str: ...
class Literal(Composable):
@property
def wrapped(self): ...
def as_string(self, context) -> str: ...
class Placeholder(Composable):
def __init__(self, name: Incomplete | None = None) -> None: ...
@property
def name(self) -> str | None: ...
def as_string(self, context) -> str: ...
NULL: Any
DEFAULT: Any

26
typings/psycopg2/tz.pyi Normal file
View File

@@ -0,0 +1,26 @@
import datetime
from _typeshed import Incomplete
from typing import Any
ZERO: Any
class FixedOffsetTimezone(datetime.tzinfo):
def __init__(self, offset: Incomplete | None = None, name: Incomplete | None = None) -> None: ...
def __new__(cls, offset: Incomplete | None = None, name: Incomplete | None = None): ...
def __eq__(self, other): ...
def __ne__(self, other): ...
def __getinitargs__(self): ...
def utcoffset(self, dt): ...
def tzname(self, dt): ...
def dst(self, dt): ...
STDOFFSET: Any
DSTOFFSET: Any
DSTDIFF: Any
class LocalTimezone(datetime.tzinfo):
def utcoffset(self, dt): ...
def dst(self, dt): ...
def tzname(self, dt): ...
LOCAL: Any

View File

@@ -0,0 +1,2 @@
from .syncobj import FAIL_REASON, SyncObj, SyncObjConf, replicated
__all__ = ['SyncObj', 'SyncObjConf', 'replicated', 'FAIL_REASON']

View File

@@ -0,0 +1,13 @@
from typing import Optional
class FAIL_REASON:
SUCCESS = ...
QUEUE_FULL = ...
MISSING_LEADER = ...
DISCARDED = ...
NOT_LEADER = ...
LEADER_CHANGED = ...
REQUEST_DENIED = ...
class SyncObjConf:
password: Optional[str]
autoTickPeriod: int
def __init__(self, **kwargs) -> None: ...

View File

@@ -0,0 +1,5 @@
from typing import Optional
class DnsCachingResolver:
def setTimeouts(self, cacheTime: float, failCacheTime: float) -> None: ...
def resolve(self, hostname: str) -> Optional[str]: ...
def globalDnsResolver() -> DnsCachingResolver: ...

View File

@@ -0,0 +1,6 @@
class Node:
@property
def id(self) -> str: ...
class TCPNode(Node):
@property
def host(self) -> str: ...

View File

@@ -0,0 +1,24 @@
from typing import Any, Callable, Collection, List, Optional, Set, Type
from .config import FAIL_REASON, SyncObjConf
from .node import Node
from .transport import Transport
__all__ = ['FAIL_REASON', 'SyncObj', 'SyncObjConf', 'replicated']
class SyncObj:
def __init__(self, selfNode: Optional[str], otherNodes: Collection[str], conf: SyncObjConf=..., consumers=..., nodeClass=..., transport=..., transportClass: Type[Transport]=...) -> None: ...
def destroy(self) -> None: ...
def doTick(self, timeToWait: float = 0.0) -> None: ...
def isNodeConnected(self, node: Node) -> bool: ...
@property
def selfNode(self) -> Node: ...
@property
def otherNodes(self) -> Set[Node]: ...
@property
def raftLastApplied(self) -> int: ...
@property
def raftCommitIndex(self) -> int: ...
@property
def conf(self) -> SyncObjConf: ...
def _getLeader(self) -> Optional[Node]: ...
def _isLeader(self) -> bool: ...
def _onTick(self, timeToWait: float = 0.0) -> None: ...
def replicated(*decArgs: Any, **decKwargs: Any) -> Callable[..., Any]: ...

View File

@@ -0,0 +1,4 @@
class CONNECTION_STATE:
DISCONNECTED = ...
CONNECTING = ...
CONNECTED = ...

View File

@@ -0,0 +1,10 @@
from typing import Any, Callable, Collection, Optional
from .node import TCPNode
from .syncobj import SyncObj
from .tcp_connection import CONNECTION_STATE
__all__ = ['CONNECTION_STATE', 'TCPTransport']
class Transport:
def setOnUtilityMessageCallback(self, message: str, callback: Callable[[Any, Callable[..., Any]], Any]) -> None: ...
class TCPTransport(Transport):
def __init__(self, syncObj: SyncObj, selfNode: Optional[TCPNode], otherNodes: Collection[TCPNode]) -> None: ...
def _connectIfNecessarySingle(self, node: TCPNode) -> bool: ...

View File

@@ -0,0 +1,5 @@
from typing import Any, List, Optional, Union
from .node import TCPNode
class TcpUtility(Utility):
def __init__(self, password: Optional[str] = None, timeout: float=900.0) -> None: ...
def executeCommand(self, node: Union[str, TCPNode], command: List[Any]) -> Any: ...

View File

@@ -0,0 +1,6 @@
from .poolmanager import PoolManager
from .response import HTTPResponse
from .util.request import make_headers
from .util.timeout import Timeout
__all__ = ['HTTPResponse', 'PoolManager', 'Timeout', 'make_headers']

View File

@@ -0,0 +1,28 @@
from typing import Any
class HTTPHeaderDict(MutableMapping[str, str]):
def __init__(self, headers=None, **kwargs) -> None: ...
def __setitem__(self, key, val) -> None: ...
def __getitem__(self, key): ...
def __delitem__(self, key) -> None: ...
def __contains__(self, key): ...
def __eq__(self, other): ...
def __iter__(self) -> NoReturn: ...
def __len__(self) -> int: ...
def __ne__(self, other): ...
values: Any
get: Any
update: Any
iterkeys: Any
itervalues: Any
def pop(self, key, default=...): ...
def discard(self, key): ...
def add(self, key, val): ...
def extend(self, *args, **kwargs): ...
def getlist(self, key): ...
getheaders: Any
getallmatchingheaders: Any
iget: Any
def copy(self): ...
def iteritems(self): ...
def itermerged(self): ...
def items(self): ...

View File

@@ -0,0 +1,2 @@
from http.client import HTTPConnection as _HTTPConnection
class HTTPConnection(_HTTPConnection): ...

View File

@@ -0,0 +1,9 @@
from typing import Any, Dict, Optional
from .response import HTTPResponse
class PoolManager:
headers: Dict[str, str]
connection_pool_kw: Dict[str, Any]
def __init__(self, num_pools: int = 10, headers: Optional[Dict[str, str]] = None, **connection_pool_kw: Any) -> None: ...
def urlopen(self, method: str, url: str, body: Optional[Any] = None, headers: Optional[Dict[str,str]] = None, encode_multipart: bool = True, multipart_boundary: Optional[str] = None, **kw: Any) -> HTTPResponse: ...
def request(self, method: str, url: str, fields: Optional[Any] = None, headers: Optional[Dict[str, str]] = None, **urlopen_kw: Any) -> HTTPResponse: ...
def clear(self) -> None: ...

View File

@@ -0,0 +1,14 @@
import io
from typing import Any, Iterator, Optional, Union
from ._collections import HTTPHeaderDict
from .connection import HTTPConnection
class HTTPResponse(io.IOBase):
headers: HTTPHeaderDict
status: int
reason: Optional[str]
def release_conn(self) -> None: ...
@property
def data(self) -> Union[bytes, Any]: ...
@property
def connection(self) -> Optional[HTTPConnection]: ...
def read_chunked(self, amt: Optional[int] = None, decode_content: Optional[bool] = None) -> Iterator[bytes]: ...

View File

@@ -0,0 +1,9 @@
from typing import Optional, Union, Dict, List
def make_headers(
keep_alive: Optional[bool] = None,
accept_encoding: Union[bool, List[str], str, None] = None,
user_agent: Optional[str] = None,
basic_auth: Optional[str] = None,
proxy_basic_auth: Optional[str] = None,
disable_cache: Optional[bool] = None,
) -> Dict[str, str]: ...

Some files were not shown because too many files have changed in this diff Show More