diff --git a/bedrock/bedrock.go b/bedrock/bedrock.go index 5508c0072..3e67640ce 100644 --- a/bedrock/bedrock.go +++ b/bedrock/bedrock.go @@ -72,6 +72,12 @@ func (e *eventstreamDecoder) Next() bool { msg, err := e.Decoder.Decode(e.rc, nil) if err != nil { + // Filter io.EOF — the SSE decoder (eventStreamDecoder) returns nil from + // scn.Err() on normal stream end, but the EventStream decoder exposes EOF + // as an error. Treat EOF as normal stream termination for consistency. + if err == io.EOF { + return false + } e.err = err return false } @@ -106,54 +112,49 @@ func (e *eventstreamDecoder) Next() bool { Type: gjson.GetBytes(decoded, "type").String(), Data: decoded, } + return true } - - case eventstreamapi.ExceptionMessageType: - // See https://github.com/aws/aws-sdk-go-v2/blob/885de40869f9bcee29ad11d60967aa0f1b571d46/service/iotsitewise/deserializers.go#L15511C1-L15567C2 - exceptionType := msg.Headers.Get(eventstreamapi.ExceptionTypeHeader) - if exceptionType == nil { - e.err = fmt.Errorf("%s event header not present", eventstreamapi.ExceptionTypeHeader) - return false - } - - // See https://github.com/aws/aws-sdk-go-v2/blob/885de40869f9bcee29ad11d60967aa0f1b571d46/aws/protocol/restjson/decoder_util.go#L15-L48k - var errInfo struct { - Code string - Type string `json:"__type"` - Message string - } - err = json.Unmarshal(msg.Payload, &errInfo) - if err != nil && err != io.EOF { - e.err = fmt.Errorf("received exception %s: parsing exception payload failed: %w", exceptionType.String(), err) - return false - } - + // Non-chunk event type — skip (don't return stale e.evt from previous call) + return e.Next() + + case eventstreamapi.ExceptionMessageType, eventstreamapi.ErrorMessageType: + // Handle both exception and error message types uniformly. + // Previously ExceptionMessageType had its own case that didn't fall through + // to ErrorMessageType (Go case statements don't fall through by default), + // causing Bedrock exceptions to be silently dropped. See #71. errorCode := "UnknownError" errorMessage := errorCode - if ev := exceptionType.String(); len(ev) > 0 { - errorCode = ev - } else if len(errInfo.Code) > 0 { - errorCode = errInfo.Code - } else if len(errInfo.Type) > 0 { - errorCode = errInfo.Type - } - if len(errInfo.Message) > 0 { - errorMessage = errInfo.Message + if messageType.String() == eventstreamapi.ExceptionMessageType { + exceptionType := msg.Headers.Get(eventstreamapi.ExceptionTypeHeader) + if exceptionType != nil { + errorCode = exceptionType.String() + } + var errInfo struct { + Code string + Type string `json:"__type"` + Message string + } + if err := json.Unmarshal(msg.Payload, &errInfo); err == nil { + if len(errInfo.Code) > 0 { + errorCode = errInfo.Code + } else if len(errInfo.Type) > 0 { + errorCode = errInfo.Type + } + if len(errInfo.Message) > 0 { + errorMessage = errInfo.Message + } + } + } else { + if header := msg.Headers.Get(eventstreamapi.ErrorCodeHeader); header != nil { + errorCode = header.String() + } + if header := msg.Headers.Get(eventstreamapi.ErrorMessageHeader); header != nil { + errorMessage = header.String() + } } - e.err = fmt.Errorf("received exception %s: %s", errorCode, errorMessage) - return false - case eventstreamapi.ErrorMessageType: - errorCode := "UnknownError" - errorMessage := errorCode - if header := msg.Headers.Get(eventstreamapi.ErrorCodeHeader); header != nil { - errorCode = header.String() - } - if header := msg.Headers.Get(eventstreamapi.ErrorMessageHeader); header != nil { - errorMessage = header.String() - } - e.err = fmt.Errorf("received error %s: %s", errorCode, errorMessage) + e.err = fmt.Errorf("received %s %s: %s", messageType.String(), errorCode, errorMessage) return false }