Compare commits

...

46 Commits

Author SHA1 Message Date
Owen
1a43f1ef4b Handle newt online offline with websocket 2026-03-14 11:59:20 -07:00
Owen
75ab074805 Attempt to improve handling bandwidth tracking 2026-03-13 12:06:01 -07:00
Owen
dc4e0253de Add message compression for large messages 2026-03-13 11:46:03 -07:00
Owen
cccf236042 Add optional compression 2026-03-12 17:49:21 -07:00
Owen
63fd63c65c Send less data down 2026-03-12 17:27:15 -07:00
Owen
beee1d692d revert: telemetry comment 2026-03-12 17:11:13 -07:00
Owen
fde786ca84 Add todo 2026-03-12 17:10:46 -07:00
Owen
3086fdd064 Merge branch 'dev' into jit 2026-03-12 16:58:23 -07:00
Owen
6c30f6db31 Dont send site if it missing public key 2026-03-12 16:33:33 -07:00
Owen
f021b73458 Add alert about domain error 2026-03-11 18:00:23 -07:00
Owen
74f4751bcc Dont show raw resource option unless remote node 2026-03-11 17:47:15 -07:00
Owen
e5bce4e180 Merge branch 'main' into dev 2026-03-11 15:55:59 -07:00
Owen
9b0e7b381c Fix error to gerbil 2026-03-11 15:49:03 -07:00
Owen
90afe5a7ac Log errors 2026-03-11 15:42:40 -07:00
Owen
b24de85157 Handle gerbil rejecting 0
Closes #2605
2026-03-11 15:06:26 -07:00
Owen
eda43dffe1 Fix not pulling wildcard cert updates 2026-03-11 15:06:26 -07:00
Owen
82c9a1eb70 Add demo link 2026-03-11 15:06:26 -07:00
Owen
a3d4553d14 Merge branch 'main' into dev 2026-03-11 14:53:55 -07:00
Owen
072c89e704 Bump dompurify 2026-03-10 16:43:40 -07:00
Owen
dbdff6812d Bump esbuild 2026-03-10 16:31:19 -07:00
Owen
cc841d5640 Add some logging to debug 2026-03-10 14:24:57 -07:00
Owen
dec358c4cd Use native drizzle count 2026-03-10 10:03:49 -07:00
Owen
e98f873f81 Clean up 2026-03-09 21:16:37 -07:00
Owen
e9a2a7e752 Reorder delete 2026-03-09 20:46:27 -07:00
Owen
06015d5191 Handle gerbil rejecting 0
Closes #2605
2026-03-09 17:35:25 -07:00
Owen
af688d2a23 Add demo link 2026-03-09 17:35:04 -07:00
Owen
7d0b3ec6b5 Fix not pulling wildcard cert updates 2026-03-09 17:34:48 -07:00
Owen
cf5fb8dc33 Working on jit 2026-03-09 16:36:13 -07:00
Owen Schwartz
9a0a255445 Merge pull request #2524 from shreyaspapi/fix/2294-path-based-routing
fix: path-based routing broken due to key collisions in sanitize()
2026-03-07 21:18:59 -08:00
Owen Schwartz
d5a37436c0 Merge pull request #2616 from LaurenceJJones/fix/issue-240-hcStatus-missing
fix(newt): missing hcStatus in hc config on reconnect
2026-03-07 21:14:27 -08:00
Laurence
be609b5000 Fix missing hcStatus field in health check config on reconnect
The buildTargetConfigurationForNewtClient function was not including the
  hcStatus field when building health check targets for the newt/wg/connect
  message. This caused custom expected response codes (e.g., 409) to revert
  to the default 2xx range check after Pangolin server restart.

  Added hcStatus to both the database select query and the returned health
  check target object, matching the behavior in targets.ts addTargets.
2026-03-07 06:28:10 +00:00
Owen
0503c6e66e Handle JIT for ssh 2026-03-06 15:49:17 -08:00
Owen
9405b0b70a Force jit above site limit 2026-03-06 14:09:57 -08:00
Owen
a26ee4ac1a Adjust billing upgrade language 2026-03-06 12:17:26 -08:00
Owen
2a5c9465e9 Add chainId field passthrough 2026-03-04 22:17:58 -08:00
Owen
f36b66e397 Merge branch 'dev' into jit 2026-03-04 17:58:50 -08:00
Owen
8c6d44677d Update lock 2026-03-04 17:48:58 -08:00
Owen
1bfff630bf Jit working for sites 2026-03-04 17:46:58 -08:00
miloschwartz
ebcef28b05 remove resend from config 2026-03-04 17:45:48 -08:00
miloschwartz
e87e12898c remove resend 2026-03-04 17:45:22 -08:00
miloschwartz
d60ab281cf remove resend from package.json 2026-03-04 17:42:25 -08:00
Owen
c73a39f797 Allow JIT based on site or resource 2026-03-04 15:44:27 -08:00
Shreyas Papinwar
75a909784a fix: simplify path encoding per review — inline utils, use single key scheme
Address PR review comments:
- Remove pathUtils.ts and move sanitize/encodePath directly into utils.ts
- Simplify dual-key approach to single key using encodePath for map keys
- Remove backward-compat logic (not needed per reviewer)
- Update tests to match simplified approach
2026-03-01 15:48:26 +05:30
Shreyas
244f497a9c test: add comprehensive backward compatibility tests for path routing fix 2026-03-01 15:48:26 +05:30
Shreyas
e58f0c9f07 fix: preserve backward-compatible router names while fixing path collisions
Use encodePath only for internal map key grouping (collision-free) and
sanitize for Traefik-facing router/service names (unchanged for existing
users). Extract pure functions into pathUtils.ts so tests can run without
DB dependencies.
2026-03-01 15:48:26 +05:30
Shreyas
5f18c06e03 fix: use collision-free path encoding for Traefik router key generation 2026-03-01 15:48:26 +05:30
55 changed files with 2292 additions and 972 deletions

View File

@@ -2343,8 +2343,8 @@
"logRetentionEndOfFollowingYear": "End of following year",
"actionLogsDescription": "View a history of actions performed in this organization",
"accessLogsDescription": "View access auth requests for resources in this organization",
"licenseRequiredToUse": "An <enterpriseLicenseLink>Enterprise Edition</enterpriseLicenseLink> license or <pangolinCloudLink>Pangolin Cloud</pangolinCloudLink> is required to use this feature.",
"ossEnterpriseEditionRequired": "The <enterpriseEditionLink>Enterprise Edition</enterpriseEditionLink> is required to use this feature. This feature is also available in <pangolinCloudLink>Pangolin Cloud</pangolinCloudLink>.",
"licenseRequiredToUse": "An <enterpriseLicenseLink>Enterprise Edition</enterpriseLicenseLink> license or <pangolinCloudLink>Pangolin Cloud</pangolinCloudLink> is required to use this feature. <bookADemoLink>Book a demo or POC trial</bookADemoLink>.",
"ossEnterpriseEditionRequired": "The <enterpriseEditionLink>Enterprise Edition</enterpriseEditionLink> is required to use this feature. This feature is also available in <pangolinCloudLink>Pangolin Cloud</pangolinCloudLink>. <bookADemoLink>Book a demo or POC trial</bookADemoLink>.",
"certResolver": "Certificate Resolver",
"certResolverDescription": "Select the certificate resolver to use for this resource.",
"selectCertResolver": "Select Certificate Resolver",
@@ -2681,5 +2681,6 @@
"approvalsEmptyStateStep2Title": "Enable Device Approvals",
"approvalsEmptyStateStep2Description": "Edit a role and enable the 'Require Device Approvals' option. Users with this role will need admin approval for new devices.",
"approvalsEmptyStatePreviewDescription": "Preview: When enabled, pending device requests will appear here for review",
"approvalsEmptyStateButtonText": "Manage Roles"
"approvalsEmptyStateButtonText": "Manage Roles",
"domainErrorTitle": "We are having trouble verifying your domain"
}

4
package-lock.json generated
View File

@@ -16773,7 +16773,8 @@
"node_modules/postal-mime": {
"version": "2.7.3",
"resolved": "https://registry.npmjs.org/postal-mime/-/postal-mime-2.7.3.tgz",
"integrity": "sha512-MjhXadAJaWgYzevi46+3kLak8y6gbg0ku14O1gO/LNOuay8dO+1PtcSGvAdgDR0DoIsSaiIA8y/Ddw6MnrO0Tw=="
"integrity": "sha512-MjhXadAJaWgYzevi46+3kLak8y6gbg0ku14O1gO/LNOuay8dO+1PtcSGvAdgDR0DoIsSaiIA8y/Ddw6MnrO0Tw==",
"license": "MIT-0"
},
"node_modules/postcss": {
"version": "8.5.6",
@@ -17814,6 +17815,7 @@
"version": "6.9.2",
"resolved": "https://registry.npmjs.org/resend/-/resend-6.9.2.tgz",
"integrity": "sha512-uIM6CQ08tS+hTCRuKBFbOBvHIGaEhqZe8s4FOgqsVXSbQLAhmNWpmUhG3UAtRnmcwTWFUqnHa/+Vux8YGPyDBA==",
"license": "MIT",
"dependencies": {
"postal-mime": "2.7.3",
"svix": "1.84.1"

View File

@@ -1,6 +1,10 @@
import { flushBandwidthToDb } from "@server/routers/newt/handleReceiveBandwidthMessage";
import { flushSiteBandwidthToDb } from "@server/routers/gerbil/receiveBandwidth";
import { cleanup as wsCleanup } from "#dynamic/routers/ws";
async function cleanup() {
await flushBandwidthToDb();
await flushSiteBandwidthToDb();
await wsCleanup();
process.exit(0);
@@ -10,4 +14,4 @@ export async function initCleanup() {
// Handle process termination
process.on("SIGTERM", () => cleanup());
process.on("SIGINT", () => cleanup());
}
}

View File

@@ -22,7 +22,8 @@ export const domains = pgTable("domains", {
tries: integer("tries").notNull().default(0),
certResolver: varchar("certResolver"),
customCertResolver: varchar("customCertResolver"),
preferWildcardCert: boolean("preferWildcardCert")
preferWildcardCert: boolean("preferWildcardCert"),
errorMessage: text("errorMessage")
});
export const dnsRecords = pgTable("dnsRecords", {
@@ -88,6 +89,7 @@ export const sites = pgTable("sites", {
lastBandwidthUpdate: varchar("lastBandwidthUpdate"),
type: varchar("type").notNull(), // "newt" or "wireguard"
online: boolean("online").notNull().default(false),
lastPing: integer("lastPing"),
address: varchar("address"),
endpoint: varchar("endpoint"),
publicKey: varchar("publicKey"),
@@ -720,6 +722,7 @@ export const clientSitesAssociationsCache = pgTable(
.notNull(),
siteId: integer("siteId").notNull(),
isRelayed: boolean("isRelayed").notNull().default(false),
isJitMode: boolean("isJitMode").notNull().default(false),
endpoint: varchar("endpoint"),
publicKey: varchar("publicKey") // this will act as the session's public key for hole punching so we can track when it changes
}

View File

@@ -13,7 +13,8 @@ export const domains = sqliteTable("domains", {
failed: integer("failed", { mode: "boolean" }).notNull().default(false),
tries: integer("tries").notNull().default(0),
certResolver: text("certResolver"),
preferWildcardCert: integer("preferWildcardCert", { mode: "boolean" })
preferWildcardCert: integer("preferWildcardCert", { mode: "boolean" }),
errorMessage: text("errorMessage")
});
export const dnsRecords = sqliteTable("dnsRecords", {
@@ -89,6 +90,7 @@ export const sites = sqliteTable("sites", {
lastBandwidthUpdate: text("lastBandwidthUpdate"),
type: text("type").notNull(), // "newt" or "wireguard"
online: integer("online", { mode: "boolean" }).notNull().default(false),
lastPing: integer("lastPing"),
// exit node stuff that is how to connect to the site when it has a wg server
address: text("address"), // this is the address of the wireguard interface in newt
@@ -409,6 +411,9 @@ export const clientSitesAssociationsCache = sqliteTable(
isRelayed: integer("isRelayed", { mode: "boolean" })
.notNull()
.default(false),
isJitMode: integer("isJitMode", { mode: "boolean" })
.notNull()
.default(false),
endpoint: text("endpoint"),
publicKey: text("publicKey") // this will act as the session's public key for hole punching so we can track when it changes
}

View File

@@ -107,7 +107,7 @@ export async function applyBlueprint({
[target],
matchingHealthcheck ? [matchingHealthcheck] : [],
result.proxyResource.protocol,
result.proxyResource.proxyPort
site.newt.version
);
}
}

View File

@@ -0,0 +1,20 @@
import semver from "semver";
export function canCompress(
clientVersion: string | null | undefined,
type: "newt" | "olm"
): boolean {
try {
if (!clientVersion) return false;
// check if it is a valid semver
if (!semver.valid(clientVersion)) return false;
if (type === "newt") {
return semver.gte(clientVersion, "1.10.3");
} else if (type === "olm") {
return semver.gte(clientVersion, "1.4.3");
}
return false;
} catch {
return false;
}
}

View File

@@ -477,6 +477,7 @@ async function handleMessagesForSiteClients(
}
if (isAdd) {
// TODO: if we are in jit mode here should we really be sending this?
await initPeerAddHandshake(
// this will kick off the add peer process for the client
client.clientId,
@@ -571,7 +572,7 @@ export async function updateClientSiteDestinations(
destinations: [
{
destinationIP: site.sites.subnet.split("/")[0],
destinationPort: site.sites.listenPort || 0
destinationPort: site.sites.listenPort || 1 // this satisfies gerbil for now but should be reevaluated
}
]
};
@@ -579,7 +580,7 @@ export async function updateClientSiteDestinations(
// add to the existing destinations
destinations.destinations.push({
destinationIP: site.sites.subnet.split("/")[0],
destinationPort: site.sites.listenPort || 0
destinationPort: site.sites.listenPort || 1 // this satisfies gerbil for now but should be reevaluated
});
}
@@ -669,7 +670,11 @@ async function handleSubnetProxyTargetUpdates(
`Adding ${targetsToAdd.length} subnet proxy targets for siteResource ${siteResource.siteResourceId}`
);
proxyJobs.push(
addSubnetProxyTargets(newt.newtId, targetsToAdd)
addSubnetProxyTargets(
newt.newtId,
targetsToAdd,
newt.version
)
);
}
@@ -705,7 +710,11 @@ async function handleSubnetProxyTargetUpdates(
`Removing ${targetsToRemove.length} subnet proxy targets for siteResource ${siteResource.siteResourceId}`
);
proxyJobs.push(
removeSubnetProxyTargets(newt.newtId, targetsToRemove)
removeSubnetProxyTargets(
newt.newtId,
targetsToRemove,
newt.version
)
);
}
@@ -1080,6 +1089,7 @@ async function handleMessagesForClientSites(
continue;
}
// TODO: if we are in jit mode here should we really be sending this?
await initPeerAddHandshake(
// this will kick off the add peer process for the client
client.clientId,
@@ -1146,7 +1156,7 @@ async function handleMessagesForClientResources(
// Add subnet proxy targets for each site
for (const [siteId, resources] of addedBySite.entries()) {
const [newt] = await trx
.select({ newtId: newts.newtId })
.select({ newtId: newts.newtId, version: newts.version })
.from(newts)
.where(eq(newts.siteId, siteId))
.limit(1);
@@ -1168,7 +1178,13 @@ async function handleMessagesForClientResources(
]);
if (targets.length > 0) {
proxyJobs.push(addSubnetProxyTargets(newt.newtId, targets));
proxyJobs.push(
addSubnetProxyTargets(
newt.newtId,
targets,
newt.version
)
);
}
try {
@@ -1217,7 +1233,7 @@ async function handleMessagesForClientResources(
// Remove subnet proxy targets for each site
for (const [siteId, resources] of removedBySite.entries()) {
const [newt] = await trx
.select({ newtId: newts.newtId })
.select({ newtId: newts.newtId, version: newts.version })
.from(newts)
.where(eq(newts.siteId, siteId))
.limit(1);
@@ -1240,7 +1256,11 @@ async function handleMessagesForClientResources(
if (targets.length > 0) {
proxyJobs.push(
removeSubnetProxyTargets(newt.newtId, targets)
removeSubnetProxyTargets(
newt.newtId,
targets,
newt.version
)
);
}

View File

@@ -1,16 +0,0 @@
export enum AudienceIds {
SignUps = "",
Subscribed = "",
Churned = "",
Newsletter = ""
}
let resend;
export default resend;
export async function moveEmailToAudience(
email: string,
audienceId: AudienceIds
) {
return;
}

View File

@@ -218,10 +218,11 @@ export class TraefikConfigManager {
return true;
}
// Fetch if it's been more than 24 hours (for renewals)
const dayInMs = 24 * 60 * 60 * 1000;
const timeSinceLastFetch =
Date.now() - this.lastCertificateFetch.getTime();
// Fetch if it's been more than 24 hours (daily routine check)
if (timeSinceLastFetch > dayInMs) {
logger.info("Fetching certificates due to 24-hour renewal check");
return true;
@@ -265,7 +266,7 @@ export class TraefikConfigManager {
return true;
}
// Check if any local certificates are missing or appear to be outdated
// Check if any local certificates are missing (needs immediate fetch)
for (const domain of domainsNeedingCerts) {
const localState = this.lastLocalCertificateState.get(domain);
if (!localState || !localState.exists) {
@@ -274,17 +275,55 @@ export class TraefikConfigManager {
);
return true;
}
}
// Check if certificate is expiring soon (within 30 days)
if (localState.expiresAt) {
const nowInSeconds = Math.floor(Date.now() / 1000);
const secondsUntilExpiry = localState.expiresAt - nowInSeconds;
const daysUntilExpiry = secondsUntilExpiry / (60 * 60 * 24);
if (daysUntilExpiry < 30) {
logger.info(
`Fetching certificates due to upcoming expiry for ${domain} (${Math.round(daysUntilExpiry)} days remaining)`
);
return true;
// For expiry checks, throttle to every 6 hours to avoid querying the
// API/DB on every monitor loop. The certificate-service renews certs
// 45 days before expiry, so checking every 6 hours is plenty frequent
// to pick up renewed certs promptly.
const renewalCheckIntervalMs = 6 * 60 * 60 * 1000; // 6 hours
if (timeSinceLastFetch > renewalCheckIntervalMs) {
// Check non-wildcard certs for expiry (within 45 days to match
// the server-side renewal window in certificate-service)
for (const domain of domainsNeedingCerts) {
const localState =
this.lastLocalCertificateState.get(domain);
if (localState?.expiresAt) {
const nowInSeconds = Math.floor(Date.now() / 1000);
const secondsUntilExpiry =
localState.expiresAt - nowInSeconds;
const daysUntilExpiry =
secondsUntilExpiry / (60 * 60 * 24);
if (daysUntilExpiry < 45) {
logger.info(
`Fetching certificates due to upcoming expiry for ${domain} (${Math.round(daysUntilExpiry)} days remaining)`
);
return true;
}
}
}
// Also check wildcard certificates for expiry. These are not
// included in domainsNeedingCerts since their subdomains are
// filtered out, so we must check them separately.
for (const [certDomain, state] of this
.lastLocalCertificateState) {
if (
state.exists &&
state.wildcard &&
state.expiresAt
) {
const nowInSeconds = Math.floor(Date.now() / 1000);
const secondsUntilExpiry =
state.expiresAt - nowInSeconds;
const daysUntilExpiry =
secondsUntilExpiry / (60 * 60 * 24);
if (daysUntilExpiry < 45) {
logger.info(
`Fetching certificates due to upcoming expiry for wildcard cert ${certDomain} (${Math.round(daysUntilExpiry)} days remaining)`
);
return true;
}
}
}
}
@@ -361,6 +400,32 @@ export class TraefikConfigManager {
}
}
// Also include wildcard cert base domains that are
// expiring or expired so they get re-fetched even though
// their subdomains were filtered out above.
for (const [certDomain, state] of this
.lastLocalCertificateState) {
if (
state.exists &&
state.wildcard &&
state.expiresAt
) {
const nowInSeconds = Math.floor(
Date.now() / 1000
);
const secondsUntilExpiry =
state.expiresAt - nowInSeconds;
const daysUntilExpiry =
secondsUntilExpiry / (60 * 60 * 24);
if (daysUntilExpiry < 45) {
domainsToFetch.add(certDomain);
logger.info(
`Including expiring wildcard cert domain ${certDomain} in fetch (${Math.round(daysUntilExpiry)} days remaining)`
);
}
}
}
if (domainsToFetch.size > 0) {
// Get valid certificates for domains not covered by wildcards
validCertificates =

View File

@@ -14,7 +14,7 @@ import logger from "@server/logger";
import config from "@server/lib/config";
import { resources, sites, Target, targets } from "@server/db";
import createPathRewriteMiddleware from "./middleware";
import { sanitize, validatePathRewriteConfig } from "./utils";
import { sanitize, encodePath, validatePathRewriteConfig } from "./utils";
const redirectHttpsMiddlewareName = "redirect-to-https";
const badgerMiddlewareName = "badger";
@@ -44,7 +44,7 @@ export async function getTraefikConfig(
filterOutNamespaceDomains = false, // UNUSED BUT USED IN PRIVATE
generateLoginPageRouters = false, // UNUSED BUT USED IN PRIVATE
allowRawResources = true,
allowMaintenancePage = true, // UNUSED BUT USED IN PRIVATE
allowMaintenancePage = true // UNUSED BUT USED IN PRIVATE
): Promise<any> {
// Get resources with their targets and sites in a single optimized query
// Start from sites on this exit node, then join to targets and resources
@@ -127,7 +127,7 @@ export async function getTraefikConfig(
resourcesWithTargetsAndSites.forEach((row) => {
const resourceId = row.resourceId;
const resourceName = sanitize(row.resourceName) || "";
const targetPath = sanitize(row.path) || ""; // Handle null/undefined paths
const targetPath = encodePath(row.path); // Use encodePath to avoid collisions (e.g. "/a/b" vs "/a-b")
const pathMatchType = row.pathMatchType || "";
const rewritePath = row.rewritePath || "";
const rewritePathType = row.rewritePathType || "";
@@ -145,7 +145,7 @@ export async function getTraefikConfig(
const mapKey = [resourceId, pathKey].filter(Boolean).join("-");
const key = sanitize(mapKey);
if (!resourcesMap.has(key)) {
if (!resourcesMap.has(mapKey)) {
const validation = validatePathRewriteConfig(
row.path,
row.pathMatchType,
@@ -160,9 +160,10 @@ export async function getTraefikConfig(
return;
}
resourcesMap.set(key, {
resourcesMap.set(mapKey, {
resourceId: row.resourceId,
name: resourceName,
key: key,
fullDomain: row.fullDomain,
ssl: row.ssl,
http: row.http,
@@ -190,7 +191,7 @@ export async function getTraefikConfig(
});
}
resourcesMap.get(key).targets.push({
resourcesMap.get(mapKey).targets.push({
resourceId: row.resourceId,
targetId: row.targetId,
ip: row.ip,
@@ -227,8 +228,9 @@ export async function getTraefikConfig(
};
// get the key and the resource
for (const [key, resource] of resourcesMap.entries()) {
for (const [, resource] of resourcesMap.entries()) {
const targets = resource.targets as TargetWithSite[];
const key = resource.key;
const routerName = `${key}-${resource.name}-router`;
const serviceName = `${key}-${resource.name}-service`;

View File

@@ -0,0 +1,323 @@
import { assertEquals } from "../../../test/assert";
// ── Pure function copies (inlined to avoid pulling in server dependencies) ──
function sanitize(input: string | null | undefined): string | undefined {
if (!input) return undefined;
if (input.length > 50) {
input = input.substring(0, 50);
}
return input
.replace(/[^a-zA-Z0-9-]/g, "-")
.replace(/-+/g, "-")
.replace(/^-|-$/g, "");
}
function encodePath(path: string | null | undefined): string {
if (!path) return "";
return path.replace(/[^a-zA-Z0-9]/g, (ch) => {
return ch.charCodeAt(0).toString(16);
});
}
// ── Helpers ──────────────────────────────────────────────────────────
/**
* Exact replica of the OLD key computation from upstream main.
* Uses sanitize() for paths — this is what had the collision bug.
*/
function oldKeyComputation(
resourceId: number,
path: string | null,
pathMatchType: string | null,
rewritePath: string | null,
rewritePathType: string | null
): string {
const targetPath = sanitize(path) || "";
const pmt = pathMatchType || "";
const rp = rewritePath || "";
const rpt = rewritePathType || "";
const pathKey = [targetPath, pmt, rp, rpt].filter(Boolean).join("-");
const mapKey = [resourceId, pathKey].filter(Boolean).join("-");
return sanitize(mapKey) || "";
}
/**
* Replica of the NEW key computation from our fix.
* Uses encodePath() for paths — collision-free.
*/
function newKeyComputation(
resourceId: number,
path: string | null,
pathMatchType: string | null,
rewritePath: string | null,
rewritePathType: string | null
): string {
const targetPath = encodePath(path);
const pmt = pathMatchType || "";
const rp = rewritePath || "";
const rpt = rewritePathType || "";
const pathKey = [targetPath, pmt, rp, rpt].filter(Boolean).join("-");
const mapKey = [resourceId, pathKey].filter(Boolean).join("-");
return sanitize(mapKey) || "";
}
// ── Tests ────────────────────────────────────────────────────────────
function runTests() {
console.log("Running path encoding tests...\n");
let passed = 0;
// ── encodePath unit tests ────────────────────────────────────────
// Test 1: null/undefined/empty
{
assertEquals(encodePath(null), "", "null should return empty");
assertEquals(
encodePath(undefined),
"",
"undefined should return empty"
);
assertEquals(encodePath(""), "", "empty string should return empty");
console.log(" PASS: encodePath handles null/undefined/empty");
passed++;
}
// Test 2: root path
{
assertEquals(encodePath("/"), "2f", "/ should encode to 2f");
console.log(" PASS: encodePath encodes root path");
passed++;
}
// Test 3: alphanumeric passthrough
{
assertEquals(encodePath("/api"), "2fapi", "/api encodes slash only");
assertEquals(encodePath("/v1"), "2fv1", "/v1 encodes slash only");
assertEquals(encodePath("abc"), "abc", "plain alpha passes through");
console.log(" PASS: encodePath preserves alphanumeric chars");
passed++;
}
// Test 4: all special chars produce unique hex
{
const paths = ["/a/b", "/a-b", "/a.b", "/a_b", "/a b"];
const results = paths.map((p) => encodePath(p));
const unique = new Set(results);
assertEquals(
unique.size,
paths.length,
"all special-char paths must produce unique encodings"
);
console.log(
" PASS: encodePath produces unique output for different special chars"
);
passed++;
}
// Test 5: output is always alphanumeric (safe for Traefik names)
{
const paths = [
"/",
"/api",
"/a/b",
"/a-b",
"/a.b",
"/complex/path/here"
];
for (const p of paths) {
const e = encodePath(p);
assertEquals(
/^[a-zA-Z0-9]+$/.test(e),
true,
`encodePath("${p}") = "${e}" must be alphanumeric`
);
}
console.log(" PASS: encodePath output is always alphanumeric");
passed++;
}
// Test 6: deterministic
{
assertEquals(
encodePath("/api"),
encodePath("/api"),
"same input same output"
);
assertEquals(
encodePath("/a/b/c"),
encodePath("/a/b/c"),
"same input same output"
);
console.log(" PASS: encodePath is deterministic");
passed++;
}
// Test 7: many distinct paths never collide
{
const paths = [
"/",
"/api",
"/api/v1",
"/api/v2",
"/a/b",
"/a-b",
"/a.b",
"/a_b",
"/health",
"/health/check",
"/admin",
"/admin/users",
"/api/v1/users",
"/api/v1/posts",
"/app",
"/app/dashboard"
];
const encoded = new Set(paths.map((p) => encodePath(p)));
assertEquals(
encoded.size,
paths.length,
`expected ${paths.length} unique encodings, got ${encoded.size}`
);
console.log(" PASS: 16 realistic paths all produce unique encodings");
passed++;
}
// ── Collision fix: the actual bug we're fixing ───────────────────
// Test 8: /a/b and /a-b now have different keys (THE BUG FIX)
{
const keyAB = newKeyComputation(1, "/a/b", "prefix", null, null);
const keyDash = newKeyComputation(1, "/a-b", "prefix", null, null);
assertEquals(
keyAB !== keyDash,
true,
"/a/b and /a-b MUST have different keys"
);
console.log(" PASS: collision fix — /a/b vs /a-b have different keys");
passed++;
}
// Test 9: demonstrate the old bug — old code maps /a/b and /a-b to same key
{
const oldKeyAB = oldKeyComputation(1, "/a/b", "prefix", null, null);
const oldKeyDash = oldKeyComputation(1, "/a-b", "prefix", null, null);
assertEquals(
oldKeyAB,
oldKeyDash,
"old code MUST have this collision (confirms the bug exists)"
);
console.log(" PASS: confirmed old code bug — /a/b and /a-b collided");
passed++;
}
// Test 10: /api/v1 and /api-v1 — old code collision, new code fixes it
{
const oldKey1 = oldKeyComputation(1, "/api/v1", "prefix", null, null);
const oldKey2 = oldKeyComputation(1, "/api-v1", "prefix", null, null);
assertEquals(
oldKey1,
oldKey2,
"old code collision for /api/v1 vs /api-v1"
);
const newKey1 = newKeyComputation(1, "/api/v1", "prefix", null, null);
const newKey2 = newKeyComputation(1, "/api-v1", "prefix", null, null);
assertEquals(
newKey1 !== newKey2,
true,
"new code must separate /api/v1 and /api-v1"
);
console.log(" PASS: collision fix — /api/v1 vs /api-v1");
passed++;
}
// Test 11: /app.v2 and /app/v2 and /app-v2 — three-way collision fixed
{
const a = newKeyComputation(1, "/app.v2", "prefix", null, null);
const b = newKeyComputation(1, "/app/v2", "prefix", null, null);
const c = newKeyComputation(1, "/app-v2", "prefix", null, null);
const keys = new Set([a, b, c]);
assertEquals(
keys.size,
3,
"three paths must produce three unique keys"
);
console.log(
" PASS: collision fix — three-way /app.v2, /app/v2, /app-v2"
);
passed++;
}
// ── Edge cases ───────────────────────────────────────────────────
// Test 12: same path in different resources — always separate
{
const key1 = newKeyComputation(1, "/api", "prefix", null, null);
const key2 = newKeyComputation(2, "/api", "prefix", null, null);
assertEquals(
key1 !== key2,
true,
"different resources with same path must have different keys"
);
console.log(" PASS: edge case — same path, different resources");
passed++;
}
// Test 13: same resource, different pathMatchType — separate keys
{
const exact = newKeyComputation(1, "/api", "exact", null, null);
const prefix = newKeyComputation(1, "/api", "prefix", null, null);
assertEquals(
exact !== prefix,
true,
"exact vs prefix must have different keys"
);
console.log(" PASS: edge case — same path, different match types");
passed++;
}
// Test 14: same resource and path, different rewrite config — separate keys
{
const noRewrite = newKeyComputation(1, "/api", "prefix", null, null);
const withRewrite = newKeyComputation(
1,
"/api",
"prefix",
"/backend",
"prefix"
);
assertEquals(
noRewrite !== withRewrite,
true,
"with vs without rewrite must have different keys"
);
console.log(" PASS: edge case — same path, different rewrite config");
passed++;
}
// Test 15: paths with special URL characters
{
const paths = ["/api?foo", "/api#bar", "/api%20baz", "/api+qux"];
const keys = new Set(
paths.map((p) => newKeyComputation(1, p, "prefix", null, null))
);
assertEquals(
keys.size,
paths.length,
"special URL chars must produce unique keys"
);
console.log(" PASS: edge case — special URL characters in paths");
passed++;
}
console.log(`\nAll ${passed} tests passed!`);
}
try {
runTests();
} catch (error) {
console.error("Test failed:", error);
process.exit(1);
}

View File

@@ -13,6 +13,26 @@ export function sanitize(input: string | null | undefined): string | undefined {
.replace(/^-|-$/g, "");
}
/**
* Encode a URL path into a collision-free alphanumeric string suitable for use
* in Traefik map keys.
*
* Unlike sanitize(), this preserves uniqueness by encoding each non-alphanumeric
* character as its hex code. Different paths always produce different outputs.
*
* encodePath("/api") => "2fapi"
* encodePath("/a/b") => "2fa2fb"
* encodePath("/a-b") => "2fa2db" (different from /a/b)
* encodePath("/") => "2f"
* encodePath(null) => ""
*/
export function encodePath(path: string | null | undefined): string {
if (!path) return "";
return path.replace(/[^a-zA-Z0-9]/g, (ch) => {
return ch.charCodeAt(0).toString(16);
});
}
export function validatePathRewriteConfig(
path: string | null,
pathMatchType: string | null,

View File

@@ -13,8 +13,12 @@
import { rateLimitService } from "#private/lib/rateLimit";
import { cleanup as wsCleanup } from "#private/routers/ws";
import { flushBandwidthToDb } from "@server/routers/newt/handleReceiveBandwidthMessage";
import { flushSiteBandwidthToDb } from "@server/routers/gerbil/receiveBandwidth";
async function cleanup() {
await flushBandwidthToDb();
await flushSiteBandwidthToDb();
await rateLimitService.cleanup();
await wsCleanup();
@@ -25,4 +29,4 @@ export async function initCleanup() {
// Handle process termination
process.on("SIGTERM", () => cleanup());
process.on("SIGINT", () => cleanup());
}
}

View File

@@ -38,10 +38,6 @@ export const privateConfigSchema = z.object({
.string()
.optional()
.transform(getEnvOrYaml("SERVER_ENCRYPTION_KEY")),
resend_api_key: z
.string()
.optional()
.transform(getEnvOrYaml("RESEND_API_KEY")),
reo_client_id: z
.string()
.optional()

View File

@@ -1,127 +0,0 @@
/*
* This file is part of a proprietary work.
*
* Copyright (c) 2025 Fossorial, Inc.
* All rights reserved.
*
* This file is licensed under the Fossorial Commercial License.
* You may not use this file except in compliance with the License.
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
*
* This file is not licensed under the AGPLv3.
*/
import { Resend } from "resend";
import privateConfig from "#private/lib/config";
import logger from "@server/logger";
export enum AudienceIds {
SignUps = "6c4e77b2-0851-4bd6-bac8-f51f91360f1a",
Subscribed = "870b43fd-387f-44de-8fc1-707335f30b20",
Churned = "f3ae92bd-2fdb-4d77-8746-2118afd62549",
Newsletter = "5500c431-191c-42f0-a5d4-8b6d445b4ea0"
}
const resend = new Resend(
privateConfig.getRawPrivateConfig().server.resend_api_key || "missing"
);
export default resend;
export async function moveEmailToAudience(
email: string,
audienceId: AudienceIds
) {
if (process.env.ENVIRONMENT !== "prod") {
logger.debug(
`Skipping moving email ${email} to audience ${audienceId} in non-prod environment`
);
return;
}
const { error, data } = await retryWithBackoff(async () => {
const { data, error } = await resend.contacts.create({
email,
unsubscribed: false,
audienceId
});
if (error) {
throw new Error(
`Error adding email ${email} to audience ${audienceId}: ${error}`
);
}
return { error, data };
});
if (error) {
logger.error(
`Error adding email ${email} to audience ${audienceId}: ${error}`
);
return;
}
if (data) {
logger.debug(
`Added email ${email} to audience ${audienceId} with contact ID ${data.id}`
);
}
const otherAudiences = Object.values(AudienceIds).filter(
(id) => id !== audienceId
);
for (const otherAudienceId of otherAudiences) {
const { error, data } = await retryWithBackoff(async () => {
const { data, error } = await resend.contacts.remove({
email,
audienceId: otherAudienceId
});
if (error) {
throw new Error(
`Error removing email ${email} from audience ${otherAudienceId}: ${error}`
);
}
return { error, data };
});
if (error) {
logger.error(
`Error removing email ${email} from audience ${otherAudienceId}: ${error}`
);
}
if (data) {
logger.info(
`Removed email ${email} from audience ${otherAudienceId}`
);
}
}
}
type RetryOptions = {
retries?: number;
initialDelayMs?: number;
factor?: number;
};
export async function retryWithBackoff<T>(
fn: () => Promise<T>,
options: RetryOptions = {}
): Promise<T> {
const { retries = 5, initialDelayMs = 500, factor = 2 } = options;
let attempt = 0;
let delay = initialDelayMs;
while (true) {
try {
return await fn();
} catch (err) {
attempt++;
if (attempt > retries) throw err;
await new Promise((resolve) => setTimeout(resolve, delay));
delay *= factor;
}
}
}

View File

@@ -34,7 +34,11 @@ import {
import logger from "@server/logger";
import config from "@server/lib/config";
import { orgs, resources, sites, Target, targets } from "@server/db";
import { sanitize, validatePathRewriteConfig } from "@server/lib/traefik/utils";
import {
sanitize,
encodePath,
validatePathRewriteConfig
} from "@server/lib/traefik/utils";
import privateConfig from "#private/lib/config";
import createPathRewriteMiddleware from "@server/lib/traefik/middleware";
import {
@@ -170,7 +174,7 @@ export async function getTraefikConfig(
resourcesWithTargetsAndSites.forEach((row) => {
const resourceId = row.resourceId;
const resourceName = sanitize(row.resourceName) || "";
const targetPath = sanitize(row.path) || ""; // Handle null/undefined paths
const targetPath = encodePath(row.path); // Use encodePath to avoid collisions (e.g. "/a/b" vs "/a-b")
const pathMatchType = row.pathMatchType || "";
const rewritePath = row.rewritePath || "";
const rewritePathType = row.rewritePathType || "";
@@ -192,7 +196,7 @@ export async function getTraefikConfig(
const mapKey = [resourceId, pathKey].filter(Boolean).join("-");
const key = sanitize(mapKey);
if (!resourcesMap.has(key)) {
if (!resourcesMap.has(mapKey)) {
const validation = validatePathRewriteConfig(
row.path,
row.pathMatchType,
@@ -207,9 +211,10 @@ export async function getTraefikConfig(
return;
}
resourcesMap.set(key, {
resourcesMap.set(mapKey, {
resourceId: row.resourceId,
name: resourceName,
key: key,
fullDomain: row.fullDomain,
ssl: row.ssl,
http: row.http,
@@ -243,7 +248,7 @@ export async function getTraefikConfig(
}
// Add target with its associated site data
resourcesMap.get(key).targets.push({
resourcesMap.get(mapKey).targets.push({
resourceId: row.resourceId,
targetId: row.targetId,
ip: row.ip,
@@ -296,8 +301,9 @@ export async function getTraefikConfig(
};
// get the key and the resource
for (const [key, resource] of resourcesMap.entries()) {
for (const [, resource] of resourcesMap.entries()) {
const targets = resource.targets as TargetWithSite[];
const key = resource.key;
const routerName = `${key}-${resource.name}-router`;
const serviceName = `${key}-${resource.name}-service`;

View File

@@ -24,7 +24,6 @@ import { eq, and } from "drizzle-orm";
import logger from "@server/logger";
import stripe from "#private/lib/stripe";
import { handleSubscriptionLifesycle } from "../subscriptionLifecycle";
import { AudienceIds, moveEmailToAudience } from "#private/lib/resend";
import { getSubType } from "./getSubType";
import privateConfig from "#private/lib/config";
import { getLicensePriceSet, LicenseId } from "@server/lib/billing/licenses";
@@ -172,7 +171,7 @@ export async function handleSubscriptionCreated(
const email = orgUserRes.user.email;
if (email) {
moveEmailToAudience(email, AudienceIds.Subscribed);
// TODO: update user in Sendy
}
}
} else if (type === "license") {

View File

@@ -23,7 +23,6 @@ import {
import { eq, and } from "drizzle-orm";
import logger from "@server/logger";
import { handleSubscriptionLifesycle } from "../subscriptionLifecycle";
import { AudienceIds, moveEmailToAudience } from "#private/lib/resend";
import { getSubType } from "./getSubType";
import stripe from "#private/lib/stripe";
import privateConfig from "#private/lib/config";
@@ -109,7 +108,7 @@ export async function handleSubscriptionDeleted(
const email = orgUserRes.user.email;
if (email) {
moveEmailToAudience(email, AudienceIds.Churned);
// TODO: update user in Sendy
}
}
} else if (type === "license") {

View File

@@ -29,7 +29,6 @@ import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import { eq, or, and } from "drizzle-orm";
import { canUserAccessSiteResource } from "@server/auth/canUserAccessSiteResource";
import { signPublicKey, getOrgCAKeys } from "@server/lib/sshCA";
@@ -64,6 +63,7 @@ export type SignSshKeyResponse = {
sshUsername: string;
sshHost: string;
resourceId: number;
siteId: number;
keyId: string;
validPrincipals: string[];
validAfter: string;
@@ -453,6 +453,7 @@ export async function signSshKey(
sshUsername: usernameToUse,
sshHost: sshHost,
resourceId: resource.siteResourceId,
siteId: resource.siteId,
keyId: cert.keyId,
validPrincipals: cert.validPrincipals,
validAfter: cert.validAfter.toISOString(),

View File

@@ -12,6 +12,7 @@
*/
import { Router, Request, Response } from "express";
import zlib from "zlib";
import { Server as HttpServer } from "http";
import { WebSocket, WebSocketServer } from "ws";
import { Socket } from "net";
@@ -24,7 +25,8 @@ import {
OlmSession,
RemoteExitNode,
RemoteExitNodeSession,
remoteExitNodes
remoteExitNodes,
sites
} from "@server/db";
import { eq } from "drizzle-orm";
import { db } from "@server/db";
@@ -57,11 +59,13 @@ const MAX_PENDING_MESSAGES = 50; // Maximum messages to queue during connection
const processMessage = async (
ws: AuthenticatedWebSocket,
data: Buffer,
isBinary: boolean,
clientId: string,
clientType: ClientType
): Promise<void> => {
try {
const message: WSMessage = JSON.parse(data.toString());
const messageBuffer = isBinary ? zlib.gunzipSync(data) : data;
const message: WSMessage = JSON.parse(messageBuffer.toString());
// logger.debug(
// `Processing message from ${clientType.toUpperCase()} ID: ${clientId}, type: ${message.type}`
@@ -76,7 +80,7 @@ const processMessage = async (
clientId,
message.type, // Pass message type for granular limiting
100, // max requests per window
20, // max requests per message type per window
100, // max requests per message type per window
60 * 1000 // window in milliseconds
);
if (rateLimitResult.isLimited) {
@@ -163,8 +167,16 @@ const processPendingMessages = async (
);
const jobs = [];
for (const messageData of ws.pendingMessages) {
jobs.push(processMessage(ws, messageData, clientId, clientType));
for (const pending of ws.pendingMessages) {
jobs.push(
processMessage(
ws,
pending.data,
pending.isBinary,
clientId,
clientType
)
);
}
await Promise.all(jobs);
@@ -325,7 +337,9 @@ const addClient = async (
// Check Redis first if enabled
if (redisManager.isRedisEnabled()) {
try {
const redisVersion = await redisManager.get(getConfigVersionKey(clientId));
const redisVersion = await redisManager.get(
getConfigVersionKey(clientId)
);
if (redisVersion !== null) {
configVersion = parseInt(redisVersion, 10);
// Sync to local cache
@@ -337,7 +351,10 @@ const addClient = async (
} else {
// Use local cache version and sync to Redis
configVersion = clientConfigVersions.get(clientId) || 0;
await redisManager.set(getConfigVersionKey(clientId), configVersion.toString());
await redisManager.set(
getConfigVersionKey(clientId),
configVersion.toString()
);
}
} catch (error) {
logger.error("Failed to get/set config version in Redis:", error);
@@ -432,7 +449,9 @@ const removeClient = async (
};
// Helper to get the current config version for a client
const getClientConfigVersion = async (clientId: string): Promise<number | undefined> => {
const getClientConfigVersion = async (
clientId: string
): Promise<number | undefined> => {
// Try Redis first if available
if (redisManager.isRedisEnabled()) {
try {
@@ -502,11 +521,26 @@ const sendToClientLocal = async (
};
const messageString = JSON.stringify(messageWithVersion);
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(messageString);
}
});
if (options.compress) {
logger.debug(
`Message size before compression: ${messageString.length} bytes`
);
const compressed = zlib.gzipSync(Buffer.from(messageString, "utf8"));
logger.debug(
`Message size after compression: ${compressed.length} bytes`
);
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(compressed);
}
});
} else {
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(messageString);
}
});
}
return true;
};
@@ -532,11 +566,22 @@ const broadcastToAllExceptLocal = async (
configVersion
};
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(JSON.stringify(messageWithVersion));
}
});
if (options.compress) {
const compressed = zlib.gzipSync(
Buffer.from(JSON.stringify(messageWithVersion), "utf8")
);
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(compressed);
}
});
} else {
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(JSON.stringify(messageWithVersion));
}
});
}
}
}
};
@@ -762,7 +807,7 @@ const setupConnection = async (
}
// Set up message handler FIRST to prevent race condition
ws.on("message", async (data) => {
ws.on("message", async (data, isBinary) => {
if (!ws.isFullyConnected) {
// Queue message for later processing with limits
ws.pendingMessages = ws.pendingMessages || [];
@@ -777,11 +822,17 @@ const setupConnection = async (
logger.debug(
`Queueing message from ${clientType.toUpperCase()} ID: ${clientId} (connection not fully established)`
);
ws.pendingMessages.push(data as Buffer);
ws.pendingMessages.push({ data: data as Buffer, isBinary });
return;
}
await processMessage(ws, data as Buffer, clientId, clientType);
await processMessage(
ws,
data as Buffer,
isBinary,
clientId,
clientType
);
});
// Set up other event handlers before async operations
@@ -796,6 +847,31 @@ const setupConnection = async (
);
});
// Handle WebSocket protocol-level pings from older newt clients that do
// not send application-level "newt/ping" messages. Update the site's
// online state and lastPing timestamp so the offline checker treats them
// the same as modern newt clients.
if (clientType === "newt") {
const newtClient = client as Newt;
ws.on("ping", async () => {
if (!newtClient.siteId) return;
try {
await db
.update(sites)
.set({
online: true,
lastPing: Math.floor(Date.now() / 1000)
})
.where(eq(sites.siteId, newtClient.siteId));
} catch (error) {
logger.error(
"Error updating newt site online state on WS ping",
{ error }
);
}
});
}
ws.on("error", (error: Error) => {
logger.error(
`WebSocket error for ${clientType.toUpperCase()} ID ${clientId}:`,

View File

@@ -22,7 +22,6 @@ import { checkValidInvite } from "@server/auth/checkValidInvite";
import { passwordSchema } from "@server/auth/passwordSchema";
import { UserType } from "@server/types/UserTypes";
import { build } from "@server/build";
import resend, { AudienceIds, moveEmailToAudience } from "#dynamic/lib/resend";
export const signupBodySchema = z.object({
email: z.email().toLowerCase(),
@@ -237,7 +236,7 @@ export async function signup(
logger.debug(
`User ${email} opted in to marketing emails during signup.`
);
moveEmailToAudience(email, AudienceIds.SignUps);
// TODO: update user in Sendy
}
if (config.getRawConfig().flags?.require_email_verification) {

View File

@@ -1,51 +1,38 @@
import { sendToClient } from "#dynamic/routers/ws";
import { db, olms, Transaction } from "@server/db";
import { canCompress } from "@server/lib/clientVersionChecks";
import { Alias, SubnetProxyTarget } from "@server/lib/ip";
import logger from "@server/logger";
import { eq } from "drizzle-orm";
const BATCH_SIZE = 50;
const BATCH_DELAY_MS = 50;
function sleep(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms));
}
function chunkArray<T>(array: T[], size: number): T[][] {
const chunks: T[][] = [];
for (let i = 0; i < array.length; i += size) {
chunks.push(array.slice(i, i + size));
}
return chunks;
}
export async function addTargets(newtId: string, targets: SubnetProxyTarget[]) {
const batches = chunkArray(targets, BATCH_SIZE);
for (let i = 0; i < batches.length; i++) {
if (i > 0) {
await sleep(BATCH_DELAY_MS);
}
await sendToClient(newtId, {
export async function addTargets(
newtId: string,
targets: SubnetProxyTarget[],
version?: string | null
) {
await sendToClient(
newtId,
{
type: `newt/wg/targets/add`,
data: batches[i]
}, { incrementConfigVersion: true });
}
data: targets
},
{ incrementConfigVersion: true, compress: canCompress(version, "newt") }
);
}
export async function removeTargets(
newtId: string,
targets: SubnetProxyTarget[]
targets: SubnetProxyTarget[],
version?: string | null
) {
const batches = chunkArray(targets, BATCH_SIZE);
for (let i = 0; i < batches.length; i++) {
if (i > 0) {
await sleep(BATCH_DELAY_MS);
}
await sendToClient(newtId, {
await sendToClient(
newtId,
{
type: `newt/wg/targets/remove`,
data: batches[i]
},{ incrementConfigVersion: true });
}
data: targets
},
{ incrementConfigVersion: true, compress: canCompress(version, "newt") }
);
}
export async function updateTargets(
@@ -53,26 +40,22 @@ export async function updateTargets(
targets: {
oldTargets: SubnetProxyTarget[];
newTargets: SubnetProxyTarget[];
}
},
version?: string | null
) {
const oldBatches = chunkArray(targets.oldTargets, BATCH_SIZE);
const newBatches = chunkArray(targets.newTargets, BATCH_SIZE);
const maxBatches = Math.max(oldBatches.length, newBatches.length);
for (let i = 0; i < maxBatches; i++) {
if (i > 0) {
await sleep(BATCH_DELAY_MS);
}
await sendToClient(newtId, {
await sendToClient(
newtId,
{
type: `newt/wg/targets/update`,
data: {
oldTargets: oldBatches[i] || [],
newTargets: newBatches[i] || []
oldTargets: targets.oldTargets,
newTargets: targets.newTargets
}
}, { incrementConfigVersion: true }).catch((error) => {
logger.warn(`Error sending message:`, error);
});
}
},
{ incrementConfigVersion: true, compress: canCompress(version, "newt") }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
}
export async function addPeerData(
@@ -80,7 +63,8 @@ export async function addPeerData(
siteId: number,
remoteSubnets: string[],
aliases: Alias[],
olmId?: string
olmId?: string,
version?: string | null
) {
if (!olmId) {
const [olm] = await db
@@ -92,16 +76,21 @@ export async function addPeerData(
return; // ignore this because an olm might not be associated with the client anymore
}
olmId = olm.olmId;
version = olm.version;
}
await sendToClient(olmId, {
type: `olm/wg/peer/data/add`,
data: {
siteId: siteId,
remoteSubnets: remoteSubnets,
aliases: aliases
}
}, { incrementConfigVersion: true }).catch((error) => {
await sendToClient(
olmId,
{
type: `olm/wg/peer/data/add`,
data: {
siteId: siteId,
remoteSubnets: remoteSubnets,
aliases: aliases
}
},
{ incrementConfigVersion: true, compress: canCompress(version, "olm") }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
}
@@ -111,7 +100,8 @@ export async function removePeerData(
siteId: number,
remoteSubnets: string[],
aliases: Alias[],
olmId?: string
olmId?: string,
version?: string | null
) {
if (!olmId) {
const [olm] = await db
@@ -123,16 +113,21 @@ export async function removePeerData(
return;
}
olmId = olm.olmId;
version = olm.version;
}
await sendToClient(olmId, {
type: `olm/wg/peer/data/remove`,
data: {
siteId: siteId,
remoteSubnets: remoteSubnets,
aliases: aliases
}
}, { incrementConfigVersion: true }).catch((error) => {
await sendToClient(
olmId,
{
type: `olm/wg/peer/data/remove`,
data: {
siteId: siteId,
remoteSubnets: remoteSubnets,
aliases: aliases
}
},
{ incrementConfigVersion: true, compress: canCompress(version, "olm") }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
}
@@ -152,7 +147,8 @@ export async function updatePeerData(
newAliases: Alias[];
}
| undefined,
olmId?: string
olmId?: string,
version?: string | null
) {
if (!olmId) {
const [olm] = await db
@@ -164,16 +160,21 @@ export async function updatePeerData(
return;
}
olmId = olm.olmId;
version = olm.version;
}
await sendToClient(olmId, {
type: `olm/wg/peer/data/update`,
data: {
siteId: siteId,
...remoteSubnets,
...aliases
}
}, { incrementConfigVersion: true }).catch((error) => {
await sendToClient(
olmId,
{
type: `olm/wg/peer/data/update`,
data: {
siteId: siteId,
...remoteSubnets,
...aliases
}
},
{ incrementConfigVersion: true, compress: canCompress(version, "olm") }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
}

View File

@@ -40,7 +40,8 @@ async function queryDomains(orgId: string, limit: number, offset: number) {
tries: domains.tries,
configManaged: domains.configManaged,
certResolver: domains.certResolver,
preferWildcardCert: domains.preferWildcardCert
preferWildcardCert: domains.preferWildcardCert,
errorMessage: domains.errorMessage
})
.from(orgDomains)
.where(eq(orgDomains.orgId, orgId))

View File

@@ -125,7 +125,7 @@ export async function generateRelayMappings(exitNode: ExitNode) {
// Add site as a destination for this client
const destination: PeerDestination = {
destinationIP: site.subnet.split("/")[0],
destinationPort: site.listenPort
destinationPort: site.listenPort || 1 // this satisfies gerbil for now but should be reevaluated
};
// Check if this destination is already in the array to avoid duplicates
@@ -165,7 +165,7 @@ export async function generateRelayMappings(exitNode: ExitNode) {
const destination: PeerDestination = {
destinationIP: peer.subnet.split("/")[0],
destinationPort: peer.listenPort
destinationPort: peer.listenPort || 1 // this satisfies gerbil for now but should be reevaluated
};
// Check for duplicates

View File

@@ -1,5 +1,5 @@
import { Request, Response, NextFunction } from "express";
import { eq, and, lt, inArray, sql } from "drizzle-orm";
import { eq, sql } from "drizzle-orm";
import { sites } from "@server/db";
import { db } from "@server/db";
import logger from "@server/logger";
@@ -11,19 +11,31 @@ import { FeatureId } from "@server/lib/billing/features";
import { checkExitNodeOrg } from "#dynamic/lib/exitNodes";
import { build } from "@server/build";
// Track sites that are already offline to avoid unnecessary queries
const offlineSites = new Set<string>();
// Retry configuration for deadlock handling
const MAX_RETRIES = 3;
const BASE_DELAY_MS = 50;
interface PeerBandwidth {
publicKey: string;
bytesIn: number;
bytesOut: number;
}
interface AccumulatorEntry {
bytesIn: number;
bytesOut: number;
/** Present when the update came through a remote exit node. */
exitNodeId?: number;
/** Whether to record egress usage for billing purposes. */
calcUsage: boolean;
}
// Retry configuration for deadlock handling
const MAX_RETRIES = 3;
const BASE_DELAY_MS = 50;
// How often to flush accumulated bandwidth data to the database
const FLUSH_INTERVAL_MS = 30_000; // 30 seconds
// In-memory accumulator: publicKey -> AccumulatorEntry
let accumulator = new Map<string, AccumulatorEntry>();
/**
* Check if an error is a deadlock error
*/
@@ -63,6 +75,220 @@ async function withDeadlockRetry<T>(
}
}
/**
* Flush all accumulated site bandwidth data to the database.
*
* Swaps out the accumulator before writing so that any bandwidth messages
* received during the flush are captured in the new accumulator rather than
* being lost or causing contention. Entries that fail to write are re-queued
* back into the accumulator so they will be retried on the next flush.
*
* This function is exported so that the application's graceful-shutdown
* cleanup handler can call it before the process exits.
*/
export async function flushSiteBandwidthToDb(): Promise<void> {
if (accumulator.size === 0) {
return;
}
// Atomically swap out the accumulator so new data keeps flowing in
// while we write the snapshot to the database.
const snapshot = accumulator;
accumulator = new Map<string, AccumulatorEntry>();
const currentTime = new Date().toISOString();
// Sort by publicKey for consistent lock ordering across concurrent
// writers — deadlock-prevention strategy.
const sortedEntries = [...snapshot.entries()].sort(([a], [b]) =>
a.localeCompare(b)
);
logger.debug(
`Flushing accumulated bandwidth data for ${sortedEntries.length} site(s) to the database`
);
// Aggregate billing usage by org, collected during the DB update loop.
const orgUsageMap = new Map<string, number>();
for (const [publicKey, { bytesIn, bytesOut, exitNodeId, calcUsage }] of sortedEntries) {
try {
const updatedSite = await withDeadlockRetry(async () => {
const [result] = await db
.update(sites)
.set({
megabytesOut: sql`COALESCE(${sites.megabytesOut}, 0) + ${bytesIn}`,
megabytesIn: sql`COALESCE(${sites.megabytesIn}, 0) + ${bytesOut}`,
lastBandwidthUpdate: currentTime
})
.where(eq(sites.pubKey, publicKey))
.returning({
orgId: sites.orgId,
siteId: sites.siteId
});
return result;
}, `flush bandwidth for site ${publicKey}`);
if (updatedSite) {
if (exitNodeId) {
const notAllowed = await checkExitNodeOrg(
exitNodeId,
updatedSite.orgId
);
if (notAllowed) {
logger.warn(
`Exit node ${exitNodeId} is not allowed for org ${updatedSite.orgId}`
);
// Skip usage tracking for this site but continue
// processing the rest.
continue;
}
}
if (calcUsage) {
const totalBandwidth = bytesIn + bytesOut;
const current = orgUsageMap.get(updatedSite.orgId) ?? 0;
orgUsageMap.set(updatedSite.orgId, current + totalBandwidth);
}
}
} catch (error) {
logger.error(
`Failed to flush bandwidth for site ${publicKey}:`,
error
);
// Re-queue the failed entry so it is retried on the next flush
// rather than silently dropped.
const existing = accumulator.get(publicKey);
if (existing) {
existing.bytesIn += bytesIn;
existing.bytesOut += bytesOut;
} else {
accumulator.set(publicKey, {
bytesIn,
bytesOut,
exitNodeId,
calcUsage
});
}
}
}
// Process billing usage updates outside the site-update loop to keep
// lock scope small and concerns separated.
if (orgUsageMap.size > 0) {
// Sort org IDs for consistent lock ordering.
const sortedOrgIds = [...orgUsageMap.keys()].sort();
for (const orgId of sortedOrgIds) {
try {
const totalBandwidth = orgUsageMap.get(orgId)!;
const bandwidthUsage = await usageService.add(
orgId,
FeatureId.EGRESS_DATA_MB,
totalBandwidth
);
if (bandwidthUsage) {
// Fire-and-forget — don't block the flush on limit checking.
usageService
.checkLimitSet(
orgId,
FeatureId.EGRESS_DATA_MB,
bandwidthUsage
)
.catch((error: any) => {
logger.error(
`Error checking bandwidth limits for org ${orgId}:`,
error
);
});
}
} catch (error) {
logger.error(
`Error processing usage for org ${orgId}:`,
error
);
// Continue with other orgs.
}
}
}
}
// ---------------------------------------------------------------------------
// Periodic flush timer
// ---------------------------------------------------------------------------
const flushTimer = setInterval(async () => {
try {
await flushSiteBandwidthToDb();
} catch (error) {
logger.error(
"Unexpected error during periodic site bandwidth flush:",
error
);
}
}, FLUSH_INTERVAL_MS);
// Allow the process to exit normally even while the timer is pending.
// The graceful-shutdown path (see server/cleanup.ts) will call
// flushSiteBandwidthToDb() explicitly before process.exit(), so no data
// is lost.
flushTimer.unref();
// ---------------------------------------------------------------------------
// Public API
// ---------------------------------------------------------------------------
/**
* Accumulate bandwidth data reported by a gerbil or remote exit node.
*
* Only peers that actually transferred data (bytesIn > 0) are added to the
* accumulator; peers with no activity are silently ignored, which means the
* flush will only write rows that have genuinely changed.
*
* The function is intentionally synchronous in its fast path so that the
* HTTP handler can respond immediately without waiting for any I/O.
*/
export async function updateSiteBandwidth(
bandwidthData: PeerBandwidth[],
calcUsageAndLimits: boolean,
exitNodeId?: number
): Promise<void> {
for (const { publicKey, bytesIn, bytesOut } of bandwidthData) {
// Skip peers that haven't transferred any data — writing zeros to the
// database would be a no-op anyway.
if (bytesIn <= 0 && bytesOut <= 0) {
continue;
}
const existing = accumulator.get(publicKey);
if (existing) {
existing.bytesIn += bytesIn;
existing.bytesOut += bytesOut;
// Retain the most-recent exitNodeId for this peer.
if (exitNodeId !== undefined) {
existing.exitNodeId = exitNodeId;
}
// Once calcUsage has been requested for a peer, keep it set for
// the lifetime of this flush window.
if (calcUsageAndLimits) {
existing.calcUsage = true;
}
} else {
accumulator.set(publicKey, {
bytesIn,
bytesOut,
exitNodeId,
calcUsage: calcUsageAndLimits
});
}
}
}
// ---------------------------------------------------------------------------
// HTTP handler
// ---------------------------------------------------------------------------
export const receiveBandwidth = async (
req: Request,
res: Response,
@@ -75,7 +301,9 @@ export const receiveBandwidth = async (
throw new Error("Invalid bandwidth data");
}
await updateSiteBandwidth(bandwidthData, build == "saas"); // we are checking the usage on saas only
// Accumulate in memory; the periodic timer (and the shutdown hook)
// will write to the database.
await updateSiteBandwidth(bandwidthData, build == "saas");
return response(res, {
data: {},
@@ -93,202 +321,4 @@ export const receiveBandwidth = async (
)
);
}
};
export async function updateSiteBandwidth(
bandwidthData: PeerBandwidth[],
calcUsageAndLimits: boolean,
exitNodeId?: number
) {
const currentTime = new Date();
const oneMinuteAgo = new Date(currentTime.getTime() - 60000); // 1 minute ago
// Sort bandwidth data by publicKey to ensure consistent lock ordering across all instances
// This is critical for preventing deadlocks when multiple instances update the same sites
const sortedBandwidthData = [...bandwidthData].sort((a, b) =>
a.publicKey.localeCompare(b.publicKey)
);
// First, handle sites that are actively reporting bandwidth
const activePeers = sortedBandwidthData.filter((peer) => peer.bytesIn > 0);
// Aggregate usage data by organization (collected outside transaction)
const orgUsageMap = new Map<string, number>();
if (activePeers.length > 0) {
// Remove any active peers from offline tracking since they're sending data
activePeers.forEach((peer) => offlineSites.delete(peer.publicKey));
// Update each active site individually with retry logic
// This reduces transaction scope and allows retries per-site
for (const peer of activePeers) {
try {
const updatedSite = await withDeadlockRetry(async () => {
const [result] = await db
.update(sites)
.set({
megabytesOut: sql`${sites.megabytesOut} + ${peer.bytesIn}`,
megabytesIn: sql`${sites.megabytesIn} + ${peer.bytesOut}`,
lastBandwidthUpdate: currentTime.toISOString(),
online: true
})
.where(eq(sites.pubKey, peer.publicKey))
.returning({
online: sites.online,
orgId: sites.orgId,
siteId: sites.siteId,
lastBandwidthUpdate: sites.lastBandwidthUpdate
});
return result;
}, `update active site ${peer.publicKey}`);
if (updatedSite) {
if (exitNodeId) {
const notAllowed = await checkExitNodeOrg(
exitNodeId,
updatedSite.orgId
);
if (notAllowed) {
logger.warn(
`Exit node ${exitNodeId} is not allowed for org ${updatedSite.orgId}`
);
// Skip this site but continue processing others
continue;
}
}
// Aggregate bandwidth usage for the org
const totalBandwidth = peer.bytesIn + peer.bytesOut;
const currentOrgUsage =
orgUsageMap.get(updatedSite.orgId) || 0;
orgUsageMap.set(
updatedSite.orgId,
currentOrgUsage + totalBandwidth
);
}
} catch (error) {
logger.error(
`Failed to update bandwidth for site ${peer.publicKey}:`,
error
);
// Continue with other sites
}
}
}
// Process usage updates outside of site update transactions
// This separates the concerns and reduces lock contention
if (calcUsageAndLimits && orgUsageMap.size > 0) {
// Sort org IDs to ensure consistent lock ordering
const allOrgIds = [...new Set([...orgUsageMap.keys()])].sort();
for (const orgId of allOrgIds) {
try {
// Process bandwidth usage for this org
const totalBandwidth = orgUsageMap.get(orgId);
if (totalBandwidth) {
const bandwidthUsage = await usageService.add(
orgId,
FeatureId.EGRESS_DATA_MB,
totalBandwidth
);
if (bandwidthUsage) {
// Fire and forget - don't block on limit checking
usageService
.checkLimitSet(
orgId,
FeatureId.EGRESS_DATA_MB,
bandwidthUsage
)
.catch((error: any) => {
logger.error(
`Error checking bandwidth limits for org ${orgId}:`,
error
);
});
}
}
} catch (error) {
logger.error(`Error processing usage for org ${orgId}:`, error);
// Continue with other orgs
}
}
}
// Handle sites that reported zero bandwidth but need online status updated
const zeroBandwidthPeers = sortedBandwidthData.filter(
(peer) => peer.bytesIn === 0 && !offlineSites.has(peer.publicKey)
);
if (zeroBandwidthPeers.length > 0) {
// Fetch all zero bandwidth sites in one query
const zeroBandwidthSites = await db
.select()
.from(sites)
.where(
inArray(
sites.pubKey,
zeroBandwidthPeers.map((p) => p.publicKey)
)
);
// Sort by siteId to ensure consistent lock ordering
const sortedZeroBandwidthSites = zeroBandwidthSites.sort(
(a, b) => a.siteId - b.siteId
);
for (const site of sortedZeroBandwidthSites) {
let newOnlineStatus = site.online;
// Check if site should go offline based on last bandwidth update WITH DATA
if (site.lastBandwidthUpdate) {
const lastUpdateWithData = new Date(site.lastBandwidthUpdate);
if (lastUpdateWithData < oneMinuteAgo) {
newOnlineStatus = false;
}
} else {
// No previous data update recorded, set to offline
newOnlineStatus = false;
}
// Only update online status if it changed
if (site.online !== newOnlineStatus) {
try {
const updatedSite = await withDeadlockRetry(async () => {
const [result] = await db
.update(sites)
.set({
online: newOnlineStatus
})
.where(eq(sites.siteId, site.siteId))
.returning();
return result;
}, `update offline status for site ${site.siteId}`);
if (updatedSite && exitNodeId) {
const notAllowed = await checkExitNodeOrg(
exitNodeId,
updatedSite.orgId
);
if (notAllowed) {
logger.warn(
`Exit node ${exitNodeId} is not allowed for org ${updatedSite.orgId}`
);
}
}
// If site went offline, add it to our tracking set
if (!newOnlineStatus && site.pubKey) {
offlineSites.add(site.pubKey);
}
} catch (error) {
logger.error(
`Failed to update offline status for site ${site.siteId}:`,
error
);
// Continue with other sites
}
}
}
}
}
};

View File

@@ -112,7 +112,7 @@ export async function updateHolePunch(
destinations: destinations
});
} catch (error) {
// logger.error(error); // FIX THIS
logger.error(error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
@@ -262,7 +262,7 @@ export async function updateAndGenerateEndpointDestinations(
if (site.subnet && site.listenPort) {
destinations.push({
destinationIP: site.subnet.split("/")[0],
destinationPort: site.listenPort
destinationPort: site.listenPort || 1 // this satisfies gerbil for now but should be reevaluated
});
}
}
@@ -339,10 +339,10 @@ export async function updateAndGenerateEndpointDestinations(
handleSiteEndpointChange(newt.siteId, updatedSite.endpoint!);
}
if (!updatedSite || !updatedSite.subnet) {
logger.warn(`Site not found: ${newt.siteId}`);
throw new Error("Site not found");
}
// if (!updatedSite || !updatedSite.subnet) {
// logger.warn(`Site not found: ${newt.siteId}`);
// throw new Error("Site not found");
// }
// Find all clients that connect to this site
// const sitesClientPairs = await db

View File

@@ -1,4 +1,15 @@
import { clients, clientSiteResourcesAssociationsCache, clientSitesAssociationsCache, db, ExitNode, resources, Site, siteResources, targetHealthCheck, targets } from "@server/db";
import {
clients,
clientSiteResourcesAssociationsCache,
clientSitesAssociationsCache,
db,
ExitNode,
resources,
Site,
siteResources,
targetHealthCheck,
targets
} from "@server/db";
import logger from "@server/logger";
import { initPeerAddHandshake, updatePeer } from "../olm/peers";
import { eq, and } from "drizzle-orm";
@@ -69,40 +80,42 @@ export async function buildClientConfigurationForNewtClient(
// )
// );
// update the peer info on the olm
// if the peer has not been added yet this will be a no-op
await updatePeer(client.clients.clientId, {
siteId: site.siteId,
endpoint: site.endpoint!,
relayEndpoint: `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`,
publicKey: site.publicKey!,
serverIP: site.address,
serverPort: site.listenPort
// remoteSubnets: generateRemoteSubnets(
// allSiteResources.map(
// ({ siteResources }) => siteResources
// )
// ),
// aliases: generateAliasConfig(
// allSiteResources.map(
// ({ siteResources }) => siteResources
// )
// )
});
if (!client.clientSitesAssociationsCache.isJitMode) { // if we are adding sites through jit then dont add the site to the olm
// update the peer info on the olm
// if the peer has not been added yet this will be a no-op
await updatePeer(client.clients.clientId, {
siteId: site.siteId,
endpoint: site.endpoint!,
relayEndpoint: `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`,
publicKey: site.publicKey!,
serverIP: site.address,
serverPort: site.listenPort
// remoteSubnets: generateRemoteSubnets(
// allSiteResources.map(
// ({ siteResources }) => siteResources
// )
// ),
// aliases: generateAliasConfig(
// allSiteResources.map(
// ({ siteResources }) => siteResources
// )
// )
});
// also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch
// if it has already been added this will be a no-op
await initPeerAddHandshake(
// this will kick off the add peer process for the client
client.clients.clientId,
{
siteId,
exitNode: {
publicKey: exitNode.publicKey,
endpoint: exitNode.endpoint
// also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch
// if it has already been added this will be a no-op
await initPeerAddHandshake(
// this will kick off the add peer process for the client
client.clients.clientId,
{
siteId,
exitNode: {
publicKey: exitNode.publicKey,
endpoint: exitNode.endpoint
}
}
}
);
);
}
return {
publicKey: client.clients.pubKey!,
@@ -188,7 +201,8 @@ export async function buildTargetConfigurationForNewtClient(siteId: number) {
hcTimeout: targetHealthCheck.hcTimeout,
hcHeaders: targetHealthCheck.hcHeaders,
hcMethod: targetHealthCheck.hcMethod,
hcTlsServerName: targetHealthCheck.hcTlsServerName
hcTlsServerName: targetHealthCheck.hcTlsServerName,
hcStatus: targetHealthCheck.hcStatus
})
.from(targets)
.innerJoin(resources, eq(targets.resourceId, resources.resourceId))
@@ -229,9 +243,9 @@ export async function buildTargetConfigurationForNewtClient(siteId: number) {
!target.hcInterval ||
!target.hcMethod
) {
logger.debug(
`Skipping adding target health check ${target.targetId} due to missing health check fields`
);
// logger.debug(
// `Skipping adding target health check ${target.targetId} due to missing health check fields`
// );
return null; // Skip targets with missing health check fields
}
@@ -261,7 +275,8 @@ export async function buildTargetConfigurationForNewtClient(siteId: number) {
hcTimeout: target.hcTimeout, // in seconds
hcHeaders: hcHeadersSend,
hcMethod: target.hcMethod,
hcTlsServerName: target.hcTlsServerName
hcTlsServerName: target.hcTlsServerName,
hcStatus: target.hcStatus
};
});

View File

@@ -6,6 +6,7 @@ import { db, ExitNode, exitNodes, Newt, sites } from "@server/db";
import { eq } from "drizzle-orm";
import { sendToExitNode } from "#dynamic/lib/exitNodes";
import { buildClientConfigurationForNewtClient } from "./buildConfiguration";
import { canCompress } from "@server/lib/clientVersionChecks";
const inputSchema = z.object({
publicKey: z.string(),
@@ -104,11 +105,11 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
const payload = {
oldDestination: {
destinationIP: existingSite.subnet?.split("/")[0],
destinationPort: existingSite.listenPort
destinationPort: existingSite.listenPort || 1 // this satisfies gerbil for now but should be reevaluated
},
newDestination: {
destinationIP: site.subnet?.split("/")[0],
destinationPort: site.listenPort
destinationPort: site.listenPort || 1 // this satisfies gerbil for now but should be reevaluated
}
};
@@ -135,6 +136,9 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
targets
}
},
options: {
compress: canCompress(newt.version, "newt")
},
broadcast: false,
excludeSender: false
};

View File

@@ -1,105 +1,107 @@
import { db, sites } from "@server/db";
import { disconnectClient, getClientConfigVersion } from "#dynamic/routers/ws";
import { db, newts, sites } from "@server/db";
import { hasActiveConnections, getClientConfigVersion } from "#dynamic/routers/ws";
import { MessageHandler } from "@server/routers/ws";
import { clients, Newt } from "@server/db";
import { Newt } from "@server/db";
import { eq, lt, isNull, and, or } from "drizzle-orm";
import logger from "@server/logger";
import { validateSessionToken } from "@server/auth/sessions/app";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { sendTerminateClient } from "../client/terminate";
import { encodeHexLowerCase } from "@oslojs/encoding";
import { sha256 } from "@oslojs/crypto/sha2";
import { sendNewtSyncMessage } from "./sync";
// Track if the offline checker interval is running
// let offlineCheckerInterval: NodeJS.Timeout | null = null;
// const OFFLINE_CHECK_INTERVAL = 30 * 1000; // Check every 30 seconds
// const OFFLINE_THRESHOLD_MS = 2 * 60 * 1000; // 2 minutes
let offlineCheckerInterval: NodeJS.Timeout | null = null;
const OFFLINE_CHECK_INTERVAL = 30 * 1000; // Check every 30 seconds
const OFFLINE_THRESHOLD_MS = 2 * 60 * 1000; // 2 minutes
/**
* Starts the background interval that checks for clients that haven't pinged recently
* and marks them as offline
* Starts the background interval that checks for newt sites that haven't
* pinged recently and marks them as offline. For backward compatibility,
* a site is only marked offline when there is no active WebSocket connection
* either — so older newt versions that don't send pings but remain connected
* continue to be treated as online.
*/
// export const startNewtOfflineChecker = (): void => {
// if (offlineCheckerInterval) {
// return; // Already running
// }
export const startNewtOfflineChecker = (): void => {
if (offlineCheckerInterval) {
return; // Already running
}
// offlineCheckerInterval = setInterval(async () => {
// try {
// const twoMinutesAgo = Math.floor(
// (Date.now() - OFFLINE_THRESHOLD_MS) / 1000
// );
offlineCheckerInterval = setInterval(async () => {
try {
const twoMinutesAgo = Math.floor(
(Date.now() - OFFLINE_THRESHOLD_MS) / 1000
);
// // TODO: WE NEED TO MAKE SURE THIS WORKS WITH DISTRIBUTED NODES ALL DOING THE SAME THING
// Find all online newt-type sites that haven't pinged recently
// (or have never pinged at all). Join newts to obtain the newtId
// needed for the WebSocket connection check.
const staleSites = await db
.select({
siteId: sites.siteId,
newtId: newts.newtId,
lastPing: sites.lastPing
})
.from(sites)
.innerJoin(newts, eq(newts.siteId, sites.siteId))
.where(
and(
eq(sites.online, true),
eq(sites.type, "newt"),
or(
lt(sites.lastPing, twoMinutesAgo),
isNull(sites.lastPing)
)
)
);
// // Find clients that haven't pinged in the last 2 minutes and mark them as offline
// const offlineClients = await db
// .update(clients)
// .set({ online: false })
// .where(
// and(
// eq(clients.online, true),
// or(
// lt(clients.lastPing, twoMinutesAgo),
// isNull(clients.lastPing)
// )
// )
// )
// .returning();
for (const staleSite of staleSites) {
// Backward-compatibility check: if the newt still has an
// active WebSocket connection (older clients that don't send
// pings), keep the site online.
const isConnected = await hasActiveConnections(staleSite.newtId);
if (isConnected) {
logger.debug(
`Newt ${staleSite.newtId} has not pinged recently but is still connected via WebSocket — keeping site ${staleSite.siteId} online`
);
continue;
}
// for (const offlineClient of offlineClients) {
// logger.info(
// `Kicking offline newt client ${offlineClient.clientId} due to inactivity`
// );
logger.info(
`Marking site ${staleSite.siteId} offline: newt ${staleSite.newtId} has no recent ping and no active WebSocket connection`
);
// if (!offlineClient.newtId) {
// logger.warn(
// `Offline client ${offlineClient.clientId} has no newtId, cannot disconnect`
// );
// continue;
// }
await db
.update(sites)
.set({ online: false })
.where(eq(sites.siteId, staleSite.siteId));
}
} catch (error) {
logger.error("Error in newt offline checker interval", { error });
}
}, OFFLINE_CHECK_INTERVAL);
// // Send a disconnect message to the client if connected
// try {
// await sendTerminateClient(
// offlineClient.clientId,
// offlineClient.newtId
// ); // terminate first
// // wait a moment to ensure the message is sent
// await new Promise((resolve) => setTimeout(resolve, 1000));
// await disconnectClient(offlineClient.newtId);
// } catch (error) {
// logger.error(
// `Error sending disconnect to offline newt ${offlineClient.clientId}`,
// { error }
// );
// }
// }
// } catch (error) {
// logger.error("Error in offline checker interval", { error });
// }
// }, OFFLINE_CHECK_INTERVAL);
// logger.debug("Started offline checker interval");
// };
logger.debug("Started newt offline checker interval");
};
/**
* Stops the background interval that checks for offline clients
* Stops the background interval that checks for offline newt sites.
*/
// export const stopNewtOfflineChecker = (): void => {
// if (offlineCheckerInterval) {
// clearInterval(offlineCheckerInterval);
// offlineCheckerInterval = null;
// logger.info("Stopped offline checker interval");
// }
// };
export const stopNewtOfflineChecker = (): void => {
if (offlineCheckerInterval) {
clearInterval(offlineCheckerInterval);
offlineCheckerInterval = null;
logger.info("Stopped newt offline checker interval");
}
};
/**
* Handles ping messages from clients and responds with pong
* Handles ping messages from newt clients.
*
* On each ping:
* - Marks the associated site as online.
* - Records the current timestamp as the newt's last-ping time.
* - Triggers a config sync if the newt is running an outdated config version.
* - Responds with a pong message.
*/
export const handleNewtPingMessage: MessageHandler = async (context) => {
const { message, client: c, sendToClient } = context;
const { message, client: c } = context;
const newt = c as Newt;
if (!newt) {
@@ -112,15 +114,31 @@ export const handleNewtPingMessage: MessageHandler = async (context) => {
return;
}
// get the version
try {
// Mark the site as online and record the ping timestamp.
await db
.update(sites)
.set({
online: true,
lastPing: Math.floor(Date.now() / 1000)
})
.where(eq(sites.siteId, newt.siteId));
} catch (error) {
logger.error("Error updating online state on newt ping", { error });
}
// Check config version and sync if stale.
const configVersion = await getClientConfigVersion(newt.newtId);
if (message.configVersion && configVersion != null && configVersion != message.configVersion) {
if (
message.configVersion != null &&
configVersion != null &&
configVersion !== message.configVersion
) {
logger.warn(
`Newt ping with outdated config version: ${message.configVersion} (current: ${configVersion})`
);
// get the site
const [site] = await db
.select()
.from(sites)
@@ -137,19 +155,6 @@ export const handleNewtPingMessage: MessageHandler = async (context) => {
await sendNewtSyncMessage(newt, site);
}
// try {
// // Update the client's last ping timestamp
// await db
// .update(clients)
// .set({
// lastPing: Math.floor(Date.now() / 1000),
// online: true
// })
// .where(eq(clients.clientId, newt.clientId));
// } catch (error) {
// logger.error("Error handling ping message", { error });
// }
return {
message: {
type: "pong",

View File

@@ -5,9 +5,7 @@ import { eq } from "drizzle-orm";
import { addPeer, deletePeer } from "../gerbil/peers";
import logger from "@server/logger";
import config from "@server/lib/config";
import {
findNextAvailableCidr,
} from "@server/lib/ip";
import { findNextAvailableCidr } from "@server/lib/ip";
import {
selectBestExitNode,
verifyExitNodeOrgAccess
@@ -15,6 +13,7 @@ import {
import { fetchContainers } from "./dockerSocket";
import { lockManager } from "#dynamic/lib/lock";
import { buildTargetConfigurationForNewtClient } from "./buildConfiguration";
import { canCompress } from "@server/lib/clientVersionChecks";
export type ExitNodePingResult = {
exitNodeId: number;
@@ -215,6 +214,9 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => {
healthCheckTargets: validHealthCheckTargets
}
},
options: {
compress: canCompress(newt.version, "newt")
},
broadcast: false, // Send to all clients
excludeSender: false // Include sender in broadcast
};

View File

@@ -10,10 +10,21 @@ interface PeerBandwidth {
bytesOut: number;
}
interface BandwidthAccumulator {
bytesIn: number;
bytesOut: number;
}
// Retry configuration for deadlock handling
const MAX_RETRIES = 3;
const BASE_DELAY_MS = 50;
// How often to flush accumulated bandwidth data to the database
const FLUSH_INTERVAL_MS = 120_000; // 120 seconds
// In-memory accumulator: publicKey -> { bytesIn, bytesOut }
let accumulator = new Map<string, BandwidthAccumulator>();
/**
* Check if an error is a deadlock error
*/
@@ -53,6 +64,90 @@ async function withDeadlockRetry<T>(
}
}
/**
* Flush all accumulated bandwidth data to the database.
*
* Swaps out the accumulator before writing so that any bandwidth messages
* received during the flush are captured in the new accumulator rather than
* being lost or causing contention. Entries that fail to write are re-queued
* back into the accumulator so they will be retried on the next flush.
*
* This function is exported so that the application's graceful-shutdown
* cleanup handler can call it before the process exits.
*/
export async function flushBandwidthToDb(): Promise<void> {
if (accumulator.size === 0) {
return;
}
// Atomically swap out the accumulator so new data keeps flowing in
// while we write the snapshot to the database.
const snapshot = accumulator;
accumulator = new Map<string, BandwidthAccumulator>();
const currentTime = new Date().toISOString();
// Sort by publicKey for consistent lock ordering across concurrent
// writers — this is the same deadlock-prevention strategy used in the
// original per-message implementation.
const sortedEntries = [...snapshot.entries()].sort(([a], [b]) =>
a.localeCompare(b)
);
logger.debug(
`Flushing accumulated bandwidth data for ${sortedEntries.length} client(s) to the database`
);
for (const [publicKey, { bytesIn, bytesOut }] of sortedEntries) {
try {
await withDeadlockRetry(async () => {
// Use atomic SQL increment to avoid the SELECT-then-UPDATE
// anti-pattern and the races it would introduce.
await db
.update(clients)
.set({
// Note: bytesIn from peer goes to megabytesOut (data
// sent to client) and bytesOut from peer goes to
// megabytesIn (data received from client).
megabytesOut: sql`COALESCE(${clients.megabytesOut}, 0) + ${bytesIn}`,
megabytesIn: sql`COALESCE(${clients.megabytesIn}, 0) + ${bytesOut}`,
lastBandwidthUpdate: currentTime
})
.where(eq(clients.pubKey, publicKey));
}, `flush bandwidth for client ${publicKey}`);
} catch (error) {
logger.error(
`Failed to flush bandwidth for client ${publicKey}:`,
error
);
// Re-queue the failed entry so it is retried on the next flush
// rather than silently dropped.
const existing = accumulator.get(publicKey);
if (existing) {
existing.bytesIn += bytesIn;
existing.bytesOut += bytesOut;
} else {
accumulator.set(publicKey, { bytesIn, bytesOut });
}
}
}
}
const flushTimer = setInterval(async () => {
try {
await flushBandwidthToDb();
} catch (error) {
logger.error("Unexpected error during periodic bandwidth flush:", error);
}
}, FLUSH_INTERVAL_MS);
// Calling unref() means this timer will not keep the Node.js event loop alive
// on its own — the process can still exit normally when there is no other work
// left. The graceful-shutdown path (see server/cleanup.ts) will call
// flushBandwidthToDb() explicitly before process.exit(), so no data is lost.
flushTimer.unref();
export const handleReceiveBandwidthMessage: MessageHandler = async (
context
) => {
@@ -69,40 +164,21 @@ export const handleReceiveBandwidthMessage: MessageHandler = async (
throw new Error("Invalid bandwidth data");
}
// Sort bandwidth data by publicKey to ensure consistent lock ordering across all instances
// This is critical for preventing deadlocks when multiple instances update the same clients
const sortedBandwidthData = [...bandwidthData].sort((a, b) =>
a.publicKey.localeCompare(b.publicKey)
);
// Accumulate the incoming data in memory; the periodic timer (and the
// shutdown hook) will take care of writing it to the database.
for (const { publicKey, bytesIn, bytesOut } of bandwidthData) {
// Skip peers that haven't transferred any data — writing zeros to the
// database would be a no-op anyway.
if (bytesIn <= 0 && bytesOut <= 0) {
continue;
}
const currentTime = new Date().toISOString();
// Update each client individually with retry logic
// This reduces transaction scope and allows retries per-client
for (const peer of sortedBandwidthData) {
const { publicKey, bytesIn, bytesOut } = peer;
try {
await withDeadlockRetry(async () => {
// Use atomic SQL increment to avoid SELECT then UPDATE pattern
// This eliminates the need to read the current value first
await db
.update(clients)
.set({
// Note: bytesIn from peer goes to megabytesOut (data sent to client)
// and bytesOut from peer goes to megabytesIn (data received from client)
megabytesOut: sql`COALESCE(${clients.megabytesOut}, 0) + ${bytesIn}`,
megabytesIn: sql`COALESCE(${clients.megabytesIn}, 0) + ${bytesOut}`,
lastBandwidthUpdate: currentTime
})
.where(eq(clients.pubKey, publicKey));
}, `update client bandwidth ${publicKey}`);
} catch (error) {
logger.error(
`Failed to update bandwidth for client ${publicKey}:`,
error
);
// Continue with other clients even if one fails
const existing = accumulator.get(publicKey);
if (existing) {
existing.bytesIn += bytesIn;
existing.bytesOut += bytesOut;
} else {
accumulator.set(publicKey, { bytesIn, bytesOut });
}
}
};

View File

@@ -6,6 +6,7 @@ import {
buildClientConfigurationForNewtClient,
buildTargetConfigurationForNewtClient
} from "./buildConfiguration";
import { canCompress } from "@server/lib/clientVersionChecks";
export async function sendNewtSyncMessage(newt: Newt, site: Site) {
const { tcpTargets, udpTargets, validHealthCheckTargets } =
@@ -24,18 +25,24 @@ export async function sendNewtSyncMessage(newt: Newt, site: Site) {
exitNode
);
await sendToClient(newt.newtId, {
type: "newt/sync",
data: {
proxyTargets: {
udp: udpTargets,
tcp: tcpTargets
},
healthCheckTargets: validHealthCheckTargets,
peers: peers,
clientTargets: targets
await sendToClient(
newt.newtId,
{
type: "newt/sync",
data: {
proxyTargets: {
udp: udpTargets,
tcp: tcpTargets
},
healthCheckTargets: validHealthCheckTargets,
peers: peers,
clientTargets: targets
}
},
{
compress: canCompress(newt.version, "newt")
}
}).catch((error) => {
).catch((error) => {
logger.warn(`Error sending newt sync message:`, error);
});
}

View File

@@ -2,13 +2,14 @@ import { Target, TargetHealthCheck, db, targetHealthCheck } from "@server/db";
import { sendToClient } from "#dynamic/routers/ws";
import logger from "@server/logger";
import { eq, inArray } from "drizzle-orm";
import { canCompress } from "@server/lib/clientVersionChecks";
export async function addTargets(
newtId: string,
targets: Target[],
healthCheckData: TargetHealthCheck[],
protocol: string,
port: number | null = null
version?: string | null
) {
//create a list of udp and tcp targets
const payloadTargets = targets.map((target) => {
@@ -22,7 +23,7 @@ export async function addTargets(
data: {
targets: payloadTargets
}
}, { incrementConfigVersion: true });
}, { incrementConfigVersion: true, compress: canCompress(version, "newt") });
// Create a map for quick lookup
const healthCheckMap = new Map<number, TargetHealthCheck>();
@@ -103,14 +104,14 @@ export async function addTargets(
data: {
targets: validHealthCheckTargets
}
}, { incrementConfigVersion: true });
}, { incrementConfigVersion: true, compress: canCompress(version, "newt") });
}
export async function removeTargets(
newtId: string,
targets: Target[],
protocol: string,
port: number | null = null
version?: string | null
) {
//create a list of udp and tcp targets
const payloadTargets = targets.map((target) => {
@@ -135,5 +136,5 @@ export async function removeTargets(
data: {
ids: healthCheckTargets
}
}, { incrementConfigVersion: true });
}, { incrementConfigVersion: true, compress: canCompress(version, "newt") });
}

View File

@@ -1,5 +1,17 @@
import { Client, clientSiteResourcesAssociationsCache, clientSitesAssociationsCache, db, exitNodes, siteResources, sites } from "@server/db";
import { generateAliasConfig, generateRemoteSubnets } from "@server/lib/ip";
import {
Client,
clientSiteResourcesAssociationsCache,
clientSitesAssociationsCache,
db,
exitNodes,
siteResources,
sites
} from "@server/db";
import {
Alias,
generateAliasConfig,
generateRemoteSubnets
} from "@server/lib/ip";
import logger from "@server/logger";
import { and, eq } from "drizzle-orm";
import { addPeer, deletePeer } from "../newt/peers";
@@ -8,9 +20,19 @@ import config from "@server/lib/config";
export async function buildSiteConfigurationForOlmClient(
client: Client,
publicKey: string | null,
relay: boolean
relay: boolean,
jitMode: boolean = false
) {
const siteConfigurations = [];
const siteConfigurations: {
siteId: number;
name?: string
endpoint?: string
publicKey?: string
serverIP?: string | null
serverPort?: number | null
remoteSubnets?: string[];
aliases: Alias[];
}[] = [];
// Get all sites data
const sitesData = await db
@@ -27,6 +49,40 @@ export async function buildSiteConfigurationForOlmClient(
sites: site,
clientSitesAssociationsCache: association
} of sitesData) {
const allSiteResources = await db // only get the site resources that this client has access to
.select()
.from(siteResources)
.innerJoin(
clientSiteResourcesAssociationsCache,
eq(
siteResources.siteResourceId,
clientSiteResourcesAssociationsCache.siteResourceId
)
)
.where(
and(
eq(siteResources.siteId, site.siteId),
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
)
)
);
if (jitMode) {
// Add site configuration to the array
siteConfigurations.push({
siteId: site.siteId,
// remoteSubnets: generateRemoteSubnets(
// allSiteResources.map(({ siteResources }) => siteResources)
// ),
aliases: generateAliasConfig(
allSiteResources.map(({ siteResources }) => siteResources)
)
});
continue;
}
if (!site.exitNodeId) {
logger.warn(
`Site ${site.siteId} does not have exit node, skipping`
@@ -42,6 +98,13 @@ export async function buildSiteConfigurationForOlmClient(
continue;
}
if (!site.publicKey || site.publicKey == "") { // the site is not ready to accept new peers
logger.warn(
`Site ${site.siteId} has no public key, skipping`
);
continue;
}
// if (site.lastHolePunch && now - site.lastHolePunch > 6 && relay) {
// logger.warn(
// `Site ${site.siteId} last hole punch is too old, skipping`
@@ -103,26 +166,6 @@ export async function buildSiteConfigurationForOlmClient(
relayEndpoint = `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`;
}
const allSiteResources = await db // only get the site resources that this client has access to
.select()
.from(siteResources)
.innerJoin(
clientSiteResourcesAssociationsCache,
eq(
siteResources.siteResourceId,
clientSiteResourcesAssociationsCache.siteResourceId
)
)
.where(
and(
eq(siteResources.siteId, site.siteId),
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
)
)
);
// Add site configuration to the array
siteConfigurations.push({
siteId: site.siteId,

View File

@@ -17,6 +17,9 @@ import { getUserDeviceName } from "@server/db/names";
import { buildSiteConfigurationForOlmClient } from "./buildConfiguration";
import { OlmErrorCodes, sendOlmError } from "./error";
import { handleFingerprintInsertion } from "./fingerprintingUtils";
import { Alias } from "@server/lib/ip";
import { build } from "@server/build";
import { canCompress } from "@server/lib/clientVersionChecks";
export const handleOlmRegisterMessage: MessageHandler = async (context) => {
logger.info("Handling register olm message!");
@@ -207,6 +210,32 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
}
}
// Get all sites data
const sitesCountResult = await db
.select({ count: count() })
.from(sites)
.innerJoin(
clientSitesAssociationsCache,
eq(sites.siteId, clientSitesAssociationsCache.siteId)
)
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
// Extract the count value from the result array
const sitesCount =
sitesCountResult.length > 0 ? sitesCountResult[0].count : 0;
// Prepare an array to store site configurations
logger.debug(`Found ${sitesCount} sites for client ${client.clientId}`);
let jitMode = true;
if (sitesCount > 250 && build == "saas") {
// THIS IS THE MAX ON THE BUSINESS TIER
// we have too many sites
// If we have too many sites we need to drop into fully JIT mode by not sending any of the sites
logger.info("Too many sites (%d), dropping into JIT mode", sitesCount);
jitMode = true;
}
logger.debug(
`Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}`
);
@@ -233,28 +262,12 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
await db
.update(clientSitesAssociationsCache)
.set({
isRelayed: relay == true
isRelayed: relay == true,
isJitMode: jitMode
})
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
}
// Get all sites data
const sitesCountResult = await db
.select({ count: count() })
.from(sites)
.innerJoin(
clientSitesAssociationsCache,
eq(sites.siteId, clientSitesAssociationsCache.siteId)
)
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
// Extract the count value from the result array
const sitesCount =
sitesCountResult.length > 0 ? sitesCountResult[0].count : 0;
// Prepare an array to store site configurations
logger.debug(`Found ${sitesCount} sites for client ${client.clientId}`);
// this prevents us from accepting a register from an olm that has not hole punched yet.
// the olm will pump the register so we can keep checking
// TODO: I still think there is a better way to do this rather than locking it out here but ???
@@ -265,19 +278,14 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
return;
}
// NOTE: its important that the client here is the old client and the public key is the new key
// NOTE: its important that the client here is the old client and the public key is the new key
const siteConfigurations = await buildSiteConfigurationForOlmClient(
client,
publicKey,
relay
relay,
jitMode
);
// REMOVED THIS SO IT CREATES THE INTERFACE AND JUST WAITS FOR THE SITES
// if (siteConfigurations.length === 0) {
// logger.warn("No valid site configurations found");
// return;
// }
// Return connect message with all site configurations
return {
message: {
@@ -288,6 +296,9 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
utilitySubnet: org.utilitySubnet
}
},
options: {
compress: canCompress(olm.version, "olm")
},
broadcast: false,
excludeSender: false
};

View File

@@ -18,7 +18,7 @@ export const handleOlmRelayMessage: MessageHandler = async (context) => {
}
if (!olm.clientId) {
logger.warn("Olm has no site!"); // TODO: Maybe we create the site here?
logger.warn("Olm has no client!");
return;
}
@@ -41,7 +41,7 @@ export const handleOlmRelayMessage: MessageHandler = async (context) => {
return;
}
const { siteId } = message.data;
const { siteId, chainId } = message.data;
// Get the site
const [site] = await db
@@ -90,7 +90,8 @@ export const handleOlmRelayMessage: MessageHandler = async (context) => {
data: {
siteId: siteId,
relayEndpoint: exitNode.endpoint,
relayPort: config.getRawConfig().gerbil.clients_start_port
relayPort: config.getRawConfig().gerbil.clients_start_port,
chainId
}
},
broadcast: false,

View File

@@ -0,0 +1,241 @@
import {
clientSiteResourcesAssociationsCache,
clientSitesAssociationsCache,
db,
exitNodes,
Site,
siteResources
} from "@server/db";
import { MessageHandler } from "@server/routers/ws";
import { clients, Olm, sites } from "@server/db";
import { and, eq, or } from "drizzle-orm";
import logger from "@server/logger";
import { initPeerAddHandshake } from "./peers";
export const handleOlmServerInitAddPeerHandshake: MessageHandler = async (
context
) => {
logger.info("Handling register olm message!");
const { message, client: c, sendToClient } = context;
const olm = c as Olm;
if (!olm) {
logger.warn("Olm not found");
return;
}
if (!olm.clientId) {
logger.warn("Olm has no client!"); // TODO: Maybe we create the site here?
return;
}
const clientId = olm.clientId;
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
logger.warn("Client not found");
return;
}
const { siteId, resourceId, chainId } = message.data;
let site: Site | null = null;
if (siteId) {
// get the site
const [siteRes] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (siteRes) {
site = siteRes;
}
}
if (resourceId && !site) {
const resources = await db
.select()
.from(siteResources)
.where(
and(
or(
eq(siteResources.niceId, resourceId),
eq(siteResources.alias, resourceId)
),
eq(siteResources.orgId, client.orgId)
)
);
if (!resources || resources.length === 0) {
logger.error(`handleOlmServerPeerAddMessage: Resource not found`);
// cancel the request from the olm side to not keep doing this
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
return;
}
if (resources.length > 1) {
// error but this should not happen because the nice id cant contain a dot and the alias has to have a dot and both have to be unique within the org so there should never be multiple matches
logger.error(
`handleOlmServerPeerAddMessage: Multiple resources found matching the criteria`
);
return;
}
const resource = resources[0];
const currentResourceAssociationCaches = await db
.select()
.from(clientSiteResourcesAssociationsCache)
.where(
and(
eq(
clientSiteResourcesAssociationsCache.siteResourceId,
resource.siteResourceId
),
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
)
)
);
if (currentResourceAssociationCaches.length === 0) {
logger.error(
`handleOlmServerPeerAddMessage: Client ${client.clientId} does not have access to resource ${resource.siteResourceId}`
);
// cancel the request from the olm side to not keep doing this
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
return;
}
const siteIdFromResource = resource.siteId;
// get the site
const [siteRes] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteIdFromResource));
if (!siteRes) {
logger.error(
`handleOlmServerPeerAddMessage: Site with ID ${site} not found`
);
return;
}
site = siteRes;
}
if (!site) {
logger.error(`handleOlmServerPeerAddMessage: Site not found`);
return;
}
// check if the client can access this site using the cache
const currentSiteAssociationCaches = await db
.select()
.from(clientSitesAssociationsCache)
.where(
and(
eq(clientSitesAssociationsCache.clientId, client.clientId),
eq(clientSitesAssociationsCache.siteId, site.siteId)
)
);
if (currentSiteAssociationCaches.length === 0) {
logger.error(
`handleOlmServerPeerAddMessage: Client ${client.clientId} does not have access to site ${site.siteId}`
);
// cancel the request from the olm side to not keep doing this
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
return;
}
if (!site.exitNodeId) {
logger.error(
`handleOlmServerPeerAddMessage: Site with ID ${site.siteId} has no exit node`
);
// cancel the request from the olm side to not keep doing this
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
return;
}
// get the exit node from the side
const [exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, site.exitNodeId));
if (!exitNode) {
logger.error(
`handleOlmServerPeerAddMessage: Site with ID ${site.siteId} has no exit node`
);
return;
}
// also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch
// if it has already been added this will be a no-op
await initPeerAddHandshake(
// this will kick off the add peer process for the client
client.clientId,
{
siteId: site.siteId,
exitNode: {
publicKey: exitNode.publicKey,
endpoint: exitNode.endpoint
}
},
olm.olmId,
chainId
);
return;
};

View File

@@ -54,7 +54,7 @@ export const handleOlmServerPeerAddMessage: MessageHandler = async (
return;
}
const { siteId } = message.data;
const { siteId, chainId } = message.data;
// get the site
const [site] = await db
@@ -179,7 +179,8 @@ export const handleOlmServerPeerAddMessage: MessageHandler = async (
),
aliases: generateAliasConfig(
allSiteResources.map(({ siteResources }) => siteResources)
)
),
chainId: chainId,
}
},
broadcast: false,

View File

@@ -17,7 +17,7 @@ export const handleOlmUnRelayMessage: MessageHandler = async (context) => {
}
if (!olm.clientId) {
logger.warn("Olm has no site!"); // TODO: Maybe we create the site here?
logger.warn("Olm has no client!");
return;
}
@@ -40,7 +40,7 @@ export const handleOlmUnRelayMessage: MessageHandler = async (context) => {
return;
}
const { siteId } = message.data;
const { siteId, chainId } = message.data;
// Get the site
const [site] = await db
@@ -87,7 +87,8 @@ export const handleOlmUnRelayMessage: MessageHandler = async (context) => {
type: "olm/wg/peer/unrelay",
data: {
siteId: siteId,
endpoint: site.endpoint
endpoint: site.endpoint,
chainId
}
},
broadcast: false,

View File

@@ -11,3 +11,4 @@ export * from "./handleOlmServerPeerAddMessage";
export * from "./handleOlmUnRelayMessage";
export * from "./recoverOlmWithFingerprint";
export * from "./handleOlmDisconnectingMessage";
export * from "./handleOlmServerInitAddPeerHandshake";

View File

@@ -1,8 +1,9 @@
import { sendToClient } from "#dynamic/routers/ws";
import { db, olms } from "@server/db";
import { clientSitesAssociationsCache, db, olms } from "@server/db";
import { canCompress } from "@server/lib/clientVersionChecks";
import config from "@server/lib/config";
import logger from "@server/logger";
import { eq } from "drizzle-orm";
import { and, eq } from "drizzle-orm";
import { Alias } from "yaml";
export async function addPeer(
@@ -18,7 +19,8 @@ export async function addPeer(
remoteSubnets: string[] | null; // optional, comma-separated list of subnets that this site can access
aliases: Alias[];
},
olmId?: string
olmId?: string,
version?: string | null
) {
if (!olmId) {
const [olm] = await db
@@ -30,6 +32,7 @@ export async function addPeer(
return; // ignore this because an olm might not be associated with the client anymore
}
olmId = olm.olmId;
version = olm.version;
}
await sendToClient(
@@ -48,7 +51,7 @@ export async function addPeer(
aliases: peer.aliases
}
},
{ incrementConfigVersion: true }
{ incrementConfigVersion: true, compress: canCompress(version, "olm") }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
@@ -60,7 +63,8 @@ export async function deletePeer(
clientId: number,
siteId: number,
publicKey: string,
olmId?: string
olmId?: string,
version?: string | null
) {
if (!olmId) {
const [olm] = await db
@@ -72,6 +76,7 @@ export async function deletePeer(
return;
}
olmId = olm.olmId;
version = olm.version;
}
await sendToClient(
@@ -83,7 +88,7 @@ export async function deletePeer(
siteId: siteId
}
},
{ incrementConfigVersion: true }
{ incrementConfigVersion: true, compress: canCompress(version, "olm") }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
@@ -103,7 +108,8 @@ export async function updatePeer(
remoteSubnets?: string[] | null; // optional, comma-separated list of subnets that
aliases?: Alias[] | null;
},
olmId?: string
olmId?: string,
version?: string | null
) {
if (!olmId) {
const [olm] = await db
@@ -115,6 +121,7 @@ export async function updatePeer(
return;
}
olmId = olm.olmId;
version = olm.version;
}
await sendToClient(
@@ -132,7 +139,7 @@ export async function updatePeer(
aliases: peer.aliases
}
},
{ incrementConfigVersion: true }
{ incrementConfigVersion: true, compress: canCompress(version, "olm") }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
@@ -149,7 +156,8 @@ export async function initPeerAddHandshake(
endpoint: string;
};
},
olmId?: string
olmId?: string,
chainId?: string
) {
if (!olmId) {
const [olm] = await db
@@ -173,7 +181,8 @@ export async function initPeerAddHandshake(
publicKey: peer.exitNode.publicKey,
relayPort: config.getRawConfig().gerbil.clients_start_port,
endpoint: peer.exitNode.endpoint
}
},
chainId
}
},
{ incrementConfigVersion: true }
@@ -181,6 +190,17 @@ export async function initPeerAddHandshake(
logger.warn(`Error sending message:`, error);
});
// update the clientSiteAssociationsCache to make the isJitMode flag false so that JIT mode is disabled for this site if it restarts or something after the connection
await db
.update(clientSitesAssociationsCache)
.set({ isJitMode: false })
.where(
and(
eq(clientSitesAssociationsCache.clientId, clientId),
eq(clientSitesAssociationsCache.siteId, peer.siteId)
)
);
logger.info(
`Initiated peer add handshake for site ${peer.siteId} to olm ${olmId}`
);

View File

@@ -1,9 +1,17 @@
import { Client, db, exitNodes, Olm, sites, clientSitesAssociationsCache } from "@server/db";
import {
Client,
db,
exitNodes,
Olm,
sites,
clientSitesAssociationsCache
} from "@server/db";
import { buildSiteConfigurationForOlmClient } from "./buildConfiguration";
import { sendToClient } from "#dynamic/routers/ws";
import logger from "@server/logger";
import { eq, inArray } from "drizzle-orm";
import config from "@server/lib/config";
import { canCompress } from "@server/lib/clientVersionChecks";
export async function sendOlmSyncMessage(olm: Olm, client: Client) {
// NOTE: WE ARE HARDCODING THE RELAY PARAMETER TO FALSE HERE BUT IN THE REGISTER MESSAGE ITS DEFINED BY THE CLIENT
@@ -17,10 +25,7 @@ export async function sendOlmSyncMessage(olm: Olm, client: Client) {
const clientSites = await db
.select()
.from(clientSitesAssociationsCache)
.innerJoin(
sites,
eq(sites.siteId, clientSitesAssociationsCache.siteId)
)
.innerJoin(sites, eq(sites.siteId, clientSitesAssociationsCache.siteId))
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
// Extract unique exit node IDs
@@ -68,13 +73,20 @@ export async function sendOlmSyncMessage(olm: Olm, client: Client) {
logger.debug("sendOlmSyncMessage: sending sync message");
await sendToClient(olm.olmId, {
type: "olm/sync",
data: {
sites: siteConfigurations,
exitNodes: exitNodesData
await sendToClient(
olm.olmId,
{
type: "olm/sync",
data: {
sites: siteConfigurations,
exitNodes: exitNodesData
}
},
{
compress: canCompress(olm.version, "olm")
}
}).catch((error) => {
).catch((error) => {
logger.warn(`Error sending olm sync message:`, error);
});
}

View File

@@ -620,7 +620,7 @@ export async function handleMessagingForUpdatedSiteResource(
await updateTargets(newt.newtId, {
oldTargets: oldTargets,
newTargets: newTargets
});
}, newt.version);
}
const olmJobs: Promise<void>[] = [];

View File

@@ -264,7 +264,7 @@ export async function createTarget(
newTarget,
healthCheck,
resource.protocol,
resource.proxyPort
newt.version
);
}
}

View File

@@ -262,7 +262,7 @@ export async function updateTarget(
[updatedTarget],
[updatedHc],
resource.protocol,
resource.proxyPort
newt.version
);
}
}

View File

@@ -6,7 +6,8 @@ import {
handleDockerContainersMessage,
handleNewtPingRequestMessage,
handleApplyBlueprintMessage,
handleNewtPingMessage
handleNewtPingMessage,
startNewtOfflineChecker
} from "../newt";
import {
handleOlmRegisterMessage,
@@ -15,7 +16,8 @@ import {
startOlmOfflineChecker,
handleOlmServerPeerAddMessage,
handleOlmUnRelayMessage,
handleOlmDisconnecingMessage
handleOlmDisconnecingMessage,
handleOlmServerInitAddPeerHandshake
} from "../olm";
import { handleHealthcheckStatusMessage } from "../target";
import { handleRoundTripMessage } from "./handleRoundTripMessage";
@@ -23,6 +25,7 @@ import { MessageHandler } from "./types";
export const messageHandlers: Record<string, MessageHandler> = {
"olm/wg/server/peer/add": handleOlmServerPeerAddMessage,
"olm/wg/server/peer/init": handleOlmServerInitAddPeerHandshake,
"olm/wg/register": handleOlmRegisterMessage,
"olm/wg/relay": handleOlmRelayMessage,
"olm/wg/unrelay": handleOlmUnRelayMessage,
@@ -41,3 +44,4 @@ export const messageHandlers: Record<string, MessageHandler> = {
};
startOlmOfflineChecker(); // this is to handle the offline check for olms
startNewtOfflineChecker(); // this is to handle the offline check for newts

View File

@@ -24,7 +24,7 @@ export interface AuthenticatedWebSocket extends WebSocket {
clientType?: ClientType;
connectionId?: string;
isFullyConnected?: boolean;
pendingMessages?: Buffer[];
pendingMessages?: { data: Buffer; isBinary: boolean }[];
configVersion?: number;
}
@@ -73,6 +73,7 @@ export type MessageHandler = (
// Options for sending messages with config version tracking
export interface SendMessageOptions {
incrementConfigVersion?: boolean;
compress?: boolean;
}
// Redis message type for cross-node communication

View File

@@ -1,8 +1,9 @@
import { Router, Request, Response } from "express";
import zlib from "zlib";
import { Server as HttpServer } from "http";
import { WebSocket, WebSocketServer } from "ws";
import { Socket } from "net";
import { Newt, newts, NewtSession, olms, Olm, OlmSession } from "@server/db";
import { Newt, newts, NewtSession, olms, Olm, OlmSession, sites } from "@server/db";
import { eq } from "drizzle-orm";
import { db } from "@server/db";
import { validateNewtSessionToken } from "@server/auth/sessions/newt";
@@ -116,11 +117,20 @@ const sendToClientLocal = async (
};
const messageString = JSON.stringify(messageWithVersion);
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(messageString);
}
});
if (options.compress) {
const compressed = zlib.gzipSync(Buffer.from(messageString, "utf8"));
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(compressed);
}
});
} else {
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(messageString);
}
});
}
return true;
};
@@ -147,11 +157,22 @@ const broadcastToAllExceptLocal = async (
...message,
configVersion
};
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(JSON.stringify(messageWithVersion));
}
});
if (options.compress) {
const compressed = zlib.gzipSync(
Buffer.from(JSON.stringify(messageWithVersion), "utf8")
);
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(compressed);
}
});
} else {
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(JSON.stringify(messageWithVersion));
}
});
}
}
});
};
@@ -286,9 +307,12 @@ const setupConnection = async (
clientType === "newt" ? (client as Newt).newtId : (client as Olm).olmId;
await addClient(clientType, clientId, ws);
ws.on("message", async (data) => {
ws.on("message", async (data, isBinary) => {
try {
const message: WSMessage = JSON.parse(data.toString());
const messageBuffer = isBinary
? zlib.gunzipSync(data as Buffer)
: (data as Buffer);
const message: WSMessage = JSON.parse(messageBuffer.toString());
if (!message.type || typeof message.type !== "string") {
throw new Error(
@@ -356,6 +380,31 @@ const setupConnection = async (
);
});
// Handle WebSocket protocol-level pings from older newt clients that do
// not send application-level "newt/ping" messages. Update the site's
// online state and lastPing timestamp so the offline checker treats them
// the same as modern newt clients.
if (clientType === "newt") {
const newtClient = client as Newt;
ws.on("ping", async () => {
if (!newtClient.siteId) return;
try {
await db
.update(sites)
.set({
online: true,
lastPing: Math.floor(Date.now() / 1000)
})
.where(eq(sites.siteId, newtClient.siteId));
} catch (error) {
logger.error(
"Error updating newt site online state on WS ping",
{ error }
);
}
});
}
ws.on("error", (error: Error) => {
logger.error(
`WebSocket error for ${clientType.toUpperCase()} ID ${clientId}:`,

View File

@@ -35,11 +35,7 @@ import {
} from "@app/components/Credenza";
import { cn } from "@app/lib/cn";
import { CreditCard, ExternalLink, Check, AlertTriangle } from "lucide-react";
import {
Alert,
AlertTitle,
AlertDescription
} from "@app/components/ui/alert";
import { Alert, AlertTitle, AlertDescription } from "@app/components/ui/alert";
import {
Tooltip,
TooltipTrigger,
@@ -69,6 +65,7 @@ type PlanOption = {
price: string;
priceDetail?: string;
tierType: Tier | null;
features: string[];
};
const planOptions: PlanOption[] = [
@@ -76,41 +73,87 @@ const planOptions: PlanOption[] = [
id: "basic",
name: "Basic",
price: "Free",
tierType: null
tierType: null,
features: [
"Basic Pangolin features",
"Free provided domains",
"Web-based proxy resources",
"Private resources and clients",
"Peer-to-peer connections"
]
},
{
id: "home",
name: "Home",
price: "$12.50",
priceDetail: "/ month",
tierType: "tier1"
tierType: "tier1",
features: [
"Everything in Basic",
"OAuth2/OIDC, Google, & Azure SSO",
"Bring your own identity provider",
"Pangolin SSH",
"Custom branding",
"Device admin approvals"
]
},
{
id: "team",
name: "Team",
price: "$4",
priceDetail: "per user / month",
tierType: "tier2"
tierType: "tier2",
features: [
"Everything in Basic",
"Custom domains",
"OAuth2/OIDC, Google, & Azure SSO",
"Access and action audit logs",
"Device posture information"
]
},
{
id: "business",
name: "Business",
price: "$9",
priceDetail: "per user / month",
tierType: "tier3"
tierType: "tier3",
features: [
"Everything in Team",
"Multiple organizations (multi-tenancy)",
"Auto-provisioning via IdP",
"Pangolin SSH",
"Device approvals",
"Custom branding",
"Business support"
]
},
{
id: "enterprise",
name: "Enterprise",
price: "Custom",
tierType: null
tierType: null,
features: [
"Everything in Business",
"Custom limits",
"Priority support and SLA",
"Log push and export",
"Private and Gov-Cloud deployment options",
"Dedicated, premium relay/exit nodes",
"Pay by invoice "
]
}
];
// Tier limits mapping derived from limit sets
const tierLimits: Record<
Tier | "basic",
{ users: number; sites: number; domains: number; remoteNodes: number; organizations: number }
{
users: number;
sites: number;
domains: number;
remoteNodes: number;
organizations: number;
}
> = {
basic: {
users: freeLimitSet[FeatureId.USERS]?.value ?? 0,
@@ -463,31 +506,43 @@ export default function BillingPage() {
const isProblematicState = hasProblematicSubscription();
// Get user-friendly subscription status message
const getSubscriptionStatusMessage = (): { title: string; description: string } | null => {
const getSubscriptionStatusMessage = (): {
title: string;
description: string;
} | null => {
if (!tierSubscription?.subscription || !isProblematicState) return null;
const status = tierSubscription.subscription.status;
switch (status) {
case "past_due":
return {
title: t("billingPastDueTitle") || "Payment Past Due",
description: t("billingPastDueDescription") || "Your payment is past due. Please update your payment method to continue using your current plan features. If not resolved, your subscription will be canceled and you'll be reverted to the free tier."
description:
t("billingPastDueDescription") ||
"Your payment is past due. Please update your payment method to continue using your current plan features. If not resolved, your subscription will be canceled and you'll be reverted to the free tier."
};
case "unpaid":
return {
title: t("billingUnpaidTitle") || "Subscription Unpaid",
description: t("billingUnpaidDescription") || "Your subscription is unpaid and you have been reverted to the free tier. Please update your payment method to restore your subscription."
description:
t("billingUnpaidDescription") ||
"Your subscription is unpaid and you have been reverted to the free tier. Please update your payment method to restore your subscription."
};
case "incomplete":
return {
title: t("billingIncompleteTitle") || "Payment Incomplete",
description: t("billingIncompleteDescription") || "Your payment is incomplete. Please complete the payment process to activate your subscription."
description:
t("billingIncompleteDescription") ||
"Your payment is incomplete. Please complete the payment process to activate your subscription."
};
case "incomplete_expired":
return {
title: t("billingIncompleteExpiredTitle") || "Payment Expired",
description: t("billingIncompleteExpiredDescription") || "Your payment was never completed and has expired. You have been reverted to the free tier. Please subscribe again to restore access to paid features."
title:
t("billingIncompleteExpiredTitle") || "Payment Expired",
description:
t("billingIncompleteExpiredDescription") ||
"Your payment was never completed and has expired. You have been reverted to the free tier. Please subscribe again to restore access to paid features."
};
default:
return null;
@@ -509,7 +564,11 @@ export default function BillingPage() {
if (plan.id === currentPlanId) {
// If it's the basic plan (basic with no subscription), show as current but disabled
if (plan.id === "basic" && !hasSubscription && !isProblematicState) {
if (
plan.id === "basic" &&
!hasSubscription &&
!isProblematicState
) {
return {
label: "Current Plan",
action: () => {},
@@ -632,7 +691,9 @@ export default function BillingPage() {
};
// Check if downgrading to a tier would violate current usage limits
const checkLimitViolations = (targetTier: Tier | "basic"): Array<{
const checkLimitViolations = (
targetTier: Tier | "basic"
): Array<{
feature: string;
currentUsage: number;
newLimit: number;
@@ -687,7 +748,10 @@ export default function BillingPage() {
// Check organizations
const organizationsUsage = getUsageValue(ORGINIZATIONS);
if (limits.organizations > 0 && organizationsUsage > limits.organizations) {
if (
limits.organizations > 0 &&
organizationsUsage > limits.organizations
) {
violations.push({
feature: "Organizations",
currentUsage: organizationsUsage,
@@ -712,17 +776,15 @@ export default function BillingPage() {
{isProblematicState && statusMessage && (
<Alert variant="destructive" className="mb-6">
<AlertTriangle className="h-4 w-4" />
<AlertTitle>
{statusMessage.title}
</AlertTitle>
<AlertTitle>{statusMessage.title}</AlertTitle>
<AlertDescription>
{statusMessage.description}
{" "}
{statusMessage.description}{" "}
<button
onClick={handleModifySubscription}
className="underline font-semibold hover:no-underline"
>
{t("billingManageSubscription") || "Manage your subscription"}
{t("billingManageSubscription") ||
"Manage your subscription"}
</button>
</AlertDescription>
</Alert>
@@ -772,7 +834,10 @@ export default function BillingPage() {
</div>
</div>
<div className="mt-4">
{isProblematicState && planAction.disabled && !isCurrentPlan && plan.id !== "enterprise" ? (
{isProblematicState &&
planAction.disabled &&
!isCurrentPlan &&
plan.id !== "enterprise" ? (
<Tooltip>
<TooltipTrigger asChild>
<div>
@@ -784,18 +849,29 @@ export default function BillingPage() {
}
size="sm"
className="w-full"
onClick={planAction.action}
disabled={
isLoading || planAction.disabled
onClick={
planAction.action
}
disabled={
isLoading ||
planAction.disabled
}
loading={
isLoading &&
isCurrentPlan
}
loading={isLoading && isCurrentPlan}
>
{planAction.label}
</Button>
</div>
</TooltipTrigger>
<TooltipContent>
<p>{t("billingResolvePaymentIssue") || "Please resolve your payment issue before upgrading or downgrading"}</p>
<p>
{t(
"billingResolvePaymentIssue"
) ||
"Please resolve your payment issue before upgrading or downgrading"}
</p>
</TooltipContent>
</Tooltip>
) : (
@@ -809,9 +885,12 @@ export default function BillingPage() {
className="w-full"
onClick={planAction.action}
disabled={
isLoading || planAction.disabled
isLoading ||
planAction.disabled
}
loading={
isLoading && isCurrentPlan
}
loading={isLoading && isCurrentPlan}
>
{planAction.label}
</Button>
@@ -886,18 +965,38 @@ export default function BillingPage() {
<Tooltip>
<TooltipTrigger className="flex items-center gap-1">
<AlertTriangle className="h-3 w-3 text-orange-400" />
<span className={cn(
"text-orange-600 dark:text-orange-400 font-medium"
)}>
<span
className={cn(
"text-orange-600 dark:text-orange-400 font-medium"
)}
>
{getLimitValue(USERS) ??
t("billingUnlimited") ??
t(
"billingUnlimited"
) ??
"∞"}{" "}
{getLimitValue(USERS) !== null &&
"users"}
{getLimitValue(
USERS
) !== null && "users"}
</span>
</TooltipTrigger>
<TooltipContent>
<p>{t("billingUsageExceedsLimit", { current: getUsageValue(USERS), limit: getLimitValue(USERS) ?? 0 }) || `Current usage (${getUsageValue(USERS)}) exceeds limit (${getLimitValue(USERS)})`}</p>
<p>
{t(
"billingUsageExceedsLimit",
{
current:
getUsageValue(
USERS
),
limit:
getLimitValue(
USERS
) ?? 0
}
) ||
`Current usage (${getUsageValue(USERS)}) exceeds limit (${getLimitValue(USERS)})`}
</p>
</TooltipContent>
</Tooltip>
) : (
@@ -905,8 +1004,8 @@ export default function BillingPage() {
{getLimitValue(USERS) ??
t("billingUnlimited") ??
"∞"}{" "}
{getLimitValue(USERS) !== null &&
"users"}
{getLimitValue(USERS) !==
null && "users"}
</>
)}
</InfoSectionContent>
@@ -920,18 +1019,38 @@ export default function BillingPage() {
<Tooltip>
<TooltipTrigger className="flex items-center gap-1">
<AlertTriangle className="h-3 w-3 text-orange-400" />
<span className={cn(
"text-orange-600 dark:text-orange-400 font-medium"
)}>
<span
className={cn(
"text-orange-600 dark:text-orange-400 font-medium"
)}
>
{getLimitValue(SITES) ??
t("billingUnlimited") ??
t(
"billingUnlimited"
) ??
"∞"}{" "}
{getLimitValue(SITES) !== null &&
"sites"}
{getLimitValue(
SITES
) !== null && "sites"}
</span>
</TooltipTrigger>
<TooltipContent>
<p>{t("billingUsageExceedsLimit", { current: getUsageValue(SITES), limit: getLimitValue(SITES) ?? 0 }) || `Current usage (${getUsageValue(SITES)}) exceeds limit (${getLimitValue(SITES)})`}</p>
<p>
{t(
"billingUsageExceedsLimit",
{
current:
getUsageValue(
SITES
),
limit:
getLimitValue(
SITES
) ?? 0
}
) ||
`Current usage (${getUsageValue(SITES)}) exceeds limit (${getLimitValue(SITES)})`}
</p>
</TooltipContent>
</Tooltip>
) : (
@@ -939,8 +1058,8 @@ export default function BillingPage() {
{getLimitValue(SITES) ??
t("billingUnlimited") ??
"∞"}{" "}
{getLimitValue(SITES) !== null &&
"sites"}
{getLimitValue(SITES) !==
null && "sites"}
</>
)}
</InfoSectionContent>
@@ -954,18 +1073,40 @@ export default function BillingPage() {
<Tooltip>
<TooltipTrigger className="flex items-center gap-1">
<AlertTriangle className="h-3 w-3 text-orange-400" />
<span className={cn(
"text-orange-600 dark:text-orange-400 font-medium"
)}>
{getLimitValue(DOMAINS) ??
t("billingUnlimited") ??
<span
className={cn(
"text-orange-600 dark:text-orange-400 font-medium"
)}
>
{getLimitValue(
DOMAINS
) ??
t(
"billingUnlimited"
) ??
"∞"}{" "}
{getLimitValue(DOMAINS) !== null &&
"domains"}
{getLimitValue(
DOMAINS
) !== null && "domains"}
</span>
</TooltipTrigger>
<TooltipContent>
<p>{t("billingUsageExceedsLimit", { current: getUsageValue(DOMAINS), limit: getLimitValue(DOMAINS) ?? 0 }) || `Current usage (${getUsageValue(DOMAINS)}) exceeds limit (${getLimitValue(DOMAINS)})`}</p>
<p>
{t(
"billingUsageExceedsLimit",
{
current:
getUsageValue(
DOMAINS
),
limit:
getLimitValue(
DOMAINS
) ?? 0
}
) ||
`Current usage (${getUsageValue(DOMAINS)}) exceeds limit (${getLimitValue(DOMAINS)})`}
</p>
</TooltipContent>
</Tooltip>
) : (
@@ -973,8 +1114,8 @@ export default function BillingPage() {
{getLimitValue(DOMAINS) ??
t("billingUnlimited") ??
"∞"}{" "}
{getLimitValue(DOMAINS) !== null &&
"domains"}
{getLimitValue(DOMAINS) !==
null && "domains"}
</>
)}
</InfoSectionContent>
@@ -989,18 +1130,40 @@ export default function BillingPage() {
<Tooltip>
<TooltipTrigger className="flex items-center gap-1">
<AlertTriangle className="h-3 w-3 text-orange-400" />
<span className={cn(
"text-orange-600 dark:text-orange-400 font-medium"
)}>
{getLimitValue(ORGINIZATIONS) ??
t("billingUnlimited") ??
<span
className={cn(
"text-orange-600 dark:text-orange-400 font-medium"
)}
>
{getLimitValue(
ORGINIZATIONS
) ??
t(
"billingUnlimited"
) ??
"∞"}{" "}
{getLimitValue(ORGINIZATIONS) !==
null && "orgs"}
{getLimitValue(
ORGINIZATIONS
) !== null && "orgs"}
</span>
</TooltipTrigger>
<TooltipContent>
<p>{t("billingUsageExceedsLimit", { current: getUsageValue(ORGINIZATIONS), limit: getLimitValue(ORGINIZATIONS) ?? 0 }) || `Current usage (${getUsageValue(ORGINIZATIONS)}) exceeds limit (${getLimitValue(ORGINIZATIONS)})`}</p>
<p>
{t(
"billingUsageExceedsLimit",
{
current:
getUsageValue(
ORGINIZATIONS
),
limit:
getLimitValue(
ORGINIZATIONS
) ?? 0
}
) ||
`Current usage (${getUsageValue(ORGINIZATIONS)}) exceeds limit (${getLimitValue(ORGINIZATIONS)})`}
</p>
</TooltipContent>
</Tooltip>
) : (
@@ -1008,8 +1171,9 @@ export default function BillingPage() {
{getLimitValue(ORGINIZATIONS) ??
t("billingUnlimited") ??
"∞"}{" "}
{getLimitValue(ORGINIZATIONS) !==
null && "orgs"}
{getLimitValue(
ORGINIZATIONS
) !== null && "orgs"}
</>
)}
</InfoSectionContent>
@@ -1024,27 +1188,52 @@ export default function BillingPage() {
<Tooltip>
<TooltipTrigger className="flex items-center gap-1">
<AlertTriangle className="h-3 w-3 text-orange-400" />
<span className={cn(
"text-orange-600 dark:text-orange-400 font-medium"
)}>
{getLimitValue(REMOTE_EXIT_NODES) ??
t("billingUnlimited") ??
<span
className={cn(
"text-orange-600 dark:text-orange-400 font-medium"
)}
>
{getLimitValue(
REMOTE_EXIT_NODES
) ??
t(
"billingUnlimited"
) ??
"∞"}{" "}
{getLimitValue(REMOTE_EXIT_NODES) !==
null && "nodes"}
{getLimitValue(
REMOTE_EXIT_NODES
) !== null && "nodes"}
</span>
</TooltipTrigger>
<TooltipContent>
<p>{t("billingUsageExceedsLimit", { current: getUsageValue(REMOTE_EXIT_NODES), limit: getLimitValue(REMOTE_EXIT_NODES) ?? 0 }) || `Current usage (${getUsageValue(REMOTE_EXIT_NODES)}) exceeds limit (${getLimitValue(REMOTE_EXIT_NODES)})`}</p>
<p>
{t(
"billingUsageExceedsLimit",
{
current:
getUsageValue(
REMOTE_EXIT_NODES
),
limit:
getLimitValue(
REMOTE_EXIT_NODES
) ?? 0
}
) ||
`Current usage (${getUsageValue(REMOTE_EXIT_NODES)}) exceeds limit (${getLimitValue(REMOTE_EXIT_NODES)})`}
</p>
</TooltipContent>
</Tooltip>
) : (
<>
{getLimitValue(REMOTE_EXIT_NODES) ??
{getLimitValue(
REMOTE_EXIT_NODES
) ??
t("billingUnlimited") ??
"∞"}{" "}
{getLimitValue(REMOTE_EXIT_NODES) !==
null && "nodes"}
{getLimitValue(
REMOTE_EXIT_NODES
) !== null && "nodes"}
</>
)}
</InfoSectionContent>
@@ -1072,7 +1261,8 @@ export default function BillingPage() {
<div className="flex flex-col md:flex-row items-start md:items-center justify-between gap-4 border rounded-lg p-4">
<div>
<div className="text-sm text-muted-foreground mb-1">
{t("billingCurrentKeys") || "Current Keys"}
{t("billingCurrentKeys") ||
"Current Keys"}
</div>
<div className="flex items-baseline gap-2">
<span className="text-3xl font-bold">
@@ -1137,61 +1327,101 @@ export default function BillingPage() {
</div>
</div>
{/* Features with check marks */}
{(() => {
const plan = planOptions.find(
(p) =>
p.tierType === pendingTier.tier ||
(pendingTier.tier === "basic" &&
p.id === "basic")
);
return plan?.features?.length ? (
<div>
<h4 className="font-semibold mb-3">
{"What's included:"}
</h4>
<div className="space-y-2">
{plan.features.map(
(feature, i) => (
<div
key={i}
className="flex items-center gap-2"
>
<Check className="h-4 w-4 text-green-600 shrink-0" />
<span>
{feature}
</span>
</div>
)
)}
</div>
</div>
) : null;
})()}
{/* Limits without check marks */}
{tierLimits[pendingTier.tier] && (
<div>
<h4 className="font-semibold mb-3">
{t("billingPlanIncludes") ||
"Plan Includes:"}
{"Up to:"}
</h4>
<div className="space-y-2">
<div className="space-y-2 text-sm text-muted-foreground">
<div className="flex items-center gap-2">
<Check className="h-4 w-4 text-green-600" />
<span className="w-1.5 h-1.5 rounded-full bg-muted-foreground/50 shrink-0" />
<span>
{
tierLimits[pendingTier.tier]
.users
tierLimits[
pendingTier.tier
].users
}{" "}
{t("billingUsers") || "Users"}
{t("billingUsers") ||
"Users"}
</span>
</div>
<div className="flex items-center gap-2">
<Check className="h-4 w-4 text-green-600" />
<span className="w-1.5 h-1.5 rounded-full bg-muted-foreground/50 shrink-0" />
<span>
{
tierLimits[pendingTier.tier]
.sites
tierLimits[
pendingTier.tier
].sites
}{" "}
{t("billingSites") || "Sites"}
{t("billingSites") ||
"Sites"}
</span>
</div>
<div className="flex items-center gap-2">
<Check className="h-4 w-4 text-green-600" />
<span className="w-1.5 h-1.5 rounded-full bg-muted-foreground/50 shrink-0" />
<span>
{
tierLimits[pendingTier.tier]
.domains
tierLimits[
pendingTier.tier
].domains
}{" "}
{t("billingDomains") ||
"Domains"}
</span>
</div>
<div className="flex items-center gap-2">
<Check className="h-4 w-4 text-green-600" />
<span className="w-1.5 h-1.5 rounded-full bg-muted-foreground/50 shrink-0" />
<span>
{
tierLimits[pendingTier.tier]
.organizations
tierLimits[
pendingTier.tier
].organizations
}{" "}
{t("billingOrganizations") ||
"Organizations"}
{t(
"billingOrganizations"
) || "Organizations"}
</span>
</div>
<div className="flex items-center gap-2">
<Check className="h-4 w-4 text-green-600" />
<span className="w-1.5 h-1.5 rounded-full bg-muted-foreground/50 shrink-0" />
<span>
{
tierLimits[pendingTier.tier]
.remoteNodes
tierLimits[
pendingTier.tier
].remoteNodes
}{" "}
{t("billingRemoteNodes") ||
"Remote Nodes"}
@@ -1202,43 +1432,84 @@ export default function BillingPage() {
)}
{/* Warning for limit violations when downgrading */}
{pendingTier.action === "downgrade" && (() => {
const violations = checkLimitViolations(pendingTier.tier);
if (violations.length > 0) {
return (
<Alert variant="destructive">
<AlertTriangle className="h-4 w-4" />
<AlertTitle>
{t("billingLimitViolationWarning") || "Usage Exceeds New Plan Limits"}
</AlertTitle>
<AlertDescription>
<p className="mb-3">
{t("billingLimitViolationDescription") || "Your current usage exceeds the limits of this plan. The following features will be disabled until you reduce usage:"}
</p>
<ul className="space-y-2">
{violations.map((violation, index) => (
<li key={index} className="flex items-center gap-2">
<span className="font-medium">{violation.feature}:</span>
<span>Currently using {violation.currentUsage}, new limit is {violation.newLimit}</span>
</li>
))}
</ul>
</AlertDescription>
</Alert>
{pendingTier.action === "downgrade" &&
(() => {
const violations = checkLimitViolations(
pendingTier.tier
);
}
return null;
})()}
if (violations.length > 0) {
return (
<Alert variant="destructive">
<AlertTriangle className="h-4 w-4" />
<AlertTitle>
{t(
"billingLimitViolationWarning"
) ||
"Usage Exceeds New Plan Limits"}
</AlertTitle>
<AlertDescription>
<p className="mb-3">
{t(
"billingLimitViolationDescription"
) ||
"Your current usage exceeds the limits of this plan. The following features will be disabled until you reduce usage:"}
</p>
<ul className="space-y-2">
{violations.map(
(
violation,
index
) => (
<li
key={
index
}
className="flex items-center gap-2"
>
<span className="font-medium">
{
violation.feature
}
:
</span>
<span>
Currently
using{" "}
{
violation.currentUsage
}
,
new
limit
is{" "}
{
violation.newLimit
}
</span>
</li>
)
)}
</ul>
</AlertDescription>
</Alert>
);
}
return null;
})()}
{/* Warning for feature loss when downgrading */}
{pendingTier.action === "downgrade" && (
<Alert variant="warning">
<AlertTriangle className="h-4 w-4" />
<AlertTitle>
{t("billingFeatureLossWarning") || "Feature Availability Notice"}
{t("billingFeatureLossWarning") ||
"Feature Availability Notice"}
</AlertTitle>
<AlertDescription>
{t("billingFeatureLossDescription") || "By downgrading, features not available in the new plan will be automatically disabled. Some settings and configurations may be lost. Please review the pricing matrix to understand which features will no longer be available."}
{t(
"billingFeatureLossDescription"
) ||
"By downgrading, features not available in the new plan will be automatically disabled. Some settings and configurations may be lost. Please review the pricing matrix to understand which features will no longer be available."}
</AlertDescription>
</Alert>
)}

View File

@@ -69,6 +69,7 @@ export default async function DomainSettingsPage({
failed={domain.failed}
verified={domain.verified}
type={domain.type}
errorMessage={domain.errorMessage}
/>
<DNSRecordsTable records={dnsRecords} type={domain.type} />

View File

@@ -54,6 +54,7 @@ import {
TooltipProvider,
TooltipTrigger
} from "@app/components/ui/tooltip";
import { Alert, AlertDescription, AlertTitle } from "@app/components/ui/alert";
import { useEnvContext } from "@app/hooks/useEnvContext";
import { toast } from "@app/hooks/useToast";
import { createApiClient, formatAxiosError } from "@app/lib/api";
@@ -65,6 +66,7 @@ import { build } from "@server/build";
import { Resource } from "@server/db";
import { isTargetValid } from "@server/lib/validators";
import { ListTargetsResponse } from "@server/routers/target";
import { ListRemoteExitNodesResponse } from "@server/routers/remoteExitNode/types";
import { ArrayElement } from "@server/types/ArrayElement";
import { useQuery } from "@tanstack/react-query";
import {
@@ -81,6 +83,7 @@ import {
CircleCheck,
CircleX,
Info,
InfoIcon,
Plus,
Settings,
SquareArrowOutUpRight
@@ -210,6 +213,13 @@ export default function Page() {
orgQueries.sites({ orgId: orgId as string })
);
const [remoteExitNodes, setRemoteExitNodes] = useState<
ListRemoteExitNodesResponse["remoteExitNodes"]
>([]);
const [loadingExitNodes, setLoadingExitNodes] = useState(
build === "saas"
);
const [createLoading, setCreateLoading] = useState(false);
const [showSnippets, setShowSnippets] = useState(false);
const [niceId, setNiceId] = useState<string>("");
@@ -224,6 +234,27 @@ export default function Page() {
useState<LocalTarget | null>(null);
const [healthCheckDialogOpen, setHealthCheckDialogOpen] = useState(false);
useEffect(() => {
if (build !== "saas") return;
const fetchExitNodes = async () => {
try {
const res = await api.get<
AxiosResponse<ListRemoteExitNodesResponse>
>(`/org/${orgId}/remote-exit-nodes`);
if (res && res.status === 200) {
setRemoteExitNodes(res.data.data.remoteExitNodes);
}
} catch (e) {
console.error("Failed to fetch remote exit nodes:", e);
} finally {
setLoadingExitNodes(false);
}
};
fetchExitNodes();
}, [orgId]);
const [isAdvancedMode, setIsAdvancedMode] = useState(() => {
if (typeof window !== "undefined") {
const saved = localStorage.getItem("create-advanced-mode");
@@ -289,15 +320,25 @@ export default function Page() {
},
...(!env.flags.allowRawResources
? []
: [
{
id: "raw" as ResourceType,
title: t("resourceRaw"),
description: build == "saas" ? t("resourceRawDescriptionCloud") : t("resourceRawDescription")
}
])
: build === "saas" && remoteExitNodes.length === 0
? []
: [
{
id: "raw" as ResourceType,
title: t("resourceRaw"),
description:
build == "saas"
? t("resourceRawDescriptionCloud")
: t("resourceRawDescription")
}
])
];
// In saas mode with no exit nodes, force HTTP
const showTypeSelector =
build !== "saas" ||
(!loadingExitNodes && remoteExitNodes.length > 0);
const baseForm = useForm({
resolver: zodResolver(baseResourceFormSchema),
defaultValues: {
@@ -984,34 +1025,35 @@ export default function Page() {
</SettingsSectionTitle>
</SettingsSectionHeader>
<SettingsSectionBody>
{resourceTypes.length > 1 && (
<>
<div className="mb-2">
<span className="text-sm font-medium">
{t("type")}
</span>
</div>
{showTypeSelector &&
resourceTypes.length > 1 && (
<>
<div className="mb-2">
<span className="text-sm font-medium">
{t("type")}
</span>
</div>
<StrategySelect
options={resourceTypes}
defaultValue="http"
onChange={(value) => {
baseForm.setValue(
"http",
value === "http"
);
// Update method default when switching resource type
addTargetForm.setValue(
"method",
value === "http"
? "http"
: null
);
}}
cols={2}
/>
</>
)}
<StrategySelect
options={resourceTypes}
defaultValue="http"
onChange={(value) => {
baseForm.setValue(
"http",
value === "http"
);
// Update method default when switching resource type
addTargetForm.setValue(
"method",
value === "http"
? "http"
: null
);
}}
cols={2}
/>
</>
)}
<SettingsSectionForm>
<Form {...baseForm}>

View File

@@ -10,17 +10,20 @@ import {
import { useTranslations } from "next-intl";
import { Badge } from "./ui/badge";
import { useEnvContext } from "@app/hooks/useEnvContext";
import { AlertTriangle } from "lucide-react";
type DomainInfoCardProps = {
failed: boolean;
verified: boolean;
type: string | null;
errorMessage?: string | null;
};
export default function DomainInfoCard({
failed,
verified,
type
type,
errorMessage
}: DomainInfoCardProps) {
const t = useTranslations();
const env = useEnvContext();
@@ -39,6 +42,7 @@ export default function DomainInfoCard({
};
return (
<div className="space-y-3">
<Alert>
<AlertDescription>
<InfoSections cols={3}>
@@ -79,5 +83,19 @@ export default function DomainInfoCard({
</InfoSections>
</AlertDescription>
</Alert>
{errorMessage && (failed || !verified) && (
<Alert variant={failed ? "destructive" : "warning"}>
<AlertTriangle className="h-4 w-4" />
<AlertTitle>
{failed
? t("domainErrorTitle", { fallback: "Domain Error" })
: t("domainPendingErrorTitle", { fallback: "Verification Issue" })}
</AlertTitle>
<AlertDescription className="font-mono text-xs break-all">
{errorMessage}
</AlertDescription>
</Alert>
)}
</div>
);
}

View File

@@ -27,6 +27,12 @@ import {
DropdownMenuItem,
DropdownMenuTrigger
} from "./ui/dropdown-menu";
import {
Tooltip,
TooltipContent,
TooltipProvider,
TooltipTrigger
} from "./ui/tooltip";
import Link from "next/link";
export type DomainRow = {
@@ -39,6 +45,7 @@ export type DomainRow = {
configManaged: boolean;
certResolver: string;
preferWildcardCert: boolean;
errorMessage?: string | null;
};
type Props = {
@@ -175,7 +182,7 @@ export default function DomainsTable({ domains, orgId }: Props) {
);
},
cell: ({ row }) => {
const { verified, failed, type } = row.original;
const { verified, failed, type, errorMessage } = row.original;
if (verified) {
return type == "wildcard" ? (
<Badge variant="outlinePrimary">{t("manual")}</Badge>
@@ -183,12 +190,44 @@ export default function DomainsTable({ domains, orgId }: Props) {
<Badge variant="green">{t("verified")}</Badge>
);
} else if (failed) {
if (errorMessage) {
return (
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<Badge variant="red" className="cursor-help">
{t("failed", { fallback: "Failed" })}
</Badge>
</TooltipTrigger>
<TooltipContent className="max-w-xs">
<p className="break-words">{errorMessage}</p>
</TooltipContent>
</Tooltip>
</TooltipProvider>
);
}
return (
<Badge variant="red">
{t("failed", { fallback: "Failed" })}
</Badge>
);
} else {
if (errorMessage) {
return (
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<Badge variant="yellow" className="cursor-help">
{t("pending")}
</Badge>
</TooltipTrigger>
<TooltipContent className="max-w-xs">
<p className="break-words">{errorMessage}</p>
</TooltipContent>
</Tooltip>
</TooltipProvider>
);
}
return <Badge variant="yellow">{t("pending")}</Badge>;
}
}

View File

@@ -51,6 +51,7 @@ const docsLinkClassName =
const PANGOLIN_CLOUD_SIGNUP_URL = "https://app.pangolin.net/auth/signup/";
const ENTERPRISE_DOCS_URL =
"https://docs.pangolin.net/self-host/enterprise-edition";
const BOOK_A_DEMO_URL = "https://click.fossorial.io/ep922";
function getTierLinkRenderer(billingHref: string) {
return function tierLinkRenderer(chunks: React.ReactNode) {
@@ -78,6 +79,22 @@ function getPangolinCloudLinkRenderer() {
};
}
function getBookADemoLinkRenderer() {
return function bookADemoLinkRenderer(chunks: React.ReactNode) {
return (
<Link
href={BOOK_A_DEMO_URL}
target="_blank"
rel="noopener noreferrer"
className={docsLinkClassName}
>
{chunks}
<ExternalLink className="size-3.5 shrink-0" />
</Link>
);
};
}
function getDocsLinkRenderer(href: string) {
return function docsLinkRenderer(chunks: React.ReactNode) {
return (
@@ -116,6 +133,7 @@ export function PaidFeaturesAlert({ tiers }: Props) {
const tierLinkRenderer = getTierLinkRenderer(billingHref);
const pangolinCloudLinkRenderer = getPangolinCloudLinkRenderer();
const enterpriseDocsLinkRenderer = getDocsLinkRenderer(ENTERPRISE_DOCS_URL);
const bookADemoLinkRenderer = getBookADemoLinkRenderer();
if (env.flags.disableEnterpriseFeatures) {
return null;
@@ -157,7 +175,8 @@ export function PaidFeaturesAlert({ tiers }: Props) {
{t.rich("licenseRequiredToUse", {
enterpriseLicenseLink:
enterpriseDocsLinkRenderer,
pangolinCloudLink: pangolinCloudLinkRenderer
pangolinCloudLink: pangolinCloudLinkRenderer,
bookADemoLink: bookADemoLinkRenderer
})}
</span>
</div>
@@ -174,7 +193,8 @@ export function PaidFeaturesAlert({ tiers }: Props) {
{t.rich("ossEnterpriseEditionRequired", {
enterpriseEditionLink:
enterpriseDocsLinkRenderer,
pangolinCloudLink: pangolinCloudLinkRenderer
pangolinCloudLink: pangolinCloudLinkRenderer,
bookADemoLink: bookADemoLinkRenderer
})}
</span>
</div>