diff --git a/packages/genkit_openai/lib/genkit_openai.dart b/packages/genkit_openai/lib/genkit_openai.dart index 1c33cb47..3efb20be 100644 --- a/packages/genkit_openai/lib/genkit_openai.dart +++ b/packages/genkit_openai/lib/genkit_openai.dart @@ -16,11 +16,23 @@ import 'dart:async'; import 'package:genkit/plugin.dart'; -import 'src/chat.dart' as chat; +import 'src/chat.dart' as chat_lib; import 'src/openai_plugin.dart'; +import 'src/stt.dart' as stt_lib; export 'src/chat.dart' show OpenAIChatOptions, OpenAIOptions; export 'src/converters.dart' show GenkitConverter; +export 'src/stt.dart' + show + OpenAISttOptions, + buildTranscriptionRequest, + buildTranslationRequest, + parseSttModelOptions, + sttModelInfo, + sttModelOptionsSchema, + transcriptionModelIds, + transcriptionToModelResponse, + whisperModelIds; export 'src/utils.dart' show defaultModelInfo, @@ -65,11 +77,26 @@ class OpenAICompatPluginHandle { ); } - /// Reference to a model - ModelRef model(String name) { + /// Reference to a chat model. + ModelRef model(String name) { + return modelRef( + 'openai/$name', + customOptions: chat_lib.chatModelOptionsSchema(), + ); + } + + /// Reference to a speech-to-text model. + /// + /// Works with both Whisper models (e.g. `'whisper-1'`) and GPT transcription + /// models (e.g. `'gpt-4o-transcribe'`, `'gpt-4o-mini-transcribe'`). + ModelRef stt( + String name, { + stt_lib.OpenAISttOptions? config, + }) { return modelRef( 'openai/$name', - customOptions: chat.chatModelOptionsSchema(), + customOptions: stt_lib.sttModelOptionsSchema(), + config: config, ); } } diff --git a/packages/genkit_openai/lib/src/openai_plugin.dart b/packages/genkit_openai/lib/src/openai_plugin.dart index d95c44cf..ac309710 100644 --- a/packages/genkit_openai/lib/src/openai_plugin.dart +++ b/packages/genkit_openai/lib/src/openai_plugin.dart @@ -16,7 +16,9 @@ import 'package:genkit/plugin.dart'; import 'package:openai_dart/openai_dart.dart' as sdk; import '../genkit_openai.dart'; -import 'chat.dart' as chat; +import 'chat.dart' as chat_lib; +import 'stt.dart' as stt_lib; +import 'utils.dart'; /// Core plugin implementation class OpenAIPlugin extends GenkitPlugin { @@ -56,12 +58,14 @@ class OpenAIPlugin extends GenkitPlugin { for (final modelId in availableModelIds) { final modelType = getModelType(modelId); - if (modelType != 'chat' && modelType != 'unknown') { - continue; + if (modelType == 'chat' || modelType == 'unknown') { + final info = modelInfoFor(modelId); + actions.add(_createModel(modelId, info)); + } else if (modelType == 'audio') { + if (_isSttModelId(modelId)) { + actions.add(_createSttModel(modelId)); + } } - - final info = modelInfoFor(modelId); - actions.add(_createModel(modelId, info)); } } catch (e) { throw GenkitException( @@ -83,13 +87,7 @@ class OpenAIPlugin extends GenkitPlugin { Future> _fetchAvailableModels() async { final resolvedConfig = await _resolveClientConfig(); - final client = sdk.OpenAIClient( - config: sdk.OpenAIConfig( - authProvider: sdk.ApiKeyProvider(resolvedConfig.apiKey), - baseUrl: resolvedConfig.baseUrl ?? 'https://api.openai.com/v1', - defaultHeaders: resolvedConfig.headers ?? const {}, - ), - ); + final client = buildOpenAIClient(resolvedConfig); try { final response = await client.models.list(); @@ -106,7 +104,7 @@ class OpenAIPlugin extends GenkitPlugin { } } - Future<_ResolvedClientConfig> _resolveClientConfig() async { + Future _resolveClientConfig() async { final configuredApiKey = await _resolveApiKey(); if (configuredApiKey == null || configuredApiKey.trim().isEmpty) { throw GenkitException( @@ -115,7 +113,7 @@ class OpenAIPlugin extends GenkitPlugin { ); } - return _ResolvedClientConfig( + return OpenAIClientConfig( apiKey: configuredApiKey.trim(), baseUrl: baseUrl, headers: headers, @@ -135,25 +133,32 @@ class OpenAIPlugin extends GenkitPlugin { list() async { try { final modelIds = await _fetchAvailableModels(); - final modelMetadataList = + final metadataList = >[]; for (final modelId in modelIds) { final modelType = getModelType(modelId); - if (modelType != 'chat' && modelType != 'unknown') { - continue; - } - modelMetadataList.add( - modelMetadata( - 'openai/$modelId', - modelInfo: modelInfoFor(modelId), - customOptions: chat.chatModelOptionsSchema(), - ), - ); + if (modelType == 'chat' || modelType == 'unknown') { + metadataList.add( + modelMetadata( + 'openai/$modelId', + modelInfo: modelInfoFor(modelId), + customOptions: chat_lib.chatModelOptionsSchema(), + ), + ); + } else if (modelType == 'audio' && _isSttModelId(modelId)) { + metadataList.add( + modelMetadata( + 'openai/$modelId', + modelInfo: stt_lib.sttModelInfo, + customOptions: stt_lib.sttModelOptionsSchema(), + ), + ); + } } - return modelMetadataList; + return metadataList; } catch (e, stackTrace) { throw GenkitException( 'Error listing models from OpenAI: $e', @@ -165,10 +170,81 @@ class OpenAIPlugin extends GenkitPlugin { @override Action? resolve(String actionType, String name) { - if (actionType == 'model') { - return _createModel(name, null); - } - return null; + if (actionType != 'model') return null; + if (_isSttModelId(name)) return _createSttModel(name); + return _createModel(name, null); + } + + static bool _isSttModelId(String modelId) { + final id = modelId.toLowerCase(); + return id.contains('whisper') || id.contains('transcribe'); + } + + /// Creates a Genkit [Model] that calls the OpenAI speech-to-text API. + /// + /// Whisper models support an additional `translate: true` option that routes + /// the request through the translation endpoint instead of transcriptions. + Model _createSttModel(String modelName) { + return Model( + name: 'openai/$modelName', + customOptions: stt_lib.sttModelOptionsSchema(), + metadata: {'model': stt_lib.sttModelInfo.toJson()}, + fn: (req, ctx) async { + final modelRequest = req!; + final options = stt_lib.parseSttModelOptions(modelRequest.config); + + final resolvedConfig = await _resolveClientConfig(); + final client = buildOpenAIClient(resolvedConfig); + + try { + final shouldTranslate = + options.translate == true && modelName.contains('whisper'); + + if (shouldTranslate) { + final translationReq = stt_lib.buildTranslationRequest( + modelId: modelName, + request: modelRequest, + options: options, + ); + final response = await client.audio.translations.create( + translationReq, + ); + return stt_lib.transcriptionToModelResponse( + response.text, + raw: response.toJson(), + ); + } + + final transcriptionReq = stt_lib.buildTranscriptionRequest( + modelId: modelName, + request: modelRequest, + options: options, + ); + + if (options.responseFormat == 'verbose_json') { + final response = await client.audio.transcriptions.createVerbose( + transcriptionReq, + ); + return stt_lib.transcriptionToModelResponse( + response.text, + raw: response.toJson(), + ); + } + + final response = await client.audio.transcriptions.create( + transcriptionReq, + ); + return stt_lib.transcriptionToModelResponse( + response.text, + raw: response.toJson(), + ); + } catch (e, stackTrace) { + rethrowAsGenkitException(e, stackTrace, 'STT'); + } finally { + client.close(); + } + }, + ); } Model _createModel(String modelName, ModelInfo? info) { @@ -176,30 +252,24 @@ class OpenAIPlugin extends GenkitPlugin { return Model( name: 'openai/$modelName', - customOptions: chat.chatModelOptionsSchema(), + customOptions: chat_lib.chatModelOptionsSchema(), metadata: {'model': modelInfo.toJson()}, fn: (req, ctx) async { final modelRequest = req!; - final options = chat.parseChatModelOptions(modelRequest.config); + final options = chat_lib.parseChatModelOptions(modelRequest.config); final resolvedConfig = await _resolveClientConfig(); - final client = sdk.OpenAIClient( - config: sdk.OpenAIConfig( - authProvider: sdk.ApiKeyProvider(resolvedConfig.apiKey), - baseUrl: resolvedConfig.baseUrl ?? 'https://api.openai.com/v1', - defaultHeaders: resolvedConfig.headers ?? const {}, - ), - ); + final client = buildOpenAIClient(resolvedConfig); try { final supports = modelInfo.supports; final supportsTools = supports?['tools'] == true; - final isJsonMode = chat.isJsonStructuredOutput( + final isJsonMode = chat_lib.isJsonStructuredOutput( modelRequest.output?.format, modelRequest.output?.contentType, ); - final responseFormat = chat.buildOpenAIResponseFormat( + final responseFormat = chat_lib.buildOpenAIResponseFormat( modelRequest.output?.schema, ); final request = sdk.ChatCompletionCreateRequest( @@ -227,25 +297,7 @@ class OpenAIPlugin extends GenkitPlugin { return await _handleNonStreaming(client, request); } } catch (e, stackTrace) { - if (e is GenkitException) { - rethrow; - } - - StatusCodes? status; - String? details; - - if (e is sdk.ApiException) { - status = StatusCodes.fromHttpStatus(e.statusCode); - details = e.body?.toString(); - } - - throw GenkitException( - 'OpenAI API error: $e', - status: status, - details: details ?? e.toString(), - underlyingException: e, - stackTrace: stackTrace, - ); + rethrowAsGenkitException(e, stackTrace, 'Chat'); } finally { client.close(); } @@ -329,15 +381,3 @@ class OpenAIPlugin extends GenkitPlugin { ); } } - -final class _ResolvedClientConfig { - final String apiKey; - final String? baseUrl; - final Map? headers; - - const _ResolvedClientConfig({ - required this.apiKey, - required this.baseUrl, - required this.headers, - }); -} diff --git a/packages/genkit_openai/lib/src/stt.dart b/packages/genkit_openai/lib/src/stt.dart new file mode 100644 index 00000000..fd723fb6 --- /dev/null +++ b/packages/genkit_openai/lib/src/stt.dart @@ -0,0 +1,248 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import 'dart:convert'; +import 'dart:typed_data'; + +import 'package:genkit/plugin.dart'; +import 'package:openai_dart/openai_dart.dart' as sdk; +import 'package:schemantic/schemantic.dart'; + +part 'stt.g.dart'; + +/// Model info for Whisper and transcription models. +final ModelInfo sttModelInfo = ModelInfo( + label: 'OpenAI STT', + supports: { + 'media': true, + 'output': ['text', 'json'], + 'multiturn': false, + 'systemRole': false, + 'tools': false, + }, +); + +/// Known whisper models. +const List whisperModelIds = ['whisper-1']; + +/// Known GPT transcription models. +const List transcriptionModelIds = [ + 'gpt-4o-transcribe', + 'gpt-4o-mini-transcribe', +]; + +/// Options for OpenAI speech-to-text (transcription / Whisper) models. +@Schema() +abstract class $OpenAISttOptions { + /// Model version override (e.g. 'whisper-1'). + String? get version; + + /// Sampling temperature (0.0 – 1.0). + @DoubleField(minimum: 0.0, maximum: 1.0) + double? get temperature; + + /// BCP-47 language code of the audio (e.g. 'en', 'fr'). + /// + /// When provided, transcription accuracy improves. + String? get language; + + /// Output format for the transcription result. + /// + /// One of: 'json', 'text', 'srt', 'verbose_json', 'vtt'. + @StringField(enumValues: ['json', 'text', 'srt', 'verbose_json', 'vtt']) + String? get responseFormat; + + /// Timestamp granularities to include (requires 'verbose_json' format). + /// + /// Each value must be one of: 'word', 'segment'. + List? get timestampGranularities; + + /// When true, translates audio to English instead of transcribing in-language. + /// + /// Only supported by Whisper models; ignored for gpt-4o-transcribe variants. + bool? get translate; +} + +/// Returns the [SchemanticType] for [OpenAISttOptions]. +SchemanticType sttModelOptionsSchema() => + OpenAISttOptions.$schema; + +/// Parses STT model options from action config. +OpenAISttOptions parseSttModelOptions(Map? config) { + return config != null + ? OpenAISttOptions.$schema.parse(config) + : OpenAISttOptions(); +} + +/// Builds an SDK [sdk.TranscriptionRequest] from a Genkit [ModelRequest]. +sdk.TranscriptionRequest buildTranscriptionRequest({ + required String modelId, + required ModelRequest request, + required OpenAISttOptions options, +}) { + final audioFile = _extractAudioFile(request); + final granularities = _parseGranularities(options.timestampGranularities); + final format = _parseTranscriptionFormat(options.responseFormat); + return sdk.TranscriptionRequest( + file: audioFile.bytes, + filename: audioFile.filename, + model: options.version ?? modelId, + language: options.language, + prompt: _extractPromptText(request), + responseFormat: format, + temperature: options.temperature, + timestampGranularities: granularities.isNotEmpty ? granularities : null, + ); +} + +/// Builds an SDK [sdk.TranslationRequest] from a Genkit [ModelRequest]. +/// +/// Used when `translate: true` is set on a Whisper model. +sdk.TranslationRequest buildTranslationRequest({ + required String modelId, + required ModelRequest request, + required OpenAISttOptions options, +}) { + final audioFile = _extractAudioFile(request); + final format = _parseTranscriptionFormat(options.responseFormat); + return sdk.TranslationRequest( + file: audioFile.bytes, + filename: audioFile.filename, + model: options.version ?? modelId, + prompt: _extractPromptText(request), + responseFormat: format, + temperature: options.temperature, + ); +} + +/// Converts a transcription/translation text result to a [ModelResponse]. +ModelResponse transcriptionToModelResponse( + String text, { + Map? raw, +}) { + return ModelResponse( + finishReason: FinishReason.stop, + message: Message( + role: Role.model, + content: [TextPart(text: text)], + ), + raw: raw ?? {'text': text}, + ); +} + +// --------------------------------------------------------------------------- +// Private helpers +// --------------------------------------------------------------------------- + +class _AudioFile { + const _AudioFile({required this.bytes, required this.filename}); + final Uint8List bytes; + final String filename; +} + +_AudioFile _extractAudioFile(ModelRequest request) { + if (request.messages.isEmpty) { + throw GenkitException( + 'STT request must contain at least one message with a media part.', + status: StatusCodes.INVALID_ARGUMENT, + ); + } + + for (final message in request.messages) { + for (final part in message.content) { + if (part.isMedia) { + final media = part.media; + if (media != null) { + return _mediaToFile(media); + } + } + } + } + + throw GenkitException( + 'No audio media part found in the request messages.', + status: StatusCodes.INVALID_ARGUMENT, + ); +} + +_AudioFile _mediaToFile(Media media) { + final contentType = media.contentType ?? _contentTypeFromDataUrl(media.url); + if (contentType == null || contentType.isEmpty) { + throw GenkitException( + 'Media part is missing a content type.', + status: StatusCodes.INVALID_ARGUMENT, + ); + } + final ext = _extensionFromContentType(contentType); + final bytes = _bytesFromDataUrl(media.url); + return _AudioFile(filename: 'input.$ext', bytes: Uint8List.fromList(bytes)); +} + +String? _contentTypeFromDataUrl(String url) { + if (!url.startsWith('data:')) return null; + final semi = url.indexOf(';'); + if (semi <= 'data:'.length) return null; + return url.substring('data:'.length, semi); +} + +List _bytesFromDataUrl(String dataUrl) { + final marker = dataUrl.indexOf(','); + final body = marker >= 0 ? dataUrl.substring(marker + 1) : dataUrl; + return base64Decode(body); +} + +String _extensionFromContentType(String contentType) { + return switch (contentType) { + 'audio/mpeg' || 'audio/mp3' => 'mp3', + 'audio/mp4' => 'mp4', + 'audio/ogg' => 'ogg', + 'audio/wav' || 'audio/x-wav' => 'wav', + 'audio/webm' => 'webm', + 'audio/flac' => 'flac', + 'audio/m4a' || 'audio/x-m4a' => 'm4a', + _ => 'mp3', + }; +} + +String? _extractPromptText(ModelRequest request) { + for (final message in request.messages) { + for (final part in message.content) { + final text = part.text; + if (text != null && text.isNotEmpty) return text; + } + } + return null; +} + +sdk.TranscriptionResponseFormat? _parseTranscriptionFormat(String? value) { + if (value == null) return null; + try { + return sdk.TranscriptionResponseFormat.fromJson(value); + } catch (_) { + return null; + } +} + +List _parseGranularities(List? values) { + if (values == null || values.isEmpty) return const []; + final result = []; + for (final v in values) { + try { + result.add(sdk.TimestampGranularity.fromJson(v)); + } catch (_) { + // Skip unknown values. + } + } + return result; +} diff --git a/packages/genkit_openai/lib/src/stt.g.dart b/packages/genkit_openai/lib/src/stt.g.dart new file mode 100644 index 00000000..55e57a9f --- /dev/null +++ b/packages/genkit_openai/lib/src/stt.g.dart @@ -0,0 +1,164 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// GENERATED CODE - DO NOT MODIFY BY HAND +// dart format width=80 + +part of 'stt.dart'; + +// ************************************************************************** +// SchemaGenerator +// ************************************************************************** + +base class OpenAISttOptions { + factory OpenAISttOptions.fromJson(Map json) => + $schema.parse(json); + + OpenAISttOptions._(this._json); + + OpenAISttOptions({ + String? version, + double? temperature, + String? language, + String? responseFormat, + List? timestampGranularities, + bool? translate, + }) { + _json = { + 'version': ?version, + 'temperature': ?temperature, + 'language': ?language, + 'responseFormat': ?responseFormat, + 'timestampGranularities': ?timestampGranularities, + 'translate': ?translate, + }; + } + + late final Map _json; + + static const SchemanticType $schema = + _OpenAISttOptionsTypeFactory(); + + String? get version { + return _json['version'] as String?; + } + + set version(String? value) { + if (value == null) { + _json.remove('version'); + } else { + _json['version'] = value; + } + } + + double? get temperature { + return (_json['temperature'] as num?)?.toDouble(); + } + + set temperature(double? value) { + if (value == null) { + _json.remove('temperature'); + } else { + _json['temperature'] = value; + } + } + + String? get language { + return _json['language'] as String?; + } + + set language(String? value) { + if (value == null) { + _json.remove('language'); + } else { + _json['language'] = value; + } + } + + String? get responseFormat { + return _json['responseFormat'] as String?; + } + + set responseFormat(String? value) { + if (value == null) { + _json.remove('responseFormat'); + } else { + _json['responseFormat'] = value; + } + } + + List? get timestampGranularities { + return (_json['timestampGranularities'] as List?)?.cast(); + } + + set timestampGranularities(List? value) { + if (value == null) { + _json.remove('timestampGranularities'); + } else { + _json['timestampGranularities'] = value; + } + } + + bool? get translate { + return _json['translate'] as bool?; + } + + set translate(bool? value) { + if (value == null) { + _json.remove('translate'); + } else { + _json['translate'] = value; + } + } + + @override + String toString() { + return _json.toString(); + } + + Map toJson() { + return _json; + } +} + +base class _OpenAISttOptionsTypeFactory + extends SchemanticType { + const _OpenAISttOptionsTypeFactory(); + + @override + OpenAISttOptions parse(Object? json) { + return OpenAISttOptions._(json as Map); + } + + @override + JsonSchemaMetadata get schemaMetadata => JsonSchemaMetadata( + name: 'OpenAISttOptions', + definition: $Schema + .object( + properties: { + 'version': $Schema.string(), + 'temperature': $Schema.number(minimum: 0.0, maximum: 1.0), + 'language': $Schema.string(), + 'responseFormat': $Schema.string( + enumValues: ['json', 'text', 'srt', 'verbose_json', 'vtt'], + ), + 'timestampGranularities': $Schema.list(items: $Schema.string()), + 'translate': $Schema.boolean(), + }, + required: [], + ) + .value, + dependencies: [], + ); +} diff --git a/packages/genkit_openai/lib/src/utils.dart b/packages/genkit_openai/lib/src/utils.dart index 30fb3a65..b5978c21 100644 --- a/packages/genkit_openai/lib/src/utils.dart +++ b/packages/genkit_openai/lib/src/utils.dart @@ -13,6 +13,7 @@ // limitations under the License. import 'package:genkit/genkit.dart'; +import 'package:openai_dart/openai_dart.dart' as sdk; final RegExp _oSeriesPattern = RegExp(r'^o\d+(?:-|$)'); final RegExp _gptPattern = RegExp(r'^gpt-\d+(\.\d+)?o?(?:-|$)'); @@ -264,3 +265,53 @@ String getModelType(String modelId) { // Unknown model type. return 'unknown'; } + +/// Resolved OpenAI client config used to construct SDK clients consistently. +final class OpenAIClientConfig { + final String apiKey; + final String? baseUrl; + final Map? headers; + + const OpenAIClientConfig({ + required this.apiKey, + required this.baseUrl, + required this.headers, + }); +} + +/// Builds an OpenAI SDK client from a resolved [OpenAIClientConfig]. +sdk.OpenAIClient buildOpenAIClient(OpenAIClientConfig config) { + return sdk.OpenAIClient( + config: sdk.OpenAIConfig( + authProvider: sdk.ApiKeyProvider(config.apiKey), + baseUrl: config.baseUrl ?? 'https://api.openai.com/v1', + defaultHeaders: config.headers ?? const {}, + ), + ); +} + +/// Re-throws unknown SDK errors as a normalized [GenkitException]. +Never rethrowAsGenkitException( + Object error, + StackTrace stackTrace, + String surface, +) { + if (error is GenkitException) { + throw error; + } + + StatusCodes? status; + String? details; + if (error is sdk.ApiException) { + status = StatusCodes.fromHttpStatus(error.statusCode); + details = error.body?.toString(); + } + + throw GenkitException( + 'OpenAI $surface API error: $error', + status: status, + details: details ?? error.toString(), + underlyingException: error, + stackTrace: stackTrace, + ); +} diff --git a/packages/genkit_openai/test/openai_plugin_stt_test.dart b/packages/genkit_openai/test/openai_plugin_stt_test.dart new file mode 100644 index 00000000..e77cd1f7 --- /dev/null +++ b/packages/genkit_openai/test/openai_plugin_stt_test.dart @@ -0,0 +1,291 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import 'dart:convert'; + +import 'package:genkit/genkit.dart'; +import 'package:genkit_openai/genkit_openai.dart'; +import 'package:openai_dart/openai_dart.dart' as sdk; +import 'package:test/test.dart'; + +void main() { + group('OpenAISttOptions', () { + test('creates default options with all nulls', () { + final opts = OpenAISttOptions(); + expect(opts.version, isNull); + expect(opts.temperature, isNull); + expect(opts.language, isNull); + expect(opts.responseFormat, isNull); + expect(opts.timestampGranularities, isNull); + expect(opts.translate, isNull); + }); + + test('parses temperature from config', () { + final opts = parseSttModelOptions({'temperature': 0.5}); + expect(opts.temperature, 0.5); + }); + + test('parses language from config', () { + final opts = parseSttModelOptions({'language': 'fr'}); + expect(opts.language, 'fr'); + }); + + test('parses responseFormat from config', () { + final opts = parseSttModelOptions({'responseFormat': 'verbose_json'}); + expect(opts.responseFormat, 'verbose_json'); + }); + + test('parses timestampGranularities from config', () { + final opts = parseSttModelOptions({ + 'timestampGranularities': ['word', 'segment'], + }); + expect(opts.timestampGranularities, ['word', 'segment']); + }); + + test('parses translate flag from config', () { + final opts = parseSttModelOptions({'translate': true}); + expect(opts.translate, isTrue); + }); + + test('returns default options for null config', () { + final opts = parseSttModelOptions(null); + expect(opts.temperature, isNull); + expect(opts.translate, isNull); + }); + }); + + group('buildTranscriptionRequest', () { + test('builds request from model request with media part', () { + final wavBytes = _fakeAudioBase64('audio/wav'); + final request = ModelRequest( + messages: [ + Message( + role: Role.user, + content: [ + MediaPart( + media: Media( + url: 'data:audio/wav;base64,$wavBytes', + contentType: 'audio/wav', + ), + ), + ], + ), + ], + ); + final opts = OpenAISttOptions(temperature: 0.3, language: 'en'); + + final transcriptionReq = buildTranscriptionRequest( + modelId: 'whisper-1', + request: request, + options: opts, + ); + + expect(transcriptionReq.model, 'whisper-1'); + expect(transcriptionReq.filename, 'input.wav'); + expect(transcriptionReq.temperature, 0.3); + expect(transcriptionReq.language, 'en'); + expect(transcriptionReq.file, isNotEmpty); + }); + + test('uses version override when provided', () { + final request = _minimalAudioRequest(); + final opts = OpenAISttOptions(version: 'whisper-1'); + + final req = buildTranscriptionRequest( + modelId: 'whisper-1', + request: request, + options: opts, + ); + + expect(req.model, 'whisper-1'); + }); + + test('includes prompt text from message', () { + final mp3Bytes = _fakeAudioBase64('audio/mpeg'); + final request = ModelRequest( + messages: [ + Message( + role: Role.user, + content: [ + TextPart(text: 'please transcribe carefully'), + MediaPart( + media: Media( + url: 'data:audio/mpeg;base64,$mp3Bytes', + contentType: 'audio/mpeg', + ), + ), + ], + ), + ], + ); + + final req = buildTranscriptionRequest( + modelId: 'whisper-1', + request: request, + options: OpenAISttOptions(), + ); + + expect(req.prompt, 'please transcribe carefully'); + }); + + test('maps responseFormat to SDK enum', () { + final request = _minimalAudioRequest(); + final opts = OpenAISttOptions(responseFormat: 'srt'); + + final req = buildTranscriptionRequest( + modelId: 'whisper-1', + request: request, + options: opts, + ); + + expect(req.responseFormat, sdk.TranscriptionResponseFormat.srt); + }); + + test('includes timestamp granularities', () { + final request = _minimalAudioRequest(); + final opts = OpenAISttOptions( + responseFormat: 'verbose_json', + timestampGranularities: ['word'], + ); + + final req = buildTranscriptionRequest( + modelId: 'whisper-1', + request: request, + options: opts, + ); + + expect( + req.timestampGranularities, + contains(sdk.TimestampGranularity.word), + ); + }); + + test('throws when no media part found', () { + final request = ModelRequest( + messages: [ + Message( + role: Role.user, + content: [TextPart(text: 'hello')], + ), + ], + ); + + expect( + () => buildTranscriptionRequest( + modelId: 'whisper-1', + request: request, + options: OpenAISttOptions(), + ), + throwsA(isA()), + ); + }); + }); + + group('buildTranslationRequest', () { + test('builds translation request with correct model', () { + final request = _minimalAudioRequest(); + final opts = OpenAISttOptions(temperature: 0.2); + + final req = buildTranslationRequest( + modelId: 'whisper-1', + request: request, + options: opts, + ); + + expect(req.model, 'whisper-1'); + expect(req.temperature, 0.2); + expect(req.file, isNotEmpty); + }); + }); + + group('transcriptionToModelResponse', () { + test('creates ModelResponse with transcribed text', () { + final response = transcriptionToModelResponse('Hello world'); + + expect(response.finishReason, FinishReason.stop); + expect(response.message?.role, Role.model); + expect(response.message?.text, 'Hello world'); + }); + + test('includes raw map in response', () { + final raw = {'text': 'Hi', 'language': 'en'}; + final response = transcriptionToModelResponse('Hi', raw: raw); + + expect(response.raw, raw); + }); + }); + + group('sttModelInfo', () { + test('reports correct capabilities', () { + expect(sttModelInfo.supports?['media'], isTrue); + expect(sttModelInfo.supports?['multiturn'], isFalse); + expect(sttModelInfo.supports?['tools'], isFalse); + expect(sttModelInfo.supports?['systemRole'], isFalse); + }); + }); + + group('OpenAICompatPluginHandle.stt', () { + test('creates ref for a Whisper model', () { + final ref = openAI.stt('whisper-1'); + expect(ref.name, 'openai/whisper-1'); + }); + + test('creates ref for a GPT transcription model', () { + final ref = openAI.stt('gpt-4o-transcribe'); + expect(ref.name, 'openai/gpt-4o-transcribe'); + }); + + test('carries config', () { + final config = OpenAISttOptions(language: 'de'); + final ref = openAI.stt('whisper-1', config: config); + expect(ref.config?.language, 'de'); + }); + }); + + group('STT model ID constants', () { + test('whisperModelIds contains whisper-1', () { + expect(whisperModelIds, contains('whisper-1')); + }); + + test('transcriptionModelIds contains gpt-4o-transcribe', () { + expect(transcriptionModelIds, contains('gpt-4o-transcribe')); + expect(transcriptionModelIds, contains('gpt-4o-mini-transcribe')); + }); + }); +} + +ModelRequest _minimalAudioRequest() { + final bytes = _fakeAudioBase64('audio/wav'); + return ModelRequest( + messages: [ + Message( + role: Role.user, + content: [ + MediaPart( + media: Media( + url: 'data:audio/wav;base64,$bytes', + contentType: 'audio/wav', + ), + ), + ], + ), + ], + ); +} + +/// Returns a base64-encoded stub audio payload for the given [contentType]. +String _fakeAudioBase64(String contentType) { + final stub = utf8.encode('fake-audio-data'); + return base64Encode(stub); +} diff --git a/testapps/openai_sample/lib/speech_to_text.dart b/testapps/openai_sample/lib/speech_to_text.dart new file mode 100644 index 00000000..6548730e --- /dev/null +++ b/testapps/openai_sample/lib/speech_to_text.dart @@ -0,0 +1,297 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import 'dart:convert'; +import 'dart:io'; + +import 'package:genkit/genkit.dart'; +import 'package:genkit_openai/genkit_openai.dart'; + +/// Defines a flow that performs basic speech-to-text on a data URL. +/// +/// This is the simplest transcription path: provide `data:audio/...;base64,...` +/// and get plain text back. +Flow defineWhisperTranscriptionFlow( + Genkit ai, { + String model = 'whisper-1', +}) { + return ai.defineFlow( + name: 'whisperTranscribeDataUrl', + inputSchema: .string( + defaultValue: 'data:audio/wav;base64,', + ), + outputSchema: .string(), + fn: (audioDataUrl, _) => _generateTranscriptFromDataUrl( + ai: ai, + model: model, + audioDataUrl: audioDataUrl, + prompt: 'Transcribe this audio. Return only the transcript text.', + ), + ); +} + +/// Defines a flow that returns the raw JSON transcription object. +/// +/// Uses `responseFormat: 'verbose_json'` so the raw response map includes +/// language, duration, and segment metadata alongside the transcript text. +Flow, void, void> defineWhisperJsonFlow( + Genkit ai, { + String model = 'gpt-4o-transcribe', +}) { + return ai.defineFlow( + name: 'whisperTranscribeJson', + inputSchema: .string( + defaultValue: 'data:audio/wav;base64,', + ), + outputSchema: .map(.string(), .dynamicSchema()), + fn: (audioDataUrl, _) => _generateTranscriptionJsonFromDataUrl( + ai: ai, + model: model, + audioDataUrl: audioDataUrl, + prompt: 'Transcribe this audio. Return transcript text.', + options: OpenAISttOptions(responseFormat: 'json'), + ), + ); +} + +/// Defines a flow that returns a verbose transcription with word and segment +/// timestamps. +/// +/// Uses `responseFormat: 'verbose_json'` with +/// `timestampGranularities: ['word', 'segment']`, producing a response that +/// includes `language`, `duration`, `text`, per-segment metadata, and +/// per-word start/end times. Only supported by `whisper-1` — gpt-4o-transcribe +/// variants only accept `json` or `text`. +Flow, void, void> +defineWhisperVerboseTimestampFlow(Genkit ai, {String model = 'whisper-1'}) { + return ai.defineFlow( + name: 'whisperTranscribeVerboseTimestamps', + inputSchema: .string( + defaultValue: 'data:audio/wav;base64,', + ), + outputSchema: .map(.string(), .dynamicSchema()), + fn: (audioDataUrl, _) => _generateTranscriptionJsonFromDataUrl( + ai: ai, + model: model, + audioDataUrl: audioDataUrl, + prompt: 'Transcribe this audio. Return transcript text.', + options: OpenAISttOptions( + responseFormat: 'verbose_json', + timestampGranularities: ['word', 'segment'], + ), + ), + ); +} + +/// Defines a flow that transcribes audio using a GPT transcription model and +/// returns the raw JSON response object. +/// +/// Uses `gpt-4o-transcribe` with `responseFormat: 'json'`. Note that +/// `verbose_json` is not supported by gpt-4o-transcribe variants — use +/// [defineWhisperVerboseTimestampFlow] with `whisper-1` for segment/word +/// timestamps. +Flow, void, void> defineGptTranscribeJsonFlow( + Genkit ai, { + String model = 'gpt-4o-transcribe', +}) { + return ai.defineFlow( + name: 'gptTranscribeJson', + inputSchema: .string( + defaultValue: 'data:audio/wav;base64,', + ), + outputSchema: .map(.string(), .dynamicSchema()), + fn: (audioDataUrl, _) => _generateTranscriptionJsonFromDataUrl( + ai: ai, + model: model, + audioDataUrl: audioDataUrl, + prompt: null, + options: OpenAISttOptions(responseFormat: 'json'), + ), + ); +} + +/// Defines a flow that translates source speech into English. +/// +/// The plugin-level `translate` option routes the request to +/// `/audio/translations`. +Flow defineWhisperTranslationFlow( + Genkit ai, { + String model = 'whisper-1', +}) { + return ai.defineFlow( + name: 'whisperTranslateToEnglish', + inputSchema: .string( + defaultValue: 'data:audio/wav;base64,', + ), + outputSchema: .string(), + fn: (audioDataUrl, _) => _generateTranscriptFromDataUrl( + ai: ai, + model: model, + audioDataUrl: audioDataUrl, + prompt: 'Translate this audio into English.', + options: OpenAISttOptions(translate: true), + ), + ); +} + +/// Defines a flow that transcribes a local audio file path. +Flow defineWhisperAudioFileTranscriptionFlow( + Genkit ai, { + String model = 'whisper-1', +}) { + return ai.defineFlow( + name: 'whisperTranscribeAudioFile', + inputSchema: .string(defaultValue: './sample.wav'), + outputSchema: .string(), + fn: (audioPath, _) async { + final dataUrl = await _readFileAsDataUrl(audioPath); + return _generateTranscriptFromDataUrl( + ai: ai, + model: model, + audioDataUrl: dataUrl, + prompt: 'Transcribe this audio. Return only the transcript text.', + ); + }, + ); +} + +void main() { + final apiKey = Platform.environment['OPENAI_API_KEY']; + if (apiKey == null || apiKey.isEmpty) { + throw StateError('OPENAI_API_KEY is required.'); + } + + final ai = Genkit(plugins: [openAI(apiKey: apiKey)]); + defineWhisperTranscriptionFlow(ai); + defineWhisperJsonFlow(ai); + defineWhisperVerboseTimestampFlow(ai); + defineGptTranscribeJsonFlow(ai); + defineWhisperTranslationFlow(ai); + defineWhisperAudioFileTranscriptionFlow(ai); +} + +Future _generateTranscriptFromDataUrl({ + required Genkit ai, + required String model, + required String audioDataUrl, + required String? prompt, + OpenAISttOptions? options, +}) async { + final response = await _requestTranscriptionFromDataUrl( + ai: ai, + model: model, + audioDataUrl: audioDataUrl, + prompt: prompt, + options: options, + ); + + final text = response.text.trim(); + if (text.isEmpty) { + throw StateError('Model returned empty transcription.'); + } + return text; +} + +Future> _generateTranscriptionJsonFromDataUrl({ + required Genkit ai, + required String model, + required String audioDataUrl, + required String? prompt, + OpenAISttOptions? options, +}) async { + final response = await _requestTranscriptionFromDataUrl( + ai: ai, + model: model, + audioDataUrl: audioDataUrl, + prompt: prompt, + options: options, + ); + + final raw = response.raw; + if (raw != null) { + return raw; + } + + throw StateError( + 'Model returned non-JSON transcription payload for a JSON flow.', + ); +} + +Future _requestTranscriptionFromDataUrl({ + required Genkit ai, + required String model, + required String audioDataUrl, + required String? prompt, + OpenAISttOptions? options, +}) async { + final contentType = _extractAudioMimeTypeFromDataUrl(audioDataUrl); + if (contentType == null) { + throw ArgumentError( + 'Input must be a base64 media data URL (data:audio/...;base64,... or data:video/...;base64,...).', + ); + } + + return ai.generate( + model: openAI.stt(model), + messages: [ + Message( + role: Role.user, + content: [ + if (prompt != null && prompt.trim().isNotEmpty) + TextPart(text: prompt), + MediaPart( + media: Media(url: audioDataUrl, contentType: contentType), + ), + ], + ), + ], + config: options, + ); +} + +Future _readFileAsDataUrl(String path) async { + final file = File(path); + if (!await file.exists()) { + throw ArgumentError('File not found: $path'); + } + + final bytes = await file.readAsBytes(); + final mimeType = _mediaMimeTypeFromPath(path); + return 'data:$mimeType;base64,${base64Encode(bytes)}'; +} + +String? _extractAudioMimeTypeFromDataUrl(String url) { + final match = RegExp( + r'^data:((?:audio|video)\/[^;]+);base64,', + caseSensitive: false, + ).firstMatch(url); + return match?.group(1); +} + +String _mediaMimeTypeFromPath(String path) { + final lower = path.toLowerCase(); + if (lower.endsWith('.flac')) return 'audio/flac'; + if (lower.endsWith('.mp3')) return 'audio/mpeg'; + if (lower.endsWith('.mp4')) return 'video/mp4'; + if (lower.endsWith('.mpeg')) return 'video/mpeg'; + if (lower.endsWith('.mpga')) return 'audio/mpga'; + if (lower.endsWith('.m4a')) return 'audio/m4a'; + if (lower.endsWith('.ogg')) return 'audio/ogg'; + if (lower.endsWith('.wav')) return 'audio/wav'; + if (lower.endsWith('.webm')) return 'audio/webm'; + + throw ArgumentError( + 'Unsupported media file extension for "$path". Supported extensions: .flac, .mp3, .mp4, .mpeg, .mpga, .m4a, .ogg, .wav, .webm.', + ); +} diff --git a/testapps/openai_sample/sample.wav b/testapps/openai_sample/sample.wav new file mode 100644 index 00000000..ba84c497 Binary files /dev/null and b/testapps/openai_sample/sample.wav differ