Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions apps/sim/app/api/speech/token/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { createLogger } from '@sim/logger'
import { eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { getSession } from '@/lib/auth'
import { hasExceededCostLimit } from '@/lib/billing/core/subscription'
import { checkServerSideUsageLimits } from '@/lib/billing/calculations/usage-monitor'
import { recordUsage } from '@/lib/billing/core/usage-log'
import { env } from '@/lib/core/config/env'
import { getCostMultiplier, isBillingEnabled } from '@/lib/core/config/feature-flags'
Expand Down Expand Up @@ -110,11 +110,14 @@ export async function POST(request: NextRequest) {
}
}

if (billingUserId && isBillingEnabled) {
const exceeded = await hasExceededCostLimit(billingUserId)
if (exceeded) {
if (billingUserId) {
const usageCheck = await checkServerSideUsageLimits(billingUserId)
if (usageCheck.isExceeded) {
return NextResponse.json(
{ error: 'Usage limit exceeded. Please upgrade your plan to continue.' },
{
error:
usageCheck.message || 'Usage limit exceeded. Please upgrade your plan to continue.',
},
{ status: 402 }
)
}
Expand Down
146 changes: 1 addition & 145 deletions apps/sim/lib/billing/core/subscription.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import { db } from '@sim/db'
import { member, subscription, user, userStats } from '@sim/db/schema'
import { member, subscription, user } from '@sim/db/schema'
import { createLogger } from '@sim/logger'
import { and, eq, inArray, sql } from 'drizzle-orm'
import { getEffectiveBillingStatus, isOrganizationBillingBlocked } from '@/lib/billing/core/access'
import { getHighestPrioritySubscription } from '@/lib/billing/core/plan'
import { getUserUsageLimit } from '@/lib/billing/core/usage'
import {
getPlanTierCredits,
isOrgPlan,
isPro as isPlanPro,
isTeam as isPlanTeam,
} from '@/lib/billing/plan-helpers'
Expand All @@ -16,12 +14,9 @@ import {
checkProPlan,
checkTeamPlan,
ENTITLED_SUBSCRIPTION_STATUSES,
getFreeTierLimit,
getPerUserMinimumLimit,
hasUsableSubscriptionAccess,
USABLE_SUBSCRIPTION_STATUSES,
} from '@/lib/billing/subscriptions/utils'
import type { UserSubscriptionState } from '@/lib/billing/types'
import {
isAccessControlEnabled,
isBillingEnabled,
Expand Down Expand Up @@ -485,145 +480,6 @@ export async function hasLiveSyncAccess(userId: string): Promise<boolean> {
}
}

/**
* Check if user has exceeded their cost limit based on current period usage
*/
export async function hasExceededCostLimit(userId: string): Promise<boolean> {
try {
if (!isBillingEnabled) {
return false
}

const subscription = await getHighestPrioritySubscription(userId)

let limit = getFreeTierLimit() // Default free tier limit

if (subscription) {
// Team/Enterprise: Use organization limit
if (isOrgPlan(subscription.plan)) {
limit = await getUserUsageLimit(userId)
logger.info('Using organization limit', {
userId,
plan: subscription.plan,
limit,
})
} else {
// Pro/Free: Use individual limit
limit = getPerUserMinimumLimit(subscription)
logger.info('Using subscription-based limit', {
userId,
plan: subscription.plan,
limit,
})
}
} else {
logger.info('Using free tier limit', { userId, limit })
}

// Get user stats to check current period usage
const statsRecords = await db.select().from(userStats).where(eq(userStats.userId, userId))

if (statsRecords.length === 0) {
return false
}

// Use current period cost instead of total cost for accurate billing period tracking
const currentCost = Number.parseFloat(
statsRecords[0].currentPeriodCost?.toString() || statsRecords[0].totalCost.toString()
)

logger.info('Checking cost limit', { userId, currentCost, limit })

return currentCost >= limit
} catch (error) {
logger.error('Error checking cost limit', { error, userId })
return false // Be conservative in case of error
}
}

/**
* Check if sharing features are enabled for user
*/
// Removed unused feature flag helpers: isSharingEnabled, isMultiplayerEnabled, isWorkspaceCollaborationEnabled

/**
* Get comprehensive subscription state for a user
* Single function to get all subscription information
*/
export async function getUserSubscriptionState(userId: string): Promise<UserSubscriptionState> {
try {
// Get subscription and user stats in parallel to minimize DB calls
const [subscription, statsRecords] = await Promise.all([
getHighestPrioritySubscription(userId),
db.select().from(userStats).where(eq(userStats.userId, userId)).limit(1),
])

// Determine plan types based on subscription (avoid redundant DB calls)
const isPro =
!isBillingEnabled ||
!!(
subscription &&
(checkProPlan(subscription) ||
checkTeamPlan(subscription) ||
checkEnterprisePlan(subscription))
)
const isTeam =
!isBillingEnabled ||
!!(subscription && (checkTeamPlan(subscription) || checkEnterprisePlan(subscription)))
const isEnterprise = !isBillingEnabled || !!(subscription && checkEnterprisePlan(subscription))
const isFree = !isPro && !isTeam && !isEnterprise

// Determine plan name
let planName = 'free'
if (isEnterprise) planName = 'enterprise'
else if (isTeam) planName = 'team'
else if (isPro) planName = 'pro'

// Check cost limit using already-fetched user stats
let hasExceededLimit = false
if (isBillingEnabled && statsRecords.length > 0) {
let limit = getFreeTierLimit() // Default free tier limit
if (subscription) {
// Team/Enterprise: Use organization limit
if (isOrgPlan(subscription.plan)) {
limit = await getUserUsageLimit(userId)
} else {
// Pro/Free: Use individual limit
limit = getPerUserMinimumLimit(subscription)
}
}

const currentCost = Number.parseFloat(
statsRecords[0].currentPeriodCost?.toString() || statsRecords[0].totalCost.toString()
)
hasExceededLimit = currentCost >= limit
}

return {
isPro,
isTeam,
isEnterprise,
isFree,
highestPrioritySubscription: subscription,
hasExceededLimit,
planName,
}
} catch (error) {
logger.error('Error getting user subscription state', { error, userId })

// Return safe defaults in case of error
return {
isPro: false,
isTeam: false,
isEnterprise: false,
isFree: true,
highestPrioritySubscription: null,
hasExceededLimit: false,
planName: 'free',
}
}
}

/**
* Send welcome email for Pro and Team plan subscriptions
*/
Expand Down
1 change: 0 additions & 1 deletion apps/sim/lib/billing/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ export * from '@/lib/billing/core/organization'
export * from '@/lib/billing/core/subscription'
export {
getHighestPrioritySubscription as getActiveSubscription,
getUserSubscriptionState as getSubscriptionState,
hasAccessControlAccess,
hasCredentialSetsAccess,
hasPaidSubscription,
Expand Down
10 changes: 0 additions & 10 deletions apps/sim/lib/billing/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,6 @@ export interface BillingData {
daysRemaining: number
}

export interface UserSubscriptionState {
isPro: boolean
isTeam: boolean
isEnterprise: boolean
isFree: boolean
highestPrioritySubscription: any | null
hasExceededLimit: boolean
planName: string
}

export interface SubscriptionPlan {
name: string
priceId: string
Expand Down
35 changes: 19 additions & 16 deletions apps/sim/lib/copilot/chat-payload.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,27 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'

vi.mock('@sim/logger', () => {
const createMockLogger = (): Record<string, any> => ({
const createMockLogger = () => ({
info: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
withMetadata: vi.fn(() => createMockLogger()),
debug: vi.fn(),
})
return { createLogger: vi.fn(() => createMockLogger()) }
})

const { mockGetHighestPrioritySubscription } = vi.hoisted(() => ({
mockGetHighestPrioritySubscription: vi.fn(),
}))

vi.mock('@/lib/billing/core/subscription', () => ({
getUserSubscriptionState: vi.fn(),
getHighestPrioritySubscription: mockGetHighestPrioritySubscription,
}))

vi.mock('@/lib/billing/plan-helpers', () => ({
isPaid: vi.fn(
(plan: string | null) => plan === 'pro' || plan === 'team' || plan === 'enterprise'
),
}))

vi.mock('@/lib/copilot/chat-context', () => ({
Expand Down Expand Up @@ -57,48 +67,41 @@ vi.mock('@/tools/params', () => ({
createUserToolSchema: vi.fn(() => ({ type: 'object', properties: {} })),
}))

import { getUserSubscriptionState } from '@/lib/billing/core/subscription'
import { buildIntegrationToolSchemas } from '@/lib/copilot/chat-payload'

const mockedGetUserSubscriptionState = getUserSubscriptionState as unknown as {
mockResolvedValue: (value: unknown) => void
mockRejectedValue: (value: unknown) => void
mockClear: () => void
}

describe('buildIntegrationToolSchemas', () => {
beforeEach(() => {
vi.clearAllMocks()
})

it('appends the email footer prompt for free users', async () => {
mockedGetUserSubscriptionState.mockResolvedValue({ isFree: true })
mockGetHighestPrioritySubscription.mockResolvedValue(null)

const toolSchemas = await buildIntegrationToolSchemas('user-free')
const gmailTool = toolSchemas.find((tool) => tool.name === 'gmail_send')

expect(getUserSubscriptionState).toHaveBeenCalledWith('user-free')
expect(mockGetHighestPrioritySubscription).toHaveBeenCalledWith('user-free')
expect(gmailTool?.description).toContain('sent with sim ai')
})

it('does not append the email footer prompt for paid users', async () => {
mockedGetUserSubscriptionState.mockResolvedValue({ isFree: false })
mockGetHighestPrioritySubscription.mockResolvedValue({ plan: 'pro', status: 'active' })

const toolSchemas = await buildIntegrationToolSchemas('user-paid')
const gmailTool = toolSchemas.find((tool) => tool.name === 'gmail_send')

expect(getUserSubscriptionState).toHaveBeenCalledWith('user-paid')
expect(mockGetHighestPrioritySubscription).toHaveBeenCalledWith('user-paid')
expect(gmailTool?.description).toBe('Send emails using Gmail')
})

it('still builds integration tools when subscription lookup fails', async () => {
mockedGetUserSubscriptionState.mockRejectedValue(new Error('db unavailable'))
mockGetHighestPrioritySubscription.mockRejectedValue(new Error('db unavailable'))

const toolSchemas = await buildIntegrationToolSchemas('user-error')
const gmailTool = toolSchemas.find((tool) => tool.name === 'gmail_send')
const brandfetchTool = toolSchemas.find((tool) => tool.name === 'brandfetch_search')

expect(getUserSubscriptionState).toHaveBeenCalledWith('user-error')
expect(mockGetHighestPrioritySubscription).toHaveBeenCalledWith('user-error')
expect(gmailTool?.description).toBe('Send emails using Gmail')
expect(brandfetchTool?.description).toBe('Search for brands by company name')
})
Expand Down
9 changes: 5 additions & 4 deletions apps/sim/lib/copilot/chat-payload.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { createLogger } from '@sim/logger'
import { getUserSubscriptionState } from '@/lib/billing/core/subscription'
import { getHighestPrioritySubscription } from '@/lib/billing/core/subscription'
import { isPaid } from '@/lib/billing/plan-helpers'
import { getCopilotToolDescription } from '@/lib/copilot/tool-descriptions'
import { isHosted } from '@/lib/core/config/feature-flags'
import { createMcpToolId } from '@/lib/mcp/utils'
Expand Down Expand Up @@ -49,7 +50,7 @@
userId: string,
messageId?: string
): Promise<ToolSchema[]> {
const reqLogger = logger.withMetadata({ messageId })

Check failure on line 53 in apps/sim/lib/copilot/chat-payload.ts

View workflow job for this annotation

GitHub Actions / Test and Build / Test and Build

lib/copilot/chat-payload.test.ts > buildIntegrationToolSchemas > still builds integration tools when subscription lookup fails

TypeError: logger.withMetadata is not a function ❯ Module.buildIntegrationToolSchemas lib/copilot/chat-payload.ts:53:28 ❯ lib/copilot/chat-payload.test.ts:100:31

Check failure on line 53 in apps/sim/lib/copilot/chat-payload.ts

View workflow job for this annotation

GitHub Actions / Test and Build / Test and Build

lib/copilot/chat-payload.test.ts > buildIntegrationToolSchemas > does not append the email footer prompt for paid users

TypeError: logger.withMetadata is not a function ❯ Module.buildIntegrationToolSchemas lib/copilot/chat-payload.ts:53:28 ❯ lib/copilot/chat-payload.test.ts:90:31

Check failure on line 53 in apps/sim/lib/copilot/chat-payload.ts

View workflow job for this annotation

GitHub Actions / Test and Build / Test and Build

lib/copilot/chat-payload.test.ts > buildIntegrationToolSchemas > appends the email footer prompt for free users

TypeError: logger.withMetadata is not a function ❯ Module.buildIntegrationToolSchemas lib/copilot/chat-payload.ts:53:28 ❯ lib/copilot/chat-payload.test.ts:80:31
const integrationTools: ToolSchema[] = []
try {
const { createUserToolSchema } = await import('@/tools/params')
Expand All @@ -57,10 +58,10 @@
let shouldAppendEmailTagline = false

try {
const subscriptionState = await getUserSubscriptionState(userId)
shouldAppendEmailTagline = subscriptionState.isFree
const subscription = await getHighestPrioritySubscription(userId)
shouldAppendEmailTagline = !subscription || !isPaid(subscription.plan)
} catch (error) {
reqLogger.warn('Failed to load subscription state for copilot tool descriptions', {
reqLogger.warn('Failed to load subscription for copilot tool descriptions', {
userId,
error: error instanceof Error ? error.message : String(error),
})
Expand Down
7 changes: 4 additions & 3 deletions apps/sim/lib/table/billing.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
*/

import { createLogger } from '@sim/logger'
import { getUserSubscriptionState } from '@/lib/billing/core/subscription'
import { getHighestPrioritySubscription } from '@/lib/billing/core/subscription'
import { getPlanTypeForLimits } from '@/lib/billing/plan-helpers'
import { getWorkspaceBilledAccountUserId } from '@/lib/workspaces/utils'
import { type PlanName, TABLE_PLAN_LIMITS, type TablePlanLimits } from './constants'

Expand All @@ -29,8 +30,8 @@ export async function getWorkspaceTableLimits(workspaceId: string): Promise<Tabl
return TABLE_PLAN_LIMITS.free
}

const subscriptionState = await getUserSubscriptionState(billedAccountUserId)
const planName = subscriptionState.planName as PlanName
const subscription = await getHighestPrioritySubscription(billedAccountUserId)
const planName = getPlanTypeForLimits(subscription?.plan) as PlanName

const limits = TABLE_PLAN_LIMITS[planName] ?? TABLE_PLAN_LIMITS.free

Expand Down
Loading