diff --git a/.gitignore b/.gitignore index 3b6cc969..5d417368 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,5 @@ coverage *.tsbuildinfo .turbo + +.databricks diff --git a/apps/dev-playground/client/src/appKitServingTypes.d.ts b/apps/dev-playground/client/src/appKitServingTypes.d.ts new file mode 100644 index 00000000..28f610b4 --- /dev/null +++ b/apps/dev-playground/client/src/appKitServingTypes.d.ts @@ -0,0 +1,114 @@ +// Auto-generated by AppKit - DO NOT EDIT +// Generated from serving endpoint OpenAPI schemas +import "@databricks/appkit"; +import "@databricks/appkit-ui/react"; + +declare module "@databricks/appkit" { + interface ServingEndpointRegistry { + default: { + request: { + messages?: { + role?: "user" | "assistant"; + content?: string; + }[]; + /** @openapi integer, nullable */ + n?: number | null; + max_tokens?: number; + /** @openapi double, nullable */ + top_p?: number | null; + reasoning_effort?: "low" | "medium" | "high" | null; + /** @openapi double, nullable */ + temperature?: number | null; + stop?: string | string[] | null; + }; + response: { + model?: string; + choices?: { + index?: number; + message?: { + role?: "user" | "assistant"; + content?: string; + }; + finish_reason?: string; + }[]; + usage?: { + prompt_tokens?: number; + completion_tokens?: number; + total_tokens?: number; + } | null; + object?: string; + id?: string; + created?: number; + }; + chunk: { + model?: string; + choices?: { + index?: number; + delta?: { + role?: "user" | "assistant"; + content?: string; + }; + finish_reason?: string | null; + }[]; + object?: string; + id?: string; + created?: number; + }; + }; + } +} + +declare module "@databricks/appkit-ui/react" { + interface ServingEndpointRegistry { + default: { + request: { + messages?: { + role?: "user" | "assistant"; + content?: string; + }[]; + /** @openapi integer, nullable */ + n?: number | null; + max_tokens?: number; + /** @openapi double, nullable */ + top_p?: number | null; + reasoning_effort?: "low" | "medium" | "high" | null; + /** @openapi double, nullable */ + temperature?: number | null; + stop?: string | string[] | null; + }; + response: { + model?: string; + choices?: { + index?: number; + message?: { + role?: "user" | "assistant"; + content?: string; + }; + finish_reason?: string; + }[]; + usage?: { + prompt_tokens?: number; + completion_tokens?: number; + total_tokens?: number; + } | null; + object?: string; + id?: string; + created?: number; + }; + chunk: { + model?: string; + choices?: { + index?: number; + delta?: { + role?: "user" | "assistant"; + content?: string; + }; + finish_reason?: string | null; + }[]; + object?: string; + id?: string; + created?: number; + }; + }; + } +} diff --git a/docs/docs/api/appkit/Function.appKitServingTypesPlugin.md b/docs/docs/api/appkit/Function.appKitServingTypesPlugin.md new file mode 100644 index 00000000..e53e5bd3 --- /dev/null +++ b/docs/docs/api/appkit/Function.appKitServingTypesPlugin.md @@ -0,0 +1,24 @@ +# Function: appKitServingTypesPlugin() + +```ts +function appKitServingTypesPlugin(options?: AppKitServingTypesPluginOptions): Plugin$1; +``` + +Vite plugin to generate TypeScript types for AppKit serving endpoints. +Fetches OpenAPI schemas from Databricks and generates a .d.ts with +ServingEndpointRegistry module augmentation. + +Endpoint discovery order: +1. Explicit `endpoints` option (override) +2. AST extraction from server file (server/index.ts or server/server.ts) +3. DATABRICKS_SERVING_ENDPOINT_NAME env var (single default endpoint) + +## Parameters + +| Parameter | Type | +| ------ | ------ | +| `options?` | `AppKitServingTypesPluginOptions` | + +## Returns + +`Plugin$1` diff --git a/docs/docs/api/appkit/Function.extractServingEndpoints.md b/docs/docs/api/appkit/Function.extractServingEndpoints.md new file mode 100644 index 00000000..24a5b00d --- /dev/null +++ b/docs/docs/api/appkit/Function.extractServingEndpoints.md @@ -0,0 +1,24 @@ +# Function: extractServingEndpoints() + +```ts +function extractServingEndpoints(serverFilePath: string): + | Record + | null; +``` + +Extract serving endpoint config from a server file by AST-parsing it. +Looks for `serving({ endpoints: { alias: { env: "..." }, ... } })` calls +and extracts the endpoint alias names and their environment variable mappings. + +## Parameters + +| Parameter | Type | Description | +| ------ | ------ | ------ | +| `serverFilePath` | `string` | Absolute path to the server entry file | + +## Returns + + \| `Record`\<`string`, [`EndpointConfig`](Interface.EndpointConfig.md)\> + \| `null` + +Extracted endpoint config, or null if not found or not extractable diff --git a/docs/docs/api/appkit/Function.findServerFile.md b/docs/docs/api/appkit/Function.findServerFile.md new file mode 100644 index 00000000..2ed4e268 --- /dev/null +++ b/docs/docs/api/appkit/Function.findServerFile.md @@ -0,0 +1,19 @@ +# Function: findServerFile() + +```ts +function findServerFile(basePath: string): string | null; +``` + +Find the server entry file by checking candidate paths in order. + +## Parameters + +| Parameter | Type | Description | +| ------ | ------ | ------ | +| `basePath` | `string` | Project root directory to search from | + +## Returns + +`string` \| `null` + +Absolute path to the server file, or null if none found diff --git a/docs/docs/api/appkit/index.md b/docs/docs/api/appkit/index.md index f4685e04..faadf237 100644 --- a/docs/docs/api/appkit/index.md +++ b/docs/docs/api/appkit/index.md @@ -70,9 +70,12 @@ plugin architecture, and React integration. | Function | Description | | ------ | ------ | +| [appKitServingTypesPlugin](Function.appKitServingTypesPlugin.md) | Vite plugin to generate TypeScript types for AppKit serving endpoints. Fetches OpenAPI schemas from Databricks and generates a .d.ts with ServingEndpointRegistry module augmentation. | | [appKitTypesPlugin](Function.appKitTypesPlugin.md) | Vite plugin to generate types for AppKit queries. Calls generateFromEntryPoint under the hood. | | [createApp](Function.createApp.md) | Bootstraps AppKit with the provided configuration. | | [createLakebasePool](Function.createLakebasePool.md) | Create a Lakebase pool with appkit's logger integration. Telemetry automatically uses appkit's OpenTelemetry configuration via global registry. | +| [extractServingEndpoints](Function.extractServingEndpoints.md) | Extract serving endpoint config from a server file by AST-parsing it. Looks for `serving({ endpoints: { alias: { env: "..." }, ... } })` calls and extracts the endpoint alias names and their environment variable mappings. | +| [findServerFile](Function.findServerFile.md) | Find the server entry file by checking candidate paths in order. | | [generateDatabaseCredential](Function.generateDatabaseCredential.md) | Generate OAuth credentials for Postgres database connection using the proper Postgres API. | | [getExecutionContext](Function.getExecutionContext.md) | Get the current execution context. | | [getLakebaseOrmConfig](Function.getLakebaseOrmConfig.md) | Get Lakebase connection configuration for ORMs that don't accept pg.Pool directly. | diff --git a/docs/docs/api/appkit/typedoc-sidebar.ts b/docs/docs/api/appkit/typedoc-sidebar.ts index 91815e3d..1d498d1a 100644 --- a/docs/docs/api/appkit/typedoc-sidebar.ts +++ b/docs/docs/api/appkit/typedoc-sidebar.ts @@ -225,6 +225,11 @@ const typedocSidebar: SidebarsConfig = { type: "category", label: "Functions", items: [ + { + type: "doc", + id: "api/appkit/Function.appKitServingTypesPlugin", + label: "appKitServingTypesPlugin" + }, { type: "doc", id: "api/appkit/Function.appKitTypesPlugin", @@ -240,6 +245,16 @@ const typedocSidebar: SidebarsConfig = { id: "api/appkit/Function.createLakebasePool", label: "createLakebasePool" }, + { + type: "doc", + id: "api/appkit/Function.extractServingEndpoints", + label: "extractServingEndpoints" + }, + { + type: "doc", + id: "api/appkit/Function.findServerFile", + label: "findServerFile" + }, { type: "doc", id: "api/appkit/Function.generateDatabaseCredential", diff --git a/packages/appkit-ui/src/react/hooks/__tests__/use-serving-invoke.test.ts b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-invoke.test.ts new file mode 100644 index 00000000..b4f5bab0 --- /dev/null +++ b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-invoke.test.ts @@ -0,0 +1,209 @@ +import { act, renderHook, waitFor } from "@testing-library/react"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; + +const mockUsePluginClientConfig = vi + .fn() + .mockReturnValue({ isNamedMode: false, aliases: ["default"] }); + +vi.mock("../use-plugin-config", () => ({ + usePluginClientConfig: (...args: unknown[]) => + mockUsePluginClientConfig(...args), +})); + +import { useServingInvoke } from "../use-serving-invoke"; + +describe("useServingInvoke", () => { + beforeEach(() => { + mockUsePluginClientConfig.mockReturnValue({ + isNamedMode: false, + aliases: ["default"], + }); + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ choices: [] }), { status: 200 }), + ); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + test("initial state is idle", () => { + const { result } = renderHook(() => useServingInvoke({ messages: [] })); + + expect(result.current.data).toBeNull(); + expect(result.current.loading).toBe(false); + expect(result.current.error).toBeNull(); + expect(typeof result.current.invoke).toBe("function"); + }); + + test("calls fetch to correct URL on invoke", async () => { + const fetchSpy = vi.spyOn(globalThis, "fetch"); + + const { result } = renderHook(() => + useServingInvoke({ messages: [{ role: "user", content: "Hello" }] }), + ); + + act(() => { + result.current.invoke(); + }); + + await waitFor(() => { + expect(fetchSpy).toHaveBeenCalledWith( + "/api/serving/invoke", + expect.objectContaining({ + method: "POST", + body: JSON.stringify({ + messages: [{ role: "user", content: "Hello" }], + }), + }), + ); + }); + }); + + test("uses alias in URL when provided", async () => { + mockUsePluginClientConfig.mockReturnValue({ + isNamedMode: true, + aliases: ["llm", "embedder"], + }); + const fetchSpy = vi.spyOn(globalThis, "fetch"); + + const { result } = renderHook(() => + useServingInvoke({ messages: [] }, { alias: "llm" }), + ); + + act(() => { + result.current.invoke(); + }); + + await waitFor(() => { + expect(fetchSpy).toHaveBeenCalledWith( + "/api/serving/llm/invoke", + expect.any(Object), + ); + }); + }); + + test("sets error for unknown alias", () => { + mockUsePluginClientConfig.mockReturnValue({ + isNamedMode: true, + aliases: ["llm", "embedder"], + }); + + const { result } = renderHook(() => + useServingInvoke({ messages: [] }, { alias: "unknown" as any }), + ); + + expect(result.current.error).toBe( + 'Unknown serving alias "unknown". Available: llm, embedder', + ); + }); + + test("invoke returns null for unknown alias without calling fetch", async () => { + mockUsePluginClientConfig.mockReturnValue({ + isNamedMode: true, + aliases: ["llm"], + }); + const fetchSpy = vi.spyOn(globalThis, "fetch"); + + const { result } = renderHook(() => + useServingInvoke({ messages: [] }, { alias: "bad" as any }), + ); + + let returnValue: unknown; + act(() => { + returnValue = result.current.invoke(); + }); + + expect(await returnValue).toBeNull(); + expect(fetchSpy).not.toHaveBeenCalled(); + }); + + test("sets data on successful response", async () => { + const responseData = { + choices: [{ message: { content: "Hi" } }], + }; + + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify(responseData), { status: 200 }), + ); + + const { result } = renderHook(() => useServingInvoke({ messages: [] })); + + act(() => { + result.current.invoke(); + }); + + await waitFor(() => { + expect(result.current.data).toEqual(responseData); + expect(result.current.loading).toBe(false); + }); + }); + + test("sets error on failed response", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(JSON.stringify({ error: "Not found" }), { status: 404 }), + ); + + const { result } = renderHook(() => useServingInvoke({ messages: [] })); + + await act(async () => { + result.current.invoke(); + // Wait for the fetch promise chain to resolve + await new Promise((r) => setTimeout(r, 10)); + }); + + await waitFor(() => { + expect(result.current.error).toBe("Not found"); + expect(result.current.loading).toBe(false); + }); + }); + + test("sets error with HTTP status on non-JSON error response", async () => { + vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response("Not Found", { + status: 404, + headers: { "Content-Type": "text/html" }, + }), + ); + + const { result } = renderHook(() => useServingInvoke({ messages: [] })); + + await act(async () => { + result.current.invoke(); + await new Promise((r) => setTimeout(r, 10)); + }); + + await waitFor(() => { + expect(result.current.error).toBe("HTTP 404"); + expect(result.current.loading).toBe(false); + }); + }); + + test("sets error on fetch network failure", async () => { + vi.spyOn(globalThis, "fetch").mockRejectedValue( + new Error("Network timeout"), + ); + + const { result } = renderHook(() => useServingInvoke({ messages: [] })); + + await act(async () => { + result.current.invoke(); + await new Promise((r) => setTimeout(r, 10)); + }); + + await waitFor(() => { + expect(result.current.error).toBe("Network timeout"); + expect(result.current.loading).toBe(false); + }); + }); + + test("auto starts when autoStart is true", async () => { + const fetchSpy = vi.spyOn(globalThis, "fetch"); + + renderHook(() => useServingInvoke({ messages: [] }, { autoStart: true })); + + await waitFor(() => { + expect(fetchSpy).toHaveBeenCalled(); + }); + }); +}); diff --git a/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts new file mode 100644 index 00000000..c9f27ca9 --- /dev/null +++ b/packages/appkit-ui/src/react/hooks/__tests__/use-serving-stream.test.ts @@ -0,0 +1,358 @@ +import { act, renderHook, waitFor } from "@testing-library/react"; +import { afterEach, describe, expect, test, vi } from "vitest"; + +// Mock connectSSE — capture callbacks so we can simulate SSE events +let capturedCallbacks: { + onMessage?: (msg: { data: string }) => void; + onError?: (err: Error) => void; + signal?: AbortSignal; +} = {}; + +let resolveStream: (() => void) | null = null; + +const mockConnectSSE = vi.fn().mockImplementation((opts: any) => { + capturedCallbacks = { + onMessage: opts.onMessage, + onError: opts.onError, + signal: opts.signal, + }; + return new Promise((resolve) => { + resolveStream = resolve; + // Also resolve after a tick as fallback for tests that don't manually resolve + setTimeout(resolve, 0); + }); +}); + +const mockUsePluginClientConfig = vi + .fn() + .mockReturnValue({ isNamedMode: false, aliases: ["default"] }); + +vi.mock("@/js", () => ({ + connectSSE: (...args: unknown[]) => mockConnectSSE(...args), +})); + +vi.mock("../use-plugin-config", () => ({ + usePluginClientConfig: (...args: unknown[]) => + mockUsePluginClientConfig(...args), +})); + +import { useServingStream } from "../use-serving-stream"; + +describe("useServingStream", () => { + afterEach(() => { + capturedCallbacks = {}; + resolveStream = null; + vi.clearAllMocks(); + }); + + test("initial state is idle", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + expect(result.current.chunks).toEqual([]); + expect(result.current.streaming).toBe(false); + expect(result.current.error).toBeNull(); + expect(typeof result.current.stream).toBe("function"); + expect(typeof result.current.reset).toBe("function"); + }); + + test("calls connectSSE with correct URL on stream", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + expect(mockConnectSSE).toHaveBeenCalledWith( + expect.objectContaining({ + url: "/api/serving/stream", + payload: JSON.stringify({ messages: [] }), + }), + ); + }); + + test("uses override body when passed to stream()", () => { + const { result } = renderHook(() => + useServingStream({ messages: [{ role: "user", content: "old" }] }), + ); + + const overrideBody = { + messages: [{ role: "user" as const, content: "new" }], + }; + + act(() => { + result.current.stream(overrideBody); + }); + + expect(mockConnectSSE).toHaveBeenCalledWith( + expect.objectContaining({ + payload: JSON.stringify(overrideBody), + }), + ); + }); + + test("uses alias in URL when provided", () => { + mockUsePluginClientConfig.mockReturnValue({ + isNamedMode: true, + aliases: ["embedder", "llm"], + }); + const { result } = renderHook(() => + useServingStream({ messages: [] }, { alias: "embedder" }), + ); + + act(() => { + result.current.stream(); + }); + + expect(mockConnectSSE).toHaveBeenCalledWith( + expect.objectContaining({ + url: "/api/serving/embedder/stream", + }), + ); + }); + + test("sets error for unknown alias", () => { + mockUsePluginClientConfig.mockReturnValue({ + isNamedMode: true, + aliases: ["llm", "embedder"], + }); + + const { result } = renderHook(() => + useServingStream({ messages: [] }, { alias: "unknown" as any }), + ); + + expect(result.current.error).toBe( + 'Unknown serving alias "unknown". Available: llm, embedder', + ); + }); + + test("stream does not call connectSSE for unknown alias", () => { + mockUsePluginClientConfig.mockReturnValue({ + isNamedMode: true, + aliases: ["llm"], + }); + + const { result } = renderHook(() => + useServingStream({ messages: [] }, { alias: "bad" as any }), + ); + + act(() => { + result.current.stream(); + }); + + expect(mockConnectSSE).not.toHaveBeenCalled(); + }); + + test("sets streaming to true when stream() is called", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + expect(result.current.streaming).toBe(true); + }); + + test("accumulates chunks from onMessage", async () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + act(() => { + capturedCallbacks.onMessage?.({ data: JSON.stringify({ id: 1 }) }); + }); + + act(() => { + capturedCallbacks.onMessage?.({ data: JSON.stringify({ id: 2 }) }); + }); + + expect(result.current.chunks).toEqual([{ id: 1 }, { id: 2 }]); + }); + + test("accumulates chunks with error field as normal data", async () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + act(() => { + capturedCallbacks.onMessage?.({ + data: JSON.stringify({ error: "Model overloaded" }), + }); + }); + + // Chunks with an `error` field are treated as data, not stream errors. + // Transport-level errors are delivered via onError callback instead. + expect(result.current.chunks).toEqual([{ error: "Model overloaded" }]); + expect(result.current.error).toBeNull(); + expect(result.current.streaming).toBe(true); + }); + + test("sets error from onError callback", async () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + act(() => { + capturedCallbacks.onError?.(new Error("Connection lost")); + }); + + expect(result.current.error).toBe("Connection lost"); + expect(result.current.streaming).toBe(false); + }); + + test("silently skips malformed JSON messages", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + act(() => { + capturedCallbacks.onMessage?.({ data: "not valid json{" }); + }); + + // No chunks added, no error set + expect(result.current.chunks).toEqual([]); + expect(result.current.error).toBeNull(); + }); + + test("reset() clears state and aborts active stream", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + act(() => { + capturedCallbacks.onMessage?.({ data: JSON.stringify({ id: 1 }) }); + }); + + expect(result.current.chunks).toHaveLength(1); + expect(result.current.streaming).toBe(true); + + act(() => { + result.current.reset(); + }); + + expect(result.current.chunks).toEqual([]); + expect(result.current.streaming).toBe(false); + expect(result.current.error).toBeNull(); + }); + + test("autoStart triggers stream on mount", async () => { + renderHook(() => useServingStream({ messages: [] }, { autoStart: true })); + + await waitFor(() => { + expect(mockConnectSSE).toHaveBeenCalled(); + }); + }); + + test("passes abort signal to connectSSE", () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + expect(capturedCallbacks.signal).toBeDefined(); + expect(capturedCallbacks.signal?.aborted).toBe(false); + }); + + test("aborts stream on unmount", () => { + const { result, unmount } = renderHook(() => + useServingStream({ messages: [] }), + ); + + act(() => { + result.current.stream(); + }); + + const signal = capturedCallbacks.signal; + expect(signal?.aborted).toBe(false); + + unmount(); + + expect(signal?.aborted).toBe(true); + }); + + test("sets streaming to false when connectSSE resolves", async () => { + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + await waitFor(() => { + expect(result.current.streaming).toBe(false); + }); + }); + + test("sets error when connectSSE promise rejects", async () => { + mockConnectSSE.mockImplementationOnce((opts: any) => { + capturedCallbacks = { + onMessage: opts.onMessage, + onError: opts.onError, + signal: opts.signal, + }; + return Promise.reject(new Error("Network failure")); + }); + + const { result } = renderHook(() => useServingStream({ messages: [] })); + + act(() => { + result.current.stream(); + }); + + await waitFor(() => { + expect(result.current.error).toBe("Connection error"); + expect(result.current.streaming).toBe(false); + }); + }); + + test("calls onComplete with accumulated chunks when stream finishes", async () => { + const onComplete = vi.fn(); + + // Use a controllable mock so stream doesn't auto-resolve + mockConnectSSE.mockImplementationOnce((opts: any) => { + capturedCallbacks = { + onMessage: opts.onMessage, + onError: opts.onError, + signal: opts.signal, + }; + return new Promise((resolve) => { + resolveStream = resolve; + }); + }); + + const { result } = renderHook(() => + useServingStream({ messages: [] }, { onComplete }), + ); + + act(() => { + result.current.stream(); + }); + + // Send two chunks + act(() => { + capturedCallbacks.onMessage?.({ data: JSON.stringify({ id: 1 }) }); + }); + act(() => { + capturedCallbacks.onMessage?.({ data: JSON.stringify({ id: 2 }) }); + }); + + expect(onComplete).not.toHaveBeenCalled(); + + // Complete the stream + await act(async () => { + resolveStream?.(); + await new Promise((r) => setTimeout(r, 0)); + }); + + expect(onComplete).toHaveBeenCalledWith([{ id: 1 }, { id: 2 }]); + }); +}); diff --git a/packages/appkit-ui/src/react/hooks/index.ts b/packages/appkit-ui/src/react/hooks/index.ts index 84d51b53..a425b010 100644 --- a/packages/appkit-ui/src/react/hooks/index.ts +++ b/packages/appkit-ui/src/react/hooks/index.ts @@ -2,8 +2,13 @@ export type { AnalyticsFormat, InferResultByFormat, InferRowType, + InferServingChunk, + InferServingRequest, + InferServingResponse, PluginRegistry, QueryRegistry, + ServingAlias, + ServingEndpointRegistry, TypedArrowTable, UseAnalyticsQueryOptions, UseAnalyticsQueryResult, @@ -15,3 +20,13 @@ export { useChartData, } from "./use-chart-data"; export { usePluginClientConfig } from "./use-plugin-config"; +export { + type UseServingInvokeOptions, + type UseServingInvokeResult, + useServingInvoke, +} from "./use-serving-invoke"; +export { + type UseServingStreamOptions, + type UseServingStreamResult, + useServingStream, +} from "./use-serving-stream"; diff --git a/packages/appkit-ui/src/react/hooks/types.ts b/packages/appkit-ui/src/react/hooks/types.ts index 5db725fc..bd5a7dc2 100644 --- a/packages/appkit-ui/src/react/hooks/types.ts +++ b/packages/appkit-ui/src/react/hooks/types.ts @@ -134,3 +134,59 @@ export type InferParams = K extends AugmentedRegistry export interface PluginRegistry { [key: string]: Record; } + +export interface ServingClientConfig { + isNamedMode: boolean; + aliases: string[]; +} + +// ============================================================================ +// Serving Endpoint Registry +// ============================================================================ + +/** + * Serving endpoint registry for type-safe alias names. + * Extend this interface via module augmentation to get alias autocomplete: + * + * @example + * ```typescript + * // Auto-generated by appKitServingTypesPlugin() + * declare module "@databricks/appkit-ui/react" { + * interface ServingEndpointRegistry { + * llm: { request: {...}; response: {...}; chunk: {...} }; + * } + * } + * ``` + */ +// biome-ignore lint/suspicious/noEmptyInterface: intentionally empty — populated via module augmentation +export interface ServingEndpointRegistry {} + +/** Resolves to registry keys if populated, otherwise string */ +export type ServingAlias = + AugmentedRegistry extends never + ? string + : AugmentedRegistry; + +/** Infers chunk type from registry when alias is a known key */ +export type InferServingChunk = + K extends AugmentedRegistry + ? ServingEndpointRegistry[K] extends { chunk: infer C } + ? C + : unknown + : unknown; + +/** Infers response type from registry when alias is a known key */ +export type InferServingResponse = + K extends AugmentedRegistry + ? ServingEndpointRegistry[K] extends { response: infer R } + ? R + : unknown + : unknown; + +/** Infers request type from registry when alias is a known key */ +export type InferServingRequest = + K extends AugmentedRegistry + ? ServingEndpointRegistry[K] extends { request: infer Req } + ? Req + : Record + : Record; diff --git a/packages/appkit-ui/src/react/hooks/use-serving-invoke.ts b/packages/appkit-ui/src/react/hooks/use-serving-invoke.ts new file mode 100644 index 00000000..ea44f280 --- /dev/null +++ b/packages/appkit-ui/src/react/hooks/use-serving-invoke.ts @@ -0,0 +1,129 @@ +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; +import type { + InferServingRequest, + InferServingResponse, + ServingAlias, + ServingClientConfig, +} from "./types"; +import { usePluginClientConfig } from "./use-plugin-config"; + +export interface UseServingInvokeOptions< + K extends ServingAlias = ServingAlias, +> { + /** Endpoint alias for named mode. Omit for default mode. */ + alias?: K; + /** If false, does not invoke automatically on mount. Default: false */ + autoStart?: boolean; +} + +export interface UseServingInvokeResult< + T = unknown, + TBody = Record, +> { + /** Trigger the invocation. Pass an optional body override for this invocation. */ + invoke: (overrideBody?: TBody) => Promise; + /** Response data, null until loaded. */ + data: T | null; + /** Whether a request is in progress. */ + loading: boolean; + /** Error message, if any. */ + error: string | null; +} + +/** + * Hook for non-streaming invocation of a serving endpoint. + * Calls `POST /api/serving/invoke` (default) or `POST /api/serving/{alias}/invoke` (named). + * + * When the type generator has populated `ServingEndpointRegistry`, the response type + * is automatically inferred from the endpoint's OpenAPI schema. + */ +export function useServingInvoke( + body: InferServingRequest, + options: UseServingInvokeOptions = {} as UseServingInvokeOptions, +): UseServingInvokeResult, InferServingRequest> { + type TResponse = InferServingResponse; + const { alias, autoStart = false } = options; + + const config = usePluginClientConfig("serving"); + + const aliasError = useMemo(() => { + if (!alias || !config.aliases) return null; + const aliasStr = String(alias); + if (!config.aliases.includes(aliasStr)) { + return `Unknown serving alias "${aliasStr}". Available: ${config.aliases.join(", ")}`; + } + return null; + }, [alias, config.aliases]); + + const [data, setData] = useState(null); + const [loading, setLoading] = useState(false); + const [error, setError] = useState(aliasError); + const abortControllerRef = useRef(null); + + const urlSuffix = alias + ? `/api/serving/${encodeURIComponent(String(alias))}/invoke` + : "/api/serving/invoke"; + + const bodyJson = JSON.stringify(body); + + const invoke = useCallback( + (overrideBody?: InferServingRequest): Promise => { + if (aliasError) { + setError(aliasError); + return Promise.resolve(null); + } + + if (abortControllerRef.current) { + abortControllerRef.current.abort(); + } + + setLoading(true); + setError(null); + setData(null); + + const abortController = new AbortController(); + abortControllerRef.current = abortController; + + const payload = overrideBody ? JSON.stringify(overrideBody) : bodyJson; + + return fetch(urlSuffix, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: payload, + signal: abortController.signal, + }) + .then(async (res) => { + if (!res.ok) { + const errorBody = await res.json().catch(() => null); + throw new Error(errorBody?.error || `HTTP ${res.status}`); + } + return res.json(); + }) + .then((result: TResponse) => { + if (abortController.signal.aborted) return null; + setData(result); + setLoading(false); + return result; + }) + .catch((err: Error) => { + if (abortController.signal.aborted) return null; + setError(err.message || "Request failed"); + setLoading(false); + return null; + }); + }, + [urlSuffix, bodyJson, aliasError], + ); + + useEffect(() => { + if (autoStart) { + invoke(); + } + + return () => { + abortControllerRef.current?.abort(); + }; + }, [invoke, autoStart]); + + return { invoke, data, loading, error }; +} diff --git a/packages/appkit-ui/src/react/hooks/use-serving-stream.ts b/packages/appkit-ui/src/react/hooks/use-serving-stream.ts new file mode 100644 index 00000000..741bcf97 --- /dev/null +++ b/packages/appkit-ui/src/react/hooks/use-serving-stream.ts @@ -0,0 +1,155 @@ +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; +import { connectSSE } from "@/js"; +import type { + InferServingChunk, + InferServingRequest, + ServingAlias, + ServingClientConfig, +} from "./types"; +import { usePluginClientConfig } from "./use-plugin-config"; + +export interface UseServingStreamOptions< + K extends ServingAlias = ServingAlias, + T = InferServingChunk, +> { + /** Endpoint alias for named mode. Omit for default mode. */ + alias?: K; + /** If true, starts streaming automatically on mount. Default: false */ + autoStart?: boolean; + /** Called with accumulated chunks when the stream completes successfully. */ + onComplete?: (chunks: T[]) => void; +} + +export interface UseServingStreamResult< + T = unknown, + TBody = Record, +> { + /** Trigger the streaming invocation. Pass an optional body override for this invocation. */ + stream: (overrideBody?: TBody) => void; + /** Accumulated chunks received so far. */ + chunks: T[]; + /** Whether streaming is in progress. */ + streaming: boolean; + /** Error message, if any. */ + error: string | null; + /** Reset chunks and abort any active stream. */ + reset: () => void; +} + +/** + * Hook for streaming invocation of a serving endpoint via SSE. + * Calls `POST /api/serving/stream` (default) or `POST /api/serving/{alias}/stream` (named). + * Accumulates parsed chunks in state. + * + * When the type generator has populated `ServingEndpointRegistry`, the chunk type + * is automatically inferred from the endpoint's OpenAPI schema. + */ +export function useServingStream( + body: InferServingRequest, + options: UseServingStreamOptions = {} as UseServingStreamOptions, +): UseServingStreamResult, InferServingRequest> { + type TChunk = InferServingChunk; + const { alias, autoStart = false, onComplete } = options; + + const config = usePluginClientConfig("serving"); + + const aliasError = useMemo(() => { + if (!alias || !config.aliases) return null; + const aliasStr = String(alias); + if (!config.aliases.includes(aliasStr)) { + return `Unknown serving alias "${aliasStr}". Available: ${config.aliases.join(", ")}`; + } + return null; + }, [alias, config.aliases]); + + const [chunks, setChunks] = useState([]); + const [streaming, setStreaming] = useState(false); + const [error, setError] = useState(aliasError); + const abortControllerRef = useRef(null); + const chunksRef = useRef([]); + const onCompleteRef = useRef(onComplete); + onCompleteRef.current = onComplete; + + const urlSuffix = alias + ? `/api/serving/${encodeURIComponent(String(alias))}/stream` + : "/api/serving/stream"; + + const reset = useCallback(() => { + abortControllerRef.current?.abort(); + abortControllerRef.current = null; + chunksRef.current = []; + setChunks([]); + setStreaming(false); + setError(null); + }, []); + + const bodyJson = JSON.stringify(body); + + const stream = useCallback( + (overrideBody?: InferServingRequest) => { + if (aliasError) { + setError(aliasError); + return; + } + + // Abort any existing stream + abortControllerRef.current?.abort(); + + setStreaming(true); + setError(null); + setChunks([]); + chunksRef.current = []; + + const abortController = new AbortController(); + abortControllerRef.current = abortController; + + const payload = overrideBody ? JSON.stringify(overrideBody) : bodyJson; + + connectSSE({ + url: urlSuffix, + payload, + signal: abortController.signal, + onMessage: async (message) => { + if (abortController.signal.aborted) return; + try { + const parsed = JSON.parse(message.data); + + chunksRef.current = [...chunksRef.current, parsed as TChunk]; + setChunks(chunksRef.current); + } catch { + // Skip malformed messages + } + }, + onError: (err) => { + if (abortController.signal.aborted) return; + setStreaming(false); + setError(err instanceof Error ? err.message : "Streaming failed"); + }, + }) + .then(() => { + if (abortController.signal.aborted) return; + // Stream completed + setStreaming(false); + onCompleteRef.current?.(chunksRef.current); + }) + .catch(() => { + if (abortController.signal.aborted) return; + setStreaming(false); + setError("Connection error"); + }); + }, + [urlSuffix, bodyJson, aliasError], + ); + + useEffect(() => { + if (autoStart) { + stream(); + } + + return () => { + abortControllerRef.current?.abort(); + }; + }, [stream, autoStart]); + + return { stream, chunks, streaming, error, reset }; +} diff --git a/packages/appkit/package.json b/packages/appkit/package.json index 0613ec51..64166c4c 100644 --- a/packages/appkit/package.json +++ b/packages/appkit/package.json @@ -50,6 +50,7 @@ "typecheck": "tsc --noEmit" }, "dependencies": { + "@ast-grep/napi": "0.37.0", "@databricks/lakebase": "workspace:*", "@databricks/sdk-experimental": "0.16.0", "@opentelemetry/api": "1.9.0", diff --git a/packages/appkit/src/index.ts b/packages/appkit/src/index.ts index 662a9178..3df5572b 100644 --- a/packages/appkit/src/index.ts +++ b/packages/appkit/src/index.ts @@ -81,6 +81,10 @@ export { SpanStatusCode, type TelemetryConfig, } from "./telemetry"; - +export { + extractServingEndpoints, + findServerFile, +} from "./type-generator/serving/server-file-extractor"; +export { appKitServingTypesPlugin } from "./type-generator/serving/vite-plugin"; // Vite plugin and type generation export { appKitTypesPlugin } from "./type-generator/vite-plugin"; diff --git a/packages/appkit/src/plugins/server/vite-dev-server.ts b/packages/appkit/src/plugins/server/vite-dev-server.ts index 65182d15..1edaa070 100644 --- a/packages/appkit/src/plugins/server/vite-dev-server.ts +++ b/packages/appkit/src/plugins/server/vite-dev-server.ts @@ -5,6 +5,7 @@ import type { ViteDevServer as ViteDevServerType } from "vite"; import { mergeConfigDedup } from "@/utils"; import { ServerError } from "../../errors"; import { createLogger } from "../../logging/logger"; +import { appKitServingTypesPlugin } from "../../type-generator/serving/vite-plugin"; import { appKitTypesPlugin } from "../../type-generator/vite-plugin"; import { BaseServer } from "./base-server"; import type { PluginClientConfigs, PluginEndpoints } from "./utils"; @@ -78,7 +79,11 @@ export class ViteDevServer extends BaseServer { ignored: ["**/node_modules/**", "!**/node_modules/@databricks/**"], }, }, - plugins: [react.default(), appKitTypesPlugin()], + plugins: [ + react.default(), + appKitTypesPlugin(), + appKitServingTypesPlugin(), + ], appType: "custom", }; diff --git a/packages/appkit/src/plugins/serving/schema-filter.ts b/packages/appkit/src/plugins/serving/schema-filter.ts index 6e52294a..92a25c69 100644 --- a/packages/appkit/src/plugins/serving/schema-filter.ts +++ b/packages/appkit/src/plugins/serving/schema-filter.ts @@ -1,19 +1,9 @@ import fs from "node:fs/promises"; import { createLogger } from "../../logging/logger"; - -const CACHE_VERSION = "1"; - -interface ServingCacheEntry { - hash: string; - requestType: string; - responseType: string; - chunkType: string | null; -} - -interface ServingCache { - version: string; - endpoints: Record; -} +import { + CACHE_VERSION, + type ServingCache, +} from "../../type-generator/serving/cache"; const logger = createLogger("serving:schema-filter"); @@ -47,11 +37,8 @@ export async function loadEndpointSchemas( const cache = parsed; for (const [alias, entry] of Object.entries(cache.endpoints)) { - // Extract property keys from the requestType string - // The requestType is a TypeScript object type like "{ messages: ...; temperature: ...; }" - const keys = extractPropertyKeys(entry.requestType); - if (keys.size > 0) { - allowlists.set(alias, keys); + if (entry.requestKeys && entry.requestKeys.length > 0) { + allowlists.set(alias, new Set(entry.requestKeys)); } } } catch (err) { @@ -67,25 +54,6 @@ export async function loadEndpointSchemas( return allowlists; } -/** - * Extracts top-level property keys from a TypeScript object type string. - * Matches patterns like `key:` or `key?:` at the first nesting level. - */ -function extractPropertyKeys(typeStr: string): Set { - const keys = new Set(); - // Match property names at the top level of the object type - // Looking for patterns: ` propertyName:` or ` propertyName?:` - const propRegex = /^\s{2}(?:\/\*\*[^*]*\*\/\s*)?(\w+)\??:/gm; - for ( - let match = propRegex.exec(typeStr); - match !== null; - match = propRegex.exec(typeStr) - ) { - keys.add(match[1]); - } - return keys; -} - /** * Filters a request body against the allowed keys for an endpoint alias. * Returns the filtered body and logs a warning for stripped params. diff --git a/packages/appkit/src/plugins/serving/serving.ts b/packages/appkit/src/plugins/serving/serving.ts index f64f6c95..bde6d9d6 100644 --- a/packages/appkit/src/plugins/serving/serving.ts +++ b/packages/appkit/src/plugins/serving/serving.ts @@ -281,6 +281,13 @@ export class ServingPlugin extends Plugin { ); } + clientConfig(): Record { + return { + isNamedMode: this.isNamedMode, + aliases: Object.keys(this.endpoints), + }; + } + async shutdown(): Promise { this.streamManager.abortAll(); } diff --git a/packages/appkit/src/plugins/serving/tests/schema-filter.test.ts b/packages/appkit/src/plugins/serving/tests/schema-filter.test.ts index 948b47f9..4fc030d8 100644 --- a/packages/appkit/src/plugins/serving/tests/schema-filter.test.ts +++ b/packages/appkit/src/plugins/serving/tests/schema-filter.test.ts @@ -109,7 +109,7 @@ describe("schema-filter", () => { expect(result.size).toBe(0); }); - test("extracts property keys from cached types", async () => { + test("reads requestKeys from cache entries", async () => { const fs = (await import("node:fs/promises")).default; vi.mocked(fs.readFile).mockResolvedValue( JSON.stringify({ @@ -117,13 +117,10 @@ describe("schema-filter", () => { endpoints: { default: { hash: "abc", - requestType: `{ - messages: string[]; - temperature?: number | null; - max_tokens: number; -}`, + requestType: "{}", responseType: "{}", chunkType: null, + requestKeys: ["messages", "temperature", "max_tokens"], }, }, }), @@ -137,5 +134,26 @@ describe("schema-filter", () => { expect(keys?.has("temperature")).toBe(true); expect(keys?.has("max_tokens")).toBe(true); }); + + test("skips entries without requestKeys (backwards compat)", async () => { + const fs = (await import("node:fs/promises")).default; + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + version: "1", + endpoints: { + default: { + hash: "abc", + requestType: "{ messages: string[] }", + responseType: "{}", + chunkType: null, + }, + }, + }), + ); + + const result = await loadEndpointSchemas("/some/path"); + // No requestKeys → passthrough mode (no allowlist) + expect(result.size).toBe(0); + }); }); }); diff --git a/packages/appkit/src/type-generator/index.ts b/packages/appkit/src/type-generator/index.ts index aa1478e8..8361e2cf 100644 --- a/packages/appkit/src/type-generator/index.ts +++ b/packages/appkit/src/type-generator/index.ts @@ -2,6 +2,7 @@ import fs from "node:fs/promises"; import dotenv from "dotenv"; import { createLogger } from "../logging/logger"; import { generateQueriesFromDescribe } from "./query-registry"; +import { generateServingTypes as generateServingTypesImpl } from "./serving/generator"; import type { QuerySchema } from "./types"; dotenv.config(); @@ -86,3 +87,8 @@ export async function generateFromEntryPoint(options: { logger.debug("Type generation complete!"); } + +// Rolldown tree-shaking only preserves "own exports" (locally defined) — not re-exports. +// A local binding ensures the serving vite plugin's import keeps this in the dependency graph, +// mirroring how generateFromEntryPoint (also defined here) is preserved via the analytics vite plugin. +export const generateServingTypes = generateServingTypesImpl; diff --git a/packages/appkit/src/type-generator/serving/cache.ts b/packages/appkit/src/type-generator/serving/cache.ts new file mode 100644 index 00000000..dc9bf7e2 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/cache.ts @@ -0,0 +1,56 @@ +import crypto from "node:crypto"; +import fs from "node:fs/promises"; +import path from "node:path"; +import { createLogger } from "../../logging/logger"; + +const logger = createLogger("type-generator:serving:cache"); + +export const CACHE_VERSION = "1"; +const CACHE_FILE = ".appkit-serving-types-cache.json"; +const CACHE_DIR = path.join( + process.cwd(), + "node_modules", + ".databricks", + "appkit", +); + +export interface ServingCacheEntry { + hash: string; + requestType: string; + responseType: string; + chunkType: string | null; + requestKeys: string[]; +} + +export interface ServingCache { + version: string; + endpoints: Record; +} + +export function hashSchema(schemaJson: string): string { + return crypto.createHash("sha256").update(schemaJson).digest("hex"); +} + +export async function loadServingCache(): Promise { + const cachePath = path.join(CACHE_DIR, CACHE_FILE); + try { + await fs.mkdir(CACHE_DIR, { recursive: true }); + const raw = await fs.readFile(cachePath, "utf8"); + const cache = JSON.parse(raw) as ServingCache; + if (cache.version === CACHE_VERSION) { + return cache; + } + logger.debug("Cache version mismatch, starting fresh"); + } catch (err) { + if ((err as NodeJS.ErrnoException).code !== "ENOENT") { + logger.warn("Cache file is corrupted, flushing cache completely."); + } + } + return { version: CACHE_VERSION, endpoints: {} }; +} + +export async function saveServingCache(cache: ServingCache): Promise { + const cachePath = path.join(CACHE_DIR, CACHE_FILE); + await fs.mkdir(CACHE_DIR, { recursive: true }); + await fs.writeFile(cachePath, JSON.stringify(cache, null, 2), "utf8"); +} diff --git a/packages/appkit/src/type-generator/serving/converter.ts b/packages/appkit/src/type-generator/serving/converter.ts new file mode 100644 index 00000000..b56b0460 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/converter.ts @@ -0,0 +1,159 @@ +import type { OpenApiOperation, OpenApiSchema } from "./fetcher"; + +/** + * Converts an OpenAPI schema to a TypeScript type string. + */ +function schemaToTypeString(schema: OpenApiSchema, indent = 0): string { + const pad = " ".repeat(indent); + + if (schema.oneOf) { + return schema.oneOf.map((s) => schemaToTypeString(s, indent)).join(" | "); + } + + if (schema.enum) { + return schema.enum.map((v) => JSON.stringify(v)).join(" | "); + } + + switch (schema.type) { + case "string": + return "string"; + case "integer": + case "number": + return "number"; + case "boolean": + return "boolean"; + case "array": { + if (!schema.items) return "unknown[]"; + const itemType = schemaToTypeString(schema.items, indent); + // Wrap union types in parens for array + if (itemType.includes(" | ") && !itemType.startsWith("{")) { + return `(${itemType})[]`; + } + return `${itemType}[]`; + } + case "object": { + if (!schema.properties) return "Record"; + const required = new Set(schema.required ?? []); + const entries = Object.entries(schema.properties).map(([key, prop]) => { + const optional = !required.has(key) ? "?" : ""; + const nullable = prop.nullable ? " | null" : ""; + const typeStr = schemaToTypeString(prop, indent + 1); + const formatComment = + prop.format && (prop.type === "number" || prop.type === "integer") + ? `/** @openapi ${prop.format}${prop.nullable ? ", nullable" : ""} */\n${pad} ` + : prop.nullable && prop.type === "integer" + ? `/** @openapi integer, nullable */\n${pad} ` + : ""; + return `${pad} ${formatComment}${key}${optional}: ${typeStr}${nullable};`; + }); + return `{\n${entries.join("\n")}\n${pad}}`; + } + default: + return "unknown"; + } +} + +/** + * Extracts the top-level property keys from the request schema. + * Strips the `stream` property (plugin-controlled). + */ +export function extractRequestKeys(operation: OpenApiOperation): string[] { + const schema = operation.requestBody?.content?.["application/json"]?.schema; + if (!schema?.properties) return []; + return Object.keys(schema.properties).filter((k) => k !== "stream"); +} + +/** + * Extracts and converts the request schema from an OpenAPI path operation. + * Strips the `stream` property from the request type. + */ +export function convertRequestSchema(operation: OpenApiOperation): string { + const schema = operation.requestBody?.content?.["application/json"]?.schema; + if (!schema || !schema.properties) return "Record"; + + // Strip `stream` property — the plugin controls this + const { stream: _stream, ...filteredProps } = schema.properties; + const filteredRequired = (schema.required ?? []).filter( + (r) => r !== "stream", + ); + + const filteredSchema: OpenApiSchema = { + ...schema, + properties: filteredProps, + required: filteredRequired.length > 0 ? filteredRequired : undefined, + }; + + return schemaToTypeString(filteredSchema); +} + +/** + * Extracts and converts the response schema from an OpenAPI path operation. + */ +export function convertResponseSchema(operation: OpenApiOperation): string { + const response = operation.responses?.["200"]; + const schema = response?.content?.["application/json"]?.schema; + if (!schema) return "unknown"; + return schemaToTypeString(schema); +} + +/** + * Derives a streaming chunk type from the response schema. + * Returns null if the response doesn't follow OpenAI-compatible format. + * + * OpenAI-compatible heuristic: response has `choices` array where items + * have a `message` object property. + */ +export function deriveChunkType(operation: OpenApiOperation): string | null { + const response = operation.responses?.["200"]; + const schema = response?.content?.["application/json"]?.schema; + if (!schema?.properties) return null; + + const choicesProp = schema.properties.choices; + if (!choicesProp || choicesProp.type !== "array" || !choicesProp.items) + return null; + + const choiceItemProps = choicesProp.items.properties; + if (!choiceItemProps?.message) return null; + + // It's OpenAI-compatible. Build the chunk type by transforming. + const messageSchema = choiceItemProps.message; + + // Build chunk schema: replace message with delta (Partial), make finish_reason nullable, drop usage + const chunkProperties: Record = {}; + + for (const [key, prop] of Object.entries(schema.properties)) { + if (key === "usage") continue; // Drop usage from chunks + if (key === "choices") { + // Transform choices items + const chunkChoiceProps: Record = {}; + for (const [ck, cp] of Object.entries(choiceItemProps)) { + if (ck === "message") { + // Replace message with delta: Partial + chunkChoiceProps.delta = { ...messageSchema }; + } else if (ck === "finish_reason") { + chunkChoiceProps[ck] = { ...cp, nullable: true }; + } else { + chunkChoiceProps[ck] = cp; + } + } + chunkProperties[key] = { + type: "array", + items: { + type: "object", + properties: chunkChoiceProps, + }, + }; + } else { + chunkProperties[key] = prop; + } + } + + const chunkSchema: OpenApiSchema = { + type: "object", + properties: chunkProperties, + }; + + // Delta properties are already optional (no `required` array in the schema), + // so schemaToTypeString renders them with `?:` — no Partial<> wrapper needed. + return schemaToTypeString(chunkSchema); +} diff --git a/packages/appkit/src/type-generator/serving/fetcher.ts b/packages/appkit/src/type-generator/serving/fetcher.ts new file mode 100644 index 00000000..c47775d7 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/fetcher.ts @@ -0,0 +1,137 @@ +import { ApiError, type WorkspaceClient } from "@databricks/sdk-experimental"; +import { createLogger } from "../../logging/logger"; + +const logger = createLogger("type-generator:serving:fetcher"); + +interface OpenApiSpec { + openapi: string; + info: { title: string; version: string }; + paths: Record>; +} + +export interface OpenApiOperation { + requestBody?: { + content: { + "application/json": { + schema: OpenApiSchema; + }; + }; + }; + responses?: Record< + string, + { + content?: { + "application/json": { + schema: OpenApiSchema; + }; + }; + } + >; +} + +export interface OpenApiSchema { + type?: string; + properties?: Record; + required?: string[]; + items?: OpenApiSchema; + enum?: string[]; + nullable?: boolean; + oneOf?: OpenApiSchema[]; + format?: string; +} + +/** + * Fetches the OpenAPI schema for a serving endpoint using the SDK. + * Returns null if the endpoint is not found or access is denied. + */ +export async function fetchOpenApiSchema( + client: WorkspaceClient, + endpointName: string, + servedModel?: string, +): Promise<{ spec: OpenApiSpec; pathKey: string } | null> { + try { + const response = await client.servingEndpoints.getOpenApi({ + name: endpointName, + }); + + if (!response.contents) { + logger.warn( + "Empty OpenAPI response for '%s', skipping type generation", + endpointName, + ); + return null; + } + + const text = await new Response(response.contents).text(); + const rawSpec: unknown = JSON.parse(text); + + if ( + typeof rawSpec !== "object" || + rawSpec === null || + !("paths" in rawSpec) || + typeof (rawSpec as OpenApiSpec).paths !== "object" + ) { + logger.warn( + "Invalid OpenAPI schema structure for '%s', skipping", + endpointName, + ); + return null; + } + const spec = rawSpec as OpenApiSpec; + + // Find the right path key + const pathKeys = Object.keys(spec.paths ?? {}); + if (pathKeys.length === 0) { + logger.warn("No paths in OpenAPI schema for '%s'", endpointName); + return null; + } + + let pathKey: string; + if (servedModel) { + const match = pathKeys.find((k) => k.includes(`/${servedModel}/`)); + if (!match) { + logger.warn( + "Served model '%s' not found in schema for '%s', using first path", + servedModel, + endpointName, + ); + pathKey = pathKeys[0]; + } else { + pathKey = match; + } + } else { + pathKey = pathKeys[0]; + } + + return { spec, pathKey }; + } catch (err) { + if (err instanceof ApiError) { + const status = err.statusCode ?? 0; + if (status === 404) { + logger.warn( + "Endpoint '%s' not found, skipping type generation", + endpointName, + ); + } else if (status === 403) { + logger.warn( + "Access denied to endpoint '%s' schema, skipping type generation", + endpointName, + ); + } else { + logger.warn( + "Failed to fetch schema for '%s' (HTTP %d), skipping: %s", + endpointName, + status, + err.message, + ); + } + } else { + logger.warn( + "Error fetching schema for '%s': %s", + endpointName, + (err as Error).message, + ); + } + return null; + } +} diff --git a/packages/appkit/src/type-generator/serving/generator.ts b/packages/appkit/src/type-generator/serving/generator.ts new file mode 100644 index 00000000..bcf4fd50 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/generator.ts @@ -0,0 +1,276 @@ +import fs from "node:fs/promises"; +import { WorkspaceClient } from "@databricks/sdk-experimental"; +import pc from "picocolors"; +import { createLogger } from "../../logging/logger"; +import type { EndpointConfig } from "../../plugins/serving/types"; +import { + CACHE_VERSION, + hashSchema, + loadServingCache, + type ServingCache, + saveServingCache, +} from "./cache"; +import { + convertRequestSchema, + convertResponseSchema, + deriveChunkType, + extractRequestKeys, +} from "./converter"; +import { fetchOpenApiSchema } from "./fetcher"; + +const logger = createLogger("type-generator:serving"); + +const GENERIC_REQUEST = "Record"; +const GENERIC_RESPONSE = "unknown"; +const GENERIC_CHUNK = "unknown"; + +interface GenerateServingTypesOptions { + outFile: string; + endpoints?: Record; + noCache?: boolean; +} + +/** + * Generates TypeScript type declarations for serving endpoints + * by fetching their OpenAPI schemas and converting to TypeScript. + */ +export async function generateServingTypes( + options: GenerateServingTypesOptions, +): Promise { + const { outFile, noCache } = options; + + // Resolve endpoints from config or env + const endpoints = options.endpoints ?? resolveDefaultEndpoints(); + if (Object.keys(endpoints).length === 0) { + logger.debug("No serving endpoints configured, skipping type generation"); + return; + } + + const startTime = performance.now(); + + const cache = noCache + ? { version: CACHE_VERSION, endpoints: {} } + : await loadServingCache(); + + let client: WorkspaceClient | undefined; + let updated = false; + + const registryEntries: string[] = []; + const logEntries: Array<{ + alias: string; + status: "HIT" | "MISS"; + error?: string; + }> = []; + + for (const [alias, config] of Object.entries(endpoints)) { + client ??= new WorkspaceClient({}); + const result = await processEndpoint(alias, config, client, cache); + if (result.cacheUpdated) updated = true; + registryEntries.push(result.entry); + logEntries.push(result.log); + } + + printLogTable(logEntries, startTime); + + const output = generateTypeDeclarations(registryEntries); + await fs.writeFile(outFile, output, "utf-8"); + + if (registryEntries.length === 0) { + logger.debug( + "Wrote empty serving types to %s (no endpoints resolved)", + outFile, + ); + } else { + logger.debug("Wrote serving types to %s", outFile); + } + + if (updated) { + await saveServingCache(cache as ServingCache); + } +} + +interface EndpointResult { + entry: string; + log: { alias: string; status: "HIT" | "MISS"; error?: string }; + cacheUpdated: boolean; +} + +function genericEntry(alias: string): string { + return buildRegistryEntry( + alias, + GENERIC_REQUEST, + GENERIC_RESPONSE, + GENERIC_CHUNK, + ); +} + +async function processEndpoint( + alias: string, + config: EndpointConfig, + client: WorkspaceClient, + cache: { endpoints: Record }, +): Promise { + const endpointName = process.env[config.env]; + if (!endpointName) { + return { + entry: genericEntry(alias), + log: { alias, status: "MISS", error: `env ${config.env} not set` }, + cacheUpdated: false, + }; + } + + const result = await fetchOpenApiSchema( + client, + endpointName, + config.servedModel, + ); + if (!result) { + return { + entry: genericEntry(alias), + log: { alias, status: "MISS", error: "schema fetch failed" }, + cacheUpdated: false, + }; + } + + const { spec, pathKey } = result; + const hash = hashSchema(JSON.stringify(spec)); + + // Cache hit + const cached = cache.endpoints[alias]; + if (cached && cached.hash === hash) { + return { + entry: buildRegistryEntry( + alias, + cached.requestType, + cached.responseType, + cached.chunkType, + ), + log: { alias, status: "HIT" }, + cacheUpdated: false, + }; + } + + // Cache miss — convert schema to types + const operation = spec.paths[pathKey]?.post; + if (!operation) { + return { + entry: genericEntry(alias), + log: { alias, status: "MISS", error: "no POST operation" }, + cacheUpdated: false, + }; + } + + try { + const requestType = convertRequestSchema(operation); + const responseType = convertResponseSchema(operation); + const chunkType = deriveChunkType(operation); + const requestKeys = extractRequestKeys(operation); + + cache.endpoints[alias] = { + hash, + requestType, + responseType, + chunkType, + requestKeys, + }; + + return { + entry: buildRegistryEntry(alias, requestType, responseType, chunkType), + log: { alias, status: "MISS" }, + cacheUpdated: true, + }; + } catch (convErr) { + logger.warn( + "Schema conversion failed for '%s': %s", + alias, + (convErr as Error).message, + ); + return { + entry: genericEntry(alias), + log: { alias, status: "MISS", error: "schema conversion failed" }, + cacheUpdated: false, + }; + } +} + +function printLogTable( + logEntries: Array<{ alias: string; status: "HIT" | "MISS"; error?: string }>, + startTime: number, +): void { + if (logEntries.length === 0) return; + + const maxNameLen = Math.max(...logEntries.map((e) => e.alias.length)); + const separator = pc.dim("─".repeat(50)); + console.log(""); + console.log( + ` ${pc.bold("Typegen Serving")} ${pc.dim(`(${logEntries.length})`)}`, + ); + console.log(` ${separator}`); + for (const entry of logEntries) { + const tag = + entry.status === "HIT" + ? `cache ${pc.bold(pc.green("HIT "))}` + : `cache ${pc.bold(pc.yellow("MISS "))}`; + const rawName = entry.alias.padEnd(maxNameLen); + const reason = entry.error ? ` ${pc.dim(entry.error)}` : ""; + console.log(` ${tag} ${rawName}${reason}`); + } + const elapsed = ((performance.now() - startTime) / 1000).toFixed(2); + const newCount = logEntries.filter((e) => e.status === "MISS").length; + const cacheCount = logEntries.filter((e) => e.status === "HIT").length; + console.log(` ${separator}`); + console.log( + ` ${newCount} new, ${cacheCount} from cache. ${pc.dim(`${elapsed}s`)}`, + ); + console.log(""); +} + +function resolveDefaultEndpoints(): Record { + if (process.env.DATABRICKS_SERVING_ENDPOINT_NAME) { + return { default: { env: "DATABRICKS_SERVING_ENDPOINT_NAME" } }; + } + return {}; +} + +function buildRegistryEntry( + alias: string, + requestType: string, + responseType: string, + chunkType: string | null, +): string { + const indent = " "; + const chunkEntry = chunkType ? chunkType : "unknown"; + return ` ${alias}: { +${indent}request: ${indentType(requestType, indent)}; +${indent}response: ${indentType(responseType, indent)}; +${indent}chunk: ${indentType(chunkEntry, indent)}; + };`; +} + +function indentType(typeStr: string, baseIndent: string): string { + if (!typeStr.includes("\n")) return typeStr; + return typeStr + .split("\n") + .map((line, i) => (i === 0 ? line : `${baseIndent}${line}`)) + .join("\n"); +} + +function generateTypeDeclarations(entries: string[]): string { + return `// Auto-generated by AppKit - DO NOT EDIT +// Generated from serving endpoint OpenAPI schemas +import "@databricks/appkit"; +import "@databricks/appkit-ui/react"; + +declare module "@databricks/appkit" { + interface ServingEndpointRegistry { +${entries.join("\n")} + } +} + +declare module "@databricks/appkit-ui/react" { + interface ServingEndpointRegistry { +${entries.join("\n")} + } +} +`; +} diff --git a/packages/appkit/src/type-generator/serving/server-file-extractor.ts b/packages/appkit/src/type-generator/serving/server-file-extractor.ts new file mode 100644 index 00000000..b26b0bf1 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/server-file-extractor.ts @@ -0,0 +1,221 @@ +import fs from "node:fs"; +import path from "node:path"; +import { Lang, parse, type SgNode } from "@ast-grep/napi"; +import { createLogger } from "../../logging/logger"; +import type { EndpointConfig } from "../../plugins/serving/types"; + +const logger = createLogger("type-generator:serving:extractor"); + +/** + * Candidate paths for the server entry file, relative to the project root. + * Checked in order; the first that exists is used. + * Same convention as plugin sync (sync.ts SERVER_FILE_CANDIDATES). + */ +const SERVER_FILE_CANDIDATES = ["server/index.ts", "server/server.ts"]; + +/** + * Find the server entry file by checking candidate paths in order. + * + * @param basePath - Project root directory to search from + * @returns Absolute path to the server file, or null if none found + */ +export function findServerFile(basePath: string): string | null { + for (const candidate of SERVER_FILE_CANDIDATES) { + const fullPath = path.join(basePath, candidate); + if (fs.existsSync(fullPath)) { + return fullPath; + } + } + return null; +} + +/** + * Extract serving endpoint config from a server file by AST-parsing it. + * Looks for `serving({ endpoints: { alias: { env: "..." }, ... } })` calls + * and extracts the endpoint alias names and their environment variable mappings. + * + * @param serverFilePath - Absolute path to the server entry file + * @returns Extracted endpoint config, or null if not found or not extractable + */ +export function extractServingEndpoints( + serverFilePath: string, +): Record | null { + let content: string; + try { + content = fs.readFileSync(serverFilePath, "utf-8"); + } catch { + logger.debug("Could not read server file: %s", serverFilePath); + return null; + } + + const lang = serverFilePath.endsWith(".tsx") ? Lang.Tsx : Lang.TypeScript; + const ast = parse(lang, content); + const root = ast.root(); + + // Find serving(...) call expressions + const servingCall = findServingCall(root); + if (!servingCall) { + logger.debug("No serving() call found in %s", serverFilePath); + return null; + } + + // Get the first argument (the config object) + const args = servingCall.field("arguments"); + if (!args) { + return null; + } + + const configArg = args.children().find((child) => child.kind() === "object"); + if (!configArg) { + // serving() called with no args or non-object arg + return null; + } + + // Find the "endpoints" property in the config object + const endpointsPair = findProperty(configArg, "endpoints"); + if (!endpointsPair) { + // Config object has no "endpoints" property (e.g. serving({ timeout: 5000 })) + return null; + } + + // Get the value of the endpoints property + const endpointsValue = getPropertyValue(endpointsPair); + if (!endpointsValue || endpointsValue.kind() !== "object") { + // endpoints is a variable reference, not an inline object + logger.debug( + "serving() endpoints is not an inline object literal in %s. " + + "Pass endpoints explicitly via appKitServingTypesPlugin({ endpoints }) in vite.config.ts.", + serverFilePath, + ); + return null; + } + + // Extract each endpoint entry + const endpoints: Record = {}; + const pairs = endpointsValue + .children() + .filter((child) => child.kind() === "pair"); + + for (const pair of pairs) { + const entry = extractEndpointEntry(pair); + if (entry) { + endpoints[entry.alias] = entry.config; + } + } + + if (Object.keys(endpoints).length === 0) { + return null; + } + + logger.debug( + "Extracted %d endpoint(s) from %s: %s", + Object.keys(endpoints).length, + serverFilePath, + Object.keys(endpoints).join(", "), + ); + + return endpoints; +} + +/** + * Find the serving() call expression in the AST. + * Looks for call expressions where the callee identifier is "serving". + */ +function findServingCall(root: SgNode): SgNode | null { + const callExpressions = root.findAll({ + rule: { kind: "call_expression" }, + }); + + for (const call of callExpressions) { + const callee = call.children()[0]; + if (callee?.kind() === "identifier" && callee.text() === "serving") { + return call; + } + } + + return null; +} + +/** + * Find a property (pair node) with the given key name in an object expression. + */ +function findProperty(objectNode: SgNode, propertyName: string): SgNode | null { + const pairs = objectNode + .children() + .filter((child) => child.kind() === "pair"); + + for (const pair of pairs) { + const key = pair.children()[0]; + if (!key) continue; + + const keyText = + key.kind() === "property_identifier" + ? key.text() + : key.kind() === "string" + ? key.text().replace(/^['"]|['"]$/g, "") + : null; + + if (keyText === propertyName) { + return pair; + } + } + + return null; +} + +/** + * Get the value node from a pair (property: value). + * The value is typically the last meaningful child after the colon. + */ +function getPropertyValue(pairNode: SgNode): SgNode | null { + const children = pairNode.children(); + // pair children: [key, ":", value] + return children.length >= 3 ? children[children.length - 1] : null; +} + +/** + * Extract a single endpoint entry from a pair node like: + * `demo: { env: "DATABRICKS_SERVING_ENDPOINT_NAME", servedModel: "my-model" }` + */ +function extractEndpointEntry( + pair: SgNode, +): { alias: string; config: EndpointConfig } | null { + const children = pair.children(); + if (children.length < 3) return null; + + // Get alias name (the key) + const keyNode = children[0]; + const alias = + keyNode.kind() === "property_identifier" + ? keyNode.text() + : keyNode.kind() === "string" + ? keyNode.text().replace(/^['"]|['"]$/g, "") + : null; + + if (!alias) return null; + + // Get the value (should be an object like { env: "..." }) + const valueNode = children[children.length - 1]; + if (valueNode.kind() !== "object") return null; + + // Extract env field + const envPair = findProperty(valueNode, "env"); + if (!envPair) return null; + + const envValue = getPropertyValue(envPair); + if (!envValue || envValue.kind() !== "string") return null; + + const env = envValue.text().replace(/^['"]|['"]$/g, ""); + + // Extract optional servedModel field + const config: EndpointConfig = { env }; + const servedModelPair = findProperty(valueNode, "servedModel"); + if (servedModelPair) { + const servedModelValue = getPropertyValue(servedModelPair); + if (servedModelValue?.kind() === "string") { + config.servedModel = servedModelValue.text().replace(/^['"]|['"]$/g, ""); + } + } + + return { alias, config }; +} diff --git a/packages/appkit/src/type-generator/serving/tests/cache.test.ts b/packages/appkit/src/type-generator/serving/tests/cache.test.ts new file mode 100644 index 00000000..0c99c997 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/cache.test.ts @@ -0,0 +1,109 @@ +import fs from "node:fs/promises"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { + CACHE_VERSION, + hashSchema, + loadServingCache, + type ServingCache, + saveServingCache, +} from "../cache"; + +vi.mock("node:fs/promises"); + +describe("serving cache", () => { + beforeEach(() => { + vi.mocked(fs.mkdir).mockResolvedValue(undefined); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe("hashSchema", () => { + test("returns consistent SHA256 hash", () => { + const hash1 = hashSchema('{"openapi": "3.1.0"}'); + const hash2 = hashSchema('{"openapi": "3.1.0"}'); + expect(hash1).toBe(hash2); + expect(hash1).toHaveLength(64); // SHA256 hex + }); + + test("different inputs produce different hashes", () => { + const hash1 = hashSchema('{"a": 1}'); + const hash2 = hashSchema('{"a": 2}'); + expect(hash1).not.toBe(hash2); + }); + }); + + describe("loadServingCache", () => { + test("returns empty cache when file does not exist", async () => { + vi.mocked(fs.readFile).mockRejectedValue( + Object.assign(new Error("ENOENT"), { code: "ENOENT" }), + ); + + const cache = await loadServingCache(); + expect(cache).toEqual({ version: CACHE_VERSION, endpoints: {} }); + }); + + test("returns parsed cache when file exists with correct version", async () => { + const cached: ServingCache = { + version: CACHE_VERSION, + endpoints: { + llm: { + hash: "abc", + requestType: "{ messages: string[] }", + responseType: "{ model: string }", + chunkType: null, + requestKeys: ["messages"], + }, + }, + }; + vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(cached)); + + const cache = await loadServingCache(); + expect(cache).toEqual(cached); + }); + + test("flushes cache when version mismatches", async () => { + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ version: "0", endpoints: { old: {} } }), + ); + + const cache = await loadServingCache(); + expect(cache).toEqual({ version: CACHE_VERSION, endpoints: {} }); + }); + + test("flushes cache when file is corrupted", async () => { + vi.mocked(fs.readFile).mockResolvedValue("not json"); + + const cache = await loadServingCache(); + expect(cache).toEqual({ version: CACHE_VERSION, endpoints: {} }); + }); + }); + + describe("saveServingCache", () => { + test("writes cache to file", async () => { + vi.mocked(fs.writeFile).mockResolvedValue(); + + const cache: ServingCache = { + version: CACHE_VERSION, + endpoints: { + test: { + hash: "xyz", + requestType: "{}", + responseType: "{}", + chunkType: null, + requestKeys: [], + }, + }, + }; + + await saveServingCache(cache); + + expect(fs.writeFile).toHaveBeenCalledWith( + expect.stringContaining(".appkit-serving-types-cache.json"), + JSON.stringify(cache, null, 2), + "utf8", + ); + }); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/tests/converter.test.ts b/packages/appkit/src/type-generator/serving/tests/converter.test.ts new file mode 100644 index 00000000..1be30738 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/converter.test.ts @@ -0,0 +1,308 @@ +import { describe, expect, test } from "vitest"; +import { + convertRequestSchema, + convertResponseSchema, + deriveChunkType, + extractRequestKeys, +} from "../converter"; +import type { OpenApiOperation, OpenApiSchema } from "../fetcher"; + +function makeOperation( + requestProps: Record, + responseProps?: Record, + required?: string[], +): OpenApiOperation { + return { + requestBody: { + content: { + "application/json": { + schema: { + type: "object", + properties: requestProps, + required, + }, + }, + }, + }, + responses: responseProps + ? { + "200": { + content: { + "application/json": { + schema: { + type: "object", + properties: responseProps, + }, + }, + }, + }, + } + : undefined, + }; +} + +describe("converter", () => { + describe("convertRequestSchema", () => { + test("converts string type", () => { + const op = makeOperation({ name: { type: "string" } }); + const result = convertRequestSchema(op); + expect(result).toContain("name?: string;"); + }); + + test("converts integer type to number", () => { + const op = makeOperation({ count: { type: "integer" } }); + expect(convertRequestSchema(op)).toContain("count?: number;"); + }); + + test("converts number type", () => { + const op = makeOperation({ + temp: { type: "number", format: "double" }, + }); + expect(convertRequestSchema(op)).toContain("temp?: number;"); + }); + + test("converts boolean type", () => { + const op = makeOperation({ flag: { type: "boolean" } }); + expect(convertRequestSchema(op)).toContain("flag?: boolean;"); + }); + + test("converts enum to string literal union", () => { + const op = makeOperation({ + role: { type: "string", enum: ["user", "assistant"] }, + }); + const result = convertRequestSchema(op); + expect(result).toContain('"user" | "assistant"'); + }); + + test("converts array type", () => { + const op = makeOperation({ + items: { type: "array", items: { type: "string" } }, + }); + expect(convertRequestSchema(op)).toContain("items?: string[];"); + }); + + test("converts nested object", () => { + const op = makeOperation({ + messages: { + type: "array", + items: { + type: "object", + properties: { + role: { type: "string" }, + content: { type: "string" }, + }, + }, + }, + }); + const result = convertRequestSchema(op); + expect(result).toContain("role?: string;"); + expect(result).toContain("content?: string;"); + }); + + test("handles nullable properties", () => { + const op = makeOperation({ + temp: { type: "number", nullable: true }, + }); + expect(convertRequestSchema(op)).toContain("temp?: number | null;"); + }); + + test("handles oneOf union types", () => { + const op = makeOperation({ + stop: { + oneOf: [ + { type: "string" }, + { type: "array", items: { type: "string" } }, + ], + }, + }); + const result = convertRequestSchema(op); + expect(result).toContain("string | string[]"); + }); + + test("strips stream property from request", () => { + const op = makeOperation({ + messages: { type: "array", items: { type: "string" } }, + stream: { type: "boolean", nullable: true }, + temperature: { type: "number" }, + }); + const result = convertRequestSchema(op); + expect(result).not.toContain("stream"); + expect(result).toContain("messages"); + expect(result).toContain("temperature"); + }); + + test("marks required properties without ?", () => { + const op = makeOperation( + { + messages: { type: "array", items: { type: "string" } }, + temperature: { type: "number" }, + }, + undefined, + ["messages"], + ); + const result = convertRequestSchema(op); + expect(result).toContain("messages: string[];"); + expect(result).toContain("temperature?: number;"); + }); + + test("returns Record for missing schema", () => { + const op: OpenApiOperation = {}; + expect(convertRequestSchema(op)).toBe("Record"); + }); + }); + + describe("convertResponseSchema", () => { + test("converts response schema", () => { + const op = makeOperation( + {}, + { + model: { type: "string" }, + id: { type: "string" }, + }, + ); + const result = convertResponseSchema(op); + expect(result).toContain("model?: string;"); + expect(result).toContain("id?: string;"); + }); + + test("returns unknown for missing response", () => { + const op: OpenApiOperation = {}; + expect(convertResponseSchema(op)).toBe("unknown"); + }); + }); + + describe("deriveChunkType", () => { + test("derives chunk type from OpenAI-compatible response", () => { + const op: OpenApiOperation = { + responses: { + "200": { + content: { + "application/json": { + schema: { + type: "object", + properties: { + model: { type: "string" }, + choices: { + type: "array", + items: { + type: "object", + properties: { + index: { type: "integer" }, + message: { + type: "object", + properties: { + role: { + type: "string", + enum: ["user", "assistant"], + }, + content: { type: "string" }, + }, + }, + finish_reason: { type: "string" }, + }, + }, + }, + usage: { + type: "object", + properties: { + prompt_tokens: { type: "integer" }, + }, + nullable: true, + }, + id: { type: "string" }, + }, + }, + }, + }, + }, + }, + }; + + const result = deriveChunkType(op); + expect(result).not.toBeNull(); + // Should have delta instead of message + expect(result).toContain("delta"); + expect(result).not.toContain("message"); + // Should make finish_reason nullable + expect(result).toContain("finish_reason"); + expect(result).toContain("| null"); + // Should drop usage + expect(result).not.toContain("usage"); + // Should keep model and id + expect(result).toContain("model"); + expect(result).toContain("id"); + }); + + test("returns null for non-OpenAI response (no choices)", () => { + const op = makeOperation( + {}, + { + predictions: { type: "array", items: { type: "number" } }, + }, + ); + expect(deriveChunkType(op)).toBeNull(); + }); + + test("returns null for choices without message", () => { + const op: OpenApiOperation = { + responses: { + "200": { + content: { + "application/json": { + schema: { + type: "object", + properties: { + choices: { + type: "array", + items: { + type: "object", + properties: { + score: { type: "number" }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }; + expect(deriveChunkType(op)).toBeNull(); + }); + + test("returns null for missing response", () => { + const op: OpenApiOperation = {}; + expect(deriveChunkType(op)).toBeNull(); + }); + }); + + describe("extractRequestKeys", () => { + test("extracts top-level property keys excluding stream", () => { + const op = makeOperation({ + messages: { type: "array", items: { type: "string" } }, + temperature: { type: "number" }, + stream: { type: "boolean", nullable: true }, + }); + expect(extractRequestKeys(op)).toEqual(["messages", "temperature"]); + }); + + test("returns empty array for missing schema", () => { + const op: OpenApiOperation = {}; + expect(extractRequestKeys(op)).toEqual([]); + }); + + test("returns empty array for schema without properties", () => { + const op: OpenApiOperation = { + requestBody: { + content: { + "application/json": { + schema: { type: "object" }, + }, + }, + }, + }; + expect(extractRequestKeys(op)).toEqual([]); + }); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/tests/fetcher.test.ts b/packages/appkit/src/type-generator/serving/tests/fetcher.test.ts new file mode 100644 index 00000000..cae3ec7b --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/fetcher.test.ts @@ -0,0 +1,148 @@ +import { ApiError } from "@databricks/sdk-experimental"; +import { afterEach, describe, expect, test, vi } from "vitest"; +import { fetchOpenApiSchema } from "../fetcher"; + +function makeValidSpec( + paths: Record = { "/invocations": { post: {} } }, +) { + return { + openapi: "3.0.0", + info: { title: "test", version: "1" }, + paths, + }; +} + +function createReadableStream(data: string): ReadableStream { + const encoder = new TextEncoder(); + return new ReadableStream({ + start(controller) { + controller.enqueue(encoder.encode(data)); + controller.close(); + }, + }); +} + +function createMockClient(getOpenApiImpl?: (...args: any[]) => any) { + const defaultImpl = async () => ({ + contents: createReadableStream(JSON.stringify(makeValidSpec())), + }); + return { + servingEndpoints: { + getOpenApi: vi.fn(getOpenApiImpl ?? defaultImpl), + }, + } as any; +} + +describe("fetchOpenApiSchema", () => { + afterEach(() => { + vi.restoreAllMocks(); + }); + + test("returns null on ApiError 404", async () => { + const client = createMockClient(async () => { + throw new ApiError("Not found", "NOT_FOUND", 404, undefined, []); + }); + const result = await fetchOpenApiSchema(client, "ep"); + expect(result).toBeNull(); + }); + + test("returns null on ApiError 403", async () => { + const client = createMockClient(async () => { + throw new ApiError("Forbidden", "FORBIDDEN", 403, undefined, []); + }); + const result = await fetchOpenApiSchema(client, "ep"); + expect(result).toBeNull(); + }); + + test("returns null on ApiError 500", async () => { + const client = createMockClient(async () => { + throw new ApiError("Server error", "INTERNAL", 500, undefined, []); + }); + const result = await fetchOpenApiSchema(client, "ep"); + expect(result).toBeNull(); + }); + + test("returns null on generic error", async () => { + const client = createMockClient(async () => { + throw new Error("network failure"); + }); + const result = await fetchOpenApiSchema(client, "ep"); + expect(result).toBeNull(); + }); + + test("returns null when response has no contents", async () => { + const client = createMockClient(async () => ({ contents: undefined })); + const result = await fetchOpenApiSchema(client, "ep"); + expect(result).toBeNull(); + }); + + test("returns spec and pathKey for valid response", async () => { + const spec = makeValidSpec({ + "/serving-endpoints/ep/invocations": { post: { requestBody: {} } }, + }); + const client = createMockClient(async () => ({ + contents: createReadableStream(JSON.stringify(spec)), + })); + + const result = await fetchOpenApiSchema(client, "ep"); + expect(result).not.toBeNull(); + expect(result?.pathKey).toBe("/serving-endpoints/ep/invocations"); + expect(result?.spec.openapi).toBe("3.0.0"); + }); + + test("matches servedModel path when provided", async () => { + const spec = makeValidSpec({ + "/serving-endpoints/ep/served-models/gpt4/invocations": { post: {} }, + "/serving-endpoints/ep/invocations": { post: {} }, + }); + const client = createMockClient(async () => ({ + contents: createReadableStream(JSON.stringify(spec)), + })); + + const result = await fetchOpenApiSchema(client, "ep", "gpt4"); + expect(result?.pathKey).toBe( + "/serving-endpoints/ep/served-models/gpt4/invocations", + ); + }); + + test("falls back to first path when servedModel not found", async () => { + const spec = makeValidSpec({ + "/serving-endpoints/ep/invocations": { post: {} }, + }); + const client = createMockClient(async () => ({ + contents: createReadableStream(JSON.stringify(spec)), + })); + + const result = await fetchOpenApiSchema(client, "ep", "nonexistent-model"); + expect(result?.pathKey).toBe("/serving-endpoints/ep/invocations"); + }); + + test("returns null for invalid spec structure (missing paths)", async () => { + const client = createMockClient(async () => ({ + contents: createReadableStream( + JSON.stringify({ openapi: "3.0.0", info: {} }), + ), + })); + + const result = await fetchOpenApiSchema(client, "ep"); + expect(result).toBeNull(); + }); + + test("returns null when paths object is empty", async () => { + const client = createMockClient(async () => ({ + contents: createReadableStream(JSON.stringify(makeValidSpec({}))), + })); + + const result = await fetchOpenApiSchema(client, "ep"); + expect(result).toBeNull(); + }); + + test("calls SDK getOpenApi with correct endpoint name", async () => { + const client = createMockClient(); + await fetchOpenApiSchema(client, "my-endpoint"); + + expect(client.servingEndpoints.getOpenApi).toHaveBeenCalledWith({ + name: "my-endpoint", + }); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/tests/generator.test.ts b/packages/appkit/src/type-generator/serving/tests/generator.test.ts new file mode 100644 index 00000000..8761519b --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/generator.test.ts @@ -0,0 +1,215 @@ +import fs from "node:fs/promises"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { generateServingTypes } from "../generator"; + +vi.mock("node:fs/promises"); + +// Mock cache module +vi.mock("../cache", () => ({ + CACHE_VERSION: "1", + hashSchema: vi.fn(() => "mock-hash"), + loadServingCache: vi.fn(async () => ({ version: "1", endpoints: {} })), + saveServingCache: vi.fn(async () => {}), +})); + +// Mock fetcher +const mockFetchOpenApiSchema = vi.fn(); +vi.mock("../fetcher", () => ({ + fetchOpenApiSchema: (...args: any[]) => mockFetchOpenApiSchema(...args), +})); + +// Mock WorkspaceClient +vi.mock("@databricks/sdk-experimental", () => ({ + WorkspaceClient: vi.fn(() => ({ config: {} })), +})); + +const CHAT_OPENAPI_SPEC = { + openapi: "3.1.0", + info: { title: "test", version: "1" }, + paths: { + "/served-models/llm/invocations": { + post: { + requestBody: { + content: { + "application/json": { + schema: { + type: "object", + properties: { + messages: { + type: "array", + items: { + type: "object", + properties: { + role: { type: "string" }, + content: { type: "string" }, + }, + }, + }, + temperature: { type: "number", nullable: true }, + stream: { type: "boolean", nullable: true }, + }, + }, + }, + }, + }, + responses: { + "200": { + content: { + "application/json": { + schema: { + type: "object", + properties: { + model: { type: "string" }, + choices: { + type: "array", + items: { + type: "object", + properties: { + message: { + type: "object", + properties: { + role: { type: "string" }, + content: { type: "string" }, + }, + }, + finish_reason: { type: "string" }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, +}; + +describe("generateServingTypes", () => { + const outFile = "/tmp/test-serving-types.d.ts"; + + beforeEach(() => { + vi.mocked(fs.writeFile).mockResolvedValue(); + process.env.TEST_SERVING_ENDPOINT = "my-endpoint"; + }); + + afterEach(() => { + delete process.env.TEST_SERVING_ENDPOINT; + delete process.env.DATABRICKS_SERVING_ENDPOINT_NAME; + vi.restoreAllMocks(); + }); + + test("generates .d.ts with module augmentation for a chat endpoint", async () => { + mockFetchOpenApiSchema.mockResolvedValue({ + spec: CHAT_OPENAPI_SPEC, + pathKey: "/served-models/llm/invocations", + }); + + await generateServingTypes({ + outFile, + endpoints: { llm: { env: "TEST_SERVING_ENDPOINT" } }, + noCache: true, + }); + + expect(fs.writeFile).toHaveBeenCalledWith( + outFile, + expect.any(String), + "utf-8", + ); + const output = vi.mocked(fs.writeFile).mock.calls[0][1] as string; + + // Verify module augmentation structure + expect(output).toContain("// Auto-generated by AppKit - DO NOT EDIT"); + expect(output).toContain('import "@databricks/appkit"'); + expect(output).toContain('import "@databricks/appkit-ui/react"'); + expect(output).toContain('declare module "@databricks/appkit"'); + expect(output).toContain('declare module "@databricks/appkit-ui/react"'); + expect(output).toContain("interface ServingEndpointRegistry"); + expect(output).toContain("llm:"); + expect(output).toContain("request:"); + expect(output).toContain("response:"); + expect(output).toContain("chunk:"); + }); + + test("strips stream property from generated request type", async () => { + mockFetchOpenApiSchema.mockResolvedValue({ + spec: CHAT_OPENAPI_SPEC, + pathKey: "/served-models/llm/invocations", + }); + + await generateServingTypes({ + outFile, + endpoints: { llm: { env: "TEST_SERVING_ENDPOINT" } }, + noCache: true, + }); + + const output = vi.mocked(fs.writeFile).mock.calls[0][1] as string; + // `stream` should be stripped from request type + expect(output).toContain("messages"); + expect(output).toContain("temperature"); + expect(output).not.toMatch(/\bstream\??\s*:/); + }); + + test("emits generic types when env var is not set", async () => { + delete process.env.TEST_SERVING_ENDPOINT; + + await generateServingTypes({ + outFile, + endpoints: { llm: { env: "TEST_SERVING_ENDPOINT" } }, + noCache: true, + }); + + expect(mockFetchOpenApiSchema).not.toHaveBeenCalled(); + const output = vi.mocked(fs.writeFile).mock.calls[0][1] as string; + expect(output).toContain("llm:"); + expect(output).toContain("Record"); + }); + + test("skips generation when no endpoints configured and no env var", async () => { + await generateServingTypes({ + outFile, + noCache: true, + }); + + expect(mockFetchOpenApiSchema).not.toHaveBeenCalled(); + expect(fs.writeFile).not.toHaveBeenCalled(); + }); + + test("emits generic types when schema fetch returns null", async () => { + mockFetchOpenApiSchema.mockResolvedValue(null); + + await generateServingTypes({ + outFile, + endpoints: { llm: { env: "TEST_SERVING_ENDPOINT" } }, + noCache: true, + }); + + const output = vi.mocked(fs.writeFile).mock.calls[0][1] as string; + expect(output).toContain("llm:"); + expect(output).toContain("Record"); + }); + + test("resolves default endpoint from DATABRICKS_SERVING_ENDPOINT_NAME", async () => { + process.env.DATABRICKS_SERVING_ENDPOINT_NAME = "my-default-endpoint"; + mockFetchOpenApiSchema.mockResolvedValue({ + spec: CHAT_OPENAPI_SPEC, + pathKey: "/served-models/llm/invocations", + }); + + await generateServingTypes({ + outFile, + noCache: true, + }); + + expect(mockFetchOpenApiSchema).toHaveBeenCalledWith( + expect.anything(), + "my-default-endpoint", + undefined, + ); + + const output = vi.mocked(fs.writeFile).mock.calls[0][1] as string; + expect(output).toContain("default:"); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/tests/server-file-extractor.test.ts b/packages/appkit/src/type-generator/serving/tests/server-file-extractor.test.ts new file mode 100644 index 00000000..4d1a73c7 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/server-file-extractor.test.ts @@ -0,0 +1,216 @@ +import fs from "node:fs"; +import path from "node:path"; +import { afterEach, describe, expect, test, vi } from "vitest"; +import { + extractServingEndpoints, + findServerFile, +} from "../server-file-extractor"; + +describe("findServerFile", () => { + afterEach(() => { + vi.restoreAllMocks(); + }); + + test("returns server/index.ts when it exists", () => { + vi.spyOn(fs, "existsSync").mockImplementation((p) => + String(p).endsWith(path.join("server", "index.ts")), + ); + expect(findServerFile("/app")).toBe( + path.join("/app", "server", "index.ts"), + ); + }); + + test("returns server/server.ts when index.ts does not exist", () => { + vi.spyOn(fs, "existsSync").mockImplementation((p) => + String(p).endsWith(path.join("server", "server.ts")), + ); + expect(findServerFile("/app")).toBe( + path.join("/app", "server", "server.ts"), + ); + }); + + test("returns null when no server file exists", () => { + vi.spyOn(fs, "existsSync").mockReturnValue(false); + expect(findServerFile("/app")).toBeNull(); + }); +}); + +describe("extractServingEndpoints", () => { + afterEach(() => { + vi.restoreAllMocks(); + }); + + function mockServerFile(content: string) { + vi.spyOn(fs, "readFileSync").mockReturnValue(content); + } + + test("extracts inline endpoints from serving() call", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [ + serving({ + endpoints: { + demo: { env: "DATABRICKS_SERVING_ENDPOINT_NAME" }, + second: { env: "DATABRICKS_SERVING_ENDPOINT_SECOND" }, + } + }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toEqual({ + demo: { env: "DATABRICKS_SERVING_ENDPOINT_NAME" }, + second: { env: "DATABRICKS_SERVING_ENDPOINT_SECOND" }, + }); + }); + + test("extracts servedModel when present", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [ + serving({ + endpoints: { + demo: { env: "DATABRICKS_SERVING_ENDPOINT_NAME", servedModel: "my-model" }, + } + }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toEqual({ + demo: { + env: "DATABRICKS_SERVING_ENDPOINT_NAME", + servedModel: "my-model", + }, + }); + }); + + test("returns null when serving() has no arguments", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [serving()], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toBeNull(); + }); + + test("returns null when serving() has config but no endpoints", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [ + serving({ timeout: 5000 }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toBeNull(); + }); + + test("returns null when serving() has empty config object", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [serving({})], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toBeNull(); + }); + + test("returns null when endpoints is a variable reference", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +const myEndpoints = { demo: { env: "DATABRICKS_SERVING_ENDPOINT_NAME" } }; +createApp({ + plugins: [ + serving({ endpoints: myEndpoints }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toBeNull(); + }); + + test("returns null when no serving() call exists", () => { + mockServerFile(` +import { createApp, analytics } from '@databricks/appkit'; + +createApp({ + plugins: [analytics({})], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toBeNull(); + }); + + test("returns null when server file cannot be read", () => { + vi.spyOn(fs, "readFileSync").mockImplementation(() => { + throw new Error("ENOENT"); + }); + + const result = extractServingEndpoints("/app/server/nonexistent.ts"); + expect(result).toBeNull(); + }); + + test("handles single-quoted env values", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [ + serving({ + endpoints: { + demo: { env: 'DATABRICKS_SERVING_ENDPOINT_NAME' }, + } + }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toEqual({ + demo: { env: "DATABRICKS_SERVING_ENDPOINT_NAME" }, + }); + }); + + test("handles endpoints with trailing commas", () => { + mockServerFile(` +import { createApp, serving } from '@databricks/appkit'; + +createApp({ + plugins: [ + serving({ + endpoints: { + demo: { env: "DATABRICKS_SERVING_ENDPOINT_NAME" }, + second: { env: "DATABRICKS_SERVING_ENDPOINT_SECOND" }, + }, + }), + ], +}); +`); + + const result = extractServingEndpoints("/app/server/index.ts"); + expect(result).toEqual({ + demo: { env: "DATABRICKS_SERVING_ENDPOINT_NAME" }, + second: { env: "DATABRICKS_SERVING_ENDPOINT_SECOND" }, + }); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/tests/vite-plugin.test.ts b/packages/appkit/src/type-generator/serving/tests/vite-plugin.test.ts new file mode 100644 index 00000000..074d3d44 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/tests/vite-plugin.test.ts @@ -0,0 +1,186 @@ +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; + +const mockGenerateServingTypes = vi.fn(async () => {}); +const mockFindServerFile = vi.fn((): string | null => null); +const mockExtractServingEndpoints = vi.fn( + (): Record | null => null, +); + +vi.mock("../generator", () => ({ + generateServingTypes: (...args: any[]) => mockGenerateServingTypes(...args), +})); + +vi.mock("../server-file-extractor", () => ({ + findServerFile: (...args: any[]) => mockFindServerFile(...args), + extractServingEndpoints: (...args: any[]) => + mockExtractServingEndpoints(...args), +})); + +import { appKitServingTypesPlugin } from "../vite-plugin"; + +describe("appKitServingTypesPlugin", () => { + const originalEnv = { ...process.env }; + + beforeEach(() => { + mockGenerateServingTypes.mockReset(); + mockFindServerFile.mockReset(); + mockExtractServingEndpoints.mockReset(); + }); + + afterEach(() => { + process.env = { ...originalEnv }; + vi.restoreAllMocks(); + }); + + describe("apply()", () => { + test("returns true when explicit endpoints provided", () => { + const plugin = appKitServingTypesPlugin({ + endpoints: { llm: { env: "LLM_ENDPOINT" } }, + }); + expect((plugin as any).apply()).toBe(true); + }); + + test("returns true when DATABRICKS_SERVING_ENDPOINT_NAME is set", () => { + process.env.DATABRICKS_SERVING_ENDPOINT_NAME = "my-endpoint"; + const plugin = appKitServingTypesPlugin(); + expect((plugin as any).apply()).toBe(true); + }); + + test("returns true when server file found in cwd", () => { + mockFindServerFile.mockReturnValueOnce("/app/server/index.ts"); + const plugin = appKitServingTypesPlugin(); + expect((plugin as any).apply()).toBe(true); + }); + + test("returns true when server file found in parent dir", () => { + mockFindServerFile + .mockReturnValueOnce(null) // cwd check + .mockReturnValueOnce("/app/server/index.ts"); // parent check + const plugin = appKitServingTypesPlugin(); + expect((plugin as any).apply()).toBe(true); + }); + + test("returns false when nothing configured", () => { + delete process.env.DATABRICKS_SERVING_ENDPOINT_NAME; + mockFindServerFile.mockReturnValue(null); + const plugin = appKitServingTypesPlugin(); + expect((plugin as any).apply()).toBe(false); + }); + }); + + describe("configResolved()", () => { + test("resolves outFile relative to config.root", async () => { + const plugin = appKitServingTypesPlugin({ + endpoints: { llm: { env: "LLM" } }, + }); + (plugin as any).configResolved({ root: "/app/client" }); + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ + outFile: expect.stringContaining( + "/app/client/src/appKitServingTypes.d.ts", + ), + }), + ); + }); + + test("uses custom outFile when provided", async () => { + const plugin = appKitServingTypesPlugin({ + outFile: "types/serving.d.ts", + endpoints: { llm: { env: "LLM" } }, + }); + (plugin as any).configResolved({ root: "/app/client" }); + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ + outFile: expect.stringContaining("types/serving.d.ts"), + }), + ); + }); + }); + + describe("buildStart()", () => { + test("calls generateServingTypes with explicit endpoints", async () => { + const endpoints = { llm: { env: "LLM_ENDPOINT" } }; + const plugin = appKitServingTypesPlugin({ endpoints }); + (plugin as any).configResolved({ root: "/app/client" }); + + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ + endpoints, + noCache: false, + }), + ); + }); + + test("extracts endpoints from server file when not explicit", async () => { + const extracted = { llm: { env: "LLM_EP" } }; + mockFindServerFile.mockReturnValue("/app/server/index.ts"); + mockExtractServingEndpoints.mockReturnValue(extracted); + + const plugin = appKitServingTypesPlugin(); + (plugin as any).configResolved({ root: "/app/client" }); + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ endpoints: extracted }), + ); + }); + + test("passes undefined endpoints when no server file found", async () => { + mockFindServerFile.mockReturnValue(null); + + const plugin = appKitServingTypesPlugin(); + (plugin as any).configResolved({ root: "/app/client" }); + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ endpoints: undefined }), + ); + }); + + test("passes undefined when AST extraction returns null", async () => { + mockFindServerFile.mockReturnValue("/app/server/index.ts"); + mockExtractServingEndpoints.mockReturnValue(null); + + const plugin = appKitServingTypesPlugin(); + (plugin as any).configResolved({ root: "/app/client" }); + await (plugin as any).buildStart(); + + expect(mockGenerateServingTypes).toHaveBeenCalledWith( + expect.objectContaining({ endpoints: undefined }), + ); + }); + + test("swallows errors in dev mode", async () => { + process.env.NODE_ENV = "development"; + mockGenerateServingTypes.mockRejectedValue(new Error("fetch failed")); + + const plugin = appKitServingTypesPlugin({ + endpoints: { llm: { env: "LLM" } }, + }); + (plugin as any).configResolved({ root: "/app/client" }); + + // Should not throw + await expect((plugin as any).buildStart()).resolves.toBeUndefined(); + }); + + test("rethrows errors in production mode", async () => { + process.env.NODE_ENV = "production"; + mockGenerateServingTypes.mockRejectedValue(new Error("fetch failed")); + + const plugin = appKitServingTypesPlugin({ + endpoints: { llm: { env: "LLM" } }, + }); + (plugin as any).configResolved({ root: "/app/client" }); + + await expect((plugin as any).buildStart()).rejects.toThrow( + "fetch failed", + ); + }); + }); +}); diff --git a/packages/appkit/src/type-generator/serving/vite-plugin.ts b/packages/appkit/src/type-generator/serving/vite-plugin.ts new file mode 100644 index 00000000..f94306c4 --- /dev/null +++ b/packages/appkit/src/type-generator/serving/vite-plugin.ts @@ -0,0 +1,109 @@ +import path from "node:path"; +import type { Plugin } from "vite"; +import { createLogger } from "../../logging/logger"; +import type { EndpointConfig } from "../../plugins/serving/types"; +import { generateServingTypes } from "../index"; +import { + extractServingEndpoints, + findServerFile, +} from "./server-file-extractor"; + +const logger = createLogger("type-generator:serving:vite-plugin"); + +interface AppKitServingTypesPluginOptions { + /** Path to the output .d.ts file (relative to client root). Default: "src/appKitServingTypes.d.ts" */ + outFile?: string; + /** Endpoint config override. If omitted, auto-discovers from the server file or falls back to DATABRICKS_SERVING_ENDPOINT_NAME env var. */ + endpoints?: Record; +} + +/** + * Vite plugin to generate TypeScript types for AppKit serving endpoints. + * Fetches OpenAPI schemas from Databricks and generates a .d.ts with + * ServingEndpointRegistry module augmentation. + * + * Endpoint discovery order: + * 1. Explicit `endpoints` option (override) + * 2. AST extraction from server file (server/index.ts or server/server.ts) + * 3. DATABRICKS_SERVING_ENDPOINT_NAME env var (single default endpoint) + */ +export function appKitServingTypesPlugin( + options?: AppKitServingTypesPluginOptions, +): Plugin { + let outFile: string; + let projectRoot: string; + + async function generate() { + try { + // Resolve endpoints: explicit option > server file AST > env var fallback (handled by generator) + let endpoints = options?.endpoints; + if (!endpoints) { + const serverFile = findServerFile(projectRoot); + if (serverFile) { + endpoints = extractServingEndpoints(serverFile) ?? undefined; + } + } + + await generateServingTypes({ + outFile, + endpoints, + noCache: false, + }); + } catch (error) { + if (process.env.NODE_ENV === "production") { + throw error; + } + logger.error("Error generating serving types: %O", error); + } + } + + return { + name: "appkit-serving-types", + + apply() { + // Fast checks — no AST parsing here + if (options?.endpoints && Object.keys(options.endpoints).length > 0) { + return true; + } + + if (process.env.DATABRICKS_SERVING_ENDPOINT_NAME) { + return true; + } + + // Check if a server file exists (may contain serving() config) + // Use process.cwd() for apply() since configResolved hasn't run yet + if (findServerFile(process.cwd())) { + return true; + } + + // Also check parent dir (for when cwd is client/) + const parentDir = path.resolve(process.cwd(), ".."); + if (findServerFile(parentDir)) { + return true; + } + + logger.debug( + "No serving endpoints configured. Skipping type generation.", + ); + return false; + }, + + configResolved(config) { + // Resolve project root: go up one level from Vite root (client dir) + // This handles both: + // - pnpm dev: process.cwd() is app root, config.root is client/ + // - pnpm build: process.cwd() is client/ (cd client && vite build), config.root is client/ + projectRoot = path.resolve(config.root, ".."); + outFile = path.resolve( + config.root, + options?.outFile ?? "src/appKitServingTypes.d.ts", + ); + }, + + async buildStart() { + await generate(); + }, + + // No configureServer / watcher — schemas change on endpoint redeploy, not on file edit + }; +} diff --git a/packages/shared/src/cli/commands/generate-types.ts b/packages/shared/src/cli/commands/generate-types.ts index 06c8b016..3be45091 100644 --- a/packages/shared/src/cli/commands/generate-types.ts +++ b/packages/shared/src/cli/commands/generate-types.ts @@ -12,35 +12,34 @@ async function runGenerateTypes( options?: { noCache?: boolean }, ) { try { - const resolvedWarehouseId = - warehouseId || process.env.DATABRICKS_WAREHOUSE_ID; + const resolvedRootDir = rootDir || process.cwd(); + const noCache = options?.noCache || false; - if (!resolvedWarehouseId) { - process.exit(0); - } + const typeGen = await import("@databricks/appkit/type-generator"); - // Try to import the type generator from @databricks/appkit - const { generateFromEntryPoint } = await import( - "@databricks/appkit/type-generator" - ); + // Generate analytics query types (requires warehouse ID) + const resolvedWarehouseId = + warehouseId || process.env.DATABRICKS_WAREHOUSE_ID; - const resolvedRootDir = rootDir || process.cwd(); - const resolvedOutFile = - outFile || path.join(process.cwd(), "client/src/appKitTypes.d.ts"); + if (resolvedWarehouseId) { + const resolvedOutFile = + outFile || path.join(process.cwd(), "client/src/appKitTypes.d.ts"); - const queryFolder = path.join(resolvedRootDir, "config/queries"); - if (!fs.existsSync(queryFolder)) { - console.warn( - `Warning: No queries found at ${queryFolder}. Skipping type generation.`, - ); - return; + const queryFolder = path.join(resolvedRootDir, "config/queries"); + if (fs.existsSync(queryFolder)) { + await typeGen.generateFromEntryPoint({ + queryFolder, + outFile: resolvedOutFile, + warehouseId: resolvedWarehouseId, + noCache, + }); + } } - await generateFromEntryPoint({ - queryFolder, - outFile: resolvedOutFile, - warehouseId: resolvedWarehouseId, - noCache: options?.noCache || false, + // Generate serving endpoint types (no warehouse required) + await typeGen.generateServingTypes({ + outFile: path.join(process.cwd(), "client/src/appKitServingTypes.d.ts"), + noCache, }); } catch (error) { if ( diff --git a/packages/shared/src/cli/commands/type-generator.d.ts b/packages/shared/src/cli/commands/type-generator.d.ts index debda666..ce69781f 100644 --- a/packages/shared/src/cli/commands/type-generator.d.ts +++ b/packages/shared/src/cli/commands/type-generator.d.ts @@ -1,9 +1,14 @@ // Type declarations for optional @databricks/appkit/type-generator module declare module "@databricks/appkit/type-generator" { export function generateFromEntryPoint(options: { - queryFolder: string; + queryFolder?: string; outFile: string; warehouseId: string; - noCache: boolean; + noCache?: boolean; + }): Promise; + + export function generateServingTypes(options: { + outFile: string; + noCache?: boolean; }): Promise; } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 199fcfb8..9ca11b81 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -242,6 +242,9 @@ importers: packages/appkit: dependencies: + '@ast-grep/napi': + specifier: 0.37.0 + version: 0.37.0 '@databricks/lakebase': specifier: workspace:* version: link:../lakebase