mirror of
https://github.com/outbackdingo/patroni.git
synced 2026-01-27 10:20:10 +00:00
Enable pyright strict mode (#2652)
- added pyrightconfig.json with typeCheckingMode=strict - added type hints to all files except api.py - added type stubs for dns, etcd, consul, kazoo, pysyncobj and other modules - added type stubs for psycopg2 and urllib3 with some little fixes - fixes most of the issues reported by pyright - remaining issues will be addressed later, along with enabling CI linting task
This commit is contained in:
committed by
GitHub
parent
1ac9b11f33
commit
76b3b99de2
@@ -1,17 +1,19 @@
|
||||
import sys
|
||||
|
||||
from typing import Any, Callable, Iterator, Tuple
|
||||
|
||||
PATRONI_ENV_PREFIX = 'PATRONI_'
|
||||
KUBERNETES_ENV_PREFIX = 'KUBERNETES_'
|
||||
MIN_PSYCOPG2 = (2, 5, 4)
|
||||
|
||||
|
||||
def fatal(string, *args):
|
||||
def fatal(string: str, *args: Any) -> None:
|
||||
sys.stderr.write('FATAL: ' + string.format(*args) + '\n')
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def parse_version(version):
|
||||
def _parse_version(version):
|
||||
def parse_version(version: str) -> Tuple[int, ...]:
|
||||
def _parse_version(version: str) -> Iterator[int]:
|
||||
for e in version.split('.'):
|
||||
try:
|
||||
yield int(e)
|
||||
@@ -21,9 +23,11 @@ def parse_version(version):
|
||||
|
||||
|
||||
# We pass MIN_PSYCOPG2 and parse_version as arguments to simplify usage of check_psycopg from the setup.py
|
||||
def check_psycopg(_min_psycopg2=MIN_PSYCOPG2, _parse_version=parse_version):
|
||||
def check_psycopg(_min_psycopg2: Tuple[int, ...] = MIN_PSYCOPG2,
|
||||
_parse_version: Callable[[str], Tuple[int, ...]] = parse_version) -> None:
|
||||
min_psycopg2_str = '.'.join(map(str, _min_psycopg2))
|
||||
|
||||
# try psycopg2
|
||||
try:
|
||||
from psycopg2 import __version__
|
||||
if _parse_version(__version__) >= _min_psycopg2:
|
||||
@@ -32,10 +36,11 @@ def check_psycopg(_min_psycopg2=MIN_PSYCOPG2, _parse_version=parse_version):
|
||||
except ImportError:
|
||||
version_str = None
|
||||
|
||||
# try psycopg3
|
||||
try:
|
||||
from psycopg import __version__
|
||||
except ImportError:
|
||||
error = 'Patroni requires psycopg2>={0}, psycopg2-binary, or psycopg>=3.0'.format(min_psycopg2_str)
|
||||
if version_str:
|
||||
if version_str is not None:
|
||||
error += ', but only psycopg2=={0} is available'.format(version_str)
|
||||
fatal(error)
|
||||
|
||||
@@ -3,14 +3,19 @@ import os
|
||||
import signal
|
||||
import time
|
||||
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING
|
||||
|
||||
from patroni.daemon import AbstractPatroniDaemon, abstract_main
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from .config import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Patroni(AbstractPatroniDaemon):
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: 'Config') -> None:
|
||||
from patroni.api import RestApiServer
|
||||
from patroni.dcs import get_dcs
|
||||
from patroni.ha import Ha
|
||||
@@ -33,9 +38,9 @@ class Patroni(AbstractPatroniDaemon):
|
||||
|
||||
self.tags = self.get_tags()
|
||||
self.next_run = time.time()
|
||||
self.scheduled_restart = {}
|
||||
self.scheduled_restart: Dict[str, Any] = {}
|
||||
|
||||
def load_dynamic_configuration(self):
|
||||
def load_dynamic_configuration(self) -> None:
|
||||
from patroni.exceptions import DCSError
|
||||
while True:
|
||||
try:
|
||||
@@ -53,19 +58,19 @@ class Patroni(AbstractPatroniDaemon):
|
||||
logger.warning('Can not get cluster from dcs')
|
||||
time.sleep(5)
|
||||
|
||||
def get_tags(self):
|
||||
def get_tags(self) -> Dict[str, Any]:
|
||||
return {tag: value for tag, value in self.config.get('tags', {}).items()
|
||||
if tag not in ('clonefrom', 'nofailover', 'noloadbalance', 'nosync') or value}
|
||||
|
||||
@property
|
||||
def nofailover(self):
|
||||
def nofailover(self) -> bool:
|
||||
return bool(self.tags.get('nofailover', False))
|
||||
|
||||
@property
|
||||
def nosync(self):
|
||||
def nosync(self) -> bool:
|
||||
return bool(self.tags.get('nosync', False))
|
||||
|
||||
def reload_config(self, sighup=False, local=False):
|
||||
def reload_config(self, sighup: bool = False, local: Optional[bool] = False) -> None:
|
||||
try:
|
||||
super(Patroni, self).reload_config(sighup, local)
|
||||
if local:
|
||||
@@ -87,7 +92,7 @@ class Patroni(AbstractPatroniDaemon):
|
||||
def noloadbalance(self):
|
||||
return bool(self.tags.get('noloadbalance', False))
|
||||
|
||||
def schedule_next_run(self):
|
||||
def schedule_next_run(self) -> None:
|
||||
self.next_run += self.dcs.loop_wait
|
||||
current_time = time.time()
|
||||
nap_time = self.next_run - current_time
|
||||
@@ -100,12 +105,12 @@ class Patroni(AbstractPatroniDaemon):
|
||||
elif self.ha.watch(nap_time):
|
||||
self.next_run = time.time()
|
||||
|
||||
def run(self):
|
||||
def run(self) -> None:
|
||||
self.api.start()
|
||||
self.next_run = time.time()
|
||||
super(Patroni, self).run()
|
||||
|
||||
def _run_cycle(self):
|
||||
def _run_cycle(self) -> None:
|
||||
logger.info(self.ha.run_cycle())
|
||||
|
||||
if self.dcs.cluster and self.dcs.cluster.config and self.dcs.cluster.config.data \
|
||||
@@ -117,7 +122,7 @@ class Patroni(AbstractPatroniDaemon):
|
||||
|
||||
self.schedule_next_run()
|
||||
|
||||
def _shutdown(self):
|
||||
def _shutdown(self) -> None:
|
||||
try:
|
||||
self.api.shutdown()
|
||||
except Exception:
|
||||
@@ -128,7 +133,7 @@ class Patroni(AbstractPatroniDaemon):
|
||||
logger.exception('Exception during Ha.shutdown')
|
||||
|
||||
|
||||
def patroni_main():
|
||||
def patroni_main() -> None:
|
||||
from multiprocessing import freeze_support
|
||||
from patroni.validator import schema
|
||||
|
||||
@@ -136,18 +141,20 @@ def patroni_main():
|
||||
abstract_main(Patroni, schema)
|
||||
|
||||
|
||||
def main():
|
||||
if os.getpid() != 1:
|
||||
from patroni import check_psycopg
|
||||
def main() -> None:
|
||||
from patroni import check_psycopg
|
||||
|
||||
check_psycopg()
|
||||
check_psycopg()
|
||||
|
||||
if os.getpid() != 1:
|
||||
return patroni_main()
|
||||
|
||||
# Patroni started with PID=1, it looks like we are in the container
|
||||
from types import FrameType
|
||||
pid = 0
|
||||
|
||||
# Looks like we are in a docker, so we will act like init
|
||||
def sigchld_handler(signo, stack_frame):
|
||||
def sigchld_handler(signo: int, stack_frame: Optional[FrameType]) -> None:
|
||||
try:
|
||||
while True:
|
||||
ret = os.waitpid(-1, os.WNOHANG)
|
||||
@@ -158,7 +165,7 @@ def main():
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def passtochild(signo, stack_frame):
|
||||
def passtochild(signo: int, stack_frame: Optional[FrameType]):
|
||||
if pid:
|
||||
os.kill(pid, signo)
|
||||
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
import logging
|
||||
|
||||
from threading import Event, Lock, RLock, Thread
|
||||
from types import TracebackType
|
||||
from typing import Any, Callable, Optional, Tuple, Type
|
||||
|
||||
from .postgresql.cancellable import CancellableSubprocess
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -15,19 +20,19 @@ class CriticalTask(object):
|
||||
call cancel. If the task has completed `cancel()` will return False and `result` field will contain the result of
|
||||
the task. When cancel returns True it is guaranteed that the background task will notice the `is_cancelled` flag.
|
||||
"""
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._lock = Lock()
|
||||
self.is_cancelled = False
|
||||
self.result = None
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
"""Must be called every time the background task is finished.
|
||||
|
||||
Must be called from async thread. Caller must hold lock on async executor when calling."""
|
||||
self.is_cancelled = False
|
||||
self.result = None
|
||||
|
||||
def cancel(self):
|
||||
def cancel(self) -> bool:
|
||||
"""Tries to cancel the task, returns True if the task has already run.
|
||||
|
||||
Caller must hold lock on async executor and the task when calling."""
|
||||
@@ -36,37 +41,38 @@ class CriticalTask(object):
|
||||
self.is_cancelled = True
|
||||
return True
|
||||
|
||||
def complete(self, result):
|
||||
def complete(self, result: Any) -> None:
|
||||
"""Mark task as completed along with a result.
|
||||
|
||||
Must be called from async thread. Caller must hold lock on task when calling."""
|
||||
self.result = result
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> 'CriticalTask':
|
||||
self._lock.acquire()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
def __exit__(self, exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]) -> None:
|
||||
self._lock.release()
|
||||
|
||||
|
||||
class AsyncExecutor(object):
|
||||
|
||||
def __init__(self, cancellable, ha_wakeup):
|
||||
def __init__(self, cancellable: CancellableSubprocess, ha_wakeup: Callable[..., None]) -> None:
|
||||
self._cancellable = cancellable
|
||||
self._ha_wakeup = ha_wakeup
|
||||
self._thread_lock = RLock()
|
||||
self._scheduled_action = None
|
||||
self._scheduled_action: Optional[str] = None
|
||||
self._scheduled_action_lock = RLock()
|
||||
self._is_cancelled = False
|
||||
self._finish_event = Event()
|
||||
self.critical_task = CriticalTask()
|
||||
|
||||
@property
|
||||
def busy(self):
|
||||
def busy(self) -> bool:
|
||||
return self.scheduled_action is not None
|
||||
|
||||
def schedule(self, action):
|
||||
def schedule(self, action: str) -> Optional[str]:
|
||||
with self._scheduled_action_lock:
|
||||
if self._scheduled_action is not None:
|
||||
return self._scheduled_action
|
||||
@@ -76,15 +82,15 @@ class AsyncExecutor(object):
|
||||
return None
|
||||
|
||||
@property
|
||||
def scheduled_action(self):
|
||||
def scheduled_action(self) -> Optional[str]:
|
||||
with self._scheduled_action_lock:
|
||||
return self._scheduled_action
|
||||
|
||||
def reset_scheduled_action(self):
|
||||
def reset_scheduled_action(self) -> None:
|
||||
with self._scheduled_action_lock:
|
||||
self._scheduled_action = None
|
||||
|
||||
def run(self, func, args=()):
|
||||
def run(self, func: Callable[..., Any], args: Tuple[Any, ...] = ()) -> Optional[bool]:
|
||||
wakeup = False
|
||||
try:
|
||||
with self:
|
||||
@@ -107,16 +113,16 @@ class AsyncExecutor(object):
|
||||
if wakeup is not None:
|
||||
self._ha_wakeup()
|
||||
|
||||
def run_async(self, func, args=()):
|
||||
def run_async(self, func: Callable[..., Any], args: Tuple[Any, ...] = ()) -> None:
|
||||
Thread(target=self.run, args=(func, args)).start()
|
||||
|
||||
def try_run_async(self, action, func, args=()):
|
||||
def try_run_async(self, action: str, func: Callable[..., Any], args: Tuple[Any, ...] = ()) -> Optional[str]:
|
||||
prev = self.schedule(action)
|
||||
if prev is None:
|
||||
return self.run_async(func, args)
|
||||
return 'Failed to run {0}, {1} is already in progress'.format(action, prev)
|
||||
|
||||
def cancel(self):
|
||||
def cancel(self) -> None:
|
||||
with self:
|
||||
with self._scheduled_action_lock:
|
||||
if self._scheduled_action is None:
|
||||
@@ -130,8 +136,10 @@ class AsyncExecutor(object):
|
||||
with self:
|
||||
self.reset_scheduled_action()
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> 'AsyncExecutor':
|
||||
self._thread_lock.acquire()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
def __exit__(self, exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]) -> None:
|
||||
self._thread_lock.release()
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
from collections import OrderedDict
|
||||
from collections.abc import MutableMapping, MutableSet
|
||||
from typing import Any, Collection, Dict, Iterable, Iterator, Optional, Tuple, Union
|
||||
from typing import Any, Collection, Dict, Iterator, MutableMapping, MutableSet, Optional
|
||||
|
||||
|
||||
class CaseInsensitiveSet(MutableSet):
|
||||
class CaseInsensitiveSet(MutableSet[str]):
|
||||
"""A case-insensitive ``set``-like object.
|
||||
|
||||
Implements all methods and operations of :class:``MutableSet``. All values are expected to be strings.
|
||||
The structure remembers the case of the last value set, however, contains testing is case insensitive.
|
||||
"""
|
||||
def __init__(self, values: Optional[Collection[str]] = None) -> None:
|
||||
self._values = {}
|
||||
self._values: Dict[str, str] = {}
|
||||
for v in values or ():
|
||||
self.add(v)
|
||||
|
||||
@@ -39,7 +38,7 @@ class CaseInsensitiveSet(MutableSet):
|
||||
return self <= other
|
||||
|
||||
|
||||
class CaseInsensitiveDict(MutableMapping):
|
||||
class CaseInsensitiveDict(MutableMapping[str, Any]):
|
||||
"""A case-insensitive ``dict``-like object.
|
||||
|
||||
Implements all methods and operations of :class:``MutableMapping`` as well as dict's :func:``copy``.
|
||||
@@ -47,8 +46,8 @@ class CaseInsensitiveDict(MutableMapping):
|
||||
and ``iter(instance)``, ``keys()``, ``items()``, ``iterkeys()``, and ``iteritems()`` will contain
|
||||
case-sensitive keys. However, querying and contains testing is case insensitive.
|
||||
"""
|
||||
def __init__(self, data: Optional[Union[Dict[str, Any], Iterable[Tuple[str, Any]]]] = None) -> None:
|
||||
self._values = OrderedDict()
|
||||
def __init__(self, data: Optional[Dict[str, Any]] = None) -> None:
|
||||
self._values: OrderedDict[str, Any] = OrderedDict()
|
||||
self.update(data or {})
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
@@ -68,7 +67,7 @@ class CaseInsensitiveDict(MutableMapping):
|
||||
return len(self._values)
|
||||
|
||||
def copy(self) -> 'CaseInsensitiveDict':
|
||||
return CaseInsensitiveDict(self._values.values())
|
||||
return CaseInsensitiveDict({v[0]: v[1] for v in self._values.values()})
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return '<{0}{1} at {2:x}>'.format(type(self).__name__, dict(self.items()), id(self))
|
||||
|
||||
@@ -7,7 +7,7 @@ import yaml
|
||||
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Callable, Collection, Dict, List, Optional, Union
|
||||
|
||||
from . import PATRONI_ENV_PREFIX
|
||||
from .collections import CaseInsensitiveDict
|
||||
@@ -33,9 +33,10 @@ _AUTH_ALLOWED_PARAMETERS = (
|
||||
)
|
||||
|
||||
|
||||
def default_validator(conf):
|
||||
def default_validator(conf: Dict[str, Any]) -> List[str]:
|
||||
if not conf:
|
||||
raise ConfigParseError("Config is empty.")
|
||||
return []
|
||||
|
||||
|
||||
class GlobalConfig(object):
|
||||
@@ -86,7 +87,7 @@ class GlobalConfig(object):
|
||||
""":returns: `True` if at least one synchronous node is required."""
|
||||
return self.check_mode('synchronous_mode_strict')
|
||||
|
||||
def get_standby_cluster_config(self) -> Any:
|
||||
def get_standby_cluster_config(self) -> Union[Dict[str, Any], Any]:
|
||||
""":returns: "standby_cluster" configuration."""
|
||||
return deepcopy(self.get('standby_cluster'))
|
||||
|
||||
@@ -142,7 +143,7 @@ class GlobalConfig(object):
|
||||
if 'primary_stop_timeout' in self.__config else self.get_int('master_stop_timeout', default)
|
||||
|
||||
|
||||
def get_global_config(cluster: Union[Cluster, None], default: Optional[Dict] = None) -> GlobalConfig:
|
||||
def get_global_config(cluster: Union[Cluster, None], default: Optional[Dict[str, Any]] = None) -> GlobalConfig:
|
||||
"""Instantiates :class:`GlobalConfig` based on the input.
|
||||
|
||||
:param cluster: the currently known cluster state from DCS
|
||||
@@ -180,7 +181,7 @@ class Config(object):
|
||||
PATRONI_CONFIG_VARIABLE = PATRONI_ENV_PREFIX + 'CONFIGURATION'
|
||||
|
||||
__CACHE_FILENAME = 'patroni.dynamic.json'
|
||||
__DEFAULT_CONFIG = {
|
||||
__DEFAULT_CONFIG: Dict[str, Any] = {
|
||||
'ttl': 30, 'loop_wait': 10, 'retry_timeout': 10,
|
||||
'standby_cluster': {
|
||||
'create_replica_methods': '',
|
||||
@@ -199,19 +200,21 @@ class Config(object):
|
||||
}
|
||||
}
|
||||
|
||||
def __init__(self, configfile, validator=default_validator):
|
||||
def __init__(self, configfile: str,
|
||||
validator: Optional[Callable[[Dict[str, Any]], List[str]]] = default_validator) -> None:
|
||||
self._modify_index = -1
|
||||
self._dynamic_configuration = {}
|
||||
|
||||
self.__environment_configuration = self._build_environment_configuration()
|
||||
|
||||
# Patroni reads the configuration from the command-line argument if it exists, otherwise from the environment
|
||||
self._config_file = configfile and os.path.exists(configfile) and configfile
|
||||
self._config_file = configfile if configfile and os.path.exists(configfile) else None
|
||||
if self._config_file:
|
||||
self._local_configuration = self._load_config_file()
|
||||
else:
|
||||
config_env = os.environ.pop(self.PATRONI_CONFIG_VARIABLE, None)
|
||||
self._local_configuration = config_env and yaml.safe_load(config_env) or self.__environment_configuration
|
||||
|
||||
if validator:
|
||||
errors = validator(self._local_configuration)
|
||||
if errors:
|
||||
@@ -224,14 +227,14 @@ class Config(object):
|
||||
self._cache_needs_saving = False
|
||||
|
||||
@property
|
||||
def config_file(self):
|
||||
def config_file(self) -> Union[str, None]:
|
||||
return self._config_file
|
||||
|
||||
@property
|
||||
def dynamic_configuration(self):
|
||||
def dynamic_configuration(self) -> Dict[str, Any]:
|
||||
return deepcopy(self._dynamic_configuration)
|
||||
|
||||
def _load_config_path(self, path):
|
||||
def _load_config_path(self, path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
If path is a file, loads the yml file pointed to by path.
|
||||
If path is a directory, loads all yml files in that directory in alphabetical order
|
||||
@@ -245,20 +248,21 @@ class Config(object):
|
||||
logger.error('config path %s is neither directory nor file', path)
|
||||
raise ConfigParseError('invalid config path')
|
||||
|
||||
overall_config = {}
|
||||
overall_config: Dict[str, Any] = {}
|
||||
for fname in files:
|
||||
with open(fname) as f:
|
||||
config = yaml.safe_load(f)
|
||||
patch_config(overall_config, config)
|
||||
return overall_config
|
||||
|
||||
def _load_config_file(self):
|
||||
def _load_config_file(self) -> Dict[str, Any]:
|
||||
"""Loads config.yaml from filesystem and applies some values which were set via ENV"""
|
||||
assert self._config_file is not None
|
||||
config = self._load_config_path(self._config_file)
|
||||
patch_config(config, self.__environment_configuration)
|
||||
return config
|
||||
|
||||
def _load_cache(self):
|
||||
def _load_cache(self) -> None:
|
||||
if os.path.isfile(self._cache_file):
|
||||
try:
|
||||
with open(self._cache_file) as f:
|
||||
@@ -266,7 +270,7 @@ class Config(object):
|
||||
except Exception:
|
||||
logger.exception('Exception when loading file: %s', self._cache_file)
|
||||
|
||||
def save_cache(self):
|
||||
def save_cache(self) -> None:
|
||||
if self._cache_needs_saving:
|
||||
tmpfile = fd = None
|
||||
try:
|
||||
@@ -290,7 +294,7 @@ class Config(object):
|
||||
logger.error('Can not remove temporary file %s', tmpfile)
|
||||
|
||||
# configuration could be either ClusterConfig or dict
|
||||
def set_dynamic_configuration(self, configuration):
|
||||
def set_dynamic_configuration(self, configuration: Union[ClusterConfig, Dict[str, Any]]) -> bool:
|
||||
if isinstance(configuration, ClusterConfig):
|
||||
if self._modify_index == configuration.modify_index:
|
||||
return False # If the index didn't changed there is nothing to do
|
||||
@@ -306,8 +310,9 @@ class Config(object):
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception('Exception when setting dynamic_configuration')
|
||||
return False
|
||||
|
||||
def reload_local_configuration(self):
|
||||
def reload_local_configuration(self) -> Optional[bool]:
|
||||
if self.config_file:
|
||||
try:
|
||||
configuration = self._load_config_file()
|
||||
@@ -322,12 +327,12 @@ class Config(object):
|
||||
logger.exception('Exception when reloading local configuration from %s', self.config_file)
|
||||
|
||||
@staticmethod
|
||||
def _process_postgresql_parameters(parameters, is_local=False):
|
||||
def _process_postgresql_parameters(parameters: Dict[str, Any], is_local: bool = False) -> Dict[str, Any]:
|
||||
return {name: value for name, value in (parameters or {}).items()
|
||||
if name not in ConfigHandler.CMDLINE_OPTIONS
|
||||
or not is_local and ConfigHandler.CMDLINE_OPTIONS[name][1](value)}
|
||||
|
||||
def _safe_copy_dynamic_configuration(self, dynamic_configuration):
|
||||
def _safe_copy_dynamic_configuration(self, dynamic_configuration: Dict[str, Any]) -> Dict[str, Any]:
|
||||
config = deepcopy(self.__DEFAULT_CONFIG)
|
||||
|
||||
for name, value in dynamic_configuration.items():
|
||||
@@ -347,10 +352,10 @@ class Config(object):
|
||||
return config
|
||||
|
||||
@staticmethod
|
||||
def _build_environment_configuration():
|
||||
ret = defaultdict(dict)
|
||||
def _build_environment_configuration() -> Dict[str, Any]:
|
||||
ret: Dict[str, Any] = defaultdict(dict)
|
||||
|
||||
def _popenv(name):
|
||||
def _popenv(name: str) -> Union[str, None]:
|
||||
return os.environ.pop(PATRONI_ENV_PREFIX + name.upper(), None)
|
||||
|
||||
for param in ('name', 'namespace', 'scope'):
|
||||
@@ -358,7 +363,7 @@ class Config(object):
|
||||
if value:
|
||||
ret[param] = value
|
||||
|
||||
def _fix_log_env(name, oldname):
|
||||
def _fix_log_env(name: str, oldname: str) -> None:
|
||||
value = _popenv(oldname)
|
||||
name = PATRONI_ENV_PREFIX + 'LOG_' + name.upper()
|
||||
if value and name not in os.environ:
|
||||
@@ -367,7 +372,7 @@ class Config(object):
|
||||
for name, oldname in (('level', 'loglevel'), ('format', 'logformat'), ('dateformat', 'log_datefmt')):
|
||||
_fix_log_env(name, oldname)
|
||||
|
||||
def _set_section_values(section, params):
|
||||
def _set_section_values(section: str, params: List[str]) -> None:
|
||||
for param in params:
|
||||
value = _popenv(section + '_' + param)
|
||||
if value:
|
||||
@@ -400,7 +405,7 @@ class Config(object):
|
||||
if value is not None:
|
||||
ret[first][second] = value
|
||||
|
||||
def _parse_list(value):
|
||||
def _parse_list(value: str) -> Union[List[str], None]:
|
||||
if not (value.strip().startswith('-') or '[' in value):
|
||||
value = '[{0}]'.format(value)
|
||||
try:
|
||||
@@ -416,7 +421,7 @@ class Config(object):
|
||||
if value:
|
||||
ret[first][second] = value
|
||||
|
||||
def _parse_dict(value):
|
||||
def _parse_dict(value: str) -> Union[Dict[str, Any], None]:
|
||||
if not value.strip().startswith('{'):
|
||||
value = '{{{0}}}'.format(value)
|
||||
try:
|
||||
@@ -433,8 +438,8 @@ class Config(object):
|
||||
if value:
|
||||
ret[first][second] = value
|
||||
|
||||
def _get_auth(name, params=None):
|
||||
ret = {}
|
||||
def _get_auth(name: str, params: Optional[Collection[str]] = None) -> Dict[str, str]:
|
||||
ret: Dict[str, str] = {}
|
||||
for param in params or _AUTH_ALLOWED_PARAMETERS[:2]:
|
||||
value = _popenv(name + '_' + param)
|
||||
if value:
|
||||
@@ -503,7 +508,8 @@ class Config(object):
|
||||
|
||||
return ret
|
||||
|
||||
def _build_effective_configuration(self, dynamic_configuration, local_configuration):
|
||||
def _build_effective_configuration(self, dynamic_configuration: Dict[str, Any],
|
||||
local_configuration: Dict[str, Union[Dict[str, Any], Any]]) -> Dict[str, Any]:
|
||||
config = self._safe_copy_dynamic_configuration(dynamic_configuration)
|
||||
for name, value in local_configuration.items():
|
||||
if name == 'citus': # remove invalid citus configuration
|
||||
@@ -565,16 +571,16 @@ class Config(object):
|
||||
|
||||
return config
|
||||
|
||||
def get(self, key, default=None):
|
||||
def get(self, key: str, default: Optional[Any] = None) -> Any:
|
||||
return self.__effective_configuration.get(key, default)
|
||||
|
||||
def __contains__(self, key):
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return key in self.__effective_configuration
|
||||
|
||||
def __getitem__(self, key):
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self.__effective_configuration[key]
|
||||
|
||||
def copy(self):
|
||||
def copy(self) -> Dict[str, Any]:
|
||||
return deepcopy(self.__effective_configuration)
|
||||
|
||||
def get_global_config(self, cluster: Union[Cluster, None]) -> GlobalConfig:
|
||||
|
||||
305
patroni/ctl.py
305
patroni/ctl.py
@@ -18,23 +18,25 @@ import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import urllib3
|
||||
import time
|
||||
import yaml
|
||||
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from click import ClickException
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from prettytable import ALL, FRAME, PrettyTable
|
||||
from urllib.parse import urlparse
|
||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Union, Tuple, TYPE_CHECKING
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from psycopg import Cursor
|
||||
from psycopg2 import cursor
|
||||
|
||||
try:
|
||||
from ydiff import markup_to_pager, PatchStream
|
||||
except ImportError: # pragma: no cover
|
||||
from cdiff import markup_to_pager, PatchStream
|
||||
|
||||
from .dcs import get_dcs as _get_dcs
|
||||
from .dcs import get_dcs as _get_dcs, AbstractDCS, Cluster, Member
|
||||
from .exceptions import PatroniException
|
||||
from .postgresql.misc import postgres_version_to_int
|
||||
from .utils import cluster_as_json, patch_config, polling_loop
|
||||
@@ -43,30 +45,31 @@ from .version import __version__
|
||||
|
||||
CONFIG_DIR_PATH = click.get_app_dir('patroni')
|
||||
CONFIG_FILE_PATH = os.path.join(CONFIG_DIR_PATH, 'patronictl.yaml')
|
||||
DCS_DEFAULTS = {'zookeeper': {'port': 2181, 'template': "zookeeper:\n hosts: ['{host}:{port}']"},
|
||||
'exhibitor': {'port': 8181, 'template': "exhibitor:\n hosts: [{host}]\n port: {port}"},
|
||||
'consul': {'port': 8500, 'template': "consul:\n host: '{host}:{port}'"},
|
||||
'etcd': {'port': 2379, 'template': "etcd:\n host: '{host}:{port}'"},
|
||||
'etcd3': {'port': 2379, 'template': "etcd3:\n host: '{host}:{port}'"}}
|
||||
DCS_DEFAULTS: Dict[str, Dict[str, Any]] = {
|
||||
'zookeeper': {'port': 2181, 'template': "zookeeper:\n hosts: ['{host}:{port}']"},
|
||||
'exhibitor': {'port': 8181, 'template': "exhibitor:\n hosts: [{host}]\n port: {port}"},
|
||||
'consul': {'port': 8500, 'template': "consul:\n host: '{host}:{port}'"},
|
||||
'etcd': {'port': 2379, 'template': "etcd:\n host: '{host}:{port}'"},
|
||||
'etcd3': {'port': 2379, 'template': "etcd3:\n host: '{host}:{port}'"}}
|
||||
|
||||
|
||||
class PatroniCtlException(ClickException):
|
||||
class PatroniCtlException(click.ClickException):
|
||||
pass
|
||||
|
||||
|
||||
class PatronictlPrettyTable(PrettyTable):
|
||||
|
||||
def __init__(self, header, *args, **kwargs):
|
||||
PrettyTable.__init__(self, *args, **kwargs)
|
||||
def __init__(self, header: str, *args: Any, **kwargs: Any) -> None:
|
||||
super(PatronictlPrettyTable, self).__init__(*args, **kwargs)
|
||||
self.__table_header = header
|
||||
self.__hline_num = 0
|
||||
self.__hline = None
|
||||
self.__hline: str
|
||||
|
||||
def __build_header(self, line):
|
||||
def __build_header(self, line: str) -> str:
|
||||
header = self.__table_header[:len(line) - 2]
|
||||
return "".join([line[0], header, line[1 + len(header):]])
|
||||
|
||||
def _stringify_hrule(self, *args, **kwargs):
|
||||
def _stringify_hrule(self, *args: Any, **kwargs: Any) -> str:
|
||||
ret = super(PatronictlPrettyTable, self)._stringify_hrule(*args, **kwargs)
|
||||
where = args[1] if len(args) > 1 else kwargs.get('where')
|
||||
if where == 'top_' and self.__table_header:
|
||||
@@ -74,13 +77,13 @@ class PatronictlPrettyTable(PrettyTable):
|
||||
self.__hline_num += 1
|
||||
return ret
|
||||
|
||||
def _is_first_hline(self):
|
||||
def _is_first_hline(self) -> bool:
|
||||
return self.__hline_num == 0
|
||||
|
||||
def _set_hline(self, value):
|
||||
def _set_hline(self, value: str) -> None:
|
||||
self.__hline = value
|
||||
|
||||
def _get_hline(self):
|
||||
def _get_hline(self) -> str:
|
||||
ret = self.__hline
|
||||
|
||||
# Inject nice table header
|
||||
@@ -93,7 +96,7 @@ class PatronictlPrettyTable(PrettyTable):
|
||||
_hrule = property(_get_hline, _set_hline)
|
||||
|
||||
|
||||
def parse_dcs(dcs: str) -> Union[Dict[str, Any], None]:
|
||||
def parse_dcs(dcs: Optional[str]) -> Optional[Dict[str, Any]]:
|
||||
"""Parse a DCS URL.
|
||||
|
||||
:param dcs: the DCS URL in the format ``DCS://HOST:PORT``. ``DCS`` can be one among
|
||||
@@ -143,7 +146,7 @@ def parse_dcs(dcs: str) -> Union[Dict[str, Any], None]:
|
||||
return yaml.safe_load(default['template'].format(host=parsed.hostname or 'localhost', port=port or default['port']))
|
||||
|
||||
|
||||
def load_config(path, dcs_url):
|
||||
def load_config(path: str, dcs_url: Optional[str]) -> Dict[str, Any]:
|
||||
from patroni.config import Config
|
||||
|
||||
if not (os.path.exists(path) and os.access(path, os.R_OK)):
|
||||
@@ -156,11 +159,11 @@ def load_config(path, dcs_url):
|
||||
logging.debug('Loading configuration from file %s', path)
|
||||
config = Config(path, validator=None).copy()
|
||||
|
||||
dcs_url = parse_dcs(dcs_url) or {}
|
||||
if dcs_url:
|
||||
dcs_kwargs = parse_dcs(dcs_url) or {}
|
||||
if dcs_kwargs:
|
||||
for d in DCS_DEFAULTS:
|
||||
config.pop(d, None)
|
||||
config.update(dcs_url)
|
||||
config.update(dcs_kwargs)
|
||||
return config
|
||||
|
||||
|
||||
@@ -183,7 +186,7 @@ role_choice = click.Choice(['leader', 'primary', 'standby-leader', 'replica', 's
|
||||
@click.option('--dcs-url', '--dcs', '-d', 'dcs_url', help='The DCS connect url', envvar='DCS_URL')
|
||||
@option_insecure
|
||||
@click.pass_context
|
||||
def ctl(ctx, config_file, dcs_url, insecure):
|
||||
def ctl(ctx: click.Context, config_file: str, dcs_url: Optional[str], insecure: bool) -> None:
|
||||
level = 'WARNING'
|
||||
for name in ('LOGLEVEL', 'PATRONI_LOGLEVEL', 'PATRONI_LOG_LEVEL'):
|
||||
level = os.environ.get(name, level)
|
||||
@@ -194,7 +197,7 @@ def ctl(ctx, config_file, dcs_url, insecure):
|
||||
ctx.obj.setdefault('ctl', {})['insecure'] = ctx.obj.get('ctl', {}).get('insecure') or insecure
|
||||
|
||||
|
||||
def get_dcs(config, scope, group):
|
||||
def get_dcs(config: Dict[str, Any], scope: str, group: Optional[int]) -> AbstractDCS:
|
||||
config.update({'scope': scope, 'patronictl': True})
|
||||
if group is not None:
|
||||
config['citus'] = {'group': group}
|
||||
@@ -202,13 +205,14 @@ def get_dcs(config, scope, group):
|
||||
try:
|
||||
dcs = _get_dcs(config)
|
||||
if config.get('citus') and group is None:
|
||||
dcs.get_cluster = dcs._get_citus_cluster
|
||||
dcs.is_citus_coordinator = lambda: True
|
||||
return dcs
|
||||
except PatroniException as e:
|
||||
raise PatroniCtlException(str(e))
|
||||
|
||||
|
||||
def request_patroni(member, method='GET', endpoint=None, data=None):
|
||||
def request_patroni(member: Member, method: str = 'GET',
|
||||
endpoint: Optional[str] = None, data: Optional[Any] = None) -> urllib3.response.HTTPResponse:
|
||||
ctx = click.get_current_context() # the current click context
|
||||
request_executor = ctx.obj.get('__request_patroni')
|
||||
if not request_executor:
|
||||
@@ -216,19 +220,20 @@ def request_patroni(member, method='GET', endpoint=None, data=None):
|
||||
return request_executor(member, method, endpoint, data)
|
||||
|
||||
|
||||
def print_output(columns, rows, alignment=None, fmt='pretty', header=None, delimiter='\t'):
|
||||
def print_output(columns: Optional[List[str]], rows: List[List[Any]], alignment: Optional[Dict[str, str]] = None,
|
||||
fmt: str = 'pretty', header: str = '', delimiter: str = '\t') -> None:
|
||||
if fmt in {'json', 'yaml', 'yml'}:
|
||||
elements = [{k: v for k, v in zip(columns, r) if not header or str(v)} for r in rows]
|
||||
elements = [{k: v for k, v in zip(columns or [], r) if not header or str(v)} for r in rows]
|
||||
func = json.dumps if fmt == 'json' else format_config_for_editing
|
||||
click.echo(func(elements))
|
||||
elif fmt in {'pretty', 'tsv', 'topology'}:
|
||||
list_cluster = bool(header and columns and columns[0] == 'Cluster')
|
||||
if list_cluster and 'Tags' in columns: # we want to format member tags as YAML
|
||||
if list_cluster and columns and 'Tags' in columns: # we want to format member tags as YAML
|
||||
i = columns.index('Tags')
|
||||
for row in rows:
|
||||
if row[i]:
|
||||
row[i] = format_config_for_editing(row[i], fmt != 'pretty').strip()
|
||||
if list_cluster and fmt != 'tsv': # skip cluster name and maybe Citus group if pretty-printing
|
||||
if list_cluster and header and fmt != 'tsv': # skip cluster name and maybe Citus group if pretty-printing
|
||||
skip_cols = 2 if ' (group: ' in header else 1
|
||||
columns = columns[skip_cols:] if columns else []
|
||||
rows = [row[skip_cols:] for row in rows]
|
||||
@@ -247,7 +252,7 @@ def print_output(columns, rows, alignment=None, fmt='pretty', header=None, delim
|
||||
click.echo(table)
|
||||
|
||||
|
||||
def watching(w, watch, max_count=None, clear=True):
|
||||
def watching(w: bool, watch: Optional[int], max_count: Optional[int] = None, clear: bool = True) -> Iterator[int]:
|
||||
"""
|
||||
>>> len(list(watching(True, 1, 0)))
|
||||
1
|
||||
@@ -275,7 +280,8 @@ def watching(w, watch, max_count=None, clear=True):
|
||||
yield 0
|
||||
|
||||
|
||||
def get_all_members(obj, cluster, group, role='leader'):
|
||||
def get_all_members(obj: Dict[str, Any], cluster: Cluster,
|
||||
group: Optional[int], role: str = 'leader') -> Iterator[Member]:
|
||||
clusters = {0: cluster}
|
||||
if obj.get('citus') and group is None:
|
||||
clusters.update(cluster.workers)
|
||||
@@ -296,23 +302,25 @@ def get_all_members(obj, cluster, group, role='leader'):
|
||||
yield m
|
||||
|
||||
|
||||
def get_any_member(obj, cluster, group, role='leader', member=None):
|
||||
def get_any_member(obj: Dict[str, Any], cluster: Cluster, group: Optional[int],
|
||||
role: str = 'leader', member: Optional[str] = None) -> Optional[Member]:
|
||||
for m in get_all_members(obj, cluster, group, role):
|
||||
if member is None or m.name == member:
|
||||
return m
|
||||
|
||||
|
||||
def get_all_members_leader_first(cluster):
|
||||
def get_all_members_leader_first(cluster: Cluster) -> Iterator[Member]:
|
||||
leader_name = cluster.leader.member.name if cluster.leader and cluster.leader.member.api_url else None
|
||||
if leader_name:
|
||||
if leader_name and cluster.leader:
|
||||
yield cluster.leader.member
|
||||
for member in cluster.members:
|
||||
if member.api_url and member.name != leader_name:
|
||||
yield member
|
||||
|
||||
|
||||
def get_cursor(obj, cluster, group, connect_parameters, role='leader', member=None):
|
||||
member = get_any_member(obj, cluster, group, role=role, member=member)
|
||||
def get_cursor(obj: Dict[str, Any], cluster: Cluster, group: Optional[int], connect_parameters: Dict[str, Any],
|
||||
role: str = 'leader', member_name: Optional[str] = None) -> Union['cursor', 'Cursor[Any]', None]:
|
||||
member = get_any_member(obj, cluster, group, role=role, member=member_name)
|
||||
if member is None:
|
||||
return None
|
||||
|
||||
@@ -330,7 +338,8 @@ def get_cursor(obj, cluster, group, connect_parameters, role='leader', member=No
|
||||
return cursor
|
||||
|
||||
cursor.execute('SELECT pg_catalog.pg_is_in_recovery()')
|
||||
in_recovery = cursor.fetchone()[0]
|
||||
row = cursor.fetchone()
|
||||
in_recovery = not row or row[0]
|
||||
|
||||
if in_recovery and role in ('replica', 'standby', 'standby-leader')\
|
||||
or not in_recovery and role in ('master', 'primary'):
|
||||
@@ -341,7 +350,8 @@ def get_cursor(obj, cluster, group, connect_parameters, role='leader', member=No
|
||||
return None
|
||||
|
||||
|
||||
def get_members(obj, cluster, cluster_name, member_names, role, force, action, ask_confirmation=True, group=None):
|
||||
def get_members(obj: Dict[str, Any], cluster: Cluster, cluster_name: str, member_names: List[str], role: str,
|
||||
force: bool, action: str, ask_confirmation: bool = True, group: Optional[int] = None) -> List[Member]:
|
||||
members = list(get_all_members(obj, cluster, group, role))
|
||||
|
||||
candidates = {m.name for m in members}
|
||||
@@ -371,7 +381,8 @@ def get_members(obj, cluster, cluster_name, member_names, role, force, action, a
|
||||
return members
|
||||
|
||||
|
||||
def confirm_members_action(members, force, action, scheduled_at=None):
|
||||
def confirm_members_action(members: List[Member], force: bool, action: str,
|
||||
scheduled_at: Optional[datetime.datetime] = None) -> None:
|
||||
if scheduled_at:
|
||||
if not force:
|
||||
confirm = click.confirm('Are you sure you want to schedule {0} of members {1} at {2}?'
|
||||
@@ -392,12 +403,13 @@ def confirm_members_action(members, force, action, scheduled_at=None):
|
||||
@arg_cluster_name
|
||||
@option_citus_group
|
||||
@click.pass_obj
|
||||
def dsn(obj, cluster_name, group, role, member):
|
||||
def dsn(obj: Dict[str, Any], cluster_name: str, group: Optional[int],
|
||||
role: Optional[str], member: Optional[str]) -> None:
|
||||
if member is not None:
|
||||
if role is not None:
|
||||
raise PatroniCtlException('--role and --member are mutually exclusive options')
|
||||
role = 'any'
|
||||
if member is None and role is None:
|
||||
elif role is None:
|
||||
role = 'leader'
|
||||
|
||||
cluster = get_dcs(obj, cluster_name, group).get_cluster()
|
||||
@@ -425,33 +437,36 @@ def dsn(obj, cluster_name, group, role, member):
|
||||
@click.option('-d', '--dbname', help='database name to connect to', type=str)
|
||||
@click.pass_obj
|
||||
def query(
|
||||
obj,
|
||||
cluster_name,
|
||||
group,
|
||||
role,
|
||||
member,
|
||||
w,
|
||||
watch,
|
||||
delimiter,
|
||||
command,
|
||||
p_file,
|
||||
password,
|
||||
username,
|
||||
dbname,
|
||||
fmt='tsv',
|
||||
):
|
||||
obj: Dict[str, Any],
|
||||
cluster_name: str,
|
||||
group: Optional[int],
|
||||
role: Optional[str],
|
||||
member: Optional[str],
|
||||
w: bool,
|
||||
watch: Optional[int],
|
||||
delimiter: str,
|
||||
command: Optional[str],
|
||||
p_file: Optional[io.BufferedReader],
|
||||
password: Optional[bool],
|
||||
username: Optional[str],
|
||||
dbname: Optional[str],
|
||||
fmt: str = 'tsv'
|
||||
) -> None:
|
||||
if member is not None:
|
||||
if role is not None:
|
||||
raise PatroniCtlException('--role and --member are mutually exclusive options')
|
||||
role = 'any'
|
||||
if member is None and role is None:
|
||||
elif role is None:
|
||||
role = 'leader'
|
||||
|
||||
if p_file is not None and command is not None:
|
||||
raise PatroniCtlException('--file and --command are mutually exclusive options')
|
||||
|
||||
if p_file is None and command is None:
|
||||
raise PatroniCtlException('You need to specify either --command or --file')
|
||||
if p_file is not None:
|
||||
if command is not None:
|
||||
raise PatroniCtlException('--file and --command are mutually exclusive options')
|
||||
sql = p_file.read().decode('utf-8')
|
||||
else:
|
||||
if command is None:
|
||||
raise PatroniCtlException('You need to specify either --command or --file')
|
||||
sql = command
|
||||
|
||||
connect_parameters = {}
|
||||
if username:
|
||||
@@ -461,25 +476,25 @@ def query(
|
||||
if dbname:
|
||||
connect_parameters['dbname'] = dbname
|
||||
|
||||
if p_file is not None:
|
||||
command = p_file.read()
|
||||
|
||||
dcs = get_dcs(obj, cluster_name, group)
|
||||
|
||||
cursor = None
|
||||
cluster = cursor = None
|
||||
for _ in watching(w, watch, clear=False):
|
||||
if cursor is None:
|
||||
if cluster is None:
|
||||
cluster = dcs.get_cluster()
|
||||
# cursor = get_cursor(obj, cluster, group, connect_parameters, role=role, member=member)
|
||||
|
||||
output, header = query_member(obj, cluster, group, cursor, member, role, command, connect_parameters)
|
||||
output, header = query_member(obj, cluster, group, cursor, member, role, sql, connect_parameters)
|
||||
print_output(header, output, fmt=fmt, delimiter=delimiter)
|
||||
|
||||
|
||||
def query_member(obj, cluster, group, cursor, member, role, command, connect_parameters):
|
||||
def query_member(obj: Dict[str, Any], cluster: Cluster, group: Optional[int],
|
||||
cursor: Union['cursor', 'Cursor[Any]', None], member: Optional[str], role: str,
|
||||
command: str, connect_parameters: Dict[str, Any]) -> Tuple[List[List[Any]], Optional[List[Any]]]:
|
||||
from . import psycopg
|
||||
try:
|
||||
if cursor is None:
|
||||
cursor = get_cursor(obj, cluster, group, connect_parameters, role=role, member=member)
|
||||
cursor = get_cursor(obj, cluster, group, connect_parameters, role=role, member_name=member)
|
||||
|
||||
if cursor is None:
|
||||
if member is not None:
|
||||
@@ -489,8 +504,8 @@ def query_member(obj, cluster, group, cursor, member, role, command, connect_par
|
||||
logging.debug(message)
|
||||
return [[timestamp(0), message]], None
|
||||
|
||||
cursor.execute(command)
|
||||
return cursor.fetchall(), [d.name for d in cursor.description]
|
||||
cursor.execute(command.encode('utf-8'))
|
||||
return [list(row) for row in cursor], cursor.description and [d.name for d in cursor.description]
|
||||
except psycopg.DatabaseError as de:
|
||||
logging.debug(de)
|
||||
if cursor is not None and not cursor.connection.closed:
|
||||
@@ -505,7 +520,7 @@ def query_member(obj, cluster, group, cursor, member, role, command, connect_par
|
||||
@option_citus_group
|
||||
@option_format
|
||||
@click.pass_obj
|
||||
def remove(obj, cluster_name, group, fmt):
|
||||
def remove(obj: Dict[str, Any], cluster_name: str, group: Optional[int], fmt: str) -> None:
|
||||
dcs = get_dcs(obj, cluster_name, group)
|
||||
cluster = dcs.get_cluster()
|
||||
|
||||
@@ -532,7 +547,8 @@ def remove(obj, cluster_name, group, fmt):
|
||||
dcs.delete_cluster()
|
||||
|
||||
|
||||
def check_response(response, member_name, action_name, silent_success=False):
|
||||
def check_response(response: urllib3.response.HTTPResponse, member_name: str,
|
||||
action_name: str, silent_success: bool = False) -> bool:
|
||||
if response.status >= 400:
|
||||
click.echo('Failed: {0} for member {1}, status code={2}, ({3})'.format(
|
||||
action_name, member_name, response.status, response.data.decode('utf-8')
|
||||
@@ -543,8 +559,8 @@ def check_response(response, member_name, action_name, silent_success=False):
|
||||
return True
|
||||
|
||||
|
||||
def parse_scheduled(scheduled):
|
||||
if (scheduled or 'now') != 'now':
|
||||
def parse_scheduled(scheduled: Optional[str]) -> Optional[datetime.datetime]:
|
||||
if scheduled is not None and (scheduled or 'now') != 'now':
|
||||
try:
|
||||
scheduled_at = dateutil.parser.parse(scheduled)
|
||||
if scheduled_at.tzinfo is None:
|
||||
@@ -564,7 +580,8 @@ def parse_scheduled(scheduled):
|
||||
@click.option('--role', '-r', help='Reload only members with this role', type=role_choice, default='any')
|
||||
@option_force
|
||||
@click.pass_obj
|
||||
def reload(obj, cluster_name, member_names, group, force, role):
|
||||
def reload(obj: Dict[str, Any], cluster_name: str, member_names: List[str],
|
||||
group: Optional[int], force: bool, role: str) -> None:
|
||||
dcs = get_dcs(obj, cluster_name, group)
|
||||
cluster = dcs.get_cluster()
|
||||
|
||||
@@ -575,8 +592,10 @@ def reload(obj, cluster_name, member_names, group, force, role):
|
||||
if r.status == 200:
|
||||
click.echo('No changes to apply on member {0}'.format(member.name))
|
||||
elif r.status == 202:
|
||||
from patroni.config import get_global_config
|
||||
config = get_global_config(cluster)
|
||||
click.echo('Reload request received for member {0} and will be processed within {1} seconds'.format(
|
||||
member.name, cluster.config.data.get('loop_wait', dcs.loop_wait))
|
||||
member.name, config.get('loop_wait') or dcs.loop_wait)
|
||||
)
|
||||
else:
|
||||
click.echo('Failed: reload for member {0}, status code={1}, ({2})'.format(
|
||||
@@ -595,11 +614,12 @@ def reload(obj, cluster_name, member_names, group, force, role):
|
||||
@click.option('--pg-version', 'version', help='Restart if the PostgreSQL version is less than provided (e.g. 9.5.2)',
|
||||
default=None)
|
||||
@click.option('--pending', help='Restart if pending', is_flag=True)
|
||||
@click.option('--timeout',
|
||||
help='Return error and fail over if necessary when restarting takes longer than this.')
|
||||
@click.option('--timeout', help='Return error and fail over if necessary when restarting takes longer than this.')
|
||||
@option_force
|
||||
@click.pass_obj
|
||||
def restart(obj, cluster_name, group, member_names, force, role, p_any, scheduled, version, pending, timeout):
|
||||
def restart(obj: Dict[str, Any], cluster_name: str, group: Optional[int], member_names: List[str],
|
||||
force: bool, role: str, p_any: bool, scheduled: Optional[str], version: Optional[str],
|
||||
pending: bool, timeout: Optional[str]) -> None:
|
||||
cluster = get_dcs(obj, cluster_name, group).get_cluster()
|
||||
|
||||
members = get_members(obj, cluster, cluster_name, member_names, role, force, 'restart', False, group=group)
|
||||
@@ -666,13 +686,14 @@ def restart(obj, cluster_name, group, member_names, force, role, p_any, schedule
|
||||
@option_force
|
||||
@click.option('--wait', help='Wait until reinitialization completes', is_flag=True)
|
||||
@click.pass_obj
|
||||
def reinit(obj, cluster_name, group, member_names, force, wait):
|
||||
def reinit(obj: Dict[str, Any], cluster_name: str, group: Optional[int],
|
||||
member_names: List[str], force: bool, wait: bool) -> None:
|
||||
cluster = get_dcs(obj, cluster_name, group).get_cluster()
|
||||
members = get_members(obj, cluster, cluster_name, member_names, 'replica', force, 'reinitialize', group=group)
|
||||
|
||||
wait_on_members = []
|
||||
wait_on_members: List[Member] = []
|
||||
for member in members:
|
||||
body = {'force': force}
|
||||
body: Dict[str, bool] = {'force': force}
|
||||
while True:
|
||||
r = request_patroni(member, 'post', 'reinitialize', body)
|
||||
started = check_response(r, member.name, 'reinitialize')
|
||||
@@ -699,7 +720,9 @@ def reinit(obj, cluster_name, group, member_names, force, wait):
|
||||
wait_on_members.remove(member)
|
||||
|
||||
|
||||
def _do_failover_or_switchover(obj, action, cluster_name, group, leader, candidate, force, scheduled=None):
|
||||
def _do_failover_or_switchover(obj: Dict[str, Any], action: str, cluster_name: str,
|
||||
group: Optional[int], leader: Optional[str], candidate: Optional[str],
|
||||
force: bool, scheduled: Optional[str] = None) -> None:
|
||||
"""
|
||||
We want to trigger a failover or switchover for the specified cluster name.
|
||||
|
||||
@@ -729,7 +752,7 @@ def _do_failover_or_switchover(obj, action, cluster_name, group, leader, candida
|
||||
else:
|
||||
from patroni.config import get_global_config
|
||||
prompt = 'Standby Leader' if get_global_config(cluster).is_standby_cluster else 'Primary'
|
||||
leader = click.prompt(prompt, type=str, default=cluster.leader.member.name)
|
||||
leader = click.prompt(prompt, type=str, default=(cluster.leader and cluster.leader.member.name))
|
||||
|
||||
if leader is not None and cluster.leader and cluster.leader.member.name != leader:
|
||||
raise PatroniCtlException('Member {0} is not the leader of cluster {1}'.format(leader, cluster_name))
|
||||
@@ -788,8 +811,8 @@ def _do_failover_or_switchover(obj, action, cluster_name, group, leader, candida
|
||||
|
||||
r = None
|
||||
try:
|
||||
member = cluster.leader.member if cluster.leader else cluster.get_member(candidate, False)
|
||||
|
||||
member = cluster.leader.member if cluster.leader else candidate and cluster.get_member(candidate, False)
|
||||
assert isinstance(member, Member)
|
||||
r = request_patroni(member, 'post', action, failover_value)
|
||||
|
||||
# probably old patroni, which doesn't support switchover yet
|
||||
@@ -820,7 +843,8 @@ def _do_failover_or_switchover(obj, action, cluster_name, group, leader, candida
|
||||
@click.option('--candidate', help='The name of the candidate', default=None)
|
||||
@option_force
|
||||
@click.pass_obj
|
||||
def failover(obj, cluster_name, group, leader, candidate, force):
|
||||
def failover(obj: Dict[str, Any], cluster_name: str, group: Optional[int],
|
||||
leader: Optional[str], candidate: Optional[str], force: bool) -> None:
|
||||
action = 'switchover' if leader else 'failover'
|
||||
_do_failover_or_switchover(obj, action, cluster_name, group, leader, candidate, force)
|
||||
|
||||
@@ -834,11 +858,13 @@ def failover(obj, cluster_name, group, leader, candidate, force):
|
||||
default=None)
|
||||
@option_force
|
||||
@click.pass_obj
|
||||
def switchover(obj, cluster_name, group, leader, candidate, force, scheduled):
|
||||
def switchover(obj: Dict[str, Any], cluster_name: str, group: Optional[int],
|
||||
leader: Optional[str], candidate: Optional[str], force: bool, scheduled: Optional[str]) -> None:
|
||||
_do_failover_or_switchover(obj, 'switchover', cluster_name, group, leader, candidate, force, scheduled)
|
||||
|
||||
|
||||
def generate_topology(level, member, topology):
|
||||
def generate_topology(level: int, member: Dict[str, Any],
|
||||
topology: Dict[str, List[Dict[str, Any]]]) -> Iterator[Dict[str, Any]]:
|
||||
members = topology.get(member['name'], [])
|
||||
|
||||
if level > 0:
|
||||
@@ -852,8 +878,8 @@ def generate_topology(level, member, topology):
|
||||
yield member
|
||||
|
||||
|
||||
def topology_sort(members):
|
||||
topology = defaultdict(list)
|
||||
def topology_sort(members: List[Dict[str, Any]]) -> Iterator[Dict[str, Any]]:
|
||||
topology: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
|
||||
leader = next((m for m in members if m['role'].endswith('leader')), {'name': None})
|
||||
replicas = set(member['name'] for member in members if not member['role'].endswith('leader'))
|
||||
for member in members:
|
||||
@@ -865,8 +891,8 @@ def topology_sort(members):
|
||||
yield member
|
||||
|
||||
|
||||
def get_cluster_service_info(cluster):
|
||||
service_info = []
|
||||
def get_cluster_service_info(cluster: Dict[str, Any]) -> List[str]:
|
||||
service_info: List[str] = []
|
||||
if cluster.get('pause'):
|
||||
service_info.append('Maintenance mode: on')
|
||||
|
||||
@@ -879,8 +905,9 @@ def get_cluster_service_info(cluster):
|
||||
return service_info
|
||||
|
||||
|
||||
def output_members(obj, cluster, name, extended=False, fmt='pretty', group=None):
|
||||
rows = []
|
||||
def output_members(obj: Dict[str, Any], cluster: Cluster, name: str,
|
||||
extended: bool = False, fmt: str = 'pretty', group: Optional[int] = None) -> None:
|
||||
rows: List[List[Any]] = []
|
||||
logging.debug(cluster)
|
||||
|
||||
initialize = {None: 'uninitialized', '': 'initializing'}.get(cluster.initialize, cluster.initialize)
|
||||
@@ -905,12 +932,12 @@ def output_members(obj, cluster, name, extended=False, fmt='pretty', group=None)
|
||||
len(set(m['host'] for m in all_members)) < len(all_members)
|
||||
|
||||
sort = topology_sort if fmt == 'topology' else iter
|
||||
for g, cluster in sorted(clusters.items()):
|
||||
for member in sort(cluster['members']):
|
||||
for g, c in sorted(clusters.items()):
|
||||
for member in sort(c['members']):
|
||||
logging.debug(member)
|
||||
|
||||
lag = member.get('lag', '')
|
||||
member.update(cluster=name, member=member['name'], group=g,
|
||||
member.update(c=name, member=member['name'], group=g,
|
||||
host=member.get('host', ''), tl=member.get('timeline', ''),
|
||||
role=member['role'].replace('_', ' ').title(),
|
||||
lag_in_mb=round(lag / 1024 / 1024) if isinstance(lag, int) else lag,
|
||||
@@ -936,8 +963,8 @@ def output_members(obj, cluster, name, extended=False, fmt='pretty', group=None)
|
||||
if fmt not in ('pretty', 'topology'): # Omit service info when using machine-readable formats
|
||||
return
|
||||
|
||||
for g, cluster in sorted(clusters.items()):
|
||||
service_info = get_cluster_service_info(cluster)
|
||||
for g, c in sorted(clusters.items()):
|
||||
service_info = get_cluster_service_info(c)
|
||||
if service_info:
|
||||
if is_citus_cluster and group is None:
|
||||
click.echo('Citus group: {0}'.format(g))
|
||||
@@ -953,7 +980,8 @@ def output_members(obj, cluster, name, extended=False, fmt='pretty', group=None)
|
||||
@option_watch
|
||||
@option_watchrefresh
|
||||
@click.pass_obj
|
||||
def members(obj, cluster_names, group, fmt, watch, w, extended, ts):
|
||||
def members(obj: Dict[str, Any], cluster_names: List[str], group: Optional[int],
|
||||
fmt: str, watch: Optional[int], w: bool, extended: bool, ts: bool) -> None:
|
||||
if not cluster_names:
|
||||
if 'scope' in obj:
|
||||
cluster_names = [obj['scope']]
|
||||
@@ -976,13 +1004,12 @@ def members(obj, cluster_names, group, fmt, watch, w, extended, ts):
|
||||
@option_citus_group
|
||||
@option_watch
|
||||
@option_watchrefresh
|
||||
@click.pass_obj
|
||||
@click.pass_context
|
||||
def topology(ctx, obj, cluster_names, group, watch, w):
|
||||
def topology(ctx: click.Context, cluster_names: List[str], group: Optional[int], watch: Optional[int], w: bool) -> None:
|
||||
ctx.forward(members, fmt='topology')
|
||||
|
||||
|
||||
def timestamp(precision=6):
|
||||
def timestamp(precision: int = 6) -> str:
|
||||
return datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:precision - 7]
|
||||
|
||||
|
||||
@@ -994,7 +1021,8 @@ def timestamp(precision=6):
|
||||
@click.option('--role', '-r', help='Flush only members with this role', type=role_choice, default='any')
|
||||
@option_force
|
||||
@click.pass_obj
|
||||
def flush(obj, cluster_name, group, member_names, force, role, target):
|
||||
def flush(obj: Dict[str, Any], cluster_name: str, group: Optional[int],
|
||||
member_names: List[str], force: bool, role: str, target: str) -> None:
|
||||
dcs = get_dcs(obj, cluster_name, group)
|
||||
cluster = dcs.get_cluster()
|
||||
|
||||
@@ -1015,28 +1043,34 @@ def flush(obj, cluster_name, group, member_names, force, role, target):
|
||||
if r.status in (200, 404):
|
||||
prefix = 'Success' if r.status == 200 else 'Failed'
|
||||
return click.echo('{0}: {1}'.format(prefix, r.data.decode('utf-8')))
|
||||
|
||||
click.echo('Failed: member={0}, status_code={1}, ({2})'.format(
|
||||
member.name, r.status, r.data.decode('utf-8')))
|
||||
except Exception as err:
|
||||
logging.warning(str(err))
|
||||
logging.warning('Member %s is not accessible', member.name)
|
||||
|
||||
click.echo('Failed: member={0}, status_code={1}, ({2})'.format(
|
||||
member.name, r.status, r.data.decode('utf-8')))
|
||||
|
||||
logging.warning('Failing over to DCS')
|
||||
click.echo('{0} Could not find any accessible member of cluster {1}'.format(timestamp(), cluster_name))
|
||||
dcs.manual_failover('', '', index=failover.index)
|
||||
|
||||
|
||||
def wait_until_pause_is_applied(dcs, paused, old_cluster):
|
||||
def wait_until_pause_is_applied(dcs: AbstractDCS, paused: bool, old_cluster: Cluster) -> None:
|
||||
from patroni.config import get_global_config
|
||||
config = get_global_config(old_cluster)
|
||||
|
||||
click.echo("'{0}' request sent, waiting until it is recognized by all nodes".format(paused and 'pause' or 'resume'))
|
||||
old = {m.name: m.index for m in old_cluster.members if m.api_url}
|
||||
loop_wait = old_cluster.config.data.get('loop_wait', dcs.loop_wait)
|
||||
loop_wait = config.get('loop_wait') or dcs.loop_wait
|
||||
|
||||
cluster = None
|
||||
for _ in polling_loop(loop_wait + 1):
|
||||
cluster = dcs.get_cluster()
|
||||
if all(m.data.get('pause', False) == paused for m in cluster.members if m.name in old):
|
||||
break
|
||||
else:
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
assert cluster is not None
|
||||
remaining = [m.name for m in cluster.members if m.data.get('pause', False) != paused
|
||||
and m.name in old and old[m.name] != m.index]
|
||||
if remaining:
|
||||
@@ -1045,7 +1079,7 @@ def wait_until_pause_is_applied(dcs, paused, old_cluster):
|
||||
return click.echo('Success: cluster management is {0}'.format(paused and 'paused' or 'resumed'))
|
||||
|
||||
|
||||
def toggle_pause(config, cluster_name, group, paused, wait):
|
||||
def toggle_pause(config: Dict[str, Any], cluster_name: str, group: Optional[int], paused: bool, wait: bool) -> None:
|
||||
from patroni.config import get_global_config
|
||||
dcs = get_dcs(config, cluster_name, group)
|
||||
cluster = dcs.get_cluster()
|
||||
@@ -1078,7 +1112,7 @@ def toggle_pause(config, cluster_name, group, paused, wait):
|
||||
@option_default_citus_group
|
||||
@click.pass_obj
|
||||
@click.option('--wait', help='Wait until pause is applied on all nodes', is_flag=True)
|
||||
def pause(obj, cluster_name, group, wait):
|
||||
def pause(obj: Dict[str, Any], cluster_name: str, group: Optional[int], wait: bool) -> None:
|
||||
return toggle_pause(obj, cluster_name, group, True, wait)
|
||||
|
||||
|
||||
@@ -1087,12 +1121,12 @@ def pause(obj, cluster_name, group, wait):
|
||||
@option_default_citus_group
|
||||
@click.option('--wait', help='Wait until pause is cleared on all nodes', is_flag=True)
|
||||
@click.pass_obj
|
||||
def resume(obj, cluster_name, group, wait):
|
||||
def resume(obj: Dict[str, Any], cluster_name: str, group: Optional[int], wait: bool) -> None:
|
||||
return toggle_pause(obj, cluster_name, group, False, wait)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temporary_file(contents, suffix='', prefix='tmp'):
|
||||
def temporary_file(contents: bytes, suffix: str = '', prefix: str = 'tmp') -> Generator[str, None, None]:
|
||||
"""Creates a temporary file with specified contents that persists for the context.
|
||||
|
||||
:param contents: binary string that will be written to the file.
|
||||
@@ -1110,12 +1144,12 @@ def temporary_file(contents, suffix='', prefix='tmp'):
|
||||
os.unlink(tmp.name)
|
||||
|
||||
|
||||
def show_diff(before_editing, after_editing):
|
||||
def show_diff(before_editing: str, after_editing: str) -> None:
|
||||
"""Shows a diff between two strings.
|
||||
|
||||
If the output is to a tty the diff will be colored. Inputs are expected to be unicode strings.
|
||||
"""
|
||||
def listify(string):
|
||||
def listify(string: str) -> List[str]:
|
||||
return [line + '\n' for line in string.rstrip('\n').split('\n')]
|
||||
|
||||
unified_diff = difflib.unified_diff(listify(before_editing), listify(after_editing))
|
||||
@@ -1159,7 +1193,7 @@ def show_diff(before_editing, after_editing):
|
||||
click.echo(line.rstrip('\n'))
|
||||
|
||||
|
||||
def format_config_for_editing(data, default_flow_style=False):
|
||||
def format_config_for_editing(data: Any, default_flow_style: bool = False) -> str:
|
||||
"""Formats configuration as YAML for human consumption.
|
||||
|
||||
:param data: configuration as nested dictionaries
|
||||
@@ -1167,7 +1201,7 @@ def format_config_for_editing(data, default_flow_style=False):
|
||||
return yaml.safe_dump(data, default_flow_style=default_flow_style, encoding=None, allow_unicode=True, width=200)
|
||||
|
||||
|
||||
def apply_config_changes(before_editing, data, kvpairs):
|
||||
def apply_config_changes(before_editing: str, data: Dict[str, Any], kvpairs: List[str]) -> Tuple[str, Dict[str, Any]]:
|
||||
"""Applies config changes specified as a list of key-value pairs.
|
||||
|
||||
Keys are interpreted as dotted paths into the configuration data structure. Except for paths beginning with
|
||||
@@ -1181,7 +1215,7 @@ def apply_config_changes(before_editing, data, kvpairs):
|
||||
"""
|
||||
changed_data = copy.deepcopy(data)
|
||||
|
||||
def set_path_value(config, path, value, prefix=()):
|
||||
def set_path_value(config: Dict[str, Any], path: List[str], value: Any, prefix: Tuple[str, ...] = ()):
|
||||
# Postgresql GUCs can't be nested, but can contain dots so we re-flatten the structure for this case
|
||||
if prefix == ('postgresql', 'parameters'):
|
||||
path = ['.'.join(path)]
|
||||
@@ -1208,7 +1242,7 @@ def apply_config_changes(before_editing, data, kvpairs):
|
||||
return format_config_for_editing(changed_data), changed_data
|
||||
|
||||
|
||||
def apply_yaml_file(data, filename):
|
||||
def apply_yaml_file(data: Dict[str, Any], filename: str) -> Tuple[str, Dict[str, Any]]:
|
||||
"""Applies changes from a YAML file to configuration
|
||||
|
||||
:param data: configuration datastructure
|
||||
@@ -1228,7 +1262,7 @@ def apply_yaml_file(data, filename):
|
||||
return format_config_for_editing(changed_data), changed_data
|
||||
|
||||
|
||||
def invoke_editor(before_editing, cluster_name):
|
||||
def invoke_editor(before_editing: str, cluster_name: str) -> Tuple[str, Dict[str, Any]]:
|
||||
"""Starts editor command to edit configuration in human readable format
|
||||
|
||||
:param before_editing: human representation before editing
|
||||
@@ -1272,10 +1306,15 @@ def invoke_editor(before_editing, cluster_name):
|
||||
' Use - for stdin.')
|
||||
@option_force
|
||||
@click.pass_obj
|
||||
def edit_config(obj, cluster_name, group, force, quiet, kvpairs, pgkvpairs, apply_filename, replace_filename):
|
||||
def edit_config(obj: Dict[str, Any], cluster_name: str, group: Optional[int],
|
||||
force: bool, quiet: bool, kvpairs: List[str], pgkvpairs: List[str],
|
||||
apply_filename: Optional[str], replace_filename: Optional[str]) -> None:
|
||||
dcs = get_dcs(obj, cluster_name, group)
|
||||
cluster = dcs.get_cluster()
|
||||
|
||||
if not cluster.config:
|
||||
raise PatroniCtlException('The config key does not exist in the cluster {0}'.format(cluster_name))
|
||||
|
||||
before_editing = format_config_for_editing(cluster.config.data)
|
||||
|
||||
after_editing = None # Serves as a flag if any changes were requested
|
||||
@@ -1317,10 +1356,10 @@ def edit_config(obj, cluster_name, group, force, quiet, kvpairs, pgkvpairs, appl
|
||||
@arg_cluster_name
|
||||
@option_default_citus_group
|
||||
@click.pass_obj
|
||||
def show_config(obj, cluster_name, group):
|
||||
def show_config(obj: Dict[str, Any], cluster_name: str, group: Optional[int]) -> None:
|
||||
cluster = get_dcs(obj, cluster_name, group).get_cluster()
|
||||
|
||||
click.echo(format_config_for_editing(cluster.config.data))
|
||||
if cluster.config:
|
||||
click.echo(format_config_for_editing(cluster.config.data))
|
||||
|
||||
|
||||
@ctl.command('version', help='Output version of patronictl command or a running Patroni instance')
|
||||
@@ -1328,7 +1367,7 @@ def show_config(obj, cluster_name, group):
|
||||
@click.argument('member_names', nargs=-1)
|
||||
@option_citus_group
|
||||
@click.pass_obj
|
||||
def version(obj, cluster_name, group, member_names):
|
||||
def version(obj: Dict[str, Any], cluster_name: str, group: Optional[int], member_names: List[str]) -> None:
|
||||
click.echo("patronictl version {0}".format(__version__))
|
||||
|
||||
if not cluster_name:
|
||||
@@ -1355,9 +1394,9 @@ def version(obj, cluster_name, group, member_names):
|
||||
@option_default_citus_group
|
||||
@option_format
|
||||
@click.pass_obj
|
||||
def history(obj, cluster_name, group, fmt):
|
||||
def history(obj: Dict[str, Any], cluster_name: str, group: Optional[int], fmt: str) -> None:
|
||||
cluster = get_dcs(obj, cluster_name, group).get_cluster()
|
||||
history = cluster.history and cluster.history.lines or []
|
||||
history: List[List[Any]] = list(map(list, cluster.history and cluster.history.lines or []))
|
||||
table_header_row = ['TL', 'LSN', 'Reason', 'Timestamp', 'New Leader']
|
||||
for line in history:
|
||||
if len(line) < len(table_header_row):
|
||||
@@ -1367,7 +1406,7 @@ def history(obj, cluster_name, group, fmt):
|
||||
print_output(table_header_row, history, {'TL': 'r', 'LSN': 'r'}, fmt)
|
||||
|
||||
|
||||
def format_pg_version(version):
|
||||
def format_pg_version(version: int) -> str:
|
||||
if version < 100000:
|
||||
return "{0}.{1}.{2}".format(version // 10000, version // 100 % 100, version % 100)
|
||||
else:
|
||||
|
||||
@@ -11,10 +11,11 @@ import signal
|
||||
import sys
|
||||
|
||||
from threading import Lock
|
||||
from typing import Any, Optional, Type
|
||||
from typing import Any, Optional, Type, TYPE_CHECKING
|
||||
|
||||
from .config import Config
|
||||
from .validator import Schema
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from .config import Config
|
||||
from .validator import Schema
|
||||
|
||||
|
||||
class AbstractPatroniDaemon(abc.ABC):
|
||||
@@ -30,7 +31,7 @@ class AbstractPatroniDaemon(abc.ABC):
|
||||
:ivar config: configuration options for this daemon.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Config) -> None:
|
||||
def __init__(self, config: 'Config') -> None:
|
||||
"""Set up signal handlers, logging handler and configuration.
|
||||
|
||||
:param config: configuration options for this daemon.
|
||||
@@ -94,7 +95,7 @@ class AbstractPatroniDaemon(abc.ABC):
|
||||
with self._sigterm_lock:
|
||||
return self._received_sigterm
|
||||
|
||||
def reload_config(self, sighup: Optional[bool] = False, local: Optional[bool] = False) -> None:
|
||||
def reload_config(self, sighup: bool = False, local: Optional[bool] = False) -> None:
|
||||
"""Reload configuration.
|
||||
|
||||
:param sighup: if it is related to a SIGHUP signal.
|
||||
@@ -140,7 +141,7 @@ class AbstractPatroniDaemon(abc.ABC):
|
||||
self.logger.shutdown()
|
||||
|
||||
|
||||
def abstract_main(cls: Type[AbstractPatroniDaemon], validator: Optional[Schema] = None) -> None:
|
||||
def abstract_main(cls: Type[AbstractPatroniDaemon], validator: Optional['Schema'] = None) -> None:
|
||||
"""Create the main entry point of a given daemon process.
|
||||
|
||||
Expose a basic argument parser, parse the command-line arguments, and run the given daemon process.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,16 +8,19 @@ import ssl
|
||||
import time
|
||||
import urllib3
|
||||
|
||||
from collections import defaultdict, namedtuple
|
||||
from collections import defaultdict
|
||||
from consul import ConsulException, NotFound, base
|
||||
from http.client import HTTPException
|
||||
from urllib3.exceptions import HTTPError
|
||||
from urllib.parse import urlencode, urlparse, quote
|
||||
from typing import Any, Callable, Dict, List, Mapping, NamedTuple, Optional, Union, Tuple, TYPE_CHECKING
|
||||
|
||||
from . import AbstractDCS, Cluster, ClusterConfig, Failover, Leader, Member, SyncState,\
|
||||
TimelineHistory, ReturnFalseException, catch_return_false_exception, citus_group_re
|
||||
from ..exceptions import DCSError
|
||||
from ..utils import deep_compare, parse_bool, Retry, RetryFailedError, split_host_port, uri, USER_AGENT
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from ..config import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -38,12 +41,17 @@ class InvalidSession(ConsulException):
|
||||
"""invalid session"""
|
||||
|
||||
|
||||
Response = namedtuple('Response', 'code,headers,body,content')
|
||||
class Response(NamedTuple):
|
||||
code: int
|
||||
headers: Union[Mapping[str, str], Mapping[bytes, bytes], None]
|
||||
body: str
|
||||
content: bytes
|
||||
|
||||
|
||||
class HTTPClient(object):
|
||||
|
||||
def __init__(self, host='127.0.0.1', port=8500, token=None, scheme='http', verify=True, cert=None, ca_cert=None):
|
||||
def __init__(self, host: str = '127.0.0.1', port: int = 8500, token: Optional[str] = None, scheme: str = 'http',
|
||||
verify: bool = True, cert: Optional[str] = None, ca_cert: Optional[str] = None) -> None:
|
||||
self.token = token
|
||||
self._read_timeout = 10
|
||||
self.base_uri = uri(scheme, (host, port))
|
||||
@@ -60,22 +68,22 @@ class HTTPClient(object):
|
||||
kwargs['ca_certs'] = ca_cert
|
||||
kwargs['cert_reqs'] = ssl.CERT_REQUIRED if verify or ca_cert else ssl.CERT_NONE
|
||||
self.http = urllib3.PoolManager(num_pools=10, maxsize=10, **kwargs)
|
||||
self._ttl = None
|
||||
self._ttl = 30
|
||||
|
||||
def set_read_timeout(self, timeout):
|
||||
def set_read_timeout(self, timeout: float) -> None:
|
||||
self._read_timeout = timeout / 3.0
|
||||
|
||||
@property
|
||||
def ttl(self):
|
||||
def ttl(self) -> int:
|
||||
return self._ttl
|
||||
|
||||
def set_ttl(self, ttl):
|
||||
def set_ttl(self, ttl: int) -> bool:
|
||||
ret = self._ttl != ttl
|
||||
self._ttl = ttl
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def response(response):
|
||||
def response(response: urllib3.response.HTTPResponse) -> Response:
|
||||
content = response.data
|
||||
body = content.decode('utf-8')
|
||||
if response.status == 500:
|
||||
@@ -88,14 +96,19 @@ class HTTPClient(object):
|
||||
raise ConsulInternalError(msg)
|
||||
return Response(response.status, response.headers, body, content)
|
||||
|
||||
def uri(self, path, params=None):
|
||||
def uri(self, path: str,
|
||||
params: Union[None, Dict[str, Any], List[Tuple[str, Any]], Tuple[Tuple[str, Any], ...]] = None) -> str:
|
||||
return '{0}{1}{2}'.format(self.base_uri, path, params and '?' + urlencode(params) or '')
|
||||
|
||||
def __getattr__(self, method):
|
||||
def __getattr__(self, method: str) -> Callable[[Callable[[Response], Union[bool, Any, Tuple[str, Any]]],
|
||||
str, Union[None, Dict[str, Any], List[Tuple[str, Any]]],
|
||||
str, Optional[Dict[str, str]]], Union[bool, Any, Tuple[str, Any]]]:
|
||||
if method not in ('get', 'post', 'put', 'delete'):
|
||||
raise AttributeError("HTTPClient instance has no attribute '{0}'".format(method))
|
||||
|
||||
def wrapper(callback, path, params=None, data='', headers=None):
|
||||
def wrapper(callback: Callable[[Response], Union[bool, Any, Tuple[str, Any]]], path: str,
|
||||
params: Union[None, Dict[str, Any], List[Tuple[str, Any]]] = None, data: str = '',
|
||||
headers: Optional[Dict[str, str]] = None) -> Union[bool, Any, Tuple[str, Any]]:
|
||||
# python-consul doesn't allow to specify ttl smaller then 10 seconds
|
||||
# because session_ttl_min defaults to 10s, so we have to do this ugly dirty hack...
|
||||
if method == 'put' and path == '/v1/session/create':
|
||||
@@ -106,7 +119,7 @@ class HTTPClient(object):
|
||||
data = data[:-1] + ', ' + ttl + '}'
|
||||
if isinstance(params, list): # starting from v1.1.0 python-consul switched from `dict` to `list` for params
|
||||
params = {k: v for k, v in params}
|
||||
kwargs = {'retries': 0, 'preload_content': False, 'body': data}
|
||||
kwargs: Dict[str, Any] = {'retries': 0, 'preload_content': False, 'body': data}
|
||||
if method == 'get' and isinstance(params, dict) and 'index' in params:
|
||||
timeout = float(params['wait'][:-1]) if 'wait' in params else 300
|
||||
# According to the documentation a small random amount of additional wait time is added to the
|
||||
@@ -127,13 +140,13 @@ class HTTPClient(object):
|
||||
|
||||
class ConsulClient(base.Consul):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
self._cert = kwargs.pop('cert', None)
|
||||
self._ca_cert = kwargs.pop('ca_cert', None)
|
||||
self.token = kwargs.get('token')
|
||||
super(ConsulClient, self).__init__(*args, **kwargs)
|
||||
|
||||
def http_connect(self, *args, **kwargs):
|
||||
def http_connect(self, *args: Any, **kwargs: Any) -> HTTPClient:
|
||||
kwargs.update(dict(zip(['host', 'port', 'scheme', 'verify'], args)))
|
||||
if self._cert:
|
||||
kwargs['cert'] = self._cert
|
||||
@@ -143,17 +156,17 @@ class ConsulClient(base.Consul):
|
||||
kwargs['token'] = self.token
|
||||
return HTTPClient(**kwargs)
|
||||
|
||||
def connect(self, *args, **kwargs):
|
||||
def connect(self, *args: Any, **kwargs: Any) -> HTTPClient:
|
||||
return self.http_connect(*args, **kwargs)
|
||||
|
||||
def reload_config(self, config):
|
||||
def reload_config(self, config: Dict[str, Any]) -> None:
|
||||
self.http.token = self.token = config.get('token')
|
||||
self.consistency = config.get('consistency', 'default')
|
||||
self.dc = config.get('dc')
|
||||
|
||||
|
||||
def catch_consul_errors(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
def catch_consul_errors(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except (RetryFailedError, ConsulException, HTTPException, HTTPError, socket.error, socket.timeout):
|
||||
@@ -161,24 +174,25 @@ def catch_consul_errors(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def force_if_last_failed(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
if wrapper.last_result is False:
|
||||
def force_if_last_failed(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
if getattr(wrapper, 'last_result', None) is False:
|
||||
kwargs['force'] = True
|
||||
wrapper.last_result = func(*args, **kwargs)
|
||||
return wrapper.last_result
|
||||
last_result = func(*args, **kwargs)
|
||||
setattr(wrapper, 'last_result', last_result)
|
||||
return last_result
|
||||
|
||||
wrapper.last_result = None
|
||||
setattr(wrapper, 'last_result', None)
|
||||
return wrapper
|
||||
|
||||
|
||||
def service_name_from_scope_name(scope_name):
|
||||
def service_name_from_scope_name(scope_name: str) -> str:
|
||||
"""Translate scope name to service name which can be used in dns.
|
||||
|
||||
230 = 253 - len('replica.') - len('.service.consul')
|
||||
"""
|
||||
|
||||
def replace_char(match):
|
||||
def replace_char(match: Any) -> str:
|
||||
c = match.group(0)
|
||||
return '-' if c in '. _' else "u{:04d}".format(ord(c))
|
||||
|
||||
@@ -188,7 +202,7 @@ def service_name_from_scope_name(scope_name):
|
||||
|
||||
class Consul(AbstractDCS):
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: Dict[str, Any]) -> None:
|
||||
super(Consul, self).__init__(config)
|
||||
self._base_path = self._base_path[1:]
|
||||
self._scope = config['scope']
|
||||
@@ -198,9 +212,9 @@ class Consul(AbstractDCS):
|
||||
retry_exceptions=(ConsulInternalError, HTTPException,
|
||||
HTTPError, socket.error, socket.timeout))
|
||||
|
||||
kwargs = {}
|
||||
if 'url' in config:
|
||||
r = urlparse(config['url'])
|
||||
url: str = config['url']
|
||||
r = urlparse(url)
|
||||
config.update({'scheme': r.scheme, 'host': r.hostname, 'port': r.port or 8500})
|
||||
elif 'host' in config:
|
||||
host, port = split_host_port(config.get('host', '127.0.0.1:8500'), 8500)
|
||||
@@ -215,7 +229,7 @@ class Consul(AbstractDCS):
|
||||
config['cert'] = (config['cert'], config['key'])
|
||||
|
||||
config_keys = ('host', 'port', 'token', 'scheme', 'cert', 'ca_cert', 'dc', 'consistency')
|
||||
kwargs = {p: config.get(p) for p in config_keys if config.get(p)}
|
||||
kwargs: Dict[str, Any] = {p: config.get(p) for p in config_keys if config.get(p)}
|
||||
|
||||
verify = config.get('verify')
|
||||
if not isinstance(verify, bool):
|
||||
@@ -240,10 +254,10 @@ class Consul(AbstractDCS):
|
||||
self.create_session()
|
||||
self._previous_loop_token = self._client.token
|
||||
|
||||
def retry(self, *args, **kwargs):
|
||||
return self._retry.copy()(*args, **kwargs)
|
||||
def retry(self, method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
||||
return self._retry.copy()(method, *args, **kwargs)
|
||||
|
||||
def create_session(self):
|
||||
def create_session(self) -> None:
|
||||
while not self._session:
|
||||
try:
|
||||
self.refresh_session()
|
||||
@@ -251,13 +265,14 @@ class Consul(AbstractDCS):
|
||||
logger.info('waiting on consul')
|
||||
time.sleep(5)
|
||||
|
||||
def reload_config(self, config):
|
||||
def reload_config(self, config: Union['Config', Dict[str, Any]]) -> None:
|
||||
super(Consul, self).reload_config(config)
|
||||
|
||||
consul_config = config.get('consul', {})
|
||||
self._client.reload_config(consul_config)
|
||||
self._previous_loop_service_tags = self._service_tags
|
||||
self._service_tags = sorted(consul_config.get('service_tags', []))
|
||||
self._service_tags: List[str] = consul_config.get('service_tags', [])
|
||||
self._service_tags.sort()
|
||||
|
||||
should_register_service = consul_config.get('register_service', False)
|
||||
if should_register_service and not self._register_service:
|
||||
@@ -266,20 +281,21 @@ class Consul(AbstractDCS):
|
||||
self._previous_loop_register_service = self._register_service
|
||||
self._register_service = should_register_service
|
||||
|
||||
def set_ttl(self, ttl):
|
||||
def set_ttl(self, ttl: int) -> Optional[bool]:
|
||||
if self._client.http.set_ttl(ttl / 2.0): # Consul multiplies the TTL by 2x
|
||||
self._session = None
|
||||
self.__do_not_watch = True
|
||||
return None
|
||||
|
||||
@property
|
||||
def ttl(self):
|
||||
def ttl(self) -> int:
|
||||
return self._client.http.ttl * 2 # we multiply the value by 2 because it was divided in the `set_ttl()` method
|
||||
|
||||
def set_retry_timeout(self, retry_timeout):
|
||||
def set_retry_timeout(self, retry_timeout: int) -> None:
|
||||
self._retry.deadline = retry_timeout
|
||||
self._client.http.set_read_timeout(retry_timeout)
|
||||
|
||||
def adjust_ttl(self):
|
||||
def adjust_ttl(self) -> None:
|
||||
try:
|
||||
settings = self._client.agent.self()
|
||||
min_ttl = (settings['Config']['SessionTTLMin'] or 10000000000) / 1000000000.0
|
||||
@@ -288,7 +304,7 @@ class Consul(AbstractDCS):
|
||||
except Exception:
|
||||
logger.exception('adjust_ttl')
|
||||
|
||||
def _do_refresh_session(self, force=False):
|
||||
def _do_refresh_session(self, force: bool = False) -> bool:
|
||||
""":returns: `!True` if it had to create new session"""
|
||||
if not force and self._session and self._last_session_refresh + self._loop_wait > time.time():
|
||||
return False
|
||||
@@ -312,7 +328,7 @@ class Consul(AbstractDCS):
|
||||
self._last_session_refresh = time.time()
|
||||
return ret
|
||||
|
||||
def refresh_session(self):
|
||||
def refresh_session(self) -> bool:
|
||||
try:
|
||||
return self.retry(self._do_refresh_session)
|
||||
except (ConsulException, RetryFailedError):
|
||||
@@ -320,10 +336,10 @@ class Consul(AbstractDCS):
|
||||
raise ConsulError('Failed to renew/create session')
|
||||
|
||||
@staticmethod
|
||||
def member(node):
|
||||
def member(node: Dict[str, str]) -> Member:
|
||||
return Member.from_node(node['ModifyIndex'], os.path.basename(node['Key']), node.get('Session'), node['Value'])
|
||||
|
||||
def _cluster_from_nodes(self, nodes):
|
||||
def _cluster_from_nodes(self, nodes: Dict[str, Any]) -> Cluster:
|
||||
# get initialize flag
|
||||
initialize = nodes.get(self._INITIALIZE)
|
||||
initialize = initialize and initialize['Value']
|
||||
@@ -351,7 +367,7 @@ class Consul(AbstractDCS):
|
||||
slots = None
|
||||
|
||||
try:
|
||||
last_lsn = int(last_lsn)
|
||||
last_lsn = int(last_lsn or '')
|
||||
except Exception:
|
||||
last_lsn = 0
|
||||
|
||||
@@ -384,7 +400,7 @@ class Consul(AbstractDCS):
|
||||
|
||||
return Cluster(initialize, config, leader, last_lsn, members, failover, sync, history, slots, failsafe)
|
||||
|
||||
def _cluster_loader(self, path):
|
||||
def _cluster_loader(self, path: str) -> Cluster:
|
||||
_, results = self.retry(self._client.kv.get, path, recurse=True)
|
||||
if results is None:
|
||||
raise NotFound
|
||||
@@ -395,9 +411,9 @@ class Consul(AbstractDCS):
|
||||
|
||||
return self._cluster_from_nodes(nodes)
|
||||
|
||||
def _citus_cluster_loader(self, path):
|
||||
def _citus_cluster_loader(self, path: str) -> Dict[int, Cluster]:
|
||||
_, results = self.retry(self._client.kv.get, path, recurse=True)
|
||||
clusters = defaultdict(dict)
|
||||
clusters: Dict[int, Dict[str, Cluster]] = defaultdict(dict)
|
||||
for node in results or []:
|
||||
key = node['Key'][len(path):].split('/', 1)
|
||||
if len(key) == 2 and citus_group_re.match(key[0]):
|
||||
@@ -405,7 +421,9 @@ class Consul(AbstractDCS):
|
||||
clusters[int(key[0])][key[1]] = node
|
||||
return {group: self._cluster_from_nodes(nodes) for group, nodes in clusters.items()}
|
||||
|
||||
def _load_cluster(self, path, loader):
|
||||
def _load_cluster(
|
||||
self, path: str, loader: Callable[[str], Union[Cluster, Dict[int, Cluster]]]
|
||||
) -> Union[Cluster, Dict[int, Cluster]]:
|
||||
try:
|
||||
return loader(path)
|
||||
except NotFound:
|
||||
@@ -415,7 +433,7 @@ class Consul(AbstractDCS):
|
||||
raise ConsulError('Consul is not responding properly')
|
||||
|
||||
@catch_consul_errors
|
||||
def touch_member(self, data):
|
||||
def touch_member(self, data: Dict[str, Any]) -> bool:
|
||||
cluster = self.cluster
|
||||
member = cluster and cluster.get_member(self._name, fallback_to_leader=False)
|
||||
|
||||
@@ -447,30 +465,32 @@ class Consul(AbstractDCS):
|
||||
logger.exception('touch_member')
|
||||
return False
|
||||
|
||||
def _set_service_name(self):
|
||||
def _set_service_name(self) -> None:
|
||||
self._service_name = service_name_from_scope_name(self._scope)
|
||||
if self._scope != self._service_name:
|
||||
logger.warning('Using %s as consul service name instead of scope name %s', self._service_name, self._scope)
|
||||
|
||||
@catch_consul_errors
|
||||
def register_service(self, service_name, **kwargs):
|
||||
def register_service(self, service_name: str, **kwargs: Any) -> bool:
|
||||
logger.info('Register service %s, params %s', service_name, kwargs)
|
||||
return self._client.agent.service.register(service_name, **kwargs)
|
||||
|
||||
@catch_consul_errors
|
||||
def deregister_service(self, service_id):
|
||||
def deregister_service(self, service_id: str) -> bool:
|
||||
logger.info('Deregister service %s', service_id)
|
||||
# service_id can contain special characters, but is used as part of uri in deregister request
|
||||
service_id = quote(service_id)
|
||||
return self._client.agent.service.deregister(service_id)
|
||||
|
||||
def _update_service(self, data):
|
||||
def _update_service(self, data: Dict[str, Any]) -> Optional[bool]:
|
||||
service_name = self._service_name
|
||||
role = data['role'].replace('_', '-')
|
||||
state = data['state']
|
||||
api_parts = urlparse(data['api_url'])
|
||||
api_url: str = data['api_url']
|
||||
api_parts = urlparse(api_url)
|
||||
api_parts = api_parts._replace(path='/{0}'.format(role))
|
||||
conn_parts = urlparse(data['conn_url'])
|
||||
conn_url: str = data['conn_url']
|
||||
conn_parts = urlparse(conn_url)
|
||||
check = base.Check.http(api_parts.geturl(), self._service_check_interval,
|
||||
deregister='{0}s'.format(self._client.http.ttl * 10))
|
||||
if self._service_check_tls_server_name is not None:
|
||||
@@ -506,7 +526,7 @@ class Consul(AbstractDCS):
|
||||
logger.warning('Could not register service: unknown role type %s', role)
|
||||
|
||||
@force_if_last_failed
|
||||
def update_service(self, old_data, new_data, force=False):
|
||||
def update_service(self, old_data: Dict[str, Any], new_data: Dict[str, Any], force: bool = False) -> Optional[bool]:
|
||||
update = False
|
||||
|
||||
for key in ['role', 'api_url', 'conn_url', 'state']:
|
||||
@@ -523,7 +543,7 @@ class Consul(AbstractDCS):
|
||||
):
|
||||
return self._update_service(new_data)
|
||||
|
||||
def _do_attempt_to_acquire_leader(self, retry):
|
||||
def _do_attempt_to_acquire_leader(self, retry: Retry) -> bool:
|
||||
try:
|
||||
return retry(self._client.kv.put, self.leader_path, self._name, acquire=self._session)
|
||||
except InvalidSession:
|
||||
@@ -540,7 +560,7 @@ class Consul(AbstractDCS):
|
||||
return retry(self._client.kv.put, self.leader_path, self._name, acquire=self._session)
|
||||
|
||||
@catch_return_false_exception
|
||||
def attempt_to_acquire_leader(self):
|
||||
def attempt_to_acquire_leader(self) -> bool:
|
||||
retry = self._retry.copy()
|
||||
self._run_and_handle_exceptions(self._do_refresh_session, retry=retry)
|
||||
|
||||
@@ -554,31 +574,31 @@ class Consul(AbstractDCS):
|
||||
|
||||
return ret
|
||||
|
||||
def take_leader(self):
|
||||
def take_leader(self) -> bool:
|
||||
return self.attempt_to_acquire_leader()
|
||||
|
||||
@catch_consul_errors
|
||||
def set_failover_value(self, value, index=None):
|
||||
def set_failover_value(self, value: str, index: Optional[int] = None) -> bool:
|
||||
return self._client.kv.put(self.failover_path, value, cas=index)
|
||||
|
||||
@catch_consul_errors
|
||||
def set_config_value(self, value, index=None):
|
||||
def set_config_value(self, value: str, index: Optional[int] = None) -> bool:
|
||||
return self._client.kv.put(self.config_path, value, cas=index)
|
||||
|
||||
@catch_consul_errors
|
||||
def _write_leader_optime(self, last_lsn):
|
||||
def _write_leader_optime(self, last_lsn: str) -> bool:
|
||||
return self._client.kv.put(self.leader_optime_path, last_lsn)
|
||||
|
||||
@catch_consul_errors
|
||||
def _write_status(self, value):
|
||||
def _write_status(self, value: str) -> bool:
|
||||
return self._client.kv.put(self.status_path, value)
|
||||
|
||||
@catch_consul_errors
|
||||
def _write_failsafe(self, value):
|
||||
def _write_failsafe(self, value: str) -> bool:
|
||||
return self._client.kv.put(self.failsafe_path, value)
|
||||
|
||||
@staticmethod
|
||||
def _run_and_handle_exceptions(method, *args, **kwargs):
|
||||
def _run_and_handle_exceptions(method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
||||
retry = kwargs.pop('retry', None)
|
||||
try:
|
||||
return retry(method, *args, **kwargs) if retry else method(*args, **kwargs)
|
||||
@@ -588,7 +608,7 @@ class Consul(AbstractDCS):
|
||||
raise ReturnFalseException
|
||||
|
||||
@catch_return_false_exception
|
||||
def _update_leader(self):
|
||||
def _update_leader(self) -> bool:
|
||||
retry = self._retry.copy()
|
||||
|
||||
self._run_and_handle_exceptions(self._do_refresh_session, True, retry=retry)
|
||||
@@ -601,7 +621,7 @@ class Consul(AbstractDCS):
|
||||
if retry.deadline < 1:
|
||||
raise ConsulError('update_leader timeout')
|
||||
logger.warning('Recreating the leader key due to session mismatch')
|
||||
if cluster.leader:
|
||||
if cluster and cluster.leader:
|
||||
self._run_and_handle_exceptions(self._client.kv.delete, self.leader_path, cas=cluster.leader.index)
|
||||
|
||||
retry.deadline = retry.stoptime - time.time()
|
||||
@@ -613,37 +633,39 @@ class Consul(AbstractDCS):
|
||||
return bool(self._session)
|
||||
|
||||
@catch_consul_errors
|
||||
def initialize(self, create_new=True, sysid=''):
|
||||
def initialize(self, create_new: bool = True, sysid: str = '') -> bool:
|
||||
kwargs = {'cas': 0} if create_new else {}
|
||||
return self.retry(self._client.kv.put, self.initialize_path, sysid, **kwargs)
|
||||
|
||||
@catch_consul_errors
|
||||
def cancel_initialization(self):
|
||||
def cancel_initialization(self) -> bool:
|
||||
return self.retry(self._client.kv.delete, self.initialize_path)
|
||||
|
||||
@catch_consul_errors
|
||||
def delete_cluster(self):
|
||||
def delete_cluster(self) -> bool:
|
||||
return self.retry(self._client.kv.delete, self.client_path(''), recurse=True)
|
||||
|
||||
@catch_consul_errors
|
||||
def set_history_value(self, value):
|
||||
def set_history_value(self, value: str) -> bool:
|
||||
return self._client.kv.put(self.history_path, value)
|
||||
|
||||
@catch_consul_errors
|
||||
def _delete_leader(self):
|
||||
def _delete_leader(self) -> bool:
|
||||
cluster = self.cluster
|
||||
if cluster and isinstance(cluster.leader, Leader) and cluster.leader.name == self._name:
|
||||
if cluster and isinstance(cluster.leader, Leader) and\
|
||||
cluster.leader.name == self._name and isinstance(cluster.leader.index, int):
|
||||
return self._client.kv.delete(self.leader_path, cas=cluster.leader.index)
|
||||
return True
|
||||
|
||||
@catch_consul_errors
|
||||
def set_sync_state_value(self, value, index=None):
|
||||
def set_sync_state_value(self, value: str, index: Optional[int] = None) -> bool:
|
||||
return self.retry(self._client.kv.put, self.sync_path, value, cas=index)
|
||||
|
||||
@catch_consul_errors
|
||||
def delete_sync_state(self, index=None):
|
||||
def delete_sync_state(self, index: Optional[int] = None) -> bool:
|
||||
return self.retry(self._client.kv.delete, self.sync_path, cas=index)
|
||||
|
||||
def watch(self, leader_index, timeout):
|
||||
def watch(self, leader_index: Optional[int], timeout: float) -> bool:
|
||||
self._last_session_refresh = 0
|
||||
if self.__do_not_watch:
|
||||
self.__do_not_watch = False
|
||||
|
||||
@@ -16,7 +16,7 @@ from dns import resolver
|
||||
from http.client import HTTPException
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
from typing import List, Optional
|
||||
from typing import Any, Callable, Collection, Dict, List, Optional, Union, Tuple, Type, TYPE_CHECKING
|
||||
from urllib.parse import urlparse
|
||||
from urllib3 import Timeout
|
||||
from urllib3.exceptions import HTTPError, ReadTimeoutError, ProtocolError
|
||||
@@ -26,6 +26,8 @@ from . import AbstractDCS, Cluster, ClusterConfig, Failover, Leader, Member, Syn
|
||||
from ..exceptions import DCSError
|
||||
from ..request import get as requests_get
|
||||
from ..utils import Retry, RetryFailedError, split_host_port, uri, USER_AGENT
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from ..config import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -38,18 +40,21 @@ class EtcdError(DCSError):
|
||||
pass
|
||||
|
||||
|
||||
_AddrInfo = Tuple[socket.AddressFamily, socket.SocketKind, int, str, Union[Tuple[str, int], Tuple[str, int, int, int]]]
|
||||
|
||||
|
||||
class DnsCachingResolver(Thread):
|
||||
|
||||
def __init__(self, cache_time=600.0, cache_fail_time=30.0):
|
||||
def __init__(self, cache_time: float = 600.0, cache_fail_time: float = 30.0) -> None:
|
||||
super(DnsCachingResolver, self).__init__()
|
||||
self._cache = {}
|
||||
self._cache: Dict[Tuple[str, int], Tuple[float, List[_AddrInfo]]] = {}
|
||||
self._cache_time = cache_time
|
||||
self._cache_fail_time = cache_fail_time
|
||||
self._resolve_queue = Queue()
|
||||
self._resolve_queue: Queue[Tuple[Tuple[str, int], int]] = Queue()
|
||||
self.daemon = True
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
def run(self) -> None:
|
||||
while True:
|
||||
(host, port), attempt = self._resolve_queue.get()
|
||||
response = self._do_resolve(host, port)
|
||||
@@ -60,7 +65,7 @@ class DnsCachingResolver(Thread):
|
||||
self.resolve_async(host, port, attempt + 1)
|
||||
time.sleep(1)
|
||||
|
||||
def resolve(self, host, port):
|
||||
def resolve(self, host: str, port: int) -> List[_AddrInfo]:
|
||||
current_time = time.time()
|
||||
cached_time, response = self._cache.get((host, port), (0, []))
|
||||
time_passed = current_time - cached_time
|
||||
@@ -71,14 +76,14 @@ class DnsCachingResolver(Thread):
|
||||
response = new_response
|
||||
return response
|
||||
|
||||
def resolve_async(self, host, port, attempt=0):
|
||||
def resolve_async(self, host: str, port: int, attempt: int = 0) -> None:
|
||||
self._resolve_queue.put(((host, port), attempt))
|
||||
|
||||
def remove(self, host, port):
|
||||
def remove(self, host: str, port: int) -> None:
|
||||
self._cache.pop((host, port), None)
|
||||
|
||||
@staticmethod
|
||||
def _do_resolve(host, port):
|
||||
def _do_resolve(host: str, port: int) -> List[_AddrInfo]:
|
||||
try:
|
||||
return socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM, socket.IPPROTO_TCP)
|
||||
except Exception as e:
|
||||
@@ -88,13 +93,15 @@ class DnsCachingResolver(Thread):
|
||||
|
||||
class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client):
|
||||
|
||||
def __init__(self, config, dns_resolver, cache_ttl=300):
|
||||
ERROR_CLS: Type[Exception]
|
||||
|
||||
def __init__(self, config: Dict[str, Any], dns_resolver: DnsCachingResolver, cache_ttl: int = 300) -> None:
|
||||
self._dns_resolver = dns_resolver
|
||||
self.set_machines_cache_ttl(cache_ttl)
|
||||
self._machines_cache_updated = 0
|
||||
args = {p: config.get(p) for p in ('host', 'port', 'protocol', 'use_proxies', 'username', 'password',
|
||||
'cert', 'ca_cert') if config.get(p)}
|
||||
super(AbstractEtcdClientWithFailover, self).__init__(read_timeout=config['retry_timeout'], **args)
|
||||
kwargs = {p: config.get(p) for p in ('host', 'port', 'protocol', 'use_proxies',
|
||||
'username', 'password', 'cert', 'ca_cert') if config.get(p)}
|
||||
super(AbstractEtcdClientWithFailover, self).__init__(read_timeout=config['retry_timeout'], **kwargs)
|
||||
# For some reason python3-etcd on debian and ubuntu are not based on the latest version
|
||||
# Workaround for the case when https://github.com/jplana/python-etcd/pull/196 is not applied
|
||||
self.http.connection_pool_kw.pop('ssl_version', None)
|
||||
@@ -106,7 +113,7 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client):
|
||||
self._read_options.add('retry')
|
||||
self._del_conditions.add('retry')
|
||||
|
||||
def _calculate_timeouts(self, etcd_nodes, timeout=None):
|
||||
def _calculate_timeouts(self, etcd_nodes: int, timeout: Optional[float] = None) -> Tuple[int, float, int]:
|
||||
"""Calculate a request timeout and number of retries per single etcd node.
|
||||
In case if the timeout per node is too small (less than one second) we will reduce the number of nodes.
|
||||
For the cluster with only one node we will try to do 2 retries.
|
||||
@@ -133,38 +140,39 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client):
|
||||
|
||||
return etcd_nodes, per_node_timeout, per_node_retries - 1
|
||||
|
||||
def reload_config(self, config):
|
||||
def reload_config(self, config: Dict[str, Any]) -> None:
|
||||
self.username = config.get('username')
|
||||
self.password = config.get('password')
|
||||
|
||||
def _get_headers(self):
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
basic_auth = ':'.join((self.username, self.password)) if self.username and self.password else None
|
||||
return urllib3.make_headers(basic_auth=basic_auth, user_agent=USER_AGENT)
|
||||
|
||||
def _prepare_common_parameters(self, etcd_nodes, timeout=None):
|
||||
kwargs = {'headers': self._get_headers(), 'redirect': self.allow_redirect, 'preload_content': False}
|
||||
def _prepare_common_parameters(self, etcd_nodes: int, timeout: Optional[float] = None) -> Dict[str, Any]:
|
||||
kwargs: Dict[str, Any] = {'headers': self._get_headers(),
|
||||
'redirect': self.allow_redirect, 'preload_content': False}
|
||||
|
||||
if timeout is not None:
|
||||
kwargs.update(retries=0, timeout=timeout)
|
||||
else:
|
||||
_, per_node_timeout, per_node_retries = self._calculate_timeouts(etcd_nodes)
|
||||
connect_timeout = max(1, per_node_timeout / 2)
|
||||
connect_timeout = max(1.0, per_node_timeout / 2.0)
|
||||
kwargs.update(timeout=Timeout(connect=connect_timeout, total=per_node_timeout), retries=per_node_retries)
|
||||
return kwargs
|
||||
|
||||
def set_machines_cache_ttl(self, cache_ttl):
|
||||
def set_machines_cache_ttl(self, cache_ttl: int) -> None:
|
||||
self._machines_cache_ttl = cache_ttl
|
||||
|
||||
@abc.abstractmethod
|
||||
def _prepare_get_members(self, etcd_nodes):
|
||||
def _prepare_get_members(self, etcd_nodes: int) -> Dict[str, Any]:
|
||||
"""returns: request parameters"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_members(self, base_uri, **kwargs):
|
||||
def _get_members(self, base_uri: str, **kwargs: Any) -> List[str]:
|
||||
"""returns: list of clientURLs"""
|
||||
|
||||
@property
|
||||
def machines_cache(self):
|
||||
def machines_cache(self) -> List[str]:
|
||||
base_uri, cache = self._base_uri, self._machines_cache
|
||||
return ([base_uri] if base_uri in cache else []) + [machine for machine in cache if machine != base_uri]
|
||||
|
||||
@@ -207,10 +215,14 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client):
|
||||
|
||||
return self._get_machines_list(self.machines_cache)
|
||||
|
||||
def set_read_timeout(self, timeout):
|
||||
def set_read_timeout(self, timeout: float) -> None:
|
||||
self._read_timeout = timeout
|
||||
|
||||
def _do_http_request(self, retry, machines_cache, request_executor, method, path, fields=None, **kwargs):
|
||||
def _do_http_request(self, retry: Optional[Retry], machines_cache: List[str],
|
||||
request_executor: Callable[..., urllib3.response.HTTPResponse],
|
||||
method: str, path: str, fields: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any) -> urllib3.response.HTTPResponse:
|
||||
is_watch_request = isinstance(fields, dict) and fields.get('wait') == 'true'
|
||||
if fields is not None:
|
||||
kwargs['fields'] = fields
|
||||
some_request_failed = False
|
||||
@@ -233,8 +245,7 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client):
|
||||
# whether the key didn't received an update or there is a network problem.
|
||||
elif i + 1 < len(machines_cache):
|
||||
self.set_base_uri(machines_cache[i + 1])
|
||||
if (isinstance(fields, dict) and fields.get("wait") == "true"
|
||||
and isinstance(e, (ReadTimeoutError, ProtocolError))):
|
||||
if is_watch_request and isinstance(e, (ReadTimeoutError, ProtocolError)):
|
||||
logger.debug("Watch timed out.")
|
||||
raise etcd.EtcdWatchTimedOut("Watch timed out: {0}".format(e), cause=e)
|
||||
logger.error("Request to server %s failed: %r", base_uri, e)
|
||||
@@ -246,10 +257,12 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client):
|
||||
raise etcd.EtcdConnectionFailed('No more machines in the cluster')
|
||||
|
||||
@abc.abstractmethod
|
||||
def _prepare_request(self, kwargs, params=None, method=None):
|
||||
def _prepare_request(self, kwargs: Dict[str, Any], params: Optional[Dict[str, Any]] = None,
|
||||
method: Optional[str] = None) -> Callable[..., urllib3.response.HTTPResponse]:
|
||||
"""returns: request_executor"""
|
||||
|
||||
def api_execute(self, path, method, params=None, timeout=None):
|
||||
def api_execute(self, path: str, method: str, params: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[float] = None) -> Any:
|
||||
retry = params.pop('retry', None) if isinstance(params, dict) else None
|
||||
|
||||
# Update machines_cache if previous attempt of update has failed
|
||||
@@ -277,6 +290,7 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client):
|
||||
etcd_nodes = len(machines_cache)
|
||||
except Exception as e:
|
||||
logger.debug('Failed to update list of etcd nodes: %r', e)
|
||||
assert isinstance(retry, Retry) # etcd.EtcdConnectionFailed is raised only if retry is not None!
|
||||
sleeptime = retry.sleeptime
|
||||
remaining_time = retry.stoptime - sleeptime - time.time()
|
||||
nodes, timeout, retries = self._calculate_timeouts(etcd_nodes, remaining_time)
|
||||
@@ -287,22 +301,22 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client):
|
||||
retry.sleep_func(sleeptime)
|
||||
retry.update_delay()
|
||||
# We still have some time left. Partially reduce `machines_cache` and retry request
|
||||
kwargs.update(timeout=Timeout(connect=max(1, timeout / 2), total=timeout), retries=retries)
|
||||
kwargs.update(timeout=Timeout(connect=max(1.0, timeout / 2.0), total=timeout), retries=retries)
|
||||
machines_cache = machines_cache[:nodes]
|
||||
|
||||
@staticmethod
|
||||
def get_srv_record(host):
|
||||
def get_srv_record(host: str) -> List[Tuple[str, int]]:
|
||||
try:
|
||||
return [(r.target.to_text(True), r.port) for r in resolver.query(host, 'SRV')]
|
||||
except DNSException:
|
||||
return []
|
||||
|
||||
def _get_machines_cache_from_srv(self, srv, srv_suffix=None):
|
||||
def _get_machines_cache_from_srv(self, srv: str, srv_suffix: Optional[str] = None) -> List[str]:
|
||||
"""Fetch list of etcd-cluster member by resolving _etcd-server._tcp. SRV record.
|
||||
This record should contain list of host and peer ports which could be used to run
|
||||
'GET http://{host}:{port}/members' request (peer protocol)"""
|
||||
|
||||
ret = []
|
||||
ret: List[str] = []
|
||||
for r in ['-client-ssl', '-client', '-ssl', '', '-server-ssl', '-server']:
|
||||
r = '{0}-{1}'.format(r, srv_suffix) if srv_suffix else r
|
||||
protocol = 'https' if '-ssl' in r else 'http'
|
||||
@@ -327,15 +341,15 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client):
|
||||
logger.warning('Can not resolve SRV for %s', srv)
|
||||
return list(set(ret))
|
||||
|
||||
def _get_machines_cache_from_dns(self, host, port):
|
||||
def _get_machines_cache_from_dns(self, host: str, port: int) -> List[str]:
|
||||
"""One host might be resolved into multiple ip addresses. We will make list out of it"""
|
||||
if self.protocol == 'http':
|
||||
ret = map(lambda res: uri(self.protocol, res[-1][:2]), self._dns_resolver.resolve(host, port))
|
||||
ret = [uri(self.protocol, res[-1][:2]) for res in self._dns_resolver.resolve(host, port)]
|
||||
if ret:
|
||||
return list(set(ret))
|
||||
return [uri(self.protocol, (host, port))]
|
||||
|
||||
def _get_machines_cache_from_config(self):
|
||||
def _get_machines_cache_from_config(self) -> List[str]:
|
||||
if 'proxy' in self._config:
|
||||
return [uri(self.protocol, (self._config['host'], self._config['port']))]
|
||||
|
||||
@@ -351,13 +365,14 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client):
|
||||
return machines_cache
|
||||
|
||||
@staticmethod
|
||||
def _update_dns_cache(func, machines):
|
||||
def _update_dns_cache(func: Callable[[str, int], None], machines: List[str]) -> None:
|
||||
for url in machines:
|
||||
r = urlparse(url)
|
||||
port = r.port or (443 if r.scheme == 'https' else 80)
|
||||
func(r.hostname, port)
|
||||
if r.hostname:
|
||||
port = r.port or (443 if r.scheme == 'https' else 80)
|
||||
func(r.hostname, port)
|
||||
|
||||
def _load_machines_cache(self):
|
||||
def _load_machines_cache(self) -> bool:
|
||||
"""This method should fill up `_machines_cache` from scratch.
|
||||
It could happen only in two cases:
|
||||
1. During class initialization
|
||||
@@ -417,7 +432,7 @@ class AbstractEtcdClientWithFailover(abc.ABC, etcd.Client):
|
||||
self._machines_cache_updated = time.time()
|
||||
return ret
|
||||
|
||||
def set_base_uri(self, value):
|
||||
def set_base_uri(self, value: str) -> None:
|
||||
if self._base_uri != value:
|
||||
logger.info('Selected new etcd server %s', value)
|
||||
self._base_uri = value
|
||||
@@ -427,22 +442,22 @@ class EtcdClient(AbstractEtcdClientWithFailover):
|
||||
|
||||
ERROR_CLS = EtcdError
|
||||
|
||||
def __del__(self):
|
||||
if self.http is not None:
|
||||
try:
|
||||
self.http.clear()
|
||||
except (ReferenceError, TypeError, AttributeError):
|
||||
pass
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
self.http.clear()
|
||||
except (ReferenceError, TypeError, AttributeError):
|
||||
pass
|
||||
|
||||
def _prepare_get_members(self, etcd_nodes):
|
||||
def _prepare_get_members(self, etcd_nodes: int) -> Dict[str, Any]:
|
||||
return self._prepare_common_parameters(etcd_nodes)
|
||||
|
||||
def _get_members(self, base_uri, **kwargs):
|
||||
def _get_members(self, base_uri: str, **kwargs: Any) -> List[str]:
|
||||
response = self.http.request(self._MGET, base_uri + self.version_prefix + '/machines', **kwargs)
|
||||
data = self._handle_server_response(response).data.decode('utf-8')
|
||||
return [m.strip() for m in data.split(',') if m.strip()]
|
||||
|
||||
def _prepare_request(self, kwargs, params=None, method=None):
|
||||
def _prepare_request(self, kwargs: Dict[str, Any], params: Optional[Dict[str, Any]] = None,
|
||||
method: Optional[str] = None) -> Callable[..., urllib3.response.HTTPResponse]:
|
||||
kwargs['fields'] = params
|
||||
if method in (self._MPOST, self._MPUT):
|
||||
kwargs['encode_multipart'] = False
|
||||
@@ -451,25 +466,32 @@ class EtcdClient(AbstractEtcdClientWithFailover):
|
||||
|
||||
class AbstractEtcd(AbstractDCS):
|
||||
|
||||
def __init__(self, config, client_cls, retry_errors_cls):
|
||||
def __init__(self, config: Dict[str, Any], client_cls: Type[AbstractEtcdClientWithFailover],
|
||||
retry_errors_cls: Union[Type[Exception], Tuple[Type[Exception], ...]]) -> None:
|
||||
super(AbstractEtcd, self).__init__(config)
|
||||
self._retry = Retry(deadline=config['retry_timeout'], max_delay=1, max_tries=-1,
|
||||
retry_exceptions=retry_errors_cls)
|
||||
self._ttl = int(config.get('ttl') or 30)
|
||||
self._client = self.get_etcd_client(config, client_cls)
|
||||
self._abstract_client = self.get_etcd_client(config, client_cls)
|
||||
self.__do_not_watch = False
|
||||
self._has_failed = False
|
||||
|
||||
def reload_config(self, config):
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def _client(self) -> AbstractEtcdClientWithFailover:
|
||||
"""return correct type of etcd client"""
|
||||
|
||||
def reload_config(self, config: Union['Config', Dict[str, Any]]) -> None:
|
||||
super(AbstractEtcd, self).reload_config(config)
|
||||
self._client.reload_config(config.get(self.__class__.__name__.lower(), {}))
|
||||
|
||||
def retry(self, *args, **kwargs):
|
||||
def retry(self, method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
||||
retry = self._retry.copy()
|
||||
kwargs['retry'] = retry
|
||||
return retry(*args, **kwargs)
|
||||
return retry(method, *args, **kwargs)
|
||||
|
||||
def _handle_exception(self, e, name='', do_sleep=False, raise_ex=None):
|
||||
def _handle_exception(self, e: Exception, name: str = '', do_sleep: bool = False,
|
||||
raise_ex: Optional[Exception] = None) -> None:
|
||||
if not self._has_failed:
|
||||
logger.exception(name)
|
||||
else:
|
||||
@@ -480,7 +502,7 @@ class AbstractEtcd(AbstractDCS):
|
||||
if isinstance(raise_ex, Exception):
|
||||
raise raise_ex
|
||||
|
||||
def _run_and_handle_exceptions(self, method, *args, **kwargs):
|
||||
def _run_and_handle_exceptions(self, method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
||||
retry = kwargs.pop('retry', self.retry)
|
||||
try:
|
||||
return retry(method, *args, **kwargs) if retry else method(*args, **kwargs)
|
||||
@@ -492,19 +514,20 @@ class AbstractEtcd(AbstractDCS):
|
||||
except Exception as e:
|
||||
self._handle_exception(e, raise_ex=self._client.ERROR_CLS('unexpected error'))
|
||||
|
||||
@staticmethod
|
||||
def set_socket_options(sock, socket_options):
|
||||
def set_socket_options(self, sock: socket.socket,
|
||||
socket_options: Optional[Collection[Tuple[int, int, int]]]) -> None:
|
||||
if socket_options:
|
||||
for opt in socket_options:
|
||||
sock.setsockopt(*opt)
|
||||
|
||||
def get_etcd_client(self, config, client_cls):
|
||||
def get_etcd_client(self, config: Dict[str, Any],
|
||||
client_cls: Type[AbstractEtcdClientWithFailover]) -> AbstractEtcdClientWithFailover:
|
||||
config = deepcopy(config)
|
||||
if 'proxy' in config:
|
||||
config['use_proxies'] = True
|
||||
config['url'] = config['proxy']
|
||||
|
||||
if 'url' in config:
|
||||
if 'url' in config and isinstance(config['url'], str):
|
||||
r = urlparse(config['url'])
|
||||
config.update({'protocol': r.scheme, 'host': r.hostname, 'port': r.port or 2379,
|
||||
'username': r.username, 'password': r.password})
|
||||
@@ -516,10 +539,11 @@ class AbstractEtcd(AbstractDCS):
|
||||
if isinstance(hosts, str):
|
||||
hosts = hosts.split(',')
|
||||
|
||||
config['hosts'] = []
|
||||
config_hosts: List[str] = []
|
||||
for value in hosts:
|
||||
if isinstance(value, str):
|
||||
config['hosts'].append(uri(protocol, split_host_port(value.strip(), default_port)))
|
||||
config_hosts.append(uri(protocol, split_host_port(value.strip(), default_port)))
|
||||
config['hosts'] = config_hosts
|
||||
elif 'host' in config:
|
||||
host, port = split_host_port(config['host'], 2379)
|
||||
config['host'] = host
|
||||
@@ -538,8 +562,10 @@ class AbstractEtcd(AbstractDCS):
|
||||
|
||||
dns_resolver = DnsCachingResolver()
|
||||
|
||||
def create_connection_patched(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
|
||||
source_address=None, socket_options=None):
|
||||
def create_connection_patched(
|
||||
address: Tuple[str, int], timeout: Any = object(),
|
||||
source_address: Optional[Any] = None, socket_options: Optional[Collection[Tuple[int, int, int]]] = None
|
||||
) -> socket.socket:
|
||||
host, port = address
|
||||
if host.startswith('['):
|
||||
host = host.strip('[]')
|
||||
@@ -549,7 +575,7 @@ class AbstractEtcd(AbstractDCS):
|
||||
try:
|
||||
sock = socket.socket(af, socktype, proto)
|
||||
self.set_socket_options(sock, socket_options)
|
||||
if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
|
||||
if timeout is None or isinstance(timeout, (float, int)):
|
||||
sock.settimeout(timeout)
|
||||
if source_address:
|
||||
sock.bind(source_address)
|
||||
@@ -580,7 +606,7 @@ class AbstractEtcd(AbstractDCS):
|
||||
time.sleep(5)
|
||||
return client
|
||||
|
||||
def set_ttl(self, ttl):
|
||||
def set_ttl(self, ttl: int) -> Optional[bool]:
|
||||
ttl = int(ttl)
|
||||
ret = self._ttl != ttl
|
||||
self._ttl = ttl
|
||||
@@ -588,16 +614,16 @@ class AbstractEtcd(AbstractDCS):
|
||||
return ret
|
||||
|
||||
@property
|
||||
def ttl(self):
|
||||
def ttl(self) -> int:
|
||||
return self._ttl
|
||||
|
||||
def set_retry_timeout(self, retry_timeout):
|
||||
def set_retry_timeout(self, retry_timeout: int) -> None:
|
||||
self._retry.deadline = retry_timeout
|
||||
self._client.set_read_timeout(retry_timeout)
|
||||
|
||||
|
||||
def catch_etcd_errors(func):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
def catch_etcd_errors(func: Callable[..., Any]) -> Any:
|
||||
def wrapper(self: AbstractEtcd, *args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
retval = func(self, *args, **kwargs) is not None
|
||||
self._has_failed = False
|
||||
@@ -613,18 +639,24 @@ def catch_etcd_errors(func):
|
||||
|
||||
class Etcd(AbstractEtcd):
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: Dict[str, Any]) -> None:
|
||||
super(Etcd, self).__init__(config, EtcdClient, (etcd.EtcdLeaderElectionInProgress, EtcdRaftInternal))
|
||||
self.__do_not_watch = False
|
||||
|
||||
def set_ttl(self, ttl):
|
||||
@property
|
||||
def _client(self) -> EtcdClient:
|
||||
assert isinstance(self._abstract_client, EtcdClient)
|
||||
return self._abstract_client
|
||||
|
||||
def set_ttl(self, ttl: int) -> Optional[bool]:
|
||||
self.__do_not_watch = super(Etcd, self).set_ttl(ttl)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def member(node):
|
||||
def member(node: etcd.EtcdResult) -> Member:
|
||||
return Member.from_node(node.modifiedIndex, os.path.basename(node.key), node.ttl, node.value)
|
||||
|
||||
def _cluster_from_nodes(self, etcd_index, nodes):
|
||||
def _cluster_from_nodes(self, etcd_index: int, nodes: Dict[str, etcd.EtcdResult]) -> Cluster:
|
||||
# get initialize flag
|
||||
initialize = nodes.get(self._INITIALIZE)
|
||||
initialize = initialize and initialize.value
|
||||
@@ -652,7 +684,7 @@ class Etcd(AbstractEtcd):
|
||||
slots = None
|
||||
|
||||
try:
|
||||
last_lsn = int(last_lsn)
|
||||
last_lsn = int(last_lsn or '')
|
||||
except Exception:
|
||||
last_lsn = 0
|
||||
|
||||
@@ -685,13 +717,13 @@ class Etcd(AbstractEtcd):
|
||||
|
||||
return Cluster(initialize, config, leader, last_lsn, members, failover, sync, history, slots, failsafe)
|
||||
|
||||
def _cluster_loader(self, path):
|
||||
def _cluster_loader(self, path: str) -> Cluster:
|
||||
result = self.retry(self._client.read, path, recursive=True)
|
||||
nodes = {node.key[len(result.key):].lstrip('/'): node for node in result.leaves}
|
||||
return self._cluster_from_nodes(result.etcd_index, nodes)
|
||||
|
||||
def _citus_cluster_loader(self, path):
|
||||
clusters = defaultdict(dict)
|
||||
def _citus_cluster_loader(self, path: str) -> Dict[int, Cluster]:
|
||||
clusters: Dict[int, Dict[str, etcd.EtcdResult]] = defaultdict(dict)
|
||||
result = self.retry(self._client.read, path, recursive=True)
|
||||
for node in result.leaves:
|
||||
key = node.key[len(result.key):].lstrip('/').split('/', 1)
|
||||
@@ -699,7 +731,9 @@ class Etcd(AbstractEtcd):
|
||||
clusters[int(key[0])][key[1]] = node
|
||||
return {group: self._cluster_from_nodes(result.etcd_index, nodes) for group, nodes in clusters.items()}
|
||||
|
||||
def _load_cluster(self, path, loader):
|
||||
def _load_cluster(
|
||||
self, path: str, loader: Callable[[str], Union[Cluster, Dict[int, Cluster]]]
|
||||
) -> Union[Cluster, Dict[int, Cluster]]:
|
||||
cluster = None
|
||||
try:
|
||||
cluster = loader(path)
|
||||
@@ -708,18 +742,19 @@ class Etcd(AbstractEtcd):
|
||||
except Exception as e:
|
||||
self._handle_exception(e, 'get_cluster', raise_ex=EtcdError('Etcd is not responding properly'))
|
||||
self._has_failed = False
|
||||
assert cluster is not None
|
||||
return cluster
|
||||
|
||||
@catch_etcd_errors
|
||||
def touch_member(self, data):
|
||||
data = json.dumps(data, separators=(',', ':'))
|
||||
return self._client.set(self.member_path, data, self._ttl)
|
||||
def touch_member(self, data: Dict[str, Any]) -> bool:
|
||||
value = json.dumps(data, separators=(',', ':'))
|
||||
return bool(self._client.set(self.member_path, value, self._ttl))
|
||||
|
||||
@catch_etcd_errors
|
||||
def take_leader(self):
|
||||
def take_leader(self) -> bool:
|
||||
return self.retry(self._client.write, self.leader_path, self._name, ttl=self._ttl)
|
||||
|
||||
def _do_attempt_to_acquire_leader(self):
|
||||
def _do_attempt_to_acquire_leader(self) -> bool:
|
||||
try:
|
||||
return bool(self.retry(self._client.write, self.leader_path, self._name, ttl=self._ttl, prevExist=False))
|
||||
except etcd.EtcdAlreadyExist:
|
||||
@@ -727,26 +762,26 @@ class Etcd(AbstractEtcd):
|
||||
return False
|
||||
|
||||
@catch_return_false_exception
|
||||
def attempt_to_acquire_leader(self):
|
||||
def attempt_to_acquire_leader(self) -> bool:
|
||||
return self._run_and_handle_exceptions(self._do_attempt_to_acquire_leader, retry=None)
|
||||
|
||||
@catch_etcd_errors
|
||||
def set_failover_value(self, value, index=None):
|
||||
return self._client.write(self.failover_path, value, prevIndex=index or 0)
|
||||
def set_failover_value(self, value: str, index: Optional[int] = None) -> bool:
|
||||
return bool(self._client.write(self.failover_path, value, prevIndex=index or 0))
|
||||
|
||||
@catch_etcd_errors
|
||||
def set_config_value(self, value, index=None):
|
||||
return self._client.write(self.config_path, value, prevIndex=index or 0)
|
||||
def set_config_value(self, value: str, index: Optional[int] = None) -> bool:
|
||||
return bool(self._client.write(self.config_path, value, prevIndex=index or 0))
|
||||
|
||||
@catch_etcd_errors
|
||||
def _write_leader_optime(self, last_lsn):
|
||||
return self._client.set(self.leader_optime_path, last_lsn)
|
||||
def _write_leader_optime(self, last_lsn: str) -> bool:
|
||||
return bool(self._client.set(self.leader_optime_path, last_lsn))
|
||||
|
||||
@catch_etcd_errors
|
||||
def _write_status(self, value):
|
||||
return self._client.set(self.status_path, value)
|
||||
def _write_status(self, value: str) -> bool:
|
||||
return bool(self._client.set(self.status_path, value))
|
||||
|
||||
def _do_update_leader(self):
|
||||
def _do_update_leader(self) -> bool:
|
||||
try:
|
||||
return self.retry(self._client.write, self.leader_path, self._name,
|
||||
prevValue=self._name, ttl=self._ttl) is not None
|
||||
@@ -754,42 +789,42 @@ class Etcd(AbstractEtcd):
|
||||
return self._do_attempt_to_acquire_leader()
|
||||
|
||||
@catch_etcd_errors
|
||||
def _write_failsafe(self, value):
|
||||
return self._client.set(self.failsafe_path, value)
|
||||
def _write_failsafe(self, value: str) -> bool:
|
||||
return bool(self._client.set(self.failsafe_path, value))
|
||||
|
||||
@catch_return_false_exception
|
||||
def _update_leader(self):
|
||||
return self._run_and_handle_exceptions(self._do_update_leader, retry=None)
|
||||
def _update_leader(self) -> bool:
|
||||
return bool(self._run_and_handle_exceptions(self._do_update_leader, retry=None))
|
||||
|
||||
@catch_etcd_errors
|
||||
def initialize(self, create_new=True, sysid=""):
|
||||
return self.retry(self._client.write, self.initialize_path, sysid, prevExist=(not create_new))
|
||||
def initialize(self, create_new: bool = True, sysid: str = "") -> bool:
|
||||
return bool(self.retry(self._client.write, self.initialize_path, sysid, prevExist=(not create_new)))
|
||||
|
||||
@catch_etcd_errors
|
||||
def _delete_leader(self):
|
||||
return self._client.delete(self.leader_path, prevValue=self._name)
|
||||
def _delete_leader(self) -> bool:
|
||||
return bool(self._client.delete(self.leader_path, prevValue=self._name))
|
||||
|
||||
@catch_etcd_errors
|
||||
def cancel_initialization(self):
|
||||
return self.retry(self._client.delete, self.initialize_path)
|
||||
def cancel_initialization(self) -> bool:
|
||||
return bool(self.retry(self._client.delete, self.initialize_path))
|
||||
|
||||
@catch_etcd_errors
|
||||
def delete_cluster(self):
|
||||
return self.retry(self._client.delete, self.client_path(''), recursive=True)
|
||||
def delete_cluster(self) -> bool:
|
||||
return bool(self.retry(self._client.delete, self.client_path(''), recursive=True))
|
||||
|
||||
@catch_etcd_errors
|
||||
def set_history_value(self, value):
|
||||
return self._client.write(self.history_path, value)
|
||||
def set_history_value(self, value: str) -> bool:
|
||||
return bool(self._client.write(self.history_path, value))
|
||||
|
||||
@catch_etcd_errors
|
||||
def set_sync_state_value(self, value, index=None):
|
||||
return self.retry(self._client.write, self.sync_path, value, prevIndex=index or 0)
|
||||
def set_sync_state_value(self, value: str, index: Optional[int] = None) -> bool:
|
||||
return bool(self.retry(self._client.write, self.sync_path, value, prevIndex=index or 0))
|
||||
|
||||
@catch_etcd_errors
|
||||
def delete_sync_state(self, index=None):
|
||||
return self.retry(self._client.delete, self.sync_path, prevIndex=index or 0)
|
||||
def delete_sync_state(self, index: Optional[int] = None) -> bool:
|
||||
return bool(self.retry(self._client.delete, self.sync_path, prevIndex=index or 0))
|
||||
|
||||
def watch(self, leader_index, timeout):
|
||||
def watch(self, leader_index: Optional[int], timeout: float) -> bool:
|
||||
if self.__do_not_watch:
|
||||
self.__do_not_watch = False
|
||||
return True
|
||||
|
||||
@@ -10,12 +10,14 @@ import time
|
||||
import urllib3
|
||||
|
||||
from collections import defaultdict
|
||||
from threading import Condition, Lock, Thread
|
||||
from enum import IntEnum
|
||||
from urllib3.exceptions import ReadTimeoutError, ProtocolError
|
||||
from threading import Condition, Lock, Thread
|
||||
from typing import Any, Callable, Collection, Dict, Iterator, List, Optional, Tuple, Type, Union
|
||||
|
||||
from . import ClusterConfig, Cluster, Failover, Leader, Member, SyncState,\
|
||||
TimelineHistory, ReturnFalseException, catch_return_false_exception, citus_group_re
|
||||
from .etcd import AbstractEtcdClientWithFailover, AbstractEtcd, catch_etcd_errors
|
||||
from .etcd import AbstractEtcdClientWithFailover, AbstractEtcd, catch_etcd_errors, DnsCachingResolver, Retry
|
||||
from ..exceptions import DCSError, PatroniException
|
||||
from ..utils import deep_compare, enable_keepalive, iter_response_objects, RetryFailedError, USER_AGENT
|
||||
|
||||
@@ -31,11 +33,27 @@ class UnsupportedEtcdVersion(PatroniException):
|
||||
|
||||
|
||||
# google.golang.org/grpc/codes
|
||||
GRPCCode = type('Enum', (), {'OK': 0, 'Canceled': 1, 'Unknown': 2, 'InvalidArgument': 3, 'DeadlineExceeded': 4,
|
||||
'NotFound': 5, 'AlreadyExists': 6, 'PermissionDenied': 7, 'ResourceExhausted': 8,
|
||||
'FailedPrecondition': 9, 'Aborted': 10, 'OutOfRange': 11, 'Unimplemented': 12,
|
||||
'Internal': 13, 'Unavailable': 14, 'DataLoss': 15, 'Unauthenticated': 16})
|
||||
GRPCcodeToText = {v: k for k, v in GRPCCode.__dict__.items() if not k.startswith('__') and isinstance(v, int)}
|
||||
class GRPCCode(IntEnum):
|
||||
OK = 0
|
||||
Canceled = 1
|
||||
Unknown = 2
|
||||
InvalidArgument = 3
|
||||
DeadlineExceeded = 4
|
||||
NotFound = 5
|
||||
AlreadyExists = 6
|
||||
PermissionDenied = 7
|
||||
ResourceExhausted = 8
|
||||
FailedPrecondition = 9
|
||||
Aborted = 10
|
||||
OutOfRange = 11
|
||||
Unimplemented = 12
|
||||
Internal = 13
|
||||
Unavailable = 14
|
||||
DataLoss = 15
|
||||
Unauthenticated = 16
|
||||
|
||||
|
||||
GRPCcodeToText: Dict[int, str] = {v: k for k, v in GRPCCode.__dict__['_member_map_'].items()}
|
||||
|
||||
|
||||
class Etcd3Exception(etcd.EtcdException):
|
||||
@@ -44,22 +62,24 @@ class Etcd3Exception(etcd.EtcdException):
|
||||
|
||||
class Etcd3ClientError(Etcd3Exception):
|
||||
|
||||
def __init__(self, code=None, error=None, status=None):
|
||||
def __init__(self, code: Optional[int] = None, error: Optional[str] = None, status: Optional[int] = None) -> None:
|
||||
if not hasattr(self, 'error'):
|
||||
self.error = error and error.strip()
|
||||
self.codeText = GRPCcodeToText.get(code)
|
||||
self.codeText = GRPCcodeToText.get(code) if code is not None else None
|
||||
self.status = status
|
||||
|
||||
def __repr__(self):
|
||||
return "<{0} error: '{1}', code: {2}>".format(self.__class__.__name__, self.error, self.code)
|
||||
def __repr__(self) -> str:
|
||||
return "<{0} error: '{1}', code: {2}>"\
|
||||
.format(self.__class__.__name__, getattr(self, 'error', None), getattr(self, 'code', None))
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
def as_dict(self):
|
||||
return {'error': self.error, 'code': self.code, 'codeText': self.codeText, 'status': self.status}
|
||||
def as_dict(self) -> Dict[str, Any]:
|
||||
return {'error': getattr(self, 'error', None), 'code': getattr(self, 'code', None),
|
||||
'codeText': self.codeText, 'status': self.status}
|
||||
|
||||
@classmethod
|
||||
def get_subclasses(cls):
|
||||
def get_subclasses(cls) -> Iterator[Type['Etcd3ClientError']]:
|
||||
for subclass in cls.__subclasses__():
|
||||
for subsubclass in subclass.get_subclasses():
|
||||
yield subsubclass
|
||||
@@ -118,61 +138,93 @@ class InvalidAuthToken(Etcd3ClientError):
|
||||
error = "etcdserver: invalid auth token"
|
||||
|
||||
|
||||
errStringToClientError = {s.error: s for s in Etcd3ClientError.get_subclasses() if hasattr(s, 'error')}
|
||||
errCodeToClientError = {s.code: s for s in Etcd3ClientError.__subclasses__()}
|
||||
errStringToClientError = {getattr(s, 'error'): s for s in Etcd3ClientError.get_subclasses() if hasattr(s, 'error')}
|
||||
errCodeToClientError = {getattr(s, 'code'): s for s in Etcd3ClientError.__subclasses__()}
|
||||
|
||||
|
||||
def _raise_for_data(data, status_code=None):
|
||||
def _raise_for_data(data: Union[bytes, str, Dict[str, Union[Any, Dict[str, Any]]]],
|
||||
status_code: Optional[int] = None) -> Etcd3ClientError:
|
||||
try:
|
||||
error = data.get('error') or data.get('Error')
|
||||
if isinstance(error, dict): # streaming response
|
||||
status_code = error.get('http_code')
|
||||
code = error['grpc_code']
|
||||
error = error['message']
|
||||
assert isinstance(data, dict)
|
||||
data_error: Optional[Dict[str, Any]] = data.get('error') or data.get('Error')
|
||||
if isinstance(data_error, dict): # streaming response
|
||||
status_code = data_error.get('http_code')
|
||||
code: Optional[int] = data_error['grpc_code']
|
||||
error: str = data_error['message']
|
||||
else:
|
||||
code = data.get('code') or data.get('Code')
|
||||
data_code = data.get('code') or data.get('Code')
|
||||
assert not isinstance(data_code, dict)
|
||||
code = data_code
|
||||
error = str(data_error)
|
||||
except Exception:
|
||||
error = str(data)
|
||||
code = GRPCCode.Unknown
|
||||
err = errStringToClientError.get(error) or errCodeToClientError.get(code) or Unknown
|
||||
raise err(code, error, status_code)
|
||||
return err(code, error, status_code)
|
||||
|
||||
|
||||
def to_bytes(v):
|
||||
def to_bytes(v: Union[str, bytes]) -> bytes:
|
||||
return v if isinstance(v, bytes) else v.encode('utf-8')
|
||||
|
||||
|
||||
def prefix_range_end(v):
|
||||
v = bytearray(to_bytes(v))
|
||||
for i in range(len(v) - 1, -1, -1):
|
||||
if v[i] < 0xff:
|
||||
v[i] += 1
|
||||
def prefix_range_end(v: str) -> bytes:
|
||||
ret = bytearray(to_bytes(v))
|
||||
for i in range(len(ret) - 1, -1, -1):
|
||||
if ret[i] < 0xff:
|
||||
ret[i] += 1
|
||||
break
|
||||
return bytes(v)
|
||||
return bytes(ret)
|
||||
|
||||
|
||||
def base64_encode(v):
|
||||
def base64_encode(v: Union[str, bytes]) -> str:
|
||||
return base64.b64encode(to_bytes(v)).decode('utf-8')
|
||||
|
||||
|
||||
def base64_decode(v):
|
||||
def base64_decode(v: str) -> str:
|
||||
return base64.b64decode(v).decode('utf-8')
|
||||
|
||||
|
||||
def build_range_request(key, range_end=None):
|
||||
def build_range_request(key: str, range_end: Union[bytes, str, None] = None) -> Dict[str, Any]:
|
||||
fields = {'key': base64_encode(key)}
|
||||
if range_end:
|
||||
fields['range_end'] = base64_encode(range_end)
|
||||
return fields
|
||||
|
||||
|
||||
def _handle_auth_errors(func: Callable[..., Any]) -> Any:
|
||||
def wrapper(self: 'Etcd3Client', *args: Any, **kwargs: Any) -> Any:
|
||||
def retry(ex: Exception) -> Any:
|
||||
if self.username and self.password:
|
||||
self.authenticate()
|
||||
return func(self, *args, **kwargs)
|
||||
else:
|
||||
logger.fatal('Username or password not set, authentication is not possible')
|
||||
raise ex
|
||||
|
||||
try:
|
||||
return func(self, *args, **kwargs)
|
||||
except (UserEmpty, PermissionDenied) as e: # no token provided
|
||||
# PermissionDenied is raised on 3.0 and 3.1
|
||||
if self._cluster_version < (3, 3) and (not isinstance(e, PermissionDenied)
|
||||
or self._cluster_version < (3, 2)):
|
||||
raise UnsupportedEtcdVersion('Authentication is required by Etcd cluster but not '
|
||||
'supported on version lower than 3.3.0. Cluster version: '
|
||||
'{0}'.format('.'.join(map(str, self._cluster_version))))
|
||||
return retry(e)
|
||||
except InvalidAuthToken as e:
|
||||
logger.error('Invalid auth token: %s', self._token)
|
||||
return retry(e)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class Etcd3Client(AbstractEtcdClientWithFailover):
|
||||
|
||||
ERROR_CLS = Etcd3Error
|
||||
|
||||
def __init__(self, config, dns_resolver, cache_ttl=300):
|
||||
def __init__(self, config: Dict[str, Any], dns_resolver: DnsCachingResolver, cache_ttl: int = 300) -> None:
|
||||
self._token = None
|
||||
self._cluster_version = None
|
||||
self._cluster_version: Tuple[int] = tuple()
|
||||
self.version_prefix = '/v3beta'
|
||||
super(Etcd3Client, self).__init__(config, dns_resolver, cache_ttl)
|
||||
|
||||
@@ -182,32 +234,33 @@ class Etcd3Client(AbstractEtcdClientWithFailover):
|
||||
logger.fatal('Etcd3 authentication failed: %r', e)
|
||||
sys.exit(1)
|
||||
|
||||
def _get_headers(self):
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
headers = urllib3.make_headers(user_agent=USER_AGENT)
|
||||
if self._token and self._cluster_version >= (3, 3, 0):
|
||||
headers['authorization'] = self._token
|
||||
return headers
|
||||
|
||||
def _prepare_request(self, kwargs, params=None, method=None):
|
||||
def _prepare_request(self, kwargs: Dict[str, Any], params: Optional[Dict[str, Any]] = None,
|
||||
method: Optional[str] = None) -> Callable[..., urllib3.response.HTTPResponse]:
|
||||
if params is not None:
|
||||
kwargs['body'] = json.dumps(params)
|
||||
kwargs['headers']['Content-Type'] = 'application/json'
|
||||
return self.http.urlopen
|
||||
|
||||
@staticmethod
|
||||
def _handle_server_response(response):
|
||||
data = response.data
|
||||
def _handle_server_response(self, response: urllib3.response.HTTPResponse) -> Dict[str, Any]:
|
||||
data: Union[bytes, str] = response.data
|
||||
try:
|
||||
data = data.decode('utf-8')
|
||||
data = json.loads(data)
|
||||
ret: Dict[str, Any] = json.loads(data)
|
||||
if response.status < 400:
|
||||
return ret
|
||||
except (TypeError, ValueError, UnicodeError) as e:
|
||||
if response.status < 400:
|
||||
raise etcd.EtcdException('Server response was not valid JSON: %r' % e)
|
||||
if response.status < 400:
|
||||
return data
|
||||
_raise_for_data(data, response.status)
|
||||
ret = {}
|
||||
raise _raise_for_data(ret or data, response.status)
|
||||
|
||||
def _ensure_version_prefix(self, base_uri, **kwargs):
|
||||
def _ensure_version_prefix(self, base_uri: str, **kwargs: Any) -> None:
|
||||
if self.version_prefix != '/v3':
|
||||
response = self.http.urlopen(self._MGET, base_uri + '/version', **kwargs)
|
||||
response = self._handle_server_response(response)
|
||||
@@ -234,87 +287,64 @@ class Etcd3Client(AbstractEtcdClientWithFailover):
|
||||
else:
|
||||
self.version_prefix = '/v3'
|
||||
|
||||
def _prepare_get_members(self, etcd_nodes):
|
||||
def _prepare_get_members(self, etcd_nodes: int) -> Dict[str, Any]:
|
||||
kwargs = self._prepare_common_parameters(etcd_nodes)
|
||||
self._prepare_request(kwargs, {})
|
||||
return kwargs
|
||||
|
||||
def _get_members(self, base_uri, **kwargs):
|
||||
def _get_members(self, base_uri: str, **kwargs: Any) -> List[str]:
|
||||
self._ensure_version_prefix(base_uri, **kwargs)
|
||||
resp = self.http.urlopen(self._MPOST, base_uri + self.version_prefix + '/cluster/member/list', **kwargs)
|
||||
members = self._handle_server_response(resp)['members']
|
||||
return set(url for member in members for url in member.get('clientURLs', []))
|
||||
return [url for member in members for url in member.get('clientURLs', [])]
|
||||
|
||||
def call_rpc(self, method, fields, retry=None):
|
||||
def call_rpc(self, method: str, fields: Dict[str, Any], retry: Optional[Retry] = None) -> Dict[str, Any]:
|
||||
fields['retry'] = retry
|
||||
return self.api_execute(self.version_prefix + method, self._MPOST, fields)
|
||||
|
||||
def authenticate(self):
|
||||
if self._use_proxies and self._cluster_version is None:
|
||||
def authenticate(self) -> bool:
|
||||
if self._use_proxies and not self._cluster_version:
|
||||
kwargs = self._prepare_common_parameters(1)
|
||||
self._ensure_version_prefix(self._base_uri, **kwargs)
|
||||
if self._cluster_version >= (3, 3) and self.username and self.password:
|
||||
logger.info('Trying to authenticate on Etcd...')
|
||||
old_token, self._token = self._token, None
|
||||
try:
|
||||
response = self.call_rpc('/auth/authenticate', {'name': self.username, 'password': self.password})
|
||||
except AuthNotEnabled:
|
||||
logger.info('Etcd authentication is not enabled')
|
||||
self._token = None
|
||||
except Exception:
|
||||
self._token = old_token
|
||||
raise
|
||||
else:
|
||||
self._token = response.get('token')
|
||||
return old_token != self._token
|
||||
|
||||
def _handle_auth_errors(func):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
def retry(ex):
|
||||
if self.username and self.password:
|
||||
self.authenticate()
|
||||
return func(self, *args, **kwargs)
|
||||
else:
|
||||
logger.fatal('Username or password not set, authentication is not possible')
|
||||
raise ex
|
||||
|
||||
try:
|
||||
return func(self, *args, **kwargs)
|
||||
except (UserEmpty, PermissionDenied) as e: # no token provided
|
||||
# PermissionDenied is raised on 3.0 and 3.1
|
||||
if self._cluster_version < (3, 3) and (not isinstance(e, PermissionDenied)
|
||||
or self._cluster_version < (3, 2)):
|
||||
raise UnsupportedEtcdVersion('Authentication is required by Etcd cluster but not '
|
||||
'supported on version lower than 3.3.0. Cluster version: '
|
||||
'{0}'.format('.'.join(map(str, self._cluster_version))))
|
||||
return retry(e)
|
||||
except InvalidAuthToken as e:
|
||||
logger.error('Invalid auth token: %s', self._token)
|
||||
return retry(e)
|
||||
|
||||
return wrapper
|
||||
if not (self._cluster_version >= (3, 3) and self.username and self.password):
|
||||
return False
|
||||
logger.info('Trying to authenticate on Etcd...')
|
||||
old_token, self._token = self._token, None
|
||||
try:
|
||||
response = self.call_rpc('/auth/authenticate', {'name': self.username, 'password': self.password})
|
||||
except AuthNotEnabled:
|
||||
logger.info('Etcd authentication is not enabled')
|
||||
self._token = None
|
||||
except Exception:
|
||||
self._token = old_token
|
||||
raise
|
||||
else:
|
||||
self._token = response.get('token')
|
||||
return old_token != self._token
|
||||
|
||||
@_handle_auth_errors
|
||||
def range(self, key, range_end=None, retry=None):
|
||||
def range(self, key: str, range_end: Union[bytes, str, None] = None,
|
||||
retry: Optional[Retry] = None) -> Dict[str, Any]:
|
||||
params = build_range_request(key, range_end)
|
||||
params['serializable'] = True # For better performance. We can tolerate stale reads.
|
||||
return self.call_rpc('/kv/range', params, retry)
|
||||
|
||||
def prefix(self, key, retry=None):
|
||||
def prefix(self, key: str, retry: Optional[Retry] = None) -> Dict[str, Any]:
|
||||
return self.range(key, prefix_range_end(key), retry)
|
||||
|
||||
@_handle_auth_errors
|
||||
def lease_grant(self, ttl, retry=None):
|
||||
def lease_grant(self, ttl: int, retry: Optional[Retry] = None) -> str:
|
||||
return self.call_rpc('/lease/grant', {'TTL': ttl}, retry)['ID']
|
||||
|
||||
def lease_keepalive(self, ID, retry=None):
|
||||
def lease_keepalive(self, ID: str, retry: Optional[Retry] = None) -> Optional[str]:
|
||||
return self.call_rpc('/lease/keepalive', {'ID': ID}, retry).get('result', {}).get('TTL')
|
||||
|
||||
def txn(self, compare, success, retry=None):
|
||||
return self.call_rpc('/kv/txn', {'compare': [compare], 'success': [success]}, retry).get('succeeded')
|
||||
def txn(self, compare: Dict[str, Any], success: Dict[str, Any], retry: Optional[Retry] = None) -> Dict[str, Any]:
|
||||
return self.call_rpc('/kv/txn', {'compare': [compare], 'success': [success]}, retry).get('succeeded', {})
|
||||
|
||||
@_handle_auth_errors
|
||||
def put(self, key, value, lease=None, create_revision=None, mod_revision=None, retry=None):
|
||||
def put(self, key: str, value: str, lease: Optional[str] = None, create_revision: Optional[str] = None,
|
||||
mod_revision: Optional[str] = None, retry: Optional[Retry] = None) -> Dict[str, Any]:
|
||||
fields = {'key': base64_encode(key), 'value': base64_encode(value)}
|
||||
if lease:
|
||||
fields['lease'] = lease
|
||||
@@ -328,17 +358,20 @@ class Etcd3Client(AbstractEtcdClientWithFailover):
|
||||
return self.txn(compare, {'request_put': fields}, retry)
|
||||
|
||||
@_handle_auth_errors
|
||||
def deleterange(self, key, range_end=None, mod_revision=None, retry=None):
|
||||
def deleterange(self, key: str, range_end: Union[bytes, str, None] = None,
|
||||
mod_revision: Optional[str] = None, retry: Optional[Retry] = None) -> Dict[str, Any]:
|
||||
fields = build_range_request(key, range_end)
|
||||
if mod_revision is None:
|
||||
return self.call_rpc('/kv/deleterange', fields, retry)
|
||||
compare = {'target': 'MOD', 'mod_revision': mod_revision, 'key': fields['key']}
|
||||
return self.txn(compare, {'request_delete_range': fields}, retry)
|
||||
|
||||
def deleteprefix(self, key, retry=None):
|
||||
def deleteprefix(self, key: str, retry: Optional[Retry] = None) -> Dict[str, Any]:
|
||||
return self.deleterange(key, prefix_range_end(key), retry=retry)
|
||||
|
||||
def watchrange(self, key, range_end=None, start_revision=None, filters=None, read_timeout=None):
|
||||
def watchrange(self, key: str, range_end: Union[bytes, str, None] = None,
|
||||
start_revision: Optional[str] = None, filters: Optional[List[Dict[str, Any]]] = None,
|
||||
read_timeout: Optional[float] = None) -> urllib3.response.HTTPResponse:
|
||||
"""returns: response object"""
|
||||
params = build_range_request(key, range_end)
|
||||
if start_revision is not None:
|
||||
@@ -349,14 +382,16 @@ class Etcd3Client(AbstractEtcdClientWithFailover):
|
||||
kwargs.update(timeout=urllib3.Timeout(connect=kwargs['timeout'], read=read_timeout), retries=0)
|
||||
return request_executor(self._MPOST, self._base_uri + self.version_prefix + '/watch', **kwargs)
|
||||
|
||||
def watchprefix(self, key, start_revision=None, filters=None, read_timeout=None):
|
||||
def watchprefix(self, key: str, start_revision: Optional[str] = None,
|
||||
filters: Optional[List[Dict[str, Any]]] = None,
|
||||
read_timeout: Optional[float] = None) -> urllib3.response.HTTPResponse:
|
||||
return self.watchrange(key, prefix_range_end(key), start_revision, filters, read_timeout)
|
||||
|
||||
|
||||
class KVCache(Thread):
|
||||
|
||||
def __init__(self, dcs, client):
|
||||
Thread.__init__(self)
|
||||
def __init__(self, dcs: 'Etcd3', client: 'PatroniEtcd3Client') -> None:
|
||||
super(KVCache, self).__init__()
|
||||
self.daemon = True
|
||||
self._dcs = dcs
|
||||
self._client = client
|
||||
@@ -373,32 +408,32 @@ class KVCache(Thread):
|
||||
self._object_cache_lock = Lock()
|
||||
self.start()
|
||||
|
||||
def set(self, value, overwrite=False):
|
||||
def set(self, value: Dict[str, Any], overwrite: bool = False) -> Tuple[bool, Optional[Dict[str, Any]]]:
|
||||
with self._object_cache_lock:
|
||||
name = value['key']
|
||||
old_value = self._object_cache.get(name)
|
||||
ret = not old_value or int(old_value['mod_revision']) < int(value['mod_revision'])
|
||||
if ret or overwrite and old_value['mod_revision'] == value['mod_revision']:
|
||||
if ret or overwrite and old_value and old_value['mod_revision'] == value['mod_revision']:
|
||||
self._object_cache[name] = value
|
||||
return ret, old_value
|
||||
|
||||
def delete(self, name, mod_revision):
|
||||
def delete(self, name: str, mod_revision: str) -> Tuple[bool, Optional[Dict[str, Any]]]:
|
||||
with self._object_cache_lock:
|
||||
old_value = self._object_cache.get(name)
|
||||
ret = old_value and int(old_value['mod_revision']) < int(mod_revision)
|
||||
if ret:
|
||||
del self._object_cache[name]
|
||||
return not old_value or ret, old_value
|
||||
return bool(not old_value or ret), old_value
|
||||
|
||||
def copy(self):
|
||||
def copy(self) -> List[Dict[str, Any]]:
|
||||
with self._object_cache_lock:
|
||||
return [v.copy() for v in self._object_cache.values()]
|
||||
|
||||
def get(self, name):
|
||||
def get(self, name: str) -> Optional[Dict[str, Any]]:
|
||||
with self._object_cache_lock:
|
||||
return self._object_cache.get(name)
|
||||
|
||||
def _process_event(self, event):
|
||||
def _process_event(self, event: Dict[str, Any]) -> None:
|
||||
kv = event['kv']
|
||||
key = kv['key']
|
||||
if event.get('type') == 'DELETE':
|
||||
@@ -422,21 +457,22 @@ class KVCache(Thread):
|
||||
or (self.get(self._leader_key) or {}).get('value') != self._name):
|
||||
self._dcs.event.set()
|
||||
|
||||
def _process_message(self, message):
|
||||
def _process_message(self, message: Dict[str, Any]) -> None:
|
||||
logger.debug('Received message: %s', message)
|
||||
if 'error' in message:
|
||||
_raise_for_data(message)
|
||||
for event in message.get('result', {}).get('events', []):
|
||||
raise _raise_for_data(message)
|
||||
events: List[Dict[str, Any]] = message.get('result', {}).get('events', [])
|
||||
for event in events:
|
||||
self._process_event(event)
|
||||
|
||||
@staticmethod
|
||||
def _finish_response(response):
|
||||
def _finish_response(response: urllib3.response.HTTPResponse) -> None:
|
||||
try:
|
||||
response.close()
|
||||
finally:
|
||||
response.release_conn()
|
||||
|
||||
def _do_watch(self, revision):
|
||||
def _do_watch(self, revision: str) -> None:
|
||||
with self._response_lock:
|
||||
self._response = None
|
||||
# We do most of requests with timeouts. The only exception /watch requests to Etcd v3.
|
||||
@@ -457,7 +493,7 @@ class KVCache(Thread):
|
||||
for message in iter_response_objects(response):
|
||||
self._process_message(message)
|
||||
|
||||
def _build_cache(self):
|
||||
def _build_cache(self) -> None:
|
||||
result = self._dcs.retry(self._client.prefix, self._dcs.cluster_prefix)
|
||||
with self._object_cache_lock:
|
||||
self._object_cache = {node['key']: node for node in result.get('kvs', [])}
|
||||
@@ -476,10 +512,10 @@ class KVCache(Thread):
|
||||
self._is_ready = False
|
||||
with self._response_lock:
|
||||
response, self._response = self._response, None
|
||||
if response:
|
||||
if isinstance(response, urllib3.response.HTTPResponse):
|
||||
self._finish_response(response)
|
||||
|
||||
def run(self):
|
||||
def run(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
self._build_cache()
|
||||
@@ -487,12 +523,12 @@ class KVCache(Thread):
|
||||
logger.error('KVCache.run %r', e)
|
||||
time.sleep(1)
|
||||
|
||||
def kill_stream(self):
|
||||
def kill_stream(self) -> None:
|
||||
sock = None
|
||||
with self._response_lock:
|
||||
if self._response:
|
||||
if isinstance(self._response, urllib3.response.HTTPResponse):
|
||||
try:
|
||||
sock = self._response.connection.sock
|
||||
sock = self._response.connection.sock if self._response.connection else None
|
||||
except Exception:
|
||||
sock = None
|
||||
else:
|
||||
@@ -504,48 +540,48 @@ class KVCache(Thread):
|
||||
except Exception as e:
|
||||
logger.debug('Error on socket.shutdown: %r', e)
|
||||
|
||||
def is_ready(self):
|
||||
def is_ready(self) -> bool:
|
||||
"""Must be called only when holding the lock on `condition`"""
|
||||
return self._is_ready
|
||||
|
||||
|
||||
class PatroniEtcd3Client(Etcd3Client):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
self._kv_cache = None
|
||||
super(PatroniEtcd3Client, self).__init__(*args, **kwargs)
|
||||
|
||||
def configure(self, etcd3):
|
||||
def configure(self, etcd3: 'Etcd3') -> None:
|
||||
self._etcd3 = etcd3
|
||||
|
||||
def start_watcher(self):
|
||||
def start_watcher(self) -> None:
|
||||
if self._cluster_version >= (3, 1):
|
||||
self._kv_cache = KVCache(self._etcd3, self)
|
||||
|
||||
def _restart_watcher(self):
|
||||
def _restart_watcher(self) -> None:
|
||||
if self._kv_cache:
|
||||
self._kv_cache.kill_stream()
|
||||
|
||||
def set_base_uri(self, value):
|
||||
def set_base_uri(self, value: str) -> None:
|
||||
super(PatroniEtcd3Client, self).set_base_uri(value)
|
||||
self._restart_watcher()
|
||||
|
||||
def authenticate(self):
|
||||
def authenticate(self) -> bool:
|
||||
ret = super(PatroniEtcd3Client, self).authenticate()
|
||||
if ret:
|
||||
self._restart_watcher()
|
||||
return ret
|
||||
|
||||
def _wait_cache(self, timeout):
|
||||
def _wait_cache(self, timeout: float) -> None:
|
||||
stop_time = time.time() + timeout
|
||||
while not self._kv_cache.is_ready():
|
||||
while self._kv_cache and not self._kv_cache.is_ready():
|
||||
timeout = stop_time - time.time()
|
||||
if timeout <= 0:
|
||||
raise RetryFailedError('Exceeded retry deadline')
|
||||
self._kv_cache.condition.wait(timeout)
|
||||
|
||||
def get_cluster(self, path):
|
||||
if self._kv_cache and path.startswith(self._etcd3.cluster_prefix):
|
||||
def get_cluster(self, path: str) -> List[Dict[str, Any]]:
|
||||
if self._kv_cache and self._etcd3._retry.deadline is not None and path.startswith(self._etcd3.cluster_prefix):
|
||||
with self._kv_cache.condition:
|
||||
self._wait_cache(self._etcd3._retry.deadline)
|
||||
ret = self._kv_cache.copy()
|
||||
@@ -557,7 +593,7 @@ class PatroniEtcd3Client(Etcd3Client):
|
||||
'lease': node.get('lease')})
|
||||
return ret
|
||||
|
||||
def call_rpc(self, method, fields, retry=None):
|
||||
def call_rpc(self, method: str, fields: Dict[str, Any], retry: Optional[Retry] = None) -> Dict[str, Any]:
|
||||
ret = super(PatroniEtcd3Client, self).call_rpc(method, fields, retry)
|
||||
|
||||
if self._kv_cache:
|
||||
@@ -582,8 +618,9 @@ class PatroniEtcd3Client(Etcd3Client):
|
||||
|
||||
class Etcd3(AbstractEtcd):
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: Dict[str, Any]) -> None:
|
||||
super(Etcd3, self).__init__(config, PatroniEtcd3Client, (DeadlineExceeded, Unavailable, FailedPrecondition))
|
||||
assert isinstance(self._client, PatroniEtcd3Client)
|
||||
self.__do_not_watch = False
|
||||
self._lease = None
|
||||
self._last_lease_refresh = 0
|
||||
@@ -593,15 +630,23 @@ class Etcd3(AbstractEtcd):
|
||||
self._client.start_watcher()
|
||||
self.create_lease()
|
||||
|
||||
def set_socket_options(self, sock, socket_options):
|
||||
@property
|
||||
def _client(self) -> PatroniEtcd3Client:
|
||||
assert isinstance(self._abstract_client, PatroniEtcd3Client)
|
||||
return self._abstract_client
|
||||
|
||||
def set_socket_options(self, sock: socket.socket,
|
||||
socket_options: Optional[Collection[Tuple[int, int, int]]]) -> None:
|
||||
assert self._retry.deadline is not None
|
||||
enable_keepalive(sock, self.ttl, int(self.loop_wait + self._retry.deadline))
|
||||
|
||||
def set_ttl(self, ttl):
|
||||
def set_ttl(self, ttl: int) -> Optional[bool]:
|
||||
self.__do_not_watch = super(Etcd3, self).set_ttl(ttl)
|
||||
if self.__do_not_watch:
|
||||
self._lease = None
|
||||
return None
|
||||
|
||||
def _do_refresh_lease(self, force=False, retry=None):
|
||||
def _do_refresh_lease(self, force: bool = False, retry: Optional[Retry] = None) -> bool:
|
||||
if not force and self._lease and self._last_lease_refresh + self._loop_wait > time.time():
|
||||
return False
|
||||
|
||||
@@ -615,14 +660,14 @@ class Etcd3(AbstractEtcd):
|
||||
self._last_lease_refresh = time.time()
|
||||
return ret
|
||||
|
||||
def refresh_lease(self):
|
||||
def refresh_lease(self) -> bool:
|
||||
try:
|
||||
return self.retry(self._do_refresh_lease)
|
||||
except (Etcd3ClientError, RetryFailedError):
|
||||
logger.exception('refresh_lease')
|
||||
raise Etcd3Error('Failed to keepalive/grant lease')
|
||||
|
||||
def create_lease(self):
|
||||
def create_lease(self) -> None:
|
||||
while not self._lease:
|
||||
try:
|
||||
self.refresh_lease()
|
||||
@@ -631,14 +676,14 @@ class Etcd3(AbstractEtcd):
|
||||
time.sleep(5)
|
||||
|
||||
@property
|
||||
def cluster_prefix(self):
|
||||
def cluster_prefix(self) -> str:
|
||||
return self._base_path + '/' if self.is_citus_coordinator() else self.client_path('')
|
||||
|
||||
@staticmethod
|
||||
def member(node):
|
||||
def member(node: Dict[str, str]) -> Member:
|
||||
return Member.from_node(node['mod_revision'], os.path.basename(node['key']), node['lease'], node['value'])
|
||||
|
||||
def _cluster_from_nodes(self, nodes):
|
||||
def _cluster_from_nodes(self, nodes: Dict[str, Any]) -> Cluster:
|
||||
# get initialize flag
|
||||
initialize = nodes.get(self._INITIALIZE)
|
||||
initialize = initialize and initialize['value']
|
||||
@@ -666,7 +711,7 @@ class Etcd3(AbstractEtcd):
|
||||
slots = None
|
||||
|
||||
try:
|
||||
last_lsn = int(last_lsn)
|
||||
last_lsn = int(last_lsn or '')
|
||||
except Exception:
|
||||
last_lsn = 0
|
||||
|
||||
@@ -701,14 +746,14 @@ class Etcd3(AbstractEtcd):
|
||||
|
||||
return Cluster(initialize, config, leader, last_lsn, members, failover, sync, history, slots, failsafe)
|
||||
|
||||
def _cluster_loader(self, path):
|
||||
def _cluster_loader(self, path: str) -> Cluster:
|
||||
nodes = {node['key'][len(path):]: node
|
||||
for node in self._client.get_cluster(path)
|
||||
if node['key'].startswith(path)}
|
||||
return self._cluster_from_nodes(nodes)
|
||||
|
||||
def _citus_cluster_loader(self, path):
|
||||
clusters = defaultdict(dict)
|
||||
def _citus_cluster_loader(self, path: str) -> Dict[int, Cluster]:
|
||||
clusters: Dict[int, Dict[str, Dict[str, Any]]] = defaultdict(dict)
|
||||
path = self._base_path + '/'
|
||||
for node in self._client.get_cluster(path):
|
||||
key = node['key'][len(path):].split('/', 1)
|
||||
@@ -716,7 +761,9 @@ class Etcd3(AbstractEtcd):
|
||||
clusters[int(key[0])][key[1]] = node
|
||||
return {group: self._cluster_from_nodes(nodes) for group, nodes in clusters.items()}
|
||||
|
||||
def _load_cluster(self, path, loader):
|
||||
def _load_cluster(
|
||||
self, path: str, loader: Callable[[str], Union[Cluster, Dict[int, Cluster]]]
|
||||
) -> Union[Cluster, Dict[int, Cluster]]:
|
||||
cluster = None
|
||||
try:
|
||||
cluster = loader(path)
|
||||
@@ -725,10 +772,11 @@ class Etcd3(AbstractEtcd):
|
||||
except Exception as e:
|
||||
self._handle_exception(e, 'get_cluster', raise_ex=Etcd3Error('Etcd is not responding properly'))
|
||||
self._has_failed = False
|
||||
assert cluster is not None
|
||||
return cluster
|
||||
|
||||
@catch_etcd_errors
|
||||
def touch_member(self, data):
|
||||
def touch_member(self, data: Dict[str, Any]) -> bool:
|
||||
try:
|
||||
self.refresh_lease()
|
||||
except Etcd3Error:
|
||||
@@ -740,19 +788,20 @@ class Etcd3(AbstractEtcd):
|
||||
if member and member.session == self._lease and deep_compare(data, member.data):
|
||||
return True
|
||||
|
||||
data = json.dumps(data, separators=(',', ':'))
|
||||
value = json.dumps(data, separators=(',', ':'))
|
||||
try:
|
||||
return self._client.put(self.member_path, data, self._lease)
|
||||
return bool(self._client.put(self.member_path, value, self._lease))
|
||||
except LeaseNotFound:
|
||||
self._lease = None
|
||||
logger.error('Our lease disappeared from Etcd, can not "touch_member"')
|
||||
return False
|
||||
|
||||
@catch_etcd_errors
|
||||
def take_leader(self):
|
||||
def take_leader(self) -> bool:
|
||||
return self.retry(self._client.put, self.leader_path, self._name, self._lease)
|
||||
|
||||
def _do_attempt_to_acquire_leader(self, retry):
|
||||
def _retry(*args, **kwargs):
|
||||
def _do_attempt_to_acquire_leader(self, retry: Retry) -> bool:
|
||||
def _retry(*args: Any, **kwargs: Any) -> Any:
|
||||
kwargs['retry'] = retry
|
||||
return retry(*args, **kwargs)
|
||||
|
||||
@@ -772,10 +821,10 @@ class Etcd3(AbstractEtcd):
|
||||
return _retry(self._client.put, self.leader_path, self._name, self._lease, 0)
|
||||
|
||||
@catch_return_false_exception
|
||||
def attempt_to_acquire_leader(self):
|
||||
def attempt_to_acquire_leader(self) -> bool:
|
||||
retry = self._retry.copy()
|
||||
|
||||
def _retry(*args, **kwargs):
|
||||
def _retry(*args: Any, **kwargs: Any) -> Any:
|
||||
kwargs['retry'] = retry
|
||||
return retry(*args, **kwargs)
|
||||
|
||||
@@ -791,30 +840,30 @@ class Etcd3(AbstractEtcd):
|
||||
return ret
|
||||
|
||||
@catch_etcd_errors
|
||||
def set_failover_value(self, value, index=None):
|
||||
return self._client.put(self.failover_path, value, mod_revision=index)
|
||||
def set_failover_value(self, value: str, index: Optional[str] = None) -> bool:
|
||||
return bool(self._client.put(self.failover_path, value, mod_revision=index))
|
||||
|
||||
@catch_etcd_errors
|
||||
def set_config_value(self, value, index=None):
|
||||
return self._client.put(self.config_path, value, mod_revision=index)
|
||||
def set_config_value(self, value: str, index: Optional[str] = None) -> bool:
|
||||
return bool(self._client.put(self.config_path, value, mod_revision=index))
|
||||
|
||||
@catch_etcd_errors
|
||||
def _write_leader_optime(self, last_lsn):
|
||||
return self._client.put(self.leader_optime_path, last_lsn)
|
||||
def _write_leader_optime(self, last_lsn: str) -> bool:
|
||||
return bool(self._client.put(self.leader_optime_path, last_lsn))
|
||||
|
||||
@catch_etcd_errors
|
||||
def _write_status(self, value):
|
||||
return self._client.put(self.status_path, value)
|
||||
def _write_status(self, value: str) -> bool:
|
||||
return bool(self._client.put(self.status_path, value))
|
||||
|
||||
@catch_etcd_errors
|
||||
def _write_failsafe(self, value):
|
||||
return self._client.put(self.failsafe_path, value)
|
||||
def _write_failsafe(self, value: str) -> bool:
|
||||
return bool(self._client.put(self.failsafe_path, value))
|
||||
|
||||
@catch_return_false_exception
|
||||
def _update_leader(self):
|
||||
def _update_leader(self) -> bool:
|
||||
retry = self._retry.copy()
|
||||
|
||||
def _retry(*args, **kwargs):
|
||||
def _retry(*args: Any, **kwargs: Any) -> Any:
|
||||
kwargs['retry'] = retry
|
||||
return retry(*args, **kwargs)
|
||||
|
||||
@@ -836,36 +885,37 @@ class Etcd3(AbstractEtcd):
|
||||
return bool(self._lease)
|
||||
|
||||
@catch_etcd_errors
|
||||
def initialize(self, create_new=True, sysid=""):
|
||||
def initialize(self, create_new: bool = True, sysid: str = ""):
|
||||
return self.retry(self._client.put, self.initialize_path, sysid, None, 0 if create_new else None)
|
||||
|
||||
@catch_etcd_errors
|
||||
def _delete_leader(self):
|
||||
def _delete_leader(self) -> bool:
|
||||
cluster = self.cluster
|
||||
if cluster and isinstance(cluster.leader, Leader) and cluster.leader.name == self._name:
|
||||
return self._client.deleterange(self.leader_path, mod_revision=cluster.leader.index)
|
||||
return True
|
||||
|
||||
@catch_etcd_errors
|
||||
def cancel_initialization(self):
|
||||
def cancel_initialization(self) -> bool:
|
||||
return self.retry(self._client.deleterange, self.initialize_path)
|
||||
|
||||
@catch_etcd_errors
|
||||
def delete_cluster(self):
|
||||
def delete_cluster(self) -> bool:
|
||||
return self.retry(self._client.deleteprefix, self.client_path(''))
|
||||
|
||||
@catch_etcd_errors
|
||||
def set_history_value(self, value):
|
||||
return self._client.put(self.history_path, value)
|
||||
def set_history_value(self, value: str) -> bool:
|
||||
return bool(self._client.put(self.history_path, value))
|
||||
|
||||
@catch_etcd_errors
|
||||
def set_sync_state_value(self, value, index=None):
|
||||
def set_sync_state_value(self, value: str, index: Optional[str] = None) -> bool:
|
||||
return self.retry(self._client.put, self.sync_path, value, mod_revision=index)
|
||||
|
||||
@catch_etcd_errors
|
||||
def delete_sync_state(self, index=None):
|
||||
def delete_sync_state(self, index: Optional[str] = None) -> bool:
|
||||
return self.retry(self._client.deleterange, self.sync_path, mod_revision=index)
|
||||
|
||||
def watch(self, leader_index, timeout):
|
||||
def watch(self, leader_index: Optional[str], timeout: float) -> bool:
|
||||
if self.__do_not_watch:
|
||||
self.__do_not_watch = False
|
||||
return True
|
||||
|
||||
@@ -3,9 +3,12 @@ import logging
|
||||
import random
|
||||
import time
|
||||
|
||||
from patroni.dcs.zookeeper import ZooKeeper
|
||||
from patroni.request import get as requests_get
|
||||
from patroni.utils import uri
|
||||
from typing import Any, Callable, Dict, List, Union
|
||||
|
||||
from . import Cluster
|
||||
from .zookeeper import ZooKeeper
|
||||
from ..request import get as requests_get
|
||||
from ..utils import uri
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -14,11 +17,12 @@ class ExhibitorEnsembleProvider(object):
|
||||
|
||||
TIMEOUT = 3.1
|
||||
|
||||
def __init__(self, hosts, port, uri_path='/exhibitor/v1/cluster/list', poll_interval=300):
|
||||
def __init__(self, hosts: List[str], port: int,
|
||||
uri_path: str = '/exhibitor/v1/cluster/list', poll_interval: int = 300) -> None:
|
||||
self._exhibitor_port = port
|
||||
self._uri_path = uri_path
|
||||
self._poll_interval = poll_interval
|
||||
self._exhibitors = hosts
|
||||
self._exhibitors: List[str] = hosts
|
||||
self._boot_exhibitors = hosts
|
||||
self._zookeeper_hosts = ''
|
||||
self._next_poll = None
|
||||
@@ -26,7 +30,7 @@ class ExhibitorEnsembleProvider(object):
|
||||
logger.info('waiting on exhibitor')
|
||||
time.sleep(5)
|
||||
|
||||
def poll(self):
|
||||
def poll(self) -> bool:
|
||||
if self._next_poll and self._next_poll > time.time():
|
||||
return False
|
||||
|
||||
@@ -36,7 +40,8 @@ class ExhibitorEnsembleProvider(object):
|
||||
|
||||
if isinstance(json, dict) and 'servers' in json and 'port' in json:
|
||||
self._next_poll = time.time() + self._poll_interval
|
||||
zookeeper_hosts = ','.join([h + ':' + str(json['port']) for h in sorted(json['servers'])])
|
||||
servers: List[str] = json['servers']
|
||||
zookeeper_hosts = ','.join([h + ':' + str(json['port']) for h in sorted(servers)])
|
||||
if self._zookeeper_hosts != zookeeper_hosts:
|
||||
logger.info('ZooKeeper connection string has changed: %s => %s', self._zookeeper_hosts, zookeeper_hosts)
|
||||
self._zookeeper_hosts = zookeeper_hosts
|
||||
@@ -44,7 +49,7 @@ class ExhibitorEnsembleProvider(object):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _query_exhibitors(self, exhibitors):
|
||||
def _query_exhibitors(self, exhibitors: List[str]) -> Union[Dict[str, Any], Any]:
|
||||
random.shuffle(exhibitors)
|
||||
for host in exhibitors:
|
||||
try:
|
||||
@@ -55,18 +60,20 @@ class ExhibitorEnsembleProvider(object):
|
||||
return None
|
||||
|
||||
@property
|
||||
def zookeeper_hosts(self):
|
||||
def zookeeper_hosts(self) -> str:
|
||||
return self._zookeeper_hosts
|
||||
|
||||
|
||||
class Exhibitor(ZooKeeper):
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: Dict[str, Any]) -> None:
|
||||
interval = config.get('poll_interval', 300)
|
||||
self._ensemble_provider = ExhibitorEnsembleProvider(config['hosts'], config['port'], poll_interval=interval)
|
||||
super(Exhibitor, self).__init__({**config, 'hosts': self._ensemble_provider.zookeeper_hosts})
|
||||
|
||||
def _load_cluster(self, path, loader):
|
||||
def _load_cluster(
|
||||
self, path: str, loader: Callable[[str], Union[Cluster, Dict[int, Cluster]]]
|
||||
) -> Union[Cluster, Dict[int, Cluster]]:
|
||||
if self._ensemble_provider.poll():
|
||||
self._client.set_hosts(self._ensemble_provider.zookeeper_hosts)
|
||||
return super(Exhibitor, self)._load_cluster(path, loader)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -10,10 +10,13 @@ from pysyncobj.dns_resolver import globalDnsResolver
|
||||
from pysyncobj.node import TCPNode
|
||||
from pysyncobj.transport import TCPTransport, CONNECTION_STATE
|
||||
from pysyncobj.utility import TcpUtility
|
||||
from typing import Any, Callable, Collection, Dict, List, Optional, Union, TYPE_CHECKING
|
||||
|
||||
from . import AbstractDCS, ClusterConfig, Cluster, Failover, Leader, Member, SyncState, TimelineHistory, citus_group_re
|
||||
from ..exceptions import DCSError
|
||||
from ..utils import validate_directory
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from ..config import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,11 +27,12 @@ class RaftError(DCSError):
|
||||
|
||||
class _TCPTransport(TCPTransport):
|
||||
|
||||
def __init__(self, syncObj, selfNode, otherNodes):
|
||||
def __init__(self, syncObj: 'DynMemberSyncObj', selfNode: Optional[TCPNode],
|
||||
otherNodes: Collection[TCPNode]) -> None:
|
||||
super(_TCPTransport, self).__init__(syncObj, selfNode, otherNodes)
|
||||
self.setOnUtilityMessageCallback('members', syncObj.getMembers)
|
||||
|
||||
def _connectIfNecessarySingle(self, node):
|
||||
def _connectIfNecessarySingle(self, node: TCPNode) -> bool:
|
||||
try:
|
||||
return super(_TCPTransport, self)._connectIfNecessarySingle(node)
|
||||
except Exception as e:
|
||||
@@ -36,7 +40,7 @@ class _TCPTransport(TCPTransport):
|
||||
return False
|
||||
|
||||
|
||||
def resolve_host(self):
|
||||
def resolve_host(self: TCPNode) -> Optional[str]:
|
||||
return globalDnsResolver().resolve(self.host)
|
||||
|
||||
|
||||
@@ -45,17 +49,19 @@ setattr(TCPNode, 'ip', property(resolve_host))
|
||||
|
||||
class SyncObjUtility(object):
|
||||
|
||||
def __init__(self, otherNodes, conf, retry_timeout=10):
|
||||
def __init__(self, otherNodes: Collection[Union[str, TCPNode]], conf: SyncObjConf, retry_timeout: int = 10) -> None:
|
||||
self._nodes = otherNodes
|
||||
self._utility = TcpUtility(conf.password, retry_timeout / max(1, len(otherNodes)))
|
||||
self.__node = next(iter(otherNodes), None)
|
||||
|
||||
def executeCommand(self, command):
|
||||
def executeCommand(self, command: List[Any]) -> Any:
|
||||
try:
|
||||
return self._utility.executeCommand(self.__node, command)
|
||||
if self.__node:
|
||||
return self._utility.executeCommand(self.__node, command)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def getMembers(self):
|
||||
def getMembers(self) -> Optional[List[str]]:
|
||||
for self.__node in self._nodes:
|
||||
response = self.executeCommand(['members'])
|
||||
if response:
|
||||
@@ -64,7 +70,8 @@ class SyncObjUtility(object):
|
||||
|
||||
class DynMemberSyncObj(SyncObj):
|
||||
|
||||
def __init__(self, selfAddress, partnerAddrs, conf, retry_timeout=10):
|
||||
def __init__(self, selfAddress: Optional[str], partnerAddrs: Collection[str],
|
||||
conf: SyncObjConf, retry_timeout: int = 10) -> None:
|
||||
self.__early_apply_local_log = selfAddress is not None
|
||||
self.applied_local_log = False
|
||||
|
||||
@@ -81,12 +88,12 @@ class DynMemberSyncObj(SyncObj):
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
|
||||
def getMembers(self, args, callback):
|
||||
def getMembers(self, args: Any, callback: Callable[[Any, Any], Any]) -> None:
|
||||
callback([{'addr': node.id, 'leader': node == self._getLeader(), 'status': CONNECTION_STATE.CONNECTED
|
||||
if self.isNodeConnected(node) else CONNECTION_STATE.DISCONNECTED} for node in self.otherNodes]
|
||||
+ [{'addr': self.selfNode.id, 'leader': self._isLeader(), 'status': CONNECTION_STATE.CONNECTED}], None)
|
||||
|
||||
def _onTick(self, timeToWait=0.0):
|
||||
def _onTick(self, timeToWait: float = 0.0):
|
||||
super(DynMemberSyncObj, self)._onTick(timeToWait)
|
||||
|
||||
# The SyncObj calls onReady callback only when cluster got the leader and is ready for writes.
|
||||
@@ -98,11 +105,12 @@ class DynMemberSyncObj(SyncObj):
|
||||
|
||||
class KVStoreTTL(DynMemberSyncObj):
|
||||
|
||||
def __init__(self, on_ready, on_set, on_delete, **config):
|
||||
def __init__(self, on_ready: Optional[Callable[..., Any]], on_set: Optional[Callable[[str, Dict[str, Any]], None]],
|
||||
on_delete: Optional[Callable[[str], None]], **config: Any) -> None:
|
||||
self.__thread = None
|
||||
self.__on_set = on_set
|
||||
self.__on_delete = on_delete
|
||||
self.__limb = {}
|
||||
self.__limb: Dict[str, Dict[str, Any]] = {}
|
||||
self.set_retry_timeout(int(config.get('retry_timeout') or 10))
|
||||
|
||||
self_addr = config.get('self_addr')
|
||||
@@ -128,22 +136,22 @@ class KVStoreTTL(DynMemberSyncObj):
|
||||
onReady=on_ready, dynamicMembershipChange=True)
|
||||
|
||||
super(KVStoreTTL, self).__init__(self_addr, partner_addrs, conf, self.__retry_timeout)
|
||||
self.__data = {}
|
||||
self.__data: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
@staticmethod
|
||||
def __check_requirements(old_value, **kwargs):
|
||||
return ('prevExist' not in kwargs or bool(kwargs['prevExist']) == bool(old_value)) and \
|
||||
('prevValue' not in kwargs or old_value and old_value['value'] == kwargs['prevValue']) and \
|
||||
(not kwargs.get('prevIndex') or old_value and old_value['index'] == kwargs['prevIndex'])
|
||||
def __check_requirements(old_value: Dict[str, Any], **kwargs: Any) -> bool:
|
||||
return bool(('prevExist' not in kwargs or bool(kwargs['prevExist']) == bool(old_value))
|
||||
and ('prevValue' not in kwargs or old_value and old_value['value'] == kwargs['prevValue'])
|
||||
and (not kwargs.get('prevIndex') or old_value and old_value['index'] == kwargs['prevIndex']))
|
||||
|
||||
def set_retry_timeout(self, retry_timeout):
|
||||
def set_retry_timeout(self, retry_timeout: int) -> None:
|
||||
self.__retry_timeout = retry_timeout
|
||||
|
||||
def retry(self, func, *args, **kwargs):
|
||||
def retry(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
||||
event = threading.Event()
|
||||
ret = {'result': None, 'error': -1}
|
||||
|
||||
def callback(result, error):
|
||||
def callback(result: Any, error: Any) -> None:
|
||||
ret.update(result=result, error=error)
|
||||
event.set()
|
||||
|
||||
@@ -167,7 +175,7 @@ class KVStoreTTL(DynMemberSyncObj):
|
||||
return False
|
||||
|
||||
@replicated
|
||||
def _set(self, key, value, **kwargs):
|
||||
def _set(self, key: str, value: Dict[str, Any], **kwargs: Any) -> bool:
|
||||
old_value = self.__data.get(key, {})
|
||||
if not self.__check_requirements(old_value, **kwargs):
|
||||
return False
|
||||
@@ -181,29 +189,30 @@ class KVStoreTTL(DynMemberSyncObj):
|
||||
self.__on_set(key, value)
|
||||
return True
|
||||
|
||||
def set(self, key, value, ttl=None, handle_raft_error=True, **kwargs):
|
||||
def set(self, key: str, value: str, ttl: Optional[int] = None,
|
||||
handle_raft_error: bool = True, **kwargs: Any) -> bool:
|
||||
old_value = self.__data.get(key, {})
|
||||
if not self.__check_requirements(old_value, **kwargs):
|
||||
return False
|
||||
|
||||
value = {'value': value, 'updated': time.time()}
|
||||
value['created'] = old_value.get('created', value['updated'])
|
||||
data: Dict[str, Any] = {'value': value, 'updated': time.time()}
|
||||
data['created'] = old_value.get('created', data['updated'])
|
||||
if ttl:
|
||||
value['expire'] = value['updated'] + ttl
|
||||
data['expire'] = data['updated'] + ttl
|
||||
try:
|
||||
return self.retry(self._set, key, value, **kwargs)
|
||||
return self.retry(self._set, key, data, **kwargs)
|
||||
except RaftError:
|
||||
if not handle_raft_error:
|
||||
raise
|
||||
return False
|
||||
|
||||
def __pop(self, key):
|
||||
def __pop(self, key: str) -> None:
|
||||
self.__data.pop(key)
|
||||
if self.__on_delete:
|
||||
self.__on_delete(key)
|
||||
|
||||
@replicated
|
||||
def _delete(self, key, recursive=False, **kwargs):
|
||||
def _delete(self, key: str, recursive: bool = False, **kwargs: Any) -> bool:
|
||||
if recursive:
|
||||
for k in list(self.__data.keys()):
|
||||
if k.startswith(key):
|
||||
@@ -214,7 +223,7 @@ class KVStoreTTL(DynMemberSyncObj):
|
||||
self.__pop(key)
|
||||
return True
|
||||
|
||||
def delete(self, key, recursive=False, **kwargs):
|
||||
def delete(self, key: str, recursive: bool = False, **kwargs: Any) -> bool:
|
||||
if not recursive and not self.__check_requirements(self.__data.get(key, {}), **kwargs):
|
||||
return False
|
||||
try:
|
||||
@@ -223,32 +232,32 @@ class KVStoreTTL(DynMemberSyncObj):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def __values_match(old, new):
|
||||
def __values_match(old: Dict[str, Any], new: Dict[str, Any]) -> bool:
|
||||
return all(old.get(n) == new.get(n) for n in ('created', 'updated', 'expire', 'value'))
|
||||
|
||||
@replicated
|
||||
def _expire(self, key, value, callback=None):
|
||||
def _expire(self, key: str, value: Dict[str, Any], callback: Optional[Callable[..., Any]] = None) -> None:
|
||||
current = self.__data.get(key)
|
||||
if current and self.__values_match(current, value):
|
||||
self.__pop(key)
|
||||
|
||||
def __expire_keys(self):
|
||||
def __expire_keys(self) -> None:
|
||||
for key, value in self.__data.items():
|
||||
if value and 'expire' in value and value['expire'] <= time.time() and \
|
||||
not (key in self.__limb and self.__values_match(self.__limb[key], value)):
|
||||
self.__limb[key] = value
|
||||
|
||||
def callback(*args):
|
||||
def callback(*args: Any) -> None:
|
||||
if key in self.__limb and self.__values_match(self.__limb[key], value):
|
||||
self.__limb.pop(key)
|
||||
self._expire(key, value, callback=callback)
|
||||
|
||||
def get(self, key, recursive=False):
|
||||
def get(self, key: str, recursive: bool = False) -> Union[None, Dict[str, Any], Dict[str, Dict[str, Any]]]:
|
||||
if not recursive:
|
||||
return self.__data.get(key)
|
||||
return {k: v for k, v in self.__data.items() if k.startswith(key)}
|
||||
|
||||
def _onTick(self, timeToWait=0.0):
|
||||
def _onTick(self, timeToWait: float = 0.0) -> None:
|
||||
super(KVStoreTTL, self)._onTick(timeToWait)
|
||||
|
||||
if self._isLeader():
|
||||
@@ -256,17 +265,17 @@ class KVStoreTTL(DynMemberSyncObj):
|
||||
else:
|
||||
self.__limb.clear()
|
||||
|
||||
def _autoTickThread(self):
|
||||
def _autoTickThread(self) -> None:
|
||||
self.__destroying = False
|
||||
while not self.__destroying:
|
||||
self.doTick(self.conf.autoTickPeriod)
|
||||
|
||||
def startAutoTick(self):
|
||||
def startAutoTick(self) -> None:
|
||||
self.__thread = threading.Thread(target=self._autoTickThread)
|
||||
self.__thread.daemon = True
|
||||
self.__thread.start()
|
||||
|
||||
def destroy(self):
|
||||
def destroy(self) -> None:
|
||||
if self.__thread:
|
||||
self.__destroying = True
|
||||
self.__thread.join()
|
||||
@@ -275,7 +284,7 @@ class KVStoreTTL(DynMemberSyncObj):
|
||||
|
||||
class Raft(AbstractDCS):
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: Dict[str, Any]) -> None:
|
||||
super(Raft, self).__init__(config)
|
||||
self._ttl = int(config.get('ttl') or 30)
|
||||
|
||||
@@ -290,7 +299,7 @@ class Raft(AbstractDCS):
|
||||
else:
|
||||
logger.info('waiting on raft')
|
||||
|
||||
def _on_set(self, key, value):
|
||||
def _on_set(self, key: str, value: Dict[str, Any]) -> None:
|
||||
leader = (self._sync_obj.get(self.leader_path) or {}).get('value')
|
||||
if key == value['created'] == value['updated'] and \
|
||||
(key.startswith(self.members_path) or key == self.leader_path and leader != self._name) or \
|
||||
@@ -298,29 +307,29 @@ class Raft(AbstractDCS):
|
||||
key in (self.config_path, self.sync_path):
|
||||
self.event.set()
|
||||
|
||||
def _on_delete(self, key):
|
||||
def _on_delete(self, key: str) -> None:
|
||||
if key == self.leader_path:
|
||||
self.event.set()
|
||||
|
||||
def set_ttl(self, ttl):
|
||||
def set_ttl(self, ttl: int) -> Optional[bool]:
|
||||
self._ttl = ttl
|
||||
|
||||
@property
|
||||
def ttl(self):
|
||||
def ttl(self) -> int:
|
||||
return self._ttl
|
||||
|
||||
def set_retry_timeout(self, retry_timeout):
|
||||
def set_retry_timeout(self, retry_timeout: int) -> None:
|
||||
self._sync_obj.set_retry_timeout(retry_timeout)
|
||||
|
||||
def reload_config(self, config):
|
||||
def reload_config(self, config: Union['Config', Dict[str, Any]]) -> None:
|
||||
super(Raft, self).reload_config(config)
|
||||
globalDnsResolver().setTimeouts(self.ttl, self.loop_wait)
|
||||
|
||||
@staticmethod
|
||||
def member(key, value):
|
||||
def member(key: str, value: Dict[str, Any]) -> Member:
|
||||
return Member.from_node(value['index'], os.path.basename(key), None, value['value'])
|
||||
|
||||
def _cluster_from_nodes(self, nodes):
|
||||
def _cluster_from_nodes(self, nodes: Dict[str, Any]) -> Cluster:
|
||||
# get initialize flag
|
||||
initialize = nodes.get(self._INITIALIZE)
|
||||
initialize = initialize and initialize['value']
|
||||
@@ -348,7 +357,7 @@ class Raft(AbstractDCS):
|
||||
slots = None
|
||||
|
||||
try:
|
||||
last_lsn = int(last_lsn)
|
||||
last_lsn = int(last_lsn or '')
|
||||
except Exception:
|
||||
last_lsn = 0
|
||||
|
||||
@@ -380,79 +389,81 @@ class Raft(AbstractDCS):
|
||||
|
||||
return Cluster(initialize, config, leader, last_lsn, members, failover, sync, history, slots, failsafe)
|
||||
|
||||
def _cluster_loader(self, path):
|
||||
def _cluster_loader(self, path: str) -> Cluster:
|
||||
response = self._sync_obj.get(path, recursive=True)
|
||||
if not response:
|
||||
return Cluster.empty()
|
||||
nodes = {key[len(path):]: value for key, value in response.items()}
|
||||
return self._cluster_from_nodes(nodes)
|
||||
|
||||
def _citus_cluster_loader(self, path):
|
||||
clusters = defaultdict(dict)
|
||||
def _citus_cluster_loader(self, path: str) -> Dict[int, Cluster]:
|
||||
clusters: Dict[int, Dict[str, Any]] = defaultdict(dict)
|
||||
response = self._sync_obj.get(path, recursive=True)
|
||||
for key, value in response.items():
|
||||
for key, value in (response or {}).items():
|
||||
key = key[len(path):].split('/', 1)
|
||||
if len(key) == 2 and citus_group_re.match(key[0]):
|
||||
clusters[int(key[0])][key[1]] = value
|
||||
return {group: self._cluster_from_nodes(nodes) for group, nodes in clusters.items()}
|
||||
|
||||
def _load_cluster(self, path, loader):
|
||||
def _load_cluster(
|
||||
self, path: str, loader: Callable[[str], Union[Cluster, Dict[int, Cluster]]]
|
||||
) -> Union[Cluster, Dict[int, Cluster]]:
|
||||
return loader(path)
|
||||
|
||||
def _write_leader_optime(self, last_lsn):
|
||||
def _write_leader_optime(self, last_lsn: str) -> bool:
|
||||
return self._sync_obj.set(self.leader_optime_path, last_lsn, timeout=1)
|
||||
|
||||
def _write_status(self, value):
|
||||
def _write_status(self, value: str) -> bool:
|
||||
return self._sync_obj.set(self.status_path, value, timeout=1)
|
||||
|
||||
def _write_failsafe(self, value):
|
||||
def _write_failsafe(self, value: str) -> bool:
|
||||
return self._sync_obj.set(self.failsafe_path, value, timeout=1)
|
||||
|
||||
def _update_leader(self):
|
||||
def _update_leader(self) -> bool:
|
||||
ret = self._sync_obj.set(self.leader_path, self._name, ttl=self._ttl,
|
||||
handle_raft_error=False, prevValue=self._name)
|
||||
if not ret and self._sync_obj.get(self.leader_path) is None:
|
||||
ret = self.attempt_to_acquire_leader()
|
||||
return ret
|
||||
|
||||
def attempt_to_acquire_leader(self):
|
||||
def attempt_to_acquire_leader(self) -> bool:
|
||||
return self._sync_obj.set(self.leader_path, self._name, ttl=self._ttl, handle_raft_error=False, prevExist=False)
|
||||
|
||||
def set_failover_value(self, value, index=None):
|
||||
def set_failover_value(self, value: str, index: Optional[int] = None) -> bool:
|
||||
return self._sync_obj.set(self.failover_path, value, prevIndex=index)
|
||||
|
||||
def set_config_value(self, value, index=None):
|
||||
def set_config_value(self, value: str, index: Optional[int] = None) -> bool:
|
||||
return self._sync_obj.set(self.config_path, value, prevIndex=index)
|
||||
|
||||
def touch_member(self, data):
|
||||
data = json.dumps(data, separators=(',', ':'))
|
||||
return self._sync_obj.set(self.member_path, data, self._ttl, timeout=2)
|
||||
def touch_member(self, data: Dict[str, Any]) -> bool:
|
||||
value = json.dumps(data, separators=(',', ':'))
|
||||
return self._sync_obj.set(self.member_path, value, self._ttl, timeout=2)
|
||||
|
||||
def take_leader(self):
|
||||
def take_leader(self) -> bool:
|
||||
return self._sync_obj.set(self.leader_path, self._name, ttl=self._ttl)
|
||||
|
||||
def initialize(self, create_new=True, sysid=''):
|
||||
def initialize(self, create_new: bool = True, sysid: str = '') -> bool:
|
||||
return self._sync_obj.set(self.initialize_path, sysid, prevExist=(not create_new))
|
||||
|
||||
def _delete_leader(self):
|
||||
def _delete_leader(self) -> bool:
|
||||
return self._sync_obj.delete(self.leader_path, prevValue=self._name, timeout=1)
|
||||
|
||||
def cancel_initialization(self):
|
||||
def cancel_initialization(self) -> bool:
|
||||
return self._sync_obj.delete(self.initialize_path)
|
||||
|
||||
def delete_cluster(self):
|
||||
def delete_cluster(self) -> bool:
|
||||
return self._sync_obj.delete(self.client_path(''), recursive=True)
|
||||
|
||||
def set_history_value(self, value):
|
||||
def set_history_value(self, value: str) -> bool:
|
||||
return self._sync_obj.set(self.history_path, value)
|
||||
|
||||
def set_sync_state_value(self, value, index=None):
|
||||
def set_sync_state_value(self, value: str, index: Optional[int] = None) -> bool:
|
||||
return self._sync_obj.set(self.sync_path, value, prevIndex=index)
|
||||
|
||||
def delete_sync_state(self, index=None):
|
||||
def delete_sync_state(self, index: Optional[int] = None) -> bool:
|
||||
return self._sync_obj.delete(self.sync_path, prevIndex=index)
|
||||
|
||||
def watch(self, leader_index, timeout):
|
||||
def watch(self, leader_index: Optional[int], timeout: float) -> bool:
|
||||
try:
|
||||
return super(Raft, self).watch(leader_index, timeout)
|
||||
finally:
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
import json
|
||||
import logging
|
||||
import select
|
||||
import socket
|
||||
import time
|
||||
|
||||
from kazoo.client import KazooClient, KazooState, KazooRetry
|
||||
from kazoo.exceptions import ConnectionClosedError, NoNodeError, NodeExistsError, SessionExpiredError
|
||||
from kazoo.handlers.threading import SequentialThreadingHandler
|
||||
from kazoo.protocol.states import KeeperState
|
||||
from kazoo.handlers.threading import AsyncResult, SequentialThreadingHandler
|
||||
from kazoo.protocol.states import KeeperState, WatchedEvent, ZnodeStat
|
||||
from kazoo.retry import RetryFailedError
|
||||
from kazoo.security import make_acl
|
||||
from kazoo.security import ACL, make_acl
|
||||
from typing import Any, Callable, Dict, List, Optional, Union, Tuple, TYPE_CHECKING
|
||||
|
||||
from . import AbstractDCS, ClusterConfig, Cluster, Failover, Leader, Member, SyncState, TimelineHistory, citus_group_re
|
||||
from ..exceptions import DCSError
|
||||
from ..utils import deep_compare
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from ..config import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -23,14 +27,14 @@ class ZooKeeperError(DCSError):
|
||||
|
||||
class PatroniSequentialThreadingHandler(SequentialThreadingHandler):
|
||||
|
||||
def __init__(self, connect_timeout):
|
||||
def __init__(self, connect_timeout: Union[int, float]) -> None:
|
||||
super(PatroniSequentialThreadingHandler, self).__init__()
|
||||
self.set_connect_timeout(connect_timeout)
|
||||
|
||||
def set_connect_timeout(self, connect_timeout):
|
||||
def set_connect_timeout(self, connect_timeout: Union[int, float]) -> None:
|
||||
self._connect_timeout = max(1.0, connect_timeout / 2.0) # try to connect to zookeeper node during loop_wait/2
|
||||
|
||||
def create_connection(self, *args, **kwargs):
|
||||
def create_connection(self, *args: Any, **kwargs: Any) -> socket.socket:
|
||||
"""This method is trying to establish connection with one of the zookeeper nodes.
|
||||
Somehow strategy "fail earlier and retry more often" works way better comparing to
|
||||
the original strategy "try to connect with specified timeout".
|
||||
@@ -41,16 +45,16 @@ class PatroniSequentialThreadingHandler(SequentialThreadingHandler):
|
||||
:param args: always contains `tuple(host, port)` as the first element and could contain
|
||||
`connect_timeout` (negotiated session timeout) as the second element."""
|
||||
|
||||
args = list(args)
|
||||
if len(args) == 0: # kazoo 2.6.0 slightly changed the way how it calls create_connection method
|
||||
args_list: List[Any] = list(args)
|
||||
if len(args_list) == 0: # kazoo 2.6.0 slightly changed the way how it calls create_connection method
|
||||
kwargs['timeout'] = max(self._connect_timeout, kwargs.get('timeout', self._connect_timeout * 10) / 10.0)
|
||||
elif len(args) == 1:
|
||||
args.append(self._connect_timeout)
|
||||
elif len(args_list) == 1:
|
||||
args_list.append(self._connect_timeout)
|
||||
else:
|
||||
args[1] = max(self._connect_timeout, args[1] / 10.0)
|
||||
return super(PatroniSequentialThreadingHandler, self).create_connection(*args, **kwargs)
|
||||
args_list[1] = max(self._connect_timeout, args_list[1] / 10.0)
|
||||
return super(PatroniSequentialThreadingHandler, self).create_connection(*args_list, **kwargs)
|
||||
|
||||
def select(self, *args, **kwargs):
|
||||
def select(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Python 3.XY may raise following exceptions if select/poll are called with an invalid socket:
|
||||
- `ValueError`: because fd == -1
|
||||
@@ -68,7 +72,7 @@ class PatroniSequentialThreadingHandler(SequentialThreadingHandler):
|
||||
|
||||
class PatroniKazooClient(KazooClient):
|
||||
|
||||
def _call(self, request, async_object):
|
||||
def _call(self, request: Tuple[Any], async_object: AsyncResult) -> Optional[bool]:
|
||||
# Before kazoo==2.7.0 it wasn't possible to send requests to zookeeper if
|
||||
# the connection is in the SUSPENDED state and Patroni was strongly relying on it.
|
||||
# The https://github.com/python-zk/kazoo/pull/588 changed it, and now such requests are queued.
|
||||
@@ -82,7 +86,7 @@ class PatroniKazooClient(KazooClient):
|
||||
|
||||
class ZooKeeper(AbstractDCS):
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: Dict[str, Any]) -> None:
|
||||
super(ZooKeeper, self).__init__(config)
|
||||
|
||||
hosts = config.get('hosts', [])
|
||||
@@ -94,17 +98,18 @@ class ZooKeeper(AbstractDCS):
|
||||
kwargs = {v: config[k] for k, v in mapping.items() if k in config}
|
||||
|
||||
if 'set_acls' in config:
|
||||
kwargs['default_acl'] = []
|
||||
default_acl: List[ACL] = []
|
||||
for principal, permissions in config['set_acls'].items():
|
||||
normalizedPermissions = [p.upper() for p in permissions]
|
||||
kwargs['default_acl'].append(make_acl(scheme='x509',
|
||||
credential=principal,
|
||||
read='READ' in normalizedPermissions,
|
||||
write='WRITE' in normalizedPermissions,
|
||||
create='CREATE' in normalizedPermissions,
|
||||
delete='DELETE' in normalizedPermissions,
|
||||
admin='ADMIN' in normalizedPermissions,
|
||||
all='ALL' in normalizedPermissions))
|
||||
default_acl.append(make_acl(scheme='x509',
|
||||
credential=principal,
|
||||
read='READ' in normalizedPermissions,
|
||||
write='WRITE' in normalizedPermissions,
|
||||
create='CREATE' in normalizedPermissions,
|
||||
delete='DELETE' in normalizedPermissions,
|
||||
admin='ADMIN' in normalizedPermissions,
|
||||
all='ALL' in normalizedPermissions))
|
||||
kwargs['default_acl'] = default_acl
|
||||
|
||||
self._client = PatroniKazooClient(hosts, handler=PatroniSequentialThreadingHandler(config['retry_timeout']),
|
||||
timeout=config['ttl'], connection_retry=KazooRetry(max_delay=1, max_tries=-1,
|
||||
@@ -112,16 +117,16 @@ class ZooKeeper(AbstractDCS):
|
||||
deadline=config['retry_timeout'], sleep_func=time.sleep), **kwargs)
|
||||
self._client.add_listener(self.session_listener)
|
||||
|
||||
self._fetch_cluster = True
|
||||
self._fetch_status = True
|
||||
self.__last_member_data = None
|
||||
self._fetch_cluster: bool = True
|
||||
self._fetch_status: bool = True
|
||||
self.__last_member_data: Optional[Dict[str, Any]] = None
|
||||
|
||||
self._orig_kazoo_connect = self._client._connection._connect
|
||||
self._client._connection._connect = self._kazoo_connect
|
||||
|
||||
self._client.start()
|
||||
|
||||
def _kazoo_connect(self, *args):
|
||||
def _kazoo_connect(self, *args: Any) -> Tuple[Union[int, float], Union[int, float]]:
|
||||
"""Kazoo is using Ping's to determine health of connection to zookeeper. If there is no
|
||||
response on Ping after Ping interval (1/2 from read_timeout) it will consider current
|
||||
connection dead and try to connect to another node. Without this "magic" it was taking
|
||||
@@ -136,30 +141,28 @@ class ZooKeeper(AbstractDCS):
|
||||
ret = self._orig_kazoo_connect(*args)
|
||||
return max(self.loop_wait - 2, 2) * 1000, ret[1]
|
||||
|
||||
def session_listener(self, state):
|
||||
def session_listener(self, state: str) -> None:
|
||||
if state in [KazooState.SUSPENDED, KazooState.LOST]:
|
||||
self.cluster_watcher(None)
|
||||
|
||||
def status_watcher(self, event):
|
||||
def status_watcher(self, event: Optional[WatchedEvent]) -> None:
|
||||
self._fetch_status = True
|
||||
self.event.set()
|
||||
|
||||
def cluster_watcher(self, event):
|
||||
def cluster_watcher(self, event: Optional[WatchedEvent]) -> None:
|
||||
self._fetch_cluster = True
|
||||
if not event or event.state != KazooState.CONNECTED or event.path.startswith(self.client_path('')):
|
||||
self.status_watcher(event)
|
||||
|
||||
def members_watcher(self, event):
|
||||
self._fetch_cluster = True
|
||||
|
||||
def reload_config(self, config):
|
||||
def reload_config(self, config: Union['Config', Dict[str, Any]]) -> None:
|
||||
self.set_retry_timeout(config['retry_timeout'])
|
||||
|
||||
loop_wait = config['loop_wait']
|
||||
|
||||
loop_wait_changed = self._loop_wait != loop_wait
|
||||
self._loop_wait = loop_wait
|
||||
self._client.handler.set_connect_timeout(loop_wait)
|
||||
if isinstance(self._client.handler, PatroniSequentialThreadingHandler):
|
||||
self._client.handler.set_connect_timeout(loop_wait)
|
||||
|
||||
# We need to reestablish connection to zookeeper if we want to change
|
||||
# read_timeout (and Ping interval respectively), because read_timeout
|
||||
@@ -170,7 +173,7 @@ class ZooKeeper(AbstractDCS):
|
||||
if not self.set_ttl(config['ttl']) and loop_wait_changed:
|
||||
self._client._connection._socket.close()
|
||||
|
||||
def set_ttl(self, ttl):
|
||||
def set_ttl(self, ttl: int) -> Optional[bool]:
|
||||
"""It is not possible to change ttl (session_timeout) in zookeeper without
|
||||
destroying old session and creating the new one. This method returns `!True`
|
||||
if session_timeout has been changed (`restart()` has been called)."""
|
||||
@@ -181,21 +184,23 @@ class ZooKeeper(AbstractDCS):
|
||||
return True
|
||||
|
||||
@property
|
||||
def ttl(self):
|
||||
return self._client._session_timeout / 1000.0
|
||||
def ttl(self) -> int:
|
||||
return int(self._client._session_timeout / 1000.0)
|
||||
|
||||
def set_retry_timeout(self, retry_timeout):
|
||||
def set_retry_timeout(self, retry_timeout: int) -> None:
|
||||
retry = self._client.retry if isinstance(self._client.retry, KazooRetry) else self._client._retry
|
||||
retry.deadline = retry_timeout
|
||||
|
||||
def get_node(self, key, watch=None):
|
||||
def get_node(
|
||||
self, key: str, watch: Optional[Callable[[WatchedEvent], None]] = None
|
||||
) -> Optional[Tuple[str, ZnodeStat]]:
|
||||
try:
|
||||
ret = self._client.get(key, watch)
|
||||
return (ret[0].decode('utf-8'), ret[1])
|
||||
except NoNodeError:
|
||||
return None
|
||||
|
||||
def get_status(self, path, leader):
|
||||
def get_status(self, path: str, leader: Optional[Leader]) -> Tuple[int, Optional[Dict[str, int]]]:
|
||||
watch = self.status_watcher if not leader or leader.name != self._name else None
|
||||
|
||||
status = self.get_node(path + self._STATUS, watch)
|
||||
@@ -212,7 +217,7 @@ class ZooKeeper(AbstractDCS):
|
||||
slots = None
|
||||
|
||||
try:
|
||||
last_lsn = int(last_lsn)
|
||||
last_lsn = int(last_lsn or '')
|
||||
except Exception:
|
||||
last_lsn = 0
|
||||
|
||||
@@ -220,24 +225,24 @@ class ZooKeeper(AbstractDCS):
|
||||
return last_lsn, slots
|
||||
|
||||
@staticmethod
|
||||
def member(name, value, znode):
|
||||
def member(name: str, value: str, znode: ZnodeStat) -> Member:
|
||||
return Member.from_node(znode.version, name, znode.ephemeralOwner, value)
|
||||
|
||||
def get_children(self, key, watch=None):
|
||||
def get_children(self, key: str, watch: Optional[Callable[[WatchedEvent], None]] = None) -> List[str]:
|
||||
try:
|
||||
return self._client.get_children(key, watch)
|
||||
except NoNodeError:
|
||||
return []
|
||||
|
||||
def load_members(self, path):
|
||||
members = []
|
||||
def load_members(self, path: str) -> List[Member]:
|
||||
members: List[Member] = []
|
||||
for member in self.get_children(path + self._MEMBERS, self.cluster_watcher):
|
||||
data = self.get_node(path + self._MEMBERS + member)
|
||||
if data is not None:
|
||||
members.append(self.member(member, *data))
|
||||
return members
|
||||
|
||||
def _cluster_loader(self, path):
|
||||
def _cluster_loader(self, path: str) -> Cluster:
|
||||
self._fetch_cluster = False
|
||||
self.event.clear()
|
||||
nodes = set(self.get_children(path, self.cluster_watcher))
|
||||
@@ -286,9 +291,9 @@ class ZooKeeper(AbstractDCS):
|
||||
|
||||
return Cluster(initialize, config, leader, last_lsn, members, failover, sync, history, slots, failsafe)
|
||||
|
||||
def _citus_cluster_loader(self, path):
|
||||
def _citus_cluster_loader(self, path: str) -> Dict[int, Cluster]:
|
||||
fetch_cluster = False
|
||||
ret = {}
|
||||
ret: Dict[int, Cluster] = {}
|
||||
for node in self.get_children(path, self.cluster_watcher):
|
||||
if citus_group_re.match(node):
|
||||
ret[int(node)] = self._cluster_loader(path + node + '/')
|
||||
@@ -296,7 +301,9 @@ class ZooKeeper(AbstractDCS):
|
||||
self._fetch_cluster = fetch_cluster
|
||||
return ret
|
||||
|
||||
def _load_cluster(self, path, loader):
|
||||
def _load_cluster(
|
||||
self, path: str, loader: Callable[[str], Union[Cluster, Dict[int, Cluster]]]
|
||||
) -> Union[Cluster, Dict[int, Cluster]]:
|
||||
cluster = self.cluster if path == self._base_path + '/' else None
|
||||
if self._fetch_cluster or cluster is None:
|
||||
try:
|
||||
@@ -315,18 +322,18 @@ class ZooKeeper(AbstractDCS):
|
||||
try:
|
||||
last_lsn, slots = self.get_status(self.client_path(''), cluster.leader)
|
||||
self.event.clear()
|
||||
cluster = list(cluster)
|
||||
cluster[3] = last_lsn
|
||||
cluster[8] = slots
|
||||
cluster = Cluster(*cluster)
|
||||
new_cluster: List[Any] = list(cluster)
|
||||
new_cluster[3] = last_lsn
|
||||
new_cluster[8] = slots
|
||||
cluster = Cluster(*new_cluster)
|
||||
except Exception:
|
||||
pass
|
||||
return cluster
|
||||
|
||||
def _bypass_caches(self):
|
||||
def _bypass_caches(self) -> None:
|
||||
self._fetch_cluster = True
|
||||
|
||||
def _create(self, path, value, retry=False, ephemeral=False):
|
||||
def _create(self, path: str, value: bytes, retry: bool = False, ephemeral: bool = False) -> bool:
|
||||
try:
|
||||
if retry:
|
||||
self._client.retry(self._client.create, path, value, makepath=True, ephemeral=ephemeral)
|
||||
@@ -337,7 +344,7 @@ class ZooKeeper(AbstractDCS):
|
||||
logger.exception('Failed to create %s', path)
|
||||
return False
|
||||
|
||||
def attempt_to_acquire_leader(self):
|
||||
def attempt_to_acquire_leader(self) -> bool:
|
||||
try:
|
||||
self._client.retry(self._client.create, self.leader_path, self._name.encode('utf-8'),
|
||||
makepath=True, ephemeral=True)
|
||||
@@ -350,43 +357,44 @@ class ZooKeeper(AbstractDCS):
|
||||
logger.info('Could not take out TTL lock')
|
||||
return False
|
||||
|
||||
def _set_or_create(self, key, value, index=None, retry=False, do_not_create_empty=False):
|
||||
value = value.encode('utf-8')
|
||||
def _set_or_create(self, key: str, value: str, index: Optional[int] = None,
|
||||
retry: bool = False, do_not_create_empty: bool = False) -> bool:
|
||||
value_bytes = value.encode('utf-8')
|
||||
try:
|
||||
if retry:
|
||||
self._client.retry(self._client.set, key, value, version=index or -1)
|
||||
self._client.retry(self._client.set, key, value_bytes, version=index or -1)
|
||||
else:
|
||||
self._client.set_async(key, value, version=index or -1).get(timeout=1)
|
||||
self._client.set_async(key, value_bytes, version=index or -1).get(timeout=1)
|
||||
return True
|
||||
except NoNodeError:
|
||||
if do_not_create_empty and not value:
|
||||
if do_not_create_empty and not value_bytes:
|
||||
return True
|
||||
elif index is None:
|
||||
return self._create(key, value, retry)
|
||||
return self._create(key, value_bytes, retry)
|
||||
else:
|
||||
return False
|
||||
except Exception:
|
||||
logger.exception('Failed to update %s', key)
|
||||
return False
|
||||
|
||||
def set_failover_value(self, value, index=None):
|
||||
def set_failover_value(self, value: str, index: Optional[int] = None) -> bool:
|
||||
return self._set_or_create(self.failover_path, value, index)
|
||||
|
||||
def set_config_value(self, value, index=None):
|
||||
def set_config_value(self, value: str, index: Optional[int] = None) -> bool:
|
||||
return self._set_or_create(self.config_path, value, index, retry=True)
|
||||
|
||||
def initialize(self, create_new=True, sysid=""):
|
||||
sysid = sysid.encode('utf-8')
|
||||
return self._create(self.initialize_path, sysid, retry=True) if create_new \
|
||||
else self._client.retry(self._client.set, self.initialize_path, sysid)
|
||||
def initialize(self, create_new: bool = True, sysid: str = "") -> bool:
|
||||
sysid_bytes = sysid.encode('utf-8')
|
||||
return self._create(self.initialize_path, sysid_bytes, retry=True) if create_new \
|
||||
else self._client.retry(self._client.set, self.initialize_path, sysid_bytes)
|
||||
|
||||
def touch_member(self, data):
|
||||
def touch_member(self, data: Dict[str, Any]) -> bool:
|
||||
cluster = self.cluster
|
||||
member = cluster and cluster.get_member(self._name, fallback_to_leader=False)
|
||||
member_data = self.__last_member_data or member and member.data
|
||||
# We want to notify leader if some important fields in the member key changed by removing ZNode
|
||||
if member and (self._client.client_id is not None and member.session != self._client.client_id[0]
|
||||
or not (deep_compare(member_data.get('tags', {}), data.get('tags', {}))
|
||||
or not (member_data and deep_compare(member_data.get('tags', {}), data.get('tags', {}))
|
||||
and (member_data.get('state') == data.get('state')
|
||||
or 'running' not in (member_data.get('state'), data.get('state')))
|
||||
and member_data.get('version') == data.get('version')
|
||||
@@ -401,7 +409,7 @@ class ZooKeeper(AbstractDCS):
|
||||
member = None
|
||||
|
||||
encoded_data = json.dumps(data, separators=(',', ':')).encode('utf-8')
|
||||
if member:
|
||||
if member and member_data:
|
||||
if deep_compare(data, member_data):
|
||||
return True
|
||||
else:
|
||||
@@ -422,19 +430,19 @@ class ZooKeeper(AbstractDCS):
|
||||
|
||||
return False
|
||||
|
||||
def take_leader(self):
|
||||
def take_leader(self) -> bool:
|
||||
return self.attempt_to_acquire_leader()
|
||||
|
||||
def _write_leader_optime(self, last_lsn):
|
||||
def _write_leader_optime(self, last_lsn: str) -> bool:
|
||||
return self._set_or_create(self.leader_optime_path, last_lsn)
|
||||
|
||||
def _write_status(self, value):
|
||||
def _write_status(self, value: str) -> bool:
|
||||
return self._set_or_create(self.status_path, value)
|
||||
|
||||
def _write_failsafe(self, value):
|
||||
def _write_failsafe(self, value: str) -> bool:
|
||||
return self._set_or_create(self.failsafe_path, value)
|
||||
|
||||
def _update_leader(self):
|
||||
def _update_leader(self) -> bool:
|
||||
cluster = self.cluster
|
||||
session = cluster and isinstance(cluster.leader, Leader) and cluster.leader.session
|
||||
if self._client.client_id and self._client.client_id[0] != session:
|
||||
@@ -459,37 +467,39 @@ class ZooKeeper(AbstractDCS):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _delete_leader(self):
|
||||
def _delete_leader(self) -> bool:
|
||||
self._client.restart()
|
||||
return True
|
||||
|
||||
def _cancel_initialization(self):
|
||||
def _cancel_initialization(self) -> None:
|
||||
node = self.get_node(self.initialize_path)
|
||||
if node:
|
||||
self._client.delete(self.initialize_path, version=node[1].version)
|
||||
|
||||
def cancel_initialization(self):
|
||||
def cancel_initialization(self) -> bool:
|
||||
try:
|
||||
self._client.retry(self._cancel_initialization)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Unable to delete initialize key")
|
||||
return False
|
||||
|
||||
def delete_cluster(self):
|
||||
def delete_cluster(self) -> bool:
|
||||
try:
|
||||
return self._client.retry(self._client.delete, self.client_path(''), recursive=True)
|
||||
except NoNodeError:
|
||||
return True
|
||||
|
||||
def set_history_value(self, value):
|
||||
def set_history_value(self, value: str) -> bool:
|
||||
return self._set_or_create(self.history_path, value)
|
||||
|
||||
def set_sync_state_value(self, value, index=None):
|
||||
def set_sync_state_value(self, value: str, index: Optional[int] = None) -> bool:
|
||||
return self._set_or_create(self.sync_path, value, index, retry=True, do_not_create_empty=True)
|
||||
|
||||
def delete_sync_state(self, index=None):
|
||||
def delete_sync_state(self, index: Optional[int] = None) -> bool:
|
||||
return self.set_sync_state_value("{}", index)
|
||||
|
||||
def watch(self, leader_index, timeout):
|
||||
def watch(self, leader_index: Optional[int], timeout: float) -> bool:
|
||||
ret = super(ZooKeeper, self).watch(leader_index, timeout + 0.5)
|
||||
if ret and not self._fetch_status:
|
||||
self._fetch_cluster = True
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class PatroniException(Exception):
|
||||
|
||||
"""Parent class for all kind of exceptions related to selected distributed configuration store"""
|
||||
|
||||
def __init__(self, value):
|
||||
def __init__(self, value: Any) -> None:
|
||||
self.value = value
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
>>> str(PatroniException('foo'))
|
||||
"'foo'"
|
||||
|
||||
361
patroni/ha.py
361
patroni/ha.py
@@ -6,27 +6,26 @@ import sys
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from collections import namedtuple
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from threading import RLock
|
||||
from typing import List, Optional, Union
|
||||
from typing import Any, Callable, Collection, Dict, List, NamedTuple, Optional, Union, Tuple, TYPE_CHECKING
|
||||
|
||||
from . import psycopg
|
||||
from .__main__ import Patroni
|
||||
from .async_executor import AsyncExecutor, CriticalTask
|
||||
from .collections import CaseInsensitiveSet
|
||||
from .dcs import AbstractDCS, Cluster, Leader, Member, RemoteMember
|
||||
from .exceptions import DCSError, PostgresConnectionException, PatroniFatalException
|
||||
from .postgresql.callback_executor import CallbackAction
|
||||
from .postgresql.misc import postgres_version_to_int
|
||||
from .postgresql.postmaster import PostmasterProcess
|
||||
from .postgresql.rewind import Rewind
|
||||
from .utils import polling_loop, tzutc
|
||||
from .dcs import Cluster, Leader, Member, RemoteMember
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _MemberStatus(namedtuple('_MemberStatus', ['member', 'reachable', 'in_recovery',
|
||||
'dcs_last_seen', 'timeline', 'wal_position',
|
||||
'tags', 'watchdog_failed'])):
|
||||
class _MemberStatus(NamedTuple):
|
||||
"""Node status distilled from API response:
|
||||
|
||||
member - dcs.Member object of the node
|
||||
@@ -38,29 +37,37 @@ class _MemberStatus(namedtuple('_MemberStatus', ['member', 'reachable', 'in_reco
|
||||
tags - dictionary with values of different tags (i.e. nofailover)
|
||||
watchdog_failed - indicates that watchdog is required by configuration but not available or failed
|
||||
"""
|
||||
member: Member
|
||||
reachable: bool
|
||||
in_recovery: Optional[bool]
|
||||
dcs_last_seen: int
|
||||
timeline: int
|
||||
wal_position: int
|
||||
tags: Dict[str, Any]
|
||||
watchdog_failed: bool
|
||||
|
||||
@classmethod
|
||||
def from_api_response(cls, member, json):
|
||||
def from_api_response(cls, member: Member, json: Dict[str, Any]) -> '_MemberStatus':
|
||||
"""
|
||||
:param member: dcs.Member object
|
||||
:param json: RestApiHandler.get_postgresql_status() result
|
||||
:returns: _MemberStatus object
|
||||
"""
|
||||
# If one of those is not in a response we want to count the node as not healthy/reachable
|
||||
assert 'wal' in json or 'xlog' in json
|
||||
|
||||
wal = json.get('wal', json.get('xlog'))
|
||||
in_recovery = not bool(wal.get('location')) # abuse difference in primary/replica response format
|
||||
wal: Dict[str, Any] = json.get('wal') or json['xlog']
|
||||
# abuse difference in primary/replica response format
|
||||
in_recovery = not bool(wal.get('location')) or json.get('role') in ('master', 'primary')
|
||||
timeline = json.get('timeline', 0)
|
||||
dcs_last_seen = json.get('dcs_last_seen', 0)
|
||||
wal = in_recovery and max(wal.get('received_location', 0), wal.get('replayed_location', 0))
|
||||
return cls(member, True, in_recovery, dcs_last_seen, timeline, wal,
|
||||
lsn = int(in_recovery and max(wal.get('received_location', 0), wal.get('replayed_location', 0)))
|
||||
return cls(member, True, in_recovery, dcs_last_seen, timeline, lsn,
|
||||
json.get('tags', {}), json.get('watchdog_failed', False))
|
||||
|
||||
@classmethod
|
||||
def unknown(cls, member):
|
||||
def unknown(cls, member: Member) -> '_MemberStatus':
|
||||
return cls(member, False, None, 0, 0, 0, {}, False)
|
||||
|
||||
def failover_limitation(self):
|
||||
def failover_limitation(self) -> Optional[str]:
|
||||
"""Returns reason why this node can't promote or None if everything is ok."""
|
||||
if not self.reachable:
|
||||
return 'not reachable'
|
||||
@@ -73,7 +80,7 @@ class _MemberStatus(namedtuple('_MemberStatus', ['member', 'reachable', 'in_reco
|
||||
|
||||
class Failsafe(object):
|
||||
|
||||
def __init__(self, dcs):
|
||||
def __init__(self, dcs: AbstractDCS) -> None:
|
||||
self._lock = RLock()
|
||||
self._dcs = dcs
|
||||
self._last_update = 0
|
||||
@@ -82,7 +89,7 @@ class Failsafe(object):
|
||||
self._api_url = None
|
||||
self._slots = None
|
||||
|
||||
def update(self, data):
|
||||
def update(self, data: Dict[str, Any]) -> None:
|
||||
with self._lock:
|
||||
self._last_update = time.time()
|
||||
self._name = data['name']
|
||||
@@ -91,26 +98,22 @@ class Failsafe(object):
|
||||
self._slots = data.get('slots')
|
||||
|
||||
@property
|
||||
def leader(self):
|
||||
def leader(self) -> Optional[Leader]:
|
||||
with self._lock:
|
||||
if self._last_update + self._dcs.ttl > time.time():
|
||||
return Leader(None, None,
|
||||
RemoteMember(self._name, {'api_url': self._api_url,
|
||||
'conn_url': self._conn_url,
|
||||
'slots': self._slots}))
|
||||
return Leader('', '', RemoteMember(self._name, {'api_url': self._api_url,
|
||||
'conn_url': self._conn_url,
|
||||
'slots': self._slots}))
|
||||
|
||||
def update_cluster(self, cluster):
|
||||
def update_cluster(self, cluster: Cluster) -> Cluster:
|
||||
# Enreach cluster with the real leader if there was a ping from it
|
||||
leader = self.leader
|
||||
if leader:
|
||||
cluster = list(cluster)
|
||||
# We rely on the strict order of fields in the namedtuple
|
||||
cluster[2] = leader
|
||||
cluster[8] = leader.member.data['slots']
|
||||
cluster = Cluster(*cluster)
|
||||
cluster = Cluster(*cluster[0:2], leader, *cluster[3:8], leader.member.data['slots'], *cluster[9:])
|
||||
return cluster
|
||||
|
||||
def is_active(self):
|
||||
def is_active(self) -> bool:
|
||||
"""Is used to report in REST API whether the failsafe mode was activated.
|
||||
|
||||
On primary the self._last_update is set from the
|
||||
@@ -124,21 +127,21 @@ class Failsafe(object):
|
||||
with self._lock:
|
||||
return self._last_update + self._dcs.ttl > time.time()
|
||||
|
||||
def set_is_active(self, value):
|
||||
def set_is_active(self, value: float) -> None:
|
||||
with self._lock:
|
||||
self._last_update = value
|
||||
|
||||
|
||||
class Ha(object):
|
||||
|
||||
def __init__(self, patroni):
|
||||
def __init__(self, patroni: Patroni):
|
||||
self.patroni = patroni
|
||||
self.state_handler = patroni.postgresql
|
||||
self._rewind = Rewind(self.state_handler)
|
||||
self.dcs = patroni.dcs
|
||||
self.cluster = None
|
||||
self.cluster = Cluster.empty()
|
||||
self.global_config = self.patroni.config.get_global_config(None)
|
||||
self.old_cluster = None
|
||||
self.old_cluster = Cluster.empty()
|
||||
self._is_leader = False
|
||||
self._is_leader_lock = RLock()
|
||||
self._failsafe = Failsafe(patroni.dcs)
|
||||
@@ -147,7 +150,7 @@ class Ha(object):
|
||||
self.recovering = False
|
||||
self._async_response = CriticalTask()
|
||||
self._crash_recovery_executed = False
|
||||
self._crash_recovery_started = None
|
||||
self._crash_recovery_started = 0
|
||||
self._start_timeout = None
|
||||
self._async_executor = AsyncExecutor(self.state_handler.cancellable, self.wakeup)
|
||||
self.watchdog = patroni.watchdog
|
||||
@@ -173,27 +176,27 @@ class Ha(object):
|
||||
ret = self.global_config.primary_stop_timeout
|
||||
return ret if ret > 0 and self.is_synchronous_mode() else None
|
||||
|
||||
def is_paused(self):
|
||||
def is_paused(self) -> bool:
|
||||
""":returns: `True` if in maintenance mode."""
|
||||
return self.global_config.is_paused
|
||||
|
||||
def check_timeline(self):
|
||||
def check_timeline(self) -> bool:
|
||||
""":returns: `True` if should check whether the timeline is latest during the leader race."""
|
||||
return self.global_config.check_mode('check_timeline')
|
||||
|
||||
def is_standby_cluster(self):
|
||||
def is_standby_cluster(self) -> bool:
|
||||
""":returns: `True` if global configuration has a valid "standby_cluster" section."""
|
||||
return self.global_config.is_standby_cluster
|
||||
|
||||
def is_leader(self):
|
||||
def is_leader(self) -> bool:
|
||||
with self._is_leader_lock:
|
||||
return self._is_leader > time.time()
|
||||
|
||||
def set_is_leader(self, value):
|
||||
def set_is_leader(self, value: bool) -> None:
|
||||
with self._is_leader_lock:
|
||||
self._is_leader = time.time() + self.dcs.ttl if value else 0
|
||||
|
||||
def load_cluster_from_dcs(self):
|
||||
def load_cluster_from_dcs(self) -> None:
|
||||
cluster = self.dcs.get_cluster()
|
||||
|
||||
# We want to keep the state of cluster when it was healthy
|
||||
@@ -208,9 +211,9 @@ class Ha(object):
|
||||
if not self.has_lock(False):
|
||||
self.set_is_leader(False)
|
||||
|
||||
self._leader_timeline = None if cluster.is_unlocked() else cluster.leader.timeline
|
||||
self._leader_timeline = cluster.leader.timeline if cluster.leader else None
|
||||
|
||||
def acquire_lock(self):
|
||||
def acquire_lock(self) -> bool:
|
||||
try:
|
||||
ret = self.dcs.attempt_to_acquire_leader()
|
||||
except DCSError:
|
||||
@@ -221,14 +224,14 @@ class Ha(object):
|
||||
self.set_is_leader(ret)
|
||||
return ret
|
||||
|
||||
def _failsafe_config(self):
|
||||
def _failsafe_config(self) -> Optional[Dict[str, str]]:
|
||||
if self.is_failsafe_mode():
|
||||
ret = {m.name: m.api_url for m in self.cluster.members}
|
||||
ret = {m.name: m.api_url for m in self.cluster.members if m.api_url}
|
||||
if self.state_handler.name not in ret:
|
||||
ret[self.state_handler.name] = self.patroni.api.connection_string
|
||||
return ret
|
||||
|
||||
def update_lock(self, write_leader_optime=False):
|
||||
def update_lock(self, write_leader_optime: bool = False) -> bool:
|
||||
last_lsn = slots = None
|
||||
if write_leader_optime:
|
||||
try:
|
||||
@@ -248,13 +251,13 @@ class Ha(object):
|
||||
self.watchdog.keepalive()
|
||||
return ret
|
||||
|
||||
def has_lock(self, info=True):
|
||||
def has_lock(self, info: bool = True) -> bool:
|
||||
lock_owner = self.cluster.leader and self.cluster.leader.name
|
||||
if info:
|
||||
logger.info('Lock owner: %s; I am %s', lock_owner, self.state_handler.name)
|
||||
return lock_owner == self.state_handler.name
|
||||
|
||||
def get_effective_tags(self):
|
||||
def get_effective_tags(self) -> Dict[str, Any]:
|
||||
"""Return configuration tags merged with dynamically applied tags."""
|
||||
tags = self.patroni.tags.copy()
|
||||
# _disable_sync could be modified concurrently, but we don't care as attribute get and set are atomic.
|
||||
@@ -262,10 +265,10 @@ class Ha(object):
|
||||
tags['nosync'] = True
|
||||
return tags
|
||||
|
||||
def notify_citus_coordinator(self, event):
|
||||
def notify_citus_coordinator(self, event: str) -> None:
|
||||
if self.state_handler.citus_handler.is_worker():
|
||||
coordinator = self.dcs.get_citus_coordinator()
|
||||
if coordinator and coordinator.leader and coordinator.leader.conn_kwargs:
|
||||
if coordinator and coordinator.leader and coordinator.leader.conn_url:
|
||||
try:
|
||||
data = {'type': event,
|
||||
'group': self.state_handler.citus_handler.group(),
|
||||
@@ -278,9 +281,9 @@ class Ha(object):
|
||||
logger.warning('Request to Citus coordinator leader %s %s failed: %r',
|
||||
coordinator.leader.name, coordinator.leader.member.api_url, e)
|
||||
|
||||
def touch_member(self):
|
||||
def touch_member(self) -> bool:
|
||||
with self._member_state_lock:
|
||||
data = {
|
||||
data: Dict[str, Any] = {
|
||||
'conn_url': self.state_handler.connection_string,
|
||||
'api_url': self.patroni.api.connection_string,
|
||||
'state': self.state_handler.state,
|
||||
@@ -302,6 +305,7 @@ class Ha(object):
|
||||
if self._async_executor.scheduled_action in (None, 'promote') \
|
||||
and data['state'] in ['running', 'restarting', 'starting']:
|
||||
try:
|
||||
timeline: Optional[int]
|
||||
timeline, wal_position, pg_control_timeline = self.state_handler.timeline_wal_position()
|
||||
data['xlog_location'] = wal_position
|
||||
if not timeline: # try pg_stat_wal_receiver to get the timeline
|
||||
@@ -317,7 +321,7 @@ class Ha(object):
|
||||
if self.state_handler.role == 'standby_leader':
|
||||
timeline = pg_control_timeline or self.state_handler.pg_control_timeline()
|
||||
else:
|
||||
timeline = self.state_handler.replica_cached_timeline(self._leader_timeline)
|
||||
timeline = self.state_handler.replica_cached_timeline(self._leader_timeline) or 0
|
||||
if timeline:
|
||||
data['timeline'] = timeline
|
||||
except Exception:
|
||||
@@ -338,7 +342,7 @@ class Ha(object):
|
||||
self._last_state = new_state
|
||||
return ret
|
||||
|
||||
def clone(self, clone_member=None, msg='(without leader)'):
|
||||
def clone(self, clone_member: Union[Leader, Member, None] = None, msg: str = '(without leader)') -> Optional[bool]:
|
||||
if self.is_standby_cluster() and not isinstance(clone_member, RemoteMember):
|
||||
clone_member = self.get_remote_member(clone_member)
|
||||
|
||||
@@ -352,16 +356,10 @@ class Ha(object):
|
||||
logger.error('failed to bootstrap %s', msg)
|
||||
self.state_handler.remove_data_directory()
|
||||
|
||||
def bootstrap(self):
|
||||
if not self.cluster.is_unlocked(): # cluster already has leader
|
||||
clone_member = self.cluster.get_clone_member(self.state_handler.name)
|
||||
member_role = 'leader' if clone_member == self.cluster.leader else 'replica'
|
||||
msg = "from {0} '{1}'".format(member_role, clone_member.name)
|
||||
ret = self._async_executor.try_run_async('bootstrap {0}'.format(msg), self.clone, args=(clone_member, msg))
|
||||
return ret or 'trying to bootstrap {0}'.format(msg)
|
||||
|
||||
def bootstrap(self) -> str:
|
||||
# no initialize key and node is allowed to be primary and has 'bootstrap' section in a configuration file
|
||||
elif self.cluster.initialize is None and not self.patroni.nofailover and 'bootstrap' in self.patroni.config:
|
||||
if self.cluster.is_unlocked() and self.cluster.initialize is None\
|
||||
and not self.patroni.nofailover and 'bootstrap' in self.patroni.config:
|
||||
if self.dcs.initialize(create_new=True): # race for initialization
|
||||
self.state_handler.bootstrapping = True
|
||||
with self._async_response:
|
||||
@@ -376,17 +374,26 @@ class Ha(object):
|
||||
return ret or 'trying to bootstrap a new cluster'
|
||||
else:
|
||||
return 'failed to acquire initialize lock'
|
||||
else:
|
||||
create_replica_methods = self.global_config.get_standby_cluster_config().get('create_replica_methods', []) \
|
||||
if self.is_standby_cluster() else None
|
||||
can_bootstrap = self.state_handler.can_create_replica_without_replication_connection(create_replica_methods)
|
||||
concurrent_bootstrap = self.cluster.initialize == ""
|
||||
if can_bootstrap and not concurrent_bootstrap:
|
||||
msg = 'bootstrap (without leader)'
|
||||
return self._async_executor.try_run_async(msg, self.clone) or 'trying to ' + msg
|
||||
return 'waiting for {0}leader to bootstrap'.format('standby_' if self.is_standby_cluster() else '')
|
||||
|
||||
def bootstrap_standby_leader(self):
|
||||
clone_member = self.cluster.get_clone_member(self.state_handler.name)
|
||||
# cluster already has a leader, we can bootstrap from it or from one of replicas (if they allow)
|
||||
if not self.cluster.is_unlocked() and clone_member:
|
||||
member_role = 'leader' if clone_member == self.cluster.leader else 'replica'
|
||||
msg = "from {0} '{1}'".format(member_role, clone_member.name)
|
||||
ret = self._async_executor.try_run_async('bootstrap {0}'.format(msg), self.clone, args=(clone_member, msg))
|
||||
return ret or 'trying to bootstrap {0}'.format(msg)
|
||||
|
||||
# no leader, but configuration may allowed replica creation using backup tools
|
||||
create_replica_methods = self.global_config.get_standby_cluster_config().get('create_replica_methods', []) \
|
||||
if self.is_standby_cluster() else None
|
||||
can_bootstrap = self.state_handler.can_create_replica_without_replication_connection(create_replica_methods)
|
||||
concurrent_bootstrap = self.cluster.initialize == ""
|
||||
if can_bootstrap and not concurrent_bootstrap:
|
||||
msg = 'bootstrap (without leader)'
|
||||
return self._async_executor.try_run_async(msg, self.clone) or 'trying to ' + msg
|
||||
return 'waiting for {0}leader to bootstrap'.format('standby_' if self.is_standby_cluster() else '')
|
||||
|
||||
def bootstrap_standby_leader(self) -> Optional[bool]:
|
||||
""" If we found 'standby' key in the configuration, we need to bootstrap
|
||||
not a real primary, but a 'standby leader', that will take base backup
|
||||
from a remote member and start follow it.
|
||||
@@ -401,16 +408,16 @@ class Ha(object):
|
||||
|
||||
return result
|
||||
|
||||
def _handle_crash_recovery(self):
|
||||
def _handle_crash_recovery(self) -> Optional[str]:
|
||||
if not self._crash_recovery_executed and (self.cluster.is_unlocked() or self._rewind.can_rewind):
|
||||
self._crash_recovery_executed = True
|
||||
self._crash_recovery_started = time.time()
|
||||
msg = 'doing crash recovery in a single user mode'
|
||||
return self._async_executor.try_run_async(msg, self._rewind.ensure_clean_shutdown) or msg
|
||||
|
||||
def _handle_rewind_or_reinitialize(self):
|
||||
def _handle_rewind_or_reinitialize(self) -> Optional[str]:
|
||||
leader = self.get_remote_member() if self.is_standby_cluster() else self.cluster.leader
|
||||
if not self._rewind.rewind_or_reinitialize_needed_and_possible(leader):
|
||||
if not self._rewind.rewind_or_reinitialize_needed_and_possible(leader) or not leader:
|
||||
return None
|
||||
|
||||
if self._rewind.can_rewind:
|
||||
@@ -428,7 +435,7 @@ class Ha(object):
|
||||
msg = 'reinitializing due to diverged timelines'
|
||||
return self._async_executor.try_run_async(msg, self._do_reinitialize, args=(self.cluster,)) or msg
|
||||
|
||||
def recover(self):
|
||||
def recover(self) -> str:
|
||||
# Postgres is not running and we will restart in standby mode. Watchdog is not needed until we promote.
|
||||
self.watchdog.disable()
|
||||
|
||||
@@ -454,7 +461,10 @@ class Ha(object):
|
||||
self.load_cluster_from_dcs()
|
||||
|
||||
role = 'replica'
|
||||
if self.is_standby_cluster() or not self.has_lock():
|
||||
if self.has_lock() and not self.is_standby_cluster():
|
||||
msg = "starting as readonly because i had the session lock"
|
||||
node_to_follow = None
|
||||
else:
|
||||
if not self._rewind.executed:
|
||||
self._rewind.trigger_check_diverged_lsn()
|
||||
msg = self._handle_rewind_or_reinitialize()
|
||||
@@ -474,16 +484,13 @@ class Ha(object):
|
||||
|
||||
if self.is_synchronous_mode():
|
||||
self.state_handler.sync_handler.set_synchronous_standby_names(CaseInsensitiveSet())
|
||||
elif self.has_lock():
|
||||
msg = "starting as readonly because i had the session lock"
|
||||
node_to_follow = None
|
||||
|
||||
if self._async_executor.try_run_async('restarting after failure', self.state_handler.follow,
|
||||
args=(node_to_follow, role, timeout)) is None:
|
||||
self.recovering = True
|
||||
return msg
|
||||
|
||||
def _get_node_to_follow(self, cluster):
|
||||
def _get_node_to_follow(self, cluster: Cluster) -> Union[Leader, Member, None]:
|
||||
# determine the node to follow. If replicatefrom tag is set,
|
||||
# try to follow the node mentioned there, otherwise, follow the leader.
|
||||
if self.is_standby_cluster() and (self.cluster.is_unlocked() or self.has_lock(False)):
|
||||
@@ -506,7 +513,7 @@ class Ha(object):
|
||||
|
||||
return node_to_follow
|
||||
|
||||
def follow(self, demote_reason, follow_reason, refresh=True):
|
||||
def follow(self, demote_reason: str, follow_reason: str, refresh: bool = True) -> str:
|
||||
if refresh:
|
||||
self.load_cluster_from_dcs()
|
||||
|
||||
@@ -565,7 +572,7 @@ class Ha(object):
|
||||
""":returns: `True` if synchronous replication is requested."""
|
||||
return self.global_config.is_synchronous_mode
|
||||
|
||||
def is_failsafe_mode(self):
|
||||
def is_failsafe_mode(self) -> bool:
|
||||
""":returns: `True` if failsafe_mode is enabled in global configuration."""
|
||||
return self.global_config.check_mode('failsafe_mode')
|
||||
|
||||
@@ -625,10 +632,10 @@ class Ha(object):
|
||||
|
||||
def is_sync_standby(self, cluster: Cluster) -> bool:
|
||||
""":returns: `True` if the current node is a synchronous standby."""
|
||||
return cluster.leader and cluster.sync.leader_matches(cluster.leader.name) \
|
||||
return bool(cluster.leader) and cluster.sync.leader_matches(cluster.leader.name) \
|
||||
and cluster.sync.matches(self.state_handler.name)
|
||||
|
||||
def while_not_sync_standby(self, func):
|
||||
def while_not_sync_standby(self, func: Callable[..., Any]) -> Any:
|
||||
"""Runs specified action while trying to make sure that the node is not assigned synchronous standby status.
|
||||
|
||||
Tags us as not allowed to be a sync standby as we are going to go away, if we currently are wait for
|
||||
@@ -665,32 +672,32 @@ class Ha(object):
|
||||
with self._member_state_lock:
|
||||
self._disable_sync -= 1
|
||||
|
||||
def update_cluster_history(self):
|
||||
def update_cluster_history(self) -> None:
|
||||
primary_timeline = self.state_handler.get_primary_timeline()
|
||||
cluster_history = self.cluster.history and self.cluster.history.lines
|
||||
cluster_history = self.cluster.history.lines if self.cluster.history else []
|
||||
if primary_timeline == 1:
|
||||
if cluster_history:
|
||||
self.dcs.set_history_value('[]')
|
||||
elif not cluster_history or cluster_history[-1][0] != primary_timeline - 1 or len(cluster_history[-1]) != 5:
|
||||
cluster_history = {line[0]: line for line in cluster_history or []}
|
||||
history = self.state_handler.get_history(primary_timeline)
|
||||
if history and self.cluster.config:
|
||||
cluster_history = {line[0]: line for line in cluster_history}
|
||||
history: List[List[Any]] = list(map(list, self.state_handler.get_history(primary_timeline)))
|
||||
if self.cluster.config:
|
||||
history = history[-self.cluster.config.max_timelines_history:]
|
||||
for line in history:
|
||||
# enrich current history with promotion timestamps stored in DCS
|
||||
if len(line) == 3 and line[0] in cluster_history \
|
||||
and len(cluster_history[line[0]]) >= 4 \
|
||||
and cluster_history[line[0]][1] == line[1]:
|
||||
line.append(cluster_history[line[0]][3])
|
||||
if len(cluster_history[line[0]]) == 5:
|
||||
line.append(cluster_history[line[0]][4])
|
||||
for line in history:
|
||||
# enrich current history with promotion timestamps stored in DCS
|
||||
cluster_history_line = list(cluster_history.get(line[0], []))
|
||||
if len(line) == 3 and len(cluster_history_line) >= 4 and cluster_history_line[1] == line[1]:
|
||||
line.append(cluster_history_line[3])
|
||||
if len(cluster_history_line) == 5:
|
||||
line.append(cluster_history_line[4])
|
||||
if history:
|
||||
self.dcs.set_history_value(json.dumps(history, separators=(',', ':')))
|
||||
|
||||
def enforce_follow_remote_member(self, message):
|
||||
def enforce_follow_remote_member(self, message: str) -> str:
|
||||
demote_reason = 'cannot be a real primary in standby cluster'
|
||||
return self.follow(demote_reason, message)
|
||||
|
||||
def enforce_primary_role(self, message, promote_message):
|
||||
def enforce_primary_role(self, message: str, promote_message: str) -> str:
|
||||
"""
|
||||
Ensure the node that has won the race for the leader key meets criteria
|
||||
for promoting its PG server to the 'primary' role.
|
||||
@@ -749,7 +756,7 @@ class Ha(object):
|
||||
before_promote, on_success))
|
||||
return promote_message
|
||||
|
||||
def fetch_node_status(self, member):
|
||||
def fetch_node_status(self, member: Member) -> _MemberStatus:
|
||||
"""This function perform http get request on member.api_url and fetches its status
|
||||
:returns: `_MemberStatus` object
|
||||
"""
|
||||
@@ -763,36 +770,36 @@ class Ha(object):
|
||||
logger.warning("Request failed to %s: GET %s (%s)", member.name, member.api_url, e)
|
||||
return _MemberStatus.unknown(member)
|
||||
|
||||
def fetch_nodes_statuses(self, members):
|
||||
def fetch_nodes_statuses(self, members: List[Member]) -> List[_MemberStatus]:
|
||||
pool = ThreadPool(len(members))
|
||||
results = pool.map(self.fetch_node_status, members) # Run API calls on members in parallel
|
||||
pool.close()
|
||||
pool.join()
|
||||
return results
|
||||
|
||||
def update_failsafe(self, data):
|
||||
def update_failsafe(self, data: Dict[str, Any]) -> Optional[str]:
|
||||
if self.state_handler.state == 'running' and self.state_handler.role in ('master', 'primary'):
|
||||
return 'Running as a leader'
|
||||
self._failsafe.update(data)
|
||||
|
||||
def failsafe_is_active(self):
|
||||
def failsafe_is_active(self) -> bool:
|
||||
return self._failsafe.is_active()
|
||||
|
||||
def call_failsafe_member(self, data, member):
|
||||
def call_failsafe_member(self, data: Dict[str, Any], member: Member) -> bool:
|
||||
try:
|
||||
response = self.patroni.request(member, 'post', 'failsafe', data, timeout=2, retries=1)
|
||||
data = response.data.decode('utf-8')
|
||||
logger.info('Got response from %s %s: %s', member.name, member.api_url, data)
|
||||
return response.status == 200 and data == 'Accepted'
|
||||
response_data = response.data.decode('utf-8')
|
||||
logger.info('Got response from %s %s: %s', member.name, member.api_url, response_data)
|
||||
return response.status == 200 and response_data == 'Accepted'
|
||||
except Exception as e:
|
||||
logger.warning("Request failed to %s: POST %s (%s)", member.name, member.api_url, e)
|
||||
return False
|
||||
|
||||
def check_failsafe_topology(self):
|
||||
def check_failsafe_topology(self) -> bool:
|
||||
failsafe = self.dcs.failsafe
|
||||
if not isinstance(failsafe, dict) or self.state_handler.name not in failsafe:
|
||||
return False
|
||||
data = {
|
||||
data: Dict[str, Any] = {
|
||||
'name': self.state_handler.name,
|
||||
'conn_url': self.state_handler.connection_string,
|
||||
'api_url': self.patroni.api.connection_string,
|
||||
@@ -813,7 +820,7 @@ class Ha(object):
|
||||
pool.join()
|
||||
return all(results)
|
||||
|
||||
def is_lagging(self, wal_position):
|
||||
def is_lagging(self, wal_position: int) -> bool:
|
||||
"""Returns if instance with an wal should consider itself unhealthy to be promoted due to replication lag.
|
||||
|
||||
:param wal_position: Current wal position.
|
||||
@@ -822,7 +829,7 @@ class Ha(object):
|
||||
lag = (self.cluster.last_lsn or 0) - wal_position
|
||||
return lag > self.global_config.maximum_lag_on_failover
|
||||
|
||||
def _is_healthiest_node(self, members, check_replication_lag=True):
|
||||
def _is_healthiest_node(self, members: Collection[Member], check_replication_lag: bool = True) -> bool:
|
||||
"""This method tries to determine whether I am healthy enough to became a new leader candidate or not."""
|
||||
|
||||
my_wal_position = self.state_handler.last_operation()
|
||||
@@ -833,6 +840,9 @@ class Ha(object):
|
||||
if not self.is_standby_cluster() and self.check_timeline():
|
||||
cluster_timeline = self.cluster.timeline
|
||||
my_timeline = self.state_handler.replica_cached_timeline(cluster_timeline)
|
||||
if my_timeline is None:
|
||||
logger.info('Can not figure out my timeline')
|
||||
return False
|
||||
if my_timeline < cluster_timeline:
|
||||
logger.info('My timeline %s is behind last known cluster timeline %s', my_timeline, cluster_timeline)
|
||||
return False
|
||||
@@ -843,7 +853,7 @@ class Ha(object):
|
||||
if members:
|
||||
for st in self.fetch_nodes_statuses(members):
|
||||
if st.failover_limitation() is None:
|
||||
if not st.in_recovery:
|
||||
if st.in_recovery is False:
|
||||
logger.warning('Primary (%s) is still alive', st.member.name)
|
||||
return False
|
||||
if my_wal_position < st.wal_position:
|
||||
@@ -875,8 +885,6 @@ class Ha(object):
|
||||
not_allowed_reason = st.failover_limitation()
|
||||
if not_allowed_reason:
|
||||
logger.info('Member %s is %s', st.member.name, not_allowed_reason)
|
||||
elif not isinstance(st.wal_position, int):
|
||||
logger.info('Member %s does not report wal_position', st.member.name)
|
||||
elif cluster_lsn and st.wal_position < cluster_lsn or\
|
||||
not cluster_lsn and self.is_lagging(st.wal_position):
|
||||
logger.info('Member %s exceeds maximum replication lag', st.member.name)
|
||||
@@ -889,13 +897,15 @@ class Ha(object):
|
||||
logger.warning('manual failover: members list is empty')
|
||||
return ret
|
||||
|
||||
def manual_failover_process_no_leader(self) -> Union[bool, None]:
|
||||
def manual_failover_process_no_leader(self) -> Optional[bool]:
|
||||
"""Handles manual failover/switchover when the old leader already stepped down.
|
||||
|
||||
:returns: - `True` if the current node is the best candidate to become the new leader
|
||||
- `None` if the current node is running as a primary and requested candidate doesn't exist
|
||||
"""
|
||||
failover = self.cluster.failover
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
assert failover is not None
|
||||
if failover.candidate: # manual failover to specific member
|
||||
if failover.candidate == self.state_handler.name: # manual failover to me
|
||||
return True
|
||||
@@ -905,7 +915,7 @@ class Ha(object):
|
||||
if not self.cluster.get_member(failover.candidate, fallback_to_leader=False)\
|
||||
and self.state_handler.is_leader():
|
||||
logger.warning("manual failover: removing failover key because failover candidate is not running")
|
||||
self.dcs.manual_failover('', '', index=self.cluster.failover.index)
|
||||
self.dcs.manual_failover('', '', index=failover.index)
|
||||
return None
|
||||
return False
|
||||
|
||||
@@ -916,7 +926,7 @@ class Ha(object):
|
||||
|
||||
# find specific node and check that it is healthy
|
||||
member = self.cluster.get_member(failover.candidate, fallback_to_leader=False)
|
||||
if member:
|
||||
if isinstance(member, Member):
|
||||
st = self.fetch_node_status(member)
|
||||
not_allowed_reason = st.failover_limitation()
|
||||
if not_allowed_reason is None: # node is healthy
|
||||
@@ -981,7 +991,7 @@ class Ha(object):
|
||||
if self.is_synchronous_mode() and self.cluster.failover.leader and \
|
||||
not self.cluster.sync.is_empty and not self.cluster.sync.matches(self.state_handler.name, True):
|
||||
return False
|
||||
return self.manual_failover_process_no_leader()
|
||||
return self.manual_failover_process_no_leader() or False
|
||||
|
||||
if not self.watchdog.is_healthy:
|
||||
logger.warning('Watchdog device is not usable')
|
||||
@@ -1011,17 +1021,17 @@ class Ha(object):
|
||||
|
||||
return self._is_healthiest_node(members.values())
|
||||
|
||||
def _delete_leader(self, last_lsn=None):
|
||||
def _delete_leader(self, last_lsn: Optional[int] = None) -> None:
|
||||
self.set_is_leader(False)
|
||||
self.dcs.delete_leader(last_lsn)
|
||||
self.dcs.reset_cluster()
|
||||
|
||||
def release_leader_key_voluntarily(self, last_lsn=None):
|
||||
def release_leader_key_voluntarily(self, last_lsn: Optional[int] = None) -> None:
|
||||
self._delete_leader(last_lsn)
|
||||
self.touch_member()
|
||||
logger.info("Leader key released")
|
||||
|
||||
def demote(self, mode):
|
||||
def demote(self, mode: str) -> Optional[bool]:
|
||||
"""Demote PostgreSQL running as primary.
|
||||
|
||||
:param mode: One of offline, graceful or immediate.
|
||||
@@ -1046,7 +1056,7 @@ class Ha(object):
|
||||
|
||||
status = {'released': False}
|
||||
|
||||
def on_shutdown(checkpoint_location):
|
||||
def on_shutdown(checkpoint_location: int) -> None:
|
||||
# Postmaster is still running, but pg_control already reports clean "shut down".
|
||||
# It could happen if Postgres is still archiving the backlog of WAL files.
|
||||
# If we know that there are replicas that received the shutdown checkpoint
|
||||
@@ -1057,13 +1067,13 @@ class Ha(object):
|
||||
self.release_leader_key_voluntarily(checkpoint_location)
|
||||
status['released'] = True
|
||||
|
||||
def before_shutdown():
|
||||
def before_shutdown() -> None:
|
||||
if self.state_handler.citus_handler.is_coordinator():
|
||||
self.state_handler.citus_handler.on_demote()
|
||||
else:
|
||||
self.notify_citus_coordinator('before_demote')
|
||||
|
||||
self.state_handler.stop(mode_control['stop'], checkpoint=mode_control['checkpoint'],
|
||||
self.state_handler.stop(str(mode_control['stop']), checkpoint=bool(mode_control['checkpoint']),
|
||||
on_safepoint=self.watchdog.disable if self.watchdog.is_running else None,
|
||||
on_shutdown=on_shutdown if mode_control['release'] else None,
|
||||
before_shutdown=before_shutdown if mode == 'graceful' else None,
|
||||
@@ -1099,7 +1109,8 @@ class Ha(object):
|
||||
return False # do not start postgres, but run pg_rewind on the next iteration
|
||||
self.state_handler.follow(node_to_follow)
|
||||
|
||||
def should_run_scheduled_action(self, action_name, scheduled_at, cleanup_fn):
|
||||
def should_run_scheduled_action(self, action_name: str, scheduled_at: Optional[datetime.datetime],
|
||||
cleanup_fn: Callable[..., Any]) -> bool:
|
||||
if scheduled_at and not self.is_paused():
|
||||
# If the scheduled action is in the far future, we shouldn't do anything and just return.
|
||||
# If the scheduled action is in the past, we consider the value to be stale and we remove
|
||||
@@ -1133,7 +1144,7 @@ class Ha(object):
|
||||
cleanup_fn()
|
||||
return False
|
||||
|
||||
def process_manual_failover_from_leader(self):
|
||||
def process_manual_failover_from_leader(self) -> Optional[str]:
|
||||
"""Checks if manual failover is requested and takes action if appropriate.
|
||||
|
||||
Cleans up failover key if failover conditions are not matched.
|
||||
@@ -1177,7 +1188,7 @@ class Ha(object):
|
||||
logger.info('Cleaning up failover key')
|
||||
self.dcs.manual_failover('', '', index=failover.index)
|
||||
|
||||
def process_unhealthy_cluster(self):
|
||||
def process_unhealthy_cluster(self) -> str:
|
||||
"""Cluster has no leader key"""
|
||||
|
||||
if self.is_healthiest_node():
|
||||
@@ -1219,7 +1230,7 @@ class Ha(object):
|
||||
return self.follow('demoting self because i am not the healthiest node',
|
||||
'following a different leader because i am not the healthiest node')
|
||||
|
||||
def process_healthy_cluster(self):
|
||||
def process_healthy_cluster(self) -> str:
|
||||
if self.has_lock():
|
||||
if self.is_paused() and not self.state_handler.is_leader():
|
||||
if self.cluster.failover and self.cluster.failover.candidate == self.state_handler.name:
|
||||
@@ -1271,7 +1282,7 @@ class Ha(object):
|
||||
'no action. I am ({0}), a secondary, and following a leader ({1})'.format(
|
||||
self.state_handler.name, lock_owner), refresh=False)
|
||||
|
||||
def evaluate_scheduled_restart(self):
|
||||
def evaluate_scheduled_restart(self) -> Optional[str]:
|
||||
if self._async_executor.busy: # Restart already in progress
|
||||
return None
|
||||
|
||||
@@ -1297,7 +1308,7 @@ class Ha(object):
|
||||
finally:
|
||||
self.delete_future_restart()
|
||||
|
||||
def restart_matches(self, role, postgres_version, pending_restart):
|
||||
def restart_matches(self, role: Optional[str], postgres_version: Optional[str], pending_restart: bool) -> bool:
|
||||
reason_to_cancel = ""
|
||||
# checking the restart filters here seem to be less ugly than moving them into the
|
||||
# run_scheduled_action.
|
||||
@@ -1316,7 +1327,7 @@ class Ha(object):
|
||||
logger.info("not proceeding with the restart: %s", reason_to_cancel)
|
||||
return False
|
||||
|
||||
def schedule_future_restart(self, restart_data):
|
||||
def schedule_future_restart(self, restart_data: Dict[str, Any]) -> bool:
|
||||
with self._async_executor:
|
||||
restart_data['postmaster_start_time'] = self.state_handler.postmaster_start_time()
|
||||
if not self.patroni.scheduled_restart:
|
||||
@@ -1325,7 +1336,7 @@ class Ha(object):
|
||||
return True
|
||||
return False
|
||||
|
||||
def delete_future_restart(self):
|
||||
def delete_future_restart(self) -> bool:
|
||||
ret = False
|
||||
with self._async_executor:
|
||||
if self.patroni.scheduled_restart:
|
||||
@@ -1334,14 +1345,13 @@ class Ha(object):
|
||||
ret = True
|
||||
return ret
|
||||
|
||||
def future_restart_scheduled(self):
|
||||
return self.patroni.scheduled_restart.copy()\
|
||||
if (self.patroni.scheduled_restart and isinstance(self.patroni.scheduled_restart, dict)) else None
|
||||
def future_restart_scheduled(self) -> Dict[str, Any]:
|
||||
return self.patroni.scheduled_restart.copy()
|
||||
|
||||
def restart_scheduled(self):
|
||||
def restart_scheduled(self) -> bool:
|
||||
return self._async_executor.scheduled_action == 'restart'
|
||||
|
||||
def restart(self, restart_data, run_async=False):
|
||||
def restart(self, restart_data: Dict[str, Any], run_async: bool = False) -> Tuple[bool, str]:
|
||||
""" conditional and unconditional restart """
|
||||
assert isinstance(restart_data, dict)
|
||||
|
||||
@@ -1365,10 +1375,10 @@ class Ha(object):
|
||||
timeout = restart_data.get('timeout', self.global_config.primary_start_timeout)
|
||||
self.set_start_timeout(timeout)
|
||||
|
||||
def before_shutdown():
|
||||
def before_shutdown() -> None:
|
||||
self.notify_citus_coordinator('before_demote')
|
||||
|
||||
def after_start():
|
||||
def after_start() -> None:
|
||||
self.notify_citus_coordinator('after_promote')
|
||||
|
||||
# For non async cases we want to wait for restart to complete or timeout before returning.
|
||||
@@ -1390,16 +1400,17 @@ class Ha(object):
|
||||
else:
|
||||
return (False, 'restart failed')
|
||||
|
||||
def _do_reinitialize(self, cluster):
|
||||
def _do_reinitialize(self, cluster: Cluster) -> Optional[bool]:
|
||||
self.state_handler.stop('immediate', stop_timeout=self.patroni.config['retry_timeout'])
|
||||
# Commented redundant data directory cleanup here
|
||||
# self.state_handler.remove_data_directory()
|
||||
|
||||
clone_member = self.cluster.get_clone_member(self.state_handler.name)
|
||||
member_role = 'leader' if clone_member == self.cluster.leader else 'replica'
|
||||
return self.clone(clone_member, "from {0} '{1}'".format(member_role, clone_member.name))
|
||||
clone_member = cluster.get_clone_member(self.state_handler.name)
|
||||
if clone_member:
|
||||
member_role = 'leader' if clone_member == cluster.leader else 'replica'
|
||||
return self.clone(clone_member, "from {0} '{1}'".format(member_role, clone_member.name))
|
||||
|
||||
def reinitialize(self, force=False):
|
||||
def reinitialize(self, force: bool = False) -> Optional[str]:
|
||||
with self._async_executor:
|
||||
self.load_cluster_from_dcs()
|
||||
|
||||
@@ -1409,6 +1420,8 @@ class Ha(object):
|
||||
if self.has_lock(False):
|
||||
return 'I am the leader, can not reinitialize'
|
||||
|
||||
cluster = self.cluster
|
||||
|
||||
if force:
|
||||
self._async_executor.cancel()
|
||||
|
||||
@@ -1417,9 +1430,9 @@ class Ha(object):
|
||||
if action is not None:
|
||||
return '{0} already in progress'.format(action)
|
||||
|
||||
self._async_executor.run_async(self._do_reinitialize, args=(self.cluster, ))
|
||||
self._async_executor.run_async(self._do_reinitialize, args=(cluster, ))
|
||||
|
||||
def handle_long_action_in_progress(self):
|
||||
def handle_long_action_in_progress(self) -> str:
|
||||
"""
|
||||
Figure out what to do with the task AsyncExecutor is performing.
|
||||
"""
|
||||
@@ -1432,7 +1445,7 @@ class Ha(object):
|
||||
self.demote('immediate')
|
||||
return 'terminated crash recovery because of startup timeout'
|
||||
|
||||
return 'updated leader lock during ' + self._async_executor.scheduled_action
|
||||
return 'updated leader lock during {0}'.format(self._async_executor.scheduled_action)
|
||||
elif not self.state_handler.bootstrapping and not self.is_paused():
|
||||
# Don't have lock, make sure we are not promoting or starting up a primary in the background
|
||||
if self._async_executor.scheduled_action == 'promote':
|
||||
@@ -1443,28 +1456,28 @@ class Ha(object):
|
||||
return 'lost leader before promote'
|
||||
|
||||
if self.state_handler.role in ('master', 'primary'):
|
||||
logger.info("Demoting primary during " + self._async_executor.scheduled_action)
|
||||
logger.info('Demoting primary during %s', self._async_executor.scheduled_action)
|
||||
if self._async_executor.scheduled_action == 'restart':
|
||||
# Restart needs a special interlocking cancel because postmaster may be just started in a
|
||||
# background thread and has not even written a pid file yet.
|
||||
with self._async_executor.critical_task as task:
|
||||
if not task.cancel():
|
||||
if not task.cancel() and isinstance(task.result, PostmasterProcess):
|
||||
self.state_handler.terminate_starting_postmaster(postmaster=task.result)
|
||||
self.demote('immediate-nolock')
|
||||
return 'lost leader lock during ' + self._async_executor.scheduled_action
|
||||
return 'lost leader lock during {0}'.format(self._async_executor.scheduled_action)
|
||||
if self.cluster.is_unlocked():
|
||||
logger.info('not healthy enough for leader race')
|
||||
|
||||
return self._async_executor.scheduled_action + ' in progress'
|
||||
return '{0} in progress'.format(self._async_executor.scheduled_action)
|
||||
|
||||
@staticmethod
|
||||
def sysid_valid(sysid):
|
||||
def sysid_valid(sysid: Optional[str]) -> bool:
|
||||
# sysid does tv_sec << 32, where tv_sec is the number of seconds sine 1970,
|
||||
# so even 1 << 32 would have 10 digits.
|
||||
sysid = str(sysid)
|
||||
return len(sysid) >= 10 and sysid.isdigit()
|
||||
|
||||
def post_recover(self):
|
||||
def post_recover(self) -> Optional[str]:
|
||||
if not self.state_handler.is_running():
|
||||
self.watchdog.disable()
|
||||
if self.has_lock():
|
||||
@@ -1475,14 +1488,14 @@ class Ha(object):
|
||||
return 'failed to start postgres'
|
||||
return None
|
||||
|
||||
def cancel_initialization(self):
|
||||
def cancel_initialization(self) -> None:
|
||||
logger.info('removing initialize key after failed attempt to bootstrap the cluster')
|
||||
self.dcs.cancel_initialization()
|
||||
self.state_handler.stop('immediate', stop_timeout=self.patroni.config['retry_timeout'])
|
||||
self.state_handler.move_data_directory()
|
||||
raise PatroniFatalException('Failed to bootstrap cluster')
|
||||
|
||||
def post_bootstrap(self):
|
||||
def post_bootstrap(self) -> str:
|
||||
with self._async_response:
|
||||
result = self._async_response.result
|
||||
# bootstrap has failed if postgres is not running
|
||||
@@ -1513,7 +1526,7 @@ class Ha(object):
|
||||
|
||||
return 'initialized a new cluster'
|
||||
|
||||
def handle_starting_instance(self):
|
||||
def handle_starting_instance(self) -> Optional[str]:
|
||||
"""Starting up PostgreSQL may take a long time. In case we are the leader we may want to
|
||||
fail over to."""
|
||||
|
||||
@@ -1552,13 +1565,13 @@ class Ha(object):
|
||||
logger.info("Still starting up as a standby.")
|
||||
return None
|
||||
|
||||
def set_start_timeout(self, value):
|
||||
def set_start_timeout(self, value: Optional[int]) -> None:
|
||||
"""Sets timeout for starting as primary before eligible for failover.
|
||||
|
||||
Must be called when async_executor is busy or in the main thread."""
|
||||
self._start_timeout = value
|
||||
|
||||
def _run_cycle(self):
|
||||
def _run_cycle(self) -> str:
|
||||
dcs_failed = False
|
||||
try:
|
||||
try:
|
||||
@@ -1619,6 +1632,8 @@ class Ha(object):
|
||||
return 'started as a secondary'
|
||||
|
||||
# is data directory empty?
|
||||
data_directory_error = ''
|
||||
data_directory_is_empty = None
|
||||
try:
|
||||
data_directory_is_empty = self.state_handler.data_directory_empty()
|
||||
data_directory_is_accessible = True
|
||||
@@ -1727,7 +1742,7 @@ class Ha(object):
|
||||
self._failsafe.set_is_active(0)
|
||||
self.touch_member()
|
||||
|
||||
def _handle_dcs_error(self):
|
||||
def _handle_dcs_error(self) -> str:
|
||||
if not self.is_paused() and self.state_handler.is_running():
|
||||
if self.state_handler.is_leader():
|
||||
if self.is_failsafe_mode() and self.check_failsafe_topology():
|
||||
@@ -1746,13 +1761,13 @@ class Ha(object):
|
||||
self._sync_replication_slots(True)
|
||||
return 'DCS is not accessible'
|
||||
|
||||
def _sync_replication_slots(self, dcs_failed):
|
||||
def _sync_replication_slots(self, dcs_failed: bool) -> List[str]:
|
||||
"""Handles replication slots.
|
||||
|
||||
:param dcs_failed: bool, indicates that communication with DCS failed (get_cluster() or update_leader())
|
||||
:returns: list[str], replication slots names that should be copied from the primary"""
|
||||
|
||||
slots = []
|
||||
slots: List[str] = []
|
||||
|
||||
# If dcs_failed we don't want to touch replication slots on a leader or replicas if failsafe_mode isn't enabled.
|
||||
if not self.cluster or dcs_failed and (self.is_leader() or not self.is_failsafe_mode()):
|
||||
@@ -1772,7 +1787,7 @@ class Ha(object):
|
||||
# Don't copy replication slots if failsafe_mode is active
|
||||
return [] if self.failsafe_is_active() else slots
|
||||
|
||||
def run_cycle(self):
|
||||
def run_cycle(self) -> str:
|
||||
with self._async_executor:
|
||||
try:
|
||||
info = self._run_cycle()
|
||||
@@ -1783,7 +1798,7 @@ class Ha(object):
|
||||
logger.exception('Unexpected exception')
|
||||
return 'Unexpected exception raised, please report it as a BUG'
|
||||
|
||||
def shutdown(self):
|
||||
def shutdown(self) -> None:
|
||||
if self.is_paused():
|
||||
logger.info('Leader key is not deleted and Postgresql is not stopped due paused state')
|
||||
self.watchdog.disable()
|
||||
@@ -1796,7 +1811,7 @@ class Ha(object):
|
||||
|
||||
status = {'deleted': False}
|
||||
|
||||
def _on_shutdown(checkpoint_location):
|
||||
def _on_shutdown(checkpoint_location: int) -> None:
|
||||
if self.is_leader():
|
||||
# Postmaster is still running, but pg_control already reports clean "shut down".
|
||||
# It could happen if Postgres is still archiving the backlog of WAL files.
|
||||
@@ -1808,7 +1823,7 @@ class Ha(object):
|
||||
else:
|
||||
self.dcs.write_leader_optime(checkpoint_location)
|
||||
|
||||
def _before_shutdown():
|
||||
def _before_shutdown() -> None:
|
||||
self.notify_citus_coordinator('before_demote')
|
||||
|
||||
on_shutdown = _on_shutdown if self.is_leader() else None
|
||||
@@ -1829,36 +1844,36 @@ class Ha(object):
|
||||
logger.error("PostgreSQL shutdown failed, leader key not removed.%s",
|
||||
(" Leaving watchdog running." if self.watchdog.is_running else ""))
|
||||
|
||||
def watch(self, timeout):
|
||||
def watch(self, timeout: float) -> bool:
|
||||
# watch on leader key changes if the postgres is running and leader is known and current node is not lock owner
|
||||
if self._async_executor.busy or not self.cluster or self.cluster.is_unlocked() or self.has_lock(False):
|
||||
leader_index = None
|
||||
else:
|
||||
leader_index = self.cluster.leader.index
|
||||
leader_index = self.cluster.leader.index if self.cluster.leader else None
|
||||
|
||||
return self.dcs.watch(leader_index, timeout)
|
||||
|
||||
def wakeup(self):
|
||||
def wakeup(self) -> None:
|
||||
"""Call of this method will trigger the next run of HA loop if there is
|
||||
no "active" leader watch request in progress.
|
||||
This usually happens on the leader or if the node is running async action"""
|
||||
self.dcs.event.set()
|
||||
|
||||
def get_remote_member(self, member=None):
|
||||
def get_remote_member(self, member: Union[Leader, Member, None] = None) -> RemoteMember:
|
||||
""" In case of standby cluster this will tel us from which remote
|
||||
member to stream. Config can be both patroni config or
|
||||
cluster.config.data
|
||||
"""
|
||||
data: Dict[str, Any] = {}
|
||||
cluster_params = self.global_config.get_standby_cluster_config()
|
||||
|
||||
if cluster_params:
|
||||
name = member.name if member else 'remote_member:{}'.format(uuid.uuid1())
|
||||
|
||||
data = {k: v for k, v in cluster_params.items() if k in RemoteMember.allowed_keys()}
|
||||
data.update({k: v for k, v in cluster_params.items() if k in RemoteMember.allowed_keys()})
|
||||
data['no_replication_slot'] = 'primary_slot_name' not in cluster_params
|
||||
conn_kwargs = member.conn_kwargs() if member else \
|
||||
{k: cluster_params[k] for k in ('host', 'port') if k in cluster_params}
|
||||
if conn_kwargs:
|
||||
data['conn_kwargs'] = conn_kwargs
|
||||
|
||||
return RemoteMember(name, data)
|
||||
name = member.name if member else 'remote_member:{}'.format(uuid.uuid1())
|
||||
return RemoteMember(name, data)
|
||||
|
||||
@@ -13,7 +13,7 @@ from patroni.utils import deep_compare
|
||||
from queue import Queue, Full
|
||||
from threading import Lock, Thread
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -68,7 +68,7 @@ class QueueHandler(logging.Handler):
|
||||
def __init__(self) -> None:
|
||||
"""Queue initialised and initial records_lost established."""
|
||||
super().__init__()
|
||||
self.queue = Queue()
|
||||
self.queue: Queue[Union[logging.LogRecord, None]] = Queue()
|
||||
self._records_lost = 0
|
||||
|
||||
def _put_record(self, record: logging.LogRecord) -> None:
|
||||
@@ -189,10 +189,10 @@ class PatroniLogger(Thread):
|
||||
super(PatroniLogger, self).__init__()
|
||||
self._queue_handler = QueueHandler()
|
||||
self._root_logger = logging.getLogger()
|
||||
self._config = None
|
||||
self._config: Optional[Dict[str, Any]] = None
|
||||
self.log_handler = None
|
||||
self.log_handler_lock = Lock()
|
||||
self._old_handlers = []
|
||||
self._old_handlers: List[logging.Handler] = []
|
||||
# initially set log level to ``DEBUG`` while the logger thread has not started running yet. The daemon process
|
||||
# will later adjust all log related settings with what was provided through the user configuration file.
|
||||
self.reload_config({'level': 'DEBUG'})
|
||||
|
||||
@@ -12,7 +12,7 @@ from datetime import datetime
|
||||
from dateutil import tz
|
||||
from psutil import TimeoutExpired
|
||||
from threading import current_thread, Lock
|
||||
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Union, Tuple, TYPE_CHECKING
|
||||
|
||||
from .bootstrap import Bootstrap
|
||||
from .callback_executor import CallbackAction, CallbackExecutor
|
||||
@@ -25,11 +25,14 @@ from .postmaster import PostmasterProcess
|
||||
from .slots import SlotsHandler
|
||||
from .sync import SyncHandler
|
||||
from .. import psycopg
|
||||
from ..dcs import Cluster, Member
|
||||
from ..async_executor import CriticalTask
|
||||
from ..dcs import Cluster, Leader, Member
|
||||
from ..exceptions import PostgresConnectionException
|
||||
from ..utils import Retry, RetryFailedError, polling_loop, data_directory_is_empty, parse_int
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from psycopg import Connection as Connection3, Cursor
|
||||
from psycopg2 import connection as connection3, cursor
|
||||
from ..config import GlobalConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -59,13 +62,15 @@ class Postgresql(object):
|
||||
"pg_catalog.pg_{0}_{1}_diff(COALESCE(pg_catalog.pg_last_{0}_receive_{1}(), '0/0'), '0/0')::bigint, "
|
||||
"pg_catalog.pg_is_in_recovery() AND pg_catalog.pg_is_{0}_replay_paused()")
|
||||
|
||||
def __init__(self, config):
|
||||
self.name = config['name']
|
||||
self.scope = config['scope']
|
||||
self._data_dir = config['data_dir']
|
||||
def __init__(self, config: Dict[str, Any]) -> None:
|
||||
self.name: str = config['name']
|
||||
self.scope: str = config['scope']
|
||||
self._data_dir: str = config['data_dir']
|
||||
self._database = config.get('database', 'postgres')
|
||||
self._version_file = os.path.join(self._data_dir, 'PG_VERSION')
|
||||
self._pg_control = os.path.join(self._data_dir, 'global', 'pg_control')
|
||||
self.connection_string: str
|
||||
self.proxy_url: Optional[str]
|
||||
self._major_version = self.get_major_version()
|
||||
self._global_config = None
|
||||
|
||||
@@ -92,7 +97,7 @@ class Postgresql(object):
|
||||
|
||||
self.cancellable = CancellableSubprocess()
|
||||
|
||||
self._sysid = None
|
||||
self._sysid = ''
|
||||
self.retry = Retry(max_tries=-1, deadline=config['retry_timeout'] / 2.0, max_delay=1,
|
||||
retry_exceptions=PostgresConnectionException)
|
||||
|
||||
@@ -102,7 +107,7 @@ class Postgresql(object):
|
||||
|
||||
self._role_lock = Lock()
|
||||
self.set_role(self.get_postgres_role_from_data_directory())
|
||||
self._state_entry_timestamp = None
|
||||
self._state_entry_timestamp = 0
|
||||
|
||||
self._cluster_info_state = {}
|
||||
self._has_permanent_logical_slots = True
|
||||
@@ -126,35 +131,35 @@ class Postgresql(object):
|
||||
self.set_role('demoted')
|
||||
|
||||
@property
|
||||
def create_replica_methods(self):
|
||||
return self.config.get('create_replica_methods', []) or self.config.get('create_replica_method', [])
|
||||
def create_replica_methods(self) -> List[str]:
|
||||
return self.config.get('create_replica_methods', []) or self.config.get('create_replica_method', []) or []
|
||||
|
||||
@property
|
||||
def major_version(self):
|
||||
def major_version(self) -> int:
|
||||
return self._major_version
|
||||
|
||||
@property
|
||||
def database(self):
|
||||
def database(self) -> str:
|
||||
return self._database
|
||||
|
||||
@property
|
||||
def data_dir(self):
|
||||
def data_dir(self) -> str:
|
||||
return self._data_dir
|
||||
|
||||
@property
|
||||
def callback(self):
|
||||
return self.config.get('callbacks') or {}
|
||||
def callback(self) -> Dict[str, str]:
|
||||
return self.config.get('callbacks', {}) or {}
|
||||
|
||||
@property
|
||||
def wal_dir(self):
|
||||
def wal_dir(self) -> str:
|
||||
return os.path.join(self._data_dir, 'pg_' + self.wal_name)
|
||||
|
||||
@property
|
||||
def wal_name(self):
|
||||
def wal_name(self) -> str:
|
||||
return 'wal' if self._major_version >= 100000 else 'xlog'
|
||||
|
||||
@property
|
||||
def lsn_name(self):
|
||||
def lsn_name(self) -> str:
|
||||
return 'lsn' if self._major_version >= 100000 else 'location'
|
||||
|
||||
@property
|
||||
@@ -163,7 +168,7 @@ class Postgresql(object):
|
||||
return self._major_version >= 90600
|
||||
|
||||
@property
|
||||
def cluster_info_query(self):
|
||||
def cluster_info_query(self) -> str:
|
||||
"""Returns the monitoring query with a fixed number of fields.
|
||||
|
||||
The query text is constructed based on current state in DCS and PostgreSQL version:
|
||||
@@ -206,7 +211,7 @@ class Postgresql(object):
|
||||
|
||||
return ("SELECT " + self.TL_LSN + ", {2}").format(self.wal_name, self.lsn_name, extra)
|
||||
|
||||
def _version_file_exists(self):
|
||||
def _version_file_exists(self) -> bool:
|
||||
return not self.data_directory_empty() and os.path.isfile(self._version_file)
|
||||
|
||||
def get_major_version(self) -> int:
|
||||
@@ -221,11 +226,11 @@ class Postgresql(object):
|
||||
logger.exception('Failed to read PG_VERSION from %s', self._data_dir)
|
||||
return 0
|
||||
|
||||
def pgcommand(self, cmd):
|
||||
def pgcommand(self, cmd: str) -> str:
|
||||
"""Returns path to the specified PostgreSQL command"""
|
||||
return os.path.join(self._bin_dir, cmd)
|
||||
|
||||
def pg_ctl(self, cmd, *args, **kwargs):
|
||||
def pg_ctl(self, cmd: str, *args: str, **kwargs: Any) -> bool:
|
||||
"""Builds and executes pg_ctl command
|
||||
|
||||
:returns: `!True` when return_code == 0, otherwise `!False`"""
|
||||
@@ -244,7 +249,7 @@ class Postgresql(object):
|
||||
initdb = [self.pgcommand('initdb')] + list(args) + [self.data_dir]
|
||||
return subprocess.call(initdb, **kwargs) == 0
|
||||
|
||||
def pg_isready(self):
|
||||
def pg_isready(self) -> str:
|
||||
"""Runs pg_isready to see if PostgreSQL is accepting connections.
|
||||
|
||||
:returns: 'ok' if PostgreSQL is up, 'reject' if starting up, 'no_resopnse' if not up."""
|
||||
@@ -267,25 +272,25 @@ class Postgresql(object):
|
||||
3: STATE_UNKNOWN}
|
||||
return return_codes.get(ret, STATE_UNKNOWN)
|
||||
|
||||
def reload_config(self, config, sighup=False):
|
||||
def reload_config(self, config: Dict[str, Any], sighup: bool = False) -> None:
|
||||
self.config.reload_config(config, sighup)
|
||||
self._is_leader_retry.deadline = self.retry.deadline = config['retry_timeout'] / 2.0
|
||||
|
||||
@property
|
||||
def pending_restart(self):
|
||||
def pending_restart(self) -> bool:
|
||||
return self._pending_restart
|
||||
|
||||
def set_pending_restart(self, value):
|
||||
def set_pending_restart(self, value: bool) -> None:
|
||||
self._pending_restart = value
|
||||
|
||||
@property
|
||||
def sysid(self):
|
||||
def sysid(self) -> str:
|
||||
if not self._sysid and not self.bootstrapping:
|
||||
data = self.controldata()
|
||||
self._sysid = data.get('Database system identifier', "")
|
||||
self._sysid = data.get('Database system identifier', '')
|
||||
return self._sysid
|
||||
|
||||
def get_postgres_role_from_data_directory(self):
|
||||
def get_postgres_role_from_data_directory(self) -> str:
|
||||
if self.data_directory_empty() or not self.controldata():
|
||||
return 'uninitialized'
|
||||
elif self.config.recovery_conf_exists():
|
||||
@@ -294,24 +299,24 @@ class Postgresql(object):
|
||||
return 'master'
|
||||
|
||||
@property
|
||||
def server_version(self):
|
||||
def server_version(self) -> int:
|
||||
return self._connection.server_version
|
||||
|
||||
def connection(self):
|
||||
def connection(self) -> Union['connection3', 'Connection3[Any]']:
|
||||
return self._connection.get()
|
||||
|
||||
def set_connection_kwargs(self, kwargs):
|
||||
def set_connection_kwargs(self, kwargs: Dict[str, Any]) -> None:
|
||||
self._connection.set_conn_kwargs(kwargs.copy())
|
||||
self.citus_handler.set_conn_kwargs(kwargs.copy())
|
||||
|
||||
def _query(self, sql, *params):
|
||||
def _query(self, sql: str, *params: Any) -> Union['Cursor[Any]', 'cursor']:
|
||||
"""We are always using the same cursor, therefore this method is not thread-safe!!!
|
||||
You can call it from different threads only if you are holding explicit `AsyncExecutor` lock,
|
||||
because the main thread is always holding this lock when running HA cycle."""
|
||||
cursor = None
|
||||
try:
|
||||
cursor = self._connection.cursor()
|
||||
cursor.execute(sql, params or None)
|
||||
cursor.execute(sql.encode('utf-8'), params or None)
|
||||
return cursor
|
||||
except psycopg.Error as e:
|
||||
if cursor and cursor.connection.closed == 0:
|
||||
@@ -327,7 +332,7 @@ class Postgresql(object):
|
||||
raise RetryFailedError('cluster is being restarted')
|
||||
raise PostgresConnectionException('connection problems')
|
||||
|
||||
def query(self, sql, *args, **kwargs):
|
||||
def query(self, sql: str, *args: Any, **kwargs: Any) -> Union['Cursor[Any]', 'cursor']:
|
||||
if not kwargs.get('retry', True):
|
||||
return self._query(sql, *args)
|
||||
try:
|
||||
@@ -335,22 +340,22 @@ class Postgresql(object):
|
||||
except RetryFailedError as e:
|
||||
raise PostgresConnectionException(str(e))
|
||||
|
||||
def pg_control_exists(self):
|
||||
def pg_control_exists(self) -> bool:
|
||||
return os.path.isfile(self._pg_control)
|
||||
|
||||
def data_directory_empty(self):
|
||||
def data_directory_empty(self) -> bool:
|
||||
if self.pg_control_exists():
|
||||
return False
|
||||
return data_directory_is_empty(self._data_dir)
|
||||
|
||||
def replica_method_options(self, method):
|
||||
return deepcopy(self.config.get(method, {}))
|
||||
def replica_method_options(self, method: str) -> Dict[str, Any]:
|
||||
return deepcopy(self.config.get(method, {}) or {})
|
||||
|
||||
def replica_method_can_work_without_replication_connection(self, method):
|
||||
return method != 'basebackup' and (self.replica_method_options(method).get('no_master')
|
||||
or self.replica_method_options(method).get('no_leader'))
|
||||
def replica_method_can_work_without_replication_connection(self, method: str) -> bool:
|
||||
return method != 'basebackup' and bool(self.replica_method_options(method).get('no_master')
|
||||
or self.replica_method_options(method).get('no_leader'))
|
||||
|
||||
def can_create_replica_without_replication_connection(self, replica_methods=None):
|
||||
def can_create_replica_without_replication_connection(self, replica_methods: Optional[List[str]]) -> bool:
|
||||
""" go through the replication methods to see if there are ones
|
||||
that does not require a working replication connection.
|
||||
"""
|
||||
@@ -359,10 +364,10 @@ class Postgresql(object):
|
||||
return any(self.replica_method_can_work_without_replication_connection(m) for m in replica_methods)
|
||||
|
||||
@property
|
||||
def enforce_hot_standby_feedback(self):
|
||||
def enforce_hot_standby_feedback(self) -> bool:
|
||||
return self._enforce_hot_standby_feedback
|
||||
|
||||
def set_enforce_hot_standby_feedback(self, value):
|
||||
def set_enforce_hot_standby_feedback(self, value: bool) -> None:
|
||||
# If we enable or disable the hot_standby_feedback we need to update postgresql.conf and reload
|
||||
if self._enforce_hot_standby_feedback != value:
|
||||
self._enforce_hot_standby_feedback = value
|
||||
@@ -370,7 +375,7 @@ class Postgresql(object):
|
||||
self.config.write_postgresql_conf()
|
||||
self.reload()
|
||||
|
||||
def reset_cluster_info_state(self, cluster: Union[Cluster, None], nofailover: Optional[bool] = None,
|
||||
def reset_cluster_info_state(self, cluster: Union[Cluster, None], nofailover: bool = False,
|
||||
global_config: Optional['GlobalConfig'] = None) -> None:
|
||||
"""Reset monitoring query cache.
|
||||
|
||||
@@ -395,7 +400,7 @@ class Postgresql(object):
|
||||
|
||||
self._global_config = global_config
|
||||
|
||||
def _cluster_info_state_get(self, name):
|
||||
def _cluster_info_state_get(self, name: str) -> Optional[Any]:
|
||||
if not self._cluster_info_state:
|
||||
try:
|
||||
result = self._is_leader_retry(self._query, self.cluster_info_query).fetchone()
|
||||
@@ -417,62 +422,63 @@ class Postgresql(object):
|
||||
|
||||
return self._cluster_info_state.get(name)
|
||||
|
||||
def replayed_location(self):
|
||||
def replayed_location(self) -> Optional[int]:
|
||||
return self._cluster_info_state_get('replayed_location')
|
||||
|
||||
def received_location(self):
|
||||
def received_location(self) -> Optional[int]:
|
||||
return self._cluster_info_state_get('received_location')
|
||||
|
||||
def slots(self):
|
||||
return self._cluster_info_state_get('slots')
|
||||
def slots(self) -> Dict[str, int]:
|
||||
return self._cluster_info_state_get('slots') or {}
|
||||
|
||||
def primary_slot_name(self):
|
||||
def primary_slot_name(self) -> Optional[str]:
|
||||
return self._cluster_info_state_get('slot_name')
|
||||
|
||||
def primary_conninfo(self):
|
||||
def primary_conninfo(self) -> Optional[str]:
|
||||
return self._cluster_info_state_get('conninfo')
|
||||
|
||||
def received_timeline(self):
|
||||
def received_timeline(self) -> Optional[int]:
|
||||
return self._cluster_info_state_get('received_tli')
|
||||
|
||||
def synchronous_commit(self) -> str:
|
||||
""":returns: "synchronous_commit" GUC value."""
|
||||
return self._cluster_info_state_get('synchronous_commit')
|
||||
return self._cluster_info_state_get('synchronous_commit') or 'on'
|
||||
|
||||
def synchronous_standby_names(self) -> str:
|
||||
""":returns: "synchronous_standby_names" GUC value."""
|
||||
return self._cluster_info_state_get('synchronous_standby_names')
|
||||
return self._cluster_info_state_get('synchronous_standby_names') or ''
|
||||
|
||||
def pg_stat_replication(self) -> List[Dict[str, Any]]:
|
||||
""":returns: a result set of 'SELECT * FROM pg_stat_replication'."""
|
||||
return self._cluster_info_state_get('pg_stat_replication') or []
|
||||
|
||||
def is_leader(self):
|
||||
def is_leader(self) -> bool:
|
||||
try:
|
||||
return bool(self._cluster_info_state_get('timeline'))
|
||||
except PostgresConnectionException:
|
||||
logger.warning('Failed to determine PostgreSQL state from the connection, falling back to cached role')
|
||||
return bool(self.is_running() and self.role in ('master', 'primary'))
|
||||
|
||||
def replay_paused(self):
|
||||
return self._cluster_info_state_get('replay_paused')
|
||||
def replay_paused(self) -> bool:
|
||||
return self._cluster_info_state_get('replay_paused') or False
|
||||
|
||||
def resume_wal_replay(self):
|
||||
def resume_wal_replay(self) -> None:
|
||||
self._query('SELECT pg_catalog.pg_{0}_replay_resume()'.format(self.wal_name))
|
||||
|
||||
def handle_parameter_change(self):
|
||||
def handle_parameter_change(self) -> None:
|
||||
if self.major_version >= 140000 and not self.is_starting() and self.replay_paused():
|
||||
logger.info('Resuming paused WAL replay for PostgreSQL 14+')
|
||||
self.resume_wal_replay()
|
||||
|
||||
def pg_control_timeline(self):
|
||||
def pg_control_timeline(self) -> Optional[int]:
|
||||
try:
|
||||
|
||||
return int(self.controldata().get("Latest checkpoint's TimeLineID"))
|
||||
return int(self.controldata().get("Latest checkpoint's TimeLineID", ""))
|
||||
except (TypeError, ValueError):
|
||||
logger.exception('Failed to parse timeline from pg_controldata output')
|
||||
|
||||
def parse_wal_record(self, timeline, lsn):
|
||||
def parse_wal_record(self, timeline: str,
|
||||
lsn: str) -> Union[Tuple[str, str, str, str], Tuple[None, None, None, None]]:
|
||||
out, err = self.waldump(timeline, lsn, 1)
|
||||
if out and not err:
|
||||
match = re.match(r'^rmgr:\s+(.+?)\s+len \(rec/tot\):\s+\d+/\s+\d+, tx:\s+\d+, '
|
||||
@@ -482,7 +488,7 @@ class Postgresql(object):
|
||||
return match.groups()
|
||||
return None, None, None, None
|
||||
|
||||
def latest_checkpoint_location(self):
|
||||
def latest_checkpoint_location(self) -> Optional[int]:
|
||||
"""Returns checkpoint location for the cleanly shut down primary.
|
||||
But, if we know that the checkpoint was written to the new WAL
|
||||
due to the archive_mode=on, we will return the LSN of prev wal record (SWITCH)."""
|
||||
@@ -490,25 +496,25 @@ class Postgresql(object):
|
||||
data = self.controldata()
|
||||
timeline = data.get("Latest checkpoint's TimeLineID")
|
||||
lsn = checkpoint_lsn = data.get('Latest checkpoint location')
|
||||
if data.get('Database cluster state') == 'shut down' and lsn and timeline:
|
||||
if data.get('Database cluster state') == 'shut down' and lsn and timeline and checkpoint_lsn:
|
||||
try:
|
||||
checkpoint_lsn = parse_lsn(checkpoint_lsn)
|
||||
rm_name, lsn, prev, desc = self.parse_wal_record(timeline, lsn)
|
||||
desc = desc.strip().lower()
|
||||
if rm_name == 'XLOG' and parse_lsn(lsn) == checkpoint_lsn and prev and\
|
||||
desc = str(desc).strip().lower()
|
||||
if rm_name == 'XLOG' and lsn and parse_lsn(lsn) == checkpoint_lsn and prev and\
|
||||
desc.startswith('checkpoint') and desc.endswith('shutdown'):
|
||||
_, lsn, _, desc = self.parse_wal_record(timeline, prev)
|
||||
prev = parse_lsn(prev)
|
||||
# If the cluster is shutdown with archive_mode=on, WAL is switched before writing the checkpoint.
|
||||
# In this case we want to take the LSN of previous record (switch) as the last known WAL location.
|
||||
if parse_lsn(lsn) == prev and desc.strip() in ('xlog switch', 'SWITCH'):
|
||||
if lsn and parse_lsn(lsn) == prev and str(desc).strip() in ('xlog switch', 'SWITCH'):
|
||||
return prev
|
||||
except Exception as e:
|
||||
logger.error('Exception when parsing WAL pg_%sdump output: %r', self.wal_name, e)
|
||||
if isinstance(checkpoint_lsn, int):
|
||||
return checkpoint_lsn
|
||||
|
||||
def is_running(self):
|
||||
def is_running(self) -> Optional[PostmasterProcess]:
|
||||
"""Returns PostmasterProcess if one is running on the data directory or None. If most recently seen process
|
||||
is running updates the cached process based on pid file."""
|
||||
if self._postmaster_proc:
|
||||
@@ -523,7 +529,7 @@ class Postgresql(object):
|
||||
return self._postmaster_proc
|
||||
|
||||
@property
|
||||
def cb_called(self):
|
||||
def cb_called(self) -> bool:
|
||||
return self.__cb_called
|
||||
|
||||
def call_nowait(self, cb_type: CallbackAction) -> None:
|
||||
@@ -544,31 +550,31 @@ class Postgresql(object):
|
||||
logger.exception('callback %s %r %s %s failed', cmd, cb_type, role, self.scope)
|
||||
|
||||
@property
|
||||
def role(self):
|
||||
def role(self) -> str:
|
||||
with self._role_lock:
|
||||
return self._role
|
||||
|
||||
def set_role(self, value):
|
||||
def set_role(self, value: str) -> None:
|
||||
with self._role_lock:
|
||||
self._role = value
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
def state(self) -> str:
|
||||
with self._state_lock:
|
||||
return self._state
|
||||
|
||||
def set_state(self, value):
|
||||
def set_state(self, value: str) -> None:
|
||||
with self._state_lock:
|
||||
self._state = value
|
||||
self._state_entry_timestamp = time.time()
|
||||
|
||||
def time_in_state(self):
|
||||
def time_in_state(self) -> float:
|
||||
return time.time() - self._state_entry_timestamp
|
||||
|
||||
def is_starting(self):
|
||||
def is_starting(self) -> bool:
|
||||
return self.state == 'starting'
|
||||
|
||||
def wait_for_port_open(self, postmaster, timeout):
|
||||
def wait_for_port_open(self, postmaster: PostmasterProcess, timeout: float) -> bool:
|
||||
"""Waits until PostgreSQL opens ports."""
|
||||
for _ in polling_loop(timeout):
|
||||
if self.cancellable.is_cancelled:
|
||||
@@ -588,7 +594,9 @@ class Postgresql(object):
|
||||
logger.warning("Timed out waiting for PostgreSQL to start")
|
||||
return False
|
||||
|
||||
def start(self, timeout=None, task=None, block_callbacks=False, role=None, after_start=None):
|
||||
def start(self, timeout: Optional[float] = None, task: Optional[CriticalTask] = None,
|
||||
block_callbacks: bool = False, role: Optional[str] = None,
|
||||
after_start: Optional[Callable[..., Any]] = None) -> Optional[bool]:
|
||||
"""Start PostgreSQL
|
||||
|
||||
Waits for postmaster to open ports or terminate so pg_isready can be used to check startup completion
|
||||
@@ -650,7 +658,7 @@ class Postgresql(object):
|
||||
start_timeout = timeout
|
||||
if not start_timeout:
|
||||
try:
|
||||
start_timeout = float(self.config.get('pg_ctl_timeout', 60))
|
||||
start_timeout = float(self.config.get('pg_ctl_timeout', 60) or 0)
|
||||
except ValueError:
|
||||
start_timeout = 60
|
||||
|
||||
@@ -668,7 +676,8 @@ class Postgresql(object):
|
||||
else:
|
||||
return None
|
||||
|
||||
def checkpoint(self, connect_kwargs=None, timeout=None):
|
||||
def checkpoint(self, connect_kwargs: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[float] = None) -> Optional[str]:
|
||||
check_not_is_in_recovery = connect_kwargs is not None
|
||||
connect_kwargs = connect_kwargs or self.config.local_connect_kwargs
|
||||
for p in ['connect_timeout', 'options']:
|
||||
@@ -680,15 +689,17 @@ class Postgresql(object):
|
||||
cur.execute("SET statement_timeout = 0")
|
||||
if check_not_is_in_recovery:
|
||||
cur.execute('SELECT pg_catalog.pg_is_in_recovery()')
|
||||
if cur.fetchone()[0]:
|
||||
row = cur.fetchone()
|
||||
if not row or row[0]:
|
||||
return 'is_in_recovery=true'
|
||||
cur.execute('CHECKPOINT')
|
||||
except psycopg.Error:
|
||||
logger.exception('Exception during CHECKPOINT')
|
||||
return 'not accessible or not healty'
|
||||
|
||||
def stop(self, mode='fast', block_callbacks=False, checkpoint=None,
|
||||
on_safepoint=None, on_shutdown=None, before_shutdown=None, stop_timeout=None):
|
||||
def stop(self, mode: str = 'fast', block_callbacks: bool = False, checkpoint: Optional[bool] = None,
|
||||
on_safepoint: Optional[Callable[..., Any]] = None, on_shutdown: Optional[Callable[[int], Any]] = None,
|
||||
before_shutdown: Optional[Callable[..., Any]] = None, stop_timeout: Optional[int] = None) -> bool:
|
||||
"""Stop PostgreSQL
|
||||
|
||||
Supports a callback when a safepoint is reached. A safepoint is when no user backend can return a successful
|
||||
@@ -716,7 +727,9 @@ class Postgresql(object):
|
||||
self.set_state('stop failed')
|
||||
return success
|
||||
|
||||
def _do_stop(self, mode, block_callbacks, checkpoint, on_safepoint, on_shutdown, before_shutdown, stop_timeout):
|
||||
def _do_stop(self, mode: str, block_callbacks: bool, checkpoint: bool,
|
||||
on_safepoint: Optional[Callable[..., Any]], on_shutdown: Optional[Callable[..., Any]],
|
||||
before_shutdown: Optional[Callable[..., Any]], stop_timeout: Optional[int]) -> Tuple[bool, bool]:
|
||||
postmaster = self.is_running()
|
||||
if not postmaster:
|
||||
if on_safepoint:
|
||||
@@ -774,7 +787,8 @@ class Postgresql(object):
|
||||
|
||||
return True, True
|
||||
|
||||
def terminate_postmaster(self, postmaster, mode, stop_timeout):
|
||||
def terminate_postmaster(self, postmaster: PostmasterProcess, mode: str,
|
||||
stop_timeout: Optional[int]) -> Optional[bool]:
|
||||
if mode in ['fast', 'smart']:
|
||||
try:
|
||||
success = postmaster.signal_stop('immediate', self.pgcommand('pg_ctl'))
|
||||
@@ -787,13 +801,13 @@ class Postgresql(object):
|
||||
logger.warning("Sending SIGKILL to Postmaster and its children")
|
||||
return postmaster.signal_kill()
|
||||
|
||||
def terminate_starting_postmaster(self, postmaster):
|
||||
def terminate_starting_postmaster(self, postmaster: PostmasterProcess) -> None:
|
||||
"""Terminates a postmaster that has not yet opened ports or possibly even written a pid file. Blocks
|
||||
until the process goes away."""
|
||||
postmaster.signal_stop('immediate', self.pgcommand('pg_ctl'))
|
||||
postmaster.wait()
|
||||
|
||||
def _wait_for_connection_close(self, postmaster):
|
||||
def _wait_for_connection_close(self, postmaster: PostmasterProcess) -> None:
|
||||
try:
|
||||
with self.connection().cursor() as cur:
|
||||
while postmaster.is_running(): # Need a timeout here?
|
||||
@@ -802,17 +816,17 @@ class Postgresql(object):
|
||||
except psycopg.Error:
|
||||
pass
|
||||
|
||||
def reload(self, block_callbacks=False):
|
||||
def reload(self, block_callbacks: bool = False) -> bool:
|
||||
ret = self.pg_ctl('reload')
|
||||
if ret and not block_callbacks:
|
||||
self.call_nowait(CallbackAction.ON_RELOAD)
|
||||
return ret
|
||||
|
||||
def check_for_startup(self):
|
||||
def check_for_startup(self) -> bool:
|
||||
"""Checks PostgreSQL status and returns if PostgreSQL is in the middle of startup."""
|
||||
return self.is_starting() and not self.check_startup_state_changed()
|
||||
|
||||
def check_startup_state_changed(self):
|
||||
def check_startup_state_changed(self) -> bool:
|
||||
"""Checks if PostgreSQL has completed starting up or failed or still starting.
|
||||
|
||||
Should only be called when state == 'starting'
|
||||
@@ -847,7 +861,7 @@ class Postgresql(object):
|
||||
|
||||
return True
|
||||
|
||||
def wait_for_startup(self, timeout=None):
|
||||
def wait_for_startup(self, timeout: float = 0) -> Optional[bool]:
|
||||
"""Waits for PostgreSQL startup to complete or fail.
|
||||
|
||||
:returns: True if start was successful, False otherwise"""
|
||||
@@ -862,8 +876,10 @@ class Postgresql(object):
|
||||
|
||||
return self.state == 'running'
|
||||
|
||||
def restart(self, timeout=None, task=None, block_callbacks=False,
|
||||
role=None, before_shutdown=None, after_start=None):
|
||||
def restart(self, timeout: Optional[float] = None, task: Optional[CriticalTask] = None,
|
||||
block_callbacks: bool = False, role: Optional[str] = None,
|
||||
before_shutdown: Optional[Callable[..., Any]] = None,
|
||||
after_start: Optional[Callable[..., Any]] = None) -> Optional[bool]:
|
||||
"""Restarts PostgreSQL.
|
||||
|
||||
When timeout parameter is set the call will block either until PostgreSQL has started, failed to start or
|
||||
@@ -880,13 +896,13 @@ class Postgresql(object):
|
||||
self.set_state('restart failed ({0})'.format(self.state))
|
||||
return ret
|
||||
|
||||
def is_healthy(self):
|
||||
def is_healthy(self) -> bool:
|
||||
if not self.is_running():
|
||||
logger.warning('Postgresql is not running.')
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_guc_value(self, name):
|
||||
def get_guc_value(self, name: str) -> Optional[str]:
|
||||
cmd = [self.pgcommand('postgres'), '-D', self._data_dir, '-C', name,
|
||||
'--config-file={}'.format(self.config.postgresql_conf)]
|
||||
try:
|
||||
@@ -896,7 +912,7 @@ class Postgresql(object):
|
||||
except Exception as e:
|
||||
logger.error('Failed to execute %s: %r', cmd, e)
|
||||
|
||||
def controldata(self):
|
||||
def controldata(self) -> Dict[str, str]:
|
||||
""" return the contents of pg_controldata, or non-True value if pg_controldata call failed """
|
||||
# Don't try to call pg_controldata during backup restore
|
||||
if self._version_file_exists() and self.state != 'creating replica':
|
||||
@@ -911,11 +927,11 @@ class Postgresql(object):
|
||||
logger.exception("Error when calling pg_controldata")
|
||||
return {}
|
||||
|
||||
def waldump(self, timeline, lsn, limit):
|
||||
def waldump(self, timeline: Union[int, str], lsn: str, limit: int) -> Tuple[Optional[bytes], Optional[bytes]]:
|
||||
cmd = self.pgcommand('pg_{0}dump'.format(self.wal_name))
|
||||
env = {**os.environ, 'LANG': 'C', 'LC_ALL': 'C', 'PGDATA': self._data_dir}
|
||||
try:
|
||||
waldump = subprocess.Popen([cmd, '-t', str(timeline), '-s', str(lsn), '-n', str(limit)],
|
||||
waldump = subprocess.Popen([cmd, '-t', str(timeline), '-s', lsn, '-n', str(limit)],
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env)
|
||||
out, err = waldump.communicate()
|
||||
waldump.wait()
|
||||
@@ -925,22 +941,24 @@ class Postgresql(object):
|
||||
return None, None
|
||||
|
||||
@contextmanager
|
||||
def get_replication_connection_cursor(self, host=None, port=5432, **kwargs):
|
||||
def get_replication_connection_cursor(self, host: Optional[str] = None, port: int = 5432,
|
||||
**kwargs: Any) -> Generator[Union['cursor', 'Cursor[Any]'], None, None]:
|
||||
conn_kwargs = self.config.replication.copy()
|
||||
conn_kwargs.update(host=host, port=int(port) if port else None, user=conn_kwargs.pop('username'),
|
||||
connect_timeout=3, replication=1, options='-c statement_timeout=2000')
|
||||
with get_connection_cursor(**conn_kwargs) as cur:
|
||||
yield cur
|
||||
|
||||
def get_replica_timeline(self):
|
||||
def get_replica_timeline(self) -> Optional[int]:
|
||||
try:
|
||||
with self.get_replication_connection_cursor(**self.config.local_replication_address) as cur:
|
||||
cur.execute('IDENTIFY_SYSTEM')
|
||||
return cur.fetchone()[1]
|
||||
row = cur.fetchone()
|
||||
return row[1] if row else None
|
||||
except Exception:
|
||||
logger.exception('Can not fetch local timeline and lsn from replication connection')
|
||||
|
||||
def replica_cached_timeline(self, primary_timeline):
|
||||
def replica_cached_timeline(self, primary_timeline: Optional[int]) -> Optional[int]:
|
||||
if not self._cached_replica_timeline or not primary_timeline\
|
||||
or self._cached_replica_timeline != primary_timeline:
|
||||
self._cached_replica_timeline = self.get_replica_timeline()
|
||||
@@ -948,29 +966,26 @@ class Postgresql(object):
|
||||
|
||||
def get_primary_timeline(self) -> int:
|
||||
""":returns: current timeline if postgres is running as a primary or 0."""
|
||||
return self._cluster_info_state_get('timeline')
|
||||
return self._cluster_info_state_get('timeline') or 0
|
||||
|
||||
def get_history(self, timeline):
|
||||
def get_history(self, timeline: int) -> List[Union[Tuple[int, int, str], Tuple[int, int, str, str, str]]]:
|
||||
history_path = os.path.join(self.wal_dir, '{0:08X}.history'.format(timeline))
|
||||
history_mtime = mtime(history_path)
|
||||
history: List[Union[Tuple[int, int, str], Tuple[int, int, str, str, str]]] = []
|
||||
if history_mtime:
|
||||
try:
|
||||
with open(history_path, 'r') as f:
|
||||
history = f.read()
|
||||
history = list(parse_history(history))
|
||||
history_content = f.read()
|
||||
history = list(parse_history(history_content))
|
||||
if history[-1][0] == timeline - 1:
|
||||
history_mtime = datetime.fromtimestamp(history_mtime).replace(tzinfo=tz.tzlocal())
|
||||
history[-1].append(history_mtime.isoformat())
|
||||
history[-1].append(self.name)
|
||||
return history
|
||||
history[-1] = history[-1][:3] + (history_mtime.isoformat(), self.name)
|
||||
except Exception:
|
||||
logger.exception('Failed to read and parse %s', (history_path,))
|
||||
return history
|
||||
|
||||
def follow(self,
|
||||
member: Member,
|
||||
role: Optional[str] = 'replica',
|
||||
timeout: Optional[float] = None,
|
||||
do_reload: Optional[bool] = False) -> Optional[bool]:
|
||||
def follow(self, member: Union[Leader, Member, None], role: str = 'replica',
|
||||
timeout: Optional[float] = None, do_reload: bool = False) -> Optional[bool]:
|
||||
"""Reconfigure postgres to follow a new member or use different recovery parameters.
|
||||
|
||||
Method may call `on_role_change` callback if role is changing.
|
||||
@@ -1016,14 +1031,14 @@ class Postgresql(object):
|
||||
self.call_nowait(CallbackAction.ON_ROLE_CHANGE)
|
||||
return ret
|
||||
|
||||
def _wait_promote(self, wait_seconds):
|
||||
def _wait_promote(self, wait_seconds: int) -> Optional[bool]:
|
||||
for _ in polling_loop(wait_seconds):
|
||||
data = self.controldata()
|
||||
if data.get('Database cluster state') == 'in production':
|
||||
self.set_role('master')
|
||||
return True
|
||||
|
||||
def _pre_promote(self):
|
||||
def _pre_promote(self) -> bool:
|
||||
"""
|
||||
Runs a fencing script after the leader lock is acquired but before the replica is promoted.
|
||||
If the script exits with a non-zero code, promotion does not happen and the leader key is removed from DCS.
|
||||
@@ -1053,7 +1068,8 @@ class Postgresql(object):
|
||||
except Exception as e:
|
||||
logger.error('Exception when calling `%s`: %r', cmd, e)
|
||||
|
||||
def promote(self, wait_seconds, task, before_promote=None, on_success=None):
|
||||
def promote(self, wait_seconds: int, task: CriticalTask, before_promote: Optional[Callable[..., Any]] = None,
|
||||
on_success: Optional[Callable[..., Any]] = None) -> Optional[bool]:
|
||||
if self.role in ('promoted', 'master', 'primary'):
|
||||
return True
|
||||
|
||||
@@ -1086,43 +1102,48 @@ class Postgresql(object):
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _wal_position(is_leader, wal_position, received_location, replayed_location):
|
||||
def _wal_position(is_leader: bool, wal_position: int,
|
||||
received_location: Optional[int], replayed_location: Optional[int]) -> int:
|
||||
return wal_position if is_leader else max(received_location or 0, replayed_location or 0)
|
||||
|
||||
def timeline_wal_position(self):
|
||||
def timeline_wal_position(self) -> Tuple[int, int, Optional[int]]:
|
||||
# This method could be called from different threads (simultaneously with some other `_query` calls).
|
||||
# If it is called not from main thread we will create a new cursor to execute statement.
|
||||
if current_thread().ident == self.__thread_ident:
|
||||
timeline = self._cluster_info_state_get('timeline')
|
||||
wal_position = self._cluster_info_state_get('wal_position')
|
||||
timeline = self._cluster_info_state_get('timeline') or 0
|
||||
wal_position = self._cluster_info_state_get('wal_position') or 0
|
||||
replayed_location = self.replayed_location()
|
||||
received_location = self.received_location()
|
||||
pg_control_timeline = self._cluster_info_state_get('pg_control_timeline')
|
||||
else:
|
||||
with self.connection().cursor() as cursor:
|
||||
cursor.execute(self.cluster_info_query)
|
||||
(timeline, wal_position, replayed_location,
|
||||
received_location, _, pg_control_timeline) = cursor.fetchone()[:6]
|
||||
cursor.execute(self.cluster_info_query.encode('utf-8'))
|
||||
row = cursor.fetchone()
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
assert row is not None
|
||||
(timeline, wal_position, replayed_location, received_location, _, pg_control_timeline) = row[:6]
|
||||
|
||||
wal_position = self._wal_position(timeline, wal_position, received_location, replayed_location)
|
||||
wal_position = self._wal_position(bool(timeline), wal_position, received_location, replayed_location)
|
||||
return (timeline, wal_position, pg_control_timeline)
|
||||
|
||||
def postmaster_start_time(self):
|
||||
def postmaster_start_time(self) -> Optional[str]:
|
||||
try:
|
||||
query = "SELECT " + self.POSTMASTER_START_TIME
|
||||
if current_thread().ident == self.__thread_ident:
|
||||
return self.query(query).fetchone()[0].isoformat(sep=' ')
|
||||
with self.connection().cursor() as cursor:
|
||||
cursor.execute(query)
|
||||
return cursor.fetchone()[0].isoformat(sep=' ')
|
||||
row = self.query(query).fetchone()
|
||||
else:
|
||||
with self.connection().cursor() as cursor:
|
||||
cursor.execute(query)
|
||||
row = cursor.fetchone()
|
||||
return row[0].isoformat(sep=' ') if row else None
|
||||
except psycopg.Error:
|
||||
return None
|
||||
|
||||
def last_operation(self):
|
||||
return self._wal_position(self.is_leader(), self._cluster_info_state_get('wal_position'),
|
||||
def last_operation(self) -> int:
|
||||
return self._wal_position(self.is_leader(), self._cluster_info_state_get('wal_position') or 0,
|
||||
self.received_location(), self.replayed_location())
|
||||
|
||||
def configure_server_parameters(self):
|
||||
def configure_server_parameters(self) -> None:
|
||||
self._major_version = self.get_major_version()
|
||||
self.config.setup_server_parameters()
|
||||
|
||||
@@ -1135,9 +1156,9 @@ class Postgresql(object):
|
||||
self.configure_server_parameters()
|
||||
return self._major_version > 0
|
||||
|
||||
def pg_wal_realpath(self):
|
||||
def pg_wal_realpath(self) -> Dict[str, str]:
|
||||
"""Returns a dict containing the symlink (key) and target (value) for the wal directory"""
|
||||
links = {}
|
||||
links: Dict[str, str] = {}
|
||||
for pg_wal_dir in ('pg_xlog', 'pg_wal'):
|
||||
pg_wal_path = os.path.join(self._data_dir, pg_wal_dir)
|
||||
if os.path.exists(pg_wal_path) and os.path.islink(pg_wal_path):
|
||||
@@ -1145,9 +1166,9 @@ class Postgresql(object):
|
||||
links[pg_wal_path] = pg_wal_realpath
|
||||
return links
|
||||
|
||||
def pg_tblspc_realpaths(self):
|
||||
def pg_tblspc_realpaths(self) -> Dict[str, str]:
|
||||
"""Returns a dict containing the symlink (key) and target (values) for the tablespaces"""
|
||||
links = {}
|
||||
links: Dict[str, str] = {}
|
||||
pg_tblsp_dir = os.path.join(self._data_dir, 'pg_tblspc')
|
||||
if os.path.exists(pg_tblsp_dir):
|
||||
for tsdn in os.listdir(pg_tblsp_dir):
|
||||
@@ -1157,7 +1178,7 @@ class Postgresql(object):
|
||||
links[pg_tsp_path] = pg_tsp_rpath
|
||||
return links
|
||||
|
||||
def move_data_directory(self):
|
||||
def move_data_directory(self) -> None:
|
||||
if os.path.isdir(self._data_dir) and not self.is_running():
|
||||
try:
|
||||
postfix = 'failed'
|
||||
@@ -1191,7 +1212,7 @@ class Postgresql(object):
|
||||
except OSError:
|
||||
logger.exception("Could not rename data directory %s", self._data_dir)
|
||||
|
||||
def remove_data_directory(self):
|
||||
def remove_data_directory(self) -> None:
|
||||
self.set_role('uninitialized')
|
||||
logger.info('Removing data directory: %s', self._data_dir)
|
||||
try:
|
||||
@@ -1219,7 +1240,7 @@ class Postgresql(object):
|
||||
logger.exception('Could not remove data directory %s', self._data_dir)
|
||||
self.move_data_directory()
|
||||
|
||||
def schedule_sanity_checks_after_pause(self):
|
||||
def schedule_sanity_checks_after_pause(self) -> None:
|
||||
"""
|
||||
After coming out of pause we have to:
|
||||
1. configure server parameters if necessary
|
||||
@@ -1229,4 +1250,4 @@ class Postgresql(object):
|
||||
self.ensure_major_version_is_known()
|
||||
self.slots_handler.schedule()
|
||||
self.citus_handler.schedule_cache_rebuild()
|
||||
self._sysid = None
|
||||
self._sysid = ''
|
||||
|
||||
@@ -3,34 +3,39 @@ import os
|
||||
import shlex
|
||||
import tempfile
|
||||
import time
|
||||
from typing import List, Dict, Union, Callable, Tuple
|
||||
|
||||
from ..dcs import RemoteMember
|
||||
from typing import Any, Callable, Dict, List, Optional, Union, Tuple, TYPE_CHECKING
|
||||
|
||||
from ..async_executor import CriticalTask
|
||||
from ..dcs import Leader, Member, RemoteMember
|
||||
from ..psycopg import quote_ident, quote_literal
|
||||
from ..utils import deep_compare, unquote
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from . import Postgresql
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Bootstrap(object):
|
||||
|
||||
def __init__(self, postgresql):
|
||||
def __init__(self, postgresql: 'Postgresql') -> None:
|
||||
self._postgresql = postgresql
|
||||
self._running_custom_bootstrap = False
|
||||
|
||||
@property
|
||||
def running_custom_bootstrap(self):
|
||||
def running_custom_bootstrap(self) -> bool:
|
||||
return self._running_custom_bootstrap
|
||||
|
||||
@property
|
||||
def keep_existing_recovery_conf(self):
|
||||
def keep_existing_recovery_conf(self) -> bool:
|
||||
return self._running_custom_bootstrap and self._keep_existing_recovery_conf
|
||||
|
||||
@staticmethod
|
||||
def process_user_options(tool: str,
|
||||
options: Union[Dict[str, str], List[Union[str, Dict[str, str]]]],
|
||||
options: Union[Any, Dict[str, str], List[Union[str, Dict[str, Any]]]],
|
||||
not_allowed_options: Tuple[str, ...],
|
||||
error_handler: Callable[[str], None]) -> List:
|
||||
error_handler: Callable[[str], None]) -> List[str]:
|
||||
"""Format *options* in a list or dictionary format into command line long form arguments.
|
||||
|
||||
.. note::
|
||||
@@ -77,9 +82,9 @@ class Bootstrap(object):
|
||||
:param error_handler: A function which will be called when an error condition is encountered
|
||||
:returns: List of long form arguments to pass to the named tool
|
||||
"""
|
||||
user_options = []
|
||||
user_options: List[str] = []
|
||||
|
||||
def option_is_allowed(name):
|
||||
def option_is_allowed(name: str) -> bool:
|
||||
ret = name not in not_allowed_options
|
||||
if not ret:
|
||||
error_handler('{0} option for {1} is not allowed'.format(name, tool))
|
||||
@@ -106,11 +111,11 @@ class Bootstrap(object):
|
||||
error_handler('{0} options must be list or dict'.format(tool))
|
||||
return user_options
|
||||
|
||||
def _initdb(self, config):
|
||||
def _initdb(self, config: Any) -> bool:
|
||||
self._postgresql.set_state('initializing new cluster')
|
||||
not_allowed_options = ('pgdata', 'nosync', 'pwfile', 'sync-only', 'version')
|
||||
|
||||
def error_handler(e):
|
||||
def error_handler(e: str) -> None:
|
||||
raise Exception(e)
|
||||
|
||||
options = self.process_user_options('initdb', config or [], not_allowed_options, error_handler)
|
||||
@@ -134,7 +139,7 @@ class Bootstrap(object):
|
||||
self._postgresql.set_state('initdb failed')
|
||||
return ret
|
||||
|
||||
def _post_restore(self):
|
||||
def _post_restore(self) -> None:
|
||||
self._postgresql.config.restore_configuration_files()
|
||||
self._postgresql.configure_server_parameters()
|
||||
|
||||
@@ -145,7 +150,7 @@ class Bootstrap(object):
|
||||
if os.path.exists(trigger_file):
|
||||
os.unlink(trigger_file)
|
||||
|
||||
def _custom_bootstrap(self, config):
|
||||
def _custom_bootstrap(self, config: Any) -> bool:
|
||||
self._postgresql.set_state('running custom bootstrap script')
|
||||
params = [] if config.get('no_params') else ['--scope=' + self._postgresql.scope,
|
||||
'--datadir=' + self._postgresql.data_dir]
|
||||
@@ -165,7 +170,7 @@ class Bootstrap(object):
|
||||
self._postgresql.config.remove_recovery_conf()
|
||||
return True
|
||||
|
||||
def call_post_bootstrap(self, config):
|
||||
def call_post_bootstrap(self, config: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
runs a script after initdb or custom bootstrap script is called and waits until completion.
|
||||
"""
|
||||
@@ -192,7 +197,7 @@ class Bootstrap(object):
|
||||
return False
|
||||
return True
|
||||
|
||||
def create_replica(self, clone_member):
|
||||
def create_replica(self, clone_member: Union[Leader, Member, None]) -> Optional[int]:
|
||||
"""
|
||||
create the replica according to the replica_method
|
||||
defined by the user. this is a list, so we need to
|
||||
@@ -279,7 +284,7 @@ class Bootstrap(object):
|
||||
self._postgresql.set_state('stopped')
|
||||
return ret
|
||||
|
||||
def basebackup(self, conn_url, env, options):
|
||||
def basebackup(self, conn_url: str, env: Dict[str, str], options: Dict[str, Any]) -> Optional[int]:
|
||||
# creates a replica data dir using pg_basebackup.
|
||||
# this is the default, built-in create_replica_methods
|
||||
# tries twice, then returns failure (as 1)
|
||||
@@ -314,7 +319,7 @@ class Bootstrap(object):
|
||||
|
||||
return ret
|
||||
|
||||
def clone(self, clone_member):
|
||||
def clone(self, clone_member: Union[Leader, Member, None]) -> bool:
|
||||
"""
|
||||
- initialize the replica from an existing member (primary or replica)
|
||||
- initialize the replica using the replica creation method that
|
||||
@@ -327,7 +332,7 @@ class Bootstrap(object):
|
||||
self._post_restore()
|
||||
return ret
|
||||
|
||||
def bootstrap(self, config):
|
||||
def bootstrap(self, config: Dict[str, Any]) -> bool:
|
||||
""" Initialize a new node from scratch and start it. """
|
||||
pg_hba = config.get('pg_hba', [])
|
||||
method = config.get('method') or 'initdb'
|
||||
@@ -339,9 +344,9 @@ class Bootstrap(object):
|
||||
method = 'initdb'
|
||||
do_initialize = self._initdb
|
||||
return do_initialize(config.get(method)) and self._postgresql.config.append_pg_hba(pg_hba) \
|
||||
and self._postgresql.config.save_configuration_files() and self._postgresql.start()
|
||||
and self._postgresql.config.save_configuration_files() and bool(self._postgresql.start())
|
||||
|
||||
def create_or_update_role(self, name, password, options):
|
||||
def create_or_update_role(self, name: str, password: Optional[str], options: List[str]) -> None:
|
||||
options = list(map(str.upper, options))
|
||||
if 'NOLOGIN' not in options and 'LOGIN' not in options:
|
||||
options.append('LOGIN')
|
||||
@@ -371,7 +376,7 @@ END;$$""".format(quote_literal(name), quote_ident(name, self._postgresql.connect
|
||||
self._postgresql.query('RESET log_statement')
|
||||
self._postgresql.query('RESET pg_stat_statements.track_utility')
|
||||
|
||||
def post_bootstrap(self, config, task):
|
||||
def post_bootstrap(self, config: Dict[str, Any], task: CriticalTask) -> Optional[bool]:
|
||||
try:
|
||||
postgresql = self._postgresql
|
||||
superuser = postgresql.config.superuser
|
||||
|
||||
@@ -17,7 +17,7 @@ class CallbackAction(str, Enum):
|
||||
ON_RELOAD = "on_reload"
|
||||
ON_ROLE_CHANGE = "on_role_change"
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
@@ -60,15 +60,17 @@ class CallbackExecutor(CancellableExecutor, Thread):
|
||||
self._cmd = cmd
|
||||
self._condition.notify()
|
||||
|
||||
def run(self):
|
||||
def run(self) -> None:
|
||||
while True:
|
||||
with self._condition:
|
||||
if self._cmd is None:
|
||||
self._condition.wait()
|
||||
cmd, self._cmd = self._cmd, None
|
||||
|
||||
with self._lock:
|
||||
if not self._start_process(cmd, close_fds=True):
|
||||
continue
|
||||
self._process.wait()
|
||||
self._kill_children()
|
||||
if cmd is not None:
|
||||
with self._lock:
|
||||
if not self._start_process(cmd, close_fds=True):
|
||||
continue
|
||||
if self._process:
|
||||
self._process.wait()
|
||||
self._kill_children()
|
||||
|
||||
@@ -5,6 +5,7 @@ import subprocess
|
||||
from patroni.exceptions import PostgresException
|
||||
from patroni.utils import polling_loop
|
||||
from threading import Lock
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -15,13 +16,13 @@ class CancellableExecutor(object):
|
||||
There must be only one such process so that AsyncExecutor can easily cancel it.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._process = None
|
||||
self._process_cmd = None
|
||||
self._process_children = []
|
||||
self._process_children: List[psutil.Process] = []
|
||||
self._lock = Lock()
|
||||
|
||||
def _start_process(self, cmd, *args, **kwargs):
|
||||
def _start_process(self, cmd: List[str], *args: Any, **kwargs: Any) -> Optional[bool]:
|
||||
"""This method must be executed only when the `_lock` is acquired"""
|
||||
|
||||
try:
|
||||
@@ -32,7 +33,7 @@ class CancellableExecutor(object):
|
||||
return logger.exception('Failed to execute %s', cmd)
|
||||
return True
|
||||
|
||||
def _kill_process(self):
|
||||
def _kill_process(self) -> None:
|
||||
with self._lock:
|
||||
if self._process is not None and self._process.is_running() and not self._process_children:
|
||||
try:
|
||||
@@ -53,8 +54,8 @@ class CancellableExecutor(object):
|
||||
except psutil.AccessDenied as e:
|
||||
logger.warning('Failed to kill the process: %s', e.msg)
|
||||
|
||||
def _kill_children(self):
|
||||
waitlist = []
|
||||
def _kill_children(self) -> None:
|
||||
waitlist: List[psutil.Process] = []
|
||||
with self._lock:
|
||||
for child in self._process_children:
|
||||
try:
|
||||
@@ -69,15 +70,16 @@ class CancellableExecutor(object):
|
||||
|
||||
class CancellableSubprocess(CancellableExecutor):
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super(CancellableSubprocess, self).__init__()
|
||||
self._is_cancelled = False
|
||||
|
||||
def call(self, *args, **kwargs):
|
||||
def call(self, *args: Any, **kwargs: Union[Any, Dict[str, str]]) -> Optional[int]:
|
||||
for s in ('stdin', 'stdout', 'stderr'):
|
||||
kwargs.pop(s, None)
|
||||
|
||||
communicate = kwargs.pop('communicate', None)
|
||||
communicate: Optional[Dict[str, str]] = kwargs.pop('communicate', None)
|
||||
input_data = None
|
||||
if isinstance(communicate, dict):
|
||||
input_data = communicate.get('input')
|
||||
if input_data:
|
||||
@@ -96,7 +98,7 @@ class CancellableSubprocess(CancellableExecutor):
|
||||
self._is_cancelled = False
|
||||
started = self._start_process(*args, **kwargs)
|
||||
|
||||
if started:
|
||||
if started and self._process is not None:
|
||||
if isinstance(communicate, dict):
|
||||
communicate['stdout'], communicate['stderr'] = self._process.communicate(input_data)
|
||||
return self._process.wait()
|
||||
@@ -105,16 +107,16 @@ class CancellableSubprocess(CancellableExecutor):
|
||||
self._process = None
|
||||
self._kill_children()
|
||||
|
||||
def reset_is_cancelled(self):
|
||||
def reset_is_cancelled(self) -> None:
|
||||
with self._lock:
|
||||
self._is_cancelled = False
|
||||
|
||||
@property
|
||||
def is_cancelled(self):
|
||||
def is_cancelled(self) -> bool:
|
||||
with self._lock:
|
||||
return self._is_cancelled
|
||||
|
||||
def cancel(self, kill=False):
|
||||
def cancel(self, kill: bool = False) -> None:
|
||||
with self._lock:
|
||||
self._is_cancelled = True
|
||||
if self._process is None or not self._process.is_running():
|
||||
|
||||
@@ -4,11 +4,17 @@ import time
|
||||
|
||||
from threading import Condition, Event, Thread
|
||||
from urllib.parse import urlparse
|
||||
from typing import Any, Dict, List, Optional, Union, Tuple, TYPE_CHECKING
|
||||
|
||||
from .connection import Connection
|
||||
from ..dcs import CITUS_COORDINATOR_GROUP_ID
|
||||
from ..dcs import CITUS_COORDINATOR_GROUP_ID, Cluster
|
||||
from ..psycopg import connect, quote_ident
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from psycopg import Cursor
|
||||
from psycopg2 import cursor
|
||||
from . import Postgresql
|
||||
|
||||
CITUS_SLOT_NAME_RE = re.compile(r'^citus_shard_(move|split)_slot(_[1-9][0-9]*){2,3}$')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -16,7 +22,8 @@ logger = logging.getLogger(__name__)
|
||||
class PgDistNode(object):
|
||||
"""Represents a single row in the `pg_dist_node` table"""
|
||||
|
||||
def __init__(self, group, host, port, event, nodeid=None, timeout=None, cooldown=None):
|
||||
def __init__(self, group: int, host: str, port: int, event: str, nodeid: Optional[int] = None,
|
||||
timeout: Optional[float] = None, cooldown: Optional[float] = None) -> None:
|
||||
self.group = group
|
||||
# A weird way of pausing client connections by adding the `-demoted` suffix to the hostname
|
||||
self.host = host + ('-demoted' if event == 'before_demote' else '')
|
||||
@@ -29,7 +36,7 @@ class PgDistNode(object):
|
||||
# If transaction was started, we need to COMMIT/ROLLBACK before the deadline
|
||||
self.timeout = timeout
|
||||
self.cooldown = cooldown or 10000 # 10s by default
|
||||
self.deadline = 0
|
||||
self.deadline: float = 0
|
||||
|
||||
# All changes in the pg_dist_node are serialized on the Patroni
|
||||
# side by performing them from a thread. The thread, that is
|
||||
@@ -38,74 +45,74 @@ class PgDistNode(object):
|
||||
# the worker, and once it is done notify the calling thread.
|
||||
self._event = Event()
|
||||
|
||||
def wait(self):
|
||||
def wait(self) -> None:
|
||||
self._event.wait()
|
||||
|
||||
def wakeup(self):
|
||||
def wakeup(self) -> None:
|
||||
self._event.set()
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, PgDistNode) and self.event == other.event\
|
||||
and self.host == other.host and self.port == other.port
|
||||
|
||||
def __ne__(self, other):
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return ('PgDistNode(nodeid={0},group={1},host={2},port={3},event={4})'
|
||||
.format(self.nodeid, self.group, self.host, self.port, self.event))
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return str(self)
|
||||
|
||||
|
||||
class CitusHandler(Thread):
|
||||
|
||||
def __init__(self, postgresql, config):
|
||||
def __init__(self, postgresql: 'Postgresql', config: Optional[Dict[str, Union[str, int]]]) -> None:
|
||||
super(CitusHandler, self).__init__()
|
||||
self.daemon = True
|
||||
self._postgresql = postgresql
|
||||
self._config = config
|
||||
self._connection = Connection()
|
||||
self._pg_dist_node = {} # Cache of pg_dist_node: {groupid: PgDistNode()}
|
||||
self._tasks = [] # Requests to change pg_dist_node, every task is a `PgDistNode`
|
||||
self._pg_dist_node: Dict[int, PgDistNode] = {} # Cache of pg_dist_node: {groupid: PgDistNode()}
|
||||
self._tasks: List[PgDistNode] = [] # Requests to change pg_dist_node, every task is a `PgDistNode`
|
||||
self._condition = Condition() # protects _pg_dist_node, _tasks, and _schedule_load_pg_dist_node
|
||||
self._in_flight = None # Reference to the `PgDistNode` if there is a transaction in progress changing it
|
||||
self._in_flight: Optional[PgDistNode] = None # Reference to the `PgDistNode` being changed in a transaction
|
||||
self.schedule_cache_rebuild()
|
||||
|
||||
def is_enabled(self):
|
||||
def is_enabled(self) -> bool:
|
||||
return isinstance(self._config, dict)
|
||||
|
||||
def group(self):
|
||||
return self._config['group']
|
||||
def group(self) -> Optional[int]:
|
||||
return int(self._config['group']) if isinstance(self._config, dict) else None
|
||||
|
||||
def is_coordinator(self):
|
||||
def is_coordinator(self) -> bool:
|
||||
return self.is_enabled() and self.group() == CITUS_COORDINATOR_GROUP_ID
|
||||
|
||||
def is_worker(self):
|
||||
def is_worker(self) -> bool:
|
||||
return self.is_enabled() and not self.is_coordinator()
|
||||
|
||||
def set_conn_kwargs(self, kwargs):
|
||||
if self.is_enabled():
|
||||
def set_conn_kwargs(self, kwargs: Dict[str, Any]) -> None:
|
||||
if isinstance(self._config, dict): # self.is_enabled():
|
||||
kwargs.update({'dbname': self._config['database'],
|
||||
'options': '-c statement_timeout=0 -c idle_in_transaction_session_timeout=0'})
|
||||
self._connection.set_conn_kwargs(kwargs)
|
||||
|
||||
def schedule_cache_rebuild(self):
|
||||
def schedule_cache_rebuild(self) -> None:
|
||||
with self._condition:
|
||||
self._schedule_load_pg_dist_node = True
|
||||
|
||||
def on_demote(self):
|
||||
def on_demote(self) -> None:
|
||||
with self._condition:
|
||||
self._pg_dist_node.clear()
|
||||
self._tasks[:] = []
|
||||
self._in_flight = None
|
||||
|
||||
def query(self, sql, *params):
|
||||
def query(self, sql: str, *params: Any) -> Union['Cursor[Any]', 'cursor']:
|
||||
try:
|
||||
logger.debug('query(%s, %s)', sql, params)
|
||||
cursor = self._connection.cursor()
|
||||
cursor.execute(sql, params or None)
|
||||
cursor.execute(sql.encode('utf-8'), params or None)
|
||||
return cursor
|
||||
except Exception as e:
|
||||
logger.error('Exception when executing query "%s", (%s): %r', sql, params, e)
|
||||
@@ -114,7 +121,7 @@ class CitusHandler(Thread):
|
||||
self.schedule_cache_rebuild()
|
||||
raise e
|
||||
|
||||
def load_pg_dist_node(self):
|
||||
def load_pg_dist_node(self) -> bool:
|
||||
"""Read from the `pg_dist_node` table and put it into the local cache"""
|
||||
|
||||
with self._condition:
|
||||
@@ -132,7 +139,7 @@ class CitusHandler(Thread):
|
||||
self._pg_dist_node = {r[1]: PgDistNode(r[1], r[2], r[3], 'after_promote', r[0]) for r in cursor}
|
||||
return True
|
||||
|
||||
def sync_pg_dist_node(self, cluster):
|
||||
def sync_pg_dist_node(self, cluster: Cluster) -> None:
|
||||
"""Maintain the `pg_dist_node` from the coordinator leader every heartbeat loop.
|
||||
|
||||
We can't always rely on REST API calls from worker nodes in order
|
||||
@@ -156,12 +163,12 @@ class CitusHandler(Thread):
|
||||
and leader.data.get('role') in ('master', 'primary') and leader.data.get('state') == 'running':
|
||||
self.add_task('after_promote', group, leader.conn_url)
|
||||
|
||||
def find_task_by_group(self, group):
|
||||
def find_task_by_group(self, group: int) -> Optional[int]:
|
||||
for i, task in enumerate(self._tasks):
|
||||
if task.group == group:
|
||||
return i
|
||||
|
||||
def pick_task(self):
|
||||
def pick_task(self) -> Tuple[Optional[int], Optional[PgDistNode]]:
|
||||
"""Returns the tuple(i, task), where `i` - is the task index in the self._tasks list
|
||||
|
||||
Tasks are picked by following priorities:
|
||||
@@ -195,15 +202,17 @@ class CitusHandler(Thread):
|
||||
task.nodeid = self._pg_dist_node[task.group].nodeid
|
||||
return i, task
|
||||
|
||||
def update_node(self, task):
|
||||
def update_node(self, task: PgDistNode) -> None:
|
||||
if task.nodeid is not None:
|
||||
self.query('SELECT pg_catalog.citus_update_node(%s, %s, %s, true, %s)',
|
||||
task.nodeid, task.host, task.port, task.cooldown)
|
||||
elif task.event != 'before_demote':
|
||||
task.nodeid = self.query("SELECT pg_catalog.citus_add_node(%s, %s, %s, 'primary', 'default')",
|
||||
task.host, task.port, task.group).fetchone()[0]
|
||||
row = self.query("SELECT pg_catalog.citus_add_node(%s, %s, %s, 'primary', 'default')",
|
||||
task.host, task.port, task.group).fetchone()
|
||||
if row is not None:
|
||||
task.nodeid = row[0]
|
||||
|
||||
def process_task(self, task):
|
||||
def process_task(self, task: PgDistNode) -> bool:
|
||||
"""Updates a single row in `pg_dist_node` table, optionally in a transaction.
|
||||
|
||||
The transaction is started if we do a demote of the worker node
|
||||
@@ -238,13 +247,13 @@ class CitusHandler(Thread):
|
||||
self._in_flight = task
|
||||
return False
|
||||
|
||||
def process_tasks(self):
|
||||
def process_tasks(self) -> None:
|
||||
while True:
|
||||
if not self._in_flight and not self.load_pg_dist_node():
|
||||
break
|
||||
|
||||
i, task = self.pick_task()
|
||||
if not task:
|
||||
if not task or i is None:
|
||||
break
|
||||
try:
|
||||
update_cache = self.process_task(task)
|
||||
@@ -259,7 +268,7 @@ class CitusHandler(Thread):
|
||||
self._tasks.pop(i)
|
||||
task.wakeup()
|
||||
|
||||
def run(self):
|
||||
def run(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
with self._condition:
|
||||
@@ -280,7 +289,7 @@ class CitusHandler(Thread):
|
||||
except Exception:
|
||||
logger.exception('run')
|
||||
|
||||
def _add_task(self, task):
|
||||
def _add_task(self, task: PgDistNode) -> bool:
|
||||
with self._condition:
|
||||
i = self.find_task_by_group(task.group)
|
||||
|
||||
@@ -306,32 +315,34 @@ class CitusHandler(Thread):
|
||||
return True
|
||||
return False
|
||||
|
||||
def add_task(self, event, group, conn_url, timeout=None, cooldown=None):
|
||||
def add_task(self, event: str, group: int, conn_url: str,
|
||||
timeout: Optional[float] = None, cooldown: Optional[float] = None) -> Optional[PgDistNode]:
|
||||
try:
|
||||
r = urlparse(conn_url)
|
||||
except Exception as e:
|
||||
return logger.error('Failed to parse connection url %s: %r', conn_url, e)
|
||||
host = r.hostname
|
||||
port = r.port or 5432
|
||||
task = PgDistNode(group, host, port, event, timeout=timeout, cooldown=cooldown)
|
||||
return task if self._add_task(task) else None
|
||||
if host:
|
||||
port = r.port or 5432
|
||||
task = PgDistNode(group, host, port, event, timeout=timeout, cooldown=cooldown)
|
||||
return task if self._add_task(task) else None
|
||||
|
||||
def handle_event(self, cluster, event):
|
||||
def handle_event(self, cluster: Cluster, event: Dict[str, Any]) -> None:
|
||||
if not self.is_alive():
|
||||
return
|
||||
|
||||
cluster = cluster.workers.get(event['group'])
|
||||
if not (cluster and cluster.leader and cluster.leader.name == event['leader'] and cluster.leader.conn_url):
|
||||
worker = cluster.workers.get(event['group'])
|
||||
if not (worker and worker.leader and worker.leader.name == event['leader'] and worker.leader.conn_url):
|
||||
return
|
||||
|
||||
task = self.add_task(event['type'], event['group'],
|
||||
cluster.leader.conn_url,
|
||||
worker.leader.conn_url,
|
||||
event['timeout'], event['cooldown'] * 1000)
|
||||
if task and event['type'] == 'before_demote':
|
||||
task.wait()
|
||||
|
||||
def bootstrap(self):
|
||||
if not self.is_enabled():
|
||||
def bootstrap(self) -> None:
|
||||
if not isinstance(self._config, dict): # self.is_enabled()
|
||||
return
|
||||
|
||||
conn_kwargs = self._postgresql.config.local_connect_kwargs
|
||||
@@ -340,7 +351,8 @@ class CitusHandler(Thread):
|
||||
conn = connect(**conn_kwargs)
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute('CREATE DATABASE {0}'.format(quote_ident(self._config['database'], conn)))
|
||||
cur.execute('CREATE DATABASE {0}'.format(
|
||||
quote_ident(self._config['database'], conn)).encode('utf-8'))
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
@@ -364,7 +376,7 @@ class CitusHandler(Thread):
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def adjust_postgres_gucs(self, parameters):
|
||||
def adjust_postgres_gucs(self, parameters: Dict[str, Any]) -> None:
|
||||
if not self.is_enabled():
|
||||
return
|
||||
|
||||
@@ -381,9 +393,9 @@ class CitusHandler(Thread):
|
||||
# Resharding in Citus implemented using logical replication
|
||||
parameters['wal_level'] = 'logical'
|
||||
|
||||
def ignore_replication_slot(self, slot):
|
||||
if self.is_enabled() and self._postgresql.is_leader() and\
|
||||
def ignore_replication_slot(self, slot: Dict[str, str]) -> bool:
|
||||
if isinstance(self._config, dict) and self._postgresql.is_leader() and\
|
||||
slot['type'] == 'logical' and slot['database'] == self._config['database']:
|
||||
m = CITUS_SLOT_NAME_RE.match(slot['name'])
|
||||
return m and {'move': 'pgoutput', 'split': 'citus'}.get(m.group(1)) == slot['plugin']
|
||||
return bool(m and {'move': 'pgoutput', 'split': 'citus'}.get(m.group(1)) == slot['plugin'])
|
||||
return False
|
||||
|
||||
@@ -7,21 +7,26 @@ import stat
|
||||
import time
|
||||
|
||||
from urllib.parse import urlparse, parse_qsl, unquote
|
||||
from types import TracebackType
|
||||
from typing import Any, Collection, Dict, List, Optional, Union, Tuple, Type, TYPE_CHECKING
|
||||
|
||||
from .validator import recovery_parameters, transform_postgresql_parameter_value, transform_recovery_parameter_value
|
||||
from ..collections import CaseInsensitiveDict
|
||||
from ..dcs import RemoteMember, slot_name_from_member_name
|
||||
from ..collections import CaseInsensitiveDict, CaseInsensitiveSet
|
||||
from ..dcs import Leader, Member, RemoteMember, slot_name_from_member_name
|
||||
from ..exceptions import PatroniFatalException
|
||||
from ..utils import compare_values, parse_bool, parse_int, split_host_port, uri, \
|
||||
validate_directory, is_subpath
|
||||
from ..utils import compare_values, parse_bool, parse_int, split_host_port, uri, validate_directory, is_subpath
|
||||
from ..validator import IntValidator
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from . import Postgresql
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PARAMETER_RE = re.compile(r'([a-z_]+)\s*=\s*')
|
||||
|
||||
|
||||
def conninfo_uri_parse(dsn):
|
||||
ret = {}
|
||||
def conninfo_uri_parse(dsn: str) -> Dict[str, str]:
|
||||
ret: Dict[str, str] = {}
|
||||
r = urlparse(dsn)
|
||||
if r.username:
|
||||
ret['user'] = r.username
|
||||
@@ -29,21 +34,18 @@ def conninfo_uri_parse(dsn):
|
||||
ret['password'] = r.password
|
||||
if r.path[1:]:
|
||||
ret['dbname'] = r.path[1:]
|
||||
hosts = []
|
||||
ports = []
|
||||
hosts: List[str] = []
|
||||
ports: List[str] = []
|
||||
for netloc in r.netloc.split('@')[-1].split(','):
|
||||
host = port = None
|
||||
host = None
|
||||
if '[' in netloc and ']' in netloc:
|
||||
host = netloc.split(']')[0][1:]
|
||||
tmp = netloc.split(':', 1)
|
||||
if host is None:
|
||||
host = tmp[0]
|
||||
hosts.append(host)
|
||||
if len(tmp) == 2:
|
||||
host, port = tmp
|
||||
if host is not None:
|
||||
hosts.append(host)
|
||||
if port is not None:
|
||||
ports.append(port)
|
||||
ports.append(tmp[1])
|
||||
if hosts:
|
||||
ret['host'] = ','.join(hosts)
|
||||
if ports:
|
||||
@@ -56,7 +58,7 @@ def conninfo_uri_parse(dsn):
|
||||
return ret
|
||||
|
||||
|
||||
def read_param_value(value):
|
||||
def read_param_value(value: str) -> Union[Tuple[None, None], Tuple[str, int]]:
|
||||
length = len(value)
|
||||
ret = ''
|
||||
is_quoted = value[0] == "'"
|
||||
@@ -76,8 +78,8 @@ def read_param_value(value):
|
||||
return (None, None) if is_quoted else (ret, i)
|
||||
|
||||
|
||||
def conninfo_parse(dsn):
|
||||
ret = {}
|
||||
def conninfo_parse(dsn: str) -> Optional[Dict[str, str]]:
|
||||
ret: Dict[str, str] = {}
|
||||
length = len(dsn)
|
||||
i = 0
|
||||
while i < length:
|
||||
@@ -96,14 +98,14 @@ def conninfo_parse(dsn):
|
||||
return
|
||||
|
||||
value, end = read_param_value(dsn[i:])
|
||||
if value is None:
|
||||
if value is None or end is None:
|
||||
return
|
||||
i += end
|
||||
ret[param] = value
|
||||
return ret
|
||||
|
||||
|
||||
def parse_dsn(value):
|
||||
def parse_dsn(value: str) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
Very simple equivalent of `psycopg2.extensions.parse_dsn` introduced in 2.7.0.
|
||||
We are not using psycopg2 function in order to remain compatible with 2.5.4+.
|
||||
@@ -147,14 +149,14 @@ def parse_dsn(value):
|
||||
return ret
|
||||
|
||||
|
||||
def strip_comment(value):
|
||||
def strip_comment(value: str) -> str:
|
||||
i = value.find('#')
|
||||
if i > -1:
|
||||
value = value[:i].strip()
|
||||
return value
|
||||
|
||||
|
||||
def read_recovery_param_value(value):
|
||||
def read_recovery_param_value(value: str) -> Optional[str]:
|
||||
"""
|
||||
>>> read_recovery_param_value('') is None
|
||||
True
|
||||
@@ -211,7 +213,7 @@ def read_recovery_param_value(value):
|
||||
return value
|
||||
|
||||
|
||||
def mtime(filename):
|
||||
def mtime(filename: str) -> Optional[float]:
|
||||
try:
|
||||
return os.stat(filename).st_mtime
|
||||
except OSError:
|
||||
@@ -220,35 +222,49 @@ def mtime(filename):
|
||||
|
||||
class ConfigWriter(object):
|
||||
|
||||
def __init__(self, filename):
|
||||
def __init__(self, filename: str) -> None:
|
||||
self._filename = filename
|
||||
self._fd = None
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> 'ConfigWriter':
|
||||
self._fd = open(self._filename, 'w')
|
||||
self.writeline('# Do not edit this file manually!\n# It will be overwritten by Patroni!')
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
def __exit__(self, exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]) -> None:
|
||||
if self._fd:
|
||||
self._fd.close()
|
||||
|
||||
def writeline(self, line):
|
||||
self._fd.write(line)
|
||||
self._fd.write('\n')
|
||||
def writeline(self, line: str) -> None:
|
||||
if self._fd:
|
||||
self._fd.write(line)
|
||||
self._fd.write('\n')
|
||||
|
||||
def writelines(self, lines):
|
||||
def writelines(self, lines: List[str]) -> None:
|
||||
for line in lines:
|
||||
self.writeline(line)
|
||||
|
||||
@staticmethod
|
||||
def escape(value): # Escape (by doubling) any single quotes or backslashes in given string
|
||||
def escape(value: Any) -> str: # Escape (by doubling) any single quotes or backslashes in given string
|
||||
return re.sub(r'([\'\\])', r'\1\1', str(value))
|
||||
|
||||
def write_param(self, param, value):
|
||||
def write_param(self, param: str, value: Any) -> None:
|
||||
self.writeline("{0} = '{1}'".format(param, self.escape(value)))
|
||||
|
||||
|
||||
def _false_validator(value: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _wal_level_validator(value: Any) -> bool:
|
||||
return str(value).lower() in ('hot_standby', 'replica', 'logical')
|
||||
|
||||
|
||||
def _bool_validator(value: Any) -> bool:
|
||||
return parse_bool(value) is not None
|
||||
|
||||
|
||||
class ConfigHandler(object):
|
||||
|
||||
# List of parameters which must be always passed to postmaster as command line options
|
||||
@@ -266,28 +282,28 @@ class ConfigHandler(object):
|
||||
# check_function -- if the new value is not correct must return `!False`
|
||||
# min_version -- major version of PostgreSQL when parameter was introduced
|
||||
CMDLINE_OPTIONS = CaseInsensitiveDict({
|
||||
'listen_addresses': (None, lambda _: False, 90100),
|
||||
'port': (None, lambda _: False, 90100),
|
||||
'cluster_name': (None, lambda _: False, 90500),
|
||||
'wal_level': ('hot_standby', lambda v: v.lower() in ('hot_standby', 'replica', 'logical'), 90100),
|
||||
'hot_standby': ('on', lambda _: False, 90100),
|
||||
'max_connections': (100, lambda v: int(v) >= 25, 90100),
|
||||
'max_wal_senders': (10, lambda v: int(v) >= 3, 90100),
|
||||
'wal_keep_segments': (8, lambda v: int(v) >= 1, 90100),
|
||||
'wal_keep_size': ('128MB', lambda v: parse_int(v, 'MB') >= 16, 130000),
|
||||
'max_prepared_transactions': (0, lambda v: int(v) >= 0, 90100),
|
||||
'max_locks_per_transaction': (64, lambda v: int(v) >= 32, 90100),
|
||||
'track_commit_timestamp': ('off', lambda v: parse_bool(v) is not None, 90500),
|
||||
'max_replication_slots': (10, lambda v: int(v) >= 4, 90400),
|
||||
'max_worker_processes': (8, lambda v: int(v) >= 2, 90400),
|
||||
'wal_log_hints': ('on', lambda _: False, 90400)
|
||||
'listen_addresses': (None, _false_validator, 90100),
|
||||
'port': (None, _false_validator, 90100),
|
||||
'cluster_name': (None, _false_validator, 90500),
|
||||
'wal_level': ('hot_standby', _wal_level_validator, 90100),
|
||||
'hot_standby': ('on', _false_validator, 90100),
|
||||
'max_connections': (100, IntValidator(min=25), 90100),
|
||||
'max_wal_senders': (10, IntValidator(min=3), 90100),
|
||||
'wal_keep_segments': (8, IntValidator(min=1), 90100),
|
||||
'wal_keep_size': ('128MB', IntValidator(min=16, base_unit='MB'), 130000),
|
||||
'max_prepared_transactions': (0, IntValidator(min=0), 90100),
|
||||
'max_locks_per_transaction': (64, IntValidator(min=32), 90100),
|
||||
'track_commit_timestamp': ('off', _bool_validator, 90500),
|
||||
'max_replication_slots': (10, IntValidator(min=4), 90400),
|
||||
'max_worker_processes': (8, IntValidator(min=2), 90400),
|
||||
'wal_log_hints': ('on', _false_validator, 90400)
|
||||
})
|
||||
|
||||
_RECOVERY_PARAMETERS = set(recovery_parameters.keys())
|
||||
_RECOVERY_PARAMETERS = CaseInsensitiveSet(recovery_parameters.keys())
|
||||
|
||||
def __init__(self, postgresql, config):
|
||||
def __init__(self, postgresql: 'Postgresql', config: Dict[str, Any]) -> None:
|
||||
self._postgresql = postgresql
|
||||
self._config_dir = os.path.abspath(config.get('config_dir') or postgresql.data_dir)
|
||||
self._config_dir = os.path.abspath(config.get('config_dir', '') or postgresql.data_dir)
|
||||
config_base_name = config.get('config_base_name', 'postgresql')
|
||||
self._postgresql_conf = os.path.join(self._config_dir, config_base_name + '.conf')
|
||||
self._postgresql_conf_mtime = None
|
||||
@@ -309,21 +325,22 @@ class ConfigHandler(object):
|
||||
self._passfile_mtime = None
|
||||
self._synchronous_standby_names = None
|
||||
self._postmaster_ctime = None
|
||||
self._current_recovery_params = None
|
||||
self._current_recovery_params: Optional[CaseInsensitiveDict] = None
|
||||
self._config = {}
|
||||
self._recovery_params = {}
|
||||
self._recovery_params = CaseInsensitiveDict()
|
||||
self._server_parameters: CaseInsensitiveDict
|
||||
self.reload_config(config)
|
||||
|
||||
def setup_server_parameters(self):
|
||||
def setup_server_parameters(self) -> None:
|
||||
self._server_parameters = self.get_server_parameters(self._config)
|
||||
self._adjust_recovery_parameters()
|
||||
|
||||
def try_to_create_dir(self, d, msg):
|
||||
d = os.path.join(self._postgresql._data_dir, d)
|
||||
if (not is_subpath(self._postgresql._data_dir, d) or not self._postgresql.data_directory_empty()):
|
||||
def try_to_create_dir(self, d: str, msg: str) -> None:
|
||||
d = os.path.join(self._postgresql.data_dir, d)
|
||||
if (not is_subpath(self._postgresql.data_dir, d) or not self._postgresql.data_directory_empty()):
|
||||
validate_directory(d, msg)
|
||||
|
||||
def check_directories(self):
|
||||
def check_directories(self) -> None:
|
||||
if "unix_socket_directories" in self._server_parameters:
|
||||
for d in self._server_parameters["unix_socket_directories"].split(","):
|
||||
self.try_to_create_dir(d.strip(), "'{}' is defined in unix_socket_directories, {}")
|
||||
@@ -335,7 +352,11 @@ class ConfigHandler(object):
|
||||
"'{}' is defined in `postgresql.pgpass`, {}")
|
||||
|
||||
@property
|
||||
def _configuration_to_save(self):
|
||||
def config_dir(self) -> str:
|
||||
return self._config_dir
|
||||
|
||||
@property
|
||||
def _configuration_to_save(self) -> List[str]:
|
||||
configuration = [os.path.basename(self._postgresql_conf)]
|
||||
if 'custom_conf' not in self._config:
|
||||
configuration.append(os.path.basename(self._postgresql_base_conf_name))
|
||||
@@ -345,7 +366,7 @@ class ConfigHandler(object):
|
||||
configuration.append('pg_ident.conf')
|
||||
return configuration
|
||||
|
||||
def save_configuration_files(self, check_custom_bootstrap=False):
|
||||
def save_configuration_files(self, check_custom_bootstrap: bool = False) -> bool:
|
||||
"""
|
||||
copy postgresql.conf to postgresql.conf.backup to be able to retrieve configuration files
|
||||
- originally stored as symlinks, those are normally skipped by pg_basebackup
|
||||
@@ -362,7 +383,7 @@ class ConfigHandler(object):
|
||||
logger.exception('unable to create backup copies of configuration files')
|
||||
return True
|
||||
|
||||
def restore_configuration_files(self):
|
||||
def restore_configuration_files(self) -> None:
|
||||
""" restore a previously saved postgresql.conf """
|
||||
try:
|
||||
for f in self._configuration_to_save:
|
||||
@@ -377,7 +398,7 @@ class ConfigHandler(object):
|
||||
except IOError:
|
||||
logger.exception('unable to restore configuration files from backup')
|
||||
|
||||
def write_postgresql_conf(self, configuration=None):
|
||||
def write_postgresql_conf(self, configuration: Optional[CaseInsensitiveDict] = None) -> None:
|
||||
# rename the original configuration if it is necessary
|
||||
if 'custom_conf' not in self._config and not os.path.exists(self._postgresql_base_conf):
|
||||
os.rename(self._postgresql_conf, self._postgresql_base_conf)
|
||||
@@ -412,13 +433,13 @@ class ConfigHandler(object):
|
||||
if not self._postgresql.bootstrap.keep_existing_recovery_conf:
|
||||
self._sanitize_auto_conf()
|
||||
|
||||
def append_pg_hba(self, config):
|
||||
def append_pg_hba(self, config: List[str]) -> bool:
|
||||
if not self.hba_file and not self._config.get('pg_hba'):
|
||||
with open(self._pg_hba_conf, 'a') as f:
|
||||
f.write('\n{}\n'.format('\n'.join(config)))
|
||||
return True
|
||||
|
||||
def replace_pg_hba(self):
|
||||
def replace_pg_hba(self) -> Optional[bool]:
|
||||
"""
|
||||
Replace pg_hba.conf content in the PGDATA if hba_file is not defined in the
|
||||
`postgresql.parameters` and pg_hba is defined in `postgresql` configuration section.
|
||||
@@ -446,7 +467,7 @@ class ConfigHandler(object):
|
||||
f.writelines(self._config['pg_hba'])
|
||||
return True
|
||||
|
||||
def replace_pg_ident(self):
|
||||
def replace_pg_ident(self) -> Optional[bool]:
|
||||
"""
|
||||
Replace pg_ident.conf content in the PGDATA if ident_file is not defined in the
|
||||
`postgresql.parameters` and pg_ident is defined in the `postgresql` section.
|
||||
@@ -459,8 +480,8 @@ class ConfigHandler(object):
|
||||
f.writelines(self._config['pg_ident'])
|
||||
return True
|
||||
|
||||
def primary_conninfo_params(self, member):
|
||||
if not (member and member.conn_url) or member.name == self._postgresql.name:
|
||||
def primary_conninfo_params(self, member: Union[Leader, Member, None]) -> Optional[Dict[str, Any]]:
|
||||
if not member or not member.conn_url or member.name == self._postgresql.name:
|
||||
return None
|
||||
ret = member.conn_kwargs(self.replication)
|
||||
ret['application_name'] = self._postgresql.name
|
||||
@@ -475,7 +496,7 @@ class ConfigHandler(object):
|
||||
del ret['dbname']
|
||||
return ret
|
||||
|
||||
def format_dsn(self, params, include_dbname=False):
|
||||
def format_dsn(self, params: Dict[str, Any], include_dbname: bool = False) -> str:
|
||||
# A list of keywords that can be found in a conninfo string. Follows what is acceptable by libpq
|
||||
keywords = ('dbname', 'user', 'passfile' if params.get('passfile') else 'password', 'host', 'port',
|
||||
'sslmode', 'sslcompression', 'sslcert', 'sslkey', 'sslpassword', 'sslrootcert', 'sslcrl',
|
||||
@@ -491,13 +512,13 @@ class ConfigHandler(object):
|
||||
else:
|
||||
skip = {'dbname'}
|
||||
|
||||
def escape(value):
|
||||
def escape(value: Any) -> str:
|
||||
return re.sub(r'([\'\\ ])', r'\\\1', str(value))
|
||||
|
||||
return ' '.join('{0}={1}'.format(kw, escape(params[kw])) for kw in keywords
|
||||
if kw not in skip and params.get(kw) is not None)
|
||||
|
||||
def _write_recovery_params(self, fd, recovery_params):
|
||||
def _write_recovery_params(self, fd: ConfigWriter, recovery_params: CaseInsensitiveDict) -> None:
|
||||
if self._postgresql.major_version >= 90500:
|
||||
pause_at_recovery_target = parse_bool(recovery_params.pop('pause_at_recovery_target', None))
|
||||
if pause_at_recovery_target is not None:
|
||||
@@ -518,8 +539,8 @@ class ConfigHandler(object):
|
||||
continue
|
||||
fd.write_param(name, value)
|
||||
|
||||
def build_recovery_params(self, member):
|
||||
recovery_params = CaseInsensitiveDict({p: v for p, v in self.get('recovery_conf', {}).items()
|
||||
def build_recovery_params(self, member: Union[Leader, Member, None]) -> CaseInsensitiveDict:
|
||||
recovery_params = CaseInsensitiveDict({p: v for p, v in (self.get('recovery_conf') or {}).items()
|
||||
if not p.lower().startswith('recovery_target')
|
||||
and p.lower() not in ('primary_conninfo', 'primary_slot_name')})
|
||||
recovery_params.update({'standby_mode': 'on', 'recovery_target_timeline': 'latest'})
|
||||
@@ -550,26 +571,26 @@ class ConfigHandler(object):
|
||||
recovery_params.update({p: member.data.get(p) for p in standby_cluster_params if member and member.data.get(p)})
|
||||
return recovery_params
|
||||
|
||||
def recovery_conf_exists(self):
|
||||
def recovery_conf_exists(self) -> bool:
|
||||
if self._postgresql.major_version >= 120000:
|
||||
return os.path.exists(self._standby_signal) or os.path.exists(self._recovery_signal)
|
||||
return os.path.exists(self._recovery_conf)
|
||||
|
||||
@property
|
||||
def triggerfile_good_name(self):
|
||||
def triggerfile_good_name(self) -> str:
|
||||
return 'trigger_file' if self._postgresql.major_version < 120000 else 'promote_trigger_file'
|
||||
|
||||
@property
|
||||
def _triggerfile_wrong_name(self):
|
||||
def _triggerfile_wrong_name(self) -> str:
|
||||
return 'trigger_file' if self._postgresql.major_version >= 120000 else 'promote_trigger_file'
|
||||
|
||||
@property
|
||||
def _recovery_parameters_to_compare(self):
|
||||
skip_params = {'pause_at_recovery_target', 'recovery_target_inclusive',
|
||||
'recovery_target_action', 'standby_mode', self._triggerfile_wrong_name}
|
||||
return self._RECOVERY_PARAMETERS - skip_params
|
||||
def _recovery_parameters_to_compare(self) -> CaseInsensitiveSet:
|
||||
skip_params = CaseInsensitiveSet({'pause_at_recovery_target', 'recovery_target_inclusive',
|
||||
'recovery_target_action', 'standby_mode', self._triggerfile_wrong_name})
|
||||
return CaseInsensitiveSet(self._RECOVERY_PARAMETERS - skip_params)
|
||||
|
||||
def _read_recovery_params(self):
|
||||
def _read_recovery_params(self) -> Tuple[Optional[CaseInsensitiveDict], Optional[bool]]:
|
||||
if self._postgresql.is_starting():
|
||||
return None, False
|
||||
|
||||
@@ -586,7 +607,7 @@ class ConfigHandler(object):
|
||||
|
||||
try:
|
||||
values = self._get_pg_settings(self._recovery_parameters_to_compare).values()
|
||||
values = {p[0]: [p[1], p[4] == 'postmaster', p[5]] for p in values}
|
||||
values = CaseInsensitiveDict({p[0]: [p[1], p[4] == 'postmaster', p[5]] for p in values})
|
||||
self._postgresql_conf_mtime = pg_conf_mtime
|
||||
self._auto_conf_mtime = auto_conf_mtime
|
||||
self._postmaster_ctime = postmaster_ctime
|
||||
@@ -594,13 +615,13 @@ class ConfigHandler(object):
|
||||
values = None
|
||||
return values, True
|
||||
|
||||
def _read_recovery_params_pre_v12(self):
|
||||
def _read_recovery_params_pre_v12(self) -> Tuple[Optional[CaseInsensitiveDict], Optional[bool]]:
|
||||
recovery_conf_mtime = mtime(self._recovery_conf)
|
||||
passfile_mtime = mtime(self._passfile) if self._passfile else False
|
||||
if recovery_conf_mtime == self._recovery_conf_mtime and passfile_mtime == self._passfile_mtime:
|
||||
return None, False
|
||||
|
||||
values = {}
|
||||
values = CaseInsensitiveDict()
|
||||
with open(self._recovery_conf, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
@@ -610,7 +631,7 @@ class ConfigHandler(object):
|
||||
match = PARAMETER_RE.match(line)
|
||||
if match:
|
||||
value = read_recovery_param_value(line[match.end():])
|
||||
if value is None:
|
||||
if match is None or value is None:
|
||||
return None, True
|
||||
values[match.group(1)] = [value, True]
|
||||
self._recovery_conf_mtime = recovery_conf_mtime
|
||||
@@ -619,7 +640,7 @@ class ConfigHandler(object):
|
||||
values.update({param: ['', True] for param in self._recovery_parameters_to_compare if param not in values})
|
||||
return values, True
|
||||
|
||||
def _check_passfile(self, passfile, wanted_primary_conninfo):
|
||||
def _check_passfile(self, passfile: str, wanted_primary_conninfo: Dict[str, Any]) -> bool:
|
||||
# If there is a passfile in the primary_conninfo try to figure out that
|
||||
# the passfile contains the line(s) allowing connection to the given node.
|
||||
# We assume that the passfile was created by Patroni and therefore doing
|
||||
@@ -628,7 +649,7 @@ class ConfigHandler(object):
|
||||
if passfile_mtime:
|
||||
try:
|
||||
with open(passfile) as f:
|
||||
wanted_lines = self._pgpass_line(wanted_primary_conninfo).splitlines()
|
||||
wanted_lines = (self._pgpass_line(wanted_primary_conninfo) or '').splitlines()
|
||||
file_lines = f.read().splitlines()
|
||||
if set(wanted_lines) == set(file_lines):
|
||||
self._passfile = passfile
|
||||
@@ -638,7 +659,8 @@ class ConfigHandler(object):
|
||||
logger.info('Failed to read %s', passfile)
|
||||
return False
|
||||
|
||||
def _check_primary_conninfo(self, primary_conninfo, wanted_primary_conninfo):
|
||||
def _check_primary_conninfo(self, primary_conninfo: Dict[str, Any],
|
||||
wanted_primary_conninfo: Dict[str, Any]) -> bool:
|
||||
# first we will cover corner cases, when we are replicating from somewhere while shouldn't
|
||||
# or there is no primary_conninfo but we should replicate from some specific node.
|
||||
if not wanted_primary_conninfo:
|
||||
@@ -667,7 +689,7 @@ class ConfigHandler(object):
|
||||
|
||||
return all(str(primary_conninfo.get(p)) == str(v) for p, v in wanted_primary_conninfo.items() if v is not None)
|
||||
|
||||
def check_recovery_conf(self, member):
|
||||
def check_recovery_conf(self, member: Union[Leader, Member, None]) -> Tuple[bool, bool]:
|
||||
"""Returns a tuple. The first boolean element indicates that recovery params don't match
|
||||
and the second is set to `True` if the restart is required in order to apply new values"""
|
||||
|
||||
@@ -701,7 +723,7 @@ class ConfigHandler(object):
|
||||
else: # empty string, primary_conninfo is not in the config
|
||||
primary_conninfo[0] = {}
|
||||
|
||||
if not self._postgresql.is_starting():
|
||||
if not self._postgresql.is_starting() and self._current_recovery_params:
|
||||
# when wal receiver is alive take primary_slot_name from pg_stat_wal_receiver
|
||||
wal_receiver_primary_slot_name = self._postgresql.primary_slot_name()
|
||||
if not wal_receiver_primary_slot_name and self._postgresql.primary_conninfo():
|
||||
@@ -715,11 +737,11 @@ class ConfigHandler(object):
|
||||
and not self._postgresql.cb_called
|
||||
and not self._postgresql.is_starting())}
|
||||
|
||||
def record_missmatch(mtype):
|
||||
def record_missmatch(mtype: bool) -> None:
|
||||
required['restart' if mtype else 'reload'] += 1
|
||||
|
||||
wanted_recovery_params = self.build_recovery_params(member)
|
||||
for param, value in self._current_recovery_params.items():
|
||||
for param, value in (self._current_recovery_params or {}).items():
|
||||
# Skip certain parameters defined in the included postgres config files
|
||||
# if we know that they are not specified in the patroni configuration.
|
||||
if len(value) > 2 and value[2] not in (self._postgresql_conf, self._auto_conf) and \
|
||||
@@ -741,25 +763,25 @@ class ConfigHandler(object):
|
||||
return required['restart'] + required['reload'] > 0, required['restart'] > 0
|
||||
|
||||
@staticmethod
|
||||
def _remove_file_if_exists(name):
|
||||
def _remove_file_if_exists(name: str) -> None:
|
||||
if os.path.isfile(name) or os.path.islink(name):
|
||||
os.unlink(name)
|
||||
|
||||
@staticmethod
|
||||
def _pgpass_line(record):
|
||||
def _pgpass_line(record: Dict[str, Any]) -> Optional[str]:
|
||||
if 'password' in record:
|
||||
def escape(value):
|
||||
def escape(value: Any) -> str:
|
||||
return re.sub(r'([:\\])', r'\\\1', str(value))
|
||||
|
||||
record = {n: escape(record.get(n) or '*') for n in ('host', 'port', 'user', 'password')}
|
||||
# 'host' could be several comma-separated hostnames, in this case
|
||||
# we need to write on pgpass line per host
|
||||
line = ''
|
||||
for hostname in record.get('host').split(','):
|
||||
for hostname in record['host'].split(','):
|
||||
line += hostname + ':{port}:*:{user}:{password}'.format(**record) + '\n'
|
||||
return line.rstrip()
|
||||
|
||||
def write_pgpass(self, record):
|
||||
def write_pgpass(self, record: Dict[str, Any]) -> Dict[str, str]:
|
||||
line = self._pgpass_line(record)
|
||||
if not line:
|
||||
return os.environ.copy()
|
||||
@@ -770,7 +792,7 @@ class ConfigHandler(object):
|
||||
|
||||
return {**os.environ, 'PGPASSFILE': self._pgpass}
|
||||
|
||||
def write_recovery_conf(self, recovery_params):
|
||||
def write_recovery_conf(self, recovery_params: CaseInsensitiveDict) -> None:
|
||||
self._recovery_params = recovery_params
|
||||
if self._postgresql.major_version >= 120000:
|
||||
if parse_bool(recovery_params.pop('standby_mode', None)):
|
||||
@@ -779,28 +801,28 @@ class ConfigHandler(object):
|
||||
self._remove_file_if_exists(self._standby_signal)
|
||||
open(self._recovery_signal, 'w').close()
|
||||
|
||||
def restart_required(name):
|
||||
def restart_required(name: str) -> bool:
|
||||
if self._postgresql.major_version >= 140000:
|
||||
return False
|
||||
return name == 'restore_command' or (self._postgresql.major_version < 130000
|
||||
and name in ('primary_conninfo', 'primary_slot_name'))
|
||||
|
||||
self._current_recovery_params = {n: [v, restart_required(n), self._postgresql_conf]
|
||||
for n, v in recovery_params.items()}
|
||||
self._current_recovery_params = CaseInsensitiveDict({n: [v, restart_required(n), self._postgresql_conf]
|
||||
for n, v in recovery_params.items()})
|
||||
else:
|
||||
with ConfigWriter(self._recovery_conf) as f:
|
||||
os.chmod(self._recovery_conf, stat.S_IWRITE | stat.S_IREAD)
|
||||
self._write_recovery_params(f, recovery_params)
|
||||
|
||||
def remove_recovery_conf(self):
|
||||
def remove_recovery_conf(self) -> None:
|
||||
for name in (self._recovery_conf, self._standby_signal, self._recovery_signal):
|
||||
self._remove_file_if_exists(name)
|
||||
self._recovery_params = {}
|
||||
self._recovery_params = CaseInsensitiveDict()
|
||||
self._current_recovery_params = None
|
||||
|
||||
def _sanitize_auto_conf(self):
|
||||
def _sanitize_auto_conf(self) -> None:
|
||||
overwrite = False
|
||||
lines = []
|
||||
lines: List[str] = []
|
||||
|
||||
if os.path.exists(self._auto_conf):
|
||||
try:
|
||||
@@ -823,7 +845,7 @@ class ConfigHandler(object):
|
||||
except Exception:
|
||||
logger.exception('Failed to remove some unwanted parameters from %s', self._auto_conf)
|
||||
|
||||
def _adjust_recovery_parameters(self):
|
||||
def _adjust_recovery_parameters(self) -> None:
|
||||
# It is not strictly necessary, but we can make patroni configs crossi-compatible with all postgres versions.
|
||||
recovery_conf = {n: v for n, v in self._server_parameters.items() if n.lower() in self._RECOVERY_PARAMETERS}
|
||||
if recovery_conf:
|
||||
@@ -834,7 +856,7 @@ class ConfigHandler(object):
|
||||
if self.triggerfile_good_name not in self._config['recovery_conf'] and value:
|
||||
self._config['recovery_conf'][self.triggerfile_good_name] = value
|
||||
|
||||
def get_server_parameters(self, config):
|
||||
def get_server_parameters(self, config: Dict[str, Any]) -> CaseInsensitiveDict:
|
||||
parameters = config['parameters'].copy()
|
||||
listen_addresses, port = split_host_port(config['listen'], 5432)
|
||||
parameters.update(cluster_name=self._postgresql.scope, listen_addresses=listen_addresses, port=str(port))
|
||||
@@ -860,7 +882,7 @@ class ConfigHandler(object):
|
||||
parameters.setdefault('wal_keep_size', str(int(wal_keep_segments) * 16) + 'MB')
|
||||
elif self._postgresql.major_version:
|
||||
wal_keep_size = parse_int(parameters.pop('wal_keep_size', self.CMDLINE_OPTIONS['wal_keep_size'][0]), 'MB')
|
||||
parameters.setdefault('wal_keep_segments', int((wal_keep_size + 8) / 16))
|
||||
parameters.setdefault('wal_keep_segments', int(((wal_keep_size or 0) + 8) / 16))
|
||||
|
||||
self._postgresql.citus_handler.adjust_postgres_gucs(parameters)
|
||||
|
||||
@@ -870,14 +892,14 @@ class ConfigHandler(object):
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _get_unix_local_address(unix_socket_directories):
|
||||
def _get_unix_local_address(unix_socket_directories: str) -> str:
|
||||
for d in unix_socket_directories.split(','):
|
||||
d = d.strip()
|
||||
if d.startswith('/'): # Only absolute path can be used to connect via unix-socket
|
||||
return d
|
||||
return ''
|
||||
|
||||
def _get_tcp_local_address(self):
|
||||
def _get_tcp_local_address(self) -> str:
|
||||
listen_addresses = self._server_parameters['listen_addresses'].split(',')
|
||||
|
||||
for la in listen_addresses:
|
||||
@@ -886,7 +908,7 @@ class ConfigHandler(object):
|
||||
return listen_addresses[0].strip() # can't use localhost, take first address from listen_addresses
|
||||
|
||||
@property
|
||||
def local_connect_kwargs(self):
|
||||
def local_connect_kwargs(self) -> Dict[str, Any]:
|
||||
ret = self._local_address.copy()
|
||||
# add all of the other connection settings that are available
|
||||
ret.update(self._superuser)
|
||||
@@ -902,7 +924,7 @@ class ConfigHandler(object):
|
||||
'options': '-c statement_timeout=2000'})
|
||||
return ret
|
||||
|
||||
def resolve_connection_addresses(self):
|
||||
def resolve_connection_addresses(self) -> None:
|
||||
port = self._server_parameters['port']
|
||||
tcp_local_address = self._get_tcp_local_address()
|
||||
netloc = self._config.get('connect_address') or tcp_local_address + ':' + port
|
||||
@@ -922,19 +944,24 @@ class ConfigHandler(object):
|
||||
self._postgresql.connection_string = uri('postgres', netloc, self._postgresql.database)
|
||||
self._postgresql.set_connection_kwargs(self.local_connect_kwargs)
|
||||
|
||||
def _get_pg_settings(self, names):
|
||||
def _get_pg_settings(
|
||||
self, names: Collection[str]
|
||||
) -> Dict[str, Tuple[str, str, Optional[str], str, str, Optional[str]]]:
|
||||
return {r[0]: r for r in self._postgresql.query(('SELECT name, setting, unit, vartype, context, sourcefile'
|
||||
+ ' FROM pg_catalog.pg_settings '
|
||||
+ ' WHERE pg_catalog.lower(name) = ANY(%s)'),
|
||||
[n.lower() for n in names])}
|
||||
|
||||
@staticmethod
|
||||
def _handle_wal_buffers(old_values, changes):
|
||||
wal_block_size = parse_int(old_values['wal_block_size'][1])
|
||||
def _handle_wal_buffers(old_values: Dict[str, Tuple[str, str, Optional[str], str, str, Optional[str]]],
|
||||
changes: CaseInsensitiveDict) -> None:
|
||||
wal_block_size = parse_int(old_values['wal_block_size'][1]) or 8192
|
||||
wal_segment_size = old_values['wal_segment_size']
|
||||
wal_segment_unit = parse_int(wal_segment_size[2], 'B') if wal_segment_size[2][0].isdigit() else 1
|
||||
wal_segment_size = parse_int(wal_segment_size[1]) * wal_segment_unit / wal_block_size
|
||||
default_wal_buffers = min(max(parse_int(old_values['shared_buffers'][1]) / 32, 8), wal_segment_size)
|
||||
wal_segment_unit = parse_int(wal_segment_size[2], 'B') or 8192 \
|
||||
if wal_segment_size[2] is not None and wal_segment_size[2][0].isdigit() else 1
|
||||
wal_segment_size = parse_int(wal_segment_size[1]) or (16777216 if wal_segment_size[2] is None else 2048)
|
||||
wal_segment_size *= wal_segment_unit / wal_block_size
|
||||
default_wal_buffers = min(max((parse_int(old_values['shared_buffers'][1]) or 16384) / 32, 8), wal_segment_size)
|
||||
|
||||
wal_buffers = old_values['wal_buffers']
|
||||
new_value = str(changes['wal_buffers'] or -1)
|
||||
@@ -945,7 +972,7 @@ class ConfigHandler(object):
|
||||
if new_value == old_value:
|
||||
del changes['wal_buffers']
|
||||
|
||||
def reload_config(self, config, sighup=False):
|
||||
def reload_config(self, config: Dict[str, Any], sighup: bool = False) -> None:
|
||||
self._superuser = config['authentication'].get('superuser', {})
|
||||
server_parameters = self.get_server_parameters(config)
|
||||
|
||||
@@ -956,6 +983,7 @@ class ConfigHandler(object):
|
||||
changes.update({p: None for p in self._server_parameters.keys()
|
||||
if not (p in changes or p.lower() in self._RECOVERY_PARAMETERS)})
|
||||
if changes:
|
||||
undef = []
|
||||
if 'wal_buffers' in changes: # we need to calculate the default value of wal_buffers
|
||||
undef = [p for p in ('shared_buffers', 'wal_segment_size', 'wal_block_size') if p not in changes]
|
||||
changes.update({p: None for p in undef})
|
||||
@@ -1030,16 +1058,17 @@ class ConfigHandler(object):
|
||||
if self._postgresql.major_version >= 90500:
|
||||
time.sleep(1)
|
||||
try:
|
||||
pending_restart = self._postgresql.query(
|
||||
'SELECT COUNT(*) FROM pg_catalog.pg_settings WHERE pg_catalog.lower(name) != ALL(%s)'
|
||||
' AND pending_restart', [n.lower() for n in self._RECOVERY_PARAMETERS]).fetchone()[0] > 0
|
||||
pending_restart = (self._postgresql.query(
|
||||
'SELECT COUNT(*) FROM pg_catalog.pg_settings'
|
||||
' WHERE pg_catalog.lower(name) != ALL(%s) AND pending_restart',
|
||||
[n.lower() for n in self._RECOVERY_PARAMETERS]).fetchone() or (0,))[0] > 0
|
||||
self._postgresql.set_pending_restart(pending_restart)
|
||||
except Exception as e:
|
||||
logger.warning('Exception %r when running query', e)
|
||||
else:
|
||||
logger.info('No PostgreSQL configuration items changed, nothing to reload.')
|
||||
|
||||
def set_synchronous_standby_names(self, value):
|
||||
def set_synchronous_standby_names(self, value: Optional[str]) -> Optional[bool]:
|
||||
"""Updates synchronous_standby_names and reloads if necessary.
|
||||
:returns: True if value was updated."""
|
||||
if value != self._synchronous_standby_names:
|
||||
@@ -1054,7 +1083,7 @@ class ConfigHandler(object):
|
||||
return True
|
||||
|
||||
@property
|
||||
def effective_configuration(self):
|
||||
def effective_configuration(self) -> CaseInsensitiveDict:
|
||||
"""It might happen that the current value of one (or more) below parameters stored in
|
||||
the controldata is higher than the value stored in the global cluster configuration.
|
||||
|
||||
@@ -1089,7 +1118,7 @@ class ConfigHandler(object):
|
||||
continue
|
||||
|
||||
cvalue = parse_int(data[cname])
|
||||
if cvalue > value:
|
||||
if cvalue is not None and value is not None and cvalue > value:
|
||||
effective_configuration[name] = cvalue
|
||||
self._postgresql.set_pending_restart(True)
|
||||
|
||||
@@ -1113,35 +1142,35 @@ class ConfigHandler(object):
|
||||
return effective_configuration
|
||||
|
||||
@property
|
||||
def replication(self):
|
||||
def replication(self) -> Dict[str, Any]:
|
||||
return self._config['authentication']['replication']
|
||||
|
||||
@property
|
||||
def superuser(self):
|
||||
def superuser(self) -> Dict[str, Any]:
|
||||
return self._superuser
|
||||
|
||||
@property
|
||||
def rewind_credentials(self):
|
||||
def rewind_credentials(self) -> Dict[str, Any]:
|
||||
return self._config['authentication'].get('rewind', self._superuser) \
|
||||
if self._postgresql.major_version >= 110000 else self._superuser
|
||||
|
||||
@property
|
||||
def ident_file(self):
|
||||
def ident_file(self) -> Optional[str]:
|
||||
ident_file = self._server_parameters.get('ident_file')
|
||||
return None if ident_file == self._pg_ident_conf else ident_file
|
||||
|
||||
@property
|
||||
def hba_file(self):
|
||||
def hba_file(self) -> Optional[str]:
|
||||
hba_file = self._server_parameters.get('hba_file')
|
||||
return None if hba_file == self._pg_hba_conf else hba_file
|
||||
|
||||
@property
|
||||
def pg_hba_conf(self):
|
||||
def pg_hba_conf(self) -> str:
|
||||
return self._pg_hba_conf
|
||||
|
||||
@property
|
||||
def postgresql_conf(self):
|
||||
def postgresql_conf(self) -> str:
|
||||
return self._postgresql_conf
|
||||
|
||||
def get(self, key, default=None):
|
||||
def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
|
||||
return self._config.get(key, default)
|
||||
|
||||
@@ -2,6 +2,10 @@ import logging
|
||||
|
||||
from contextlib import contextmanager
|
||||
from threading import Lock
|
||||
from typing import Any, Dict, Generator, Union, TYPE_CHECKING
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from psycopg import Connection as Connection3, Cursor
|
||||
from psycopg2 import connection, cursor
|
||||
|
||||
from .. import psycopg
|
||||
|
||||
@@ -9,29 +13,30 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Connection(object):
|
||||
server_version: int
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._lock = Lock()
|
||||
self._connection = None
|
||||
self._cursor_holder = None
|
||||
|
||||
def set_conn_kwargs(self, conn_kwargs):
|
||||
def set_conn_kwargs(self, conn_kwargs: Dict[str, Any]) -> None:
|
||||
self._conn_kwargs = conn_kwargs
|
||||
|
||||
def get(self):
|
||||
def get(self) -> Union['connection', 'Connection3[Any]']:
|
||||
with self._lock:
|
||||
if not self._connection or self._connection.closed != 0:
|
||||
self._connection = psycopg.connect(**self._conn_kwargs)
|
||||
self.server_version = self._connection.server_version
|
||||
self.server_version = getattr(self._connection, 'server_version', 0)
|
||||
return self._connection
|
||||
|
||||
def cursor(self):
|
||||
def cursor(self) -> Union['cursor', 'Cursor[Any]']:
|
||||
if not self._cursor_holder or self._cursor_holder.closed or self._cursor_holder.connection.closed != 0:
|
||||
logger.info("establishing a new patroni connection to the postgres cluster")
|
||||
self._cursor_holder = self.get().cursor()
|
||||
return self._cursor_holder
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
if self._connection and self._connection.closed == 0:
|
||||
self._connection.close()
|
||||
logger.info("closed patroni connection to the postgresql cluster")
|
||||
@@ -39,7 +44,7 @@ class Connection(object):
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_connection_cursor(**kwargs):
|
||||
def get_connection_cursor(**kwargs: Any) -> Generator[Union['cursor', 'Cursor[Any]'], None, None]:
|
||||
conn = psycopg.connect(**kwargs)
|
||||
with conn.cursor() as cur:
|
||||
yield cur
|
||||
|
||||
@@ -2,7 +2,9 @@ import errno
|
||||
import logging
|
||||
import os
|
||||
|
||||
from patroni.exceptions import PostgresException
|
||||
from typing import Iterable, Tuple
|
||||
|
||||
from ..exceptions import PostgresException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -55,29 +57,27 @@ def postgres_major_version_to_int(pg_version: str) -> int:
|
||||
return postgres_version_to_int(pg_version + '.0')
|
||||
|
||||
|
||||
def parse_lsn(lsn):
|
||||
def parse_lsn(lsn: str) -> int:
|
||||
t = lsn.split('/')
|
||||
return int(t[0], 16) * 0x100000000 + int(t[1], 16)
|
||||
|
||||
|
||||
def parse_history(data):
|
||||
def parse_history(data: str) -> Iterable[Tuple[int, int, str]]:
|
||||
for line in data.split('\n'):
|
||||
values = line.strip().split('\t')
|
||||
if len(values) == 3:
|
||||
try:
|
||||
values[0] = int(values[0])
|
||||
values[1] = parse_lsn(values[1])
|
||||
yield values
|
||||
yield int(values[0]), parse_lsn(values[1]), values[2]
|
||||
except (IndexError, ValueError):
|
||||
logger.exception('Exception when parsing timeline history line "%s"', values)
|
||||
|
||||
|
||||
def format_lsn(lsn, full=False):
|
||||
def format_lsn(lsn: int, full: bool = False) -> str:
|
||||
template = '{0:X}/{1:08X}' if full else '{0:X}/{1:X}'
|
||||
return template.format(lsn >> 32, lsn & 0xFFFFFFFF)
|
||||
|
||||
|
||||
def fsync_dir(path):
|
||||
def fsync_dir(path: str) -> None:
|
||||
if os.name != 'nt':
|
||||
fd = os.open(path, os.O_DIRECTORY)
|
||||
try:
|
||||
|
||||
@@ -7,6 +7,9 @@ import signal
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from multiprocessing.connection import Connection
|
||||
from typing import Dict, Optional, List
|
||||
|
||||
from patroni import PATRONI_ENV_PREFIX, KUBERNETES_ENV_PREFIX
|
||||
|
||||
# avoid spawning the resource tracker process
|
||||
@@ -26,7 +29,7 @@ STOP_SIGNALS = {
|
||||
}
|
||||
|
||||
|
||||
def pg_ctl_start(conn, cmdline, env):
|
||||
def pg_ctl_start(conn: Connection, cmdline: List[str], env: Dict[str, str]) -> None:
|
||||
if os.name != 'nt':
|
||||
os.setsid()
|
||||
try:
|
||||
@@ -40,7 +43,8 @@ def pg_ctl_start(conn, cmdline, env):
|
||||
|
||||
class PostmasterProcess(psutil.Process):
|
||||
|
||||
def __init__(self, pid):
|
||||
def __init__(self, pid: int) -> None:
|
||||
self._postmaster_pid: Dict[str, str]
|
||||
self.is_single_user = False
|
||||
if pid < 0:
|
||||
pid = -pid
|
||||
@@ -48,7 +52,7 @@ class PostmasterProcess(psutil.Process):
|
||||
super(PostmasterProcess, self).__init__(pid)
|
||||
|
||||
@staticmethod
|
||||
def _read_postmaster_pidfile(data_dir):
|
||||
def _read_postmaster_pidfile(data_dir: str) -> Dict[str, str]:
|
||||
"""Reads and parses postmaster.pid from the data directory
|
||||
|
||||
:returns dictionary of values if successful, empty dictionary otherwise
|
||||
@@ -60,7 +64,7 @@ class PostmasterProcess(psutil.Process):
|
||||
except IOError:
|
||||
return {}
|
||||
|
||||
def _is_postmaster_process(self):
|
||||
def _is_postmaster_process(self) -> bool:
|
||||
try:
|
||||
start_time = int(self._postmaster_pid.get('start_time', 0))
|
||||
if start_time and abs(self.create_time() - start_time) > 3:
|
||||
@@ -79,7 +83,7 @@ class PostmasterProcess(psutil.Process):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _from_pidfile(cls, data_dir):
|
||||
def _from_pidfile(cls, data_dir: str) -> Optional['PostmasterProcess']:
|
||||
postmaster_pid = PostmasterProcess._read_postmaster_pidfile(data_dir)
|
||||
try:
|
||||
pid = int(postmaster_pid.get('pid', 0))
|
||||
@@ -88,10 +92,10 @@ class PostmasterProcess(psutil.Process):
|
||||
proc._postmaster_pid = postmaster_pid
|
||||
return proc
|
||||
except ValueError:
|
||||
pass
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def from_pidfile(data_dir):
|
||||
def from_pidfile(data_dir: str) -> Optional['PostmasterProcess']:
|
||||
try:
|
||||
proc = PostmasterProcess._from_pidfile(data_dir)
|
||||
return proc if proc and proc._is_postmaster_process() else None
|
||||
@@ -99,13 +103,13 @@ class PostmasterProcess(psutil.Process):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_pid(cls, pid):
|
||||
def from_pid(cls, pid: int) -> Optional['PostmasterProcess']:
|
||||
try:
|
||||
return cls(pid)
|
||||
except psutil.NoSuchProcess:
|
||||
return None
|
||||
|
||||
def signal_kill(self):
|
||||
def signal_kill(self) -> bool:
|
||||
"""to suspend and kill postmaster and all children
|
||||
|
||||
:returns True if postmaster and children are killed, False if error
|
||||
@@ -141,7 +145,7 @@ class PostmasterProcess(psutil.Process):
|
||||
psutil.wait_procs(children + [self])
|
||||
return True
|
||||
|
||||
def signal_stop(self, mode, pg_ctl='pg_ctl'):
|
||||
def signal_stop(self, mode: str, pg_ctl: str = 'pg_ctl') -> Optional[bool]:
|
||||
"""Signal postmaster process to stop
|
||||
|
||||
:returns None if signaled, True if process is already gone, False if error
|
||||
@@ -161,7 +165,7 @@ class PostmasterProcess(psutil.Process):
|
||||
|
||||
return None
|
||||
|
||||
def pg_ctl_kill(self, mode, pg_ctl):
|
||||
def pg_ctl_kill(self, mode: str, pg_ctl: str) -> Optional[bool]:
|
||||
try:
|
||||
status = subprocess.call([pg_ctl, "kill", STOP_SIGNALS[mode], str(self.pid)])
|
||||
except OSError:
|
||||
@@ -171,7 +175,7 @@ class PostmasterProcess(psutil.Process):
|
||||
else:
|
||||
return not self.is_running()
|
||||
|
||||
def wait_for_user_backends_to_close(self, stop_timeout):
|
||||
def wait_for_user_backends_to_close(self, stop_timeout: Optional[float]) -> None:
|
||||
# These regexps are cross checked against versions PostgreSQL 9.1 .. 15
|
||||
aux_proc_re = re.compile("(?:postgres:)( .*:)? (?:(?:archiver|startup|autovacuum launcher|autovacuum worker|"
|
||||
"checkpointer|logger|stats collector|wal receiver|wal writer|writer)(?: process )?|"
|
||||
@@ -183,8 +187,8 @@ class PostmasterProcess(psutil.Process):
|
||||
except psutil.Error:
|
||||
return logger.debug('Failed to get list of postmaster children')
|
||||
|
||||
user_backends = []
|
||||
user_backends_cmdlines = {}
|
||||
user_backends: List[psutil.Process] = []
|
||||
user_backends_cmdlines: Dict[int, str] = {}
|
||||
for child in children:
|
||||
try:
|
||||
cmdline = child.cmdline()
|
||||
@@ -195,7 +199,7 @@ class PostmasterProcess(psutil.Process):
|
||||
pass
|
||||
if user_backends:
|
||||
logger.debug('Waiting for user backends %s to close', ', '.join(user_backends_cmdlines.values()))
|
||||
gone, live = psutil.wait_procs(user_backends, stop_timeout)
|
||||
_, live = psutil.wait_procs(user_backends, stop_timeout)
|
||||
if stop_timeout and live:
|
||||
live = [user_backends_cmdlines[b.pid] for b in live]
|
||||
logger.warning('Backends still alive after %s: %s', stop_timeout, ', '.join(live))
|
||||
@@ -203,7 +207,7 @@ class PostmasterProcess(psutil.Process):
|
||||
logger.debug("Backends closed")
|
||||
|
||||
@staticmethod
|
||||
def start(pgcommand, data_dir, conf, options):
|
||||
def start(pgcommand: str, data_dir: str, conf: str, options: List[str]) -> Optional['PostmasterProcess']:
|
||||
# Unfortunately `pg_ctl start` does not return postmaster pid to us. Without this information
|
||||
# it is hard to know the current state of postgres startup, so we had to reimplement pg_ctl start
|
||||
# in python. It will start postgres, wait for port to be open and wait until postgres will start
|
||||
@@ -234,7 +238,7 @@ class PostmasterProcess(psutil.Process):
|
||||
pass
|
||||
cmdline = [pgcommand, '-D', data_dir, '--config-file={}'.format(conf)] + options
|
||||
logger.debug("Starting postgres: %s", " ".join(cmdline))
|
||||
ctx = multiprocessing.get_context('spawn') if sys.version_info >= (3, 4) else multiprocessing
|
||||
ctx = multiprocessing.get_context('spawn')
|
||||
parent_conn, child_conn = ctx.Pipe(False)
|
||||
proc = ctx.Process(target=pg_ctl_start, args=(child_conn, cmdline, env))
|
||||
proc.start()
|
||||
|
||||
@@ -5,36 +5,46 @@ import shlex
|
||||
import shutil
|
||||
import subprocess
|
||||
|
||||
from enum import IntEnum
|
||||
from threading import Lock, Thread
|
||||
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
||||
|
||||
from . import Postgresql
|
||||
from .connection import get_connection_cursor
|
||||
from .misc import format_lsn, fsync_dir, parse_history, parse_lsn
|
||||
from ..async_executor import CriticalTask
|
||||
from ..dcs import Leader
|
||||
from ..dcs import Leader, RemoteMember
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
REWIND_STATUS = type('Enum', (), {'INITIAL': 0, 'CHECKPOINT': 1, 'CHECK': 2, 'NEED': 3,
|
||||
'NOT_NEED': 4, 'SUCCESS': 5, 'FAILED': 6})
|
||||
|
||||
class REWIND_STATUS(IntEnum):
|
||||
INITIAL = 0
|
||||
CHECKPOINT = 1
|
||||
CHECK = 2
|
||||
NEED = 3
|
||||
NOT_NEED = 4
|
||||
SUCCESS = 5
|
||||
FAILED = 6
|
||||
|
||||
|
||||
class Rewind(object):
|
||||
|
||||
def __init__(self, postgresql):
|
||||
def __init__(self, postgresql: Postgresql) -> None:
|
||||
self._postgresql = postgresql
|
||||
self._checkpoint_task_lock = Lock()
|
||||
self.reset_state()
|
||||
|
||||
@staticmethod
|
||||
def configuration_allows_rewind(data):
|
||||
def configuration_allows_rewind(data: Dict[str, str]) -> bool:
|
||||
return data.get('wal_log_hints setting', 'off') == 'on' or data.get('Data page checksum version', '0') != '0'
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self._postgresql.config.get('use_pg_rewind')
|
||||
def enabled(self) -> bool:
|
||||
return bool(self._postgresql.config.get('use_pg_rewind'))
|
||||
|
||||
@property
|
||||
def can_rewind(self):
|
||||
def can_rewind(self) -> bool:
|
||||
""" check if pg_rewind executable is there and that pg_controldata indicates
|
||||
we have either wal_log_hints or checksums turned on
|
||||
"""
|
||||
@@ -52,43 +62,45 @@ class Rewind(object):
|
||||
return self.configuration_allows_rewind(self._postgresql.controldata())
|
||||
|
||||
@property
|
||||
def should_remove_data_directory_on_diverged_timelines(self):
|
||||
return self._postgresql.config.get('remove_data_directory_on_diverged_timelines')
|
||||
def should_remove_data_directory_on_diverged_timelines(self) -> bool:
|
||||
return bool(self._postgresql.config.get('remove_data_directory_on_diverged_timelines'))
|
||||
|
||||
@property
|
||||
def can_rewind_or_reinitialize_allowed(self):
|
||||
def can_rewind_or_reinitialize_allowed(self) -> bool:
|
||||
return self.should_remove_data_directory_on_diverged_timelines or self.can_rewind
|
||||
|
||||
def trigger_check_diverged_lsn(self):
|
||||
def trigger_check_diverged_lsn(self) -> None:
|
||||
if self.can_rewind_or_reinitialize_allowed and self._state != REWIND_STATUS.NEED:
|
||||
self._state = REWIND_STATUS.CHECK
|
||||
|
||||
@staticmethod
|
||||
def check_leader_is_not_in_recovery(conn_kwargs):
|
||||
def check_leader_is_not_in_recovery(conn_kwargs: Dict[str, Any]) -> Optional[bool]:
|
||||
try:
|
||||
with get_connection_cursor(connect_timeout=3, options='-c statement_timeout=2000', **conn_kwargs) as cur:
|
||||
cur.execute('SELECT pg_catalog.pg_is_in_recovery()')
|
||||
if not cur.fetchone()[0]:
|
||||
row = cur.fetchone()
|
||||
if not row or not row[0]:
|
||||
return True
|
||||
logger.info('Leader is still in_recovery and therefore can\'t be used for rewind')
|
||||
except Exception:
|
||||
return logger.exception('Exception when working with leader')
|
||||
|
||||
@staticmethod
|
||||
def check_leader_has_run_checkpoint(conn_kwargs):
|
||||
def check_leader_has_run_checkpoint(conn_kwargs: Dict[str, Any]) -> Optional[str]:
|
||||
try:
|
||||
with get_connection_cursor(connect_timeout=3, options='-c statement_timeout=2000', **conn_kwargs) as cur:
|
||||
cur.execute("SELECT NOT pg_catalog.pg_is_in_recovery()"
|
||||
" AND ('x' || pg_catalog.substr(pg_catalog.pg_walfile_name("
|
||||
" pg_catalog.pg_current_wal_lsn()), 1, 8))::bit(32)::int = timeline_id"
|
||||
" FROM pg_catalog.pg_control_checkpoint()")
|
||||
if not cur.fetchone()[0]:
|
||||
row = cur.fetchone()
|
||||
if not row or not row[0]:
|
||||
return 'leader has not run a checkpoint yet'
|
||||
except Exception:
|
||||
logger.exception('Exception when working with leader')
|
||||
return 'not accessible or not healty'
|
||||
|
||||
def _get_checkpoint_end(self, timeline, lsn):
|
||||
def _get_checkpoint_end(self, timeline: int, lsn: int) -> int:
|
||||
"""The checkpoint record size in WAL depends on postgres major version and platform (memory alignment).
|
||||
Hence, the only reliable way to figure out where it ends, read the record from file with the help of pg_waldump
|
||||
and parse the output. We are trying to read two records, and expect that it will fail to read the second one:
|
||||
@@ -96,12 +108,12 @@ class Rewind(object):
|
||||
The error message contains information about LSN of the next record, which is exactly where checkpoint ends."""
|
||||
|
||||
lsn8 = format_lsn(lsn, True)
|
||||
lsn = format_lsn(lsn)
|
||||
out, err = self._postgresql.waldump(timeline, lsn, 2)
|
||||
lsn_str = format_lsn(lsn)
|
||||
out, err = self._postgresql.waldump(timeline, lsn_str, 2)
|
||||
if out is not None and err is not None:
|
||||
out = out.decode('utf-8').rstrip().split('\n')
|
||||
err = err.decode('utf-8').rstrip().split('\n')
|
||||
pattern = 'error in WAL record at {0}: invalid record length at '.format(lsn)
|
||||
pattern = 'error in WAL record at {0}: invalid record length at '.format(lsn_str)
|
||||
|
||||
if len(out) == 1 and len(err) == 1 and ', lsn: {0}, prev '.format(lsn8) in out[0] and pattern in err[0]:
|
||||
i = err[0].find(pattern) + len(pattern)
|
||||
@@ -117,20 +129,20 @@ class Rewind(object):
|
||||
|
||||
return 0
|
||||
|
||||
def _get_local_timeline_lsn_from_controldata(self):
|
||||
def _get_local_timeline_lsn_from_controldata(self) -> Tuple[Optional[bool], Optional[int], Optional[int]]:
|
||||
in_recovery = timeline = lsn = None
|
||||
data = self._postgresql.controldata()
|
||||
try:
|
||||
if data.get('Database cluster state') in ('shut down in recovery', 'in archive recovery'):
|
||||
in_recovery = True
|
||||
lsn = data.get('Minimum recovery ending location')
|
||||
timeline = int(data.get("Min recovery ending loc's timeline"))
|
||||
timeline = int(data.get("Min recovery ending loc's timeline", ""))
|
||||
if lsn == '0/0' or timeline == 0: # it was a primary when it crashed
|
||||
data['Database cluster state'] = 'shut down'
|
||||
if data.get('Database cluster state') == 'shut down':
|
||||
in_recovery = False
|
||||
lsn = data.get('Latest checkpoint location')
|
||||
timeline = int(data.get("Latest checkpoint's TimeLineID"))
|
||||
timeline = int(data.get("Latest checkpoint's TimeLineID", ""))
|
||||
except (TypeError, ValueError):
|
||||
logger.exception('Failed to get local timeline and lsn from pg_controldata output')
|
||||
|
||||
@@ -143,7 +155,7 @@ class Rewind(object):
|
||||
|
||||
return in_recovery, timeline, lsn
|
||||
|
||||
def _get_local_timeline_lsn(self):
|
||||
def _get_local_timeline_lsn(self) -> Tuple[Optional[bool], Optional[int], Optional[int]]:
|
||||
if self._postgresql.is_running(): # if postgres is running - get timeline from replication connection
|
||||
in_recovery = True
|
||||
timeline = self._postgresql.received_timeline() or self._postgresql.get_replica_timeline()
|
||||
@@ -156,14 +168,15 @@ class Rewind(object):
|
||||
return in_recovery, timeline, lsn
|
||||
|
||||
@staticmethod
|
||||
def _log_primary_history(history, i):
|
||||
def _log_primary_history(history: List[Tuple[int, int, str]], i: int) -> None:
|
||||
start = max(0, i - 3)
|
||||
end = None if i + 4 >= len(history) else i + 2
|
||||
history_show = []
|
||||
history_show: List[str] = []
|
||||
|
||||
def format_history_line(line):
|
||||
def format_history_line(line: Tuple[int, int, str]) -> str:
|
||||
return '{0}\t{1}\t{2}'.format(line[0], format_lsn(line[1]), line[2])
|
||||
|
||||
line = None
|
||||
for line in history[start:end]:
|
||||
history_show.append(format_history_line(line))
|
||||
|
||||
@@ -173,7 +186,7 @@ class Rewind(object):
|
||||
|
||||
logger.info('primary: history=%s', '\n'.join(history_show))
|
||||
|
||||
def _conn_kwargs(self, member, auth):
|
||||
def _conn_kwargs(self, member: Union[Leader, RemoteMember], auth: Dict[str, Any]) -> Dict[str, Any]:
|
||||
ret = member.conn_kwargs(auth)
|
||||
if not ret.get('dbname'):
|
||||
ret['dbname'] = self._postgresql.database
|
||||
@@ -183,7 +196,7 @@ class Rewind(object):
|
||||
ret['target_session_attrs'] = 'read-write'
|
||||
return ret
|
||||
|
||||
def _check_timeline_and_lsn(self, leader):
|
||||
def _check_timeline_and_lsn(self, leader: Union[Leader, RemoteMember]) -> None:
|
||||
in_recovery, local_timeline, local_lsn = self._get_local_timeline_lsn()
|
||||
if local_timeline is None or local_lsn is None:
|
||||
return
|
||||
@@ -205,23 +218,28 @@ class Rewind(object):
|
||||
try:
|
||||
with self._postgresql.get_replication_connection_cursor(**leader.conn_kwargs()) as cur:
|
||||
cur.execute('IDENTIFY_SYSTEM')
|
||||
primary_timeline = cur.fetchone()[1]
|
||||
logger.info('primary_timeline=%s', primary_timeline)
|
||||
if local_timeline > primary_timeline: # Not always supported by pg_rewind
|
||||
need_rewind = True
|
||||
elif local_timeline == primary_timeline:
|
||||
need_rewind = False
|
||||
elif primary_timeline > 1:
|
||||
cur.execute('TIMELINE_HISTORY {0}'.format(primary_timeline))
|
||||
history = cur.fetchone()[1]
|
||||
if not isinstance(history, str):
|
||||
history = bytes(history).decode('utf-8')
|
||||
logger.debug('primary: history=%s', history)
|
||||
row = cur.fetchone()
|
||||
if row:
|
||||
primary_timeline = row[1]
|
||||
logger.info('primary_timeline=%s', primary_timeline)
|
||||
if local_timeline > primary_timeline: # Not always supported by pg_rewind
|
||||
need_rewind = True
|
||||
elif local_timeline == primary_timeline:
|
||||
need_rewind = False
|
||||
elif primary_timeline > 1:
|
||||
cur.execute('TIMELINE_HISTORY {0}'.format(primary_timeline).encode('utf-8'))
|
||||
row = cur.fetchone()
|
||||
if row:
|
||||
history = row[1]
|
||||
if not isinstance(history, str):
|
||||
history = bytes(history).decode('utf-8')
|
||||
logger.debug('primary: history=%s', history)
|
||||
except Exception:
|
||||
return logger.exception('Exception when working with primary via replication connection')
|
||||
|
||||
if history is not None:
|
||||
history = list(parse_history(history))
|
||||
i = len(history)
|
||||
for i, (parent_timeline, switchpoint, _) in enumerate(history):
|
||||
if parent_timeline == local_timeline:
|
||||
# We don't need to rewind when:
|
||||
@@ -243,12 +261,12 @@ class Rewind(object):
|
||||
|
||||
self._state = need_rewind and REWIND_STATUS.NEED or REWIND_STATUS.NOT_NEED
|
||||
|
||||
def rewind_or_reinitialize_needed_and_possible(self, leader):
|
||||
def rewind_or_reinitialize_needed_and_possible(self, leader: Union[Leader, RemoteMember, None]) -> bool:
|
||||
if leader and leader.name != self._postgresql.name and leader.conn_url and self._state == REWIND_STATUS.CHECK:
|
||||
self._check_timeline_and_lsn(leader)
|
||||
return leader and leader.conn_url and self._state == REWIND_STATUS.NEED
|
||||
return bool(leader and leader.conn_url) and self._state == REWIND_STATUS.NEED
|
||||
|
||||
def __checkpoint(self, task, wakeup):
|
||||
def __checkpoint(self, task: CriticalTask, wakeup: Callable[..., Any]) -> None:
|
||||
try:
|
||||
result = self._postgresql.checkpoint()
|
||||
except Exception as e:
|
||||
@@ -258,7 +276,7 @@ class Rewind(object):
|
||||
if task.result:
|
||||
wakeup()
|
||||
|
||||
def ensure_checkpoint_after_promote(self, wakeup):
|
||||
def ensure_checkpoint_after_promote(self, wakeup: Callable[..., Any]) -> None:
|
||||
"""After promote issue a CHECKPOINT from a new thread and asynchronously check the result.
|
||||
In case if CHECKPOINT failed, just check that timeline in pg_control was updated."""
|
||||
|
||||
@@ -275,10 +293,10 @@ class Rewind(object):
|
||||
self._checkpoint_task = CriticalTask()
|
||||
Thread(target=self.__checkpoint, args=(self._checkpoint_task, wakeup)).start()
|
||||
|
||||
def checkpoint_after_promote(self):
|
||||
def checkpoint_after_promote(self) -> bool:
|
||||
return self._state == REWIND_STATUS.CHECKPOINT
|
||||
|
||||
def _buid_archiver_command(self, command, wal_filename):
|
||||
def _buid_archiver_command(self, command: str, wal_filename: str) -> str:
|
||||
"""Replace placeholders in the given archiver command's template.
|
||||
Applicable for archive_command and restore_command.
|
||||
Can also be used for archive_cleanup_command and recovery_end_command,
|
||||
@@ -306,13 +324,13 @@ class Rewind(object):
|
||||
|
||||
return cmd
|
||||
|
||||
def _fetch_missing_wal(self, restore_command, wal_filename):
|
||||
def _fetch_missing_wal(self, restore_command: str, wal_filename: str) -> bool:
|
||||
cmd = self._buid_archiver_command(restore_command, wal_filename)
|
||||
|
||||
logger.info('Trying to fetch the missing wal: %s', cmd)
|
||||
return self._postgresql.cancellable.call(shlex.split(cmd)) == 0
|
||||
|
||||
def _find_missing_wal(self, data):
|
||||
def _find_missing_wal(self, data: bytes) -> Optional[str]:
|
||||
# could not open file "$PGDATA/pg_wal/0000000A00006AA100000068": No such file or directory
|
||||
pattern = 'could not open file "'
|
||||
for line in data.decode('utf-8').split('\n'):
|
||||
@@ -325,7 +343,7 @@ class Rewind(object):
|
||||
if waldir.endswith('/pg_' + self._postgresql.wal_name) and len(wal_filename) == 24:
|
||||
return wal_filename
|
||||
|
||||
def _archive_ready_wals(self):
|
||||
def _archive_ready_wals(self) -> None:
|
||||
"""Try to archive WALs that have .ready files just in case
|
||||
archive_mode was not set to 'always' before promote, while
|
||||
after it the WALs were recycled on the promoted replica.
|
||||
@@ -361,7 +379,7 @@ class Rewind(object):
|
||||
else:
|
||||
logger.info('Failed to archive WAL segment %s', wal)
|
||||
|
||||
def _maybe_clean_pg_replslot(self):
|
||||
def _maybe_clean_pg_replslot(self) -> None:
|
||||
"""Clean pg_replslot directory if pg version is less then 11
|
||||
(pg_rewind deletes $PGDATA/pg_replslot content only since pg11)."""
|
||||
if self._postgresql.major_version < 110000:
|
||||
@@ -373,33 +391,33 @@ class Rewind(object):
|
||||
except Exception as e:
|
||||
logger.warning('Unable to clean %s: %r', replslot_dir, e)
|
||||
|
||||
def pg_rewind(self, r):
|
||||
def pg_rewind(self, r: Dict[str, Any]) -> bool:
|
||||
# prepare pg_rewind connection
|
||||
env = self._postgresql.config.write_pgpass(r)
|
||||
env.update(LANG='C', LC_ALL='C', PGOPTIONS='-c statement_timeout=0')
|
||||
dsn = self._postgresql.config.format_dsn(r, True)
|
||||
logger.info('running pg_rewind from %s', dsn)
|
||||
|
||||
restore_command = self._postgresql.config.get('recovery_conf', {}).get('restore_command') \
|
||||
restore_command = (self._postgresql.config.get('recovery_conf') or {}).get('restore_command') \
|
||||
if self._postgresql.major_version < 120000 else self._postgresql.get_guc_value('restore_command')
|
||||
|
||||
# Until v15 pg_rewind expected postgresql.conf to be inside $PGDATA, which is not the case on e.g. Debian
|
||||
pg_rewind_can_restore = restore_command and (self._postgresql.major_version >= 150000
|
||||
or (self._postgresql.major_version >= 130000
|
||||
and self._postgresql.config._config_dir
|
||||
and self._postgresql.config.config_dir
|
||||
== self._postgresql.data_dir))
|
||||
|
||||
cmd = [self._postgresql.pgcommand('pg_rewind')]
|
||||
if pg_rewind_can_restore:
|
||||
cmd.append('--restore-target-wal')
|
||||
if self._postgresql.major_version >= 150000 and\
|
||||
self._postgresql.config._config_dir != self._postgresql.data_dir:
|
||||
self._postgresql.config.config_dir != self._postgresql.data_dir:
|
||||
cmd.append('--config-file={0}'.format(self._postgresql.config.postgresql_conf))
|
||||
|
||||
cmd.extend(['-D', self._postgresql.data_dir, '--source-server', dsn])
|
||||
|
||||
while True:
|
||||
results = {}
|
||||
results: Dict[str, bytes] = {}
|
||||
ret = self._postgresql.cancellable.call(cmd, env=env, communicate=results)
|
||||
|
||||
logger.info('pg_rewind exit code=%s', ret)
|
||||
@@ -422,7 +440,7 @@ class Rewind(object):
|
||||
logger.info('Failed to fetch WAL segment %s required for pg_rewind', missing_wal)
|
||||
return False
|
||||
|
||||
def execute(self, leader):
|
||||
def execute(self, leader: Union[Leader, RemoteMember]) -> Optional[bool]:
|
||||
if self._postgresql.is_running() and not self._postgresql.stop(checkpoint=False):
|
||||
return logger.warning('Can not run pg_rewind because postgres is still running')
|
||||
|
||||
@@ -471,26 +489,26 @@ class Rewind(object):
|
||||
break
|
||||
return False
|
||||
|
||||
def reset_state(self):
|
||||
def reset_state(self) -> None:
|
||||
self._state = REWIND_STATUS.INITIAL
|
||||
with self._checkpoint_task_lock:
|
||||
self._checkpoint_task = None
|
||||
|
||||
@property
|
||||
def is_needed(self):
|
||||
def is_needed(self) -> bool:
|
||||
return self._state in (REWIND_STATUS.CHECK, REWIND_STATUS.NEED)
|
||||
|
||||
@property
|
||||
def executed(self):
|
||||
def executed(self) -> bool:
|
||||
return self._state > REWIND_STATUS.NOT_NEED
|
||||
|
||||
@property
|
||||
def failed(self):
|
||||
def failed(self) -> bool:
|
||||
return self._state == REWIND_STATUS.FAILED
|
||||
|
||||
def read_postmaster_opts(self):
|
||||
def read_postmaster_opts(self) -> Dict[str, str]:
|
||||
"""returns the list of option names/values from postgres.opts, Empty dict if read failed or no file"""
|
||||
result = {}
|
||||
result: Dict[str, str] = {}
|
||||
try:
|
||||
with open(os.path.join(self._postgresql.data_dir, 'postmaster.opts')) as f:
|
||||
data = f.read()
|
||||
@@ -502,7 +520,8 @@ class Rewind(object):
|
||||
logger.exception('Error when reading postmaster.opts')
|
||||
return result
|
||||
|
||||
def single_user_mode(self, communicate=None, options=None):
|
||||
def single_user_mode(self, communicate: Optional[Dict[str, Any]] = None,
|
||||
options: Optional[Dict[str, str]] = None) -> Optional[int]:
|
||||
"""run a given command in a single-user mode. If the command is empty - then just start and stop"""
|
||||
cmd = [self._postgresql.pgcommand('postgres'), '--single', '-D', self._postgresql.data_dir]
|
||||
for opt, val in sorted((options or {}).items()):
|
||||
@@ -511,7 +530,7 @@ class Rewind(object):
|
||||
cmd.append('template1')
|
||||
return self._postgresql.cancellable.call(cmd, communicate=communicate)
|
||||
|
||||
def cleanup_archive_status(self):
|
||||
def cleanup_archive_status(self) -> None:
|
||||
status_dir = os.path.join(self._postgresql.wal_dir, 'archive_status')
|
||||
try:
|
||||
for f in os.listdir(status_dir):
|
||||
@@ -526,7 +545,7 @@ class Rewind(object):
|
||||
except OSError:
|
||||
logger.exception('Unable to list %s', status_dir)
|
||||
|
||||
def ensure_clean_shutdown(self):
|
||||
def ensure_clean_shutdown(self) -> Optional[bool]:
|
||||
self._archive_ready_wals()
|
||||
self.cleanup_archive_status()
|
||||
|
||||
@@ -534,7 +553,7 @@ class Rewind(object):
|
||||
opts = self.read_postmaster_opts()
|
||||
opts.update({'archive_mode': 'on', 'archive_command': 'false'})
|
||||
self._postgresql.config.remove_recovery_conf()
|
||||
output = {}
|
||||
output: Dict[str, bytes] = {}
|
||||
ret = self.single_user_mode(communicate=output, options=opts)
|
||||
if ret != 0:
|
||||
logger.error('Crash recovery finished with code=%s', ret)
|
||||
|
||||
@@ -5,36 +5,43 @@ import shutil
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from threading import Condition, Thread
|
||||
from typing import Any, Dict, Generator, List, Optional, Union, Tuple, TYPE_CHECKING
|
||||
|
||||
from .connection import get_connection_cursor
|
||||
from .misc import format_lsn, fsync_dir
|
||||
from ..dcs import Cluster, Leader
|
||||
from ..psycopg import OperationalError
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from psycopg import Cursor
|
||||
from psycopg2 import cursor
|
||||
from . import Postgresql
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def compare_slots(s1, s2, dbid='database'):
|
||||
def compare_slots(s1: Dict[str, Any], s2: Dict[str, Any], dbid: str = 'database') -> bool:
|
||||
return s1['type'] == s2['type'] and (s1['type'] == 'physical'
|
||||
or s1.get(dbid) == s2.get(dbid) and s1['plugin'] == s2['plugin'])
|
||||
|
||||
|
||||
class SlotsAdvanceThread(Thread):
|
||||
|
||||
def __init__(self, slots_handler):
|
||||
def __init__(self, slots_handler: 'SlotsHandler') -> None:
|
||||
super(SlotsAdvanceThread, self).__init__()
|
||||
self.daemon = True
|
||||
self._slots_handler = slots_handler
|
||||
|
||||
# _copy_slots and _failed are used to asynchronously give some feedback to the main thread
|
||||
self._copy_slots = []
|
||||
self._copy_slots: List[str] = []
|
||||
self._failed = False
|
||||
|
||||
self._scheduled = defaultdict(dict) # {'dbname1': {'slot1': 100, 'slot2': 100}, 'dbname2': {'slot3': 100}}
|
||||
# {'dbname1': {'slot1': 100, 'slot2': 100}, 'dbname2': {'slot3': 100}}
|
||||
self._scheduled: Dict[str, Dict[str, int]] = defaultdict(dict)
|
||||
self._condition = Condition() # protect self._scheduled from concurrent access and to wakeup the run() method
|
||||
|
||||
self.start()
|
||||
|
||||
def sync_slot(self, cur, database, slot, lsn):
|
||||
def sync_slot(self, cur: Union['cursor', 'Cursor[Any]'], database: str, slot: str, lsn: int) -> None:
|
||||
failed = copy = False
|
||||
try:
|
||||
cur.execute("SELECT pg_catalog.pg_replication_slot_advance(%s, %s)", (slot, format_lsn(lsn)))
|
||||
@@ -55,7 +62,7 @@ class SlotsAdvanceThread(Thread):
|
||||
if not self._scheduled[database]:
|
||||
self._scheduled.pop(database)
|
||||
|
||||
def sync_slots_in_database(self, database, slots):
|
||||
def sync_slots_in_database(self, database: str, slots: List[str]) -> None:
|
||||
with self._slots_handler.get_local_connection_cursor(dbname=database, options='-c statement_timeout=0') as cur:
|
||||
for slot in slots:
|
||||
with self._condition:
|
||||
@@ -63,7 +70,7 @@ class SlotsAdvanceThread(Thread):
|
||||
if lsn:
|
||||
self.sync_slot(cur, database, slot, lsn)
|
||||
|
||||
def sync_slots(self):
|
||||
def sync_slots(self) -> None:
|
||||
with self._condition:
|
||||
databases = list(self._scheduled.keys())
|
||||
for database in databases:
|
||||
@@ -75,7 +82,7 @@ class SlotsAdvanceThread(Thread):
|
||||
except Exception as e:
|
||||
logger.error('Failed to advance replication slots in database %s: %r', database, e)
|
||||
|
||||
def run(self):
|
||||
def run(self) -> None:
|
||||
while True:
|
||||
with self._condition:
|
||||
if not self._scheduled:
|
||||
@@ -83,7 +90,7 @@ class SlotsAdvanceThread(Thread):
|
||||
|
||||
self.sync_slots()
|
||||
|
||||
def schedule(self, advance_slots):
|
||||
def schedule(self, advance_slots: Dict[str, Dict[str, int]]) -> Tuple[bool, List[str]]:
|
||||
with self._condition:
|
||||
for database, values in advance_slots.items():
|
||||
self._scheduled[database].update(values)
|
||||
@@ -94,7 +101,7 @@ class SlotsAdvanceThread(Thread):
|
||||
|
||||
return ret
|
||||
|
||||
def on_promote(self):
|
||||
def on_promote(self) -> None:
|
||||
with self._condition:
|
||||
self._scheduled.clear()
|
||||
self._failed = False
|
||||
@@ -103,22 +110,22 @@ class SlotsAdvanceThread(Thread):
|
||||
|
||||
class SlotsHandler(object):
|
||||
|
||||
def __init__(self, postgresql):
|
||||
def __init__(self, postgresql: 'Postgresql') -> None:
|
||||
self._postgresql = postgresql
|
||||
self._advance = None
|
||||
self._replication_slots = {} # already existing replication slots
|
||||
self._unready_logical_slots = {}
|
||||
self._replication_slots: Dict[str, Dict[str, Any]] = {} # already existing replication slots
|
||||
self._unready_logical_slots: Dict[str, Optional[int]] = {}
|
||||
self.pg_replslot_dir = os.path.join(self._postgresql.data_dir, 'pg_replslot')
|
||||
self.schedule()
|
||||
|
||||
def _query(self, sql, *params):
|
||||
def _query(self, sql: str, *params: Any) -> Union['cursor', 'Cursor[Any]']:
|
||||
return self._postgresql.query(sql, *params, retry=False)
|
||||
|
||||
@staticmethod
|
||||
def _copy_items(src, dst, keys=None):
|
||||
def _copy_items(src: Dict[str, Any], dst: Dict[str, Any], keys: Optional[List[str]] = None) -> None:
|
||||
dst.update({key: src[key] for key in keys or ('datoid', 'catalog_xmin', 'confirmed_flush_lsn')})
|
||||
|
||||
def process_permanent_slots(self, slots):
|
||||
def process_permanent_slots(self, slots: List[Dict[str, Any]]) -> Dict[str, int]:
|
||||
"""This methods solves three problems at once (I know, it is weird).
|
||||
|
||||
The cluster_info_query from `Postgresql` is executed every HA loop and returns
|
||||
@@ -130,11 +137,11 @@ class SlotsHandler(object):
|
||||
3. Updates the local cache with the fresh catalog_xmin and confirmed_flush_lsn for every known slot.
|
||||
This info is used when performing the check of logical slot readiness on standbys.
|
||||
"""
|
||||
ret = {}
|
||||
ret: Dict[str, int] = {}
|
||||
|
||||
slots = {slot['slot_name']: slot for slot in slots or []}
|
||||
if slots:
|
||||
for name, value in slots.items():
|
||||
slots_dict: Dict[str, Dict[str, Any]] = {slot['slot_name']: slot for slot in slots or []}
|
||||
if slots_dict:
|
||||
for name, value in slots_dict.items():
|
||||
if name in self._replication_slots:
|
||||
if compare_slots(value, self._replication_slots[name], 'datoid'):
|
||||
if value['type'] == 'logical':
|
||||
@@ -143,15 +150,15 @@ class SlotsHandler(object):
|
||||
else:
|
||||
self._schedule_load_slots = True
|
||||
|
||||
# It could happen that the slots was deleted in the background, we want to detect this case
|
||||
if any(name not in slots for name in self._replication_slots.keys()):
|
||||
# It could happen that the slot was deleted in the background, we want to detect this case
|
||||
if any(name not in slots_dict for name in self._replication_slots.keys()):
|
||||
self._schedule_load_slots = True
|
||||
|
||||
return ret
|
||||
|
||||
def load_replication_slots(self):
|
||||
def load_replication_slots(self) -> None:
|
||||
if self._postgresql.major_version >= 90400 and self._schedule_load_slots:
|
||||
replication_slots = {}
|
||||
replication_slots: Dict[str, Dict[str, Any]] = {}
|
||||
extra = ", catalog_xmin, pg_catalog.pg_wal_lsn_diff(confirmed_flush_lsn, '0/0')::bigint"\
|
||||
if self._postgresql.major_version >= 100000 else ""
|
||||
skip_temp_slots = ' WHERE NOT temporary' if self._postgresql.major_version >= 100000 else ''
|
||||
@@ -170,15 +177,16 @@ class SlotsHandler(object):
|
||||
self._unready_logical_slots = {n: None for n, v in replication_slots.items() if v['type'] == 'logical'}
|
||||
self._force_readiness_check = False
|
||||
|
||||
def ignore_replication_slot(self, cluster, name):
|
||||
def ignore_replication_slot(self, cluster: Cluster, name: str) -> bool:
|
||||
slot = self._replication_slots[name]
|
||||
for matcher in cluster.config.ignore_slots_matchers:
|
||||
if ((matcher.get("name") is None or matcher["name"] == name)
|
||||
and all(not matcher.get(a) or matcher[a] == slot.get(a) for a in ('database', 'plugin', 'type'))):
|
||||
return True
|
||||
if cluster.config:
|
||||
for matcher in cluster.config.ignore_slots_matchers:
|
||||
if ((matcher.get("name") is None or matcher["name"] == name)
|
||||
and all(not matcher.get(a) or matcher[a] == slot.get(a) for a in ('database', 'plugin', 'type'))):
|
||||
return True
|
||||
return self._postgresql.citus_handler.ignore_replication_slot(slot)
|
||||
|
||||
def drop_replication_slot(self, name):
|
||||
def drop_replication_slot(self, name: str) -> Tuple[bool, bool]:
|
||||
"""Returns a tuple(active, dropped)"""
|
||||
cursor = self._query(('WITH slots AS (SELECT slot_name, active'
|
||||
' FROM pg_catalog.pg_replication_slots WHERE slot_name = %s),'
|
||||
@@ -186,9 +194,12 @@ class SlotsHandler(object):
|
||||
' true AS dropped FROM slots WHERE not active) '
|
||||
'SELECT active, COALESCE(dropped, false) FROM slots'
|
||||
' FULL OUTER JOIN dropped ON true'), name)
|
||||
return cursor.fetchone() if cursor.rowcount == 1 else (False, False)
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
row = (False, False)
|
||||
return row
|
||||
|
||||
def _drop_incorrect_slots(self, cluster, slots, paused):
|
||||
def _drop_incorrect_slots(self, cluster: Cluster, slots: Dict[str, Any], paused: bool) -> None:
|
||||
# drop old replication slots which are not presented in desired slots
|
||||
for name in set(self._replication_slots) - set(slots):
|
||||
if not paused and not self.ignore_replication_slot(cluster, name):
|
||||
@@ -211,7 +222,7 @@ class SlotsHandler(object):
|
||||
logger.error("Failed to drop replication slot '%s'", name)
|
||||
self._schedule_load_slots = True
|
||||
|
||||
def _ensure_physical_slots(self, slots):
|
||||
def _ensure_physical_slots(self, slots: Dict[str, Any]) -> None:
|
||||
immediately_reserve = ', true' if self._postgresql.major_version >= 90600 else ''
|
||||
for name, value in slots.items():
|
||||
if name not in self._replication_slots and value['type'] == 'physical':
|
||||
@@ -225,15 +236,15 @@ class SlotsHandler(object):
|
||||
self._schedule_load_slots = True
|
||||
|
||||
@contextmanager
|
||||
def get_local_connection_cursor(self, **kwargs):
|
||||
def get_local_connection_cursor(self, **kwargs: Any) -> Generator[Union['cursor', 'Cursor[Any]'], None, None]:
|
||||
conn_kwargs = self._postgresql.config.local_connect_kwargs
|
||||
conn_kwargs.update(kwargs)
|
||||
with get_connection_cursor(**conn_kwargs) as cur:
|
||||
yield cur
|
||||
|
||||
def _ensure_logical_slots_primary(self, slots):
|
||||
def _ensure_logical_slots_primary(self, slots: Dict[str, Any]) -> None:
|
||||
# Group logical slots to be created by database name
|
||||
logical_slots = defaultdict(dict)
|
||||
logical_slots: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict(dict)
|
||||
for name, value in slots.items():
|
||||
if value['type'] == 'logical':
|
||||
# If the logical already exists, copy some information about it into the original structure
|
||||
@@ -257,26 +268,27 @@ class SlotsHandler(object):
|
||||
slots.pop(name)
|
||||
self._schedule_load_slots = True
|
||||
|
||||
def schedule_advance_slots(self, slots):
|
||||
def schedule_advance_slots(self, slots: Dict[str, Dict[str, int]]) -> Tuple[bool, List[str]]:
|
||||
if not self._advance:
|
||||
self._advance = SlotsAdvanceThread(self)
|
||||
return self._advance.schedule(slots)
|
||||
|
||||
def _ensure_logical_slots_replica(self, cluster, slots):
|
||||
advance_slots = defaultdict(dict) # Group logical slots to be advanced by database name
|
||||
create_slots = [] # And collect logical slots to be created on the replica
|
||||
def _ensure_logical_slots_replica(self, cluster: Cluster, slots: Dict[str, Any]) -> List[str]:
|
||||
# Group logical slots to be advanced by database name
|
||||
advance_slots: Dict[str, Dict[str, int]] = defaultdict(dict)
|
||||
create_slots: List[str] = [] # And collect logical slots to be created on the replica
|
||||
for name, value in slots.items():
|
||||
if value['type'] == 'logical':
|
||||
# If the logical already exists, copy some information about it into the original structure
|
||||
if self._replication_slots.get(name, {}).get('datoid'):
|
||||
self._copy_items(self._replication_slots[name], value)
|
||||
if name in cluster.slots:
|
||||
if cluster.slots and name in cluster.slots:
|
||||
try: # Skip slots that doesn't need to be advanced
|
||||
if value['confirmed_flush_lsn'] < int(cluster.slots[name]):
|
||||
advance_slots[value['database']][name] = int(cluster.slots[name])
|
||||
except Exception as e:
|
||||
logger.error('Failed to parse "%s": %r', cluster.slots[name], e)
|
||||
elif name in cluster.slots: # We want to copy only slots with feedback in a DCS
|
||||
elif cluster.slots and name in cluster.slots: # We want to copy only slots with feedback in a DCS
|
||||
create_slots.append(name)
|
||||
|
||||
error, copy_slots = self.schedule_advance_slots(advance_slots)
|
||||
@@ -284,8 +296,9 @@ class SlotsHandler(object):
|
||||
self._schedule_load_slots = True
|
||||
return create_slots + copy_slots
|
||||
|
||||
def sync_replication_slots(self, cluster, nofailover, replicatefrom=None, paused=False):
|
||||
ret = None
|
||||
def sync_replication_slots(self, cluster: Cluster, nofailover: bool,
|
||||
replicatefrom: Optional[str] = None, paused: bool = False) -> List[str]:
|
||||
ret = []
|
||||
if self._postgresql.major_version >= 90400 and cluster.config:
|
||||
try:
|
||||
self.load_replication_slots()
|
||||
@@ -312,14 +325,15 @@ class SlotsHandler(object):
|
||||
return ret
|
||||
|
||||
@contextmanager
|
||||
def _get_leader_connection_cursor(self, leader):
|
||||
def _get_leader_connection_cursor(self, leader: Leader) -> Generator[Union['cursor', 'Cursor[Any]'], None, None]:
|
||||
conn_kwargs = leader.conn_kwargs(self._postgresql.config.rewind_credentials)
|
||||
conn_kwargs['dbname'] = self._postgresql.database
|
||||
with get_connection_cursor(connect_timeout=3, options="-c statement_timeout=2000", **conn_kwargs) as cur:
|
||||
yield cur
|
||||
|
||||
def check_logical_slots_readiness(self, cluster, nofailover, replicatefrom):
|
||||
if self._unready_logical_slots:
|
||||
def check_logical_slots_readiness(self, cluster: Cluster, nofailover: bool, replicatefrom: Optional[str]) -> None:
|
||||
catalog_xmin = None
|
||||
if self._unready_logical_slots and cluster.leader:
|
||||
slot_name = cluster.get_my_slot_name_on_primary(self._postgresql.name, replicatefrom)
|
||||
try:
|
||||
with self._get_leader_connection_cursor(cluster.leader) as cur:
|
||||
@@ -335,13 +349,14 @@ class SlotsHandler(object):
|
||||
# Remember catalog_xmin of logical slots on the primary when catalog_xmin of
|
||||
# the physical slot became valid. Logical slots on replica will be safe to use after
|
||||
# promote when catalog_xmin of the physical slot overtakes these values.
|
||||
if catalog_xmin:
|
||||
if catalog_xmin is not None:
|
||||
for name, value in slots.items():
|
||||
self._unready_logical_slots[name] = value
|
||||
else: # Replica isn't streaming or the hot_standby_feedback isn't enabled
|
||||
try:
|
||||
cur = self._query("SELECT pg_catalog.current_setting('hot_standby_feedback')::boolean")
|
||||
if not cur.fetchone()[0]:
|
||||
row = cur.fetchone()
|
||||
if row and not row[0]:
|
||||
logger.error('Logical slot failover requires "hot_standby_feedback".'
|
||||
' Please check postgresql.auto.conf')
|
||||
except Exception as e:
|
||||
@@ -354,14 +369,18 @@ class SlotsHandler(object):
|
||||
# 1. has a nonzero/non-null catalog_xmin
|
||||
# 2. has a catalog_xmin that is not newer (greater) than the catalog_xmin of any slot on the standby
|
||||
# 3. overtook the catalog_xmin of remembered values of logical slots on the primary.
|
||||
if not value or self._unready_logical_slots[name] <= catalog_xmin <= value['catalog_xmin']:
|
||||
if not value or catalog_xmin is not None and\
|
||||
self._unready_logical_slots[name] <= catalog_xmin <= value['catalog_xmin']:
|
||||
del self._unready_logical_slots[name]
|
||||
if value:
|
||||
logger.info('Logical slot %s is safe to be used after a failover', name)
|
||||
|
||||
def copy_logical_slots(self, cluster, create_slots):
|
||||
def copy_logical_slots(self, cluster: Cluster, create_slots: List[str]) -> None:
|
||||
leader = cluster.leader
|
||||
if not leader:
|
||||
return
|
||||
slots = cluster.get_replication_slots(self._postgresql.name, 'replica', False, self._postgresql.major_version)
|
||||
copy_slots: Dict[str, Dict[str, Any]] = {}
|
||||
with self._get_leader_connection_cursor(leader) as cur:
|
||||
try:
|
||||
cur.execute("SELECT slot_name, slot_type, datname, plugin, catalog_xmin, "
|
||||
@@ -370,21 +389,20 @@ class SlotsHandler(object):
|
||||
" FROM pg_catalog.pg_get_replication_slots() JOIN pg_catalog.pg_database ON datoid = oid"
|
||||
" WHERE NOT pg_catalog.pg_is_in_recovery() AND slot_name = ANY(%s)", (create_slots,))
|
||||
|
||||
create_slots = {}
|
||||
for r in cur:
|
||||
if r[0] in slots: # slot_name is defined in the global configuration
|
||||
slot = {'type': r[1], 'database': r[2], 'plugin': r[3],
|
||||
'catalog_xmin': r[4], 'confirmed_flush_lsn': r[5], 'data': r[6]}
|
||||
if compare_slots(slot, slots[r[0]]):
|
||||
create_slots[r[0]] = slot
|
||||
copy_slots[r[0]] = slot
|
||||
else:
|
||||
logger.warning('Will not copy the logical slot "%s" due to the configuration mismatch: '
|
||||
'configuration=%s, slot on the primary=%s', r[0], slots[r[0]], slot)
|
||||
except Exception as e:
|
||||
logger.error("Failed to copy logical slots from the %s via postgresql connection: %r", leader.name, e)
|
||||
|
||||
if isinstance(create_slots, dict) and create_slots and self._postgresql.stop():
|
||||
for name, value in create_slots.items():
|
||||
if copy_slots and self._postgresql.stop():
|
||||
for name, value in copy_slots.items():
|
||||
slot_dir = os.path.join(self._postgresql.slots_handler.pg_replslot_dir, name)
|
||||
slot_tmp_dir = slot_dir + '.tmp'
|
||||
if os.path.exists(slot_tmp_dir):
|
||||
@@ -403,12 +421,12 @@ class SlotsHandler(object):
|
||||
fsync_dir(self._postgresql.slots_handler.pg_replslot_dir)
|
||||
self._postgresql.start()
|
||||
|
||||
def schedule(self, value=None):
|
||||
def schedule(self, value: Optional[bool] = None) -> None:
|
||||
if value is None:
|
||||
value = self._postgresql.major_version >= 90400
|
||||
self._schedule_load_slots = self._force_readiness_check = value
|
||||
|
||||
def on_promote(self):
|
||||
def on_promote(self) -> None:
|
||||
if self._advance:
|
||||
self._advance.on_promote()
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import re
|
||||
import time
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Collection, Dict, Tuple, TYPE_CHECKING
|
||||
from typing import Any, Collection, Dict, List, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from ..collections import CaseInsensitiveDict, CaseInsensitiveSet
|
||||
from ..dcs import Cluster
|
||||
@@ -109,7 +109,7 @@ def parse_sync_standby_names(value: str) -> Dict[str, Any]:
|
||||
result = {'type': 'priority', 'num': int(tokens[0][1])}
|
||||
synclist = tokens[2:-1]
|
||||
else:
|
||||
result = {'type': 'priority', 'num': 1}
|
||||
result: Dict[str, Union[int, str, CaseInsensitiveSet]] = {'type': 'priority', 'num': 1}
|
||||
synclist = tokens
|
||||
result['members'] = CaseInsensitiveSet()
|
||||
for i, (a_type, a_value, a_pos) in enumerate(synclist):
|
||||
@@ -149,12 +149,12 @@ class SyncHandler(object):
|
||||
# "sync" replication connections, that were verified to reach self._primary_flush_lsn at some point
|
||||
self._ready_replicas = CaseInsensitiveDict({}) # keys: member names, values: connection pids
|
||||
|
||||
def _handle_synchronous_standby_names_change(self):
|
||||
def _handle_synchronous_standby_names_change(self) -> None:
|
||||
"""If synchronous_standby_names has changed we need to check that newly added replicas
|
||||
have reached self._primary_flush_lsn. Only after that they could be counted as sync."""
|
||||
synchronous_standby_names = self._postgresql.synchronous_standby_names()
|
||||
if synchronous_standby_names == self._synchronous_standby_names:
|
||||
return False
|
||||
return
|
||||
|
||||
self._synchronous_standby_names = synchronous_standby_names
|
||||
try:
|
||||
@@ -165,7 +165,7 @@ class SyncHandler(object):
|
||||
|
||||
# Invalidate cache of "sync" connections
|
||||
for app_name in list(self._ready_replicas.keys()):
|
||||
if app_name not in self._ssn_data['members']:
|
||||
if isinstance(self._ssn_data['members'], CaseInsensitiveSet) and app_name not in self._ssn_data['members']:
|
||||
del self._ready_replicas[app_name]
|
||||
|
||||
# Newly connected replicas will be counted as sync only when reached self._primary_flush_lsn
|
||||
@@ -201,7 +201,7 @@ class SyncHandler(object):
|
||||
if r[sort_col] is not None]
|
||||
|
||||
members = CaseInsensitiveDict({m.name: m for m in cluster.members})
|
||||
replica_list = []
|
||||
replica_list: List[Tuple[int, str, str, int, bool]] = []
|
||||
# pg_stat_replication.sync_state has 4 possible states - async, potential, quorum, sync.
|
||||
# That is, alphabetically they are in the reversed order of priority.
|
||||
# Since we are doing reversed sort on (sync_state, lsn) tuples, it helps to keep the result
|
||||
@@ -227,7 +227,8 @@ class SyncHandler(object):
|
||||
for pid, app_name, sync_state, replica_lsn, _ in sorted(replica_list, key=lambda x: x[4]):
|
||||
# if standby name is listed in the /sync key we can count it as synchronous, otherwice
|
||||
# it becomes really synchronous when sync_state = 'sync' and it is known that it managed to catch up
|
||||
if app_name not in self._ready_replicas and app_name in self._ssn_data['members'] and\
|
||||
if app_name not in self._ready_replicas and isinstance(self._ssn_data['members'], CaseInsensitiveSet)\
|
||||
and app_name in self._ssn_data['members'] and\
|
||||
(cluster.sync.matches(app_name) or sync_state == 'sync' and replica_lsn >= self._primary_flush_lsn):
|
||||
self._ready_replicas[app_name] = pid
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import abc
|
||||
import logging
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import Any, MutableMapping, Optional, Tuple, Union
|
||||
|
||||
from ..collections import CaseInsensitiveDict
|
||||
from ..utils import parse_bool, parse_int, parse_real
|
||||
@@ -9,23 +9,65 @@ from ..utils import parse_bool, parse_int, parse_real
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Bool(namedtuple('Bool', 'version_from,version_till')):
|
||||
class _Transformable(abc.ABC):
|
||||
|
||||
@staticmethod
|
||||
def transform(name, value):
|
||||
def __init__(self, version_from: int, version_till: Optional[int]) -> None:
|
||||
self.__version_from = version_from
|
||||
self.__version_till = version_till
|
||||
|
||||
@property
|
||||
def version_from(self) -> int:
|
||||
return self.__version_from
|
||||
|
||||
@property
|
||||
def version_till(self) -> Optional[int]:
|
||||
return self.__version_till
|
||||
|
||||
@abc.abstractmethod
|
||||
def transform(self, name: str, value: Any) -> Optional[Any]:
|
||||
"""Verify that provided value is valid.
|
||||
|
||||
:param name: GUC's name
|
||||
:param value: GUC's value
|
||||
:returns: the value (sometimes clamped) or ``None`` if the value isn't valid
|
||||
"""
|
||||
|
||||
|
||||
class Bool(_Transformable):
|
||||
|
||||
def transform(self, name: str, value: Any) -> Optional[Any]:
|
||||
if parse_bool(value) is not None:
|
||||
return value
|
||||
logger.warning('Removing bool parameter=%s from the config due to the invalid value=%s', name, value)
|
||||
|
||||
|
||||
class Number(abc.ABC, namedtuple('Number', 'version_from,version_till,min_val,max_val,unit')):
|
||||
class Number(_Transformable):
|
||||
|
||||
def __init__(self, version_from: int, version_till: Optional[int],
|
||||
min_val: Union[int, float], max_val: Union[int, float], unit: Optional[str]) -> None:
|
||||
super(Number, self).__init__(version_from, version_till)
|
||||
self.__min_val = min_val
|
||||
self.__max_val = max_val
|
||||
self.__unit = unit
|
||||
|
||||
@property
|
||||
def min_val(self) -> Union[int, float]:
|
||||
return self.__min_val
|
||||
|
||||
@property
|
||||
def max_val(self) -> Union[int, float]:
|
||||
return self.__max_val
|
||||
|
||||
@property
|
||||
def unit(self) -> Optional[str]:
|
||||
return self.__unit
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def parse(value, unit):
|
||||
"""parse value"""
|
||||
def parse(value: Any, unit: Optional[str]) -> Optional[Any]:
|
||||
"""Convert provided value to unit."""
|
||||
|
||||
def transform(self, name, value):
|
||||
def transform(self, name: str, value: Any) -> Union[int, float, None]:
|
||||
num_value = self.parse(value, self.unit)
|
||||
if num_value is not None:
|
||||
if num_value < self.min_val:
|
||||
@@ -44,20 +86,28 @@ class Number(abc.ABC, namedtuple('Number', 'version_from,version_till,min_val,ma
|
||||
class Integer(Number):
|
||||
|
||||
@staticmethod
|
||||
def parse(value, unit):
|
||||
def parse(value: Any, unit: Optional[str]) -> Optional[int]:
|
||||
return parse_int(value, unit)
|
||||
|
||||
|
||||
class Real(Number):
|
||||
|
||||
@staticmethod
|
||||
def parse(value, unit):
|
||||
def parse(value: Any, unit: Optional[str]) -> Optional[float]:
|
||||
return parse_real(value, unit)
|
||||
|
||||
|
||||
class Enum(namedtuple('Enum', 'version_from,version_till,possible_values')):
|
||||
class Enum(_Transformable):
|
||||
|
||||
def transform(self, name, value):
|
||||
def __init__(self, version_from: int, version_till: Optional[int], possible_values: Tuple[str, ...]) -> None:
|
||||
super(Enum, self).__init__(version_from, version_till)
|
||||
self.__possible_values = possible_values
|
||||
|
||||
@property
|
||||
def possible_values(self) -> Tuple[str, ...]:
|
||||
return self.__possible_values
|
||||
|
||||
def transform(self, name: str, value: Optional[Any]) -> Optional[Any]:
|
||||
if str(value).lower() in self.possible_values:
|
||||
return value
|
||||
logger.warning('Removing enum parameter=%s from the config due to the invalid value=%s', name, value)
|
||||
@@ -65,16 +115,15 @@ class Enum(namedtuple('Enum', 'version_from,version_till,possible_values')):
|
||||
|
||||
class EnumBool(Enum):
|
||||
|
||||
def transform(self, name, value):
|
||||
def transform(self, name: str, value: Optional[Any]) -> Optional[Any]:
|
||||
if parse_bool(value) is not None:
|
||||
return value
|
||||
return super(EnumBool, self).transform(name, value)
|
||||
|
||||
|
||||
class String(namedtuple('String', 'version_from,version_till')):
|
||||
class String(_Transformable):
|
||||
|
||||
@staticmethod
|
||||
def transform(name, value):
|
||||
def transform(self, name: str, value: Optional[Any]) -> Optional[Any]:
|
||||
return value
|
||||
|
||||
|
||||
@@ -511,17 +560,18 @@ recovery_parameters = CaseInsensitiveDict({
|
||||
})
|
||||
|
||||
|
||||
def _transform_parameter_value(validators, version, name, value):
|
||||
validators = validators.get(name)
|
||||
if validators:
|
||||
for validator in (validators if isinstance(validators[0], tuple) else [validators]):
|
||||
def _transform_parameter_value(validators: MutableMapping[str, Union[_Transformable, Tuple[_Transformable, ...]]],
|
||||
version: int, name: str, value: Any) -> Optional[Any]:
|
||||
name_validators = validators.get(name)
|
||||
if name_validators:
|
||||
for validator in (name_validators if isinstance(name_validators, tuple) else (name_validators,)):
|
||||
if version >= validator.version_from and\
|
||||
(validator.version_till is None or version < validator.version_till):
|
||||
return validator.transform(name, value)
|
||||
logger.warning('Removing unexpected parameter=%s value=%s from the config', name, value)
|
||||
|
||||
|
||||
def transform_postgresql_parameter_value(version, name, value):
|
||||
def transform_postgresql_parameter_value(version: int, name: str, value: Any) -> Optional[Any]:
|
||||
if '.' in name:
|
||||
return value
|
||||
if name in recovery_parameters:
|
||||
@@ -529,5 +579,5 @@ def transform_postgresql_parameter_value(version, name, value):
|
||||
return _transform_parameter_value(parameters, version, name, value)
|
||||
|
||||
|
||||
def transform_recovery_parameter_value(version, name, value):
|
||||
def transform_recovery_parameter_value(version: int, name: str, value: Any) -> Optional[Any]:
|
||||
return _transform_parameter_value(recovery_parameters, version, name, value)
|
||||
|
||||
@@ -3,8 +3,10 @@
|
||||
This module is able to handle both ``pyscopg2`` and ``psycopg3``, and it exposes a common interface for both.
|
||||
``psycopg2`` takes precedence. ``psycopg3`` will only be used if ``psycopg2`` is either absent or older than ``2.5.4``.
|
||||
"""
|
||||
from typing import Any, Optional
|
||||
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from psycopg import Connection
|
||||
from psycopg2 import connection, cursor
|
||||
|
||||
__all__ = ['connect', 'quote_ident', 'quote_literal', 'DatabaseError', 'Error', 'OperationalError', 'ProgrammingError']
|
||||
|
||||
@@ -39,9 +41,9 @@ try:
|
||||
value.prepare(conn)
|
||||
return value.getquoted().decode('utf-8')
|
||||
except ImportError:
|
||||
from psycopg import connect as __connect, sql, Error, DatabaseError, OperationalError, ProgrammingError, Connection
|
||||
from psycopg import connect as __connect, sql, Error, DatabaseError, OperationalError, ProgrammingError
|
||||
|
||||
def _connect(dsn: str = "", **kwargs: Any) -> Any:
|
||||
def _connect(dsn: Optional[str] = None, **kwargs: Any) -> 'Connection[Any]':
|
||||
"""Call ``psycopg.connect`` with ``dsn`` and ``**kwargs``.
|
||||
|
||||
.. note::
|
||||
@@ -53,19 +55,19 @@ except ImportError:
|
||||
|
||||
:returns: a connection to the database.
|
||||
"""
|
||||
ret = __connect(dsn, **kwargs)
|
||||
ret = __connect(dsn or "", **kwargs)
|
||||
setattr(ret, 'server_version', ret.pgconn.server_version) # compatibility with psycopg2
|
||||
return ret
|
||||
|
||||
def _quote_ident(value: Any, conn: Connection) -> str:
|
||||
def _quote_ident(value: Any, scope: Any) -> str:
|
||||
"""Quote *value* as a SQL identifier.
|
||||
|
||||
:param value: value to be quoted.
|
||||
:param conn: connection to evaluate the returning string into.
|
||||
:param scope: connection to evaluate the returning string into.
|
||||
|
||||
:returns: *value* quoted as a SQL identifier.
|
||||
"""
|
||||
return sql.Identifier(value).as_string(conn)
|
||||
return sql.Identifier(value).as_string(scope)
|
||||
|
||||
def quote_literal(value: Any, conn: Optional[Any] = None) -> str:
|
||||
"""Quote *value* as a SQL literal.
|
||||
@@ -78,7 +80,7 @@ except ImportError:
|
||||
return sql.Literal(value).as_string(conn)
|
||||
|
||||
|
||||
def connect(*args: Any, **kwargs: Any) -> Any:
|
||||
def connect(*args: Any, **kwargs: Any) -> Union['connection', 'Connection[Any]']:
|
||||
"""Get a connection to the database.
|
||||
|
||||
.. note::
|
||||
@@ -102,7 +104,7 @@ def connect(*args: Any, **kwargs: Any) -> Any:
|
||||
return ret
|
||||
|
||||
|
||||
def quote_ident(value: Any, conn: Optional[Any] = None) -> str:
|
||||
def quote_ident(value: Any, conn: Optional[Union['cursor', 'connection', 'Connection[Any]']] = None) -> str:
|
||||
"""Quote *value* as a SQL identifier.
|
||||
|
||||
:param value: value to be quoted.
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
|
||||
from .config import Config
|
||||
from .daemon import AbstractPatroniDaemon, abstract_main
|
||||
from .dcs.raft import KVStoreTTL
|
||||
|
||||
@@ -8,22 +9,22 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class RaftController(AbstractPatroniDaemon):
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: Config) -> None:
|
||||
super(RaftController, self).__init__(config)
|
||||
|
||||
config = self.config.get('raft')
|
||||
assert 'self_addr' in config
|
||||
self._raft = KVStoreTTL(None, None, None, **config)
|
||||
kvstore_config = self.config.get('raft')
|
||||
assert 'self_addr' in kvstore_config
|
||||
self._raft = KVStoreTTL(None, None, None, **kvstore_config)
|
||||
|
||||
def _run_cycle(self):
|
||||
def _run_cycle(self) -> None:
|
||||
try:
|
||||
self._raft.doTick(self._raft.conf.autoTickPeriod)
|
||||
except Exception:
|
||||
logger.exception('doTick')
|
||||
|
||||
def _shutdown(self):
|
||||
def _shutdown(self) -> None:
|
||||
self._raft.destroy()
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
abstract_main(RaftController)
|
||||
|
||||
@@ -140,14 +140,14 @@ class PatroniRequest(object):
|
||||
|
||||
:returns: the response returned upon request.
|
||||
"""
|
||||
url = member.api_url
|
||||
url = member.api_url or ''
|
||||
if endpoint:
|
||||
scheme, netloc, _, _, _, _ = urlparse(url)
|
||||
url = urlunparse((scheme, netloc, endpoint, '', '', ''))
|
||||
return self.request(method, url, data, **kwargs)
|
||||
|
||||
|
||||
def get(url: str, verify: Optional[bool] = True, **kwargs: Any) -> urllib3.response.HTTPResponse:
|
||||
def get(url: str, verify: bool = True, **kwargs: Any) -> urllib3.response.HTTPResponse:
|
||||
"""Perform an HTTP GET request.
|
||||
|
||||
.. note::
|
||||
|
||||
@@ -5,20 +5,22 @@ import logging
|
||||
import sys
|
||||
import boto3
|
||||
|
||||
from ..utils import Retry, RetryFailedError
|
||||
|
||||
from botocore.exceptions import ClientError
|
||||
from botocore.utils import IMDSFetcher
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..utils import Retry, RetryFailedError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AWSConnection(object):
|
||||
|
||||
def __init__(self, cluster_name):
|
||||
def __init__(self, cluster_name: Optional[str]) -> None:
|
||||
self.available = False
|
||||
self.cluster_name = cluster_name if cluster_name is not None else 'unknown'
|
||||
self._retry = Retry(deadline=300, max_delay=30, max_tries=-1, retry_exceptions=(ClientError,))
|
||||
self._retry = Retry(deadline=300, max_delay=30, max_tries=-1, retry_exceptions=ClientError)
|
||||
try:
|
||||
# get the instance id
|
||||
fetcher = IMDSFetcher(timeout=2.1)
|
||||
@@ -38,13 +40,13 @@ class AWSConnection(object):
|
||||
return
|
||||
self.available = True
|
||||
|
||||
def retry(self, *args, **kwargs):
|
||||
def retry(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self._retry.copy()(*args, **kwargs)
|
||||
|
||||
def aws_available(self):
|
||||
def aws_available(self) -> bool:
|
||||
return self.available
|
||||
|
||||
def _tag_ebs(self, conn, role):
|
||||
def _tag_ebs(self, conn: Any, role: str) -> None:
|
||||
""" set tags, carrying the cluster name, instance role and instance id for the EBS storage """
|
||||
tags = [{'Key': 'Name', 'Value': 'spilo_' + self.cluster_name},
|
||||
{'Key': 'Role', 'Value': role},
|
||||
@@ -52,16 +54,16 @@ class AWSConnection(object):
|
||||
volumes = conn.volumes.filter(Filters=[{'Name': 'attachment.instance-id', 'Values': [self.instance_id]}])
|
||||
conn.create_tags(Resources=[v.id for v in volumes], Tags=tags)
|
||||
|
||||
def _tag_ec2(self, conn, role):
|
||||
def _tag_ec2(self, conn: Any, role: str) -> None:
|
||||
""" tag the current EC2 instance with a cluster role """
|
||||
tags = [{'Key': 'Role', 'Value': role}]
|
||||
conn.create_tags(Resources=[self.instance_id], Tags=tags)
|
||||
|
||||
def on_role_change(self, new_role):
|
||||
def on_role_change(self, new_role: str) -> bool:
|
||||
if not self.available:
|
||||
return False
|
||||
try:
|
||||
conn = boto3.resource('ec2', region_name=self.region)
|
||||
conn = boto3.resource('ec2', region_name=self.region) # type: ignore
|
||||
self.retry(self._tag_ec2, conn, new_role)
|
||||
self.retry(self._tag_ebs, conn, new_role)
|
||||
except RetryFailedError:
|
||||
|
||||
@@ -31,7 +31,8 @@ import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
from collections import namedtuple
|
||||
from enum import IntEnum
|
||||
from typing import Any, List, NamedTuple, Optional, Tuple
|
||||
|
||||
from .. import psycopg
|
||||
|
||||
@@ -42,15 +43,14 @@ si_prefixes = ['K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y']
|
||||
|
||||
|
||||
# Meaningful names to the exit codes used by WALERestore
|
||||
ExitCode = type('Enum', (), {
|
||||
'SUCCESS': 0, #: Succeeded
|
||||
'RETRY_LATER': 1, #: External issue, retry later
|
||||
'FAIL': 2 #: Don't try again unless configuration changes
|
||||
})
|
||||
class ExitCode(IntEnum):
|
||||
SUCCESS = 0 #: Succeeded
|
||||
RETRY_LATER = 1 #: External issue, retry later
|
||||
FAIL = 2 #: Don't try again unless configuration changes
|
||||
|
||||
|
||||
# We need to know the current PG version in order to figure out the correct WAL directory name
|
||||
def get_major_version(data_dir):
|
||||
def get_major_version(data_dir: str) -> float:
|
||||
version_file = os.path.join(data_dir, 'PG_VERSION')
|
||||
if os.path.isfile(version_file): # version file exists
|
||||
try:
|
||||
@@ -61,7 +61,7 @@ def get_major_version(data_dir):
|
||||
return 0.0
|
||||
|
||||
|
||||
def repr_size(n_bytes):
|
||||
def repr_size(n_bytes: float) -> str:
|
||||
"""
|
||||
>>> repr_size(1000)
|
||||
'1000 Bytes'
|
||||
@@ -77,7 +77,7 @@ def repr_size(n_bytes):
|
||||
return '{0} {1}iB'.format(round(n_bytes, 1), si_prefixes[i])
|
||||
|
||||
|
||||
def size_as_bytes(size_, prefix):
|
||||
def size_as_bytes(size: float, prefix: str) -> int:
|
||||
"""
|
||||
>>> size_as_bytes(7.5, 'T')
|
||||
8246337208320
|
||||
@@ -88,23 +88,19 @@ def size_as_bytes(size_, prefix):
|
||||
|
||||
exponent = si_prefixes.index(prefix) + 1
|
||||
|
||||
return int(size_ * (1024.0 ** exponent))
|
||||
return int(size * (1024.0 ** exponent))
|
||||
|
||||
|
||||
WALEConfig = namedtuple(
|
||||
'WALEConfig',
|
||||
[
|
||||
'env_dir',
|
||||
'threshold_mb',
|
||||
'threshold_pct',
|
||||
'cmd',
|
||||
]
|
||||
)
|
||||
class WALEConfig(NamedTuple):
|
||||
env_dir: str
|
||||
threshold_mb: int
|
||||
threshold_pct: int
|
||||
cmd: List[str]
|
||||
|
||||
|
||||
class WALERestore(object):
|
||||
def __init__(self, scope, datadir, connstring, env_dir, threshold_mb,
|
||||
threshold_pct, use_iam, no_leader, retries):
|
||||
def __init__(self, scope: str, datadir: str, connstring: str, env_dir: str, threshold_mb: int,
|
||||
threshold_pct: int, use_iam: int, no_leader: bool, retries: int) -> None:
|
||||
self.scope = scope
|
||||
self.leader_connection = connstring
|
||||
self.data_dir = datadir
|
||||
@@ -129,7 +125,7 @@ class WALERestore(object):
|
||||
self.init_error = (not os.path.exists(self.wal_e.env_dir))
|
||||
self.retries = retries
|
||||
|
||||
def run(self):
|
||||
def run(self) -> int:
|
||||
"""
|
||||
Creates a new replica using WAL-E
|
||||
|
||||
@@ -158,7 +154,7 @@ class WALERestore(object):
|
||||
logger.exception("Unhandled exception when running WAL-E restore")
|
||||
return ExitCode.FAIL
|
||||
|
||||
def should_use_s3_to_create_replica(self):
|
||||
def should_use_s3_to_create_replica(self) -> Optional[bool]:
|
||||
""" determine whether it makes sense to use S3 and not pg_basebackup """
|
||||
|
||||
threshold_megabytes = self.wal_e.threshold_mb
|
||||
@@ -218,7 +214,7 @@ class WALERestore(object):
|
||||
try:
|
||||
# get the difference in bytes between the current WAL location and the backup start offset
|
||||
con = psycopg.connect(self.leader_connection)
|
||||
if con.server_version >= 100000:
|
||||
if getattr(con, 'server_version', 0) >= 100000:
|
||||
wal_name = 'wal'
|
||||
lsn_name = 'lsn'
|
||||
else:
|
||||
@@ -232,8 +228,9 @@ class WALERestore(object):
|
||||
" ELSE pg_catalog.pg_{0}_{1}_diff(pg_catalog.pg_current_{0}_{1}(), %s)::bigint"
|
||||
" END").format(wal_name, lsn_name),
|
||||
(backup_start_lsn, backup_start_lsn, backup_start_lsn))
|
||||
|
||||
diff_in_bytes = int(cur.fetchone()[0])
|
||||
for row in cur:
|
||||
diff_in_bytes = int(row[0])
|
||||
break
|
||||
except psycopg.Error:
|
||||
logger.exception('could not determine difference with the leader location')
|
||||
if attempts_no < self.retries: # retry in case of a temporarily connection issue
|
||||
@@ -262,11 +259,11 @@ class WALERestore(object):
|
||||
are_thresholds_ok = is_size_thresh_ok and is_percentage_thresh_ok
|
||||
|
||||
class Size(object):
|
||||
def __init__(self, n_bytes, prefix=None):
|
||||
def __init__(self, n_bytes: float, prefix: Optional[str] = None) -> None:
|
||||
self.n_bytes = n_bytes
|
||||
self.prefix = prefix
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
if self.prefix is not None:
|
||||
n_bytes = size_as_bytes(self.n_bytes, self.prefix)
|
||||
else:
|
||||
@@ -274,10 +271,10 @@ class WALERestore(object):
|
||||
return repr_size(n_bytes)
|
||||
|
||||
class HumanContext(object):
|
||||
def __init__(self, items):
|
||||
def __init__(self, items: List[Tuple[str, Any]]) -> None:
|
||||
self.items = items
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return ', '.join('{}={!r}'.format(key, value)
|
||||
for key, value in self.items)
|
||||
|
||||
@@ -298,7 +295,7 @@ class WALERestore(object):
|
||||
logger.info('Thresholds are OK, using wal-e basebackup: %s', human_context)
|
||||
return are_thresholds_ok
|
||||
|
||||
def fix_subdirectory_path_if_broken(self, dirname):
|
||||
def fix_subdirectory_path_if_broken(self, dirname: str) -> bool:
|
||||
# in case it is a symlink pointing to a non-existing location, remove it and create the actual directory
|
||||
path = os.path.join(self.data_dir, dirname)
|
||||
if not os.path.exists(path):
|
||||
@@ -316,7 +313,7 @@ class WALERestore(object):
|
||||
return False
|
||||
return True
|
||||
|
||||
def create_replica_with_s3(self):
|
||||
def create_replica_with_s3(self) -> int:
|
||||
# if we're set up, restore the replica using fetch latest
|
||||
try:
|
||||
cmd = self.wal_e.cmd + ['backup-fetch',
|
||||
@@ -334,7 +331,7 @@ class WALERestore(object):
|
||||
return exit_code
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> int:
|
||||
logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', level=logging.INFO)
|
||||
parser = argparse.ArgumentParser(description='Script to image replicas using WAL-E')
|
||||
parser.add_argument('--scope', required=True)
|
||||
@@ -363,11 +360,12 @@ def main():
|
||||
threshold_pct=args.threshold_backup_size_percentage, use_iam=args.use_iam,
|
||||
no_leader=args.no_leader, retries=args.retries)
|
||||
exit_code = restore.run()
|
||||
if not exit_code == ExitCode.RETRY_LATER: # only WAL-E failures lead to the retry
|
||||
if exit_code != ExitCode.RETRY_LATER: # only WAL-E failures lead to the retry
|
||||
logger.debug('exit_code is %r, not retrying', exit_code)
|
||||
break
|
||||
time.sleep(RETRY_SLEEP_INTERVAL)
|
||||
|
||||
assert exit_code is not None
|
||||
return exit_code
|
||||
|
||||
|
||||
|
||||
117
patroni/utils.py
117
patroni/utils.py
@@ -7,9 +7,9 @@
|
||||
:var DEC_RE: regular expression to match decimal numbers, signed or unsigned.
|
||||
:var HEX_RE: regular expression to match hex strings, signed or unsigned.
|
||||
:var DBL_RE: regular expression to match double precision numbers, signed or unsigned. Matches scientific notation too.
|
||||
:var WHITESPACE_RE: regular expression to match whitespace characters
|
||||
"""
|
||||
import errno
|
||||
import json.decoder as json_decoder
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
@@ -21,9 +21,10 @@ import tempfile
|
||||
import time
|
||||
from shlex import split
|
||||
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union, TYPE_CHECKING
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Union, Tuple, Type, TYPE_CHECKING
|
||||
|
||||
from dateutil import tz
|
||||
from json import JSONDecoder
|
||||
from urllib3.response import HTTPResponse
|
||||
|
||||
from .exceptions import PatroniException
|
||||
@@ -42,9 +43,10 @@ OCT_RE = re.compile(r'^[-+]?0[0-7]*')
|
||||
DEC_RE = re.compile(r'^[-+]?(0|[1-9][0-9]*)')
|
||||
HEX_RE = re.compile(r'^[-+]?0x[0-9a-fA-F]+')
|
||||
DBL_RE = re.compile(r'^[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?')
|
||||
WHITESPACE_RE = re.compile(r'[ \t\n\r]*', re.VERBOSE | re.MULTILINE | re.DOTALL)
|
||||
|
||||
|
||||
def deep_compare(obj1: Dict, obj2: Dict) -> bool:
|
||||
def deep_compare(obj1: Dict[Any, Union[Any, Dict[Any, Any]]], obj2: Dict[Any, Union[Any, Dict[Any, Any]]]) -> bool:
|
||||
"""Recursively compare two dictionaries to check if they are equal in terms of keys and values.
|
||||
|
||||
.. note::
|
||||
@@ -84,7 +86,7 @@ def deep_compare(obj1: Dict, obj2: Dict) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def patch_config(config: Dict, data: Dict) -> bool:
|
||||
def patch_config(config: Dict[Any, Union[Any, Dict[Any, Any]]], data: Dict[Any, Union[Any, Dict[Any, Any]]]) -> bool:
|
||||
"""Update and append to dictionary *config* from overrides in *data*.
|
||||
|
||||
.. note::
|
||||
@@ -239,7 +241,7 @@ def strtod(value: Any) -> Tuple[Union[float, None], str]:
|
||||
return None, value
|
||||
|
||||
|
||||
def convert_to_base_unit(value: Union[int, float], unit: str, base_unit: str) -> Union[int, float, None]:
|
||||
def convert_to_base_unit(value: Union[int, float], unit: str, base_unit: Optional[str]) -> Union[int, float, None]:
|
||||
"""Convert *value* as a *unit* of compute information or time to *base_unit*.
|
||||
|
||||
:param value: value to be converted to the base unit.
|
||||
@@ -268,7 +270,7 @@ def convert_to_base_unit(value: Union[int, float], unit: str, base_unit: str) ->
|
||||
>>> convert_to_base_unit(1, 'GB', '512 MB') is None
|
||||
True
|
||||
"""
|
||||
convert = {
|
||||
convert: Dict[str, Dict[str, Union[int, float]]] = {
|
||||
'B': {'B': 1, 'kB': 1024, 'MB': 1024 * 1024, 'GB': 1024 * 1024 * 1024, 'TB': 1024 * 1024 * 1024 * 1024},
|
||||
'kB': {'B': 1.0 / 1024, 'kB': 1, 'MB': 1024, 'GB': 1024 * 1024, 'TB': 1024 * 1024 * 1024},
|
||||
'MB': {'B': 1.0 / (1024 * 1024), 'kB': 1.0 / 1024, 'MB': 1, 'GB': 1024, 'TB': 1024 * 1024},
|
||||
@@ -287,7 +289,7 @@ def convert_to_base_unit(value: Union[int, float], unit: str, base_unit: str) ->
|
||||
else:
|
||||
base_value = 1
|
||||
|
||||
if base_unit in convert and unit in convert[base_unit]:
|
||||
if base_value is not None and base_unit in convert and unit in convert[base_unit]:
|
||||
value *= convert[base_unit][unit] / float(base_value)
|
||||
|
||||
if unit in round_order:
|
||||
@@ -297,7 +299,7 @@ def convert_to_base_unit(value: Union[int, float], unit: str, base_unit: str) ->
|
||||
return value
|
||||
|
||||
|
||||
def parse_int(value: Any, base_unit: Optional[str] = None) -> Union[int, None]:
|
||||
def parse_int(value: Any, base_unit: Optional[str] = None) -> Optional[int]:
|
||||
"""Parse *value* as an :class:`int`.
|
||||
|
||||
:param value: any value that can be handled either by :func:`strtol` or :func:`strtod`. If *value* contains a
|
||||
@@ -350,7 +352,7 @@ def parse_int(value: Any, base_unit: Optional[str] = None) -> Union[int, None]:
|
||||
return round(val)
|
||||
|
||||
|
||||
def parse_real(value: Any, base_unit: Optional[str] = None) -> Union[float, None]:
|
||||
def parse_real(value: Any, base_unit: Optional[str] = None) -> Optional[float]:
|
||||
"""Parse *value* as a :class:`float`.
|
||||
|
||||
:param value: any value that can be handled by :func:`strtod`. If *value* contains a unit, then *base_unit* must
|
||||
@@ -381,7 +383,7 @@ def parse_real(value: Any, base_unit: Optional[str] = None) -> Union[float, None
|
||||
return convert_to_base_unit(val, unit, base_unit)
|
||||
|
||||
|
||||
def compare_values(vartype: str, unit: str, old_value: Any, new_value: Any) -> bool:
|
||||
def compare_values(vartype: str, unit: Optional[str], old_value: Any, new_value: Any) -> bool:
|
||||
"""Check if *old_value* and *new_value* are equivalent after parsing them as *vartype*.
|
||||
|
||||
:param vartpe: the target type to parse *old_value* and *new_value* before comparing them. Accepts any among of the
|
||||
@@ -426,7 +428,7 @@ def compare_values(vartype: str, unit: str, old_value: Any, new_value: Any) -> b
|
||||
>>> compare_values('integer', 'kB', 4098, '4097.5kB')
|
||||
True
|
||||
"""
|
||||
converters = {
|
||||
converters: Dict[str, Callable[[str, Optional[str]], Union[None, bool, int, float, str]]] = {
|
||||
'bool': lambda v1, v2: parse_bool(v1),
|
||||
'integer': parse_int,
|
||||
'real': parse_real,
|
||||
@@ -434,11 +436,11 @@ def compare_values(vartype: str, unit: str, old_value: Any, new_value: Any) -> b
|
||||
'string': lambda v1, v2: str(v1)
|
||||
}
|
||||
|
||||
convert = converters.get(vartype) or converters['string']
|
||||
old_value = convert(old_value, None)
|
||||
new_value = convert(new_value, unit)
|
||||
converter = converters.get(vartype) or converters['string']
|
||||
old_converted = converter(old_value, None)
|
||||
new_converted = converter(new_value, unit)
|
||||
|
||||
return old_value is not None and new_value is not None and old_value == new_value
|
||||
return old_converted is not None and new_converted is not None and old_converted == new_converted
|
||||
|
||||
|
||||
def _sleep(interval: Union[int, float]) -> None:
|
||||
@@ -467,11 +469,11 @@ class Retry(object):
|
||||
:ivar retry_exceptions: single exception or tuple
|
||||
"""
|
||||
|
||||
def __init__(self, max_tries: Optional[int] = 1, delay: Optional[float] = 0.1, backoff: Optional[int] = 2,
|
||||
max_jitter: Optional[float] = 0.8, max_delay: Optional[int] = 3600,
|
||||
sleep_func: Optional[Callable[[Union[int, float]], None]] = _sleep,
|
||||
def __init__(self, max_tries: Optional[int] = 1, delay: float = 0.1, backoff: int = 2,
|
||||
max_jitter: float = 0.8, max_delay: int = 3600,
|
||||
sleep_func: Callable[[Union[int, float]], None] = _sleep,
|
||||
deadline: Optional[Union[int, float]] = None,
|
||||
retry_exceptions: Optional[Union[Exception, Tuple[Exception]]] = PatroniException) -> None:
|
||||
retry_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = PatroniException) -> None:
|
||||
"""Create a :class:`Retry` instance for retrying function calls.
|
||||
|
||||
:param max_tries: how many times to retry the command. ``-1`` means infinite tries.
|
||||
@@ -504,7 +506,7 @@ class Retry(object):
|
||||
def copy(self) -> 'Retry':
|
||||
"""Return a clone of this retry manager."""
|
||||
return Retry(max_tries=self.max_tries, delay=self.delay, backoff=self.backoff,
|
||||
max_jitter=self.max_jitter / 100.0, max_delay=self.max_delay, sleep_func=self.sleep_func,
|
||||
max_jitter=self.max_jitter / 100.0, max_delay=int(self.max_delay), sleep_func=self.sleep_func,
|
||||
deadline=self.deadline, retry_exceptions=self.retry_exceptions)
|
||||
|
||||
@property
|
||||
@@ -525,11 +527,11 @@ class Retry(object):
|
||||
self._cur_delay = min(self._cur_delay * self.backoff, self.max_delay)
|
||||
|
||||
@property
|
||||
def stoptime(self) -> Union[float, None]:
|
||||
def stoptime(self) -> float:
|
||||
"""Get the current stop time."""
|
||||
return self._cur_stoptime
|
||||
return self._cur_stoptime or 0
|
||||
|
||||
def __call__(self, func: Callable, *args: Any, **kwargs: Any) -> Any:
|
||||
def __call__(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
||||
"""Call a function *func* with arguments ``*args`` and ``*kwargs`` in a loop.
|
||||
|
||||
*func* will be called until one of the following conditions is met:
|
||||
@@ -561,7 +563,9 @@ class Retry(object):
|
||||
logger.warning('Retry got exception: %s', e)
|
||||
raise RetryFailedError("Too many retry attempts")
|
||||
self._attempts += 1
|
||||
sleeptime = hasattr(e, 'sleeptime') and e.sleeptime or self.sleeptime
|
||||
sleeptime = getattr(e, 'sleeptime', None)
|
||||
if not isinstance(sleeptime, (int, float)):
|
||||
sleeptime = self.sleeptime
|
||||
|
||||
if self._cur_stoptime is not None and time.time() + sleeptime >= self._cur_stoptime:
|
||||
logger.warning('Retry got exception: %s', e)
|
||||
@@ -571,7 +575,7 @@ class Retry(object):
|
||||
self.update_delay()
|
||||
|
||||
|
||||
def polling_loop(timeout: Union[int, float], interval: Optional[Union[int, float]] = 1) -> Iterator[int]:
|
||||
def polling_loop(timeout: Union[int, float], interval: Union[int, float] = 1) -> Iterator[int]:
|
||||
"""Return an iterator that returns values every *interval* seconds until *timeout* has passed.
|
||||
|
||||
.. note::
|
||||
@@ -587,10 +591,10 @@ def polling_loop(timeout: Union[int, float], interval: Optional[Union[int, float
|
||||
while time.time() < end_time:
|
||||
yield iteration
|
||||
iteration += 1
|
||||
time.sleep(interval)
|
||||
time.sleep(float(interval))
|
||||
|
||||
|
||||
def split_host_port(value: str, default_port: int) -> Tuple[str, int]:
|
||||
def split_host_port(value: str, default_port: Optional[int]) -> Tuple[str, int]:
|
||||
"""Extract host(s) and port from *value*.
|
||||
|
||||
:param value: string from where host(s) and port will be extracted. Accepts either of these formats
|
||||
@@ -625,11 +629,11 @@ def split_host_port(value: str, default_port: int) -> Tuple[str, int]:
|
||||
# If *value* contains ``:`` we consider it to be an IPv6 address, so we attempt to remove possible square brackets
|
||||
if ':' in t[0]:
|
||||
t[0] = ','.join([h.strip().strip('[]') for h in t[0].split(',')])
|
||||
t.append(default_port)
|
||||
t.append(str(default_port))
|
||||
return t[0], int(t[1])
|
||||
|
||||
|
||||
def uri(proto: str, netloc: Union[List, Tuple[str, int], str], path: Optional[str] = '',
|
||||
def uri(proto: str, netloc: Union[List[str], Tuple[str, Union[int, str]], str], path: Optional[str] = '',
|
||||
user: Optional[str] = None) -> str:
|
||||
"""Construct URI from given arguments.
|
||||
|
||||
@@ -670,17 +674,15 @@ def iter_response_objects(response: HTTPResponse) -> Iterator[Dict[str, Any]]:
|
||||
:rtype: Iterator[:class:`dict`] with current JSON document.
|
||||
"""
|
||||
prev = ''
|
||||
decoder = json_decoder.JSONDecoder()
|
||||
decoder = JSONDecoder()
|
||||
for chunk in response.read_chunked(decode_content=False):
|
||||
if isinstance(chunk, bytes):
|
||||
chunk = chunk.decode('utf-8')
|
||||
chunk = prev + chunk
|
||||
chunk = prev + chunk.decode('utf-8')
|
||||
|
||||
length = len(chunk)
|
||||
# ``chunk`` is analyzed in parts. ``idx`` holds the position of the first character in the current part that is
|
||||
# neither space nor tab nor line-break, or in other words, the position in the ``chunk`` where it is likely
|
||||
# that a JSON document begins
|
||||
idx = json_decoder.WHITESPACE.match(chunk, 0).end()
|
||||
idx = WHITESPACE_RE.match(chunk, 0).end() # pyright: ignore [reportOptionalMemberAccess]
|
||||
while idx < length:
|
||||
try:
|
||||
# Get a JSON document from the chunk. ``message`` is a dictionary representing the JSON document, and
|
||||
@@ -690,7 +692,7 @@ def iter_response_objects(response: HTTPResponse) -> Iterator[Dict[str, Any]]:
|
||||
break
|
||||
else:
|
||||
yield message
|
||||
idx = json_decoder.WHITESPACE.match(chunk, idx).end()
|
||||
idx = WHITESPACE_RE.match(chunk, idx).end() # pyright: ignore [reportOptionalMemberAccess]
|
||||
# It is not usual that a ``chunk`` would contain more than one JSON document, but we handle that just in case
|
||||
prev = chunk[idx:]
|
||||
|
||||
@@ -732,7 +734,7 @@ def cluster_as_json(cluster: 'Cluster', global_config: Optional['GlobalConfig']
|
||||
leader_name = cluster.leader.name if cluster.leader else None
|
||||
cluster_lsn = cluster.last_lsn or 0
|
||||
|
||||
ret = {'members': []}
|
||||
ret: Dict[str, Any] = {'members': []}
|
||||
for m in cluster.members:
|
||||
if m.name == leader_name:
|
||||
role = 'standby_leader' if global_config.is_standby_cluster else 'leader'
|
||||
@@ -762,7 +764,8 @@ def cluster_as_json(cluster: 'Cluster', global_config: Optional['GlobalConfig']
|
||||
ret['members'].append(member)
|
||||
|
||||
# sort members by name for consistency
|
||||
ret['members'].sort(key=lambda m: m['name'])
|
||||
cmp: Callable[[Dict[str, Any]], bool] = lambda m: m['name']
|
||||
ret['members'].sort(key=cmp)
|
||||
if global_config.is_paused:
|
||||
ret['pause'] = True
|
||||
if cluster.failover and cluster.failover.scheduled_at:
|
||||
@@ -791,7 +794,7 @@ def is_subpath(d1: str, d2: str) -> bool:
|
||||
return os.path.commonprefix([real_d1, real_d2 + os.path.sep]) == real_d1
|
||||
|
||||
|
||||
def validate_directory(d: str, msg: Optional[str] = "{} {}") -> None:
|
||||
def validate_directory(d: str, msg: str = "{} {}") -> None:
|
||||
"""Ensure directory exists and is writable.
|
||||
|
||||
.. note::
|
||||
@@ -842,7 +845,7 @@ def data_directory_is_empty(data_dir: str) -> bool:
|
||||
return all(os.name != 'nt' and (n.startswith('.') or n == 'lost+found') for n in os.listdir(data_dir))
|
||||
|
||||
|
||||
def keepalive_intvl(timeout: int, idle: int, cnt: Optional[int] = 3) -> int:
|
||||
def keepalive_intvl(timeout: int, idle: int, cnt: int = 3) -> int:
|
||||
"""Calculate the value to be used as ``TCP_KEEPINTVL`` based on *timeout*, *idle*, and *cnt*.
|
||||
|
||||
:param timeout: value for ``TCP_USER_TIMEOUT``.
|
||||
@@ -854,7 +857,7 @@ def keepalive_intvl(timeout: int, idle: int, cnt: Optional[int] = 3) -> int:
|
||||
return max(1, int(float(timeout - idle) / cnt))
|
||||
|
||||
|
||||
def keepalive_socket_options(timeout: int, idle: int, cnt: Optional[int] = 3) -> Iterator[Tuple[int, int, int]]:
|
||||
def keepalive_socket_options(timeout: int, idle: int, cnt: int = 3) -> Iterator[Tuple[int, int, int]]:
|
||||
"""Get all keepalive related options to be set in a socket.
|
||||
|
||||
:param timeout: value for ``TCP_USER_TIMEOUT``.
|
||||
@@ -871,25 +874,27 @@ def keepalive_socket_options(timeout: int, idle: int, cnt: Optional[int] = 3) ->
|
||||
"""
|
||||
yield (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
|
||||
|
||||
if sys.platform.startswith('linux'):
|
||||
yield (socket.SOL_TCP, 18, int(timeout * 1000)) # TCP_USER_TIMEOUT
|
||||
TCP_KEEPIDLE = getattr(socket, 'TCP_KEEPIDLE', None)
|
||||
TCP_KEEPINTVL = getattr(socket, 'TCP_KEEPINTVL', None)
|
||||
TCP_KEEPCNT = getattr(socket, 'TCP_KEEPCNT', None)
|
||||
elif sys.platform.startswith('darwin'):
|
||||
TCP_KEEPIDLE = 0x10 # (named "TCP_KEEPALIVE" in C)
|
||||
TCP_KEEPINTVL = 0x101
|
||||
TCP_KEEPCNT = 0x102
|
||||
else:
|
||||
if not (sys.platform.startswith('linux') or sys.platform.startswith('darwin')):
|
||||
return
|
||||
|
||||
intvl = keepalive_intvl(timeout, idle, cnt)
|
||||
yield (socket.IPPROTO_TCP, TCP_KEEPIDLE, idle)
|
||||
yield (socket.IPPROTO_TCP, TCP_KEEPINTVL, intvl)
|
||||
yield (socket.IPPROTO_TCP, TCP_KEEPCNT, cnt)
|
||||
if sys.platform.startswith('linux'):
|
||||
yield (socket.SOL_TCP, 18, int(timeout * 1000)) # TCP_USER_TIMEOUT
|
||||
|
||||
# The socket constants from MacOS netinet/tcp.h are not exported by python's
|
||||
# socket module, therefore we are using 0x10, 0x101, 0x102 constants.
|
||||
TCP_KEEPIDLE = getattr(socket, 'TCP_KEEPIDLE', 0x10 if sys.platform.startswith('darwin') else None)
|
||||
if TCP_KEEPIDLE is not None:
|
||||
yield (socket.IPPROTO_TCP, TCP_KEEPIDLE, idle)
|
||||
TCP_KEEPINTVL = getattr(socket, 'TCP_KEEPINTVL', 0x101 if sys.platform.startswith('darwin') else None)
|
||||
if TCP_KEEPINTVL is not None:
|
||||
intvl = keepalive_intvl(timeout, idle, cnt)
|
||||
yield (socket.IPPROTO_TCP, TCP_KEEPINTVL, intvl)
|
||||
TCP_KEEPCNT = getattr(socket, 'TCP_KEEPCNT', 0x102 if sys.platform.startswith('darwin') else None)
|
||||
if TCP_KEEPCNT is not None:
|
||||
yield (socket.IPPROTO_TCP, TCP_KEEPCNT, cnt)
|
||||
|
||||
|
||||
def enable_keepalive(sock: socket.socket, timeout: int, idle: int, cnt: Optional[int] = 3) -> Union[int, None]:
|
||||
def enable_keepalive(sock: socket.socket, timeout: int, idle: int, cnt: int = 3) -> None:
|
||||
"""Enable keepalive for *sock*.
|
||||
|
||||
Will set socket options depending on the platform, as per return of :func:`keepalive_socket_options`.
|
||||
@@ -908,7 +913,7 @@ def enable_keepalive(sock: socket.socket, timeout: int, idle: int, cnt: Optional
|
||||
SIO_KEEPALIVE_VALS = getattr(socket, 'SIO_KEEPALIVE_VALS', None)
|
||||
if SIO_KEEPALIVE_VALS is not None: # Windows
|
||||
intvl = keepalive_intvl(timeout, idle, cnt)
|
||||
return sock.ioctl(SIO_KEEPALIVE_VALS, (1, idle * 1000, intvl * 1000))
|
||||
sock.ioctl(SIO_KEEPALIVE_VALS, (1, idle * 1000, intvl * 1000))
|
||||
|
||||
for opt in keepalive_socket_options(timeout, idle, cnt):
|
||||
sock.setsockopt(*opt)
|
||||
|
||||
@@ -11,9 +11,9 @@ import shutil
|
||||
import socket
|
||||
import subprocess
|
||||
|
||||
from typing import Any, Union, Iterator, List, Optional as OptionalType
|
||||
from typing import Any, Dict, Union, Iterator, List, Optional as OptionalType
|
||||
|
||||
from .utils import split_host_port, data_directory_is_empty
|
||||
from .utils import parse_int, split_host_port, data_directory_is_empty
|
||||
from .dcs import dcs_modules
|
||||
from .exceptions import ConfigParseError
|
||||
|
||||
@@ -48,8 +48,7 @@ def validate_connect_address(address: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def validate_host_port(host_port: str, listen: OptionalType[bool] = False,
|
||||
multiple_hosts: OptionalType[bool] = False) -> bool:
|
||||
def validate_host_port(host_port: str, listen: bool = False, multiple_hosts: bool = False) -> bool:
|
||||
"""Check if host(s) and port are valid and available for usage.
|
||||
|
||||
:param host_port: the host(s) and port to be validated. It can be in either of these formats
|
||||
@@ -68,7 +67,7 @@ def validate_host_port(host_port: str, listen: OptionalType[bool] = False,
|
||||
* If :class:`socket.gaierror` is thrown by socket module when attempting to connect to the given address(es).
|
||||
"""
|
||||
try:
|
||||
hosts, port = split_host_port(host_port, None)
|
||||
hosts, port = split_host_port(host_port, 1)
|
||||
except (ValueError, TypeError):
|
||||
raise ConfigParseError("contains a wrong value")
|
||||
else:
|
||||
@@ -197,6 +196,7 @@ def get_major_version(bin_dir: OptionalType[str] = None) -> str:
|
||||
binary = os.path.join(bin_dir, 'postgres')
|
||||
version = subprocess.check_output([binary, '--version']).decode()
|
||||
version = re.match(r'^[^\s]+ [^\s]+ (\d+)(\.(\d+))?', version)
|
||||
assert version is not None
|
||||
return '.'.join([version.group(1), version.group(3)]) if int(version.group(1)) < 10 else version.group(1)
|
||||
|
||||
|
||||
@@ -251,8 +251,8 @@ class Result(object):
|
||||
:ivar error: error message if the validation failed, otherwise ``None``.
|
||||
"""
|
||||
|
||||
def __init__(self, status: bool, error: OptionalType[str] = "didn't pass validation", level: OptionalType[int] = 0,
|
||||
path: OptionalType[str] = "", data: OptionalType[Any] = "") -> None:
|
||||
def __init__(self, status: bool, error: OptionalType[str] = "didn't pass validation", level: int = 0,
|
||||
path: str = "", data: Any = "") -> None:
|
||||
"""Create a :class:`Result` object based on the given arguments.
|
||||
|
||||
.. note::
|
||||
@@ -290,7 +290,7 @@ class Case(object):
|
||||
them, if they are set.
|
||||
"""
|
||||
|
||||
def __init__(self, schema: dict) -> None:
|
||||
def __init__(self, schema: Dict[str, Any]) -> None:
|
||||
"""Create a :class:`Case` object.
|
||||
|
||||
:param schema: the schema for validating a set of attributes that may be available in the configuration.
|
||||
@@ -317,7 +317,7 @@ class Or(object):
|
||||
validation functions and/or expected types for a given configuration option.
|
||||
"""
|
||||
|
||||
def __init__(self, *args) -> None:
|
||||
def __init__(self, *args: Any) -> None:
|
||||
"""Create an :class:`Or` object.
|
||||
|
||||
:param `*args`: any arguments that the caller wants to be stored in this :class:`Or` object.
|
||||
@@ -486,7 +486,7 @@ class Schema(object):
|
||||
:param data: configuration to be validated against ``validator``.
|
||||
:returns: list of errors identified while validating the *data*, if any.
|
||||
"""
|
||||
errors = []
|
||||
errors: List[str] = []
|
||||
for i in self.validate(data):
|
||||
if not i.status:
|
||||
errors.append(str(i))
|
||||
@@ -533,14 +533,13 @@ class Schema(object):
|
||||
except Exception as e:
|
||||
yield Result(False, "didn't pass validation: {}".format(e), data=self.data)
|
||||
elif isinstance(self.validator, dict):
|
||||
if not len(self.validator):
|
||||
if not isinstance(self.data, dict):
|
||||
yield Result(isinstance(self.data, dict), "is not a dictionary", level=1, data=self.data)
|
||||
elif isinstance(self.validator, list):
|
||||
if not isinstance(self.data, list):
|
||||
yield Result(isinstance(self.data, list), "is not a list", level=1, data=self.data)
|
||||
return
|
||||
for i in self.iter():
|
||||
yield i
|
||||
yield from self.iter()
|
||||
|
||||
def iter(self) -> Iterator[Result]:
|
||||
"""Iterate over ``validator``, if it is an iterable object, to validate the corresponding entries in ``data``.
|
||||
@@ -553,12 +552,11 @@ class Schema(object):
|
||||
if not isinstance(self.data, dict):
|
||||
yield Result(False, "is not a dictionary.", level=1)
|
||||
else:
|
||||
for i in self.iter_dict():
|
||||
yield i
|
||||
yield from self.iter_dict()
|
||||
elif isinstance(self.validator, list):
|
||||
if len(self.data) == 0:
|
||||
yield Result(False, "is an empty list", data=self.data)
|
||||
if len(self.validator) > 0:
|
||||
if self.validator:
|
||||
for key, value in enumerate(self.data):
|
||||
# Although the value in the configuration (`data`) is expected to contain 1 or more entries, only
|
||||
# the first validator defined in `validator` property list will be used. It is only defined as a
|
||||
@@ -569,11 +567,9 @@ class Schema(object):
|
||||
yield Result(v.status, v.error,
|
||||
path=(str(key) + ("." + v.path if v.path else "")), level=v.level, data=value)
|
||||
elif isinstance(self.validator, Directory):
|
||||
for v in self.validator.validate(self.data):
|
||||
yield v
|
||||
yield from self.validator.validate(self.data)
|
||||
elif isinstance(self.validator, Or):
|
||||
for i in self.iter_or():
|
||||
yield i
|
||||
yield from self.iter_or()
|
||||
|
||||
def iter_dict(self) -> Iterator[Result]:
|
||||
"""Iterate over a :class:`dict` based ``validator`` to validate the corresponding entries in ``data``.
|
||||
@@ -606,14 +602,14 @@ class Schema(object):
|
||||
|
||||
:rtype: Iterator[:class:`Result`] objects with the error message related to the failure, if any check fails.
|
||||
"""
|
||||
results = []
|
||||
results: List[Result] = []
|
||||
for a in self.validator.args:
|
||||
r = []
|
||||
r: List[Result] = []
|
||||
# Each of the `Or` validators can throw 0 to many `Result` instances.
|
||||
for v in Schema(a).validate(self.data):
|
||||
r.append(v)
|
||||
if any([x.status for x in r]) and not all([x.status for x in r]):
|
||||
results += filter(lambda x: not x.status, r)
|
||||
results += [x for x in r if not x.status]
|
||||
else:
|
||||
results += r
|
||||
# None of the `Or` validators succeeded to validate `data`, so we report the issues back.
|
||||
@@ -664,12 +660,12 @@ def _get_type_name(python_type: Any) -> str:
|
||||
Returns:
|
||||
User friendly name of the given Python type.
|
||||
"""
|
||||
return {str: 'a string', int: 'an integer', float: 'a number',
|
||||
bool: 'a boolean', list: 'an array', dict: 'a dictionary'}.get(
|
||||
python_type, getattr(python_type, __name__, "unknown type"))
|
||||
types: Dict[Any, str] = {str: 'a string', int: 'an integer', float: 'a number',
|
||||
bool: 'a boolean', list: 'an array', dict: 'a dictionary'}
|
||||
return types.get(python_type, getattr(python_type, __name__, "unknown type"))
|
||||
|
||||
|
||||
def assert_(condition: bool, message: OptionalType[str] = "Wrong value") -> None:
|
||||
def assert_(condition: bool, message: str = "Wrong value") -> None:
|
||||
"""Assert that a given condition is ``True``.
|
||||
|
||||
If the assertion fails, then throw a message.
|
||||
@@ -680,14 +676,63 @@ def assert_(condition: bool, message: OptionalType[str] = "Wrong value") -> None
|
||||
assert condition, message
|
||||
|
||||
|
||||
class IntValidator(object):
|
||||
"""Validate an integer setting.
|
||||
|
||||
:cvar expected_type: the expect Python type for an integer setting (:class:`int`).
|
||||
:ivar min: minimum allowed value for the setting, if any.
|
||||
:ivar max: maximum allowed value for the setting, if any.
|
||||
:ivar base_unit: the base unit to convert the value to before checking if it's within `min` and `max` range.
|
||||
:ivar raise_assert: if an ``assert`` call should be performed regarding expected type and valid range.
|
||||
"""
|
||||
expected_type = int
|
||||
|
||||
def __init__(self, min: OptionalType[int] = None, max: OptionalType[int] = None,
|
||||
base_unit: OptionalType[str] = None, raise_assert: bool = False) -> None:
|
||||
"""Create an :class:`IntValidator` object with the given rules.
|
||||
|
||||
:param min: minimum allowed value for the setting, if any.
|
||||
:param max: maximum allowed value for the setting, if any.
|
||||
:param base_unit: the base unit to convert the value to before checking if it's within *min* and *max* range.
|
||||
:param raise_assert: if an ``assert`` call should be performed regarding expected type and valid range.
|
||||
"""
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.base_unit = base_unit
|
||||
self.raise_assert = raise_assert
|
||||
|
||||
def __call__(self, value: Union[int, str]) -> bool:
|
||||
"""Check if *value* is a valid integer and within the expected range.
|
||||
|
||||
.. note::
|
||||
If ``raise_assert`` is ``True`` and *value* is not valid, then an ``AssertionError`` will be triggered.
|
||||
:param value: value to be checked against the rules defined for this :class:`IntValidator` instance.
|
||||
:returns: ``True`` if *value* is valid and within the expected range.
|
||||
"""
|
||||
if self.base_unit:
|
||||
value = parse_int(value, self.base_unit) or ""
|
||||
ret = isinstance(value, int)\
|
||||
and (self.min is None or value >= self.min)\
|
||||
and (self.max is None or value <= self.max)
|
||||
|
||||
if self.raise_assert:
|
||||
assert_(ret)
|
||||
return ret
|
||||
|
||||
|
||||
def validate_watchdog_mode(value: Any) -> None:
|
||||
assert_(isinstance(value, (str, bool)), "expected type is not a string")
|
||||
assert_(value in (False, "off", "automatic", "required"))
|
||||
|
||||
|
||||
userattributes = {"username": "", Optional("password"): ""}
|
||||
available_dcs = [m.split(".")[-1] for m in dcs_modules()]
|
||||
validate_host_port_list.expected_type = list
|
||||
comma_separated_host_port.expected_type = str
|
||||
validate_connect_address.expected_type = str
|
||||
validate_host_port_listen.expected_type = str
|
||||
validate_host_port_listen_multiple_hosts.expected_type = str
|
||||
validate_data_dir.expected_type = str
|
||||
setattr(validate_host_port_list, 'expected_type', list)
|
||||
setattr(comma_separated_host_port, 'expected_type', str)
|
||||
setattr(validate_connect_address, 'expected_type', str)
|
||||
setattr(validate_host_port_listen, 'expected_type', str)
|
||||
setattr(validate_host_port_listen_multiple_hosts, 'expected_type', str)
|
||||
setattr(validate_data_dir, 'expected_type', str)
|
||||
validate_etcd = {
|
||||
Or("host", "hosts", "srv", "srv_suffix", "url", "proxy"): Case({
|
||||
"host": validate_host_port,
|
||||
@@ -704,7 +749,7 @@ schema = Schema({
|
||||
"restapi": {
|
||||
"listen": validate_host_port_listen,
|
||||
"connect_address": validate_connect_address,
|
||||
Optional("request_queue_size"): lambda i: assert_(0 <= int(i) <= 4096)
|
||||
Optional("request_queue_size"): IntValidator(min=0, max=4096, raise_assert=True)
|
||||
},
|
||||
Optional("bootstrap"): {
|
||||
"dcs": {
|
||||
@@ -726,7 +771,7 @@ schema = Schema({
|
||||
"etcd3": validate_etcd,
|
||||
"exhibitor": {
|
||||
"hosts": [str],
|
||||
"port": lambda i: assert_(int(i) <= 65535),
|
||||
"port": IntValidator(max=65535, raise_assert=True),
|
||||
Optional("pool_interval"): int
|
||||
},
|
||||
"raft": {
|
||||
@@ -767,7 +812,7 @@ schema = Schema({
|
||||
Optional("bin_dir"): Directory(contains_executable=["pg_ctl", "initdb", "pg_controldata", "pg_basebackup",
|
||||
"postgres", "pg_isready"]),
|
||||
Optional("parameters"): {
|
||||
Optional("unix_socket_directories"): lambda s: assert_(all([isinstance(s, str), len(s)]))
|
||||
Optional("unix_socket_directories"): str
|
||||
},
|
||||
Optional("pg_hba"): [str],
|
||||
Optional("pg_ident"): [str],
|
||||
@@ -775,7 +820,7 @@ schema = Schema({
|
||||
Optional("use_pg_rewind"): bool
|
||||
},
|
||||
Optional("watchdog"): {
|
||||
Optional("mode"): lambda m: assert_(m in ["off", "automatic", "required"]),
|
||||
Optional("mode"): validate_watchdog_mode,
|
||||
Optional("device"): str
|
||||
},
|
||||
Optional("tags"): {
|
||||
|
||||
@@ -4,7 +4,10 @@ import platform
|
||||
import sys
|
||||
from threading import RLock
|
||||
|
||||
from patroni.exceptions import WatchdogError
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
from ..config import Config
|
||||
from ..exceptions import WatchdogError
|
||||
|
||||
__all__ = ['WatchdogError', 'Watchdog']
|
||||
|
||||
@@ -15,10 +18,10 @@ MODE_AUTOMATIC = 'automatic' # Will use a watchdog if one is available
|
||||
MODE_OFF = 'off' # Will not try to use a watchdog
|
||||
|
||||
|
||||
def parse_mode(mode):
|
||||
def parse_mode(mode: Union[bool, str]) -> str:
|
||||
if mode is False:
|
||||
return MODE_OFF
|
||||
mode = mode.lower()
|
||||
mode = str(mode).lower()
|
||||
if mode in ['require', 'required']:
|
||||
return MODE_REQUIRED
|
||||
elif mode in ['auto', 'automatic']:
|
||||
@@ -29,8 +32,8 @@ def parse_mode(mode):
|
||||
return MODE_OFF
|
||||
|
||||
|
||||
def synchronized(func):
|
||||
def wrapped(self, *args, **kwargs):
|
||||
def synchronized(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def wrapped(self: 'Watchdog', *args: Any, **kwargs: Any) -> Any:
|
||||
with self._lock:
|
||||
return func(self, *args, **kwargs)
|
||||
return wrapped
|
||||
@@ -38,7 +41,7 @@ def synchronized(func):
|
||||
|
||||
class WatchdogConfig(object):
|
||||
"""Helper to contain a snapshot of configuration"""
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: Config) -> None:
|
||||
watchdog_config = config.get("watchdog") or {'mode': 'automatic'}
|
||||
|
||||
self.mode = parse_mode(watchdog_config.get('mode', 'automatic'))
|
||||
@@ -49,15 +52,15 @@ class WatchdogConfig(object):
|
||||
self.driver_config = dict((k, v) for k, v in watchdog_config.items()
|
||||
if k not in ['mode', 'safety_margin', 'driver'])
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, WatchdogConfig) and \
|
||||
all(getattr(self, attr) == getattr(other, attr) for attr in
|
||||
['mode', 'ttl', 'loop_wait', 'safety_margin', 'driver', 'driver_config'])
|
||||
|
||||
def __ne__(self, other):
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
def get_impl(self):
|
||||
def get_impl(self) -> 'WatchdogBase':
|
||||
if self.driver == 'testing': # pragma: no cover
|
||||
from patroni.watchdog.linux import TestingWatchdogDevice
|
||||
return TestingWatchdogDevice.from_config(self.driver_config)
|
||||
@@ -68,14 +71,14 @@ class WatchdogConfig(object):
|
||||
return NullWatchdog()
|
||||
|
||||
@property
|
||||
def timeout(self):
|
||||
def timeout(self) -> int:
|
||||
if self.safety_margin == -1:
|
||||
return int(self.ttl // 2)
|
||||
else:
|
||||
return self.ttl - self.safety_margin
|
||||
|
||||
@property
|
||||
def timing_slack(self):
|
||||
def timing_slack(self) -> int:
|
||||
return self.timeout - self.loop_wait
|
||||
|
||||
|
||||
@@ -84,8 +87,9 @@ class Watchdog(object):
|
||||
|
||||
When activation fails underlying implementation will be switched to a Null implementation. To avoid log spam
|
||||
activation will only be retried when watchdog configuration is changed."""
|
||||
def __init__(self, config):
|
||||
self.active_config = self.config = WatchdogConfig(config)
|
||||
def __init__(self, config: Config) -> None:
|
||||
self.config = WatchdogConfig(config)
|
||||
self.active_config: WatchdogConfig = self.config
|
||||
self._lock = RLock()
|
||||
self.active = False
|
||||
|
||||
@@ -98,7 +102,7 @@ class Watchdog(object):
|
||||
sys.exit(1)
|
||||
|
||||
@synchronized
|
||||
def reload_config(self, config):
|
||||
def reload_config(self, config: Config) -> None:
|
||||
self.config = WatchdogConfig(config)
|
||||
# Turning a watchdog off can always be done immediately
|
||||
if self.config.mode == MODE_OFF:
|
||||
@@ -115,7 +119,7 @@ class Watchdog(object):
|
||||
self.active_config = self.config
|
||||
|
||||
@synchronized
|
||||
def activate(self):
|
||||
def activate(self) -> bool:
|
||||
"""Activates the watchdog device with suitable timeouts. While watchdog is active keepalive needs
|
||||
to be called every time loop_wait expires.
|
||||
|
||||
@@ -124,7 +128,7 @@ class Watchdog(object):
|
||||
self.active = True
|
||||
return self._activate()
|
||||
|
||||
def _activate(self):
|
||||
def _activate(self) -> bool:
|
||||
self.active_config = self.config
|
||||
|
||||
if self.config.timing_slack < 0:
|
||||
@@ -138,12 +142,13 @@ class Watchdog(object):
|
||||
except WatchdogError as e:
|
||||
logger.warning("Could not activate %s: %s", self.impl.describe(), e)
|
||||
self.impl = NullWatchdog()
|
||||
actual_timeout = self.impl.get_timeout()
|
||||
|
||||
if self.impl.is_running and not self.impl.can_be_disabled:
|
||||
logger.warning("Watchdog implementation can't be disabled."
|
||||
" Watchdog will trigger after Patroni loses leader key.")
|
||||
|
||||
if not self.impl.is_running or actual_timeout > self.config.timeout:
|
||||
if not self.impl.is_running or actual_timeout and actual_timeout > self.config.timeout:
|
||||
if self.config.mode == MODE_REQUIRED:
|
||||
if self.impl.is_null:
|
||||
logger.error("Configuration requires watchdog, but watchdog could not be configured.")
|
||||
@@ -167,7 +172,7 @@ class Watchdog(object):
|
||||
|
||||
return True
|
||||
|
||||
def _set_timeout(self):
|
||||
def _set_timeout(self) -> Optional[int]:
|
||||
if self.impl.has_set_timeout():
|
||||
self.impl.set_timeout(self.config.timeout)
|
||||
|
||||
@@ -184,11 +189,11 @@ class Watchdog(object):
|
||||
return actual_timeout
|
||||
|
||||
@synchronized
|
||||
def disable(self):
|
||||
def disable(self) -> None:
|
||||
self._disable()
|
||||
self.active = False
|
||||
|
||||
def _disable(self):
|
||||
def _disable(self) -> None:
|
||||
try:
|
||||
if self.impl.is_running and not self.impl.can_be_disabled:
|
||||
# Give sysadmin some extra time to clean stuff up.
|
||||
@@ -200,7 +205,7 @@ class Watchdog(object):
|
||||
logger.error("Error while disabling watchdog: %s", e)
|
||||
|
||||
@synchronized
|
||||
def keepalive(self):
|
||||
def keepalive(self) -> None:
|
||||
try:
|
||||
if self.active:
|
||||
self.impl.keepalive()
|
||||
@@ -225,12 +230,12 @@ class Watchdog(object):
|
||||
|
||||
@property
|
||||
@synchronized
|
||||
def is_running(self):
|
||||
def is_running(self) -> bool:
|
||||
return self.impl.is_running
|
||||
|
||||
@property
|
||||
@synchronized
|
||||
def is_healthy(self):
|
||||
def is_healthy(self) -> bool:
|
||||
if self.config.mode != MODE_REQUIRED:
|
||||
return True
|
||||
return self.config.timing_slack >= 0 and self.impl.is_healthy
|
||||
@@ -242,60 +247,59 @@ class WatchdogBase(abc.ABC):
|
||||
is_null = False
|
||||
|
||||
@property
|
||||
def is_running(self):
|
||||
def is_running(self) -> bool:
|
||||
"""Returns True when watchdog is activated and capable of performing it's task."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_healthy(self):
|
||||
def is_healthy(self) -> bool:
|
||||
"""Returns False when calling open() is known to fail."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def can_be_disabled(self):
|
||||
def can_be_disabled(self) -> bool:
|
||||
"""Returns True when watchdog will be disabled by calling close(). Some watchdog devices
|
||||
will keep running no matter what once activated. May raise WatchdogError if called without
|
||||
calling open() first."""
|
||||
return True
|
||||
|
||||
@abc.abstractmethod
|
||||
def open(self):
|
||||
def open(self) -> None:
|
||||
"""Open watchdog device.
|
||||
|
||||
When watchdog is opened keepalive must be called. Returns nothing on success
|
||||
or raises WatchdogError if the device could not be opened."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Gracefully close watchdog device."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def keepalive(self):
|
||||
def keepalive(self) -> None:
|
||||
"""Resets the watchdog timer.
|
||||
|
||||
Watchdog must be open when keepalive is called."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_timeout(self):
|
||||
def get_timeout(self) -> int:
|
||||
"""Returns the current keepalive timeout in effect."""
|
||||
|
||||
@staticmethod
|
||||
def has_set_timeout():
|
||||
def has_set_timeout(self) -> bool:
|
||||
"""Returns True if setting a timeout is supported."""
|
||||
return False
|
||||
|
||||
def set_timeout(self, timeout):
|
||||
def set_timeout(self, timeout: int) -> None:
|
||||
"""Set the watchdog timer timeout.
|
||||
|
||||
:param timeout: watchdog timeout in seconds"""
|
||||
raise WatchdogError("Setting timeout is not supported on {0}".format(self.describe()))
|
||||
|
||||
def describe(self):
|
||||
def describe(self) -> str:
|
||||
"""Human readable name for this device"""
|
||||
return self.__class__.__name__
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
def from_config(cls, config: Dict[str, Any]) -> 'WatchdogBase':
|
||||
return cls()
|
||||
|
||||
|
||||
@@ -303,15 +307,15 @@ class NullWatchdog(WatchdogBase):
|
||||
"""Null implementation when watchdog is not supported."""
|
||||
is_null = True
|
||||
|
||||
def open(self):
|
||||
def open(self) -> None:
|
||||
return
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
return
|
||||
|
||||
def keepalive(self):
|
||||
def keepalive(self) -> None:
|
||||
return
|
||||
|
||||
def get_timeout(self):
|
||||
def get_timeout(self) -> int:
|
||||
# A big enough number to not matter
|
||||
return 1000000000
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import collections
|
||||
import ctypes
|
||||
import os
|
||||
import platform
|
||||
from patroni.watchdog.base import WatchdogBase, WatchdogError
|
||||
|
||||
from typing import Any, Dict, NamedTuple
|
||||
|
||||
from .base import WatchdogBase, WatchdogError
|
||||
|
||||
# Pythonification of linux/ioctl.h
|
||||
IOC_NONE = 0
|
||||
@@ -19,7 +21,7 @@ machine = platform.machine()
|
||||
if machine in ['mips', 'sparc', 'powerpc', 'ppc64', 'ppc64le']: # pragma: no cover
|
||||
IOC_SIZEBITS = 13
|
||||
IOC_DIRBITS = 3
|
||||
IOC_NONE, IOC_WRITE, IOC_READ = 1, 4, 2
|
||||
IOC_NONE, IOC_WRITE = 1, 4
|
||||
elif machine == 'parisc': # pragma: no cover
|
||||
IOC_WRITE, IOC_READ = 2, 1
|
||||
|
||||
@@ -29,19 +31,19 @@ IOC_SIZESHIFT = IOC_TYPESHIFT + IOC_TYPEBITS
|
||||
IOC_DIRSHIFT = IOC_SIZESHIFT + IOC_SIZEBITS
|
||||
|
||||
|
||||
def IOW(type_, nr, size):
|
||||
def IOW(type_: str, nr: int, size: int) -> int:
|
||||
return IOC(IOC_WRITE, type_, nr, size)
|
||||
|
||||
|
||||
def IOR(type_, nr, size):
|
||||
def IOR(type_: str, nr: int, size: int) -> int:
|
||||
return IOC(IOC_READ, type_, nr, size)
|
||||
|
||||
|
||||
def IOWR(type_, nr, size):
|
||||
def IOWR(type_: str, nr: int, size: int) -> int:
|
||||
return IOC(IOC_READ | IOC_WRITE, type_, nr, size)
|
||||
|
||||
|
||||
def IOC(dir_, type_, nr, size):
|
||||
def IOC(dir_: int, type_: str, nr: int, size: int) -> int:
|
||||
return (dir_ << IOC_DIRSHIFT) \
|
||||
| (ord(type_) << IOC_TYPESHIFT) \
|
||||
| (nr << IOC_NRSHIFT) \
|
||||
@@ -104,9 +106,13 @@ WDIOS = {
|
||||
# Implementation
|
||||
|
||||
|
||||
class WatchdogInfo(collections.namedtuple('WatchdogInfo', 'options,version,identity')):
|
||||
class WatchdogInfo(NamedTuple):
|
||||
"""Watchdog descriptor from the kernel"""
|
||||
def __getattr__(self, name):
|
||||
options: int
|
||||
version: int
|
||||
identity: str
|
||||
|
||||
def __getattr__(self, name: str) -> bool:
|
||||
"""Convenience has_XYZ attributes for checking WDIOF bits in options"""
|
||||
if name.startswith('has_') and name[4:] in WDIOF:
|
||||
return bool(self.options & WDIOF[name[4:]])
|
||||
@@ -117,32 +123,32 @@ class WatchdogInfo(collections.namedtuple('WatchdogInfo', 'options,version,ident
|
||||
class LinuxWatchdogDevice(WatchdogBase):
|
||||
DEFAULT_DEVICE = '/dev/watchdog'
|
||||
|
||||
def __init__(self, device):
|
||||
def __init__(self, device: str) -> None:
|
||||
self.device = device
|
||||
self._support_cache = None
|
||||
self._fd = None
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
def from_config(cls, config: Dict[str, Any]) -> 'LinuxWatchdogDevice':
|
||||
device = config.get('device', cls.DEFAULT_DEVICE)
|
||||
return cls(device)
|
||||
|
||||
@property
|
||||
def is_running(self):
|
||||
def is_running(self) -> bool:
|
||||
return self._fd is not None
|
||||
|
||||
@property
|
||||
def is_healthy(self):
|
||||
def is_healthy(self) -> bool:
|
||||
return os.path.exists(self.device) and os.access(self.device, os.W_OK)
|
||||
|
||||
def open(self):
|
||||
def open(self) -> None:
|
||||
try:
|
||||
self._fd = os.open(self.device, os.O_WRONLY)
|
||||
except OSError as e:
|
||||
raise WatchdogError("Can't open watchdog device: {0}".format(e))
|
||||
|
||||
def close(self):
|
||||
if self.is_running:
|
||||
def close(self) -> None:
|
||||
if self._fd is not None: # self.is_running
|
||||
try:
|
||||
os.write(self._fd, b'V')
|
||||
os.close(self._fd)
|
||||
@@ -151,10 +157,10 @@ class LinuxWatchdogDevice(WatchdogBase):
|
||||
raise WatchdogError("Error while closing {0}: {1}".format(self.describe(), e))
|
||||
|
||||
@property
|
||||
def can_be_disabled(self):
|
||||
def can_be_disabled(self) -> bool:
|
||||
return self.get_support().has_MAGICCLOSE
|
||||
|
||||
def _ioctl(self, func, arg):
|
||||
def _ioctl(self, func: int, arg: Any) -> None:
|
||||
"""Runs the specified ioctl on the underlying fd.
|
||||
|
||||
Raises WatchdogError if the device is closed.
|
||||
@@ -165,7 +171,7 @@ class LinuxWatchdogDevice(WatchdogBase):
|
||||
import fcntl
|
||||
fcntl.ioctl(self._fd, func, arg, True)
|
||||
|
||||
def get_support(self):
|
||||
def get_support(self) -> WatchdogInfo:
|
||||
if self._support_cache is None:
|
||||
info = watchdog_info()
|
||||
try:
|
||||
@@ -177,7 +183,7 @@ class LinuxWatchdogDevice(WatchdogBase):
|
||||
bytearray(info.identity).decode(errors='ignore').rstrip('\x00'))
|
||||
return self._support_cache
|
||||
|
||||
def describe(self):
|
||||
def describe(self) -> str:
|
||||
dev_str = " at {0}".format(self.device) if self.device != self.DEFAULT_DEVICE else ""
|
||||
ver_str = ""
|
||||
identity = "Linux watchdog device"
|
||||
@@ -190,17 +196,19 @@ class LinuxWatchdogDevice(WatchdogBase):
|
||||
|
||||
return identity + ver_str + dev_str
|
||||
|
||||
def keepalive(self):
|
||||
def keepalive(self) -> None:
|
||||
if self._fd is None:
|
||||
raise WatchdogError("Watchdog device is closed")
|
||||
try:
|
||||
os.write(self._fd, b'1')
|
||||
except OSError as e:
|
||||
raise WatchdogError("Could not send watchdog keepalive: {0}".format(e))
|
||||
|
||||
def has_set_timeout(self):
|
||||
def has_set_timeout(self) -> bool:
|
||||
"""Returns True if setting a timeout is supported."""
|
||||
return self.get_support().has_SETTIMEOUT
|
||||
|
||||
def set_timeout(self, timeout):
|
||||
def set_timeout(self, timeout: int) -> None:
|
||||
timeout = int(timeout)
|
||||
if not 0 < timeout < 0xFFFF:
|
||||
raise WatchdogError("Invalid timeout {0}. Supported values are between 1 and 65535".format(timeout))
|
||||
@@ -209,7 +217,7 @@ class LinuxWatchdogDevice(WatchdogBase):
|
||||
except (WatchdogError, OSError, IOError) as e:
|
||||
raise WatchdogError("Could not set timeout on watchdog device: {}".format(e))
|
||||
|
||||
def get_timeout(self):
|
||||
def get_timeout(self) -> int:
|
||||
timeout = ctypes.c_int()
|
||||
try:
|
||||
self._ioctl(WDIOC_GETTIMEOUT, timeout)
|
||||
@@ -222,14 +230,16 @@ class TestingWatchdogDevice(LinuxWatchdogDevice): # pragma: no cover
|
||||
"""Converts timeout ioctls to regular writes that can be intercepted from a named pipe."""
|
||||
timeout = 60
|
||||
|
||||
def get_support(self):
|
||||
def get_support(self) -> WatchdogInfo:
|
||||
return WatchdogInfo(WDIOF['MAGICCLOSE'] | WDIOF['SETTIMEOUT'], 0, "Watchdog test harness")
|
||||
|
||||
def set_timeout(self, timeout):
|
||||
def set_timeout(self, timeout: int) -> None:
|
||||
if self._fd is None:
|
||||
raise WatchdogError("Watchdog device is closed")
|
||||
buf = "Ctimeout={0}\n".format(timeout).encode('utf8')
|
||||
while len(buf):
|
||||
buf = buf[os.write(self._fd, buf):]
|
||||
self.timeout = timeout
|
||||
|
||||
def get_timeout(self):
|
||||
def get_timeout(self) -> int:
|
||||
return self.timeout
|
||||
|
||||
27
pyrightconfig.json
Normal file
27
pyrightconfig.json
Normal file
@@ -0,0 +1,27 @@
|
||||
{
|
||||
"include": [
|
||||
"patroni"
|
||||
],
|
||||
|
||||
"exclude": [
|
||||
"**/__pycache__"
|
||||
],
|
||||
|
||||
"ignore": [
|
||||
],
|
||||
|
||||
"defineConstant": {
|
||||
"DEBUG": true
|
||||
},
|
||||
|
||||
"stubPath": "typings/",
|
||||
|
||||
"reportMissingImports": true,
|
||||
"reportMissingTypeStubs": false,
|
||||
|
||||
"pythonVersion": "3.6",
|
||||
"pythonPlatform": "All",
|
||||
|
||||
"typeCheckingMode": "strict"
|
||||
|
||||
}
|
||||
@@ -23,6 +23,7 @@ class MockResponse(object):
|
||||
|
||||
def __init__(self, status_code=200):
|
||||
self.status_code = status_code
|
||||
self.headers = {'content-type': 'json'}
|
||||
self.content = '{}'
|
||||
self.reason = 'Not Found'
|
||||
|
||||
@@ -38,10 +39,6 @@ class MockResponse(object):
|
||||
def getheader(*args):
|
||||
return ''
|
||||
|
||||
@staticmethod
|
||||
def getheaders():
|
||||
return {'content-type': 'json'}
|
||||
|
||||
|
||||
def requests_get(url, method='GET', endpoint=None, data='', **kwargs):
|
||||
members = '[{"id":14855829450254237642,"peerURLs":["http://localhost:2380","http://localhost:7001"],' +\
|
||||
@@ -85,6 +82,8 @@ class MockCursor(object):
|
||||
self.description = [Mock()]
|
||||
|
||||
def execute(self, sql, *params):
|
||||
if isinstance(sql, bytes):
|
||||
sql = sql.decode('utf-8')
|
||||
if sql.startswith('blabla'):
|
||||
raise psycopg.ProgrammingError()
|
||||
elif sql == 'CHECKPOINT' or sql.startswith('SELECT pg_catalog.pg_create_'):
|
||||
@@ -98,7 +97,7 @@ class MockCursor(object):
|
||||
elif sql.startswith('SELECT slot_name'):
|
||||
self.results = [('blabla', 'physical'), ('foobar', 'physical'), ('ls', 'logical', 'a', 'b', 5, 100, 500)]
|
||||
elif sql.startswith('WITH slots AS (SELECT slot_name, active'):
|
||||
self.results = [(False, True)]
|
||||
self.results = [(False, True)] if self.rowcount == 1 else [None]
|
||||
elif sql.startswith('SELECT CASE WHEN pg_catalog.pg_is_in_recovery()'):
|
||||
self.results = [(1, 2, 1, 0, False, 1, 1, None, None,
|
||||
[{"slot_name": "ls", "confirmed_flush_lsn": 12345}],
|
||||
|
||||
@@ -136,8 +136,8 @@ class TestCitus(BaseTestPostgresql):
|
||||
self.assertEqual(parameters['shared_preload_libraries'], 'citus,foo,bar')
|
||||
self.assertEqual(parameters['wal_level'], 'logical')
|
||||
|
||||
@patch.object(CitusHandler, 'is_enabled', Mock(return_value=False))
|
||||
def test_bootstrap(self):
|
||||
self.c._config = None
|
||||
self.c.bootstrap()
|
||||
|
||||
def test_ignore_replication_slot(self):
|
||||
|
||||
@@ -19,8 +19,9 @@ class TestConfig(unittest.TestCase):
|
||||
|
||||
def test_set_dynamic_configuration(self):
|
||||
with patch.object(Config, '_build_effective_configuration', Mock(side_effect=Exception)):
|
||||
self.assertIsNone(self.config.set_dynamic_configuration({'foo': 'bar'}))
|
||||
self.assertTrue(self.config.set_dynamic_configuration({'standby_cluster': {}}))
|
||||
self.assertFalse(self.config.set_dynamic_configuration({'foo': 'bar'}))
|
||||
self.assertTrue(self.config.set_dynamic_configuration({'standby_cluster': {}, 'postgresql': {
|
||||
'parameters': {'cluster_name': 1, 'wal_keep_size': 1, 'track_commit_timestamp': 1, 'wal_level': 1}}}))
|
||||
|
||||
def test_reload_local_configuration(self):
|
||||
os.environ.update({
|
||||
|
||||
@@ -195,6 +195,8 @@ class TestConsul(unittest.TestCase):
|
||||
@patch.object(consul.Consul.KV, 'delete', Mock(return_value=True))
|
||||
def test_delete_leader(self):
|
||||
self.c.delete_leader()
|
||||
self.c._name = 'other'
|
||||
self.c.delete_leader()
|
||||
|
||||
@patch.object(consul.Consul.KV, 'put', Mock(return_value=True))
|
||||
def test_initialize(self):
|
||||
|
||||
@@ -9,7 +9,7 @@ from mock import patch, Mock, PropertyMock
|
||||
from patroni.ctl import ctl, load_config, output_members, get_dcs, parse_dcs, \
|
||||
get_all_members, get_any_member, get_cursor, query_member, PatroniCtlException, apply_config_changes, \
|
||||
format_config_for_editing, show_diff, invoke_editor, format_pg_version, CONFIG_FILE_PATH, PatronictlPrettyTable
|
||||
from patroni.dcs.etcd import AbstractEtcdClientWithFailover, Failover
|
||||
from patroni.dcs.etcd import AbstractEtcdClientWithFailover, Cluster, Failover
|
||||
from patroni.psycopg import OperationalError
|
||||
from patroni.utils import tzutc
|
||||
from prettytable import PrettyTable, ALL
|
||||
@@ -639,6 +639,10 @@ class TestCtl(unittest.TestCase):
|
||||
self.runner.invoke(ctl, ['edit-config', 'dummy', '--force', '--apply', '-'], input='foo: bar')
|
||||
mock_get_dcs.return_value.set_config_value.return_value = True
|
||||
self.runner.invoke(ctl, ['edit-config', 'dummy', '--force', '--apply', '-'], input='foo: bar')
|
||||
mock_get_dcs.return_value.get_cluster = Mock(return_value=Cluster.empty())
|
||||
result = self.runner.invoke(ctl, ['edit-config', 'dummy'])
|
||||
assert result.exit_code == 1
|
||||
assert 'The config key does not exist in the cluster dummy' in result.output
|
||||
|
||||
@patch('patroni.ctl.get_dcs')
|
||||
def test_version(self, mock_get_dcs):
|
||||
|
||||
@@ -138,7 +138,9 @@ class TestClient(unittest.TestCase):
|
||||
@patch.object(EtcdClient, '_get_machines_list',
|
||||
Mock(return_value=['http://localhost:2379', 'http://localhost:4001']))
|
||||
def setUp(self):
|
||||
self.client = EtcdClient({'srv': 'test', 'retry_timeout': 3}, DnsCachingResolver())
|
||||
self.etcd = Etcd({'namespace': '/patroni/', 'ttl': 30, 'retry_timeout': 3,
|
||||
'srv': 'test', 'scope': 'test', 'name': 'foo'})
|
||||
self.client = self.etcd._client
|
||||
self.client.http.request = http_request
|
||||
self.client.http.request_encode_body = http_request
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import json
|
||||
import unittest
|
||||
import urllib3
|
||||
|
||||
from mock import Mock, patch
|
||||
from mock import Mock, PropertyMock, patch
|
||||
from patroni.dcs.etcd import DnsCachingResolver
|
||||
from patroni.dcs.etcd3 import PatroniEtcd3Client, Cluster, Etcd3Client, Etcd3Error, Etcd3ClientError, RetryFailedError,\
|
||||
InvalidAuthToken, Unavailable, Unknown, UnsupportedEtcdVersion, UserEmpty, AuthFailed, base64_encode, Etcd3
|
||||
@@ -85,6 +85,11 @@ class BaseTestEtcd3(unittest.TestCase):
|
||||
|
||||
class TestKVCache(BaseTestEtcd3):
|
||||
|
||||
@patch.object(urllib3.PoolManager, 'urlopen', mock_urlopen)
|
||||
@patch.object(Etcd3Client, 'watchprefix', Mock(return_value=urllib3.response.HTTPResponse()))
|
||||
def test__build_cache(self):
|
||||
self.kv_cache._build_cache()
|
||||
|
||||
def test__do_watch(self):
|
||||
self.client.watchprefix = Mock(return_value=False)
|
||||
self.assertRaises(AttributeError, self.kv_cache._do_watch, '1')
|
||||
@@ -94,14 +99,17 @@ class TestKVCache(BaseTestEtcd3):
|
||||
def test_run(self):
|
||||
self.assertRaises(SleepException, self.kv_cache.run)
|
||||
|
||||
@patch.object(urllib3.PoolManager, 'urlopen', mock_urlopen)
|
||||
@patch.object(urllib3.response.HTTPResponse, 'read_chunked',
|
||||
Mock(return_value=[b'{"error":{"grpc_code":14,"message":"","http_code":503}}']))
|
||||
@patch.object(Etcd3Client, 'watchprefix', Mock(return_value=urllib3.response.HTTPResponse()))
|
||||
def test_kill_stream(self):
|
||||
self.assertRaises(Unavailable, self.kv_cache._do_watch, '1')
|
||||
self.kv_cache.kill_stream()
|
||||
with patch.object(MockResponse, 'connection', create=True) as mock_conn:
|
||||
with patch.object(urllib3.response.HTTPResponse, 'connection') as mock_conn:
|
||||
self.kv_cache.kill_stream()
|
||||
mock_conn.sock.close.side_effect = Exception
|
||||
self.kv_cache.kill_stream()
|
||||
type(mock_conn).sock = PropertyMock(side_effect=Exception)
|
||||
self.kv_cache.kill_stream()
|
||||
|
||||
|
||||
class TestPatroniEtcd3Client(BaseTestEtcd3):
|
||||
@@ -180,7 +188,8 @@ class TestEtcd3(BaseTestEtcd3):
|
||||
@patch.object(urllib3.PoolManager, 'urlopen', mock_urlopen)
|
||||
def setUp(self):
|
||||
super(TestEtcd3, self).setUp()
|
||||
self.assertRaises(AttributeError, self.kv_cache._build_cache)
|
||||
# self.assertRaises(AttributeError, self.kv_cache._build_cache)
|
||||
self.kv_cache._build_cache()
|
||||
self.kv_cache._is_ready = True
|
||||
self.etcd3.get_cluster()
|
||||
|
||||
@@ -276,6 +285,8 @@ class TestEtcd3(BaseTestEtcd3):
|
||||
|
||||
def test_delete_leader(self):
|
||||
self.etcd3.delete_leader()
|
||||
self.etcd3._name = 'other'
|
||||
self.etcd3.delete_leader()
|
||||
|
||||
def test_delete_cluster(self):
|
||||
self.etcd3.delete_cluster()
|
||||
|
||||
@@ -244,7 +244,6 @@ class TestHa(PostgresInit):
|
||||
def test_bootstrap_as_standby_leader(self, initialize):
|
||||
self.p.data_directory_empty = true
|
||||
self.ha.cluster = get_cluster_not_initialized_without_leader(cluster_config=ClusterConfig(0, {}, 0))
|
||||
self.ha.cluster.is_unlocked = true
|
||||
self.ha.patroni.config._dynamic_configuration = {"standby_cluster": {"port": 5432}}
|
||||
self.assertEqual(self.ha.run_cycle(), 'trying to bootstrap a new standby leader')
|
||||
|
||||
@@ -261,7 +260,6 @@ class TestHa(PostgresInit):
|
||||
def test_start_as_cascade_replica_in_standby_cluster(self):
|
||||
self.p.data_directory_empty = true
|
||||
self.ha.cluster = get_standby_cluster_initialized_with_only_leader()
|
||||
self.ha.cluster.is_unlocked = false
|
||||
self.assertEqual(self.ha.run_cycle(), "trying to bootstrap from replica 'test'")
|
||||
|
||||
def test_recover_replica_failed(self):
|
||||
@@ -381,8 +379,8 @@ class TestHa(PostgresInit):
|
||||
self.ha._async_executor.schedule('promote')
|
||||
self.assertEqual(self.ha.run_cycle(), 'lost leader before promote')
|
||||
|
||||
@patch.object(Cluster, 'is_unlocked', Mock(return_value=False))
|
||||
def test_long_promote(self):
|
||||
self.ha.cluster.is_unlocked = false
|
||||
self.ha.has_lock = true
|
||||
self.p.is_leader = false
|
||||
self.p.set_role('primary')
|
||||
@@ -407,14 +405,13 @@ class TestHa(PostgresInit):
|
||||
self.p.is_leader = false
|
||||
self.assertEqual(self.ha.run_cycle(), 'following a different leader because i am not the healthiest node')
|
||||
|
||||
@patch.object(Cluster, 'is_unlocked', Mock(return_value=False))
|
||||
def test_promote_because_have_lock(self):
|
||||
self.ha.cluster.is_unlocked = false
|
||||
self.ha.has_lock = true
|
||||
self.p.is_leader = false
|
||||
self.assertEqual(self.ha.run_cycle(), 'promoted self to leader because I had the session lock')
|
||||
|
||||
def test_promote_without_watchdog(self):
|
||||
self.ha.cluster.is_unlocked = false
|
||||
self.ha.has_lock = true
|
||||
self.p.is_leader = true
|
||||
with patch.object(Watchdog, 'activate', Mock(return_value=False)):
|
||||
@@ -424,24 +421,22 @@ class TestHa(PostgresInit):
|
||||
|
||||
def test_leader_with_lock(self):
|
||||
self.ha.cluster = get_cluster_initialized_with_leader()
|
||||
self.ha.cluster.is_unlocked = false
|
||||
self.ha.has_lock = true
|
||||
self.assertEqual(self.ha.run_cycle(), 'no action. I am (postgresql0), the leader with the lock')
|
||||
|
||||
def test_coordinator_leader_with_lock(self):
|
||||
self.ha.cluster = get_cluster_initialized_with_leader()
|
||||
self.ha.cluster.is_unlocked = false
|
||||
self.ha.has_lock = true
|
||||
self.assertEqual(self.ha.run_cycle(), 'no action. I am (postgresql0), the leader with the lock')
|
||||
|
||||
@patch.object(Postgresql, '_wait_for_connection_close', Mock())
|
||||
@patch.object(Cluster, 'is_unlocked', Mock(return_value=False))
|
||||
def test_demote_because_not_having_lock(self):
|
||||
self.ha.cluster.is_unlocked = false
|
||||
with patch.object(Watchdog, 'is_running', PropertyMock(return_value=True)):
|
||||
self.assertEqual(self.ha.run_cycle(), 'demoting self because I do not have the lock and I was a leader')
|
||||
|
||||
@patch.object(Cluster, 'is_unlocked', Mock(return_value=False))
|
||||
def test_demote_because_update_lock_failed(self):
|
||||
self.ha.cluster.is_unlocked = false
|
||||
self.ha.has_lock = true
|
||||
self.ha.update_lock = false
|
||||
self.assertEqual(self.ha.run_cycle(), 'demoted self because failed to update leader lock in DCS')
|
||||
@@ -450,8 +445,8 @@ class TestHa(PostgresInit):
|
||||
self.p.is_leader = false
|
||||
self.assertEqual(self.ha.run_cycle(), 'not promoting because failed to update leader lock in DCS')
|
||||
|
||||
@patch.object(Cluster, 'is_unlocked', Mock(return_value=False))
|
||||
def test_follow(self):
|
||||
self.ha.cluster.is_unlocked = false
|
||||
self.p.is_leader = false
|
||||
self.assertEqual(self.ha.run_cycle(), 'no action. I am (postgresql0), a secondary, and following a leader ()')
|
||||
self.ha.patroni.replicatefrom = "foo"
|
||||
@@ -465,8 +460,8 @@ class TestHa(PostgresInit):
|
||||
self.assertEqual(self.ha.run_cycle(), 'no action. I am (postgresql0), a secondary, and following a leader ()')
|
||||
del self.ha.cluster.config.data['postgresql']['use_slots']
|
||||
|
||||
@patch.object(Cluster, 'is_unlocked', Mock(return_value=False))
|
||||
def test_follow_in_pause(self):
|
||||
self.ha.cluster.is_unlocked = false
|
||||
self.ha.is_paused = true
|
||||
self.assertEqual(self.ha.run_cycle(), 'PAUSE: continue to run as primary without lock')
|
||||
self.p.is_leader = false
|
||||
@@ -659,10 +654,12 @@ class TestHa(PostgresInit):
|
||||
|
||||
self.ha.update_lock = false
|
||||
self.p.set_role('primary')
|
||||
with patch('patroni.async_executor.CriticalTask.cancel', Mock(return_value=False)):
|
||||
with patch('patroni.postgresql.Postgresql.terminate_starting_postmaster') as mock_terminate:
|
||||
self.assertEqual(self.ha.run_cycle(), 'lost leader lock during restart')
|
||||
mock_terminate.assert_called()
|
||||
with patch('patroni.async_executor.CriticalTask.cancel', Mock(return_value=False)),\
|
||||
patch('patroni.async_executor.CriticalTask.result',
|
||||
PropertyMock(return_value=PostmasterProcess(os.getpid())), create=True),\
|
||||
patch('patroni.postgresql.Postgresql.terminate_starting_postmaster') as mock_terminate:
|
||||
self.assertEqual(self.ha.run_cycle(), 'lost leader lock during restart')
|
||||
mock_terminate.assert_called()
|
||||
|
||||
self.ha.is_paused = true
|
||||
self.assertEqual(self.ha.run_cycle(), 'PAUSE: restart in progress')
|
||||
@@ -741,7 +738,6 @@ class TestHa(PostgresInit):
|
||||
def test_manual_failover_process_no_leader(self):
|
||||
self.p.is_leader = false
|
||||
self.ha.cluster = get_cluster_initialized_without_leader(failover=Failover(0, '', self.p.name, None))
|
||||
self.assertEqual(self.ha.run_cycle(), 'promoted self to leader by acquiring session lock')
|
||||
self.ha.cluster = get_cluster_initialized_without_leader(failover=Failover(0, '', 'leader', None))
|
||||
self.p.set_role('replica')
|
||||
self.assertEqual(self.ha.run_cycle(), 'promoted self to leader by acquiring session lock')
|
||||
@@ -847,10 +843,10 @@ class TestHa(PostgresInit):
|
||||
self.assertFalse(self.ha.is_healthiest_node())
|
||||
|
||||
def test__is_healthiest_node(self):
|
||||
self.p.is_leader = false
|
||||
self.ha.cluster = get_cluster_initialized_without_leader(sync=('postgresql1', self.p.name))
|
||||
self.ha.global_config = self.ha.patroni.config.get_global_config(self.ha.cluster)
|
||||
self.assertTrue(self.ha._is_healthiest_node(self.ha.old_cluster.members))
|
||||
self.p.is_leader = false
|
||||
self.ha.fetch_node_status = get_node_status() # accessible, in_recovery
|
||||
self.assertTrue(self.ha._is_healthiest_node(self.ha.old_cluster.members))
|
||||
self.ha.fetch_node_status = get_node_status(in_recovery=False) # accessible, not in_recovery
|
||||
@@ -864,6 +860,8 @@ class TestHa(PostgresInit):
|
||||
self.ha.global_config = self.ha.patroni.config.get_global_config(self.ha.cluster)
|
||||
with patch('patroni.postgresql.Postgresql.last_operation', return_value=1):
|
||||
self.assertFalse(self.ha._is_healthiest_node(self.ha.old_cluster.members))
|
||||
with patch('patroni.postgresql.Postgresql.replica_cached_timeline', return_value=None):
|
||||
self.assertFalse(self.ha._is_healthiest_node(self.ha.old_cluster.members))
|
||||
with patch('patroni.postgresql.Postgresql.replica_cached_timeline', return_value=1):
|
||||
self.assertFalse(self.ha._is_healthiest_node(self.ha.old_cluster.members))
|
||||
self.ha.patroni.nofailover = True
|
||||
@@ -972,11 +970,11 @@ class TestHa(PostgresInit):
|
||||
with patch.object(Leader, 'conn_url', PropertyMock(return_value='')):
|
||||
self.assertEqual(self.ha.run_cycle(), 'continue following the old known standby leader')
|
||||
|
||||
@patch.object(Cluster, 'is_unlocked', Mock(return_value=True))
|
||||
def test_process_unhealthy_standby_cluster_as_standby_leader(self):
|
||||
self.p.is_leader = false
|
||||
self.p.name = 'leader'
|
||||
self.ha.cluster = get_standby_cluster_initialized_with_only_leader()
|
||||
self.ha.cluster.is_unlocked = true
|
||||
self.ha.sysid_valid = true
|
||||
self.p._sysid = True
|
||||
self.assertEqual(self.ha.run_cycle(), 'promoted self to a standby leader by acquiring session lock')
|
||||
@@ -987,7 +985,6 @@ class TestHa(PostgresInit):
|
||||
self.p.is_leader = false
|
||||
self.p.name = 'replica'
|
||||
self.ha.cluster = get_standby_cluster_initialized_with_only_leader()
|
||||
self.ha.is_unlocked = true
|
||||
self.assertTrue(self.ha.run_cycle().startswith('running pg_rewind from remote_member:'))
|
||||
|
||||
def test_recover_unhealthy_leader_in_standby_cluster(self):
|
||||
@@ -998,13 +995,13 @@ class TestHa(PostgresInit):
|
||||
self.ha.cluster = get_standby_cluster_initialized_with_only_leader()
|
||||
self.assertEqual(self.ha.run_cycle(), 'starting as a standby leader because i had the session lock')
|
||||
|
||||
@patch.object(Cluster, 'is_unlocked', Mock(return_value=True))
|
||||
def test_recover_unhealthy_unlocked_standby_cluster(self):
|
||||
self.p.is_leader = false
|
||||
self.p.name = 'leader'
|
||||
self.p.is_running = false
|
||||
self.p.follow = false
|
||||
self.ha.cluster = get_standby_cluster_initialized_with_only_leader()
|
||||
self.ha.cluster.is_unlocked = true
|
||||
self.ha.has_lock = false
|
||||
self.assertEqual(self.ha.run_cycle(), 'trying to follow a remote member because standby cluster is unhealthy')
|
||||
|
||||
@@ -1287,10 +1284,10 @@ class TestHa(PostgresInit):
|
||||
|
||||
@patch('patroni.postgresql.mtime', Mock(return_value=1588316884))
|
||||
@patch('builtins.open', Mock(side_effect=Exception))
|
||||
@patch.object(Cluster, 'is_unlocked', Mock(return_value=False))
|
||||
def test_restore_cluster_config(self):
|
||||
self.ha.cluster.config.data.clear()
|
||||
self.ha.has_lock = true
|
||||
self.ha.cluster.is_unlocked = false
|
||||
self.assertEqual(self.ha.run_cycle(), 'no action. I am (postgresql0), the leader with the lock')
|
||||
|
||||
def test_watch(self):
|
||||
@@ -1341,9 +1338,9 @@ class TestHa(PostgresInit):
|
||||
@patch('patroni.postgresql.mtime', Mock(return_value=1588316884))
|
||||
@patch('builtins.open', mock_open(read_data=('1\t0/40159C0\tno recovery target specified\n\n'
|
||||
'2\t1/40159C0\tno recovery target specified\n')))
|
||||
@patch.object(Cluster, 'is_unlocked', Mock(return_value=False))
|
||||
def test_update_cluster_history(self):
|
||||
self.ha.has_lock = true
|
||||
self.ha.cluster.is_unlocked = false
|
||||
for tl in (1, 3):
|
||||
self.p.get_primary_timeline = Mock(return_value=tl)
|
||||
self.assertEqual(self.ha.run_cycle(), 'no action. I am (postgresql0), the leader with the lock')
|
||||
@@ -1355,9 +1352,9 @@ class TestHa(PostgresInit):
|
||||
self.ha.run_cycle()
|
||||
exit_mock.assert_called_once_with(1)
|
||||
|
||||
@patch.object(Cluster, 'is_unlocked', Mock(return_value=False))
|
||||
def test_after_pause(self):
|
||||
self.ha.has_lock = true
|
||||
self.ha.cluster.is_unlocked = false
|
||||
self.ha.is_paused = true
|
||||
self.assertEqual(self.ha.run_cycle(), 'PAUSE: no action. I am (postgresql0), the leader with the lock')
|
||||
self.ha.is_paused = false
|
||||
@@ -1409,16 +1406,10 @@ class TestHa(PostgresInit):
|
||||
@patch.object(Postgresql, 'major_version', PropertyMock(return_value=130000))
|
||||
@patch.object(SlotsHandler, 'sync_replication_slots', Mock(return_value=['ls']))
|
||||
def test_follow_copy(self):
|
||||
self.ha.cluster.is_unlocked = false
|
||||
self.ha.cluster.config.data['slots'] = {'ls': {'database': 'a', 'plugin': 'b'}}
|
||||
self.p.is_leader = false
|
||||
self.assertTrue(self.ha.run_cycle().startswith('Copying logical slots'))
|
||||
|
||||
def test_is_failover_possible(self):
|
||||
self.ha.fetch_node_status = Mock(return_value=_MemberStatus(self.ha.cluster.members[0],
|
||||
True, True, 0, 2, None, {}, False))
|
||||
self.assertFalse(self.ha.is_failover_possible(self.ha.cluster.members))
|
||||
|
||||
def test_acquire_lock(self):
|
||||
self.ha.dcs.attempt_to_acquire_leader = Mock(side_effect=[DCSError('foo'), Exception])
|
||||
self.assertRaises(DCSError, self.ha.acquire_lock)
|
||||
|
||||
@@ -5,6 +5,7 @@ import mock
|
||||
import socket
|
||||
import time
|
||||
import unittest
|
||||
import urllib3
|
||||
|
||||
from mock import Mock, PropertyMock, mock_open, patch
|
||||
from patroni.dcs.kubernetes import Cluster, k8s_client, k8s_config, K8sConfig, K8sConnectionFailed,\
|
||||
@@ -34,6 +35,9 @@ def mock_list_namespaced_config_map(*args, **kwargs):
|
||||
metadata.update({'name': 'test-1-leader', 'labels': {Kubernetes._CITUS_LABEL: '1'},
|
||||
'annotations': {'leader': 'p-3', 'ttl': '30s'}})
|
||||
items.append(k8s_client.V1ConfigMap(metadata=k8s_client.V1ObjectMeta(**metadata)))
|
||||
metadata.update({'name': 'test-2-config', 'labels': {Kubernetes._CITUS_LABEL: '2'}, 'annotations': {}})
|
||||
items.append(k8s_client.V1ConfigMap(metadata=k8s_client.V1ObjectMeta(**metadata)))
|
||||
|
||||
metadata = k8s_client.V1ObjectMeta(resource_version='1')
|
||||
return k8s_client.V1ConfigMapList(metadata=metadata, items=items, kind='ConfigMapList')
|
||||
|
||||
@@ -403,13 +407,18 @@ class TestKubernetesEndpoints(BaseTestKubernetes):
|
||||
self.assertEqual(('create_config_service failed',), mock_logger_exception.call_args[0])
|
||||
|
||||
|
||||
def mock_watch(*args):
|
||||
return urllib3.HTTPResponse()
|
||||
|
||||
|
||||
class TestCacheBuilder(BaseTestKubernetes):
|
||||
|
||||
@patch.object(k8s_client.CoreV1Api, 'list_namespaced_config_map', mock_list_namespaced_config_map, create=True)
|
||||
@patch('patroni.dcs.kubernetes.ObjectCache._watch')
|
||||
def test__build_cache(self, mock_response):
|
||||
@patch('patroni.dcs.kubernetes.ObjectCache._watch', mock_watch)
|
||||
@patch.object(urllib3.HTTPResponse, 'read_chunked')
|
||||
def test__build_cache(self, mock_read_chunked):
|
||||
self.k._citus_group = '0'
|
||||
mock_response.return_value.read_chunked.return_value = [json.dumps(
|
||||
mock_read_chunked.return_value = [json.dumps(
|
||||
{'type': 'MODIFIED', 'object': {'metadata': {
|
||||
'name': self.k.config_path, 'resourceVersion': '2', 'annotations': {self.k._CONFIG: 'foo'}}}}
|
||||
).encode('utf-8'), ('\n' + json.dumps(
|
||||
@@ -435,12 +444,13 @@ class TestCacheBuilder(BaseTestKubernetes):
|
||||
self.assertRaises(AttributeError, self.k._kinds._do_watch, '1')
|
||||
|
||||
@patch.object(k8s_client.CoreV1Api, 'list_namespaced_config_map', mock_list_namespaced_config_map, create=True)
|
||||
@patch('patroni.dcs.kubernetes.ObjectCache._watch')
|
||||
def test_kill_stream(self, mock_watch):
|
||||
self.k._kinds.kill_stream()
|
||||
mock_watch.return_value.read_chunked.return_value = []
|
||||
mock_watch.return_value.connection.sock.close.side_effect = Exception
|
||||
self.k._kinds._do_watch('1')
|
||||
self.k._kinds.kill_stream()
|
||||
type(mock_watch.return_value).connection = PropertyMock(side_effect=Exception)
|
||||
@patch('patroni.dcs.kubernetes.ObjectCache._watch', mock_watch)
|
||||
@patch.object(urllib3.HTTPResponse, 'read_chunked', Mock(return_value=[]))
|
||||
def test_kill_stream(self):
|
||||
self.k._kinds.kill_stream()
|
||||
with patch.object(urllib3.HTTPResponse, 'connection') as mock_connection:
|
||||
mock_connection.sock.close.side_effect = Exception
|
||||
self.k._kinds._do_watch('1')
|
||||
self.k._kinds.kill_stream()
|
||||
with patch.object(urllib3.HTTPResponse, 'connection', PropertyMock(side_effect=Exception)):
|
||||
self.k._kinds.kill_stream()
|
||||
|
||||
@@ -369,6 +369,10 @@ class TestPostgresql(BaseTestPostgresql):
|
||||
def test_latest_checkpoint_location(self, mock_popen):
|
||||
mock_popen.return_value.communicate.return_value = (None, None)
|
||||
self.assertEqual(self.p.latest_checkpoint_location(), 28163096)
|
||||
with patch.object(Postgresql, 'controldata', Mock(return_value={'Database cluster state': 'shut down',
|
||||
'Latest checkpoint location': 'k/1ADBC18',
|
||||
"Latest checkpoint's TimeLineID": '1'})):
|
||||
self.assertIsNone(self.p.latest_checkpoint_location())
|
||||
# 9.3 and 9.4 format
|
||||
mock_popen.return_value.communicate.side_effect = [
|
||||
(b'rmgr: XLOG len (rec/tot): 72/ 104, tx: 0, lsn: 0/01ADBC18, prev 0/01ADBBB8, '
|
||||
@@ -518,10 +522,10 @@ class TestPostgresql(BaseTestPostgresql):
|
||||
|
||||
def test_can_create_replica_without_replication_connection(self):
|
||||
self.p.config._config['create_replica_method'] = []
|
||||
self.assertFalse(self.p.can_create_replica_without_replication_connection())
|
||||
self.assertFalse(self.p.can_create_replica_without_replication_connection(None))
|
||||
self.p.config._config['create_replica_method'] = ['wale', 'basebackup']
|
||||
self.p.config._config['wale'] = {'command': 'foo', 'no_leader': 1}
|
||||
self.assertTrue(self.p.can_create_replica_without_replication_connection())
|
||||
self.assertTrue(self.p.can_create_replica_without_replication_connection(None))
|
||||
|
||||
def test_replica_method_can_work_without_replication_connection(self):
|
||||
self.assertFalse(self.p.replica_method_can_work_without_replication_connection('basebackup'))
|
||||
|
||||
@@ -111,6 +111,8 @@ class TestSlotsHandler(BaseTestPostgresql):
|
||||
self.s.copy_logical_slots(self.cluster, ['ls'])
|
||||
with patch.object(MockCursor, 'execute', Mock(side_effect=psycopg.OperationalError)):
|
||||
self.s.copy_logical_slots(self.cluster, ['foo'])
|
||||
with patch.object(Cluster, 'leader', PropertyMock(return_value=None)):
|
||||
self.s.copy_logical_slots(self.cluster, ['foo'])
|
||||
|
||||
@patch.object(Postgresql, 'stop', Mock(return_value=True))
|
||||
@patch.object(Postgresql, 'start', Mock(return_value=True))
|
||||
|
||||
@@ -36,7 +36,7 @@ class TestUtils(unittest.TestCase):
|
||||
|
||||
def test_enable_keepalive(self):
|
||||
with patch('socket.SIO_KEEPALIVE_VALS', 1, create=True):
|
||||
self.assertIsNotNone(enable_keepalive(Mock(), 10, 5))
|
||||
self.assertIsNone(enable_keepalive(Mock(), 10, 5))
|
||||
with patch('socket.SIO_KEEPALIVE_VALS', None, create=True):
|
||||
for platform in ('linux2', 'darwin', 'other'):
|
||||
with patch('sys.platform', platform):
|
||||
|
||||
@@ -123,7 +123,7 @@ class TestWALERestore(unittest.TestCase):
|
||||
with patch.object(WALERestore, 'run', Mock(return_value=1)), \
|
||||
patch('time.sleep', mock_sleep):
|
||||
self.assertEqual(_main(), 1)
|
||||
self.assertTrue(sleeps[0], WALE_TEST_RETRIES)
|
||||
self.assertEqual(sleeps[0], WALE_TEST_RETRIES)
|
||||
|
||||
@patch('os.path.isfile', Mock(return_value=True))
|
||||
def test_get_major_version(self):
|
||||
|
||||
@@ -110,6 +110,11 @@ class TestWatchdog(unittest.TestCase):
|
||||
watchdog.keepalive()
|
||||
self.assertEqual(len(device.writes), 1)
|
||||
|
||||
watchdog.impl._fd, fd = None, watchdog.impl._fd
|
||||
watchdog.keepalive()
|
||||
self.assertEqual(len(device.writes), 1)
|
||||
watchdog.impl._fd = fd
|
||||
|
||||
watchdog.disable()
|
||||
self.assertFalse(device.open)
|
||||
self.assertEqual(device.writes[-1], b'V')
|
||||
|
||||
@@ -13,6 +13,7 @@ from patroni.dcs.zookeeper import Cluster, Leader, PatroniKazooClient,\
|
||||
|
||||
class MockKazooClient(Mock):
|
||||
|
||||
handler = PatroniSequentialThreadingHandler(10)
|
||||
leader = False
|
||||
exists = True
|
||||
|
||||
@@ -154,11 +155,6 @@ class TestZooKeeper(unittest.TestCase):
|
||||
def test_session_listener(self):
|
||||
self.zk.session_listener(KazooState.SUSPENDED)
|
||||
|
||||
def test_members_watcher(self):
|
||||
self.zk._fetch_cluster = False
|
||||
self.zk.members_watcher(None)
|
||||
self.assertTrue(self.zk._fetch_cluster)
|
||||
|
||||
def test_reload_config(self):
|
||||
self.zk.reload_config({'ttl': 20, 'retry_timeout': 10, 'loop_wait': 10})
|
||||
self.zk.reload_config({'ttl': 20, 'retry_timeout': 10, 'loop_wait': 5})
|
||||
@@ -223,6 +219,8 @@ class TestZooKeeper(unittest.TestCase):
|
||||
|
||||
def test_cancel_initialization(self):
|
||||
self.zk.cancel_initialization()
|
||||
with patch.object(MockKazooClient, 'delete', Mock()):
|
||||
self.zk.cancel_initialization()
|
||||
|
||||
def test_touch_member(self):
|
||||
self.zk._name = 'buzz'
|
||||
|
||||
0
typings/botocore/__init__.pyi
Normal file
0
typings/botocore/__init__.pyi
Normal file
1
typings/botocore/exceptions.pyi
Normal file
1
typings/botocore/exceptions.pyi
Normal file
@@ -0,0 +1 @@
|
||||
class ClientError(Exception): ...
|
||||
11
typings/botocore/utils.pyi
Normal file
11
typings/botocore/utils.pyi
Normal file
@@ -0,0 +1,11 @@
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
DEFAULT_METADATA_SERVICE_TIMEOUT = 1
|
||||
METADATA_BASE_URL = 'http://169.254.169.254/'
|
||||
class AWSResponse:
|
||||
status_code: int
|
||||
@property
|
||||
def text(self) -> str: ...
|
||||
class IMDSFetcher:
|
||||
def __init__(self, timeout: float = DEFAULT_METADATA_SERVICE_TIMEOUT, num_attempts: int = 1, base_url: str = METADATA_BASE_URL, env: Optional[Dict[str, str]] = None, user_agent: Optional[str] = None, config: Optional[Dict[str, Any]] = None) -> None: ...
|
||||
def _fetch_metadata_token(self) -> Optional[str]: ...:
|
||||
def _get_request(self, url_path: str, retry_func: Optional[Callable[[AWSResponse], bool]] = None, token: Optional[str] = None) -> AWSResponse: ...
|
||||
5
typings/cdiff/__init__.pyi
Normal file
5
typings/cdiff/__init__.pyi
Normal file
@@ -0,0 +1,5 @@
|
||||
import io
|
||||
from typing import Any
|
||||
class PatchStream:
|
||||
def __init__(self, diff_hdl: io.TextIOBase) -> None: ...
|
||||
def markup_to_pager(stream: Any, opts: Any) -> None: ...
|
||||
2
typings/consul/__init__.pyi
Normal file
2
typings/consul/__init__.pyi
Normal file
@@ -0,0 +1,2 @@
|
||||
from consul.base import ConsulException, NotFound
|
||||
__all__ = ['ConsulException', 'Consul', 'NotFound']
|
||||
24
typings/consul/base.pyi
Normal file
24
typings/consul/base.pyi
Normal file
@@ -0,0 +1,24 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
class ConsulException(Exception): ...
|
||||
class NotFound(ConsulException): ...
|
||||
class Check:
|
||||
@classmethod
|
||||
def http(klass, url: str, interval: str, timeout: Optional[str] = None, deregister: Optional[str] = None) -> Dict[str, str]: ...
|
||||
class Consul:
|
||||
http: Any
|
||||
agent: 'Consul.Agent'
|
||||
session: 'Consul.Session'
|
||||
kv: 'Consul.KV'
|
||||
class KV:
|
||||
def get(self, key: str, index: Optional[int]=None, recurse: bool = False, wait: Optional[str] = None, token: Optional[str] = None, consistency: Optional[str] = None, keys: bool = False, separator: Optional[str] = '', dc: Optional[str] = None) -> Tuple[int, Dict[str, Any]]: ...
|
||||
def put(self, key: str, value: str, cas: Optional[int] = None, flags: Optional[int] = None, acquire: Optional[str] = None, release: Optional[str] = None, token: Optional[str] = None, dc: Optional[str] = None) -> bool: ...
|
||||
def delete(self, key: str, recurse: Optional[bool] = None, cas: Optional[int] = None, token: Optional[str] = None, dc: Optional[str] = None) -> bool: ...
|
||||
class Agent:
|
||||
service: 'Consul.Agent.Service'
|
||||
def self(self) -> Dict[str, Dict[str, Any]]: ...
|
||||
class Service:
|
||||
def register(self, name: str, service_id=..., address=..., port=..., tags=..., check=..., token=..., script=..., interval=..., ttl=..., http=..., timeout=..., enable_tag_override=...) -> bool: ...
|
||||
def deregister(self, service_id: str) -> bool: ...
|
||||
class Session:
|
||||
def create(self, name: Optional[str] = None, node: Optional[str] = [], checks: Optional[List[str]]=None, lock_delay: float = 15, behavior: str = 'release', ttl: Optional[int] = None, dc: Optional[str] = None) -> str: ...
|
||||
def renew(self, session_id: str, dc: Optional[str] = None) -> Optional[str]: ...
|
||||
17
typings/dns/resolver.pyi
Normal file
17
typings/dns/resolver.pyi
Normal file
@@ -0,0 +1,17 @@
|
||||
from typing import Union, Optional, Iterator
|
||||
class Name:
|
||||
def to_text(self, omit_final_dot: bool = ...) -> str: ...
|
||||
class Rdata:
|
||||
target: Name = ...
|
||||
port: int = ...
|
||||
class Answer:
|
||||
def __iter__(self) -> Iterator[Rdata]: ...
|
||||
def resolve(qname : str, rdtype : Union[int,str] = 0,
|
||||
rdclass : Union[int,str] = 0,
|
||||
tcp=False, source=None, raise_on_no_answer=True,
|
||||
source_port=0, lifetime : Optional[float]=None,
|
||||
search : Optional[bool]=None) -> Answer: ...
|
||||
def query(qname : str, rdtype : Union[int,str] = 0,
|
||||
rdclass : Union[int,str] = 0,
|
||||
tcp=False, source: Optional[str] = None, raise_on_no_answer=True,
|
||||
source_port=0, lifetime : Optional[float]=None) -> Answer: ...
|
||||
24
typings/etcd/__init__.pyi
Normal file
24
typings/etcd/__init__.pyi
Normal file
@@ -0,0 +1,24 @@
|
||||
from typing import Dict, Optional, Type, List
|
||||
from .client import Client
|
||||
__all__ = ['Client', 'EtcdError', 'EtcdException', 'EtcdEventIndexCleared', 'EtcdWatcherCleared', 'EtcdKeyNotFound', 'EtcdAlreadyExist', 'EtcdResult', 'EtcdConnectionFailed', 'EtcdWatchTimedOut']
|
||||
class EtcdResult:
|
||||
action: str = ...
|
||||
modifiedIndex: int = ...
|
||||
key: str = ...
|
||||
value: str = ...
|
||||
ttl: Optional[float] = ...
|
||||
@property
|
||||
def leaves(self) -> List['EtcdResult']: ...
|
||||
class EtcdException(Exception):
|
||||
def __init__(self, message=..., payload=...) -> None: ...
|
||||
class EtcdConnectionFailed(EtcdException):
|
||||
def __init__(self, message=..., payload=..., cause=...) -> None: ...
|
||||
class EtcdKeyError(EtcdException): ...
|
||||
class EtcdKeyNotFound(EtcdKeyError): ...
|
||||
class EtcdAlreadyExist(EtcdKeyError): ...
|
||||
class EtcdEventIndexCleared(EtcdException): ...
|
||||
class EtcdWatchTimedOut(EtcdConnectionFailed): ...
|
||||
class EtcdWatcherCleared(EtcdException): ...
|
||||
class EtcdLeaderElectionInProgress(EtcdException): ...
|
||||
class EtcdError:
|
||||
error_exceptions: Dict[int, Type[EtcdException]] = ...
|
||||
29
typings/etcd/client.pyi
Normal file
29
typings/etcd/client.pyi
Normal file
@@ -0,0 +1,29 @@
|
||||
import urllib3
|
||||
from typing import Any, Optional, Set
|
||||
from . import EtcdResult
|
||||
class Client:
|
||||
_MGET: str
|
||||
_MPUT: str
|
||||
_MPOST: str
|
||||
_MDELETE: str
|
||||
_comparison_conditions: Set[str]
|
||||
_read_options: Set[str]
|
||||
_del_conditions: Set[str]
|
||||
http: urllib3.poolmanager.PoolManager
|
||||
_use_proxies: bool
|
||||
version_prefix: str
|
||||
username: Optional[str]
|
||||
password: Optional[str]
|
||||
def __init__(self, host=..., port=..., srv_domain=..., version_prefix=..., read_timeout=..., allow_redirect=..., protocol=..., cert=..., ca_cert=..., username=..., password=..., allow_reconnect=..., use_proxies=..., expected_cluster_id=..., per_host_pool_size=..., lock_prefix=...): ...
|
||||
@property
|
||||
def protocol(self) -> str: ...
|
||||
@property
|
||||
def read_timeout(self) -> int: ...
|
||||
@property
|
||||
def allow_redirect(self) -> bool: ...
|
||||
def write(self, key: str, value: str, ttl: int = ..., dir: bool = ..., append: bool = ..., **kwdargs: Any) -> EtcdResult: ...
|
||||
def read(self, key: str, **kwdargs: Any) -> EtcdResult: ...
|
||||
def delete(self, key: str, recursive: bool = ..., dir: bool = ..., **kwdargs: Any) -> EtcdResult: ...
|
||||
def set(self, key: str, value: str, ttl: int = ...) -> EtcdResult: ...
|
||||
def watch(self, key: str, index: int = ..., timeout: float = ..., recursive: bool = ...) -> EtcdResult: ...
|
||||
def _handle_server_response(self, response: urllib3.response.HTTPResponse) -> Any: ...
|
||||
34
typings/kazoo/client.pyi
Normal file
34
typings/kazoo/client.pyi
Normal file
@@ -0,0 +1,34 @@
|
||||
__all__ = ['KazooState', 'KazooClient', 'KazooRetry']
|
||||
|
||||
from kazoo.protocol.connection import ConnectionHandler
|
||||
from kazoo.protocol.states import KazooState, WatchedEvent, ZnodeStat
|
||||
from kazoo.handlers.threading import AsyncResult, SequentialThreadingHandler
|
||||
from kazoo.retry import KazooRetry
|
||||
from kazoo.security import ACL
|
||||
|
||||
from typing import Any, Callable, Optional, Tuple, List
|
||||
|
||||
|
||||
class KazooClient:
|
||||
handler: SequentialThreadingHandler
|
||||
_state: str
|
||||
_connection: ConnectionHandler
|
||||
_session_timeout: int
|
||||
retry: Callable[..., Any]
|
||||
_retry: KazooRetry
|
||||
def __init__(self, hosts=..., timeout=..., client_id=..., handler=..., default_acl=..., auth_data=..., sasl_options=..., read_only=..., randomize_hosts=..., connection_retry=..., command_retry=..., logger=..., keyfile=..., keyfile_password=..., certfile=..., ca=..., use_ssl=..., verify_certs=..., **kwargs) -> None: ...
|
||||
@property
|
||||
def client_id(self) -> Optional[Tuple[Any]]: ...
|
||||
def add_listener(self, listener: Callable[[str], None]) -> None: ...
|
||||
def start(self, timeout: int = ...) -> None: ...
|
||||
def restart(self) -> None: ...
|
||||
def set_hosts(self, hosts: str, randomize_hosts: Optional[bool] = None) -> None: ...
|
||||
def create(self, path: str, value: bytes = b'', acl: Optional[ACL]=None, ephemeral: bool = False, sequence: bool = False, makepath: bool = False, include_data: bool = False) -> None: ...
|
||||
def create_async(self, path: str, value: bytes = b'', acl: Optional[ACL]=None, ephemeral: bool = False, sequence: bool = False, makepath: bool = False, include_data: bool = False) -> AsyncResult: ...
|
||||
def get(self, path: str, watch: Optional[Callable[[WatchedEvent], None]] = None) -> Tuple[bytes, ZnodeStat]: ...
|
||||
def get_children(self, path: str, watch: Optional[Callable[[WatchedEvent], None]] = None, include_data: bool = False) -> List[str]: ...
|
||||
def set(self, path: str, value: bytes, version: int = -1) -> ZnodeStat: ...
|
||||
def set_async(self, path: str, value: bytes, version: int = -1) -> AsyncResult: ...
|
||||
def delete(self, path: str, version: int = -1, recursive: bool = False) -> None: ...
|
||||
def delete_async(self, path: str, version: int = -1) -> AsyncResult: ...
|
||||
def _call(self, request: Tuple[Any], async_object: AsyncResult) -> Optional[bool]: ...
|
||||
12
typings/kazoo/exceptions.pyi
Normal file
12
typings/kazoo/exceptions.pyi
Normal file
@@ -0,0 +1,12 @@
|
||||
class KazooException(Exception):
|
||||
...
|
||||
class ZookeeperError(KazooException):
|
||||
...
|
||||
class SessionExpiredError(ZookeeperError):
|
||||
...
|
||||
class ConnectionClosedError(SessionExpiredError):
|
||||
...
|
||||
class NoNodeError(ZookeeperError):
|
||||
...
|
||||
class NodeExistsError(ZookeeperError):
|
||||
...
|
||||
13
typings/kazoo/handlers/threading.pyi
Normal file
13
typings/kazoo/handlers/threading.pyi
Normal file
@@ -0,0 +1,13 @@
|
||||
import socket
|
||||
from kazoo.handlers import utils
|
||||
from typing import Any
|
||||
|
||||
class AsyncResult(utils.AsyncResult):
|
||||
...
|
||||
|
||||
class SequentialThreadingHandler:
|
||||
def select(self, *args: Any, **kwargs: Any) -> Any:
|
||||
...
|
||||
|
||||
def create_connection(self, *args: Any, **kwargs: Any) -> socket.socket:
|
||||
...
|
||||
6
typings/kazoo/handlers/utils.pyi
Normal file
6
typings/kazoo/handlers/utils.pyi
Normal file
@@ -0,0 +1,6 @@
|
||||
from typing import Any, Optional
|
||||
class AsyncResult:
|
||||
def set_exception(self, exception: Exception) -> None:
|
||||
...
|
||||
def get(self, block: bool = False, timeout: Optional[float] = None) -> Any:
|
||||
...
|
||||
6
typings/kazoo/protocol/connection.pyi
Normal file
6
typings/kazoo/protocol/connection.pyi
Normal file
@@ -0,0 +1,6 @@
|
||||
import socket
|
||||
from typing import Any, Union, Tuple
|
||||
class ConnectionHandler:
|
||||
_socket: socket.socket
|
||||
def _connect(self, *args: Any) -> Tuple[Union[int, float], Union[int, float]]:
|
||||
...
|
||||
29
typings/kazoo/protocol/states.pyi
Normal file
29
typings/kazoo/protocol/states.pyi
Normal file
@@ -0,0 +1,29 @@
|
||||
from typing import Any, NamedTuple
|
||||
class KazooState:
|
||||
SUSPENDED: str
|
||||
CONNECTED: str
|
||||
LOST: str
|
||||
class KeeperState:
|
||||
AUTH_FAILED: str
|
||||
CONNECTED: str
|
||||
CONNECTED_RO: str
|
||||
CONNECTING: str
|
||||
CLOSED: str
|
||||
EXPIRED_SESSION: str
|
||||
class WatchedEvent(NamedTuple):
|
||||
type: str
|
||||
state: str
|
||||
path: str
|
||||
class ZnodeStat(NamedTuple):
|
||||
|
||||
czxid: int
|
||||
mzxid: int
|
||||
ctime: float
|
||||
mtime: float
|
||||
version: int
|
||||
cversion: int
|
||||
aversion: int
|
||||
ephemeralOwner: Any
|
||||
dataLength: int
|
||||
numChildren: int
|
||||
pzxid: int
|
||||
7
typings/kazoo/retry.pyi
Normal file
7
typings/kazoo/retry.pyi
Normal file
@@ -0,0 +1,7 @@
|
||||
from kazoo.exceptions import KazooException
|
||||
class RetryFailedError(KazooException):
|
||||
...
|
||||
class KazooRetry:
|
||||
deadline: float
|
||||
def __init__(self, max_tries=..., delay=..., backoff=..., max_jitter=..., max_delay=..., ignore_expire=..., sleep_func=..., deadline=..., interrupt=...) -> None:
|
||||
...
|
||||
5
typings/kazoo/security.pyi
Normal file
5
typings/kazoo/security.pyi
Normal file
@@ -0,0 +1,5 @@
|
||||
from collections import namedtuple
|
||||
class ACL(namedtuple('ACL', 'perms id')):
|
||||
...
|
||||
def make_acl(scheme: str, credential: str, read: bool = ..., write: bool = ..., create: bool = ..., delete: bool = ..., admin: bool = ..., all: bool = ...) -> ACL:
|
||||
...
|
||||
13
typings/prettytable/__init__.pyi
Normal file
13
typings/prettytable/__init__.pyi
Normal file
@@ -0,0 +1,13 @@
|
||||
from typing import Any, Dict, List
|
||||
FRAME = 1
|
||||
ALL = 1
|
||||
class PrettyTable:
|
||||
def __init__(self, *args: str, **kwargs: Any) -> None: ...
|
||||
def _stringify_hrule(self, options: Dict[str, Any], where: str = '') -> str: ...
|
||||
@property
|
||||
def align(self) -> Dict[str, str]: ...
|
||||
@align.setter
|
||||
def align(self, val: str) -> None: ...
|
||||
def add_row(self, row: List[Any]) -> None: ...
|
||||
def __str__(self) -> str: ...
|
||||
def __repr__(self) -> str: ...
|
||||
52
typings/psycopg2/__init__.pyi
Normal file
52
typings/psycopg2/__init__.pyi
Normal file
@@ -0,0 +1,52 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any, TypeVar, overload
|
||||
|
||||
# connection and cursor not available at runtime
|
||||
from psycopg2._psycopg import (
|
||||
BINARY as BINARY,
|
||||
DATETIME as DATETIME,
|
||||
NUMBER as NUMBER,
|
||||
ROWID as ROWID,
|
||||
STRING as STRING,
|
||||
Binary as Binary,
|
||||
DatabaseError as DatabaseError,
|
||||
DataError as DataError,
|
||||
Date as Date,
|
||||
DateFromTicks as DateFromTicks,
|
||||
Error as Error,
|
||||
IntegrityError as IntegrityError,
|
||||
InterfaceError as InterfaceError,
|
||||
InternalError as InternalError,
|
||||
NotSupportedError as NotSupportedError,
|
||||
OperationalError as OperationalError,
|
||||
ProgrammingError as ProgrammingError,
|
||||
Time as Time,
|
||||
TimeFromTicks as TimeFromTicks,
|
||||
Timestamp as Timestamp,
|
||||
TimestampFromTicks as TimestampFromTicks,
|
||||
Warning as Warning,
|
||||
__libpq_version__ as __libpq_version__,
|
||||
apilevel as apilevel,
|
||||
connection as connection,
|
||||
cursor as cursor,
|
||||
paramstyle as paramstyle,
|
||||
threadsafety as threadsafety,
|
||||
)
|
||||
|
||||
__version__: str
|
||||
|
||||
_T_conn = TypeVar("_T_conn", bound=connection)
|
||||
|
||||
@overload
|
||||
def connect(dsn: str, connection_factory: Callable[..., _T_conn], cursor_factory: None = None, **kwargs: Any) -> _T_conn: ...
|
||||
@overload
|
||||
def connect(
|
||||
dsn: str | None = None, *, connection_factory: Callable[..., _T_conn], cursor_factory: None = None, **kwargs: Any
|
||||
) -> _T_conn: ...
|
||||
@overload
|
||||
def connect(
|
||||
dsn: str | None = None,
|
||||
connection_factory: Callable[..., connection] | None = None,
|
||||
cursor_factory: Callable[..., cursor] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> connection: ...
|
||||
9
typings/psycopg2/_ipaddress.pyi
Normal file
9
typings/psycopg2/_ipaddress.pyi
Normal file
@@ -0,0 +1,9 @@
|
||||
from _typeshed import Incomplete
|
||||
from typing import Any
|
||||
|
||||
ipaddress: Any
|
||||
|
||||
def register_ipaddress(conn_or_curs: Incomplete | None = None) -> None: ...
|
||||
def cast_interface(s, cur: Incomplete | None = None): ...
|
||||
def cast_network(s, cur: Incomplete | None = None): ...
|
||||
def adapt_ipaddress(obj): ...
|
||||
26
typings/psycopg2/_json.pyi
Normal file
26
typings/psycopg2/_json.pyi
Normal file
@@ -0,0 +1,26 @@
|
||||
from _typeshed import Incomplete
|
||||
from typing import Any
|
||||
|
||||
JSON_OID: int
|
||||
JSONARRAY_OID: int
|
||||
JSONB_OID: int
|
||||
JSONBARRAY_OID: int
|
||||
|
||||
class Json:
|
||||
adapted: Any
|
||||
def __init__(self, adapted, dumps: Incomplete | None = None) -> None: ...
|
||||
def __conform__(self, proto): ...
|
||||
def dumps(self, obj): ...
|
||||
def prepare(self, conn) -> None: ...
|
||||
def getquoted(self): ...
|
||||
|
||||
def register_json(
|
||||
conn_or_curs: Incomplete | None = None,
|
||||
globally: bool = False,
|
||||
loads: Incomplete | None = None,
|
||||
oid: Incomplete | None = None,
|
||||
array_oid: Incomplete | None = None,
|
||||
name: str = "json",
|
||||
): ...
|
||||
def register_default_json(conn_or_curs: Incomplete | None = None, globally: bool = False, loads: Incomplete | None = None): ...
|
||||
def register_default_jsonb(conn_or_curs: Incomplete | None = None, globally: bool = False, loads: Incomplete | None = None): ...
|
||||
488
typings/psycopg2/_psycopg.pyi
Normal file
488
typings/psycopg2/_psycopg.pyi
Normal file
@@ -0,0 +1,488 @@
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from types import TracebackType
|
||||
from typing import Any, TypeVar, overload
|
||||
from typing_extensions import Literal, Self, TypeAlias
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extensions
|
||||
from psycopg2.sql import Composable
|
||||
|
||||
_Vars: TypeAlias = Sequence[Any] | Mapping[str, Any] | None
|
||||
|
||||
BINARY: Any
|
||||
BINARYARRAY: Any
|
||||
BOOLEAN: Any
|
||||
BOOLEANARRAY: Any
|
||||
BYTES: Any
|
||||
BYTESARRAY: Any
|
||||
CIDRARRAY: Any
|
||||
DATE: Any
|
||||
DATEARRAY: Any
|
||||
DATETIME: Any
|
||||
DATETIMEARRAY: Any
|
||||
DATETIMETZ: Any
|
||||
DATETIMETZARRAY: Any
|
||||
DECIMAL: Any
|
||||
DECIMALARRAY: Any
|
||||
FLOAT: Any
|
||||
FLOATARRAY: Any
|
||||
INETARRAY: Any
|
||||
INTEGER: Any
|
||||
INTEGERARRAY: Any
|
||||
INTERVAL: Any
|
||||
INTERVALARRAY: Any
|
||||
LONGINTEGER: Any
|
||||
LONGINTEGERARRAY: Any
|
||||
MACADDRARRAY: Any
|
||||
NUMBER: Any
|
||||
PYDATE: Any
|
||||
PYDATEARRAY: Any
|
||||
PYDATETIME: Any
|
||||
PYDATETIMEARRAY: Any
|
||||
PYDATETIMETZ: Any
|
||||
PYDATETIMETZARRAY: Any
|
||||
PYINTERVAL: Any
|
||||
PYINTERVALARRAY: Any
|
||||
PYTIME: Any
|
||||
PYTIMEARRAY: Any
|
||||
REPLICATION_LOGICAL: int
|
||||
REPLICATION_PHYSICAL: int
|
||||
ROWID: Any
|
||||
ROWIDARRAY: Any
|
||||
STRING: Any
|
||||
STRINGARRAY: Any
|
||||
TIME: Any
|
||||
TIMEARRAY: Any
|
||||
UNICODE: Any
|
||||
UNICODEARRAY: Any
|
||||
UNKNOWN: Any
|
||||
adapters: dict[Any, Any]
|
||||
apilevel: str
|
||||
binary_types: dict[Any, Any]
|
||||
encodings: dict[Any, Any]
|
||||
paramstyle: str
|
||||
sqlstate_errors: dict[Any, Any]
|
||||
string_types: dict[Any, Any]
|
||||
threadsafety: int
|
||||
|
||||
__libpq_version__: int
|
||||
|
||||
class cursor:
|
||||
arraysize: int
|
||||
binary_types: Any
|
||||
closed: Any
|
||||
connection: Any
|
||||
description: Any
|
||||
itersize: Any
|
||||
lastrowid: Any
|
||||
name: Any
|
||||
pgresult_ptr: Any
|
||||
query: Any
|
||||
row_factory: Any
|
||||
rowcount: int
|
||||
rownumber: int
|
||||
scrollable: bool | None
|
||||
statusmessage: Any
|
||||
string_types: Any
|
||||
typecaster: Any
|
||||
tzinfo_factory: Any
|
||||
withhold: bool
|
||||
def __init__(self, conn: connection, name: str | bytes | None = ...) -> None: ...
|
||||
def callproc(self, procname, parameters=...): ...
|
||||
def cast(self, oid, s): ...
|
||||
def close(self): ...
|
||||
def copy_expert(self, sql: str | bytes | Composable, file, size=...): ...
|
||||
def copy_from(self, file, table, sep=..., null=..., size=..., columns=...): ...
|
||||
def copy_to(self, file, table, sep=..., null=..., columns=...): ...
|
||||
def execute(self, query: str | bytes | Composable, vars: _Vars = ...) -> None: ...
|
||||
def executemany(self, query: str | bytes | Composable, vars_list: Iterable[_Vars]) -> None: ...
|
||||
def fetchall(self) -> list[tuple[Any, ...]]: ...
|
||||
def fetchmany(self, size: int | None = ...) -> list[tuple[Any, ...]]: ...
|
||||
def fetchone(self) -> tuple[Any, ...] | None: ...
|
||||
def mogrify(self, *args, **kwargs): ...
|
||||
def nextset(self): ...
|
||||
def scroll(self, value, mode=...): ...
|
||||
def setinputsizes(self, sizes): ...
|
||||
def setoutputsize(self, size, column=...): ...
|
||||
def __enter__(self) -> Self: ...
|
||||
def __exit__(
|
||||
self, type: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None
|
||||
) -> None: ...
|
||||
def __iter__(self) -> Self: ...
|
||||
def __next__(self) -> tuple[Any, ...]: ...
|
||||
|
||||
_Cursor: TypeAlias = cursor
|
||||
|
||||
class AsIs:
|
||||
adapted: Any
|
||||
def __init__(self, *args, **kwargs) -> None: ...
|
||||
def getquoted(self, *args, **kwargs): ...
|
||||
def __conform__(self, *args, **kwargs): ...
|
||||
|
||||
class Binary:
|
||||
adapted: Any
|
||||
buffer: Any
|
||||
def __init__(self, *args, **kwargs) -> None: ...
|
||||
def getquoted(self, *args, **kwargs): ...
|
||||
def prepare(self, conn): ...
|
||||
def __conform__(self, *args, **kwargs): ...
|
||||
|
||||
class Boolean:
|
||||
adapted: Any
|
||||
def __init__(self, *args, **kwargs) -> None: ...
|
||||
def getquoted(self, *args, **kwargs): ...
|
||||
def __conform__(self, *args, **kwargs): ...
|
||||
|
||||
class Column:
|
||||
display_size: Any
|
||||
internal_size: Any
|
||||
name: Any
|
||||
null_ok: Any
|
||||
precision: Any
|
||||
scale: Any
|
||||
table_column: Any
|
||||
table_oid: Any
|
||||
type_code: Any
|
||||
def __init__(self, *args, **kwargs) -> None: ...
|
||||
def __eq__(self, __other): ...
|
||||
def __ge__(self, __other): ...
|
||||
def __getitem__(self, __index): ...
|
||||
def __getstate__(self): ...
|
||||
def __gt__(self, __other): ...
|
||||
def __le__(self, __other): ...
|
||||
def __len__(self) -> int: ...
|
||||
def __lt__(self, __other): ...
|
||||
def __ne__(self, __other): ...
|
||||
def __setstate__(self, state): ...
|
||||
|
||||
class ConnectionInfo:
|
||||
# Note: the following properties can be None if their corresponding libpq function
|
||||
# returns NULL. They're not annotated as such, because this is very unlikely in
|
||||
# practice---the psycopg2 docs [1] don't even mention this as a possibility!
|
||||
#
|
||||
# - db_name
|
||||
# - user
|
||||
# - password
|
||||
# - host
|
||||
# - port
|
||||
# - options
|
||||
#
|
||||
# (To prove this, one needs to inspect the psycopg2 source code [2], plus the
|
||||
# documentation [3] and source code [4] of the corresponding libpq calls.)
|
||||
#
|
||||
# [1]: https://www.psycopg.org/docs/extensions.html#psycopg2.extensions.ConnectionInfo
|
||||
# [2]: https://github.com/psycopg/psycopg2/blob/1d3a89a0bba621dc1cc9b32db6d241bd2da85ad1/psycopg/conninfo_type.c#L52 and below
|
||||
# [3]: https://www.postgresql.org/docs/current/libpq-status.html
|
||||
# [4]: https://github.com/postgres/postgres/blob/b39838889e76274b107935fa8e8951baf0e8b31b/src/interfaces/libpq/fe-connect.c#L6754 and below
|
||||
@property
|
||||
def backend_pid(self) -> int: ...
|
||||
@property
|
||||
def dbname(self) -> str: ...
|
||||
@property
|
||||
def dsn_parameters(self) -> dict[str, str]: ...
|
||||
@property
|
||||
def error_message(self) -> str | None: ...
|
||||
@property
|
||||
def host(self) -> str: ...
|
||||
@property
|
||||
def needs_password(self) -> bool: ...
|
||||
@property
|
||||
def options(self) -> str: ...
|
||||
@property
|
||||
def password(self) -> str: ...
|
||||
@property
|
||||
def port(self) -> int: ...
|
||||
@property
|
||||
def protocol_version(self) -> int: ...
|
||||
@property
|
||||
def server_version(self) -> int: ...
|
||||
@property
|
||||
def socket(self) -> int: ...
|
||||
@property
|
||||
def ssl_attribute_names(self) -> list[str]: ...
|
||||
@property
|
||||
def ssl_in_use(self) -> bool: ...
|
||||
@property
|
||||
def status(self) -> int: ...
|
||||
@property
|
||||
def transaction_status(self) -> int: ...
|
||||
@property
|
||||
def used_password(self) -> bool: ...
|
||||
@property
|
||||
def user(self) -> str: ...
|
||||
def __init__(self, *args, **kwargs) -> None: ...
|
||||
def parameter_status(self, name: str) -> str | None: ...
|
||||
def ssl_attribute(self, name: str) -> str | None: ...
|
||||
|
||||
class DataError(psycopg2.DatabaseError): ...
|
||||
class DatabaseError(psycopg2.Error): ...
|
||||
|
||||
class Decimal:
|
||||
adapted: Any
|
||||
def __init__(self, *args, **kwargs) -> None: ...
|
||||
def getquoted(self, *args, **kwargs): ...
|
||||
def __conform__(self, *args, **kwargs): ...
|
||||
|
||||
class Diagnostics:
|
||||
column_name: str | None
|
||||
constraint_name: str | None
|
||||
context: str | None
|
||||
datatype_name: str | None
|
||||
internal_position: str | None
|
||||
internal_query: str | None
|
||||
message_detail: str | None
|
||||
message_hint: str | None
|
||||
message_primary: str | None
|
||||
schema_name: str | None
|
||||
severity: str | None
|
||||
severity_nonlocalized: str | None
|
||||
source_file: str | None
|
||||
source_function: str | None
|
||||
source_line: str | None
|
||||
sqlstate: str | None
|
||||
statement_position: str | None
|
||||
table_name: str | None
|
||||
def __init__(self, __err: Error) -> None: ...
|
||||
|
||||
class Error(Exception):
|
||||
cursor: _Cursor | None
|
||||
diag: Diagnostics
|
||||
pgcode: str | None
|
||||
pgerror: str | None
|
||||
def __init__(self, *args, **kwargs) -> None: ...
|
||||
def __reduce__(self): ...
|
||||
def __setstate__(self, state): ...
|
||||
|
||||
class Float:
|
||||
adapted: Any
|
||||
def __init__(self, *args, **kwargs) -> None: ...
|
||||
def getquoted(self, *args, **kwargs): ...
|
||||
def __conform__(self, *args, **kwargs): ...
|
||||
|
||||
class ISQLQuote:
|
||||
_wrapped: Any
|
||||
def __init__(self, *args, **kwargs) -> None: ...
|
||||
def getbinary(self, *args, **kwargs): ...
|
||||
def getbuffer(self, *args, **kwargs): ...
|
||||
def getquoted(self, *args, **kwargs) -> bytes: ...
|
||||
|
||||
class Int:
|
||||
adapted: Any
|
||||
def __init__(self, *args, **kwargs) -> None: ...
|
||||
def getquoted(self, *args, **kwargs): ...
|
||||
def __conform__(self, *args, **kwargs): ...
|
||||
|
||||
class IntegrityError(psycopg2.DatabaseError): ...
|
||||
class InterfaceError(psycopg2.Error): ...
|
||||
class InternalError(psycopg2.DatabaseError): ...
|
||||
|
||||
class List:
|
||||
adapted: Any
|
||||
def __init__(self, *args, **kwargs) -> None: ...
|
||||
def getquoted(self, *args, **kwargs): ...
|
||||
def prepare(self, *args, **kwargs): ...
|
||||
def __conform__(self, *args, **kwargs): ...
|
||||
|
||||
class NotSupportedError(psycopg2.DatabaseError): ...
|
||||
|
||||
class Notify:
|
||||
channel: Any
|
||||
payload: Any
|
||||
pid: Any
|
||||
def __init__(self, *args, **kwargs) -> None: ...
|
||||
def __eq__(self, __other): ...
|
||||
def __ge__(self, __other): ...
|
||||
def __getitem__(self, __index): ...
|
||||
def __gt__(self, __other): ...
|
||||
def __hash__(self) -> int: ...
|
||||
def __le__(self, __other): ...
|
||||
def __len__(self) -> int: ...
|
||||
def __lt__(self, __other): ...
|
||||
def __ne__(self, __other): ...
|
||||
|
||||
class OperationalError(psycopg2.DatabaseError): ...
|
||||
class ProgrammingError(psycopg2.DatabaseError): ...
|
||||
class QueryCanceledError(psycopg2.OperationalError): ...
|
||||
|
||||
class QuotedString:
|
||||
adapted: Any
|
||||
buffer: Any
|
||||
encoding: Any
|
||||
def __init__(self, *args, **kwargs) -> None: ...
|
||||
def getquoted(self, *args, **kwargs): ...
|
||||
def prepare(self, *args, **kwargs): ...
|
||||
def __conform__(self, *args, **kwargs): ...
|
||||
|
||||
class ReplicationConnection(psycopg2.extensions.connection):
|
||||
autocommit: Any
|
||||
isolation_level: Any
|
||||
replication_type: Any
|
||||
reset: Any
|
||||
set_isolation_level: Any
|
||||
set_session: Any
|
||||
def __init__(self, *args, **kwargs) -> None: ...
|
||||
|
||||
class ReplicationCursor(cursor):
|
||||
feedback_timestamp: Any
|
||||
io_timestamp: Any
|
||||
wal_end: Any
|
||||
def __init__(self, *args, **kwargs) -> None: ...
|
||||
def consume_stream(self, consumer, keepalive_interval=...): ...
|
||||
def read_message(self, *args, **kwargs): ...
|
||||
def send_feedback(self, write_lsn=..., flush_lsn=..., apply_lsn=..., reply=..., force=...): ...
|
||||
def start_replication_expert(self, command, decode=..., status_interval=...): ...
|
||||
|
||||
class ReplicationMessage:
|
||||
cursor: Any
|
||||
data_size: Any
|
||||
data_start: Any
|
||||
payload: Any
|
||||
send_time: Any
|
||||
wal_end: Any
|
||||
def __init__(self, *args, **kwargs) -> None: ...
|
||||
|
||||
class TransactionRollbackError(psycopg2.OperationalError): ...
|
||||
class Warning(Exception): ...
|
||||
|
||||
class Xid:
|
||||
bqual: Any
|
||||
database: Any
|
||||
format_id: Any
|
||||
gtrid: Any
|
||||
owner: Any
|
||||
prepared: Any
|
||||
def __init__(self, *args, **kwargs) -> None: ...
|
||||
def from_string(self, *args, **kwargs): ...
|
||||
def __getitem__(self, __index): ...
|
||||
def __len__(self) -> int: ...
|
||||
|
||||
_T_cur = TypeVar("_T_cur", bound=cursor)
|
||||
|
||||
class connection:
|
||||
DataError: Any
|
||||
DatabaseError: Any
|
||||
Error: Any
|
||||
IntegrityError: Any
|
||||
InterfaceError: Any
|
||||
InternalError: Any
|
||||
NotSupportedError: Any
|
||||
OperationalError: Any
|
||||
ProgrammingError: Any
|
||||
Warning: Any
|
||||
@property
|
||||
def async_(self) -> int: ...
|
||||
autocommit: bool
|
||||
@property
|
||||
def binary_types(self) -> Any: ...
|
||||
@property
|
||||
def closed(self) -> int: ...
|
||||
cursor_factory: Callable[..., _Cursor]
|
||||
@property
|
||||
def dsn(self) -> str: ...
|
||||
@property
|
||||
def encoding(self) -> str: ...
|
||||
@property
|
||||
def info(self) -> ConnectionInfo: ...
|
||||
@property
|
||||
def isolation_level(self) -> int | None: ...
|
||||
@isolation_level.setter
|
||||
def isolation_level(self, __value: str | bytes | int | None) -> None: ...
|
||||
notices: list[Any]
|
||||
notifies: list[Any]
|
||||
@property
|
||||
def pgconn_ptr(self) -> int | None: ...
|
||||
@property
|
||||
def protocol_version(self) -> int: ...
|
||||
@property
|
||||
def deferrable(self) -> bool | None: ...
|
||||
@deferrable.setter
|
||||
def deferrable(self, __value: Literal["default"] | bool | None) -> None: ...
|
||||
@property
|
||||
def readonly(self) -> bool | None: ...
|
||||
@readonly.setter
|
||||
def readonly(self, __value: Literal["default"] | bool | None) -> None: ...
|
||||
@property
|
||||
def server_version(self) -> int: ...
|
||||
@property
|
||||
def status(self) -> int: ...
|
||||
@property
|
||||
def string_types(self) -> Any: ...
|
||||
# Really it's dsn: str, async: int = ..., async_: int = ..., but
|
||||
# that would be a syntax error.
|
||||
def __init__(self, dsn: str, *, async_: int = ...) -> None: ...
|
||||
def cancel(self) -> None: ...
|
||||
def close(self) -> None: ...
|
||||
def commit(self) -> None: ...
|
||||
@overload
|
||||
def cursor(self, name: str | bytes | None = ..., *, withhold: bool = ..., scrollable: bool | None = ...) -> _Cursor: ...
|
||||
def fileno(self) -> int: ...
|
||||
def get_backend_pid(self) -> int: ...
|
||||
def get_dsn_parameters(self) -> dict[str, str]: ...
|
||||
def get_native_connection(self): ...
|
||||
def get_parameter_status(self, parameter: str) -> str | None: ...
|
||||
def get_transaction_status(self) -> int: ...
|
||||
def isexecuting(self) -> bool: ...
|
||||
def lobject(
|
||||
self,
|
||||
oid: int = ...,
|
||||
mode: str | None = ...,
|
||||
new_oid: int = ...,
|
||||
new_file: str | None = ...,
|
||||
lobject_factory: type[lobject] = ...,
|
||||
) -> lobject: ...
|
||||
def poll(self) -> int: ...
|
||||
def reset(self) -> None: ...
|
||||
def rollback(self) -> None: ...
|
||||
def set_client_encoding(self, encoding: str) -> None: ...
|
||||
def set_isolation_level(self, level: int | None) -> None: ...
|
||||
def set_session(
|
||||
self,
|
||||
isolation_level: str | bytes | int | None = ...,
|
||||
readonly: bool | Literal["default", b"default"] | None = ...,
|
||||
deferrable: bool | Literal["default", b"default"] | None = ...,
|
||||
autocommit: bool = ...,
|
||||
) -> None: ...
|
||||
def tpc_begin(self, xid: str | bytes | Xid) -> None: ...
|
||||
def tpc_commit(self, __xid: str | bytes | Xid = ...) -> None: ...
|
||||
def tpc_prepare(self) -> None: ...
|
||||
def tpc_recover(self) -> list[Xid]: ...
|
||||
def tpc_rollback(self, __xid: str | bytes | Xid = ...) -> None: ...
|
||||
def xid(self, format_id, gtrid, bqual) -> Xid: ...
|
||||
def __enter__(self) -> Self: ...
|
||||
def __exit__(self, __type: type[BaseException] | None, __name: BaseException | None, __tb: TracebackType | None) -> None: ...
|
||||
|
||||
class lobject:
|
||||
closed: Any
|
||||
mode: Any
|
||||
oid: Any
|
||||
def __init__(self, *args, **kwargs) -> None: ...
|
||||
def close(self): ...
|
||||
def export(self, filename): ...
|
||||
def read(self, size=...): ...
|
||||
def seek(self, offset, whence=...): ...
|
||||
def tell(self): ...
|
||||
def truncate(self, len=...): ...
|
||||
def unlink(self): ...
|
||||
def write(self, str): ...
|
||||
|
||||
def Date(year, month, day): ...
|
||||
def DateFromPy(*args, **kwargs): ...
|
||||
def DateFromTicks(ticks): ...
|
||||
def IntervalFromPy(*args, **kwargs): ...
|
||||
def Time(hour, minutes, seconds, tzinfo=...): ...
|
||||
def TimeFromPy(*args, **kwargs): ...
|
||||
def TimeFromTicks(ticks): ...
|
||||
def Timestamp(year, month, day, hour, minutes, seconds, tzinfo=...): ...
|
||||
def TimestampFromPy(*args, **kwargs): ...
|
||||
def TimestampFromTicks(ticks): ...
|
||||
def _connect(*args, **kwargs): ...
|
||||
def adapt(*args, **kwargs): ...
|
||||
def encrypt_password(*args, **kwargs): ...
|
||||
def get_wait_callback(*args, **kwargs): ...
|
||||
def libpq_version(*args, **kwargs): ...
|
||||
def new_array_type(oids, name, baseobj): ...
|
||||
def new_type(oids, name, castobj): ...
|
||||
def parse_dsn(dsn: str | bytes) -> dict[str, Any]: ...
|
||||
def quote_ident(value: Any, scope: connection | cursor | None) -> str: ...
|
||||
def register_type(*args, **kwargs): ...
|
||||
def set_wait_callback(_none): ...
|
||||
62
typings/psycopg2/_range.pyi
Normal file
62
typings/psycopg2/_range.pyi
Normal file
@@ -0,0 +1,62 @@
|
||||
from _typeshed import Incomplete
|
||||
from typing import Any
|
||||
|
||||
class Range:
|
||||
def __init__(
|
||||
self, lower: Incomplete | None = None, upper: Incomplete | None = None, bounds: str = "[)", empty: bool = False
|
||||
) -> None: ...
|
||||
@property
|
||||
def lower(self): ...
|
||||
@property
|
||||
def upper(self): ...
|
||||
@property
|
||||
def isempty(self): ...
|
||||
@property
|
||||
def lower_inf(self): ...
|
||||
@property
|
||||
def upper_inf(self): ...
|
||||
@property
|
||||
def lower_inc(self): ...
|
||||
@property
|
||||
def upper_inc(self): ...
|
||||
def __contains__(self, x): ...
|
||||
def __bool__(self) -> bool: ...
|
||||
def __eq__(self, other): ...
|
||||
def __ne__(self, other): ...
|
||||
def __hash__(self) -> int: ...
|
||||
def __lt__(self, other): ...
|
||||
def __le__(self, other): ...
|
||||
def __gt__(self, other): ...
|
||||
def __ge__(self, other): ...
|
||||
|
||||
def register_range(pgrange, pyrange, conn_or_curs, globally: bool = False): ...
|
||||
|
||||
class RangeAdapter:
|
||||
name: Any
|
||||
adapted: Any
|
||||
def __init__(self, adapted) -> None: ...
|
||||
def __conform__(self, proto): ...
|
||||
def prepare(self, conn) -> None: ...
|
||||
def getquoted(self): ...
|
||||
|
||||
class RangeCaster:
|
||||
subtype_oid: Any
|
||||
typecaster: Any
|
||||
array_typecaster: Any
|
||||
def __init__(self, pgrange, pyrange, oid, subtype_oid, array_oid: Incomplete | None = None) -> None: ...
|
||||
def parse(self, s, cur: Incomplete | None = None): ...
|
||||
|
||||
class NumericRange(Range): ...
|
||||
class DateRange(Range): ...
|
||||
class DateTimeRange(Range): ...
|
||||
class DateTimeTZRange(Range): ...
|
||||
|
||||
class NumberRangeAdapter(RangeAdapter):
|
||||
def getquoted(self): ...
|
||||
|
||||
int4range_caster: Any
|
||||
int8range_caster: Any
|
||||
numrange_caster: Any
|
||||
daterange_caster: Any
|
||||
tsrange_caster: Any
|
||||
tstzrange_caster: Any
|
||||
304
typings/psycopg2/errorcodes.pyi
Normal file
304
typings/psycopg2/errorcodes.pyi
Normal file
@@ -0,0 +1,304 @@
|
||||
def lookup(code, _cache={}): ...
|
||||
|
||||
CLASS_SUCCESSFUL_COMPLETION: str
|
||||
CLASS_WARNING: str
|
||||
CLASS_NO_DATA: str
|
||||
CLASS_SQL_STATEMENT_NOT_YET_COMPLETE: str
|
||||
CLASS_CONNECTION_EXCEPTION: str
|
||||
CLASS_TRIGGERED_ACTION_EXCEPTION: str
|
||||
CLASS_FEATURE_NOT_SUPPORTED: str
|
||||
CLASS_INVALID_TRANSACTION_INITIATION: str
|
||||
CLASS_LOCATOR_EXCEPTION: str
|
||||
CLASS_INVALID_GRANTOR: str
|
||||
CLASS_INVALID_ROLE_SPECIFICATION: str
|
||||
CLASS_DIAGNOSTICS_EXCEPTION: str
|
||||
CLASS_CASE_NOT_FOUND: str
|
||||
CLASS_CARDINALITY_VIOLATION: str
|
||||
CLASS_DATA_EXCEPTION: str
|
||||
CLASS_INTEGRITY_CONSTRAINT_VIOLATION: str
|
||||
CLASS_INVALID_CURSOR_STATE: str
|
||||
CLASS_INVALID_TRANSACTION_STATE: str
|
||||
CLASS_INVALID_SQL_STATEMENT_NAME: str
|
||||
CLASS_TRIGGERED_DATA_CHANGE_VIOLATION: str
|
||||
CLASS_INVALID_AUTHORIZATION_SPECIFICATION: str
|
||||
CLASS_DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST: str
|
||||
CLASS_INVALID_TRANSACTION_TERMINATION: str
|
||||
CLASS_SQL_ROUTINE_EXCEPTION: str
|
||||
CLASS_INVALID_CURSOR_NAME: str
|
||||
CLASS_EXTERNAL_ROUTINE_EXCEPTION: str
|
||||
CLASS_EXTERNAL_ROUTINE_INVOCATION_EXCEPTION: str
|
||||
CLASS_SAVEPOINT_EXCEPTION: str
|
||||
CLASS_INVALID_CATALOG_NAME: str
|
||||
CLASS_INVALID_SCHEMA_NAME: str
|
||||
CLASS_TRANSACTION_ROLLBACK: str
|
||||
CLASS_SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION: str
|
||||
CLASS_WITH_CHECK_OPTION_VIOLATION: str
|
||||
CLASS_INSUFFICIENT_RESOURCES: str
|
||||
CLASS_PROGRAM_LIMIT_EXCEEDED: str
|
||||
CLASS_OBJECT_NOT_IN_PREREQUISITE_STATE: str
|
||||
CLASS_OPERATOR_INTERVENTION: str
|
||||
CLASS_SYSTEM_ERROR: str
|
||||
CLASS_SNAPSHOT_FAILURE: str
|
||||
CLASS_CONFIGURATION_FILE_ERROR: str
|
||||
CLASS_FOREIGN_DATA_WRAPPER_ERROR: str
|
||||
CLASS_PL_PGSQL_ERROR: str
|
||||
CLASS_INTERNAL_ERROR: str
|
||||
SUCCESSFUL_COMPLETION: str
|
||||
WARNING: str
|
||||
NULL_VALUE_ELIMINATED_IN_SET_FUNCTION: str
|
||||
STRING_DATA_RIGHT_TRUNCATION_: str
|
||||
PRIVILEGE_NOT_REVOKED: str
|
||||
PRIVILEGE_NOT_GRANTED: str
|
||||
IMPLICIT_ZERO_BIT_PADDING: str
|
||||
DYNAMIC_RESULT_SETS_RETURNED: str
|
||||
DEPRECATED_FEATURE: str
|
||||
NO_DATA: str
|
||||
NO_ADDITIONAL_DYNAMIC_RESULT_SETS_RETURNED: str
|
||||
SQL_STATEMENT_NOT_YET_COMPLETE: str
|
||||
CONNECTION_EXCEPTION: str
|
||||
SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION: str
|
||||
CONNECTION_DOES_NOT_EXIST: str
|
||||
SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION: str
|
||||
CONNECTION_FAILURE: str
|
||||
TRANSACTION_RESOLUTION_UNKNOWN: str
|
||||
PROTOCOL_VIOLATION: str
|
||||
TRIGGERED_ACTION_EXCEPTION: str
|
||||
FEATURE_NOT_SUPPORTED: str
|
||||
INVALID_TRANSACTION_INITIATION: str
|
||||
LOCATOR_EXCEPTION: str
|
||||
INVALID_LOCATOR_SPECIFICATION: str
|
||||
INVALID_GRANTOR: str
|
||||
INVALID_GRANT_OPERATION: str
|
||||
INVALID_ROLE_SPECIFICATION: str
|
||||
DIAGNOSTICS_EXCEPTION: str
|
||||
STACKED_DIAGNOSTICS_ACCESSED_WITHOUT_ACTIVE_HANDLER: str
|
||||
CASE_NOT_FOUND: str
|
||||
CARDINALITY_VIOLATION: str
|
||||
DATA_EXCEPTION: str
|
||||
STRING_DATA_RIGHT_TRUNCATION: str
|
||||
NULL_VALUE_NO_INDICATOR_PARAMETER: str
|
||||
NUMERIC_VALUE_OUT_OF_RANGE: str
|
||||
NULL_VALUE_NOT_ALLOWED_: str
|
||||
ERROR_IN_ASSIGNMENT: str
|
||||
INVALID_DATETIME_FORMAT: str
|
||||
DATETIME_FIELD_OVERFLOW: str
|
||||
INVALID_TIME_ZONE_DISPLACEMENT_VALUE: str
|
||||
ESCAPE_CHARACTER_CONFLICT: str
|
||||
INVALID_USE_OF_ESCAPE_CHARACTER: str
|
||||
INVALID_ESCAPE_OCTET: str
|
||||
ZERO_LENGTH_CHARACTER_STRING: str
|
||||
MOST_SPECIFIC_TYPE_MISMATCH: str
|
||||
SEQUENCE_GENERATOR_LIMIT_EXCEEDED: str
|
||||
NOT_AN_XML_DOCUMENT: str
|
||||
INVALID_XML_DOCUMENT: str
|
||||
INVALID_XML_CONTENT: str
|
||||
INVALID_XML_COMMENT: str
|
||||
INVALID_XML_PROCESSING_INSTRUCTION: str
|
||||
INVALID_INDICATOR_PARAMETER_VALUE: str
|
||||
SUBSTRING_ERROR: str
|
||||
DIVISION_BY_ZERO: str
|
||||
INVALID_PRECEDING_OR_FOLLOWING_SIZE: str
|
||||
INVALID_ARGUMENT_FOR_NTILE_FUNCTION: str
|
||||
INTERVAL_FIELD_OVERFLOW: str
|
||||
INVALID_ARGUMENT_FOR_NTH_VALUE_FUNCTION: str
|
||||
INVALID_CHARACTER_VALUE_FOR_CAST: str
|
||||
INVALID_ESCAPE_CHARACTER: str
|
||||
INVALID_REGULAR_EXPRESSION: str
|
||||
INVALID_ARGUMENT_FOR_LOGARITHM: str
|
||||
INVALID_ARGUMENT_FOR_POWER_FUNCTION: str
|
||||
INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION: str
|
||||
INVALID_ROW_COUNT_IN_LIMIT_CLAUSE: str
|
||||
INVALID_ROW_COUNT_IN_RESULT_OFFSET_CLAUSE: str
|
||||
INVALID_LIMIT_VALUE: str
|
||||
CHARACTER_NOT_IN_REPERTOIRE: str
|
||||
INDICATOR_OVERFLOW: str
|
||||
INVALID_PARAMETER_VALUE: str
|
||||
UNTERMINATED_C_STRING: str
|
||||
INVALID_ESCAPE_SEQUENCE: str
|
||||
STRING_DATA_LENGTH_MISMATCH: str
|
||||
TRIM_ERROR: str
|
||||
ARRAY_SUBSCRIPT_ERROR: str
|
||||
INVALID_TABLESAMPLE_REPEAT: str
|
||||
INVALID_TABLESAMPLE_ARGUMENT: str
|
||||
DUPLICATE_JSON_OBJECT_KEY_VALUE: str
|
||||
INVALID_ARGUMENT_FOR_SQL_JSON_DATETIME_FUNCTION: str
|
||||
INVALID_JSON_TEXT: str
|
||||
INVALID_SQL_JSON_SUBSCRIPT: str
|
||||
MORE_THAN_ONE_SQL_JSON_ITEM: str
|
||||
NO_SQL_JSON_ITEM: str
|
||||
NON_NUMERIC_SQL_JSON_ITEM: str
|
||||
NON_UNIQUE_KEYS_IN_A_JSON_OBJECT: str
|
||||
SINGLETON_SQL_JSON_ITEM_REQUIRED: str
|
||||
SQL_JSON_ARRAY_NOT_FOUND: str
|
||||
SQL_JSON_MEMBER_NOT_FOUND: str
|
||||
SQL_JSON_NUMBER_NOT_FOUND: str
|
||||
SQL_JSON_OBJECT_NOT_FOUND: str
|
||||
TOO_MANY_JSON_ARRAY_ELEMENTS: str
|
||||
TOO_MANY_JSON_OBJECT_MEMBERS: str
|
||||
SQL_JSON_SCALAR_REQUIRED: str
|
||||
FLOATING_POINT_EXCEPTION: str
|
||||
INVALID_TEXT_REPRESENTATION: str
|
||||
INVALID_BINARY_REPRESENTATION: str
|
||||
BAD_COPY_FILE_FORMAT: str
|
||||
UNTRANSLATABLE_CHARACTER: str
|
||||
NONSTANDARD_USE_OF_ESCAPE_CHARACTER: str
|
||||
INTEGRITY_CONSTRAINT_VIOLATION: str
|
||||
RESTRICT_VIOLATION: str
|
||||
NOT_NULL_VIOLATION: str
|
||||
FOREIGN_KEY_VIOLATION: str
|
||||
UNIQUE_VIOLATION: str
|
||||
CHECK_VIOLATION: str
|
||||
EXCLUSION_VIOLATION: str
|
||||
INVALID_CURSOR_STATE: str
|
||||
INVALID_TRANSACTION_STATE: str
|
||||
ACTIVE_SQL_TRANSACTION: str
|
||||
BRANCH_TRANSACTION_ALREADY_ACTIVE: str
|
||||
INAPPROPRIATE_ACCESS_MODE_FOR_BRANCH_TRANSACTION: str
|
||||
INAPPROPRIATE_ISOLATION_LEVEL_FOR_BRANCH_TRANSACTION: str
|
||||
NO_ACTIVE_SQL_TRANSACTION_FOR_BRANCH_TRANSACTION: str
|
||||
READ_ONLY_SQL_TRANSACTION: str
|
||||
SCHEMA_AND_DATA_STATEMENT_MIXING_NOT_SUPPORTED: str
|
||||
HELD_CURSOR_REQUIRES_SAME_ISOLATION_LEVEL: str
|
||||
NO_ACTIVE_SQL_TRANSACTION: str
|
||||
IN_FAILED_SQL_TRANSACTION: str
|
||||
IDLE_IN_TRANSACTION_SESSION_TIMEOUT: str
|
||||
INVALID_SQL_STATEMENT_NAME: str
|
||||
TRIGGERED_DATA_CHANGE_VIOLATION: str
|
||||
INVALID_AUTHORIZATION_SPECIFICATION: str
|
||||
INVALID_PASSWORD: str
|
||||
DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST: str
|
||||
DEPENDENT_OBJECTS_STILL_EXIST: str
|
||||
INVALID_TRANSACTION_TERMINATION: str
|
||||
SQL_ROUTINE_EXCEPTION: str
|
||||
MODIFYING_SQL_DATA_NOT_PERMITTED_: str
|
||||
PROHIBITED_SQL_STATEMENT_ATTEMPTED_: str
|
||||
READING_SQL_DATA_NOT_PERMITTED_: str
|
||||
FUNCTION_EXECUTED_NO_RETURN_STATEMENT: str
|
||||
INVALID_CURSOR_NAME: str
|
||||
EXTERNAL_ROUTINE_EXCEPTION: str
|
||||
CONTAINING_SQL_NOT_PERMITTED: str
|
||||
MODIFYING_SQL_DATA_NOT_PERMITTED: str
|
||||
PROHIBITED_SQL_STATEMENT_ATTEMPTED: str
|
||||
READING_SQL_DATA_NOT_PERMITTED: str
|
||||
EXTERNAL_ROUTINE_INVOCATION_EXCEPTION: str
|
||||
INVALID_SQLSTATE_RETURNED: str
|
||||
NULL_VALUE_NOT_ALLOWED: str
|
||||
TRIGGER_PROTOCOL_VIOLATED: str
|
||||
SRF_PROTOCOL_VIOLATED: str
|
||||
EVENT_TRIGGER_PROTOCOL_VIOLATED: str
|
||||
SAVEPOINT_EXCEPTION: str
|
||||
INVALID_SAVEPOINT_SPECIFICATION: str
|
||||
INVALID_CATALOG_NAME: str
|
||||
INVALID_SCHEMA_NAME: str
|
||||
TRANSACTION_ROLLBACK: str
|
||||
SERIALIZATION_FAILURE: str
|
||||
TRANSACTION_INTEGRITY_CONSTRAINT_VIOLATION: str
|
||||
STATEMENT_COMPLETION_UNKNOWN: str
|
||||
DEADLOCK_DETECTED: str
|
||||
SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION: str
|
||||
INSUFFICIENT_PRIVILEGE: str
|
||||
SYNTAX_ERROR: str
|
||||
INVALID_NAME: str
|
||||
INVALID_COLUMN_DEFINITION: str
|
||||
NAME_TOO_LONG: str
|
||||
DUPLICATE_COLUMN: str
|
||||
AMBIGUOUS_COLUMN: str
|
||||
UNDEFINED_COLUMN: str
|
||||
UNDEFINED_OBJECT: str
|
||||
DUPLICATE_OBJECT: str
|
||||
DUPLICATE_ALIAS: str
|
||||
DUPLICATE_FUNCTION: str
|
||||
AMBIGUOUS_FUNCTION: str
|
||||
GROUPING_ERROR: str
|
||||
DATATYPE_MISMATCH: str
|
||||
WRONG_OBJECT_TYPE: str
|
||||
INVALID_FOREIGN_KEY: str
|
||||
CANNOT_COERCE: str
|
||||
UNDEFINED_FUNCTION: str
|
||||
GENERATED_ALWAYS: str
|
||||
RESERVED_NAME: str
|
||||
UNDEFINED_TABLE: str
|
||||
UNDEFINED_PARAMETER: str
|
||||
DUPLICATE_CURSOR: str
|
||||
DUPLICATE_DATABASE: str
|
||||
DUPLICATE_PREPARED_STATEMENT: str
|
||||
DUPLICATE_SCHEMA: str
|
||||
DUPLICATE_TABLE: str
|
||||
AMBIGUOUS_PARAMETER: str
|
||||
AMBIGUOUS_ALIAS: str
|
||||
INVALID_COLUMN_REFERENCE: str
|
||||
INVALID_CURSOR_DEFINITION: str
|
||||
INVALID_DATABASE_DEFINITION: str
|
||||
INVALID_FUNCTION_DEFINITION: str
|
||||
INVALID_PREPARED_STATEMENT_DEFINITION: str
|
||||
INVALID_SCHEMA_DEFINITION: str
|
||||
INVALID_TABLE_DEFINITION: str
|
||||
INVALID_OBJECT_DEFINITION: str
|
||||
INDETERMINATE_DATATYPE: str
|
||||
INVALID_RECURSION: str
|
||||
WINDOWING_ERROR: str
|
||||
COLLATION_MISMATCH: str
|
||||
INDETERMINATE_COLLATION: str
|
||||
WITH_CHECK_OPTION_VIOLATION: str
|
||||
INSUFFICIENT_RESOURCES: str
|
||||
DISK_FULL: str
|
||||
OUT_OF_MEMORY: str
|
||||
TOO_MANY_CONNECTIONS: str
|
||||
CONFIGURATION_LIMIT_EXCEEDED: str
|
||||
PROGRAM_LIMIT_EXCEEDED: str
|
||||
STATEMENT_TOO_COMPLEX: str
|
||||
TOO_MANY_COLUMNS: str
|
||||
TOO_MANY_ARGUMENTS: str
|
||||
OBJECT_NOT_IN_PREREQUISITE_STATE: str
|
||||
OBJECT_IN_USE: str
|
||||
CANT_CHANGE_RUNTIME_PARAM: str
|
||||
LOCK_NOT_AVAILABLE: str
|
||||
UNSAFE_NEW_ENUM_VALUE_USAGE: str
|
||||
OPERATOR_INTERVENTION: str
|
||||
QUERY_CANCELED: str
|
||||
ADMIN_SHUTDOWN: str
|
||||
CRASH_SHUTDOWN: str
|
||||
CANNOT_CONNECT_NOW: str
|
||||
DATABASE_DROPPED: str
|
||||
SYSTEM_ERROR: str
|
||||
IO_ERROR: str
|
||||
UNDEFINED_FILE: str
|
||||
DUPLICATE_FILE: str
|
||||
SNAPSHOT_TOO_OLD: str
|
||||
CONFIG_FILE_ERROR: str
|
||||
LOCK_FILE_EXISTS: str
|
||||
FDW_ERROR: str
|
||||
FDW_OUT_OF_MEMORY: str
|
||||
FDW_DYNAMIC_PARAMETER_VALUE_NEEDED: str
|
||||
FDW_INVALID_DATA_TYPE: str
|
||||
FDW_COLUMN_NAME_NOT_FOUND: str
|
||||
FDW_INVALID_DATA_TYPE_DESCRIPTORS: str
|
||||
FDW_INVALID_COLUMN_NAME: str
|
||||
FDW_INVALID_COLUMN_NUMBER: str
|
||||
FDW_INVALID_USE_OF_NULL_POINTER: str
|
||||
FDW_INVALID_STRING_FORMAT: str
|
||||
FDW_INVALID_HANDLE: str
|
||||
FDW_INVALID_OPTION_INDEX: str
|
||||
FDW_INVALID_OPTION_NAME: str
|
||||
FDW_OPTION_NAME_NOT_FOUND: str
|
||||
FDW_REPLY_HANDLE: str
|
||||
FDW_UNABLE_TO_CREATE_EXECUTION: str
|
||||
FDW_UNABLE_TO_CREATE_REPLY: str
|
||||
FDW_UNABLE_TO_ESTABLISH_CONNECTION: str
|
||||
FDW_NO_SCHEMAS: str
|
||||
FDW_SCHEMA_NOT_FOUND: str
|
||||
FDW_TABLE_NOT_FOUND: str
|
||||
FDW_FUNCTION_SEQUENCE_ERROR: str
|
||||
FDW_TOO_MANY_HANDLES: str
|
||||
FDW_INCONSISTENT_DESCRIPTOR_INFORMATION: str
|
||||
FDW_INVALID_ATTRIBUTE_VALUE: str
|
||||
FDW_INVALID_STRING_LENGTH_OR_BUFFER_LENGTH: str
|
||||
FDW_INVALID_DESCRIPTOR_FIELD_IDENTIFIER: str
|
||||
PLPGSQL_ERROR: str
|
||||
RAISE_EXCEPTION: str
|
||||
NO_DATA_FOUND: str
|
||||
TOO_MANY_ROWS: str
|
||||
ASSERT_FAILURE: str
|
||||
INTERNAL_ERROR: str
|
||||
DATA_CORRUPTED: str
|
||||
INDEX_CORRUPTED: str
|
||||
263
typings/psycopg2/errors.pyi
Normal file
263
typings/psycopg2/errors.pyi
Normal file
@@ -0,0 +1,263 @@
|
||||
from psycopg2._psycopg import Error as Error, Warning as Warning
|
||||
|
||||
class DatabaseError(Error): ...
|
||||
class InterfaceError(Error): ...
|
||||
class DataError(DatabaseError): ...
|
||||
class DiagnosticsException(DatabaseError): ...
|
||||
class IntegrityError(DatabaseError): ...
|
||||
class InternalError(DatabaseError): ...
|
||||
class InvalidGrantOperation(DatabaseError): ...
|
||||
class InvalidGrantor(DatabaseError): ...
|
||||
class InvalidLocatorSpecification(DatabaseError): ...
|
||||
class InvalidRoleSpecification(DatabaseError): ...
|
||||
class InvalidTransactionInitiation(DatabaseError): ...
|
||||
class LocatorException(DatabaseError): ...
|
||||
class NoAdditionalDynamicResultSetsReturned(DatabaseError): ...
|
||||
class NoData(DatabaseError): ...
|
||||
class NotSupportedError(DatabaseError): ...
|
||||
class OperationalError(DatabaseError): ...
|
||||
class ProgrammingError(DatabaseError): ...
|
||||
class SnapshotTooOld(DatabaseError): ...
|
||||
class SqlStatementNotYetComplete(DatabaseError): ...
|
||||
class StackedDiagnosticsAccessedWithoutActiveHandler(DatabaseError): ...
|
||||
class TriggeredActionException(DatabaseError): ...
|
||||
class ActiveSqlTransaction(InternalError): ...
|
||||
class AdminShutdown(OperationalError): ...
|
||||
class AmbiguousAlias(ProgrammingError): ...
|
||||
class AmbiguousColumn(ProgrammingError): ...
|
||||
class AmbiguousFunction(ProgrammingError): ...
|
||||
class AmbiguousParameter(ProgrammingError): ...
|
||||
class ArraySubscriptError(DataError): ...
|
||||
class AssertFailure(InternalError): ...
|
||||
class BadCopyFileFormat(DataError): ...
|
||||
class BranchTransactionAlreadyActive(InternalError): ...
|
||||
class CannotCoerce(ProgrammingError): ...
|
||||
class CannotConnectNow(OperationalError): ...
|
||||
class CantChangeRuntimeParam(OperationalError): ...
|
||||
class CardinalityViolation(ProgrammingError): ...
|
||||
class CaseNotFound(ProgrammingError): ...
|
||||
class CharacterNotInRepertoire(DataError): ...
|
||||
class CheckViolation(IntegrityError): ...
|
||||
class CollationMismatch(ProgrammingError): ...
|
||||
class ConfigFileError(InternalError): ...
|
||||
class ConfigurationLimitExceeded(OperationalError): ...
|
||||
class ConnectionDoesNotExist(OperationalError): ...
|
||||
class ConnectionException(OperationalError): ...
|
||||
class ConnectionFailure(OperationalError): ...
|
||||
class ContainingSqlNotPermitted(InternalError): ...
|
||||
class CrashShutdown(OperationalError): ...
|
||||
class DataCorrupted(InternalError): ...
|
||||
class DataException(DataError): ...
|
||||
class DatabaseDropped(OperationalError): ...
|
||||
class DatatypeMismatch(ProgrammingError): ...
|
||||
class DatetimeFieldOverflow(DataError): ...
|
||||
class DependentObjectsStillExist(InternalError): ...
|
||||
class DependentPrivilegeDescriptorsStillExist(InternalError): ...
|
||||
class DiskFull(OperationalError): ...
|
||||
class DivisionByZero(DataError): ...
|
||||
class DuplicateAlias(ProgrammingError): ...
|
||||
class DuplicateColumn(ProgrammingError): ...
|
||||
class DuplicateCursor(ProgrammingError): ...
|
||||
class DuplicateDatabase(ProgrammingError): ...
|
||||
class DuplicateFile(OperationalError): ...
|
||||
class DuplicateFunction(ProgrammingError): ...
|
||||
class DuplicateJsonObjectKeyValue(DataError): ...
|
||||
class DuplicateObject(ProgrammingError): ...
|
||||
class DuplicatePreparedStatement(ProgrammingError): ...
|
||||
class DuplicateSchema(ProgrammingError): ...
|
||||
class DuplicateTable(ProgrammingError): ...
|
||||
class ErrorInAssignment(DataError): ...
|
||||
class EscapeCharacterConflict(DataError): ...
|
||||
class EventTriggerProtocolViolated(InternalError): ...
|
||||
class ExclusionViolation(IntegrityError): ...
|
||||
class ExternalRoutineException(InternalError): ...
|
||||
class ExternalRoutineInvocationException(InternalError): ...
|
||||
class FdwColumnNameNotFound(OperationalError): ...
|
||||
class FdwDynamicParameterValueNeeded(OperationalError): ...
|
||||
class FdwError(OperationalError): ...
|
||||
class FdwFunctionSequenceError(OperationalError): ...
|
||||
class FdwInconsistentDescriptorInformation(OperationalError): ...
|
||||
class FdwInvalidAttributeValue(OperationalError): ...
|
||||
class FdwInvalidColumnName(OperationalError): ...
|
||||
class FdwInvalidColumnNumber(OperationalError): ...
|
||||
class FdwInvalidDataType(OperationalError): ...
|
||||
class FdwInvalidDataTypeDescriptors(OperationalError): ...
|
||||
class FdwInvalidDescriptorFieldIdentifier(OperationalError): ...
|
||||
class FdwInvalidHandle(OperationalError): ...
|
||||
class FdwInvalidOptionIndex(OperationalError): ...
|
||||
class FdwInvalidOptionName(OperationalError): ...
|
||||
class FdwInvalidStringFormat(OperationalError): ...
|
||||
class FdwInvalidStringLengthOrBufferLength(OperationalError): ...
|
||||
class FdwInvalidUseOfNullPointer(OperationalError): ...
|
||||
class FdwNoSchemas(OperationalError): ...
|
||||
class FdwOptionNameNotFound(OperationalError): ...
|
||||
class FdwOutOfMemory(OperationalError): ...
|
||||
class FdwReplyHandle(OperationalError): ...
|
||||
class FdwSchemaNotFound(OperationalError): ...
|
||||
class FdwTableNotFound(OperationalError): ...
|
||||
class FdwTooManyHandles(OperationalError): ...
|
||||
class FdwUnableToCreateExecution(OperationalError): ...
|
||||
class FdwUnableToCreateReply(OperationalError): ...
|
||||
class FdwUnableToEstablishConnection(OperationalError): ...
|
||||
class FeatureNotSupported(NotSupportedError): ...
|
||||
class FloatingPointException(DataError): ...
|
||||
class ForeignKeyViolation(IntegrityError): ...
|
||||
class FunctionExecutedNoReturnStatement(InternalError): ...
|
||||
class GeneratedAlways(ProgrammingError): ...
|
||||
class GroupingError(ProgrammingError): ...
|
||||
class HeldCursorRequiresSameIsolationLevel(InternalError): ...
|
||||
class IdleInTransactionSessionTimeout(InternalError): ...
|
||||
class InFailedSqlTransaction(InternalError): ...
|
||||
class InappropriateAccessModeForBranchTransaction(InternalError): ...
|
||||
class InappropriateIsolationLevelForBranchTransaction(InternalError): ...
|
||||
class IndeterminateCollation(ProgrammingError): ...
|
||||
class IndeterminateDatatype(ProgrammingError): ...
|
||||
class IndexCorrupted(InternalError): ...
|
||||
class IndicatorOverflow(DataError): ...
|
||||
class InsufficientPrivilege(ProgrammingError): ...
|
||||
class InsufficientResources(OperationalError): ...
|
||||
class IntegrityConstraintViolation(IntegrityError): ...
|
||||
class InternalError_(InternalError): ...
|
||||
class IntervalFieldOverflow(DataError): ...
|
||||
class InvalidArgumentForLogarithm(DataError): ...
|
||||
class InvalidArgumentForNthValueFunction(DataError): ...
|
||||
class InvalidArgumentForNtileFunction(DataError): ...
|
||||
class InvalidArgumentForPowerFunction(DataError): ...
|
||||
class InvalidArgumentForSqlJsonDatetimeFunction(DataError): ...
|
||||
class InvalidArgumentForWidthBucketFunction(DataError): ...
|
||||
class InvalidAuthorizationSpecification(OperationalError): ...
|
||||
class InvalidBinaryRepresentation(DataError): ...
|
||||
class InvalidCatalogName(ProgrammingError): ...
|
||||
class InvalidCharacterValueForCast(DataError): ...
|
||||
class InvalidColumnDefinition(ProgrammingError): ...
|
||||
class InvalidColumnReference(ProgrammingError): ...
|
||||
class InvalidCursorDefinition(ProgrammingError): ...
|
||||
class InvalidCursorName(OperationalError): ...
|
||||
class InvalidCursorState(InternalError): ...
|
||||
class InvalidDatabaseDefinition(ProgrammingError): ...
|
||||
class InvalidDatetimeFormat(DataError): ...
|
||||
class InvalidEscapeCharacter(DataError): ...
|
||||
class InvalidEscapeOctet(DataError): ...
|
||||
class InvalidEscapeSequence(DataError): ...
|
||||
class InvalidForeignKey(ProgrammingError): ...
|
||||
class InvalidFunctionDefinition(ProgrammingError): ...
|
||||
class InvalidIndicatorParameterValue(DataError): ...
|
||||
class InvalidJsonText(DataError): ...
|
||||
class InvalidName(ProgrammingError): ...
|
||||
class InvalidObjectDefinition(ProgrammingError): ...
|
||||
class InvalidParameterValue(DataError): ...
|
||||
class InvalidPassword(OperationalError): ...
|
||||
class InvalidPrecedingOrFollowingSize(DataError): ...
|
||||
class InvalidPreparedStatementDefinition(ProgrammingError): ...
|
||||
class InvalidRecursion(ProgrammingError): ...
|
||||
class InvalidRegularExpression(DataError): ...
|
||||
class InvalidRowCountInLimitClause(DataError): ...
|
||||
class InvalidRowCountInResultOffsetClause(DataError): ...
|
||||
class InvalidSavepointSpecification(InternalError): ...
|
||||
class InvalidSchemaDefinition(ProgrammingError): ...
|
||||
class InvalidSchemaName(ProgrammingError): ...
|
||||
class InvalidSqlJsonSubscript(DataError): ...
|
||||
class InvalidSqlStatementName(OperationalError): ...
|
||||
class InvalidSqlstateReturned(InternalError): ...
|
||||
class InvalidTableDefinition(ProgrammingError): ...
|
||||
class InvalidTablesampleArgument(DataError): ...
|
||||
class InvalidTablesampleRepeat(DataError): ...
|
||||
class InvalidTextRepresentation(DataError): ...
|
||||
class InvalidTimeZoneDisplacementValue(DataError): ...
|
||||
class InvalidTransactionState(InternalError): ...
|
||||
class InvalidTransactionTermination(InternalError): ...
|
||||
class InvalidUseOfEscapeCharacter(DataError): ...
|
||||
class InvalidXmlComment(DataError): ...
|
||||
class InvalidXmlContent(DataError): ...
|
||||
class InvalidXmlDocument(DataError): ...
|
||||
class InvalidXmlProcessingInstruction(DataError): ...
|
||||
class IoError(OperationalError): ...
|
||||
class LockFileExists(InternalError): ...
|
||||
class LockNotAvailable(OperationalError): ...
|
||||
class ModifyingSqlDataNotPermitted(InternalError): ...
|
||||
class ModifyingSqlDataNotPermittedExt(InternalError): ...
|
||||
class MoreThanOneSqlJsonItem(DataError): ...
|
||||
class MostSpecificTypeMismatch(DataError): ...
|
||||
class NameTooLong(ProgrammingError): ...
|
||||
class NoActiveSqlTransaction(InternalError): ...
|
||||
class NoActiveSqlTransactionForBranchTransaction(InternalError): ...
|
||||
class NoDataFound(InternalError): ...
|
||||
class NoSqlJsonItem(DataError): ...
|
||||
class NonNumericSqlJsonItem(DataError): ...
|
||||
class NonUniqueKeysInAJsonObject(DataError): ...
|
||||
class NonstandardUseOfEscapeCharacter(DataError): ...
|
||||
class NotAnXmlDocument(DataError): ...
|
||||
class NotNullViolation(IntegrityError): ...
|
||||
class NullValueNoIndicatorParameter(DataError): ...
|
||||
class NullValueNotAllowed(DataError): ...
|
||||
class NullValueNotAllowedExt(InternalError): ...
|
||||
class NumericValueOutOfRange(DataError): ...
|
||||
class ObjectInUse(OperationalError): ...
|
||||
class ObjectNotInPrerequisiteState(OperationalError): ...
|
||||
class OperatorIntervention(OperationalError): ...
|
||||
class OutOfMemory(OperationalError): ...
|
||||
class PlpgsqlError(InternalError): ...
|
||||
class ProgramLimitExceeded(OperationalError): ...
|
||||
class ProhibitedSqlStatementAttempted(InternalError): ...
|
||||
class ProhibitedSqlStatementAttemptedExt(InternalError): ...
|
||||
class ProtocolViolation(OperationalError): ...
|
||||
class QueryCanceledError(OperationalError): ...
|
||||
class RaiseException(InternalError): ...
|
||||
class ReadOnlySqlTransaction(InternalError): ...
|
||||
class ReadingSqlDataNotPermitted(InternalError): ...
|
||||
class ReadingSqlDataNotPermittedExt(InternalError): ...
|
||||
class ReservedName(ProgrammingError): ...
|
||||
class RestrictViolation(IntegrityError): ...
|
||||
class SavepointException(InternalError): ...
|
||||
class SchemaAndDataStatementMixingNotSupported(InternalError): ...
|
||||
class SequenceGeneratorLimitExceeded(DataError): ...
|
||||
class SingletonSqlJsonItemRequired(DataError): ...
|
||||
class SqlJsonArrayNotFound(DataError): ...
|
||||
class SqlJsonMemberNotFound(DataError): ...
|
||||
class SqlJsonNumberNotFound(DataError): ...
|
||||
class SqlJsonObjectNotFound(DataError): ...
|
||||
class SqlJsonScalarRequired(DataError): ...
|
||||
class SqlRoutineException(InternalError): ...
|
||||
class SqlclientUnableToEstablishSqlconnection(OperationalError): ...
|
||||
class SqlserverRejectedEstablishmentOfSqlconnection(OperationalError): ...
|
||||
class SrfProtocolViolated(InternalError): ...
|
||||
class StatementTooComplex(OperationalError): ...
|
||||
class StringDataLengthMismatch(DataError): ...
|
||||
class StringDataRightTruncation(DataError): ...
|
||||
class SubstringError(DataError): ...
|
||||
class SyntaxError(ProgrammingError): ...
|
||||
class SyntaxErrorOrAccessRuleViolation(ProgrammingError): ...
|
||||
class SystemError(OperationalError): ...
|
||||
class TooManyArguments(OperationalError): ...
|
||||
class TooManyColumns(OperationalError): ...
|
||||
class TooManyConnections(OperationalError): ...
|
||||
class TooManyJsonArrayElements(DataError): ...
|
||||
class TooManyJsonObjectMembers(DataError): ...
|
||||
class TooManyRows(InternalError): ...
|
||||
class TransactionResolutionUnknown(OperationalError): ...
|
||||
class TransactionRollbackError(OperationalError): ...
|
||||
class TriggerProtocolViolated(InternalError): ...
|
||||
class TriggeredDataChangeViolation(OperationalError): ...
|
||||
class TrimError(DataError): ...
|
||||
class UndefinedColumn(ProgrammingError): ...
|
||||
class UndefinedFile(OperationalError): ...
|
||||
class UndefinedFunction(ProgrammingError): ...
|
||||
class UndefinedObject(ProgrammingError): ...
|
||||
class UndefinedParameter(ProgrammingError): ...
|
||||
class UndefinedTable(ProgrammingError): ...
|
||||
class UniqueViolation(IntegrityError): ...
|
||||
class UnsafeNewEnumValueUsage(OperationalError): ...
|
||||
class UnterminatedCString(DataError): ...
|
||||
class UntranslatableCharacter(DataError): ...
|
||||
class WindowingError(ProgrammingError): ...
|
||||
class WithCheckOptionViolation(ProgrammingError): ...
|
||||
class WrongObjectType(ProgrammingError): ...
|
||||
class ZeroLengthCharacterString(DataError): ...
|
||||
class DeadlockDetected(TransactionRollbackError): ...
|
||||
class QueryCanceled(QueryCanceledError): ...
|
||||
class SerializationFailure(TransactionRollbackError): ...
|
||||
class StatementCompletionUnknown(TransactionRollbackError): ...
|
||||
class TransactionIntegrityConstraintViolation(TransactionRollbackError): ...
|
||||
class TransactionRollback(TransactionRollbackError): ...
|
||||
|
||||
def lookup(code): ...
|
||||
117
typings/psycopg2/extensions.pyi
Normal file
117
typings/psycopg2/extensions.pyi
Normal file
@@ -0,0 +1,117 @@
|
||||
from _typeshed import Incomplete
|
||||
from typing import Any
|
||||
|
||||
from psycopg2._psycopg import (
|
||||
BINARYARRAY as BINARYARRAY,
|
||||
BOOLEAN as BOOLEAN,
|
||||
BOOLEANARRAY as BOOLEANARRAY,
|
||||
BYTES as BYTES,
|
||||
BYTESARRAY as BYTESARRAY,
|
||||
DATE as DATE,
|
||||
DATEARRAY as DATEARRAY,
|
||||
DATETIMEARRAY as DATETIMEARRAY,
|
||||
DECIMAL as DECIMAL,
|
||||
DECIMALARRAY as DECIMALARRAY,
|
||||
FLOAT as FLOAT,
|
||||
FLOATARRAY as FLOATARRAY,
|
||||
INTEGER as INTEGER,
|
||||
INTEGERARRAY as INTEGERARRAY,
|
||||
INTERVAL as INTERVAL,
|
||||
INTERVALARRAY as INTERVALARRAY,
|
||||
LONGINTEGER as LONGINTEGER,
|
||||
LONGINTEGERARRAY as LONGINTEGERARRAY,
|
||||
PYDATE as PYDATE,
|
||||
PYDATEARRAY as PYDATEARRAY,
|
||||
PYDATETIME as PYDATETIME,
|
||||
PYDATETIMEARRAY as PYDATETIMEARRAY,
|
||||
PYDATETIMETZ as PYDATETIMETZ,
|
||||
PYDATETIMETZARRAY as PYDATETIMETZARRAY,
|
||||
PYINTERVAL as PYINTERVAL,
|
||||
PYINTERVALARRAY as PYINTERVALARRAY,
|
||||
PYTIME as PYTIME,
|
||||
PYTIMEARRAY as PYTIMEARRAY,
|
||||
ROWIDARRAY as ROWIDARRAY,
|
||||
STRINGARRAY as STRINGARRAY,
|
||||
TIME as TIME,
|
||||
TIMEARRAY as TIMEARRAY,
|
||||
UNICODE as UNICODE,
|
||||
UNICODEARRAY as UNICODEARRAY,
|
||||
AsIs as AsIs,
|
||||
Binary as Binary,
|
||||
Boolean as Boolean,
|
||||
Column as Column,
|
||||
ConnectionInfo as ConnectionInfo,
|
||||
DateFromPy as DateFromPy,
|
||||
Diagnostics as Diagnostics,
|
||||
Float as Float,
|
||||
Int as Int,
|
||||
IntervalFromPy as IntervalFromPy,
|
||||
ISQLQuote as ISQLQuote,
|
||||
Notify as Notify,
|
||||
QueryCanceledError as QueryCanceledError,
|
||||
QuotedString as QuotedString,
|
||||
TimeFromPy as TimeFromPy,
|
||||
TimestampFromPy as TimestampFromPy,
|
||||
TransactionRollbackError as TransactionRollbackError,
|
||||
Xid as Xid,
|
||||
adapt as adapt,
|
||||
adapters as adapters,
|
||||
binary_types as binary_types,
|
||||
connection as connection,
|
||||
cursor as cursor,
|
||||
encodings as encodings,
|
||||
encrypt_password as encrypt_password,
|
||||
get_wait_callback as get_wait_callback,
|
||||
libpq_version as libpq_version,
|
||||
lobject as lobject,
|
||||
new_array_type as new_array_type,
|
||||
new_type as new_type,
|
||||
parse_dsn as parse_dsn,
|
||||
quote_ident as quote_ident,
|
||||
register_type as register_type,
|
||||
set_wait_callback as set_wait_callback,
|
||||
string_types as string_types,
|
||||
)
|
||||
|
||||
ISOLATION_LEVEL_AUTOCOMMIT: int
|
||||
ISOLATION_LEVEL_READ_UNCOMMITTED: int
|
||||
ISOLATION_LEVEL_READ_COMMITTED: int
|
||||
ISOLATION_LEVEL_REPEATABLE_READ: int
|
||||
ISOLATION_LEVEL_SERIALIZABLE: int
|
||||
ISOLATION_LEVEL_DEFAULT: Any
|
||||
STATUS_SETUP: int
|
||||
STATUS_READY: int
|
||||
STATUS_BEGIN: int
|
||||
STATUS_SYNC: int
|
||||
STATUS_ASYNC: int
|
||||
STATUS_PREPARED: int
|
||||
STATUS_IN_TRANSACTION: int
|
||||
POLL_OK: int
|
||||
POLL_READ: int
|
||||
POLL_WRITE: int
|
||||
POLL_ERROR: int
|
||||
TRANSACTION_STATUS_IDLE: int
|
||||
TRANSACTION_STATUS_ACTIVE: int
|
||||
TRANSACTION_STATUS_INTRANS: int
|
||||
TRANSACTION_STATUS_INERROR: int
|
||||
TRANSACTION_STATUS_UNKNOWN: int
|
||||
|
||||
def register_adapter(typ, callable) -> None: ...
|
||||
|
||||
class SQL_IN:
|
||||
def __init__(self, seq) -> None: ...
|
||||
def prepare(self, conn) -> None: ...
|
||||
def getquoted(self): ...
|
||||
|
||||
class NoneAdapter:
|
||||
def __init__(self, obj) -> None: ...
|
||||
def getquoted(self, _null: bytes = b"NULL"): ...
|
||||
|
||||
def make_dsn(dsn: Incomplete | None = None, **kwargs): ...
|
||||
|
||||
JSON: Any
|
||||
JSONARRAY: Any
|
||||
JSONB: Any
|
||||
JSONBARRAY: Any
|
||||
|
||||
def adapt(obj: Any) -> ISQLQuote: ...
|
||||
240
typings/psycopg2/extras.pyi
Normal file
240
typings/psycopg2/extras.pyi
Normal file
@@ -0,0 +1,240 @@
|
||||
from _typeshed import Incomplete
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
from typing import Any, NamedTuple, TypeVar, overload
|
||||
|
||||
from psycopg2._ipaddress import register_ipaddress as register_ipaddress
|
||||
from psycopg2._json import (
|
||||
Json as Json,
|
||||
register_default_json as register_default_json,
|
||||
register_default_jsonb as register_default_jsonb,
|
||||
register_json as register_json,
|
||||
)
|
||||
from psycopg2._psycopg import (
|
||||
REPLICATION_LOGICAL as REPLICATION_LOGICAL,
|
||||
REPLICATION_PHYSICAL as REPLICATION_PHYSICAL,
|
||||
ReplicationConnection as _replicationConnection,
|
||||
ReplicationCursor as _replicationCursor,
|
||||
ReplicationMessage as ReplicationMessage,
|
||||
)
|
||||
from psycopg2._range import (
|
||||
DateRange as DateRange,
|
||||
DateTimeRange as DateTimeRange,
|
||||
DateTimeTZRange as DateTimeTZRange,
|
||||
NumericRange as NumericRange,
|
||||
Range as Range,
|
||||
RangeAdapter as RangeAdapter,
|
||||
RangeCaster as RangeCaster,
|
||||
register_range as register_range,
|
||||
)
|
||||
|
||||
from .extensions import connection as _connection, cursor as _cursor, quote_ident as quote_ident
|
||||
|
||||
_T_cur = TypeVar("_T_cur", bound=_cursor)
|
||||
|
||||
class DictCursorBase(_cursor):
|
||||
def __init__(self, *args, **kwargs) -> None: ...
|
||||
|
||||
class DictConnection(_connection):
|
||||
@overload
|
||||
def cursor(self, name: str | bytes | None = ..., *, withhold: bool = ..., scrollable: bool | None = ...) -> DictCursor: ...
|
||||
@overload
|
||||
def cursor(
|
||||
self,
|
||||
name: str | bytes | None = ...,
|
||||
*,
|
||||
cursor_factory: Callable[..., _T_cur],
|
||||
withhold: bool = ...,
|
||||
scrollable: bool | None = ...,
|
||||
) -> _T_cur: ...
|
||||
@overload
|
||||
def cursor(
|
||||
self, name: str | bytes | None, cursor_factory: Callable[..., _T_cur], withhold: bool = ..., scrollable: bool | None = ...
|
||||
) -> _T_cur: ...
|
||||
|
||||
class DictCursor(DictCursorBase):
|
||||
def __init__(self, *args, **kwargs) -> None: ...
|
||||
index: Any
|
||||
def execute(self, query, vars: Incomplete | None = None): ...
|
||||
def callproc(self, procname, vars: Incomplete | None = None): ...
|
||||
def fetchone(self) -> DictRow | None: ... # type: ignore[override]
|
||||
def fetchmany(self, size: int | None = None) -> list[DictRow]: ... # type: ignore[override]
|
||||
def fetchall(self) -> list[DictRow]: ... # type: ignore[override]
|
||||
def __next__(self) -> DictRow: ... # type: ignore[override]
|
||||
|
||||
class DictRow(list[Any]):
|
||||
def __init__(self, cursor) -> None: ...
|
||||
def __getitem__(self, x): ...
|
||||
def __setitem__(self, x, v) -> None: ...
|
||||
def items(self): ...
|
||||
def keys(self): ...
|
||||
def values(self): ...
|
||||
def get(self, x, default: Incomplete | None = None): ...
|
||||
def copy(self): ...
|
||||
def __contains__(self, x): ...
|
||||
def __reduce__(self): ...
|
||||
|
||||
class RealDictConnection(_connection):
|
||||
@overload
|
||||
def cursor(
|
||||
self, name: str | bytes | None = ..., *, withhold: bool = ..., scrollable: bool | None = ...
|
||||
) -> RealDictCursor: ...
|
||||
@overload
|
||||
def cursor(
|
||||
self,
|
||||
name: str | bytes | None = ...,
|
||||
*,
|
||||
cursor_factory: Callable[..., _T_cur],
|
||||
withhold: bool = ...,
|
||||
scrollable: bool | None = ...,
|
||||
) -> _T_cur: ...
|
||||
@overload
|
||||
def cursor(
|
||||
self, name: str | bytes | None, cursor_factory: Callable[..., _T_cur], withhold: bool = ..., scrollable: bool | None = ...
|
||||
) -> _T_cur: ...
|
||||
|
||||
class RealDictCursor(DictCursorBase):
|
||||
def __init__(self, *args, **kwargs) -> None: ...
|
||||
column_mapping: Any
|
||||
def execute(self, query, vars: Incomplete | None = None): ...
|
||||
def callproc(self, procname, vars: Incomplete | None = None): ...
|
||||
def fetchone(self) -> RealDictRow | None: ... # type: ignore[override]
|
||||
def fetchmany(self, size: int | None = None) -> list[RealDictRow]: ... # type: ignore[override]
|
||||
def fetchall(self) -> list[RealDictRow]: ... # type: ignore[override]
|
||||
def __next__(self) -> RealDictRow: ... # type: ignore[override]
|
||||
|
||||
class RealDictRow(OrderedDict[Any, Any]):
|
||||
def __init__(self, *args, **kwargs) -> None: ...
|
||||
def __setitem__(self, key, value) -> None: ...
|
||||
|
||||
class NamedTupleConnection(_connection):
|
||||
@overload
|
||||
def cursor(
|
||||
self, name: str | bytes | None = ..., *, withhold: bool = ..., scrollable: bool | None = ...
|
||||
) -> NamedTupleCursor: ...
|
||||
@overload
|
||||
def cursor(
|
||||
self,
|
||||
name: str | bytes | None = ...,
|
||||
*,
|
||||
cursor_factory: Callable[..., _T_cur],
|
||||
withhold: bool = ...,
|
||||
scrollable: bool | None = ...,
|
||||
) -> _T_cur: ...
|
||||
@overload
|
||||
def cursor(
|
||||
self, name: str | bytes | None, cursor_factory: Callable[..., _T_cur], withhold: bool = ..., scrollable: bool | None = ...
|
||||
) -> _T_cur: ...
|
||||
|
||||
class NamedTupleCursor(_cursor):
|
||||
Record: Any
|
||||
MAX_CACHE: int
|
||||
def execute(self, query, vars: Incomplete | None = None): ...
|
||||
def executemany(self, query, vars): ...
|
||||
def callproc(self, procname, vars: Incomplete | None = None): ...
|
||||
def fetchone(self) -> NamedTuple | None: ...
|
||||
def fetchmany(self, size: int | None = None) -> list[NamedTuple]: ... # type: ignore[override]
|
||||
def fetchall(self) -> list[NamedTuple]: ... # type: ignore[override]
|
||||
def __next__(self) -> NamedTuple: ...
|
||||
|
||||
class LoggingConnection(_connection):
|
||||
log: Any
|
||||
def initialize(self, logobj) -> None: ...
|
||||
def filter(self, msg, curs): ...
|
||||
def cursor(self, *args, **kwargs): ...
|
||||
|
||||
class LoggingCursor(_cursor):
|
||||
def execute(self, query, vars: Incomplete | None = None): ...
|
||||
def callproc(self, procname, vars: Incomplete | None = None): ...
|
||||
|
||||
class MinTimeLoggingConnection(LoggingConnection):
|
||||
def initialize(self, logobj, mintime: int = 0) -> None: ...
|
||||
def filter(self, msg, curs): ...
|
||||
def cursor(self, *args, **kwargs): ...
|
||||
|
||||
class MinTimeLoggingCursor(LoggingCursor):
|
||||
timestamp: Any
|
||||
def execute(self, query, vars: Incomplete | None = None): ...
|
||||
def callproc(self, procname, vars: Incomplete | None = None): ...
|
||||
|
||||
class LogicalReplicationConnection(_replicationConnection):
|
||||
def __init__(self, *args, **kwargs) -> None: ...
|
||||
|
||||
class PhysicalReplicationConnection(_replicationConnection):
|
||||
def __init__(self, *args, **kwargs) -> None: ...
|
||||
|
||||
class StopReplication(Exception): ...
|
||||
|
||||
class ReplicationCursor(_replicationCursor):
|
||||
def create_replication_slot(
|
||||
self, slot_name, slot_type: Incomplete | None = None, output_plugin: Incomplete | None = None
|
||||
) -> None: ...
|
||||
def drop_replication_slot(self, slot_name) -> None: ...
|
||||
def start_replication(
|
||||
self,
|
||||
slot_name: Incomplete | None = None,
|
||||
slot_type: Incomplete | None = None,
|
||||
start_lsn: int = 0,
|
||||
timeline: int = 0,
|
||||
options: Incomplete | None = None,
|
||||
decode: bool = False,
|
||||
status_interval: int = 10,
|
||||
) -> None: ...
|
||||
def fileno(self): ...
|
||||
|
||||
class UUID_adapter:
|
||||
def __init__(self, uuid) -> None: ...
|
||||
def __conform__(self, proto): ...
|
||||
def getquoted(self): ...
|
||||
|
||||
def register_uuid(oids: Incomplete | None = None, conn_or_curs: Incomplete | None = None): ...
|
||||
|
||||
class Inet:
|
||||
addr: Any
|
||||
def __init__(self, addr) -> None: ...
|
||||
def prepare(self, conn) -> None: ...
|
||||
def getquoted(self): ...
|
||||
def __conform__(self, proto): ...
|
||||
|
||||
def register_inet(oid: Incomplete | None = None, conn_or_curs: Incomplete | None = None): ...
|
||||
def wait_select(conn) -> None: ...
|
||||
|
||||
class HstoreAdapter:
|
||||
wrapped: Any
|
||||
def __init__(self, wrapped) -> None: ...
|
||||
conn: Any
|
||||
getquoted: Any
|
||||
def prepare(self, conn) -> None: ...
|
||||
@classmethod
|
||||
def parse(cls, s, cur, _bsdec=...): ...
|
||||
@classmethod
|
||||
def parse_unicode(cls, s, cur): ...
|
||||
@classmethod
|
||||
def get_oids(cls, conn_or_curs): ...
|
||||
|
||||
def register_hstore(
|
||||
conn_or_curs,
|
||||
globally: bool = False,
|
||||
unicode: bool = False,
|
||||
oid: Incomplete | None = None,
|
||||
array_oid: Incomplete | None = None,
|
||||
) -> None: ...
|
||||
|
||||
class CompositeCaster:
|
||||
name: Any
|
||||
schema: Any
|
||||
oid: Any
|
||||
array_oid: Any
|
||||
attnames: Any
|
||||
atttypes: Any
|
||||
typecaster: Any
|
||||
array_typecaster: Any
|
||||
def __init__(self, name, oid, attrs, array_oid: Incomplete | None = None, schema: Incomplete | None = None) -> None: ...
|
||||
def parse(self, s, curs): ...
|
||||
def make(self, values): ...
|
||||
@classmethod
|
||||
def tokenize(cls, s): ...
|
||||
|
||||
def register_composite(name, conn_or_curs, globally: bool = False, factory: Incomplete | None = None): ...
|
||||
def execute_batch(cur, sql, argslist, page_size: int = 100) -> None: ...
|
||||
def execute_values(cur, sql, argslist, template: Incomplete | None = None, page_size: int = 100, fetch: bool = False): ...
|
||||
24
typings/psycopg2/pool.pyi
Normal file
24
typings/psycopg2/pool.pyi
Normal file
@@ -0,0 +1,24 @@
|
||||
from _typeshed import Incomplete
|
||||
from typing import Any
|
||||
|
||||
import psycopg2
|
||||
|
||||
class PoolError(psycopg2.Error): ...
|
||||
|
||||
class AbstractConnectionPool:
|
||||
minconn: Any
|
||||
maxconn: Any
|
||||
closed: bool
|
||||
def __init__(self, minconn, maxconn, *args, **kwargs) -> None: ...
|
||||
# getconn, putconn and closeall are officially documented as methods of the
|
||||
# abstract base class, but in reality, they only exist on the children classes
|
||||
def getconn(self, key: Incomplete | None = ...): ...
|
||||
def putconn(self, conn: Any, key: Incomplete | None = ..., close: bool = ...) -> None: ...
|
||||
def closeall(self) -> None: ...
|
||||
|
||||
class SimpleConnectionPool(AbstractConnectionPool): ...
|
||||
|
||||
class ThreadedConnectionPool(AbstractConnectionPool):
|
||||
# This subclass has a default value for conn which doesn't exist
|
||||
# in the SimpleConnectionPool class, nor in the documentation
|
||||
def putconn(self, conn: Incomplete | None = None, key: Incomplete | None = None, close: bool = False) -> None: ...
|
||||
50
typings/psycopg2/sql.pyi
Normal file
50
typings/psycopg2/sql.pyi
Normal file
@@ -0,0 +1,50 @@
|
||||
from _typeshed import Incomplete
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
class Composable:
|
||||
def __init__(self, wrapped) -> None: ...
|
||||
def as_string(self, context) -> str: ...
|
||||
def __add__(self, other) -> Composed: ...
|
||||
def __mul__(self, n) -> Composed: ...
|
||||
def __eq__(self, other) -> bool: ...
|
||||
def __ne__(self, other) -> bool: ...
|
||||
|
||||
class Composed(Composable):
|
||||
def __init__(self, seq) -> None: ...
|
||||
@property
|
||||
def seq(self) -> list[Composable]: ...
|
||||
def as_string(self, context) -> str: ...
|
||||
def __iter__(self) -> Iterator[Composable]: ...
|
||||
def __add__(self, other) -> Composed: ...
|
||||
def join(self, joiner) -> Composed: ...
|
||||
|
||||
class SQL(Composable):
|
||||
def __init__(self, string) -> None: ...
|
||||
@property
|
||||
def string(self) -> str: ...
|
||||
def as_string(self, context) -> str: ...
|
||||
def format(self, *args, **kwargs) -> Composed: ...
|
||||
def join(self, seq) -> Composed: ...
|
||||
|
||||
class Identifier(Composable):
|
||||
def __init__(self, *strings) -> None: ...
|
||||
@property
|
||||
def strings(self) -> tuple[str, ...]: ...
|
||||
@property
|
||||
def string(self) -> str: ...
|
||||
def as_string(self, context) -> str: ...
|
||||
|
||||
class Literal(Composable):
|
||||
@property
|
||||
def wrapped(self): ...
|
||||
def as_string(self, context) -> str: ...
|
||||
|
||||
class Placeholder(Composable):
|
||||
def __init__(self, name: Incomplete | None = None) -> None: ...
|
||||
@property
|
||||
def name(self) -> str | None: ...
|
||||
def as_string(self, context) -> str: ...
|
||||
|
||||
NULL: Any
|
||||
DEFAULT: Any
|
||||
26
typings/psycopg2/tz.pyi
Normal file
26
typings/psycopg2/tz.pyi
Normal file
@@ -0,0 +1,26 @@
|
||||
import datetime
|
||||
from _typeshed import Incomplete
|
||||
from typing import Any
|
||||
|
||||
ZERO: Any
|
||||
|
||||
class FixedOffsetTimezone(datetime.tzinfo):
|
||||
def __init__(self, offset: Incomplete | None = None, name: Incomplete | None = None) -> None: ...
|
||||
def __new__(cls, offset: Incomplete | None = None, name: Incomplete | None = None): ...
|
||||
def __eq__(self, other): ...
|
||||
def __ne__(self, other): ...
|
||||
def __getinitargs__(self): ...
|
||||
def utcoffset(self, dt): ...
|
||||
def tzname(self, dt): ...
|
||||
def dst(self, dt): ...
|
||||
|
||||
STDOFFSET: Any
|
||||
DSTOFFSET: Any
|
||||
DSTDIFF: Any
|
||||
|
||||
class LocalTimezone(datetime.tzinfo):
|
||||
def utcoffset(self, dt): ...
|
||||
def dst(self, dt): ...
|
||||
def tzname(self, dt): ...
|
||||
|
||||
LOCAL: Any
|
||||
2
typings/pysyncobj/__init__.pyi
Normal file
2
typings/pysyncobj/__init__.pyi
Normal file
@@ -0,0 +1,2 @@
|
||||
from .syncobj import FAIL_REASON, SyncObj, SyncObjConf, replicated
|
||||
__all__ = ['SyncObj', 'SyncObjConf', 'replicated', 'FAIL_REASON']
|
||||
13
typings/pysyncobj/config.pyi
Normal file
13
typings/pysyncobj/config.pyi
Normal file
@@ -0,0 +1,13 @@
|
||||
from typing import Optional
|
||||
class FAIL_REASON:
|
||||
SUCCESS = ...
|
||||
QUEUE_FULL = ...
|
||||
MISSING_LEADER = ...
|
||||
DISCARDED = ...
|
||||
NOT_LEADER = ...
|
||||
LEADER_CHANGED = ...
|
||||
REQUEST_DENIED = ...
|
||||
class SyncObjConf:
|
||||
password: Optional[str]
|
||||
autoTickPeriod: int
|
||||
def __init__(self, **kwargs) -> None: ...
|
||||
5
typings/pysyncobj/dns_resolver.pyi
Normal file
5
typings/pysyncobj/dns_resolver.pyi
Normal file
@@ -0,0 +1,5 @@
|
||||
from typing import Optional
|
||||
class DnsCachingResolver:
|
||||
def setTimeouts(self, cacheTime: float, failCacheTime: float) -> None: ...
|
||||
def resolve(self, hostname: str) -> Optional[str]: ...
|
||||
def globalDnsResolver() -> DnsCachingResolver: ...
|
||||
6
typings/pysyncobj/node.pyi
Normal file
6
typings/pysyncobj/node.pyi
Normal file
@@ -0,0 +1,6 @@
|
||||
class Node:
|
||||
@property
|
||||
def id(self) -> str: ...
|
||||
class TCPNode(Node):
|
||||
@property
|
||||
def host(self) -> str: ...
|
||||
24
typings/pysyncobj/syncobj.pyi
Normal file
24
typings/pysyncobj/syncobj.pyi
Normal file
@@ -0,0 +1,24 @@
|
||||
from typing import Any, Callable, Collection, List, Optional, Set, Type
|
||||
from .config import FAIL_REASON, SyncObjConf
|
||||
from .node import Node
|
||||
from .transport import Transport
|
||||
__all__ = ['FAIL_REASON', 'SyncObj', 'SyncObjConf', 'replicated']
|
||||
class SyncObj:
|
||||
def __init__(self, selfNode: Optional[str], otherNodes: Collection[str], conf: SyncObjConf=..., consumers=..., nodeClass=..., transport=..., transportClass: Type[Transport]=...) -> None: ...
|
||||
def destroy(self) -> None: ...
|
||||
def doTick(self, timeToWait: float = 0.0) -> None: ...
|
||||
def isNodeConnected(self, node: Node) -> bool: ...
|
||||
@property
|
||||
def selfNode(self) -> Node: ...
|
||||
@property
|
||||
def otherNodes(self) -> Set[Node]: ...
|
||||
@property
|
||||
def raftLastApplied(self) -> int: ...
|
||||
@property
|
||||
def raftCommitIndex(self) -> int: ...
|
||||
@property
|
||||
def conf(self) -> SyncObjConf: ...
|
||||
def _getLeader(self) -> Optional[Node]: ...
|
||||
def _isLeader(self) -> bool: ...
|
||||
def _onTick(self, timeToWait: float = 0.0) -> None: ...
|
||||
def replicated(*decArgs: Any, **decKwargs: Any) -> Callable[..., Any]: ...
|
||||
4
typings/pysyncobj/tcp_connection.pyi
Normal file
4
typings/pysyncobj/tcp_connection.pyi
Normal file
@@ -0,0 +1,4 @@
|
||||
class CONNECTION_STATE:
|
||||
DISCONNECTED = ...
|
||||
CONNECTING = ...
|
||||
CONNECTED = ...
|
||||
10
typings/pysyncobj/transport.pyi
Normal file
10
typings/pysyncobj/transport.pyi
Normal file
@@ -0,0 +1,10 @@
|
||||
from typing import Any, Callable, Collection, Optional
|
||||
from .node import TCPNode
|
||||
from .syncobj import SyncObj
|
||||
from .tcp_connection import CONNECTION_STATE
|
||||
__all__ = ['CONNECTION_STATE', 'TCPTransport']
|
||||
class Transport:
|
||||
def setOnUtilityMessageCallback(self, message: str, callback: Callable[[Any, Callable[..., Any]], Any]) -> None: ...
|
||||
class TCPTransport(Transport):
|
||||
def __init__(self, syncObj: SyncObj, selfNode: Optional[TCPNode], otherNodes: Collection[TCPNode]) -> None: ...
|
||||
def _connectIfNecessarySingle(self, node: TCPNode) -> bool: ...
|
||||
5
typings/pysyncobj/utility.pyi
Normal file
5
typings/pysyncobj/utility.pyi
Normal file
@@ -0,0 +1,5 @@
|
||||
from typing import Any, List, Optional, Union
|
||||
from .node import TCPNode
|
||||
class TcpUtility(Utility):
|
||||
def __init__(self, password: Optional[str] = None, timeout: float=900.0) -> None: ...
|
||||
def executeCommand(self, node: Union[str, TCPNode], command: List[Any]) -> Any: ...
|
||||
6
typings/urllib3/__init__.pyi
Normal file
6
typings/urllib3/__init__.pyi
Normal file
@@ -0,0 +1,6 @@
|
||||
from .poolmanager import PoolManager
|
||||
from .response import HTTPResponse
|
||||
from .util.request import make_headers
|
||||
from .util.timeout import Timeout
|
||||
|
||||
__all__ = ['HTTPResponse', 'PoolManager', 'Timeout', 'make_headers']
|
||||
28
typings/urllib3/_collections.pyi
Normal file
28
typings/urllib3/_collections.pyi
Normal file
@@ -0,0 +1,28 @@
|
||||
from typing import Any
|
||||
class HTTPHeaderDict(MutableMapping[str, str]):
|
||||
def __init__(self, headers=None, **kwargs) -> None: ...
|
||||
def __setitem__(self, key, val) -> None: ...
|
||||
def __getitem__(self, key): ...
|
||||
def __delitem__(self, key) -> None: ...
|
||||
def __contains__(self, key): ...
|
||||
def __eq__(self, other): ...
|
||||
def __iter__(self) -> NoReturn: ...
|
||||
def __len__(self) -> int: ...
|
||||
def __ne__(self, other): ...
|
||||
values: Any
|
||||
get: Any
|
||||
update: Any
|
||||
iterkeys: Any
|
||||
itervalues: Any
|
||||
def pop(self, key, default=...): ...
|
||||
def discard(self, key): ...
|
||||
def add(self, key, val): ...
|
||||
def extend(self, *args, **kwargs): ...
|
||||
def getlist(self, key): ...
|
||||
getheaders: Any
|
||||
getallmatchingheaders: Any
|
||||
iget: Any
|
||||
def copy(self): ...
|
||||
def iteritems(self): ...
|
||||
def itermerged(self): ...
|
||||
def items(self): ...
|
||||
2
typings/urllib3/connection.pyi
Normal file
2
typings/urllib3/connection.pyi
Normal file
@@ -0,0 +1,2 @@
|
||||
from http.client import HTTPConnection as _HTTPConnection
|
||||
class HTTPConnection(_HTTPConnection): ...
|
||||
9
typings/urllib3/poolmanager.pyi
Normal file
9
typings/urllib3/poolmanager.pyi
Normal file
@@ -0,0 +1,9 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from .response import HTTPResponse
|
||||
class PoolManager:
|
||||
headers: Dict[str, str]
|
||||
connection_pool_kw: Dict[str, Any]
|
||||
def __init__(self, num_pools: int = 10, headers: Optional[Dict[str, str]] = None, **connection_pool_kw: Any) -> None: ...
|
||||
def urlopen(self, method: str, url: str, body: Optional[Any] = None, headers: Optional[Dict[str,str]] = None, encode_multipart: bool = True, multipart_boundary: Optional[str] = None, **kw: Any) -> HTTPResponse: ...
|
||||
def request(self, method: str, url: str, fields: Optional[Any] = None, headers: Optional[Dict[str, str]] = None, **urlopen_kw: Any) -> HTTPResponse: ...
|
||||
def clear(self) -> None: ...
|
||||
14
typings/urllib3/response.pyi
Normal file
14
typings/urllib3/response.pyi
Normal file
@@ -0,0 +1,14 @@
|
||||
import io
|
||||
from typing import Any, Iterator, Optional, Union
|
||||
from ._collections import HTTPHeaderDict
|
||||
from .connection import HTTPConnection
|
||||
class HTTPResponse(io.IOBase):
|
||||
headers: HTTPHeaderDict
|
||||
status: int
|
||||
reason: Optional[str]
|
||||
def release_conn(self) -> None: ...
|
||||
@property
|
||||
def data(self) -> Union[bytes, Any]: ...
|
||||
@property
|
||||
def connection(self) -> Optional[HTTPConnection]: ...
|
||||
def read_chunked(self, amt: Optional[int] = None, decode_content: Optional[bool] = None) -> Iterator[bytes]: ...
|
||||
9
typings/urllib3/util/request.pyi
Normal file
9
typings/urllib3/util/request.pyi
Normal file
@@ -0,0 +1,9 @@
|
||||
from typing import Optional, Union, Dict, List
|
||||
def make_headers(
|
||||
keep_alive: Optional[bool] = None,
|
||||
accept_encoding: Union[bool, List[str], str, None] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
basic_auth: Optional[str] = None,
|
||||
proxy_basic_auth: Optional[str] = None,
|
||||
disable_cache: Optional[bool] = None,
|
||||
) -> Dict[str, str]: ...
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user