Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions packages/genkit_openai/lib/genkit_openai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,23 @@ import 'dart:async';

import 'package:genkit/plugin.dart';

import 'src/chat.dart' as chat;
import 'src/chat.dart' as chat_lib;
import 'src/openai_plugin.dart';
import 'src/stt.dart' as stt_lib;

export 'src/chat.dart' show OpenAIChatOptions, OpenAIOptions;
export 'src/converters.dart' show GenkitConverter;
export 'src/stt.dart'
show
OpenAISttOptions,
buildTranscriptionRequest,
buildTranslationRequest,
parseSttModelOptions,
sttModelInfo,
sttModelOptionsSchema,
transcriptionModelIds,
transcriptionToModelResponse,
whisperModelIds;
export 'src/utils.dart'
show
defaultModelInfo,
Expand Down Expand Up @@ -65,11 +77,26 @@ class OpenAICompatPluginHandle {
);
}

/// Reference to a model
ModelRef<chat.OpenAIChatOptions> model(String name) {
/// Reference to a chat model.
ModelRef<chat_lib.OpenAIChatOptions> model(String name) {
return modelRef(
'openai/$name',
customOptions: chat_lib.chatModelOptionsSchema(),
);
}

/// Reference to a speech-to-text model.
///
/// Works with both Whisper models (e.g. `'whisper-1'`) and GPT transcription
/// models (e.g. `'gpt-4o-transcribe'`, `'gpt-4o-mini-transcribe'`).
ModelRef<stt_lib.OpenAISttOptions> stt(
String name, {
stt_lib.OpenAISttOptions? config,
}) {
return modelRef(
'openai/$name',
customOptions: chat.chatModelOptionsSchema(),
customOptions: stt_lib.sttModelOptionsSchema(),
config: config,
);
}
}
186 changes: 113 additions & 73 deletions packages/genkit_openai/lib/src/openai_plugin.dart
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ import 'package:genkit/plugin.dart';
import 'package:openai_dart/openai_dart.dart' as sdk;

import '../genkit_openai.dart';
import 'chat.dart' as chat;
import 'chat.dart' as chat_lib;
import 'stt.dart' as stt_lib;
import 'utils.dart';

/// Core plugin implementation
class OpenAIPlugin extends GenkitPlugin {
Expand Down Expand Up @@ -56,12 +58,14 @@ class OpenAIPlugin extends GenkitPlugin {
for (final modelId in availableModelIds) {
final modelType = getModelType(modelId);

if (modelType != 'chat' && modelType != 'unknown') {
continue;
if (modelType == 'chat' || modelType == 'unknown') {
final info = modelInfoFor(modelId);
actions.add(_createModel(modelId, info));
} else if (modelType == 'audio') {
if (_isSttModelId(modelId)) {
actions.add(_createSttModel(modelId));
}
}

final info = modelInfoFor(modelId);
actions.add(_createModel(modelId, info));
}
} catch (e) {
throw GenkitException(
Expand All @@ -83,13 +87,7 @@ class OpenAIPlugin extends GenkitPlugin {
Future<List<String>> _fetchAvailableModels() async {
final resolvedConfig = await _resolveClientConfig();

final client = sdk.OpenAIClient(
config: sdk.OpenAIConfig(
authProvider: sdk.ApiKeyProvider(resolvedConfig.apiKey),
baseUrl: resolvedConfig.baseUrl ?? 'https://api.openai.com/v1',
defaultHeaders: resolvedConfig.headers ?? const {},
),
);
final client = buildOpenAIClient(resolvedConfig);

try {
final response = await client.models.list();
Expand All @@ -106,7 +104,7 @@ class OpenAIPlugin extends GenkitPlugin {
}
}

Future<_ResolvedClientConfig> _resolveClientConfig() async {
Future<OpenAIClientConfig> _resolveClientConfig() async {
final configuredApiKey = await _resolveApiKey();
if (configuredApiKey == null || configuredApiKey.trim().isEmpty) {
throw GenkitException(
Expand All @@ -115,7 +113,7 @@ class OpenAIPlugin extends GenkitPlugin {
);
}

return _ResolvedClientConfig(
return OpenAIClientConfig(
apiKey: configuredApiKey.trim(),
baseUrl: baseUrl,
headers: headers,
Expand All @@ -135,25 +133,32 @@ class OpenAIPlugin extends GenkitPlugin {
list() async {
try {
final modelIds = await _fetchAvailableModels();
final modelMetadataList =
final metadataList =
<ActionMetadata<dynamic, dynamic, dynamic, dynamic>>[];

for (final modelId in modelIds) {
final modelType = getModelType(modelId);
if (modelType != 'chat' && modelType != 'unknown') {
continue;
}

modelMetadataList.add(
modelMetadata(
'openai/$modelId',
modelInfo: modelInfoFor(modelId),
customOptions: chat.chatModelOptionsSchema(),
),
);
if (modelType == 'chat' || modelType == 'unknown') {
metadataList.add(
modelMetadata(
'openai/$modelId',
modelInfo: modelInfoFor(modelId),
customOptions: chat_lib.chatModelOptionsSchema(),
),
);
} else if (modelType == 'audio' && _isSttModelId(modelId)) {
metadataList.add(
modelMetadata(
'openai/$modelId',
modelInfo: stt_lib.sttModelInfo,
customOptions: stt_lib.sttModelOptionsSchema(),
),
);
}
}

return modelMetadataList;
return metadataList;
} catch (e, stackTrace) {
throw GenkitException(
'Error listing models from OpenAI: $e',
Expand All @@ -165,41 +170,106 @@ class OpenAIPlugin extends GenkitPlugin {

@override
Action? resolve(String actionType, String name) {
if (actionType == 'model') {
return _createModel(name, null);
}
return null;
if (actionType != 'model') return null;
if (_isSttModelId(name)) return _createSttModel(name);
return _createModel(name, null);
}

static bool _isSttModelId(String modelId) {
final id = modelId.toLowerCase();
return id.contains('whisper') || id.contains('transcribe');
}

/// Creates a Genkit [Model] that calls the OpenAI speech-to-text API.
///
/// Whisper models support an additional `translate: true` option that routes
/// the request through the translation endpoint instead of transcriptions.
Model<stt_lib.OpenAISttOptions> _createSttModel(String modelName) {
return Model<stt_lib.OpenAISttOptions>(
name: 'openai/$modelName',
customOptions: stt_lib.sttModelOptionsSchema(),
metadata: {'model': stt_lib.sttModelInfo.toJson()},
fn: (req, ctx) async {
final modelRequest = req!;
final options = stt_lib.parseSttModelOptions(modelRequest.config);

final resolvedConfig = await _resolveClientConfig();
final client = buildOpenAIClient(resolvedConfig);

try {
final shouldTranslate =
options.translate == true && modelName.contains('whisper');

if (shouldTranslate) {
final translationReq = stt_lib.buildTranslationRequest(
modelId: modelName,
request: modelRequest,
options: options,
);
final response = await client.audio.translations.create(
translationReq,
);
return stt_lib.transcriptionToModelResponse(
response.text,
raw: response.toJson(),
);
}

final transcriptionReq = stt_lib.buildTranscriptionRequest(
modelId: modelName,
request: modelRequest,
options: options,
);

if (options.responseFormat == 'verbose_json') {
final response = await client.audio.transcriptions.createVerbose(
transcriptionReq,
);
return stt_lib.transcriptionToModelResponse(
response.text,
raw: response.toJson(),
);
}

final response = await client.audio.transcriptions.create(
transcriptionReq,
);
return stt_lib.transcriptionToModelResponse(
response.text,
raw: response.toJson(),
);
} catch (e, stackTrace) {
rethrowAsGenkitException(e, stackTrace, 'STT');
} finally {
client.close();
}
},
);
}

Model _createModel(String modelName, ModelInfo? info) {
final modelInfo = info ?? modelInfoFor(modelName);

return Model(
name: 'openai/$modelName',
customOptions: chat.chatModelOptionsSchema(),
customOptions: chat_lib.chatModelOptionsSchema(),
metadata: {'model': modelInfo.toJson()},
fn: (req, ctx) async {
final modelRequest = req!;
final options = chat.parseChatModelOptions(modelRequest.config);
final options = chat_lib.parseChatModelOptions(modelRequest.config);

final resolvedConfig = await _resolveClientConfig();
final client = sdk.OpenAIClient(
config: sdk.OpenAIConfig(
authProvider: sdk.ApiKeyProvider(resolvedConfig.apiKey),
baseUrl: resolvedConfig.baseUrl ?? 'https://api.openai.com/v1',
defaultHeaders: resolvedConfig.headers ?? const {},
),
);
final client = buildOpenAIClient(resolvedConfig);

try {
final supports = modelInfo.supports;
final supportsTools = supports?['tools'] == true;

final isJsonMode = chat.isJsonStructuredOutput(
final isJsonMode = chat_lib.isJsonStructuredOutput(
modelRequest.output?.format,
modelRequest.output?.contentType,
);
final responseFormat = chat.buildOpenAIResponseFormat(
final responseFormat = chat_lib.buildOpenAIResponseFormat(
modelRequest.output?.schema,
);
final request = sdk.ChatCompletionCreateRequest(
Expand Down Expand Up @@ -227,25 +297,7 @@ class OpenAIPlugin extends GenkitPlugin {
return await _handleNonStreaming(client, request);
}
} catch (e, stackTrace) {
if (e is GenkitException) {
rethrow;
}

StatusCodes? status;
String? details;

if (e is sdk.ApiException) {
status = StatusCodes.fromHttpStatus(e.statusCode);
details = e.body?.toString();
}

throw GenkitException(
'OpenAI API error: $e',
status: status,
details: details ?? e.toString(),
underlyingException: e,
stackTrace: stackTrace,
);
rethrowAsGenkitException(e, stackTrace, 'Chat');
} finally {
client.close();
}
Expand Down Expand Up @@ -329,15 +381,3 @@ class OpenAIPlugin extends GenkitPlugin {
);
}
}

final class _ResolvedClientConfig {
final String apiKey;
final String? baseUrl;
final Map<String, String>? headers;

const _ResolvedClientConfig({
required this.apiKey,
required this.baseUrl,
required this.headers,
});
}
Loading