diff --git a/device/device_test.go b/device/device_test.go index 0091e2052..c75a07a4f 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -446,6 +446,8 @@ func (t *fakeTUNDeviceSized) Name() (string, error) { ret func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil } func (t *fakeTUNDeviceSized) Close() error { return nil } func (t *fakeTUNDeviceSized) BatchSize() int { return t.size } +func (t *fakeTUNDeviceSized) MinOffset() int { return 0 } +func (t *fakeTUNDeviceSized) SetCarrier(bool) error { return nil } func TestBatchSize(t *testing.T) { d := Device{} diff --git a/tun/checksum.go b/tun/checksum.go index b489c56f5..7c11d3658 100644 --- a/tun/checksum.go +++ b/tun/checksum.go @@ -5,8 +5,34 @@ import ( "math/bits" ) -// TODO: Explore SIMD and/or other assembly optimizations. -func checksumNoFold(b []byte, initial uint64) uint64 { +// IP protocol constants +const ( + ProtocolICMP4 = 1 + ProtocolTCP = 6 + ProtocolUDP = 17 + ProtocolICMP6 = 58 +) + +const ( + IPv4SrcAddrOffset = 12 + IPv6SrcAddrOffset = 8 +) + +var ( + // PseudoHeaderProtocolTCP TCP protocol field of the TCP pseudoheader + PseudoHeaderProtocolTCP = []byte{0, ProtocolTCP} + // PseudoHeaderProtocolUDP UDP protocol field of the UDP pseudoheader + PseudoHeaderProtocolUDP = []byte{0, ProtocolUDP} + // PseudoHeaderProtocolMap provides dispatch for IP protocol to the corresponding protocol pseudo-header field + PseudoHeaderProtocolMap = map[uint8][]byte{ + ProtocolTCP: PseudoHeaderProtocolTCP, + ProtocolUDP: PseudoHeaderProtocolUDP, + } +) + +// ChecksumNoFold performs intermediate checksum computation per RFC 1071 +func ChecksumNoFold(b []byte, initial uint64) uint64 { + // TODO: Explore SIMD and/or other assembly optimizations. tmp := make([]byte, 8) binary.NativeEndian.PutUint64(tmp, initial) ac := binary.BigEndian.Uint64(tmp) @@ -83,8 +109,9 @@ func checksumNoFold(b []byte, initial uint64) uint64 { return binary.BigEndian.Uint64(tmp) } -func checksum(b []byte, initial uint64) uint16 { - ac := checksumNoFold(b, initial) +// Checksum performs final checksum computation per RFC 1071 +func Checksum(b []byte, initial uint64) uint16 { + ac := ChecksumNoFold(b, initial) ac = (ac >> 16) + (ac & 0xffff) ac = (ac >> 16) + (ac & 0xffff) ac = (ac >> 16) + (ac & 0xffff) @@ -92,11 +119,66 @@ func checksum(b []byte, initial uint64) uint16 { return uint16(ac) } -func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 { - sum := checksumNoFold(srcAddr, 0) - sum = checksumNoFold(dstAddr, sum) - sum = checksumNoFold([]byte{0, protocol}, sum) - tmp := make([]byte, 2) - binary.BigEndian.PutUint16(tmp, totalLen) - return checksumNoFold(tmp, sum) +// PseudoHeaderChecksumNoFold performs intermediate checksum computation for TCP/UDP pseudoheader values +func PseudoHeaderChecksumNoFold(protocol, srcDstAddr, totalLen []byte) uint64 { + sum := ChecksumNoFold(srcDstAddr, 0) + sum = ChecksumNoFold(protocol, sum) + return ChecksumNoFold(totalLen, sum) +} + +// ComputeIPChecksum updates IP and TCP/UDP checksums +func ComputeIPChecksum(pkt []byte) { + ComputeIPChecksumBuffer(pkt, false) +} + +// ComputeIPChecksumBuffer updates IP and TCP/UDP checksums using the provided length buffer of size 2 +func ComputeIPChecksumBuffer(pkt []byte, partial bool) { + var ( + lenbuf [2]byte + addrsum uint64 + protocol uint8 + headerLen int + totalLen uint16 + ) + + if pkt[0]>>4 == 4 { + pkt[10], pkt[11] = 0, 0 // clear IP header checksum + protocol = pkt[9] + ihl := pkt[0] & 0xF + headerLen = int(ihl * 4) + totalLen = binary.BigEndian.Uint16(pkt[2:]) + addrsum = ChecksumNoFold(pkt[IPv4SrcAddrOffset:IPv4SrcAddrOffset+8], 0) + binary.BigEndian.PutUint16(pkt[10:], ^Checksum(pkt[:IPv4SrcAddrOffset], addrsum)) + } else { + protocol = pkt[6] + headerLen = 40 + totalLen = 40 + binary.BigEndian.Uint16(pkt[4:]) + addrsum = ChecksumNoFold(pkt[IPv6SrcAddrOffset:IPv6SrcAddrOffset+32], 0) + } + + switch protocol { + case ProtocolTCP: + pkt[headerLen+16], pkt[headerLen+17] = 0, 0 + binary.BigEndian.PutUint16(lenbuf[:], totalLen-uint16(headerLen)) + tcpCSum := ChecksumNoFold(PseudoHeaderProtocolTCP, addrsum) + tcpCSum = ChecksumNoFold(lenbuf[:], tcpCSum) + if partial { + binary.BigEndian.PutUint16(pkt[headerLen+16:], Checksum([]byte{}, tcpCSum)) + } else { + binary.BigEndian.PutUint16(pkt[headerLen+16:], ^Checksum(pkt[headerLen:totalLen], tcpCSum)) + } + case ProtocolUDP: + pkt[headerLen+6], pkt[headerLen+7] = 0, 0 + binary.BigEndian.PutUint16(lenbuf[:], totalLen-uint16(headerLen)) + udpCSum := ChecksumNoFold(PseudoHeaderProtocolUDP, addrsum) + udpCSum = ChecksumNoFold(lenbuf[:], udpCSum) + if partial { + binary.BigEndian.PutUint16(pkt[headerLen+6:], Checksum([]byte{}, udpCSum)) + } else { + binary.BigEndian.PutUint16(pkt[headerLen+6:], ^Checksum(pkt[headerLen:totalLen], udpCSum)) + } + case ProtocolICMP4, ProtocolICMP6: + pkt[headerLen+2], pkt[headerLen+3] = 0, 0 + binary.BigEndian.PutUint16(pkt[headerLen+2:], ^Checksum(pkt[headerLen:totalLen], 0)) + } } diff --git a/tun/checksum_test.go b/tun/checksum_test.go index 4ea9b8b52..b398cb5f5 100644 --- a/tun/checksum_test.go +++ b/tun/checksum_test.go @@ -40,7 +40,7 @@ func TestChecksum(t *testing.T) { buf := make([]byte, length) rng := rand.New(rand.NewSource(1)) rng.Read(buf) - csum := checksum(buf, 0x1234) + csum := Checksum(buf, 0x1234) csumRef := checksumRef(buf, 0x1234) if csum != csumRef { t.Error("Expected checksum", csumRef, "got", csum) @@ -49,18 +49,20 @@ func TestChecksum(t *testing.T) { } func TestPseudoHeaderChecksum(t *testing.T) { + lenbuf := make([]byte, 2) + for _, addrLen := range []int{4, 16} { for length := 0; length <= 9001; length++ { - srcAddr := make([]byte, addrLen) - dstAddr := make([]byte, addrLen) - buf := make([]byte, length) + srcDstAddr := make([]byte, addrLen*2) rng := rand.New(rand.NewSource(1)) - rng.Read(srcAddr) - rng.Read(dstAddr) + rng.Read(srcDstAddr) + rng.Read(srcDstAddr[addrLen:]) + buf := make([]byte, length) rng.Read(buf) - phSum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(length)) - csum := checksum(buf, phSum) - phSumRef := pseudoHeaderChecksumRefNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(length)) + binary.BigEndian.PutUint16(lenbuf, uint16(length)) + phSum := PseudoHeaderChecksumNoFold(PseudoHeaderProtocolTCP, srcDstAddr, lenbuf) + csum := Checksum(buf, phSum) + phSumRef := pseudoHeaderChecksumRefNoFold(unix.IPPROTO_TCP, srcDstAddr[:addrLen], srcDstAddr[addrLen:], uint16(length)) csumRef := checksumRef(buf, phSumRef) if csum != csumRef { t.Error("Expected checksumRef", csumRef, "got", csum) @@ -91,7 +93,7 @@ func BenchmarkChecksum(b *testing.B) { rng.Read(buf) b.ResetTimer() for i := 0; i < b.N; i++ { - checksum(buf, 0) + Checksum(buf, 0) } }) } diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go index a7aec9e82..fdd1db4a4 100644 --- a/tun/netstack/tun.go +++ b/tun/netstack/tun.go @@ -191,6 +191,14 @@ func (tun *netTun) BatchSize() int { return 1 } +func (tun *netTun) MinOffset() int { + return 0 +} + +func (tun *netTun) SetCarrier(bool) error { + return nil +} + func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { var protoNumber tcpip.NetworkProtocolNumber if endpoint.Addr().Is4() { diff --git a/tun/offload_linux.go b/tun/offload_linux.go index 5f0db062c..acdfd3e2c 100644 --- a/tun/offload_linux.go +++ b/tun/offload_linux.go @@ -10,6 +10,7 @@ import ( "encoding/binary" "errors" "io" + "math" "unsafe" "golang.org/x/sys/unix" @@ -385,16 +386,19 @@ func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet return coalesceUnavailable } -func checksumValid(pkt []byte, iphLen, proto uint8, isV6 bool) bool { - srcAddrAt := ipv4SrcAddrOffset +func ChecksumValid(pkt []byte, iphLen, proto uint8, isV6 bool) bool { + var lenBuf [2]byte + + srcAddrAt := IPv4SrcAddrOffset addrSize := 4 if isV6 { - srcAddrAt = ipv6SrcAddrOffset + srcAddrAt = IPv6SrcAddrOffset addrSize = 16 } lenForPseudo := uint16(len(pkt) - int(iphLen)) - cSum := pseudoHeaderChecksumNoFold(proto, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], lenForPseudo) - return ^checksum(pkt[iphLen:], cSum) == 0 + binary.BigEndian.PutUint16(lenBuf[:], lenForPseudo) + cSum := PseudoHeaderChecksumNoFold(PseudoHeaderProtocolMap[proto], pkt[srcAddrAt:srcAddrAt+addrSize*2], lenBuf[:]) + return ^Checksum(pkt[iphLen:], cSum) == 0 } // coalesceResult represents the result of attempting to coalesce two TCP @@ -422,11 +426,11 @@ func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset return coalesceInsufficientCap } if item.numMerged == 0 { - if item.cSumKnownInvalid || !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_UDP, isV6) { + if item.cSumKnownInvalid || !ChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_UDP, isV6) { return coalesceItemInvalidCSum } } - if !checksumValid(pkt, item.iphLen, unix.IPPROTO_UDP, isV6) { + if !ChecksumValid(pkt, item.iphLen, unix.IPPROTO_UDP, isV6) { return coalescePktInvalidCSum } extendBy := len(pkt) - int(headersLen) @@ -458,11 +462,11 @@ func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize return coalescePSHEnding } if item.numMerged == 0 { - if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) { + if !ChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) { return coalesceItemInvalidCSum } } - if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) { + if !ChecksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) { return coalescePktInvalidCSum } item.sentSeq = seq @@ -480,11 +484,11 @@ func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize return coalesceInsufficientCap } if item.numMerged == 0 { - if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) { + if !ChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) { return coalesceItemInvalidCSum } } - if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) { + if !ChecksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) { return coalescePktInvalidCSum } if pshSet { @@ -509,12 +513,6 @@ const ( ipv4FlagMoreFragments uint8 = 0x20 ) -const ( - ipv4SrcAddrOffset = 12 - ipv6SrcAddrOffset = 8 - maxUint16 = 1<<16 - 1 -) - type groResult int const ( @@ -530,7 +528,7 @@ const ( // coalesced with another packet in table. func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) groResult { pkt := bufs[pktI][offset:] - if len(pkt) > maxUint16 { + if len(pkt) > math.MaxUint16 { // A valid IPv4 or IPv6 packet will never exceed this. return groResultNoop } @@ -578,10 +576,10 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) return groResultNoop } seq := binary.BigEndian.Uint32(pkt[iphLen+4:]) - srcAddrOffset := ipv4SrcAddrOffset + srcAddrOffset := IPv4SrcAddrOffset addrLen := 4 if isV6 { - srcAddrOffset = ipv6SrcAddrOffset + srcAddrOffset = IPv6SrcAddrOffset addrLen = 16 } items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) @@ -622,7 +620,7 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) // applyTCPCoalesceAccounting updates bufs to account for coalescing based on the // metadata found in table. -func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) error { +func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, forceCSum bool) error { for _, items := range table.itemsByFlow { for _, item := range items { if item.numMerged > 0 { @@ -642,10 +640,7 @@ func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) e binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len } else { hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4 - pkt[10], pkt[11] = 0, 0 binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length - iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum - binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field } err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) if err != nil { @@ -655,18 +650,12 @@ func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) e // Calculate the pseudo header checksum and place it at the TCP // checksum offset. Downstream checksum offloading will combine // this with computation of the tcp header and payload checksum. - addrLen := 4 - addrOffset := ipv4SrcAddrOffset - if item.key.isV6 { - addrLen = 16 - addrOffset = ipv6SrcAddrOffset - } - srcAddrAt := offset + addrOffset - srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] - dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] - psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) - binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) + ComputeIPChecksumBuffer(pkt, true) } else { + if forceCSum { + ComputeIPChecksumBuffer(bufs[item.bufsIndex][offset:], false) + } + hdr := virtioNetHdr{} err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) if err != nil { @@ -680,7 +669,7 @@ func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) e // applyUDPCoalesceAccounting updates bufs to account for coalescing based on the // metadata found in table. -func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) error { +func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable, forceCSum bool) error { for _, items := range table.itemsByFlow { for _, item := range items { if item.numMerged > 0 { @@ -699,10 +688,7 @@ func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) e if item.key.isV6 { binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len } else { - pkt[10], pkt[11] = 0, 0 binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length - iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum - binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field } err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) if err != nil { @@ -715,18 +701,12 @@ func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) e // Calculate the pseudo header checksum and place it at the UDP // checksum offset. Downstream checksum offloading will combine // this with computation of the udp header and payload checksum. - addrLen := 4 - addrOffset := ipv4SrcAddrOffset - if item.key.isV6 { - addrLen = 16 - addrOffset = ipv6SrcAddrOffset - } - srcAddrAt := offset + addrOffset - srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] - dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] - psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_UDP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) - binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) + ComputeIPChecksumBuffer(bufs[item.bufsIndex][offset:], true) } else { + if forceCSum { + ComputeIPChecksumBuffer(bufs[item.bufsIndex][offset:], false) + } + hdr := virtioNetHdr{} err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) if err != nil { @@ -785,7 +765,7 @@ const ( // coalesced with another packet in table. func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) groResult { pkt := bufs[pktI][offset:] - if len(pkt) > maxUint16 { + if len(pkt) > math.MaxUint16 { // A valid IPv4 or IPv6 packet will never exceed this. return groResultNoop } @@ -819,10 +799,10 @@ func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) if gsoSize < 1 { return groResultNoop } - srcAddrOffset := ipv4SrcAddrOffset + srcAddrOffset := IPv4SrcAddrOffset addrLen := 4 if isV6 { - srcAddrOffset = ipv6SrcAddrOffset + srcAddrOffset = IPv6SrcAddrOffset addrLen = 16 } items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI) @@ -862,7 +842,7 @@ func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) // empty (but non-nil), and are passed in to save allocs as the caller may reset // and recycle them across vectors of packets. canUDPGRO indicates if UDP GRO is // supported. -func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, canUDPGRO bool, toWrite *[]int) error { +func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, canUDPGRO, forceCSum bool, toWrite *[]int) error { for i := range bufs { if offset < virtioNetHdrLen || offset > len(bufs[i])-1 { return errors.New("invalid offset") @@ -878,8 +858,13 @@ func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGR case udp6GROCandidate: result = udpGRO(bufs, offset, i, udpTable, true) } + switch result { case groResultNoop: + if forceCSum { + ComputeIPChecksumBuffer(bufs[i][offset:], false) + } + hdr := virtioNetHdr{} err := hdr.encode(bufs[i][offset-virtioNetHdrLen:]) if err != nil { @@ -890,8 +875,8 @@ func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGR *toWrite = append(*toWrite, i) } } - errTCP := applyTCPCoalesceAccounting(bufs, offset, tcpTable) - errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable) + errTCP := applyTCPCoalesceAccounting(bufs, offset, tcpTable, forceCSum) + errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable, forceCSum) return errors.Join(errTCP, errUDP) } @@ -900,22 +885,26 @@ func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGR // error. func gsoSplit(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int, isV6 bool) (int, error) { iphLen := int(hdr.csumStart) - srcAddrOffset := ipv6SrcAddrOffset + srcAddrOffset := IPv6SrcAddrOffset addrLen := 16 if !isV6 { in[10], in[11] = 0, 0 // clear ipv4 header checksum - srcAddrOffset = ipv4SrcAddrOffset + srcAddrOffset = IPv4SrcAddrOffset addrLen = 4 } transportCsumAt := int(hdr.csumStart + hdr.csumOffset) in[transportCsumAt], in[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum var firstTCPSeqNum uint32 var protocol uint8 + var protocolfield []byte + var lenBuf [2]byte if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 || hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV6 { protocol = unix.IPPROTO_TCP + protocolfield = PseudoHeaderProtocolTCP firstTCPSeqNum = binary.BigEndian.Uint32(in[hdr.csumStart+4:]) } else { protocol = unix.IPPROTO_UDP + protocolfield = PseudoHeaderProtocolUDP } nextSegmentDataAt := int(hdr.hdrLen) i := 0 @@ -943,7 +932,7 @@ func gsoSplit(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOf binary.BigEndian.PutUint16(out[4:], id) } binary.BigEndian.PutUint16(out[2:], uint16(totalLen)) - ipv4CSum := ^checksum(out[:iphLen], 0) + ipv4CSum := ^Checksum(out[:iphLen], 0) binary.BigEndian.PutUint16(out[10:], ipv4CSum) } else { // For IPv6 we are responsible for updating the payload length field. @@ -972,9 +961,9 @@ func gsoSplit(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOf // transport checksum transportHeaderLen := int(hdr.hdrLen - hdr.csumStart) - lenForPseudo := uint16(transportHeaderLen + segmentDataLen) - transportCSumNoFold := pseudoHeaderChecksumNoFold(protocol, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo) - transportCSum := ^checksum(out[hdr.csumStart:totalLen], transportCSumNoFold) + binary.BigEndian.PutUint16(lenBuf[:], uint16(transportHeaderLen+segmentDataLen)) + transportCSumNoFold := PseudoHeaderChecksumNoFold(protocolfield, in[srcAddrOffset:srcAddrOffset+addrLen*2], lenBuf[:]) + transportCSum := ^Checksum(out[hdr.csumStart:totalLen], transportCSumNoFold) binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], transportCSum) nextSegmentDataAt += int(hdr.gsoSize) @@ -988,6 +977,6 @@ func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error { // checksum we compute. This is typically the pseudo-header checksum. initial := binary.BigEndian.Uint16(in[cSumAt:]) in[cSumAt], in[cSumAt+1] = 0, 0 - binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial))) + binary.BigEndian.PutUint16(in[cSumAt:], ^Checksum(in[cSumStart:], uint64(initial))) return nil } diff --git a/tun/offload_linux_test.go b/tun/offload_linux_test.go index d87e63612..f4504c389 100644 --- a/tun/offload_linux_test.go +++ b/tun/offload_linux_test.go @@ -6,6 +6,7 @@ package tun import ( + "fmt" "net/netip" "testing" @@ -236,6 +237,7 @@ func Test_handleVirtioRead(t *testing.T) { t.Run(tt.name, func(t *testing.T) { out := make([][]byte, conn.IdealBatchSize) sizes := make([]int, conn.IdealBatchSize) + for i := range out { out[i] = make([]byte, 65535) } @@ -288,244 +290,248 @@ func Fuzz_handleGRO(f *testing.F) { pkt11 := udp6Packet(ip6PortA, ip6PortC, 100) f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11, true, offset) f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, canUDPGRO bool, offset int) { - pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11} - toWrite := make([]int, 0, len(pkts)) - handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), canUDPGRO, &toWrite) - if len(toWrite) > len(pkts) { - t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts)) - } - seenWriteI := make(map[int]bool) - for _, writeI := range toWrite { - if writeI < 0 || writeI > len(pkts)-1 { - t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts)) + for _, forceCSum := range []bool{false, true} { + pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11} + toWrite := make([]int, 0, len(pkts)) + handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), canUDPGRO, forceCSum, &toWrite) + if len(toWrite) > len(pkts) { + t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts)) } - if seenWriteI[writeI] { - t.Errorf("duplicate toWrite value: %d", writeI) + seenWriteI := make(map[int]bool) + for _, writeI := range toWrite { + if writeI < 0 || writeI > len(pkts)-1 { + t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts)) + } + if seenWriteI[writeI] { + t.Errorf("duplicate toWrite value: %d", writeI) + } + seenWriteI[writeI] = true } - seenWriteI[writeI] = true } }) } func Test_handleGRO(t *testing.T) { - tests := []struct { - name string - pktsIn [][]byte - canUDPGRO bool - wantToWrite []int - wantLens []int - wantErr bool - }{ - { - "multiple protocols and flows", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1 - udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 - udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1 - tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1 - tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2 - udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 - udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 - udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 + for _, forceCSum := range []bool{true, false} { + tests := []struct { + name string + pktsIn [][]byte + canUDPGRO bool + wantToWrite []int + wantLens []int + wantErr bool + }{ + { + "multiple protocols and flows", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1 + udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 + udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1 + tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1 + tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2 + udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 + udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 + udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 + }, + true, + []int{0, 1, 2, 4, 5, 7, 9}, + []int{240, 228, 128, 140, 260, 160, 248}, + false, }, - true, - []int{0, 1, 2, 4, 5, 7, 9}, - []int{240, 228, 128, 140, 260, 160, 248}, - false, - }, - { - "multiple protocols and flows no UDP GRO", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1 - udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 - udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1 - tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1 - tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2 - udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 - udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 - udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 + { + "multiple protocols and flows no UDP GRO", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1 + udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 + udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1 + tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1 + tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2 + udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 + udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 + udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 + }, + false, + []int{0, 1, 2, 4, 5, 7, 8, 9, 10}, + []int{240, 128, 128, 140, 260, 160, 128, 148, 148}, + false, }, - false, - []int{0, 1, 2, 4, 5, 7, 8, 9, 10}, - []int{240, 128, 128, 140, 260, 160, 128, 148, 148}, - false, - }, - { - "PSH interleaved", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1 + { + "PSH interleaved", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1 + }, + true, + []int{0, 2, 4, 6}, + []int{240, 240, 260, 260}, + false, }, - true, - []int{0, 2, 4, 6}, - []int{240, 240, 260, 260}, - false, - }, - { - "coalesceItemInvalidCSum", - [][]byte{ - flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 - flipUDP4Checksum(udp4Packet(ip4PortA, ip4PortB, 100)), - udp4Packet(ip4PortA, ip4PortB, 100), - udp4Packet(ip4PortA, ip4PortB, 100), + { + "coalesceItemInvalidCSum", + [][]byte{ + flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 + flipUDP4Checksum(udp4Packet(ip4PortA, ip4PortB, 100)), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4Packet(ip4PortA, ip4PortB, 100), + }, + true, + []int{0, 1, 3, 4}, + []int{140, 240, 128, 228}, + false, }, - true, - []int{0, 1, 3, 4}, - []int{140, 240, 128, 228}, - false, - }, - { - "out of order", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 + { + "out of order", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 + }, + true, + []int{0}, + []int{340}, + false, }, - true, - []int{0}, - []int{340}, - false, - }, - { - "unequal TTL", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), - tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { - fields.TTL++ - }), - udp4Packet(ip4PortA, ip4PortB, 100), - udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { - fields.TTL++ - }), + { + "unequal TTL", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), + tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { + fields.TTL++ + }), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { + fields.TTL++ + }), + }, + true, + []int{0, 1, 2, 3}, + []int{140, 140, 128, 128}, + false, }, - true, - []int{0, 1, 2, 3}, - []int{140, 140, 128, 128}, - false, - }, - { - "unequal ToS", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), - tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { - fields.TOS++ - }), - udp4Packet(ip4PortA, ip4PortB, 100), - udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { - fields.TOS++ - }), + { + "unequal ToS", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), + tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { + fields.TOS++ + }), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { + fields.TOS++ + }), + }, + true, + []int{0, 1, 2, 3}, + []int{140, 140, 128, 128}, + false, }, - true, - []int{0, 1, 2, 3}, - []int{140, 140, 128, 128}, - false, - }, - { - "unequal flags more fragments set", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), - tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { - fields.Flags = 1 - }), - udp4Packet(ip4PortA, ip4PortB, 100), - udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { - fields.Flags = 1 - }), + { + "unequal flags more fragments set", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), + tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { + fields.Flags = 1 + }), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { + fields.Flags = 1 + }), + }, + true, + []int{0, 1, 2, 3}, + []int{140, 140, 128, 128}, + false, }, - true, - []int{0, 1, 2, 3}, - []int{140, 140, 128, 128}, - false, - }, - { - "unequal flags DF set", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), - tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { - fields.Flags = 2 - }), - udp4Packet(ip4PortA, ip4PortB, 100), - udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { - fields.Flags = 2 - }), + { + "unequal flags DF set", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), + tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { + fields.Flags = 2 + }), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { + fields.Flags = 2 + }), + }, + true, + []int{0, 1, 2, 3}, + []int{140, 140, 128, 128}, + false, }, - true, - []int{0, 1, 2, 3}, - []int{140, 140, 128, 128}, - false, - }, - { - "ipv6 unequal hop limit", - [][]byte{ - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), - tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { - fields.HopLimit++ - }), - udp6Packet(ip6PortA, ip6PortB, 100), - udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) { - fields.HopLimit++ - }), + { + "ipv6 unequal hop limit", + [][]byte{ + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), + tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { + fields.HopLimit++ + }), + udp6Packet(ip6PortA, ip6PortB, 100), + udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) { + fields.HopLimit++ + }), + }, + true, + []int{0, 1, 2, 3}, + []int{160, 160, 148, 148}, + false, }, - true, - []int{0, 1, 2, 3}, - []int{160, 160, 148, 148}, - false, - }, - { - "ipv6 unequal traffic class", - [][]byte{ - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), - tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { - fields.TrafficClass++ - }), - udp6Packet(ip6PortA, ip6PortB, 100), - udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) { - fields.TrafficClass++ - }), + { + "ipv6 unequal traffic class", + [][]byte{ + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), + tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { + fields.TrafficClass++ + }), + udp6Packet(ip6PortA, ip6PortB, 100), + udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) { + fields.TrafficClass++ + }), + }, + true, + []int{0, 1, 2, 3}, + []int{160, 160, 148, 148}, + false, }, - true, - []int{0, 1, 2, 3}, - []int{160, 160, 148, 148}, - false, - }, - } + } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - toWrite := make([]int, 0, len(tt.pktsIn)) - err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.canUDPGRO, &toWrite) - if err != nil { - if tt.wantErr { - return + for _, tt := range tests { + t.Run(fmt.Sprint(tt.name, " force ", forceCSum), func(t *testing.T) { + toWrite := make([]int, 0, len(tt.pktsIn)) + err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.canUDPGRO, forceCSum, &toWrite) + if err != nil { + if tt.wantErr { + return + } + t.Fatalf("got err: %v", err) } - t.Fatalf("got err: %v", err) - } - if len(toWrite) != len(tt.wantToWrite) { - t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite)) - } - for i, pktI := range tt.wantToWrite { - if tt.wantToWrite[i] != toWrite[i] { - t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i]) + if len(toWrite) != len(tt.wantToWrite) { + t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite)) } - if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) { - t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:])) + for i, pktI := range tt.wantToWrite { + if tt.wantToWrite[i] != toWrite[i] { + t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i]) + } + if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) { + t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:])) + } } - } - }) + }) + } } } diff --git a/tun/tun.go b/tun/tun.go index 336d64225..583012dbe 100644 --- a/tun/tun.go +++ b/tun/tun.go @@ -26,12 +26,14 @@ type Device interface { // packet lengths within the sizes slice. len(sizes) must be >= len(bufs). // A nonzero offset can be used to instruct the Device on where to begin // reading into each element of the bufs slice. + // A value of offset must be equal or greater than indicated by MinOffset() Read(bufs [][]byte, sizes []int, offset int) (n int, err error) // Write one or more packets to the device (without any additional headers). // On a successful write it returns the number of packets written. A nonzero // offset can be used to instruct the Device on where to begin writing from // each packet contained within the bufs slice. + // A value of offset must be equal or greater than indicated by MinOffset() Write(bufs [][]byte, offset int) (int, error) // MTU returns the MTU of the Device. @@ -43,6 +45,9 @@ type Device interface { // Events returns a channel of type Event, which is fed Device events. Events() <-chan Event + // SetCarrier sets carrier indication + SetCarrier(present bool) error + // Close stops the Device and closes the Event channel. Close() error @@ -50,4 +55,7 @@ type Device interface { // written in a single read/write call. BatchSize must not change over the // lifetime of a Device. BatchSize() int + + // MinOffset indicates minimum offset value buffers must use + MinOffset() int } diff --git a/tun/tun_darwin.go b/tun/tun_darwin.go index 341afe3c5..eb33626ac 100644 --- a/tun/tun_darwin.go +++ b/tun/tun_darwin.go @@ -20,14 +20,40 @@ import ( const utunControlName = "com.apple.net.utun_control" type NativeTun struct { - name string - tunFile *os.File - events chan Event - errors chan error - routeSocket int - closeOnce sync.Once + name string + tunFile *os.File + events chan Event + errors chan error + routeSocket int + closeOnce sync.Once + writeForceChecksum bool } +// Option functional option interface +type Option func(tun *NativeTun) error + +// WithOffload API compatability stub, does nothing on this platform +func WithOffload(offload bool) Option { + return func(tun *NativeTun) error { + return nil + } +} + +// WithCarrier API compatability stub, does nothing on this platform +func WithCarrier(carrier bool) Option { + return func(tun *NativeTun) error { + return nil + } +} + +// WithWriteForceChecksum force checksum computation for sent packets, default false +func WithWriteForceChecksum(forceChecksum bool) Option { + return func(tun *NativeTun) error { + tun.writeForceChecksum = forceChecksum + + return nil + } +} func (tun *NativeTun) routineRouteListener(tunIfindex int) { var ( statusUp bool @@ -82,7 +108,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) { } } -func CreateTUN(name string, mtu int) (Device, error) { +func CreateTUN(name string, mtu int, options ...Option) (Device, error) { ifIndex := -1 if name != "utun" { _, err := fmt.Sscanf(name, "utun%d", &ifIndex) @@ -120,7 +146,7 @@ func CreateTUN(name string, mtu int) (Device, error) { unix.Close(fd) return nil, err } - tun, err := CreateTUNFromFile(os.NewFile(uintptr(fd), ""), mtu) + tun, err := CreateTUNFromFile(os.NewFile(uintptr(fd), ""), mtu, options...) if err == nil && name == "utun" { fname := os.Getenv("WG_TUN_NAME_FILE") @@ -132,13 +158,20 @@ func CreateTUN(name string, mtu int) (Device, error) { return tun, err } -func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { +func CreateTUNFromFile(file *os.File, mtu int, options ...Option) (Device, error) { tun := &NativeTun{ tunFile: file, events: make(chan Event, 10), errors: make(chan error, 5), } + for _, opt := range options { + err := opt(tun) + if err != nil { + return nil, err + } + } + name, err := tun.Name() if err != nil { tun.tunFile.Close() @@ -224,6 +257,10 @@ func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { return 0, io.ErrShortBuffer } for i, buf := range bufs { + if tun.writeForceChecksum { + ComputeIPChecksumBuffer(bufs[i][offset:], false) + } + buf = buf[offset-4:] buf[0] = 0x00 buf[1] = 0x00 @@ -303,10 +340,18 @@ func (tun *NativeTun) MTU() (int, error) { return int(ifr.MTU), nil } +func (tun *NativeTun) SetCarrier(carrier bool) (err error) { + return nil +} + func (tun *NativeTun) BatchSize() int { return 1 } +func (tun *NativeTun) MinOffset() int { + return 4 +} + func socketCloexec(family, sotype, proto int) (fd int, err error) { // See go/src/net/sys_cloexec.go for background. syscall.ForkLock.RLock() diff --git a/tun/tun_freebsd.go b/tun/tun_freebsd.go index 4adf3a165..fbfbfeac8 100644 --- a/tun/tun_freebsd.go +++ b/tun/tun_freebsd.go @@ -68,12 +68,39 @@ type nd6Req struct { } type NativeTun struct { - name string - tunFile *os.File - events chan Event - errors chan error - routeSocket int - closeOnce sync.Once + name string + tunFile *os.File + events chan Event + errors chan error + routeSocket int + closeOnce sync.Once + writeForceChecksum bool +} + +// Option functional option interface +type Option func(tun *NativeTun) error + +// WithOffload API compatability stub, does nothing on this platform +func WithOffload(offload bool) Option { + return func(tun *NativeTun) error { + return nil + } +} + +// WithCarrier API compatability stub, does nothing on this platform +func WithCarrier(carrier bool) Option { + return func(tun *NativeTun) error { + return nil + } +} + +// WithWriteForceChecksum force checksum computation for sent packets, default false +func WithWriteForceChecksum(forceChecksum bool) Option { + return func(tun *NativeTun) error { + tun.writeForceChecksum = forceChecksum + + return nil + } } func (tun *NativeTun) routineRouteListener(tunIfindex int) { @@ -159,7 +186,7 @@ func tunDestroy(name string) error { return nil } -func CreateTUN(name string, mtu int) (Device, error) { +func CreateTUN(name string, mtu int, options ...Option) (Device, error) { if len(name) > unix.IFNAMSIZ-1 { return nil, errors.New("interface name too long") } @@ -258,16 +285,23 @@ func CreateTUN(name string, mtu int) (Device, error) { } } - return CreateTUNFromFile(tunFile, mtu) + return CreateTUNFromFile(tunFile, mtu, options...) } -func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { +func CreateTUNFromFile(file *os.File, mtu int, options ...Option) (Device, error) { tun := &NativeTun{ tunFile: file, events: make(chan Event, 10), errors: make(chan error, 1), } + for _, opt := range options { + err := opt(tun) + if err != nil { + return nil, err + } + } + var errno syscall.Errno tun.operateOnFd(func(fd uintptr) { _, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, _TUNSIFPID, uintptr(0)) @@ -353,6 +387,10 @@ func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { return 0, io.ErrShortBuffer } for i, buf := range bufs { + if tun.writeForceChecksum { + ComputeIPChecksumBuffer(bufs[i][offset:], false) + } + buf = buf[offset-4:] if len(buf) < 5 { return i, io.ErrShortBuffer @@ -430,6 +468,14 @@ func (tun *NativeTun) MTU() (int, error) { return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil } +func (tun *NativeTun) SetCarrier(carrier bool) (err error) { + return nil +} + func (tun *NativeTun) BatchSize() int { return 1 } + +func (tun *NativeTun) MinOffset() int { + return 4 +} diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 1461e068d..60d3dcb90 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -49,10 +49,46 @@ type NativeTun struct { readOpMu sync.Mutex // readOpMu guards readBuff readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr - writeOpMu sync.Mutex // writeOpMu guards toWrite, tcpGROTable - toWrite []int - tcpGROTable *tcpGROTable - udpGROTable *udpGROTable + writeOpMu sync.Mutex // writeOpMu guards toWrite, tcpGROTable + toWrite []int + tcpGROTable *tcpGROTable + udpGROTable *udpGROTable + carrier *bool + offload bool + writeForceChecksum bool +} + +// Option functional option interface +type Option func(tun *NativeTun) error + +// WithOffload set whether offloads should be used if available, default true +func WithOffload(offload bool) Option { + return func(tun *NativeTun) error { + tun.offload = offload + + return nil + } +} + +// WithCarrier set initial carrier state. +// If provided and false, interface is created without carrier. +// If provided and true, interface is marked as having carrier right after creation. +// If not provided (default), interface is created with unknown carrier state. +func WithCarrier(carrier bool) Option { + return func(tun *NativeTun) error { + tun.carrier = &carrier + + return nil + } +} + +// WithWriteForceChecksum force checksum computation for sent packets, default false +func WithWriteForceChecksum(forceChecksum bool) Option { + return func(tun *NativeTun) error { + tun.writeForceChecksum = forceChecksum + + return nil + } } func (tun *NativeTun) File() *os.File { @@ -339,19 +375,26 @@ func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { tun.udpGROTable.reset() tun.writeOpMu.Unlock() }() + var ( errs error total int ) + tun.toWrite = tun.toWrite[:0] + if tun.vnetHdr { - err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.udpGSO, &tun.toWrite) + err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.udpGSO, tun.writeForceChecksum, &tun.toWrite) if err != nil { return 0, err } offset -= virtioNetHdrLen } else { for i := range bufs { + if tun.writeForceChecksum { + ComputeIPChecksumBuffer(bufs[i][offset:], false) + } + tun.toWrite = append(tun.toWrite, i) } } @@ -498,10 +541,40 @@ func (tun *NativeTun) Close() error { return err2 } +func (tun *NativeTun) SetCarrier(carrier bool) (err error) { + sys, err := tun.File().SyscallConn() + if err != nil { + return err + } + + var value int + + if carrier { + value = 1 + } + + cerr := sys.Control(func(fd uintptr) { + err = unix.IoctlSetPointerInt(int(fd), unix.TUNSETCARRIER, value) + }) + if err != nil { + return + } + + return cerr +} + func (tun *NativeTun) BatchSize() int { return tun.batchSize } +func (tun *NativeTun) MinOffset() int { + if tun.vnetHdr { + return virtioNetHdrLen + } + + return 0 +} + const ( // TODO: support TSO with ECN bits tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 @@ -526,7 +599,7 @@ func (tun *NativeTun) initFromFlags(name string) error { return } got := ifr.Uint16() - if got&unix.IFF_VNET_HDR != 0 { + if tun.offload && got&unix.IFF_VNET_HDR != 0 { // tunTCPOffloads were added in Linux v2.6. We require their support // if IFF_VNET_HDR is set. err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads) @@ -541,6 +614,13 @@ func (tun *NativeTun) initFromFlags(name string) error { } else { tun.batchSize = 1 } + + if tun.carrier != nil && *tun.carrier { + err = unix.IoctlSetPointerInt(int(fd), unix.TUNSETCARRIER, 1) + if err != nil { + return + } + } }); e != nil { return e } @@ -548,7 +628,7 @@ func (tun *NativeTun) initFromFlags(name string) error { } // CreateTUN creates a Device with the provided name and MTU. -func CreateTUN(name string, mtu int) (Device, error) { +func CreateTUN(name string, mtu int, options ...Option) (Device, error) { nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0) if err != nil { if os.IsNotExist(err) { @@ -561,9 +641,21 @@ func CreateTUN(name string, mtu int) (Device, error) { if err != nil { return nil, err } + + tun, err := createTUN(options...) + if err != nil { + return nil, err + } + + flags := uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR) + + if tun.carrier != nil && !*tun.carrier { + flags |= unix.IFF_NO_CARRIER + } + // IFF_VNET_HDR enables the "tun status hack" via routineHackListener() // where a null write will return EINVAL indicating the TUN is up. - ifr.SetUint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR) + ifr.SetUint16(flags) err = unix.IoctlIfreq(nfd, unix.TUNSETIFF, ifr) if err != nil { return nil, err @@ -578,21 +670,43 @@ func CreateTUN(name string, mtu int) (Device, error) { // Note that the above -- open,ioctl,nonblock -- must happen prior to handing it to netpoll as below this line. fd := os.NewFile(uintptr(nfd), cloneDevicePath) - return CreateTUNFromFile(fd, mtu) + return CreateTUNFromFile(fd, mtu, options...) } // CreateTUNFromFile creates a Device from an os.File with the provided MTU. -func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { +func CreateTUNFromFile(file *os.File, mtu int, options ...Option) (Device, error) { + tun, err := createTUN(options...) + if err != nil { + return nil, err + } + + return populateTUN(file, tun, mtu) +} + +func createTUN(options ...Option) (*NativeTun, error) { tun := &NativeTun{ - tunFile: file, events: make(chan Event, 5), errors: make(chan error, 5), statusListenersShutdown: make(chan struct{}), tcpGROTable: newTCPGROTable(), udpGROTable: newUDPGROTable(), toWrite: make([]int, 0, conn.IdealBatchSize), + offload: true, } + for _, opt := range options { + err := opt(tun) + if err != nil { + return nil, err + } + } + + return tun, nil +} + +func populateTUN(file *os.File, tun *NativeTun, mtu int) (Device, error) { + tun.tunFile = file + name, err := tun.Name() if err != nil { return nil, err @@ -634,7 +748,7 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { // CreateUnmonitoredTUNFromFD creates a Device from the provided file // descriptor. -func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { +func CreateUnmonitoredTUNFromFD(fd int, options ...Option) (Device, string, error) { err := unix.SetNonblock(fd, true) if err != nil { return nil, "", err @@ -647,7 +761,16 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { tcpGROTable: newTCPGROTable(), udpGROTable: newUDPGROTable(), toWrite: make([]int, 0, conn.IdealBatchSize), + offload: true, + } + + for _, opt := range options { + err = opt(tun) + if err != nil { + return nil, "", err + } } + name, err := tun.Name() if err != nil { return nil, "", err diff --git a/tun/tun_openbsd.go b/tun/tun_openbsd.go index 5aa90705e..8e107a463 100644 --- a/tun/tun_openbsd.go +++ b/tun/tun_openbsd.go @@ -28,12 +28,39 @@ type ifreq_mtu struct { const _TUNSIFMODE = 0x8004745d type NativeTun struct { - name string - tunFile *os.File - events chan Event - errors chan error - routeSocket int - closeOnce sync.Once + name string + tunFile *os.File + events chan Event + errors chan error + routeSocket int + closeOnce sync.Once + writeForceChecksum bool +} + +// Option functional option interface +type Option func(tun *NativeTun) error + +// WithOffload API compatability stub, does nothing on this platform +func WithOffload(offload bool) Option { + return func(tun *NativeTun) error { + return nil + } +} + +// WithCarrier API compatability stub, does nothing on this platform +func WithCarrier(carrier bool) Option { + return func(tun *NativeTun) error { + return nil + } +} + +// WithWriteForceChecksum force checksum computation for sent packets, default false +func WithWriteForceChecksum(forceChecksum bool) Option { + return func(tun *NativeTun) error { + tun.writeForceChecksum = forceChecksum + + return nil + } } func (tun *NativeTun) routineRouteListener(tunIfindex int) { @@ -101,7 +128,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) { } } -func CreateTUN(name string, mtu int) (Device, error) { +func CreateTUN(name string, mtu int, options ...Option) (Device, error) { ifIndex := -1 if name != "tun" { _, err := fmt.Sscanf(name, "tun%d", &ifIndex) @@ -128,7 +155,7 @@ func CreateTUN(name string, mtu int) (Device, error) { return nil, err } - tun, err := CreateTUNFromFile(tunfile, mtu) + tun, err := CreateTUNFromFile(tunfile, mtu, options...) if err == nil && name == "tun" { fname := os.Getenv("WG_TUN_NAME_FILE") @@ -140,13 +167,21 @@ func CreateTUN(name string, mtu int) (Device, error) { return tun, err } -func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { +func CreateTUNFromFile(file *os.File, mtu int, options ...Option) (Device, error) { tun := &NativeTun{ tunFile: file, events: make(chan Event, 10), errors: make(chan error, 1), } + for _, opt := range options { + err := opt(tun) + if err != nil { + tun.tunFile.Close() + return nil, err + } + } + name, err := tun.Name() if err != nil { tun.tunFile.Close() @@ -224,6 +259,9 @@ func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { return 0, io.ErrShortBuffer } for i, buf := range bufs { + if tun.writeForceChecksum { + ComputeIPChecksumBuffer(bufs[i][offset:], false) + } buf = buf[offset-4:] buf[0] = 0x00 buf[1] = 0x00 @@ -328,6 +366,14 @@ func (tun *NativeTun) MTU() (int, error) { return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil } +func (tun *NativeTun) SetCarrier(carrier bool) (err error) { + return nil +} + func (tun *NativeTun) BatchSize() int { return 1 } + +func (tun *NativeTun) MinOffset() int { + return 4 +} diff --git a/tun/tun_windows.go b/tun/tun_windows.go index de65fb446..9e334639b 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -32,18 +32,45 @@ type rateJuggler struct { } type NativeTun struct { - wt *wintun.Adapter - name string - handle windows.Handle - rate rateJuggler - session wintun.Session - readWait windows.Handle - events chan Event - running sync.WaitGroup - closeOnce sync.Once - close atomic.Bool - forcedMTU int - outSizes []int + wt *wintun.Adapter + name string + handle windows.Handle + rate rateJuggler + session wintun.Session + readWait windows.Handle + events chan Event + running sync.WaitGroup + closeOnce sync.Once + close atomic.Bool + forcedMTU int + outSizes []int + writeForceChecksum bool +} + +// Option functional option interface +type Option func(tun *NativeTun) error + +// WithOffload API compatability stub, does nothing on this platform +func WithOffload(offload bool) Option { + return func(tun *NativeTun) error { + return nil + } +} + +// WithCarrier API compatability stub, does nothing on this platform +func WithCarrier(carrier bool) Option { + return func(tun *NativeTun) error { + return nil + } +} + +// WithWriteForceChecksum force checksum computation for sent packets, default false +func WithWriteForceChecksum(forceChecksum bool) Option { + return func(tun *NativeTun) error { + tun.writeForceChecksum = forceChecksum + + return nil + } } var ( @@ -59,13 +86,13 @@ func nanotime() int64 // CreateTUN creates a Wintun interface with the given name. Should a Wintun // interface with the same name exist, it is reused. -func CreateTUN(ifname string, mtu int) (Device, error) { - return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu) +func CreateTUN(ifname string, mtu int, options ...Option) (Device, error) { + return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu, options...) } // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and // a requested GUID. Should a Wintun interface with the same name exist, it is reused. -func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) { +func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int, options ...Option) (Device, error) { wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID) if err != nil { return nil, fmt.Errorf("Error creating interface: %w", err) @@ -84,6 +111,15 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu forcedMTU: forcedMTU, } + for _, opt := range options { + err = opt(tun) + if err != nil { + tun.wt.Close() + close(tun.events) + return nil, fmt.Errorf("Error configuring interface: %w", err) + } + } + tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB if err != nil { tun.wt.Close() @@ -137,11 +173,19 @@ func (tun *NativeTun) ForceMTU(mtu int) { } } +func (tun *NativeTun) SetCarrier(carrier bool) (err error) { + return nil +} + func (tun *NativeTun) BatchSize() int { // TODO: implement batching with wintun return 1 } +func (tun *NativeTun) MinOffset() int { + return 0 +} + // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { @@ -189,6 +233,10 @@ func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { } for i, buf := range bufs { + if tun.writeForceChecksum { + ComputeIPChecksumBuffer(bufs[i][offset:], false) + } + packetSize := len(buf) - offset tun.rate.update(uint64(packetSize)) diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go index 9c4564f26..f5b62c6a7 100644 --- a/tun/tuntest/tuntest.go +++ b/tun/tuntest/tuntest.go @@ -153,3 +153,5 @@ func (t *chTun) Close() error { t.Write(nil, -1) return nil } +func (t *chTun) MinOffset() int { return 0 } +func (t *chTun) SetCarrier(bool) error { return nil }