"""OpenVPN and WireGuard methods for OPNsenseClient."""
from collections.abc import MutableMapping
from datetime import datetime, timedelta
from typing import Any
from ._typing import AiopnsenseClientProtocol
from .helpers import (
_LOGGER,
_log_errors,
api_value_matches,
dict_get,
timestamp_to_datetime,
try_to_int,
)
class VPNMixin(AiopnsenseClientProtocol):
"""VPN methods for OPNsenseClient."""
@staticmethod
def _mapping_value(data: MutableMapping[str, Any], key: str) -> MutableMapping[str, Any]:
"""Return a nested mapping field or an empty mapping.
Args:
data (MutableMapping[str, Any]): Source mapping.
key (str): Field expected to contain a nested mapping.
Returns:
MutableMapping[str, Any]: Nested mapping, or an empty mapping when
the field is unavailable or malformed.
"""
value = data.get(key, {})
return value if isinstance(value, MutableMapping) else {}
@staticmethod
def _wireguard_is_connected(past_time: datetime | None) -> bool:
"""Determine whether a WireGuard session is still considered active.
Args:
past_time (datetime | None): Timestamp of the most recent WireGuard handshake.
Returns:
bool: True if a wireguard session is still considered active; otherwise, False.
"""
if not past_time:
return False
return datetime.now().astimezone() - past_time <= timedelta(minutes=3)
@_log_errors
async def get_openvpn(self) -> MutableMapping[str, Any]:
"""Return OpenVPN server and client status information.
Returns:
MutableMapping[str, Any]: Mapping with ``servers`` and ``clients``
keyed by UUID. Server entries combine instance, provider,
session, route, tunnel-address, DNS, status, byte-counter, and
connected-client data when available.
"""
# https://docs.opnsense.org/development/api/core/openvpn.html
# https://github.com/opnsense/core/blob/master/src/opnsense/www/js/widgets/OpenVPNClients.js
# https://github.com/opnsense/core/blob/master/src/opnsense/www/js/widgets/OpenVPNServers.js
openvpn: dict[str, Any] = {"servers": {}, "clients": {}}
# Fetch data
sessions_endpoint = await self._get_endpoint_path(
snake_case_path="/api/openvpn/service/search_sessions",
camel_case_path="/api/openvpn/service/searchSessions",
)
if await self.is_endpoint_available(sessions_endpoint):
sessions_info = await self._safe_dict_get(sessions_endpoint)
else:
_LOGGER.debug("OpenVPN sessions endpoint unavailable")
sessions_info = {}
routes_endpoint = await self._get_endpoint_path(
snake_case_path="/api/openvpn/service/search_routes",
camel_case_path="/api/openvpn/service/searchRoutes",
)
if await self.is_endpoint_available(routes_endpoint):
routes_info = await self._safe_dict_get(routes_endpoint)
else:
_LOGGER.debug("OpenVPN routes endpoint unavailable")
routes_info = {}
providers_endpoint = "/api/openvpn/export/providers"
if await self.is_endpoint_available(providers_endpoint):
providers_info = await self._safe_dict_get(providers_endpoint)
else:
_LOGGER.debug("OpenVPN providers endpoint unavailable")
providers_info = {}
instances_endpoint = "/api/openvpn/instances/search"
if await self.is_endpoint_available(instances_endpoint):
instances_info = await self._safe_dict_get(instances_endpoint)
else:
_LOGGER.debug("OpenVPN instances endpoint unavailable")
instances_info = {}
self._process_openvpn_instances(instances_info, openvpn)
self._process_openvpn_providers(providers_info, openvpn)
self._process_openvpn_sessions(sessions_info, openvpn)
self._process_openvpn_routes(routes_info, openvpn)
await self._fetch_openvpn_server_details(openvpn)
_LOGGER.debug(
"[get_openvpn] servers: %s, clients: %s",
len(openvpn["servers"]),
len(openvpn["clients"]),
)
return openvpn
@staticmethod
def _process_openvpn_instances(
instances_info: MutableMapping[str, Any], openvpn: MutableMapping[str, Any]
) -> None:
"""Process OpenVPN instances into servers and clients.
Args:
instances_info (MutableMapping[str, Any]): Raw OpenVPN instance payload from the API.
openvpn (MutableMapping[str, Any]): Accumulated OpenVPN data structure being populated.
"""
for instance in instances_info.get("rows", []):
if not isinstance(instance, MutableMapping):
continue
role = instance.get("role", "").lower()
uuid = instance.get("uuid")
if role == "server":
VPNMixin._add_openvpn_server(instance, openvpn)
elif role == "client" and uuid:
openvpn["clients"][uuid] = {
"name": instance.get("description"),
"uuid": uuid,
"enabled": api_value_matches(instance.get("enabled"), "1"),
}
@staticmethod
def _add_openvpn_server(
instance: MutableMapping[str, Any], openvpn: MutableMapping[str, Any]
) -> None:
"""Add a server to the OpenVPN structure.
Args:
instance (MutableMapping[str, Any]): OpenVPN instance entry from the API payload.
openvpn (MutableMapping[str, Any]): Accumulated OpenVPN data structure being populated.
"""
uuid = instance.get("uuid")
if not uuid:
return
if uuid not in openvpn["servers"]:
openvpn["servers"][uuid] = {
"uuid": uuid,
"name": instance.get("description"),
"enabled": api_value_matches(instance.get("enabled"), "1"),
"dev_type": instance.get("dev_type"),
"clients": [],
}
@staticmethod
def _process_openvpn_providers(
providers_info: MutableMapping[str, Any], openvpn: MutableMapping[str, Any]
) -> None:
"""Process OpenVPN providers.
Args:
providers_info (MutableMapping[str, Any]): Raw OpenVPN provider payload from the API.
openvpn (MutableMapping[str, Any]): Accumulated OpenVPN data structure being populated.
"""
for uuid, vpn_info in providers_info.items():
if not uuid or not isinstance(vpn_info, MutableMapping):
continue
server = openvpn["servers"].setdefault(uuid, {"uuid": uuid, "clients": []})
server.update({"name": vpn_info.get("name")})
if vpn_info.get("hostname") and vpn_info.get("local_port"):
server["endpoint"] = f"{vpn_info['hostname']}:{vpn_info['local_port']}"
@staticmethod
def _process_openvpn_sessions(
sessions_info: MutableMapping[str, Any], openvpn: MutableMapping[str, Any]
) -> None:
"""Process OpenVPN sessions.
Args:
sessions_info (MutableMapping[str, Any]): Raw OpenVPN session payload from the API.
openvpn (MutableMapping[str, Any]): Accumulated OpenVPN data structure being populated.
"""
for session in sessions_info.get("rows", []):
if not isinstance(session, MutableMapping) or "id" not in session:
continue
if session.get("type") != "server":
continue
server_id = str(session["id"]).split("_", 1)[0]
server = openvpn["servers"].setdefault(server_id, {"uuid": server_id, "clients": []})
if description := session.get("description"):
server["name"] = description
VPNMixin._update_openvpn_server_status(server, session)
@staticmethod
def _update_openvpn_server_status(
server: MutableMapping[str, Any], session: MutableMapping[str, Any]
) -> None:
"""Update server status based on session data.
Args:
server (MutableMapping[str, Any]): Server entry to update.
session (MutableMapping[str, Any]): Session entry payload retrieved from the API.
"""
status = session.get("status")
if not session.get("is_client", False):
server["status"] = (
"disabled"
if not server.get("enabled", True)
else "up"
if status in {"connected", "ok"}
else "failed"
if status == "failed"
else status or "down"
)
else:
server.update(
{
"status": "up",
"latest_handshake": timestamp_to_datetime(
session.get("connected_since__time_t_")
),
"total_bytes_recv": try_to_int(session.get("bytes_received", 0), 0),
"total_bytes_sent": try_to_int(session.get("bytes_sent", 0), 0),
}
)
@staticmethod
def _process_openvpn_routes(
routes_info: MutableMapping[str, Any], openvpn: MutableMapping[str, Any]
) -> None:
"""Process OpenVPN routes.
Args:
routes_info (MutableMapping[str, Any]): Raw OpenVPN route payload from the API.
openvpn (MutableMapping[str, Any]): Accumulated OpenVPN data structure being populated.
"""
for route in routes_info.get("rows", []):
if not isinstance(route, MutableMapping):
continue
server_id = route.get("id")
if server_id not in openvpn["servers"]:
continue
openvpn["servers"][server_id]["clients"].append(
{
"name": route.get("common_name"),
"endpoint": route.get("real_address"),
"tunnel_addresses": [route.get("virtual_address")],
"latest_handshake": timestamp_to_datetime(route.get("last_ref__time_t_", 0)),
}
)
async def _fetch_openvpn_server_details(self, openvpn: MutableMapping[str, Any]) -> None:
"""Fetch detailed server information.
Args:
openvpn (MutableMapping[str, Any]): Accumulated OpenVPN data structure being populated.
"""
for uuid, server in openvpn["servers"].items():
server.setdefault("total_bytes_sent", 0)
server.setdefault("total_bytes_recv", 0)
server["connected_clients"] = len(server.get("clients", []))
details_endpoint = f"/api/openvpn/instances/get/{uuid}"
if await self.is_endpoint_available(details_endpoint):
details_info = await self._safe_dict_get(details_endpoint)
else:
_LOGGER.debug("OpenVPN instance details endpoint unavailable for uuid: %s", uuid)
details_info = {}
details = self._mapping_value(details_info, "instance")
if details.get("server"):
server["tunnel_addresses"] = [details["server"]]
server["dns_servers"] = [
dns["value"]
for dns in self._mapping_value(details, "dns_servers").values()
if isinstance(dns, MutableMapping)
and api_value_matches(dns.get("selected"), "1")
and dns.get("value")
]
@_log_errors
async def get_wireguard(self) -> MutableMapping[str, Any]:
"""Return WireGuard server and client status information.
Returns:
MutableMapping[str, Any]: Mapping with ``servers`` and ``clients``
keyed by UUID. Entries combine configured tunnel addresses,
peer links, enabled state, interface names, handshake status,
byte counters, and connected peer counts where available.
"""
data_sources = {
"summary_raw": "/api/wireguard/service/show",
"clients_raw": "/api/wireguard/client/get",
"servers_raw": "/api/wireguard/server/get",
}
data: dict[str, dict[str, Any]] = {}
for key, path in data_sources.items():
if await self.is_endpoint_available(path):
data[key] = await self._safe_dict_get(path)
else:
_LOGGER.debug("WireGuard endpoint unavailable: %s", path)
data[key] = {}
summary = data["summary_raw"].get("rows", [])
client_summ = dict_get(data["clients_raw"], "client.clients.client", {})
server_summ = dict_get(data["servers_raw"], "server.servers.server", {})
if (
not isinstance(summary, list)
or not isinstance(client_summ, MutableMapping)
or not isinstance(server_summ, MutableMapping)
):
_LOGGER.debug("[get_wireguard] servers: 0, clients: 0")
return {"servers": {}, "clients": {}}
servers = {
uid: self._process_wireguard_server(uid, srv, client_summ)
for uid, srv in server_summ.items()
if isinstance(srv, MutableMapping)
}
clients = {
uid: self._process_wireguard_client(uid, clnt, servers)
for uid, clnt in client_summ.items()
if isinstance(clnt, MutableMapping)
}
self._update_wireguard_status(summary, servers, clients)
wireguard = {"servers": servers, "clients": clients}
_LOGGER.debug(
"[get_wireguard] servers: %s, clients: %s",
len(servers),
len(clients),
)
return wireguard
@staticmethod
def _process_wireguard_server(
uid: str, srv: MutableMapping[str, Any], client_summ: MutableMapping[str, Any]
) -> MutableMapping[str, Any]:
"""Process a single WireGuard server entry.
Args:
uid (str): WireGuard server UUID from the server settings payload.
srv (MutableMapping[str, Any]): WireGuard server settings row.
client_summ (MutableMapping[str, Any]): Client settings mapping
used to enrich selected peer links with public keys.
Returns:
MutableMapping[str, Any]: Normalized server mapping with identity,
enabled state, interface, DNS servers, tunnel addresses,
configured client peers, and initialized connection counters.
"""
tunnel_addresses = VPNMixin._mapping_value(srv, "tunneladdress")
peers = VPNMixin._mapping_value(srv, "peers")
return {
"uuid": uid,
"name": srv.get("name"),
"pubkey": srv.get("pubkey"),
"enabled": api_value_matches(srv.get("enabled"), "1"),
"interface": f"wg{srv.get('instance', '')}",
"dns_servers": [srv.get("peer_dns")] if srv.get("peer_dns") else [],
"tunnel_addresses": [
addr.get("value")
for addr in tunnel_addresses.values()
if isinstance(addr, MutableMapping)
and api_value_matches(addr.get("selected"), "1")
and addr.get("value")
],
"clients": [
{
"name": peer.get("value"),
"uuid": peer_id,
"pubkey": client_summ.get(peer_id, {}).get("pubkey"),
"connected": False,
}
for peer_id, peer in peers.items()
if isinstance(peer, MutableMapping)
and api_value_matches(peer.get("selected"), "1")
and peer.get("value")
],
"connected_clients": 0,
"total_bytes_recv": 0,
"total_bytes_sent": 0,
}
@staticmethod
def _process_wireguard_client(
uid: str, clnt: MutableMapping[str, Any], servers: MutableMapping[str, Any]
) -> MutableMapping[str, Any]:
"""Process a single WireGuard client entry.
Args:
uid (str): WireGuard client UUID from the client settings payload.
clnt (MutableMapping[str, Any]): WireGuard client settings row.
servers (MutableMapping[str, Any]): Normalized server mapping used
to enrich configured server links.
Returns:
MutableMapping[str, Any]: Normalized client mapping with identity,
enabled state, tunnel addresses, configured server links, and
initialized connection counters.
"""
tunnel_addresses = VPNMixin._mapping_value(clnt, "tunneladdress")
server_links = VPNMixin._mapping_value(clnt, "servers")
return {
"uuid": uid,
"name": clnt.get("name"),
"pubkey": clnt.get("pubkey"),
"enabled": api_value_matches(clnt.get("enabled"), "1"),
"tunnel_addresses": [
addr.get("value")
for addr in tunnel_addresses.values()
if isinstance(addr, MutableMapping)
and api_value_matches(addr.get("selected"), "1")
and addr.get("value")
],
"servers": [
VPNMixin._link_wireguard_client_to_server(srv_id, servers, srv)
for srv_id, srv in server_links.items()
if isinstance(srv, MutableMapping)
and api_value_matches(srv.get("selected"), "1")
and srv.get("value")
],
"connected_servers": 0,
"total_bytes_recv": 0,
"total_bytes_sent": 0,
}
@staticmethod
def _link_wireguard_client_to_server(
srv_id: str, servers: MutableMapping[str, Any], srv: MutableMapping[str, Any]
) -> MutableMapping[str, Any]:
"""Link a WireGuard client to its corresponding server.
Args:
srv_id (str): Server identifier used to match related entries.
servers (MutableMapping[str, Any]): Server mapping keyed by server identifier.
srv (MutableMapping[str, Any]): WireGuard server mapping entry.
Returns:
MutableMapping[str, Any]: Mapping that describes the linked
client-to-server relationship, including keys such as
``name``, ``uuid``, and ``connected``, and optionally
``pubkey``, ``interface``, and ``tunnel_addresses`` when the
referenced server exists in ``servers``.
"""
if srv_id in servers:
server = servers[srv_id]
return {
"name": server.get("name"),
"uuid": srv_id,
"connected": False,
"pubkey": server.get("pubkey"),
"interface": server.get("interface"),
"tunnel_addresses": server.get("tunnel_addresses"),
}
return {
"name": srv.get("value"),
"uuid": srv_id,
"connected": False,
}
@staticmethod
def _update_wireguard_status(
summary: list[MutableMapping[str, Any]],
servers: MutableMapping[str, Any],
clients: MutableMapping[str, Any],
) -> None:
"""Update WireGuard server and client statuses based on the summary.
Args:
summary (list[MutableMapping[str, Any]]): WireGuard summary payload from the API.
servers (MutableMapping[str, Any]): Server mapping keyed by server identifier.
clients (MutableMapping[str, Any]): Client mapping keyed by client identifier.
"""
for entry in summary:
if not isinstance(entry, MutableMapping):
continue
if entry.get("type") == "interface":
for server in servers.values():
if server.get("pubkey") == entry.get("public-key"):
server["status"] = entry.get("status")
elif entry.get("type") == "peer":
VPNMixin._update_wireguard_peer_status(entry, servers, clients)
@staticmethod
def _update_wireguard_peer_status(
entry: MutableMapping[str, Any],
servers: MutableMapping[str, Any],
clients: MutableMapping[str, Any],
) -> None:
"""Update the WireGuard peer status for clients and servers.
Args:
entry (MutableMapping[str, Any]): WireGuard summary peer row from
``/api/wireguard/service/show``.
servers (MutableMapping[str, Any]): Server mapping keyed by server identifier.
clients (MutableMapping[str, Any]): Client mapping keyed by client identifier.
"""
pubkey = entry.get("public-key", "-")
interface = entry.get("if", "-")
endpoint = entry.get("endpoint", None)
transfer_rx: int = try_to_int(entry.get("transfer-rx", 0), 0) or 0
transfer_tx: int = try_to_int(entry.get("transfer-tx", 0), 0) or 0
latest_handshake = try_to_int(entry.get("latest-handshake", 0), 0)
handshake_time = timestamp_to_datetime(latest_handshake)
is_connected = VPNMixin._wireguard_is_connected(handshake_time)
# Update servers
for server in servers.values():
if server.get("interface") == interface:
for client in server.get("clients", []):
if client.get("pubkey") == pubkey:
VPNMixin._update_wireguard_peer_details(
peer=client,
server_or_client=server,
endpoint=endpoint,
transfer_rx=transfer_rx,
transfer_tx=transfer_tx,
handshake_time=handshake_time,
is_connected=is_connected,
connection_counter_key="connected_clients",
)
# Update clients
for client in clients.values():
if client.get("pubkey") == pubkey:
for server in client.get("servers", []):
if server.get("interface") == interface:
VPNMixin._update_wireguard_peer_details(
peer=server,
server_or_client=client,
endpoint=endpoint,
transfer_rx=transfer_rx,
transfer_tx=transfer_tx,
handshake_time=handshake_time,
is_connected=is_connected,
connection_counter_key="connected_servers",
)
@staticmethod
def _update_wireguard_peer_details(
peer: MutableMapping[str, Any],
server_or_client: MutableMapping[str, Any],
endpoint: str,
transfer_rx: int,
transfer_tx: int,
handshake_time: datetime | None,
is_connected: bool,
connection_counter_key: str,
) -> None:
"""Apply live WireGuard peer counters to a linked server or client.
Args:
peer (MutableMapping[str, Any]): Linked peer mapping to update.
server_or_client (MutableMapping[str, Any]): Parent server or
client mapping whose aggregate counters should be incremented.
endpoint (str): Remote endpoint string reported by WireGuard.
transfer_rx (int): Received byte counter for peer statistics.
transfer_tx (int): Transmitted byte counter for peer statistics.
handshake_time (datetime | None): Handshake time used by this operation.
is_connected (bool): Connection status flag for the interface entry.
connection_counter_key (str): Counter key used for interface connection tracking.
"""
if endpoint and endpoint != "(none)":
peer["endpoint"] = endpoint
peer["bytes_recv"] = transfer_rx
peer["bytes_sent"] = transfer_tx
peer["latest_handshake"] = handshake_time
peer["connected"] = is_connected
# Update the parent (server or client) stats
server_or_client["total_bytes_recv"] = (
server_or_client.get("total_bytes_recv", 0) + transfer_rx
)
server_or_client["total_bytes_sent"] = (
server_or_client.get("total_bytes_sent", 0) + transfer_tx
)
if is_connected:
server_or_client[connection_counter_key] = (
server_or_client.get(connection_counter_key, 0) + 1
)
# Update the latest handshake time if it's newer
if (
server_or_client.get("latest_handshake") is None
or server_or_client["latest_handshake"] < handshake_time
):
server_or_client["latest_handshake"] = handshake_time
async def toggle_vpn_instance(self, vpn_type: str, clients_servers: str, uuid: str) -> bool:
"""Toggle the specified VPN instance on or off.
Args:
vpn_type (str): Vpn type used by this operation.
clients_servers (str): WireGuard collection to toggle. Use
``clients`` for a client entry or ``servers`` for a server
entry. Ignored for OpenVPN.
uuid (str): UUID of the VPN instance to toggle.
Returns:
bool: True when the toggle operation completes successfully; otherwise, False.
"""
if vpn_type == "openvpn":
success = await self._safe_dict_post(f"/api/openvpn/instances/toggle/{uuid}")
if not success.get("changed", False):
return False
reconfigure = await self._safe_dict_post("/api/openvpn/service/reconfigure")
return reconfigure.get("result", "") == "ok"
if vpn_type == "wireguard":
if clients_servers == "clients":
endpoint = await self._get_endpoint_path(
snake_case_path=f"/api/wireguard/client/toggle_client/{uuid}",
camel_case_path=f"/api/wireguard/client/toggleClient/{uuid}",
)
elif clients_servers == "servers":
endpoint = await self._get_endpoint_path(
snake_case_path=f"/api/wireguard/server/toggle_server/{uuid}",
camel_case_path=f"/api/wireguard/server/toggleServer/{uuid}",
)
else:
return False
success = await self._safe_dict_post(endpoint)
if not success.get("changed", False):
return False
reconfigure = await self._safe_dict_post("/api/wireguard/service/reconfigure")
return reconfigure.get("result", "") == "ok"
return False