diff --git a/src/wg/wgclient_win.cpp b/src/wg/wgclient_win.cpp new file mode 100644 index 0000000..85278d6 --- /dev/null +++ b/src/wg/wgclient_win.cpp @@ -0,0 +1,317 @@ +// src/wg/wgclient_win.cpp — Windows WireGuard tunnel for Artemis. +// +// Requires: Administrator privileges (Wintun kernel driver). +// Links against: boringtun.lib, iphlpapi.lib, ws2_32.lib, ntdll.lib, ole32.lib +// +// Packet format on Windows: raw IP (no AF-prefix unlike macOS utun). + +#include "wgclient.h" + +#ifdef _WIN32 + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +// boringtun C ABI (header lives in the same directory — symlinked or copied) +#include "boringtun_ffi.h" + +// Wintun typedefs & loader +#include "wintun_artemis.h" + +namespace wg { + +// ── helpers ─────────────────────────────────────────────────────────────────── + +static std::string winErr(const char *ctx) { + char buf[256] = {}; + FormatMessageA(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + nullptr, GetLastError(), 0, buf, sizeof(buf), nullptr); + return std::string(ctx) + ": " + buf; +} + +static bool parseCIDR(const std::string &cidr, IN_ADDR *addr, UINT8 *prefix) { + auto slash = cidr.find('/'); + std::string ip = (slash != std::string::npos) ? cidr.substr(0, slash) : cidr; + *prefix = (slash != std::string::npos) ? (UINT8)std::stoi(cidr.substr(slash + 1)) : 32; + return InetPtonA(AF_INET, ip.c_str(), addr) == 1; +} + +// ── Wintun function pointer types ───────────────────────────────────────────── +// (Mirror of DragonMoonlight's wintun.h — kept minimal to avoid header dep) + +typedef PVOID WINTUN_ADAPTER_HANDLE; +typedef PVOID WINTUN_SESSION_HANDLE; +typedef DWORD WINTUN_CAPACITY; + +typedef WINTUN_ADAPTER_HANDLE (WINAPI *PFN_WintunCreateAdapter)(LPCWSTR, LPCWSTR, const GUID *); +typedef void (WINAPI *PFN_WintunDeleteAdapter)(WINTUN_ADAPTER_HANDLE); +typedef void (WINAPI *PFN_WintunGetAdapterLUID)(WINTUN_ADAPTER_HANDLE, NET_LUID *); +typedef WINTUN_SESSION_HANDLE (WINAPI *PFN_WintunStartSession)(WINTUN_ADAPTER_HANDLE, DWORD); +typedef void (WINAPI *PFN_WintunEndSession)(WINTUN_SESSION_HANDLE); +typedef BYTE *(WINAPI *PFN_WintunReceivePacket)(WINTUN_SESSION_HANDLE, DWORD *); +typedef void (WINAPI *PFN_WintunReleaseReceivePacket)(WINTUN_SESSION_HANDLE, BYTE *); +typedef BYTE *(WINAPI *PFN_WintunAllocateSendPacket)(WINTUN_SESSION_HANDLE, DWORD); +typedef void (WINAPI *PFN_WintunSendPacket)(WINTUN_SESSION_HANDLE, BYTE *); +typedef HANDLE (WINAPI *PFN_WintunGetReadWaitEvent)(WINTUN_SESSION_HANDLE); + +struct WintunFns { + HMODULE hmod = nullptr; + PFN_WintunCreateAdapter CreateAdapter = nullptr; + PFN_WintunDeleteAdapter DeleteAdapter = nullptr; + PFN_WintunGetAdapterLUID GetAdapterLUID = nullptr; + PFN_WintunStartSession StartSession = nullptr; + PFN_WintunEndSession EndSession = nullptr; + PFN_WintunReceivePacket ReceivePacket = nullptr; + PFN_WintunReleaseReceivePacket ReleaseReceivePacket = nullptr; + PFN_WintunAllocateSendPacket AllocateSendPacket = nullptr; + PFN_WintunSendPacket SendPacket = nullptr; + PFN_WintunGetReadWaitEvent GetReadWaitEvent = nullptr; + + bool load() { + hmod = LoadLibraryExW(L"wintun.dll", nullptr, + LOAD_LIBRARY_SEARCH_APPLICATION_DIR | + LOAD_LIBRARY_SEARCH_SYSTEM32); + if (!hmod) return false; +#define R(fn) fn = (PFN_##fn)GetProcAddress(hmod, #fn); if (!fn) { unload(); return false; } + R(WintunCreateAdapter) + R(WintunDeleteAdapter) + R(WintunGetAdapterLUID) + R(WintunStartSession) + R(WintunEndSession) + R(WintunReceivePacket) + R(WintunReleaseReceivePacket) + R(WintunAllocateSendPacket) + R(WintunSendPacket) + R(WintunGetReadWaitEvent) +#undef R + return true; + } + void unload() { + if (hmod) { FreeLibrary(hmod); hmod = nullptr; } + } +}; + +// ── ClientImpl ──────────────────────────────────────────────────────────────── + +class ClientImpl { +public: + WintunFns wt; + WINTUN_ADAPTER_HANDLE adapter = nullptr; + WINTUN_SESSION_HANDLE session = nullptr; + wireguard_tunnel *wg = nullptr; + SOCKET udpSock = INVALID_SOCKET; + HANDLE stopEvt = nullptr; + NET_LUID luid {}; + std::atomic live {false}; + std::thread tTunToUdp, tUdpToTun, tTicker; + Config cfg; + std::string localIPStr; + LogFn log; + ErrorFn err; +}; + +// ── I/O threads ─────────────────────────────────────────────────────────────── + +static void tunToUdp(ClientImpl *p) { + struct sockaddr_in peer {}; peer.sin_family = AF_INET; + InetPtonA(AF_INET, p->cfg.endpointHost().c_str(), &peer.sin_addr); + peer.sin_port = htons((u_short)p->cfg.endpointPort()); + + HANDLE re = p->wt.GetReadWaitEvent(p->session); + HANDLE evs[2] = { re, p->stopEvt }; + std::vector dst(65536); + + while (p->live) { + if (WaitForMultipleObjects(2, evs, FALSE, INFINITE) != WAIT_OBJECT_0) break; + DWORD sz = 0; + BYTE *pkt = p->wt.ReceivePacket(p->session, &sz); + if (!pkt) continue; + size_t dlen = dst.size(); + wireguard_result r = wireguard_write(p->wg, pkt, sz, dst.data(), &dlen); + p->wt.ReleaseReceivePacket(p->session, pkt); + if (r.op == WRITE_TO_NETWORK && dlen > 0) + sendto(p->udpSock, (char *)dst.data(), (int)dlen, 0, + (sockaddr *)&peer, sizeof(peer)); + } +} + +static void udpToTun(ClientImpl *p) { + struct sockaddr_in peer {}; peer.sin_family = AF_INET; + InetPtonA(AF_INET, p->cfg.endpointHost().c_str(), &peer.sin_addr); + peer.sin_port = htons((u_short)p->cfg.endpointPort()); + + WSAEVENT se = WSACreateEvent(); + WSAEventSelect(p->udpSock, se, FD_READ); + HANDLE evs[2] = { se, p->stopEvt }; + std::vector enc(65536), plain(65536); + + while (p->live) { + if (WaitForMultipleObjects(2, evs, FALSE, INFINITE) != WAIT_OBJECT_0) break; + WSAResetEvent(se); + struct sockaddr_in from {}; int fl = sizeof(from); + int n = recvfrom(p->udpSock, (char *)enc.data(), (int)enc.size(), + 0, (sockaddr *)&from, &fl); + if (n <= 0) continue; + size_t plen = plain.size(); + wireguard_result r = wireguard_read(p->wg, enc.data(), n, plain.data(), &plen); + if ((r.op == WRITE_TO_TUNNEL_IPV4 || r.op == WRITE_TO_TUNNEL_IPV6) && plen > 0) { + BYTE *out = p->wt.AllocateSendPacket(p->session, (DWORD)plen); + if (out) { memcpy(out, plain.data(), plen); p->wt.SendPacket(p->session, out); } + } else if (r.op == WRITE_TO_NETWORK && plen > 0) { + sendto(p->udpSock, (char *)plain.data(), (int)plen, 0, + (sockaddr *)&peer, sizeof(peer)); + } + } + WSACloseEvent(se); +} + +static void ticker(ClientImpl *p) { + struct sockaddr_in peer {}; peer.sin_family = AF_INET; + InetPtonA(AF_INET, p->cfg.endpointHost().c_str(), &peer.sin_addr); + peer.sin_port = htons((u_short)p->cfg.endpointPort()); + std::vector buf(512); + while (p->live) { + if (WaitForSingleObject(p->stopEvt, 100) != WAIT_TIMEOUT) break; + size_t len = buf.size(); + wireguard_result r = wireguard_tick(p->wg, buf.data(), &len); + if (r.op == WRITE_TO_NETWORK && len > 0) + sendto(p->udpSock, (char *)buf.data(), (int)len, 0, + (sockaddr *)&peer, sizeof(peer)); + } +} + +// ── Client public API ───────────────────────────────────────────────────────── + +Client::Client() = default; +Client::~Client() { stop(); } + +void Client::start(const Config &cfg) { + if (m_impl && m_impl->live) return; + + auto p = std::make_unique(); + p->cfg = cfg; + p->log = m_log; + p->err = m_error; + + if (!p->wt.load()) + throw std::runtime_error("Failed to load wintun.dll"); + + p->wg = new_tunnel( + cfg.privateKey().c_str(), + cfg.peerPublicKey().c_str(), + cfg.presharedKey().empty() ? nullptr : cfg.presharedKey().c_str(), + cfg.persistentKeepalive(), 0); + if (!p->wg) throw std::runtime_error("boringtun: tunnel creation failed"); + + GUID guid; CoCreateGuid(&guid); + p->adapter = p->wt.CreateAdapter(L"Artemis", L"WireGuard", &guid); + if (!p->adapter) { tunnel_free(p->wg); p->wt.unload(); + throw std::runtime_error(winErr("WintunCreateAdapter")); } + p->wt.GetAdapterLUID(p->adapter, &p->luid); + + // Configure IP address + { + IN_ADDR addr {}; UINT8 prefix = 24; + parseCIDR(cfg.address(), &addr, &prefix); + MIB_UNICASTIPADDRESS_ROW row {}; InitializeUnicastIpAddressEntry(&row); + row.InterfaceLuid = p->luid; + row.Address.si_family = AF_INET; + row.Address.Ipv4.sin_family = AF_INET; + row.Address.Ipv4.sin_addr = addr; + row.OnLinkPrefixLength = prefix; + row.DadState = IpDadStatePreferred; + CreateUnicastIpAddressEntry(&row); + } + + // Configure routes (AllowedIPs) + for (const auto &cidr : cfg.allowedIPs()) { + IN_ADDR net {}; UINT8 prefix = 24; + if (!parseCIDR(cidr, &net, &prefix)) continue; + MIB_IPFORWARD_ROW2 row {}; InitializeIpForwardEntry(&row); + row.InterfaceLuid = p->luid; + row.DestinationPrefix.Prefix.si_family = AF_INET; + row.DestinationPrefix.Prefix.Ipv4.sin_family = AF_INET; + row.DestinationPrefix.Prefix.Ipv4.sin_addr = net; + row.DestinationPrefix.PrefixLength = prefix; + row.NextHop.si_family = AF_INET; + row.Metric = 1; + row.Protocol = MIB_IPPROTO_NETMGMT; + CreateIpForwardEntry2(&row); + } + + p->session = p->wt.StartSession(p->adapter, 0x400000); + if (!p->session) { + p->wt.DeleteAdapter(p->adapter); tunnel_free(p->wg); p->wt.unload(); + throw std::runtime_error(winErr("WintunStartSession")); + } + + WSADATA wsd {}; WSAStartup(MAKEWORD(2,2), &wsd); + p->udpSock = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); + if (p->udpSock == INVALID_SOCKET) { + p->wt.EndSession(p->session); p->wt.DeleteAdapter(p->adapter); + tunnel_free(p->wg); p->wt.unload(); + throw std::runtime_error("socket() failed"); + } + sockaddr_in local {}; local.sin_family = AF_INET; local.sin_port = 0; + bind(p->udpSock, (sockaddr *)&local, sizeof(local)); + + p->stopEvt = CreateEvent(nullptr, TRUE, FALSE, nullptr); + p->localIPStr = cfg.addressIP(); + + // Initial handshake + { + std::vector hs(512); size_t len = hs.size(); + wireguard_force_handshake(p->wg, hs.data(), &len); + if (len > 0) { + sockaddr_in peer {}; peer.sin_family = AF_INET; + InetPtonA(AF_INET, cfg.endpointHost().c_str(), &peer.sin_addr); + peer.sin_port = htons((u_short)cfg.endpointPort()); + sendto(p->udpSock, (char *)hs.data(), (int)len, 0, + (sockaddr *)&peer, sizeof(peer)); + } + } + + p->live = true; + p->tTunToUdp = std::thread(tunToUdp, p.get()); + p->tUdpToTun = std::thread(udpToTun, p.get()); + p->tTicker = std::thread(ticker, p.get()); + + m_impl = std::move(p); +} + +void Client::stop() { + if (!m_impl) return; + auto *p = m_impl.get(); + p->live = false; + SetEvent(p->stopEvt); + if (p->tTunToUdp.joinable()) p->tTunToUdp.join(); + if (p->tUdpToTun.joinable()) p->tUdpToTun.join(); + if (p->tTicker.joinable()) p->tTicker.join(); + p->wt.EndSession(p->session); + p->wt.DeleteAdapter(p->adapter); + tunnel_free(p->wg); + closesocket(p->udpSock); + WSACleanup(); + CloseHandle(p->stopEvt); + p->wt.unload(); + m_impl.reset(); +} + +bool Client::running() const { return m_impl && m_impl->live; } +std::string Client::localIP() const { + return (m_impl && m_impl->live) ? m_impl->localIPStr : ""; +} + +} // namespace wg + +#endif // _WIN32