Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions packages/genkit_openai/lib/genkit_openai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,21 @@ import 'dart:async';

import 'package:genkit/plugin.dart';

import 'src/chat.dart' as chat;
import 'src/chat.dart' as chat_lib;
import 'src/openai_plugin.dart';
import 'src/tts.dart' as tts_lib;

export 'src/chat.dart' show OpenAIChatOptions, OpenAIOptions;
export 'src/converters.dart' show GenkitConverter;
export 'src/tts.dart'
show
OpenAITtsOptions,
audioMimeTypes,
parseTtsOptions,
supportedTtsModels,
ttsModelInfo,
ttsModelRef,
ttsOptionsSchema;
export 'src/utils.dart'
show
defaultModelInfo,
Expand Down Expand Up @@ -65,11 +75,21 @@ class OpenAICompatPluginHandle {
);
}

/// Reference to a model
ModelRef<chat.OpenAIChatOptions> model(String name) {
/// Reference to a chat model.
///
/// Use [name] to select a supported model such as `gpt-4o`, `gpt-4o-mini`,
/// or `o3-mini`.
ModelRef<chat_lib.OpenAIChatOptions> model(String name) {
return modelRef(
'openai/$name',
customOptions: chat.chatModelOptionsSchema(),
customOptions: chat_lib.chatModelOptionsSchema(),
);
}

/// Reference to a TTS (text-to-speech) model.
///
/// Use [name] to select a supported model such as `tts-1`, `tts-1-hd`,
/// or `gpt-4o-mini-tts`.
ModelRef<tts_lib.OpenAITtsOptions> speech(String name) =>
tts_lib.ttsModelRef(name);
}
164 changes: 101 additions & 63 deletions packages/genkit_openai/lib/src/openai_plugin.dart
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ import 'package:genkit/plugin.dart';
import 'package:openai_dart/openai_dart.dart' as sdk;

import '../genkit_openai.dart';
import 'chat.dart' as chat;
import 'chat.dart' as chat_lib;
import 'tts.dart' as tts_lib;
import 'utils.dart';

/// Core plugin implementation
class OpenAIPlugin extends GenkitPlugin {
Expand Down Expand Up @@ -48,20 +50,20 @@ class OpenAIPlugin extends GenkitPlugin {
Future<List<Action>> init() async {
final actions = <Action>[];

// Fetch and register models from OpenAI API only for default OpenAI host.
// Fetch and register models from OpenAI API only for the default host.
if (baseUrl == null) {
try {
final availableModelIds = await _fetchAvailableModels();

for (final modelId in availableModelIds) {
final modelType = getModelType(modelId);

if (modelType != 'chat' && modelType != 'unknown') {
continue;
if (modelType == 'chat' || modelType == 'unknown') {
actions.add(_createModel(modelId, modelInfoFor(modelId)));
} else if (modelType == 'audio' &&
tts_lib.supportedTtsModels.contains(modelId)) {
actions.add(_createTtsModel(modelId));
}

final info = modelInfoFor(modelId);
actions.add(_createModel(modelId, info));
}
} catch (e) {
throw GenkitException(
Expand All @@ -71,42 +73,27 @@ class OpenAIPlugin extends GenkitPlugin {
}
}

// Register custom models
// Register custom models.
for (final model in customModels) {
actions.add(_createModel(model.name, model.info));
}

return actions;
}

/// Fetch available model IDs from OpenAI API
/// Fetch available model IDs from OpenAI API.
Future<List<String>> _fetchAvailableModels() async {
final resolvedConfig = await _resolveClientConfig();

final client = sdk.OpenAIClient(
config: sdk.OpenAIConfig(
authProvider: sdk.ApiKeyProvider(resolvedConfig.apiKey),
baseUrl: resolvedConfig.baseUrl ?? 'https://api.openai.com/v1',
defaultHeaders: resolvedConfig.headers ?? const {},
),
);

final client = buildOpenAIClient(resolvedConfig);
try {
final response = await client.models.list();
final modelIds = <String>[];

// Collect all model IDs
for (final model in response.data) {
modelIds.add(model.id);
}

return modelIds;
return response.data.map((m) => m.id).toList();
} finally {
client.close();
}
}

Future<_ResolvedClientConfig> _resolveClientConfig() async {
Future<OpenAIClientConfig> _resolveClientConfig() async {
final configuredApiKey = await _resolveApiKey();
if (configuredApiKey == null || configuredApiKey.trim().isEmpty) {
throw GenkitException(
Expand All @@ -115,7 +102,7 @@ class OpenAIPlugin extends GenkitPlugin {
);
}

return _ResolvedClientConfig(
return OpenAIClientConfig(
apiKey: configuredApiKey.trim(),
baseUrl: baseUrl,
headers: headers,
Expand All @@ -135,25 +122,36 @@ class OpenAIPlugin extends GenkitPlugin {
list() async {
try {
final modelIds = await _fetchAvailableModels();
final modelMetadataList =
<ActionMetadata<dynamic, dynamic, dynamic, dynamic>>[];
final metadata = <ActionMetadata<dynamic, dynamic, dynamic, dynamic>>[];

for (final modelId in modelIds) {
final modelType = getModelType(modelId);
if (modelType != 'chat' && modelType != 'unknown') {
continue;

if (modelType == 'chat' || modelType == 'unknown') {
metadata.add(
modelMetadata(
'openai/$modelId',
modelInfo: modelInfoFor(modelId),
customOptions: chat_lib.chatModelOptionsSchema(),
),
);
}
}

modelMetadataList.add(
modelMetadata(
'openai/$modelId',
modelInfo: modelInfoFor(modelId),
customOptions: chat.chatModelOptionsSchema(),
),
);
// Include known TTS models when using the default OpenAI endpoint.
if (baseUrl == null) {
for (final modelId in tts_lib.supportedTtsModels) {
metadata.add(
modelMetadata(
'openai/$modelId',
modelInfo: tts_lib.ttsModelInfo(modelId),
customOptions: tts_lib.ttsOptionsSchema(),
),
);
}
}

return modelMetadataList;
return metadata;
} catch (e, stackTrace) {
throw GenkitException(
'Error listing models from OpenAI: $e',
Expand All @@ -166,40 +164,90 @@ class OpenAIPlugin extends GenkitPlugin {
@override
Action? resolve(String actionType, String name) {
if (actionType == 'model') {
if (tts_lib.supportedTtsModels.contains(name)) {
return _createTtsModel(name);
}
return _createModel(name, null);
}
return null;
}

/// Creates a [Model] action that calls the OpenAI TTS endpoint.
Model _createTtsModel(String modelName) {
return Model(
name: 'openai/$modelName',
customOptions: tts_lib.ttsOptionsSchema(),
metadata: {'model': tts_lib.ttsModelInfo(modelName).toJson()},
fn: (req, ctx) async {
final request = req!;
final options = tts_lib.parseTtsOptions(request.config);
final responseFormat = options.responseFormat ?? 'mp3';
final modelVersion = options.version ?? modelName;

if (request.messages.isEmpty) {
throw GenkitException(
'TTS requires a text input message.',
status: StatusCodes.INVALID_ARGUMENT,
);
}
final inputText = request.messages.first.text;
if (inputText.isEmpty) {
throw GenkitException(
'TTS requires non-empty text input.',
status: StatusCodes.INVALID_ARGUMENT,
);
}

final resolvedConfig = await _resolveClientConfig();
final client = buildOpenAIClient(resolvedConfig);
try {
final audioBytes = await client.audio.speech.create(
sdk.SpeechRequest(
model: modelVersion,
input: inputText,
voice: tts_lib.parseSpeechVoice(options.voice),
responseFormat: tts_lib.parseSpeechResponseFormat(responseFormat),
speed: options.speed,
),
);
return tts_lib.speechToModelResponse(audioBytes, responseFormat, {
'model': modelVersion,
'responseFormat': responseFormat,
});
} catch (e, stackTrace) {
rethrowAsGenkitException(e, stackTrace, 'TTS');
} finally {
client.close();
}
},
);
}

// ─── Chat model factory ────────────────────────────────────────────────────────

Model _createModel(String modelName, ModelInfo? info) {
final modelInfo = info ?? modelInfoFor(modelName);

return Model(
name: 'openai/$modelName',
customOptions: chat.chatModelOptionsSchema(),
customOptions: chat_lib.chatModelOptionsSchema(),
metadata: {'model': modelInfo.toJson()},
fn: (req, ctx) async {
final modelRequest = req!;
final options = chat.parseChatModelOptions(modelRequest.config);
final options = chat_lib.parseChatModelOptions(modelRequest.config);

final resolvedConfig = await _resolveClientConfig();
final client = sdk.OpenAIClient(
config: sdk.OpenAIConfig(
authProvider: sdk.ApiKeyProvider(resolvedConfig.apiKey),
baseUrl: resolvedConfig.baseUrl ?? 'https://api.openai.com/v1',
defaultHeaders: resolvedConfig.headers ?? const {},
),
);
final client = buildOpenAIClient(resolvedConfig);

try {
final supports = modelInfo.supports;
final supportsTools = supports?['tools'] == true;

final isJsonMode = chat.isJsonStructuredOutput(
final isJsonMode = chat_lib.isJsonStructuredOutput(
modelRequest.output?.format,
modelRequest.output?.contentType,
);
final responseFormat = chat.buildOpenAIResponseFormat(
final responseFormat = chat_lib.buildOpenAIResponseFormat(
modelRequest.output?.schema,
);
final request = sdk.ChatCompletionCreateRequest(
Expand Down Expand Up @@ -253,6 +301,8 @@ class OpenAIPlugin extends GenkitPlugin {
);
}

// ─── Streaming helpers ─────────────────────────────────────────────────────────

/// Handle streaming response
Future<ModelResponse> _handleStreaming(
sdk.OpenAIClient client,
Expand Down Expand Up @@ -329,15 +379,3 @@ class OpenAIPlugin extends GenkitPlugin {
);
}
}

final class _ResolvedClientConfig {
final String apiKey;
final String? baseUrl;
final Map<String, String>? headers;

const _ResolvedClientConfig({
required this.apiKey,
required this.baseUrl,
required this.headers,
});
}
Loading