Implement generate-config (#2786)

New patroni.py option that allows to

* generate patroni.yml configuration file with the values from a running cluster
* generate a sample patroni.yml configuration file
This commit is contained in:
Polina Bungina
2023-08-09 17:46:53 +02:00
committed by GitHub
parent 713244975c
commit 3734ecc851
8 changed files with 912 additions and 36 deletions

View File

@@ -163,11 +163,26 @@ def patroni_main(configfile: str) -> None:
def process_arguments() -> Namespace:
from patroni.config_generator import generate_config
parser = get_base_arg_parser()
parser.add_argument('--validate-config', action='store_true', help='Run config validator and exit')
group = parser.add_mutually_exclusive_group()
group.add_argument('--validate-config', action='store_true', help='Run config validator and exit')
group.add_argument('--generate-sample-config', action='store_true',
help='Generate a sample Patroni yaml configuration file')
group.add_argument('--generate-config', action='store_true',
help='Generate a Patroni yaml configuration file for a running instance')
parser.add_argument('--dsn', help='Optional DSN string of the instance to be used as a source \
for config generation. Superuser connection is required.')
args = parser.parse_args()
if args.validate_config:
if args.generate_sample_config:
generate_config(args.configfile, True, None)
sys.exit(0)
elif args.generate_config:
generate_config(args.configfile, False, args.dsn)
sys.exit(0)
elif args.validate_config:
from patroni.validator import schema
from patroni.config import Config, ConfigParseError

View File

@@ -3,7 +3,7 @@
Provides a case insensitive :class:`dict` and :class:`set` object types.
"""
from collections import OrderedDict
from typing import Any, Collection, Dict, Iterator, MutableMapping, MutableSet, Optional
from typing import Any, Collection, Dict, Iterator, KeysView, MutableMapping, MutableSet, Optional
class CaseInsensitiveSet(MutableSet[str]):
@@ -187,6 +187,13 @@ class CaseInsensitiveDict(MutableMapping[str, Any]):
"""
return CaseInsensitiveDict({v[0]: v[1] for v in self._values.values()})
def keys(self) -> KeysView[str]:
"""Return a new view of the dict's keys.
:returns: a set-like object providing a view on the dict's keys
"""
return self._values.keys()
def __repr__(self) -> str:
"""Get a string representation of the dict.

View File

@@ -194,10 +194,9 @@ class Config(object):
'recovery_min_apply_delay': ''
},
'postgresql': {
'bin_dir': '',
'use_slots': True,
'parameters': CaseInsensitiveDict({p: v[0] for p, v in ConfigHandler.CMDLINE_OPTIONS.items()
if p not in ('wal_keep_segments', 'wal_keep_size')})
if v[0] is not None and p not in ('wal_keep_segments', 'wal_keep_size')})
}
}
@@ -236,6 +235,22 @@ class Config(object):
def dynamic_configuration(self) -> Dict[str, Any]:
return deepcopy(self._dynamic_configuration)
@property
def local_configuration(self) -> Dict[str, Any]:
"""Deep copy of cached Patroni local configuration.
:returns: copy of :attr:`~Config._local_configuration`
"""
return deepcopy(dict(self._local_configuration))
@classmethod
def get_default_config(cls) -> Dict[str, Any]:
"""Deep copy default configuration.
:returns: copy of :attr:`~Config.__DEFAULT_CONFIG`
"""
return deepcopy(cls.__DEFAULT_CONFIG)
def _load_config_path(self, path: str) -> Dict[str, Any]:
"""
If path is a file, loads the yml file pointed to by path.
@@ -348,7 +363,7 @@ class Config(object):
return pg_params
def _safe_copy_dynamic_configuration(self, dynamic_configuration: Dict[str, Any]) -> Dict[str, Any]:
config = deepcopy(self.__DEFAULT_CONFIG)
config = self.get_default_config()
for name, value in dynamic_configuration.items():
if name == 'postgresql':

463
patroni/config_generator.py Normal file
View File

@@ -0,0 +1,463 @@
"""patroni ``--generate-config`` machinery."""
import abc
import logging
import os
import psutil
import socket
import sys
import yaml
from getpass import getuser, getpass
from contextlib import contextmanager
from typing import Any, Dict, Iterator, List, Optional, Tuple, TYPE_CHECKING, Union
if TYPE_CHECKING: # pragma: no cover
from psycopg import Cursor
from psycopg2 import cursor
from . import psycopg
from .config import Config
from .exceptions import PatroniException
from .postgresql.config import ConfigHandler, parse_dsn
from .postgresql.misc import postgres_major_version_to_int
from .utils import get_major_version, parse_bool, patch_config, read_stripped
# Mapping between the libpq connection parameters and the environment variables.
# This dict should be kept in sync with `patroni.utils._AUTH_ALLOWED_PARAMETERS`
# (we use "username" in the Patroni config for some reason, other parameter names are the same).
_AUTH_ALLOWED_PARAMETERS_MAPPING = {
'user': 'PGUSER',
'password': 'PGPASSWORD',
'sslmode': 'PGSSLMODE',
'sslcert': 'PGSSLCERT',
'sslkey': 'PGSSLKEY',
'sslpassword': '',
'sslrootcert': 'PGSSLROOTCERT',
'sslcrl': 'PGSSLCRL',
'sslcrldir': 'PGSSLCRLDIR',
'gssencmode': 'PGGSSENCMODE',
'channel_binding': 'PGCHANNELBINDING'
}
_NO_VALUE_MSG = '#FIXME'
def get_address() -> Tuple[str, str]:
"""Try to get hostname and the ip address for it returned by :func:`~socket.gethostname`.
.. note::
Can also return local ip.
:returns: tuple consisting of the hostname returned by :func:`~socket.gethostname`
and the first element in the sorted list of the addresses returned by :func:`~socket.getaddrinfo`.
Sorting guarantees it will prefer IPv4.
If an exception occured, hostname and ip values are equal to :data:`~patroni.config_generator._NO_VALUE_MSG`.
"""
hostname = None
try:
hostname = socket.gethostname()
return hostname, sorted(socket.getaddrinfo(hostname, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, 0),
key=lambda x: x[0])[0][4][0]
except Exception as err:
logging.warning('Failed to obtain address: %r', err)
return _NO_VALUE_MSG, _NO_VALUE_MSG
class AbstractConfigGenerator(abc.ABC):
"""Object representing the generated Patroni config.
:ivar output_file: full path to the output file to be used.
:ivar pg_major: integer representation of the major PostgreSQL version.
:ivar config: dictionary used for the generated configuration storage.
"""
_HOSTNAME, _IP = get_address()
def __init__(self, output_file: Optional[str]) -> None:
"""Set up the output file (if passed), helper vars and the minimal config structure.
:param output_file: full path to the output file to be used.
"""
self.output_file = output_file
self.pg_major = 0
self.config = self.get_template_config()
self.generate()
@classmethod
def get_template_config(cls) -> Dict[str, Any]:
"""Generate a template config for further extension (e.g. in the inherited classes).
:returns: dictionary with the values gathered from Patroni env, hopefully defined hostname and ip address
(otherwise set to :data:`~patroni.config_generator._NO_VALUE_MSG`), and some sane defaults.
"""
template_config: Dict[str, Any] = {
'scope': _NO_VALUE_MSG,
'name': cls._HOSTNAME,
'postgresql': {
'data_dir': _NO_VALUE_MSG,
'connect_address': _NO_VALUE_MSG + ':5432',
'listen': _NO_VALUE_MSG + ':5432',
'bin_dir': '',
'authentication': {
'superuser': {
'username': 'postgres',
'password': _NO_VALUE_MSG
},
'replication': {
'username': 'replicator',
'password': _NO_VALUE_MSG
}
}
},
'restapi': {
'connect_address': cls._IP + ':8008',
'listen': cls._IP + ':8008'
}
}
dynamic_config = Config.get_default_config()
# to properly dump CaseInsensitiveDict as YAML later
dynamic_config['postgresql']['parameters'] = dict(dynamic_config['postgresql']['parameters'])
config = Config('', None).local_configuration # Get values from env
config.setdefault('bootstrap', {})['dcs'] = dynamic_config
config.setdefault('postgresql', {})
del config['bootstrap']['dcs']['standby_cluster']
patch_config(template_config, config)
return template_config
@abc.abstractmethod
def generate(self) -> None:
"""Generate config and store in :attr:`~AbstractConfigGenerator.config`."""
def write_config(self) -> None:
"""Write current :attr:`~AbstractConfigGenerator.config` to the output file if provided, to stdout otherwise."""
if self.output_file:
dir_path = os.path.dirname(self.output_file)
if dir_path and not os.path.isdir(dir_path):
os.makedirs(dir_path)
with open(self.output_file, 'w', encoding='UTF-8') as output_file:
yaml.safe_dump(self.config, output_file, default_flow_style=False, allow_unicode=True)
else:
yaml.safe_dump(self.config, sys.stdout, default_flow_style=False, allow_unicode=True)
class SampleConfigGenerator(AbstractConfigGenerator):
"""Object representing the generated sample Patroni config.
Sane defults are used based on the gathered PG version.
"""
@property
def get_auth_method(self) -> str:
"""Return the preferred authentication method for a specific PG version if provided or the default ``md5``.
:returns: :class:`str` value for the preferred authentication method.
"""
return 'scram-sha-256' if self.pg_major and self.pg_major >= 100000 else 'md5'
def _get_int_major_version(self) -> int:
"""Get major PostgreSQL version from the binary as an integer.
:returns: an integer PostgreSQL major version representation gathered from the PostgreSQL binary.
See :func:`~patroni.postgresql.misc.postgres_major_version_to_int` and
:func:`~patroni.utils.get_major_version`.
"""
postgres_bin = ((self.config.get('postgresql') or {}).get('bin_name') or {}).get('postgres', 'postgres')
return postgres_major_version_to_int(get_major_version(self.config['postgresql'].get('bin_dir'), postgres_bin))
def generate(self) -> None:
"""Generate sample config using some sane defaults and update :attr:`~AbstractConfigGenerator.config`."""
self.pg_major = self._get_int_major_version()
self.config['postgresql']['parameters'] = {'password_encryption': self.get_auth_method}
username = self.config["postgresql"]["authentication"]["replication"]["username"]
self.config['postgresql']['pg_hba'] = [
f'host all all all {self.get_auth_method}',
f'host replication {username} all {self.get_auth_method}'
]
# add version-specific configuration
wal_keep_param = 'wal_keep_segments' if self.pg_major < 130000 else 'wal_keep_size'
self.config['bootstrap']['dcs']['postgresql']['parameters'][wal_keep_param] = \
ConfigHandler.CMDLINE_OPTIONS[wal_keep_param][0]
self.config['bootstrap']['dcs']['postgresql']['use_pg_rewind'] = True
if self.pg_major >= 110000:
self.config['postgresql']['authentication'].setdefault(
'rewind', {'username': 'rewind_user'}).setdefault('password', _NO_VALUE_MSG)
class RunningClusterConfigGenerator(AbstractConfigGenerator):
"""Object representing the Patroni config generated using information gathered from the running instance.
:ivar dsn: DSN string for the local instance to get GUC values from (if provided).
:ivar parsed_dsn: DSN string parsed into a dictionary (see :func:`~patroni.postgresql.config.parse_dsn`).
"""
def __init__(self, output_file: Optional[str] = None, dsn: Optional[str] = None) -> None:
"""Additionally store the passed dsn (if any) in both original and parsed version and run config generation.
:param output_file: full path to the output file to be used.
:param dsn: DSN string for the local instance to get GUC values from.
:raises:
:exc:`~patroni.exceptions.PatroniException`: if DSN parsing failed.
"""
self.dsn = dsn
self.parsed_dsn = {}
super().__init__(output_file)
@property
def _get_hba_conn_types(self) -> Tuple[str, ...]:
"""Return the connection types allowed.
If :attr:`~RunningClusterConfigGenerator.pg_major` is defined, adds additional parameters
for PostgreSQL version >=16.
:returns: tuple of the connection methods allowed.
"""
allowed_types = ('local', 'host', 'hostssl', 'hostnossl', 'hostgssenc', 'hostnogssenc')
if self.pg_major and self.pg_major >= 160000:
allowed_types += ('include', 'include_if_exists', 'include_dir')
return allowed_types
@property
def _required_pg_params(self) -> List[str]:
"""PG configuration prameters that have to be always present in the generated config.
:returns: list of the parameter names.
"""
return ['hba_file', 'ident_file', 'config_file', 'data_directory'] + \
list(ConfigHandler.CMDLINE_OPTIONS.keys())
def _get_bin_dir_from_running_instance(self) -> str:
"""Define the directory postgres binaries reside using postmaster's pid executable.
:returns: path to the PostgreSQL binaries directory.
:raises:
:exc:`~patroni.exceptions.PatroniException`: if:
* pid could not be obtained from the ``postmaster.pid`` file; or
* :exc:`OSError` occured during ``postmaster.pid`` file handling; or
* the obtained postmaster pid doesn't exist.
"""
postmaster_pid = None
data_dir = self.config['postgresql']['data_dir']
try:
with open(f"{data_dir}/postmaster.pid", 'r') as pid_file:
postmaster_pid = pid_file.readline()
if not postmaster_pid:
raise PatroniException('Failed to obtain postmaster pid from postmaster.pid file')
postmaster_pid = int(postmaster_pid.strip())
except OSError as err:
raise PatroniException(f'Error while reading postmaster.pid file: {err}')
try:
return os.path.dirname(psutil.Process(postmaster_pid).exe())
except psutil.NoSuchProcess:
raise PatroniException("Obtained postmaster pid doesn't exist.")
@contextmanager
def _get_connection_cursor(self) -> Iterator[Union['cursor', 'Cursor[Any]']]:
"""Get cursor for the PG connection established based on the stored information.
:raises:
:exc:`~patroni.exceptions.PatroniException`: if :exc:`psycopg.Error` occured.
"""
try:
conn = psycopg.connect(dsn=self.dsn,
password=self.config['postgresql']['authentication']['superuser']['password'])
with conn.cursor() as cur:
yield cur
conn.close()
except psycopg.Error as e:
raise PatroniException(f'Failed to establish PostgreSQL connection: {e}')
def _set_pg_params(self, cur: Union['cursor', 'Cursor[Any]']) -> None:
"""Extend :attr:`~RunningClusterConfigGenerator.config` with the actual PG GUCs values.
THe following GUC values are set:
* Non-internal having configuration file, postmaster command line or environment variable
as a source.
* List of the always required parameters (see :meth:`~RunningClusterConfigGenerator._required_pg_params`).
:param cur: connection cursor to use.
"""
cur.execute("SELECT name, current_setting(name) FROM pg_settings "
"WHERE context <> 'internal' "
"AND source IN ('configuration file', 'command line', 'environment variable') "
"AND category <> 'Write-Ahead Log / Recovery Target' "
"AND setting <> '(disabled)' "
"OR name = ANY(%s)", (self._required_pg_params,))
helper_dict = dict.fromkeys(['port', 'listen_addresses'])
self.config['postgresql'].setdefault('parameters', {})
for param, value in cur.fetchall():
if param == 'data_directory':
self.config['postgresql']['data_dir'] = value
elif param == 'cluster_name' and value:
self.config['scope'] = value
elif param in ('archive_command', 'restore_command',
'archive_cleanup_command', 'recovery_end_command',
'ssl_passphrase_command', 'hba_file',
'ident_file', 'config_file'):
# write commands to the local config due to security implications
# write hba/ident/config_file to local config to ensure they are not removed later
self.config['postgresql']['parameters'][param] = value
elif param in helper_dict:
helper_dict[param] = value
else:
self.config['bootstrap']['dcs']['postgresql']['parameters'][param] = value
connect_port = self.parsed_dsn.get('port', os.getenv('PGPORT', helper_dict['port']))
self.config['postgresql']['connect_address'] = f'{self._IP}:{connect_port}'
self.config['postgresql']['listen'] = f'{helper_dict["listen_addresses"]}:{helper_dict["port"]}'
def _set_su_params(self) -> None:
"""Extend :attr:`~RunningClusterConfigGenerator.config` with the superuser auth information.
Information set is based on the options used for connection.
"""
su_params: Dict[str, str] = {}
for conn_param, env_var in _AUTH_ALLOWED_PARAMETERS_MAPPING.items():
val = self.parsed_dsn.get(conn_param, os.getenv(env_var))
if val:
su_params[conn_param] = val
patroni_env_su_username = ((self.config.get('authentication') or {}).get('superuser') or {}).get('username')
patroni_env_su_pwd = ((self.config.get('authentication') or {}).get('superuser') or {}).get('password')
# because we use "username" in the config for some reason
su_params['username'] = su_params.pop('user', patroni_env_su_username) or getuser()
su_params['password'] = su_params.get('password', patroni_env_su_pwd) or \
getpass('Please enter the user password:')
self.config['postgresql']['authentication'] = {
'superuser': su_params,
'replication': {'username': _NO_VALUE_MSG, 'password': _NO_VALUE_MSG}
}
def _set_conf_files(self) -> None:
"""Extend :attr:`~RunningClusterConfigGenerator.config` with ``pg_hba.conf`` and ``pg_ident.conf`` content.
.. note::
This function only defines ``postgresql.pg_hba`` and ``postgresql.pg_ident`` when
``hba_file`` and ``ident_file`` are set to the defaults. It may happen these files
are located outside of ``PGDATA`` and Patroni doesn't have write permissions for them.
:raises:
:exc:`~patroni.exceptions.PatroniException`: if :exc:`OSError` occured during the conf files handling.
"""
default_hba_path = os.path.join(self.config['postgresql']['data_dir'], 'pg_hba.conf')
if self.config['postgresql']['parameters']['hba_file'] == default_hba_path:
try:
self.config['postgresql']['pg_hba'] = list(
filter(lambda i: i and i.split()[0] in self._get_hba_conn_types, read_stripped(default_hba_path)))
except OSError as err:
raise PatroniException(f'Failed to read pg_hba.conf: {err}')
default_ident_path = os.path.join(self.config['postgresql']['data_dir'], 'pg_ident.conf')
if self.config['postgresql']['parameters']['ident_file'] == default_ident_path:
try:
self.config['postgresql']['pg_ident'] = [i for i in read_stripped(default_ident_path)
if i and not i.startswith('#')]
except OSError as err:
raise PatroniException(f'Failed to read pg_ident.conf: {err}')
if not self.config['postgresql']['pg_ident']:
del self.config['postgresql']['pg_ident']
def _enrich_config_from_running_instance(self) -> None:
"""Extend :attr:`~RunningClusterConfigGenerator.config` with the values gathered from the running instance.
Retrieve the following information from the running PostgreSQL instance:
* superuser auth parameters (see :meth:`~RunningClusterConfigGenerator._set_su_params`);
* some GUC values (see :meth:`~RunningClusterConfigGenerator._set_pg_params`);
* ``postgresql.connect_address``, ``postgresql.listen``;
* ``postgresql.pg_hba`` and ``postgresql.pg_ident`` (see :meth:`~RunningClusterConfigGenerator._set_conf_files`)
And redefine ``scope`` with the ``cluster_name`` GUC value if set.
:raises:
:exc:`~patroni.exceptions.PatroniException`: if the provided user doesn't have superuser privileges.
"""
self._set_su_params()
with self._get_connection_cursor() as cur:
self.pg_major = getattr(cur.connection, 'server_version', 0)
if not parse_bool(cur.connection.info.parameter_status('is_superuser')):
raise PatroniException('The provided user does not have superuser privilege')
self._set_pg_params(cur)
self._set_conf_files()
def generate(self) -> None:
"""Generate config using the info gathered from the specified running PG instance.
Result is written to :attr:`~RunningClusterConfigGenerator.config`.
"""
if self.dsn:
self.parsed_dsn = parse_dsn(self.dsn) or {}
if not self.parsed_dsn:
raise PatroniException('Failed to parse DSN string')
self._enrich_config_from_running_instance()
self.config['postgresql']['bin_dir'] = self._get_bin_dir_from_running_instance()
def generate_config(output_file: str, sample: bool, dsn: Optional[str]) -> None:
"""Generate Patroni configuration file.
Gather all the available non-internal GUC values having configuration file, postmaster command line or environment
variable as a source and store them in the appropriate part of Patroni configuration (``postgresql.parameters`` or
``bootstrap.dcs.postgresql.parameters``). Either the provided DSN (takes precedence) or PG ENV vars will be used
for the connection. If password is not provided, it should be entered via prompt.
The created configuration contains:
* ``scope``: ``cluster_name`` GUC value or ``PATRONI_SCOPE ENV`` variable value if available.
* ``name``: ``PATRONI_NAME`` ENV variable value if set, otherwise hostname.
* ``bootstrap.dcs``: section with all the parameters (incl. the majority of PG GUCs) set to their default values
defined by Patroni and adjusted by the source instances's configuration values.
* ``postgresql.parameters``: the source instance's ``archive_command``, ``restore_command``,
``archive_cleanup_command``, ``recovery_end_command``, ``ssl_passphrase_command``, ``hba_file``, ``ident_file``,
``config_file`` GUC values.
* ``postgresql.bin_dir``: path to Postgres binaries gathered from the running instance or, if not available,
the value of ``PATRONI_POSTGRESQL_BIN_DIR`` ENV variable. Otherwise, an empty string.
* ``postgresql.datadir``: the value gathered from the corresponding PG GUC.
* ``postgresql.listen``: source instance's ``listen_addresses`` and port GUC values.
* ``postgresql.connect_address``: if possible, generated from the connection params.
* ``postgresql.authentication``:
* superuser and replication users defined (if possible, usernames are set from the respective Patroni ENV vars,
otherwise the default ``postgres`` and ``replicator`` values are used).
If not a sample config, either DSN or PG ENV vars are used to define superuser authentication parameters.
* rewind user is defined only for sample config, if PG version can be defined and PG version is >=11
(if possible, username is set from the respective Patroni ENV var).
* ``bootstrap.dcs.postgresql.use_pg_rewind`` set to ``True`` for a sample config only.
* ``postgresql.pg_hba`` defaults or the lines gathered from the source instance's ``hba_file``.
* ``postgresql.pg_ident`` the lines gathered from the source instance's ``ident_file``.
:param output_file: Full path to the configuration file to be used. If not provided, result is sent to ``stdout``.
:param sample: Optional flag. If set, no source instance will be used - generate config with some sane defaults.
:param dsn: Optional DSN string for the local instance to get GUC values from.
"""
try:
if sample:
config_generator = SampleConfigGenerator(output_file)
else:
config_generator = RunningClusterConfigGenerator(output_file, dsn)
config_generator.write_config()
except PatroniException as e:
sys.exit(str(e))
except Exception as e:
sys.exit(f'Unexpected exception: {e}')

View File

@@ -16,6 +16,7 @@ import platform
import random
import re
import socket
import subprocess
import sys
import tempfile
import time
@@ -467,6 +468,18 @@ def _sleep(interval: Union[int, float]) -> None:
time.sleep(interval)
def read_stripped(file_path: str) -> Iterator[str]:
"""Iterate over stripped lines in the given file.
:param file_path: path to the file to read from
:yields: each line from the given file stripped
"""
with open(file_path) as f:
for line in f:
yield line.strip()
class RetryFailedError(PatroniException):
"""Maximum number of attempts exhausted in retry operation."""
@@ -977,3 +990,34 @@ def unquote(string: str) -> str:
except ValueError:
ret = string
return ret
def get_major_version(bin_dir: Optional[str] = None, bin_name: str = 'postgres') -> str:
"""Get the major version of PostgreSQL.
It is based on the output of ``postgres --version``.
:param bin_dir: path to the PostgreSQL binaries directory. If ``None`` or an empty string, it will use the first
*bin_name* binary that is found by the subprocess in the ``PATH``.
:param bin_name: name of the postgres binary to call (``postgres`` by default)
:returns: the PostgreSQL major version.
:raises:
:exc:`~patroni.exceptions.PatroniException`: if the postgres binary call failed due to :exc:`OSError`.
:Example:
* Returns `9.6` for PostgreSQL 9.6.24
* Returns `15` for PostgreSQL 15.2
"""
if not bin_dir:
binary = bin_name
else:
binary = os.path.join(bin_dir, bin_name)
try:
version = subprocess.check_output([binary, '--version']).decode()
except OSError as e:
raise PatroniException(f'Failed to get postgres version: {e}')
version = re.match(r'^[^\s]+ [^\s]+ (\d+)(\.(\d+))?', version)
if TYPE_CHECKING: # pragma: no cover
assert version is not None
return '.'.join([version.group(1), version.group(3)]) if int(version.group(1)) < 10 else version.group(1)

View File

@@ -6,17 +6,16 @@ This module contains facilities for validating configuration of Patroni processe
:var schema: configuration schema of the daemon launched by `patroni` command.
"""
import os
import re
import shutil
import socket
import subprocess
from typing import Any, Dict, Union, Iterator, List, Optional as OptionalType, Tuple, TYPE_CHECKING
from typing import Any, Dict, Union, Iterator, List, Optional as OptionalType, Tuple
from .collections import CaseInsensitiveSet
from .dcs import dcs_modules
from .exceptions import ConfigParseError
from .utils import parse_int, split_host_port, data_directory_is_empty
from .utils import parse_int, split_host_port, data_directory_is_empty, get_major_version
def data_directory_empty(data_dir: str) -> bool:
@@ -187,31 +186,6 @@ def get_bin_name(bin_name: str) -> str:
return (schema.data.get('postgresql', {}).get('bin_name', {}) or {}).get(bin_name, bin_name)
def get_major_version(bin_dir: OptionalType[str] = None) -> str:
"""Get the major version of PostgreSQL.
It is based on the output of ``postgres --version``.
:param bin_dir: path to PostgreSQL binaries directory. If ``None`` it will use the first ``postgres`` binary that
is found by subprocess in the ``PATH``.
:returns: the PostgreSQL major version.
:Example:
* Returns `9.6` for PostgreSQL 9.6.24
* Returns `15` for PostgreSQL 15.2
"""
if not bin_dir:
binary = get_bin_name('postgres')
else:
binary = os.path.join(bin_dir, get_bin_name('postgres'))
version = subprocess.check_output([binary, '--version']).decode()
version = re.match(r'^[^\s]+ [^\s]+ (\d+)(\.(\d+))?', version)
if TYPE_CHECKING: # pragma: no cover
assert version is not None
return '.'.join([version.group(1), version.group(3)]) if int(version.group(1)) < 10 else version.group(1)
def validate_data_dir(data_dir: str) -> bool:
"""Validate the value of ``postgresql.data_dir`` configuration option.
@@ -246,7 +220,7 @@ def validate_data_dir(data_dir: str) -> bool:
raise ConfigParseError("data dir for the cluster is not empty, but doesn't contain"
" \"{}\" directory".format(waldir))
bin_dir = schema.data.get("postgresql", {}).get("bin_dir", None)
major_version = get_major_version(bin_dir)
major_version = get_major_version(bin_dir, get_bin_name('postgres'))
if pgversion != major_version:
raise ConfigParseError("data_dir directory postgresql version ({}) doesn't match with "
"'postgres --version' output ({})".format(pgversion, major_version))

View File

@@ -118,6 +118,21 @@ class MockCursor(object):
'"state":"streaming","sync_state":"async","sync_priority":0}]'
now = datetime.datetime.now(tzutc)
self.results = [(now, 0, '', 0, '', False, now, 'streaming', None, replication_info)]
elif sql.startswith('SELECT name, current_setting(name) FROM pg_settings'):
self.results = [('data_directory', 'data'),
('hba_file', os.path.join('data', 'pg_hba.conf')),
('ident_file', os.path.join('data', 'pg_ident.conf')),
('max_connections', 42),
('max_locks_per_transaction', 73),
('max_prepared_transactions', 0),
('max_replication_slots', 21),
('max_wal_senders', 37),
('track_commit_timestamp', 'off'),
('wal_level', 'replica'),
('listen_addresses', '6.6.6.6'),
('port', 1984),
('archive_command', 'my archive command'),
('cluster_name', 'my_cluster')]
elif sql.startswith('SELECT name, setting'):
self.results = [('wal_segment_size', '2048', '8kB', 'integer', 'internal'),
('wal_block_size', '8192', None, 'integer', 'internal'),
@@ -159,11 +174,20 @@ class MockCursor(object):
pass
class MockConnectionInfo(object):
def parameter_status(self, param_name):
if param_name == 'is_superuser':
return 'on'
return '0'
class MockConnect(object):
server_version = 99999
autocommit = False
closed = 0
info = MockConnectionInfo()
def cursor(self):
return MockCursor(self)

View File

@@ -0,0 +1,334 @@
import os
import psutil
import socket
import unittest
from . import MockConnect, MockCursor, MockConnectionInfo
from copy import deepcopy
from mock import MagicMock, Mock, PropertyMock, mock_open, patch
from patroni.__main__ import main as _main
from patroni.config import Config
from patroni.config_generator import AbstractConfigGenerator, get_address
from patroni.utils import patch_config
from . import psycopg_connect
@patch('patroni.psycopg.connect', psycopg_connect)
@patch('socket.getaddrinfo', Mock(return_value=[(0, 0, 0, 0, ('1.9.8.4', 1984))]))
@patch('builtins.open', MagicMock())
@patch('subprocess.check_output', Mock(return_value=b"postgres (PostgreSQL) 16.2"))
@patch('psutil.Process.exe', Mock(return_value='/bin/dir/from/running/postgres'))
@patch('psutil.Process.__init__', Mock(return_value=None))
class TestGenerateConfig(unittest.TestCase):
no_value_msg = '#FIXME'
_HOSTNAME = socket.gethostname()
_IP = sorted(socket.getaddrinfo(_HOSTNAME, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, 0), key=lambda x: x[0])[0][4][0]
def setUp(self):
self.maxDiff = None
os.environ['PATRONI_SCOPE'] = 'scope_from_env'
os.environ['PATRONI_POSTGRESQL_BIN_DIR'] = '/bin/from/env'
os.environ['PATRONI_SUPERUSER_USERNAME'] = 'su_user_from_env'
os.environ['PATRONI_SUPERUSER_PASSWORD'] = 'su_pwd_from_env'
os.environ['PATRONI_REPLICATION_USERNAME'] = 'repl_user_from_env'
os.environ['PATRONI_REPLICATION_PASSWORD'] = 'repl_pwd_from_env'
os.environ['PATRONI_REWIND_USERNAME'] = 'rewind_user_from_env'
os.environ['PGUSER'] = 'pguser_from_env'
os.environ['PGPASSWORD'] = 'pguser_pwd_from_env'
os.environ['PATRONI_RESTAPI_CONNECT_ADDRESS'] = 'localhost:8080'
os.environ['PATRONI_RESTAPI_LISTEN'] = 'localhost:8080'
os.environ['PATRONI_POSTGRESQL_BIN_POSTGRES'] = 'custom_postgres_bin_from_env'
self.environ = deepcopy(os.environ)
dynamic_config = Config.get_default_config()
dynamic_config['postgresql']['parameters'] = dict(dynamic_config['postgresql']['parameters'])
del dynamic_config['standby_cluster']
dynamic_config['postgresql']['parameters']['wal_keep_segments'] = 8
dynamic_config['postgresql']['use_pg_rewind'] = True
self.config = {
'scope': self.environ['PATRONI_SCOPE'],
'name': self._HOSTNAME,
'bootstrap': {
'dcs': dynamic_config
},
'postgresql': {
'connect_address': self.no_value_msg + ':5432',
'data_dir': self.no_value_msg,
'listen': self.no_value_msg + ':5432',
'pg_hba': ['host all all all md5',
f'host replication {self.environ["PATRONI_REPLICATION_USERNAME"]} all md5'],
'authentication': {'superuser': {'username': self.environ['PATRONI_SUPERUSER_USERNAME'],
'password': self.environ['PATRONI_SUPERUSER_PASSWORD']},
'replication': {'username': self.environ['PATRONI_REPLICATION_USERNAME'],
'password': self.environ['PATRONI_REPLICATION_PASSWORD']},
'rewind': {'username': self.environ['PATRONI_REWIND_USERNAME']}},
'bin_dir': self.environ['PATRONI_POSTGRESQL_BIN_DIR'],
'bin_name': {'postgres': self.environ['PATRONI_POSTGRESQL_BIN_POSTGRES']},
'parameters': {'password_encryption': 'md5'}
},
'restapi': {
'connect_address': self.environ['PATRONI_RESTAPI_CONNECT_ADDRESS'],
'listen': self.environ['PATRONI_RESTAPI_LISTEN']
}
}
def _set_running_instance_config_vals(self):
# values are taken from tests/__init__.py
conf = {
'scope': 'my_cluster',
'bootstrap': {
'dcs': {
'postgresql': {
'parameters': {
'max_connections': 42,
'max_locks_per_transaction': 73,
'max_replication_slots': 21,
'max_wal_senders': 37,
'wal_level': 'replica',
'wal_keep_segments': None
},
'use_pg_rewind': None
}
}
},
'postgresql': {
'connect_address': f'{self._IP}:bar',
'listen': '6.6.6.6:1984',
'data_dir': 'data',
'bin_dir': '/bin/dir/from/running',
'parameters': {
'archive_command': 'my archive command',
'hba_file': os.path.join('data', 'pg_hba.conf'),
'ident_file': os.path.join('data', 'pg_ident.conf'),
'password_encryption': None
},
'authentication': {
'superuser': {
'username': 'foobar',
'password': 'qwerty',
'channel_binding': 'prefer',
'gssencmode': 'prefer',
'sslmode': 'prefer'
},
'replication': {
'username': self.no_value_msg,
'password': self.no_value_msg
},
'rewind': None
},
}
}
patch_config(self.config, conf)
def _get_running_instance_open_res(self):
hba_content = '\n'.join(self.config['postgresql']['pg_hba'] + ['#host all all all md5',
' host all all all md5',
'',
'hostall all all md5'])
ident_content = '\n'.join(['# something very interesting', ' '])
self.config['postgresql']['pg_hba'] += ['host all all all md5']
return [
mock_open(read_data=hba_content)(),
mock_open(read_data=ident_content)(),
mock_open(read_data='1984')(),
mock_open()()
]
@patch('os.makedirs')
@patch('yaml.safe_dump')
def test_generate_sample_config_pre_13_dir_creation(self, mock_config_dump, mock_makedir):
with patch('sys.argv', ['patroni.py', '--generate-sample-config', '/foo/bar.yml']), \
patch('subprocess.check_output', Mock(return_value=b"postgres (PostgreSQL) 9.4.3")) as pg_bin_mock, \
self.assertRaises(SystemExit) as e:
_main()
self.assertEqual(e.exception.code, 0)
self.assertEqual(self.config, mock_config_dump.call_args[0][0])
mock_makedir.assert_called_once()
pg_bin_mock.assert_called_once_with([os.path.join(self.environ['PATRONI_POSTGRESQL_BIN_DIR'],
self.environ['PATRONI_POSTGRESQL_BIN_POSTGRES']),
'--version'])
@patch('os.makedirs', Mock())
@patch('yaml.safe_dump')
def test_generate_sample_config_16(self, mock_config_dump):
conf = {
'bootstrap': {
'dcs': {
'postgresql': {
'parameters': {
'wal_keep_size': '128MB',
'wal_keep_segments': None
},
}
}
},
'postgresql': {
'parameters': {
'password_encryption': 'scram-sha-256'
},
'pg_hba': ['host all all all scram-sha-256',
f'host replication {self.environ["PATRONI_REPLICATION_USERNAME"]} all scram-sha-256'],
'authentication': {
'rewind': {
'username': self.environ['PATRONI_REWIND_USERNAME'],
'password': self.no_value_msg}
},
}
}
patch_config(self.config, conf)
with patch('sys.argv', ['patroni.py', '--generate-sample-config', '/foo/bar.yml']), \
self.assertRaises(SystemExit) as e:
_main()
self.assertEqual(e.exception.code, 0)
self.assertEqual(self.config, mock_config_dump.call_args[0][0])
@patch('os.makedirs', Mock())
@patch('yaml.safe_dump')
def test_generate_config_running_instance_16(self, mock_config_dump):
self._set_running_instance_config_vals()
with patch('builtins.open', Mock(side_effect=self._get_running_instance_open_res())), \
patch('sys.argv', ['patroni.py', '--generate-config',
'--dsn', 'host=foo port=bar user=foobar password=qwerty']), \
self.assertRaises(SystemExit) as e:
_main()
self.assertEqual(e.exception.code, 0)
self.assertEqual(self.config, mock_config_dump.call_args[0][0])
@patch('os.makedirs', Mock())
@patch('yaml.safe_dump')
def test_generate_config_running_instance_16_connect_from_env(self, mock_config_dump):
self._set_running_instance_config_vals()
# su auth params and connect host from env
os.environ['PGCHANNELBINDING'] = \
self.config['postgresql']['authentication']['superuser']['channel_binding'] = 'disable'
conf = {
'scope': 'my_cluster',
'bootstrap': {
'dcs': {
'postgresql': {
'parameters': {
'max_connections': 42,
'max_locks_per_transaction': 73,
'max_replication_slots': 21,
'max_wal_senders': 37,
'wal_level': 'replica',
'wal_keep_segments': None
},
'use_pg_rewind': None
}
}
},
'postgresql': {
'connect_address': f'{self._IP}:1984',
'authentication': {
'superuser': {
'username': self.environ['PGUSER'],
'password': self.environ['PGPASSWORD'],
'gssencmode': None,
'sslmode': None
},
},
}
}
patch_config(self.config, conf)
with patch('builtins.open', Mock(side_effect=self._get_running_instance_open_res())), \
patch('sys.argv', ['patroni.py', '--generate-config']), \
patch.object(MockConnect, 'server_version', PropertyMock(return_value=160000)), \
self.assertRaises(SystemExit) as e:
_main()
self.assertEqual(e.exception.code, 0)
self.assertEqual(self.config, mock_config_dump.call_args[0][0])
def test_generate_config_running_instance_errors(self):
# 1. Wrong DSN format
with patch('sys.argv', ['patroni.py', '--generate-config', '--dsn', 'host:foo port:bar user:foobar']), \
self.assertRaises(SystemExit) as e:
_main()
self.assertIn('Failed to parse DSN string', e.exception.code)
# 2. User is not a superuser
with patch('sys.argv', ['patroni.py',
'--generate-config', '--dsn', 'host=foo port=bar user=foobar password=pwd_from_dsn']), \
patch.object(MockCursor, 'rowcount', PropertyMock(return_value=0), create=True), \
patch.object(MockConnectionInfo, 'parameter_status', Mock(return_value='off')), \
self.assertRaises(SystemExit) as e:
_main()
self.assertIn('The provided user does not have superuser privilege', e.exception.code)
# 3. Error while calling postgres --version
with patch('subprocess.check_output', Mock(side_effect=OSError)), \
patch('sys.argv', ['patroni.py', '--generate-sample-config']), \
self.assertRaises(SystemExit) as e:
_main()
self.assertIn('Failed to get postgres version:', e.exception.code)
with patch('sys.argv', ['patroni.py', '--generate-config']):
# 4. empty postmaster.pid
with patch('builtins.open', Mock(side_effect=[mock_open(read_data='hba_content')(),
mock_open(read_data='ident_content')(),
mock_open(read_data='')()])), \
self.assertRaises(SystemExit) as e:
_main()
self.assertIn('Failed to obtain postmaster pid from postmaster.pid file', e.exception.code)
# 5. Failed to open postmaster.pid
with patch('builtins.open', Mock(side_effect=[mock_open(read_data='hba_content')(),
mock_open(read_data='ident_content')(),
OSError])), \
self.assertRaises(SystemExit) as e:
_main()
self.assertIn('Error while reading postmaster.pid file', e.exception.code)
# 6. Invalid postmaster pid
with patch('builtins.open', Mock(side_effect=[mock_open(read_data='hba_content')(),
mock_open(read_data='ident_content')(),
mock_open(read_data='1984')()])), \
patch('psutil.Process.__init__', Mock(return_value=None)), \
patch('psutil.Process.exe', Mock(side_effect=psutil.NoSuchProcess(1984))), \
self.assertRaises(SystemExit) as e:
_main()
self.assertIn("Obtained postmaster pid doesn't exist", e.exception.code)
# 7. Failed to open pg_hba
with patch('builtins.open', Mock(side_effect=OSError)), \
self.assertRaises(SystemExit) as e:
_main()
self.assertIn('Failed to read pg_hba.conf', e.exception.code)
# 8. Failed to open pg_ident
with patch('builtins.open', Mock(side_effect=[mock_open(read_data='hba_content')(), OSError])), \
self.assertRaises(SystemExit) as e:
_main()
self.assertIn('Failed to read pg_ident.conf', e.exception.code)
# 9. Failed PG connecttion
from . import psycopg
with patch('patroni.psycopg.connect', side_effect=psycopg.Error), \
self.assertRaises(SystemExit) as e:
_main()
self.assertIn('Failed to establish PostgreSQL connection', e.exception.code)
# 10. An unexpected error
with patch.object(AbstractConfigGenerator, '__init__', side_effect=psycopg.Error), \
self.assertRaises(SystemExit) as e:
_main()
self.assertIn('Unexpected exception', e.exception.code)
def test_get_address(self):
with patch('socket.getaddrinfo', Mock(side_effect=Exception)), \
patch('logging.warning') as mock_warning:
self.assertEqual(get_address(), (self.no_value_msg, self.no_value_msg))
self.assertIn('Failed to obtain address: %r', mock_warning.call_args_list[0][0])