diff --git a/packages/genkit_openai/lib/genkit_openai.dart b/packages/genkit_openai/lib/genkit_openai.dart index 1c33cb47..388509fb 100644 --- a/packages/genkit_openai/lib/genkit_openai.dart +++ b/packages/genkit_openai/lib/genkit_openai.dart @@ -16,11 +16,21 @@ import 'dart:async'; import 'package:genkit/plugin.dart'; -import 'src/chat.dart' as chat; +import 'src/chat.dart' as chat_lib; import 'src/openai_plugin.dart'; +import 'src/tts.dart' as tts_lib; export 'src/chat.dart' show OpenAIChatOptions, OpenAIOptions; export 'src/converters.dart' show GenkitConverter; +export 'src/tts.dart' + show + OpenAITtsOptions, + audioMimeTypes, + parseTtsOptions, + supportedTtsModels, + ttsModelInfo, + ttsModelRef, + ttsOptionsSchema; export 'src/utils.dart' show defaultModelInfo, @@ -65,11 +75,21 @@ class OpenAICompatPluginHandle { ); } - /// Reference to a model - ModelRef model(String name) { + /// Reference to a chat model. + /// + /// Use [name] to select a supported model such as `gpt-4o`, `gpt-4o-mini`, + /// or `o3-mini`. + ModelRef model(String name) { return modelRef( 'openai/$name', - customOptions: chat.chatModelOptionsSchema(), + customOptions: chat_lib.chatModelOptionsSchema(), ); } + + /// Reference to a TTS (text-to-speech) model. + /// + /// Use [name] to select a supported model such as `tts-1`, `tts-1-hd`, + /// or `gpt-4o-mini-tts`. + ModelRef speech(String name) => + tts_lib.ttsModelRef(name); } diff --git a/packages/genkit_openai/lib/src/openai_plugin.dart b/packages/genkit_openai/lib/src/openai_plugin.dart index d95c44cf..50e4a49a 100644 --- a/packages/genkit_openai/lib/src/openai_plugin.dart +++ b/packages/genkit_openai/lib/src/openai_plugin.dart @@ -16,7 +16,9 @@ import 'package:genkit/plugin.dart'; import 'package:openai_dart/openai_dart.dart' as sdk; import '../genkit_openai.dart'; -import 'chat.dart' as chat; +import 'chat.dart' as chat_lib; +import 'tts.dart' as tts_lib; +import 'utils.dart'; /// Core plugin implementation class OpenAIPlugin extends GenkitPlugin { @@ -48,7 +50,7 @@ class OpenAIPlugin extends GenkitPlugin { Future> init() async { final actions = []; - // Fetch and register models from OpenAI API only for default OpenAI host. + // Fetch and register models from OpenAI API only for the default host. if (baseUrl == null) { try { final availableModelIds = await _fetchAvailableModels(); @@ -56,12 +58,12 @@ class OpenAIPlugin extends GenkitPlugin { for (final modelId in availableModelIds) { final modelType = getModelType(modelId); - if (modelType != 'chat' && modelType != 'unknown') { - continue; + if (modelType == 'chat' || modelType == 'unknown') { + actions.add(_createModel(modelId, modelInfoFor(modelId))); + } else if (modelType == 'audio' && + tts_lib.supportedTtsModels.contains(modelId)) { + actions.add(_createTtsModel(modelId)); } - - final info = modelInfoFor(modelId); - actions.add(_createModel(modelId, info)); } } catch (e) { throw GenkitException( @@ -71,7 +73,7 @@ class OpenAIPlugin extends GenkitPlugin { } } - // Register custom models + // Register custom models. for (final model in customModels) { actions.add(_createModel(model.name, model.info)); } @@ -79,34 +81,19 @@ class OpenAIPlugin extends GenkitPlugin { return actions; } - /// Fetch available model IDs from OpenAI API + /// Fetch available model IDs from OpenAI API. Future> _fetchAvailableModels() async { final resolvedConfig = await _resolveClientConfig(); - - final client = sdk.OpenAIClient( - config: sdk.OpenAIConfig( - authProvider: sdk.ApiKeyProvider(resolvedConfig.apiKey), - baseUrl: resolvedConfig.baseUrl ?? 'https://api.openai.com/v1', - defaultHeaders: resolvedConfig.headers ?? const {}, - ), - ); - + final client = buildOpenAIClient(resolvedConfig); try { final response = await client.models.list(); - final modelIds = []; - - // Collect all model IDs - for (final model in response.data) { - modelIds.add(model.id); - } - - return modelIds; + return response.data.map((m) => m.id).toList(); } finally { client.close(); } } - Future<_ResolvedClientConfig> _resolveClientConfig() async { + Future _resolveClientConfig() async { final configuredApiKey = await _resolveApiKey(); if (configuredApiKey == null || configuredApiKey.trim().isEmpty) { throw GenkitException( @@ -115,7 +102,7 @@ class OpenAIPlugin extends GenkitPlugin { ); } - return _ResolvedClientConfig( + return OpenAIClientConfig( apiKey: configuredApiKey.trim(), baseUrl: baseUrl, headers: headers, @@ -135,25 +122,36 @@ class OpenAIPlugin extends GenkitPlugin { list() async { try { final modelIds = await _fetchAvailableModels(); - final modelMetadataList = - >[]; + final metadata = >[]; for (final modelId in modelIds) { final modelType = getModelType(modelId); - if (modelType != 'chat' && modelType != 'unknown') { - continue; + + if (modelType == 'chat' || modelType == 'unknown') { + metadata.add( + modelMetadata( + 'openai/$modelId', + modelInfo: modelInfoFor(modelId), + customOptions: chat_lib.chatModelOptionsSchema(), + ), + ); } + } - modelMetadataList.add( - modelMetadata( - 'openai/$modelId', - modelInfo: modelInfoFor(modelId), - customOptions: chat.chatModelOptionsSchema(), - ), - ); + // Include known TTS models when using the default OpenAI endpoint. + if (baseUrl == null) { + for (final modelId in tts_lib.supportedTtsModels) { + metadata.add( + modelMetadata( + 'openai/$modelId', + modelInfo: tts_lib.ttsModelInfo(modelId), + customOptions: tts_lib.ttsOptionsSchema(), + ), + ); + } } - return modelMetadataList; + return metadata; } catch (e, stackTrace) { throw GenkitException( 'Error listing models from OpenAI: $e', @@ -166,40 +164,90 @@ class OpenAIPlugin extends GenkitPlugin { @override Action? resolve(String actionType, String name) { if (actionType == 'model') { + if (tts_lib.supportedTtsModels.contains(name)) { + return _createTtsModel(name); + } return _createModel(name, null); } return null; } + /// Creates a [Model] action that calls the OpenAI TTS endpoint. + Model _createTtsModel(String modelName) { + return Model( + name: 'openai/$modelName', + customOptions: tts_lib.ttsOptionsSchema(), + metadata: {'model': tts_lib.ttsModelInfo(modelName).toJson()}, + fn: (req, ctx) async { + final request = req!; + final options = tts_lib.parseTtsOptions(request.config); + final responseFormat = options.responseFormat ?? 'mp3'; + final modelVersion = options.version ?? modelName; + + if (request.messages.isEmpty) { + throw GenkitException( + 'TTS requires a text input message.', + status: StatusCodes.INVALID_ARGUMENT, + ); + } + final inputText = request.messages.first.text; + if (inputText.isEmpty) { + throw GenkitException( + 'TTS requires non-empty text input.', + status: StatusCodes.INVALID_ARGUMENT, + ); + } + + final resolvedConfig = await _resolveClientConfig(); + final client = buildOpenAIClient(resolvedConfig); + try { + final audioBytes = await client.audio.speech.create( + sdk.SpeechRequest( + model: modelVersion, + input: inputText, + voice: tts_lib.parseSpeechVoice(options.voice), + responseFormat: tts_lib.parseSpeechResponseFormat(responseFormat), + speed: options.speed, + ), + ); + return tts_lib.speechToModelResponse(audioBytes, responseFormat, { + 'model': modelVersion, + 'responseFormat': responseFormat, + }); + } catch (e, stackTrace) { + rethrowAsGenkitException(e, stackTrace, 'TTS'); + } finally { + client.close(); + } + }, + ); + } + + // ─── Chat model factory ──────────────────────────────────────────────────────── + Model _createModel(String modelName, ModelInfo? info) { final modelInfo = info ?? modelInfoFor(modelName); return Model( name: 'openai/$modelName', - customOptions: chat.chatModelOptionsSchema(), + customOptions: chat_lib.chatModelOptionsSchema(), metadata: {'model': modelInfo.toJson()}, fn: (req, ctx) async { final modelRequest = req!; - final options = chat.parseChatModelOptions(modelRequest.config); + final options = chat_lib.parseChatModelOptions(modelRequest.config); final resolvedConfig = await _resolveClientConfig(); - final client = sdk.OpenAIClient( - config: sdk.OpenAIConfig( - authProvider: sdk.ApiKeyProvider(resolvedConfig.apiKey), - baseUrl: resolvedConfig.baseUrl ?? 'https://api.openai.com/v1', - defaultHeaders: resolvedConfig.headers ?? const {}, - ), - ); + final client = buildOpenAIClient(resolvedConfig); try { final supports = modelInfo.supports; final supportsTools = supports?['tools'] == true; - final isJsonMode = chat.isJsonStructuredOutput( + final isJsonMode = chat_lib.isJsonStructuredOutput( modelRequest.output?.format, modelRequest.output?.contentType, ); - final responseFormat = chat.buildOpenAIResponseFormat( + final responseFormat = chat_lib.buildOpenAIResponseFormat( modelRequest.output?.schema, ); final request = sdk.ChatCompletionCreateRequest( @@ -253,6 +301,8 @@ class OpenAIPlugin extends GenkitPlugin { ); } + // ─── Streaming helpers ───────────────────────────────────────────────────────── + /// Handle streaming response Future _handleStreaming( sdk.OpenAIClient client, @@ -329,15 +379,3 @@ class OpenAIPlugin extends GenkitPlugin { ); } } - -final class _ResolvedClientConfig { - final String apiKey; - final String? baseUrl; - final Map? headers; - - const _ResolvedClientConfig({ - required this.apiKey, - required this.baseUrl, - required this.headers, - }); -} diff --git a/packages/genkit_openai/lib/src/tts.dart b/packages/genkit_openai/lib/src/tts.dart new file mode 100644 index 00000000..e6fadda6 --- /dev/null +++ b/packages/genkit_openai/lib/src/tts.dart @@ -0,0 +1,141 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/// TTS (text-to-speech) support. +/// +/// Speech synthesis is performed via [sdk.OpenAIClient.audio.speech.create]. +library; + +import 'dart:convert'; +import 'dart:typed_data'; + +import 'package:genkit/plugin.dart'; +import 'package:openai_dart/openai_dart.dart' as sdk; +import 'package:schemantic/schemantic.dart'; + +part 'tts.g.dart'; + +/// Options for OpenAI TTS models (`tts-1`, `tts-1-hd`, `gpt-4o-mini-tts`). +/// +/// Note: [speed] is ignored by `gpt-4o-mini-tts`, which does not support it. +@Schema() +abstract class $OpenAITtsOptions { + /// Model version override (e.g. `tts-1-hd`). + String? get version; + + /// Voice to use for speech synthesis. + @StringField( + enumValues: ['alloy', 'echo', 'fable', 'onyx', 'nova', 'shimmer'], + ) + String? get voice; + + /// Speaking speed multiplier (0.25 – 4.0). Not supported by gpt-4o-mini-tts. + @DoubleField(minimum: 0.25, maximum: 4.0) + double? get speed; + + /// Audio encoding format for the response. + @StringField(enumValues: ['mp3', 'opus', 'aac', 'flac', 'wav', 'pcm']) + String? get responseFormat; +} + +/// Maps OpenAI TTS response format strings to their MIME types. +const Map audioMimeTypes = { + 'mp3': 'audio/mpeg', + 'opus': 'audio/opus', + 'aac': 'audio/aac', + 'flac': 'audio/flac', + 'wav': 'audio/wav', + 'pcm': 'audio/L16', +}; + +/// TTS model IDs supported by the plugin. +const List supportedTtsModels = [ + 'tts-1', + 'tts-1-hd', + 'gpt-4o-mini-tts', +]; + +/// Returns a [ModelInfo] for a TTS (speech output) model. +ModelInfo ttsModelInfo(String label) => ModelInfo( + label: label, + supports: { + 'media': false, + 'output': ['media'], + 'multiturn': false, + 'systemRole': false, + 'tools': false, + }, +); + +/// Returns the [SchemanticType] for TTS model options. +SchemanticType ttsOptionsSchema() => OpenAITtsOptions.$schema; + +/// Parses TTS model options from an action config map. +OpenAITtsOptions parseTtsOptions(Map? config) { + return config != null + ? OpenAITtsOptions.$schema.parse(config) + : OpenAITtsOptions(); +} + +/// Returns a [ModelRef] for the named TTS model under the `openai` namespace. +ModelRef ttsModelRef(String name) => + modelRef( + 'openai/$name', + customOptions: OpenAITtsOptions.$schema, + ); + +/// Converts TTS audio [audioBytes] to a [ModelResponse] with a base64 data-URI +/// media part whose MIME type is derived from [responseFormat]. +ModelResponse speechToModelResponse( + Uint8List audioBytes, + String responseFormat, + Map raw, +) { + final mimeType = audioMimeTypes[responseFormat] ?? audioMimeTypes['mp3']!; + return ModelResponse( + finishReason: FinishReason.stop, + message: Message( + role: Role.model, + content: [ + MediaPart( + media: Media( + contentType: mimeType, + url: 'data:$mimeType;base64,${base64Encode(audioBytes)}', + ), + ), + ], + ), + raw: raw, + ); +} + +/// Parses [voice] into a [sdk.SpeechVoice], defaulting to `alloy` on failure. +sdk.SpeechVoice parseSpeechVoice(String? voice) { + if (voice == null) return sdk.SpeechVoice.alloy; + try { + return sdk.SpeechVoice.fromJson(voice); + } catch (_) { + return sdk.SpeechVoice.alloy; + } +} + +/// Parses [format] into a [sdk.SpeechResponseFormat], defaulting to `mp3`. +sdk.SpeechResponseFormat parseSpeechResponseFormat(String? format) { + if (format == null) return sdk.SpeechResponseFormat.mp3; + try { + return sdk.SpeechResponseFormat.fromJson(format); + } catch (_) { + return sdk.SpeechResponseFormat.mp3; + } +} diff --git a/packages/genkit_openai/lib/src/tts.g.dart b/packages/genkit_openai/lib/src/tts.g.dart new file mode 100644 index 00000000..3a51e1a8 --- /dev/null +++ b/packages/genkit_openai/lib/src/tts.g.dart @@ -0,0 +1,136 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// GENERATED CODE - DO NOT MODIFY BY HAND +// dart format width=80 + +part of 'tts.dart'; + +// ************************************************************************** +// SchemaGenerator +// ************************************************************************** + +base class OpenAITtsOptions { + factory OpenAITtsOptions.fromJson(Map json) => + $schema.parse(json); + + OpenAITtsOptions._(this._json); + + OpenAITtsOptions({ + String? version, + String? voice, + double? speed, + String? responseFormat, + }) { + _json = { + 'version': ?version, + 'voice': ?voice, + 'speed': ?speed, + 'responseFormat': ?responseFormat, + }; + } + + late final Map _json; + + static const SchemanticType $schema = + _OpenAITtsOptionsTypeFactory(); + + String? get version { + return _json['version'] as String?; + } + + set version(String? value) { + if (value == null) { + _json.remove('version'); + } else { + _json['version'] = value; + } + } + + String? get voice { + return _json['voice'] as String?; + } + + set voice(String? value) { + if (value == null) { + _json.remove('voice'); + } else { + _json['voice'] = value; + } + } + + double? get speed { + return (_json['speed'] as num?)?.toDouble(); + } + + set speed(double? value) { + if (value == null) { + _json.remove('speed'); + } else { + _json['speed'] = value; + } + } + + String? get responseFormat { + return _json['responseFormat'] as String?; + } + + set responseFormat(String? value) { + if (value == null) { + _json.remove('responseFormat'); + } else { + _json['responseFormat'] = value; + } + } + + @override + String toString() { + return _json.toString(); + } + + Map toJson() { + return _json; + } +} + +base class _OpenAITtsOptionsTypeFactory + extends SchemanticType { + const _OpenAITtsOptionsTypeFactory(); + + @override + OpenAITtsOptions parse(Object? json) { + return OpenAITtsOptions._(json as Map); + } + + @override + JsonSchemaMetadata get schemaMetadata => JsonSchemaMetadata( + name: 'OpenAITtsOptions', + definition: $Schema + .object( + properties: { + 'version': $Schema.string(), + 'voice': $Schema.string( + enumValues: ['alloy', 'echo', 'fable', 'onyx', 'nova', 'shimmer'], + ), + 'speed': $Schema.number(minimum: 0.25, maximum: 4.0), + 'responseFormat': $Schema.string( + enumValues: ['mp3', 'opus', 'aac', 'flac', 'wav', 'pcm'], + ), + }, + required: [], + ) + .value, + dependencies: [], + ); +} diff --git a/packages/genkit_openai/lib/src/utils.dart b/packages/genkit_openai/lib/src/utils.dart index 30fb3a65..54861a6b 100644 --- a/packages/genkit_openai/lib/src/utils.dart +++ b/packages/genkit_openai/lib/src/utils.dart @@ -13,6 +13,8 @@ // limitations under the License. import 'package:genkit/genkit.dart'; +import 'package:genkit/plugin.dart'; +import 'package:openai_dart/openai_dart.dart' as sdk; final RegExp _oSeriesPattern = RegExp(r'^o\d+(?:-|$)'); final RegExp _gptPattern = RegExp(r'^gpt-\d+(\.\d+)?o?(?:-|$)'); @@ -264,3 +266,54 @@ String getModelType(String modelId) { // Unknown model type. return 'unknown'; } + +final class OpenAIClientConfig { + final String apiKey; + final String? baseUrl; + final Map? headers; + + const OpenAIClientConfig({ + required this.apiKey, + required this.baseUrl, + required this.headers, + }); +} + +sdk.OpenAIClient buildOpenAIClient(OpenAIClientConfig config) { + return sdk.OpenAIClient( + config: sdk.OpenAIConfig( + authProvider: sdk.ApiKeyProvider(config.apiKey), + baseUrl: config.baseUrl ?? 'https://api.openai.com/v1', + defaultHeaders: config.headers ?? const {}, + ), + ); +} + +/// Converts any caught exception into a [GenkitException], re-throwing +/// [GenkitException]s unchanged. +Never rethrowAsGenkitException( + Object e, + StackTrace stackTrace, + String operation, +) { + if (e is GenkitException) Error.throwWithStackTrace(e, stackTrace); + + StatusCodes? status; + String? details; + + if (e is sdk.ApiException) { + status = StatusCodes.fromHttpStatus(e.statusCode); + details = e.body?.toString(); + } + + Error.throwWithStackTrace( + GenkitException( + 'OpenAI $operation error: $e', + status: status, + details: details ?? e.toString(), + underlyingException: e, + stackTrace: stackTrace, + ), + stackTrace, + ); +} diff --git a/packages/genkit_openai/test/openai_plugin_tts_test.dart b/packages/genkit_openai/test/openai_plugin_tts_test.dart new file mode 100644 index 00000000..dac82be9 --- /dev/null +++ b/packages/genkit_openai/test/openai_plugin_tts_test.dart @@ -0,0 +1,258 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import 'dart:convert'; +import 'dart:typed_data'; + +import 'package:genkit/plugin.dart'; +import 'package:genkit_openai/genkit_openai.dart'; +import 'package:genkit_openai/src/tts.dart' as tts; +import 'package:openai_dart/openai_dart.dart' as sdk; +import 'package:test/test.dart'; + +void main() { + group('supportedTtsModels', () { + test('contains all expected model IDs', () { + expect( + supportedTtsModels, + containsAll(['tts-1', 'tts-1-hd', 'gpt-4o-mini-tts']), + ); + }); + }); + + group('OpenAITtsOptions', () { + test('default constructor produces empty options', () { + final opts = OpenAITtsOptions(); + expect(opts.voice, isNull); + expect(opts.speed, isNull); + expect(opts.responseFormat, isNull); + expect(opts.version, isNull); + }); + + test('named constructor round-trips via toJson / fromJson', () { + final opts = OpenAITtsOptions( + voice: 'nova', + speed: 1.5, + responseFormat: 'wav', + version: 'tts-1-hd', + ); + final json = opts.toJson(); + final parsed = OpenAITtsOptions.fromJson(json); + + expect(parsed.voice, 'nova'); + expect(parsed.speed, 1.5); + expect(parsed.responseFormat, 'wav'); + expect(parsed.version, 'tts-1-hd'); + }); + + test('schema exposes voice enum values', () { + final props = + OpenAITtsOptions.$schema.jsonSchema()['properties'] + as Map; + final voiceProp = props['voice'] as Map; + expect( + voiceProp['enum'], + containsAll(['alloy', 'echo', 'fable', 'onyx', 'nova', 'shimmer']), + ); + }); + + test('schema exposes responseFormat enum values', () { + final props = + OpenAITtsOptions.$schema.jsonSchema()['properties'] + as Map; + final fmtProp = props['responseFormat'] as Map; + expect( + fmtProp['enum'], + containsAll(['mp3', 'opus', 'aac', 'flac', 'wav', 'pcm']), + ); + }); + + test('schema exposes speed with min/max constraints', () { + final props = + OpenAITtsOptions.$schema.jsonSchema()['properties'] + as Map; + final speedProp = props['speed'] as Map; + expect(speedProp['minimum'], 0.25); + expect(speedProp['maximum'], 4.0); + }); + }); + + group('parseTtsOptions', () { + test('parses a populated config map', () { + final opts = parseTtsOptions({ + 'voice': 'shimmer', + 'speed': 0.8, + 'responseFormat': 'opus', + 'version': 'tts-1', + }); + expect(opts.voice, 'shimmer'); + expect(opts.speed, 0.8); + expect(opts.responseFormat, 'opus'); + expect(opts.version, 'tts-1'); + }); + + test('returns defaults for null config', () { + final opts = parseTtsOptions(null); + expect(opts.voice, isNull); + expect(opts.speed, isNull); + }); + }); + + group('ttsModelInfo', () { + test('reports media output and no media input', () { + final info = ttsModelInfo('tts-1'); + expect(info.label, 'tts-1'); + expect(info.supports?['output'], contains('media')); + expect(info.supports?['media'], isFalse); + expect(info.supports?['multiturn'], isFalse); + expect(info.supports?['tools'], isFalse); + }); + }); + + group('ttsModelRef', () { + test('produces a ref under the openai namespace', () { + expect(ttsModelRef('tts-1').name, 'openai/tts-1'); + }); + + test('customOptions is the OpenAITtsOptions schema', () { + final ref = ttsModelRef('tts-1'); + expect(ref.customOptions, same(OpenAITtsOptions.$schema)); + }); + }); + + group('openAI handle', () { + test('speech() returns a TTS ModelRef', () { + final ref = openAI.speech('tts-1'); + expect(ref.name, 'openai/tts-1'); + }); + + test('speech() ModelRef uses OpenAITtsOptions schema', () { + final ref = openAI.speech('tts-1-hd'); + expect(ref.customOptions, same(OpenAITtsOptions.$schema)); + }); + }); + + group('parseSpeechVoice', () { + test('parses all valid OpenAI voices', () { + for (final voice in [ + 'alloy', + 'echo', + 'fable', + 'onyx', + 'nova', + 'shimmer', + ]) { + expect( + tts.parseSpeechVoice(voice).toJson(), + voice, + reason: 'failed for voice: $voice', + ); + } + }); + + test('defaults to alloy for null', () { + expect(tts.parseSpeechVoice(null), sdk.SpeechVoice.alloy); + }); + + test('defaults to alloy for an unrecognised voice', () { + expect(tts.parseSpeechVoice('robot-voice'), sdk.SpeechVoice.alloy); + }); + }); + + group('parseSpeechResponseFormat', () { + test('parses all valid audio formats', () { + for (final fmt in ['mp3', 'opus', 'aac', 'flac', 'wav', 'pcm']) { + expect( + tts.parseSpeechResponseFormat(fmt).toJson(), + fmt, + reason: 'failed for format: $fmt', + ); + } + }); + + test('defaults to mp3 for null', () { + expect(tts.parseSpeechResponseFormat(null), sdk.SpeechResponseFormat.mp3); + }); + + test('defaults to mp3 for an unrecognised format', () { + expect( + tts.parseSpeechResponseFormat('xyz'), + sdk.SpeechResponseFormat.mp3, + ); + }); + }); + + group('speechToModelResponse', () { + final fakeAudio = Uint8List.fromList([0xDE, 0xAD, 0xBE, 0xEF]); + + test('finish reason is stop', () { + final r = tts.speechToModelResponse(fakeAudio, 'mp3', {}); + expect(r.finishReason, FinishReason.stop); + }); + + test('message has model role with one MediaPart', () { + final r = tts.speechToModelResponse(fakeAudio, 'mp3', {}); + final msg = r.message!; + expect(msg.role, Role.model); + expect(msg.content.length, 1); + expect(msg.content.first.isMedia, isTrue); + }); + + test('mp3 uses audio/mpeg MIME type', () { + final r = tts.speechToModelResponse(fakeAudio, 'mp3', {}); + expect(r.message!.media!.contentType, 'audio/mpeg'); + expect(r.message!.media!.url, startsWith('data:audio/mpeg;base64,')); + }); + + test('wav uses audio/wav MIME type', () { + final r = tts.speechToModelResponse(fakeAudio, 'wav', {}); + expect(r.message!.media!.contentType, 'audio/wav'); + }); + + test('opus uses audio/opus MIME type', () { + final r = tts.speechToModelResponse(fakeAudio, 'opus', {}); + expect(r.message!.media!.contentType, 'audio/opus'); + }); + + test('base64 payload round-trips correctly', () { + final r = tts.speechToModelResponse(fakeAudio, 'mp3', {}); + final url = r.message!.media!.url; + final decoded = base64Decode(url.substring(url.indexOf(',') + 1)); + expect(decoded, equals(fakeAudio)); + }); + + test('unknown format falls back to audio/mpeg', () { + final r = tts.speechToModelResponse(fakeAudio, 'unknown', {}); + expect(r.message!.media!.contentType, 'audio/mpeg'); + }); + + test('raw map is passed through unchanged', () { + final raw = {'model': 'tts-1', 'responseFormat': 'mp3'}; + final r = tts.speechToModelResponse(fakeAudio, 'mp3', raw); + expect(r.raw, raw); + }); + }); + + group('audioMimeTypes', () { + test('covers all TTS response formats', () { + for (final fmt in ['mp3', 'opus', 'aac', 'flac', 'wav', 'pcm']) { + expect( + audioMimeTypes.containsKey(fmt), + isTrue, + reason: 'missing entry for $fmt', + ); + } + }); + }); +} diff --git a/testapps/openai_sample/lib/text_to_speech.dart b/testapps/openai_sample/lib/text_to_speech.dart new file mode 100644 index 00000000..7ecc7e29 --- /dev/null +++ b/testapps/openai_sample/lib/text_to_speech.dart @@ -0,0 +1,129 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import 'dart:io'; + +import 'package:genkit/genkit.dart'; +import 'package:genkit_openai/genkit_openai.dart'; + +/// Defines a flow using the defaults +Flow defineDefaultTextToSpeechFlow(Genkit ai) { + return ai.defineFlow( + name: 'defaultTextToSpeech', + inputSchema: .string( + defaultValue: 'Genkit Dart supports OpenAI text to speech.', + ), + outputSchema: Media.$schema, + fn: (prompt, _) async { + final response = await ai.generate( + model: openAI.speech('gpt-4o-mini-tts'), + prompt: prompt, + config: OpenAITtsOptions(), + ); + + final media = response.media; + if (media == null) throw StateError('Model returned no audio media.'); + return media; + }, + ); +} + +/// Demonstrates the [OpenAITtsOptions] fields on the `tts-1-hd` model. +Flow defineCustomTextToSpeechFlow(Genkit ai) { + return ai.defineFlow( + name: 'customTextToSpeech', + inputSchema: .string( + defaultValue: 'Genkit Dart supports high-quality speech synthesis.', + ), + outputSchema: Media.$schema, + fn: (prompt, _) async { + final response = await ai.generate( + model: openAI.speech('tts-1-hd'), + prompt: prompt, + config: OpenAITtsOptions( + voice: 'nova', + speed: 1.25, + responseFormat: 'wav', + ), + ); + + final media = response.media; + if (media == null) throw StateError('Model returned no audio media.'); + return media; + }, + ); +} + +/// Shows supported voices on the `tts-1` model by defining one flow per voice. +List> defineVoiceShowcaseFlows(Genkit ai) { + const voices = ['alloy', 'echo', 'fable', 'onyx', 'nova', 'shimmer']; + + return [ + for (final voice in voices) + ai.defineFlow( + name: 'tts_$voice', + inputSchema: .string(defaultValue: 'Hello, I am the $voice voice.'), + outputSchema: Media.$schema, + fn: (prompt, _) async { + final response = await ai.generate( + model: openAI.speech('tts-1'), + prompt: prompt, + config: OpenAITtsOptions(voice: voice, responseFormat: 'mp3'), + ); + + final media = response.media; + if (media == null) throw StateError('Model returned no audio media.'); + return media; + }, + ), + ]; +} + +/// Shows supported `responseFormat` values on the `tts-1` model. +List> defineFormatShowcaseFlows(Genkit ai) { + const formats = ['mp3', 'opus', 'aac', 'flac', 'wav', 'pcm']; + + return [ + for (final format in formats) + ai.defineFlow( + name: 'tts_format_$format', + inputSchema: .string( + defaultValue: 'This sample is encoded as $format.', + ), + outputSchema: Media.$schema, + fn: (prompt, _) async { + final response = await ai.generate( + model: openAI.speech('tts-1'), + prompt: prompt, + config: OpenAITtsOptions(voice: 'alloy', responseFormat: format), + ); + + final media = response.media; + if (media == null) throw StateError('Model returned no audio media.'); + return media; + }, + ), + ]; +} + +void main() { + final ai = Genkit( + plugins: [openAI(apiKey: Platform.environment['OPENAI_API_KEY'])], + ); + + defineDefaultTextToSpeechFlow(ai); + defineCustomTextToSpeechFlow(ai); + defineVoiceShowcaseFlows(ai); + defineFormatShowcaseFlows(ai); +}