Source code for aiopnsense.vpn

"""OpenVPN and WireGuard methods for OPNsenseClient."""

from collections.abc import MutableMapping
from datetime import datetime, timedelta
from typing import Any

from ._typing import AiopnsenseClientProtocol
from .helpers import (
    _LOGGER,
    _log_errors,
    api_value_matches,
    dict_get,
    timestamp_to_datetime,
    try_to_int,
)


class VPNMixin(AiopnsenseClientProtocol):
    """VPN methods for OPNsenseClient."""

    @staticmethod
    def _mapping_value(data: MutableMapping[str, Any], key: str) -> MutableMapping[str, Any]:
        """Return a nested mapping field or an empty mapping.

        Args:
            data (MutableMapping[str, Any]): Source mapping.
            key (str): Field expected to contain a nested mapping.

        Returns:
            MutableMapping[str, Any]: Nested mapping, or an empty mapping when
            the field is unavailable or malformed.
        """
        value = data.get(key, {})
        return value if isinstance(value, MutableMapping) else {}

    @staticmethod
    def _wireguard_is_connected(past_time: datetime | None) -> bool:
        """Determine whether a WireGuard session is still considered active.

        Args:
            past_time (datetime | None): Timestamp of the most recent WireGuard handshake.

        Returns:
            bool: True if a wireguard session is still considered active; otherwise, False.
        """
        if not past_time:
            return False
        return datetime.now().astimezone() - past_time <= timedelta(minutes=3)

    @_log_errors
    async def get_openvpn(self) -> MutableMapping[str, Any]:
        """Return OpenVPN server and client status information.

        Returns:
            MutableMapping[str, Any]: Mapping with ``servers`` and ``clients``
                keyed by UUID. Server entries combine instance, provider,
                session, route, tunnel-address, DNS, status, byte-counter, and
                connected-client data when available.
        """
        # https://docs.opnsense.org/development/api/core/openvpn.html
        # https://github.com/opnsense/core/blob/master/src/opnsense/www/js/widgets/OpenVPNClients.js
        # https://github.com/opnsense/core/blob/master/src/opnsense/www/js/widgets/OpenVPNServers.js
        openvpn: dict[str, Any] = {"servers": {}, "clients": {}}

        # Fetch data
        sessions_endpoint = await self._get_endpoint_path(
            snake_case_path="/api/openvpn/service/search_sessions",
            camel_case_path="/api/openvpn/service/searchSessions",
        )
        if await self.is_endpoint_available(sessions_endpoint):
            sessions_info = await self._safe_dict_get(sessions_endpoint)
        else:
            _LOGGER.debug("OpenVPN sessions endpoint unavailable")
            sessions_info = {}

        routes_endpoint = await self._get_endpoint_path(
            snake_case_path="/api/openvpn/service/search_routes",
            camel_case_path="/api/openvpn/service/searchRoutes",
        )
        if await self.is_endpoint_available(routes_endpoint):
            routes_info = await self._safe_dict_get(routes_endpoint)
        else:
            _LOGGER.debug("OpenVPN routes endpoint unavailable")
            routes_info = {}

        providers_endpoint = "/api/openvpn/export/providers"
        if await self.is_endpoint_available(providers_endpoint):
            providers_info = await self._safe_dict_get(providers_endpoint)
        else:
            _LOGGER.debug("OpenVPN providers endpoint unavailable")
            providers_info = {}

        instances_endpoint = "/api/openvpn/instances/search"
        if await self.is_endpoint_available(instances_endpoint):
            instances_info = await self._safe_dict_get(instances_endpoint)
        else:
            _LOGGER.debug("OpenVPN instances endpoint unavailable")
            instances_info = {}

        self._process_openvpn_instances(instances_info, openvpn)
        self._process_openvpn_providers(providers_info, openvpn)
        self._process_openvpn_sessions(sessions_info, openvpn)
        self._process_openvpn_routes(routes_info, openvpn)
        await self._fetch_openvpn_server_details(openvpn)
        _LOGGER.debug(
            "[get_openvpn] servers: %s, clients: %s",
            len(openvpn["servers"]),
            len(openvpn["clients"]),
        )
        return openvpn

    @staticmethod
    def _process_openvpn_instances(
        instances_info: MutableMapping[str, Any], openvpn: MutableMapping[str, Any]
    ) -> None:
        """Process OpenVPN instances into servers and clients.

        Args:
            instances_info (MutableMapping[str, Any]): Raw OpenVPN instance payload from the API.
            openvpn (MutableMapping[str, Any]): Accumulated OpenVPN data structure being populated.
        """
        for instance in instances_info.get("rows", []):
            if not isinstance(instance, MutableMapping):
                continue
            role = instance.get("role", "").lower()
            uuid = instance.get("uuid")
            if role == "server":
                VPNMixin._add_openvpn_server(instance, openvpn)
            elif role == "client" and uuid:
                openvpn["clients"][uuid] = {
                    "name": instance.get("description"),
                    "uuid": uuid,
                    "enabled": api_value_matches(instance.get("enabled"), "1"),
                }

    @staticmethod
    def _add_openvpn_server(
        instance: MutableMapping[str, Any], openvpn: MutableMapping[str, Any]
    ) -> None:
        """Add a server to the OpenVPN structure.

        Args:
            instance (MutableMapping[str, Any]): OpenVPN instance entry from the API payload.
            openvpn (MutableMapping[str, Any]): Accumulated OpenVPN data structure being populated.
        """
        uuid = instance.get("uuid")
        if not uuid:
            return
        if uuid not in openvpn["servers"]:
            openvpn["servers"][uuid] = {
                "uuid": uuid,
                "name": instance.get("description"),
                "enabled": api_value_matches(instance.get("enabled"), "1"),
                "dev_type": instance.get("dev_type"),
                "clients": [],
            }

    @staticmethod
    def _process_openvpn_providers(
        providers_info: MutableMapping[str, Any], openvpn: MutableMapping[str, Any]
    ) -> None:
        """Process OpenVPN providers.

        Args:
            providers_info (MutableMapping[str, Any]): Raw OpenVPN provider payload from the API.
            openvpn (MutableMapping[str, Any]): Accumulated OpenVPN data structure being populated.
        """
        for uuid, vpn_info in providers_info.items():
            if not uuid or not isinstance(vpn_info, MutableMapping):
                continue
            server = openvpn["servers"].setdefault(uuid, {"uuid": uuid, "clients": []})
            server.update({"name": vpn_info.get("name")})
            if vpn_info.get("hostname") and vpn_info.get("local_port"):
                server["endpoint"] = f"{vpn_info['hostname']}:{vpn_info['local_port']}"

    @staticmethod
    def _process_openvpn_sessions(
        sessions_info: MutableMapping[str, Any], openvpn: MutableMapping[str, Any]
    ) -> None:
        """Process OpenVPN sessions.

        Args:
            sessions_info (MutableMapping[str, Any]): Raw OpenVPN session payload from the API.
            openvpn (MutableMapping[str, Any]): Accumulated OpenVPN data structure being populated.
        """
        for session in sessions_info.get("rows", []):
            if not isinstance(session, MutableMapping) or "id" not in session:
                continue
            if session.get("type") != "server":
                continue
            server_id = str(session["id"]).split("_", 1)[0]
            server = openvpn["servers"].setdefault(server_id, {"uuid": server_id, "clients": []})
            if description := session.get("description"):
                server["name"] = description
            VPNMixin._update_openvpn_server_status(server, session)

    @staticmethod
    def _update_openvpn_server_status(
        server: MutableMapping[str, Any], session: MutableMapping[str, Any]
    ) -> None:
        """Update server status based on session data.

        Args:
            server (MutableMapping[str, Any]): Server entry to update.
            session (MutableMapping[str, Any]): Session entry payload retrieved from the API.
        """
        status = session.get("status")
        if not session.get("is_client", False):
            server["status"] = (
                "disabled"
                if not server.get("enabled", True)
                else "up"
                if status in {"connected", "ok"}
                else "failed"
                if status == "failed"
                else status or "down"
            )
        else:
            server.update(
                {
                    "status": "up",
                    "latest_handshake": timestamp_to_datetime(
                        session.get("connected_since__time_t_")
                    ),
                    "total_bytes_recv": try_to_int(session.get("bytes_received", 0), 0),
                    "total_bytes_sent": try_to_int(session.get("bytes_sent", 0), 0),
                }
            )

    @staticmethod
    def _process_openvpn_routes(
        routes_info: MutableMapping[str, Any], openvpn: MutableMapping[str, Any]
    ) -> None:
        """Process OpenVPN routes.

        Args:
            routes_info (MutableMapping[str, Any]): Raw OpenVPN route payload from the API.
            openvpn (MutableMapping[str, Any]): Accumulated OpenVPN data structure being populated.
        """
        for route in routes_info.get("rows", []):
            if not isinstance(route, MutableMapping):
                continue
            server_id = route.get("id")
            if server_id not in openvpn["servers"]:
                continue
            openvpn["servers"][server_id]["clients"].append(
                {
                    "name": route.get("common_name"),
                    "endpoint": route.get("real_address"),
                    "tunnel_addresses": [route.get("virtual_address")],
                    "latest_handshake": timestamp_to_datetime(route.get("last_ref__time_t_", 0)),
                }
            )

    async def _fetch_openvpn_server_details(self, openvpn: MutableMapping[str, Any]) -> None:
        """Fetch detailed server information.

        Args:
            openvpn (MutableMapping[str, Any]): Accumulated OpenVPN data structure being populated.
        """
        for uuid, server in openvpn["servers"].items():
            server.setdefault("total_bytes_sent", 0)
            server.setdefault("total_bytes_recv", 0)
            server["connected_clients"] = len(server.get("clients", []))
            details_endpoint = f"/api/openvpn/instances/get/{uuid}"
            if await self.is_endpoint_available(details_endpoint):
                details_info = await self._safe_dict_get(details_endpoint)
            else:
                _LOGGER.debug("OpenVPN instance details endpoint unavailable for uuid: %s", uuid)
                details_info = {}
            details = self._mapping_value(details_info, "instance")
            if details.get("server"):
                server["tunnel_addresses"] = [details["server"]]
            server["dns_servers"] = [
                dns["value"]
                for dns in self._mapping_value(details, "dns_servers").values()
                if isinstance(dns, MutableMapping)
                and api_value_matches(dns.get("selected"), "1")
                and dns.get("value")
            ]

    @_log_errors
    async def get_wireguard(self) -> MutableMapping[str, Any]:
        """Return WireGuard server and client status information.

        Returns:
            MutableMapping[str, Any]: Mapping with ``servers`` and ``clients``
                keyed by UUID. Entries combine configured tunnel addresses,
                peer links, enabled state, interface names, handshake status,
                byte counters, and connected peer counts where available.
        """
        data_sources = {
            "summary_raw": "/api/wireguard/service/show",
            "clients_raw": "/api/wireguard/client/get",
            "servers_raw": "/api/wireguard/server/get",
        }
        data: dict[str, dict[str, Any]] = {}
        for key, path in data_sources.items():
            if await self.is_endpoint_available(path):
                data[key] = await self._safe_dict_get(path)
            else:
                _LOGGER.debug("WireGuard endpoint unavailable: %s", path)
                data[key] = {}

        summary = data["summary_raw"].get("rows", [])
        client_summ = dict_get(data["clients_raw"], "client.clients.client", {})
        server_summ = dict_get(data["servers_raw"], "server.servers.server", {})

        if (
            not isinstance(summary, list)
            or not isinstance(client_summ, MutableMapping)
            or not isinstance(server_summ, MutableMapping)
        ):
            _LOGGER.debug("[get_wireguard] servers: 0, clients: 0")
            return {"servers": {}, "clients": {}}

        servers = {
            uid: self._process_wireguard_server(uid, srv, client_summ)
            for uid, srv in server_summ.items()
            if isinstance(srv, MutableMapping)
        }
        clients = {
            uid: self._process_wireguard_client(uid, clnt, servers)
            for uid, clnt in client_summ.items()
            if isinstance(clnt, MutableMapping)
        }

        self._update_wireguard_status(summary, servers, clients)

        wireguard = {"servers": servers, "clients": clients}
        _LOGGER.debug(
            "[get_wireguard] servers: %s, clients: %s",
            len(servers),
            len(clients),
        )
        return wireguard

    @staticmethod
    def _process_wireguard_server(
        uid: str, srv: MutableMapping[str, Any], client_summ: MutableMapping[str, Any]
    ) -> MutableMapping[str, Any]:
        """Process a single WireGuard server entry.

        Args:
            uid (str): WireGuard server UUID from the server settings payload.
            srv (MutableMapping[str, Any]): WireGuard server settings row.
            client_summ (MutableMapping[str, Any]): Client settings mapping
                used to enrich selected peer links with public keys.

        Returns:
            MutableMapping[str, Any]: Normalized server mapping with identity,
                enabled state, interface, DNS servers, tunnel addresses,
                configured client peers, and initialized connection counters.
        """
        tunnel_addresses = VPNMixin._mapping_value(srv, "tunneladdress")
        peers = VPNMixin._mapping_value(srv, "peers")
        return {
            "uuid": uid,
            "name": srv.get("name"),
            "pubkey": srv.get("pubkey"),
            "enabled": api_value_matches(srv.get("enabled"), "1"),
            "interface": f"wg{srv.get('instance', '')}",
            "dns_servers": [srv.get("peer_dns")] if srv.get("peer_dns") else [],
            "tunnel_addresses": [
                addr.get("value")
                for addr in tunnel_addresses.values()
                if isinstance(addr, MutableMapping)
                and api_value_matches(addr.get("selected"), "1")
                and addr.get("value")
            ],
            "clients": [
                {
                    "name": peer.get("value"),
                    "uuid": peer_id,
                    "pubkey": client_summ.get(peer_id, {}).get("pubkey"),
                    "connected": False,
                }
                for peer_id, peer in peers.items()
                if isinstance(peer, MutableMapping)
                and api_value_matches(peer.get("selected"), "1")
                and peer.get("value")
            ],
            "connected_clients": 0,
            "total_bytes_recv": 0,
            "total_bytes_sent": 0,
        }

    @staticmethod
    def _process_wireguard_client(
        uid: str, clnt: MutableMapping[str, Any], servers: MutableMapping[str, Any]
    ) -> MutableMapping[str, Any]:
        """Process a single WireGuard client entry.

        Args:
            uid (str): WireGuard client UUID from the client settings payload.
            clnt (MutableMapping[str, Any]): WireGuard client settings row.
            servers (MutableMapping[str, Any]): Normalized server mapping used
                to enrich configured server links.

        Returns:
            MutableMapping[str, Any]: Normalized client mapping with identity,
                enabled state, tunnel addresses, configured server links, and
                initialized connection counters.
        """
        tunnel_addresses = VPNMixin._mapping_value(clnt, "tunneladdress")
        server_links = VPNMixin._mapping_value(clnt, "servers")
        return {
            "uuid": uid,
            "name": clnt.get("name"),
            "pubkey": clnt.get("pubkey"),
            "enabled": api_value_matches(clnt.get("enabled"), "1"),
            "tunnel_addresses": [
                addr.get("value")
                for addr in tunnel_addresses.values()
                if isinstance(addr, MutableMapping)
                and api_value_matches(addr.get("selected"), "1")
                and addr.get("value")
            ],
            "servers": [
                VPNMixin._link_wireguard_client_to_server(srv_id, servers, srv)
                for srv_id, srv in server_links.items()
                if isinstance(srv, MutableMapping)
                and api_value_matches(srv.get("selected"), "1")
                and srv.get("value")
            ],
            "connected_servers": 0,
            "total_bytes_recv": 0,
            "total_bytes_sent": 0,
        }

    @staticmethod
    def _link_wireguard_client_to_server(
        srv_id: str, servers: MutableMapping[str, Any], srv: MutableMapping[str, Any]
    ) -> MutableMapping[str, Any]:
        """Link a WireGuard client to its corresponding server.

        Args:
            srv_id (str): Server identifier used to match related entries.
            servers (MutableMapping[str, Any]): Server mapping keyed by server identifier.
            srv (MutableMapping[str, Any]): WireGuard server mapping entry.

        Returns:
            MutableMapping[str, Any]: Mapping that describes the linked
                client-to-server relationship, including keys such as
                ``name``, ``uuid``, and ``connected``, and optionally
                ``pubkey``, ``interface``, and ``tunnel_addresses`` when the
                referenced server exists in ``servers``.
        """
        if srv_id in servers:
            server = servers[srv_id]
            return {
                "name": server.get("name"),
                "uuid": srv_id,
                "connected": False,
                "pubkey": server.get("pubkey"),
                "interface": server.get("interface"),
                "tunnel_addresses": server.get("tunnel_addresses"),
            }
        return {
            "name": srv.get("value"),
            "uuid": srv_id,
            "connected": False,
        }

    @staticmethod
    def _update_wireguard_status(
        summary: list[MutableMapping[str, Any]],
        servers: MutableMapping[str, Any],
        clients: MutableMapping[str, Any],
    ) -> None:
        """Update WireGuard server and client statuses based on the summary.

        Args:
            summary (list[MutableMapping[str, Any]]): WireGuard summary payload from the API.
            servers (MutableMapping[str, Any]): Server mapping keyed by server identifier.
            clients (MutableMapping[str, Any]): Client mapping keyed by client identifier.
        """
        for entry in summary:
            if not isinstance(entry, MutableMapping):
                continue
            if entry.get("type") == "interface":
                for server in servers.values():
                    if server.get("pubkey") == entry.get("public-key"):
                        server["status"] = entry.get("status")
            elif entry.get("type") == "peer":
                VPNMixin._update_wireguard_peer_status(entry, servers, clients)

    @staticmethod
    def _update_wireguard_peer_status(
        entry: MutableMapping[str, Any],
        servers: MutableMapping[str, Any],
        clients: MutableMapping[str, Any],
    ) -> None:
        """Update the WireGuard peer status for clients and servers.

        Args:
            entry (MutableMapping[str, Any]): WireGuard summary peer row from
                ``/api/wireguard/service/show``.
            servers (MutableMapping[str, Any]): Server mapping keyed by server identifier.
            clients (MutableMapping[str, Any]): Client mapping keyed by client identifier.
        """
        pubkey = entry.get("public-key", "-")
        interface = entry.get("if", "-")
        endpoint = entry.get("endpoint", None)
        transfer_rx: int = try_to_int(entry.get("transfer-rx", 0), 0) or 0
        transfer_tx: int = try_to_int(entry.get("transfer-tx", 0), 0) or 0
        latest_handshake = try_to_int(entry.get("latest-handshake", 0), 0)
        handshake_time = timestamp_to_datetime(latest_handshake)
        is_connected = VPNMixin._wireguard_is_connected(handshake_time)

        # Update servers
        for server in servers.values():
            if server.get("interface") == interface:
                for client in server.get("clients", []):
                    if client.get("pubkey") == pubkey:
                        VPNMixin._update_wireguard_peer_details(
                            peer=client,
                            server_or_client=server,
                            endpoint=endpoint,
                            transfer_rx=transfer_rx,
                            transfer_tx=transfer_tx,
                            handshake_time=handshake_time,
                            is_connected=is_connected,
                            connection_counter_key="connected_clients",
                        )

        # Update clients
        for client in clients.values():
            if client.get("pubkey") == pubkey:
                for server in client.get("servers", []):
                    if server.get("interface") == interface:
                        VPNMixin._update_wireguard_peer_details(
                            peer=server,
                            server_or_client=client,
                            endpoint=endpoint,
                            transfer_rx=transfer_rx,
                            transfer_tx=transfer_tx,
                            handshake_time=handshake_time,
                            is_connected=is_connected,
                            connection_counter_key="connected_servers",
                        )

    @staticmethod
    def _update_wireguard_peer_details(
        peer: MutableMapping[str, Any],
        server_or_client: MutableMapping[str, Any],
        endpoint: str,
        transfer_rx: int,
        transfer_tx: int,
        handshake_time: datetime | None,
        is_connected: bool,
        connection_counter_key: str,
    ) -> None:
        """Apply live WireGuard peer counters to a linked server or client.

        Args:
            peer (MutableMapping[str, Any]): Linked peer mapping to update.
            server_or_client (MutableMapping[str, Any]): Parent server or
                client mapping whose aggregate counters should be incremented.
            endpoint (str): Remote endpoint string reported by WireGuard.
            transfer_rx (int): Received byte counter for peer statistics.
            transfer_tx (int): Transmitted byte counter for peer statistics.
            handshake_time (datetime | None): Handshake time used by this operation.
            is_connected (bool): Connection status flag for the interface entry.
            connection_counter_key (str): Counter key used for interface connection tracking.
        """
        if endpoint and endpoint != "(none)":
            peer["endpoint"] = endpoint
        peer["bytes_recv"] = transfer_rx
        peer["bytes_sent"] = transfer_tx
        peer["latest_handshake"] = handshake_time
        peer["connected"] = is_connected

        # Update the parent (server or client) stats
        server_or_client["total_bytes_recv"] = (
            server_or_client.get("total_bytes_recv", 0) + transfer_rx
        )
        server_or_client["total_bytes_sent"] = (
            server_or_client.get("total_bytes_sent", 0) + transfer_tx
        )

        if is_connected:
            server_or_client[connection_counter_key] = (
                server_or_client.get(connection_counter_key, 0) + 1
            )
            # Update the latest handshake time if it's newer
            if (
                server_or_client.get("latest_handshake") is None
                or server_or_client["latest_handshake"] < handshake_time
            ):
                server_or_client["latest_handshake"] = handshake_time

    async def toggle_vpn_instance(self, vpn_type: str, clients_servers: str, uuid: str) -> bool:
        """Toggle the specified VPN instance on or off.

        Args:
            vpn_type (str): Vpn type used by this operation.
            clients_servers (str): WireGuard collection to toggle. Use
                ``clients`` for a client entry or ``servers`` for a server
                entry. Ignored for OpenVPN.
            uuid (str): UUID of the VPN instance to toggle.

        Returns:
            bool: True when the toggle operation completes successfully; otherwise, False.
        """
        if vpn_type == "openvpn":
            success = await self._safe_dict_post(f"/api/openvpn/instances/toggle/{uuid}")
            if not success.get("changed", False):
                return False
            reconfigure = await self._safe_dict_post("/api/openvpn/service/reconfigure")
            return reconfigure.get("result", "") == "ok"
        if vpn_type == "wireguard":
            if clients_servers == "clients":
                endpoint = await self._get_endpoint_path(
                    snake_case_path=f"/api/wireguard/client/toggle_client/{uuid}",
                    camel_case_path=f"/api/wireguard/client/toggleClient/{uuid}",
                )
            elif clients_servers == "servers":
                endpoint = await self._get_endpoint_path(
                    snake_case_path=f"/api/wireguard/server/toggle_server/{uuid}",
                    camel_case_path=f"/api/wireguard/server/toggleServer/{uuid}",
                )
            else:
                return False
            success = await self._safe_dict_post(endpoint)
            if not success.get("changed", False):
                return False
            reconfigure = await self._safe_dict_post("/api/wireguard/service/reconfigure")
            return reconfigure.get("result", "") == "ok"
        return False