diff --git a/spine/device_local.go b/spine/device_local.go index d401e3f..f170958 100644 --- a/spine/device_local.go +++ b/spine/device_local.go @@ -129,6 +129,11 @@ var _ api.DeviceLocalInterface = (*DeviceLocal)(nil) // Setup a new remote device with a given SKI and triggers SPINE requesting device details func (r *DeviceLocal) SetupRemoteDevice(ski string, writeI shipapi.ShipConnectionDataWriterInterface) shipapi.ShipConnectionDataReaderInterface { + // Clean up any existing device for this SKI (e.g., fast reconnect) + if existing := r.RemoteDeviceForSki(ski); existing != nil { + r.RemoveRemoteDevice(ski) + } + sender := NewSender(writeI) rDevice := NewDeviceRemote(r, ski, sender) @@ -286,6 +291,45 @@ func (r *DeviceLocal) RemoveEntity(entity api.EntityLocalInterface) { r.notifySubscribersOfEntity(entity, model.NetworkManagementStateChangeTypeRemoved) } +// Close shuts down the DeviceLocal, stopping all goroutines and cleaning up all state. +// It removes all remote device connections, waits for pending event handlers, +// removes all application entities (stopping heartbeats), and unsubscribes from events. +// Safe to call multiple times. +func (r *DeviceLocal) Close() { + // Snapshot remote device SKIs + r.mux.Lock() + skis := make([]string, 0, len(r.remoteDevices)) + for ski := range r.remoteDevices { + skis = append(skis, ski) + } + r.mux.Unlock() + + // Remove all remote devices (cleans subscriptions, bindings, write approval timers) + for _, ski := range skis { + r.RemoveRemoteDeviceConnection(ski) + } + + // Wait for disconnect event handlers to finish + r.events.drain() + + // Snapshot entities + r.mux.Lock() + entities := make([]api.EntityLocalInterface, len(r.entities)) + copy(entities, r.entities) + r.mux.Unlock() + + // Remove all application entities (stops heartbeats), skip DeviceInformation entity[0] + for _, entity := range entities { + addr := entity.Address().Entity + if len(addr) > 0 && addr[0] != model.AddressEntityType(DeviceInformationEntityId) { + r.RemoveEntity(entity) + } + } + + // Unsubscribe from core events + _ = r.events.unsubscribe(api.EventHandlerLevelCore, r) +} + func (r *DeviceLocal) Entities() []api.EntityLocalInterface { r.mux.Lock() defer r.mux.Unlock() @@ -346,28 +390,28 @@ func (r *DeviceLocal) ProcessCmd(datagram model.DatagramType, remoteDevice api.D // Validate cmd.function consistency when filters are present // Per SPINE spec section 5.3.4: "SHALL be present if datagram.payload.cmd.filter is present." // The primary security concern is type confusion attacks when filters target wrong functions - + filterPartial, filterDelete := cmd.ExtractFilter() hasFilters := filterPartial != nil || filterDelete != nil - + if hasFilters { // Filters present: cmd.Function MUST be present and consistent // This is the critical validation to prevent type confusion attacks if err := cmd.ValidateFunctionConsistencyStrict(); err != nil { inconsistencies := cmd.GetInconsistentFunctions() errorMsg := fmt.Sprintf("cmd function validation failed: %s", err.Error()) - + // Log validation failure for security monitoring (non-sensitive info only) - logging.Log().Debugf("Command function validation failed: %s (inconsistencies: %d, device: %s, classifier: %v)", - err.Error(), + logging.Log().Debugf("Command function validation failed: %s (inconsistencies: %d, device: %s, classifier: %v)", + err.Error(), len(inconsistencies), remoteDevice.Address(), cmdClassifier) - + // Send proper error response to remote device validationError := model.NewErrorType(model.ErrorNumberTypeCommandRejected, errorMsg) _ = remoteDevice.Sender().ResultError(&datagram.Header, destAddr, validationError) - + return fmt.Errorf("cmd function validation failed: %w", err) } } diff --git a/spine/device_local_test.go b/spine/device_local_test.go index 6beab34..ac8a3ac 100644 --- a/spine/device_local_test.go +++ b/spine/device_local_test.go @@ -1,6 +1,7 @@ package spine import ( + "runtime" "testing" "time" @@ -426,3 +427,95 @@ func (d *DeviceLocalTestSuite) Test_ProcessCmd() { err = sut.ProcessCmd(datagram, remote) assert.NotNil(d.T(), err) } + +// Step 6 — Issue 9: SetupRemoteDevice should clean up existing device for same SKI. +func (d *DeviceLocalTestSuite) Test_SetupRemoteDevice_CleansUpExisting() { + sut := NewDeviceLocal("brand", "model", "serial", "code", "address", + model.DeviceTypeTypeEnergyManagementSystem, model.NetworkManagementFeatureSetTypeSmart) + + ski := "reconnect-ski" + _ = sut.SetupRemoteDevice(ski, d) + device1 := sut.RemoteDeviceForSki(ski) + assert.NotNil(d.T(), device1) + + // Second setup for the same SKI — should clean up the first + _ = sut.SetupRemoteDevice(ski, d) + device2 := sut.RemoteDeviceForSki(ski) + assert.NotNil(d.T(), device2) + + // The new device should be different from the old one + assert.NotEqual(d.T(), device1, device2) + + // Only one device should exist for that SKI + remotes := sut.RemoteDevices() + skiCount := 0 + for _, r := range remotes { + if r.Ski() == ski { + skiCount++ + } + } + assert.Equal(d.T(), 1, skiCount) + + sut.RemoveRemoteDeviceConnection(ski) +} + +// Step 8 — Issue 1: Close() stops all goroutines and cleans up all state. +func (d *DeviceLocalTestSuite) Test_Close_StopsAllGoroutines() { + runtime.GC() + time.Sleep(50 * time.Millisecond) + baseline := runtime.NumGoroutine() + + sut := NewDeviceLocal("brand", "model", "serial", "code", "address", + model.DeviceTypeTypeEnergyManagementSystem, model.NetworkManagementFeatureSetTypeSmart) + + // Create 3 entities with heartbeat + for i := uint(1); i <= 3; i++ { + entity := NewEntityLocal(sut, model.EntityTypeTypeCEM, + []model.AddressEntityType{model.AddressEntityType(i)}, 4*time.Second) + sut.AddEntity(entity) + + diagFeature := NewFeatureLocal(entity.NextFeatureId(), entity, + model.FeatureTypeTypeDeviceDiagnosis, model.RoleTypeServer) + diagFeature.AddFunctionType(model.FunctionTypeDeviceDiagnosisHeartbeatData, true, false) + entity.AddFeature(diagFeature) + } + + // Connect 2 remote devices + _ = sut.SetupRemoteDevice("remote-1", d) + _ = sut.SetupRemoteDevice("remote-2", d) + + time.Sleep(50 * time.Millisecond) + running := runtime.NumGoroutine() + assert.Greater(d.T(), running, baseline, "heartbeat goroutines should be running") + + // Close should clean up everything + sut.Close() + + time.Sleep(50 * time.Millisecond) + after := runtime.NumGoroutine() + assert.LessOrEqual(d.T(), after, baseline+1, + "all goroutines should be stopped after Close()") + + // All remote devices should be gone + assert.Empty(d.T(), sut.RemoteDevices()) + + // Only entity[0] should remain + entities := sut.Entities() + assert.Equal(d.T(), 1, len(entities)) +} + +// Step 8 — Close() is idempotent. +func (d *DeviceLocalTestSuite) Test_Close_Idempotent() { + sut := NewDeviceLocal("brand", "model", "serial", "code", "address", + model.DeviceTypeTypeEnergyManagementSystem, model.NetworkManagementFeatureSetTypeSmart) + + entity := NewEntityLocal(sut, model.EntityTypeTypeCEM, + []model.AddressEntityType{1}, 4*time.Second) + sut.AddEntity(entity) + + _ = sut.SetupRemoteDevice("ski-1", d) + + // Close twice — should not panic + sut.Close() + sut.Close() +} diff --git a/spine/events.go b/spine/events.go index 9135d04..627abff 100644 --- a/spine/events.go +++ b/spine/events.go @@ -24,6 +24,7 @@ type eventHandlerItem struct { type events struct { mu sync.Mutex muHandle sync.Mutex + wg sync.WaitGroup handlers []eventHandlerItem // event handling outside of the core stack } @@ -105,9 +106,19 @@ func (r *events) Publish(payload api.EventPayload) { // and expected actions are taken item.Handler.HandleEvent(payload) } else { - go item.Handler.HandleEvent(payload) + r.wg.Add(1) + go func(h api.EventHandlerInterface) { + defer r.wg.Done() + h.HandleEvent(payload) + }(item.Handler) } } } r.muHandle.Unlock() } + +// drain waits for all dispatched application-level event handlers to complete. +// Used during shutdown to ensure no handlers are still running. +func (r *events) drain() { + r.wg.Wait() +} diff --git a/spine/events_test.go b/spine/events_test.go index dfee5f8..347c443 100644 --- a/spine/events_test.go +++ b/spine/events_test.go @@ -97,3 +97,40 @@ func (s *EventsTestSuite) Test_Publish_Application() { err = s.events.Unsubscribe(s) assert.Nil(s.T(), err) } + +// Step 7 — Issue 6: drain() must wait for all dispatched handlers to complete. +type slowTestHandler struct { + mu sync.Mutex + finished bool +} + +func (h *slowTestHandler) HandleEvent(event api.EventPayload) { + time.Sleep(200 * time.Millisecond) + h.mu.Lock() + h.finished = true + h.mu.Unlock() +} + +func (h *slowTestHandler) isFinished() bool { + h.mu.Lock() + defer h.mu.Unlock() + return h.finished +} + +func (s *EventsTestSuite) Test_Drain_WaitsForHandlers() { + handler := &slowTestHandler{} + err := s.events.Subscribe(handler) + assert.Nil(s.T(), err) + + s.events.Publish(api.EventPayload{}) + + // Handler is still running at this point + assert.False(s.T(), handler.isFinished()) + + // drain() should block until handler completes + s.events.drain() + assert.True(s.T(), handler.isFinished()) + + err = s.events.Unsubscribe(handler) + assert.Nil(s.T(), err) +} diff --git a/spine/feature_local.go b/spine/feature_local.go index d3f9d2c..ff4d014 100644 --- a/spine/feature_local.go +++ b/spine/feature_local.go @@ -25,7 +25,6 @@ type FeatureLocal struct { writeTimeout time.Duration writeApprovalCallbacks []api.WriteApprovalCallbackFunc - muxWriteReceived sync.Mutex writeApprovalReceived map[string]map[model.MsgCounterType]int pendingWriteApprovals map[string]map[model.MsgCounterType]*time.Timer @@ -193,6 +192,10 @@ func (r *FeatureLocal) addPendingApproval(msg *api.Message) { newTimer := time.AfterFunc(r.writeTimeout, func() { r.muxResponseCB.Lock() + if _, ok := r.pendingWriteApprovals[ski]; !ok { + r.muxResponseCB.Unlock() + return + } delete(r.pendingWriteApprovals[ski], *msg.RequestHeader.MsgCounter) r.muxResponseCB.Unlock() @@ -217,18 +220,17 @@ func (r *FeatureLocal) ApproveOrDenyWrite(msg *api.Message, err model.ErrorType) ski := msg.DeviceRemote.Ski() r.muxResponseCB.Lock() - timer, ok := r.pendingWriteApprovals[ski][*msg.RequestHeader.MsgCounter] - count := len(r.writeApprovalCallbacks) - r.muxResponseCB.Unlock() - // if there is no timer running, we are too late and error has already been sent + timer, ok := r.pendingWriteApprovals[ski][*msg.RequestHeader.MsgCounter] + // if there is no timer, we are too late and error has already been sent if !ok || timer == nil { + r.muxResponseCB.Unlock() return } + count := len(r.writeApprovalCallbacks) + // do we have enough approvals? - r.muxWriteReceived.Lock() - defer r.muxWriteReceived.Unlock() if count > 1 && err.ErrorNumber == 0 { amount, ok := r.writeApprovalReceived[ski][*msg.RequestHeader.MsgCounter] if ok { @@ -239,18 +241,20 @@ func (r *FeatureLocal) ApproveOrDenyWrite(msg *api.Message, err model.ErrorType) } // do we have enough approve messages, if not exit if r.writeApprovalReceived[ski][*msg.RequestHeader.MsgCounter] < count { + r.muxResponseCB.Unlock() return } } + // Atomically stop the timer and clean up entries under the lock. + // This prevents the TOCTOU race where the timer fires between lock acquisitions. timer.Stop() - delete(r.writeApprovalReceived[ski], *msg.RequestHeader.MsgCounter) - - r.muxResponseCB.Lock() - defer r.muxResponseCB.Unlock() delete(r.pendingWriteApprovals[ski], *msg.RequestHeader.MsgCounter) + r.muxResponseCB.Unlock() + + // Process outside the lock to avoid holding it during network I/O if err.ErrorNumber == 0 { r.processWrite(msg) return @@ -267,6 +271,13 @@ func (r *FeatureLocal) CleanWriteApprovalCaches(ski string) { r.muxResponseCB.Lock() defer r.muxResponseCB.Unlock() + // Stop all pending timers for this SKI before deleting + if timers, ok := r.pendingWriteApprovals[ski]; ok { + for _, timer := range timers { + timer.Stop() + } + } + delete(r.pendingWriteApprovals, ski) delete(r.writeApprovalReceived, ski) } diff --git a/spine/feature_local_test.go b/spine/feature_local_test.go index 5d42317..ae1240d 100644 --- a/spine/feature_local_test.go +++ b/spine/feature_local_test.go @@ -1234,9 +1234,117 @@ func (s *LocalFeatureTestSuite) Test_Read_WithPartialFilter_NoErrors() { // Test that partial read capability is correctly reported as false func (s *LocalFeatureTestSuite) Test_Operations_NoPartialReadSupport() { operations := s.localServerFeatureWrite.Operations() - + // Verify that partial read is not supported operation, exists := operations[s.serverWriteFunction] assert.True(s.T(), exists) assert.False(s.T(), operation.ReadPartial()) } + +// Step 3 — Issue 2: CleanWriteApprovalCaches must stop running timers. +// After cleanup, no timer should fire and send messages on the dead connection. +func (s *LocalFeatureTestSuite) Test_CleanWriteApprovalCaches_StopsTimers() { + writeHandler := &WriteMessageHandler{} + device, localEntity := createLocalDeviceAndEntity(1) + _, serverFeature := createLocalFeatures(localEntity, model.FeatureTypeTypeLoadControl, model.FunctionTypeLoadControlLimitListData) + + serverFeature.AddWriteApprovalCallback(func(msg *api.Message) { + // Intentionally don't approve — let the timer expire. + }) + + ski := "timer-test" + sender := NewSender(writeHandler) + remoteDevice := createRemoteDevice(device, ski, sender) + device.AddRemoteDeviceForSki(ski, remoteDevice) + remoteFeature, _ := createRemoteEntityAndFeature(remoteDevice, 1, + model.FeatureTypeTypeLoadControl, model.FunctionTypeLoadControlLimitListData) + + serverFeature.SetWriteApprovalTimeout(200 * time.Millisecond) + + msgCounter := model.MsgCounterType(42) + msg := &api.Message{ + RequestHeader: &model.HeaderType{ + MsgCounter: util.Ptr(msgCounter), + AddressSource: &model.FeatureAddressType{ + Device: remoteDevice.Address(), + Entity: remoteDevice.Entity([]model.AddressEntityType{1}).Address().Entity, + Feature: remoteFeature.Address().Feature, + }, + AddressDestination: serverFeature.Address(), + }, + DeviceRemote: remoteDevice, + EntityRemote: remoteDevice.Entity([]model.AddressEntityType{1}), + FeatureRemote: remoteFeature, + } + + sf := serverFeature.(*FeatureLocal) + sf.addPendingApproval(msg) + + callsBefore := len(writeHandler.sentMessages) + + // Cleanup should stop the timer + serverFeature.CleanWriteApprovalCaches(ski) + + // Wait past the timer expiry + time.Sleep(400 * time.Millisecond) + + callsAfter := len(writeHandler.sentMessages) + assert.Equal(s.T(), callsBefore, callsAfter, + "no messages should be sent after CleanWriteApprovalCaches stops timers") +} + +// Step 4 — Issue 7: ApproveOrDenyWrite must not send double responses. +// The timer and approval path must not both send a response for the same request. +func (s *LocalFeatureTestSuite) Test_ApproveOrDenyWrite_NoDoubleResponse() { + writeHandler := &WriteMessageHandler{} + device, localEntity := createLocalDeviceAndEntity(1) + _, serverFeature := createLocalFeatures(localEntity, model.FeatureTypeTypeLoadControl, model.FunctionTypeLoadControlLimitListData) + + serverFeature.AddWriteApprovalCallback(func(msg *api.Message) { + // Will be approved externally + }) + + ski := "toctou-test" + sender := NewSender(writeHandler) + remoteDevice := createRemoteDevice(device, ski, sender) + device.AddRemoteDeviceForSki(ski, remoteDevice) + remoteFeature, _ := createRemoteEntityAndFeature(remoteDevice, 1, + model.FeatureTypeTypeLoadControl, model.FunctionTypeLoadControlLimitListData) + + // Very short timeout to make the race window likely + serverFeature.SetWriteApprovalTimeout(80 * time.Millisecond) + + msgCounter := model.MsgCounterType(99) + msg := &api.Message{ + RequestHeader: &model.HeaderType{ + MsgCounter: util.Ptr(msgCounter), + AddressSource: &model.FeatureAddressType{ + Device: remoteDevice.Address(), + Entity: remoteDevice.Entity([]model.AddressEntityType{1}).Address().Entity, + Feature: remoteFeature.Address().Feature, + }, + AddressDestination: serverFeature.Address(), + }, + DeviceRemote: remoteDevice, + EntityRemote: remoteDevice.Entity([]model.AddressEntityType{1}), + FeatureRemote: remoteFeature, + } + + sf := serverFeature.(*FeatureLocal) + sf.addPendingApproval(msg) + + // Wait close to timeout, then approve + time.Sleep(70 * time.Millisecond) + serverFeature.ApproveOrDenyWrite(msg, model.ErrorType{ErrorNumber: model.ErrorNumberType(0)}) + + // Wait for any timer to fire + time.Sleep(100 * time.Millisecond) + + writeHandler.mux.Lock() + totalMessages := len(writeHandler.sentMessages) + writeHandler.mux.Unlock() + + // At most 1 response should be sent (either success from approval or error from timeout, not both) + assert.LessOrEqual(s.T(), totalMessages, 1, + "at most 1 response should be sent, not both timer error and approval success") +} diff --git a/spine/heartbeat_manager.go b/spine/heartbeat_manager.go index 8e663b9..33cbbb6 100644 --- a/spine/heartbeat_manager.go +++ b/spine/heartbeat_manager.go @@ -15,6 +15,7 @@ type HeartbeatManager struct { heartBeatNum uint64 // see https://github.com/golang/go/issues/11891 stopHeartbeatC chan struct{} + doneC chan struct{} // closed when the goroutine exits stopMux sync.Mutex heartBeatTimeout *model.DurationType @@ -88,21 +89,38 @@ func (c *HeartbeatManager) StartHeartbeat() error { return err } - // stop an already running heartbeat + // stop an already running heartbeat and wait for its goroutine to exit c.StopHeartbeat() + c.stopMux.Lock() c.stopHeartbeatC = make(chan struct{}) + c.doneC = make(chan struct{}) + stopC := c.stopHeartbeatC + doneC := c.doneC + c.stopMux.Unlock() - go c.updateHeartbeatData(c.stopHeartbeatC, timeout) + go func() { + defer close(doneC) + c.updateHeartbeatData(stopC, timeout) + }() return nil } -// Stop updating heartbeat data +// Stop updating heartbeat data and wait for the goroutine to exit. // Note: No active subscribers will get any further notifications! func (c *HeartbeatManager) StopHeartbeat() { - if c.IsHeartbeatRunning() { + c.stopMux.Lock() + needsStop := c.stopHeartbeatC != nil && !c.isHeartbeatClosed() + var doneC chan struct{} + if needsStop { close(c.stopHeartbeatC) + doneC = c.doneC + } + c.stopMux.Unlock() + + if doneC != nil { + <-doneC } } @@ -126,6 +144,7 @@ func (c *HeartbeatManager) updateHeartbeatData(stopC chan struct{}, d time.Durat d -= 2 * time.Second } ticker := time.NewTicker(d) + defer ticker.Stop() for { select { case <-ticker.C: diff --git a/spine/heartbeat_manager_test.go b/spine/heartbeat_manager_test.go index 08d035d..dad3a3d 100644 --- a/spine/heartbeat_manager_test.go +++ b/spine/heartbeat_manager_test.go @@ -1,6 +1,8 @@ package spine import ( + "runtime" + "sync" "testing" "time" @@ -178,3 +180,80 @@ func (s *HeartBeatManagerSuite) Test_HeartbeatSuccess() { isHeartbeatRunning = s.sut.IsHeartbeatRunning() assert.Equal(s.T(), false, isHeartbeatRunning) } + +// Step 1 — Issue 5: Verify ticker is stopped when heartbeat goroutine exits. +// After multiple start/stop cycles, goroutine count should return to baseline. +func (s *HeartBeatManagerSuite) Test_TickerStoppedOnShutdown() { + localFeature := s.localEntity.GetOrAddFeature(model.FeatureTypeTypeDeviceDiagnosis, model.RoleTypeServer) + localFeature.AddFunctionType(model.FunctionTypeDeviceDiagnosisHeartbeatData, true, false) + s.localEntity.AddFeature(localFeature) + + runtime.GC() + time.Sleep(50 * time.Millisecond) + baseline := runtime.NumGoroutine() + + for i := 0; i < 10; i++ { + _ = s.sut.StartHeartbeat() + time.Sleep(20 * time.Millisecond) + s.sut.StopHeartbeat() + time.Sleep(20 * time.Millisecond) + } + + runtime.GC() + time.Sleep(100 * time.Millisecond) + after := runtime.NumGoroutine() + + // With the ticker leak, goroutine count would grow. After fix, it should be stable. + assert.LessOrEqual(s.T(), after, baseline+1, + "goroutine count should not grow after repeated start/stop cycles") +} + +// Step 2 — Issue 4: Concurrent StartHeartbeat/IsHeartbeatRunning must not race. +// Run with: go test -race -run TestHeartbeatManagerSuite/Test_StartStopHeartbeat_ConcurrentAccess +func (s *HeartBeatManagerSuite) Test_StartStopHeartbeat_ConcurrentAccess() { + localFeature := s.localEntity.GetOrAddFeature(model.FeatureTypeTypeDeviceDiagnosis, model.RoleTypeServer) + localFeature.AddFunctionType(model.FunctionTypeDeviceDiagnosisHeartbeatData, true, false) + s.localEntity.AddFeature(localFeature) + + var wg sync.WaitGroup + for i := 0; i < 20; i++ { + wg.Add(2) + go func() { + defer wg.Done() + _ = s.sut.StartHeartbeat() + }() + go func() { + defer wg.Done() + _ = s.sut.IsHeartbeatRunning() + }() + } + wg.Wait() + + s.sut.StopHeartbeat() + time.Sleep(50 * time.Millisecond) + assert.False(s.T(), s.sut.IsHeartbeatRunning()) +} + +// Step 5 — Issue 3: StopHeartbeat must wait for the goroutine to exit. +func (s *HeartBeatManagerSuite) Test_StopHeartbeat_WaitsForGoroutine() { + localFeature := s.localEntity.GetOrAddFeature(model.FeatureTypeTypeDeviceDiagnosis, model.RoleTypeServer) + localFeature.AddFunctionType(model.FunctionTypeDeviceDiagnosisHeartbeatData, true, false) + s.localEntity.AddFeature(localFeature) + + runtime.GC() + time.Sleep(50 * time.Millisecond) + baseline := runtime.NumGoroutine() + + _ = s.sut.StartHeartbeat() + time.Sleep(20 * time.Millisecond) + assert.True(s.T(), s.sut.IsHeartbeatRunning()) + + // After StopHeartbeat returns, the goroutine should be fully exited. + s.sut.StopHeartbeat() + + // No sleep needed — StopHeartbeat should block until goroutine exits. + after := runtime.NumGoroutine() + assert.LessOrEqual(s.T(), after, baseline+1, + "goroutine should be fully stopped when StopHeartbeat returns") + assert.False(s.T(), s.sut.IsHeartbeatRunning()) +} diff --git a/spine/lifecycle_issues_demo_test.go b/spine/lifecycle_issues_demo_test.go new file mode 100644 index 0000000..663e0d0 --- /dev/null +++ b/spine/lifecycle_issues_demo_test.go @@ -0,0 +1,469 @@ +package spine + +// lifecycle_issues_demo_test.go +// +// Verifies that lifecycle issues in spine-go's start/stop handling are fixed. +// Run with: +// go test -v -run TestLifecycleIssues ./spine/ +// go test -v -race -run TestLifecycleIssues_Race ./spine/ + +import ( + "fmt" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/enbility/spine-go/api" + "github.com/enbility/spine-go/model" + "github.com/enbility/spine-go/util" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +// --------------------------------------------------------------------------- +// Shared helpers +// --------------------------------------------------------------------------- + +func countGoroutines() int { + return runtime.NumGoroutine() +} + +func stabilize() { + runtime.Gosched() + time.Sleep(50 * time.Millisecond) +} + +type trackingWriter struct { + mu sync.Mutex + messages [][]byte + calls atomic.Int64 +} + +func (w *trackingWriter) WriteShipMessageWithPayload(msg []byte) { + w.calls.Add(1) + w.mu.Lock() + defer w.mu.Unlock() + w.messages = append(w.messages, msg) +} + +func (w *trackingWriter) callCount() int64 { + return w.calls.Load() +} + +type slowEventHandler struct { + mu sync.Mutex + events []api.EventPayload + started atomic.Int64 + finished atomic.Int64 + delay time.Duration +} + +func newSlowEventHandler(delay time.Duration) *slowEventHandler { + return &slowEventHandler{delay: delay} +} + +func (h *slowEventHandler) HandleEvent(payload api.EventPayload) { + h.started.Add(1) + time.Sleep(h.delay) + h.mu.Lock() + h.events = append(h.events, payload) + h.mu.Unlock() + h.finished.Add(1) +} + +// --------------------------------------------------------------------------- +// Test Suite +// --------------------------------------------------------------------------- + +type LifecycleIssuesSuite struct { + suite.Suite +} + +func TestLifecycleIssues(t *testing.T) { + suite.Run(t, new(LifecycleIssuesSuite)) +} + +// ========================================================================= +// Issue 1: Close() stops all goroutines — no leak when used +// ========================================================================= + +func (s *LifecycleIssuesSuite) TestIssue1_CloseStopsAllGoroutines() { + stabilize() + before := countGoroutines() + + device := NewDeviceLocal( + "Demo", "Model", "Serial", "Code", "device-1", + model.DeviceTypeTypeEnergyManagementSystem, + model.NetworkManagementFeatureSetTypeSmart) + + entity := NewEntityLocal(device, model.EntityTypeTypeCEM, + []model.AddressEntityType{1}, 4*time.Second) + device.AddEntity(entity) + + diagFeature := NewFeatureLocal( + entity.NextFeatureId(), entity, + model.FeatureTypeTypeDeviceDiagnosis, model.RoleTypeServer) + diagFeature.AddFunctionType(model.FunctionTypeDeviceDiagnosisHeartbeatData, true, false) + entity.AddFeature(diagFeature) + + assert.True(s.T(), entity.HeartbeatManager().IsHeartbeatRunning()) + + // Close() should clean up everything + device.Close() + + stabilize() + after := countGoroutines() + + assert.LessOrEqual(s.T(), after, before+1, + "no goroutines should leak after Close()") +} + +// ========================================================================= +// Issue 2: CleanWriteApprovalCaches stops timers — no post-cleanup fire +// ========================================================================= + +func (s *LifecycleIssuesSuite) TestIssue2_TimersStoppedOnCleanup() { + writer := &trackingWriter{} + + device, localEntity := createLocalDeviceAndEntity(1) + _, serverFeature := createLocalFeatures(localEntity, model.FeatureTypeTypeLoadControl, model.FunctionTypeLoadControlLimitListData) + + serverFeature.AddWriteApprovalCallback(func(msg *api.Message) {}) + + ski := "remote-1" + sender := NewSender(writer) + remoteDevice := createRemoteDevice(device, ski, sender) + device.AddRemoteDeviceForSki(ski, remoteDevice) + + remoteFeature, _ := createRemoteEntityAndFeature(remoteDevice, 1, + model.FeatureTypeTypeLoadControl, model.FunctionTypeLoadControlLimitListData) + + serverFeature.SetWriteApprovalTimeout(200 * time.Millisecond) + + msgCounter := model.MsgCounterType(42) + msg := &api.Message{ + RequestHeader: &model.HeaderType{ + MsgCounter: util.Ptr(msgCounter), + AddressSource: &model.FeatureAddressType{ + Device: remoteDevice.Address(), + Entity: remoteDevice.Entity([]model.AddressEntityType{1}).Address().Entity, + Feature: remoteFeature.Address().Feature, + }, + AddressDestination: serverFeature.Address(), + }, + DeviceRemote: remoteDevice, + EntityRemote: remoteDevice.Entity([]model.AddressEntityType{1}), + FeatureRemote: remoteFeature, + } + + serverFeature.(*FeatureLocal).addPendingApproval(msg) + + callsBefore := writer.callCount() + device.RemoveRemoteDevice(ski) + + // Wait past timer expiry + time.Sleep(400 * time.Millisecond) + + callsAfter := writer.callCount() + assert.Equal(s.T(), callsBefore, callsAfter, + "timer should not fire after CleanWriteApprovalCaches stops it") +} + +// ========================================================================= +// Issue 3: StopHeartbeat waits for goroutine — synchronous stop +// ========================================================================= + +func (s *LifecycleIssuesSuite) TestIssue3_HeartbeatJoinedOnRemoval() { + device := NewDeviceLocal( + "Demo", "Model", "Serial", "Code", "device-1", + model.DeviceTypeTypeEnergyManagementSystem, + model.NetworkManagementFeatureSetTypeSmart) + + entity := NewEntityLocal(device, model.EntityTypeTypeCEM, + []model.AddressEntityType{1}, 4*time.Second) + device.AddEntity(entity) + + diagFeature := NewFeatureLocal( + entity.NextFeatureId(), entity, + model.FeatureTypeTypeDeviceDiagnosis, model.RoleTypeServer) + diagFeature.AddFunctionType(model.FunctionTypeDeviceDiagnosisHeartbeatData, true, false) + entity.AddFeature(diagFeature) + + assert.True(s.T(), entity.HeartbeatManager().IsHeartbeatRunning()) + + goroutinesBefore := countGoroutines() + + // RemoveEntity -> StopHeartbeat now blocks until goroutine exits + device.RemoveEntity(entity) + + // Goroutine should be gone immediately (no sleep needed) + goroutinesAfter := countGoroutines() + assert.LessOrEqual(s.T(), goroutinesAfter, goroutinesBefore, + "goroutine should exit synchronously when StopHeartbeat returns") +} + +// ========================================================================= +// Issue 4: No data race on stopHeartbeatC (run with -race) +// ========================================================================= + +type LifecycleRaceSuite struct { + suite.Suite +} + +func TestLifecycleIssues_Race(t *testing.T) { + suite.Run(t, new(LifecycleRaceSuite)) +} + +func (s *LifecycleRaceSuite) TestIssue4_NoRaceOnStopHeartbeatChannel() { + device := NewDeviceLocal( + "Demo", "Model", "Serial", "Code", "device-1", + model.DeviceTypeTypeEnergyManagementSystem, + model.NetworkManagementFeatureSetTypeSmart) + + entity := NewEntityLocal(device, model.EntityTypeTypeCEM, + []model.AddressEntityType{1}, 4*time.Second) + device.AddEntity(entity) + + diagFeature := NewFeatureLocal( + entity.NextFeatureId(), entity, + model.FeatureTypeTypeDeviceDiagnosis, model.RoleTypeServer) + diagFeature.AddFunctionType(model.FunctionTypeDeviceDiagnosisHeartbeatData, true, false) + entity.AddFeature(diagFeature) + + hbm := entity.HeartbeatManager() + + // Concurrent StartHeartbeat/IsHeartbeatRunning — should NOT trigger race detector. + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(2) + go func() { + defer wg.Done() + _ = hbm.StartHeartbeat() + }() + go func() { + defer wg.Done() + _ = hbm.IsHeartbeatRunning() + }() + } + wg.Wait() + + hbm.StopHeartbeat() + assert.False(s.T(), hbm.IsHeartbeatRunning()) +} + +// ========================================================================= +// Issue 5: Ticker properly stopped — no resource leak +// ========================================================================= + +func (s *LifecycleIssuesSuite) TestIssue5_TickerStoppedOnExit() { + device := NewDeviceLocal( + "Demo", "Model", "Serial", "Code", "device-1", + model.DeviceTypeTypeEnergyManagementSystem, + model.NetworkManagementFeatureSetTypeSmart) + + entity := NewEntityLocal(device, model.EntityTypeTypeCEM, + []model.AddressEntityType{1}, 4*time.Second) + device.AddEntity(entity) + + diagFeature := NewFeatureLocal( + entity.NextFeatureId(), entity, + model.FeatureTypeTypeDeviceDiagnosis, model.RoleTypeServer) + diagFeature.AddFunctionType(model.FunctionTypeDeviceDiagnosisHeartbeatData, true, false) + entity.AddFeature(diagFeature) + + stabilize() + baseline := countGoroutines() + + // 5 start/stop cycles — goroutine count should be stable + for i := 0; i < 5; i++ { + _ = entity.HeartbeatManager().StartHeartbeat() + time.Sleep(20 * time.Millisecond) + entity.HeartbeatManager().StopHeartbeat() + } + + stabilize() + after := countGoroutines() + + assert.LessOrEqual(s.T(), after, baseline+1, + "goroutine count should not grow after repeated start/stop (ticker properly stopped)") +} + +// ========================================================================= +// Issue 6: Event handlers can be drained after teardown +// ========================================================================= + +func (s *LifecycleIssuesSuite) TestIssue6_EventHandlersDrainedAfterTeardown() { + writer := &trackingWriter{} + handler := newSlowEventHandler(300 * time.Millisecond) + + device := NewDeviceLocal( + "Demo", "Model", "Serial", "Code", "device-1", + model.DeviceTypeTypeEnergyManagementSystem, + model.NetworkManagementFeatureSetTypeSmart) + + _ = device.events.Subscribe(handler) + defer func() { _ = device.events.Unsubscribe(handler) }() + + ski := "remote-1" + _ = device.SetupRemoteDevice(ski, writer) + + device.RemoveRemoteDeviceConnection(ski) + + // Handler is still running — drain waits for it + device.events.drain() + + started := handler.started.Load() + finished := handler.finished.Load() + assert.Equal(s.T(), started, finished, + "all event handlers should be finished after drain()") +} + +// ========================================================================= +// Issue 7: No double response from ApproveOrDenyWrite +// ========================================================================= + +func (s *LifecycleIssuesSuite) TestIssue7_NoDoubleResponse() { + writer := &trackingWriter{} + + device, localEntity := createLocalDeviceAndEntity(1) + _, serverFeature := createLocalFeatures(localEntity, model.FeatureTypeTypeLoadControl, model.FunctionTypeLoadControlLimitListData) + + serverFeature.AddWriteApprovalCallback(func(msg *api.Message) {}) + + ski := "remote-1" + sender := NewSender(writer) + remoteDevice := createRemoteDevice(device, ski, sender) + device.AddRemoteDeviceForSki(ski, remoteDevice) + + remoteFeature, _ := createRemoteEntityAndFeature(remoteDevice, 1, + model.FeatureTypeTypeLoadControl, model.FunctionTypeLoadControlLimitListData) + + serverFeature.SetWriteApprovalTimeout(80 * time.Millisecond) + + msgCounter := model.MsgCounterType(99) + msg := &api.Message{ + RequestHeader: &model.HeaderType{ + MsgCounter: util.Ptr(msgCounter), + AddressSource: &model.FeatureAddressType{ + Device: remoteDevice.Address(), + Entity: remoteDevice.Entity([]model.AddressEntityType{1}).Address().Entity, + Feature: remoteFeature.Address().Feature, + }, + AddressDestination: serverFeature.Address(), + }, + DeviceRemote: remoteDevice, + EntityRemote: remoteDevice.Entity([]model.AddressEntityType{1}), + FeatureRemote: remoteFeature, + } + + serverFeature.(*FeatureLocal).addPendingApproval(msg) + + // Wait close to timeout, then approve + time.Sleep(70 * time.Millisecond) + serverFeature.ApproveOrDenyWrite(msg, model.ErrorType{ErrorNumber: model.ErrorNumberType(0)}) + + // Wait for any timer to fire + time.Sleep(100 * time.Millisecond) + + assert.LessOrEqual(s.T(), writer.callCount(), int64(1), + "at most 1 response should be sent, not both timer error and approval success") + + device.RemoveRemoteDevice(ski) +} + +// ========================================================================= +// Issue 9: SetupRemoteDevice cleans up existing device +// ========================================================================= + +func (s *LifecycleIssuesSuite) TestIssue9_SetupRemoteDeviceCleansUpExisting() { + writer := &trackingWriter{} + + device := NewDeviceLocal( + "Demo", "Model", "Serial", "Code", "device-1", + model.DeviceTypeTypeEnergyManagementSystem, + model.NetworkManagementFeatureSetTypeSmart) + + ski := "remote-1" + + reader1 := device.SetupRemoteDevice(ski, writer) + assert.NotNil(s.T(), reader1) + + device1 := device.RemoteDeviceForSki(ski) + assert.NotNil(s.T(), device1) + + // Second setup for same SKI — should clean up old device first + reader2 := device.SetupRemoteDevice(ski, writer) + assert.NotNil(s.T(), reader2) + + device2 := device.RemoteDeviceForSki(ski) + assert.NotNil(s.T(), device2) + + // New device should be different object from old one + assert.NotEqual(s.T(), device1, device2, + "second SetupRemoteDevice should create a new device, not reuse the old one") + + device.RemoveRemoteDeviceConnection(ski) +} + +// ========================================================================= +// Combined: Full lifecycle with Close() — all clean +// ========================================================================= + +func (s *LifecycleIssuesSuite) TestCombined_FullLifecycleWithClose() { + writer := &trackingWriter{} + handler := newSlowEventHandler(200 * time.Millisecond) + + stabilize() + goroutinesBefore := countGoroutines() + + device := NewDeviceLocal( + "Demo", "Model", "Serial", "Code", "device-1", + model.DeviceTypeTypeEnergyManagementSystem, + model.NetworkManagementFeatureSetTypeSmart) + + _ = device.events.Subscribe(handler) + defer func() { + _ = device.events.Unsubscribe(handler) + stabilize() + }() + + for i := uint(1); i <= 3; i++ { + entity := NewEntityLocal(device, model.EntityTypeTypeCEM, + []model.AddressEntityType{model.AddressEntityType(i)}, 4*time.Second) + device.AddEntity(entity) + + diagFeature := NewFeatureLocal( + entity.NextFeatureId(), entity, + model.FeatureTypeTypeDeviceDiagnosis, model.RoleTypeServer) + diagFeature.AddFunctionType(model.FunctionTypeDeviceDiagnosisHeartbeatData, true, false) + entity.AddFeature(diagFeature) + } + + for i := 1; i <= 2; i++ { + ski := fmt.Sprintf("remote-%d", i) + _ = device.SetupRemoteDevice(ski, writer) + } + + stabilize() + goroutinesRunning := countGoroutines() + assert.Greater(s.T(), goroutinesRunning, goroutinesBefore, + "heartbeat goroutines should be running") + + // Single Close() call cleans up everything + device.Close() + + stabilize() + goroutinesAfter := countGoroutines() + + assert.LessOrEqual(s.T(), goroutinesAfter, goroutinesBefore+1, + "all goroutines should be cleaned up after Close()") + assert.Empty(s.T(), device.RemoteDevices(), + "all remote devices should be removed after Close()") + + // Only entity[0] should remain + entities := device.Entities() + assert.Equal(s.T(), 1, len(entities), + "only DeviceInformation entity should remain after Close()") +}