Enforce search_path=pg_catalog for non-replication connections (#2496)

There is a known [vector of attact](https://pganalyze.com/blog/5mins-postgres-security-patch-releases-pgspot-pghostile) by creating functions and/or operators in a public scheme with the same name and signature as corresponding objects in `pg_catalog`.

Since Patroni is heavily relying on superuser connections we want to mitigate it by enforcing `search_path=pg_catalog` for all connections created by Patroni (except replication connections). It is achieved by introducing a new function, that wraps psycopg.connect() and appends ` -c search_path=pg_catalog` to `options` parameter.

In addition to that, we set connection.autocommit to True before returning it.
This commit is contained in:
Alexander Kukushkin
2022-12-20 09:56:14 +01:00
committed by GitHub
parent b6b220dddb
commit 4d77b444dc
9 changed files with 51 additions and 44 deletions

View File

@@ -177,7 +177,7 @@ class PatroniController(AbstractController):
patroni_config_name = self.PATRONI_CONFIG.format(name)
patroni_config_path = os.path.join(self._output_dir, patroni_config_name)
with open(patroni_config_name) as f:
with open('postgres0.yml') as f:
config = yaml.safe_load(f)
config.pop('etcd', None)
@@ -186,8 +186,10 @@ class PatroniController(AbstractController):
os.environ['RAFT_PORT'] = str(int(raft_port) + 1)
config['raft'] = {'data_dir': self._output_dir, 'self_addr': 'localhost:' + os.environ['RAFT_PORT']}
host = config['postgresql']['listen'].split(':')[0]
host = config['restapi']['listen'].rsplit(':', 1)[0]
config['restapi']['listen'] = config['restapi']['connect_address'] = '{0}:{1}'.format(host, 8008+int(name[-1]))
host = config['postgresql']['listen'].rsplit(':', 1)[0]
config['postgresql']['listen'] = config['postgresql']['connect_address'] = '{0}:{1}'.format(host, self.__PORT)
config['name'] = name
@@ -231,7 +233,6 @@ class PatroniController(AbstractController):
def _connection(self):
if not self._conn or self._conn.closed != 0:
self._conn = psycopg.connect(**self._connkwargs)
self._conn.autocommit = True
return self._conn
def _cursor(self):

View File

@@ -28,7 +28,7 @@ def stop_postgres(context, name):
def add_table(context, table_name, pg_name):
# parse the configuration file and get the port
try:
context.pctl.query(pg_name, "CREATE TABLE {0}()".format(table_name))
context.pctl.query(pg_name, "CREATE TABLE public.{0}()".format(table_name))
except pg.Error as e:
assert False, "Error creating table {0} on {1}: {2}".format(table_name, pg_name, e)
@@ -37,9 +37,9 @@ def add_table(context, table_name, pg_name):
def toggle_wal_replay(context, action, pg_name):
# pause or resume the wal replay process
try:
version = context.pctl.query(pg_name, "select pg_catalog.pg_read_file('PG_VERSION', 0, 2)").fetchone()
wal = version and version[0] and int(version[0].split('.')[0]) < 10 and "xlog" or "wal"
context.pctl.query(pg_name, "SELECT pg_{0}_replay_{1}()".format(wal, action))
version = context.pctl.query(pg_name, "SHOW server_version_num").fetchone()[0]
wal_name = 'xlog' if int(version)/10000 < 10 else 'wal'
context.pctl.query(pg_name, "SELECT pg_{0}_replay_{1}()".format(wal_name, action))
except pg.Error as e:
assert False, "Error during {0} wal recovery on {1}: {2}".format(action, pg_name, e)
@@ -47,10 +47,10 @@ def toggle_wal_replay(context, action, pg_name):
@step('I {action:w} table on {pg_name:w}')
def crdr_mytest(context, action, pg_name):
try:
if (action == "create"):
context.pctl.query(pg_name, "create table if not exists mytest(id Numeric)")
else:
context.pctl.query(pg_name, "drop table if exists mytest")
if (action == "create"):
context.pctl.query(pg_name, "create table if not exists public.mytest(id numeric)")
else:
context.pctl.query(pg_name, "drop table if exists public.mytest")
except pg.Error as e:
assert False, "Error {0} table mytest on {1}: {2}".format(action, pg_name, e)
@@ -59,7 +59,7 @@ def crdr_mytest(context, action, pg_name):
def initiate_load(context, pg_name):
# perform dummy load
try:
context.pctl.query(pg_name, "begin; insert into mytest select r::numeric from generate_series(1, 350000) r; commit;")
context.pctl.query(pg_name, "insert into public.mytest select r::numeric from generate_series(1, 350000) r")
except pg.Error as e:
assert False, "Error loading test data on {0}: {1}".format(pg_name, e)
@@ -68,7 +68,7 @@ def initiate_load(context, pg_name):
def table_is_present_on(context, table_name, pg_name, max_replication_delay):
max_replication_delay *= context.timeout_multiplier
for _ in range(int(max_replication_delay)):
if context.pctl.query(pg_name, "SELECT 1 FROM {0}".format(table_name), fail_ok=True) is not None:
if context.pctl.query(pg_name, "SELECT 1 FROM public.{0}".format(table_name), fail_ok=True) is not None:
break
sleep(1)
else:

View File

@@ -1,5 +1,4 @@
import json
import os
import parse
import shlex
import subprocess
@@ -86,10 +85,7 @@ def do_request(context, request_method, url, data):
def do_run(context, cmd):
cmd = [sys.executable, '-m', 'coverage', 'run', '--source=patroni', '-p'] + shlex.split(cmd)
try:
# XXX: Dirty hack! We need to take name/passwd from the config!
env = os.environ.copy()
env.update({'PATRONI_RESTAPI_USERNAME': 'username', 'PATRONI_RESTAPI_PASSWORD': 'password'})
response = subprocess.check_output(cmd, stderr=subprocess.STDOUT, env=env)
response = subprocess.check_output(cmd, stderr=subprocess.STDOUT)
context.status_code = 0
except subprocess.CalledProcessError as e:
response = e.output

View File

@@ -629,7 +629,7 @@ class RestApiHandler(BaseHTTPRequestHandler):
stmt = ("SELECT " + postgresql.POSTMASTER_START_TIME + ", " + postgresql.TL_LSN + ","
" pg_catalog.pg_last_xact_replay_timestamp(),"
" pg_catalog.array_to_json(pg_catalog.array_agg(pg_catalog.row_to_json(ri))) "
"FROM (SELECT (SELECT rolname FROM pg_authid WHERE oid = usesysid) AS usename,"
"FROM (SELECT (SELECT rolname FROM pg_catalog.pg_authid WHERE oid = usesysid) AS usename,"
" application_name, client_addr, w.state, sync_state, sync_priority"
" FROM pg_catalog.pg_stat_get_wal_senders() w, pg_catalog.pg_stat_get_activity(pid)) AS ri")

View File

@@ -274,7 +274,6 @@ def get_cursor(cluster, connect_parameters, role='master', member=None):
from . import psycopg
conn = psycopg.connect(**params)
conn.autocommit = True
cursor = conn.cursor()
if role == 'any':
return cursor

View File

@@ -22,7 +22,6 @@ class Connection(object):
with self._lock:
if not self._connection or self._connection.closed != 0:
self._connection = psycopg.connect(**self._conn_kwargs)
self._connection.autocommit = True
self.server_version = self._connection.server_version
return self._connection
@@ -42,7 +41,6 @@ class Connection(object):
@contextmanager
def get_connection_cursor(**kwargs):
conn = psycopg.connect(**kwargs)
conn.autocommit = True
with conn.cursor() as cur:
yield cur
conn.close()

View File

@@ -6,7 +6,7 @@ try:
from . import MIN_PSYCOPG2, parse_version
if parse_version(__version__) < MIN_PSYCOPG2:
raise ImportError
from psycopg2 import connect, Error, DatabaseError, OperationalError, ProgrammingError
from psycopg2 import connect as _connect, Error, DatabaseError, OperationalError, ProgrammingError
from psycopg2.extensions import adapt
try:
@@ -20,10 +20,10 @@ try:
value.prepare(conn)
return value.getquoted().decode('utf-8')
except ImportError:
from psycopg import connect as _connect, sql, Error, DatabaseError, OperationalError, ProgrammingError
from psycopg import connect as __connect, sql, Error, DatabaseError, OperationalError, ProgrammingError
def connect(*args, **kwargs):
ret = _connect(*args, **kwargs)
def _connect(*args, **kwargs):
ret = __connect(*args, **kwargs)
ret.server_version = ret.pgconn.server_version # compatibility with psycopg2
return ret
@@ -34,6 +34,16 @@ except ImportError:
return sql.Literal(value).as_string(conn)
def connect(*args, **kwargs):
if kwargs and 'replication' not in kwargs and kwargs.get('fallback_application_name') != 'Patroni ctl':
options = [kwargs['options']] if 'options' in kwargs else []
options.append('-c search_path=pg_catalog')
kwargs['options'] = ' '.join(options)
ret = _connect(*args, **kwargs)
ret.autocommit = True
return ret
def quote_ident(value, conn=None):
if _legacy or conn is None:
return '"{0}"'.format(value.replace('"', '""'))

View File

@@ -214,26 +214,26 @@ class WALERestore(object):
attempts_no = 0
while True:
if self.master_connection:
con = None
try:
# get the difference in bytes between the current WAL location and the backup start offset
with psycopg.connect(self.master_connection) as con:
if con.server_version >= 100000:
wal_name = 'wal'
lsn_name = 'lsn'
else:
wal_name = 'xlog'
lsn_name = 'location'
con.autocommit = True
with con.cursor() as cur:
cur.execute(("SELECT CASE WHEN pg_catalog.pg_is_in_recovery()"
" THEN GREATEST(pg_catalog.pg_{0}_{1}_diff(COALESCE("
"pg_last_{0}_receive_{1}(), '0/0'), %s)::bigint, "
"pg_catalog.pg_{0}_{1}_diff(pg_catalog.pg_last_{0}_replay_{1}(), %s)::bigint)"
" ELSE pg_catalog.pg_{0}_{1}_diff(pg_catalog.pg_current_{0}_{1}(), %s)::bigint"
" END").format(wal_name, lsn_name),
(backup_start_lsn, backup_start_lsn, backup_start_lsn))
con = psycopg.connect(self.master_connection)
if con.server_version >= 100000:
wal_name = 'wal'
lsn_name = 'lsn'
else:
wal_name = 'xlog'
lsn_name = 'location'
with con.cursor() as cur:
cur.execute(("SELECT CASE WHEN pg_catalog.pg_is_in_recovery()"
" THEN GREATEST(pg_catalog.pg_{0}_{1}_diff(COALESCE("
"pg_last_{0}_receive_{1}(), '0/0'), %s)::bigint, "
"pg_catalog.pg_{0}_{1}_diff(pg_catalog.pg_last_{0}_replay_{1}(), %s)::bigint)"
" ELSE pg_catalog.pg_{0}_{1}_diff(pg_catalog.pg_current_{0}_{1}(), %s)::bigint"
" END").format(wal_name, lsn_name),
(backup_start_lsn, backup_start_lsn, backup_start_lsn))
diff_in_bytes = int(cur.fetchone()[0])
diff_in_bytes = int(cur.fetchone()[0])
except psycopg.Error:
logger.exception('could not determine difference with the master location')
if attempts_no < self.retries: # retry in case of a temporarily connection issue
@@ -246,6 +246,9 @@ class WALERestore(object):
logger.info("continue with base backup from S3 since master is not available")
diff_in_bytes = 0
break
finally:
if con:
con.close()
else:
# always try to use WAL-E if master connection string is not available
diff_in_bytes = 0

View File

@@ -177,7 +177,7 @@ class PostgresInit(unittest.TestCase):
'force_parallel_mode': '1', 'constraint_exclusion': '',
'max_stack_depth': 'Z', 'vacuum_cost_limit': -1, 'vacuum_cost_delay': 200}
@patch('patroni.psycopg.connect', psycopg_connect)
@patch('patroni.psycopg._connect', psycopg_connect)
@patch('patroni.postgresql.CallbackExecutor', Mock())
@patch.object(ConfigHandler, 'write_postgresql_conf', Mock())
@patch.object(ConfigHandler, 'replace_pg_hba', Mock())