#!/usr/bin/env python3
# Copyright (c) 2017, VMware, Inc. or its affiliates.

import os
from functools import reduce


class ValidationException(Exception):
    def __init__(self, message):
        super().__init__(message)
        self.message = message


class CgroupValidation(object):
    @staticmethod
    def detect_cgroup_mount_point():
        # Get the cgroup mount place
        proc_mounts_path = "/proc/self/mounts"
        if os.path.exists(proc_mounts_path):
            with open(proc_mounts_path) as f:
                for line in f:
                    mount_specs = line.split()
                    if mount_specs[2] == "cgroup2":
                        return mount_specs[1]
        return ""


class CgroupValidationVersionTwo(CgroupValidation):
    def __init__(self):
        self.mount_point = self.detect_cgroup_mount_point()
        self.tab = {"r": os.R_OK, "w": os.W_OK, "x": os.X_OK, "f": os.F_OK}

    def validate_all(self):
        """
        Check the permissions of the toplevel gpdb cgroup dirs.

        The checks should keep in sync with
        src/backend/utils/resgroup/cgroup-ops-v2.c
        """

        if not self.mount_point:
            self.die("failed to detect cgroup v2 mount point.")

        self.validate_permission("cgroup.procs", "rw")

        self.validate_permission("gpdb/", "rwx")
        self.validate_permission("gpdb/cgroup.procs", "rw")

        self.validate_permission("gpdb/cpu.max", "rw")
        self.validate_permission("gpdb/cpu.weight", "rw")
        self.validate_permission("gpdb/cpu.weight.nice", "rw")
        self.validate_permission("gpdb/cpu.stat", "r")

        self.validate_permission("gpdb/cpuset.cpus", "rw")
        self.validate_permission("gpdb/cpuset.cpus.partition", "rw")
        self.validate_permission("gpdb/cpuset.mems", "rw")
        self.validate_permission("gpdb/cpuset.cpus.effective", "r")
        self.validate_permission("gpdb/cpuset.mems.effective", "r")

        self.validate_permission("gpdb/memory.current", "r")

        self.validate_permission("gpdb/io.max", "rw")

    def die(self, msg):
        raise ValidationException("cgroup is not properly configured: {}".format(msg))

    def validate_permission(self, path, mode):
        """
        Validate permission on path.
        If path is a dir it must end with '/'.
        """
        fullpath = os.path.join(self.mount_point, path)
        pathtype = path[-1] == "/" and "directory" or "file"
        mode_bits = reduce(lambda x, y: x | y, [self.tab[x] for x in mode], 0)

        try:
            if not os.path.exists(fullpath):
                self.die("{} '{}' does not exist".
                         format(pathtype, fullpath))
            if not os.access(fullpath, mode_bits):
                self.die("{} '{}' permission denied: require permission '{}'".
                         format(pathtype, fullpath, mode))
        except IOError as e:
            self.die("can't check permission on {} '{}': {}".format(pathtype, fullpath, str(e)))


if __name__ == '__main__':
    try:
        CgroupValidationVersionTwo().validate_all()
    except ValidationException as e:
        exit(e.message)
