diff --git a/patroni/api.py b/patroni/api.py index 45b50a4d..0fb39157 100644 --- a/patroni/api.py +++ b/patroni/api.py @@ -1,3 +1,11 @@ +"""Implement Patroni's REST API. + +Exposes a REST API of patroni operations functions, such as status, performance and management to web clients. + +Much of what can be achieved with the command line tool patronictl can be done via the API. Patroni CLI and daemon +utilises the API to perform these functions. +""" + import base64 import hmac import json @@ -11,13 +19,15 @@ import socket import sys from http.server import BaseHTTPRequestHandler, HTTPServer -from ipaddress import ip_address, ip_network +from ipaddress import ip_address, ip_network, IPv4Network, IPv6Network from socketserver import ThreadingMixIn from threading import Thread from urllib.parse import urlparse, parse_qs -from typing import Any, Dict, Optional, Union + +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union from . import psycopg +from .__main__ import Patroni from .dcs import Cluster from .exceptions import PostgresConnectionException, PostgresException from .postgresql.misc import postgres_version_to_int @@ -27,14 +37,97 @@ from .utils import deep_compare, enable_keepalive, parse_bool, patch_config, Ret logger = logging.getLogger(__name__) -class RestApiHandler(BaseHTTPRequestHandler): +def check_access(func: Callable[['RestApiHandler'], None]) -> Callable[..., None]: + """Check the source ip, authorization header, or client certificates. - def _write_status_code_only(self, status_code): + .. note:: + The actual logic to check access is implemented through :func:`RestApiServer.check_access`. + + :param func: function to be decorated. + + :returns: a decorator that executes *func* only if :func:`RestApiServer.check_access` returns ``True``. + + :Example: + + @check_access + def do_PUT_foo(): + pass + """ + + def wrapper(self: 'RestApiHandler', *args: Any, **kwargs: Any) -> None: + if self.server.check_access(self): + return func(self, *args, **kwargs) + + return wrapper + + +class RestApiHandler(BaseHTTPRequestHandler): + """Define how to handle each of the requests that are made against the REST API server.""" + + # Comment from pyi stub file. These unions can cause typing errors with IDEs, e.g. PyCharm + # + # Those are technically of types, respectively: + # * _RequestType = Union[socket.socket, Tuple[bytes, socket.socket]] + # * _AddressType = Tuple[str, int] + # But there are some concerns that having unions here would cause + # too much inconvenience to people using it (see + # https://github.com/python/typeshed/pull/384#issuecomment-234649696) + + def __init__(self, request: Any, + client_address: Any, + server: Union['RestApiServer', HTTPServer]) -> None: + """Create a :class:`RestApiHandler` instance. + + .. note:: + Currently not different from its superclass :func:`__init__`, and only used so ``pyright`` can understand + the type of ``server`` attribute. + + :param request: client request to be processed. + :param client_address: address of the client connection. + :param server: HTTP server that received the request. + """ + assert isinstance(server, RestApiServer) + super(RestApiHandler, self).__init__(request, client_address, server) + self.server: 'RestApiServer' = server + self.__start_time: float = 0.0 + self.path_query: Dict[str, List[str]] = {} + + def _write_status_code_only(self, status_code: int) -> None: + """Write a response that is composed only of the HTTP status. + + The response is written with these values separated by space: + * HTTP protocol version; + * *status_code*; + * description of *status_code*. + + .. note:: + This is usually useful for replying to requests from software like HAProxy. + + :param status_code: HTTP status code. + + :Example: + + * ``_write_status_code_only(200)`` would write a response like ``HTTP/1.0 200 OK``. + """ message = self.responses[status_code][0] self.wfile.write('{0} {1} {2}\r\n\r\n'.format(self.protocol_version, status_code, message).encode('utf-8')) self.log_request(status_code) - def _write_response(self, status_code, body, content_type='text/html', headers=None): + def _write_response(self, status_code: int, body: str, content_type: str = 'text/html', + headers: Optional[Dict[str, str]] = None) -> None: + """Write an HTTP response. + + .. note:: + Besides ``Content-Type`` header, and the HTTP headers passed through *headers*, this function will also + write the HTTP headers defined through ``restapi.http_extra_headers`` and ``restapi.https_extra_headers`` + from Patroni configuration. + + :param status_code: response HTTP status code. + :param body: response body. + :param content_type: value for ``Content-Type`` HTTP header. + :param headers: dictionary of additional HTTP headers to set for the response. Each key is the header name, and + the corresponding value is the value for the header in the response. + """ # TODO: try-catch ConnectionResetError: [Errno 104] Connection reset by peer and log it in DEBUG level self.send_response(status_code) headers = headers or {} @@ -42,31 +135,45 @@ class RestApiHandler(BaseHTTPRequestHandler): headers['Content-Type'] = content_type for name, value in headers.items(): self.send_header(name, value) - for name, value in self.server.http_extra_headers.items(): + for name, value in (self.server.http_extra_headers or {}).items(): self.send_header(name, value) self.end_headers() self.wfile.write(body.encode('utf-8')) - def _write_json_response(self, status_code, response): + def _write_json_response(self, status_code: int, response: Any) -> None: + """Write an HTTP response with a JSON content type. + + Call :func:`_write_response` with ``content_type`` as ``application/json``. + + :param status_code: response HTTP status code. + :param response: value to be dumped as a JSON string and to be used as the response body. + """ self._write_response(status_code, json.dumps(response, default=str), content_type='application/json') - def check_access(func): - """Decorator function to check the source ip, authorization header. or client certificates - - Usage example: - @check_access - def do_PUT_foo(): - pass - """ - - def wrapper(self, *args, **kwargs): - if self.server.check_access(self): - return func(self, *args, **kwargs) - - return wrapper - def _write_status_response(self, status_code: int, response: Dict[str, Any]) -> None: - """Sends HTTP response with Patroni/Postgres status in JSON format.""" + """Write an HTTP response with Patroni/Postgres status in JSON format. + + Modifies *response* before sending it to the client. Defines the ``patroni`` key, which is a + dictionary that contains the mandatory keys: + + * ``version``: Patroni version, e.g. ``3.0.2``; + * ``scope``: value of ``scope`` setting from Patroni configuration. + + May also add the following optional keys, depending on the status of this Patroni/PostgreSQL node: + + * ``tags``: tags that were set through Patroni configuration merged with dynamically applied tags; + * ``database_system_identifier``: ``Database system identifier`` from ``pg_controldata`` output; + * ``pending_restart``: ``True`` if PostgreSQL is pending to be restarted; + * ``scheduled_restart``: a dictionary with a single key ``schedule``, which is the timestamp for the scheduled + restart; + * ``watchdog_failed``: ``True`` if watchdog device is unhealthy; + * ``logger_queue_size``: log queue length if it is longer than expected; + * ``logger_records_lost``: number of log records that have been lost while the log queue was full. + + :param status_code: response HTTP status code. + :param response: represents the status of the PostgreSQL node, and is used as a basis for the HTTP response. + This dictionary is built through :func:`get_postgresql_status`. + """ patroni = self.server.patroni tags = patroni.ha.get_effective_tags() if tags: @@ -76,7 +183,7 @@ class RestApiHandler(BaseHTTPRequestHandler): if patroni.postgresql.pending_restart: response['pending_restart'] = True response['patroni'] = {'version': patroni.version, 'scope': patroni.postgresql.scope} - if patroni.scheduled_restart and isinstance(patroni.scheduled_restart, dict): + if patroni.scheduled_restart: response['scheduled_restart'] = patroni.scheduled_restart.copy() del response['scheduled_restart']['postmaster_start_time'] response['scheduled_restart']['schedule'] = (response['scheduled_restart']['schedule']).isoformat() @@ -90,15 +197,52 @@ class RestApiHandler(BaseHTTPRequestHandler): response['logger_records_lost'] = lost self._write_json_response(status_code, response) - def do_GET(self, write_status_code_only: Optional[bool] = False) -> None: - """Default method for processing all GET requests which can not be routed to other methods. + def do_GET(self, write_status_code_only: bool = False) -> None: + """Process all GET requests which can not be routed to other methods. - Is used for handling all health-checks requests. E.g. "GET /(primary|replica|sync|async|etc...)" - :param write_status_code_only: indicates that instead of normal HTTP response we should - send only HTTP Status Code and close the connection. - It is useful to when health-checks are executed by HAProxy. + Is used for handling all health-checks requests. E.g. "GET /(primary|replica|sync|async|etc...)". + + The (optional) query parameters and the HTTP response status depend on the requested path: + * ``/``, ``primary``, or ``read-write``: + * HTTP status ``200``: if a primary with the leader lock. + * ``/standby-leader``: + * HTTP status ``200``: if holds the leader lock in a standby cluster. + * ``/leader``: + * HTTP status ``200``: if holds the leader lock. + * ``/replica``: + * Query parameters: + * ``lag``: only accept replication lag up to ``lag``. Accepts either an :class:`int`, which + represents lag in bytes, or a :class:`str` representing lag in human-readable format (e.g. + ``10MB``). + * Any custom parameter: will attempt to match them against node tags. + * HTTP status ``200``: if up and running as a standby and without ``noloadbalance`` tag. + * ``/read-only``: + * HTTP status ``200``: if up and running and without ``noloadbalance`` tag. + * ``/synchronous`` or ``/sync``: + * HTTP status ``200``: if up and running as a synchronous standby. + * ``/read-only-sync``: + * HTTP status ``200``: if up and running as a synchronous standby or primary. + * ``/asynchronous``: + * Query parameters: + * ``lag``: only accept replication lag up to ``lag``. Accepts either an :class:`int`, which + represents lag in bytes, or a :class:`str` representing lag in human-readable format (e.g. + ``10MB``). + * HTTP status ``200``: if up and running as an asynchronous standby. + * ``/health``: + * HTTP status ``200``: if up and running. + + .. note:: + If not able to honor the query parameter, or not able to match the condition described for HTTP status + ``200`` in each path above, then HTTP status will be ``503``. + + .. note:: + Independently of the requested path, if *write_status_code_only* is ``False``, then it always write an HTTP + response through :func:`_write_status_response`, with the node status. + + :param write_status_code_only: indicates that instead of a normal HTTP response we should + send only the HTTP Status Code and close the connection. + Useful when health-checks are executed by HAProxy. """ - path = '/primary' if self.path == '/' else self.path response = self.get_postgresql_status() @@ -185,14 +329,33 @@ class RestApiHandler(BaseHTTPRequestHandler): else: self._write_status_response(status_code, response) - def do_OPTIONS(self): + def do_OPTIONS(self) -> None: + """Handle an ``OPTIONS`` request. + + Write a simple HTTP response that represents the current PostgreSQL status. Send only `200 OK` or + `503 Service Unavailable` as a response and nothing more, particularly no headers. + """ self.do_GET(write_status_code_only=True) - def do_HEAD(self): + def do_HEAD(self) -> None: + """Handle a ``HEAD`` request. + + Write a simple HTTP response that represents the current PostgreSQL status. Send only `200 OK` or + `503 Service Unavailable` as a response and nothing more, particularly no headers. + """ self.do_GET(write_status_code_only=True) - def do_GET_liveness(self): - patroni = self.server.patroni + def do_GET_liveness(self) -> None: + """Handle a ``GET`` request to ``/liveness`` path. + + Write a simple HTTP response with HTTP status: + * ``200``: + * If the cluster is in maintenance mode; or + * If Patroni heartbeat loop is properly running; + * ``503`` if Patroni heartbeat loop last run was more than ``ttl`` setting ago on the primary (or twice the + value of ``ttl`` on a replica). + """ + patroni: Patroni = self.server.patroni is_primary = patroni.postgresql.role in ('master', 'primary') and patroni.postgresql.is_running() # We can tolerate Patroni problems longer on the replica. # On the primary the liveness probe most likely will start failing only after the leader key expired. @@ -203,7 +366,15 @@ class RestApiHandler(BaseHTTPRequestHandler): status_code = 200 if patroni.ha.is_paused() or patroni.next_run + liveness_threshold > time.time() else 503 self._write_status_code_only(status_code) - def do_GET_readiness(self): + def do_GET_readiness(self) -> None: + """Handle a ``GET`` request to ``/readiness`` path. + + Write a simple HTTP response which HTTP status can be: + * ``200``: + * If this Patroni node holds the DCS leader lock; or + * If this PostgreSQL instance is up and running; + * ``503``: if none of the previous conditions apply. + """ patroni = self.server.patroni if patroni.ha.is_leader(): status_code = 200 @@ -213,21 +384,51 @@ class RestApiHandler(BaseHTTPRequestHandler): status_code = 503 self._write_status_code_only(status_code) - def do_GET_patroni(self): + def do_GET_patroni(self) -> None: + """Handle a ``GET`` request to ``/patroni`` path. + + Write an HTTP response through :func:`_write_status_response`, with HTTP status ``200`` and the status of + Postgres. + """ response = self.get_postgresql_status(True) self._write_status_response(200, response) def do_GET_cluster(self) -> None: - """Sends response with JSON representaion of Cluster topology.""" + """Handle a ``GET`` request to ``/cluster`` path. + + Write an HTTP response with JSON content based on the output of :func:`cluster_as_json`, with HTTP status + ``200`` and the JSON representation of the cluster topology. + """ cluster = self.server.patroni.dcs.get_cluster(True) global_config = self.server.patroni.config.get_global_config(cluster) self._write_json_response(200, cluster_as_json(cluster, global_config)) - def do_GET_history(self): + def do_GET_history(self) -> None: + """Handle a ``GET`` request to ``/history`` path. + + Write an HTTP response with a JSON content representing the history of events in the cluster, with HTTP status + ``200``. + + The response contains a :class:`list` of failover/switchover events. Each item is a :class:`list` with the + following items: + * Timeline when the event occurred (class:`int`); + * LSN at which the event occurred (class:`int`); + * The reason for the event (class:`str`); + * Timestamp when the new timeline was created (class:`str`); + * Name of the involved Patroni node (class:`str`). + """ cluster = self.server.patroni.dcs.cluster or self.server.patroni.dcs.get_cluster() self._write_json_response(200, cluster.history and cluster.history.lines or []) - def do_GET_config(self): + def do_GET_config(self) -> None: + """Handle a ``GET`` request to ``/config`` path. + + Write an HTTP response with a JSON content representing the Patroni configuration that is stored in the DCS, + with HTTP status ``200``. + + If the cluster information is not available in the DCS, then it will respond with no body and HTTP status + ``502`` instead. + """ cluster = self.server.patroni.dcs.cluster or self.server.patroni.dcs.get_cluster() if cluster.config: self._write_json_response(200, cluster.config.data) @@ -235,12 +436,38 @@ class RestApiHandler(BaseHTTPRequestHandler): self.send_error(502) def do_GET_metrics(self) -> None: - """Sends response in Prometheus format.""" + """Handle a ``GET`` request to ``/metrics`` path. + + Write an HTTP response with plain text content in the format used by Prometheus, with HTTP status ``200``. + + The response contains the following items: + + * ``patroni_version``: Patroni version without periods, e.g. ``030002`` for Patroni ``3.0.2``; + * ``patroni_postgres_running``: ``1`` if PostgreSQL is running, else ``0``; + * ``patroni_postmaster_start_time``: epoch timestamp since Postmaster was started; + * ``patroni_master``: ``1`` if this node holds the leader lock, else ``0``; + * ``patroni_primary``: same as ``patroni_master``; + * ``patroni_xlog_location``: ``pg_wal_lsn_diff(pg_current_wal_lsn(), '0/0')`` if leader, else ``0``; + * ``patroni_standby_leader``: ``1`` if standby leader node, else ``0``; + * ``patroni_replica``: ``1`` if a replica, else ``0``; + * ``patroni_sync_standby``: ``1`` if a sync replica, else ``0``; + * ``patroni_xlog_received_location``: ``pg_wal_lsn_diff(pg_last_wal_receive_lsn(), '0/0')``; + * ``patroni_xlog_replayed_location``: ``pg_wal_lsn_diff(pg_last_wal_replay_lsn(), '0/0)``; + * ``patroni_xlog_replayed_timestamp``: ``pg_last_xact_replay_timestamp``; + * ``patroni_xlog_paused``: ``pg_is_wal_replay_paused()``; + * ``patroni_postgres_server_version``: Postgres version without periods, e.g. ``150002`` for Postgres ``15.2``; + * ``patroni_cluster_unlocked``: ``1`` if no one holds the leader lock, else ``0``; + * ``patroni_failsafe_mode_is_active``: ``1`` if ``failmode`` is currently active, else ``0``; + * ``patroni_postgres_timeline``: PostgreSQL timeline based on current WAL file name; + * ``patroni_dcs_last_seen``: epoch timestamp when DCS was last contacted successfully; + * ``patroni_pending_restart``: ``1`` if this PostgreSQL node is pending a restart, else ``0``; + * ``patroni_is_paused``: ``1`` if Patroni is in maintenance node, else ``0``. + """ postgres = self.get_postgresql_status(True) patroni = self.server.patroni epoch = datetime.datetime(1970, 1, 1, tzinfo=tzutc) - metrics = [] + metrics: List[str] = [] scope_label = '{{scope="{0}"}}'.format(patroni.postgresql.scope) metrics.append("# HELP patroni_version Patroni semver without periods.") @@ -315,7 +542,7 @@ class RestApiHandler(BaseHTTPRequestHandler): metrics.append("# TYPE patroni_cluster_unlocked gauge") metrics.append("patroni_cluster_unlocked{0} {1}".format(scope_label, int(postgres.get('cluster_unlocked', 0)))) - metrics.append("# HELP patroni_failsafe_mode_is_active Value is 1 if the cluster is unlocked, 0 if locked.") + metrics.append("# HELP patroni_failsafe_mode_is_active Value is 1 if failsafe mode is active, 0 if inactive.") metrics.append("# TYPE patroni_failsafe_mode_is_active gauge") metrics.append("patroni_failsafe_mode_is_active{0} {1}" .format(scope_label, int(postgres.get('failsafe_mode_is_active', 0)))) @@ -340,14 +567,32 @@ class RestApiHandler(BaseHTTPRequestHandler): self._write_response(200, '\n'.join(metrics) + '\n', content_type='text/plain') - def _read_json_content(self, body_is_optional=False): + def _read_json_content(self, body_is_optional: bool = False) -> Optional[Dict[Any, Any]]: + """Read JSON from HTTP request body. + + .. note:: + Retrieves the request body based on `content-length` HTTP header. The body is expected to be a JSON + string with that length. + + If request body is expected but `content-length` HTTP header is absent, then write an HTTP response + with HTTP status ``411``. + + If request body is expected but contains nothing, or if an exception is faced, then write an HTTP + response with HTTP status ``400``. + + :param body_is_optional: if ``False`` then the request must contain a body. If ``True``, then the request may or + may not contain a body. + + :returns: deserialized JSON string from request body, if present. If body is absent, but *body_is_optional* is + ``True``, then return an empty dictionary. Returns ``None`` otherwise. + """ if 'content-length' not in self.headers: return self.send_error(411) if not body_is_optional else {} try: - content_length = int(self.headers.get('content-length')) + content_length = int(self.headers.get('content-length') or 0) if content_length == 0 and body_is_optional: return {} - request = json.loads(self.rfile.read(content_length).decode('utf-8')) + request: Union[Dict[str, Any], Any] = json.loads(self.rfile.read(content_length).decode('utf-8')) if isinstance(request, dict) and (request or body_is_optional): return request except Exception: @@ -355,7 +600,18 @@ class RestApiHandler(BaseHTTPRequestHandler): self.send_error(400) @check_access - def do_PATCH_config(self): + def do_PATCH_config(self) -> None: + """Handle a ``PATCH`` request to ``/config`` path. + + Updates the Patroni configuration based on the JSON request body, then writes a response with the new + configuration, with HTTP status ``200``. + + .. note:: + If the configuration has been previously wiped out from DCS, then write a response with + HTTP status ``503``. + + If applying a configuration value fails, then write a response with HTTP status ``409``. + """ request = self._read_json_content() if request: cluster = self.server.patroni.dcs.get_cluster(True) @@ -370,22 +626,43 @@ class RestApiHandler(BaseHTTPRequestHandler): self._write_json_response(200, data) @check_access - def do_PUT_config(self): + def do_PUT_config(self) -> None: + """Handle a ``PUT`` request to ``/config`` path. + + Overwrites the Patroni configuration based on the JSON request body, then writes a response with the new + configuration, with HTTP status ``200``. + + .. note:: + If applying the new configuration fails, then write a response with HTTP status ``502``. + """ request = self._read_json_content() if request: cluster = self.server.patroni.dcs.get_cluster() - if not deep_compare(request, cluster.config.data): + if not (cluster.config and deep_compare(request, cluster.config.data)): value = json.dumps(request, separators=(',', ':')) if not self.server.patroni.dcs.set_config_value(value): return self.send_error(502) self._write_json_response(200, request) @check_access - def do_POST_reload(self): + def do_POST_reload(self) -> None: + """Handle a ``POST`` request to ``/reload`` path. + + Schedules a reload to Patroni and writes a response with HTTP status `202`. + """ self.server.patroni.sighup_handler() self._write_response(202, 'reload scheduled') - def do_GET_failsafe(self): + def do_GET_failsafe(self) -> None: + """Handle a ``GET`` request to ``/failsafe`` path. + + Writes a response with a JSON string body containing all nodes that are known to Patroni at a given point + in time, with HTTP status ``200``. The JSON contains a dictionary, each key is the name of the Patroni node, + and the corresponding value is the URI to access `/patroni` path of its REST API. + + .. note:: + If ``failsafe_mode`` is not enabled, then write a response with HTTP status ``502``. + """ failsafe = self.server.patroni.dcs.failsafe if isinstance(failsafe, dict): self._write_json_response(200, failsafe) @@ -393,7 +670,15 @@ class RestApiHandler(BaseHTTPRequestHandler): self.send_error(502) @check_access - def do_POST_failsafe(self): + def do_POST_failsafe(self) -> None: + """Handle a ``POST`` request to ``/failsafe`` path. + + Writes a response with HTTP status ``200`` if this node is a Standby, or with HTTP status ``500`` if this is + the primary. + + .. note:: + If ``failsafe_mode`` is not enabled, then write a response with HTTP status ``502``. + """ if self.server.patroni.ha.is_failsafe_mode(): request = self._read_json_content() if request: @@ -404,16 +689,34 @@ class RestApiHandler(BaseHTTPRequestHandler): self.send_error(502) @check_access - def do_POST_sigterm(self): - """Only for behave testing on windows""" + def do_POST_sigterm(self) -> None: + """Handle a ``POST`` request to ``/sigterm`` path. + Schedule a shutdown and write a response with HTTP status ``202``. + + .. note:: + Only for behave testing on Windows. + """ if os.name == 'nt' and os.getenv('BEHAVE_DEBUG'): self.server.patroni.api_sigterm() self._write_response(202, 'shutdown scheduled') @staticmethod - def parse_schedule(schedule, action): - """ parses the given schedule and validates at """ + def parse_schedule(schedule: str, + action: str) -> Tuple[Union[int, None], Union[str, None], Union[datetime.datetime, None]]: + """Parse the given *schedule* and validate it. + + :param schedule: a string representing a timestamp, e.g. ``2023-04-14T20:27:00+00:00``. + :param action: the action to be scheduled (``restart``, ``switchover``, or ``failover``). + + :returns: a tuple composed of 3 items + * Suggested HTTP status code for a response: + * ``None``: if no issue was faced while parsing, leaving it up to the caller to decide the status; or + * ``400``: if no timezone information could be found in *schedule*; or + * ``422``: if *schedule* is invalid -- in the past or not parsable. + * An error message, if any error is faced, otherwise ``None``; + * Parsed *schedule*, if able to parse, otherwise ``None``. + """ error = None scheduled_at = None try: @@ -430,11 +733,41 @@ class RestApiHandler(BaseHTTPRequestHandler): logger.exception('Invalid scheduled %s time: %s', action, schedule) error = 'Unable to parse scheduled timestamp. It should be in an unambiguous format, e.g. ISO 8601' status_code = 422 - return (status_code, error, scheduled_at) + return status_code, error, scheduled_at @check_access def do_POST_restart(self) -> None: - """Is used to restart postgres, mainly by "patronictl restart".""" + """Handle a ``POST`` request to ``/restart`` path. + + Used to restart postgres (or schedule a restart), mainly by ``patronictl restart``. + + The request body should be a JSON dictionary, and it can contain the following keys: + * ``schedule``: timestamp at which the restart should occur; + * ``role``: restart only nodes which role is ``role``. Can be either: + * ``primary`` (or ``master``); or + * ``replica``. + * ``postgres_version``: restart only nodes which PostgreSQL version is less than ``postgres_version``, e.g. + ``15.2``; + * ``timeout``: if restart takes longer than ``timeout`` return an error and fail over to a replica; + * ``restart_pending``: if we should restart only when have ``pending restart`` flag; + + Response HTTP status codes: + * ``200``: if successfully performed an immediate restart; or + * ``202``: if successfully scheduled a restart for later; or + * ``500``: if the cluster is in maintenance mode; or + * ``400``: if + * ``role`` value is invalid; or + * ``postgres_version`` value is invalid; or + * ``timeout`` is not a number, or lesser than ``0``; or + * request contains an unknown key; or + * exception is faced while performing an immediate restart. + * ``409``: if another restart was already previously scheduled; or + * ``503``: if any issue was found while performing an immediate restart; or + * HTTP status returned by :func:`parse_schedule`, if any error was observed while parsing the schedule. + + .. note:: + If it's not able to parse the request body, then the request is silently discarded. + """ status_code = 500 data = 'restart failed' request = self._read_json_content(body_is_optional=True) @@ -492,10 +825,21 @@ class RestApiHandler(BaseHTTPRequestHandler): else: data = "Another restart is already scheduled" status_code = 409 + # pyright thinks ``data`` can be ``None`` because ``parse_schedule`` call may return ``None``. However, if + # that's the case, ``data`` will be overwritten when the ``for`` loop ends + assert isinstance(data, str) self._write_response(status_code, data) @check_access - def do_DELETE_restart(self): + def do_DELETE_restart(self) -> None: + """Handle a ``DELETE`` request to ``/restart`` path. + + Used to remove a scheduled restart of PostgreSQL. + + Response HTTP status codes: + * ``200``: if a scheduled restart was removed; or + * ``404``: if no scheduled restart could be found. + """ if self.server.patroni.ha.delete_future_restart(): data = "scheduled restart deleted" code = 200 @@ -505,7 +849,16 @@ class RestApiHandler(BaseHTTPRequestHandler): self._write_response(code, data) @check_access - def do_DELETE_switchover(self): + def do_DELETE_switchover(self) -> None: + """Handle a ``DELETE`` request to ``/switchover`` path. + + Used to remove a scheduled switchover in the cluster. + + It writes a response, and the HTTP status code can be: + * ``200``: if a scheduled switchover was removed; or + * ``404``: if no scheduled switchover could be found; or + * ``409``: if not able to update the switchover info in the DCS. + """ failover = self.server.patroni.dcs.get_cluster().failover if failover and failover.scheduled_at: if not self.server.patroni.dcs.manual_failover('', '', index=failover.index): @@ -519,7 +872,16 @@ class RestApiHandler(BaseHTTPRequestHandler): self._write_response(code, data) @check_access - def do_POST_reinitialize(self): + def do_POST_reinitialize(self) -> None: + """Handle a ``POST`` request to ``/reinitialize`` path. + + The request body may contain a JSON dictionary with the following key: + * ``force``: ``True`` if we want to cancel an already running task in order to reinit a replica. + + Response HTTP status codes: + * ``200``: if the reinit operation has started; or + * ``503``: if any error is returned by :func:`Ha.reinitialize`. + """ request = self._read_json_content(body_is_optional=True) if request: @@ -535,13 +897,25 @@ class RestApiHandler(BaseHTTPRequestHandler): status_code = 503 self._write_response(status_code, data) - def poll_failover_result(self, leader, candidate, action): + def poll_failover_result(self, leader: Optional[str], candidate: Optional[str], action: str) -> Tuple[int, str]: + """Poll failover/switchover operation until it finishes or times out. + + :param leader: name of the current Patroni leader. + :param candidate: name of the Patroni node to be promoted. + :param action: the action that is ongoing (``switchover`` or ``failover``). + + :returns: a tuple composed of 2 items + * Response HTTP status codes: + * ``200``: if the operation succeeded; or + * ``503``: if the operation failed or timed out. + * A status message about the operation. + """ timeout = max(10, self.server.patroni.dcs.loop_wait) for _ in range(0, timeout * 2): time.sleep(1) try: cluster = self.server.patroni.dcs.get_cluster() - if not cluster.is_unlocked() and cluster.leader.name != leader: + if not cluster.is_unlocked() and cluster.leader and cluster.leader.name != leader: if not candidate or candidate == cluster.leader.name: return 200, 'Successfully {0}ed over to "{1}"'.format(action[:-4], cluster.leader.name) else: @@ -553,10 +927,16 @@ class RestApiHandler(BaseHTTPRequestHandler): logger.debug('Exception occurred during polling %s result: %s', action, e) return 503, action.title() + ' status unknown' - def is_failover_possible(self, cluster: Cluster, leader: str, candidate: str, action: str) -> Union[str, None]: - """Checks whether there are nodes that could take it over after demoting the primary. + def is_failover_possible(self, cluster: Cluster, leader: Optional[str], candidate: Optional[str], + action: str) -> Optional[str]: + """Checks whether there are nodes that could take over after demoting the primary. - :returns: a string with the error message or `None` if good nodes are found + :param cluster: the Patroni cluster. + :param leader: name of the current Patroni leader. + :param candidate: name of the Patroni node to be promoted. + :param action: the action to be performed (``switchover`` or ``failover``). + + :returns: a string with the error message or ``None`` if good nodes are found. """ is_synchronous_mode = self.server.patroni.config.get_global_config(cluster).is_synchronous_mode if leader and (not cluster.leader or cluster.leader.name != leader): @@ -572,7 +952,7 @@ class RestApiHandler(BaseHTTPRequestHandler): if not members: return action + ' is not possible: can not find sync_standby' else: - members = [m for m in cluster.members if m.name != cluster.leader.name and m.api_url] + members = [m for m in cluster.members if not cluster.leader or m.name != cluster.leader.name and m.api_url] if not members: return action + ' is not possible: cluster does not have members except leader' for st in self.server.patroni.ha.fetch_nodes_statuses(members): @@ -581,8 +961,29 @@ class RestApiHandler(BaseHTTPRequestHandler): return action + ' is not possible: no good candidates have been found' @check_access - def do_POST_failover(self, action: Optional[str] = 'failover') -> None: - """Handles manual failovers/switchovers, mainly from "patronictl".""" + def do_POST_failover(self, action: str = 'failover') -> None: + """Handle a ``POST`` request to ``/failover`` path. + + Handles manual failovers/switchovers, mainly from ``patronictl``. + + The request body should be a JSON dictionary, and it can contain the following keys: + * ``leader``: name of the current leader in the cluster; + * ``candidate``: name of the Patroni node to be promoted; + * ``scheduled_at``: a string representing the timestamp when to execute the switchover/failover, e.g. + ``2023-04-14T20:27:00+00:00``. + + Response HTTP status codes: + * ``202``: if operation has been scheduled; + * ``412``: if operation is not possible; + * ``503``: if unable to register the operation to the DCS; + * HTTP status returned by :func:`parse_schedule`, if any error was observed while parsing the schedule; + * HTTP status returned by :func:`poll_failover_result` if the operation has been processed immediately. + + .. note:: + If unable to parse the request body, then the request is silently discarded. + + :param action: the action to be performed (``switchover`` or ``failover``). + """ request = self._read_json_content() (status_code, data) = (400, '') if not request: @@ -630,13 +1031,29 @@ class RestApiHandler(BaseHTTPRequestHandler): else: data = 'failed to write {0} key into DCS'.format(action) status_code = 503 + # pyright thinks ``status_code`` can be ``None`` because ``parse_schedule`` call may return ``None``. However, + # if that's the case, ``status_code`` will be overwritten somewhere between ``parse_schedule`` and + # ``_write_response`` calls. + assert isinstance(status_code, int) self._write_response(status_code, data) - def do_POST_switchover(self): + def do_POST_switchover(self) -> None: + """Handle a ``POST`` request to ``/switchover`` path. + + Calls :func:`do_POST_failover` with ``switchover`` option. + """ self.do_POST_failover(action='switchover') @check_access - def do_POST_citus(self): + def do_POST_citus(self) -> None: + """Handle a ``POST`` request to ``/citus`` path. + + Call :func:`CitusHandler.handle_event` to handle the request, then write a response with HTTP status code + ``200``. + + .. note:: + If unable to parse the request body, then the request is silently discarded. + """ request = self._read_json_content() if not request: return @@ -647,16 +1064,20 @@ class RestApiHandler(BaseHTTPRequestHandler): patroni.postgresql.citus_handler.handle_event(cluster, request) self._write_response(200, 'OK') - def parse_request(self): - """Override parse_request method to enrich basic functionality of `BaseHTTPRequestHandler` class + def parse_request(self) -> bool: + """Override :func:`parse_request` method to enrich basic functionality of :class:`BaseHTTPRequestHandler`. - Original class can only invoke do_GET, do_POST, do_PUT, etc method implementations if they are defined. + Original class can only invoke :func:`do_GET`, :func:`do_POST`, :func:`do_PUT`, etc method implementations if + they are defined. But we would like to have at least some simple routing mechanism, i.e.: - GET /uri1/part2 request should invoke `do_GET_uri1()` - POST /other should invoke `do_POST_other()` + * ``GET /uri1/part2`` request should invoke :func:`do_GET_uri1()` + * ``POST /other`` should invoke :func:`do_POST_other()` - If the `do__` method does not exists we'll fallback to original behavior.""" + If the :func:`do__` method does not exist we'll fall back to original behavior. + :returns: ``True`` for success, ``False`` for failure; on failure, any relevant error response has already been + sent back. + """ ret = BaseHTTPRequestHandler.parse_request(self) if ret: urlpath = urlparse(self.path) @@ -668,18 +1089,57 @@ class RestApiHandler(BaseHTTPRequestHandler): self.command = mname return ret - def query(self, sql, *params, **kwargs): + def query(self, sql: str, *params: Any, **kwargs: Any) -> List[Tuple[Any, ...]]: + """Execute *sql* query with *params*. + + :param sql: the SQL statement to be run. + :param params: positional arguments to call :func:`RestApiServer.query` with. + :param kwargs: can contain the key ``retry``. If the key is present its value should be a :class:`bool` which + indicates whether the query should be retried upon failure or given up immediately. + + :returns: a list of rows that were fetched from the database. + """ if not kwargs.get('retry', False): return self.server.query(sql, *params) retry = Retry(delay=1, retry_exceptions=PostgresConnectionException) return retry(self.server.query, sql, *params) - def get_postgresql_status(self, retry: Optional[bool] = False) -> Dict[str, Any]: + def get_postgresql_status(self, retry: bool = False) -> Dict[str, Any]: """Builds an object representing a status of "postgres". - Some of values are collected by executing a query and other are taken from the state stored in memory. + Some of the values are collected by executing a query and other are taken from the state stored in memory. + :param retry: whether the query should be retried if failed or give up immediately - :returns: a dict with the status of Postgres/Patroni + :returns: a dict with the status of Postgres/Patroni. The keys are: + * ``state``: Postgres state among ``stopping``, ``stopped``, ``stop failed``, ``crashed``, ``running``, + ``starting``, ``start failed``, ``restarting``, ``restart failed``, ``initializing new cluster``, + ``initdb failed``, ``running custom bootstrap script``, ``custom bootstrap failed``, + ``creating replica``, or ``unknown``; + * ``postmaster_start_time``: ``pg_postmaster_start_time()``; + * ``role``: ``replica`` or ``master`` based on ``pg_is_in_recovery()`` output; + * ``server_version``: Postgres version without periods, e.g. ``150002`` for Postgres ``15.2``; + * ``xlog``: dictionary. Its structure depends on ``role``: + * If ``master``: + * ``location``: ``pg_current_wal_lsn()`` + * If ``replica``: + * ``received_location``: ``pg_wal_lsn_diff(pg_last_wal_receive_lsn(), '0/0')``; + * ``replayed_location``: ``pg_wal_lsn_diff(pg_last_wal_replay_lsn(), '0/0)``; + * ``replayed_timestamp``: ``pg_last_xact_replay_timestamp``; + * ``paused``: ``pg_is_wal_replay_paused()``; + * ``sync_standby``: ``True`` if replication mode is synchronous and this is a sync standby; + * ``timeline``: PostgreSQL primary node timeline; + * ``replication``: :class:`list` of :class:`dict` entries, one for each replication connection. Each entry + contains the following keys: + * ``application_name``: ``pg_stat_activity.application_name``; + * ``client_addr``: ``pg_stat_activity.client_addr``; + * ``state``: ``pg_stat_replication.state``; + * ``sync_priority``: ``pg_stat_replication.sync_priority``; + * ``sync_state``: ``pg_stat_replication.sync_state``; + * ``usename``: ``pg_stat_activity.usename``. + * ``pause``: ``True`` if cluster is in maintenance mode; + * ``cluster_unlocked``: ``True`` if cluster has no node holding the leader lock; + * ``failsafe_mode_is_active``: ``True`` if DCS failsafe mode is currently active; + * ``dcs_last_seen``: epoch timestamp DCS was last reached by Patroni. """ postgresql = self.server.patroni.postgresql cluster = self.server.patroni.dcs.cluster @@ -715,13 +1175,14 @@ class RestApiHandler(BaseHTTPRequestHandler): result['role'] = postgresql.role if result['role'] == 'replica' and global_config.is_synchronous_mode\ - and cluster.sync.matches(postgresql.name): + and cluster and cluster.sync.matches(postgresql.name): result['sync_standby'] = True if row[1] > 0: result['timeline'] = row[1] else: - leader_timeline = None if not cluster or cluster.is_unlocked() else cluster.leader.timeline + leader_timeline = None\ + if not cluster or cluster.is_unlocked() or not cluster.leader else cluster.leader.timeline result['timeline'] = postgresql.replica_cached_timeline(leader_timeline) if row[7]: @@ -732,7 +1193,7 @@ class RestApiHandler(BaseHTTPRequestHandler): if state == 'running': logger.exception('get_postgresql_status') state = 'unknown' - result = {'state': state, 'role': postgresql.role} + result: Dict[str, Any] = {'state': state, 'role': postgresql.role} if global_config.is_paused: result['pause'] = True @@ -743,34 +1204,72 @@ class RestApiHandler(BaseHTTPRequestHandler): result['dcs_last_seen'] = self.server.patroni.dcs.last_seen return result - def handle_one_request(self): + def handle_one_request(self) -> None: + """Parse and dispatch a request to the appropriate ``do_*`` method. + + .. note:: + This is only used to keep track of latency when logging messages through :func:`log_message`. + """ self.__start_time = time.time() BaseHTTPRequestHandler.handle_one_request(self) - def log_message(self, fmt, *args): + def log_message(self, format: str, *args: Any) -> None: + """Log a custom ``debug`` message. + + Additionally, to *format*, the log entry contains the client IP address and the current latency of the request. + + :param format: printf-style format string message to be logged. + :param args: arguments to be applied as inputs to *format*. + """ latency = 1000.0 * (time.time() - self.__start_time) - logger.debug("API thread: %s - - %s latency: %0.3f ms", self.client_address[0], fmt % args, latency) + logger.debug("API thread: %s - - %s latency: %0.3f ms", self.client_address[0], format % args, latency) class RestApiServer(ThreadingMixIn, HTTPServer, Thread): + """Patroni REST API server. + + An asynchronous thread-based HTTP server. + """ + # On 3.7+ the `ThreadingMixIn` gathers all non-daemon worker threads in order to join on them at server close. daemon_threads = True # Make worker threads "fire and forget" to prevent a memory leak. - def __init__(self, patroni, config): + def __init__(self, patroni: Patroni, config: Dict[str, Any]) -> None: + """Establish patroni configuration for the REST API daemon. + + Create a :class:`RestApiServer` instance. + + :param patroni: Patroni daemon process. + :param config: ``restapi`` section of Patroni configuration. + """ + self.connection_string = None + self.__auth_key = None + self.__allowlist_include_members: Optional[bool] = None + self.__allowlist: Tuple[Union[IPv4Network, IPv6Network], ...] = () + self.http_extra_headers: Dict[str, str] = {} self.patroni = patroni self.__listen = None self.request_queue_size = int(config.get('request_queue_size', 5)) - self.__ssl_options = None + self.__ssl_options: Dict[str, Any] = {} self.__ssl_serial_number = None self._received_new_cert = False self.reload_config(config) self.daemon = True - def query(self, sql, *params): + def query(self, sql: str, *params: Any) -> List[Tuple[Any, ...]]: + """Execute *sql* query with *params*. + + :param sql: the SQL statement to be run. + :param params: positional arguments to be used as parameters for *sql*. + + :returns: a list of rows that were fetched from the database. + :raises psycopg.Error: if had issues while executing *sql*. + :raises PostgresConnectionException: if had issues while connecting to the database. + """ cursor = None try: with self.patroni.postgresql.connection().cursor() as cursor: - cursor.execute(sql, params) + cursor.execute(sql.encode('utf-8'), params) return [r for r in cursor] except psycopg.Error as e: if cursor and cursor.connection.closed == 0: @@ -778,16 +1277,39 @@ class RestApiServer(ThreadingMixIn, HTTPServer, Thread): raise PostgresConnectionException('connection problems') @staticmethod - def _set_fd_cloexec(fd): + def _set_fd_cloexec(fd: socket.socket) -> None: + """Set ``FD_CLOEXEC`` for *fd*. + + It is used to avoid inheriting the REST API port when forking its process. + + .. note:: + Only takes effect on non-Windows environments. + + :param fd: socket file descriptor. + """ if os.name != 'nt': import fcntl flags = fcntl.fcntl(fd, fcntl.F_GETFD) fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC) - def check_basic_auth_key(self, key): + def check_basic_auth_key(self, key: str) -> bool: + """Check if *key* matches the password configured for the REST API. + + :param key: the password received through the Basic authorization header of an HTTP request. + + :returns: ``True`` if *key* matches the password configured for the REST API. + """ + # pyright -- ``__auth_key`` was already checked through the caller method (:func:`check_auth_header`). + assert self.__auth_key is not None return hmac.compare_digest(self.__auth_key, key.encode('utf-8')) - def check_auth_header(self, auth_header): + def check_auth_header(self, auth_header: Optional[str]) -> Optional[str]: + """Validate HTTP Basic authorization header, if present. + + :param auth_header: value of ``Authorization`` HTTP header, if present, else ``None``. + + :returns: an error message if any issue is found, ``None`` otherwise. + """ if self.__auth_key: if auth_header is None: return 'no auth header received' @@ -795,14 +1317,29 @@ class RestApiServer(ThreadingMixIn, HTTPServer, Thread): return 'not authenticated' @staticmethod - def __resolve_ips(host, port): + def __resolve_ips(host: str, port: int) -> Iterator[Union[IPv4Network, IPv6Network]]: + """Resolve *host* + *port* to one or more IP networks. + + :param host: hostname to be checked. + :param port: port to be checked. + + :rtype: Iterator[Union[IPv4Network, IPv6Network]] of *host* + *port* resolved to IP networks. + """ try: for _, _, _, _, sa in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM, socket.IPPROTO_TCP): yield ip_network(sa[0], False) except Exception as e: logger.error('Failed to resolve %s: %r', host, e) - def __members_ips(self): + def __members_ips(self) -> Iterator[Union[IPv4Network, IPv6Network]]: + """Resolve each Patroni node ``restapi.connect_address`` to IP networks. + + .. note:: + Only yields object if ``restapi.allowlist_include_members`` setting is enabled. + + :rtype: Iterator[Union[IPv4Network, IPv6Network]] of each node ``restapi.connect_address`` resolved to an IP + network. + """ cluster = self.patroni.dcs.cluster if self.__allowlist_include_members and cluster: for cluster in [cluster] + list(cluster.workers.values()): @@ -810,14 +1347,27 @@ class RestApiServer(ThreadingMixIn, HTTPServer, Thread): if member.api_url: try: r = urlparse(member.api_url) - host = r.hostname - port = r.port or (443 if r.scheme == 'https' else 80) - for ip in self.__resolve_ips(host, port): - yield ip + if r.hostname: + port = r.port or (443 if r.scheme == 'https' else 80) + for ip in self.__resolve_ips(r.hostname, port): + yield ip except Exception as e: logger.debug('Failed to parse url %s: %r', member.api_url, e) - def check_access(self, rh): + def check_access(self, rh: RestApiHandler) -> Optional[bool]: + """Ensure client has enough privileges to perform a given request. + + Write a response back to the client if any issue is observed, and the HTTP status may be: + * ``401``: if ``Authorization`` header is missing or contain an invalid password; + * ``403``: if: + * ``restapi.allowlist`` was configured, but client IP is not in the allowed list; or + * ``restapi.allowlist_include_members`` is enabled, but client IP is not in the members list; or + * a client certificate is expected by the server, but is missing in the request. + + :param rh: the request which access should be checked. + + :returns: ``True`` if client access verification succeeded, otherwise ``None``. + """ if self.__allowlist or self.__allowlist_include_members: incoming_ip = ip_address(rh.client_address[0]) if not any(incoming_ip in net for net in self.__allowlist + tuple(self.__members_ips())): @@ -834,7 +1384,11 @@ class RestApiServer(ThreadingMixIn, HTTPServer, Thread): return True @staticmethod - def __has_dual_stack(): + def __has_dual_stack() -> bool: + """Check if the system has support for dual stack sockets. + + :returns: ``True`` if it has support for dual stack sockets. + """ if hasattr(socket, 'AF_INET6') and hasattr(socket, 'IPPROTO_IPV6') and hasattr(socket, 'IPV6_V6ONLY'): sock = None try: @@ -848,12 +1402,21 @@ class RestApiServer(ThreadingMixIn, HTTPServer, Thread): sock.close() return False - def __httpserver_init(self, host, port): - dual_stack = self.__has_dual_stack() - if host in ('', '*'): - host = None + def __httpserver_init(self, host: str, port: int) -> None: + """Start REST API HTTP server. - info = socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE) + .. note:: + If system has no support for dual stack sockets, then IPv4 is preferred over IPv6. + + :param host: host to bind REST API to. + :param port: port to bind REST API to. + """ + dual_stack = self.__has_dual_stack() + hostname = host + if hostname in ('', '*'): + hostname = None + + info = socket.getaddrinfo(hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE) # in case dual stack is not supported we want IPv4 to be preferred over IPv6 info.sort(key=lambda x: x[0] == socket.AF_INET, reverse=not dual_stack) @@ -862,10 +1425,32 @@ class RestApiServer(ThreadingMixIn, HTTPServer, Thread): HTTPServer.__init__(self, info[0][-1][:2], RestApiHandler) except socket.error: logger.error( - "Couldn't start a service on '%s:%s', please check your `restapi.listen` configuration", host, port) + "Couldn't start a service on '%s:%s', please check your `restapi.listen` configuration", hostname, port) raise - def __initialize(self, listen, ssl_options): + def __initialize(self, listen: str, ssl_options: Dict[str, Any]) -> None: + """Configure and start REST API HTTP server. + + .. note:: + This method can be called upon first initialization, and also when reloading Patroni. When reloading + Patroni, it restarts the HTTP server thread. + + :param listen: IP and port to bind REST API to. It should be a string in the format ``host:port``, where + ``host`` can be a hostname or IP address. It is the value of ``restapi.listen`` setting. + :param ssl_options: dictionary that may contain the following keys, depending on what has been configured in + ``restapi` section: + * ``certfile``: path to PEM certificate. If given, will start in HTTPS mode; + * ``keyfile``: path to key of ``certfile``; + * ``keyfile_password``: password for decrypting ``keyfile``; + * ``cafile``: path to CA file to validate client certificates; + * ``ciphers``: permitted cipher suites; + * ``verify_client``: value can be one among: + * ``none``: do not check client certificates; + * ``optional``: check client certificate only for unsafe REST API endpoints; + * ``required``: check client certificate for all REST API endpoints. + + :raises ValueError: if any issue is faced while parsing *listen*. + """ try: host, port = split_host_port(listen, None) except Exception: @@ -908,30 +1493,63 @@ class RestApiServer(ThreadingMixIn, HTTPServer, Thread): if reloading_config: self.start() - def process_request_thread(self, request, client_address): - enable_keepalive(request, 10, 3) + def process_request_thread(self, request: Union[socket.socket, Tuple[bytes, socket.socket]], + client_address: Tuple[str, int]) -> None: + """Process a request to the REST API. + + Wrapper for :func:`ThreadingMixIn.process_request_thread` that additionally: + * Enable TCP keepalive + * Perform SSL handshake (if an SSL socket). + + :param request: socket to handle the client request. + :param client_address: tuple containing the client IP and port. + """ + if isinstance(request, socket.socket): + enable_keepalive(request, 10, 3) if hasattr(request, 'context'): # SSLSocket - request.do_handshake() + from ssl import SSLSocket + if isinstance(request, SSLSocket): # pyright + request.do_handshake() super(RestApiServer, self).process_request_thread(request, client_address) - def shutdown_request(self, request): + def shutdown_request(self, request: Union[socket.socket, Tuple[bytes, socket.socket]]) -> None: + """Shut down a request to the REST API. + + Wrapper for :func:`HTTPServer.shutdown_request` that additionally: + * Perform SSL shutdown handshake (if a SSL socket). + + :param request: socket to handle the client request. + """ if hasattr(request, 'context'): # SSLSocket try: - request.unwrap() + from ssl import SSLSocket + if isinstance(request, SSLSocket): # pyright + request.unwrap() except Exception as e: logger.debug('Failed to shutdown SSL connection: %r', e) super(RestApiServer, self).shutdown_request(request) - def get_certificate_serial_number(self): + def get_certificate_serial_number(self) -> Optional[str]: + """Get serial number of the certificate used by the REST API. + + :returns: serial number of the certificate configured through ``restapi.certfile`` setting. + """ if self.__ssl_options.get('certfile'): import ssl try: - crt = ssl._ssl._test_decode_cert(self.__ssl_options['certfile']) - return crt.get('serialNumber') - except ssl.SSLError as e: + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + crts = ctx.load_verify_locations(self.__ssl_options['certfile']) + if crts: + return crts[0].get('serialNumber') + except Exception as e: logger.error('Failed to get serial number from certificate %s: %r', self.__ssl_options['certfile'], e) - def reload_local_certificate(self): + def reload_local_certificate(self) -> Optional[bool]: + """Reload the SSL certificate used by the REST API. + + :return: ``True`` if a different certificate has been configured through ``restapi.certfile` setting, ``None`` + otherwise. + """ if self.__protocol == 'https': on_disk_cert_serial_number = self.get_certificate_serial_number() if on_disk_cert_serial_number != self.__ssl_serial_number: @@ -939,7 +1557,14 @@ class RestApiServer(ThreadingMixIn, HTTPServer, Thread): self.__ssl_serial_number = on_disk_cert_serial_number return True - def _build_allowlist(self, value): + def _build_allowlist(self, value: Optional[List[str]]) -> Iterator[Union[IPv4Network, IPv6Network]]: + """Resolve each entry in *value* to an IP network object. + + :param value: list of IPs and/or networks contained in ``restapi.allowlist`` setting. Each item can be a host, + an IP, or a network in CIDR format. + + :rtype: Iterator[Union[IPv4Network, IPv6Network]] of *host* + *port* resolved to IP networks. + """ if isinstance(value, list): for v in value: if '/' in v: # netmask @@ -951,7 +1576,12 @@ class RestApiServer(ThreadingMixIn, HTTPServer, Thread): for ip in self.__resolve_ips(v, 8080): yield ip - def reload_config(self, config): + def reload_config(self, config: Dict[str, Any]) -> None: + """Reload REST API configuration. + + :param config: dictionary representing values under the ``restapi`` configuration section. + :raises ValueError: if ``listen`` key is not present in *config*. + """ if 'listen' not in config: # changing config in runtime raise ValueError('Can not find "restapi.listen" config') @@ -971,10 +1601,20 @@ class RestApiServer(ThreadingMixIn, HTTPServer, Thread): self.__initialize(config['listen'], ssl_options) self.__auth_key = base64.b64encode(config['auth'].encode('utf-8')) if 'auth' in config else None + # pyright -- ``__listen`` is initially created as ``None``, but right after that it is replaced with a string + # through :func:`__initialize`. + assert isinstance(self.__listen, str) self.connection_string = uri(self.__protocol, config.get('connect_address') or self.__listen, 'patroni') - @staticmethod - def handle_error(request, client_address): + def handle_error(self, request: Union[socket.socket, Tuple[bytes, socket.socket]], + client_address: Tuple[str, int]) -> None: + """Handle any exception that is thrown while handling a request to the REST API. + + Logs ``WARNING`` messages with the client information, and the stack trace of the faced exception. + + :param request: the request that faced an exception. + :param client_address: a tuple composed of the IP and port of the client connection. + """ logger.warning('Exception happened during processing of request from %s:%s', client_address[0], client_address[1]) logger.warning(traceback.format_exc()) diff --git a/tests/test_api.py b/tests/test_api.py index 85c01b49..166d3eb1 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -180,6 +180,7 @@ class MockRestApiServer(RestApiServer): @patch('ssl.SSLContext.load_cert_chain', Mock()) @patch('ssl.SSLContext.wrap_socket', Mock(return_value=0)) +@patch('ssl.SSLContext.load_verify_locations', Mock(return_value=[Mock()])) @patch.object(HTTPServer, '__init__', Mock()) class TestRestApiHandler(unittest.TestCase): @@ -588,6 +589,7 @@ class TestRestApiServer(unittest.TestCase): @patch('ssl.SSLContext.load_cert_chain', Mock()) @patch('ssl.SSLContext.set_ciphers', Mock()) @patch('ssl.SSLContext.wrap_socket', Mock(return_value=0)) + @patch('ssl.SSLContext.load_verify_locations', Mock(return_value=[Mock()])) @patch.object(HTTPServer, '__init__', Mock()) def setUp(self): self.srv = MockRestApiServer(Mock(), '', {'listen': '*:8008', 'certfile': 'a', 'verify_client': 'required', @@ -621,24 +623,39 @@ class TestRestApiServer(unittest.TestCase): try: raise Exception() except Exception: - self.assertIsNone(MockRestApiServer.handle_error(None, ('127.0.0.1', 55555))) + self.assertIsNone(self.srv.handle_error(None, ('127.0.0.1', 55555))) @patch.object(HTTPServer, '__init__', Mock(side_effect=socket.error)) def test_socket_error(self): self.assertRaises(socket.error, MockRestApiServer, Mock(), '', {'listen': '*:8008'}) + def __create_socket(self): + sock = socket.socket() + try: + import ssl + ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + ctx.check_hostname = False + sock = ctx.wrap_socket(sock=sock) + sock.do_handshake = Mock() + sock.unwrap = Mock(side_effect=Exception) + except Exception: + pass + return sock + @patch.object(ThreadingMixIn, 'process_request_thread', Mock()) def test_process_request_thread(self): - self.srv.process_request_thread(Mock(), '2') + self.srv.process_request_thread(self.__create_socket(), ('2', 54321)) @patch.object(MockRestApiServer, 'process_request', Mock(side_effect=RuntimeError)) @patch.object(MockRestApiServer, 'get_request') def test_process_request_error(self, mock_get_request): - mock_request = Mock() - mock_request.unwrap.side_effect = Exception - mock_get_request.return_value = (mock_request, ('127.0.0.1', 55555)) + mock_get_request.return_value = (self.__create_socket(), ('127.0.0.1', 55555)) self.srv._handle_request_noblock() - @patch('ssl._ssl._test_decode_cert', Mock()) + @patch('ssl.SSLContext.load_verify_locations', Mock(return_value=[Mock()])) def test_reload_local_certificate(self): self.assertTrue(self.srv.reload_local_certificate()) + + @patch('ssl.SSLContext.load_verify_locations', Mock(side_effect=Exception)) + def test_get_certificate_serial_number(self): + self.assertIsNone(self.srv.get_certificate_serial_number())