mirror of
https://github.com/outbackdingo/patroni.git
synced 2026-01-27 18:20:05 +00:00
Enforce search_path=pg_catalog for non-replication connections (#2496)
There is a known [vector of attact](https://pganalyze.com/blog/5mins-postgres-security-patch-releases-pgspot-pghostile) by creating functions and/or operators in a public scheme with the same name and signature as corresponding objects in `pg_catalog`. Since Patroni is heavily relying on superuser connections we want to mitigate it by enforcing `search_path=pg_catalog` for all connections created by Patroni (except replication connections). It is achieved by introducing a new function, that wraps psycopg.connect() and appends ` -c search_path=pg_catalog` to `options` parameter. In addition to that, we set connection.autocommit to True before returning it.
This commit is contained in:
committed by
GitHub
parent
b6b220dddb
commit
4d77b444dc
@@ -177,7 +177,7 @@ class PatroniController(AbstractController):
|
||||
patroni_config_name = self.PATRONI_CONFIG.format(name)
|
||||
patroni_config_path = os.path.join(self._output_dir, patroni_config_name)
|
||||
|
||||
with open(patroni_config_name) as f:
|
||||
with open('postgres0.yml') as f:
|
||||
config = yaml.safe_load(f)
|
||||
config.pop('etcd', None)
|
||||
|
||||
@@ -186,8 +186,10 @@ class PatroniController(AbstractController):
|
||||
os.environ['RAFT_PORT'] = str(int(raft_port) + 1)
|
||||
config['raft'] = {'data_dir': self._output_dir, 'self_addr': 'localhost:' + os.environ['RAFT_PORT']}
|
||||
|
||||
host = config['postgresql']['listen'].split(':')[0]
|
||||
host = config['restapi']['listen'].rsplit(':', 1)[0]
|
||||
config['restapi']['listen'] = config['restapi']['connect_address'] = '{0}:{1}'.format(host, 8008+int(name[-1]))
|
||||
|
||||
host = config['postgresql']['listen'].rsplit(':', 1)[0]
|
||||
config['postgresql']['listen'] = config['postgresql']['connect_address'] = '{0}:{1}'.format(host, self.__PORT)
|
||||
|
||||
config['name'] = name
|
||||
@@ -231,7 +233,6 @@ class PatroniController(AbstractController):
|
||||
def _connection(self):
|
||||
if not self._conn or self._conn.closed != 0:
|
||||
self._conn = psycopg.connect(**self._connkwargs)
|
||||
self._conn.autocommit = True
|
||||
return self._conn
|
||||
|
||||
def _cursor(self):
|
||||
|
||||
@@ -28,7 +28,7 @@ def stop_postgres(context, name):
|
||||
def add_table(context, table_name, pg_name):
|
||||
# parse the configuration file and get the port
|
||||
try:
|
||||
context.pctl.query(pg_name, "CREATE TABLE {0}()".format(table_name))
|
||||
context.pctl.query(pg_name, "CREATE TABLE public.{0}()".format(table_name))
|
||||
except pg.Error as e:
|
||||
assert False, "Error creating table {0} on {1}: {2}".format(table_name, pg_name, e)
|
||||
|
||||
@@ -37,9 +37,9 @@ def add_table(context, table_name, pg_name):
|
||||
def toggle_wal_replay(context, action, pg_name):
|
||||
# pause or resume the wal replay process
|
||||
try:
|
||||
version = context.pctl.query(pg_name, "select pg_catalog.pg_read_file('PG_VERSION', 0, 2)").fetchone()
|
||||
wal = version and version[0] and int(version[0].split('.')[0]) < 10 and "xlog" or "wal"
|
||||
context.pctl.query(pg_name, "SELECT pg_{0}_replay_{1}()".format(wal, action))
|
||||
version = context.pctl.query(pg_name, "SHOW server_version_num").fetchone()[0]
|
||||
wal_name = 'xlog' if int(version)/10000 < 10 else 'wal'
|
||||
context.pctl.query(pg_name, "SELECT pg_{0}_replay_{1}()".format(wal_name, action))
|
||||
except pg.Error as e:
|
||||
assert False, "Error during {0} wal recovery on {1}: {2}".format(action, pg_name, e)
|
||||
|
||||
@@ -47,10 +47,10 @@ def toggle_wal_replay(context, action, pg_name):
|
||||
@step('I {action:w} table on {pg_name:w}')
|
||||
def crdr_mytest(context, action, pg_name):
|
||||
try:
|
||||
if (action == "create"):
|
||||
context.pctl.query(pg_name, "create table if not exists mytest(id Numeric)")
|
||||
else:
|
||||
context.pctl.query(pg_name, "drop table if exists mytest")
|
||||
if (action == "create"):
|
||||
context.pctl.query(pg_name, "create table if not exists public.mytest(id numeric)")
|
||||
else:
|
||||
context.pctl.query(pg_name, "drop table if exists public.mytest")
|
||||
except pg.Error as e:
|
||||
assert False, "Error {0} table mytest on {1}: {2}".format(action, pg_name, e)
|
||||
|
||||
@@ -59,7 +59,7 @@ def crdr_mytest(context, action, pg_name):
|
||||
def initiate_load(context, pg_name):
|
||||
# perform dummy load
|
||||
try:
|
||||
context.pctl.query(pg_name, "begin; insert into mytest select r::numeric from generate_series(1, 350000) r; commit;")
|
||||
context.pctl.query(pg_name, "insert into public.mytest select r::numeric from generate_series(1, 350000) r")
|
||||
except pg.Error as e:
|
||||
assert False, "Error loading test data on {0}: {1}".format(pg_name, e)
|
||||
|
||||
@@ -68,7 +68,7 @@ def initiate_load(context, pg_name):
|
||||
def table_is_present_on(context, table_name, pg_name, max_replication_delay):
|
||||
max_replication_delay *= context.timeout_multiplier
|
||||
for _ in range(int(max_replication_delay)):
|
||||
if context.pctl.query(pg_name, "SELECT 1 FROM {0}".format(table_name), fail_ok=True) is not None:
|
||||
if context.pctl.query(pg_name, "SELECT 1 FROM public.{0}".format(table_name), fail_ok=True) is not None:
|
||||
break
|
||||
sleep(1)
|
||||
else:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
import os
|
||||
import parse
|
||||
import shlex
|
||||
import subprocess
|
||||
@@ -86,10 +85,7 @@ def do_request(context, request_method, url, data):
|
||||
def do_run(context, cmd):
|
||||
cmd = [sys.executable, '-m', 'coverage', 'run', '--source=patroni', '-p'] + shlex.split(cmd)
|
||||
try:
|
||||
# XXX: Dirty hack! We need to take name/passwd from the config!
|
||||
env = os.environ.copy()
|
||||
env.update({'PATRONI_RESTAPI_USERNAME': 'username', 'PATRONI_RESTAPI_PASSWORD': 'password'})
|
||||
response = subprocess.check_output(cmd, stderr=subprocess.STDOUT, env=env)
|
||||
response = subprocess.check_output(cmd, stderr=subprocess.STDOUT)
|
||||
context.status_code = 0
|
||||
except subprocess.CalledProcessError as e:
|
||||
response = e.output
|
||||
|
||||
@@ -629,7 +629,7 @@ class RestApiHandler(BaseHTTPRequestHandler):
|
||||
stmt = ("SELECT " + postgresql.POSTMASTER_START_TIME + ", " + postgresql.TL_LSN + ","
|
||||
" pg_catalog.pg_last_xact_replay_timestamp(),"
|
||||
" pg_catalog.array_to_json(pg_catalog.array_agg(pg_catalog.row_to_json(ri))) "
|
||||
"FROM (SELECT (SELECT rolname FROM pg_authid WHERE oid = usesysid) AS usename,"
|
||||
"FROM (SELECT (SELECT rolname FROM pg_catalog.pg_authid WHERE oid = usesysid) AS usename,"
|
||||
" application_name, client_addr, w.state, sync_state, sync_priority"
|
||||
" FROM pg_catalog.pg_stat_get_wal_senders() w, pg_catalog.pg_stat_get_activity(pid)) AS ri")
|
||||
|
||||
|
||||
@@ -274,7 +274,6 @@ def get_cursor(cluster, connect_parameters, role='master', member=None):
|
||||
|
||||
from . import psycopg
|
||||
conn = psycopg.connect(**params)
|
||||
conn.autocommit = True
|
||||
cursor = conn.cursor()
|
||||
if role == 'any':
|
||||
return cursor
|
||||
|
||||
@@ -22,7 +22,6 @@ class Connection(object):
|
||||
with self._lock:
|
||||
if not self._connection or self._connection.closed != 0:
|
||||
self._connection = psycopg.connect(**self._conn_kwargs)
|
||||
self._connection.autocommit = True
|
||||
self.server_version = self._connection.server_version
|
||||
return self._connection
|
||||
|
||||
@@ -42,7 +41,6 @@ class Connection(object):
|
||||
@contextmanager
|
||||
def get_connection_cursor(**kwargs):
|
||||
conn = psycopg.connect(**kwargs)
|
||||
conn.autocommit = True
|
||||
with conn.cursor() as cur:
|
||||
yield cur
|
||||
conn.close()
|
||||
|
||||
@@ -6,7 +6,7 @@ try:
|
||||
from . import MIN_PSYCOPG2, parse_version
|
||||
if parse_version(__version__) < MIN_PSYCOPG2:
|
||||
raise ImportError
|
||||
from psycopg2 import connect, Error, DatabaseError, OperationalError, ProgrammingError
|
||||
from psycopg2 import connect as _connect, Error, DatabaseError, OperationalError, ProgrammingError
|
||||
from psycopg2.extensions import adapt
|
||||
|
||||
try:
|
||||
@@ -20,10 +20,10 @@ try:
|
||||
value.prepare(conn)
|
||||
return value.getquoted().decode('utf-8')
|
||||
except ImportError:
|
||||
from psycopg import connect as _connect, sql, Error, DatabaseError, OperationalError, ProgrammingError
|
||||
from psycopg import connect as __connect, sql, Error, DatabaseError, OperationalError, ProgrammingError
|
||||
|
||||
def connect(*args, **kwargs):
|
||||
ret = _connect(*args, **kwargs)
|
||||
def _connect(*args, **kwargs):
|
||||
ret = __connect(*args, **kwargs)
|
||||
ret.server_version = ret.pgconn.server_version # compatibility with psycopg2
|
||||
return ret
|
||||
|
||||
@@ -34,6 +34,16 @@ except ImportError:
|
||||
return sql.Literal(value).as_string(conn)
|
||||
|
||||
|
||||
def connect(*args, **kwargs):
|
||||
if kwargs and 'replication' not in kwargs and kwargs.get('fallback_application_name') != 'Patroni ctl':
|
||||
options = [kwargs['options']] if 'options' in kwargs else []
|
||||
options.append('-c search_path=pg_catalog')
|
||||
kwargs['options'] = ' '.join(options)
|
||||
ret = _connect(*args, **kwargs)
|
||||
ret.autocommit = True
|
||||
return ret
|
||||
|
||||
|
||||
def quote_ident(value, conn=None):
|
||||
if _legacy or conn is None:
|
||||
return '"{0}"'.format(value.replace('"', '""'))
|
||||
|
||||
@@ -214,26 +214,26 @@ class WALERestore(object):
|
||||
attempts_no = 0
|
||||
while True:
|
||||
if self.master_connection:
|
||||
con = None
|
||||
try:
|
||||
# get the difference in bytes between the current WAL location and the backup start offset
|
||||
with psycopg.connect(self.master_connection) as con:
|
||||
if con.server_version >= 100000:
|
||||
wal_name = 'wal'
|
||||
lsn_name = 'lsn'
|
||||
else:
|
||||
wal_name = 'xlog'
|
||||
lsn_name = 'location'
|
||||
con.autocommit = True
|
||||
with con.cursor() as cur:
|
||||
cur.execute(("SELECT CASE WHEN pg_catalog.pg_is_in_recovery()"
|
||||
" THEN GREATEST(pg_catalog.pg_{0}_{1}_diff(COALESCE("
|
||||
"pg_last_{0}_receive_{1}(), '0/0'), %s)::bigint, "
|
||||
"pg_catalog.pg_{0}_{1}_diff(pg_catalog.pg_last_{0}_replay_{1}(), %s)::bigint)"
|
||||
" ELSE pg_catalog.pg_{0}_{1}_diff(pg_catalog.pg_current_{0}_{1}(), %s)::bigint"
|
||||
" END").format(wal_name, lsn_name),
|
||||
(backup_start_lsn, backup_start_lsn, backup_start_lsn))
|
||||
con = psycopg.connect(self.master_connection)
|
||||
if con.server_version >= 100000:
|
||||
wal_name = 'wal'
|
||||
lsn_name = 'lsn'
|
||||
else:
|
||||
wal_name = 'xlog'
|
||||
lsn_name = 'location'
|
||||
with con.cursor() as cur:
|
||||
cur.execute(("SELECT CASE WHEN pg_catalog.pg_is_in_recovery()"
|
||||
" THEN GREATEST(pg_catalog.pg_{0}_{1}_diff(COALESCE("
|
||||
"pg_last_{0}_receive_{1}(), '0/0'), %s)::bigint, "
|
||||
"pg_catalog.pg_{0}_{1}_diff(pg_catalog.pg_last_{0}_replay_{1}(), %s)::bigint)"
|
||||
" ELSE pg_catalog.pg_{0}_{1}_diff(pg_catalog.pg_current_{0}_{1}(), %s)::bigint"
|
||||
" END").format(wal_name, lsn_name),
|
||||
(backup_start_lsn, backup_start_lsn, backup_start_lsn))
|
||||
|
||||
diff_in_bytes = int(cur.fetchone()[0])
|
||||
diff_in_bytes = int(cur.fetchone()[0])
|
||||
except psycopg.Error:
|
||||
logger.exception('could not determine difference with the master location')
|
||||
if attempts_no < self.retries: # retry in case of a temporarily connection issue
|
||||
@@ -246,6 +246,9 @@ class WALERestore(object):
|
||||
logger.info("continue with base backup from S3 since master is not available")
|
||||
diff_in_bytes = 0
|
||||
break
|
||||
finally:
|
||||
if con:
|
||||
con.close()
|
||||
else:
|
||||
# always try to use WAL-E if master connection string is not available
|
||||
diff_in_bytes = 0
|
||||
|
||||
@@ -177,7 +177,7 @@ class PostgresInit(unittest.TestCase):
|
||||
'force_parallel_mode': '1', 'constraint_exclusion': '',
|
||||
'max_stack_depth': 'Z', 'vacuum_cost_limit': -1, 'vacuum_cost_delay': 200}
|
||||
|
||||
@patch('patroni.psycopg.connect', psycopg_connect)
|
||||
@patch('patroni.psycopg._connect', psycopg_connect)
|
||||
@patch('patroni.postgresql.CallbackExecutor', Mock())
|
||||
@patch.object(ConfigHandler, 'write_postgresql_conf', Mock())
|
||||
@patch.object(ConfigHandler, 'replace_pg_hba', Mock())
|
||||
|
||||
Reference in New Issue
Block a user