From 85db209c19c5e99deb4e8d5bc9fde0b9cc9f2a55 Mon Sep 17 00:00:00 2001 From: Polina Bungina <27892524+hughcapet@users.noreply.github.com> Date: Thu, 14 Sep 2023 18:34:45 +0200 Subject: [PATCH] Always store CMDLINE_OPTIONS config values as int (#2861) --- patroni/config.py | 6 ++++-- patroni/postgresql/citus.py | 2 +- tests/test_bootstrap.py | 3 ++- tests/test_config.py | 25 ++++++++++++++++++++++++- tests/test_postgresql.py | 2 +- 5 files changed, 32 insertions(+), 6 deletions(-) diff --git a/patroni/config.py b/patroni/config.py index 78523dfc..437ce4ec 100644 --- a/patroni/config.py +++ b/patroni/config.py @@ -15,6 +15,7 @@ from .dcs import ClusterConfig, Cluster from .exceptions import ConfigParseError from .file_perm import pg_perm from .postgresql.config import ConfigHandler +from .validator import IntValidator from .utils import deep_compare, parse_bool, parse_int, patch_config logger = logging.getLogger(__name__) @@ -339,8 +340,9 @@ class Config(object): if name not in ConfigHandler.CMDLINE_OPTIONS: pg_params[name] = value elif not is_local: - if ConfigHandler.CMDLINE_OPTIONS[name][1](value): - pg_params[name] = value + validator = ConfigHandler.CMDLINE_OPTIONS[name][1] + if validator(value): + pg_params[name] = int(value) if isinstance(validator, IntValidator) else value else: logger.warning("postgresql parameter %s=%s failed validation, defaulting to %s", name, value, ConfigHandler.CMDLINE_OPTIONS[name][0]) diff --git a/patroni/postgresql/citus.py b/patroni/postgresql/citus.py index e11c206c..fc4f1bd2 100644 --- a/patroni/postgresql/citus.py +++ b/patroni/postgresql/citus.py @@ -405,7 +405,7 @@ class CitusHandler(Thread): parameters['shared_preload_libraries'] = ','.join(['citus'] + shared_preload_libraries) # if not explicitly set Citus overrides max_prepared_transactions to max_connections*2 - if parameters.get('max_prepared_transactions') == 0: + if parameters['max_prepared_transactions'] == 0: parameters['max_prepared_transactions'] = parameters['max_connections'] * 2 # Resharding in Citus implemented using logical replication diff --git a/tests/test_bootstrap.py b/tests/test_bootstrap.py index 9f98fecb..94d8ad95 100644 --- a/tests/test_bootstrap.py +++ b/tests/test_bootstrap.py @@ -238,7 +238,8 @@ class TestBootstrap(BaseTestPostgresql): self.p.reload_config({'authentication': {'superuser': {'username': 'p', 'password': 'p'}, 'replication': {'username': 'r', 'password': 'r'}, 'rewind': {'username': 'rw', 'password': 'rw'}}, - 'listen': '*', 'retry_timeout': 10, 'parameters': {'wal_level': '', 'hba_file': 'foo'}}) + 'listen': '*', 'retry_timeout': 10, + 'parameters': {'wal_level': '', 'hba_file': 'foo', 'max_prepared_transactions': 10}}) with patch.object(Postgresql, 'major_version', PropertyMock(return_value=110000)), \ patch.object(Postgresql, 'restart', Mock()) as mock_restart: self.b.post_bootstrap({}, task) diff --git a/tests/test_config.py b/tests/test_config.py index cf798d00..bd8d0a90 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -3,6 +3,7 @@ import sys import unittest import io +from copy import deepcopy from mock import MagicMock, Mock, patch from patroni.config import Config, ConfigParseError @@ -22,7 +23,7 @@ class TestConfig(unittest.TestCase): self.assertFalse(self.config.set_dynamic_configuration({'foo': 'bar'})) self.assertTrue(self.config.set_dynamic_configuration({'standby_cluster': {}, 'postgresql': { 'parameters': {'cluster_name': 1, 'hot_standby': 1, 'wal_keep_size': 1, - 'track_commit_timestamp': 1, 'wal_level': 1}}})) + 'track_commit_timestamp': 1, 'wal_level': 1, 'max_connections': '100'}}})) def test_reload_local_configuration(self): os.environ.update({ @@ -149,3 +150,25 @@ class TestConfig(unittest.TestCase): @patch('os.path.isdir', Mock(return_value=False)) def test_invalid_path(self): self.assertRaises(ConfigParseError, Config, 'postgres0') + + def test__process_postgresql_parameters(self): + expected_params = { + 'f.oo': 'bar', # not in ConfigHandler.CMDLINE_OPTIONS + 'max_connections': 100, # IntValidator + 'wal_level': 'hot_standby', # EnumValidator + } + input_params = deepcopy(expected_params) + + input_params['max_connections'] = '100' + self.assertEqual(self.config._process_postgresql_parameters(input_params), expected_params) + + expected_params['f.oo'] = input_params['f.oo'] = '100' + self.assertEqual(self.config._process_postgresql_parameters(input_params), expected_params) + + input_params['wal_level'] = 'cold_standby' + expected_params.pop('wal_level') + self.assertEqual(self.config._process_postgresql_parameters(input_params), expected_params) + + input_params['max_connections'] = 10 + expected_params.pop('max_connections') + self.assertEqual(self.config._process_postgresql_parameters(input_params), expected_params) diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 8022170a..00b90b94 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -682,7 +682,7 @@ class TestPostgresql(BaseTestPostgresql): self.assertIsNone(self.p.wait_for_startup()) def test_get_server_parameters(self): - config = {'parameters': {'wal_level': 'hot_standby'}, 'listen': '0'} + config = {'parameters': {'wal_level': 'hot_standby', 'max_prepared_transactions': 100}, 'listen': '0'} self.p._global_config = GlobalConfig({'synchronous_mode': True}) self.p.config.get_server_parameters(config) self.p._global_config = GlobalConfig({'synchronous_mode': True, 'synchronous_mode_strict': True})