Skip to content
Snippets Groups Projects
test_postgresql.py 14.6 KiB
Newer Older
#!/usr/bin/env python3
"""This test check the current state of the PostgreSQL database cluster."""

import imp
import os
import re
import socket
import sys
import time

import psycopg2

GREEN = "\033[92m"
YELLOW = "\033[93m"
RED = "\033[91m"
DEF = "\033[0m"


def success(message: str):
    """Print formatted success message.

    :param message: Message to print
    :type message: str
    """

    print(" {}✔{} {}".format(GREEN, DEF, message))


def warning(message: str):
    """Print formatted warning message.

    :param message: Message to print
    :type message: str
    """

    print(" {}✔{} {}".format(YELLOW, DEF, message))


def error(message: str):
    """Print formatted error message.

    :param message: Message to print
    :type message: str
    """

    print(" {}✔{} {}".format(RED, DEF, message))


def is_ha(port: int) -> bool:
    """Check wether this setup is using higlhy-available databases.

    :param port: Port number
    :type port: int
    :return: Wether it is a highly-available setup or not
    :rtype: bool
    """

    return port == 54321


def get_haproxy_conf(path: str = "/etc/haproxy/haproxy.cfg") -> dict:
    """Get HAProxy configuration in a dictionary.

    :param path: HAProxy configuration file, defaults to "/etc/haproxy/haproxy.cfg"
    :type path: str
    :return: HAProxy configuration file content
    :rtype: dict
    """

    # init configuration dictionary

    # load configuration file
    try:
        with open(path) as data:
            lines = data.readlines()
    except EnvironmentError:
        return conf

    # define patterns
    pattern_block = re.compile(r"^([a-zA-Z0-9_.-]+ *[a-zA-Z0-9_.-]+)")
    pattern_param = re.compile(r"^\s+([ /:()|a-zA-Z0-9_.-]+)")

    # parse configuration file
    for line in lines:
        match_block = pattern_block.match(line)
        if match_block:
            block = match_block.group(1)
            conf[block] = []
        else:
            match_param = pattern_param.match(line)
            if match_param:
                param = match_param.group(1)
                conf[block].append(param)

    return conf


def get_nodes(conf: dict) -> dict:
    """Get the list of nodes from HAProxy configuration.

    :param conf: The HAProxy configuration file content
    :type conf: dict
    :return: The list of nodes found in HAProxy configuration
    :rtype: dict
    """


    for item in conf.keys():
        if "pgsql-primary" in item:
            # filter `server` lines
            server_lines = [x for x in conf[item] if x.startswith("server ")]
            for line in server_lines:
                # split line
                elements = line.split()

                # get needed elements
                name = elements[1]
                address = elements[2].split(":")
                host = address[0]
                port = int(address[1])
                rephacheck = elements[7]

                # update dictionary
                servers.update(
                    {name: {"host": host, "port": port, "rephacheck": rephacheck}}
                )

    return servers


def check_odd_number(number: int) -> bool:
    """Check if we have an odd number of nodes, ensuring we can have a quorum.

    :param number: The number of nodes in the cluster
    :type number: int
    :return: Wether it as an odd number or not
    :rtype: bool
    """

    modulo = number % 2

    return modulo != 0


def get_node_state(host: str, port: int) -> str:
    """Get the curent state of node from its RepHACheck daemon.

    :param node: The node's hostname or IP address
    :param port: The node's port on which RepHACheck is listening
    :type node: str
    :type port: int
    :return: The current state of the node accordind to its RepHACheck daemon
    :rtype: str
    """

    # connect and get tcp stream data
    client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    client.connect((host, port))
    data = client.recv(1024)
    client.close()
    # extract string from data output
    state = data.decode("utf-8").rstrip()

    return state


def check_primary(nodes: dict) -> tuple:
    """Check if we have a primary in the nodes.

    :param nodes: The dictionary containing nodes and their informations
    :type nodes: dict
    :return: Wether the nodes list contains a primary server
    :rtype: tuple
    """

    for node in nodes.keys():
        host = nodes[node]["host"]
        port = int(nodes[node]["rephacheck"])
        if get_node_state(host, port) == "primary":
            return True, node

    return False, None


def check_standby(nodes: dict) -> tuple:
    """Check if we have a standby in the nodes.

    :param nodes: The dictionary containing nodes and their informations
    :type nodes: dict
    :return: Wether the nodes list contains a standby server
    :rtype: tuple
    """

    for node in nodes.keys():
        host = nodes[node]["host"]
        port = int(nodes[node]["rephacheck"])
        if get_node_state(host, port) == "standby":
            return True, node

    return False, None


def check_witness(nodes: dict) -> tuple:
    """Check if we have a witness in the nodes.

    :param nodes: The dictionary containing nodes and their informations
    :type nodes: dict
    :return: Wether the nodes list contains a witness server
    :rtype: tuple
    """

    for node in nodes.keys():
        host = nodes[node]["host"]
        port = int(nodes[node]["rephacheck"])
        if get_node_state(host, port) == "witness":
            return True, node

    return False, None


def check_fenced(nodes: dict) -> tuple:
    """Check if the cluster have a fenced node.

    :param nodes: The dictionary containing nodes and their informations
    :type nodes: dict
    :return: Wether the nodes list contains a fenced server
    :rtype: tuple
    """

    for node in nodes.keys():
        host = nodes[node]["host"]
        port = int(nodes[node]["rephacheck"])
        if get_node_state(host, port) == "fenced":
            return True, node

    return False, None


# pylint: disable=bad-continuation
def check_write(
    host: str, port: int, user: str, pswd: str, name: str = "postgres"
) -> bool:
    """Check if we can write data on this node.

    :param host: Database server's hostname or IP address
    :type host: str
    :param port: Database server's port
    :type port: int
    :param user: Database username
    :type user: str
    :param pswd: Database username's password
    :type pswd: str
    :param pswd: Database name
    :type pswd: str
    :return: Wether it is writeable or not
    :rtype: bool
    """

    # connection
    try:
        client = psycopg2.connect(
            dbname=name, user=user, password=pswd, host=host, port=port
        )
    except psycopg2.Error:
        error("Cannot connect to the database")
        return False

    # query
    try:
        psql = client.cursor()
        psql.execute("CREATE TABLE es_test (id serial PRIMARY KEY);")
        psql.execute("DROP TABLE es_test;")
    except psycopg2.Error:
        return False

    # close
    psql.close()
    client.close()

    return True


# pylint: disable=bad-continuation
def check_read(
    host: str, port: int, user: str, pswd: str, name: str = "postgres"
) -> bool:
    """Check if we can read data on this node.

    :param host: Database server's hostname or IP address
    :type host: str
    :param port: Database server's port
    :type port: int
    :param user: Database username
    :type user: str
    :param pswd: Database username's password
    :type pswd: str
    :param pswd: Database name
    :type pswd: str
    :return: Wether it is writeable or not
    :rtype: bool
    """

    # connection
    try:
        client = psycopg2.connect(
            dbname=name, user=user, password=pswd, host=host, port=port
        )
    except psycopg2.Error:
        error("Cannot connect to the database")
        return False

    # query
    try:
        psql = client.cursor()
        psql.execute("SELECT;")
    except psycopg2.Error:
        return False

    # close
    psql.close()
    client.close()

    return True


def check_replication(primary: dict, standby: dict) -> bool:
    """Check if replication is working between the primary and standby servers.

    :param primary: Connection details for primary server
    :type primary: dict
    :param standby: Connection details for standby server
    :type standby: dict
    :return: Wether replication between primary/stanbdy is working or not
    :rtype: bool
    """

    # connections
    try:
        primary_client = psycopg2.connect(**primary)
        standby_client = psycopg2.connect(**standby)
    except psycopg2.Error:
        error("Cannot connect to the databases")
        return False

    # queries
    try:
        primary_psql = primary_client.cursor()
        primary_psql.execute("CREATE TABLE es_test (id serial PRIMARY KEY);")
        standby_psql = primary_client.cursor()
        standby_psql.execute("SELECT * FROM es_test;")
        primary_psql.execute("DROP TABLE es_test;")
    except psycopg2.Error:
        return False

    # close
    primary_psql.close()
    standby_psql.close()
    primary_client.close()
    standby_client.close()

    return True


def check_listen(host: str, port: int) -> bool:
    """Check if server is listening (TCP only).

    :param host: The hostname or IP address to bind
    :param port: The port number to bind
    :type host: str
    :type port: int
    :return: Wether the `host` is listening on TCP/`port`
    :rtype: bool
    """

    # try to connect to the port used by psql-primary frontend
    client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    result = client.connect_ex((host, port))
    client.close()

    return result == 0


def check_ha(db_conn: dict, errors: int = 0, warnings: int = 0) -> tuple:
    """Run all tests for a highly-available setup.

    :param db_conn: Database connection parameters
    :type db_conn: dict
    :param errors: Error counter, defaults to 0
    :param errors: int, optional
    :param warnings: Warning counter, defaults to 0
    :param warnings: int, optional
    :return: Numbers of errors and warnings
    :rtype: tuple
    """

    db_host = db_conn["host"]
    db_user = db_conn["user"]
    db_pass = db_conn["pass"]

    # get haproxy conf
    ha_conf = get_haproxy_conf()

    # get nodes data
    nodes = get_nodes(ha_conf)

    # check haproxy
    print("Checking local HAProxy frontends:")
    if not check_listen(db_host, 54321):
        error("HAProxy pgsql-primary frontend is not listening")
        errors += 1
    else:
        success("HAProxy pgsql-primary frontend is listening")
    if not check_listen(db_host, 54322):
        error("HAProxy pgsql-standby frontend is not listening")
        errors += 1
    else:
        success("HAProxy pgsql-standby frontend is listening")
    print("Checking remote PostgreSQL nodes:")
    for node in nodes:
        node_host = nodes[node]["host"]
        node_port = nodes[node]["port"]
        if not check_listen(node_host, node_port):
            error("Cannot bind {}:{}".format(node_host, node_port))
            errors += 1
        else:
            success("Can bind {}:{}".format(node_host, node_port))
    print("Checking cluster state:")
    fenced, node = check_fenced(nodes)
    if fenced:
        error("Node `{}` is fenced".format(node))
        errors += 1
    else:

    # check replication
    print("Checking replication state:")
    primary = {
        "dbname": "postgres",
        "user": db_user,
        "password": db_pass,
        "host": db_host,
        "port": 54321,
    }
    standby = {
        "dbname": "postgres",
        "user": db_user,
        "password": db_pass,
        "host": db_host,
        "port": 54322,
    }
    if not check_replication(primary, standby):
        error("Cannot replicate data between primary/standby")
        errors += 1
    else:
        success("Can replicate data between primary/standby")
def check_local(db_conn: dict, errors: int = 0, warnings: int = 0) -> tuple:
    """Run all tests for a highly-available setup.

    :param db_conn: Database connection parameters
    :type db_conn: dict
    :param errors: Error counter, defaults to 0
    :param errors: int, optional
    :param warnings: Warning counter, defaults to 0
    :param warnings: int, optional
    :return: Numbers of errors and warnings
    :rtype: tuple
    """

    db_host = db_conn["host"]
    db_port = db_conn["port"]
    db_user = db_conn["user"]
    db_pass = db_conn["pass"]

    # check listen
    print("Checking local PostgreSQL node:")
    if not check_listen(db_host, db_port):
        error("Cannot connect to {}:{}".format(db_host, db_port))
        errors += 1
        success("Can connect to {}:{}".format(db_host, db_port))
    print("Checking read operation:")
    if not check_read(db_host, db_port, db_user, db_pass):
        error("Cannot read data on {}@{}:{}".format(db_user, db_host, db_port))
        errors += 1
    else:
        success("Can read data on {}@{}:{}".format(db_user, db_host, db_port))
    print("Checking write operation:")
    if not check_write(db_host, db_port, db_user, db_pass):
        error("Cannot write data on {}@{}:{}".format(db_user, db_host, db_port))
        errors += 1
    else:
        success("Can write data on {}@{}:{}".format(db_user, db_host, db_port))

    return errors, warnings


def main():
    """Run all checks and exits with corresponding exit code."""

    # envsetup utils path
    cwd = os.path.dirname(__file__)
    utils = os.path.join(cwd, "..", "utils.py")

    # check envsetup utils presence
    if not os.path.isfile(utils):
        error("{} not found.".format(utils))
        sys.exit(1)

    # load envsetup utils
    es_utils = imp.load_source("es_utils", utils)

    # load configuration
    conf = es_utils.load_conf()

    # get database configuration
    db_host = conf.get("DB_HOST") if conf.get("DB_HOST") else "127.0.0.1"
    db_port = int(conf.get("DB_PORT")) if conf.get("DB_PORT") else 5432
    db_user = conf.get("DB_USER") if conf.get("DB_USER") else "postgres"
    db_pass = conf.get("DB_PG_ROOT_PWD")
    db_conf = {"host": db_host, "port": db_port, "user": db_user, "pass": db_pass}

    # determine if HA setup
    if is_ha(db_port):
        print("This setup is using a HA database")

        errors, warnings = check_ha(db_conf)
    else:
        print("This setup is using a local database")

        errors, warnings = check_local(db_conf)

    if errors:
        sys.exit(1)
    elif warnings:
        sys.exit(2)
    else:
        sys.exit(0)


if __name__ == "__main__":
    main()