"""
Copyright (c) 2025 Proton AG

This file is part of Proton VPN.

Proton VPN is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

Proton VPN is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with ProtonVPN.  If not, see <https://www.gnu.org/licenses/>.
"""
from typing import Optional
import ctypes
import os
import re
import subprocess  # nosec # nosemgrep: gitlab.bandit.B404

from bcc import BPF

from proton.vpn import logging

logger = logging.getLogger(__name__)


_FWMARK_MAP_KEY = ctypes.c_int32(1)


# ebpf program to split traffic based on PID map
BPF_PROGRAM = """
BPF_HASH(fwmark_map, u32, u32);
BPF_HASH(pid_map, u32, u32);

int split_tunnel(struct bpf_sock *sk) {
    u32 pid = bpf_get_current_pid_tgid();

    u32 fwmark_key = 1;
    u32 *fwmark = fwmark_map.lookup(&fwmark_key);

    u32 *pid_found = pid_map.lookup(&pid);

    if (fwmark && pid_found) {
        bpf_trace_printk("Excluding PID %d from VPN\\n", pid);
        sk->mark = *fwmark;
    }

    return 1;
}
"""


def _get_wireguard_fwmark(interface) -> Optional[int]:
    # FIXME: get fwmark via D-Bus from NetworkManager  # pylint: disable=fixme
    #  instead of with sudo wg show. The WG backend could somehow make it available?
    try:
        result = subprocess.run(  # nosec
            ["/usr/bin/wg", "show", interface], capture_output=True, text=True, check=True
        )
        fwmark_match = re.search(r'fwmark:\s+(\S+)', result.stdout)
        if fwmark_match:
            fwmark_hex = fwmark_match.group(1)
            fwmark_int = int(fwmark_hex, 16)
            return fwmark_int

        raise RuntimeError(
            f"Couldn't get fwmark from interface {interface}: \n\n{result.stdout}"
        )
    except subprocess.CalledProcessError:
        # `wg show proton0` failed: the user is not connected to the VPN
        return None


class SocketMonitor:
    """Split-tunnels sockets based on the proces that creates them."""

    # FIXME could the WG backend somehow pass the interface name (proton0)  # pylint: disable=fixme
    WIREGUARD_INTERFACE_NAME = "proton0"

    def __init__(self):
        self._bpf = BPF(text=BPF_PROGRAM)
        self._bpf_fwmark_map = self._bpf.get_table("fwmark_map")
        self._bpf_pid_map = self._bpf.get_table("pid_map")
        self._bpf_split_tunneling_func = self._bpf.load_func(
            "split_tunnel", self._bpf.CGROUP_SOCK
        )
        self._cgroup = None
        self._bpf_enum_group = None

    def log_status(self):
        """Logs the socket monitor status."""
        logger.info("==============Socket monitor status==================")
        logger.info("Tracked PIDs: %s", [key.value for key in self._bpf_pid_map.keys()])
        logger.info("fwmark: %s", self._bpf_fwmark_map.get(_FWMARK_MAP_KEY))
        logger.info("=====================================================")

    def start(self):
        """Starts monitoring sockets."""
        fwmark = _get_wireguard_fwmark(self.WIREGUARD_INTERFACE_NAME)
        if fwmark is not None:
            self._bpf_fwmark_map[_FWMARK_MAP_KEY] = ctypes.c_int32(fwmark)
        elif _FWMARK_MAP_KEY in self._bpf_fwmark_map:
            del self._bpf_fwmark_map[_FWMARK_MAP_KEY]

        if self._started:
            logger.info("Socket monitor already running: fwmark updated to %s", fwmark)
            return

        logger.info("Starting socket monitor (with fwmark %s)", fwmark)
        self._cgroup = os.open("/sys/fs/cgroup/user.slice", os.O_RDONLY)

        # attach the ST eBPF function to the cgroup/sock_create event
        # of the root user cgroup
        self._bpf.attach_func(
            self._bpf_split_tunneling_func,
            self._cgroup,
            self._backwards_compatible_bfp_attach_type.CGROUP_INET_SOCK_CREATE
        )

    @property
    def _started(self):
        return self._cgroup is not None

    def _cleanup(self):
        self._bpf_pid_map.clear()
        self._bpf_fwmark_map.clear()

    def __enter__(self):
        self.start()
        return self

    def __exit__(self, exc_type, exc_value, exc_traceback):
        self.stop()

    def exclude_process_from_vpn(self, pid: int):  # pylint: disable=redefined-outer-name
        """
        Split tunnels the specified process. Afterwards, traffic generated by sockets
        created by such process will be redirected outside the VPN.
        @param pid: the unix id (pid) of the process to split tunnel.
        """
        if not self._started:
            raise RuntimeError("Socket monitor was not started yet")

        logger.debug("Adding %s to pid map", pid)
        self._bpf_pid_map[ctypes.c_uint32(pid)] = ctypes.c_uint32(1)

    def stop_tracking_process(self, pid: int):  # pylint: disable=redefined-outer-name
        """
        Stops tracking sockets opened by the specified process process.
        @param pid: the unix id of the process.
        If the specified process was not being split tunneled then this is a noop.
        """
        if not self._started:
            raise RuntimeError("Socket monitor was not started yet")

        pid_c_uint32 = ctypes.c_uint32(pid)
        if pid_c_uint32 in self._bpf_pid_map:
            logger.debug("Removing %s from pid map", pid)
            del self._bpf_pid_map[pid_c_uint32]

    def stop(self):
        """Stops monitoring sockets."""
        if not self._started:
            logger.info("Socket monitor is already stopped")
            return

        logger.info("Unloading split tunneling ebpf function")
        self._bpf.detach_func(
            self._bpf_split_tunneling_func,
            self._cgroup,
            self._backwards_compatible_bfp_attach_type.CGROUP_INET_SOCK_CREATE
        )
        self._cleanup()
        os.close(self._cgroup)
        self._cgroup = None

        logger.info("Socket monitor stopped")

    @property
    def _backwards_compatible_bfp_attach_type(self) -> object:
        """In v20 of bcc a refactor was made where enums
        were extracted into their own type, the BPFAttachType.

        Before that the types were part of the bpf program.

        See more here:
        https://github.com/iovisor/bcc/commit/2731825b9327a9a720f2ef92ed891ce0525a8dc3

        Returns:
            object: Either the BPFAttachType or bpf program.
        """
        if not self._bpf_enum_group:
            try:
                from bcc import BPFAttachType  # pylint: disable=import-outside-toplevel
                self._bpf_enum_group = BPFAttachType
            except ImportError:
                self._bpf_enum_group = self._bpf

        return self._bpf_enum_group


def main():
    """Test script"""
    is_child_process = os.fork()
    if is_child_process:
        pid = os.getpid()
        with SocketMonitor() as socket_monitor:
            socket_monitor.exclude_process_from_vpn(pid)
            from urllib.request import Request, urlopen  # pylint: disable=import-outside-toplevel
            request = Request("https://ip.me", headers={"User-Agent": "curl/7.54.1"})
            # nosemgrep: python.lang.security.audit.dynamic-urllib-use-detected.dynamic-urllib-use-detected # pylint: disable=line-too-long # noqa: E501
            with urlopen(request) as response:  # nosemgrep: gitlab.bandit.B310-1 # nosec B310
                current_ip = response.read().decode('utf-8')
                print(f"Current IP: {current_ip}")
