Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 51 additions & 7 deletions spine/device_local.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ var _ api.DeviceLocalInterface = (*DeviceLocal)(nil)

// Setup a new remote device with a given SKI and triggers SPINE requesting device details
func (r *DeviceLocal) SetupRemoteDevice(ski string, writeI shipapi.ShipConnectionDataWriterInterface) shipapi.ShipConnectionDataReaderInterface {
// Clean up any existing device for this SKI (e.g., fast reconnect)
if existing := r.RemoteDeviceForSki(ski); existing != nil {
r.RemoveRemoteDevice(ski)
}

sender := NewSender(writeI)
rDevice := NewDeviceRemote(r, ski, sender)

Expand Down Expand Up @@ -286,6 +291,45 @@ func (r *DeviceLocal) RemoveEntity(entity api.EntityLocalInterface) {
r.notifySubscribersOfEntity(entity, model.NetworkManagementStateChangeTypeRemoved)
}

// Close shuts down the DeviceLocal, stopping all goroutines and cleaning up all state.
// It removes all remote device connections, waits for pending event handlers,
// removes all application entities (stopping heartbeats), and unsubscribes from events.
// Safe to call multiple times.
func (r *DeviceLocal) Close() {
// Snapshot remote device SKIs
r.mux.Lock()
skis := make([]string, 0, len(r.remoteDevices))
for ski := range r.remoteDevices {
skis = append(skis, ski)
}
r.mux.Unlock()

// Remove all remote devices (cleans subscriptions, bindings, write approval timers)
for _, ski := range skis {
r.RemoveRemoteDeviceConnection(ski)
}

// Wait for disconnect event handlers to finish
r.events.drain()

// Snapshot entities
r.mux.Lock()
entities := make([]api.EntityLocalInterface, len(r.entities))
copy(entities, r.entities)
r.mux.Unlock()

// Remove all application entities (stops heartbeats), skip DeviceInformation entity[0]
for _, entity := range entities {
addr := entity.Address().Entity
if len(addr) > 0 && addr[0] != model.AddressEntityType(DeviceInformationEntityId) {
r.RemoveEntity(entity)
}
}

// Unsubscribe from core events
_ = r.events.unsubscribe(api.EventHandlerLevelCore, r)
}

func (r *DeviceLocal) Entities() []api.EntityLocalInterface {
r.mux.Lock()
defer r.mux.Unlock()
Expand Down Expand Up @@ -346,28 +390,28 @@ func (r *DeviceLocal) ProcessCmd(datagram model.DatagramType, remoteDevice api.D
// Validate cmd.function consistency when filters are present
// Per SPINE spec section 5.3.4: "SHALL be present if datagram.payload.cmd.filter is present."
// The primary security concern is type confusion attacks when filters target wrong functions

filterPartial, filterDelete := cmd.ExtractFilter()
hasFilters := filterPartial != nil || filterDelete != nil

if hasFilters {
// Filters present: cmd.Function MUST be present and consistent
// This is the critical validation to prevent type confusion attacks
if err := cmd.ValidateFunctionConsistencyStrict(); err != nil {
inconsistencies := cmd.GetInconsistentFunctions()
errorMsg := fmt.Sprintf("cmd function validation failed: %s", err.Error())

// Log validation failure for security monitoring (non-sensitive info only)
logging.Log().Debugf("Command function validation failed: %s (inconsistencies: %d, device: %s, classifier: %v)",
err.Error(),
logging.Log().Debugf("Command function validation failed: %s (inconsistencies: %d, device: %s, classifier: %v)",
err.Error(),
len(inconsistencies),
remoteDevice.Address(),
cmdClassifier)

// Send proper error response to remote device
validationError := model.NewErrorType(model.ErrorNumberTypeCommandRejected, errorMsg)
_ = remoteDevice.Sender().ResultError(&datagram.Header, destAddr, validationError)

return fmt.Errorf("cmd function validation failed: %w", err)
}
}
Expand Down
93 changes: 93 additions & 0 deletions spine/device_local_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package spine

import (
"runtime"
"testing"
"time"

Expand Down Expand Up @@ -426,3 +427,95 @@ func (d *DeviceLocalTestSuite) Test_ProcessCmd() {
err = sut.ProcessCmd(datagram, remote)
assert.NotNil(d.T(), err)
}

// Step 6 — Issue 9: SetupRemoteDevice should clean up existing device for same SKI.
func (d *DeviceLocalTestSuite) Test_SetupRemoteDevice_CleansUpExisting() {
sut := NewDeviceLocal("brand", "model", "serial", "code", "address",
model.DeviceTypeTypeEnergyManagementSystem, model.NetworkManagementFeatureSetTypeSmart)

ski := "reconnect-ski"
_ = sut.SetupRemoteDevice(ski, d)
device1 := sut.RemoteDeviceForSki(ski)
assert.NotNil(d.T(), device1)

// Second setup for the same SKI — should clean up the first
_ = sut.SetupRemoteDevice(ski, d)
device2 := sut.RemoteDeviceForSki(ski)
assert.NotNil(d.T(), device2)

// The new device should be different from the old one
assert.NotEqual(d.T(), device1, device2)

// Only one device should exist for that SKI
remotes := sut.RemoteDevices()
skiCount := 0
for _, r := range remotes {
if r.Ski() == ski {
skiCount++
}
}
assert.Equal(d.T(), 1, skiCount)

sut.RemoveRemoteDeviceConnection(ski)
}

// Step 8 — Issue 1: Close() stops all goroutines and cleans up all state.
func (d *DeviceLocalTestSuite) Test_Close_StopsAllGoroutines() {
runtime.GC()
time.Sleep(50 * time.Millisecond)
baseline := runtime.NumGoroutine()

sut := NewDeviceLocal("brand", "model", "serial", "code", "address",
model.DeviceTypeTypeEnergyManagementSystem, model.NetworkManagementFeatureSetTypeSmart)

// Create 3 entities with heartbeat
for i := uint(1); i <= 3; i++ {
entity := NewEntityLocal(sut, model.EntityTypeTypeCEM,
[]model.AddressEntityType{model.AddressEntityType(i)}, 4*time.Second)
sut.AddEntity(entity)

diagFeature := NewFeatureLocal(entity.NextFeatureId(), entity,
model.FeatureTypeTypeDeviceDiagnosis, model.RoleTypeServer)
diagFeature.AddFunctionType(model.FunctionTypeDeviceDiagnosisHeartbeatData, true, false)
entity.AddFeature(diagFeature)
}

// Connect 2 remote devices
_ = sut.SetupRemoteDevice("remote-1", d)
_ = sut.SetupRemoteDevice("remote-2", d)

time.Sleep(50 * time.Millisecond)
running := runtime.NumGoroutine()
assert.Greater(d.T(), running, baseline, "heartbeat goroutines should be running")

// Close should clean up everything
sut.Close()

time.Sleep(50 * time.Millisecond)
after := runtime.NumGoroutine()
assert.LessOrEqual(d.T(), after, baseline+1,
"all goroutines should be stopped after Close()")

// All remote devices should be gone
assert.Empty(d.T(), sut.RemoteDevices())

// Only entity[0] should remain
entities := sut.Entities()
assert.Equal(d.T(), 1, len(entities))
}

// Step 8 — Close() is idempotent.
func (d *DeviceLocalTestSuite) Test_Close_Idempotent() {
sut := NewDeviceLocal("brand", "model", "serial", "code", "address",
model.DeviceTypeTypeEnergyManagementSystem, model.NetworkManagementFeatureSetTypeSmart)

entity := NewEntityLocal(sut, model.EntityTypeTypeCEM,
[]model.AddressEntityType{1}, 4*time.Second)
sut.AddEntity(entity)

_ = sut.SetupRemoteDevice("ski-1", d)

// Close twice — should not panic
sut.Close()
sut.Close()
}
13 changes: 12 additions & 1 deletion spine/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type eventHandlerItem struct {
type events struct {
mu sync.Mutex
muHandle sync.Mutex
wg sync.WaitGroup

handlers []eventHandlerItem // event handling outside of the core stack
}
Expand Down Expand Up @@ -105,9 +106,19 @@ func (r *events) Publish(payload api.EventPayload) {
// and expected actions are taken
item.Handler.HandleEvent(payload)
} else {
go item.Handler.HandleEvent(payload)
r.wg.Add(1)
go func(h api.EventHandlerInterface) {
defer r.wg.Done()
h.HandleEvent(payload)
}(item.Handler)
}
}
}
r.muHandle.Unlock()
}

// drain waits for all dispatched application-level event handlers to complete.
// Used during shutdown to ensure no handlers are still running.
func (r *events) drain() {
r.wg.Wait()
}
37 changes: 37 additions & 0 deletions spine/events_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,40 @@ func (s *EventsTestSuite) Test_Publish_Application() {
err = s.events.Unsubscribe(s)
assert.Nil(s.T(), err)
}

// Step 7 — Issue 6: drain() must wait for all dispatched handlers to complete.
type slowTestHandler struct {
mu sync.Mutex
finished bool
}

func (h *slowTestHandler) HandleEvent(event api.EventPayload) {
time.Sleep(200 * time.Millisecond)
h.mu.Lock()
h.finished = true
h.mu.Unlock()
}

func (h *slowTestHandler) isFinished() bool {
h.mu.Lock()
defer h.mu.Unlock()
return h.finished
}

func (s *EventsTestSuite) Test_Drain_WaitsForHandlers() {
handler := &slowTestHandler{}
err := s.events.Subscribe(handler)
assert.Nil(s.T(), err)

s.events.Publish(api.EventPayload{})

// Handler is still running at this point
assert.False(s.T(), handler.isFinished())

// drain() should block until handler completes
s.events.drain()
assert.True(s.T(), handler.isFinished())

err = s.events.Unsubscribe(handler)
assert.Nil(s.T(), err)
}
33 changes: 22 additions & 11 deletions spine/feature_local.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

writeTimeout time.Duration
writeApprovalCallbacks []api.WriteApprovalCallbackFunc
muxWriteReceived sync.Mutex
writeApprovalReceived map[string]map[model.MsgCounterType]int
pendingWriteApprovals map[string]map[model.MsgCounterType]*time.Timer

Expand Down Expand Up @@ -82,7 +81,7 @@
}
}
// Partial reads are intentionally not supported (spec-compliant design decision)
// SPINE specification section 5.3.4.5 states: "A server MAY ignore unsupported cmdOption

Check failure on line 84 in spine/feature_local.go

View workflow job for this annotation

GitHub Actions / Build

File is not properly formatted (gofmt)
// combinations and then replies with more than the requested parts instead."
// By setting readPartial to false, we ensure all read requests return full data,
// which provides the safest interoperability behavior for multi-vendor scenarios.
Expand Down Expand Up @@ -193,6 +192,10 @@

newTimer := time.AfterFunc(r.writeTimeout, func() {
r.muxResponseCB.Lock()
if _, ok := r.pendingWriteApprovals[ski]; !ok {
r.muxResponseCB.Unlock()
return
}
delete(r.pendingWriteApprovals[ski], *msg.RequestHeader.MsgCounter)
r.muxResponseCB.Unlock()

Expand All @@ -217,18 +220,17 @@
ski := msg.DeviceRemote.Ski()

r.muxResponseCB.Lock()
timer, ok := r.pendingWriteApprovals[ski][*msg.RequestHeader.MsgCounter]
count := len(r.writeApprovalCallbacks)
r.muxResponseCB.Unlock()

// if there is no timer running, we are too late and error has already been sent
timer, ok := r.pendingWriteApprovals[ski][*msg.RequestHeader.MsgCounter]
// if there is no timer, we are too late and error has already been sent
if !ok || timer == nil {
r.muxResponseCB.Unlock()
return
}

count := len(r.writeApprovalCallbacks)

// do we have enough approvals?
r.muxWriteReceived.Lock()
defer r.muxWriteReceived.Unlock()
if count > 1 && err.ErrorNumber == 0 {
amount, ok := r.writeApprovalReceived[ski][*msg.RequestHeader.MsgCounter]
if ok {
Expand All @@ -239,18 +241,20 @@
}
// do we have enough approve messages, if not exit
if r.writeApprovalReceived[ski][*msg.RequestHeader.MsgCounter] < count {
r.muxResponseCB.Unlock()
return
}
}

// Atomically stop the timer and clean up entries under the lock.
// This prevents the TOCTOU race where the timer fires between lock acquisitions.
timer.Stop()

delete(r.writeApprovalReceived[ski], *msg.RequestHeader.MsgCounter)

r.muxResponseCB.Lock()
defer r.muxResponseCB.Unlock()
delete(r.pendingWriteApprovals[ski], *msg.RequestHeader.MsgCounter)

r.muxResponseCB.Unlock()

// Process outside the lock to avoid holding it during network I/O
if err.ErrorNumber == 0 {
r.processWrite(msg)
return
Expand All @@ -267,6 +271,13 @@
r.muxResponseCB.Lock()
defer r.muxResponseCB.Unlock()

// Stop all pending timers for this SKI before deleting
if timers, ok := r.pendingWriteApprovals[ski]; ok {
for _, timer := range timers {
timer.Stop()
}
}

delete(r.pendingWriteApprovals, ski)
delete(r.writeApprovalReceived, ski)
}
Expand Down
Loading
Loading