package webrtc import ( "sync" "testing" "github.com/pion/rtp" ) // makePacket returns a minimal *rtp.Packet with the given payload bytes. func makePacket(payload []byte) *rtp.Packet { return &rtp.Packet{Payload: payload} } // --- isH264IDRStart --- func TestIsH264IDRStart_Empty(t *testing.T) { if isH264IDRStart(makePacket(nil)) { t.Error("empty payload should not be IDR") } if isH264IDRStart(makePacket([]byte{})) { t.Error("zero-length payload should not be IDR") } } func TestIsH264IDRStart_SingleNAL_IDR(t *testing.T) { // NAL type 5 = IDR slice. Forbidden zero bit + NRI can be anything. p := makePacket([]byte{0x65, 0xb8, 0x00}) // 0x65 = 0110_0101 → type=5 if !isH264IDRStart(p) { t.Error("single NAL type 5 should be detected as IDR start") } } func TestIsH264IDRStart_SingleNAL_NonIDR(t *testing.T) { tests := []struct { name string payload []byte }{ {"SPS (type 7)", []byte{0x67, 0x42, 0x00}}, {"PPS (type 8)", []byte{0x68, 0xce, 0x38}}, {"P-frame (type 1)", []byte{0x41, 0x9a}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { nalType := tt.payload[0] & 0x1F if isH264IDRStart(makePacket(tt.payload)) { t.Errorf("NAL type %d should not be IDR start", nalType) } }) } } func TestIsH264IDRStart_FUA_Start_IDR(t *testing.T) { // FU-A: header byte NAL type = 28 (0x1C), FU header start bit set (0x80), inner type = 5 p := makePacket([]byte{0x7c, 0x85, 0x00, 0x00}) // 0x7c = type 28; 0x85 = start|IDR if !isH264IDRStart(p) { t.Error("FU-A with start bit + inner type 5 should be IDR start") } } func TestIsH264IDRStart_FUA_Start_NonIDR(t *testing.T) { // FU-A start, but inner NAL type = 1 (P-frame fragment) p := makePacket([]byte{0x7c, 0x81}) // start bit set, inner type = 1 if isH264IDRStart(p) { t.Error("FU-A P-frame start should not be IDR") } } func TestIsH264IDRStart_FUA_Continuation(t *testing.T) { // FU-A continuation: start bit NOT set, even if inner type byte = 5 p := makePacket([]byte{0x7c, 0x05}) // 0x05 & 0x80 == 0 — no start bit if isH264IDRStart(p) { t.Error("FU-A continuation should not be IDR start") } } func TestIsH264IDRStart_FUA_TruncatedPayload(t *testing.T) { // FU-A with only 1 byte — no FU header byte present p := makePacket([]byte{0x7c}) if isH264IDRStart(p) { t.Error("truncated FU-A (1 byte) should not panic or return true") } } func TestIsH264IDRStart_STAPA_LeadingIDR(t *testing.T) { // STAP-A (type 24): byte 0 = NAL header (0x78 = type 24), // bytes 1-2 = first NAL size (big-endian), byte 3 = first NAL header. // First NAL type = 5 (IDR) → should be detected. p := makePacket([]byte{ 0x78, // STAP-A header: NRI=3, type=24 0x00, 0x03, // first NAL size = 3 bytes 0x65, 0x88, 0x84, // first NAL: type 5 (IDR slice) 0x00, 0x02, // second NAL size = 2 bytes (SPS, doesn't matter) 0x67, 0x42, // second NAL: SPS }) if !isH264IDRStart(p) { t.Error("STAP-A with leading IDR NAL (type 5) should be detected as IDR start") } } func TestIsH264IDRStart_STAPA_LeadingNonIDR(t *testing.T) { // STAP-A where the first NAL is SPS (type 7), not IDR. // Common pattern: encoders bundle SPS+PPS+IDR in separate STAP-A, // then IDR in a single NAL or FU-A. This STAP-A should not trigger reset. p := makePacket([]byte{ 0x78, // STAP-A header 0x00, 0x03, // first NAL size = 3 0x67, 0x42, 0x00, // first NAL: SPS (type 7) }) if isH264IDRStart(p) { t.Error("STAP-A with leading SPS (type 7) should not be IDR start") } } func TestIsH264IDRStart_STAPA_Truncated(t *testing.T) { // STAP-A with fewer than 4 bytes — cannot safely read first NAL header. tests := []struct { name string payload []byte }{ {"1 byte (header only)", []byte{0x78}}, {"2 bytes (header + 1 size byte)", []byte{0x78, 0x00}}, {"3 bytes (header + size, no NAL)", []byte{0x78, 0x00, 0x01}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if isH264IDRStart(makePacket(tt.payload)) { t.Errorf("truncated STAP-A (%d bytes) should not panic or return true", len(tt.payload)) } }) } } func TestIsH264IDRStart_Opus(t *testing.T) { // Opus RTP payload starts with a TOC byte — definitely not H.264 p := makePacket([]byte{0xf8, 0xff, 0xfe}) if isH264IDRStart(p) { t.Error("Opus payload should not be detected as IDR") } } // --- keyFrameCache push / snapshot --- func TestKeyFrameCache_EmptySnapshot(t *testing.T) { c := newKeyFrameCache() if snap := c.snapshot(); snap != nil { t.Errorf("expected nil snapshot from empty cache, got %d packets", len(snap)) } } func TestKeyFrameCache_IDRResetsPreviousBurst(t *testing.T) { c := newKeyFrameCache() // Push some non-IDR packets first. for i := 0; i < 5; i++ { c.push(makePacket([]byte{0x41, byte(i)})) } // Now push an IDR start — cache should reset to just this packet. idrPkt := makePacket([]byte{0x65, 0x88, 0x84}) c.push(idrPkt) snap := c.snapshot() if len(snap) != 1 { t.Errorf("expected exactly 1 packet after IDR reset, got %d", len(snap)) } if snap[0] != idrPkt { t.Error("snapshot should contain the IDR packet itself") } } func TestKeyFrameCache_BurstAccumulation(t *testing.T) { c := newKeyFrameCache() idr := makePacket([]byte{0x65, 0x88, 0x84}) c.push(idr) // Push 9 more packets (P-frames) for i := 0; i < 9; i++ { c.push(makePacket([]byte{0x41, byte(i)})) } snap := c.snapshot() if len(snap) != 10 { t.Errorf("expected 10 packets in burst, got %d", len(snap)) } } func TestKeyFrameCache_SecondIDRResetsAgain(t *testing.T) { c := newKeyFrameCache() c.push(makePacket([]byte{0x65, 0x01})) // first IDR for i := 0; i < 4; i++ { c.push(makePacket([]byte{0x41, byte(i)})) } second := makePacket([]byte{0x65, 0x02}) // second IDR — resets burst c.push(second) snap := c.snapshot() if len(snap) != 1 { t.Errorf("second IDR should reset burst to 1 packet, got %d", len(snap)) } if snap[0] != second { t.Error("snapshot should contain the second IDR packet") } } func TestKeyFrameCache_STAPA_IDR_ResetsCache(t *testing.T) { // Verify that a STAP-A with a leading IDR NAL correctly resets the burst, // just like a single-NAL IDR packet does. c := newKeyFrameCache() // Pre-load some P-frames. for i := 0; i < 5; i++ { c.push(makePacket([]byte{0x41, byte(i)})) } stapA := makePacket([]byte{ 0x78, // STAP-A header 0x00, 0x03, // first NAL size = 3 0x65, 0x88, 0x84, // first NAL: IDR (type 5) }) c.push(stapA) snap := c.snapshot() if len(snap) != 1 { t.Errorf("STAP-A IDR should reset burst to 1 packet, got %d", len(snap)) } if snap[0] != stapA { t.Error("snapshot should contain the STAP-A IDR packet") } } func TestKeyFrameCache_MaxPacketsCap(t *testing.T) { c := newKeyFrameCache() c.maxPackets = 5 c.maxBytes = 1 << 20 // generous byte cap idr := makePacket([]byte{0x65, 0x88}) c.push(idr) for i := 0; i < 20; i++ { c.push(makePacket([]byte{0x41, byte(i)})) } snap := c.snapshot() if len(snap) != 5 { t.Errorf("cache should stop at maxPackets=5, got %d", len(snap)) } } func TestKeyFrameCache_MaxBytesCap(t *testing.T) { c := newKeyFrameCache() c.maxPackets = 512 c.maxBytes = 10 // tiny — only 2 packets of 4 bytes each fit (8 ≤ 10 < 12) idr := makePacket([]byte{0x65, 0x88, 0x84, 0x21}) // 4 bytes payload c.push(idr) c.push(makePacket([]byte{0x41, 0x9a, 0xab, 0xcd})) // 4 bytes → total 8 ≤ 10 ✓ c.push(makePacket([]byte{0x41, 0x01, 0x02, 0x03})) // 4 bytes → would be 12 > 10 ✗ snap := c.snapshot() if len(snap) != 2 { t.Errorf("expected 2 packets within byte cap, got %d", len(snap)) } } func TestKeyFrameCache_SnapshotIsACopy(t *testing.T) { c := newKeyFrameCache() c.push(makePacket([]byte{0x65, 0x88})) c.push(makePacket([]byte{0x41, 0x01})) snap1 := c.snapshot() // Push another IDR — clears the internal slice. c.push(makePacket([]byte{0x65, 0x99})) // snap1 should still hold the original 2 packets. if len(snap1) != 2 { t.Errorf("snapshot should be independent of later cache mutations, got %d", len(snap1)) } } func TestKeyFrameCache_ConcurrentSnapshotAndPush(t *testing.T) { c := newKeyFrameCache() var wg sync.WaitGroup // Single writer: 1000 pushes alternating IDR / P-frame. wg.Add(1) go func() { defer wg.Done() for i := 0; i < 1000; i++ { if i%50 == 0 { c.push(makePacket([]byte{0x65, byte(i)})) } else { c.push(makePacket([]byte{0x41, byte(i)})) } } }() // Multiple readers: concurrent snapshots — must not data-race or panic. for r := 0; r < 4; r++ { wg.Add(1) go func() { defer wg.Done() for i := 0; i < 250; i++ { _ = c.snapshot() } }() } wg.Wait() }