Skip to content
Snippets Groups Projects
test_postgresql.py 11.5 KiB
Newer Older
#!/usr/bin/env python3
"""
Criticality: High
Checks the current state of the PostgreSQL database cluster.
from pathlib import Path
import re
import socket
import sys
import time
import urllib
try:
    import psycopg2
except ImportError:
    sys.exit(2)
sys.path.append(str(Path(__file__).parents[1].resolve()))

import utils as u  # noqa: E402
Nicolas KAROLAK's avatar
Nicolas KAROLAK committed
from utils_lib.apt import Apt  # noqa: E402
def is_ha(port: int) -> bool:
    """Check wether this setup is using higlhy-available databases.

    :param port: Port number
    :type port: int
    :return: Wether it is a highly-available setup or not
    :rtype: bool
    """

    return port == 54321


def get_haproxy_conf(path: str = "/etc/haproxy/haproxy.cfg") -> dict:
    """Get HAProxy configuration in a dictionary.

    :param path: HAProxy configuration file, defaults to "/etc/haproxy/haproxy.cfg"
    :type path: str
    :return: HAProxy configuration file content
    :rtype: dict
    """

    # init configuration dictionary

    # load configuration file
    try:
        with open(path) as data:
            lines = data.readlines()
    except EnvironmentError:
        return conf

    # define patterns
    pattern_block = re.compile(r"^([a-zA-Z0-9_.-]+ *[a-zA-Z0-9_.-]+)")
    pattern_param = re.compile(r"^\s+([ /:()|a-zA-Z0-9_.-]+)")

    # parse configuration file
    for line in lines:
        match_block = pattern_block.match(line)
        if match_block:
            block = match_block.group(1)
            conf[block] = []
        else:
            match_param = pattern_param.match(line)
            if match_param:
                param = match_param.group(1)
                conf[block].append(param)

    return conf


def get_nodes(conf: dict) -> dict:
    """Get the list of nodes from HAProxy configuration.

    :param conf: The HAProxy configuration file content
    :type conf: dict
    :return: The list of nodes found in HAProxy configuration
    :rtype: dict
    """


    for item in conf.keys():
        if "pgsql-primary" in item:
            # filter `server` lines
            server_lines = [x for x in conf[item] if x.startswith("server ")]
            for line in server_lines:
                # split line
                elements = line.split()

                # get needed elements
                name = elements[1]
                address = elements[2].split(":")
                host = address[0]
                port = int(address[1])
                rephacheck = elements[7]

                # update dictionary
                servers.update(
                    {name: {"host": host, "port": port, "rephacheck": rephacheck}}
                )

    return servers


def get_node_state(host: str, port: int) -> str:
    """Get the curent state of node from its RepHACheck daemon.

    :param node: The node's hostname or IP address
    :param port: The node's port on which RepHACheck is listening
    :type node: str
    :type port: int
    :return: The current state of the node accordind to its RepHACheck daemon
    :rtype: str
    """

    # connect and get tcp stream data
    client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    client.connect((host, port))
    data = client.recv(1024)
    client.close()
    # extract string from data output
    state = data.decode("utf-8").rstrip()

    return state


def check_fenced(nodes: dict) -> tuple:
    """Check if the cluster have a fenced node.

    :param nodes: The dictionary containing nodes and their informations
    :type nodes: dict
    :return: Wether the nodes list contains a fenced server
    :rtype: tuple
    """

    for node in nodes.keys():
        host = nodes[node]["host"]
        port = int(nodes[node]["rephacheck"])
        if get_node_state(host, port) == "fenced":
            return True, node

    return False, None


def check_listen(host: str, port: int) -> bool:
    """Check if server is listening (TCP only).
    :param host: The hostname or IP address to bind
    :param port: The port number to bind
    :type host: str
    :type port: int
    :return: Wether the `host` is listening on TCP/`port`
    :rtype: bool
    # try to connect to the port used by psql-primary frontend
    client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    result = client.connect_ex((host, port))
    client.close()
def check_psql(db_conn: dict, query: str) -> tuple:
    """Check if we can write data on this node.

    :param db_conn: Database connection parameters
    :type db_conn: dict
    :param query: Query to execute
    :type query: str
    :return: Wether the query can be executed or not
    # build database connection uri
    if "password" in db_conn:
        uri = "postgresql://{}:{}@{}:{}/{}".format(
            db_conn["user"],
            urllib.parse.quote_plus(db_conn["password"]),
            db_conn["host"],
            db_conn["port"],
            db_conn["dbname"],
    else:
        uri = "postgresql:///{}".format(db_conn["dbname"])
    command = ["su -l postgres -c \"psql {} -c '{}'\"".format(uri, query)]
        subprocess.check_output(command, shell=True)
    except subprocess.CalledProcessError as psql_error:
        return False, str(psql_error).rstrip()
    return True, None
def check_replication(primary: dict, standby: dict) -> tuple:
    """Check if replication is working between the primary and standby servers.

    :param primary: Database connection parameters for primary server
    :type primary: dict
    :param standby: Database connection parameters for standby server
    :type standby: dict
    :return: Wether replication between primary/stanbdy is working or not
    """

    # connections
    try:
        primary_client = psycopg2.connect(**primary)
        standby_client = psycopg2.connect(**standby)
    except psycopg2.Error as repl_conn_error:
        return False, str(repl_conn_error).rstrip()
    # random id
    rand = uuid.uuid4().hex
    write_query = "CREATE TABLE es_test_{} (id serial PRIMARY KEY);".format(rand)
    read_query = "SELECT * FROM es_test_{};".format(rand)

    # write
    try:
        primary_psql = primary_client.cursor()
    except psycopg2.Error as repl_write_error:
        return False, str(repl_write_error).rstrip()
    # read
    max_time = 6.0
    timer = 0.0
    while timer < max_time:
        time.sleep(timer)
        timer += 0.2
        try:
            standby_psql = primary_client.cursor()
            standby_psql.execute(read_query)
            msg = "took ~{}s".format(str(timer))
        except psycopg2.Error as repl_read_error:
            msg = str(repl_read_error).rstrip()
        primary_psql.execute("DROP TABLE es_test_{};".format(rand))
    # close
    primary_psql.close()
    standby_psql.close()
    primary_client.close()
    standby_client.close()

def check_ha(db_conn: dict, errors: int = 0) -> int:
    """Run all tests for a highly-available setup.

    :param db_conn: Database connection parameters
    :type db_conn: dict
    :param errors: Error counter, defaults to 0
    :param errors: int, optional
    :return: Number of errors
    :rtype: int
    """

    # get haproxy conf
    ha_conf = get_haproxy_conf()

    # get nodes data
    nodes = get_nodes(ha_conf)

    # check haproxy
    print("Checking local HAProxy frontends:")
    if not check_listen(db_conn["host"], 54321):
        u.error("HAProxy pgsql-primary frontend is not listening")
        errors += 1
    else:
        u.success("HAProxy pgsql-primary frontend is listening")
    if not check_listen(db_conn["host"], 54322):
        u.error("HAProxy pgsql-standby frontend is not listening")
        errors += 1
    else:
        u.success("HAProxy pgsql-standby frontend is listening")
    print("Checking remote PostgreSQL nodes:")
    for node in nodes:
        node_host = nodes[node]["host"]
        node_port = nodes[node]["port"]
        if not check_listen(node_host, node_port):
            u.error("cannot bind {}:{}".format(node_host, node_port))
            errors += 1
        else:
            u.success("can bind {}:{}".format(node_host, node_port))
    print("Checking cluster state:")
    fenced, node = check_fenced(nodes)
    if fenced:
        u.error("Node `{}` is fenced".format(node))
        errors += 1
    else:
        u.success("No fenced node found")

    # check replication
    print("Checking replication state:")
    primary["port"] = 54321
    standby["port"] = 54322
    status, info = check_replication(primary, standby)
    if not status:
        u.error("cannot replicate between primary/standby ({})".format(info))
        errors += 1
    else:
        u.success("can replicate between primary/standby ({})".format(info))
def check_local(db_conn: dict, errors: int = 0) -> int:
    """Run all tests for a highly-available setup.

    :param db_conn: Database connection parameters
    :type db_conn: dict
    :param errors: Error counter, defaults to 0
    :param errors: int, optional
    :return: Number of errors
    :rtype: int
    host = db_conn["host"]
    port = db_conn["port"]
    user = db_conn["user"]
    print("Checking local PostgreSQL node:")
    if not check_listen(host, port):
        u.error("cannot connect to {}:{}".format(host, port))
        errors += 1
        u.success("can connect to {}:{}".format(host, port))
    print("Checking read operation:")
    read_query = "SELECT 1;"
    status, info = check_psql(db_conn, read_query)
    if not status:
        u.error("cannot read from {}@{}:{} ({})".format(user, host, port, info))
        errors += 1
    else:
        u.success("can read from {}@{}:{}".format(user, host, port))
    print("Checking write operation:")
    rand = uuid.uuid4().hex
    write_query = "CREATE TABLE es_test_{} (id serial PRIMARY KEY);".format(rand)
    status, info = check_psql(db_conn, write_query)
    if not status:
        u.error("cannot write on {}@{}:{} ({})".format(user, host, port, info))
        errors += 1
    else:
        u.success("can write on {}@{}:{}".format(user, host, port))
        check_psql(db_conn, "DROP TABLE es_test_{};".format(rand))


def main():
    """Run all checks and exits with corresponding exit code."""

Nicolas KAROLAK's avatar
Nicolas KAROLAK committed
    apt = Apt()
    if "postgresql" not in apt.installed_packages:
        exit(2)

    # load configuration

    # get database configuration
    db_host = conf.get("DB_HOST") if conf.get("DB_HOST") else "127.0.0.1"
    db_port = int(conf.get("DB_PORT")) if conf.get("DB_PORT") else 5432
    db_user = conf.get("DB_USER") if conf.get("DB_USER") else "postgres"
    db_pass = conf.get("DB_PG_ROOT_PWD")
    db_conn = {"dbname": db_user, "host": db_host, "port": db_port, "user": db_user}
    if db_pass:
        db_conn.update({"password": db_pass})
    # determine if HA setup and run according tests
    print("Checking availibility mode:")
    if is_ha(db_port):
        u.info("this setup is using a HA database")
        errors = check_ha(db_conn)
        u.info("this setup is using a local database")
        errors = check_local(db_conn)

    if errors:
        sys.exit(1)
    else:
        sys.exit(0)


if __name__ == "__main__":
    main()