Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions device/device_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,8 @@ func (t *fakeTUNDeviceSized) Name() (string, error) { ret
func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil }
func (t *fakeTUNDeviceSized) Close() error { return nil }
func (t *fakeTUNDeviceSized) BatchSize() int { return t.size }
func (t *fakeTUNDeviceSized) MinOffset() int { return 0 }
func (t *fakeTUNDeviceSized) SetCarrier(bool) error { return nil }

func TestBatchSize(t *testing.T) {
d := Device{}
Expand Down
104 changes: 93 additions & 11 deletions tun/checksum.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,34 @@ import (
"math/bits"
)

// TODO: Explore SIMD and/or other assembly optimizations.
func checksumNoFold(b []byte, initial uint64) uint64 {
// IP protocol constants
const (
ProtocolICMP4 = 1
ProtocolTCP = 6
ProtocolUDP = 17
ProtocolICMP6 = 58
)

const (
IPv4SrcAddrOffset = 12
IPv6SrcAddrOffset = 8
)

var (
// PseudoHeaderProtocolTCP TCP protocol field of the TCP pseudoheader
PseudoHeaderProtocolTCP = []byte{0, ProtocolTCP}
// PseudoHeaderProtocolUDP UDP protocol field of the UDP pseudoheader
PseudoHeaderProtocolUDP = []byte{0, ProtocolUDP}
// PseudoHeaderProtocolMap provides dispatch for IP protocol to the corresponding protocol pseudo-header field
PseudoHeaderProtocolMap = map[uint8][]byte{
ProtocolTCP: PseudoHeaderProtocolTCP,
ProtocolUDP: PseudoHeaderProtocolUDP,
}
)

// ChecksumNoFold performs intermediate checksum computation per RFC 1071
func ChecksumNoFold(b []byte, initial uint64) uint64 {
// TODO: Explore SIMD and/or other assembly optimizations.
tmp := make([]byte, 8)
binary.NativeEndian.PutUint64(tmp, initial)
ac := binary.BigEndian.Uint64(tmp)
Expand Down Expand Up @@ -83,20 +109,76 @@ func checksumNoFold(b []byte, initial uint64) uint64 {
return binary.BigEndian.Uint64(tmp)
}

func checksum(b []byte, initial uint64) uint16 {
ac := checksumNoFold(b, initial)
// Checksum performs final checksum computation per RFC 1071
func Checksum(b []byte, initial uint64) uint16 {
ac := ChecksumNoFold(b, initial)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
return uint16(ac)
}

func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
sum := checksumNoFold(srcAddr, 0)
sum = checksumNoFold(dstAddr, sum)
sum = checksumNoFold([]byte{0, protocol}, sum)
tmp := make([]byte, 2)
binary.BigEndian.PutUint16(tmp, totalLen)
return checksumNoFold(tmp, sum)
// PseudoHeaderChecksumNoFold performs intermediate checksum computation for TCP/UDP pseudoheader values
func PseudoHeaderChecksumNoFold(protocol, srcDstAddr, totalLen []byte) uint64 {
sum := ChecksumNoFold(srcDstAddr, 0)
sum = ChecksumNoFold(protocol, sum)
return ChecksumNoFold(totalLen, sum)
}

// ComputeIPChecksum updates IP and TCP/UDP checksums
func ComputeIPChecksum(pkt []byte) {
ComputeIPChecksumBuffer(pkt, false)
}

// ComputeIPChecksumBuffer updates IP and TCP/UDP checksums using the provided length buffer of size 2
func ComputeIPChecksumBuffer(pkt []byte, partial bool) {
var (
lenbuf [2]byte
addrsum uint64
protocol uint8
headerLen int
totalLen uint16
)

if pkt[0]>>4 == 4 {
pkt[10], pkt[11] = 0, 0 // clear IP header checksum
protocol = pkt[9]
ihl := pkt[0] & 0xF
headerLen = int(ihl * 4)
totalLen = binary.BigEndian.Uint16(pkt[2:])
addrsum = ChecksumNoFold(pkt[IPv4SrcAddrOffset:IPv4SrcAddrOffset+8], 0)
binary.BigEndian.PutUint16(pkt[10:], ^Checksum(pkt[:IPv4SrcAddrOffset], addrsum))
} else {
protocol = pkt[6]
headerLen = 40
totalLen = 40 + binary.BigEndian.Uint16(pkt[4:])
addrsum = ChecksumNoFold(pkt[IPv6SrcAddrOffset:IPv6SrcAddrOffset+32], 0)
}

switch protocol {
case ProtocolTCP:
pkt[headerLen+16], pkt[headerLen+17] = 0, 0
binary.BigEndian.PutUint16(lenbuf[:], totalLen-uint16(headerLen))
tcpCSum := ChecksumNoFold(PseudoHeaderProtocolTCP, addrsum)
tcpCSum = ChecksumNoFold(lenbuf[:], tcpCSum)
if partial {
binary.BigEndian.PutUint16(pkt[headerLen+16:], Checksum([]byte{}, tcpCSum))
} else {
binary.BigEndian.PutUint16(pkt[headerLen+16:], ^Checksum(pkt[headerLen:totalLen], tcpCSum))
}
case ProtocolUDP:
pkt[headerLen+6], pkt[headerLen+7] = 0, 0
binary.BigEndian.PutUint16(lenbuf[:], totalLen-uint16(headerLen))
udpCSum := ChecksumNoFold(PseudoHeaderProtocolUDP, addrsum)
udpCSum = ChecksumNoFold(lenbuf[:], udpCSum)
if partial {
binary.BigEndian.PutUint16(pkt[headerLen+6:], Checksum([]byte{}, udpCSum))
} else {
binary.BigEndian.PutUint16(pkt[headerLen+6:], ^Checksum(pkt[headerLen:totalLen], udpCSum))
}
case ProtocolICMP4, ProtocolICMP6:
pkt[headerLen+2], pkt[headerLen+3] = 0, 0
binary.BigEndian.PutUint16(pkt[headerLen+2:], ^Checksum(pkt[headerLen:totalLen], 0))
}
}
22 changes: 12 additions & 10 deletions tun/checksum_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestChecksum(t *testing.T) {
buf := make([]byte, length)
rng := rand.New(rand.NewSource(1))
rng.Read(buf)
csum := checksum(buf, 0x1234)
csum := Checksum(buf, 0x1234)
csumRef := checksumRef(buf, 0x1234)
if csum != csumRef {
t.Error("Expected checksum", csumRef, "got", csum)
Expand All @@ -49,18 +49,20 @@ func TestChecksum(t *testing.T) {
}

func TestPseudoHeaderChecksum(t *testing.T) {
lenbuf := make([]byte, 2)

for _, addrLen := range []int{4, 16} {
for length := 0; length <= 9001; length++ {
srcAddr := make([]byte, addrLen)
dstAddr := make([]byte, addrLen)
buf := make([]byte, length)
srcDstAddr := make([]byte, addrLen*2)
rng := rand.New(rand.NewSource(1))
rng.Read(srcAddr)
rng.Read(dstAddr)
rng.Read(srcDstAddr)
rng.Read(srcDstAddr[addrLen:])
buf := make([]byte, length)
rng.Read(buf)
phSum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(length))
csum := checksum(buf, phSum)
phSumRef := pseudoHeaderChecksumRefNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(length))
binary.BigEndian.PutUint16(lenbuf, uint16(length))
phSum := PseudoHeaderChecksumNoFold(PseudoHeaderProtocolTCP, srcDstAddr, lenbuf)
csum := Checksum(buf, phSum)
phSumRef := pseudoHeaderChecksumRefNoFold(unix.IPPROTO_TCP, srcDstAddr[:addrLen], srcDstAddr[addrLen:], uint16(length))
csumRef := checksumRef(buf, phSumRef)
if csum != csumRef {
t.Error("Expected checksumRef", csumRef, "got", csum)
Expand Down Expand Up @@ -91,7 +93,7 @@ func BenchmarkChecksum(b *testing.B) {
rng.Read(buf)
b.ResetTimer()
for i := 0; i < b.N; i++ {
checksum(buf, 0)
Checksum(buf, 0)
}
})
}
Expand Down
8 changes: 8 additions & 0 deletions tun/netstack/tun.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,14 @@ func (tun *netTun) BatchSize() int {
return 1
}

func (tun *netTun) MinOffset() int {
return 0
}

func (tun *netTun) SetCarrier(bool) error {
return nil
}

func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
var protoNumber tcpip.NetworkProtocolNumber
if endpoint.Addr().Is4() {
Expand Down
Loading