diff --git a/core/webrtc/ice.go b/core/webrtc/ice.go new file mode 100644 index 0000000..a728055 --- /dev/null +++ b/core/webrtc/ice.go @@ -0,0 +1,47 @@ +package webrtc + +import ( + "github.com/pion/webrtc/v4" +) + +// BuildICEConfig translates a Config into the two Pion config pieces every +// PeerConnection needs: a webrtc.Configuration (with ICE servers) and a +// SettingEngine (with NAT1To1 and port range tuning). +// +// The returned *SettingEngine may be nil if no engine-level tuning is +// required (i.e. PublicIP unset and UDPPortRange at defaults). Callers +// should only pass it to webrtc.NewAPI when non-nil. +func BuildICEConfig(c Config) (webrtc.Configuration, *webrtc.SettingEngine, error) { + if err := c.Validate(); err != nil { + return webrtc.Configuration{}, nil, err + } + + rtcConfig := webrtc.Configuration{ + ICEServers: make([]webrtc.ICEServer, 0, len(c.ICEServers)), + } + for _, uri := range c.ICEServers { + rtcConfig.ICEServers = append(rtcConfig.ICEServers, webrtc.ICEServer{ + URLs: []string{uri}, + }) + } + + var se *webrtc.SettingEngine + if c.PublicIP != "" || c.UDPPortRange.Low > 0 { + engine := webrtc.SettingEngine{} + if c.PublicIP != "" { + engine.SetNAT1To1IPs([]string{c.PublicIP}, webrtc.ICECandidateTypeHost) + } + // Constrain the ephemeral UDP range Pion allocates for ICE candidates. + // Note: this is a separate concern from our FFmpeg→Source UDP ports; + // Pion uses its own port pool for the WebRTC media path. + if c.UDPPortRange.Low > 0 && c.UDPPortRange.High >= c.UDPPortRange.Low { + if err := engine.SetEphemeralUDPPortRange( + uint16(c.UDPPortRange.Low), uint16(c.UDPPortRange.High)); err != nil { + return webrtc.Configuration{}, nil, err + } + } + se = &engine + } + + return rtcConfig, se, nil +} diff --git a/core/webrtc/ice_test.go b/core/webrtc/ice_test.go new file mode 100644 index 0000000..99294c1 --- /dev/null +++ b/core/webrtc/ice_test.go @@ -0,0 +1,50 @@ +package webrtc + +import ( + "testing" + + "github.com/pion/webrtc/v4" +) + +func TestBuildICEConfig_Defaults(t *testing.T) { + c := DefaultConfig() + rtcConfig, _, err := BuildICEConfig(c) + if err != nil { + t.Fatalf("BuildICEConfig: %v", err) + } + if len(rtcConfig.ICEServers) == 0 { + t.Error("ICEServers should not be empty") + } + // First default is Cloudflare STUN. + if rtcConfig.ICEServers[0].URLs[0] != "stun:stun.cloudflare.com:3478" { + t.Errorf("first ICE server = %q, want stun:stun.cloudflare.com:3478", + rtcConfig.ICEServers[0].URLs[0]) + } +} + +func TestBuildICEConfig_PublicIP(t *testing.T) { + c := DefaultConfig() + c.PublicIP = "203.0.113.10" + _, se, err := BuildICEConfig(c) + if err != nil { + t.Fatalf("BuildICEConfig: %v", err) + } + if se == nil { + t.Fatal("SettingEngine should not be nil when PublicIP is set") + } + // We can't introspect NAT1To1IPs directly from Pion's public API; the + // smoke test is that building an API from this engine works. + api := webrtc.NewAPI(webrtc.WithSettingEngine(*se)) + if api == nil { + t.Fatal("NewAPI returned nil") + } +} + +func TestBuildICEConfig_InvalidConfig(t *testing.T) { + c := DefaultConfig() + c.WHEPListen = "" + _, _, err := BuildICEConfig(c) + if err == nil { + t.Error("BuildICEConfig should reject invalid config") + } +} diff --git a/go.mod b/go.mod index fbd79ba..6ad4d98 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/mattn/go-isatty v0.0.20 github.com/minio/minio-go/v7 v7.0.70 github.com/pion/rtp v1.10.1 + github.com/pion/webrtc/v4 v4.2.11 github.com/prep/average v0.0.0-20200506183628-d26c465f48c3 github.com/prometheus/client_golang v1.19.1 github.com/puzpuzpuz/xsync/v3 v3.1.0 @@ -72,7 +73,20 @@ require ( github.com/miekg/dns v1.1.59 // indirect github.com/minio/md5-simd v1.1.2 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/pion/datachannel v1.6.0 // indirect + github.com/pion/dtls/v3 v3.1.2 // indirect + github.com/pion/ice/v4 v4.2.2 // indirect + github.com/pion/interceptor v0.1.44 // indirect + github.com/pion/logging v0.2.4 // indirect + github.com/pion/mdns/v2 v2.1.0 // indirect github.com/pion/randutil v0.1.0 // indirect + github.com/pion/rtcp v1.2.16 // indirect + github.com/pion/sctp v1.9.4 // indirect + github.com/pion/sdp/v3 v3.0.18 // indirect + github.com/pion/srtp/v3 v3.0.10 // indirect + github.com/pion/stun/v3 v3.1.1 // indirect + github.com/pion/transport/v4 v4.0.1 // indirect + github.com/pion/turn/v4 v4.1.4 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect github.com/prometheus/client_model v0.6.1 // indirect @@ -88,6 +102,7 @@ require ( github.com/urfave/cli/v2 v2.27.2 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasttemplate v1.2.2 // indirect + github.com/wlynxg/anet v0.0.5 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/xrash/smetrics v0.0.0-20240312152122-5f08fbb34913 // indirect diff --git a/go.sum b/go.sum index fb52ee4..9015077 100644 --- a/go.sum +++ b/go.sum @@ -134,10 +134,40 @@ github.com/minio/minio-go/v7 v7.0.70 h1:1u9NtMgfK1U42kUxcsl5v0yj6TEOPR497OAQxpJn github.com/minio/minio-go/v7 v7.0.70/go.mod h1:4yBA8v80xGA30cfM3fz0DKYMXunWl/AV/6tWEs9ryzo= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/pion/datachannel v1.6.0 h1:XecBlj+cvsxhAMZWFfFcPyUaDZtd7IJvrXqlXD/53i0= +github.com/pion/datachannel v1.6.0/go.mod h1:ur+wzYF8mWdC+Mkis5Thosk+u/VOL287apDNEbFpsIk= +github.com/pion/dtls/v3 v3.1.2 h1:gqEdOUXLtCGW+afsBLO0LtDD8GnuBBjEy6HRtyofZTc= +github.com/pion/dtls/v3 v3.1.2/go.mod h1:Hw/igcX4pdY69z1Hgv5x7wJFrUkdgHwAn/Q/uo7YHRo= +github.com/pion/ice/v4 v4.2.2 h1:dQJzzcgTFHDYyV3BoCfjPeX+JEtr58BWPi4PGyo6Vjg= +github.com/pion/ice/v4 v4.2.2/go.mod h1:2quLV1S5v1tAx3VvAJaH//KGitRXvo4RKlX6D3tnN+c= +github.com/pion/interceptor v0.1.44 h1:sNlZwM8dWXU9JQAkJh8xrarC0Etn8Oolcniukmuy0/I= +github.com/pion/interceptor v0.1.44/go.mod h1:4atVlBkcgXuUP+ykQF0qOCGU2j7pQzX2ofvPRFsY5RY= +github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= +github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= +github.com/pion/mdns/v2 v2.1.0 h1:3IJ9+Xio6tWYjhN6WwuY142P/1jA0D5ERaIqawg/fOY= +github.com/pion/mdns/v2 v2.1.0/go.mod h1:pcez23GdynwcfRU1977qKU0mDxSeucttSHbCSfFOd9A= github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= +github.com/pion/rtcp v1.2.16 h1:fk1B1dNW4hsI78XUCljZJlC4kZOPk67mNRuQ0fcEkSo= +github.com/pion/rtcp v1.2.16/go.mod h1:/as7VKfYbs5NIb4h6muQ35kQF/J0ZVNz2Z3xKoCBYOo= github.com/pion/rtp v1.10.1 h1:xP1prZcCTUuhO2c83XtxyOHJteISg6o8iPsE2acaMtA= github.com/pion/rtp v1.10.1/go.mod h1:rF5nS1GqbR7H/TCpKwylzeq6yDM+MM6k+On5EgeThEM= +github.com/pion/sctp v1.9.4 h1:cMxEu0F5tbP4qH07bKf1Zjf4rUih9LIo0qQt424e258= +github.com/pion/sctp v1.9.4/go.mod h1:N20Dq6LY+JvJDAh9VVh1JELngb2rQ8dPgds5yBWiPgw= +github.com/pion/sdp/v3 v3.0.18 h1:l0bAXazKHpepazVdp+tPYnrsy9dfh7ZbT8DxesH5ZnI= +github.com/pion/sdp/v3 v3.0.18/go.mod h1:ZREGo6A9ZygQ9XkqAj5xYCQtQpif0i6Pa81HOiAdqQ8= +github.com/pion/srtp/v3 v3.0.10 h1:tFirkpBb3XccP5VEXLi50GqXhv5SKPxqrdlhDCJlZrQ= +github.com/pion/srtp/v3 v3.0.10/go.mod h1:3mOTIB0cq9qlbn59V4ozvv9ClW/BSEbRp4cY0VtaR7M= +github.com/pion/stun/v3 v3.1.1 h1:CkQxveJ4xGQjulGSROXbXq94TAWu8gIX2dT+ePhUkqw= +github.com/pion/stun/v3 v3.1.1/go.mod h1:qC1DfmcCTQjl9PBaMa5wSn3x9IPmKxSdcCsxBcDBndM= +github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM= +github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ= +github.com/pion/transport/v4 v4.0.1 h1:sdROELU6BZ63Ab7FrOLn13M6YdJLY20wldXW2Cu2k8o= +github.com/pion/transport/v4 v4.0.1/go.mod h1:nEuEA4AD5lPdcIegQDpVLgNoDGreqM/YqmEx3ovP4jM= +github.com/pion/turn/v4 v4.1.4 h1:EU11yMXKIsK43FhcUnjLlrhE4nboHZq+TXBIi3QpcxQ= +github.com/pion/turn/v4 v4.1.4/go.mod h1:ES1DXVFKnOhuDkqn9hn5VJlSWmZPaRJLyBXoOeO/BmQ= +github.com/pion/webrtc/v4 v4.2.11 h1:QUX1QZKlNIn4O7U5JxLPGP0sV5RTncZkzu9SPR3jVNU= +github.com/pion/webrtc/v4 v4.2.11/go.mod h1:s/rAiyy77GyRFrZMx+Ls6aua26dIBPudH8/ZHYbIRWY= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -204,6 +234,8 @@ github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQ github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= github.com/vektah/gqlparser/v2 v2.5.12 h1:COMhVVnql6RoaF7+aTBWiTADdpLGyZWU3K/NwW0ph98= github.com/vektah/gqlparser/v2 v2.5.12/go.mod h1:WQQjFc+I1YIzoPvZBhUQX7waZgg3pMLi0r8KymvAE2w= +github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU= +github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb h1:zGWFAtiMcyryUHoUjUJX0/lt1H2+i2Ka2n+D3DImSNo= github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= diff --git a/vendor/github.com/pion/datachannel/.gitignore b/vendor/github.com/pion/datachannel/.gitignore new file mode 100644 index 0000000..6e2f206 --- /dev/null +++ b/vendor/github.com/pion/datachannel/.gitignore @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +### JetBrains IDE ### +##################### +.idea/ + +### Emacs Temporary Files ### +############################# +*~ + +### Folders ### +############### +bin/ +vendor/ +node_modules/ + +### Files ### +############# +*.ivf +*.ogg +tags +cover.out +*.sw[poe] +*.wasm +examples/sfu-ws/cert.pem +examples/sfu-ws/key.pem +wasm_exec.js diff --git a/vendor/github.com/pion/datachannel/.golangci.yml b/vendor/github.com/pion/datachannel/.golangci.yml new file mode 100644 index 0000000..4b4025f --- /dev/null +++ b/vendor/github.com/pion/datachannel/.golangci.yml @@ -0,0 +1,147 @@ +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +version: "2" +linters: + enable: + - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers + - bidichk # Checks for dangerous unicode character sequences + - bodyclose # checks whether HTTP response body is closed successfully + - containedctx # containedctx is a linter that detects struct contained context.Context field + - contextcheck # check the function whether use a non-inherited context + - cyclop # checks function and package cyclomatic complexity + - decorder # check declaration order and count of types, constants, variables and functions + - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) + - dupl # Tool for code clone detection + - durationcheck # check for two durations multiplied together + - err113 # Golang linter to check the errors handling expressions + - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases + - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. + - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. + - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. + - exhaustive # check exhaustiveness of enum switch statements + - forbidigo # Forbids identifiers + - forcetypeassert # finds forced type assertions + - gochecknoglobals # Checks that no globals are present in Go code + - gocognit # Computes and checks the cognitive complexity of functions + - goconst # Finds repeated strings that could be replaced by a constant + - gocritic # The most opinionated Go source code linter + - gocyclo # Computes and checks the cyclomatic complexity of functions + - godot # Check if comments end in a period + - godox # Tool for detection of FIXME, TODO and other comment keywords + - goheader # Checks is file header matches to pattern + - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. + - goprintffuncname # Checks that printf-like functions are named with `f` at the end + - gosec # Inspects source code for security problems + - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string + - grouper # An analyzer to analyze expression groups. + - importas # Enforces consistent import aliases + - ineffassign # Detects when assignments to existing variables are not used + - lll # Reports long lines + - maintidx # maintidx measures the maintainability index of each function. + - makezero # Finds slice declarations with non-zero initial length + - misspell # Finds commonly misspelled English words in comments + - nakedret # Finds naked returns in functions greater than a specified function length + - nestif # Reports deeply nested if statements + - nilerr # Finds the code that returns nil even if it checks that the error is not nil. + - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. + - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity + - noctx # noctx finds sending http request without context.Context + - predeclared # find code that shadows one of Go's predeclared identifiers + - revive # golint replacement, finds style mistakes + - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks + - tagliatelle # Checks the struct tags. + - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers + - unconvert # Remove unnecessary type conversions + - unparam # Reports unused function parameters + - unused # Checks Go code for unused constants, variables, functions and types + - varnamelen # checks that the length of a variable's name matches its scope + - wastedassign # wastedassign finds wasted assignment statements + - whitespace # Tool for detection of leading and trailing whitespace + disable: + - depguard # Go linter that checks if package imports are in a list of acceptable packages + - funlen # Tool for detection of long functions + - gochecknoinits # Checks that no init functions are present in Go code + - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. + - interfacebloat # A linter that checks length of interface. + - ireturn # Accept Interfaces, Return Concrete Types + - mnd # An analyzer to detect magic numbers + - nolintlint # Reports ill-formed or insufficient nolint directives + - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test + - prealloc # Finds slice declarations that could potentially be preallocated + - promlinter # Check Prometheus metrics naming via promlint + - rowserrcheck # checks whether Err of rows is checked successfully + - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. + - testpackage # linter that makes you use a separate _test package + - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes + - wrapcheck # Checks that errors returned from external packages are wrapped + - wsl # Whitespace Linter - Forces you to use empty lines! + settings: + staticcheck: + checks: + - all + - -QF1008 # "could remove embedded field", to keep it explicit! + - -QF1003 # "could use tagged switch on enum", Cases conflicts with exhaustive! + exhaustive: + default-signifies-exhaustive: true + forbidigo: + forbid: + - pattern: ^fmt.Print(f|ln)?$ + - pattern: ^log.(Panic|Fatal|Print)(f|ln)?$ + - pattern: ^os.Exit$ + - pattern: ^panic$ + - pattern: ^print(ln)?$ + - pattern: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$ + pkg: ^testing$ + msg: use testify/assert instead + analyze-types: true + gomodguard: + blocked: + modules: + - github.com/pkg/errors: + recommendations: + - errors + govet: + enable: + - shadow + revive: + rules: + # Prefer 'any' type alias over 'interface{}' for Go 1.18+ compatibility + - name: use-any + severity: warning + disabled: false + misspell: + locale: US + varnamelen: + max-distance: 12 + min-name-length: 2 + ignore-type-assert-ok: true + ignore-map-index-ok: true + ignore-chan-recv-ok: true + ignore-decls: + - i int + - n int + - w io.Writer + - r io.Reader + - b []byte + exclusions: + generated: lax + rules: + - linters: + - forbidigo + - gocognit + path: (examples|main\.go) + - linters: + - gocognit + path: _test\.go + - linters: + - forbidigo + path: cmd +formatters: + enable: + - gci # Gci control golang package import order and make it always deterministic. + - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification + - gofumpt # Gofumpt checks whether code was gofumpt-ed. + - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports + exclusions: + generated: lax diff --git a/vendor/github.com/pion/datachannel/.goreleaser.yml b/vendor/github.com/pion/datachannel/.goreleaser.yml new file mode 100644 index 0000000..30093e9 --- /dev/null +++ b/vendor/github.com/pion/datachannel/.goreleaser.yml @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +builds: +- skip: true diff --git a/vendor/github.com/pion/datachannel/LICENSE b/vendor/github.com/pion/datachannel/LICENSE new file mode 100644 index 0000000..491caf6 --- /dev/null +++ b/vendor/github.com/pion/datachannel/LICENSE @@ -0,0 +1,9 @@ +MIT License + +Copyright (c) 2023 The Pion community + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/pion/datachannel/README.md b/vendor/github.com/pion/datachannel/README.md new file mode 100644 index 0000000..9885a7e --- /dev/null +++ b/vendor/github.com/pion/datachannel/README.md @@ -0,0 +1,34 @@ +

+
+ Pion Data Channels +
+

+

A Go implementation of WebRTC Data Channels

+

+ Pion Data Channels + join us on Discord Follow us on Bluesky +
+ GitHub Workflow Status + Go Reference + Coverage Status + Go Report Card + License: MIT +

+
+ +### Roadmap +The library is used as a part of our WebRTC implementation. Please refer to that [roadmap](https://github.com/pion/webrtc/issues/9) to track our major milestones. + +### Community +Pion has an active community on the [Discord](https://discord.gg/PngbdqpFbt). + +Follow the [Pion Bluesky](https://bsky.app/profile/pion.ly) or [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news. + +We are always looking to support **your projects**. Please reach out if you have something to build! +If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly) + +### Contributing +Check out the [contributing wiki](https://github.com/pion/webrtc/wiki/Contributing) to join the group of amazing people making this project possible + +### License +MIT License - see [LICENSE](LICENSE) for full text diff --git a/vendor/github.com/pion/datachannel/codecov.yml b/vendor/github.com/pion/datachannel/codecov.yml new file mode 100644 index 0000000..263e4d4 --- /dev/null +++ b/vendor/github.com/pion/datachannel/codecov.yml @@ -0,0 +1,22 @@ +# +# DO NOT EDIT THIS FILE +# +# It is automatically copied from https://github.com/pion/.goassets repository. +# +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +coverage: + status: + project: + default: + # Allow decreasing 2% of total coverage to avoid noise. + threshold: 2% + patch: + default: + target: 70% + only_pulls: true + +ignore: + - "examples/*" + - "examples/**/*" diff --git a/vendor/github.com/pion/datachannel/datachannel.go b/vendor/github.com/pion/datachannel/datachannel.go new file mode 100644 index 0000000..976a3fb --- /dev/null +++ b/vendor/github.com/pion/datachannel/datachannel.go @@ -0,0 +1,445 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package datachannel implements WebRTC Data Channels +package datachannel + +import ( + "errors" + "fmt" + "io" + "sync" + "sync/atomic" + "time" + + "github.com/pion/logging" + "github.com/pion/sctp" +) + +const receiveMTU = 8192 + +// Reader is an extended io.Reader +// that also returns if the message is text. +type Reader interface { + ReadDataChannel([]byte) (int, bool, error) +} + +// ReadDeadliner extends an io.Reader to expose setting a read deadline. +type ReadDeadliner interface { + SetReadDeadline(time.Time) error +} + +// Writer is an extended io.Writer +// that also allows indicating if a message is text. +type Writer interface { + WriteDataChannel([]byte, bool) (int, error) +} + +// WriteDeadliner extends an io.Writer to expose setting a write deadline. +type WriteDeadliner interface { + SetWriteDeadline(time.Time) error +} + +// ReadWriteCloser is an extended io.ReadWriteCloser +// that also implements our Reader and Writer. +type ReadWriteCloser interface { + io.Reader + io.Writer + Reader + Writer + io.Closer +} + +// ReadWriteCloserDeadliner is an extended ReadWriteCloser +// that also implements r/w deadline. +type ReadWriteCloserDeadliner interface { + ReadWriteCloser + ReadDeadliner + WriteDeadliner +} + +// DataChannel represents a data channel. +type DataChannel struct { + Config + + // stats + messagesSent uint32 + messagesReceived uint32 + bytesSent uint64 + bytesReceived uint64 + + mu sync.Mutex + onOpenCompleteHandler func() + openCompleteHandlerOnce sync.Once + + stream *sctp.Stream + log logging.LeveledLogger +} + +// Config is used to configure the data channel. +type Config struct { + ChannelType ChannelType + Negotiated bool + Priority uint16 + ReliabilityParameter uint32 + Label string + Protocol string + LoggerFactory logging.LoggerFactory +} + +func newDataChannel(stream *sctp.Stream, config *Config) *DataChannel { + return &DataChannel{ + Config: *config, + stream: stream, + log: config.LoggerFactory.NewLogger("datachannel"), + } +} + +// Dial opens a data channels over SCTP. +func Dial(a *sctp.Association, id uint16, config *Config) (*DataChannel, error) { + stream, err := a.OpenStream(id, sctp.PayloadTypeWebRTCBinary) + if err != nil { + return nil, err + } + + dc, err := Client(stream, config) + if err != nil { + return nil, err + } + + isReliable := dc.ChannelType == ChannelTypeReliable || dc.ChannelType == ChannelTypeReliableUnordered + if isReliable && dc.ReliabilityParameter != 0 { + dc.log.Warnf("DataChannel opened with channel type %s, but has a non-zero reliability parameter: %d (expected 0)", + dc.ChannelType, + dc.ReliabilityParameter) + } + + return dc, nil +} + +// Client opens a data channel over an SCTP stream. +func Client(stream *sctp.Stream, config *Config) (*DataChannel, error) { + msg := &channelOpen{ + ChannelType: config.ChannelType, + Priority: config.Priority, + ReliabilityParameter: config.ReliabilityParameter, + + Label: []byte(config.Label), + Protocol: []byte(config.Protocol), + } + + if !config.Negotiated { + rawMsg, err := msg.Marshal() + if err != nil { + return nil, fmt.Errorf("failed to marshal ChannelOpen %w", err) + } + + if _, err = stream.WriteSCTP(rawMsg, sctp.PayloadTypeWebRTCDCEP); err != nil { + return nil, fmt.Errorf("failed to send ChannelOpen %w", err) + } + } + + return newDataChannel(stream, config), nil +} + +// Accept is used to accept incoming data channels over SCTP. +func Accept(a *sctp.Association, config *Config, existingChannels ...*DataChannel) (*DataChannel, error) { + stream, err := a.AcceptStream() + if err != nil { + return nil, err + } + for _, ch := range existingChannels { + if ch.StreamIdentifier() == stream.StreamIdentifier() { + ch.stream.SetDefaultPayloadType(sctp.PayloadTypeWebRTCBinary) + + return ch, nil + } + } + + stream.SetDefaultPayloadType(sctp.PayloadTypeWebRTCBinary) + + dc, err := Server(stream, config) + if err != nil { + return nil, err + } + + return dc, nil +} + +// Server accepts a data channel over an SCTP stream. +func Server(stream *sctp.Stream, config *Config) (*DataChannel, error) { + buffer := make([]byte, receiveMTU) + n, ppi, err := stream.ReadSCTP(buffer) + if err != nil { + return nil, err + } + + if ppi != sctp.PayloadTypeWebRTCDCEP { + return nil, fmt.Errorf("%w %s", ErrInvalidPayloadProtocolIdentifier, ppi) + } + + openMsg, err := parseExpectDataChannelOpen(buffer[:n]) + if err != nil { + return nil, fmt.Errorf("failed to parse DataChannelOpen packet %w", err) + } + + config.ChannelType = openMsg.ChannelType + config.Priority = openMsg.Priority + config.ReliabilityParameter = openMsg.ReliabilityParameter + config.Label = string(openMsg.Label) + config.Protocol = string(openMsg.Protocol) + + dataChannel := newDataChannel(stream, config) + + err = dataChannel.writeDataChannelAck() + if err != nil { + return nil, err + } + + err = dataChannel.commitReliabilityParams() + if err != nil { + return nil, err + } + + return dataChannel, nil +} + +// Read reads a packet of len(pkt) bytes as binary data. +func (c *DataChannel) Read(pkt []byte) (int, error) { + n, _, err := c.ReadDataChannel(pkt) + + return n, err +} + +// ReadDataChannel reads a packet of len(pkt) bytes. +func (c *DataChannel) ReadDataChannel(pkt []byte) (int, bool, error) { + for { + n, ppi, err := c.stream.ReadSCTP(pkt) + if errors.Is(err, io.EOF) { + // When the peer sees that an incoming stream was + // reset, it also resets its corresponding outgoing stream. + if closeErr := c.stream.Close(); closeErr != nil { + return 0, false, closeErr + } + } + if err != nil { + return 0, false, err + } + + if ppi == sctp.PayloadTypeWebRTCDCEP { + if err = c.handleDCEP(pkt[:n]); err != nil { + c.log.Errorf("Failed to handle DCEP: %s", err.Error()) + } + + continue + } else if ppi == sctp.PayloadTypeWebRTCBinaryEmpty || ppi == sctp.PayloadTypeWebRTCStringEmpty { + n = 0 + } + + atomic.AddUint32(&c.messagesReceived, 1) + atomic.AddUint64(&c.bytesReceived, uint64(n)) //nolint:gosec //G115 + + isString := ppi == sctp.PayloadTypeWebRTCString || ppi == sctp.PayloadTypeWebRTCStringEmpty + + return n, isString, err + } +} + +// SetReadDeadline sets a deadline for reads to return. +func (c *DataChannel) SetReadDeadline(t time.Time) error { + return c.stream.SetReadDeadline(t) +} + +// SetWriteDeadline sets a deadline for writes to return, +// only available if the BlockWrite is enabled for sctp. +func (c *DataChannel) SetWriteDeadline(t time.Time) error { + return c.stream.SetWriteDeadline(t) +} + +// MessagesSent returns the number of messages sent. +func (c *DataChannel) MessagesSent() uint32 { + return atomic.LoadUint32(&c.messagesSent) +} + +// MessagesReceived returns the number of messages received. +func (c *DataChannel) MessagesReceived() uint32 { + return atomic.LoadUint32(&c.messagesReceived) +} + +// OnOpen sets an event handler which is invoked when +// a DATA_CHANNEL_ACK message is received. +// The handler is called only on thefor the channel opened +// https://datatracker.ietf.org/doc/html/draft-ietf-rtcweb-data-protocol-09#section-5.2 +func (c *DataChannel) OnOpen(f func()) { + c.mu.Lock() + c.openCompleteHandlerOnce = sync.Once{} + c.onOpenCompleteHandler = f + c.mu.Unlock() +} + +func (c *DataChannel) onOpenComplete() { + c.mu.Lock() + hdlr := c.onOpenCompleteHandler + c.mu.Unlock() + + if hdlr != nil { + go c.openCompleteHandlerOnce.Do(func() { + hdlr() + }) + } +} + +// BytesSent returns the number of bytes sent. +func (c *DataChannel) BytesSent() uint64 { + return atomic.LoadUint64(&c.bytesSent) +} + +// BytesReceived returns the number of bytes received. +func (c *DataChannel) BytesReceived() uint64 { + return atomic.LoadUint64(&c.bytesReceived) +} + +// StreamIdentifier returns the Stream identifier associated to the stream. +func (c *DataChannel) StreamIdentifier() uint16 { + return c.stream.StreamIdentifier() +} + +func (c *DataChannel) handleDCEP(data []byte) error { + msg, err := parse(data) + if err != nil { + return fmt.Errorf("failed to parse DataChannel packet %w", err) + } + + switch msg := msg.(type) { + case *channelAck: + if err := c.commitReliabilityParams(); err != nil { + return err + } + c.onOpenComplete() + default: + return fmt.Errorf("%w, wanted ACK got %v", ErrUnexpectedDataChannelType, msg) + } + + return nil +} + +// Write writes len(pkt) bytes from pkt as binary data. +func (c *DataChannel) Write(pkt []byte) (n int, err error) { + return c.WriteDataChannel(pkt, false) +} + +// WriteDataChannel writes len(pkt) bytes from pkt. +func (c *DataChannel) WriteDataChannel(pkt []byte, isString bool) (n int, err error) { + // https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-12#section-6.6 + // SCTP does not support the sending of empty user messages. Therefore, + // if an empty message has to be sent, the appropriate PPID (WebRTC + // String Empty or WebRTC Binary Empty) is used and the SCTP user + // message of one zero byte is sent. When receiving an SCTP user + // message with one of these PPIDs, the receiver MUST ignore the SCTP + // user message and process it as an empty message. + var ppi sctp.PayloadProtocolIdentifier + switch { + case !isString && len(pkt) > 0: + ppi = sctp.PayloadTypeWebRTCBinary + case !isString && len(pkt) == 0: + ppi = sctp.PayloadTypeWebRTCBinaryEmpty + case isString && len(pkt) > 0: + ppi = sctp.PayloadTypeWebRTCString + case isString && len(pkt) == 0: + ppi = sctp.PayloadTypeWebRTCStringEmpty + } + + atomic.AddUint32(&c.messagesSent, 1) + atomic.AddUint64(&c.bytesSent, uint64(len(pkt))) + + if len(pkt) == 0 { + _, err := c.stream.WriteSCTP([]byte{0}, ppi) + + return 0, err + } + + return c.stream.WriteSCTP(pkt, ppi) +} + +func (c *DataChannel) writeDataChannelAck() error { + ack := channelAck{} + ackMsg, err := ack.Marshal() + if err != nil { + return fmt.Errorf("failed to marshal ChannelOpen ACK: %w", err) + } + + if _, err = c.stream.WriteSCTP(ackMsg, sctp.PayloadTypeWebRTCDCEP); err != nil { + return fmt.Errorf("failed to send ChannelOpen ACK: %w", err) + } + + return err +} + +// Close closes the DataChannel and the underlying SCTP stream. +func (c *DataChannel) Close() error { + // https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-13#section-6.7 + // Closing of a data channel MUST be signaled by resetting the + // corresponding outgoing streams [RFC6525]. This means that if one + // side decides to close the data channel, it resets the corresponding + // outgoing stream. When the peer sees that an incoming stream was + // reset, it also resets its corresponding outgoing stream. Once this + // is completed, the data channel is closed. Resetting a stream sets + // the Stream Sequence Numbers (SSNs) of the stream back to 'zero' with + // a corresponding notification to the application layer that the reset + // has been performed. Streams are available for reuse after a reset + // has been performed. + return c.stream.Close() +} + +// BufferedAmount returns the number of bytes of data currently queued to be +// sent over this stream. +func (c *DataChannel) BufferedAmount() uint64 { + return c.stream.BufferedAmount() +} + +// BufferedAmountLowThreshold returns the number of bytes of buffered outgoing +// data that is considered "low." Defaults to 0. +func (c *DataChannel) BufferedAmountLowThreshold() uint64 { + return c.stream.BufferedAmountLowThreshold() +} + +// SetBufferedAmountLowThreshold is used to update the threshold. +// See BufferedAmountLowThreshold(). +func (c *DataChannel) SetBufferedAmountLowThreshold(th uint64) { + c.stream.SetBufferedAmountLowThreshold(th) +} + +// OnBufferedAmountLow sets the callback handler which would be called when the +// number of bytes of outgoing data buffered is lower than the threshold. +func (c *DataChannel) OnBufferedAmountLow(f func()) { + c.stream.OnBufferedAmountLow(f) +} + +func (c *DataChannel) commitReliabilityParams() error { + switch c.Config.ChannelType { + case ChannelTypeReliable: + c.stream.SetReliabilityParams(false, sctp.ReliabilityTypeReliable, c.Config.ReliabilityParameter) // RFC 8832 sec 5.1 + if c.Config.ReliabilityParameter != 0 { + c.log.Warnf("Channel type is Reliable but has a non-zero reliability parameter: %d (expected 0)", + c.Config.ReliabilityParameter) + } + case ChannelTypeReliableUnordered: + c.stream.SetReliabilityParams(true, sctp.ReliabilityTypeReliable, c.Config.ReliabilityParameter) // RFC 8832 sec 5.1 + if c.Config.ReliabilityParameter != 0 { + c.log.Warnf("Channel type is ReliableUnordered but has a non-zero reliability parameter: %d (expected 0)", + c.Config.ReliabilityParameter) + } + case ChannelTypePartialReliableRexmit: + c.stream.SetReliabilityParams(false, sctp.ReliabilityTypeRexmit, c.Config.ReliabilityParameter) + case ChannelTypePartialReliableRexmitUnordered: + c.stream.SetReliabilityParams(true, sctp.ReliabilityTypeRexmit, c.Config.ReliabilityParameter) + case ChannelTypePartialReliableTimed: + c.stream.SetReliabilityParams(false, sctp.ReliabilityTypeTimed, c.Config.ReliabilityParameter) + case ChannelTypePartialReliableTimedUnordered: + c.stream.SetReliabilityParams(true, sctp.ReliabilityTypeTimed, c.Config.ReliabilityParameter) + default: + return fmt.Errorf("%w %v", ErrInvalidChannelType, c.Config.ChannelType) + } + + return nil +} diff --git a/vendor/github.com/pion/datachannel/errors.go b/vendor/github.com/pion/datachannel/errors.go new file mode 100644 index 0000000..a72d5bf --- /dev/null +++ b/vendor/github.com/pion/datachannel/errors.go @@ -0,0 +1,29 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package datachannel + +import "errors" + +var ( + // ErrDataChannelMessageTooShort means that the data isn't long enough to be a valid DataChannel message. + ErrDataChannelMessageTooShort = errors.New("DataChannel message is not long enough to determine type") + + // ErrInvalidPayloadProtocolIdentifier means that we got a DataChannel messages with a Payload Protocol Identifier + // we don't know how to handle. + ErrInvalidPayloadProtocolIdentifier = errors.New( + "DataChannel message Payload Protocol Identifier is value we can't handle", + ) + + // ErrInvalidChannelType means that the remote requested a channel type that we don't support. + ErrInvalidChannelType = errors.New("invalid Channel Type") + + // ErrInvalidMessageType is returned when a DataChannel Message has a type we don't support. + ErrInvalidMessageType = errors.New("invalid Message Type") + + // ErrExpectedAndActualLengthMismatch is when the declared length and actual length don't match. + ErrExpectedAndActualLengthMismatch = errors.New("expected and actual length do not match") + + // ErrUnexpectedDataChannelType is when a message type does not match the expected type. + ErrUnexpectedDataChannelType = errors.New("expected and actual message type does not match") +) diff --git a/vendor/github.com/pion/datachannel/message.go b/vendor/github.com/pion/datachannel/message.go new file mode 100644 index 0000000..9568a7a --- /dev/null +++ b/vendor/github.com/pion/datachannel/message.go @@ -0,0 +1,92 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package datachannel + +import ( + "fmt" +) + +// message is a parsed DataChannel message. +type message interface { + Marshal() ([]byte, error) + Unmarshal([]byte) error + String() string +} + +// messageType is the first byte in a DataChannel message that specifies type. +type messageType byte + +// DataChannel Message Types. +const ( + dataChannelAck messageType = 0x02 + dataChannelOpen messageType = 0x03 +) + +func (t messageType) String() string { + switch t { + case dataChannelAck: + return "DataChannelAck" + case dataChannelOpen: + return "DataChannelOpen" + default: + return fmt.Sprintf("Unknown MessageType: %d", t) + } +} + +// parse accepts raw input and returns a DataChannel message. +func parse(raw []byte) (message, error) { + if len(raw) == 0 { + return nil, ErrDataChannelMessageTooShort + } + + var msg message + switch messageType(raw[0]) { + case dataChannelOpen: + msg = &channelOpen{} + case dataChannelAck: + msg = &channelAck{} + default: + return nil, fmt.Errorf("%w %v", ErrInvalidMessageType, messageType(raw[0])) + } + + if err := msg.Unmarshal(raw); err != nil { + return nil, err + } + + return msg, nil +} + +// parseExpectDataChannelOpen parses a DataChannelOpen message +// or throws an error. +func parseExpectDataChannelOpen(raw []byte) (*channelOpen, error) { + if len(raw) == 0 { + return nil, ErrDataChannelMessageTooShort + } + + if actualTyp := messageType(raw[0]); actualTyp != dataChannelOpen { + return nil, fmt.Errorf("%w expected(%s) actual(%s)", ErrUnexpectedDataChannelType, actualTyp, dataChannelOpen) + } + + msg := &channelOpen{} + if err := msg.Unmarshal(raw); err != nil { + return nil, err + } + + return msg, nil +} + +// TryMarshalUnmarshal attempts to marshal and unmarshal a message. Added for fuzzing. +func TryMarshalUnmarshal(msg []byte) int { + message, err := parse(msg) + if err != nil { + return 0 + } + + _, err = message.Marshal() + if err != nil { + return 0 + } + + return 1 +} diff --git a/vendor/github.com/pion/datachannel/message_channel_ack.go b/vendor/github.com/pion/datachannel/message_channel_ack.go new file mode 100644 index 0000000..f4ce81e --- /dev/null +++ b/vendor/github.com/pion/datachannel/message_channel_ack.go @@ -0,0 +1,29 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package datachannel + +// channelAck is used to ACK a DataChannel open. +type channelAck struct{} + +const ( + channelOpenAckLength = 4 +) + +// Marshal returns raw bytes for the given message. +func (c *channelAck) Marshal() ([]byte, error) { + raw := make([]byte, channelOpenAckLength) + raw[0] = uint8(dataChannelAck) + + return raw, nil +} + +// Unmarshal populates the struct with the given raw data. +func (c *channelAck) Unmarshal(_ []byte) error { + // Message type already checked in Parse and there is no further data + return nil +} + +func (c channelAck) String() string { + return "ACK" +} diff --git a/vendor/github.com/pion/datachannel/message_channel_open.go b/vendor/github.com/pion/datachannel/message_channel_open.go new file mode 100644 index 0000000..5605c80 --- /dev/null +++ b/vendor/github.com/pion/datachannel/message_channel_open.go @@ -0,0 +1,155 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package datachannel + +import ( + "encoding/binary" + "fmt" +) + +/* +channelOpen represents a DATA_CHANNEL_OPEN Message + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Message Type | Channel Type | Priority | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Reliability Parameter | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Label Length | Protocol Length | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| | +| Label | +| | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| | +| Protocol | +| | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +. +*/ +type channelOpen struct { + ChannelType ChannelType + Priority uint16 + ReliabilityParameter uint32 + + Label []byte + Protocol []byte +} + +const ( + channelOpenHeaderLength = 12 +) + +// ChannelType determines the reliability of the WebRTC DataChannel. +type ChannelType byte + +// ChannelType enums. +const ( + // ChannelTypeReliable determines the Data Channel provides a + // reliable in-order bi-directional communication. + ChannelTypeReliable ChannelType = 0x00 + // ChannelTypeReliableUnordered determines the Data Channel + // provides a reliable unordered bi-directional communication. + ChannelTypeReliableUnordered ChannelType = 0x80 + // ChannelTypePartialReliableRexmit determines the Data Channel + // provides a partially-reliable in-order bi-directional communication. + // User messages will not be retransmitted more times than specified in the Reliability Parameter. + ChannelTypePartialReliableRexmit ChannelType = 0x01 + // ChannelTypePartialReliableRexmitUnordered determines + // the Data Channel provides a partial reliable unordered bi-directional communication. + // User messages will not be retransmitted more times than specified in the Reliability Parameter. + ChannelTypePartialReliableRexmitUnordered ChannelType = 0x81 + // ChannelTypePartialReliableTimed determines the Data Channel + // provides a partial reliable in-order bi-directional communication. + // User messages might not be transmitted or retransmitted after + // a specified life-time given in milli- seconds in the Reliability Parameter. + // This life-time starts when providing the user message to the protocol stack. + ChannelTypePartialReliableTimed ChannelType = 0x02 + // The Data Channel provides a partial reliable unordered bi-directional + // communication. User messages might not be transmitted or retransmitted + // after a specified life-time given in milli- seconds in the Reliability Parameter. + // This life-time starts when providing the user message to the protocol stack. + ChannelTypePartialReliableTimedUnordered ChannelType = 0x82 +) + +func (c ChannelType) String() string { + switch c { + case ChannelTypeReliable: + return "ReliableOrdered" + case ChannelTypeReliableUnordered: + return "ReliableUnordered" + case ChannelTypePartialReliableRexmit: + return "PartialReliableRexmit" + case ChannelTypePartialReliableRexmitUnordered: + return "PartialReliableRexmitUnordered" + case ChannelTypePartialReliableTimed: + return "PartialReliableTimed" + case ChannelTypePartialReliableTimedUnordered: + return "PartialReliableTimedUnordered" + } + + return "Unknown" +} + +// ChannelPriority enums. +const ( + ChannelPriorityBelowNormal uint16 = 128 + ChannelPriorityNormal uint16 = 256 + ChannelPriorityHigh uint16 = 512 + ChannelPriorityExtraHigh uint16 = 1024 +) + +// Marshal returns raw bytes for the given message. +func (c *channelOpen) Marshal() ([]byte, error) { + labelLength := len(c.Label) + protocolLength := len(c.Protocol) + + totalLen := channelOpenHeaderLength + labelLength + protocolLength + raw := make([]byte, totalLen) + + raw[0] = uint8(dataChannelOpen) + raw[1] = byte(c.ChannelType) + + binary.BigEndian.PutUint16(raw[2:], c.Priority) + binary.BigEndian.PutUint32(raw[4:], c.ReliabilityParameter) + binary.BigEndian.PutUint16(raw[8:], uint16(labelLength)) //nolint:gosec //G115 + binary.BigEndian.PutUint16(raw[10:], uint16(protocolLength)) //nolint:gosec //G115 + endLabel := channelOpenHeaderLength + labelLength + copy(raw[channelOpenHeaderLength:endLabel], c.Label) + copy(raw[endLabel:endLabel+protocolLength], c.Protocol) + + return raw, nil +} + +// Unmarshal populates the struct with the given raw data. +func (c *channelOpen) Unmarshal(raw []byte) error { + if len(raw) < channelOpenHeaderLength { + return fmt.Errorf("%w expected(%d) actual(%d)", ErrExpectedAndActualLengthMismatch, channelOpenHeaderLength, len(raw)) + } + c.ChannelType = ChannelType(raw[1]) + c.Priority = binary.BigEndian.Uint16(raw[2:]) + c.ReliabilityParameter = binary.BigEndian.Uint32(raw[4:]) + + labelLength := binary.BigEndian.Uint16(raw[8:]) + protocolLength := binary.BigEndian.Uint16(raw[10:]) + + if expectedLen := channelOpenHeaderLength + int(labelLength) + int(protocolLength); len(raw) != expectedLen { + return fmt.Errorf("%w expected(%d) actual(%d)", ErrExpectedAndActualLengthMismatch, expectedLen, len(raw)) + } + + c.Label = raw[channelOpenHeaderLength : channelOpenHeaderLength+labelLength] + c.Protocol = raw[channelOpenHeaderLength+labelLength : channelOpenHeaderLength+labelLength+protocolLength] + + return nil +} + +func (c channelOpen) String() string { + return fmt.Sprintf( + "Open ChannelType(%s) Priority(%v) ReliabilityParameter(%d) Label(%s) Protocol(%s)", + c.ChannelType, c.Priority, c.ReliabilityParameter, string(c.Label), string(c.Protocol), + ) +} diff --git a/vendor/github.com/pion/datachannel/renovate.json b/vendor/github.com/pion/datachannel/renovate.json new file mode 100644 index 0000000..f1bb98c --- /dev/null +++ b/vendor/github.com/pion/datachannel/renovate.json @@ -0,0 +1,6 @@ +{ + "$schema": "https://docs.renovatebot.com/renovate-schema.json", + "extends": [ + "github>pion/renovate-config" + ] +} diff --git a/vendor/github.com/pion/dtls/v3/.editorconfig b/vendor/github.com/pion/dtls/v3/.editorconfig new file mode 100644 index 0000000..714d21d --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/.editorconfig @@ -0,0 +1,23 @@ +# http://editorconfig.org/ +# SPDX-FileCopyrightText: 2026 The Pion community +# SPDX-License-Identifier: MIT + +root = true + +[*] +charset = utf-8 +insert_final_newline = true +trim_trailing_whitespace = true +end_of_line = lf + +[*.go] +indent_style = tab +indent_size = 4 + +[{*.yml,*.yaml}] +indent_style = space +indent_size = 2 + +# Makefiles always use tabs for indentation +[Makefile] +indent_style = tab diff --git a/vendor/github.com/pion/dtls/v3/.gitignore b/vendor/github.com/pion/dtls/v3/.gitignore new file mode 100644 index 0000000..2394557 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/.gitignore @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: 2026 The Pion community +# SPDX-License-Identifier: MIT + +### JetBrains IDE ### +##################### +.idea/ + +### Emacs Temporary Files ### +############################# +*~ + +### Folders ### +############### +bin/ +vendor/ +node_modules/ + +### Files ### +############# +*.ivf +*.ogg +tags +cover.out +*.sw[poe] +*.wasm +examples/sfu-ws/cert.pem +examples/sfu-ws/key.pem +wasm_exec.js diff --git a/vendor/github.com/pion/dtls/v3/.golangci.yml b/vendor/github.com/pion/dtls/v3/.golangci.yml new file mode 100644 index 0000000..43af4c3 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/.golangci.yml @@ -0,0 +1,147 @@ +# SPDX-FileCopyrightText: 2026 The Pion community +# SPDX-License-Identifier: MIT + +version: "2" +linters: + enable: + - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers + - bidichk # Checks for dangerous unicode character sequences + - bodyclose # checks whether HTTP response body is closed successfully + - containedctx # containedctx is a linter that detects struct contained context.Context field + - contextcheck # check the function whether use a non-inherited context + - cyclop # checks function and package cyclomatic complexity + - decorder # check declaration order and count of types, constants, variables and functions + - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) + - dupl # Tool for code clone detection + - durationcheck # check for two durations multiplied together + - err113 # Golang linter to check the errors handling expressions + - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases + - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. + - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. + - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. + - exhaustive # check exhaustiveness of enum switch statements + - forbidigo # Forbids identifiers + - forcetypeassert # finds forced type assertions + - gochecknoglobals # Checks that no globals are present in Go code + - gocognit # Computes and checks the cognitive complexity of functions + - goconst # Finds repeated strings that could be replaced by a constant + - gocritic # The most opinionated Go source code linter + - gocyclo # Computes and checks the cyclomatic complexity of functions + - godot # Check if comments end in a period + - godox # Tool for detection of FIXME, TODO and other comment keywords + - goheader # Checks is file header matches to pattern + - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. + - goprintffuncname # Checks that printf-like functions are named with `f` at the end + - gosec # Inspects source code for security problems + - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string + - grouper # An analyzer to analyze expression groups. + - importas # Enforces consistent import aliases + - ineffassign # Detects when assignments to existing variables are not used + - lll # Reports long lines + - maintidx # maintidx measures the maintainability index of each function. + - makezero # Finds slice declarations with non-zero initial length + - misspell # Finds commonly misspelled English words in comments + - nakedret # Finds naked returns in functions greater than a specified function length + - nestif # Reports deeply nested if statements + - nilerr # Finds the code that returns nil even if it checks that the error is not nil. + - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. + - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity + - noctx # noctx finds sending http request without context.Context + - predeclared # find code that shadows one of Go's predeclared identifiers + - revive # golint replacement, finds style mistakes + - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks + - tagliatelle # Checks the struct tags. + - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers + - unconvert # Remove unnecessary type conversions + - unparam # Reports unused function parameters + - unused # Checks Go code for unused constants, variables, functions and types + - varnamelen # checks that the length of a variable's name matches its scope + - wastedassign # wastedassign finds wasted assignment statements + - whitespace # Tool for detection of leading and trailing whitespace + disable: + - depguard # Go linter that checks if package imports are in a list of acceptable packages + - funlen # Tool for detection of long functions + - gochecknoinits # Checks that no init functions are present in Go code + - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. + - interfacebloat # A linter that checks length of interface. + - ireturn # Accept Interfaces, Return Concrete Types + - mnd # An analyzer to detect magic numbers + - nolintlint # Reports ill-formed or insufficient nolint directives + - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test + - prealloc # Finds slice declarations that could potentially be preallocated + - promlinter # Check Prometheus metrics naming via promlint + - rowserrcheck # checks whether Err of rows is checked successfully + - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. + - testpackage # linter that makes you use a separate _test package + - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes + - wrapcheck # Checks that errors returned from external packages are wrapped + - wsl # Whitespace Linter - Forces you to use empty lines! + settings: + staticcheck: + checks: + - all + - -QF1008 # "could remove embedded field", to keep it explicit! + - -QF1003 # "could use tagged switch on enum", Cases conflicts with exhaustive! + exhaustive: + default-signifies-exhaustive: true + forbidigo: + forbid: + - pattern: ^fmt.Print(f|ln)?$ + - pattern: ^log.(Panic|Fatal|Print)(f|ln)?$ + - pattern: ^os.Exit$ + - pattern: ^panic$ + - pattern: ^print(ln)?$ + - pattern: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$ + pkg: ^testing$ + msg: use testify/assert instead + analyze-types: true + gomodguard: + blocked: + modules: + - github.com/pkg/errors: + recommendations: + - errors + govet: + enable: + - shadow + revive: + rules: + # Prefer 'any' type alias over 'interface{}' for Go 1.18+ compatibility + - name: use-any + severity: warning + disabled: false + misspell: + locale: US + varnamelen: + max-distance: 12 + min-name-length: 2 + ignore-type-assert-ok: true + ignore-map-index-ok: true + ignore-chan-recv-ok: true + ignore-decls: + - i int + - n int + - w io.Writer + - r io.Reader + - b []byte + exclusions: + generated: lax + rules: + - linters: + - forbidigo + - gocognit + path: (examples|main\.go) + - linters: + - gocognit + path: _test\.go + - linters: + - forbidigo + path: cmd +formatters: + enable: + - gci # Gci control golang package import order and make it always deterministic. + - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification + - gofumpt # Gofumpt checks whether code was gofumpt-ed. + - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports + exclusions: + generated: lax diff --git a/vendor/github.com/pion/dtls/v3/.goreleaser.yml b/vendor/github.com/pion/dtls/v3/.goreleaser.yml new file mode 100644 index 0000000..8577d86 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/.goreleaser.yml @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: 2026 The Pion community +# SPDX-License-Identifier: MIT + +builds: +- skip: true diff --git a/vendor/github.com/pion/dtls/v3/LICENSE b/vendor/github.com/pion/dtls/v3/LICENSE new file mode 100644 index 0000000..d96df05 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/LICENSE @@ -0,0 +1,9 @@ +MIT License + +Copyright (c) 2026 The Pion community + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/pion/dtls/v3/README.md b/vendor/github.com/pion/dtls/v3/README.md new file mode 100644 index 0000000..a73a197 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/README.md @@ -0,0 +1,159 @@ +

+
+ Pion DTLS +
+

+

A Go implementation of DTLS

+

+ Pion DTLS + Sourcegraph Widget + join us on Discord Follow us on Bluesky +
+ GitHub Workflow Status + Go Reference + Coverage Status + Go Report Card + License: MIT +

+
+ +Native [DTLS 1.2][rfc6347] implementation in the Go programming language. + +A long term goal is a professional security review, and maybe an inclusion in stdlib. + +### RFCs +#### Implemented +- **RFC 6347**: [Datagram Transport Layer Security Version 1.2][rfc6347] +- **RFC 5705**: [Keying Material Exporters for Transport Layer Security (TLS)][rfc5705] +- **RFC 7627**: [Transport Layer Security (TLS) - Session Hash and Extended Master Secret Extension][rfc7627] +- **RFC 7301**: [Transport Layer Security (TLS) - Application-Layer Protocol Negotiation Extension][rfc7301] + +[rfc5289]: https://tools.ietf.org/html/rfc5289 +[rfc5487]: https://tools.ietf.org/html/rfc5487 +[rfc5489]: https://tools.ietf.org/html/rfc5489 +[rfc5705]: https://tools.ietf.org/html/rfc5705 +[rfc6347]: https://tools.ietf.org/html/rfc6347 +[rfc6655]: https://tools.ietf.org/html/rfc6655 +[rfc7301]: https://tools.ietf.org/html/rfc7301 +[rfc7627]: https://tools.ietf.org/html/rfc7627 +[rfc8422]: https://tools.ietf.org/html/rfc8422 +[rfc9147]: https://tools.ietf.org/html/rfc9147 + +### Goals/Progress +This will only be targeting DTLS 1.2, and the most modern/common cipher suites. +We would love contributions that fall under the 'Planned Features' and any bug fixes! + +#### Current features +* DTLS 1.2 Client/Server +* Key Exchange via ECDHE(curve25519, nistp256, nistp384) and PSK +* Packet loss and re-ordering is handled during handshaking +* Key export ([RFC 5705][rfc5705]) +* Serialization and Resumption of sessions +* Extended Master Secret extension ([RFC 7627][rfc7627]) +* ALPN extension ([RFC 7301][rfc7301]) + +#### Supported ciphers + +##### ECDHE + +* TLS_ECDHE_ECDSA_WITH_AES_128_CCM ([RFC 6655][rfc6655]) +* TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 ([RFC 6655][rfc6655]) +* TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 ([RFC 5289][rfc5289]) +* TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 ([RFC 5289][rfc5289]) +* TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 ([RFC 5289][rfc5289]) +* TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 ([RFC 5289][rfc5289]) +* TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA ([RFC 8422][rfc8422]) +* TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA ([RFC 8422][rfc8422]) + +##### PSK + +* TLS_PSK_WITH_AES_128_CCM ([RFC 6655][rfc6655]) +* TLS_PSK_WITH_AES_128_CCM_8 ([RFC 6655][rfc6655]) +* TLS_PSK_WITH_AES_256_CCM_8 ([RFC 6655][rfc6655]) +* TLS_PSK_WITH_AES_128_GCM_SHA256 ([RFC 5487][rfc5487]) +* TLS_PSK_WITH_AES_128_CBC_SHA256 ([RFC 5487][rfc5487]) + +##### ECDHE & PSK + +* TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 ([RFC 5489][rfc5489]) + +#### Planned Features +* DTLS 1.3 ([RFC 9147][rfc9147]) +* Chacha20Poly1305 + +#### Excluded Features +* DTLS 1.0 +* Renegotiation +* Compression + +### Using + +This library needs at least Go 1.21, and you should have [Go modules +enabled](https://github.com/golang/go/wiki/Modules). + +#### Pion DTLS +For a DTLS 1.2 Server that listens on 127.0.0.1:4444 +```sh +go run examples/listen/selfsign/main.go +``` + +For a DTLS 1.2 Client that connects to 127.0.0.1:4444 +```sh +go run examples/dial/selfsign/main.go +``` + +#### OpenSSL +Pion DTLS can connect to itself and OpenSSL. +``` + // Generate a certificate + openssl ecparam -out key.pem -name prime256v1 -genkey + openssl req -new -sha256 -key key.pem -out server.csr + openssl x509 -req -sha256 -days 365 -in server.csr -signkey key.pem -out cert.pem + + // Use with examples/dial/selfsign/main.go + openssl s_server -dtls1_2 -cert cert.pem -key key.pem -accept 4444 + + // Use with examples/listen/selfsign/main.go + openssl s_client -dtls1_2 -connect 127.0.0.1:4444 -debug -cert cert.pem -key key.pem +``` + +### Using with PSK +Pion DTLS also comes with examples that do key exchange via PSK + +#### Pion DTLS +```sh +go run examples/listen/psk/main.go +``` + +```sh +go run examples/dial/psk/main.go +``` + +#### OpenSSL +``` + // Use with examples/dial/psk/main.go + openssl s_server -dtls1_2 -accept 4444 -nocert -psk abc123 -cipher PSK-AES128-CCM8 + + // Use with examples/listen/psk/main.go + openssl s_client -dtls1_2 -connect 127.0.0.1:4444 -psk abc123 -cipher PSK-AES128-CCM8 +``` + +### Community +Pion has an active community on the [Discord](https://discord.gg/PngbdqpFbt). + +Follow the [Pion Bluesky](https://bsky.app/profile/pion.ly) or [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news. + +We are always looking to support **your projects**. Please reach out if you have something to build! +If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly) + +### Contributing +Check out the [contributing wiki](https://github.com/pion/webrtc/wiki/Contributing) to join the group of amazing people making this project possible + +### Funding +NLnet foundation logo +NLnet foundation logo + +The DTLS 1.3 implementation in this project is funded through the [NGI0 Commons Fund](https://nlnet.nl/commonsfund), a fund established by [NLnet](https://nlnet.nl/) with financial support from the European Commission's [Next Generation Internet](https://ngi.eu/) programme, under the aegis of [DG Communications Networks, Content and Technology](https://commission.europa.eu/about-european-commission/departments-and-executive-agencies/communications-networks-content-and-technology_en) under grant agreement No [101135429](https://cordis.europa.eu/project/id/101135429). Additional funding is made available by the [Swiss State Secretariat for Education, Research and Innovation](https://www.sbfi.admin.ch/sbfi/en/home.html) (SERI). Learn more on the [NLnet project page](https://nlnet.nl/project/PION-DTLS1.3/). + +### License +MIT License - see [LICENSE](LICENSE) for full text diff --git a/vendor/github.com/pion/dtls/v3/certificate.go b/vendor/github.com/pion/dtls/v3/certificate.go new file mode 100644 index 0000000..7aeb49e --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/certificate.go @@ -0,0 +1,167 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "bytes" + "crypto/tls" + "crypto/x509" + "fmt" + "strings" + + "github.com/pion/dtls/v3/pkg/protocol/handshake" +) + +// ClientHelloInfo contains information from a ClientHello message in order to +// guide application logic in the GetCertificate. +type ClientHelloInfo struct { + // ServerName indicates the name of the server requested by the client + // in order to support virtual hosting. ServerName is only set if the + // client is using SNI (see RFC 4366, Section 3.1). + ServerName string + + // CipherSuites lists the CipherSuites supported by the client (e.g. + // TLS_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256). + CipherSuites []CipherSuiteID + + // RandomBytes stores the client hello random bytes + RandomBytes [handshake.RandomBytesLength]byte +} + +// CertificateRequestInfo contains information from a server's +// CertificateRequest message, which is used to demand a certificate and proof +// of control from a client. +type CertificateRequestInfo struct { + // AcceptableCAs contains zero or more, DER-encoded, X.501 + // Distinguished Names. These are the names of root or intermediate CAs + // that the server wishes the returned certificate to be signed by. An + // empty slice indicates that the server has no preference. + AcceptableCAs [][]byte +} + +// SupportsCertificate returns nil if the provided certificate is supported by +// the server that sent the CertificateRequest. Otherwise, it returns an error +// describing the reason for the incompatibility. +// NOTE: original src: +// https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/common.go#L1273 +func (cri *CertificateRequestInfo) SupportsCertificate(c *tls.Certificate) error { + if len(cri.AcceptableCAs) == 0 { + return nil + } + + for j, cert := range c.Certificate { + x509Cert := c.Leaf + // Parse the certificate if this isn't the leaf node, or if + // chain.Leaf was nil. + if j != 0 || x509Cert == nil { + var err error + if x509Cert, err = x509.ParseCertificate(cert); err != nil { + return fmt.Errorf("failed to parse certificate #%d in the chain: %w", j, err) + } + } + + for _, ca := range cri.AcceptableCAs { + if bytes.Equal(x509Cert.RawIssuer, ca) { + return nil + } + } + } + + return errNotAcceptableCertificateChain +} + +func (c *handshakeConfig) setNameToCertificateLocked() { + nameToCertificate := make(map[string]*tls.Certificate) + for i := range c.localCertificates { + cert := &c.localCertificates[i] + x509Cert := cert.Leaf + if x509Cert == nil { + var parseErr error + x509Cert, parseErr = x509.ParseCertificate(cert.Certificate[0]) + if parseErr != nil { + continue + } + } + if len(x509Cert.Subject.CommonName) > 0 { + nameToCertificate[strings.ToLower(x509Cert.Subject.CommonName)] = cert + } + for _, san := range x509Cert.DNSNames { + nameToCertificate[strings.ToLower(san)] = cert + } + } + c.nameToCertificate = nameToCertificate +} + +//nolint:cyclop +func (c *handshakeConfig) getCertificate(clientHelloInfo *ClientHelloInfo) (*tls.Certificate, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.localGetCertificate != nil && + (len(c.localCertificates) == 0 || len(clientHelloInfo.ServerName) > 0) { + cert, err := c.localGetCertificate(clientHelloInfo) + if cert != nil || err != nil { + return cert, err + } + } + + if c.nameToCertificate == nil { + c.setNameToCertificateLocked() + } + + if len(c.localCertificates) == 0 { + return nil, errNoCertificates + } + + if len(c.localCertificates) == 1 { + // There's only one choice, so no point doing any work. + return &c.localCertificates[0], nil + } + + if len(clientHelloInfo.ServerName) == 0 { + return &c.localCertificates[0], nil + } + + name := strings.TrimRight(strings.ToLower(clientHelloInfo.ServerName), ".") + + if cert, ok := c.nameToCertificate[name]; ok { + return cert, nil + } + + // try replacing labels in the name with wildcards until we get a + // match. + labels := strings.Split(name, ".") + for i := range labels { + labels[i] = "*" + candidate := strings.Join(labels, ".") + if cert, ok := c.nameToCertificate[candidate]; ok { + return cert, nil + } + } + + // If nothing matches, return the first certificate. + return &c.localCertificates[0], nil +} + +// NOTE: original src: +// https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/handshake_client.go#L974 +func (c *handshakeConfig) getClientCertificate(cri *CertificateRequestInfo) (*tls.Certificate, error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.localGetClientCertificate != nil { + return c.localGetClientCertificate(cri) + } + + for i := range c.localCertificates { + chain := c.localCertificates[i] + if err := cri.SupportsCertificate(&chain); err != nil { + continue + } + + return &chain, nil + } + + // No acceptable certificate found. Don't send a certificate. + return new(tls.Certificate), nil +} diff --git a/vendor/github.com/pion/dtls/v3/cipher_suite.go b/vendor/github.com/pion/dtls/v3/cipher_suite.go new file mode 100644 index 0000000..a53202f --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/cipher_suite.go @@ -0,0 +1,295 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/tls" + "fmt" + "hash" + + "github.com/pion/dtls/v3/internal/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" +) + +// CipherSuiteID is an ID for our supported CipherSuites. +type CipherSuiteID = ciphersuite.ID + +// Supported Cipher Suites. +const ( + + // nolint: godot + // AES-128-CCM + TLS_ECDHE_ECDSA_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM // nolint: revive,staticcheck,lll + TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 // nolint: revive,staticcheck,lll + + // nolint: godot + // AES-128-GCM-SHA256 + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 // nolint: revive,staticcheck,lll + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 // nolint: revive,staticcheck,lll + + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 // nolint: revive,staticcheck,lll + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 // nolint: revive,staticcheck,lll + + // nolint: godot + // AES-256-CBC-SHA + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA // nolint: revive,staticcheck,lll + TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA // nolint: revive,staticcheck,lll + + TLS_PSK_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM // nolint: revive,staticcheck,lll + TLS_PSK_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM_8 // nolint: revive,staticcheck,lll + TLS_PSK_WITH_AES_256_CCM_8 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_256_CCM_8 // nolint: revive,staticcheck,lll + TLS_PSK_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_GCM_SHA256 // nolint: revive,staticcheck,lll + TLS_PSK_WITH_AES_128_CBC_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CBC_SHA256 // nolint: revive,staticcheck,lll + + TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 // nolint: revive,staticcheck,lll +) + +// CipherSuiteAuthenticationType controls what authentication method is using during the handshake for a CipherSuite. +type CipherSuiteAuthenticationType = ciphersuite.AuthenticationType + +// AuthenticationType Enums. +const ( + CipherSuiteAuthenticationTypeCertificate CipherSuiteAuthenticationType = ciphersuite.AuthenticationTypeCertificate + CipherSuiteAuthenticationTypePreSharedKey CipherSuiteAuthenticationType = ciphersuite.AuthenticationTypePreSharedKey + CipherSuiteAuthenticationTypeAnonymous CipherSuiteAuthenticationType = ciphersuite.AuthenticationTypeAnonymous +) + +// CipherSuiteKeyExchangeAlgorithm controls what exchange algorithm is using during the handshake for a CipherSuite. +type CipherSuiteKeyExchangeAlgorithm = ciphersuite.KeyExchangeAlgorithm + +// CipherSuiteKeyExchangeAlgorithm Bitmask. +const ( + CipherSuiteKeyExchangeAlgorithmNone CipherSuiteKeyExchangeAlgorithm = ciphersuite.KeyExchangeAlgorithmNone + CipherSuiteKeyExchangeAlgorithmPsk CipherSuiteKeyExchangeAlgorithm = ciphersuite.KeyExchangeAlgorithmPsk + CipherSuiteKeyExchangeAlgorithmEcdhe CipherSuiteKeyExchangeAlgorithm = ciphersuite.KeyExchangeAlgorithmEcdhe +) + +var _ = allCipherSuites() // Necessary until this function isn't only used by Go 1.14 + +// CipherSuite is an interface that all DTLS CipherSuites must satisfy. +type CipherSuite interface { + // String of CipherSuite, only used for logging + String() string + + // ID of CipherSuite. + ID() CipherSuiteID + + // What type of Certificate does this CipherSuite use + CertificateType() clientcertificate.Type + + // What Hash function is used during verification + HashFunc() func() hash.Hash + + // AuthenticationType controls what authentication method is using during the handshake + AuthenticationType() CipherSuiteAuthenticationType + + // KeyExchangeAlgorithm controls what exchange algorithm is using during the handshake + KeyExchangeAlgorithm() CipherSuiteKeyExchangeAlgorithm + + // ECC (Elliptic Curve Cryptography) determines whether ECC extesions will be send during handshake. + // https://datatracker.ietf.org/doc/html/rfc4492#page-10 + ECC() bool + + // Called when keying material has been generated, should initialize the internal cipher + Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error + IsInitialized() bool + Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) + Decrypt(h recordlayer.Header, in []byte) ([]byte, error) +} + +// CipherSuiteName provides the same functionality as tls.CipherSuiteName +// that appeared first in Go 1.14. +// +// Our implementation differs slightly in that it takes in a CiperSuiteID, +// like the rest of our library, instead of a uint16 like crypto/tls. +func CipherSuiteName(id CipherSuiteID) string { + suite := cipherSuiteForID(id, nil) + if suite != nil { + return suite.String() + } + + return fmt.Sprintf("0x%04X", uint16(id)) +} + +// Taken from https://www.iana.org/assignments/tls-parameters/tls-parameters.xml +// A cipherSuite is a specific combination of key agreement, cipher and MAC +// function. +func cipherSuiteForID(id CipherSuiteID, customCiphers func() []CipherSuite) CipherSuite { //nolint:cyclop + switch id { //nolint:exhaustive + case TLS_ECDHE_ECDSA_WITH_AES_128_CCM: + return ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm() + case TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8: + return ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm8() + case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: + return &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{} + case TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: + return &ciphersuite.TLSEcdheRsaWithAes128GcmSha256{} + case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: + return &ciphersuite.TLSEcdheEcdsaWithAes256CbcSha{} + case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: + return &ciphersuite.TLSEcdheRsaWithAes256CbcSha{} + case TLS_PSK_WITH_AES_128_CCM: + return ciphersuite.NewTLSPskWithAes128Ccm() + case TLS_PSK_WITH_AES_128_CCM_8: + return ciphersuite.NewTLSPskWithAes128Ccm8() + case TLS_PSK_WITH_AES_256_CCM_8: + return ciphersuite.NewTLSPskWithAes256Ccm8() + case TLS_PSK_WITH_AES_128_GCM_SHA256: + return &ciphersuite.TLSPskWithAes128GcmSha256{} + case TLS_PSK_WITH_AES_128_CBC_SHA256: + return &ciphersuite.TLSPskWithAes128CbcSha256{} + case TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: + return &ciphersuite.TLSEcdheEcdsaWithAes256GcmSha384{} + case TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: + return &ciphersuite.TLSEcdheRsaWithAes256GcmSha384{} + case TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256: + return ciphersuite.NewTLSEcdhePskWithAes128CbcSha256() + } + + if customCiphers != nil { + for _, c := range customCiphers() { + if c.ID() == id { + return c + } + } + } + + return nil +} + +// CipherSuites we support in order of preference. +func defaultCipherSuites() []CipherSuite { + return []CipherSuite{ + &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}, + &ciphersuite.TLSEcdheRsaWithAes128GcmSha256{}, + &ciphersuite.TLSEcdheEcdsaWithAes256CbcSha{}, + &ciphersuite.TLSEcdheRsaWithAes256CbcSha{}, + &ciphersuite.TLSEcdheEcdsaWithAes256GcmSha384{}, + &ciphersuite.TLSEcdheRsaWithAes256GcmSha384{}, + } +} + +func allCipherSuites() []CipherSuite { + return []CipherSuite{ + ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm(), + ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm8(), + &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}, + &ciphersuite.TLSEcdheRsaWithAes128GcmSha256{}, + &ciphersuite.TLSEcdheEcdsaWithAes256CbcSha{}, + &ciphersuite.TLSEcdheRsaWithAes256CbcSha{}, + ciphersuite.NewTLSPskWithAes128Ccm(), + ciphersuite.NewTLSPskWithAes128Ccm8(), + ciphersuite.NewTLSPskWithAes256Ccm8(), + &ciphersuite.TLSPskWithAes128GcmSha256{}, + &ciphersuite.TLSEcdheEcdsaWithAes256GcmSha384{}, + &ciphersuite.TLSEcdheRsaWithAes256GcmSha384{}, + } +} + +func cipherSuiteIDs(cipherSuites []CipherSuite) []uint16 { + rtrn := []uint16{} + for _, c := range cipherSuites { + rtrn = append(rtrn, uint16(c.ID())) + } + + return rtrn +} + +//nolint:cyclop +func parseCipherSuites( + userSelectedSuites []CipherSuiteID, + customCipherSuites func() []CipherSuite, + includeCertificateSuites, includePSKSuites bool, +) ([]CipherSuite, error) { + cipherSuitesForIDs := func(ids []CipherSuiteID) ([]CipherSuite, error) { + cipherSuites := []CipherSuite{} + for _, id := range ids { + c := cipherSuiteForID(id, nil) + if c == nil { + return nil, &invalidCipherSuiteError{id} + } + cipherSuites = append(cipherSuites, c) + } + + return cipherSuites, nil + } + + var ( + cipherSuites []CipherSuite + err error + i int + ) + if userSelectedSuites != nil { + cipherSuites, err = cipherSuitesForIDs(userSelectedSuites) + if err != nil { + return nil, err + } + } else { + cipherSuites = defaultCipherSuites() + } + + // Put CustomCipherSuites before ID selected suites + if customCipherSuites != nil { + cipherSuites = append(customCipherSuites(), cipherSuites...) + } + + var foundCertificateSuite, foundPSKSuite, foundAnonymousSuite bool + for _, c := range cipherSuites { + switch { + case includeCertificateSuites && c.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate: + foundCertificateSuite = true + case includePSKSuites && c.AuthenticationType() == CipherSuiteAuthenticationTypePreSharedKey: + foundPSKSuite = true + case c.AuthenticationType() == CipherSuiteAuthenticationTypeAnonymous: + foundAnonymousSuite = true + default: + continue + } + cipherSuites[i] = c + i++ + } + + switch { + case includeCertificateSuites && !foundCertificateSuite && !foundAnonymousSuite: + return nil, errNoAvailableCertificateCipherSuite + case includePSKSuites && !foundPSKSuite: + return nil, errNoAvailablePSKCipherSuite + case i == 0: + return nil, errNoAvailableCipherSuites + } + + return cipherSuites[:i], nil +} + +func filterCipherSuitesForCertificate(cert *tls.Certificate, cipherSuites []CipherSuite) []CipherSuite { + if cert == nil || cert.PrivateKey == nil { + return cipherSuites + } + signer, ok := cert.PrivateKey.(crypto.Signer) + if !ok { + return cipherSuites + } + + var certType clientcertificate.Type + switch signer.Public().(type) { + case ed25519.PublicKey, *ecdsa.PublicKey: + certType = clientcertificate.ECDSASign + case *rsa.PublicKey: + certType = clientcertificate.RSASign + } + + filtered := []CipherSuite{} + for _, c := range cipherSuites { + if c.AuthenticationType() != CipherSuiteAuthenticationTypeCertificate || certType == c.CertificateType() { + filtered = append(filtered, c) + } + } + + return filtered +} diff --git a/vendor/github.com/pion/dtls/v3/cipher_suite_go114.go b/vendor/github.com/pion/dtls/v3/cipher_suite_go114.go new file mode 100644 index 0000000..4017462 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/cipher_suite_go114.go @@ -0,0 +1,46 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build go1.14 +// +build go1.14 + +package dtls + +import ( + "crypto/tls" +) + +// VersionDTLS12 is the DTLS version in the same style as +// VersionTLSXX from crypto/tls. +const VersionDTLS12 = 0xfefd + +// Convert from our cipherSuite interface to a tls.CipherSuite struct. +func toTLSCipherSuite(c CipherSuite) *tls.CipherSuite { + return &tls.CipherSuite{ + ID: uint16(c.ID()), + Name: c.String(), + SupportedVersions: []uint16{VersionDTLS12}, + Insecure: false, + } +} + +// CipherSuites returns a list of cipher suites currently implemented by this +// package, excluding those with security issues, which are returned by +// InsecureCipherSuites. +func CipherSuites() []*tls.CipherSuite { + suites := allCipherSuites() + res := make([]*tls.CipherSuite, len(suites)) + for i, c := range suites { + res[i] = toTLSCipherSuite(c) + } + + return res +} + +// InsecureCipherSuites returns a list of cipher suites currently implemented by +// this package and which have security issues. +func InsecureCipherSuites() []*tls.CipherSuite { + var res []*tls.CipherSuite + + return res +} diff --git a/vendor/github.com/pion/dtls/v3/codecov.yml b/vendor/github.com/pion/dtls/v3/codecov.yml new file mode 100644 index 0000000..b9639c2 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/codecov.yml @@ -0,0 +1,22 @@ +# +# DO NOT EDIT THIS FILE +# +# It is automatically copied from https://github.com/pion/.goassets repository. +# +# SPDX-FileCopyrightText: 2026 The Pion community +# SPDX-License-Identifier: MIT + +coverage: + status: + project: + default: + # Allow decreasing 2% of total coverage to avoid noise. + threshold: 2% + patch: + default: + target: 70% + only_pulls: true + +ignore: + - "examples/*" + - "examples/**/*" diff --git a/vendor/github.com/pion/dtls/v3/compression_method.go b/vendor/github.com/pion/dtls/v3/compression_method.go new file mode 100644 index 0000000..1b93599 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/compression_method.go @@ -0,0 +1,12 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import "github.com/pion/dtls/v3/pkg/protocol" + +func defaultCompressionMethods() []*protocol.CompressionMethod { + return []*protocol.CompressionMethod{ + {}, + } +} diff --git a/vendor/github.com/pion/dtls/v3/config.go b/vendor/github.com/pion/dtls/v3/config.go new file mode 100644 index 0000000..77f977d --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/config.go @@ -0,0 +1,306 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "io" + "net" + "time" + + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/logging" +) + +const keyLogLabelTLS12 = "CLIENT_RANDOM" + +// Config is used to configure a DTLS client or server. +// After a Config is passed to a DTLS function it must not be modified. +// +// Deprecated: prefer the options-based APIs (`*WithOptions`) to construct immutable configs, +// This will be removed in the next major version. +type Config struct { //nolint:dupl + // Certificates contains certificate chain to present to the other side of the connection. + // Server MUST set this if PSK is non-nil + // client SHOULD sets this so CertificateRequests can be handled if PSK is non-nil + Certificates []tls.Certificate + + // CipherSuites is a list of supported cipher suites. + // If CipherSuites is nil, a default list is used + CipherSuites []CipherSuiteID + + // CustomCipherSuites is a list of CipherSuites that can be + // provided by the user. This allow users to user Ciphers that are reserved + // for private usage. + CustomCipherSuites func() []CipherSuite + + // SignatureSchemes contains the signature and hash schemes that the peer requests to verify. + SignatureSchemes []tls.SignatureScheme + + // CertificateSignatureSchemes contains the signature and hash schemes that may be used + // in digital signatures for X.509 certificates. If not set, the signature_algorithms_cert + // extension is not sent, and SignatureSchemes is used for both handshake signatures and + // certificate chain validation, as specified in RFC 8446 Section 4.2.3. + CertificateSignatureSchemes []tls.SignatureScheme + + // SRTPProtectionProfiles are the supported protection profiles + // Clients will send this via use_srtp and assert that the server properly responds + // Servers will assert that clients send one of these profiles and will respond as needed + SRTPProtectionProfiles []SRTPProtectionProfile + + // SRTPMasterKeyIdentifier value (if any) is sent via the use_srtp + // extension for Clients and Servers + SRTPMasterKeyIdentifier []byte + + // ClientAuth determines the server's policy for + // TLS Client Authentication. The default is NoClientCert. + ClientAuth ClientAuthType + + // RequireExtendedMasterSecret determines if the "Extended Master Secret" extension + // should be disabled, requested, or required (default requested). + ExtendedMasterSecret ExtendedMasterSecretType + + // FlightInterval controls how often we send outbound handshake messages + // defaults to time.Second + FlightInterval time.Duration + + // DisableRetransmitBackoff can be used to the disable the backoff feature + // when sending outbound messages as specified in RFC 4347 4.2.4.1 + DisableRetransmitBackoff bool + + // PSK sets the pre-shared key used by this DTLS connection + // If PSK is non-nil only PSK CipherSuites will be used + PSK PSKCallback + PSKIdentityHint []byte + + // InsecureSkipVerify controls whether a client verifies the + // server's certificate chain and host name. + // If InsecureSkipVerify is true, TLS accepts any certificate + // presented by the server and any host name in that certificate. + // In this mode, TLS is susceptible to man-in-the-middle attacks. + // This should be used only for testing. + InsecureSkipVerify bool + + // InsecureHashes allows the use of hashing algorithms that are known + // to be vulnerable. + InsecureHashes bool + + // VerifyPeerCertificate, if not nil, is called after normal + // certificate verification by either a client or server. It + // receives the certificate provided by the peer and also a flag + // that tells if normal verification has succeedded. If it returns a + // non-nil error, the handshake is aborted and that error results. + // + // If normal verification fails then the handshake will abort before + // considering this callback. If normal verification is disabled by + // setting InsecureSkipVerify, or (for a server) when ClientAuth is + // RequestClientCert or RequireAnyClientCert, then this callback will + // be considered but the verifiedChains will always be nil. + VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error + + // VerifyConnection, if not nil, is called after normal certificate + // verification/PSK and after VerifyPeerCertificate by either a TLS client + // or server. If it returns a non-nil error, the handshake is aborted + // and that error results. + // + // If normal verification fails then the handshake will abort before + // considering this callback. This callback will run for all connections + // regardless of InsecureSkipVerify or ClientAuth settings. + VerifyConnection func(*State) error + + // RootCAs defines the set of root certificate authorities + // that one peer uses when verifying the other peer's certificates. + // If RootCAs is nil, TLS uses the host's root CA set. + RootCAs *x509.CertPool + + // ClientCAs defines the set of root certificate authorities + // that servers use if required to verify a client certificate + // by the policy in ClientAuth. + ClientCAs *x509.CertPool + + // ServerName is used to verify the hostname on the returned + // certificates unless InsecureSkipVerify is given. + ServerName string + + LoggerFactory logging.LoggerFactory + + // MTU is the length at which handshake messages will be fragmented to + // fit within the maximum transmission unit (default is 1200 bytes) + MTU int + + // ReplayProtectionWindow is the size of the replay attack protection window. + // Duplication of the sequence number is checked in this window size. + // Packet with sequence number older than this value compared to the latest + // accepted packet will be discarded. (default is 64) + ReplayProtectionWindow int + + // KeyLogWriter optionally specifies a destination for TLS master secrets + // in NSS key log format that can be used to allow external programs + // such as Wireshark to decrypt TLS connections. + // See https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format. + // Use of KeyLogWriter compromises security and should only be + // used for debugging. + KeyLogWriter io.Writer + + // SessionStore is the container to store session for resumption. + SessionStore SessionStore + + // List of application protocols the peer supports, for ALPN + SupportedProtocols []string + + // List of Elliptic Curves to use + // + // If an ECC ciphersuite is configured and EllipticCurves is empty + // it will default to X25519, P-256, P-384 in this specific order. + EllipticCurves []elliptic.Curve + + // GetCertificate returns a Certificate based on the given + // ClientHelloInfo. It will only be called if the client supplies SNI + // information or if Certificates is empty. + // + // If GetCertificate is nil or returns nil, then the certificate is + // retrieved from NameToCertificate. If NameToCertificate is nil, the + // best element of Certificates will be used. + GetCertificate func(*ClientHelloInfo) (*tls.Certificate, error) + + // GetClientCertificate, if not nil, is called when a server requests a + // certificate from a client. If set, the contents of Certificates will + // be ignored. + // + // If GetClientCertificate returns an error, the handshake will be + // aborted and that error will be returned. Otherwise + // GetClientCertificate must return a non-nil Certificate. If + // Certificate.Certificate is empty then no certificate will be sent to + // the server. If this is unacceptable to the server then it may abort + // the handshake. + GetClientCertificate func(*CertificateRequestInfo) (*tls.Certificate, error) + + // InsecureSkipVerifyHello, if true and when acting as server, allow client to + // skip hello verify phase and receive ServerHello after initial ClientHello. + // This have implication on DoS attack resistance. + InsecureSkipVerifyHello bool + + // ConnectionIDGenerator generates connection identifiers that should be + // sent by the remote party if it supports the DTLS Connection Identifier + // extension, as determined during the handshake. Generated connection + // identifiers must always have the same length. Returning a zero-length + // connection identifier indicates that the local party supports sending + // connection identifiers but does not require the remote party to send + // them. A nil ConnectionIDGenerator indicates that connection identifiers + // are not supported. + // https://datatracker.ietf.org/doc/html/rfc9146 + ConnectionIDGenerator func() []byte + + // PaddingLengthGenerator generates the number of padding bytes used to + // inflate ciphertext size in order to obscure content size from observers. + // The length of the content is passed to the generator such that both + // deterministic and random padding schemes can be applied while not + // exceeding maximum record size. + // If no PaddingLengthGenerator is specified, padding will not be applied. + // https://datatracker.ietf.org/doc/html/rfc9146#section-4 + PaddingLengthGenerator func(uint) uint + + // HelloRandomBytesGenerator generates custom client hello random bytes. + HelloRandomBytesGenerator func() [handshake.RandomBytesLength]byte + + // Handshake hooks: hooks can be used for testing invalid messages, + // mimicking other implementations or randomizing fields, which is valuable + // for applications that need censorship-resistance by making + // fingerprinting more difficult. + + // ClientHelloMessageHook, if not nil, is called when a Client Hello message is sent + // from a client. The returned handshake message replaces the original message. + ClientHelloMessageHook func(handshake.MessageClientHello) handshake.Message + + // ServerHelloMessageHook, if not nil, is called when a Server Hello message is sent + // from a server. The returned handshake message replaces the original message. + ServerHelloMessageHook func(handshake.MessageServerHello) handshake.Message + + // CertificateRequestMessageHook, if not nil, is called when a Certificate Request + // message is sent from a server. The returned handshake message replaces the original message. + CertificateRequestMessageHook func(handshake.MessageCertificateRequest) handshake.Message + + // OnConnectionAttempt is fired Whenever a connection attempt is made, + // the server or application can call this callback function. + // The callback function can then implement logic to handle the connection attempt, such as logging the attempt, + // checking against a list of blocked IPs, or counting the attempts to prevent brute force attacks. + // If the callback function returns an error, the connection attempt will be aborted. + OnConnectionAttempt func(net.Addr) error +} + +func (c *Config) includeCertificateSuites() bool { + return c.PSK == nil || len(c.Certificates) > 0 || c.GetCertificate != nil || c.GetClientCertificate != nil +} + +const defaultMTU = 1200 // bytes + +var defaultCurves = []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384} //nolint:gochecknoglobals + +// PSKCallback is called once we have the remote's PSKIdentityHint. +// If the remote provided none it will be nil. +type PSKCallback func([]byte) ([]byte, error) + +// ClientAuthType declares the policy the server will follow for +// TLS Client Authentication. +type ClientAuthType int + +// ClientAuthType enums. +const ( + NoClientCert ClientAuthType = iota + RequestClientCert + RequireAnyClientCert + VerifyClientCertIfGiven + RequireAndVerifyClientCert +) + +// ExtendedMasterSecretType declares the policy the client and server +// will follow for the Extended Master Secret extension. +type ExtendedMasterSecretType int + +// ExtendedMasterSecretType enums. +const ( + RequestExtendedMasterSecret ExtendedMasterSecretType = iota + RequireExtendedMasterSecret + DisableExtendedMasterSecret +) + +func validateConfig(config *Config) error { //nolint:cyclop + switch { + case config == nil: + return errNoConfigProvided + case config.PSKIdentityHint != nil && config.PSK == nil: + return errIdentityNoPSK + } + + for _, cert := range config.Certificates { + if cert.Certificate == nil { + return errInvalidCertificate + } + if cert.PrivateKey != nil { + signer, ok := cert.PrivateKey.(crypto.Signer) + if !ok { + return errInvalidPrivateKey + } + switch signer.Public().(type) { + case ed25519.PublicKey: + case *ecdsa.PublicKey: + case *rsa.PublicKey: + default: + return errInvalidPrivateKey + } + } + } + + _, err := parseCipherSuites( + config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil, + ) + + return err +} diff --git a/vendor/github.com/pion/dtls/v3/conn.go b/vendor/github.com/pion/dtls/v3/conn.go new file mode 100644 index 0000000..e08151f --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/conn.go @@ -0,0 +1,1397 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/pion/dtls/v3/internal/closer" + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/crypto/signaturehash" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" + "github.com/pion/logging" + "github.com/pion/transport/v4/deadline" + "github.com/pion/transport/v4/netctx" + "github.com/pion/transport/v4/replaydetector" +) + +const ( + initialTickerInterval = time.Second + cookieLength = 20 + sessionLength = 32 + defaultNamedCurve = elliptic.X25519 + inboundBufferSize = 8192 + // Default replay protection window is specified by RFC 6347 Section 4.1.2.6. + defaultReplayProtectionWindow = 64 + // maxAppDataPacketQueueSize is the maximum number of app data packets we will. + // enqueue before the handshake is completed. + maxAppDataPacketQueueSize = 100 +) + +func invalidKeyingLabels() map[string]bool { + return map[string]bool{ + "client finished": true, + "server finished": true, + "master secret": true, + "key expansion": true, + } +} + +type addrPkt struct { + rAddr net.Addr + data []byte +} + +type recvHandshakeState struct { + done chan struct{} + isRetransmit bool +} + +// Conn represents a DTLS connection. +type Conn struct { + lock sync.RWMutex // Internal lock (must not be public) + nextConn netctx.PacketConn // Embedded Conn, typically a udpconn we read/write from + fragmentBuffer *fragmentBuffer // out-of-order and missing fragment handling + handshakeCache *handshakeCache // caching of handshake messages for verifyData generation + decrypted chan any // Decrypted Application Data or error, pull by calling `Read` + rAddr net.Addr + state State // Internal state + + maximumTransmissionUnit int + paddingLengthGenerator func(uint) uint + + handshakeCompletedSuccessfully atomic.Bool + handshakeMutex sync.Mutex + handshakeDone chan struct{} + + encryptedPackets []addrPkt + + connectionClosedByUser bool + closeLock sync.Mutex + closed *closer.Closer + + readDeadline *deadline.Deadline + writeDeadline *deadline.Deadline + + log logging.LeveledLogger + + reading chan struct{} + handshakeRecv chan recvHandshakeState + cancelHandshaker func() + cancelHandshakeReader func() + + fsm *handshakeFSM + + replayProtectionWindow uint + + handshakeConfig *handshakeConfig +} + +// createConn creates a new DTLS connection. +// Caller is responsible for validating the config before calling this function. +// +//nolint:cyclop +func createConn( + nextConn net.PacketConn, + rAddr net.Addr, + config *Config, + isClient bool, + resumeState *State, +) (*Conn, error) { + if nextConn == nil { + return nil, errNilNextConn + } + + loggerFactory := config.LoggerFactory + if loggerFactory == nil { + loggerFactory = logging.NewDefaultLoggerFactory() + } + + logger := loggerFactory.NewLogger("dtls") + + mtu := config.MTU + if mtu <= 0 { + mtu = defaultMTU + } + + replayProtectionWindow := config.ReplayProtectionWindow + if replayProtectionWindow <= 0 { + replayProtectionWindow = defaultReplayProtectionWindow + } + + paddingLengthGenerator := config.PaddingLengthGenerator + if paddingLengthGenerator == nil { + paddingLengthGenerator = func(uint) uint { return 0 } + } + + cipherSuites, err := parseCipherSuites( + config.CipherSuites, + config.CustomCipherSuites, + config.includeCertificateSuites(), + config.PSK != nil, + ) + if err != nil { + return nil, err + } + + signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes) + if err != nil { + return nil, err + } + + // Parse certificate signature schemes only if explicitly configured + var certSignatureSchemes []signaturehash.Algorithm + if len(config.CertificateSignatureSchemes) > 0 { + certSignatureSchemes, err = signaturehash.ParseSignatureSchemes( + config.CertificateSignatureSchemes, + config.InsecureHashes, + ) + if err != nil { + return nil, err + } + } + + workerInterval := initialTickerInterval + if config.FlightInterval > 0 { + workerInterval = config.FlightInterval + } + + serverName := config.ServerName + // Do not allow the use of an IP address literal as an SNI value. + // See RFC 6066, Section 3. + if net.ParseIP(serverName) != nil { + serverName = "" + } + + curves := config.EllipticCurves + if len(curves) == 0 { + curves = defaultCurves + } + + handshakeConfig := &handshakeConfig{ + localPSKCallback: config.PSK, + localPSKIdentityHint: config.PSKIdentityHint, + localCipherSuites: cipherSuites, + localSignatureSchemes: signatureSchemes, + localCertSignatureSchemes: certSignatureSchemes, + extendedMasterSecret: config.ExtendedMasterSecret, + localSRTPProtectionProfiles: config.SRTPProtectionProfiles, + localSRTPMasterKeyIdentifier: config.SRTPMasterKeyIdentifier, + serverName: serverName, + supportedProtocols: config.SupportedProtocols, + clientAuth: config.ClientAuth, + localCertificates: config.Certificates, + insecureSkipVerify: config.InsecureSkipVerify, + verifyPeerCertificate: config.VerifyPeerCertificate, + verifyConnection: config.VerifyConnection, + rootCAs: config.RootCAs, + clientCAs: config.ClientCAs, + customCipherSuites: config.CustomCipherSuites, + initialRetransmitInterval: workerInterval, + disableRetransmitBackoff: config.DisableRetransmitBackoff, + log: logger, + initialEpoch: 0, + keyLogWriter: config.KeyLogWriter, + sessionStore: config.SessionStore, + ellipticCurves: curves, + localGetCertificate: config.GetCertificate, + localGetClientCertificate: config.GetClientCertificate, + insecureSkipHelloVerify: config.InsecureSkipVerifyHello, + connectionIDGenerator: config.ConnectionIDGenerator, + helloRandomBytesGenerator: config.HelloRandomBytesGenerator, + clientHelloMessageHook: config.ClientHelloMessageHook, + serverHelloMessageHook: config.ServerHelloMessageHook, + certificateRequestMessageHook: config.CertificateRequestMessageHook, + resumeState: resumeState, + } + + conn := &Conn{ + rAddr: rAddr, + nextConn: netctx.NewPacketConn(nextConn), + handshakeConfig: handshakeConfig, + fragmentBuffer: newFragmentBuffer(), + handshakeCache: newHandshakeCache(), + maximumTransmissionUnit: mtu, + paddingLengthGenerator: paddingLengthGenerator, + + decrypted: make(chan any, 1), + log: logger, + + readDeadline: deadline.New(), + writeDeadline: deadline.New(), + + reading: make(chan struct{}, 1), + handshakeRecv: make(chan recvHandshakeState), + closed: closer.NewCloser(), + cancelHandshaker: func() {}, + cancelHandshakeReader: func() {}, + + replayProtectionWindow: uint(replayProtectionWindow), //nolint:gosec // G115 + + state: State{ + isClient: isClient, + }, + } + + conn.setRemoteEpoch(0) + conn.setLocalEpoch(0) + + return conn, nil +} + +// Handshake runs the client or server DTLS handshake +// protocol if it has not yet been run. +// +// Most uses of this package need not call Handshake explicitly: the +// first [Conn.Read] or [Conn.Write] will call it automatically. +// +// For control over canceling or setting a timeout on a handshake, use +// [Conn.HandshakeContext]. +func (c *Conn) Handshake() error { + return c.HandshakeContext(context.Background()) +} + +// HandshakeContext runs the client or server DTLS handshake +// protocol if it has not yet been run. +// +// The provided Context must be non-nil. If the context is canceled before +// the handshake is complete, the handshake is interrupted and an error is returned. +// Once the handshake has completed, cancellation of the context will not affect the +// connection. +// +// Most uses of this package need not call HandshakeContext explicitly: the +// first [Conn.Read] or [Conn.Write] will call it automatically. +func (c *Conn) HandshakeContext(ctx context.Context) error { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + if c.isHandshakeCompletedSuccessfully() { + return nil + } + + handshakeDone := make(chan struct{}) + defer close(handshakeDone) + c.closeLock.Lock() + c.handshakeDone = handshakeDone + c.closeLock.Unlock() + + // rfc5246#section-7.4.3 + // In addition, the hash and signature algorithms MUST be compatible + // with the key in the server's end-entity certificate. + if !c.state.isClient { + cert, err := c.handshakeConfig.getCertificate(&ClientHelloInfo{}) + if err != nil && !errors.Is(err, errNoCertificates) { + return err + } + c.handshakeConfig.localCipherSuites = filterCipherSuitesForCertificate(cert, c.handshakeConfig.localCipherSuites) + } + + var initialFlight flightVal + var initialFSMState handshakeState + + if c.handshakeConfig.resumeState != nil { //nolint:nestif + if c.state.isClient { + initialFlight = flight5 + } else { + initialFlight = flight6 + } + initialFSMState = handshakeFinished + + c.state = *c.handshakeConfig.resumeState + } else { + if c.state.isClient { + initialFlight = flight1 + } else { + initialFlight = flight0 + } + initialFSMState = handshakePreparing + } + // Do handshake + if err := c.handshake(ctx, c.handshakeConfig, initialFlight, initialFSMState); err != nil { + return err + } + + c.log.Trace("Handshake Completed") + + return nil +} + +// Dial connects to the given network address and establishes a DTLS connection on top. +// +// Deprecated: Use DialWithOptions instead. +func Dial(network string, rAddr *net.UDPAddr, config *Config) (*Conn, error) { + // net.ListenUDP is used rather than net.DialUDP as the latter prevents the + // use of net.PacketConn.WriteTo. + // https://github.com/golang/go/blob/ce5e37ec21442c6eb13a43e68ca20129102ebac0/src/net/udpsock_posix.go#L115 + pConn, err := net.ListenUDP(network, nil) + if err != nil { + return nil, err + } + + return Client(pConn, rAddr, config) +} + +// DialWithOptions connects to the given network address and establishes a DTLS connection on top. +func DialWithOptions(network string, rAddr *net.UDPAddr, opts ...ClientOption) (*Conn, error) { + config, err := buildClientConfig(opts...) + if err != nil { + return nil, err + } + + return Dial(network, rAddr, config) +} + +// Client establishes a DTLS connection over an existing connection. +// +// Deprecated: Use ClientWithOptions instead. +func Client(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { + switch { + case config == nil: + return nil, errNoConfigProvided + case config.PSK != nil && config.PSKIdentityHint == nil: + return nil, errPSKAndIdentityMustBeSetForClient + } + + if err := validateConfig(config); err != nil { + return nil, err + } + + return createConn(conn, rAddr, config, true, nil) +} + +// ClientWithOptions establishes a DTLS connection over an existing connection. +func ClientWithOptions(conn net.PacketConn, rAddr net.Addr, opts ...ClientOption) (*Conn, error) { + config, err := buildClientConfig(opts...) + if err != nil { + return nil, err + } + + return Client(conn, rAddr, config) +} + +// serverWithConfig is an internal helper that accepts a *Config. +func serverWithConfig(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { + if config == nil { + return nil, errNoConfigProvided + } + if config.OnConnectionAttempt != nil { + if err := config.OnConnectionAttempt(rAddr); err != nil { + return nil, err + } + } + + return createConn(conn, rAddr, config, false, nil) +} + +// Server listens for incoming DTLS connections. +// +// Deprecated: Use ServerWithOptions instead. +func Server(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { + if config == nil { + return nil, errNoConfigProvided + } + + if err := validateConfig(config); err != nil { + return nil, err + } + + return serverWithConfig(conn, rAddr, config) +} + +// ServerWithOptions listens for incoming DTLS connections. +func ServerWithOptions(conn net.PacketConn, rAddr net.Addr, opts ...ServerOption) (*Conn, error) { + config, err := buildServerConfig(opts...) + if err != nil { + return nil, err + } + + return Server(conn, rAddr, config) +} + +// Read reads data from the connection. +func (c *Conn) Read(buff []byte) (n int, err error) { //nolint:cyclop + if err := c.Handshake(); err != nil { + return 0, err + } + + select { + case <-c.readDeadline.Done(): + return 0, errDeadlineExceeded + default: + } + + for { + select { + case <-c.readDeadline.Done(): + return 0, errDeadlineExceeded + case out, ok := <-c.decrypted: + if !ok { + return 0, io.EOF + } + switch val := out.(type) { + case ([]byte): + if len(buff) < len(val) { + return 0, errBufferTooSmall + } + copy(buff, val) + + return len(val), nil + case (error): + return 0, val + } + } + } +} + +// Write writes len(payload) bytes from payload to the DTLS connection. +func (c *Conn) Write(payload []byte) (int, error) { + if c.isConnectionClosed() { + return 0, ErrConnClosed + } + + select { + case <-c.writeDeadline.Done(): + return 0, errDeadlineExceeded + default: + } + + if err := c.Handshake(); err != nil { + return 0, err + } + + return len(payload), c.writePackets(c.writeDeadline, []*packet{ + { + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Epoch: c.state.getLocalEpoch(), + Version: protocol.Version1_2, + }, + Content: &protocol.ApplicationData{ + Data: payload, + }, + }, + shouldWrapCID: len(c.state.remoteConnectionID) > 0, + shouldEncrypt: true, + }, + }) +} + +// Close closes the connection. +func (c *Conn) Close() error { + err := c.close(true) //nolint:contextcheck + c.closeLock.Lock() + handshakeDone := c.handshakeDone + c.closeLock.Unlock() + if handshakeDone != nil { + <-handshakeDone + } + + return err +} + +// ConnectionState returns basic DTLS details about the connection. +// Note that this replaced the `Export` function of v1. +func (c *Conn) ConnectionState() (State, bool) { + c.lock.RLock() + defer c.lock.RUnlock() + stateClone, err := c.state.clone() + if err != nil { + return State{}, false + } + + return *stateClone, true +} + +// SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile. +func (c *Conn) SelectedSRTPProtectionProfile() (SRTPProtectionProfile, bool) { + profile := c.state.getSRTPProtectionProfile() + if profile == 0 { + return 0, false + } + + return profile, true +} + +// RemoteSRTPMasterKeyIdentifier returns the MasterKeyIdentifier value from the use_srtp. +func (c *Conn) RemoteSRTPMasterKeyIdentifier() ([]byte, bool) { + if profile := c.state.getSRTPProtectionProfile(); profile == 0 { + return nil, false + } + + return c.state.remoteSRTPMasterKeyIdentifier, true +} + +func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error { + c.lock.Lock() + defer c.lock.Unlock() + + var rawPackets [][]byte + + for _, pkt := range pkts { + if dtlsHandshake, ok := pkt.record.Content.(*handshake.Handshake); ok { + handshakeRaw, err := pkt.record.Marshal() + if err != nil { + return err + } + + c.log.Tracef("[handshake:%v] -> %s (epoch: %d, seq: %d)", + srvCliStr(c.state.isClient), dtlsHandshake.Header.Type.String(), + pkt.record.Header.Epoch, dtlsHandshake.Header.MessageSequence) + + c.handshakeCache.push( + handshakeRaw[recordlayer.FixedHeaderSize:], + pkt.record.Header.Epoch, + dtlsHandshake.Header.MessageSequence, + dtlsHandshake.Header.Type, + c.state.isClient, + ) + + rawHandshakePackets, err := c.processHandshakePacket(pkt, dtlsHandshake) + if err != nil { + return err + } + rawPackets = append(rawPackets, rawHandshakePackets...) + } else { + rawPacket, err := c.processPacket(pkt) + if err != nil { + return err + } + rawPackets = append(rawPackets, rawPacket) + } + } + if len(rawPackets) == 0 { + return nil + } + compactedRawPackets := c.compactRawPackets(rawPackets) + + for _, compactedRawPackets := range compactedRawPackets { + if _, err := c.nextConn.WriteToContext(ctx, compactedRawPackets, c.rAddr); err != nil { + return netError(err) + } + } + + return nil +} + +func (c *Conn) compactRawPackets(rawPackets [][]byte) [][]byte { + // avoid a useless copy in the common case + if len(rawPackets) == 1 { + return rawPackets + } + + combinedRawPackets := make([][]byte, 0) + currentCombinedRawPacket := make([]byte, 0) + + for _, rawPacket := range rawPackets { + if len(currentCombinedRawPacket) > 0 && len(currentCombinedRawPacket)+len(rawPacket) >= c.maximumTransmissionUnit { + combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket) + currentCombinedRawPacket = []byte{} + } + currentCombinedRawPacket = append(currentCombinedRawPacket, rawPacket...) + } + + combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket) + + return combinedRawPackets +} + +func (c *Conn) processPacket(pkt *packet) ([]byte, error) { //nolint:cyclop + epoch := pkt.record.Header.Epoch + for len(c.state.localSequenceNumber) <= int(epoch) { + c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0)) + } + seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1 + if seq > recordlayer.MaxSequenceNumber { + // RFC 6347 Section 4.1.0 + // The implementation must either abandon an association or rehandshake + // prior to allowing the sequence number to wrap. + return nil, errSequenceNumberOverflow + } + pkt.record.Header.SequenceNumber = seq + + var rawPacket []byte + if pkt.shouldWrapCID { //nolint:nestif + // Record must be marshaled to populate fields used in inner plaintext. + if _, err := pkt.record.Marshal(); err != nil { + return nil, err + } + content, err := pkt.record.Content.Marshal() + if err != nil { + return nil, err + } + inner := &recordlayer.InnerPlaintext{ + Content: content, + RealType: pkt.record.Header.ContentType, + } + rawInner, err := inner.Marshal() //nolint:govet + if err != nil { + return nil, err + } + cidHeader := &recordlayer.Header{ + Version: pkt.record.Header.Version, + ContentType: protocol.ContentTypeConnectionID, + Epoch: pkt.record.Header.Epoch, + ContentLen: uint16(len(rawInner)), //nolint:gosec //G115 + ConnectionID: c.state.remoteConnectionID, + SequenceNumber: pkt.record.Header.SequenceNumber, + } + rawPacket, err = cidHeader.Marshal() + if err != nil { + return nil, err + } + pkt.record.Header = *cidHeader + rawPacket = append(rawPacket, rawInner...) + } else { + var err error + rawPacket, err = pkt.record.Marshal() + if err != nil { + return nil, err + } + } + + if pkt.shouldEncrypt { + var err error + rawPacket, err = c.state.cipherSuite.Encrypt(pkt.record, rawPacket) + if err != nil { + return nil, err + } + } + + return rawPacket, nil +} + +//nolint:cyclop +func (c *Conn) processHandshakePacket(pkt *packet, dtlsHandshake *handshake.Handshake) ([][]byte, error) { + rawPackets := make([][]byte, 0) + + handshakeFragments, err := c.fragmentHandshake(dtlsHandshake) + if err != nil { + return nil, err + } + epoch := pkt.record.Header.Epoch + for len(c.state.localSequenceNumber) <= int(epoch) { + c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0)) + } + + for _, handshakeFragment := range handshakeFragments { + seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1 + if seq > recordlayer.MaxSequenceNumber { + return nil, errSequenceNumberOverflow + } + + var rawPacket []byte + if pkt.shouldWrapCID { + inner := &recordlayer.InnerPlaintext{ + Content: handshakeFragment, + RealType: protocol.ContentTypeHandshake, + Zeros: c.paddingLengthGenerator(uint(len(handshakeFragment))), + } + rawInner, err := inner.Marshal() //nolint:govet + if err != nil { + return nil, err + } + cidHeader := &recordlayer.Header{ + Version: pkt.record.Header.Version, + ContentType: protocol.ContentTypeConnectionID, + Epoch: pkt.record.Header.Epoch, + ContentLen: uint16(len(rawInner)), //nolint:gosec //G115 + ConnectionID: c.state.remoteConnectionID, + SequenceNumber: pkt.record.Header.SequenceNumber, + } + rawPacket, err = cidHeader.Marshal() + if err != nil { + return nil, err + } + pkt.record.Header = *cidHeader + rawPacket = append(rawPacket, rawInner...) + } else { + recordlayerHeader := &recordlayer.Header{ + Version: pkt.record.Header.Version, + ContentType: pkt.record.Header.ContentType, + ContentLen: uint16(len(handshakeFragment)), //nolint:gosec // G115 + Epoch: pkt.record.Header.Epoch, + SequenceNumber: seq, + } + + rawPacket, err = recordlayerHeader.Marshal() + if err != nil { + return nil, err + } + + pkt.record.Header = *recordlayerHeader + rawPacket = append(rawPacket, handshakeFragment...) + } + + if pkt.shouldEncrypt { + var err error + rawPacket, err = c.state.cipherSuite.Encrypt(pkt.record, rawPacket) + if err != nil { + return nil, err + } + } + + rawPackets = append(rawPackets, rawPacket) + } + + return rawPackets, nil +} + +func (c *Conn) fragmentHandshake(dtlsHandshake *handshake.Handshake) ([][]byte, error) { + content, err := dtlsHandshake.Message.Marshal() + if err != nil { + return nil, err + } + + fragmentedHandshakes := make([][]byte, 0) + + contentFragments := splitBytes(content, c.maximumTransmissionUnit) + if len(contentFragments) == 0 { + contentFragments = [][]byte{ + {}, + } + } + + offset := 0 + for _, contentFragment := range contentFragments { + contentFragmentLen := len(contentFragment) + + headerFragment := &handshake.Header{ + Type: dtlsHandshake.Header.Type, + Length: dtlsHandshake.Header.Length, + MessageSequence: dtlsHandshake.Header.MessageSequence, + FragmentOffset: uint32(offset), //nolint:gosec // G115 + FragmentLength: uint32(contentFragmentLen), //nolint:gosec // G115 + } + + offset += contentFragmentLen + + fragmentedHandshake, err := headerFragment.Marshal() + if err != nil { + return nil, err + } + + fragmentedHandshake = append(fragmentedHandshake, contentFragment...) + fragmentedHandshakes = append(fragmentedHandshakes, fragmentedHandshake) + } + + return fragmentedHandshakes, nil +} + +var poolReadBuffer = sync.Pool{ //nolint:gochecknoglobals + New: func() any { + b := make([]byte, inboundBufferSize) + + return &b + }, +} + +func (c *Conn) readAndBuffer(ctx context.Context) error { //nolint:cyclop + bufptr, ok := poolReadBuffer.Get().(*[]byte) + if !ok { + return errFailedToAccessPoolReadBuffer + } + defer poolReadBuffer.Put(bufptr) + + b := *bufptr + i, rAddr, err := c.nextConn.ReadFromContext(ctx, b) + if err != nil { + return netError(err) + } + + pkts, err := recordlayer.ContentAwareUnpackDatagram(b[:i], len(c.state.getLocalConnectionID())) + if err != nil { + return err + } + + var hasHandshake, isRetransmit bool + for _, p := range pkts { + hs, rtx, alert, err := c.handleIncomingPacket(ctx, p, rAddr, true) + if alert != nil { + if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { + if err == nil { + err = alertErr + } + } + } + + var e *alertError + if errors.As(err, &e) && e.IsFatalOrCloseNotify() { + return e + } + if err != nil { + return err + } + if hs { + hasHandshake = true + } + if rtx { + isRetransmit = true + } + } + if hasHandshake { + s := recvHandshakeState{ + done: make(chan struct{}), + isRetransmit: isRetransmit, + } + select { + case c.handshakeRecv <- s: + // If the other party may retransmit the flight, + // we should respond even if it not a new message. + <-s.done + case <-c.fsm.Done(): + } + } + + return nil +} + +func (c *Conn) handleQueuedPackets(ctx context.Context) error { + c.lock.Lock() + pkts := c.encryptedPackets + c.encryptedPackets = nil + c.lock.Unlock() + + for _, p := range pkts { + _, _, alert, err := c.handleIncomingPacket(ctx, p.data, p.rAddr, false) // don't re-enqueue + if alert != nil { + if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { + if err == nil { + err = alertErr + } + } + } + var e *alertError + if errors.As(err, &e) && e.IsFatalOrCloseNotify() { + return e + } + if err != nil { + return err + } + } + + return nil +} + +func (c *Conn) enqueueEncryptedPackets(packet addrPkt) bool { + c.lock.Lock() + defer c.lock.Unlock() + + if len(c.encryptedPackets) < maxAppDataPacketQueueSize { + c.encryptedPackets = append(c.encryptedPackets, packet) + + return true + } + + return false +} + +//nolint:gocognit,gocyclo,cyclop,maintidx +func (c *Conn) handleIncomingPacket( + ctx context.Context, + buf []byte, + rAddr net.Addr, + enqueue bool, +) (bool, bool, *alert.Alert, error) { + header := &recordlayer.Header{} + // Set connection ID size so that records of content type tls12_cid will + // be parsed correctly. + if len(c.state.getLocalConnectionID()) > 0 { + header.ConnectionID = make([]byte, len(c.state.getLocalConnectionID())) + } + if err := header.Unmarshal(buf); err != nil { + // Decode error must be silently discarded + // [RFC6347 Section-4.1.2.7] + c.log.Debugf("discarded broken packet: %v", err) + + return false, false, nil, nil + } + // Validate epoch + remoteEpoch := c.state.getRemoteEpoch() + if header.Epoch > remoteEpoch { + if header.Epoch > remoteEpoch+1 { + c.log.Debugf("discarded future packet (epoch: %d, seq: %d)", + header.Epoch, header.SequenceNumber, + ) + + return false, false, nil, nil + } + if enqueue { + if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok { + c.log.Debug("received packet of next epoch, queuing packet") + } + } + + return false, false, nil, nil + } + + // Anti-replay protection + for len(c.state.replayDetector) <= int(header.Epoch) { + c.state.replayDetector = append(c.state.replayDetector, + replaydetector.New(c.replayProtectionWindow, recordlayer.MaxSequenceNumber), + ) + } + markPacketAsValid, ok := c.state.replayDetector[int(header.Epoch)].Check(header.SequenceNumber) + if !ok { + c.log.Debugf("discarded duplicated packet (epoch: %d, seq: %d)", + header.Epoch, header.SequenceNumber, + ) + + return false, false, nil, nil + } + + // originalCID indicates whether the original record had content type + // Connection ID. + originalCID := false + + // Decrypt + if header.Epoch != 0 { //nolint:nestif + if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() { + if enqueue { + if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok { + c.log.Debug("handshake not finished, queuing packet") + } + } + + return false, false, nil, nil + } + + // If a connection identifier had been negotiated and encryption is + // enabled, the connection identifier MUST be sent. + if len(c.state.getLocalConnectionID()) > 0 && header.ContentType != protocol.ContentTypeConnectionID { + c.log.Debug("discarded packet missing connection ID after value negotiated") + + return false, false, nil, nil + } + + var err error + var hdr recordlayer.Header + if header.ContentType == protocol.ContentTypeConnectionID { + hdr.ConnectionID = make([]byte, len(c.state.getLocalConnectionID())) + } + buf, err = c.state.cipherSuite.Decrypt(hdr, buf) + if err != nil { + c.log.Debugf("%s: decrypt failed: %s", srvCliStr(c.state.isClient), err) + + return false, false, nil, nil + } + // If this is a connection ID record, make it look like a normal record for + // further processing. + if header.ContentType == protocol.ContentTypeConnectionID { + originalCID = true + ip := &recordlayer.InnerPlaintext{} + if err := ip.Unmarshal(buf[header.Size():]); err != nil { //nolint:govet + c.log.Debugf("unpacking inner plaintext failed: %s", err) + + return false, false, nil, nil + } + unpacked := &recordlayer.Header{ + ContentType: ip.RealType, + ContentLen: uint16(len(ip.Content)), //nolint:gosec // G115 + Version: header.Version, + Epoch: header.Epoch, + SequenceNumber: header.SequenceNumber, + } + buf, err = unpacked.Marshal() + if err != nil { + c.log.Debugf("converting CID record to inner plaintext failed: %s", err) + + return false, false, nil, nil + } + buf = append(buf, ip.Content...) + } + + // If connection ID does not match discard the packet. + if !bytes.Equal(c.state.getLocalConnectionID(), header.ConnectionID) { + c.log.Debug("unexpected connection ID") + + return false, false, nil, nil + } + } + + isHandshake, isRetransmit, err := c.fragmentBuffer.push(append([]byte{}, buf...)) + if err != nil { + // Decode error must be silently discarded + // [RFC6347 Section-4.1.2.7] + c.log.Debugf("defragment failed: %s", err) + + return false, false, nil, nil + } else if isHandshake { + markPacketAsValid() + + for out, epoch := c.fragmentBuffer.pop(); out != nil; out, epoch = c.fragmentBuffer.pop() { + header := &handshake.Header{} + if err := header.Unmarshal(out); err != nil { + c.log.Debugf("%s: handshake parse failed: %s", srvCliStr(c.state.isClient), err) + + continue + } + c.handshakeCache.push(out, epoch, header.MessageSequence, header.Type, !c.state.isClient) + } + + return true, isRetransmit, nil, nil + } + + r := &recordlayer.RecordLayer{} + if err := r.Unmarshal(buf); err != nil { + return false, false, &alert.Alert{Level: alert.Fatal, Description: alert.DecodeError}, err + } + + isLatestSeqNum := false + switch content := r.Content.(type) { + case *alert.Alert: + c.log.Tracef("%s: <- %s", srvCliStr(c.state.isClient), content.String()) + var a *alert.Alert + if content.Description == alert.CloseNotify { + // Respond with a close_notify [RFC5246 Section 7.2.1] + a = &alert.Alert{Level: alert.Warning, Description: alert.CloseNotify} + } + _ = markPacketAsValid() + + return false, false, a, &alertError{content} + case *protocol.ChangeCipherSpec: + if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() { + if enqueue { + if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok { + c.log.Debugf("CipherSuite not initialized, queuing packet") + } + } + + return false, false, nil, nil + } + + newRemoteEpoch := header.Epoch + 1 + c.log.Tracef("%s: <- ChangeCipherSpec (epoch: %d)", srvCliStr(c.state.isClient), newRemoteEpoch) + + if c.state.getRemoteEpoch()+1 == newRemoteEpoch { + c.setRemoteEpoch(newRemoteEpoch) + isLatestSeqNum = markPacketAsValid() + } + case *protocol.ApplicationData: + if header.Epoch == 0 { + return false, false, &alert.Alert{ + Level: alert.Fatal, Description: alert.UnexpectedMessage, + }, errApplicationDataEpochZero + } + + isLatestSeqNum = markPacketAsValid() + + select { + case c.decrypted <- content.Data: + case <-c.closed.Done(): + case <-ctx.Done(): + } + + default: + return false, false, &alert.Alert{ + Level: alert.Fatal, Description: alert.UnexpectedMessage, + }, fmt.Errorf("%w: %d", errUnhandledContextType, content.ContentType()) + } + + // Any valid connection ID record is a candidate for updating the remote + // address if it is the latest record received. + // https://datatracker.ietf.org/doc/html/rfc9146#peer-address-update + if originalCID && isLatestSeqNum { + if rAddr != c.RemoteAddr() { + c.lock.Lock() + c.rAddr = rAddr + c.lock.Unlock() + } + } + + return false, false, nil, nil +} + +func (c *Conn) recvHandshake() <-chan recvHandshakeState { + return c.handshakeRecv +} + +func (c *Conn) notify(ctx context.Context, level alert.Level, desc alert.Description) error { + if level == alert.Fatal && len(c.state.SessionID) > 0 { + // According to the RFC, we need to delete the stored session. + // https://datatracker.ietf.org/doc/html/rfc5246#section-7.2 + if ss := c.fsm.cfg.sessionStore; ss != nil { + c.log.Tracef("clean invalid session: %s", c.state.SessionID) + if err := ss.Del(c.sessionKey()); err != nil { + return err + } + } + } + + return c.writePackets(ctx, []*packet{ + { + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Epoch: c.state.getLocalEpoch(), + Version: protocol.Version1_2, + }, + Content: &alert.Alert{ + Level: level, + Description: desc, + }, + }, + shouldWrapCID: len(c.state.remoteConnectionID) > 0, + shouldEncrypt: c.isHandshakeCompletedSuccessfully(), + }, + }) +} + +func (c *Conn) setHandshakeCompletedSuccessfully() bool { + return c.handshakeCompletedSuccessfully.CompareAndSwap(false, true) +} + +func (c *Conn) isHandshakeCompletedSuccessfully() bool { + return c.handshakeCompletedSuccessfully.Load() +} + +//nolint:cyclop,gocognit,contextcheck +func (c *Conn) handshake( + ctx context.Context, + cfg *handshakeConfig, + initialFlight flightVal, + initialState handshakeState, +) error { + c.fsm = newHandshakeFSM(&c.state, c.handshakeCache, cfg, initialFlight) + + done := make(chan struct{}) + ctxRead, cancelRead := context.WithCancel(context.Background()) + cfg.onFlightState = func(_ flightVal, s handshakeState) { + if s == handshakeFinished && c.setHandshakeCompletedSuccessfully() { + close(done) + } + } + + ctxHs, cancel := context.WithCancel(context.Background()) + + c.closeLock.Lock() + c.cancelHandshaker = cancel + c.cancelHandshakeReader = cancelRead + c.closeLock.Unlock() + + firstErr := make(chan error, 1) + + var handshakeLoopsFinished sync.WaitGroup + handshakeLoopsFinished.Add(2) + + // Handshake routine should be live until close. + // The other party may request retransmission of the last flight to cope with packet drop. + go func() { + defer handshakeLoopsFinished.Done() + err := c.fsm.Run(ctxHs, c, initialState) + if !errors.Is(err, context.Canceled) { + select { + case firstErr <- err: + default: + } + } + }() + go func() { + defer func() { + if c.isHandshakeCompletedSuccessfully() { + // Escaping read loop. + // It's safe to close decrypted channnel now. + close(c.decrypted) + } + + // Force stop handshaker when the underlying connection is closed. + cancel() + }() + defer handshakeLoopsFinished.Done() + for { + if err := c.readAndBuffer(ctxRead); err != nil { //nolint:nestif + var alertErr *alertError + if errors.As(err, &alertErr) { + if !alertErr.IsFatalOrCloseNotify() { + if c.isHandshakeCompletedSuccessfully() { + // Pass the error to Read() + select { + case c.decrypted <- err: + case <-c.closed.Done(): + case <-ctxRead.Done(): + } + } + + continue // non-fatal alert must not stop read loop + } + } else { + switch { + case errors.Is(err, context.DeadlineExceeded), + errors.Is(err, context.Canceled), + errors.Is(err, io.EOF), + errors.Is(err, net.ErrClosed): + case errors.Is(err, recordlayer.ErrInvalidPacketLength): + // Decode error must be silently discarded + // [RFC6347 Section-4.1.2.7] + continue + default: + if c.isHandshakeCompletedSuccessfully() { + // Keep read loop and pass the read error to Read() + select { + case c.decrypted <- err: + case <-c.closed.Done(): + case <-ctxRead.Done(): + } + + continue // non-fatal alert must not stop read loop + } + } + } + + select { + case firstErr <- err: + default: + } + + if alertErr != nil { + if alertErr.IsFatalOrCloseNotify() { + _ = c.close(false) //nolint:contextcheck + } + } + if !c.isConnectionClosed() && errors.Is(err, context.Canceled) { + c.log.Trace("handshake timeouts - closing underline connection") + _ = c.close(false) //nolint:contextcheck + } + + return + } + } + }() + + select { + case err := <-firstErr: + cancelRead() + cancel() + handshakeLoopsFinished.Wait() + + return c.translateHandshakeCtxError(err) + case <-ctx.Done(): + cancelRead() + cancel() + handshakeLoopsFinished.Wait() + + return c.translateHandshakeCtxError(ctx.Err()) + case <-done: + return nil + } +} + +func (c *Conn) translateHandshakeCtxError(err error) error { + if err == nil { + return nil + } + if errors.Is(err, context.Canceled) && c.isHandshakeCompletedSuccessfully() { + return nil + } + + return &HandshakeError{Err: err} +} + +func (c *Conn) close(byUser bool) error { + c.closeLock.Lock() + cancelHandshaker := c.cancelHandshaker + cancelHandshakeReader := c.cancelHandshakeReader + c.closeLock.Unlock() + + cancelHandshaker() + cancelHandshakeReader() + + if c.isHandshakeCompletedSuccessfully() && byUser { + // Discard error from notify() to return non-error on the first user call of Close() + // even if the underlying connection is already closed. + _ = c.notify(context.Background(), alert.Warning, alert.CloseNotify) + } + + c.closeLock.Lock() + // Don't return ErrConnClosed at the first time of the call from user. + closedByUser := c.connectionClosedByUser + if byUser { + c.connectionClosedByUser = true + } + isClosed := c.isConnectionClosed() + c.closed.Close() + c.closeLock.Unlock() + + if closedByUser { + return ErrConnClosed + } + + if isClosed { + return nil + } + + return c.nextConn.Close() +} + +func (c *Conn) isConnectionClosed() bool { + select { + case <-c.closed.Done(): + return true + default: + return false + } +} + +func (c *Conn) setLocalEpoch(epoch uint16) { + c.state.localEpoch.Store(epoch) +} + +func (c *Conn) setRemoteEpoch(epoch uint16) { + c.state.remoteEpoch.Store(epoch) +} + +// LocalAddr implements net.Conn.LocalAddr. +func (c *Conn) LocalAddr() net.Addr { + return c.nextConn.LocalAddr() +} + +// RemoteAddr implements net.Conn.RemoteAddr. +func (c *Conn) RemoteAddr() net.Addr { + c.lock.RLock() + defer c.lock.RUnlock() + + return c.rAddr +} + +func (c *Conn) sessionKey() []byte { + if c.state.isClient { + // As ServerName can be like 0.example.com, it's better to add + // delimiter character which is not allowed to be in + // neither address or domain name. + return []byte(c.rAddr.String() + "_" + c.fsm.cfg.serverName) + } + + return c.state.SessionID +} + +// SetDeadline implements net.Conn.SetDeadline. +func (c *Conn) SetDeadline(t time.Time) error { + c.readDeadline.Set(t) + + return c.SetWriteDeadline(t) +} + +// SetReadDeadline implements net.Conn.SetReadDeadline. +func (c *Conn) SetReadDeadline(t time.Time) error { + c.readDeadline.Set(t) + // Read deadline is fully managed by this layer. + // Don't set read deadline to underlying connection. + return nil +} + +// SetWriteDeadline implements net.Conn.SetWriteDeadline. +func (c *Conn) SetWriteDeadline(t time.Time) error { + c.writeDeadline.Set(t) + // Write deadline is also fully managed by this layer. + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/connection_id.go b/vendor/github.com/pion/dtls/v3/connection_id.go new file mode 100644 index 0000000..a8622eb --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/connection_id.go @@ -0,0 +1,105 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "crypto/rand" + + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/extension" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" +) + +// RandomCIDGenerator is a random Connection ID generator where CID is the +// specified size. Specifying a size of 0 will indicate to peers that sending a +// Connection ID is not necessary. +func RandomCIDGenerator(size int) func() []byte { + return func() []byte { + cid := make([]byte, size) + if _, err := rand.Read(cid); err != nil { + panic(err) //nolint -- nonrecoverable + } + + return cid + } +} + +// OnlySendCIDGenerator enables sending Connection IDs negotiated with a peer, +// but indicates to the peer that sending Connection IDs in return is not +// necessary. +func OnlySendCIDGenerator() func() []byte { + return func() []byte { + return nil + } +} + +// cidDatagramRouter extracts connection IDs from incoming datagram payloads and +// uses them to route to the proper connection. +// NOTE: properly routing datagrams based on connection IDs requires using +// constant size connection IDs. +func cidDatagramRouter(size int) func([]byte) (string, bool) { + return func(packet []byte) (string, bool) { + pkts, err := recordlayer.ContentAwareUnpackDatagram(packet, size) + if err != nil || len(pkts) < 1 { + return "", false + } + for _, pkt := range pkts { + h := &recordlayer.Header{ + ConnectionID: make([]byte, size), + } + if err := h.Unmarshal(pkt); err != nil { + continue + } + if h.ContentType != protocol.ContentTypeConnectionID { + continue + } + + return string(h.ConnectionID), true + } + + return "", false + } +} + +// cidConnIdentifier extracts connection IDs from outgoing ServerHello records +// and associates them with the associated connection. +// NOTE: a ServerHello should always be the first record in a datagram if +// multiple are present, so we avoid iterating through all packets if the first +// is not a ServerHello. +func cidConnIdentifier() func([]byte) (string, bool) { //nolint:cyclop + return func(packet []byte) (string, bool) { + pkts, err := recordlayer.UnpackDatagram(packet) + if err != nil || len(pkts) < 1 { + return "", false + } + var h recordlayer.Header + if hErr := h.Unmarshal(pkts[0]); hErr != nil { + return "", false + } + if h.ContentType != protocol.ContentTypeHandshake { + return "", false + } + var hh handshake.Header + var sh handshake.MessageServerHello + for _, pkt := range pkts { + if hhErr := hh.Unmarshal(pkt[recordlayer.FixedHeaderSize:]); hhErr != nil { + continue + } + if err = sh.Unmarshal(pkt[recordlayer.FixedHeaderSize+handshake.HeaderLength:]); err == nil { + break + } + } + if err != nil { + return "", false + } + for _, ext := range sh.Extensions { + if e, ok := ext.(*extension.ConnectionID); ok { + return string(e.CID), true + } + } + + return "", false + } +} diff --git a/vendor/github.com/pion/dtls/v3/crypto.go b/vendor/github.com/pion/dtls/v3/crypto.go new file mode 100644 index 0000000..ad00bfd --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/crypto.go @@ -0,0 +1,457 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/binary" + "math/big" + "time" + + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/crypto/hash" + "github.com/pion/dtls/v3/pkg/crypto/signature" + "github.com/pion/dtls/v3/pkg/crypto/signaturehash" +) + +type ecdsaSignature struct { + R, S *big.Int +} + +func valueKeyMessage(clientRandom, serverRandom, publicKey []byte, namedCurve elliptic.Curve) []byte { + serverECDHParams := make([]byte, 4) + serverECDHParams[0] = 3 // named curve + binary.BigEndian.PutUint16(serverECDHParams[1:], uint16(namedCurve)) + serverECDHParams[3] = byte(len(publicKey)) + + plaintext := []byte{} + plaintext = append(plaintext, clientRandom...) + plaintext = append(plaintext, serverRandom...) + plaintext = append(plaintext, serverECDHParams...) + plaintext = append(plaintext, publicKey...) + + return plaintext +} + +// validateSignatureAlgOID validates that the signature scheme matches the +// certificate's public key algorithm OID. This is required by RFC 8446 Section 4.2.3: +// - RSA_PSS_RSAE requires rsaEncryption OID +// - RSA_PSS_PSS requires id-RSASSA-PSS OID +// +// Note: returns nil if the given signature.Algorithm is not PSS based. +// +// https://www.rfc-editor.org/rfc/rfc8446#section-4.2.3 +func validateSignatureAlgOID(cert *x509.Certificate, sigAlg signature.Algorithm) error { + if !sigAlg.IsPSS() { + return nil + } + + // Get the certificate's public key algorithm OID from the raw certificate + // We need to parse the SubjectPublicKeyInfo to get the algorithm OID + var spki struct { + Algorithm pkix.AlgorithmIdentifier + PublicKey asn1.BitString + } + if _, err := asn1.Unmarshal(cert.RawSubjectPublicKeyInfo, &spki); err != nil { + return err + } + + certOID := spki.Algorithm.Algorithm + + switch sigAlg { + // Check RSAE variants (0x0804-0x0806) require rsaEncryption OID + case signature.RSA_PSS_RSAE_SHA256, signature.RSA_PSS_RSAE_SHA384, signature.RSA_PSS_RSAE_SHA512: + oidPublicKeyRSA := asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 1} // OID: rsaEncryption + if !certOID.Equal(oidPublicKeyRSA) { + return errInvalidCertificateOID + } + + return nil + + // Check PSS variants (0x0809-0x080b) require id-RSASSA-PSS OID + case signature.RSA_PSS_PSS_SHA256, signature.RSA_PSS_PSS_SHA384, signature.RSA_PSS_PSS_SHA512: + oidPublicKeyRSAPSS := asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 10} // OID: id-RSASSA-PSS + if !certOID.Equal(oidPublicKeyRSAPSS) { + return errInvalidCertificateOID + } + + return nil + + default: + return nil + } +} + +// If the client provided a "signature_algorithms" extension, then all +// certificates provided by the server MUST be signed by a +// hash/signature algorithm pair that appears in that extension +// +// https://tools.ietf.org/html/rfc5246#section-7.4.2 +func generateKeySignature( + clientRandom, serverRandom, publicKey []byte, + namedCurve elliptic.Curve, + signer crypto.Signer, + hashAlgorithm hash.Algorithm, + signatureAlgorithm signature.Algorithm, +) ([]byte, error) { + msg := valueKeyMessage(clientRandom, serverRandom, publicKey, namedCurve) + switch signer.Public().(type) { + case ed25519.PublicKey: + // https://crypto.stackexchange.com/a/55483 + return signer.Sign(rand.Reader, msg, crypto.Hash(0)) + case *ecdsa.PublicKey: + hashed := hashAlgorithm.Digest(msg) + + return signer.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) + case *rsa.PublicKey: + hashed := hashAlgorithm.Digest(msg) + + // Use RSA-PSS if the signature algorithm is PSS + if signatureAlgorithm.IsPSS() { + pssOpts := &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthEqualsHash, + Hash: hashAlgorithm.CryptoHash(), + } + + return signer.Sign(rand.Reader, hashed, pssOpts) + } + + // Otherwise use PKCS#1 v1.5 + return signer.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) + } + + return nil, errKeySignatureGenerateUnimplemented +} + +//nolint:dupl,cyclop +func verifyKeySignature( + message, remoteKeySignature []byte, + hashAlgorithm hash.Algorithm, + signatureAlgorithm signature.Algorithm, + rawCertificates [][]byte, +) error { + if len(rawCertificates) == 0 { + return errLengthMismatch + } + certificate, err := x509.ParseCertificate(rawCertificates[0]) + if err != nil { + return err + } + + // Validate that the signature algorithm matches the certificate's OID + if err := validateSignatureAlgOID(certificate, signatureAlgorithm); err != nil { + return err + } + + switch pubKey := certificate.PublicKey.(type) { + case ed25519.PublicKey: + if ok := ed25519.Verify(pubKey, message, remoteKeySignature); !ok { + return errKeySignatureMismatch + } + + return nil + case *ecdsa.PublicKey: + ecdsaSig := &ecdsaSignature{} + if _, err := asn1.Unmarshal(remoteKeySignature, ecdsaSig); err != nil { + return err + } + if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 { + return errInvalidECDSASignature + } + hashed := hashAlgorithm.Digest(message) + if !ecdsa.Verify(pubKey, hashed, ecdsaSig.R, ecdsaSig.S) { + return errKeySignatureMismatch + } + + return nil + case *rsa.PublicKey: + hashed := hashAlgorithm.Digest(message) + + // Use RSA-PSS verification if the signature algorithm is PSS + if signatureAlgorithm.IsPSS() { + pssOpts := &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthEqualsHash, + Hash: hashAlgorithm.CryptoHash(), + } + if err := rsa.VerifyPSS(pubKey, hashAlgorithm.CryptoHash(), hashed, remoteKeySignature, pssOpts); err != nil { + return errKeySignatureMismatch + } + + return nil + } + + // Otherwise use PKCS#1 v1.5 + if rsa.VerifyPKCS1v15(pubKey, hashAlgorithm.CryptoHash(), hashed, remoteKeySignature) != nil { + return errKeySignatureMismatch + } + + return nil + } + + return errKeySignatureVerifyUnimplemented +} + +// If the server has sent a CertificateRequest message, the client MUST send the Certificate +// message. The ClientKeyExchange message is now sent, and the content +// of that message will depend on the public key algorithm selected +// between the ClientHello and the ServerHello. If the client has sent +// a certificate with signing ability, a digitally-signed +// CertificateVerify message is sent to explicitly verify possession of +// the private key in the certificate. +// https://tools.ietf.org/html/rfc5246#section-7.3 +func generateCertificateVerify( + handshakeBodies []byte, + signer crypto.Signer, + hashAlgorithm hash.Algorithm, + signatureAlgorithm signature.Algorithm, +) ([]byte, error) { + if _, ok := signer.Public().(ed25519.PublicKey); ok { + // https://pkg.go.dev/crypto/ed25519#PrivateKey.Sign + // Sign signs the given message with priv. Ed25519 performs two passes over + // messages to be signed and therefore cannot handle pre-hashed messages. + return signer.Sign(rand.Reader, handshakeBodies, crypto.Hash(0)) + } + + hashed := hashAlgorithm.Digest(handshakeBodies) + + switch signer.Public().(type) { + case *ecdsa.PublicKey: + return signer.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) + case *rsa.PublicKey: + // Use RSA-PSS if the signature algorithm is PSS + if signatureAlgorithm.IsPSS() { + pssOpts := &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthEqualsHash, + Hash: hashAlgorithm.CryptoHash(), + } + + return signer.Sign(rand.Reader, hashed, pssOpts) + } + + // Otherwise use PKCS#1 v1.5 + return signer.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) + } + + return nil, errInvalidSignatureAlgorithm +} + +//nolint:dupl,cyclop +func verifyCertificateVerify( + handshakeBodies []byte, + hashAlgorithm hash.Algorithm, + signatureAlgorithm signature.Algorithm, + remoteKeySignature []byte, + rawCertificates [][]byte, +) error { + if len(rawCertificates) == 0 { + return errLengthMismatch + } + certificate, err := x509.ParseCertificate(rawCertificates[0]) + if err != nil { + return err + } + + // Validate that the signature algorithm matches the certificate's OID + if err := validateSignatureAlgOID(certificate, signatureAlgorithm); err != nil { + return err + } + + switch pubKey := certificate.PublicKey.(type) { + case ed25519.PublicKey: + if ok := ed25519.Verify(pubKey, handshakeBodies, remoteKeySignature); !ok { + return errKeySignatureMismatch + } + + return nil + case *ecdsa.PublicKey: + ecdsaSig := &ecdsaSignature{} + if _, err := asn1.Unmarshal(remoteKeySignature, ecdsaSig); err != nil { + return err + } + if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 { + return errInvalidECDSASignature + } + hash := hashAlgorithm.Digest(handshakeBodies) + if !ecdsa.Verify(pubKey, hash, ecdsaSig.R, ecdsaSig.S) { + return errKeySignatureMismatch + } + + return nil + case *rsa.PublicKey: + hash := hashAlgorithm.Digest(handshakeBodies) + + // Use RSA-PSS verification if the signature algorithm is PSS + if signatureAlgorithm.IsPSS() { + pssOpts := &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthEqualsHash, + Hash: hashAlgorithm.CryptoHash(), + } + if err := rsa.VerifyPSS(pubKey, hashAlgorithm.CryptoHash(), hash, remoteKeySignature, pssOpts); err != nil { + return errKeySignatureMismatch + } + + return nil + } + + // Otherwise use PKCS#1 v1.5 + if rsa.VerifyPKCS1v15(pubKey, hashAlgorithm.CryptoHash(), hash, remoteKeySignature) != nil { + return errKeySignatureMismatch + } + + return nil + } + + return errKeySignatureVerifyUnimplemented +} + +func loadCerts(rawCertificates [][]byte) ([]*x509.Certificate, error) { + if len(rawCertificates) == 0 { + return nil, errLengthMismatch + } + + certs := make([]*x509.Certificate, 0, len(rawCertificates)) + for _, rawCert := range rawCertificates { + cert, err := x509.ParseCertificate(rawCert) + if err != nil { + return nil, err + } + certs = append(certs, cert) + } + + return certs, nil +} + +func verifyClientCert( + rawCertificates [][]byte, + roots *x509.CertPool, + certSignatureSchemes []signaturehash.Algorithm, +) (chains [][]*x509.Certificate, err error) { + certificate, err := loadCerts(rawCertificates) + if err != nil { + return nil, err + } + intermediateCAPool := x509.NewCertPool() + for _, cert := range certificate[1:] { + intermediateCAPool.AddCert(cert) + } + opts := x509.VerifyOptions{ + Roots: roots, + CurrentTime: time.Now(), + Intermediates: intermediateCAPool, + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + + chains, err = certificate[0].Verify(opts) + if err != nil { + return nil, err + } + + // Validate certificate signature algorithms if specified. + // At least one chain must use only allowed signature algorithms. + if len(certSignatureSchemes) > 0 && len(chains) > 0 { + var validChainFound bool + for _, chain := range chains { + if err := validateCertificateSignatureAlgorithms(chain, certSignatureSchemes); err == nil { + validChainFound = true + + break + } + } + if !validChainFound { + return nil, errInvalidCertificateSignatureAlgorithm + } + } + + return chains, nil +} + +func verifyServerCert( + rawCertificates [][]byte, + roots *x509.CertPool, + serverName string, + certSignatureSchemes []signaturehash.Algorithm, +) (chains [][]*x509.Certificate, err error) { + certificate, err := loadCerts(rawCertificates) + if err != nil { + return nil, err + } + intermediateCAPool := x509.NewCertPool() + for _, cert := range certificate[1:] { + intermediateCAPool.AddCert(cert) + } + opts := x509.VerifyOptions{ + Roots: roots, + CurrentTime: time.Now(), + DNSName: serverName, + Intermediates: intermediateCAPool, + } + + chains, err = certificate[0].Verify(opts) + if err != nil { + return nil, err + } + + // Validate certificate signature algorithms if specified. + // At least one chain must use only allowed signature algorithms. + if len(certSignatureSchemes) > 0 && len(chains) > 0 { + var validChainFound bool + for _, chain := range chains { + if err := validateCertificateSignatureAlgorithms(chain, certSignatureSchemes); err == nil { + validChainFound = true + + break + } + } + if !validChainFound { + return nil, errInvalidCertificateSignatureAlgorithm + } + } + + return chains, nil +} + +// validateCertificateSignatureAlgorithms validates that all certificates in the chain +// use signature algorithms that are in the allowed list. This implements the +// signature_algorithms_cert extension validation per RFC 8446 Section 4.2.3. +func validateCertificateSignatureAlgorithms( + certs []*x509.Certificate, + allowedAlgorithms []signaturehash.Algorithm, +) error { + if len(allowedAlgorithms) == 0 { + // No restrictions specified + return nil + } + + // Validate each certificate's signature algorithm (except the root, which we trust) + for i := 0; i < len(certs)-1; i++ { + cert := certs[i] + certAlg, err := signaturehash.FromCertificate(cert) + if err != nil { + return err + } + + // Check if this algorithm is in the allowed list + found := false + for _, allowed := range allowedAlgorithms { + if certAlg.Hash == allowed.Hash && certAlg.Signature == allowed.Signature { + found = true + + break + } + } + + if !found { + return errInvalidCertificateSignatureAlgorithm + } + } + + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/dtls.go b/vendor/github.com/pion/dtls/v3/dtls.go new file mode 100644 index 0000000..6d52957 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/dtls.go @@ -0,0 +1,5 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package dtls implements Datagram Transport Layer Security (DTLS) 1.2 +package dtls diff --git a/vendor/github.com/pion/dtls/v3/errors.go b/vendor/github.com/pion/dtls/v3/errors.go new file mode 100644 index 0000000..0db0de6 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/errors.go @@ -0,0 +1,308 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "os" + + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/alert" +) + +// Typed errors. +var ( + ErrConnClosed = &FatalError{Err: errors.New("conn is closed")} //nolint:err113 + + errDeadlineExceeded = &TimeoutError{Err: fmt.Errorf("read/write timeout: %w", context.DeadlineExceeded)} + errInvalidContentType = &TemporaryError{Err: errors.New("invalid content type")} //nolint:err113 + + //nolint:err113 + errBufferTooSmall = &TemporaryError{Err: errors.New("buffer is too small")} + //nolint:err113 + errContextUnsupported = &TemporaryError{Err: errors.New("context is not supported for ExportKeyingMaterial")} + //nolint:err113 + errHandshakeInProgress = &TemporaryError{Err: errors.New("handshake is in progress")} + //nolint:err113 + errReservedExportKeyingMaterial = &TemporaryError{ + Err: errors.New("ExportKeyingMaterial can not be used with a reserved label"), + } + //nolint:err113 + errApplicationDataEpochZero = &TemporaryError{Err: errors.New("ApplicationData with epoch of 0")} + //nolint:err113 + errUnhandledContextType = &TemporaryError{Err: errors.New("unhandled contentType")} + + //nolint:err113 + errCertificateVerifyNoCertificate = &FatalError{ + Err: errors.New("client sent certificate verify but we have no certificate to verify"), + } + //nolint:err113 + errCipherSuiteNoIntersection = &FatalError{Err: errors.New("client+server do not support any shared cipher suites")} + //nolint:err113 + errClientCertificateNotVerified = &FatalError{Err: errors.New("client sent certificate but did not verify it")} + //nolint:err113 + errClientCertificateRequired = &FatalError{Err: errors.New("server required client verification, but got none")} + //nolint:err113 + errClientNoMatchingSRTPProfile = &FatalError{Err: errors.New("server responded with SRTP Profile we do not support")} + //nolint:err113 + errClientRequiredButNoServerEMS = &FatalError{ + Err: errors.New("client required Extended Master Secret extension, but server does not support it"), + } + //nolint:err113 + errCookieMismatch = &FatalError{Err: errors.New("client+server cookie does not match")} + //nolint:err113 + errIdentityNoPSK = &FatalError{Err: errors.New("PSK Identity Hint provided but PSK is nil")} + //nolint:err113 + errInvalidCertificate = &FatalError{Err: errors.New("no certificate provided")} + //nolint:err113 + errInvalidCipherSuite = &FatalError{Err: errors.New("invalid or unknown cipher suite")} + //nolint:err113 + errInvalidClientAuthType = &FatalError{Err: errors.New("invalid client auth type")} + //nolint:err113 + errInvalidECDSASignature = &FatalError{Err: errors.New("ECDSA signature contained zero or negative values")} + //nolint:err113 + errInvalidPrivateKey = &FatalError{Err: errors.New("invalid private key type")} + //nolint:err113 + errInvalidSignatureAlgorithm = &FatalError{Err: errors.New("invalid signature algorithm")} + //nolint:err113 + errInvalidExtendedMasterSecretType = &FatalError{Err: errors.New("invalid extended master secret type")} + //nolint:err113 + errInvalidCertificateSignatureAlgorithm = &FatalError{ + Err: errors.New("certificate uses a signature algorithm that is not allowed"), + } + //nolint:err113 + errKeySignatureMismatch = &FatalError{Err: errors.New("expected and actual key signature do not match")} + //nolint:err113 + errInvalidCertificateOID = &FatalError{Err: errors.New("certificate OID does not match signature algorithm")} + //nolint:err113 + errNilNextConn = &FatalError{Err: errors.New("Conn can not be created with a nil nextConn")} + //nolint:err113 + errNoAvailableCipherSuites = &FatalError{ + Err: errors.New("connection can not be created, no CipherSuites satisfy this Config"), + } + //nolint:err113 + errNoAvailablePSKCipherSuite = &FatalError{ + Err: errors.New("connection can not be created, pre-shared key present but no compatible CipherSuite"), + } + //nolint:err113 + errNoAvailableCertificateCipherSuite = &FatalError{ + Err: errors.New("connection can not be created, certificate present but no compatible CipherSuite"), + } + //nolint:err113 + errNoAvailableSignatureSchemes = &FatalError{ + Err: errors.New("connection can not be created, no SignatureScheme satisfy this Config"), + } + //nolint:err113 + errNoCertificates = &FatalError{Err: errors.New("no certificates configured")} + //nolint:err113 + errNoConfigProvided = &FatalError{Err: errors.New("no config provided")} + //nolint:err113 + errNoSupportedEllipticCurves = &FatalError{ + Err: errors.New("client requested zero or more elliptic curves that are not supported by the server"), + } + //nolint:err113 + errUnsupportedProtocolVersion = &FatalError{Err: errors.New("unsupported protocol version")} + //nolint:err113 + errPSKAndIdentityMustBeSetForClient = &FatalError{ + Err: errors.New("PSK and PSK Identity Hint must both be set for client"), + } + //nolint:err113 + errRequestedButNoSRTPExtension = &FatalError{ + Err: errors.New("SRTP support was requested but server did not respond with use_srtp extension"), + } + //nolint:err113 + errServerNoMatchingSRTPProfile = &FatalError{Err: errors.New("client requested SRTP but we have no matching profiles")} + //nolint:err113 + errServerRequiredButNoClientEMS = &FatalError{ + Err: errors.New("server requires the Extended Master Secret extension, but the client does not support it"), + } + //nolint:err113 + errVerifyDataMismatch = &FatalError{Err: errors.New("expected and actual verify data does not match")} + //nolint:err113 + errNotAcceptableCertificateChain = &FatalError{Err: errors.New("certificate chain is not signed by an acceptable CA")} + + //nolint:err113 + errInvalidFlight = &InternalError{Err: errors.New("invalid flight number")} + //nolint:err113 + errKeySignatureGenerateUnimplemented = &InternalError{ + Err: errors.New("unable to generate key signature, unimplemented"), + } + //nolint:err113 + errKeySignatureVerifyUnimplemented = &InternalError{Err: errors.New("unable to verify key signature, unimplemented")} + //nolint:err113 + errLengthMismatch = &InternalError{Err: errors.New("data length and declared length do not match")} + //nolint:err113 + errSequenceNumberOverflow = &InternalError{Err: errors.New("sequence number overflow")} + //nolint:err113 + errInvalidFSMTransition = &InternalError{Err: errors.New("invalid state machine transition")} + //nolint:err113 + errFailedToAccessPoolReadBuffer = &InternalError{Err: errors.New("failed to access pool read buffer")} + //nolint:err113 + errFragmentBufferOverflow = &InternalError{Err: errors.New("fragment buffer overflow")} + + //nolint:err113 + errEmptyCertificates = &FatalError{Err: errors.New("certificates option requires at least one certificate")} + //nolint:err113 + errEmptyCipherSuites = &FatalError{Err: errors.New("cipher suites option requires at least one cipher suite")} + //nolint:err113 + errNilCustomCipherSuites = &FatalError{Err: errors.New("custom cipher suites option requires a non-nil function")} + //nolint:err113 + errEmptySignatureSchemes = &FatalError{Err: errors.New("signature schemes option requires at least one scheme")} + //nolint:err113 + errEmptyCertificateSignatureSchemes = &FatalError{ + Err: errors.New("certificate signature schemes option requires at least one scheme"), + } + //nolint:err113 + errEmptySRTPProtectionProfiles = &FatalError{ + Err: errors.New("SRTP protection profiles option requires at least one profile"), + } + //nolint:err113 + errInvalidFlightInterval = &FatalError{Err: errors.New("flight interval must be positive")} + //nolint:err113 + errNilPSKCallback = &FatalError{Err: errors.New("PSK option requires a non-nil callback")} + //nolint:err113 + errNilVerifyPeerCertificate = &FatalError{ + Err: errors.New("verify peer certificate option requires a non-nil callback"), + } + //nolint:err113 + errNilVerifyConnection = &FatalError{Err: errors.New("verify connection option requires a non-nil callback")} + //nolint:err113 + errInvalidMTU = &FatalError{Err: errors.New("MTU must be positive")} + //nolint:err113 + errInvalidReplayProtectionWindow = &FatalError{Err: errors.New("replay protection window must be non-negative")} + //nolint:err113 + errEmptySupportedProtocols = &FatalError{ + Err: errors.New("supported protocols option requires at least one protocol"), + } + //nolint:err113 + errEmptyEllipticCurves = &FatalError{Err: errors.New("elliptic curves option requires at least one curve")} + //nolint:err113 + errNilGetClientCertificate = &FatalError{ + Err: errors.New("get client certificate option requires a non-nil callback"), + } + //nolint:err113 + errNilConnectionIDGenerator = &FatalError{ + Err: errors.New("connection ID generator option requires a non-nil function"), + } + //nolint:err113 + errNilPaddingLengthGenerator = &FatalError{ + Err: errors.New("padding length generator option requires a non-nil function"), + } + //nolint:err113 + errNilHelloRandomBytesGenerator = &FatalError{ + Err: errors.New("hello random bytes generator option requires a non-nil function"), + } + //nolint:err113 + errNilClientHelloMessageHook = &FatalError{ + Err: errors.New("client hello message hook option requires a non-nil function"), + } + //nolint:err113 + errNilGetCertificate = &FatalError{Err: errors.New("get certificate option requires a non-nil callback")} + //nolint:err113 + errNilServerHelloMessageHook = &FatalError{ + Err: errors.New("server hello message hook option requires a non-nil function"), + } + //nolint:err113 + errNilCertificateRequestMessageHook = &FatalError{ + Err: errors.New("certificate request message hook option requires a non-nil function"), + } + //nolint:err113 + errNilOnConnectionAttempt = &FatalError{ + Err: errors.New("on connection attempt option requires a non-nil callback"), + } +) + +// FatalError indicates that the DTLS connection is no longer available. +// It is mainly caused by wrong configuration of server or client. +type FatalError = protocol.FatalError + +// InternalError indicates and internal error caused by the implementation, +// and the DTLS connection is no longer available. +// It is mainly caused by bugs or tried to use unimplemented features. +type InternalError = protocol.InternalError + +// TemporaryError indicates that the DTLS connection is still available, but the request was failed temporary. +type TemporaryError = protocol.TemporaryError + +// TimeoutError indicates that the request was timed out. +type TimeoutError = protocol.TimeoutError + +// HandshakeError indicates that the handshake failed. +type HandshakeError = protocol.HandshakeError + +// errInvalidCipherSuite indicates an attempt at using an unsupported cipher suite. +type invalidCipherSuiteError struct { + id CipherSuiteID +} + +func (e *invalidCipherSuiteError) Error() string { + return fmt.Sprintf("CipherSuite with id(%d) is not valid", e.id) +} + +func (e *invalidCipherSuiteError) Is(err error) bool { + var other *invalidCipherSuiteError + if errors.As(err, &other) { + return e.id == other.id + } + + return false +} + +// errAlert wraps DTLS alert notification as an error. +type alertError struct { + *alert.Alert +} + +func (e *alertError) Error() string { + return fmt.Sprintf("alert: %s", e.Alert.String()) +} + +func (e *alertError) IsFatalOrCloseNotify() bool { + return e.Level == alert.Fatal || e.Description == alert.CloseNotify +} + +func (e *alertError) Is(err error) bool { + var other *alertError + if errors.As(err, &other) { + return e.Level == other.Level && e.Description == other.Description + } + + return false +} + +// netError translates an error from underlying Conn to corresponding net.Error. +func netError(err error) error { + switch { + case errors.Is(err, io.EOF), errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded): + // Return io.EOF and context errors as is. + return err + } + + var ( + ne net.Error + opError *net.OpError + se *os.SyscallError + ) + + if errors.As(err, &opError) { //nolint:nestif + if errors.As(opError, &se) { + if se.Timeout() { + return &TimeoutError{Err: err} + } + if isOpErrorTemporary(se) { + return &TemporaryError{Err: err} + } + } + } + + if errors.As(err, &ne) { + return err + } + + return &FatalError{Err: err} +} diff --git a/vendor/github.com/pion/dtls/v3/errors_errno.go b/vendor/github.com/pion/dtls/v3/errors_errno.go new file mode 100644 index 0000000..804b057 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/errors_errno.go @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build aix || darwin || dragonfly || freebsd || linux || nacl || nacljs || netbsd || openbsd || solaris || windows +// +build aix darwin dragonfly freebsd linux nacl nacljs netbsd openbsd solaris windows + +// For systems having syscall.Errno. +// Update build targets by following command: +// $ grep -R ECONN $(go env GOROOT)/src/syscall/zerrors_*.go \ +// | tr "." "_" | cut -d"_" -f"2" | sort | uniq + +package dtls + +import ( + "errors" + "os" + "syscall" +) + +func isOpErrorTemporary(err *os.SyscallError) bool { + return errors.Is(err.Err, syscall.ECONNREFUSED) +} diff --git a/vendor/github.com/pion/dtls/v3/errors_noerrno.go b/vendor/github.com/pion/dtls/v3/errors_noerrno.go new file mode 100644 index 0000000..7969dde --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/errors_noerrno.go @@ -0,0 +1,18 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !aix && !darwin && !dragonfly && !freebsd && !linux && !nacl && !nacljs && !netbsd && !openbsd && !solaris && !windows +// +build !aix,!darwin,!dragonfly,!freebsd,!linux,!nacl,!nacljs,!netbsd,!openbsd,!solaris,!windows + +// For systems without syscall.Errno. +// Build targets must be inverse of errors_errno.go + +package dtls + +import ( + "os" +) + +func isOpErrorTemporary(err *os.SyscallError) bool { + return false +} diff --git a/vendor/github.com/pion/dtls/v3/flight.go b/vendor/github.com/pion/dtls/v3/flight.go new file mode 100644 index 0000000..0544177 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/flight.go @@ -0,0 +1,104 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +/* + DTLS messages are grouped into a series of message flights, according + to the diagrams below. Although each flight of messages may consist + of a number of messages, they should be viewed as monolithic for the + purpose of timeout and retransmission. + https://tools.ietf.org/html/rfc4347#section-4.2.4 + + Message flights for full handshake: + + Client Server + ------ ------ + Waiting Flight 0 + + ClientHello --------> Flight 1 + + <------- HelloVerifyRequest Flight 2 + + ClientHello --------> Flight 3 + + ServerHello \ + Certificate* \ + ServerKeyExchange* Flight 4 + CertificateRequest* / + <-------- ServerHelloDone / + + Certificate* \ + ClientKeyExchange \ + CertificateVerify* Flight 5 + [ChangeCipherSpec] / + Finished --------> / + + [ChangeCipherSpec] \ Flight 6 + <-------- Finished / + + Message flights for session-resuming handshake (no cookie exchange): + + Client Server + ------ ------ + Waiting Flight 0 + + ClientHello --------> Flight 1 + + ServerHello \ + [ChangeCipherSpec] Flight 4b + <-------- Finished / + + [ChangeCipherSpec] \ Flight 5b + Finished --------> / + + [ChangeCipherSpec] \ Flight 6 + <-------- Finished / +*/ + +type flightVal uint8 + +const ( + flight0 flightVal = iota + 1 + flight1 + flight2 + flight3 + flight4 + flight4b + flight5 + flight5b + flight6 +) + +func (f flightVal) String() string { //nolint:cyclop + switch f { + case flight0: + return "Flight 0" + case flight1: + return "Flight 1" + case flight2: + return "Flight 2" + case flight3: + return "Flight 3" + case flight4: + return "Flight 4" + case flight4b: + return "Flight 4b" + case flight5: + return "Flight 5" + case flight5b: + return "Flight 5b" + case flight6: + return "Flight 6" + default: + return "Invalid Flight" + } +} + +func (f flightVal) isLastSendFlight() bool { + return f == flight6 || f == flight5b +} + +func (f flightVal) isLastRecvFlight() bool { + return f == flight5 || f == flight4b +} diff --git a/vendor/github.com/pion/dtls/v3/flight0handler.go b/vendor/github.com/pion/dtls/v3/flight0handler.go new file mode 100644 index 0000000..1a9452c --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/flight0handler.go @@ -0,0 +1,189 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "context" + "crypto/rand" + + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/extension" + "github.com/pion/dtls/v3/pkg/protocol/handshake" +) + +// renegotiationInfoSCSV is TLS_EMPTY_RENEGOTIATION_INFO_SCSV defined in RFC 5746. +// https://datatracker.ietf.org/doc/html/rfc5746#section-3.3. +const renegotiationInfoSCSV uint16 = 0x00ff + +//nolint:cyclop,gocognit +func flight0Parse( + _ context.Context, + _ flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { + seq, msgs, ok := cache.fullPullMap(0, state.cipherSuite, + handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, + ) + if !ok { + // No valid message received. Keep reading + return 0, nil, nil + } + + // Connection Identifiers must be negotiated afresh on session resumption. + // https://datatracker.ietf.org/doc/html/rfc9146#name-the-connection_id-extension + state.setLocalConnectionID(nil) + state.remoteConnectionID = nil + + state.handshakeRecvSequence = seq + + var clientHello *handshake.MessageClientHello + + // Validate type + if clientHello, ok = msgs[handshake.TypeClientHello].(*handshake.MessageClientHello); !ok { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil + } + + if !clientHello.Version.Equal(protocol.Version1_2) { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion + } + + state.remoteRandom = clientHello.Random + + cipherSuites := []CipherSuite{} + for _, id := range clientHello.CipherSuiteIDs { + if id == renegotiationInfoSCSV { + state.remoteSupportsRenegotiation = true + + continue + } + if c := cipherSuiteForID(CipherSuiteID(id), cfg.customCipherSuites); c != nil { + cipherSuites = append(cipherSuites, c) + } + } + + if state.cipherSuite, ok = findMatchingCipherSuite(cipherSuites, cfg.localCipherSuites); !ok { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errCipherSuiteNoIntersection + } + + for _, val := range clientHello.Extensions { + switch ext := val.(type) { + case *extension.SupportedEllipticCurves: + if len(ext.EllipticCurves) == 0 { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoSupportedEllipticCurves + } + state.namedCurve = ext.EllipticCurves[0] + case *extension.UseSRTP: + profile, ok := findMatchingSRTPProfile(cfg.localSRTPProtectionProfiles, ext.ProtectionProfiles) + if !ok { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerNoMatchingSRTPProfile + } + state.setSRTPProtectionProfile(profile) + state.remoteSRTPMasterKeyIdentifier = ext.MasterKeyIdentifier + case *extension.UseExtendedMasterSecret: + if cfg.extendedMasterSecret != DisableExtendedMasterSecret { + state.extendedMasterSecret = true + } + case *extension.ServerName: + state.serverName = ext.ServerName // remote server name + case *extension.RenegotiationInfo: + state.remoteSupportsRenegotiation = true + case *extension.ALPN: + state.peerSupportedProtocols = ext.ProtocolNameList + case *extension.ConnectionID: + // Only set connection ID to be sent if server supports connection + // IDs. + if cfg.connectionIDGenerator != nil { + state.remoteConnectionID = ext.CID + } + case *extension.SignatureAlgorithmsCert: + // Store the client's certificate signature schemes for later validation + state.remoteCertSignatureSchemes = ext.SignatureHashAlgorithms + } + } + + // If the client doesn't support connection IDs, the server should not + // expect one to be sent. + if state.remoteConnectionID == nil { + state.setLocalConnectionID(nil) + } + + if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerRequiredButNoClientEMS + } + + if state.localKeypair == nil { + var err error + state.localKeypair, err = elliptic.GenerateKeypair(state.namedCurve) + if err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err + } + } + + nextFlight := flight2 + + if cfg.insecureSkipHelloVerify { + nextFlight = flight4 + } + + return handleHelloResume(clientHello.SessionID, state, cfg, nextFlight) +} + +func handleHelloResume( + sessionID []byte, + state *State, + cfg *handshakeConfig, + next flightVal, +) (flightVal, *alert.Alert, error) { + if len(sessionID) > 0 && cfg.sessionStore != nil { + if s, err := cfg.sessionStore.Get(sessionID); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } else if s.ID != nil { + cfg.log.Tracef("[handshake] resume session: %x", sessionID) + + state.SessionID = sessionID + state.masterSecret = s.Secret + + if err := state.initCipherSuite(); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + + clientRandom := state.localRandom.MarshalFixed() + cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret) + + return flight4b, nil, nil + } + } + + return next, nil, nil +} + +func flight0Generate( + _ flightConn, + state *State, + _ *handshakeCache, + cfg *handshakeConfig, +) ([]*packet, *alert.Alert, error) { + // Initialize + if !cfg.insecureSkipHelloVerify { + state.cookie = make([]byte, cookieLength) + if _, err := rand.Read(state.cookie); err != nil { + return nil, nil, err + } + } + + var zeroEpoch uint16 + state.localEpoch.Store(zeroEpoch) + state.remoteEpoch.Store(zeroEpoch) + state.namedCurve = defaultNamedCurve + + if err := state.localRandom.Populate(); err != nil { + return nil, nil, err + } + + return nil, nil, nil +} diff --git a/vendor/github.com/pion/dtls/v3/flight1handler.go b/vendor/github.com/pion/dtls/v3/flight1handler.go new file mode 100644 index 0000000..cc40684 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/flight1handler.go @@ -0,0 +1,189 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "context" + + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/extension" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" +) + +func flight1Parse( + ctx context.Context, + conn flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { + // HelloVerifyRequest can be skipped by the server, + // so allow ServerHello during flight1 also + seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, + handshakeCachePullRule{handshake.TypeHelloVerifyRequest, cfg.initialEpoch, false, true}, + handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, true}, + ) + if !ok { + // No valid message received. Keep reading + return 0, nil, nil + } + + if _, ok := msgs[handshake.TypeServerHello]; ok { + // Flight1 and flight2 were skipped. + // Parse as flight3. + return flight3Parse(ctx, conn, state, cache, cfg) + } + + if h, ok := msgs[handshake.TypeHelloVerifyRequest].(*handshake.MessageHelloVerifyRequest); ok { + // DTLS 1.2 clients must not assume that the server will use the protocol version + // specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1 + if !h.Version.Equal(protocol.Version1_0) && !h.Version.Equal(protocol.Version1_2) { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion + } + state.cookie = append([]byte{}, h.Cookie...) + state.handshakeRecvSequence = seq + + return flight3, nil, nil + } + + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil +} + +//nolint:cyclop +func flight1Generate( + conn flightConn, + state *State, + _ *handshakeCache, + cfg *handshakeConfig, +) ([]*packet, *alert.Alert, error) { + var zeroEpoch uint16 + state.localEpoch.Store(zeroEpoch) + state.remoteEpoch.Store(zeroEpoch) + state.namedCurve = defaultNamedCurve + state.cookie = nil + + if err := state.localRandom.Populate(); err != nil { + return nil, nil, err + } + + if cfg.helloRandomBytesGenerator != nil { + state.localRandom.RandomBytes = cfg.helloRandomBytesGenerator() + } + + extensions := []extension.Extension{ + &extension.SupportedSignatureAlgorithms{ + SignatureHashAlgorithms: cfg.localSignatureSchemes, + }, + &extension.RenegotiationInfo{ + RenegotiatedConnection: 0, + }, + } + + if len(cfg.localCertSignatureSchemes) > 0 { + extensions = append(extensions, &extension.SignatureAlgorithmsCert{ + SignatureHashAlgorithms: cfg.localCertSignatureSchemes, + }) + } + + var setEllipticCurveCryptographyClientHelloExtensions bool + for _, c := range cfg.localCipherSuites { + if c.ECC() { + setEllipticCurveCryptographyClientHelloExtensions = true + + break + } + } + + if setEllipticCurveCryptographyClientHelloExtensions { + extensions = append(extensions, []extension.Extension{ + &extension.SupportedEllipticCurves{ + EllipticCurves: cfg.ellipticCurves, + }, + &extension.SupportedPointFormats{ + PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed}, + }, + }...) + } + + if len(cfg.localSRTPProtectionProfiles) > 0 { + extensions = append(extensions, &extension.UseSRTP{ + ProtectionProfiles: cfg.localSRTPProtectionProfiles, + MasterKeyIdentifier: cfg.localSRTPMasterKeyIdentifier, + }) + } + + if cfg.extendedMasterSecret == RequestExtendedMasterSecret || + cfg.extendedMasterSecret == RequireExtendedMasterSecret { + extensions = append(extensions, &extension.UseExtendedMasterSecret{ + Supported: true, + }) + } + + if len(cfg.serverName) > 0 { + extensions = append(extensions, &extension.ServerName{ServerName: cfg.serverName}) + } + + if len(cfg.supportedProtocols) > 0 { + extensions = append(extensions, &extension.ALPN{ProtocolNameList: cfg.supportedProtocols}) + } + + if cfg.sessionStore != nil { + cfg.log.Tracef("[handshake] try to resume session") + if s, err := cfg.sessionStore.Get(conn.sessionKey()); err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } else if s.ID != nil { + cfg.log.Tracef("[handshake] get saved session: %x", s.ID) + + state.SessionID = s.ID + state.masterSecret = s.Secret + } + } + + // If we have a connection ID generator, use it. The CID may be zero length, + // in which case we are just requesting that the server send us a CID to + // use. + if cfg.connectionIDGenerator != nil { + state.setLocalConnectionID(cfg.connectionIDGenerator()) + // The presence of a generator indicates support for connection IDs. We + // use the presence of a non-nil local CID in flight 3 to determine + // whether we send a CID in the second ClientHello, so we convert any + // nil CID returned by a generator to []byte{}. + if state.getLocalConnectionID() == nil { + state.setLocalConnectionID([]byte{}) + } + extensions = append(extensions, &extension.ConnectionID{CID: state.getLocalConnectionID()}) + } + + clientHello := &handshake.MessageClientHello{ + Version: protocol.Version1_2, + SessionID: state.SessionID, + Cookie: state.cookie, + Random: state.localRandom, + CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites), + CompressionMethods: defaultCompressionMethods(), + Extensions: extensions, + } + + var content handshake.Handshake + + if cfg.clientHelloMessageHook != nil { + content = handshake.Handshake{Message: cfg.clientHelloMessageHook(*clientHello)} + } else { + content = handshake.Handshake{Message: clientHello} + } + + return []*packet{ + { + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &content, + }, + }, + }, nil, nil +} diff --git a/vendor/github.com/pion/dtls/v3/flight2handler.go b/vendor/github.com/pion/dtls/v3/flight2handler.go new file mode 100644 index 0000000..bfe8670 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/flight2handler.go @@ -0,0 +1,77 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "bytes" + "context" + + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" +) + +func flight2Parse( + ctx context.Context, + c flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { + seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, + handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, + ) + if !ok { + // Client may retransmit the first ClientHello when HelloVerifyRequest is dropped. + // Parse as flight 0 in this case. + return flight0Parse(ctx, c, state, cache, cfg) + } + state.handshakeRecvSequence = seq + + var clientHello *handshake.MessageClientHello + + // Validate type + if clientHello, ok = msgs[handshake.TypeClientHello].(*handshake.MessageClientHello); !ok { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil + } + + if !clientHello.Version.Equal(protocol.Version1_2) { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion + } + + if len(clientHello.Cookie) == 0 { + return 0, nil, nil + } + if !bytes.Equal(state.cookie, clientHello.Cookie) { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.AccessDenied}, errCookieMismatch + } + + return flight4, nil, nil +} + +func flight2Generate( + _ flightConn, + state *State, + _ *handshakeCache, + _ *handshakeConfig, +) ([]*packet, *alert.Alert, error) { + state.handshakeSendSequence = 0 + + return []*packet{ + { + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &handshake.Handshake{ + Message: &handshake.MessageHelloVerifyRequest{ + Version: protocol.Version1_2, + Cookie: state.cookie, + }, + }, + }, + }, + }, nil, nil +} diff --git a/vendor/github.com/pion/dtls/v3/flight3handler.go b/vendor/github.com/pion/dtls/v3/flight3handler.go new file mode 100644 index 0000000..169f81e --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/flight3handler.go @@ -0,0 +1,363 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "bytes" + "context" + + "github.com/pion/dtls/v3/internal/ciphersuite/types" + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/crypto/prf" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/extension" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" +) + +//nolint:gocognit,gocyclo,maintidx,cyclop +func flight3Parse( + ctx context.Context, + conn flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { + // Clients may receive multiple HelloVerifyRequest messages with different cookies. + // Clients SHOULD handle this by sending a new ClientHello with a cookie in response + // to the new HelloVerifyRequest. RFC 6347 Section 4.2.1 + seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, + handshakeCachePullRule{handshake.TypeHelloVerifyRequest, cfg.initialEpoch, false, true}, + ) + if ok { + if h, msgOk := msgs[handshake.TypeHelloVerifyRequest].(*handshake.MessageHelloVerifyRequest); msgOk { + // DTLS 1.2 clients must not assume that the server will use the protocol version + // specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1 + if !h.Version.Equal(protocol.Version1_0) && !h.Version.Equal(protocol.Version1_2) { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion + } + state.cookie = append([]byte{}, h.Cookie...) + state.handshakeRecvSequence = seq + + return flight3, nil, nil + } + } + + _, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, + handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, + ) + if !ok { + // Don't have enough messages. Keep reading + return 0, nil, nil + } + + if serverHelloMsg, msgOk := msgs[handshake.TypeServerHello].(*handshake.MessageServerHello); msgOk { //nolint:nestif + if !serverHelloMsg.Version.Equal(protocol.Version1_2) { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion + } + for _, v := range serverHelloMsg.Extensions { + switch ext := v.(type) { + case *extension.UseSRTP: + profile, found := findMatchingSRTPProfile(ext.ProtectionProfiles, cfg.localSRTPProtectionProfiles) + if !found { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, errClientNoMatchingSRTPProfile + } + state.setSRTPProtectionProfile(profile) + state.remoteSRTPMasterKeyIdentifier = ext.MasterKeyIdentifier + case *extension.UseExtendedMasterSecret: + if cfg.extendedMasterSecret != DisableExtendedMasterSecret { + state.extendedMasterSecret = true + } + case *extension.ALPN: + if len(ext.ProtocolNameList) > 1 { // This should be exactly 1, the zero case is handle when unmarshalling + return 0, &alert.Alert{ + Level: alert.Fatal, + Description: alert.InternalError, + }, extension.ErrALPNInvalidFormat // Meh, internal error? + } + state.NegotiatedProtocol = ext.ProtocolNameList[0] + case *extension.ConnectionID: + // Only set connection ID to be sent if client supports connection + // IDs. + if cfg.connectionIDGenerator != nil { + state.remoteConnectionID = ext.CID + } + } + } + // If the server doesn't support connection IDs, the client should not + // expect one to be sent. + if state.remoteConnectionID == nil { + state.setLocalConnectionID(nil) + } + + if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errClientRequiredButNoServerEMS + } + if len(cfg.localSRTPProtectionProfiles) > 0 && state.getSRTPProtectionProfile() == 0 { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errRequestedButNoSRTPExtension + } + + remoteCipherSuite := cipherSuiteForID(CipherSuiteID(*serverHelloMsg.CipherSuiteID), cfg.customCipherSuites) + if remoteCipherSuite == nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errCipherSuiteNoIntersection + } + + selectedCipherSuite, found := findMatchingCipherSuite([]CipherSuite{remoteCipherSuite}, cfg.localCipherSuites) + if !found { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite + } + + state.cipherSuite = selectedCipherSuite + state.remoteRandom = serverHelloMsg.Random + cfg.log.Tracef("[handshake] use cipher suite: %s", selectedCipherSuite.String()) + + if len(serverHelloMsg.SessionID) > 0 && bytes.Equal(state.SessionID, serverHelloMsg.SessionID) { + return handleResumption(ctx, conn, state, cache, cfg) + } + + if len(state.SessionID) > 0 { + cfg.log.Tracef("[handshake] clean old session : %s", state.SessionID) + if err := cfg.sessionStore.Del(state.SessionID); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + } + + if cfg.sessionStore == nil { + state.SessionID = []byte{} + } else { + state.SessionID = serverHelloMsg.SessionID + } + + state.masterSecret = []byte{} + } + + if cfg.localPSKCallback != nil { + seq, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite, + handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, true}, + handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, + ) + } else { + seq, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite, + handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, true}, + handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, true}, + handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, + ) + } + if !ok { + // Don't have enough messages. Keep reading + return 0, nil, nil + } + state.handshakeRecvSequence = seq + + if h, ok := msgs[handshake.TypeCertificate].(*handshake.MessageCertificate); ok { + state.PeerCertificates = h.Certificate + } else if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errInvalidCertificate + } + + if h, ok := msgs[handshake.TypeServerKeyExchange].(*handshake.MessageServerKeyExchange); ok { + alertPtr, err := handleServerKeyExchange(conn, state, cfg, h) + if err != nil { + return 0, alertPtr, err + } + } + + if creq, ok := msgs[handshake.TypeCertificateRequest].(*handshake.MessageCertificateRequest); ok { + state.remoteCertRequestAlgs = creq.SignatureHashAlgorithms + state.remoteRequestedCertificate = true + } + + return flight5, nil, nil +} + +func handleResumption( + ctx context.Context, + c flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { + if err := state.initCipherSuite(); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + + // Now, encrypted packets can be handled + if err := c.handleQueuedPackets(ctx); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + + _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite, + handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false}, + ) + if !ok { + // No valid message received. Keep reading + return 0, nil, nil + } + + var finished *handshake.MessageFinished + if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil + } + plainText := cache.pullAndMerge( + handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, + ) + + expectedVerifyData, err := prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc()) + if err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + if !bytes.Equal(expectedVerifyData, finished.VerifyData) { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch + } + + clientRandom := state.localRandom.MarshalFixed() + cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret) + + return flight5b, nil, nil +} + +//nolint:cyclop +func handleServerKeyExchange( + _ flightConn, + state *State, + cfg *handshakeConfig, + keyExchangeMessage *handshake.MessageServerKeyExchange, +) (*alert.Alert, error) { + var err error + if state.cipherSuite == nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite + } + if cfg.localPSKCallback != nil { //nolint:nestif + var psk []byte + if psk, err = cfg.localPSKCallback(keyExchangeMessage.IdentityHint); err != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + state.IdentityHint = keyExchangeMessage.IdentityHint + switch state.cipherSuite.KeyExchangeAlgorithm() { + case types.KeyExchangeAlgorithmPsk: + state.preMasterSecret = prf.PSKPreMasterSecret(psk) + case (types.KeyExchangeAlgorithmEcdhe | types.KeyExchangeAlgorithmPsk): + if state.localKeypair, err = elliptic.GenerateKeypair(keyExchangeMessage.NamedCurve); err != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + state.preMasterSecret, err = prf.EcdhePSKPreMasterSecret( + psk, + keyExchangeMessage.PublicKey, + state.localKeypair.PrivateKey, + state.localKeypair.Curve, + ) + if err != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + default: + return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite + } + } else { + if state.localKeypair, err = elliptic.GenerateKeypair(keyExchangeMessage.NamedCurve); err != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + + if state.preMasterSecret, err = prf.PreMasterSecret( + keyExchangeMessage.PublicKey, + state.localKeypair.PrivateKey, + state.localKeypair.Curve, + ); err != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + } + + return nil, nil //nolint:nilnil +} + +func flight3Generate( + _ flightConn, + state *State, + _ *handshakeCache, + cfg *handshakeConfig, +) ([]*packet, *alert.Alert, error) { + extensions := []extension.Extension{ + &extension.SupportedSignatureAlgorithms{ + SignatureHashAlgorithms: cfg.localSignatureSchemes, + }, + &extension.RenegotiationInfo{ + RenegotiatedConnection: 0, + }, + } + + if len(cfg.localCertSignatureSchemes) > 0 { + extensions = append(extensions, &extension.SignatureAlgorithmsCert{ + SignatureHashAlgorithms: cfg.localCertSignatureSchemes, + }) + } + + if state.namedCurve != 0 { + extensions = append(extensions, []extension.Extension{ + &extension.SupportedEllipticCurves{ + EllipticCurves: cfg.ellipticCurves, + }, + &extension.SupportedPointFormats{ + PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed}, + }, + }...) + } + + if len(cfg.localSRTPProtectionProfiles) > 0 { + extensions = append(extensions, &extension.UseSRTP{ + ProtectionProfiles: cfg.localSRTPProtectionProfiles, + }) + } + + if cfg.extendedMasterSecret == RequestExtendedMasterSecret || + cfg.extendedMasterSecret == RequireExtendedMasterSecret { + extensions = append(extensions, &extension.UseExtendedMasterSecret{ + Supported: true, + }) + } + + if len(cfg.serverName) > 0 { + extensions = append(extensions, &extension.ServerName{ServerName: cfg.serverName}) + } + + if len(cfg.supportedProtocols) > 0 { + extensions = append(extensions, &extension.ALPN{ProtocolNameList: cfg.supportedProtocols}) + } + + // If we sent a connection ID on the first ClientHello, send it on the + // second. + if state.getLocalConnectionID() != nil { + extensions = append(extensions, &extension.ConnectionID{CID: state.getLocalConnectionID()}) + } + + clientHello := &handshake.MessageClientHello{ + Version: protocol.Version1_2, + SessionID: state.SessionID, + Cookie: state.cookie, + Random: state.localRandom, + CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites), + CompressionMethods: defaultCompressionMethods(), + Extensions: extensions, + } + + var content handshake.Handshake + + if cfg.clientHelloMessageHook != nil { + content = handshake.Handshake{Message: cfg.clientHelloMessageHook(*clientHello)} + } else { + content = handshake.Handshake{Message: clientHello} + } + + return []*packet{ + { + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &content, + }, + }, + }, nil, nil +} diff --git a/vendor/github.com/pion/dtls/v3/flight4bhandler.go b/vendor/github.com/pion/dtls/v3/flight4bhandler.go new file mode 100644 index 0000000..6233fff --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/flight4bhandler.go @@ -0,0 +1,163 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "bytes" + "context" + + "github.com/pion/dtls/v3/pkg/crypto/prf" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/extension" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" +) + +func flight4bParse( + _ context.Context, + _ flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { + _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, + handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, + ) + if !ok { + // No valid message received. Keep reading + return 0, nil, nil + } + + var finished *handshake.MessageFinished + if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil + } + + plainText := cache.pullAndMerge( + handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false}, + ) + + expectedVerifyData, err := prf.VerifyDataClient(state.masterSecret, plainText, state.cipherSuite.HashFunc()) + if err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + if !bytes.Equal(expectedVerifyData, finished.VerifyData) { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch + } + + // Other party may re-transmit the last flight. Keep state to be flight4b. + return flight4b, nil, nil +} + +//nolint:cyclop +func flight4bGenerate( + _ flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) ([]*packet, *alert.Alert, error) { + var pkts []*packet + + extensions := []extension.Extension{&extension.RenegotiationInfo{ + RenegotiatedConnection: 0, + }} + if (cfg.extendedMasterSecret == RequestExtendedMasterSecret || + cfg.extendedMasterSecret == RequireExtendedMasterSecret) && state.extendedMasterSecret { + extensions = append(extensions, &extension.UseExtendedMasterSecret{ + Supported: true, + }) + } + if state.getSRTPProtectionProfile() != 0 { + extensions = append(extensions, &extension.UseSRTP{ + ProtectionProfiles: []SRTPProtectionProfile{state.getSRTPProtectionProfile()}, + MasterKeyIdentifier: cfg.localSRTPMasterKeyIdentifier, + }) + } + + selectedProto, err := extension.ALPNProtocolSelection(cfg.supportedProtocols, state.peerSupportedProtocols) + if err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.NoApplicationProtocol}, err + } + if selectedProto != "" { + extensions = append(extensions, &extension.ALPN{ + ProtocolNameList: []string{selectedProto}, + }) + state.NegotiatedProtocol = selectedProto + } + + cipherSuiteID := uint16(state.cipherSuite.ID()) + var serverHello handshake.Handshake + + serverHelloMessage := &handshake.MessageServerHello{ + Version: protocol.Version1_2, + Random: state.localRandom, + SessionID: state.SessionID, + CipherSuiteID: &cipherSuiteID, + CompressionMethod: defaultCompressionMethods()[0], + Extensions: extensions, + } + + if cfg.serverHelloMessageHook != nil { + serverHello = handshake.Handshake{Message: cfg.serverHelloMessageHook(*serverHelloMessage)} + } else { + serverHello = handshake.Handshake{Message: serverHelloMessage} + } + + serverHello.Header.MessageSequence = uint16(state.handshakeSendSequence) //nolint:gosec // G115 + + if len(state.localVerifyData) == 0 { + plainText := cache.pullAndMerge( + handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, + ) + raw, err := serverHello.Marshal() + if err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + plainText = append(plainText, raw...) + + state.localVerifyData, err = prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc()) + if err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + } + + pkts = append(pkts, + &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &serverHello, + }, + }, + &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &protocol.ChangeCipherSpec{}, + }, + }, + &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + Epoch: 1, + }, + Content: &handshake.Handshake{ + Message: &handshake.MessageFinished{ + VerifyData: state.localVerifyData, + }, + }, + }, + shouldEncrypt: true, + resetLocalSequenceNumber: true, + }, + ) + + return pkts, nil, nil +} diff --git a/vendor/github.com/pion/dtls/v3/flight4handler.go b/vendor/github.com/pion/dtls/v3/flight4handler.go new file mode 100644 index 0000000..4b5fc61 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/flight4handler.go @@ -0,0 +1,500 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "context" + "crypto" + "crypto/rand" + "crypto/x509" + + "github.com/pion/dtls/v3/internal/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/crypto/prf" + "github.com/pion/dtls/v3/pkg/crypto/signaturehash" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/extension" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" +) + +//nolint:gocognit,gocyclo,lll,cyclop,maintidx +func flight4Parse( + ctx context.Context, + conn flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { + seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, + handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, true}, + handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, true}, + ) + if !ok { + // No valid message received. Keep reading + return 0, nil, nil + } + + // Validate type + var clientKeyExchange *handshake.MessageClientKeyExchange + if clientKeyExchange, ok = msgs[handshake.TypeClientKeyExchange].(*handshake.MessageClientKeyExchange); !ok { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil + } + + if h, hasCert := msgs[handshake.TypeCertificate].(*handshake.MessageCertificate); hasCert { + state.PeerCertificates = h.Certificate + // If the client offer its certificate, just disable session resumption. + // Otherwise, we have to store the certificate identitfication and expire time. + // And we have to check whether this certificate expired, revoked or changed. + // + // https://curl.se/docs/CVE-2016-5419.html + state.SessionID = nil + } + + //nolint:nestif + if verify, hasVerify := msgs[handshake.TypeCertificateVerify].(*handshake.MessageCertificateVerify); hasVerify { + if state.PeerCertificates == nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errCertificateVerifyNoCertificate + } + + plainText := cache.pullAndMerge( + handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, + ) + + // Verify that the pair of hash algorithm and signiture is listed. + var validSignatureScheme bool + for _, ss := range cfg.localSignatureSchemes { + if ss.Hash == verify.HashAlgorithm && ss.Signature == verify.SignatureAlgorithm { + validSignatureScheme = true + + break + } + } + if !validSignatureScheme { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoAvailableSignatureSchemes + } + + if err := verifyCertificateVerify( + plainText, + verify.HashAlgorithm, + verify.SignatureAlgorithm, + verify.Signature, + state.PeerCertificates, + ); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err + } + var chains [][]*x509.Certificate + var err error + var verified bool + if cfg.clientAuth >= VerifyClientCertIfGiven { + // Use cert-specific algorithms if present, otherwise fall back to signature_algorithms per RFC 8446 + certAlgs := cfg.localCertSignatureSchemes + if len(certAlgs) == 0 { + certAlgs = cfg.localSignatureSchemes + } + if chains, err = verifyClientCert(state.PeerCertificates, cfg.clientCAs, certAlgs); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err + } + verified = true + } + if cfg.verifyPeerCertificate != nil { + if err := cfg.verifyPeerCertificate(state.PeerCertificates, chains); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err + } + } + state.peerCertificatesVerified = verified + } else if state.PeerCertificates != nil { + // A certificate was received, but we haven't seen a CertificateVerify + // keep reading until we receive one + return 0, nil, nil + } + + if !state.cipherSuite.IsInitialized() { //nolint:nestif + serverRandom := state.localRandom.MarshalFixed() + clientRandom := state.remoteRandom.MarshalFixed() + + var err error + var preMasterSecret []byte + if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypePreSharedKey { + var psk []byte + if psk, err = cfg.localPSKCallback(clientKeyExchange.IdentityHint); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + state.IdentityHint = clientKeyExchange.IdentityHint + switch state.cipherSuite.KeyExchangeAlgorithm() { + case CipherSuiteKeyExchangeAlgorithmPsk: + preMasterSecret = prf.PSKPreMasterSecret(psk) + case (CipherSuiteKeyExchangeAlgorithmPsk | CipherSuiteKeyExchangeAlgorithmEcdhe): + if preMasterSecret, err = prf.EcdhePSKPreMasterSecret( + psk, + clientKeyExchange.PublicKey, + state.localKeypair.PrivateKey, + state.localKeypair.Curve, + ); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + default: + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidCipherSuite + } + } else { + preMasterSecret, err = prf.PreMasterSecret( + clientKeyExchange.PublicKey, + state.localKeypair.PrivateKey, + state.localKeypair.Curve, + ) + if err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err + } + } + + if state.extendedMasterSecret { + var sessionHash []byte + sessionHash, err = cache.sessionHash(state.cipherSuite.HashFunc(), cfg.initialEpoch) + if err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + + state.masterSecret, err = prf.ExtendedMasterSecret(preMasterSecret, sessionHash, state.cipherSuite.HashFunc()) + if err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + } else { + state.masterSecret, err = prf.MasterSecret( + preMasterSecret, + clientRandom[:], + serverRandom[:], + state.cipherSuite.HashFunc(), + ) + if err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + } + + if err := state.cipherSuite.Init(state.masterSecret, clientRandom[:], serverRandom[:], false); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret) + } + + if len(state.SessionID) > 0 { + s := Session{ + ID: state.SessionID, + Secret: state.masterSecret, + } + cfg.log.Tracef("[handshake] save new session: %x", s.ID) + if err := cfg.sessionStore.Set(state.SessionID, s); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + } + + // Now, encrypted packets can be handled + if err := conn.handleQueuedPackets(ctx); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + + seq, msgs, ok = cache.fullPullMap(seq, state.cipherSuite, + handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, + ) + if !ok { + // No valid message received. Keep reading + return 0, nil, nil + } + state.handshakeRecvSequence = seq + + if _, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil + } + + if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeAnonymous { //nolint:nestif + if cfg.verifyConnection != nil { + stateClone, err := state.clone() + if err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + if err := cfg.verifyConnection(stateClone); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err + } + } + + return flight6, nil, nil + } + + switch cfg.clientAuth { + case RequireAnyClientCert: + if state.PeerCertificates == nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errClientCertificateRequired + } + case VerifyClientCertIfGiven: + if state.PeerCertificates != nil && !state.peerCertificatesVerified { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, errClientCertificateNotVerified + } + case RequireAndVerifyClientCert: + if state.PeerCertificates == nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errClientCertificateRequired + } + if !state.peerCertificatesVerified { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, errClientCertificateNotVerified + } + case NoClientCert, RequestClientCert: + // go to flight6 + } + if cfg.verifyConnection != nil { + stateClone, err := state.clone() + if err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + if err := cfg.verifyConnection(stateClone); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err + } + } + + return flight6, nil, nil +} + +//nolint:gocognit,cyclop,maintidx +func flight4Generate( + _ flightConn, + state *State, + _ *handshakeCache, + cfg *handshakeConfig, +) ([]*packet, *alert.Alert, error) { + extensions := []extension.Extension{} + + if (cfg.extendedMasterSecret == RequestExtendedMasterSecret || + cfg.extendedMasterSecret == RequireExtendedMasterSecret) && state.extendedMasterSecret { + extensions = append(extensions, &extension.UseExtendedMasterSecret{ + Supported: true, + }) + } + if state.getSRTPProtectionProfile() != 0 { + extensions = append(extensions, &extension.UseSRTP{ + ProtectionProfiles: []SRTPProtectionProfile{state.getSRTPProtectionProfile()}, + MasterKeyIdentifier: cfg.localSRTPMasterKeyIdentifier, + }) + } + if state.remoteSupportsRenegotiation { + extensions = append(extensions, &extension.RenegotiationInfo{ + RenegotiatedConnection: 0, + }) + } + if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate { + extensions = append(extensions, &extension.SupportedPointFormats{ + PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed}, + }) + } + + selectedProto, err := extension.ALPNProtocolSelection(cfg.supportedProtocols, state.peerSupportedProtocols) + if err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.NoApplicationProtocol}, err + } + if selectedProto != "" { + extensions = append(extensions, &extension.ALPN{ + ProtocolNameList: []string{selectedProto}, + }) + state.NegotiatedProtocol = selectedProto + } + + // If we have a connection ID generator, we are willing to use connection + // IDs. We already know whether the client supports connection IDs from + // parsing the ClientHello, so avoid setting local connection ID if the + // client won't send it. + if cfg.connectionIDGenerator != nil && state.remoteConnectionID != nil { + state.setLocalConnectionID(cfg.connectionIDGenerator()) + extensions = append(extensions, &extension.ConnectionID{CID: state.getLocalConnectionID()}) + } + + var pkts []*packet + cipherSuiteID := uint16(state.cipherSuite.ID()) + + if cfg.sessionStore != nil { + state.SessionID = make([]byte, sessionLength) + if _, err := rand.Read(state.SessionID); err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + } + + serverHello := &handshake.MessageServerHello{ + Version: protocol.Version1_2, + Random: state.localRandom, + SessionID: state.SessionID, + CipherSuiteID: &cipherSuiteID, + CompressionMethod: defaultCompressionMethods()[0], + Extensions: extensions, + } + + var content handshake.Handshake + + if cfg.serverHelloMessageHook != nil { + content = handshake.Handshake{Message: cfg.serverHelloMessageHook(*serverHello)} + } else { + content = handshake.Handshake{Message: serverHello} + } + + pkts = append(pkts, &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &content, + }, + }) + + switch { + case state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate: + certificate, err := cfg.getCertificate(&ClientHelloInfo{ + ServerName: state.serverName, + CipherSuites: []ciphersuite.ID{state.cipherSuite.ID()}, + RandomBytes: state.remoteRandom.RandomBytes, + }) + if err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, err + } + + pkts = append(pkts, &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &handshake.Handshake{ + Message: &handshake.MessageCertificate{ + Certificate: certificate.Certificate, + }, + }, + }, + }) + + serverRandom := state.localRandom.MarshalFixed() + clientRandom := state.remoteRandom.MarshalFixed() + + signer, ok := certificate.PrivateKey.(crypto.Signer) + if !ok { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidPrivateKey + } + + // Find compatible signature scheme + signatureHashAlgo, err := signaturehash.SelectSignatureScheme(cfg.localSignatureSchemes, signer) + if err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, err + } + + signature, err := generateKeySignature( + clientRandom[:], + serverRandom[:], + state.localKeypair.PublicKey, + state.namedCurve, + signer, + signatureHashAlgo.Hash, + signatureHashAlgo.Signature, + ) + if err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + state.localKeySignature = signature + + pkts = append(pkts, &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &handshake.Handshake{ + Message: &handshake.MessageServerKeyExchange{ + EllipticCurveType: elliptic.CurveTypeNamedCurve, + NamedCurve: state.namedCurve, + PublicKey: state.localKeypair.PublicKey, + HashAlgorithm: signatureHashAlgo.Hash, + SignatureAlgorithm: signatureHashAlgo.Signature, + Signature: state.localKeySignature, + }, + }, + }, + }) + + if cfg.clientAuth > NoClientCert { + // An empty list of certificateAuthorities signals to + // the client that it may send any certificate in response + // to our request. When we know the CAs we trust, then + // we can send them down, so that the client can choose + // an appropriate certificate to give to us. + var certificateAuthorities [][]byte + if cfg.clientCAs != nil { + // nolint:staticcheck // ignoring tlsCert.RootCAs.Subjects is deprecated ERR + // because cert does not come from SystemCertPool and it's ok if certificate + // authorities is empty. + certificateAuthorities = cfg.clientCAs.Subjects() + } + + certReq := &handshake.MessageCertificateRequest{ + CertificateTypes: []clientcertificate.Type{clientcertificate.RSASign, clientcertificate.ECDSASign}, + SignatureHashAlgorithms: cfg.localSignatureSchemes, + CertificateAuthoritiesNames: certificateAuthorities, + } + + var content handshake.Handshake + + if cfg.certificateRequestMessageHook != nil { + content = handshake.Handshake{Message: cfg.certificateRequestMessageHook(*certReq)} + } else { + content = handshake.Handshake{Message: certReq} + } + + pkts = append(pkts, &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &content, + }, + }) + } + case cfg.localPSKIdentityHint != nil || + state.cipherSuite.KeyExchangeAlgorithm().Has(CipherSuiteKeyExchangeAlgorithmEcdhe): + // To help the client in selecting which identity to use, the server + // can provide a "PSK identity hint" in the ServerKeyExchange message. + // If no hint is provided and cipher suite doesn't use elliptic curve, + // the ServerKeyExchange message is omitted. + // + // https://tools.ietf.org/html/rfc4279#section-2 + srvExchange := &handshake.MessageServerKeyExchange{ + IdentityHint: cfg.localPSKIdentityHint, + } + if state.cipherSuite.KeyExchangeAlgorithm().Has(CipherSuiteKeyExchangeAlgorithmEcdhe) { + srvExchange.EllipticCurveType = elliptic.CurveTypeNamedCurve + srvExchange.NamedCurve = state.namedCurve + srvExchange.PublicKey = state.localKeypair.PublicKey + } + pkts = append(pkts, &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &handshake.Handshake{ + Message: srvExchange, + }, + }, + }) + } + + pkts = append(pkts, &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &handshake.Handshake{ + Message: &handshake.MessageServerHelloDone{}, + }, + }, + }) + + return pkts, nil, nil +} diff --git a/vendor/github.com/pion/dtls/v3/flight5bhandler.go b/vendor/github.com/pion/dtls/v3/flight5bhandler.go new file mode 100644 index 0000000..f9c3be6 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/flight5bhandler.go @@ -0,0 +1,89 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "context" + + "github.com/pion/dtls/v3/pkg/crypto/prf" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" +) + +func flight5bParse( + _ context.Context, + _ flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { + _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence-1, state.cipherSuite, + handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false}, + ) + if !ok { + // No valid message received. Keep reading + return 0, nil, nil + } + + if _, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil + } + + // Other party may re-transmit the last flight. Keep state to be flight5b. + return flight5b, nil, nil +} + +func flight5bGenerate( + _ flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) ([]*packet, *alert.Alert, error) { //nolint:gocognit + var pkts []*packet + + pkts = append(pkts, + &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &protocol.ChangeCipherSpec{}, + }, + }) + + if len(state.localVerifyData) == 0 { + plainText := cache.pullAndMerge( + handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false}, + ) + + var err error + state.localVerifyData, err = prf.VerifyDataClient(state.masterSecret, plainText, state.cipherSuite.HashFunc()) + if err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + } + + pkts = append(pkts, + &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + Epoch: 1, + }, + Content: &handshake.Handshake{ + Message: &handshake.MessageFinished{ + VerifyData: state.localVerifyData, + }, + }, + }, + shouldEncrypt: true, + resetLocalSequenceNumber: true, + }) + + return pkts, nil, nil +} diff --git a/vendor/github.com/pion/dtls/v3/flight5handler.go b/vendor/github.com/pion/dtls/v3/flight5handler.go new file mode 100644 index 0000000..5ab4364 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/flight5handler.go @@ -0,0 +1,410 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "bytes" + "context" + "crypto" + "crypto/x509" + + "github.com/pion/dtls/v3/pkg/crypto/prf" + "github.com/pion/dtls/v3/pkg/crypto/signaturehash" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" +) + +func flight5Parse( + _ context.Context, + conn flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { + _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, + handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false}, + ) + if !ok { + // No valid message received. Keep reading + return 0, nil, nil + } + + var finished *handshake.MessageFinished + if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil + } + plainText := cache.pullAndMerge( + handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, + ) + + expectedVerifyData, err := prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc()) + if err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + if !bytes.Equal(expectedVerifyData, finished.VerifyData) { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch + } + + if len(state.SessionID) > 0 { + s := Session{ + ID: state.SessionID, + Secret: state.masterSecret, + } + cfg.log.Tracef("[handshake] save new session: %x", s.ID) + if err := cfg.sessionStore.Set(conn.sessionKey(), s); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + } + + return flight5, nil, nil +} + +//nolint:gocognit,cyclop,maintidx +func flight5Generate( + conn flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) ([]*packet, *alert.Alert, error) { + var signer crypto.Signer + var pkts []*packet + if state.remoteRequestedCertificate { //nolint:nestif + _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence-2, state.cipherSuite, + handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}) + if !ok { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errClientCertificateRequired + } + reqInfo := CertificateRequestInfo{} + if r, ok2 := msgs[handshake.TypeCertificateRequest].(*handshake.MessageCertificateRequest); ok2 { + reqInfo.AcceptableCAs = r.CertificateAuthoritiesNames + } else { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errClientCertificateRequired + } + certificate, err := cfg.getClientCertificate(&reqInfo) + if err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, err + } + if certificate == nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errNotAcceptableCertificateChain + } + if certificate.Certificate != nil { + signer, ok = certificate.PrivateKey.(crypto.Signer) + if !ok { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errInvalidPrivateKey + } + } + pkts = append(pkts, + &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &handshake.Handshake{ + Message: &handshake.MessageCertificate{ + Certificate: certificate.Certificate, + }, + }, + }, + }) + } + + clientKeyExchange := &handshake.MessageClientKeyExchange{} + if cfg.localPSKCallback == nil { + clientKeyExchange.PublicKey = state.localKeypair.PublicKey + } else { + clientKeyExchange.IdentityHint = cfg.localPSKIdentityHint + } + if state != nil && state.localKeypair != nil && len(state.localKeypair.PublicKey) > 0 { + clientKeyExchange.PublicKey = state.localKeypair.PublicKey + } + + pkts = append(pkts, + &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &handshake.Handshake{ + Message: clientKeyExchange, + }, + }, + }) + + serverKeyExchangeData := cache.pullAndMerge( + handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, + ) + + serverKeyExchange := &handshake.MessageServerKeyExchange{} + + // handshakeMessageServerKeyExchange is optional for PSK + if len(serverKeyExchangeData) == 0 { + alertPtr, err := handleServerKeyExchange(conn, state, cfg, &handshake.MessageServerKeyExchange{}) + if err != nil { + return nil, alertPtr, err + } + } else { + rawHandshake := &handshake.Handshake{ + KeyExchangeAlgorithm: state.cipherSuite.KeyExchangeAlgorithm(), + } + err := rawHandshake.Unmarshal(serverKeyExchangeData) + if err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, err + } + + switch h := rawHandshake.Message.(type) { + case *handshake.MessageServerKeyExchange: + serverKeyExchange = h + default: + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errInvalidContentType + } + } + + // Append not-yet-sent packets + merged := []byte{} + seqPred := uint16(state.handshakeSendSequence) //nolint:gosec // G115 + for _, p := range pkts { + h, ok := p.record.Content.(*handshake.Handshake) + if !ok { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidContentType + } + h.Header.MessageSequence = seqPred + seqPred++ + raw, err := h.Marshal() + if err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + merged = append(merged, raw...) + } + + if alertPtr, err := initializeCipherSuite(state, cache, cfg, serverKeyExchange, merged); err != nil { + return nil, alertPtr, err + } + + // If the client has sent a certificate with signing ability, a digitally-signed + // CertificateVerify message is sent to explicitly verify possession of the + // private key in the certificate. + if state.remoteRequestedCertificate && signer != nil { + plainText := append(cache.pullAndMerge( + handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, + ), merged...) + + // Find compatible signature scheme + + signatureHashAlgo, err := signaturehash.SelectSignatureScheme(state.remoteCertRequestAlgs, signer) + if err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, err + } + + certVerify, err := generateCertificateVerify(plainText, signer, signatureHashAlgo.Hash, signatureHashAlgo.Signature) + if err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + state.localCertificatesVerify = certVerify + + pkt := &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &handshake.Handshake{ + Message: &handshake.MessageCertificateVerify{ + HashAlgorithm: signatureHashAlgo.Hash, + SignatureAlgorithm: signatureHashAlgo.Signature, + Signature: state.localCertificatesVerify, + }, + }, + }, + } + pkts = append(pkts, pkt) + + h, ok := pkt.record.Content.(*handshake.Handshake) + if !ok { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidContentType + } + h.Header.MessageSequence = seqPred + // seqPred++ // this is the last use of seqPred + raw, err := h.Marshal() + if err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + merged = append(merged, raw...) + } + + pkts = append(pkts, + &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &protocol.ChangeCipherSpec{}, + }, + }) + + if len(state.localVerifyData) == 0 { + plainText := cache.pullAndMerge( + handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, + ) + + var err error + state.localVerifyData, err = prf.VerifyDataClient( + state.masterSecret, + append(plainText, merged...), + state.cipherSuite.HashFunc(), + ) + if err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + } + + pkts = append(pkts, + &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + Epoch: 1, + }, + Content: &handshake.Handshake{ + Message: &handshake.MessageFinished{ + VerifyData: state.localVerifyData, + }, + }, + }, + shouldWrapCID: len(state.remoteConnectionID) > 0, + shouldEncrypt: true, + resetLocalSequenceNumber: true, + }) + + return pkts, nil, nil +} + +//nolint:gocognit,cyclop +func initializeCipherSuite( + state *State, + cache *handshakeCache, + cfg *handshakeConfig, + handshakeKeyExchange *handshake.MessageServerKeyExchange, + sendingPlainText []byte, +) (*alert.Alert, error) { + if state.cipherSuite.IsInitialized() { + return nil, nil //nolint + } + + clientRandom := state.localRandom.MarshalFixed() + serverRandom := state.remoteRandom.MarshalFixed() + + var err error + + if state.extendedMasterSecret { + var sessionHash []byte + sessionHash, err = cache.sessionHash(state.cipherSuite.HashFunc(), cfg.initialEpoch, sendingPlainText) + if err != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + + state.masterSecret, err = prf.ExtendedMasterSecret(state.preMasterSecret, sessionHash, state.cipherSuite.HashFunc()) + if err != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err + } + } else { + state.masterSecret, err = prf.MasterSecret( + state.preMasterSecret, + clientRandom[:], + serverRandom[:], + state.cipherSuite.HashFunc(), + ) + if err != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + } + + if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate { //nolint:nestif + // Verify that the pair of hash algorithm and signiture is listed. + var validSignatureScheme bool + for _, ss := range cfg.localSignatureSchemes { + if ss.Hash == handshakeKeyExchange.HashAlgorithm && ss.Signature == handshakeKeyExchange.SignatureAlgorithm { + validSignatureScheme = true + + break + } + } + if !validSignatureScheme { + return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoAvailableSignatureSchemes + } + + expectedMsg := valueKeyMessage( + clientRandom[:], + serverRandom[:], + handshakeKeyExchange.PublicKey, + handshakeKeyExchange.NamedCurve, + ) + if err = verifyKeySignature( + expectedMsg, + handshakeKeyExchange.Signature, + handshakeKeyExchange.HashAlgorithm, + handshakeKeyExchange.SignatureAlgorithm, + state.PeerCertificates, + ); err != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err + } + var chains [][]*x509.Certificate + if !cfg.insecureSkipVerify { + certAlgs := cfg.localCertSignatureSchemes + if len(certAlgs) == 0 { + certAlgs = cfg.localSignatureSchemes + } + if chains, err = verifyServerCert(state.PeerCertificates, cfg.rootCAs, cfg.serverName, certAlgs); err != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err + } + } + if cfg.verifyPeerCertificate != nil { + if err = cfg.verifyPeerCertificate(state.PeerCertificates, chains); err != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err + } + } + } + if cfg.verifyConnection != nil { + stateClone, errC := state.clone() + if errC != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errC + } + if errC = cfg.verifyConnection(stateClone); errC != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, errC + } + } + + if err = state.cipherSuite.Init(state.masterSecret, clientRandom[:], serverRandom[:], true); err != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + + cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret) + + return nil, nil //nolint +} diff --git a/vendor/github.com/pion/dtls/v3/flight6handler.go b/vendor/github.com/pion/dtls/v3/flight6handler.go new file mode 100644 index 0000000..bc9e818 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/flight6handler.go @@ -0,0 +1,98 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "context" + + "github.com/pion/dtls/v3/pkg/crypto/prf" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" +) + +func flight6Parse( + _ context.Context, + _ flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { + _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence-1, state.cipherSuite, + handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, + ) + if !ok { + // No valid message received. Keep reading + return 0, nil, nil + } + + if _, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil + } + + // Other party may re-transmit the last flight. Keep state to be flight6. + return flight6, nil, nil +} + +func flight6Generate( + _ flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) ([]*packet, *alert.Alert, error) { + var pkts []*packet + + pkts = append(pkts, + &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &protocol.ChangeCipherSpec{}, + }, + }) + + if len(state.localVerifyData) == 0 { + plainText := cache.pullAndMerge( + handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, + ) + + var err error + state.localVerifyData, err = prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc()) + if err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + } + + pkts = append(pkts, + &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + Epoch: 1, + }, + Content: &handshake.Handshake{ + Message: &handshake.MessageFinished{ + VerifyData: state.localVerifyData, + }, + }, + }, + shouldWrapCID: len(state.remoteConnectionID) > 0, + shouldEncrypt: true, + resetLocalSequenceNumber: true, + }, + ) + + return pkts, nil, nil +} diff --git a/vendor/github.com/pion/dtls/v3/flighthandler.go b/vendor/github.com/pion/dtls/v3/flighthandler.go new file mode 100644 index 0000000..0a79a07 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/flighthandler.go @@ -0,0 +1,74 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "context" + + "github.com/pion/dtls/v3/pkg/protocol/alert" +) + +// Parse received handshakes and return next flightVal. +type flightParser func( + context.Context, + flightConn, + *State, + *handshakeCache, + *handshakeConfig, +) (flightVal, *alert.Alert, error) + +// Generate flights. +type flightGenerator func(flightConn, *State, *handshakeCache, *handshakeConfig) ([]*packet, *alert.Alert, error) + +func (f flightVal) getFlightParser() (flightParser, error) { //nolint:cyclop + switch f { + case flight0: + return flight0Parse, nil + case flight1: + return flight1Parse, nil + case flight2: + return flight2Parse, nil + case flight3: + return flight3Parse, nil + case flight4: + return flight4Parse, nil + case flight4b: + return flight4bParse, nil + case flight5: + return flight5Parse, nil + case flight5b: + return flight5bParse, nil + case flight6: + return flight6Parse, nil + default: + return nil, errInvalidFlight + } +} + +func (f flightVal) getFlightGenerator() (gen flightGenerator, retransmit bool, err error) { //nolint:cyclop + switch f { + case flight0: + return flight0Generate, true, nil + case flight1: + return flight1Generate, true, nil + case flight2: + // https://tools.ietf.org/html/rfc6347#section-3.2.1 + // HelloVerifyRequests must not be retransmitted. + return flight2Generate, false, nil + case flight3: + return flight3Generate, true, nil + case flight4: + return flight4Generate, true, nil + case flight4b: + return flight4bGenerate, true, nil + case flight5: + return flight5Generate, true, nil + case flight5b: + return flight5bGenerate, true, nil + case flight6: + return flight6Generate, true, nil + default: + return nil, false, errInvalidFlight + } +} diff --git a/vendor/github.com/pion/dtls/v3/fragment_buffer.go b/vendor/github.com/pion/dtls/v3/fragment_buffer.go new file mode 100644 index 0000000..926f12d --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/fragment_buffer.go @@ -0,0 +1,151 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" +) + +const ( + // 2 megabytes. + fragmentBufferMaxSize = 2000000 + fragmentBufferMaxCount = 1000 +) + +type fragment struct { + recordLayerHeader recordlayer.Header + handshakeHeader handshake.Header + data []byte +} + +type fragments struct { + fragmentByOffset map[uint32]*fragment + fragmentsLength uint32 + handshakeLength uint32 +} + +type fragmentBuffer struct { + // map of MessageSequenceNumbers that hold slices of fragments + cache map[uint16]*fragments + + currentMessageSequenceNumber uint16 + + totalBufferSize int + totalFragmentCount int +} + +func newFragmentBuffer() *fragmentBuffer { + return &fragmentBuffer{cache: map[uint16]*fragments{}} +} + +// current total size of buffer. +func (f *fragmentBuffer) size() int { + return f.totalBufferSize +} + +// Attempts to push a DTLS packet to the fragmentBuffer +// when it returns true it means the fragmentBuffer has inserted and the buffer shouldn't be handled +// when an error returns it is fatal, and the DTLS connection should be stopped. +func (f *fragmentBuffer) push(buf []byte) (isHandshake, isRetransmit bool, err error) { //nolint:cyclop + if f.size()+len(buf) >= fragmentBufferMaxSize || f.totalFragmentCount >= fragmentBufferMaxCount { + return false, false, errFragmentBufferOverflow + } + + recordLayerHeader := recordlayer.Header{} + if err := recordLayerHeader.Unmarshal(buf); err != nil { + return false, false, err + } + + // fragment isn't a handshake, we don't need to handle it + if recordLayerHeader.ContentType != protocol.ContentTypeHandshake { + return false, false, nil + } + + frag := new(fragment) + for buf = buf[recordlayer.FixedHeaderSize:]; len(buf) != 0; frag = new(fragment) { //nolint:gosec // G602 + if err := frag.handshakeHeader.Unmarshal(buf); err != nil { + return false, false, err + } + + // Fragment is a retransmission. We have already assembled it before successfully + isRetransmit = frag.handshakeHeader.FragmentOffset == 0 && + frag.handshakeHeader.MessageSequence < f.currentMessageSequenceNumber + + end := int(handshake.HeaderLength + frag.handshakeHeader.FragmentLength) + if end > len(buf) { + return false, false, errBufferTooSmall + } + if frag.handshakeHeader.MessageSequence < f.currentMessageSequenceNumber { + buf = buf[end:] + + continue + } + + messageFragments, ok := f.cache[frag.handshakeHeader.MessageSequence] + if !ok { + messageFragments = &fragments{ + fragmentByOffset: map[uint32]*fragment{}, handshakeLength: frag.handshakeHeader.Length, + } + f.cache[frag.handshakeHeader.MessageSequence] = messageFragments + } + + // Discard all headers, when rebuilding the packet we will re-build + frag.data = append([]byte{}, buf[handshake.HeaderLength:end]...) + frag.recordLayerHeader = recordLayerHeader + + if _, ok = messageFragments.fragmentByOffset[frag.handshakeHeader.FragmentOffset]; !ok { + messageFragments.fragmentByOffset[frag.handshakeHeader.FragmentOffset] = frag + messageFragments.fragmentsLength += frag.handshakeHeader.FragmentLength + f.totalBufferSize += int(frag.handshakeHeader.FragmentLength) + f.totalFragmentCount++ + } + buf = buf[end:] + } + + return true, isRetransmit, nil +} + +func (f *fragmentBuffer) pop() (content []byte, epoch uint16) { + frags, ok := f.cache[f.currentMessageSequenceNumber] + if !ok { + return nil, 0 + } + + if frags.fragmentsLength != frags.handshakeLength { + return nil, 0 + } + + var rawMessage []byte + targetOffset := uint32(0) + for i := 0; i < len(frags.fragmentByOffset) && targetOffset < frags.handshakeLength; i++ { + if frag, ok := frags.fragmentByOffset[targetOffset]; ok { + rawMessage = append(rawMessage, frag.data...) + targetOffset = frag.handshakeHeader.FragmentOffset + frag.handshakeHeader.FragmentLength + } else { + return nil, 0 + } + } + + if int(frags.handshakeLength) != len(rawMessage) { + return nil, 0 + } + + firstHeader := frags.fragmentByOffset[0].handshakeHeader + firstHeader.FragmentOffset = 0 + firstHeader.FragmentLength = firstHeader.Length + + rawHeader, _ := firstHeader.Marshal() + + messageEpoch := frags.fragmentByOffset[0].recordLayerHeader.Epoch + + f.totalBufferSize -= int(frags.fragmentsLength) + f.totalFragmentCount -= len(frags.fragmentByOffset) + + delete(f.cache, f.currentMessageSequenceNumber) + f.currentMessageSequenceNumber++ + + return append(rawHeader, rawMessage...), messageEpoch +} diff --git a/vendor/github.com/pion/dtls/v3/handshake_cache.go b/vendor/github.com/pion/dtls/v3/handshake_cache.go new file mode 100644 index 0000000..53e9515 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/handshake_cache.go @@ -0,0 +1,185 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "sync" + + "github.com/pion/dtls/v3/pkg/crypto/prf" + "github.com/pion/dtls/v3/pkg/protocol/handshake" +) + +type handshakeCacheItem struct { + typ handshake.Type + isClient bool + epoch uint16 + messageSequence uint16 + data []byte +} + +type handshakeCachePullRule struct { + typ handshake.Type + epoch uint16 + isClient bool + optional bool +} + +type handshakeCache struct { + cache []*handshakeCacheItem + mu sync.Mutex +} + +func newHandshakeCache() *handshakeCache { + return &handshakeCache{} +} + +func (h *handshakeCache) push(data []byte, epoch, messageSequence uint16, typ handshake.Type, isClient bool) { + h.mu.Lock() + defer h.mu.Unlock() + + h.cache = append(h.cache, &handshakeCacheItem{ + data: append([]byte{}, data...), + epoch: epoch, + messageSequence: messageSequence, + typ: typ, + isClient: isClient, + }) +} + +// returns a list handshakes that match the requested rules +// the list will contain null entries for rules that can't be satisfied +// multiple entries may match a rule, but only the last match is returned (ie ClientHello with cookies). +func (h *handshakeCache) pull(rules ...handshakeCachePullRule) []*handshakeCacheItem { + h.mu.Lock() + defer h.mu.Unlock() + + out := make([]*handshakeCacheItem, len(rules)) + for i, r := range rules { + for _, c := range h.cache { + if c.typ == r.typ && c.isClient == r.isClient && c.epoch == r.epoch { + switch { + case out[i] == nil: + out[i] = c + case out[i].messageSequence < c.messageSequence: + out[i] = c + } + } + } + } + + return out +} + +// fullPullMap pulls all handshakes between rules[0] to rules[len(rules)-1] as map. +// +//nolint:cyclop +func (h *handshakeCache) fullPullMap( + startSeq int, + cipherSuite CipherSuite, + rules ...handshakeCachePullRule, +) (int, map[handshake.Type]handshake.Message, bool) { + h.mu.Lock() + defer h.mu.Unlock() + + ci := make(map[handshake.Type]*handshakeCacheItem) + for _, rule := range rules { + var item *handshakeCacheItem + for _, c := range h.cache { + if c.typ == rule.typ && c.isClient == rule.isClient && c.epoch == rule.epoch { + switch { + case item == nil: + item = c + case item.messageSequence < c.messageSequence: + item = c + } + } + } + if !rule.optional && item == nil { + // Missing mandatory message. + return startSeq, nil, false + } + ci[rule.typ] = item + } + out := make(map[handshake.Type]handshake.Message) + seq := startSeq + ok := false + for _, r := range rules { + typ := r.typ + i := ci[typ] + if i == nil { + continue + } + var keyExchangeAlgorithm CipherSuiteKeyExchangeAlgorithm + if cipherSuite != nil { + keyExchangeAlgorithm = cipherSuite.KeyExchangeAlgorithm() + } + rawHandshake := &handshake.Handshake{ + KeyExchangeAlgorithm: keyExchangeAlgorithm, + } + if err := rawHandshake.Unmarshal(i.data); err != nil { + return startSeq, nil, false + } + if uint16(seq) != rawHandshake.Header.MessageSequence { //nolint:gosec // G115 + // There is a gap. Some messages are not arrived. + return startSeq, nil, false + } + seq++ + ok = true + out[typ] = rawHandshake.Message + } + if !ok { + return seq, nil, false + } + + return seq, out, true +} + +// pullAndMerge calls pull and then merges the results, ignoring any null entries. +func (h *handshakeCache) pullAndMerge(rules ...handshakeCachePullRule) []byte { + merged := []byte{} + + for _, p := range h.pull(rules...) { + if p != nil { + merged = append(merged, p.data...) + } + } + + return merged +} + +// sessionHash returns the session hash for Extended Master Secret support +// https://tools.ietf.org/html/draft-ietf-tls-session-hash-06#section-4 +func (h *handshakeCache) sessionHash(hf prf.HashFunc, epoch uint16, additional ...[]byte) ([]byte, error) { + merged := []byte{} + + // Order defined by https://tools.ietf.org/html/rfc5246#section-7.3 + handshakeBuffer := h.pull( + handshakeCachePullRule{handshake.TypeClientHello, epoch, true, false}, + handshakeCachePullRule{handshake.TypeServerHello, epoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificate, epoch, false, false}, + handshakeCachePullRule{handshake.TypeServerKeyExchange, epoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificateRequest, epoch, false, false}, + handshakeCachePullRule{handshake.TypeServerHelloDone, epoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificate, epoch, true, false}, + handshakeCachePullRule{handshake.TypeClientKeyExchange, epoch, true, false}, + ) + + for _, p := range handshakeBuffer { + if p == nil { + continue + } + + merged = append(merged, p.data...) + } + for _, a := range additional { + merged = append(merged, a...) + } + + hash := hf() + if _, err := hash.Write(merged); err != nil { + return []byte{}, err + } + + return hash.Sum(nil), nil +} diff --git a/vendor/github.com/pion/dtls/v3/handshaker.go b/vendor/github.com/pion/dtls/v3/handshaker.go new file mode 100644 index 0000000..74fbfbe --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/handshaker.go @@ -0,0 +1,364 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "io" + "sync" + "time" + + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/crypto/signaturehash" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/logging" +) + +// [RFC6347 Section-4.2.4] +// +-----------+ +// +---> | PREPARING | <--------------------+ +// | +-----------+ | +// | | | +// | | Buffer next flight | +// | | | +// | \|/ | +// | +-----------+ | +// | | SENDING |<------------------+ | Send +// | +-----------+ | | HelloRequest +// Receive | | | | +// next | | Send flight | | or +// flight | +--------+ | | +// | | | Set retransmit timer | | Receive +// | | \|/ | | HelloRequest +// | | +-----------+ | | Send +// +--)--| WAITING |-------------------+ | ClientHello +// | | +-----------+ Timer expires | | +// | | | | | +// | | +------------------------+ | +// Receive | | Send Read retransmit | +// last | | last | +// flight | | flight | +// | | | +// \|/\|/ | +// +-----------+ | +// | FINISHED | -------------------------------+ +// +-----------+ +// | /|\ +// | | +// +---+ +// Read retransmit +// Retransmit last flight + +type handshakeState uint8 + +const ( + handshakeErrored handshakeState = iota + handshakePreparing + handshakeSending + handshakeWaiting + handshakeFinished +) + +func (s handshakeState) String() string { + switch s { + case handshakeErrored: + return "Errored" + case handshakePreparing: + return "Preparing" + case handshakeSending: + return "Sending" + case handshakeWaiting: + return "Waiting" + case handshakeFinished: + return "Finished" + default: + return "Unknown" + } +} + +type handshakeFSM struct { + currentFlight flightVal + flights []*packet + retransmit bool + retransmitInterval time.Duration + state *State + cache *handshakeCache + cfg *handshakeConfig + closed chan struct{} +} + +type handshakeConfig struct { + localPSKCallback PSKCallback + localPSKIdentityHint []byte + localCipherSuites []CipherSuite // Available CipherSuites + localSignatureSchemes []signaturehash.Algorithm // Available signature schemes + localCertSignatureSchemes []signaturehash.Algorithm // Available signature schemes for certificates + extendedMasterSecret ExtendedMasterSecretType // Policy for the Extended Master Support extension + localSRTPProtectionProfiles []SRTPProtectionProfile // Available SRTPProtectionProfiles, if empty no SRTP support + localSRTPMasterKeyIdentifier []byte + serverName string + supportedProtocols []string + clientAuth ClientAuthType // If we are a client should we request a client certificate + localCertificates []tls.Certificate + nameToCertificate map[string]*tls.Certificate + insecureSkipVerify bool + verifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error + verifyConnection func(*State) error + sessionStore SessionStore + rootCAs *x509.CertPool + clientCAs *x509.CertPool + initialRetransmitInterval time.Duration + disableRetransmitBackoff bool + customCipherSuites func() []CipherSuite + ellipticCurves []elliptic.Curve + insecureSkipHelloVerify bool + connectionIDGenerator func() []byte + helloRandomBytesGenerator func() [handshake.RandomBytesLength]byte + + onFlightState func(flightVal, handshakeState) + log logging.LeveledLogger + keyLogWriter io.Writer + + localGetCertificate func(*ClientHelloInfo) (*tls.Certificate, error) + localGetClientCertificate func(*CertificateRequestInfo) (*tls.Certificate, error) + + initialEpoch uint16 + + mu sync.Mutex + + clientHelloMessageHook func(handshake.MessageClientHello) handshake.Message + serverHelloMessageHook func(handshake.MessageServerHello) handshake.Message + certificateRequestMessageHook func(handshake.MessageCertificateRequest) handshake.Message + + resumeState *State +} + +type flightConn interface { + notify(ctx context.Context, level alert.Level, desc alert.Description) error + writePackets(context.Context, []*packet) error + recvHandshake() <-chan recvHandshakeState + setLocalEpoch(epoch uint16) + handleQueuedPackets(context.Context) error + sessionKey() []byte +} + +func (c *handshakeConfig) writeKeyLog(label string, clientRandom, secret []byte) { + if c.keyLogWriter == nil { + return + } + c.mu.Lock() + defer c.mu.Unlock() + _, err := fmt.Fprintf(c.keyLogWriter, "%s %x %x\n", label, clientRandom, secret) + if err != nil { + c.log.Debugf("failed to write key log file: %s", err) + } +} + +func srvCliStr(isClient bool) string { + if isClient { + return "client" + } + + return "server" +} + +func newHandshakeFSM( + s *State, cache *handshakeCache, cfg *handshakeConfig, + initialFlight flightVal, +) *handshakeFSM { + return &handshakeFSM{ + currentFlight: initialFlight, + state: s, + cache: cache, + cfg: cfg, + retransmitInterval: cfg.initialRetransmitInterval, + closed: make(chan struct{}), + } +} + +func (s *handshakeFSM) Run(ctx context.Context, conn flightConn, initialState handshakeState) error { + state := initialState + defer func() { + close(s.closed) + }() + for { + s.cfg.log.Tracef("[handshake:%s] %s: %s", srvCliStr(s.state.isClient), s.currentFlight.String(), state.String()) + if s.cfg.onFlightState != nil { + s.cfg.onFlightState(s.currentFlight, state) + } + var err error + switch state { + case handshakePreparing: + state, err = s.prepare(ctx, conn) + case handshakeSending: + state, err = s.send(ctx, conn) + case handshakeWaiting: + state, err = s.wait(ctx, conn) + case handshakeFinished: + state, err = s.finish(ctx, conn) + default: + return errInvalidFSMTransition + } + if err != nil { + return err + } + } +} + +func (s *handshakeFSM) Done() <-chan struct{} { + return s.closed +} + +func (s *handshakeFSM) prepare(ctx context.Context, conn flightConn) (handshakeState, error) { + s.flights = nil + // Prepare flights + var ( + dtlsAlert *alert.Alert + err error + pkts []*packet + ) + gen, retransmit, errFlight := s.currentFlight.getFlightGenerator() + if errFlight != nil { + err = errFlight + dtlsAlert = &alert.Alert{Level: alert.Fatal, Description: alert.InternalError} + } else { + pkts, dtlsAlert, err = gen(conn, s.state, s.cache, s.cfg) + s.retransmit = retransmit + } + if dtlsAlert != nil { + if alertErr := conn.notify(ctx, dtlsAlert.Level, dtlsAlert.Description); alertErr != nil { + if err != nil { + err = alertErr + } + } + } + if err != nil { + return handshakeErrored, err + } + + s.flights = pkts + epoch := s.cfg.initialEpoch + nextEpoch := epoch + for _, p := range s.flights { + p.record.Header.Epoch += epoch + if p.record.Header.Epoch > nextEpoch { + nextEpoch = p.record.Header.Epoch + } + if h, ok := p.record.Content.(*handshake.Handshake); ok { + h.Header.MessageSequence = uint16(s.state.handshakeSendSequence) //nolint:gosec // G115 + s.state.handshakeSendSequence++ + } + } + if epoch != nextEpoch { + s.cfg.log.Tracef("[handshake:%s] -> changeCipherSpec (epoch: %d)", srvCliStr(s.state.isClient), nextEpoch) + conn.setLocalEpoch(nextEpoch) + } + + return handshakeSending, nil +} + +func (s *handshakeFSM) send(ctx context.Context, c flightConn) (handshakeState, error) { + // Send flights + if err := c.writePackets(ctx, s.flights); err != nil { + return handshakeErrored, err + } + + if s.currentFlight.isLastSendFlight() { + return handshakeFinished, nil + } + + return handshakeWaiting, nil +} + +func (s *handshakeFSM) wait(ctx context.Context, conn flightConn) (handshakeState, error) { //nolint:gocognit,cyclop + parse, errFlight := s.currentFlight.getFlightParser() + if errFlight != nil { + if alertErr := conn.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil { + return handshakeErrored, alertErr + } + + return handshakeErrored, errFlight + } + + retransmitTimer := time.NewTimer(s.retransmitInterval) + for { + select { + case state := <-conn.recvHandshake(): + if state.isRetransmit { + close(state.done) + // ignore incoming retransmit hints, only rely on the timer-driven path below + // https://github.com/pion/dtls/issues/758 + continue + } + + nextFlight, alert, err := parse(ctx, conn, s.state, s.cache, s.cfg) + s.retransmitInterval = s.cfg.initialRetransmitInterval + close(state.done) + if alert != nil { + if alertErr := conn.notify(ctx, alert.Level, alert.Description); alertErr != nil { + if err != nil { + err = alertErr + } + } + } + if err != nil { + return handshakeErrored, err + } + if nextFlight == 0 { + break + } + s.cfg.log.Tracef( + "[handshake:%s] %s -> %s", + srvCliStr(s.state.isClient), + s.currentFlight.String(), + nextFlight.String(), + ) + if nextFlight.isLastRecvFlight() && s.currentFlight == nextFlight { + return handshakeFinished, nil + } + s.currentFlight = nextFlight + + return handshakePreparing, nil + + case <-retransmitTimer.C: + if !s.retransmit { + return handshakeWaiting, nil + } + + // RFC 4347 4.2.4.1: + // Implementations SHOULD use an initial timer value of 1 second (the minimum defined in RFC 2988 [RFC2988]) + // and double the value at each retransmission, up to no less than the RFC 2988 maximum of 60 seconds. + if !s.cfg.disableRetransmitBackoff { + s.retransmitInterval *= 2 + } + if s.retransmitInterval > time.Second*60 { + s.retransmitInterval = time.Second * 60 + } + + return handshakeSending, nil + case <-ctx.Done(): + s.retransmitInterval = s.cfg.initialRetransmitInterval + + return handshakeErrored, ctx.Err() + } + } +} + +func (s *handshakeFSM) finish(ctx context.Context, c flightConn) (handshakeState, error) { + select { + case state := <-c.recvHandshake(): + close(state.done) + if s.state.isClient { + return handshakeFinished, nil + } else { + return handshakeSending, nil + } + case <-ctx.Done(): + return handshakeErrored, ctx.Err() + } +} diff --git a/vendor/github.com/pion/dtls/v3/internal/ciphersuite/aes_128_ccm.go b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/aes_128_ccm.go new file mode 100644 index 0000000..9731ae4 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/aes_128_ccm.go @@ -0,0 +1,41 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ciphersuite + +import ( + "github.com/pion/dtls/v3/pkg/crypto/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" +) + +// Aes128Ccm is a base class used by multiple AES-CCM Ciphers. +type Aes128Ccm struct { + AesCcm +} + +func newAes128Ccm( + clientCertificateType clientcertificate.Type, + id ID, + psk bool, + cryptoCCMTagLen ciphersuite.CCMTagLen, + keyExchangeAlgorithm KeyExchangeAlgorithm, + ecc bool, +) *Aes128Ccm { + return &Aes128Ccm{ + AesCcm: AesCcm{ + clientCertificateType: clientCertificateType, + id: id, + psk: psk, + cryptoCCMTagLen: cryptoCCMTagLen, + keyExchangeAlgorithm: keyExchangeAlgorithm, + ecc: ecc, + }, + } +} + +// Init initializes the internal Cipher with keying material. +func (c *Aes128Ccm) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error { + const prfKeyLen = 16 + + return c.AesCcm.Init(masterSecret, clientRandom, serverRandom, isClient, prfKeyLen) +} diff --git a/vendor/github.com/pion/dtls/v3/internal/ciphersuite/aes_256_ccm.go b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/aes_256_ccm.go new file mode 100644 index 0000000..c0e1a0b --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/aes_256_ccm.go @@ -0,0 +1,41 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ciphersuite + +import ( + "github.com/pion/dtls/v3/pkg/crypto/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" +) + +// Aes256Ccm is a base class used by multiple AES-CCM Ciphers. +type Aes256Ccm struct { + AesCcm +} + +func newAes256Ccm( + clientCertificateType clientcertificate.Type, + id ID, + psk bool, + cryptoCCMTagLen ciphersuite.CCMTagLen, + keyExchangeAlgorithm KeyExchangeAlgorithm, + ecc bool, +) *Aes256Ccm { + return &Aes256Ccm{ + AesCcm: AesCcm{ + clientCertificateType: clientCertificateType, + id: id, + psk: psk, + cryptoCCMTagLen: cryptoCCMTagLen, + keyExchangeAlgorithm: keyExchangeAlgorithm, + ecc: ecc, + }, + } +} + +// Init initializes the internal Cipher with keying material. +func (c *Aes256Ccm) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error { + const prfKeyLen = 32 + + return c.AesCcm.Init(masterSecret, clientRandom, serverRandom, isClient, prfKeyLen) +} diff --git a/vendor/github.com/pion/dtls/v3/internal/ciphersuite/aes_ccm.go b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/aes_ccm.go new file mode 100644 index 0000000..5ec1f57 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/aes_ccm.go @@ -0,0 +1,120 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ciphersuite + +import ( + "crypto/sha256" + "fmt" + "hash" + "sync/atomic" + + "github.com/pion/dtls/v3/pkg/crypto/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v3/pkg/crypto/prf" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" +) + +// AesCcm is a base class used by multiple AES-CCM Ciphers. +type AesCcm struct { + ccm atomic.Value // *cryptoCCM + clientCertificateType clientcertificate.Type + id ID + psk bool + keyExchangeAlgorithm KeyExchangeAlgorithm + cryptoCCMTagLen ciphersuite.CCMTagLen + ecc bool +} + +// CertificateType returns what type of certificate this CipherSuite exchanges. +func (c *AesCcm) CertificateType() clientcertificate.Type { + return c.clientCertificateType +} + +// ID returns the ID of the CipherSuite. +func (c *AesCcm) ID() ID { + return c.id +} + +func (c *AesCcm) String() string { + return c.id.String() +} + +// ECC uses Elliptic Curve Cryptography. +func (c *AesCcm) ECC() bool { + return c.ecc +} + +// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake. +func (c *AesCcm) KeyExchangeAlgorithm() KeyExchangeAlgorithm { + return c.keyExchangeAlgorithm +} + +// HashFunc returns the hashing func for this CipherSuite. +func (c *AesCcm) HashFunc() func() hash.Hash { + return sha256.New +} + +// AuthenticationType controls what authentication method is using during the handshake. +func (c *AesCcm) AuthenticationType() AuthenticationType { + if c.psk { + return AuthenticationTypePreSharedKey + } + + return AuthenticationTypeCertificate +} + +// IsInitialized returns if the CipherSuite has keying material and can +// encrypt/decrypt packets. +func (c *AesCcm) IsInitialized() bool { + return c.ccm.Load() != nil +} + +// Init initializes the internal Cipher with keying material. +func (c *AesCcm) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool, prfKeyLen int) error { + const ( + prfMacLen = 0 + prfIvLen = 4 + ) + + keys, err := prf.GenerateEncryptionKeys( + masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc(), + ) + if err != nil { + return err + } + + var ccm *ciphersuite.CCM + if isClient { + ccm, err = ciphersuite.NewCCM( + c.cryptoCCMTagLen, keys.ClientWriteKey, keys.ClientWriteIV, keys.ServerWriteKey, keys.ServerWriteIV, + ) + } else { + ccm, err = ciphersuite.NewCCM( + c.cryptoCCMTagLen, keys.ServerWriteKey, keys.ServerWriteIV, keys.ClientWriteKey, keys.ClientWriteIV, + ) + } + c.ccm.Store(ccm) + + return err +} + +// Encrypt encrypts a single TLS RecordLayer. +func (c *AesCcm) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { + cipherSuite, ok := c.ccm.Load().(*ciphersuite.CCM) + if !ok { + return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit) + } + + return cipherSuite.Encrypt(pkt, raw) +} + +// Decrypt decrypts a single TLS RecordLayer. +func (c *AesCcm) Decrypt(h recordlayer.Header, raw []byte) ([]byte, error) { + cipherSuite, ok := c.ccm.Load().(*ciphersuite.CCM) + if !ok { + return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit) + } + + return cipherSuite.Decrypt(h, raw) +} diff --git a/vendor/github.com/pion/dtls/v3/internal/ciphersuite/ciphersuite.go b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/ciphersuite.go new file mode 100644 index 0000000..c27e319 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/ciphersuite.go @@ -0,0 +1,100 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package ciphersuite provides TLS Ciphers as registered with the IANA +// https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-4 +package ciphersuite + +import ( + "errors" + "fmt" + + "github.com/pion/dtls/v3/internal/ciphersuite/types" + "github.com/pion/dtls/v3/pkg/protocol" +) + +//nolint:err113 +var errCipherSuiteNotInit = &protocol.TemporaryError{Err: errors.New("CipherSuite has not been initialized")} + +// ID is an ID for our supported CipherSuites. +type ID uint16 + +func (i ID) String() string { //nolint:cyclop + switch i { + case TLS_ECDHE_ECDSA_WITH_AES_128_CCM: + return "TLS_ECDHE_ECDSA_WITH_AES_128_CCM" + case TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8: + return "TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8" + case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: + return "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" + case TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: + return "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" + case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: + return "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA" + case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: + return "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA" + case TLS_PSK_WITH_AES_128_CCM: + return "TLS_PSK_WITH_AES_128_CCM" + case TLS_PSK_WITH_AES_128_CCM_8: + return "TLS_PSK_WITH_AES_128_CCM_8" + case TLS_PSK_WITH_AES_256_CCM_8: + return "TLS_PSK_WITH_AES_256_CCM_8" + case TLS_PSK_WITH_AES_128_GCM_SHA256: + return "TLS_PSK_WITH_AES_128_GCM_SHA256" + case TLS_PSK_WITH_AES_128_CBC_SHA256: + return "TLS_PSK_WITH_AES_128_CBC_SHA256" + case TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: + return "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" + case TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: + return "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384" + case TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256: + return "TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256" + default: + return fmt.Sprintf("unknown(%v)", uint16(i)) + } +} + +// Supported Cipher Suites. +const ( + // AES-128-CCM. + TLS_ECDHE_ECDSA_WITH_AES_128_CCM ID = 0xc0ac // nolint: revive,staticcheck + TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 ID = 0xc0ae // nolint: revive,staticcheck + + // AES-128-GCM-SHA256. + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 ID = 0xc02b // nolint: revive,staticcheck + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 ID = 0xc02f // nolint: revive,staticcheck + + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 ID = 0xc02c // nolint: revive,staticcheck + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 ID = 0xc030 // nolint: revive,staticcheck + // AES-256-CBC-SHA. + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA ID = 0xc00a // nolint: revive,staticcheck + TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA ID = 0xc014 // nolint: revive,staticcheck + + TLS_PSK_WITH_AES_128_CCM ID = 0xc0a4 // nolint: revive,staticcheck + TLS_PSK_WITH_AES_128_CCM_8 ID = 0xc0a8 // nolint: revive,staticcheck + TLS_PSK_WITH_AES_256_CCM_8 ID = 0xc0a9 // nolint: revive,staticcheck + TLS_PSK_WITH_AES_128_GCM_SHA256 ID = 0x00a8 // nolint: revive,staticcheck + TLS_PSK_WITH_AES_128_CBC_SHA256 ID = 0x00ae // nolint: revive,staticcheck + + TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 ID = 0xC037 // nolint: revive,staticcheck +) + +// AuthenticationType controls what authentication method is using during the handshake. +type AuthenticationType = types.AuthenticationType + +// AuthenticationType Enums. +const ( + AuthenticationTypeCertificate AuthenticationType = types.AuthenticationTypeCertificate + AuthenticationTypePreSharedKey AuthenticationType = types.AuthenticationTypePreSharedKey + AuthenticationTypeAnonymous AuthenticationType = types.AuthenticationTypeAnonymous +) + +// KeyExchangeAlgorithm controls what exchange algorithm was chosen. +type KeyExchangeAlgorithm = types.KeyExchangeAlgorithm + +// KeyExchangeAlgorithm Bitmask. +const ( + KeyExchangeAlgorithmNone KeyExchangeAlgorithm = types.KeyExchangeAlgorithmNone + KeyExchangeAlgorithmPsk KeyExchangeAlgorithm = types.KeyExchangeAlgorithmPsk + KeyExchangeAlgorithmEcdhe KeyExchangeAlgorithm = types.KeyExchangeAlgorithmEcdhe +) diff --git a/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_ccm.go b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_ccm.go new file mode 100644 index 0000000..07b38c7 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_ccm.go @@ -0,0 +1,21 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ciphersuite + +import ( + "github.com/pion/dtls/v3/pkg/crypto/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" +) + +// NewTLSEcdheEcdsaWithAes128Ccm constructs a TLS_ECDHE_ECDSA_WITH_AES_128_CCM Cipher. +func NewTLSEcdheEcdsaWithAes128Ccm() *Aes128Ccm { + return newAes128Ccm( + clientcertificate.ECDSASign, + TLS_ECDHE_ECDSA_WITH_AES_128_CCM, + false, + ciphersuite.CCMTagLength, + KeyExchangeAlgorithmEcdhe, + true, + ) +} diff --git a/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_ccm8.go b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_ccm8.go new file mode 100644 index 0000000..4a02551 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_ccm8.go @@ -0,0 +1,21 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ciphersuite + +import ( + "github.com/pion/dtls/v3/pkg/crypto/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" +) + +// NewTLSEcdheEcdsaWithAes128Ccm8 creates a new TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 CipherSuite. +func NewTLSEcdheEcdsaWithAes128Ccm8() *Aes128Ccm { + return newAes128Ccm( + clientcertificate.ECDSASign, + TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8, + false, + ciphersuite.CCMTagLength8, + KeyExchangeAlgorithmEcdhe, + true, + ) +} diff --git a/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_gcm_sha256.go b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_gcm_sha256.go new file mode 100644 index 0000000..6f38618 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_gcm_sha256.go @@ -0,0 +1,116 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ciphersuite + +import ( + "crypto/sha256" + "fmt" + "hash" + "sync/atomic" + + "github.com/pion/dtls/v3/pkg/crypto/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v3/pkg/crypto/prf" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" +) + +// TLSEcdheEcdsaWithAes128GcmSha256 represents a TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuite. +type TLSEcdheEcdsaWithAes128GcmSha256 struct { + gcm atomic.Value // *cryptoGCM +} + +// CertificateType returns what type of certficate this CipherSuite exchanges. +func (c *TLSEcdheEcdsaWithAes128GcmSha256) CertificateType() clientcertificate.Type { + return clientcertificate.ECDSASign +} + +// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake. +func (c *TLSEcdheEcdsaWithAes128GcmSha256) KeyExchangeAlgorithm() KeyExchangeAlgorithm { + return KeyExchangeAlgorithmEcdhe +} + +// ECC uses Elliptic Curve Cryptography. +func (c *TLSEcdheEcdsaWithAes128GcmSha256) ECC() bool { + return true +} + +// ID returns the ID of the CipherSuite. +func (c *TLSEcdheEcdsaWithAes128GcmSha256) ID() ID { + return TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 +} + +func (c *TLSEcdheEcdsaWithAes128GcmSha256) String() string { + return "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" +} + +// HashFunc returns the hashing func for this CipherSuite. +func (c *TLSEcdheEcdsaWithAes128GcmSha256) HashFunc() func() hash.Hash { + return sha256.New +} + +// AuthenticationType controls what authentication method is using during the handshake. +func (c *TLSEcdheEcdsaWithAes128GcmSha256) AuthenticationType() AuthenticationType { + return AuthenticationTypeCertificate +} + +// IsInitialized returns if the CipherSuite has keying material and can +// encrypt/decrypt packets. +func (c *TLSEcdheEcdsaWithAes128GcmSha256) IsInitialized() bool { + return c.gcm.Load() != nil +} + +func (c *TLSEcdheEcdsaWithAes128GcmSha256) init( + masterSecret, clientRandom, serverRandom []byte, + isClient bool, + prfMacLen, prfKeyLen, prfIvLen int, + hashFunc func() hash.Hash, +) error { + keys, err := prf.GenerateEncryptionKeys( + masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, hashFunc, + ) + if err != nil { + return err + } + + var gcm *ciphersuite.GCM + if isClient { + gcm, err = ciphersuite.NewGCM(keys.ClientWriteKey, keys.ClientWriteIV, keys.ServerWriteKey, keys.ServerWriteIV) + } else { + gcm, err = ciphersuite.NewGCM(keys.ServerWriteKey, keys.ServerWriteIV, keys.ClientWriteKey, keys.ClientWriteIV) + } + c.gcm.Store(gcm) + + return err +} + +// Init initializes the internal Cipher with keying material. +func (c *TLSEcdheEcdsaWithAes128GcmSha256) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error { + const ( + prfMacLen = 0 + prfKeyLen = 16 + prfIvLen = 4 + ) + + return c.init(masterSecret, clientRandom, serverRandom, isClient, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc()) +} + +// Encrypt encrypts a single TLS RecordLayer. +func (c *TLSEcdheEcdsaWithAes128GcmSha256) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { + cipherSuite, ok := c.gcm.Load().(*ciphersuite.GCM) + if !ok { + return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit) + } + + return cipherSuite.Encrypt(pkt, raw) +} + +// Decrypt decrypts a single TLS RecordLayer. +func (c *TLSEcdheEcdsaWithAes128GcmSha256) Decrypt(h recordlayer.Header, raw []byte) ([]byte, error) { + cipherSuite, ok := c.gcm.Load().(*ciphersuite.GCM) + if !ok { + return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit) + } + + return cipherSuite.Decrypt(h, raw) +} diff --git a/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_256_cbc_sha.go b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_256_cbc_sha.go new file mode 100644 index 0000000..0dec9af --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_256_cbc_sha.go @@ -0,0 +1,116 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ciphersuite + +import ( + "crypto/sha1" //nolint: gosec,gci + "crypto/sha256" + "fmt" + "hash" + "sync/atomic" + + "github.com/pion/dtls/v3/pkg/crypto/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v3/pkg/crypto/prf" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" +) + +// TLSEcdheEcdsaWithAes256CbcSha represents a TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA CipherSuite. +type TLSEcdheEcdsaWithAes256CbcSha struct { + cbc atomic.Value // *cryptoCBC +} + +// CertificateType returns what type of certficate this CipherSuite exchanges. +func (c *TLSEcdheEcdsaWithAes256CbcSha) CertificateType() clientcertificate.Type { + return clientcertificate.ECDSASign +} + +// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake. +func (c *TLSEcdheEcdsaWithAes256CbcSha) KeyExchangeAlgorithm() KeyExchangeAlgorithm { + return KeyExchangeAlgorithmEcdhe +} + +// ECC uses Elliptic Curve Cryptography. +func (c *TLSEcdheEcdsaWithAes256CbcSha) ECC() bool { + return true +} + +// ID returns the ID of the CipherSuite. +func (c *TLSEcdheEcdsaWithAes256CbcSha) ID() ID { + return TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA +} + +func (c *TLSEcdheEcdsaWithAes256CbcSha) String() string { + return "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA" +} + +// HashFunc returns the hashing func for this CipherSuite. +func (c *TLSEcdheEcdsaWithAes256CbcSha) HashFunc() func() hash.Hash { + return sha256.New +} + +// AuthenticationType controls what authentication method is using during the handshake. +func (c *TLSEcdheEcdsaWithAes256CbcSha) AuthenticationType() AuthenticationType { + return AuthenticationTypeCertificate +} + +// IsInitialized returns if the CipherSuite has keying material and can +// encrypt/decrypt packets. +func (c *TLSEcdheEcdsaWithAes256CbcSha) IsInitialized() bool { + return c.cbc.Load() != nil +} + +// Init initializes the internal Cipher with keying material. +func (c *TLSEcdheEcdsaWithAes256CbcSha) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error { + const ( + prfMacLen = 20 + prfKeyLen = 32 + prfIvLen = 16 + ) + + keys, err := prf.GenerateEncryptionKeys( + masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc(), + ) + if err != nil { + return err + } + + var cbc *ciphersuite.CBC + if isClient { + cbc, err = ciphersuite.NewCBC( + keys.ClientWriteKey, keys.ClientWriteIV, keys.ClientMACKey, + keys.ServerWriteKey, keys.ServerWriteIV, keys.ServerMACKey, + sha1.New, + ) + } else { + cbc, err = ciphersuite.NewCBC( + keys.ServerWriteKey, keys.ServerWriteIV, keys.ServerMACKey, + keys.ClientWriteKey, keys.ClientWriteIV, keys.ClientMACKey, + sha1.New, + ) + } + c.cbc.Store(cbc) + + return err +} + +// Encrypt encrypts a single TLS RecordLayer. +func (c *TLSEcdheEcdsaWithAes256CbcSha) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { + cipherSuite, ok := c.cbc.Load().(*ciphersuite.CBC) + if !ok { + return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit) + } + + return cipherSuite.Encrypt(pkt, raw) +} + +// Decrypt decrypts a single TLS RecordLayer. +func (c *TLSEcdheEcdsaWithAes256CbcSha) Decrypt(h recordlayer.Header, raw []byte) ([]byte, error) { + cipherSuite, ok := c.cbc.Load().(*ciphersuite.CBC) + if !ok { + return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit) + } + + return cipherSuite.Decrypt(h, raw) +} diff --git a/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_256_gcm_sha384.go b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_256_gcm_sha384.go new file mode 100644 index 0000000..6b29120 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_256_gcm_sha384.go @@ -0,0 +1,39 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ciphersuite + +import ( + "crypto/sha512" + "hash" +) + +// TLSEcdheEcdsaWithAes256GcmSha384 represents a TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuite. +type TLSEcdheEcdsaWithAes256GcmSha384 struct { + TLSEcdheEcdsaWithAes128GcmSha256 +} + +// ID returns the ID of the CipherSuite. +func (c *TLSEcdheEcdsaWithAes256GcmSha384) ID() ID { + return TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 +} + +func (c *TLSEcdheEcdsaWithAes256GcmSha384) String() string { + return "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" +} + +// HashFunc returns the hashing func for this CipherSuite. +func (c *TLSEcdheEcdsaWithAes256GcmSha384) HashFunc() func() hash.Hash { + return sha512.New384 +} + +// Init initializes the internal Cipher with keying material. +func (c *TLSEcdheEcdsaWithAes256GcmSha384) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error { + const ( + prfMacLen = 0 + prfKeyLen = 32 + prfIvLen = 4 + ) + + return c.init(masterSecret, clientRandom, serverRandom, isClient, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc()) +} diff --git a/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_ecdhe_psk_with_aes_128_cbc_sha256.go b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_ecdhe_psk_with_aes_128_cbc_sha256.go new file mode 100644 index 0000000..dbf9b4b --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_ecdhe_psk_with_aes_128_cbc_sha256.go @@ -0,0 +1,120 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ciphersuite + +import ( + "crypto/sha256" + "fmt" + "hash" + "sync/atomic" + + "github.com/pion/dtls/v3/pkg/crypto/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v3/pkg/crypto/prf" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" +) + +// TLSEcdhePskWithAes128CbcSha256 implements the TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 CipherSuite. +type TLSEcdhePskWithAes128CbcSha256 struct { + cbc atomic.Value // *cryptoCBC +} + +// NewTLSEcdhePskWithAes128CbcSha256 creates TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 cipher. +func NewTLSEcdhePskWithAes128CbcSha256() *TLSEcdhePskWithAes128CbcSha256 { + return &TLSEcdhePskWithAes128CbcSha256{} +} + +// CertificateType returns what type of certificate this CipherSuite exchanges. +func (c *TLSEcdhePskWithAes128CbcSha256) CertificateType() clientcertificate.Type { + return clientcertificate.Type(0) +} + +// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake. +func (c *TLSEcdhePskWithAes128CbcSha256) KeyExchangeAlgorithm() KeyExchangeAlgorithm { + return (KeyExchangeAlgorithmPsk | KeyExchangeAlgorithmEcdhe) +} + +// ECC uses Elliptic Curve Cryptography. +func (c *TLSEcdhePskWithAes128CbcSha256) ECC() bool { + return true +} + +// ID returns the ID of the CipherSuite. +func (c *TLSEcdhePskWithAes128CbcSha256) ID() ID { + return TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 +} + +func (c *TLSEcdhePskWithAes128CbcSha256) String() string { + return "TLS-ECDHE-PSK-WITH-AES-128-CBC-SHA256" +} + +// HashFunc returns the hashing func for this CipherSuite. +func (c *TLSEcdhePskWithAes128CbcSha256) HashFunc() func() hash.Hash { + return sha256.New +} + +// AuthenticationType controls what authentication method is using during the handshake. +func (c *TLSEcdhePskWithAes128CbcSha256) AuthenticationType() AuthenticationType { + return AuthenticationTypePreSharedKey +} + +// IsInitialized returns if the CipherSuite has keying material and can +// encrypt/decrypt packets. +func (c *TLSEcdhePskWithAes128CbcSha256) IsInitialized() bool { + return c.cbc.Load() != nil +} + +// Init initializes the internal Cipher with keying material. +func (c *TLSEcdhePskWithAes128CbcSha256) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error { + const ( + prfMacLen = 32 + prfKeyLen = 16 + prfIvLen = 16 + ) + + keys, err := prf.GenerateEncryptionKeys( + masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc(), + ) + if err != nil { + return err + } + + var cbc *ciphersuite.CBC + if isClient { + cbc, err = ciphersuite.NewCBC( + keys.ClientWriteKey, keys.ClientWriteIV, keys.ClientMACKey, + keys.ServerWriteKey, keys.ServerWriteIV, keys.ServerMACKey, + c.HashFunc(), + ) + } else { + cbc, err = ciphersuite.NewCBC( + keys.ServerWriteKey, keys.ServerWriteIV, keys.ServerMACKey, + keys.ClientWriteKey, keys.ClientWriteIV, keys.ClientMACKey, + c.HashFunc(), + ) + } + c.cbc.Store(cbc) + + return err +} + +// Encrypt encrypts a single TLS RecordLayer. +func (c *TLSEcdhePskWithAes128CbcSha256) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { + cipherSuite, ok := c.cbc.Load().(*ciphersuite.CBC) + if !ok { // !c.isInitialized() + return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit) + } + + return cipherSuite.Encrypt(pkt, raw) +} + +// Decrypt decrypts a single TLS RecordLayer. +func (c *TLSEcdhePskWithAes128CbcSha256) Decrypt(h recordlayer.Header, raw []byte) ([]byte, error) { + cipherSuite, ok := c.cbc.Load().(*ciphersuite.CBC) + if !ok { // !c.isInitialized() + return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit) + } + + return cipherSuite.Decrypt(h, raw) +} diff --git a/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_ecdhe_rsa_with_aes_128_gcm_sha256.go b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_ecdhe_rsa_with_aes_128_gcm_sha256.go new file mode 100644 index 0000000..bdc052a --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_ecdhe_rsa_with_aes_128_gcm_sha256.go @@ -0,0 +1,25 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ciphersuite + +import "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" + +// TLSEcdheRsaWithAes128GcmSha256 implements the TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 CipherSuite. +type TLSEcdheRsaWithAes128GcmSha256 struct { + TLSEcdheEcdsaWithAes128GcmSha256 +} + +// CertificateType returns what type of certificate this CipherSuite exchanges. +func (c *TLSEcdheRsaWithAes128GcmSha256) CertificateType() clientcertificate.Type { + return clientcertificate.RSASign +} + +// ID returns the ID of the CipherSuite. +func (c *TLSEcdheRsaWithAes128GcmSha256) ID() ID { + return TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 +} + +func (c *TLSEcdheRsaWithAes128GcmSha256) String() string { + return "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" +} diff --git a/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_ecdhe_rsa_with_aes_256_cbc_sha.go b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_ecdhe_rsa_with_aes_256_cbc_sha.go new file mode 100644 index 0000000..b268d1f --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_ecdhe_rsa_with_aes_256_cbc_sha.go @@ -0,0 +1,25 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ciphersuite + +import "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" + +// TLSEcdheRsaWithAes256CbcSha implements the TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA CipherSuite. +type TLSEcdheRsaWithAes256CbcSha struct { + TLSEcdheEcdsaWithAes256CbcSha +} + +// CertificateType returns what type of certificate this CipherSuite exchanges. +func (c *TLSEcdheRsaWithAes256CbcSha) CertificateType() clientcertificate.Type { + return clientcertificate.RSASign +} + +// ID returns the ID of the CipherSuite. +func (c *TLSEcdheRsaWithAes256CbcSha) ID() ID { + return TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA +} + +func (c *TLSEcdheRsaWithAes256CbcSha) String() string { + return "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA" +} diff --git a/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_ecdhe_rsa_with_aes_256_gcm_sha384.go b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_ecdhe_rsa_with_aes_256_gcm_sha384.go new file mode 100644 index 0000000..36bb507 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_ecdhe_rsa_with_aes_256_gcm_sha384.go @@ -0,0 +1,25 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ciphersuite + +import "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" + +// TLSEcdheRsaWithAes256GcmSha384 implements the TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 CipherSuite. +type TLSEcdheRsaWithAes256GcmSha384 struct { + TLSEcdheEcdsaWithAes256GcmSha384 +} + +// CertificateType returns what type of certificate this CipherSuite exchanges. +func (c *TLSEcdheRsaWithAes256GcmSha384) CertificateType() clientcertificate.Type { + return clientcertificate.RSASign +} + +// ID returns the ID of the CipherSuite. +func (c *TLSEcdheRsaWithAes256GcmSha384) ID() ID { + return TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 +} + +func (c *TLSEcdheRsaWithAes256GcmSha384) String() string { + return "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384" +} diff --git a/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_psk_with_aes_128_cbc_sha256.go b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_psk_with_aes_128_cbc_sha256.go new file mode 100644 index 0000000..16a7625 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_psk_with_aes_128_cbc_sha256.go @@ -0,0 +1,115 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ciphersuite + +import ( + "crypto/sha256" + "fmt" + "hash" + "sync/atomic" + + "github.com/pion/dtls/v3/pkg/crypto/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v3/pkg/crypto/prf" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" +) + +// TLSPskWithAes128CbcSha256 implements the TLS_PSK_WITH_AES_128_CBC_SHA256 CipherSuite. +type TLSPskWithAes128CbcSha256 struct { + cbc atomic.Value // *cryptoCBC +} + +// CertificateType returns what type of certificate this CipherSuite exchanges. +func (c *TLSPskWithAes128CbcSha256) CertificateType() clientcertificate.Type { + return clientcertificate.Type(0) +} + +// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake. +func (c *TLSPskWithAes128CbcSha256) KeyExchangeAlgorithm() KeyExchangeAlgorithm { + return KeyExchangeAlgorithmPsk +} + +// ECC uses Elliptic Curve Cryptography. +func (c *TLSPskWithAes128CbcSha256) ECC() bool { + return false +} + +// ID returns the ID of the CipherSuite. +func (c *TLSPskWithAes128CbcSha256) ID() ID { + return TLS_PSK_WITH_AES_128_CBC_SHA256 +} + +func (c *TLSPskWithAes128CbcSha256) String() string { + return "TLS_PSK_WITH_AES_128_CBC_SHA256" +} + +// HashFunc returns the hashing func for this CipherSuite. +func (c *TLSPskWithAes128CbcSha256) HashFunc() func() hash.Hash { + return sha256.New +} + +// AuthenticationType controls what authentication method is using during the handshake. +func (c *TLSPskWithAes128CbcSha256) AuthenticationType() AuthenticationType { + return AuthenticationTypePreSharedKey +} + +// IsInitialized returns if the CipherSuite has keying material and can +// encrypt/decrypt packets. +func (c *TLSPskWithAes128CbcSha256) IsInitialized() bool { + return c.cbc.Load() != nil +} + +// Init initializes the internal Cipher with keying material. +func (c *TLSPskWithAes128CbcSha256) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error { + const ( + prfMacLen = 32 + prfKeyLen = 16 + prfIvLen = 16 + ) + + keys, err := prf.GenerateEncryptionKeys( + masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc(), + ) + if err != nil { + return err + } + + var cbc *ciphersuite.CBC + if isClient { + cbc, err = ciphersuite.NewCBC( + keys.ClientWriteKey, keys.ClientWriteIV, keys.ClientMACKey, + keys.ServerWriteKey, keys.ServerWriteIV, keys.ServerMACKey, + c.HashFunc(), + ) + } else { + cbc, err = ciphersuite.NewCBC( + keys.ServerWriteKey, keys.ServerWriteIV, keys.ServerMACKey, + keys.ClientWriteKey, keys.ClientWriteIV, keys.ClientMACKey, + c.HashFunc(), + ) + } + c.cbc.Store(cbc) + + return err +} + +// Encrypt encrypts a single TLS RecordLayer. +func (c *TLSPskWithAes128CbcSha256) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { + cipherSuite, ok := c.cbc.Load().(*ciphersuite.CBC) + if !ok { + return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit) + } + + return cipherSuite.Encrypt(pkt, raw) +} + +// Decrypt decrypts a single TLS RecordLayer. +func (c *TLSPskWithAes128CbcSha256) Decrypt(h recordlayer.Header, raw []byte) ([]byte, error) { + cipherSuite, ok := c.cbc.Load().(*ciphersuite.CBC) + if !ok { + return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit) + } + + return cipherSuite.Decrypt(h, raw) +} diff --git a/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_psk_with_aes_128_ccm.go b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_psk_with_aes_128_ccm.go new file mode 100644 index 0000000..a4fda1c --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_psk_with_aes_128_ccm.go @@ -0,0 +1,21 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ciphersuite + +import ( + "github.com/pion/dtls/v3/pkg/crypto/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" +) + +// NewTLSPskWithAes128Ccm returns the TLS_PSK_WITH_AES_128_CCM CipherSuite. +func NewTLSPskWithAes128Ccm() *Aes128Ccm { + return newAes128Ccm( + clientcertificate.Type(0), + TLS_PSK_WITH_AES_128_CCM, + true, + ciphersuite.CCMTagLength, + KeyExchangeAlgorithmPsk, + false, + ) +} diff --git a/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_psk_with_aes_128_ccm8.go b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_psk_with_aes_128_ccm8.go new file mode 100644 index 0000000..66170f3 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_psk_with_aes_128_ccm8.go @@ -0,0 +1,21 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ciphersuite + +import ( + "github.com/pion/dtls/v3/pkg/crypto/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" +) + +// NewTLSPskWithAes128Ccm8 returns the TLS_PSK_WITH_AES_128_CCM_8 CipherSuite. +func NewTLSPskWithAes128Ccm8() *Aes128Ccm { + return newAes128Ccm( + clientcertificate.Type(0), + TLS_PSK_WITH_AES_128_CCM_8, + true, + ciphersuite.CCMTagLength8, + KeyExchangeAlgorithmPsk, + false, + ) +} diff --git a/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_psk_with_aes_128_gcm_sha256.go b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_psk_with_aes_128_gcm_sha256.go new file mode 100644 index 0000000..724ccf7 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_psk_with_aes_128_gcm_sha256.go @@ -0,0 +1,35 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ciphersuite + +import "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" + +// TLSPskWithAes128GcmSha256 implements the TLS_PSK_WITH_AES_128_GCM_SHA256 CipherSuite. +type TLSPskWithAes128GcmSha256 struct { + TLSEcdheEcdsaWithAes128GcmSha256 +} + +// CertificateType returns what type of certificate this CipherSuite exchanges. +func (c *TLSPskWithAes128GcmSha256) CertificateType() clientcertificate.Type { + return clientcertificate.Type(0) +} + +// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake. +func (c *TLSPskWithAes128GcmSha256) KeyExchangeAlgorithm() KeyExchangeAlgorithm { + return KeyExchangeAlgorithmPsk +} + +// ID returns the ID of the CipherSuite. +func (c *TLSPskWithAes128GcmSha256) ID() ID { + return TLS_PSK_WITH_AES_128_GCM_SHA256 +} + +func (c *TLSPskWithAes128GcmSha256) String() string { + return "TLS_PSK_WITH_AES_128_GCM_SHA256" +} + +// AuthenticationType controls what authentication method is using during the handshake. +func (c *TLSPskWithAes128GcmSha256) AuthenticationType() AuthenticationType { + return AuthenticationTypePreSharedKey +} diff --git a/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_psk_with_aes_256_ccm8.go b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_psk_with_aes_256_ccm8.go new file mode 100644 index 0000000..770feed --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/tls_psk_with_aes_256_ccm8.go @@ -0,0 +1,21 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ciphersuite + +import ( + "github.com/pion/dtls/v3/pkg/crypto/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" +) + +// NewTLSPskWithAes256Ccm8 returns the TLS_PSK_WITH_AES_256_CCM_8 CipherSuite. +func NewTLSPskWithAes256Ccm8() *Aes256Ccm { + return newAes256Ccm( + clientcertificate.Type(0), + TLS_PSK_WITH_AES_256_CCM_8, + true, + ciphersuite.CCMTagLength8, + KeyExchangeAlgorithmPsk, + false, + ) +} diff --git a/vendor/github.com/pion/dtls/v3/internal/ciphersuite/types/authentication_type.go b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/types/authentication_type.go new file mode 100644 index 0000000..681a853 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/types/authentication_type.go @@ -0,0 +1,14 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package types + +// AuthenticationType controls what authentication method is using during the handshake. +type AuthenticationType int + +// AuthenticationType Enums. +const ( + AuthenticationTypeCertificate AuthenticationType = iota + 1 + AuthenticationTypePreSharedKey + AuthenticationTypeAnonymous +) diff --git a/vendor/github.com/pion/dtls/v3/internal/ciphersuite/types/key_exchange_algorithm.go b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/types/key_exchange_algorithm.go new file mode 100644 index 0000000..5e34aaa --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/internal/ciphersuite/types/key_exchange_algorithm.go @@ -0,0 +1,20 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package types provides types for TLS Ciphers +package types // nolint:revive + +// KeyExchangeAlgorithm controls what exchange algorithm was chosen. +type KeyExchangeAlgorithm int + +// KeyExchangeAlgorithm Bitmask. +const ( + KeyExchangeAlgorithmNone KeyExchangeAlgorithm = 0 + KeyExchangeAlgorithmPsk KeyExchangeAlgorithm = iota << 1 + KeyExchangeAlgorithmEcdhe +) + +// Has check if keyExchangeAlgorithm is supported. +func (a KeyExchangeAlgorithm) Has(v KeyExchangeAlgorithm) bool { + return (a & v) == v +} diff --git a/vendor/github.com/pion/dtls/v3/internal/closer/closer.go b/vendor/github.com/pion/dtls/v3/internal/closer/closer.go new file mode 100644 index 0000000..4110948 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/internal/closer/closer.go @@ -0,0 +1,50 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package closer provides signaling channel for shutdown +package closer + +import ( + "context" +) + +// Closer allows for each signaling a channel for shutdown. +type Closer struct { + ctx context.Context //nolint:containedctx + closeFunc func() +} + +// NewCloser creates a new instance of Closer. +func NewCloser() *Closer { + ctx, closeFunc := context.WithCancel(context.Background()) + + return &Closer{ + ctx: ctx, + closeFunc: closeFunc, + } +} + +// NewCloserWithParent creates a new instance of Closer with a parent context. +func NewCloserWithParent(ctx context.Context) *Closer { + ctx, closeFunc := context.WithCancel(ctx) + + return &Closer{ + ctx: ctx, + closeFunc: closeFunc, + } +} + +// Done returns a channel signaling when it is done. +func (c *Closer) Done() <-chan struct{} { + return c.ctx.Done() +} + +// Err returns an error of the context. +func (c *Closer) Err() error { + return c.ctx.Err() +} + +// Close sends a signal to trigger the ctx done channel. +func (c *Closer) Close() { + c.closeFunc() +} diff --git a/vendor/github.com/pion/dtls/v3/internal/net/buffer.go b/vendor/github.com/pion/dtls/v3/internal/net/buffer.go new file mode 100644 index 0000000..6ac67c9 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/internal/net/buffer.go @@ -0,0 +1,242 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package net implements DTLS specific networking primitives. +// NOTE: this package is an adaption of pion/transport/packetio that allows for +// storing a remote address alongside each packet in the buffer and implements +// relevant methods of net.PacketConn. If possible, the updates made in this +// repository will be reflected back upstream. If not, it is likely that this +// will be moved to a public package in this repository. +// +// This package was migrated from pion/transport/packetio at +// https://github.com/pion/transport/commit/6890c795c807a617c054149eee40a69d7fdfbfdb +package net + +import ( + "bytes" + "errors" + "io" + "net" + "sync" + "time" + + "github.com/pion/transport/v4/deadline" +) + +// ErrTimeout indicates that deadline was reached before operation could be +// completed. +var ErrTimeout = errors.New("buffer: i/o timeout") + +// AddrPacket is a packet payload and the associated remote address from which +// it was received. +type AddrPacket struct { + addr net.Addr + data bytes.Buffer +} + +// PacketBuffer is a circular buffer for network packets. Each slot in the +// buffer contains the remote address from which the packet was received, as +// well as the packet data. +type PacketBuffer struct { + mutex sync.Mutex + + packets []AddrPacket + write, read int + + // full indicates whether the buffer is full, which is needed to distinguish + // when the write pointer and read pointer are at the same index. + full bool + + notify chan struct{} + closed bool + + readDeadline *deadline.Deadline +} + +// NewPacketBuffer creates a new PacketBuffer. +func NewPacketBuffer() *PacketBuffer { + return &PacketBuffer{ + readDeadline: deadline.New(), + // In the narrow context in which this package is currently used, there + // will always be at least one packet written to the buffer. Therefore, + // we opt to allocate with size of 1 during construction, rather than + // waiting until that first packet is written. + packets: make([]AddrPacket, 1), + full: false, + } +} + +// WriteTo writes a single packet to the buffer. The supplied address will +// remain associated with the packet. +func (b *PacketBuffer) WriteTo(pkt []byte, addr net.Addr) (int, error) { + b.mutex.Lock() + + if b.closed { + b.mutex.Unlock() + + return 0, io.ErrClosedPipe + } + + var notify chan struct{} + if b.notify != nil { + notify = b.notify + b.notify = nil + } + + // Check to see if we are full. + if b.full { + // If so, grow AddrPacket buffer. + var newSize int + if len(b.packets) < 128 { + // Double the number of packets. + newSize = len(b.packets) * 2 + } else { + // Increase the number of packets by 25%. + newSize = 5 * len(b.packets) / 4 + } + newBuf := make([]AddrPacket, newSize) + var n int + if b.read < b.write { + n = copy(newBuf, b.packets[b.read:b.write]) + } else { + n = copy(newBuf, b.packets[b.read:]) + n += copy(newBuf[n:], b.packets[:b.write]) + } + + b.packets = newBuf + + // Update read/write pointers and mark buffer as not full. + b.read = 0 + b.write = n + b.full = false + } + + // Store the packet at the write pointer. + packet := &b.packets[b.write] + packet.data.Reset() + n, err := packet.data.Write(pkt) + if err != nil { + b.mutex.Unlock() + + return n, err + } + packet.addr = addr + + // Increment write pointer. + b.write++ + + // If the write pointer is equal to the length of the buffer, wrap around. + if len(b.packets) == b.write { + b.write = 0 + } + + // If a write resulted in making write and read pointers equivalent, then we + // are full. + if b.write == b.read { + b.full = true + } + + b.mutex.Unlock() + + if notify != nil { + close(notify) + } + + return n, nil +} + +// ReadFrom reads a single packet from the buffer, or blocks until one is +// available. +func (b *PacketBuffer) ReadFrom(packet []byte) (n int, addr net.Addr, err error) { //nolint:cyclop + select { + case <-b.readDeadline.Done(): + return 0, nil, ErrTimeout + default: + } + + for { + b.mutex.Lock() + + if b.read != b.write || b.full { + ap := b.packets[b.read] + if len(packet) < ap.data.Len() { + b.mutex.Unlock() + + return 0, nil, io.ErrShortBuffer + } + + // Copy packet data from buffer. + n, err := ap.data.Read(packet) + if err != nil { + b.mutex.Unlock() + + return n, nil, err + } + + // Advance read pointer. + b.read++ + if len(b.packets) == b.read { + b.read = 0 + } + + // If we were full before reading and have successfully read, we are + // no longer full. + if b.full { + b.full = false + } + + b.mutex.Unlock() + + return n, ap.addr, nil + } + + if b.closed { + b.mutex.Unlock() + + return 0, nil, io.EOF + } + + if b.notify == nil { + b.notify = make(chan struct{}) + } + notify := b.notify + b.mutex.Unlock() + + select { + case <-b.readDeadline.Done(): + return 0, nil, ErrTimeout + case <-notify: + } + } +} + +// Close closes the buffer, allowing unread packets to be read, but erroring on +// any new writes. +func (b *PacketBuffer) Close() (err error) { + b.mutex.Lock() + + if b.closed { + b.mutex.Unlock() + + return nil + } + + notify := b.notify + b.notify = nil + b.closed = true + + b.mutex.Unlock() + + if notify != nil { + close(notify) + } + + return nil +} + +// SetReadDeadline sets the read deadline for the buffer. +func (b *PacketBuffer) SetReadDeadline(t time.Time) error { + b.readDeadline.Set(t) + + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/internal/net/udp/packet_conn.go b/vendor/github.com/pion/dtls/v3/internal/net/udp/packet_conn.go new file mode 100644 index 0000000..1824f2b --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/internal/net/udp/packet_conn.go @@ -0,0 +1,413 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package udp implements DTLS specific UDP networking primitives. +// NOTE: this package is an adaption of pion/transport/udp that allows for +// routing datagrams based on identifiers other than the remote address. The +// primary use case for this functionality is routing based on DTLS connection +// IDs. In order to allow for consumers of this package to treat connections as +// generic net.PackageConn, routing and identitier establishment is based on +// custom introspecion of datagrams, rather than direct intervention by +// consumers. If possible, the updates made in this repository will be reflected +// back upstream. If not, it is likely that this will be moved to a public +// package in this repository. +// +// This package was migrated from pion/transport/udp at +// https://github.com/pion/transport/commit/6890c795c807a617c054149eee40a69d7fdfbfdb +package udp + +import ( + "context" + "errors" + "net" + "sync" + "sync/atomic" + "time" + + idtlsnet "github.com/pion/dtls/v3/internal/net" + dtlsnet "github.com/pion/dtls/v3/pkg/net" + "github.com/pion/transport/v4/deadline" +) + +const ( + receiveMTU = 8192 + defaultListenBacklog = 128 // same as Linux default +) + +// Typed errors. +var ( + ErrClosedListener = errors.New("udp: listener closed") + ErrListenQueueExceeded = errors.New("udp: listen queue exceeded") +) + +// listener augments a connection-oriented Listener over a UDP PacketConn. +type listener struct { + pConn *net.UDPConn + + accepting atomic.Value // bool + acceptCh chan *PacketConn + doneCh chan struct{} + doneOnce sync.Once + acceptFilter func([]byte) bool + datagramRouter func([]byte) (string, bool) + connIdentifier func([]byte) (string, bool) + + connLock sync.Mutex + conns map[string]*PacketConn + connWG sync.WaitGroup + + readWG sync.WaitGroup + errClose atomic.Value // error + + readDoneCh chan struct{} + errRead atomic.Value // error +} + +// Accept waits for and returns the next connection to the listener. +func (l *listener) Accept() (net.PacketConn, net.Addr, error) { + select { + case c := <-l.acceptCh: + l.connWG.Add(1) + + return c, c.raddr, nil + + case <-l.readDoneCh: + err, _ := l.errRead.Load().(error) + + return nil, nil, err + + case <-l.doneCh: + return nil, nil, ErrClosedListener + } +} + +// Close closes the listener. +// Any blocked Accept operations will be unblocked and return errors. +func (l *listener) Close() error { + var err error + l.doneOnce.Do(func() { + l.accepting.Store(false) + close(l.doneCh) + + l.connLock.Lock() + // Close unaccepted connections + lclose: + for { + select { + case c := <-l.acceptCh: + close(c.doneCh) + // If we have an alternate identifier, remove it from the connection + // map. + if id := c.id.Load(); id != nil { + delete(l.conns, id.(string)) //nolint:forcetypeassert + } + // If we haven't already removed the remote address, remove it + // from the connection map. + if c.rmraddr.Load() == nil { + delete(l.conns, c.raddr.String()) + c.rmraddr.Store(true) + } + default: + break lclose + } + } + nConns := len(l.conns) + l.connLock.Unlock() + + l.connWG.Done() + + if nConns == 0 { + // Wait if this is the final connection. + l.readWG.Wait() + if errClose, ok := l.errClose.Load().(error); ok { + err = errClose + } + } else { + err = nil + } + }) + + return err +} + +// Addr returns the listener's network address. +func (l *listener) Addr() net.Addr { + return l.pConn.LocalAddr() +} + +// ListenConfig stores options for listening to an address. +type ListenConfig struct { + // Backlog defines the maximum length of the queue of pending + // connections. It is equivalent of the backlog argument of + // POSIX listen function. + // If a connection request arrives when the queue is full, + // the request will be silently discarded, unlike TCP. + // Set zero to use default value 128 which is same as Linux default. + Backlog int + + // AcceptFilter determines whether the new conn should be made for + // the incoming packet. If not set, any packet creates new conn. + AcceptFilter func([]byte) bool + + // DatagramRouter routes an incoming datagram to a connection by extracting + // an identifier from the its paylod + DatagramRouter func([]byte) (string, bool) + + // ConnectionIdentifier extracts an identifier from an outgoing packet. If + // the identifier is not already associated with the connection, it will be + // added. + ConnectionIdentifier func([]byte) (string, bool) +} + +// Listen creates a new listener based on the ListenConfig. +func (lc *ListenConfig) Listen(network string, laddr *net.UDPAddr) (dtlsnet.PacketListener, error) { + if lc.Backlog == 0 { + lc.Backlog = defaultListenBacklog + } + + conn, err := net.ListenUDP(network, laddr) + if err != nil { + return nil, err + } + + packetListener := &listener{ + pConn: conn, + acceptCh: make(chan *PacketConn, lc.Backlog), + conns: make(map[string]*PacketConn), + doneCh: make(chan struct{}), + acceptFilter: lc.AcceptFilter, + datagramRouter: lc.DatagramRouter, + connIdentifier: lc.ConnectionIdentifier, + readDoneCh: make(chan struct{}), + } + + packetListener.accepting.Store(true) + packetListener.connWG.Add(1) + packetListener.readWG.Add(2) // wait readLoop and Close execution routine + + go packetListener.readLoop() + go func() { + packetListener.connWG.Wait() + if err := packetListener.pConn.Close(); err != nil { + packetListener.errClose.Store(err) + } + packetListener.readWG.Done() + }() + + return packetListener, nil +} + +// Listen creates a new listener using default ListenConfig. +func Listen(network string, laddr *net.UDPAddr) (dtlsnet.PacketListener, error) { + return (&ListenConfig{}).Listen(network, laddr) +} + +// readLoop dispatches packets to the proper connection, creating a new one if +// necessary, until all connections are closed. +func (l *listener) readLoop() { + defer l.readWG.Done() + defer close(l.readDoneCh) + + buf := make([]byte, receiveMTU) + + for { + n, raddr, err := l.pConn.ReadFrom(buf) + if err != nil { + l.errRead.Store(err) + + return + } + conn, ok, err := l.getConn(raddr, buf[:n]) + if err != nil { + continue + } + if ok { + _, _ = conn.buffer.WriteTo(buf[:n], raddr) + } + } +} + +// getConn gets an existing connection or creates a new one. +func (l *listener) getConn(raddr net.Addr, buf []byte) (*PacketConn, bool, error) { //nolint:cyclop + l.connLock.Lock() + defer l.connLock.Unlock() + // If we have a custom resolver, use it. + if l.datagramRouter != nil { + if id, ok := l.datagramRouter(buf); ok { + if conn, ok := l.conns[id]; ok { + return conn, true, nil + } + } + } + + // If we don't have a custom resolver, or we were unable to find an + // associated connection, fall back to remote address. + conn, ok := l.conns[raddr.String()] + if !ok { + if isAccepting, ok := l.accepting.Load().(bool); !isAccepting || !ok { + return nil, false, ErrClosedListener + } + if l.acceptFilter != nil { + if !l.acceptFilter(buf) { + return nil, false, nil + } + } + conn = l.newPacketConn(raddr) + select { + case l.acceptCh <- conn: + l.conns[raddr.String()] = conn + default: + return nil, false, ErrListenQueueExceeded + } + } + + return conn, true, nil +} + +// PacketConn is a net.PacketConn implementation that is able to dictate its +// routing ID via an alternate identifier from its remote address. Internal +// buffering is performed for reads, and writes are passed through to the +// underlying net.PacketConn. +type PacketConn struct { + listener *listener + + raddr net.Addr + rmraddr atomic.Value // bool + id atomic.Value // string + + buffer *idtlsnet.PacketBuffer + + doneCh chan struct{} + doneOnce sync.Once + + writeDeadline *deadline.Deadline +} + +// newPacketConn constructs a new PacketConn. +func (l *listener) newPacketConn(raddr net.Addr) *PacketConn { + return &PacketConn{ + listener: l, + raddr: raddr, + buffer: idtlsnet.NewPacketBuffer(), + doneCh: make(chan struct{}), + writeDeadline: deadline.New(), + } +} + +// ReadFrom reads a single packet payload and its associated remote address from +// the underlying buffer. +func (c *PacketConn) ReadFrom(buff []byte) (int, net.Addr, error) { + return c.buffer.ReadFrom(buff) +} + +// WriteTo writes len(payload) bytes from payload to the specified address. +func (c *PacketConn) WriteTo(payload []byte, addr net.Addr) (n int, err error) { + // If we have a connection identifier, check to see if the outgoing packet + // sets it. + if c.listener.connIdentifier != nil { + id := c.id.Load() + // Only update establish identifier if we haven't already done so. + if id == nil { + candidate, ok := c.listener.connIdentifier(payload) + // If we have an identifier, add entry to connection map. + if ok { + c.listener.connLock.Lock() + c.listener.conns[candidate] = c + c.listener.connLock.Unlock() + c.id.Store(candidate) + } + } + // If we are writing to a remote address that differs from the initial, + // we have an alternate identifier established, and we haven't already + // freed the remote address, free the remote address to be used by + // another connection. + // Note: this strategy results in holding onto a remote address after it + // is potentially no longer in use by the client. However, releasing + // earlier means that we could miss some packets that should have been + // routed to this connection. Ideally, we would drop the connection + // entry for the remote address as soon as the client starts sending + // using an alternate identifier, but in practice this proves + // challenging because any client could spoof a connection identifier, + // resulting in the remote address entry being dropped prior to the + // "real" client transitioning to sending using the alternate + // identifier. + if id != nil && c.rmraddr.Load() == nil && addr.String() != c.raddr.String() { + c.listener.connLock.Lock() + delete(c.listener.conns, c.raddr.String()) + c.rmraddr.Store(true) + c.listener.connLock.Unlock() + } + } + + select { + case <-c.writeDeadline.Done(): + return 0, context.DeadlineExceeded + default: + } + + return c.listener.pConn.WriteTo(payload, addr) +} + +// Close closes the conn and releases any Read calls. +func (c *PacketConn) Close() error { + var err error + c.doneOnce.Do(func() { + c.listener.connWG.Done() + close(c.doneCh) + c.listener.connLock.Lock() + // If we have an alternate identifier, remove it from the connection + // map. + if id := c.id.Load(); id != nil { + delete(c.listener.conns, id.(string)) //nolint:forcetypeassert + } + // If we haven't already removed the remote address, remove it from the + // connection map. + if c.rmraddr.Load() == nil { + delete(c.listener.conns, c.raddr.String()) + c.rmraddr.Store(true) + } + nConns := len(c.listener.conns) + c.listener.connLock.Unlock() + + if isAccepting, ok := c.listener.accepting.Load().(bool); nConns == 0 && !isAccepting && ok { + // Wait if this is the final connection + c.listener.readWG.Wait() + if errClose, ok := c.listener.errClose.Load().(error); ok { + err = errClose + } + } else { + err = nil + } + + if errBuf := c.buffer.Close(); errBuf != nil && err == nil { + err = errBuf + } + }) + + return err +} + +// LocalAddr implements net.PacketConn.LocalAddr. +func (c *PacketConn) LocalAddr() net.Addr { + return c.listener.pConn.LocalAddr() +} + +// SetDeadline implements net.PacketConn.SetDeadline. +func (c *PacketConn) SetDeadline(t time.Time) error { + c.writeDeadline.Set(t) + + return c.SetReadDeadline(t) +} + +// SetReadDeadline implements net.PacketConn.SetReadDeadline. +func (c *PacketConn) SetReadDeadline(t time.Time) error { + return c.buffer.SetReadDeadline(t) +} + +// SetWriteDeadline implements net.PacketConn.SetWriteDeadline. +func (c *PacketConn) SetWriteDeadline(t time.Time) error { + c.writeDeadline.Set(t) + // Write deadline of underlying connection should not be changed + // since the connection can be shared. + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/internal/util/util.go b/vendor/github.com/pion/dtls/v3/internal/util/util.go new file mode 100644 index 0000000..8806643 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/internal/util/util.go @@ -0,0 +1,53 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package util contains small helpers used across the repo +package util // nolint:revive + +import ( + "encoding/binary" + + "golang.org/x/crypto/cryptobyte" +) + +// BigEndianUint24 returns the value of a big endian uint24. +func BigEndianUint24(raw []byte) uint32 { + if len(raw) < 3 { + return 0 + } + + rawCopy := make([]byte, 4) + copy(rawCopy[1:], raw) + + return binary.BigEndian.Uint32(rawCopy) +} + +// PutBigEndianUint24 encodes a uint24 and places into out. +func PutBigEndianUint24(out []byte, in uint32) { + tmp := make([]byte, 4) + binary.BigEndian.PutUint32(tmp, in) + copy(out, tmp[1:]) +} + +// PutBigEndianUint48 encodes a uint64 and places into out. +func PutBigEndianUint48(out []byte, in uint64) { + tmp := make([]byte, 8) + binary.BigEndian.PutUint64(tmp, in) + copy(out, tmp[2:]) +} + +// Max returns the larger value. +func Max(a, b int) int { + if a > b { + return a + } + + return b +} + +// AddUint48 appends a big-endian, 48-bit value to the byte string. +// Remove if / when https://github.com/golang/crypto/pull/265 is merged +// upstream. +func AddUint48(b *cryptobyte.Builder, v uint64) { + b.AddBytes([]byte{byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}) +} diff --git a/vendor/github.com/pion/dtls/v3/listener.go b/vendor/github.com/pion/dtls/v3/listener.go new file mode 100644 index 0000000..84c768e --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/listener.go @@ -0,0 +1,115 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "net" + + "github.com/pion/dtls/v3/internal/net/udp" + dtlsnet "github.com/pion/dtls/v3/pkg/net" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" +) + +// Listen creates a DTLS listener. +// +// Deprecated: Use ListenWithOptions instead. +func Listen(network string, laddr *net.UDPAddr, config *Config) (net.Listener, error) { + if err := validateConfig(config); err != nil { + return nil, err + } + + lc := udp.ListenConfig{ + AcceptFilter: func(packet []byte) bool { + pkts, err := recordlayer.UnpackDatagram(packet) + if err != nil || len(pkts) < 1 { + return false + } + h := &recordlayer.Header{} + if err := h.Unmarshal(pkts[0]); err != nil { + return false + } + + return h.ContentType == protocol.ContentTypeHandshake + }, + } + // If connection ID support is enabled, then they must be supported in + // routing. + if config.ConnectionIDGenerator != nil { + lc.DatagramRouter = cidDatagramRouter(len(config.ConnectionIDGenerator())) + lc.ConnectionIdentifier = cidConnIdentifier() + } + parent, err := lc.Listen(network, laddr) + if err != nil { + return nil, err + } + + return &listener{ + config: config, + parent: parent, + }, nil +} + +// ListenWithOptions creates a DTLS listener. +func ListenWithOptions(network string, laddr *net.UDPAddr, opts ...ServerOption) (net.Listener, error) { + config, err := buildServerConfig(opts...) + if err != nil { + return nil, err + } + + return Listen(network, laddr, config) +} + +// NewListener creates a DTLS listener which accepts connections from an inner Listener. +// +// Deprecated: Use NewListenerWithOptions instead. +func NewListener(inner dtlsnet.PacketListener, config *Config) (net.Listener, error) { + if err := validateConfig(config); err != nil { + return nil, err + } + + return &listener{ + config: config, + parent: inner, + }, nil +} + +// NewListenerWithOptions creates a DTLS listener which accepts connections from an inner Listener. +func NewListenerWithOptions(inner dtlsnet.PacketListener, opts ...ServerOption) (net.Listener, error) { + config, err := buildServerConfig(opts...) + if err != nil { + return nil, err + } + + return NewListener(inner, config) +} + +// listener represents a DTLS listener. +type listener struct { + config *Config + parent dtlsnet.PacketListener +} + +// Accept waits for and returns the next connection to the listener. +// You have to either close or read on all connection that are created. +func (l *listener) Accept() (net.Conn, error) { + c, raddr, err := l.parent.Accept() + if err != nil { + return nil, err + } + + return serverWithConfig(c, raddr, l.config) +} + +// Close closes the listener. +// Any blocked Accept operations will be unblocked and return errors. +// Already Accepted connections are not closed. +func (l *listener) Close() error { + return l.parent.Close() +} + +// Addr returns the listener's network address. +func (l *listener) Addr() net.Addr { + return l.parent.Addr() +} diff --git a/vendor/github.com/pion/dtls/v3/options.go b/vendor/github.com/pion/dtls/v3/options.go new file mode 100644 index 0000000..f3d076f --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/options.go @@ -0,0 +1,656 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "crypto/tls" + "crypto/x509" + "io" + "net" + "time" + + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/logging" +) + +// ServerOption configures a DTLS server. +type ServerOption interface { + applyServer(*dtlsConfig) error +} + +// ClientOption configures a DTLS client. +type ClientOption interface { + applyClient(*dtlsConfig) error +} + +// Option is an option that can be used with both client and server. +// This is used for options that apply to both sides of a connection, +// such as in the Resume function where the side is determined at runtime. +type Option interface { + ServerOption + ClientOption +} + +// defensiveCopy copies a slice. This prevents the caller from mutating +// the config after construction. Returns empty slice if input is empty. +func defensiveCopy[T any](t ...T) []T { + return append([]T{}, t...) +} + +// dtlsConfig is the internal configuration structure. +// This will eventually replace the exported Config struct. +type dtlsConfig struct { //nolint:dupl + certificates []tls.Certificate + cipherSuites []CipherSuiteID + customCipherSuites func() []CipherSuite + signatureSchemes []tls.SignatureScheme + certificateSignatureSchemes []tls.SignatureScheme + srtpProtectionProfiles []SRTPProtectionProfile + srtpMasterKeyIdentifier []byte + clientAuth ClientAuthType + extendedMasterSecret ExtendedMasterSecretType + flightInterval time.Duration + disableRetransmitBackoff bool + psk PSKCallback + pskIdentityHint []byte + insecureSkipVerify bool + insecureHashes bool + verifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error + verifyConnection func(*State) error + rootCAs *x509.CertPool + clientCAs *x509.CertPool + serverName string + loggerFactory logging.LoggerFactory + mtu int + replayProtectionWindow int + keyLogWriter io.Writer + sessionStore SessionStore + supportedProtocols []string + ellipticCurves []elliptic.Curve + getCertificate func(*ClientHelloInfo) (*tls.Certificate, error) + getClientCertificate func(*CertificateRequestInfo) (*tls.Certificate, error) + insecureSkipVerifyHello bool + connectionIDGenerator func() []byte + paddingLengthGenerator func(uint) uint + helloRandomBytesGenerator func() [handshake.RandomBytesLength]byte + clientHelloMessageHook func(handshake.MessageClientHello) handshake.Message + serverHelloMessageHook func(handshake.MessageServerHello) handshake.Message + certificateRequestMessageHook func(handshake.MessageCertificateRequest) handshake.Message + onConnectionAttempt func(net.Addr) error +} + +// applyDefaults applies default values to the config. +func (c *dtlsConfig) applyDefaults() { + c.extendedMasterSecret = RequestExtendedMasterSecret + c.flightInterval = time.Second + c.mtu = defaultMTU + c.replayProtectionWindow = defaultReplayProtectionWindow +} + +// toConfig converts internal dtlsConfig to the exported Config struct. +// This is for backward compatibility and will be removed when Config is deprecated. +// All slice fields are copied to ensure immutability. +func (c *dtlsConfig) toConfig() *Config { + config := &Config{ + CustomCipherSuites: c.customCipherSuites, + ClientAuth: c.clientAuth, + ExtendedMasterSecret: c.extendedMasterSecret, + FlightInterval: c.flightInterval, + DisableRetransmitBackoff: c.disableRetransmitBackoff, + PSK: c.psk, + InsecureSkipVerify: c.insecureSkipVerify, + InsecureHashes: c.insecureHashes, + VerifyPeerCertificate: c.verifyPeerCertificate, + VerifyConnection: c.verifyConnection, + RootCAs: c.rootCAs, + ClientCAs: c.clientCAs, + ServerName: c.serverName, + LoggerFactory: c.loggerFactory, + MTU: c.mtu, + ReplayProtectionWindow: c.replayProtectionWindow, + KeyLogWriter: c.keyLogWriter, + SessionStore: c.sessionStore, + GetCertificate: c.getCertificate, + GetClientCertificate: c.getClientCertificate, + InsecureSkipVerifyHello: c.insecureSkipVerifyHello, + ConnectionIDGenerator: c.connectionIDGenerator, + PaddingLengthGenerator: c.paddingLengthGenerator, + HelloRandomBytesGenerator: c.helloRandomBytesGenerator, + ClientHelloMessageHook: c.clientHelloMessageHook, + ServerHelloMessageHook: c.serverHelloMessageHook, + CertificateRequestMessageHook: c.certificateRequestMessageHook, + OnConnectionAttempt: c.onConnectionAttempt, + } + + if len(c.certificates) > 0 { + config.Certificates = append([]tls.Certificate(nil), c.certificates...) + } + if len(c.cipherSuites) > 0 { + config.CipherSuites = append([]CipherSuiteID(nil), c.cipherSuites...) + } + if len(c.signatureSchemes) > 0 { + config.SignatureSchemes = append([]tls.SignatureScheme(nil), c.signatureSchemes...) + } + if len(c.certificateSignatureSchemes) > 0 { + config.CertificateSignatureSchemes = append([]tls.SignatureScheme(nil), c.certificateSignatureSchemes...) + } + if len(c.srtpProtectionProfiles) > 0 { + config.SRTPProtectionProfiles = append([]SRTPProtectionProfile(nil), c.srtpProtectionProfiles...) + } + if len(c.srtpMasterKeyIdentifier) > 0 { + config.SRTPMasterKeyIdentifier = append([]byte(nil), c.srtpMasterKeyIdentifier...) + } + if len(c.pskIdentityHint) > 0 { + config.PSKIdentityHint = append([]byte(nil), c.pskIdentityHint...) + } + if len(c.supportedProtocols) > 0 { + config.SupportedProtocols = append([]string(nil), c.supportedProtocols...) + } + if len(c.ellipticCurves) > 0 { + config.EllipticCurves = append([]elliptic.Curve(nil), c.ellipticCurves...) + } + + return config +} + +// buildConfig builds a Config from the provided options, for mixed client/server cases. +func buildConfig(opts ...Option) (*Config, error) { + cfg := &dtlsConfig{} + cfg.applyDefaults() + + for _, opt := range opts { + if err := opt.applyServer(cfg); err != nil { + return nil, err + } + } + + return cfg.toConfig(), nil +} + +// buildServerConfig builds a Config for server from the provided options. +func buildServerConfig(opts ...ServerOption) (*Config, error) { + cfg := &dtlsConfig{} + cfg.applyDefaults() + + for _, opt := range opts { + if err := opt.applyServer(cfg); err != nil { + return nil, err + } + } + + return cfg.toConfig(), nil +} + +// buildClientConfig builds a Config for client from the provided options. +func buildClientConfig(opts ...ClientOption) (*Config, error) { + cfg := &dtlsConfig{} + cfg.applyDefaults() + + for _, opt := range opts { + if err := opt.applyClient(cfg); err != nil { + return nil, err + } + } + + return cfg.toConfig(), nil +} + +// sharedOption wraps an apply function that works for both client and server. +// This eliminates code duplication for options that behave identically on both sides. +type sharedOption func(*dtlsConfig) error + +func (o sharedOption) applyServer(c *dtlsConfig) error { return o(c) } +func (o sharedOption) applyClient(c *dtlsConfig) error { return o(c) } + +// WithCertificates sets the certificate chain to present to the other side of the connection. +// For functional options, an explicitly empty slice is not allowed. +func WithCertificates(certs ...tls.Certificate) Option { + return sharedOption(func(c *dtlsConfig) error { + if len(certs) == 0 { + return errEmptyCertificates + } + c.certificates = defensiveCopy(certs...) + + return nil + }) +} + +// WithCipherSuites sets the supported cipher suites. +// For functional options, an explicitly empty slice is not allowed. +func WithCipherSuites(suites ...CipherSuiteID) Option { + return sharedOption(func(c *dtlsConfig) error { + if len(suites) == 0 { + return errEmptyCipherSuites + } + c.cipherSuites = defensiveCopy(suites...) + + return nil + }) +} + +// WithCustomCipherSuites sets the custom cipher suites provider. +// Returns an error if the provider is nil. +func WithCustomCipherSuites(fn func() []CipherSuite) Option { + return sharedOption(func(c *dtlsConfig) error { + if fn == nil { + return errNilCustomCipherSuites + } + c.customCipherSuites = fn + + return nil + }) +} + +// WithSignatureSchemes sets the signature schemes. +// For functional options, an explicitly empty slice is not allowed. +func WithSignatureSchemes(schemes ...tls.SignatureScheme) Option { + return sharedOption(func(c *dtlsConfig) error { + if len(schemes) == 0 { + return errEmptySignatureSchemes + } + c.signatureSchemes = defensiveCopy(schemes...) + + return nil + }) +} + +// WithCertificateSignatureSchemes sets the signature and hash schemes that may be used +// in digital signatures for X.509 certificates. If not set, the signature_algorithms_cert +// extension is not sent, and SignatureSchemes is used for both handshake signatures and +// certificate chain validation, as specified in RFC 8446 Section 4.2.3. +// For functional options, an explicitly empty slice is not allowed. +func WithCertificateSignatureSchemes(schemes ...tls.SignatureScheme) Option { + return sharedOption(func(c *dtlsConfig) error { + if len(schemes) == 0 { + return errEmptyCertificateSignatureSchemes + } + c.certificateSignatureSchemes = defensiveCopy(schemes...) + + return nil + }) +} + +// WithSRTPProtectionProfiles sets the SRTP protection profiles. +// For functional options, an explicitly empty slice is not allowed. +func WithSRTPProtectionProfiles(profiles ...SRTPProtectionProfile) Option { + return sharedOption(func(c *dtlsConfig) error { + if len(profiles) == 0 { + return errEmptySRTPProtectionProfiles + } + c.srtpProtectionProfiles = defensiveCopy(profiles...) + + return nil + }) +} + +// WithSRTPMasterKeyIdentifier sets the SRTP master key identifier. +func WithSRTPMasterKeyIdentifier(identifier []byte) Option { + return sharedOption(func(c *dtlsConfig) error { + c.srtpMasterKeyIdentifier = defensiveCopy(identifier...) + + return nil + }) +} + +// WithExtendedMasterSecret sets the extended master secret policy. +// Returns an error if the type is invalid. +func WithExtendedMasterSecret(ems ExtendedMasterSecretType) Option { + return sharedOption(func(c *dtlsConfig) error { + if ems < RequestExtendedMasterSecret || ems > DisableExtendedMasterSecret { + return errInvalidExtendedMasterSecretType + } + c.extendedMasterSecret = ems + + return nil + }) +} + +// WithFlightInterval sets the flight interval for handshake messages. +// Returns an error if the interval is not positive. +func WithFlightInterval(interval time.Duration) Option { + return sharedOption(func(c *dtlsConfig) error { + if interval <= 0 { + return errInvalidFlightInterval + } + c.flightInterval = interval + + return nil + }) +} + +// WithDisableRetransmitBackoff disables retransmit backoff. +func WithDisableRetransmitBackoff(disable bool) Option { + return sharedOption(func(c *dtlsConfig) error { + c.disableRetransmitBackoff = disable + + return nil + }) +} + +// WithPSK sets the pre-shared key callback. +// Returns an error if the callback is nil. +func WithPSK(callback PSKCallback) Option { + return sharedOption(func(c *dtlsConfig) error { + if callback == nil { + return errNilPSKCallback + } + c.psk = callback + + return nil + }) +} + +// WithPSKIdentityHint sets the PSK identity hint. +func WithPSKIdentityHint(hint []byte) Option { + return sharedOption(func(c *dtlsConfig) error { + c.pskIdentityHint = defensiveCopy(hint...) + + return nil + }) +} + +// WithInsecureSkipVerify skips certificate verification. +// This should only be used for testing. +func WithInsecureSkipVerify(skip bool) Option { + return sharedOption(func(c *dtlsConfig) error { + c.insecureSkipVerify = skip + + return nil + }) +} + +// WithInsecureHashes allows the use of insecure hash algorithms. +func WithInsecureHashes(allow bool) Option { + return sharedOption(func(c *dtlsConfig) error { + c.insecureHashes = allow + + return nil + }) +} + +// WithVerifyPeerCertificate sets the peer certificate verification callback. +// Returns an error if the callback is nil. +func WithVerifyPeerCertificate(fn func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error) Option { + return sharedOption(func(c *dtlsConfig) error { + if fn == nil { + return errNilVerifyPeerCertificate + } + c.verifyPeerCertificate = fn + + return nil + }) +} + +// WithVerifyConnection sets the connection verification callback. +// Returns an error if the callback is nil. +func WithVerifyConnection(fn func(*State) error) Option { + return sharedOption(func(c *dtlsConfig) error { + if fn == nil { + return errNilVerifyConnection + } + c.verifyConnection = fn + + return nil + }) +} + +// WithRootCAs sets the root certificate authorities. +func WithRootCAs(pool *x509.CertPool) Option { + return sharedOption(func(c *dtlsConfig) error { + c.rootCAs = pool + + return nil + }) +} + +// WithServerName sets the server name for certificate verification. +func WithServerName(name string) Option { + return sharedOption(func(c *dtlsConfig) error { + c.serverName = name + + return nil + }) +} + +// WithLoggerFactory sets the logger factory for creating loggers. +func WithLoggerFactory(factory logging.LoggerFactory) Option { + return sharedOption(func(c *dtlsConfig) error { + c.loggerFactory = factory + + return nil + }) +} + +// WithMTU sets the maximum transmission unit. +// Returns an error if the MTU is not positive. +func WithMTU(mtu int) Option { + return sharedOption(func(c *dtlsConfig) error { + if mtu <= 0 { + return errInvalidMTU + } + c.mtu = mtu + + return nil + }) +} + +// WithReplayProtectionWindow sets the replay protection window size. +// Returns an error if the window size is negative. +func WithReplayProtectionWindow(window int) Option { + return sharedOption(func(c *dtlsConfig) error { + if window < 0 { + return errInvalidReplayProtectionWindow + } + c.replayProtectionWindow = window + + return nil + }) +} + +// WithKeyLogWriter sets the key log writer for debugging. +// Use of KeyLogWriter compromises security and should only be used for debugging. +func WithKeyLogWriter(writer io.Writer) Option { + return sharedOption(func(c *dtlsConfig) error { + c.keyLogWriter = writer + + return nil + }) +} + +// WithSessionStore sets the session store for resumption. +func WithSessionStore(store SessionStore) Option { + return sharedOption(func(c *dtlsConfig) error { + c.sessionStore = store + + return nil + }) +} + +// WithSupportedProtocols sets the supported application protocols for ALPN. +// For functional options, an explicitly empty slice is not allowed. +func WithSupportedProtocols(protocols ...string) Option { + return sharedOption(func(c *dtlsConfig) error { + if len(protocols) == 0 { + return errEmptySupportedProtocols + } + c.supportedProtocols = defensiveCopy(protocols...) + + return nil + }) +} + +// WithEllipticCurves sets the elliptic curves. +// For functional options, an explicitly empty slice is not allowed. +func WithEllipticCurves(curves ...elliptic.Curve) Option { + return sharedOption(func(c *dtlsConfig) error { + if len(curves) == 0 { + return errEmptyEllipticCurves + } + c.ellipticCurves = defensiveCopy(curves...) + + return nil + }) +} + +// WithGetClientCertificate sets the client certificate getter callback. +// Returns an error if the callback is nil. +func WithGetClientCertificate(fn func(*CertificateRequestInfo) (*tls.Certificate, error)) Option { + return sharedOption(func(c *dtlsConfig) error { + if fn == nil { + return errNilGetClientCertificate + } + c.getClientCertificate = fn + + return nil + }) +} + +// WithConnectionIDGenerator sets the connection ID generator. +// Returns an error if the generator is nil. +func WithConnectionIDGenerator(fn func() []byte) Option { + return sharedOption(func(c *dtlsConfig) error { + if fn == nil { + return errNilConnectionIDGenerator + } + c.connectionIDGenerator = fn + + return nil + }) +} + +// WithPaddingLengthGenerator sets the padding length generator. +// Returns an error if the generator is nil. +func WithPaddingLengthGenerator(fn func(uint) uint) Option { + return sharedOption(func(c *dtlsConfig) error { + if fn == nil { + return errNilPaddingLengthGenerator + } + c.paddingLengthGenerator = fn + + return nil + }) +} + +// WithHelloRandomBytesGenerator sets the hello random bytes generator. +// Returns an error if the generator is nil. +func WithHelloRandomBytesGenerator(fn func() [handshake.RandomBytesLength]byte) Option { + return sharedOption(func(c *dtlsConfig) error { + if fn == nil { + return errNilHelloRandomBytesGenerator + } + c.helloRandomBytesGenerator = fn + + return nil + }) +} + +// WithClientHelloMessageHook sets the client hello message hook. +// Returns an error if the hook is nil. +func WithClientHelloMessageHook(fn func(handshake.MessageClientHello) handshake.Message) Option { + return sharedOption(func(c *dtlsConfig) error { + if fn == nil { + return errNilClientHelloMessageHook + } + c.clientHelloMessageHook = fn + + return nil + }) +} + +// serverOnlyOption wraps an apply function for server-only options. +type serverOnlyOption func(*dtlsConfig) error + +func (o serverOnlyOption) applyServer(c *dtlsConfig) error { return o(c) } + +// WithClientAuth sets the client authentication policy. +// Returns an error if the type is invalid. +// This option is only applicable to servers. +func WithClientAuth(auth ClientAuthType) ServerOption { + return serverOnlyOption(func(c *dtlsConfig) error { + if auth < NoClientCert || auth > RequireAndVerifyClientCert { + return errInvalidClientAuthType + } + c.clientAuth = auth + + return nil + }) +} + +// WithClientCAs sets the client certificate authorities. +// This option is only applicable to servers. +func WithClientCAs(pool *x509.CertPool) ServerOption { + return serverOnlyOption(func(c *dtlsConfig) error { + c.clientCAs = pool + + return nil + }) +} + +// WithGetCertificate sets the certificate getter callback. +// Returns an error if the callback is nil. +// This option is only applicable to servers. +func WithGetCertificate(fn func(*ClientHelloInfo) (*tls.Certificate, error)) ServerOption { + return serverOnlyOption(func(c *dtlsConfig) error { + if fn == nil { + return errNilGetCertificate + } + c.getCertificate = fn + + return nil + }) +} + +// WithInsecureSkipVerifyHello skips hello verify phase on the server. +// This has implication on DoS attack resistance. +// This option is only applicable to servers. +func WithInsecureSkipVerifyHello(skip bool) ServerOption { + return serverOnlyOption(func(c *dtlsConfig) error { + c.insecureSkipVerifyHello = skip + + return nil + }) +} + +// WithServerHelloMessageHook sets the server hello message hook. +// Returns an error if the hook is nil. +// This option is only applicable to servers. +func WithServerHelloMessageHook(fn func(handshake.MessageServerHello) handshake.Message) ServerOption { + return serverOnlyOption(func(c *dtlsConfig) error { + if fn == nil { + return errNilServerHelloMessageHook + } + c.serverHelloMessageHook = fn + + return nil + }) +} + +// WithCertificateRequestMessageHook sets the certificate request message hook. +// Returns an error if the hook is nil. +// This option is only applicable to servers. +func WithCertificateRequestMessageHook(fn func(handshake.MessageCertificateRequest) handshake.Message) ServerOption { + return serverOnlyOption(func(c *dtlsConfig) error { + if fn == nil { + return errNilCertificateRequestMessageHook + } + c.certificateRequestMessageHook = fn + + return nil + }) +} + +// WithOnConnectionAttempt sets the connection attempt callback. +// Returns an error if the callback is nil. +// This option is only applicable to servers. +func WithOnConnectionAttempt(fn func(net.Addr) error) ServerOption { + return serverOnlyOption(func(c *dtlsConfig) error { + if fn == nil { + return errNilOnConnectionAttempt + } + c.onConnectionAttempt = fn + + return nil + }) +} diff --git a/vendor/github.com/pion/dtls/v3/packet.go b/vendor/github.com/pion/dtls/v3/packet.go new file mode 100644 index 0000000..c458e65 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/packet.go @@ -0,0 +1,15 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" +) + +type packet struct { + record *recordlayer.RecordLayer + shouldEncrypt bool + shouldWrapCID bool + resetLocalSequenceNumber bool +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/crypto/ccm/ccm.go b/vendor/github.com/pion/dtls/v3/pkg/crypto/ccm/ccm.go new file mode 100644 index 0000000..13a4242 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/crypto/ccm/ccm.go @@ -0,0 +1,260 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package ccm implements a CCM, Counter with CBC-MAC +// as per RFC 3610. +// +// See https://tools.ietf.org/html/rfc3610 +// +// This code was lifted from https://github.com/bocajim/dtls/blob/a3300364a283fcb490d28a93d7fcfa7ba437fbbe/ccm/ccm.go +// and as such was not written by the Pions authors. Like Pions this +// code is licensed under MIT. +// +// A request for including CCM into the Go standard library +// can be found as issue #27484 on the https://github.com/golang/go/ +// repository. +package ccm + +import ( + "crypto/cipher" + "crypto/subtle" + "encoding/binary" + "errors" + "math" +) + +// ccm represents a Counter with CBC-MAC with a specific key. +type ccm struct { + b cipher.Block + M uint8 + L uint8 +} + +const ccmBlockSize = 16 + +// CCM is a block cipher in Counter with CBC-MAC mode. +// Providing authenticated encryption with associated data via the cipher.AEAD interface. +type CCM interface { + cipher.AEAD + // MaxLength returns the maxium length of plaintext in calls to Seal. + // The maximum length of ciphertext in calls to Open is MaxLength()+Overhead(). + // The maximum length is related to CCM's `L` parameter (15-noncesize) and + // is 1<<(8*L) - 1 (but also limited by the maxium size of an int). + MaxLength() int +} + +var ( + errInvalidBlockSize = errors.New("ccm: NewCCM requires 128-bit block cipher") + errInvalidTagSize = errors.New("ccm: tagsize must be 4, 6, 8, 10, 12, 14, or 16") + errInvalidNonceSize = errors.New("ccm: invalid nonce size") +) + +// NewCCM returns the given 128-bit block cipher wrapped in CCM. +// The tagsize must be an even integer between 4 and 16 inclusive +// and is used as CCM's `M` parameter. +// The noncesize must be an integer between 7 and 13 inclusive, +// 15-noncesize is used as CCM's `L` parameter. +func NewCCM(b cipher.Block, tagsize, noncesize int) (CCM, error) { + if b.BlockSize() != ccmBlockSize { + return nil, errInvalidBlockSize + } + if tagsize < 4 || tagsize > 16 || tagsize&1 != 0 { + return nil, errInvalidTagSize + } + lensize := 15 - noncesize + if lensize < 2 || lensize > 8 { + return nil, errInvalidNonceSize + } + c := &ccm{b: b, M: uint8(tagsize), L: uint8(lensize)} //nolint:gosec // G114 + + return c, nil +} + +func (c *ccm) NonceSize() int { return 15 - int(c.L) } +func (c *ccm) Overhead() int { return int(c.M) } +func (c *ccm) MaxLength() int { return maxlen(c.L, c.Overhead()) } + +func maxlen(l uint8, tagsize int) int { + mLen := (uint64(1) << (8 * l)) - 1 + if m64 := uint64(math.MaxInt64) - uint64(tagsize); l > 8 || mLen > m64 { //nolint:gosec // G114 + mLen = m64 // The maximum lentgh on a 64bit arch + } + if mLen != uint64(int(mLen)) { //nolint:gosec // G114 + return math.MaxInt32 - tagsize // We have only 32bit int's + } + + return int(mLen) //nolint:gosec // G114 +} + +// MaxNonceLength returns the maximum nonce length for a given plaintext length. +// A return value <= 0 indicates that plaintext length is too large for +// any nonce length. +func MaxNonceLength(pdatalen int) int { + const tagsize = 16 + for L := 2; L <= 8; L++ { + if maxlen(uint8(L), tagsize) >= pdatalen { //nolint:gosec // G115 + return 15 - L + } + } + + return 0 +} + +func (c *ccm) cbcRound(mac, data []byte) { + for i := 0; i < ccmBlockSize; i++ { + mac[i] ^= data[i] + } + c.b.Encrypt(mac, mac) +} + +func (c *ccm) cbcData(mac, data []byte) { + for len(data) >= ccmBlockSize { + c.cbcRound(mac, data[:ccmBlockSize]) + data = data[ccmBlockSize:] + } + if len(data) > 0 { + var block [ccmBlockSize]byte + copy(block[:], data) + c.cbcRound(mac, block[:]) + } +} + +var errPlaintextTooLong = errors.New("ccm: plaintext too large") + +func (c *ccm) tag(nonce, plaintext, adata []byte) ([]byte, error) { + var mac [ccmBlockSize]byte + + if len(adata) > 0 { + mac[0] |= 1 << 6 + } + mac[0] |= (c.M - 2) << 2 + mac[0] |= c.L - 1 + if len(nonce) != c.NonceSize() { + return nil, errInvalidNonceSize + } + if len(plaintext) > c.MaxLength() { + return nil, errPlaintextTooLong + } + binary.BigEndian.PutUint64(mac[ccmBlockSize-8:], uint64(len(plaintext))) + copy(mac[1:ccmBlockSize-c.L], nonce) + c.b.Encrypt(mac[:], mac[:]) + + var block [ccmBlockSize]byte + if adataLength := uint64(len(adata)); adataLength > 0 { //nolint:nestif + // First adata block includes adata length + i := 2 + if adataLength <= 0xfeff { + binary.BigEndian.PutUint16(block[:i], uint16(adataLength)) + } else { + binary.BigEndian.PutUint16(block[0:2], 0xfeff) + if adataLength < uint64(1<<32) { + i = 2 + 4 + binary.BigEndian.PutUint32(block[2:i], uint32(adataLength)) //nolint:gosec // G115 + } else { + i = 2 + 8 + binary.BigEndian.PutUint64(block[2:i], adataLength) + } + } + i = copy(block[i:], adata) + c.cbcRound(mac[:], block[:]) + c.cbcData(mac[:], adata[i:]) + } + + if len(plaintext) > 0 { + c.cbcData(mac[:], plaintext) + } + + return mac[:c.M], nil +} + +// sliceForAppend takes a slice and a requested number of bytes. It returns a +// slice with the contents of the given slice followed by that many bytes and a +// second slice that aliases into it and contains only the extra bytes. If the +// original slice has sufficient capacity then no allocation is performed. +// From crypto/cipher/gcm.go +// . +func sliceForAppend(in []byte, n int) (head, tail []byte) { + if total := len(in) + n; cap(in) >= total { + head = in[:total] + } else { + head = make([]byte, total) + copy(head, in) + } + tail = head[len(in):] + + return +} + +// Seal encrypts and authenticates plaintext, authenticates the +// additional data and appends the result to dst, returning the updated +// slice. The nonce must be NonceSize() bytes long and unique for all +// time, for a given key. +// The plaintext must be no longer than MaxLength() bytes long. +// +// The plaintext and dst may alias exactly or not at all. +func (c *ccm) Seal(dst, nonce, plaintext, adata []byte) []byte { + tag, err := c.tag(nonce, plaintext, adata) + if err != nil { + // The cipher.AEAD interface doesn't allow for an error return. + panic(err) // nolint + } + + var iv, s0 [ccmBlockSize]byte + iv[0] = c.L - 1 + copy(iv[1:ccmBlockSize-c.L], nonce) + c.b.Encrypt(s0[:], iv[:]) + for i := 0; i < int(c.M); i++ { + tag[i] ^= s0[i] + } + iv[len(iv)-1] |= 1 + stream := cipher.NewCTR(c.b, iv[:]) + ret, out := sliceForAppend(dst, len(plaintext)+int(c.M)) + stream.XORKeyStream(out, plaintext) + copy(out[len(plaintext):], tag) + + return ret +} + +var ( + errOpen = errors.New("ccm: message authentication failed") + errCiphertextTooShort = errors.New("ccm: ciphertext too short") + errCiphertextTooLong = errors.New("ccm: ciphertext too long") +) + +func (c *ccm) Open(dst, nonce, ciphertext, adata []byte) ([]byte, error) { + if len(ciphertext) < int(c.M) { + return nil, errCiphertextTooShort + } + if len(ciphertext) > c.MaxLength()+c.Overhead() { + return nil, errCiphertextTooLong + } + + tag := make([]byte, int(c.M)) + copy(tag, ciphertext[len(ciphertext)-int(c.M):]) + ciphertextWithoutTag := ciphertext[:len(ciphertext)-int(c.M)] + + var iv, s0 [ccmBlockSize]byte + iv[0] = c.L - 1 + copy(iv[1:ccmBlockSize-c.L], nonce) + c.b.Encrypt(s0[:], iv[:]) + for i := 0; i < int(c.M); i++ { + tag[i] ^= s0[i] + } + iv[len(iv)-1] |= 1 + stream := cipher.NewCTR(c.b, iv[:]) + + // Cannot decrypt directly to dst since we're not supposed to + // reveal the plaintext to the caller if authentication fails. + plaintext := make([]byte, len(ciphertextWithoutTag)) + stream.XORKeyStream(plaintext, ciphertextWithoutTag) + expectedTag, err := c.tag(nonce, plaintext, adata) + if err != nil { + return nil, err + } + + if subtle.ConstantTimeCompare(tag, expectedTag) != 1 { + return nil, errOpen + } + + return append(dst, plaintext...), nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/crypto/ciphersuite/README.md b/vendor/github.com/pion/dtls/v3/pkg/crypto/ciphersuite/README.md new file mode 100644 index 0000000..6d0c4c1 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/crypto/ciphersuite/README.md @@ -0,0 +1,128 @@ +# Ciphersuite Package + +This package provides DTLS cipher suite implementations for GCM, CCM, and CBC modes. + +## Benchmarking + +The package includes comprehensive benchmarks for all cipher operations across multiple payload sizes. + +**Note:** Benchmarks are excluded from regular test runs using build tags. You must specify `-tags=bench` to run them. + +### Running all ciphersuite benchmarks + +```bash +go test -tags=bench -bench=. -benchmem +``` + +### Running a specific benchmark + +- GCM benchmarks only: + +```bash +go test -tags=bench -bench=BenchmarkGCM -benchmem +``` + +- GCM `Encrypt` benchmark only: + +```bash +go test -tags=bench -bench=BenchmarkGCMEncrypt -benchmem +``` + +- GCM `Decrypt` benchmark only: + +```bash +go test -tags=bench -bench=BenchmarkGCMDecrypt -benchmem +``` + +- CCM benchmarks only: + +```bash +go test -tags=bench -bench=BenchmarkCCM -benchmem +``` + +- CCM `Encrypt` benchmark only: + +```bash +go test -tags=bench -bench=BenchmarkCCMEncrypt -benchmem +``` + +- CCM `Decrypt` benchmark only: + +```bash +go test -tags=bench -bench=BenchmarkCCMDecrypt -benchmem +``` + +- CBC benchmarks only: + +```bash +go test -tags=bench -bench=BenchmarkCBC -benchmem +``` + +- CBC `Encrypt` benchmark only: + +```bash +go test -tags=bench -bench=BenchmarkCBCEncrypt -benchmem +``` + +- CBC `Decrypt` benchmark only: + +```bash +go test -tags=bench -bench=BenchmarkCBCDecrypt -benchmem +``` + +- All ciphers, with 1KB payloads only + +```bash +go test -tags=bench -bench=/1KB -benchmem +``` + +- All ciphers, with 16B payloads only + +```bash +go test -tags=bench -bench=/16B -benchmem +``` + +### Benchmark Options + +Increase benchmark time for more accurate results: + +```bash +go test -tags=bench -bench=BenchmarkGCM -benchmem -benchtime=5s +``` + +Run benchmarks multiple times: + +```bash +go test -tags=bench -bench=BenchmarkGCM -benchmem -count=5 +``` + +### Understanding Results + +Example output: + +``` +BenchmarkGCMEncrypt/016B-8 5895367 202.6 ns/op 78.99 MB/s 160 B/op 5 allocs/op +``` + +- `5895367`: Number of iterations +- `202.6 ns/op`: Time per operation +- `78.99 MB/s`: Throughput +- `160 B/op`: Bytes allocated per operation +- `5 allocs/op`: Number of allocations per operation + + +## Profiling + +Generate CPU profile: + +```bash +go test -tags=bench -bench=BenchmarkGCMEncrypt -benchmem -cpuprofile=cpu.prof +go tool pprof -top cpu.prof +``` + +Generate memory profile: + +```bash +go test -tags=bench -bench=BenchmarkGCMEncrypt -benchmem -memprofile=mem.prof +go tool pprof -top -alloc_objects mem.prof +``` \ No newline at end of file diff --git a/vendor/github.com/pion/dtls/v3/pkg/crypto/ciphersuite/cbc.go b/vendor/github.com/pion/dtls/v3/pkg/crypto/ciphersuite/cbc.go new file mode 100644 index 0000000..262490b --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/crypto/ciphersuite/cbc.go @@ -0,0 +1,250 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ciphersuite + +import ( //nolint:gci + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/rand" + "encoding/binary" + "hash" + + "github.com/pion/dtls/v3/internal/util" + "github.com/pion/dtls/v3/pkg/crypto/prf" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" + "golang.org/x/crypto/cryptobyte" +) + +// block ciphers using cipher block chaining. +type cbcMode interface { + cipher.BlockMode + SetIV([]byte) +} + +// CBC Provides an API to Encrypt/Decrypt DTLS 1.2 Packets. +type CBC struct { + writeCBC, readCBC cbcMode + writeMac, readMac []byte + h prf.HashFunc +} + +// NewCBC creates a DTLS CBC Cipher. +func NewCBC( + localKey, localWriteIV, localMac, remoteKey, remoteWriteIV, remoteMac []byte, + hashFunc prf.HashFunc, +) (*CBC, error) { + writeBlock, err := aes.NewCipher(localKey) + if err != nil { + return nil, err + } + + readBlock, err := aes.NewCipher(remoteKey) + if err != nil { + return nil, err + } + + writeCBC, ok := cipher.NewCBCEncrypter(writeBlock, localWriteIV).(cbcMode) + if !ok { + return nil, errFailedToCast + } + + readCBC, ok := cipher.NewCBCDecrypter(readBlock, remoteWriteIV).(cbcMode) + if !ok { + return nil, errFailedToCast + } + + return &CBC{ + writeCBC: writeCBC, + writeMac: localMac, + + readCBC: readCBC, + readMac: remoteMac, + h: hashFunc, + }, nil +} + +// Encrypt encrypt a DTLS RecordLayer message. +func (c *CBC) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { + payload := raw[pkt.Header.Size():] + raw = raw[:pkt.Header.Size()] + blockSize := c.writeCBC.BlockSize() + + // Generate + Append MAC + h := pkt.Header + + var err error + var mac []byte + if h.ContentType == protocol.ContentTypeConnectionID { + mac, err = c.hmacCID(h.Epoch, h.SequenceNumber, h.Version, payload, c.writeMac, c.h, h.ConnectionID) + } else { + mac, err = c.hmac(h.Epoch, h.SequenceNumber, h.ContentType, h.Version, payload, c.writeMac, c.h) + } + if err != nil { + return nil, err + } + payload = append(payload, mac...) + + // Generate + Append padding + padding := make([]byte, blockSize-len(payload)%blockSize) + paddingLen := len(padding) + for i := 0; i < paddingLen; i++ { + padding[i] = byte(paddingLen - 1) + } + payload = append(payload, padding...) + + // Generate IV + iv := make([]byte, blockSize) + if _, err := rand.Read(iv); err != nil { + return nil, err + } + + // Set IV + Encrypt + Prepend IV + c.writeCBC.SetIV(iv) + c.writeCBC.CryptBlocks(payload, payload) + payload = append(iv, payload...) //nolint:makezero // todo: FIX + + // Prepend unencrypted header with encrypted payload + raw = append(raw, payload...) + + // Update recordLayer size to include IV+MAC+Padding + binary.BigEndian.PutUint16(raw[pkt.Header.Size()-2:], uint16(len(raw)-pkt.Header.Size())) //nolint:gosec //G115 + + return raw, nil +} + +// Decrypt decrypts a DTLS RecordLayer message. +func (c *CBC) Decrypt(header recordlayer.Header, in []byte) ([]byte, error) { + blockSize := c.readCBC.BlockSize() + mac := c.h() + + if err := header.Unmarshal(in); err != nil { + return nil, err + } + body := in[header.Size():] + + switch { + case header.ContentType == protocol.ContentTypeChangeCipherSpec: + // Nothing to encrypt with ChangeCipherSpec + return in, nil + case len(body)%blockSize != 0 || len(body) < blockSize+util.Max(mac.Size()+1, blockSize): + return nil, errNotEnoughRoomForNonce + } + + // Set + remove per record IV + c.readCBC.SetIV(body[:blockSize]) + body = body[blockSize:] + + // Decrypt + c.readCBC.CryptBlocks(body, body) + + // Padding+MAC needs to be checked in constant time + // Otherwise we reveal information about the level of correctness + paddingLen, paddingGood := examinePadding(body) + if paddingGood != 255 { + return nil, errInvalidMAC + } + + macSize := mac.Size() + if len(body) < macSize { + return nil, errInvalidMAC + } + + dataEnd := len(body) - macSize - paddingLen + + expectedMAC := body[dataEnd : dataEnd+macSize] + var err error + var actualMAC []byte + if header.ContentType == protocol.ContentTypeConnectionID { + actualMAC, err = c.hmacCID( + header.Epoch, header.SequenceNumber, header.Version, body[:dataEnd], c.readMac, c.h, header.ConnectionID, + ) + } else { + actualMAC, err = c.hmac( + header.Epoch, header.SequenceNumber, header.ContentType, header.Version, body[:dataEnd], c.readMac, c.h, + ) + } + // Compute Local MAC and compare + if err != nil || !hmac.Equal(actualMAC, expectedMAC) { + return nil, errInvalidMAC + } + + return append(in[:header.Size()], body[:dataEnd]...), nil +} + +func (c *CBC) hmac( + epoch uint16, + sequenceNumber uint64, + contentType protocol.ContentType, + protocolVersion protocol.Version, + payload []byte, + key []byte, + hf func() hash.Hash, +) ([]byte, error) { + hmacHash := hmac.New(hf, key) + + msg := make([]byte, 13) + + binary.BigEndian.PutUint16(msg, epoch) + util.PutBigEndianUint48(msg[2:], sequenceNumber) + msg[8] = byte(contentType) + msg[9] = protocolVersion.Major + msg[10] = protocolVersion.Minor + binary.BigEndian.PutUint16(msg[11:], uint16(len(payload))) //nolint:gosec //G115 + + if _, err := hmacHash.Write(msg); err != nil { + return nil, err + } + if _, err := hmacHash.Write(payload); err != nil { + return nil, err + } + + return hmacHash.Sum(nil), nil +} + +// hmacCID calculates a MAC according to +// https://datatracker.ietf.org/doc/html/rfc9146#section-5.1 +func (c *CBC) hmacCID( + epoch uint16, + sequenceNumber uint64, + protocolVersion protocol.Version, + payload []byte, + key []byte, + hf func() hash.Hash, + cid []byte, +) ([]byte, error) { + // Must unmarshal inner plaintext in orde to perform MAC. + ip := &recordlayer.InnerPlaintext{} + if err := ip.Unmarshal(payload); err != nil { + return nil, err + } + + hmacHash := hmac.New(hf, key) + + var msg cryptobyte.Builder + + msg.AddUint64(seqNumPlaceholder) + msg.AddUint8(uint8(protocol.ContentTypeConnectionID)) + msg.AddUint8(uint8(len(cid))) //nolint:gosec //G115 + msg.AddUint8(uint8(protocol.ContentTypeConnectionID)) + msg.AddUint8(protocolVersion.Major) + msg.AddUint8(protocolVersion.Minor) + msg.AddUint16(epoch) + util.AddUint48(&msg, sequenceNumber) + msg.AddBytes(cid) + msg.AddUint16(uint16(len(payload))) //nolint:gosec //G115 + msg.AddBytes(ip.Content) + msg.AddUint8(uint8(ip.RealType)) + msg.AddBytes(make([]byte, ip.Zeros)) + + if _, err := hmacHash.Write(msg.BytesOrPanic()); err != nil { + return nil, err + } + if _, err := hmacHash.Write(payload); err != nil { + return nil, err + } + + return hmacHash.Sum(nil), nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/crypto/ciphersuite/ccm.go b/vendor/github.com/pion/dtls/v3/pkg/crypto/ciphersuite/ccm.go new file mode 100644 index 0000000..723f2ce --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/crypto/ciphersuite/ccm.go @@ -0,0 +1,68 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ciphersuite + +import ( + "crypto/aes" + + "github.com/pion/dtls/v3/pkg/crypto/ccm" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" +) + +// CCMTagLen is the length of Authentication Tag. +type CCMTagLen int + +// CCM Enums. +const ( + CCMTagLength8 CCMTagLen = 8 + CCMTagLength CCMTagLen = 16 + ccmNonceLength = 12 +) + +// CCM Provides an API to Encrypt/Decrypt DTLS 1.2 Packets. +type CCM struct { + aead *aead +} + +// NewCCM creates a DTLS GCM Cipher. +func NewCCM(tagLen CCMTagLen, localKey, localWriteIV, remoteKey, remoteWriteIV []byte) (*CCM, error) { + localBlock, err := aes.NewCipher(localKey) + if err != nil { + return nil, err + } + localCCM, err := ccm.NewCCM(localBlock, int(tagLen), ccmNonceLength) + if err != nil { + return nil, err + } + + remoteBlock, err := aes.NewCipher(remoteKey) + if err != nil { + return nil, err + } + remoteCCM, err := ccm.NewCCM(remoteBlock, int(tagLen), ccmNonceLength) + if err != nil { + return nil, err + } + + return &CCM{ + aead: newAEAD( + localCCM, + localWriteIV, + remoteCCM, + remoteWriteIV, + ccmNonceLength, + int(tagLen), + ), + }, nil +} + +// Encrypt encrypt a DTLS RecordLayer message. +func (c *CCM) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { + return c.aead.encrypt(pkt, raw) +} + +// Decrypt decrypts a DTLS RecordLayer message. +func (c *CCM) Decrypt(header recordlayer.Header, in []byte) ([]byte, error) { + return c.aead.decrypt(header, in) +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/crypto/ciphersuite/ciphersuite.go b/vendor/github.com/pion/dtls/v3/pkg/crypto/ciphersuite/ciphersuite.go new file mode 100644 index 0000000..a1873c1 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/crypto/ciphersuite/ciphersuite.go @@ -0,0 +1,226 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package ciphersuite provides the crypto operations needed for a DTLS CipherSuite +package ciphersuite + +import ( + "crypto/cipher" + "encoding/binary" + "errors" + "fmt" + "sync" + + "github.com/pion/dtls/v3/internal/util" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" + "golang.org/x/crypto/cryptobyte" +) + +const ( + // 8 bytes of 0xff. + // https://datatracker.ietf.org/doc/html/rfc9146#name-record-payload-protection + seqNumPlaceholder = 0xffffffffffffffff +) + +var ( + //nolint:err113 + errNotEnoughRoomForNonce = &protocol.InternalError{Err: errors.New("buffer not long enough to contain nonce")} + //nolint:err113 + errDecryptPacket = &protocol.TemporaryError{Err: errors.New("failed to decrypt packet")} + //nolint:err113 + errInvalidMAC = &protocol.TemporaryError{Err: errors.New("invalid mac")} + //nolint:err113 + errFailedToCast = &protocol.FatalError{Err: errors.New("failed to cast")} +) + +// aead provides a generic API to Encrypt/Decrypt DTLS 1.2 Packets. +type aead struct { + localAEAD cipher.AEAD + remoteAEAD cipher.AEAD + localWriteIV []byte + remoteWriteIV []byte + nonceLength int + tagLength int + + // buffer pool for (fixed-size) nonces. + nonceBufferPool sync.Pool +} + +// newAEAD creates a generic DTLS AEAD-based Cipher. +func newAEAD( + localAEAD cipher.AEAD, + localWriteIV []byte, + remoteAEAD cipher.AEAD, + remoteWriteIV []byte, + nonceLength int, + tagLength int, +) *aead { + return &aead{ + localAEAD: localAEAD, + localWriteIV: localWriteIV, + remoteAEAD: remoteAEAD, + remoteWriteIV: remoteWriteIV, + nonceLength: nonceLength, + tagLength: tagLength, + nonceBufferPool: sync.Pool{ + New: func() any { + b := make([]byte, nonceLength) + return &b // nolint:nlreturn + }, + }, + } +} + +// encrypt encrypts a DTLS RecordLayer message. +func (a *aead) encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { + payload := raw[pkt.Header.Size():] + raw = raw[:pkt.Header.Size()] + + // Get nonce buffer from pool + noncePtr := a.nonceBufferPool.Get().(*[]byte) // nolint:forcetypeassert + nonce := *noncePtr + + copy(nonce, a.localWriteIV[:4]) + + // https://www.rfc-editor.org/rfc/rfc9325#name-nonce-reuse-in-tls-12 + seq64 := (uint64(pkt.Header.Epoch) << 48) | (pkt.Header.SequenceNumber & 0x0000ffffffffffff) + binary.BigEndian.PutUint64(nonce[4:], seq64) + + var additionalData []byte + if pkt.Header.ContentType == protocol.ContentTypeConnectionID { + additionalData = generateAEADAdditionalDataCID(&pkt.Header, len(payload)) + } else { + additionalData = generateAEADAdditionalData(&pkt.Header, len(payload)) + } + finalSize := len(raw) + 8 + len(payload) + a.tagLength + r := make([]byte, finalSize) + copy(r, raw) + copy(r[len(raw):], nonce[4:]) + + a.localAEAD.Seal(r[len(raw)+8:len(raw)+8], nonce, payload, additionalData) + + // Update recordLayer size to include explicit nonce + binary.BigEndian.PutUint16(r[pkt.Header.Size()-2:], uint16(len(r)-pkt.Header.Size())) //nolint:gosec //G115 + + // Return nonce buffer to pool + a.nonceBufferPool.Put(noncePtr) + + return r, nil +} + +// decrypt decrypts a DTLS RecordLayer message. +func (a *aead) decrypt(header recordlayer.Header, in []byte) ([]byte, error) { + err := header.Unmarshal(in) + switch { + case err != nil: + return nil, err + case header.ContentType == protocol.ContentTypeChangeCipherSpec: + // Nothing to encrypt with ChangeCipherSpec + return in, nil + case len(in) <= (8 + header.Size()): + return nil, errNotEnoughRoomForNonce + } + + // Get nonce buffer from pool + noncePtr := a.nonceBufferPool.Get().(*[]byte) // nolint:forcetypeassert + nonce := *noncePtr + + copy(nonce[:4], a.remoteWriteIV[:4]) + copy(nonce[4:], in[header.Size():header.Size()+8]) + out := in[header.Size()+8:] + + var additionalData []byte + if header.ContentType == protocol.ContentTypeConnectionID { + additionalData = generateAEADAdditionalDataCID(&header, len(out)-a.tagLength) + } else { + additionalData = generateAEADAdditionalData(&header, len(out)-a.tagLength) + } + out, err = a.remoteAEAD.Open(out[:0], nonce, out, additionalData) + if err != nil { + // Return nonce buffer to pool + a.nonceBufferPool.Put(noncePtr) + + return nil, fmt.Errorf("%w: %v", errDecryptPacket, err) //nolint:errorlint + } + + // Return nonce buffer to pool + a.nonceBufferPool.Put(noncePtr) + + return append(in[:header.Size()], out...), nil +} + +func generateAEADAdditionalData(h *recordlayer.Header, payloadLen int) []byte { + var additionalData [13]byte + + // SequenceNumber MUST be set first + // we only want uint48, clobbering an extra 2 (using uint64, Golang doesn't have uint48) + binary.BigEndian.PutUint64(additionalData[:], h.SequenceNumber) + binary.BigEndian.PutUint16(additionalData[:], h.Epoch) + additionalData[8] = byte(h.ContentType) + additionalData[9] = h.Version.Major + additionalData[10] = h.Version.Minor + //nolint:gosec //G115 + binary.BigEndian.PutUint16(additionalData[len(additionalData)-2:], uint16(payloadLen)) + + return additionalData[:] +} + +// generateAEADAdditionalDataCID generates additional data for AEAD ciphers +// according to https://datatracker.ietf.org/doc/html/rfc9146#name-aead-ciphers +func generateAEADAdditionalDataCID(h *recordlayer.Header, payloadLen int) []byte { + var builder cryptobyte.Builder + + builder.AddUint64(seqNumPlaceholder) + builder.AddUint8(uint8(protocol.ContentTypeConnectionID)) + builder.AddUint8(uint8(len(h.ConnectionID))) //nolint:gosec //G115 + builder.AddUint8(uint8(protocol.ContentTypeConnectionID)) + builder.AddUint8(h.Version.Major) + builder.AddUint8(h.Version.Minor) + builder.AddUint16(h.Epoch) + util.AddUint48(&builder, h.SequenceNumber) + builder.AddBytes(h.ConnectionID) + builder.AddUint16(uint16(payloadLen)) //nolint:gosec //G115 + + return builder.BytesOrPanic() +} + +// examinePadding returns, in constant time, the length of the padding to remove +// from the end of payload. It also returns a byte which is equal to 255 if the +// padding was valid and 0 otherwise. See RFC 2246, Section 6.2.3.2. +// +// https://github.com/golang/go/blob/039c2081d1178f90a8fa2f4e6958693129f8de33/src/crypto/tls/conn.go#L245 +func examinePadding(payload []byte) (toRemove int, good byte) { + if len(payload) < 1 { + return 0, 0 + } + + paddingLen := payload[len(payload)-1] + t := uint(len(payload)-1) - uint(paddingLen) //nolint:gosec //G115 + // if len(payload) >= (paddingLen - 1) then the MSB of t is zero + good = byte(int32(^t) >> 31) //nolint:gosec //G115 + + // The maximum possible padding length plus the actual length field + toCheck := min( + // The length of the padded data is public, so we can use an if here + 256, len(payload)) + + for i := 0; i < toCheck; i++ { + t := uint(paddingLen) - uint(i) //nolint:gosec //G115 + // if i <= paddingLen then the MSB of t is zero + mask := byte(int32(^t) >> 31) //nolint:gosec //G115 + b := payload[len(payload)-1-i] + good &^= mask&paddingLen ^ mask&b + } + + // We AND together the bits of good and replicate the result across + // all the bits. + good &= good << 4 + good &= good << 2 + good &= good << 1 + good = uint8(int8(good) >> 7) //nolint:gosec //G115 + + toRemove = int(paddingLen) + 1 + + return toRemove, good +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/crypto/ciphersuite/gcm.go b/vendor/github.com/pion/dtls/v3/pkg/crypto/ciphersuite/gcm.go new file mode 100644 index 0000000..13ce2f9 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/crypto/ciphersuite/gcm.go @@ -0,0 +1,63 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ciphersuite + +import ( + "crypto/aes" + "crypto/cipher" + + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" +) + +const ( + gcmTagLength = 16 + gcmNonceLength = 12 +) + +// GCM Provides an API to Encrypt/Decrypt DTLS 1.2 Packets. +type GCM struct { + aead *aead +} + +// NewGCM creates a DTLS GCM Cipher. +func NewGCM(localKey, localWriteIV, remoteKey, remoteWriteIV []byte) (*GCM, error) { + localBlock, err := aes.NewCipher(localKey) + if err != nil { + return nil, err + } + localGCM, err := cipher.NewGCM(localBlock) + if err != nil { + return nil, err + } + + remoteBlock, err := aes.NewCipher(remoteKey) + if err != nil { + return nil, err + } + remoteGCM, err := cipher.NewGCM(remoteBlock) + if err != nil { + return nil, err + } + + return &GCM{ + aead: newAEAD( + localGCM, + localWriteIV, + remoteGCM, + remoteWriteIV, + gcmNonceLength, + gcmTagLength, + ), + }, nil +} + +// Encrypt encrypts a DTLS RecordLayer message. +func (g *GCM) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { + return g.aead.encrypt(pkt, raw) +} + +// Decrypt decrypts a DTLS RecordLayer message. +func (g *GCM) Decrypt(header recordlayer.Header, in []byte) ([]byte, error) { + return g.aead.decrypt(header, in) +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/crypto/clientcertificate/client_certificate.go b/vendor/github.com/pion/dtls/v3/pkg/crypto/clientcertificate/client_certificate.go new file mode 100644 index 0000000..2b75267 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/crypto/clientcertificate/client_certificate.go @@ -0,0 +1,25 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package clientcertificate provides all the support Client Certificate types +package clientcertificate + +// Type is used to communicate what +// type of certificate is being transported +// +// https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-2 +type Type byte + +// ClientCertificateType enums. +const ( + RSASign Type = 1 + ECDSASign Type = 64 +) + +// Types returns all valid ClientCertificate Types. +func Types() map[Type]bool { + return map[Type]bool{ + RSASign: true, + ECDSASign: true, + } +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/crypto/elliptic/elliptic.go b/vendor/github.com/pion/dtls/v3/pkg/crypto/elliptic/elliptic.go new file mode 100644 index 0000000..99dbadd --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/crypto/elliptic/elliptic.go @@ -0,0 +1,121 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package elliptic provides elliptic curve cryptography for DTLS +package elliptic + +import ( + "crypto/ecdh" + "crypto/rand" + "errors" + "fmt" +) + +var errInvalidNamedCurve = errors.New("invalid named curve") + +// CurvePointFormat is used to represent the IANA registered curve points +// +// https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-9 +type CurvePointFormat byte + +// CurvePointFormat enums. +const ( + CurvePointFormatUncompressed CurvePointFormat = 0 +) + +// Keypair is a Curve with a Private/Public Keypair. +type Keypair struct { + Curve Curve + PublicKey []byte + PrivateKey []byte +} + +// CurveType is used to represent the IANA registered curve types for TLS +// +// https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-10 +type CurveType byte + +// CurveType enums. +const ( + CurveTypeNamedCurve CurveType = 0x03 +) + +// CurveTypes returns all known curves. +func CurveTypes() map[CurveType]struct{} { + return map[CurveType]struct{}{ + CurveTypeNamedCurve: {}, + } +} + +// Curve is used to represent the IANA registered curves for TLS +// +// https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-8 +type Curve uint16 + +// Curve enums. +const ( + P256 Curve = 0x0017 + P384 Curve = 0x0018 + X25519 Curve = 0x001d + // X25519MLKEM768 + // https://pkg.go.dev/crypto/internal/fips140/mlkem + // https://datatracker.ietf.org/doc/draft-ietf-tls-hybrid-design/ + // https://datatracker.ietf.org/doc/draft-ietf-tls-ecdhe-mlkem/ +) + +func (c Curve) String() string { + switch c { + case P256: + return "P-256" + case P384: + return "P-384" + case X25519: + return "X25519" + } + + return fmt.Sprintf("%#x", uint16(c)) +} + +// Curves returns all curves we implement. +func Curves() map[Curve]bool { + return map[Curve]bool{ + X25519: true, + P256: true, + P384: true, + } +} + +// GenerateKeypair generates a keypair for the given Curve. +func GenerateKeypair(curve Curve) (*Keypair, error) { + ec, err := curve.toECDH() + if err != nil { + return nil, err + } + + sk, err := ec.GenerateKey(rand.Reader) + if err != nil { + return nil, err + } + + pk := sk.PublicKey() + + return &Keypair{ + Curve: curve, + PublicKey: pk.Bytes(), // NIST: SEC1 uncompressed (04||X||Y); X25519: 32 bytes + PrivateKey: sk.Bytes(), // Scalar suitable for ecdh.NewPrivateKey + }, nil +} + +// toECDH returns the crypto/ecdh curve for our enum. +func (c Curve) toECDH() (ecdh.Curve, error) { + switch c { + case X25519: + return ecdh.X25519(), nil + case P256: + return ecdh.P256(), nil + case P384: + return ecdh.P384(), nil + default: + return nil, errInvalidNamedCurve + } +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/crypto/fingerprint/fingerprint.go b/vendor/github.com/pion/dtls/v3/pkg/crypto/fingerprint/fingerprint.go new file mode 100644 index 0000000..7025efc --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/crypto/fingerprint/fingerprint.go @@ -0,0 +1,53 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package fingerprint provides a helper to create fingerprint string from certificate +package fingerprint + +import ( + "crypto" + "crypto/x509" + "errors" + "fmt" +) + +var ( + errHashUnavailable = errors.New("fingerprint: hash algorithm is not linked into the binary") + errInvalidFingerprintLength = errors.New("fingerprint: invalid fingerprint length") +) + +// Fingerprint creates a fingerprint for a certificate using the specified hash algorithm. +func Fingerprint(cert *x509.Certificate, algo crypto.Hash) (string, error) { + if !algo.Available() { + return "", errHashUnavailable + } + h := algo.New() + for i := 0; i < len(cert.Raw); { + n, _ := h.Write(cert.Raw[i:]) + // Hash.Writer is specified to be never returning an error. + // https://golang.org/pkg/hash/#Hash + i += n + } + digest := fmt.Appendf(nil, "%x", h.Sum(nil)) + + digestlen := len(digest) + if digestlen == 0 { + return "", nil + } + if digestlen%2 != 0 { + return "", errInvalidFingerprintLength + } + res := make([]byte, digestlen>>1+digestlen-1) + + pos := 0 + for i, c := range digest { + res[pos] = c + pos++ + if (i)%2 != 0 && i < digestlen-1 { + res[pos] = byte(':') + pos++ + } + } + + return string(res), nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/crypto/fingerprint/hash.go b/vendor/github.com/pion/dtls/v3/pkg/crypto/fingerprint/hash.go new file mode 100644 index 0000000..5b07a0c --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/crypto/fingerprint/hash.go @@ -0,0 +1,43 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package fingerprint + +import ( + "crypto" + "errors" + "strings" +) + +var errInvalidHashAlgorithm = errors.New("fingerprint: invalid hash algorithm") + +func nameToHash() map[string]crypto.Hash { + return map[string]crypto.Hash{ + "md5": crypto.MD5, // [RFC3279] + "sha-1": crypto.SHA1, // [RFC3279] + "sha-224": crypto.SHA224, // [RFC4055] + "sha-256": crypto.SHA256, // [RFC4055] + "sha-384": crypto.SHA384, // [RFC4055] + "sha-512": crypto.SHA512, // [RFC4055] + } +} + +// HashFromString allows looking up a hash algorithm by it's string representation. +func HashFromString(s string) (crypto.Hash, error) { + if h, ok := nameToHash()[strings.ToLower(s)]; ok { + return h, nil + } + + return 0, errInvalidHashAlgorithm +} + +// StringFromHash allows looking up a string representation of the crypto.Hash. +func StringFromHash(hash crypto.Hash) (string, error) { + for s, h := range nameToHash() { + if h == hash { + return s, nil + } + } + + return "", errInvalidHashAlgorithm +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/crypto/hash/hash.go b/vendor/github.com/pion/dtls/v3/pkg/crypto/hash/hash.go new file mode 100644 index 0000000..e3a066f --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/crypto/hash/hash.go @@ -0,0 +1,157 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package hash provides TLS HashAlgorithm as defined in TLS 1.2 +package hash + +import ( //nolint:gci + "crypto" + "crypto/md5" //nolint:gosec + "crypto/sha1" //nolint:gosec + "crypto/sha256" + "crypto/sha512" +) + +// Algorithm is used to indicate the hash algorithm used +// https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-18 +type Algorithm uint16 + +// Supported hash algorithms. +const ( + None Algorithm = 0 // Blacklisted + MD5 Algorithm = 1 // Blacklisted + SHA1 Algorithm = 2 // Blacklisted + SHA224 Algorithm = 3 + SHA256 Algorithm = 4 + SHA384 Algorithm = 5 + SHA512 Algorithm = 6 + Ed25519 Algorithm = 8 +) + +// String makes hashAlgorithm printable. +func (a Algorithm) String() string { + switch a { + case None: + return "none" + case MD5: + return "md5" // [RFC3279] + case SHA1: + return "sha-1" // [RFC3279] + case SHA224: + return "sha-224" // [RFC4055] + case SHA256: + return "sha-256" // [RFC4055] + case SHA384: + return "sha-384" // [RFC4055] + case SHA512: + return "sha-512" // [RFC4055] + case Ed25519: + return "null" + default: + return "unknown or unsupported hash algorithm" + } +} + +// Digest performs a digest on the passed value. +func (a Algorithm) Digest(b []byte) []byte { + switch a { + case None: + return nil + case MD5: + hash := md5.Sum(b) // #nosec + + return hash[:] + case SHA1: + hash := sha1.Sum(b) // #nosec + + return hash[:] + case SHA224: + hash := sha256.Sum224(b) + + return hash[:] + case SHA256: + hash := sha256.Sum256(b) + + return hash[:] + case SHA384: + hash := sha512.Sum384(b) + + return hash[:] + case SHA512: + hash := sha512.Sum512(b) + + return hash[:] + default: + return nil + } +} + +// Insecure returns if the given HashAlgorithm is considered secure in DTLS 1.2 +// . +func (a Algorithm) Insecure() bool { + switch a { + case None, MD5, SHA1: + return true + default: + return false + } +} + +// CryptoHash returns the crypto.Hash implementation for the given HashAlgorithm. +func (a Algorithm) CryptoHash() crypto.Hash { + switch a { + case None: + return crypto.Hash(0) + case MD5: + return crypto.MD5 + case SHA1: + return crypto.SHA1 + case SHA224: + return crypto.SHA224 + case SHA256: + return crypto.SHA256 + case SHA384: + return crypto.SHA384 + case SHA512: + return crypto.SHA512 + case Ed25519: + return crypto.Hash(0) + default: + return crypto.Hash(0) + } +} + +// Algorithms returns all the supported Hash Algorithms. +func Algorithms() map[Algorithm]struct{} { + return map[Algorithm]struct{}{ + None: {}, + MD5: {}, + SHA1: {}, + SHA224: {}, + SHA256: {}, + SHA384: {}, + SHA512: {}, + Ed25519: {}, + } +} + +// ExtractHashFromPSS extracts the hash algorithm from an RSA-PSS SignatureScheme value. +// This handles TLS 1.3 PSS schemes. +// Returns None if the scheme is not a recognized PSS scheme. +func ExtractHashFromPSS(pssScheme uint16) Algorithm { + // Note: We can't import signature package here due to circular dependency, + // so we use the raw values. These correspond to: + // 0x0804 = RSA_PSS_RSAE_SHA256, 0x0809 = RSA_PSS_PSS_SHA256 + // 0x0805 = RSA_PSS_RSAE_SHA384, 0x080a = RSA_PSS_PSS_SHA384 + // 0x0806 = RSA_PSS_RSAE_SHA512, 0x080b = RSA_PSS_PSS_SHA512 + switch pssScheme { + case 0x0804, 0x0809: + return SHA256 + case 0x0805, 0x080a: + return SHA384 + case 0x0806, 0x080b: + return SHA512 + default: + return None + } +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/crypto/prf/prf.go b/vendor/github.com/pion/dtls/v3/pkg/crypto/prf/prf.go new file mode 100644 index 0000000..d75ec1c --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/crypto/prf/prf.go @@ -0,0 +1,264 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package prf implements TLS 1.2 Pseudorandom functions +package prf + +import ( //nolint:gci + "crypto/ecdh" + "crypto/hmac" + "encoding/binary" + "errors" + "fmt" + "hash" + "math" + + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/protocol" +) + +const ( + masterSecretLabel = "master secret" + extendedMasterSecretLabel = "extended master secret" + keyExpansionLabel = "key expansion" + verifyDataClientLabel = "client finished" + verifyDataServerLabel = "server finished" +) + +// HashFunc allows callers to decide what hash is used in PRF. +type HashFunc func() hash.Hash + +// EncryptionKeys is all the state needed for a TLS CipherSuite. +type EncryptionKeys struct { + MasterSecret []byte + ClientMACKey []byte + ServerMACKey []byte + ClientWriteKey []byte + ServerWriteKey []byte + ClientWriteIV []byte + ServerWriteIV []byte +} + +var errInvalidNamedCurve = &protocol.FatalError{Err: errors.New("invalid named curve")} //nolint:err113 + +func (e *EncryptionKeys) String() string { + return fmt.Sprintf(`encryptionKeys: +- masterSecret: %#v +- clientMACKey: %#v +- serverMACKey: %#v +- clientWriteKey: %#v +- serverWriteKey: %#v +- clientWriteIV: %#v +- serverWriteIV: %#v +`, + e.MasterSecret, + e.ClientMACKey, + e.ServerMACKey, + e.ClientWriteKey, + e.ServerWriteKey, + e.ClientWriteIV, + e.ServerWriteIV) +} + +// PSKPreMasterSecret generates the PSK Premaster Secret +// The premaster secret is formed as follows: if the PSK is N octets +// long, concatenate a uint16 with the value N, N zero octets, a second +// uint16 with the value N, and the PSK itself. +// +// https://tools.ietf.org/html/rfc4279#section-2 +func PSKPreMasterSecret(psk []byte) []byte { + pskLen := uint16(len(psk)) //nolint:gosec // G115 + + out := append(make([]byte, 2+pskLen+2), psk...) + binary.BigEndian.PutUint16(out, pskLen) + binary.BigEndian.PutUint16(out[2+pskLen:], pskLen) + + return out +} + +// EcdhePSKPreMasterSecret implements TLS 1.2 Premaster Secret generation given a psk, a keypair and a curve +// +// https://datatracker.ietf.org/doc/html/rfc5489#section-2 +func EcdhePSKPreMasterSecret(psk, publicKey, privateKey []byte, curve elliptic.Curve) ([]byte, error) { + preMasterSecret, err := PreMasterSecret(publicKey, privateKey, curve) + if err != nil { + return nil, err + } + out := make([]byte, 2+len(preMasterSecret)+2+len(psk)) + + // write preMasterSecret length + offset := 0 + binary.BigEndian.PutUint16(out[offset:], uint16(len(preMasterSecret))) //nolint:gosec // G115 + offset += 2 + + // write preMasterSecret + copy(out[offset:], preMasterSecret) + offset += len(preMasterSecret) + + // write psk length + binary.BigEndian.PutUint16(out[offset:], uint16(len(psk))) //nolint:gosec // G115 + offset += 2 + + // write psk + copy(out[offset:], psk) + + return out, nil +} + +// PreMasterSecret implements TLS 1.2 Premaster Secret generation given a keypair and a curve. +func PreMasterSecret(publicKey, privateKey []byte, curve elliptic.Curve) ([]byte, error) { + var ec ecdh.Curve + + switch curve { + case elliptic.X25519: + ec = ecdh.X25519() + case elliptic.P256: + ec = ecdh.P256() + case elliptic.P384: + ec = ecdh.P384() + default: + return nil, errInvalidNamedCurve + } + + sk, err := ec.NewPrivateKey(privateKey) + if err != nil { + return nil, err + } + + pk, err := ec.NewPublicKey(publicKey) // NIST: SEC1 uncompressed; X25519: 32-byte u + if err != nil { + return nil, err + } + + return sk.ECDH(pk) +} + +// PHash is PRF is the SHA-256 hash function is used for all cipher suites +// defined in this TLS 1.2 document and in TLS documents published prior to this +// document when TLS 1.2 is negotiated. New cipher suites MUST explicitly +// specify a PRF and, in general, SHOULD use the TLS PRF with SHA-256 or a +// stronger standard hash function. +// +// P_hash(secret, seed) = HMAC_hash(secret, A(1) + seed) + +// HMAC_hash(secret, A(2) + seed) + +// HMAC_hash(secret, A(3) + seed) + ... +// +// A() is defined as: +// +// A(0) = seed +// A(i) = HMAC_hash(secret, A(i-1)) +// +// P_hash can be iterated as many times as necessary to produce the +// required quantity of data. For example, if P_SHA256 is being used to +// create 80 bytes of data, it will have to be iterated three times +// (through A(3)), creating 96 bytes of output data; the last 16 bytes +// of the final iteration will then be discarded, leaving 80 bytes of +// output data. +// +// https://tools.ietf.org/html/rfc4346w +func PHash(secret, seed []byte, requestedLength int, hashFunc HashFunc) ([]byte, error) { + hmacSHA256 := func(key, data []byte) ([]byte, error) { + mac := hmac.New(hashFunc, key) + if _, err := mac.Write(data); err != nil { + return nil, err + } + + return mac.Sum(nil), nil + } + + var err error + lastRound := seed + out := []byte{} + + iterations := int(math.Ceil(float64(requestedLength) / float64(hashFunc().Size()))) + for i := 0; i < iterations; i++ { + lastRound, err = hmacSHA256(secret, lastRound) + if err != nil { + return nil, err + } + withSecret, err := hmacSHA256(secret, append(lastRound, seed...)) + if err != nil { + return nil, err + } + out = append(out, withSecret...) + } + + return out[:requestedLength], nil +} + +// ExtendedMasterSecret generates a Extended MasterSecret as defined in +// https://tools.ietf.org/html/rfc7627 +func ExtendedMasterSecret(preMasterSecret, sessionHash []byte, h HashFunc) ([]byte, error) { + seed := append([]byte(extendedMasterSecretLabel), sessionHash...) + + return PHash(preMasterSecret, seed, 48, h) +} + +// MasterSecret generates a TLS 1.2 MasterSecret. +func MasterSecret(preMasterSecret, clientRandom, serverRandom []byte, h HashFunc) ([]byte, error) { + seed := append(append([]byte(masterSecretLabel), clientRandom...), serverRandom...) + + return PHash(preMasterSecret, seed, 48, h) +} + +// GenerateEncryptionKeys is the final step TLS 1.2 PRF. Given all state generated so far generates +// the final keys need for encryption. +func GenerateEncryptionKeys( + masterSecret, clientRandom, serverRandom []byte, + macLen, keyLen, ivLen int, + h HashFunc, +) (*EncryptionKeys, error) { + seed := append(append([]byte(keyExpansionLabel), serverRandom...), clientRandom...) + keyMaterial, err := PHash(masterSecret, seed, (2*macLen)+(2*keyLen)+(2*ivLen), h) + if err != nil { + return nil, err + } + + clientMACKey := keyMaterial[:macLen] + keyMaterial = keyMaterial[macLen:] + + serverMACKey := keyMaterial[:macLen] + keyMaterial = keyMaterial[macLen:] + + clientWriteKey := keyMaterial[:keyLen] + keyMaterial = keyMaterial[keyLen:] + + serverWriteKey := keyMaterial[:keyLen] + keyMaterial = keyMaterial[keyLen:] + + clientWriteIV := keyMaterial[:ivLen] + keyMaterial = keyMaterial[ivLen:] + + serverWriteIV := keyMaterial[:ivLen] + + return &EncryptionKeys{ + MasterSecret: masterSecret, + ClientMACKey: clientMACKey, + ServerMACKey: serverMACKey, + ClientWriteKey: clientWriteKey, + ServerWriteKey: serverWriteKey, + ClientWriteIV: clientWriteIV, + ServerWriteIV: serverWriteIV, + }, nil +} + +func prfVerifyData(masterSecret, handshakeBodies []byte, label string, hashFunc HashFunc) ([]byte, error) { + h := hashFunc() + if _, err := h.Write(handshakeBodies); err != nil { + return nil, err + } + + seed := append([]byte(label), h.Sum(nil)...) + + return PHash(masterSecret, seed, 12, hashFunc) +} + +// VerifyDataClient is caled on the Client Side to either verify or generate the VerifyData message. +func VerifyDataClient(masterSecret, handshakeBodies []byte, h HashFunc) ([]byte, error) { + return prfVerifyData(masterSecret, handshakeBodies, verifyDataClientLabel, h) +} + +// VerifyDataServer is caled on the Server Side to either verify or generate the VerifyData message. +func VerifyDataServer(masterSecret, handshakeBodies []byte, h HashFunc) ([]byte, error) { + return prfVerifyData(masterSecret, handshakeBodies, verifyDataServerLabel, h) +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/crypto/signature/signature.go b/vendor/github.com/pion/dtls/v3/pkg/crypto/signature/signature.go new file mode 100644 index 0000000..b2f3535 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/crypto/signature/signature.go @@ -0,0 +1,83 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package signature provides our implemented Signature Algorithms +package signature + +import "github.com/pion/dtls/v3/pkg/crypto/hash" + +// Algorithm as defined in TLS 1.2 +// https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-16 +type Algorithm uint16 + +// SignatureAlgorithm enums. +const ( + Anonymous Algorithm = 0 + RSA Algorithm = 1 + ECDSA Algorithm = 3 + Ed25519 Algorithm = 7 + + // RSA-PSS (DTLS 1.3 only) - full SignatureScheme values + // https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-signaturescheme + RSA_PSS_RSAE_SHA256 Algorithm = 0x0804 // nolint: staticcheck + RSA_PSS_RSAE_SHA384 Algorithm = 0x0805 // nolint: staticcheck + RSA_PSS_RSAE_SHA512 Algorithm = 0x0806 // nolint: revive,staticcheck + RSA_PSS_PSS_SHA256 Algorithm = 0x0809 // nolint: revive,staticcheck + RSA_PSS_PSS_SHA384 Algorithm = 0x080a // nolint: revive,staticcheck + RSA_PSS_PSS_SHA512 Algorithm = 0x080b // nolint: revive,staticcheck +) + +// Algorithms returns all implemented Signature Algorithms. +func Algorithms() map[Algorithm]struct{} { + return map[Algorithm]struct{}{ + Anonymous: {}, + RSA: {}, + ECDSA: {}, + Ed25519: {}, + RSA_PSS_RSAE_SHA256: {}, + RSA_PSS_RSAE_SHA384: {}, + RSA_PSS_RSAE_SHA512: {}, + RSA_PSS_PSS_SHA256: {}, + RSA_PSS_PSS_SHA384: {}, + RSA_PSS_PSS_SHA512: {}, + } +} + +// IsPSS returns true if the algorithm is an RSA-PSS signature scheme. +// It's tempting to check for range between 0x0804 and 0x080b, but 0x0807 is Ed25519 +// and 0x0808 is Ed448, which are NOT PSS, so we check specific values instead. +func (a Algorithm) IsPSS() bool { + return a == RSA_PSS_RSAE_SHA256 || + a == RSA_PSS_RSAE_SHA384 || + a == RSA_PSS_RSAE_SHA512 || + a == RSA_PSS_PSS_SHA256 || + a == RSA_PSS_PSS_SHA384 || + a == RSA_PSS_PSS_SHA512 +} + +// IsUnsupported returns true if the algorithm is a signature scheme that is +// not supported by pion/dtls. +func (a Algorithm) IsUnsupported() bool { + // Skip RSA_PSS_PSS schemes (0x0809-0x080b). We parse them for interoperability + // but don't negotiate them to avoid unnecessary complexity for certificates that + // don't exist in practice. This follows the pragmatic approach of Go's crypto/tls + // and BoringSSL: target real-world WebPKI use cases rather than RFC completeness. + return a == RSA_PSS_PSS_SHA256 || + a == RSA_PSS_PSS_SHA384 || + a == RSA_PSS_PSS_SHA512 +} + +// GetPSSHash returns the hash algorithm associated with an RSA-PSS signature scheme. +// Returns hash.None if the algorithm is not an RSA-PSS scheme. +func (a Algorithm) GetPSSHash() hash.Algorithm { + switch a { + case RSA_PSS_RSAE_SHA256, RSA_PSS_PSS_SHA256: + return hash.SHA256 + case RSA_PSS_RSAE_SHA384, RSA_PSS_PSS_SHA384: + return hash.SHA384 + case RSA_PSS_RSAE_SHA512, RSA_PSS_PSS_SHA512: + return hash.SHA512 + default: + return hash.None + } +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/crypto/signaturehash/errors.go b/vendor/github.com/pion/dtls/v3/pkg/crypto/signaturehash/errors.go new file mode 100644 index 0000000..6e11fbe --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/crypto/signaturehash/errors.go @@ -0,0 +1,13 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package signaturehash + +import "errors" + +var ( + errNoAvailableSignatureSchemes = errors.New("connection can not be created, no SignatureScheme satisfy this Config") + errInvalidSignatureAlgorithm = errors.New("invalid signature algorithm") + errInvalidHashAlgorithm = errors.New("invalid hash algorithm") + errInvalidPrivateKey = errors.New("invalid private key type") +) diff --git a/vendor/github.com/pion/dtls/v3/pkg/crypto/signaturehash/select_13.go b/vendor/github.com/pion/dtls/v3/pkg/crypto/signaturehash/select_13.go new file mode 100644 index 0000000..2c22e15 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/crypto/signaturehash/select_13.go @@ -0,0 +1,32 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package signaturehash + +import ( + "crypto" +) + +// selectSignatureScheme13 returns most preferred and compatible scheme. +// It's compatible with all DTLS versions up to and including 1.3. +func selectSignatureScheme13(sigs []Algorithm, privateKey crypto.PrivateKey, is13 bool) (Algorithm, error) { + signer, ok := privateKey.(crypto.Signer) + if !ok { + return Algorithm{}, errInvalidPrivateKey + } + for _, ss := range sigs { + // Skip PSS schemes for DTLS 1.2 (PSS is only supported in DTLS 1.3) + if !is13 && ss.Signature.IsPSS() { + continue + } + // Skip schemes understood but not supported by pion/dtls. + if ss.Signature.IsUnsupported() { + continue + } + if ss.isCompatible(signer) { + return ss, nil + } + } + + return Algorithm{}, errNoAvailableSignatureSchemes +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/crypto/signaturehash/signaturehash.go b/vendor/github.com/pion/dtls/v3/pkg/crypto/signaturehash/signaturehash.go new file mode 100644 index 0000000..3a138a3 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/crypto/signaturehash/signaturehash.go @@ -0,0 +1,187 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package signaturehash provides the SignatureHashAlgorithm as defined in TLS 1.2 +package signaturehash + +import ( + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "fmt" + + "github.com/pion/dtls/v3/pkg/crypto/hash" + "github.com/pion/dtls/v3/pkg/crypto/signature" +) + +// Algorithm is a signature/hash algorithm pairs which may be used in +// digital signatures. +// +// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1 +type Algorithm struct { + Hash hash.Algorithm + Signature signature.Algorithm +} + +// Algorithms returns signature algorithms compatible with DTLS 1.2 / TLS 1.2. +// This excludes TLS 1.3-specific schemes like RSA-PSS to ensure compatibility +// with implementations like OpenSSL that don't recognize TLS 1.3 signature +// algorithm IDs in DTLS 1.2 handshakes. +// +// IMPORTANT: order in this slice determines priority used by SelectSignatureScheme. +// +// Order follows industry standard preference (ECDSA-first) as used by OpenSSL, +// BoringSSL, Firefox, Chrome, and other major TLS implementations. +func Algorithms() []Algorithm { + return []Algorithm{ + // ECDSA schemes (modern, efficient - industry standard preference) + {hash.SHA256, signature.ECDSA}, + {hash.SHA384, signature.ECDSA}, + {hash.SHA512, signature.ECDSA}, + + // Ed25519 + {hash.Ed25519, signature.Ed25519}, + + // RSA PKCS#1 v1.5 schemes (legacy, DTLS 1.2) + {hash.SHA256, signature.RSA}, + {hash.SHA384, signature.RSA}, + {hash.SHA512, signature.RSA}, + } +} + +// SelectSignatureScheme returns most preferred and compatible scheme for DTLS <= 1.2. +func SelectSignatureScheme(sigs []Algorithm, privateKey crypto.PrivateKey) (Algorithm, error) { + return selectSignatureScheme13(sigs, privateKey, false) +} + +// SelectSignatureScheme13 returns most preferred and compatible scheme for DTLS 1.3. +func SelectSignatureScheme13(sigs []Algorithm, privateKey crypto.PrivateKey) (Algorithm, error) { + return selectSignatureScheme13(sigs, privateKey, true) +} + +// isCompatible checks that given private key is compatible with the signature scheme. +func (a *Algorithm) isCompatible(signer crypto.Signer) bool { + switch signer.Public().(type) { + case ed25519.PublicKey: + return a.Signature == signature.Ed25519 + case *ecdsa.PublicKey: + return a.Signature == signature.ECDSA + case *rsa.PublicKey: + // RSA keys are compatible with both PKCS#1 v1.5 and PSS signatures + return a.Signature == signature.RSA || a.Signature.IsPSS() + default: + return false + } +} + +// ParseSignatureSchemes translates []tls.SignatureScheme to []signatureHashAlgorithm. +// It returns default signature scheme list if no SignatureScheme is passed. +// This function handles both TLS 1.2 byte-split encoding and TLS 1.3 PSS full uint16 schemes. +// +// For DTLS 1.2 / TLS 1.2, this returns Algorithms() which excludes TLS 1.3-specific +// schemes like RSA-PSS for compatibility with implementations like OpenSSL. +// When DTLS 1.3 is implemented, use Algorithms13() or create ParseSignatureSchemes13(). +func ParseSignatureSchemes(sigs []tls.SignatureScheme, insecureHashes bool) ([]Algorithm, error) { + if len(sigs) == 0 { + return Algorithms(), nil + } + out := []Algorithm{} + for _, ss := range sigs { + hashAlg, sigAlg, err := parseSignatureScheme(ss) + if err != nil { + return nil, err + } + + if hashAlg.Insecure() && !insecureHashes { + continue + } + + out = append(out, Algorithm{ + Hash: hashAlg, + Signature: sigAlg, + }) + } + + if len(out) == 0 { + return nil, errNoAvailableSignatureSchemes + } + + return out, nil +} + +// FromCertificate maps x509.SignatureAlgorithm to the corresponding Algorithm type. +func FromCertificate(cert *x509.Certificate) (Algorithm, error) { //nolint:cyclop + var hashAlg hash.Algorithm + var sigAlg signature.Algorithm + + switch cert.SignatureAlgorithm { + case x509.SHA256WithRSA, x509.SHA256WithRSAPSS: + hashAlg = hash.SHA256 + sigAlg = signature.RSA + case x509.SHA384WithRSA, x509.SHA384WithRSAPSS: + hashAlg = hash.SHA384 + sigAlg = signature.RSA + case x509.SHA512WithRSA, x509.SHA512WithRSAPSS: + hashAlg = hash.SHA512 + sigAlg = signature.RSA + case x509.ECDSAWithSHA256: + hashAlg = hash.SHA256 + sigAlg = signature.ECDSA + case x509.ECDSAWithSHA384: + hashAlg = hash.SHA384 + sigAlg = signature.ECDSA + case x509.ECDSAWithSHA512: + hashAlg = hash.SHA512 + sigAlg = signature.ECDSA + case x509.PureEd25519: + hashAlg = hash.None // Ed25519 doesn't use a separate hash + sigAlg = signature.Ed25519 + case x509.SHA1WithRSA: + hashAlg = hash.SHA1 + sigAlg = signature.RSA + case x509.ECDSAWithSHA1: + hashAlg = hash.SHA1 + sigAlg = signature.ECDSA + default: + return Algorithm{}, errInvalidSignatureAlgorithm + } + + return Algorithm{Hash: hashAlg, Signature: sigAlg}, nil +} + +// parseSignatureScheme translates a tls.SignatureScheme to a hash.Algorithm +// and signature.Algorithm. It returns default signature scheme list if no +// SignatureScheme is passed. This function handles both TLS 1.2 byte-split +// encoding and TLS 1.3 PSS full uint16 schemes. +func parseSignatureScheme(sigScheme tls.SignatureScheme) (hash.Algorithm, signature.Algorithm, error) { + var sigAlg signature.Algorithm + var hashAlg hash.Algorithm + + if signature.Algorithm(sigScheme).IsPSS() { + // TLS 1.3 PSS scheme - full uint16 is the signature algorithm + sigAlg = signature.Algorithm(sigScheme) + hashAlg = hash.ExtractHashFromPSS(uint16(sigScheme)) + if hashAlg == hash.None { + return 0, 0, fmt.Errorf("SignatureScheme %04x: %w", sigScheme, errInvalidHashAlgorithm) + } + } else { + // TLS 1.2 style - split into hash (high byte) and signature (low byte) + sigAlg = signature.Algorithm(sigScheme & 0xFF) + hashAlg = hash.Algorithm(sigScheme >> 8) + } + + // Validate signature algorithm + if _, ok := signature.Algorithms()[sigAlg]; !ok { + return 0, 0, fmt.Errorf("SignatureScheme %04x: %w", sigScheme, errInvalidSignatureAlgorithm) + } + + // Validate hash algorithm + if _, ok := hash.Algorithms()[hashAlg]; !ok || (ok && hashAlg == hash.None) { + return 0, 0, fmt.Errorf("SignatureScheme %04x: %w", sigScheme, errInvalidHashAlgorithm) + } + + return hashAlg, sigAlg, nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/crypto/signaturehash/signaturehash_13.go b/vendor/github.com/pion/dtls/v3/pkg/crypto/signaturehash/signaturehash_13.go new file mode 100644 index 0000000..f71b13b --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/crypto/signaturehash/signaturehash_13.go @@ -0,0 +1,48 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package signaturehash + +import ( + "github.com/pion/dtls/v3/pkg/crypto/hash" + "github.com/pion/dtls/v3/pkg/crypto/signature" +) + +// Algorithms13 returns signature algorithms compatible with DTLS 1.3. This. +// includes DTLS 1.3-specific schemes like RSA-PSS in addition to DTLS 1.2 schemes. +// +// IMPORTANT: order in this slice determines priority used by SelectSignatureScheme13. +// +// Order follows industry standard preference (ECDSA-first) as used by OpenSSL, +// BoringSSL, Firefox, Chrome, and other major TLS 1.3 implementations. +func Algorithms13() []Algorithm { + return []Algorithm{ + // ECDSA schemes (modern, efficient - industry standard preference) + {hash.SHA256, signature.ECDSA}, + {hash.SHA384, signature.ECDSA}, + {hash.SHA512, signature.ECDSA}, + + // Ed25519 + {hash.Ed25519, signature.Ed25519}, + + // RSA-PSS RSAE schemes (TLS 1.3 / DTLS 1.3 compatible with standard RSA certs) + // Note: We only offer RSA_PSS_RSAE variants (0x0804-0x0806), not RSA_PSS_PSS + // (0x0809-0x080b). RSA-PSS certificates with OID id-RSASSA-PSS are virtually + // unused in the real world and are not allowed by the CA/Browser Forum Baseline + // Requirements for WebPKI. We avoid unnecessary complexity for certificates that + // don't exist in practice, following the pragmatic approach of Go's crypto/tls + // and BoringSSL: target real-world WebPKI use cases rather than RFC completeness. + // RSA_PSS_PSS schemes are parsed for wire-format compatibility but never negotiated. + {hash.SHA256, signature.RSA_PSS_RSAE_SHA256}, + {hash.SHA384, signature.RSA_PSS_RSAE_SHA384}, + {hash.SHA512, signature.RSA_PSS_RSAE_SHA512}, + // {hash.SHA256, signature.RSA_PSS_PSS_SHA256}, + // {hash.SHA384, signature.RSA_PSS_PSS_SHA384}, + // {hash.SHA512, signature.RSA_PSS_PSS_SHA512}, + + // RSA PKCS#1 v1.5 schemes (backward compatibility with DTLS 1.2) + {hash.SHA256, signature.RSA}, + {hash.SHA384, signature.RSA}, + {hash.SHA512, signature.RSA}, + } +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/net/net.go b/vendor/github.com/pion/dtls/v3/pkg/net/net.go new file mode 100644 index 0000000..2c9518e --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/net/net.go @@ -0,0 +1,111 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package net defines packet-oriented primitives that are compatible with net +// in the standard library. +package net + +import ( + "net" + "time" +) + +// A PacketListener is the same as net.Listener but returns a net.PacketConn on +// Accept() rather than a net.Conn. +// +// Multiple goroutines may invoke methods on a PacketListener simultaneously. +type PacketListener interface { + // Accept waits for and returns the next connection to the listener. + Accept() (net.PacketConn, net.Addr, error) + + // Close closes the listener. + // Any blocked Accept operations will be unblocked and return errors. + Close() error + + // Addr returns the listener's network address. + Addr() net.Addr +} + +// PacketListenerFromListener converts a net.Listener into a +// dtlsnet.PacketListener. +func PacketListenerFromListener(l net.Listener) PacketListener { + return &packetListenerWrapper{ + l: l, + } +} + +// packetListenerWrapper wraps a net.Listener and implements +// dtlsnet.PacketListener. +type packetListenerWrapper struct { + l net.Listener +} + +// Accept calls Accept on the underlying net.Listener and converts the returned +// net.Conn into a net.PacketConn. +func (p *packetListenerWrapper) Accept() (net.PacketConn, net.Addr, error) { + c, err := p.l.Accept() + if err != nil { + return PacketConnFromConn(c), nil, err + } + + return PacketConnFromConn(c), c.RemoteAddr(), nil +} + +// Close closes the underlying net.Listener. +func (p *packetListenerWrapper) Close() error { + return p.l.Close() +} + +// Addr returns the address of the underlying net.Listener. +func (p *packetListenerWrapper) Addr() net.Addr { + return p.l.Addr() +} + +// PacketConnFromConn converts a net.Conn into a net.PacketConn. +func PacketConnFromConn(conn net.Conn) net.PacketConn { + return &packetConnWrapper{conn} +} + +// packetConnWrapper wraps a net.Conn and implements net.PacketConn. +type packetConnWrapper struct { + conn net.Conn +} + +// ReadFrom reads from the underlying net.Conn and returns its remote address. +func (p *packetConnWrapper) ReadFrom(b []byte) (int, net.Addr, error) { + n, err := p.conn.Read(b) + + return n, p.conn.RemoteAddr(), err +} + +// WriteTo writes to the underlying net.Conn. +func (p *packetConnWrapper) WriteTo(b []byte, _ net.Addr) (int, error) { + n, err := p.conn.Write(b) + + return n, err +} + +// Close closes the underlying net.Conn. +func (p *packetConnWrapper) Close() error { + return p.conn.Close() +} + +// LocalAddr returns the local address of the underlying net.Conn. +func (p *packetConnWrapper) LocalAddr() net.Addr { + return p.conn.LocalAddr() +} + +// SetDeadline sets the deadline on the underlying net.Conn. +func (p *packetConnWrapper) SetDeadline(t time.Time) error { + return p.conn.SetDeadline(t) +} + +// SetReadDeadline sets the read deadline on the underlying net.Conn. +func (p *packetConnWrapper) SetReadDeadline(t time.Time) error { + return p.conn.SetReadDeadline(t) +} + +// SetWriteDeadline sets the write deadline on the underlying net.Conn. +func (p *packetConnWrapper) SetWriteDeadline(t time.Time) error { + return p.conn.SetWriteDeadline(t) +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/alert/alert.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/alert/alert.go new file mode 100644 index 0000000..c317567 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/alert/alert.go @@ -0,0 +1,167 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package alert implements TLS alert protocol https://tools.ietf.org/html/rfc5246#section-7.2 +package alert + +import ( + "errors" + "fmt" + + "github.com/pion/dtls/v3/pkg/protocol" +) + +var errBufferTooSmall = &protocol.TemporaryError{Err: errors.New("buffer is too small")} //nolint:err113 + +// Level is the level of the TLS Alert. +type Level byte + +// Level enums. +const ( + Warning Level = 1 + Fatal Level = 2 +) + +func (l Level) String() string { + switch l { + case Warning: + return "Warning" + case Fatal: + return "Fatal" + default: + return "Invalid alert level" + } +} + +// Description is the extended info of the TLS Alert. +type Description byte + +// Description enums. +const ( + CloseNotify Description = 0 + UnexpectedMessage Description = 10 + BadRecordMac Description = 20 + DecryptionFailed Description = 21 + RecordOverflow Description = 22 + DecompressionFailure Description = 30 + HandshakeFailure Description = 40 + NoCertificate Description = 41 + BadCertificate Description = 42 + UnsupportedCertificate Description = 43 + CertificateRevoked Description = 44 + CertificateExpired Description = 45 + CertificateUnknown Description = 46 + IllegalParameter Description = 47 + UnknownCA Description = 48 + AccessDenied Description = 49 + DecodeError Description = 50 + DecryptError Description = 51 + ExportRestriction Description = 60 + ProtocolVersion Description = 70 + InsufficientSecurity Description = 71 + InternalError Description = 80 + UserCanceled Description = 90 + NoRenegotiation Description = 100 + UnsupportedExtension Description = 110 + NoApplicationProtocol Description = 120 +) + +func (d Description) String() string { //nolint:cyclop + switch d { + case CloseNotify: + return "CloseNotify" + case UnexpectedMessage: + return "UnexpectedMessage" + case BadRecordMac: + return "BadRecordMac" + case DecryptionFailed: + return "DecryptionFailed" + case RecordOverflow: + return "RecordOverflow" + case DecompressionFailure: + return "DecompressionFailure" + case HandshakeFailure: + return "HandshakeFailure" + case NoCertificate: + return "NoCertificate" + case BadCertificate: + return "BadCertificate" + case UnsupportedCertificate: + return "UnsupportedCertificate" + case CertificateRevoked: + return "CertificateRevoked" + case CertificateExpired: + return "CertificateExpired" + case CertificateUnknown: + return "CertificateUnknown" + case IllegalParameter: + return "IllegalParameter" + case UnknownCA: + return "UnknownCA" + case AccessDenied: + return "AccessDenied" + case DecodeError: + return "DecodeError" + case DecryptError: + return "DecryptError" + case ExportRestriction: + return "ExportRestriction" + case ProtocolVersion: + return "ProtocolVersion" + case InsufficientSecurity: + return "InsufficientSecurity" + case InternalError: + return "InternalError" + case UserCanceled: + return "UserCanceled" + case NoRenegotiation: + return "NoRenegotiation" + case UnsupportedExtension: + return "UnsupportedExtension" + case NoApplicationProtocol: + return "NoApplicationProtocol" + default: + return "Invalid alert description" + } +} + +// Alert is one of the content types supported by the TLS record layer. +// Alert messages convey the severity of the message +// (warning or fatal) and a description of the alert. Alert messages +// with a level of fatal result in the immediate termination of the +// connection. In this case, other connections corresponding to the +// session may continue, but the session identifier MUST be invalidated, +// preventing the failed session from being used to establish new +// connections. Like other messages, alert messages are encrypted and +// compressed, as specified by the current connection state. +// https://tools.ietf.org/html/rfc5246#section-7.2 +type Alert struct { + Level Level + Description Description +} + +// ContentType returns the ContentType of this Content. +func (a Alert) ContentType() protocol.ContentType { + return protocol.ContentTypeAlert +} + +// Marshal returns the encoded alert. +func (a *Alert) Marshal() ([]byte, error) { + return []byte{byte(a.Level), byte(a.Description)}, nil +} + +// Unmarshal populates the alert from binary data. +func (a *Alert) Unmarshal(data []byte) error { + if len(data) != 2 { + return errBufferTooSmall + } + + a.Level = Level(data[0]) + a.Description = Description(data[1]) + + return nil +} + +func (a *Alert) String() string { + return fmt.Sprintf("Alert %s: %s", a.Level, a.Description) +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/application_data.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/application_data.go new file mode 100644 index 0000000..f5d8153 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/application_data.go @@ -0,0 +1,30 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package protocol + +// ApplicationData messages are carried by the record layer and are +// fragmented, compressed, and encrypted based on the current connection +// state. The messages are treated as transparent data to the record +// layer. +// https://tools.ietf.org/html/rfc5246#section-10 +type ApplicationData struct { + Data []byte +} + +// ContentType returns the ContentType of this content. +func (a ApplicationData) ContentType() ContentType { + return ContentTypeApplicationData +} + +// Marshal encodes the ApplicationData to binary. +func (a *ApplicationData) Marshal() ([]byte, error) { + return append([]byte{}, a.Data...), nil +} + +// Unmarshal populates the ApplicationData from binary. +func (a *ApplicationData) Unmarshal(data []byte) error { + a.Data = append([]byte{}, data...) + + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/change_cipher_spec.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/change_cipher_spec.go new file mode 100644 index 0000000..e8b18de --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/change_cipher_spec.go @@ -0,0 +1,30 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package protocol + +// ChangeCipherSpec protocol exists to signal transitions in +// ciphering strategies. The protocol consists of a single message, +// which is encrypted and compressed under the current (not the pending) +// connection state. The message consists of a single byte of value 1. +// https://tools.ietf.org/html/rfc5246#section-7.1 +type ChangeCipherSpec struct{} + +// ContentType returns the ContentType of this content. +func (c ChangeCipherSpec) ContentType() ContentType { + return ContentTypeChangeCipherSpec +} + +// Marshal encodes the ChangeCipherSpec to binary. +func (c *ChangeCipherSpec) Marshal() ([]byte, error) { + return []byte{0x01}, nil +} + +// Unmarshal populates the ChangeCipherSpec from binary. +func (c *ChangeCipherSpec) Unmarshal(data []byte) error { + if len(data) == 1 && data[0] == 0x01 { + return nil + } + + return errInvalidCipherSpec +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/compression_method.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/compression_method.go new file mode 100644 index 0000000..1dd4af1 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/compression_method.go @@ -0,0 +1,53 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package protocol + +// CompressionMethodID is the ID for a CompressionMethod. +type CompressionMethodID byte + +const ( + compressionMethodNull CompressionMethodID = 0 +) + +// CompressionMethod represents a TLS Compression Method. +type CompressionMethod struct { + ID CompressionMethodID +} + +// CompressionMethods returns all supported CompressionMethods. +func CompressionMethods() map[CompressionMethodID]*CompressionMethod { + return map[CompressionMethodID]*CompressionMethod{ + compressionMethodNull: {ID: compressionMethodNull}, + } +} + +// DecodeCompressionMethods the given compression methods. +func DecodeCompressionMethods(buf []byte) ([]*CompressionMethod, error) { + if len(buf) < 1 { + return nil, errBufferTooSmall + } + compressionMethodsCount := int(buf[0]) + c := []*CompressionMethod{} + for i := 0; i < compressionMethodsCount; i++ { + if len(buf) <= i+1 { + return nil, errBufferTooSmall + } + id := CompressionMethodID(buf[i+1]) + if compressionMethod, ok := CompressionMethods()[id]; ok { + c = append(c, compressionMethod) + } + } + + return c, nil +} + +// EncodeCompressionMethods the given compression methods. +func EncodeCompressionMethods(c []*CompressionMethod) []byte { + out := []byte{byte(len(c))} + for i := len(c); i > 0; i-- { + out = append(out, byte(c[i-1].ID)) + } + + return out +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/content.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/content.go new file mode 100644 index 0000000..58bbdc5 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/content.go @@ -0,0 +1,25 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package protocol + +// ContentType represents the IANA Registered ContentTypes +// +// https://tools.ietf.org/html/rfc4346#section-6.2.1 +type ContentType uint8 + +// ContentType enums. +const ( + ContentTypeChangeCipherSpec ContentType = 20 + ContentTypeAlert ContentType = 21 + ContentTypeHandshake ContentType = 22 + ContentTypeApplicationData ContentType = 23 + ContentTypeConnectionID ContentType = 25 +) + +// Content is the top level distinguisher for a DTLS Datagram. +type Content interface { + ContentType() ContentType + Marshal() ([]byte, error) + Unmarshal(data []byte) error +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/errors.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/errors.go new file mode 100644 index 0000000..6e69a75 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/errors.go @@ -0,0 +1,112 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package protocol + +import ( + "errors" + "fmt" + "net" +) + +var ( + errBufferTooSmall = &TemporaryError{Err: errors.New("buffer is too small")} //nolint:err113 + errInvalidCipherSpec = &FatalError{Err: errors.New("cipher spec invalid")} //nolint:err113 +) + +// FatalError indicates that the DTLS connection is no longer available. +// It is mainly caused by wrong configuration of server or client. +type FatalError struct { + Err error +} + +// InternalError indicates and internal error caused by the implementation, +// and the DTLS connection is no longer available. +// It is mainly caused by bugs or tried to use unimplemented features. +type InternalError struct { + Err error +} + +// TemporaryError indicates that the DTLS connection is still available, but the request was failed temporary. +type TemporaryError struct { + Err error +} + +// TimeoutError indicates that the request was timed out. +type TimeoutError struct { + Err error +} + +// HandshakeError indicates that the handshake failed. +type HandshakeError struct { + Err error +} + +// Timeout implements net.Error.Timeout(). +func (*FatalError) Timeout() bool { return false } + +// Temporary implements net.Error.Temporary(). +func (*FatalError) Temporary() bool { return false } + +// Unwrap implements Go1.13 error unwrapper. +func (e *FatalError) Unwrap() error { return e.Err } + +func (e *FatalError) Error() string { return fmt.Sprintf("dtls fatal: %v", e.Err) } + +// Timeout implements net.Error.Timeout(). +func (*InternalError) Timeout() bool { return false } + +// Temporary implements net.Error.Temporary(). +func (*InternalError) Temporary() bool { return false } + +// Unwrap implements Go1.13 error unwrapper. +func (e *InternalError) Unwrap() error { return e.Err } + +func (e *InternalError) Error() string { return fmt.Sprintf("dtls internal: %v", e.Err) } + +// Timeout implements net.Error.Timeout(). +func (*TemporaryError) Timeout() bool { return false } + +// Temporary implements net.Error.Temporary(). +func (*TemporaryError) Temporary() bool { return true } + +// Unwrap implements Go1.13 error unwrapper. +func (e *TemporaryError) Unwrap() error { return e.Err } + +func (e *TemporaryError) Error() string { return fmt.Sprintf("dtls temporary: %v", e.Err) } + +// Timeout implements net.Error.Timeout(). +func (*TimeoutError) Timeout() bool { return true } + +// Temporary implements net.Error.Temporary(). +func (*TimeoutError) Temporary() bool { return true } + +// Unwrap implements Go1.13 error unwrapper. +func (e *TimeoutError) Unwrap() error { return e.Err } + +func (e *TimeoutError) Error() string { return fmt.Sprintf("dtls timeout: %v", e.Err) } + +// Timeout implements net.Error.Timeout(). +func (e *HandshakeError) Timeout() bool { + var netErr net.Error + if errors.As(e.Err, &netErr) { + return netErr.Timeout() + } + + return false +} + +// Temporary implements net.Error.Temporary(). +func (e *HandshakeError) Temporary() bool { + var netErr net.Error + if errors.As(e.Err, &netErr) { + return netErr.Temporary() //nolint + } + + return false +} + +// Unwrap implements Go1.13 error unwrapper. +func (e *HandshakeError) Unwrap() error { return e.Err } + +func (e *HandshakeError) Error() string { return fmt.Sprintf("handshake error: %v", e.Err) } diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/alpn.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/alpn.go new file mode 100644 index 0000000..28ea844 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/alpn.go @@ -0,0 +1,83 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package extension + +import ( + "slices" + + "golang.org/x/crypto/cryptobyte" +) + +// ALPN is a TLS extension for application-layer protocol negotiation within +// the TLS handshake. +// +// https://tools.ietf.org/html/rfc7301 +type ALPN struct { + ProtocolNameList []string +} + +// TypeValue returns the extension TypeValue. +func (a ALPN) TypeValue() TypeValue { + return ALPNTypeValue +} + +// Marshal encodes the extension. +func (a *ALPN) Marshal() ([]byte, error) { + var builder cryptobyte.Builder + builder.AddUint16(uint16(a.TypeValue())) + builder.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, proto := range a.ProtocolNameList { + p := proto // Satisfy range scope lint + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(p)) + }) + } + }) + }) + + return builder.Bytes() +} + +// Unmarshal populates the extension from encoded data. +func (a *ALPN) Unmarshal(data []byte) error { + val := cryptobyte.String(data) + + var extension uint16 + val.ReadUint16(&extension) + if TypeValue(extension) != a.TypeValue() { + return errInvalidExtensionType + } + + var extData cryptobyte.String + val.ReadUint16LengthPrefixed(&extData) + + var protoList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { + return ErrALPNInvalidFormat + } + for !protoList.Empty() { + var proto cryptobyte.String + if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() { + return ErrALPNInvalidFormat + } + a.ProtocolNameList = append(a.ProtocolNameList, string(proto)) + } + + return nil +} + +// ALPNProtocolSelection negotiates a shared protocol according to #3.2 of rfc7301. +func ALPNProtocolSelection(supportedProtocols, peerSupportedProtocols []string) (string, error) { + if len(supportedProtocols) == 0 || len(peerSupportedProtocols) == 0 { + return "", nil + } + for _, s := range supportedProtocols { + if slices.Contains(peerSupportedProtocols, s) { + return s, nil + } + } + + return "", errALPNNoAppProto +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/connection_id.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/connection_id.go new file mode 100644 index 0000000..fa27a7d --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/connection_id.go @@ -0,0 +1,61 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package extension + +import ( + "golang.org/x/crypto/cryptobyte" +) + +// ConnectionID is a DTLS extension that provides an alternative to IP address +// and port for session association. +// +// https://tools.ietf.org/html/rfc9146 +type ConnectionID struct { + // A zero-length connection ID indicates for a client or server that + // negotiated connection IDs from the peer will be sent but there is no need + // to respond with one + CID []byte // variable length +} + +// TypeValue returns the extension TypeValue. +func (c ConnectionID) TypeValue() TypeValue { + return ConnectionIDTypeValue +} + +// Marshal encodes the extension. +func (c *ConnectionID) Marshal() ([]byte, error) { + var b cryptobyte.Builder + b.AddUint16(uint16(c.TypeValue())) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(c.CID) + }) + }) + + return b.Bytes() +} + +// Unmarshal populates the extension from encoded data. +func (c *ConnectionID) Unmarshal(data []byte) error { + val := cryptobyte.String(data) + var extension uint16 + val.ReadUint16(&extension) + if TypeValue(extension) != c.TypeValue() { + return errInvalidExtensionType + } + + var extData cryptobyte.String + val.ReadUint16LengthPrefixed(&extData) + + var cid cryptobyte.String + if !extData.ReadUint8LengthPrefixed(&cid) { + return errInvalidCIDFormat + } + c.CID = make([]byte, len(cid)) + if !cid.CopyBytes(c.CID) { + return errInvalidCIDFormat + } + + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/cookie.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/cookie.go new file mode 100644 index 0000000..a71bb8d --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/cookie.go @@ -0,0 +1,61 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package extension + +import ( + "golang.org/x/crypto/cryptobyte" +) + +const maxCookieSize = 0xffff - 2 + +// CookieExt implements the cookie extension in DTLS 1.3. +// See RFC 8446 section 4.2.2. Cookie. +type CookieExt struct { + Cookie []byte +} + +// TypeValue returns the extension TypeValue. +func (c CookieExt) TypeValue() TypeValue { + return CookieTypeValue +} + +// Marshal encodes the extension. +func (c *CookieExt) Marshal() ([]byte, error) { + cookieLength := len(c.Cookie) + if cookieLength == 0 || cookieLength > maxCookieSize { + return nil, errCookieExtFormat + } + var b cryptobyte.Builder + b.AddUint16(uint16(c.TypeValue())) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(c.Cookie) + }) + }) + + return b.Bytes() +} + +// Unmarshal populates the extension from encoded data. +func (c *CookieExt) Unmarshal(data []byte) error { //nolint:cyclop + val := cryptobyte.String(data) + var extension uint16 + if !val.ReadUint16(&extension) || TypeValue(extension) != c.TypeValue() { + return errInvalidExtensionType + } + + var extData cryptobyte.String + if !val.ReadUint16LengthPrefixed(&extData) { + return errBufferTooSmall + } + + var cookie cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&cookie) || cookie.Empty() || len(cookie) > maxCookieSize { + return errCookieExtFormat + } + + c.Cookie = append([]byte(nil), cookie...) + + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/errors.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/errors.go new file mode 100644 index 0000000..f9a0163 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/errors.go @@ -0,0 +1,62 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package extension + +import ( + "errors" + + "github.com/pion/dtls/v3/pkg/protocol" +) + +var ( + // ErrALPNInvalidFormat is raised when the ALPN format is invalid. + ErrALPNInvalidFormat = &protocol.FatalError{ + Err: errors.New("invalid alpn format"), //nolint:err113 + } + errALPNNoAppProto = &protocol.FatalError{ + Err: errors.New("no application protocol"), //nolint:err113 + } + errBufferTooSmall = &protocol.TemporaryError{ + Err: errors.New("buffer is too small"), //nolint:err113 + } + errInvalidExtensionType = &protocol.FatalError{ + Err: errors.New("invalid extension type"), //nolint:err113 + } + errInvalidSNIFormat = &protocol.FatalError{ + Err: errors.New("invalid server name format"), //nolint:err113 + } + errInvalidCIDFormat = &protocol.FatalError{ + Err: errors.New("invalid connection ID format"), //nolint:err113 + } + errLengthMismatch = &protocol.InternalError{ + Err: errors.New("data length and declared length do not match"), //nolint:err113 + } + errMasterKeyIdentifierTooLarge = &protocol.FatalError{ + Err: errors.New("master key identifier is over 255 bytes"), //nolint:err113 + } + errPreSharedKeyFormat = &protocol.FatalError{ + Err: errors.New("invalid Pre-Shared Key extension format"), //nolint:err113 + } + errPskKeyExchangeModesFormat = &protocol.FatalError{ + Err: errors.New("invalid Pre-Shared Key Exchange Modes extension format"), //nolint:err113 + } + errNoPskKeyExchangeMode = &protocol.InternalError{ + Err: errors.New("no mode set for the Pre-Shared Key Exchange Modes extension"), //nolint:err113 + } + errCookieExtFormat = &protocol.FatalError{ + Err: errors.New("invalid cookie format"), //nolint:err113 + } + errInvalidKeyShareFormat = &protocol.FatalError{ + Err: errors.New("invalid key_share format"), //nolint:err113 + } + errDuplicateKeyShare = &protocol.FatalError{ + Err: errors.New("duplicate key_share group"), //nolint:err113 + } + errInvalidSupportedVersionsFormat = &protocol.FatalError{ + Err: errors.New("invalid supported_versions format"), //nolint:err113 + } + errInvalidDTLSVersion = &protocol.InternalError{ + Err: errors.New("invalid dtls version was provided"), //nolint:err113 + } +) diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/extension.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/extension.go new file mode 100644 index 0000000..4810c08 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/extension.go @@ -0,0 +1,148 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package extension implements the extension values in the ClientHello/ServerHello +package extension + +import ( + "encoding/binary" + + "github.com/pion/dtls/v3/pkg/crypto/hash" + "github.com/pion/dtls/v3/pkg/crypto/signature" +) + +// TypeValue is the 2 byte value for a TLS Extension as registered in the IANA +// +// https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml +type TypeValue uint16 + +// TypeValue constants. +const ( + ServerNameTypeValue TypeValue = 0 + // In DTLS 1.3, this extension in renamed to "supported_groups". + SupportedEllipticCurvesTypeValue TypeValue = 10 + SupportedPointFormatsTypeValue TypeValue = 11 + SupportedSignatureAlgorithmsTypeValue TypeValue = 13 + UseSRTPTypeValue TypeValue = 14 + ALPNTypeValue TypeValue = 16 + UseExtendedMasterSecretTypeValue TypeValue = 23 + PreSharedKeyValue TypeValue = 41 + SupportedVersionsTypeValue TypeValue = 43 + CookieTypeValue TypeValue = 44 + PskKeyExchangeModesTypeValue TypeValue = 45 + SignatureAlgorithmsCertTypeValue TypeValue = 50 + KeyShareTypeValue TypeValue = 51 + ConnectionIDTypeValue TypeValue = 54 + RenegotiationInfoTypeValue TypeValue = 65281 +) + +// Extension represents a single TLS extension. +type Extension interface { + Marshal() ([]byte, error) + Unmarshal(data []byte) error + TypeValue() TypeValue +} + +// Unmarshal many extensions at once. +func Unmarshal(buf []byte) ([]Extension, error) { //nolint:cyclop + switch { + case len(buf) == 0: + return []Extension{}, nil + case len(buf) < 2: + return nil, errBufferTooSmall + } + + declaredLen := binary.BigEndian.Uint16(buf) + if len(buf)-2 != int(declaredLen) { + return nil, errLengthMismatch + } + + extensions := []Extension{} + unmarshalAndAppend := func(data []byte, e Extension) error { + err := e.Unmarshal(data) + if err != nil { + return err + } + extensions = append(extensions, e) + + return nil + } + + for offset := 2; offset < len(buf); { + bufView := buf[offset:] //nolint:gosec // offset bounded by loop condition + if len(bufView) < 2 { + return nil, errBufferTooSmall + } + + var err error + switch TypeValue(binary.BigEndian.Uint16(bufView)) { + case ServerNameTypeValue: + err = unmarshalAndAppend(bufView, &ServerName{}) + case SupportedEllipticCurvesTypeValue: + err = unmarshalAndAppend(bufView, &SupportedEllipticCurves{}) + case SupportedPointFormatsTypeValue: + err = unmarshalAndAppend(bufView, &SupportedPointFormats{}) + case SupportedSignatureAlgorithmsTypeValue: + err = unmarshalAndAppend(bufView, &SupportedSignatureAlgorithms{}) + case SignatureAlgorithmsCertTypeValue: + err = unmarshalAndAppend(bufView, &SignatureAlgorithmsCert{}) + case UseSRTPTypeValue: + err = unmarshalAndAppend(bufView, &UseSRTP{}) + case ALPNTypeValue: + err = unmarshalAndAppend(bufView, &ALPN{}) + case UseExtendedMasterSecretTypeValue: + err = unmarshalAndAppend(bufView, &UseExtendedMasterSecret{}) + case RenegotiationInfoTypeValue: + err = unmarshalAndAppend(bufView, &RenegotiationInfo{}) + case ConnectionIDTypeValue: + err = unmarshalAndAppend(bufView, &ConnectionID{}) + case SupportedVersionsTypeValue: + err = unmarshalAndAppend(bufView, &SupportedVersions{}) + case KeyShareTypeValue: + err = unmarshalAndAppend(bufView, &KeyShare{}) + case CookieTypeValue: + err = unmarshalAndAppend(bufView, &CookieExt{}) + default: + } + + if err != nil { + return nil, err + } + if len(bufView) < 4 { + return nil, errBufferTooSmall + } + extensionLength := binary.BigEndian.Uint16(bufView[2:]) + offset += (4 + int(extensionLength)) + } + + return extensions, nil +} + +// Marshal many extensions at once. +func Marshal(e []Extension) ([]byte, error) { + extensions := []byte{} + for _, e := range e { + raw, err := e.Marshal() + if err != nil { + return nil, err + } + extensions = append(extensions, raw...) + } + out := []byte{0x00, 0x00} + binary.BigEndian.PutUint16(out, uint16(len(extensions))) //nolint:gosec // G115 + + return append(out, extensions...), nil +} + +// parseSignatureScheme parses a signature scheme from a uint16 value. +// It handles both TLS 1.2 style (hash byte + signature byte) and TLS 1.3 style (full uint16 PSS schemes). +// Returns the hash algorithm and signature algorithm. +func parseSignatureScheme(scheme uint16) (hash.Algorithm, signature.Algorithm) { + if signature.Algorithm(scheme).IsPSS() { + // TLS 1.3 PSS scheme - full uint16 is the signature algorithm + return hash.ExtractHashFromPSS(scheme), signature.Algorithm(scheme) + } + + // TLS 1.2 style - split into hash (high byte) and signature (low byte) + return hash.Algorithm(scheme >> 8), signature.Algorithm(scheme & 0xFF) +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/generic_signature_hash_extension.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/generic_signature_hash_extension.go new file mode 100644 index 0000000..6553c55 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/generic_signature_hash_extension.go @@ -0,0 +1,81 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package extension implements the extension values in the ClientHello/ServerHello +package extension + +import ( + "github.com/pion/dtls/v3/pkg/crypto/hash" + "github.com/pion/dtls/v3/pkg/crypto/signature" + "github.com/pion/dtls/v3/pkg/crypto/signaturehash" + "golang.org/x/crypto/cryptobyte" +) + +// marshalGenericSignatureHashAlgorithm encodes the extension. +// This supports hybrid encoding: TLS 1.3 PSS schemes are encoded as full uint16, +// while TLS 1.2 schemes use hash (high byte) + signature (low byte) encoding. +func marshalGenericSignatureHashAlgorithm(typeValue TypeValue, sigHashAlgs []signaturehash.Algorithm) ([]byte, error) { + var builder cryptobyte.Builder + builder.AddUint16(uint16(typeValue)) + builder.AddUint16LengthPrefixed(func(extBuilder *cryptobyte.Builder) { + extBuilder.AddUint16LengthPrefixed(func(algBuilder *cryptobyte.Builder) { + for _, v := range sigHashAlgs { + // For PSS schemes, write the full uint16 SignatureScheme value + // For other schemes, write hash (high byte) + signature (low byte) in TLS 1.2 style + if v.Signature.IsPSS() { + // TLS 1.3 PSS: full uint16 is the signature scheme + algBuilder.AddUint16(uint16(v.Signature)) + } else { + // TLS 1.2 style: hash byte + signature byte + algBuilder.AddUint8(byte(v.Hash)) + algBuilder.AddUint8(byte(v.Signature)) + } + } + }) + }) + + return builder.Bytes() +} + +// unmarshalGenericSignatureAlgorithm populates the extension from encoded data. +// This supports hybrid encoding: detects TLS 1.3 PSS schemes +// and handles them as full uint16, while TLS 1.2 schemes use byte-split encoding. +func unmarshalGenericSignatureHashAlgorithm(typeValue TypeValue, data []byte, dst *[]signaturehash.Algorithm) error { + val := cryptobyte.String(data) + var extension uint16 + if !val.ReadUint16(&extension) || TypeValue(extension) != typeValue { + return errInvalidExtensionType + } + + var extData cryptobyte.String + if !val.ReadUint16LengthPrefixed(&extData) { + return errBufferTooSmall + } + + var algData cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&algData) { + return errLengthMismatch + } + + for !algData.Empty() { + var scheme uint16 + if !algData.ReadUint16(&scheme) { + return errLengthMismatch + } + + // Parse the signature scheme (handles both TLS 1.2 and TLS 1.3 PSS encoding) + supportedHashAlgorithm, supportedSignatureAlgorithm := parseSignatureScheme(scheme) + + // Validate both hash and signature algorithms + if _, ok := hash.Algorithms()[supportedHashAlgorithm]; ok { + if _, ok := signature.Algorithms()[supportedSignatureAlgorithm]; ok { + *dst = append(*dst, signaturehash.Algorithm{ + Hash: supportedHashAlgorithm, + Signature: supportedSignatureAlgorithm, + }) + } + } + } + + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/key_share.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/key_share.go new file mode 100644 index 0000000..78c2de1 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/key_share.go @@ -0,0 +1,179 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package extension + +import ( + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "golang.org/x/crypto/cryptobyte" +) + +type KeyShareEntry struct { + Group elliptic.Curve + KeyExchange []byte +} + +// KeyShare represents the "key_share" extension. Only one of the fields can be used at a time. +// See RFC 8446 section 4.2.8. +type KeyShare struct { + ClientShares []KeyShareEntry // ClientHello + ServerShare *KeyShareEntry // ServerHello + SelectedGroup *elliptic.Curve // HelloRetryRequest +} + +func (k KeyShare) TypeValue() TypeValue { return KeyShareTypeValue } + +// Marshal encodes the extension. +func (k *KeyShare) Marshal() ([]byte, error) { //nolint:cyclop + hasClientShares := k.ClientShares != nil // vector MAY be empty + hasServerShare := k.ServerShare != nil + hasHelloRetryRequest := k.SelectedGroup != nil + + // there must be exactly one context. + if hasTooManyContexts(hasClientShares, hasServerShare, hasHelloRetryRequest) { + return nil, errInvalidKeyShareFormat + } + + var builder cryptobyte.Builder + + builder.AddUint16(uint16(k.TypeValue())) + + if hasClientShares { + seenGroups := map[elliptic.Curve]struct{}{} + for _, e := range k.ClientShares { + if _, ok := seenGroups[e.Group]; ok { + return nil, errDuplicateKeyShare + } + + seenGroups[e.Group] = struct{}{} + + if l := len(e.KeyExchange); l == 0 || l > 0xffff { + return nil, errInvalidKeyShareFormat + } + } + } + + if hasServerShare { + if l := len(k.ServerShare.KeyExchange); l == 0 || l > 0xffff { + return nil, errInvalidKeyShareFormat + } + } + + builder.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + switch { + case hasHelloRetryRequest: + // KeyShareHelloRetryRequest { NamedGroup selected_group; } + b.AddUint16(uint16(*k.SelectedGroup)) + + case hasServerShare: + // KeyShareServerHello { KeyShareEntry server_share; } + addKeyShareEntry(b, *k.ServerShare) + + default: + // KeyShareClientHello { KeyShareEntry client_shares<0..2^16-1>; } + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, e := range k.ClientShares { + addKeyShareEntry(b, e) + } + }) + } + }) + + return builder.Bytes() +} + +// Unmarshal decodes the extension. +func (k *KeyShare) Unmarshal(data []byte) error { //nolint:cyclop + val := cryptobyte.String(data) + var extData cryptobyte.String + + var ext uint16 + if !val.ReadUint16(&ext) || TypeValue(ext) != k.TypeValue() { + return errInvalidExtensionType + } + if !val.ReadUint16LengthPrefixed(&extData) { + return errBufferTooSmall + } + if extData.Empty() { + return errInvalidKeyShareFormat + } + + k.ClientShares, k.ServerShare, k.SelectedGroup = nil, nil, nil + + peek := extData + var vecLen uint16 + // ClientHello: client_shares is a uint16-length-prefixed vector. + if peek.ReadUint16(&vecLen) && int(vecLen) == len(peek) { //nolint:nestif + seenGroups := map[elliptic.Curve]struct{}{} + for !peek.Empty() { + var entry KeyShareEntry + var groupU16 uint16 + var raw cryptobyte.String + + if !peek.ReadUint16(&groupU16) || !peek.ReadUint16LengthPrefixed(&raw) || len(raw) == 0 { + return errInvalidKeyShareFormat + } + + group := elliptic.Curve(groupU16) + + if _, ok := seenGroups[group]; ok { + return errDuplicateKeyShare + } + + seenGroups[group] = struct{}{} + + entry.Group = group + entry.KeyExchange = append([]byte(nil), raw...) + k.ClientShares = append(k.ClientShares, entry) + } + + // consume vector (2 bytes length + vecLen) + if !extData.Skip(2 + int(vecLen)) { + return errInvalidKeyShareFormat + } + + return nil + } + + // HelloRetryRequest: exactly 2 bytes = selected_group + if len(extData) == 2 { + var groupU16 uint16 + if !extData.ReadUint16(&groupU16) { + return errInvalidKeyShareFormat + } + + group := elliptic.Curve(groupU16) + if elliptic.Curves()[group] { + k.SelectedGroup = &group + } + + return nil + } + + // ServerHello: exactly one KeyShareEntry and no trailing bytes + var groupU16 uint16 + var raw cryptobyte.String + + if !extData.ReadUint16(&groupU16) || !extData.ReadUint16LengthPrefixed(&raw) || !extData.Empty() || len(raw) == 0 { + return errInvalidKeyShareFormat + } + + group := elliptic.Curve(groupU16) + share := KeyShareEntry{Group: group, KeyExchange: append([]byte(nil), raw...)} + k.ServerShare = &share + + return nil +} + +func addKeyShareEntry(b *cryptobyte.Builder, e KeyShareEntry) { + b.AddUint16(uint16(e.Group)) + + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(e.KeyExchange) + }) +} + +// hasTooManyContexts is used in Marshal(). It returns whether the KeyShare struct has more than exactly one context. +func hasTooManyContexts(a bool, b bool, c bool) bool { + return (a && b) || (a && c) || (b && c) +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/pre_shared_key.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/pre_shared_key.go new file mode 100644 index 0000000..73ff695 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/pre_shared_key.go @@ -0,0 +1,144 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package extension + +import ( + "golang.org/x/crypto/cryptobyte" +) + +// PreSharedKey represents the "pre_shared_key" extension for DTLS 1.3. +// This extension is used in both ClientHello and ServerHello messages, +// but only the relevant fields should be populated for each context. +// See RFC 8446 section 4.2.11. +// +// https://datatracker.ietf.org/doc/html/rfc8446#section-4.2.11 +type PreSharedKey struct { + // ClientHello only - offered PSK identities + Identities []PskIdentity + // ClientHello only - binder values associated with a PSK identity + Binders []PskBinderEntry + // ServerHello only - index of selected identity + SelectedIdentity uint16 +} + +// PskIdentity represents the PSK identitiy in the "pre_shared_key" extension +// for DTLS 1.3. +type PskIdentity struct { + Identity []byte + ObfuscatedTicketAge uint32 +} + +// PskBinderEntry represents the binder related to a PSK identity in the +// "pre_shared_key" extension for DTLS 1.3. +type PskBinderEntry []byte + +const minPSKBinderSize = 32 + +// TypeValue returns the extension TypeValue. +func (p PreSharedKey) TypeValue() TypeValue { + return PreSharedKeyValue +} + +// Marshal encodes the extension. +func (p *PreSharedKey) Marshal() ([]byte, error) { + var out cryptobyte.Builder + out.AddUint16(uint16(p.TypeValue())) + + // ServerHello + if len(p.Identities) == 0 || len(p.Binders) == 0 { + out.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(p.SelectedIdentity) + }) + + return out.Bytes() + } + + // ClientHello + out.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, pskIdentity := range p.Identities { + if len(pskIdentity.Identity) == 0 { + b.SetError(errPreSharedKeyFormat) + } + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(pskIdentity.Identity) + }) + b.AddUint32(pskIdentity.ObfuscatedTicketAge) + } + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, binder := range p.Binders { + if len(binder) < minPSKBinderSize { + b.SetError(errPreSharedKeyFormat) + } + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(binder) + }) + } + }) + }) + + return out.Bytes() +} + +// Unmarshal populates the extension from encoded data. +func (p *PreSharedKey) Unmarshal(data []byte) error { //nolint:cyclop + val := cryptobyte.String(data) + var extension uint16 + if !val.ReadUint16(&extension) || TypeValue(extension) != p.TypeValue() { + return errInvalidExtensionType + } + + var extData cryptobyte.String + if !val.ReadUint16LengthPrefixed(&extData) { + return errBufferTooSmall + } + + // ServerHello + if len(extData) == 2 { + var selected uint16 + if !extData.ReadUint16(&selected) { + return errPreSharedKeyFormat + } + p.SelectedIdentity = selected + + return nil + } + + // ClientHello + var identities cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&identities) || identities.Empty() { + return errPreSharedKeyFormat + } + + for !identities.Empty() { + var identity cryptobyte.String + var ticket uint32 + if !identities.ReadUint16LengthPrefixed(&identity) || !identities.ReadUint32(&ticket) || identity.Empty() { + return errPreSharedKeyFormat + } + p.Identities = append(p.Identities, PskIdentity{identity, ticket}) + } + + var binders cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&binders) || binders.Empty() { + return errPreSharedKeyFormat + } + + for !binders.Empty() { + var binder cryptobyte.String + if !binders.ReadUint8LengthPrefixed(&binder) || len(binder) < minPSKBinderSize { + return errPreSharedKeyFormat + } + p.Binders = append(p.Binders, PskBinderEntry(binder)) + } + + // Ensure there is one binder value per identity in list, + // and check for trailing bytes. + if len(p.Binders) != len(p.Identities) || !extData.Empty() { + return errPreSharedKeyFormat + } + + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/psk_key_exchange_modes.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/psk_key_exchange_modes.go new file mode 100644 index 0000000..59d8cf1 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/psk_key_exchange_modes.go @@ -0,0 +1,80 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package extension + +import ( + "golang.org/x/crypto/cryptobyte" +) + +// PskKeyExchangeModes implements the PskKeyExchangeModes extension in DTLS 1.3. +// See RFC 8446 section 4.2.9. Pre-Shared Key Exchange Modes. +// +// https://datatracker.ietf.org/doc/html/rfc8446#section-4.2.9 +type PskKeyExchangeModes struct { + KeModes []PskKeyExchangeMode +} + +type PskKeyExchangeMode uint8 + +// TypeValue constants. +const ( + PskKe PskKeyExchangeMode = 0 + PskDheKe PskKeyExchangeMode = 1 +) + +// TypeValue returns the extension TypeValue. +func (p PskKeyExchangeModes) TypeValue() TypeValue { + return PskKeyExchangeModesTypeValue +} + +// Marshal encodes the extension. +func (p *PskKeyExchangeModes) Marshal() ([]byte, error) { + if len(p.KeModes) == 0 { + return nil, errNoPskKeyExchangeMode + } + + var out cryptobyte.Builder + out.AddUint16(uint16(p.TypeValue())) + + out.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + for _, keM := range p.KeModes { + b.AddUint8(uint8(keM)) + } + }) + }) + + return out.Bytes() +} + +// Unmarshal populates the extension from encoded data. +func (p *PskKeyExchangeModes) Unmarshal(data []byte) error { //nolint:cyclop + val := cryptobyte.String(data) + var extension uint16 + if !val.ReadUint16(&extension) || TypeValue(extension) != p.TypeValue() { + return errInvalidExtensionType + } + + var extData cryptobyte.String + if !val.ReadUint16LengthPrefixed(&extData) { + return errBufferTooSmall + } + + var strModes cryptobyte.String + + if !extData.ReadUint8LengthPrefixed(&strModes) { + return errPskKeyExchangeModesFormat + } + if strModes.Empty() { + return errPskKeyExchangeModesFormat + } + + p.KeModes = make([]PskKeyExchangeMode, 0) + + for _, mode := range strModes { + p.KeModes = append(p.KeModes, PskKeyExchangeMode(mode)) + } + + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/renegotiation_info.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/renegotiation_info.go new file mode 100644 index 0000000..be7160e --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/renegotiation_info.go @@ -0,0 +1,47 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package extension + +import "encoding/binary" + +const ( + renegotiationInfoHeaderSize = 5 +) + +// RenegotiationInfo allows a Client/Server to +// communicate their renegotation support +// +// https://tools.ietf.org/html/rfc5746 +type RenegotiationInfo struct { + RenegotiatedConnection uint8 +} + +// TypeValue returns the extension TypeValue. +func (r RenegotiationInfo) TypeValue() TypeValue { + return RenegotiationInfoTypeValue +} + +// Marshal encodes the extension. +func (r *RenegotiationInfo) Marshal() ([]byte, error) { + out := make([]byte, renegotiationInfoHeaderSize) + + binary.BigEndian.PutUint16(out, uint16(r.TypeValue())) + binary.BigEndian.PutUint16(out[2:], uint16(1)) // length + out[4] = r.RenegotiatedConnection + + return out, nil +} + +// Unmarshal populates the extension from encoded data. +func (r *RenegotiationInfo) Unmarshal(data []byte) error { + if len(data) < renegotiationInfoHeaderSize { + return errBufferTooSmall + } else if TypeValue(binary.BigEndian.Uint16(data)) != r.TypeValue() { + return errInvalidExtensionType + } + + r.RenegotiatedConnection = data[4] + + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/server_name.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/server_name.go new file mode 100644 index 0000000..098e228 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/server_name.go @@ -0,0 +1,83 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package extension + +import ( + "strings" + + "golang.org/x/crypto/cryptobyte" +) + +const serverNameTypeDNSHostName = 0 + +// ServerName allows the client to inform the server the specific +// name it wishes to contact. Useful if multiple DNS names resolve +// to one IP +// +// https://tools.ietf.org/html/rfc6066#section-3 +type ServerName struct { + ServerName string +} + +// TypeValue returns the extension TypeValue. +func (s ServerName) TypeValue() TypeValue { + return ServerNameTypeValue +} + +// Marshal encodes the extension. +func (s *ServerName) Marshal() ([]byte, error) { + var b cryptobyte.Builder + b.AddUint16(uint16(s.TypeValue())) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(serverNameTypeDNSHostName) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(s.ServerName)) + }) + }) + }) + + return b.Bytes() +} + +// Unmarshal populates the extension from encoded data. +func (s *ServerName) Unmarshal(data []byte) error { //nolint:cyclop + val := cryptobyte.String(data) + var extension uint16 + val.ReadUint16(&extension) + if TypeValue(extension) != s.TypeValue() { + return errInvalidExtensionType + } + + var extData cryptobyte.String + val.ReadUint16LengthPrefixed(&extData) + + var nameList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() { + return errInvalidSNIFormat + } + for !nameList.Empty() { + var nameType uint8 + var serverName cryptobyte.String + if !nameList.ReadUint8(&nameType) || + !nameList.ReadUint16LengthPrefixed(&serverName) || + serverName.Empty() { + return errInvalidSNIFormat + } + if nameType != serverNameTypeDNSHostName { + continue + } + if len(s.ServerName) != 0 { + // Multiple names of the same name_type are prohibited. + return errInvalidSNIFormat + } + s.ServerName = string(serverName) + // An SNI value may not include a trailing dot. + if strings.HasSuffix(s.ServerName, ".") { + return errInvalidSNIFormat + } + } + + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/signature_algorithms_cert.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/signature_algorithms_cert.go new file mode 100644 index 0000000..f2881ec --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/signature_algorithms_cert.go @@ -0,0 +1,43 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package extension + +import ( + "github.com/pion/dtls/v3/pkg/crypto/signaturehash" +) + +// SignatureAlgorithmsCert allows a Client/Server to indicate which signature algorithms +// may be used in digital signatures for X.509 certificates. +// This is separate from signature_algorithms which applies to handshake signatures. +// +// RFC 8446 Section 4.2.3: +// "TLS 1.2 implementations SHOULD also process this extension. +// If present, the signature_algorithms_cert extension SHALL be treated as being +// equivalent to signature_algorithms for the purposes of certificate chain validation." +// +// https://tools.ietf.org/html/rfc8446#section-4.2.3 +type SignatureAlgorithmsCert struct { + SignatureHashAlgorithms []signaturehash.Algorithm +} + +// TypeValue returns the extension TypeValue. +func (s SignatureAlgorithmsCert) TypeValue() TypeValue { + return SignatureAlgorithmsCertTypeValue +} + +// Marshal encodes the extension. +// This supports hybrid encoding: TLS 1.3 PSS schemes are encoded as full uint16, +// while TLS 1.2 schemes use hash (high byte) + signature (low byte) encoding. +func (s *SignatureAlgorithmsCert) Marshal() ([]byte, error) { + return marshalGenericSignatureHashAlgorithm(s.TypeValue(), s.SignatureHashAlgorithms) +} + +// Unmarshal populates the extension from encoded data. +// This supports hybrid encoding: detects TLS 1.3 PSS schemes +// and handles them as full uint16, while TLS 1.2 schemes use byte-split encoding. +func (s *SignatureAlgorithmsCert) Unmarshal(data []byte) error { + s.SignatureHashAlgorithms = []signaturehash.Algorithm{} + + return unmarshalGenericSignatureHashAlgorithm(s.TypeValue(), data, &s.SignatureHashAlgorithms) +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/srtp_protection_profile.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/srtp_protection_profile.go new file mode 100644 index 0000000..07510a2 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/srtp_protection_profile.go @@ -0,0 +1,32 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package extension + +// SRTPProtectionProfile defines the parameters and options that are in effect for the SRTP processing +// https://tools.ietf.org/html/rfc5764#section-4.1.2 +type SRTPProtectionProfile uint16 + +const ( + SRTP_AES128_CM_HMAC_SHA1_80 SRTPProtectionProfile = 0x0001 // nolint + SRTP_AES128_CM_HMAC_SHA1_32 SRTPProtectionProfile = 0x0002 // nolint + SRTP_AES256_CM_SHA1_80 SRTPProtectionProfile = 0x0003 // nolint + SRTP_AES256_CM_SHA1_32 SRTPProtectionProfile = 0x0004 // nolint + SRTP_NULL_HMAC_SHA1_80 SRTPProtectionProfile = 0x0005 // nolint + SRTP_NULL_HMAC_SHA1_32 SRTPProtectionProfile = 0x0006 // nolint + SRTP_AEAD_AES_128_GCM SRTPProtectionProfile = 0x0007 // nolint + SRTP_AEAD_AES_256_GCM SRTPProtectionProfile = 0x0008 // nolint +) + +func srtpProtectionProfiles() map[SRTPProtectionProfile]bool { + return map[SRTPProtectionProfile]bool{ + SRTP_AES128_CM_HMAC_SHA1_80: true, + SRTP_AES128_CM_HMAC_SHA1_32: true, + SRTP_AES256_CM_SHA1_80: true, + SRTP_AES256_CM_SHA1_32: true, + SRTP_NULL_HMAC_SHA1_80: true, + SRTP_NULL_HMAC_SHA1_32: true, + SRTP_AEAD_AES_128_GCM: true, + SRTP_AEAD_AES_256_GCM: true, + } +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/supported_elliptic_curves.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/supported_elliptic_curves.go new file mode 100644 index 0000000..fd0cf61 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/supported_elliptic_curves.go @@ -0,0 +1,69 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package extension + +import ( + "encoding/binary" + + "github.com/pion/dtls/v3/pkg/crypto/elliptic" +) + +const ( + supportedGroupsHeaderSize = 6 +) + +// SupportedEllipticCurves allows a Client/Server to communicate +// what curves they both support +// +// https://tools.ietf.org/html/rfc8422#section-5.1.1 +// +// In DTLS 1.3, this extension in renamed to "supported_groups". +// https://datatracker.ietf.org/doc/html/rfc8446#section-4.2.7 +type SupportedEllipticCurves struct { + EllipticCurves []elliptic.Curve +} + +// TypeValue returns the extension TypeValue. +func (s SupportedEllipticCurves) TypeValue() TypeValue { + return SupportedEllipticCurvesTypeValue +} + +// Marshal encodes the extension. +func (s *SupportedEllipticCurves) Marshal() ([]byte, error) { + out := make([]byte, supportedGroupsHeaderSize) + + binary.BigEndian.PutUint16(out, uint16(s.TypeValue())) + binary.BigEndian.PutUint16(out[2:], uint16(2+(len(s.EllipticCurves)*2))) //nolint:gosec // G115 + binary.BigEndian.PutUint16(out[4:], uint16(len(s.EllipticCurves)*2)) //nolint:gosec // G115 + + for _, v := range s.EllipticCurves { + out = append(out, []byte{0x00, 0x00}...) //nolint:makezero // todo: fix + binary.BigEndian.PutUint16(out[len(out)-2:], uint16(v)) + } + + return out, nil +} + +// Unmarshal populates the extension from encoded data. +func (s *SupportedEllipticCurves) Unmarshal(data []byte) error { + if len(data) <= supportedGroupsHeaderSize { + return errBufferTooSmall + } else if TypeValue(binary.BigEndian.Uint16(data)) != s.TypeValue() { + return errInvalidExtensionType + } + + groupCount := int(binary.BigEndian.Uint16(data[4:]) / 2) + if supportedGroupsHeaderSize+(groupCount*2) > len(data) { + return errLengthMismatch + } + + for i := 0; i < groupCount; i++ { + supportedGroupID := elliptic.Curve(binary.BigEndian.Uint16(data[(supportedGroupsHeaderSize + (i * 2)):])) + if _, ok := elliptic.Curves()[supportedGroupID]; ok { + s.EllipticCurves = append(s.EllipticCurves, supportedGroupID) + } + } + + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/supported_point_formats.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/supported_point_formats.go new file mode 100644 index 0000000..7120727 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/supported_point_formats.go @@ -0,0 +1,69 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package extension + +import ( + "encoding/binary" + + "github.com/pion/dtls/v3/pkg/crypto/elliptic" +) + +const ( + supportedPointFormatsSize = 5 +) + +// SupportedPointFormats allows a Client/Server to negotiate +// the EllipticCurvePointFormats +// +// https://tools.ietf.org/html/rfc4492#section-5.1.2 +type SupportedPointFormats struct { + PointFormats []elliptic.CurvePointFormat +} + +// TypeValue returns the extension TypeValue. +func (s SupportedPointFormats) TypeValue() TypeValue { + return SupportedPointFormatsTypeValue +} + +// Marshal encodes the extension. +func (s *SupportedPointFormats) Marshal() ([]byte, error) { + out := make([]byte, supportedPointFormatsSize) + + binary.BigEndian.PutUint16(out, uint16(s.TypeValue())) + binary.BigEndian.PutUint16(out[2:], uint16(1+(len(s.PointFormats)))) //nolint:gosec // G115 + out[4] = byte(len(s.PointFormats)) + + for _, v := range s.PointFormats { + out = append(out, byte(v)) //nolint:makezero // todo: fix + } + + return out, nil +} + +// Unmarshal populates the extension from encoded data. +func (s *SupportedPointFormats) Unmarshal(data []byte) error { + if len(data) <= supportedPointFormatsSize { + return errBufferTooSmall + } + + if TypeValue(binary.BigEndian.Uint16(data)) != s.TypeValue() { + return errInvalidExtensionType + } + + pointFormatCount := int(data[4]) + if supportedPointFormatsSize+pointFormatCount > len(data) { + return errLengthMismatch + } + + for i := 0; i < pointFormatCount; i++ { + p := elliptic.CurvePointFormat(data[supportedPointFormatsSize+i]) + switch p { + case elliptic.CurvePointFormatUncompressed: + s.PointFormats = append(s.PointFormats, p) + default: + } + } + + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/supported_signature_algorithms.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/supported_signature_algorithms.go new file mode 100644 index 0000000..de4efe6 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/supported_signature_algorithms.go @@ -0,0 +1,37 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package extension + +import ( + "github.com/pion/dtls/v3/pkg/crypto/signaturehash" +) + +// SupportedSignatureAlgorithms allows a Client/Server to +// negotiate what SignatureHash Algorithms they both support +// +// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1 +type SupportedSignatureAlgorithms struct { + SignatureHashAlgorithms []signaturehash.Algorithm +} + +// TypeValue returns the extension TypeValue. +func (s SupportedSignatureAlgorithms) TypeValue() TypeValue { + return SupportedSignatureAlgorithmsTypeValue +} + +// Marshal encodes the extension. +// This supports hybrid encoding: TLS 1.3 PSS schemes are encoded as full uint16, +// while TLS 1.2 schemes use hash (high byte) + signature (low byte) encoding. +func (s *SupportedSignatureAlgorithms) Marshal() ([]byte, error) { + return marshalGenericSignatureHashAlgorithm(s.TypeValue(), s.SignatureHashAlgorithms) +} + +// Unmarshal populates the extension from encoded data. +// This supports hybrid encoding: detects TLS 1.3 PSS schemes +// and handles them as full uint16, while TLS 1.2 schemes use byte-split encoding. +func (s *SupportedSignatureAlgorithms) Unmarshal(data []byte) error { + s.SignatureHashAlgorithms = []signaturehash.Algorithm{} + + return unmarshalGenericSignatureHashAlgorithm(s.TypeValue(), data, &s.SignatureHashAlgorithms) +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/supported_versions.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/supported_versions.go new file mode 100644 index 0000000..1145769 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/supported_versions.go @@ -0,0 +1,129 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package extension + +import ( + "github.com/pion/dtls/v3/pkg/protocol" + "golang.org/x/crypto/cryptobyte" +) + +// SupportedVersions is a TLS extension used by the client to indicate +// which versions of TLS it supports and by the server to indicate which +// version it is using. +// +// https://datatracker.ietf.org/doc/html/rfc8446#section-4.2.1 +type SupportedVersions struct { + // ClientHello's preference-ordered list. + Versions []protocol.Version +} + +func (s SupportedVersions) TypeValue() TypeValue { return SupportedVersionsTypeValue } + +// Marshal encodes the extension without carrying negotiation state. +func (s *SupportedVersions) Marshal() ([]byte, error) { + if len(s.Versions) == 0 { + return nil, errInvalidSupportedVersionsFormat + } + + totalBytes := len(s.Versions) * 2 + + // The 2..254 bound is defined in the following: + // https://datatracker.ietf.org/doc/html/rfc8446#section-4.2.1 + if totalBytes < 2 || totalBytes > 254 { + return nil, errInvalidSupportedVersionsFormat + } + + // We're only checking for *valid* versions, not to be confused with supported versions. + // Error on invalid versions to protect against malformed messages/DOS attacks. + for _, v := range s.Versions { + if !protocol.IsValidVersion(v) { + return nil, errInvalidDTLSVersion + } + } + + var builder cryptobyte.Builder + + builder.AddUint16(uint16(s.TypeValue())) + builder.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if len(s.Versions) == 1 { + // in the case that there's only one version, the do not add the length (uint8). + b.AddUint8(s.Versions[0].Major) + b.AddUint8(s.Versions[0].Minor) + + return + } + + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + for _, v := range s.Versions { + b.AddUint8(v.Major) + b.AddUint8(v.Minor) + } + }) + }) + + return builder.Bytes() +} + +// Unmarshal parses either the ClientHello list or the ServerHello/HelloRetryRequest single value. +// Any version not recognized is discarded. +func (s *SupportedVersions) Unmarshal(data []byte) error { //nolint:cyclop + val := cryptobyte.String(data) + var extData cryptobyte.String + + var extension uint16 + val.ReadUint16(&extension) + if TypeValue(extension) != s.TypeValue() { + return errInvalidExtensionType + } + + if !val.ReadUint16LengthPrefixed(&extData) { + return errBufferTooSmall + } + + if extData.Empty() { + return errInvalidSupportedVersionsFormat + } + + // Try ClientHello list: versions<2..254> (1-byte length, then pairs) + peek := extData + var listLen uint8 + if peek.ReadUint8(&listLen) && int(listLen) == len(peek) && listLen >= 2 && (listLen%2) == 0 { + s.Versions = s.Versions[:0] + + for !peek.Empty() { + var major, minor uint8 + if !peek.ReadUint8(&major) || !peek.ReadUint8(&minor) { + return errInvalidSupportedVersionsFormat + } + + // We're only checking for *valid* versions, not to be confused with supported versions. + if protocol.IsValidBytes(major, minor) { + s.Versions = append(s.Versions, protocol.Version{Major: major, Minor: minor}) + } + } + + if !extData.Skip(1 + int(listLen)) { + return errInvalidSupportedVersionsFormat + } + + return nil + } + + // Otherwise, expect ServerHello/HelloRetryRequest selected_version, which should be exactly 2 bytes. + if len(extData) != 2 { + return errInvalidSupportedVersionsFormat + } + + var major, minor uint8 + if !extData.ReadUint8(&major) || !extData.ReadUint8(&minor) { + return errInvalidSupportedVersionsFormat + } + + // We're only checking for *valid* versions, not to be confused with supported versions. + if protocol.IsValidBytes(major, minor) { + s.Versions = append(s.Versions[:0], protocol.Version{Major: major, Minor: minor}) + } + + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/use_master_secret.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/use_master_secret.go new file mode 100644 index 0000000..1fd4409 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/use_master_secret.go @@ -0,0 +1,49 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package extension + +import "encoding/binary" + +const ( + useExtendedMasterSecretHeaderSize = 4 +) + +// UseExtendedMasterSecret defines a TLS extension that contextually binds the +// master secret to a log of the full handshake that computes it, thus +// preventing MITM attacks. +type UseExtendedMasterSecret struct { + Supported bool +} + +// TypeValue returns the extension TypeValue. +func (u UseExtendedMasterSecret) TypeValue() TypeValue { + return UseExtendedMasterSecretTypeValue +} + +// Marshal encodes the extension. +func (u *UseExtendedMasterSecret) Marshal() ([]byte, error) { + if !u.Supported { + return []byte{}, nil + } + + out := make([]byte, useExtendedMasterSecretHeaderSize) + + binary.BigEndian.PutUint16(out, uint16(u.TypeValue())) + binary.BigEndian.PutUint16(out[2:], uint16(0)) // length + + return out, nil +} + +// Unmarshal populates the extension from encoded data. +func (u *UseExtendedMasterSecret) Unmarshal(data []byte) error { + if len(data) < useExtendedMasterSecretHeaderSize { + return errBufferTooSmall + } else if TypeValue(binary.BigEndian.Uint16(data)) != u.TypeValue() { + return errInvalidExtensionType + } + + u.Supported = true + + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/use_srtp.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/use_srtp.go new file mode 100644 index 0000000..1cae27e --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/extension/use_srtp.go @@ -0,0 +1,86 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package extension + +import ( + "encoding/binary" +) + +const ( + useSRTPHeaderSize = 6 +) + +// UseSRTP allows a Client/Server to negotiate what SRTPProtectionProfiles +// they both support +// +// https://tools.ietf.org/html/rfc8422 +type UseSRTP struct { + ProtectionProfiles []SRTPProtectionProfile + MasterKeyIdentifier []byte +} + +// TypeValue returns the extension TypeValue. +func (u UseSRTP) TypeValue() TypeValue { + return UseSRTPTypeValue +} + +// Marshal encodes the extension. +func (u *UseSRTP) Marshal() ([]byte, error) { + out := make([]byte, useSRTPHeaderSize) + + binary.BigEndian.PutUint16(out, uint16(u.TypeValue())) + //nolint:gosec // G115 + binary.BigEndian.PutUint16( + out[2:], + uint16(2+(len(u.ProtectionProfiles)*2)+ /* MKI Length */ 1+len(u.MasterKeyIdentifier)), + ) + binary.BigEndian.PutUint16(out[4:], uint16(len(u.ProtectionProfiles)*2)) //nolint:gosec // G115 + + for _, v := range u.ProtectionProfiles { + out = append(out, []byte{0x00, 0x00}...) //nolint:makezero // todo: fix + binary.BigEndian.PutUint16(out[len(out)-2:], uint16(v)) + } + if len(u.MasterKeyIdentifier) > 255 { + return nil, errMasterKeyIdentifierTooLarge + } + + out = append(out, byte(len(u.MasterKeyIdentifier))) //nolint:makezero // todo: fix + out = append(out, u.MasterKeyIdentifier...) //nolint:makezero // todo: fix + + return out, nil +} + +// Unmarshal populates the extension from encoded data. +func (u *UseSRTP) Unmarshal(data []byte) error { + if len(data) <= useSRTPHeaderSize { + return errBufferTooSmall + } else if TypeValue(binary.BigEndian.Uint16(data)) != u.TypeValue() { + return errInvalidExtensionType + } + + profileCount := int(binary.BigEndian.Uint16(data[4:]) / 2) + masterKeyIdentifierIndex := supportedGroupsHeaderSize + (profileCount * 2) + if masterKeyIdentifierIndex+1 > len(data) { + return errLengthMismatch + } + + for i := 0; i < profileCount; i++ { + supportedProfile := SRTPProtectionProfile(binary.BigEndian.Uint16(data[(useSRTPHeaderSize + (i * 2)):])) + if _, ok := srtpProtectionProfiles()[supportedProfile]; ok { + u.ProtectionProfiles = append(u.ProtectionProfiles, supportedProfile) + } + } + + masterKeyIdentifierLen := int(data[masterKeyIdentifierIndex]) + if masterKeyIdentifierIndex+masterKeyIdentifierLen >= len(data) { + return errLengthMismatch + } + + u.MasterKeyIdentifier = append( + []byte{}, + data[masterKeyIdentifierIndex+1:masterKeyIdentifierIndex+1+masterKeyIdentifierLen]..., + ) + + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/cipher_suite.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/cipher_suite.go new file mode 100644 index 0000000..3b9fbe4 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/cipher_suite.go @@ -0,0 +1,34 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package handshake + +import "encoding/binary" + +func decodeCipherSuiteIDs(buf []byte) ([]uint16, error) { + if len(buf) < 2 { + return nil, errBufferTooSmall + } + cipherSuitesCount := int(binary.BigEndian.Uint16(buf[0:])) / 2 + rtrn := make([]uint16, cipherSuitesCount) + for i := 0; i < cipherSuitesCount; i++ { + if len(buf) < (i*2 + 4) { + return nil, errBufferTooSmall + } + + rtrn[i] = binary.BigEndian.Uint16(buf[(i*2)+2:]) + } + + return rtrn, nil +} + +func encodeCipherSuiteIDs(cipherSuiteIDs []uint16) []byte { + out := []byte{0x00, 0x00} + binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(cipherSuiteIDs)*2)) //nolint:gosec // G115 + for _, id := range cipherSuiteIDs { + out = append(out, []byte{0x00, 0x00}...) + binary.BigEndian.PutUint16(out[len(out)-2:], id) + } + + return out +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/errors.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/errors.go new file mode 100644 index 0000000..8ac4621 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/errors.go @@ -0,0 +1,74 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package handshake + +import ( + "errors" + + "github.com/pion/dtls/v3/pkg/protocol" +) + +// Typed errors. +var ( + errUnableToMarshalFragmented = &protocol.InternalError{ + Err: errors.New("unable to marshal fragmented handshakes"), //nolint:err113 + } + errHandshakeMessageUnset = &protocol.InternalError{ + Err: errors.New("handshake message unset, unable to marshal"), //nolint:err113 + } + errBufferTooSmall = &protocol.TemporaryError{ + Err: errors.New("buffer is too small"), //nolint:err113 + } + errLengthMismatch = &protocol.InternalError{ + Err: errors.New("data length and declared length do not match"), //nolint:err113 + } + errInvalidClientKeyExchange = &protocol.FatalError{ + Err: errors.New("unable to determine if ClientKeyExchange is a public key or PSK Identity"), //nolint:err113 + } + errInvalidHashAlgorithm = &protocol.FatalError{ + Err: errors.New("invalid hash algorithm"), //nolint:err113 + } + errInvalidSignatureAlgorithm = &protocol.FatalError{ + Err: errors.New("invalid signature algorithm"), //nolint:err113 + } + errCookieTooLong = &protocol.FatalError{ + Err: errors.New("cookie must not be longer then 255 bytes"), //nolint:err113 + } + errInvalidEllipticCurveType = &protocol.FatalError{ + Err: errors.New("invalid or unknown elliptic curve type"), //nolint:err113 + } + errInvalidNamedCurve = &protocol.FatalError{ + Err: errors.New("invalid named curve"), //nolint:err113 + } + errCipherSuiteUnset = &protocol.FatalError{ + Err: errors.New("server hello can not be created without a cipher suite"), //nolint:err113 + } + errCompressionMethodUnset = &protocol.FatalError{ + Err: errors.New("server hello can not be created without a compression method"), //nolint:err113 + } + errInvalidCompressionMethod = &protocol.FatalError{ + Err: errors.New("invalid or unknown compression method"), //nolint:err113 + } + errNotImplemented = &protocol.InternalError{ + Err: errors.New("feature has not been implemented yet"), //nolint:err113 + } + errInvalidCertificateRequestContext = &protocol.FatalError{ + Err: errors.New("invalid certificate request context"), //nolint:err113 + } + errInvalidCertificateEntry = &protocol.FatalError{ + Err: errors.New("invalid certificate entry"), //nolint:err113 + } + errCertificateRequestContextTooLong = &protocol.FatalError{ + Err: errors.New("certificate request context must not be longer than 255 bytes"), //nolint:err113 + } + errCertificateListTooLong = &protocol.FatalError{ + Err: errors.New("certificate list must not be longer than 2^24-1 bytes"), //nolint:err113 + } + errInvalidExtensionsLength = &protocol.FatalError{ + Err: errors.New("extensions data must be between 2 and 2^16-1 bytes"), //nolint:err113 + } + errMissingSignatureAlgorithmsExtension = &protocol.FatalError{ + Err: errors.New("signature_algorithms extension is required in CertificateRequest"), //nolint:err113 + } +) diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/handshake.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/handshake.go new file mode 100644 index 0000000..6d40a5a --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/handshake.go @@ -0,0 +1,152 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package handshake provides the DTLS wire protocol for handshakes +package handshake + +import ( + "github.com/pion/dtls/v3/internal/ciphersuite/types" + "github.com/pion/dtls/v3/internal/util" + "github.com/pion/dtls/v3/pkg/protocol" +) + +// Type is the unique identifier for each handshake message +// https://tools.ietf.org/html/rfc5246#section-7.4 +type Type uint8 + +// Types of DTLS Handshake messages we know about. +const ( + TypeHelloRequest Type = 0 + TypeClientHello Type = 1 + TypeServerHello Type = 2 + TypeHelloVerifyRequest Type = 3 + TypeCertificate Type = 11 + TypeServerKeyExchange Type = 12 + TypeCertificateRequest Type = 13 + TypeServerHelloDone Type = 14 + TypeCertificateVerify Type = 15 + TypeClientKeyExchange Type = 16 + TypeFinished Type = 20 +) + +// String returns the string representation of this type. +func (t Type) String() string { //nolint:cyclop + switch t { + case TypeHelloRequest: + return "HelloRequest" + case TypeClientHello: + return "ClientHello" + case TypeServerHello: + return "ServerHello" + case TypeHelloVerifyRequest: + return "HelloVerifyRequest" + case TypeCertificate: + return "TypeCertificate" + case TypeServerKeyExchange: + return "ServerKeyExchange" + case TypeCertificateRequest: + return "CertificateRequest" + case TypeServerHelloDone: + return "ServerHelloDone" + case TypeCertificateVerify: + return "CertificateVerify" + case TypeClientKeyExchange: + return "ClientKeyExchange" + case TypeFinished: + return "Finished" + } + + return "" +} + +// Message is the body of a Handshake datagram. +type Message interface { + Marshal() ([]byte, error) + Unmarshal(data []byte) error + Type() Type +} + +// Handshake protocol is responsible for selecting a cipher spec and +// generating a master secret, which together comprise the primary +// cryptographic parameters associated with a secure session. The +// handshake protocol can also optionally authenticate parties who have +// certificates signed by a trusted certificate authority. +// https://tools.ietf.org/html/rfc5246#section-7.3 +type Handshake struct { + Header Header + Message Message + + KeyExchangeAlgorithm types.KeyExchangeAlgorithm +} + +// ContentType returns what kind of content this message is carying. +func (h Handshake) ContentType() protocol.ContentType { + return protocol.ContentTypeHandshake +} + +// Marshal encodes a handshake into a binary message. +func (h *Handshake) Marshal() ([]byte, error) { + if h.Message == nil { + return nil, errHandshakeMessageUnset + } else if h.Header.FragmentOffset != 0 { + return nil, errUnableToMarshalFragmented + } + + msg, err := h.Message.Marshal() + if err != nil { + return nil, err + } + + h.Header.Length = uint32(len(msg)) //nolint:gosec // G115 + h.Header.FragmentLength = h.Header.Length + h.Header.Type = h.Message.Type() + header, err := h.Header.Marshal() + if err != nil { + return nil, err + } + + return append(header, msg...), nil +} + +// Unmarshal decodes a handshake from a binary message. +func (h *Handshake) Unmarshal(data []byte) error { //nolint:cyclop + if err := h.Header.Unmarshal(data); err != nil { + return err + } + + reportedLen := util.BigEndianUint24(data[1:]) + if uint32(len(data)-HeaderLength) != reportedLen { //nolint:gosec // G115 + return errLengthMismatch + } else if reportedLen != h.Header.FragmentLength { + return errLengthMismatch + } + + switch Type(data[0]) { + case TypeHelloRequest: + return errNotImplemented + case TypeClientHello: + h.Message = &MessageClientHello{} + case TypeHelloVerifyRequest: + h.Message = &MessageHelloVerifyRequest{} + case TypeServerHello: + h.Message = &MessageServerHello{} + case TypeCertificate: + h.Message = &MessageCertificate{} + case TypeServerKeyExchange: + h.Message = &MessageServerKeyExchange{KeyExchangeAlgorithm: h.KeyExchangeAlgorithm} + case TypeCertificateRequest: + h.Message = &MessageCertificateRequest{} + case TypeServerHelloDone: + h.Message = &MessageServerHelloDone{} + case TypeClientKeyExchange: + h.Message = &MessageClientKeyExchange{KeyExchangeAlgorithm: h.KeyExchangeAlgorithm} + case TypeFinished: + h.Message = &MessageFinished{} + case TypeCertificateVerify: + h.Message = &MessageCertificateVerify{} + default: + return errNotImplemented + } + + return h.Message.Unmarshal(data[HeaderLength:]) +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/header.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/header.go new file mode 100644 index 0000000..befeed1 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/header.go @@ -0,0 +1,55 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package handshake + +import ( + "encoding/binary" + + "github.com/pion/dtls/v3/internal/util" +) + +// HeaderLength msg_len for Handshake messages assumes an extra +// 12 bytes for sequence, fragment and version information vs TLS. +const HeaderLength = 12 + +// Header is the static first 12 bytes of each RecordLayer +// of type Handshake. These fields allow us to support message loss, reordering, and +// message fragmentation, +// +// https://tools.ietf.org/html/rfc6347#section-4.2.2 +type Header struct { + Type Type + Length uint32 // uint24 in spec + MessageSequence uint16 + FragmentOffset uint32 // uint24 in spec + FragmentLength uint32 // uint24 in spec +} + +// Marshal encodes the Header. +func (h *Header) Marshal() ([]byte, error) { + out := make([]byte, HeaderLength) + + out[0] = byte(h.Type) + util.PutBigEndianUint24(out[1:], h.Length) + binary.BigEndian.PutUint16(out[4:], h.MessageSequence) + util.PutBigEndianUint24(out[6:], h.FragmentOffset) + util.PutBigEndianUint24(out[9:], h.FragmentLength) + + return out, nil +} + +// Unmarshal populates the header from encoded data. +func (h *Header) Unmarshal(data []byte) error { + if len(data) < HeaderLength { + return errBufferTooSmall + } + + h.Type = Type(data[0]) + h.Length = util.BigEndianUint24(data[1:]) + h.MessageSequence = binary.BigEndian.Uint16(data[4:]) + h.FragmentOffset = util.BigEndianUint24(data[6:]) + h.FragmentLength = util.BigEndianUint24(data[9:]) + + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_certificate.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_certificate.go new file mode 100644 index 0000000..8da5237 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_certificate.go @@ -0,0 +1,74 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package handshake + +import ( + "github.com/pion/dtls/v3/internal/util" +) + +// MessageCertificate is a DTLS Handshake Message +// it can contain either a Client or Server Certificate +// +// https://tools.ietf.org/html/rfc5246#section-7.4.2 +type MessageCertificate struct { + Certificate [][]byte +} + +// Type returns the Handshake Type. +func (m MessageCertificate) Type() Type { + return TypeCertificate +} + +const ( + handshakeMessageCertificateLengthFieldSize = 3 +) + +// Marshal encodes the Handshake. +func (m *MessageCertificate) Marshal() ([]byte, error) { + out := make([]byte, handshakeMessageCertificateLengthFieldSize) + + for _, r := range m.Certificate { + // Certificate Length + //nolint:makezero // todo: fix + out = append(out, make([]byte, handshakeMessageCertificateLengthFieldSize)...) + //nolint:gosec // G115 + util.PutBigEndianUint24(out[len(out)-handshakeMessageCertificateLengthFieldSize:], uint32(len(r))) + + // Certificate body + out = append(out, append([]byte{}, r...)...) //nolint:makezero // todo: fix + } + + // Total Payload Size + util.PutBigEndianUint24(out[0:], uint32(len(out[handshakeMessageCertificateLengthFieldSize:]))) //nolint:gosec //G115 + + return out, nil +} + +// Unmarshal populates the message from encoded data. +func (m *MessageCertificate) Unmarshal(data []byte) error { + if len(data) < handshakeMessageCertificateLengthFieldSize { + return errBufferTooSmall + } + + if certificateBodyLen := int(util.BigEndianUint24( + data, + )); certificateBodyLen+handshakeMessageCertificateLengthFieldSize != len(data) { + return errLengthMismatch + } + + offset := handshakeMessageCertificateLengthFieldSize + for offset < len(data) { + certificateLen := int(util.BigEndianUint24(data[offset:])) + offset += handshakeMessageCertificateLengthFieldSize + + if offset+certificateLen > len(data) { + return errLengthMismatch + } + + m.Certificate = append(m.Certificate, append([]byte{}, data[offset:offset+certificateLen]...)) + offset += certificateLen + } + + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_certificate_13.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_certificate_13.go new file mode 100644 index 0000000..36e7b4d --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_certificate_13.go @@ -0,0 +1,198 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package handshake + +import ( + "encoding/binary" + + "github.com/pion/dtls/v3/internal/util" + "github.com/pion/dtls/v3/pkg/protocol/extension" + "golang.org/x/crypto/cryptobyte" +) + +// CertificateEntry13 represents a single certificate entry in the DTLS 1.3 Certificate message. +// Each entry contains certificate data and optional per-certificate extensions. +// +// https://datatracker.ietf.org/doc/html/rfc8446#section-4.4.2 +type CertificateEntry13 struct { + // CertificateData contains the DER-encoded X.509 certificate. + // Can be empty for certain contexts (e.g., RawPublicKey mode). + CertificateData []byte + + // Extensions contains per-certificate extensions. + // Examples: OCSP status, SignedCertificateTimestamp, etc. + Extensions []extension.Extension +} + +// MessageCertificate13 represents the Certificate handshake message for DTLS 1.3. +// This message is used to transport the certificate chain and associated extensions. +// +// https://datatracker.ietf.org/doc/html/rfc8446#section-4.4.2 +type MessageCertificate13 struct { + // CertificateRequestContext is an opaque value that binds this certificate + // to a specific CertificateRequest (for client certificates) or is empty + // for server certificates. + CertificateRequestContext []byte + + // CertificateList contains the certificate chain with each entry having + // optional per-certificate extensions. + CertificateList []CertificateEntry13 +} + +// Type returns the handshake message type. +func (m MessageCertificate13) Type() Type { + return TypeCertificate +} + +const ( + maxUint24 = 0xffffff + cert13ContextLengthFieldSize = 1 + cert13ContextMaxLength = 255 + cert13CertLengthFieldSize = 3 + cert13ExtLengthFieldSize = 2 +) + +// Marshal encodes the MessageCertificate13 into its wire format. +// +// Wire format: +// +// [1 byte] certificate_request_context length +// [0-255] certificate_request_context data +// [3 bytes] certificate_list length +// For each certificate: +// [3 bytes] cert_data length +// [variable] cert_data (DER certificate) +// [2 bytes] extensions length (from extension.Marshal) +// [variable] extensions data +func (m *MessageCertificate13) Marshal() ([]byte, error) { + // Validate certificate_request_context length + if len(m.CertificateRequestContext) > cert13ContextMaxLength { + return nil, errCertificateRequestContextTooLong + } + + // Start with certificate_request_context (1-byte length prefix) + out := []byte{byte(len(m.CertificateRequestContext))} + out = append(out, m.CertificateRequestContext...) + + // Build certificate_list + certificateList := []byte{} + for _, entry := range m.CertificateList { + // Add cert_data as a 3-byte length prefix + certDataLen := len(entry.CertificateData) + if certDataLen == 0 || certDataLen > maxUint24 { + return nil, errInvalidCertificateEntry + } + certDataLenBytes := make([]byte, cert13CertLengthFieldSize) + util.PutBigEndianUint24(certDataLenBytes, uint32(certDataLen)) //nolint:gosec // G115 + certificateList = append(certificateList, certDataLenBytes...) + certificateList = append(certificateList, entry.CertificateData...) + + // Marshal extensions (includes a 2-byte length prefix) + extensionsData, err := extension.Marshal(entry.Extensions) + if err != nil { + return nil, err + } + certificateList = append(certificateList, extensionsData...) + + // Check size of certificate_list is still within bounds + if len(certificateList) > maxUint24 { + return nil, errCertificateListTooLong + } + } + + // Add certificate_list with 3-byte length prefix + certificateListLenBytes := make([]byte, cert13CertLengthFieldSize) + util.PutBigEndianUint24(certificateListLenBytes, uint32(len(certificateList))) //nolint:gosec // G115 + out = append(out, certificateListLenBytes...) + out = append(out, certificateList...) + + return out, nil +} + +// parseCertificate13Entry parses a single certificate entry from the cryptobyte string. +func parseCertificate13Entry(str *cryptobyte.String) (*CertificateEntry13, error) { + // Read cert_data with 3-byte length prefix + var certData cryptobyte.String + if !str.ReadUint24LengthPrefixed(&certData) { + return nil, errInvalidCertificateEntry + } + + // Validate cert_data length is in valid range <1..2^24-1> + if len(certData) == 0 { + return nil, errInvalidCertificateEntry + } + + // Copy cert_data to avoid aliasing issues + certDataBytes := make([]byte, len(certData)) + copy(certDataBytes, certData) + + // Validate extensions length (2-byte length prefix + up to 2^16-1 bytes of data) + if len(*str) < cert13ExtLengthFieldSize { + return nil, errInvalidCertificateEntry + } + + // Read extensions length to validate we have enough data + extensionsLen := binary.BigEndian.Uint16([]byte(*str)[:cert13ExtLengthFieldSize]) + if len(*str) < cert13ExtLengthFieldSize+int(extensionsLen) { + return nil, errInvalidCertificateEntry + } + + // Unmarshal extensions data + extensionsData := []byte(*str)[:cert13ExtLengthFieldSize+int(extensionsLen)] + extensions, err := extension.Unmarshal(extensionsData) + if err != nil { + return nil, err + } + + // Advance the cryptobyte.String's position + if !str.Skip(cert13ExtLengthFieldSize + int(extensionsLen)) { + return nil, errInvalidCertificateEntry + } + + return &CertificateEntry13{ + CertificateData: certDataBytes, + Extensions: extensions, + }, nil +} + +// Unmarshal decodes the MessageCertificate13 from its wire format. +func (m *MessageCertificate13) Unmarshal(data []byte) error { + // Validate minimum data length + if len(data) < cert13ContextLengthFieldSize+cert13CertLengthFieldSize { + return errBufferTooSmall + } + + str := cryptobyte.String(data) + + // Read certificate_request_context with 1-byte length prefix + var contextData cryptobyte.String + if !str.ReadUint8LengthPrefixed(&contextData) { + return errInvalidCertificateRequestContext + } + m.CertificateRequestContext = make([]byte, len(contextData)) + copy(m.CertificateRequestContext, contextData) + + // Read certificate_list with 3-byte length prefix + var certificateListData cryptobyte.String + if !str.ReadUint24LengthPrefixed(&certificateListData) { + return errInvalidCertificateEntry + } + + // Ensure no trailing data + if len(str) != 0 { + return errLengthMismatch + } + + // Parse certificate_list + m.CertificateList = []CertificateEntry13{} + for len(certificateListData) > 0 { + entry, err := parseCertificate13Entry(&certificateListData) + if err != nil { + return err + } + m.CertificateList = append(m.CertificateList, *entry) + } + + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_certificate_request.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_certificate_request.go new file mode 100644 index 0000000..0efae1e --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_certificate_request.go @@ -0,0 +1,145 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package handshake + +import ( + "encoding/binary" + + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v3/pkg/crypto/hash" + "github.com/pion/dtls/v3/pkg/crypto/signature" + "github.com/pion/dtls/v3/pkg/crypto/signaturehash" +) + +/* +MessageCertificateRequest is so a non-anonymous server can optionally +request a certificate from the client, if appropriate for the selected cipher +suite. This message, if sent, will immediately follow the ServerKeyExchange +message (if it is sent; otherwise, this message follows the +server's Certificate message). + +https://tools.ietf.org/html/rfc5246#section-7.4.4 +*/ +type MessageCertificateRequest struct { + CertificateTypes []clientcertificate.Type + SignatureHashAlgorithms []signaturehash.Algorithm + CertificateAuthoritiesNames [][]byte +} + +const ( + messageCertificateRequestMinLength = 5 +) + +// Type returns the Handshake Type. +func (m MessageCertificateRequest) Type() Type { + return TypeCertificateRequest +} + +// Marshal encodes the Handshake. +func (m *MessageCertificateRequest) Marshal() ([]byte, error) { + out := []byte{byte(len(m.CertificateTypes))} + for _, v := range m.CertificateTypes { + out = append(out, byte(v)) + } + + out = append(out, []byte{0x00, 0x00}...) + binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(m.SignatureHashAlgorithms)*2)) //nolint:gosec //G115 + for _, v := range m.SignatureHashAlgorithms { + out = append(out, byte(v.Hash)) + out = append(out, byte(v.Signature)) + } + + // Distinguished Names + casLength := 0 + for _, ca := range m.CertificateAuthoritiesNames { + casLength += len(ca) + 2 + } + out = append(out, []byte{0x00, 0x00}...) + binary.BigEndian.PutUint16(out[len(out)-2:], uint16(casLength)) //nolint:gosec //G115 + if casLength > 0 { + for _, ca := range m.CertificateAuthoritiesNames { + out = append(out, []byte{0x00, 0x00}...) + binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(ca))) //nolint:gosec //G115 + out = append(out, ca...) + } + } + + return out, nil +} + +// Unmarshal populates the message from encoded data. +func (m *MessageCertificateRequest) Unmarshal(data []byte) error { //nolint:cyclop + if len(data) < messageCertificateRequestMinLength { + return errBufferTooSmall + } + + offset := 0 + certificateTypesLength := int(data[0]) + offset++ + + if (offset + certificateTypesLength) > len(data) { + return errBufferTooSmall + } + + for i := 0; i < certificateTypesLength; i++ { + certType := clientcertificate.Type(data[offset+i]) + if _, ok := clientcertificate.Types()[certType]; ok { + m.CertificateTypes = append(m.CertificateTypes, certType) + } + } + offset += certificateTypesLength + if len(data) < offset+2 { + return errBufferTooSmall + } + signatureHashAlgorithmsLength := int(binary.BigEndian.Uint16(data[offset:])) + offset += 2 + + if (offset + signatureHashAlgorithmsLength) > len(data) { + return errBufferTooSmall + } + + for i := 0; i < signatureHashAlgorithmsLength; i += 2 { + if len(data) < (offset + i + 2) { + return errBufferTooSmall + } + h := hash.Algorithm(data[offset+i]) + s := signature.Algorithm(data[offset+i+1]) + + if _, ok := hash.Algorithms()[h]; !ok { + continue + } else if _, ok := signature.Algorithms()[s]; !ok { + continue + } + m.SignatureHashAlgorithms = append(m.SignatureHashAlgorithms, signaturehash.Algorithm{Signature: s, Hash: h}) + } + + offset += signatureHashAlgorithmsLength + if len(data) < offset+2 { + return errBufferTooSmall + } + casLength := int(binary.BigEndian.Uint16(data[offset:])) + offset += 2 + if (offset + casLength) > len(data) { + return errBufferTooSmall + } + cas := make([]byte, casLength) + copy(cas, data[offset:offset+casLength]) + m.CertificateAuthoritiesNames = nil + for len(cas) > 0 { + if len(cas) < 2 { + return errBufferTooSmall + } + caLen := binary.BigEndian.Uint16(cas) + cas = cas[2:] + + if len(cas) < int(caLen) { + return errBufferTooSmall + } + + m.CertificateAuthoritiesNames = append(m.CertificateAuthoritiesNames, cas[:caLen]) + cas = cas[caLen:] + } + + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_certificate_request_13.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_certificate_request_13.go new file mode 100644 index 0000000..3563b0e --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_certificate_request_13.go @@ -0,0 +1,134 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package handshake + +import ( + "encoding/binary" + + "github.com/pion/dtls/v3/pkg/protocol/extension" + "golang.org/x/crypto/cryptobyte" +) + +// MessageCertificateRequest13 represents the CertificateRequest handshake message for DTLS 1.3. +// This message is used by the server to request a certificate from the client. +// +// https://datatracker.ietf.org/doc/html/rfc8446#section-4.3.2 +type MessageCertificateRequest13 struct { + // CertificateRequestContext is an opaque value that the server creates + // to bind the client's certificate to the handshake context. + CertificateRequestContext []byte + + // Extensions contains the list of extensions. + // The signature_algorithms extension is REQUIRED per RFC 8446. + Extensions []extension.Extension +} + +// Type returns the handshake message type. +func (m MessageCertificateRequest13) Type() Type { + return TypeCertificateRequest +} + +const ( + maxUint16 = 0xffff + certReq13ContextMaxLength = 255 + certReq13MinLength = 3 +) + +// Marshal encodes the MessageCertificateRequest13 into its wire format. +// +// Wire format: +// +// [1 byte] certificate_request_context length +// [0-255] certificate_request_context data +// [2 bytes] extensions length (from extension.Marshal) +// [variable] extensions data +func (m *MessageCertificateRequest13) Marshal() ([]byte, error) { + // Validate certificate_request_context length + if len(m.CertificateRequestContext) > certReq13ContextMaxLength { + return nil, errCertificateRequestContextTooLong + } + + // Validate that signature_algorithms extension is present (required by RFC 8446) + hasSignatureAlgorithms := false + for _, ext := range m.Extensions { + if ext.TypeValue() == extension.SupportedSignatureAlgorithmsTypeValue { + hasSignatureAlgorithms = true + + break + } + } + if !hasSignatureAlgorithms { + return nil, errMissingSignatureAlgorithmsExtension + } + + var builder cryptobyte.Builder + + // Add certificate_request_context (1-byte length prefix) + builder.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.CertificateRequestContext) + }) + + // Marshal extensions (includes 2-byte length prefix, like in TLS 1.2) + extensionsData, err := extension.Marshal(m.Extensions) + if err != nil { + return nil, err + } + // Validate extensions length is in valid range <2..2^16-1> + if len(extensionsData) < 2 || len(extensionsData) > maxUint16 { + return nil, errInvalidExtensionsLength + } + builder.AddBytes(extensionsData) + + return builder.Bytes() +} + +// Unmarshal decodes the MessageCertificateRequest13 from its wire format. +func (m *MessageCertificateRequest13) Unmarshal(data []byte) error { + // Validate minimum data length + if len(data) < certReq13MinLength { + return errBufferTooSmall + } + + str := cryptobyte.String(data) + + // Read certificate_request_context + var contextData cryptobyte.String + if !str.ReadUint8LengthPrefixed(&contextData) { + return errInvalidCertificateRequestContext + } + m.CertificateRequestContext = make([]byte, len(contextData)) + copy(m.CertificateRequestContext, contextData) + + // Read extensions length (2 bytes) + if len(str) < 2 { + return errInvalidExtensionsLength + } + extensionsLen := binary.BigEndian.Uint16(str[:2]) + + // Validate we have exactly extensionsLen bytes remaining after the length field + if len(str[2:]) != int(extensionsLen) { + return errLengthMismatch + } + + var err error + m.Extensions, err = extension.Unmarshal([]byte(str)) + if err != nil { + return err + } + + // Validate that signature_algorithms extension is present (required by RFC 8446) + hasSignatureAlgorithms := false + for _, ext := range m.Extensions { + if ext.TypeValue() == extension.SupportedSignatureAlgorithmsTypeValue { + hasSignatureAlgorithms = true + + break + } + } + if !hasSignatureAlgorithms { + return errMissingSignatureAlgorithmsExtension + } + + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_certificate_verify.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_certificate_verify.go new file mode 100644 index 0000000..3f53e95 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_certificate_verify.go @@ -0,0 +1,66 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package handshake + +import ( + "encoding/binary" + + "github.com/pion/dtls/v3/pkg/crypto/hash" + "github.com/pion/dtls/v3/pkg/crypto/signature" +) + +// MessageCertificateVerify provide explicit verification of a +// client certificate. +// +// https://tools.ietf.org/html/rfc5246#section-7.4.8 +type MessageCertificateVerify struct { + HashAlgorithm hash.Algorithm + SignatureAlgorithm signature.Algorithm + Signature []byte +} + +const handshakeMessageCertificateVerifyMinLength = 4 + +// Type returns the Handshake Type. +func (m MessageCertificateVerify) Type() Type { + return TypeCertificateVerify +} + +// Marshal encodes the Handshake. +func (m *MessageCertificateVerify) Marshal() ([]byte, error) { + out := make([]byte, 1+1+2+len(m.Signature)) + + out[0] = byte(m.HashAlgorithm) + out[1] = byte(m.SignatureAlgorithm) + binary.BigEndian.PutUint16(out[2:], uint16(len(m.Signature))) //nolint:gosec // G115 + copy(out[4:], m.Signature) + + return out, nil +} + +// Unmarshal populates the message from encoded data. +func (m *MessageCertificateVerify) Unmarshal(data []byte) error { + if len(data) < handshakeMessageCertificateVerifyMinLength { + return errBufferTooSmall + } + + m.HashAlgorithm = hash.Algorithm(data[0]) + if _, ok := hash.Algorithms()[m.HashAlgorithm]; !ok { + return errInvalidHashAlgorithm + } + + m.SignatureAlgorithm = signature.Algorithm(data[1]) + if _, ok := signature.Algorithms()[m.SignatureAlgorithm]; !ok { + return errInvalidSignatureAlgorithm + } + + signatureLength := int(binary.BigEndian.Uint16(data[2:])) + if (signatureLength + 4) != len(data) { + return errBufferTooSmall + } + + m.Signature = append([]byte{}, data[4:]...) + + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_client_hello.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_client_hello.go new file mode 100644 index 0000000..49baeaa --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_client_hello.go @@ -0,0 +1,141 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package handshake + +import ( + "encoding/binary" + + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/extension" +) + +/* +MessageClientHello is for when a client first connects to a server it is +required to send the client hello as its first message. The client can also send a +client hello in response to a hello request or on its own +initiative in order to renegotiate the security parameters in an +existing connection. +*/ +type MessageClientHello struct { + Version protocol.Version + Random Random + Cookie []byte + + SessionID []byte + + CipherSuiteIDs []uint16 + CompressionMethods []*protocol.CompressionMethod + Extensions []extension.Extension +} + +const handshakeMessageClientHelloVariableWidthStart = 34 + +// Type returns the Handshake Type. +func (m MessageClientHello) Type() Type { + return TypeClientHello +} + +// Marshal encodes the Handshake. +func (m *MessageClientHello) Marshal() ([]byte, error) { + if len(m.Cookie) > 255 { + return nil, errCookieTooLong + } + + out := make([]byte, handshakeMessageClientHelloVariableWidthStart) + out[0], out[1] = m.Version.Major, m.Version.Minor //nolint:gosec // out is initialized with length 34 + + rand := m.Random.MarshalFixed() + copy(out[2:], rand[:]) + + out = append(out, byte(len(m.SessionID))) //nolint:makezero // todo: fix + out = append(out, m.SessionID...) //nolint:makezero // todo: fix + + out = append(out, byte(len(m.Cookie))) //nolint:makezero // todo: fix + out = append(out, m.Cookie...) //nolint:makezero // todo: fix + out = append(out, encodeCipherSuiteIDs(m.CipherSuiteIDs)...) //nolint:makezero // todo: fix + out = append(out, protocol.EncodeCompressionMethods(m.CompressionMethods)...) //nolint:makezero // todo: fix + + extensions, err := extension.Marshal(m.Extensions) + if err != nil { + return nil, err + } + + return append(out, extensions...), nil //nolint:makezero // todo: fix +} + +// Unmarshal populates the message from encoded data. +func (m *MessageClientHello) Unmarshal(data []byte) error { //nolint:cyclop + if len(data) < 2+RandomLength { + return errBufferTooSmall + } + + m.Version.Major = data[0] + m.Version.Minor = data[1] + + var random [RandomLength]byte + copy(random[:], data[2:]) + m.Random.UnmarshalFixed(random) + + // rest of packet has variable width sections + currOffset := handshakeMessageClientHelloVariableWidthStart + + currOffset++ + if len(data) <= currOffset { + return errBufferTooSmall + } + n := int(data[currOffset-1]) + if len(data) <= currOffset+n { + return errBufferTooSmall + } + m.SessionID = append([]byte{}, data[currOffset:currOffset+n]...) + currOffset += len(m.SessionID) + + currOffset++ + if len(data) <= currOffset { + return errBufferTooSmall + } + n = int(data[currOffset-1]) + if len(data) <= currOffset+n { + return errBufferTooSmall + } + m.Cookie = append([]byte{}, data[currOffset:currOffset+n]...) + currOffset += len(m.Cookie) + + // Cipher Suites + if len(data) < currOffset { + return errBufferTooSmall + } + cipherSuiteIDs, err := decodeCipherSuiteIDs(data[currOffset:]) + if err != nil { + return err + } + m.CipherSuiteIDs = cipherSuiteIDs + if len(data) < currOffset+2 { + return errBufferTooSmall + } + currOffset += int(binary.BigEndian.Uint16(data[currOffset:])) + 2 + + // Compression Methods + if len(data) < currOffset { + return errBufferTooSmall + } + compressionMethods, err := protocol.DecodeCompressionMethods(data[currOffset:]) + if err != nil { + return err + } + m.CompressionMethods = compressionMethods + if len(data) < currOffset { + return errBufferTooSmall + } + currOffset += int(data[currOffset]) + 1 + + // Extensions + extensions, err := extension.Unmarshal(data[currOffset:]) + if err != nil { + return err + } + m.Extensions = extensions + + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_client_key_exchange.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_client_key_exchange.go new file mode 100644 index 0000000..2208945 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_client_key_exchange.go @@ -0,0 +1,81 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package handshake + +import ( + "encoding/binary" + + "github.com/pion/dtls/v3/internal/ciphersuite/types" +) + +// MessageClientKeyExchange is a DTLS Handshake Message +// With this message, the premaster secret is set, either by direct +// transmission of the RSA-encrypted secret or by the transmission of +// Diffie-Hellman parameters that will allow each side to agree upon +// the same premaster secret. +// +// https://tools.ietf.org/html/rfc5246#section-7.4.7 +type MessageClientKeyExchange struct { + IdentityHint []byte + PublicKey []byte + + // for unmarshaling + KeyExchangeAlgorithm types.KeyExchangeAlgorithm +} + +// Type returns the Handshake Type. +func (m MessageClientKeyExchange) Type() Type { + return TypeClientKeyExchange +} + +// Marshal encodes the Handshake. +func (m *MessageClientKeyExchange) Marshal() (out []byte, err error) { + if m.IdentityHint == nil && m.PublicKey == nil { + return nil, errInvalidClientKeyExchange + } + + if m.IdentityHint != nil { + out = append([]byte{0x00, 0x00}, m.IdentityHint...) + binary.BigEndian.PutUint16(out, uint16(len(out)-2)) //nolint:gosec // G115 + } + + if m.PublicKey != nil { + out = append(out, byte(len(m.PublicKey))) + out = append(out, m.PublicKey...) + } + + return out, nil +} + +// Unmarshal populates the message from encoded data. +func (m *MessageClientKeyExchange) Unmarshal(data []byte) error { + switch { + case len(data) < 2: + return errBufferTooSmall + case m.KeyExchangeAlgorithm == types.KeyExchangeAlgorithmNone: + return errCipherSuiteUnset + } + + offset := 0 + if m.KeyExchangeAlgorithm.Has(types.KeyExchangeAlgorithmPsk) { + pskLength := int(binary.BigEndian.Uint16(data)) + if pskLength > len(data)-2 { + return errBufferTooSmall + } + + m.IdentityHint = append([]byte{}, data[2:pskLength+2]...) + offset += pskLength + 2 + } + + if m.KeyExchangeAlgorithm.Has(types.KeyExchangeAlgorithmEcdhe) { + publicKeyLength := int(data[offset]) + if publicKeyLength > len(data)-1-offset { + return errBufferTooSmall + } + + m.PublicKey = append([]byte{}, data[offset+1:]...) + } + + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_finished.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_finished.go new file mode 100644 index 0000000..6362eac --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_finished.go @@ -0,0 +1,31 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package handshake + +// MessageFinished is a DTLS Handshake Message +// this message is the first one protected with the just +// negotiated algorithms, keys, and secrets. Recipients of Finished +// messages MUST verify that the contents are correct. +// +// https://tools.ietf.org/html/rfc5246#section-7.4.9 +type MessageFinished struct { + VerifyData []byte +} + +// Type returns the Handshake Type. +func (m MessageFinished) Type() Type { + return TypeFinished +} + +// Marshal encodes the Handshake. +func (m *MessageFinished) Marshal() ([]byte, error) { + return append([]byte{}, m.VerifyData...), nil +} + +// Unmarshal populates the message from encoded data. +func (m *MessageFinished) Unmarshal(data []byte) error { + m.VerifyData = append([]byte{}, data...) + + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_hello_verify_request.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_hello_verify_request.go new file mode 100644 index 0000000..796ea65 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_hello_verify_request.go @@ -0,0 +1,66 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package handshake + +import ( + "github.com/pion/dtls/v3/pkg/protocol" +) + +// MessageHelloVerifyRequest is as follows: +// +// struct { +// ProtocolVersion server_version; +// opaque cookie<0..2^8-1>; +// } HelloVerifyRequest; +// +// The HelloVerifyRequest message type is hello_verify_request(3). +// +// When the client sends its ClientHello message to the server, the server +// MAY respond with a HelloVerifyRequest message. This message contains +// a stateless cookie generated using the technique of [PHOTURIS]. The +// client MUST retransmit the ClientHello with the cookie added. +// +// https://tools.ietf.org/html/rfc6347#section-4.2.1 +type MessageHelloVerifyRequest struct { + Version protocol.Version + Cookie []byte +} + +// Type returns the Handshake Type. +func (m MessageHelloVerifyRequest) Type() Type { + return TypeHelloVerifyRequest +} + +// Marshal encodes the Handshake. +func (m *MessageHelloVerifyRequest) Marshal() ([]byte, error) { + if len(m.Cookie) > 255 { + return nil, errCookieTooLong + } + + out := make([]byte, 3+len(m.Cookie)) + out[0] = m.Version.Major + out[1] = m.Version.Minor + out[2] = byte(len(m.Cookie)) + copy(out[3:], m.Cookie) + + return out, nil +} + +// Unmarshal populates the message from encoded data. +func (m *MessageHelloVerifyRequest) Unmarshal(data []byte) error { + if len(data) < 3 { + return errBufferTooSmall + } + m.Version.Major = data[0] + m.Version.Minor = data[1] + cookieLength := int(data[2]) + if len(data) < cookieLength+3 { + return errBufferTooSmall + } + m.Cookie = make([]byte, cookieLength) + + copy(m.Cookie, data[3:3+cookieLength]) + + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_server_hello.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_server_hello.go new file mode 100644 index 0000000..400e15c --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_server_hello.go @@ -0,0 +1,124 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package handshake + +import ( + "encoding/binary" + + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/extension" +) + +// MessageServerHello is sent in response to a ClientHello +// message when it was able to find an acceptable set of algorithms. +// If it cannot find such a match, it will respond with a handshake +// failure alert. +// +// https://tools.ietf.org/html/rfc5246#section-7.4.1.3 +type MessageServerHello struct { + Version protocol.Version + Random Random + + SessionID []byte + + CipherSuiteID *uint16 + CompressionMethod *protocol.CompressionMethod + Extensions []extension.Extension +} + +const messageServerHelloVariableWidthStart = 2 + RandomLength + +// Type returns the Handshake Type. +func (m MessageServerHello) Type() Type { + return TypeServerHello +} + +// Marshal encodes the Handshake. +func (m *MessageServerHello) Marshal() ([]byte, error) { + if m.CipherSuiteID == nil { + return nil, errCipherSuiteUnset + } else if m.CompressionMethod == nil { + return nil, errCompressionMethodUnset + } + + out := make([]byte, messageServerHelloVariableWidthStart) + out[0] = m.Version.Major + out[1] = m.Version.Minor + + rand := m.Random.MarshalFixed() + copy(out[2:], rand[:]) + + out = append(out, byte(len(m.SessionID))) //nolint:makezero // todo: fix + out = append(out, m.SessionID...) //nolint:makezero // todo: fix + + out = append(out, []byte{0x00, 0x00}...) //nolint:makezero // todo: fix + binary.BigEndian.PutUint16(out[len(out)-2:], *m.CipherSuiteID) + + out = append(out, byte(m.CompressionMethod.ID)) //nolint:makezero // todo: fix + + extensions, err := extension.Marshal(m.Extensions) + if err != nil { + return nil, err + } + + return append(out, extensions...), nil //nolint:makezero // todo: fix +} + +// Unmarshal populates the message from encoded data. +func (m *MessageServerHello) Unmarshal(data []byte) error { + if len(data) < 2+RandomLength { + return errBufferTooSmall + } + + m.Version.Major = data[0] + m.Version.Minor = data[1] + + var random [RandomLength]byte + copy(random[:], data[2:]) + m.Random.UnmarshalFixed(random) + + currOffset := messageServerHelloVariableWidthStart + currOffset++ + if len(data) <= currOffset { + return errBufferTooSmall + } + + n := int(data[currOffset-1]) + if len(data) <= currOffset+n { + return errBufferTooSmall + } + m.SessionID = append([]byte{}, data[currOffset:currOffset+n]...) + currOffset += len(m.SessionID) + + if len(data) < currOffset+2 { + return errBufferTooSmall + } + m.CipherSuiteID = new(uint16) + *m.CipherSuiteID = binary.BigEndian.Uint16(data[currOffset:]) + currOffset += 2 + + if len(data) <= currOffset { + return errBufferTooSmall + } + if compressionMethod, ok := protocol.CompressionMethods()[protocol.CompressionMethodID(data[currOffset])]; ok { + m.CompressionMethod = compressionMethod + currOffset++ + } else { + return errInvalidCompressionMethod + } + + if len(data) <= currOffset { + m.Extensions = []extension.Extension{} + + return nil + } + + extensions, err := extension.Unmarshal(data[currOffset:]) + if err != nil { + return err + } + m.Extensions = extensions + + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_server_hello_done.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_server_hello_done.go new file mode 100644 index 0000000..87cc9ad --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_server_hello_done.go @@ -0,0 +1,24 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package handshake + +// MessageServerHelloDone is final non-encrypted message from server +// this communicates server has sent all its handshake messages and next +// should be MessageFinished. +type MessageServerHelloDone struct{} + +// Type returns the Handshake Type. +func (m MessageServerHelloDone) Type() Type { + return TypeServerHelloDone +} + +// Marshal encodes the Handshake. +func (m *MessageServerHelloDone) Marshal() ([]byte, error) { + return []byte{}, nil +} + +// Unmarshal populates the message from encoded data. +func (m *MessageServerHelloDone) Unmarshal([]byte) error { + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_server_key_exchange.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_server_key_exchange.go new file mode 100644 index 0000000..93f6231 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/message_server_key_exchange.go @@ -0,0 +1,150 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package handshake + +import ( + "encoding/binary" + + "github.com/pion/dtls/v3/internal/ciphersuite/types" + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/crypto/hash" + "github.com/pion/dtls/v3/pkg/crypto/signature" +) + +// MessageServerKeyExchange supports ECDH and PSK. +type MessageServerKeyExchange struct { + IdentityHint []byte + + EllipticCurveType elliptic.CurveType + NamedCurve elliptic.Curve + PublicKey []byte + HashAlgorithm hash.Algorithm + SignatureAlgorithm signature.Algorithm + Signature []byte + + // for unmarshaling + KeyExchangeAlgorithm types.KeyExchangeAlgorithm +} + +// Type returns the Handshake Type. +func (m MessageServerKeyExchange) Type() Type { + return TypeServerKeyExchange +} + +// Marshal encodes the Handshake. +func (m *MessageServerKeyExchange) Marshal() ([]byte, error) { //nolint:cyclop + var out []byte + if m.IdentityHint != nil { + out = append([]byte{0x00, 0x00}, m.IdentityHint...) + binary.BigEndian.PutUint16(out, uint16(len(out)-2)) //nolint:gosec //G115 + } + + if m.EllipticCurveType == 0 || len(m.PublicKey) == 0 { + return out, nil + } + out = append(out, byte(m.EllipticCurveType), 0x00, 0x00) + binary.BigEndian.PutUint16(out[len(out)-2:], uint16(m.NamedCurve)) + + out = append(out, byte(len(m.PublicKey))) + out = append(out, m.PublicKey...) + switch { + case m.HashAlgorithm != hash.None && len(m.Signature) == 0: + return nil, errInvalidHashAlgorithm + case m.HashAlgorithm == hash.None && len(m.Signature) > 0: + return nil, errInvalidHashAlgorithm + case m.SignatureAlgorithm == signature.Anonymous && (m.HashAlgorithm != hash.None || len(m.Signature) > 0): + return nil, errInvalidSignatureAlgorithm + case m.SignatureAlgorithm == signature.Anonymous: + return out, nil + } + + out = append(out, []byte{byte(m.HashAlgorithm), byte(m.SignatureAlgorithm), 0x00, 0x00}...) + binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(m.Signature))) //nolint:gosec // G115 + out = append(out, m.Signature...) + + return out, nil +} + +// Unmarshal populates the message from encoded data. +func (m *MessageServerKeyExchange) Unmarshal(data []byte) error { //nolint:cyclop + switch { + case len(data) < 2: + return errBufferTooSmall + case m.KeyExchangeAlgorithm == types.KeyExchangeAlgorithmNone: + return errCipherSuiteUnset + } + + hintLength := binary.BigEndian.Uint16(data) + if int(hintLength) <= len(data)-2 && m.KeyExchangeAlgorithm.Has(types.KeyExchangeAlgorithmPsk) { + m.IdentityHint = append([]byte{}, data[2:2+hintLength]...) + data = data[2+hintLength:] + } + if m.KeyExchangeAlgorithm == types.KeyExchangeAlgorithmPsk { + if len(data) == 0 { + return nil + } + + return errLengthMismatch + } + + if !m.KeyExchangeAlgorithm.Has(types.KeyExchangeAlgorithmEcdhe) { + return errLengthMismatch + } + + if _, ok := elliptic.CurveTypes()[elliptic.CurveType(data[0])]; ok { + m.EllipticCurveType = elliptic.CurveType(data[0]) + } else { + return errInvalidEllipticCurveType + } + + if len(data[1:]) < 2 { + return errBufferTooSmall + } + m.NamedCurve = elliptic.Curve(binary.BigEndian.Uint16(data[1:3])) + if _, ok := elliptic.Curves()[m.NamedCurve]; !ok { + return errInvalidNamedCurve + } + if len(data) < 4 { + return errBufferTooSmall + } + + publicKeyLength := int(data[3]) + offset := 4 + publicKeyLength + if len(data) < offset { + return errBufferTooSmall + } + m.PublicKey = append([]byte{}, data[4:offset]...) + + // Anon connection doesn't contains hashAlgorithm, signatureAlgorithm, signature + if len(data) == offset { + return nil + } else if len(data) <= offset { + return errBufferTooSmall + } + + m.HashAlgorithm = hash.Algorithm(data[offset]) + if _, ok := hash.Algorithms()[m.HashAlgorithm]; !ok { + return errInvalidHashAlgorithm + } + offset++ + if len(data) <= offset { + return errBufferTooSmall + } + m.SignatureAlgorithm = signature.Algorithm(data[offset]) + if _, ok := signature.Algorithms()[m.SignatureAlgorithm]; !ok { + return errInvalidSignatureAlgorithm + } + offset++ + if len(data) < offset+2 { + return errBufferTooSmall + } + signatureLength := int(binary.BigEndian.Uint16(data[offset:])) + offset += 2 + if len(data) < offset+signatureLength { + return errBufferTooSmall + } + m.Signature = append([]byte{}, data[offset:offset+signatureLength]...) + + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/random.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/random.go new file mode 100644 index 0000000..3efa31a --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/handshake/random.go @@ -0,0 +1,52 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package handshake + +import ( + "crypto/rand" + "encoding/binary" + "time" +) + +// Consts for Random in Handshake. +const ( + RandomBytesLength = 28 + RandomLength = RandomBytesLength + 4 +) + +// Random value that is used in ClientHello and ServerHello +// +// https://tools.ietf.org/html/rfc4346#section-7.4.1.2 +type Random struct { + GMTUnixTime time.Time + RandomBytes [RandomBytesLength]byte +} + +// MarshalFixed encodes the Handshake. +func (r *Random) MarshalFixed() [RandomLength]byte { + var out [RandomLength]byte + + binary.BigEndian.PutUint32(out[0:], uint32(r.GMTUnixTime.Unix())) //nolint:gosec // G115 + copy(out[4:], r.RandomBytes[:]) + + return out +} + +// UnmarshalFixed populates the message from encoded data. +func (r *Random) UnmarshalFixed(data [RandomLength]byte) { + r.GMTUnixTime = time.Unix(int64(binary.BigEndian.Uint32(data[0:])), 0) + copy(r.RandomBytes[:], data[4:]) +} + +// Populate fills the handshakeRandom with random values +// may be called multiple times. +func (r *Random) Populate() error { + r.GMTUnixTime = time.Now() + + tmp := make([]byte, RandomBytesLength) + _, err := rand.Read(tmp) + copy(r.RandomBytes[:], tmp) + + return err +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/recordlayer/errors.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/recordlayer/errors.go new file mode 100644 index 0000000..6f3f114 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/recordlayer/errors.go @@ -0,0 +1,24 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package recordlayer implements the TLS Record Layer https://tools.ietf.org/html/rfc5246#section-6 +package recordlayer + +import ( + "errors" + + "github.com/pion/dtls/v3/pkg/protocol" +) + +var ( + // ErrInvalidPacketLength is returned when the packet length too small + // or declared length do not match. + ErrInvalidPacketLength = &protocol.TemporaryError{ + Err: errors.New("packet length and declared length do not match"), //nolint:err113 + } + + errBufferTooSmall = &protocol.TemporaryError{Err: errors.New("buffer is too small")} //nolint:err113 + errSequenceNumberOverflow = &protocol.InternalError{Err: errors.New("sequence number overflow")} //nolint:err113 + errUnsupportedProtocolVersion = &protocol.FatalError{Err: errors.New("unsupported protocol version")} //nolint:err113 + errInvalidContentType = &protocol.TemporaryError{Err: errors.New("invalid content type")} //nolint:err113 +) diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/recordlayer/header.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/recordlayer/header.go new file mode 100644 index 0000000..d1255f8 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/recordlayer/header.go @@ -0,0 +1,86 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package recordlayer + +import ( + "encoding/binary" + + "github.com/pion/dtls/v3/internal/util" + "github.com/pion/dtls/v3/pkg/protocol" +) + +// Header implements a TLS RecordLayer header. +type Header struct { + ContentType protocol.ContentType + ContentLen uint16 + Version protocol.Version + Epoch uint16 + SequenceNumber uint64 // uint48 in spec + + // Optional Fields + ConnectionID []byte +} + +// RecordLayer enums. +const ( + // FixedHeaderSize is the size of a DTLS record header when connection IDs + // are not in use. + FixedHeaderSize = 13 + MaxSequenceNumber = 0x0000FFFFFFFFFFFF +) + +// Marshal encodes a TLS RecordLayer Header to binary. +func (h *Header) Marshal() ([]byte, error) { + if h.SequenceNumber > MaxSequenceNumber { + return nil, errSequenceNumberOverflow + } + + hs := FixedHeaderSize + len(h.ConnectionID) + + out := make([]byte, hs) + out[0] = byte(h.ContentType) + out[1] = h.Version.Major + out[2] = h.Version.Minor + binary.BigEndian.PutUint16(out[3:], h.Epoch) + util.PutBigEndianUint48(out[5:], h.SequenceNumber) + copy(out[11:11+len(h.ConnectionID)], h.ConnectionID) + binary.BigEndian.PutUint16(out[hs-2:], h.ContentLen) + + return out, nil +} + +// Unmarshal populates a TLS RecordLayer Header from binary. +func (h *Header) Unmarshal(data []byte) error { + if len(data) < FixedHeaderSize { + return errBufferTooSmall + } + h.ContentType = protocol.ContentType(data[0]) + if h.ContentType == protocol.ContentTypeConnectionID { + // If a CID was expected the ConnectionID should have been initialized. + if len(data) < FixedHeaderSize+len(h.ConnectionID) { + return errBufferTooSmall + } + h.ConnectionID = data[11 : 11+len(h.ConnectionID)] + } + + h.Version.Major = data[1] + h.Version.Minor = data[2] + h.Epoch = binary.BigEndian.Uint16(data[3:]) + + // SequenceNumber is stored as uint48, make into uint64 + seqCopy := make([]byte, 8) + copy(seqCopy[2:], data[5:11]) + h.SequenceNumber = binary.BigEndian.Uint64(seqCopy) + + if !h.Version.Equal(protocol.Version1_0) && !h.Version.Equal(protocol.Version1_2) { + return errUnsupportedProtocolVersion + } + + return nil +} + +// Size returns the total size of the header. +func (h *Header) Size() int { + return FixedHeaderSize + len(h.ConnectionID) +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/recordlayer/inner_plaintext.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/recordlayer/inner_plaintext.go new file mode 100644 index 0000000..353ebb1 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/recordlayer/inner_plaintext.go @@ -0,0 +1,49 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package recordlayer + +import ( + "github.com/pion/dtls/v3/pkg/protocol" + "golang.org/x/crypto/cryptobyte" +) + +// InnerPlaintext implements DTLSInnerPlaintext +// +// https://datatracker.ietf.org/doc/html/rfc9146#name-record-layer-extensions +type InnerPlaintext struct { + Content []byte + RealType protocol.ContentType + Zeros uint +} + +// Marshal encodes a DTLS InnerPlaintext to binary. +func (p *InnerPlaintext) Marshal() ([]byte, error) { + var out cryptobyte.Builder + out.AddBytes(p.Content) + out.AddUint8(uint8(p.RealType)) + out.AddBytes(make([]byte, p.Zeros)) + + return out.Bytes() +} + +// Unmarshal populates a DTLS InnerPlaintext from binary. +func (p *InnerPlaintext) Unmarshal(data []byte) error { + // Process in reverse + i := len(data) - 1 + for i >= 0 { + if data[i] != 0 { + p.Zeros = uint(len(data) - 1 - i) //nolint:gosec // G115 + + break + } + i-- + } + if i == 0 { + return errBufferTooSmall + } + p.RealType = protocol.ContentType(data[i]) + p.Content = append([]byte{}, data[:i]...) + + return nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/recordlayer/recordlayer.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/recordlayer/recordlayer.go new file mode 100644 index 0000000..a9a456f --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/recordlayer/recordlayer.go @@ -0,0 +1,145 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package recordlayer + +import ( + "encoding/binary" + + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/handshake" +) + +// DTLS fixed size record layer header when Connection IDs are not in-use. + +// --------------------------------- +// | Type | Version | Epoch | +// --------------------------------- +// | Epoch | Sequence Number | +// --------------------------------- +// | Sequence Number | Length | +// --------------------------------- +// | Length | Fragment... | +// --------------------------------- + +// fixedHeaderLenIdx is the index at which the record layer content length is +// specified in a fixed length header (i.e. one that does not include a +// Connection ID). +const fixedHeaderLenIdx = 11 + +// RecordLayer which handles all data transport. +// The record layer is assumed to sit directly on top of some +// reliable transport such as TCP. The record layer can carry four types of content: +// +// 1. Handshake messages—used for algorithm negotiation and key establishment. +// 2. ChangeCipherSpec messages—really part of the handshake but technically a separate kind of message. +// 3. Alert messages—used to signal that errors have occurred +// 4. Application layer data +// +// The DTLS record layer is extremely similar to that of TLS 1.1. The +// only change is the inclusion of an explicit sequence number in the +// record. This sequence number allows the recipient to correctly +// verify the TLS MAC. +// +// https://tools.ietf.org/html/rfc4347#section-4.1 +type RecordLayer struct { + Header Header + Content protocol.Content +} + +// Marshal encodes the RecordLayer to binary. +func (r *RecordLayer) Marshal() ([]byte, error) { + contentRaw, err := r.Content.Marshal() + if err != nil { + return nil, err + } + + r.Header.ContentLen = uint16(len(contentRaw)) //nolint:gosec // G115 + r.Header.ContentType = r.Content.ContentType() + + headerRaw, err := r.Header.Marshal() + if err != nil { + return nil, err + } + + return append(headerRaw, contentRaw...), nil +} + +// Unmarshal populates the RecordLayer from binary. +func (r *RecordLayer) Unmarshal(data []byte) error { + if err := r.Header.Unmarshal(data); err != nil { + return err + } + + switch r.Header.ContentType { + case protocol.ContentTypeChangeCipherSpec: + r.Content = &protocol.ChangeCipherSpec{} + case protocol.ContentTypeAlert: + r.Content = &alert.Alert{} + case protocol.ContentTypeHandshake: + r.Content = &handshake.Handshake{} + case protocol.ContentTypeApplicationData: + r.Content = &protocol.ApplicationData{} + default: + return errInvalidContentType + } + + return r.Content.Unmarshal(data[r.Header.Size()+len(r.Header.ConnectionID):]) +} + +// UnpackDatagram extracts all RecordLayer messages from a single datagram. +// Note that as with TLS, multiple handshake messages may be placed in +// the same DTLS record, provided that there is room and that they are +// part of the same flight. Thus, there are two acceptable ways to pack +// two DTLS messages into the same datagram: in the same record or in +// separate records. +// https://tools.ietf.org/html/rfc6347#section-4.2.3 +func UnpackDatagram(buf []byte) ([][]byte, error) { + out := [][]byte{} + + for offset := 0; len(buf) != offset; { + if len(buf)-offset <= FixedHeaderSize { + return nil, ErrInvalidPacketLength + } + + pktLen := (FixedHeaderSize + int(binary.BigEndian.Uint16(buf[offset+11:]))) + if offset+pktLen > len(buf) { + return nil, ErrInvalidPacketLength + } + + out = append(out, buf[offset:offset+pktLen]) + offset += pktLen + } + + return out, nil +} + +// ContentAwareUnpackDatagram is the same as UnpackDatagram but considers the +// presence of a connection identifier if the record is of content type +// tls12_cid. +func ContentAwareUnpackDatagram(buf []byte, cidLength int) ([][]byte, error) { + out := [][]byte{} + + for offset := 0; len(buf) != offset; { + headerSize := FixedHeaderSize + lenIdx := fixedHeaderLenIdx + if protocol.ContentType(buf[offset]) == protocol.ContentTypeConnectionID { + headerSize += cidLength + lenIdx += cidLength + } + if len(buf)-offset <= headerSize { + return nil, ErrInvalidPacketLength + } + + pktLen := (headerSize + int(binary.BigEndian.Uint16(buf[offset+lenIdx:]))) + if offset+pktLen > len(buf) { + return nil, ErrInvalidPacketLength + } + + out = append(out, buf[offset:offset+pktLen]) + offset += pktLen + } + + return out, nil +} diff --git a/vendor/github.com/pion/dtls/v3/pkg/protocol/version.go b/vendor/github.com/pion/dtls/v3/pkg/protocol/version.go new file mode 100644 index 0000000..bfb3d63 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/pkg/protocol/version.go @@ -0,0 +1,53 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package protocol provides the DTLS wire format +package protocol + +// Version enums. +var ( + Version1_0 = Version{Major: 0xfe, Minor: 0xff} //nolint:gochecknoglobals + Version1_2 = Version{Major: 0xfe, Minor: 0xfd} //nolint:gochecknoglobals + Version1_3 = Version{Major: 0xfe, Minor: 0xfc} //nolint:gochecknoglobals +) + +// Version is the minor/major value in the RecordLayer +// and ClientHello/ServerHello +// +// https://tools.ietf.org/html/rfc4346#section-6.2.1 +type Version struct { + Major, Minor uint8 +} + +// Equal determines if two protocol versions are equal. +func (v Version) Equal(x Version) bool { + return v.Major == x.Major && v.Minor == x.Minor +} + +// IsSupportedBytes returns true if it's supported by Pion. Only DTLS 1.2 is currently supported. +// DTLS 1.3 is a work in progress and is currently being implemented. +func IsSupportedBytes(major uint8, minor uint8) bool { + return major == 0xfe && (minor == 0xfd || minor == 0xfc) +} + +// IsSupportedVersion returns true if it's supported by Pion. Only DTLS 1.2 is currently supported. +// DTLS 1.3 is a work in progress and is currently being implemented. +func IsSupportedVersion(v Version) bool { + return v.Equal(Version1_2) || v.Equal(Version1_3) +} + +// IsValidBytes returns true if the bytes represent a valid DTLS version as defined in RFC9147 below. +// Note that this is not the same as whether it's *supported* by Pion. Please see IsSupportedBytes() for more info. +// +// https://tools.ietf.org/html/rfc9147#section-5.3 (see legacy_version) +func IsValidBytes(major uint8, minor uint8) bool { + return major == 0xfe && (minor == 0xff || minor == 0xfd || minor == 0xfc) +} + +// IsValidVersion returns true if the bytes represent a valid DTLS version as defined in RFC9147 below. +// Note that this is not the same as whether it's *supported* by Pion. Please see IsSupportedBytes() for more info. +// / +// https://tools.ietf.org/html/rfc9147#section-5.3 (see legacy_version) +func IsValidVersion(v Version) bool { + return v.Equal(Version1_0) || v.Equal(Version1_2) || v.Equal(Version1_3) +} diff --git a/vendor/github.com/pion/dtls/v3/renovate.json b/vendor/github.com/pion/dtls/v3/renovate.json new file mode 100644 index 0000000..f1bb98c --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/renovate.json @@ -0,0 +1,6 @@ +{ + "$schema": "https://docs.renovatebot.com/renovate-schema.json", + "extends": [ + "github>pion/renovate-config" + ] +} diff --git a/vendor/github.com/pion/dtls/v3/resume.go b/vendor/github.com/pion/dtls/v3/resume.go new file mode 100644 index 0000000..c0f9c81 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/resume.go @@ -0,0 +1,37 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "net" +) + +// Resume imports an already established dtls connection using a specific dtls state. +// +// Deprecated: Use ResumeWithOptions instead. +func Resume(state *State, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { + if err := state.initCipherSuite(); err != nil { + return nil, err + } + + if config == nil { + return nil, errNoConfigProvided + } + + if err := validateConfig(config); err != nil { + return nil, err + } + + return createConn(conn, rAddr, config, state.isClient, state) +} + +// ResumeWithOptions imports an already established dtls connection using a specific dtls state. +func ResumeWithOptions(state *State, conn net.PacketConn, rAddr net.Addr, opts ...Option) (*Conn, error) { + config, err := buildConfig(opts...) + if err != nil { + return nil, err + } + + return Resume(state, conn, rAddr, config) +} diff --git a/vendor/github.com/pion/dtls/v3/session.go b/vendor/github.com/pion/dtls/v3/session.go new file mode 100644 index 0000000..6113a88 --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/session.go @@ -0,0 +1,24 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +// Session store data needed in resumption. +type Session struct { + // ID store session id + ID []byte + // Secret store session master secret + Secret []byte +} + +// SessionStore defines methods needed for session resumption. +type SessionStore interface { + // Set save a session. + // For client, use server name as key. + // For server, use session id. + Set(key []byte, s Session) error + // Get fetch a session. + Get(key []byte) (Session, error) + // Del clean saved session. + Del(key []byte) error +} diff --git a/vendor/github.com/pion/dtls/v3/srtp_protection_profile.go b/vendor/github.com/pion/dtls/v3/srtp_protection_profile.go new file mode 100644 index 0000000..eb0927c --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/srtp_protection_profile.go @@ -0,0 +1,21 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import "github.com/pion/dtls/v3/pkg/protocol/extension" + +// SRTPProtectionProfile defines the parameters and options that are in effect for the SRTP processing +// https://tools.ietf.org/html/rfc5764#section-4.1.2 +type SRTPProtectionProfile = extension.SRTPProtectionProfile + +const ( + SRTP_AES128_CM_HMAC_SHA1_80 SRTPProtectionProfile = extension.SRTP_AES128_CM_HMAC_SHA1_80 // nolint: revive,staticcheck + SRTP_AES128_CM_HMAC_SHA1_32 SRTPProtectionProfile = extension.SRTP_AES128_CM_HMAC_SHA1_32 // nolint: revive,staticcheck + SRTP_AES256_CM_SHA1_80 SRTPProtectionProfile = extension.SRTP_AES256_CM_SHA1_80 // nolint: revive,staticcheck + SRTP_AES256_CM_SHA1_32 SRTPProtectionProfile = extension.SRTP_AES256_CM_SHA1_32 // nolint: revive,staticcheck + SRTP_NULL_HMAC_SHA1_80 SRTPProtectionProfile = extension.SRTP_NULL_HMAC_SHA1_80 // nolint: revive,staticcheck + SRTP_NULL_HMAC_SHA1_32 SRTPProtectionProfile = extension.SRTP_NULL_HMAC_SHA1_32 // nolint: revive,staticcheck + SRTP_AEAD_AES_128_GCM SRTPProtectionProfile = extension.SRTP_AEAD_AES_128_GCM // nolint: revive,staticcheck + SRTP_AEAD_AES_256_GCM SRTPProtectionProfile = extension.SRTP_AEAD_AES_256_GCM // nolint: revive,staticcheck +) diff --git a/vendor/github.com/pion/dtls/v3/state.go b/vendor/github.com/pion/dtls/v3/state.go new file mode 100644 index 0000000..c5fbfcf --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/state.go @@ -0,0 +1,303 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "bytes" + "encoding/gob" + "errors" + "sync/atomic" + + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/crypto/prf" + "github.com/pion/dtls/v3/pkg/crypto/signaturehash" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/transport/v4/replaydetector" +) + +// State holds the dtls connection state and implements both encoding.BinaryMarshaler and +// encoding.BinaryUnmarshaler. +type State struct { + localEpoch, remoteEpoch atomic.Value + localSequenceNumber []uint64 // uint48 + localRandom, remoteRandom handshake.Random + masterSecret []byte + cipherSuite CipherSuite // nil if a cipherSuite hasn't been chosen + CipherSuiteID CipherSuiteID + + remoteSupportsRenegotiation bool // True when Client Hello contained renegotiation extension + + srtpProtectionProfile atomic.Value // Negotiated SRTPProtectionProfile + remoteSRTPMasterKeyIdentifier []byte + + PeerCertificates [][]byte + IdentityHint []byte + SessionID []byte + + // Connection Identifiers must be negotiated afresh on session resumption. + // https://datatracker.ietf.org/doc/html/rfc9146#name-the-connection_id-extension + + // localConnectionID is the locally generated connection ID that is expected + // to be received from the remote endpoint. + // For a server, this is the connection ID sent in ServerHello. + // For a client, this is the connection ID sent in the ClientHello. + localConnectionID atomic.Value + // remoteConnectionID is the connection ID that the remote endpoint + // specifies should be sent. + // For a server, this is the connection ID received in the ClientHello. + // For a client, this is the connection ID received in the ServerHello. + remoteConnectionID []byte + + isClient bool + + preMasterSecret []byte + extendedMasterSecret bool + + namedCurve elliptic.Curve + localKeypair *elliptic.Keypair + cookie []byte + handshakeSendSequence int + handshakeRecvSequence int + serverName string + remoteCertRequestAlgs []signaturehash.Algorithm + remoteCertSignatureSchemes []signaturehash.Algorithm // signature_algorithms_cert from peer + remoteRequestedCertificate bool // Did we get a CertificateRequest + localCertificatesVerify []byte // cache CertificateVerify + localVerifyData []byte // cached VerifyData + localKeySignature []byte // cached keySignature + peerCertificatesVerified bool + + replayDetector []replaydetector.ReplayDetector + + peerSupportedProtocols []string + NegotiatedProtocol string +} + +type serializedState struct { + LocalEpoch uint16 + RemoteEpoch uint16 + LocalRandom [handshake.RandomLength]byte + RemoteRandom [handshake.RandomLength]byte + CipherSuiteID uint16 + MasterSecret []byte + SequenceNumber uint64 + SRTPProtectionProfile uint16 + PeerCertificates [][]byte + IdentityHint []byte + SessionID []byte + LocalConnectionID []byte + RemoteConnectionID []byte + IsClient bool + NegotiatedProtocol string +} + +var errCipherSuiteNotSet = &InternalError{Err: errors.New("cipher suite not set")} //nolint:err113 + +func (s *State) clone() (*State, error) { + serialized, err := s.serialize() + if err != nil { + return nil, err + } + state := &State{} + state.deserialize(*serialized) + + return state, err +} + +func (s *State) serialize() (*serializedState, error) { + if s.cipherSuite == nil { + return nil, errCipherSuiteNotSet + } + cipherSuiteID := uint16(s.cipherSuite.ID()) + + // Marshal random values + localRnd := s.localRandom.MarshalFixed() + remoteRnd := s.remoteRandom.MarshalFixed() + + epoch := s.getLocalEpoch() + + return &serializedState{ + LocalEpoch: s.getLocalEpoch(), + RemoteEpoch: s.getRemoteEpoch(), + CipherSuiteID: cipherSuiteID, + MasterSecret: s.masterSecret, + SequenceNumber: atomic.LoadUint64(&s.localSequenceNumber[epoch]), + LocalRandom: localRnd, + RemoteRandom: remoteRnd, + SRTPProtectionProfile: uint16(s.getSRTPProtectionProfile()), + PeerCertificates: s.PeerCertificates, + IdentityHint: s.IdentityHint, + SessionID: s.SessionID, + LocalConnectionID: s.getLocalConnectionID(), + RemoteConnectionID: s.remoteConnectionID, + IsClient: s.isClient, + NegotiatedProtocol: s.NegotiatedProtocol, + }, nil +} + +func (s *State) deserialize(serialized serializedState) { + // Set epoch values + epoch := serialized.LocalEpoch + s.localEpoch.Store(serialized.LocalEpoch) + s.remoteEpoch.Store(serialized.RemoteEpoch) + + for len(s.localSequenceNumber) <= int(epoch) { + s.localSequenceNumber = append(s.localSequenceNumber, uint64(0)) + } + + // Set random values + localRandom := &handshake.Random{} + localRandom.UnmarshalFixed(serialized.LocalRandom) + s.localRandom = *localRandom + + remoteRandom := &handshake.Random{} + remoteRandom.UnmarshalFixed(serialized.RemoteRandom) + s.remoteRandom = *remoteRandom + + s.isClient = serialized.IsClient + + // Set master secret + s.masterSecret = serialized.MasterSecret + + // Set cipher suite + s.CipherSuiteID = CipherSuiteID(serialized.CipherSuiteID) + s.cipherSuite = cipherSuiteForID(s.CipherSuiteID, nil) + + atomic.StoreUint64(&s.localSequenceNumber[epoch], serialized.SequenceNumber) + s.setSRTPProtectionProfile(SRTPProtectionProfile(serialized.SRTPProtectionProfile)) + + // Set remote certificate + s.PeerCertificates = serialized.PeerCertificates + + s.IdentityHint = serialized.IdentityHint + + // Set local and remote connection IDs + s.setLocalConnectionID(serialized.LocalConnectionID) + s.remoteConnectionID = serialized.RemoteConnectionID + + s.SessionID = serialized.SessionID + + s.NegotiatedProtocol = serialized.NegotiatedProtocol +} + +func (s *State) initCipherSuite() error { + if s.cipherSuite.IsInitialized() { + return nil + } + + localRandom := s.localRandom.MarshalFixed() + remoteRandom := s.remoteRandom.MarshalFixed() + + var err error + if s.isClient { + err = s.cipherSuite.Init(s.masterSecret, localRandom[:], remoteRandom[:], true) + } else { + err = s.cipherSuite.Init(s.masterSecret, remoteRandom[:], localRandom[:], false) + } + if err != nil { + return err + } + + return nil +} + +// MarshalBinary is a binary.BinaryMarshaler.MarshalBinary implementation. +func (s *State) MarshalBinary() ([]byte, error) { + serialized, err := s.serialize() + if err != nil { + return nil, err + } + + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + if err := enc.Encode(*serialized); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// UnmarshalBinary is a binary.BinaryUnmarshaler.UnmarshalBinary implementation. +func (s *State) UnmarshalBinary(data []byte) error { + enc := gob.NewDecoder(bytes.NewBuffer(data)) + var serialized serializedState + if err := enc.Decode(&serialized); err != nil { + return err + } + + s.deserialize(serialized) + + return s.initCipherSuite() +} + +// ExportKeyingMaterial returns length bytes of exported key material in a new +// slice as defined in RFC 5705. +// This allows protocols to use DTLS for key establishment, but +// then use some of the keying material for their own purposes. +func (s *State) ExportKeyingMaterial(label string, context []byte, length int) ([]byte, error) { + if s.getLocalEpoch() == 0 { + return nil, errHandshakeInProgress + } else if len(context) != 0 { + return nil, errContextUnsupported + } else if _, ok := invalidKeyingLabels()[label]; ok { + return nil, errReservedExportKeyingMaterial + } + + localRandom := s.localRandom.MarshalFixed() + remoteRandom := s.remoteRandom.MarshalFixed() + + seed := []byte(label) + if s.isClient { + seed = append(append(seed, localRandom[:]...), remoteRandom[:]...) + } else { + seed = append(append(seed, remoteRandom[:]...), localRandom[:]...) + } + + return prf.PHash(s.masterSecret, seed, length, s.cipherSuite.HashFunc()) +} + +func (s *State) getRemoteEpoch() uint16 { + if remoteEpoch, ok := s.remoteEpoch.Load().(uint16); ok { + return remoteEpoch + } + + return 0 +} + +func (s *State) getLocalEpoch() uint16 { + if localEpoch, ok := s.localEpoch.Load().(uint16); ok { + return localEpoch + } + + return 0 +} + +func (s *State) setSRTPProtectionProfile(profile SRTPProtectionProfile) { + s.srtpProtectionProfile.Store(profile) +} + +func (s *State) getSRTPProtectionProfile() SRTPProtectionProfile { + if val, ok := s.srtpProtectionProfile.Load().(SRTPProtectionProfile); ok { + return val + } + + return 0 +} + +func (s *State) getLocalConnectionID() []byte { + if val, ok := s.localConnectionID.Load().([]byte); ok { + return val + } + + return nil +} + +func (s *State) setLocalConnectionID(v []byte) { + s.localConnectionID.Store(v) +} + +// RemoteRandomBytes returns the remote client hello random bytes. +func (s *State) RemoteRandomBytes() [handshake.RandomBytesLength]byte { + return s.remoteRandom.RandomBytes +} diff --git a/vendor/github.com/pion/dtls/v3/util.go b/vendor/github.com/pion/dtls/v3/util.go new file mode 100644 index 0000000..7b5cb0e --- /dev/null +++ b/vendor/github.com/pion/dtls/v3/util.go @@ -0,0 +1,40 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import "slices" + +func findMatchingSRTPProfile(a, b []SRTPProtectionProfile) (SRTPProtectionProfile, bool) { + for _, aProfile := range a { + if slices.Contains(b, aProfile) { + return aProfile, true + } + } + + return 0, false +} + +func findMatchingCipherSuite(a, b []CipherSuite) (CipherSuite, bool) { + for _, aSuite := range a { + for _, bSuite := range b { + if aSuite.ID() == bSuite.ID() { + return aSuite, true + } + } + } + + return nil, false +} + +func splitBytes(bytes []byte, splitLen int) [][]byte { + splitBytes := make([][]byte, 0) + numBytes := len(bytes) + for i := 0; i < numBytes; i += splitLen { + j := min(i+splitLen, numBytes) + + splitBytes = append(splitBytes, bytes[i:j]) + } + + return splitBytes +} diff --git a/vendor/github.com/pion/ice/v4/.gitignore b/vendor/github.com/pion/ice/v4/.gitignore new file mode 100644 index 0000000..2394557 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/.gitignore @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: 2026 The Pion community +# SPDX-License-Identifier: MIT + +### JetBrains IDE ### +##################### +.idea/ + +### Emacs Temporary Files ### +############################# +*~ + +### Folders ### +############### +bin/ +vendor/ +node_modules/ + +### Files ### +############# +*.ivf +*.ogg +tags +cover.out +*.sw[poe] +*.wasm +examples/sfu-ws/cert.pem +examples/sfu-ws/key.pem +wasm_exec.js diff --git a/vendor/github.com/pion/ice/v4/.golangci.yml b/vendor/github.com/pion/ice/v4/.golangci.yml new file mode 100644 index 0000000..1fbb8db --- /dev/null +++ b/vendor/github.com/pion/ice/v4/.golangci.yml @@ -0,0 +1,148 @@ +# SPDX-FileCopyrightText: 2026 The Pion community +# SPDX-License-Identifier: MIT + +version: "2" +linters: + enable: + - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers + - bidichk # Checks for dangerous unicode character sequences + - bodyclose # checks whether HTTP response body is closed successfully + - containedctx # containedctx is a linter that detects struct contained context.Context field + - contextcheck # check the function whether use a non-inherited context + - cyclop # checks function and package cyclomatic complexity + - decorder # check declaration order and count of types, constants, variables and functions + - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) + - dupl # Tool for code clone detection + - durationcheck # check for two durations multiplied together + - err113 # Golang linter to check the errors handling expressions + - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases + - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. + - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. + - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. + - exhaustive # check exhaustiveness of enum switch statements + - forbidigo # Forbids identifiers + - forcetypeassert # finds forced type assertions + - gochecknoglobals # Checks that no globals are present in Go code + - gocognit # Computes and checks the cognitive complexity of functions + - goconst # Finds repeated strings that could be replaced by a constant + - gocritic # The most opinionated Go source code linter + - gocyclo # Computes and checks the cyclomatic complexity of functions + - godot # Check if comments end in a period + - godox # Tool for detection of FIXME, TODO and other comment keywords + - goheader # Checks is file header matches to pattern + - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. + - goprintffuncname # Checks that printf-like functions are named with `f` at the end + - gosec # Inspects source code for security problems + - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string + - grouper # An analyzer to analyze expression groups. + - importas # Enforces consistent import aliases + - ineffassign # Detects when assignments to existing variables are not used + - lll # Reports long lines + - maintidx # maintidx measures the maintainability index of each function. + - makezero # Finds slice declarations with non-zero initial length + - misspell # Finds commonly misspelled English words in comments + - modernize # Replace and suggests simplifications to code + - nakedret # Finds naked returns in functions greater than a specified function length + - nestif # Reports deeply nested if statements + - nilerr # Finds the code that returns nil even if it checks that the error is not nil. + - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. + - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity + - noctx # noctx finds sending http request without context.Context + - predeclared # find code that shadows one of Go's predeclared identifiers + - revive # golint replacement, finds style mistakes + - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks + - tagliatelle # Checks the struct tags. + - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers + - unconvert # Remove unnecessary type conversions + - unparam # Reports unused function parameters + - unused # Checks Go code for unused constants, variables, functions and types + - varnamelen # checks that the length of a variable's name matches its scope + - wastedassign # wastedassign finds wasted assignment statements + - whitespace # Tool for detection of leading and trailing whitespace + disable: + - depguard # Go linter that checks if package imports are in a list of acceptable packages + - funlen # Tool for detection of long functions + - gochecknoinits # Checks that no init functions are present in Go code + - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. + - interfacebloat # A linter that checks length of interface. + - ireturn # Accept Interfaces, Return Concrete Types + - mnd # An analyzer to detect magic numbers + - nolintlint # Reports ill-formed or insufficient nolint directives + - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test + - prealloc # Finds slice declarations that could potentially be preallocated + - promlinter # Check Prometheus metrics naming via promlint + - rowserrcheck # checks whether Err of rows is checked successfully + - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. + - testpackage # linter that makes you use a separate _test package + - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes + - wrapcheck # Checks that errors returned from external packages are wrapped + - wsl # Whitespace Linter - Forces you to use empty lines! + settings: + staticcheck: + checks: + - all + - -QF1008 # "could remove embedded field", to keep it explicit! + - -QF1003 # "could use tagged switch on enum", Cases conflicts with exhaustive! + exhaustive: + default-signifies-exhaustive: true + forbidigo: + forbid: + - pattern: ^fmt.Print(f|ln)?$ + - pattern: ^log.(Panic|Fatal|Print)(f|ln)?$ + - pattern: ^os.Exit$ + - pattern: ^panic$ + - pattern: ^print(ln)?$ + - pattern: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$ + pkg: ^testing$ + msg: use testify/assert instead + analyze-types: true + gomodguard: + blocked: + modules: + - github.com/pkg/errors: + recommendations: + - errors + govet: + enable: + - shadow + revive: + rules: + # Prefer 'any' type alias over 'interface{}' for Go 1.18+ compatibility + - name: use-any + severity: warning + disabled: false + misspell: + locale: US + varnamelen: + max-distance: 12 + min-name-length: 2 + ignore-type-assert-ok: true + ignore-map-index-ok: true + ignore-chan-recv-ok: true + ignore-decls: + - i int + - n int + - w io.Writer + - r io.Reader + - b []byte + exclusions: + generated: lax + rules: + - linters: + - forbidigo + - gocognit + path: (examples|main\.go) + - linters: + - gocognit + path: _test\.go + - linters: + - forbidigo + path: cmd +formatters: + enable: + - gci # Gci control golang package import order and make it always deterministic. + - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification + - gofumpt # Gofumpt checks whether code was gofumpt-ed. + - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports + exclusions: + generated: lax diff --git a/vendor/github.com/pion/ice/v4/.goreleaser.yml b/vendor/github.com/pion/ice/v4/.goreleaser.yml new file mode 100644 index 0000000..8577d86 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/.goreleaser.yml @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: 2026 The Pion community +# SPDX-License-Identifier: MIT + +builds: +- skip: true diff --git a/vendor/github.com/pion/ice/v4/LICENSE b/vendor/github.com/pion/ice/v4/LICENSE new file mode 100644 index 0000000..d96df05 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/LICENSE @@ -0,0 +1,9 @@ +MIT License + +Copyright (c) 2026 The Pion community + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/pion/ice/v4/README.md b/vendor/github.com/pion/ice/v4/README.md new file mode 100644 index 0000000..5171a41 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/README.md @@ -0,0 +1,34 @@ +

+
+ Pion ICE +
+

+

A Go implementation of ICE

+

+ Pion ICE + join us on Discord Follow us on Bluesky +
+ GitHub Workflow Status + Go Reference + Coverage Status + Go Report Card + License: MIT +

+
+ +### Roadmap +The library is used as a part of our WebRTC implementation. Please refer to that [roadmap](https://github.com/pion/webrtc/issues/9) to track our major milestones. + +### Community +Pion has an active community on the [Discord](https://discord.gg/PngbdqpFbt). + +Follow the [Pion Bluesky](https://bsky.app/profile/pion.ly) or [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news. + +We are always looking to support **your projects**. Please reach out if you have something to build! +If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly) + +### Contributing +Check out the [contributing wiki](https://github.com/pion/webrtc/wiki/Contributing) to join the group of amazing people making this project possible + +### License +MIT License - see [LICENSE](LICENSE) for full text diff --git a/vendor/github.com/pion/ice/v4/active_tcp.go b/vendor/github.com/pion/ice/v4/active_tcp.go new file mode 100644 index 0000000..cf82a14 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/active_tcp.go @@ -0,0 +1,211 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "context" + "io" + "net" + "net/netip" + "sync/atomic" + "time" + + "github.com/pion/logging" + "github.com/pion/transport/v4/packetio" +) + +type activeTCPConn struct { + readBuffer, writeBuffer *packetio.Buffer + localAddr, remoteAddr atomic.Value + conn atomic.Value // stores net.Conn + closed atomic.Bool +} + +func newActiveTCPConn( + ctx context.Context, + localAddress string, + remoteAddress netip.AddrPort, + log logging.LeveledLogger, +) (a *activeTCPConn) { + a = &activeTCPConn{ + readBuffer: packetio.NewBuffer(), + writeBuffer: packetio.NewBuffer(), + } + + laddr, err := getTCPAddrOnInterface(localAddress) + if err != nil { + a.closed.Store(true) + log.Infof("Failed to dial TCP address %s: %v", remoteAddress, err) + + return a + } + a.localAddr.Store(laddr) + + go func() { + defer func() { + a.closed.Store(true) + }() + + dialer := &net.Dialer{ + LocalAddr: laddr, + } + conn, err := dialer.DialContext(ctx, "tcp", remoteAddress.String()) + if err != nil { + log.Infof("Failed to dial TCP address %s: %v", remoteAddress, err) + + return + } + a.conn.Store(conn) + a.remoteAddr.Store(conn.RemoteAddr()) + + go func() { + buff := make([]byte, receiveMTU) + + for !a.closed.Load() { + n, err := readStreamingPacket(conn, buff) + if err != nil { + log.Infof("Failed to read streaming packet: %s", err) + + break + } + + if _, err := a.readBuffer.Write(buff[:n]); err != nil { + log.Infof("Failed to write to buffer: %s", err) + + break + } + } + }() + + buff := make([]byte, receiveMTU) + + for !a.closed.Load() { + n, err := a.writeBuffer.Read(buff) + if err != nil { + log.Infof("Failed to read from buffer: %s", err) + + break + } + + if _, err = writeStreamingPacket(conn, buff[:n]); err != nil { + log.Infof("Failed to write streaming packet: %s", err) + + break + } + } + + if err := conn.Close(); err != nil { + log.Infof("Failed to close connection: %s", err) + } + }() + + return a +} + +func (a *activeTCPConn) ReadFrom(buff []byte) (n int, srcAddr net.Addr, err error) { + if a.closed.Load() { + return 0, nil, io.ErrClosedPipe + } + + n, err = a.readBuffer.Read(buff) + // RemoteAddr is assuredly set *after* we can read from the buffer + srcAddr = a.RemoteAddr() + + return +} + +func (a *activeTCPConn) WriteTo(buff []byte, _ net.Addr) (n int, err error) { + if a.closed.Load() { + return 0, io.ErrClosedPipe + } + + return a.writeBuffer.Write(buff) +} + +func (a *activeTCPConn) Close() error { + a.closed.Store(true) + _ = a.readBuffer.Close() + _ = a.writeBuffer.Close() + if c, ok := a.conn.Load().(net.Conn); ok { + _ = c.Close() + } + + return nil +} + +func (a *activeTCPConn) LocalAddr() net.Addr { + if v, ok := a.localAddr.Load().(*net.TCPAddr); ok { + return v + } + + return &net.TCPAddr{} +} + +// RemoteAddr returns the remote address of the connection which is only +// set once a background goroutine has successfully dialed. That means +// this may return ":0" for the address prior to that happening. If this +// becomes an issue, we can introduce a synchronization point between Dial +// and these methods. +func (a *activeTCPConn) RemoteAddr() net.Addr { + if v, ok := a.remoteAddr.Load().(*net.TCPAddr); ok { + return v + } + + return &net.TCPAddr{} +} + +func (a *activeTCPConn) SetDeadline(t time.Time) error { + if a.closed.Load() { + return io.EOF + } + if c, ok := a.conn.Load().(net.Conn); ok { + return c.SetDeadline(t) + } + + return io.EOF +} + +func (a *activeTCPConn) SetReadDeadline(t time.Time) error { + if a.closed.Load() { + return io.EOF + } + if c, ok := a.conn.Load().(net.Conn); ok { + return c.SetReadDeadline(t) + } + + return io.EOF +} + +func (a *activeTCPConn) SetWriteDeadline(t time.Time) error { + if a.closed.Load() { + return io.EOF + } + if c, ok := a.conn.Load().(net.Conn); ok { + return c.SetWriteDeadline(t) + } + + return io.EOF +} + +func getTCPAddrOnInterface(address string) (*net.TCPAddr, error) { + addr, err := net.ResolveTCPAddr("tcp", address) + if err != nil { + return nil, err + } + + l, err := net.ListenTCP("tcp", addr) + if err != nil { + return nil, err + } + defer func() { + _ = l.Close() + }() + + tcpAddr, ok := l.Addr().(*net.TCPAddr) + if !ok { + return nil, errInvalidAddress + } + + return tcpAddr, nil +} diff --git a/vendor/github.com/pion/ice/v4/addr.go b/vendor/github.com/pion/ice/v4/addr.go new file mode 100644 index 0000000..3b9f3af --- /dev/null +++ b/vendor/github.com/pion/ice/v4/addr.go @@ -0,0 +1,144 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "fmt" + "net" + "net/netip" +) + +func addrWithOptionalZone(addr netip.Addr, zone string) netip.Addr { + if zone == "" { + return addr + } + if addr.Is6() && (addr.IsLinkLocalUnicast() || addr.IsLinkLocalMulticast()) { + return addr.WithZone(zone) + } + + return addr +} + +// parseAddrFromIface should only be used when it's known the address belongs to that interface. +// e.g. it's LocalAddress on a listener. +func parseAddrFromIface(in net.Addr, ifcName string) (netip.Addr, int, NetworkType, error) { + addr, port, nt, err := parseAddr(in) + if err != nil { + return netip.Addr{}, 0, 0, err + } + if _, ok := in.(*net.IPNet); ok { + // net.IPNet does not have a Zone but we provide it from the interface + addr = addrWithOptionalZone(addr, ifcName) + } + + return addr, port, nt, nil +} + +func parseAddr(in net.Addr) (netip.Addr, int, NetworkType, error) { + host := func(ip net.IP, zone string) (netip.Addr, int, NetworkType, error) { + a, err := ipAddrToNetIP(ip, zone) + if err != nil { + return netip.Addr{}, 0, 0, err + } + + return a, 0, 0, nil + } + + sock := func(ip net.IP, zone string, port int, v4, v6 NetworkType) (netip.Addr, int, NetworkType, error) { + a, err := ipAddrToNetIP(ip, zone) + if err != nil { + return netip.Addr{}, 0, 0, err + } + + nt := v6 + if a.Is4() { + nt = v4 + } + + return a, port, nt, nil + } + + switch a := in.(type) { + case *net.IPNet: + return host(a.IP, "") + case *net.IPAddr: + return host(a.IP, a.Zone) + case *net.UDPAddr: + return sock(a.IP, a.Zone, a.Port, NetworkTypeUDP4, NetworkTypeUDP6) + case *net.TCPAddr: + return sock(a.IP, a.Zone, a.Port, NetworkTypeTCP4, NetworkTypeTCP6) + default: + return netip.Addr{}, 0, 0, addrParseError{in} + } +} + +type addrParseError struct { + addr net.Addr +} + +func (e addrParseError) Error() string { + return fmt.Sprintf("do not know how to parse address type %T", e.addr) +} + +type ipConvertError struct { + ip []byte +} + +func (e ipConvertError) Error() string { + return fmt.Sprintf("failed to convert IP '%s' to netip.Addr", e.ip) +} + +func ipAddrToNetIP(ip []byte, zone string) (netip.Addr, error) { + netIPAddr, ok := netip.AddrFromSlice(ip) + if !ok { + return netip.Addr{}, ipConvertError{ip} + } + // we'd rather have an IPv4-mapped IPv6 become IPv4 so that it is usable. + netIPAddr = netIPAddr.Unmap() + netIPAddr = addrWithOptionalZone(netIPAddr, zone) + + return netIPAddr, nil +} + +func createAddr(network NetworkType, ip netip.Addr, port int) net.Addr { + switch { + case network.IsTCP(): + return &net.TCPAddr{IP: ip.AsSlice(), Port: port, Zone: ip.Zone()} + default: + return &net.UDPAddr{IP: ip.AsSlice(), Port: port, Zone: ip.Zone()} + } +} + +func addrEqual(a, b net.Addr) bool { + aIP, aPort, aType, aErr := parseAddr(a) + if aErr != nil { + return false + } + + bIP, bPort, bType, bErr := parseAddr(b) + if bErr != nil { + return false + } + + return aType == bType && aIP.Compare(bIP) == 0 && aPort == bPort +} + +// AddrPort is an IP and a port number. +type AddrPort [18]byte + +func toAddrPort(addr net.Addr) AddrPort { + var ap AddrPort + switch addr := addr.(type) { + case *net.UDPAddr: + copy(ap[:16], addr.IP.To16()) + ap[16] = uint8(addr.Port >> 8) //nolint:gosec // G115 false positive + ap[17] = uint8(addr.Port) //nolint:gosec // G115 false positive + case *net.TCPAddr: + copy(ap[:16], addr.IP.To16()) + ap[16] = uint8(addr.Port >> 8) //nolint:gosec // G115 false positive + ap[17] = uint8(addr.Port) //nolint:gosec // G115 false positive + } + + return ap +} diff --git a/vendor/github.com/pion/ice/v4/agent.go b/vendor/github.com/pion/ice/v4/agent.go new file mode 100644 index 0000000..ff312c4 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/agent.go @@ -0,0 +1,2053 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package ice implements the Interactive Connectivity Establishment (ICE) +// protocol defined in rfc5245. +package ice + +import ( + "context" + "fmt" + "math" + "net" + "net/netip" + "slices" + "strings" + "sync" + "sync/atomic" + "time" + + stunx "github.com/pion/ice/v4/internal/stun" + "github.com/pion/ice/v4/internal/taskloop" + "github.com/pion/logging" + "github.com/pion/mdns/v2" + "github.com/pion/stun/v3" + "github.com/pion/transport/v4" + "github.com/pion/transport/v4/packetio" + "github.com/pion/transport/v4/stdnet" + "github.com/pion/transport/v4/vnet" + "github.com/pion/turn/v4" + "golang.org/x/net/proxy" +) + +type bindingRequest struct { + timestamp time.Time + transactionID [stun.TransactionIDSize]byte + destination net.Addr + isUseCandidate bool + nominationValue *uint32 // Tracks nomination value for renomination requests +} + +// Agent represents the ICE agent. +type Agent struct { + loop *taskloop.Loop + + // constructed is set to true after the agent is fully initialized. + // Options can check this flag to reject updates that are only valid during construction. + constructed bool + + onConnectionStateChangeHdlr atomic.Value // func(ConnectionState) + onSelectedCandidatePairChangeHdlr atomic.Value // func(Candidate, Candidate) + onCandidateHdlr atomic.Value // func(Candidate) + + onConnected chan struct{} + onConnectedOnce sync.Once + + // Force candidate to be contacted immediately (instead of waiting for task ticker) + forceCandidateContact chan bool + + tieBreaker uint64 + lite bool + + connectionState ConnectionState + gatheringState GatheringState + + mDNSMode MulticastDNSMode + mDNSName string + mDNSConn *mdns.Conn + + muHaveStarted sync.Mutex + startedCh <-chan struct{} + startedFn func() + isControlling atomic.Bool + + maxBindingRequests uint16 + + hostAcceptanceMinWait time.Duration + srflxAcceptanceMinWait time.Duration + prflxAcceptanceMinWait time.Duration + relayAcceptanceMinWait time.Duration + stunGatherTimeout time.Duration + + tcpPriorityOffset uint16 + disableActiveTCP bool + + portMin uint16 + portMax uint16 + + candidateTypes []CandidateType + + // How long connectivity checks can fail before the ICE Agent + // goes to disconnected + disconnectedTimeout time.Duration + + // How long connectivity checks can fail before the ICE Agent + // goes to failed + failedTimeout time.Duration + + // How often should we send keepalive packets? + // 0 means never + keepaliveInterval time.Duration + + // How often should we run our internal taskLoop to check for state changes when connecting + checkInterval time.Duration + + localUfrag string + localPwd string + localCandidates map[NetworkType][]Candidate + + remoteUfrag string + remotePwd string + remoteCandidates map[NetworkType][]Candidate + + checklist []*CandidatePair + nextPairID uint64 + pairsByID map[uint64]*CandidatePair + + selectorLock sync.RWMutex + selector pairCandidateSelector + + selectedPair atomic.Value // *CandidatePair + + urls []*stun.URI + networkTypes []NetworkType + addressRewriteRules []AddressRewriteRule + + buf *packetio.Buffer + + // LRU of outbound Binding request Transaction IDs + pendingBindingRequests []bindingRequest + + // Address rewrite (1:1) IP mapping + addressRewriteMapper *addressRewriteMapper + + // Callback that allows user to implement custom behavior + // for STUN Binding Requests + userBindingRequestHandler func(m *stun.Message, local, remote Candidate, pair *CandidatePair) bool + + gatherCandidateCancel func() + gatherCandidateDone chan struct{} + + connectionStateNotifier *handlerNotifier + candidateNotifier *handlerNotifier + selectedCandidatePairNotifier *handlerNotifier + + loggerFactory logging.LoggerFactory + log logging.LeveledLogger + + net transport.Net + tcpMux TCPMux + udpMux UDPMux + udpMuxSrflx UniversalUDPMux + + interfaceFilter func(string) (keep bool) + ipFilter func(net.IP) (keep bool) + remoteIPFilter func(net.IP) (keep bool) + includeLoopback bool + + insecureSkipVerify bool + + proxyDialer proxy.Dialer + + enableUseCandidateCheckPriority bool + + // Renomination support + enableRenomination bool + nominationValueGenerator func() uint32 + nominationAttribute stun.AttrType + + // Continual gathering support + continualGatheringPolicy ContinualGatheringPolicy + networkMonitorInterval time.Duration + lastKnownInterfaces map[string]netip.Addr // map[iface+ip] for deduplication + + // Automatic renomination + automaticRenomination bool + renominationInterval time.Duration + lastRenominationTime time.Time + + turnClientFactory func(*turn.ClientConfig) (turnClient, error) +} + +// NewAgent creates a new Agent. +// +// Deprecated: use NewAgentWithOptions instead. +func NewAgent(config *AgentConfig) (*Agent, error) { + return newAgentFromConfig(config) +} + +// NewAgentWithOptions creates a new Agent with options only. +func NewAgentWithOptions(opts ...AgentOption) (*Agent, error) { + return newAgentFromConfig(&AgentConfig{}, opts...) +} + +func newAgentFromConfig(config *AgentConfig, opts ...AgentOption) (*Agent, error) { + if config == nil { + config = &AgentConfig{} + } + + agent, err := createAgentBase(config) + if err != nil { + return nil, err + } + + agent.localUfrag = config.LocalUfrag + agent.localPwd = config.LocalPwd + if config.NAT1To1IPs != nil { + if err := validateLegacyNAT1To1IPs(config.NAT1To1IPs); err != nil { + return nil, err + } + + typ := CandidateTypeHost + if config.NAT1To1IPCandidateType != CandidateTypeUnspecified { + typ = config.NAT1To1IPCandidateType + } + + rules, err := legacyNAT1To1Rules(config.NAT1To1IPs, typ) + if err != nil { + return nil, err + } + agent.addressRewriteRules = rules + } + + return newAgentWithConfig(agent, opts...) +} + +func validateLegacyNAT1To1IPs(ips []string) error { + var hasIPv4CatchAll, hasIPv6CatchAll bool + + for _, mapping := range ips { + trimmed := strings.TrimSpace(mapping) + var err error + hasIPv4CatchAll, hasIPv6CatchAll, err = validateLegacyNAT1To1Entry(trimmed, hasIPv4CatchAll, hasIPv6CatchAll) + if err != nil { + return err + } + } + + return nil +} + +func validateLegacyNAT1To1Entry(mapping string, hasIPv4CatchAll, hasIPv6CatchAll bool) (bool, bool, error) { + if mapping == "" { + return hasIPv4CatchAll, hasIPv6CatchAll, nil + } + + parts := strings.Split(mapping, "/") + if len(parts) == 0 || len(parts) > 2 { + return hasIPv4CatchAll, hasIPv6CatchAll, ErrInvalidNAT1To1IPMapping + } + + _, isIPv4, err := validateIPString(parts[0]) + if err != nil { + return hasIPv4CatchAll, hasIPv6CatchAll, err + } + + if len(parts) == 2 { + if _, _, err := validateIPString(strings.TrimSpace(parts[1])); err != nil { + return hasIPv4CatchAll, hasIPv6CatchAll, err + } + + return hasIPv4CatchAll, hasIPv6CatchAll, nil + } + + if isIPv4 { + if hasIPv4CatchAll { + return hasIPv4CatchAll, hasIPv6CatchAll, ErrInvalidNAT1To1IPMapping + } + + return true, hasIPv6CatchAll, nil + } + + if hasIPv6CatchAll { + return hasIPv4CatchAll, hasIPv6CatchAll, ErrInvalidNAT1To1IPMapping + } + + return hasIPv4CatchAll, true, nil +} + +func legacyNAT1To1Rules(ips []string, candidateType CandidateType) ([]AddressRewriteRule, error) { + var rules []AddressRewriteRule + + for _, mapping := range ips { + trimmed := strings.TrimSpace(mapping) + if trimmed == "" { + continue + } + + parts := strings.Split(trimmed, "/") + switch len(parts) { + case 1: + rules = append(rules, AddressRewriteRule{ + External: []string{parts[0]}, + AsCandidateType: candidateType, + }) + case 2: + ext := strings.TrimSpace(parts[0]) + local := strings.TrimSpace(parts[1]) + if ext == "" || local == "" { + return nil, ErrInvalidNAT1To1IPMapping + } + + if _, _, err := validateIPString(ext); err != nil { + return nil, err + } + if _, _, err := validateIPString(local); err != nil { + return nil, err + } + + rules = append(rules, AddressRewriteRule{ + External: []string{ext}, + Local: local, + AsCandidateType: candidateType, + }) + default: + return nil, ErrInvalidNAT1To1IPMapping + } + } + + return rules, nil +} + +func createAgentBase(config *AgentConfig) (*Agent, error) { + if config.PortMax < config.PortMin { + return nil, ErrPort + } + + mDNSName, mDNSMode, err := setupMDNSConfig(config) + if err != nil { + return nil, err + } + + loggerFactory := config.LoggerFactory + if loggerFactory == nil { + loggerFactory = logging.NewDefaultLoggerFactory() + } + log := loggerFactory.NewLogger("ice") + + startedCtx, startedFn := context.WithCancel(context.Background()) + + agent := &Agent{ + tieBreaker: globalMathRandomGenerator.Uint64(), + lite: config.Lite, + gatheringState: GatheringStateNew, + connectionState: ConnectionStateNew, + localCandidates: make(map[NetworkType][]Candidate), + remoteCandidates: make(map[NetworkType][]Candidate), + pairsByID: make(map[uint64]*CandidatePair), + urls: config.Urls, + networkTypes: config.NetworkTypes, + onConnected: make(chan struct{}), + buf: packetio.NewBuffer(), + startedCh: startedCtx.Done(), + startedFn: startedFn, + portMin: config.PortMin, + portMax: config.PortMax, + loggerFactory: loggerFactory, + log: log, + net: config.Net, + proxyDialer: config.ProxyDialer, + tcpMux: config.TCPMux, + udpMux: config.UDPMux, + udpMuxSrflx: config.UDPMuxSrflx, + mDNSMode: mDNSMode, + mDNSName: mDNSName, + gatherCandidateCancel: func() {}, + forceCandidateContact: make(chan bool, 1), + interfaceFilter: config.InterfaceFilter, + ipFilter: config.IPFilter, + remoteIPFilter: config.RemoteIPFilter, + insecureSkipVerify: config.InsecureSkipVerify, + includeLoopback: config.IncludeLoopback, + disableActiveTCP: config.DisableActiveTCP, + userBindingRequestHandler: config.BindingRequestHandler, + enableUseCandidateCheckPriority: config.EnableUseCandidateCheckPriority, + enableRenomination: false, + nominationValueGenerator: nil, + nominationAttribute: stun.AttrType(0x0030), // Default value + continualGatheringPolicy: GatherOnce, // Default to GatherOnce + networkMonitorInterval: 2 * time.Second, + lastKnownInterfaces: make(map[string]netip.Addr), + automaticRenomination: false, + renominationInterval: 3 * time.Second, // Default matching libwebrtc + turnClientFactory: defaultTurnClient, + } + + config.initWithDefaults(agent) + + return agent, nil +} + +func applyAddressRewriteMapping(agent *Agent) error { + mapper, err := newAddressRewriteMapper(agent.addressRewriteRules) + if err != nil { + return err + } + + agent.addressRewriteMapper = mapper + if agent.addressRewriteMapper == nil { + return nil + } + + if agent.addressRewriteMapper.hasCandidateType(CandidateTypeHost) { + // for mDNS QueryAndGather we never advertise rewritten host IPs to avoid + // leaking local addresses, this matches the legacy NAT1:1 behavior. + if agent.mDNSMode == MulticastDNSModeQueryAndGather { + return ErrMulticastDNSWithNAT1To1IPMapping + } + // surface misconfiguration when host candidates are disabled but a host + // rewrite rule was provided. + if !containsCandidateType(CandidateTypeHost, agent.candidateTypes) { + return ErrIneffectiveNAT1To1IPMappingHost + } + } + + if agent.addressRewriteMapper.hasCandidateType(CandidateTypeServerReflexive) { + // surface misconfiguration when srflx candidates are disabled but a srflx + // rewrite rule was provided. + if !containsCandidateType(CandidateTypeServerReflexive, agent.candidateTypes) { + return ErrIneffectiveNAT1To1IPMappingSrflx + } + } + + return nil +} + +// setupMDNSConfig validates and returns mDNS configuration. +func setupMDNSConfig(config *AgentConfig) (string, MulticastDNSMode, error) { + mDNSName := config.MulticastDNSHostName + if mDNSName == "" { + var err error + if mDNSName, err = generateMulticastDNSName(); err != nil { + return "", 0, err + } + } + + if !strings.HasSuffix(mDNSName, ".local") || len(strings.Split(mDNSName, ".")) != 2 { + return "", 0, ErrInvalidMulticastDNSHostName + } + + mDNSMode := config.MulticastDNSMode + if mDNSMode == 0 { + mDNSMode = MulticastDNSModeQueryOnly + } + + return mDNSName, mDNSMode, nil +} + +// newAgentWithConfig finalizes a pre-configured agent with optional overrides. +// +//nolint:gocognit,cyclop +func newAgentWithConfig(agent *Agent, opts ...AgentOption) (*Agent, error) { + var err error + + for _, opt := range opts { + if err = opt(agent); err != nil { + return nil, err + } + } + + agent.connectionStateNotifier = &handlerNotifier{ + connectionStateFunc: agent.onConnectionStateChange, + done: make(chan struct{}), + } + agent.candidateNotifier = &handlerNotifier{candidateFunc: agent.onCandidate, done: make(chan struct{})} + agent.selectedCandidatePairNotifier = &handlerNotifier{ + candidatePairFunc: agent.onSelectedCandidatePairChange, + done: make(chan struct{}), + } + + if agent.net == nil { + agent.net, err = stdnet.NewNet() + if err != nil { + return nil, fmt.Errorf("failed to create network: %w", err) + } + } else if _, isVirtual := agent.net.(*vnet.Net); isVirtual { + agent.log.Warn("Virtual network is enabled") + if agent.mDNSMode != MulticastDNSModeDisabled { + agent.log.Warn("Virtual network does not support mDNS yet") + } + } + + localIfcs, _, err := localInterfaces( + agent.net, + agent.interfaceFilter, + agent.ipFilter, + agent.networkTypes, + agent.includeLoopback, + ) + if err != nil { + return nil, fmt.Errorf("error getting local interfaces: %w", err) + } + + mDNSLocalAddress := mDNSLocalAddressFromTCPMux(agent.tcpMux, agent.networkTypes) + + // Opportunistic mDNS: If we can't open the connection, that's ok: we + // can continue without it. + if agent.mDNSConn, agent.mDNSMode, err = createMulticastDNS( + agent.net, + agent.networkTypes, + localIfcs, + agent.includeLoopback, + mDNSLocalAddress, + agent.mDNSMode, + agent.mDNSName, + agent.log, + agent.loggerFactory, + ); err != nil { + agent.log.Warnf("Failed to initialize mDNS %s: %v", agent.mDNSName, err) + } + + // Make sure the buffer doesn't grow indefinitely. + // NOTE: We actually won't get anywhere close to this limit. + // SRTP will constantly read from the endpoint and drop packets if it's full. + agent.buf.SetLimitSize(maxBufferSize) + + if agent.lite && (len(agent.candidateTypes) != 1 || agent.candidateTypes[0] != CandidateTypeHost) { + agent.closeMulticastConn() + + return nil, ErrLiteUsingNonHostCandidates + } + + if len(agent.urls) > 0 && + !containsCandidateType(CandidateTypeServerReflexive, agent.candidateTypes) && + !containsCandidateType(CandidateTypeRelay, agent.candidateTypes) { + agent.closeMulticastConn() + + return nil, ErrUselessUrlsProvided + } + + if err = applyAddressRewriteMapping(agent); err != nil { + agent.closeMulticastConn() + + return nil, err + } + + agent.loop = taskloop.New(func() { + agent.gatherCandidateCancel() + if agent.gatherCandidateDone != nil { + <-agent.gatherCandidateDone + } + + agent.removeUfragFromMux() + agent.deleteAllCandidates() + agent.startedFn() + + if err := agent.buf.Close(); err != nil { + agent.log.Warnf("Failed to close buffer: %v", err) + } + + agent.closeMulticastConn() + agent.updateConnectionState(ConnectionStateClosed) + }) + + // Restart is also used to initialize the agent for the first time + if err := agent.Restart(agent.localUfrag, agent.localPwd); err != nil { + agent.closeMulticastConn() + _ = agent.Close() + + return nil, err + } + + agent.constructed = true + + return agent, nil +} + +func mDNSLocalAddressFromTCPMux(tcpMux TCPMux, networkTypes []NetworkType) net.IP { + if tcpMux == nil || !allNetworkTypesTCP(networkTypes) { + return nil + } + + tcpAddr, ok := localTCPAddrFromMux(tcpMux) + if !ok { + return nil + } + + localAddr, ok := mDNSLocalAddressFromIP(tcpAddr.IP) + if !ok { + return nil + } + + return localAddr +} + +func allNetworkTypesTCP(networkTypes []NetworkType) bool { + if len(networkTypes) == 0 { + return false + } + + for _, networkType := range networkTypes { + if !networkType.IsTCP() { + return false + } + } + + return true +} + +func localTCPAddrFromMux(tcpMux TCPMux) (*net.TCPAddr, bool) { + addrProvider, ok := tcpMux.(interface{ LocalAddr() net.Addr }) + if !ok { + return nil, false + } + + tcpAddr, ok := addrProvider.LocalAddr().(*net.TCPAddr) + if !ok || tcpAddr.IP == nil || tcpAddr.IP.IsUnspecified() { + return nil, false + } + + return tcpAddr, true +} + +func mDNSLocalAddressFromIP(ip net.IP) (net.IP, bool) { + parsed, ok := netip.AddrFromSlice(ip) + if !ok { + return nil, false + } + + parsed = parsed.Unmap() + if parsed.Is6() && (parsed.IsLinkLocalUnicast() || parsed.IsLinkLocalMulticast()) { + // mdns.Config.LocalAddress has no zone support for link-local IPv6. + return nil, false + } + + return parsed.AsSlice(), true +} + +func (a *Agent) startConnectivityChecks(isControlling bool, remoteUfrag, remotePwd string) error { + a.muHaveStarted.Lock() + defer a.muHaveStarted.Unlock() + select { + case <-a.startedCh: + return ErrMultipleStart + default: + } + if err := a.SetRemoteCredentials(remoteUfrag, remotePwd); err != nil { //nolint:contextcheck + return err + } + + a.log.Debugf("Started agent: isControlling? %t, remoteUfrag: %q, remotePwd: %q", isControlling, remoteUfrag, remotePwd) + + return a.loop.Run(a.loop, func(_ context.Context) { + a.isControlling.Store(isControlling) + a.remoteUfrag = remoteUfrag + a.remotePwd = remotePwd + a.setSelector() + + a.startedFn() + + a.updateConnectionState(ConnectionStateChecking) + + a.requestConnectivityCheck() + go a.connectivityChecks() //nolint:contextcheck + }) +} + +func (a *Agent) connectivityChecks() { //nolint:cyclop + lastConnectionState := ConnectionState(0) + checkingDuration := time.Time{} + + contact := func() { + if err := a.loop.Run(a.loop, func(_ context.Context) { + defer func() { + lastConnectionState = a.connectionState + }() + + switch a.connectionState { + case ConnectionStateFailed: + // The connection is currently failed so don't send any checks + // In the future it may be restarted though + return + case ConnectionStateChecking: + // We have just entered checking for the first time so update our checking timer + if lastConnectionState != a.connectionState { + checkingDuration = time.Now() + } + + // We have been in checking longer then Disconnect+Failed timeout, set the connection to Failed + if time.Since(checkingDuration) > a.disconnectedTimeout+a.failedTimeout { + a.updateConnectionState(ConnectionStateFailed) + + return + } + default: + } + + a.getSelector().ContactCandidates() + }); err != nil { + a.log.Warnf("Failed to start connectivity checks: %v", err) + } + } + + timer := time.NewTimer(math.MaxInt64) + timer.Stop() + + for { + interval := defaultKeepaliveInterval + + updateInterval := func(x time.Duration) { + if x != 0 && (interval == 0 || interval > x) { + interval = x + } + } + + switch lastConnectionState { + case ConnectionStateNew, ConnectionStateChecking: // While connecting, check candidates more frequently + updateInterval(a.checkInterval) + case ConnectionStateConnected, ConnectionStateDisconnected: + updateInterval(a.keepaliveInterval) + default: + } + // Ensure we run our task loop as quickly as the minimum of our various configured timeouts + updateInterval(a.disconnectedTimeout) + updateInterval(a.failedTimeout) + + timer.Reset(interval) + + select { + case <-a.forceCandidateContact: + if !timer.Stop() { + <-timer.C + } + contact() + case <-timer.C: + contact() + case <-a.loop.Done(): + timer.Stop() + + return + } + } +} + +func (a *Agent) updateConnectionState(newState ConnectionState) { + if a.connectionState != newState { + // Connection has gone to failed, release all gathered candidates + if newState == ConnectionStateFailed { + a.removeUfragFromMux() + a.checklist = make([]*CandidatePair, 0) + a.pairsByID = make(map[uint64]*CandidatePair) + a.pendingBindingRequests = make([]bindingRequest, 0) + a.setSelectedPair(nil) + a.deleteAllCandidates() + } + + a.log.Infof("Setting new connection state: %s", newState) + a.connectionState = newState + a.connectionStateNotifier.EnqueueConnectionState(newState) + } +} + +func (a *Agent) setSelectedPair(pair *CandidatePair) { + if pair == nil { + var nilPair *CandidatePair + a.selectedPair.Store(nilPair) + a.log.Tracef("Unset selected candidate pair") + + return + } + + pair.nominated = true + a.selectedPair.Store(pair) + a.log.Tracef("Set selected candidate pair: %s", pair) + + // Signal connected: notify any Connect() calls waiting on onConnected + a.onConnectedOnce.Do(func() { close(a.onConnected) }) + + // Update connection state to Connected and notify state change handlers + a.updateConnectionState(ConnectionStateConnected) + + // Notify when the selected candidate pair changes + a.selectedCandidatePairNotifier.EnqueueSelectedCandidatePair(pair) +} + +func (a *Agent) pingAllCandidates() { + a.log.Trace("Pinging all candidates") + + if len(a.checklist) == 0 { + a.log.Warn("Failed to ping without candidate pairs. Connection is not possible yet.") + } + + for _, p := range a.checklist { + if p.state == CandidatePairStateWaiting { + p.state = CandidatePairStateInProgress + } else if p.state != CandidatePairStateInProgress { + continue + } + + if p.bindingRequestCount > a.maxBindingRequests { + a.log.Tracef("Maximum requests reached for pair %s, marking it as failed", p) + p.state = CandidatePairStateFailed + } else { + a.getSelector().PingCandidate(p.Local, p.Remote) + p.bindingRequestCount++ + } + } +} + +// keepAliveCandidatesForRenomination pings all candidate pairs to keep them tested +// and ready for automatic renomination. Unlike pingAllCandidates, this: +// - Pings pairs in succeeded state to keep RTT measurements fresh +// - Ignores maxBindingRequests limit (we want to keep testing alternate paths) +// - Only pings pairs that are not failed. +func (a *Agent) keepAliveCandidatesForRenomination() { + a.log.Trace("Keep alive candidates for automatic renomination") + + if len(a.checklist) == 0 { + return + } + + for _, pair := range a.checklist { + switch pair.state { + case CandidatePairStateFailed: + // Skip failed pairs + continue + case CandidatePairStateWaiting: + // Transition waiting pairs to in-progress + pair.state = CandidatePairStateInProgress + case CandidatePairStateInProgress, CandidatePairStateSucceeded: + // Continue pinging in-progress and succeeded pairs + } + + // Ping all non-failed pairs (including succeeded ones) + // to keep RTT measurements fresh for renomination decisions + a.getSelector().PingCandidate(pair.Local, pair.Remote) + } +} + +func (a *Agent) getBestAvailableCandidatePair() *CandidatePair { + var best *CandidatePair + for _, p := range a.checklist { + if p.state == CandidatePairStateFailed { + continue + } + + if best == nil { + best = p + } else if best.priority() < p.priority() { + best = p + } + } + + return best +} + +func (a *Agent) getBestValidCandidatePair() *CandidatePair { + var best *CandidatePair + for _, p := range a.checklist { + if p.state != CandidatePairStateSucceeded { + continue + } + + if best == nil { + best = p + } else if best.priority() < p.priority() { + best = p + } + } + + return best +} + +func (a *Agent) addPair(local, remote Candidate) *CandidatePair { + a.nextPairID++ + p := newCandidatePair(local, remote, a.isControlling.Load()) + p.id = a.nextPairID + a.checklist = append(a.checklist, p) + a.pairsByID[p.id] = p + + return p +} + +func (a *Agent) findPair(local, remote Candidate) *CandidatePair { + for _, p := range a.checklist { + if p.Local.Equal(local) && p.Remote.Equal(remote) { + return p + } + } + + return nil +} + +// validateSelectedPair checks if the selected pair is (still) valid +// Note: the caller should hold the agent lock. +func (a *Agent) validateSelectedPair() bool { + selectedPair := a.getSelectedPair() + if selectedPair == nil { + return false + } + + disconnectedTime := time.Since(selectedPair.Remote.LastReceived()) + + // Only allow transitions to failed if a.failedTimeout is non-zero + totalTimeToFailure := a.failedTimeout + if totalTimeToFailure != 0 { + totalTimeToFailure += a.disconnectedTimeout + } + + a.updateConnectionState(a.connectionStateForDisconnection(disconnectedTime, totalTimeToFailure)) + + return true +} + +func (a *Agent) connectionStateForDisconnection( + disconnectedTime time.Duration, + totalTimeToFailure time.Duration, +) ConnectionState { + disconnected := a.disconnectedTimeout != 0 && disconnectedTime > a.disconnectedTimeout + failed := totalTimeToFailure != 0 && disconnectedTime > totalTimeToFailure + + switch { + case failed: + if disconnected && a.connectionState != ConnectionStateDisconnected && a.connectionState != ConnectionStateFailed { + // If we never reported disconnected but both thresholds are already exceeded, + // emit disconnected first so callers can observe both transitions. + return ConnectionStateDisconnected + } + + return ConnectionStateFailed + case disconnected: + return ConnectionStateDisconnected + default: + return ConnectionStateConnected + } +} + +// checkKeepalive sends STUN Binding Indications to the selected pair +// if no packet has been sent on that pair in the last keepaliveInterval +// Note: the caller should hold the agent lock. +func (a *Agent) checkKeepalive() { + selectedPair := a.getSelectedPair() + if selectedPair == nil { + return + } + + if a.keepaliveInterval != 0 { + // We use binding request instead of indication to support refresh consent schemas + // see https://tools.ietf.org/html/rfc7675 + a.getSelector().PingCandidate(selectedPair.Local, selectedPair.Remote) + } +} + +// AddRemoteCandidate adds a new remote candidate. +func (a *Agent) AddRemoteCandidate(cand Candidate) error { + if cand == nil { + return nil + } + + // TCP Candidates with TCP type active will probe server passive ones, so + // no need to do anything with them. + if cand.TCPType() == TCPTypeActive { + a.log.Infof("Ignoring remote candidate with tcpType active: %s", cand) + + return nil + } + + // If we have a mDNS Candidate lets fully resolve it before adding it locally + if cand.Type() == CandidateTypeHost && strings.HasSuffix(cand.Address(), ".local") { + if a.mDNSMode == MulticastDNSModeDisabled { + a.log.Warnf("Remote mDNS candidate added, but mDNS is disabled: (%s)", cand.Address()) + + return nil + } + + hostCandidate, ok := cand.(*CandidateHost) + if !ok { + return ErrAddressParseFailed + } + + go a.resolveAndAddMulticastCandidate(hostCandidate) + + return nil + } + + go func() { + if err := a.loop.Run(a.loop, func(_ context.Context) { + // nolint: contextcheck + a.addRemoteCandidate(cand) + }); err != nil { + a.log.Warnf("Failed to add remote candidate %s: %v", cand.Address(), err) + + return + } + }() + + return nil +} + +func (a *Agent) resolveAndAddMulticastCandidate(cand *CandidateHost) { + if a.mDNSConn == nil { + return + } + + ctx, cancel := context.WithTimeout(a.loop, a.mDNSQueryTimeout()) + defer cancel() + + _, src, err := a.mDNSConn.QueryAddr(ctx, cand.Address()) + if err != nil { + a.log.Warnf("Failed to discover mDNS candidate %s: %v", cand.Address(), err) + + return + } + + if err = cand.setIPAddr(src); err != nil { + a.log.Warnf("Failed to discover mDNS candidate %s: %v", cand.Address(), err) + + return + } + + if err = a.loop.Run(a.loop, func(_ context.Context) { + // nolint: contextcheck + a.addRemoteCandidate(cand) + }); err != nil { + a.log.Warnf("Failed to add mDNS candidate %s: %v", cand.Address(), err) + + return + } +} + +func (a *Agent) mDNSQueryTimeout() time.Duration { + if a.stunGatherTimeout > 0 { + return a.stunGatherTimeout + } + + return defaultSTUNGatherTimeout +} + +func (a *Agent) requestConnectivityCheck() { + select { + case a.forceCandidateContact <- true: + default: + } +} + +func (a *Agent) addRemotePassiveTCPCandidate(remoteCandidate Candidate) { + _, localIPs, err := localInterfaces( + a.net, + a.interfaceFilter, + a.ipFilter, + []NetworkType{remoteCandidate.NetworkType()}, + a.includeLoopback, + ) + if err != nil { + a.log.Warnf("Failed to iterate local interfaces, host candidates will not be gathered %s", err) + + return + } + + for i := range localIPs { + ip, _, _, err := parseAddr(remoteCandidate.addr()) + if err != nil { + a.log.Warnf("Failed to parse address: %s; error: %s", remoteCandidate.addr(), err) + + continue + } + + dialIP := remoteDialIPForLocalInterface(ip, localIPs[i].addr) + + conn := newActiveTCPConn( + a.loop, + net.JoinHostPort(localIPs[i].addr.String(), "0"), + netip.AddrPortFrom(dialIP, uint16(remoteCandidate.Port())), //nolint:gosec // G115, no overflow, a port + a.log, + ) + + tcpAddr, ok := conn.LocalAddr().(*net.TCPAddr) + if !ok { + closeConnAndLog(conn, a.log, "Failed to create Active ICE-TCP Candidate: %v", errInvalidAddress) + + continue + } + + localCandidate, err := NewCandidateHost(&CandidateHostConfig{ + Network: remoteCandidate.NetworkType().String(), + Address: localIPs[i].addr.String(), + Port: tcpAddr.Port, + Component: ComponentRTP, + TCPType: TCPTypeActive, + }) + if err != nil { + closeConnAndLog(conn, a.log, "Failed to create Active ICE-TCP Candidate: %v", err) + + continue + } + + localCandidate.start(a, conn, a.startedCh) + a.localCandidates[localCandidate.NetworkType()] = append( + a.localCandidates[localCandidate.NetworkType()], + localCandidate, + ) + a.candidateNotifier.EnqueueCandidate(localCandidate) + + a.addPair(localCandidate, remoteCandidate) + } +} + +func remoteDialIPForLocalInterface(remoteIP, localIP netip.Addr) netip.Addr { + if remoteIP.Is6() && + remoteIP.Zone() == "" && + (remoteIP.IsLinkLocalUnicast() || remoteIP.IsLinkLocalMulticast()) { + if zone := localIP.Zone(); zone != "" { + return remoteIP.WithZone(zone) + } + } + + return remoteIP +} + +// addRemoteCandidate assumes you are holding the lock (must be execute using a.run). +// Returns true when the candidate is accepted (including duplicates). +func (a *Agent) addRemoteCandidate(cand Candidate) bool { //nolint:cyclop + if !a.shouldAcceptRemoteCandidate(cand) { + return false + } + + if len(a.networkTypes) > 0 && !slices.Contains(a.networkTypes, cand.NetworkType()) { + a.log.Infof("Ignoring remote candidate with disabled network type %s: %s", cand.NetworkType(), cand) + + return false + } + + set := a.remoteCandidates[cand.NetworkType()] + + for _, candidate := range set { + if candidate.Equal(cand) { + return true + } + } + + acceptRemotePassiveTCPCandidate := false + // Assert that TCP4 or TCP6 is a enabled NetworkType locally + if !a.disableActiveTCP && cand.TCPType() == TCPTypePassive { + for _, networkType := range a.networkTypes { + if cand.NetworkType() == networkType { + acceptRemotePassiveTCPCandidate = true + } + } + } + + if acceptRemotePassiveTCPCandidate { + a.addRemotePassiveTCPCandidate(cand) + } + + set = append(set, cand) + a.remoteCandidates[cand.NetworkType()] = set + + if cand.TCPType() != TCPTypePassive { + if localCandidates, ok := a.localCandidates[cand.NetworkType()]; ok { + for _, localCandidate := range localCandidates { + a.addPair(localCandidate, cand) + } + } + } + + a.requestConnectivityCheck() + + return true +} + +func (a *Agent) shouldAcceptRemoteCandidate(cand Candidate) bool { + if a.remoteIPFilter == nil { + return true + } + + ipAddr, _, _, err := parseAddr(cand.addr()) + if err != nil { + a.log.Warnf("Ignoring remote candidate with unparsable address %q: %v", cand.addr(), err) + + return false + } + + if !a.remoteIPFilter(ipAddr.AsSlice()) { + a.log.Warnf("Ignoring remote candidate filtered by remote IP policy: %s", cand) + + return false + } + + return true +} + +func (a *Agent) addCandidate(ctx context.Context, cand Candidate, candidateConn net.PacketConn) error { + if err := ctx.Err(); err != nil { + return err + } + + return a.loop.Run(ctx, func(context.Context) { + set := a.localCandidates[cand.NetworkType()] + for _, candidate := range set { + if candidate.Equal(cand) { + a.log.Debugf("Ignore duplicate candidate: %s", cand) + if err := cand.close(); err != nil { + a.log.Warnf("Failed to close duplicate candidate: %v", err) + } + if err := candidateConn.Close(); err != nil { + a.log.Warnf("Failed to close duplicate candidate connection: %v", err) + } + + return + } + } + + a.setCandidateExtensions(cand) + cand.start(a, candidateConn, a.startedCh) + + set = append(set, cand) + a.localCandidates[cand.NetworkType()] = set + + if remoteCandidates, ok := a.remoteCandidates[cand.NetworkType()]; ok { + for _, remoteCandidate := range remoteCandidates { + a.addPair(cand, remoteCandidate) + } + } + + a.requestConnectivityCheck() + + if !cand.filterForLocationTracking() { + a.candidateNotifier.EnqueueCandidate(cand) + } + }) +} + +func (a *Agent) setCandidateExtensions(cand Candidate) { + err := cand.AddExtension(CandidateExtension{ + Key: "ufrag", + Value: a.localUfrag, + }) + if err != nil { + a.log.Errorf("Failed to add ufrag extension to candidate: %v", err) + } +} + +// GetRemoteCandidates returns the remote candidates. +func (a *Agent) GetRemoteCandidates() ([]Candidate, error) { + var res []Candidate + + err := a.loop.Run(a.loop, func(_ context.Context) { + var candidates []Candidate + for _, set := range a.remoteCandidates { + candidates = append(candidates, set...) + } + res = candidates + }) + if err != nil { + return nil, err + } + + return res, nil +} + +// GetLocalCandidates returns the local candidates. +func (a *Agent) GetLocalCandidates() ([]Candidate, error) { + var res []Candidate + + err := a.loop.Run(a.loop, func(_ context.Context) { + var candidates []Candidate + for _, set := range a.localCandidates { + for _, c := range set { + if c.filterForLocationTracking() { + continue + } + candidates = append(candidates, c) + } + } + res = candidates + }) + if err != nil { + return nil, err + } + + return res, nil +} + +// GetGatheringState returns the current gathering state of the Agent. +func (a *Agent) GetGatheringState() (GatheringState, error) { + var state GatheringState + err := a.loop.Run(a.loop, func(_ context.Context) { + state = a.gatheringState + }) + if err != nil { + return GatheringStateUnknown, err + } + + return state, nil +} + +// GetLocalUserCredentials returns the local user credentials. +func (a *Agent) GetLocalUserCredentials() (frag string, pwd string, err error) { + valSet := make(chan struct{}) + err = a.loop.Run(a.loop, func(_ context.Context) { + frag = a.localUfrag + pwd = a.localPwd + close(valSet) + }) + + if err == nil { + <-valSet + } + + return +} + +// GetRemoteUserCredentials returns the remote user credentials. +func (a *Agent) GetRemoteUserCredentials() (frag string, pwd string, err error) { + valSet := make(chan struct{}) + err = a.loop.Run(a.loop, func(_ context.Context) { + frag = a.remoteUfrag + pwd = a.remotePwd + close(valSet) + }) + + if err == nil { + <-valSet + } + + return +} + +func (a *Agent) removeUfragFromMux() { + if a.tcpMux != nil { + a.tcpMux.RemoveConnByUfrag(a.localUfrag) + } + if a.udpMux != nil { + a.udpMux.RemoveConnByUfrag(a.localUfrag) + } + if a.udpMuxSrflx != nil { + a.udpMuxSrflx.RemoveConnByUfrag(a.localUfrag) + } +} + +// Close cleans up the Agent. +func (a *Agent) Close() error { + return a.close(false) +} + +// GracefulClose cleans up the Agent and waits for any goroutines it started +// to complete. This is only safe to call outside of Agent callbacks or if in a callback, +// in its own goroutine. +func (a *Agent) GracefulClose() error { + return a.close(true) +} + +func (a *Agent) close(graceful bool) error { + // the loop is safe to wait on no matter what + a.loop.Close() + + // but we are in less control of the notifiers, so we will + // pass through `graceful`. + a.connectionStateNotifier.Close(graceful) + a.candidateNotifier.Close(graceful) + a.selectedCandidatePairNotifier.Close(graceful) + + return nil +} + +// Remove all candidates. This closes any listening sockets +// and removes both the local and remote candidate lists. +// +// This is used for restarts, failures and on close. +func (a *Agent) deleteAllCandidates() { + for net, cs := range a.localCandidates { + for _, c := range cs { + if err := c.close(); err != nil { + a.log.Warnf("Failed to close candidate %s: %v", c, err) + } + } + delete(a.localCandidates, net) + } + for net, cs := range a.remoteCandidates { + for _, c := range cs { + if err := c.close(); err != nil { + a.log.Warnf("Failed to close candidate %s: %v", c, err) + } + } + delete(a.remoteCandidates, net) + } +} + +func (a *Agent) findRemoteCandidate(networkType NetworkType, addr net.Addr) Candidate { + ip, port, _, err := parseAddr(addr) + if err != nil { + a.log.Warnf("Failed to parse address: %s; error: %s", addr, err) + + return nil + } + + set := a.remoteCandidates[networkType] + for _, c := range set { + if c.Address() == ip.String() && c.Port() == port { + return c + } + } + + return nil +} + +func (a *Agent) sendBindingRequest(msg *stun.Message, local, remote Candidate) { + a.log.Tracef("Ping STUN from %s to %s", local, remote) + + // Extract nomination value if present + var nominationValue *uint32 + var nomination NominationAttribute + if err := nomination.GetFromWithType(msg, a.nominationAttribute); err == nil { + nominationValue = &nomination.Value + } + + a.invalidatePendingBindingRequests(time.Now()) + a.pendingBindingRequests = append(a.pendingBindingRequests, bindingRequest{ + timestamp: time.Now(), + transactionID: msg.TransactionID, + destination: remote.addr(), + isUseCandidate: msg.Contains(stun.AttrUseCandidate), + nominationValue: nominationValue, + }) + + if pair := a.findPair(local, remote); pair != nil { + pair.UpdateRequestSent() + } else { + a.log.Warnf("Failed to find pair for add binding request from %s to %s", local, remote) + } + a.sendSTUN(msg, local, remote) +} + +func (a *Agent) sendBindingSuccess(m *stun.Message, local, remote Candidate) { + base := remote + + ip, port, _, err := parseAddr(base.addr()) + if err != nil { + a.log.Warnf("Failed to parse address: %s; error: %s", base.addr(), err) + + return + } + + if out, err := stun.Build(m, stun.BindingSuccess, + &stun.XORMappedAddress{ + IP: ip.AsSlice(), + Port: port, + }, + stun.NewShortTermIntegrity(a.localPwd), + stun.Fingerprint, + ); err != nil { + a.log.Warnf("Failed to handle inbound ICE from: %s to: %s error: %s", local, remote, err) + } else { + if pair := a.findPair(local, remote); pair != nil { + pair.UpdateResponseSent() + } else { + a.log.Warnf("Failed to find pair for add binding response from %s to %s", local, remote) + } + a.sendSTUN(out, local, remote) + } +} + +// Removes pending binding requests that are over maxBindingRequestTimeout old +// +// Let HTO be the transaction timeout, which SHOULD be 2*RTT if +// RTT is known or 500 ms otherwise. +// https://tools.ietf.org/html/rfc8445#appendix-B.1 +func (a *Agent) invalidatePendingBindingRequests(filterTime time.Time) { + initialSize := len(a.pendingBindingRequests) + + temp := a.pendingBindingRequests[:0] + for _, bindingRequest := range a.pendingBindingRequests { + if filterTime.Sub(bindingRequest.timestamp) < maxBindingRequestTimeout { + temp = append(temp, bindingRequest) + } + } + + a.pendingBindingRequests = temp + if bindRequestsRemoved := initialSize - len(a.pendingBindingRequests); bindRequestsRemoved > 0 { + a.log.Tracef("Discarded %d binding requests because they expired", bindRequestsRemoved) + } +} + +// Assert that the passed TransactionID is in our pendingBindingRequests and returns the destination +// If the bindingRequest was valid remove it from our pending cache. +func (a *Agent) handleInboundBindingSuccess(id [stun.TransactionIDSize]byte) (bool, *bindingRequest, time.Duration) { + a.invalidatePendingBindingRequests(time.Now()) + for i := range a.pendingBindingRequests { + if a.pendingBindingRequests[i].transactionID == id { + validBindingRequest := a.pendingBindingRequests[i] + a.pendingBindingRequests = append(a.pendingBindingRequests[:i], a.pendingBindingRequests[i+1:]...) + + return true, &validBindingRequest, time.Since(validBindingRequest.timestamp) + } + } + + return false, nil, 0 +} + +func (a *Agent) handleRoleConflict(msg *stun.Message, local, remote Candidate, remoteTieBreaker *AttrControl) { + localIsGreaterOrEqual := a.tieBreaker >= remoteTieBreaker.Tiebreaker + a.log.Warnf("Role conflict local and remote same role(%s), localIsGreaterOrEqual(%t)", a.role(), localIsGreaterOrEqual) + + // https://datatracker.ietf.org/doc/html/rfc8445#section-7.3.1.1 + // An agent MUST examine the Binding request for either the ICE- + // CONTROLLING or ICE-CONTROLLED attribute. It MUST follow these + // procedures: + + // If the agent's tiebreaker value is larger than or equal to the contents of the ICE-CONTROLLING attribute + // If the agent's tiebreaker value is less than the contents of the ICE-CONTROLLED attribute + // the agent generates a Binding error response + if (a.isControlling.Load() && localIsGreaterOrEqual) || (!a.isControlling.Load() && !localIsGreaterOrEqual) { + if roleConflictMsg, err := stun.Build(msg, stun.BindingError, + stun.ErrorCodeAttribute{ + Code: stun.CodeRoleConflict, + Reason: []byte("Role Conflict"), + }, + stun.NewShortTermIntegrity(a.localPwd), + stun.Fingerprint, + ); err != nil { + a.log.Warnf("Failed to generate Role Conflict message from: %s to: %s error: %s", local, remote, err) + } else { + a.sendSTUN(roleConflictMsg, local, remote) + } + } else { + a.isControlling.Store(!a.isControlling.Load()) + a.setSelector() + } +} + +// handleInbound processes STUN traffic from a remote candidate. +func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Addr) { + if msg == nil || local == nil { + return + } + + if !canHandleInbound(msg) { + a.log.Tracef("Unhandled STUN from %s to %s class(%s) method(%s)", remote, local, msg.Type.Class, msg.Type.Method) + + return + } + + remoteCandidate := a.findRemoteCandidate(local.NetworkType(), remote) + + switch msg.Type.Class { + case stun.ClassSuccessResponse: + if !a.handleInboundResponse(remoteCandidate, local, remote, msg) { + return + } + case stun.ClassRequest: + var ok bool + if remoteCandidate, ok = a.handleInboundRequest(remoteCandidate, local, remote, msg); !ok { + return + } + default: + } + + if remoteCandidate != nil { + remoteCandidate.seen(false) + } +} + +func canHandleInbound(msg *stun.Message) bool { + return msg.Type.Method == stun.MethodBinding && + (msg.Type.Class == stun.ClassSuccessResponse || + msg.Type.Class == stun.ClassRequest || + msg.Type.Class == stun.ClassIndication) +} + +func (a *Agent) handleInboundResponse( + remoteCandidate, local Candidate, remote net.Addr, msg *stun.Message, +) bool { + if err := stun.MessageIntegrity([]byte(a.remotePwd)).Check(msg); err != nil { + a.log.Warnf("Discard success response with broken integrity from (%s), %v", remote, err) + + return false + } + + if remoteCandidate == nil { + a.log.Warnf("Discard success message from (%s), no such remote", remote) + + return false + } + + a.getSelector().HandleSuccessResponse(msg, local, remoteCandidate, remote) + + return true +} + +func (a *Agent) handleInboundRequest( + remoteCandidate, local Candidate, remote net.Addr, msg *stun.Message, +) (remoteCand Candidate, ok bool) { + a.log.Tracef( + "Inbound STUN (Request) from %s to %s, useCandidate: %v", + remote, + local, + msg.Contains(stun.AttrUseCandidate), + ) + + if err := stunx.AssertUsername(msg, a.localUfrag+":"+a.remoteUfrag); err != nil { + a.log.Warnf("Discard request with wrong username from (%s), %v", remote, err) + + return nil, false + } else if err := stun.MessageIntegrity([]byte(a.localPwd)).Check(msg); err != nil { + a.log.Warnf("Discard request with broken integrity from (%s), %v", remote, err) + + return nil, false + } + + if remoteCandidate == nil { + ip, port, networkType, err := parseAddr(remote) + if err != nil { + a.log.Errorf("Failed to create parse remote net.Addr when creating remote prflx candidate: %s", err) + + return nil, false + } + + prflxCandidateConfig := CandidatePeerReflexiveConfig{ + Network: networkType.String(), + Address: ip.String(), + Port: port, + Component: local.Component(), + RelAddr: "", + RelPort: 0, + } + + prflxCandidate, err := NewCandidatePeerReflexive(&prflxCandidateConfig) + if err != nil { + a.log.Errorf("Failed to create new remote prflx candidate (%s)", err) + + return nil, false + } + remoteCandidate = prflxCandidate + + a.log.Debugf("Adding a new peer-reflexive candidate: %s ", remote) + if !a.addRemoteCandidate(remoteCandidate) { + return nil, false + } + } + + // Support Remotes that don't set a TIE-BREAKER. Not standards compliant, but + // keeping to maintain backwards compat + remoteTieBreaker := &AttrControl{} + if err := remoteTieBreaker.GetFrom(msg); err == nil && remoteTieBreaker.Role == a.role() { + a.handleRoleConflict(msg, local, remoteCandidate, remoteTieBreaker) + + return nil, false + } + + a.getSelector().HandleBindingRequest(msg, local, remoteCandidate) + + return remoteCandidate, true +} + +// validateNonSTUNTraffic processes non STUN traffic from a remote candidate, +// and returns true if it is an actual remote candidate. +func (a *Agent) validateNonSTUNTraffic(local Candidate, remote net.Addr) (Candidate, bool) { + var remoteCandidate Candidate + if err := a.loop.Run(local.context(), func(context.Context) { + remoteCandidate = a.findRemoteCandidate(local.NetworkType(), remote) + if remoteCandidate != nil { + remoteCandidate.seen(false) + } + }); err != nil { + a.log.Warnf("Failed to validate remote candidate: %v", err) + } + + return remoteCandidate, remoteCandidate != nil +} + +// GetSelectedCandidatePair returns the selected pair or nil if there is none. +func (a *Agent) GetSelectedCandidatePair() (*CandidatePair, error) { + selectedPair := a.getSelectedPair() + if selectedPair == nil { + return nil, nil //nolint:nilnil + } + + local, err := selectedPair.Local.copy() + if err != nil { + return nil, err + } + + remote, err := selectedPair.Remote.copy() + if err != nil { + return nil, err + } + + return &CandidatePair{Local: local, Remote: remote}, nil +} + +func (a *Agent) getSelectedPair() *CandidatePair { + if selectedPair, ok := a.selectedPair.Load().(*CandidatePair); ok { + return selectedPair + } + + return nil +} + +func (a *Agent) closeMulticastConn() { + if a.mDNSConn != nil { + if err := a.mDNSConn.Close(); err != nil { + a.log.Warnf("Failed to close mDNS Conn: %v", err) + } + } +} + +// SetRemoteCredentials sets the credentials of the remote agent. +func (a *Agent) SetRemoteCredentials(remoteUfrag, remotePwd string) error { + switch { + case remoteUfrag == "": + return ErrRemoteUfragEmpty + case remotePwd == "": + return ErrRemotePwdEmpty + } + + return a.loop.Run(a.loop, func(_ context.Context) { + a.remoteUfrag = remoteUfrag + a.remotePwd = remotePwd + }) +} + +// UpdateOptions applies the given options to the agent at runtime. +// Only a subset of options can be updated after agent creation: +// - WithUrls: updates STUN/TURN server URLs (takes effect on next GatherCandidates call) +// +// Returns an error if the agent is closed or if an unsupported option is provided. +func (a *Agent) UpdateOptions(opts ...AgentOption) error { + var optErr error + + err := a.loop.Run(a.loop, func(_ context.Context) { + for _, opt := range opts { + if optErr = opt(a); optErr != nil { + return + } + } + }) + if err != nil { + return err + } + + return optErr +} + +// Restart restarts the ICE Agent with the provided ufrag/pwd +// If no ufrag/pwd is provided the Agent will generate one itself +// +// If there is a gatherer routine currently running, Restart will +// cancel it. +// After a Restart, the user must then call GatherCandidates explicitly +// to start generating new ones. +func (a *Agent) Restart(ufrag, pwd string) error { //nolint:cyclop + if ufrag == "" { + var err error + ufrag, err = generateUFrag() + if err != nil { + return err + } + } + if pwd == "" { + var err error + pwd, err = generatePwd() + if err != nil { + return err + } + } + + if len([]rune(ufrag))*8 < 24 { + return ErrLocalUfragInsufficientBits + } + if len([]rune(pwd))*8 < 128 { + return ErrLocalPwdInsufficientBits + } + + var err error + if runErr := a.loop.Run(a.loop, func(_ context.Context) { + if a.gatheringState == GatheringStateGathering { + a.gatherCandidateCancel() + } + + // Clear all agent needed to take back to fresh state + a.removeUfragFromMux() + a.localUfrag = ufrag + a.localPwd = pwd + a.remoteUfrag = "" + a.remotePwd = "" + a.gatheringState = GatheringStateNew + a.checklist = make([]*CandidatePair, 0) + a.pairsByID = make(map[uint64]*CandidatePair) + a.pendingBindingRequests = make([]bindingRequest, 0) + a.setSelectedPair(nil) + a.deleteAllCandidates() + a.setSelector() + + // Restart is used by NewAgent. Accept/Connect should be used to move to checking + // for new Agents + if a.connectionState != ConnectionStateNew { + a.updateConnectionState(ConnectionStateChecking) + } + }); runErr != nil { + return runErr + } + + return err +} + +func (a *Agent) setGatheringState(newState GatheringState) error { + done := make(chan struct{}) + if err := a.loop.Run(a.loop, func(context.Context) { + if a.gatheringState != newState && newState == GatheringStateComplete { + a.candidateNotifier.EnqueueCandidate(nil) + } + + a.gatheringState = newState + close(done) + }); err != nil { + return err + } + + <-done + + return nil +} + +func (a *Agent) needsToCheckPriorityOnNominated() bool { + return !a.lite || a.enableUseCandidateCheckPriority +} + +func (a *Agent) role() Role { + if a.isControlling.Load() { + return Controlling + } + + return Controlled +} + +func (a *Agent) setSelector() { + a.selectorLock.Lock() + defer a.selectorLock.Unlock() + + var s pairCandidateSelector + if a.isControlling.Load() { + s = &controllingSelector{agent: a, log: a.log} + } else { + s = &controlledSelector{agent: a, log: a.log} + } + if a.lite { + s = &liteSelector{pairCandidateSelector: s} + } + + s.Start() + a.selector = s +} + +func (a *Agent) getSelector() pairCandidateSelector { + a.selectorLock.Lock() + defer a.selectorLock.Unlock() + + return a.selector +} + +// getNominationValue returns a nomination value if generator is available, otherwise 0. +func (a *Agent) getNominationValue() uint32 { + if a.nominationValueGenerator != nil { + return a.nominationValueGenerator() + } + + return 0 +} + +// RenominateCandidate allows the controlling ICE agent to nominate a new candidate pair. +// This implements the continuous renomination feature from draft-thatcher-ice-renomination-01. +func (a *Agent) RenominateCandidate(local, remote Candidate) error { + if !a.isControlling.Load() { + return ErrOnlyControllingAgentCanRenominate + } + + if !a.enableRenomination { + return ErrRenominationNotEnabled + } + + // Find the candidate pair + pair := a.findPair(local, remote) + if pair == nil { + return ErrCandidatePairNotFound + } + + // Send nomination with custom attribute + return a.sendNominationRequest(pair, a.getNominationValue()) +} + +// sendNominationRequest sends a nomination request with custom nomination value. +func (a *Agent) sendNominationRequest(pair *CandidatePair, nominationValue uint32) error { + attributes := []stun.Setter{ + stun.TransactionID, + stun.NewUsername(a.remoteUfrag + ":" + a.localUfrag), + UseCandidate(), + AttrControlling(a.tieBreaker), + PriorityAttr(pair.Local.Priority()), + stun.NewShortTermIntegrity(a.remotePwd), + stun.Fingerprint, + } + + // Add nomination attribute if renomination is enabled and value > 0 + if a.enableRenomination && nominationValue > 0 { + attributes = append(attributes, NominationSetter{ + Value: nominationValue, + AttrType: a.nominationAttribute, + }) + a.log.Tracef("Sending renomination request from %s to %s with nomination value %d", + pair.Local, pair.Remote, nominationValue) + } + + msg, err := stun.Build(append([]stun.Setter{stun.BindingRequest}, attributes...)...) + if err != nil { + return fmt.Errorf("failed to build nomination request: %w", err) + } + + a.sendBindingRequest(msg, pair.Local, pair.Remote) + + return nil +} + +// evaluateCandidatePairQuality calculates a quality score for a candidate pair. +// Higher scores indicate better quality. The score considers: +// - Candidate types (host > srflx > relay) +// - RTT (lower is better) +// - Connection stability. +func (a *Agent) evaluateCandidatePairQuality(pair *CandidatePair) float64 { //nolint:cyclop + if pair == nil || pair.state != CandidatePairStateSucceeded { + return 0 + } + + score := float64(0) + + // Type preference scoring (host=100, srflx=50, prflx=30, relay=10) + localTypeScore := float64(0) + switch pair.Local.Type() { + case CandidateTypeHost: + localTypeScore = 100 + case CandidateTypeServerReflexive: + localTypeScore = 50 + case CandidateTypePeerReflexive: + localTypeScore = 30 + case CandidateTypeRelay: + localTypeScore = 10 + case CandidateTypeUnspecified: + localTypeScore = 0 + } + + remoteTypeScore := float64(0) + switch pair.Remote.Type() { + case CandidateTypeHost: + remoteTypeScore = 100 + case CandidateTypeServerReflexive: + remoteTypeScore = 50 + case CandidateTypePeerReflexive: + remoteTypeScore = 30 + case CandidateTypeRelay: + remoteTypeScore = 10 + case CandidateTypeUnspecified: + remoteTypeScore = 0 + } + + // Combined type score (average of local and remote) + score += (localTypeScore + remoteTypeScore) / 2 + + // RTT scoring (convert to penalty, lower RTT = higher score) + // Use current RTT if available, otherwise assume high latency + rtt := pair.CurrentRoundTripTime() + if rtt > 0 { + // Convert RTT to Duration for cleaner calculation + rttDuration := time.Duration(rtt * float64(time.Second)) + rttMs := float64(rttDuration / time.Millisecond) + if rttMs < 1 { + rttMs = 1 // Minimum 1ms to avoid log(0) + } + // Subtract RTT penalty (logarithmic to reduce impact of very high RTTs) + score -= math.Log10(rttMs) * 10 + } else { + // No RTT data available, apply moderate penalty + score -= 30 + } + + // Boost score if pair has been stable (received responses recently) + if pair.ResponsesReceived() > 0 { + lastResponse := pair.LastResponseReceivedAt() + if !lastResponse.IsZero() && time.Since(lastResponse) < 5*time.Second { + score += 20 // Stability bonus + } + } + + return score +} + +// shouldRenominate determines if automatic renomination should occur. +// It compares the current selected pair with a candidate pair and decides +// if switching would provide significant benefit. +func (a *Agent) shouldRenominate(current, candidate *CandidatePair) bool { //nolint:cyclop + if current == nil || candidate == nil || current.equal(candidate) || candidate.state != CandidatePairStateSucceeded { + return false + } + + // Type-based switching (always prefer direct over relay) + currentIsRelay := current.Local.Type() == CandidateTypeRelay || + current.Remote.Type() == CandidateTypeRelay + candidateIsDirect := candidate.Local.Type() == CandidateTypeHost && + candidate.Remote.Type() == CandidateTypeHost + + if currentIsRelay && candidateIsDirect { + a.log.Debugf("Should renominate: relay -> direct connection available") + + return true + } + + // RTT-based switching (must improve by at least 10ms) + currentRTT := current.CurrentRoundTripTime() + candidateRTT := candidate.CurrentRoundTripTime() + + // Only compare RTT if both values are valid + if currentRTT > 0 && candidateRTT > 0 { + currentRTTDuration := time.Duration(currentRTT * float64(time.Second)) + candidateRTTDuration := time.Duration(candidateRTT * float64(time.Second)) + rttImprovement := currentRTTDuration - candidateRTTDuration + + if rttImprovement > 10*time.Millisecond { + a.log.Debugf("Should renominate: RTT improvement of %v", rttImprovement) + + return true + } + } + + // Quality score comparison (must improve by at least 15%) + currentScore := a.evaluateCandidatePairQuality(current) + candidateScore := a.evaluateCandidatePairQuality(candidate) + + if candidateScore > currentScore*1.15 { + a.log.Debugf("Should renominate: quality score improved from %.2f to %.2f", + currentScore, candidateScore) + + return true + } + + return false +} + +// findBestCandidatePair finds the best available candidate pair based on quality assessment. +func (a *Agent) findBestCandidatePair() *CandidatePair { + var best *CandidatePair + bestScore := float64(-math.MaxFloat64) + + for _, pair := range a.checklist { + if pair.state != CandidatePairStateSucceeded { + continue + } + + score := a.evaluateCandidatePairQuality(pair) + if score > bestScore { + bestScore = score + best = pair + } + } + + return best +} diff --git a/vendor/github.com/pion/ice/v4/agent_config.go b/vendor/github.com/pion/ice/v4/agent_config.go new file mode 100644 index 0000000..71584d9 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/agent_config.go @@ -0,0 +1,308 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "net" + "time" + + "github.com/pion/logging" + "github.com/pion/stun/v3" + "github.com/pion/transport/v4" + "golang.org/x/net/proxy" +) + +const ( + // defaultCheckInterval is the interval at which the agent performs candidate checks in the connecting phase. + defaultCheckInterval = 200 * time.Millisecond + + // keepaliveInterval used to keep candidates alive. + defaultKeepaliveInterval = 2 * time.Second + + // defaultDisconnectedTimeout is the default time till an Agent transitions disconnected. + defaultDisconnectedTimeout = 5 * time.Second + + // defaultFailedTimeout is the default time till an Agent transitions to failed after disconnected. + defaultFailedTimeout = 25 * time.Second + + // defaultHostAcceptanceMinWait is the wait time before nominating a host candidate. + defaultHostAcceptanceMinWait = 0 + + // defaultSrflxAcceptanceMinWait is the wait time before nominating a srflx candidate. + defaultSrflxAcceptanceMinWait = 500 * time.Millisecond + + // defaultPrflxAcceptanceMinWait is the wait time before nominating a prflx candidate. + defaultPrflxAcceptanceMinWait = 1000 * time.Millisecond + + // defaultRelayAcceptanceMinWait is the wait time before nominating a relay candidate. + defaultRelayAcceptanceMinWait = 2000 * time.Millisecond + + // defaultRelayOnlyAcceptanceMinWait is the wait time before nominating with a relay only candidate. + defaultRelayOnlyAcceptanceMinWait = time.Duration(0) + + // defaultSTUNGatherTimeout is the wait time for STUN responses. + defaultSTUNGatherTimeout = 5 * time.Second + + // defaultMaxBindingRequests is the maximum number of binding requests before considering a pair failed. + defaultMaxBindingRequests = 7 + + // TCPPriorityOffset is a number which is subtracted from the default (UDP) candidate type preference + // for host, srflx and prfx candidate types. + defaultTCPPriorityOffset = 27 + + // maxBufferSize is the number of bytes that can be buffered before we start to error. + maxBufferSize = 1000 * 1000 // 1MB + + // maxBindingRequestTimeout is the wait time before binding requests can be deleted. + maxBindingRequestTimeout = 4000 * time.Millisecond +) + +func defaultCandidateTypes() []CandidateType { + return []CandidateType{CandidateTypeHost, CandidateTypeServerReflexive, CandidateTypeRelay} +} + +func defaultRelayAcceptanceMinWaitFor(candidateTypes []CandidateType) time.Duration { + if len(candidateTypes) == 1 && candidateTypes[0] == CandidateTypeRelay { + return defaultRelayOnlyAcceptanceMinWait + } + + return defaultRelayAcceptanceMinWait +} + +// AgentConfig collects the arguments to ice.Agent construction into +// a single structure, for future-proofness of the interface. +// +// Deprecated: use NewAgentWithOptions instead. +type AgentConfig struct { + Urls []*stun.URI + + // PortMin and PortMax are optional. Leave them 0 for the default UDP port allocation strategy. + PortMin uint16 + PortMax uint16 + + // LocalUfrag and LocalPwd values used to perform connectivity + // checks. The values MUST be unguessable, with at least 128 bits of + // random number generator output used to generate the password, and + // at least 24 bits of output to generate the username fragment. + LocalUfrag string + LocalPwd string + + // MulticastDNSMode controls mDNS behavior for the ICE agent + MulticastDNSMode MulticastDNSMode + + // MulticastDNSHostName controls the hostname for this agent. If none is specified a random one will be generated + MulticastDNSHostName string + + // DisconnectedTimeout defaults to 5 seconds when this property is nil. + // If the duration is 0, the ICE Agent will never go to disconnected + DisconnectedTimeout *time.Duration + + // FailedTimeout defaults to 25 seconds when this property is nil. + // If the duration is 0, we will never go to failed. + FailedTimeout *time.Duration + + // KeepaliveInterval determines how often should we send ICE + // keepalives (should be less then connectiontimeout above) + // when this is nil, it defaults to 2 seconds. + // A keepalive interval of 0 means we never send keepalive packets + KeepaliveInterval *time.Duration + + // CheckInterval controls how often our task loop runs when in the + // connecting state. + CheckInterval *time.Duration + + // NetworkTypes is an optional configuration for disabling or enabling + // support for specific network types. + NetworkTypes []NetworkType + + // CandidateTypes is an optional configuration for disabling or enabling + // support for specific candidate types. + CandidateTypes []CandidateType + + LoggerFactory logging.LoggerFactory + + // MaxBindingRequests is the max amount of binding requests the agent will send + // over a candidate pair for validation or nomination, if after MaxBindingRequests + // the candidate is yet to answer a binding request or a nomination we set the pair as failed + MaxBindingRequests *uint16 + + // Lite agents do not perform connectivity check and only provide host candidates. + Lite bool + + // NAT1To1IPCandidateType is used along with NAT1To1IPs to specify which candidate type + // the 1:1 NAT IP addresses should be mapped to. + // If unspecified or CandidateTypeHost, NAT1To1IPs are used to replace host candidate IPs. + // If CandidateTypeServerReflexive, it will insert a srflx candidate (as if it was derived + // from a STUN server) with its port number being the one for the actual host candidate. + // Other values will result in an error. + // + // Deprecated: use WithAddressRewriteRules with an explicit host or srflx rule instead. + // This field will be removed in a future major release. + NAT1To1IPCandidateType CandidateType + + // NAT1To1IPs contains a list of public IP addresses that are to be used as a host + // candidate or srflx candidate. This is used typically for servers that are behind + // 1:1 D-NAT (e.g. AWS EC2 instances) and to eliminate the need of server reflexive + // candidate gathering. + // + // Deprecated: use WithAddressRewriteRules with an explicit host or srflx rule instead. + // This field will be removed in a future major release. + NAT1To1IPs []string + + // HostAcceptanceMinWait specify a minimum wait time before selecting host candidates + HostAcceptanceMinWait *time.Duration + // SrflxAcceptanceMinWait specify a minimum wait time before selecting srflx candidates + SrflxAcceptanceMinWait *time.Duration + // PrflxAcceptanceMinWait specify a minimum wait time before selecting prflx candidates + PrflxAcceptanceMinWait *time.Duration + // RelayAcceptanceMinWait specify a minimum wait time before selecting relay candidates + RelayAcceptanceMinWait *time.Duration + // STUNGatherTimeout specify a minimum wait time for STUN responses + STUNGatherTimeout *time.Duration + + // Net is the our abstracted network interface for internal development purpose only + // (see https://github.com/pion/transport) + Net transport.Net + + // InterfaceFilter is a function that you can use in order to whitelist or blacklist + // the interfaces which are used to gather ICE candidates. + InterfaceFilter func(string) (keep bool) + + // IPFilter is a function that you can use in order to whitelist or blacklist + // the ips which are used to gather ICE candidates. + IPFilter func(net.IP) (keep bool) + + // RemoteIPFilter is a function that you can use in order to whitelist or blacklist + // remote candidate IP addresses before they are added to the agent. + RemoteIPFilter func(net.IP) (keep bool) + + // InsecureSkipVerify controls if self-signed certificates are accepted when connecting + // to TURN servers via TLS or DTLS + InsecureSkipVerify bool + + // TCPMux will be used for multiplexing incoming TCP connections for ICE TCP. + // Currently only passive candidates are supported. This functionality is + // experimental and the API might change in the future. + TCPMux TCPMux + + // UDPMux is used for multiplexing multiple incoming UDP connections on a single port + // when this is set, the agent ignores PortMin and PortMax configurations and will + // defer to UDPMux for incoming connections + UDPMux UDPMux + + // UDPMuxSrflx is used for multiplexing multiple incoming UDP connections of server reflexive candidates + // on a single port when this is set, the agent ignores PortMin and PortMax configurations and will + // defer to UDPMuxSrflx for incoming connections + // It embeds UDPMux to do the actual connection multiplexing + UDPMuxSrflx UniversalUDPMux + + // Proxy Dialer is a dialer that should be implemented by the user based on golang.org/x/net/proxy + // dial interface in order to support corporate proxies + ProxyDialer proxy.Dialer + + // Deprecated: AcceptAggressiveNomination always enabled. + AcceptAggressiveNomination bool + + // Include loopback addresses in the candidate list. + IncludeLoopback bool + + // TCPPriorityOffset is a number which is subtracted from the default (UDP) candidate type preference + // for host, srflx and prfx candidate types. It helps to configure relative preference of UDP candidates + // against TCP ones. Relay candidates for TCP and UDP are always 0 and not affected by this setting. + // When this is nil, defaultTCPPriorityOffset is used. + TCPPriorityOffset *uint16 + + // DisableActiveTCP can be used to disable Active TCP candidates. Otherwise when TCP is enabled + // Active TCP candidates will be created when a new passive TCP remote candidate is added. + DisableActiveTCP bool + + // BindingRequestHandler allows applications to perform logic on incoming STUN Binding Requests + // This was implemented to allow users to + // * Log incoming Binding Requests for debugging + // * Implement draft-thatcher-ice-renomination + // * Implement custom CandidatePair switching logic + BindingRequestHandler func(m *stun.Message, local, remote Candidate, pair *CandidatePair) bool + + // EnableUseCandidateCheckPriority can be used to enable checking for equal or higher priority to + // switch selected candidate pair if the peer requests USE-CANDIDATE and agent is a lite agent. + // This is disabled by default, i. e. when peer requests USE-CANDIDATE, the selected pair will be + // switched to that irrespective of relative priority between current selected pair + // and priority of the pair being switched to. + EnableUseCandidateCheckPriority bool +} + +// initWithDefaults populates an agent and falls back to defaults if fields are unset. +func (config *AgentConfig) initWithDefaults(agent *Agent) { //nolint:cyclop + if config.MaxBindingRequests == nil { + agent.maxBindingRequests = defaultMaxBindingRequests + } else { + agent.maxBindingRequests = *config.MaxBindingRequests + } + + if config.HostAcceptanceMinWait == nil { + agent.hostAcceptanceMinWait = defaultHostAcceptanceMinWait + } else { + agent.hostAcceptanceMinWait = *config.HostAcceptanceMinWait + } + + if config.SrflxAcceptanceMinWait == nil { + agent.srflxAcceptanceMinWait = defaultSrflxAcceptanceMinWait + } else { + agent.srflxAcceptanceMinWait = *config.SrflxAcceptanceMinWait + } + + if config.PrflxAcceptanceMinWait == nil { + agent.prflxAcceptanceMinWait = defaultPrflxAcceptanceMinWait + } else { + agent.prflxAcceptanceMinWait = *config.PrflxAcceptanceMinWait + } + + if config.RelayAcceptanceMinWait == nil { + agent.relayAcceptanceMinWait = defaultRelayAcceptanceMinWaitFor(config.CandidateTypes) + } else { + agent.relayAcceptanceMinWait = *config.RelayAcceptanceMinWait + } + + if config.STUNGatherTimeout == nil { + agent.stunGatherTimeout = defaultSTUNGatherTimeout + } else { + agent.stunGatherTimeout = *config.STUNGatherTimeout + } + + if config.TCPPriorityOffset == nil { + agent.tcpPriorityOffset = defaultTCPPriorityOffset + } else { + agent.tcpPriorityOffset = *config.TCPPriorityOffset + } + + if config.DisconnectedTimeout == nil { + agent.disconnectedTimeout = defaultDisconnectedTimeout + } else { + agent.disconnectedTimeout = *config.DisconnectedTimeout + } + + if config.FailedTimeout == nil { + agent.failedTimeout = defaultFailedTimeout + } else { + agent.failedTimeout = *config.FailedTimeout + } + + if config.KeepaliveInterval == nil { + agent.keepaliveInterval = defaultKeepaliveInterval + } else { + agent.keepaliveInterval = *config.KeepaliveInterval + } + + if config.CheckInterval == nil { + agent.checkInterval = defaultCheckInterval + } else { + agent.checkInterval = *config.CheckInterval + } + + if len(config.CandidateTypes) == 0 { + agent.candidateTypes = defaultCandidateTypes() + } else { + agent.candidateTypes = config.CandidateTypes + } +} diff --git a/vendor/github.com/pion/ice/v4/agent_handlers.go b/vendor/github.com/pion/ice/v4/agent_handlers.go new file mode 100644 index 0000000..95ee925 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/agent_handlers.go @@ -0,0 +1,190 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import "sync" + +// OnConnectionStateChange sets a handler that is fired when the connection state changes. +func (a *Agent) OnConnectionStateChange(f func(ConnectionState)) error { + a.onConnectionStateChangeHdlr.Store(f) + + return nil +} + +// OnSelectedCandidatePairChange sets a handler that is fired when the final candidate. +// pair is selected. +func (a *Agent) OnSelectedCandidatePairChange(f func(Candidate, Candidate)) error { + a.onSelectedCandidatePairChangeHdlr.Store(f) + + return nil +} + +// OnCandidate sets a handler that is fired when new candidates gathered. When +// the gathering process complete the last candidate is nil. +func (a *Agent) OnCandidate(f func(Candidate)) error { + a.onCandidateHdlr.Store(f) + + return nil +} + +func (a *Agent) onSelectedCandidatePairChange(p *CandidatePair) { + if h, ok := a.onSelectedCandidatePairChangeHdlr.Load().(func(Candidate, Candidate)); ok && h != nil { + h(p.Local, p.Remote) + } +} + +func (a *Agent) onCandidate(c Candidate) { + if onCandidateHdlr, ok := a.onCandidateHdlr.Load().(func(Candidate)); ok && onCandidateHdlr != nil { + onCandidateHdlr(c) + } +} + +func (a *Agent) onConnectionStateChange(s ConnectionState) { + if hdlr, ok := a.onConnectionStateChangeHdlr.Load().(func(ConnectionState)); ok && hdlr != nil { + hdlr(s) + } +} + +type handlerNotifier struct { + sync.Mutex + running bool + notifiers sync.WaitGroup + + connectionStates []ConnectionState + connectionStateFunc func(ConnectionState) + + candidates []Candidate + candidateFunc func(Candidate) + + selectedCandidatePairs []*CandidatePair + candidatePairFunc func(*CandidatePair) + + // State for closing + done chan struct{} +} + +func (h *handlerNotifier) Close(graceful bool) { + if graceful { + // if we were closed ungracefully before, we now + // want ot wait. + defer h.notifiers.Wait() + } + + h.Lock() + + select { + case <-h.done: + h.Unlock() + + return + default: + } + close(h.done) + h.Unlock() +} + +func (h *handlerNotifier) EnqueueConnectionState(state ConnectionState) { + h.Lock() + defer h.Unlock() + + select { + case <-h.done: + return + default: + } + + notify := func() { + defer h.notifiers.Done() + for { + h.Lock() + if len(h.connectionStates) == 0 { + h.running = false + h.Unlock() + + return + } + notification := h.connectionStates[0] + h.connectionStates = h.connectionStates[1:] + h.Unlock() + h.connectionStateFunc(notification) + } + } + + h.connectionStates = append(h.connectionStates, state) + if !h.running { + h.running = true + h.notifiers.Add(1) + go notify() + } +} + +func (h *handlerNotifier) EnqueueCandidate(cand Candidate) { + h.Lock() + defer h.Unlock() + + select { + case <-h.done: + return + default: + } + + notify := func() { + defer h.notifiers.Done() + for { + h.Lock() + if len(h.candidates) == 0 { + h.running = false + h.Unlock() + + return + } + notification := h.candidates[0] + h.candidates = h.candidates[1:] + h.Unlock() + h.candidateFunc(notification) + } + } + + h.candidates = append(h.candidates, cand) + if !h.running { + h.running = true + h.notifiers.Add(1) + go notify() + } +} + +func (h *handlerNotifier) EnqueueSelectedCandidatePair(pair *CandidatePair) { + h.Lock() + defer h.Unlock() + + select { + case <-h.done: + return + default: + } + + notify := func() { + defer h.notifiers.Done() + for { + h.Lock() + if len(h.selectedCandidatePairs) == 0 { + h.running = false + h.Unlock() + + return + } + notification := h.selectedCandidatePairs[0] + h.selectedCandidatePairs = h.selectedCandidatePairs[1:] + h.Unlock() + h.candidatePairFunc(notification) + } + } + + h.selectedCandidatePairs = append(h.selectedCandidatePairs, pair) + if !h.running { + h.running = true + h.notifiers.Add(1) + go notify() + } +} diff --git a/vendor/github.com/pion/ice/v4/agent_options.go b/vendor/github.com/pion/ice/v4/agent_options.go new file mode 100644 index 0000000..fc97d0b --- /dev/null +++ b/vendor/github.com/pion/ice/v4/agent_options.go @@ -0,0 +1,976 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "fmt" + "net" + "sort" + "strings" + "sync/atomic" + "time" + + "github.com/pion/logging" + "github.com/pion/stun/v3" + "github.com/pion/transport/v4" + "golang.org/x/net/proxy" +) + +// AgentOption represents a function that can be used to configure an Agent. +type AgentOption func(*Agent) error + +// NominationValueGenerator is a function that generates nomination values for renomination. +type NominationValueGenerator func() uint32 + +// DefaultNominationValueGenerator returns a generator that starts at 1 and increments for each call. +// This provides a simple, monotonically increasing sequence suitable for renomination. +func DefaultNominationValueGenerator() NominationValueGenerator { + var counter atomic.Uint32 + + return func() uint32 { + return counter.Add(1) + } +} + +// WithAddressRewriteRules appends the provided address rewrite (1:1) rules to the agent's +// existing configuration. Each `AddressRewriteRule` can limit the mapping to a specific +// interface (`Iface`), local address (`Local`), CIDR block (`CIDR`), or subset +// of network types (`Networks`), allowing fine-grained control over which local +// addresses are replaced with the supplied external IPs. +// Use `Mode` to control whether a rule replaces the original candidate (default for +// host) or appends additional candidates (default for other types). +// +// Rules are evaluated in the order they are added; for each candidate type + +// local address, explicit `Local` matches win immediately. Otherwise, the most +// specific catch-all is chosen (iface+CIDR > iface-only > CIDR-only > global), +// with declaration order breaking ties at the same specificity. `Iface` (when +// set) must also match. This lets you layer specificity (e.g., iface+CIDR, then +// iface-only, then global) while still keeping rule order meaningful. +// Overlapping rules in the same scope are logged as warnings. +func WithAddressRewriteRules(rules ...AddressRewriteRule) AgentOption { + return func(agent *Agent) error { + if agent.constructed { + return ErrAgentOptionNotUpdatable + } + + return appendAddressRewriteRules(agent, rules...) + } +} + +func warnOnAddressRewriteConflicts(agent *Agent) { + if agent == nil || agent.log == nil { + return + } + + for _, conflict := range findAddressRewriteRuleConflicts(agent.addressRewriteRules) { + scope := conflict.scope + scopeSummary := fmt.Sprintf( + "candidate=%s iface=%s cidr=%s networks=%s local=%s", + scope.candidateType.String(), + emptyScopeValue(scope.iface), + emptyScopeValue(scope.cidr), + emptyScopeValue(scope.networksKey), + scope.localKey, + ) + + message := fmt.Sprintf( + "detected overlapping address rewrite rule (%s): existing external IPs [%s], additional external IP %s", + scopeSummary, + strings.Join(conflict.existingExternalIPs, ", "), + conflict.conflictingExternal, + ) + + agent.log.Warn(message) + } +} + +func emptyScopeValue(v string) string { + if v == "" { + return "*" + } + + return v +} + +func appendAddressRewriteRules(agent *Agent, rules ...AddressRewriteRule) error { + if len(rules) == 0 { + return nil + } + + sanitized := make([]AddressRewriteRule, 0, len(rules)) + for _, rule := range rules { + normalized, err := sanitizeAddressRewriteRule(rule) + if err != nil { + return err + } + + sanitized = append(sanitized, normalized) + } + + agent.addressRewriteRules = append(agent.addressRewriteRules, sanitized...) + warnOnAddressRewriteConflicts(agent) + + return nil +} + +func sanitizeAddressRewriteRule(rule AddressRewriteRule) (AddressRewriteRule, error) { + cleaned, err := sanitizeExternalIPs(rule.External) + if err != nil { + return AddressRewriteRule{}, err + } + + normalized := rule + normalized.External = cleaned + normalized.Local = strings.TrimSpace(rule.Local) + if normalized.Local != "" { + if _, _, err := validateIPString(normalized.Local); err != nil { + return AddressRewriteRule{}, err + } + } + switch normalized.Mode { + case addressRewriteModeUnspecified: + normalized.Mode = defaultAddressRewriteMode(normalized.AsCandidateType) + case AddressRewriteReplace, AddressRewriteAppend: + default: + return AddressRewriteRule{}, ErrInvalidNAT1To1IPMapping + } + if len(rule.Networks) > 0 { + normalized.Networks = append([]NetworkType(nil), rule.Networks...) + } + + return normalized, nil +} + +func defaultAddressRewriteMode(candidateType CandidateType) AddressRewriteMode { + if candidateType == CandidateTypeUnspecified || candidateType == CandidateTypeHost { + return AddressRewriteReplace + } + + return AddressRewriteAppend +} + +func sanitizeExternalIPs(ips []string) ([]string, error) { + seen := make(map[string]struct{}, len(ips)) + sanitized := make([]string, 0, len(ips)) + + for _, raw := range ips { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + continue + } + + if _, ok := seen[trimmed]; ok { + continue + } + + if strings.Contains(trimmed, "/") { + return nil, ErrInvalidNAT1To1IPMapping + } + + if _, _, err := validateIPString(trimmed); err != nil { + return nil, err + } + + seen[trimmed] = struct{}{} + sanitized = append(sanitized, trimmed) + } + + if len(sanitized) == 0 { + return nil, ErrInvalidNAT1To1IPMapping + } + + return sanitized, nil +} + +type addressRewriteScopeKey struct { + candidateType CandidateType + iface string + cidr string + networksKey string + localKey string +} + +type addressRewriteConflict struct { + scope addressRewriteScopeKey + existingExternalIPs []string + conflictingExternal string +} + +func findAddressRewriteRuleConflicts(rules []AddressRewriteRule) []addressRewriteConflict { + conflicts := make([]addressRewriteConflict, 0) + scopeState := make(map[addressRewriteScopeKey]map[string]struct{}) + + for _, rule := range rules { + candidateType := rule.AsCandidateType + if candidateType == CandidateTypeUnspecified { + candidateType = CandidateTypeHost + } + + networksKey := "*" + if len(rule.Networks) > 0 { + names := make([]string, len(rule.Networks)) + for i, network := range rule.Networks { + names[i] = network.String() + } + sort.Strings(names) + networksKey = strings.Join(names, ",") + } + + externalEntries := enumerateAddressRewriteExternalEntries(rule) + for _, entry := range externalEntries { + key := addressRewriteScopeKey{ + candidateType: candidateType, + iface: rule.Iface, + cidr: rule.CIDR, + networksKey: networksKey, + localKey: entry.localScopeKey, + } + + existing := scopeState[key] + if existing == nil { + existing = make(map[string]struct{}) + scopeState[key] = existing + } + + if len(existing) > 0 { + if _, ok := existing[entry.externalIP]; !ok { + conflicts = append(conflicts, addressRewriteConflict{ + scope: key, + existingExternalIPs: mapKeys(existing), + conflictingExternal: entry.externalIP, + }) + } + } + + existing[entry.externalIP] = struct{}{} + } + } + + return conflicts +} + +type addressRewriteExternalEntry struct { + externalIP string + localScopeKey string +} + +func enumerateAddressRewriteExternalEntries(rule AddressRewriteRule) []addressRewriteExternalEntry { + if len(rule.External) == 0 { + return nil + } + + entries := make([]addressRewriteExternalEntry, 0, len(rule.External)) + localScope := deriveAddressRewriteLocalScopeKey(rule.Local) + + for _, mapping := range rule.External { + if mapping == "" { + continue + } + + external := strings.TrimSpace(mapping) + if external == "" { + continue + } + + scopeKey := localScope + if scopeKey == "" { + scopeKey = deriveAddressRewriteFamilyScopeKey(external) + } + + entries = append(entries, addressRewriteExternalEntry{ + externalIP: external, + localScopeKey: scopeKey, + }) + } + + return entries +} + +func deriveAddressRewriteLocalScopeKey(local string) string { + local = strings.TrimSpace(local) + if local == "" { + return "" + } + + ip, _, err := validateIPString(local) + if err != nil { + return "family:unknown" + } + + if ip.To4() != nil { + return "family:ipv4" + } + + return "family:ipv6" +} + +func deriveAddressRewriteFamilyScopeKey(ipStr string) string { + ip, _, err := validateIPString(ipStr) + if err != nil { + return "family:unknown" + } + + if ip.To4() != nil { + return "family:ipv4" + } + + return "family:ipv6" +} + +func mapKeys(m map[string]struct{}) []string { + if len(m) == 0 { + return nil + } + + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + + return keys +} + +// WithICELite configures whether the agent operates in lite mode. +// Lite agents do not perform connectivity checks and only provide host candidates. +func WithICELite(lite bool) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.lite = lite + + return nil + } +} + +// WithUrls sets the STUN/TURN server URLs used by the agent. +func WithUrls(urls []*stun.URI) AgentOption { + return func(a *Agent) error { + if len(urls) == 0 { + a.urls = nil + + return nil + } + + cloned := make([]*stun.URI, len(urls)) + copy(cloned, urls) + a.urls = cloned + + return nil + } +} + +// WithPortRange sets the UDP port range for host candidates. +func WithPortRange(portMin, portMax uint16) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.portMin = portMin + a.portMax = portMax + + return nil + } +} + +// WithDisconnectedTimeout sets the duration before the agent transitions to disconnected state. +// A timeout of 0 disables the transition. +func WithDisconnectedTimeout(timeout time.Duration) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.disconnectedTimeout = timeout + + return nil + } +} + +// WithFailedTimeout sets the duration before the agent transitions to failed state after disconnected. +// A timeout of 0 disables the transition. +func WithFailedTimeout(timeout time.Duration) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.failedTimeout = timeout + + return nil + } +} + +// WithKeepaliveInterval sets how often ICE keepalive packets are sent. +// An interval of 0 disables keepalives. +func WithKeepaliveInterval(interval time.Duration) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.keepaliveInterval = interval + + return nil + } +} + +// WithHostAcceptanceMinWait sets the minimum wait before selecting host candidates. +func WithHostAcceptanceMinWait(wait time.Duration) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.hostAcceptanceMinWait = wait + + return nil + } +} + +// WithSrflxAcceptanceMinWait sets the minimum wait before selecting srflx candidates. +func WithSrflxAcceptanceMinWait(wait time.Duration) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.srflxAcceptanceMinWait = wait + + return nil + } +} + +// WithPrflxAcceptanceMinWait sets the minimum wait before selecting prflx candidates. +func WithPrflxAcceptanceMinWait(wait time.Duration) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.prflxAcceptanceMinWait = wait + + return nil + } +} + +// WithRelayAcceptanceMinWait sets the minimum wait before selecting relay candidates. +func WithRelayAcceptanceMinWait(wait time.Duration) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.relayAcceptanceMinWait = wait + + return nil + } +} + +// WithSTUNGatherTimeout sets the STUN gather timeout. +func WithSTUNGatherTimeout(timeout time.Duration) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.stunGatherTimeout = timeout + + return nil + } +} + +// WithIPFilter sets a filter for IP addresses used during candidate gathering. +func WithIPFilter(filter func(net.IP) bool) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.ipFilter = filter + + return nil + } +} + +// WithRemoteIPFilter sets a filter for remote candidate IP addresses. +// Candidates for which this function returns false are ignored. +func WithRemoteIPFilter(filter func(net.IP) bool) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.remoteIPFilter = filter + + return nil + } +} + +// WithNet sets the underlying network implementation for the agent. +func WithNet(net transport.Net) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.net = net + + return nil + } +} + +// WithMulticastDNSMode configures mDNS behavior for the agent. +func WithMulticastDNSMode(mode MulticastDNSMode) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.mDNSMode = mode + + return nil + } +} + +// WithMulticastDNSHostName sets the mDNS host name used by the agent. +func WithMulticastDNSHostName(hostName string) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + if !strings.HasSuffix(hostName, ".local") || len(strings.Split(hostName, ".")) != 2 { + return ErrInvalidMulticastDNSHostName + } + + a.mDNSName = hostName + + return nil + } +} + +// WithLocalCredentials sets the local ICE username fragment and password used during Restart. +// If empty strings are provided, the agent will generate values during Restart. +func WithLocalCredentials(ufrag, pwd string) AgentOption { + return func(a *Agent) error { //nolint:varnamelen + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + if ufrag != "" && len([]rune(ufrag))*8 < 24 { + return ErrLocalUfragInsufficientBits + } + if pwd != "" && len([]rune(pwd))*8 < 128 { + return ErrLocalPwdInsufficientBits + } + + a.localUfrag = ufrag + a.localPwd = pwd + + return nil + } +} + +// WithTCPMux sets the TCP mux for ICE TCP multiplexing. +func WithTCPMux(tcpMux TCPMux) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.tcpMux = tcpMux + + return nil + } +} + +// WithUDPMux sets the UDP mux used for multiplexing host candidates. +func WithUDPMux(udpMux UDPMux) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.udpMux = udpMux + + return nil + } +} + +// WithUDPMuxSrflx sets the UDP mux for server reflexive candidates. +func WithUDPMuxSrflx(udpMuxSrflx UniversalUDPMux) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.udpMuxSrflx = udpMuxSrflx + + return nil + } +} + +// WithProxyDialer sets the proxy dialer used for TURN over TCP/TLS/DTLS connections. +func WithProxyDialer(dialer proxy.Dialer) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.proxyDialer = dialer + + return nil + } +} + +// WithMaxBindingRequests sets the maximum number of binding requests before considering a pair failed. +func WithMaxBindingRequests(limit uint16) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.maxBindingRequests = limit + + return nil + } +} + +// WithCheckInterval sets how often the agent runs connectivity checks while connecting. +func WithCheckInterval(interval time.Duration) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.checkInterval = interval + + return nil + } +} + +// WithRenomination enables ICE renomination as described in draft-thatcher-ice-renomination-01. +// When enabled, the controlling agent can renominate candidate pairs multiple times +// and the controlled agent follows "last nomination wins" rule. +// +// The generator parameter specifies how nomination values are generated. +// Use DefaultNominationValueGenerator() for a simple incrementing counter, +// or provide a custom generator for more complex scenarios. +// +// Example: +// +// agent, err := NewAgentWithOptions(config, WithRenomination(DefaultNominationValueGenerator())) +func WithRenomination(generator NominationValueGenerator) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + if generator == nil { + return ErrInvalidNominationValueGenerator + } + a.enableRenomination = true + a.nominationValueGenerator = generator + + return nil + } +} + +// WithNominationAttribute sets the STUN attribute type to use for ICE renomination. +// The default value is 0x0030. This can be configured until the attribute is officially +// assigned by IANA for draft-thatcher-ice-renomination. +// +// This option returns an error if the provided attribute type is invalid. +// Currently, validation ensures the attribute is not 0x0000 (reserved). +// Additional validation may be added in the future. +func WithNominationAttribute(attrType uint16) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + // Basic validation: ensure it's not the reserved 0x0000 + if attrType == 0x0000 { + return ErrInvalidNominationAttribute + } + + a.nominationAttribute = stun.AttrType(attrType) + + return nil + } +} + +// WithIncludeLoopback includes loopback addresses in the candidate list. +// By default, loopback addresses are excluded. +// +// Example: +// +// agent, err := NewAgentWithOptions(WithIncludeLoopback()) +func WithIncludeLoopback() AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.includeLoopback = true + + return nil + } +} + +// WithTCPPriorityOffset sets a number which is subtracted from the default (UDP) candidate type preference +// for host, srflx and prfx candidate types. It helps to configure relative preference of UDP candidates +// against TCP ones. Relay candidates for TCP and UDP are always 0 and not affected by this setting. +// When not set, defaultTCPPriorityOffset (27) is used. +// +// Example: +// +// agent, err := NewAgentWithOptions(WithTCPPriorityOffset(50)) +func WithTCPPriorityOffset(offset uint16) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.tcpPriorityOffset = offset + + return nil + } +} + +// WithDisableActiveTCP disables Active TCP candidates. +// When TCP is enabled, Active TCP candidates will be created when a new passive TCP remote candidate is added +// unless this option is used. +// +// Example: +// +// agent, err := NewAgentWithOptions(WithDisableActiveTCP()) +func WithDisableActiveTCP() AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.disableActiveTCP = true + + return nil + } +} + +// WithBindingRequestHandler sets a handler to allow applications to perform logic on incoming STUN Binding Requests. +// This was implemented to allow users to: +// - Log incoming Binding Requests for debugging +// - Implement draft-thatcher-ice-renomination +// - Implement custom CandidatePair switching logic +// +// Example: +// +// handler := func(m *stun.Message, local, remote Candidate, pair *CandidatePair) bool { +// log.Printf("Binding request from %s to %s", remote.Address(), local.Address()) +// return true // Accept the request +// } +// agent, err := NewAgentWithOptions(WithBindingRequestHandler(handler)) +func WithBindingRequestHandler( + handler func(m *stun.Message, local, remote Candidate, pair *CandidatePair) bool, +) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.userBindingRequestHandler = handler + + return nil + } +} + +// WithEnableUseCandidateCheckPriority enables checking for equal or higher priority when +// switching selected candidate pair if the peer requests USE-CANDIDATE and agent is a lite agent. +// This is disabled by default, i.e. when peer requests USE-CANDIDATE, the selected pair will be +// switched to that irrespective of relative priority between current selected pair +// and priority of the pair being switched to. +// +// Example: +// +// agent, err := NewAgentWithOptions(WithEnableUseCandidateCheckPriority()) +func WithEnableUseCandidateCheckPriority() AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.enableUseCandidateCheckPriority = true + + return nil + } +} + +// WithContinualGatheringPolicy sets the continual gathering policy for the agent. +// When set to GatherContinually, the agent will continuously monitor network interfaces +// and gather new candidates as they become available. +// When set to GatherOnce (default), gathering completes after the initial phase. +// +// Example: +// +// agent, err := NewAgentWithOptions(WithContinualGatheringPolicy(GatherContinually)) +func WithContinualGatheringPolicy(policy ContinualGatheringPolicy) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.continualGatheringPolicy = policy + + return nil + } +} + +// WithNetworkMonitorInterval sets the interval at which the agent checks for network interface changes +// when using GatherContinually policy. This option only has effect when used with +// WithContinualGatheringPolicy(GatherContinually). +// Default is 2 seconds if not specified. +// +// Example: +// +// agent, err := NewAgentWithOptions( +// WithContinualGatheringPolicy(GatherContinually), +// WithNetworkMonitorInterval(5 * time.Second), +// ) +func WithNetworkMonitorInterval(interval time.Duration) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + if interval <= 0 { + return ErrInvalidNetworkMonitorInterval + } + a.networkMonitorInterval = interval + + return nil + } +} + +// WithNetworkTypes sets the enabled network types for candidate gathering. +// By default, all network types are enabled. +// +// Example: +// +// agent, err := NewAgentWithOptions( +// WithNetworkTypes([]NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}), +// ) +func WithNetworkTypes(networkTypes []NetworkType) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.networkTypes = networkTypes + + return nil + } +} + +// WithCandidateTypes sets the enabled candidate types for gathering. +// By default, host, server reflexive, and relay candidates are enabled. +// +// Example: +// +// agent, err := NewAgentWithOptions( +// WithCandidateTypes([]CandidateType{CandidateTypeHost, CandidateTypeServerReflexive}), +// ) +func WithCandidateTypes(candidateTypes []CandidateType) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.candidateTypes = candidateTypes + + return nil + } +} + +// WithAutomaticRenomination enables automatic renomination of candidate pairs +// when better pairs become available after initial connection establishment. +// This feature requires renomination to be enabled and both agents to support it. +// +// When enabled, the controlling agent will periodically evaluate candidate pairs +// and renominate if a significantly better pair is found (e.g., switching from +// relay to direct connection, or when RTT improves significantly). +// +// The interval parameter specifies the minimum time to wait after connection +// before considering automatic renomination. If set to 0, it defaults to 3 seconds. +// +// Example: +// +// agent, err := NewAgentWithOptions( +// WithRenomination(DefaultNominationValueGenerator()), +// WithAutomaticRenomination(3*time.Second), +// ) +func WithAutomaticRenomination(interval time.Duration) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.automaticRenomination = true + if interval > 0 { + a.renominationInterval = interval + } + // Note: renomination must be enabled separately via WithRenomination + return nil + } +} + +// WithInterfaceFilter sets a filter function to whitelist or blacklist network interfaces +// for ICE candidate gathering. +// +// The filter function receives the interface name and should return true to keep the interface, +// or false to exclude it. +// +// Example: +// +// // Only use interfaces starting with "eth" +// agent, err := NewAgentWithOptions( +// WithInterfaceFilter(func(interfaceName string) bool { +// return len(interfaceName) >= 3 && interfaceName[:3] == "eth" +// }), +// ) +func WithInterfaceFilter(filter func(string) bool) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.interfaceFilter = filter + + return nil + } +} + +// WithLoggerFactory sets the logger factory for the agent. +// +// Example: +// +// import "github.com/pion/logging" +// +// loggerFactory := logging.NewDefaultLoggerFactory() +// loggerFactory.DefaultLogLevel = logging.LogLevelDebug +// agent, err := NewAgentWithOptions(WithLoggerFactory(loggerFactory)) +func WithLoggerFactory(loggerFactory logging.LoggerFactory) AgentOption { + return func(a *Agent) error { + if a.constructed { + return ErrAgentOptionNotUpdatable + } + + a.log = loggerFactory.NewLogger("ice") + + return nil + } +} diff --git a/vendor/github.com/pion/ice/v4/agent_stats.go b/vendor/github.com/pion/ice/v4/agent_stats.go new file mode 100644 index 0000000..4129db1 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/agent_stats.go @@ -0,0 +1,169 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "context" + "time" +) + +// GetCandidatePairsStats returns a list of candidate pair stats. +func (a *Agent) GetCandidatePairsStats() []CandidatePairStats { + var res []CandidatePairStats + err := a.loop.Run(a.loop, func(_ context.Context) { + result := make([]CandidatePairStats, 0, len(a.checklist)) + for _, cp := range a.checklist { + stat := CandidatePairStats{ + Timestamp: time.Now(), + LocalCandidateID: cp.Local.ID(), + RemoteCandidateID: cp.Remote.ID(), + State: cp.state, + Nominated: cp.nominated, + PacketsSent: cp.PacketsSent(), + PacketsReceived: cp.PacketsReceived(), + BytesSent: cp.BytesSent(), + BytesReceived: cp.BytesReceived(), + LastPacketSentTimestamp: cp.LastPacketSentAt(), + LastPacketReceivedTimestamp: cp.LastPacketReceivedAt(), + FirstRequestTimestamp: cp.FirstRequestSentAt(), + LastRequestTimestamp: cp.LastRequestSentAt(), + FirstResponseTimestamp: cp.FirstResponseReceivedAt(), + LastResponseTimestamp: cp.LastResponseReceivedAt(), + FirstRequestReceivedTimestamp: cp.FirstRequestReceivedAt(), + LastRequestReceivedTimestamp: cp.LastRequestReceivedAt(), + + TotalRoundTripTime: cp.TotalRoundTripTime(), + CurrentRoundTripTime: cp.CurrentRoundTripTime(), + // AvailableOutgoingBitrate float64 + // AvailableIncomingBitrate float64 + // CircuitBreakerTriggerCount uint32 + RequestsReceived: cp.RequestsReceived(), + RequestsSent: cp.RequestsSent(), + ResponsesReceived: cp.ResponsesReceived(), + ResponsesSent: cp.ResponsesSent(), + // RetransmissionsReceived uint64 + // RetransmissionsSent uint64 + // ConsentRequestsSent uint64 + // ConsentExpiredTimestamp time.Time + } + result = append(result, stat) + } + res = result + }) + if err != nil { + a.log.Errorf("Failed to get candidate pairs stats: %v", err) + + return []CandidatePairStats{} + } + + return res +} + +// GetSelectedCandidatePairStats returns a candidate pair stats for selected candidate pair. +// Returns false if there is no selected pair. +func (a *Agent) GetSelectedCandidatePairStats() (CandidatePairStats, bool) { + isAvailable := false + var res CandidatePairStats + err := a.loop.Run(a.loop, func(_ context.Context) { + sp := a.getSelectedPair() + if sp == nil { + return + } + + isAvailable = true + res = CandidatePairStats{ + Timestamp: time.Now(), + LocalCandidateID: sp.Local.ID(), + RemoteCandidateID: sp.Remote.ID(), + State: sp.state, + Nominated: sp.nominated, + PacketsSent: sp.PacketsSent(), + PacketsReceived: sp.PacketsReceived(), + BytesSent: sp.BytesSent(), + BytesReceived: sp.BytesReceived(), + LastPacketSentTimestamp: sp.LastPacketSentAt(), + LastPacketReceivedTimestamp: sp.LastPacketReceivedAt(), + // FirstRequestTimestamp time.Time + // LastRequestTimestamp time.Time + // LastResponseTimestamp time.Time + TotalRoundTripTime: sp.TotalRoundTripTime(), + CurrentRoundTripTime: sp.CurrentRoundTripTime(), + // AvailableOutgoingBitrate float64 + // AvailableIncomingBitrate float64 + // CircuitBreakerTriggerCount uint32 + // RequestsReceived uint64 + // RequestsSent uint64 + ResponsesReceived: sp.ResponsesReceived(), + // ResponsesSent uint64 + // RetransmissionsReceived uint64 + // RetransmissionsSent uint64 + // ConsentRequestsSent uint64 + // ConsentExpiredTimestamp time.Time + } + }) + if err != nil { + a.log.Errorf("Failed to get selected candidate pair stats: %v", err) + + return CandidatePairStats{}, false + } + + return res, isAvailable +} + +// GetLocalCandidatesStats returns a list of local candidates stats. +func (a *Agent) GetLocalCandidatesStats() []CandidateStats { + return a.getCandidatesStats(true) +} + +// GetRemoteCandidatesStats returns a list of remote candidates stats. +func (a *Agent) GetRemoteCandidatesStats() []CandidateStats { + return a.getCandidatesStats(false) +} + +// getCandidatesStats returns a list of candidates stats. +func (a *Agent) getCandidatesStats(isLocal bool) []CandidateStats { + var res []CandidateStats + err := a.loop.Run(a.loop, func(_ context.Context) { + var candidateMap map[NetworkType][]Candidate + if isLocal { + candidateMap = a.localCandidates + } else { + candidateMap = a.remoteCandidates + } + + result := make([]CandidateStats, 0, len(candidateMap)) + for networkType, candidate := range candidateMap { + for _, cand := range candidate { + relayProtocol := "" + + if isLocal && cand.Type() == CandidateTypeRelay { + if cRelay, ok := cand.(*CandidateRelay); ok { + relayProtocol = cRelay.RelayProtocol() + } + } + + stat := CandidateStats{ + Timestamp: time.Now(), + ID: cand.ID(), + NetworkType: networkType, + IP: cand.Address(), + Port: cand.Port(), + CandidateType: cand.Type(), + Priority: cand.Priority(), + // URL string + RelayProtocol: relayProtocol, + } + result = append(result, stat) + } + } + res = result + }) + if err != nil { + a.log.Errorf("Failed to get candidate pair stats: %v", err) + + return []CandidateStats{} + } + + return res +} diff --git a/vendor/github.com/pion/ice/v4/candidate.go b/vendor/github.com/pion/ice/v4/candidate.go new file mode 100644 index 0000000..dafe07c --- /dev/null +++ b/vendor/github.com/pion/ice/v4/candidate.go @@ -0,0 +1,95 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "context" + "net" + "time" +) + +const ( + receiveMTU = 8192 + defaultLocalPreference = 65535 + + // ComponentRTP indicates that the candidate is used for RTP. + ComponentRTP uint16 = 1 + // ComponentRTCP indicates that the candidate is used for RTCP. + ComponentRTCP +) + +// Candidate represents an ICE candidate. +type Candidate interface { + // An arbitrary string used in the freezing algorithm to + // group similar candidates. It is the same for two candidates that + // have the same type, base IP address, protocol (UDP, TCP, etc.), + // and STUN or TURN server. + Foundation() string + + // ID is a unique identifier for just this candidate + // Unlike the foundation this is different for each candidate + ID() string + + // A component is a piece of a data stream. + // An example is one for RTP, and one for RTCP + Component() uint16 + SetComponent(uint16) + + // The last time this candidate received traffic + LastReceived() time.Time + + // The last time this candidate sent traffic + LastSent() time.Time + + NetworkType() NetworkType + Address() string + Port() int + + Priority() uint32 + + // A transport address related to a + // candidate, which is useful for diagnostics and other purposes + RelatedAddress() *CandidateRelatedAddress + + // Extensions returns a copy of all extension attributes associated with the ICECandidate. + // In the order of insertion, *(key value). + // Extension attributes are defined in RFC 5245, Section 15.1: + // https://datatracker.ietf.org/doc/html/rfc5245#section-15.1 + //. + Extensions() []CandidateExtension + // GetExtension returns the value of the extension attribute associated with the ICECandidate. + // Extension attributes are defined in RFC 5245, Section 15.1: + // https://datatracker.ietf.org/doc/html/rfc5245#section-15.1 + //. + GetExtension(key string) (value CandidateExtension, ok bool) + // AddExtension adds an extension attribute to the ICECandidate. + // If an extension with the same key already exists, it will be overwritten. + // Extension attributes are defined in RFC 5245, Section 15.1: + AddExtension(extension CandidateExtension) error + // RemoveExtension removes an extension attribute from the ICECandidate. + // Extension attributes are defined in RFC 5245, Section 15.1: + RemoveExtension(key string) (ok bool) + + String() string + Type() CandidateType + TCPType() TCPType + + Equal(other Candidate) bool + + // DeepEqual same as Equal, But it also compares the candidate extensions. + DeepEqual(other Candidate) bool + + Marshal() string + + addr() net.Addr + filterForLocationTracking() bool + agent() *Agent + context() context.Context + + close() error + copy() (Candidate, error) + seen(outbound bool) + start(a *Agent, conn net.PacketConn, initializedCh <-chan struct{}) + writeTo(raw []byte, dst Candidate) (int, error) +} diff --git a/vendor/github.com/pion/ice/v4/candidate_base.go b/vendor/github.com/pion/ice/v4/candidate_base.go new file mode 100644 index 0000000..47e20ec --- /dev/null +++ b/vendor/github.com/pion/ice/v4/candidate_base.go @@ -0,0 +1,1097 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "context" + "errors" + "fmt" + "hash/crc32" + "io" + "net" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/pion/stun/v3" +) + +type candidateBase struct { + id string + networkType NetworkType + candidateType CandidateType + + component uint16 + address string + port int + relatedAddress *CandidateRelatedAddress + tcpType TCPType + + resolvedAddr net.Addr + + lastSent atomic.Int64 + lastReceived atomic.Int64 + conn net.PacketConn + + currAgent *Agent + closeCh chan struct{} + closedCh chan struct{} + + foundationOverride string + priorityOverride uint32 + + relayLocalPreference uint16 + + remoteCandidateCaches map[AddrPort]Candidate + isLocationTracked bool + extensions []CandidateExtension +} + +// Save a time reference to calculate monotonic time for candidate last sent/received. +// nolint: gochecknoglobals +var timeRef = time.Now() + +// getMonoNanos returns the monotonic nanoseconds of a time t since timeRef. +func getMonoNanos(t time.Time) int64 { + return t.Sub(timeRef).Nanoseconds() +} + +// getMonoTime returns a time.Time based on monotonic nanos since timeRef. +func getMonoTime(nanos int64) time.Time { + return timeRef.Add(time.Duration(nanos)) +} + +// Done implements context.Context. +func (c *candidateBase) Done() <-chan struct{} { + return c.closeCh +} + +// Err implements context.Context. +func (c *candidateBase) Err() error { + select { + case <-c.closedCh: + return ErrRunCanceled + default: + return nil + } +} + +// Deadline implements context.Context. +func (c *candidateBase) Deadline() (deadline time.Time, ok bool) { + return time.Time{}, false +} + +// Value implements context.Context. +func (c *candidateBase) Value(any) any { + return nil +} + +// setWriteDeadline is used by upper layers to push write deadlines down to the +// underlying packet connection. +func (c *candidateBase) setWriteDeadline(t time.Time) error { + if c.conn == nil { + return nil + } + + return c.conn.SetWriteDeadline(t) +} + +// ID returns Candidate ID. +func (c *candidateBase) ID() string { + return c.id +} + +func (c *candidateBase) Foundation() string { + if c.foundationOverride != "" { + return c.foundationOverride + } + + return fmt.Sprintf("%d", crc32.ChecksumIEEE([]byte(c.Type().String()+c.address+c.networkType.String()))) +} + +// Address returns Candidate Address. +func (c *candidateBase) Address() string { + return c.address +} + +// Port returns Candidate Port. +func (c *candidateBase) Port() int { + return c.port +} + +// Type returns candidate type. +func (c *candidateBase) Type() CandidateType { + return c.candidateType +} + +// NetworkType returns candidate NetworkType. +func (c *candidateBase) NetworkType() NetworkType { + return c.networkType +} + +// Component returns candidate component. +func (c *candidateBase) Component() uint16 { + return c.component +} + +func (c *candidateBase) SetComponent(component uint16) { + c.component = component +} + +// LocalPreference returns the local preference for this candidate. +func (c *candidateBase) LocalPreference() uint16 { //nolint:cyclop + if c.candidateType == CandidateTypeRelay { + return c.relayLocalPreference + } + + if c.NetworkType().IsTCP() { + // RFC 6544, section 4.2 + // + // In Section 4.1.2.1 of [RFC5245], a recommended formula for UDP ICE + // candidate prioritization is defined. For TCP candidates, the same + // formula and candidate type preferences SHOULD be used, and the + // RECOMMENDED type preferences for the new candidate types defined in + // this document (see Section 5) are 105 for NAT-assisted candidates and + // 75 for UDP-tunneled candidates. + // + // (...) + // + // With TCP candidates, the local preference part of the recommended + // priority formula is updated to also include the directionality + // (active, passive, or simultaneous-open) of the TCP connection. The + // RECOMMENDED local preference is then defined as: + // + // local preference = (2^13) * direction-pref + other-pref + // + // The direction-pref MUST be between 0 and 7 (both inclusive), with 7 + // being the most preferred. The other-pref MUST be between 0 and 8191 + // (both inclusive), with 8191 being the most preferred. It is + // RECOMMENDED that the host, UDP-tunneled, and relayed TCP candidates + // have the direction-pref assigned as follows: 6 for active, 4 for + // passive, and 2 for S-O. For the NAT-assisted and server reflexive + // candidates, the RECOMMENDED values are: 6 for S-O, 4 for active, and + // 2 for passive. + // + // (...) + // + // If any two candidates have the same type-preference and direction- + // pref, they MUST have a unique other-pref. With this specification, + // this usually only happens with multi-homed hosts, in which case + // other-pref is the preference for the particular IP address from which + // the candidate was obtained. When there is only a single IP address, + // this value SHOULD be set to the maximum allowed value (8191). + var otherPref uint16 = 8191 + + directionPref := func() uint16 { + switch c.Type() { + case CandidateTypeHost, CandidateTypeRelay: + switch c.tcpType { + case TCPTypeActive: + return 6 + case TCPTypePassive: + return 4 + case TCPTypeSimultaneousOpen: + return 2 + case TCPTypeUnspecified: + return 0 + } + case CandidateTypePeerReflexive, CandidateTypeServerReflexive: + switch c.tcpType { + case TCPTypeSimultaneousOpen: + return 6 + case TCPTypeActive: + return 4 + case TCPTypePassive: + return 2 + case TCPTypeUnspecified: + return 0 + } + case CandidateTypeUnspecified: + return 0 + } + + return 0 + }() + + return (1<<13)*directionPref + otherPref + } + + return defaultLocalPreference +} + +// RelatedAddress returns *CandidateRelatedAddress. +func (c *candidateBase) RelatedAddress() *CandidateRelatedAddress { + return c.relatedAddress +} + +func (c *candidateBase) TCPType() TCPType { + return c.tcpType +} + +// start runs the candidate using the provided connection. +func (c *candidateBase) start(a *Agent, conn net.PacketConn, initializedCh <-chan struct{}) { + if c.conn != nil { + c.agent().log.Warn("Can't start already started candidateBase") + + return + } + c.currAgent = a + c.conn = conn + c.closeCh = make(chan struct{}) + c.closedCh = make(chan struct{}) + + go c.recvLoop(initializedCh) +} + +var bufferPool = sync.Pool{ // nolint:gochecknoglobals + New: func() any { + return make([]byte, receiveMTU) + }, +} + +func (c *candidateBase) recvLoop(initializedCh <-chan struct{}) { + agent := c.agent() + + defer close(c.closedCh) + + select { + case <-initializedCh: + case <-c.closeCh: + return + } + + bufferPoolBuffer := bufferPool.Get() + defer bufferPool.Put(bufferPoolBuffer) + buf, ok := bufferPoolBuffer.([]byte) + if !ok { + return + } + + for { + n, srcAddr, err := c.conn.ReadFrom(buf) + if err != nil { + if !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { + agent.log.Warnf("Failed to read from candidate %s: %v", c, err) + } + + return + } + + c.handleInboundPacket(buf[:n], srcAddr) + } +} + +func (c *candidateBase) validateSTUNTrafficCache(addr net.Addr) bool { + if candidate, ok := c.remoteCandidateCaches[toAddrPort(addr)]; ok { + candidate.seen(false) + + return true + } + + return false +} + +func (c *candidateBase) addRemoteCandidateCache(candidate Candidate, srcAddr net.Addr) { + if c.validateSTUNTrafficCache(srcAddr) { + return + } + c.remoteCandidateCaches[toAddrPort(srcAddr)] = candidate +} + +func (c *candidateBase) handleInboundPacket(buf []byte, srcAddr net.Addr) { + agent := c.agent() + + if stun.IsMessage(buf) { + msg := &stun.Message{ + Raw: make([]byte, len(buf)), + } + + // Explicitly copy raw buffer so Message can own the memory. + copy(msg.Raw, buf) + + if err := msg.Decode(); err != nil { + agent.log.Warnf("Failed to handle decode ICE from %s to %s: %v", c.addr(), srcAddr, err) + + return + } + + if err := agent.loop.Run(c, func(_ context.Context) { + // nolint: contextcheck + agent.handleInbound(msg, c, srcAddr) + }); err != nil { + agent.log.Warnf("Failed to handle message: %v", err) + } + + return + } + + if !c.validateSTUNTrafficCache(srcAddr) { + remoteCandidate, valid := agent.validateNonSTUNTraffic(c, srcAddr) //nolint:contextcheck + if !valid { + agent.log.Warnf("Discarded message from %s, not a valid remote candidate", c.addr()) + + return + } + c.addRemoteCandidateCache(remoteCandidate, srcAddr) + } + + // Note: This will return packetio.ErrFull if the buffer ever manages to fill up. + n, err := agent.buf.Write(buf) + if err != nil { + agent.log.Warnf("Failed to write packet: %s", err) + + return + } + + // Add received application bytes to the currently selected candidate pair. + if n > 0 { + if sp := agent.getSelectedPair(); sp != nil { + sp.UpdatePacketReceived(n) + } + } +} + +// close stops the recvLoop. +func (c *candidateBase) close() error { + // If conn has never been started will be nil + if c.Done() == nil { + return nil + } + + // Assert that conn has not already been closed + select { + case <-c.Done(): + return nil + default: + } + + var firstErr error + + // Unblock recvLoop + close(c.closeCh) + if err := c.conn.SetDeadline(time.Now()); err != nil { + firstErr = err + } + + // Close the conn + if err := c.conn.Close(); err != nil && firstErr == nil { + firstErr = err + } + + if firstErr != nil { + return firstErr + } + + // Wait until the recvLoop is closed + <-c.closedCh + + return nil +} + +func (c *candidateBase) writeTo(raw []byte, dst Candidate) (int, error) { + n, err := c.conn.WriteTo(raw, dst.addr()) + if err != nil { + // If the connection is closed, we should return the error + if errors.Is(err, io.ErrClosedPipe) { + return n, err + } + c.agent().log.Infof("Failed to send packet: %v", err) + + return n, nil + } + c.seen(true) + + return n, nil +} + +// TypePreference returns the type preference for this candidate. +func (c *candidateBase) TypePreference() uint16 { + pref := c.Type().Preference() + if pref == 0 { + return 0 + } + + if c.NetworkType().IsTCP() { + var tcpPriorityOffset uint16 = defaultTCPPriorityOffset + if c.agent() != nil { + tcpPriorityOffset = c.agent().tcpPriorityOffset + } + + pref -= tcpPriorityOffset + } + + return pref +} + +// Priority computes the priority for this ICE Candidate +// See: https://www.rfc-editor.org/rfc/rfc8445#section-5.1.2.1 +func (c *candidateBase) Priority() uint32 { + if c.priorityOverride != 0 { + return c.priorityOverride + } + + // The local preference MUST be an integer from 0 (lowest preference) to + // 65535 (highest preference) inclusive. When there is only a single IP + // address, this value SHOULD be set to 65535. If there are multiple + // candidates for a particular component for a particular data stream + // that have the same type, the local preference MUST be unique for each + // one. + + return (1<<24)*uint32(c.TypePreference()) + + (1<<8)*uint32(c.LocalPreference()) + + (1<<0)*uint32(256-c.Component()) +} + +// Equal is used to compare two candidateBases. +func (c *candidateBase) Equal(other Candidate) bool { + if c.addr() != other.addr() { + if c.addr() == nil || other.addr() == nil { + return false + } + if !addrEqual(c.addr(), other.addr()) { + return false + } + } + + return c.NetworkType() == other.NetworkType() && + c.Type() == other.Type() && + c.Address() == other.Address() && + c.Port() == other.Port() && + c.TCPType() == other.TCPType() && + c.RelatedAddress().Equal(other.RelatedAddress()) +} + +// DeepEqual is same as Equal but also compares the extensions. +func (c *candidateBase) DeepEqual(other Candidate) bool { + return c.Equal(other) && c.extensionsEqual(other.Extensions()) +} + +// String makes the candidateBase printable. +func (c *candidateBase) String() string { + return fmt.Sprintf( + "%s %s %s%s (resolved: %v)", + c.NetworkType(), + c.Type(), + net.JoinHostPort(c.Address(), strconv.Itoa(c.Port())), + c.relatedAddress, + c.resolvedAddr, + ) +} + +// LastReceived returns a time.Time indicating the last time +// this candidate was received. +func (c *candidateBase) LastReceived() time.Time { + if lastReceived := c.lastReceived.Load(); lastReceived != 0 { + return getMonoTime(lastReceived) + } + + return time.Time{} +} + +func (c *candidateBase) setLastReceived(t time.Time) { + c.lastReceived.Store(getMonoNanos(t)) +} + +// LastSent returns a time.Time indicating the last time +// this candidate was sent. +func (c *candidateBase) LastSent() time.Time { + if lastSent := c.lastSent.Load(); lastSent != 0 { + return getMonoTime(lastSent) + } + + return time.Time{} +} + +func (c *candidateBase) setLastSent(t time.Time) { + c.lastSent.Store(getMonoNanos(t)) +} + +func (c *candidateBase) seen(outbound bool) { + if outbound { + c.setLastSent(time.Now()) + } else { + c.setLastReceived(time.Now()) + } +} + +func (c *candidateBase) addr() net.Addr { + return c.resolvedAddr +} + +func (c *candidateBase) filterForLocationTracking() bool { + return c.isLocationTracked +} + +func (c *candidateBase) agent() *Agent { + return c.currAgent +} + +func (c *candidateBase) context() context.Context { + return c +} + +func (c *candidateBase) copy() (Candidate, error) { + return UnmarshalCandidate(c.Marshal()) +} + +func removeZoneIDFromAddress(addr string) string { + if before, _, ok := strings.Cut(addr, "%"); ok { + return before + } + + return addr +} + +// Marshal returns the string representation of the ICECandidate. +func (c *candidateBase) Marshal() string { + val := c.Foundation() + if val == " " { + val = "" + } + + val = fmt.Sprintf("%s %d %s %d %s %d typ %s", + val, + c.Component(), + c.NetworkType().NetworkShort(), + c.Priority(), + removeZoneIDFromAddress(c.Address()), + c.Port(), + c.Type()) + + if r := c.RelatedAddress(); r != nil && r.Address != "" && r.Port != 0 { + val = fmt.Sprintf("%s raddr %s rport %d", + val, + r.Address, + r.Port) + } + + extensions := c.marshalExtensions() + + if extensions != "" { + val = fmt.Sprintf("%s %s", val, extensions) + } + + return val +} + +// CandidateExtension represents a single candidate extension +// as defined in https://tools.ietf.org/html/rfc5245#section-15.1 +// . +type CandidateExtension struct { + Key string + Value string +} + +func (c *candidateBase) Extensions() []CandidateExtension { + tcpType := c.TCPType() + hasTCPType := 0 + if tcpType != TCPTypeUnspecified { + hasTCPType = 1 + } + + extensions := make([]CandidateExtension, len(c.extensions)+hasTCPType) + // We store the TCPType in c.tcpType, but we need to return it as an extension. + if hasTCPType == 1 { + extensions[0] = CandidateExtension{ + Key: "tcptype", + Value: tcpType.String(), + } + } + + copy(extensions[hasTCPType:], c.extensions) + + return extensions +} + +// Get returns the value of the given key if it exists. +func (c *candidateBase) GetExtension(key string) (CandidateExtension, bool) { + extension := CandidateExtension{Key: key} + + for i := range c.extensions { + if c.extensions[i].Key == key { + extension.Value = c.extensions[i].Value + + return extension, true + } + } + + // TCPType was manually set. + if key == "tcptype" && c.TCPType() != TCPTypeUnspecified { //nolint:goconst + extension.Value = c.TCPType().String() + + return extension, true + } + + return extension, false +} + +func (c *candidateBase) AddExtension(ext CandidateExtension) error { + if ext.Key == "tcptype" { + tcpType := NewTCPType(ext.Value) + if tcpType == TCPTypeUnspecified { + return fmt.Errorf("%w: invalid or unsupported TCPtype %s", errParseTCPType, ext.Value) + } + + c.tcpType = tcpType + + return nil + } + + if ext.Key == "" { + return fmt.Errorf("%w: key is empty", errParseExtension) + } + + // per spec, Extensions aren't explicitly unique, we only set the first one. + // If the exteion is set multiple times. + for i := range c.extensions { + if c.extensions[i].Key == ext.Key { + c.extensions[i] = ext + + return nil + } + } + + c.extensions = append(c.extensions, ext) + + return nil +} + +func (c *candidateBase) RemoveExtension(key string) (ok bool) { + if key == "tcptype" { + c.tcpType = TCPTypeUnspecified + ok = true + } + + for i := range c.extensions { + if c.extensions[i].Key == key { + c.extensions = append(c.extensions[:i], c.extensions[i+1:]...) + ok = true + + break + } + } + + return ok +} + +// marshalExtensions returns the string representation of the candidate extensions. +func (c *candidateBase) marshalExtensions() string { + value := "" + exts := c.Extensions() + + for i := range exts { + if value != "" { + value += " " + } + + value += exts[i].Key + " " + exts[i].Value + } + + return value +} + +// Equal returns true if the candidate extensions are equal. +func (c *candidateBase) extensionsEqual(other []CandidateExtension) bool { + freq1 := make(map[CandidateExtension]int) + freq2 := make(map[CandidateExtension]int) + + if len(c.extensions) != len(other) { + return false + } + + if len(c.extensions) == 0 { + return true + } + + if len(c.extensions) == 1 { + return c.extensions[0] == other[0] + } + + for i := range c.extensions { + freq1[c.extensions[i]]++ + freq2[other[i]]++ + } + + for k, v := range freq1 { + if freq2[k] != v { + return false + } + } + + return true +} + +func (c *candidateBase) setExtensions(extensions []CandidateExtension) { + c.extensions = extensions +} + +// UnmarshalCandidate Parses a candidate from a string +// https://datatracker.ietf.org/doc/html/rfc5245#section-15.1 +func UnmarshalCandidate(raw string) (Candidate, error) { //nolint:cyclop + // Handle candidates with the "candidate:" prefix as defined in RFC 5245 section 15.1. + raw = strings.TrimPrefix(raw, "candidate:") + + pos := 0 + // foundation ( 1*32ice-char ) But we allow for empty foundation, + foundation, pos, err := readCandidateCharToken(raw, pos, 32) + if err != nil { + return nil, fmt.Errorf("%w: %v in %s", errParseFoundation, err, raw) //nolint:errorlint // we wrap the error + } + + // Empty foundation, not RFC 8445 compliant but seen in the wild + if foundation == "" { + foundation = " " + } + + if pos >= len(raw) { + return nil, fmt.Errorf("%w: expected component in %s", errAttributeTooShortICECandidate, raw) + } + + // component-id ( 1*5DIGIT ) + component, pos, err := readCandidateDigitToken(raw, pos, 5) + if err != nil { + return nil, fmt.Errorf("%w: %v in %s", errParseComponent, err, raw) //nolint:errorlint // we wrap the error + } + + if pos >= len(raw) { + return nil, fmt.Errorf("%w: expected transport in %s", errAttributeTooShortICECandidate, raw) + } + + // transport ( "UDP" / transport-extension ; from RFC 3261 ) SP + protocol, pos := readCandidateStringToken(raw, pos) + + if pos >= len(raw) { + return nil, fmt.Errorf("%w: expected priority in %s", errAttributeTooShortICECandidate, raw) + } + + // priority ( 1*10DIGIT ) SP + priority, pos, err := readCandidateDigitToken(raw, pos, 10) + if err != nil { + return nil, fmt.Errorf("%w: %v in %s", errParsePriority, err, raw) //nolint:errorlint // we wrap the error + } + + if pos >= len(raw) { + return nil, fmt.Errorf("%w: expected address in %s", errAttributeTooShortICECandidate, raw) + } + + // connection-address SP ;from RFC 4566 + address, pos := readCandidateStringToken(raw, pos) + + // Remove IPv6 ZoneID: https://github.com/pion/ice/pull/704 + address = removeZoneIDFromAddress(address) + + if pos >= len(raw) { + return nil, fmt.Errorf("%w: expected port in %s", errAttributeTooShortICECandidate, raw) + } + + // port from RFC 4566 + port, pos, err := readCandidatePort(raw, pos) + if err != nil { + return nil, fmt.Errorf("%w: %v in %s", errParsePort, err, raw) //nolint:errorlint // we wrap the error + } + + // "typ" SP + typeKey, pos := readCandidateStringToken(raw, pos) + if typeKey != "typ" { + return nil, fmt.Errorf("%w (%s)", ErrUnknownCandidateTyp, typeKey) + } + + if pos >= len(raw) { + return nil, fmt.Errorf("%w: expected candidate type in %s", errAttributeTooShortICECandidate, raw) + } + + // SP cand-type ("host" / "srflx" / "prflx" / "relay") + typ, pos := readCandidateStringToken(raw, pos) + + raddr, rport, pos, err := tryReadRelativeAddrs(raw, pos) + if err != nil { + return nil, err + } + + tcpType := TCPTypeUnspecified + var extensions []CandidateExtension + var tcpTypeRaw string + + if pos < len(raw) { + extensions, tcpTypeRaw, err = unmarshalCandidateExtensions(raw[pos:]) + if err != nil { + return nil, fmt.Errorf("%w: %v", errParseExtension, err) //nolint:errorlint // we wrap the error + } + + if tcpTypeRaw != "" { + tcpType = NewTCPType(tcpTypeRaw) + if tcpType == TCPTypeUnspecified { + return nil, fmt.Errorf("%w: invalid or unsupported TCPtype %s", errParseTCPType, tcpTypeRaw) + } + } + } + + // this code is ugly because we can't break backwards compatibility + // with the old way of parsing candidates + switch typ { + case "host": + candidate, err := NewCandidateHost(&CandidateHostConfig{ + "", + protocol, + address, + port, + uint16(component), //nolint:gosec // G115 no overflow we read 5 digits + uint32(priority), //nolint:gosec // G115 no overflow we read 5 digits + foundation, + tcpType, + false, + }) + if err != nil { + return nil, err + } + + candidate.setExtensions(extensions) + + return candidate, nil + case "srflx": + candidate, err := NewCandidateServerReflexive(&CandidateServerReflexiveConfig{ + "", + protocol, + address, + port, + uint16(component), //nolint:gosec // G115 no overflow we read 5 digits + uint32(priority), //nolint:gosec // G115 no overflow we read 5 digits + foundation, + raddr, + rport, + }) + if err != nil { + return nil, err + } + + candidate.setExtensions(extensions) + + return candidate, nil + case "prflx": + candidate, err := NewCandidatePeerReflexive(&CandidatePeerReflexiveConfig{ + "", + protocol, + address, + port, + uint16(component), //nolint:gosec // G115 no overflow we read 5 digits + uint32(priority), //nolint:gosec // G115 no overflow we read 5 digits + foundation, + raddr, + rport, + }) + if err != nil { + return nil, err + } + + candidate.setExtensions(extensions) + + return candidate, nil + case "relay": + candidate, err := NewCandidateRelay(&CandidateRelayConfig{ + "", + protocol, + address, + port, + uint16(component), //nolint:gosec // G115 no overflow we read 5 digits + uint32(priority), //nolint:gosec // G115 no overflow we read 5 digits + foundation, + raddr, + rport, + "", + nil, + }) + if err != nil { + return nil, err + } + + candidate.setExtensions(extensions) + + return candidate, nil + default: + return nil, fmt.Errorf("%w (%s)", ErrUnknownCandidateTyp, typ) + } +} + +// Read an ice-char token from the raw string +// ice-char = ALPHA / DIGIT / "+" / "/" +// stop reading when a space is encountered or the end of the string. +func readCandidateCharToken(raw string, start int, limit int) (string, int, error) { //nolint:cyclop + for i, char := range raw[start:] { + if char == 0x20 { // SP + return raw[start : start+i], start + i + 1, nil + } + + if i == limit { + //nolint: err113 // handled by caller + return "", 0, fmt.Errorf("token too long: %s expected 1x%d", raw[start:start+i], limit) + } + + if (char < 'A' || char > 'Z') && + (char < 'a' || char > 'z') && + (char < '0' || char > '9') && + char != '+' && char != '/' { + return "", 0, fmt.Errorf("invalid ice-char token: %c", char) //nolint: err113 // handled by caller + } + } + + return raw[start:], len(raw), nil +} + +// Read an ice string token from the raw string until a space is encountered +// Or the end of the string, we imply that ice string are UTF-8 encoded. +func readCandidateStringToken(raw string, start int) (string, int) { + for i, char := range raw[start:] { + if char == 0x20 { // SP + return raw[start : start+i], start + i + 1 + } + } + + return raw[start:], len(raw) +} + +// Read a digit token from the raw string +// stop reading when a space is encountered or the end of the string. +func readCandidateDigitToken(raw string, start, limit int) (int, int, error) { + var val int + for i, char := range raw[start:] { + if char == 0x20 { // SP + return val, start + i + 1, nil + } + + if i == limit { + //nolint: err113 // handled by caller + return 0, 0, fmt.Errorf("token too long: %s expected 1x%d", raw[start:start+i], limit) + } + + if char < '0' || char > '9' { + return 0, 0, fmt.Errorf("invalid digit token: %c", char) //nolint: err113 // handled by caller + } + + val = val*10 + int(char-'0') + } + + return val, len(raw), nil +} + +// Read and validate RFC 4566 port from the raw string. +func readCandidatePort(raw string, start int) (int, int, error) { + port, pos, err := readCandidateDigitToken(raw, start, 5) + if err != nil { + return 0, 0, err + } + + if port > 65535 { + return 0, 0, fmt.Errorf("invalid RFC 4566 port %d", port) //nolint: err113 // handled by caller + } + + return port, pos, nil +} + +// Read a byte-string token from the raw string +// As defined in RFC 4566 1*(%x01-09/%x0B-0C/%x0E-FF) ;any byte except NUL, CR, or LF +// we imply that extensions byte-string are UTF-8 encoded. +func readCandidateByteString(raw string, start int) (string, int, error) { + for i, char := range raw[start:] { + if char == 0x20 { // SP + return raw[start : start+i], start + i + 1, nil + } + + // 1*(%x01-09/%x0B-0C/%x0E-FF) + if (char < 0x01 || char > 0x09) && + (char < 0x0B || char > 0x0C) && + (char < 0x0E || char > 0xFF) { + return "", 0, fmt.Errorf("invalid byte-string character: %c", char) //nolint: err113 // handled by caller + } + } + + return raw[start:], len(raw), nil +} + +// Read and validate raddr and rport from the raw string +// [SP rel-addr] [SP rel-port] +// defined in https://datatracker.ietf.org/doc/html/rfc5245#section-15.1 +// . +func tryReadRelativeAddrs(raw string, start int) (raddr string, rport, pos int, err error) { + key, pos := readCandidateStringToken(raw, start) + + if key != "raddr" { + return "", 0, start, nil + } + + if pos >= len(raw) { + return "", 0, 0, fmt.Errorf("%w: expected raddr value in %s", errParseRelatedAddr, raw) + } + + raddr, pos = readCandidateStringToken(raw, pos) + + if pos >= len(raw) { + return "", 0, 0, fmt.Errorf("%w: expected rport in %s", errParseRelatedAddr, raw) + } + + key, pos = readCandidateStringToken(raw, pos) + if key != "rport" { + return "", 0, 0, fmt.Errorf("%w: expected rport in %s", errParseRelatedAddr, raw) + } + + if pos >= len(raw) { + return "", 0, 0, fmt.Errorf("%w: expected rport value in %s", errParseRelatedAddr, raw) + } + + rport, pos, err = readCandidatePort(raw, pos) + if err != nil { + return "", 0, 0, fmt.Errorf("%w: %v", errParseRelatedAddr, err) //nolint:errorlint // we wrap the error + } + + return raddr, rport, pos, nil +} + +// UnmarshalCandidateExtensions parses the candidate extensions from the raw string. +// *(SP extension-att-name SP extension-att-value) +// Where extension-att-name, and extension-att-value are byte-strings +// as defined in https://tools.ietf.org/html/rfc5245#section-15.1 +func unmarshalCandidateExtensions(raw string) (extensions []CandidateExtension, rawTCPTypeRaw string, err error) { + extensions = make([]CandidateExtension, 0) + + if raw == "" { + return extensions, "", nil + } + + if raw[0] == 0x20 { // SP + return extensions, "", fmt.Errorf("%w: unexpected space %s", errParseExtension, raw) + } + + for i := 0; i < len(raw); { + key, next, err := readCandidateByteString(raw, i) + if err != nil { + return extensions, "", fmt.Errorf( + "%w: failed to read key %v", errParseExtension, err, //nolint: errorlint // we wrap the error + ) + } + i = next + + // while not spec-compliant, we allow for empty values, as seen in the wild + var value string + if i < len(raw) { + value, next, err = readCandidateByteString(raw, i) + if err != nil { + return extensions, "", fmt.Errorf( + "%w: failed to read value %v", errParseExtension, err, //nolint: errorlint // we are wrapping the error + ) + } + i = next + } + + if key == "tcptype" { + rawTCPTypeRaw = value + + continue + } + + extensions = append(extensions, CandidateExtension{key, value}) + } + + return extensions, rawTCPTypeRaw, nil +} diff --git a/vendor/github.com/pion/ice/v4/candidate_host.go b/vendor/github.com/pion/ice/v4/candidate_host.go new file mode 100644 index 0000000..d49e5d2 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/candidate_host.go @@ -0,0 +1,82 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "net/netip" + "strings" +) + +// CandidateHost is a candidate of type host. +type CandidateHost struct { + candidateBase + + network string +} + +// CandidateHostConfig is the config required to create a new CandidateHost. +type CandidateHostConfig struct { + CandidateID string + Network string + Address string + Port int + Component uint16 + Priority uint32 + Foundation string + TCPType TCPType + IsLocationTracked bool +} + +// NewCandidateHost creates a new host candidate. +func NewCandidateHost(config *CandidateHostConfig) (*CandidateHost, error) { + candidateID := config.CandidateID + + if candidateID == "" { + candidateID = globalCandidateIDGenerator.Generate() + } + + candidateHost := &CandidateHost{ + candidateBase: candidateBase{ + id: candidateID, + address: config.Address, + candidateType: CandidateTypeHost, + component: config.Component, + port: config.Port, + tcpType: config.TCPType, + foundationOverride: config.Foundation, + priorityOverride: config.Priority, + remoteCandidateCaches: map[AddrPort]Candidate{}, + isLocationTracked: config.IsLocationTracked, + }, + network: config.Network, + } + + if !strings.HasSuffix(config.Address, ".local") && !strings.HasSuffix(config.Address, ".invalid") { + ipAddr, err := netip.ParseAddr(config.Address) + if err != nil { + return nil, err + } + + if err := candidateHost.setIPAddr(ipAddr); err != nil { + return nil, err + } + } else { + // Until mDNS candidate is resolved assume it is UDPv4 + candidateHost.candidateBase.networkType = NetworkTypeUDP4 + } + + return candidateHost, nil +} + +func (c *CandidateHost) setIPAddr(addr netip.Addr) error { + networkType, err := determineNetworkType(c.network, addr) + if err != nil { + return err + } + + c.candidateBase.networkType = networkType + c.candidateBase.resolvedAddr = createAddr(networkType, addr, c.port) + + return nil +} diff --git a/vendor/github.com/pion/ice/v4/candidate_peer_reflexive.go b/vendor/github.com/pion/ice/v4/candidate_peer_reflexive.go new file mode 100644 index 0000000..e4e06f8 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/candidate_peer_reflexive.go @@ -0,0 +1,66 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package ice ... +// +//nolint:dupl +package ice + +import ( + "net/netip" +) + +// CandidatePeerReflexive ... +type CandidatePeerReflexive struct { + candidateBase +} + +// CandidatePeerReflexiveConfig is the config required to create a new CandidatePeerReflexive. +type CandidatePeerReflexiveConfig struct { + CandidateID string + Network string + Address string + Port int + Component uint16 + Priority uint32 + Foundation string + RelAddr string + RelPort int +} + +// NewCandidatePeerReflexive creates a new peer reflective candidate. +func NewCandidatePeerReflexive(config *CandidatePeerReflexiveConfig) (*CandidatePeerReflexive, error) { + ipAddr, err := netip.ParseAddr(config.Address) + if err != nil { + return nil, err + } + + networkType, err := determineNetworkType(config.Network, ipAddr) + if err != nil { + return nil, err + } + + candidateID := config.CandidateID + if candidateID == "" { + candidateID = globalCandidateIDGenerator.Generate() + } + + return &CandidatePeerReflexive{ + candidateBase: candidateBase{ + id: candidateID, + networkType: networkType, + candidateType: CandidateTypePeerReflexive, + address: config.Address, + port: config.Port, + resolvedAddr: createAddr(networkType, ipAddr, config.Port), + component: config.Component, + foundationOverride: config.Foundation, + priorityOverride: config.Priority, + relatedAddress: &CandidateRelatedAddress{ + Address: config.RelAddr, + Port: config.RelPort, + }, + remoteCandidateCaches: map[AddrPort]Candidate{}, + }, + }, nil +} diff --git a/vendor/github.com/pion/ice/v4/candidate_relay.go b/vendor/github.com/pion/ice/v4/candidate_relay.go new file mode 100644 index 0000000..3adaebe --- /dev/null +++ b/vendor/github.com/pion/ice/v4/candidate_relay.go @@ -0,0 +1,129 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "net" + "net/netip" +) + +const ( + // These preference values come from libwebrtc + //nolint:lll + // https://source.chromium.org/chromium/chromium/src/+/main:third_party/webrtc/p2p/base/p2p_constants.h;l=126;drc=bf712ec1a13783224debb691ba88ad5c15b93194 + preferenceRelayTLS = 0 + preferenceRelayTCP = 1 + preferenceRelayDTLS = 2 + preferenceRelayUDP = 3 +) + +// CandidateRelay ... +type CandidateRelay struct { + candidateBase + + relayProtocol string + onClose func() error +} + +// CandidateRelayConfig is the config required to create a new CandidateRelay. +type CandidateRelayConfig struct { + CandidateID string + Network string + Address string + Port int + Component uint16 + Priority uint32 + Foundation string + RelAddr string + RelPort int + RelayProtocol string + OnClose func() error +} + +// NewCandidateRelay creates a new relay candidate. +func NewCandidateRelay(config *CandidateRelayConfig) (*CandidateRelay, error) { + candidateID := config.CandidateID + + if candidateID == "" { + candidateID = globalCandidateIDGenerator.Generate() + } + + ipAddr, err := netip.ParseAddr(config.Address) + if err != nil { + return nil, err + } + + networkType, err := determineNetworkType(config.Network, ipAddr) + if err != nil { + return nil, err + } + + return &CandidateRelay{ + candidateBase: candidateBase{ + id: candidateID, + networkType: networkType, + candidateType: CandidateTypeRelay, + address: config.Address, + port: config.Port, + resolvedAddr: &net.UDPAddr{ + IP: ipAddr.AsSlice(), + Port: config.Port, + Zone: ipAddr.Zone(), + }, + component: config.Component, + foundationOverride: config.Foundation, + priorityOverride: config.Priority, + relatedAddress: &CandidateRelatedAddress{ + Address: config.RelAddr, + Port: config.RelPort, + }, + relayLocalPreference: relayProtocolPreference(config.RelayProtocol), + remoteCandidateCaches: map[AddrPort]Candidate{}, + }, + relayProtocol: config.RelayProtocol, + onClose: config.OnClose, + }, nil +} + +// RelayProtocol returns the protocol used between the endpoint and the relay server. +func (c *CandidateRelay) RelayProtocol() string { + return c.relayProtocol +} + +func (c *CandidateRelay) close() error { + err := c.candidateBase.close() + if c.onClose != nil { + err = c.onClose() + c.onClose = nil + } + + return err +} + +func (c *CandidateRelay) copy() (Candidate, error) { + cc, err := c.candidateBase.copy() + if err != nil { + return nil, err + } + + if ccr, ok := cc.(*CandidateRelay); ok { + ccr.relayProtocol = c.relayProtocol + } + + return cc, nil +} + +// relayProtocolPreference returns the preference for the relay protocol. +func relayProtocolPreference(relayProtocol string) uint16 { + switch relayProtocol { + case relayProtocolTLS: + return preferenceRelayTLS + case tcp: + return preferenceRelayTCP + case relayProtocolDTLS: + return preferenceRelayDTLS + default: + return preferenceRelayUDP + } +} diff --git a/vendor/github.com/pion/ice/v4/candidate_server_reflexive.go b/vendor/github.com/pion/ice/v4/candidate_server_reflexive.go new file mode 100644 index 0000000..7e2eb9d --- /dev/null +++ b/vendor/github.com/pion/ice/v4/candidate_server_reflexive.go @@ -0,0 +1,68 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "net" + "net/netip" +) + +// CandidateServerReflexive ... +type CandidateServerReflexive struct { + candidateBase +} + +// CandidateServerReflexiveConfig is the config required to create a new CandidateServerReflexive. +type CandidateServerReflexiveConfig struct { + CandidateID string + Network string + Address string + Port int + Component uint16 + Priority uint32 + Foundation string + RelAddr string + RelPort int +} + +// NewCandidateServerReflexive creates a new server reflective candidate. +func NewCandidateServerReflexive(config *CandidateServerReflexiveConfig) (*CandidateServerReflexive, error) { + ipAddr, err := netip.ParseAddr(config.Address) + if err != nil { + return nil, err + } + + networkType, err := determineNetworkType(config.Network, ipAddr) + if err != nil { + return nil, err + } + + candidateID := config.CandidateID + if candidateID == "" { + candidateID = globalCandidateIDGenerator.Generate() + } + + return &CandidateServerReflexive{ + candidateBase: candidateBase{ + id: candidateID, + networkType: networkType, + candidateType: CandidateTypeServerReflexive, + address: config.Address, + port: config.Port, + resolvedAddr: &net.UDPAddr{ + IP: ipAddr.AsSlice(), + Port: config.Port, + Zone: ipAddr.Zone(), + }, + component: config.Component, + foundationOverride: config.Foundation, + priorityOverride: config.Priority, + relatedAddress: &CandidateRelatedAddress{ + Address: config.RelAddr, + Port: config.RelPort, + }, + remoteCandidateCaches: map[AddrPort]Candidate{}, + }, + }, nil +} diff --git a/vendor/github.com/pion/ice/v4/candidatepair.go b/vendor/github.com/pion/ice/v4/candidatepair.go new file mode 100644 index 0000000..c0a7daf --- /dev/null +++ b/vendor/github.com/pion/ice/v4/candidatepair.go @@ -0,0 +1,335 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "fmt" + "sync/atomic" + "time" + + "github.com/pion/stun/v3" +) + +func newCandidatePair(local, remote Candidate, controlling bool) *CandidatePair { + return &CandidatePair{ + iceRoleControlling: controlling, + Remote: remote, + Local: local, + state: CandidatePairStateWaiting, + } +} + +// CandidatePair is a combination of a local and remote candidate. +type CandidatePair struct { + id uint64 + iceRoleControlling bool + Remote Candidate + Local Candidate + bindingRequestCount uint16 + state CandidatePairState + nominated bool + nominateOnBindingSuccess bool + + // stats + currentRoundTripTime int64 // in ns + totalRoundTripTime int64 // in ns + + packetsSent uint32 + packetsReceived uint32 + bytesSent uint64 + bytesReceived uint64 + lastPacketSentAt atomic.Value // time.Time + lastPacketReceivedAt atomic.Value // time.Time + + requestsReceived uint64 + requestsSent uint64 + responsesReceived uint64 + responsesSent uint64 + + firstRequestSentAt atomic.Value // time.Time + lastRequestSentAt atomic.Value // time.Time + firstResponseReceivedAt atomic.Value // time.Time + lastResponseReceivedAt atomic.Value // time.Time + firstRequestReceivedAt atomic.Value // time.Time + lastRequestReceivedAt atomic.Value // time.Time +} + +func (p *CandidatePair) String() string { + if p == nil { + return "" + } + + return fmt.Sprintf( + "prio %d (local, prio %d) %s <-> %s (remote, prio %d), state: %s, nominated: %v, nominateOnBindingSuccess: %v", + p.priority(), + p.Local.Priority(), + p.Local, + p.Remote, + p.Remote.Priority(), + p.state, + p.nominated, + p.nominateOnBindingSuccess, + ) +} + +func (p *CandidatePair) equal(other *CandidatePair) bool { + if p == nil && other == nil { + return true + } + if p == nil || other == nil { + return false + } + + return p.Local.Equal(other.Local) && p.Remote.Equal(other.Remote) +} + +// RFC 5245 - 5.7.2. Computing Pair Priority and Ordering Pairs +// Let G be the priority for the candidate provided by the controlling +// agent. Let D be the priority for the candidate provided by the +// controlled agent. +// pair priority = 2^32*MIN(G,D) + 2*MAX(G,D) + (G>D?1:0). +func (p *CandidatePair) priority() uint64 { + var g, d uint32 //nolint:varnamelen // clearer to use g and d here + if p.iceRoleControlling { + g = p.Local.Priority() + d = p.Remote.Priority() + } else { + g = p.Remote.Priority() + d = p.Local.Priority() + } + + // Just implement these here rather + // than fooling around with the math package + localMin := func(x, y uint32) uint64 { + if x < y { + return uint64(x) + } + + return uint64(y) + } + localMax := func(x, y uint32) uint64 { + if x > y { + return uint64(x) + } + + return uint64(y) + } + cmp := func(x, y uint32) uint64 { + if x > y { + return uint64(1) + } + + return uint64(0) + } + + // 1<<32 overflows uint32; and if both g && d are + // maxUint32, this result would overflow uint64 + return (1<<32-1)*localMin(g, d) + 2*localMax(g, d) + cmp(g, d) +} + +func (p *CandidatePair) Write(b []byte) (int, error) { + return p.Local.writeTo(b, p.Remote) +} + +func (a *Agent) sendSTUN(msg *stun.Message, local, remote Candidate) { + _, err := local.writeTo(msg.Raw, remote) + if err != nil { + a.log.Tracef("Failed to send STUN message: %s", err) + } +} + +// UpdateRoundTripTime sets the current round time of this pair and +// accumulates total round trip time and responses received. +func (p *CandidatePair) UpdateRoundTripTime(rtt time.Duration) { + rttNs := rtt.Nanoseconds() + atomic.StoreInt64(&p.currentRoundTripTime, rttNs) + atomic.AddInt64(&p.totalRoundTripTime, rttNs) + atomic.AddUint64(&p.responsesReceived, 1) + + now := time.Now() + p.firstResponseReceivedAt.CompareAndSwap(nil, now) + p.lastResponseReceivedAt.Store(now) +} + +// CurrentRoundTripTime returns the current round trip time in seconds +// https://www.w3.org/TR/webrtc-stats/#dom-rtcicecandidatepairstats-currentroundtriptime +func (p *CandidatePair) CurrentRoundTripTime() float64 { + return time.Duration(atomic.LoadInt64(&p.currentRoundTripTime)).Seconds() +} + +// TotalRoundTripTime returns the current round trip time in seconds +// https://www.w3.org/TR/webrtc-stats/#dom-rtcicecandidatepairstats-totalroundtriptime +func (p *CandidatePair) TotalRoundTripTime() float64 { + return time.Duration(atomic.LoadInt64(&p.totalRoundTripTime)).Seconds() +} + +// RequestsReceived returns the total number of connectivity checks received +// https://www.w3.org/TR/webrtc-stats/#dom-rtcicecandidatepairstats-requestsreceived +func (p *CandidatePair) RequestsReceived() uint64 { + return atomic.LoadUint64(&p.requestsReceived) +} + +// RequestsSent returns the total number of connectivity checks sent +// https://www.w3.org/TR/webrtc-stats/#dom-rtcicecandidatepairstats-requestssent +func (p *CandidatePair) RequestsSent() uint64 { + return atomic.LoadUint64(&p.requestsSent) +} + +// ResponsesReceived returns the total number of connectivity responses received +// https://www.w3.org/TR/webrtc-stats/#dom-rtcicecandidatepairstats-responsesreceived +func (p *CandidatePair) ResponsesReceived() uint64 { + return atomic.LoadUint64(&p.responsesReceived) +} + +// ResponsesSent returns the total number of connectivity responses sent +// https://www.w3.org/TR/webrtc-stats/#dom-rtcicecandidatepairstats-responsessent +func (p *CandidatePair) ResponsesSent() uint64 { + return atomic.LoadUint64(&p.responsesSent) +} + +// PacketsSent returns total application (non-STUN) packets sent on this pair. +func (p *CandidatePair) PacketsSent() uint32 { + return atomic.LoadUint32(&p.packetsSent) +} + +// PacketsReceived returns total application (non-STUN) packets received on this pair. +func (p *CandidatePair) PacketsReceived() uint32 { + return atomic.LoadUint32(&p.packetsReceived) +} + +// BytesSent returns total application bytes sent on this pair. +func (p *CandidatePair) BytesSent() uint64 { + return atomic.LoadUint64(&p.bytesSent) +} + +// BytesReceived returns total application bytes received on this pair. +func (p *CandidatePair) BytesReceived() uint64 { + return atomic.LoadUint64(&p.bytesReceived) +} + +// LastPacketSentAt returns the timestamp of the last application packet sent. +func (p *CandidatePair) LastPacketSentAt() time.Time { + if v, ok := p.lastPacketSentAt.Load().(time.Time); ok { + return v + } + + return time.Time{} +} + +// LastPacketReceivedAt returns the timestamp of the last application packet received. +func (p *CandidatePair) LastPacketReceivedAt() time.Time { + if v, ok := p.lastPacketReceivedAt.Load().(time.Time); ok { + return v + } + + return time.Time{} +} + +// UpdatePacketSent increments packet/byte counters and updates timestamp for a sent application packet. +func (p *CandidatePair) UpdatePacketSent(n int) { + if n <= 0 { + return + } + + atomic.AddUint32(&p.packetsSent, 1) + atomic.AddUint64(&p.bytesSent, uint64(n)) // #nosec G115 -- n > 0 validated above + p.lastPacketSentAt.Store(time.Now()) +} + +// UpdatePacketReceived increments packet/byte counters and updates timestamp for a received application packet. +func (p *CandidatePair) UpdatePacketReceived(n int) { + if n <= 0 { + return + } + + atomic.AddUint32(&p.packetsReceived, 1) + atomic.AddUint64(&p.bytesReceived, uint64(n)) // #nosec G115 -- n > 0 validated above + p.lastPacketReceivedAt.Store(time.Now()) +} + +// FirstRequestSentAt returns the timestamp of the first connectivity check sent. +func (p *CandidatePair) FirstRequestSentAt() time.Time { + if v, ok := p.firstRequestSentAt.Load().(time.Time); ok { + return v + } + + return time.Time{} +} + +// LastRequestSentAt returns the timestamp of the last connectivity check sent. +func (p *CandidatePair) LastRequestSentAt() time.Time { + if v, ok := p.lastRequestSentAt.Load().(time.Time); ok { + return v + } + + return time.Time{} +} + +// Deprecated: use FirstResponseReceivedAt +// FirstReponseReceivedAt returns the timestamp of the first connectivity response received. +func (p *CandidatePair) FirstReponseReceivedAt() time.Time { + return p.FirstResponseReceivedAt() +} + +// FirstResponseReceivedAt returns the timestamp of the first connectivity response received. +func (p *CandidatePair) FirstResponseReceivedAt() time.Time { + if v, ok := p.firstResponseReceivedAt.Load().(time.Time); ok { + return v + } + + return time.Time{} +} + +// LastResponseReceivedAt returns the timestamp of the last connectivity response received. +func (p *CandidatePair) LastResponseReceivedAt() time.Time { + if v, ok := p.lastResponseReceivedAt.Load().(time.Time); ok { + return v + } + + return time.Time{} +} + +// FirstRequestReceivedAt returns the timestamp of the first connectivity check received. +func (p *CandidatePair) FirstRequestReceivedAt() time.Time { + if v, ok := p.firstRequestReceivedAt.Load().(time.Time); ok { + return v + } + + return time.Time{} +} + +// LastRequestReceivedAt returns the timestamp of the last connectivity check received. +func (p *CandidatePair) LastRequestReceivedAt() time.Time { + if v, ok := p.lastRequestReceivedAt.Load().(time.Time); ok { + return v + } + + return time.Time{} +} + +// UpdateRequestSent increments the number of requests sent and updates the timestamp. +func (p *CandidatePair) UpdateRequestSent() { + atomic.AddUint64(&p.requestsSent, 1) + now := time.Now() + p.firstRequestSentAt.CompareAndSwap(nil, now) + p.lastRequestSentAt.Store(now) +} + +// UpdateResponseSent increments the number of responses sent. +func (p *CandidatePair) UpdateResponseSent() { + atomic.AddUint64(&p.responsesSent, 1) +} + +// UpdateRequestReceived increments the number of requests received and updates the timestamp. +func (p *CandidatePair) UpdateRequestReceived() { + atomic.AddUint64(&p.requestsReceived, 1) + now := time.Now() + p.firstRequestReceivedAt.CompareAndSwap(nil, now) + p.lastRequestReceivedAt.Store(now) +} + +// ID returns the unique identifier for this candidate pair. +func (p *CandidatePair) ID() uint64 { + return p.id +} diff --git a/vendor/github.com/pion/ice/v4/candidatepair_state.go b/vendor/github.com/pion/ice/v4/candidatepair_state.go new file mode 100644 index 0000000..70ce2d5 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/candidatepair_state.go @@ -0,0 +1,41 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +// CandidatePairState represent the ICE candidate pair state. +type CandidatePairState int + +const ( + // CandidatePairStateWaiting means a check has not been performed for + // this pair. + CandidatePairStateWaiting CandidatePairState = iota + 1 + + // CandidatePairStateInProgress means a check has been sent for this pair, + // but the transaction is in progress. + CandidatePairStateInProgress + + // CandidatePairStateFailed means a check for this pair was already done + // and failed, either never producing any response or producing an unrecoverable + // failure response. + CandidatePairStateFailed + + // CandidatePairStateSucceeded means a check for this pair was already + // done and produced a successful result. + CandidatePairStateSucceeded +) + +func (c CandidatePairState) String() string { + switch c { + case CandidatePairStateWaiting: + return "waiting" + case CandidatePairStateInProgress: + return "in-progress" + case CandidatePairStateFailed: + return "failed" + case CandidatePairStateSucceeded: + return "succeeded" + } + + return "Unknown candidate pair state" +} diff --git a/vendor/github.com/pion/ice/v4/candidaterelatedaddress.go b/vendor/github.com/pion/ice/v4/candidaterelatedaddress.go new file mode 100644 index 0000000..8ab014f --- /dev/null +++ b/vendor/github.com/pion/ice/v4/candidaterelatedaddress.go @@ -0,0 +1,34 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import "fmt" + +// CandidateRelatedAddress convey transport addresses related to the +// candidate, useful for diagnostics and other purposes. +type CandidateRelatedAddress struct { + Address string + Port int +} + +// String makes CandidateRelatedAddress printable. +func (c *CandidateRelatedAddress) String() string { + if c == nil { + return "" + } + + return fmt.Sprintf(" related %s:%d", c.Address, c.Port) +} + +// Equal allows comparing two CandidateRelatedAddresses. +// The CandidateRelatedAddress are allowed to be nil. +func (c *CandidateRelatedAddress) Equal(other *CandidateRelatedAddress) bool { + if c == nil && other == nil { + return true + } + + return c != nil && other != nil && + c.Address == other.Address && + c.Port == other.Port +} diff --git a/vendor/github.com/pion/ice/v4/candidatetype.go b/vendor/github.com/pion/ice/v4/candidatetype.go new file mode 100644 index 0000000..ac1a731 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/candidatetype.go @@ -0,0 +1,65 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import "slices" + +// CandidateType represents the type of candidate. +type CandidateType byte + +// CandidateType enum. +const ( + CandidateTypeUnspecified CandidateType = iota + CandidateTypeHost + CandidateTypeServerReflexive + CandidateTypePeerReflexive + CandidateTypeRelay +) + +// String makes CandidateType printable. +func (c CandidateType) String() string { + switch c { + case CandidateTypeHost: + return "host" + case CandidateTypeServerReflexive: + return "srflx" + case CandidateTypePeerReflexive: + return "prflx" + case CandidateTypeRelay: + return "relay" + case CandidateTypeUnspecified: + return "Unknown candidate type" + } + + return "Unknown candidate type" +} + +// Preference returns the preference weight of a CandidateType +// +// 4.1.2.2. Guidelines for Choosing Type and Local Preferences +// The RECOMMENDED values are 126 for host candidates, 100 +// for server reflexive candidates, 110 for peer reflexive candidates, +// and 0 for relayed candidates. +func (c CandidateType) Preference() uint16 { + switch c { + case CandidateTypeHost: + return 126 + case CandidateTypePeerReflexive: + return 110 + case CandidateTypeServerReflexive: + return 100 + case CandidateTypeRelay, CandidateTypeUnspecified: + return 0 + } + + return 0 +} + +func containsCandidateType(candidateType CandidateType, candidateTypeList []CandidateType) bool { + if candidateTypeList == nil { + return false + } + + return slices.Contains(candidateTypeList, candidateType) +} diff --git a/vendor/github.com/pion/ice/v4/codecov.yml b/vendor/github.com/pion/ice/v4/codecov.yml new file mode 100644 index 0000000..b9639c2 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/codecov.yml @@ -0,0 +1,22 @@ +# +# DO NOT EDIT THIS FILE +# +# It is automatically copied from https://github.com/pion/.goassets repository. +# +# SPDX-FileCopyrightText: 2026 The Pion community +# SPDX-License-Identifier: MIT + +coverage: + status: + project: + default: + # Allow decreasing 2% of total coverage to avoid noise. + threshold: 2% + patch: + default: + target: 70% + only_pulls: true + +ignore: + - "examples/*" + - "examples/**/*" diff --git a/vendor/github.com/pion/ice/v4/errors.go b/vendor/github.com/pion/ice/v4/errors.go new file mode 100644 index 0000000..1d0e0d4 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/errors.go @@ -0,0 +1,203 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "errors" + + "github.com/pion/ice/v4/internal/taskloop" +) + +var ( + // ErrUnknownType indicates an error with Unknown info. + ErrUnknownType = errors.New("Unknown") + + // ErrSchemeType indicates the scheme type could not be parsed. + ErrSchemeType = errors.New("unknown scheme type") + + // ErrSTUNQuery indicates query arguments are provided in a STUN URL. + ErrSTUNQuery = errors.New("queries not supported in STUN address") + + // ErrInvalidQuery indicates an malformed query is provided. + ErrInvalidQuery = errors.New("invalid query") + + // ErrHost indicates malformed hostname is provided. + ErrHost = errors.New("invalid hostname") + + // ErrPort indicates malformed port is provided. + ErrPort = errors.New("invalid port") + + // ErrLocalUfragInsufficientBits indicates local username fragment insufficient bits are provided. + // Have to be at least 24 bits long. + ErrLocalUfragInsufficientBits = errors.New("local username fragment is less than 24 bits long") + + // ErrLocalPwdInsufficientBits indicates local password insufficient bits are provided. + // Have to be at least 128 bits long. + ErrLocalPwdInsufficientBits = errors.New("local password is less than 128 bits long") + + // ErrProtoType indicates an unsupported transport type was provided. + ErrProtoType = errors.New("invalid transport protocol type") + + // ErrClosed indicates the agent is closed. + ErrClosed = taskloop.ErrClosed + + // ErrNoCandidatePairs indicates agent does not have a valid candidate pair. + ErrNoCandidatePairs = errors.New("no candidate pairs available") + + // ErrCanceledByCaller indicates agent connection was canceled by the caller. + ErrCanceledByCaller = errors.New("connecting canceled by caller") + + // ErrMultipleStart indicates agent was started twice. + ErrMultipleStart = errors.New("attempted to start agent twice") + + // ErrRemoteUfragEmpty indicates agent was started with an empty remote ufrag. + ErrRemoteUfragEmpty = errors.New("remote ufrag is empty") + + // ErrRemotePwdEmpty indicates agent was started with an empty remote pwd. + ErrRemotePwdEmpty = errors.New("remote pwd is empty") + + // ErrNoOnCandidateHandler indicates agent was started without OnCandidate. + ErrNoOnCandidateHandler = errors.New("no OnCandidate provided") + + // ErrMultipleGatherAttempted indicates GatherCandidates has been called multiple times. + ErrMultipleGatherAttempted = errors.New("attempting to gather candidates during gathering state") + + // ErrUsernameEmpty indicates agent was give TURN URL with an empty Username. + ErrUsernameEmpty = errors.New("username is empty") + + // ErrPasswordEmpty indicates agent was give TURN URL with an empty Password. + ErrPasswordEmpty = errors.New("password is empty") + + // ErrAddressParseFailed indicates we were unable to parse a candidate address. + ErrAddressParseFailed = errors.New("failed to parse address") + + // ErrLiteUsingNonHostCandidates indicates non host candidates were selected for a lite agent. + ErrLiteUsingNonHostCandidates = errors.New("lite agents must only use host candidates") + + // ErrUselessUrlsProvided indicates that one or more URL was provided to the agent but no host + // candidate required them. + ErrUselessUrlsProvided = errors.New("agent does not need URL with selected candidate types") + + // ErrUnsupportedNAT1To1IPCandidateType indicates that the specified NAT1To1IPCandidateType is + // unsupported. + // + // Deprecated: use ErrUnsupportedAddressRewriteCandidateType instead. May still be returned + // when configuring address rewrite rules while NAT1:1 compatibility remains. + ErrUnsupportedNAT1To1IPCandidateType = errors.New("unsupported address rewrite candidate type") + // ErrUnsupportedAddressRewriteCandidateType is an alias for ErrUnsupportedNAT1To1IPCandidateType. + ErrUnsupportedAddressRewriteCandidateType = ErrUnsupportedNAT1To1IPCandidateType + + // ErrInvalidNAT1To1IPMapping indicates that the given 1:1 NAT IP mapping is invalid. + // + // Deprecated: use ErrInvalidAddressRewriteMapping instead. May still be returned by + // WithAddressRewriteRules while NAT1:1 compatibility remains. + ErrInvalidNAT1To1IPMapping = errors.New("invalid address rewrite mapping") + // ErrInvalidAddressRewriteMapping is an alias for ErrInvalidNAT1To1IPMapping. + ErrInvalidAddressRewriteMapping = ErrInvalidNAT1To1IPMapping + + // ErrExternalMappedIPNotFound in address rewrite mapping. + // + // Kept for compatibility; current code paths treat "no externals" via match state and + // no longer return this error. + ErrExternalMappedIPNotFound = errors.New("external mapped IP not found") + + // ErrMulticastDNSWithNAT1To1IPMapping indicates that the mDNS gathering cannot be used along + // with 1:1 NAT IP mapping for host candidate. + // + // Deprecated: use ErrMulticastDNSWithAddressRewrite instead. May still be returned by + // WithAddressRewriteRules while NAT1:1 compatibility remains. + ErrMulticastDNSWithNAT1To1IPMapping = errors.New( + "mDNS gathering cannot be used with address rewrite for host candidate", + ) + // ErrMulticastDNSWithAddressRewrite is an alias for ErrMulticastDNSWithNAT1To1IPMapping. + ErrMulticastDNSWithAddressRewrite = ErrMulticastDNSWithNAT1To1IPMapping + + // ErrIneffectiveNAT1To1IPMappingHost indicates that 1:1 NAT IP mapping for host candidate is + // requested, but the host candidate type is disabled. + // + // Deprecated: use ErrIneffectiveAddressRewriteHost instead. May still be returned by + // WithAddressRewriteRules while NAT1:1 compatibility remains. + ErrIneffectiveNAT1To1IPMappingHost = errors.New("address rewrite for host candidate ineffective") + // ErrIneffectiveAddressRewriteHost is an alias for ErrIneffectiveNAT1To1IPMappingHost. + ErrIneffectiveAddressRewriteHost = ErrIneffectiveNAT1To1IPMappingHost + + // ErrIneffectiveNAT1To1IPMappingSrflx indicates that 1:1 NAT IP mapping for srflx candidate is + // requested, but the srflx candidate type is disabled. + // + // Deprecated: use ErrIneffectiveAddressRewriteSrflx instead. May still be returned by + // WithAddressRewriteRules while NAT1:1 compatibility remains. + ErrIneffectiveNAT1To1IPMappingSrflx = errors.New("address rewrite for srflx candidate ineffective") + // ErrIneffectiveAddressRewriteSrflx is an alias for ErrIneffectiveNAT1To1IPMappingSrflx. + ErrIneffectiveAddressRewriteSrflx = ErrIneffectiveNAT1To1IPMappingSrflx + + // ErrInvalidMulticastDNSHostName indicates an invalid MulticastDNSHostName. + ErrInvalidMulticastDNSHostName = errors.New( + "invalid mDNS HostName, must end with .local and can only contain a single '.'", + ) + + // ErrRunCanceled indicates a run operation was canceled by its individual done. + ErrRunCanceled = errors.New("run was canceled by done") + + // ErrTCPRemoteAddrAlreadyExists indicates we already have the connection with same remote addr. + ErrTCPRemoteAddrAlreadyExists = errors.New("conn with same remote addr already exists") + + // ErrUnknownCandidateTyp indicates that a candidate had a unknown type value. + ErrUnknownCandidateTyp = errors.New("unknown candidate typ") + + // ErrDetermineNetworkType indicates that the NetworkType was not able to be parsed. + ErrDetermineNetworkType = errors.New("unable to determine networkType") + + // ErrOnlyControllingAgentCanRenominate indicates that only controlling agent can renominate. + ErrOnlyControllingAgentCanRenominate = errors.New("only controlling agent can renominate") + + // ErrRenominationNotEnabled indicates that renomination is not enabled. + ErrRenominationNotEnabled = errors.New("renomination is not enabled") + + // ErrCandidatePairNotFound indicates that candidate pair was not found. + ErrCandidatePairNotFound = errors.New("candidate pair not found") + + // ErrCandidatePairNotSucceeded indicates that candidate pair is not in succeeded state. + ErrCandidatePairNotSucceeded = errors.New("candidate pair not in succeeded state") + + // ErrInvalidNominationAttribute indicates an invalid nomination attribute type was provided. + ErrInvalidNominationAttribute = errors.New("invalid nomination attribute type") + + // ErrInvalidNominationValueGenerator indicates a nil nomination value generator was provided. + ErrInvalidNominationValueGenerator = errors.New("nomination value generator cannot be nil") + + // ErrInvalidNetworkMonitorInterval indicates an invalid network monitor interval was provided. + ErrInvalidNetworkMonitorInterval = errors.New("network monitor interval must be greater than 0") + + // ErrAgentOptionNotUpdatable indicates an option cannot be updated after construction. + ErrAgentOptionNotUpdatable = errors.New("option can only be set during agent construction") + + errAttributeTooShortICECandidate = errors.New("attribute not long enough to be ICE candidate") + errClosingConnection = errors.New("failed to close connection") + errConnectionAddrAlreadyExist = errors.New("connection with same remote address already exists") + errGetXorMappedAddrResponse = errors.New("failed to get XOR-MAPPED-ADDRESS response") + errInvalidAddress = errors.New("invalid address") + errNoTCPMuxAvailable = errors.New("no TCP mux is available") + errNotImplemented = errors.New("not implemented yet") + errNoUDPMuxAvailable = errors.New("no UDP mux is available") + errNoXorAddrMapping = errors.New("no address mapping") + errParseFoundation = errors.New("failed to parse foundation") + errParseComponent = errors.New("failed to parse component") + errParsePort = errors.New("failed to parse port") + errParsePriority = errors.New("failed to parse priority") + errParseRelatedAddr = errors.New("failed to parse related addresses") + errParseExtension = errors.New("failed to parse extension") + errParseTCPType = errors.New("failed to parse TCP type") + errUDPMuxDisabled = errors.New("UDPMux is not enabled") + errUnknownRole = errors.New("unknown role") + errWrite = errors.New("failed to write") + errWriteSTUNMessage = errors.New("failed to send STUN message") + errWriteSTUNMessageToIceConn = errors.New("failed to write STUN message to ICE connection") + errXORMappedAddrTimeout = errors.New("timeout while waiting for XORMappedAddr") + errFailedToCastUDPAddr = errors.New("failed to cast net.Addr to net.UDPAddr") + errInvalidIPAddress = errors.New("invalid ip address") + + // UDPMuxDefault should not listen on unspecified address, but to keep backward compatibility, don't return error now. + // will be used in the future. + // errListenUnspecified = errors.New("can't listen on unspecified address"). +) diff --git a/vendor/github.com/pion/ice/v4/external_ip_mapper.go b/vendor/github.com/pion/ice/v4/external_ip_mapper.go new file mode 100644 index 0000000..975d8cd --- /dev/null +++ b/vendor/github.com/pion/ice/v4/external_ip_mapper.go @@ -0,0 +1,456 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "fmt" + "net" + "strings" +) + +// AddressRewriteMode controls whether a rule replaces or appends candidates. +type AddressRewriteMode int + +const ( + addressRewriteModeUnspecified AddressRewriteMode = iota + AddressRewriteReplace + AddressRewriteAppend +) + +// AddressRewriteRule represents a rule for remapping candidate addresses. +type AddressRewriteRule struct { + // External are the 1:1 external addresses to advertise for this rule. + // For replace mode, an empty list is treated as "drop the matched local + // address" (no candidate emitted). For append mode, an empty list is a + // no-op: the original candidate is kept. + // Empty External rules are intentional: + // - Mode AddressRewriteReplace drops the matched candidate (deny-list style). + // - Mode AddressRewriteAppend keeps the original candidate and adds nothing, + // which is useful when you combine a catch-all replace with per-interface + // allow rules. + External []string + // Local optionally pins this rule to a specific local address. When set, + // external IPs map to that address regardless of IP family. When empty, + // External acts as a catch-all for the family implied by the local scope + // (CIDR when set, otherwise the external IP family). + Local string + // Iface is the optional interface name to limit the rule to, empty = any. + Iface string + // CIDR is the optional CIDR to limit the rule to, empty = any. + CIDR string + // AsCandidateType is the candidate type to publish as for this rule. Defaults to host + // when unspecified. Supported values: host, server reflexive, relay. + AsCandidateType CandidateType + // Mode controls whether we replace the original candidate or append extra + // candidates. + // + // If Mode is zero, the default is: + // - CandidateTypeHost -> AddressRewriteReplace + // - CandidateTypeServerReflexive, CandidateTypeRelay -> AddressRewriteAppend + // For replace mode, a match with zero external IPs removes the candidate. + // For append mode, a match with zero external IPs leaves the original + // candidate untouched. + Mode AddressRewriteMode + // Networks is the optional networks to limit the rule to, nil/empty = all. + Networks []NetworkType +} + +func validateIPString(ipStr string) (net.IP, bool, error) { + ip := net.ParseIP(ipStr) + if ip == nil { + return nil, false, ErrInvalidNAT1To1IPMapping + } + + return ip, (ip.To4() != nil), nil +} + +// ipMapping holds the mapping of local and external IP address +// +// for a particular IP family. +type ipMapping struct { + ipSole []net.IP // When non-empty, these are the catch-all external IPs for one local IP family + ipMap map[string][]net.IP // Local-to-external IP mapping (k: local, v: external IPs) + valid bool // If not set any external IP, valid is false + catchAllSet bool +} + +func newIPMapping() ipMapping { + return ipMapping{ + ipMap: make(map[string][]net.IP), + } +} + +func (m *ipMapping) addSoleIP(ip net.IP) { + m.ipSole = append(m.ipSole, ip) + m.valid = true + m.catchAllSet = true +} + +func addExternalMappings( + external []string, + ruleMapping *addressRewriteRuleMapping, + hasLocalAddr bool, + localAddr net.IP, + localIsIPv4 bool, +) (bool, error) { + added := false + + for _, raw := range external { + extIPStr := strings.TrimSpace(raw) + ipPair := strings.Split(extIPStr, "/") + if len(ipPair) != 1 { + return false, ErrInvalidNAT1To1IPMapping + } + + extIP, isExtIPv4, err := validateIPString(ipPair[0]) + if err != nil { + return false, err + } + + targetLocalIPv4 := isExtIPv4 + if hasLocalAddr { + targetLocalIPv4 = localIsIPv4 + } else if ruleMapping.cidr != nil { + targetLocalIPv4 = ruleMapping.cidr.IP.To4() != nil + } + + if !ruleMapping.isFamilyAllowed(targetLocalIPv4) { + continue + } + + ruleMapping.addImplicitMapping(extIP, targetLocalIPv4, hasLocalAddr, localAddr) + added = true + } + + return added, nil +} + +func maybeMarkEmptyMapping( + ruleMapping *addressRewriteRuleMapping, + added bool, + hasLocalAddr bool, + localIsIPv4 bool, + localAddr net.IP, +) { + if added { + return + } + + if hasLocalAddr { + if ruleMapping.isFamilyAllowed(localIsIPv4) { + family := ruleMapping.mappingForFamily(localIsIPv4) + family.ipMap[localAddr.String()] = nil + family.valid = true + } + + return + } + + if ruleMapping.allowIPv4 { + ruleMapping.ipv4Mapping.valid = true + ruleMapping.ipv4Mapping.catchAllSet = true + } + if ruleMapping.allowIPv6 { + ruleMapping.ipv6Mapping.valid = true + ruleMapping.ipv6Mapping.catchAllSet = true + } +} + +func (m *ipMapping) addIPMapping(locIP, extIP net.IP) { + locIPStr := locIP.String() + + m.ipMap[locIPStr] = append(m.ipMap[locIPStr], extIP) + m.valid = true +} + +func cloneIPs(src []net.IP) []net.IP { + if len(src) == 0 { + return nil + } + + cloned := make([]net.IP, 0, len(src)) + for _, ip := range src { + if ip == nil { + continue + } + copied := make(net.IP, len(ip)) + copy(copied, ip) + cloned = append(cloned, copied) + } + + return cloned +} + +func (m *ipMapping) findExternalIPs(locIP net.IP) []net.IP { + if !m.valid { + return nil + } + + if m.ipMap != nil { + if extIPs, ok := m.ipMap[locIP.String()]; ok && len(extIPs) > 0 { + return cloneIPs(extIPs) + } + } + + if len(m.ipSole) > 0 { + return cloneIPs(m.ipSole) + } + + return nil +} + +type addressRewriteRuleMapping struct { + rule AddressRewriteRule + mode AddressRewriteMode + ipv4Mapping ipMapping + ipv6Mapping ipMapping + cidr *net.IPNet + allowIPv4 bool + allowIPv6 bool +} + +func (m *addressRewriteRuleMapping) hasMappings() bool { + return m.ipv4Mapping.valid || m.ipv6Mapping.valid +} + +func (m *addressRewriteRuleMapping) mappingForFamily(isIPv4 bool) *ipMapping { + if isIPv4 { + return &m.ipv4Mapping + } + + return &m.ipv6Mapping +} + +func (m *addressRewriteRuleMapping) isFamilyAllowed(isLocalIPv4 bool) bool { + if isLocalIPv4 { + return m.allowIPv4 + } + + return m.allowIPv6 +} + +func (m *addressRewriteRuleMapping) addImplicitMapping( + extIP net.IP, + isLocalIPv4 bool, + hasLocalAddr bool, + localAddr net.IP, +) { + mapping := m.mappingForFamily(isLocalIPv4) + if hasLocalAddr { + mapping.addIPMapping(localAddr, extIP) + } else { + mapping.addSoleIP(extIP) + } +} + +type addressRewriteMapper struct { + rulesByCandidateType map[CandidateType][]*addressRewriteRuleMapping +} + +//nolint:gocognit,gocyclo,cyclop +func newAddressRewriteMapper(rules []AddressRewriteRule) (*addressRewriteMapper, error) { + if len(rules) == 0 { + return nil, nil //nolint:nilnil + } + + mapper := &addressRewriteMapper{ + rulesByCandidateType: make(map[CandidateType][]*addressRewriteRuleMapping), + } + + for _, rule := range rules { + candidateType := rule.AsCandidateType + if candidateType == CandidateTypeUnspecified { + candidateType = CandidateTypeHost + } + if candidateType == CandidateTypePeerReflexive { + return nil, ErrUnsupportedNAT1To1IPCandidateType + } + + mode := rule.Mode + if mode == addressRewriteModeUnspecified { + mode = defaultAddressRewriteMode(candidateType) + } + + ruleMapping := &addressRewriteRuleMapping{ + rule: rule, + mode: mode, + ipv4Mapping: newIPMapping(), + ipv6Mapping: newIPMapping(), + allowIPv4: true, + allowIPv6: true, + } + + if len(rule.Networks) > 0 { + ruleMapping.allowIPv4 = false + ruleMapping.allowIPv6 = false + for _, network := range rule.Networks { + if network.IsIPv4() { + ruleMapping.allowIPv4 = true + } + if network.IsIPv6() { + ruleMapping.allowIPv6 = true + } + } + if !ruleMapping.allowIPv4 && !ruleMapping.allowIPv6 { + continue + } + } + if rule.CIDR != "" { + _, ipNet, err := net.ParseCIDR(rule.CIDR) + if err != nil { + return nil, ErrInvalidNAT1To1IPMapping + } + ruleMapping.cidr = ipNet + } + + var ( + localAddr net.IP + localIsIPv4 bool + hasLocalAddr bool + err error + ) + if trimmedLocal := strings.TrimSpace(rule.Local); trimmedLocal != "" { + localAddr, localIsIPv4, err = validateIPString(trimmedLocal) + if err != nil { + return nil, err + } + hasLocalAddr = true + + if ruleMapping.cidr != nil && !ruleMapping.cidr.Contains(localAddr) { + return nil, fmt.Errorf("%w: Invalid local IP is outside CIDR", ErrInvalidNAT1To1IPMapping) + } + } + + added, mapErr := addExternalMappings(rule.External, ruleMapping, hasLocalAddr, localAddr, localIsIPv4) + if mapErr != nil { + return nil, mapErr + } + maybeMarkEmptyMapping(ruleMapping, added, hasLocalAddr, localIsIPv4, localAddr) + + if ruleMapping.hasMappings() { + mapper.rulesByCandidateType[candidateType] = append(mapper.rulesByCandidateType[candidateType], ruleMapping) + } + } + + if len(mapper.rulesByCandidateType) == 0 { + return nil, nil //nolint:nilnil + } + + return mapper, nil +} + +func (m *addressRewriteMapper) hasCandidateType(candidateType CandidateType) bool { + rules := m.rulesByCandidateType[candidateType] + for _, rule := range rules { + if rule.hasMappings() { + return true + } + } + + return false +} + +func (m *addressRewriteMapper) shouldReplace(candidateType CandidateType) bool { + for _, rule := range m.rulesByCandidateType[candidateType] { + if rule.mode == AddressRewriteReplace { + return true + } + } + + return false +} + +func (m *addressRewriteMapper) findExternalIPs( + candidateType CandidateType, + localIPStr string, + iface string, +) ([]net.IP, bool, AddressRewriteMode, error) { + locIP, isLocIPv4, err := validateIPString(localIPStr) + if err != nil { + return nil, false, addressRewriteModeUnspecified, err + } + + rules := m.rulesByCandidateType[candidateType] + ips, matched, mode := evaluateRewriteRules(rules, locIP, isLocIPv4, iface) + + return ips, matched, mode, nil +} + +func ruleMappingForLookup( + rule *addressRewriteRuleMapping, + locIP net.IP, + isLocIPv4 bool, + iface string, +) (*ipMapping, bool) { + if rule.rule.Iface != "" && rule.rule.Iface != iface { + return nil, false + } + if rule.cidr != nil && !rule.cidr.Contains(locIP) { + return nil, false + } + + ipMapping := rule.mappingForFamily(isLocIPv4) + if !ipMapping.valid { + return nil, false + } + + return ipMapping, true +} + +func catchAllSpecificity(rule *addressRewriteRuleMapping, iface string) int { + spec := 0 + if rule.rule.Iface != "" { + spec += 2 + if rule.cidr != nil { + spec++ + } + } else if iface == "" && rule.cidr != nil { + spec++ + } + + return spec +} + +func evaluateRewriteRules( + rules []*addressRewriteRuleMapping, + locIP net.IP, + isLocIPv4 bool, + iface string, +) (ips []net.IP, matched bool, mode AddressRewriteMode) { + var ( + catchAll []net.IP + catchAllMode AddressRewriteMode + hasCatchAll bool + bestSpec = -1 + ) + + for _, rule := range rules { + ipMapping, ok := ruleMappingForLookup(rule, locIP, isLocIPv4, iface) + if !ok { + continue + } + + if explicit, ok := ipMapping.ipMap[locIP.String()]; ok { + cloned := cloneIPs(explicit) + + return cloned, true, rule.mode + } + + if ipMapping.catchAllSet { + spec := catchAllSpecificity(rule, iface) + if !hasCatchAll || spec > bestSpec { + catchAll = cloneIPs(ipMapping.ipSole) + catchAllMode = rule.mode + hasCatchAll = true + bestSpec = spec + } + } + } + + if hasCatchAll { + return catchAll, true, catchAllMode + } + + return nil, false, addressRewriteModeUnspecified +} diff --git a/vendor/github.com/pion/ice/v4/gather.go b/vendor/github.com/pion/ice/v4/gather.go new file mode 100644 index 0000000..3210866 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/gather.go @@ -0,0 +1,1415 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "context" + "crypto/tls" + "io" + "net" + "net/netip" + "reflect" + "strconv" + "sync" + "time" + + "github.com/pion/dtls/v3" + "github.com/pion/ice/v4/internal/fakenet" + stunx "github.com/pion/ice/v4/internal/stun" + "github.com/pion/logging" + "github.com/pion/stun/v3" + "github.com/pion/transport/v4/stdnet" + "github.com/pion/turn/v4" +) + +type turnClient interface { + Listen() error + Allocate() (net.PacketConn, error) + Close() +} + +func defaultTurnClient(cfg *turn.ClientConfig) (turnClient, error) { + return turn.NewClient(cfg) +} + +func configuredNetworkTypes(networkTypes []NetworkType) []NetworkType { + if len(networkTypes) == 0 { + return supportedNetworkTypes() + } + + return networkTypes +} + +func effectiveURLProtoType(url stun.URI) stun.ProtoType { + if url.Proto != stun.ProtoTypeUnknown { + return url.Proto + } + + switch url.Scheme { + case stun.SchemeTypeSTUN, stun.SchemeTypeTURN: + return stun.ProtoTypeUDP + case stun.SchemeTypeSTUNS, stun.SchemeTypeTURNS: + return stun.ProtoTypeTCP + default: + return stun.ProtoTypeUnknown + } +} + +func urlSupportsSrflxGathering(url stun.URI) bool { + if effectiveURLProtoType(url) != stun.ProtoTypeUDP { + return false + } + + return url.Scheme == stun.SchemeTypeSTUN || url.Scheme == stun.SchemeTypeTURN +} + +func relayNetworkTypesForURL(url stun.URI, networkTypes []NetworkType) []NetworkType { + proto := effectiveURLProtoType(url) + switch proto { + case stun.ProtoTypeUDP: + res := []NetworkType{} + for _, networkType := range configuredNetworkTypes(networkTypes) { + if networkType.IsUDP() { + res = append(res, networkType) + } + } + + return res + case stun.ProtoTypeTCP: + res := []NetworkType{} + for _, networkType := range configuredNetworkTypes(networkTypes) { + if networkType.IsTCP() { + res = append(res, networkType) + } + } + + return res + default: + return nil + } +} + +// Close a net.Conn and log if we have a failure. +func closeConnAndLog(c io.Closer, log logging.LeveledLogger, msg string, args ...any) { + if c == nil || (reflect.ValueOf(c).Kind() == reflect.Ptr && reflect.ValueOf(c).IsNil()) { + log.Warnf("Connection is not allocated: "+msg, args...) + + return + } + + log.Warnf(msg, args...) + if err := c.Close(); err != nil { + log.Warnf("Failed to close connection: %v", err) + } +} + +// GatherCandidates initiates the trickle based gathering process. +func (a *Agent) GatherCandidates() error { + var gatherErr error + + if runErr := a.loop.Run(a.loop, func(ctx context.Context) { + if a.gatheringState != GatheringStateNew { + gatherErr = ErrMultipleGatherAttempted + + return + } else if a.onCandidateHdlr.Load() == nil { + gatherErr = ErrNoOnCandidateHandler + + return + } + + a.gatherCandidateCancel() // Cancel previous gathering routine + ctx, cancel := context.WithCancel(ctx) + a.gatherCandidateCancel = cancel + done := make(chan struct{}) + a.gatherCandidateDone = done + + go a.gatherCandidates(ctx, done) + }); runErr != nil { + return runErr + } + + return gatherErr +} + +func (a *Agent) gatherCandidates(ctx context.Context, done chan struct{}) { //nolint:cyclop + defer close(done) + if err := a.setGatheringState(GatheringStateGathering); err != nil { //nolint:contextcheck + a.log.Warnf("Failed to set gatheringState to GatheringStateGathering: %v", err) + + return + } + + a.gatherCandidatesInternal(ctx) + + switch a.continualGatheringPolicy { + case GatherOnce: + if err := a.setGatheringState(GatheringStateComplete); err != nil { //nolint:contextcheck + a.log.Warnf("Failed to set gatheringState to GatheringStateComplete: %v", err) + } + case GatherContinually: + // Initialize known interfaces before starting monitoring + _, addrs, err := localInterfaces( + a.net, + a.interfaceFilter, + a.ipFilter, + a.networkTypes, + a.includeLoopback, + ) + if err != nil { + a.log.Warnf("Failed to get initial interfaces for monitoring: %v", err) + } else { + for _, info := range addrs { + a.lastKnownInterfaces[info.addr.String()] = info.addr + } + a.log.Infof("Initialized network monitoring with %d IP addresses", len(addrs)) + } + go a.startNetworkMonitoring(ctx) + } +} + +func (a *Agent) shouldRewriteCandidateType(candidateType CandidateType) bool { + return a.addressRewriteMapper != nil && a.addressRewriteMapper.hasCandidateType(candidateType) +} + +func (a *Agent) shouldRewriteHostCandidates() bool { + return a.mDNSMode != MulticastDNSModeQueryAndGather && a.shouldRewriteCandidateType(CandidateTypeHost) +} + +func (a *Agent) applyHostAddressRewrite(addr netip.Addr, mappedAddrs []netip.Addr, iface string) ([]netip.Addr, bool) { + mappedIPs, matched, mode, innerErr := a.addressRewriteMapper.findExternalIPs( + CandidateTypeHost, + addr.String(), + iface, + ) + if innerErr != nil { + a.log.Warnf("Address rewrite mapping is enabled but no external IP is found for %s", addr.String()) + + return mappedAddrs, true + } + if !matched { + return mappedAddrs, true + } + + if mode == AddressRewriteReplace { + mappedAddrs = mappedAddrs[:0] + } + mappedAddrs = appendHostMappedAddrs(mappedAddrs, mappedIPs, addr, a.log) + if len(mappedAddrs) == 0 && mode == AddressRewriteReplace { + a.log.Warnf("Address rewrite mapping is enabled but produced no usable external IP for %s", addr.String()) + + return mappedAddrs, false + } + + return mappedAddrs, true +} + +func appendHostMappedAddrs( + mappedAddrs []netip.Addr, + mappedIPs []net.IP, + addr netip.Addr, + log logging.LeveledLogger, +) []netip.Addr { + for _, mappedIP := range mappedIPs { + conv, ok := netip.AddrFromSlice(mappedIP) + if !ok { + log.Warnf("failed to convert mapped external IP to netip.Addr'%s'", addr.String()) + + continue + } + // we'd rather have an IPv4-mapped IPv6 become IPv4 so that it is usable + mappedAddrs = append(mappedAddrs, conv.Unmap()) + } + + return mappedAddrs +} + +func (a *Agent) applyHostRewriteForUDPMux(candidateIPs []net.IP, udpAddr *net.UDPAddr) ([]net.IP, bool) { + mappedIPs, matched, mode, err := a.addressRewriteMapper.findExternalIPs(CandidateTypeHost, udpAddr.IP.String(), "") + if err != nil { + a.log.Warnf("Address rewrite mapping is enabled but failed for %s: %v", udpAddr.IP.String(), err) + + return candidateIPs, false + } + if !matched { + return candidateIPs, true + } + if len(mappedIPs) == 0 { + if mode == AddressRewriteReplace { + return candidateIPs, false + } + + return candidateIPs, true + } + if mode == AddressRewriteReplace { + candidateIPs = candidateIPs[:0] + } + + return append(candidateIPs, mappedIPs...), true +} + +// gatherCandidatesInternal performs the actual candidate gathering for all configured types. +func (a *Agent) gatherCandidatesInternal(ctx context.Context) { + var wg sync.WaitGroup + for _, t := range a.candidateTypes { + switch t { + case CandidateTypeHost: + wg.Add(1) + go func() { + a.gatherCandidatesLocal(ctx, a.networkTypes) + wg.Done() + }() + case CandidateTypeServerReflexive: + a.gatherServerReflexiveCandidates(ctx, &wg) + case CandidateTypeRelay: + wg.Add(1) + go func() { + a.gatherCandidatesRelay(ctx, a.urls) + wg.Done() + }() + case CandidateTypePeerReflexive, CandidateTypeUnspecified: + } + } + + // Block until all STUN and TURN URLs have been gathered (or timed out) + wg.Wait() +} + +func (a *Agent) gatherServerReflexiveCandidates(ctx context.Context, wg *sync.WaitGroup) { + replaceSrflx := a.addressRewriteMapper != nil && a.addressRewriteMapper.shouldReplace(CandidateTypeServerReflexive) + if !replaceSrflx { + wg.Add(1) + go func() { + if a.udpMuxSrflx != nil { + a.gatherCandidatesSrflxUDPMux(ctx, a.urls, a.networkTypes) + } else { + a.gatherCandidatesSrflx(ctx, a.urls, a.networkTypes) + } + wg.Done() + }() + } + if a.addressRewriteMapper != nil && a.addressRewriteMapper.hasCandidateType(CandidateTypeServerReflexive) { + wg.Add(1) + go func() { + a.gatherCandidatesSrflxMapped(ctx, a.networkTypes) + wg.Done() + }() + } +} + +//nolint:gocognit,gocyclo,cyclop,maintidx +func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []NetworkType) { + networks := map[string]struct{}{} + for _, networkType := range networkTypes { + if networkType.IsTCP() { + networks[tcp] = struct{}{} + } else { + networks[udp] = struct{}{} + } + } + + // When UDPMux is enabled, skip other UDP candidates + if a.udpMux != nil { + if err := a.gatherCandidatesLocalUDPMux(ctx); err != nil { + a.log.Warnf("Failed to create host candidate for UDPMux: %s", err) + } + delete(networks, udp) + } + + _, localAddrs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, networkTypes, a.includeLoopback) + if err != nil { + a.log.Warnf("Failed to iterate local interfaces, host candidates will not be gathered %s", err) + + return + } + + for _, info := range localAddrs { + addr := info.addr + ifaceName := info.iface + mappedAddrs := []netip.Addr{addr} + if a.shouldRewriteHostCandidates() { + var ok bool + mappedAddrs, ok = a.applyHostAddressRewrite(addr, mappedAddrs, ifaceName) + if !ok { + continue + } + } + + for mappedIdx, mappedIP := range mappedAddrs { + address := mappedIP.String() + var isLocationTracked bool + if a.mDNSMode == MulticastDNSModeQueryAndGather { + address = a.mDNSName + } else { + // Here, we are not doing multicast gathering, so we will need to skip this address so + // that we don't accidentally reveal location tracking information. Otherwise, the + // case above hides the IP behind an mDNS address. + isLocationTracked = shouldFilterLocationTrackedIP(mappedIP) + } + + for network := range networks { + // TCPMux maintains a single listener per interface. Avoid duplicating passive TCP candidates + // for additional mapped IPs until connection sharing is supported. + if network == tcp && mappedIdx > 0 { + continue + } + + type connAndPort struct { + conn net.PacketConn + port int + } + var ( + conns []connAndPort + tcpType TCPType + ) + + switch network { + case tcp: + if a.tcpMux == nil { + continue + } + + // Only advertise TCP candidates for addresses that the mux listener is actually + // bound to. When the listener is bound to a specific IP, exposing other interface + // addresses would generate unreachable passive candidates and can stall active + // TCP connect attempts. + if addrProvider, ok := a.tcpMux.(interface{ LocalAddr() net.Addr }); ok { + if muxAddr, ok := addrProvider.LocalAddr().(*net.TCPAddr); ok { + if ip := muxAddr.IP; ip != nil && !ip.IsUnspecified() && !ip.Equal(addr.AsSlice()) { + continue + } + } + } + + // Handle ICE TCP passive mode + var muxConns []net.PacketConn + if multi, ok := a.tcpMux.(AllConnsGetter); ok { + a.log.Debugf("GetAllConns by ufrag: %s", a.localUfrag) + // Note: this is missing zone for IPv6 by just grabbing the IP slice + muxConns, err = multi.GetAllConns(a.localUfrag, mappedIP.Is6(), addr.AsSlice()) + if err != nil { + a.log.Warnf("Failed to get all TCP connections by ufrag: %s %s %s", network, addr, a.localUfrag) + + continue + } + } else { + a.log.Debugf("GetConn by ufrag: %s", a.localUfrag) + // Note: this is missing zone for IPv6 by just grabbing the IP slice + conn, err := a.tcpMux.GetConnByUfrag(a.localUfrag, mappedIP.Is6(), addr.AsSlice()) + if err != nil { + a.log.Warnf("Failed to get TCP connections by ufrag: %s %s %s", network, addr, a.localUfrag) + + continue + } + muxConns = []net.PacketConn{conn} + } + + // Extract the port for each PacketConn we got. + for _, conn := range muxConns { + if tcpConn, ok := conn.LocalAddr().(*net.TCPAddr); ok { + conns = append(conns, connAndPort{conn, tcpConn.Port}) + } else { + a.log.Warnf("Failed to get port of connection from TCPMux: %s %s %s", network, addr, a.localUfrag) + } + } + if len(conns) == 0 { + // Didn't succeed with any, try the next network. + continue + } + tcpType = TCPTypePassive + // Is there a way to verify that the listen address is even + // accessible from the current interface. + case udp: + conn, err := listenUDPInPortRange(a.net, a.log, int(a.portMax), int(a.portMin), network, &net.UDPAddr{ + IP: addr.AsSlice(), + Port: 0, + Zone: addr.Zone(), + }) + if err != nil { + a.log.Warnf("Failed to listen %s %s", network, addr) + + continue + } + + if udpConn, ok := conn.LocalAddr().(*net.UDPAddr); ok { + conns = append(conns, connAndPort{conn, udpConn.Port}) + } else { + a.log.Warnf("Failed to get port of UDPAddr from ListenUDPInPortRange: %s %s %s", network, addr, a.localUfrag) + + continue + } + } + + for _, connAndPort := range conns { + hostConfig := CandidateHostConfig{ + Network: network, + Address: address, + Port: connAndPort.port, + Component: ComponentRTP, + TCPType: tcpType, + // we will still process this candidate so that we start up the right + // listeners. + IsLocationTracked: isLocationTracked, + } + + candidateHost, err := NewCandidateHost(&hostConfig) + + if err == nil && a.mDNSMode == MulticastDNSModeQueryAndGather { + err = candidateHost.setIPAddr(addr) + } + + if err != nil { + closeConnAndLog( + connAndPort.conn, + a.log, + "failed to create host candidate: %s %s %d: %v", + network, mappedIP, + connAndPort.port, + err, + ) + + continue + } + + if err := a.addCandidate(ctx, candidateHost, connAndPort.conn); err != nil { + if closeErr := candidateHost.close(); closeErr != nil { + a.log.Warnf("Failed to close candidate: %v", closeErr) + } + a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v", err) + } + } + } + } + } +} + +// shouldFilterLocationTrackedIP returns if this candidate IP should be filtered out from +// any candidate publishing/notification for location tracking reasons. +func shouldFilterLocationTrackedIP(candidateIP netip.Addr) bool { + // https://tools.ietf.org/html/rfc8445#section-5.1.1.1 + // Similarly, when host candidates corresponding to + // an IPv6 address generated using a mechanism that prevents location + // tracking are gathered, then host candidates corresponding to IPv6 + // link-local addresses [RFC4291] MUST NOT be gathered. + return candidateIP.Is6() && (candidateIP.IsLinkLocalUnicast() || candidateIP.IsLinkLocalMulticast()) +} + +// shouldFilterLocationTracked returns if this candidate IP should be filtered out from +// any candidate publishing/notification for location tracking reasons. +func shouldFilterLocationTracked(candidateIP net.IP) bool { + addr, ok := netip.AddrFromSlice(candidateIP) + if !ok { + return false + } + + return shouldFilterLocationTrackedIP(addr) +} + +func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolint:gocognit,cyclop + if a.udpMux == nil { + return errUDPMuxDisabled + } + + localAddresses := a.udpMux.GetListenAddresses() + existingConfigs := make(map[CandidateHostConfig]struct{}) + + for _, addr := range localAddresses { + udpAddr, ok := addr.(*net.UDPAddr) + if !ok { + return errInvalidAddress + } + candidateIPs := []net.IP{udpAddr.IP} + + if _, ok := a.udpMux.(*UDPMuxDefault); ok && !a.includeLoopback && udpAddr.IP.IsLoopback() { + // Unlike MultiUDPMux Default, UDPMuxDefault doesn't have + // a separate param to include loopback, so we respect agent config + continue + } + + if a.shouldRewriteHostCandidates() { + var ok bool + candidateIPs, ok = a.applyHostRewriteForUDPMux(candidateIPs, udpAddr) + if !ok { + continue + } + } + + for _, candidateIP := range candidateIPs { + var address string + var isLocationTracked bool + if a.mDNSMode == MulticastDNSModeQueryAndGather { + address = a.mDNSName + } else { + address = candidateIP.String() + // Here, we are not doing multicast gathering, so we will need to skip this address so + // that we don't accidentally reveal location tracking information. Otherwise, the + // case above hides the IP behind an mDNS address. + isLocationTracked = shouldFilterLocationTracked(candidateIP) + } + + hostConfig := CandidateHostConfig{ + Network: udp, + Address: address, + Port: udpAddr.Port, + Component: ComponentRTP, + IsLocationTracked: isLocationTracked, + } + + // Detect a duplicate candidate before calling addCandidate(). + // otherwise, addCandidate() detects the duplicate candidate + // and close its connection, invalidating all candidates + // that share the same connection. + if _, ok := existingConfigs[hostConfig]; ok { + continue + } + + conn, err := a.udpMux.GetConn(a.localUfrag, udpAddr) + if err != nil { + return err + } + + c, err := NewCandidateHost(&hostConfig) + if err != nil { + closeConnAndLog(conn, a.log, "failed to create host mux candidate: %s %d: %v", candidateIP, udpAddr.Port, err) + + continue + } + + if err := a.addCandidate(ctx, c, conn); err != nil { + if closeErr := c.close(); closeErr != nil { + a.log.Warnf("Failed to close candidate: %v", closeErr) + } + + closeConnAndLog(conn, a.log, "failed to add candidate: %s %d: %v", candidateIP, udpAddr.Port, err) + + continue + } + + existingConfigs[hostConfig] = struct{}{} + } + } + + return nil +} + +func (a *Agent) gatherCandidatesSrflxMapped(ctx context.Context, networkTypes []NetworkType) { //nolint:gocognit,cyclop + var wg sync.WaitGroup + defer wg.Wait() + + _, ifaces, _ := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, networkTypes, a.includeLoopback) + + for _, networkType := range networkTypes { + if networkType.IsTCP() { + continue + } + + network := networkType.String() + wg.Add(1) + go func() { + defer wg.Done() + + conn, err := listenUDPInPortRange( + a.net, + a.log, + int(a.portMax), + int(a.portMin), + network, + &net.UDPAddr{IP: nil, Port: 0}, + ) + if err != nil { + a.log.Warnf("Failed to listen %s: %v", network, err) + + return + } + + lAddr, ok := conn.LocalAddr().(*net.UDPAddr) + if !ok { + closeConnAndLog(conn, a.log, "Address rewrite mapping is enabled but LocalAddr is not a UDPAddr") + + return + } + + addresses, ok := a.resolveSrflxAddresses(lAddr.IP, findIfaceForIP(ifaces, lAddr.IP)) + if !ok { + closeConnAndLog( + conn, a.log, "Address rewrite mapping did not provide usable external IPs for %s", lAddr.IP.String(), + ) + + return + } + + for idx, mappedIP := range addresses { + currentConn := conn + currentAddr := lAddr + if idx > 0 { + newConn, listenErr := listenUDPInPortRange( + a.net, + a.log, + int(a.portMax), + int(a.portMin), + network, + &net.UDPAddr{IP: lAddr.IP, Port: 0}, + ) + if listenErr != nil { + closeConnAndLog(newConn, a.log, "Failed to listen %s for additional srflx mapping: %v", network, listenErr) + + return + } + currentConn = newConn + var ok bool + currentAddr, ok = currentConn.LocalAddr().(*net.UDPAddr) + if !ok { + closeConnAndLog(currentConn, a.log, "Address rewrite mapping is enabled but LocalAddr is not a UDPAddr") + + return + } + } + + if shouldFilterLocationTracked(mappedIP) { + closeConnAndLog(currentConn, a.log, "external IP is somehow filtered for location tracking reasons %s", mappedIP) + + continue + } + + srflxConfig := CandidateServerReflexiveConfig{ + Network: network, + Address: mappedIP.String(), + Port: currentAddr.Port, + Component: ComponentRTP, + RelAddr: currentAddr.IP.String(), + RelPort: currentAddr.Port, + } + c, err := NewCandidateServerReflexive(&srflxConfig) + if err != nil { + closeConnAndLog(currentConn, a.log, "failed to create server reflexive candidate: %s %s %d: %v", + network, + mappedIP.String(), + currentAddr.Port, + err) + + continue + } + + if err := a.addCandidate(ctx, c, currentConn); err != nil { + if closeErr := c.close(); closeErr != nil { + a.log.Warnf("Failed to close candidate: %v", closeErr) + } + a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v", err) + closeConnAndLog( + currentConn, + a.log, + "closing srflx conn after addCandidate failure: %v", + err, + ) + } + } + }() + } +} + +//nolint:gocognit,cyclop +func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*stun.URI, networkTypes []NetworkType) { + var wg sync.WaitGroup + defer wg.Wait() + + for _, networkType := range networkTypes { + if networkType.IsTCP() { + continue + } + + for i := range urls { + if !urlSupportsSrflxGathering(*urls[i]) { + continue + } + + for _, listenAddr := range a.udpMuxSrflx.GetListenAddresses() { + udpAddr, ok := listenAddr.(*net.UDPAddr) + if !ok { + a.log.Warn("Failed to cast udpMuxSrflx listen address to UDPAddr") + + continue + } + wg.Add(1) + go func(url stun.URI, network string, localAddr *net.UDPAddr) { + defer wg.Done() + + hostPort := net.JoinHostPort(url.Host, strconv.Itoa(url.Port)) + serverAddr, err := a.net.ResolveUDPAddr(network, hostPort) + if err != nil { + a.log.Debugf("Failed to resolve STUN host: %s %s: %v", network, hostPort, err) + + return + } + + if shouldFilterLocationTracked(serverAddr.IP) { + a.log.Warnf("STUN host %s is somehow filtered for location tracking reasons", hostPort) + + return + } + + xorAddr, err := a.udpMuxSrflx.GetXORMappedAddr(serverAddr, a.stunGatherTimeout) + if err != nil { + a.log.Warnf("Failed get server reflexive address %s %s: %v", network, url, err) + + return + } + + conn, err := a.udpMuxSrflx.GetConnForURL(a.localUfrag, url.String(), localAddr) + if err != nil { + a.log.Warnf("Failed to find connection in UDPMuxSrflx %s %s: %v", network, url, err) + + return + } + + ip := xorAddr.IP + port := xorAddr.Port + + srflxConfig := CandidateServerReflexiveConfig{ + Network: network, + Address: ip.String(), + Port: port, + Component: ComponentRTP, + RelAddr: localAddr.IP.String(), + RelPort: localAddr.Port, + } + c, err := NewCandidateServerReflexive(&srflxConfig) + if err != nil { + closeConnAndLog(conn, a.log, "failed to create server reflexive candidate: %s %s %d: %v", network, ip, port, err) + + return + } + + if err := a.addCandidate(ctx, c, conn); err != nil { + if closeErr := c.close(); closeErr != nil { + a.log.Warnf("Failed to close candidate: %v", closeErr) + } + a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v", err) + } + }(*urls[i], networkType.String(), udpAddr) + } + } + } +} + +//nolint:cyclop,gocognit +func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, networkTypes []NetworkType) { + var wg sync.WaitGroup + defer wg.Wait() + + useFilteredLocalAddrs := a.interfaceFilter != nil || a.ipFilter != nil + localAddrs := []ifaceAddr{} + if useFilteredLocalAddrs { + _, addrs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, networkTypes, a.includeLoopback) + if err != nil { + a.log.Warnf("Failed to iterate local interfaces, srflx candidates will not be gathered %s", err) + + return + } + localAddrs = addrs + } + + gatherForURL := func(url stun.URI, network string, listenAddr *net.UDPAddr) { + defer wg.Done() + + hostPort := net.JoinHostPort(url.Host, strconv.Itoa(url.Port)) + serverAddr, err := a.net.ResolveUDPAddr(network, hostPort) + if err != nil { + a.log.Debugf("Failed to resolve STUN host: %s %s: %v", network, hostPort, err) + + return + } + + if shouldFilterLocationTracked(serverAddr.IP) { + a.log.Warnf("STUN host %s is somehow filtered for location tracking reasons", hostPort) + + return + } + + conn, err := listenUDPInPortRange( + a.net, + a.log, + int(a.portMax), + int(a.portMin), + network, + listenAddr, + ) + if err != nil { + closeConnAndLog(conn, a.log, "failed to listen for %s: %v", serverAddr.String(), err) + + return + } + // If the agent closes midway through the connection + // we end it early to prevent close delay. + cancelCtx, cancelFunc := context.WithCancel(ctx) + defer cancelFunc() + go func() { + select { + case <-cancelCtx.Done(): + return + case <-a.loop.Done(): + _ = conn.Close() + } + }() + + xorAddr, err := stunx.GetXORMappedAddr(conn, serverAddr, a.stunGatherTimeout) + if err != nil { + closeConnAndLog(conn, a.log, "failed to get server reflexive address %s %s: %v", network, url, err) + + return + } + + ip := xorAddr.IP + port := xorAddr.Port + + lAddr := conn.LocalAddr().(*net.UDPAddr) //nolint:forcetypeassert + srflxConfig := CandidateServerReflexiveConfig{ + Network: network, + Address: ip.String(), + Port: port, + Component: ComponentRTP, + RelAddr: lAddr.IP.String(), + RelPort: lAddr.Port, + } + c, err := NewCandidateServerReflexive(&srflxConfig) + if err != nil { + closeConnAndLog(conn, a.log, "failed to create server reflexive candidate: %s %s %d: %v", network, ip, port, err) + + return + } + + if err := a.addCandidate(ctx, c, conn); err != nil { + if closeErr := c.close(); closeErr != nil { + a.log.Warnf("Failed to close candidate: %v", closeErr) + } + a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v", err) + } + } + + for _, networkType := range networkTypes { + if networkType.IsTCP() { + continue + } + + for i := range urls { + if !urlSupportsSrflxGathering(*urls[i]) { + continue + } + + if !useFilteredLocalAddrs { + wg.Add(1) + go gatherForURL(*urls[i], networkType.String(), &net.UDPAddr{IP: nil, Port: 0}) + + continue + } + + for j := range localAddrs { + if networkType.IsIPv4() && localAddrs[j].addr.Is6() { + continue + } + if networkType.IsIPv6() && !localAddrs[j].addr.Is6() { + continue + } + + wg.Add(1) + go gatherForURL( + *urls[i], + networkType.String(), + &net.UDPAddr{IP: localAddrs[j].addr.AsSlice(), Zone: localAddrs[j].addr.Zone(), Port: 0}, + ) + } + } + } +} + +//nolint:maintidx,gocognit,gocyclo,cyclop +func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { + var wg sync.WaitGroup + defer wg.Wait() + _, ifaces, _ := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, a.networkTypes, a.includeLoopback) + + useFilteredLocalAddrs := a.interfaceFilter != nil || a.ipFilter != nil + localAddrs := []ifaceAddr{} + if useFilteredLocalAddrs { + localAddrs = append(localAddrs, ifaces...) + } + + for _, url := range urls { + switch { + case url.Scheme != stun.SchemeTypeTURN && url.Scheme != stun.SchemeTypeTURNS: + continue + case url.Username == "": + a.log.Errorf("Failed to gather relay candidates: %v", ErrUsernameEmpty) + + return + case url.Password == "": + a.log.Errorf("Failed to gather relay candidates: %v", ErrPasswordEmpty) + + return + } + + networkTypes := relayNetworkTypesForURL(*url, a.networkTypes) + if len(networkTypes) == 0 { + continue + } + + for _, networkType := range networkTypes { + // IPv6 TURN support is not finished yet, so skip for now. + if networkType.IsIPv6() { + continue + } + + network := networkType.String() + urlProto := effectiveURLProtoType(*url) + + bindAddrs := []string{} + if !useFilteredLocalAddrs { // nolint:nestif + if networkType.IsIPv6() { + bindAddrs = append(bindAddrs, "[::]:0") + } else { + bindAddrs = append(bindAddrs, "0.0.0.0:0") + } + } else { + for i := range localAddrs { + if networkType.IsIPv4() && localAddrs[i].addr.Is6() { + continue + } + if networkType.IsIPv6() && !localAddrs[i].addr.Is6() { + continue + } + + bindAddrs = append(bindAddrs, net.JoinHostPort(localAddrs[i].addr.String(), "0")) + } + } + + for _, localBindAddr := range bindAddrs { + wg.Add(1) + go func(url stun.URI, network string, urlProto stun.ProtoType, localBindAddr string) { + defer wg.Done() + + turnServerAddr := net.JoinHostPort(url.Host, strconv.Itoa(url.Port)) + var ( + locConn net.PacketConn + err error + relAddr string + relPort int + relayProtocol string + ) + + switch { + case urlProto == stun.ProtoTypeUDP && url.Scheme == stun.SchemeTypeTURN: + if locConn, err = a.net.ListenPacket(network, localBindAddr); err != nil { + a.log.Warnf("Failed to listen %s: %v", network, err) + + return + } + + relAddr = locConn.LocalAddr().(*net.UDPAddr).IP.String() //nolint:forcetypeassert + relPort = locConn.LocalAddr().(*net.UDPAddr).Port //nolint:forcetypeassert + relayProtocol = udp + case a.proxyDialer != nil && urlProto == stun.ProtoTypeTCP && + (url.Scheme == stun.SchemeTypeTURN || url.Scheme == stun.SchemeTypeTURNS): + conn, connectErr := a.proxyDialer.Dial(network, turnServerAddr) + if connectErr != nil { + a.log.Warnf("Failed to dial TCP address %s via proxy dialer: %v", turnServerAddr, connectErr) + + return + } + + relAddr = conn.LocalAddr().(*net.TCPAddr).IP.String() //nolint:forcetypeassert + relPort = conn.LocalAddr().(*net.TCPAddr).Port //nolint:forcetypeassert + if url.Scheme == stun.SchemeTypeTURN { + relayProtocol = tcp + } else if url.Scheme == stun.SchemeTypeTURNS { + relayProtocol = "tls" + } + locConn = turn.NewSTUNConn(conn) + + case urlProto == stun.ProtoTypeTCP && url.Scheme == stun.SchemeTypeTURN: + tcpAddr, connectErr := a.net.ResolveTCPAddr(network, turnServerAddr) + if connectErr != nil { + a.log.Warnf("Failed to resolve TCP address %s: %v", turnServerAddr, connectErr) + + return + } + + conn, connectErr := a.net.DialTCP(network, nil, tcpAddr) + if connectErr != nil { + a.log.Warnf("Failed to dial TCP address %s: %v", turnServerAddr, connectErr) + + return + } + + relAddr = conn.LocalAddr().(*net.TCPAddr).IP.String() //nolint:forcetypeassert + relPort = conn.LocalAddr().(*net.TCPAddr).Port //nolint:forcetypeassert + relayProtocol = tcp + locConn = turn.NewSTUNConn(conn) + case urlProto == stun.ProtoTypeUDP && url.Scheme == stun.SchemeTypeTURNS: + udpAddr, connectErr := a.net.ResolveUDPAddr(network, turnServerAddr) + if connectErr != nil { + a.log.Warnf("Failed to resolve UDP address %s: %v", turnServerAddr, connectErr) + + return + } + + udpConn, dialErr := a.net.DialUDP(network, nil, udpAddr) + if dialErr != nil { + a.log.Warnf("Failed to dial DTLS address %s: %v", turnServerAddr, dialErr) + + return + } + + conn, connectErr := dtls.ClientWithOptions(&fakenet.PacketConn{Conn: udpConn}, udpConn.RemoteAddr(), + dtls.WithServerName(url.Host), + dtls.WithInsecureSkipVerify(a.insecureSkipVerify), //nolint:gosec + dtls.WithLoggerFactory(a.loggerFactory), + ) + if connectErr != nil { + a.log.Warnf("Failed to create DTLS client: %v", turnServerAddr, connectErr) + if closeErr := udpConn.Close(); closeErr != nil { + a.log.Errorf("Failed to close relay connection: %v", closeErr) + } + + return + } + + if connectErr = conn.HandshakeContext(ctx); connectErr != nil { + a.log.Warnf("Failed to create DTLS client: %v", turnServerAddr, connectErr) + if closeErr := conn.Close(); closeErr != nil { + a.log.Errorf("Failed to close relay connection: %v", closeErr) + } + + return + } + + relAddr = conn.LocalAddr().(*net.UDPAddr).IP.String() //nolint:forcetypeassert + relPort = conn.LocalAddr().(*net.UDPAddr).Port //nolint:forcetypeassert + relayProtocol = relayProtocolDTLS + locConn = &fakenet.PacketConn{Conn: conn} + case urlProto == stun.ProtoTypeTCP && url.Scheme == stun.SchemeTypeTURNS: + tcpAddr, resolvErr := a.net.ResolveTCPAddr(network, turnServerAddr) + if resolvErr != nil { + a.log.Warnf("Failed to resolve relay address %s: %v", turnServerAddr, resolvErr) + + return + } + + tcpConn, dialErr := a.net.DialTCP(network, nil, tcpAddr) + if dialErr != nil { + a.log.Warnf("Failed to connect to relay: %v", dialErr) + + return + } + + conn := tls.Client(tcpConn, &tls.Config{ + ServerName: url.Host, + InsecureSkipVerify: a.insecureSkipVerify, //nolint:gosec + }) + + if hsErr := conn.HandshakeContext(ctx); hsErr != nil { + if closeErr := tcpConn.Close(); closeErr != nil { + a.log.Errorf("Failed to close relay connection: %v", closeErr) + } + a.log.Warnf("Failed to connect to relay: %v", hsErr) + + return + } + + relAddr = conn.LocalAddr().(*net.TCPAddr).IP.String() //nolint:forcetypeassert + relPort = conn.LocalAddr().(*net.TCPAddr).Port //nolint:forcetypeassert + relayProtocol = relayProtocolTLS + locConn = turn.NewSTUNConn(conn) + default: + a.log.Warnf("Unable to handle URL in gatherCandidatesRelay %v", url) + + return + } + + factory := a.turnClientFactory + if factory == nil { + factory = defaultTurnClient + } + + client, err := factory(&turn.ClientConfig{ + TURNServerAddr: turnServerAddr, + Conn: locConn, + Username: url.Username, + Password: url.Password, + LoggerFactory: a.loggerFactory, + Net: a.net, + }) + if err != nil { + closeConnAndLog(locConn, a.log, "failed to create new TURN client %s %s", turnServerAddr, err) + + return + } + + if err = client.Listen(); err != nil { + client.Close() + closeConnAndLog(locConn, a.log, "failed to listen on TURN client %s %s", turnServerAddr, err) + + return + } + + relayConn, err := client.Allocate() + if err != nil { + client.Close() + closeConnAndLog(locConn, a.log, "failed to allocate on TURN client %s %s", turnServerAddr, err) + + return + } + + closeRelayConn := func() { + if relayConErr := relayConn.Close(); relayConErr != nil { + a.log.Warnf("Failed to close relay %v", relayConErr) + } + } + + rAddr := relayConn.LocalAddr().(*net.UDPAddr) //nolint:forcetypeassert + if shouldFilterLocationTracked(rAddr.IP) { + closeRelayConn() + client.Close() + closeConnAndLog(locConn, a.log, + "TURN address %s is somehow filtered for location tracking reasons", rAddr.IP) + + return + } + + a.addRelayCandidates(ctx, relayEndpoint{ + network: network, + address: rAddr.IP, + port: rAddr.Port, + relAddr: relAddr, + relPort: relPort, + iface: findIfaceForIP(ifaces, net.ParseIP(relAddr)), + protocol: relayProtocol, + conn: relayConn, + onClose: func() error { + client.Close() + + return locConn.Close() + }, + closeConn: closeRelayConn, + }) + }(*url, network, urlProto, localBindAddr) + } + } + } +} + +type relayEndpoint struct { + network string + address net.IP + port int + relAddr string + relPort int + protocol string + iface string + conn net.PacketConn + onClose func() error + closeConn func() +} + +func (a *Agent) resolveRelayAddresses(ep relayEndpoint) ([]net.IP, bool) { + addresses := []net.IP{ep.address} + if !a.shouldRewriteCandidateType(CandidateTypeRelay) { + return addresses, true + } + + mappedIPs, matched, mode, err := a.addressRewriteMapper.findExternalIPs( + CandidateTypeRelay, + ep.relAddr, + ep.iface, + ) + if err != nil { + return nil, false + } + if !matched { + return addresses, true + } + if len(mappedIPs) == 0 { + if mode == AddressRewriteReplace { + a.log.Warnf("Address rewrite mapping returned no external relay addresses for %s", ep.relAddr) + + return nil, false + } + + return addresses, true + } + if mode == AddressRewriteReplace { + return mappedIPs, true + } + + return append(addresses, mappedIPs...), true +} + +func (a *Agent) resolveSrflxAddresses(localIP net.IP, iface string) ([]net.IP, bool) { + addresses := []net.IP{localIP} + if !a.shouldRewriteCandidateType(CandidateTypeServerReflexive) { + return addresses, true + } + + mappedIPs, matched, mode, err := a.addressRewriteMapper.findExternalIPs( + CandidateTypeServerReflexive, + localIP.String(), + iface, + ) + if err != nil { + a.log.Warnf("Address rewrite mapping is enabled but no external IP is found for %s: %v", localIP.String(), err) + + return nil, false + } + + if !matched { + return addresses, true + } + + if len(mappedIPs) == 0 { + if mode == AddressRewriteReplace { + return nil, false + } + + return addresses, true + } + + if mode == AddressRewriteReplace { + return mappedIPs, true + } + + return mappedIPs, true +} + +func findIfaceForIP(ifaces []ifaceAddr, ip net.IP) string { + if ip == nil { + return "" + } + for _, info := range ifaces { + if info.addr.String() == ip.String() { + return info.iface + } + } + + return "" +} + +func (a *Agent) createRelayCandidate(ctx context.Context, ep relayEndpoint, ip net.IP, onClose func() error) error { + relayConfig := CandidateRelayConfig{ + Network: ep.network, + Component: ComponentRTP, + Address: ip.String(), + Port: ep.port, + RelAddr: ep.relAddr, + RelPort: ep.relPort, + RelayProtocol: ep.protocol, + OnClose: onClose, + } + candidate, err := NewCandidateRelay(&relayConfig) + if err != nil { + a.log.Warnf("failed to create relay candidate: %s %d: %v", ip, ep.port, err) + + return err + } + + if err := a.addCandidate(ctx, candidate, ep.conn); err != nil { + if closeErr := candidate.close(); closeErr != nil { + a.log.Warnf("Failed to close candidate: %v", closeErr) + } + a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v", err) + + return err + } + + return nil +} + +func (a *Agent) addRelayCandidates(ctx context.Context, ep relayEndpoint) { + if ep.conn == nil || ep.address == nil { + return + } + + addresses, ok := a.resolveRelayAddresses(ep) + if !ok { + return + } + + for idx, ip := range addresses { + onClose := ep.onClose + if idx > 0 { + onClose = nil + } + + if err := a.createRelayCandidate(ctx, ep, ip, onClose); err != nil { + if idx == 0 { + if ep.closeConn != nil { + ep.closeConn() + } + + return + } + + a.log.Warnf("failed to create additional relay candidate for %s: %v", ip, err) + + continue + } + } +} + +// startNetworkMonitoring starts a goroutine that periodically checks for network changes +// and re-gathers candidates when changes are detected. This is only used with GatherContinually policy. +func (a *Agent) startNetworkMonitoring(ctx context.Context) { + ticker := time.NewTicker(a.networkMonitorInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if a.detectNetworkChanges() { + a.gatherCandidatesInternal(ctx) + } + } + } +} + +// detectNetworkChanges checks if the network interfaces have changed since the last check. +func (a *Agent) detectNetworkChanges() bool { + // Try to refresh interfaces if using stdnet + if stdNet, ok := a.net.(*stdnet.Net); ok { + if err := stdNet.UpdateInterfaces(); err != nil { + a.log.Warnf("Failed to update interfaces: %v", err) + } + } + + _, currentAddrs, err := localInterfaces( + a.net, + a.interfaceFilter, + a.ipFilter, + a.networkTypes, + a.includeLoopback, + ) + if err != nil { + a.log.Warnf("Failed to get local interfaces during network monitoring: %v", err) + + return false + } + + currentInterfaces := make(map[string]netip.Addr) + for _, info := range currentAddrs { + key := info.addr.String() + currentInterfaces[key] = info.addr + } + + hasAdditions := false + + for key, addr := range currentInterfaces { + if _, exists := a.lastKnownInterfaces[key]; !exists { + a.log.Infof("New IP address detected: %s", addr) + hasAdditions = true + } + } + + a.lastKnownInterfaces = currentInterfaces + + return hasAdditions +} diff --git a/vendor/github.com/pion/ice/v4/ice.go b/vendor/github.com/pion/ice/v4/ice.go new file mode 100644 index 0000000..9278fcb --- /dev/null +++ b/vendor/github.com/pion/ice/v4/ice.go @@ -0,0 +1,110 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +// ConnectionState is an enum showing the state of a ICE Connection. +type ConnectionState int + +// List of supported States. +const ( + // ConnectionStateUnknown represents an unknown state. + ConnectionStateUnknown ConnectionState = iota + + // ConnectionStateNew ICE agent is gathering addresses. + ConnectionStateNew + + // ConnectionStateChecking ICE agent has been given local and remote candidates, and is attempting to find a match. + ConnectionStateChecking + + // ConnectionStateConnected ICE agent has a pairing, but is still checking other pairs. + ConnectionStateConnected + + // ConnectionStateCompleted ICE agent has finished. + ConnectionStateCompleted + + // ConnectionStateFailed ICE agent never could successfully connect. + ConnectionStateFailed + + // ConnectionStateDisconnected ICE agent connected successfully, but has entered a failed state. + ConnectionStateDisconnected + + // ConnectionStateClosed ICE agent has finished and is no longer handling requests. + ConnectionStateClosed +) + +func (c ConnectionState) String() string { + switch c { + case ConnectionStateNew: + return "New" + case ConnectionStateChecking: + return "Checking" + case ConnectionStateConnected: + return "Connected" + case ConnectionStateCompleted: + return "Completed" + case ConnectionStateFailed: + return "Failed" + case ConnectionStateDisconnected: + return "Disconnected" + case ConnectionStateClosed: + return "Closed" + default: + return "Invalid" + } +} + +// GatheringState describes the state of the candidate gathering process. +type GatheringState int + +const ( + // GatheringStateUnknown represents an unknown state. + GatheringStateUnknown GatheringState = iota + + // GatheringStateNew indicates candidate gathering is not yet started. + GatheringStateNew + + // GatheringStateGathering indicates candidate gathering is ongoing. + GatheringStateGathering + + // GatheringStateComplete indicates candidate gathering has been completed. + GatheringStateComplete +) + +func (t GatheringState) String() string { + switch t { + case GatheringStateNew: + return "new" + case GatheringStateGathering: + return "gathering" + case GatheringStateComplete: + return "complete" + default: + return ErrUnknownType.Error() + } +} + +// ContinualGatheringPolicy defines the behavior for gathering ICE candidates. +type ContinualGatheringPolicy int + +const ( + GatherOnce ContinualGatheringPolicy = iota + GatherContinually +) + +func (c ContinualGatheringPolicy) String() string { + switch c { + case GatherOnce: + return "gather_once" + case GatherContinually: + return "gather_continually" + default: + return unknownStr + } +} + +const ( + unknownStr = "unknown" + relayProtocolDTLS = "dtls" + relayProtocolTLS = "tls" +) diff --git a/vendor/github.com/pion/ice/v4/icecontrol.go b/vendor/github.com/pion/ice/v4/icecontrol.go new file mode 100644 index 0000000..d159942 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/icecontrol.go @@ -0,0 +1,96 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "encoding/binary" + + "github.com/pion/stun/v3" +) + +// tiebreaker is common helper for ICE-{CONTROLLED,CONTROLLING} +// and represents the so-called tiebreaker number. +type tiebreaker uint64 + +const tiebreakerSize = 8 // 64 bit + +// AddToAs adds tiebreaker value to m as t attribute. +func (a tiebreaker) AddToAs(m *stun.Message, t stun.AttrType) error { + v := make([]byte, tiebreakerSize) + binary.BigEndian.PutUint64(v, uint64(a)) + m.Add(t, v) + + return nil +} + +// GetFromAs decodes tiebreaker value in message getting it as for t type. +func (a *tiebreaker) GetFromAs(m *stun.Message, t stun.AttrType) error { + v, err := m.Get(t) + if err != nil { + return err + } + if err = stun.CheckSize(t, len(v), tiebreakerSize); err != nil { + return err + } + *a = tiebreaker(binary.BigEndian.Uint64(v)) + + return nil +} + +// AttrControlled represents ICE-CONTROLLED attribute. +type AttrControlled uint64 + +// AddTo adds ICE-CONTROLLED to message. +func (c AttrControlled) AddTo(m *stun.Message) error { + return tiebreaker(c).AddToAs(m, stun.AttrICEControlled) +} + +// GetFrom decodes ICE-CONTROLLED from message. +func (c *AttrControlled) GetFrom(m *stun.Message) error { + return (*tiebreaker)(c).GetFromAs(m, stun.AttrICEControlled) +} + +// AttrControlling represents ICE-CONTROLLING attribute. +type AttrControlling uint64 + +// AddTo adds ICE-CONTROLLING to message. +func (c AttrControlling) AddTo(m *stun.Message) error { + return tiebreaker(c).AddToAs(m, stun.AttrICEControlling) +} + +// GetFrom decodes ICE-CONTROLLING from message. +func (c *AttrControlling) GetFrom(m *stun.Message) error { + return (*tiebreaker)(c).GetFromAs(m, stun.AttrICEControlling) +} + +// AttrControl is helper that wraps ICE-{CONTROLLED,CONTROLLING}. +type AttrControl struct { + Role Role + Tiebreaker uint64 +} + +// AddTo adds ICE-CONTROLLED or ICE-CONTROLLING attribute depending on Role. +func (c AttrControl) AddTo(m *stun.Message) error { + if c.Role == Controlling { + return tiebreaker(c.Tiebreaker).AddToAs(m, stun.AttrICEControlling) + } + + return tiebreaker(c.Tiebreaker).AddToAs(m, stun.AttrICEControlled) +} + +// GetFrom decodes Role and Tiebreaker value from message. +func (c *AttrControl) GetFrom(m *stun.Message) error { + if m.Contains(stun.AttrICEControlling) { + c.Role = Controlling + + return (*tiebreaker)(&c.Tiebreaker).GetFromAs(m, stun.AttrICEControlling) + } + if m.Contains(stun.AttrICEControlled) { + c.Role = Controlled + + return (*tiebreaker)(&c.Tiebreaker).GetFromAs(m, stun.AttrICEControlled) + } + + return stun.ErrAttributeNotFound +} diff --git a/vendor/github.com/pion/ice/v4/internal/atomic/atomic.go b/vendor/github.com/pion/ice/v4/internal/atomic/atomic.go new file mode 100644 index 0000000..d155d18 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/internal/atomic/atomic.go @@ -0,0 +1,24 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package atomic contains custom atomic types +package atomic + +import "sync/atomic" + +// Error is an atomic error. +type Error struct { + v atomic.Value +} + +// Store updates the value of the atomic variable. +func (a *Error) Store(err error) { + a.v.Store(struct{ error }{err}) +} + +// Load retrieves the current value of the atomic variable. +func (a *Error) Load() error { + err, _ := a.v.Load().(struct{ error }) + + return err.error +} diff --git a/vendor/github.com/pion/ice/v4/internal/fakenet/mock_conn.go b/vendor/github.com/pion/ice/v4/internal/fakenet/mock_conn.go new file mode 100644 index 0000000..005143b --- /dev/null +++ b/vendor/github.com/pion/ice/v4/internal/fakenet/mock_conn.go @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js + +package fakenet + +import ( + "net" + "time" +) + +// MockPacketConn for tests. +type MockPacketConn struct{} + +func (m *MockPacketConn) ReadFrom([]byte) (n int, addr net.Addr, err error) { return 0, nil, nil } //nolint:revive +func (m *MockPacketConn) WriteTo([]byte, net.Addr) (n int, err error) { return 0, nil } //nolint:revive +func (m *MockPacketConn) Close() error { return nil } //nolint:revive +func (m *MockPacketConn) LocalAddr() net.Addr { return nil } //nolint:revive +func (m *MockPacketConn) SetDeadline(time.Time) error { return nil } //nolint:revive +func (m *MockPacketConn) SetReadDeadline(time.Time) error { return nil } //nolint:revive +func (m *MockPacketConn) SetWriteDeadline(time.Time) error { return nil } //nolint:revive diff --git a/vendor/github.com/pion/ice/v4/internal/fakenet/packet_conn.go b/vendor/github.com/pion/ice/v4/internal/fakenet/packet_conn.go new file mode 100644 index 0000000..1ab2a22 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/internal/fakenet/packet_conn.go @@ -0,0 +1,30 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package fakenet contains fake network abstractions +package fakenet + +import ( + "net" +) + +// Compile-time assertion. +var _ net.PacketConn = (*PacketConn)(nil) + +// PacketConn wraps a net.Conn and emulates net.PacketConn. +type PacketConn struct { + net.Conn +} + +// ReadFrom reads a packet from the connection. +func (f *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, err = f.Conn.Read(p) + addr = f.Conn.RemoteAddr() + + return +} + +// WriteTo writes a packet with payload p to addr. +func (f *PacketConn) WriteTo(p []byte, _ net.Addr) (int, error) { + return f.Conn.Write(p) +} diff --git a/vendor/github.com/pion/ice/v4/internal/netutil/errno_unix.go b/vendor/github.com/pion/ice/v4/internal/netutil/errno_unix.go new file mode 100644 index 0000000..901655c --- /dev/null +++ b/vendor/github.com/pion/ice/v4/internal/netutil/errno_unix.go @@ -0,0 +1,23 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !windows + +// Package netutil provides network-related helpers. +package netutil + +import ( + "errors" + "syscall" +) + +// IsAddrUnavailable reports whether err indicates that the address +// is unavailable (as opposed to a specific port being busy). +func IsAddrUnavailable(err error) bool { + var errno syscall.Errno + if errors.As(err, &errno) { + return errno == syscall.EADDRNOTAVAIL + } + + return false +} diff --git a/vendor/github.com/pion/ice/v4/internal/netutil/errno_windows.go b/vendor/github.com/pion/ice/v4/internal/netutil/errno_windows.go new file mode 100644 index 0000000..3a06e2d --- /dev/null +++ b/vendor/github.com/pion/ice/v4/internal/netutil/errno_windows.go @@ -0,0 +1,26 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build windows + +package netutil + +import ( + "errors" + "syscall" +) + +// Go's syscall.EADDRNOTAVAIL is an invented POSIX-compat constant that does not +// match the raw Winsock errno returned by the kernel, so we check both. +const wsaeaddrnotavail syscall.Errno = 10049 + +// IsAddrUnavailable reports whether err indicates that the address +// is unavailable (as opposed to a specific port being busy). +func IsAddrUnavailable(err error) bool { + var errno syscall.Errno + if errors.As(err, &errno) { + return errno == syscall.EADDRNOTAVAIL || errno == wsaeaddrnotavail + } + + return false +} diff --git a/vendor/github.com/pion/ice/v4/internal/stun/stun.go b/vendor/github.com/pion/ice/v4/internal/stun/stun.go new file mode 100644 index 0000000..0402863 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/internal/stun/stun.go @@ -0,0 +1,72 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package stun contains ICE specific STUN code +package stun + +import ( + "errors" + "fmt" + "net" + "time" + + "github.com/pion/stun/v3" +) + +var ( + errGetXorMappedAddrResponse = errors.New("failed to get XOR-MAPPED-ADDRESS response") + errMismatchUsername = errors.New("username mismatch") +) + +// GetXORMappedAddr initiates a STUN requests to serverAddr using conn, reads the response and returns +// the XORMappedAddress returned by the STUN server. +func GetXORMappedAddr(conn net.PacketConn, serverAddr net.Addr, timeout time.Duration) (*stun.XORMappedAddress, error) { + if timeout > 0 { + if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil { + return nil, err + } + + // Reset timeout after completion + defer conn.SetReadDeadline(time.Time{}) //nolint:errcheck + } + + req, err := stun.Build(stun.BindingRequest, stun.TransactionID) + if err != nil { + return nil, err + } + + if _, err = conn.WriteTo(req.Raw, serverAddr); err != nil { + return nil, err + } + + const maxMessageSize = 1280 + buf := make([]byte, maxMessageSize) + n, _, err := conn.ReadFrom(buf) + if err != nil { + return nil, err + } + + res := &stun.Message{Raw: buf[:n]} + if err = res.Decode(); err != nil { + return nil, err + } + + var addr stun.XORMappedAddress + if err = addr.GetFrom(res); err != nil { + return nil, fmt.Errorf("%w: %v", errGetXorMappedAddrResponse, err) //nolint:errorlint + } + + return &addr, nil +} + +// AssertUsername checks that the given STUN message m has a USERNAME attribute with a given value. +func AssertUsername(m *stun.Message, expectedUsername string) error { + var username stun.Username + if err := username.GetFrom(m); err != nil { + return err + } else if string(username) != expectedUsername { + return fmt.Errorf("%w expected(%x) actual(%x)", errMismatchUsername, expectedUsername, string(username)) + } + + return nil +} diff --git a/vendor/github.com/pion/ice/v4/internal/taskloop/taskloop.go b/vendor/github.com/pion/ice/v4/internal/taskloop/taskloop.go new file mode 100644 index 0000000..c2a3b33 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/internal/taskloop/taskloop.go @@ -0,0 +1,123 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package taskloop implements a task loop to run +// tasks sequentially in a separate Goroutine. +package taskloop + +import ( + "context" + "errors" + "time" + + atomicx "github.com/pion/ice/v4/internal/atomic" +) + +// ErrClosed indicates that the loop has been stopped. +var ErrClosed = errors.New("the agent is closed") + +type task struct { + fn func(context.Context) + done chan struct{} +} + +// Loop runs submitted task serially in a dedicated Goroutine. +type Loop struct { + tasks chan task + + // State for closing + done chan struct{} + taskLoopDone chan struct{} + err atomicx.Error +} + +// New creates and starts a new task loop. +func New(onClose func()) *Loop { + l := &Loop{ + tasks: make(chan task), + done: make(chan struct{}), + taskLoopDone: make(chan struct{}), + } + + go l.runLoop(onClose) + + return l +} + +// runLoop handles registered tasks and agent close. +func (l *Loop) runLoop(onClose func()) { + defer func() { + onClose() + close(l.taskLoopDone) + }() + + for { + select { + case <-l.done: + return + case t := <-l.tasks: + t.fn(l) + close(t.done) + } + } +} + +// Close stops the loop after finishing the execution of the current task. +// Other pending tasks will not be executed. +func (l *Loop) Close() { + if err := l.Err(); err != nil { + return + } + + l.err.Store(ErrClosed) + + close(l.done) + <-l.taskLoopDone +} + +// Run serially executes the submitted callback. +// Blocking tasks must be cancelable by context. +func (l *Loop) Run(ctx context.Context, t func(context.Context)) error { + if err := l.Err(); err != nil { + return err + } + done := make(chan struct{}) + select { + case <-ctx.Done(): + return ctx.Err() + case <-l.done: + return ErrClosed + case l.tasks <- task{t, done}: + <-done + + return nil + } +} + +// The following methods implement context.Context for TaskLoop + +// Done returns a channel that's closed when the task loop has been stopped. +func (l *Loop) Done() <-chan struct{} { + return l.done +} + +// Err returns nil if the task loop is still running. +// Otherwise it return errClosed if the loop has been closed/stopped. +func (l *Loop) Err() error { + select { + case <-l.done: + return ErrClosed + default: + return nil + } +} + +// Deadline returns the no valid time as task loops have no deadline. +func (l *Loop) Deadline() (deadline time.Time, ok bool) { + return time.Time{}, false +} + +// Value is not supported for task loops. +func (l *Loop) Value(any) any { + return nil +} diff --git a/vendor/github.com/pion/ice/v4/mdns.go b/vendor/github.com/pion/ice/v4/mdns.go new file mode 100644 index 0000000..72010c7 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/mdns.go @@ -0,0 +1,147 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "net" + + "github.com/google/uuid" + "github.com/pion/logging" + "github.com/pion/mdns/v2" + "github.com/pion/transport/v4" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +// MulticastDNSMode represents the different Multicast modes ICE can run in. +type MulticastDNSMode byte + +// MulticastDNSMode enum. +const ( + // MulticastDNSModeDisabled means remote mDNS candidates will be discarded, and local host candidates will use IPs. + MulticastDNSModeDisabled MulticastDNSMode = iota + 1 + + // MulticastDNSModeQueryOnly means remote mDNS candidates will be accepted, and local host candidates will use IPs. + MulticastDNSModeQueryOnly + + // MulticastDNSModeQueryAndGather means remote mDNS candidates will be accepted, + // and local host candidates will use mDNS. + MulticastDNSModeQueryAndGather +) + +func generateMulticastDNSName() (string, error) { + // https://tools.ietf.org/id/draft-ietf-rtcweb-mdns-ice-candidates-02.html#gathering + // The unique name MUST consist of a version 4 UUID as defined in [RFC4122], followed by “.local”. + u, err := uuid.NewRandom() + + return u.String() + ".local", err +} + +//nolint:cyclop +func createMulticastDNS( + netTransport transport.Net, + networkTypes []NetworkType, + interfaces []*transport.Interface, + includeLoopback bool, + localAddress net.IP, + mDNSMode MulticastDNSMode, + mDNSName string, + log logging.LeveledLogger, + loggerFactory logging.LoggerFactory, +) (*mdns.Conn, MulticastDNSMode, error) { + if mDNSMode == MulticastDNSModeDisabled { + return nil, mDNSMode, nil + } + + var useV4, useV6 bool + if len(networkTypes) == 0 { + useV4 = true + useV6 = true + } else { + for _, nt := range networkTypes { + if nt.IsIPv4() { + useV4 = true + + continue + } + if nt.IsIPv6() { + useV6 = true + } + } + } + + addr4, mdnsErr := netTransport.ResolveUDPAddr("udp4", mdns.DefaultAddressIPv4) + if mdnsErr != nil { + return nil, mDNSMode, mdnsErr + } + addr6, mdnsErr := netTransport.ResolveUDPAddr("udp6", mdns.DefaultAddressIPv6) + if mdnsErr != nil { + return nil, mDNSMode, mdnsErr + } + + var pktConnV4 *ipv4.PacketConn + var mdns4Err error + if useV4 { + var l transport.UDPConn + l, mdns4Err = netTransport.ListenUDP("udp4", addr4) + if mdns4Err != nil { + // If ICE fails to start MulticastDNS server just warn the user and continue + log.Errorf("Failed to enable mDNS over IPv4: (%s)", mdns4Err) + + return nil, MulticastDNSModeDisabled, nil + } + pktConnV4 = ipv4.NewPacketConn(l) + } + + var pktConnV6 *ipv6.PacketConn + var mdns6Err error + if useV6 { + var l transport.UDPConn + l, mdns6Err = netTransport.ListenUDP("udp6", addr6) + if mdns6Err != nil { + log.Errorf("Failed to enable mDNS over IPv6: (%s)", mdns6Err) + + return nil, MulticastDNSModeDisabled, nil + } + pktConnV6 = ipv6.NewPacketConn(l) + } + + if mdns4Err != nil && mdns6Err != nil { + // If ICE fails to start MulticastDNS server just warn the user and continue + log.Errorf("Failed to enable mDNS, continuing in mDNS disabled mode") + //nolint:nilerr + return nil, MulticastDNSModeDisabled, nil + } + var ifcs []net.Interface + if interfaces != nil { + ifcs = make([]net.Interface, 0, len(ifcs)) + for _, ifc := range interfaces { + ifcs = append(ifcs, ifc.Interface) + } + } + + switch mDNSMode { + case MulticastDNSModeQueryOnly: + conn, err := mdns.Server(pktConnV4, pktConnV6, &mdns.Config{ + Interfaces: ifcs, + IncludeLoopback: includeLoopback, + LocalAddress: localAddress, + LoggerFactory: loggerFactory, + }) + + return conn, mDNSMode, err + case MulticastDNSModeQueryAndGather: + conn, err := mdns.Server(pktConnV4, pktConnV6, &mdns.Config{ + Interfaces: ifcs, + IncludeLoopback: includeLoopback, + LocalAddress: localAddress, + LocalNames: []string{mDNSName}, + LoggerFactory: loggerFactory, + }) + + return conn, mDNSMode, err + default: + return nil, mDNSMode, nil + } +} diff --git a/vendor/github.com/pion/ice/v4/net.go b/vendor/github.com/pion/ice/v4/net.go new file mode 100644 index 0000000..1518ffa --- /dev/null +++ b/vendor/github.com/pion/ice/v4/net.go @@ -0,0 +1,180 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "net" + "net/netip" + + "github.com/pion/ice/v4/internal/netutil" + "github.com/pion/logging" + "github.com/pion/transport/v4" +) + +type ifaceAddr struct { + addr netip.Addr + iface string +} + +// The conditions of invalidation written below are defined in +// https://tools.ietf.org/html/rfc8445#section-5.1.1.1 +// It is partial because the link-local check is done later in various gather local +// candidate methods which conditionally accept IPv6 based on usage of mDNS or not. +func isSupportedIPv6Partial(ip net.IP) bool { + if len(ip) != net.IPv6len || + // Deprecated IPv4-compatible IPv6 addresses [RFC4291] and IPv6 site- + // local unicast addresses [RFC3879] MUST NOT be included in the + // address candidates. + isZeros(ip[0:12]) || // !(IPv4-compatible IPv6) + ip[0] == 0xfe && ip[1]&0xc0 == 0xc0 { // !(IPv6 site-local unicast) + return false + } + + return true +} + +func isZeros(ip net.IP) bool { + for i := range ip { + if ip[i] != 0 { + return false + } + } + + return true +} + +//nolint:gocognit,cyclop +func localInterfaces( + n transport.Net, + interfaceFilter func(string) (keep bool), + ipFilter func(net.IP) (keep bool), + networkTypes []NetworkType, + includeLoopback bool, +) ([]*transport.Interface, []ifaceAddr, error) { + ipAddrs := []ifaceAddr{} + ifaces, err := n.Interfaces() + if err != nil { + return nil, ipAddrs, err + } + + filteredIfaces := make([]*transport.Interface, 0, len(ifaces)) + + var ipV4Requested, ipv6Requested bool + if len(networkTypes) == 0 { + ipV4Requested = true + ipv6Requested = true + } else { + for _, typ := range networkTypes { + if typ.IsIPv4() { + ipV4Requested = true + } + + if typ.IsIPv6() { + ipv6Requested = true + } + } + } + + for _, iface := range ifaces { + if iface.Flags&net.FlagUp == 0 { + continue // Interface down + } + if (iface.Flags&net.FlagLoopback != 0) && !includeLoopback { + continue // Loopback interface + } + + if interfaceFilter != nil && !interfaceFilter(iface.Name) { + continue + } + + ifaceAddrs, err := iface.Addrs() + if err != nil { + continue + } + + atLeastOneAddr := false + for _, addr := range ifaceAddrs { + ipAddr, _, _, err := parseAddrFromIface(addr, iface.Name) + if err != nil || (ipAddr.IsLoopback() && !includeLoopback) { + continue + } + if ipAddr.Is6() { + if !ipv6Requested { + continue + } else if !isSupportedIPv6Partial(ipAddr.AsSlice()) { + continue + } + } else if !ipV4Requested { + continue + } + + if ipFilter != nil && !ipFilter(ipAddr.AsSlice()) { + continue + } + + atLeastOneAddr = true + ipAddrs = append(ipAddrs, ifaceAddr{addr: ipAddr, iface: iface.Name}) + } + + if atLeastOneAddr { + ifaceCopy := iface + filteredIfaces = append(filteredIfaces, ifaceCopy) + } + } + + return filteredIfaces, ipAddrs, nil +} + +//nolint:cyclop +func listenUDPInPortRange( + netTransport transport.Net, + log logging.LeveledLogger, + portMax, portMin int, + network string, + lAddr *net.UDPAddr, +) (transport.UDPConn, error) { + if (lAddr.Port != 0) || ((portMin == 0) && (portMax == 0)) { + return netTransport.ListenUDP(network, lAddr) + } + + if portMin == 0 { + portMin = 1024 // Start at 1024 which is non-privileged + } + + if portMax == 0 { + portMax = 0xFFFF + } + + if portMin > portMax { + return nil, ErrPort + } + + portStart := globalMathRandomGenerator.Intn(portMax-portMin+1) + portMin + portCurrent := portStart + for { + addr := &net.UDPAddr{ + IP: lAddr.IP, + Zone: lAddr.Zone, + Port: portCurrent, + } + + c, e := netTransport.ListenUDP(network, addr) + if e == nil { + return c, e //nolint:nilerr + } + log.Debugf("Failed to listen %s: %v", lAddr.String(), e) + if netutil.IsAddrUnavailable(e) { + return nil, e + } + portCurrent++ + if portCurrent > portMax { + portCurrent = portMin + } + if portCurrent == portStart { + break + } + } + + return nil, ErrPort +} diff --git a/vendor/github.com/pion/ice/v4/networktype.go b/vendor/github.com/pion/ice/v4/networktype.go new file mode 100644 index 0000000..674d1b8 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/networktype.go @@ -0,0 +1,142 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "fmt" + "net/netip" + "strings" +) + +const ( + udp = "udp" + tcp = "tcp" + udp4 = "udp4" + udp6 = "udp6" + tcp4 = "tcp4" + tcp6 = "tcp6" +) + +func supportedNetworkTypes() []NetworkType { + return []NetworkType{ + NetworkTypeUDP4, + NetworkTypeUDP6, + NetworkTypeTCP4, + NetworkTypeTCP6, + } +} + +// NetworkType represents the type of network. +type NetworkType int + +const ( + // NetworkTypeUDP4 indicates UDP over IPv4. + NetworkTypeUDP4 NetworkType = iota + 1 + + // NetworkTypeUDP6 indicates UDP over IPv6. + NetworkTypeUDP6 + + // NetworkTypeTCP4 indicates TCP over IPv4. + NetworkTypeTCP4 + + // NetworkTypeTCP6 indicates TCP over IPv6. + NetworkTypeTCP6 +) + +func (t NetworkType) String() string { + switch t { + case NetworkTypeUDP4: + return udp4 + case NetworkTypeUDP6: + return udp6 + case NetworkTypeTCP4: + return tcp4 + case NetworkTypeTCP6: + return tcp6 + default: + return ErrUnknownType.Error() + } +} + +// IsUDP returns true when network is UDP4 or UDP6. +func (t NetworkType) IsUDP() bool { + return t == NetworkTypeUDP4 || t == NetworkTypeUDP6 +} + +// IsTCP returns true when network is TCP4 or TCP6. +func (t NetworkType) IsTCP() bool { + return t == NetworkTypeTCP4 || t == NetworkTypeTCP6 +} + +// NetworkShort returns the short network description. +func (t NetworkType) NetworkShort() string { + switch t { + case NetworkTypeUDP4, NetworkTypeUDP6: + return udp + case NetworkTypeTCP4, NetworkTypeTCP6: + return tcp + default: + return ErrUnknownType.Error() + } +} + +// IsReliable returns true if the network is reliable. +func (t NetworkType) IsReliable() bool { + switch t { + case NetworkTypeUDP4, NetworkTypeUDP6: + return false + case NetworkTypeTCP4, NetworkTypeTCP6: + return true + } + + return false +} + +// IsIPv4 returns whether the network type is IPv4 or not. +func (t NetworkType) IsIPv4() bool { + switch t { + case NetworkTypeUDP4, NetworkTypeTCP4: + return true + case NetworkTypeUDP6, NetworkTypeTCP6: + return false + } + + return false +} + +// IsIPv6 returns whether the network type is IPv6 or not. +func (t NetworkType) IsIPv6() bool { + switch t { + case NetworkTypeUDP4, NetworkTypeTCP4: + return false + case NetworkTypeUDP6, NetworkTypeTCP6: + return true + } + + return false +} + +// determineNetworkType determines the type of network based on +// the short network string and an IP address. +func determineNetworkType(network string, ip netip.Addr) (NetworkType, error) { + // we'd rather have an IPv4-mapped IPv6 become IPv4 so that it is usable. + ip = ip.Unmap() + switch { + case strings.HasPrefix(strings.ToLower(network), udp): + if ip.Is4() { + return NetworkTypeUDP4, nil + } + + return NetworkTypeUDP6, nil + + case strings.HasPrefix(strings.ToLower(network), tcp): + if ip.Is4() { + return NetworkTypeTCP4, nil + } + + return NetworkTypeTCP6, nil + } + + return NetworkType(0), fmt.Errorf("%w from %s %s", ErrDetermineNetworkType, network, ip) +} diff --git a/vendor/github.com/pion/ice/v4/priority.go b/vendor/github.com/pion/ice/v4/priority.go new file mode 100644 index 0000000..51f40b9 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/priority.go @@ -0,0 +1,38 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "encoding/binary" + + "github.com/pion/stun/v3" +) + +// PriorityAttr represents PRIORITY attribute. +type PriorityAttr uint32 + +const prioritySize = 4 // 32 bit + +// AddTo adds PRIORITY attribute to message. +func (p PriorityAttr) AddTo(m *stun.Message) error { + v := make([]byte, prioritySize) + binary.BigEndian.PutUint32(v, uint32(p)) + m.Add(stun.AttrPriority, v) + + return nil +} + +// GetFrom decodes PRIORITY attribute from message. +func (p *PriorityAttr) GetFrom(m *stun.Message) error { + v, err := m.Get(stun.AttrPriority) + if err != nil { + return err + } + if err = stun.CheckSize(stun.AttrPriority, len(v), prioritySize); err != nil { + return err + } + *p = PriorityAttr(binary.BigEndian.Uint32(v)) + + return nil +} diff --git a/vendor/github.com/pion/ice/v4/rand.go b/vendor/github.com/pion/ice/v4/rand.go new file mode 100644 index 0000000..f5e27bc --- /dev/null +++ b/vendor/github.com/pion/ice/v4/rand.go @@ -0,0 +1,56 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import "github.com/pion/randutil" + +const ( + runesAlpha = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + runesDigit = "0123456789" + runesCandidateIDFoundation = runesAlpha + runesDigit + "+/" + + lenUFrag = 16 + lenPwd = 32 +) + +// Seeding random generator each time limits number of generated sequence to 31-bits, +// and causes collision on low time accuracy environments. +// Use global random generator seeded by crypto grade random. +var ( + globalMathRandomGenerator = randutil.NewMathRandomGenerator() //nolint:gochecknoglobals + globalCandidateIDGenerator = candidateIDGenerator{globalMathRandomGenerator} //nolint:gochecknoglobals +) + +// candidateIDGenerator is a random candidate ID generator. +// Candidate ID is used in SDP and always shared to the other peer. +// It doesn't require cryptographic random. +type candidateIDGenerator struct { + randutil.MathRandomGenerator +} + +func newCandidateIDGenerator() *candidateIDGenerator { + return &candidateIDGenerator{ + randutil.NewMathRandomGenerator(), + } +} + +func (g *candidateIDGenerator) Generate() string { + // https://tools.ietf.org/html/rfc5245#section-15.1 + // candidate-id = "candidate" ":" foundation + // foundation = 1*32ice-char + // ice-char = ALPHA / DIGIT / "+" / "/" + return "candidate:" + g.MathRandomGenerator.GenerateString(32, runesCandidateIDFoundation) +} + +// generatePwd generates ICE pwd. +// This internally uses generateCryptoRandomString. +func generatePwd() (string, error) { + return randutil.GenerateCryptoRandomString(lenPwd, runesAlpha) +} + +// generateUFrag generates ICE user fragment. +// This internally uses generateCryptoRandomString. +func generateUFrag() (string, error) { + return randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha) +} diff --git a/vendor/github.com/pion/ice/v4/renomination.go b/vendor/github.com/pion/ice/v4/renomination.go new file mode 100644 index 0000000..a0674f0 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/renomination.go @@ -0,0 +1,86 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "fmt" + + "github.com/pion/stun/v3" +) + +// Default STUN Nomination attribute type for ICE renomination. +// Following the specification draft-thatcher-ice-renomination-01. +const ( + // DefaultNominationAttribute represents the default STUN Nomination attribute. + // This is a custom attribute for ICE renomination support. + // This value can be overridden via AgentConfig.NominationAttribute. + DefaultNominationAttribute stun.AttrType = 0x0030 // Using a value in the reserved range +) + +// NominationAttribute represents a STUN Nomination attribute. +type NominationAttribute struct { + Value uint32 +} + +// GetFrom decodes a Nomination attribute from a STUN message. +func (a *NominationAttribute) GetFrom(m *stun.Message) error { + return a.GetFromWithType(m, DefaultNominationAttribute) +} + +// GetFromWithType decodes a Nomination attribute from a STUN message using a specific attribute type. +func (a *NominationAttribute) GetFromWithType(m *stun.Message, attrType stun.AttrType) error { + v, err := m.Get(attrType) + if err != nil { + return err + } + if len(v) < 4 { + return stun.ErrAttributeSizeInvalid + } + + // Extract 24-bit value from the last 3 bytes + a.Value = uint32(v[1])<<16 | uint32(v[2])<<8 | uint32(v[3]) + + return nil +} + +// AddTo adds a Nomination attribute to a STUN message. +func (a NominationAttribute) AddTo(m *stun.Message) error { + return a.AddToWithType(m, DefaultNominationAttribute) +} + +// AddToWithType adds a Nomination attribute to a STUN message using a specific attribute type. +func (a NominationAttribute) AddToWithType(m *stun.Message, attrType stun.AttrType) error { + // Store as 4 bytes with first byte as 0 + v := make([]byte, 4) + v[1] = byte(a.Value >> 16) //nolint:gosec + v[2] = byte(a.Value >> 8) //nolint:gosec + v[3] = byte(a.Value) //nolint:gosec + + m.Add(attrType, v) + + return nil +} + +// String returns string representation of the nomination attribute. +func (a NominationAttribute) String() string { + return fmt.Sprintf("NOMINATION: %d", a.Value) +} + +// Nomination creates a new STUN nomination attribute. +func Nomination(value uint32) NominationAttribute { + return NominationAttribute{Value: value} +} + +// NominationSetter is a STUN setter for nomination attribute with configurable type. +type NominationSetter struct { + Value uint32 + AttrType stun.AttrType +} + +// AddTo adds a Nomination attribute to a STUN message using the configured attribute type. +func (n NominationSetter) AddTo(m *stun.Message) error { + attr := NominationAttribute{Value: n.Value} + + return attr.AddToWithType(m, n.AttrType) +} diff --git a/vendor/github.com/pion/ice/v4/renovate.json b/vendor/github.com/pion/ice/v4/renovate.json new file mode 100644 index 0000000..f1bb98c --- /dev/null +++ b/vendor/github.com/pion/ice/v4/renovate.json @@ -0,0 +1,6 @@ +{ + "$schema": "https://docs.renovatebot.com/renovate-schema.json", + "extends": [ + "github>pion/renovate-config" + ] +} diff --git a/vendor/github.com/pion/ice/v4/role.go b/vendor/github.com/pion/ice/v4/role.go new file mode 100644 index 0000000..7719f62 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/role.go @@ -0,0 +1,47 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "fmt" +) + +// Role represents ICE agent role, which can be controlling or controlled. +type Role byte + +// Possible ICE agent roles. +const ( + Controlling Role = iota + Controlled +) + +// UnmarshalText implements TextUnmarshaler. +func (r *Role) UnmarshalText(text []byte) error { + switch string(text) { + case "controlling": + *r = Controlling + case "controlled": + *r = Controlled + default: + return fmt.Errorf("%w %q", errUnknownRole, text) + } + + return nil +} + +// MarshalText implements TextMarshaler. +func (r Role) MarshalText() (text []byte, err error) { + return []byte(r.String()), nil +} + +func (r Role) String() string { + switch r { + case Controlling: + return "controlling" + case Controlled: + return "controlled" + default: + return "unknown" + } +} diff --git a/vendor/github.com/pion/ice/v4/selection.go b/vendor/github.com/pion/ice/v4/selection.go new file mode 100644 index 0000000..efed335 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/selection.go @@ -0,0 +1,488 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "net" + "time" + + "github.com/pion/logging" + "github.com/pion/stun/v3" +) + +type pairCandidateSelector interface { + Start() + ContactCandidates() + PingCandidate(local, remote Candidate) + HandleSuccessResponse(m *stun.Message, local, remote Candidate, remoteAddr net.Addr) + HandleBindingRequest(m *stun.Message, local, remote Candidate) +} + +type controllingSelector struct { + startTime time.Time + agent *Agent + nominatedPair *CandidatePair + log logging.LeveledLogger +} + +func (s *controllingSelector) Start() { + s.startTime = time.Now() + s.nominatedPair = nil +} + +func (s *controllingSelector) isNominatable(c Candidate) bool { + switch { + case c.Type() == CandidateTypeHost: + return time.Since(s.startTime).Nanoseconds() > s.agent.hostAcceptanceMinWait.Nanoseconds() + case c.Type() == CandidateTypeServerReflexive: + return time.Since(s.startTime).Nanoseconds() > s.agent.srflxAcceptanceMinWait.Nanoseconds() + case c.Type() == CandidateTypePeerReflexive: + return time.Since(s.startTime).Nanoseconds() > s.agent.prflxAcceptanceMinWait.Nanoseconds() + case c.Type() == CandidateTypeRelay: + return time.Since(s.startTime).Nanoseconds() > s.agent.relayAcceptanceMinWait.Nanoseconds() + } + + s.log.Errorf("Invalid candidate type: %s", c.Type()) + + return false +} + +func (s *controllingSelector) ContactCandidates() { + switch { + case s.agent.getSelectedPair() != nil: + if s.agent.validateSelectedPair() { + s.log.Trace("Checking keepalive") + s.agent.checkKeepalive() + + // If automatic renomination is enabled, continuously ping all candidate pairs + // to keep them tested with fresh RTT measurements for switching decisions + if s.agent.automaticRenomination && s.agent.enableRenomination { + s.agent.keepAliveCandidatesForRenomination() + } + + s.checkForAutomaticRenomination() + } + case s.nominatedPair != nil: + s.nominatePair(s.nominatedPair) + default: + p := s.agent.getBestValidCandidatePair() + if p != nil && s.isNominatable(p.Local) && s.isNominatable(p.Remote) { + s.log.Tracef("Nominatable pair found, nominating (%s, %s)", p.Local, p.Remote) + p.nominated = true + s.nominatedPair = p + s.nominatePair(p) + + return + } + s.agent.pingAllCandidates() + } +} + +func (s *controllingSelector) nominatePair(pair *CandidatePair) { + // The controlling agent MUST include the USE-CANDIDATE attribute in + // order to nominate a candidate pair (Section 8.1.1). The controlled + // agent MUST NOT include the USE-CANDIDATE attribute in a Binding + // request. + msg, err := stun.Build(stun.BindingRequest, stun.TransactionID, + stun.NewUsername(s.agent.remoteUfrag+":"+s.agent.localUfrag), + UseCandidate(), + AttrControlling(s.agent.tieBreaker), + PriorityAttr(pair.Local.Priority()), + stun.NewShortTermIntegrity(s.agent.remotePwd), + stun.Fingerprint, + ) + if err != nil { + s.log.Error(err.Error()) + + return + } + + s.log.Tracef("Ping STUN (nominate candidate pair) from %s to %s", pair.Local, pair.Remote) + s.agent.sendBindingRequest(msg, pair.Local, pair.Remote) +} + +func (s *controllingSelector) HandleBindingRequest(message *stun.Message, local, remote Candidate) { //nolint:cyclop + s.agent.sendBindingSuccess(message, local, remote) + + pair := s.agent.findPair(local, remote) + + if pair == nil { + pair = s.agent.addPair(local, remote) + pair.UpdateRequestReceived() + + return + } + pair.UpdateRequestReceived() + + if pair.state == CandidatePairStateSucceeded && s.nominatedPair == nil && s.agent.getSelectedPair() == nil { + bestPair := s.agent.getBestAvailableCandidatePair() + if bestPair == nil { + s.log.Tracef("No best pair available") + } else if bestPair.equal(pair) && s.isNominatable(pair.Local) && s.isNominatable(pair.Remote) { + s.log.Tracef( + "The candidate (%s, %s) is the best candidate available, marking it as nominated", + pair.Local, + pair.Remote, + ) + s.nominatedPair = pair + s.nominatePair(pair) + } + } + + if s.agent.userBindingRequestHandler != nil { + if shouldSwitch := s.agent.userBindingRequestHandler(message, local, remote, pair); shouldSwitch { + s.agent.setSelectedPair(pair) + } + } +} + +func (s *controllingSelector) HandleSuccessResponse(m *stun.Message, local, remote Candidate, remoteAddr net.Addr) { + ok, pendingRequest, rtt := s.agent.handleInboundBindingSuccess(m.TransactionID) + if !ok { + s.log.Warnf("Discard success response from (%s), unknown TransactionID 0x%x", remote, m.TransactionID) + + return + } + + transactionAddr := pendingRequest.destination + + // Assert that NAT is not symmetric + // https://tools.ietf.org/html/rfc8445#section-7.2.5.2.1 + if !addrEqual(transactionAddr, remoteAddr) { + s.log.Debugf( + "Discard message: transaction source and destination does not match expected(%s), actual(%s)", + transactionAddr, + remote, + ) + + return + } + + s.log.Tracef("Inbound STUN (SuccessResponse) from %s to %s", remote, local) + pair := s.agent.findPair(local, remote) + + if pair == nil { + // This shouldn't happen + s.log.Error("Success response from invalid candidate pair") + + return + } + + pair.state = CandidatePairStateSucceeded + s.log.Tracef("Found valid candidate pair: %s", pair) + + // Handle nomination/renomination + if pendingRequest.isUseCandidate { + selectedPair := s.agent.getSelectedPair() + + // If this is a renomination request (has nomination value), always update the selected pair + // If it's a standard nomination (no value), only set if no pair is selected yet + if pendingRequest.nominationValue != nil { + s.log.Infof("Renomination success response received for pair %s (nomination value: %d), switching to this pair", + pair, *pendingRequest.nominationValue) + s.agent.setSelectedPair(pair) + } else if selectedPair == nil { + s.agent.setSelectedPair(pair) + } + } + + pair.UpdateRoundTripTime(rtt) +} + +func (s *controllingSelector) PingCandidate(local, remote Candidate) { + msg, err := stun.Build(stun.BindingRequest, stun.TransactionID, + stun.NewUsername(s.agent.remoteUfrag+":"+s.agent.localUfrag), + AttrControlling(s.agent.tieBreaker), + PriorityAttr(local.Priority()), + stun.NewShortTermIntegrity(s.agent.remotePwd), + stun.Fingerprint, + ) + if err != nil { + s.log.Error(err.Error()) + + return + } + + s.agent.sendBindingRequest(msg, local, remote) +} + +// checkForAutomaticRenomination evaluates if automatic renomination should occur. +// This is called periodically when the agent is in connected state and automatic +// renomination is enabled. +func (s *controllingSelector) checkForAutomaticRenomination() { + if !s.agent.automaticRenomination || !s.agent.enableRenomination { + s.log.Tracef("Automatic renomination check skipped: automaticRenomination=%v, enableRenomination=%v", + s.agent.automaticRenomination, s.agent.enableRenomination) + + return + } + + timeSinceStart := time.Since(s.startTime) + if timeSinceStart < s.agent.renominationInterval { + s.log.Tracef("Automatic renomination check skipped: not enough time since start (%v < %v)", + timeSinceStart, s.agent.renominationInterval) + + return + } + + if !s.agent.lastRenominationTime.IsZero() { + timeSinceLastRenomination := time.Since(s.agent.lastRenominationTime) + if timeSinceLastRenomination < s.agent.renominationInterval { + s.log.Tracef("Automatic renomination check skipped: too soon since last renomination (%v < %v)", + timeSinceLastRenomination, s.agent.renominationInterval) + + return + } + } + + currentPair := s.agent.getSelectedPair() + if currentPair == nil { + s.log.Tracef("Automatic renomination check skipped: no current selected pair") + + return + } + + bestPair := s.agent.findBestCandidatePair() + if bestPair == nil { + s.log.Tracef("Automatic renomination check skipped: no best pair found") + + return + } + + s.log.Debugf("Evaluating automatic renomination: current=%s (RTT=%.2fms), best=%s (RTT=%.2fms)", + currentPair, currentPair.CurrentRoundTripTime()*1000, + bestPair, bestPair.CurrentRoundTripTime()*1000) + + if s.agent.shouldRenominate(currentPair, bestPair) { + s.log.Infof("Automatic renomination triggered: switching from %s to %s", + currentPair, bestPair) + + // Update last renomination time to prevent rapid renominations + s.agent.lastRenominationTime = time.Now() + + if err := s.agent.RenominateCandidate(bestPair.Local, bestPair.Remote); err != nil { + s.log.Errorf("Failed to trigger automatic renomination: %v", err) + } + } else { + s.log.Debugf("Automatic renomination not warranted") + } +} + +type controlledSelector struct { + agent *Agent + log logging.LeveledLogger + lastNomination *uint32 // For renomination: tracks highest nomination value seen +} + +func (s *controlledSelector) Start() { + s.lastNomination = nil +} + +// shouldAcceptNomination checks if a nomination should be accepted based on renomination rules. +func (s *controlledSelector) shouldAcceptNomination(nominationValue *uint32) bool { + // If no nomination value, accept normally (standard ICE nomination) + if nominationValue == nil { + return true + } + + // If nomination value is present, controlling side is using renomination + // Apply "last nomination wins" rule + + if s.lastNomination == nil || *nominationValue > *s.lastNomination { + s.lastNomination = nominationValue + s.log.Tracef("Accepting nomination with value %d", *nominationValue) + + return true + } + + s.log.Tracef("Rejecting nomination value %d (current is %d)", *nominationValue, *s.lastNomination) + + return false +} + +// shouldSwitchSelectedPair determines if we should switch to a new nominated pair. +// Returns true if the switch should occur, false otherwise. +func (s *controlledSelector) shouldSwitchSelectedPair(pair, selectedPair *CandidatePair, nominationValue *uint32) bool { + switch { + case selectedPair == nil: + // No current selection, accept the nomination + return true + case selectedPair == pair: + // Same pair, no change needed + return false + case nominationValue != nil: + // Renomination is in use (nomination value present) + // Accept the switch based on nomination value alone, not priority + // The shouldAcceptNomination check already validated this is a valid renomination + s.log.Debugf("Accepting renomination to pair %s (nomination value: %d)", pair, *nominationValue) + + return true + } + + // Standard ICE nomination without renomination - apply priority rules + // Only switch if we don't check priority, OR new pair has strictly higher priority + return !s.agent.needsToCheckPriorityOnNominated() || + selectedPair.priority() < pair.priority() +} + +func (s *controlledSelector) ContactCandidates() { + if s.agent.getSelectedPair() != nil { + if s.agent.validateSelectedPair() { + s.log.Trace("Checking keepalive") + s.agent.checkKeepalive() + } + } else { + s.agent.pingAllCandidates() + } +} + +func (s *controlledSelector) PingCandidate(local, remote Candidate) { + msg, err := stun.Build(stun.BindingRequest, stun.TransactionID, + stun.NewUsername(s.agent.remoteUfrag+":"+s.agent.localUfrag), + AttrControlled(s.agent.tieBreaker), + PriorityAttr(local.Priority()), + stun.NewShortTermIntegrity(s.agent.remotePwd), + stun.Fingerprint, + ) + if err != nil { + s.log.Error(err.Error()) + + return + } + + s.agent.sendBindingRequest(msg, local, remote) +} + +func (s *controlledSelector) HandleSuccessResponse(m *stun.Message, local, remote Candidate, remoteAddr net.Addr) { + //nolint:godox + // TODO according to the standard we should specifically answer a failed nomination: + // https://tools.ietf.org/html/rfc8445#section-7.3.1.5 + // If the controlled agent does not accept the request from the + // controlling agent, the controlled agent MUST reject the nomination + // request with an appropriate error code response (e.g., 400) + // [RFC5389]. + + ok, pendingRequest, rtt := s.agent.handleInboundBindingSuccess(m.TransactionID) + if !ok { + s.log.Warnf("Discard message from (%s), unknown TransactionID 0x%x", remote, m.TransactionID) + + return + } + + transactionAddr := pendingRequest.destination + + // Assert that NAT is not symmetric + // https://tools.ietf.org/html/rfc8445#section-7.2.5.2.1 + if !addrEqual(transactionAddr, remoteAddr) { + s.log.Debugf( + "Discard message: transaction source and destination does not match expected(%s), actual(%s)", + transactionAddr, + remote, + ) + + return + } + + s.log.Tracef("Inbound STUN (SuccessResponse) from %s to %s", remote, local) + + pair := s.agent.findPair(local, remote) + if pair == nil { + // This shouldn't happen + s.log.Error("Success response from invalid candidate pair") + + return + } + + pair.state = CandidatePairStateSucceeded + s.log.Tracef("Found valid candidate pair: %s", pair) + if pair.nominateOnBindingSuccess { + if selectedPair := s.agent.getSelectedPair(); selectedPair == nil || + (selectedPair != pair && + (!s.agent.needsToCheckPriorityOnNominated() || selectedPair.priority() <= pair.priority())) { + s.agent.setSelectedPair(pair) + } else if selectedPair != pair { + s.log.Tracef("Ignore nominate new pair %s, already nominated pair %s", pair, selectedPair) + } + } + + pair.UpdateRoundTripTime(rtt) +} + +func (s *controlledSelector) HandleBindingRequest(message *stun.Message, local, remote Candidate) { //nolint:cyclop + pair := s.agent.findPair(local, remote) + if pair == nil { + pair = s.agent.addPair(local, remote) + } + pair.UpdateRequestReceived() + + if message.Contains(stun.AttrUseCandidate) || message.Contains(s.agent.nominationAttribute) { //nolint:nestif + // https://tools.ietf.org/html/rfc8445#section-7.3.1.5 + + // Check for renomination attribute + var nominationValue *uint32 + var nomination NominationAttribute + if err := nomination.GetFromWithType(message, s.agent.nominationAttribute); err == nil { + nominationValue = &nomination.Value + s.log.Tracef("Received nomination with value %d", nomination.Value) + } + + // Check if we should accept this nomination based on renomination rules + if !s.shouldAcceptNomination(nominationValue) { + s.log.Tracef("Rejecting nomination request due to renomination rules") + s.agent.sendBindingSuccess(message, local, remote) + + return + } + + if pair.state == CandidatePairStateSucceeded { + // If the state of this pair is Succeeded, it means that the check + // previously sent by this pair produced a successful response and + // generated a valid pair (Section 7.2.5.3.2). The agent sets the + // nominated flag value of the valid pair to true. + selectedPair := s.agent.getSelectedPair() + if s.shouldSwitchSelectedPair(pair, selectedPair, nominationValue) { + s.log.Tracef("Accepting nomination for pair %s", pair) + s.agent.setSelectedPair(pair) + } else { + s.log.Tracef("Ignore nominate new pair %s, already nominated pair %s", pair, selectedPair) + } + } else { + // If the received Binding request triggered a new check to be + // enqueued in the triggered-check queue (Section 7.3.1.4), once the + // check is sent and if it generates a successful response, and + // generates a valid pair, the agent sets the nominated flag of the + // pair to true. If the request fails (Section 7.2.5.2), the agent + // MUST remove the candidate pair from the valid list, set the + // candidate pair state to Failed, and set the checklist state to + // Failed. + pair.nominateOnBindingSuccess = true + } + } + + s.agent.sendBindingSuccess(message, local, remote) + s.PingCandidate(local, remote) + + if s.agent.userBindingRequestHandler != nil { + if shouldSwitch := s.agent.userBindingRequestHandler(message, local, remote, pair); shouldSwitch { + s.agent.setSelectedPair(pair) + } + } +} + +type liteSelector struct { + pairCandidateSelector +} + +// A lite selector should not contact candidates. +func (s *liteSelector) ContactCandidates() { + if _, ok := s.pairCandidateSelector.(*controllingSelector); ok { + //nolint:godox + // https://github.com/pion/ice/issues/96 + // TODO: implement lite controlling agent. For now falling back to full agent. + // This only happens if both peers are lite. See RFC 8445 S6.1.1 and S6.2 + s.pairCandidateSelector.ContactCandidates() + } else if v, ok := s.pairCandidateSelector.(*controlledSelector); ok { + v.agent.validateSelectedPair() + } +} diff --git a/vendor/github.com/pion/ice/v4/sped.go b/vendor/github.com/pion/ice/v4/sped.go new file mode 100644 index 0000000..cea3be7 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/sped.go @@ -0,0 +1,71 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "encoding/binary" + + "github.com/pion/stun/v3" +) + +// DtlsInStunAttribute is a STUN attribute for carrying DTLS embedded in STUN. +type DtlsInStunAttribute []byte + +// AddTo adds DTLS-in-STUN attribute to message. +func (d DtlsInStunAttribute) AddTo(m *stun.Message) error { + m.Add(stun.AttrDtlsInStun, d) + + return nil +} + +// GetFrom decodes DTLS-in-STUN attribute from message. +func (d *DtlsInStunAttribute) GetFrom(m *stun.Message) error { + v, err := m.Get(stun.AttrDtlsInStun) + if err != nil { + return err + } + + *d = v + + return nil +} + +// DtlsInStunAckAttribute is a STUN attribute for acknowledging the receipt +// of DTLS packets (embedded in STUN or without embedding). +type DtlsInStunAckAttribute []uint32 + +// Acks are 32 bit values, the attribute can carry up to four of these. +const ackSizeBytes, ackSizeValues = 32, 4 + +// AddTo adds DTLS-in-STUN-ACK attribute to message. +func (a DtlsInStunAckAttribute) AddTo(m *stun.Message) error { + if len(a) > ackSizeValues { + return stun.ErrAttributeSizeInvalid + } + v := make([]byte, len(a)*4) + for i, ack := range a { + binary.BigEndian.PutUint32(v[i*4:], ack) + } + m.Add(stun.AttrDtlsInStunAck, v) + + return nil +} + +// GetFrom decodes DTLS-in-STUN-ACK attribute from message. +func (a *DtlsInStunAckAttribute) GetFrom(m *stun.Message) error { + v, err := m.Get(stun.AttrDtlsInStunAck) + if err != nil { + return err + } + if len(v) > ackSizeBytes || len(v)%4 != 0 { + return stun.ErrAttributeSizeInvalid + } + u := make([]uint32, len(v)/4) + for i := range u { + u[i] = binary.BigEndian.Uint32(v[i*4 : (i+1)*4]) + } + *a = DtlsInStunAckAttribute(u) + + return nil +} diff --git a/vendor/github.com/pion/ice/v4/stats.go b/vendor/github.com/pion/ice/v4/stats.go new file mode 100644 index 0000000..4b5ad01 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/stats.go @@ -0,0 +1,219 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "time" +) + +// CandidatePairStats contains ICE candidate pair statistics. +type CandidatePairStats struct { + // Timestamp is the timestamp associated with this object. + Timestamp time.Time + + // LocalCandidateID is the ID of the local candidate + LocalCandidateID string + + // RemoteCandidateID is the ID of the remote candidate + RemoteCandidateID string + + // State represents the state of the checklist for the local and remote + // candidates in a pair. + State CandidatePairState + + // Nominated is true when this valid pair that should be used for media + // if it is the highest-priority one amongst those whose nominated flag is set + Nominated bool + + // PacketsSent represents the total number of packets sent on this candidate pair. + PacketsSent uint32 + + // PacketsReceived represents the total number of packets received on this candidate pair. + PacketsReceived uint32 + + // BytesSent represents the total number of payload bytes sent on this candidate pair + // not including headers or padding. + BytesSent uint64 + + // BytesReceived represents the total number of payload bytes received on this candidate pair + // not including headers or padding. + BytesReceived uint64 + + // LastPacketSentTimestamp represents the timestamp at which the last packet was + // sent on this particular candidate pair, excluding STUN packets. + LastPacketSentTimestamp time.Time + + // LastPacketReceivedTimestamp represents the timestamp at which the last packet + // was received on this particular candidate pair, excluding STUN packets. + LastPacketReceivedTimestamp time.Time + + // FirstRequestTimestamp represents the timestamp at which the first STUN request + // was sent on this particular candidate pair. + FirstRequestTimestamp time.Time + + // LastRequestTimestamp represents the timestamp at which the last STUN request + // was sent on this particular candidate pair. The average interval between two + // consecutive connectivity checks sent can be calculated with + // (LastRequestTimestamp - FirstRequestTimestamp) / RequestsSent. + LastRequestTimestamp time.Time + + // FirstResponseTimestamp represents the timestamp at which the first STUN response + // was received on this particular candidate pair. + FirstResponseTimestamp time.Time + + // LastResponseTimestamp represents the timestamp at which the last STUN response + // was received on this particular candidate pair. + LastResponseTimestamp time.Time + + // FirstRequestReceivedTimestamp represents the timestamp at which the first + // connectivity check request was received. + FirstRequestReceivedTimestamp time.Time + + // LastRequestReceivedTimestamp represents the timestamp at which the last + // connectivity check request was received. + LastRequestReceivedTimestamp time.Time + + // TotalRoundTripTime represents the sum of all round trip time measurements + // in seconds since the beginning of the session, based on STUN connectivity + // check responses (ResponsesReceived), including those that reply to requests + // that are sent in order to verify consent. The average round trip time can + // be computed from TotalRoundTripTime by dividing it by ResponsesReceived. + TotalRoundTripTime float64 + + // CurrentRoundTripTime represents the latest round trip time measured in seconds, + // computed from both STUN connectivity checks, including those that are sent + // for consent verification. + CurrentRoundTripTime float64 + + // AvailableOutgoingBitrate is calculated by the underlying congestion control + // by combining the available bitrate for all the outgoing RTP streams using + // this candidate pair. The bitrate measurement does not count the size of the + // IP or other transport layers like TCP or UDP. It is similar to the TIAS defined + // in RFC 3890, i.e., it is measured in bits per second and the bitrate is calculated + // over a 1 second window. + AvailableOutgoingBitrate float64 + + // AvailableIncomingBitrate is calculated by the underlying congestion control + // by combining the available bitrate for all the incoming RTP streams using + // this candidate pair. The bitrate measurement does not count the size of the + // IP or other transport layers like TCP or UDP. It is similar to the TIAS defined + // in RFC 3890, i.e., it is measured in bits per second and the bitrate is + // calculated over a 1 second window. + AvailableIncomingBitrate float64 + + // CircuitBreakerTriggerCount represents the number of times the circuit breaker + // is triggered for this particular 5-tuple, ceasing transmission. + CircuitBreakerTriggerCount uint32 + + // RequestsReceived represents the total number of connectivity check requests + // received (including retransmissions). It is impossible for the receiver to + // tell whether the request was sent in order to check connectivity or check + // consent, so all connectivity checks requests are counted here. + RequestsReceived uint64 + + // RequestsSent represents the total number of connectivity check requests + // sent (not including retransmissions). + RequestsSent uint64 + + // ResponsesReceived represents the total number of connectivity check responses received. + ResponsesReceived uint64 + + // ResponsesSent represents the total number of connectivity check responses sent. + // Since we cannot distinguish connectivity check requests and consent requests, + // all responses are counted. + ResponsesSent uint64 + + // RetransmissionsReceived represents the total number of connectivity check + // request retransmissions received. + RetransmissionsReceived uint64 + + // RetransmissionsSent represents the total number of connectivity check + // request retransmissions sent. + RetransmissionsSent uint64 + + // ConsentRequestsSent represents the total number of consent requests sent. + ConsentRequestsSent uint64 + + // ConsentExpiredTimestamp represents the timestamp at which the latest valid + // STUN binding response expired. + ConsentExpiredTimestamp time.Time +} + +// CandidatePairInfo is a snapshot of a candidate pair's state. +// Use the ID with Conn.WriteToPair() to write to this specific pair. +type CandidatePairInfo struct { + // ID is the unique identifier for this candidate pair. + // Use this with Conn.WriteToPair() to write to this pair. + ID uint64 + + // LocalCandidateType is the type of the local candidate (host, srflx, prflx, relay). + LocalCandidateType CandidateType + + // RemoteCandidateType is the type of the remote candidate (host, srflx, prflx, relay). + RemoteCandidateType CandidateType + + // State is the current state of the candidate pair. + State CandidatePairState + + // Nominated indicates whether this pair has been nominated. + Nominated bool + + // CurrentRoundTripTime is the latest RTT measurement. + CurrentRoundTripTime time.Duration + + // RenominationQuality is a score indicating the pair's quality (higher is better). + // Considers candidate types (host > srflx > relay), measured RTT, and stability. + RenominationQuality float64 +} + +// CandidateStats contains ICE candidate statistics related to the ICETransport objects. +type CandidateStats struct { + // Timestamp is the timestamp associated with this object. + Timestamp time.Time + + // ID is the candidate ID + ID string + + // NetworkType represents the type of network interface used by the base of a + // local candidate (the address the ICE agent sends from). Only present for + // local candidates; it's not possible to know what type of network interface + // a remote candidate is using. + // + // Note: + // This stat only tells you about the network interface used by the first "hop"; + // it's possible that a connection will be bottlenecked by another type of network. + // For example, when using Wi-Fi tethering, the networkType of the relevant candidate + // would be "wifi", even when the next hop is over a cellular connection. + NetworkType NetworkType + + // IP is the IP address of the candidate, allowing for IPv4 addresses and + // IPv6 addresses, but fully qualified domain names (FQDNs) are not allowed. + IP string + + // Port is the port number of the candidate. + Port int + + // CandidateType is the "Type" field of the ICECandidate. + CandidateType CandidateType + + // Priority is the "Priority" field of the ICECandidate. + Priority uint32 + + // URL is the URL of the TURN or STUN server indicated in the that translated + // this IP address. It is the URL address surfaced in an PeerConnectionICEEvent. + URL string + + // RelayProtocol is the protocol used by the endpoint to communicate with the + // TURN server. This is only present for local candidates. Valid values for + // the TURN URL protocol is one of UDP, TCP, or TLS. + RelayProtocol string + + // Deleted is true if the candidate has been deleted/freed. For host candidates, + // this means that any network resources (typically a socket) associated with the + // candidate have been released. For TURN candidates, this means the TURN allocation + // is no longer active. + // + // Only defined for local candidates. For remote candidates, this property is not applicable. + Deleted bool +} diff --git a/vendor/github.com/pion/ice/v4/tcp_mux.go b/vendor/github.com/pion/ice/v4/tcp_mux.go new file mode 100644 index 0000000..a7a956a --- /dev/null +++ b/vendor/github.com/pion/ice/v4/tcp_mux.go @@ -0,0 +1,476 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "encoding/binary" + "errors" + "io" + "net" + "strings" + "sync" + "time" + + "github.com/pion/logging" + "github.com/pion/stun/v3" +) + +// ErrGetTransportAddress can't convert net.Addr to underlying type (UDPAddr or TCPAddr). +var ErrGetTransportAddress = errors.New("failed to get local transport address") + +// TCPMux is allows grouping multiple TCP net.Conns and using them like UDP +// net.PacketConns. The main implementation of this is TCPMuxDefault, and this +// interface exists to allow mocking in tests. +type TCPMux interface { + io.Closer + GetConnByUfrag(ufrag string, isIPv6 bool, local net.IP) (net.PacketConn, error) + RemoveConnByUfrag(ufrag string) +} + +type ipAddr string + +// TCPMuxDefault muxes TCP net.Conns into net.PacketConns and groups them by +// Ufrag. It is a default implementation of TCPMux interface. +type TCPMuxDefault struct { + params *TCPMuxParams + closed bool + + // connsIPv4 and connsIPv6 are maps of all tcpPacketConns indexed by ufrag and local address + connsIPv4, connsIPv6 map[string]map[ipAddr]*tcpPacketConn + + mu sync.Mutex + wg sync.WaitGroup +} + +// TCPMuxParams are parameters for TCPMux. +type TCPMuxParams struct { + Listener net.Listener + Logger logging.LeveledLogger + ReadBufferSize int + + // Maximum buffer size for write op. 0 means no write buffer, the write op will block until the whole packet is written + // if the write buffer is full, the subsequent write packet will be dropped until it has enough space. + // a default 4MB is recommended. + WriteBufferSize int + + // A new established connection will be removed if the first STUN binding request is not received within this timeout, + // avoiding the client with bad network or attacker to create a lot of empty connections. + // Default 30s timeout will be used if not set. + FirstStunBindTimeout time.Duration + + // TCPMux will create connection from STUN binding request with an unknown username, if + // the connection is not used in the timeout, it will be removed to avoid resource leak / attack. + // Default 30s timeout will be used if not set. + AliveDurationForConnFromStun time.Duration +} + +// NewTCPMuxDefault creates a new instance of TCPMuxDefault. +func NewTCPMuxDefault(params TCPMuxParams) *TCPMuxDefault { + if params.Logger == nil { + params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") + } + + if params.FirstStunBindTimeout == 0 { + params.FirstStunBindTimeout = 30 * time.Second + } + + if params.AliveDurationForConnFromStun == 0 { + params.AliveDurationForConnFromStun = 30 * time.Second + } + + mux := &TCPMuxDefault{ + params: ¶ms, + + connsIPv4: map[string]map[ipAddr]*tcpPacketConn{}, + connsIPv6: map[string]map[ipAddr]*tcpPacketConn{}, + } + + mux.wg.Add(1) + go func() { + defer mux.wg.Done() + mux.start() + }() + + return mux +} + +func (m *TCPMuxDefault) start() { + m.params.Logger.Infof("Listening TCP on %s", m.params.Listener.Addr()) + for { + conn, err := m.params.Listener.Accept() + if err != nil { + m.params.Logger.Infof("Error accepting connection: %s", err) + + return + } + + m.params.Logger.Debugf("Accepted connection from: %s to %s", conn.RemoteAddr(), conn.LocalAddr()) + + m.wg.Add(1) + go func() { + defer m.wg.Done() + m.handleConn(conn) + }() + } +} + +// LocalAddr returns the listening address of this TCPMuxDefault. +func (m *TCPMuxDefault) LocalAddr() net.Addr { + return m.params.Listener.Addr() +} + +// GetConnByUfrag retrieves an existing or creates a new net.PacketConn. +func (m *TCPMuxDefault) GetConnByUfrag(ufrag string, isIPv6 bool, local net.IP) (net.PacketConn, error) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.closed { + return nil, io.ErrClosedPipe + } + + if conn, ok := m.getConn(ufrag, isIPv6, local); ok { + conn.ClearAliveTimer() + + return conn, nil + } + + return m.createConn(ufrag, isIPv6, local, false) +} + +func (m *TCPMuxDefault) createConn(ufrag string, isIPv6 bool, local net.IP, fromStun bool) (*tcpPacketConn, error) { + addr, ok := m.LocalAddr().(*net.TCPAddr) + if !ok { + return nil, ErrGetTransportAddress + } + localAddr := *addr + // Note: this is missing zone for IPv6 + localAddr.IP = local + + var alive time.Duration + if fromStun { + alive = m.params.AliveDurationForConnFromStun + } + + conn := newTCPPacketConn(tcpPacketParams{ + ReadBuffer: m.params.ReadBufferSize, + WriteBuffer: m.params.WriteBufferSize, + LocalAddr: &localAddr, + Logger: m.params.Logger, + AliveDuration: alive, + }) + + var conns map[ipAddr]*tcpPacketConn + if isIPv6 { + if conns, ok = m.connsIPv6[ufrag]; !ok { + conns = make(map[ipAddr]*tcpPacketConn) + m.connsIPv6[ufrag] = conns + } + } else { + if conns, ok = m.connsIPv4[ufrag]; !ok { + conns = make(map[ipAddr]*tcpPacketConn) + m.connsIPv4[ufrag] = conns + } + } + // Note: this is missing zone for IPv6 + connKey := ipAddr(local.String()) + conns[connKey] = conn + + m.wg.Add(1) + go func() { + defer m.wg.Done() + <-conn.CloseChannel() + m.removeConnByUfragAndLocalHost(ufrag, connKey) + }() + + return conn, nil +} + +func (m *TCPMuxDefault) closeAndLogError(closer io.Closer) { + err := closer.Close() + if err != nil { + m.params.Logger.Warnf("Error closing connection: %s", err) + } +} + +func (m *TCPMuxDefault) handleConn(conn net.Conn) { //nolint:cyclop + buf := make([]byte, 512) + + if m.params.FirstStunBindTimeout > 0 { + if err := conn.SetReadDeadline(time.Now().Add(m.params.FirstStunBindTimeout)); err != nil { + m.params.Logger.Warnf( + "Failed to set read deadline for first STUN message: %s to %s , err: %s", + conn.RemoteAddr(), + conn.LocalAddr(), + err, + ) + } + } + n, err := readStreamingPacket(conn, buf) + if err != nil { + if errors.Is(err, io.ErrShortBuffer) { + m.params.Logger.Warnf("Buffer too small for first packet from %s: %s", conn.RemoteAddr(), err) + } else { + m.params.Logger.Warnf("Error reading first packet from %s: %s", conn.RemoteAddr(), err) + } + m.closeAndLogError(conn) + + return + } + if err = conn.SetReadDeadline(time.Time{}); err != nil { + m.params.Logger.Warnf("Failed to reset read deadline from %s: %s", conn.RemoteAddr(), err) + } + + buf = buf[:n] + + msg := &stun.Message{ + Raw: make([]byte, len(buf)), + } + // Explicitly copy raw buffer so Message can own the memory. + copy(msg.Raw, buf) + if err = msg.Decode(); err != nil { + m.closeAndLogError(conn) + m.params.Logger.Warnf("Failed to handle decode ICE from %s to %s: %v", conn.RemoteAddr(), conn.LocalAddr(), err) + + return + } + + if m == nil || msg.Type.Method != stun.MethodBinding { // Not a STUN + m.closeAndLogError(conn) + m.params.Logger.Warnf("Not a STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr()) + + return + } + + for _, attr := range msg.Attributes { + m.params.Logger.Debugf("Message attribute: %s", attr.String()) + } + + attr, err := msg.Get(stun.AttrUsername) + if err != nil { + m.closeAndLogError(conn) + m.params.Logger.Warnf( + "No Username attribute in STUN message from %s to %s", + conn.RemoteAddr(), + conn.LocalAddr(), + ) + + return + } + + ufrag := strings.Split(string(attr), ":")[0] + m.params.Logger.Debugf("Ufrag: %s", ufrag) + + host, _, err := net.SplitHostPort(conn.RemoteAddr().String()) + if err != nil { + m.closeAndLogError(conn) + m.params.Logger.Warnf( + "Failed to get host in STUN message from %s to %s", + conn.RemoteAddr(), + conn.LocalAddr(), + ) + + return + } + + isIPv6 := net.ParseIP(host).To4() == nil + + localAddr, ok := conn.LocalAddr().(*net.TCPAddr) + if !ok { + m.closeAndLogError(conn) + m.params.Logger.Warnf( + "Failed to get local tcp address in STUN message from %s to %s", + conn.RemoteAddr(), + conn.LocalAddr(), + ) + + return + } + m.mu.Lock() + + packetConn, ok := m.getConn(ufrag, isIPv6, localAddr.IP) + if !ok { + packetConn, err = m.createConn(ufrag, isIPv6, localAddr.IP, true) + if err != nil { + m.mu.Unlock() + m.closeAndLogError(conn) + m.params.Logger.Warnf( + "Failed to create packetConn for STUN message from %s to %s", + conn.RemoteAddr(), + conn.LocalAddr(), + ) + + return + } + } + m.mu.Unlock() + + if err := packetConn.AddConn(conn, buf); err != nil { + m.closeAndLogError(conn) + m.params.Logger.Warnf( + "Error adding conn to tcpPacketConn from %s to %s: %s", + conn.RemoteAddr(), + conn.LocalAddr(), + err, + ) + + return + } +} + +// Close closes the listener and waits for all goroutines to exit. +func (m *TCPMuxDefault) Close() error { + m.mu.Lock() + m.closed = true + + for _, conns := range m.connsIPv4 { + for _, conn := range conns { + m.closeAndLogError(conn) + } + } + for _, conns := range m.connsIPv6 { + for _, conn := range conns { + m.closeAndLogError(conn) + } + } + + m.connsIPv4 = map[string]map[ipAddr]*tcpPacketConn{} + m.connsIPv6 = map[string]map[ipAddr]*tcpPacketConn{} + + err := m.params.Listener.Close() + + m.mu.Unlock() + + m.wg.Wait() + + return err +} + +// RemoveConnByUfrag closes and removes a net.PacketConn by Ufrag. +func (m *TCPMuxDefault) RemoveConnByUfrag(ufrag string) { + removedConns := make([]*tcpPacketConn, 0, 4) + + // Keep lock section small to avoid deadlock with conn lock + m.mu.Lock() + if conns, ok := m.connsIPv4[ufrag]; ok { + delete(m.connsIPv4, ufrag) + for _, conn := range conns { + removedConns = append(removedConns, conn) + } + } + if conns, ok := m.connsIPv6[ufrag]; ok { + delete(m.connsIPv6, ufrag) + for _, conn := range conns { + removedConns = append(removedConns, conn) + } + } + + m.mu.Unlock() + + // Close the connections outside the critical section to avoid + // deadlocking TCP mux if (*tcpPacketConn).Close() blocks. + for _, conn := range removedConns { + m.closeAndLogError(conn) + } +} + +func (m *TCPMuxDefault) removeConnByUfragAndLocalHost(ufrag string, localIPAddr ipAddr) { + removedConns := make([]*tcpPacketConn, 0, 4) + + // Keep lock section small to avoid deadlock with conn lock + m.mu.Lock() + if conns, ok := m.connsIPv4[ufrag]; ok { + if conn, ok := conns[localIPAddr]; ok { + delete(conns, localIPAddr) + if len(conns) == 0 { + delete(m.connsIPv4, ufrag) + } + removedConns = append(removedConns, conn) + } + } + if conns, ok := m.connsIPv6[ufrag]; ok { + if conn, ok := conns[localIPAddr]; ok { + delete(conns, localIPAddr) + if len(conns) == 0 { + delete(m.connsIPv6, ufrag) + } + removedConns = append(removedConns, conn) + } + } + m.mu.Unlock() + + // Close the connections outside the critical section to avoid + // deadlocking TCP mux if (*tcpPacketConn).Close() blocks. + for _, conn := range removedConns { + m.closeAndLogError(conn) + } +} + +func (m *TCPMuxDefault) getConn(ufrag string, isIPv6 bool, local net.IP) (val *tcpPacketConn, ok bool) { + var conns map[ipAddr]*tcpPacketConn + if isIPv6 { + conns, ok = m.connsIPv6[ufrag] + } else { + conns, ok = m.connsIPv4[ufrag] + } + if conns != nil { + // Note: this is missing zone for IPv6 + connKey := ipAddr(local.String()) + val, ok = conns[connKey] + } + + return +} + +const streamingPacketHeaderLen = 2 + +// readStreamingPacket reads 1 packet from stream +// read packet bytes https://tools.ietf.org/html/rfc4571#section-2 +// 2-byte length header prepends each packet: +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// ----------------------------------------------------------------- +// | LENGTH | RTP or RTCP packet ... | +// ----------------------------------------------------------------- +func readStreamingPacket(conn net.Conn, buf []byte) (int, error) { + header := make([]byte, streamingPacketHeaderLen) + var bytesRead, n int + var err error + + for bytesRead < streamingPacketHeaderLen { + if n, err = conn.Read(header[bytesRead:streamingPacketHeaderLen]); err != nil { + return 0, err + } + bytesRead += n + } + + length := int(binary.BigEndian.Uint16(header)) + + if length > cap(buf) { + return length, io.ErrShortBuffer + } + + bytesRead = 0 + for bytesRead < length { + if n, err = conn.Read(buf[bytesRead:length]); err != nil { + return 0, err + } + bytesRead += n + } + + return bytesRead, nil +} + +func writeStreamingPacket(conn net.Conn, buf []byte) (int, error) { + bufCopy := make([]byte, streamingPacketHeaderLen+len(buf)) + binary.BigEndian.PutUint16(bufCopy, uint16(len(buf))) //nolint:gosec // G115 + copy(bufCopy[2:], buf) + + n, err := conn.Write(bufCopy) + if err != nil { + return 0, err + } + + return n - streamingPacketHeaderLen, nil +} diff --git a/vendor/github.com/pion/ice/v4/tcp_mux_multi.go b/vendor/github.com/pion/ice/v4/tcp_mux_multi.go new file mode 100644 index 0000000..669205f --- /dev/null +++ b/vendor/github.com/pion/ice/v4/tcp_mux_multi.go @@ -0,0 +1,86 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "net" +) + +// AllConnsGetter allows multiple fixed TCP ports to be used, +// each of which is multiplexed like TCPMux. AllConnsGetter also acts as +// a TCPMux, in which case it will return a single connection for one +// of the ports. +type AllConnsGetter interface { + GetAllConns(ufrag string, isIPv6 bool, localIP net.IP) ([]net.PacketConn, error) +} + +// MultiTCPMuxDefault implements both TCPMux and AllConnsGetter, +// allowing users to pass multiple TCPMux instances to the ICE agent +// configuration. +type MultiTCPMuxDefault struct { + muxes []TCPMux +} + +// NewMultiTCPMuxDefault creates an instance of MultiTCPMuxDefault that +// uses the provided TCPMux instances. +func NewMultiTCPMuxDefault(muxes ...TCPMux) *MultiTCPMuxDefault { + return &MultiTCPMuxDefault{ + muxes: muxes, + } +} + +// GetConnByUfrag returns a PacketConn given the connection's ufrag, network and local address +// creates the connection if an existing one can't be found. This, unlike +// GetAllConns, will only return a single PacketConn from the first mux that was +// passed in to NewMultiTCPMuxDefault. +func (m *MultiTCPMuxDefault) GetConnByUfrag(ufrag string, isIPv6 bool, local net.IP) (net.PacketConn, error) { + // NOTE: We always use the first element here in order to maintain the + // behavior of using an existing connection if one exists. + if len(m.muxes) == 0 { + return nil, errNoTCPMuxAvailable + } + + return m.muxes[0].GetConnByUfrag(ufrag, isIPv6, local) +} + +// RemoveConnByUfrag stops and removes the muxed packet connection +// from all underlying TCPMux instances. +func (m *MultiTCPMuxDefault) RemoveConnByUfrag(ufrag string) { + for _, mux := range m.muxes { + mux.RemoveConnByUfrag(ufrag) + } +} + +// GetAllConns returns a PacketConn for each underlying TCPMux. +func (m *MultiTCPMuxDefault) GetAllConns(ufrag string, isIPv6 bool, local net.IP) ([]net.PacketConn, error) { + if len(m.muxes) == 0 { + // Make sure that we either return at least one connection or an error. + return nil, errNoTCPMuxAvailable + } + var conns []net.PacketConn + for _, mux := range m.muxes { + conn, err := mux.GetConnByUfrag(ufrag, isIPv6, local) + if err != nil { + // For now, this implementation is all or none. + return nil, err + } + if conn != nil { + conns = append(conns, conn) + } + } + + return conns, nil +} + +// Close the multi mux, no further connections could be created. +func (m *MultiTCPMuxDefault) Close() error { + var err error + for _, mux := range m.muxes { + if e := mux.Close(); e != nil { + err = e + } + } + + return err +} diff --git a/vendor/github.com/pion/ice/v4/tcp_packet_conn.go b/vendor/github.com/pion/ice/v4/tcp_packet_conn.go new file mode 100644 index 0000000..388b4d1 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/tcp_packet_conn.go @@ -0,0 +1,383 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "errors" + "fmt" + "io" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/pion/logging" + "github.com/pion/transport/v4/packetio" +) + +type bufferedConn struct { + net.Conn + buf *packetio.Buffer + logger logging.LeveledLogger + closed int32 +} + +func newBufferedConn(conn net.Conn, bufSize int, logger logging.LeveledLogger) net.Conn { + buf := packetio.NewBuffer() + if bufSize > 0 { + buf.SetLimitSize(bufSize) + } + + bc := &bufferedConn{ + Conn: conn, + buf: buf, + logger: logger, + } + + go bc.writeProcess() + + return bc +} + +func (bc *bufferedConn) Write(b []byte) (int, error) { + n, err := bc.buf.Write(b) + if err != nil { + return n, err + } + + return n, nil +} + +func (bc *bufferedConn) writeProcess() { + pktBuf := make([]byte, receiveMTU) + for atomic.LoadInt32(&bc.closed) == 0 { + n, err := bc.buf.Read(pktBuf) + if errors.Is(err, io.EOF) { + return + } + + if err != nil { + bc.logger.Warnf("Failed to read from buffer: %s", err) + + continue + } + + if _, err := bc.Conn.Write(pktBuf[:n]); err != nil { + bc.logger.Warnf("Failed to write: %s", err) + + continue + } + } +} + +func (bc *bufferedConn) Close() error { + atomic.StoreInt32(&bc.closed, 1) + _ = bc.buf.Close() + + return bc.Conn.Close() +} + +type tcpPacketConn struct { + params *tcpPacketParams + + // conns is a map of net.Conns indexed by remote net.Addr.String() + conns map[string]net.Conn + + recvChan chan streamingPacket + + mu sync.Mutex + wg sync.WaitGroup + closedChan chan struct{} + closeOnce sync.Once + aliveTimer *time.Timer +} + +type streamingPacket struct { + Data []byte + RAddr net.Addr + Err error +} + +type tcpPacketParams struct { + ReadBuffer int + LocalAddr net.Addr + Logger logging.LeveledLogger + WriteBuffer int + AliveDuration time.Duration +} + +func newTCPPacketConn(params tcpPacketParams) *tcpPacketConn { + packet := &tcpPacketConn{ + params: ¶ms, + + conns: map[string]net.Conn{}, + + recvChan: make(chan streamingPacket, params.ReadBuffer), + closedChan: make(chan struct{}), + } + + if params.AliveDuration > 0 { + packet.aliveTimer = time.AfterFunc(params.AliveDuration, func() { + packet.params.Logger.Warn("close tcp packet conn by alive timeout") + _ = packet.Close() + }) + } + + return packet +} + +func (t *tcpPacketConn) ClearAliveTimer() { + t.mu.Lock() + if t.aliveTimer != nil { + t.aliveTimer.Stop() + } + t.mu.Unlock() +} + +func (t *tcpPacketConn) AddConn(conn net.Conn, firstPacketData []byte) error { + t.params.Logger.Infof( + "Added connection: %s remote %s to local %s", + conn.RemoteAddr().Network(), + conn.RemoteAddr(), + conn.LocalAddr(), + ) + + t.mu.Lock() + defer t.mu.Unlock() + + select { + case <-t.closedChan: + return io.ErrClosedPipe + default: + } + + if _, ok := t.conns[conn.RemoteAddr().String()]; ok { + return fmt.Errorf("%w: %s", errConnectionAddrAlreadyExist, conn.RemoteAddr().String()) + } + + if t.params.WriteBuffer > 0 { + conn = newBufferedConn(conn, t.params.WriteBuffer, t.params.Logger) + } + t.conns[conn.RemoteAddr().String()] = conn + + t.wg.Add(1) + go func() { + defer t.wg.Done() + if firstPacketData != nil { + select { + case <-t.closedChan: + // NOTE: recvChan can fill up and never drain in edge + // cases while closing a connection, which can cause the + // packetConn to never finish closing. Bail out early + // here to prevent that. + return + case t.recvChan <- streamingPacket{firstPacketData, conn.RemoteAddr(), nil}: + } + } + t.startReading(conn) + }() + + return nil +} + +func (t *tcpPacketConn) startReading(conn net.Conn) { + buf := make([]byte, receiveMTU) + + for { + n, err := readStreamingPacket(conn, buf) + if err != nil { + t.params.Logger.Warnf("Failed to read streaming packet: %s", err) + last := t.removeConn(conn) + // Only propagate connection closure errors if no other open connection exists. + if last || (!errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed)) { + t.handleRecv(streamingPacket{nil, conn.RemoteAddr(), err}) + } + + return + } + + data := make([]byte, n) + copy(data, buf[:n]) + + t.handleRecv(streamingPacket{data, conn.RemoteAddr(), nil}) + } +} + +func (t *tcpPacketConn) handleRecv(pkt streamingPacket) { + t.mu.Lock() + + recvChan := t.recvChan + if t.isClosed() { + recvChan = nil + } + + t.mu.Unlock() + + select { + case recvChan <- pkt: + case <-t.closedChan: + } +} + +func (t *tcpPacketConn) isClosed() bool { + select { + case <-t.closedChan: + return true + default: + return false + } +} + +// WriteTo is for passive and s-o candidates. +func (t *tcpPacketConn) ReadFrom(b []byte) (n int, rAddr net.Addr, err error) { + pkt, ok := <-t.recvChan + + if !ok { + return 0, nil, io.ErrClosedPipe + } + + if pkt.Err != nil { + return 0, pkt.RAddr, pkt.Err + } + + if cap(b) < len(pkt.Data) { + return 0, pkt.RAddr, io.ErrShortBuffer + } + + n = len(pkt.Data) + copy(b, pkt.Data[:n]) + + return n, pkt.RAddr, err +} + +// WriteTo is for active and s-o candidates. +func (t *tcpPacketConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) { + t.mu.Lock() + conn, ok := t.conns[rAddr.String()] + t.mu.Unlock() + + if !ok { + return 0, io.ErrClosedPipe + } + + n, err = writeStreamingPacket(conn, buf) + if err != nil { + t.params.Logger.Tracef("%w %s", errWrite, rAddr) + + return n, err + } + + return n, err +} + +func (t *tcpPacketConn) closeAndLogError(closer io.Closer) { + err := closer.Close() + if err != nil { + t.params.Logger.Warnf("%v: %s", errClosingConnection, err) + } +} + +func (t *tcpPacketConn) removeConn(conn net.Conn) bool { + t.mu.Lock() + defer t.mu.Unlock() + + t.closeAndLogError(conn) + + // wait for some time to flush pending writes + _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + // read deadline as well just in case + _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + + delete(t.conns, conn.RemoteAddr().String()) + + return len(t.conns) == 0 +} + +func (t *tcpPacketConn) Close() error { + t.mu.Lock() + + var shouldCloseRecvChan bool + t.closeOnce.Do(func() { + close(t.closedChan) + shouldCloseRecvChan = true + if t.aliveTimer != nil { + t.aliveTimer.Stop() + } + }) + + for _, conn := range t.conns { + t.closeAndLogError(conn) + + // wait for some time to flush pending writes + _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + // read deadline as well just in case + _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + + delete(t.conns, conn.RemoteAddr().String()) + } + + t.mu.Unlock() + + t.wg.Wait() + + if shouldCloseRecvChan { + close(t.recvChan) + } + + return nil +} + +func (t *tcpPacketConn) LocalAddr() net.Addr { + return t.params.LocalAddr +} + +func (t *tcpPacketConn) SetDeadline(d time.Time) error { + t.mu.Lock() + defer t.mu.Unlock() + + var err error + for _, conn := range t.conns { + if setErr := conn.SetDeadline(d); err == nil && setErr != nil { + err = setErr + } + } + + return err +} + +func (t *tcpPacketConn) SetReadDeadline(d time.Time) error { + t.mu.Lock() + defer t.mu.Unlock() + + var err error + for _, conn := range t.conns { + if setErr := conn.SetReadDeadline(d); err == nil && setErr != nil { + err = setErr + } + } + + return err +} + +func (t *tcpPacketConn) SetWriteDeadline(d time.Time) error { + t.mu.Lock() + defer t.mu.Unlock() + + var err error + for _, conn := range t.conns { + if setErr := conn.SetWriteDeadline(d); err == nil && setErr != nil { + err = setErr + } + } + + return err +} + +func (t *tcpPacketConn) CloseChannel() <-chan struct{} { + return t.closedChan +} + +func (t *tcpPacketConn) String() string { + return fmt.Sprintf("tcpPacketConn{LocalAddr: %s}", t.params.LocalAddr) +} diff --git a/vendor/github.com/pion/ice/v4/tcptype.go b/vendor/github.com/pion/ice/v4/tcptype.go new file mode 100644 index 0000000..cf9951c --- /dev/null +++ b/vendor/github.com/pion/ice/v4/tcptype.go @@ -0,0 +1,51 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import "strings" + +// TCPType is the type of ICE TCP candidate as described in +// https://tools.ietf.org/html/rfc6544#section-4.5 +type TCPType int + +const ( + // TCPTypeUnspecified is the default value. For example UDP candidates do not + // need this field. + TCPTypeUnspecified TCPType = iota + // TCPTypeActive is active TCP candidate, which initiates TCP connections. + TCPTypeActive + // TCPTypePassive is passive TCP candidate, only accepts TCP connections. + TCPTypePassive + // TCPTypeSimultaneousOpen is like active and passive at the same time. + TCPTypeSimultaneousOpen +) + +// NewTCPType creates a new TCPType from string. +func NewTCPType(value string) TCPType { + switch strings.ToLower(value) { + case "active": + return TCPTypeActive + case "passive": + return TCPTypePassive + case "so": + return TCPTypeSimultaneousOpen + default: + return TCPTypeUnspecified + } +} + +func (t TCPType) String() string { + switch t { + case TCPTypeUnspecified: + return "" + case TCPTypeActive: + return "active" + case TCPTypePassive: + return "passive" + case TCPTypeSimultaneousOpen: + return "so" + default: + return ErrUnknownType.Error() + } +} diff --git a/vendor/github.com/pion/ice/v4/transport.go b/vendor/github.com/pion/ice/v4/transport.go new file mode 100644 index 0000000..40ac049 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/transport.go @@ -0,0 +1,240 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "context" + "net" + "sync/atomic" + "time" + + "github.com/pion/stun/v3" +) + +// Dial connects to the remote agent, acting as the controlling ice agent. +// Dial blocks until at least one ice candidate pair has successfully connected. +func (a *Agent) Dial(ctx context.Context, remoteUfrag, remotePwd string) (*Conn, error) { + return a.connect(ctx, true, remoteUfrag, remotePwd) +} + +// Accept connects to the remote agent, acting as the controlled ice agent. +// Accept blocks until at least one ice candidate pair has successfully connected. +func (a *Agent) Accept(ctx context.Context, remoteUfrag, remotePwd string) (*Conn, error) { + return a.connect(ctx, false, remoteUfrag, remotePwd) +} + +// Conn represents the ICE connection. +// At the moment the lifetime of the Conn is equal to the Agent. +type Conn struct { + bytesReceived atomic.Uint64 + bytesSent atomic.Uint64 + agent *Agent +} + +// BytesSent returns the number of bytes sent. +func (c *Conn) BytesSent() uint64 { + return c.bytesSent.Load() +} + +// BytesReceived returns the number of bytes received. +func (c *Conn) BytesReceived() uint64 { + return c.bytesReceived.Load() +} + +func (a *Agent) connect(ctx context.Context, isControlling bool, remoteUfrag, remotePwd string) (*Conn, error) { + err := a.loop.Err() + if err != nil { + return nil, err + } + err = a.startConnectivityChecks(isControlling, remoteUfrag, remotePwd) //nolint:contextcheck + if err != nil { + return nil, err + } + + // Block until pair selected + select { + case <-a.loop.Done(): + return nil, a.loop.Err() + case <-ctx.Done(): + return nil, ErrCanceledByCaller + case <-a.onConnected: + } + + return &Conn{ + agent: a, + }, nil +} + +// Read implements the Conn Read method. +func (c *Conn) Read(p []byte) (int, error) { + err := c.agent.loop.Err() + if err != nil { + return 0, err + } + + n, err := c.agent.buf.Read(p) + c.bytesReceived.Add(uint64(n)) //nolint:gosec // G115 + + return n, err +} + +// Write implements the Conn Write method. +func (c *Conn) Write(packet []byte) (int, error) { + err := c.agent.loop.Err() + if err != nil { + return 0, err + } + + if stun.IsMessage(packet) { + return 0, errWriteSTUNMessageToIceConn + } + + pair := c.agent.getSelectedPair() + if pair == nil { + if err = c.agent.loop.Run(c.agent.loop, func(_ context.Context) { + pair = c.agent.getBestValidCandidatePair() + }); err != nil { + return 0, err + } + + if pair == nil { + return 0, err + } + } + + // Write application data via the selected pair and update stats with actual bytes written. + n, err := pair.Write(packet) + if n > 0 { + c.bytesSent.Add(uint64(n)) + pair.UpdatePacketSent(n) + } + + return n, err +} + +// GetCandidatePairsInfo returns snapshot information for all candidate pairs. +// Use the returned ID with WriteToPair() to write to a specific pair. +func (c *Conn) GetCandidatePairsInfo() []CandidatePairInfo { + var pairs []CandidatePairInfo + + err := c.agent.loop.Run(c.agent.loop, func(_ context.Context) { + pairs = make([]CandidatePairInfo, 0, len(c.agent.checklist)) + for _, cp := range c.agent.checklist { + pairs = append(pairs, CandidatePairInfo{ + ID: cp.id, + LocalCandidateType: cp.Local.Type(), + RemoteCandidateType: cp.Remote.Type(), + State: cp.state, + Nominated: cp.nominated, + CurrentRoundTripTime: time.Duration(atomic.LoadInt64(&cp.currentRoundTripTime)), + RenominationQuality: c.agent.evaluateCandidatePairQuality(cp), + }) + } + }) + if err != nil { + return nil + } + + return pairs +} + +// WriteToPair writes packet to a specific candidate pair identified by its ID. +// Returns ErrCandidatePairNotFound if the pair ID is not found. +// Returns ErrCandidatePairNotSucceeded if the pair is not in Succeeded state. +// This is useful for sending packets over alternate paths +// even if they are not nominated. +func (c *Conn) WriteToPair(pairID uint64, packet []byte) (int, error) { + if err := c.agent.loop.Err(); err != nil { + return 0, err + } + + if stun.IsMessage(packet) { + return 0, errWriteSTUNMessageToIceConn + } + + var pair *CandidatePair + var lookupErr error + + if err := c.agent.loop.Run(c.agent.loop, func(_ context.Context) { + pair = c.agent.pairsByID[pairID] + if pair == nil { + lookupErr = ErrCandidatePairNotFound + + return + } + if pair.state != CandidatePairStateSucceeded { + lookupErr = ErrCandidatePairNotSucceeded + } + }); err != nil { + return 0, err + } + + if lookupErr != nil { + return 0, lookupErr + } + + n, err := pair.Write(packet) + if n > 0 { + pair.UpdatePacketSent(n) + } + + return n, err +} + +// Close implements the Conn Close method. It is used to close +// the connection. Any calls to Read and Write will be unblocked and return an error. +func (c *Conn) Close() error { + return c.agent.Close() +} + +// LocalAddr returns the local address of the current selected pair or nil if there is none. +func (c *Conn) LocalAddr() net.Addr { + pair := c.agent.getSelectedPair() + if pair == nil { + return nil + } + + return pair.Local.addr() +} + +// RemoteAddr returns the remote address of the current selected pair or nil if there is none. +func (c *Conn) RemoteAddr() net.Addr { + pair := c.agent.getSelectedPair() + if pair == nil { + return nil + } + + return pair.Remote.addr() +} + +// SetDeadline sets both read and write deadlines on the underlying ICE connection. +func (c *Conn) SetDeadline(t time.Time) error { + if err := c.SetReadDeadline(t); err != nil { + return err + } + + return c.SetWriteDeadline(t) +} + +// SetReadDeadline sets the read deadline on the packet buffer used for application data. +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.agent.buf.SetReadDeadline(t) +} + +// SetWriteDeadline sets the write deadline on the currently selected local candidate connection. +// The deadline applies to the selected candidate pair and will affect all traffic over that pair. +func (c *Conn) SetWriteDeadline(t time.Time) error { + pair := c.agent.getSelectedPair() + if pair == nil || pair.Local == nil { + return nil + } + + if d, ok := pair.Local.(interface { + setWriteDeadline(time.Time) error + }); ok { + return d.setWriteDeadline(t) + } + + return nil +} diff --git a/vendor/github.com/pion/ice/v4/udp_mux.go b/vendor/github.com/pion/ice/v4/udp_mux.go new file mode 100644 index 0000000..2077981 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/udp_mux.go @@ -0,0 +1,409 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "errors" + "io" + "net" + "net/netip" + "os" + "strings" + "sync" + + "github.com/pion/logging" + "github.com/pion/stun/v3" + "github.com/pion/transport/v4" + "github.com/pion/transport/v4/stdnet" +) + +// UDPMux allows multiple connections to go over a single UDP port. +type UDPMux interface { + io.Closer + GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) + RemoveConnByUfrag(ufrag string) + GetListenAddresses() []net.Addr +} + +// UDPMuxDefault is an implementation of the interface. +type UDPMuxDefault struct { + params UDPMuxParams + + closedChan chan struct{} + closeOnce sync.Once + + // connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType + connsIPv4, connsIPv6 map[string]*udpMuxedConn + + addressMapMu sync.RWMutex + addressMap map[ipPort]*udpMuxedConn + + // Buffer pool to recycle buffers for net.UDPAddr encodes/decodes + pool *sync.Pool + + mu sync.Mutex + + // For UDP connection listen at unspecified address + localAddrsForUnspecified []net.Addr +} + +// UDPMuxParams are parameters for UDPMux. +type UDPMuxParams struct { + Logger logging.LeveledLogger + UDPConn net.PacketConn + UDPConnString string + + // Required for gathering local addresses + // in case a un UDPConn is passed which does not + // bind to a specific local address. + Net transport.Net +} + +// NewUDPMuxDefault creates an implementation of UDPMux. +func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { //nolint:cyclop + if params.Logger == nil { + params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") + } + + var localAddrsForUnspecified []net.Addr + if udpAddr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok { //nolint:nestif + params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr()) + } else if ok && udpAddr.IP.IsUnspecified() { + // For unspecified addresses, the correct behavior is to return errListenUnspecified, but + // it will break the applications that are already using unspecified UDP connection + // with UDPMuxDefault, so print a warn log and create a local address list for mux. + params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") + var networks []NetworkType + switch { + case udpAddr.IP.To4() != nil: + networks = []NetworkType{NetworkTypeUDP4} + + case udpAddr.IP.To16() != nil: + networks = []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6} + + default: + params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr()) + } + if len(networks) > 0 { + if params.Net == nil { + var err error + if params.Net, err = stdnet.NewNet(); err != nil { + params.Logger.Errorf("Failed to get create network: %v", err) + } + } + + _, addrs, err := localInterfaces(params.Net, nil, nil, networks, true) + if err == nil { + localAddrsForUnspecified = make([]net.Addr, len(addrs)) + for i, addr := range addrs { + localAddrsForUnspecified[i] = &net.UDPAddr{ + IP: addr.addr.AsSlice(), + Port: udpAddr.Port, + Zone: addr.addr.Zone(), + } + } + } else { + params.Logger.Errorf("Failed to get local interfaces for unspecified addr: %v", err) + } + } + } + params.UDPConnString = params.UDPConn.LocalAddr().String() + + mux := &UDPMuxDefault{ + addressMap: map[ipPort]*udpMuxedConn{}, + params: params, + connsIPv4: make(map[string]*udpMuxedConn), + connsIPv6: make(map[string]*udpMuxedConn), + closedChan: make(chan struct{}, 1), + pool: &sync.Pool{ + New: func() any { + // Big enough buffer to fit both packet and address + return newBufferHolder(receiveMTU) + }, + }, + localAddrsForUnspecified: localAddrsForUnspecified, + } + + go mux.connWorker() + + return mux +} + +// LocalAddr returns the listening address of this UDPMuxDefault. +func (m *UDPMuxDefault) LocalAddr() net.Addr { + return m.params.UDPConn.LocalAddr() +} + +// GetListenAddresses returns the list of addresses that this mux is listening on. +func (m *UDPMuxDefault) GetListenAddresses() []net.Addr { + if len(m.localAddrsForUnspecified) > 0 { + return m.localAddrsForUnspecified + } + + return []net.Addr{m.LocalAddr()} +} + +// GetConn returns a PacketConn given the connection's ufrag and network address. +// creates the connection if an existing one can't be found. +func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) { + // don't check addr for mux using unspecified address + if len(m.localAddrsForUnspecified) == 0 && m.params.UDPConnString != addr.String() { + return nil, errInvalidAddress + } + + var isIPv6 bool + if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil { + isIPv6 = true + } + m.mu.Lock() + defer m.mu.Unlock() + + if m.IsClosed() { + return nil, io.ErrClosedPipe + } + + if conn, ok := m.getConn(ufrag, isIPv6); ok { + return conn, nil + } + + c := m.createMuxedConn(ufrag) + go func() { + <-c.CloseChannel() + m.RemoveConnByUfrag(ufrag) + }() + + if isIPv6 { + m.connsIPv6[ufrag] = c + } else { + m.connsIPv4[ufrag] = c + } + + return c, nil +} + +// RemoveConnByUfrag stops and removes the muxed packet connection. +func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { + removedConns := make([]*udpMuxedConn, 0, 2) + + // Keep lock section small to avoid deadlock with conn lock. + m.mu.Lock() + if c, ok := m.connsIPv4[ufrag]; ok { + delete(m.connsIPv4, ufrag) + removedConns = append(removedConns, c) + } + if c, ok := m.connsIPv6[ufrag]; ok { + delete(m.connsIPv6, ufrag) + removedConns = append(removedConns, c) + } + m.mu.Unlock() + + if len(removedConns) == 0 { + // No need to lock if no connection was found. + return + } + + m.addressMapMu.Lock() + defer m.addressMapMu.Unlock() + + for _, c := range removedConns { + addresses := c.getAddresses() + for _, addr := range addresses { + delete(m.addressMap, addr) + } + } +} + +// IsClosed returns true if the mux had been closed. +func (m *UDPMuxDefault) IsClosed() bool { + select { + case <-m.closedChan: + return true + default: + return false + } +} + +// Close the mux, no further connections could be created. +func (m *UDPMuxDefault) Close() error { + var err error + m.closeOnce.Do(func() { + m.mu.Lock() + defer m.mu.Unlock() + + for _, c := range m.connsIPv4 { + _ = c.Close() + } + for _, c := range m.connsIPv6 { + _ = c.Close() + } + + m.connsIPv4 = make(map[string]*udpMuxedConn) + m.connsIPv6 = make(map[string]*udpMuxedConn) + + close(m.closedChan) + + _ = m.params.UDPConn.Close() + }) + + return err +} + +func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) { + return m.params.UDPConn.WriteTo(buf, rAddr) +} + +func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr ipPort) { + if m.IsClosed() { + return + } + + m.addressMapMu.Lock() + defer m.addressMapMu.Unlock() + + existing, ok := m.addressMap[addr] + if ok { + existing.removeAddress(addr) + } + m.addressMap[addr] = conn + + m.params.Logger.Debugf("Registered %s for %s", addr.addr.String(), conn.params.Key) +} + +func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn { + c := newUDPMuxedConn(&udpMuxedConnParams{ + Mux: m, + Key: key, + AddrPool: m.pool, + LocalAddr: m.LocalAddr(), + Logger: m.params.Logger, + }) + + return c +} + +func (m *UDPMuxDefault) connWorker() { //nolint:cyclop + logger := m.params.Logger + + defer func() { + _ = m.Close() + }() + + buf := make([]byte, receiveMTU) + for { + n, addr, err := m.params.UDPConn.ReadFrom(buf) + if m.IsClosed() { + return + } else if err != nil { + if os.IsTimeout(err) { + continue + } else if !errors.Is(err, io.EOF) { + logger.Errorf("Failed to read UDP packet: %v", err) + } + + return + } + + netUDPAddr, ok := addr.(*net.UDPAddr) + if !ok { + logger.Errorf("Underlying PacketConn did not return a UDPAddr") + + return + } + udpAddr, err := newIPPort(netUDPAddr.IP, netUDPAddr.Zone, uint16(netUDPAddr.Port)) //nolint:gosec + if err != nil { + logger.Errorf("Failed to create a new IP/Port host pair") + + return + } + + // If we have already seen this address dispatch to the appropriate destination + m.addressMapMu.Lock() + destinationConn := m.addressMap[udpAddr] + m.addressMapMu.Unlock() + + // If we haven't seen this address before but is a STUN packet lookup by ufrag + if destinationConn == nil && stun.IsMessage(buf[:n]) { + msg := &stun.Message{ + Raw: append([]byte{}, buf[:n]...), + } + + if err = msg.Decode(); err != nil { + m.params.Logger.Warnf("Failed to handle decode ICE from %s: %v", addr.String(), err) + + continue + } + + attr, stunAttrErr := msg.Get(stun.AttrUsername) + if stunAttrErr != nil { + m.params.Logger.Warnf("No Username attribute in STUN message from %s", addr.String()) + + continue + } + + ufrag := strings.Split(string(attr), ":")[0] + isIPv6 := netUDPAddr.IP.To4() == nil + + m.mu.Lock() + destinationConn, _ = m.getConn(ufrag, isIPv6) + m.mu.Unlock() + } + + if destinationConn == nil { + m.params.Logger.Tracef("Dropping packet from %s, addr: %s", udpAddr.addr, addr) + + continue + } + + if err = destinationConn.writePacket(buf[:n], netUDPAddr); err != nil { + m.params.Logger.Errorf("Failed to write packet: %v", err) + } + } +} + +func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) { + if isIPv6 { + val, ok = m.connsIPv6[ufrag] + } else { + val, ok = m.connsIPv4[ufrag] + } + + return +} + +type bufferHolder struct { + next *bufferHolder + buf []byte + addr *net.UDPAddr +} + +func newBufferHolder(size int) *bufferHolder { + return &bufferHolder{ + buf: make([]byte, size), + } +} + +func (b *bufferHolder) reset() { + b.next = nil + b.addr = nil +} + +type ipPort struct { + addr netip.Addr + port uint16 +} + +// newIPPort create a custom type of address based on netip.Addr and +// port. The underlying ip address passed is converted to IPv6 format +// to simplify ip address handling. +func newIPPort(ip net.IP, zone string, port uint16) (ipPort, error) { + n, ok := netip.AddrFromSlice(ip.To16()) + if !ok { + return ipPort{}, errInvalidIPAddress + } + + return ipPort{ + addr: n.WithZone(zone), + port: port, + }, nil +} diff --git a/vendor/github.com/pion/ice/v4/udp_mux_multi.go b/vendor/github.com/pion/ice/v4/udp_mux_multi.go new file mode 100644 index 0000000..04a90b1 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/udp_mux_multi.go @@ -0,0 +1,238 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "fmt" + "net" + + "github.com/pion/logging" + "github.com/pion/transport/v4" + "github.com/pion/transport/v4/stdnet" +) + +// MultiUDPMuxDefault implements both UDPMux and AllConnsGetter, +// allowing users to pass multiple UDPMux instances to the ICE agent +// configuration. +type MultiUDPMuxDefault struct { + muxes []UDPMux + localAddrToMux map[string]UDPMux +} + +// NewMultiUDPMuxDefault creates an instance of MultiUDPMuxDefault that +// uses the provided UDPMux instances. +func NewMultiUDPMuxDefault(muxes ...UDPMux) *MultiUDPMuxDefault { + addrToMux := make(map[string]UDPMux) + for _, mux := range muxes { + for _, addr := range mux.GetListenAddresses() { + addrToMux[addr.String()] = mux + } + } + + return &MultiUDPMuxDefault{ + muxes: muxes, + localAddrToMux: addrToMux, + } +} + +// GetConn returns a PacketConn given the connection's ufrag and network +// creates the connection if an existing one can't be found. +func (m *MultiUDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) { + mux, ok := m.localAddrToMux[addr.String()] + if !ok { + return nil, errNoUDPMuxAvailable + } + + return mux.GetConn(ufrag, addr) +} + +// RemoveConnByUfrag stops and removes the muxed packet connection +// from all underlying UDPMux instances. +func (m *MultiUDPMuxDefault) RemoveConnByUfrag(ufrag string) { + for _, mux := range m.muxes { + mux.RemoveConnByUfrag(ufrag) + } +} + +// Close the multi mux, no further connections could be created. +func (m *MultiUDPMuxDefault) Close() error { + var err error + for _, mux := range m.muxes { + if e := mux.Close(); e != nil { + err = e + } + } + + return err +} + +// GetListenAddresses returns the list of addresses that this mux is listening on. +func (m *MultiUDPMuxDefault) GetListenAddresses() []net.Addr { + addrs := make([]net.Addr, 0, len(m.localAddrToMux)) + for _, mux := range m.muxes { + addrs = append(addrs, mux.GetListenAddresses()...) + } + + return addrs +} + +// NewMultiUDPMuxFromPort creates an instance of MultiUDPMuxDefault that +// listen all interfaces on the provided port. +func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMuxDefault, error) { //nolint:cyclop + params := multiUDPMuxFromPortParam{ + networks: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}, + } + for _, opt := range opts { + opt.apply(¶ms) + } + + if params.net == nil { + var err error + if params.net, err = stdnet.NewNet(); err != nil { + return nil, fmt.Errorf("failed to get create network: %w", err) + } + } + + _, addrs, err := localInterfaces(params.net, params.ifFilter, params.ipFilter, params.networks, params.includeLoopback) + if err != nil { + return nil, err + } + + conns := make([]net.PacketConn, 0, len(addrs)) + for _, addr := range addrs { + conn, listenErr := params.net.ListenUDP("udp", &net.UDPAddr{ + IP: addr.addr.AsSlice(), + Port: port, + Zone: addr.addr.Zone(), + }) + if listenErr != nil { + err = listenErr + + break + } + if params.readBufferSize > 0 { + _ = conn.SetReadBuffer(params.readBufferSize) + } + if params.writeBufferSize > 0 { + _ = conn.SetWriteBuffer(params.writeBufferSize) + } + conns = append(conns, conn) + } + + if err != nil { + for _, conn := range conns { + _ = conn.Close() + } + + return nil, err + } + + muxes := make([]UDPMux, 0, len(conns)) + for _, conn := range conns { + mux := NewUDPMuxDefault(UDPMuxParams{ + Logger: params.logger, + UDPConn: conn, + Net: params.net, + }) + muxes = append(muxes, mux) + } + + return NewMultiUDPMuxDefault(muxes...), nil +} + +// UDPMuxFromPortOption provide options for NewMultiUDPMuxFromPort. +type UDPMuxFromPortOption interface { + apply(*multiUDPMuxFromPortParam) +} + +type multiUDPMuxFromPortParam struct { + ifFilter func(string) (keep bool) + ipFilter func(ip net.IP) (keep bool) + networks []NetworkType + readBufferSize int + writeBufferSize int + logger logging.LeveledLogger + includeLoopback bool + net transport.Net +} + +type udpMuxFromPortOption struct { + f func(*multiUDPMuxFromPortParam) +} + +func (o *udpMuxFromPortOption) apply(p *multiUDPMuxFromPortParam) { + o.f(p) +} + +// UDPMuxFromPortWithInterfaceFilter set the filter to filter out interfaces that should not be used. +func UDPMuxFromPortWithInterfaceFilter(f func(string) (keep bool)) UDPMuxFromPortOption { + return &udpMuxFromPortOption{ + f: func(p *multiUDPMuxFromPortParam) { + p.ifFilter = f + }, + } +} + +// UDPMuxFromPortWithIPFilter set the filter to filter out IP addresses that should not be used. +func UDPMuxFromPortWithIPFilter(f func(ip net.IP) (keep bool)) UDPMuxFromPortOption { + return &udpMuxFromPortOption{ + f: func(p *multiUDPMuxFromPortParam) { + p.ipFilter = f + }, + } +} + +// UDPMuxFromPortWithNetworks set the networks that should be used. default is both IPv4 and IPv6. +func UDPMuxFromPortWithNetworks(networks ...NetworkType) UDPMuxFromPortOption { + return &udpMuxFromPortOption{ + f: func(p *multiUDPMuxFromPortParam) { + p.networks = networks + }, + } +} + +// UDPMuxFromPortWithReadBufferSize set the UDP connection read buffer size. +func UDPMuxFromPortWithReadBufferSize(size int) UDPMuxFromPortOption { + return &udpMuxFromPortOption{ + f: func(p *multiUDPMuxFromPortParam) { + p.readBufferSize = size + }, + } +} + +// UDPMuxFromPortWithWriteBufferSize set the UDP connection write buffer size. +func UDPMuxFromPortWithWriteBufferSize(size int) UDPMuxFromPortOption { + return &udpMuxFromPortOption{ + f: func(p *multiUDPMuxFromPortParam) { + p.writeBufferSize = size + }, + } +} + +// UDPMuxFromPortWithLogger set the logger for the created UDPMux. +func UDPMuxFromPortWithLogger(logger logging.LeveledLogger) UDPMuxFromPortOption { + return &udpMuxFromPortOption{ + f: func(p *multiUDPMuxFromPortParam) { + p.logger = logger + }, + } +} + +// UDPMuxFromPortWithLoopback set loopback interface should be included. +func UDPMuxFromPortWithLoopback() UDPMuxFromPortOption { + return &udpMuxFromPortOption{ + f: func(p *multiUDPMuxFromPortParam) { + p.includeLoopback = true + }, + } +} + +// UDPMuxFromPortWithNet sets the network transport to use. +func UDPMuxFromPortWithNet(n transport.Net) UDPMuxFromPortOption { + return &udpMuxFromPortOption{ + f: func(p *multiUDPMuxFromPortParam) { + p.net = n + }, + } +} diff --git a/vendor/github.com/pion/ice/v4/udp_mux_universal.go b/vendor/github.com/pion/ice/v4/udp_mux_universal.go new file mode 100644 index 0000000..54d0aaa --- /dev/null +++ b/vendor/github.com/pion/ice/v4/udp_mux_universal.go @@ -0,0 +1,281 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "fmt" + "net" + "time" + + "github.com/pion/logging" + "github.com/pion/stun/v3" + "github.com/pion/transport/v4" +) + +// UniversalUDPMux allows multiple connections to go over a single UDP port for +// host, server reflexive and relayed candidates. +// Actual connection muxing is happening in the UDPMux. +type UniversalUDPMux interface { + UDPMux + GetXORMappedAddr(stunAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error) + GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error) + GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) +} + +// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn overriding ReadFrom. +// It the passes packets to the UDPMux that does the actual connection muxing. +type UniversalUDPMuxDefault struct { + *UDPMuxDefault + params UniversalUDPMuxParams + + // Since we have a shared socket, for srflx candidates it makes sense + // to have a shared mapped address across all the agents + // stun.XORMappedAddress indexed by the STUN server addr + xorMappedMap map[string]*xorMapped +} + +// UniversalUDPMuxParams are parameters for UniversalUDPMux server reflexive. +type UniversalUDPMuxParams struct { + Logger logging.LeveledLogger + UDPConn net.PacketConn + XORMappedAddrCacheTTL time.Duration + Net transport.Net +} + +// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux. +func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault { + if params.Logger == nil { + params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") + } + if params.XORMappedAddrCacheTTL == 0 { + params.XORMappedAddrCacheTTL = time.Second * 25 + } + + mux := &UniversalUDPMuxDefault{ + params: params, + xorMappedMap: make(map[string]*xorMapped), + } + + // Wrap UDP connection, process server reflexive messages + // before they are passed to the UDPMux connection handler (connWorker) + mux.params.UDPConn = &udpConn{ + PacketConn: params.UDPConn, + mux: mux, + logger: params.Logger, + } + + // Embed UDPMux + udpMuxParams := UDPMuxParams{ + Logger: params.Logger, + UDPConn: mux.params.UDPConn, + Net: mux.params.Net, + } + mux.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams) + + return mux +} + +// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets. +type udpConn struct { + net.PacketConn + mux *UniversalUDPMuxDefault + logger logging.LeveledLogger +} + +// GetRelayedAddr creates relayed connection to the given TURN service and returns the relayed addr. +// Not implemented yet. +func (m *UniversalUDPMuxDefault) GetRelayedAddr(net.Addr, time.Duration) (*net.Addr, error) { + return nil, errNotImplemented +} + +// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL +// (e.g. STUN URL) to be able to support multiple STUN/TURN servers +// and return a unique connection per server. +func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) { + return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr) +} + +// ReadFrom is called by UDPMux connWorker and handles packets coming from the STUN server discovering a mapped address. +// It passes processed packets further to the UDPMux (maybe this is not really necessary). +func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, addr, err = c.PacketConn.ReadFrom(p) + if err != nil { + return n, addr, err + } + + if stun.IsMessage(p[:n]) { //nolint:nestif + msg := &stun.Message{ + Raw: append([]byte{}, p[:n]...), + } + + if err = msg.Decode(); err != nil { + c.logger.Warnf("Failed to handle decode ICE from %s: %v", addr.String(), err) + + return n, addr, nil + } + + udpAddr, ok := addr.(*net.UDPAddr) + if !ok { + // Message about this err will be logged in the UDPMux + return n, addr, err + } + + if c.mux.isXORMappedResponse(msg, udpAddr.String()) { + err = c.mux.handleXORMappedResponse(udpAddr, msg) + if err != nil { + c.logger.Debugf("%w: %v", errGetXorMappedAddrResponse, err) + err = nil + } + + return n, addr, err + } + } + + return n, addr, err +} + +// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server. +func (m *UniversalUDPMuxDefault) isXORMappedResponse(msg *stun.Message, stunAddr string) bool { + m.mu.Lock() + defer m.mu.Unlock() + // Check first if it is a STUN server address, + // because remote peer can also send similar messages but as a BindingSuccess. + _, ok := m.xorMappedMap[stunAddr] + _, err := msg.Get(stun.AttrXORMappedAddress) + + return err == nil && ok +} + +// handleXORMappedResponse parses response from the STUN server, extracts XORMappedAddress attribute. +// and set the mapped address for the server. +func (m *UniversalUDPMuxDefault) handleXORMappedResponse(stunAddr *net.UDPAddr, msg *stun.Message) error { + m.mu.Lock() + defer m.mu.Unlock() + + mappedAddr, ok := m.xorMappedMap[stunAddr.String()] + if !ok { + return errNoXorAddrMapping + } + + var addr stun.XORMappedAddress + if err := addr.GetFrom(msg); err != nil { + return err + } + + m.xorMappedMap[stunAddr.String()] = mappedAddr + mappedAddr.SetAddr(&addr) + + return nil +} + +// GetXORMappedAddr returns *stun.XORMappedAddress if already present for a given STUN server. +// Makes a STUN binding request to discover mapped address otherwise. +// Blocks until the stun.XORMappedAddress has been discovered or deadline. +// Method is safe for concurrent use. +func (m *UniversalUDPMuxDefault) GetXORMappedAddr( + serverAddr net.Addr, + deadline time.Duration, +) (*stun.XORMappedAddress, error) { + m.mu.Lock() + mappedAddr, ok := m.xorMappedMap[serverAddr.String()] + // If we already have a mapping for this STUN server (address already received) + // and if it is not too old we return it without making a new request to STUN server + if ok { + if mappedAddr.expired() { + mappedAddr.closeWaiters() + delete(m.xorMappedMap, serverAddr.String()) + ok = false + } else if mappedAddr.pending() { + ok = false + } + } + m.mu.Unlock() + if ok { + return mappedAddr.addr, nil + } + + // Otherwise, make a STUN request to discover the address + // or wait for already sent request to complete + waitAddrReceived, err := m.writeSTUN(serverAddr) + if err != nil { + return nil, fmt.Errorf("%w: %s", errWriteSTUNMessage, err) //nolint:errorlint + } + + // Block until response was handled by the connWorker routine and XORMappedAddress was updated + select { + case <-waitAddrReceived: + // When channel closed, addr was obtained + m.mu.Lock() + mappedAddr := *m.xorMappedMap[serverAddr.String()] + m.mu.Unlock() + if mappedAddr.addr == nil { + return nil, errNoXorAddrMapping + } + + return mappedAddr.addr, nil + case <-time.After(deadline): + return nil, errXORMappedAddrTimeout + } +} + +// writeSTUN sends a STUN request via UDP conn. +// +// The returned channel is closed when the STUN response has been received. +// Method is safe for concurrent use. +func (m *UniversalUDPMuxDefault) writeSTUN(serverAddr net.Addr) (chan struct{}, error) { + m.mu.Lock() + defer m.mu.Unlock() + + // If record present in the map, we already sent a STUN request, + // just wait when waitAddrReceived will be closed + addrMap, ok := m.xorMappedMap[serverAddr.String()] + if !ok { + addrMap = &xorMapped{ + expiresAt: time.Now().Add(m.params.XORMappedAddrCacheTTL), + waitAddrReceived: make(chan struct{}), + } + m.xorMappedMap[serverAddr.String()] = addrMap + } + + req, err := stun.Build(stun.BindingRequest, stun.TransactionID) + if err != nil { + return nil, err + } + + if _, err = m.params.UDPConn.WriteTo(req.Raw, serverAddr); err != nil { + return nil, err + } + + return addrMap.waitAddrReceived, nil +} + +type xorMapped struct { + addr *stun.XORMappedAddress + waitAddrReceived chan struct{} + expiresAt time.Time +} + +func (a *xorMapped) closeWaiters() { + select { + case <-a.waitAddrReceived: + // Notify was close, ok, that means we received duplicate response just exit + break + default: + // Notify tha twe have a new addr + close(a.waitAddrReceived) + } +} + +func (a *xorMapped) pending() bool { + return a.addr == nil +} + +func (a *xorMapped) expired() bool { + return a.expiresAt.Before(time.Now()) +} + +func (a *xorMapped) SetAddr(addr *stun.XORMappedAddress) { + a.addr = addr + a.closeWaiters() +} diff --git a/vendor/github.com/pion/ice/v4/udp_muxed_conn.go b/vendor/github.com/pion/ice/v4/udp_muxed_conn.go new file mode 100644 index 0000000..4f4fd41 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/udp_muxed_conn.go @@ -0,0 +1,251 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "io" + "net" + "slices" + "sync" + "time" + + "github.com/pion/logging" +) + +type udpMuxedConnState int + +const ( + udpMuxedConnOpen udpMuxedConnState = iota + udpMuxedConnWaiting + udpMuxedConnClosed +) + +type udpMuxedConnParams struct { + Mux *UDPMuxDefault + AddrPool *sync.Pool + Key string + LocalAddr net.Addr + Logger logging.LeveledLogger +} + +// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag. +type udpMuxedConn struct { + params *udpMuxedConnParams + // Remote addresses that we have sent to on this conn + addresses []ipPort + + // FIFO queue holding incoming packets + bufHead, bufTail *bufferHolder + notify chan struct{} + closedChan chan struct{} + state udpMuxedConnState + mu sync.Mutex +} + +func newUDPMuxedConn(params *udpMuxedConnParams) *udpMuxedConn { + return &udpMuxedConn{ + params: params, + notify: make(chan struct{}, 1), + closedChan: make(chan struct{}), + } +} + +func (c *udpMuxedConn) ReadFrom(b []byte) (n int, rAddr net.Addr, err error) { + for { + c.mu.Lock() + if c.bufTail != nil { + pkt := c.bufTail + c.bufTail = pkt.next + + if pkt == c.bufHead { + c.bufHead = nil + } + c.mu.Unlock() + + if len(b) < len(pkt.buf) { + err = io.ErrShortBuffer + } else { + n = copy(b, pkt.buf) + rAddr = pkt.addr + } + + pkt.reset() + c.params.AddrPool.Put(pkt) + + return n, rAddr, err + } + + if c.state == udpMuxedConnClosed { + c.mu.Unlock() + + return 0, nil, io.EOF + } + + c.state = udpMuxedConnWaiting + c.mu.Unlock() + + select { + case <-c.notify: + case <-c.closedChan: + return 0, nil, io.EOF + } + } +} + +func (c *udpMuxedConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) { + if c.isClosed() { + return 0, io.ErrClosedPipe + } + // Each time we write to a new address, we'll register it with the mux + netUDPAddr, ok := rAddr.(*net.UDPAddr) + if !ok { + return 0, errFailedToCastUDPAddr + } + + port := netUDPAddr.Port + if port < 0 || port > 0xFFFF { + return 0, ErrPort + } + ipAndPort, err := newIPPort(netUDPAddr.IP, netUDPAddr.Zone, uint16(port)) + if err != nil { + return 0, err + } + if !c.containsAddress(ipAndPort) { + c.addAddress(ipAndPort) + } + + return c.params.Mux.writeTo(buf, rAddr) +} + +func (c *udpMuxedConn) LocalAddr() net.Addr { + return c.params.LocalAddr +} + +func (c *udpMuxedConn) SetDeadline(time.Time) error { + return nil +} + +func (c *udpMuxedConn) SetReadDeadline(time.Time) error { + return nil +} + +func (c *udpMuxedConn) SetWriteDeadline(time.Time) error { + return nil +} + +func (c *udpMuxedConn) CloseChannel() <-chan struct{} { + return c.closedChan +} + +func (c *udpMuxedConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if c.state != udpMuxedConnClosed { + for pkt := c.bufTail; pkt != nil; { + next := pkt.next + + pkt.reset() + c.params.AddrPool.Put(pkt) + + pkt = next + } + c.bufHead = nil + c.bufTail = nil + + c.state = udpMuxedConnClosed + close(c.closedChan) + } + + return nil +} + +func (c *udpMuxedConn) isClosed() bool { + c.mu.Lock() + defer c.mu.Unlock() + + return c.state == udpMuxedConnClosed +} + +func (c *udpMuxedConn) getAddresses() []ipPort { + c.mu.Lock() + defer c.mu.Unlock() + addresses := make([]ipPort, len(c.addresses)) + copy(addresses, c.addresses) + + return addresses +} + +func (c *udpMuxedConn) addAddress(addr ipPort) { + c.mu.Lock() + c.addresses = append(c.addresses, addr) + c.mu.Unlock() + + // Map it on mux + c.params.Mux.registerConnForAddress(c, addr) +} + +func (c *udpMuxedConn) removeAddress(addr ipPort) { + c.mu.Lock() + defer c.mu.Unlock() + + newAddresses := make([]ipPort, 0, len(c.addresses)) + for _, a := range c.addresses { + if a != addr { + newAddresses = append(newAddresses, a) + } + } + + c.addresses = newAddresses +} + +func (c *udpMuxedConn) containsAddress(addr ipPort) bool { + c.mu.Lock() + defer c.mu.Unlock() + + return slices.Contains(c.addresses, addr) +} + +func (c *udpMuxedConn) writePacket(data []byte, addr *net.UDPAddr) error { + pkt := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert + if cap(pkt.buf) < len(data) { + c.params.AddrPool.Put(pkt) + + return io.ErrShortBuffer + } + + pkt.buf = append(pkt.buf[:0], data...) + pkt.addr = addr + + c.mu.Lock() + if c.state == udpMuxedConnClosed { + c.mu.Unlock() + + pkt.reset() + c.params.AddrPool.Put(pkt) + + return io.ErrClosedPipe + } + + if c.bufHead != nil { + c.bufHead.next = pkt + } + c.bufHead = pkt + + if c.bufTail == nil { + c.bufTail = pkt + } + + state := c.state + c.state = udpMuxedConnOpen + c.mu.Unlock() + + if state == udpMuxedConnWaiting { + select { + case c.notify <- struct{}{}: + default: + } + } + + return nil +} diff --git a/vendor/github.com/pion/ice/v4/url.go b/vendor/github.com/pion/ice/v4/url.go new file mode 100644 index 0000000..f5b0205 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/url.go @@ -0,0 +1,82 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import "github.com/pion/stun/v3" + +type ( + // URL represents a STUN (rfc7064) or TURN (rfc7065) URI. + // + // Deprecated: Please use pion/stun.URI. + URL = stun.URI + + // ProtoType indicates the transport protocol type that is used in the ice.URL + // structure. + // + // Deprecated: TPlease use pion/stun.ProtoType. + ProtoType = stun.ProtoType + + // SchemeType indicates the type of server used in the ice.URL structure. + // + // Deprecated: Please use pion/stun.SchemeType. + SchemeType = stun.SchemeType +) + +const ( + // SchemeTypeSTUN indicates the URL represents a STUN server. + // + // Deprecated: Please use pion/stun.SchemeTypeSTUN. + SchemeTypeSTUN = stun.SchemeTypeSTUN + + // SchemeTypeSTUNS indicates the URL represents a STUNS (secure) server. + // + // Deprecated: Please use pion/stun.SchemeTypeSTUNS. + SchemeTypeSTUNS = stun.SchemeTypeSTUNS + + // SchemeTypeTURN indicates the URL represents a TURN server. + // + // Deprecated: Please use pion/stun.SchemeTypeTURN. + SchemeTypeTURN = stun.SchemeTypeTURN + + // SchemeTypeTURNS indicates the URL represents a TURNS (secure) server. + // + // Deprecated: Please use pion/stun.SchemeTypeTURNS. + SchemeTypeTURNS = stun.SchemeTypeTURNS +) + +const ( + // ProtoTypeUDP indicates the URL uses a UDP transport. + // + // Deprecated: Please use pion/stun.ProtoTypeUDP. + ProtoTypeUDP = stun.ProtoTypeUDP + + // ProtoTypeTCP indicates the URL uses a TCP transport. + // + // Deprecated: Please use pion/stun.ProtoTypeTCP. + ProtoTypeTCP = stun.ProtoTypeTCP +) + +// Unknown represents and unknown ProtoType or SchemeType. +// +// Deprecated: Please use pion/stun.SchemeTypeUnknown or pion/stun.ProtoTypeUnknown. +const Unknown = 0 + +// ParseURL parses a STUN or TURN urls following the ABNF syntax described in. +// https://tools.ietf.org/html/rfc7064 and https://tools.ietf.org/html/rfc7065 +// respectively. +// +// Deprecated: Please use pion/stun.ParseURI. +var ParseURL = stun.ParseURI //nolint:gochecknoglobals + +// NewSchemeType defines a procedure for creating a new SchemeType from a raw. +// string naming the scheme type. +// +// Deprecated: Please use pion/stun.NewSchemeType. +var NewSchemeType = stun.NewSchemeType //nolint:gochecknoglobals + +// NewProtoType defines a procedure for creating a new ProtoType from a raw. +// string naming the transport protocol type. +// +// Deprecated: Please use pion/stun.NewProtoType. +var NewProtoType = stun.NewProtoType //nolint:gochecknoglobals diff --git a/vendor/github.com/pion/ice/v4/usecandidate.go b/vendor/github.com/pion/ice/v4/usecandidate.go new file mode 100644 index 0000000..f098792 --- /dev/null +++ b/vendor/github.com/pion/ice/v4/usecandidate.go @@ -0,0 +1,28 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import "github.com/pion/stun/v3" + +// UseCandidateAttr represents USE-CANDIDATE attribute. +type UseCandidateAttr struct{} + +// AddTo adds USE-CANDIDATE attribute to message. +func (UseCandidateAttr) AddTo(m *stun.Message) error { + m.Add(stun.AttrUseCandidate, nil) + + return nil +} + +// IsSet returns true if USE-CANDIDATE attribute is set. +func (UseCandidateAttr) IsSet(m *stun.Message) bool { + _, err := m.Get(stun.AttrUseCandidate) + + return err == nil +} + +// UseCandidate is shorthand for UseCandidateAttr. +func UseCandidate() UseCandidateAttr { + return UseCandidateAttr{} +} diff --git a/vendor/github.com/pion/interceptor/.gitignore b/vendor/github.com/pion/interceptor/.gitignore new file mode 100644 index 0000000..2394557 --- /dev/null +++ b/vendor/github.com/pion/interceptor/.gitignore @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: 2026 The Pion community +# SPDX-License-Identifier: MIT + +### JetBrains IDE ### +##################### +.idea/ + +### Emacs Temporary Files ### +############################# +*~ + +### Folders ### +############### +bin/ +vendor/ +node_modules/ + +### Files ### +############# +*.ivf +*.ogg +tags +cover.out +*.sw[poe] +*.wasm +examples/sfu-ws/cert.pem +examples/sfu-ws/key.pem +wasm_exec.js diff --git a/vendor/github.com/pion/interceptor/.golangci.yml b/vendor/github.com/pion/interceptor/.golangci.yml new file mode 100644 index 0000000..43af4c3 --- /dev/null +++ b/vendor/github.com/pion/interceptor/.golangci.yml @@ -0,0 +1,147 @@ +# SPDX-FileCopyrightText: 2026 The Pion community +# SPDX-License-Identifier: MIT + +version: "2" +linters: + enable: + - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers + - bidichk # Checks for dangerous unicode character sequences + - bodyclose # checks whether HTTP response body is closed successfully + - containedctx # containedctx is a linter that detects struct contained context.Context field + - contextcheck # check the function whether use a non-inherited context + - cyclop # checks function and package cyclomatic complexity + - decorder # check declaration order and count of types, constants, variables and functions + - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) + - dupl # Tool for code clone detection + - durationcheck # check for two durations multiplied together + - err113 # Golang linter to check the errors handling expressions + - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases + - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. + - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. + - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. + - exhaustive # check exhaustiveness of enum switch statements + - forbidigo # Forbids identifiers + - forcetypeassert # finds forced type assertions + - gochecknoglobals # Checks that no globals are present in Go code + - gocognit # Computes and checks the cognitive complexity of functions + - goconst # Finds repeated strings that could be replaced by a constant + - gocritic # The most opinionated Go source code linter + - gocyclo # Computes and checks the cyclomatic complexity of functions + - godot # Check if comments end in a period + - godox # Tool for detection of FIXME, TODO and other comment keywords + - goheader # Checks is file header matches to pattern + - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. + - goprintffuncname # Checks that printf-like functions are named with `f` at the end + - gosec # Inspects source code for security problems + - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string + - grouper # An analyzer to analyze expression groups. + - importas # Enforces consistent import aliases + - ineffassign # Detects when assignments to existing variables are not used + - lll # Reports long lines + - maintidx # maintidx measures the maintainability index of each function. + - makezero # Finds slice declarations with non-zero initial length + - misspell # Finds commonly misspelled English words in comments + - nakedret # Finds naked returns in functions greater than a specified function length + - nestif # Reports deeply nested if statements + - nilerr # Finds the code that returns nil even if it checks that the error is not nil. + - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. + - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity + - noctx # noctx finds sending http request without context.Context + - predeclared # find code that shadows one of Go's predeclared identifiers + - revive # golint replacement, finds style mistakes + - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks + - tagliatelle # Checks the struct tags. + - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers + - unconvert # Remove unnecessary type conversions + - unparam # Reports unused function parameters + - unused # Checks Go code for unused constants, variables, functions and types + - varnamelen # checks that the length of a variable's name matches its scope + - wastedassign # wastedassign finds wasted assignment statements + - whitespace # Tool for detection of leading and trailing whitespace + disable: + - depguard # Go linter that checks if package imports are in a list of acceptable packages + - funlen # Tool for detection of long functions + - gochecknoinits # Checks that no init functions are present in Go code + - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. + - interfacebloat # A linter that checks length of interface. + - ireturn # Accept Interfaces, Return Concrete Types + - mnd # An analyzer to detect magic numbers + - nolintlint # Reports ill-formed or insufficient nolint directives + - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test + - prealloc # Finds slice declarations that could potentially be preallocated + - promlinter # Check Prometheus metrics naming via promlint + - rowserrcheck # checks whether Err of rows is checked successfully + - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. + - testpackage # linter that makes you use a separate _test package + - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes + - wrapcheck # Checks that errors returned from external packages are wrapped + - wsl # Whitespace Linter - Forces you to use empty lines! + settings: + staticcheck: + checks: + - all + - -QF1008 # "could remove embedded field", to keep it explicit! + - -QF1003 # "could use tagged switch on enum", Cases conflicts with exhaustive! + exhaustive: + default-signifies-exhaustive: true + forbidigo: + forbid: + - pattern: ^fmt.Print(f|ln)?$ + - pattern: ^log.(Panic|Fatal|Print)(f|ln)?$ + - pattern: ^os.Exit$ + - pattern: ^panic$ + - pattern: ^print(ln)?$ + - pattern: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$ + pkg: ^testing$ + msg: use testify/assert instead + analyze-types: true + gomodguard: + blocked: + modules: + - github.com/pkg/errors: + recommendations: + - errors + govet: + enable: + - shadow + revive: + rules: + # Prefer 'any' type alias over 'interface{}' for Go 1.18+ compatibility + - name: use-any + severity: warning + disabled: false + misspell: + locale: US + varnamelen: + max-distance: 12 + min-name-length: 2 + ignore-type-assert-ok: true + ignore-map-index-ok: true + ignore-chan-recv-ok: true + ignore-decls: + - i int + - n int + - w io.Writer + - r io.Reader + - b []byte + exclusions: + generated: lax + rules: + - linters: + - forbidigo + - gocognit + path: (examples|main\.go) + - linters: + - gocognit + path: _test\.go + - linters: + - forbidigo + path: cmd +formatters: + enable: + - gci # Gci control golang package import order and make it always deterministic. + - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification + - gofumpt # Gofumpt checks whether code was gofumpt-ed. + - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports + exclusions: + generated: lax diff --git a/vendor/github.com/pion/interceptor/.goreleaser.yml b/vendor/github.com/pion/interceptor/.goreleaser.yml new file mode 100644 index 0000000..8577d86 --- /dev/null +++ b/vendor/github.com/pion/interceptor/.goreleaser.yml @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: 2026 The Pion community +# SPDX-License-Identifier: MIT + +builds: +- skip: true diff --git a/vendor/github.com/pion/interceptor/LICENSE b/vendor/github.com/pion/interceptor/LICENSE new file mode 100644 index 0000000..d96df05 --- /dev/null +++ b/vendor/github.com/pion/interceptor/LICENSE @@ -0,0 +1,9 @@ +MIT License + +Copyright (c) 2026 The Pion community + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/pion/interceptor/README.md b/vendor/github.com/pion/interceptor/README.md new file mode 100644 index 0000000..46a9ca6 --- /dev/null +++ b/vendor/github.com/pion/interceptor/README.md @@ -0,0 +1,84 @@ +

+
+ Pion Interceptor +
+

+

RTP and RTCP processors for building real time communications

+

+ Pion Interceptor + join us on Discord Follow us on Bluesky +
+ GitHub Workflow Status + Go Reference + Coverage Status + Go Report Card + License: MIT +

+
+ +Interceptor is a framework for building RTP/RTCP communication software. This framework defines +a interface that each interceptor must satisfy. These interceptors are then run sequentially. We +also then provide common interceptors that will be useful for building RTC software. + +This package was built for [pion/webrtc](https://github.com/pion/webrtc), but we designed it to be consumable +by anyone. With the following tenets in mind. + +* Useful defaults. Each interceptor will be configured to give you a good default experience. +* Unblock unique use cases. New use cases are what is driving WebRTC, we want to empower them. +* Encourage modification. Add your own interceptors without forking. Mixing with the ones we provide. +* Empower learning. This code base should be useful to read and learn even if you aren't using Pion. + +### Current Interceptors +* [NACK Generator/Responder](https://github.com/pion/interceptor/tree/master/pkg/nack) +* [Sender and Receiver Reports](https://github.com/pion/interceptor/tree/master/pkg/report) +* [Transport Wide Congestion Control Feedback](https://github.com/pion/interceptor/tree/master/pkg/twcc) +* [Packet Dump](https://github.com/pion/interceptor/tree/master/pkg/packetdump) +* [Google Congestion Control](https://github.com/pion/interceptor/tree/master/pkg/gcc) +* [Stats](https://github.com/pion/interceptor/tree/master/pkg/stats) A [webrtc-stats](https://www.w3.org/TR/webrtc-stats/) compliant statistics generation +* [Interval PLI](https://github.com/pion/interceptor/tree/master/pkg/intervalpli) Generate PLI on a interval. Useful when no decoder is available. +* [FlexFec](https://github.com/pion/interceptor/tree/master/pkg/flexfec) – [FlexFEC-03](https://datatracker.ietf.org/doc/html/draft-ietf-payload-flexible-fec-scheme-03) encoder implementation + +### Planned Interceptors +* Bandwidth Estimation + - [NADA](https://tools.ietf.org/html/rfc8698) +* JitterBuffer, re-order packets and wait for arrival +* [RTCP Feedback for Congestion Control](https://datatracker.ietf.org/doc/html/rfc8888) the standardized alternative to TWCC. + +### Interceptor Public API +The public interface is defined in [interceptor.go](https://github.com/pion/interceptor/blob/master/interceptor.go). +The methods you need to satisy are broken up into 4 groups. + +* `BindRTCPWriter` and `BindRTCPReader` allow you to inspect/modify RTCP traffic. +* `BindLocalStream` and `BindRemoteStream` notify you of a new SSRC stream and allow you to inspect/modify. +* `UnbindLocalStream` and `UnbindRemoteStream` notify you when a SSRC stream has been removed +* `Close` called when the interceptor is closed. + +Interceptors also pass Attributes between each other. These are a collection of key/value pairs and are useful for storing metadata +or caching. + +[noop.go](https://github.com/pion/interceptor/blob/master/noop.go) is an interceptor that satisfies this interface, but does nothing. +You can embed this interceptor as a starting point so you only need to define exactly what you need. + +[chain.go]( https://github.com/pion/interceptor/blob/master/chain.go) is used to combine multiple interceptors into one. They are called +sequentially as the packet moves through them. + +### Examples +The [examples](https://github.com/pion/interceptor/blob/master/examples) directory provides some basic examples. If you need more please file an issue! +You should also look in [pion/webrtc](https://github.com/pion/webrtc) for real world examples. + +### Roadmap +The library is used as a part of our WebRTC implementation. Please refer to that [roadmap](https://github.com/pion/webrtc/issues/9) to track our major milestones. + +### Community +Pion has an active community on the [Discord](https://discord.gg/PngbdqpFbt). + +Follow the [Pion Bluesky](https://bsky.app/profile/pion.ly) or [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news. + +We are always looking to support **your projects**. Please reach out if you have something to build! +If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly) + +### Contributing +Check out the [contributing wiki](https://github.com/pion/webrtc/wiki/Contributing) to join the group of amazing people making this project possible + +### License +MIT License - see [LICENSE](LICENSE) for full text diff --git a/vendor/github.com/pion/interceptor/attributes.go b/vendor/github.com/pion/interceptor/attributes.go new file mode 100644 index 0000000..1eb4f10 --- /dev/null +++ b/vendor/github.com/pion/interceptor/attributes.go @@ -0,0 +1,72 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package interceptor + +import ( + "errors" + + "github.com/pion/rtcp" + "github.com/pion/rtp" +) + +type unmarshaledDataKeyType int + +const ( + rtpHeaderKey unmarshaledDataKeyType = iota + rtcpPacketsKey +) + +var errInvalidType = errors.New("found value of invalid type in attributes map") + +// Attributes are a generic key/value store used by interceptors. +type Attributes map[any]any + +// Get returns the attribute associated with key. +func (a Attributes) Get(key any) any { + return a[key] +} + +// Set sets the attribute associated with key to the given value. +func (a Attributes) Set(key any, val any) { + a[key] = val +} + +// GetRTPHeader gets the RTP header if present. If it is not present, it will be +// unmarshalled from the raw byte slice and stored in the attributes. +func (a Attributes) GetRTPHeader(raw []byte) (*rtp.Header, error) { + if val, ok := a[rtpHeaderKey]; ok { + if header, ok := val.(*rtp.Header); ok { + return header, nil + } + + return nil, errInvalidType + } + header := &rtp.Header{} + if _, err := header.Unmarshal(raw); err != nil { + return nil, err + } + a[rtpHeaderKey] = header + + return header, nil +} + +// GetRTCPPackets gets the RTCP packets if present. If the packet slice is not +// present, it will be unmarshalled from the raw byte slice and stored in the +// attributes. +func (a Attributes) GetRTCPPackets(raw []byte) ([]rtcp.Packet, error) { + if val, ok := a[rtcpPacketsKey]; ok { + if packets, ok := val.([]rtcp.Packet); ok { + return packets, nil + } + + return nil, errInvalidType + } + pkts, err := rtcp.Unmarshal(raw) + if err != nil { + return nil, err + } + a[rtcpPacketsKey] = pkts + + return pkts, nil +} diff --git a/vendor/github.com/pion/interceptor/chain.go b/vendor/github.com/pion/interceptor/chain.go new file mode 100644 index 0000000..95dcdb9 --- /dev/null +++ b/vendor/github.com/pion/interceptor/chain.go @@ -0,0 +1,79 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package interceptor + +// Chain is an interceptor that runs all child interceptors in order. +type Chain struct { + interceptors []Interceptor +} + +// NewChain returns a new Chain interceptor. +func NewChain(interceptors []Interceptor) *Chain { + return &Chain{interceptors: interceptors} +} + +// BindRTCPReader lets you modify any incoming RTCP packets. It is called once per sender/receiver, however this might +// change in the future. The returned method will be called once per packet batch. +func (i *Chain) BindRTCPReader(reader RTCPReader) RTCPReader { + for _, interceptor := range i.interceptors { + reader = interceptor.BindRTCPReader(reader) + } + + return reader +} + +// BindRTCPWriter lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method +// will be called once per packet batch. +func (i *Chain) BindRTCPWriter(writer RTCPWriter) RTCPWriter { + for _, interceptor := range i.interceptors { + writer = interceptor.BindRTCPWriter(writer) + } + + return writer +} + +// BindLocalStream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method +// will be called once per rtp packet. +func (i *Chain) BindLocalStream(ctx *StreamInfo, writer RTPWriter) RTPWriter { + for _, interceptor := range i.interceptors { + writer = interceptor.BindLocalStream(ctx, writer) + } + + return writer +} + +// UnbindLocalStream is called when the Stream is removed. It can be used to clean up any data related to that track. +func (i *Chain) UnbindLocalStream(ctx *StreamInfo) { + for _, interceptor := range i.interceptors { + interceptor.UnbindLocalStream(ctx) + } +} + +// BindRemoteStream lets you modify any incoming RTP packets. +// It is called once for per RemoteStream. The returned method +// will be called once per rtp packet. +func (i *Chain) BindRemoteStream(ctx *StreamInfo, reader RTPReader) RTPReader { + for _, interceptor := range i.interceptors { + reader = interceptor.BindRemoteStream(ctx, reader) + } + + return reader +} + +// UnbindRemoteStream is called when the Stream is removed. It can be used to clean up any data related to that track. +func (i *Chain) UnbindRemoteStream(ctx *StreamInfo) { + for _, interceptor := range i.interceptors { + interceptor.UnbindRemoteStream(ctx) + } +} + +// Close closes the Interceptor, cleaning up any data if necessary. +func (i *Chain) Close() error { + var errs []error + for _, interceptor := range i.interceptors { + errs = append(errs, interceptor.Close()) + } + + return flattenErrs(errs) +} diff --git a/vendor/github.com/pion/interceptor/codecov.yml b/vendor/github.com/pion/interceptor/codecov.yml new file mode 100644 index 0000000..b9639c2 --- /dev/null +++ b/vendor/github.com/pion/interceptor/codecov.yml @@ -0,0 +1,22 @@ +# +# DO NOT EDIT THIS FILE +# +# It is automatically copied from https://github.com/pion/.goassets repository. +# +# SPDX-FileCopyrightText: 2026 The Pion community +# SPDX-License-Identifier: MIT + +coverage: + status: + project: + default: + # Allow decreasing 2% of total coverage to avoid noise. + threshold: 2% + patch: + default: + target: 70% + only_pulls: true + +ignore: + - "examples/*" + - "examples/**/*" diff --git a/vendor/github.com/pion/interceptor/errors.go b/vendor/github.com/pion/interceptor/errors.go new file mode 100644 index 0000000..d074ff7 --- /dev/null +++ b/vendor/github.com/pion/interceptor/errors.go @@ -0,0 +1,56 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package interceptor + +import ( + "errors" + "strings" +) + +func flattenErrs(errs []error) error { + errs2 := []error{} + for _, e := range errs { + if e != nil { + errs2 = append(errs2, e) + } + } + if len(errs2) == 0 { + return nil + } + + return multiError(errs2) +} + +type multiError []error //nolint + +func (me multiError) Error() string { + var errstrings []string + + for _, err := range me { + if err != nil { + errstrings = append(errstrings, err.Error()) + } + } + + if len(errstrings) == 0 { + return "multiError must contain multiple error but is empty" + } + + return strings.Join(errstrings, "\n") +} + +func (me multiError) Is(err error) bool { + for _, e := range me { + if errors.Is(e, err) { + return true + } + if me2, ok := e.(multiError); ok { //nolint + if me2.Is(err) { + return true + } + } + } + + return false +} diff --git a/vendor/github.com/pion/interceptor/interceptor.go b/vendor/github.com/pion/interceptor/interceptor.go new file mode 100644 index 0000000..e54220f --- /dev/null +++ b/vendor/github.com/pion/interceptor/interceptor.go @@ -0,0 +1,103 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package interceptor contains the Interceptor interface, with some useful interceptors that should be safe to use +// in most cases. +package interceptor + +import ( + "io" + + "github.com/pion/rtcp" + "github.com/pion/rtp" +) + +// Factory provides an interface for constructing interceptors. +type Factory interface { + NewInterceptor(id string) (Interceptor, error) +} + +// Interceptor can be used to add functionality to you PeerConnections by modifying any incoming/outgoing rtp/rtcp +// packets, or sending your own packets as needed. +type Interceptor interface { + // BindRTCPReader lets you modify any incoming RTCP packets. It is called once per sender/receiver, however this might + // change in the future. The returned method will be called once per packet batch. + BindRTCPReader(reader RTCPReader) RTCPReader + + // BindRTCPWriter lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method + // will be called once per packet batch. + BindRTCPWriter(writer RTCPWriter) RTCPWriter + + // BindLocalStream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method + // will be called once per rtp packet. + BindLocalStream(info *StreamInfo, writer RTPWriter) RTPWriter + + // UnbindLocalStream is called when the Stream is removed. It can be used to clean up any data related to that track. + UnbindLocalStream(info *StreamInfo) + + // BindRemoteStream lets you modify any incoming RTP packets. + // It is called once for per RemoteStream. The returned method + // will be called once per rtp packet. + BindRemoteStream(info *StreamInfo, reader RTPReader) RTPReader + + // UnbindRemoteStream is called when the Stream is removed. It can be used to clean up any data related to that track. + UnbindRemoteStream(info *StreamInfo) + + io.Closer +} + +// RTPWriter is used by Interceptor.BindLocalStream. +type RTPWriter interface { + // Write a rtp packet + Write(header *rtp.Header, payload []byte, attributes Attributes) (int, error) +} + +// RTPReader is used by Interceptor.BindRemoteStream. +type RTPReader interface { + // Read a rtp packet + Read([]byte, Attributes) (int, Attributes, error) +} + +// RTCPWriter is used by Interceptor.BindRTCPWriter. +type RTCPWriter interface { + // Write a batch of rtcp packets + Write(pkts []rtcp.Packet, attributes Attributes) (int, error) +} + +// RTCPReader is used by Interceptor.BindRTCPReader. +type RTCPReader interface { + // Read a batch of rtcp packets + Read([]byte, Attributes) (int, Attributes, error) +} + +// RTPWriterFunc is an adapter for RTPWrite interface. +type RTPWriterFunc func(header *rtp.Header, payload []byte, attributes Attributes) (int, error) + +// RTPReaderFunc is an adapter for RTPReader interface. +type RTPReaderFunc func([]byte, Attributes) (int, Attributes, error) + +// RTCPWriterFunc is an adapter for RTCPWriter interface. +type RTCPWriterFunc func(pkts []rtcp.Packet, attributes Attributes) (int, error) + +// RTCPReaderFunc is an adapter for RTCPReader interface. +type RTCPReaderFunc func([]byte, Attributes) (int, Attributes, error) + +// Write a rtp packet. +func (f RTPWriterFunc) Write(header *rtp.Header, payload []byte, attributes Attributes) (int, error) { + return f(header, payload, attributes) +} + +// Read a rtp packet. +func (f RTPReaderFunc) Read(b []byte, a Attributes) (int, Attributes, error) { + return f(b, a) +} + +// Write a batch of rtcp packets. +func (f RTCPWriterFunc) Write(pkts []rtcp.Packet, attributes Attributes) (int, error) { + return f(pkts, attributes) +} + +// Read a batch of rtcp packets. +func (f RTCPReaderFunc) Read(b []byte, a Attributes) (int, Attributes, error) { + return f(b, a) +} diff --git a/vendor/github.com/pion/interceptor/internal/ntp/ntp.go b/vendor/github.com/pion/interceptor/internal/ntp/ntp.go new file mode 100644 index 0000000..4d8a949 --- /dev/null +++ b/vendor/github.com/pion/interceptor/internal/ntp/ntp.go @@ -0,0 +1,47 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package ntp provides conversion methods between time.Time and NTP timestamps +// stored in uint64 +package ntp + +import ( + "time" +) + +// ToNTP converts a time.Time oboject to an uint64 NTP timestamp. +func ToNTP(t time.Time) uint64 { + // seconds since 1st January 1900 + s := (float64(t.UnixNano()) / 1000000000) + 2208988800 + + // higher 32 bits are the integer part, lower 32 bits are the fractional part + integerPart := uint32(s) + fractionalPart := uint32((s - float64(integerPart)) * 0xFFFFFFFF) + + return uint64(integerPart)<<32 | uint64(fractionalPart) //nolint:gosec // G115 +} + +// ToNTP32 converts a time.Time object to a uint32 NTP timestamp. +func ToNTP32(t time.Time) uint32 { + return uint32(ToNTP(t) >> 16) //nolint:gosec // G115 +} + +// ToTime converts a uint64 NTP timestamps to a time.Time object. +func ToTime(t uint64) time.Time { + seconds := (t & 0xFFFFFFFF00000000) >> 32 + fractional := float64(t&0x00000000FFFFFFFF) / float64(0xFFFFFFFF) + //nolint:gosec // G115 + d := time.Duration(seconds)*time.Second + time.Duration(fractional*1e9)*time.Nanosecond + + return time.Unix(0, 0).Add(-2208988800 * time.Second).Add(d) +} + +// ToTime32 converts a uint32 NTP timestamp to a time.Time object, using the +// highest 16 bit of the reference to recover the lost bits. The low 16 bits are +// not recovered. +func ToTime32(t uint32, reference time.Time) time.Time { + referenceNTP := ToNTP(reference) & 0xFFFF000000000000 + tu64 := ((uint64(t) << 16) & 0x0000FFFFFFFF0000) | referenceNTP + + return ToTime(tu64) +} diff --git a/vendor/github.com/pion/interceptor/internal/rtpbuffer/errors.go b/vendor/github.com/pion/interceptor/internal/rtpbuffer/errors.go new file mode 100644 index 0000000..8ba12ad --- /dev/null +++ b/vendor/github.com/pion/interceptor/internal/rtpbuffer/errors.go @@ -0,0 +1,16 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package rtpbuffer + +import "errors" + +// ErrInvalidSize is returned by newReceiveLog/newRTPBuffer, when an incorrect buffer size is supplied. +var ErrInvalidSize = errors.New("invalid buffer size") + +var ( + errPacketReleased = errors.New("could not retain packet, already released") + errFailedToCastHeaderPool = errors.New("could not access header pool, failed cast") + errFailedToCastPayloadPool = errors.New("could not access payload pool, failed cast") + errPaddingOverflow = errors.New("padding size exceeds payload size") +) diff --git a/vendor/github.com/pion/interceptor/internal/rtpbuffer/packet_factory.go b/vendor/github.com/pion/interceptor/internal/rtpbuffer/packet_factory.go new file mode 100644 index 0000000..a261a2e --- /dev/null +++ b/vendor/github.com/pion/interceptor/internal/rtpbuffer/packet_factory.go @@ -0,0 +1,149 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package rtpbuffer + +import ( + "encoding/binary" + "io" + "sync" + + "github.com/pion/rtp" +) + +const rtxSsrcByteLength = 2 + +// PacketFactory allows custom logic around the handle of RTP Packets before they added to the RTPBuffer. +// The NoOpPacketFactory doesn't copy packets, while the RetainablePacket will take a copy before adding. +type PacketFactory interface { + NewPacket(header *rtp.Header, payload []byte, rtxSsrc uint32, rtxPayloadType uint8) (*RetainablePacket, error) +} + +// PacketFactoryCopy is PacketFactory that takes a copy of packets when added to the RTPBuffer. +type PacketFactoryCopy struct { + headerPool *sync.Pool + payloadPool *sync.Pool + rtxSequencer rtp.Sequencer +} + +// NewPacketFactoryCopy constructs a PacketFactory that takes a copy of packets when added to the RTPBuffer. +func NewPacketFactoryCopy() *PacketFactoryCopy { + return &PacketFactoryCopy{ + headerPool: &sync.Pool{ + New: func() any { + return &rtp.Header{} + }, + }, + payloadPool: &sync.Pool{ + New: func() any { + buf := make([]byte, maxPayloadLen) + + return &buf + }, + }, + rtxSequencer: rtp.NewRandomSequencer(), + } +} + +// NewPacket constructs a new RetainablePacket that can be added to the RTPBuffer. +// +//nolint:cyclop +func (m *PacketFactoryCopy) NewPacket( + header *rtp.Header, payload []byte, rtxSsrc uint32, rtxPayloadType uint8, +) (*RetainablePacket, error) { + if len(payload) > maxPayloadLen { + return nil, io.ErrShortBuffer + } + + retainablePacket := &RetainablePacket{ + onRelease: m.releasePacket, + sequenceNumber: header.SequenceNumber, + // new packets have retain count of 1 + count: 1, + } + + var ok bool + retainablePacket.header, ok = m.headerPool.Get().(*rtp.Header) + if !ok { + return nil, errFailedToCastHeaderPool + } + + *retainablePacket.header = header.Clone() + + if payload != nil { + retainablePacket.buffer, ok = m.payloadPool.Get().(*[]byte) + if !ok { + return nil, errFailedToCastPayloadPool + } + if rtxSsrc != 0 && rtxPayloadType != 0 { + size := copy((*retainablePacket.buffer)[rtxSsrcByteLength:], payload) + retainablePacket.payload = (*retainablePacket.buffer)[:size+rtxSsrcByteLength] + } else { + size := copy(*retainablePacket.buffer, payload) + retainablePacket.payload = (*retainablePacket.buffer)[:size] + } + } + + if rtxSsrc != 0 && rtxPayloadType != 0 { //nolint:nestif + if payload == nil { + retainablePacket.buffer, ok = m.payloadPool.Get().(*[]byte) + if !ok { + return nil, errFailedToCastPayloadPool + } + retainablePacket.payload = (*retainablePacket.buffer)[:rtxSsrcByteLength] + } + // Write the original sequence number at the beginning of the payload. + binary.BigEndian.PutUint16(retainablePacket.payload, retainablePacket.header.SequenceNumber) + + // Rewrite the SSRC. + retainablePacket.header.SSRC = rtxSsrc + // Rewrite the payload type. + retainablePacket.header.PayloadType = rtxPayloadType + // Rewrite the sequence number. + retainablePacket.header.SequenceNumber = m.rtxSequencer.NextSequenceNumber() + // Remove padding if present. + if retainablePacket.header.Padding { + // Older versions of pion/rtp didn't have the Header.PaddingSize field and as a workaround + // users had to add padding to the payload. We need to handle this case here. + if retainablePacket.header.PaddingSize == 0 && len(retainablePacket.payload) > 0 { + paddingLength := int(retainablePacket.payload[len(retainablePacket.payload)-1]) + if paddingLength > len(retainablePacket.payload) { + return nil, errPaddingOverflow + } + retainablePacket.payload = (*retainablePacket.buffer)[:len(retainablePacket.payload)-paddingLength] + } + + retainablePacket.header.Padding = false + retainablePacket.header.PaddingSize = 0 + } + } + + return retainablePacket, nil +} + +func (m *PacketFactoryCopy) releasePacket(header *rtp.Header, payload *[]byte) { + m.headerPool.Put(header) + if payload != nil { + m.payloadPool.Put(payload) + } +} + +// PacketFactoryNoOp is a PacketFactory implementation that doesn't copy packets. +type PacketFactoryNoOp struct{} + +// NewPacket constructs a new RetainablePacket that can be added to the RTPBuffer. +func (f *PacketFactoryNoOp) NewPacket( + header *rtp.Header, payload []byte, _ uint32, _ uint8, +) (*RetainablePacket, error) { + return &RetainablePacket{ + onRelease: f.releasePacket, + count: 1, + header: header, + payload: payload, + sequenceNumber: header.SequenceNumber, + }, nil +} + +func (f *PacketFactoryNoOp) releasePacket(_ *rtp.Header, _ *[]byte) { + // no-op +} diff --git a/vendor/github.com/pion/interceptor/internal/rtpbuffer/retainable_packet.go b/vendor/github.com/pion/interceptor/internal/rtpbuffer/retainable_packet.go new file mode 100644 index 0000000..adb65b1 --- /dev/null +++ b/vendor/github.com/pion/interceptor/internal/rtpbuffer/retainable_packet.go @@ -0,0 +1,62 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package rtpbuffer + +import ( + "sync" + + "github.com/pion/rtp" +) + +// RetainablePacket is a referenced counted RTP packet. +type RetainablePacket struct { + onRelease func(*rtp.Header, *[]byte) + + countMu sync.Mutex + count int + + header *rtp.Header + buffer *[]byte + payload []byte + + sequenceNumber uint16 +} + +// Header returns the RTP Header of the RetainablePacket. +func (p *RetainablePacket) Header() *rtp.Header { + return p.header +} + +// Payload returns the RTP Payload of the RetainablePacket. +func (p *RetainablePacket) Payload() []byte { + return p.payload +} + +// Retain increases the reference count of the RetainablePacket. +func (p *RetainablePacket) Retain() error { + p.countMu.Lock() + defer p.countMu.Unlock() + if p.count == 0 { + // already released + return errPacketReleased + } + p.count++ + + return nil +} + +// Release decreases the reference count of the RetainablePacket and frees if needed. +func (p *RetainablePacket) Release() { + p.countMu.Lock() + defer p.countMu.Unlock() + p.count-- + + if p.count == 0 { + // release back to pool + p.onRelease(p.header, p.buffer) + p.header = nil + p.buffer = nil + p.payload = nil + } +} diff --git a/vendor/github.com/pion/interceptor/internal/rtpbuffer/rtpbuffer.go b/vendor/github.com/pion/interceptor/internal/rtpbuffer/rtpbuffer.go new file mode 100644 index 0000000..2501646 --- /dev/null +++ b/vendor/github.com/pion/interceptor/internal/rtpbuffer/rtpbuffer.go @@ -0,0 +1,107 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package rtpbuffer provides a buffer for storing RTP packets +package rtpbuffer + +import ( + "fmt" +) + +const ( + // Uint16SizeHalf is half of a math.Uint16. + Uint16SizeHalf = 1 << 15 + + maxPayloadLen = 1460 +) + +// RTPBuffer stores RTP packets and allows custom logic +// around the lifetime of them via the PacketFactory. +type RTPBuffer struct { + packets []*RetainablePacket + size uint16 + highestAdded uint16 + started bool +} + +// NewRTPBuffer constructs a new RTPBuffer. +func NewRTPBuffer(size uint16) (*RTPBuffer, error) { + allowedSizes := make([]uint16, 0) + correctSize := false + for i := 0; i < 16; i++ { + if size == 1<= Uint16SizeHalf { + return nil + } + + if diff >= r.size { + return nil + } + + pkt := r.packets[seq%r.size] + if pkt != nil { + if pkt.sequenceNumber != seq { + return nil + } + // already released + if err := pkt.Retain(); err != nil { + return nil + } + } + + return pkt +} diff --git a/vendor/github.com/pion/interceptor/internal/sequencenumber/unwrapper.go b/vendor/github.com/pion/interceptor/internal/sequencenumber/unwrapper.go new file mode 100644 index 0000000..c6e466d --- /dev/null +++ b/vendor/github.com/pion/interceptor/internal/sequencenumber/unwrapper.go @@ -0,0 +1,48 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package sequencenumber provides a sequence number unwrapper +package sequencenumber + +const ( + maxSequenceNumberPlusOne = int64(65536) + breakpoint = 32768 // half of max uint16 +) + +// Unwrapper stores an unwrapped sequence number. +type Unwrapper struct { + init bool + lastUnwrapped int64 +} + +func isNewer(value, previous uint16) bool { + if value-previous == breakpoint { + return value > previous + } + + return value != previous && (value-previous) < breakpoint +} + +// Unwrap unwraps the next sequencenumber. +func (u *Unwrapper) Unwrap(i uint16) int64 { + if !u.init { + u.init = true + u.lastUnwrapped = int64(i) + + return u.lastUnwrapped + } + + lastWrapped := uint16(u.lastUnwrapped) //nolint:gosec // G115 + delta := int64(i - lastWrapped) + if isNewer(i, lastWrapped) { + if delta < 0 { + delta += maxSequenceNumberPlusOne + } + } else if delta > 0 && u.lastUnwrapped+delta-maxSequenceNumberPlusOne >= 0 { + delta -= maxSequenceNumberPlusOne + } + + u.lastUnwrapped += delta + + return u.lastUnwrapped +} diff --git a/vendor/github.com/pion/interceptor/noop.go b/vendor/github.com/pion/interceptor/noop.go new file mode 100644 index 0000000..bdc9c96 --- /dev/null +++ b/vendor/github.com/pion/interceptor/noop.go @@ -0,0 +1,44 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package interceptor + +// NoOp is an Interceptor that does not modify any packets. It can embedded in other interceptors, so it's +// possible to implement only a subset of the methods. +type NoOp struct{} + +// BindRTCPReader lets you modify any incoming RTCP packets. It is called once per sender/receiver, however this might +// change in the future. The returned method will be called once per packet batch. +func (i *NoOp) BindRTCPReader(reader RTCPReader) RTCPReader { + return reader +} + +// BindRTCPWriter lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method +// will be called once per packet batch. +func (i *NoOp) BindRTCPWriter(writer RTCPWriter) RTCPWriter { + return writer +} + +// BindLocalStream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method +// will be called once per rtp packet. +func (i *NoOp) BindLocalStream(_ *StreamInfo, writer RTPWriter) RTPWriter { + return writer +} + +// UnbindLocalStream is called when the Stream is removed. It can be used to clean up any data related to that track. +func (i *NoOp) UnbindLocalStream(_ *StreamInfo) {} + +// BindRemoteStream lets you modify any incoming RTP packets. +// It is called once for per RemoteStream. The returned method +// will be called once per rtp packet. +func (i *NoOp) BindRemoteStream(_ *StreamInfo, reader RTPReader) RTPReader { + return reader +} + +// UnbindRemoteStream is called when the Stream is removed. It can be used to clean up any data related to that track. +func (i *NoOp) UnbindRemoteStream(_ *StreamInfo) {} + +// Close closes the Interceptor, cleaning up any data if necessary. +func (i *NoOp) Close() error { + return nil +} diff --git a/vendor/github.com/pion/interceptor/pkg/flexfec/encoder_interceptor.go b/vendor/github.com/pion/interceptor/pkg/flexfec/encoder_interceptor.go new file mode 100644 index 0000000..0b3f709 --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/flexfec/encoder_interceptor.go @@ -0,0 +1,128 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package flexfec + +import ( + "errors" + "sync" + + "github.com/pion/interceptor" + "github.com/pion/rtp" +) + +// streamState holds the state for a single stream. +type streamState struct { + mu sync.Mutex + flexFecEncoder FlexEncoder + packetBuffer []rtp.Packet +} + +// FecInterceptor implements FlexFec. +type FecInterceptor struct { + interceptor.NoOp + mu sync.Mutex + streams map[uint32]*streamState + numMediaPackets uint32 + numFecPackets uint32 + encoderFactory EncoderFactory +} + +// FecInterceptorFactory creates new FecInterceptors. +type FecInterceptorFactory struct { + opts []FecOption +} + +// NewFecInterceptor returns a new Fec interceptor factory. +func NewFecInterceptor(opts ...FecOption) (*FecInterceptorFactory, error) { + return &FecInterceptorFactory{opts: opts}, nil +} + +// NewInterceptor constructs a new FecInterceptor. +func (r *FecInterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) { + interceptor := &FecInterceptor{ + streams: make(map[uint32]*streamState), + numMediaPackets: 5, + numFecPackets: 2, + encoderFactory: FlexEncoder03Factory{}, + } + + for _, opt := range r.opts { + if err := opt(interceptor); err != nil { + return nil, err + } + } + + return interceptor, nil +} + +// UnbindLocalStream removes the stream state for a specific SSRC. +func (r *FecInterceptor) UnbindLocalStream(info *interceptor.StreamInfo) { + r.mu.Lock() + defer r.mu.Unlock() + + delete(r.streams, info.SSRC) +} + +// BindLocalStream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method +// will be called once per rtp packet. +func (r *FecInterceptor) BindLocalStream( + info *interceptor.StreamInfo, writer interceptor.RTPWriter, +) interceptor.RTPWriter { + if info.PayloadTypeForwardErrorCorrection == 0 || info.SSRCForwardErrorCorrection == 0 { + return writer + } + + mediaSSRC := info.SSRC + + r.mu.Lock() + stream := &streamState{ + // Chromium supports version flexfec-03 of existing draft, this is the one we will configure by default + // although we should support configuring the latest (flexfec-20) as well. + flexFecEncoder: r.encoderFactory.NewEncoder(info.PayloadTypeForwardErrorCorrection, info.SSRCForwardErrorCorrection), + packetBuffer: make([]rtp.Packet, 0), + } + r.streams[mediaSSRC] = stream + r.mu.Unlock() + + return interceptor.RTPWriterFunc( + func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + // Ignore non-media packets + if header.SSRC != mediaSSRC { + return writer.Write(header, payload, attributes) + } + + var fecPackets []rtp.Packet + stream.mu.Lock() + stream.packetBuffer = append(stream.packetBuffer, rtp.Packet{ + Header: *header, + Payload: payload, + }) + + // Check if we have enough packets to generate FEC + if len(stream.packetBuffer) == int(r.numMediaPackets) { + fecPackets = stream.flexFecEncoder.EncodeFec(stream.packetBuffer, r.numFecPackets) + // Reset the packet buffer now that we've sent the corresponding FEC packets. + stream.packetBuffer = nil + } + stream.mu.Unlock() + + var errs []error + result, err := writer.Write(header, payload, attributes) + if err != nil { + errs = append(errs, err) + } + + for _, packet := range fecPackets { + header := packet.Header + + _, err = writer.Write(&header, packet.Payload, attributes) + if err != nil { + errs = append(errs, err) + } + } + + return result, errors.Join(errs...) + }, + ) +} diff --git a/vendor/github.com/pion/interceptor/pkg/flexfec/flexfec_coverage.go b/vendor/github.com/pion/interceptor/pkg/flexfec/flexfec_coverage.go new file mode 100644 index 0000000..9660f5a --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/flexfec/flexfec_coverage.go @@ -0,0 +1,177 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package flexfec + +import ( + "github.com/pion/interceptor/pkg/flexfec/util" + "github.com/pion/rtp" +) + +// Maximum number of media packets that can be protected by a single FEC packet. +// We are not supporting the possibility of having an FEC packet protect multiple +// SSRC source packets for now. +// https://datatracker.ietf.org/doc/html/rfc8627#section-4.2.2.1 +const ( + MaxMediaPackets uint32 = 110 + MaxFecPackets uint32 = MaxMediaPackets +) + +// ProtectionCoverage defines the map of RTP packets that individual Fec packets protect. +type ProtectionCoverage struct { + // Array of masks, each mask capable of covering up to maxMediaPkts = 110. + // A mask is represented as a grouping of bytes where each individual bit + // represents the coverage for the media packet at the corresponding index. + packetMasks [MaxFecPackets]util.BitArray + numFecPackets uint32 + numMediaPackets uint32 + mediaPackets []rtp.Packet +} + +// NewCoverage returns a new ProtectionCoverage object. numFecPackets represents the number of +// Fec packets that we will be generating to cover the list of mediaPackets. This allows us to know +// how big the underlying map should be. +func NewCoverage(mediaPackets []rtp.Packet, numFecPackets uint32) *ProtectionCoverage { + numMediaPackets := uint32(len(mediaPackets)) //nolint:gosec // G115 + + // Basic sanity checks + if numMediaPackets <= 0 || numMediaPackets > MaxMediaPackets { + return nil + } + + // We allocate the biggest array of bitmasks that respects the max constraints. + var packetMasks [MaxFecPackets]util.BitArray + for i := 0; i < int(MaxFecPackets); i++ { + packetMasks[i] = util.BitArray{} + } + + coverage := &ProtectionCoverage{ + packetMasks: packetMasks, + numFecPackets: 0, + numMediaPackets: 0, + mediaPackets: nil, + } + + coverage.UpdateCoverage(mediaPackets, numFecPackets) + + return coverage +} + +// UpdateCoverage updates the ProtectionCoverage object with new bitmasks accounting for the numFecPackets +// we want to use to protect the batch media packets. +func (p *ProtectionCoverage) UpdateCoverage(mediaPackets []rtp.Packet, numFecPackets uint32) { + numMediaPackets := uint32(len(mediaPackets)) //nolint:gosec // G115 + + // Basic sanity checks + if numMediaPackets <= 0 || numMediaPackets > MaxMediaPackets { + return + } + + p.mediaPackets = mediaPackets + + if numFecPackets == p.numFecPackets && numMediaPackets == p.numMediaPackets { + // We have the same number of FEC packets covering the same number of media packets, we can simply + // reuse the previous coverage map with the updated media packets. + return + } + + p.numFecPackets = numFecPackets + p.numMediaPackets = numMediaPackets + + // The number of FEC packets and/or the number of packets has changed, we need to update the coverage map + // to reflect these new values. + p.resetCoverage() + + // Generate FEC bit mask where numFecPackets FEC packets are covering numMediaPackets Media packets. + // In the packetMasks array, each FEC packet is represented by a single BitArray, each bit in a given BitArray + // corresponds to a specific Media packet. + // Ex: Row I, Col J is set to 1 -> FEC packet I will protect media packet J. + for fecPacketIndex := uint32(0); fecPacketIndex < numFecPackets; fecPacketIndex++ { + // We use an interleaved method to determine coverage. Given N FEC packets, Media packet X will be + // covered by FEC packet X % N. + coveredMediaPacketIndex := fecPacketIndex + for coveredMediaPacketIndex < numMediaPackets { + p.packetMasks[fecPacketIndex].SetBit(coveredMediaPacketIndex) + coveredMediaPacketIndex += numFecPackets + } + } +} + +// ResetCoverage clears the underlying map so that we can reuse it for new batches of RTP packets. +func (p *ProtectionCoverage) resetCoverage() { + for i := uint32(0); i < MaxFecPackets; i++ { + p.packetMasks[i].Reset() + } +} + +// GetCoveredBy returns an iterator over RTP packets that are protected by the specified Fec packet index. +func (p *ProtectionCoverage) GetCoveredBy(fecPacketIndex uint32) *util.MediaPacketIterator { + coverage := make([]uint32, 0, p.numMediaPackets) + for mediaPacketIndex := uint32(0); mediaPacketIndex < p.numMediaPackets; mediaPacketIndex++ { + if p.packetMasks[fecPacketIndex].GetBit(mediaPacketIndex) == 1 { + coverage = append(coverage, mediaPacketIndex) + } + } + + return util.NewMediaPacketIterator(p.mediaPackets, coverage) +} + +// ExtractMask1 returns the first section of the bitmask as defined by the FEC header. +// https://datatracker.ietf.org/doc/html/rfc8627#section-4.2.2.1 +func (p *ProtectionCoverage) ExtractMask1(fecPacketIndex uint32) uint16 { + return extractMask1(p.packetMasks[fecPacketIndex]) +} + +// ExtractMask2 returns the second section of the bitmask as defined by the FEC header. +// https://datatracker.ietf.org/doc/html/rfc8627#section-4.2.2.1 +func (p *ProtectionCoverage) ExtractMask2(fecPacketIndex uint32) uint32 { + return extractMask2(p.packetMasks[fecPacketIndex]) +} + +// ExtractMask3 returns the third section of the bitmask as defined by the FEC header. +// https://datatracker.ietf.org/doc/html/rfc8627#section-4.2.2.1 +func (p *ProtectionCoverage) ExtractMask3(fecPacketIndex uint32) uint64 { + return extractMask3(p.packetMasks[fecPacketIndex]) +} + +// ExtractMask3_03 returns the third section of the bitmask as defined by the FEC header. +// https://datatracker.ietf.org/doc/html/draft-ietf-payload-flexible-fec-scheme-03#section-4.2 +func (p *ProtectionCoverage) ExtractMask3_03(fecPacketIndex uint32) uint64 { + return extractMask3_03(p.packetMasks[fecPacketIndex]) +} + +func extractMask1(mask util.BitArray) uint16 { + // We get the first 16 bits (64 - 16 -> shift by 48) and we shift once more for K field + mask1 := mask.Lo >> 49 + + return uint16(mask1) //nolint:gosec // G115 +} + +func extractMask2(mask util.BitArray) uint32 { + // We remove the first 15 bits + mask2 := mask.Lo << 15 + // We get the first 31 bits (64 - 32 -> shift by 32) and we shift once more for K field + mask2 >>= 33 + + return uint32(mask2) //nolint:gosec +} + +func extractMask3(mask util.BitArray) uint64 { + // We remove the first 46 bits + maskLo := mask.Lo << 46 + maskHi := mask.Hi >> 18 + mask3 := maskLo | maskHi + + return mask3 +} + +func extractMask3_03(mask util.BitArray) uint64 { + // We remove the first 46 bits + maskLo := mask.Lo << 46 + maskHi := mask.Hi >> 18 + mask3 := maskLo | maskHi + // We shift once for the K bit. + mask3 >>= 1 + + return mask3 +} diff --git a/vendor/github.com/pion/interceptor/pkg/flexfec/flexfec_decoder_03.go b/vendor/github.com/pion/interceptor/pkg/flexfec/flexfec_decoder_03.go new file mode 100644 index 0000000..2706a2c --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/flexfec/flexfec_decoder_03.go @@ -0,0 +1,432 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package flexfec implements FlexFEC-03 to recover missing RTP packets due to packet loss. +// https://datatracker.ietf.org/doc/html/draft-ietf-payload-flexible-fec-scheme-03 +package flexfec + +import ( + "encoding/binary" + "errors" + "fmt" + "sort" + + "github.com/pion/logging" + "github.com/pion/rtp" +) + +// Static errors for the flexfec package. +var ( + errPacketTruncated = errors.New("packet truncated") + errRetransmissionBitSet = errors.New("packet with retransmission bit set not supported") + errInflexibleGeneratorMatrix = errors.New("packet with inflexible generator matrix not supported") + errMultipleSSRCProtection = errors.New("multiple ssrc protection not supported") + errLastOptionalMaskKBitSetToFalse = errors.New("k-bit of last optional mask is set to false") +) + +// fecDecoder is a WIP implementation decoder used for testing purposes. +type fecDecoder struct { + logger logging.LeveledLogger + ssrc uint32 + protectedStreamSSRC uint32 + maxMediaPackets int + maxFECPackets int + recoveredPackets []rtp.Packet + receivedFECPackets []fecPacketState +} + +func newFECDecoder(ssrc uint32, protectedStreamSSRC uint32, loggerFactory logging.LoggerFactory) *fecDecoder { + return &fecDecoder{ + logger: loggerFactory.NewLogger("fec_decoder"), + ssrc: ssrc, + protectedStreamSSRC: protectedStreamSSRC, + maxMediaPackets: 100, + maxFECPackets: 100, + recoveredPackets: make([]rtp.Packet, 0), + receivedFECPackets: make([]fecPacketState, 0), + } +} + +func (d *fecDecoder) DecodeFec(receivedPacket rtp.Packet) []rtp.Packet { + if len(d.recoveredPackets) == d.maxMediaPackets { + backRecoveredPacket := d.recoveredPackets[len(d.recoveredPackets)-1] + if backRecoveredPacket.SSRC == receivedPacket.SSRC { + seqDiffVal := seqDiff(receivedPacket.SequenceNumber, backRecoveredPacket.SequenceNumber) + if seqDiffVal > uint16(d.maxMediaPackets) { //nolint:gosec + d.logger.Info("big gap in media sequence numbers - resetting buffers") + d.recoveredPackets = nil + d.receivedFECPackets = nil + } + } + } + + d.insertPacket(receivedPacket) + + return d.attemptRecovery() +} + +func (d *fecDecoder) insertPacket(receivedPkt rtp.Packet) { + // Discard old FEC packets such that the sequence numbers in + // `received_fec_packets_` span at most 1/2 of the sequence number space. + // This is important for keeping `received_fec_packets_` sorted, and may + // also reduce the possibility of incorrect decoding due to sequence number + // wrap-around. + if len(d.receivedFECPackets) > 0 && receivedPkt.SSRC == d.ssrc { + toRemove := 0 + for _, fecPkt := range d.receivedFECPackets { + if abs(int(receivedPkt.SequenceNumber)-int(fecPkt.packet.SequenceNumber)) > 0x3fff { + toRemove++ + } else { + // No need to keep iterating, since |received_fec_packets_| is sorted. + break + } + } + if toRemove > 0 { + d.receivedFECPackets = d.receivedFECPackets[toRemove:] + } + } + + switch receivedPkt.SSRC { + case d.ssrc: + d.insertFECPacket(receivedPkt) + case d.protectedStreamSSRC: + d.insertMediaPacket(receivedPkt) + } + + d.discardOldRecoveredPackets() +} + +func (d *fecDecoder) insertMediaPacket(receivedPkt rtp.Packet) { + for _, recoveredPacket := range d.recoveredPackets { + if recoveredPacket.SequenceNumber == receivedPkt.SequenceNumber { + return + } + } + + d.recoveredPackets = append(d.recoveredPackets, receivedPkt) + sort.Slice(d.recoveredPackets, func(i, j int) bool { + return isNewerSeq(d.recoveredPackets[i].SequenceNumber, d.recoveredPackets[j].SequenceNumber) + }) + d.updateCoveringFecPackets(receivedPkt) +} + +func (d *fecDecoder) updateCoveringFecPackets(receivedPkt rtp.Packet) { + for _, fecPkt := range d.receivedFECPackets { + for _, protectedPacket := range fecPkt.protectedPackets { + if protectedPacket.seq == receivedPkt.SequenceNumber { + protectedPacket.packet = &receivedPkt + } + } + } +} + +func (d *fecDecoder) insertFECPacket(fecPkt rtp.Packet) { //nolint:cyclop + for _, existingFECPacket := range d.receivedFECPackets { + if existingFECPacket.packet.SequenceNumber == fecPkt.SequenceNumber { + return + } + } + + fec, err := parseFlexFEC03Header(fecPkt.Payload) + if err != nil { + d.logger.Errorf("failed to parse flexfec03 header: %v", err) + + return + } + + if fec.protectedSSRC != d.protectedStreamSSRC { + d.logger.Errorf("fec is protecting unknown ssrc, expected %d, got %d", fec.protectedSSRC, d.protectedStreamSSRC) + + return + } + + protectedSeqs := decodeMask(uint64(fec.mask0), 15, fec.seqNumBase) + if fec.mask1 != 0 { + protectedSeqs = append(protectedSeqs, decodeMask(uint64(fec.mask1), 31, fec.seqNumBase+15)...) + } + if fec.mask2 != 0 { + protectedSeqs = append(protectedSeqs, decodeMask(fec.mask2, 63, fec.seqNumBase+46)...) + } + + if len(protectedSeqs) == 0 { + d.logger.Warn("empty fec packet mask") + + return + } + + protectedPackets := make([]*protectedPacket, 0, len(protectedSeqs)) + protectedSeqIt := 0 + recoveredPacketIt := 0 + + for protectedSeqIt < len(protectedSeqs) && recoveredPacketIt < len(d.recoveredPackets) { + switch { + case isNewerSeq(protectedSeqs[protectedSeqIt], d.recoveredPackets[recoveredPacketIt].SequenceNumber): + protectedPackets = append(protectedPackets, &protectedPacket{ + seq: protectedSeqs[protectedSeqIt], + packet: nil, + }) + protectedSeqIt++ + case isNewerSeq(d.recoveredPackets[recoveredPacketIt].SequenceNumber, protectedSeqs[protectedSeqIt]): + recoveredPacketIt++ + default: + protectedPackets = append(protectedPackets, &protectedPacket{ + seq: protectedSeqs[protectedSeqIt], + packet: &d.recoveredPackets[recoveredPacketIt], + }) + protectedSeqIt++ + recoveredPacketIt++ + } + } + + for protectedSeqIt < len(protectedSeqs) { + protectedPackets = append(protectedPackets, &protectedPacket{ + seq: protectedSeqs[protectedSeqIt], + packet: nil, + }) + protectedSeqIt++ + } + d.receivedFECPackets = append(d.receivedFECPackets, fecPacketState{ + packet: fecPkt, + flexFec: fec, + protectedPackets: protectedPackets, + }) + + sort.Slice(d.receivedFECPackets, func(i, j int) bool { + return isNewerSeq(d.receivedFECPackets[i].packet.SequenceNumber, d.receivedFECPackets[j].packet.SequenceNumber) + }) + + if len(d.receivedFECPackets) > d.maxFECPackets { + d.receivedFECPackets = d.receivedFECPackets[1:] + } +} + +func (d *fecDecoder) attemptRecovery() []rtp.Packet { + recoveredPackets := make([]rtp.Packet, 0) + for { + packetsRecovered := 0 + for _, fecPkt := range d.receivedFECPackets { + packetsMissing := 0 + for _, pkt := range fecPkt.protectedPackets { + if pkt.packet == nil { + packetsMissing++ + if packetsMissing > 1 { + break + } + } + } + + if packetsMissing != 1 { + continue + } + + recovered, err := d.recoverPacket(&fecPkt) //nolint:gosec + if err != nil { + d.logger.Errorf("failed to recover packet: %v", err) + } + + recoveredPackets = append(recoveredPackets, recovered) + d.recoveredPackets = append(d.recoveredPackets, recovered) + sort.Slice(d.recoveredPackets, func(i, j int) bool { + return isNewerSeq(d.recoveredPackets[i].SequenceNumber, d.recoveredPackets[j].SequenceNumber) + }) + + d.updateCoveringFecPackets(recovered) + d.discardOldRecoveredPackets() + packetsRecovered++ + } + + if packetsRecovered == 0 { + break + } + } + + return recoveredPackets +} + +func (d *fecDecoder) recoverPacket(fec *fecPacketState) (rtp.Packet, error) { + // https://datatracker.ietf.org/doc/html/draft-ietf-payload-flexible-fec-scheme-03#section-6.3.2 + + // 2. For the repair packet in T, extract the FEC bit string as the + // first 80 bits of the FEC header. + headerRecovery := make([]byte, 12) + copy(headerRecovery, fec.packet.Payload[:10]) + + var seqnum uint16 + for _, protectedPacket := range fec.protectedPackets { + if protectedPacket.packet != nil { + // 1. For each of the source packets that are successfully received in + // T, compute the 80-bit string by concatenating the first 64 bits + // of their RTP header and the unsigned network-ordered 16-bit + // representation of their length in bytes minus 12. + receivedHeader, err := protectedPacket.packet.Header.Marshal() + if err != nil { + return rtp.Packet{}, fmt.Errorf("marshal received header: %w", err) + } + binary.BigEndian.PutUint16(receivedHeader[2:4], uint16(protectedPacket.packet.MarshalSize()-12)) //nolint:gosec + for i := 0; i < 8; i++ { + headerRecovery[i] ^= receivedHeader[i] + } + } else { + seqnum = protectedPacket.seq + } + } + + // set version to 2 + headerRecovery[0] |= 0x80 + headerRecovery[0] &= 0xbf + payloadLength := binary.BigEndian.Uint16(headerRecovery[2:4]) + binary.BigEndian.PutUint16(headerRecovery[2:4], seqnum) + binary.BigEndian.PutUint32(headerRecovery[8:12], d.protectedStreamSSRC) + + payloadRecovery := make([]byte, payloadLength) + copy(payloadRecovery, fec.flexFec.payload) + for _, protectedPacket := range fec.protectedPackets { + if protectedPacket.packet != nil { + packet, err := protectedPacket.packet.Marshal() + if err != nil { + return rtp.Packet{}, fmt.Errorf("marshal protected packet: %w", err) + } + for i := 0; i < min(int(payloadLength), len(packet)-12); i++ { + payloadRecovery[i] ^= packet[12+i] + } + } + } + + headerRecovery = append(headerRecovery, payloadRecovery...) //nolint:makezero + + var packet rtp.Packet + err := packet.Unmarshal(headerRecovery) + if err != nil { + return rtp.Packet{}, fmt.Errorf("unmarshal recovered: %w", err) + } + + return packet, nil +} + +func (d *fecDecoder) discardOldRecoveredPackets() { + const limit = 192 + if len(d.recoveredPackets) > limit { + d.recoveredPackets = d.recoveredPackets[len(d.recoveredPackets)-192:] + } +} + +func decodeMask(mask uint64, bitCount uint16, seqNumBase uint16) []uint16 { + res := make([]uint16, 0) + for i := uint16(0); i < bitCount; i++ { + if (mask>>(bitCount-1-i))&1 == 1 { + res = append(res, seqNumBase+i) + } + } + + return res +} + +type fecPacketState struct { + packet rtp.Packet + flexFec flexFec + protectedPackets []*protectedPacket +} + +type flexFec struct { + protectedSSRC uint32 + seqNumBase uint16 + mask0 uint16 + mask1 uint32 + mask2 uint64 + payload []byte +} + +type protectedPacket struct { + seq uint16 + packet *rtp.Packet +} + +func parseFlexFEC03Header(data []byte) (flexFec, error) { + if len(data) < 20 { + return flexFec{}, fmt.Errorf("%w: length %d", errPacketTruncated, len(data)) + } + + rBit := (data[0] & 0x80) != 0 + if rBit { + return flexFec{}, errRetransmissionBitSet + } + + fBit := (data[0] & 0x40) != 0 + if fBit { + return flexFec{}, errInflexibleGeneratorMatrix + } + + ssrcCount := data[8] + if ssrcCount != 1 { + return flexFec{}, fmt.Errorf("%w: count %d", errMultipleSSRCProtection, ssrcCount) + } + + protectedSSRC := binary.BigEndian.Uint32(data[12:]) + seqNumBase := binary.BigEndian.Uint16(data[16:]) + rawPacketMask := data[18:] + var payload []byte + + kBit0 := (rawPacketMask[0] & 0x80) != 0 + maskPart0 := binary.BigEndian.Uint16(rawPacketMask[0:2]) & 0x7FFF + var maskPart1 uint32 + var maskPart2 uint64 + + if kBit0 { //nolint:nestif + payload = rawPacketMask[2:] + } else { + if len(data) < 24 { + return flexFec{}, fmt.Errorf("%w: length %d", errPacketTruncated, len(data)) + } + + kBit1 := (rawPacketMask[2] & 0x80) != 0 + maskPart1 = binary.BigEndian.Uint32(rawPacketMask[2:]) & 0x7FFFFFFF + + if kBit1 { + payload = rawPacketMask[6:] + } else { + if len(data) < 32 { + return flexFec{}, fmt.Errorf("%w: length %d", errPacketTruncated, len(data)) + } + + kBit2 := (rawPacketMask[6] & 0x80) != 0 + maskPart2 = binary.BigEndian.Uint64(rawPacketMask[6:]) & 0x7FFFFFFFFFFFFFFF + + if kBit2 { + payload = rawPacketMask[14:] + } else { + return flexFec{}, errLastOptionalMaskKBitSetToFalse + } + } + } + + return flexFec{ + protectedSSRC: protectedSSRC, + seqNumBase: seqNumBase, + mask0: maskPart0, + mask1: maskPart1, + mask2: maskPart2, + payload: payload, + }, nil +} + +func seqDiff(a, b uint16) uint16 { + return min(a-b, b-a) +} + +func abs(x int) int { + if x >= 0 { + return x + } + + return -x +} + +func isNewerSeq(prevValue, value uint16) bool { + // half-way mark + breakpoint := uint16(0x8000) + if value-prevValue == breakpoint { + return value > prevValue + } + + return value != prevValue && (value-prevValue) < breakpoint +} diff --git a/vendor/github.com/pion/interceptor/pkg/flexfec/flexfec_encoder.go b/vendor/github.com/pion/interceptor/pkg/flexfec/flexfec_encoder.go new file mode 100644 index 0000000..cec1d59 --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/flexfec/flexfec_encoder.go @@ -0,0 +1,209 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package flexfec implements FlexFEC to recover missing RTP packets due to packet loss. +// https://datatracker.ietf.org/doc/html/rfc8627 +package flexfec + +import ( + "encoding/binary" + + "github.com/pion/interceptor/pkg/flexfec/util" + "github.com/pion/rtp" +) + +const ( + // BaseRTPHeaderSize represents the minium RTP packet header size in bytes. + BaseRTPHeaderSize = 12 + // BaseFecHeaderSize represents the minium FEC payload's header size including the + // required first mask. + BaseFecHeaderSize = 12 +) + +// EncoderFactory is an interface for generic FEC encoders. +type EncoderFactory interface { + NewEncoder(payloadType uint8, ssrc uint32) FlexEncoder +} + +// FlexEncoder is the interface that FecInterceptor uses to encode Fec packets. +type FlexEncoder interface { + EncodeFec(mediaPackets []rtp.Packet, numFecPackets uint32) []rtp.Packet +} + +// FlexEncoder20 implementation is WIP, contains bugs and no tests. Check out FlexEncoder03. +type FlexEncoder20 struct { + fecBaseSn uint16 + payloadType uint8 + ssrc uint32 + coverage *ProtectionCoverage +} + +// NewFlexEncoder returns a new FlexEncoder20. +// FlexEncoder20 implementation is WIP, contains bugs and no tests. Check out FlexEncoder03. +func NewFlexEncoder(payloadType uint8, ssrc uint32) *FlexEncoder20 { + return &FlexEncoder20{ + payloadType: payloadType, + ssrc: ssrc, + fecBaseSn: uint16(1000), + } +} + +// EncodeFec returns a list of generated RTP packets with FEC payloads that protect the specified mediaPackets. +// This method does not account for missing RTP packets in the mediaPackets array nor does it account for +// them being passed out of order. +func (flex *FlexEncoder20) EncodeFec(mediaPackets []rtp.Packet, numFecPackets uint32) []rtp.Packet { + // Start by defining which FEC packets cover which media packets + if flex.coverage == nil { + flex.coverage = NewCoverage(mediaPackets, numFecPackets) + } else { + flex.coverage.UpdateCoverage(mediaPackets, numFecPackets) + } + + if flex.coverage == nil { + return nil + } + + // Generate FEC payloads + fecPackets := make([]rtp.Packet, numFecPackets) + for fecPacketIndex := uint32(0); fecPacketIndex < numFecPackets; fecPacketIndex++ { + fecPackets[fecPacketIndex] = flex.encodeFlexFecPacket(fecPacketIndex, mediaPackets[0].SequenceNumber) + } + + return fecPackets +} + +func (flex *FlexEncoder20) encodeFlexFecPacket(fecPacketIndex uint32, mediaBaseSn uint16) rtp.Packet { + mediaPacketsIt := flex.coverage.GetCoveredBy(fecPacketIndex) + flexFecHeader := flex.encodeFlexFecHeader( + mediaPacketsIt, + flex.coverage.ExtractMask1(fecPacketIndex), + flex.coverage.ExtractMask2(fecPacketIndex), + flex.coverage.ExtractMask3(fecPacketIndex), + mediaBaseSn, + ) + flexFecRepairPayload := flex.encodeFlexFecRepairPayload(mediaPacketsIt.Reset()) + + packet := rtp.Packet{ + Header: rtp.Header{ + Version: 2, + Padding: false, + Extension: false, + Marker: false, + PayloadType: flex.payloadType, + SequenceNumber: flex.fecBaseSn, + Timestamp: 54243243, + SSRC: flex.ssrc, + CSRC: []uint32{}, + }, + Payload: append(flexFecHeader, flexFecRepairPayload...), + } + flex.fecBaseSn++ + + return packet +} + +func (flex *FlexEncoder20) encodeFlexFecHeader( + mediaPackets *util.MediaPacketIterator, + mask1 uint16, + optionalMask2 uint32, + optionalMask3 uint64, + mediaBaseSn uint16, +) []byte { + /* + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |0|0|P|X| CC |M| PT recovery | length recovery | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | TS recovery | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | SN base_i |k| Mask [0-14] | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |k| Mask [15-45] (optional) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Mask [46-109] (optional) | + | | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | ... next SN base and Mask for CSRC_i in CSRC list ... | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + : Repair "Payload" follows FEC Header : + : : + */ + + // Get header size - This depends on the size of the bitmask. + headerSize := BaseFecHeaderSize + if optionalMask2 > 0 { + headerSize += 4 + } + if optionalMask3 > 0 { + headerSize += 8 + } + + // Allocate the FlexFec header + flexFecHeader := make([]byte, headerSize) + + // XOR the relevant fields for the header + // TO DO - CHECK TO SEE IF THE MARSHALTO() call works with this. + tmpMediaPacketBuf := make([]byte, headerSize) + for mediaPackets.HasNext() { + mediaPacket := mediaPackets.Next() + n, err := mediaPacket.MarshalTo(tmpMediaPacketBuf) + + if n == 0 || err != nil { + return nil + } + + // XOR the first 2 bytes of the header: V, P, X, CC, M, PT fields + flexFecHeader[0] ^= tmpMediaPacketBuf[0] + flexFecHeader[1] ^= tmpMediaPacketBuf[1] + + // XOR the length recovery field + lengthRecoveryVal := uint16(mediaPacket.MarshalSize() - BaseRTPHeaderSize) //nolint:gosec // G115 + flexFecHeader[2] ^= uint8(lengthRecoveryVal >> 8) //nolint:gosec // G115 + flexFecHeader[3] ^= uint8(lengthRecoveryVal) //nolint:gosec // G115 + + // XOR the 5th to 8th bytes of the header: the timestamp field + flexFecHeader[4] ^= flexFecHeader[4] + flexFecHeader[5] ^= flexFecHeader[5] + flexFecHeader[6] ^= flexFecHeader[6] + flexFecHeader[7] ^= flexFecHeader[7] + } + + // Write the base SN for the batch of media packets + binary.BigEndian.PutUint16(flexFecHeader[8:10], mediaBaseSn) + + // Write the bitmasks to the header + binary.BigEndian.PutUint16(flexFecHeader[10:12], mask1) + + if optionalMask2 > 0 { + binary.BigEndian.PutUint32(flexFecHeader[12:16], optionalMask2) + flexFecHeader[10] |= 0b10000000 + } + if optionalMask3 > 0 { + binary.BigEndian.PutUint64(flexFecHeader[16:24], optionalMask3) + flexFecHeader[12] |= 0b10000000 + } + + return flexFecHeader +} + +func (flex *FlexEncoder20) encodeFlexFecRepairPayload(mediaPackets *util.MediaPacketIterator) []byte { + flexFecPayload := make([]byte, len(mediaPackets.First().Payload)) + + for mediaPackets.HasNext() { + mediaPacketPayload := mediaPackets.Next().Payload + + if len(flexFecPayload) < len(mediaPacketPayload) { + // Expected FEC packet payload is bigger that what we can currently store, + // we need to resize. + flexFecPayloadTmp := make([]byte, len(mediaPacketPayload)) + copy(flexFecPayloadTmp, flexFecPayload) + flexFecPayload = flexFecPayloadTmp + } + for byteIndex := 0; byteIndex < len(mediaPacketPayload); byteIndex++ { + flexFecPayload[byteIndex] ^= mediaPacketPayload[byteIndex] + } + } + + return flexFecPayload +} diff --git a/vendor/github.com/pion/interceptor/pkg/flexfec/flexfec_encoder_03.go b/vendor/github.com/pion/interceptor/pkg/flexfec/flexfec_encoder_03.go new file mode 100644 index 0000000..b73416a --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/flexfec/flexfec_encoder_03.go @@ -0,0 +1,241 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package flexfec implements FlexFEC to recover missing RTP packets due to packet loss. +// https://datatracker.ietf.org/doc/html/rfc8627 +package flexfec + +import ( + "encoding/binary" + "sync" + + "github.com/pion/rtp" +) + +const ( + // BaseFec03HeaderSize represents the minium FEC payload's header size including the + // required first mask. + BaseFec03HeaderSize = 20 + // maxRTPPacketSize represents the maximum size of an RTP packet buffer. + // This is a reasonable upper bound for typical RTP packets. + maxRTPPacketSize = 1500 +) + +var bufferPool = sync.Pool{ //nolint:gochecknoglobals + New: func() any { + b := make([]byte, maxRTPPacketSize) + + return &b + }, +} + +// FlexEncoder03 implements the Fec encoding mechanism for the "Flex" variant of FlexFec. +type FlexEncoder03 struct { + fecBaseSn uint16 + payloadType uint8 + ssrc uint32 + coverage *ProtectionCoverage +} + +// FlexEncoder03Factory is a factory for FlexFEC-03 encoders. +type FlexEncoder03Factory struct{} + +// NewEncoder creates new FlexFEC-03 encoder. +func (f FlexEncoder03Factory) NewEncoder(payloadType uint8, ssrc uint32) FlexEncoder { + return NewFlexEncoder03(payloadType, ssrc) +} + +// NewFlexEncoder03 creates new FlexFEC-03 encoder. +func NewFlexEncoder03(payloadType uint8, ssrc uint32) *FlexEncoder03 { + return &FlexEncoder03{ + payloadType: payloadType, + ssrc: ssrc, + fecBaseSn: uint16(1000), + } +} + +// EncodeFec returns a list of generated RTP packets with FEC payloads that protect the specified mediaPackets. +// This method returns nil in case of missing RTP packets in the mediaPackets array or packets passed out of order. +func (flex *FlexEncoder03) EncodeFec(mediaPackets []rtp.Packet, numFecPackets uint32) []rtp.Packet { + // Check if mediaPackets is empty + if len(mediaPackets) == 0 { + return nil + } + + // Check if RTP packets are in order by comparing sequence numbers + for i := 1; i < len(mediaPackets); i++ { + if mediaPackets[i].SequenceNumber != mediaPackets[i-1].SequenceNumber+1 { + // Packets are not in order or there are missing packets + return nil + } + } + + // Start by defining which FEC packets cover which media packets + if flex.coverage == nil { + flex.coverage = NewCoverage(mediaPackets, numFecPackets) + } else { + flex.coverage.UpdateCoverage(mediaPackets, numFecPackets) + } + + if flex.coverage == nil { + return nil + } + + // Generate FEC payloads + fecPackets := make([]rtp.Packet, 0, numFecPackets) + for fecPacketIndex := uint32(0); fecPacketIndex < numFecPackets; fecPacketIndex++ { + fecPacket, ok := flex.encodeFlexFecPacket(fecPacketIndex, mediaPackets[0].SequenceNumber) + if ok { + fecPackets = append(fecPackets, fecPacket) + } + } + + return fecPackets +} + +//nolint:cyclop +func (flex *FlexEncoder03) encodeFlexFecPacket(fecPacketIndex uint32, mediaBaseSn uint16) (rtp.Packet, bool) { + mediaPackets := flex.coverage.GetCoveredBy(fecPacketIndex) + mask1 := flex.coverage.ExtractMask1(fecPacketIndex) + optionalMask2 := flex.coverage.ExtractMask2(fecPacketIndex) + optionalMask3 := flex.coverage.ExtractMask3_03(fecPacketIndex) + + /* + FlexFEC Header Format: + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |0|0| P|X| CC |M| PT recovery | length recovery | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | TS recovery | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | SSRCCount | reserved | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | SSRC_i | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | SN base_i |k| Mask [0-14] | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |k| Mask [15-45] (optional) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |k| | + +-+ Mask [46-108] (optional) | + | | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | ... next in SSRC_i ... | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ + + if !mediaPackets.HasNext() { + return rtp.Packet{}, false + } + + // Get header size - This depends on the size of the bitmask. + headerSize := BaseFec03HeaderSize + if optionalMask2 > 0 || optionalMask3 > 0 { + headerSize += 4 + } + if optionalMask3 > 0 { + headerSize += 8 + } + + // Find the maximum payload size among all media packets + maxPayloadSize := 0 + for mediaPackets.HasNext() { + maxPayloadSize = max(maxPayloadSize, mediaPackets.Next().MarshalSize()-BaseRTPHeaderSize) + } + mediaPackets.Reset() + + flexFecPayload := make([]byte, headerSize+maxPayloadSize) + + flexFecHeader := flexFecPayload[:headerSize] + flexFecRepairPayload := flexFecPayload[headerSize : headerSize+maxPayloadSize] + + bufferFromPool := bufferPool.Get().(*[]byte) //nolint:forcetypeassert + defer bufferPool.Put(bufferFromPool) + tmpMediaPacketBuf := *bufferFromPool + + for mediaPackets.HasNext() { + mediaPacket := mediaPackets.Next() + packetSize := mediaPacket.MarshalSize() + if packetSize > len(tmpMediaPacketBuf) { + // Packet is too large for our fixed buffer, fallback to dynamic allocation + tmpMediaPacketBuf = make([]byte, packetSize) + } + + n, err := mediaPacket.MarshalTo(tmpMediaPacketBuf[:packetSize]) + if n == 0 || err != nil { + return rtp.Packet{}, false + } + + // XOR the first 2 bytes of the header: V, P, X, CC, M, PT fields + flexFecHeader[0] ^= tmpMediaPacketBuf[0] + flexFecHeader[1] ^= tmpMediaPacketBuf[1] + + // Clear the first 2 bits + flexFecHeader[0] &= 0b00111111 + + // XOR the length recovery field + lengthRecoveryVal := uint16(mediaPacket.MarshalSize() - BaseRTPHeaderSize) //nolint:gosec // G115 + flexFecHeader[2] ^= uint8(lengthRecoveryVal >> 8) //nolint:gosec // G115 + flexFecHeader[3] ^= uint8(lengthRecoveryVal) //nolint:gosec // G115 + + // XOR the 5th to 8th bytes of the header: the timestamp field + flexFecHeader[4] ^= tmpMediaPacketBuf[4] + flexFecHeader[5] ^= tmpMediaPacketBuf[5] + flexFecHeader[6] ^= tmpMediaPacketBuf[6] + flexFecHeader[7] ^= tmpMediaPacketBuf[7] + + // Process FlexFEC Repair Payload (bytes after RTP header) + for byteIndex := 0; byteIndex < packetSize-BaseRTPHeaderSize; byteIndex++ { + flexFecRepairPayload[byteIndex] ^= tmpMediaPacketBuf[byteIndex+BaseRTPHeaderSize] + } + } + + // Write the SSRC count + flexFecHeader[8] = 1 + + // Write 0s in reserved + flexFecHeader[9] = 0 + flexFecHeader[10] = 0 + flexFecHeader[11] = 0 + + // Write the SSRC of media packets protected by this FEC packet + binary.BigEndian.PutUint32(flexFecHeader[12:16], mediaPackets.First().SSRC) + + // Write the base SN for the batch of media packets + binary.BigEndian.PutUint16(flexFecHeader[16:18], mediaBaseSn) + + // Write the bitmasks to the header + binary.BigEndian.PutUint16(flexFecHeader[18:20], mask1) + + if optionalMask2 == 0 && optionalMask3 == 0 { + flexFecHeader[18] |= 0b10000000 + } else { + binary.BigEndian.PutUint32(flexFecHeader[20:24], optionalMask2) + + if optionalMask3 == 0 { + flexFecHeader[20] |= 0b10000000 + } else { + binary.BigEndian.PutUint64(flexFecHeader[24:32], optionalMask3) + flexFecHeader[24] |= 0b10000000 + } + } + + packet := rtp.Packet{ + Header: rtp.Header{ + Version: 2, + Padding: false, + Extension: false, + Marker: false, + PayloadType: flex.payloadType, + SequenceNumber: flex.fecBaseSn, + Timestamp: 54243243, + SSRC: flex.ssrc, + CSRC: []uint32{}, + }, + Payload: flexFecPayload, + } + flex.fecBaseSn++ + + return packet, true +} diff --git a/vendor/github.com/pion/interceptor/pkg/flexfec/option.go b/vendor/github.com/pion/interceptor/pkg/flexfec/option.go new file mode 100644 index 0000000..344ebe5 --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/flexfec/option.go @@ -0,0 +1,34 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package flexfec + +// FecOption can be used to set initial options on Fec encoder interceptors. +type FecOption func(d *FecInterceptor) error + +// NumMediaPackets sets the number of media packets to accumulate before generating another FEC packets batch. +func NumMediaPackets(numMediaPackets uint32) FecOption { + return func(f *FecInterceptor) error { + f.numMediaPackets = numMediaPackets + + return nil + } +} + +// NumFECPackets sets the number of FEC packets to generate for each batch of media packets. +func NumFECPackets(numFecPackets uint32) FecOption { + return func(f *FecInterceptor) error { + f.numFecPackets = numFecPackets + + return nil + } +} + +// FECEncoderFactory sets the custom factory for constructing the FEC Encoders. +func FECEncoderFactory(factory EncoderFactory) FecOption { + return func(f *FecInterceptor) error { + f.encoderFactory = factory + + return nil + } +} diff --git a/vendor/github.com/pion/interceptor/pkg/flexfec/util/bitarray.go b/vendor/github.com/pion/interceptor/pkg/flexfec/util/bitarray.go new file mode 100644 index 0000000..f3eb43a --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/flexfec/util/bitarray.go @@ -0,0 +1,47 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package util implements utilities to better support Fec decoding / encoding. +package util + +// BitArray provides support for bitmask manipulations. +type BitArray struct { + Lo uint64 // leftmost 64 bits + Hi uint64 // rightmost 64 bits +} + +// SetBit sets a bit to the specified bit value on the bitmask. +func (b *BitArray) SetBit(bitIndex uint32) { + if bitIndex < 64 { + b.Lo |= uint64(0b1) << (63 - bitIndex) + } else { + hiBitIndex := bitIndex - 64 + b.Hi |= uint64(0b1) << (63 - hiBitIndex) + } +} + +// Reset clears the bitmask. +func (b *BitArray) Reset() { + b.Lo = 0 + b.Hi = 0 +} + +// GetBit returns the bit value at a specified index of the bitmask. +func (b *BitArray) GetBit(bitIndex uint32) uint8 { + if bitIndex < 64 { + result := (b.Lo & (uint64(0b1) << (63 - bitIndex))) + if result > 0 { + return 1 + } + + return 0 + } + + hiBitIndex := bitIndex - 64 + result := (b.Hi & (uint64(0b1) << (63 - hiBitIndex))) + if result > 0 { + return 1 + } + + return 0 +} diff --git a/vendor/github.com/pion/interceptor/pkg/flexfec/util/media_packet_iterator.go b/vendor/github.com/pion/interceptor/pkg/flexfec/util/media_packet_iterator.go new file mode 100644 index 0000000..0813c9c --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/flexfec/util/media_packet_iterator.go @@ -0,0 +1,57 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// nolint: revive, staticcheck +package util + +import "github.com/pion/rtp" + +// MediaPacketIterator supports iterating through a list of media packets protected by +// a specific Fec packet. +type MediaPacketIterator struct { + mediaPackets []rtp.Packet + coveredIndices []uint32 + nextIndex int +} + +// NewMediaPacketIterator returns a new MediaPacketIterator. +func NewMediaPacketIterator(mediaPackets []rtp.Packet, coveredIndices []uint32) *MediaPacketIterator { + return &MediaPacketIterator{ + mediaPackets: mediaPackets, + coveredIndices: coveredIndices, + nextIndex: 0, + } +} + +// Reset sets the starting iterating index back to 0. +func (m *MediaPacketIterator) Reset() *MediaPacketIterator { + m.nextIndex = 0 + + return m +} + +// HasNext indicates whether or not there are more media packets +// that can be iterated through. +func (m *MediaPacketIterator) HasNext() bool { + return m.nextIndex < len(m.coveredIndices) +} + +// Next returns the next media packet to iterate through. +func (m *MediaPacketIterator) Next() *rtp.Packet { + if m.nextIndex == len(m.coveredIndices) { + return nil + } + packet := m.mediaPackets[m.coveredIndices[m.nextIndex]] + m.nextIndex++ + + return &packet +} + +// First returns the first media packet to iterate through. +func (m *MediaPacketIterator) First() *rtp.Packet { + if len(m.coveredIndices) == 0 { + return nil + } + + return &m.mediaPackets[m.coveredIndices[0]] +} diff --git a/vendor/github.com/pion/interceptor/pkg/nack/errors.go b/vendor/github.com/pion/interceptor/pkg/nack/errors.go new file mode 100644 index 0000000..bc1c432 --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/nack/errors.go @@ -0,0 +1,9 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package nack + +import "github.com/pion/interceptor/internal/rtpbuffer" + +// ErrInvalidSize is returned by newReceiveLog/newRTPBuffer, when an incorrect buffer size is supplied. +var ErrInvalidSize = rtpbuffer.ErrInvalidSize diff --git a/vendor/github.com/pion/interceptor/pkg/nack/generator_interceptor.go b/vendor/github.com/pion/interceptor/pkg/nack/generator_interceptor.go new file mode 100644 index 0000000..9ed6d69 --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/nack/generator_interceptor.go @@ -0,0 +1,244 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package nack + +import ( + "math/rand" + "slices" + "sync" + "time" + + "github.com/pion/interceptor" + "github.com/pion/logging" + "github.com/pion/rtcp" +) + +// GeneratorInterceptorFactory is a interceptor.Factory for a GeneratorInterceptor. +type GeneratorInterceptorFactory struct { + opts []GeneratorOption +} + +// NewInterceptor constructs a new ReceiverInterceptor. +func (g *GeneratorInterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) { + generatorInterceptor := &GeneratorInterceptor{ + streamsFilter: streamSupportNack, + size: 512, + skipLastN: 0, + maxNacksPerPacket: 0, + interval: time.Millisecond * 100, + receiveLogs: map[uint32]*receiveLog{}, + nackCountLogs: map[uint32]map[uint16]uint16{}, + close: make(chan struct{}), + } + + for _, opt := range g.opts { + if err := opt(generatorInterceptor); err != nil { + return nil, err + } + } + + if generatorInterceptor.loggerFactory == nil { + generatorInterceptor.loggerFactory = logging.NewDefaultLoggerFactory() + } + if generatorInterceptor.log == nil { + generatorInterceptor.log = generatorInterceptor.loggerFactory.NewLogger("nack_generator") + } + + if _, err := newReceiveLog(generatorInterceptor.size); err != nil { + return nil, err + } + + return generatorInterceptor, nil +} + +// GeneratorInterceptor interceptor generates nack feedback messages. +type GeneratorInterceptor struct { + interceptor.NoOp + streamsFilter func(info *interceptor.StreamInfo) bool + size uint16 + skipLastN uint16 + maxNacksPerPacket uint16 + interval time.Duration + m sync.Mutex + wg sync.WaitGroup + close chan struct{} + log logging.LeveledLogger + loggerFactory logging.LoggerFactory + nackCountLogs map[uint32]map[uint16]uint16 + + receiveLogs map[uint32]*receiveLog + receiveLogsMu sync.Mutex +} + +// NewGeneratorInterceptor returns a new GeneratorInterceptorFactory. +func NewGeneratorInterceptor(opts ...GeneratorOption) (*GeneratorInterceptorFactory, error) { + return &GeneratorInterceptorFactory{opts}, nil +} + +// BindRTCPWriter lets you modify any outgoing RTCP packets. It is called once per PeerConnection. +// The returned method will be called once per packet batch. +func (n *GeneratorInterceptor) BindRTCPWriter(writer interceptor.RTCPWriter) interceptor.RTCPWriter { + n.m.Lock() + defer n.m.Unlock() + + if n.isClosed() { + return writer + } + + n.wg.Add(1) + + go n.loop(writer) + + return writer +} + +// BindRemoteStream lets you modify any incoming RTP packets. It is called once for per RemoteStream. +// The returned method will be called once per rtp packet. +func (n *GeneratorInterceptor) BindRemoteStream( + info *interceptor.StreamInfo, reader interceptor.RTPReader, +) interceptor.RTPReader { + if !n.streamsFilter(info) { + return reader + } + + // error is already checked in NewGeneratorInterceptor + receiveLog, _ := newReceiveLog(n.size) + n.receiveLogsMu.Lock() + n.receiveLogs[info.SSRC] = receiveLog + n.receiveLogsMu.Unlock() + + return interceptor.RTPReaderFunc(func(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) { + i, attr, err := reader.Read(b, a) + if err != nil { + return 0, nil, err + } + + if attr == nil { + attr = make(interceptor.Attributes) + } + header, err := attr.GetRTPHeader(b[:i]) + if err != nil { + return 0, nil, err + } + receiveLog.add(header.SequenceNumber) + + return i, attr, nil + }) +} + +// UnbindRemoteStream is called when the Stream is removed. It can be used to clean up any data related to that track. +func (n *GeneratorInterceptor) UnbindRemoteStream(info *interceptor.StreamInfo) { + n.receiveLogsMu.Lock() + delete(n.receiveLogs, info.SSRC) + // the count logs must also be dropped for the specific SSRC. + delete(n.nackCountLogs, info.SSRC) + n.receiveLogsMu.Unlock() +} + +// Close closes the interceptor. +func (n *GeneratorInterceptor) Close() error { + defer n.wg.Wait() + n.m.Lock() + defer n.m.Unlock() + + if !n.isClosed() { + close(n.close) + } + + return nil +} + +// nolint:gocognit,cyclop +func (n *GeneratorInterceptor) loop(rtcpWriter interceptor.RTCPWriter) { + defer n.wg.Done() + + senderSSRC := rand.Uint32() // #nosec + + missingPacketSeqNums := make([]uint16, n.size) + filteredMissingPacket := make([]uint16, n.size) + + ticker := time.NewTicker(n.interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + // save NACKs to send without holding the mutex during Write + var toSend []rtcp.Packet + + n.receiveLogsMu.Lock() + for ssrc, receiveLog := range n.receiveLogs { + missing := receiveLog.missingSeqNumbers(n.skipLastN, missingPacketSeqNums) + + if len(missing) == 0 || n.nackCountLogs[ssrc] == nil { + n.nackCountLogs[ssrc] = map[uint16]uint16{} + } + if len(missing) == 0 { + continue + } + + var nack *rtcp.TransportLayerNack + + count := 0 + if n.maxNacksPerPacket > 0 { + for _, missingSeq := range missing { + if n.nackCountLogs[ssrc][missingSeq] < n.maxNacksPerPacket { + filteredMissingPacket[count] = missingSeq + count++ + } + n.nackCountLogs[ssrc][missingSeq]++ + } + + if count == 0 { + continue + } + + nack = &rtcp.TransportLayerNack{ + SenderSSRC: senderSSRC, + MediaSSRC: ssrc, + Nacks: rtcp.NackPairsFromSequenceNumbers(filteredMissingPacket[:count]), + } + } else { + nack = &rtcp.TransportLayerNack{ + SenderSSRC: senderSSRC, + MediaSSRC: ssrc, + Nacks: rtcp.NackPairsFromSequenceNumbers(missing), + } + } + + for nackSeq := range n.nackCountLogs[ssrc] { + if !slices.Contains(missing, nackSeq) { + delete(n.nackCountLogs[ssrc], nackSeq) + } + } + + // clean up the count log for the ssrc if it's empty + if len(n.nackCountLogs[ssrc]) == 0 { + delete(n.nackCountLogs, ssrc) + } + + toSend = append(toSend, nack) + } + n.receiveLogsMu.Unlock() + + // send RTCP without holding receiveLogsMu + for _, pkt := range toSend { + if _, err := rtcpWriter.Write([]rtcp.Packet{pkt}, interceptor.Attributes{}); err != nil { + n.log.Warnf("failed sending nack: %+v", err) + } + } + + case <-n.close: + return + } + } +} + +func (n *GeneratorInterceptor) isClosed() bool { + select { + case <-n.close: + return true + default: + return false + } +} diff --git a/vendor/github.com/pion/interceptor/pkg/nack/generator_option.go b/vendor/github.com/pion/interceptor/pkg/nack/generator_option.go new file mode 100644 index 0000000..9a80d7b --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/nack/generator_option.go @@ -0,0 +1,81 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package nack + +import ( + "time" + + "github.com/pion/interceptor" + "github.com/pion/logging" +) + +// GeneratorOption can be used to configure GeneratorInterceptor. +type GeneratorOption func(r *GeneratorInterceptor) error + +// GeneratorSize sets the size of the interceptor. +// Size must be one of: 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768. +func GeneratorSize(size uint16) GeneratorOption { + return func(r *GeneratorInterceptor) error { + r.size = size + + return nil + } +} + +// GeneratorSkipLastN sets the number of packets (n-1 packets before the last received packets) +// +// to ignore when generating nack requests. +func GeneratorSkipLastN(skipLastN uint16) GeneratorOption { + return func(r *GeneratorInterceptor) error { + r.skipLastN = skipLastN + + return nil + } +} + +// GeneratorMaxNacksPerPacket sets the maximum number of NACKs sent per missing packet, e.g. if set to 2, a missing +// packet will only be NACKed at most twice. If set to 0 (default), max number of NACKs is unlimited. +func GeneratorMaxNacksPerPacket(maxNacks uint16) GeneratorOption { + return func(r *GeneratorInterceptor) error { + r.maxNacksPerPacket = maxNacks + + return nil + } +} + +// GeneratorLog sets a logger for the interceptor. +func GeneratorLog(log logging.LeveledLogger) GeneratorOption { + return func(r *GeneratorInterceptor) error { + r.log = log + + return nil + } +} + +// WithGeneratorLoggerFactory sets a logger factory for the interceptor. +func WithGeneratorLoggerFactory(loggerFactory logging.LoggerFactory) GeneratorOption { + return func(r *GeneratorInterceptor) error { + r.loggerFactory = loggerFactory + + return nil + } +} + +// GeneratorInterval sets the nack send interval for the interceptor. +func GeneratorInterval(interval time.Duration) GeneratorOption { + return func(r *GeneratorInterceptor) error { + r.interval = interval + + return nil + } +} + +// GeneratorStreamsFilter sets filter for generator streams. +func GeneratorStreamsFilter(filter func(info *interceptor.StreamInfo) bool) GeneratorOption { + return func(r *GeneratorInterceptor) error { + r.streamsFilter = filter + + return nil + } +} diff --git a/vendor/github.com/pion/interceptor/pkg/nack/nack.go b/vendor/github.com/pion/interceptor/pkg/nack/nack.go new file mode 100644 index 0000000..6d560a2 --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/nack/nack.go @@ -0,0 +1,17 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package nack provides interceptors to implement sending and receiving negative acknowledgements +package nack + +import "github.com/pion/interceptor" + +func streamSupportNack(info *interceptor.StreamInfo) bool { + for _, fb := range info.RTCPFeedback { + if fb.Type == "nack" && fb.Parameter == "" { + return true + } + } + + return false +} diff --git a/vendor/github.com/pion/interceptor/pkg/nack/receive_log.go b/vendor/github.com/pion/interceptor/pkg/nack/receive_log.go new file mode 100644 index 0000000..88ea991 --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/nack/receive_log.go @@ -0,0 +1,144 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package nack + +import ( + "fmt" + "sync" + + "github.com/pion/interceptor/internal/rtpbuffer" +) + +type receiveLog struct { + packets []uint64 + size uint16 + end uint16 + started bool + lastConsecutive uint16 + m sync.RWMutex +} + +func newReceiveLog(size uint16) (*receiveLog, error) { + allowedSizes := make([]uint16, 0) + correctSize := false + for i := 6; i < 16; i++ { + if size == 1< end (with counting for rollovers) + for i := s.end + 1; i != seq; i++ { + // clear packets between end and seq (these may contain packets from a "size" ago) + s.delReceived(i) + } + s.end = seq + + if s.lastConsecutive+1 == seq { + s.lastConsecutive = seq + } else if seq-s.lastConsecutive > s.size { + s.lastConsecutive = seq - s.size + s.fixLastConsecutive() // there might be valid packets at the beginning of the buffer now + } + case s.lastConsecutive+1 == seq: + // negative diff, seq < end (with counting for rollovers) + s.lastConsecutive = seq + s.fixLastConsecutive() // there might be other valid packets after seq + } + + s.setReceived(seq) +} + +func (s *receiveLog) get(seq uint16) bool { + s.m.RLock() + defer s.m.RUnlock() + + diff := s.end - seq + if diff >= rtpbuffer.Uint16SizeHalf { + return false + } + + if diff >= s.size { + return false + } + + return s.getReceived(seq) +} + +func (s *receiveLog) missingSeqNumbers(skipLastN uint16, missingPacketSeqNums []uint16) []uint16 { + s.m.RLock() + defer s.m.RUnlock() + + until := s.end - skipLastN + if until-s.lastConsecutive >= rtpbuffer.Uint16SizeHalf { + // until < s.lastConsecutive (counting for rollover) + return nil + } + + c := 0 + for i := s.lastConsecutive + 1; i != until+1; i++ { + if !s.getReceived(i) { + missingPacketSeqNums[c] = i + c++ + } + } + + return missingPacketSeqNums[:c] +} + +func (s *receiveLog) setReceived(seq uint16) { + pos := seq % s.size + s.packets[pos/64] |= 1 << (pos % 64) +} + +func (s *receiveLog) delReceived(seq uint16) { + pos := seq % s.size + s.packets[pos/64] &^= 1 << (pos % 64) +} + +func (s *receiveLog) getReceived(seq uint16) bool { + pos := seq % s.size + + return (s.packets[pos/64] & (1 << (pos % 64))) != 0 +} + +func (s *receiveLog) fixLastConsecutive() { + i := s.lastConsecutive + 1 + for ; i != s.end+1 && s.getReceived(i); i++ { //nolint:revive + // find all consecutive packets + } + + s.lastConsecutive = i - 1 +} diff --git a/vendor/github.com/pion/interceptor/pkg/nack/responder_interceptor.go b/vendor/github.com/pion/interceptor/pkg/nack/responder_interceptor.go new file mode 100644 index 0000000..4aa6247 --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/nack/responder_interceptor.go @@ -0,0 +1,178 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package nack + +import ( + "sync" + + "github.com/pion/interceptor" + "github.com/pion/interceptor/internal/rtpbuffer" + "github.com/pion/logging" + "github.com/pion/rtcp" + "github.com/pion/rtp" +) + +// ResponderInterceptorFactory is a interceptor.Factory for a ResponderInterceptor. +type ResponderInterceptorFactory struct { + opts []ResponderOption +} + +// NewInterceptor constructs a new ResponderInterceptor. +func (r *ResponderInterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) { + responderInterceptor := &ResponderInterceptor{ + streamsFilter: streamSupportNack, + size: 1024, + streams: map[uint32]*localStream{}, + } + + for _, opt := range r.opts { + if err := opt(responderInterceptor); err != nil { + return nil, err + } + } + + if responderInterceptor.loggerFactory == nil { + responderInterceptor.loggerFactory = logging.NewDefaultLoggerFactory() + } + if responderInterceptor.log == nil { + responderInterceptor.log = responderInterceptor.loggerFactory.NewLogger("nack_responder") + } + if responderInterceptor.packetFactory == nil { + responderInterceptor.packetFactory = rtpbuffer.NewPacketFactoryCopy() + } + + if _, err := rtpbuffer.NewRTPBuffer(responderInterceptor.size); err != nil { + return nil, err + } + + return responderInterceptor, nil +} + +// ResponderInterceptor responds to nack feedback messages. +type ResponderInterceptor struct { + interceptor.NoOp + streamsFilter func(info *interceptor.StreamInfo) bool + size uint16 + log logging.LeveledLogger + loggerFactory logging.LoggerFactory + packetFactory rtpbuffer.PacketFactory + + streams map[uint32]*localStream + streamsMu sync.Mutex +} + +type localStream struct { + rtpBuffer *rtpbuffer.RTPBuffer + rtpBufferMutex sync.RWMutex + rtpWriter interceptor.RTPWriter +} + +// NewResponderInterceptor returns a new ResponderInterceptorFactor. +func NewResponderInterceptor(opts ...ResponderOption) (*ResponderInterceptorFactory, error) { + return &ResponderInterceptorFactory{opts}, nil +} + +// BindRTCPReader lets you modify any incoming RTCP packets. It is called once per sender/receiver, however this might +// change in the future. The returned method will be called once per packet batch. +func (n *ResponderInterceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor.RTCPReader { + return interceptor.RTCPReaderFunc(func(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) { + i, attr, err := reader.Read(b, a) + if err != nil { + return 0, nil, err + } + + if attr == nil { + attr = make(interceptor.Attributes) + } + pkts, err := attr.GetRTCPPackets(b[:i]) + if err != nil { + return 0, nil, err + } + for _, rtcpPacket := range pkts { + nack, ok := rtcpPacket.(*rtcp.TransportLayerNack) + if !ok { + continue + } + + go n.resendPackets(nack) + } + + return i, attr, err + }) +} + +// BindLocalStream lets you modify any outgoing RTP packets. It is called once for per LocalStream. +// The returned method will be called once per rtp packet. +func (n *ResponderInterceptor) BindLocalStream( + info *interceptor.StreamInfo, writer interceptor.RTPWriter, +) interceptor.RTPWriter { + if !n.streamsFilter(info) { + return writer + } + + // error is already checked in NewGeneratorInterceptor + rtpBuffer, _ := rtpbuffer.NewRTPBuffer(n.size) + stream := &localStream{ + rtpBuffer: rtpBuffer, + rtpWriter: writer, + } + n.streamsMu.Lock() + n.streams[info.SSRC] = stream + n.streamsMu.Unlock() + + return interceptor.RTPWriterFunc( + func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + // If this packet doesn't belong to the main SSRC, do not add it to rtpBuffer + if header.SSRC != info.SSRC { + return writer.Write(header, payload, attributes) + } + + pkt, err := n.packetFactory.NewPacket(header, payload, info.SSRCRetransmission, info.PayloadTypeRetransmission) + if err != nil { + return 0, err + } + + stream.rtpBufferMutex.Lock() + stream.rtpBuffer.Add(pkt) + stream.rtpBufferMutex.Unlock() + + return writer.Write(header, payload, attributes) + }, + ) +} + +// UnbindLocalStream is called when the Stream is removed. It can be used to clean up any data related to that track. +func (n *ResponderInterceptor) UnbindLocalStream(info *interceptor.StreamInfo) { + n.streamsMu.Lock() + delete(n.streams, info.SSRC) + n.streamsMu.Unlock() +} + +func (n *ResponderInterceptor) resendPackets(nack *rtcp.TransportLayerNack) { + n.streamsMu.Lock() + stream, ok := n.streams[nack.MediaSSRC] + n.streamsMu.Unlock() + if !ok { + return + } + + for i := range nack.Nacks { + nack.Nacks[i].Range(func(seq uint16) bool { + // save the packet under the buffer lock + stream.rtpBufferMutex.Lock() + p := stream.rtpBuffer.Get(seq) + stream.rtpBufferMutex.Unlock() + + if p != nil { + // send without holding rtpBufferMutex + if _, err := stream.rtpWriter.Write(p.Header(), p.Payload(), interceptor.Attributes{}); err != nil { + n.log.Warnf("failed resending nacked packet: %+v", err) + } + p.Release() + } + + return true + }) + } +} diff --git a/vendor/github.com/pion/interceptor/pkg/nack/responder_option.go b/vendor/github.com/pion/interceptor/pkg/nack/responder_option.go new file mode 100644 index 0000000..95b24f0 --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/nack/responder_option.go @@ -0,0 +1,60 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package nack + +import ( + "github.com/pion/interceptor" + "github.com/pion/interceptor/internal/rtpbuffer" + "github.com/pion/logging" +) + +// ResponderOption can be used to configure ResponderInterceptor. +type ResponderOption func(s *ResponderInterceptor) error + +// ResponderSize sets the size of the interceptor. +// Size must be one of: 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768. +func ResponderSize(size uint16) ResponderOption { + return func(r *ResponderInterceptor) error { + r.size = size + + return nil + } +} + +// ResponderLog sets a logger for the interceptor. +func ResponderLog(log logging.LeveledLogger) ResponderOption { + return func(r *ResponderInterceptor) error { + r.log = log + + return nil + } +} + +// WithResponderLoggerFactory sets a logger factory for the interceptor. +func WithResponderLoggerFactory(loggerFactory logging.LoggerFactory) ResponderOption { + return func(r *ResponderInterceptor) error { + r.loggerFactory = loggerFactory + + return nil + } +} + +// DisableCopy bypasses copy of underlying packets. It should be used when +// you are not re-using underlying buffers of packets that have been written. +func DisableCopy() ResponderOption { + return func(s *ResponderInterceptor) error { + s.packetFactory = &rtpbuffer.PacketFactoryNoOp{} + + return nil + } +} + +// ResponderStreamsFilter sets filter for local streams. +func ResponderStreamsFilter(filter func(info *interceptor.StreamInfo) bool) ResponderOption { + return func(r *ResponderInterceptor) error { + r.streamsFilter = filter + + return nil + } +} diff --git a/vendor/github.com/pion/interceptor/pkg/report/receiver_interceptor.go b/vendor/github.com/pion/interceptor/pkg/report/receiver_interceptor.go new file mode 100644 index 0000000..d1c0c8a --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/report/receiver_interceptor.go @@ -0,0 +1,195 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package report + +import ( + "sync" + "time" + + "github.com/pion/interceptor" + "github.com/pion/logging" + "github.com/pion/rtcp" +) + +// ReceiverInterceptorFactory is a interceptor.Factory for a ReceiverInterceptor. +type ReceiverInterceptorFactory struct { + opts []ReceiverOption +} + +// NewInterceptor constructs a new ReceiverInterceptor. +func (r *ReceiverInterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) { + receiverInterceptor := &ReceiverInterceptor{ + interval: 1 * time.Second, + now: time.Now, + close: make(chan struct{}), + } + + for _, opt := range r.opts { + if err := opt(receiverInterceptor); err != nil { + return nil, err + } + } + + if receiverInterceptor.loggerFactory == nil { + receiverInterceptor.loggerFactory = logging.NewDefaultLoggerFactory() + } + if receiverInterceptor.log == nil { + receiverInterceptor.log = receiverInterceptor.loggerFactory.NewLogger("receiver_interceptor") + } + + return receiverInterceptor, nil +} + +// NewReceiverInterceptor returns a new ReceiverInterceptorFactory. +func NewReceiverInterceptor(opts ...ReceiverOption) (*ReceiverInterceptorFactory, error) { + return &ReceiverInterceptorFactory{opts}, nil +} + +// ReceiverInterceptor interceptor generates receiver reports. +type ReceiverInterceptor struct { + interceptor.NoOp + interval time.Duration + now func() time.Time + streams sync.Map + log logging.LeveledLogger + loggerFactory logging.LoggerFactory + m sync.Mutex + wg sync.WaitGroup + close chan struct{} +} + +func (r *ReceiverInterceptor) isClosed() bool { + select { + case <-r.close: + return true + default: + return false + } +} + +// Close closes the interceptor. +func (r *ReceiverInterceptor) Close() error { + defer r.wg.Wait() + r.m.Lock() + defer r.m.Unlock() + + if !r.isClosed() { + close(r.close) + } + + return nil +} + +// BindRTCPWriter lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method +// will be called once per packet batch. +func (r *ReceiverInterceptor) BindRTCPWriter(writer interceptor.RTCPWriter) interceptor.RTCPWriter { + r.m.Lock() + defer r.m.Unlock() + + if r.isClosed() { + return writer + } + + r.wg.Add(1) + + go r.loop(writer) + + return writer +} + +func (r *ReceiverInterceptor) loop(rtcpWriter interceptor.RTCPWriter) { + defer r.wg.Done() + + ticker := time.NewTicker(r.interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + now := r.now() + r.streams.Range(func(_, value any) bool { + if stream, ok := value.(*receiverStream); !ok { + r.log.Warnf("failed to cast ReceiverInterceptor stream") + } else if _, err := rtcpWriter.Write( + []rtcp.Packet{stream.generateReport(now)}, interceptor.Attributes{}, + ); err != nil { + r.log.Warnf("failed sending: %+v", err) + } + + return true + }) + + case <-r.close: + return + } + } +} + +// BindRemoteStream lets you modify any incoming RTP packets. It is called once for per RemoteStream. +// The returned method will be called once per rtp packet. +func (r *ReceiverInterceptor) BindRemoteStream( + info *interceptor.StreamInfo, reader interceptor.RTPReader, +) interceptor.RTPReader { + stream := newReceiverStream(info.SSRC, info.ClockRate) + r.streams.Store(info.SSRC, stream) + + return interceptor.RTPReaderFunc(func(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) { + i, attr, err := reader.Read(b, a) + if err != nil { + return 0, nil, err + } + + if attr == nil { + attr = make(interceptor.Attributes) + } + header, err := attr.GetRTPHeader(b[:i]) + if err != nil { + return 0, nil, err + } + + stream.processRTP(r.now(), header) + + return i, attr, nil + }) +} + +// UnbindRemoteStream is called when the Stream is removed. It can be used to clean up any data related to that track. +func (r *ReceiverInterceptor) UnbindRemoteStream(info *interceptor.StreamInfo) { + r.streams.Delete(info.SSRC) +} + +// BindRTCPReader lets you modify any incoming RTCP packets. It is called once per sender/receiver, however this might +// change in the future. The returned method will be called once per packet batch. +func (r *ReceiverInterceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor.RTCPReader { + return interceptor.RTCPReaderFunc(func(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) { + i, attr, err := reader.Read(b, a) + if err != nil { + return 0, nil, err + } + + if attr == nil { + attr = make(interceptor.Attributes) + } + pkts, err := attr.GetRTCPPackets(b[:i]) + if err != nil { + return 0, nil, err + } + + for _, pkt := range pkts { + if sr, ok := (pkt).(*rtcp.SenderReport); ok { + value, ok := r.streams.Load(sr.SSRC) + if !ok { + continue + } + + if stream, ok := value.(*receiverStream); !ok { + r.log.Warnf("failed to cast ReceiverInterceptor stream") + } else { + stream.processSenderReport(r.now(), sr) + } + } + } + + return i, attr, nil + }) +} diff --git a/vendor/github.com/pion/interceptor/pkg/report/receiver_option.go b/vendor/github.com/pion/interceptor/pkg/report/receiver_option.go new file mode 100644 index 0000000..5693b8b --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/report/receiver_option.go @@ -0,0 +1,49 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package report + +import ( + "time" + + "github.com/pion/logging" +) + +// ReceiverOption can be used to configure ReceiverInterceptor. +type ReceiverOption func(r *ReceiverInterceptor) error + +// ReceiverLog sets a logger for the interceptor. +func ReceiverLog(log logging.LeveledLogger) ReceiverOption { + return func(r *ReceiverInterceptor) error { + r.log = log + + return nil + } +} + +// WithReceiverLoggerFactory sets a logger factory for the interceptor. +func WithReceiverLoggerFactory(loggerFactory logging.LoggerFactory) ReceiverOption { + return func(r *ReceiverInterceptor) error { + r.loggerFactory = loggerFactory + + return nil + } +} + +// ReceiverInterval sets send interval for the interceptor. +func ReceiverInterval(interval time.Duration) ReceiverOption { + return func(r *ReceiverInterceptor) error { + r.interval = interval + + return nil + } +} + +// ReceiverNow sets an alternative for the time.Now function. +func ReceiverNow(f func() time.Time) ReceiverOption { + return func(r *ReceiverInterceptor) error { + r.now = f + + return nil + } +} diff --git a/vendor/github.com/pion/interceptor/pkg/report/receiver_stream.go b/vendor/github.com/pion/interceptor/pkg/report/receiver_stream.go new file mode 100644 index 0000000..c652a25 --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/report/receiver_stream.go @@ -0,0 +1,174 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package report + +import ( + "math/rand" + "sync" + "time" + + "github.com/pion/rtcp" + "github.com/pion/rtp" +) + +const ( + // packetsPerHistoryEntry represents how many packets are in the bitmask for + // each entry in the `packets` slice in the receiver stream. Because we use + // a uint64, we can keep track of 64 packets per entry. + packetsPerHistoryEntry = 64 +) + +type receiverStream struct { + ssrc uint32 + receiverSSRC uint32 + clockRate float64 + + m sync.Mutex + size uint16 + packets []uint64 + started bool + seqnumCycles uint16 + lastSeqnum uint16 + lastReportSeqnum uint16 + lastRTPTimeRTP uint32 + lastRTPTimeTime time.Time + jitter float64 + lastSenderReport uint32 + lastSenderReportTime time.Time + totalLost uint32 +} + +func newReceiverStream(ssrc uint32, clockRate uint32) *receiverStream { + receiverSSRC := rand.Uint32() // #nosec + + return &receiverStream{ + ssrc: ssrc, + receiverSSRC: receiverSSRC, + clockRate: float64(clockRate), + size: 128, + packets: make([]uint64, 128), + } +} + +func (stream *receiverStream) processRTP(now time.Time, pktHeader *rtp.Header) { + stream.m.Lock() + defer stream.m.Unlock() + + //nolint:nestif + if !stream.started { // first frame + stream.started = true + stream.setReceived(pktHeader.SequenceNumber) + stream.lastSeqnum = pktHeader.SequenceNumber + stream.lastReportSeqnum = pktHeader.SequenceNumber - 1 + stream.lastRTPTimeRTP = pktHeader.Timestamp + stream.lastRTPTimeTime = now + } else { // following frames + stream.setReceived(pktHeader.SequenceNumber) + + diff := pktHeader.SequenceNumber - stream.lastSeqnum + if diff > 0 && diff < (1<<15) { + // wrap around + if pktHeader.SequenceNumber < stream.lastSeqnum { + stream.seqnumCycles++ + } + + // set missing packets as missing + for i := stream.lastSeqnum + 1; i != pktHeader.SequenceNumber; i++ { + stream.delReceived(i) + } + + stream.lastSeqnum = pktHeader.SequenceNumber + } + + // compute jitter + // https://tools.ietf.org/html/rfc3550#page-39 + D := now.Sub(stream.lastRTPTimeTime).Seconds()*stream.clockRate - + (float64(pktHeader.Timestamp) - float64(stream.lastRTPTimeRTP)) + if D < 0 { + D = -D + } + stream.jitter += (D - stream.jitter) / 16 + stream.lastRTPTimeRTP = pktHeader.Timestamp + stream.lastRTPTimeTime = now + } +} + +func (stream *receiverStream) setReceived(seq uint16) { + pos := seq % (stream.size * packetsPerHistoryEntry) + stream.packets[pos/packetsPerHistoryEntry] |= 1 << (pos % packetsPerHistoryEntry) +} + +func (stream *receiverStream) delReceived(seq uint16) { + pos := seq % (stream.size * packetsPerHistoryEntry) + stream.packets[pos/packetsPerHistoryEntry] &^= 1 << (pos % packetsPerHistoryEntry) +} + +func (stream *receiverStream) getReceived(seq uint16) bool { + pos := seq % (stream.size * packetsPerHistoryEntry) + + return (stream.packets[pos/packetsPerHistoryEntry] & (1 << (pos % packetsPerHistoryEntry))) != 0 +} + +func (stream *receiverStream) processSenderReport(now time.Time, sr *rtcp.SenderReport) { + stream.m.Lock() + defer stream.m.Unlock() + + stream.lastSenderReport = uint32(sr.NTPTime >> 16) //nolint:gosec // G115 + stream.lastSenderReportTime = now +} + +func (stream *receiverStream) generateReport(now time.Time) *rtcp.ReceiverReport { + stream.m.Lock() + defer stream.m.Unlock() + + totalSinceReport := stream.lastSeqnum - stream.lastReportSeqnum + totalLostSinceReport := func() uint32 { + if stream.lastSeqnum == stream.lastReportSeqnum { + return 0 + } + + ret := uint32(0) + for i := stream.lastReportSeqnum + 1; i != stream.lastSeqnum; i++ { + if !stream.getReceived(i) { + ret++ + } + } + + return ret + }() + stream.totalLost += totalLostSinceReport + + // allow up to 24 bits + if totalLostSinceReport > 0xFFFFFF { + totalLostSinceReport = 0xFFFFFF + } + if stream.totalLost > 0xFFFFFF { + stream.totalLost = 0xFFFFFF + } + + receiverReport := &rtcp.ReceiverReport{ + SSRC: stream.receiverSSRC, + Reports: []rtcp.ReceptionReport{ + { + SSRC: stream.ssrc, + LastSequenceNumber: uint32(stream.seqnumCycles)<<16 | uint32(stream.lastSeqnum), + LastSenderReport: stream.lastSenderReport, + FractionLost: uint8(float64(totalLostSinceReport*256) / float64(totalSinceReport)), + TotalLost: stream.totalLost, + Delay: func() uint32 { + if stream.lastSenderReportTime.IsZero() { + return 0 + } + + return uint32(now.Sub(stream.lastSenderReportTime).Seconds() * 65536) + }(), + Jitter: uint32(stream.jitter), + }, + }, + } + + stream.lastReportSeqnum = stream.lastSeqnum + + return receiverReport +} diff --git a/vendor/github.com/pion/interceptor/pkg/report/report.go b/vendor/github.com/pion/interceptor/pkg/report/report.go new file mode 100644 index 0000000..4338a49 --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/report/report.go @@ -0,0 +1,5 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package report provides interceptors to implement sending sender and receiver reports. +package report diff --git a/vendor/github.com/pion/interceptor/pkg/report/sender_interceptor.go b/vendor/github.com/pion/interceptor/pkg/report/sender_interceptor.go new file mode 100644 index 0000000..ea1139b --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/report/sender_interceptor.go @@ -0,0 +1,162 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package report + +import ( + "sync" + "time" + + "github.com/pion/interceptor" + "github.com/pion/logging" + "github.com/pion/rtcp" + "github.com/pion/rtp" +) + +// TickerFactory is a factory to create new tickers. +type TickerFactory func(d time.Duration) Ticker + +// SenderInterceptorFactory is a interceptor.Factory for a SenderInterceptor. +type SenderInterceptorFactory struct { + opts []SenderOption +} + +// NewInterceptor constructs a new SenderInterceptor. +func (s *SenderInterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) { + senderInterceptor := &SenderInterceptor{ + interval: 1 * time.Second, + now: time.Now, + newTicker: func(d time.Duration) Ticker { + return &timeTicker{time.NewTicker(d)} + }, + close: make(chan struct{}), + } + + for _, opt := range s.opts { + if err := opt(senderInterceptor); err != nil { + return nil, err + } + } + + if senderInterceptor.loggerFactory == nil { + senderInterceptor.loggerFactory = logging.NewDefaultLoggerFactory() + } + if senderInterceptor.log == nil { + senderInterceptor.log = senderInterceptor.loggerFactory.NewLogger("sender_interceptor") + } + + return senderInterceptor, nil +} + +// NewSenderInterceptor returns a new SenderInterceptorFactory. +func NewSenderInterceptor(opts ...SenderOption) (*SenderInterceptorFactory, error) { + return &SenderInterceptorFactory{opts}, nil +} + +// SenderInterceptor interceptor generates sender reports. +type SenderInterceptor struct { + interceptor.NoOp + interval time.Duration + now func() time.Time + newTicker TickerFactory + streams sync.Map + log logging.LeveledLogger + loggerFactory logging.LoggerFactory + m sync.Mutex + wg sync.WaitGroup + close chan struct{} + started chan struct{} + + useLatestPacket bool +} + +func (s *SenderInterceptor) isClosed() bool { + select { + case <-s.close: + return true + default: + return false + } +} + +// Close closes the interceptor. +func (s *SenderInterceptor) Close() error { + defer s.wg.Wait() + s.m.Lock() + defer s.m.Unlock() + + if !s.isClosed() { + close(s.close) + } + + return nil +} + +// BindRTCPWriter lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method +// will be called once per packet batch. +func (s *SenderInterceptor) BindRTCPWriter(writer interceptor.RTCPWriter) interceptor.RTCPWriter { + s.m.Lock() + defer s.m.Unlock() + + if s.isClosed() { + return writer + } + + s.wg.Add(1) + + go s.loop(writer) + + return writer +} + +func (s *SenderInterceptor) loop(rtcpWriter interceptor.RTCPWriter) { + defer s.wg.Done() + + ticker := s.newTicker(s.interval) + defer ticker.Stop() + if s.started != nil { + // This lets us synchronize in tests to know whether the loop has begun or not. + // It only happens if started was initialized, which should not occur in non-tests. + close(s.started) + } + for { + select { + case <-ticker.Ch(): + now := s.now() + s.streams.Range(func(_, value any) bool { + if stream, ok := value.(*senderStream); !ok { + s.log.Warnf("failed to cast SenderInterceptor stream") + } else if _, err := rtcpWriter.Write( + []rtcp.Packet{stream.generateReport(now)}, interceptor.Attributes{}, + ); err != nil { + s.log.Warnf("failed sending: %+v", err) + } + + return true + }) + + case <-s.close: + return + } + } +} + +// BindLocalStream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method +// will be called once per rtp packet. +func (s *SenderInterceptor) BindLocalStream( + info *interceptor.StreamInfo, writer interceptor.RTPWriter, +) interceptor.RTPWriter { + stream := newSenderStream(info.SSRC, info.ClockRate, s.useLatestPacket) + s.streams.Store(info.SSRC, stream) + + return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, a interceptor.Attributes) (int, error) { + stream.processRTP(s.now(), header, payload) + + return writer.Write(header, payload, a) + }) +} + +// UnbindLocalStream is called when the Stream is removed. It can be used to clean up any data related to that track. +func (s *SenderInterceptor) UnbindLocalStream(info *interceptor.StreamInfo) { + s.streams.Delete(info.SSRC) +} diff --git a/vendor/github.com/pion/interceptor/pkg/report/sender_option.go b/vendor/github.com/pion/interceptor/pkg/report/sender_option.go new file mode 100644 index 0000000..1170582 --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/report/sender_option.go @@ -0,0 +1,78 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package report + +import ( + "time" + + "github.com/pion/logging" +) + +// SenderOption can be used to configure SenderInterceptor. +type SenderOption func(r *SenderInterceptor) error + +// SenderLog sets a logger for the interceptor. +func SenderLog(log logging.LeveledLogger) SenderOption { + return func(r *SenderInterceptor) error { + r.log = log + + return nil + } +} + +// WithSenderLoggerFactory sets a logger factory for the interceptor. +func WithSenderLoggerFactory(loggerFactory logging.LoggerFactory) SenderOption { + return func(r *SenderInterceptor) error { + r.loggerFactory = loggerFactory + + return nil + } +} + +// SenderInterval sets send interval for the interceptor. +func SenderInterval(interval time.Duration) SenderOption { + return func(r *SenderInterceptor) error { + r.interval = interval + + return nil + } +} + +// SenderNow sets an alternative for the time.Now function. +func SenderNow(f func() time.Time) SenderOption { + return func(r *SenderInterceptor) error { + r.now = f + + return nil + } +} + +// SenderTicker sets an alternative for the time.NewTicker function. +func SenderTicker(f TickerFactory) SenderOption { + return func(r *SenderInterceptor) error { + r.newTicker = f + + return nil + } +} + +// SenderUseLatestPacket sets the interceptor to always use the latest packet, even +// if it appears to be out-of-order. +func SenderUseLatestPacket() SenderOption { + return func(r *SenderInterceptor) error { + r.useLatestPacket = true + + return nil + } +} + +// enableStartTracking is used by tests to synchronize whether the loop() has begun +// and it's safe to start sending ticks to the ticker. +func enableStartTracking(startedCh chan struct{}) SenderOption { + return func(r *SenderInterceptor) error { + r.started = startedCh + + return nil + } +} diff --git a/vendor/github.com/pion/interceptor/pkg/report/sender_stream.go b/vendor/github.com/pion/interceptor/pkg/report/sender_stream.go new file mode 100644 index 0000000..a637354 --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/report/sender_stream.go @@ -0,0 +1,69 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package report + +import ( + "sync" + "time" + + "github.com/pion/interceptor/internal/ntp" + "github.com/pion/rtcp" + "github.com/pion/rtp" +) + +type senderStream struct { + ssrc uint32 + clockRate float64 + m sync.Mutex + + useLatestPacket bool + + // data from rtp packets + lastRTPTimeRTP uint32 + lastRTPTimeTime time.Time + lastRTPSN uint16 + packetCount uint32 + octetCount uint32 +} + +func newSenderStream(ssrc uint32, clockRate uint32, useLatestPacket bool) *senderStream { + return &senderStream{ + ssrc: ssrc, + clockRate: float64(clockRate), + useLatestPacket: useLatestPacket, + } +} + +func (stream *senderStream) processRTP(now time.Time, header *rtp.Header, payload []byte) { + stream.m.Lock() + defer stream.m.Unlock() + + diff := header.SequenceNumber - stream.lastRTPSN + if stream.useLatestPacket || stream.packetCount == 0 || (diff > 0 && diff < (1<<15)) { + // Told to consider every packet, or this was the first packet, or it's in-order + stream.lastRTPSN = header.SequenceNumber + // update only on first packet of a frame to ensure sender report does not get affected by + // processing delay of pushing a large frame which could span multiple packets + if header.Timestamp != stream.lastRTPTimeRTP { + stream.lastRTPTimeRTP = header.Timestamp + stream.lastRTPTimeTime = now + } + } + + stream.packetCount++ + stream.octetCount += uint32(len(payload)) //nolint:gosec // G115 +} + +func (stream *senderStream) generateReport(now time.Time) *rtcp.SenderReport { + stream.m.Lock() + defer stream.m.Unlock() + + return &rtcp.SenderReport{ + SSRC: stream.ssrc, + NTPTime: ntp.ToNTP(now), + RTPTime: stream.lastRTPTimeRTP + uint32(now.Sub(stream.lastRTPTimeTime).Seconds()*stream.clockRate), + PacketCount: stream.packetCount, + OctetCount: stream.octetCount, + } +} diff --git a/vendor/github.com/pion/interceptor/pkg/report/ticker.go b/vendor/github.com/pion/interceptor/pkg/report/ticker.go new file mode 100644 index 0000000..a6fef16 --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/report/ticker.go @@ -0,0 +1,20 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package report + +import "time" + +// Ticker is an interface for *time.Ticker for use with the SenderTicker option. +type Ticker interface { + Ch() <-chan time.Time + Stop() +} + +type timeTicker struct { + *time.Ticker +} + +func (t *timeTicker) Ch() <-chan time.Time { + return t.C +} diff --git a/vendor/github.com/pion/interceptor/pkg/rfc8888/interceptor.go b/vendor/github.com/pion/interceptor/pkg/rfc8888/interceptor.go new file mode 100644 index 0000000..f513a1a --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/rfc8888/interceptor.go @@ -0,0 +1,197 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package rfc8888 provides an interceptor that generates congestion control +// feedback reports as defined by RFC 8888. +package rfc8888 + +import ( + "sync" + "time" + + "github.com/pion/interceptor" + "github.com/pion/logging" + "github.com/pion/rtcp" +) + +// TickerFactory is a factory to create new tickers. +type TickerFactory func(d time.Duration) ticker + +// SenderInterceptorFactory is a interceptor.Factory for a SenderInterceptor. +type SenderInterceptorFactory struct { + opts []Option +} + +// NewInterceptor constructs a new SenderInterceptor. +func (s *SenderInterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) { + senderInterceptor := &SenderInterceptor{ + NoOp: interceptor.NoOp{}, + lock: sync.Mutex{}, + wg: sync.WaitGroup{}, + recorder: NewRecorder(), + interval: 100 * time.Millisecond, + maxReportSize: 1200, + packetChan: make(chan packet), + newTicker: func(d time.Duration) ticker { + return &timeTicker{time.NewTicker(d)} + }, + now: time.Now, + close: make(chan struct{}), + } + for _, opt := range s.opts { + err := opt(senderInterceptor) + if err != nil { + return nil, err + } + } + + if senderInterceptor.loggerFactory == nil { + senderInterceptor.loggerFactory = logging.NewDefaultLoggerFactory() + } + if senderInterceptor.log == nil { + senderInterceptor.log = senderInterceptor.loggerFactory.NewLogger("rfc8888_interceptor") + } + + return senderInterceptor, nil +} + +// NewSenderInterceptor returns a new SenderInterceptorFactory configured with the given options. +func NewSenderInterceptor(opts ...Option) (*SenderInterceptorFactory, error) { + return &SenderInterceptorFactory{opts: opts}, nil +} + +// SenderInterceptor sends congestion control feedback as specified in RFC 8888. +type SenderInterceptor struct { + interceptor.NoOp + log logging.LeveledLogger + loggerFactory logging.LoggerFactory + lock sync.Mutex + wg sync.WaitGroup + recorder *Recorder + interval time.Duration + maxReportSize int64 + packetChan chan packet + newTicker TickerFactory + now func() time.Time + close chan struct{} +} + +type packet struct { + arrival time.Time + ssrc uint32 + sequenceNumber uint16 + ecn uint8 +} + +// BindRTCPWriter lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method +// will be called once per packet batch. +func (s *SenderInterceptor) BindRTCPWriter(writer interceptor.RTCPWriter) interceptor.RTCPWriter { + s.lock.Lock() + defer s.lock.Unlock() + + if s.isClosed() { + return writer + } + + s.wg.Add(1) + go s.loop(writer) + + return writer +} + +// BindRemoteStream lets you modify any incoming RTP packets. +// It is called once for per RemoteStream. The returned method +// will be called once per rtp packet.. +func (s *SenderInterceptor) BindRemoteStream( + _ *interceptor.StreamInfo, reader interceptor.RTPReader, +) interceptor.RTPReader { + return interceptor.RTPReaderFunc(func(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) { + i, attr, err := reader.Read(b, a) + if err != nil { + return 0, nil, err + } + + if attr == nil { + attr = make(interceptor.Attributes) + } + header, err := attr.GetRTPHeader(b[:i]) + if err != nil { + return 0, nil, err + } + + p := packet{ + arrival: s.now(), + ssrc: header.SSRC, + sequenceNumber: header.SequenceNumber, + ecn: 0, // ECN is not supported (yet). + } + s.packetChan <- p + + return i, attr, nil + }) +} + +// Close closes the interceptor. +func (s *SenderInterceptor) Close() error { + s.log.Trace("close") + defer s.wg.Wait() + + if !s.isClosed() { + close(s.close) + } + + return nil +} + +func (s *SenderInterceptor) isClosed() bool { + select { + case <-s.close: + return true + default: + return false + } +} + +func (s *SenderInterceptor) loop(writer interceptor.RTCPWriter) { + defer s.wg.Done() + + select { + case <-s.close: + return + case pkt := <-s.packetChan: + s.log.Tracef("got first packet: %v", pkt) + s.recorder.AddPacket(pkt.arrival, pkt.ssrc, pkt.sequenceNumber, pkt.ecn) + } + + s.log.Trace("start loop") + t := s.newTicker(s.interval) + for { + select { + case <-s.close: + t.Stop() + + return + + case pkt := <-s.packetChan: + s.log.Tracef("got packet: %v", pkt) + s.recorder.AddPacket(pkt.arrival, pkt.ssrc, pkt.sequenceNumber, pkt.ecn) + + case <-t.Ch(): + now := s.now() + s.log.Tracef("report triggered at %v", now) + if writer == nil { + s.log.Trace("no writer added, continue") + + continue + } + pkts := s.recorder.BuildReport(now, int(s.maxReportSize)) + if pkts == nil { + continue + } + s.log.Tracef("got report: %v", pkts) + if _, err := writer.Write([]rtcp.Packet{pkts}, nil); err != nil { + s.log.Error(err.Error()) + } + } + } +} diff --git a/vendor/github.com/pion/interceptor/pkg/rfc8888/option.go b/vendor/github.com/pion/interceptor/pkg/rfc8888/option.go new file mode 100644 index 0000000..ca70d3e --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/rfc8888/option.go @@ -0,0 +1,49 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package rfc8888 + +import ( + "time" + + "github.com/pion/logging" +) + +// An Option is a function that can be used to configure a SenderInterceptor. +type Option func(*SenderInterceptor) error + +// SenderTicker sets an alternative for time.Ticker. +func SenderTicker(f TickerFactory) Option { + return func(i *SenderInterceptor) error { + i.newTicker = f + + return nil + } +} + +// SenderNow sets an alternative for the time.Now function. +func SenderNow(f func() time.Time) Option { + return func(i *SenderInterceptor) error { + i.now = f + + return nil + } +} + +// SendInterval sets the feedback send interval for the interceptor. +func SendInterval(interval time.Duration) Option { + return func(s *SenderInterceptor) error { + s.interval = interval + + return nil + } +} + +// WithLoggerFactory sets the logger factory for the interceptor. +func WithLoggerFactory(loggerFactory logging.LoggerFactory) Option { + return func(i *SenderInterceptor) error { + i.loggerFactory = loggerFactory + + return nil + } +} diff --git a/vendor/github.com/pion/interceptor/pkg/rfc8888/recorder.go b/vendor/github.com/pion/interceptor/pkg/rfc8888/recorder.go new file mode 100644 index 0000000..35a65cf --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/rfc8888/recorder.go @@ -0,0 +1,60 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package rfc8888 + +import ( + "time" + + "github.com/pion/interceptor/internal/ntp" + "github.com/pion/rtcp" +) + +type packetReport struct { + arrivalTime time.Time + ecn uint8 +} + +// Recorder records incoming RTP packets and their arrival times. Recorder can +// be used to create feedback reports as defined by RFC 8888. +type Recorder struct { + ssrc uint32 + streams map[uint32]*streamLog +} + +// NewRecorder creates a new Recorder. +func NewRecorder() *Recorder { + return &Recorder{ + streams: map[uint32]*streamLog{}, + } +} + +// AddPacket writes a packet to the underlying stream. +func (r *Recorder) AddPacket(ts time.Time, ssrc uint32, seq uint16, ecn uint8) { + stream, ok := r.streams[ssrc] + if !ok { + stream = newStreamLog(ssrc) + r.streams[ssrc] = stream + } + stream.add(ts, seq, ecn) +} + +// BuildReport creates a new rtcp.CCFeedbackReport containing all packets that +// were added by AddPacket and missing packets. +func (r *Recorder) BuildReport(now time.Time, maxSize int) *rtcp.CCFeedbackReport { + report := &rtcp.CCFeedbackReport{ + SenderSSRC: r.ssrc, + ReportBlocks: []rtcp.CCFeedbackReportBlock{}, + ReportTimestamp: ntp.ToNTP32(now), + } + + maxReportBlocks := (maxSize - 12 - (8 * len(r.streams))) / 2 + maxReportBlocksPerStream := maxReportBlocks / len(r.streams) + + for _, log := range r.streams { + block := log.metricsAfter(now, int64(maxReportBlocksPerStream)) + report.ReportBlocks = append(report.ReportBlocks, block) + } + + return report +} diff --git a/vendor/github.com/pion/interceptor/pkg/rfc8888/stream_log.go b/vendor/github.com/pion/interceptor/pkg/rfc8888/stream_log.go new file mode 100644 index 0000000..d95decc --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/rfc8888/stream_log.go @@ -0,0 +1,113 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package rfc8888 + +import ( + "time" + + "github.com/pion/interceptor/internal/sequencenumber" + "github.com/pion/rtcp" +) + +const maxReportsPerReportBlock = 16384 + +type streamLog struct { + ssrc uint32 + sequence sequencenumber.Unwrapper + init bool + nextSequenceNumberToReport int64 // next to report + lastSequenceNumberReceived int64 // highest received + log map[int64]*packetReport +} + +func newStreamLog(ssrc uint32) *streamLog { + return &streamLog{ + ssrc: ssrc, + sequence: sequencenumber.Unwrapper{}, + init: false, + nextSequenceNumberToReport: 0, + lastSequenceNumberReceived: 0, + log: map[int64]*packetReport{}, + } +} + +func (l *streamLog) add(ts time.Time, sequenceNumber uint16, ecn uint8) { + unwrappedSequenceNumber := l.sequence.Unwrap(sequenceNumber) + if !l.init { + l.init = true + l.nextSequenceNumberToReport = unwrappedSequenceNumber + } + l.log[unwrappedSequenceNumber] = &packetReport{ + arrivalTime: ts, + ecn: ecn, + } + if l.lastSequenceNumberReceived < unwrappedSequenceNumber { + l.lastSequenceNumberReceived = unwrappedSequenceNumber + } +} + +// metricsAfter iterates over all packets order of their sequence number. +// Packets are removed until the first loss is detected. +func (l *streamLog) metricsAfter(reference time.Time, maxReportBlocks int64) rtcp.CCFeedbackReportBlock { + if len(l.log) == 0 { + return rtcp.CCFeedbackReportBlock{ + MediaSSRC: l.ssrc, + BeginSequence: uint16(l.nextSequenceNumberToReport), //nolint:gosec // G115 + MetricBlocks: []rtcp.CCFeedbackMetricBlock{}, + } + } + numReports := l.lastSequenceNumberReceived - l.nextSequenceNumberToReport + 1 + if numReports > maxReportBlocks { + numReports = maxReportBlocks + l.nextSequenceNumberToReport = l.lastSequenceNumberReceived - maxReportBlocks + 1 + } + metricBlocks := make([]rtcp.CCFeedbackMetricBlock, numReports) + offset := l.nextSequenceNumberToReport + lastReceived := l.nextSequenceNumberToReport + gapDetected := false + for i := offset; i <= l.lastSequenceNumberReceived; i++ { //nolint:varnamelen // i int64 + received := false + ecn := uint8(0) + ato := uint16(0) + if report, ok := l.log[i]; ok { + received = true + ecn = report.ecn + ato = getArrivalTimeOffset(reference, report.arrivalTime) + } + metricBlocks[i-offset] = rtcp.CCFeedbackMetricBlock{ + Received: received, + ECN: rtcp.ECN(ecn), + ArrivalTimeOffset: ato, + } + + if !gapDetected { + if received && i == l.nextSequenceNumberToReport { + delete(l.log, i) + l.nextSequenceNumberToReport++ + lastReceived = i + } + if i > lastReceived+1 { + gapDetected = true + } + } + } + + return rtcp.CCFeedbackReportBlock{ + MediaSSRC: l.ssrc, + BeginSequence: uint16(offset), //nolint:gosec // G115 + MetricBlocks: metricBlocks, + } +} + +func getArrivalTimeOffset(base time.Time, arrival time.Time) uint16 { + if base.Before(arrival) { + return 0x1FFF + } + ato := uint16(base.Sub(arrival).Seconds() * 1024.0) + if ato > 0x1FFD { + return 0x1FFE + } + + return ato +} diff --git a/vendor/github.com/pion/interceptor/pkg/rfc8888/ticker.go b/vendor/github.com/pion/interceptor/pkg/rfc8888/ticker.go new file mode 100644 index 0000000..97806bd --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/rfc8888/ticker.go @@ -0,0 +1,19 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package rfc8888 + +import "time" + +type ticker interface { + Ch() <-chan time.Time + Stop() +} + +type timeTicker struct { + *time.Ticker +} + +func (t *timeTicker) Ch() <-chan time.Time { + return t.C +} diff --git a/vendor/github.com/pion/interceptor/pkg/stats/interceptor.go b/vendor/github.com/pion/interceptor/pkg/stats/interceptor.go new file mode 100644 index 0000000..de2226f --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/stats/interceptor.go @@ -0,0 +1,246 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package stats provides an interceptor that records RTP/RTCP stream statistics +package stats + +import ( + "sync" + "time" + + "github.com/pion/interceptor" + "github.com/pion/logging" + "github.com/pion/rtcp" + "github.com/pion/rtp" +) + +// Option can be used to configure the stats interceptor. +type Option func(*Interceptor) error + +// SetRecorderFactory sets the factory that is used to create new stats +// recorders for new streams. +func SetRecorderFactory(f RecorderFactory) Option { + return func(i *Interceptor) error { + i.RecorderFactory = f + + return nil + } +} + +// SetNowFunc sets the function the interceptor uses to get a current timestamp. +// This is mostly useful for testing. +func SetNowFunc(now func() time.Time) Option { + return func(i *Interceptor) error { + i.now = now + + return nil + } +} + +// WithLoggerFactory sets the logger factory for the interceptor. +func WithLoggerFactory(loggerFactory logging.LoggerFactory) Option { + return func(i *Interceptor) error { + i.loggerFactory = loggerFactory + + return nil + } +} + +// Getter returns the most recent stats of a stream. +type Getter interface { + Get(ssrc uint32) *Stats +} + +// NewPeerConnectionCallback receives a new StatsGetter for a newly created +// PeerConnection. +type NewPeerConnectionCallback func(string, Getter) + +// InterceptorFactory is a interceptor.Factory for a stats Interceptor. +type InterceptorFactory struct { + opts []Option + addPeerConnection NewPeerConnectionCallback +} + +// NewInterceptor creates a new InterceptorFactory. +func NewInterceptor(opts ...Option) (*InterceptorFactory, error) { + return &InterceptorFactory{ + opts: opts, + addPeerConnection: nil, + }, nil +} + +// OnNewPeerConnection sets the callback that is called when a new +// PeerConnection is created. +func (r *InterceptorFactory) OnNewPeerConnection(cb NewPeerConnectionCallback) { + r.addPeerConnection = cb +} + +// NewInterceptor creates a new Interceptor. +func (r *InterceptorFactory) NewInterceptor(id string) (interceptor.Interceptor, error) { + interceptor := &Interceptor{ + NoOp: interceptor.NoOp{}, + now: time.Now, + lock: sync.Mutex{}, + recorders: map[uint32]Recorder{}, + wg: sync.WaitGroup{}, + } + for _, opt := range r.opts { + if err := opt(interceptor); err != nil { + return nil, err + } + } + + if interceptor.loggerFactory == nil { + interceptor.loggerFactory = logging.NewDefaultLoggerFactory() + } + if interceptor.RecorderFactory == nil { + interceptor.RecorderFactory = func(ssrc uint32, clockRate float64) Recorder { + return newRecorder(ssrc, clockRate, interceptor.loggerFactory) + } + } + + if r.addPeerConnection != nil { + r.addPeerConnection(id, interceptor) + } + + return interceptor, nil +} + +// Recorder is the interface of a statistics recorder. +type Recorder interface { + QueueIncomingRTP(ts time.Time, buf []byte, attr interceptor.Attributes) + QueueIncomingRTCP(ts time.Time, buf []byte, attr interceptor.Attributes) + QueueOutgoingRTP(ts time.Time, header *rtp.Header, payload []byte, attr interceptor.Attributes) + QueueOutgoingRTCP(ts time.Time, pkts []rtcp.Packet, attr interceptor.Attributes) + GetStats() Stats + Stop() + Start() +} + +// RecorderFactory creates new Recorders to be used by the interceptor. +type RecorderFactory func(ssrc uint32, clockRate float64) Recorder + +// Interceptor is the interceptor that collects stream stats. +type Interceptor struct { + interceptor.NoOp + now func() time.Time + lock sync.Mutex + RecorderFactory RecorderFactory + recorders map[uint32]Recorder + wg sync.WaitGroup + loggerFactory logging.LoggerFactory +} + +// Get returns the statistics for the stream with ssrc. +func (r *Interceptor) Get(ssrc uint32) *Stats { + r.lock.Lock() + defer r.lock.Unlock() + if rec, ok := r.recorders[ssrc]; ok { + stats := rec.GetStats() + + return &stats + } + + return nil +} + +func (r *Interceptor) getRecorder(ssrc uint32, clockRate float64) Recorder { + r.lock.Lock() + defer r.lock.Unlock() + if rec, ok := r.recorders[ssrc]; ok { + return rec + } + rec := r.RecorderFactory(ssrc, clockRate) + r.wg.Add(1) + go func() { + defer r.wg.Done() + rec.Start() + }() + r.recorders[ssrc] = rec + + return rec +} + +// Close closes the interceptor and associated stats recorders. +func (r *Interceptor) Close() error { + defer r.wg.Wait() + + r.lock.Lock() + defer r.lock.Unlock() + + for _, r := range r.recorders { + r.Stop() + } + + return nil +} + +// BindRTCPReader lets you modify any incoming RTCP packets. It is called once per sender/receiver, however this might +// change in the future. The returned method will be called once per packet batch. +func (r *Interceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor.RTCPReader { + return interceptor.RTCPReaderFunc( + func(bytes []byte, attributes interceptor.Attributes) (int, interceptor.Attributes, error) { + n, attattributes, err := reader.Read(bytes, attributes) + if err != nil { + return 0, attattributes, err + } + r.lock.Lock() + for _, recorder := range r.recorders { + recorder.QueueIncomingRTCP(r.now(), bytes[:n], attributes) + } + r.lock.Unlock() + + return n, attattributes, err + }, + ) +} + +// BindRTCPWriter lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method +// will be called once per packet batch. +func (r *Interceptor) BindRTCPWriter(writer interceptor.RTCPWriter) interceptor.RTCPWriter { + return interceptor.RTCPWriterFunc(func(pkts []rtcp.Packet, attributes interceptor.Attributes) (int, error) { + r.lock.Lock() + for _, recorder := range r.recorders { + recorder.QueueOutgoingRTCP(r.now(), pkts, attributes) + } + r.lock.Unlock() + + return writer.Write(pkts, attributes) + }) +} + +// BindLocalStream lets you modify any outgoing RTP packets. It is called once for per LocalStream. +// The returned method will be called once per rtp packet. +func (r *Interceptor) BindLocalStream( + info *interceptor.StreamInfo, writer interceptor.RTPWriter, +) interceptor.RTPWriter { + recorder := r.getRecorder(info.SSRC, float64(info.ClockRate)) + + return interceptor.RTPWriterFunc( + func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + recorder.QueueOutgoingRTP(r.now(), header, payload, attributes) + + return writer.Write(header, payload, attributes) + }, + ) +} + +// BindRemoteStream lets you modify any incoming RTP packets. It is called once for per RemoteStream. +// The returned method will be called once per rtp packet. +func (r *Interceptor) BindRemoteStream( + info *interceptor.StreamInfo, reader interceptor.RTPReader, +) interceptor.RTPReader { + recorder := r.getRecorder(info.SSRC, float64(info.ClockRate)) + + return interceptor.RTPReaderFunc( + func(bytes []byte, attributes interceptor.Attributes) (int, interceptor.Attributes, error) { + n, attributes, err := reader.Read(bytes, attributes) + if err != nil { + return 0, nil, err + } + recorder.QueueIncomingRTP(r.now(), bytes[:n], attributes) + + return n, attributes, nil + }, + ) +} diff --git a/vendor/github.com/pion/interceptor/pkg/stats/received_stats.go b/vendor/github.com/pion/interceptor/pkg/stats/received_stats.go new file mode 100644 index 0000000..2db485f --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/stats/received_stats.go @@ -0,0 +1,74 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package stats + +import ( + "fmt" + "time" +) + +// ReceivedRTPStreamStats contains common receiver stats of RTP streams. +type ReceivedRTPStreamStats struct { + PacketsReceived uint64 + PacketsLost int64 + Jitter float64 +} + +// String returns a string representation of ReceivedRTPStreamStats. +func (s ReceivedRTPStreamStats) String() string { + out := fmt.Sprintf("\tPacketsReceived: %v\n", s.PacketsReceived) + out += fmt.Sprintf("\tPacketsLost: %v\n", s.PacketsLost) + out += fmt.Sprintf("\tJitter: %v\n", s.Jitter) + + return out +} + +// InboundRTPStreamStats contains stats of inbound RTP streams. +type InboundRTPStreamStats struct { + ReceivedRTPStreamStats + + LastPacketReceivedTimestamp time.Time + HeaderBytesReceived uint64 + BytesReceived uint64 + FIRCount uint32 + PLICount uint32 + NACKCount uint32 +} + +// String returns a string representation of InboundRTPStreamStats. +func (s InboundRTPStreamStats) String() string { + out := "InboundRTPStreamStats:\n" + out += s.ReceivedRTPStreamStats.String() + out += fmt.Sprintf("\tLastPacketReceivedTimestamp: %v\n", s.LastPacketReceivedTimestamp) + out += fmt.Sprintf("\tHeaderBytesReceived: %v\n", s.HeaderBytesReceived) + out += fmt.Sprintf("\tBytesReceived: %v\n", s.BytesReceived) + out += fmt.Sprintf("\tFIRCount: %v\n", s.FIRCount) + out += fmt.Sprintf("\tPLICount: %v\n", s.PLICount) + out += fmt.Sprintf("\tNACKCount: %v\n", s.NACKCount) + + return out +} + +// RemoteInboundRTPStreamStats contains stats of inbound RTP streams of the +// remote peer. +type RemoteInboundRTPStreamStats struct { + ReceivedRTPStreamStats + + RoundTripTime time.Duration + TotalRoundTripTime time.Duration + FractionLost float64 + RoundTripTimeMeasurements uint64 +} + +// String returns a string representation of RemoteInboundRTPStreamStats. +func (s RemoteInboundRTPStreamStats) String() string { + out := "RemoteInboundRTPStreamStats:\n" + out += s.ReceivedRTPStreamStats.String() + out += fmt.Sprintf("\tRoundTripTime: %v\n", s.RoundTripTime) + out += fmt.Sprintf("\tTotalRoundTripTime: %v\n", s.TotalRoundTripTime) + out += fmt.Sprintf("\tFractionLost: %v\n", s.FractionLost) + out += fmt.Sprintf("\tRoundTripTimeMeasurements: %v\n", s.RoundTripTimeMeasurements) + + return out +} diff --git a/vendor/github.com/pion/interceptor/pkg/stats/sent_stats.go b/vendor/github.com/pion/interceptor/pkg/stats/sent_stats.go new file mode 100644 index 0000000..b96bf7b --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/stats/sent_stats.go @@ -0,0 +1,70 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package stats + +import ( + "fmt" + "time" +) + +// SentRTPStreamStats contains common sender stats of RTP streams. +type SentRTPStreamStats struct { + PacketsSent uint64 + BytesSent uint64 +} + +// String returns a string representation of SentRTPStreamStats. +func (s SentRTPStreamStats) String() string { + out := fmt.Sprintf("\tPacketsSent: %v\n", s.PacketsSent) + out += fmt.Sprintf("\tBytesSent: %v\n", s.BytesSent) + + return out +} + +// OutboundRTPStreamStats contains stats of outbound RTP streams. +type OutboundRTPStreamStats struct { + SentRTPStreamStats + + HeaderBytesSent uint64 + NACKCount uint32 + FIRCount uint32 + PLICount uint32 +} + +// String returns a string representation of OutboundRTPStreamStats. +func (s OutboundRTPStreamStats) String() string { + out := "OutboundRTPStreamStats\n" + out += s.SentRTPStreamStats.String() + out += fmt.Sprintf("\tHeaderBytesSent: %v\n", s.HeaderBytesSent) + out += fmt.Sprintf("\tNACKCount: %v\n", s.NACKCount) + out += fmt.Sprintf("\tFIRCount: %v\n", s.FIRCount) + out += fmt.Sprintf("\tPLICount: %v\n", s.PLICount) + + return out +} + +// RemoteOutboundRTPStreamStats contains stats of outbound RTP streams of the +// remote peer. +type RemoteOutboundRTPStreamStats struct { + SentRTPStreamStats + + RemoteTimeStamp time.Time + ReportsSent uint64 + RoundTripTime time.Duration + TotalRoundTripTime time.Duration + RoundTripTimeMeasurements uint64 +} + +// String returns a string representation of RemoteOutboundRTPStreamStats. +func (s RemoteOutboundRTPStreamStats) String() string { + out := "RemoteOutboundRTPStreamStats:\n" + out += s.SentRTPStreamStats.String() + out += fmt.Sprintf("\tRemoteTimeStamp: %v\n", s.RemoteTimeStamp) + out += fmt.Sprintf("\tReportsSent: %v\n", s.ReportsSent) + out += fmt.Sprintf("\tRoundTripTime: %v\n", s.RoundTripTime) + out += fmt.Sprintf("\tTotalRoundTripTime: %v\n", s.TotalRoundTripTime) + out += fmt.Sprintf("\tRoundTripTimeMeasurements: %v\n", s.RoundTripTimeMeasurements) + + return out +} diff --git a/vendor/github.com/pion/interceptor/pkg/stats/stats_recorder.go b/vendor/github.com/pion/interceptor/pkg/stats/stats_recorder.go new file mode 100644 index 0000000..29242cc --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/stats/stats_recorder.go @@ -0,0 +1,406 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package stats + +import ( + "slices" + "sync" + "sync/atomic" + "time" + + "github.com/pion/interceptor" + "github.com/pion/interceptor/internal/ntp" + "github.com/pion/interceptor/internal/sequencenumber" + "github.com/pion/logging" + "github.com/pion/rtcp" + "github.com/pion/rtp" +) + +// Stats contains all the available statistics of RTP streams. +type Stats struct { + InboundRTPStreamStats + OutboundRTPStreamStats + RemoteInboundRTPStreamStats + RemoteOutboundRTPStreamStats +} + +type internalStats struct { + inboundSequencerNumber sequencenumber.Unwrapper + inboundSequenceNumberInitialized bool + inboundFirstSequenceNumber int64 + inboundHighestSequenceNumber int64 + + inboundLastArrivalInitialized bool + inboundLastArrival time.Time + inboundLastArrivalRTP uint32 + inboundLastTransit int + + remoteInboundFirstSequenceNumberInitialized bool + remoteInboundFirstSequenceNumber int64 + + lastSenderReports []uint64 + + lastReceiverReferenceTimes []uint64 + + InboundRTPStreamStats + OutboundRTPStreamStats + + RemoteInboundRTPStreamStats + RemoteOutboundRTPStreamStats +} + +type incomingRTP struct { + ts time.Time + header rtp.Header + payloadLen int + attr interceptor.Attributes +} + +type incomingRTCP struct { + ts time.Time + pkts []rtcp.Packet + attr interceptor.Attributes +} + +type outgoingRTP struct { + ts time.Time + header rtp.Header + payloadLen int + attr interceptor.Attributes +} + +type outgoingRTCP struct { + ts time.Time + pkts []rtcp.Packet + attr interceptor.Attributes +} + +type recorder struct { + logger logging.LeveledLogger + + ssrc uint32 + clockRate float64 + + maxLastSenderReports int + maxLastReceiverReferenceTimes int + + latestStats *internalStats + ms *sync.Mutex // Locks latestStats + running uint32 +} + +func newRecorder(ssrc uint32, clockRate float64, loggerFactory logging.LoggerFactory) *recorder { + return &recorder{ + logger: loggerFactory.NewLogger("stats_recorder"), + ssrc: ssrc, + clockRate: clockRate, + maxLastSenderReports: 5, + maxLastReceiverReferenceTimes: 5, + latestStats: &internalStats{}, + ms: &sync.Mutex{}, + } +} + +func (r *recorder) Stop() { + atomic.StoreUint32(&r.running, 0) +} + +func (r *recorder) GetStats() Stats { + r.ms.Lock() + defer r.ms.Unlock() + + return Stats{ + InboundRTPStreamStats: r.latestStats.InboundRTPStreamStats, + OutboundRTPStreamStats: r.latestStats.OutboundRTPStreamStats, + RemoteInboundRTPStreamStats: r.latestStats.RemoteInboundRTPStreamStats, + RemoteOutboundRTPStreamStats: r.latestStats.RemoteOutboundRTPStreamStats, + } +} + +func (r *recorder) recordIncomingRTP(latestStats internalStats, incoming *incomingRTP) internalStats { + if incoming.header.SSRC != r.ssrc { + return latestStats + } + sequenceNumber := latestStats.inboundSequencerNumber.Unwrap(incoming.header.SequenceNumber) + if !latestStats.inboundSequenceNumberInitialized { + latestStats.inboundFirstSequenceNumber = sequenceNumber + latestStats.inboundSequenceNumberInitialized = true + } + if sequenceNumber > latestStats.inboundHighestSequenceNumber { + latestStats.inboundHighestSequenceNumber = sequenceNumber + } + + latestStats.InboundRTPStreamStats.PacketsReceived++ + expectedPackets := latestStats.inboundHighestSequenceNumber - latestStats.inboundFirstSequenceNumber + 1 + //nolint:gosec // G115 + latestStats.InboundRTPStreamStats.PacketsLost = expectedPackets - + int64(latestStats.InboundRTPStreamStats.PacketsReceived) + + if !latestStats.inboundLastArrivalInitialized { + latestStats.inboundLastArrival = incoming.ts + latestStats.inboundLastArrivalRTP = incoming.header.Timestamp + latestStats.inboundLastArrivalInitialized = true + } else { + rtpUnitsSinceLastArrival := incoming.ts.Sub(latestStats.inboundLastArrival).Seconds() * r.clockRate + arrival := latestStats.inboundLastArrivalRTP + uint32(rtpUnitsSinceLastArrival) + transit := int(arrival) - int(incoming.header.Timestamp) + d := transit - latestStats.inboundLastTransit + if d < 0 { + d = -d + } + dSec := float64(d) / r.clockRate + latestStats.inboundLastTransit = transit + latestStats.InboundRTPStreamStats.Jitter += (1.0 / 16.0) * (dSec - latestStats.InboundRTPStreamStats.Jitter) + latestStats.inboundLastArrival = incoming.ts + latestStats.inboundLastArrivalRTP = incoming.header.Timestamp + } + + latestStats.LastPacketReceivedTimestamp = incoming.ts + latestStats.HeaderBytesReceived += uint64(incoming.header.MarshalSize()) //nolint:gosec // G115 + latestStats.BytesReceived += uint64(incoming.header.MarshalSize() + incoming.payloadLen) //nolint:gosec // G115 + + return latestStats +} + +//nolint:cyclop +func (r *recorder) recordOutgoingRTCP(latestStats internalStats, v *outgoingRTCP) internalStats { + for _, pkt := range v.pkts { + // The SSRC check is performed for most of the cases but not all. The + // reason is that ReceiverReferenceTimeReportBlocks don't have + // destination SSRCs but must still be recorded. + switch rtcpPkt := pkt.(type) { + case *rtcp.FullIntraRequest: + if !contains(pkt.DestinationSSRC(), r.ssrc) { + r.logger.Debugf("skipping outgoing RTCP pkt: %v", pkt) + + continue + } + latestStats.InboundRTPStreamStats.FIRCount++ + case *rtcp.PictureLossIndication: + if !contains(pkt.DestinationSSRC(), r.ssrc) { + r.logger.Debugf("skipping outgoing RTCP pkt: %v", pkt) + + continue + } + latestStats.InboundRTPStreamStats.PLICount++ + case *rtcp.TransportLayerNack: + if !contains(pkt.DestinationSSRC(), r.ssrc) { + r.logger.Debugf("skipping outgoing RTCP pkt: %v", pkt) + + continue + } + latestStats.InboundRTPStreamStats.NACKCount++ + case *rtcp.SenderReport: + if !contains(pkt.DestinationSSRC(), r.ssrc) { + r.logger.Debugf("skipping outgoing RTCP pkt: %v", pkt) + + continue + } + latestStats.lastSenderReports = append(latestStats.lastSenderReports, rtcpPkt.NTPTime) + if len(latestStats.lastSenderReports) > r.maxLastSenderReports { + latestStats.lastSenderReports = latestStats.lastSenderReports[len( + latestStats.lastSenderReports, + )-r.maxLastSenderReports:] + } + case *rtcp.ExtendedReport: + for _, block := range rtcpPkt.Reports { + if xr, ok := block.(*rtcp.ReceiverReferenceTimeReportBlock); ok { + latestStats.lastReceiverReferenceTimes = append(latestStats.lastReceiverReferenceTimes, xr.NTPTimestamp) + if len(latestStats.lastReceiverReferenceTimes) > r.maxLastReceiverReferenceTimes { + latestStats.lastReceiverReferenceTimes = latestStats.lastReceiverReferenceTimes[len( + latestStats.lastReceiverReferenceTimes, + )-r.maxLastReceiverReferenceTimes:] + } + } + } + } + } + + return latestStats +} + +func (r *recorder) recordOutgoingRTP(latestStats internalStats, v *outgoingRTP) internalStats { + if v.header.SSRC != r.ssrc { + return latestStats + } + headerSize := v.header.MarshalSize() + latestStats.OutboundRTPStreamStats.PacketsSent++ + latestStats.OutboundRTPStreamStats.BytesSent += uint64(headerSize + v.payloadLen) //nolint:gosec // G115 + latestStats.HeaderBytesSent += uint64(headerSize) //nolint:gosec // G115 + if !latestStats.remoteInboundFirstSequenceNumberInitialized { + latestStats.remoteInboundFirstSequenceNumber = int64(v.header.SequenceNumber) + latestStats.remoteInboundFirstSequenceNumberInitialized = true + } + + return latestStats +} + +func (r *recorder) recordIncomingRR(latestStats internalStats, pkt *rtcp.ReceiverReport, ts time.Time) internalStats { + for _, report := range pkt.Reports { + if latestStats.remoteInboundFirstSequenceNumberInitialized { + cycles := uint64(report.LastSequenceNumber&0xFFFF0000) >> 16 + nr := uint64(report.LastSequenceNumber & 0x0000FFFF) + highest := cycles*(0xFFFF+1) + nr + //nolint:gosec // G115 + expected := int64(highest) - latestStats.remoteInboundFirstSequenceNumber + 1 + received := max(expected-int64(report.TotalLost), 0) + //nolint:gosec // G115 + latestStats.RemoteInboundRTPStreamStats.PacketsReceived = uint64(received) + } + latestStats.RemoteInboundRTPStreamStats.PacketsLost = int64(report.TotalLost) + latestStats.RemoteInboundRTPStreamStats.Jitter = float64(report.Jitter) / r.clockRate + + if report.Delay != 0 && report.LastSenderReport != 0 { + for i := min(r.maxLastSenderReports, len(latestStats.lastSenderReports)) - 1; i >= 0; i-- { + lastReport := latestStats.lastSenderReports[i] + if (lastReport&0x0000FFFFFFFF0000)>>16 == uint64(report.LastSenderReport) { + dlsr := time.Duration(float64(report.Delay) / 65536.0 * float64(time.Second)) + latestStats.RemoteInboundRTPStreamStats.RoundTripTime = (ts.Add(-dlsr)).Sub(ntp.ToTime(lastReport)) + latestStats.RemoteInboundRTPStreamStats.TotalRoundTripTime += latestStats.RemoteInboundRTPStreamStats.RoundTripTime + latestStats.RemoteInboundRTPStreamStats.RoundTripTimeMeasurements++ + + break + } + } + } + latestStats.FractionLost = float64(report.FractionLost) / 256.0 + } + + return latestStats +} + +func (r *recorder) recordIncomingXR(latestStats internalStats, pkt *rtcp.ExtendedReport, ts time.Time) internalStats { + for _, report := range pkt.Reports { + if xr, ok := report.(*rtcp.DLRRReportBlock); ok { + for _, xrReport := range xr.Reports { + if xrReport.LastRR != 0 && xrReport.DLRR != 0 { + for i := min(r.maxLastReceiverReferenceTimes, len(latestStats.lastReceiverReferenceTimes)) - 1; i >= 0; i-- { + lastRR := latestStats.lastReceiverReferenceTimes[i] + if (lastRR&0x0000FFFFFFFF0000)>>16 == uint64(xrReport.LastRR) { + dlrr := time.Duration(float64(xrReport.DLRR) / 65536.0 * float64(time.Second)) + latestStats.RemoteOutboundRTPStreamStats.RoundTripTime = (ts.Add(-dlrr)).Sub(ntp.ToTime(lastRR)) + //nolint:lll + latestStats.RemoteOutboundRTPStreamStats.TotalRoundTripTime += latestStats.RemoteOutboundRTPStreamStats.RoundTripTime + latestStats.RemoteOutboundRTPStreamStats.RoundTripTimeMeasurements++ + } + } + } + } + } + } + + return latestStats +} + +func contains(ls []uint32, e uint32) bool { + return slices.Contains(ls, e) +} + +func (r *recorder) recordIncomingRTCP(latestStats internalStats, incoming *incomingRTCP) internalStats { + for _, pkt := range incoming.pkts { + if !contains(pkt.DestinationSSRC(), r.ssrc) { + r.logger.Debugf("skipping incoming RTCP pkt: %v", pkt) + + continue + } + switch pkt := pkt.(type) { + case *rtcp.TransportLayerNack: + latestStats.OutboundRTPStreamStats.NACKCount++ + case *rtcp.FullIntraRequest: + latestStats.OutboundRTPStreamStats.FIRCount++ + case *rtcp.PictureLossIndication: + latestStats.OutboundRTPStreamStats.PLICount++ + case *rtcp.ReceiverReport: + latestStats = r.recordIncomingRR(latestStats, pkt, incoming.ts) + case *rtcp.SenderReport: + latestStats.RemoteOutboundRTPStreamStats.PacketsSent = uint64(pkt.PacketCount) + latestStats.RemoteOutboundRTPStreamStats.BytesSent = uint64(pkt.OctetCount) + latestStats.RemoteTimeStamp = ntp.ToTime(pkt.NTPTime) + latestStats.ReportsSent++ + + case *rtcp.ExtendedReport: + return r.recordIncomingXR(latestStats, pkt, incoming.ts) + } + } + + return latestStats +} + +func (r *recorder) Start() { + atomic.StoreUint32(&r.running, 1) +} + +func (r *recorder) QueueIncomingRTP(ts time.Time, buf []byte, attr interceptor.Attributes) { + if atomic.LoadUint32(&r.running) == 0 { + return + } + if attr == nil { + attr = make(interceptor.Attributes) + } + header, err := attr.GetRTPHeader(buf) + if err != nil { + r.logger.Warnf("failed to get RTP Header, skipping incoming RTP packet in stats calculation: %v", err) + + return + } + hdr := header.Clone() + r.ms.Lock() + *r.latestStats = r.recordIncomingRTP(*r.latestStats, &incomingRTP{ + ts: ts, + header: hdr, + payloadLen: len(buf) - hdr.MarshalSize(), + attr: attr, + }) + r.ms.Unlock() +} + +func (r *recorder) QueueIncomingRTCP(ts time.Time, buf []byte, attr interceptor.Attributes) { + if atomic.LoadUint32(&r.running) == 0 { + return + } + if attr == nil { + attr = make(interceptor.Attributes) + } + pkts, err := attr.GetRTCPPackets(buf) + if err != nil { + r.logger.Warnf("failed to get RTCP packets, skipping incoming RTCP packet in stats calculation: %v", err) + + return + } + r.ms.Lock() + *r.latestStats = r.recordIncomingRTCP(*r.latestStats, &incomingRTCP{ + ts: ts, + pkts: pkts, + attr: attr, + }) + r.ms.Unlock() +} + +func (r *recorder) QueueOutgoingRTP(ts time.Time, header *rtp.Header, payload []byte, attr interceptor.Attributes) { + if atomic.LoadUint32(&r.running) == 0 { + return + } + hdr := header.Clone() + r.ms.Lock() + *r.latestStats = r.recordOutgoingRTP(*r.latestStats, &outgoingRTP{ + ts: ts, + header: hdr, + payloadLen: len(payload), + attr: attr, + }) + r.ms.Unlock() +} + +func (r *recorder) QueueOutgoingRTCP(ts time.Time, pkts []rtcp.Packet, attr interceptor.Attributes) { + if atomic.LoadUint32(&r.running) == 0 { + return + } + r.ms.Lock() + *r.latestStats = r.recordOutgoingRTCP(*r.latestStats, &outgoingRTCP{ + ts: ts, + pkts: pkts, + attr: attr, + }) + r.ms.Unlock() +} diff --git a/vendor/github.com/pion/interceptor/pkg/twcc/arrival_time_map.go b/vendor/github.com/pion/interceptor/pkg/twcc/arrival_time_map.go new file mode 100644 index 0000000..3b0b394 --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/twcc/arrival_time_map.go @@ -0,0 +1,204 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package twcc + +const ( + minCapacity = 128 + maxNumberOfPackets = 1 << 15 +) + +// packetArrivalTimeMap is adapted from Chrome's implementation of TWCC, and keeps track +// of the arrival times of packets. It is used by the TWCC interceptor to build feedback +// packets. +// See https://source.chromium.org/chromium/chromium/src/+/refs/heads/main:third_party/webrtc/modules/remote_bitrate_estimator/packet_arrival_map.h;drc=b5cd13bb6d5d157a5fbe3628b2dd1c1e106203c6 +// +//nolint:lll +type packetArrivalTimeMap struct { + // arrivalTimes is a circular buffer, where the packet with sequence number sn is stored + // in slot sn % len(arrivalTimes) + arrivalTimes []int64 + + // The unwrapped sequence numbers for the range of valid sequence numbers in arrivalTimes. + // beginSequenceNumber is inclusive, and endSequenceNumber is exclusive. + beginSequenceNumber, endSequenceNumber int64 +} + +// AddPacket records the fact that the packet with sequence number sequenceNumber arrived +// at arrivalTime. +func (m *packetArrivalTimeMap) AddPacket(sequenceNumber int64, arrivalTime int64) { + if m.arrivalTimes == nil { + // First packet + m.reallocate(minCapacity) + m.beginSequenceNumber = sequenceNumber + m.endSequenceNumber = sequenceNumber + 1 + m.arrivalTimes[m.index(sequenceNumber)] = arrivalTime + + return + } + + if sequenceNumber >= m.beginSequenceNumber && sequenceNumber < m.endSequenceNumber { + // The packet is within the buffer, no need to resize. + m.arrivalTimes[m.index(sequenceNumber)] = arrivalTime + + return + } + + if sequenceNumber < m.beginSequenceNumber { + // The packet goes before the current buffer. Expand to add packet, + // but only if it fits within the maximum number of packets. + newSize := int(m.endSequenceNumber - sequenceNumber) + if newSize > maxNumberOfPackets { + // Don't expand the buffer back for this packet, as it would remove newer received + // packets. + return + } + m.adjustToSize(newSize) + m.arrivalTimes[m.index(sequenceNumber)] = arrivalTime + m.setNotReceived(sequenceNumber+1, m.beginSequenceNumber) + m.beginSequenceNumber = sequenceNumber + + return + } + + // The packet goes after the buffer. + newEndSequenceNumber := sequenceNumber + 1 + + if newEndSequenceNumber >= m.endSequenceNumber+maxNumberOfPackets { + // All old packets have to be removed. + m.beginSequenceNumber = sequenceNumber + m.endSequenceNumber = newEndSequenceNumber + m.arrivalTimes[m.index(sequenceNumber)] = arrivalTime + + return + } + + if m.beginSequenceNumber < newEndSequenceNumber-maxNumberOfPackets { + // Remove oldest entries. + m.beginSequenceNumber = newEndSequenceNumber - maxNumberOfPackets + } + + m.adjustToSize(int(newEndSequenceNumber - m.beginSequenceNumber)) + + // Packets can be received out of order. If this isn't the next expected packet, + // add enough placeholders to fill the gap. + m.setNotReceived(m.endSequenceNumber, sequenceNumber) + m.endSequenceNumber = newEndSequenceNumber + m.arrivalTimes[m.index(sequenceNumber)] = arrivalTime +} + +func (m *packetArrivalTimeMap) setNotReceived(startInclusive, endExclusive int64) { + for sn := startInclusive; sn < endExclusive; sn++ { + m.arrivalTimes[m.index(sn)] = -1 + } +} + +// BeginSequenceNumber returns the first valid sequence number in the map. +func (m *packetArrivalTimeMap) BeginSequenceNumber() int64 { + return m.beginSequenceNumber +} + +// EndSequenceNumber returns the first sequence number after the last valid sequence number in the map. +func (m *packetArrivalTimeMap) EndSequenceNumber() int64 { + return m.endSequenceNumber +} + +// FindNextAtOrAfter returns the sequence number and timestamp of the first received packet that has a sequence number +// greater or equal to sequenceNumber. +func (m *packetArrivalTimeMap) FindNextAtOrAfter(sequenceNumber int64) ( + int64, int64, bool, +) { + for seq := m.Clamp(sequenceNumber); seq < m.endSequenceNumber; seq++ { + if arrivalTime := m.get(seq); arrivalTime >= 0 { + return seq, arrivalTime, true + } + } + + return -1, -1, false +} + +// EraseTo erases all elements from the beginning of the map until sequenceNumber. +func (m *packetArrivalTimeMap) EraseTo(sequenceNumber int64) { + if sequenceNumber < m.beginSequenceNumber { + return + } + if sequenceNumber >= m.endSequenceNumber { + // Erase all. + m.beginSequenceNumber = m.endSequenceNumber + + return + } + // Remove some + m.beginSequenceNumber = sequenceNumber + m.adjustToSize(int(m.endSequenceNumber - m.beginSequenceNumber)) +} + +// RemoveOldPackets removes packets from the beginning of the map as long as they are before +// sequenceNumber and with an age older than arrivalTimeLimit. +func (m *packetArrivalTimeMap) RemoveOldPackets(sequenceNumber int64, arrivalTimeLimit int64) { + checkTo := min(sequenceNumber, m.endSequenceNumber) + for m.beginSequenceNumber < checkTo && m.get(m.beginSequenceNumber) <= arrivalTimeLimit { + m.beginSequenceNumber++ + } + m.adjustToSize(int(m.endSequenceNumber - m.beginSequenceNumber)) +} + +// HasReceived returns whether a packet with the sequence number has been received. +func (m *packetArrivalTimeMap) HasReceived(sequenceNumber int64) bool { + return m.get(sequenceNumber) >= 0 +} + +// Clamp returns sequenceNumber clamped to [beginSequenceNumber, endSequenceNumber]. +func (m *packetArrivalTimeMap) Clamp(sequenceNumber int64) int64 { + if sequenceNumber < m.beginSequenceNumber { + return m.beginSequenceNumber + } + if m.endSequenceNumber < sequenceNumber { + return m.endSequenceNumber + } + + return sequenceNumber +} + +func (m *packetArrivalTimeMap) get(sequenceNumber int64) int64 { + if sequenceNumber < m.beginSequenceNumber || sequenceNumber >= m.endSequenceNumber { + return -1 + } + + return m.arrivalTimes[m.index(sequenceNumber)] +} + +func (m *packetArrivalTimeMap) index(sequenceNumber int64) int { + // Sequence number might be negative, and we always guarantee that arrivalTimes + // length is a power of 2, so it's easier to use "&" instead of "%" + return int(sequenceNumber & int64(m.capacity()-1)) +} + +func (m *packetArrivalTimeMap) adjustToSize(newSize int) { + if newSize > m.capacity() { + newCapacity := m.capacity() + for newCapacity < newSize { + newCapacity *= 2 + } + m.reallocate(newCapacity) + } + if m.capacity() > max(minCapacity, newSize*4) { + newCapacity := m.capacity() + for newCapacity >= 2*max(newSize, minCapacity) { + newCapacity /= 2 + } + m.reallocate(newCapacity) + } +} + +func (m *packetArrivalTimeMap) capacity() int { + return len(m.arrivalTimes) +} + +func (m *packetArrivalTimeMap) reallocate(newCapacity int) { + newBuffer := make([]int64, newCapacity) + for sn := m.beginSequenceNumber; sn < m.endSequenceNumber; sn++ { + newBuffer[int(sn&(int64(newCapacity-1)))] = m.get(sn) + } + m.arrivalTimes = newBuffer +} diff --git a/vendor/github.com/pion/interceptor/pkg/twcc/header_extension_interceptor.go b/vendor/github.com/pion/interceptor/pkg/twcc/header_extension_interceptor.go new file mode 100644 index 0000000..143aa5e --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/twcc/header_extension_interceptor.go @@ -0,0 +1,74 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package twcc + +import ( + "errors" + "sync/atomic" + + "github.com/pion/interceptor" + "github.com/pion/rtp" +) + +var errHeaderIsNil = errors.New("header is nil") + +// HeaderExtensionInterceptorFactory is a interceptor.Factory for a HeaderExtensionInterceptor. +type HeaderExtensionInterceptorFactory struct{} + +// NewInterceptor constructs a new HeaderExtensionInterceptor. +func (h *HeaderExtensionInterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) { + return &HeaderExtensionInterceptor{}, nil +} + +// NewHeaderExtensionInterceptor returns a HeaderExtensionInterceptorFactory. +func NewHeaderExtensionInterceptor() (*HeaderExtensionInterceptorFactory, error) { + return &HeaderExtensionInterceptorFactory{}, nil +} + +// HeaderExtensionInterceptor adds transport wide sequence numbers as header extension to each RTP packet. +type HeaderExtensionInterceptor struct { + interceptor.NoOp + nextSequenceNr uint32 +} + +const transportCCURI = "http://www.ietf.org/id/draft-holmer-rmcat-transport-wide-cc-extensions-01" + +// BindLocalStream returns a writer that adds a rtp.TransportCCExtension +// header with increasing sequence numbers to each outgoing packet. +func (h *HeaderExtensionInterceptor) BindLocalStream( + info *interceptor.StreamInfo, + writer interceptor.RTPWriter, +) interceptor.RTPWriter { + var hdrExtID uint8 + for _, e := range info.RTPHeaderExtensions { + if e.URI == transportCCURI { + hdrExtID = uint8(e.ID) //nolint:gosec // G115 + + break + } + } + if hdrExtID == 0 { // Don't add header extension if ID is 0, because 0 is an invalid extension ID + return writer + } + + return interceptor.RTPWriterFunc( + func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + sequenceNumber := atomic.AddUint32(&h.nextSequenceNr, 1) - 1 + //nolint:gosec // G115 + tcc, err := (&rtp.TransportCCExtension{TransportSequence: uint16(sequenceNumber)}).Marshal() + if err != nil { + return 0, err + } + if header == nil { + return 0, errHeaderIsNil + } + err = header.SetExtension(hdrExtID, tcc) + if err != nil { + return 0, err + } + + return writer.Write(header, payload, attributes) + }, + ) +} diff --git a/vendor/github.com/pion/interceptor/pkg/twcc/sender_interceptor.go b/vendor/github.com/pion/interceptor/pkg/twcc/sender_interceptor.go new file mode 100644 index 0000000..782dd1b --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/twcc/sender_interceptor.go @@ -0,0 +1,234 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package twcc + +import ( + "errors" + "math/rand" + "sync" + "time" + + "github.com/pion/interceptor" + "github.com/pion/logging" + "github.com/pion/rtp" +) + +// SenderInterceptorFactory is a interceptor.Factory for a SenderInterceptor. +type SenderInterceptorFactory struct { + opts []Option +} + +var errClosed = errors.New("interceptor is closed") + +// NewInterceptor constructs a new SenderInterceptor. +func (s *SenderInterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) { + senderInterceptor := &SenderInterceptor{ + packetChan: make(chan packet), + close: make(chan struct{}), + interval: 100 * time.Millisecond, + startTime: time.Now(), + } + + for _, opt := range s.opts { + err := opt(senderInterceptor) + if err != nil { + return nil, err + } + } + + if senderInterceptor.loggerFactory == nil { + senderInterceptor.loggerFactory = logging.NewDefaultLoggerFactory() + } + if senderInterceptor.log == nil { + senderInterceptor.log = senderInterceptor.loggerFactory.NewLogger("twcc_sender_interceptor") + } + + return senderInterceptor, nil +} + +// NewSenderInterceptor returns a new SenderInterceptorFactory configured with the given options. +func NewSenderInterceptor(opts ...Option) (*SenderInterceptorFactory, error) { + return &SenderInterceptorFactory{opts: opts}, nil +} + +// SenderInterceptor sends transport wide congestion control reports as specified in: +// https://datatracker.ietf.org/doc/html/draft-holmer-rmcat-transport-wide-cc-extensions-01 +type SenderInterceptor struct { + interceptor.NoOp + + log logging.LeveledLogger + loggerFactory logging.LoggerFactory + + m sync.Mutex + wg sync.WaitGroup + close chan struct{} + + interval time.Duration + startTime time.Time + + recorder *Recorder + packetChan chan packet +} + +// An Option is a function that can be used to configure a SenderInterceptor. +type Option func(*SenderInterceptor) error + +// SendInterval sets the interval at which the interceptor +// will send new feedback reports. +func SendInterval(interval time.Duration) Option { + return func(s *SenderInterceptor) error { + s.interval = interval + + return nil + } +} + +// WithLoggerFactory sets the logger factory for the interceptor. +func WithLoggerFactory(loggerFactory logging.LoggerFactory) Option { + return func(s *SenderInterceptor) error { + s.loggerFactory = loggerFactory + + return nil + } +} + +// BindRTCPWriter lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method +// will be called once per packet batch. +func (s *SenderInterceptor) BindRTCPWriter(writer interceptor.RTCPWriter) interceptor.RTCPWriter { + s.m.Lock() + defer s.m.Unlock() + + s.recorder = NewRecorder(rand.Uint32()) // #nosec + + if s.isClosed() { + return writer + } + + s.wg.Add(1) + + go s.loop(writer) + + return writer +} + +type packet struct { + hdr *rtp.Header + sequenceNumber uint16 + arrivalTime int64 + ssrc uint32 +} + +// BindRemoteStream lets you modify any incoming RTP packets. +// It is called once for per RemoteStream. The returned method +// will be called once per rtp packet. +// +//nolint:cyclop +func (s *SenderInterceptor) BindRemoteStream( + info *interceptor.StreamInfo, reader interceptor.RTPReader, +) interceptor.RTPReader { + var hdrExtID uint8 + for _, e := range info.RTPHeaderExtensions { + if e.URI == transportCCURI { + hdrExtID = uint8(e.ID) //nolint:gosec // G115 + + break + } + } + if hdrExtID == 0 { // Don't try to read header extension if ID is 0, because 0 is an invalid extension ID + return reader + } + + return interceptor.RTPReaderFunc( + func(buf []byte, attributes interceptor.Attributes) (int, interceptor.Attributes, error) { + i, attr, err := reader.Read(buf, attributes) + if err != nil { + return 0, nil, err + } + + if attr == nil { + attr = make(interceptor.Attributes) + } + header, err := attr.GetRTPHeader(buf[:i]) + if err != nil { + return 0, nil, err + } + var tccExt rtp.TransportCCExtension + if ext := header.GetExtension(hdrExtID); ext != nil { + err = tccExt.Unmarshal(ext) + if err != nil { + return 0, nil, err + } + + p := packet{ + hdr: header, + sequenceNumber: tccExt.TransportSequence, + arrivalTime: time.Since(s.startTime).Microseconds(), + ssrc: info.SSRC, + } + select { + case <-s.close: + return 0, nil, errClosed + case s.packetChan <- p: + } + } + + return i, attr, nil + }, + ) +} + +// Close closes the interceptor. +func (s *SenderInterceptor) Close() error { + defer s.wg.Wait() + s.m.Lock() + defer s.m.Unlock() + + if !s.isClosed() { + close(s.close) + } + + return nil +} + +func (s *SenderInterceptor) isClosed() bool { + select { + case <-s.close: + return true + default: + return false + } +} + +func (s *SenderInterceptor) loop(writer interceptor.RTCPWriter) { + defer s.wg.Done() + + select { + case <-s.close: + return + case p := <-s.packetChan: + s.recorder.Record(p.ssrc, p.sequenceNumber, p.arrivalTime) + } + + ticker := time.NewTicker(s.interval) + for { + select { + case <-s.close: + ticker.Stop() + + return + case p := <-s.packetChan: + s.recorder.Record(p.ssrc, p.sequenceNumber, p.arrivalTime) + + case <-ticker.C: + // build and send twcc + pkts := s.recorder.BuildFeedbackPacket() + if len(pkts) == 0 { + continue + } + if _, err := writer.Write(pkts, nil); err != nil { + s.log.Error(err.Error()) + } + } + } +} diff --git a/vendor/github.com/pion/interceptor/pkg/twcc/twcc.go b/vendor/github.com/pion/interceptor/pkg/twcc/twcc.go new file mode 100644 index 0000000..2648e85 --- /dev/null +++ b/vendor/github.com/pion/interceptor/pkg/twcc/twcc.go @@ -0,0 +1,349 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package twcc provides interceptors to implement transport wide congestion control. +package twcc + +import ( + "math" + + "github.com/pion/interceptor/internal/sequencenumber" + "github.com/pion/rtcp" +) + +const ( + packetWindowMicroseconds = 500_000 + maxMissingSequenceNumbers = 0x7FFE +) + +// Recorder records incoming RTP packets and their delays and creates +// transport wide congestion control feedback reports as specified in +// https://datatracker.ietf.org/doc/html/draft-holmer-rmcat-transport-wide-cc-extensions-01 +type Recorder struct { + arrivalTimeMap packetArrivalTimeMap + + sequenceUnwrapper sequencenumber.Unwrapper + + // startSequenceNumber is the first sequence number that will be included in the the + // next feedback packet. + startSequenceNumber *int64 + + senderSSRC uint32 + mediaSSRC uint32 + fbPktCnt uint8 + + packetsHeld int +} + +// NewRecorder creates a new Recorder which uses the given senderSSRC in the created +// feedback packets. +func NewRecorder(senderSSRC uint32) *Recorder { + return &Recorder{ + senderSSRC: senderSSRC, + } +} + +// Record marks a packet with mediaSSRC and a transport wide sequence number sequenceNumber as received at arrivalTime. +func (r *Recorder) Record(mediaSSRC uint32, sequenceNumber uint16, arrivalTime int64) { + r.mediaSSRC = mediaSSRC + + // "Unwrap" the sequence number to get a monotonically increasing sequence number that + // won't wrap around after math.MaxUint16. + unwrappedSN := r.sequenceUnwrapper.Unwrap(sequenceNumber) + r.maybeCullOldPackets(unwrappedSN, arrivalTime) + if r.startSequenceNumber == nil || unwrappedSN < *r.startSequenceNumber { + r.startSequenceNumber = &unwrappedSN + } + + // We are only interested in the first time a packet is received. + if r.arrivalTimeMap.HasReceived(unwrappedSN) { + return + } + + r.arrivalTimeMap.AddPacket(unwrappedSN, arrivalTime) + r.packetsHeld++ + + // Limit the range of sequence numbers to send feedback for. + if *r.startSequenceNumber < r.arrivalTimeMap.BeginSequenceNumber() { + sn := r.arrivalTimeMap.BeginSequenceNumber() + r.startSequenceNumber = &sn + } +} + +func (r *Recorder) maybeCullOldPackets(sequenceNumber int64, arrivalTime int64) { + if r.startSequenceNumber != nil && *r.startSequenceNumber >= r.arrivalTimeMap.EndSequenceNumber() && + arrivalTime >= packetWindowMicroseconds { + r.arrivalTimeMap.RemoveOldPackets(sequenceNumber, arrivalTime-packetWindowMicroseconds) + } +} + +// PacketsHeld returns the number of received packets currently held by the recorder. +func (r *Recorder) PacketsHeld() int { + return r.packetsHeld +} + +// BuildFeedbackPacket creates a new RTCP packet containing a TWCC feedback report. +func (r *Recorder) BuildFeedbackPacket() []rtcp.Packet { + if r.startSequenceNumber == nil { + return nil + } + + endSN := r.arrivalTimeMap.EndSequenceNumber() + var feedbacks []rtcp.Packet + for *r.startSequenceNumber < endSN { + feedback := r.maybeBuildFeedbackPacket(*r.startSequenceNumber, endSN) + if feedback == nil { + break + } + feedbacks = append(feedbacks, feedback.getRTCP()) + + // NOTE: we don't erase packets from the history in case they need to be resent + // after a reordering. They will be removed instead in Record when they get too + // old. + } + r.packetsHeld = 0 + + return feedbacks +} + +// maybeBuildFeedbackPacket builds a feedback packet starting from startSN (inclusive) until +// endSN (exclusive). +func (r *Recorder) maybeBuildFeedbackPacket(beginSeqNumInclusive, endSeqNumExclusive int64) *feedback { + // NOTE: The logic of this method is inspired by the implementation in Chrome. + // See https://source.chromium.org/chromium/chromium/src/+/refs/heads/main:third_party/webrtc/modules/remote_bitrate_estimator/remote_estimator_proxy.cc;l=276;drc=b5cd13bb6d5d157a5fbe3628b2dd1c1e106203c6 + //nolint:lll + startSNInclusive, endSNExclusive := r.arrivalTimeMap.Clamp(beginSeqNumInclusive), r.arrivalTimeMap.Clamp(endSeqNumExclusive) + + // Create feedback on demand, as we don't yet know if there are packets in the range that have been + // received. + var fb *feedback + + nextSequenceNumber := beginSeqNumInclusive + + for seq := startSNInclusive; seq < endSNExclusive; seq++ { + foundSeq, arrivalTime, ok := r.arrivalTimeMap.FindNextAtOrAfter(seq) + seq = foundSeq + if !ok || seq >= endSNExclusive { + break + } + + if fb == nil { + fb = newFeedback(r.senderSSRC, r.mediaSSRC, r.fbPktCnt) + r.fbPktCnt++ + + // It should be possible to add seq to this new packet. + // If the difference between seq and beginSeqNumInclusive is too large, discard + // reporting too old missing packets. + baseSequenceNumber := max(beginSeqNumInclusive, seq-maxMissingSequenceNumbers) + + // baseSequenceNumber is the expected first sequence number. This is known, + // but we may not have actually received it, so the base time should be the time + // of the first received packet in the feedback. + fb.setBase(uint16(baseSequenceNumber), arrivalTime) //nolint:gosec // G115 + + if !fb.addReceived(uint16(seq), arrivalTime) { //nolint:gosec // G115 + // Could not add a single received packet to the feedback. + // This is unexpected to actually occur, but if it does, we'll + // try again after skipping any missing packets. + // NOTE: It's fine that we already incremented fbPktCnt, as in essence + // we did actually "skip" a feedback (and this matches Chrome's behavior). + r.startSequenceNumber = &seq + + return nil + } + } else if !fb.addReceived(uint16(seq), arrivalTime) { //nolint:gosec // G115 + // Could not add timestamp. Packet may be full. Return + // and try again with a fresh packet. + break + } + + nextSequenceNumber = seq + 1 + } + + r.startSequenceNumber = &nextSequenceNumber + + return fb +} + +type feedback struct { + rtcp *rtcp.TransportLayerCC + baseSequenceNumber uint16 + refTimestamp64MS int64 + lastTimestampUS int64 + nextSequenceNumber uint16 + sequenceNumberCount uint16 + len int + lastChunk chunk + chunks []rtcp.PacketStatusChunk + deltas []*rtcp.RecvDelta +} + +func newFeedback(senderSSRC, mediaSSRC uint32, count uint8) *feedback { + return &feedback{ + rtcp: &rtcp.TransportLayerCC{ + SenderSSRC: senderSSRC, + MediaSSRC: mediaSSRC, + FbPktCount: count, + }, + } +} + +func (f *feedback) setBase(sequenceNumber uint16, timeUS int64) { + f.baseSequenceNumber = sequenceNumber + f.nextSequenceNumber = f.baseSequenceNumber + f.refTimestamp64MS = timeUS / 64e3 + f.lastTimestampUS = f.refTimestamp64MS * 64e3 +} + +func (f *feedback) getRTCP() *rtcp.TransportLayerCC { + f.rtcp.PacketStatusCount = f.sequenceNumberCount + f.rtcp.ReferenceTime = uint32(f.refTimestamp64MS) //nolint:gosec // G115 + f.rtcp.BaseSequenceNumber = f.baseSequenceNumber + for len(f.lastChunk.deltas) > 0 { + f.chunks = append(f.chunks, f.lastChunk.encode()) + } + f.rtcp.PacketChunks = append(f.rtcp.PacketChunks, f.chunks...) + f.rtcp.RecvDeltas = f.deltas + + // 4 bytes header + 16 bytes twcc header + 2 bytes for each chunk + length of deltas + padLen := 20 + len(f.rtcp.PacketChunks)*2 + f.len + padding := padLen%4 != 0 + for padLen%4 != 0 { + padLen++ + } + f.rtcp.Header = rtcp.Header{ + Count: rtcp.FormatTCC, + Type: rtcp.TypeTransportSpecificFeedback, + Padding: padding, + Length: uint16((padLen / 4) - 1), //nolint:gosec // G115 + } + + return f.rtcp +} + +func (f *feedback) addReceived(sequenceNumber uint16, timestampUS int64) bool { + deltaUS := timestampUS - f.lastTimestampUS + var delta250US int64 + if deltaUS >= 0 { + delta250US = (deltaUS + rtcp.TypeTCCDeltaScaleFactor/2) / rtcp.TypeTCCDeltaScaleFactor + } else { + delta250US = (deltaUS - rtcp.TypeTCCDeltaScaleFactor/2) / rtcp.TypeTCCDeltaScaleFactor + } + // delta doesn't fit into 16 bit, need to create new packet + if delta250US < math.MinInt16 || delta250US > math.MaxInt16 { + return false + } + deltaUSRounded := delta250US * rtcp.TypeTCCDeltaScaleFactor + + for ; f.nextSequenceNumber != sequenceNumber; f.nextSequenceNumber++ { + if !f.lastChunk.canAdd(rtcp.TypeTCCPacketNotReceived) { + f.chunks = append(f.chunks, f.lastChunk.encode()) + } + f.lastChunk.add(rtcp.TypeTCCPacketNotReceived) + f.sequenceNumberCount++ + } + + var recvDelta uint16 + switch { + case delta250US >= 0 && delta250US <= 0xff: + f.len++ + recvDelta = rtcp.TypeTCCPacketReceivedSmallDelta + default: + f.len += 2 + recvDelta = rtcp.TypeTCCPacketReceivedLargeDelta + } + + if !f.lastChunk.canAdd(recvDelta) { + f.chunks = append(f.chunks, f.lastChunk.encode()) + } + f.lastChunk.add(recvDelta) + f.deltas = append(f.deltas, &rtcp.RecvDelta{ + Type: recvDelta, + Delta: deltaUSRounded, + }) + f.lastTimestampUS += deltaUSRounded + f.sequenceNumberCount++ + f.nextSequenceNumber++ + + return true +} + +const ( + maxRunLengthCap = 0x1fff // 13 bits + maxOneBitCap = 14 // bits + maxTwoBitCap = 7 // bits +) + +type chunk struct { + hasLargeDelta bool + hasDifferentTypes bool + deltas []uint16 +} + +func (c *chunk) canAdd(delta uint16) bool { + if len(c.deltas) < maxTwoBitCap { + return true + } + if len(c.deltas) < maxOneBitCap && !c.hasLargeDelta && delta != rtcp.TypeTCCPacketReceivedLargeDelta { + return true + } + if len(c.deltas) < maxRunLengthCap && !c.hasDifferentTypes && delta == c.deltas[0] { + return true + } + + return false +} + +func (c *chunk) add(delta uint16) { + c.deltas = append(c.deltas, delta) + c.hasLargeDelta = c.hasLargeDelta || delta == rtcp.TypeTCCPacketReceivedLargeDelta + c.hasDifferentTypes = c.hasDifferentTypes || delta != c.deltas[0] +} + +func (c *chunk) encode() rtcp.PacketStatusChunk { + if !c.hasDifferentTypes { + defer c.reset() + + return &rtcp.RunLengthChunk{ + PacketStatusSymbol: c.deltas[0], + RunLength: uint16(len(c.deltas)), //nolint:gosec // G115 + } + } + if len(c.deltas) == maxOneBitCap { + defer c.reset() + + return &rtcp.StatusVectorChunk{ + SymbolSize: rtcp.TypeTCCSymbolSizeOneBit, + SymbolList: c.deltas, + } + } + + minCap := min(maxTwoBitCap, len(c.deltas)) + svc := &rtcp.StatusVectorChunk{ + SymbolSize: rtcp.TypeTCCSymbolSizeTwoBit, + SymbolList: c.deltas[:minCap], + } + c.deltas = c.deltas[minCap:] + c.hasDifferentTypes = false + c.hasLargeDelta = false + + if len(c.deltas) > 0 { + tmp := c.deltas[0] + for _, d := range c.deltas { + if tmp != d { + c.hasDifferentTypes = true + } + if d == rtcp.TypeTCCPacketReceivedLargeDelta { + c.hasLargeDelta = true + } + } + } + + return svc +} + +func (c *chunk) reset() { + c.deltas = []uint16{} + c.hasLargeDelta = false + c.hasDifferentTypes = false +} diff --git a/vendor/github.com/pion/interceptor/registry.go b/vendor/github.com/pion/interceptor/registry.go new file mode 100644 index 0000000..f0b6ab9 --- /dev/null +++ b/vendor/github.com/pion/interceptor/registry.go @@ -0,0 +1,33 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package interceptor + +// Registry is a collector for interceptors. +type Registry struct { + factories []Factory +} + +// Add adds a new Interceptor to the registry. +func (r *Registry) Add(f Factory) { + r.factories = append(r.factories, f) +} + +// Build constructs a single Interceptor from a InterceptorRegistry. +func (r *Registry) Build(id string) (Interceptor, error) { + if len(r.factories) == 0 { + return &NoOp{}, nil + } + + interceptors := make([]Interceptor, 0, len(r.factories)) + for _, f := range r.factories { + i, err := f.NewInterceptor(id) + if err != nil { + return nil, err + } + + interceptors = append(interceptors, i) + } + + return NewChain(interceptors), nil +} diff --git a/vendor/github.com/pion/interceptor/renovate.json b/vendor/github.com/pion/interceptor/renovate.json new file mode 100644 index 0000000..f1bb98c --- /dev/null +++ b/vendor/github.com/pion/interceptor/renovate.json @@ -0,0 +1,6 @@ +{ + "$schema": "https://docs.renovatebot.com/renovate-schema.json", + "extends": [ + "github>pion/renovate-config" + ] +} diff --git a/vendor/github.com/pion/interceptor/streaminfo.go b/vendor/github.com/pion/interceptor/streaminfo.go new file mode 100644 index 0000000..bcf3133 --- /dev/null +++ b/vendor/github.com/pion/interceptor/streaminfo.go @@ -0,0 +1,41 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package interceptor + +// RTPHeaderExtension represents a negotiated RFC5285 RTP header extension. +type RTPHeaderExtension struct { + URI string + ID int +} + +// StreamInfo is the Context passed when a StreamLocal or StreamRemote has been Binded or Unbinded. +type StreamInfo struct { + ID string + Attributes Attributes + SSRC uint32 + SSRCRetransmission uint32 + SSRCForwardErrorCorrection uint32 + PayloadType uint8 + PayloadTypeRetransmission uint8 + PayloadTypeForwardErrorCorrection uint8 + RTPHeaderExtensions []RTPHeaderExtension + MimeType string + ClockRate uint32 + Channels uint16 + SDPFmtpLine string + RTCPFeedback []RTCPFeedback +} + +// RTCPFeedback signals the connection to use additional RTCP packet types. +// https://draft.ortc.org/#dom-rtcrtcpfeedback +type RTCPFeedback struct { + // Type is the type of feedback. + // see: https://draft.ortc.org/#dom-rtcrtcpfeedback + // valid: ack, ccm, nack, goog-remb, transport-cc + Type string + + // The parameter value depends on the type. + // For example, type="nack" parameter="pli" will send Picture Loss Indicator packets. + Parameter string +} diff --git a/vendor/github.com/pion/logging/.gitignore b/vendor/github.com/pion/logging/.gitignore new file mode 100644 index 0000000..6e2f206 --- /dev/null +++ b/vendor/github.com/pion/logging/.gitignore @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +### JetBrains IDE ### +##################### +.idea/ + +### Emacs Temporary Files ### +############################# +*~ + +### Folders ### +############### +bin/ +vendor/ +node_modules/ + +### Files ### +############# +*.ivf +*.ogg +tags +cover.out +*.sw[poe] +*.wasm +examples/sfu-ws/cert.pem +examples/sfu-ws/key.pem +wasm_exec.js diff --git a/vendor/github.com/pion/logging/.golangci.yml b/vendor/github.com/pion/logging/.golangci.yml new file mode 100644 index 0000000..59edee2 --- /dev/null +++ b/vendor/github.com/pion/logging/.golangci.yml @@ -0,0 +1,151 @@ +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +run: + timeout: 5m + +linters-settings: + govet: + enable: + - shadow + misspell: + locale: US + exhaustive: + default-signifies-exhaustive: true + gomodguard: + blocked: + modules: + - github.com/pkg/errors: + recommendations: + - errors + forbidigo: + analyze-types: true + forbid: + - ^fmt.Print(f|ln)?$ + - ^log.(Panic|Fatal|Print)(f|ln)?$ + - ^os.Exit$ + - ^panic$ + - ^print(ln)?$ + - p: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$ + pkg: ^testing$ + msg: "use testify/assert instead" + varnamelen: + max-distance: 12 + min-name-length: 2 + ignore-type-assert-ok: true + ignore-map-index-ok: true + ignore-chan-recv-ok: true + ignore-decls: + - i int + - n int + - w io.Writer + - r io.Reader + - b []byte + revive: + rules: + # Prefer 'any' type alias over 'interface{}' for Go 1.18+ compatibility + - name: use-any + severity: warning + disabled: false + +linters: + enable: + - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers + - bidichk # Checks for dangerous unicode character sequences + - bodyclose # checks whether HTTP response body is closed successfully + - containedctx # containedctx is a linter that detects struct contained context.Context field + - contextcheck # check the function whether use a non-inherited context + - cyclop # checks function and package cyclomatic complexity + - decorder # check declaration order and count of types, constants, variables and functions + - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) + - dupl # Tool for code clone detection + - durationcheck # check for two durations multiplied together + - err113 # Golang linter to check the errors handling expressions + - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases + - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. + - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. + - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. + - exhaustive # check exhaustiveness of enum switch statements + - exportloopref # checks for pointers to enclosing loop variables + - forbidigo # Forbids identifiers + - forcetypeassert # finds forced type assertions + - gci # Gci control golang package import order and make it always deterministic. + - gochecknoglobals # Checks that no globals are present in Go code + - gocognit # Computes and checks the cognitive complexity of functions + - goconst # Finds repeated strings that could be replaced by a constant + - gocritic # The most opinionated Go source code linter + - gocyclo # Computes and checks the cyclomatic complexity of functions + - godot # Check if comments end in a period + - godox # Tool for detection of FIXME, TODO and other comment keywords + - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification + - gofumpt # Gofumpt checks whether code was gofumpt-ed. + - goheader # Checks is file header matches to pattern + - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports + - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. + - goprintffuncname # Checks that printf-like functions are named with `f` at the end + - gosec # Inspects source code for security problems + - gosimple # Linter for Go source code that specializes in simplifying a code + - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string + - grouper # An analyzer to analyze expression groups. + - importas # Enforces consistent import aliases + - ineffassign # Detects when assignments to existing variables are not used + - lll # Reports long lines + - maintidx # maintidx measures the maintainability index of each function. + - makezero # Finds slice declarations with non-zero initial length + - misspell # Finds commonly misspelled English words in comments + - nakedret # Finds naked returns in functions greater than a specified function length + - nestif # Reports deeply nested if statements + - nilerr # Finds the code that returns nil even if it checks that the error is not nil. + - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. + - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity + - noctx # noctx finds sending http request without context.Context + - predeclared # find code that shadows one of Go's predeclared identifiers + - revive # golint replacement, finds style mistakes + - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks + - stylecheck # Stylecheck is a replacement for golint + - tagliatelle # Checks the struct tags. + - tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17 + - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers + - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code + - unconvert # Remove unnecessary type conversions + - unparam # Reports unused function parameters + - unused # Checks Go code for unused constants, variables, functions and types + - varnamelen # checks that the length of a variable's name matches its scope + - wastedassign # wastedassign finds wasted assignment statements + - whitespace # Tool for detection of leading and trailing whitespace + disable: + - depguard # Go linter that checks if package imports are in a list of acceptable packages + - funlen # Tool for detection of long functions + - gochecknoinits # Checks that no init functions are present in Go code + - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. + - interfacebloat # A linter that checks length of interface. + - ireturn # Accept Interfaces, Return Concrete Types + - mnd # An analyzer to detect magic numbers + - nolintlint # Reports ill-formed or insufficient nolint directives + - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test + - prealloc # Finds slice declarations that could potentially be preallocated + - promlinter # Check Prometheus metrics naming via promlint + - rowserrcheck # checks whether Err of rows is checked successfully + - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. + - testpackage # linter that makes you use a separate _test package + - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes + - wrapcheck # Checks that errors returned from external packages are wrapped + - wsl # Whitespace Linter - Forces you to use empty lines! + +issues: + exclude-use-default: false + exclude-dirs-use-default: false + exclude-rules: + # Allow complex tests and examples, better to be self contained + - path: (examples|main\.go) + linters: + - gocognit + - forbidigo + - path: _test\.go + linters: + - gocognit + + # Allow forbidden identifiers in CLI commands + - path: cmd + linters: + - forbidigo diff --git a/vendor/github.com/pion/logging/.goreleaser.yml b/vendor/github.com/pion/logging/.goreleaser.yml new file mode 100644 index 0000000..30093e9 --- /dev/null +++ b/vendor/github.com/pion/logging/.goreleaser.yml @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +builds: +- skip: true diff --git a/vendor/github.com/pion/logging/LICENSE b/vendor/github.com/pion/logging/LICENSE new file mode 100644 index 0000000..491caf6 --- /dev/null +++ b/vendor/github.com/pion/logging/LICENSE @@ -0,0 +1,9 @@ +MIT License + +Copyright (c) 2023 The Pion community + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/pion/logging/README.md b/vendor/github.com/pion/logging/README.md new file mode 100644 index 0000000..20ae889 --- /dev/null +++ b/vendor/github.com/pion/logging/README.md @@ -0,0 +1,34 @@ +

+
+ Pion Logging +
+

+

The Pion logging library

+

+ Pion transport + join us on Discord Follow us on Bluesky +
+ GitHub Workflow Status + Go Reference + Coverage Status + Go Report Card + License: MIT +

+
+ +### Roadmap +The library is used as a part of our WebRTC implementation. Please refer to that [roadmap](https://github.com/pion/webrtc/issues/9) to track our major milestones. + +### Community +Pion has an active community on the [Discord](https://discord.gg/PngbdqpFbt). + +Follow the [Pion Bluesky](https://bsky.app/profile/pion.ly) or [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news. + +We are always looking to support **your projects**. Please reach out if you have something to build! +If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly) + +### Contributing +Check out the [contributing wiki](https://github.com/pion/webrtc/wiki/Contributing) to join the group of amazing people making this project possible + +### License +MIT License - see [LICENSE](LICENSE) for full text diff --git a/vendor/github.com/pion/logging/codecov.yml b/vendor/github.com/pion/logging/codecov.yml new file mode 100644 index 0000000..263e4d4 --- /dev/null +++ b/vendor/github.com/pion/logging/codecov.yml @@ -0,0 +1,22 @@ +# +# DO NOT EDIT THIS FILE +# +# It is automatically copied from https://github.com/pion/.goassets repository. +# +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +coverage: + status: + project: + default: + # Allow decreasing 2% of total coverage to avoid noise. + threshold: 2% + patch: + default: + target: 70% + only_pulls: true + +ignore: + - "examples/*" + - "examples/**/*" diff --git a/vendor/github.com/pion/logging/logger.go b/vendor/github.com/pion/logging/logger.go new file mode 100644 index 0000000..b23aaa1 --- /dev/null +++ b/vendor/github.com/pion/logging/logger.go @@ -0,0 +1,244 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package logging provides the logging library used by Pion +package logging + +import ( + "fmt" + "io" + "log" + "os" + "strings" + "sync" +) + +// Use this abstraction to ensure thread-safe access to the logger's io.Writer. +// (which could change at runtime). +type loggerWriter struct { + sync.RWMutex + output io.Writer +} + +func (lw *loggerWriter) SetOutput(output io.Writer) { + lw.Lock() + defer lw.Unlock() + lw.output = output +} + +func (lw *loggerWriter) Write(data []byte) (int, error) { + lw.RLock() + defer lw.RUnlock() + + return lw.output.Write(data) +} + +// DefaultLeveledLogger encapsulates functionality for providing logging at. +// user-defined levels. +type DefaultLeveledLogger struct { + level LogLevel + writer *loggerWriter + trace *log.Logger + debug *log.Logger + info *log.Logger + warn *log.Logger + err *log.Logger +} + +// WithTraceLogger is a chainable configuration function which sets the +// Trace-level logger. +func (ll *DefaultLeveledLogger) WithTraceLogger(log *log.Logger) *DefaultLeveledLogger { + ll.trace = log + + return ll +} + +// WithDebugLogger is a chainable configuration function which sets the +// Debug-level logger. +func (ll *DefaultLeveledLogger) WithDebugLogger(log *log.Logger) *DefaultLeveledLogger { + ll.debug = log + + return ll +} + +// WithInfoLogger is a chainable configuration function which sets the +// Info-level logger. +func (ll *DefaultLeveledLogger) WithInfoLogger(log *log.Logger) *DefaultLeveledLogger { + ll.info = log + + return ll +} + +// WithWarnLogger is a chainable configuration function which sets the +// Warn-level logger. +func (ll *DefaultLeveledLogger) WithWarnLogger(log *log.Logger) *DefaultLeveledLogger { + ll.warn = log + + return ll +} + +// WithErrorLogger is a chainable configuration function which sets the +// Error-level logger. +func (ll *DefaultLeveledLogger) WithErrorLogger(log *log.Logger) *DefaultLeveledLogger { + ll.err = log + + return ll +} + +// WithOutput is a chainable configuration function which sets the logger's +// logging output to the supplied io.Writer. +func (ll *DefaultLeveledLogger) WithOutput(output io.Writer) *DefaultLeveledLogger { + ll.writer.SetOutput(output) + + return ll +} + +func (ll *DefaultLeveledLogger) logf(logger *log.Logger, level LogLevel, format string, args ...any) { + if ll.level.Get() < level { + return + } + + callDepth := 3 // this frame + wrapper func + caller + msg := fmt.Sprintf(format, args...) + if err := logger.Output(callDepth, msg); err != nil { + fmt.Fprintf(os.Stderr, "Unable to log: %s", err) + } +} + +// SetLevel sets the logger's logging level. +func (ll *DefaultLeveledLogger) SetLevel(newLevel LogLevel) { + ll.level.Set(newLevel) +} + +// Trace emits the preformatted message if the logger is at or below LogLevelTrace. +func (ll *DefaultLeveledLogger) Trace(msg string) { + ll.logf(ll.trace, LogLevelTrace, msg) // nolint: govet +} + +// Tracef formats and emits a message if the logger is at or below LogLevelTrace. +func (ll *DefaultLeveledLogger) Tracef(format string, args ...any) { + ll.logf(ll.trace, LogLevelTrace, format, args...) +} + +// Debug emits the preformatted message if the logger is at or below LogLevelDebug. +func (ll *DefaultLeveledLogger) Debug(msg string) { + ll.logf(ll.debug, LogLevelDebug, msg) // nolint: govet +} + +// Debugf formats and emits a message if the logger is at or below LogLevelDebug. +func (ll *DefaultLeveledLogger) Debugf(format string, args ...any) { + ll.logf(ll.debug, LogLevelDebug, format, args...) +} + +// Info emits the preformatted message if the logger is at or below LogLevelInfo. +func (ll *DefaultLeveledLogger) Info(msg string) { + ll.logf(ll.info, LogLevelInfo, msg) // nolint: govet +} + +// Infof formats and emits a message if the logger is at or below LogLevelInfo. +func (ll *DefaultLeveledLogger) Infof(format string, args ...any) { + ll.logf(ll.info, LogLevelInfo, format, args...) +} + +// Warn emits the preformatted message if the logger is at or below LogLevelWarn. +func (ll *DefaultLeveledLogger) Warn(msg string) { + ll.logf(ll.warn, LogLevelWarn, msg) // nolint: govet +} + +// Warnf formats and emits a message if the logger is at or below LogLevelWarn. +func (ll *DefaultLeveledLogger) Warnf(format string, args ...any) { + ll.logf(ll.warn, LogLevelWarn, format, args...) +} + +// Error emits the preformatted message if the logger is at or below LogLevelError. +func (ll *DefaultLeveledLogger) Error(msg string) { + ll.logf(ll.err, LogLevelError, msg) // nolint: govet +} + +// Errorf formats and emits a message if the logger is at or below LogLevelError. +func (ll *DefaultLeveledLogger) Errorf(format string, args ...any) { + ll.logf(ll.err, LogLevelError, format, args...) +} + +// NewDefaultLeveledLoggerForScope returns a configured LeveledLogger. +func NewDefaultLeveledLoggerForScope(scope string, level LogLevel, writer io.Writer) *DefaultLeveledLogger { + if writer == nil { + writer = os.Stderr + } + logger := &DefaultLeveledLogger{ + writer: &loggerWriter{output: writer}, + level: level, + } + + return logger. + WithTraceLogger(log.New(logger.writer, fmt.Sprintf("%s TRACE: ", scope), log.Lmicroseconds|log.Lshortfile)). + WithDebugLogger(log.New(logger.writer, fmt.Sprintf("%s DEBUG: ", scope), log.Lmicroseconds|log.Lshortfile)). + WithInfoLogger(log.New(logger.writer, fmt.Sprintf("%s INFO: ", scope), log.LstdFlags)). + WithWarnLogger(log.New(logger.writer, fmt.Sprintf("%s WARNING: ", scope), log.LstdFlags)). + WithErrorLogger(log.New(logger.writer, fmt.Sprintf("%s ERROR: ", scope), log.LstdFlags)) +} + +// DefaultLoggerFactory define levels by scopes and creates new DefaultLeveledLogger. +type DefaultLoggerFactory struct { + Writer io.Writer + DefaultLogLevel LogLevel + ScopeLevels map[string]LogLevel +} + +// NewDefaultLoggerFactory creates a new DefaultLoggerFactory. +func NewDefaultLoggerFactory() *DefaultLoggerFactory { + factory := DefaultLoggerFactory{} + factory.DefaultLogLevel = LogLevelError + factory.ScopeLevels = make(map[string]LogLevel) + factory.Writer = os.Stderr + + logLevels := map[string]LogLevel{ + "DISABLE": LogLevelDisabled, + "ERROR": LogLevelError, + "WARN": LogLevelWarn, + "INFO": LogLevelInfo, + "DEBUG": LogLevelDebug, + "TRACE": LogLevelTrace, + } + + for name, level := range logLevels { + env := os.Getenv(fmt.Sprintf("PION_LOG_%s", name)) + + if env == "" { + env = os.Getenv(fmt.Sprintf("PIONS_LOG_%s", name)) + } + + if env == "" { + continue + } + + if strings.ToLower(env) == "all" { + if factory.DefaultLogLevel < level { + factory.DefaultLogLevel = level + } + + continue + } + + scopes := strings.Split(strings.ToLower(env), ",") + for _, scope := range scopes { + factory.ScopeLevels[scope] = level + } + } + + return &factory +} + +// NewLogger returns a configured LeveledLogger for the given, argsscope. +func (f *DefaultLoggerFactory) NewLogger(scope string) LeveledLogger { + logLevel := f.DefaultLogLevel + if f.ScopeLevels != nil { + scopeLevel, found := f.ScopeLevels[scope] + + if found { + logLevel = scopeLevel + } + } + + return NewDefaultLeveledLoggerForScope(scope, logLevel, f.Writer) +} diff --git a/vendor/github.com/pion/logging/renovate.json b/vendor/github.com/pion/logging/renovate.json new file mode 100644 index 0000000..f1bb98c --- /dev/null +++ b/vendor/github.com/pion/logging/renovate.json @@ -0,0 +1,6 @@ +{ + "$schema": "https://docs.renovatebot.com/renovate-schema.json", + "extends": [ + "github>pion/renovate-config" + ] +} diff --git a/vendor/github.com/pion/logging/scoped.go b/vendor/github.com/pion/logging/scoped.go new file mode 100644 index 0000000..aac518e --- /dev/null +++ b/vendor/github.com/pion/logging/scoped.go @@ -0,0 +1,75 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package logging + +import ( + "sync/atomic" +) + +// LogLevel represents the level at which the logger will emit log messages. +type LogLevel int32 + +// Set updates the LogLevel to the supplied value. +func (ll *LogLevel) Set(newLevel LogLevel) { + atomic.StoreInt32((*int32)(ll), int32(newLevel)) +} + +// Get retrieves the current LogLevel value. +func (ll *LogLevel) Get() LogLevel { + return LogLevel(atomic.LoadInt32((*int32)(ll))) +} + +func (ll LogLevel) String() string { + switch ll { + case LogLevelDisabled: + return "Disabled" + case LogLevelError: + return "Error" + case LogLevelWarn: + return "Warn" + case LogLevelInfo: + return "Info" + case LogLevelDebug: + return "Debug" + case LogLevelTrace: + return "Trace" + default: + return "UNKNOWN" + } +} + +const ( + // LogLevelDisabled completely disables logging of any events. + LogLevelDisabled LogLevel = iota + // LogLevelError is for fatal errors which should be handled by user code, + // but are logged to ensure that they are seen. + LogLevelError + // LogLevelWarn is for logging abnormal, but non-fatal library operation. + LogLevelWarn + // LogLevelInfo is for logging normal library operation (e.g. state transitions, etc.). + LogLevelInfo + // LogLevelDebug is for logging low-level library information (e.g. internal operations). + LogLevelDebug + // LogLevelTrace is for logging very low-level library information (e.g. network traces). + LogLevelTrace +) + +// LeveledLogger is the basic pion Logger interface. +type LeveledLogger interface { + Trace(msg string) + Tracef(format string, args ...any) + Debug(msg string) + Debugf(format string, args ...any) + Info(msg string) + Infof(format string, args ...any) + Warn(msg string) + Warnf(format string, args ...any) + Error(msg string) + Errorf(format string, args ...any) +} + +// LoggerFactory is the basic pion LoggerFactory interface. +type LoggerFactory interface { + NewLogger(scope string) LeveledLogger +} diff --git a/vendor/github.com/pion/mdns/v2/.gitignore b/vendor/github.com/pion/mdns/v2/.gitignore new file mode 100644 index 0000000..6e2f206 --- /dev/null +++ b/vendor/github.com/pion/mdns/v2/.gitignore @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +### JetBrains IDE ### +##################### +.idea/ + +### Emacs Temporary Files ### +############################# +*~ + +### Folders ### +############### +bin/ +vendor/ +node_modules/ + +### Files ### +############# +*.ivf +*.ogg +tags +cover.out +*.sw[poe] +*.wasm +examples/sfu-ws/cert.pem +examples/sfu-ws/key.pem +wasm_exec.js diff --git a/vendor/github.com/pion/mdns/v2/.golangci.yml b/vendor/github.com/pion/mdns/v2/.golangci.yml new file mode 100644 index 0000000..4b4025f --- /dev/null +++ b/vendor/github.com/pion/mdns/v2/.golangci.yml @@ -0,0 +1,147 @@ +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +version: "2" +linters: + enable: + - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers + - bidichk # Checks for dangerous unicode character sequences + - bodyclose # checks whether HTTP response body is closed successfully + - containedctx # containedctx is a linter that detects struct contained context.Context field + - contextcheck # check the function whether use a non-inherited context + - cyclop # checks function and package cyclomatic complexity + - decorder # check declaration order and count of types, constants, variables and functions + - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) + - dupl # Tool for code clone detection + - durationcheck # check for two durations multiplied together + - err113 # Golang linter to check the errors handling expressions + - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases + - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. + - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. + - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. + - exhaustive # check exhaustiveness of enum switch statements + - forbidigo # Forbids identifiers + - forcetypeassert # finds forced type assertions + - gochecknoglobals # Checks that no globals are present in Go code + - gocognit # Computes and checks the cognitive complexity of functions + - goconst # Finds repeated strings that could be replaced by a constant + - gocritic # The most opinionated Go source code linter + - gocyclo # Computes and checks the cyclomatic complexity of functions + - godot # Check if comments end in a period + - godox # Tool for detection of FIXME, TODO and other comment keywords + - goheader # Checks is file header matches to pattern + - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. + - goprintffuncname # Checks that printf-like functions are named with `f` at the end + - gosec # Inspects source code for security problems + - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string + - grouper # An analyzer to analyze expression groups. + - importas # Enforces consistent import aliases + - ineffassign # Detects when assignments to existing variables are not used + - lll # Reports long lines + - maintidx # maintidx measures the maintainability index of each function. + - makezero # Finds slice declarations with non-zero initial length + - misspell # Finds commonly misspelled English words in comments + - nakedret # Finds naked returns in functions greater than a specified function length + - nestif # Reports deeply nested if statements + - nilerr # Finds the code that returns nil even if it checks that the error is not nil. + - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. + - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity + - noctx # noctx finds sending http request without context.Context + - predeclared # find code that shadows one of Go's predeclared identifiers + - revive # golint replacement, finds style mistakes + - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks + - tagliatelle # Checks the struct tags. + - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers + - unconvert # Remove unnecessary type conversions + - unparam # Reports unused function parameters + - unused # Checks Go code for unused constants, variables, functions and types + - varnamelen # checks that the length of a variable's name matches its scope + - wastedassign # wastedassign finds wasted assignment statements + - whitespace # Tool for detection of leading and trailing whitespace + disable: + - depguard # Go linter that checks if package imports are in a list of acceptable packages + - funlen # Tool for detection of long functions + - gochecknoinits # Checks that no init functions are present in Go code + - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. + - interfacebloat # A linter that checks length of interface. + - ireturn # Accept Interfaces, Return Concrete Types + - mnd # An analyzer to detect magic numbers + - nolintlint # Reports ill-formed or insufficient nolint directives + - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test + - prealloc # Finds slice declarations that could potentially be preallocated + - promlinter # Check Prometheus metrics naming via promlint + - rowserrcheck # checks whether Err of rows is checked successfully + - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. + - testpackage # linter that makes you use a separate _test package + - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes + - wrapcheck # Checks that errors returned from external packages are wrapped + - wsl # Whitespace Linter - Forces you to use empty lines! + settings: + staticcheck: + checks: + - all + - -QF1008 # "could remove embedded field", to keep it explicit! + - -QF1003 # "could use tagged switch on enum", Cases conflicts with exhaustive! + exhaustive: + default-signifies-exhaustive: true + forbidigo: + forbid: + - pattern: ^fmt.Print(f|ln)?$ + - pattern: ^log.(Panic|Fatal|Print)(f|ln)?$ + - pattern: ^os.Exit$ + - pattern: ^panic$ + - pattern: ^print(ln)?$ + - pattern: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$ + pkg: ^testing$ + msg: use testify/assert instead + analyze-types: true + gomodguard: + blocked: + modules: + - github.com/pkg/errors: + recommendations: + - errors + govet: + enable: + - shadow + revive: + rules: + # Prefer 'any' type alias over 'interface{}' for Go 1.18+ compatibility + - name: use-any + severity: warning + disabled: false + misspell: + locale: US + varnamelen: + max-distance: 12 + min-name-length: 2 + ignore-type-assert-ok: true + ignore-map-index-ok: true + ignore-chan-recv-ok: true + ignore-decls: + - i int + - n int + - w io.Writer + - r io.Reader + - b []byte + exclusions: + generated: lax + rules: + - linters: + - forbidigo + - gocognit + path: (examples|main\.go) + - linters: + - gocognit + path: _test\.go + - linters: + - forbidigo + path: cmd +formatters: + enable: + - gci # Gci control golang package import order and make it always deterministic. + - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification + - gofumpt # Gofumpt checks whether code was gofumpt-ed. + - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports + exclusions: + generated: lax diff --git a/vendor/github.com/pion/mdns/v2/.goreleaser.yml b/vendor/github.com/pion/mdns/v2/.goreleaser.yml new file mode 100644 index 0000000..30093e9 --- /dev/null +++ b/vendor/github.com/pion/mdns/v2/.goreleaser.yml @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +builds: +- skip: true diff --git a/vendor/github.com/pion/mdns/v2/LICENSE b/vendor/github.com/pion/mdns/v2/LICENSE new file mode 100644 index 0000000..ab60297 --- /dev/null +++ b/vendor/github.com/pion/mdns/v2/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2018 + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/pion/mdns/v2/README.md b/vendor/github.com/pion/mdns/v2/README.md new file mode 100644 index 0000000..8b94e04 --- /dev/null +++ b/vendor/github.com/pion/mdns/v2/README.md @@ -0,0 +1,72 @@ +

+
+ Pion mDNS +
+

+

A Go implementation of mDNS

+

+ Pion mDNS + join us on Discord Follow us on Bluesky +
+ GitHub Workflow Status + Go Reference + Coverage Status + Go Report Card + License: MIT +

+
+ +Go mDNS implementation. The original user is Pion WebRTC, but we would love to see it work for everyone. + +### Running Server +For a mDNS server that responds to queries for `pion-test.local` +```sh +go run examples/server/main.go +``` + +For a mDNS server that responds to queries for `pion-test.local` with a given address +```sh +go run examples/server/publish_ip/main.go -ip=[IP] +``` +If you don't set the `ip` parameter, "1.2.3.4" will be used instead. + + +### Running Client +To query using Pion you can run the `query` example +```sh +go run examples/query/main.go +``` + +You can use the macOS client +``` +dns-sd -q pion-test.local +``` + +Or the avahi client +``` +avahi-resolve -a pion-test.local +``` + +### RFCs +#### Implemented +- **RFC 6762** [Multicast DNS][rfc6762] +- **draft-ietf-rtcweb-mdns-ice-candidates-02** [Using Multicast DNS to protect privacy when exposing ICE candidates](https://datatracker.ietf.org/doc/html/draft-ietf-rtcweb-mdns-ice-candidates-02.html) + +[rfc6762]: https://tools.ietf.org/html/rfc6762 + +### Roadmap +The library is used as a part of our WebRTC implementation. Please refer to that [roadmap](https://github.com/pion/webrtc/issues/9) to track our major milestones. + +### Community +Pion has an active community on the [Discord](https://discord.gg/PngbdqpFbt). + +Follow the [Pion Bluesky](https://bsky.app/profile/pion.ly) or [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news. + +We are always looking to support **your projects**. Please reach out if you have something to build! +If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly) + +### Contributing +Check out the [contributing wiki](https://github.com/pion/webrtc/wiki/Contributing) to join the group of amazing people making this project possible + +### License +MIT License - see [LICENSE](LICENSE) for full text \ No newline at end of file diff --git a/vendor/github.com/pion/mdns/v2/codecov.yml b/vendor/github.com/pion/mdns/v2/codecov.yml new file mode 100644 index 0000000..263e4d4 --- /dev/null +++ b/vendor/github.com/pion/mdns/v2/codecov.yml @@ -0,0 +1,22 @@ +# +# DO NOT EDIT THIS FILE +# +# It is automatically copied from https://github.com/pion/.goassets repository. +# +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +coverage: + status: + project: + default: + # Allow decreasing 2% of total coverage to avoid noise. + threshold: 2% + patch: + default: + target: 70% + only_pulls: true + +ignore: + - "examples/*" + - "examples/**/*" diff --git a/vendor/github.com/pion/mdns/v2/config.go b/vendor/github.com/pion/mdns/v2/config.go new file mode 100644 index 0000000..1356fab --- /dev/null +++ b/vendor/github.com/pion/mdns/v2/config.go @@ -0,0 +1,49 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package mdns + +import ( + "net" + "time" + + "github.com/pion/logging" +) + +const ( + // DefaultAddressIPv4 is the default used by mDNS + // and in most cases should be the address that the + // ipv4.PacketConn passed to Server is bound to. + DefaultAddressIPv4 = "224.0.0.0:5353" + + // DefaultAddressIPv6 is the default IPv6 address used + // by mDNS and in most cases should be the address that + // the ipv6.PacketConn passed to Server is bound to. + DefaultAddressIPv6 = "[FF02::]:5353" +) + +// Config is used to configure a mDNS client or server. +type Config struct { + // Name is the name of the client/server used for logging purposes. + Name string + + // QueryInterval controls how often we sends Queries until we + // get a response for the requested name + QueryInterval time.Duration + + // LocalNames are the names that we will generate answers for + // when we get questions + LocalNames []string + + // LocalAddress will override the published address with the given IP + // when set. Otherwise, the automatically determined address will be used. + LocalAddress net.IP + + LoggerFactory logging.LoggerFactory + + // IncludeLoopback will include loopback interfaces to be eligble for queries and answers. + IncludeLoopback bool + + // Interfaces will override the interfaces used for queries and answers. + Interfaces []net.Interface +} diff --git a/vendor/github.com/pion/mdns/v2/conn.go b/vendor/github.com/pion/mdns/v2/conn.go new file mode 100644 index 0000000..8662602 --- /dev/null +++ b/vendor/github.com/pion/mdns/v2/conn.go @@ -0,0 +1,1302 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package mdns + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "strings" + "sync" + "time" + + "github.com/pion/logging" + "golang.org/x/net/dns/dnsmessage" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +// Conn represents a mDNS Server. +type Conn struct { + mu sync.RWMutex + name string + log logging.LeveledLogger + + multicastPktConnV4 ipPacketConn + multicastPktConnV6 ipPacketConn + dstAddr4 *net.UDPAddr + dstAddr6 *net.UDPAddr + + unicastPktConnV4 ipPacketConn + unicastPktConnV6 ipPacketConn + + queryInterval time.Duration + localNames []string + queries []*query + ifaces map[int]netInterface + + closed chan any +} + +type query struct { + nameWithSuffix string + queryResultChan chan queryResult +} + +type queryResult struct { + answer dnsmessage.ResourceHeader + addr netip.Addr +} + +const ( + defaultQueryInterval = time.Second + destinationAddress4 = "224.0.0.251:5353" + destinationAddress6 = "[FF02::FB]:5353" + maxMessageRecords = 3 + responseTTL = 120 + // maxPacketSize is the maximum size of a mdns packet. + // From RFC 6762: + // Even when fragmentation is used, a Multicast DNS packet, including IP + // and UDP headers, MUST NOT exceed 9000 bytes. + // https://datatracker.ietf.org/doc/html/rfc6762#section-17 + maxPacketSize = 9000 +) + +var ( + errNoPositiveMTUFound = errors.New("no positive MTU found") + errNoPacketConn = errors.New("must supply at least a multicast IPv4 or IPv6 PacketConn") + errNoUsableInterfaces = errors.New("no usable interfaces found for mDNS") + errFailedToClose = errors.New("failed to close mDNS Conn") + errFailedToDecodeAddrFromAResource = errors.New("failed to decode netip.Addr from A type Resource") + errFailedToDecodeAddrFromAAAAResource = errors.New("failed to decode netip.Addr from AAAA type Resource") + errUnhandledAnswerHeaderType = errors.New("header for Answer had unhandled type") +) + +type netInterface struct { + net.Interface + ipAddrs []netip.Addr + supportsV4 bool + supportsV6 bool +} + +// Server establishes a mDNS connection over an existing conn. +// Either one or both of the multicast packet conns should be provided. +// The presence of each IP type of PacketConn will dictate what kinds +// of questions are sent for queries. That is, if an ipv6.PacketConn is +// provided, then AAAA questions will be sent. A questions will only be +// sent if an ipv4.PacketConn is also provided. In the future, we may +// add a QueryAddr method that allows specifying this more clearly. +// +//nolint:gocognit,gocyclo,cyclop,maintidx +func Server( + multicastPktConnV4 *ipv4.PacketConn, + multicastPktConnV6 *ipv6.PacketConn, + config *Config, +) (*Conn, error) { + if config == nil { + return nil, errNilConfig + } + loggerFactory := config.LoggerFactory + if loggerFactory == nil { + loggerFactory = logging.NewDefaultLoggerFactory() + } + log := loggerFactory.NewLogger("mdns") + + conn := &Conn{ + queryInterval: defaultQueryInterval, + log: log, + closed: make(chan any), + } + conn.name = config.Name + if conn.name == "" { + conn.name = fmt.Sprintf("%p", &conn) + } + + if multicastPktConnV4 == nil && multicastPktConnV6 == nil { + return nil, errNoPacketConn + } + + ifaces := config.Interfaces + if ifaces == nil { + var err error + ifaces, err = net.Interfaces() + if err != nil { + return nil, err + } + } + + var unicastPktConnV4 *ipv4.PacketConn + { + addr4, err := net.ResolveUDPAddr("udp4", "0.0.0.0:0") + if err != nil { + return nil, err + } + + unicastConnV4, err := net.ListenUDP("udp4", addr4) + if err != nil { + log.Warnf( + "[%s] failed to listen on unicast IPv4 %s: %s; will not be able to receive unicast responses on IPv4", + conn.name, addr4, err, + ) + } else { + unicastPktConnV4 = ipv4.NewPacketConn(unicastConnV4) + } + } + + var unicastPktConnV6 *ipv6.PacketConn + { + addr6, err := net.ResolveUDPAddr("udp6", "[::]:") + if err != nil { + return nil, err + } + + unicastConnV6, err := net.ListenUDP("udp6", addr6) + if err != nil { + log.Warnf( + "[%s] failed to listen on unicast IPv6 %s: %s; will not be able to receive unicast responses on IPv6", + conn.name, addr6, err, + ) + } else { + unicastPktConnV6 = ipv6.NewPacketConn(unicastConnV6) + } + } + + multicastGroup4 := net.IPv4(224, 0, 0, 251) + multicastGroupAddr4 := &net.UDPAddr{IP: multicastGroup4} + + // FF02::FB + multicastGroup6 := net.IP{0xff, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xfb} + multicastGroupAddr6 := &net.UDPAddr{IP: multicastGroup6} + + inboundBufferSize := 0 + joinErrCount := 0 + ifacesToUse := make(map[int]netInterface, len(ifaces)) + for i := range ifaces { + ifc := ifaces[i] + if !config.IncludeLoopback && ifc.Flags&net.FlagLoopback == net.FlagLoopback { + continue + } + if ifc.Flags&net.FlagUp == 0 { + continue + } + + addrs, err := ifc.Addrs() + if err != nil { + continue + } + var supportsV4, supportsV6 bool + ifcIPAddrs := make([]netip.Addr, 0, len(addrs)) + for _, addr := range addrs { + var ipToConv net.IP + switch addr := addr.(type) { + case *net.IPNet: + ipToConv = addr.IP + case *net.IPAddr: + ipToConv = addr.IP + default: + continue + } + + ipAddr, ok := netip.AddrFromSlice(ipToConv) + if !ok { + continue + } + if multicastPktConnV4 != nil { + // don't want mapping since we also support IPv4/A + ipAddr = ipAddr.Unmap() + } + ipAddr = addrWithOptionalZone(ipAddr, ifc.Name) + + if ipAddr.Is6() && !ipAddr.Is4In6() { + supportsV6 = true + } else { + // we'll claim we support v4 but defer if we send it or not + // based on IPv4-to-IPv6 mapping rules later (search for Is4In6 below) + supportsV4 = true + } + ifcIPAddrs = append(ifcIPAddrs, ipAddr) + } + if !supportsV4 && !supportsV6 { + continue + } + + var atLeastOneJoin bool + if supportsV4 && multicastPktConnV4 != nil { + if err := multicastPktConnV4.JoinGroup(&ifc, multicastGroupAddr4); err == nil { + atLeastOneJoin = true + } + } + if supportsV6 && multicastPktConnV6 != nil { + if err := multicastPktConnV6.JoinGroup(&ifc, multicastGroupAddr6); err == nil { + atLeastOneJoin = true + } + } + if !atLeastOneJoin { + joinErrCount++ + + continue + } + + ifacesToUse[ifc.Index] = netInterface{ + Interface: ifc, + ipAddrs: ifcIPAddrs, + supportsV4: supportsV4, + supportsV6: supportsV6, + } + if ifc.MTU > inboundBufferSize { + inboundBufferSize = ifc.MTU + } + } + + if len(ifacesToUse) == 0 { + return nil, errNoUsableInterfaces + } + if inboundBufferSize == 0 { + return nil, errNoPositiveMTUFound + } + if inboundBufferSize > maxPacketSize { + inboundBufferSize = maxPacketSize + } + if joinErrCount >= len(ifaces) { + return nil, errJoiningMulticastGroup + } + + dstAddr4, err := net.ResolveUDPAddr("udp4", destinationAddress4) + if err != nil { + return nil, err + } + + dstAddr6, err := net.ResolveUDPAddr("udp6", destinationAddress6) + if err != nil { + return nil, err + } + + var localNames []string + for _, l := range config.LocalNames { + localNames = append(localNames, l+".") + } + + conn.dstAddr4 = dstAddr4 + conn.dstAddr6 = dstAddr6 + conn.localNames = localNames + conn.ifaces = ifacesToUse + + if config.QueryInterval != 0 { + conn.queryInterval = config.QueryInterval + } + + if multicastPktConnV4 != nil { + if err := multicastPktConnV4.SetControlMessage(ipv4.FlagInterface, true); err != nil { + conn.log.Warnf( + "[%s] failed to SetControlMessage(ipv4.FlagInterface) on multicast IPv4 PacketConn %v", + conn.name, err, + ) + } + if err := multicastPktConnV4.SetControlMessage(ipv4.FlagDst, true); err != nil { + conn.log.Warnf("[%s] failed to SetControlMessage(ipv4.FlagDst) on multicast IPv4 PacketConn %v", conn.name, err) + } + conn.multicastPktConnV4 = ipPacketConn4{conn.name, multicastPktConnV4, log} + } + if multicastPktConnV6 != nil { + if err := multicastPktConnV6.SetControlMessage(ipv6.FlagInterface, true); err != nil { + conn.log.Warnf( + "[%s] failed to SetControlMessage(ipv6.FlagInterface) on multicast IPv6 PacketConn %v", + conn.name, err, + ) + } + if err := multicastPktConnV6.SetControlMessage(ipv6.FlagDst, true); err != nil { + conn.log.Warnf( + "[%s] failed to SetControlMessage(ipv6.FlagInterface) on multicast IPv6 PacketConn %v", + conn.name, err, + ) + } + conn.multicastPktConnV6 = ipPacketConn6{conn.name, multicastPktConnV6, log} + } + if unicastPktConnV4 != nil { + if err := unicastPktConnV4.SetControlMessage(ipv4.FlagInterface, true); err != nil { + conn.log.Warnf("[%s] failed to SetControlMessage(ipv4.FlagInterface) on unicast IPv4 PacketConn %v", conn.name, err) + } + if err := unicastPktConnV4.SetControlMessage(ipv4.FlagDst, true); err != nil { + conn.log.Warnf("[%s] failed to SetControlMessage(ipv4.FlagInterface) on unicast IPv4 PacketConn %v", conn.name, err) + } + conn.unicastPktConnV4 = ipPacketConn4{conn.name, unicastPktConnV4, log} + } + if unicastPktConnV6 != nil { + if err := unicastPktConnV6.SetControlMessage(ipv6.FlagInterface, true); err != nil { + conn.log.Warnf("[%s] failed to SetControlMessage(ipv6.FlagInterface) on unicast IPv6 PacketConn %v", conn.name, err) + } + if err := unicastPktConnV6.SetControlMessage(ipv6.FlagDst, true); err != nil { + conn.log.Warnf("[%s] failed to SetControlMessage(ipv6.FlagInterface) on unicast IPv6 PacketConn %v", conn.name, err) + } + conn.unicastPktConnV6 = ipPacketConn6{conn.name, unicastPktConnV6, log} + } + + if config.IncludeLoopback { //nolint:nestif + // this is an efficient way for us to send ourselves a message faster instead of it going + // further out into the network stack. + if multicastPktConnV4 != nil { + if err := multicastPktConnV4.SetMulticastLoopback(true); err != nil { + conn.log.Warnf( + //nolint:lll + "[%s] failed to SetMulticastLoopback(true) on multicast IPv4 PacketConn %v; this may cause inefficient network path c.name,communications", + conn.name, err, + ) + } + } + if multicastPktConnV6 != nil { + if err := multicastPktConnV6.SetMulticastLoopback(true); err != nil { + conn.log.Warnf( + //nolint:lll + "[%s] failed to SetMulticastLoopback(true) on multicast IPv6 PacketConn %v; this may cause inefficient network path c.name,communications", + conn.name, err, + ) + } + } + if unicastPktConnV4 != nil { + if err := unicastPktConnV4.SetMulticastLoopback(true); err != nil { + conn.log.Warnf( + //nolint:lll + "[%s] failed to SetMulticastLoopback(true) on unicast IPv4 PacketConn %v; this may cause inefficient network path c.name,communications", + conn.name, err, + ) + } + } + if unicastPktConnV6 != nil { + if err := unicastPktConnV6.SetMulticastLoopback(true); err != nil { + conn.log.Warnf( + //nolint:lll + "[%s] failed to SetMulticastLoopback(true) on unicast IPv6 PacketConn %v; this may cause inefficient network path c.name,communications", + conn.name, err, + ) + } + } + } + + // https://www.rfc-editor.org/rfc/rfc6762.html#section-17 + // Multicast DNS messages carried by UDP may be up to the IP MTU of the + // physical interface, less the space required for the IP header (20 + // bytes for IPv4; 40 bytes for IPv6) and the UDP header (8 bytes). + started := make(chan struct{}) + go conn.start(started, inboundBufferSize-20-8, config) + <-started + + return conn, nil +} + +// Close closes the mDNS Conn. +func (c *Conn) Close() error { //nolint:cyclop + select { + case <-c.closed: + return nil + default: + } + + // Once on go1.20, can use errors.Join + var errs []error + if c.multicastPktConnV4 != nil { + if err := c.multicastPktConnV4.Close(); err != nil { + errs = append(errs, err) + } + } + + if c.multicastPktConnV6 != nil { + if err := c.multicastPktConnV6.Close(); err != nil { + errs = append(errs, err) + } + } + + if c.unicastPktConnV4 != nil { + if err := c.unicastPktConnV4.Close(); err != nil { + errs = append(errs, err) + } + } + + if c.unicastPktConnV6 != nil { + if err := c.unicastPktConnV6.Close(); err != nil { + errs = append(errs, err) + } + } + + if len(errs) == 0 { + <-c.closed + + return nil + } + + rtrn := errFailedToClose + for _, err := range errs { + rtrn = fmt.Errorf("%w\n%w", err, rtrn) + } + + return rtrn +} + +// Query sends mDNS Queries for the following name until +// either the Context is canceled/expires or we get a result +// +// Deprecated: Use QueryAddr instead as it supports the easier to use netip.Addr. +func (c *Conn) Query(ctx context.Context, name string) (dnsmessage.ResourceHeader, net.Addr, error) { + header, addr, err := c.QueryAddr(ctx, name) + if err != nil { + return header, nil, err + } + + return header, &net.IPAddr{ + IP: addr.AsSlice(), + Zone: addr.Zone(), + }, nil +} + +// QueryAddr sends mDNS Queries for the following name until +// either the Context is canceled/expires or we get a result. +func (c *Conn) QueryAddr(ctx context.Context, name string) (dnsmessage.ResourceHeader, netip.Addr, error) { + select { + case <-c.closed: + return dnsmessage.ResourceHeader{}, netip.Addr{}, errConnectionClosed + default: + } + + nameWithSuffix := name + "." + + queryChan := make(chan queryResult, 1) + query := &query{nameWithSuffix, queryChan} + c.mu.Lock() + c.queries = append(c.queries, query) + c.mu.Unlock() + + defer func() { + c.mu.Lock() + defer c.mu.Unlock() + for i := len(c.queries) - 1; i >= 0; i-- { + if c.queries[i] == query { + c.queries = append(c.queries[:i], c.queries[i+1:]...) + } + } + }() + + ticker := time.NewTicker(c.queryInterval) + defer ticker.Stop() + + c.sendQuestion(nameWithSuffix) + for { + select { + case <-ticker.C: + c.sendQuestion(nameWithSuffix) + case <-c.closed: + return dnsmessage.ResourceHeader{}, netip.Addr{}, errConnectionClosed + case res := <-queryChan: + // Given https://datatracker.ietf.org/doc/html/draft-ietf-mmusic-mdns-ice-candidates#section-3.2.2-2 + // An ICE agent SHOULD ignore candidates where the hostname resolution returns more than one IP address. + // + // We will take the first we receive which could result in a race between two suitable addresses where + // one is better than the other (e.g. localhost vs LAN). + return res.answer, res.addr, nil + case <-ctx.Done(): + return dnsmessage.ResourceHeader{}, netip.Addr{}, errContextElapsed + } + } +} + +type ipToBytesError struct { + addr netip.Addr + expectedType string +} + +func (err ipToBytesError) Error() string { + return fmt.Sprintf("ip (%s) is not %s", err.addr, err.expectedType) +} + +// assumes ipv4-to-ipv6 mapping has been checked. +func ipv4ToBytes(ipAddr netip.Addr) ([4]byte, error) { + if !ipAddr.Is4() { + return [4]byte{}, ipToBytesError{ipAddr, "IPv4"} + } + + md, err := ipAddr.MarshalBinary() + if err != nil { + return [4]byte{}, err + } + + // net.IPs are stored in big endian / network byte order + var out [4]byte + copy(out[:], md) + + return out, nil +} + +// assumes ipv4-to-ipv6 mapping has been checked. +func ipv6ToBytes(ipAddr netip.Addr) ([16]byte, error) { + if !ipAddr.Is6() { + return [16]byte{}, ipToBytesError{ipAddr, "IPv6"} + } + md, err := ipAddr.MarshalBinary() + if err != nil { + return [16]byte{}, err + } + + // net.IPs are stored in big endian / network byte order + var out [16]byte + copy(out[:], md) + + return out, nil +} + +type ipToAddrError struct { + ip []byte +} + +func (err ipToAddrError) Error() string { + return fmt.Sprintf("failed to convert ip address '%s' to netip.Addr", err.ip) +} + +func interfaceForRemote(remote string) (*netip.Addr, error) { + conn, err := net.Dial("udp", remote) //nolint: noctx + if err != nil { + return nil, err + } + + localAddr, ok := conn.LocalAddr().(*net.UDPAddr) + if !ok { + return nil, errFailedCast + } + + if err := conn.Close(); err != nil { + return nil, err + } + + ipAddr, ok := netip.AddrFromSlice(localAddr.IP) + if !ok { + return nil, ipToAddrError{localAddr.IP} + } + ipAddr = addrWithOptionalZone(ipAddr, localAddr.Zone) + + return &ipAddr, nil +} + +type writeType byte + +const ( + writeTypeQuestion writeType = iota + writeTypeAnswer +) + +func (c *Conn) sendQuestion(name string) { + packedName, err := dnsmessage.NewName(name) + if err != nil { + c.log.Warnf("[%s] failed to construct mDNS packet %v", c.name, err) + + return + } + + // https://datatracker.ietf.org/doc/html/draft-ietf-rtcweb-mdns-ice-candidates-04#section-3.2.1 + // + // 2. Otherwise, resolve the candidate using mDNS. The ICE agent + // SHOULD set the unicast-response bit of the corresponding mDNS + // query message; this minimizes multicast traffic, as the response + // is probably only useful to the querying node. + // + // 18.12. Repurposing of Top Bit of qclass in Question Section + // + // In the Question Section of a Multicast DNS query, the top bit of the + // qclass field is used to indicate that unicast responses are preferred + // for this particular question. (See Section 5.4.) + // + // We'll follow this up sending on our unicast based packet connections so that we can + // get a unicast response back. + msg := dnsmessage.Message{ + Header: dnsmessage.Header{}, + } + + // limit what we ask for based on what IPv is available. In the future, + // this could be an option since there's no reason you cannot get an + // A record on an IPv6 sourced question and vice versa. + if c.multicastPktConnV4 != nil { + msg.Questions = append(msg.Questions, dnsmessage.Question{ + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET | (1 << 15), + Name: packedName, + }) + } + if c.multicastPktConnV6 != nil { + msg.Questions = append(msg.Questions, dnsmessage.Question{ + Type: dnsmessage.TypeAAAA, + Class: dnsmessage.ClassINET | (1 << 15), + Name: packedName, + }) + } + + rawQuery, err := msg.Pack() + if err != nil { + c.log.Warnf("[%s] failed to construct mDNS packet %v", c.name, err) + + return + } + + c.writeToSocket(-1, rawQuery, false, false, writeTypeQuestion, nil) +} + +//nolint:gocognit,gocyclo,cyclop +func (c *Conn) writeToSocket( + ifIndex int, + b []byte, + hasLoopbackData bool, + hasIPv6Zone bool, + wType writeType, + unicastDst *net.UDPAddr, +) { + var dst4, dst6 net.Addr + if wType == writeTypeAnswer { //nolint:nestif + if unicastDst == nil { + dst4 = c.dstAddr4 + dst6 = c.dstAddr6 + } else { + if unicastDst.IP.To4() == nil { + dst6 = unicastDst + } else { + dst4 = unicastDst + } + } + } + + if ifIndex != -1 { //nolint:nestif + if wType == writeTypeQuestion { + c.log.Errorf("[%s] Unexpected question using specific interface index %d; dropping question", c.name, ifIndex) + + return + } + + ifc, ok := c.ifaces[ifIndex] + if !ok { + c.log.Warnf("[%s] no interface for %d", c.name, ifIndex) + + return + } + if hasLoopbackData && ifc.Flags&net.FlagLoopback == 0 { + // avoid accidentally tricking the destination that itself is the same as us + c.log.Debugf("[%s] interface is not loopback %d", c.name, ifIndex) + + return + } + + c.log.Debugf("[%s] writing answer to IPv4: %v, IPv6: %v", c.name, dst4, dst6) + + if ifc.supportsV4 && c.multicastPktConnV4 != nil && dst4 != nil { + if !hasIPv6Zone { + if _, err := c.multicastPktConnV4.WriteTo(b, &ifc.Interface, nil, dst4); err != nil { + c.log.Warnf("[%s] failed to send mDNS packet on IPv4 interface %d: %v", c.name, ifIndex, err) + } + } else { + c.log.Debugf("[%s] refusing to send mDNS packet with IPv6 zone over IPv4", c.name) + } + } + if ifc.supportsV6 && c.multicastPktConnV6 != nil && dst6 != nil { + if _, err := c.multicastPktConnV6.WriteTo(b, &ifc.Interface, nil, dst6); err != nil { + c.log.Warnf("[%s] failed to send mDNS packet on IPv6 interface %d: %v", c.name, ifIndex, err) + } + } + + return + } + for ifcIdx := range c.ifaces { + ifc := c.ifaces[ifcIdx] + if hasLoopbackData { + c.log.Debugf("[%s] Refusing to send loopback data with non-specific interface", c.name) + + continue + } + + if wType == writeTypeQuestion { //nolint:nestif + // we'll write via unicast if we can in case the responder chooses to respond to the address the request + // came from (i.e. not respecting unicast-response bit). If we were to use the multicast packet + // conn here, we'd be writing from a specific multicast address which won't be able to receive unicast + // traffic (it only works when listening on 0.0.0.0/[::]). + if c.unicastPktConnV4 == nil && c.unicastPktConnV6 == nil { + c.log.Debugf("[%s] writing question to multicast IPv4/6 %s", c.name, c.dstAddr4) + if ifc.supportsV4 && c.multicastPktConnV4 != nil { + if _, err := c.multicastPktConnV4.WriteTo(b, &ifc.Interface, nil, c.dstAddr4); err != nil { + c.log.Warnf("[%s] failed to send mDNS packet (multicast) on IPv4 interface %d: %v", c.name, ifc.Index, err) + } + } + if ifc.supportsV6 && c.multicastPktConnV6 != nil { + if _, err := c.multicastPktConnV6.WriteTo(b, &ifc.Interface, nil, c.dstAddr6); err != nil { + c.log.Warnf("[%s] failed to send mDNS packet (multicast) on IPv6 interface %d: %v", c.name, ifc.Index, err) + } + } + } + if ifc.supportsV4 && c.unicastPktConnV4 != nil { + c.log.Debugf("[%s] writing question to unicast IPv4 %s", c.name, c.dstAddr4) + if _, err := c.unicastPktConnV4.WriteTo(b, &ifc.Interface, nil, c.dstAddr4); err != nil { + c.log.Warnf("[%s] failed to send mDNS packet (unicast) on interface %d: %v", c.name, ifc.Index, err) + } + } + if ifc.supportsV6 && c.unicastPktConnV6 != nil { + c.log.Debugf("[%s] writing question to unicast IPv6 %s", c.name, c.dstAddr6) + if _, err := c.unicastPktConnV6.WriteTo(b, &ifc.Interface, nil, c.dstAddr6); err != nil { + c.log.Warnf("[%s] failed to send mDNS packet (unicast) on interface %d: %v", c.name, ifc.Index, err) + } + } + } else { + c.log.Debugf("[%s] writing answer to IPv4: %v, IPv6: %v", c.name, dst4, dst6) + + if ifc.supportsV4 && c.multicastPktConnV4 != nil && dst4 != nil { + if !hasIPv6Zone { + if _, err := c.multicastPktConnV4.WriteTo(b, &ifc.Interface, nil, dst4); err != nil { + c.log.Warnf("[%s] failed to send mDNS packet (multicast) on IPv4 interface %d: %v", c.name, ifIndex, err) + } + } else { + c.log.Debugf("[%s] refusing to send mDNS packet with IPv6 zone over IPv4", c.name) + } + } + if ifc.supportsV6 && c.multicastPktConnV6 != nil && dst6 != nil { + if _, err := c.multicastPktConnV6.WriteTo(b, &ifc.Interface, nil, dst6); err != nil { + c.log.Warnf("[%s] failed to send mDNS packet (multicast) on IPv6 interface %d: %v", c.name, ifIndex, err) + } + } + } + } +} + +func createAnswer(id uint16, question dnsmessage.Question, addr netip.Addr, + isUnicast bool, +) (dnsmessage.Message, error) { + packedName, err := dnsmessage.NewName(question.Name.String()) + if err != nil { + return dnsmessage.Message{}, err + } + + msg := dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: id, + Response: true, + Authoritative: true, + }, + Answers: []dnsmessage.Resource{ + { + Header: dnsmessage.ResourceHeader{ + Class: dnsmessage.ClassINET, + Name: packedName, + TTL: responseTTL, + }, + }, + }, + } + + // include question in answer if specified for this answer (such as unicast: Spec 6.7.) + if isUnicast { + msg.Questions = []dnsmessage.Question{question} + } + + if addr.Is4() { + ipBuf, err := ipv4ToBytes(addr) + if err != nil { + return dnsmessage.Message{}, err + } + msg.Answers[0].Header.Type = dnsmessage.TypeA + msg.Answers[0].Body = &dnsmessage.AResource{ + A: ipBuf, + } + } else if addr.Is6() { + // we will lose the zone here, but the receiver can reconstruct it + ipBuf, err := ipv6ToBytes(addr) + if err != nil { + return dnsmessage.Message{}, err + } + msg.Answers[0].Header.Type = dnsmessage.TypeAAAA + msg.Answers[0].Body = &dnsmessage.AAAAResource{ + AAAA: ipBuf, + } + } + + return msg, nil +} + +func (c *Conn) sendAnswer(queryID uint16, question dnsmessage.Question, ifIndex int, result netip.Addr, + dst *net.UDPAddr, isUnicast bool, +) { + answer, err := createAnswer(queryID, question, result, isUnicast) + if err != nil { + c.log.Warnf("[%s] failed to create mDNS answer %v", c.name, err) + + return + } + + rawAnswer, err := answer.Pack() + if err != nil { + c.log.Warnf("[%s] failed to construct mDNS packet %v", c.name, err) + + return + } + + c.writeToSocket( + ifIndex, + rawAnswer, + result.IsLoopback(), + result.Is6() && result.Zone() != "", + writeTypeAnswer, + dst, + ) +} + +type ipControlMessage struct { + IfIndex int + Dst net.IP +} + +type ipPacketConn interface { + ReadFrom(b []byte) (n int, cm *ipControlMessage, src net.Addr, err error) + WriteTo(b []byte, via *net.Interface, cm *ipControlMessage, dst net.Addr) (n int, err error) + Close() error +} + +type ipPacketConn4 struct { + name string + conn *ipv4.PacketConn + log logging.LeveledLogger +} + +func (c ipPacketConn4) ReadFrom(b []byte) (n int, cm *ipControlMessage, src net.Addr, err error) { + n, cm4, src, err := c.conn.ReadFrom(b) + if err != nil || cm4 == nil { + return n, nil, src, err + } + + return n, &ipControlMessage{IfIndex: cm4.IfIndex, Dst: cm4.Dst}, src, err +} + +func (c ipPacketConn4) WriteTo(b []byte, via *net.Interface, cm *ipControlMessage, dst net.Addr) (n int, err error) { + var cm4 *ipv4.ControlMessage + if cm != nil { + cm4 = &ipv4.ControlMessage{ + IfIndex: cm.IfIndex, + } + } + if err := c.conn.SetMulticastInterface(via); err != nil { + c.log.Warnf("[%s] failed to set multicast interface for %d: %v", c.name, via.Index, err) + + return 0, err + } + + return c.conn.WriteTo(b, cm4, dst) +} + +func (c ipPacketConn4) Close() error { + return c.conn.Close() +} + +type ipPacketConn6 struct { + name string + conn *ipv6.PacketConn + log logging.LeveledLogger +} + +func (c ipPacketConn6) ReadFrom(b []byte) (n int, cm *ipControlMessage, src net.Addr, err error) { + n, cm6, src, err := c.conn.ReadFrom(b) + if err != nil || cm6 == nil { + return n, nil, src, err + } + + return n, &ipControlMessage{IfIndex: cm6.IfIndex, Dst: cm6.Dst}, src, err +} + +func (c ipPacketConn6) WriteTo(b []byte, via *net.Interface, cm *ipControlMessage, dst net.Addr) (n int, err error) { + var cm6 *ipv6.ControlMessage + if cm != nil { + cm6 = &ipv6.ControlMessage{ + IfIndex: cm.IfIndex, + } + } + if err := c.conn.SetMulticastInterface(via); err != nil { + c.log.Warnf("[%s] failed to set multicast interface for %d: %v", c.name, via.Index, err) + + return 0, err + } + + return c.conn.WriteTo(b, cm6, dst) +} + +func (c ipPacketConn6) Close() error { + return c.conn.Close() +} + +//nolint:gocognit,gocyclo,cyclop,maintidx +func (c *Conn) readLoop(name string, pktConn ipPacketConn, inboundBufferSize int, config *Config) { + b := make([]byte, inboundBufferSize) + + for { + n, cm, src, err := pktConn.ReadFrom(b) + if err != nil { + if errors.Is(err, net.ErrClosed) { + return + } + c.log.Warnf("[%s] failed to ReadFrom %q %v", c.name, src, err) + + continue + } + c.log.Debugf("[%s] got read on %s from %s", c.name, name, src) + + var ifIndex int + var pktDst net.IP + if cm != nil { + ifIndex = cm.IfIndex + pktDst = cm.Dst + } else { + ifIndex = -1 + } + srcAddr, ok := src.(*net.UDPAddr) + if !ok { + c.log.Warnf("[%s] expected source address %s to be UDP but got %", c.name, src, src) + + continue + } + + func() { + var msg dnsmessage.Message + err := msg.Unpack(b[:n]) + if err != nil { + c.log.Warnf("[%s] failed to parse mDNS packet %v", c.name, err) + + return + } + + // Questions are often echoed with answers, therefore + // If we have more questions than answers it is a question we might need to respond to + if len(msg.Questions) > len(msg.Answers) { //nolint:nestif + for _, question := range msg.Questions { + if question.Type != dnsmessage.TypeA && question.Type != dnsmessage.TypeAAAA { + continue + } + + // https://datatracker.ietf.org/doc/html/rfc6762#section-6 + // The destination UDP port in all Multicast DNS responses MUST be 5353, + // and the destination address MUST be the mDNS IPv4 link-local + // multicast address 224.0.0.251 or its IPv6 equivalent FF02::FB, except + // when generating a reply to a query that explicitly requested a + // unicast response + isQU := (question.Class & (1 << 15)) != 0 // via the unicast-response bit + isLegacy := srcAddr.Port != 5353 // by virtue of being a legacy query (Section 6.7) + isDirect := len(pktDst) != 0 && + !pktDst.Equal(c.dstAddr4.IP) && + !pktDst.Equal(c.dstAddr6.IP) // by virtue of being a direct unicast query + shouldReplyUnicast := isQU || isLegacy || isDirect + var dst *net.UDPAddr + if shouldReplyUnicast { + dst = srcAddr + } + + queryWantsV4 := question.Type == dnsmessage.TypeA + + for _, localName := range c.localNames { + if strings.EqualFold(localName, question.Name.String()) { //nolint:nestif + var localAddress *netip.Addr + if config.LocalAddress != nil { + // this means the LocalAddress does not support link-local since + // we have no zone to set here. + ipAddr, ok := netip.AddrFromSlice(config.LocalAddress) + if !ok { + c.log.Warnf("[%s] failed to convert config.LocalAddress '%s' to netip.Addr", c.name, config.LocalAddress) + + continue + } + if c.multicastPktConnV4 != nil { + // don't want mapping since we also support IPv4/A + ipAddr = ipAddr.Unmap() + } + localAddress = &ipAddr + } else { + // prefer the address of the interface if we know its index, but otherwise + // derive it from the address we read from. We do this because even if + // multicast loopback is in use or we send from a loopback interface, + // there are still cases where the IP packet will contain the wrong + // source IP (e.g. a LAN interface). + // For example, we can have a packet that has: + // Source: 192.168.65.3 + // Destination: 224.0.0.251 + // Interface Index: 1 + // Interface Addresses @ 1: [127.0.0.1/8 ::1/128] + if ifIndex != -1 { + ifc, ok := c.ifaces[ifIndex] + if !ok { + c.log.Warnf("[%s] no interface for %d", c.name, ifIndex) + + return + } + var selectedAddrs []netip.Addr + for _, addr := range ifc.ipAddrs { + addrCopy := addr + + // match up respective IP types based on question + if queryWantsV4 { + if addrCopy.Is4In6() { + // we may allow 4-in-6, but the question wants an A record + addrCopy = addrCopy.Unmap() + } + if !addrCopy.Is4() { + continue + } + } else { // queryWantsV6 + if !addrCopy.Is6() { + continue + } + if !isSupportedIPv6(addrCopy, c.multicastPktConnV4 == nil) { + c.log.Debugf("[%s] interface %d address not a supported IPv6 address %s", c.name, ifIndex, &addrCopy) + + continue + } + } + + selectedAddrs = append(selectedAddrs, addrCopy) + } + if len(selectedAddrs) == 0 { + c.log.Debugf( + "[%s] failed to find suitable IP for interface %d; deriving address from source address c.name,instead", + c.name, ifIndex, + ) + } else { + // choose the best match + var choice *netip.Addr + for _, option := range selectedAddrs { + optCopy := option + if option.Is4() { + // select first + choice = &optCopy + + break + } + // we're okay with 4In6 for now but ideally we get a an actual IPv6. + // Maybe in the future we never want this but it does look like Docker + // can route IPv4 over IPv6. + if choice == nil || !optCopy.Is4In6() { + choice = &optCopy + } + if !optCopy.Is4In6() { + break + } + // otherwise keep searching for an actual IPv6 + } + localAddress = choice + } + } + if ifIndex == -1 || localAddress == nil { + localAddress, err = interfaceForRemote(src.String()) + if err != nil { + c.log.Warnf("[%s] failed to get local interface to communicate with %s: %v", c.name, src.String(), err) + + continue + } + } + } + if queryWantsV4 { + if !localAddress.Is4() { + c.log.Debugf( + "[%s] have IPv6 address %s to respond with but question is for A not c.name,AAAA", + c.name, localAddress, + ) + + continue + } + } else { + if !localAddress.Is6() { + c.log.Debugf( + "[%s] have IPv4 address %s to respond with but question is for AAAA not c.name,A", + c.name, localAddress, + ) + + continue + } + if !isSupportedIPv6(*localAddress, c.multicastPktConnV4 == nil) { + c.log.Debugf("[%s] got local interface address but not a supported IPv6 address %v", c.name, localAddress) + + continue + } + } + + if dst != nil && len(dst.IP) == net.IPv4len && + localAddress.Is6() && + localAddress.Zone() != "" && + (localAddress.IsLinkLocalUnicast() || localAddress.IsLinkLocalMulticast()) { + // This case happens when multicast v4 picks up an AAAA question that has a zone + // in the address. Since we cannot send this zone over DNS (it's meaningless), + // the other side can only infer this via the response interface on the other + // side (some IPv6 interface). + c.log.Debugf("[%s] refusing to send link-local address %s to an IPv4 destination %s", c.name, localAddress, dst) + + continue + } + c.log.Debugf( + "[%s] sending response for %s on ifc %d of %s to %s", + c.name, question.Name, ifIndex, *localAddress, dst, + ) + c.sendAnswer(msg.Header.ID, question, ifIndex, *localAddress, dst, shouldReplyUnicast) + } + } + } + } else { + for _, answer := range msg.Answers { + if answer.Header.Type != dnsmessage.TypeA && answer.Header.Type != dnsmessage.TypeAAAA { + continue + } + + c.mu.Lock() + queries := make([]*query, len(c.queries)) + copy(queries, c.queries) + c.mu.Unlock() + + var answered []*query + for _, query := range queries { + queryCopy := query + if strings.EqualFold(queryCopy.nameWithSuffix, answer.Header.Name.String()) { + addr, err := addrFromAnswer(answer) + if err != nil { + c.log.Warnf("[%s] failed to parse mDNS answer %v", c.name, err) + + return + } + + resultAddr := *addr + // DNS records don't contain IPv6 zones. + // We're trusting that since we're on the same link, that we will only + // be sent link-local addresses from that source's interface's address. + // If it's not present, we're out of luck since we cannot rely on the + // interface zone to be the same as the source's. + resultAddr = addrWithOptionalZone(resultAddr, srcAddr.Zone) + + select { + case queryCopy.queryResultChan <- queryResult{answer.Header, resultAddr}: + answered = append(answered, queryCopy) + default: + } + } + } + + c.mu.Lock() + for queryIdx := len(c.queries) - 1; queryIdx >= 0; queryIdx-- { + for answerIdx := len(answered) - 1; answerIdx >= 0; answerIdx-- { + if c.queries[queryIdx] == answered[answerIdx] { + c.queries = append(c.queries[:queryIdx], c.queries[queryIdx+1:]...) + answered = append(answered[:answerIdx], answered[answerIdx+1:]...) + queryIdx-- + + break + } + } + } + c.mu.Unlock() + } + } + }() + } +} + +func (c *Conn) start(started chan<- struct{}, inboundBufferSize int, config *Config) { + defer func() { + c.mu.Lock() + defer c.mu.Unlock() + close(c.closed) + }() + + var numReaders int + readerStarted := make(chan struct{}) + readerEnded := make(chan struct{}) + + if c.multicastPktConnV4 != nil { + numReaders++ + go func() { + defer func() { + readerEnded <- struct{}{} + }() + readerStarted <- struct{}{} + c.readLoop("multi4", c.multicastPktConnV4, inboundBufferSize, config) + }() + } + if c.multicastPktConnV6 != nil { + numReaders++ + go func() { + defer func() { + readerEnded <- struct{}{} + }() + readerStarted <- struct{}{} + c.readLoop("multi6", c.multicastPktConnV6, inboundBufferSize, config) + }() + } + if c.unicastPktConnV4 != nil { + numReaders++ + go func() { + defer func() { + readerEnded <- struct{}{} + }() + readerStarted <- struct{}{} + c.readLoop("uni4", c.unicastPktConnV4, inboundBufferSize, config) + }() + } + if c.unicastPktConnV6 != nil { + numReaders++ + go func() { + defer func() { + readerEnded <- struct{}{} + }() + readerStarted <- struct{}{} + c.readLoop("uni6", c.unicastPktConnV6, inboundBufferSize, config) + }() + } + for i := 0; i < numReaders; i++ { + <-readerStarted + } + close(started) + for i := 0; i < numReaders; i++ { + <-readerEnded + } +} + +func addrFromAnswer(answer dnsmessage.Resource) (*netip.Addr, error) { + switch answer.Header.Type { + case dnsmessage.TypeA: + if a, ok := answer.Body.(*dnsmessage.AResource); ok { + addr, ok := netip.AddrFromSlice(a.A[:]) + if ok { + addr = addr.Unmap() // do not want 4-in-6 + + return &addr, nil + } + } + + return nil, errFailedToDecodeAddrFromAResource + case dnsmessage.TypeAAAA: + if a, ok := answer.Body.(*dnsmessage.AAAAResource); ok { + addr, ok := netip.AddrFromSlice(a.AAAA[:]) + if ok { + return &addr, nil + } + } + + return nil, errFailedToDecodeAddrFromAAAAResource + default: + return nil, errUnhandledAnswerHeaderType + } +} + +func isSupportedIPv6(addr netip.Addr, ipv6Only bool) bool { + if !addr.Is6() { + return false + } + // IPv4-mapped-IPv6 addresses cannot be connected to unless + // unmapped. + if !ipv6Only && addr.Is4In6() { + return false + } + + return true +} + +func addrWithOptionalZone(addr netip.Addr, zone string) netip.Addr { + if zone == "" { + return addr + } + if addr.Is6() && (addr.IsLinkLocalUnicast() || addr.IsLinkLocalMulticast()) { + return addr.WithZone(zone) + } + + return addr +} diff --git a/vendor/github.com/pion/mdns/v2/errors.go b/vendor/github.com/pion/mdns/v2/errors.go new file mode 100644 index 0000000..f065e16 --- /dev/null +++ b/vendor/github.com/pion/mdns/v2/errors.go @@ -0,0 +1,14 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package mdns + +import "errors" + +var ( + errJoiningMulticastGroup = errors.New("mDNS: failed to join multicast group") + errConnectionClosed = errors.New("mDNS: connection is closed") + errContextElapsed = errors.New("mDNS: context has elapsed") + errNilConfig = errors.New("mDNS: config must not be nil") + errFailedCast = errors.New("mDNS: failed to cast listener to UDPAddr") +) diff --git a/vendor/github.com/pion/mdns/v2/mdns.go b/vendor/github.com/pion/mdns/v2/mdns.go new file mode 100644 index 0000000..a76b16b --- /dev/null +++ b/vendor/github.com/pion/mdns/v2/mdns.go @@ -0,0 +1,5 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package mdns implements mDNS (multicast DNS) +package mdns diff --git a/vendor/github.com/pion/mdns/v2/renovate.json b/vendor/github.com/pion/mdns/v2/renovate.json new file mode 100644 index 0000000..f1bb98c --- /dev/null +++ b/vendor/github.com/pion/mdns/v2/renovate.json @@ -0,0 +1,6 @@ +{ + "$schema": "https://docs.renovatebot.com/renovate-schema.json", + "extends": [ + "github>pion/renovate-config" + ] +} diff --git a/vendor/github.com/pion/rtcp/.gitignore b/vendor/github.com/pion/rtcp/.gitignore new file mode 100644 index 0000000..6e2f206 --- /dev/null +++ b/vendor/github.com/pion/rtcp/.gitignore @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +### JetBrains IDE ### +##################### +.idea/ + +### Emacs Temporary Files ### +############################# +*~ + +### Folders ### +############### +bin/ +vendor/ +node_modules/ + +### Files ### +############# +*.ivf +*.ogg +tags +cover.out +*.sw[poe] +*.wasm +examples/sfu-ws/cert.pem +examples/sfu-ws/key.pem +wasm_exec.js diff --git a/vendor/github.com/pion/rtcp/.golangci.yml b/vendor/github.com/pion/rtcp/.golangci.yml new file mode 100644 index 0000000..4b4025f --- /dev/null +++ b/vendor/github.com/pion/rtcp/.golangci.yml @@ -0,0 +1,147 @@ +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +version: "2" +linters: + enable: + - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers + - bidichk # Checks for dangerous unicode character sequences + - bodyclose # checks whether HTTP response body is closed successfully + - containedctx # containedctx is a linter that detects struct contained context.Context field + - contextcheck # check the function whether use a non-inherited context + - cyclop # checks function and package cyclomatic complexity + - decorder # check declaration order and count of types, constants, variables and functions + - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) + - dupl # Tool for code clone detection + - durationcheck # check for two durations multiplied together + - err113 # Golang linter to check the errors handling expressions + - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases + - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. + - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. + - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. + - exhaustive # check exhaustiveness of enum switch statements + - forbidigo # Forbids identifiers + - forcetypeassert # finds forced type assertions + - gochecknoglobals # Checks that no globals are present in Go code + - gocognit # Computes and checks the cognitive complexity of functions + - goconst # Finds repeated strings that could be replaced by a constant + - gocritic # The most opinionated Go source code linter + - gocyclo # Computes and checks the cyclomatic complexity of functions + - godot # Check if comments end in a period + - godox # Tool for detection of FIXME, TODO and other comment keywords + - goheader # Checks is file header matches to pattern + - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. + - goprintffuncname # Checks that printf-like functions are named with `f` at the end + - gosec # Inspects source code for security problems + - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string + - grouper # An analyzer to analyze expression groups. + - importas # Enforces consistent import aliases + - ineffassign # Detects when assignments to existing variables are not used + - lll # Reports long lines + - maintidx # maintidx measures the maintainability index of each function. + - makezero # Finds slice declarations with non-zero initial length + - misspell # Finds commonly misspelled English words in comments + - nakedret # Finds naked returns in functions greater than a specified function length + - nestif # Reports deeply nested if statements + - nilerr # Finds the code that returns nil even if it checks that the error is not nil. + - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. + - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity + - noctx # noctx finds sending http request without context.Context + - predeclared # find code that shadows one of Go's predeclared identifiers + - revive # golint replacement, finds style mistakes + - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks + - tagliatelle # Checks the struct tags. + - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers + - unconvert # Remove unnecessary type conversions + - unparam # Reports unused function parameters + - unused # Checks Go code for unused constants, variables, functions and types + - varnamelen # checks that the length of a variable's name matches its scope + - wastedassign # wastedassign finds wasted assignment statements + - whitespace # Tool for detection of leading and trailing whitespace + disable: + - depguard # Go linter that checks if package imports are in a list of acceptable packages + - funlen # Tool for detection of long functions + - gochecknoinits # Checks that no init functions are present in Go code + - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. + - interfacebloat # A linter that checks length of interface. + - ireturn # Accept Interfaces, Return Concrete Types + - mnd # An analyzer to detect magic numbers + - nolintlint # Reports ill-formed or insufficient nolint directives + - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test + - prealloc # Finds slice declarations that could potentially be preallocated + - promlinter # Check Prometheus metrics naming via promlint + - rowserrcheck # checks whether Err of rows is checked successfully + - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. + - testpackage # linter that makes you use a separate _test package + - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes + - wrapcheck # Checks that errors returned from external packages are wrapped + - wsl # Whitespace Linter - Forces you to use empty lines! + settings: + staticcheck: + checks: + - all + - -QF1008 # "could remove embedded field", to keep it explicit! + - -QF1003 # "could use tagged switch on enum", Cases conflicts with exhaustive! + exhaustive: + default-signifies-exhaustive: true + forbidigo: + forbid: + - pattern: ^fmt.Print(f|ln)?$ + - pattern: ^log.(Panic|Fatal|Print)(f|ln)?$ + - pattern: ^os.Exit$ + - pattern: ^panic$ + - pattern: ^print(ln)?$ + - pattern: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$ + pkg: ^testing$ + msg: use testify/assert instead + analyze-types: true + gomodguard: + blocked: + modules: + - github.com/pkg/errors: + recommendations: + - errors + govet: + enable: + - shadow + revive: + rules: + # Prefer 'any' type alias over 'interface{}' for Go 1.18+ compatibility + - name: use-any + severity: warning + disabled: false + misspell: + locale: US + varnamelen: + max-distance: 12 + min-name-length: 2 + ignore-type-assert-ok: true + ignore-map-index-ok: true + ignore-chan-recv-ok: true + ignore-decls: + - i int + - n int + - w io.Writer + - r io.Reader + - b []byte + exclusions: + generated: lax + rules: + - linters: + - forbidigo + - gocognit + path: (examples|main\.go) + - linters: + - gocognit + path: _test\.go + - linters: + - forbidigo + path: cmd +formatters: + enable: + - gci # Gci control golang package import order and make it always deterministic. + - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification + - gofumpt # Gofumpt checks whether code was gofumpt-ed. + - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports + exclusions: + generated: lax diff --git a/vendor/github.com/pion/rtcp/.goreleaser.yml b/vendor/github.com/pion/rtcp/.goreleaser.yml new file mode 100644 index 0000000..30093e9 --- /dev/null +++ b/vendor/github.com/pion/rtcp/.goreleaser.yml @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +builds: +- skip: true diff --git a/vendor/github.com/pion/rtcp/LICENSE b/vendor/github.com/pion/rtcp/LICENSE new file mode 100644 index 0000000..491caf6 --- /dev/null +++ b/vendor/github.com/pion/rtcp/LICENSE @@ -0,0 +1,9 @@ +MIT License + +Copyright (c) 2023 The Pion community + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/pion/rtcp/README.md b/vendor/github.com/pion/rtcp/README.md new file mode 100644 index 0000000..c4a43cd --- /dev/null +++ b/vendor/github.com/pion/rtcp/README.md @@ -0,0 +1,37 @@ +

+
+ Pion RTCP +
+

+

A Go implementation of RTCP

+

+ Pion RTCP + Sourcegraph Widget + join us on Discord Follow us on Bluesky +
+ GitHub Workflow Status + Go Reference + Coverage Status + Go Report Card + License: MIT +

+
+ +See [DESIGN.md](DESIGN.md) for an overview of features and future goals. + +### Roadmap +The library is used as a part of our WebRTC implementation. Please refer to that [roadmap](https://github.com/pion/webrtc/issues/9) to track our major milestones. + +### Community +Pion has an active community on the [Discord](https://discord.gg/PngbdqpFbt). + +Follow the [Pion Bluesky](https://bsky.app/profile/pion.ly) or [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news. + +We are always looking to support **your projects**. Please reach out if you have something to build! +If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly) + +### Contributing +Check out the [contributing wiki](https://github.com/pion/webrtc/wiki/Contributing) to join the group of amazing people making this project possible + +### License +MIT License - see [LICENSE](LICENSE) for full text diff --git a/vendor/github.com/pion/rtcp/application_defined.go b/vendor/github.com/pion/rtcp/application_defined.go new file mode 100644 index 0000000..ca5f844 --- /dev/null +++ b/vendor/github.com/pion/rtcp/application_defined.go @@ -0,0 +1,123 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package rtcp + +import ( + "encoding/binary" +) + +// ApplicationDefined represents an RTCP application-defined packet. +type ApplicationDefined struct { + SubType uint8 + SSRC uint32 + Name string + Data []byte +} + +// DestinationSSRC returns the SSRC value for this packet. +func (a ApplicationDefined) DestinationSSRC() []uint32 { + return []uint32{a.SSRC} +} + +// Marshal serializes the application-defined struct into a byte slice with padding. +func (a ApplicationDefined) Marshal() ([]byte, error) { + dataLength := len(a.Data) + if dataLength > 0xFFFF-12 { + return nil, errAppDefinedDataTooLarge + } + if len(a.Name) != 4 { + return nil, errAppDefinedInvalidName + } + // Calculate the padding size to be added to make the packet length a multiple of 4 bytes. + paddingSize := 4 - (dataLength % 4) + if paddingSize == 4 { + paddingSize = 0 + } + + packetSize := a.MarshalSize() + header := Header{ + Type: TypeApplicationDefined, + Length: uint16((packetSize / 4) - 1), //nolint:gosec // G115 + Padding: paddingSize != 0, + Count: a.SubType, + } + + headerBytes, err := header.Marshal() + if err != nil { + return nil, err + } + + rawPacket := make([]byte, packetSize) + copy(rawPacket, headerBytes) + binary.BigEndian.PutUint32(rawPacket[4:8], a.SSRC) + copy(rawPacket[8:12], a.Name) + copy(rawPacket[12:], a.Data) + + // Add padding if necessary. + if paddingSize > 0 { + for i := 0; i < paddingSize; i++ { + rawPacket[12+dataLength+i] = byte(paddingSize) + } + } + + return rawPacket, nil +} + +// Unmarshal parses the given raw packet into an application-defined struct, handling padding. +func (a *ApplicationDefined) Unmarshal(rawPacket []byte) error { + /* + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |V=2|P| subtype | PT=APP=204 | length | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | SSRC/CSRC | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | name (ASCII) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | application-dependent data ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ + header := Header{} + err := header.Unmarshal(rawPacket) + if err != nil { + return err + } + if len(rawPacket) < 12 { + return errPacketTooShort + } + + if int(header.Length+1)*4 != len(rawPacket) { + return errAppDefinedInvalidLength + } + + a.SubType = header.Count + a.SSRC = binary.BigEndian.Uint32(rawPacket[4:8]) + a.Name = string(rawPacket[8:12]) + + // Check for padding. + paddingSize := 0 + if header.Padding { + paddingSize = int(rawPacket[len(rawPacket)-1]) + if paddingSize > len(rawPacket)-12 { + return errWrongPadding + } + } + + a.Data = rawPacket[12 : len(rawPacket)-paddingSize] + + return nil +} + +// MarshalSize returns the size of the packet once marshaled. +func (a *ApplicationDefined) MarshalSize() int { + dataLength := len(a.Data) + // Calculate the padding size to be added to make the packet length a multiple of 4 bytes. + paddingSize := 4 - (dataLength % 4) + if paddingSize == 4 { + paddingSize = 0 + } + + return 12 + dataLength + paddingSize +} diff --git a/vendor/github.com/pion/rtcp/codecov.yml b/vendor/github.com/pion/rtcp/codecov.yml new file mode 100644 index 0000000..263e4d4 --- /dev/null +++ b/vendor/github.com/pion/rtcp/codecov.yml @@ -0,0 +1,22 @@ +# +# DO NOT EDIT THIS FILE +# +# It is automatically copied from https://github.com/pion/.goassets repository. +# +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +coverage: + status: + project: + default: + # Allow decreasing 2% of total coverage to avoid noise. + threshold: 2% + patch: + default: + target: 70% + only_pulls: true + +ignore: + - "examples/*" + - "examples/**/*" diff --git a/vendor/github.com/pion/rtcp/compound_packet.go b/vendor/github.com/pion/rtcp/compound_packet.go new file mode 100644 index 0000000..1752b64 --- /dev/null +++ b/vendor/github.com/pion/rtcp/compound_packet.go @@ -0,0 +1,167 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package rtcp + +import ( + "fmt" + "strings" +) + +// A CompoundPacket is a collection of RTCP packets transmitted as a single packet with +// the underlying protocol (for example UDP). +// +// To maximize the resolution of receiption statistics, the first Packet in a CompoundPacket +// must always be either a SenderReport or a ReceiverReport. This is true even if no data +// has been sent or received, in which case an empty ReceiverReport must be sent, and even +// if the only other RTCP packet in the compound packet is a Goodbye. +// +// Next, a SourceDescription containing a CNAME item must be included in each CompoundPacket +// to identify the source and to begin associating media for purposes such as lip-sync. +// +// Other RTCP packet types may follow in any order. Packet types may appear more than once. +type CompoundPacket []Packet + +// Validate returns an error if this is not an RFC-compliant CompoundPacket. +// +//nolint:cyclop +func (c CompoundPacket) Validate() error { + if len(c) == 0 { + return errEmptyCompound + } + + // SenderReport and ReceiverReport are the only types that + // are allowed to be the first packet in a compound datagram + switch c[0].(type) { + case *SenderReport, *ReceiverReport: + // ok + default: + return errBadFirstPacket + } + + for _, pkt := range c[1:] { + switch p := pkt.(type) { + // If the number of RecetpionReports exceeds 31 additional ReceiverReports + // can be included here. + case *ReceiverReport: + continue + + // A SourceDescription containing a CNAME must be included in every + // CompoundPacket. + case *SourceDescription: + var hasCNAME bool + for _, c := range p.Chunks { + for _, it := range c.Items { + if it.Type == SDESCNAME { + hasCNAME = true + } + } + } + + if !hasCNAME { + return errMissingCNAME + } + + return nil + + // Other packets are not permitted before the CNAME + default: + return errPacketBeforeCNAME + } + } + + // CNAME never reached + return errMissingCNAME +} + +// CNAME returns the CNAME that *must* be present in every CompoundPacket. +func (c CompoundPacket) CNAME() (string, error) { + var err error + + if len(c) < 1 { + return "", errEmptyCompound + } + + for _, pkt := range c[1:] { + sdes, ok := pkt.(*SourceDescription) + if ok { + for _, c := range sdes.Chunks { + for _, it := range c.Items { + if it.Type == SDESCNAME { + return it.Text, err + } + } + } + } else { + _, ok := pkt.(*ReceiverReport) + if !ok { + err = errPacketBeforeCNAME + } + } + } + + return "", errMissingCNAME +} + +// Marshal encodes the CompoundPacket as binary. +func (c CompoundPacket) Marshal() ([]byte, error) { + if err := c.Validate(); err != nil { + return nil, err + } + + p := []Packet(c) + + return Marshal(p) +} + +// MarshalSize returns the size of the packet once marshaled. +func (c CompoundPacket) MarshalSize() int { + l := 0 + for _, p := range c { + l += p.MarshalSize() + } + + return l +} + +// Unmarshal decodes a CompoundPacket from binary. +func (c *CompoundPacket) Unmarshal(rawData []byte) error { + out := make(CompoundPacket, 0) + for len(rawData) != 0 { + p, processed, err := unmarshal(rawData) + if err != nil { + return err + } + + out = append(out, p) + rawData = rawData[processed:] + } + *c = out + + return c.Validate() +} + +// DestinationSSRC returns the synchronization sources associated with this +// CompoundPacket's reception report. +func (c CompoundPacket) DestinationSSRC() []uint32 { + if len(c) == 0 { + return nil + } + + return c[0].DestinationSSRC() +} + +func (c CompoundPacket) String() string { + out := "CompoundPacket\n" + for _, p := range c { + stringer, canString := p.(fmt.Stringer) + if canString { + out += stringer.String() + } else { + out += stringify(p) + } + } + out = strings.TrimSuffix(strings.ReplaceAll(out, "\n", "\n\t"), "\t") + + return out +} diff --git a/vendor/github.com/pion/rtcp/doc.go b/vendor/github.com/pion/rtcp/doc.go new file mode 100644 index 0000000..22f06cc --- /dev/null +++ b/vendor/github.com/pion/rtcp/doc.go @@ -0,0 +1,42 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +/* +Package rtcp implements encoding and decoding of RTCP packets according to RFCs 3550 and 5506. + +RTCP is a sister protocol of the Real-time Transport Protocol (RTP). Its basic functionality +and packet structure is defined in RFC 3550. RTCP provides out-of-band statistics and control +information for an RTP session. It partners with RTP in the delivery and packaging of multimedia data, +but does not transport any media data itself. + +The primary function of RTCP is to provide feedback on the quality of service (QoS) +in media distribution by periodically sending statistics information such as transmitted octet +and packet counts, packet loss, packet delay variation, and round-trip delay time to participants +in a streaming multimedia session. An application may use this information to control quality of +service parameters, perhaps by limiting flow, or using a different codec. + +Decoding RTCP packets: + + pkts, err := rtcp.Unmarshal(rtcpData) + // ... + for _, pkt := range pkts { + switch p := pkt.(type) { + case *rtcp.CompoundPacket: + ... + case *rtcp.PictureLossIndication: + ... + default: + ... + } + } + +Encoding RTCP packets: + + pkt := &rtcp.PictureLossIndication{ + SenderSSRC: senderSSRC, + MediaSSRC: mediaSSRC + } + pliData, err := pkt.Marshal() + // ... +*/ +package rtcp diff --git a/vendor/github.com/pion/rtcp/errors.go b/vendor/github.com/pion/rtcp/errors.go new file mode 100644 index 0000000..05e6847 --- /dev/null +++ b/vendor/github.com/pion/rtcp/errors.go @@ -0,0 +1,41 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package rtcp + +import "errors" + +var ( + errWrongMarshalSize = errors.New("rtcp: wrong marshal size") + errInvalidTotalLost = errors.New("rtcp: invalid total lost count") + errInvalidHeader = errors.New("rtcp: invalid header") + errEmptyCompound = errors.New("rtcp: empty compound packet") + errBadFirstPacket = errors.New("rtcp: first packet in compound must be SR or RR") + errMissingCNAME = errors.New("rtcp: compound missing SourceDescription with CNAME") + errPacketBeforeCNAME = errors.New("rtcp: feedback packet seen before CNAME") + errTooManyReports = errors.New("rtcp: too many reports") + errTooManyChunks = errors.New("rtcp: too many chunks") + errTooManySources = errors.New("rtcp: too many sources") + errPacketTooShort = errors.New("rtcp: packet too short") + errWrongType = errors.New("rtcp: wrong packet type") + errSDESTextTooLong = errors.New("rtcp: sdes must be < 255 octets long") + errSDESMissingType = errors.New("rtcp: sdes item missing type") + errReasonTooLong = errors.New("rtcp: reason must be < 255 octets long") + errBadVersion = errors.New("rtcp: invalid packet version") + errBadLength = errors.New("rtcp: invalid packet length") + errWrongPadding = errors.New("rtcp: invalid padding value") + errWrongFeedbackType = errors.New("rtcp: wrong feedback message type") + errWrongPayloadType = errors.New("rtcp: wrong payload type") + errHeaderTooSmall = errors.New("rtcp: header length is too small") + errSSRCMustBeZero = errors.New("rtcp: media SSRC must be 0") + errMissingREMBidentifier = errors.New("missing REMB identifier") + errSSRCNumAndLengthMismatch = errors.New("SSRC num and length do not match") + errInvalidSizeOrStartIndex = errors.New("invalid size or startIndex") + errInvalidBitrate = errors.New("invalid bitrate") + errWrongChunkType = errors.New("rtcp: wrong chunk type") + errBadStructMemberType = errors.New("rtcp: struct contains unexpected member type") + errBadReadParameter = errors.New("rtcp: cannot read into non-pointer") + errAppDefinedInvalidLength = errors.New("rtcp: application defined type invalid length") + errAppDefinedDataTooLarge = errors.New("rtcp: application defined data is too large") + errAppDefinedInvalidName = errors.New("rtcp: application defined name must be 4 ASCII chars") +) diff --git a/vendor/github.com/pion/rtcp/extended_report.go b/vendor/github.com/pion/rtcp/extended_report.go new file mode 100644 index 0000000..7b025c9 --- /dev/null +++ b/vendor/github.com/pion/rtcp/extended_report.go @@ -0,0 +1,677 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package rtcp + +import ( + "fmt" +) + +// The ExtendedReport packet is an Implementation of RTCP Extended +// Reports defined in RFC 3611. It is used to convey detailed +// information about an RTP stream. Each packet contains one or +// more report blocks, each of which conveys a different kind of +// information. +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |V=2|P|reserved | PT=XR=207 | length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | SSRC | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// : report blocks : +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// . +type ExtendedReport struct { + SenderSSRC uint32 `fmt:"0x%X"` + Reports []ReportBlock +} + +// ReportBlock represents a single report within an ExtendedReport +// packet. +type ReportBlock interface { + DestinationSSRC() []uint32 + setupBlockHeader() + unpackBlockHeader() +} + +// TypeSpecificField as described in RFC 3611 section 4.5. In typical +// cases, users of ExtendedReports shouldn't need to access this, +// and should instead use the corresponding fields in the actual +// report blocks themselves. +type TypeSpecificField uint8 + +// XRHeader defines the common fields that must appear at the start +// of each report block. In typical cases, users of ExtendedReports +// shouldn't need to access this. For locally-constructed report +// blocks, these values will not be accurate until the corresponding +// packet is marshaled. +type XRHeader struct { + BlockType BlockTypeType + TypeSpecific TypeSpecificField `fmt:"0x%X"` + BlockLength uint16 +} + +// BlockTypeType specifies the type of report in a report block. +type BlockTypeType uint8 + +// Extended Report block types from RFC 3611. +const ( + LossRLEReportBlockType = 1 // RFC 3611, section 4.1 + DuplicateRLEReportBlockType = 2 // RFC 3611, section 4.2 + PacketReceiptTimesReportBlockType = 3 // RFC 3611, section 4.3 + ReceiverReferenceTimeReportBlockType = 4 // RFC 3611, section 4.4 + DLRRReportBlockType = 5 // RFC 3611, section 4.5 + StatisticsSummaryReportBlockType = 6 // RFC 3611, section 4.6 + VoIPMetricsReportBlockType = 7 // RFC 3611, section 4.7 +) + +// String converts the Extended report block types into readable strings. +func (t BlockTypeType) String() string { + switch t { + case LossRLEReportBlockType: + return "LossRLEReportBlockType" + case DuplicateRLEReportBlockType: + return "DuplicateRLEReportBlockType" + case PacketReceiptTimesReportBlockType: + return "PacketReceiptTimesReportBlockType" + case ReceiverReferenceTimeReportBlockType: + return "ReceiverReferenceTimeReportBlockType" + case DLRRReportBlockType: + return "DLRRReportBlockType" + case StatisticsSummaryReportBlockType: + return "StatisticsSummaryReportBlockType" + case VoIPMetricsReportBlockType: + return "VoIPMetricsReportBlockType" + } + + return fmt.Sprintf("invalid value %d", t) +} + +// rleReportBlock defines the common structure used by both +// Loss RLE report blocks (RFC 3611 §4.1) and Duplicate RLE +// report blocks (RFC 3611 §4.2). +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | BT = 1 or 2 | rsvd. | T | block length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | SSRC of source | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | begin_seq | end_seq | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | chunk 1 | chunk 2 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// : ... : +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | chunk n-1 | chunk n | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// . +type rleReportBlock struct { + XRHeader + T uint8 `encoding:"omit"` + SSRC uint32 `fmt:"0x%X"` + BeginSeq uint16 + EndSeq uint16 + Chunks []Chunk +} + +// Chunk as defined in RFC 3611, section 4.1. These represent information +// about packet losses and packet duplication. They have three representations: +// +// Run Length Chunk: +// +// 0 1 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |C|R| run length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// +// Bit Vector Chunk: +// +// 0 1 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |C| bit vector | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// +// Terminating Null Chunk: +// +// 0 1 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0| +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +type Chunk uint16 + +// LossRLEReportBlock is used to report information about packet +// losses, as described in RFC 3611, section 4.1. +type LossRLEReportBlock rleReportBlock + +// DestinationSSRC returns an array of SSRC values that this report block refers to. +func (b *LossRLEReportBlock) DestinationSSRC() []uint32 { + return []uint32{b.SSRC} +} + +func (b *LossRLEReportBlock) setupBlockHeader() { + b.XRHeader.BlockType = LossRLEReportBlockType + b.XRHeader.TypeSpecific = TypeSpecificField(b.T & 0x0F) + b.XRHeader.BlockLength = uint16(wireSize(b)/4 - 1) //nolint:gosec // G115 +} + +func (b *LossRLEReportBlock) unpackBlockHeader() { + b.T = uint8(b.XRHeader.TypeSpecific) & 0x0F +} + +// DuplicateRLEReportBlock is used to report information about packet +// duplication, as described in RFC 3611, section 4.1. +type DuplicateRLEReportBlock rleReportBlock + +// DestinationSSRC returns an array of SSRC values that this report block refers to. +func (b *DuplicateRLEReportBlock) DestinationSSRC() []uint32 { + return []uint32{b.SSRC} +} + +func (b *DuplicateRLEReportBlock) setupBlockHeader() { + b.XRHeader.BlockType = DuplicateRLEReportBlockType + b.XRHeader.TypeSpecific = TypeSpecificField(b.T & 0x0F) + b.XRHeader.BlockLength = uint16(wireSize(b)/4 - 1) //nolint:gosec // G115 +} + +func (b *DuplicateRLEReportBlock) unpackBlockHeader() { + b.T = uint8(b.XRHeader.TypeSpecific) & 0x0F +} + +// ChunkType enumerates the three kinds of chunks described in RFC 3611 section 4.1. +type ChunkType uint8 + +// These are the valid values that ChunkType can assume. +const ( + RunLengthChunkType = 0 + BitVectorChunkType = 1 + TerminatingNullChunkType = 2 +) + +func (c Chunk) String() string { + switch c.Type() { + case RunLengthChunkType: + runType, _ := c.RunType() + + return fmt.Sprintf("[RunLength type=%d, length=%d]", runType, c.Value()) + case BitVectorChunkType: + return fmt.Sprintf("[BitVector 0b%015b]", c.Value()) + case TerminatingNullChunkType: + return "[TerminatingNull]" + } + + return fmt.Sprintf("[0x%X]", uint16(c)) +} + +// Type returns the ChunkType that this Chunk represents. +func (c Chunk) Type() ChunkType { + if c == 0 { + return TerminatingNullChunkType + } + + return ChunkType(c >> 15) //nolint:gosec // G115 +} + +// RunType returns the RunType that this Chunk represents. It is +// only valid if ChunkType is RunLengthChunkType. +func (c Chunk) RunType() (uint, error) { + if c.Type() != RunLengthChunkType { + return 0, errWrongChunkType + } + + return uint((c >> 14) & 0x01), nil +} + +// Value returns the value represented in this Chunk. +func (c Chunk) Value() uint { + switch c.Type() { + case RunLengthChunkType: + return uint(c & 0x3FFF) + case BitVectorChunkType: + return uint(c & 0x7FFF) + case TerminatingNullChunkType: + return 0 + } + + return uint(c) +} + +// PacketReceiptTimesReportBlock represents a Packet Receipt Times +// report block, as described in RFC 3611 section 4.3. +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | BT=3 | rsvd. | T | block length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | SSRC of source | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | begin_seq | end_seq | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Receipt time of packet begin_seq | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Receipt time of packet (begin_seq + 1) mod 65536 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// : ... : +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Receipt time of packet (end_seq - 1) mod 65536 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// . +type PacketReceiptTimesReportBlock struct { + XRHeader + T uint8 `encoding:"omit"` + SSRC uint32 `fmt:"0x%X"` + BeginSeq uint16 + EndSeq uint16 + ReceiptTime []uint32 +} + +// DestinationSSRC returns an array of SSRC values that this report block refers to. +func (b *PacketReceiptTimesReportBlock) DestinationSSRC() []uint32 { + return []uint32{b.SSRC} +} + +func (b *PacketReceiptTimesReportBlock) setupBlockHeader() { + b.XRHeader.BlockType = PacketReceiptTimesReportBlockType + b.XRHeader.TypeSpecific = TypeSpecificField(b.T & 0x0F) + b.XRHeader.BlockLength = uint16(wireSize(b)/4 - 1) //nolint:gosec // G115 +} + +func (b *PacketReceiptTimesReportBlock) unpackBlockHeader() { + b.T = uint8(b.XRHeader.TypeSpecific) & 0x0F +} + +// ReceiverReferenceTimeReportBlock encodes a Receiver Reference Time +// report block as described in RFC 3611 section 4.4. +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | BT=4 | reserved | block length = 2 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | NTP timestamp, most significant word | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | NTP timestamp, least significant word | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// . +type ReceiverReferenceTimeReportBlock struct { + XRHeader + NTPTimestamp uint64 +} + +// DestinationSSRC returns an array of SSRC values that this report block refers to. +func (b *ReceiverReferenceTimeReportBlock) DestinationSSRC() []uint32 { + return []uint32{} +} + +func (b *ReceiverReferenceTimeReportBlock) setupBlockHeader() { + b.XRHeader.BlockType = ReceiverReferenceTimeReportBlockType + b.XRHeader.TypeSpecific = 0 + b.XRHeader.BlockLength = uint16(wireSize(b)/4 - 1) //nolint:gosec // G115 +} + +func (b *ReceiverReferenceTimeReportBlock) unpackBlockHeader() { +} + +// DLRRReportBlock encodes a DLRR Report Block as described in +// RFC 3611 section 4.5. +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | BT=5 | reserved | block length | +// +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ +// | SSRC_1 (SSRC of first receiver) | sub- +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ block +// | last RR (LRR) | 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | delay since last RR (DLRR) | +// +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ +// | SSRC_2 (SSRC of second receiver) | sub- +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ block +// : ... : 2 +// +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ +// . +type DLRRReportBlock struct { + XRHeader + Reports []DLRRReport +} + +// DLRRReport encodes a single report inside a DLRRReportBlock. +type DLRRReport struct { + SSRC uint32 `fmt:"0x%X"` + LastRR uint32 + DLRR uint32 +} + +// DestinationSSRC returns an array of SSRC values that this report block refers to. +func (b *DLRRReportBlock) DestinationSSRC() []uint32 { + ssrc := make([]uint32, len(b.Reports)) + for i, r := range b.Reports { + ssrc[i] = r.SSRC + } + + return ssrc +} + +func (b *DLRRReportBlock) setupBlockHeader() { + b.XRHeader.BlockType = DLRRReportBlockType + b.XRHeader.TypeSpecific = 0 + b.XRHeader.BlockLength = uint16(wireSize(b)/4 - 1) //nolint:gosec // G115 +} + +func (b *DLRRReportBlock) unpackBlockHeader() { +} + +// StatisticsSummaryReportBlock encodes a Statistics Summary Report +// Block as described in RFC 3611, section 4.6. +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | BT=6 |L|D|J|ToH|rsvd.| block length = 9 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | SSRC of source | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | begin_seq | end_seq | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | lost_packets | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | dup_packets | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | min_jitter | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | max_jitter | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | mean_jitter | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | dev_jitter | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | min_ttl_or_hl | max_ttl_or_hl |mean_ttl_or_hl | dev_ttl_or_hl | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// . +type StatisticsSummaryReportBlock struct { + XRHeader + LossReports bool `encoding:"omit"` + DuplicateReports bool `encoding:"omit"` + JitterReports bool `encoding:"omit"` + TTLorHopLimit TTLorHopLimitType `encoding:"omit"` + SSRC uint32 `fmt:"0x%X"` + BeginSeq uint16 + EndSeq uint16 + LostPackets uint32 + DupPackets uint32 + MinJitter uint32 + MaxJitter uint32 + MeanJitter uint32 + DevJitter uint32 + MinTTLOrHL uint8 + MaxTTLOrHL uint8 + MeanTTLOrHL uint8 + DevTTLOrHL uint8 +} + +// TTLorHopLimitType encodes values for the ToH field in +// a StatisticsSummaryReportBlock. +type TTLorHopLimitType uint8 + +// Values for TTLorHopLimitType. +const ( + ToHMissing = 0 + ToHIPv4 = 1 + ToHIPv6 = 2 +) + +func (t TTLorHopLimitType) String() string { + switch t { + case ToHMissing: + return "[ToH Missing]" + case ToHIPv4: + return "[ToH = IPv4]" + case ToHIPv6: + return "[ToH = IPv6]" + } + + return "[ToH Flag is Invalid]" +} + +// DestinationSSRC returns an array of SSRC values that this report block refers to. +func (b *StatisticsSummaryReportBlock) DestinationSSRC() []uint32 { + return []uint32{b.SSRC} +} + +func (b *StatisticsSummaryReportBlock) setupBlockHeader() { + b.XRHeader.BlockType = StatisticsSummaryReportBlockType + b.XRHeader.TypeSpecific = 0x00 + if b.LossReports { + b.XRHeader.TypeSpecific |= 0x80 + } + if b.DuplicateReports { + b.XRHeader.TypeSpecific |= 0x40 + } + if b.JitterReports { + b.XRHeader.TypeSpecific |= 0x20 + } + b.XRHeader.TypeSpecific |= TypeSpecificField((b.TTLorHopLimit & 0x03) << 3) + b.XRHeader.BlockLength = uint16(wireSize(b)/4 - 1) //nolint:gosec // G115 +} + +func (b *StatisticsSummaryReportBlock) unpackBlockHeader() { + b.LossReports = b.XRHeader.TypeSpecific&0x80 != 0 + b.DuplicateReports = b.XRHeader.TypeSpecific&0x40 != 0 + b.JitterReports = b.XRHeader.TypeSpecific&0x20 != 0 + b.TTLorHopLimit = TTLorHopLimitType((b.XRHeader.TypeSpecific & 0x18) >> 3) +} + +// VoIPMetricsReportBlock encodes a VoIP Metrics Report Block as described +// in RFC 3611, section 4.7. +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | BT=7 | reserved | block length = 8 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | SSRC of source | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | loss rate | discard rate | burst density | gap density | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | burst duration | gap duration | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | round trip delay | end system delay | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | signal level | noise level | RERL | Gmin | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | R factor | ext. R factor | MOS-LQ | MOS-CQ | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | RX config | reserved | JB nominal | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | JB maximum | JB abs max | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// . +type VoIPMetricsReportBlock struct { + XRHeader + SSRC uint32 `fmt:"0x%X"` + LossRate uint8 + DiscardRate uint8 + BurstDensity uint8 + GapDensity uint8 + BurstDuration uint16 + GapDuration uint16 + RoundTripDelay uint16 + EndSystemDelay uint16 + SignalLevel uint8 + NoiseLevel uint8 + RERL uint8 + Gmin uint8 + RFactor uint8 + ExtRFactor uint8 + MOSLQ uint8 + MOSCQ uint8 + RXConfig uint8 + _ uint8 + JBNominal uint16 + JBMaximum uint16 + JBAbsMax uint16 +} + +// DestinationSSRC returns an array of SSRC values that this report block refers to. +func (b *VoIPMetricsReportBlock) DestinationSSRC() []uint32 { + return []uint32{b.SSRC} +} + +func (b *VoIPMetricsReportBlock) setupBlockHeader() { + b.XRHeader.BlockType = VoIPMetricsReportBlockType + b.XRHeader.TypeSpecific = 0 + b.XRHeader.BlockLength = uint16(wireSize(b)/4 - 1) //nolint:gosec // G115 +} + +func (b *VoIPMetricsReportBlock) unpackBlockHeader() { +} + +// UnknownReportBlock is used to store bytes for any report block +// that has an unknown Report Block Type. +type UnknownReportBlock struct { + XRHeader + Bytes []byte +} + +// DestinationSSRC returns an array of SSRC values that this report block refers to. +func (b *UnknownReportBlock) DestinationSSRC() []uint32 { + return []uint32{} +} + +func (b *UnknownReportBlock) setupBlockHeader() { + b.XRHeader.BlockLength = uint16(wireSize(b)/4 - 1) //nolint:gosec // G115 +} + +func (b *UnknownReportBlock) unpackBlockHeader() { +} + +// MarshalSize returns the size of the packet once marshaled. +func (x ExtendedReport) MarshalSize() int { + return wireSize(x) +} + +// Marshal encodes the ExtendedReport in binary. +func (x ExtendedReport) Marshal() ([]byte, error) { + for _, p := range x.Reports { + p.setupBlockHeader() + } + + length := wireSize(x) + + // RTCP Header + header := Header{ + Type: TypeExtendedReport, + Length: uint16(length / 4), //nolint:gosec // G115 + } + headerBuffer, err := header.Marshal() + if err != nil { + return []byte{}, err + } + length += len(headerBuffer) + + rawPacket := make([]byte, length) + buffer := packetBuffer{bytes: rawPacket} + + err = buffer.write(headerBuffer) + if err != nil { + return []byte{}, err + } + err = buffer.write(x) + if err != nil { + return []byte{}, err + } + + return rawPacket, nil +} + +// Unmarshal decodes the ExtendedReport from binary. +// +//nolint:cyclop +func (x *ExtendedReport) Unmarshal(b []byte) error { + var header Header + if err := header.Unmarshal(b); err != nil { + return err + } + if header.Type != TypeExtendedReport { + return errWrongType + } + + buffer := packetBuffer{bytes: b[headerLength:]} + err := buffer.read(&x.SenderSSRC) + if err != nil { + return err + } + + for len(buffer.bytes) > 0 { + var block ReportBlock + + headerBuffer := buffer + xrHeader := XRHeader{} + err = headerBuffer.read(&xrHeader) + if err != nil { + return err + } + + switch xrHeader.BlockType { + case LossRLEReportBlockType: + block = new(LossRLEReportBlock) + case DuplicateRLEReportBlockType: + block = new(DuplicateRLEReportBlock) + case PacketReceiptTimesReportBlockType: + block = new(PacketReceiptTimesReportBlock) + case ReceiverReferenceTimeReportBlockType: + block = new(ReceiverReferenceTimeReportBlock) + case DLRRReportBlockType: + block = new(DLRRReportBlock) + case StatisticsSummaryReportBlockType: + block = new(StatisticsSummaryReportBlock) + case VoIPMetricsReportBlockType: + block = new(VoIPMetricsReportBlock) + default: + block = new(UnknownReportBlock) + } + + // We need to limit the amount of data available to + // this block to the actual length of the block + blockLength := (int(xrHeader.BlockLength) + 1) * 4 + blockBuffer := buffer.split(blockLength) + err = blockBuffer.read(block) + if err != nil { + return err + } + block.unpackBlockHeader() + x.Reports = append(x.Reports, block) + } + + return nil +} + +// DestinationSSRC returns an array of SSRC values that this packet refers to. +func (x *ExtendedReport) DestinationSSRC() []uint32 { + ssrc := make([]uint32, 0, len(x.Reports)+1) + ssrc = append(ssrc, x.SenderSSRC) + for _, p := range x.Reports { + ssrc = append(ssrc, p.DestinationSSRC()...) + } + + return ssrc +} + +func (x *ExtendedReport) String() string { + return stringify(x) +} diff --git a/vendor/github.com/pion/rtcp/full_intra_request.go b/vendor/github.com/pion/rtcp/full_intra_request.go new file mode 100644 index 0000000..0d93c5c --- /dev/null +++ b/vendor/github.com/pion/rtcp/full_intra_request.go @@ -0,0 +1,119 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package rtcp + +import ( + "encoding/binary" + "fmt" +) + +// A FIREntry is a (SSRC, seqno) pair, as carried by FullIntraRequest. +type FIREntry struct { + SSRC uint32 + SequenceNumber uint8 +} + +// The FullIntraRequest packet is used to reliably request an Intra frame +// in a video stream. See RFC 5104 Section 3.5.1. This is not for loss +// recovery, which should use PictureLossIndication (PLI) instead. +type FullIntraRequest struct { + SenderSSRC uint32 + MediaSSRC uint32 + + FIR []FIREntry +} + +const ( + firOffset = 8 +) + +var _ Packet = (*FullIntraRequest)(nil) + +// Marshal encodes the FullIntraRequest. +func (p FullIntraRequest) Marshal() ([]byte, error) { + rawPacket := make([]byte, firOffset+(len(p.FIR)*8)) + binary.BigEndian.PutUint32(rawPacket, p.SenderSSRC) + binary.BigEndian.PutUint32(rawPacket[4:], p.MediaSSRC) + for i, fir := range p.FIR { + binary.BigEndian.PutUint32(rawPacket[firOffset+8*i:], fir.SSRC) + rawPacket[firOffset+8*i+4] = fir.SequenceNumber + } + h := p.Header() + hData, err := h.Marshal() + if err != nil { + return nil, err + } + + return append(hData, rawPacket...), nil +} + +// Unmarshal decodes the TransportLayerNack. +func (p *FullIntraRequest) Unmarshal(rawPacket []byte) error { + if len(rawPacket) < (headerLength + ssrcLength) { + return errPacketTooShort + } + + var header Header + if err := header.Unmarshal(rawPacket); err != nil { + return err + } + + if len(rawPacket) < (headerLength + int(4*header.Length)) { + return errPacketTooShort + } + + if header.Type != TypePayloadSpecificFeedback || header.Count != FormatFIR { + return errWrongType + } + + // The FCI field MUST contain one or more FIR entries + if 4*header.Length-firOffset <= 0 || (4*header.Length)%8 != 0 { + return errBadLength + } + + p.SenderSSRC = binary.BigEndian.Uint32(rawPacket[headerLength:]) + p.MediaSSRC = binary.BigEndian.Uint32(rawPacket[headerLength+ssrcLength:]) + for i := headerLength + firOffset; i < (headerLength + int(header.Length*4)); i += 8 { + p.FIR = append(p.FIR, FIREntry{ + binary.BigEndian.Uint32(rawPacket[i:]), + rawPacket[i+4], + }) + } + + return nil +} + +// Header returns the Header associated with this packet. +func (p *FullIntraRequest) Header() Header { + return Header{ + Count: FormatFIR, + Type: TypePayloadSpecificFeedback, + Length: uint16((p.MarshalSize() / 4) - 1), //nolint:gosec // G115 + } +} + +// MarshalSize returns the size of the packet once marshaled. +func (p *FullIntraRequest) MarshalSize() int { + return headerLength + firOffset + len(p.FIR)*8 +} + +func (p *FullIntraRequest) String() string { + out := fmt.Sprintf("FullIntraRequest %x %x", + p.SenderSSRC, p.MediaSSRC) + for _, e := range p.FIR { + out += fmt.Sprintf(" (%x %v)", e.SSRC, e.SequenceNumber) + } + + return out +} + +// DestinationSSRC returns an array of SSRC values that this packet refers to. +func (p *FullIntraRequest) DestinationSSRC() []uint32 { + ssrcs := make([]uint32, 0, len(p.FIR)) + for _, entry := range p.FIR { + ssrcs = append(ssrcs, entry.SSRC) + } + + return ssrcs +} diff --git a/vendor/github.com/pion/rtcp/goodbye.go b/vendor/github.com/pion/rtcp/goodbye.go new file mode 100644 index 0000000..c647743 --- /dev/null +++ b/vendor/github.com/pion/rtcp/goodbye.go @@ -0,0 +1,164 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package rtcp + +import ( + "encoding/binary" + "fmt" +) + +// The Goodbye packet indicates that one or more sources are no longer active. +type Goodbye struct { + // The SSRC/CSRC identifiers that are no longer active + Sources []uint32 + // Optional text indicating the reason for leaving, e.g., "camera malfunction" or "RTP loop detected" + Reason string +} + +// Marshal encodes the Goodbye packet in binary. +func (g Goodbye) Marshal() ([]byte, error) { + /* + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * |V=2|P| SC | PT=BYE=203 | length | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | SSRC/CSRC | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * : ... : + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * (opt) | length | reason for leaving ... + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ + + rawPacket := make([]byte, g.MarshalSize()) + packetBody := rawPacket[headerLength:] + + if len(g.Sources) > countMax { + return nil, errTooManySources + } + + for i, s := range g.Sources { + binary.BigEndian.PutUint32(packetBody[i*ssrcLength:], s) + } + + if g.Reason != "" { + reason := []byte(g.Reason) + + if len(reason) > sdesMaxOctetCount { + return nil, errReasonTooLong + } + + reasonOffset := len(g.Sources) * ssrcLength + packetBody[reasonOffset] = uint8(len(reason)) //nolint:gosec // G115 + copy(packetBody[reasonOffset+1:], reason) + } + + hData, err := g.Header().Marshal() + if err != nil { + return nil, err + } + copy(rawPacket, hData) + + return rawPacket, nil +} + +// Unmarshal decodes the Goodbye packet from binary. +func (g *Goodbye) Unmarshal(rawPacket []byte) error { + /* + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * |V=2|P| SC | PT=BYE=203 | length | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | SSRC/CSRC | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * : ... : + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * (opt) | length | reason for leaving ... + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ + + var header Header + if err := header.Unmarshal(rawPacket); err != nil { + return err + } + + if header.Type != TypeGoodbye { + return errWrongType + } + + if getPadding(len(rawPacket)) != 0 { + return errPacketTooShort + } + + g.Sources = make([]uint32, header.Count) + + reasonOffset := int(headerLength + header.Count*ssrcLength) + if reasonOffset > len(rawPacket) { + return errPacketTooShort + } + + for i := 0; i < int(header.Count); i++ { + offset := headerLength + i*ssrcLength + + g.Sources[i] = binary.BigEndian.Uint32(rawPacket[offset:]) + } + + if reasonOffset < len(rawPacket) { + reasonLen := int(rawPacket[reasonOffset]) + reasonEnd := reasonOffset + 1 + reasonLen + + if reasonEnd > len(rawPacket) { + return errPacketTooShort + } + + g.Reason = string(rawPacket[reasonOffset+1 : reasonEnd]) + } + + return nil +} + +// Header returns the Header associated with this packet. +func (g *Goodbye) Header() Header { + return Header{ + Padding: false, + Count: uint8(len(g.Sources)), //nolint:gosec //G115 + Type: TypeGoodbye, + Length: uint16((g.MarshalSize() / 4) - 1), //nolint:gosec //G115 + } +} + +// MarshalSize returns the size of the packet once marshaled. +func (g *Goodbye) MarshalSize() int { + srcsLength := len(g.Sources) * ssrcLength + // reason is optional + reasonLength := len(g.Reason) + if reasonLength > 0 { + reasonLength++ + } + + l := headerLength + srcsLength + reasonLength + + // align to 32-bit boundary + return l + getPadding(l) +} + +// DestinationSSRC returns an array of SSRC values that this packet refers to. +func (g *Goodbye) DestinationSSRC() []uint32 { + out := make([]uint32, len(g.Sources)) + copy(out, g.Sources) + + return out +} + +func (g Goodbye) String() string { + out := "Goodbye\n" + for i, s := range g.Sources { + out += fmt.Sprintf("\tSource %d: %x\n", i, s) + } + out += fmt.Sprintf("\tReason: %s\n", g.Reason) + + return out +} diff --git a/vendor/github.com/pion/rtcp/header.go b/vendor/github.com/pion/rtcp/header.go new file mode 100644 index 0000000..f392e14 --- /dev/null +++ b/vendor/github.com/pion/rtcp/header.go @@ -0,0 +1,150 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package rtcp + +import ( + "encoding/binary" +) + +// PacketType specifies the type of an RTCP packet. +type PacketType uint8 + +// RTCP packet types registered with IANA. See: +// +// https://www.iana.org/assignments/rtp-parameters/rtp-parameters.xhtml#rtp-parameters-4 +const ( + TypeSenderReport PacketType = 200 // RFC 3550, 6.4.1 + TypeReceiverReport PacketType = 201 // RFC 3550, 6.4.2 + TypeSourceDescription PacketType = 202 // RFC 3550, 6.5 + TypeGoodbye PacketType = 203 // RFC 3550, 6.6 + TypeApplicationDefined PacketType = 204 // RFC 3550, 6.7 (unimplemented) + TypeTransportSpecificFeedback PacketType = 205 // RFC 4585, 6051 + TypePayloadSpecificFeedback PacketType = 206 // RFC 4585, 6.3 + TypeExtendedReport PacketType = 207 // RFC 3611 + +) + +// Transport and Payload specific feedback messages overload the count field to act as a message type. +// those are listed here. +const ( + FormatSLI uint8 = 2 + FormatPLI uint8 = 1 + FormatFIR uint8 = 4 + FormatTLN uint8 = 1 + FormatRRR uint8 = 5 + FormatCCFB uint8 = 11 + FormatREMB uint8 = 15 + + // https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01#page-5 + FormatTCC uint8 = 15 +) + +func (p PacketType) String() string { + switch p { + case TypeSenderReport: + return "SR" + case TypeReceiverReport: + return "RR" + case TypeSourceDescription: + return "SDES" + case TypeGoodbye: + return "BYE" + case TypeApplicationDefined: + return "APP" + case TypeTransportSpecificFeedback: + return "TSFB" + case TypePayloadSpecificFeedback: + return "PSFB" + case TypeExtendedReport: + return "XR" + default: + return string(p) + } +} + +const rtpVersion = 2 + +// A Header is the common header shared by all RTCP packets. +type Header struct { + // If the padding bit is set, this individual RTCP packet contains + // some additional padding octets at the end which are not part of + // the control information but are included in the length field. + Padding bool + // The number of reception reports, sources contained or FMT in this packet (depending on the Type) + Count uint8 + // The RTCP packet type for this packet + Type PacketType + // The length of this RTCP packet in 32-bit words minus one, + // including the header and any padding. + Length uint16 +} + +const ( + headerLength = 4 + versionShift = 6 + versionMask = 0x3 + paddingShift = 5 + paddingMask = 0x1 + countShift = 0 + countMask = 0x1f + countMax = (1 << 5) - 1 +) + +// Marshal encodes the Header in binary. +func (h Header) Marshal() ([]byte, error) { + /* + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * |V=2|P| RC | PT=SR=200 | length | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ + rawPacket := make([]byte, headerLength) + + rawPacket[0] |= rtpVersion << versionShift + + if h.Padding { + rawPacket[0] |= 1 << paddingShift + } + + if h.Count > 31 { + return nil, errInvalidHeader + } + rawPacket[0] |= h.Count << countShift + + rawPacket[1] = uint8(h.Type) + + binary.BigEndian.PutUint16(rawPacket[2:], h.Length) + + return rawPacket, nil +} + +// Unmarshal decodes the Header from binary. +func (h *Header) Unmarshal(rawPacket []byte) error { + if len(rawPacket) < headerLength { + return errPacketTooShort + } + + /* + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * |V=2|P| RC | PT | length | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ + + version := rawPacket[0] >> versionShift & versionMask + if version != rtpVersion { + return errBadVersion + } + + h.Padding = (rawPacket[0] >> paddingShift & paddingMask) > 0 + h.Count = rawPacket[0] >> countShift & countMask + + h.Type = PacketType(rawPacket[1]) + + h.Length = binary.BigEndian.Uint16(rawPacket[2:]) + + return nil +} diff --git a/vendor/github.com/pion/rtcp/packet.go b/vendor/github.com/pion/rtcp/packet.go new file mode 100644 index 0000000..637d9cd --- /dev/null +++ b/vendor/github.com/pion/rtcp/packet.go @@ -0,0 +1,131 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package rtcp + +// Packet represents an RTCP packet, a protocol used for out-of-band statistics +// and control information for an RTP session. +type Packet interface { + // DestinationSSRC returns an array of SSRC values that this packet refers to. + DestinationSSRC() []uint32 + + Marshal() ([]byte, error) + Unmarshal(rawPacket []byte) error + MarshalSize() int +} + +// Unmarshal takes an entire udp datagram (which may consist of multiple RTCP packets) and +// returns the unmarshaled packets it contains. +// +// If this is a reduced-size RTCP packet a feedback packet (Goodbye, SliceLossIndication, etc) +// will be returned. Otherwise, the underlying type of the returned packet will be +// CompoundPacket. +func Unmarshal(rawData []byte) ([]Packet, error) { + var packets []Packet + for len(rawData) != 0 { + p, processed, err := unmarshal(rawData) + if err != nil { + return nil, err + } + + packets = append(packets, p) + rawData = rawData[processed:] + } + + switch len(packets) { + // Empty packet + case 0: + return nil, errInvalidHeader + // Multiple Packets + default: + return packets, nil + } +} + +// Marshal takes an array of Packets and serializes them to a single buffer. +func Marshal(packets []Packet) ([]byte, error) { + out := make([]byte, 0) + for _, p := range packets { + data, err := p.Marshal() + if err != nil { + return nil, err + } + out = append(out, data...) + } + + return out, nil +} + +// unmarshal is a factory which pulls the first RTCP packet from a bytestream, +// and returns it's parsed representation, and the amount of data that was processed. +// +//nolint:cyclop +func unmarshal(rawData []byte) (packet Packet, bytesprocessed int, err error) { + var header Header + + err = header.Unmarshal(rawData) + if err != nil { + return nil, 0, err + } + + bytesprocessed = int(header.Length+1) * 4 + if bytesprocessed > len(rawData) { + return nil, 0, errPacketTooShort + } + inPacket := rawData[:bytesprocessed] + + switch header.Type { + case TypeSenderReport: + packet = new(SenderReport) + + case TypeReceiverReport: + packet = new(ReceiverReport) + + case TypeSourceDescription: + packet = new(SourceDescription) + + case TypeGoodbye: + packet = new(Goodbye) + + case TypeTransportSpecificFeedback: + switch header.Count { + case FormatTLN: + packet = new(TransportLayerNack) + case FormatRRR: + packet = new(RapidResynchronizationRequest) + case FormatTCC: + packet = new(TransportLayerCC) + case FormatCCFB: + packet = new(CCFeedbackReport) + default: + packet = new(RawPacket) + } + + case TypePayloadSpecificFeedback: + switch header.Count { + case FormatPLI: + packet = new(PictureLossIndication) + case FormatSLI: + packet = new(SliceLossIndication) + case FormatREMB: + packet = new(ReceiverEstimatedMaximumBitrate) + case FormatFIR: + packet = new(FullIntraRequest) + default: + packet = new(RawPacket) + } + + case TypeExtendedReport: + packet = new(ExtendedReport) + + case TypeApplicationDefined: + packet = new(ApplicationDefined) + + default: + packet = new(RawPacket) + } + + err = packet.Unmarshal(inPacket) + + return packet, bytesprocessed, err +} diff --git a/vendor/github.com/pion/rtcp/packet_buffer.go b/vendor/github.com/pion/rtcp/packet_buffer.go new file mode 100644 index 0000000..5efaad1 --- /dev/null +++ b/vendor/github.com/pion/rtcp/packet_buffer.go @@ -0,0 +1,270 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package rtcp + +import ( + "encoding/binary" + "reflect" + "unsafe" +) + +// These functions implement an introspective structure +// serializer/deserializer, designed to allow RTCP packet +// Structs to be self-describing. They currently work with +// fields of type uint8, uint16, uint32, and uint64 (and +// types derived from them). +// +// - Unexported fields will take up space in the encoded +// array, but wil be set to zero when written, and ignore +// when read. +// +// - Fields that are marked with the tag `encoding:"omit"` +// will be ignored when reading and writing data. +// +// For example: +// +// type Example struct { +// A uint32 +// B bool `encoding:"omit"` +// _ uint64 +// C uint16 +// } +// +// "A" will be encoded as four bytes, in network order. "B" +// will not be encoded at all. The anonymous uint64 will +// encode as 8 bytes of value "0", followed by two bytes +// encoding "C" in network order. + +type packetBuffer struct { + bytes []byte +} + +const omit = "omit" + +// Writes the structure passed to into the buffer that +// PacketBuffer is initialized with. This function will +// modify the PacketBuffer.bytes slice to exclude those +// bytes that have been written into. +// +//nolint:gocognit,cyclop +func (b *packetBuffer) write(v any) error { + value := reflect.ValueOf(v) + + // Indirect is safe to call on non-pointers, and + // will simply return the same value in such cases + value = reflect.Indirect(value) + + switch value.Kind() { + case reflect.Uint8: + if len(b.bytes) < 1 { + return errWrongMarshalSize + } + if value.CanInterface() { + b.bytes[0] = byte(value.Uint()) + } + b.bytes = b.bytes[1:] + case reflect.Uint16: + if len(b.bytes) < 2 { + return errWrongMarshalSize + } + if value.CanInterface() { + binary.BigEndian.PutUint16(b.bytes, uint16(value.Uint())) //nolint:gosec // G115 + } + b.bytes = b.bytes[2:] + case reflect.Uint32: + if len(b.bytes) < 4 { + return errWrongMarshalSize + } + if value.CanInterface() { + binary.BigEndian.PutUint32(b.bytes, uint32(value.Uint())) //nolint:gosec // G115 + } + b.bytes = b.bytes[4:] + case reflect.Uint64: + if len(b.bytes) < 8 { + return errWrongMarshalSize + } + if value.CanInterface() { + binary.BigEndian.PutUint64(b.bytes, value.Uint()) + } + b.bytes = b.bytes[8:] + case reflect.Slice: + for i := 0; i < value.Len(); i++ { + if value.Index(i).CanInterface() { + if err := b.write(value.Index(i).Interface()); err != nil { + return err + } + } else { + b.bytes = b.bytes[value.Index(i).Type().Size():] + } + } + case reflect.Struct: + for i := 0; i < value.NumField(); i++ { + encoding := value.Type().Field(i).Tag.Get("encoding") + if encoding == omit { + continue + } + if value.Field(i).CanInterface() { + if err := b.write(value.Field(i).Interface()); err != nil { + return err + } + } else { + advance := int(value.Field(i).Type().Size()) + if len(b.bytes) < advance { + return errWrongMarshalSize + } + b.bytes = b.bytes[advance:] + } + } + default: + return errBadStructMemberType + } + + return nil +} + +// Reads bytes from the buffer as necessary to populate +// the structure passed as a parameter. This function will +// modify the PacketBuffer.bytes slice to exclude those +// bytes that have already been read. +// +//nolint:gocognit,cyclop +func (b *packetBuffer) read(v any) error { + ptr := reflect.ValueOf(v) + if ptr.Kind() != reflect.Ptr { + return errBadReadParameter + } + value := reflect.Indirect(ptr) + + // If this is an interface, we need to make it concrete before using it + if value.Kind() == reflect.Interface { + value = reflect.ValueOf(value.Interface()) + } + value = reflect.Indirect(value) + + switch value.Kind() { + case reflect.Uint8: + if len(b.bytes) < 1 { + return errWrongMarshalSize + } + value.SetUint(uint64(b.bytes[0])) + b.bytes = b.bytes[1:] + + case reflect.Uint16: + if len(b.bytes) < 2 { + return errWrongMarshalSize + } + value.SetUint(uint64(binary.BigEndian.Uint16(b.bytes))) + b.bytes = b.bytes[2:] + + case reflect.Uint32: + if len(b.bytes) < 4 { + return errWrongMarshalSize + } + value.SetUint(uint64(binary.BigEndian.Uint32(b.bytes))) + b.bytes = b.bytes[4:] + + case reflect.Uint64: + if len(b.bytes) < 8 { + return errWrongMarshalSize + } + value.SetUint(binary.BigEndian.Uint64(b.bytes)) + b.bytes = b.bytes[8:] + + case reflect.Slice: + // If we encounter a slice, we consume the rest of the data + // in the buffer and load it into the slice. + for len(b.bytes) > 0 { + newElementPtr := reflect.New(value.Type().Elem()) + if err := b.read(newElementPtr.Interface()); err != nil { + return err + } + if value.CanSet() { + value.Set(reflect.Append(value, reflect.Indirect(newElementPtr))) + } + } + + case reflect.Struct: + for i := 0; i < value.NumField(); i++ { + encoding := value.Type().Field(i).Tag.Get("encoding") + if encoding == omit { + continue + } + if value.Field(i).CanInterface() { + field := value.Field(i) + newFieldPtr := reflect.NewAt( + //nolint:gosec // This is the only way to get a typed pointer to a structure's field + field.Type(), unsafe.Pointer(field.UnsafeAddr()), + ) + if err := b.read(newFieldPtr.Interface()); err != nil { + return err + } + } else { + advance := int(value.Field(i).Type().Size()) + if len(b.bytes) < advance { + return errWrongMarshalSize + } + b.bytes = b.bytes[advance:] + } + } + + default: + return errBadStructMemberType + } + + return nil +} + +// Consumes `size` bytes and returns them as an +// independent PacketBuffer. +func (b *packetBuffer) split(size int) packetBuffer { + if size > len(b.bytes) { + size = len(b.bytes) + } + newBuffer := packetBuffer{bytes: b.bytes[:size]} + + b.bytes = b.bytes[size:] + + return newBuffer +} + +// Returns the size that a structure will encode into. +// This fuction doesn't check that Write() will succeed, +// and may return unexpectedly large results for those +// structures that Write() will fail on. +func wireSize(v any) int { + value := reflect.ValueOf(v) + // Indirect is safe to call on non-pointers, and + // will simply return the same value in such cases + value = reflect.Indirect(value) + size := int(0) + + switch value.Kind() { + case reflect.Slice: + for i := 0; i < value.Len(); i++ { + if value.Index(i).CanInterface() { + size += wireSize(value.Index(i).Interface()) + } else { + size += int(value.Index(i).Type().Size()) + } + } + + case reflect.Struct: + for i := 0; i < value.NumField(); i++ { + encoding := value.Type().Field(i).Tag.Get("encoding") + if encoding == omit { + continue + } + if value.Field(i).CanInterface() { + size += wireSize(value.Field(i).Interface()) + } else { + size += int(value.Field(i).Type().Size()) + } + } + + default: + size = int(value.Type().Size()) + } + + return size +} diff --git a/vendor/github.com/pion/rtcp/packet_stringifier.go b/vendor/github.com/pion/rtcp/packet_stringifier.go new file mode 100644 index 0000000..8536f2f --- /dev/null +++ b/vendor/github.com/pion/rtcp/packet_stringifier.go @@ -0,0 +1,112 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package rtcp + +import ( + "fmt" + "reflect" +) + +/* +Converts an RTCP Packet into a human-readable format. The Packets +themselves can control the presentation as follows: + + - Fields of a type that have a String() method will be formatted + with that String method (which should not emit '\n' characters) + + - Otherwise, fields with a tag containing a "fmt" string will use that + format when serializing the value. For example, to format an SSRC + value as base 16 insted of base 10: + + type ExamplePacket struct { + LocalSSRC uint32 `fmt:"0x%X"` + RemotsSSRCs []uint32 `fmt:"%X"` + } + +- If no fmt string is present, "%+v" is used by default + +The intention of this stringify() function is to simplify creation +of String() methods on new packet types, as it provides a simple +baseline implementation that works well in the majority of cases. +*/ +func stringify(p Packet) string { + value := reflect.Indirect(reflect.ValueOf(p)) + + return formatField(value.Type().String(), "", p, "") +} + +//nolint:gocognit,cyclop +func formatField(name string, format string, f any, indent string) string { + out := indent + value := reflect.ValueOf(f) + + if !value.IsValid() { + return fmt.Sprintf("%s%s: \n", out, name) + } + + isPacket := reflect.TypeOf(f).Implements(reflect.TypeOf((*Packet)(nil)).Elem()) + + // Resolve pointers to their underlying values + if value.Type().Kind() == reflect.Ptr && !value.IsNil() { + underlying := reflect.Indirect(value) + if underlying.IsValid() { + value = underlying + } + } + + // If the field type has a custom String method, use that + // (unless we're a packet, since we want to avoid recursing + // back into this function if the Packet's String() method + // uses it) + if stringMethod := value.MethodByName("String"); !isPacket && stringMethod.IsValid() { + out += fmt.Sprintf("%s: %s\n", name, stringMethod.Call([]reflect.Value{})) + + return out + } + + switch value.Kind() { + case reflect.Struct: + out += fmt.Sprintf("%s:\n", name) + for i := 0; i < value.NumField(); i++ { + if value.Field(i).CanInterface() { + format = value.Type().Field(i).Tag.Get("fmt") + if format == "" { + format = "%+v" + } + out += formatField(value.Type().Field(i).Name, format, value.Field(i).Interface(), indent+"\t") + } + } + case reflect.Slice: + childKind := value.Type().Elem().Kind() + _, hasStringMethod := value.Type().Elem().MethodByName("String") + if hasStringMethod || childKind == reflect.Struct || childKind == reflect.Ptr || + childKind == reflect.Interface || childKind == reflect.Slice { + out += fmt.Sprintf("%s:\n", name) + for i := 0; i < value.Len(); i++ { + childName := fmt.Sprint(i) + // Since interfaces can hold different types of things, we add the + // most specific type name to the name to make it clear what the + // subsequent fields represent. + if value.Index(i).Kind() == reflect.Interface { + childName += fmt.Sprintf(" (%s)", reflect.Indirect(reflect.ValueOf(value.Index(i).Interface())).Type()) + } + if value.Index(i).CanInterface() { + out += formatField(childName, format, value.Index(i).Interface(), indent+"\t") + } + } + + return out + } + + // If we didn't take care of stringing the value already, we fall through to the + // generic case. This will print slices of basic types on a single line. + fallthrough + default: + if value.CanInterface() { + out += fmt.Sprintf("%s: "+format+"\n", name, value.Interface()) + } + } + + return out +} diff --git a/vendor/github.com/pion/rtcp/picture_loss_indication.go b/vendor/github.com/pion/rtcp/picture_loss_indication.go new file mode 100644 index 0000000..17379cd --- /dev/null +++ b/vendor/github.com/pion/rtcp/picture_loss_indication.go @@ -0,0 +1,95 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package rtcp + +import ( + "encoding/binary" + "fmt" +) + +// The PictureLossIndication packet informs the encoder about the loss of an undefined amount of +// coded video data belonging to one or more pictures. +type PictureLossIndication struct { + // SSRC of sender + SenderSSRC uint32 + + // SSRC where the loss was experienced + MediaSSRC uint32 +} + +const ( + pliLength = 2 +) + +// Marshal encodes the PictureLossIndication in binary. +func (p PictureLossIndication) Marshal() ([]byte, error) { + /* + * PLI does not require parameters. Therefore, the length field MUST be + * 2, and there MUST NOT be any Feedback Control Information. + * + * The semantics of this FB message is independent of the payload type. + */ + rawPacket := make([]byte, p.MarshalSize()) + packetBody := rawPacket[headerLength:] + + binary.BigEndian.PutUint32(packetBody, p.SenderSSRC) + binary.BigEndian.PutUint32(packetBody[4:], p.MediaSSRC) + + h := Header{ + Count: FormatPLI, + Type: TypePayloadSpecificFeedback, + Length: pliLength, + } + hData, err := h.Marshal() + if err != nil { + return nil, err + } + copy(rawPacket, hData) + + return rawPacket, nil +} + +// Unmarshal decodes the PictureLossIndication from binary. +func (p *PictureLossIndication) Unmarshal(rawPacket []byte) error { + if len(rawPacket) < (headerLength + (ssrcLength * 2)) { + return errPacketTooShort + } + + var h Header + if err := h.Unmarshal(rawPacket); err != nil { + return err + } + + if h.Type != TypePayloadSpecificFeedback || h.Count != FormatPLI { + return errWrongType + } + + p.SenderSSRC = binary.BigEndian.Uint32(rawPacket[headerLength:]) + p.MediaSSRC = binary.BigEndian.Uint32(rawPacket[headerLength+ssrcLength:]) + + return nil +} + +// Header returns the Header associated with this packet. +func (p *PictureLossIndication) Header() Header { + return Header{ + Count: FormatPLI, + Type: TypePayloadSpecificFeedback, + Length: pliLength, + } +} + +// MarshalSize returns the size of the packet once marshaled. +func (p *PictureLossIndication) MarshalSize() int { + return headerLength + ssrcLength*2 +} + +func (p *PictureLossIndication) String() string { + return fmt.Sprintf("PictureLossIndication %x %x", p.SenderSSRC, p.MediaSSRC) +} + +// DestinationSSRC returns an array of SSRC values that this packet refers to. +func (p *PictureLossIndication) DestinationSSRC() []uint32 { + return []uint32{p.MediaSSRC} +} diff --git a/vendor/github.com/pion/rtcp/rapid_resynchronization_request.go b/vendor/github.com/pion/rtcp/rapid_resynchronization_request.go new file mode 100644 index 0000000..d422033 --- /dev/null +++ b/vendor/github.com/pion/rtcp/rapid_resynchronization_request.go @@ -0,0 +1,96 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package rtcp + +import ( + "encoding/binary" + "fmt" +) + +// The RapidResynchronizationRequest packet informs the encoder about the loss of +// an undefined amount of coded video data belonging to one or more pictures. +type RapidResynchronizationRequest struct { + // SSRC of sender + SenderSSRC uint32 + + // SSRC of the media source + MediaSSRC uint32 +} + +// RapidResynchronisationRequest is provided as RFC 6051 spells resynchronization with an s. +// We provide both names to be consistent with other RFCs which spell resynchronization with a z. +type RapidResynchronisationRequest = RapidResynchronizationRequest + +const ( + rrrLength = 2 + rrrHeaderLength = ssrcLength * 2 + rrrMediaOffset = 4 +) + +// Marshal encodes the RapidResynchronizationRequest in binary. +func (p RapidResynchronizationRequest) Marshal() ([]byte, error) { + /* + * RRR does not require parameters. Therefore, the length field MUST be + * 2, and there MUST NOT be any Feedback Control Information. + * + * The semantics of this FB message is independent of the payload type. + */ + rawPacket := make([]byte, p.MarshalSize()) + packetBody := rawPacket[headerLength:] + + binary.BigEndian.PutUint32(packetBody, p.SenderSSRC) + binary.BigEndian.PutUint32(packetBody[rrrMediaOffset:], p.MediaSSRC) + + hData, err := p.Header().Marshal() + if err != nil { + return nil, err + } + copy(rawPacket, hData) + + return rawPacket, nil +} + +// Unmarshal decodes the RapidResynchronizationRequest from binary. +func (p *RapidResynchronizationRequest) Unmarshal(rawPacket []byte) error { + if len(rawPacket) < (headerLength + (ssrcLength * 2)) { + return errPacketTooShort + } + + var h Header + if err := h.Unmarshal(rawPacket); err != nil { + return err + } + + if h.Type != TypeTransportSpecificFeedback || h.Count != FormatRRR { + return errWrongType + } + + p.SenderSSRC = binary.BigEndian.Uint32(rawPacket[headerLength:]) + p.MediaSSRC = binary.BigEndian.Uint32(rawPacket[headerLength+ssrcLength:]) + + return nil +} + +// MarshalSize returns the size of the packet once marshaled. +func (p *RapidResynchronizationRequest) MarshalSize() int { + return headerLength + rrrHeaderLength +} + +// Header returns the Header associated with this packet. +func (p *RapidResynchronizationRequest) Header() Header { + return Header{ + Count: FormatRRR, + Type: TypeTransportSpecificFeedback, + Length: rrrLength, + } +} + +// DestinationSSRC returns an array of SSRC values that this packet refers to. +func (p *RapidResynchronizationRequest) DestinationSSRC() []uint32 { + return []uint32{p.MediaSSRC} +} + +func (p *RapidResynchronizationRequest) String() string { + return fmt.Sprintf("RapidResynchronizationRequest %x %x", p.SenderSSRC, p.MediaSSRC) +} diff --git a/vendor/github.com/pion/rtcp/raw_packet.go b/vendor/github.com/pion/rtcp/raw_packet.go new file mode 100644 index 0000000..71ac152 --- /dev/null +++ b/vendor/github.com/pion/rtcp/raw_packet.go @@ -0,0 +1,53 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package rtcp + +import "fmt" + +// RawPacket represents an unparsed RTCP packet. It's returned by Unmarshal when +// a packet with an unknown type is encountered. +type RawPacket []byte + +// Marshal encodes the packet in binary. +func (r RawPacket) Marshal() ([]byte, error) { + return r, nil +} + +// Unmarshal decodes the packet from binary. +func (r *RawPacket) Unmarshal(b []byte) error { + if len(b) < (headerLength) { + return errPacketTooShort + } + *r = b + + var h Header + + return h.Unmarshal(b) +} + +// Header returns the Header associated with this packet. +func (r RawPacket) Header() Header { + var h Header + if err := h.Unmarshal(r); err != nil { + return Header{} + } + + return h +} + +// DestinationSSRC returns an array of SSRC values that this packet refers to. +func (r *RawPacket) DestinationSSRC() []uint32 { + return []uint32{} +} + +func (r RawPacket) String() string { + out := fmt.Sprintf("RawPacket: %v", ([]byte)(r)) + + return out +} + +// MarshalSize returns the size of the packet once marshaled. +func (r RawPacket) MarshalSize() int { + return len(r) +} diff --git a/vendor/github.com/pion/rtcp/receiver_estimated_maximum_bitrate.go b/vendor/github.com/pion/rtcp/receiver_estimated_maximum_bitrate.go new file mode 100644 index 0000000..cb6cdae --- /dev/null +++ b/vendor/github.com/pion/rtcp/receiver_estimated_maximum_bitrate.go @@ -0,0 +1,287 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package rtcp + +import ( + "bytes" + "encoding/binary" + "fmt" + "math" +) + +// ReceiverEstimatedMaximumBitrate contains the receiver's estimated maximum bitrate. +// see: https://tools.ietf.org/html/draft-alvestrand-rmcat-remb-03 +type ReceiverEstimatedMaximumBitrate struct { + // SSRC of sender + SenderSSRC uint32 + + // Estimated maximum bitrate + Bitrate float32 + + // SSRC entries which this packet applies to + SSRCs []uint32 +} + +// Marshal serializes the packet and returns a byte slice. +func (p ReceiverEstimatedMaximumBitrate) Marshal() (buf []byte, err error) { + // Allocate a buffer of the exact output size. + buf = make([]byte, p.MarshalSize()) + + // Write to our buffer. + n, err := p.MarshalTo(buf) + if err != nil { + return nil, err + } + + // This will always be true but just to be safe. + if n != len(buf) { + return nil, errWrongMarshalSize + } + + return buf, nil +} + +// MarshalSize returns the size of the packet once marshaled. +func (p ReceiverEstimatedMaximumBitrate) MarshalSize() int { + return 20 + 4*len(p.SSRCs) +} + +// MarshalTo serializes the packet to the given byte slice. +func (p ReceiverEstimatedMaximumBitrate) MarshalTo(buf []byte) (n int, err error) { + const bitratemax = 0x3FFFFp+63 + /* + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |V=2|P| FMT=15 | PT=206 | length | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | SSRC of packet sender | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | SSRC of media source | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Unique identifier 'R' 'E' 'M' 'B' | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Num SSRC | BR Exp | BR Mantissa | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | SSRC feedback | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | ... | + */ + + size := p.MarshalSize() + if len(buf) < size { + return 0, errPacketTooShort + } + + buf[0] = 143 // v=2, p=0, fmt=15 + buf[1] = 206 + + // Length of this packet in 32-bit words minus one. + length := uint16((p.MarshalSize() / 4) - 1) //nolint:gosec // G115 + binary.BigEndian.PutUint16(buf[2:4], length) + + binary.BigEndian.PutUint32(buf[4:8], p.SenderSSRC) + binary.BigEndian.PutUint32(buf[8:12], 0) // always zero + + // ALL HAIL REMB + buf[12] = 'R' + buf[13] = 'E' + buf[14] = 'M' + buf[15] = 'B' + + // Write the length of the ssrcs to follow at the end + buf[16] = byte(len(p.SSRCs)) + + exp := 0 + bitrate := p.Bitrate + + if bitrate >= bitratemax { + bitrate = bitratemax + } + + if bitrate < 0 { + return 0, errInvalidBitrate + } + + for bitrate >= (1 << 18) { + bitrate /= 2.0 + exp++ + } + + if exp >= (1 << 6) { + return 0, errInvalidBitrate + } + + mantissa := uint(math.Floor(float64(bitrate))) + + // We can't quite use the binary package because + // a) it's a uint24 and b) the exponent is only 6-bits + // Just trust me; this is big-endian encoding. + buf[17] = byte(exp<<2) | byte(mantissa>>16) + buf[18] = byte(mantissa >> 8) + buf[19] = byte(mantissa) + + // Write the SSRCs at the very end. + n = 20 + for _, ssrc := range p.SSRCs { + binary.BigEndian.PutUint32(buf[n:n+4], ssrc) + n += 4 + } + + return n, nil +} + +// Unmarshal reads a REMB packet from the given byte slice. +// +//nolint:cyclop +func (p *ReceiverEstimatedMaximumBitrate) Unmarshal(buf []byte) (err error) { + const mantissamax = 0x7FFFFF + /* + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |V=2|P| FMT=15 | PT=206 | length | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | SSRC of packet sender | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | SSRC of media source | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Unique identifier 'R' 'E' 'M' 'B' | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Num SSRC | BR Exp | BR Mantissa | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | SSRC feedback | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | ... | + */ + + // 20 bytes is the size of the packet with no SSRCs + if len(buf) < 20 { + return errPacketTooShort + } + + // version must be 2 + version := buf[0] >> 6 + if version != 2 { + return fmt.Errorf("%w expected(2) actual(%d)", errBadVersion, version) + } + + // padding must be unset + padding := (buf[0] >> 5) & 1 + if padding != 0 { + return fmt.Errorf("%w expected(0) actual(%d)", errWrongPadding, padding) + } + + // fmt must be 15 + fmtVal := buf[0] & 31 + if fmtVal != 15 { + return fmt.Errorf("%w expected(15) actual(%d)", errWrongFeedbackType, fmtVal) + } + + // Must be payload specific feedback + if buf[1] != 206 { + return fmt.Errorf("%w expected(206) actual(%d)", errWrongPayloadType, buf[1]) + } + + // length is the number of 32-bit words, minus 1 + length := binary.BigEndian.Uint16(buf[2:4]) + size := int((length + 1) * 4) + + // There's not way this could be legit + if size < 20 { + return errHeaderTooSmall + } + + // Make sure the buffer is large enough. + if len(buf) < size { + return errPacketTooShort + } + + // The sender SSRC is 32-bits + p.SenderSSRC = binary.BigEndian.Uint32(buf[4:8]) + + // The destination SSRC must be 0 + media := binary.BigEndian.Uint32(buf[8:12]) + if media != 0 { + return errSSRCMustBeZero + } + + // REMB rules all around me + if !bytes.Equal(buf[12:16], []byte{'R', 'E', 'M', 'B'}) { + return errMissingREMBidentifier + } + + // The next byte is the number of SSRC entries at the end. + num := int(buf[16]) + + // Now we know the expected size, make sure they match. + if size != 20+4*num { + return errSSRCNumAndLengthMismatch + } + + // Get the 6-bit exponent value. + exp := buf[17] >> 2 + exp += 127 // bias for IEEE754 + exp += 23 // IEEE754 biases the decimal to the left, abs-send-time biases it to the right + + // The remaining 2-bits plus the next 16-bits are the mantissa. + mantissa := uint32(buf[17]&3)<<16 | uint32(buf[18])<<8 | uint32(buf[19]) + + if mantissa != 0 { + // ieee754 requires an implicit leading bit + for (mantissa & (mantissamax + 1)) == 0 { + exp-- + mantissa *= 2 + } + } + + // bitrate = mantissa * 2^exp + p.Bitrate = math.Float32frombits((uint32(exp) << 23) | (mantissa & mantissamax)) + + // Clear any existing SSRCs + p.SSRCs = nil + + // Loop over and parse the SSRC entires at the end. + // We already verified that size == num * 4 + for n := 20; n < size; n += 4 { + ssrc := binary.BigEndian.Uint32(buf[n : n+4]) + p.SSRCs = append(p.SSRCs, ssrc) + } + + return nil +} + +// Header returns the Header associated with this packet. +func (p *ReceiverEstimatedMaximumBitrate) Header() Header { + return Header{ + Count: FormatREMB, + Type: TypePayloadSpecificFeedback, + Length: uint16((p.MarshalSize() / 4) - 1), //nolint:gosec // G115 + } +} + +// String prints the REMB packet in a human-readable format. +func (p *ReceiverEstimatedMaximumBitrate) String() string { + // Keep a table of powers to units for fast conversion. + bitUnits := []string{"b", "Kb", "Mb", "Gb", "Tb", "Pb", "Eb"} + + // Do some unit conversions because b/s is far too difficult to read. + bitrate := p.Bitrate + powers := 0 + + // Keep dividing the bitrate until it's under 1000 + for bitrate >= 1000.0 && powers < len(bitUnits) { + bitrate /= 1000.0 + powers++ + } + + unit := bitUnits[powers] + + return fmt.Sprintf("ReceiverEstimatedMaximumBitrate %x %.2f %s/s", p.SenderSSRC, bitrate, unit) +} + +// DestinationSSRC returns an array of SSRC values that this packet refers to. +func (p *ReceiverEstimatedMaximumBitrate) DestinationSSRC() []uint32 { + return p.SSRCs +} diff --git a/vendor/github.com/pion/rtcp/receiver_report.go b/vendor/github.com/pion/rtcp/receiver_report.go new file mode 100644 index 0000000..84c682d --- /dev/null +++ b/vendor/github.com/pion/rtcp/receiver_report.go @@ -0,0 +1,199 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package rtcp + +import ( + "encoding/binary" + "fmt" +) + +// A ReceiverReport (RR) packet provides reception quality feedback for an RTP stream. +type ReceiverReport struct { + // The synchronization source identifier for the originator of this RR packet. + SSRC uint32 + // Zero or more reception report blocks depending on the number of other + // sources heard by this sender since the last report. Each reception report + // block conveys statistics on the reception of RTP packets from a + // single synchronization source. + Reports []ReceptionReport + // Extension contains additional, payload-specific information that needs to + // be reported regularly about the receiver. + ProfileExtensions []byte +} + +const ( + ssrcLength = 4 + rrSSRCOffset = headerLength + rrReportOffset = rrSSRCOffset + ssrcLength +) + +// Marshal encodes the ReceiverReport in binary. +func (r ReceiverReport) Marshal() ([]byte, error) { + /* + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * header |V=2|P| RC | PT=RR=201 | length | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | SSRC of packet sender | + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * report | SSRC_1 (SSRC of first source) | + * block +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * 1 | fraction lost | cumulative number of packets lost | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | extended highest sequence number received | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | interarrival jitter | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | last SR (LSR) | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | delay since last SR (DLSR) | + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * report | SSRC_2 (SSRC of second source) | + * block +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * 2 : ... : + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * | profile-specific extensions | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ + + rawPacket := make([]byte, r.MarshalSize()) + packetBody := rawPacket[headerLength:] + + binary.BigEndian.PutUint32(packetBody, r.SSRC) + + for i, rp := range r.Reports { + data, err := rp.Marshal() + if err != nil { + return nil, err + } + offset := ssrcLength + receptionReportLength*i + copy(packetBody[offset:], data) + } + + if len(r.Reports) > countMax { + return nil, errTooManyReports + } + + pe := make([]byte, len(r.ProfileExtensions)) + copy(pe, r.ProfileExtensions) + + // if the length of the profile extensions isn't devisible + // by 4, we need to pad the end. + for (len(pe) & 0x3) != 0 { + pe = append(pe, 0) //nolint:makezero + } + + rawPacket = append(rawPacket, pe...) //nolint:makezero + + hData, err := r.Header().Marshal() + if err != nil { + return nil, err + } + copy(rawPacket, hData) + + return rawPacket, nil +} + +// Unmarshal decodes the ReceiverReport from binary. +func (r *ReceiverReport) Unmarshal(rawPacket []byte) error { + /* + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * header |V=2|P| RC | PT=RR=201 | length | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | SSRC of packet sender | + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * report | SSRC_1 (SSRC of first source) | + * block +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * 1 | fraction lost | cumulative number of packets lost | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | extended highest sequence number received | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | interarrival jitter | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | last SR (LSR) | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | delay since last SR (DLSR) | + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * report | SSRC_2 (SSRC of second source) | + * block +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * 2 : ... : + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * | profile-specific extensions | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ + + if len(rawPacket) < (headerLength + ssrcLength) { + return errPacketTooShort + } + + var header Header + if err := header.Unmarshal(rawPacket); err != nil { + return err + } + + if header.Type != TypeReceiverReport { + return errWrongType + } + + r.SSRC = binary.BigEndian.Uint32(rawPacket[rrSSRCOffset:]) + + for i := rrReportOffset; i < len(rawPacket) && len(r.Reports) < int(header.Count); i += receptionReportLength { + var rr ReceptionReport + if err := rr.Unmarshal(rawPacket[i:]); err != nil { + return err + } + r.Reports = append(r.Reports, rr) + } + r.ProfileExtensions = rawPacket[rrReportOffset+(len(r.Reports)*receptionReportLength):] + + //nolint:gosec // G115 + if uint8(len(r.Reports)) != header.Count { + return errInvalidHeader + } + + return nil +} + +// MarshalSize returns the size of the packet once marshaled. +func (r *ReceiverReport) MarshalSize() int { + repsLength := 0 + for _, rep := range r.Reports { + repsLength += rep.len() + } + + return headerLength + ssrcLength + repsLength +} + +// Header returns the Header associated with this packet. +func (r *ReceiverReport) Header() Header { + return Header{ + Count: uint8(len(r.Reports)), //nolint:gosec // G115 + Type: TypeReceiverReport, + Length: uint16((r.MarshalSize()/4)-1) + uint16(getPadding(len(r.ProfileExtensions))), //nolint:gosec // G115 + } +} + +// DestinationSSRC returns an array of SSRC values that this packet refers to. +func (r *ReceiverReport) DestinationSSRC() []uint32 { + out := make([]uint32, len(r.Reports)) + for i, v := range r.Reports { + out[i] = v.SSRC + } + + return out +} + +func (r ReceiverReport) String() string { + out := fmt.Sprintf("ReceiverReport from %x\n", r.SSRC) + out += "\tSSRC \tLost\tLastSequence\n" + for _, i := range r.Reports { + out += fmt.Sprintf("\t%x\t%d/%d\t%d\n", i.SSRC, i.FractionLost, i.TotalLost, i.LastSequenceNumber) + } + out += fmt.Sprintf("\tProfile Extension Data: %v\n", r.ProfileExtensions) + + return out +} diff --git a/vendor/github.com/pion/rtcp/reception_report.go b/vendor/github.com/pion/rtcp/reception_report.go new file mode 100644 index 0000000..f2b4548 --- /dev/null +++ b/vendor/github.com/pion/rtcp/reception_report.go @@ -0,0 +1,133 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package rtcp + +import "encoding/binary" + +// A ReceptionReport block conveys statistics on the reception of RTP packets +// from a single synchronization source. +type ReceptionReport struct { + // The SSRC identifier of the source to which the information in this + // reception report block pertains. + SSRC uint32 + // The fraction of RTP data packets from source SSRC lost since the + // previous SR or RR packet was sent, expressed as a fixed point + // number with the binary point at the left edge of the field. + FractionLost uint8 + // The total number of RTP data packets from source SSRC that have + // been lost since the beginning of reception. + TotalLost uint32 + // The low 16 bits contain the highest sequence number received in an + // RTP data packet from source SSRC, and the most significant 16 + // bits extend that sequence number with the corresponding count of + // sequence number cycles. + LastSequenceNumber uint32 + // An estimate of the statistical variance of the RTP data packet + // interarrival time, measured in timestamp units and expressed as an + // unsigned integer. + Jitter uint32 + // The middle 32 bits out of 64 in the NTP timestamp received as part of + // the most recent RTCP sender report (SR) packet from source SSRC. If no + // SR has been received yet, the field is set to zero. + LastSenderReport uint32 + // The delay, expressed in units of 1/65536 seconds, between receiving the + // last SR packet from source SSRC and sending this reception report block. + // If no SR packet has been received yet from SSRC, the field is set to zero. + Delay uint32 +} + +const ( + receptionReportLength = 24 + fractionLostOffset = 4 + totalLostOffset = 5 + lastSeqOffset = 8 + jitterOffset = 12 + lastSROffset = 16 + delayOffset = 20 +) + +// Marshal encodes the ReceptionReport in binary. +func (r ReceptionReport) Marshal() ([]byte, error) { + /* + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * | SSRC | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | fraction lost | cumulative number of packets lost | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | extended highest sequence number received | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | interarrival jitter | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | last SR (LSR) | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | delay since last SR (DLSR) | + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + */ + + rawPacket := make([]byte, receptionReportLength) + + binary.BigEndian.PutUint32(rawPacket, r.SSRC) + + rawPacket[fractionLostOffset] = r.FractionLost + + // pack TotalLost into 24 bits + if r.TotalLost >= (1 << 25) { + return nil, errInvalidTotalLost + } + tlBytes := rawPacket[totalLostOffset:] + tlBytes[0] = byte(r.TotalLost >> 16) + tlBytes[1] = byte(r.TotalLost >> 8) + tlBytes[2] = byte(r.TotalLost) + + binary.BigEndian.PutUint32(rawPacket[lastSeqOffset:], r.LastSequenceNumber) + binary.BigEndian.PutUint32(rawPacket[jitterOffset:], r.Jitter) + binary.BigEndian.PutUint32(rawPacket[lastSROffset:], r.LastSenderReport) + binary.BigEndian.PutUint32(rawPacket[delayOffset:], r.Delay) + + return rawPacket, nil +} + +// Unmarshal decodes the ReceptionReport from binary. +func (r *ReceptionReport) Unmarshal(rawPacket []byte) error { + if len(rawPacket) < receptionReportLength { + return errPacketTooShort + } + + /* + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * | SSRC | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | fraction lost | cumulative number of packets lost | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | extended highest sequence number received | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | interarrival jitter | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | last SR (LSR) | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | delay since last SR (DLSR) | + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + */ + + r.SSRC = binary.BigEndian.Uint32(rawPacket) + r.FractionLost = rawPacket[fractionLostOffset] + + tlBytes := rawPacket[totalLostOffset:] + r.TotalLost = uint32(tlBytes[2]) | uint32(tlBytes[1])<<8 | uint32(tlBytes[0])<<16 + + r.LastSequenceNumber = binary.BigEndian.Uint32(rawPacket[lastSeqOffset:]) + r.Jitter = binary.BigEndian.Uint32(rawPacket[jitterOffset:]) + r.LastSenderReport = binary.BigEndian.Uint32(rawPacket[lastSROffset:]) + r.Delay = binary.BigEndian.Uint32(rawPacket[delayOffset:]) + + return nil +} + +func (r *ReceptionReport) len() int { + return receptionReportLength +} diff --git a/vendor/github.com/pion/rtcp/renovate.json b/vendor/github.com/pion/rtcp/renovate.json new file mode 100644 index 0000000..f1bb98c --- /dev/null +++ b/vendor/github.com/pion/rtcp/renovate.json @@ -0,0 +1,6 @@ +{ + "$schema": "https://docs.renovatebot.com/renovate-schema.json", + "extends": [ + "github>pion/renovate-config" + ] +} diff --git a/vendor/github.com/pion/rtcp/rfc8888.go b/vendor/github.com/pion/rtcp/rfc8888.go new file mode 100644 index 0000000..ce5dfb1 --- /dev/null +++ b/vendor/github.com/pion/rtcp/rfc8888.go @@ -0,0 +1,370 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package rtcp + +import ( + "encoding/binary" + "errors" + "fmt" + "math" +) + +// https://www.rfc-editor.org/rfc/rfc8888.html#name-rtcp-congestion-control-fee +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |V=2|P| FMT=11 | PT = 205 | length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | SSRC of RTCP packet sender | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | SSRC of 1st RTP Stream | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | begin_seq | num_reports | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |R|ECN| Arrival time offset | ... . +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// . . +// . . +// . . +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | SSRC of nth RTP Stream | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | begin_seq | num_reports | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |R|ECN| Arrival time offset | ... | +// . . +// . . +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Report Timestamp (32 bits) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +var ( + errReportBlockLength = errors.New("feedback report blocks must be at least 8 bytes") + errIncorrectNumReports = errors.New("feedback report block contains less reports than num_reports") + errMetricBlockLength = errors.New("feedback report metric blocks must be exactly 2 bytes") +) + +// ECN represents the two ECN bits. +type ECN uint8 + +const ( + //nolint:misspell + // ECNNonECT signals Non ECN-Capable Transport, Non-ECT. + ECNNonECT ECN = iota // 00 + + //nolint:misspell + // ECNECT1 signals ECN Capable Transport, ECT(0). + ECNECT1 // 01 + + //nolint:misspell + // ECNECT0 signals ECN Capable Transport, ECT(1). + ECNECT0 // 10 + + // ECNCE signals ECN Congestion Encountered, CE. + ECNCE // 11 +) + +func (e ECN) String() string { + switch e { + case ECNNonECT: + //nolint:misspell + return "Non-ECT (00)" + case ECNECT0: + //nolint:misspell + return "ECT(0) (01)" + case ECNECT1: + //nolint:misspell + return "ECT(1) (10)" + case ECNCE: + //nolint:misspell + return "CE (11)" + } + + return "invalid ECN value" +} + +const ( + reportTimestampLength = 4 + reportBlockOffset = 8 +) + +// CCFeedbackReport is a Congestion Control Feedback Report as defined in +// https://www.rfc-editor.org/rfc/rfc8888.html#name-rtcp-congestion-control-fee +type CCFeedbackReport struct { + // SSRC of sender + SenderSSRC uint32 + + // Report Blocks + ReportBlocks []CCFeedbackReportBlock + + // Basetime + ReportTimestamp uint32 +} + +// DestinationSSRC returns an array of SSRC values that this packet refers to. +func (b CCFeedbackReport) DestinationSSRC() []uint32 { + ssrcs := make([]uint32, len(b.ReportBlocks)) + for i, block := range b.ReportBlocks { + ssrcs[i] = block.MediaSSRC + } + + return ssrcs +} + +// Len returns the length of the report in bytes. +func (b *CCFeedbackReport) Len() int { + return b.MarshalSize() +} + +// MarshalSize returns the size of the packet once marshaled. +func (b *CCFeedbackReport) MarshalSize() int { + n := 0 + for _, block := range b.ReportBlocks { + n += block.len() + } + + return reportBlockOffset + n + reportTimestampLength +} + +// Header returns the Header associated with this packet. +func (b *CCFeedbackReport) Header() Header { + return Header{ + Padding: false, + Count: FormatCCFB, + Type: TypeTransportSpecificFeedback, + Length: uint16(b.MarshalSize()/4 - 1), //nolint:gosec // G115 + } +} + +// Marshal encodes the Congestion Control Feedback Report in binary. +func (b CCFeedbackReport) Marshal() ([]byte, error) { + header := b.Header() + headerBuf, err := header.Marshal() + if err != nil { + return nil, err + } + length := 4 * (header.Length + 1) + buf := make([]byte, length) + copy(buf[:headerLength], headerBuf) + binary.BigEndian.PutUint32(buf[headerLength:], b.SenderSSRC) + offset := reportBlockOffset + for _, block := range b.ReportBlocks { + b, err := block.marshal() + if err != nil { + return nil, err + } + copy(buf[offset:], b) + offset += block.len() + } + + binary.BigEndian.PutUint32(buf[offset:], b.ReportTimestamp) + + return buf, nil +} + +func (b CCFeedbackReport) String() string { + out := fmt.Sprintf("CCFB:\n\tHeader %v\n", b.Header()) + out += fmt.Sprintf("CCFB:\n\tSender SSRC %d\n", b.SenderSSRC) + out += fmt.Sprintf("\tReport Timestamp %d\n", b.ReportTimestamp) + out += "\tFeedback Reports \n" + for _, report := range b.ReportBlocks { + out += fmt.Sprintf("%v ", report) + } + out += "\n" + + return out +} + +// Unmarshal decodes the Congestion Control Feedback Report from binary. +func (b *CCFeedbackReport) Unmarshal(rawPacket []byte) error { + if len(rawPacket) < headerLength+ssrcLength+reportTimestampLength { + return errPacketTooShort + } + + var h Header + if err := h.Unmarshal(rawPacket); err != nil { + return err + } + if h.Type != TypeTransportSpecificFeedback { + return errWrongType + } + + b.SenderSSRC = binary.BigEndian.Uint32(rawPacket[headerLength:]) + + reportTimestampOffset := len(rawPacket) - reportTimestampLength + b.ReportTimestamp = binary.BigEndian.Uint32(rawPacket[reportTimestampOffset:]) + + offset := reportBlockOffset + b.ReportBlocks = []CCFeedbackReportBlock{} + for offset < reportTimestampOffset { + var block CCFeedbackReportBlock + if err := block.unmarshal(rawPacket[offset:]); err != nil { + return err + } + b.ReportBlocks = append(b.ReportBlocks, block) + offset += block.len() + } + + return nil +} + +const ( + ssrcOffset = 0 + beginSequenceOffset = 4 + numReportsOffset = 6 + reportsOffset = 8 + + maxMetricBlocks = 16384 +) + +// CCFeedbackReportBlock is a Feedback Report Block. +type CCFeedbackReportBlock struct { + // SSRC of the RTP stream on which this block is reporting + MediaSSRC uint32 + BeginSequence uint16 + MetricBlocks []CCFeedbackMetricBlock +} + +// len returns the length of the report block in bytes. +func (b *CCFeedbackReportBlock) len() int { + n := len(b.MetricBlocks) + if n%2 != 0 { + n++ + } + + return reportsOffset + 2*n +} + +func (b CCFeedbackReportBlock) String() string { + out := fmt.Sprintf("\tReport Block Media SSRC %d\n", b.MediaSSRC) + out += fmt.Sprintf("\tReport Begin Sequence Nr %d\n", b.BeginSequence) + out += fmt.Sprintf("\tReport length %d\n\t", len(b.MetricBlocks)) + for i, block := range b.MetricBlocks { + //nolint:gosec // G115 + out += fmt.Sprintf( + "{nr: %d, rx: %v, ts: %v, ecn: %v} ", + b.BeginSequence+uint16(i), + block.Received, + block.ArrivalTimeOffset, + block.ECN, + ) + } + out += "\n" + + return out +} + +// marshal encodes the Congestion Control Feedback Report Block in binary. +func (b CCFeedbackReportBlock) marshal() ([]byte, error) { + if len(b.MetricBlocks) > maxMetricBlocks { + return nil, errTooManyReports + } + + buf := make([]byte, b.len()) + binary.BigEndian.PutUint32(buf[ssrcOffset:], b.MediaSSRC) + binary.BigEndian.PutUint16(buf[beginSequenceOffset:], b.BeginSequence) + + length := uint16(len(b.MetricBlocks)) //nolint:gosec // G115 + + binary.BigEndian.PutUint16(buf[numReportsOffset:], length) + + for i, block := range b.MetricBlocks { + b, err := block.marshal() + if err != nil { + return nil, err + } + copy(buf[reportsOffset+i*2:], b) + } + + return buf, nil +} + +// Unmarshal decodes the Congestion Control Feedback Report Block from binary. +func (b *CCFeedbackReportBlock) unmarshal(rawPacket []byte) error { + if len(rawPacket) < reportsOffset { + return errReportBlockLength + } + b.MediaSSRC = binary.BigEndian.Uint32(rawPacket[:beginSequenceOffset]) + b.BeginSequence = binary.BigEndian.Uint16(rawPacket[beginSequenceOffset:numReportsOffset]) + numReports := int(binary.BigEndian.Uint16(rawPacket[numReportsOffset:])) + if numReports == 0 { + return nil + } + + if numReports > math.MaxUint16 { + return errIncorrectNumReports + } + + if len(rawPacket) < reportsOffset+numReports*2 { + return errIncorrectNumReports + } + + b.MetricBlocks = make([]CCFeedbackMetricBlock, numReports) + for i := int(0); i < numReports; i++ { + var mb CCFeedbackMetricBlock + offset := reportsOffset + 2*i + if err := mb.unmarshal(rawPacket[offset : offset+2]); err != nil { + return err + } + b.MetricBlocks[i] = mb + } + + return nil +} + +const ( + metricBlockLength = 2 +) + +// CCFeedbackMetricBlock is a Feedback Metric Block. +type CCFeedbackMetricBlock struct { + Received bool + ECN ECN + + // Offset in 1/1024 seconds before Report Timestamp + ArrivalTimeOffset uint16 +} + +// Marshal encodes the Congestion Control Feedback Metric Block in binary. +func (b CCFeedbackMetricBlock) marshal() ([]byte, error) { + buf := make([]byte, 2) + r := uint16(0) + if b.Received { + r = 1 + } + dst, err := setNBitsOfUint16(0, 1, 0, r) + if err != nil { + return nil, err + } + dst, err = setNBitsOfUint16(dst, 2, 1, uint16(b.ECN)) + if err != nil { + return nil, err + } + dst, err = setNBitsOfUint16(dst, 13, 3, b.ArrivalTimeOffset) + if err != nil { + return nil, err + } + + binary.BigEndian.PutUint16(buf, dst) + + return buf, nil +} + +// Unmarshal decodes the Congestion Control Feedback Metric Block from binary. +func (b *CCFeedbackMetricBlock) unmarshal(rawPacket []byte) error { + if len(rawPacket) != metricBlockLength { + return errMetricBlockLength + } + b.Received = rawPacket[0]&0x80 != 0 + if !b.Received { + b.ECN = ECNNonECT + b.ArrivalTimeOffset = 0 + + return nil + } + b.ECN = ECN(rawPacket[0] >> 5 & 0x03) + b.ArrivalTimeOffset = binary.BigEndian.Uint16(rawPacket) & 0x1FFF + + return nil +} diff --git a/vendor/github.com/pion/rtcp/sender_report.go b/vendor/github.com/pion/rtcp/sender_report.go new file mode 100644 index 0000000..0ad8699 --- /dev/null +++ b/vendor/github.com/pion/rtcp/sender_report.go @@ -0,0 +1,265 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package rtcp + +import ( + "encoding/binary" + "fmt" +) + +// A SenderReport (SR) packet provides reception quality feedback for an RTP stream. +type SenderReport struct { + // The synchronization source identifier for the originator of this SR packet. + SSRC uint32 + // The wallclock time when this report was sent so that it may be used in + // combination with timestamps returned in reception reports from other + // receivers to measure round-trip propagation to those receivers. + NTPTime uint64 + // Corresponds to the same time as the NTP timestamp (above), but in + // the same units and with the same random offset as the RTP + // timestamps in data packets. This correspondence may be used for + // intra- and inter-media synchronization for sources whose NTP + // timestamps are synchronized, and may be used by media-independent + // receivers to estimate the nominal RTP clock frequency. + RTPTime uint32 + // The total number of RTP data packets transmitted by the sender + // since starting transmission up until the time this SR packet was + // generated. + PacketCount uint32 + // The total number of payload octets (i.e., not including header or + // padding) transmitted in RTP data packets by the sender since + // starting transmission up until the time this SR packet was + // generated. + OctetCount uint32 + // Zero or more reception report blocks depending on the number of other + // sources heard by this sender since the last report. Each reception report + // block conveys statistics on the reception of RTP packets from a + // single synchronization source. + Reports []ReceptionReport + // ProfileExtensions contains additional, payload-specific information that needs to + // be reported regularly about the sender. + ProfileExtensions []byte +} + +const ( + srHeaderLength = 24 + srSSRCOffset = 0 + srNTPOffset = srSSRCOffset + ssrcLength + ntpTimeLength = 8 + srRTPOffset = srNTPOffset + ntpTimeLength + rtpTimeLength = 4 + srPacketCountOffset = srRTPOffset + rtpTimeLength + srPacketCountLength = 4 + srOctetCountOffset = srPacketCountOffset + srPacketCountLength + srOctetCountLength = 4 + srReportOffset = srOctetCountOffset + srOctetCountLength +) + +// Marshal encodes the SenderReport in binary. +func (r SenderReport) Marshal() ([]byte, error) { + /* + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * header |V=2|P| RC | PT=SR=200 | length | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | SSRC of sender | + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * sender | NTP timestamp, most significant word | + * info +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | NTP timestamp, least significant word | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | RTP timestamp | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | sender's packet count | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | sender's octet count | + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * report | SSRC_1 (SSRC of first source) | + * block +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * 1 | fraction lost | cumulative number of packets lost | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | extended highest sequence number received | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | interarrival jitter | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | last SR (LSR) | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | delay since last SR (DLSR) | + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * report | SSRC_2 (SSRC of second source) | + * block +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * 2 : ... : + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * | profile-specific extensions | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ + + rawPacket := make([]byte, r.MarshalSize()) + packetBody := rawPacket[headerLength:] + + binary.BigEndian.PutUint32(packetBody[srSSRCOffset:], r.SSRC) + binary.BigEndian.PutUint64(packetBody[srNTPOffset:], r.NTPTime) + binary.BigEndian.PutUint32(packetBody[srRTPOffset:], r.RTPTime) + binary.BigEndian.PutUint32(packetBody[srPacketCountOffset:], r.PacketCount) + binary.BigEndian.PutUint32(packetBody[srOctetCountOffset:], r.OctetCount) + + offset := srHeaderLength + for _, rp := range r.Reports { + data, err := rp.Marshal() + if err != nil { + return nil, err + } + copy(packetBody[offset:], data) + offset += receptionReportLength + } + + if len(r.Reports) > countMax { + return nil, errTooManyReports + } + + copy(packetBody[offset:], r.ProfileExtensions) + + hData, err := r.Header().Marshal() + if err != nil { + return nil, err + } + copy(rawPacket, hData) + + return rawPacket, nil +} + +// Unmarshal decodes the SenderReport from binary. +func (r *SenderReport) Unmarshal(rawPacket []byte) error { + /* + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * header |V=2|P| RC | PT=SR=200 | length | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | SSRC of sender | + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * sender | NTP timestamp, most significant word | + * info +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | NTP timestamp, least significant word | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | RTP timestamp | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | sender's packet count | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | sender's octet count | + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * report | SSRC_1 (SSRC of first source) | + * block +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * 1 | fraction lost | cumulative number of packets lost | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | extended highest sequence number received | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | interarrival jitter | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | last SR (LSR) | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | delay since last SR (DLSR) | + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * report | SSRC_2 (SSRC of second source) | + * block +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * 2 : ... : + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * | profile-specific extensions | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ + + if len(rawPacket) < (headerLength + srHeaderLength) { + return errPacketTooShort + } + + var header Header + if err := header.Unmarshal(rawPacket); err != nil { + return err + } + + if header.Type != TypeSenderReport { + return errWrongType + } + + packetBody := rawPacket[headerLength:] + + r.SSRC = binary.BigEndian.Uint32(packetBody[srSSRCOffset:]) + r.NTPTime = binary.BigEndian.Uint64(packetBody[srNTPOffset:]) + r.RTPTime = binary.BigEndian.Uint32(packetBody[srRTPOffset:]) + r.PacketCount = binary.BigEndian.Uint32(packetBody[srPacketCountOffset:]) + r.OctetCount = binary.BigEndian.Uint32(packetBody[srOctetCountOffset:]) + + offset := srReportOffset + for i := 0; i < int(header.Count); i++ { + rrEnd := offset + receptionReportLength + if rrEnd > len(packetBody) { + return errPacketTooShort + } + rrBody := packetBody[offset : offset+receptionReportLength] + offset = rrEnd + + var rr ReceptionReport + if err := rr.Unmarshal(rrBody); err != nil { + return err + } + r.Reports = append(r.Reports, rr) + } + + if offset < len(packetBody) { + r.ProfileExtensions = packetBody[offset:] + } + + if uint8(len(r.Reports)) != header.Count { //nolint:gosec // G115 + return errInvalidHeader + } + + return nil +} + +// DestinationSSRC returns an array of SSRC values that this packet refers to. +func (r *SenderReport) DestinationSSRC() []uint32 { + out := make([]uint32, len(r.Reports)+1) + for i, v := range r.Reports { + out[i] = v.SSRC + } + out[len(r.Reports)] = r.SSRC + + return out +} + +// MarshalSize returns the size of the packet once marshaled. +func (r *SenderReport) MarshalSize() int { + repsLength := 0 + for _, rep := range r.Reports { + repsLength += rep.len() + } + + return headerLength + srHeaderLength + repsLength + len(r.ProfileExtensions) +} + +// Header returns the Header associated with this packet. +func (r *SenderReport) Header() Header { + return Header{ + Count: uint8(len(r.Reports)), //nolint:gosec // G115 + Type: TypeSenderReport, + Length: uint16((r.MarshalSize() / 4) - 1), //nolint:gosec // G115 + } +} + +func (r SenderReport) String() string { + out := fmt.Sprintf("SenderReport from %x\n", r.SSRC) + out += fmt.Sprintf("\tNTPTime:\t%d\n", r.NTPTime) + out += fmt.Sprintf("\tRTPTIme:\t%d\n", r.RTPTime) + out += fmt.Sprintf("\tPacketCount:\t%d\n", r.PacketCount) + out += fmt.Sprintf("\tOctetCount:\t%d\n", r.OctetCount) + + out += "\tSSRC \tLost\tLastSequence\n" + for _, i := range r.Reports { + out += fmt.Sprintf("\t%x\t%d/%d\t%d\n", i.SSRC, i.FractionLost, i.TotalLost, i.LastSequenceNumber) + } + out += fmt.Sprintf("\tProfile Extension Data: %v\n", r.ProfileExtensions) + + return out +} diff --git a/vendor/github.com/pion/rtcp/slice_loss_indication.go b/vendor/github.com/pion/rtcp/slice_loss_indication.go new file mode 100644 index 0000000..43e2d80 --- /dev/null +++ b/vendor/github.com/pion/rtcp/slice_loss_indication.go @@ -0,0 +1,118 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package rtcp + +import ( + "encoding/binary" + "fmt" + "math" +) + +// SLIEntry represents a single entry to the SLI packet's +// list of lost slices. +type SLIEntry struct { + // ID of first lost slice + First uint16 + + // Number of lost slices + Number uint16 + + // ID of related picture + Picture uint8 +} + +// The SliceLossIndication packet informs the encoder about the loss of a picture slice. +type SliceLossIndication struct { + // SSRC of sender + SenderSSRC uint32 + + // SSRC of the media source + MediaSSRC uint32 + + SLI []SLIEntry +} + +const ( + sliLength = 2 + sliOffset = 8 +) + +// Marshal encodes the SliceLossIndication in binary. +func (p SliceLossIndication) Marshal() ([]byte, error) { + if len(p.SLI)+sliLength > math.MaxUint8 { + return nil, errTooManyReports + } + + rawPacket := make([]byte, sliOffset+(len(p.SLI)*4)) + binary.BigEndian.PutUint32(rawPacket, p.SenderSSRC) + binary.BigEndian.PutUint32(rawPacket[4:], p.MediaSSRC) + for i, s := range p.SLI { + sli := ((uint32(s.First) & 0x1FFF) << 19) | + ((uint32(s.Number) & 0x1FFF) << 6) | + (uint32(s.Picture) & 0x3F) + binary.BigEndian.PutUint32(rawPacket[sliOffset+(4*i):], sli) + } + hData, err := p.Header().Marshal() + if err != nil { + return nil, err + } + + return append(hData, rawPacket...), nil +} + +// Unmarshal decodes the SliceLossIndication from binary. +func (p *SliceLossIndication) Unmarshal(rawPacket []byte) error { + if len(rawPacket) < (headerLength + ssrcLength) { + return errPacketTooShort + } + + var header Header + if err := header.Unmarshal(rawPacket); err != nil { + return err + } + + if len(rawPacket) < (headerLength + int(4*header.Length)) { + return errPacketTooShort + } + + if header.Type != TypeTransportSpecificFeedback || header.Count != FormatSLI { + return errWrongType + } + + p.SenderSSRC = binary.BigEndian.Uint32(rawPacket[headerLength:]) + p.MediaSSRC = binary.BigEndian.Uint32(rawPacket[headerLength+ssrcLength:]) + for i := headerLength + sliOffset; i < (headerLength + int(header.Length*4)); i += 4 { + sli := binary.BigEndian.Uint32(rawPacket[i:]) + p.SLI = append(p.SLI, SLIEntry{ + First: uint16((sli >> 19) & 0x1FFF), //nolint:gosec // G115 + Number: uint16((sli >> 6) & 0x1FFF), //nolint:gosec // G115 + Picture: uint8(sli & 0x3F), //nolint:gosec // G115 + }) + } + + return nil +} + +// MarshalSize returns the size of the packet once marshaled. +func (p *SliceLossIndication) MarshalSize() int { + return headerLength + sliOffset + (len(p.SLI) * 4) +} + +// Header returns the Header associated with this packet. +func (p *SliceLossIndication) Header() Header { + return Header{ + Count: FormatSLI, + Type: TypeTransportSpecificFeedback, + Length: uint16((p.MarshalSize() / 4) - 1), //nolint:gosec // G115 + } +} + +func (p *SliceLossIndication) String() string { + return fmt.Sprintf("SliceLossIndication %x %x %+v", p.SenderSSRC, p.MediaSSRC, p.SLI) +} + +// DestinationSSRC returns an array of SSRC values that this packet refers to. +func (p *SliceLossIndication) DestinationSSRC() []uint32 { + return []uint32{p.MediaSSRC} +} diff --git a/vendor/github.com/pion/rtcp/source_description.go b/vendor/github.com/pion/rtcp/source_description.go new file mode 100644 index 0000000..6d6f2a0 --- /dev/null +++ b/vendor/github.com/pion/rtcp/source_description.go @@ -0,0 +1,374 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package rtcp + +import ( + "encoding/binary" + "fmt" +) + +// SDESType is the item type used in the RTCP SDES control packet. +type SDESType uint8 + +// RTP SDES item types registered with IANA. +// See: https://www.iana.org/assignments/rtp-parameters/rtp-parameters.xhtml#rtp-parameters-5 +// . +const ( + SDESEnd SDESType = iota // end of SDES list RFC 3550, 6.5 + SDESCNAME // canonical name RFC 3550, 6.5.1 + SDESName // user name RFC 3550, 6.5.2 + SDESEmail // user's electronic mail address RFC 3550, 6.5.3 + SDESPhone // user's phone number RFC 3550, 6.5.4 + SDESLocation // geographic user location RFC 3550, 6.5.5 + SDESTool // name of application or tool RFC 3550, 6.5.6 + SDESNote // notice about the source RFC 3550, 6.5.7 + SDESPrivate // private extensions RFC 3550, 6.5.8 (not implemented) +) + +//nolint:cyclop +func (s SDESType) String() string { + switch s { + case SDESEnd: + return "END" + case SDESCNAME: + return "CNAME" + case SDESName: + return "NAME" + case SDESEmail: + return "EMAIL" + case SDESPhone: + return "PHONE" + case SDESLocation: + return "LOC" + case SDESTool: + return "TOOL" + case SDESNote: + return "NOTE" + case SDESPrivate: + return "PRIV" + default: + return string(s) + } +} + +const ( + sdesSourceLen = 4 + sdesTypeLen = 1 + sdesTypeOffset = 0 + sdesOctetCountLen = 1 + sdesOctetCountOffset = 1 + sdesMaxOctetCount = (1 << 8) - 1 + sdesTextOffset = 2 +) + +// A SourceDescription (SDES) packet describes the sources in an RTP stream. +type SourceDescription struct { + Chunks []SourceDescriptionChunk +} + +// NewCNAMESourceDescription creates a new SourceDescription with a single CNAME item. +func NewCNAMESourceDescription(ssrc uint32, cname string) *SourceDescription { + return &SourceDescription{ + Chunks: []SourceDescriptionChunk{{ + Source: ssrc, + Items: []SourceDescriptionItem{{ + Type: SDESCNAME, + Text: cname, + }}, + }}, + } +} + +// Marshal encodes the SourceDescription in binary. +func (s SourceDescription) Marshal() ([]byte, error) { + /* + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * header |V=2|P| SC | PT=SDES=202 | length | + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * chunk | SSRC/CSRC_1 | + * 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | SDES items | + * | ... | + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * chunk | SSRC/CSRC_2 | + * 2 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | SDES items | + * | ... | + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + */ + + rawPacket := make([]byte, s.MarshalSize()) + packetBody := rawPacket[headerLength:] + + chunkOffset := 0 + for _, c := range s.Chunks { + data, err := c.Marshal() + if err != nil { + return nil, err + } + copy(packetBody[chunkOffset:], data) + chunkOffset += len(data) + } + + if len(s.Chunks) > countMax { + return nil, errTooManyChunks + } + + hData, err := s.Header().Marshal() + if err != nil { + return nil, err + } + copy(rawPacket, hData) + + return rawPacket, nil +} + +// Unmarshal decodes the SourceDescription from binary. +func (s *SourceDescription) Unmarshal(rawPacket []byte) error { + /* + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * header |V=2|P| SC | PT=SDES=202 | length | + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * chunk | SSRC/CSRC_1 | + * 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | SDES items | + * | ... | + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * chunk | SSRC/CSRC_2 | + * 2 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | SDES items | + * | ... | + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + */ + + var header Header + if err := header.Unmarshal(rawPacket); err != nil { + return err + } + + if header.Type != TypeSourceDescription { + return errWrongType + } + + for i := headerLength; i < len(rawPacket); { + var chunk SourceDescriptionChunk + if err := chunk.Unmarshal(rawPacket[i:]); err != nil { + return err + } + s.Chunks = append(s.Chunks, chunk) + + i += chunk.len() + } + + if len(s.Chunks) != int(header.Count) { + return errInvalidHeader + } + + return nil +} + +// MarshalSize returns the size of the packet once marshaled. +func (s *SourceDescription) MarshalSize() int { + chunksLength := 0 + for _, c := range s.Chunks { + chunksLength += c.len() + } + + return headerLength + chunksLength +} + +// Header returns the Header associated with this packet. +func (s *SourceDescription) Header() Header { + return Header{ + Count: uint8(len(s.Chunks)), //nolint:gosec // G115 + Type: TypeSourceDescription, + Length: uint16((s.MarshalSize() / 4) - 1), //nolint:gosec // G115 + } +} + +// A SourceDescriptionChunk contains items describing a single RTP source. +type SourceDescriptionChunk struct { + // The source (ssrc) or contributing source (csrc) identifier this packet describes + Source uint32 + Items []SourceDescriptionItem +} + +// Marshal encodes the SourceDescriptionChunk in binary. +func (s SourceDescriptionChunk) Marshal() ([]byte, error) { + /* + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * | SSRC/CSRC_1 | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | SDES items | + * | ... | + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + */ + + rawPacket := make([]byte, sdesSourceLen) + binary.BigEndian.PutUint32(rawPacket, s.Source) + + for _, it := range s.Items { + data, err := it.Marshal() + if err != nil { + return nil, err + } + rawPacket = append(rawPacket, data...) //nolint:makezero + } + + // The list of items in each chunk MUST be terminated by one or more null octets + rawPacket = append(rawPacket, uint8(SDESEnd)) //nolint:makezero + + // additional null octets MUST be included if needed to pad until the next 32-bit boundary + rawPacket = append(rawPacket, make([]byte, getPadding(len(rawPacket)))...) //nolint:makezero + + return rawPacket, nil +} + +// Unmarshal decodes the SourceDescriptionChunk from binary. +func (s *SourceDescriptionChunk) Unmarshal(rawPacket []byte) error { + /* + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * | SSRC/CSRC_1 | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | SDES items | + * | ... | + * +=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + */ + + if len(rawPacket) < (sdesSourceLen + sdesTypeLen) { + return errPacketTooShort + } + + s.Source = binary.BigEndian.Uint32(rawPacket) + + for i := 4; i < len(rawPacket); { + if pktType := SDESType(rawPacket[i]); pktType == SDESEnd { + return nil + } + + var it SourceDescriptionItem + if err := it.Unmarshal(rawPacket[i:]); err != nil { + return err + } + s.Items = append(s.Items, it) + i += it.Len() + } + + return errPacketTooShort +} + +func (s SourceDescriptionChunk) len() int { + chunkLen := sdesSourceLen + for _, it := range s.Items { + chunkLen += it.Len() + } + chunkLen += sdesTypeLen // for terminating null octet + + // align to 32-bit boundary + chunkLen += getPadding(chunkLen) + + return chunkLen +} + +// A SourceDescriptionItem is a part of a SourceDescription that describes a stream. +type SourceDescriptionItem struct { + // The type identifier for this item. eg, SDESCNAME for canonical name description. + // + // Type zero or SDESEnd is interpreted as the end of an item list and cannot be used. + Type SDESType + // Text is a unicode text blob associated with the item. Its meaning varies based on the item's Type. + Text string +} + +// Len returns the length of the SourceDescriptionItem when encoded as binary. +func (s SourceDescriptionItem) Len() int { + /* + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | CNAME=1 | length | user and domain name ... + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ + return sdesTypeLen + sdesOctetCountLen + len([]byte(s.Text)) +} + +// Marshal encodes the SourceDescriptionItem in binary. +func (s SourceDescriptionItem) Marshal() ([]byte, error) { + /* + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | CNAME=1 | length | user and domain name ... + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ + + if s.Type == SDESEnd { + return nil, errSDESMissingType + } + + rawPacket := make([]byte, sdesTypeLen+sdesOctetCountLen) + + rawPacket[sdesTypeOffset] = uint8(s.Type) + + txtBytes := []byte(s.Text) + octetCount := len(txtBytes) + if octetCount > sdesMaxOctetCount { + return nil, errSDESTextTooLong + } + rawPacket[sdesOctetCountOffset] = uint8(octetCount) + + rawPacket = append(rawPacket, txtBytes...) //nolint:makezero + + return rawPacket, nil +} + +// Unmarshal decodes the SourceDescriptionItem from binary. +func (s *SourceDescriptionItem) Unmarshal(rawPacket []byte) error { + /* + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | CNAME=1 | length | user and domain name ... + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ + + if len(rawPacket) < (sdesTypeLen + sdesOctetCountLen) { + return errPacketTooShort + } + + s.Type = SDESType(rawPacket[sdesTypeOffset]) + + octetCount := int(rawPacket[sdesOctetCountOffset]) + if sdesTextOffset+octetCount > len(rawPacket) { + return errPacketTooShort + } + + txtBytes := rawPacket[sdesTextOffset : sdesTextOffset+octetCount] + s.Text = string(txtBytes) + + return nil +} + +// DestinationSSRC returns an array of SSRC values that this packet refers to. +func (s *SourceDescription) DestinationSSRC() []uint32 { + out := make([]uint32, len(s.Chunks)) + for i, v := range s.Chunks { + out[i] = v.Source + } + + return out +} + +func (s *SourceDescription) String() string { + out := "Source Description:\n" + for _, c := range s.Chunks { + out += fmt.Sprintf("\t%x: %s\n", c.Source, c.Items) + } + + return out +} diff --git a/vendor/github.com/pion/rtcp/transport_layer_cc.go b/vendor/github.com/pion/rtcp/transport_layer_cc.go new file mode 100644 index 0000000..de464cb --- /dev/null +++ b/vendor/github.com/pion/rtcp/transport_layer_cc.go @@ -0,0 +1,585 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package rtcp + +// Author: adwpc + +import ( + "encoding/binary" + "errors" + "fmt" + "math" +) + +// https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01#page-5 +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |V=2|P| FMT=15 | PT=205 | length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | SSRC of packet sender | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | SSRC of media source | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | base sequence number | packet status count | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | reference time | fb pkt. count | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | packet chunk | packet chunk | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// . . +// . . +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | packet chunk | recv delta | recv delta | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// . . +// . . +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | recv delta | recv delta | zero padding | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +// for packet status chunk. +const ( + // type of packet status chunk. + TypeTCCRunLengthChunk = 0 + TypeTCCStatusVectorChunk = 1 + + // len of packet status chunk. + packetStatusChunkLength = 2 +) + +// type of packet status symbol and recv delta. +const ( + // https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01#section-3.1.1 + TypeTCCPacketNotReceived = uint16(iota) + TypeTCCPacketReceivedSmallDelta + TypeTCCPacketReceivedLargeDelta + // https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01#page-7 + // see Example 2: "packet received, w/o recv delta". + TypeTCCPacketReceivedWithoutDelta +) + +// for status vector chunk. +const ( + // https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01#section-3.1.4 + TypeTCCSymbolSizeOneBit = 0 + TypeTCCSymbolSizeTwoBit = 1 + + // Notice: RFC is wrong: "packet received" (0) and "packet not received" (1) + // if S == TypeTCCSymbolSizeOneBit, symbol list will be: TypeTCCPacketNotReceived TypeTCCPacketReceivedSmallDelta + // if S == TypeTCCSymbolSizeTwoBit, symbol list will be same as above: + //. +) + +func numOfBitsOfSymbolSize() map[uint16]uint16 { + return map[uint16]uint16{ + TypeTCCSymbolSizeOneBit: 1, + TypeTCCSymbolSizeTwoBit: 2, + } +} + +var ( + errPacketStatusChunkLength = errors.New("packet status chunk must be 2 bytes") + errDeltaExceedLimit = errors.New("delta exceed limit") +) + +// PacketStatusChunk has two kinds: +// RunLengthChunk and StatusVectorChunk. +type PacketStatusChunk interface { + Marshal() ([]byte, error) + Unmarshal(rawPacket []byte) error +} + +// RunLengthChunk T=TypeTCCRunLengthChunk +// 0 1 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |T| S | Run Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// . +type RunLengthChunk struct { + PacketStatusChunk + + // T = TypeTCCRunLengthChunk + Type uint16 + + // S: type of packet status + // kind: TypeTCCPacketNotReceived or... + PacketStatusSymbol uint16 + + // RunLength: count of S + RunLength uint16 +} + +// Marshal .. +func (r RunLengthChunk) Marshal() ([]byte, error) { + chunk := make([]byte, 2) + + // append 1 bit '0' + dst, err := setNBitsOfUint16(0, 1, 0, 0) + if err != nil { + return nil, err + } + + // append 2 bit PacketStatusSymbol + dst, err = setNBitsOfUint16(dst, 2, 1, r.PacketStatusSymbol) + if err != nil { + return nil, err + } + + // append 13 bit RunLength + dst, err = setNBitsOfUint16(dst, 13, 3, r.RunLength) + if err != nil { + return nil, err + } + + binary.BigEndian.PutUint16(chunk, dst) + + return chunk, nil +} + +// Unmarshal .. +func (r *RunLengthChunk) Unmarshal(rawPacket []byte) error { + if len(rawPacket) != packetStatusChunkLength { + return errPacketStatusChunkLength + } + + // record type + r.Type = TypeTCCRunLengthChunk + + // get PacketStatusSymbol + // r.PacketStatusSymbol = uint16(rawPacket[0] >> 5 & 0x03) + r.PacketStatusSymbol = getNBitsFromByte(rawPacket[0], 1, 2) + + // get RunLength + // r.RunLength = uint16(rawPacket[0]&0x1F)*256 + uint16(rawPacket[1]) + r.RunLength = getNBitsFromByte(rawPacket[0], 3, 5)<<8 + uint16(rawPacket[1]) + + return nil +} + +// StatusVectorChunk T=typeStatusVecotrChunk +// 0 1 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |T|S| symbol list | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// . +type StatusVectorChunk struct { + PacketStatusChunk + // T = TypeTCCRunLengthChunk + Type uint16 + + // TypeTCCSymbolSizeOneBit or TypeTCCSymbolSizeTwoBit + SymbolSize uint16 + + // when SymbolSize = TypeTCCSymbolSizeOneBit, SymbolList is 14*1bit: + // TypeTCCSymbolListPacketReceived or TypeTCCSymbolListPacketNotReceived + // when SymbolSize = TypeTCCSymbolSizeTwoBit, SymbolList is 7*2bit: + // TypeTCCPacketNotReceived TypeTCCPacketReceivedSmallDelta TypeTCCPacketReceivedLargeDelta or typePacketReserved + SymbolList []uint16 +} + +// Marshal .. +func (r StatusVectorChunk) Marshal() ([]byte, error) { + chunk := make([]byte, 2) + + // set first bit '1' + dst, err := setNBitsOfUint16(0, 1, 0, 1) + if err != nil { + return nil, err + } + + // set second bit SymbolSize + dst, err = setNBitsOfUint16(dst, 1, 1, r.SymbolSize) + if err != nil { + return nil, err + } + + numOfBits := numOfBitsOfSymbolSize()[r.SymbolSize] + // append 14 bit SymbolList + for i, s := range r.SymbolList { + index := numOfBits*uint16(i) + 2 //nolint:gosec // G115 + dst, err = setNBitsOfUint16(dst, numOfBits, index, s) + if err != nil { + return nil, err + } + } + + binary.BigEndian.PutUint16(chunk, dst) + // set SymbolList(bit8-15) + // chunk[1] = uint8(r.SymbolList) & 0x0f + return chunk, nil +} + +// Unmarshal .. +func (r *StatusVectorChunk) Unmarshal(rawPacket []byte) error { + if len(rawPacket) != packetStatusChunkLength { + return errPacketStatusChunkLength + } + + r.Type = TypeTCCStatusVectorChunk + r.SymbolSize = getNBitsFromByte(rawPacket[0], 1, 1) + + if r.SymbolSize == TypeTCCSymbolSizeOneBit { + for i := uint16(0); i < 6; i++ { + r.SymbolList = append(r.SymbolList, getNBitsFromByte(rawPacket[0], 2+i, 1)) + } + for i := uint16(0); i < 8; i++ { + r.SymbolList = append(r.SymbolList, getNBitsFromByte(rawPacket[1], i, 1)) + } + + return nil + } + if r.SymbolSize == TypeTCCSymbolSizeTwoBit { + for i := uint16(0); i < 3; i++ { + r.SymbolList = append(r.SymbolList, getNBitsFromByte(rawPacket[0], 2+i*2, 2)) + } + for i := uint16(0); i < 4; i++ { + r.SymbolList = append(r.SymbolList, getNBitsFromByte(rawPacket[1], i*2, 2)) + } + + return nil + } + + r.SymbolSize = getNBitsFromByte(rawPacket[0], 2, 6)<<8 + uint16(rawPacket[1]) + + return nil +} + +const ( + // TypeTCCDeltaScaleFactor https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01#section-3.1.5 + TypeTCCDeltaScaleFactor = 250 +) + +// RecvDelta are represented as multiples of 250us +// small delta is 1 byte: [0,63.75]ms = [0, 63750]us = [0, 255]*250us +// big delta is 2 bytes: [-8192.0, 8191.75]ms = [-8192000, 8191750]us = [-32768, 32767]*250us +// https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01#section-3.1.5 +type RecvDelta struct { + Type uint16 + // us + Delta int64 +} + +// Marshal .. +func (r RecvDelta) Marshal() ([]byte, error) { + delta := r.Delta / TypeTCCDeltaScaleFactor + + // small delta + if r.Type == TypeTCCPacketReceivedSmallDelta && delta >= 0 && delta <= math.MaxUint8 { + deltaChunk := make([]byte, 1) + deltaChunk[0] = byte(delta) + + return deltaChunk, nil + } + + // big delta + if r.Type == TypeTCCPacketReceivedLargeDelta && delta >= math.MinInt16 && delta <= math.MaxInt16 { + deltaChunk := make([]byte, 2) + binary.BigEndian.PutUint16(deltaChunk, uint16(delta)) + + return deltaChunk, nil + } + + // overflow + return nil, errDeltaExceedLimit +} + +// Unmarshal .. +func (r *RecvDelta) Unmarshal(rawPacket []byte) error { + chunkLen := len(rawPacket) + + // must be 1 or 2 bytes + if chunkLen != 1 && chunkLen != 2 { + return errDeltaExceedLimit + } + + if chunkLen == 1 { + r.Type = TypeTCCPacketReceivedSmallDelta + r.Delta = TypeTCCDeltaScaleFactor * int64(rawPacket[0]) + + return nil + } + + r.Type = TypeTCCPacketReceivedLargeDelta + r.Delta = TypeTCCDeltaScaleFactor * int64(int16(binary.BigEndian.Uint16(rawPacket))) //nolint:gosec // G115 + + return nil +} + +const ( + // the offset after header. + baseSequenceNumberOffset = 8 + packetStatusCountOffset = 10 + referenceTimeOffset = 12 + fbPktCountOffset = 15 + packetChunkOffset = 16 +) + +// TransportLayerCC for sender-BWE +// https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01#page-5 +type TransportLayerCC struct { + // header + Header Header + + // SSRC of sender + SenderSSRC uint32 + + // SSRC of the media source + MediaSSRC uint32 + + // Transport wide sequence of rtp extension + BaseSequenceNumber uint16 + + // PacketStatusCount + PacketStatusCount uint16 + + // ReferenceTime + ReferenceTime uint32 + + // FbPktCount + FbPktCount uint8 + + // PacketChunks + PacketChunks []PacketStatusChunk + + // RecvDeltas + RecvDeltas []*RecvDelta +} + +// Header returns the Header associated with this packet. +// func (t *TransportLayerCC) Header() Header { +// return t.Header +// return Header{ +// Padding: true, +// Count: FormatTCC, +// Type: TypeTCCTransportSpecificFeedback, +// // https://tools.ietf.org/html/rfc4585#page-33 +// Length: uint16((t.len() / 4) - 1), +// } +// } + +func (t *TransportLayerCC) packetLen() uint16 { + //nolint:gocognit,cyclop + n := uint16(headerLength + packetChunkOffset + len(t.PacketChunks)*2) //nolint:gosec // G115 + for _, d := range t.RecvDeltas { + if d.Type == TypeTCCPacketReceivedSmallDelta { + n++ + } else { + n += 2 + } + } + + return n +} + +// Len return total bytes with padding. +func (t *TransportLayerCC) Len() uint16 { + return uint16(t.MarshalSize()) //nolint:gosec // G115 +} + +// MarshalSize returns the size of the packet once marshaled. +func (t *TransportLayerCC) MarshalSize() int { + n := t.packetLen() + // has padding + if n%4 != 0 { + n = (n/4 + 1) * 4 + } + + return int(n) +} + +func (t TransportLayerCC) String() string { + out := fmt.Sprintf("TransportLayerCC:\n\tHeader %v\n", t.Header) + out += fmt.Sprintf("TransportLayerCC:\n\tSender Ssrc %d\n", t.SenderSSRC) + out += fmt.Sprintf("\tMedia Ssrc %d\n", t.MediaSSRC) + out += fmt.Sprintf("\tBase Sequence Number %d\n", t.BaseSequenceNumber) + out += fmt.Sprintf("\tStatus Count %d\n", t.PacketStatusCount) + out += fmt.Sprintf("\tReference Time %d\n", t.ReferenceTime) + out += fmt.Sprintf("\tFeedback Packet Count %d\n", t.FbPktCount) + out += "\tPacketChunks " + for _, chunk := range t.PacketChunks { + out += fmt.Sprintf("%+v ", chunk) + } + out += "\n\tRecvDeltas " + for _, delta := range t.RecvDeltas { + out += fmt.Sprintf("%+v ", delta) + } + out += "\n" + + return out +} + +// Marshal encodes the TransportLayerCC in binary. +func (t TransportLayerCC) Marshal() ([]byte, error) { + header, err := t.Header.Marshal() + if err != nil { + return nil, err + } + + payload := make([]byte, t.MarshalSize()-headerLength) + binary.BigEndian.PutUint32(payload, t.SenderSSRC) + binary.BigEndian.PutUint32(payload[4:], t.MediaSSRC) + binary.BigEndian.PutUint16(payload[baseSequenceNumberOffset:], t.BaseSequenceNumber) + binary.BigEndian.PutUint16(payload[packetStatusCountOffset:], t.PacketStatusCount) + ReferenceTimeAndFbPktCount := appendNBitsToUint32(0, 24, t.ReferenceTime) + ReferenceTimeAndFbPktCount = appendNBitsToUint32(ReferenceTimeAndFbPktCount, 8, uint32(t.FbPktCount)) + binary.BigEndian.PutUint32(payload[referenceTimeOffset:], ReferenceTimeAndFbPktCount) + + for i, chunk := range t.PacketChunks { + b, err := chunk.Marshal() + if err != nil { + return nil, err + } + copy(payload[packetChunkOffset+i*2:], b) + } + + recvDeltaOffset := packetChunkOffset + len(t.PacketChunks)*2 + var i int + for _, delta := range t.RecvDeltas { + b, err := delta.Marshal() + if err == nil { + copy(payload[recvDeltaOffset+i:], b) + i++ + if delta.Type == TypeTCCPacketReceivedLargeDelta { + i++ + } + } + } + + if t.Header.Padding { + payload[len(payload)-1] = uint8(t.MarshalSize() - int(t.packetLen())) //nolint:gosec // G115 + } + + return append(header, payload...), nil +} + +// Unmarshal .. +// +//nolint:gocognit,cyclop +func (t *TransportLayerCC) Unmarshal(rawPacket []byte) error { + if len(rawPacket) < (headerLength + ssrcLength) { + return errPacketTooShort + } + + if err := t.Header.Unmarshal(rawPacket); err != nil { + return err + } + + // https://tools.ietf.org/html/rfc4585#page-33 + // header's length + payload's length + totalLength := 4 * (t.Header.Length + 1) + + if totalLength < headerLength+packetChunkOffset { + return errPacketTooShort + } + + if len(rawPacket) < int(totalLength) { + return errPacketTooShort + } + + if t.Header.Type != TypeTransportSpecificFeedback || t.Header.Count != FormatTCC { + return errWrongType + } + + t.SenderSSRC = binary.BigEndian.Uint32(rawPacket[headerLength:]) + t.MediaSSRC = binary.BigEndian.Uint32(rawPacket[headerLength+ssrcLength:]) + t.BaseSequenceNumber = binary.BigEndian.Uint16(rawPacket[headerLength+baseSequenceNumberOffset:]) + t.PacketStatusCount = binary.BigEndian.Uint16(rawPacket[headerLength+packetStatusCountOffset:]) + t.ReferenceTime = get24BitsFromBytes(rawPacket[headerLength+referenceTimeOffset : headerLength+referenceTimeOffset+3]) + t.FbPktCount = rawPacket[headerLength+fbPktCountOffset] + + packetStatusPos := uint16(headerLength + packetChunkOffset) + var processedPacketNum uint16 + for processedPacketNum < t.PacketStatusCount { + if packetStatusPos+packetStatusChunkLength >= totalLength { + return errPacketTooShort + } + typ := getNBitsFromByte(rawPacket[packetStatusPos : packetStatusPos+1][0], 0, 1) + var iPacketStatus PacketStatusChunk + switch typ { + case TypeTCCRunLengthChunk: + packetStatus := &RunLengthChunk{Type: typ} + iPacketStatus = packetStatus + err := packetStatus.Unmarshal(rawPacket[packetStatusPos : packetStatusPos+2]) + if err != nil { + return err + } + + packetNumberToProcess := localMin(t.PacketStatusCount-processedPacketNum, packetStatus.RunLength) + if packetStatus.PacketStatusSymbol == TypeTCCPacketReceivedSmallDelta || + packetStatus.PacketStatusSymbol == TypeTCCPacketReceivedLargeDelta { + for j := uint16(0); j < packetNumberToProcess; j++ { + t.RecvDeltas = append(t.RecvDeltas, &RecvDelta{Type: packetStatus.PacketStatusSymbol}) + } + } + processedPacketNum += packetNumberToProcess + case TypeTCCStatusVectorChunk: + packetStatus := &StatusVectorChunk{Type: typ} + iPacketStatus = packetStatus + err := packetStatus.Unmarshal(rawPacket[packetStatusPos : packetStatusPos+2]) + if err != nil { + return err + } + if packetStatus.SymbolSize == TypeTCCSymbolSizeOneBit { + for j := 0; j < len(packetStatus.SymbolList); j++ { + if packetStatus.SymbolList[j] == TypeTCCPacketReceivedSmallDelta { + t.RecvDeltas = append(t.RecvDeltas, &RecvDelta{Type: TypeTCCPacketReceivedSmallDelta}) + } + } + } + if packetStatus.SymbolSize == TypeTCCSymbolSizeTwoBit { + for j := 0; j < len(packetStatus.SymbolList); j++ { + if packetStatus.SymbolList[j] == TypeTCCPacketReceivedSmallDelta || + packetStatus.SymbolList[j] == TypeTCCPacketReceivedLargeDelta { + t.RecvDeltas = append(t.RecvDeltas, &RecvDelta{Type: packetStatus.SymbolList[j]}) + } + } + } + processedPacketNum += uint16(len(packetStatus.SymbolList)) //nolint:gosec // G115 + } + packetStatusPos += packetStatusChunkLength + t.PacketChunks = append(t.PacketChunks, iPacketStatus) + } + + recvDeltasPos := packetStatusPos + for _, delta := range t.RecvDeltas { + if delta.Type == TypeTCCPacketReceivedSmallDelta { + if recvDeltasPos+1 > totalLength { + return errPacketTooShort + } + err := delta.Unmarshal(rawPacket[recvDeltasPos : recvDeltasPos+1]) + if err != nil { + return err + } + recvDeltasPos++ + } + if delta.Type == TypeTCCPacketReceivedLargeDelta { + if recvDeltasPos+2 > totalLength { + return errPacketTooShort + } + err := delta.Unmarshal(rawPacket[recvDeltasPos : recvDeltasPos+2]) + if err != nil { + return err + } + recvDeltasPos += 2 + } + } + + return nil +} + +// DestinationSSRC returns an array of SSRC values that this packet refers to. +func (t TransportLayerCC) DestinationSSRC() []uint32 { + return []uint32{t.MediaSSRC} +} + +func localMin(x, y uint16) uint16 { + if x < y { + return x + } + + return y +} diff --git a/vendor/github.com/pion/rtcp/transport_layer_nack.go b/vendor/github.com/pion/rtcp/transport_layer_nack.go new file mode 100644 index 0000000..c0e3b8f --- /dev/null +++ b/vendor/github.com/pion/rtcp/transport_layer_nack.go @@ -0,0 +1,187 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package rtcp + +import ( + "encoding/binary" + "fmt" + "math" +) + +// PacketBitmap shouldn't be used like a normal integral, +// so it's type is masked here. Access it with PacketList(). +type PacketBitmap uint16 + +// NackPair is a wire-representation of a collection of +// Lost RTP packets. +type NackPair struct { + // ID of lost packets + PacketID uint16 + + // Bitmask of following lost packets + LostPackets PacketBitmap +} + +// The TransportLayerNack packet informs the encoder about the loss of a transport packet +// IETF RFC 4585, Section 6.2.1 +// https://tools.ietf.org/html/rfc4585#section-6.2.1 +type TransportLayerNack struct { + // SSRC of sender + SenderSSRC uint32 + + // SSRC of the media source + MediaSSRC uint32 + + Nacks []NackPair +} + +// NackPairsFromSequenceNumbers generates a slice of NackPair from a list of SequenceNumbers +// This handles generating the proper values for PacketID/LostPackets. +func NackPairsFromSequenceNumbers(sequenceNumbers []uint16) (pairs []NackPair) { + if len(sequenceNumbers) == 0 { + return []NackPair{} + } + + nackPair := &NackPair{PacketID: sequenceNumbers[0]} + for i := 1; i < len(sequenceNumbers); i++ { + m := sequenceNumbers[i] + + if m-nackPair.PacketID > 16 { + pairs = append(pairs, *nackPair) + nackPair = &NackPair{PacketID: m} + + continue + } + + nackPair.LostPackets |= 1 << (m - nackPair.PacketID - 1) + } + pairs = append(pairs, *nackPair) + + return +} + +// Range calls f sequentially for each sequence number covered by n. +// If f returns false, Range stops the iteration. +func (n *NackPair) Range(f func(seqno uint16) bool) { + more := f(n.PacketID) + if !more { + return + } + + b := n.LostPackets + for i := uint16(0); b != 0; i++ { + if (b & (1 << i)) != 0 { + b &^= (1 << i) + more = f(n.PacketID + i + 1) + if !more { + return + } + } + } +} + +// PacketList returns a list of Nack'd packets that's referenced by a NackPair. +func (n *NackPair) PacketList() []uint16 { + out := make([]uint16, 0, 17) + n.Range(func(seqno uint16) bool { + out = append(out, seqno) + + return true + }) + + return out +} + +const ( + tlnLength = 2 + nackOffset = 8 +) + +// Marshal encodes the TransportLayerNack in binary. +func (p TransportLayerNack) Marshal() ([]byte, error) { + if len(p.Nacks)+tlnLength > math.MaxUint8 { + return nil, errTooManyReports + } + + rawPacket := make([]byte, nackOffset+(len(p.Nacks)*4)) + binary.BigEndian.PutUint32(rawPacket, p.SenderSSRC) + binary.BigEndian.PutUint32(rawPacket[4:], p.MediaSSRC) + for i := 0; i < len(p.Nacks); i++ { + binary.BigEndian.PutUint16(rawPacket[nackOffset+(4*i):], p.Nacks[i].PacketID) + binary.BigEndian.PutUint16(rawPacket[nackOffset+(4*i)+2:], uint16(p.Nacks[i].LostPackets)) + } + h := p.Header() + hData, err := h.Marshal() + if err != nil { + return nil, err + } + + return append(hData, rawPacket...), nil +} + +// Unmarshal decodes the TransportLayerNack from binary. +func (p *TransportLayerNack) Unmarshal(rawPacket []byte) error { + if len(rawPacket) < (headerLength + ssrcLength) { + return errPacketTooShort + } + + var header Header + if err := header.Unmarshal(rawPacket); err != nil { + return err + } + + if len(rawPacket) < (headerLength + int(4*header.Length)) { + return errPacketTooShort + } + + if header.Type != TypeTransportSpecificFeedback || header.Count != FormatTLN { + return errWrongType + } + + // The FCI field MUST contain at least one and MAY contain more than one Generic NACK + if 4*header.Length <= nackOffset { + return errBadLength + } + + p.SenderSSRC = binary.BigEndian.Uint32(rawPacket[headerLength:]) + p.MediaSSRC = binary.BigEndian.Uint32(rawPacket[headerLength+ssrcLength:]) + for i := headerLength + nackOffset; i < (headerLength + int(header.Length*4)); i += 4 { + p.Nacks = append(p.Nacks, NackPair{ + binary.BigEndian.Uint16(rawPacket[i:]), + PacketBitmap(binary.BigEndian.Uint16(rawPacket[i+2:])), + }) + } + + return nil +} + +// MarshalSize returns the size of the packet once marshaled. +func (p *TransportLayerNack) MarshalSize() int { + return headerLength + nackOffset + (len(p.Nacks) * 4) +} + +// Header returns the Header associated with this packet. +func (p *TransportLayerNack) Header() Header { + return Header{ + Count: FormatTLN, + Type: TypeTransportSpecificFeedback, + Length: uint16((p.MarshalSize() / 4) - 1), //nolint:gosec // G115 + } +} + +func (p TransportLayerNack) String() string { + out := fmt.Sprintf("TransportLayerNack from %x\n", p.SenderSSRC) + out += fmt.Sprintf("\tMedia Ssrc %x\n", p.MediaSSRC) + out += "\tID\tLostPackets\n" + for _, i := range p.Nacks { + out += fmt.Sprintf("\t%d\t%b\n", i.PacketID, i.LostPackets) + } + + return out +} + +// DestinationSSRC returns an array of SSRC values that this packet refers to. +func (p *TransportLayerNack) DestinationSSRC() []uint32 { + return []uint32{p.MediaSSRC} +} diff --git a/vendor/github.com/pion/rtcp/util.go b/vendor/github.com/pion/rtcp/util.go new file mode 100644 index 0000000..705b290 --- /dev/null +++ b/vendor/github.com/pion/rtcp/util.go @@ -0,0 +1,43 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package rtcp + +// getPadding Returns the padding required to make the length a multiple of 4. +func getPadding(packetLen int) int { + if packetLen%4 == 0 { + return 0 + } + + return 4 - (packetLen % 4) +} + +// setNBitsOfUint16 will truncate the value to size, left-shift to startIndex position and set. +func setNBitsOfUint16(src, size, startIndex, val uint16) (uint16, error) { + if startIndex+size > 16 { + return 0, errInvalidSizeOrStartIndex + } + + // truncate val to size bits + val &= (1 << size) - 1 + + return src | (val << (16 - size - startIndex)), nil +} + +// appendBit32 will left-shift and append n bits of val. +func appendNBitsToUint32(src, n, val uint32) uint32 { + return (src << n) | (val & (0xFFFFFFFF >> (32 - n))) +} + +// getNBit get n bits from 1 byte, begin with a position. +func getNBitsFromByte(b byte, begin, n uint16) uint16 { + endShift := 8 - (begin + n) + mask := (0xFF >> begin) & uint8(0xFF<> endShift +} + +// get24BitFromBytes get 24bits from `[3]byte` slice. +func get24BitsFromBytes(b []byte) uint32 { + return uint32(b[0])<<16 + uint32(b[1])<<8 + uint32(b[2]) +} diff --git a/vendor/github.com/pion/rtp/codecs/av1_depacketizer.go b/vendor/github.com/pion/rtp/codecs/av1_depacketizer.go new file mode 100644 index 0000000..8b04e49 --- /dev/null +++ b/vendor/github.com/pion/rtp/codecs/av1_depacketizer.go @@ -0,0 +1,197 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package codecs + +import ( + "fmt" + + "github.com/pion/rtp/codecs/av1/obu" +) + +// AV1Depacketizer is a AV1 RTP Packet depacketizer. +// Reads AV1 packets from a RTP stream and outputs AV1 low overhead bitstream. +type AV1Depacketizer struct { + // holds the fragmented OBU from the previous packet. + buffer []byte + + // Z, Y, N are flags from the AV1 Aggregation Header. + Z, Y, N bool + + videoDepacketizer +} + +func (d *AV1Depacketizer) appendOBUWithCalculatedSize( + buff []byte, + obuHeader *obu.Header, + obuBuffer []byte, + payloadOffset int, +) []byte { + obuPayloadSize := len(obuBuffer) - payloadOffset + buff = append(buff, obuHeader.Marshal()...) + buff = append(buff, obu.WriteToLeb128(uint(obuPayloadSize))...) // nolint: gosec // G104 + buff = append(buff, obuBuffer[payloadOffset:]...) + + return buff +} + +// Unmarshal parses an AV1 RTP payload into its constituent OBUs stream with obu_size_field, +// It assumes that the payload is in order (e.g. the caller is responsible for reordering RTP packets). +// If the last OBU in the payload is fragmented, it will be stored in the buffer until the +// it is completed. +// +//nolint:gocognit,cyclop +func (d *AV1Depacketizer) Unmarshal(payload []byte) (buff []byte, err error) { + buff = make([]byte, 0) + + if len(payload) <= 1 { + return nil, errShortPacket + } + + // |Z|Y| W |N|-|-|-| + obuZ := (av1ZMask & payload[0]) != 0 // Z + obuY := (av1YMask & payload[0]) != 0 // Y + obuCount := (av1WMask & payload[0]) >> 4 // W + obuN := (av1NMask & payload[0]) != 0 // N + d.Z = obuZ + d.Y = obuY + d.N = obuN + if obuN { + d.buffer = nil + } + + // Make sure we clear the buffer if Z is not 0. + if !obuZ && len(d.buffer) > 0 { + d.buffer = nil + } + + obuOffset := 0 + for offset := 1; offset < len(payload); obuOffset++ { + isFirst := obuOffset == 0 + isLast := obuCount != 0 && obuOffset == int(obuCount)-1 + + // https://aomediacodec.github.io/av1-rtp-spec/#44-av1-aggregation-header + // W: two bit field that describes the number of OBU elements in the packet. + // This field MUST be set equal to 0 or equal to the number of OBU elements contained in the packet. + // If set to 0, each OBU element MUST be preceded by a length field. If not set to 0 + // (i.e., W = 1, 2 or 3) the last OBU element MUST NOT be preceded by a length field. + var lengthField, n int + if obuCount == 0 || !isLast { + obuSizeVal, nVal, err := obu.ReadLeb128(payload[offset:]) //nolint:gosec //guard from loop + lengthField = int(obuSizeVal) //nolint:gosec // G115 false positive + n = int(nVal) //nolint:gosec // G115 false positive + if err != nil { + return nil, err + } + + offset += n + if obuCount == 0 && offset+lengthField == len(payload) { + isLast = true + } + } else { + // https://aomediacodec.github.io/av1-rtp-spec/#44-av1-aggregation-header + // Length of the last OBU element = + // length of the RTP payload + // - length of aggregation header + // - length of previous OBU elements including length fields + lengthField = len(payload) - offset + } + + if offset+lengthField > len(payload) { + return nil, fmt.Errorf( + "%w: OBU size %d + %d offset exceeds payload length %d", + errShortPacket, lengthField, offset, len(payload), + ) + } + + var obuBuffer []byte + if isFirst && obuZ { + // We lost the first fragment of the OBU + // We drop the buffer and continue + if len(d.buffer) == 0 { + if isLast { + break + } + offset += lengthField + + continue + } + + obuBuffer = make([]byte, len(d.buffer)+lengthField) + + copy(obuBuffer, d.buffer) + copy(obuBuffer[len(d.buffer):], payload[offset:offset+lengthField]) + d.buffer = nil + } else { + obuBuffer = payload[offset : offset+lengthField] + } + offset += lengthField + + if isLast && obuY { + d.buffer = obuBuffer + + break + } + + if len(obuBuffer) == 0 { + continue + } + + obuHeader, err := obu.ParseOBUHeader(obuBuffer) + if err != nil { + return nil, err + } + + // The temporal delimiter OBU, if present, SHOULD be removed when transmitting, + // and MUST be ignored by receivers. Tile list OBUs are not supported. + // They SHOULD be removed when transmitted, and MUST be ignored by receivers. + // https://aomediacodec.github.io/av1-rtp-spec/#5-packetization-rules + if obuHeader.Type == obu.OBUTemporalDelimiter || obuHeader.Type == obu.OBUTileList { + continue + } + + // obu_has_size_field should be set to 0 for AV1 RTP packets. + // But we still check it to be sure, if we get obu size we just use it, instead of calculating it. + if obuHeader.HasSizeField { + obuSize, n, err := obu.ReadLeb128(obuBuffer[obuHeader.Size():]) + if err != nil { + return nil, err + } + + // Ignore obu_size_field if it is present and doesn't match the calculated size. + sizeFromOBUSize := obuHeader.Size() + int(obuSize) + int(n) //nolint:gosec + if lengthField != sizeFromOBUSize { + payloadOffset := obuHeader.Size() + int(n) //nolint:gosec // n is small, LEB128. + buff = d.appendOBUWithCalculatedSize(buff, obuHeader, obuBuffer, payloadOffset) + } else { + buff = append(buff, obuBuffer...) + } + } else { + obuHeader.HasSizeField = true + buff = d.appendOBUWithCalculatedSize(buff, obuHeader, obuBuffer, obuHeader.Size()) + } + + if isLast { + break + } + } + + if obuCount != 0 && obuOffset != int(obuCount-1) { + return nil, fmt.Errorf( + "%w: OBU count %d does not match number of OBUs %d", + errShortPacket, obuCount, obuOffset, + ) + } + + return buff, nil +} + +// IsPartitionHead returns true if Z in the AV1 Aggregation Header +// is set to 0. +func (d *AV1Depacketizer) IsPartitionHead(payload []byte) bool { + if len(payload) == 0 { + return false + } + + return (payload[0] & av1ZMask) == 0 +} diff --git a/vendor/github.com/pion/rtp/codecs/av1_packet.go b/vendor/github.com/pion/rtp/codecs/av1_packet.go new file mode 100644 index 0000000..797d4fe --- /dev/null +++ b/vendor/github.com/pion/rtp/codecs/av1_packet.go @@ -0,0 +1,386 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package codecs + +import ( + "github.com/pion/rtp/codecs/av1/obu" +) + +const ( + av1ZMask = byte(0b10000000) + av1ZBitshift = 7 + + av1YMask = byte(0b01000000) + av1YBitshift = 6 + + av1WMask = byte(0b00110000) + av1WBitshift = 4 + + av1NMask = byte(0b00001000) + av1NBitshift = 3 +) + +// AV1Payloader payloads AV1 packets. +type AV1Payloader struct{} + +// Payload implements AV1 RTP payloader. +// Reads from a open_bitstream_unit (OBU) framing stream as defined in +// 5.3. https://aomediacodec.github.io/av1-spec/av1-spec.pdf#page=39 +// Returns AV1 RTP packets https://aomediacodec.github.io/av1-rtp-spec/ +// The payload is fragmented into multiple packets, each packet is a valid AV1 RTP payload. +// nolint:cyclop +func (p *AV1Payloader) Payload(mtu uint16, payload []byte) (payloads [][]byte) { + // 2 is the minimum MTU for AV1 (aggregate header + 1 byte) + if mtu <= 1 || len(payload) == 0 { + return payloads + } + + // We maximize the use of the W field in the AV1 aggregation header + // to minimize the need for explicit length fields for each OBU. + // To achieve this, we temporarily hold the OBU payload before adding it to a packet. + // Since we can't determine in advance whether the next OBU should be included in the same packet + // or start a new one, we also can't know ahead of time if an OBU is the last in the current packet. + var currentOBUPayload []byte + var currentPacketOBUHeader *obu.ExtensionHeader + obusInPacket := 0 + newSequence := false + startWithNewPacket := false + + for offset := 0; offset < len(payload); { + obuHeader, err := obu.ParseOBUHeader(payload[offset:]) + if err != nil { + break + } + + offset += obuHeader.Size() + // if ( obu_has_size_field ) { + // obu_size leb128() + // } else { + // obu_size = sz - 1 - obu_extension_flag + // } + var obuSize int + if obuHeader.HasSizeField { + obuSizeValue, n, err := obu.ReadLeb128(payload[offset:]) + if err != nil { + break + } + + offset += int(n) //nolint:gosec // G115, leb128 size is a signle digit + obuSize = int(obuSizeValue) //nolint:gosec // G115, Leb128 is capped at 4 bytes + } else { + obuSize = len(payload) - offset + } + + // Each RTP packet MUST NOT contain OBUs that belong to different temporal units. + // If a sequence header OBU is present in an RTP packet, then it SHOULD be the first OBU in the packet. + // https://aomediacodec.github.io/av1-rtp-spec/#5-packetization-rules + needNewPacket := obuHeader.Type == obu.OBUTemporalDelimiter || obuHeader.Type == obu.OBUSequenceHeader + // If more than one OBU contained in an RTP packet has an OBU extension header, + // then the values of the temporal_id and spatial_id MUST be the same in all such OBUs in the RTP packet. + if !needNewPacket && obuHeader.ExtensionHeader != nil && currentPacketOBUHeader != nil { + needNewPacket = obuHeader.ExtensionHeader.SpatialID != currentPacketOBUHeader.SpatialID || + obuHeader.ExtensionHeader.TemporalID != currentPacketOBUHeader.TemporalID + } + + if obuHeader.ExtensionHeader != nil { + currentPacketOBUHeader = obuHeader.ExtensionHeader + } + + if obuSize > len(payload)-offset { + break + } + + if len(currentOBUPayload) > 0 { + payloads, obusInPacket = p.appendOBUPayload( + payloads, + currentOBUPayload, + newSequence, + needNewPacket, + startWithNewPacket, + int(mtu), + obusInPacket, + ) + currentOBUPayload = nil + startWithNewPacket = needNewPacket + + if needNewPacket { + newSequence = false + currentPacketOBUHeader = nil + } + } + + // The temporal delimiter OBU, if present, SHOULD be removed when transmitting, + // and MUST be ignored by receivers. Tile list OBUs are not supported. + // They SHOULD be removed when transmitted, and MUST be ignored by receivers. + // https://aomediacodec.github.io/av1-rtp-spec/#5-packetization-rules + if obuHeader.Type == obu.OBUTileList || obuHeader.Type == obu.OBUTemporalDelimiter { + offset += obuSize + + continue + } + + currentOBUPayload = make([]byte, obuSize+obuHeader.Size()) + // The AV1 specification allows OBUs to have an optional size field called obu_size + // (also leb128 encoded), signaled by the obu_has_size_field flag in the OBU header. + // To minimize overhead, the obu_has_size_field flag SHOULD be set to zero in all OBUs. + // https://aomediacodec.github.io/av1-rtp-spec/#45-payload-structure + obuHeader.HasSizeField = false + copy(currentOBUPayload, obuHeader.Marshal()) + //nolint:gosec // G115 we validate the size of the payload + copy(currentOBUPayload[obuHeader.Size():], payload[offset:offset+obuSize]) + offset += obuSize + newSequence = obuHeader.Type == obu.OBUSequenceHeader + } + + if len(currentOBUPayload) > 0 { + payloads, _ = p.appendOBUPayload( + payloads, + currentOBUPayload, + newSequence, + true, + startWithNewPacket, + int(mtu), + obusInPacket, + ) + } + + return payloads +} + +//nolint:cyclop +func (p *AV1Payloader) appendOBUPayload( + payloads [][]byte, + obuPayload []byte, + isNewVideoSequence, isLast, startWithNewPacket bool, + mtu, currentOBUCount int, +) ([][]byte, int) { + currentPayload := len(payloads) - 1 + freeSpace := 0 + if currentPayload >= 0 { + freeSpace = mtu - len(payloads[currentPayload]) + } + + if currentPayload < 0 || freeSpace <= 0 || startWithNewPacket { + payload := make([]byte, 1, mtu) + if isNewVideoSequence { + payload[0] |= 1 << av1NBitshift + } + + payloads = append(payloads, payload) + currentPayload = len(payloads) - 1 + // MTU - aggregation header + freeSpace = mtu - 1 + currentOBUCount = 0 + } + + remaining := len(obuPayload) + // How much to write to the current packet. + toWrite := min(remaining, freeSpace) + + // W: two bit field that describes the number of OBU elements in the packet. + // This field MUST be set equal to 0 or equal to the number of OBU elements contained in the packet. + // If set to 0, each OBU element MUST be preceded by a length field. If not set to 0 (i.e., W = 1, 2 or 3) + // the last OBU element MUST NOT be preceded by a length field. + // https://aomediacodec.github.io/av1-rtp-spec/#44-av1-aggregation-header + shouldUseWField := (isLast || toWrite >= freeSpace) && currentOBUCount < 3 + switch { + case shouldUseWField: + payloads[currentPayload][0] |= byte((currentOBUCount+1)<= 2: + // 2 bytes is the minimum size for OBUs with length field. + // [1 byte for the length field] [1 byte for the OBU] + //nolint:gosec // G115 false positive + toWrite = p.computeWriteSize(toWrite, freeSpace) + lengthField := obu.WriteToLeb128(uint(toWrite)) //nolint:gosec // G115 false positive + payloads[currentPayload] = append(payloads[currentPayload], lengthField...) + payloads[currentPayload] = append(payloads[currentPayload], obuPayload[:toWrite]...) + currentOBUCount++ + default: + // If we can't fit any more OBUs in the current packet (only 1 byte left and W=0) + toWrite = 0 + } + + obuPayload = obuPayload[toWrite:] + remaining -= toWrite + + // Handle fragments. + for remaining > 0 { + // New packet with empty aggregation header. + payload := make([]byte, 1, mtu) + payloads = append(payloads, payload) + currentPayload++ + + // Append the Y bit to the previous packet. And Z bit to the current packet. + // If we wrote some bytes to the previous packet. + // Handles an edge case where the previous packet has only one byte remaining, + // while the W field is not used. This results in insufficient space + // for a one-byte length field and a one-byte OBU. + // So we don't write anything to the initial packet. + if toWrite != 0 { + payloads[currentPayload-1][0] |= av1YMask + payloads[currentPayload][0] |= av1ZMask + } + + toWrite = min(remaining, + // MTU - aggregation header + mtu-1) + + // Last OBU in the current packet, Or this whole packet is a fragment. + if isLast || remaining >= mtu-1 { + payloads[currentPayload][0] |= 1 << av1WBitshift + } else { + toWrite = p.computeWriteSize(toWrite, mtu-1) + lengthField := obu.WriteToLeb128(uint(toWrite)) //nolint:gosec // G115 false positive + payloads[currentPayload] = append(payloads[currentPayload], lengthField...) + } + + payloads[currentPayload] = append(payloads[currentPayload], obuPayload[:toWrite]...) + obuPayload = obuPayload[toWrite:] + remaining -= toWrite + currentOBUCount = 1 + } + + return payloads, currentOBUCount +} + +// Measure the maximum write size for a payload with leb128 encoding added. +func (p *AV1Payloader) computeWriteSize(wantToWrite, canWrite int) int { + leb128Size, isAtEdge := p.leb128Size(wantToWrite) + if canWrite >= wantToWrite+leb128Size { + return wantToWrite + } + + // Handle edge case where subtracting one from the leb128 size + // results in a smaller leb128 size that can fit in the remaining space. + if isAtEdge && canWrite >= wantToWrite+leb128Size-1 { + return wantToWrite - 1 + } + + return wantToWrite - leb128Size +} + +func (p *AV1Payloader) leb128Size(leb128 int) (size int, isAtEge bool) { + switch { + case leb128 >= 268435456: // 2^28 + return 5, leb128 == 268435456 + case leb128 >= 2097152: // 2^21 + return 4, leb128 == 2097152 + case leb128 >= 16384: // 2^14 + return 3, leb128 == 16384 + case leb128 >= 128: // 2^7 + return 2, leb128 == 128 + default: + return 1, false + } +} + +// AV1Packet represents a depacketized AV1 RTP Packet +/* +* 0 1 2 3 4 5 6 7 +* +-+-+-+-+-+-+-+-+ +* |Z|Y| W |N|-|-|-| +* +-+-+-+-+-+-+-+-+ +**/ +// https://aomediacodec.github.io/av1-rtp-spec/#44-av1-aggregation-header +// +// Deprecated: Use AV1Depacketizer instead. +type AV1Packet struct { + // Z: MUST be set to 1 if the first OBU element is an + // OBU fragment that is a continuation of an OBU fragment + // from the previous packet, and MUST be set to 0 otherwise. + Z bool + + // Y: MUST be set to 1 if the last OBU element is an OBU fragment + // that will continue in the next packet, and MUST be set to 0 otherwise. + Y bool + + // W: two bit field that describes the number of OBU elements in the packet. + // This field MUST be set equal to 0 or equal to the number of OBU elements + // contained in the packet. If set to 0, each OBU element MUST be preceded by + // a length field. If not set to 0 (i.e., W = 1, 2 or 3) the last OBU element + // MUST NOT be preceded by a length field. Instead, the length of the last OBU + // element contained in the packet can be calculated as follows: + // Length of the last OBU element = + // length of the RTP payload + // - length of aggregation header + // - length of previous OBU elements including length fields + W byte + + // N: MUST be set to 1 if the packet is the first packet of a coded video sequence, and MUST be set to 0 otherwise. + N bool + + // Each AV1 RTP Packet is a collection of OBU Elements. Each OBU Element may be a full OBU, or just a fragment of one. + // AV1Frame provides the tools to construct a collection of OBUs from a collection of OBU Elements + OBUElements [][]byte + + // zeroAllocation prevents populating the OBUElements field + zeroAllocation bool +} + +// Unmarshal parses the passed byte slice and stores the result in the AV1Packet this method is called upon. +func (p *AV1Packet) Unmarshal(payload []byte) ([]byte, error) { + if payload == nil { + return nil, errNilPacket + } else if len(payload) < 2 { + return nil, errShortPacket + } + + p.Z = ((payload[0] & av1ZMask) >> av1ZBitshift) != 0 + p.Y = ((payload[0] & av1YMask) >> av1YBitshift) != 0 + p.N = ((payload[0] & av1NMask) >> av1NBitshift) != 0 + p.W = (payload[0] & av1WMask) >> av1WBitshift + + if p.Z && p.N { + return nil, errIsKeyframeAndFragment + } + + if !p.zeroAllocation { + obuElements, err := p.parseBody(payload[1:]) + if err != nil { + return nil, err + } + p.OBUElements = obuElements + } + + return payload[1:], nil +} + +func (p *AV1Packet) parseBody(payload []byte) ([][]byte, error) { + if p.OBUElements != nil { + return p.OBUElements, nil + } + + obuElements := [][]byte{} + + var obuElementLength, bytesRead uint + currentIndex := uint(0) + for i := 1; ; i++ { + if currentIndex == uint(len(payload)) { + break + } + + // If W bit is set the last OBU Element will have no length header + if byte(i) == p.W { + bytesRead = 0 + obuElementLength = uint(len(payload)) - currentIndex + } else { + var err error + obuElementLength, bytesRead, err = obu.ReadLeb128(payload[currentIndex:]) + if err != nil { + return nil, err + } + } + + currentIndex += bytesRead + if uint(len(payload)) < currentIndex+obuElementLength { + return nil, errShortPacket + } + obuElements = append(obuElements, payload[currentIndex:currentIndex+obuElementLength]) + currentIndex += obuElementLength + } + + return obuElements, nil +} diff --git a/vendor/github.com/pion/rtp/codecs/codecs.go b/vendor/github.com/pion/rtp/codecs/codecs.go new file mode 100644 index 0000000..42caf1e --- /dev/null +++ b/vendor/github.com/pion/rtp/codecs/codecs.go @@ -0,0 +1,5 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package codecs implements codec specific RTP payloader/depayloaders +package codecs diff --git a/vendor/github.com/pion/rtp/codecs/common.go b/vendor/github.com/pion/rtp/codecs/common.go new file mode 100644 index 0000000..41ff6f3 --- /dev/null +++ b/vendor/github.com/pion/rtp/codecs/common.go @@ -0,0 +1,40 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package codecs + +func minInt(a, b int) int { + if a < b { + return a + } + + return b +} + +// audioDepacketizer is a mixin for audio codec depacketizers. +type audioDepacketizer struct{} + +func (d *audioDepacketizer) IsPartitionTail(_ bool, _ []byte) bool { + return true +} + +func (d *audioDepacketizer) IsPartitionHead(_ []byte) bool { + return true +} + +// videoDepacketizer is a mixin for video codec depacketizers. +type videoDepacketizer struct { + zeroAllocation bool +} + +func (d *videoDepacketizer) IsPartitionTail(marker bool, _ []byte) bool { + return marker +} + +// SetZeroAllocation enables Zero Allocation mode for the depacketizer +// By default the Depacketizers will allocate as they parse. These allocations +// are needed for Metadata and other optional values. If you don't need this information +// enabling SetZeroAllocation gives you higher performance at a reduced feature set. +func (d *videoDepacketizer) SetZeroAllocation(zeroAllocation bool) { + d.zeroAllocation = zeroAllocation +} diff --git a/vendor/github.com/pion/rtp/codecs/error.go b/vendor/github.com/pion/rtp/codecs/error.go new file mode 100644 index 0000000..81a26e0 --- /dev/null +++ b/vendor/github.com/pion/rtp/codecs/error.go @@ -0,0 +1,19 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package codecs + +import "errors" + +var ( + errShortPacket = errors.New("packet is not large enough") + errNilPacket = errors.New("invalid nil packet") + errTooManyPDiff = errors.New("too many PDiff") + errTooManySpatialLayers = errors.New("too many spatial layers") + errUnhandledNALUType = errors.New("NALU Type is unhandled") + + // AV1 Errors. + errIsKeyframeAndFragment = errors.New( + "bits Z and N are set. Not possible to have OBU be tail fragment and be keyframe", + ) +) diff --git a/vendor/github.com/pion/rtp/codecs/g711_packet.go b/vendor/github.com/pion/rtp/codecs/g711_packet.go new file mode 100644 index 0000000..61022c0 --- /dev/null +++ b/vendor/github.com/pion/rtp/codecs/g711_packet.go @@ -0,0 +1,26 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package codecs + +// G711Payloader payloads G711 packets. +type G711Payloader struct{} + +// Payload fragments an G711 packet across one or more byte arrays. +func (p *G711Payloader) Payload(mtu uint16, payload []byte) [][]byte { + var out [][]byte + if payload == nil || mtu == 0 { + return out + } + + for len(payload) > int(mtu) { + o := make([]byte, mtu) + copy(o, payload[:mtu]) + payload = payload[mtu:] + out = append(out, o) + } + o := make([]byte, len(payload)) + copy(o, payload) + + return append(out, o) +} diff --git a/vendor/github.com/pion/rtp/codecs/g722_packet.go b/vendor/github.com/pion/rtp/codecs/g722_packet.go new file mode 100644 index 0000000..5569df8 --- /dev/null +++ b/vendor/github.com/pion/rtp/codecs/g722_packet.go @@ -0,0 +1,26 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package codecs + +// G722Payloader payloads G722 packets. +type G722Payloader struct{} + +// Payload fragments an G722 packet across one or more byte arrays. +func (p *G722Payloader) Payload(mtu uint16, payload []byte) [][]byte { + var out [][]byte + if payload == nil || mtu == 0 { + return out + } + + for len(payload) > int(mtu) { + o := make([]byte, mtu) + copy(o, payload[:mtu]) + payload = payload[mtu:] + out = append(out, o) + } + o := make([]byte, len(payload)) + copy(o, payload) + + return append(out, o) +} diff --git a/vendor/github.com/pion/rtp/codecs/h264_packet.go b/vendor/github.com/pion/rtp/codecs/h264_packet.go new file mode 100644 index 0000000..c011d4f --- /dev/null +++ b/vendor/github.com/pion/rtp/codecs/h264_packet.go @@ -0,0 +1,340 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package codecs + +import ( + "bytes" + "encoding/binary" + "fmt" +) + +// H264Payloader payloads H264 packets. +type H264Payloader struct { + spsNalu, ppsNalu []byte + DisableStapA bool +} + +const ( + stapaNALUType = 24 + fuaNALUType = 28 + fubNALUType = 29 + spsNALUType = 7 + ppsNALUType = 8 + audNALUType = 9 + fillerNALUType = 12 + + fuaHeaderSize = 2 + stapaHeaderSize = 1 + stapaNALULengthSize = 2 + + naluTypeBitmask = 0x1F + naluRefIdcBitmask = 0x60 + fuStartBitmask = 0x80 + fuEndBitmask = 0x40 + + outputStapAHeader = 0x78 +) + +// nolint:gochecknoglobals +var ( + naluStartCode = []byte{0x00, 0x00, 0x01} + annexbNALUStartCode = []byte{0x00, 0x00, 0x00, 0x01} +) + +func emitNalus(nals []byte, emit func([]byte)) { + // look for 3-byte NALU start code + start := bytes.Index(nals, naluStartCode) + offset := 3 + + if start == -1 { + // no start code, emit the whole buffer + emit(nals) + + return + } + + length := len(nals) + + for start < length { + // look for the next NALU start (end of this NALU) + end := bytes.Index(nals[start+offset:], naluStartCode) + if end == -1 { + // no more NALUs, emit the rest of the buffer + emit(nals[start+offset:]) + + break + } + + // next NALU start + nextStart := start + offset + end + + // check if the next NALU is actually a 4-byte start code + endIs4Byte := nals[nextStart-1] == 0 + if endIs4Byte { + nextStart-- + } + + emit(nals[start+offset : nextStart]) + + start = nextStart + + if endIs4Byte { + offset = 4 + } else { + offset = 3 + } + } +} + +// Payload fragments a H264 packet across one or more byte arrays. +func (p *H264Payloader) Payload(mtu uint16, payload []byte) [][]byte { //nolint:cyclop + var payloads [][]byte + if len(payload) == 0 { + return payloads + } + + emitNalus(payload, func(nalu []byte) { + if len(nalu) == 0 { + return + } + + naluType := nalu[0] & naluTypeBitmask + naluRefIdc := nalu[0] & naluRefIdcBitmask + + switch { + case naluType == audNALUType || naluType == fillerNALUType: + return + case naluType == spsNALUType: + if !p.DisableStapA { + p.spsNalu = nalu + + return + } + case naluType == ppsNALUType: + if !p.DisableStapA { + p.ppsNalu = nalu + + return + } + case !p.DisableStapA && p.spsNalu != nil && p.ppsNalu != nil: + // Pack current NALU with SPS and PPS as STAP-A + spsLen := make([]byte, 2) + binary.BigEndian.PutUint16(spsLen, uint16(len(p.spsNalu))) // nolint: gosec // G115 + + ppsLen := make([]byte, 2) + binary.BigEndian.PutUint16(ppsLen, uint16(len(p.ppsNalu))) // nolint: gosec // G115 + + stapANalu := []byte{outputStapAHeader} + stapANalu = append(stapANalu, spsLen...) + stapANalu = append(stapANalu, p.spsNalu...) + stapANalu = append(stapANalu, ppsLen...) + stapANalu = append(stapANalu, p.ppsNalu...) + if len(stapANalu) <= int(mtu) { + out := make([]byte, len(stapANalu)) + copy(out, stapANalu) + payloads = append(payloads, out) + } + + p.spsNalu = nil + p.ppsNalu = nil + } + + // Single NALU + if len(nalu) <= int(mtu) { + out := make([]byte, len(nalu)) + copy(out, nalu) + payloads = append(payloads, out) + + return + } + + // FU-A + maxFragmentSize := int(mtu) - fuaHeaderSize + + // The FU payload consists of fragments of the payload of the fragmented + // NAL unit so that if the fragmentation unit payloads of consecutive + // FUs are sequentially concatenated, the payload of the fragmented NAL + // unit can be reconstructed. The NAL unit type octet of the fragmented + // NAL unit is not included as such in the fragmentation unit payload, + // but rather the information of the NAL unit type octet of the + // fragmented NAL unit is conveyed in the F and NRI fields of the FU + // indicator octet of the fragmentation unit and in the type field of + // the FU header. An FU payload MAY have any number of octets and MAY + // be empty. + + // According to the RFC, the first octet is skipped due to redundant information + naluIndex := 1 + naluLength := len(nalu) - naluIndex + naluRemaining := naluLength + + if minInt(maxFragmentSize, naluRemaining) <= 0 { + return + } + + for naluRemaining > 0 { + currentFragmentSize := minInt(maxFragmentSize, naluRemaining) + out := make([]byte, fuaHeaderSize+currentFragmentSize) + + // +---------------+ + // |0|1|2|3|4|5|6|7| + // +-+-+-+-+-+-+-+-+ + // |F|NRI| Type | + // +---------------+ + out[0] = fuaNALUType + out[0] |= naluRefIdc + + // +---------------+ + // |0|1|2|3|4|5|6|7| + // +-+-+-+-+-+-+-+-+ + // |S|E|R| Type | + // +---------------+ + + out[1] = naluType + if naluRemaining == naluLength { + // Set start bit + out[1] |= 1 << 7 + } else if naluRemaining-currentFragmentSize == 0 { + // Set end bit + out[1] |= 1 << 6 + } + + copy(out[fuaHeaderSize:], nalu[naluIndex:naluIndex+currentFragmentSize]) + payloads = append(payloads, out) + + naluRemaining -= currentFragmentSize + naluIndex += currentFragmentSize + } + }) + + return payloads +} + +// H264Packet represents the H264 header that is stored in the payload of an RTP Packet. +type H264Packet struct { + IsAVC bool + fuaBuffer []byte + + videoDepacketizer +} + +func (p *H264Packet) doPackaging(buf, nalu []byte) []byte { + if p.IsAVC { + buf = binary.BigEndian.AppendUint32(buf, uint32(len(nalu))) // nolint: gosec // G115 false positive + buf = append(buf, nalu...) + + return buf + } + + buf = append(buf, annexbNALUStartCode...) + buf = append(buf, nalu...) + + return buf +} + +// IsDetectedFinalPacketInSequence returns true of the packet passed in has the +// marker bit set indicated the end of a packet sequence. +func (p *H264Packet) IsDetectedFinalPacketInSequence(rtpPacketMarketBit bool) bool { + return rtpPacketMarketBit +} + +// Unmarshal parses the passed byte slice and stores the result in the H264Packet this method is called upon. +func (p *H264Packet) Unmarshal(payload []byte) ([]byte, error) { + if p.zeroAllocation { + return payload, nil + } + + return p.parseBody(payload) +} + +func (p *H264Packet) parseBody(payload []byte) ([]byte, error) { //nolint:cyclop + if len(payload) == 0 { + return nil, fmt.Errorf("%w: %d <=0", errShortPacket, len(payload)) + } + + // NALU Types + // https://tools.ietf.org/html/rfc6184#section-5.4 + naluType := payload[0] & naluTypeBitmask + switch { + case naluType > 0 && naluType < 24: + return p.doPackaging(nil, payload), nil + + case naluType == stapaNALUType: + currOffset := int(stapaHeaderSize) + result := []byte{} + for currOffset < len(payload) { + naluSizeBytes := payload[currOffset:] + if len(naluSizeBytes) < stapaNALULengthSize { + break + } + naluSize := int(binary.BigEndian.Uint16(naluSizeBytes)) + currOffset += stapaNALULengthSize + + if len(payload) < currOffset+naluSize { + return nil, fmt.Errorf( + "%w STAP-A declared size(%d) is larger than buffer(%d)", + errShortPacket, + naluSize, + len(payload)-currOffset, + ) + } + + result = p.doPackaging(result, payload[currOffset:currOffset+naluSize]) + currOffset += naluSize + } + + return result, nil + + case naluType == fuaNALUType: + if len(payload) < fuaHeaderSize { + return nil, errShortPacket + } + + if p.fuaBuffer == nil { + p.fuaBuffer = []byte{} + } + + p.fuaBuffer = append(p.fuaBuffer, payload[fuaHeaderSize:]...) + + if payload[1]&fuEndBitmask != 0 { + naluRefIdc := payload[0] & naluRefIdcBitmask + fragmentedNaluType := payload[1] & naluTypeBitmask + + nalu := append([]byte{}, naluRefIdc|fragmentedNaluType) + nalu = append(nalu, p.fuaBuffer...) + p.fuaBuffer = nil + + return p.doPackaging(nil, nalu), nil + } + + return []byte{}, nil + } + + return nil, fmt.Errorf("%w: %d", errUnhandledNALUType, naluType) +} + +// H264PartitionHeadChecker checks H264 partition head. +// +// Deprecated: replaced by H264Packet.IsPartitionHead(). +type H264PartitionHeadChecker struct{} + +// IsPartitionHead checks if this is the head of a packetized nalu stream. +// +// Deprecated: replaced by H264Packet.IsPartitionHead(). +func (*H264PartitionHeadChecker) IsPartitionHead(packet []byte) bool { + return (&H264Packet{}).IsPartitionHead(packet) +} + +// IsPartitionHead checks if this is the head of a packetized nalu stream. +func (*H264Packet) IsPartitionHead(payload []byte) bool { + if len(payload) < 2 { + return false + } + + if payload[0]&naluTypeBitmask == fuaNALUType || + payload[0]&naluTypeBitmask == fubNALUType { + return payload[1]&fuStartBitmask != 0 + } + + return true +} diff --git a/vendor/github.com/pion/rtp/codecs/h265_packet.go b/vendor/github.com/pion/rtp/codecs/h265_packet.go new file mode 100644 index 0000000..38a8bbb --- /dev/null +++ b/vendor/github.com/pion/rtp/codecs/h265_packet.go @@ -0,0 +1,1812 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package codecs + +import ( + "encoding/binary" + "errors" + "fmt" +) + +// +// Errors +// + +var ( + errH265CorruptedPacket = errors.New("corrupted h265 packet") + errInvalidH265PacketType = errors.New("invalid h265 packet type") + errMissingDonl = errors.New("expecting all aggregated packets to have DONL values") + errDonlOutOfOrder = errors.New("expecting aggregation packets to have increasing DONL values") + errDondTooLarge = errors.New("expecint DONL difference between packets to be no more than 256") + errExpectFragmentationStartUnit = errors.New("expecting a fragmentation start unit") + errH265PACIPHESTooLong = errors.New("expecting a PHES field shorter than 32 bytes") +) + +// +// Network Abstraction Unit Header implementation +// + +const ( + // sizeof(uint16). + h265NaluHeaderSize = 2 + // sizeof(uint16). + h265NaluDonlSize = 2 + // https://datatracker.ietf.org/doc/html/rfc7798#section-4.4.2 + h265NaluAggregationPacketType = 48 + // https://datatracker.ietf.org/doc/html/rfc7798#section-4.4.3 + h265NaluFragmentationUnitType = 49 + // https://datatracker.ietf.org/doc/html/rfc7798#section-4.4.4 + h265NaluPACIPacketType = 50 + h265AggregatedPacketMaxSize = ^uint16(0) + h265AggregatedPacketLengthSize = 2 +) + +// H265NALUHeader is a H265 NAL Unit Header. +// https://datatracker.ietf.org/doc/html/rfc7798#section-1.1.4 +// +// +---------------+---------------+ +// |0|1|2|3|4|5|6|7|0|1|2|3|4|5|6|7| +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |F| Type | LayerID | TID | +// +-------------+-----------------+ +// +// . +type H265NALUHeader uint16 + +func newH265NALUHeader(highByte, lowByte uint8) H265NALUHeader { + return H265NALUHeader((uint16(highByte) << 8) | uint16(lowByte)) +} + +// F is the forbidden bit, should always be 0. +func (h H265NALUHeader) F() bool { + return (uint16(h) >> 15) != 0 +} + +// Type of NAL Unit. +func (h H265NALUHeader) Type() uint8 { + // 01111110 00000000 + const mask = 0b01111110 << 8 + + return uint8((uint16(h) & mask) >> (8 + 1)) // nolint: gosec // G115 false positive +} + +// IsTypeVCLUnit returns whether or not the NAL Unit type is a VCL NAL unit. +func (h H265NALUHeader) IsTypeVCLUnit() bool { + // Type is coded on 6 bits + const msbMask = 0b00100000 + + return (h.Type() & msbMask) == 0 +} + +// LayerID should always be 0 in non-3D HEVC context. +func (h H265NALUHeader) LayerID() uint8 { + // 00000001 11111000 + const mask = (0b00000001 << 8) | 0b11111000 + + return uint8((uint16(h) & mask) >> 3) // nolint: gosec // G115 false positive +} + +// TID is the temporal identifier of the NAL unit +1. +func (h H265NALUHeader) TID() uint8 { + const mask = 0b00000111 + + return uint8(uint16(h) & mask) // nolint: gosec // G115 false positive +} + +// IsAggregationPacket returns whether or not the packet is an Aggregation packet. +func (h H265NALUHeader) IsAggregationPacket() bool { + return h.Type() == h265NaluAggregationPacketType +} + +// IsFragmentationUnit returns whether or not the packet is a Fragmentation Unit packet. +func (h H265NALUHeader) IsFragmentationUnit() bool { + return h.Type() == h265NaluFragmentationUnitType +} + +// IsPACIPacket returns whether or not the packet is a PACI packet. +func (h H265NALUHeader) IsPACIPacket() bool { + return h.Type() == h265NaluPACIPacketType +} + +// +// Single NAL Unit Packet implementation +// + +// H265SingleNALUnitPacket represents a NALU packet, containing exactly one NAL unit. +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | PayloadHdr | DONL (conditional) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | | +// | NAL unit payload data | +// | | +// | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | :...OPTIONAL RTP padding | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// +// Reference: https://datatracker.ietf.org/doc/html/rfc7798#section-4.4.1 +type h265SingleNALUnitPacket struct { + // payloadHeader is the header of the H265 packet. + payloadHeader H265NALUHeader + // donl is a 16-bit field, that may or may not be present. + donl *uint16 + // payload of the fragmentation unit. + payload []byte +} + +// PayloadHeader returns the NALU header of the packet. +func (p *h265SingleNALUnitPacket) PayloadHeader() H265NALUHeader { + return p.payloadHeader +} + +// DONL returns the DONL of the packet. +func (p *h265SingleNALUnitPacket) DONL() *uint16 { + return p.donl +} + +// Payload returns the Fragmentation Unit packet payload. +func (p *h265SingleNALUnitPacket) Payload() []byte { + return p.payload +} + +func (p *h265SingleNALUnitPacket) wireSize() int { + size := h265NaluHeaderSize + if p.donl != nil { + size += h265NaluDonlSize + } + size += len(p.payload) + + return size +} + +func parseH265SingleNalUnitPacket(buf []byte, withDONL bool) (*h265SingleNALUnitPacket, error) { + if buf == nil { + return nil, errNilPacket + } + + minSize := h265NaluHeaderSize + + if withDONL { + minSize += h265NaluDonlSize + } + + if len(buf) <= minSize { + return nil, fmt.Errorf("%w: %d <= %v", errShortPacket, len(buf), minSize) + } + + payloadHeader := newH265NALUHeader(buf[0], buf[1]) + + if payloadHeader.F() { + return nil, errH265CorruptedPacket + } + + if payloadHeader.IsFragmentationUnit() || payloadHeader.IsPACIPacket() || payloadHeader.IsAggregationPacket() { + return nil, errInvalidH265PacketType + } + + var donl *uint16 + + buf = buf[2:] + + if withDONL { + donlValue := binary.BigEndian.Uint16(buf[:2]) + donl = &donlValue + buf = buf[2:] + } + + packet := h265SingleNALUnitPacket{ + payloadHeader, + donl, + buf, + } + + return &packet, nil +} + +func (p *h265SingleNALUnitPacket) isH265Packet() {} + +func (p *h265SingleNALUnitPacket) header() H265NALUHeader { + return p.payloadHeader +} + +func (p *h265SingleNALUnitPacket) toAnnexB(buf []byte) []byte { + buf = append(buf, annexbNALUStartCode...) + + donl := p.donl + p.donl = nil + buf = p.serialize(buf) + p.donl = donl + + return buf +} + +func (p *h265SingleNALUnitPacket) serialize(buf []byte) []byte { + buf = binary.BigEndian.AppendUint16(buf, uint16(p.payloadHeader)) + + if p.donl != nil { + buf = binary.BigEndian.AppendUint16(buf, *p.donl) + } + + buf = append(buf, p.payload...) + + return buf +} + +// H265SingleNALUnitPacket represents a NALU packet, containing exactly one NAL unit. +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | PayloadHdr | DONL (conditional) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | | +// | NAL unit payload data | +// | | +// | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | :...OPTIONAL RTP padding | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// +// Reference: https://datatracker.ietf.org/doc/html/rfc7798#section-4.4.1 +// +// Deprecated: replaced with a private type instead, will be removed in a future release. +type H265SingleNALUnitPacket struct { + // payloadHeader is the header of the H265 packet. + payloadHeader H265NALUHeader + // donl is a 16-bit field, that may or may not be present. + donl *uint16 + // payload of the fragmentation unit. + payload []byte + + mightNeedDONL bool +} + +// WithDONL can be called to specify whether or not DONL might be parsed. +// DONL may need to be parsed if `sprop-max-don-diff` is greater than 0 on the RTP stream. +func (p *H265SingleNALUnitPacket) WithDONL(value bool) { + p.mightNeedDONL = value +} + +// Unmarshal parses the passed byte slice and stores the result in the H265SingleNALUnitPacket +// this method is called upon. +func (p *H265SingleNALUnitPacket) Unmarshal(payload []byte) ([]byte, error) { + parsed, err := parseH265SingleNalUnitPacket(payload, p.mightNeedDONL) + if err != nil { + return nil, err + } + p.payloadHeader = parsed.payloadHeader + p.donl = parsed.donl + p.payload = parsed.payload + + return nil, nil +} + +// PayloadHeader returns the NALU header of the packet. +func (p *H265SingleNALUnitPacket) PayloadHeader() H265NALUHeader { + return p.payloadHeader +} + +// DONL returns the DONL of the packet. +func (p *H265SingleNALUnitPacket) DONL() *uint16 { + return p.donl +} + +// Payload returns the Fragmentation Unit packet payload. +func (p *H265SingleNALUnitPacket) Payload() []byte { + return p.payload +} + +// +// Aggregation Packets implementation +// + +// h265AggregationPacket represents an Aggregation packet. +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | PayloadHdr (Type=48) | | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | +// | | +// | two or more aggregation units | +// | | +// | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | :...OPTIONAL RTP padding | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// +// Reference: https://datatracker.ietf.org/doc/html/rfc7798#section-4.4.2 +type h265AggregationPacket struct { + payloadHeader H265NALUHeader + donl *uint16 + payload []byte +} + +// PayloadHeader returns the NALU header of the packet. +func (p *h265AggregationPacket) PayloadHeader() H265NALUHeader { + return p.payloadHeader +} + +// DONL returns the DONL of the packet. +func (p *h265AggregationPacket) DONL() *uint16 { + return p.donl +} + +// Payload returns the Fragmentation Unit packet payload. +func (p *h265AggregationPacket) Payload() []byte { + return p.payload +} + +func (p *h265AggregationPacket) isH265Packet() {} + +func (p *h265AggregationPacket) header() H265NALUHeader { + return p.payloadHeader +} + +func (p *h265AggregationPacket) serialize(buf []byte) []byte { + buf = binary.BigEndian.AppendUint16(buf, uint16(p.payloadHeader)) + + if p.donl != nil { + buf = binary.BigEndian.AppendUint16(buf, *p.donl) + } + + buf = append(buf, p.payload...) + + return buf +} + +func parseH265AggregationPacket(buf []byte, withDONL bool) (*h265AggregationPacket, error) { + // header + 2 length fields + minSize := h265NaluHeaderSize + (h265AggregatedPacketLengthSize * 2) + payloadStart := h265NaluHeaderSize + + if withDONL { + payloadStart += h265NaluDonlSize + minSize += h265NaluDonlSize + } + + if len(buf) < minSize { + return nil, errShortPacket + } + + header := H265NALUHeader(binary.BigEndian.Uint16(buf[0:2])) + + if !header.IsAggregationPacket() { + return nil, errInvalidH265PacketType + } + + var donl *uint16 + + if withDONL { + donlValue := binary.BigEndian.Uint16(buf[2:4]) + donl = &donlValue + } + + payload := buf[payloadStart:] + + packet := h265AggregationPacket{ + header, + donl, + payload, + } + + return &packet, nil +} + +// returns whether this NALU can even fit inside an AP with another NALU. +func canAggregateH265(mtu uint16, packet *h265SingleNALUnitPacket) bool { + // must leave enough space for the AP header, optionally its DONL field, 2 length headers, a 2nd AU's header + // and a second packet's DOND field + return packet.wireSize()+(h265AggregatedPacketLengthSize*2)+h265NaluHeaderSize+1 <= int(mtu) +} + +// returns whether inserting a new packet will make this list of packets too big to aggregate within the MTU. +func shouldAggregateH265Now(mtu uint16, packets []h265SingleNALUnitPacket, newPacket h265SingleNALUnitPacket) bool { + if len(packets) < 1 { + return false + } + // AP header + each AU's size field + totalSize := h265NaluHeaderSize + ((len(packets) + 1) * h265AggregatedPacketLengthSize) + hasDonl := packets[0].donl != nil + // first AU's DONL field + if hasDonl { + totalSize += 2 + } + + if hasDonl && newPacket.donl == nil { + return true + } + + for _, p := range packets { + totalSize += p.wireSize() + // individual AUs have their DONL fields replaced with DOND (1 byte) + if hasDonl { + totalSize -= 1 + } + } + + totalSize += newPacket.wireSize() + if hasDonl { + totalSize -= 1 + } + + return totalSize > int(mtu) +} + +// Reference: https://datatracker.ietf.org/doc/html/rfc7798#section-4.4.2 +// nolint: cyclop // hot path +func newH265AggregationPacket(packets []h265SingleNALUnitPacket) (*h265AggregationPacket, error) { + if packets == nil { + return nil, errNilPacket + } + if len(packets) < 2 { + return nil, errNotEnoughPackets + } + + donlExpected := packets[0].donl != nil + var aggrDonl *uint16 + if donlExpected { + aggrDonlVal := *packets[0].donl + aggrDonl = &aggrDonlVal + } + + header := uint16(0) + header |= h265NaluAggregationPacketType << 9 + + firstPacket := packets[0] + if firstPacket.wireSize() > int(h265AggregatedPacketMaxSize) { + return nil, errPacketTooLarge + } + + fBit := firstPacket.payloadHeader.F() + layerID := firstPacket.payloadHeader.LayerID() + tid := firstPacket.payloadHeader.TID() + + payload := make([]byte, 0) + + lastDonl := packets[0].donl + for i, packet := range packets { + if donlExpected && packet.donl == nil { + return nil, errMissingDonl + } + if i > 0 && packet.donl != nil { + // the DOND field plus 1 specifies the difference between + // the decoding order number values of the current aggregated NAL unit + // and the preceding aggregated NAL unit in the same AP. + dond := int(*packet.donl) - int(*lastDonl) - 1 + if dond < 0 { + return nil, errDonlOutOfOrder + } + if dond > int(^uint8(0)) { + return nil, errDondTooLarge + } + payload = append(payload, uint8(dond)) + lastDonl = packet.donl + } + // following AUs' DONs are derived from the DOND field + packet.donl = nil + + if packet.wireSize() > int(h265AggregatedPacketMaxSize) { + return nil, errPacketTooLarge + } + + if packet.payloadHeader.F() { + fBit = true + } + pLayerID := packet.payloadHeader.LayerID() + if pLayerID < layerID { + layerID = pLayerID + } + pTid := packet.payloadHeader.TID() + if pTid < tid { + tid = pTid + } + + // nolint: gosec // Already checked for max size + payload = binary.BigEndian.AppendUint16(payload, uint16(packet.wireSize())) + + payload = packet.serialize(payload) + } + + header |= uint16(tid) + header |= uint16(layerID) << 3 + + if fBit { + header |= uint16(0b1) << 15 + } + + packet := h265AggregationPacket{ + H265NALUHeader(header), + aggrDonl, + payload, + } + + return &packet, nil +} + +func splitH265AggregationPacket(packet h265AggregationPacket) ([]h265SingleNALUnitPacket, error) { // nolint:cyclop + curDonl := packet.donl + packets := make([]h265SingleNALUnitPacket, 0) + payload := packet.payload + + i := 0 + for len(payload) > 0 { + minSize := h265AggregatedPacketLengthSize + + // DOND is present starting on 2nd AU + if curDonl != nil && i > 0 { + minSize += 1 + } + + if len(payload) < minSize { + return nil, errShortPacket + } + + var donl *uint16 + if curDonl != nil { + if i == 0 { + donl = curDonl + } else { + donlValue := *curDonl + uint16(payload[0]) + 1 + donl = &donlValue + curDonl = &donlValue + payload = payload[1:] + } + } + + curLen := binary.BigEndian.Uint16(payload[0:2]) + if len(payload[2:]) < int(curLen) { + return nil, errShortPacket + } + + parsed, err := parseH265SingleNalUnitPacket(payload[2:2+curLen], false) + if err != nil { + return nil, err + } + + if curDonl != nil { + parsed.donl = donl + } + packets = append(packets, *parsed) + payload = payload[2+curLen:] + + i++ + } + if len(packets) < 2 { + return nil, errNotEnoughPackets + } + + return packets, nil +} + +// H265AggregationUnitFirst represent the First Aggregation Unit in an AP. +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// : DONL (conditional) | NALU size | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | NALU size | | +// +-+-+-+-+-+-+-+-+ NAL unit | +// | | +// | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | : +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// +// Reference: https://datatracker.ietf.org/doc/html/rfc7798#section-4.4.2 +// +// Deprecated: replaced with a private type instead, will be removed in a future release. +type H265AggregationUnitFirst struct { + donl *uint16 + nalUnitSize uint16 + nalUnit []byte +} + +// DONL field, when present, specifies the value of the 16 least +// significant bits of the decoding order number of the aggregated NAL +// unit. +func (u H265AggregationUnitFirst) DONL() *uint16 { + return u.donl +} + +// NALUSize represents the size, in bytes, of the NalUnit. +func (u H265AggregationUnitFirst) NALUSize() uint16 { + return u.nalUnitSize +} + +// NalUnit payload. +func (u H265AggregationUnitFirst) NalUnit() []byte { + return u.nalUnit +} + +// H265AggregationUnit represent the an Aggregation Unit in an AP, which is not the first one. +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// : DOND (cond) | NALU size | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | | +// | NAL unit | +// | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | : +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// +// Reference: https://datatracker.ietf.org/doc/html/rfc7798#section-4.4.2 +// +// Deprecated: replaced with a private type instead, will be removed in a future release. +type H265AggregationUnit struct { + dond *uint8 + nalUnitSize uint16 + nalUnit []byte +} + +// DOND field plus 1 specifies the difference between +// the decoding order number values of the current aggregated NAL unit +// and the preceding aggregated NAL unit in the same AP. +func (u H265AggregationUnit) DOND() *uint8 { + return u.dond +} + +// NALUSize represents the size, in bytes, of the NalUnit. +func (u H265AggregationUnit) NALUSize() uint16 { + return u.nalUnitSize +} + +// NalUnit payload. +func (u H265AggregationUnit) NalUnit() []byte { + return u.nalUnit +} + +// H265AggregationPacket represents an Aggregation packet. +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | PayloadHdr (Type=48) | | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | +// | | +// | two or more aggregation units | +// | | +// | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | :...OPTIONAL RTP padding | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// +// Reference: https://datatracker.ietf.org/doc/html/rfc7798#section-4.4.2 +// +// Deprecated: replaced with a private type instead, will be removed in a future release. +type H265AggregationPacket struct { + payloadHeader H265NALUHeader + firstUnit *H265AggregationUnitFirst + otherUnits []H265AggregationUnit + + mightNeedDONL bool +} + +// WithDONL can be called to specify whether or not DONL might be parsed. +// DONL may need to be parsed if `sprop-max-don-diff` is greater than 0 on the RTP stream. +func (p *H265AggregationPacket) WithDONL(value bool) { + p.mightNeedDONL = value +} + +// Unmarshal parses the passed byte slice and stores the result in the H265AggregationPacket this method is called upon. +func (p *H265AggregationPacket) Unmarshal(payload []byte) ([]byte, error) { //nolint:cyclop + // sizeof(headers) + minSize := h265NaluHeaderSize + (h265AggregatedPacketLengthSize * 2) + + if p.mightNeedDONL { + minSize += h265NaluDonlSize + } + + if payload == nil { + return nil, errNilPacket + } else if len(payload) <= minSize { + return nil, fmt.Errorf("%w: %d <= %v", errShortPacket, len(payload), minSize) + } + + payloadHeader := newH265NALUHeader(payload[0], payload[1]) + if payloadHeader.F() { + return nil, errH265CorruptedPacket + } + if !payloadHeader.IsAggregationPacket() { + return nil, errInvalidH265PacketType + } + p.payloadHeader = payloadHeader + + // First parse the first aggregation unit + payload = payload[2:] + firstUnit := &H265AggregationUnitFirst{} + + if p.mightNeedDONL { + if len(payload) < 2 { + return nil, errShortPacket + } + + donl := binary.BigEndian.Uint16(payload[0:2]) + firstUnit.donl = &donl + + payload = payload[2:] + } + if len(payload) < 2 { + return nil, errShortPacket + } + firstUnit.nalUnitSize = binary.BigEndian.Uint16(payload[0:2]) + payload = payload[2:] + + if len(payload) < int(firstUnit.nalUnitSize) { + return nil, errShortPacket + } + + firstUnit.nalUnit = payload[:firstUnit.nalUnitSize] + payload = payload[firstUnit.nalUnitSize:] + + // Parse remaining Aggregation Units + var units []H265AggregationUnit + for { + unit := H265AggregationUnit{} + + if p.mightNeedDONL { + if len(payload) < 1 { + break + } + + dond := payload[0] + unit.dond = &dond + + payload = payload[1:] + } + + if len(payload) < 2 { + break + } + unit.nalUnitSize = binary.BigEndian.Uint16(payload[0:2]) + payload = payload[2:] + + if len(payload) < int(unit.nalUnitSize) { + return nil, errShortPacket + } + + unit.nalUnit = payload[:unit.nalUnitSize] + payload = payload[unit.nalUnitSize:] + + units = append(units, unit) + } + + // There need to be **at least** two Aggregation Units (first + another one) + if len(units) < 1 { + return nil, errShortPacket + } + + p.firstUnit = firstUnit + p.otherUnits = units + + return nil, nil +} + +// FirstUnit returns the first Aggregated Unit of the packet. +func (p *H265AggregationPacket) FirstUnit() *H265AggregationUnitFirst { + return p.firstUnit +} + +// OtherUnits returns the all the other Aggregated Unit of the packet (excluding the first one). +func (p *H265AggregationPacket) OtherUnits() []H265AggregationUnit { + return p.otherUnits +} + +// +// Fragmentation Unit implementation +// + +const ( + // sizeof(uint8). + h265FragmentationUnitHeaderSize = 1 +) + +// H265FragmentationUnitHeader is a H265 FU Header. +// +// +---------------+ +// |0|1|2|3|4|5|6|7| +// +-+-+-+-+-+-+-+-+ +// |S|E| FuType | +// +---------------+ +// +// . +type H265FragmentationUnitHeader uint8 + +func newH265FragmentationUnitHeader( + payloadHeader H265NALUHeader, + s, e bool, //nolint:unparam +) H265FragmentationUnitHeader { + header := payloadHeader.Type() + if s { + header |= 0b1 << 7 + } + if e { + header |= 0b1 << 6 + } + + return H265FragmentationUnitHeader(header) +} + +// S represents the start of a fragmented NAL unit. +func (h H265FragmentationUnitHeader) S() bool { + const mask = 0b10000000 + + return ((h & mask) >> 7) != 0 +} + +// E represents the end of a fragmented NAL unit. +func (h H265FragmentationUnitHeader) E() bool { + const mask = 0b01000000 + + return ((h & mask) >> 6) != 0 +} + +// FuType MUST be equal to the field Type of the fragmented NAL unit. +func (h H265FragmentationUnitHeader) FuType() uint8 { + const mask = 0b00111111 + + return uint8(h) & mask +} + +type h265FragmentationPacket struct { + payloadHeader H265NALUHeader + fuHeader H265FragmentationUnitHeader + donl *uint16 + payload []byte +} + +func (p *h265FragmentationPacket) isH265Packet() {} + +func (p *h265FragmentationPacket) header() H265NALUHeader { + return p.payloadHeader +} + +func (p *h265FragmentationPacket) serialize(buf []byte) []byte { + buf = binary.BigEndian.AppendUint16(buf, uint16(p.payloadHeader)) + buf = append(buf, byte(p.fuHeader)) + + if p.donl != nil { + buf = binary.BigEndian.AppendUint16(buf, *p.donl) + } + + buf = append(buf, p.payload...) + + return buf +} + +func parseH265FragmentationPacket(payload []byte, withDONL bool) (*h265FragmentationPacket, error) { + minSize := h265NaluHeaderSize + h265FragmentationUnitHeaderSize + payloadStart := h265NaluHeaderSize + h265FragmentationUnitHeaderSize + + if withDONL { + minSize += h265NaluDonlSize + payloadStart += h265NaluDonlSize + } + + if len(payload) < minSize { + return nil, errShortPacket + } + + header := H265NALUHeader(binary.BigEndian.Uint16(payload[0:2])) + + if !header.IsFragmentationUnit() { + return nil, errInvalidH265PacketType + } + + var donl *uint16 + if withDONL { + donlVal := binary.BigEndian.Uint16(payload[3:5]) + donl = &donlVal + } + + packet := h265FragmentationPacket{ + header, + H265FragmentationUnitHeader(payload[2]), + donl, + payload[payloadStart:], + } + + return &packet, nil +} + +// Replaces the original header's type with 49, while keeping other fields. +func newH265FragmentationPacketHeader(payloadHeader H265NALUHeader) H265NALUHeader { + typeMask := ^uint16(0b01111110_00000000) + + return H265NALUHeader((uint16(payloadHeader) & typeMask) | (h265NaluFragmentationUnitType << 9)) +} + +// Replaces the FU's payload header's type with the FU Header's type, while keeping other fields. +func rebuildH265FragmentationPacketHeader( + payloadHeader H265NALUHeader, + fuHeader H265FragmentationUnitHeader, +) H265NALUHeader { + typeMask := ^uint16(0b01111110_00000000) + origType := uint8(fuHeader) & 0b00111111 + + return H265NALUHeader((uint16(payloadHeader) & typeMask) | (uint16(origType) << 9)) +} + +// Splits a H265SingleNALUnitPacket into many FU packets. +// +// Errors if the packet would result in a single FU packet. +// +// The P bit is not set in any case. +func newH265FragmentationPackets(mtu uint16, packet *h265SingleNALUnitPacket) ([]h265FragmentationPacket, error) { + if packet == nil { + return nil, errNilPacket + } + + // size of Header, FU header and (optionally) the DONL + overheadSize := 3 + if packet.donl != nil { + overheadSize += 2 + } + + sliceSize := int(mtu) - overheadSize + + if len(packet.payload) <= sliceSize { + return nil, errShortPacket + } + + packets := make([]h265FragmentationPacket, 0) + header := newH265FragmentationPacketHeader(packet.payloadHeader) + + fuPayload := packet.payload + + firstPacket := h265FragmentationPacket{ + payloadHeader: header, + fuHeader: newH265FragmentationUnitHeader(packet.payloadHeader, true, false), + donl: packet.donl, + payload: fuPayload[:sliceSize], + } + packets = append(packets, firstPacket) + fuPayload = fuPayload[sliceSize:] + + for len(fuPayload) > sliceSize { + p := h265FragmentationPacket{ + payloadHeader: header, + fuHeader: newH265FragmentationUnitHeader(packet.payloadHeader, false, false), + donl: nil, + payload: fuPayload[:sliceSize], + } + packets = append(packets, p) + + fuPayload = fuPayload[sliceSize:] + } + + lastPacket := h265FragmentationPacket{ + payloadHeader: header, + fuHeader: newH265FragmentationUnitHeader(packet.payloadHeader, false, true), + donl: nil, + payload: fuPayload, + } + packets = append(packets, lastPacket) + + return packets, nil +} + +func rebuildH265FragmentationPackets(packets []h265FragmentationPacket) (*h265SingleNALUnitPacket, error) { + if len(packets) < 2 { + return nil, errNotEnoughPackets + } + + if !packets[0].fuHeader.S() { + return nil, errFirstFragmentationUnitMissing + } + if !packets[len(packets)-1].fuHeader.E() { + return nil, errLastFragmentationUnitMissing + } + + payload := make([]byte, 0) + for _, fu := range packets { + payload = append(payload, fu.payload...) + } + + rebuilt := h265SingleNALUnitPacket{ + payloadHeader: rebuildH265FragmentationPacketHeader(packets[0].payloadHeader, packets[0].fuHeader), + donl: packets[0].donl, + payload: payload, + } + + return &rebuilt, nil +} + +// H265FragmentationUnitPacket represents a single Fragmentation Unit packet. +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | PayloadHdr (Type=49) | FU header | DONL (cond) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-| +// | DONL (cond) | | +// |-+-+-+-+-+-+-+-+ | +// | FU payload | +// | | +// | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | :...OPTIONAL RTP padding | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// +// Reference: https://datatracker.ietf.org/doc/html/rfc7798#section-4.4.3 +// +// Deprecated: replaced with a private type instead, will be removed in a future release. +type H265FragmentationUnitPacket struct { + // payloadHeader is the header of the H265 packet. + payloadHeader H265NALUHeader + // fuHeader is the header of the fragmentation unit + fuHeader H265FragmentationUnitHeader + // donl is a 16-bit field, that may or may not be present. + donl *uint16 + // payload of the fragmentation unit. + payload []byte + + mightNeedDONL bool +} + +// WithDONL can be called to specify whether or not DONL might be parsed. +// DONL may need to be parsed if `sprop-max-don-diff` is greater than 0 on the RTP stream. +func (p *H265FragmentationUnitPacket) WithDONL(value bool) { + p.mightNeedDONL = value +} + +// Unmarshal parses the passed byte slice and stores the result in the H265FragmentationUnitPacket +// this method is called upon. +func (p *H265FragmentationUnitPacket) Unmarshal(payload []byte) ([]byte, error) { + parsed, err := parseH265FragmentationPacket(payload, p.mightNeedDONL) + if err != nil { + return nil, err + } + + p.payloadHeader = parsed.payloadHeader + p.fuHeader = parsed.fuHeader + p.donl = parsed.donl + p.payload = parsed.payload + + return nil, nil +} + +// PayloadHeader returns the NALU header of the packet. +func (p *H265FragmentationUnitPacket) PayloadHeader() H265NALUHeader { + return p.payloadHeader +} + +// FuHeader returns the Fragmentation Unit Header of the packet. +func (p *H265FragmentationUnitPacket) FuHeader() H265FragmentationUnitHeader { + return p.fuHeader +} + +// DONL returns the DONL of the packet. +func (p *H265FragmentationUnitPacket) DONL() *uint16 { + return p.donl +} + +// Payload returns the Fragmentation Unit packet payload. +func (p *H265FragmentationUnitPacket) Payload() []byte { + return p.payload +} + +// H265FragmentationPacket represents a Fragmentation packet, which contains one or more Fragmentation Units. +// +// Deprecated: replaced with a private type instead, will be removed in a future release. +type H265FragmentationPacket struct { + payloadHeader H265NALUHeader + donl *uint16 + units []*H265FragmentationUnitPacket + payload []byte +} + +// NewH265FragmentationPacket creates a H265FragmentationPacket. +func NewH265FragmentationPacket(startUnit *H265FragmentationUnitPacket) *H265FragmentationPacket { + return &H265FragmentationPacket{ + payloadHeader: (startUnit.payloadHeader & 0x81FF) | (H265NALUHeader(startUnit.FuHeader().FuType()) << 9), + donl: startUnit.donl, + units: []*H265FragmentationUnitPacket{startUnit}, + } +} + +// PayloadHeader returns the NALU header of the packet. +func (p *H265FragmentationPacket) PayloadHeader() H265NALUHeader { + return p.payloadHeader +} + +// DONL returns the DONL of the packet. +func (p *H265FragmentationPacket) DONL() *uint16 { + return p.donl +} + +// Payload returns the Fragmentation packet payload. +func (p *H265FragmentationPacket) Payload() []byte { + return p.payload +} + +// +// PACI implementation +// + +// paciHeaderFields is the few fields after the payload header of a PACI packet +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | PayloadHdr (Type=50) |A| cType | PHSsize |F0..2|Y| +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Payload Header Extension Structure (PHES) | +// |=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=| +// | | +// | PACI payload: NAL unit | +// | . . . | +// | | +// | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | :...OPTIONAL RTP padding | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +type paciHeaderFields uint16 + +func (h *paciHeaderFields) A() bool { + return (uint16(*h) & 0b1 << 15) != 0 +} + +func (h *paciHeaderFields) CType() uint8 { + mask := uint16(0b111111) << 9 + + return uint8((uint16(*h) & mask) >> 9) // nolint:gosec // G115 false positive +} + +func (h *paciHeaderFields) PHSize() uint8 { + mask := uint16(0b11111) << 4 + + return uint8((uint16(*h) & mask) >> 4) // nolint:gosec // G115 false positive +} + +func (h *paciHeaderFields) F0() bool { + return (uint16(*h) & 0b1 << 3) != 0 +} + +func (h *paciHeaderFields) F1() bool { + return (uint16(*h) & 0b1 << 2) != 0 +} + +func (h *paciHeaderFields) F2() bool { + return (uint16(*h) & 0b1 << 1) != 0 +} + +func (h *paciHeaderFields) Y() bool { + return (uint16(*h) & 0b1) != 0 +} + +// H265PACIPacket represents a single H265 PACI packet. +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | PayloadHdr (Type=50) |A| cType | PHSsize |F0..2|Y| +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Payload Header Extension Structure (PHES) | +// |=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=| +// | | +// | PACI payload: NAL unit | +// | . . . | +// | | +// | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | :...OPTIONAL RTP padding | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// +// Reference: https://datatracker.ietf.org/doc/html/rfc7798#section-4.4.4 +type H265PACIPacket struct { + // payloadHeader is the header of the H265 packet. + payloadHeader H265NALUHeader + + // Field which holds value for `A`, `cType`, `PHSsize`, `F0`, `F1`, `F2` and `Y` fields. + paciHeaderFields + + // phes is a header extension, of byte length `PHSsize` + phes []byte + + // Payload contains NAL units & optional padding + payload isH265Packet +} + +// PayloadHeader returns the NAL Unit Header. +func (p *H265PACIPacket) PayloadHeader() H265NALUHeader { + return p.payloadHeader +} + +func (p *H265PACIPacket) PHSsize() uint8 { + return p.paciHeaderFields.PHSize() +} + +// PHES contains header extensions. Its size is indicated by PHSsize. +func (p *H265PACIPacket) PHES() []byte { + return p.phes +} + +// Payload is a single NALU or NALU-like struct, without its header. +func (p *H265PACIPacket) Payload() []byte { + return p.payload.serialize(make([]byte, 0))[2:] +} + +// TSCI returns the Temporal Scalability Control Information extension, if present. +func (p *H265PACIPacket) TSCI() *H265TSCI { + if !p.F0() || p.PHSsize() < 3 || len(p.phes) < 3 { + return nil + } + + tsci := H265TSCI((uint32(p.phes[0]) << 16) | (uint32(p.phes[1]) << 8) | uint32(p.phes[2])) + + return &tsci +} + +func rebuildPACIHeader(header H265NALUHeader, paciFields paciHeaderFields) H265NALUHeader { + f := uint16(0) + if paciFields.A() { + f = 1 + } + pType := paciFields.CType() + layerID := header.LayerID() + tid := header.TID() + + return H265NALUHeader( + (f << 15) | + (uint16(pType) << 9) | + (uint16(layerID) << 3) | + (uint16(tid)), + ) +} + +func parseH265PACIPacket(buf []byte, withDONL bool) (*H265PACIPacket, error) { // nolint: cyclop + minSize := h265NaluHeaderSize + 2 + if buf == nil { + return nil, errNilPacket + } + if len(buf) < minSize { + return nil, errShortPacket + } + header := H265NALUHeader(binary.BigEndian.Uint16(buf[0:2])) + + if header.Type() != h265NaluPACIPacketType { + return nil, errInvalidH265PacketType + } + + paciFields := paciHeaderFields(binary.BigEndian.Uint16(buf[2:4])) + + // a PACI packet cannot be inside another PACI packet + if paciFields.CType() == h265NaluPACIPacketType { + return nil, errInvalidH265PacketType + } + + if len(buf) < minSize+int(paciFields.PHSize()) { + return nil, errShortPacket + } + + payloadStart := 4 + paciFields.PHSize() + + phes := buf[4:payloadStart] + + innerNalu := buf[payloadStart:] + + var innerPacket isH265Packet + + switch paciFields.CType() { + case h265NaluAggregationPacketType: + minLength := h265NaluHeaderSize + h265AggregatedPacketLengthSize*2 + // DONL field + 1 DOND field + if withDONL { + minLength += h265NaluDonlSize + 1 + } + if len(innerNalu) < minLength { + return nil, errShortPacket + } + var donl *uint16 + innerPayloadStart := 0 + if withDONL { + donlVal := binary.BigEndian.Uint16(innerNalu[0:2]) + donl = &donlVal + innerPayloadStart += h265NaluDonlSize + } + + innerPacket = &h265AggregationPacket{ + payloadHeader: rebuildPACIHeader(header, paciFields), + donl: donl, + payload: innerNalu[innerPayloadStart:], + } + case h265NaluFragmentationUnitType: + // header + fuHeader + minLength := h265NaluHeaderSize + 1 + if withDONL { + minLength += h265NaluDonlSize + } + if len(innerNalu) < minLength { + return nil, errShortPacket + } + var donl *uint16 + innerPayloadStart := 1 + if withDONL { + donlVal := binary.BigEndian.Uint16(innerNalu[1:3]) + donl = &donlVal + innerPayloadStart += h265NaluDonlSize + } + innerPacket = &h265FragmentationPacket{ + payloadHeader: rebuildPACIHeader(header, paciFields), + fuHeader: H265FragmentationUnitHeader(innerNalu[0]), + donl: donl, + payload: innerNalu[innerPayloadStart:], + } + default: + // header + fuHeader + minLength := h265NaluHeaderSize + if withDONL { + minLength += h265NaluDonlSize + } + if len(innerNalu) < minLength { + return nil, errShortPacket + } + var donl *uint16 + innerPayloadStart := 0 + if withDONL { + donlVal := binary.BigEndian.Uint16(innerNalu[0:2]) + donl = &donlVal + innerPayloadStart += h265NaluDonlSize + } + innerPacket = &h265SingleNALUnitPacket{ + payloadHeader: rebuildPACIHeader(header, paciFields), + donl: donl, + payload: innerNalu[innerPayloadStart:], + } + } + + packet := H265PACIPacket{ + header, + paciFields, + phes, + innerPacket, + } + + return &packet, nil +} + +func newH265PACIPacketHeaders(originalHeader H265NALUHeader, phes []byte) (*H265NALUHeader, *paciHeaderFields, error) { + if len(phes) >= 32 { + return nil, nil, errH265PACIPHESTooLong + } + newHeader := H265NALUHeader( + uint16(h265NaluPACIPacketType)<<9 | + uint16(originalHeader.LayerID())<<3 | + uint16(originalHeader.TID()), + ) + a := uint16(0) + if originalHeader.F() { + a = 1 + } + f0 := uint16(0) + if len(phes) > 0 { + f0 = 1 + } + headerFields := paciHeaderFields( + (a << 15) | + (uint16(originalHeader.Type()) << 9) | + (uint16(len(phes)) << 4) | // nolint: gosec // G115 false positive + (f0 << 3), + ) + + return &newHeader, &headerFields, nil +} + +func newH265PACIPacket(inner isH265Packet) (*H265PACIPacket, error) { + _, ok := inner.(*H265PACIPacket) + if ok { + return nil, errInvalidH265PacketType + } + + header, headerFields, err := newH265PACIPacketHeaders(inner.header(), nil) + if err != nil { + return nil, err + } + + packet := H265PACIPacket{ + payloadHeader: *header, + paciHeaderFields: *headerFields, + phes: nil, + payload: inner, + } + + return &packet, nil +} + +// Unmarshal parses the passed byte slice and stores the result in the H265PACIPacket this method is called upon. +func (p *H265PACIPacket) Unmarshal(payload []byte) ([]byte, error) { + // Bad behavior, no DONL parsing + packet, err := parseH265PACIPacket(payload, false) + if err != nil { + return nil, err + } + + p.payloadHeader = packet.payloadHeader + p.paciHeaderFields = packet.paciHeaderFields + p.phes = packet.phes + p.payload = packet.payload + + return nil, nil +} + +func (p *H265PACIPacket) isH265Packet() {} + +func (p *H265PACIPacket) header() H265NALUHeader { + return p.payloadHeader +} + +func (p *H265PACIPacket) serialize(buf []byte) []byte { + buf = binary.BigEndian.AppendUint16(buf, uint16(p.payloadHeader)) + + buf = binary.BigEndian.AppendUint16(buf, uint16(p.paciHeaderFields)) + + if len(p.phes) > 0 { + buf = append(buf, p.phes...) + } + + fragment, ok := p.payload.(*h265FragmentationPacket) + if ok { + buf = append(buf, byte(fragment.fuHeader)) + if fragment.donl != nil { + buf = binary.BigEndian.AppendUint16(buf, *fragment.donl) + } + buf = append(buf, fragment.payload...) + } + + aggregation, ok := p.payload.(*h265AggregationPacket) + if ok { + if aggregation.donl != nil { + buf = binary.BigEndian.AppendUint16(buf, *aggregation.donl) + } + buf = append(buf, aggregation.payload...) + } + + single, ok := p.payload.(*h265SingleNALUnitPacket) + if ok { + if single.donl != nil { + buf = binary.BigEndian.AppendUint16(buf, *single.donl) + } + buf = append(buf, single.payload...) + } + + return buf +} + +// +// Temporal Scalability Control Information +// + +// H265TSCI is a Temporal Scalability Control Information header extension. +// Reference: https://datatracker.ietf.org/doc/html/rfc7798#section-4.5 +type H265TSCI uint32 + +// TL0PICIDX see RFC7798 for more details. +func (h H265TSCI) TL0PICIDX() uint8 { + const m1 = 0xFFFF0000 + const m2 = 0xFF00 + + return uint8((((h & m1) >> 16) & m2) >> 8) // nolint: gosec // G115 false positive +} + +// IrapPicID see RFC7798 for more details. +func (h H265TSCI) IrapPicID() uint8 { + const m1 = 0xFFFF0000 + const m2 = 0x00FF + + return uint8(((h & m1) >> 16) & m2) // nolint: gosec // G115 false positive +} + +// S see RFC7798 for more details. +func (h H265TSCI) S() bool { + const m1 = 0xFF00 + const m2 = 0b10000000 + + return (uint8((h&m1)>>8) & m2) != 0 // nolint: gosec // G115 false positive +} + +// E see RFC7798 for more details. +func (h H265TSCI) E() bool { + const m1 = 0xFF00 + const m2 = 0b01000000 + + return (uint8((h&m1)>>8) & m2) != 0 // nolint: gosec // G115 false positive +} + +// RES see RFC7798 for more details. +func (h H265TSCI) RES() uint8 { + const m1 = 0xFF00 + const m2 = 0b00111111 + + return uint8((h&m1)>>8) & m2 // nolint: gosec // G115 false positive +} + +// +// H265 Packet interface +// + +type isH265Packet interface { + isH265Packet() + header() H265NALUHeader + serialize([]byte) []byte +} + +var ( + _ isH265Packet = (*h265FragmentationPacket)(nil) + _ isH265Packet = (*H265PACIPacket)(nil) + _ isH265Packet = (*h265SingleNALUnitPacket)(nil) + _ isH265Packet = (*h265AggregationPacket)(nil) +) + +// +// Packet implementation +// + +// H265Depacketizer unmarshals an H265 RTP stream into an Annex-B one. +type H265Depacketizer struct { + hasDonl bool + partials []h265FragmentationPacket + + videoDepacketizer +} + +func (d *H265Depacketizer) handleSingleUnit(output []byte, single h265SingleNALUnitPacket) []byte { + d.partials = d.partials[:0] + output = single.toAnnexB(output) + + return output +} + +func (d *H265Depacketizer) handleAggregationUnit(output []byte, aggregation h265AggregationPacket) ([]byte, error) { + d.partials = d.partials[:0] + aggregated, err := splitH265AggregationPacket(aggregation) + if err != nil { + return nil, err + } + + for _, p := range aggregated { + output = p.toAnnexB(output) + } + + return output, nil +} + +func (d *H265Depacketizer) handleFragmentationUnit(output []byte, fragment h265FragmentationPacket) ([]byte, error) { + if fragment.fuHeader.E() { // nolint: nestif + if len(d.partials) == 0 { + return output, nil + } + + d.partials = append(d.partials, fragment) + + rebuilt, err := rebuildH265FragmentationPackets(d.partials) + if err != nil { + return nil, err + } + output = d.handleSingleUnit(output, *rebuilt) + d.partials = d.partials[:0] + + return output, nil + } else { + // discard lost partial fragments + if fragment.fuHeader.S() { + d.partials = d.partials[:0] + } else if len(d.partials) == 0 { + return nil, errExpectFragmentationStartUnit + } + + d.partials = append(d.partials, fragment) + + return nil, nil + } +} + +func (d *H265Depacketizer) Unmarshal(payload []byte) ([]byte, error) { // nolint:cyclop, gocognit + if len(payload) < h265NaluHeaderSize { + return nil, errShortPacket + } + + header := H265NALUHeader(binary.BigEndian.Uint16(payload[0:2])) + + output := make([]byte, 0) + + switch { + case header.IsFragmentationUnit(): + parseDonl := len(d.partials) == 0 && d.hasDonl + fragment, err := parseH265FragmentationPacket(payload, parseDonl) + if err != nil { + return nil, err + } + output, err = d.handleFragmentationUnit(output, *fragment) + if err != nil { + return nil, err + } + case header.IsAggregationPacket(): + aggregation, err := parseH265AggregationPacket(payload, d.hasDonl) + if err != nil { + return nil, err + } + output, err = d.handleAggregationUnit(output, *aggregation) + if err != nil { + return nil, err + } + case header.IsPACIPacket(): + paci, err := parseH265PACIPacket(payload, d.hasDonl) + if err != nil { + return nil, err + } + fragment, ok := paci.payload.(*h265FragmentationPacket) + if ok { + output, err = d.handleFragmentationUnit(output, *fragment) + if err != nil { + return nil, err + } + } + aggregation, ok := paci.payload.(*h265AggregationPacket) + if ok { + output, err = d.handleAggregationUnit(output, *aggregation) + if err != nil { + return nil, err + } + } + single, ok := paci.payload.(*h265SingleNALUnitPacket) + if ok { + output = d.handleSingleUnit(output, *single) + } + default: + single, err := parseH265SingleNalUnitPacket(payload, d.hasDonl) + if err != nil { + return nil, err + } + output = d.handleSingleUnit(output, *single) + } + + return output, nil +} + +func (d *H265Depacketizer) IsPartitionHead(payload []byte) bool { + if len(payload) < 2 { + return false + } + header := H265NALUHeader(binary.BigEndian.Uint16(payload[0:2])) + if header.IsFragmentationUnit() { + if len(payload) < 3 { + return false + } + fuHeader := H265FragmentationUnitHeader(payload[2]) + + return fuHeader.S() + } + + return true +} + +func (d *H265Depacketizer) IsPartitionTail(marker bool, payload []byte) bool { + if len(payload) < 3 { + return marker + } + header := H265NALUHeader(binary.BigEndian.Uint16(payload[0:2])) + if !header.IsFragmentationUnit() { + return marker + } + fuHeader := H265FragmentationUnitHeader(payload[2]) + + return fuHeader.E() +} + +// H265Packet represents a H265 packet, stored in the payload of an RTP packet. +// +// Deprecated: Use H265Depacketizer instead. +type H265Packet struct { + packet isH265Packet + + H265Depacketizer +} + +// WithDONL can be called to specify whether or not DONL might be parsed. +// DONL may need to be parsed if `sprop-max-don-diff` is greater than 0 on the RTP stream. +func (p *H265Packet) WithDONL(value bool) { + p.H265Depacketizer.hasDonl = value +} + +// Packet returns the populated packet. +// Must be casted to one of: +// - *H265SingleNALUnitPacket +// - *H265FragmentationUnitPacket +// - *H265AggregationPacket +// - *H265PACIPacket. +// +// Deprecated: will always return nil. +func (p *H265Packet) Packet() isH265Packet { + return p.packet +} + +// H265Payloader payloads H265 packets. +type H265Payloader struct { + // Deprecated: Has no effect. + AddDONL bool + SkipAggregation bool +} + +// Payload fragments a H265 packet across one or more byte arrays. +func (p *H265Payloader) Payload(mtu uint16, payload []byte) [][]byte { // nolint:cyclop + // SampleBuilder reuses the payload buffer so this is required + tmp := make([]byte, len(payload)) + copy(tmp, payload) + payload = tmp + + var payloads [][]byte + naluBuffer := make([]h265SingleNALUnitPacket, 0) + + flushBuffer := func() { + switch len(naluBuffer) { + case 0: + return + case 1: + packetized := naluBuffer[0].serialize(make([]byte, 0, naluBuffer[0].wireSize())) + naluBuffer = naluBuffer[:0] + payloads = append(payloads, packetized) + default: + aggrPacket, err := newH265AggregationPacket(naluBuffer) + naluBuffer = naluBuffer[:0] + if err != nil { + return + } + packetized := aggrPacket.serialize(make([]byte, 0)) + payloads = append(payloads, packetized) + } + } + + emitNalus(payload, func(nalu []byte) { + if len(nalu) < h265NaluHeaderSize { + return + } + + header := H265NALUHeader(binary.BigEndian.Uint16(nalu[0:2])) + + if header.IsAggregationPacket() || + header.IsFragmentationUnit() || + header.IsPACIPacket() { + return + } + + packet := h265SingleNALUnitPacket{ + header, + nil, + nalu[2:], + } + + if len(nalu) > int(mtu) { // nolint: nestif + flushBuffer() + fragments, err := newH265FragmentationPackets(mtu, &packet) + if err != nil { + return + } + for _, fragment := range fragments { + payloads = append(payloads, fragment.serialize(make([]byte, 0))) + } + } else { + if p.SkipAggregation { + payloads = append(payloads, nalu) + + return + } + if len(naluBuffer) == 0 { + if canAggregateH265(mtu, &packet) { + naluBuffer = append(naluBuffer, packet) + } else { + payloads = append(payloads, nalu) + } + } else { + // can't fit any more packets, just send what we have and make current first in buffer + if shouldAggregateH265Now(mtu, naluBuffer, packet) { + flushBuffer() + } + naluBuffer = append(naluBuffer, packet) + } + } + }) + + flushBuffer() + + return payloads +} diff --git a/vendor/github.com/pion/rtp/codecs/h266_packet.go b/vendor/github.com/pion/rtp/codecs/h266_packet.go new file mode 100644 index 0000000..d2098bc --- /dev/null +++ b/vendor/github.com/pion/rtp/codecs/h266_packet.go @@ -0,0 +1,785 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package codecs + +import ( + "bytes" + "encoding/binary" + "errors" +) + +var ( + errNalCorrupted = errors.New("NAL could not be parsed to one of known types") + errInvalidNalType = errors.New("NAL types 28 and 29 are reserved for RTP streams") + errPacketTooLarge = errors.New("packet passed in is larger than 65535 bytes") + errNotEnoughPackets = errors.New("aggregation and fragmentation packets requires at least 2 packets") + errFirstFragmentationUnitMissing = errors.New("expecting the first fragmentation packet") + errLastFragmentationUnitMissing = errors.New("expecting the last fragmentation packet") +) + +const ( + // sizeof(uint16). + h266NaluHeaderSize = 2 + // sizeof(uint16). + h266NaluDonlSize = 2 + // https://datatracker.ietf.org/doc/html/rfc9328#section-4.3.2 + h266NaluAggregationPacketType = 28 + // https://datatracker.ietf.org/doc/html/rfc9328#section-4.3.3 + h266NaluFragmentationUnitType = 29 + h266AggregatedPacketMaxSize = ^uint16(0) + h266AggregatedPacketLengthSize = 2 +) + +func emitH266Nalus(nals []byte, emit func([]byte)) { + // look for 3-byte NALU start code + start := bytes.Index(nals, naluStartCode) + offset := 3 + + if start == -1 { + // no start code, emit the whole buffer + emit(nals) + + return + } + + length := len(nals) + + for start < length { + // look for the next NALU start (end of this NALU) + end := bytes.Index(nals[start+offset:], naluStartCode) + if end == -1 { + // no more NALUs, emit the rest of the buffer + emit(nals[start+offset:]) + + break + } + + // next NALU start + nextStart := start + offset + end + + // check if the next NALU is actually a 4-byte start code + endIs4Byte := nals[nextStart-1] == 0 + if endIs4Byte { + nextStart-- + } + + emit(nals[start+offset : nextStart]) + + start = nextStart + + if endIs4Byte { + offset = 4 + } else { + offset = 3 + } + } +} + +type isH266Packet interface { + isH266Packet() + // write the packet in its wire format + packetize([]byte) []byte +} + +// h266NALUHeader is an H266 NAL Unit Header. +// https://datatracker.ietf.org/doc/html/rfc9328#section-1.1.4 +// +// +---------------+---------------+ +// |0|1|2|3|4|5|6|7|0|1|2|3|4|5|6|7| +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |F|Z| LayerID | Type | TID | +// +---------------+---------------+ +type h266NALUHeader uint16 + +func newH266NALUHeader(highByte, lowByte uint8) h266NALUHeader { + return h266NALUHeader((uint16(highByte) << 8) | uint16(lowByte)) +} + +// F is the forbidden bit, should always be 0. +func (h h266NALUHeader) F() bool { + return (uint16(h) >> 15) != 0 +} + +// Z is a reserved bit, should always be 0. +func (h h266NALUHeader) Z() bool { + const mask = 0b01000000 << 8 + + return (uint16(h) & mask) != 0 +} + +// Type of NAL Unit. +func (h h266NALUHeader) Type() uint8 { + const mask = 0b11111000 + + return uint8((h & mask) >> 3) // nolint: gosec // G115 false positive +} + +// IsTypeVCLUnit returns whether or not the NAL Unit type is a VCL NAL unit. +func (h h266NALUHeader) IsTypeVCLUnit() bool { + // Section 7.4.2.2 http://www.itu.int/rec/T-REC-H.266 + return (h.Type() <= 11) +} + +func (h h266NALUHeader) LayerID() uint8 { + // 00111111 00000000 + const mask = 0b00111111 << 8 + + return uint8((uint16(h) & mask) >> 8) // nolint: gosec // G115 false positive +} + +func (h h266NALUHeader) TID() uint8 { + const mask = 0b00000111 + + return uint8(uint16(h) & mask) // nolint: gosec // G115 false positive +} + +// IsAggregationPacket returns whether or not the packet is an Aggregation packet. +func (h h266NALUHeader) IsAggregationPacket() bool { + return h.Type() == h266NaluAggregationPacketType +} + +// IsFragmentationUnit returns whether or not the packet is a Fragmentation Unit packet. +func (h h266NALUHeader) IsFragmentationUnit() bool { + return h.Type() == h266NaluFragmentationUnitType +} + +// h266SingleNALUnitPacket represents a NALU packet, containing exactly one NAL unit. +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | PayloadHdr | DONL (conditional) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | | +// | NAL unit payload data | +// | | +// | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | :...OPTIONAL RTP padding | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// +// Reference: https://datatracker.ietf.org/doc/html/rfc7798#section-4.4.1 +type h266SingleNALUnitPacket struct { + // payloadHeader is the header of the H266 packet. + payloadHeader h266NALUHeader + // donl is a 16-bit field, that may or may not be present. + donl *uint16 + // payload of the NAL unit. + payload []byte +} + +func (p *h266SingleNALUnitPacket) wireSize() int { + donlSize := 0 + if p.donl != nil { + donlSize = 2 + } + + return h266NaluHeaderSize + donlSize + len(p.payload) +} + +func (p h266SingleNALUnitPacket) isH266Packet() {} + +func (p *h266SingleNALUnitPacket) packetize(buf []byte) []byte { + buf = binary.BigEndian.AppendUint16(buf, uint16(p.payloadHeader)) + + if p.donl != nil { + buf = binary.BigEndian.AppendUint16(buf, *p.donl) + } + + buf = append(buf, p.payload...) + + return buf +} + +// Aggregation Packet implementation + +// h266AggregationPacket is a single H266 aggregation packet. +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | PayloadHdr (Type=28) | | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | +// | | +// | two or more aggregation units | +// | | +// | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | :...OPTIONAL RTP padding | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// +// Reference: https://datatracker.ietf.org/doc/html/rfc9328#section-4.3.2 +type h266AggregationPacket struct { + payloadHeader h266NALUHeader + donl *uint16 + payload []byte +} + +// returns whether this NALU can even fit inside an AP with another NALU. +func canAggregate(mtu uint16, packet *h266SingleNALUnitPacket) bool { + // must leave enough space for the AP header, optionally its DONL field, 2 length headers and a 2nd AU's header + return packet.wireSize()+(h266AggregatedPacketLengthSize*2)+h266NaluHeaderSize <= int(mtu) +} + +// returns whether inserting a new packet will make this list of packets too big to aggregate within the MTU. +func shouldAggregateNow(mtu uint16, packets []h266SingleNALUnitPacket, newPacket h266SingleNALUnitPacket) bool { + if len(packets) < 1 { + return false + } + // AP header + each AU's size field + totalSize := h266NaluHeaderSize + ((len(packets) + 1) * h266AggregatedPacketLengthSize) + hasDonl := packets[0].donl != nil + // first AU's DONL field + if hasDonl { + totalSize += 2 + } + for _, p := range packets { + totalSize += p.wireSize() + // individual AUs have their DONL fields removed + if hasDonl { + totalSize -= 2 + } + } + totalSize += newPacket.wireSize() + if hasDonl { + totalSize -= 2 + } + + return totalSize > int(mtu) +} + +// Reference: https://datatracker.ietf.org/doc/html/rfc9328#section-4.3.2 +func newH266AggregationPacket(packets []h266SingleNALUnitPacket) (*h266AggregationPacket, error) { + if packets == nil { + return nil, errNilPacket + } + if len(packets) < 2 { + return nil, errNotEnoughPackets + } + + header := uint16(0) + // type 28 + header |= 28 << 3 + + firstPacket := packets[0] + if firstPacket.wireSize() > int(h266AggregatedPacketMaxSize) { + return nil, errPacketTooLarge + } + + fBit := firstPacket.payloadHeader.F() + layerID := firstPacket.payloadHeader.LayerID() + tid := firstPacket.payloadHeader.TID() + + payload := make([]byte, 0) + + for _, packet := range packets { + // following AUs' DONs are derived as the previous AU's DON + 1 + packet.donl = nil + + if packet.wireSize() > int(h266AggregatedPacketMaxSize) { + return nil, errPacketTooLarge + } + + if packet.payloadHeader.F() { + fBit = true + } + pLayerID := packet.payloadHeader.LayerID() + if pLayerID < layerID { + layerID = pLayerID + } + pTid := packet.payloadHeader.TID() + if pTid < tid { + tid = pTid + } + + // nolint: gosec // Already checked for max size + payload = binary.BigEndian.AppendUint16(payload, uint16(packet.wireSize())) + + payload = packet.packetize(payload) + } + + header |= uint16(tid) + header |= uint16(layerID) << 8 + + if fBit { + header |= uint16(0b1) << 15 + } + + packet := h266AggregationPacket{ + h266NALUHeader(header), + firstPacket.donl, + payload, + } + + return &packet, nil +} + +func splitH266AggregationPacket(packet h266AggregationPacket) ([]h266SingleNALUnitPacket, error) { + curDonl := packet.donl + packets := make([]h266SingleNALUnitPacket, 0) + payload := packet.payload + for len(payload) > 0 { + if len(payload) < 2 { + return nil, errShortPacket + } + curLen := binary.BigEndian.Uint16(payload) + if len(payload[2:]) < int(curLen) { + return nil, errShortPacket + } + + parsed, err := parseH266Packet(payload[2:2+curLen], false) + if err != nil { + return nil, err + } + p, ok := parsed.(*h266SingleNALUnitPacket) + if !ok { + return nil, errInvalidNalType + } + if curDonl != nil { + nextDonl := *curDonl + 1 + p.donl = curDonl + curDonl = &nextDonl + } + packets = append(packets, *p) + payload = payload[2+curLen:] + } + if len(packets) < 2 { + return nil, errNotEnoughPackets + } + + return packets, nil +} + +func (p *h266AggregationPacket) isH266Packet() {} + +func (p *h266AggregationPacket) packetize(buf []byte) []byte { + buf = binary.BigEndian.AppendUint16(buf, uint16(p.payloadHeader)) + + if p.donl != nil { + buf = binary.BigEndian.AppendUint16(buf, *p.donl) + } + + buf = append(buf, p.payload...) + + return buf +} + +// Fragmentation Unit implementation + +// h266FragmentationUnitHeader is the header for each H266FragmentationPacket. +// +// +---------------+ +// |0|1|2|3|4|5|6|7| +// +-+-+-+-+-+-+-+-+ +// |S|E|P| FuType | +// +---------------+ +type h266FragmentationUnitHeader uint8 + +func newH266FragmentationUnitHeader( + payloadHeader h266NALUHeader, + s, e, p bool, //nolint:unparam +) h266FragmentationUnitHeader { + header := payloadHeader.Type() + if s { + header |= 0b1 << 7 + } + if e { + header |= 0b1 << 6 + } + if p { + header |= 0b1 << 5 + } + + return h266FragmentationUnitHeader(header) +} + +// S represents the start of a fragmented NAL unit. +func (h h266FragmentationUnitHeader) S() bool { + const mask = 0b10000000 + + return (h & mask) != 0 +} + +// E represents the end of a fragmented NAL unit. +func (h h266FragmentationUnitHeader) E() bool { + const mask = 0b01000000 + + return (h & mask) != 0 +} + +// P indicates the last FU of the last VCL NAL unit of a coded picture. +func (h h266FragmentationUnitHeader) P() bool { + const mask = 0b00100000 + + return (h & mask) != 0 +} + +// FuType MUST be equal to the field Type of the fragmented NAL unit. +func (h h266FragmentationUnitHeader) FuType() uint8 { + const mask = 0b00011111 + + return uint8(h) & mask +} + +// h266FragmentationPacket is a single H266 Fragmentation Unit. +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | PayloadHdr (Type=29) | FU header | DONL (cond) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-| +// | DONL (cond) | | +// |-+-+-+-+-+-+-+-+ | +// | FU payload | +// | | +// | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | :...OPTIONAL RTP padding | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// +// Reference: https://datatracker.ietf.org/doc/html/rfc9328#section-4.3.3 +type h266FragmentationPacket struct { + payloadHeader h266NALUHeader + fuHeader h266FragmentationUnitHeader + donl *uint16 + payload []byte +} + +// Replaces the original header's type with 29, while keeping other fields. +func newH266FragmentationPacketHeader(payloadHeader h266NALUHeader) h266NALUHeader { + typeMask := ^uint16(0b11111000) + + return h266NALUHeader((uint16(payloadHeader) & typeMask) | (h266NaluFragmentationUnitType << 3)) +} + +// Replaces the FU's payload header's type with the FU Header's type, while keeping other fields. +func rebuildH266FragmentationPacketHeader( + payloadHeader h266NALUHeader, + fuHeader h266FragmentationUnitHeader, +) h266NALUHeader { + typeMask := ^uint16(0b11111000) + origType := uint8(fuHeader) & 0b00011111 + + return h266NALUHeader((uint16(payloadHeader) & typeMask) | (uint16(origType) << 3)) +} + +// Splits a H266SingleNALUnitPacket into many FU packets. +// +// Errors if the packet would result in a single FU packet. +// +// The P bit is not set in any case. +func newH266FragmentationPackets(mtu uint16, packet *h266SingleNALUnitPacket) ([]h266FragmentationPacket, error) { + if packet == nil { + return nil, errNilPacket + } + + // size of Header, FU header and (optionally) the DONL + overheadSize := 3 + if packet.donl != nil { + overheadSize += 2 + } + + sliceSize := int(mtu) - overheadSize + + if len(packet.payload) <= sliceSize { + return nil, errShortPacket + } + + packets := make([]h266FragmentationPacket, 0) + header := newH266FragmentationPacketHeader(packet.payloadHeader) + + fuPayload := packet.payload + + firstPacket := h266FragmentationPacket{ + payloadHeader: header, + fuHeader: newH266FragmentationUnitHeader(packet.payloadHeader, true, false, false), + donl: packet.donl, + payload: fuPayload[:sliceSize], + } + packets = append(packets, firstPacket) + fuPayload = fuPayload[sliceSize:] + + for len(fuPayload) > sliceSize { + p := h266FragmentationPacket{ + payloadHeader: header, + fuHeader: newH266FragmentationUnitHeader(packet.payloadHeader, false, false, false), + donl: nil, + payload: fuPayload[:sliceSize], + } + packets = append(packets, p) + + fuPayload = fuPayload[sliceSize:] + } + + lastPacket := h266FragmentationPacket{ + payloadHeader: header, + fuHeader: newH266FragmentationUnitHeader(packet.payloadHeader, false, true, false), + donl: nil, + payload: fuPayload, + } + packets = append(packets, lastPacket) + + return packets, nil +} + +func rebuildH266FragmentationPackets(packets []h266FragmentationPacket) (*h266SingleNALUnitPacket, error) { + if len(packets) < 2 { + return nil, errNotEnoughPackets + } + + if !packets[0].fuHeader.S() { + return nil, errFirstFragmentationUnitMissing + } + if !packets[len(packets)-1].fuHeader.E() { + return nil, errLastFragmentationUnitMissing + } + + payload := make([]byte, 0) + for _, fu := range packets { + payload = append(payload, fu.payload...) + } + + rebuilt := h266SingleNALUnitPacket{ + payloadHeader: rebuildH266FragmentationPacketHeader(packets[0].payloadHeader, packets[0].fuHeader), + donl: packets[0].donl, + payload: payload, + } + + return &rebuilt, nil +} + +func (p *h266FragmentationPacket) isH266Packet() {} + +func (p *h266FragmentationPacket) packetize(buf []byte) []byte { + buf = binary.BigEndian.AppendUint16(buf, uint16(p.payloadHeader)) + buf = append(buf, uint8(p.fuHeader)) + + if p.donl != nil { + buf = binary.BigEndian.AppendUint16(buf, *p.donl) + } + + buf = append(buf, p.payload...) + + return buf +} + +func parseH266Packet(buf []byte, hasDonl bool) (isH266Packet, error) { // nolint:cyclop + if buf == nil { + return nil, errNilPacket + } + minLength := h266NaluHeaderSize + payloadStart := h265NaluHeaderSize + donlStart := h266NaluHeaderSize + + if hasDonl { + payloadStart += h266NaluDonlSize + minLength += h266NaluDonlSize + } + + if len(buf) < minLength { + return nil, errShortPacket + } + + header := newH266NALUHeader(buf[0], buf[1]) + + // take into account FuPacket + if header.IsFragmentationUnit() { + payloadStart += 1 + donlStart += 1 + minLength += 1 + } + + if len(buf) < minLength { + return nil, errShortPacket + } + + var donl *uint16 + if hasDonl { + donlVal := binary.BigEndian.Uint16(buf[donlStart : donlStart+2]) + donl = &donlVal + } + + switch { + case header.IsAggregationPacket(): + packet := &h266AggregationPacket{ + payloadHeader: header, + donl: donl, + payload: buf[payloadStart:], + } + + return packet, nil + case header.IsFragmentationUnit(): + packet := &h266FragmentationPacket{ + payloadHeader: header, + fuHeader: h266FragmentationUnitHeader(buf[2]), + donl: donl, + payload: buf[payloadStart:], + } + + return packet, nil + default: + packet := &h266SingleNALUnitPacket{ + payloadHeader: header, + donl: donl, + payload: buf[payloadStart:], + } + + return packet, nil + } +} + +type H266Depacketizer struct { + hasDonl bool + partials []h266FragmentationPacket +} + +func (d *H266Depacketizer) Unmarshal(packet []byte) ([]byte, error) { //nolint: cyclop + if packet == nil { + return nil, errNilPacket + } + if len(packet) < 2 { + return nil, errShortPacket + } + + parsedHeader := newH266NALUHeader(packet[0], packet[1]) + + // we are expecting another FU but only the first FU of a series has the DONL field present + isFrag := parsedHeader.IsFragmentationUnit() + parseDonl := d.hasDonl && ((len(d.partials) == 0 && isFrag) || !isFrag) + + parsed, err := parseH266Packet(packet, parseDonl) + if err != nil { + return nil, err + } + output := make([]byte, 0) + + fragment, ok := parsed.(*h266FragmentationPacket) + + if ok { // nolint:nestif + if fragment.fuHeader.E() { + d.partials = append(d.partials, *fragment) + output = append(output, annexbNALUStartCode...) + + rebuilt, err := rebuildH266FragmentationPackets(d.partials) + if err != nil { + return nil, err + } + rebuilt.donl = nil + output = append(output, rebuilt.packetize(make([]byte, 0))...) + d.partials = d.partials[:0] + + return output, nil + } else { + // discard lost partial fragments + if fragment.fuHeader.S() { + d.partials = d.partials[:0] + } else if len(d.partials) == 0 { + return nil, errExpectFragmentationStartUnit + } + + d.partials = append(d.partials, *fragment) + + return nil, nil + } + } + + d.partials = d.partials[:0] + + aggregation, ok := parsed.(*h266AggregationPacket) + if ok { + aggregated, err := splitH266AggregationPacket(*aggregation) + if err != nil { + return nil, err + } + for _, p := range aggregated { + output = append(output, annexbNALUStartCode...) + p.donl = nil + output = p.packetize(output) + } + + return output, nil + } + + output = append(output, annexbNALUStartCode...) + single, ok := parsed.(*h266SingleNALUnitPacket) + if !ok { + return nil, errNalCorrupted + } + + single.donl = nil + output = single.packetize(output) + + return output, nil +} + +type H266Packetizer struct { + naluBuffer []h266SingleNALUnitPacket +} + +func (p *H266Packetizer) Payload(mtu uint16, payload []byte) [][]byte { //nolint: cyclop + var payloads [][]byte + + flushBuffer := func() { + switch len(p.naluBuffer) { + case 0: + return + case 1: + packetized := p.naluBuffer[0].packetize(make([]byte, 0)) + p.naluBuffer = p.naluBuffer[:0] + payloads = append(payloads, packetized) + default: + aggrPacket, err := newH266AggregationPacket(p.naluBuffer) + p.naluBuffer = p.naluBuffer[:0] + if err != nil { + return + } + packetized := aggrPacket.packetize(make([]byte, 0)) + payloads = append(payloads, packetized) + } + } + + emitH266Nalus(payload, func(nalu []byte) { + if len(nalu) < h266NaluHeaderSize { + return + } + + parsedPacket, err := parseH266Packet(nalu, false) + if err != nil { + return + } + + // ignores RFC9328 packets + packet, ok := parsedPacket.(*h266SingleNALUnitPacket) + if !ok { + return + } + + if len(nalu) > int(mtu) { //nolint:nestif + flushBuffer() + fragments, err := newH266FragmentationPackets(mtu, packet) + if err != nil { + return + } + for _, f := range fragments { + packetized := f.packetize(make([]byte, 0)) + payloads = append(payloads, packetized) + } + } else { + if len(p.naluBuffer) == 0 { + if canAggregate(mtu, packet) { + p.naluBuffer = append(p.naluBuffer, *packet) + } else { + payloads = append(payloads, nalu) + } + } else { + // can't fit any more packets, just send what we have and make current first in buffer + if shouldAggregateNow(mtu, p.naluBuffer, *packet) { + flushBuffer() + } + p.naluBuffer = append(p.naluBuffer, *packet) + } + } + }) + + flushBuffer() + + return payloads +} diff --git a/vendor/github.com/pion/rtp/codecs/opus_packet.go b/vendor/github.com/pion/rtp/codecs/opus_packet.go new file mode 100644 index 0000000..8d09f58 --- /dev/null +++ b/vendor/github.com/pion/rtp/codecs/opus_packet.go @@ -0,0 +1,51 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package codecs + +// OpusPayloader payloads Opus packets. +type OpusPayloader struct{} + +// Payload fragments an Opus packet across one or more byte arrays. +func (p *OpusPayloader) Payload(_ uint16, payload []byte) [][]byte { + if payload == nil { + return [][]byte{} + } + + out := make([]byte, len(payload)) + copy(out, payload) + + return [][]byte{out} +} + +// OpusPacket represents the Opus header that is stored in the payload of an RTP Packet. +type OpusPacket struct { + Payload []byte + + audioDepacketizer +} + +// Unmarshal parses the passed byte slice and stores the result in the OpusPacket this method is called upon. +func (p *OpusPacket) Unmarshal(packet []byte) ([]byte, error) { + if packet == nil { + return nil, errNilPacket + } else if len(packet) == 0 { + return nil, errShortPacket + } + + p.Payload = packet + + return packet, nil +} + +// OpusPartitionHeadChecker checks Opus partition head. +// +// Deprecated: replaced by OpusPacket.IsPartitionHead(). +type OpusPartitionHeadChecker struct{} + +// IsPartitionHead checks whether if this is a head of the Opus partition. +// +// Deprecated: replaced by OpusPacket.IsPartitionHead(). +func (*OpusPartitionHeadChecker) IsPartitionHead(packet []byte) bool { + return (&OpusPacket{}).IsPartitionHead(packet) +} diff --git a/vendor/github.com/pion/rtp/codecs/vp8_packet.go b/vendor/github.com/pion/rtp/codecs/vp8_packet.go new file mode 100644 index 0000000..8538098 --- /dev/null +++ b/vendor/github.com/pion/rtp/codecs/vp8_packet.go @@ -0,0 +1,237 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package codecs + +// VP8Payloader payloads VP8 packets. +type VP8Payloader struct { + EnablePictureID bool + pictureID uint16 +} + +const ( + vp8HeaderSize = 1 +) + +// Payload fragments a VP8 packet across one or more byte arrays. +func (p *VP8Payloader) Payload(mtu uint16, payload []byte) [][]byte { //nolint:cyclop + /* + * https://tools.ietf.org/html/rfc7741#section-4.2 + * + * 0 1 2 3 4 5 6 7 + * +-+-+-+-+-+-+-+-+ + * |X|R|N|S|R| PID | (REQUIRED) + * +-+-+-+-+-+-+-+-+ + * X: |I|L|T|K| RSV | (OPTIONAL) + * +-+-+-+-+-+-+-+-+ + * I: |M| PictureID | (OPTIONAL) + * +-+-+-+-+-+-+-+-+ + * L: | TL0PICIDX | (OPTIONAL) + * +-+-+-+-+-+-+-+-+ + * T/K: |TID|Y| KEYIDX | (OPTIONAL) + * +-+-+-+-+-+-+-+-+ + * S: Start of VP8 partition. SHOULD be set to 1 when the first payload + * octet of the RTP packet is the beginning of a new VP8 partition, + * and MUST NOT be 1 otherwise. The S bit MUST be set to 1 for the + * first packet of each encoded frame. + */ + + usingHeaderSize := vp8HeaderSize + if p.EnablePictureID { + switch { + case p.pictureID == 0: + case p.pictureID < 128: + usingHeaderSize = vp8HeaderSize + 2 + default: + usingHeaderSize = vp8HeaderSize + 3 + } + } + + maxFragmentSize := int(mtu) - usingHeaderSize + + payloadData := payload + payloadDataRemaining := len(payload) + + payloadDataIndex := 0 + var payloads [][]byte + + // Make sure the fragment/payload size is correct + if minInt(maxFragmentSize, payloadDataRemaining) <= 0 { + return payloads + } + first := true + for payloadDataRemaining > 0 { + currentFragmentSize := minInt(maxFragmentSize, payloadDataRemaining) + out := make([]byte, usingHeaderSize+currentFragmentSize) + + if first { + out[0] = 0x10 + first = false + } + if p.EnablePictureID { + switch usingHeaderSize { + case vp8HeaderSize: + case vp8HeaderSize + 2: + out[0] |= 0x80 + out[1] |= 0x80 + out[2] |= uint8(p.pictureID & 0x7F) // nolint: gosec // G115 false positive + case vp8HeaderSize + 3: + out[0] |= 0x80 + out[1] |= 0x80 + out[2] |= 0x80 | uint8((p.pictureID>>8)&0x7F) // nolint: gosec // G115 false positive + out[3] |= uint8(p.pictureID & 0xFF) // nolint: gosec // G115 false positive + } + } + + copy(out[usingHeaderSize:], payloadData[payloadDataIndex:payloadDataIndex+currentFragmentSize]) + payloads = append(payloads, out) + + payloadDataRemaining -= currentFragmentSize + payloadDataIndex += currentFragmentSize + } + + p.pictureID++ + p.pictureID &= 0x7FFF + + return payloads +} + +// VP8Packet represents the VP8 header that is stored in the payload of an RTP Packet. +type VP8Packet struct { + // Required Header + X uint8 /* extended control bits present */ + N uint8 /* when set to 1 this frame can be discarded */ + S uint8 /* start of VP8 partition */ + PID uint8 /* partition index */ + + // Extended control bits + I uint8 /* 1 if PictureID is present */ + L uint8 /* 1 if TL0PICIDX is present */ + T uint8 /* 1 if TID is present */ + K uint8 /* 1 if KEYIDX is present */ + + // Optional extension + PictureID uint16 /* 8 or 16 bits, picture ID */ + TL0PICIDX uint8 /* 8 bits temporal level zero index */ + TID uint8 /* 2 bits temporal layer index */ + Y uint8 /* 1 bit layer sync bit */ + KEYIDX uint8 /* 5 bits temporal key frame index */ + + Payload []byte + + videoDepacketizer +} + +// Unmarshal parses the passed byte slice and stores the result in the VP8Packet this method is called upon. +func (p *VP8Packet) Unmarshal(payload []byte) ([]byte, error) { //nolint:gocognit,cyclop + if payload == nil { + return nil, errNilPacket + } + + payloadLen := len(payload) + + payloadIndex := 0 + + if payloadIndex >= payloadLen { + return nil, errShortPacket + } + p.X = (payload[payloadIndex] & 0x80) >> 7 + p.N = (payload[payloadIndex] & 0x20) >> 5 + p.S = (payload[payloadIndex] & 0x10) >> 4 + p.PID = payload[payloadIndex] & 0x07 + + payloadIndex++ + + if p.X == 1 { + if payloadIndex >= payloadLen { + return nil, errShortPacket + } + p.I = (payload[payloadIndex] & 0x80) >> 7 + p.L = (payload[payloadIndex] & 0x40) >> 6 + p.T = (payload[payloadIndex] & 0x20) >> 5 + p.K = (payload[payloadIndex] & 0x10) >> 4 + payloadIndex++ + } else { + p.I = 0 + p.L = 0 + p.T = 0 + p.K = 0 + } + + // nolint: nestif + if p.I == 1 { // PID present? + if payloadIndex >= payloadLen { + return nil, errShortPacket + } + if payload[payloadIndex]&0x80 > 0 { // M == 1, PID is 16bit + if payloadIndex+1 >= payloadLen { + return nil, errShortPacket + } + p.PictureID = (uint16(payload[payloadIndex]&0x7F) << 8) | uint16(payload[payloadIndex+1]) + payloadIndex += 2 + } else { + p.PictureID = uint16(payload[payloadIndex]) + payloadIndex++ + } + } else { + p.PictureID = 0 + } + + if p.L == 1 { + if payloadIndex >= payloadLen { + return nil, errShortPacket + } + p.TL0PICIDX = payload[payloadIndex] + payloadIndex++ + } else { + p.TL0PICIDX = 0 + } + + if p.T == 1 || p.K == 1 { // nolint: nestif + if payloadIndex >= payloadLen { + return nil, errShortPacket + } + if p.T == 1 { + p.TID = payload[payloadIndex] >> 6 //nolint:gosec // guarded by first if + p.Y = (payload[payloadIndex] >> 5) & 0x1 //nolint:gosec // guarded by first if + } else { + p.TID = 0 + p.Y = 0 + } + if p.K == 1 { + p.KEYIDX = payload[payloadIndex] & 0x1F //nolint:gosec // guarded by first if + } else { + p.KEYIDX = 0 + } + payloadIndex++ + } else { + p.TID = 0 + p.Y = 0 + p.KEYIDX = 0 + } + + p.Payload = payload[payloadIndex:] + + return p.Payload, nil +} + +// VP8PartitionHeadChecker checks VP8 partition head +// +// Deprecated: replaced by VP8Packet.IsPartitionHead(). +type VP8PartitionHeadChecker struct{} + +// IsPartitionHead checks whether if this is a head of the VP8 partition. +// +// Deprecated: replaced by VP8Packet.IsPartitionHead(). +func (*VP8PartitionHeadChecker) IsPartitionHead(packet []byte) bool { + return (&VP8Packet{}).IsPartitionHead(packet) +} + +// IsPartitionHead checks whether if this is a head of the VP8 partition. +func (*VP8Packet) IsPartitionHead(payload []byte) bool { + if len(payload) < 1 { + return false + } + + return (payload[0] & 0x10) != 0 +} diff --git a/vendor/github.com/pion/rtp/codecs/vp9/bits.go b/vendor/github.com/pion/rtp/codecs/vp9/bits.go new file mode 100644 index 0000000..7bfa47b --- /dev/null +++ b/vendor/github.com/pion/rtp/codecs/vp9/bits.go @@ -0,0 +1,68 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package vp9 + +import "errors" + +var errNotEnoughBits = errors.New("not enough bits") + +func hasSpace(buf []byte, pos int, n int) error { + if n > ((len(buf) * 8) - pos) { + return errNotEnoughBits + } + + return nil +} + +func readFlag(buf []byte, pos *int) (bool, error) { + err := hasSpace(buf, *pos, 1) + if err != nil { + return false, err + } + + return readFlagUnsafe(buf, pos), nil +} + +func readFlagUnsafe(buf []byte, pos *int) bool { + b := (buf[*pos>>0x03] >> (7 - (*pos & 0x07))) & 0x01 + *pos++ + + return b == 1 +} + +func readBits(buf []byte, pos *int, n int) (uint64, error) { + err := hasSpace(buf, *pos, n) + if err != nil { + return 0, err + } + + return readBitsUnsafe(buf, pos, n), nil +} + +func readBitsUnsafe(buf []byte, pos *int, n int) uint64 { + res := 8 - (*pos & 0x07) + if n < res { + bits := uint64((buf[*pos>>0x03] >> (res - n)) & (1<>0x03] & (1<= 8 { + bits = (bits << 8) | uint64(buf[*pos>>0x03]) + *pos += 8 + n -= 8 + } + + if n > 0 { + bits = (bits << n) | uint64(buf[*pos>>0x03]>>(8-n)) + *pos += n + } + + return bits +} diff --git a/vendor/github.com/pion/rtp/codecs/vp9/header.go b/vendor/github.com/pion/rtp/codecs/vp9/header.go new file mode 100644 index 0000000..138fe95 --- /dev/null +++ b/vendor/github.com/pion/rtp/codecs/vp9/header.go @@ -0,0 +1,225 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package vp9 contains a VP9 header parser. +package vp9 + +import ( + "errors" +) + +var ( + errInvalidFrameMarker = errors.New("invalid frame marker") + errWrongFrameSyncByte0 = errors.New("wrong frame_sync_byte_0") + errWrongFrameSyncByte1 = errors.New("wrong frame_sync_byte_1") + errWrongFrameSyncByte2 = errors.New("wrong frame_sync_byte_2") +) + +// HeaderColorConfig is the color_config member of an header. +type HeaderColorConfig struct { + TenOrTwelveBit bool + BitDepth uint8 + ColorSpace uint8 + ColorRange bool + SubsamplingX bool + SubsamplingY bool +} + +func (c *HeaderColorConfig) unmarshal(profile uint8, buf []byte, pos *int) error { // nolint:cyclop + if profile >= 2 { + var err error + c.TenOrTwelveBit, err = readFlag(buf, pos) + if err != nil { + return err + } + + if c.TenOrTwelveBit { + c.BitDepth = 12 + } else { + c.BitDepth = 10 + } + } else { + c.BitDepth = 8 + } + + tmp, err := readBits(buf, pos, 3) + if err != nil { + return err + } + c.ColorSpace = uint8(tmp) // nolint: gosec // G115, no overflow we read 3 bits + + if c.ColorSpace != 7 { // nolint: nestif + var err error + c.ColorRange, err = readFlag(buf, pos) + if err != nil { + return err + } + + if profile == 1 || profile == 3 { + err := hasSpace(buf, *pos, 3) + if err != nil { + return err + } + + c.SubsamplingX = readFlagUnsafe(buf, pos) + c.SubsamplingY = readFlagUnsafe(buf, pos) + *pos++ + } else { + c.SubsamplingX = true + c.SubsamplingY = true + } + } else { + c.ColorRange = true + + if profile == 1 || profile == 3 { + c.SubsamplingX = false + c.SubsamplingY = false + + err := hasSpace(buf, *pos, 1) + if err != nil { + return err + } + *pos++ + } + } + + return nil +} + +// HeaderFrameSize is the frame_size member of an header. +type HeaderFrameSize struct { + FrameWidthMinus1 uint16 + FrameHeightMinus1 uint16 +} + +func (s *HeaderFrameSize) unmarshal(buf []byte, pos *int) error { + err := hasSpace(buf, *pos, 32) + if err != nil { + return err + } + + s.FrameWidthMinus1 = uint16(readBitsUnsafe(buf, pos, 16)) // nolint: gosec // G115 no overflow, we read 16 bits + s.FrameHeightMinus1 = uint16(readBitsUnsafe(buf, pos, 16)) // nolint: gosec // G115 + + return nil +} + +// Header is a VP9 Frame header. +// Specification: +// https://storage.googleapis.com/downloads.webmproject.org/docs/vp9/vp9-bitstream-specification-v0.6-20160331-draft.pdf +type Header struct { + Profile uint8 + ShowExistingFrame bool + FrameToShowMapIdx uint8 + NonKeyFrame bool + ShowFrame bool + ErrorResilientMode bool + ColorConfig *HeaderColorConfig + FrameSize *HeaderFrameSize +} + +// Unmarshal decodes a Header. +func (h *Header) Unmarshal(buf []byte) error { //nolint:cyclop + pos := 0 + + err := hasSpace(buf, pos, 4) + if err != nil { + return err + } + + frameMarker := readBitsUnsafe(buf, &pos, 2) + if frameMarker != 2 { + return errInvalidFrameMarker + } + + profileLowBit := uint8(readBitsUnsafe(buf, &pos, 1)) // nolint: gosec // no overflow, we read 1 bit + profileHighBit := uint8(readBitsUnsafe(buf, &pos, 1)) // nolint: gosec // G115 + h.Profile = profileHighBit<<1 + profileLowBit + + if h.Profile == 3 { + err = hasSpace(buf, pos, 1) + if err != nil { + return err + } + pos++ + } + + h.ShowExistingFrame, err = readFlag(buf, &pos) + if err != nil { + return err + } + + if h.ShowExistingFrame { + var tmp uint64 + tmp, err = readBits(buf, &pos, 3) + if err != nil { + return err + } + h.FrameToShowMapIdx = uint8(tmp) // nolint: gosec // no overflow, we read 3 bits + + return nil + } + + err = hasSpace(buf, pos, 3) + if err != nil { + return err + } + + h.NonKeyFrame = readFlagUnsafe(buf, &pos) + h.ShowFrame = readFlagUnsafe(buf, &pos) + h.ErrorResilientMode = readFlagUnsafe(buf, &pos) + + if !h.NonKeyFrame { // nolint: nestif + err := hasSpace(buf, pos, 24) + if err != nil { + return err + } + + frameSyncByte0 := uint8(readBitsUnsafe(buf, &pos, 8)) // nolint: gosec // no overflow, we read 8 bits + if frameSyncByte0 != 0x49 { + return errWrongFrameSyncByte0 + } + + frameSyncByte1 := uint8(readBitsUnsafe(buf, &pos, 8)) // nolint: gosec // no overflow, we read 8 bits + if frameSyncByte1 != 0x83 { + return errWrongFrameSyncByte1 + } + + frameSyncByte2 := uint8(readBitsUnsafe(buf, &pos, 8)) // nolint: gosec // no overflow, we read 8 bits + if frameSyncByte2 != 0x42 { + return errWrongFrameSyncByte2 + } + + h.ColorConfig = &HeaderColorConfig{} + err = h.ColorConfig.unmarshal(h.Profile, buf, &pos) + if err != nil { + return err + } + + h.FrameSize = &HeaderFrameSize{} + err = h.FrameSize.unmarshal(buf, &pos) + if err != nil { + return err + } + } + + return nil +} + +// Width returns the video width. +func (h Header) Width() uint16 { + if h.FrameSize == nil { + return 0 + } + + return h.FrameSize.FrameWidthMinus1 + 1 +} + +// Height returns the video height. +func (h Header) Height() uint16 { + if h.FrameSize == nil { + return 0 + } + + return h.FrameSize.FrameHeightMinus1 + 1 +} diff --git a/vendor/github.com/pion/rtp/codecs/vp9_packet.go b/vendor/github.com/pion/rtp/codecs/vp9_packet.go new file mode 100644 index 0000000..fbb2d07 --- /dev/null +++ b/vendor/github.com/pion/rtp/codecs/vp9_packet.go @@ -0,0 +1,522 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package codecs + +import ( + "github.com/pion/randutil" + "github.com/pion/rtp/codecs/vp9" +) + +// Use global random generator to properly seed by crypto grade random. +var globalMathRandomGenerator = randutil.NewMathRandomGenerator() // nolint:gochecknoglobals + +// VP9Payloader payloads VP9 packets. +type VP9Payloader struct { + // whether to use flexible mode or non-flexible mode. + FlexibleMode bool + + // InitialPictureIDFn is a function that returns random initial picture ID. + InitialPictureIDFn func() uint16 + + pictureID uint16 + initialized bool +} + +const ( + maxSpatialLayers = 5 + maxVP9RefPics = 3 +) + +// Payload fragments an VP9 packet across one or more byte arrays. +func (p *VP9Payloader) Payload(mtu uint16, payload []byte) [][]byte { + if !p.initialized { + if p.InitialPictureIDFn == nil { + p.InitialPictureIDFn = func() uint16 { + return uint16(globalMathRandomGenerator.Intn(0x7FFF)) // nolint: gosec + } + } + p.pictureID = p.InitialPictureIDFn() & 0x7FFF + p.initialized = true + } + + var payloads [][]byte + if p.FlexibleMode { + payloads = p.payloadFlexible(mtu, payload) + } else { + payloads = p.payloadNonFlexible(mtu, payload) + } + + p.pictureID++ + if p.pictureID >= 0x8000 { + p.pictureID = 0 + } + + return payloads +} + +func (p *VP9Payloader) payloadFlexible(mtu uint16, payload []byte) [][]byte { + /* + * Flexible mode (F=1) + * 0 1 2 3 4 5 6 7 + * +-+-+-+-+-+-+-+-+ + * |I|P|L|F|B|E|V|Z| (REQUIRED) + * +-+-+-+-+-+-+-+-+ + * I: |M| PICTURE ID | (REQUIRED) + * +-+-+-+-+-+-+-+-+ + * M: | EXTENDED PID | (RECOMMENDED) + * +-+-+-+-+-+-+-+-+ + * L: | TID |U| SID |D| (CONDITIONALLY RECOMMENDED) + * +-+-+-+-+-+-+-+-+ -\ + * P,F: | P_DIFF |N| (CONDITIONALLY REQUIRED) - up to 3 times + * +-+-+-+-+-+-+-+-+ -/ + * V: | SS | + * | .. | + * +-+-+-+-+-+-+-+-+ + */ + + headerSize := 3 + maxFragmentSize := int(mtu) - headerSize + payloadDataRemaining := len(payload) + payloadDataIndex := 0 + var payloads [][]byte + + if minInt(maxFragmentSize, payloadDataRemaining) <= 0 { + return [][]byte{} + } + + for payloadDataRemaining > 0 { + currentFragmentSize := minInt(maxFragmentSize, payloadDataRemaining) + out := make([]byte, headerSize+currentFragmentSize) + + out[0] = 0x90 // F=1, I=1 + if payloadDataIndex == 0 { + out[0] |= 0x08 // B=1 + } + if payloadDataRemaining == currentFragmentSize { + out[0] |= 0x04 // E=1 + } + + out[1] = byte(p.pictureID>>8) | 0x80 + out[2] = byte(p.pictureID) + + copy(out[headerSize:], payload[payloadDataIndex:payloadDataIndex+currentFragmentSize]) + payloads = append(payloads, out) + + payloadDataRemaining -= currentFragmentSize + payloadDataIndex += currentFragmentSize + } + + return payloads +} + +func (p *VP9Payloader) payloadNonFlexible(mtu uint16, payload []byte) [][]byte { //nolint:cyclop + /* + * Non-flexible mode (F=0) + * 0 1 2 3 4 5 6 7 + * +-+-+-+-+-+-+-+-+ + * |I|P|L|F|B|E|V|Z| (REQUIRED) + * +-+-+-+-+-+-+-+-+ + * I: |M| PICTURE ID | (RECOMMENDED) + * +-+-+-+-+-+-+-+-+ + * M: | EXTENDED PID | (RECOMMENDED) + * +-+-+-+-+-+-+-+-+ + * L: | TID |U| SID |D| (CONDITIONALLY RECOMMENDED) + * +-+-+-+-+-+-+-+-+ + * | TL0PICIDX | (CONDITIONALLY REQUIRED) + * +-+-+-+-+-+-+-+-+ + * V: | SS | + * | .. | + * +-+-+-+-+-+-+-+-+ + */ + + var header vp9.Header + err := header.Unmarshal(payload) + if err != nil { + return [][]byte{} + } + + payloadDataRemaining := len(payload) + payloadDataIndex := 0 + var payloads [][]byte + + for payloadDataRemaining > 0 { + var headerSize int + if !header.NonKeyFrame && payloadDataIndex == 0 { + headerSize = 3 + 8 + } else { + headerSize = 3 + } + + maxFragmentSize := int(mtu) - headerSize + currentFragmentSize := minInt(maxFragmentSize, payloadDataRemaining) + if currentFragmentSize <= 0 { + return [][]byte{} + } + + out := make([]byte, headerSize+currentFragmentSize) + + out[0] = 0x80 | 0x01 // I=1, Z=1 + + if header.NonKeyFrame { + out[0] |= 0x40 // P=1 + } + if payloadDataIndex == 0 { + out[0] |= 0x08 // B=1 + } + if payloadDataRemaining == currentFragmentSize { + out[0] |= 0x04 // E=1 + } + + out[1] = byte(p.pictureID>>8) | 0x80 + out[2] = byte(p.pictureID) + off := 3 + + if !header.NonKeyFrame && payloadDataIndex == 0 { + out[0] |= 0x02 // V=1 + out[off] = 0x10 | 0x08 // N_S=0, Y=1, G=1 + off++ + + width := header.Width() + out[off] = byte(width >> 8) + off++ + out[off] = byte(width & 0xFF) + off++ + + height := header.Height() + out[off] = byte(height >> 8) + off++ + out[off] = byte(height & 0xFF) + off++ + + out[off] = 0x01 // N_G=1 + off++ + + out[off] = 1<<4 | 1<<2 // TID=0, U=1, R=1 + off++ + + out[off] = 0x01 // P_DIFF=1 + } + + copy(out[headerSize:], payload[payloadDataIndex:payloadDataIndex+currentFragmentSize]) + payloads = append(payloads, out) + + payloadDataRemaining -= currentFragmentSize + payloadDataIndex += currentFragmentSize + } + + return payloads +} + +// VP9Packet represents the VP9 header that is stored in the payload of an RTP Packet. +type VP9Packet struct { + // Required header + I bool // PictureID is present + P bool // Inter-picture predicted frame + L bool // Layer indices is present + F bool // Flexible mode + B bool // Start of a frame + E bool // End of a frame + V bool // Scalability structure (SS) data present + Z bool // Not a reference frame for upper spatial layers + + // Recommended headers + PictureID uint16 // 7 or 16 bits, picture ID + + // Conditionally recommended headers + TID uint8 // Temporal layer ID + U bool // Switching up point + SID uint8 // Spatial layer ID + D bool // Inter-layer dependency used + + // Conditionally required headers + PDiff []uint8 // Reference index (F=1) + TL0PICIDX uint8 // Temporal layer zero index (F=0) + + // Scalability structure headers + NS uint8 // N_S + 1 indicates the number of spatial layers present in the VP9 stream + Y bool // Each spatial layer's frame resolution present + G bool // PG description present flag. + NG uint8 // N_G indicates the number of pictures in a Picture Group (PG) + Width []uint16 + Height []uint16 + PGTID []uint8 // Temporal layer ID of pictures in a Picture Group + PGU []bool // Switching up point of pictures in a Picture Group + PGPDiff [][]uint8 // Reference indecies of pictures in a Picture Group + + Payload []byte + + videoDepacketizer +} + +// Unmarshal parses the passed byte slice and stores the result in the VP9Packet this method is called upon. +func (p *VP9Packet) Unmarshal(packet []byte) ([]byte, error) { // nolint:cyclop + if packet == nil { + return nil, errNilPacket + } + if len(packet) < 1 { + return nil, errShortPacket + } + + p.I = packet[0]&0x80 != 0 + p.P = packet[0]&0x40 != 0 + p.L = packet[0]&0x20 != 0 + p.F = packet[0]&0x10 != 0 + p.B = packet[0]&0x08 != 0 + p.E = packet[0]&0x04 != 0 + p.V = packet[0]&0x02 != 0 + p.Z = packet[0]&0x01 != 0 + + pos := 1 + var err error + + if p.I { + pos, err = p.parsePictureID(packet, pos) + if err != nil { + return nil, err + } + } + + if p.L { + pos, err = p.parseLayerInfo(packet, pos) + if err != nil { + return nil, err + } + } + + if p.F && p.P { + pos, err = p.parseRefIndices(packet, pos) + if err != nil { + return nil, err + } + } + + if p.V { + pos, err = p.parseSSData(packet, pos) + if err != nil { + return nil, err + } + } + + p.Payload = packet[pos:] + + return p.Payload, nil +} + +// Picture ID: +/* +* +-+-+-+-+-+-+-+-+ +* I: |M| PICTURE ID | M:0 => picture id is 7 bits. +* +-+-+-+-+-+-+-+-+ M:1 => picture id is 15 bits. +* M: | EXTENDED PID | +* +-+-+-+-+-+-+-+-+ +**/ +// . +func (p *VP9Packet) parsePictureID(packet []byte, pos int) (int, error) { + if len(packet) <= pos { + return pos, errShortPacket + } + + p.PictureID = uint16(packet[pos] & 0x7F) + if packet[pos]&0x80 != 0 { + pos++ + if len(packet) <= pos { + return pos, errShortPacket + } + p.PictureID = p.PictureID<<8 | uint16(packet[pos]) + } + pos++ + + return pos, nil +} + +func (p *VP9Packet) parseLayerInfo(packet []byte, pos int) (int, error) { + pos, err := p.parseLayerInfoCommon(packet, pos) + if err != nil { + return pos, err + } + + if p.F { + return pos, nil + } + + return p.parseLayerInfoNonFlexibleMode(packet, pos) +} + +// Layer indices (flexible mode): +/* +* +-+-+-+-+-+-+-+-+ +* L: | T |U| S |D| +* +-+-+-+-+-+-+-+-+ +**/ +// . +func (p *VP9Packet) parseLayerInfoCommon(packet []byte, pos int) (int, error) { + if len(packet) <= pos { + return pos, errShortPacket + } + + p.TID = packet[pos] >> 5 + p.U = packet[pos]&0x10 != 0 + p.SID = (packet[pos] >> 1) & 0x7 + p.D = packet[pos]&0x01 != 0 + + if p.SID >= maxSpatialLayers { + return pos, errTooManySpatialLayers + } + + pos++ + + return pos, nil +} + +// Layer indices (non-flexible mode): +/* +* +-+-+-+-+-+-+-+-+ +* L: | T |U| S |D| +* +-+-+-+-+-+-+-+-+ +* | TL0PICIDX | +* +-+-+-+-+-+-+-+-+ +**/ +// . +func (p *VP9Packet) parseLayerInfoNonFlexibleMode(packet []byte, pos int) (int, error) { + if len(packet) <= pos { + return pos, errShortPacket + } + + p.TL0PICIDX = packet[pos] + pos++ + + return pos, nil +} + +// Reference indices: . +/* +* +-+-+-+-+-+-+-+-+ P=1,F=1: At least one reference index +* P,F: | P_DIFF |N| up to 3 times has to be specified. +* +-+-+-+-+-+-+-+-+ N=1: An additional P_DIFF follows +* current P_DIFF. +* +**/ +// . +func (p *VP9Packet) parseRefIndices(packet []byte, pos int) (int, error) { + for { + if len(packet) <= pos { + return pos, errShortPacket + } + p.PDiff = append(p.PDiff, packet[pos]>>1) + if packet[pos]&0x01 == 0 { + break + } + if len(p.PDiff) >= maxVP9RefPics { + return pos, errTooManyPDiff + } + pos++ + } + pos++ + + return pos, nil +} + +// Scalability structure (SS): +/* +* +-+-+-+-+-+-+-+-+ +* V: | N_S |Y|G|-|-|-| +* +-+-+-+-+-+-+-+-+ -| +* Y: | WIDTH | (OPTIONAL) . +* + . +* | | (OPTIONAL) . +* +-+-+-+-+-+-+-+-+ . N_S + 1 times +* | HEIGHT | (OPTIONAL) . +* + . +* | | (OPTIONAL) . +* +-+-+-+-+-+-+-+-+ -| +* G: | N_G | (OPTIONAL) +* +-+-+-+-+-+-+-+-+ -| +* N_G: | T |U| R |-|-| (OPTIONAL) . +* +-+-+-+-+-+-+-+-+ -| . N_G times +* | P_DIFF | (OPTIONAL) . R times . +* +-+-+-+-+-+-+-+-+ -| -| +**/ +// . +func (p *VP9Packet) parseSSData(packet []byte, pos int) (int, error) { // nolint: cyclop + if len(packet) <= pos { + return pos, errShortPacket + } + + p.NS = packet[pos] >> 5 + p.Y = packet[pos]&0x10 != 0 + p.G = packet[pos]&0x8 != 0 + pos++ + + NS := p.NS + 1 + p.NG = 0 + + if p.Y { + p.Width = make([]uint16, NS) + p.Height = make([]uint16, NS) + for i := 0; i < int(NS); i++ { + if len(packet) <= (pos + 3) { + return pos, errShortPacket + } + + p.Width[i] = uint16(packet[pos])<<8 | uint16(packet[pos+1]) + pos += 2 + p.Height[i] = uint16(packet[pos])<<8 | uint16(packet[pos+1]) + pos += 2 + } + } + + if p.G { + if len(packet) <= pos { + return pos, errShortPacket + } + + p.NG = packet[pos] + pos++ + } + + for i := 0; i < int(p.NG); i++ { + if len(packet) <= pos { + return pos, errShortPacket + } + + p.PGTID = append(p.PGTID, packet[pos]>>5) + p.PGU = append(p.PGU, packet[pos]&0x10 != 0) + R := (packet[pos] >> 2) & 0x3 + pos++ + + p.PGPDiff = append(p.PGPDiff, []uint8{}) + + if len(packet) <= (pos + int(R) - 1) { + return pos, errShortPacket + } + + for j := 0; j < int(R); j++ { + p.PGPDiff[i] = append(p.PGPDiff[i], packet[pos]) + pos++ + } + } + + return pos, nil +} + +// VP9PartitionHeadChecker checks VP9 partition head. +// +// Deprecated: replaced by VP9Packet.IsPartitionHead(). +type VP9PartitionHeadChecker struct{} + +// IsPartitionHead checks whether if this is a head of the VP9 partition. +// +// Deprecated: replaced by VP9Packet.IsPartitionHead(). +func (*VP9PartitionHeadChecker) IsPartitionHead(packet []byte) bool { + return (&VP9Packet{}).IsPartitionHead(packet) +} + +// IsPartitionHead checks whether if this is a head of the VP9 partition. +func (*VP9Packet) IsPartitionHead(payload []byte) bool { + if len(payload) < 1 { + return false + } + + return (payload[0] & 0x08) != 0 +} diff --git a/vendor/github.com/pion/sctp/.gitignore b/vendor/github.com/pion/sctp/.gitignore new file mode 100644 index 0000000..2394557 --- /dev/null +++ b/vendor/github.com/pion/sctp/.gitignore @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: 2026 The Pion community +# SPDX-License-Identifier: MIT + +### JetBrains IDE ### +##################### +.idea/ + +### Emacs Temporary Files ### +############################# +*~ + +### Folders ### +############### +bin/ +vendor/ +node_modules/ + +### Files ### +############# +*.ivf +*.ogg +tags +cover.out +*.sw[poe] +*.wasm +examples/sfu-ws/cert.pem +examples/sfu-ws/key.pem +wasm_exec.js diff --git a/vendor/github.com/pion/sctp/.golangci.yml b/vendor/github.com/pion/sctp/.golangci.yml new file mode 100644 index 0000000..1fbb8db --- /dev/null +++ b/vendor/github.com/pion/sctp/.golangci.yml @@ -0,0 +1,148 @@ +# SPDX-FileCopyrightText: 2026 The Pion community +# SPDX-License-Identifier: MIT + +version: "2" +linters: + enable: + - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers + - bidichk # Checks for dangerous unicode character sequences + - bodyclose # checks whether HTTP response body is closed successfully + - containedctx # containedctx is a linter that detects struct contained context.Context field + - contextcheck # check the function whether use a non-inherited context + - cyclop # checks function and package cyclomatic complexity + - decorder # check declaration order and count of types, constants, variables and functions + - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) + - dupl # Tool for code clone detection + - durationcheck # check for two durations multiplied together + - err113 # Golang linter to check the errors handling expressions + - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases + - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. + - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. + - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. + - exhaustive # check exhaustiveness of enum switch statements + - forbidigo # Forbids identifiers + - forcetypeassert # finds forced type assertions + - gochecknoglobals # Checks that no globals are present in Go code + - gocognit # Computes and checks the cognitive complexity of functions + - goconst # Finds repeated strings that could be replaced by a constant + - gocritic # The most opinionated Go source code linter + - gocyclo # Computes and checks the cyclomatic complexity of functions + - godot # Check if comments end in a period + - godox # Tool for detection of FIXME, TODO and other comment keywords + - goheader # Checks is file header matches to pattern + - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. + - goprintffuncname # Checks that printf-like functions are named with `f` at the end + - gosec # Inspects source code for security problems + - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string + - grouper # An analyzer to analyze expression groups. + - importas # Enforces consistent import aliases + - ineffassign # Detects when assignments to existing variables are not used + - lll # Reports long lines + - maintidx # maintidx measures the maintainability index of each function. + - makezero # Finds slice declarations with non-zero initial length + - misspell # Finds commonly misspelled English words in comments + - modernize # Replace and suggests simplifications to code + - nakedret # Finds naked returns in functions greater than a specified function length + - nestif # Reports deeply nested if statements + - nilerr # Finds the code that returns nil even if it checks that the error is not nil. + - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. + - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity + - noctx # noctx finds sending http request without context.Context + - predeclared # find code that shadows one of Go's predeclared identifiers + - revive # golint replacement, finds style mistakes + - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks + - tagliatelle # Checks the struct tags. + - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers + - unconvert # Remove unnecessary type conversions + - unparam # Reports unused function parameters + - unused # Checks Go code for unused constants, variables, functions and types + - varnamelen # checks that the length of a variable's name matches its scope + - wastedassign # wastedassign finds wasted assignment statements + - whitespace # Tool for detection of leading and trailing whitespace + disable: + - depguard # Go linter that checks if package imports are in a list of acceptable packages + - funlen # Tool for detection of long functions + - gochecknoinits # Checks that no init functions are present in Go code + - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. + - interfacebloat # A linter that checks length of interface. + - ireturn # Accept Interfaces, Return Concrete Types + - mnd # An analyzer to detect magic numbers + - nolintlint # Reports ill-formed or insufficient nolint directives + - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test + - prealloc # Finds slice declarations that could potentially be preallocated + - promlinter # Check Prometheus metrics naming via promlint + - rowserrcheck # checks whether Err of rows is checked successfully + - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. + - testpackage # linter that makes you use a separate _test package + - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes + - wrapcheck # Checks that errors returned from external packages are wrapped + - wsl # Whitespace Linter - Forces you to use empty lines! + settings: + staticcheck: + checks: + - all + - -QF1008 # "could remove embedded field", to keep it explicit! + - -QF1003 # "could use tagged switch on enum", Cases conflicts with exhaustive! + exhaustive: + default-signifies-exhaustive: true + forbidigo: + forbid: + - pattern: ^fmt.Print(f|ln)?$ + - pattern: ^log.(Panic|Fatal|Print)(f|ln)?$ + - pattern: ^os.Exit$ + - pattern: ^panic$ + - pattern: ^print(ln)?$ + - pattern: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$ + pkg: ^testing$ + msg: use testify/assert instead + analyze-types: true + gomodguard: + blocked: + modules: + - github.com/pkg/errors: + recommendations: + - errors + govet: + enable: + - shadow + revive: + rules: + # Prefer 'any' type alias over 'interface{}' for Go 1.18+ compatibility + - name: use-any + severity: warning + disabled: false + misspell: + locale: US + varnamelen: + max-distance: 12 + min-name-length: 2 + ignore-type-assert-ok: true + ignore-map-index-ok: true + ignore-chan-recv-ok: true + ignore-decls: + - i int + - n int + - w io.Writer + - r io.Reader + - b []byte + exclusions: + generated: lax + rules: + - linters: + - forbidigo + - gocognit + path: (examples|main\.go) + - linters: + - gocognit + path: _test\.go + - linters: + - forbidigo + path: cmd +formatters: + enable: + - gci # Gci control golang package import order and make it always deterministic. + - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification + - gofumpt # Gofumpt checks whether code was gofumpt-ed. + - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports + exclusions: + generated: lax diff --git a/vendor/github.com/pion/sctp/.goreleaser.yml b/vendor/github.com/pion/sctp/.goreleaser.yml new file mode 100644 index 0000000..8577d86 --- /dev/null +++ b/vendor/github.com/pion/sctp/.goreleaser.yml @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: 2026 The Pion community +# SPDX-License-Identifier: MIT + +builds: +- skip: true diff --git a/vendor/github.com/pion/sctp/LICENSE b/vendor/github.com/pion/sctp/LICENSE new file mode 100644 index 0000000..d96df05 --- /dev/null +++ b/vendor/github.com/pion/sctp/LICENSE @@ -0,0 +1,9 @@ +MIT License + +Copyright (c) 2026 The Pion community + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/pion/sctp/README.md b/vendor/github.com/pion/sctp/README.md new file mode 100644 index 0000000..ef33fb5 --- /dev/null +++ b/vendor/github.com/pion/sctp/README.md @@ -0,0 +1,64 @@ +

+
+ Pion SCTP +
+

+

A Go implementation of SCTP

+

+ Pion SCTP + join us on Discord Follow us on Bluesky +
+ GitHub Workflow Status + Go Reference + Coverage Status + Go Report Card + License: MIT +

+
+ +### Implemented +- [RFC 6525](https://www.rfc-editor.org/rfc/rfc6525.html) — Stream Control Transmission Protocol (SCTP) Stream Reconfiguration +- [RFC 3758](https://www.rfc-editor.org/rfc/rfc3758.html) — Stream Control Transmission Protocol (SCTP) Partial Reliability Extension +- [RFC 5061](https://www.rfc-editor.org/rfc/rfc5061.html) — Stream Control Transmission Protocol (SCTP) Dynamic Address Reconfiguration +- [RFC 4895](https://www.rfc-editor.org/rfc/rfc4895.html) — Authenticated Chunks for the Stream Control Transmission Protocol (SCTP) +- [RFC 1982](https://www.rfc-editor.org/rfc/rfc1982.html) — Serial Number Arithmetic + +### Partial implementations +Pion only implements the subset of RFC 4960 that is required for WebRTC. + +- [RFC 4960](https://www.rfc-editor.org/rfc/rfc4960.html) — Stream Control Transmission Protocol [Obsoleted by 9260, above] +- [RFC 2960](https://www.rfc-editor.org/rfc/rfc2960.html) — Stream Control Transmission Protocol [Obsoleted by 4960, above] + +The update to [RFC 9260](https://www.rfc-editor.org/rfc/rfc9260) — Stream Control Transmission Protocol is currently a [work in progress](https://github.com/pion/sctp/issues/402). + +### Potential future implementations +Ideally, we would like to add the following features as part of a [v2 refresh](https://github.com/pion/sctp/issues/314): + +Feature | Reference | Progress +--- | --- | --- +RACK (tail loss probing) | [Paper](https://icnp20.cs.ucr.edu/proceedings/nipaa/RACK%20for%20SCTP.pdf), [Comment](https://github.com/pion/sctp/issues/206#issuecomment-968265853)| [In review](https://github.com/pion/sctp/pull/390) +Adaptive burst mitigation | [Paper, see section 5A](https://icnp20.cs.ucr.edu/proceedings/nipaa/RACK%20for%20SCTP.pdf)| [In review](https://github.com/pion/sctp/pull/394) +Update to RFC 9260 | [Parent issue](https://github.com/pion/sctp/issues/402) | [In progress](https://github.com/pion/sctp/issues/402) +Implement RFC 8260 | [Issue](https://github.com/pion/sctp/issues/435) | In progress (no PR available yet) +Blocking writes | [1](https://github.com/pion/sctp/issues/77), [2](https://github.com/pion/sctp/issues/357) | [Potentially in progress](https://github.com/pion/sctp/issues/357#issuecomment-3382050767) +association.listener (and better docs) | [1](https://github.com/pion/sctp/issues/74), [2](https://github.com/pion/sctp/issues/173) | Not started, [blocked by above](https://github.com/pion/sctp/issues/74#issuecomment-545550714) + +RFCs of interest: +- [RFC 9438](https://datatracker.ietf.org/doc/rfc9438/) as it addresses the low utilization problem of [RFC 4960](https://www.rfc-editor.org/rfc/rfc4960.html) in fast long-distance networks as mentioned [here](https://github.com/pion/sctp/issues/218#issuecomment-3329690797). + +### Roadmap +The library is used as a part of our WebRTC implementation. Please refer to that [roadmap](https://github.com/pion/webrtc/issues/9) to track our major milestones. + +### Community +Pion has an active community on the [Discord](https://discord.gg/PngbdqpFbt). + +Follow the [Pion Bluesky](https://bsky.app/profile/pion.ly) or [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news. + +We are always looking to support **your projects**. Please reach out if you have something to build! +If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly) + +### Contributing +Check out the [contributing wiki](https://github.com/pion/webrtc/wiki/Contributing) to join the group of amazing people making this project possible + +### License +MIT License - see [LICENSE](LICENSE) for full text diff --git a/vendor/github.com/pion/sctp/ack_timer.go b/vendor/github.com/pion/sctp/ack_timer.go new file mode 100644 index 0000000..868cbc0 --- /dev/null +++ b/vendor/github.com/pion/sctp/ack_timer.go @@ -0,0 +1,106 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "math" + "sync" + "time" +) + +const ( + ackInterval time.Duration = 200 * time.Millisecond +) + +// ackTimerObserver is the inteface to an ack timer observer. +type ackTimerObserver interface { + onAckTimeout() +} + +type ackTimerState uint8 + +const ( + ackTimerStopped ackTimerState = iota + ackTimerStarted + ackTimerClosed +) + +// ackTimer provides the retnransmission timer conforms with RFC 4960 Sec 6.3.1. +type ackTimer struct { + timer *time.Timer + observer ackTimerObserver + mutex sync.Mutex + state ackTimerState + pending uint8 +} + +// newAckTimer creates a new acknowledgement timer used to enable delayed ack. +func newAckTimer(observer ackTimerObserver) *ackTimer { + t := &ackTimer{observer: observer} + t.timer = time.AfterFunc(math.MaxInt64, t.timeout) + t.timer.Stop() + + return t +} + +func (t *ackTimer) timeout() { + t.mutex.Lock() + if t.pending--; t.pending == 0 && t.state == ackTimerStarted { + t.state = ackTimerStopped + defer t.observer.onAckTimeout() + } + t.mutex.Unlock() +} + +// start starts the timer. +func (t *ackTimer) start() bool { + t.mutex.Lock() + defer t.mutex.Unlock() + + // this timer is already closed or already running + if t.state != ackTimerStopped { + return false + } + + t.state = ackTimerStarted + t.pending++ + t.timer.Reset(ackInterval) + + return true +} + +// stops the timer. this is similar to stop() but subsequent start() call +// will fail (the timer is no longer usable). +func (t *ackTimer) stop() { + t.mutex.Lock() + defer t.mutex.Unlock() + + if t.state == ackTimerStarted { + if t.timer.Stop() { + t.pending-- + } + t.state = ackTimerStopped + } +} + +// closes the timer. this is similar to stop() but subsequent start() call +// will fail (the timer is no longer usable). +func (t *ackTimer) close() { + t.mutex.Lock() + defer t.mutex.Unlock() + + if t.state == ackTimerStarted && t.timer.Stop() { + t.pending-- + } + t.state = ackTimerClosed +} + +// isRunning tests if the timer is running. +// Debug purpose only. +func (t *ackTimer) isRunning() bool { + t.mutex.Lock() + defer t.mutex.Unlock() + + return t.state == ackTimerStarted +} diff --git a/vendor/github.com/pion/sctp/association.go b/vendor/github.com/pion/sctp/association.go new file mode 100644 index 0000000..4d1d353 --- /dev/null +++ b/vendor/github.com/pion/sctp/association.go @@ -0,0 +1,4348 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "bytes" + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "math" + "net" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/pion/logging" + "github.com/pion/randutil" + "github.com/pion/transport/v4/deadline" +) + +// Port 5000 shows up in examples for SDPs used by WebRTC. Since this implementation +// assumes it will be used by DTLS over UDP, the port is only meaningful for de-multiplexing +// but more-so verification. +// Example usage: https://www.rfc-editor.org/rfc/rfc8841.html#section-13.1-2 +const defaultSCTPSrcDstPort = 5000 + +// Use global random generator to properly seed by crypto grade random. +var globalMathRandomGenerator = randutil.NewMathRandomGenerator() // nolint:gochecknoglobals + +// Generates a non-zero Initiate tag. +func generateInitiateTag() uint32 { + for { + if u := globalMathRandomGenerator.Uint32(); u != 0 { + return u + } + } +} + +// Association errors. +var ( + ErrChunk = errors.New("abort chunk, with following errors") + ErrShutdownNonEstablished = errors.New("shutdown called in non-established state") + ErrAssociationClosedBeforeConn = errors.New("association closed before connecting") + ErrAssociationClosed = errors.New("association closed") + ErrSilentlyDiscard = errors.New("silently discard") + ErrInitNotStoredToSend = errors.New("the init not stored to send") + ErrCookieEchoNotStoredToSend = errors.New("cookieEcho not stored to send") + ErrSCTPPacketSourcePortZero = errors.New("sctp packet must not have a source port of 0") + ErrSCTPPacketDestinationPortZero = errors.New("sctp packet must not have a destination port of 0") + ErrInitChunkBundled = errors.New("init chunk must not be bundled with any other chunk") + ErrInitChunkVerifyTagNotZero = errors.New( + "init chunk expects a verification tag of 0 on the packet when out-of-the-blue", + ) + ErrHandleInitState = errors.New("todo: handle Init when in state") + ErrInitAckNoCookie = errors.New("no cookie in InitAck") + ErrInflightQueueTSNPop = errors.New("unable to be popped from inflight queue TSN") + ErrTSNRequestNotExist = errors.New("requested non-existent TSN") + ErrResetPacketInStateNotExist = errors.New("sending reset packet in non-established state") + ErrParamterType = errors.New("unexpected parameter type") + ErrPayloadDataStateNotExist = errors.New("sending payload data in non-established state") + ErrChunkTypeUnhandled = errors.New("unhandled chunk type") + ErrHandshakeInitAck = errors.New("handshake failed (INIT ACK)") + ErrHandshakeCookieEcho = errors.New("handshake failed (COOKIE ECHO)") + ErrTooManyReconfigRequests = errors.New("too many outstanding reconfig requests") +) + +const ( + receiveMTU uint32 = 8192 // MTU for inbound packet (from DTLS) + initialMTU uint32 = 1228 // initial MTU for outgoing packets (to DTLS) + initialRecvBufSize uint32 = 1024 * 1024 + commonHeaderSize uint32 = 12 + dataChunkHeaderSize uint32 = 16 + defaultMaxMessageSize uint32 = 65536 +) + +// association state enums. +const ( + closed uint32 = iota + cookieWait + cookieEchoed + established + shutdownAckSent + shutdownPending + shutdownReceived + shutdownSent +) + +// retransmission timer IDs. +const ( + timerT1Init int = iota + timerT1Cookie + timerT2Shutdown + timerT3RTX + timerReconfig +) + +// ack mode (for testing). +const ( + ackModeNormal int = iota + ackModeNoDelay + ackModeAlwaysDelay +) + +// ack transmission state. +const ( + ackStateIdle int = iota // ack timer is off + ackStateImmediate // will send ack immediately + ackStateDelay // ack timer is on (ack is being delayed) +) + +// other constants. +const ( + acceptChSize = 16 + // avgChunkSize is an estimate of the average chunk size. There is no theory behind + // this estimate. + avgChunkSize = 500 + // minTSNOffset is the minimum offset over the cummulative TSN that we will enqueue + // irrespective of the receive buffer size + // see getMaxTSNOffset. + minTSNOffset = 2000 + // maxTSNOffset is the maximum offset over the cummulative TSN that we will enqueue + // irrespective of the receive buffer size + // see getMaxTSNOffset. + maxTSNOffset = 40000 + // maxReconfigRequests is the maximum number of reconfig requests we will keep outstanding. + maxReconfigRequests = 1000 + + // TLR Adaptive burst mitigation uses quarter-MTU units. + // 1 MTU == 4 units, 0.25 MTU == 1 unit. + tlrUnitsPerMTU = 4 + + // Default burst limits. + tlrBurstDefaultFirstRTT = 16 // 4.0 MTU + tlrBurstDefaultLaterRTT = 8 // 2.0 MTU + + // Minimum burst limits. + tlrBurstMinFirstRTT = 8 // 2.0 MTU + tlrBurstMinLaterRTT = 5 // 1.25 MTU + + // Adaptation steps. + tlrBurstStepDownFirstRTT = 4 // reduce by 1.0 MTU + tlrBurstStepDownLaterRTT = 1 // reduce by 0.25 MTU + + tlrGoodOpsResetThreshold = 16 +) + +func getAssociationStateString(assoc uint32) string { + switch assoc { + case closed: + return "Closed" + case cookieWait: + return "CookieWait" + case cookieEchoed: + return "CookieEchoed" + case established: + return "Established" + case shutdownPending: + return "ShutdownPending" + case shutdownSent: + return "ShutdownSent" + case shutdownReceived: + return "ShutdownReceived" + case shutdownAckSent: + return "ShutdownAckSent" + default: + return fmt.Sprintf("Invalid association state %d", assoc) + } +} + +// Association represents an SCTP association +// 13.2. Parameters Necessary per Association (i.e., the TCB) +// +// Peer : Tag value to be sent in every packet and is received +// Verification: in the INIT or INIT ACK chunk. +// Tag : +// State : A state variable indicating what state the association +// : is in, i.e., COOKIE-WAIT, COOKIE-ECHOED, ESTABLISHED, +// : SHUTDOWN-PENDING, SHUTDOWN-SENT, SHUTDOWN-RECEIVED, +// : SHUTDOWN-ACK-SENT. +// +// Note: No "CLOSED" state is illustrated since if a +// association is "CLOSED" its TCB SHOULD be removed. +// Note: By nature of an Association being constructed with one net.Conn, +// it is not a multi-home supporting implementation of SCTP. +type Association struct { + bytesReceived uint64 + bytesSent uint64 + + lock sync.RWMutex + + netConn net.Conn + + peerVerificationTag uint32 + myVerificationTag uint32 + state uint32 + initialTSN uint32 + myNextTSN uint32 // nextTSN + minTSN2MeasureRTT uint32 // for RTT measurement + willSendForwardTSN bool + willRetransmitFast bool + willRetransmitReconfig bool + + willSendShutdown bool + willSendShutdownAck bool + willSendShutdownComplete bool + + willSendAbort bool + willSendAbortCause errorCause + abortSentOnce sync.Once + abortSentCh chan struct{} + + // Reconfig + myNextRSN uint32 + reconfigs map[uint32]*chunkReconfig + reconfigRequests map[uint32]*paramOutgoingResetRequest + + // Non-RFC internal data + sourcePort uint16 + destinationPort uint16 + myMaxNumInboundStreams uint16 + myMaxNumOutboundStreams uint16 + myCookie *paramStateCookie + payloadQueue *receivePayloadQueue + inflightQueue *payloadQueue + pendingQueue *pendingQueue + controlQueue *controlQueue + mtu uint32 + maxPayloadSize uint32 // max DATA chunk payload size + srtt atomic.Value // type float64 + cumulativeTSNAckPoint uint32 + advancedPeerTSNAckPoint uint32 + useForwardTSN bool + sendZeroChecksum bool + recvZeroChecksum bool + + // Congestion control parameters + maxReceiveBufferSize uint32 + maxMessageSize uint32 + cwnd uint32 // my congestion window size + rwnd uint32 // calculated peer's receiver windows size + ssthresh uint32 // slow start threshold + partialBytesAcked uint32 + inFastRecovery bool + fastRecoverExitPoint uint32 + minCwnd uint32 // Minimum congestion window + fastRtxWnd uint32 // Send window for fast retransmit + cwndCAStep uint32 // Step of congestion window increase at Congestion Avoidance + + // RTX & Ack timer + rtoMgr *rtoManager + t1Init *rtxTimer + t1Cookie *rtxTimer + t2Shutdown *rtxTimer + t3RTX *rtxTimer + tReconfig *rtxTimer + ackTimer *ackTimer + + // RACK / TLP state + rack rackSettings // rack configurable options + rackReoWnd time.Duration // dynamic reordering window + rackMinRTT time.Duration // min observed RTT + rackDeliveredTime time.Time // send time of most recently delivered original chunk + rackHighestDeliveredOrigTSN uint32 + rackReorderingSeen bool // ever observed reordering for this association + rackKeepInflatedRecoveries int // keep inflated reoWnd for 16 loss recoveries + // RACK xmit-time ordered list + rackHead *chunkPayloadData + rackTail *chunkPayloadData + + // Unified timer for RACK and PTO driven by a single goroutine. + // Deadlines are protected with timerMu. + timerMu sync.Mutex + timerUpdateCh chan struct{} + rackDeadline time.Time + ptoDeadline time.Time + + // Chunks stored for retransmission + storedInit *chunkInit + storedCookieEcho *chunkCookieEcho + + streams map[uint16]*Stream + acceptCh chan *Stream + readLoopCloseCh chan struct{} + awakeWriteLoopCh chan struct{} + closeWriteLoopCh chan struct{} + handshakeCompletedCh chan error + + closeWriteLoopOnce sync.Once + + // local error + silentError error + + ackState int + ackMode int // for testing + + // stats + stats *associationStats + + // per inbound packet context + delayedAckTriggered bool + immediateAckTriggered bool + + blockWrite bool + writePending bool + writeNotify chan struct{} + + name string + log logging.LeveledLogger + + // Adaptive burst mitigation variables + tlrActive bool + tlrFirstRTT bool // first RTT of this TLR operation + tlrHadAdditionalLoss bool + tlrEndTSN uint32 // recovery is done when cumAck >= tlrEndTSN + + tlrBurstFirstRTTUnits int64 // quarter-MTU units + tlrBurstLaterRTTUnits int64 // quarter-MTU units + + tlrGoodOps uint32 // count of TLR ops completed w/o additional loss + tlrStartTime time.Time // time of first recovery RTT +} + +type snapConfig struct { + // Local and remote SCTP init to use for SNAP + localInit []byte + remoteInit []byte +} + +// Config collects the arguments to createAssociation construction into +// a single structure. +type Config struct { + LoggerFactory logging.LoggerFactory + Name string + NetConn net.Conn + + BlockWrite bool + EnableZeroChecksum bool + MTU uint32 + + // congestion control configuration + MaxReceiveBufferSize uint32 + MaxMessageSize uint32 + // RTOMax is the maximum retransmission timeout in milliseconds + RTOMax float64 + // Minimum congestion window + MinCwnd uint32 + // Send window for fast retransmit + FastRtxWnd uint32 + // Step of congestion window increase at Congestion Avoidance + CwndCAStep uint32 + + // RACK config options + rack rackSettings + + // SNAP/sctp-init + snapConfig *snapConfig +} + +// Server accepts a SCTP stream over a conn. +// +// Deprecated: Use ServerWithOptions instead. +func Server(config Config) (*Association, error) { + return ServerWithOptions(config) +} + +// ServerWithOptions accepts a SCTP stream over a conn. +func ServerWithOptions(opts ...ServerOption) (*Association, error) { + assoc, err := createServerAssociation(opts...) + if err != nil { + return nil, err + } + assoc.initServer() + + select { + case err := <-assoc.handshakeCompletedCh: + if err != nil { + return nil, err + } + + return assoc, nil + case <-assoc.readLoopCloseCh: + return nil, ErrAssociationClosedBeforeConn + } +} + +// Client opens a SCTP stream over a conn. +// +// Deprecated: Use ClientWithOptions instead. +func Client(config Config) (*Association, error) { + return ClientWithOptions(config) +} + +// ClientWithOptions opens a SCTP stream over a conn. +func ClientWithOptions(opts ...ClientOption) (*Association, error) { + return createClientWithOptionsWithContext(context.Background(), opts...) +} + +func createClientWithContext(ctx context.Context, config Config) (*Association, error) { + return createClientWithOptionsWithContext(ctx, config) +} + +func createSNAPAssociation(config *Config) (*Association, error) { + // SNAP, aka sctp-init in the SDP. + remote := &chunkInit{} + err := remote.unmarshal(config.snapConfig.remoteInit) + if err != nil { + return nil, err + } + local := &chunkInit{} + err = local.unmarshal(config.snapConfig.localInit) + if err != nil { + return nil, err + } + assoc := createAssociationFromConfigWithTsn(config, local.initialTSN) + assoc.initWithOutOfBandTokens(local, remote) + + return assoc, nil +} + +func createClientWithOptionsWithContext(ctx context.Context, opts ...ClientOption) (*Association, error) { + config, err := buildClientConfig(opts...) + if err != nil { + return nil, err + } + if config.snapConfig != nil && len(config.snapConfig.remoteInit) != 0 && len(config.snapConfig.localInit) != 0 { + return createSNAPAssociation(config) + } + assoc, err := createClientAssociation(opts...) + if err != nil { + return nil, err + } + + assoc.initClient() + + select { + case <-ctx.Done(): + assoc.log.Errorf("[%s] client handshake canceled: state=%s", assoc.name, getAssociationStateString(assoc.getState())) + assoc.Close() // nolint:errcheck,gosec + + return nil, ctx.Err() + case err := <-assoc.handshakeCompletedCh: + if err != nil { + return nil, err + } + + return assoc, nil + case <-assoc.readLoopCloseCh: + return nil, ErrAssociationClosedBeforeConn + } +} + +// applyDefaults applies default values to the config. +func (c *Config) applyDefaults() { + if c.LoggerFactory == nil { + c.LoggerFactory = logging.NewDefaultLoggerFactory() + } + if c.MaxReceiveBufferSize == 0 { + c.MaxReceiveBufferSize = initialRecvBufSize + } + if c.MaxMessageSize == 0 { + c.MaxMessageSize = defaultMaxMessageSize + } + if c.MTU == 0 { + c.MTU = initialMTU + } +} + +func createServerAssociation(opts ...ServerOption) (*Association, error) { + cfg, err := buildServerConfig(opts...) + if err != nil { + return nil, err + } + + return createAssociationFromConfig(cfg) +} + +func (a *Association) initServer() { + a.lock.Lock() + defer a.lock.Unlock() + + go a.readLoop() + go a.writeLoop() +} + +// applyServer allows the exported Config to act as a ServerOption. +func (c Config) applyServer(cfg *Config) error { //nolint:dupl,cyclop + if c.LoggerFactory != nil { + cfg.LoggerFactory = c.LoggerFactory + } + if c.Name != "" { + cfg.Name = c.Name + } + if c.NetConn != nil { + cfg.NetConn = c.NetConn + } + + cfg.BlockWrite = c.BlockWrite + cfg.EnableZeroChecksum = c.EnableZeroChecksum + + if c.MTU != 0 { + cfg.MTU = c.MTU + } + if c.MaxReceiveBufferSize != 0 { + cfg.MaxReceiveBufferSize = c.MaxReceiveBufferSize + } + if c.MaxMessageSize != 0 { + cfg.MaxMessageSize = c.MaxMessageSize + } + if c.RTOMax != 0 { + cfg.RTOMax = c.RTOMax + } + if c.MinCwnd != 0 { + cfg.MinCwnd = c.MinCwnd + } + if c.FastRtxWnd != 0 { + cfg.FastRtxWnd = c.FastRtxWnd + } + if c.CwndCAStep != 0 { + cfg.CwndCAStep = c.CwndCAStep + } + + cfg.rack = c.rack + + return nil +} + +func buildServerConfig(opts ...ServerOption) (*Config, error) { + cfg := &Config{} + cfg.applyDefaults() + + for _, opt := range opts { + if opt == nil { + continue + } + if err := opt.applyServer(cfg); err != nil { + return nil, err + } + } + + cfg.applyDefaults() + + if cfg.NetConn == nil { + return nil, errNilNetConn + } + + return cfg, nil +} + +func createClientAssociation(opts ...ClientOption) (*Association, error) { + cfg, err := buildClientConfig(opts...) + if err != nil { + return nil, err + } + + return createAssociationFromConfig(cfg) +} + +func (a *Association) initClient() { + a.lock.Lock() + defer a.lock.Unlock() + + go a.readLoop() + go a.writeLoop() + + init := &chunkInit{} + init.initialTSN = a.myNextTSN + init.numOutboundStreams = a.myMaxNumOutboundStreams + init.numInboundStreams = a.myMaxNumInboundStreams + init.initiateTag = a.myVerificationTag + init.advertisedReceiverWindowCredit = a.maxReceiveBufferSize + setSupportedExtensions(&init.chunkInitCommon) + + if a.recvZeroChecksum { + init.params = append(init.params, ¶mZeroChecksumAcceptable{edmid: dtlsErrorDetectionMethod}) + } + + a.storedInit = init + + err := a.sendInit() + if err != nil { + a.log.Errorf("[%s] failed to send init: %s", a.name, err.Error()) + } + + // After sending the INIT chunk, "A" starts the T1-init timer and enters the COOKIE-WAIT state. + // Note: ideally we would set state after the timer starts but since we don't do this in an atomic + // set + timer-start, it's safer to just set the state first so that we don't have a timer expiration + // race. + a.setState(cookieWait) + a.t1Init.start(a.rtoMgr.getRTO()) +} + +// applyClient allows the exported Config to act as a ClientOption. +// this is currently the same as applyServer. +func (c Config) applyClient(cfg *Config) error { //nolint:dupl,cyclop + if c.LoggerFactory != nil { + cfg.LoggerFactory = c.LoggerFactory + } + if c.Name != "" { + cfg.Name = c.Name + } + if c.NetConn != nil { + cfg.NetConn = c.NetConn + } + + cfg.BlockWrite = c.BlockWrite + cfg.EnableZeroChecksum = c.EnableZeroChecksum + + if c.MTU != 0 { + cfg.MTU = c.MTU + } + if c.MaxReceiveBufferSize != 0 { + cfg.MaxReceiveBufferSize = c.MaxReceiveBufferSize + } + if c.MaxMessageSize != 0 { + cfg.MaxMessageSize = c.MaxMessageSize + } + if c.RTOMax != 0 { + cfg.RTOMax = c.RTOMax + } + if c.MinCwnd != 0 { + cfg.MinCwnd = c.MinCwnd + } + if c.FastRtxWnd != 0 { + cfg.FastRtxWnd = c.FastRtxWnd + } + if c.CwndCAStep != 0 { + cfg.CwndCAStep = c.CwndCAStep + } + + cfg.rack = c.rack + + cfg.snapConfig = c.snapConfig + + return nil +} + +func buildClientConfig(opts ...ClientOption) (*Config, error) { + cfg := &Config{} + cfg.applyDefaults() + + for _, opt := range opts { + if opt == nil { + continue + } + if err := opt.applyClient(cfg); err != nil { + return nil, err + } + } + + cfg.applyDefaults() + + if cfg.NetConn == nil { + return nil, errNilNetConn + } + + return cfg, nil +} + +func createAssociationFromConfig(cfg *Config) (*Association, error) { + tsn := globalMathRandomGenerator.Uint32() + + return createAssociationFromConfigWithTsn(cfg, tsn), nil +} + +func createAssociationFromConfigWithTsn(cfg *Config, tsn uint32) *Association { + maxReceiveBufferSize := cfg.MaxReceiveBufferSize + if maxReceiveBufferSize == 0 { + maxReceiveBufferSize = initialRecvBufSize + } + + maxMessageSize := cfg.MaxMessageSize + if maxMessageSize == 0 { + maxMessageSize = defaultMaxMessageSize + } + + mtu := cfg.MTU + if mtu == 0 { + mtu = initialMTU + } + + rtoMax := cfg.RTOMax + + assoc := &Association{ + netConn: cfg.NetConn, + maxReceiveBufferSize: maxReceiveBufferSize, + maxMessageSize: maxMessageSize, + minCwnd: cfg.MinCwnd, + fastRtxWnd: cfg.FastRtxWnd, + cwndCAStep: cfg.CwndCAStep, + + myMaxNumOutboundStreams: math.MaxUint16, + myMaxNumInboundStreams: math.MaxUint16, + + payloadQueue: newReceivePayloadQueue(getMaxTSNOffset(maxReceiveBufferSize)), + inflightQueue: newPayloadQueue(), + pendingQueue: newPendingQueue(), + controlQueue: newControlQueue(), + mtu: mtu, + maxPayloadSize: mtu - (commonHeaderSize + dataChunkHeaderSize), + myVerificationTag: generateInitiateTag(), + initialTSN: tsn, + myNextTSN: tsn, + myNextRSN: tsn, + minTSN2MeasureRTT: tsn, + state: closed, + rtoMgr: newRTOManager(rtoMax), + streams: map[uint16]*Stream{}, + reconfigs: map[uint32]*chunkReconfig{}, + reconfigRequests: map[uint32]*paramOutgoingResetRequest{}, + acceptCh: make(chan *Stream, acceptChSize), + readLoopCloseCh: make(chan struct{}), + awakeWriteLoopCh: make(chan struct{}, 1), + closeWriteLoopCh: make(chan struct{}), + handshakeCompletedCh: make(chan error), + cumulativeTSNAckPoint: tsn - 1, + advancedPeerTSNAckPoint: tsn - 1, + recvZeroChecksum: cfg.EnableZeroChecksum, + silentError: ErrSilentlyDiscard, + stats: &associationStats{}, + log: cfg.LoggerFactory.NewLogger("sctp"), + name: cfg.Name, + blockWrite: cfg.BlockWrite, + writeNotify: make(chan struct{}, 1), + abortSentCh: make(chan struct{}), + } + + // adaptive burst mitigation defaults + assoc.tlrBurstFirstRTTUnits = tlrBurstDefaultFirstRTT + assoc.tlrBurstLaterRTTUnits = tlrBurstDefaultLaterRTT + + // RACK defaults + assoc.rack.rackWCDelAck = cfg.rack.rackWCDelAck + if assoc.rack.rackWCDelAck == 0 { + assoc.rack.rackWCDelAck = 200 * time.Millisecond // WCDelAckT, RACK for SCTP section 2C + } + + assoc.rack.rackMinRTTWnd = cfg.rack.rackMinRTTWnd + if assoc.rack.rackMinRTTWnd == nil { + assoc.rack.rackMinRTTWnd = newWindowedMin(30 * time.Second) + } + + assoc.timerUpdateCh = make(chan struct{}, 1) + go assoc.timerLoop() + + assoc.rack.rackReoWndFloor = cfg.rack.rackReoWndFloor // optional floor; usually 0 + assoc.rackKeepInflatedRecoveries = 0 + + if assoc.name == "" { + assoc.name = fmt.Sprintf("%p", assoc) + } + + assoc.setCWND(min32(4*assoc.MTU(), max32(2*assoc.MTU(), 4380))) + assoc.log.Tracef("[%s] updated cwnd=%d ssthresh=%d inflight=%d (INI)", + assoc.name, assoc.CWND(), assoc.ssthresh, assoc.inflightQueue.getNumBytes()) + + assoc.srtt.Store(float64(0)) + assoc.t1Init = newRTXTimer(timerT1Init, assoc, maxInitRetrans, rtoMax) + assoc.t1Cookie = newRTXTimer(timerT1Cookie, assoc, maxInitRetrans, rtoMax) + assoc.t2Shutdown = newRTXTimer(timerT2Shutdown, assoc, noMaxRetrans, rtoMax) + assoc.t3RTX = newRTXTimer(timerT3RTX, assoc, noMaxRetrans, rtoMax) + assoc.tReconfig = newRTXTimer(timerReconfig, assoc, noMaxRetrans, rtoMax) + assoc.ackTimer = newAckTimer(assoc) + + return assoc +} + +func (a *Association) initWithOutOfBandTokens(localInit *chunkInit, remoteInit *chunkInit) { + a.lock.Lock() + defer a.lock.Unlock() + + go a.readLoop() + go a.writeLoop() + + a.payloadQueue.init(remoteInit.initialTSN - 1) + a.myMaxNumInboundStreams = min16(localInit.numInboundStreams, remoteInit.numInboundStreams) + a.myMaxNumOutboundStreams = min16(localInit.numOutboundStreams, remoteInit.numOutboundStreams) + a.setRWND(remoteInit.advertisedReceiverWindowCredit) + a.peerVerificationTag = remoteInit.initiateTag + a.sourcePort = defaultSCTPSrcDstPort + a.destinationPort = defaultSCTPSrcDstPort + for _, param := range remoteInit.params { + switch v := param.(type) { // nolint:gocritic + case *paramSupportedExtensions: + for _, t := range v.ChunkTypes { + if t == ctForwardTSN { + a.log.Debugf("[%s] use ForwardTSN (on init)", a.name) + a.useForwardTSN = true + } + } + case *paramZeroChecksumAcceptable: + a.sendZeroChecksum = v.edmid == dtlsErrorDetectionMethod + } + } + + if !a.useForwardTSN { + a.log.Warnf("[%s] not using ForwardTSN (on init)", a.name) + } + + a.ssthresh = a.RWND() + + a.setState(established) +} + +// caller must hold a.lock. +func (a *Association) sendInit() error { + a.log.Debugf("[%s] sending INIT", a.name) + if a.storedInit == nil { + return ErrInitNotStoredToSend + } + + outbound := &packet{} + outbound.verificationTag = 0 + a.sourcePort = defaultSCTPSrcDstPort + a.destinationPort = defaultSCTPSrcDstPort + outbound.sourcePort = a.sourcePort + outbound.destinationPort = a.destinationPort + + outbound.chunks = []chunk{a.storedInit} + + a.controlQueue.push(outbound) + a.awakeWriteLoop() + + return nil +} + +// caller must hold a.lock. +func (a *Association) sendCookieEcho() error { + if a.storedCookieEcho == nil { + return ErrCookieEchoNotStoredToSend + } + + a.log.Debugf("[%s] sending COOKIE-ECHO", a.name) + + outbound := &packet{} + outbound.verificationTag = a.peerVerificationTag + outbound.sourcePort = a.sourcePort + outbound.destinationPort = a.destinationPort + outbound.chunks = []chunk{a.storedCookieEcho} + + a.controlQueue.push(outbound) + a.awakeWriteLoop() + + return nil +} + +// Shutdown initiates the shutdown sequence. The method blocks until the +// shutdown sequence is completed and the connection is closed, or until the +// passed context is done, in which case the context's error is returned. +func (a *Association) Shutdown(ctx context.Context) error { + a.log.Debugf("[%s] closing association..", a.name) + + state := a.getState() + + if state != established { + return fmt.Errorf("%w: shutdown %s", ErrShutdownNonEstablished, a.name) + } + + // Attempt a graceful shutdown. + a.setState(shutdownPending) + + a.lock.Lock() + + if a.inflightQueue.size() == 0 { + // No more outstanding, send shutdown. + a.willSendShutdown = true + a.awakeWriteLoop() + a.setState(shutdownSent) + } + + a.lock.Unlock() + + select { + case <-a.closeWriteLoopCh: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// Close ends the SCTP Association and cleans up any state. +func (a *Association) Close() error { + a.log.Debugf("[%s] closing association..", a.name) + + err := a.close() + + // Wait for readLoop to end + <-a.readLoopCloseCh + + a.log.Debugf("[%s] association closed", a.name) + a.log.Debugf("[%s] stats nPackets (in) : %d", a.name, a.stats.getNumPacketsReceived()) + a.log.Debugf("[%s] stats nPackets (out) : %d", a.name, a.stats.getNumPacketsSent()) + a.log.Debugf("[%s] stats nDATAs (in) : %d", a.name, a.stats.getNumDATAs()) + a.log.Debugf("[%s] stats nSACKs (in) : %d", a.name, a.stats.getNumSACKsReceived()) + a.log.Debugf("[%s] stats nSACKs (out) : %d", a.name, a.stats.getNumSACKsSent()) + a.log.Debugf("[%s] stats nT3Timeouts : %d", a.name, a.stats.getNumT3Timeouts()) + a.log.Debugf("[%s] stats nAckTimeouts: %d", a.name, a.stats.getNumAckTimeouts()) + a.log.Debugf("[%s] stats nFastRetrans: %d", a.name, a.stats.getNumFastRetrans()) + + return err +} + +func (a *Association) close() error { + a.log.Debugf("[%s] closing association..", a.name) + + a.setState(closed) + + err := a.netConn.Close() + + a.closeAllTimers() + + // awake writeLoop to exit + a.closeWriteLoopOnce.Do(func() { close(a.closeWriteLoopCh) }) + + return err +} + +// Abort sends the abort packet with user initiated abort and immediately +// closes the connection. +func (a *Association) Abort(reason string) { + a.log.Debugf("[%s] aborting association: %s", a.name, reason) + + a.lock.Lock() + + a.willSendAbort = true + a.willSendAbortCause = &errorCauseUserInitiatedAbort{ + upperLayerAbortReason: []byte(reason), + } + + a.lock.Unlock() + + flushTimeout := 200 * time.Millisecond + + // short bound for abort flush. + _ = a.netConn.SetWriteDeadline(time.Now().Add(flushTimeout)) + a.awakeWriteLoop() + + // Give writeLoop a chance to write the ABORT before we force readLoop to exit + // (readLoop exit closes closeWriteLoopCh and can race the ABORT send). + select { + case <-a.abortSentCh: + case <-time.After(flushTimeout): + } + + // unblock readLoop even if the underlying connection is half-open. + // We want Abort to return promptly during shutdown. + _ = a.netConn.SetReadDeadline(time.Now()) + + // Wait for readLoop to end + <-a.readLoopCloseCh + + // Ensure ABORT write was at least attempted before returning (bounded). + select { + case <-a.abortSentCh: + case <-time.After(flushTimeout): + } +} + +func (a *Association) closeAllTimers() { + // Close all retransmission & ack timers + a.t1Init.close() + a.t1Cookie.close() + a.t2Shutdown.close() + a.t3RTX.close() + a.tReconfig.close() + a.ackTimer.close() + a.stopRackTimer() + a.stopPTOTimer() +} + +func (a *Association) readLoop() { + var closeErr error + defer func() { + // also stop writeLoop, otherwise writeLoop can be leaked + // if connection is lost when there is no writing packet. + a.closeWriteLoopOnce.Do(func() { close(a.closeWriteLoopCh) }) + + a.lock.Lock() + a.setState(closed) + for _, s := range a.streams { + a.unregisterStream(s, closeErr) + } + a.lock.Unlock() + close(a.acceptCh) + close(a.readLoopCloseCh) + + a.log.Debugf("[%s] association closed", a.name) + a.log.Debugf("[%s] stats nDATAs (in) : %d", a.name, a.stats.getNumDATAs()) + a.log.Debugf("[%s] stats nSACKs (in) : %d", a.name, a.stats.getNumSACKsReceived()) + a.log.Debugf("[%s] stats nT3Timeouts : %d", a.name, a.stats.getNumT3Timeouts()) + a.log.Debugf("[%s] stats nAckTimeouts: %d", a.name, a.stats.getNumAckTimeouts()) + a.log.Debugf("[%s] stats nFastRetrans: %d", a.name, a.stats.getNumFastRetrans()) + }() + + a.log.Debugf("[%s] readLoop entered", a.name) + buffer := make([]byte, receiveMTU) + + for { + n, err := a.netConn.Read(buffer) + if err != nil { + closeErr = err + + break + } + // Make a buffer sized to what we read, then copy the data we + // read from the underlying transport. We do this because the + // user data is passed to the reassembly queue without + // copying. + inbound := make([]byte, n) + copy(inbound, buffer[:n]) + atomic.AddUint64(&a.bytesReceived, uint64(n)) //nolint:gosec // G115 + if err = a.handleInbound(inbound); err != nil { + closeErr = err + + break + } + } + + a.log.Debugf("[%s] readLoop exited %s", a.name, closeErr) +} + +func (a *Association) writeLoop() { // nolint:cyclop + a.log.Debugf("[%s] writeLoop entered", a.name) + defer a.log.Debugf("[%s] writeLoop exited", a.name) + +loop: + for { + rawPackets, ok := a.gatherOutbound() + + for _, raw := range rawPackets { + isAbortPacket := len(raw) > int(commonHeaderSize) && raw[commonHeaderSize] == byte(ctAbort) + _, err := a.netConn.Write(raw) + if isAbortPacket { + a.abortSentOnce.Do(func() { close(a.abortSentCh) }) + } + if err != nil { + if !errors.Is(err, io.EOF) { + a.log.Warnf("[%s] failed to write packets on netConn: %v", a.name, err) + } + a.log.Debugf("[%s] writeLoop ended", a.name) + + break loop + } + atomic.AddUint64(&a.bytesSent, uint64(len(raw))) + a.stats.incPacketsSent() + } + + if !ok { + if err := a.close(); err != nil { + a.log.Warnf("[%s] failed to close association: %v", a.name, err) + } + + return + } + + select { + case <-a.awakeWriteLoopCh: + case <-a.closeWriteLoopCh: + a.lock.Lock() + abortPending := a.willSendAbort + a.lock.Unlock() + if abortPending { + // If an ABORT is pending, prefer sending it even if readLoop has + // already ended and closed closeWriteLoopCh. + continue + } + + break loop + } + } + + a.setState(closed) + a.closeAllTimers() +} + +func (a *Association) awakeWriteLoop() { + select { + case a.awakeWriteLoopCh <- struct{}{}: + default: + } +} + +func (a *Association) isBlockWrite() bool { + return a.blockWrite +} + +// Mark the association is writable and unblock the waiting write, +// the caller should hold the association write lock. +func (a *Association) notifyBlockWritable() { + a.writePending = false + select { + case a.writeNotify <- struct{}{}: + default: + } +} + +// unregisterStream un-registers a stream from the association +// The caller should hold the association write lock. +func (a *Association) unregisterStream(s *Stream, err error) { + s.lock.Lock() + defer s.lock.Unlock() + + delete(a.streams, s.streamIdentifier) + s.readErr = err + s.readNotifier.Broadcast() +} + +func chunkMandatoryChecksum(cc []chunk) bool { + for _, c := range cc { + switch c.(type) { + case *chunkInit, *chunkCookieEcho: + return true + } + } + + return false +} + +func (a *Association) marshalPacket(p *packet) ([]byte, error) { + return p.marshal(!a.sendZeroChecksum || chunkMandatoryChecksum(p.chunks)) +} + +func (a *Association) unmarshalPacket(raw []byte) (*packet, error) { + p := &packet{} + if err := p.unmarshal(!a.recvZeroChecksum, raw); err != nil { + return nil, err + } + + return p, nil +} + +// handleInbound parses incoming raw packets. +func (a *Association) handleInbound(raw []byte) error { + pkt, err := a.unmarshalPacket(raw) + if err != nil { + a.log.Warnf("[%s] unable to parse SCTP packet %s", a.name, err) + + return nil + } + + if err := checkPacket(pkt); err != nil { + a.log.Warnf("[%s] failed validating packet %s", a.name, err) + + return nil + } + + a.handleChunksStart() + + for _, c := range pkt.chunks { + if err := a.handleChunk(pkt, c); err != nil { + return err + } + } + + a.handleChunksEnd() + + return nil +} + +// The caller should hold the lock. +func (a *Association) gatherDataPacketsToRetransmit(rawPackets [][]byte, budgetUnits *int64, consumed *bool) [][]byte { + for _, p := range a.getDataPacketsToRetransmit(budgetUnits, consumed) { + raw, err := a.marshalPacket(p) + if err != nil { + a.log.Warnf("[%s] failed to serialize a DATA packet to be retransmitted", a.name) + + continue + } + rawPackets = append(rawPackets, raw) + } + + return rawPackets +} + +// The caller should hold the lock. +// +//nolint:cyclop +func (a *Association) gatherOutboundDataAndReconfigPackets( + rawPackets [][]byte, + budgetUnits *int64, + consumed *bool, +) [][]byte { + // Pop unsent data chunks from the pending queue to send as much as + // cwnd and rwnd allow. + chunks, sisToReset := a.popPendingDataChunksToSend(budgetUnits, consumed) + + if len(chunks) > 0 { + // Start timer. (noop if already started) + a.log.Tracef("[%s] T3-rtx timer start (pt1)", a.name) + a.t3RTX.start(a.rtoMgr.getRTO()) + for _, p := range a.bundleDataChunksIntoPackets(chunks) { + raw, err := a.marshalPacket(p) + if err != nil { + a.log.Warnf("[%s] failed to serialize a DATA packet", a.name) + + continue + } + rawPackets = append(rawPackets, raw) + } + // RFC 8985 (RACK) schedule PTO on new data transmission + a.schedulePTOAfterSendLocked() + } + + if len(sisToReset) > 0 || a.willRetransmitReconfig { //nolint:nestif + if a.willRetransmitReconfig { + a.willRetransmitReconfig = false + a.log.Debugf("[%s] retransmit %d RECONFIG chunk(s)", a.name, len(a.reconfigs)) + for _, c := range a.reconfigs { + p := a.createPacket([]chunk{c}) + raw, err := a.marshalPacket(p) + if err != nil { + a.log.Warnf("[%s] failed to serialize a RECONFIG packet to be retransmitted", a.name) + } else { + rawPackets = append(rawPackets, raw) + } + } + } + + if len(sisToReset) > 0 { + rsn := a.generateNextRSN() + tsn := a.myNextTSN - 1 + c := &chunkReconfig{ + paramA: ¶mOutgoingResetRequest{ + reconfigRequestSequenceNumber: rsn, + senderLastTSN: tsn, + streamIdentifiers: sisToReset, + }, + } + a.reconfigs[rsn] = c // store in the map for retransmission + a.log.Debugf("[%s] sending RECONFIG: rsn=%d tsn=%d streams=%v", + a.name, rsn, a.myNextTSN-1, sisToReset) + p := a.createPacket([]chunk{c}) + raw, err := a.marshalPacket(p) + if err != nil { + a.log.Warnf("[%s] failed to serialize a RECONFIG packet to be transmitted", a.name) + } else { + rawPackets = append(rawPackets, raw) + } + } + + if len(a.reconfigs) > 0 { + a.tReconfig.start(a.rtoMgr.getRTO()) + } + } + + return rawPackets +} + +// The caller should hold the lock. +// +//nolint:cyclop +func (a *Association) gatherOutboundFastRetransmissionPackets( //nolint:gocognit + rawPackets [][]byte, + budgetScaled *int64, + consumed *bool, +) [][]byte { + if !a.willRetransmitFast { + return rawPackets + } + a.willRetransmitFast = false + + toFastRetrans := []*chunkPayloadData{} + fastRetransSize := int(commonHeaderSize) + fastRetransWnd := int(max(a.MTU(), a.fastRtxWnd)) + now := time.Now() + + // MTU bundling + burst budgeting tracker + bytesInPacket := 0 + stopBundling := false + + for i := 0; ; i++ { + chunkPayload, ok := a.inflightQueue.get(a.cumulativeTSNAckPoint + uint32(i) + 1) //nolint:gosec // G115 + if !ok { + break // end of pending data + } + + if chunkPayload.acked || chunkPayload.abandoned() { + continue + } + + if chunkPayload.nSent > 1 || chunkPayload.missIndicator < 3 { + continue + } + + // include padding for sizing. + chunkBytes := int(dataChunkHeaderSize) + len(chunkPayload.userData) + chunkBytes += getPadding(chunkBytes) + + // fast retransmit window cap + if fastRetransWnd < fastRetransSize+chunkBytes { + break + } + + // MTU bundling + burst budget before mutating + for { + addBytes := chunkBytes + + if bytesInPacket == 0 { + addBytes += int(commonHeaderSize) + if addBytes > int(a.MTU()) { + stopBundling = true + + break + } + } else if bytesInPacket+chunkBytes > int(a.MTU()) { + // start a new packet and retry this same chunk as first in packet + bytesInPacket = 0 + + continue + } + + if !a.tlrAllowSendLocked(budgetScaled, consumed, addBytes) { + // budget exhausted, stop selecting any more fast-rtx chunks + stopBundling = true + + break + } + + if bytesInPacket == 0 { + bytesInPacket = int(commonHeaderSize) + } + bytesInPacket += chunkBytes + + break + } + + if stopBundling { + break + } + + fastRetransSize += chunkBytes + a.stats.incFastRetrans() + + // Update for retransmission + chunkPayload.nSent++ + chunkPayload.since = now + a.rackRemove(chunkPayload) + a.rackInsert(chunkPayload) + + a.checkPartialReliabilityStatus(chunkPayload) + toFastRetrans = append(toFastRetrans, chunkPayload) + a.log.Tracef("[%s] fast-retransmit: tsn=%d sent=%d htna=%d", + a.name, chunkPayload.tsn, chunkPayload.nSent, a.fastRecoverExitPoint) + } + + if len(toFastRetrans) == 0 { + return rawPackets + } + + for _, p := range a.bundleDataChunksIntoPackets(toFastRetrans) { + raw, err := a.marshalPacket(p) + if err != nil { + a.log.Warnf("[%s] failed to serialize a DATA packet to be fast-retransmitted", a.name) + + continue + } + rawPackets = append(rawPackets, raw) + } + + return rawPackets +} + +// The caller should hold the lock. +func (a *Association) gatherOutboundSackPackets(rawPackets [][]byte) [][]byte { + if a.ackState == ackStateImmediate { + a.ackState = ackStateIdle + sack := a.createSelectiveAckChunk() + a.stats.incSACKsSent() + a.log.Debugf("[%s] sending SACK: %s", a.name, sack) + raw, err := a.marshalPacket(a.createPacket([]chunk{sack})) + if err != nil { + a.log.Warnf("[%s] failed to serialize a SACK packet", a.name) + } else { + rawPackets = append(rawPackets, raw) + } + } + + return rawPackets +} + +// The caller should hold the lock. +func (a *Association) gatherOutboundForwardTSNPackets(rawPackets [][]byte) [][]byte { + if a.willSendForwardTSN { + a.willSendForwardTSN = false + if sna32GT(a.advancedPeerTSNAckPoint, a.cumulativeTSNAckPoint) { + fwdtsn := a.createForwardTSN() + raw, err := a.marshalPacket(a.createPacket([]chunk{fwdtsn})) + if err != nil { + a.log.Warnf("[%s] failed to serialize a Forward TSN packet", a.name) + } else { + rawPackets = append(rawPackets, raw) + } + } + } + + return rawPackets +} + +func (a *Association) gatherOutboundShutdownPackets(rawPackets [][]byte) ([][]byte, bool) { + ok := true + + switch { + case a.willSendShutdown: + a.willSendShutdown = false + + shutdown := &chunkShutdown{ + cumulativeTSNAck: a.cumulativeTSNAckPoint, + } + + raw, err := a.marshalPacket(a.createPacket([]chunk{shutdown})) + if err != nil { + a.log.Warnf("[%s] failed to serialize a Shutdown packet", a.name) + } else { + a.t2Shutdown.start(a.rtoMgr.getRTO()) + rawPackets = append(rawPackets, raw) + } + case a.willSendShutdownAck: + a.willSendShutdownAck = false + + shutdownAck := &chunkShutdownAck{} + + raw, err := a.marshalPacket(a.createPacket([]chunk{shutdownAck})) + if err != nil { + a.log.Warnf("[%s] failed to serialize a ShutdownAck packet", a.name) + } else { + a.t2Shutdown.start(a.rtoMgr.getRTO()) + rawPackets = append(rawPackets, raw) + } + case a.willSendShutdownComplete: + a.willSendShutdownComplete = false + + shutdownComplete := &chunkShutdownComplete{} + + raw, err := a.marshalPacket(a.createPacket([]chunk{shutdownComplete})) + if err != nil { + a.log.Warnf("[%s] failed to serialize a ShutdownComplete packet", a.name) + } else { + rawPackets = append(rawPackets, raw) + ok = false + } + } + + return rawPackets, ok +} + +func (a *Association) gatherAbortPacket() ([]byte, error) { + cause := a.willSendAbortCause + + a.willSendAbort = false + a.willSendAbortCause = nil + + abort := &chunkAbort{} + + if cause != nil { + abort.errorCauses = []errorCause{cause} + } + + raw, err := a.marshalPacket(a.createPacket([]chunk{abort})) + + return raw, err +} + +// gatherOutbound gathers outgoing packets. The returned bool value set to +// false means the association should be closed down after the final send. +func (a *Association) gatherOutbound() ([][]byte, bool) { + a.lock.Lock() + defer a.lock.Unlock() + + if a.willSendAbort { + pkt, err := a.gatherAbortPacket() + if err != nil { + a.log.Warnf("[%s] failed to serialize an abort packet", a.name) + // If we can't marshal the ABORT, no write will occur, but Abort() may + // still be waiting on abortSentCh. Signal completion of the ABORT attempt + // (even though it failed) to avoid unnecessary delays. + a.abortSentOnce.Do(func() { close(a.abortSentCh) }) + + return nil, false + } + + return [][]byte{pkt}, false + } + + rawPackets := [][]byte{} + + if a.controlQueue.size() > 0 { + for _, p := range a.controlQueue.popAll() { + raw, err := a.marshalPacket(p) + if err != nil { + a.log.Warnf("[%s] failed to serialize a control packet", a.name) + + continue + } + rawPackets = append(rawPackets, raw) + } + } + + state := a.getState() + + ok := true + + switch state { + case established: + budgetUnits := a.tlrCurrentBurstBudgetScaledLocked() + consumed := false + + rawPackets = a.gatherDataPacketsToRetransmit(rawPackets, &budgetUnits, &consumed) + rawPackets = a.gatherOutboundDataAndReconfigPackets(rawPackets, &budgetUnits, &consumed) + rawPackets = a.gatherOutboundFastRetransmissionPackets(rawPackets, &budgetUnits, &consumed) + + // control traffic shouldn't be limited. + rawPackets = a.gatherOutboundSackPackets(rawPackets) + rawPackets = a.gatherOutboundForwardTSNPackets(rawPackets) + case shutdownPending, shutdownSent, shutdownReceived: + budgetUnits := a.tlrCurrentBurstBudgetScaledLocked() + consumed := false + + rawPackets = a.gatherDataPacketsToRetransmit(rawPackets, &budgetUnits, &consumed) + rawPackets = a.gatherOutboundFastRetransmissionPackets(rawPackets, &budgetUnits, &consumed) + + rawPackets = a.gatherOutboundSackPackets(rawPackets) + rawPackets, ok = a.gatherOutboundShutdownPackets(rawPackets) + case shutdownAckSent: + rawPackets, ok = a.gatherOutboundShutdownPackets(rawPackets) + } + + return rawPackets, ok +} + +func checkPacket(pkt *packet) error { + // All packets must adhere to these rules + + // This is the SCTP sender's port number. It can be used by the + // receiver in combination with the source IP address, the SCTP + // destination port, and possibly the destination IP address to + // identify the association to which this packet belongs. The port + // number 0 MUST NOT be used. + if pkt.sourcePort == 0 { + return ErrSCTPPacketSourcePortZero + } + + // This is the SCTP port number to which this packet is destined. + // The receiving host will use this port number to de-multiplex the + // SCTP packet to the correct receiving endpoint/application. The + // port number 0 MUST NOT be used. + if pkt.destinationPort == 0 { + return ErrSCTPPacketDestinationPortZero + } + + // Check values on the packet that are specific to a particular chunk type + for _, c := range pkt.chunks { + switch c.(type) { // nolint:gocritic + case *chunkInit: + // An INIT or INIT ACK chunk MUST NOT be bundled with any other chunk. + // They MUST be the only chunks present in the SCTP packets that carry + // them. + if len(pkt.chunks) != 1 { + return ErrInitChunkBundled + } + + // A packet containing an INIT chunk MUST have a zero Verification + // Tag. + if pkt.verificationTag != 0 { + return ErrInitChunkVerifyTagNotZero + } + } + } + + return nil +} + +func min16(a, b uint16) uint16 { + if a < b { + return a + } + + return b +} + +func max32(a, b uint32) uint32 { + if a > b { + return a + } + + return b +} + +func min32(a, b uint32) uint32 { + if a < b { + return a + } + + return b +} + +// peerLastTSN return last received cumulative TSN. +func (a *Association) peerLastTSN() uint32 { + return a.payloadQueue.getcumulativeTSN() +} + +// setState atomically sets the state of the Association. +// The caller should hold the lock. +func (a *Association) setState(newState uint32) { + oldState := atomic.SwapUint32(&a.state, newState) + if newState != oldState { + a.log.Debugf("[%s] state change: '%s' => '%s'", + a.name, + getAssociationStateString(oldState), + getAssociationStateString(newState)) + } +} + +// getState atomically returns the state of the Association. +func (a *Association) getState() uint32 { + return atomic.LoadUint32(&a.state) +} + +// BytesSent returns the number of bytes sent. +func (a *Association) BytesSent() uint64 { + return atomic.LoadUint64(&a.bytesSent) +} + +// BytesReceived returns the number of bytes received. +func (a *Association) BytesReceived() uint64 { + return atomic.LoadUint64(&a.bytesReceived) +} + +// MTU returns the association's current MTU. +func (a *Association) MTU() uint32 { + return atomic.LoadUint32(&a.mtu) +} + +// CWND returns the association's current congestion window (cwnd). +func (a *Association) CWND() uint32 { + return atomic.LoadUint32(&a.cwnd) +} + +func (a *Association) setCWND(cwnd uint32) { + if cwnd < a.minCwnd { + cwnd = a.minCwnd + } + atomic.StoreUint32(&a.cwnd, cwnd) +} + +// RWND returns the association's current receiver window (rwnd). +func (a *Association) RWND() uint32 { + return atomic.LoadUint32(&a.rwnd) +} + +func (a *Association) setRWND(rwnd uint32) { + atomic.StoreUint32(&a.rwnd, rwnd) +} + +// SRTT returns the latest smoothed round-trip time (srrt). +func (a *Association) SRTT() float64 { + return a.srtt.Load().(float64) //nolint:forcetypeassert +} + +// getMaxTSNOffset returns the maximum offset over the current cummulative TSN that +// we are willing to enqueue. This ensures that we keep the bytes utilized in the receive +// buffer within a small multiple of the user provided max receive buffer size. +func getMaxTSNOffset(maxReceiveBufferSize uint32) uint32 { + // 4 is a magic number here. There is no theory behind this. + offset := min(max((maxReceiveBufferSize*4)/avgChunkSize, minTSNOffset), maxTSNOffset) + + return offset +} + +func setSupportedExtensions(init *chunkInitCommon) { + // nolint:godox + // TODO RFC5061 https://tools.ietf.org/html/rfc6525#section-5.2 + // An implementation supporting this (Supported Extensions Parameter) + // extension MUST list the ASCONF, the ASCONF-ACK, and the AUTH chunks + // in its INIT and INIT-ACK parameters. + init.params = append(init.params, ¶mSupportedExtensions{ + ChunkTypes: []chunkType{ctReconfig, ctForwardTSN}, + }) +} + +// The caller should hold the lock. +// +//nolint:cyclop +func (a *Association) handleInit(pkt *packet, initChunk *chunkInit) ([]*packet, error) { + state := a.getState() + a.log.Debugf("[%s] chunkInit received in state '%s'", a.name, getAssociationStateString(state)) + + // https://tools.ietf.org/html/rfc4960#section-5.2.1 + // Upon receipt of an INIT in the COOKIE-WAIT state, an endpoint MUST + // respond with an INIT ACK using the same parameters it sent in its + // original INIT chunk (including its Initiate Tag, unchanged). When + // responding, the endpoint MUST send the INIT ACK back to the same + // address that the original INIT (sent by this endpoint) was sent. + + if state != closed && state != cookieWait && state != cookieEchoed { + // 5.2.2. Unexpected INIT in States Other than CLOSED, COOKIE-ECHOED, + // COOKIE-WAIT, and SHUTDOWN-ACK-SENT + return nil, fmt.Errorf("%w: %s", ErrHandleInitState, getAssociationStateString(state)) + } + + // NOTE: Setting these prior to a reception of a COOKIE ECHO chunk containing + // our cookie is not compliant with https://www.rfc-editor.org/rfc/rfc9260#section-5.1-2.2.3. + // It makes us more vulnerable to resource attacks, albeit minimally so. + // https://www.rfc-editor.org/rfc/rfc9260#sec_handle_stream_parameters + a.myMaxNumInboundStreams = min16(initChunk.numInboundStreams, a.myMaxNumInboundStreams) + a.myMaxNumOutboundStreams = min16(initChunk.numOutboundStreams, a.myMaxNumOutboundStreams) + a.peerVerificationTag = initChunk.initiateTag + a.sourcePort = pkt.destinationPort + a.destinationPort = pkt.sourcePort + + // 13.2 This is the last TSN received in sequence. This value + // is set initially by taking the peer's initial TSN, + // received in the INIT or INIT ACK chunk, and + // subtracting one from it. + a.payloadQueue.init(initChunk.initialTSN - 1) + + a.setRWND(initChunk.advertisedReceiverWindowCredit) + a.log.Debugf("[%s] initial rwnd=%d", a.name, a.RWND()) + + for _, param := range initChunk.params { + switch v := param.(type) { // nolint:gocritic + case *paramSupportedExtensions: + for _, t := range v.ChunkTypes { + if t == ctForwardTSN { + a.log.Debugf("[%s] use ForwardTSN (on init)", a.name) + a.useForwardTSN = true + } + } + case *paramZeroChecksumAcceptable: + a.sendZeroChecksum = v.edmid == dtlsErrorDetectionMethod + } + } + + if !a.useForwardTSN { + a.log.Warnf("[%s] not using ForwardTSN (on init)", a.name) + } + + outbound := &packet{} + outbound.verificationTag = a.peerVerificationTag + outbound.sourcePort = a.sourcePort + outbound.destinationPort = a.destinationPort + + initAck := &chunkInitAck{} + a.log.Debug("sending INIT ACK") + + initAck.initialTSN = a.myNextTSN + initAck.numOutboundStreams = a.myMaxNumOutboundStreams + initAck.numInboundStreams = a.myMaxNumInboundStreams + initAck.initiateTag = a.myVerificationTag + initAck.advertisedReceiverWindowCredit = a.maxReceiveBufferSize + + if a.myCookie == nil { + var err error + // NOTE: This generation process is not compliant with + // 5.1.3. Generating State Cookie (https://www.rfc-editor.org/rfc/rfc4960#section-5.1.3) + if a.myCookie, err = newRandomStateCookie(); err != nil { + return nil, err + } + } + + initAck.params = []param{a.myCookie} + + if a.recvZeroChecksum { + initAck.params = append(initAck.params, ¶mZeroChecksumAcceptable{edmid: dtlsErrorDetectionMethod}) + } + a.log.Debugf("[%s] sendZeroChecksum=%t (on init)", a.name, a.sendZeroChecksum) + + setSupportedExtensions(&initAck.chunkInitCommon) + + outbound.chunks = []chunk{initAck} + + return pack(outbound), nil +} + +// The caller should hold the lock. +func (a *Association) handleInitAck(pkt *packet, initChunkAck *chunkInitAck) error { //nolint:cyclop + state := a.getState() + a.log.Debugf("[%s] chunkInitAck received in state '%s'", a.name, getAssociationStateString(state)) + if state != cookieWait { + // RFC 4960 + // 5.2.3. Unexpected INIT ACK + // If an INIT ACK is received by an endpoint in any state other than the + // COOKIE-WAIT state, the endpoint should discard the INIT ACK chunk. + // An unexpected INIT ACK usually indicates the processing of an old or + // duplicated INIT chunk. + return nil + } + + a.myMaxNumInboundStreams = min16(initChunkAck.numInboundStreams, a.myMaxNumInboundStreams) + a.myMaxNumOutboundStreams = min16(initChunkAck.numOutboundStreams, a.myMaxNumOutboundStreams) + a.peerVerificationTag = initChunkAck.initiateTag + a.payloadQueue.init(initChunkAck.initialTSN - 1) + if a.sourcePort != pkt.destinationPort || + a.destinationPort != pkt.sourcePort { + a.log.Warnf("[%s] handleInitAck: port mismatch", a.name) + + return nil + } + + a.setRWND(initChunkAck.advertisedReceiverWindowCredit) + a.log.Debugf("[%s] initial rwnd=%d", a.name, a.RWND()) + + // RFC 4960 Sec 7.2.1 + // o The initial value of ssthresh MAY be arbitrarily high (for + // example, implementations MAY use the size of the receiver + // advertised window). + a.ssthresh = a.RWND() + a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d inflight=%d (INI)", + a.name, a.CWND(), a.ssthresh, a.inflightQueue.getNumBytes()) + + a.t1Init.stop() + a.storedInit = nil + + var cookieParam *paramStateCookie + for _, param := range initChunkAck.params { + switch v := param.(type) { + case *paramStateCookie: + cookieParam = v + case *paramSupportedExtensions: + for _, t := range v.ChunkTypes { + if t == ctForwardTSN { + a.log.Debugf("[%s] use ForwardTSN (on initAck)", a.name) + a.useForwardTSN = true + } + } + case *paramZeroChecksumAcceptable: + a.sendZeroChecksum = v.edmid == dtlsErrorDetectionMethod + } + } + + a.log.Debugf("[%s] sendZeroChecksum=%t (on initAck)", a.name, a.sendZeroChecksum) + + if !a.useForwardTSN { + a.log.Warnf("[%s] not using ForwardTSN (on initAck)", a.name) + } + if cookieParam == nil { + return ErrInitAckNoCookie + } + + a.storedCookieEcho = &chunkCookieEcho{} + a.storedCookieEcho.cookie = cookieParam.cookie + + err := a.sendCookieEcho() + if err != nil { + a.log.Errorf("[%s] failed to send init: %s", a.name, err.Error()) + } + + a.t1Cookie.start(a.rtoMgr.getRTO()) + a.setState(cookieEchoed) + + return nil +} + +// The caller should hold the lock. +func (a *Association) handleHeartbeat(c *chunkHeartbeat) []*packet { + a.log.Tracef("[%s] chunkHeartbeat", a.name) + + if len(c.params) == 0 { + a.log.Warnf("[%s] Heartbeat without ParamHeartbeatInfo (no params)", a.name) + + return nil + } + + info, ok := c.params[0].(*paramHeartbeatInfo) + if !ok { + a.log.Warnf("[%s] Heartbeat without ParamHeartbeatInfo (got %T)", a.name, c.params[0]) + + return nil + } + + return pack(&packet{ + verificationTag: a.peerVerificationTag, + sourcePort: a.sourcePort, + destinationPort: a.destinationPort, + chunks: []chunk{&chunkHeartbeatAck{ + params: []param{ + ¶mHeartbeatInfo{ + heartbeatInformation: info.heartbeatInformation, + }, + }, + }}, + }) +} + +// The caller should hold the lock. +func (a *Association) handleHeartbeatAck(c *chunkHeartbeatAck) { + a.log.Tracef("[%s] chunkHeartbeatAck", a.name) + + if len(c.params) == 0 { + return + } + + info, ok := c.params[0].(*paramHeartbeatInfo) + if !ok { + a.log.Warnf("[%s] HeartbeatAck without ParamHeartbeatInfo", a.name) + + return + } + + // active RTT probe: if heartbeatInformation is exactly 8 bytes, treat it + // as a big-endian unix nano timestamp. + if len(info.heartbeatInformation) == 8 { + ns := binary.BigEndian.Uint64(info.heartbeatInformation) + if ns > math.MaxInt64 { + // Malformed or future-unsafe value; ignore this heartbeat-ack. + a.log.Warnf("[%s] HB RTT: timestamp overflows int64, ignoring", a.name) + + return + } + + sentNanos := int64(ns) + sent := time.Unix(0, sentNanos) + now := time.Now() + + if !sent.IsZero() && !now.Before(sent) { + rttMs := now.Sub(sent).Seconds() * 1000.0 + srtt := a.rtoMgr.setNewRTT(rttMs) + a.srtt.Store(srtt) + + a.rack.rackMinRTTWnd.Push(now, now.Sub(sent)) + + a.log.Tracef("[%s] HB RTT: measured=%.3fms srtt=%.3fms rto=%.3fms", + a.name, rttMs, srtt, a.rtoMgr.getRTO()) + } + } +} + +// The caller should hold the lock. +func (a *Association) handleCookieEcho(cookieEcho *chunkCookieEcho) []*packet { + state := a.getState() + a.log.Debugf("[%s] COOKIE-ECHO received in state '%s'", a.name, getAssociationStateString(state)) + + if a.myCookie == nil { + a.log.Debugf("[%s] COOKIE-ECHO received before initialization", a.name) + + return nil + } + switch state { + default: + return nil + case established: + if !bytes.Equal(a.myCookie.cookie, cookieEcho.cookie) { + return nil + } + case closed, cookieWait, cookieEchoed: + if !bytes.Equal(a.myCookie.cookie, cookieEcho.cookie) { + return nil + } + + // RFC wise, these do not seem to belong here, but removing them + // causes TestCookieEchoRetransmission to break + a.t1Init.stop() + a.storedInit = nil + + a.t1Cookie.stop() + a.storedCookieEcho = nil + + a.setState(established) + if !a.completeHandshake(nil) { + return nil + } + } + + p := &packet{ + verificationTag: a.peerVerificationTag, + sourcePort: a.sourcePort, + destinationPort: a.destinationPort, + chunks: []chunk{&chunkCookieAck{}}, + } + + return pack(p) +} + +// The caller should hold the lock. +func (a *Association) handleCookieAck() { + state := a.getState() + a.log.Debugf("[%s] COOKIE-ACK received in state '%s'", a.name, getAssociationStateString(state)) + if state != cookieEchoed { + // RFC 4960 + // 5.2.5. Handle Duplicate COOKIE-ACK. + // At any state other than COOKIE-ECHOED, an endpoint should silently + // discard a received COOKIE ACK chunk. + return + } + + a.t1Cookie.stop() + a.storedCookieEcho = nil + + a.setState(established) + a.completeHandshake(nil) +} + +// The caller should hold the lock. +func (a *Association) handleData(chunkPayload *chunkPayloadData) []*packet { + a.log.Tracef("[%s] DATA: tsn=%d immediateSack=%v len=%d", + a.name, chunkPayload.tsn, chunkPayload.immediateSack, len(chunkPayload.userData)) + a.stats.incDATAs() + + canPush := a.payloadQueue.canPush(chunkPayload.tsn) + if canPush { //nolint:nestif + stream := a.getOrCreateStream(chunkPayload.streamIdentifier, true, PayloadTypeUnknown) + if stream == nil { + // silently discard the data. (sender will retry on T3-rtx timeout) + // see pion/sctp#30 + a.log.Debugf("[%s] discard %d", a.name, chunkPayload.streamSequenceNumber) + + return nil + } + + if a.getMyReceiverWindowCredit() > 0 { + // Pass the new chunk to stream level as soon as it arrives + a.payloadQueue.push(chunkPayload.tsn) + stream.handleData(chunkPayload) + } else { + // Receive buffer is full + lastTSN, ok := a.payloadQueue.getLastTSNReceived() + if ok && sna32LT(chunkPayload.tsn, lastTSN) { + a.log.Debugf( + "[%s] receive buffer full, but accepted as this is a missing chunk with tsn=%d ssn=%d", + a.name, chunkPayload.tsn, chunkPayload.streamSequenceNumber, + ) + a.payloadQueue.push(chunkPayload.tsn) + stream.handleData(chunkPayload) + } else { + a.log.Debugf( + "[%s] receive buffer full. dropping DATA with tsn=%d ssn=%d", + a.name, chunkPayload.tsn, chunkPayload.streamSequenceNumber, + ) + } + } + } + + // Upon the reception of a new DATA chunk, an endpoint shall examine the + // continuity of the TSNs received. If the endpoint detects a gap in + // the received DATA chunk sequence, it SHOULD send a SACK with Gap Ack + // Blocks immediately. The data receiver continues sending a SACK after + // receipt of each SCTP packet that doesn't fill the gap. + // https://datatracker.ietf.org/doc/html/rfc4960#section-6.7 + expectedTSN := a.peerLastTSN() + 1 + gapDetected := sna32GT(chunkPayload.tsn, expectedTSN) + + sackNow := chunkPayload.immediateSack || gapDetected + + return a.handlePeerLastTSNAndAcknowledgement(sackNow) +} + +// A common routine for handleData and handleForwardTSN routines +// The caller should hold the lock. +func (a *Association) handlePeerLastTSNAndAcknowledgement(sackImmediately bool) []*packet { //nolint:cyclop + var reply []*packet + + // Try to advance peerLastTSN + + // From RFC 3758 Sec 3.6: + // .. and then MUST further advance its cumulative TSN point locally + // if possible + // Meaning, if peerLastTSN+1 points to a chunk that is received, + // advance peerLastTSN until peerLastTSN+1 points to unreceived chunk. + for { + if popOk := a.payloadQueue.pop(false); !popOk { + break + } + + for _, rstReq := range a.reconfigRequests { + resp := a.resetStreamsIfAny(rstReq) + if resp != nil { + a.log.Debugf("[%s] RESET RESPONSE: %+v", a.name, resp) + reply = append(reply, resp) + } + } + } + + hasPacketLoss := (a.payloadQueue.size() > 0) + if hasPacketLoss { + a.log.Tracef("[%s] packetloss: %s", a.name, a.payloadQueue.getGapAckBlocksString()) + } + + // RFC 4960 $6.7: SHOULD ack immediately when detecting a gap. + if sackImmediately || hasPacketLoss || a.ackMode == ackModeNoDelay { + a.immediateAckTriggered = true + + return reply + } + + if a.ackMode == ackModeAlwaysDelay || (a.ackMode == ackModeNormal && a.ackState != ackStateImmediate) { + if a.ackState == ackStateIdle { + a.delayedAckTriggered = true + } else { + a.immediateAckTriggered = true + } + + return reply + } + + a.immediateAckTriggered = true + + return reply +} + +// The caller should hold the lock. +func (a *Association) getMyReceiverWindowCredit() uint32 { + var bytesQueued uint32 + for _, s := range a.streams { + bytesQueued += uint32(s.getNumBytesInReassemblyQueue()) //nolint:gosec // G115 + } + + if bytesQueued >= a.maxReceiveBufferSize { + return 0 + } + + return a.maxReceiveBufferSize - bytesQueued +} + +// OpenStream opens a stream. +func (a *Association) OpenStream( + streamIdentifier uint16, + defaultPayloadType PayloadProtocolIdentifier, +) (*Stream, error) { + a.lock.Lock() + defer a.lock.Unlock() + + switch a.getState() { + case shutdownAckSent, shutdownPending, shutdownReceived, shutdownSent, closed: + return nil, ErrAssociationClosed + } + + return a.getOrCreateStream(streamIdentifier, false, defaultPayloadType), nil +} + +// AcceptStream accepts a stream. +func (a *Association) AcceptStream() (*Stream, error) { + s, ok := <-a.acceptCh + if !ok { + return nil, io.EOF // no more incoming streams + } + + return s, nil +} + +// createStream creates a stream. The caller should hold the lock and check no stream exists for this id. +func (a *Association) createStream(streamIdentifier uint16, accept bool) *Stream { + stream := &Stream{ + association: a, + streamIdentifier: streamIdentifier, + reassemblyQueue: newReassemblyQueue(streamIdentifier), + log: a.log, + name: fmt.Sprintf("%d:%s", streamIdentifier, a.name), + writeDeadline: deadline.New(), + } + + stream.readNotifier = sync.NewCond(&stream.lock) + + if accept { + select { + case a.acceptCh <- stream: + a.streams[streamIdentifier] = stream + a.log.Debugf("[%s] accepted a new stream (streamIdentifier: %d)", + a.name, streamIdentifier) + default: + a.log.Debugf("[%s] dropped a new stream (acceptCh size: %d)", + a.name, len(a.acceptCh)) + + return nil + } + } else { + a.streams[streamIdentifier] = stream + } + + return stream +} + +// getOrCreateStream gets or creates a stream. The caller should hold the lock. +func (a *Association) getOrCreateStream( + streamIdentifier uint16, + accept bool, + defaultPayloadType PayloadProtocolIdentifier, +) *Stream { + if s, ok := a.streams[streamIdentifier]; ok { + s.SetDefaultPayloadType(defaultPayloadType) + + return s + } + + s := a.createStream(streamIdentifier, accept) + if s != nil { + s.SetDefaultPayloadType(defaultPayloadType) + } + + return s +} + +// The caller should hold the lock. +// +//nolint:gocognit,cyclop +func (a *Association) processSelectiveAck(selectiveAckChunk *chunkSelectiveAck) ( + bytesAckedPerStream map[uint16]int, + htna uint32, + newestDeliveredSendTime time.Time, + newestDeliveredOrigTSN uint32, + deliveredFound bool, + err error, +) { + bytesAckedPerStream = map[uint16]int{} + now := time.Now() // capture the time for this SACK + + // New ack point, so pop all ACKed packets from inflightQueue + // We add 1 because the "currentAckPoint" has already been popped from the inflight queue + // For the first SACK we take care of this by setting the ackpoint to cumAck - 1 + for idx := a.cumulativeTSNAckPoint + 1; sna32LTE(idx, selectiveAckChunk.cumulativeTSNAck); idx++ { + chunkPayload, ok := a.inflightQueue.pop(idx) + if !ok { + return nil, 0, time.Time{}, 0, false, fmt.Errorf("%w: %v", ErrInflightQueueTSNPop, idx) + } + + // RACK: remove from xmit-time list since it's delivered + a.rackRemove(chunkPayload) + + if !chunkPayload.acked { //nolint:nestif + // RFC 4960 sec 6.3.2. Retransmission Timer Rules + // R3) Whenever a SACK is received that acknowledges the DATA chunk + // with the earliest outstanding TSN for that address, restart the + // T3-rtx timer for that address with its current RTO (if there is + // still outstanding data on that address). + if idx == a.cumulativeTSNAckPoint+1 { + // T3 timer needs to be reset. Stop it for now. + a.t3RTX.stop() + } + + nBytesAcked := len(chunkPayload.userData) + + // Sum the number of bytes acknowledged per stream + if amount, ok := bytesAckedPerStream[chunkPayload.streamIdentifier]; ok { + bytesAckedPerStream[chunkPayload.streamIdentifier] = amount + nBytesAcked + } else { + bytesAckedPerStream[chunkPayload.streamIdentifier] = nBytesAcked + } + + // RFC 4960 sec 6.3.1. RTO Calculation + // C4) When data is in flight and when allowed by rule C5 below, a new + // RTT measurement MUST be made each round trip. Furthermore, new + // RTT measurements SHOULD be made no more than once per round trip + // for a given destination transport address. + // C5) Karn's algorithm: RTT measurements MUST NOT be made using + // packets that were retransmitted (and thus for which it is + // ambiguous whether the reply was for the first instance of the + // chunk or for a later instance) + if sna32GTE(chunkPayload.tsn, a.minTSN2MeasureRTT) { + // Only original transmissions for classic RTT measurement (Karn's rule) + if chunkPayload.nSent == 1 { + a.minTSN2MeasureRTT = a.myNextTSN + rtt := now.Sub(chunkPayload.since).Seconds() * 1000.0 + srtt := a.rtoMgr.setNewRTT(rtt) + a.srtt.Store(srtt) + + // use a window to determine minRtt instead of a global min + // as the RTT can fluctuate, which can cause problems if going from a + // high RTT to a low RTT. + a.rack.rackMinRTTWnd.Push(now, now.Sub(chunkPayload.since)) + + a.log.Tracef("[%s] SACK: measured-rtt=%f srtt=%f new-rto=%f", + a.name, rtt, srtt, a.rtoMgr.getRTO()) + } + } + + // RFC 8985 (RACK) sec 5.2: RACK.segment is the most recently sent + // segment that has been delivered, including retransmissions. + if chunkPayload.since.After(newestDeliveredSendTime) { + newestDeliveredSendTime = chunkPayload.since + newestDeliveredOrigTSN = chunkPayload.tsn + deliveredFound = true + } + } + + if a.inFastRecovery && chunkPayload.tsn == a.fastRecoverExitPoint { + a.log.Debugf("[%s] exit fast-recovery", a.name) + a.inFastRecovery = false + } + } + + htna = selectiveAckChunk.cumulativeTSNAck + + // Mark selectively acknowledged chunks as "acked" + for _, g := range selectiveAckChunk.gapAckBlocks { + for i := g.start; i <= g.end; i++ { + tsn := selectiveAckChunk.cumulativeTSNAck + uint32(i) + chunkPayload, ok := a.inflightQueue.get(tsn) + if !ok { + return nil, 0, time.Time{}, 0, false, fmt.Errorf("%w: %v", ErrTSNRequestNotExist, tsn) + } + + // RACK: remove from xmit-time list since it's delivered + a.rackRemove(chunkPayload) + + if !chunkPayload.acked { //nolint:nestif + nBytesAcked := a.inflightQueue.markAsAcked(tsn) + + // Sum the number of bytes acknowledged per stream + if amount, ok := bytesAckedPerStream[chunkPayload.streamIdentifier]; ok { + bytesAckedPerStream[chunkPayload.streamIdentifier] = amount + nBytesAcked + } else { + bytesAckedPerStream[chunkPayload.streamIdentifier] = nBytesAcked + } + + a.log.Tracef("[%s] tsn=%d has been sacked", a.name, chunkPayload.tsn) + + // RTT / RTO and RACK updates + if sna32GTE(chunkPayload.tsn, a.minTSN2MeasureRTT) { + // Only original transmissions for classic RTT measurement + if chunkPayload.nSent == 1 { + a.minTSN2MeasureRTT = a.myNextTSN + rtt := now.Sub(chunkPayload.since).Seconds() * 1000.0 + srtt := a.rtoMgr.setNewRTT(rtt) + a.srtt.Store(srtt) + + a.rack.rackMinRTTWnd.Push(now, now.Sub(chunkPayload.since)) + + a.log.Tracef("[%s] SACK: measured-rtt=%f srtt=%f new-rto=%f", + a.name, rtt, srtt, a.rtoMgr.getRTO()) + } + } + + if chunkPayload.since.After(newestDeliveredSendTime) { + newestDeliveredSendTime = chunkPayload.since + newestDeliveredOrigTSN = chunkPayload.tsn + deliveredFound = true + } + } + + if sna32LT(htna, tsn) { + htna = tsn + } + } + } + + return bytesAckedPerStream, htna, newestDeliveredSendTime, newestDeliveredOrigTSN, deliveredFound, nil +} + +// The caller should hold the lock. +func (a *Association) onCumulativeTSNAckPointAdvanced(totalBytesAcked int) { + // RFC 4960, sec 6.3.2. Retransmission Timer Rules + // R2) Whenever all outstanding data sent to an address have been + // acknowledged, turn off the T3-rtx timer of that address. + if a.inflightQueue.size() == 0 { + a.log.Tracef("[%s] SACK: no more packet in-flight (pending=%d)", a.name, a.pendingQueue.size()) + a.t3RTX.stop() + a.stopPTOTimer() + a.stopRackTimer() + } else { + a.log.Tracef("[%s] T3-rtx timer start (pt2)", a.name) + a.t3RTX.start(a.rtoMgr.getRTO()) + } + + // Update congestion control parameters + if a.CWND() <= a.ssthresh { //nolint:nestif + // RFC 4960, sec 7.2.1. Slow-Start + // o When cwnd is less than or equal to ssthresh, an SCTP endpoint MUST + // use the slow-start algorithm to increase cwnd only if the current + // congestion window is being fully utilized, an incoming SACK + // advances the Cumulative TSN Ack Point, and the data sender is not + // in Fast Recovery. Only when these three conditions are met can + // the cwnd be increased; otherwise, the cwnd MUST not be increased. + // If these conditions are met, then cwnd MUST be increased by, at + // most, the lesser of 1) the total size of the previously + // outstanding DATA chunk(s) acknowledged, and 2) the destination's + // path MTU. + if !a.inFastRecovery && + a.pendingQueue.size() > 0 { + a.setCWND(a.CWND() + min32(uint32(totalBytesAcked), a.CWND())) //nolint:gosec // G115 + // a.cwnd += min32(uint32(totalBytesAcked), a.MTU()) // SCTP way (slow) + a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d acked=%d (SS)", + a.name, a.CWND(), a.ssthresh, totalBytesAcked) + } else { + a.log.Tracef("[%s] cwnd did not grow: cwnd=%d ssthresh=%d acked=%d FR=%v pending=%d", + a.name, a.CWND(), a.ssthresh, totalBytesAcked, a.inFastRecovery, a.pendingQueue.size()) + } + } else { + // RFC 4960, sec 7.2.2. Congestion Avoidance + // o Whenever cwnd is greater than ssthresh, upon each SACK arrival + // that advances the Cumulative TSN Ack Point, increase + // partial_bytes_acked by the total number of bytes of all new chunks + // acknowledged in that SACK including chunks acknowledged by the new + // Cumulative TSN Ack and by Gap Ack Blocks. + a.partialBytesAcked += uint32(totalBytesAcked) //nolint:gosec // G115 + + // o When partial_bytes_acked is equal to or greater than cwnd and + // before the arrival of the SACK the sender had cwnd or more bytes + // of data outstanding (i.e., before arrival of the SACK, flight size + // was greater than or equal to cwnd), increase cwnd by MTU, and + // reset partial_bytes_acked to (partial_bytes_acked - cwnd). + if a.partialBytesAcked >= a.CWND() && a.pendingQueue.size() > 0 { + a.partialBytesAcked -= a.CWND() + step := max(a.MTU(), a.cwndCAStep) + a.setCWND(a.CWND() + step) + a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d acked=%d (CA)", + a.name, a.CWND(), a.ssthresh, totalBytesAcked) + } + } +} + +// The caller should hold the lock. +// +//nolint:cyclop +func (a *Association) processFastRetransmission( //nolint:gocognit + cumTSNAckPoint uint32, + gapAckBlocks []gapAckBlock, + htna uint32, + cumTSNAckPointAdvanced bool, +) error { + // HTNA algorithm - RFC 4960 Sec 7.2.4 + // Increment missIndicator of each chunks that the SACK reported missing + // when either of the following is met: + // a) Not in fast-recovery + // miss indications are incremented only for missing TSNs prior to the + // highest TSN newly acknowledged in the SACK. + // b) In fast-recovery AND the Cumulative TSN Ack Point advanced + // the miss indications are incremented for all TSNs reported missing + // in the SACK. + //nolint:nestif + if !a.inFastRecovery || + (a.inFastRecovery && cumTSNAckPointAdvanced) { + var maxTSN uint32 + if !a.inFastRecovery { + // a) increment only for missing TSNs prior to the HTNA + maxTSN = htna + } else { + // b) increment for all TSNs reported missing + maxTSN = cumTSNAckPoint + if len(gapAckBlocks) > 0 { + maxTSN += uint32(gapAckBlocks[len(gapAckBlocks)-1].end) + } + } + + for tsn := cumTSNAckPoint + 1; sna32LT(tsn, maxTSN); tsn++ { + c, ok := a.inflightQueue.get(tsn) + if !ok { + return fmt.Errorf("%w: %v", ErrTSNRequestNotExist, tsn) + } + if !c.acked && !c.abandoned() && c.missIndicator < 3 { + c.missIndicator++ + if c.missIndicator == 3 { + if a.tlrActive { + a.tlrApplyAdditionalLossLocked(time.Now()) + } + + if !a.inFastRecovery { + // 2) If not in Fast Recovery, adjust the ssthresh and cwnd of the + // destination address(es) to which the missing DATA chunks were + // last sent, according to the formula described in Section 7.2.3. + a.inFastRecovery = true + a.fastRecoverExitPoint = htna + a.ssthresh = max32(a.CWND()/2, 4*a.MTU()) + a.setCWND(a.ssthresh) + a.partialBytesAcked = 0 + a.willRetransmitFast = true + + a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d inflight=%d (FR)", + a.name, a.CWND(), a.ssthresh, a.inflightQueue.getNumBytes()) + } + } + } + } + } + + if a.inFastRecovery && cumTSNAckPointAdvanced { + a.willRetransmitFast = true + } + + return nil +} + +// The caller should hold the lock. +// +//nolint:cyclop +func (a *Association) handleSack(selectiveAckChunk *chunkSelectiveAck) error { + a.log.Tracef( + "[%s] SACK: cumTSN=%d a_rwnd=%d", + a.name, selectiveAckChunk.cumulativeTSNAck, selectiveAckChunk.advertisedReceiverWindowCredit, + ) + state := a.getState() + if state != established && state != shutdownPending && state != shutdownReceived { + return nil + } + + a.stats.incSACKsReceived() + + if sna32GT(a.cumulativeTSNAckPoint, selectiveAckChunk.cumulativeTSNAck) { + // RFC 4960 sec 6.2.1. Processing a Received SACK + // D) + // i) If Cumulative TSN Ack is less than the Cumulative TSN Ack + // Point, then drop the SACK. Since Cumulative TSN Ack is + // monotonically increasing, a SACK whose Cumulative TSN Ack is + // less than the Cumulative TSN Ack Point indicates an out-of- + // order SACK. + + a.log.Debugf("[%s] SACK Cumulative ACK %v is older than ACK point %v", + a.name, + selectiveAckChunk.cumulativeTSNAck, + a.cumulativeTSNAckPoint) + + return nil + } + + // Process selective ack + bytesAckedPerStream, htna, + newestDeliveredSendTime, newestDeliveredOrigTSN, + deliveredFound, err := a.processSelectiveAck(selectiveAckChunk) + if err != nil { + return err + } + + var totalBytesAcked int + for _, nBytesAcked := range bytesAckedPerStream { + totalBytesAcked += nBytesAcked + } + + cumTSNAckPointAdvanced := false + if sna32LT(a.cumulativeTSNAckPoint, selectiveAckChunk.cumulativeTSNAck) { + a.log.Tracef("[%s] SACK: cumTSN advanced: %d -> %d", + a.name, + a.cumulativeTSNAckPoint, + selectiveAckChunk.cumulativeTSNAck) + + a.cumulativeTSNAckPoint = selectiveAckChunk.cumulativeTSNAck + cumTSNAckPointAdvanced = true + a.onCumulativeTSNAckPointAdvanced(totalBytesAcked) + } + + for si, nBytesAcked := range bytesAckedPerStream { + if s, ok := a.streams[si]; ok { + a.lock.Unlock() + s.onBufferReleased(nBytesAcked) + a.lock.Lock() + } + } + + // New rwnd value + // RFC 4960 sec 6.2.1. Processing a Received SACK + // D) + // ii) Set rwnd equal to the newly received a_rwnd minus the number + // of bytes still outstanding after processing the Cumulative + // TSN Ack and the Gap Ack Blocks. + + // bytes acked were already subtracted by markAsAcked() method + bytesOutstanding := uint32(a.inflightQueue.getNumBytes()) //nolint:gosec // G115 + if bytesOutstanding >= selectiveAckChunk.advertisedReceiverWindowCredit { + a.setRWND(0) + } else { + a.setRWND(selectiveAckChunk.advertisedReceiverWindowCredit - bytesOutstanding) + } + + err = a.processFastRetransmission( + selectiveAckChunk.cumulativeTSNAck, selectiveAckChunk.gapAckBlocks, htna, cumTSNAckPointAdvanced, + ) + if err != nil { + return err + } + + if a.useForwardTSN { + // RFC 3758 Sec 3.5 C1 + if sna32LT(a.advancedPeerTSNAckPoint, a.cumulativeTSNAckPoint) { + a.advancedPeerTSNAckPoint = a.cumulativeTSNAckPoint + } + + // RFC 3758 Sec 3.5 C2 + for i := a.advancedPeerTSNAckPoint + 1; ; i++ { + c, ok := a.inflightQueue.get(i) + if !ok { + break + } + if !c.abandoned() { + break + } + a.advancedPeerTSNAckPoint = i + } + + // RFC 3758 Sec 3.5 C3 + if sna32GT(a.advancedPeerTSNAckPoint, a.cumulativeTSNAckPoint) { + a.willSendForwardTSN = true + } + a.awakeWriteLoop() + } + + a.postprocessSack(state, cumTSNAckPointAdvanced) + + // RACK + a.onRackAfterSACK(deliveredFound, newestDeliveredSendTime, newestDeliveredOrigTSN, selectiveAckChunk) + + // adaptive burst mitigation + ackProgress := cumTSNAckPointAdvanced || deliveredFound + a.tlrMaybeFinishLocked(ackProgress) + + return nil +} + +// The caller must hold the lock. This method was only added because the +// linter was complaining about the "cognitive complexity" of handleSack. +func (a *Association) postprocessSack(state uint32, shouldAwakeWriteLoop bool) { + switch { + case a.inflightQueue.size() > 0: + // Start timer. (noop if already started) + a.log.Tracef("[%s] T3-rtx timer start (pt3)", a.name) + a.t3RTX.start(a.rtoMgr.getRTO()) + case state == shutdownPending: + // No more outstanding, send shutdown. + shouldAwakeWriteLoop = true + a.willSendShutdown = true + a.setState(shutdownSent) + case state == shutdownReceived: + // No more outstanding, send shutdown ack. + shouldAwakeWriteLoop = true + a.willSendShutdownAck = true + a.setState(shutdownAckSent) + } + + if shouldAwakeWriteLoop { + a.awakeWriteLoop() + } +} + +// The caller should hold the lock. +func (a *Association) handleShutdown(_ *chunkShutdown) { + state := a.getState() + + switch state { + case established: + if a.inflightQueue.size() > 0 { + a.setState(shutdownReceived) + } else { + // No more outstanding, send shutdown ack. + a.willSendShutdownAck = true + a.setState(shutdownAckSent) + + a.awakeWriteLoop() + } + + // a.cumulativeTSNAckPoint = c.cumulativeTSNAck + case shutdownSent: + a.willSendShutdownAck = true + a.setState(shutdownAckSent) + + a.awakeWriteLoop() + } +} + +// The caller should hold the lock. +func (a *Association) handleShutdownAck(_ *chunkShutdownAck) { + state := a.getState() + if state == shutdownSent || state == shutdownAckSent { + a.t2Shutdown.stop() + a.willSendShutdownComplete = true + + a.awakeWriteLoop() + } +} + +func (a *Association) handleShutdownComplete(_ *chunkShutdownComplete) error { + state := a.getState() + if state == shutdownAckSent { + a.t2Shutdown.stop() + + return a.close() + } + + return nil +} + +func (a *Association) handleAbort(c *chunkAbort) error { + var errStr strings.Builder + for _, e := range c.errorCauses { + fmt.Fprintf(&errStr, "(%s)", e) + } + + _ = a.close() + + return fmt.Errorf("[%s] %w: %s", a.name, ErrChunk, errStr.String()) +} + +// createForwardTSN generates ForwardTSN chunk. +// This method will be be called if useForwardTSN is set to false. +// The caller should hold the lock. +func (a *Association) createForwardTSN() *chunkForwardTSN { + // RFC 3758 Sec 3.5 C4 + streamMap := map[uint16]uint16{} // to report only once per SI + for i := a.cumulativeTSNAckPoint + 1; sna32LTE(i, a.advancedPeerTSNAckPoint); i++ { + c, ok := a.inflightQueue.get(i) + if !ok { + break + } + + ssn, ok := streamMap[c.streamIdentifier] + if !ok { + streamMap[c.streamIdentifier] = c.streamSequenceNumber + } else if sna16LT(ssn, c.streamSequenceNumber) { + // to report only once with greatest SSN + streamMap[c.streamIdentifier] = c.streamSequenceNumber + } + } + + fwdtsn := &chunkForwardTSN{ + newCumulativeTSN: a.advancedPeerTSNAckPoint, + streams: []chunkForwardTSNStream{}, + } + + var streamStr strings.Builder + for si, ssn := range streamMap { + fmt.Fprintf(&streamStr, "(si=%d ssn=%d)", si, ssn) + fwdtsn.streams = append(fwdtsn.streams, chunkForwardTSNStream{ + identifier: si, + sequence: ssn, + }) + } + a.log.Tracef( + "[%s] building fwdtsn: newCumulativeTSN=%d cumTSN=%d - %s", + a.name, fwdtsn.newCumulativeTSN, a.cumulativeTSNAckPoint, streamStr.String(), + ) + + return fwdtsn +} + +// createPacket wraps chunks in a packet. +// The caller should hold the read lock. +func (a *Association) createPacket(cs []chunk) *packet { + return &packet{ + verificationTag: a.peerVerificationTag, + sourcePort: a.sourcePort, + destinationPort: a.destinationPort, + chunks: cs, + } +} + +// The caller should hold the lock. +func (a *Association) handleReconfig(reconfigChunk *chunkReconfig) ([]*packet, error) { + a.log.Tracef("[%s] handleReconfig", a.name) + + pp := make([]*packet, 0) + + pkt, err := a.handleReconfigParam(reconfigChunk.paramA) + if err != nil { + return nil, err + } + if pkt != nil { + pp = append(pp, pkt) + } + + if reconfigChunk.paramB != nil { + pkt, err = a.handleReconfigParam(reconfigChunk.paramB) + if err != nil { + return nil, err + } + if pkt != nil { + pp = append(pp, pkt) + } + } + + return pp, nil +} + +// The caller should hold the lock. +func (a *Association) handleForwardTSN(chunkTSN *chunkForwardTSN) []*packet { + a.log.Tracef("[%s] FwdTSN: %s", a.name, chunkTSN.String()) + + if !a.useForwardTSN { + a.log.Warn("[%s] received FwdTSN but not enabled") + // Return an error chunk + cerr := &chunkError{ + errorCauses: []errorCause{&errorCauseUnrecognizedChunkType{}}, + } + outbound := &packet{} + outbound.verificationTag = a.peerVerificationTag + outbound.sourcePort = a.sourcePort + outbound.destinationPort = a.destinationPort + outbound.chunks = []chunk{cerr} + + return []*packet{outbound} + } + + // From RFC 3758 Sec 3.6: + // Note, if the "New Cumulative TSN" value carried in the arrived + // FORWARD TSN chunk is found to be behind or at the current cumulative + // TSN point, the data receiver MUST treat this FORWARD TSN as out-of- + // date and MUST NOT update its Cumulative TSN. The receiver SHOULD + // send a SACK to its peer (the sender of the FORWARD TSN) since such a + // duplicate may indicate the previous SACK was lost in the network. + + a.log.Tracef("[%s] should send ack? newCumTSN=%d peerLastTSN=%d", + a.name, chunkTSN.newCumulativeTSN, a.peerLastTSN()) + if sna32LTE(chunkTSN.newCumulativeTSN, a.peerLastTSN()) { + a.log.Tracef("[%s] sending ack on Forward TSN", a.name) + a.ackState = ackStateImmediate + a.ackTimer.stop() + a.awakeWriteLoop() + + return nil + } + + // From RFC 3758 Sec 3.6: + // the receiver MUST perform the same TSN handling, including duplicate + // detection, gap detection, SACK generation, cumulative TSN + // advancement, etc. as defined in RFC 2960 [2]---with the following + // exceptions and additions. + + // When a FORWARD TSN chunk arrives, the data receiver MUST first update + // its cumulative TSN point to the value carried in the FORWARD TSN + // chunk, + + // Advance peerLastTSN + for sna32LT(a.peerLastTSN(), chunkTSN.newCumulativeTSN) { + a.payloadQueue.pop(true) // may not exist + } + + // Report new peerLastTSN value and abandoned largest SSN value to + // corresponding streams so that the abandoned chunks can be removed + // from the reassemblyQueue. + for _, forwarded := range chunkTSN.streams { + if s, ok := a.streams[forwarded.identifier]; ok { + s.handleForwardTSNForOrdered(forwarded.sequence) + } + } + + // TSN may be forewared for unordered chunks. ForwardTSN chunk does not + // report which stream identifier it skipped for unordered chunks. + // Therefore, we need to broadcast this event to all existing streams for + // unordered chunks. + // See https://github.com/pion/sctp/issues/106 + for _, s := range a.streams { + s.handleForwardTSNForUnordered(chunkTSN.newCumulativeTSN) + } + + return a.handlePeerLastTSNAndAcknowledgement(false) +} + +func (a *Association) sendResetRequest(streamIdentifier uint16) error { + a.lock.Lock() + defer a.lock.Unlock() + + state := a.getState() + if state != established { + return fmt.Errorf("%w: state=%s", ErrResetPacketInStateNotExist, + getAssociationStateString(state)) + } + + // Create DATA chunk which only contains valid stream identifier with + // nil userData and use it as a EOS from the stream. + c := &chunkPayloadData{ + streamIdentifier: streamIdentifier, + beginningFragment: true, + endingFragment: true, + userData: nil, + } + + a.pendingQueue.push(c) + a.awakeWriteLoop() + + return nil +} + +// The caller should hold the lock. +func (a *Association) handleReconfigParam(raw param) (*packet, error) { + switch par := raw.(type) { + case *paramOutgoingResetRequest: + a.log.Tracef("[%s] handleReconfigParam (OutgoingResetRequest)", a.name) + if a.peerLastTSN() < par.senderLastTSN && len(a.reconfigRequests) >= maxReconfigRequests { + // We have too many reconfig requests outstanding. Drop the request and let + // the peer retransmit. A well behaved peer should only have 1 outstanding + // reconfig request. + // + // RFC 6525: https://www.rfc-editor.org/rfc/rfc6525.html#section-5.1.1 + // At any given time, there MUST NOT be more than one request in flight. + // So, if the Re-configuration Timer is running and the RE-CONFIG chunk + // contains at least one request parameter, the chunk MUST be buffered. + // chrome: + // https://chromium.googlesource.com/external/webrtc/+/refs/heads/main/net/dcsctp/socket/stream_reset_handler.cc#271 + return nil, fmt.Errorf("%w: %d", ErrTooManyReconfigRequests, len(a.reconfigRequests)) + } + a.reconfigRequests[par.reconfigRequestSequenceNumber] = par + resp := a.resetStreamsIfAny(par) + if resp != nil { + return resp, nil + } + + return nil, nil //nolint:nilnil + case *paramReconfigResponse: + a.log.Tracef("[%s] handleReconfigParam (ReconfigResponse)", a.name) + if par.result == reconfigResultInProgress { + // RFC 6525: https://www.rfc-editor.org/rfc/rfc6525.html#section-5.2.7 + // + // If the Result field indicates "In progress", the timer for the + // Re-configuration Request Sequence Number is started again. If + // the timer runs out, the RE-CONFIG chunk MUST be retransmitted + // but the corresponding error counters MUST NOT be incremented. + if _, ok := a.reconfigs[par.reconfigResponseSequenceNumber]; ok { + a.tReconfig.stop() + a.tReconfig.start(a.rtoMgr.getRTO()) + } + + return nil, nil //nolint:nilnil + } + delete(a.reconfigs, par.reconfigResponseSequenceNumber) + if len(a.reconfigs) == 0 { + a.tReconfig.stop() + } + + return nil, nil //nolint:nilnil + default: + return nil, fmt.Errorf("%w: %t", ErrParamterType, par) + } +} + +// The caller should hold the lock. +func (a *Association) resetStreamsIfAny(resetRequest *paramOutgoingResetRequest) *packet { + result := reconfigResultSuccessPerformed + if sna32LTE(resetRequest.senderLastTSN, a.peerLastTSN()) { + a.log.Debugf("[%s] resetStream(): senderLastTSN=%d <= peerLastTSN=%d", + a.name, resetRequest.senderLastTSN, a.peerLastTSN()) + for _, id := range resetRequest.streamIdentifiers { + s, ok := a.streams[id] + if !ok { + continue + } + a.lock.Unlock() + s.onInboundStreamReset() + a.lock.Lock() + a.log.Debugf("[%s] deleting stream %d", a.name, id) + delete(a.streams, s.streamIdentifier) + } + delete(a.reconfigRequests, resetRequest.reconfigRequestSequenceNumber) + } else { + a.log.Debugf("[%s] resetStream(): senderLastTSN=%d > peerLastTSN=%d", + a.name, resetRequest.senderLastTSN, a.peerLastTSN()) + result = reconfigResultInProgress + } + + return a.createPacket([]chunk{&chunkReconfig{ + paramA: ¶mReconfigResponse{ + reconfigResponseSequenceNumber: resetRequest.reconfigRequestSequenceNumber, + result: result, + }, + }}) +} + +// Move the chunk peeked with a.pendingQueue.peek() to the inflightQueue. +// The caller should hold the lock. +func (a *Association) movePendingDataChunkToInflightQueue(chunkPayload *chunkPayloadData) { + if err := a.pendingQueue.pop(chunkPayload); err != nil { + a.log.Errorf("[%s] failed to pop from pending queue: %s", a.name, err.Error()) + } + + if chunkPayload.endingFragment { + chunkPayload.setAllInflight() + } + + // Assign TSN and original send time + chunkPayload.tsn = a.generateNextTSN() + chunkPayload.since = time.Now() + chunkPayload.nSent = 1 + + a.checkPartialReliabilityStatus(chunkPayload) + + a.log.Tracef( + "[%s] sending ppi=%d tsn=%d ssn=%d sent=%d len=%d (%v,%v)", + a.name, + chunkPayload.payloadType, + chunkPayload.tsn, + chunkPayload.streamSequenceNumber, + chunkPayload.nSent, + len(chunkPayload.userData), + chunkPayload.beginningFragment, + chunkPayload.endingFragment, + ) + + a.inflightQueue.pushNoCheck(chunkPayload) + + // RACK: track outstanding original transmissions by send time. + a.rackInsert(chunkPayload) +} + +// popPendingDataChunksToSend pops chunks from the pending queues as many as +// the cwnd and rwnd allows to send. +// The caller should hold the lock. +// +//nolint:cyclop +func (a *Association) popPendingDataChunksToSend( //nolint:cyclop,gocognit + budgetScaled *int64, + consumed *bool, +) ([]*chunkPayloadData, []uint16) { + chunks := []*chunkPayloadData{} + var sisToReset []uint16 // stream indentifiers to reset + + // track current packet size for MTU bundling so budgeting is accurate. + bytesInPacket := 0 + + if a.pendingQueue.size() > 0 { //nolint:nestif + // RFC 4960 sec 6.1. Transmission of DATA Chunks + // A) At any given time, the data sender MUST NOT transmit new data to + // any destination transport address if its peer's rwnd indicates + // that the peer has no buffer space (i.e., rwnd is 0; see Section + // 6.2.1). However, regardless of the value of rwnd (including if it + // is 0), the data sender can always have one DATA chunk in flight to + // the receiver if allowed by cwnd (see rule B, below). + + for { + chunkPayload := a.pendingQueue.peek() + if chunkPayload == nil { + break // no more pending data + } + + dataLen := uint32(len(chunkPayload.userData)) //nolint:gosec // G115 + if dataLen == 0 { + sisToReset = append(sisToReset, chunkPayload.streamIdentifier) + err := a.pendingQueue.pop(chunkPayload) + if err != nil { + a.log.Errorf("failed to pop from pending queue: %s", err.Error()) + } + + continue + } + + if uint32(a.inflightQueue.getNumBytes())+dataLen > a.CWND() { //nolint:gosec // G115 + break // would exceeds cwnd + } + + if dataLen > a.RWND() { + break // no more rwnd + } + + // compute current DATA chunk size including padding. + chunkBytes := int(dataChunkHeaderSize) + len(chunkPayload.userData) + chunkBytes += getPadding(chunkBytes) + + // ensure MTU bundling matches bundleDataChunksIntoPackets(). + addBytes := chunkBytes + if bytesInPacket == 0 { + addBytes += int(commonHeaderSize) + if addBytes > int(a.MTU()) { + break + } + + // reserve budget for common header + first chunk. + if !a.tlrAllowSendLocked(budgetScaled, consumed, addBytes) { + break + } + + bytesInPacket = int(commonHeaderSize) + } else { + // if it doesn't fit, start a new packet and retry same chunk. + if bytesInPacket+chunkBytes > int(a.MTU()) { + bytesInPacket = 0 + + continue + } + + // reserve budget for the additional chunk bytes. + if !a.tlrAllowSendLocked(budgetScaled, consumed, chunkBytes) { + break + } + } + + a.setRWND(a.RWND() - dataLen) + + a.movePendingDataChunkToInflightQueue(chunkPayload) + chunks = append(chunks, chunkPayload) + bytesInPacket += chunkBytes + } + + // allow one DATA chunk if nothing is inflight to the receiver. + if len(chunks) == 0 && a.inflightQueue.size() == 0 { + // Send zero window probe + c := a.pendingQueue.peek() + if c != nil && len(c.userData) > 0 { + // probe is a new packet: common header + chunk bytes. + chunkBytes := int(dataChunkHeaderSize) + len(c.userData) + chunkBytes += getPadding(chunkBytes) + addBytes := int(commonHeaderSize) + chunkBytes + + if addBytes <= int(a.MTU()) && a.tlrAllowSendLocked(budgetScaled, consumed, addBytes) { + a.movePendingDataChunkToInflightQueue(c) + chunks = append(chunks, c) + } + } + } + } + + if a.blockWrite && len(chunks) > 0 && a.pendingQueue.size() == 0 { + a.log.Tracef("[%s] all pending data have been sent, notify writable", a.name) + a.notifyBlockWritable() + } + + return chunks, sisToReset +} + +// bundleDataChunksIntoPackets packs DATA chunks into packets. It tries to bundle +// DATA chunks into a packet so long as the resulting packet size does not exceed +// the path MTU. +// The caller should hold the lock. +func (a *Association) bundleDataChunksIntoPackets(chunks []*chunkPayloadData) []*packet { + packets := []*packet{} + chunksToSend := []chunk{} + bytesInPacket := int(commonHeaderSize) + + for _, chunkPayload := range chunks { + // RFC 4960 sec 6.1. Transmission of DATA Chunks + // Multiple DATA chunks committed for transmission MAY be bundled in a + // single packet. Furthermore, DATA chunks being retransmitted MAY be + // bundled with new DATA chunks, as long as the resulting packet size + // does not exceed the path MTU. + chunkSizeInPacket := int(dataChunkHeaderSize) + len(chunkPayload.userData) + chunkSizeInPacket += getPadding(chunkSizeInPacket) + if bytesInPacket+chunkSizeInPacket > int(a.MTU()) { + packets = append(packets, a.createPacket(chunksToSend)) + chunksToSend = []chunk{} + bytesInPacket = int(commonHeaderSize) + } + chunksToSend = append(chunksToSend, chunkPayload) + bytesInPacket += chunkSizeInPacket + } + + if len(chunksToSend) > 0 { + packets = append(packets, a.createPacket(chunksToSend)) + } + + return packets +} + +// sendPayloadData sends the data chunks. +func (a *Association) sendPayloadData(ctx context.Context, chunks []*chunkPayloadData) error { + a.lock.Lock() + + state := a.getState() + if state != established { + a.lock.Unlock() + + return fmt.Errorf("%w: state=%s", ErrPayloadDataStateNotExist, + getAssociationStateString(state)) + } + + if a.blockWrite { + for a.writePending { + a.lock.Unlock() + select { + case <-ctx.Done(): + return ctx.Err() + case <-a.writeNotify: + a.lock.Lock() + } + } + a.writePending = true + } + + // Push the chunks into the pending queue first. + for _, c := range chunks { + a.pendingQueue.push(c) + } + + a.lock.Unlock() + a.awakeWriteLoop() + + return nil +} + +// The caller should hold the lock. +func (a *Association) checkPartialReliabilityStatus(chunkPayload *chunkPayloadData) { + if !a.useForwardTSN { + return + } + + // draft-ietf-rtcweb-data-protocol-09.txt section 6 + // 6. Procedures + // All Data Channel Establishment Protocol messages MUST be sent using + // ordered delivery and reliable transmission. + // + if chunkPayload.payloadType == PayloadTypeWebRTCDCEP { + return + } + + // PR-SCTP + if stream, ok := a.streams[chunkPayload.streamIdentifier]; ok { //nolint:nestif + stream.lock.RLock() + if stream.reliabilityType == ReliabilityTypeRexmit { + if chunkPayload.nSent >= stream.reliabilityValue { + chunkPayload.setAbandoned(true) + a.rackRemove(chunkPayload) + a.log.Tracef( + "[%s] marked as abandoned: tsn=%d ppi=%d (remix: %d)", + a.name, chunkPayload.tsn, chunkPayload.payloadType, chunkPayload.nSent, + ) + } + } else if stream.reliabilityType == ReliabilityTypeTimed { + elapsed := int64(time.Since(chunkPayload.since).Seconds() * 1000) + if elapsed >= int64(stream.reliabilityValue) { + chunkPayload.setAbandoned(true) + a.rackRemove(chunkPayload) + a.log.Tracef( + "[%s] marked as abandoned: tsn=%d ppi=%d (timed: %d)", + a.name, chunkPayload.tsn, chunkPayload.payloadType, elapsed, + ) + } + } + stream.lock.RUnlock() + } else { + // Remote has reset its send side of the stream, we can still send data. + a.log.Tracef("[%s] stream %d not found, remote reset", a.name, chunkPayload.streamIdentifier) + } +} + +// getDataPacketsToRetransmit is called when T3-rtx is timed out and retransmit outstanding data chunks +// that are not acked or abandoned yet. +// The caller should hold the lock. +func (a *Association) getDataPacketsToRetransmit(budgetScaled *int64, consumed *bool) []*packet { //nolint:cyclop + awnd := min32(a.CWND(), a.RWND()) + chunks := []*chunkPayloadData{} + var bytesToSend int + currRtxTimestamp := time.Now() + + bytesInPacket := 0 + + for i := 0; ; i++ { + chunkPayload, ok := a.inflightQueue.get(a.cumulativeTSNAckPoint + uint32(i) + 1) //nolint:gosec // G115 + if !ok { + break // end of pending data + } + + if !chunkPayload.retransmit { + continue + } + + if i == 0 && int(a.RWND()) < len(chunkPayload.userData) { + // allow as zero window probe + } else if bytesToSend+len(chunkPayload.userData) > int(awnd) { + break + } + + chunkBytes := int(dataChunkHeaderSize) + len(chunkPayload.userData) + chunkBytes += getPadding(chunkBytes) + + // retry as first chunk in a new packet if needed. + for { + addBytes := chunkBytes + if bytesInPacket == 0 { + addBytes += int(commonHeaderSize) + if addBytes > int(a.MTU()) { + return a.bundleDataChunksIntoPackets(chunks) + } + } else if bytesInPacket+chunkBytes > int(a.MTU()) { + bytesInPacket = 0 + + continue + } + + // burst budget gate before mutating the chunk. + if !a.tlrAllowSendLocked(budgetScaled, consumed, addBytes) { + return a.bundleDataChunksIntoPackets(chunks) + } + + if bytesInPacket == 0 { + bytesInPacket = int(commonHeaderSize) + } + bytesInPacket += chunkBytes + + break + } + + chunkPayload.retransmit = false + bytesToSend += len(chunkPayload.userData) + + // Update for retransmission + chunkPayload.nSent++ + chunkPayload.since = currRtxTimestamp + a.rackRemove(chunkPayload) + a.rackInsert(chunkPayload) + + a.checkPartialReliabilityStatus(chunkPayload) + + a.log.Tracef( + "[%s] retransmitting tsn=%d ssn=%d sent=%d", + a.name, chunkPayload.tsn, chunkPayload.streamSequenceNumber, chunkPayload.nSent, + ) + + chunks = append(chunks, chunkPayload) + } + + return a.bundleDataChunksIntoPackets(chunks) +} + +// generateNextTSN returns the myNextTSN and increases it. The caller should hold the lock. +// The caller should hold the lock. +func (a *Association) generateNextTSN() uint32 { + tsn := a.myNextTSN + a.myNextTSN++ + + return tsn +} + +// generateNextRSN returns the myNextRSN and increases it. The caller should hold the lock. +// The caller should hold the lock. +func (a *Association) generateNextRSN() uint32 { + rsn := a.myNextRSN + a.myNextRSN++ + + return rsn +} + +func (a *Association) createSelectiveAckChunk() *chunkSelectiveAck { + sack := &chunkSelectiveAck{} + sack.cumulativeTSNAck = a.peerLastTSN() + sack.advertisedReceiverWindowCredit = a.getMyReceiverWindowCredit() + sack.duplicateTSN = a.payloadQueue.popDuplicates() + sack.gapAckBlocks = a.payloadQueue.getGapAckBlocks() + + return sack +} + +func pack(p *packet) []*packet { + return []*packet{p} +} + +func (a *Association) handleChunksStart() { + a.lock.Lock() + defer a.lock.Unlock() + + a.stats.incPacketsReceived() + + a.delayedAckTriggered = false + a.immediateAckTriggered = false +} + +func (a *Association) handleChunksEnd() { + a.lock.Lock() + defer a.lock.Unlock() + + if a.immediateAckTriggered { + a.ackState = ackStateImmediate + a.ackTimer.stop() + a.awakeWriteLoop() + } else if a.delayedAckTriggered { + // Will send delayed ack in the next ack timeout + a.ackState = ackStateDelay + a.ackTimer.start() + } +} + +func (a *Association) handleChunk(receivedPacket *packet, receivedChunk chunk) error { //nolint:cyclop + a.lock.Lock() + defer a.lock.Unlock() + + var packets []*packet + var err error + + if _, err = receivedChunk.check(); err != nil { + a.log.Errorf("[%s] failed validating chunk: %s ", a.name, err) + + return nil + } + + isAbort := false + + switch receivedChunk := receivedChunk.(type) { + // Note: We do not do the following for chunkInit, chunkInitAck, and chunkCookieEcho: + // If an endpoint receives an INIT, INIT ACK, or COOKIE ECHO chunk but decides not to establish the + // new association due to missing mandatory parameters in the received INIT or INIT ACK chunk, invalid + // parameter values, or lack of local resources, it SHOULD respond with an ABORT chunk. + + case *chunkInit: + packets, err = a.handleInit(receivedPacket, receivedChunk) + + case *chunkInitAck: + err = a.handleInitAck(receivedPacket, receivedChunk) + + case *chunkAbort: + isAbort = true + err = a.handleAbort(receivedChunk) + + case *chunkError: + var errStr strings.Builder + for _, e := range receivedChunk.errorCauses { + fmt.Fprintf(&errStr, "(%s)", e) + } + a.log.Debugf("[%s] Error chunk, with following errors: %s", a.name, errStr.String()) + + case *chunkHeartbeat: + packets = a.handleHeartbeat(receivedChunk) + + case *chunkHeartbeatAck: + a.handleHeartbeatAck(receivedChunk) + + case *chunkCookieEcho: + packets = a.handleCookieEcho(receivedChunk) + + case *chunkCookieAck: + a.handleCookieAck() + + case *chunkPayloadData: + packets = a.handleData(receivedChunk) + + case *chunkSelectiveAck: + err = a.handleSack(receivedChunk) + + case *chunkReconfig: + packets, err = a.handleReconfig(receivedChunk) + + case *chunkForwardTSN: + packets = a.handleForwardTSN(receivedChunk) + + case *chunkShutdown: + a.handleShutdown(receivedChunk) + case *chunkShutdownAck: + a.handleShutdownAck(receivedChunk) + case *chunkShutdownComplete: + err = a.handleShutdownComplete(receivedChunk) + + default: + err = ErrChunkTypeUnhandled + } + + // Log and return, the only condition that is fatal is a ABORT chunk + if err != nil { + if isAbort { + return err + } + + a.log.Errorf("Failed to handle chunk: %v", err) + + return nil + } + + if len(packets) > 0 { + a.controlQueue.pushAll(packets) + a.awakeWriteLoop() + } + + return nil +} + +func (a *Association) onRetransmissionTimeout(id int, nRtos uint) { //nolint:cyclop + a.lock.Lock() + defer a.lock.Unlock() + + if id == timerT1Init { + err := a.sendInit() + if err != nil { + a.log.Debugf("[%s] failed to retransmit init (nRtos=%d): %v", a.name, nRtos, err) + } + + return + } + + if id == timerT1Cookie { + err := a.sendCookieEcho() + if err != nil { + a.log.Debugf("[%s] failed to retransmit cookie-echo (nRtos=%d): %v", a.name, nRtos, err) + } + + return + } + + if id == timerT2Shutdown { + a.log.Debugf("[%s] retransmission of shutdown timeout (nRtos=%d): %v", a.name, nRtos) + state := a.getState() + + switch state { + case shutdownSent: + a.willSendShutdown = true + a.awakeWriteLoop() + case shutdownAckSent: + a.willSendShutdownAck = true + a.awakeWriteLoop() + } + } + + if id == timerT3RTX { //nolint:nestif + a.stats.incT3Timeouts() + + // RFC 4960 sec 6.3.3 + // E1) For the destination address for which the timer expires, adjust + // its ssthresh with rules defined in Section 7.2.3 and set the + // cwnd <- MTU. + // RFC 4960 sec 7.2.3 + // When the T3-rtx timer expires on an address, SCTP should perform slow + // start by: + // ssthresh = max(cwnd/2, 4*MTU) + // cwnd = 1*MTU + + a.ssthresh = max32(a.CWND()/2, 4*a.MTU()) + a.setCWND(a.MTU()) + a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d inflight=%d (RTO)", + a.name, a.CWND(), a.ssthresh, a.inflightQueue.getNumBytes()) + // If not in Fast Recovery, enter Fast Recovery and mark the highest outstanding TSN as the Fast Recovery exit point. + // When a SACK acknowledges all TSNs up to and including this exit point, Fast Recovery is exited. + // https://www.rfc-editor.org/rfc/rfc4960#section-7.2.4 + // https://www.rfc-editor.org/rfc/rfc9260.html#section-7.2.4 + if a.inFastRecovery { + a.inFastRecovery = false + a.willRetransmitFast = false + a.fastRecoverExitPoint = 0 + a.partialBytesAcked = 0 + a.log.Debugf("[%s] exit fast-recovery (RTO)", a.name) + } + + // RFC 3758 sec 3.5 + // A5) Any time the T3-rtx timer expires, on any destination, the sender + // SHOULD try to advance the "Advanced.Peer.Ack.Point" by following + // the procedures outlined in C2 - C5. + if a.useForwardTSN { + // RFC 3758 Sec 3.5 C2 + for i := a.advancedPeerTSNAckPoint + 1; ; i++ { + c, ok := a.inflightQueue.get(i) + if !ok { + break + } + if !c.abandoned() { + break + } + a.advancedPeerTSNAckPoint = i + } + + // RFC 3758 Sec 3.5 C3 + if sna32GT(a.advancedPeerTSNAckPoint, a.cumulativeTSNAckPoint) { + a.willSendForwardTSN = true + } + } + + a.log.Debugf("[%s] T3-rtx timed out: nRtos=%d cwnd=%d ssthresh=%d", a.name, nRtos, a.CWND(), a.ssthresh) + + /* + a.log.Debugf(" - advancedPeerTSNAckPoint=%d", a.advancedPeerTSNAckPoint) + a.log.Debugf(" - cumulativeTSNAckPoint=%d", a.cumulativeTSNAckPoint) + a.inflightQueue.updateSortedKeys() + for i, tsn := range a.inflightQueue.sorted { + if c, ok := a.inflightQueue.get(tsn); ok { + a.log.Debugf(" - [%d] tsn=%d acked=%v abandoned=%v (%v,%v) len=%d", + i, c.tsn, c.acked, c.abandoned(), c.beginningFragment, c.endingFragment, len(c.userData)) + } + } + */ + + a.inflightQueue.markAllToRetrasmit() + a.awakeWriteLoop() + + return + } + + if id == timerReconfig { + a.willRetransmitReconfig = true + a.awakeWriteLoop() + } +} + +func (a *Association) onRetransmissionFailure(id int) { + a.lock.Lock() + defer a.lock.Unlock() + + if id == timerT1Init { + a.log.Errorf("[%s] retransmission failure: T1-init", a.name) + a.completeHandshake(ErrHandshakeInitAck) + + return + } + + if id == timerT1Cookie { + a.log.Errorf("[%s] retransmission failure: T1-cookie", a.name) + a.completeHandshake(ErrHandshakeCookieEcho) + + return + } + + if id == timerT2Shutdown { + a.log.Errorf("[%s] retransmission failure: T2-shutdown", a.name) + + return + } + + if id == timerT3RTX { + // T3-rtx timer will not fail by design + // Justifications: + // * ICE would fail if the connectivity is lost + // * WebRTC spec is not clear how this incident should be reported to ULP + a.log.Errorf("[%s] retransmission failure: T3-rtx (DATA)", a.name) + + return + } +} + +func (a *Association) onAckTimeout() { + a.lock.Lock() + defer a.lock.Unlock() + + a.log.Tracef("[%s] ack timed out (ackState: %d)", a.name, a.ackState) + a.stats.incAckTimeouts() + + a.ackState = ackStateImmediate + a.awakeWriteLoop() +} + +// BufferedAmount returns total amount (in bytes) of currently buffered user data. +func (a *Association) BufferedAmount() int { + a.lock.RLock() + defer a.lock.RUnlock() + + return a.pendingQueue.getNumBytes() + a.inflightQueue.getNumBytes() +} + +// MaxMessageSize returns the maximum message size you can send. +func (a *Association) MaxMessageSize() uint32 { + return atomic.LoadUint32(&a.maxMessageSize) +} + +// SetMaxMessageSize sets the maximum message size you can send. +func (a *Association) SetMaxMessageSize(maxMsgSize uint32) { + atomic.StoreUint32(&a.maxMessageSize, maxMsgSize) +} + +// completeHandshake sends the given error to handshakeCompletedCh unless the read/write +// side of the association closes before that can happen. It returns whether it was able +// to send on the channel or not. +func (a *Association) completeHandshake(handshakeErr error) bool { + select { + // Note: This is a future place where the user could be notified (COMMUNICATION UP) + case a.handshakeCompletedCh <- handshakeErr: + return true + case <-a.closeWriteLoopCh: // check the read/write sides for closure + case <-a.readLoopCloseCh: + } + + return false +} + +func (a *Association) pokeTimerLoop() { + // enqueue a single wake-up without blocking. + select { + case a.timerUpdateCh <- struct{}{}: + default: + } +} + +func (a *Association) startRackTimer(dur time.Duration) { + a.timerMu.Lock() + + if dur <= 0 { + a.rackDeadline = time.Time{} + } else { + a.rackDeadline = time.Now().Add(dur) + } + + a.timerMu.Unlock() + + a.pokeTimerLoop() +} + +func (a *Association) stopRackTimer() { + a.timerMu.Lock() + a.rackDeadline = time.Time{} + a.timerMu.Unlock() + + a.pokeTimerLoop() +} + +func (a *Association) startPTOTimer(dur time.Duration) { + a.timerMu.Lock() + + if dur <= 0 { + a.ptoDeadline = time.Time{} + } else { + a.ptoDeadline = time.Now().Add(dur) + } + + a.timerMu.Unlock() + + a.pokeTimerLoop() +} + +func (a *Association) stopPTOTimer() { + a.timerMu.Lock() + a.ptoDeadline = time.Time{} + a.timerMu.Unlock() + + a.pokeTimerLoop() +} + +// drainTimer safely stops a timer and drains its channel if needed. +func drainTimer(t *time.Timer) { + if !t.Stop() { + select { + case <-t.C: + default: + } + } +} + +// timerLoop runs one goroutine per association for RACK and PTO deadlines. +// this only runs if RACK is enabled. +func (a *Association) timerLoop() { //nolint:gocognit,cyclop + // begin with a disarmed timer. + timer := time.NewTimer(time.Hour) + drainTimer(timer) + armed := false + + for { + // compute the earliest non-zero deadline. + a.timerMu.Lock() + rackDeadline := a.rackDeadline + ptoDeadline := a.ptoDeadline + a.timerMu.Unlock() + + var next time.Time + switch { + case rackDeadline.IsZero(): + next = ptoDeadline + case ptoDeadline.IsZero(): + next = rackDeadline + default: + if rackDeadline.Before(ptoDeadline) { + next = rackDeadline + } else { + next = ptoDeadline + } + } + + if next.IsZero() { + if armed { + drainTimer(timer) + armed = false + } + } else { + d := time.Until(next) + + if d <= 0 { + d = time.Nanosecond + } + + if armed { + drainTimer(timer) + } + + timer.Reset(d) + armed = true + } + + select { + case <-a.closeWriteLoopCh: + if armed { + drainTimer(timer) + } + + return + + case <-a.timerUpdateCh: + // re-compute deadlines and (re)arm in next loop iteration. + + case <-timer.C: + armed = false + + // snapshot & clear due deadlines before firing to avoid races with re-arms. + currTime := time.Now() + var fireRack, firePTO bool + + a.timerMu.Lock() + + if !a.rackDeadline.IsZero() && !currTime.Before(a.rackDeadline) { + fireRack = true + a.rackDeadline = time.Time{} + } + + if !a.ptoDeadline.IsZero() && !currTime.Before(a.ptoDeadline) { + firePTO = true + a.ptoDeadline = time.Time{} + } + + a.timerMu.Unlock() + + // fire callbacks without holding timerMu. + if fireRack { + a.onRackTimeout() + } + + if firePTO { + a.onPTOTimer() + } + } + } +} + +// onRackAfterSACK implements the RACK logic (RACK for SCTP section 2A/B, section 3) and TLP scheduling (section 2C). +func (a *Association) onRackAfterSACK( // nolint:gocognit,cyclop,gocyclo + deliveredFound bool, + newestDeliveredSendTime time.Time, + newestDeliveredOrigTSN uint32, + sack *chunkSelectiveAck, +) { + // store the current time for when we check if it's needed in step 2 (whether we should maintain ReoWND) + currTime := time.Now() + + // 1) Update highest delivered original TSN for reordering detection (section 2B) + if deliveredFound { + if sna32LT(a.rackHighestDeliveredOrigTSN, newestDeliveredOrigTSN) { + a.rackHighestDeliveredOrigTSN = newestDeliveredOrigTSN + } else { + // ACK of an original TSN below the high-watermark -> reordering observed + a.rackReorderingSeen = true + } + if newestDeliveredSendTime.After(a.rackDeliveredTime) { + a.rackDeliveredTime = newestDeliveredSendTime + } + } + + // 2) Maintain ReoWND (RACK for SCTP section 2B) + if minRTT := a.rack.rackMinRTTWnd.Min(currTime); minRTT > 0 { + a.rackMinRTT = minRTT + } + + var base time.Duration + if a.rackMinRTT > 0 { + base = max(a.rackMinRTT/4, a.rack.rackReoWndFloor) + } + + // Suppress during recovery if no reordering ever seen; else (re)initialize from base if zero. + if !a.rackReorderingSeen && (a.inFastRecovery || a.t3RTX.isRunning()) { + a.rackReoWnd = 0 + } else if a.rackReoWnd == 0 && base > 0 { + a.rackReoWnd = base + } + + // DSACK-style inflation using SCTP duplicate TSNs (RACK for SCTP section 3 noting SCTP + // natively reports duplicates + RACK for SCTP section 2B policy) + if len(sack.duplicateTSN) > 0 && a.rackMinRTT > 0 { + a.rackReoWnd += max(a.rackMinRTT/4, a.rack.rackReoWndFloor) + // keep inflated for 16 loss recoveries before reset + a.rackKeepInflatedRecoveries = 16 + a.log.Tracef("[%s] RACK: DSACK/dupTSN seen, inflate reoWnd to %v", a.name, a.rackReoWnd) + } + + // decrement the keep inflated counter when we leave recovery + if !a.inFastRecovery && a.rackKeepInflatedRecoveries > 0 { + a.rackKeepInflatedRecoveries-- + if a.rackKeepInflatedRecoveries == 0 && a.rackMinRTT > 0 { + a.rackReoWnd = a.rackMinRTT / 4 + } + } + + // RFC 8985: the reordering window MUST be bounded by SRTT. + if srttMs := a.SRTT(); srttMs > 0 { + if srttDur := time.Duration(srttMs * 1e6); a.rackReoWnd > srttDur { + a.rackReoWnd = srttDur + } + } + + // 3) Loss marking on ACK: any outstanding chunk whose (send_time + reoWnd) < newestDeliveredSendTime + // is lost (RACK for SCTP section 2A) + if !a.rackDeliveredTime.IsZero() { //nolint:nestif + marked := false + + for chunk := a.rackHead; chunk != nil; { + next := chunk.rackNext // save in case we remove c + + // but clean up if they exist. + if chunk.acked || chunk.abandoned() { + a.rackRemove(chunk) + chunk = next + + continue + } + + if chunk.retransmit || chunk.nSent > 1 { + // Either already scheduled for retransmit or not an original send: + // skip but keep in list in case it's still outstanding. + chunk = next + + continue + } + + // Ordered by original send time. If this one is too new, + // all later ones are even newer -> short-circuit. + if !chunk.since.Add(a.rackReoWnd).Before(a.rackDeliveredTime) { + break + } + + // Mark as lost by RACK. + chunk.retransmit = true + marked = true + + // Remove from xmit-time list: we no longer need RACK for this TSN. + a.rackRemove(chunk) + + a.log.Tracef("[%s] RACK: mark lost tsn=%d (sent=%v, delivered=%v, reoWnd=%v)", + a.name, chunk.tsn, chunk.since, a.rackDeliveredTime, a.rackReoWnd) + + chunk = next + } + + if marked { + // loss detected during active TLR so we must reduce burst + if a.tlrActive { + a.tlrApplyAdditionalLossLocked(currTime) + } + + a.awakeWriteLoop() + } + } + + // 4) Arm the RACK timer if there are still outstanding but not-yet-overdue chunks (RACK for SCTP section 2A) + if a.rackHead != nil && !a.rackDeliveredTime.IsZero() { + // RackRTT = RTT of the most recently delivered packet + rackRTT := max(time.Since(a.rackDeliveredTime), time.Duration(0)) + a.startRackTimer(rackRTT + a.rackReoWnd) // RACK for SCTP section 2A + } else { + a.stopRackTimer() + } + + // 5) Re/schedule Tail Loss Probe (PTO) (RACK for SCTP section 2C) + // Triggered when new data is sent or cum-ack advances; we approximate by scheduling on every SACK that advanced + if a.inflightQueue.size() == 0 { + a.stopPTOTimer() + + return + } + + var pto time.Duration + srttMs := a.SRTT() + if srttMs > 0 { + srtt := time.Duration(srttMs * 1e6) + extra := 2 * time.Millisecond + + if a.inflightQueue.size() == 1 { + extra = a.rack.rackWCDelAck // 200ms for single outstanding, else 2ms + } + + pto = 2*srtt + extra + } else { + pto = time.Second // no RTT yet + } + + a.startPTOTimer(pto) +} + +// schedulePTOAfterSendLocked starts/restarts the PTO timer when new data is transmitted. +// Caller must hold a.lock. +func (a *Association) schedulePTOAfterSendLocked() { + if a.inflightQueue.size() == 0 { + a.stopPTOTimer() + + return + } + + var pto time.Duration + if srttMs := a.SRTT(); srttMs > 0 { + srtt := time.Duration(srttMs * 1e6) + extra := 2 * time.Millisecond + + if a.inflightQueue.size() == 1 { + extra = a.rack.rackWCDelAck + } + + pto = 2*srtt + extra + } else { + pto = time.Second + } + + a.startPTOTimer(pto) +} + +// onRackTimeout is fired to avoid waiting for the next ACK. +func (a *Association) onRackTimeout() { + a.lock.Lock() + defer a.lock.Unlock() + a.onRackTimeoutLocked() +} + +func (a *Association) onRackTimeoutLocked() { //nolint:cyclop + if a.rackDeliveredTime.IsZero() { + return + } + + marked := false + + for chunk := a.rackHead; chunk != nil; { + next := chunk.rackNext + + if chunk.acked || chunk.abandoned() { + a.rackRemove(chunk) + chunk = next + + continue + } + if chunk.retransmit || chunk.nSent > 1 { + chunk = next + + continue + } + + if !chunk.since.Add(a.rackReoWnd).Before(a.rackDeliveredTime) { + // too new, later ones are newer so we can skip. + break + } + + chunk.retransmit = true + marked = true + a.rackRemove(chunk) + + a.log.Tracef("[%s] RACK timer: mark lost tsn=%d", a.name, chunk.tsn) + + chunk = next + } + + if marked { + // loss detected during active TLR so we must reduce burst + if a.tlrActive { + a.tlrApplyAdditionalLossLocked(time.Now()) + } + + a.awakeWriteLoop() + } +} + +func (a *Association) onPTOTimer() { + a.lock.Lock() + defer a.lock.Unlock() + a.onPTOTimerLocked() +} + +func (a *Association) onPTOTimerLocked() { + // if nothing is inflight, PTO should not drive TLR. + // use PTO as a chance to probe RTT via HEARTBEAT instead of retransmitting DATA. + if a.inflightQueue.size() == 0 { + a.stopPTOTimer() + a.log.Tracef("[%s] PTO idle: sending active HEARTBEAT for RTT probe", a.name) + a.sendActiveHeartbeatLocked() + + return + } + + currTime := time.Now() + + if !a.tlrActive { + a.tlrBeginLocked() + } else { + a.tlrApplyAdditionalLossLocked(currTime) + } + + // If we have unsent data, PTO should just wake the writer. + if a.pendingQueue.size() > 0 { + a.awakeWriteLoop() + + return + } + + // otherwise retransmit most recently sent in-flight DATA. + var latest *chunkPayloadData + for i := uint32(0); ; i++ { + c, ok := a.inflightQueue.get(a.cumulativeTSNAckPoint + i + 1) + if !ok { + break + } + + if c.acked || c.abandoned() { + continue + } + + latest = c + } + + if latest != nil && !latest.retransmit { + latest.retransmit = true + a.log.Tracef("[%s] PTO fired: probe tsn=%d", a.name, latest.tsn) + a.awakeWriteLoop() + } +} + +func (a *Association) rackInsert(c *chunkPayloadData) { + if c == nil || c.rackInList { + return + } + + if a.rackTail != nil { + a.rackTail.rackNext = c + c.rackPrev = a.rackTail + } else { + a.rackHead = c + } + a.rackTail = c + c.rackInList = true +} + +func (a *Association) rackRemove(chunk *chunkPayloadData) { + if chunk == nil || !chunk.rackInList { + return + } + + if prev := chunk.rackPrev; prev != nil { + prev.rackNext = chunk.rackNext + } else { + a.rackHead = chunk.rackNext + } + + if next := chunk.rackNext; next != nil { + next.rackPrev = chunk.rackPrev + } else { + a.rackTail = chunk.rackPrev + } + + chunk.rackPrev = nil + chunk.rackNext = nil + chunk.rackInList = false +} + +// caller must hold a.lock. +func (a *Association) tlrFirstRTTDurationLocked() time.Duration { + // Use SRTT when available; fall back to a safe default. + if srttMs := a.SRTT(); srttMs > 0 { + return time.Duration(srttMs * 1e6) + } + + return time.Second +} + +// caller must hold a.lock. +func (a *Association) tlrUpdatePhaseLocked(currTime time.Time) { + if !a.tlrActive || !a.tlrFirstRTT { + return + } + if a.tlrStartTime.IsZero() { + return + } + + if currTime.Sub(a.tlrStartTime) >= a.tlrFirstRTTDurationLocked() { + a.tlrFirstRTT = false + } +} + +// caller must hold a.lock. +func (a *Association) tlrCurrentBurstUnitsLocked() int64 { + if !a.tlrActive { + return 0 + } + + a.tlrUpdatePhaseLocked(time.Now()) + + if a.tlrFirstRTT { + return a.tlrBurstFirstRTTUnits + } + + return a.tlrBurstLaterRTTUnits +} + +// caller must hold a.lock. +// Returns remaining burst budget in "scaled bytes": bytes * 4 (quarter-MTU precision). +func (a *Association) tlrCurrentBurstBudgetScaledLocked() int64 { + if !a.tlrActive { + return 0 + } + + units := a.tlrCurrentBurstUnitsLocked() + + return units * int64(a.MTU()) +} + +// caller must hold a.lock. +func (a *Association) tlrHighestOutstandingTSNLocked() (uint32, bool) { + var last uint32 + found := false + + for i := uint32(0); ; i++ { + tsn := a.cumulativeTSNAckPoint + i + 1 + _, ok := a.inflightQueue.get(tsn) + if !ok { + break + } + last = tsn + found = true + } + + return last, found +} + +// caller must hold a.lock. +func (a *Association) tlrBeginLocked() { + currTime := time.Now() + + a.tlrActive = true + a.tlrFirstRTT = true + a.tlrHadAdditionalLoss = false + a.tlrStartTime = currTime + + if endTSN, ok := a.tlrHighestOutstandingTSNLocked(); ok { + a.tlrEndTSN = endTSN + } else { + a.tlrEndTSN = a.cumulativeTSNAckPoint + } +} + +// caller must hold a.lock. +func (a *Association) tlrApplyAdditionalLossLocked(currTime time.Time) { + if !a.tlrActive { + return + } + + // Decide whether we're still within the first recovery RTT window. + a.tlrUpdatePhaseLocked(currTime) + + a.tlrHadAdditionalLoss = true + a.tlrGoodOps = 0 + + if a.tlrFirstRTT { + // Loss during first recovery RTT => initial burst too high. + a.tlrBurstFirstRTTUnits -= tlrBurstStepDownFirstRTT + if a.tlrBurstFirstRTTUnits < tlrBurstMinFirstRTT { + a.tlrBurstFirstRTTUnits = tlrBurstMinFirstRTT + } + } else { + // Loss during later RTTs => increasing rate too high. + a.tlrBurstLaterRTTUnits -= tlrBurstStepDownLaterRTT + if a.tlrBurstLaterRTTUnits < tlrBurstMinLaterRTT { + a.tlrBurstLaterRTTUnits = tlrBurstMinLaterRTT + } + } +} + +// caller must hold a.lock. +func (a *Association) tlrMaybeFinishLocked(ackProgress bool) { + if !a.tlrActive { + return + } + + // determine if we should move from the first RTT burst to later RTT burst. + if a.tlrFirstRTT && ackProgress { + a.tlrFirstRTT = false + } + + // finish once cumulatively ACKed through the tail we were recovering. + if sna32GTE(a.cumulativeTSNAckPoint, a.tlrEndTSN) { + if !a.tlrHadAdditionalLoss { + a.tlrGoodOps++ + if a.tlrGoodOps >= tlrGoodOpsResetThreshold { + a.tlrBurstFirstRTTUnits = tlrBurstDefaultFirstRTT + a.tlrBurstLaterRTTUnits = tlrBurstDefaultLaterRTT + a.tlrGoodOps = 0 + } + } else { + a.tlrGoodOps = 0 + } + + a.tlrActive = false + a.tlrFirstRTT = false + a.tlrHadAdditionalLoss = false + a.tlrEndTSN = 0 + } +} + +// caller must hold a.lock. +// "budgetScaled" is remaining burst budget in (bytes*4) scale. +// "consumed" allows the first send in a burst. +func (a *Association) tlrAllowSendLocked(budgetScaled *int64, consumed *bool, estBytes int) bool { + if !a.tlrActive || budgetScaled == nil || consumed == nil { + return true + } + if estBytes <= 0 { + return true + } + + needScaled := int64(estBytes) * tlrUnitsPerMTU // bytes*4 + if *consumed && *budgetScaled < needScaled { + return false + } + + *budgetScaled -= needScaled + if *budgetScaled < 0 { + *budgetScaled = 0 + } + *consumed = true + + return true +} + +// ActiveHeartbeat sends a HEARTBEAT chunk on the association to perform an +// on-demand RTT measurement without application payload. +// +// It is safe to call from outside; it will take the association lock and +// be a no-op if the association is not established. +func (a *Association) ActiveHeartbeat() { + a.lock.Lock() + defer a.lock.Unlock() + + if a.getState() != established { + return + } + + a.sendActiveHeartbeatLocked() +} + +// caller must hold a.lock. +func (a *Association) sendActiveHeartbeatLocked() { + now := time.Now().UnixNano() + buf := make([]byte, 8) + binary.BigEndian.PutUint64(buf, uint64(now)) //nolint:gosec // time.now() will never be negative + + info := ¶mHeartbeatInfo{heartbeatInformation: buf} + + hb := &chunkHeartbeat{ + chunkHeader: chunkHeader{ + typ: ctHeartbeat, + flags: 0, + }, + params: []param{info}, + } + + a.controlQueue.push(&packet{ + verificationTag: a.peerVerificationTag, + sourcePort: a.sourcePort, + destinationPort: a.destinationPort, + chunks: []chunk{hb}, + }) + a.awakeWriteLoop() +} + +// GenerateOutOfBandToken generates an out-of-band connection token (i.e. a +// serialized SCTP INIT chunk) for use with SNAP. +func GenerateOutOfBandToken(opts ...ClientOption) ([]byte, error) { + config := &Config{} + config.applyDefaults() + + for _, opt := range opts { + if opt == nil { + continue + } + if err := opt.applyClient(config); err != nil { + return nil, err + } + } + + config.applyDefaults() + + init := &chunkInit{} + init.initialTSN = globalMathRandomGenerator.Uint32() + init.numOutboundStreams = math.MaxUint16 + init.numInboundStreams = math.MaxUint16 + init.initiateTag = generateInitiateTag() + init.advertisedReceiverWindowCredit = config.MaxReceiveBufferSize + setSupportedExtensions(&init.chunkInitCommon) + + if config.EnableZeroChecksum { + init.params = append(init.params, ¶mZeroChecksumAcceptable{edmid: dtlsErrorDetectionMethod}) + } + _, err := init.check() + if err != nil { + return nil, err + } + + return init.marshal() +} diff --git a/vendor/github.com/pion/sctp/association_options.go b/vendor/github.com/pion/sctp/association_options.go new file mode 100644 index 0000000..4e91fb1 --- /dev/null +++ b/vendor/github.com/pion/sctp/association_options.go @@ -0,0 +1,180 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "net" + + "github.com/pion/logging" +) + +// ServerOption configures a Server. +type ServerOption interface { + applyServer(*Config) error +} + +// ClientOption configures a Client. +type ClientOption interface { + applyClient(*Config) error +} + +// AssociationOption applies to both client and server. +type AssociationOption interface { + ServerOption + ClientOption +} + +// sharedOption wraps an apply function that works for both client and server. +type sharedOption func(*Config) error + +func (o sharedOption) applyServer(c *Config) error { return o(c) } +func (o sharedOption) applyClient(c *Config) error { return o(c) } + +// WithLoggerFactory sets the logger factory for the association. +func WithLoggerFactory(loggerFactory logging.LoggerFactory) AssociationOption { + return sharedOption(func(c *Config) error { + if loggerFactory == nil { + return errNilLoggerFactory // add if you don't already have it + } + c.LoggerFactory = loggerFactory + + return nil + }) +} + +// WithName sets the name of the association. +func WithName(name string) AssociationOption { + return sharedOption(func(c *Config) error { + c.Name = name + + return nil + }) +} + +// WithNetConn sets the net.Conn used by the association. +func WithNetConn(conn net.Conn) AssociationOption { + return sharedOption(func(c *Config) error { + if conn == nil { + return errNilNetConn + } + c.NetConn = conn + + return nil + }) +} + +// WithBlockWrite sets whether the association should use blocking writes. +// By default this is false. +func WithBlockWrite(b bool) AssociationOption { + return sharedOption(func(c *Config) error { + c.BlockWrite = b + + return nil + }) +} + +// WithEnableZeroChecksum sets whether the association should accept zero as a valid checksum. +// By default this is false. +func WithEnableZeroChecksum(b bool) AssociationOption { + return sharedOption(func(c *Config) error { + c.EnableZeroChecksum = b + + return nil + }) +} + +// WithMTU sets the MTU size for the association. +// By default this is 1228. +func WithMTU(size uint32) AssociationOption { + return sharedOption(func(c *Config) error { + if size == 0 { + return errZeroMTUOption + } + c.MTU = size + + return nil + }) +} + +// Congestion control options // + +// WithMaxReceiveBufferSize sets the maximum receive buffer size for the association. +// By default this is 1024 * 1024 = 1048576. +func WithMaxReceiveBufferSize(size uint32) AssociationOption { + return sharedOption(func(c *Config) error { + if size == 0 { + return errZeroMaxReceiveBufferOption + } + c.MaxReceiveBufferSize = size + + return nil + }) +} + +// WithMaxMessageSize sets the maximum message size for the association. +// By default this is 65536. +func WithMaxMessageSize(size uint32) AssociationOption { + return sharedOption(func(c *Config) error { + if size == 0 { + return errZeroMaxMessageSize + } + c.MaxMessageSize = size + + return nil + }) +} + +// WithRTOMax sets the max retransmission timeout in ms for the association. +func WithRTOMax(rtoMax float64) AssociationOption { + return sharedOption(func(c *Config) error { + if rtoMax <= 0 { + return errInvalidRTOMax + } + c.RTOMax = rtoMax + + return nil + }) +} + +// WithMinCwnd sets the minimum congestion window for the association. +func WithMinCwnd(minCwnd uint32) AssociationOption { + return sharedOption(func(c *Config) error { + c.MinCwnd = minCwnd + + return nil + }) +} + +// WithFastRtxWnd sets the fast retransmission window for the association. +func WithFastRtxWnd(fastRtxWnd uint32) AssociationOption { + return sharedOption(func(c *Config) error { + c.FastRtxWnd = fastRtxWnd + + return nil + }) +} + +// WithCwndCAStep sets the congestion window congestion avoidance step for the association. +func WithCwndCAStep(cwndCAStep uint32) AssociationOption { + return sharedOption(func(c *Config) error { + c.CwndCAStep = cwndCAStep + + return nil + }) +} + +// WithSNAP enables SNAP, https://datatracker.ietf.org/doc/draft-hancke-tsvwg-snap/. +func WithSNAP(localSctpInit []byte, remoteSctpInit []byte) AssociationOption { + return sharedOption(func(c *Config) error { + if len(localSctpInit) == 0 || len(remoteSctpInit) == 0 { + return errInvalidSnapToken + } + c.snapConfig = &snapConfig{ + localInit: localSctpInit, + remoteInit: remoteSctpInit, + } + + return nil + }) +} diff --git a/vendor/github.com/pion/sctp/association_rack_options.go b/vendor/github.com/pion/sctp/association_rack_options.go new file mode 100644 index 0000000..d53be6b --- /dev/null +++ b/vendor/github.com/pion/sctp/association_rack_options.go @@ -0,0 +1,85 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "time" +) + +// rackSettings holds the optional RACK related settings for an association. +type rackSettings struct { + // Optional: size of window used to determine minimum RTT for RACK (defaults to 30s) + rackMinRTTWnd *windowedMin + + // Optional: cap the minimum reordering window: 0 = use quarter-RTT + rackReoWndFloor time.Duration + + // Optional: receiver worst-case delayed-ACK for PTO when only one packet is in flight + rackWCDelAck time.Duration +} + +// AssociationRACKOption represents a function that can be used to configure an Association's RACK options. +type AssociationRACKOption func(*rackSettings) error + +// RACK config options // + +// WithRackMinRTTWnd sets the length of the local minimum window used to determine the minRTT. +// By default this is 30 seconds. +func WithRackMinRTTWnd(rackMinRTTWnd time.Duration) AssociationRACKOption { + return func(a *rackSettings) error { + if rackMinRTTWnd <= 0 { + return errInvalidRackMinRTTWnd + } + a.rackMinRTTWnd = newWindowedMin(rackMinRTTWnd) + + return nil + } +} + +// WithRackReoWndFloor sets the RACK reordering window floor for the association. +// By default this is 0. +func WithRackReoWndFloor(rackReoWndFloor time.Duration) AssociationRACKOption { + return func(a *rackSettings) error { + if rackReoWndFloor < 0 { + return errInvalidRackReoWndFloor + } + a.rackReoWndFloor = rackReoWndFloor + + return nil + } +} + +// WithRackWCDelAck sets the receiver worst-case delayed-ACK for PTO when only 1 packet is in flight. +// By default this is 200 ms. +func WithRackWCDelAck(rackWCDelAck time.Duration) AssociationRACKOption { + return func(a *rackSettings) error { + if rackWCDelAck <= 0 { + return errInvalidRackWcDelAck + } + a.rackWCDelAck = rackWCDelAck + + return nil + } +} + +// WithRACKOptions configures optional RACK settings using the above options. +// This also creates the new windowedMin slice used for tracking the minRTT. +func WithRACKOptions(opts ...AssociationRACKOption) AssociationOption { + return sharedOption(func(c *Config) error { + cfg := c.rack + for _, opt := range opts { + if opt == nil { + continue + } + + if err := opt(&cfg); err != nil { + return err + } + } + + c.rack = cfg + + return nil + }) +} diff --git a/vendor/github.com/pion/sctp/association_stats.go b/vendor/github.com/pion/sctp/association_stats.go new file mode 100644 index 0000000..0e3e23b --- /dev/null +++ b/vendor/github.com/pion/sctp/association_stats.go @@ -0,0 +1,94 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "sync/atomic" +) + +type associationStats struct { + nPacketsReceived uint64 + nPacketsSent uint64 + nDATAs uint64 + nSACKsReceived uint64 + nSACKsSent uint64 + nT3Timeouts uint64 + nAckTimeouts uint64 + nFastRetrans uint64 +} + +func (s *associationStats) incPacketsReceived() { + atomic.AddUint64(&s.nPacketsReceived, 1) +} + +func (s *associationStats) getNumPacketsReceived() uint64 { + return atomic.LoadUint64(&s.nPacketsReceived) +} + +func (s *associationStats) incPacketsSent() { + atomic.AddUint64(&s.nPacketsSent, 1) +} + +func (s *associationStats) getNumPacketsSent() uint64 { + return atomic.LoadUint64(&s.nPacketsSent) +} + +func (s *associationStats) incDATAs() { + atomic.AddUint64(&s.nDATAs, 1) +} + +func (s *associationStats) getNumDATAs() uint64 { + return atomic.LoadUint64(&s.nDATAs) +} + +func (s *associationStats) incSACKsReceived() { + atomic.AddUint64(&s.nSACKsReceived, 1) +} + +func (s *associationStats) getNumSACKsReceived() uint64 { + return atomic.LoadUint64(&s.nSACKsReceived) +} + +func (s *associationStats) incSACKsSent() { + atomic.AddUint64(&s.nSACKsSent, 1) +} + +func (s *associationStats) getNumSACKsSent() uint64 { + return atomic.LoadUint64(&s.nSACKsSent) +} + +func (s *associationStats) incT3Timeouts() { + atomic.AddUint64(&s.nT3Timeouts, 1) +} + +func (s *associationStats) getNumT3Timeouts() uint64 { + return atomic.LoadUint64(&s.nT3Timeouts) +} + +func (s *associationStats) incAckTimeouts() { + atomic.AddUint64(&s.nAckTimeouts, 1) +} + +func (s *associationStats) getNumAckTimeouts() uint64 { + return atomic.LoadUint64(&s.nAckTimeouts) +} + +func (s *associationStats) incFastRetrans() { + atomic.AddUint64(&s.nFastRetrans, 1) +} + +func (s *associationStats) getNumFastRetrans() uint64 { + return atomic.LoadUint64(&s.nFastRetrans) +} + +func (s *associationStats) reset() { + atomic.StoreUint64(&s.nPacketsReceived, 0) + atomic.StoreUint64(&s.nPacketsSent, 0) + atomic.StoreUint64(&s.nDATAs, 0) + atomic.StoreUint64(&s.nSACKsReceived, 0) + atomic.StoreUint64(&s.nSACKsSent, 0) + atomic.StoreUint64(&s.nT3Timeouts, 0) + atomic.StoreUint64(&s.nAckTimeouts, 0) + atomic.StoreUint64(&s.nFastRetrans, 0) +} diff --git a/vendor/github.com/pion/sctp/chunk.go b/vendor/github.com/pion/sctp/chunk.go new file mode 100644 index 0000000..8337f9a --- /dev/null +++ b/vendor/github.com/pion/sctp/chunk.go @@ -0,0 +1,12 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +type chunk interface { + unmarshal(raw []byte) error + marshal() ([]byte, error) + check() (bool, error) + + valueLength() int +} diff --git a/vendor/github.com/pion/sctp/chunk_abort.go b/vendor/github.com/pion/sctp/chunk_abort.go new file mode 100644 index 0000000..bff8639 --- /dev/null +++ b/vendor/github.com/pion/sctp/chunk_abort.go @@ -0,0 +1,96 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp // nolint:dupl + +import ( + "errors" + "fmt" + "strings" +) + +/* +Abort represents an SCTP Chunk of type ABORT + +The ABORT chunk is sent to the peer of an association to close the +association. The ABORT chunk may contain Cause Parameters to inform +the receiver about the reason of the abort. DATA chunks MUST NOT be +bundled with ABORT. Control chunks (except for INIT, INIT ACK, and +SHUTDOWN COMPLETE) MAY be bundled with an ABORT, but they MUST be +placed before the ABORT in the SCTP packet or they will be ignored by +the receiver. + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Type = 6 |Reserved |T| Length | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | | + | zero or more Error Causes | + | | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +*/ +type chunkAbort struct { + chunkHeader + errorCauses []errorCause +} + +// Abort chunk errors. +var ( + ErrChunkTypeNotAbort = errors.New("ChunkType is not of type ABORT") + ErrBuildAbortChunkFailed = errors.New("failed build Abort Chunk") +) + +func (a *chunkAbort) unmarshal(raw []byte) error { + if err := a.chunkHeader.unmarshal(raw); err != nil { + return err + } + + if a.typ != ctAbort { + return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotAbort, a.typ.String()) + } + + offset := chunkHeaderSize + for len(raw)-offset >= 4 { + e, err := buildErrorCause(raw[offset:]) + if err != nil { + return fmt.Errorf("%w: %v", ErrBuildAbortChunkFailed, err) //nolint:errorlint + } + + offset += int(e.length()) + a.errorCauses = append(a.errorCauses, e) + } + + return nil +} + +func (a *chunkAbort) marshal() ([]byte, error) { + a.chunkHeader.typ = ctAbort + a.flags = 0x00 + a.raw = []byte{} + for _, ec := range a.errorCauses { + raw, err := ec.marshal() + if err != nil { + return nil, err + } + a.raw = append(a.raw, raw...) + } + + return a.chunkHeader.marshal() +} + +func (a *chunkAbort) check() (abort bool, err error) { + return false, nil +} + +// String makes chunkAbort printable. +func (a *chunkAbort) String() string { + var res strings.Builder + res.WriteString(a.chunkHeader.String()) + + for _, cause := range a.errorCauses { + fmt.Fprintf(&res, "\n - %s", cause) + } + + return res.String() +} diff --git a/vendor/github.com/pion/sctp/chunk_cookie_ack.go b/vendor/github.com/pion/sctp/chunk_cookie_ack.go new file mode 100644 index 0000000..5ca2820 --- /dev/null +++ b/vendor/github.com/pion/sctp/chunk_cookie_ack.go @@ -0,0 +1,54 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "errors" + "fmt" +) + +/* +chunkCookieAck represents an SCTP Chunk of type chunkCookieAck + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Type = 11 |Chunk Flags | Length = 4 | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +*/ +type chunkCookieAck struct { + chunkHeader +} + +// Cookie ack chunk errors. +var ( + ErrChunkTypeNotCookieAck = errors.New("ChunkType is not of type COOKIEACK") +) + +func (c *chunkCookieAck) unmarshal(raw []byte) error { + if err := c.chunkHeader.unmarshal(raw); err != nil { + return err + } + + if c.typ != ctCookieAck { + return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotCookieAck, c.typ.String()) + } + + return nil +} + +func (c *chunkCookieAck) marshal() ([]byte, error) { + c.chunkHeader.typ = ctCookieAck + + return c.chunkHeader.marshal() +} + +func (c *chunkCookieAck) check() (abort bool, err error) { + return false, nil +} + +// String makes chunkCookieAck printable. +func (c *chunkCookieAck) String() string { + return c.chunkHeader.String() +} diff --git a/vendor/github.com/pion/sctp/chunk_cookie_echo.go b/vendor/github.com/pion/sctp/chunk_cookie_echo.go new file mode 100644 index 0000000..c8f77ed --- /dev/null +++ b/vendor/github.com/pion/sctp/chunk_cookie_echo.go @@ -0,0 +1,55 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "errors" + "fmt" +) + +/* +CookieEcho represents an SCTP Chunk of type CookieEcho + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Type = 10 |Chunk Flags | Length | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Cookie | + | | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +*/ +type chunkCookieEcho struct { + chunkHeader + cookie []byte +} + +// Cookie echo chunk errors. +var ( + ErrChunkTypeNotCookieEcho = errors.New("ChunkType is not of type COOKIEECHO") +) + +func (c *chunkCookieEcho) unmarshal(raw []byte) error { + if err := c.chunkHeader.unmarshal(raw); err != nil { + return err + } + + if c.typ != ctCookieEcho { + return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotCookieEcho, c.typ.String()) + } + c.cookie = c.raw + + return nil +} + +func (c *chunkCookieEcho) marshal() ([]byte, error) { + c.chunkHeader.typ = ctCookieEcho + c.chunkHeader.raw = c.cookie + + return c.chunkHeader.marshal() +} + +func (c *chunkCookieEcho) check() (abort bool, err error) { + return false, nil +} diff --git a/vendor/github.com/pion/sctp/chunk_error.go b/vendor/github.com/pion/sctp/chunk_error.go new file mode 100644 index 0000000..35de536 --- /dev/null +++ b/vendor/github.com/pion/sctp/chunk_error.go @@ -0,0 +1,103 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp // nolint:dupl + +import ( + "errors" + "fmt" + "strings" +) + +/* +Operation Error (ERROR) (9) + +An endpoint sends this chunk to its peer endpoint to notify it of +certain error conditions. It contains one or more error causes. An +Operation Error is not considered fatal in and of itself, but may be +used with an ERROR chunk to report a fatal condition. It has the +following parameters: + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Type = 9 | Chunk Flags | Length | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + \ \ + / one or more Error Causes / + \ \ + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +Chunk Flags: 8 bits + + Set to 0 on transmit and ignored on receipt. + +Length: 16 bits (unsigned integer) + + Set to the size of the chunk in bytes, including the chunk header + and all the Error Cause fields present. +*/ +type chunkError struct { + chunkHeader + errorCauses []errorCause +} + +// Error chunk errors. +var ( + ErrChunkTypeNotCtError = errors.New("ChunkType is not of type ctError") + ErrBuildErrorChunkFailed = errors.New("failed build Error Chunk") +) + +func (a *chunkError) unmarshal(raw []byte) error { + if err := a.chunkHeader.unmarshal(raw); err != nil { + return err + } + + if a.typ != ctError { + return fmt.Errorf("%w, actually is %s", ErrChunkTypeNotCtError, a.typ.String()) + } + + offset := chunkHeaderSize + for len(raw)-offset >= 4 { + e, err := buildErrorCause(raw[offset:]) + if err != nil { + return fmt.Errorf("%w: %v", ErrBuildErrorChunkFailed, err) //nolint:errorlint + } + + offset += int(e.length()) + a.errorCauses = append(a.errorCauses, e) + } + + return nil +} + +func (a *chunkError) marshal() ([]byte, error) { + a.chunkHeader.typ = ctError + a.flags = 0x00 + a.raw = []byte{} + for _, ec := range a.errorCauses { + raw, err := ec.marshal() + if err != nil { + return nil, err + } + a.raw = append(a.raw, raw...) + } + + return a.chunkHeader.marshal() +} + +func (a *chunkError) check() (abort bool, err error) { + return false, nil +} + +// String makes chunkError printable. +func (a *chunkError) String() string { + var res strings.Builder + res.WriteString(a.chunkHeader.String()) + + for _, cause := range a.errorCauses { + fmt.Fprintf(&res, "\n - %s", cause) + } + + return res.String() +} diff --git a/vendor/github.com/pion/sctp/chunk_forward_tsn.go b/vendor/github.com/pion/sctp/chunk_forward_tsn.go new file mode 100644 index 0000000..1bd5811 --- /dev/null +++ b/vendor/github.com/pion/sctp/chunk_forward_tsn.go @@ -0,0 +1,155 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "encoding/binary" + "errors" + "fmt" + "strings" +) + +// This chunk shall be used by the data sender to inform the data +// receiver to adjust its cumulative received TSN point forward because +// some missing TSNs are associated with data chunks that SHOULD NOT be +// transmitted or retransmitted by the sender. +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 192 | Flags = 0x00 | Length = Variable | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | New Cumulative TSN | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Stream-1 | Stream Sequence-1 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// \ / +// / \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Stream-N | Stream Sequence-N | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +type chunkForwardTSN struct { + chunkHeader + + // This indicates the new cumulative TSN to the data receiver. Upon + // the reception of this value, the data receiver MUST consider + // any missing TSNs earlier than or equal to this value as received, + // and stop reporting them as gaps in any subsequent SACKs. + newCumulativeTSN uint32 + + streams []chunkForwardTSNStream +} + +const ( + newCumulativeTSNLength = 4 + forwardTSNStreamLength = 4 +) + +// Forward TSN chunk errors. +var ( + ErrMarshalStreamFailed = errors.New("failed to marshal stream") + ErrChunkTooShort = errors.New("chunk too short") +) + +func (c *chunkForwardTSN) unmarshal(raw []byte) error { + if err := c.chunkHeader.unmarshal(raw); err != nil { + return err + } + + if len(c.raw) < newCumulativeTSNLength { + return ErrChunkTooShort + } + + c.newCumulativeTSN = binary.BigEndian.Uint32(c.raw[0:]) + + offset := newCumulativeTSNLength + remaining := len(c.raw) - offset + for remaining > 0 { + s := chunkForwardTSNStream{} + + if err := s.unmarshal(c.raw[offset:]); err != nil { + return fmt.Errorf("%w: %v", ErrMarshalStreamFailed, err) //nolint:errorlint + } + + c.streams = append(c.streams, s) + + offset += s.length() + remaining -= s.length() + } + + return nil +} + +func (c *chunkForwardTSN) marshal() ([]byte, error) { + out := make([]byte, newCumulativeTSNLength) + binary.BigEndian.PutUint32(out[0:], c.newCumulativeTSN) + + for _, s := range c.streams { + b, err := s.marshal() + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrMarshalStreamFailed, err) //nolint:errorlint + } + out = append(out, b...) //nolint:makezero // TODO: fix + } + + c.typ = ctForwardTSN + c.raw = out + + return c.chunkHeader.marshal() +} + +func (c *chunkForwardTSN) check() (abort bool, err error) { + return true, nil +} + +// String makes chunkForwardTSN printable. +func (c *chunkForwardTSN) String() string { + var res strings.Builder + fmt.Fprintf(&res, "New Cumulative TSN: %d\n", c.newCumulativeTSN) + for _, s := range c.streams { + fmt.Fprintf(&res, " - si=%d, ssn=%d\n", s.identifier, s.sequence) + } + + return res.String() +} + +type chunkForwardTSNStream struct { + // This field holds a stream number that was skipped by this + // FWD-TSN. + identifier uint16 + + // This field holds the sequence number associated with the stream + // that was skipped. The stream sequence field holds the largest + // stream sequence number in this stream being skipped. The receiver + // of the FWD-TSN's can use the Stream-N and Stream Sequence-N fields + // to enable delivery of any stranded TSN's that remain on the stream + // re-ordering queues. This field MUST NOT report TSN's corresponding + // to DATA chunks that are marked as unordered. For ordered DATA + // chunks this field MUST be filled in. + sequence uint16 +} + +func (s *chunkForwardTSNStream) length() int { + return forwardTSNStreamLength +} + +func (s *chunkForwardTSNStream) unmarshal(raw []byte) error { + if len(raw) < forwardTSNStreamLength { + return ErrChunkTooShort + } + s.identifier = binary.BigEndian.Uint16(raw[0:]) + s.sequence = binary.BigEndian.Uint16(raw[2:]) + + return nil +} + +func (s *chunkForwardTSNStream) marshal() ([]byte, error) { // nolint:unparam + out := make([]byte, forwardTSNStreamLength) + + binary.BigEndian.PutUint16(out[0:], s.identifier) + binary.BigEndian.PutUint16(out[2:], s.sequence) + + return out, nil +} diff --git a/vendor/github.com/pion/sctp/chunk_heartbeat.go b/vendor/github.com/pion/sctp/chunk_heartbeat.go new file mode 100644 index 0000000..de99a06 --- /dev/null +++ b/vendor/github.com/pion/sctp/chunk_heartbeat.go @@ -0,0 +1,116 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "errors" + "fmt" +) + +/* +chunkHeartbeat represents an SCTP Chunk of type HEARTBEAT (RFC 9260 section 3.3.6) + +An endpoint sends this chunk to probe reachability of a destination address. +The chunk MUST contain exactly one variable-length parameter: + +Variable Parameters Status Type Value +------------------------------------------------------------- +Heartbeat Info Mandatory 1 + +nolint:godot +*/ +type chunkHeartbeat struct { + chunkHeader + params []param +} + +// Heartbeat chunk errors. +var ( + ErrChunkTypeNotHeartbeat = errors.New("ChunkType is not of type HEARTBEAT") + ErrHeartbeatNotLongEnoughInfo = errors.New("heartbeat is not long enough to contain Heartbeat Info") + ErrParseParamTypeFailed = errors.New("failed to parse param type") + ErrHeartbeatParam = errors.New("heartbeat should only have HEARTBEAT param") + ErrHeartbeatChunkUnmarshal = errors.New("failed unmarshalling param in Heartbeat Chunk") + ErrHeartbeatExtraNonZero = errors.New("heartbeat has non-zero trailing bytes after last parameter") + ErrHeartbeatMarshalNoInfo = errors.New("heartbeat marshal requires exactly one Heartbeat Info parameter") +) + +func (h *chunkHeartbeat) unmarshal(raw []byte) error { //nolint:cyclop + if err := h.chunkHeader.unmarshal(raw); err != nil { + return err + } + + if h.typ != ctHeartbeat { + return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotHeartbeat, h.typ.String()) + } + + // if the body is completely empty, accept it but don't populate params. + if len(h.raw) == 0 { + return nil + } + + // need at least a parameter header present (TLV: 4 bytes minimum). + if len(h.raw) < initOptionalVarHeaderLength { + return fmt.Errorf("%w: %d", ErrHeartbeatNotLongEnoughInfo, len(h.raw)) + } + + pType, err := parseParamType(h.raw) + if err != nil { + return fmt.Errorf("%w: %v", ErrParseParamTypeFailed, err) //nolint:errorlint + } + if pType != heartbeatInfo { + return fmt.Errorf("%w: instead have %s", ErrHeartbeatParam, pType.String()) + } + + var pHeader paramHeader + if e := pHeader.unmarshal(h.raw); e != nil { + return fmt.Errorf("%w: %v", ErrParseParamTypeFailed, e) //nolint:errorlint + } + + plen := pHeader.length() + if plen < initOptionalVarHeaderLength || plen > len(h.raw) { + return ErrHeartbeatNotLongEnoughInfo + } + + p, err := buildParam(pType, h.raw[:plen]) + if err != nil { + return fmt.Errorf("%w: %v", ErrHeartbeatChunkUnmarshal, err) //nolint:errorlint + } + h.params = append(h.params, p) + + // any trailing bytes beyond the single param must be all zeros. + if rem := h.raw[plen:]; len(rem) > 0 && !allZero(rem) { + return ErrHeartbeatExtraNonZero + } + + return nil +} + +func (h *chunkHeartbeat) Marshal() ([]byte, error) { + // exactly one Heartbeat Info param is required. + if len(h.params) != 1 { + return nil, ErrHeartbeatMarshalNoInfo + } + + // enforce correct concrete type via type assertion (param interface has no type getter). + if _, ok := h.params[0].(*paramHeartbeatInfo); !ok { + return nil, ErrHeartbeatParam + } + + pp, err := h.params[0].marshal() + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrHeartbeatChunkUnmarshal, err) //nolint:errorlint + } + + // single TLV, no inter-parameter padding within the chunk body. + h.chunkHeader.typ = ctHeartbeat + h.chunkHeader.flags = 0 // sender MUST set to 0 + h.chunkHeader.raw = append([]byte(nil), pp...) + + return h.chunkHeader.marshal() +} + +func (h *chunkHeartbeat) check() (abort bool, err error) { + return false, nil +} diff --git a/vendor/github.com/pion/sctp/chunk_heartbeat_ack.go b/vendor/github.com/pion/sctp/chunk_heartbeat_ack.go new file mode 100644 index 0000000..d367169 --- /dev/null +++ b/vendor/github.com/pion/sctp/chunk_heartbeat_ack.go @@ -0,0 +1,146 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "errors" + "fmt" +) + +/* +chunkHeartbeatAck represents an SCTP Chunk of type HEARTBEAT ACK + +An endpoint should send this chunk to its peer endpoint as a response +to a HEARTBEAT chunk (see Section 8.3). A HEARTBEAT ACK is always +sent to the source IP address of the IP datagram containing the +HEARTBEAT chunk to which this ack is responding. + +The parameter field contains a variable-length opaque data structure. + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Type = 5 | Chunk Flags | Heartbeat Ack Length | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | | + | Heartbeat Information TLV (Variable-Length) | + | | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +Defined as a variable-length parameter using the format described +in Section 3.2.1, i.e.: + +Variable Parameters Status Type Value +------------------------------------------------------------- +Heartbeat Info Mandatory 1 +. +*/ +type chunkHeartbeatAck struct { + chunkHeader + params []param +} + +// Heartbeat ack chunk errors. +var ( + // Deprecated: this error is no longer used but is kept for compatibility. + ErrUnimplemented = errors.New("unimplemented") + ErrChunkTypeNotHeartbeatAck = errors.New("chunk type is not of type HEARTBEAT ACK") + ErrHeartbeatAckParams = errors.New("heartbeat Ack must have one param") + ErrHeartbeatAckNotHeartbeatInfo = errors.New("heartbeat Ack must have one param, and it should be a HeartbeatInfo") + ErrHeartbeatAckMarshalParam = errors.New("unable to marshal parameter for Heartbeat Ack") +) + +func (h *chunkHeartbeatAck) unmarshal(raw []byte) error { //nolint:cyclop + if err := h.chunkHeader.unmarshal(raw); err != nil { + return err + } + + if h.typ != ctHeartbeatAck { + return fmt.Errorf("%w %s", ErrChunkTypeNotHeartbeatAck, h.typ.String()) + } + + // allow for an empty heartbeat: no RTT info -> ActiveHeartbeat just won't update SRTT. + if len(h.raw) == 0 { + h.params = nil + + return nil + } + + if len(h.raw) < initOptionalVarHeaderLength { + return fmt.Errorf("%w: %d", ErrHeartbeatAckParams, len(h.raw)) + } + + pType, err := parseParamType(h.raw) + if err != nil { + return fmt.Errorf("%w: %v", ErrHeartbeatAckParams, err) //nolint:errorlint + } + if pType != heartbeatInfo { + return fmt.Errorf("%w: instead have %s", ErrHeartbeatAckNotHeartbeatInfo, pType.String()) + } + + var pHeader paramHeader + if e := pHeader.unmarshal(h.raw); e != nil { + return fmt.Errorf("%w: %v", ErrHeartbeatAckParams, e) //nolint:errorlint + } + plen := pHeader.length() + if plen < initOptionalVarHeaderLength || plen > len(h.raw) { + return fmt.Errorf("%w: %d", ErrHeartbeatAckParams, plen) + } + + p, err := buildParam(pType, h.raw[:plen]) + if err != nil { + return fmt.Errorf("%w: %v", ErrHeartbeatAckMarshalParam, err) //nolint:errorlint + } + h.params = []param{p} + + // Any trailing bytes beyond the single param must be zero. + if rem := h.raw[plen:]; len(rem) > 0 && !allZero(rem) { + return ErrHeartbeatExtraNonZero + } + + return nil +} + +func (h *chunkHeartbeatAck) marshal() ([]byte, error) { + if len(h.params) != 1 { + return nil, ErrHeartbeatAckParams + } + + switch h.params[0].(type) { + case *paramHeartbeatInfo: + // ParamHeartbeatInfo is valid + default: + return nil, ErrHeartbeatAckNotHeartbeatInfo + } + + out := make([]byte, 0) + for idx, p := range h.params { + pp, err := p.marshal() + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrHeartbeatAckMarshalParam, err) //nolint:errorlint + } + + out = append(out, pp...) + + // Chunks (including Type, Length, and Value fields) are padded out + // by the sender with all zero bytes to be a multiple of 4 bytes + // long. This padding MUST NOT be more than 3 bytes in total. The + // Chunk Length value does not include terminating padding of the + // chunk. *However, it does include padding of any variable-length + // parameter except the last parameter in the chunk.* The receiver + // MUST ignore the padding. + if idx != len(h.params)-1 { + out = padByte(out, getPadding(len(pp))) + } + } + + h.chunkHeader.typ = ctHeartbeatAck + h.chunkHeader.raw = out + + return h.chunkHeader.marshal() +} + +func (h *chunkHeartbeatAck) check() (abort bool, err error) { + return false, nil +} diff --git a/vendor/github.com/pion/sctp/chunk_init.go b/vendor/github.com/pion/sctp/chunk_init.go new file mode 100644 index 0000000..a6568b3 --- /dev/null +++ b/vendor/github.com/pion/sctp/chunk_init.go @@ -0,0 +1,143 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp // nolint:dupl + +import ( + "errors" + "fmt" +) + +/* +Init represents an SCTP Chunk of type INIT + +See chunkInitCommon for the fixed headers + + Variable Parameters Status Type Value + ------------------------------------------------------------- + IPv4 IP (Note 1) Optional 5 + IPv6 IP (Note 1) Optional 6 + Cookie Preservative Optional 9 + Reserved for ECN Capable (Note 2) Optional 32768 (0x8000) + Host Name IP (Note 3) Optional 11 + Supported IP Types (Note 4) Optional 12 +*/ +type chunkInit struct { + chunkHeader + chunkInitCommon +} + +// Init chunk errors. +var ( + ErrChunkTypeNotTypeInit = errors.New("ChunkType is not of type INIT") + ErrChunkValueNotLongEnough = errors.New("chunk Value isn't long enough for mandatory parameters exp") + ErrChunkTypeInitFlagZero = errors.New("ChunkType of type INIT flags must be all 0") + ErrChunkTypeInitUnmarshalFailed = errors.New("failed to unmarshal INIT body") + ErrChunkTypeInitMarshalFailed = errors.New("failed marshaling INIT common data") + ErrChunkTypeInitInitateTagZero = errors.New("ChunkType of type INIT ACK InitiateTag must not be 0") + ErrInitInboundStreamRequestZero = errors.New("INIT ACK inbound stream request must be > 0") + ErrInitOutboundStreamRequestZero = errors.New("INIT ACK outbound stream request must be > 0") + ErrInitAdvertisedReceiver1500 = errors.New("INIT ACK Advertised Receiver Window Credit (a_rwnd) must be >= 1500") + ErrInitUnknownParam = errors.New("INIT with unknown param") +) + +func (i *chunkInit) unmarshal(raw []byte) error { + if err := i.chunkHeader.unmarshal(raw); err != nil { + return err + } + + if i.typ != ctInit { + return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotTypeInit, i.typ.String()) + } else if len(i.raw) < initChunkMinLength { + return fmt.Errorf("%w: %d actual: %d", ErrChunkValueNotLongEnough, initChunkMinLength, len(i.raw)) + } + + // The Chunk Flags field in INIT is reserved, and all bits in it should + // be set to 0 by the sender and ignored by the receiver. The sequence + // of parameters within an INIT can be processed in any order. + if i.flags != 0 { + return ErrChunkTypeInitFlagZero + } + + if err := i.chunkInitCommon.unmarshal(i.raw); err != nil { + return fmt.Errorf("%w: %v", ErrChunkTypeInitUnmarshalFailed, err) //nolint:errorlint + } + + return nil +} + +func (i *chunkInit) marshal() ([]byte, error) { + initShared, err := i.chunkInitCommon.marshal() + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrChunkTypeInitMarshalFailed, err) //nolint:errorlint + } + + i.chunkHeader.typ = ctInit + i.chunkHeader.raw = initShared + + return i.chunkHeader.marshal() +} + +func (i *chunkInit) check() (abort bool, err error) { + // The receiver of the INIT (the responding end) records the value of + // the Initiate Tag parameter. This value MUST be placed into the + // Verification Tag field of every SCTP packet that the receiver of + // the INIT transmits within this association. + // + // The Initiate Tag is allowed to have any value except 0. See + // Section 5.3.1 for more on the selection of the tag value. + // + // If the value of the Initiate Tag in a received INIT chunk is found + // to be 0, the receiver MUST treat it as an error and close the + // association by transmitting an ABORT. + if i.initiateTag == 0 { + return true, ErrChunkTypeInitInitateTagZero + } + + // Defines the maximum number of streams the sender of this INIT + // chunk allows the peer end to create in this association. The + // value 0 MUST NOT be used. + // + // Note: There is no negotiation of the actual number of streams but + // instead the two endpoints will use the min(requested, offered). + // See Section 5.1.1 for details. + // + // Note: A receiver of an INIT with the MIS value of 0 SHOULD abort + // the association. + if i.numInboundStreams == 0 { + return true, ErrInitInboundStreamRequestZero + } + + // Defines the number of outbound streams the sender of this INIT + // chunk wishes to create in this association. The value of 0 MUST + // NOT be used. + // + // Note: A receiver of an INIT with the OS value set to 0 SHOULD + // abort the association. + + if i.numOutboundStreams == 0 { + return true, ErrInitOutboundStreamRequestZero + } + + // An SCTP receiver MUST be able to receive a minimum of 1500 bytes in + // one SCTP packet. This means that an SCTP endpoint MUST NOT indicate + // less than 1500 bytes in its initial a_rwnd sent in the INIT or INIT + // ACK. + if i.advertisedReceiverWindowCredit < 1500 { + return true, ErrInitAdvertisedReceiver1500 + } + + for _, p := range i.unrecognizedParams { + if p.unrecognizedAction == paramHeaderUnrecognizedActionStop || + p.unrecognizedAction == paramHeaderUnrecognizedActionStopAndReport { + return true, ErrInitUnknownParam + } + } + + return false, nil +} + +// String makes chunkInit printable. +func (i *chunkInit) String() string { + return fmt.Sprintf("%s\n%s", i.chunkHeader, i.chunkInitCommon) +} diff --git a/vendor/github.com/pion/sctp/chunk_init_ack.go b/vendor/github.com/pion/sctp/chunk_init_ack.go new file mode 100644 index 0000000..7312cc1 --- /dev/null +++ b/vendor/github.com/pion/sctp/chunk_init_ack.go @@ -0,0 +1,145 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp // nolint:dupl + +import ( + "errors" + "fmt" +) + +/* +chunkInitAck represents an SCTP Chunk of type INIT ACK + +See chunkInitCommon for the fixed headers + + Variable Parameters Status Type Value + ------------------------------------------------------------- + State Cookie Mandatory 7 + IPv4 IP (Note 1) Optional 5 + IPv6 IP (Note 1) Optional 6 + Unrecognized Parameter Optional 8 + Reserved for ECN Capable (Note 2) Optional 32768 (0x8000) + Host Name IP (Note 3) Optional 11 +*/ +type chunkInitAck struct { + chunkHeader + chunkInitCommon +} + +// Init ack chunk errors. +var ( + ErrChunkTypeNotInitAck = errors.New("ChunkType is not of type INIT ACK") + ErrChunkNotLongEnoughForParams = errors.New("chunk Value isn't long enough for mandatory parameters exp") + ErrChunkTypeInitAckFlagZero = errors.New("ChunkType of type INIT ACK flags must be all 0") + ErrInitAckUnmarshalFailed = errors.New("failed to unmarshal INIT body") + ErrInitCommonDataMarshalFailed = errors.New("failed marshaling INIT common data") + ErrChunkTypeInitAckInitateTagZero = errors.New("ChunkType of type INIT ACK InitiateTag must not be 0") + ErrInitAckInboundStreamRequestZero = errors.New("INIT ACK inbound stream request must be > 0") + ErrInitAckOutboundStreamRequestZero = errors.New("INIT ACK outbound stream request must be > 0") + ErrInitAckAdvertisedReceiver1500 = errors.New("INIT ACK Advertised Receiver Window Credit (a_rwnd) must be >= 1500") +) + +func (i *chunkInitAck) unmarshal(raw []byte) error { + if err := i.chunkHeader.unmarshal(raw); err != nil { + return err + } + + if i.typ != ctInitAck { + return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotInitAck, i.typ.String()) + } else if len(i.raw) < initChunkMinLength { + return fmt.Errorf("%w: %d actual: %d", ErrChunkNotLongEnoughForParams, initChunkMinLength, len(i.raw)) + } + + // The Chunk Flags field in INIT is reserved, and all bits in it should + // be set to 0 by the sender and ignored by the receiver. The sequence + // of parameters within an INIT can be processed in any order. + if i.flags != 0 { + return ErrChunkTypeInitAckFlagZero + } + + if err := i.chunkInitCommon.unmarshal(i.raw); err != nil { + return fmt.Errorf("%w: %v", ErrInitAckUnmarshalFailed, err) //nolint:errorlint + } + + return nil +} + +func (i *chunkInitAck) marshal() ([]byte, error) { + initShared, err := i.chunkInitCommon.marshal() + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrInitCommonDataMarshalFailed, err) //nolint:errorlint + } + + i.chunkHeader.typ = ctInitAck + i.chunkHeader.raw = initShared + + return i.chunkHeader.marshal() +} + +func (i *chunkInitAck) check() (abort bool, err error) { + // The receiver of the INIT ACK records the value of the Initiate Tag + // parameter. This value MUST be placed into the Verification Tag + // field of every SCTP packet that the INIT ACK receiver transmits + // within this association. + // + // The Initiate Tag MUST NOT take the value 0. See Section 5.3.1 for + // more on the selection of the Initiate Tag value. + // + // If the value of the Initiate Tag in a received INIT ACK chunk is + // found to be 0, the receiver MUST destroy the association + // discarding its TCB. The receiver MAY send an ABORT for debugging + // purpose. + if i.initiateTag == 0 { + abort = true + + return abort, ErrChunkTypeInitAckInitateTagZero + } + + // Defines the maximum number of streams the sender of this INIT ACK + // chunk allows the peer end to create in this association. The + // value 0 MUST NOT be used. + // + // Note: There is no negotiation of the actual number of streams but + // instead the two endpoints will use the min(requested, offered). + // See Section 5.1.1 for details. + // + // Note: A receiver of an INIT ACK with the MIS value set to 0 SHOULD + // destroy the association discarding its TCB. + if i.numInboundStreams == 0 { + abort = true + + return abort, ErrInitAckInboundStreamRequestZero + } + + // Defines the number of outbound streams the sender of this INIT ACK + // chunk wishes to create in this association. The value of 0 MUST + // NOT be used, and the value MUST NOT be greater than the MIS value + // sent in the INIT chunk. + // + // Note: A receiver of an INIT ACK with the OS value set to 0 SHOULD + // destroy the association discarding its TCB. + + if i.numOutboundStreams == 0 { + abort = true + + return abort, ErrInitAckOutboundStreamRequestZero + } + + // An SCTP receiver MUST be able to receive a minimum of 1500 bytes in + // one SCTP packet. This means that an SCTP endpoint MUST NOT indicate + // less than 1500 bytes in its initial a_rwnd sent in the INIT or INIT + // ACK. + if i.advertisedReceiverWindowCredit < 1500 { + abort = true + + return abort, ErrInitAckAdvertisedReceiver1500 + } + + return false, nil +} + +// String makes chunkInitAck printable. +func (i *chunkInitAck) String() string { + return fmt.Sprintf("%s\n%s", i.chunkHeader, i.chunkInitCommon) +} diff --git a/vendor/github.com/pion/sctp/chunk_init_common.go b/vendor/github.com/pion/sctp/chunk_init_common.go new file mode 100644 index 0000000..cf3aab8 --- /dev/null +++ b/vendor/github.com/pion/sctp/chunk_init_common.go @@ -0,0 +1,181 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "encoding/binary" + "errors" + "fmt" + "strings" +) + +/* +chunkInitCommon represents an SCTP Chunk body of type INIT and INIT ACK + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Type = 1 | Chunk Flags | Chunk Length | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Initiate Tag | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Advertised Receiver Window Credit (a_rwnd) | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Number of Outbound Streams | Number of Inbound Streams | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Initial TSN | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| | +| Optional/Variable-Length Parameters | +| | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +The INIT chunk contains the following parameters. Unless otherwise +noted, each parameter MUST only be included once in the INIT chunk. + +Fixed Parameters Status +---------------------------------------------- +Initiate Tag Mandatory +Advertised Receiver Window Credit Mandatory +Number of Outbound Streams Mandatory +Number of Inbound Streams Mandatory +Initial TSN Mandatory +*/ + +type chunkInitCommon struct { + initiateTag uint32 + advertisedReceiverWindowCredit uint32 + numOutboundStreams uint16 + numInboundStreams uint16 + initialTSN uint32 + params []param + unrecognizedParams []paramHeader +} + +const ( + initChunkMinLength = 16 + initOptionalVarHeaderLength = 4 +) + +// Init chunk errors. +var ( + ErrInitChunkParseParamTypeFailed = errors.New("failed to parse param type") + ErrInitAckMarshalParam = errors.New("unable to marshal parameter for INIT/INITACK") +) + +func (i *chunkInitCommon) unmarshal(raw []byte) error { + i.initiateTag = binary.BigEndian.Uint32(raw[0:]) + i.advertisedReceiverWindowCredit = binary.BigEndian.Uint32(raw[4:]) + i.numOutboundStreams = binary.BigEndian.Uint16(raw[8:]) + i.numInboundStreams = binary.BigEndian.Uint16(raw[10:]) + i.initialTSN = binary.BigEndian.Uint32(raw[12:]) + + // https://tools.ietf.org/html/rfc4960#section-3.2.1 + // + // Chunk values of SCTP control chunks consist of a chunk-type-specific + // header of required fields, followed by zero or more parameters. The + // optional and variable-length parameters contained in a chunk are + // defined in a Type-Length-Value format as shown below. + // + // 0 1 2 3 + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Parameter Type | Parameter Length | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | | + // | Parameter Value | + // | | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + offset := initChunkMinLength + remaining := len(raw) - offset + for remaining > 0 { + if remaining > initOptionalVarHeaderLength { + var pHeader paramHeader + if err := pHeader.unmarshal(raw[offset:]); err != nil { + return fmt.Errorf("%w: %v", ErrInitChunkParseParamTypeFailed, err) //nolint:errorlint + } + + p, err := buildParam(pHeader.typ, raw[offset:]) + if err != nil { + i.unrecognizedParams = append(i.unrecognizedParams, pHeader) + } else { + i.params = append(i.params, p) + } + + padding := getPadding(pHeader.length()) + offset += pHeader.length() + padding + remaining -= pHeader.length() + padding + } else { + break + } + } + + return nil +} + +func (i *chunkInitCommon) marshal() ([]byte, error) { + out := make([]byte, initChunkMinLength) + binary.BigEndian.PutUint32(out[0:], i.initiateTag) + binary.BigEndian.PutUint32(out[4:], i.advertisedReceiverWindowCredit) + binary.BigEndian.PutUint16(out[8:], i.numOutboundStreams) + binary.BigEndian.PutUint16(out[10:], i.numInboundStreams) + binary.BigEndian.PutUint32(out[12:], i.initialTSN) + for idx, p := range i.params { + pp, err := p.marshal() + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrInitAckMarshalParam, err) //nolint:errorlint + } + + out = append(out, pp...) //nolint:makezero // TODO: fix + + // Chunks (including Type, Length, and Value fields) are padded out + // by the sender with all zero bytes to be a multiple of 4 bytes + // long. This padding MUST NOT be more than 3 bytes in total. The + // Chunk Length value does not include terminating padding of the + // chunk. *However, it does include padding of any variable-length + // parameter except the last parameter in the chunk.* The receiver + // MUST ignore the padding. + if idx != len(i.params)-1 { + out = padByte(out, getPadding(len(pp))) + } + } + + return out, nil +} + +// String makes chunkInitCommon printable. +func (i chunkInitCommon) String() string { + format := `initiateTag: %d + advertisedReceiverWindowCredit: %d + numOutboundStreams: %d + numInboundStreams: %d + initialTSN: %d` + + var res strings.Builder + fmt.Fprintf(&res, format, + i.initiateTag, + i.advertisedReceiverWindowCredit, + i.numOutboundStreams, + i.numInboundStreams, + i.initialTSN, + ) + + for i, param := range i.params { + fmt.Fprintf(&res, "Param %d:\n %s", i, param) + } + + return res.String() +} + +// allZero returns true if every byte is 0x00. +func allZero(b []byte) bool { + for _, v := range b { + if v != 0 { + return false + } + } + + return true +} diff --git a/vendor/github.com/pion/sctp/chunk_payload_data.go b/vendor/github.com/pion/sctp/chunk_payload_data.go new file mode 100644 index 0000000..dfd8ef7 --- /dev/null +++ b/vendor/github.com/pion/sctp/chunk_payload_data.go @@ -0,0 +1,219 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "encoding/binary" + "errors" + "fmt" + "time" +) + +/* +chunkPayloadData represents an SCTP Chunk of type DATA + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Type = 0 | Reserved|U|B|E| Length | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | TSN | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Stream Identifier S | Stream Sequence Number n | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Payload Protocol Identifier | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | | + | User Data (seq n of Stream S) | + | | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +An unfragmented user message shall have both the B and E bits set to +'1'. Setting both B and E bits to '0' indicates a middle fragment of +a multi-fragment user message, as summarized in the following table: + + B E Description + ============================================================ + | 1 0 | First piece of a fragmented user message | + +----------------------------------------------------------+ + | 0 0 | Middle piece of a fragmented user message | + +----------------------------------------------------------+ + | 0 1 | Last piece of a fragmented user message | + +----------------------------------------------------------+ + | 1 1 | Unfragmented message | + ============================================================ + | Table 1: Fragment Description Flags | + ============================================================ +*/ +type chunkPayloadData struct { + chunkHeader + + unordered bool + beginningFragment bool + endingFragment bool + immediateSack bool + + tsn uint32 + streamIdentifier uint16 + streamSequenceNumber uint16 + payloadType PayloadProtocolIdentifier + userData []byte + + // Whether this data chunk was acknowledged (received by peer) + acked bool + missIndicator uint32 + + // Partial-reliability parameters used only by sender + since time.Time + nSent uint32 // number of transmission made for this chunk + _abandoned bool + _allInflight bool // valid only with the first fragment + + // Retransmission flag set when T1-RTX timeout occurred and this + // chunk is still in the inflight queue + retransmit bool + + head *chunkPayloadData // link to the head of the fragment + + rackPrev *chunkPayloadData + rackNext *chunkPayloadData + rackInList bool +} + +const ( + payloadDataEndingFragmentBitmask = 1 + payloadDataBeginingFragmentBitmask = 2 + payloadDataUnorderedBitmask = 4 + payloadDataImmediateSACK = 8 + + payloadDataHeaderSize = 12 +) + +// PayloadProtocolIdentifier is an enum for DataChannel payload types. +type PayloadProtocolIdentifier uint32 + +// PayloadProtocolIdentifier enums +// https://www.iana.org/assignments/sctp-parameters/sctp-parameters.xhtml#sctp-parameters-25 +const ( + PayloadTypeUnknown PayloadProtocolIdentifier = 0 + PayloadTypeWebRTCDCEP PayloadProtocolIdentifier = 50 + PayloadTypeWebRTCString PayloadProtocolIdentifier = 51 + PayloadTypeWebRTCBinary PayloadProtocolIdentifier = 53 + PayloadTypeWebRTCStringEmpty PayloadProtocolIdentifier = 56 + PayloadTypeWebRTCBinaryEmpty PayloadProtocolIdentifier = 57 +) + +// Data chunk errors. +var ( + ErrChunkPayloadSmall = errors.New("packet is smaller than the header size") +) + +func (p PayloadProtocolIdentifier) String() string { + switch p { + case PayloadTypeWebRTCDCEP: + return "WebRTC DCEP" + case PayloadTypeWebRTCString: + return "WebRTC String" + case PayloadTypeWebRTCBinary: + return "WebRTC Binary" + case PayloadTypeWebRTCStringEmpty: + return "WebRTC String (Empty)" + case PayloadTypeWebRTCBinaryEmpty: + return "WebRTC Binary (Empty)" + default: + return fmt.Sprintf("Unknown Payload Protocol Identifier: %d", p) + } +} + +func (p *chunkPayloadData) unmarshal(raw []byte) error { + if err := p.chunkHeader.unmarshal(raw); err != nil { + return err + } + + p.immediateSack = p.flags&payloadDataImmediateSACK != 0 + p.unordered = p.flags&payloadDataUnorderedBitmask != 0 + p.beginningFragment = p.flags&payloadDataBeginingFragmentBitmask != 0 + p.endingFragment = p.flags&payloadDataEndingFragmentBitmask != 0 + + if len(p.raw) < payloadDataHeaderSize { + return ErrChunkPayloadSmall + } + p.tsn = binary.BigEndian.Uint32(p.raw[0:]) + p.streamIdentifier = binary.BigEndian.Uint16(p.raw[4:]) + p.streamSequenceNumber = binary.BigEndian.Uint16(p.raw[6:]) + p.payloadType = PayloadProtocolIdentifier(binary.BigEndian.Uint32(p.raw[8:])) + p.userData = p.raw[payloadDataHeaderSize:] + + return nil +} + +func (p *chunkPayloadData) marshal() ([]byte, error) { + payRaw := make([]byte, payloadDataHeaderSize+len(p.userData)) + + binary.BigEndian.PutUint32(payRaw[0:], p.tsn) + binary.BigEndian.PutUint16(payRaw[4:], p.streamIdentifier) + binary.BigEndian.PutUint16(payRaw[6:], p.streamSequenceNumber) + binary.BigEndian.PutUint32(payRaw[8:], uint32(p.payloadType)) + copy(payRaw[payloadDataHeaderSize:], p.userData) + + flags := uint8(0) + if p.endingFragment { + flags = 1 + } + if p.beginningFragment { + flags |= 1 << 1 + } + if p.unordered { + flags |= 1 << 2 + } + if p.immediateSack { + flags |= 1 << 3 + } + + p.chunkHeader.flags = flags + p.chunkHeader.typ = ctPayloadData + p.chunkHeader.raw = payRaw + + return p.chunkHeader.marshal() +} + +func (p *chunkPayloadData) check() (abort bool, err error) { + return false, nil +} + +// String makes chunkPayloadData printable. +func (p *chunkPayloadData) String() string { + return fmt.Sprintf("%s\n%d", p.chunkHeader, p.tsn) +} + +func (p *chunkPayloadData) abandoned() bool { + if p.head != nil { + return p.head._abandoned && p.head._allInflight + } + + return p._abandoned && p._allInflight +} + +func (p *chunkPayloadData) setAbandoned(abandoned bool) { + if p.head != nil { + p.head._abandoned = abandoned + + return + } + p._abandoned = abandoned +} + +func (p *chunkPayloadData) setAllInflight() { + if p.endingFragment { + if p.head != nil { + p.head._allInflight = true + } else { + p._allInflight = true + } + } +} + +func (p *chunkPayloadData) isFragmented() bool { + return p.head != nil || !p.beginningFragment || !p.endingFragment +} diff --git a/vendor/github.com/pion/sctp/chunk_reconfig.go b/vendor/github.com/pion/sctp/chunk_reconfig.go new file mode 100644 index 0000000..d219677 --- /dev/null +++ b/vendor/github.com/pion/sctp/chunk_reconfig.go @@ -0,0 +1,110 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "errors" + "fmt" +) + +// https://tools.ietf.org/html/rfc6525#section-3.1 +// chunkReconfig represents an SCTP Chunk used to reconfigure streams. +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 130 | Chunk Flags | Chunk Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// \ \ +// / Re-configuration Parameter / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// \ \ +// / Re-configuration Parameter (optional) / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +type chunkReconfig struct { + chunkHeader + paramA param + paramB param +} + +// Reconfigure chunk errors. +var ( + ErrChunkParseParamTypeFailed = errors.New("failed to parse param type") + ErrChunkMarshalParamAReconfigFailed = errors.New("unable to marshal parameter A for reconfig") + ErrChunkMarshalParamBReconfigFailed = errors.New("unable to marshal parameter B for reconfig") +) + +func (c *chunkReconfig) unmarshal(raw []byte) error { + if err := c.chunkHeader.unmarshal(raw); err != nil { + return err + } + pType, err := parseParamType(c.raw) + if err != nil { + return fmt.Errorf("%w: %v", ErrChunkParseParamTypeFailed, err) //nolint:errorlint + } + a, err := buildParam(pType, c.raw) + if err != nil { + return err + } + c.paramA = a + + padding := getPadding(a.length()) + offset := a.length() + padding + if len(c.raw) > offset { + pType, err := parseParamType(c.raw[offset:]) + if err != nil { + return fmt.Errorf("%w: %v", ErrChunkParseParamTypeFailed, err) //nolint:errorlint + } + b, err := buildParam(pType, c.raw[offset:]) + if err != nil { + return err + } + c.paramB = b + } + + return nil +} + +func (c *chunkReconfig) marshal() ([]byte, error) { + out, err := c.paramA.marshal() + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrChunkMarshalParamAReconfigFailed, err) //nolint:errorlint + } + if c.paramB != nil { + // Pad param A + out = padByte(out, getPadding(len(out))) + + outB, err := c.paramB.marshal() + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrChunkMarshalParamBReconfigFailed, err) //nolint:errorlint + } + + out = append(out, outB...) + } + + c.typ = ctReconfig + c.raw = out + + return c.chunkHeader.marshal() +} + +func (c *chunkReconfig) check() (abort bool, err error) { + // nolint:godox + // TODO: check allowed combinations: + // https://tools.ietf.org/html/rfc6525#section-3.1 + return true, nil +} + +// String makes chunkReconfig printable. +func (c *chunkReconfig) String() string { + res := fmt.Sprintf("Param A:\n %s", c.paramA) + if c.paramB != nil { + res += fmt.Sprintf("Param B:\n %s", c.paramB) + } + + return res +} diff --git a/vendor/github.com/pion/sctp/chunk_selective_ack.go b/vendor/github.com/pion/sctp/chunk_selective_ack.go new file mode 100644 index 0000000..4111341 --- /dev/null +++ b/vendor/github.com/pion/sctp/chunk_selective_ack.go @@ -0,0 +1,152 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "encoding/binary" + "errors" + "fmt" +) + +/* +chunkSelectiveAck represents an SCTP Chunk of type SACK + +This chunk is sent to the peer endpoint to acknowledge received DATA +chunks and to inform the peer endpoint of gaps in the received +subsequences of DATA chunks as represented by their TSNs. +0 1 2 3 +0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Type = 3 |Chunk Flags | Chunk Length | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Cumulative TSN Ack | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Advertised Receiver Window Credit (a_rwnd) | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Number of Gap Ack Blocks = N | Number of Duplicate TSNs = X | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Gap Ack Block #1 Start | Gap Ack Block #1 End | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/ / +\ ... \ +/ / ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Gap Ack Block #N Start | Gap Ack Block #N End | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Duplicate TSN 1 | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/ / +\ ... \ +/ / ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Duplicate TSN X | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +*/ + +type gapAckBlock struct { + start uint16 + end uint16 +} + +// Selective ack chunk errors. +var ( + ErrChunkTypeNotSack = errors.New("ChunkType is not of type SACK") + ErrSackSizeNotLargeEnoughInfo = errors.New("SACK Chunk size is not large enough to contain header") + ErrSackSizeNotMatchPredicted = errors.New("SACK Chunk size does not match predicted amount from header values") +) + +// String makes gapAckBlock printable. +func (g gapAckBlock) String() string { + return fmt.Sprintf("%d - %d", g.start, g.end) +} + +type chunkSelectiveAck struct { + chunkHeader + cumulativeTSNAck uint32 + advertisedReceiverWindowCredit uint32 + gapAckBlocks []gapAckBlock + duplicateTSN []uint32 +} + +const ( + selectiveAckHeaderSize = 12 +) + +func (s *chunkSelectiveAck) unmarshal(raw []byte) error { + if err := s.chunkHeader.unmarshal(raw); err != nil { + return err + } + + if s.typ != ctSack { + return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotSack, s.typ.String()) + } + + if len(s.raw) < selectiveAckHeaderSize { + return fmt.Errorf("%w: %v remaining, needs %v bytes", ErrSackSizeNotLargeEnoughInfo, + len(s.raw), selectiveAckHeaderSize) + } + + s.cumulativeTSNAck = binary.BigEndian.Uint32(s.raw[0:]) + s.advertisedReceiverWindowCredit = binary.BigEndian.Uint32(s.raw[4:]) + s.gapAckBlocks = make([]gapAckBlock, binary.BigEndian.Uint16(s.raw[8:])) + s.duplicateTSN = make([]uint32, binary.BigEndian.Uint16(s.raw[10:])) + + if len(s.raw) != selectiveAckHeaderSize+(4*len(s.gapAckBlocks)+(4*len(s.duplicateTSN))) { + return ErrSackSizeNotMatchPredicted + } + + offset := selectiveAckHeaderSize + for i := range s.gapAckBlocks { + s.gapAckBlocks[i].start = binary.BigEndian.Uint16(s.raw[offset:]) + s.gapAckBlocks[i].end = binary.BigEndian.Uint16(s.raw[offset+2:]) + offset += 4 + } + for i := range s.duplicateTSN { + s.duplicateTSN[i] = binary.BigEndian.Uint32(s.raw[offset:]) + offset += 4 + } + + return nil +} + +func (s *chunkSelectiveAck) marshal() ([]byte, error) { + sackRaw := make([]byte, selectiveAckHeaderSize+(4*len(s.gapAckBlocks)+(4*len(s.duplicateTSN)))) + binary.BigEndian.PutUint32(sackRaw[0:], s.cumulativeTSNAck) + binary.BigEndian.PutUint32(sackRaw[4:], s.advertisedReceiverWindowCredit) + binary.BigEndian.PutUint16(sackRaw[8:], uint16(len(s.gapAckBlocks))) //nolint:gosec // G115 + binary.BigEndian.PutUint16(sackRaw[10:], uint16(len(s.duplicateTSN))) //nolint:gosec // G115 + offset := selectiveAckHeaderSize + for _, g := range s.gapAckBlocks { + binary.BigEndian.PutUint16(sackRaw[offset:], g.start) + binary.BigEndian.PutUint16(sackRaw[offset+2:], g.end) + offset += 4 + } + for _, t := range s.duplicateTSN { + binary.BigEndian.PutUint32(sackRaw[offset:], t) + offset += 4 + } + + s.chunkHeader.typ = ctSack + s.chunkHeader.raw = sackRaw + + return s.chunkHeader.marshal() +} + +func (s *chunkSelectiveAck) check() (abort bool, err error) { + return false, nil +} + +// String makes chunkSelectiveAck printable. +func (s *chunkSelectiveAck) String() string { + res := fmt.Sprintf("SACK cumTsnAck=%d arwnd=%d dupTsn=%d", + s.cumulativeTSNAck, + s.advertisedReceiverWindowCredit, + s.duplicateTSN) + + for _, gap := range s.gapAckBlocks { + res = fmt.Sprintf("%s\n gap ack: %s", res, gap) + } + + return res +} diff --git a/vendor/github.com/pion/sctp/chunk_shutdown.go b/vendor/github.com/pion/sctp/chunk_shutdown.go new file mode 100644 index 0000000..dba680a --- /dev/null +++ b/vendor/github.com/pion/sctp/chunk_shutdown.go @@ -0,0 +1,73 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "encoding/binary" + "errors" + "fmt" +) + +/* +chunkShutdown represents an SCTP Chunk of type chunkShutdown + +0 1 2 3 +0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Type = 7 | Chunk Flags | Length = 8 | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Cumulative TSN Ack | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+. +*/ +type chunkShutdown struct { + chunkHeader + cumulativeTSNAck uint32 +} + +const ( + cumulativeTSNAckLength = 4 +) + +// Shutdown chunk errors. +var ( + ErrInvalidChunkSize = errors.New("invalid chunk size") + ErrChunkTypeNotShutdown = errors.New("ChunkType is not of type SHUTDOWN") +) + +func (c *chunkShutdown) unmarshal(raw []byte) error { + if err := c.chunkHeader.unmarshal(raw); err != nil { + return err + } + + if c.typ != ctShutdown { + return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotShutdown, c.typ.String()) + } + + if len(c.raw) != cumulativeTSNAckLength { + return ErrInvalidChunkSize + } + + c.cumulativeTSNAck = binary.BigEndian.Uint32(c.raw[0:]) + + return nil +} + +func (c *chunkShutdown) marshal() ([]byte, error) { + out := make([]byte, cumulativeTSNAckLength) + binary.BigEndian.PutUint32(out[0:], c.cumulativeTSNAck) + + c.typ = ctShutdown + c.raw = out + + return c.chunkHeader.marshal() +} + +func (c *chunkShutdown) check() (abort bool, err error) { + return false, nil +} + +// String makes chunkShutdown printable. +func (c *chunkShutdown) String() string { + return c.chunkHeader.String() +} diff --git a/vendor/github.com/pion/sctp/chunk_shutdown_ack.go b/vendor/github.com/pion/sctp/chunk_shutdown_ack.go new file mode 100644 index 0000000..3387f2b --- /dev/null +++ b/vendor/github.com/pion/sctp/chunk_shutdown_ack.go @@ -0,0 +1,54 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "errors" + "fmt" +) + +/* +chunkShutdownAck represents an SCTP Chunk of type chunkShutdownAck + +0 1 2 3 +0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Type = 8 | Chunk Flags | Length = 4 | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+. +*/ +type chunkShutdownAck struct { + chunkHeader +} + +// Shutdown ack chunk errors. +var ( + ErrChunkTypeNotShutdownAck = errors.New("ChunkType is not of type SHUTDOWN-ACK") +) + +func (c *chunkShutdownAck) unmarshal(raw []byte) error { + if err := c.chunkHeader.unmarshal(raw); err != nil { + return err + } + + if c.typ != ctShutdownAck { + return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotShutdownAck, c.typ.String()) + } + + return nil +} + +func (c *chunkShutdownAck) marshal() ([]byte, error) { + c.typ = ctShutdownAck + + return c.chunkHeader.marshal() +} + +func (c *chunkShutdownAck) check() (abort bool, err error) { + return false, nil +} + +// String makes chunkShutdownAck printable. +func (c *chunkShutdownAck) String() string { + return c.chunkHeader.String() +} diff --git a/vendor/github.com/pion/sctp/chunk_shutdown_complete.go b/vendor/github.com/pion/sctp/chunk_shutdown_complete.go new file mode 100644 index 0000000..e589aa2 --- /dev/null +++ b/vendor/github.com/pion/sctp/chunk_shutdown_complete.go @@ -0,0 +1,54 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "errors" + "fmt" +) + +/* +chunkShutdownComplete represents an SCTP Chunk of type chunkShutdownComplete + +0 1 2 3 +0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Type = 14 |Reserved |T| Length = 4 | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+. +*/ +type chunkShutdownComplete struct { + chunkHeader +} + +// Shutdown complete chunk errors. +var ( + ErrChunkTypeNotShutdownComplete = errors.New("ChunkType is not of type SHUTDOWN-COMPLETE") +) + +func (c *chunkShutdownComplete) unmarshal(raw []byte) error { + if err := c.chunkHeader.unmarshal(raw); err != nil { + return err + } + + if c.typ != ctShutdownComplete { + return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotShutdownComplete, c.typ.String()) + } + + return nil +} + +func (c *chunkShutdownComplete) marshal() ([]byte, error) { + c.typ = ctShutdownComplete + + return c.chunkHeader.marshal() +} + +func (c *chunkShutdownComplete) check() (abort bool, err error) { + return false, nil +} + +// String makes chunkShutdownComplete printable. +func (c *chunkShutdownComplete) String() string { + return c.chunkHeader.String() +} diff --git a/vendor/github.com/pion/sctp/chunkheader.go b/vendor/github.com/pion/sctp/chunkheader.go new file mode 100644 index 0000000..9f3c91b --- /dev/null +++ b/vendor/github.com/pion/sctp/chunkheader.go @@ -0,0 +1,105 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "encoding/binary" + "errors" + "fmt" +) + +/* +chunkHeader represents a SCTP Chunk header, defined in https://tools.ietf.org/html/rfc4960#section-3.2 +The figure below illustrates the field format for the chunks to be +transmitted in the SCTP packet. Each chunk is formatted with a Chunk +Type field, a chunk-specific Flag field, a Chunk Length field, and a +Value field. + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Chunk Type | Chunk Flags | Chunk Length | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | | + | Chunk Value | + | | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +*/ +type chunkHeader struct { + typ chunkType + flags byte + raw []byte +} + +const ( + chunkHeaderSize = 4 +) + +// SCTP chunk header errors. +var ( + ErrChunkHeaderTooSmall = errors.New("raw is too small for a SCTP chunk") + ErrChunkHeaderNotEnoughSpace = errors.New("not enough data left in SCTP packet to satisfy requested length") + ErrChunkHeaderPaddingNonZero = errors.New("chunk padding is non-zero at offset") +) + +func (c *chunkHeader) unmarshal(raw []byte) error { + if len(raw) < chunkHeaderSize { + return fmt.Errorf( + "%w: raw only %d bytes, %d is the minimum length", + ErrChunkHeaderTooSmall, len(raw), chunkHeaderSize, + ) + } + + c.typ = chunkType(raw[0]) + c.flags = raw[1] + length := binary.BigEndian.Uint16(raw[2:]) + + // Length includes Chunk header + valueLength := int(length - chunkHeaderSize) + lengthAfterValue := len(raw) - (chunkHeaderSize + valueLength) + + if lengthAfterValue < 0 { + return fmt.Errorf("%w: remain %d req %d ", ErrChunkHeaderNotEnoughSpace, valueLength, len(raw)-chunkHeaderSize) + } else if lengthAfterValue < 4 { + // https://tools.ietf.org/html/rfc4960#section-3.2 + // The Chunk Length field does not count any chunk padding. + // Chunks (including Type, Length, and Value fields) are padded out + // by the sender with all zero bytes to be a multiple of 4 bytes + // long. This padding MUST NOT be more than 3 bytes in total. The + // Chunk Length value does not include terminating padding of the + // chunk. However, it does include padding of any variable-length + // parameter except the last parameter in the chunk. The receiver + // MUST ignore the padding. + for i := lengthAfterValue; i > 0; i-- { + paddingOffset := chunkHeaderSize + valueLength + (i - 1) + if raw[paddingOffset] != 0 { + return fmt.Errorf("%w: %d ", ErrChunkHeaderPaddingNonZero, paddingOffset) + } + } + } + + c.raw = raw[chunkHeaderSize : chunkHeaderSize+valueLength] + + return nil +} + +func (c *chunkHeader) marshal() ([]byte, error) { + raw := make([]byte, 4+len(c.raw)) + + raw[0] = uint8(c.typ) + raw[1] = c.flags + binary.BigEndian.PutUint16(raw[2:], uint16(len(c.raw)+chunkHeaderSize)) //nolint:gosec // G115 + copy(raw[4:], c.raw) + + return raw, nil +} + +func (c *chunkHeader) valueLength() int { + return len(c.raw) +} + +// String makes chunkHeader printable. +func (c chunkHeader) String() string { + return c.typ.String() +} diff --git a/vendor/github.com/pion/sctp/chunktype.go b/vendor/github.com/pion/sctp/chunktype.go new file mode 100644 index 0000000..3fe0aa3 --- /dev/null +++ b/vendor/github.com/pion/sctp/chunktype.go @@ -0,0 +1,70 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import "fmt" + +// chunkType is an enum for SCTP Chunk Type field +// This field identifies the type of information contained in the +// Chunk Value field. +type chunkType uint8 + +// List of known chunkType enums. +const ( + ctPayloadData chunkType = 0 + ctInit chunkType = 1 + ctInitAck chunkType = 2 + ctSack chunkType = 3 + ctHeartbeat chunkType = 4 + ctHeartbeatAck chunkType = 5 + ctAbort chunkType = 6 + ctShutdown chunkType = 7 + ctShutdownAck chunkType = 8 + ctError chunkType = 9 + ctCookieEcho chunkType = 10 + ctCookieAck chunkType = 11 + ctCWR chunkType = 13 + ctShutdownComplete chunkType = 14 + ctReconfig chunkType = 130 + ctForwardTSN chunkType = 192 +) + +func (c chunkType) String() string { //nolint:cyclop + switch c { + case ctPayloadData: + return "DATA" + case ctInit: + return "INIT" + case ctInitAck: + return "INIT-ACK" + case ctSack: + return "SACK" + case ctHeartbeat: + return "HEARTBEAT" + case ctHeartbeatAck: + return "HEARTBEAT-ACK" + case ctAbort: + return "ABORT" + case ctShutdown: + return "SHUTDOWN" + case ctShutdownAck: + return "SHUTDOWN-ACK" + case ctError: + return "ERROR" + case ctCookieEcho: + return "COOKIE-ECHO" + case ctCookieAck: + return "COOKIE-ACK" + case ctCWR: + return "ECNE" // Explicit Congestion Notification Echo + case ctShutdownComplete: + return "SHUTDOWN-COMPLETE" + case ctReconfig: + return "RECONFIG" // Re-configuration + case ctForwardTSN: + return "FORWARD-TSN" + default: + return fmt.Sprintf("Unknown ChunkType: %d", c) + } +} diff --git a/vendor/github.com/pion/sctp/codecov.yml b/vendor/github.com/pion/sctp/codecov.yml new file mode 100644 index 0000000..b9639c2 --- /dev/null +++ b/vendor/github.com/pion/sctp/codecov.yml @@ -0,0 +1,22 @@ +# +# DO NOT EDIT THIS FILE +# +# It is automatically copied from https://github.com/pion/.goassets repository. +# +# SPDX-FileCopyrightText: 2026 The Pion community +# SPDX-License-Identifier: MIT + +coverage: + status: + project: + default: + # Allow decreasing 2% of total coverage to avoid noise. + threshold: 2% + patch: + default: + target: 70% + only_pulls: true + +ignore: + - "examples/*" + - "examples/**/*" diff --git a/vendor/github.com/pion/sctp/control_queue.go b/vendor/github.com/pion/sctp/control_queue.go new file mode 100644 index 0000000..e377bb1 --- /dev/null +++ b/vendor/github.com/pion/sctp/control_queue.go @@ -0,0 +1,33 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +// control queue + +type controlQueue struct { + queue []*packet +} + +func newControlQueue() *controlQueue { + return &controlQueue{queue: []*packet{}} +} + +func (q *controlQueue) push(c *packet) { + q.queue = append(q.queue, c) +} + +func (q *controlQueue) pushAll(packets []*packet) { + q.queue = append(q.queue, packets...) +} + +func (q *controlQueue) popAll() []*packet { + packets := q.queue + q.queue = []*packet{} + + return packets +} + +func (q *controlQueue) size() int { + return len(q.queue) +} diff --git a/vendor/github.com/pion/sctp/error_cause.go b/vendor/github.com/pion/sctp/error_cause.go new file mode 100644 index 0000000..aaca3a1 --- /dev/null +++ b/vendor/github.com/pion/sctp/error_cause.go @@ -0,0 +1,101 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "encoding/binary" + "errors" + "fmt" +) + +// errorCauseCode is a cause code that appears in either a ERROR or ABORT chunk. +type errorCauseCode uint16 + +type errorCause interface { + unmarshal([]byte) error + marshal() ([]byte, error) + length() uint16 + String() string + + errorCauseCode() errorCauseCode +} + +// Error and abort chunk errors. +var ( + ErrBuildErrorCaseHandle = errors.New("BuildErrorCause does not handle") +) + +// buildErrorCause delegates the building of a error cause from raw bytes to the correct structure. +func buildErrorCause(raw []byte) (errorCause, error) { + var errCause errorCause + + c := errorCauseCode(binary.BigEndian.Uint16(raw[0:])) + switch c { + case invalidMandatoryParameter: + errCause = &errorCauseInvalidMandatoryParameter{} + case unrecognizedChunkType: + errCause = &errorCauseUnrecognizedChunkType{} + case protocolViolation: + errCause = &errorCauseProtocolViolation{} + case userInitiatedAbort: + errCause = &errorCauseUserInitiatedAbort{} + default: + return nil, fmt.Errorf("%w: %s", ErrBuildErrorCaseHandle, c.String()) + } + + if err := errCause.unmarshal(raw); err != nil { + return nil, err + } + + return errCause, nil +} + +const ( + invalidStreamIdentifier errorCauseCode = 1 + missingMandatoryParameter errorCauseCode = 2 + staleCookieError errorCauseCode = 3 + outOfResource errorCauseCode = 4 + unresolvableAddress errorCauseCode = 5 + unrecognizedChunkType errorCauseCode = 6 + invalidMandatoryParameter errorCauseCode = 7 + unrecognizedParameters errorCauseCode = 8 + noUserData errorCauseCode = 9 + cookieReceivedWhileShuttingDown errorCauseCode = 10 + restartOfAnAssociationWithNewAddresses errorCauseCode = 11 + userInitiatedAbort errorCauseCode = 12 + protocolViolation errorCauseCode = 13 +) + +func (e errorCauseCode) String() string { //nolint:cyclop + switch e { + case invalidStreamIdentifier: + return "Invalid Stream Identifier" + case missingMandatoryParameter: + return "Missing Mandatory Parameter" + case staleCookieError: + return "Stale Cookie Error" + case outOfResource: + return "Out Of Resource" + case unresolvableAddress: + return "Unresolvable IP" + case unrecognizedChunkType: + return "Unrecognized Chunk Type" + case invalidMandatoryParameter: + return "Invalid Mandatory Parameter" + case unrecognizedParameters: + return "Unrecognized Parameters" + case noUserData: + return "No User Data" + case cookieReceivedWhileShuttingDown: + return "Cookie Received While Shutting Down" + case restartOfAnAssociationWithNewAddresses: + return "Restart Of An Association With New Addresses" + case userInitiatedAbort: + return "User Initiated Abort" + case protocolViolation: + return "Protocol Violation" + default: + return fmt.Sprintf("Unknown CauseCode: %d", e) + } +} diff --git a/vendor/github.com/pion/sctp/error_cause_header.go b/vendor/github.com/pion/sctp/error_cause_header.go new file mode 100644 index 0000000..f8c556f --- /dev/null +++ b/vendor/github.com/pion/sctp/error_cause_header.go @@ -0,0 +1,58 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "encoding/binary" + "errors" +) + +// errorCauseHeader represents the shared header that is shared by all error causes. +type errorCauseHeader struct { + code errorCauseCode + len uint16 + raw []byte +} + +const ( + errorCauseHeaderLength = 4 +) + +// ErrInvalidSCTPChunk is returned when an SCTP chunk is invalid. +var ErrInvalidSCTPChunk = errors.New("invalid SCTP chunk") + +func (e *errorCauseHeader) marshal() ([]byte, error) { + e.len = uint16(len(e.raw)) + uint16(errorCauseHeaderLength) //nolint:gosec // G115 + raw := make([]byte, e.len) + binary.BigEndian.PutUint16(raw[0:], uint16(e.code)) + binary.BigEndian.PutUint16(raw[2:], e.len) + copy(raw[errorCauseHeaderLength:], e.raw) + + return raw, nil +} + +func (e *errorCauseHeader) unmarshal(raw []byte) error { + e.code = errorCauseCode(binary.BigEndian.Uint16(raw[0:])) + e.len = binary.BigEndian.Uint16(raw[2:]) + if e.len < errorCauseHeaderLength || int(e.len) > len(raw) { + return ErrInvalidSCTPChunk + } + valueLength := e.len - errorCauseHeaderLength + e.raw = raw[errorCauseHeaderLength : errorCauseHeaderLength+valueLength] + + return nil +} + +func (e *errorCauseHeader) length() uint16 { + return e.len +} + +func (e *errorCauseHeader) errorCauseCode() errorCauseCode { + return e.code +} + +// String makes errorCauseHeader printable. +func (e errorCauseHeader) String() string { + return e.code.String() +} diff --git a/vendor/github.com/pion/sctp/error_cause_invalid_mandatory_parameter.go b/vendor/github.com/pion/sctp/error_cause_invalid_mandatory_parameter.go new file mode 100644 index 0000000..b46bb4e --- /dev/null +++ b/vendor/github.com/pion/sctp/error_cause_invalid_mandatory_parameter.go @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +// errorCauseInvalidMandatoryParameter represents an SCTP error cause. +type errorCauseInvalidMandatoryParameter struct { + errorCauseHeader +} + +func (e *errorCauseInvalidMandatoryParameter) marshal() ([]byte, error) { + return e.errorCauseHeader.marshal() +} + +func (e *errorCauseInvalidMandatoryParameter) unmarshal(raw []byte) error { + return e.errorCauseHeader.unmarshal(raw) +} + +// String makes errorCauseInvalidMandatoryParameter printable. +func (e *errorCauseInvalidMandatoryParameter) String() string { + return e.errorCauseHeader.String() +} diff --git a/vendor/github.com/pion/sctp/error_cause_protocol_violation.go b/vendor/github.com/pion/sctp/error_cause_protocol_violation.go new file mode 100644 index 0000000..d625fdd --- /dev/null +++ b/vendor/github.com/pion/sctp/error_cause_protocol_violation.go @@ -0,0 +1,58 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "errors" + "fmt" +) + +/* +This error cause MAY be included in ABORT chunks that are sent +because an SCTP endpoint detects a protocol violation of the peer +that is not covered by the error causes described in Section 3.3.10.1 +to Section 3.3.10.12. An implementation MAY provide additional +information specifying what kind of protocol violation has been +detected. + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Cause Code=13 | Cause Length=Variable | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + / Additional Information / + \ \ + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +*/ +type errorCauseProtocolViolation struct { + errorCauseHeader + additionalInformation []byte +} + +// Abort chunk errors. +var ( + ErrProtocolViolationUnmarshal = errors.New("unable to unmarshal Protocol Violation error") +) + +func (e *errorCauseProtocolViolation) marshal() ([]byte, error) { + e.raw = e.additionalInformation + + return e.errorCauseHeader.marshal() +} + +func (e *errorCauseProtocolViolation) unmarshal(raw []byte) error { + err := e.errorCauseHeader.unmarshal(raw) + if err != nil { + return fmt.Errorf("%w: %v", ErrProtocolViolationUnmarshal, err) //nolint:errorlint + } + + e.additionalInformation = e.raw + + return nil +} + +// String makes errorCauseProtocolViolation printable. +func (e *errorCauseProtocolViolation) String() string { + return fmt.Sprintf("%s: %s", e.errorCauseHeader, e.additionalInformation) +} diff --git a/vendor/github.com/pion/sctp/error_cause_unrecognized_chunk_type.go b/vendor/github.com/pion/sctp/error_cause_unrecognized_chunk_type.go new file mode 100644 index 0000000..d641e50 --- /dev/null +++ b/vendor/github.com/pion/sctp/error_cause_unrecognized_chunk_type.go @@ -0,0 +1,33 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +// errorCauseUnrecognizedChunkType represents an SCTP error cause. +type errorCauseUnrecognizedChunkType struct { + errorCauseHeader + unrecognizedChunk []byte +} + +func (e *errorCauseUnrecognizedChunkType) marshal() ([]byte, error) { + e.code = unrecognizedChunkType + e.errorCauseHeader.raw = e.unrecognizedChunk + + return e.errorCauseHeader.marshal() +} + +func (e *errorCauseUnrecognizedChunkType) unmarshal(raw []byte) error { + err := e.errorCauseHeader.unmarshal(raw) + if err != nil { + return err + } + + e.unrecognizedChunk = e.errorCauseHeader.raw + + return nil +} + +// String makes errorCauseUnrecognizedChunkType printable. +func (e *errorCauseUnrecognizedChunkType) String() string { + return e.errorCauseHeader.String() +} diff --git a/vendor/github.com/pion/sctp/error_cause_user_initiated_abort.go b/vendor/github.com/pion/sctp/error_cause_user_initiated_abort.go new file mode 100644 index 0000000..5efbb9e --- /dev/null +++ b/vendor/github.com/pion/sctp/error_cause_user_initiated_abort.go @@ -0,0 +1,51 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "fmt" +) + +/* +This error cause MAY be included in ABORT chunks that are sent +because of an upper-layer request. The upper layer can specify an +Upper Layer Abort Reason that is transported by SCTP transparently +and MAY be delivered to the upper-layer protocol at the peer. + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Cause Code=12 | Cause Length=Variable | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + / Upper Layer Abort Reason / + \ \ + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +*/ +type errorCauseUserInitiatedAbort struct { + errorCauseHeader + upperLayerAbortReason []byte +} + +func (e *errorCauseUserInitiatedAbort) marshal() ([]byte, error) { + e.code = userInitiatedAbort + e.errorCauseHeader.raw = e.upperLayerAbortReason + + return e.errorCauseHeader.marshal() +} + +func (e *errorCauseUserInitiatedAbort) unmarshal(raw []byte) error { + err := e.errorCauseHeader.unmarshal(raw) + if err != nil { + return err + } + + e.upperLayerAbortReason = e.errorCauseHeader.raw + + return nil +} + +// String makes errorCauseUserInitiatedAbort printable. +func (e *errorCauseUserInitiatedAbort) String() string { + return fmt.Sprintf("%s: %s", e.errorCauseHeader.String(), e.upperLayerAbortReason) +} diff --git a/vendor/github.com/pion/sctp/errors.go b/vendor/github.com/pion/sctp/errors.go new file mode 100644 index 0000000..dd2ab1e --- /dev/null +++ b/vendor/github.com/pion/sctp/errors.go @@ -0,0 +1,39 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "errors" +) + +var ( + errNilNetConn = errors.New("netConn must not be nil") + errNilLoggerFactory = errors.New("loggerFactory must not be nil") + + // errZeroMTUOption indicates that the MTU option was set to zero. + errZeroMTUOption = errors.New("MTU option cannot be set to zero") + + // errZeroMaxReceiveBufferOption indicates that the MTU option was set to zero. + errZeroMaxReceiveBufferOption = errors.New("MaxReceiveBuffer option cannot be set to zero") + + // errZeroMaxMessageSize indicates that the MTU option was set to zero. + errZeroMaxMessageSize = errors.New("MaxMessageSize option cannot be set to zero") + + // errInvalidRTOMax indicates that the RTO max was set to 0 or a negative value. + errInvalidRTOMax = errors.New("RTO max was set to <= 0") + + // errInvalidRackMinRTTWnd indicates the length of the local minimum window used to determine the + // minRTT was set to <= 0. + errInvalidRackMinRTTWnd = errors.New("RackMinRTT was set to <= 0") + + // errInvalidRackReoWndFloor indicates the length of the RACK reordering window floor was set to < 0. + errInvalidRackReoWndFloor = errors.New("RackReoWndFloor was set to < 0") + + // errInvalidRackWcDelAck indicates the receiver worst-case delayed-ACK for PTO when only 1 packet in flight + // was set to < 0. + errInvalidRackWcDelAck = errors.New("RackWcDelAck was set to <= 0") + + // errInvalidSnapToken indicates a SNAP token that is not parseable. + errInvalidSnapToken = errors.New("SNAP token is invalid") +) diff --git a/vendor/github.com/pion/sctp/packet.go b/vendor/github.com/pion/sctp/packet.go new file mode 100644 index 0000000..31dec63 --- /dev/null +++ b/vendor/github.com/pion/sctp/packet.go @@ -0,0 +1,249 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "encoding/binary" + "errors" + "fmt" + "hash/crc32" + "strings" +) + +// Create the crc32 table we'll use for the checksum. +var castagnoliTable = crc32.MakeTable(crc32.Castagnoli) // nolint:gochecknoglobals + +// Allocate and zero this data once. +// We need to use it for the checksum and don't want to allocate/clear each time. +var fourZeroes [4]byte // nolint:gochecknoglobals + +/* +Packet represents an SCTP packet, defined in https://tools.ietf.org/html/rfc4960#section-3 +An SCTP packet is composed of a common header and chunks. A chunk +contains either control information or user data. + + SCTP Packet Format + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Common Header | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Chunk #1 | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | ... | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Chunk #n | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + SCTP Common Header Format + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Source Value Number | Destination Value Number | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Verification Tag | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Checksum | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +*/ +type packet struct { + sourcePort uint16 + destinationPort uint16 + verificationTag uint32 + chunks []chunk +} + +const ( + packetHeaderSize = 12 +) + +// SCTP packet errors. +var ( + ErrPacketRawTooSmall = errors.New("raw is smaller than the minimum length for a SCTP packet") + ErrParseSCTPChunkNotEnoughData = errors.New("unable to parse SCTP chunk, not enough data for complete header") + ErrUnmarshalUnknownChunkType = errors.New("failed to unmarshal, contains unknown chunk type") + ErrChecksumMismatch = errors.New("checksum mismatch theirs") +) + +func (p *packet) unmarshal(doChecksum bool, raw []byte) error { //nolint:cyclop + if len(raw) < packetHeaderSize { + return fmt.Errorf("%w: raw only %d bytes, %d is the minimum length", ErrPacketRawTooSmall, len(raw), packetHeaderSize) + } + + offset := packetHeaderSize + + // Check if doing CRC32c is required. + // Without having SCTP AUTH implemented, this depends only on the type + // og the first chunk. + if offset+chunkHeaderSize <= len(raw) { + switch chunkType(raw[offset]) { + case ctInit, ctCookieEcho: + doChecksum = true + default: + } + } + theirChecksum := binary.LittleEndian.Uint32(raw[8:]) + if theirChecksum != 0 || doChecksum { + ourChecksum := generatePacketChecksum(raw) + if theirChecksum != ourChecksum { + return fmt.Errorf("%w: %d ours: %d", ErrChecksumMismatch, theirChecksum, ourChecksum) + } + } + + p.sourcePort = binary.BigEndian.Uint16(raw[0:]) + p.destinationPort = binary.BigEndian.Uint16(raw[2:]) + p.verificationTag = binary.BigEndian.Uint32(raw[4:]) + + for offset < len(raw) { + // guaranteed to be safe by loop condition + remaining := raw[offset:] // nolint:gosec + + // must have at least a full chunk header to continue. + if len(remaining) < chunkHeaderSize { + return fmt.Errorf("%w: offset %d remaining %d", ErrParseSCTPChunkNotEnoughData, offset, len(remaining)) + } + + ctype := chunkType(remaining[0]) + + var dataChunk chunk + switch ctype { + case ctInit: + dataChunk = &chunkInit{} + case ctInitAck: + dataChunk = &chunkInitAck{} + case ctAbort: + dataChunk = &chunkAbort{} + case ctCookieEcho: + dataChunk = &chunkCookieEcho{} + case ctCookieAck: + dataChunk = &chunkCookieAck{} + case ctHeartbeat: + dataChunk = &chunkHeartbeat{} + case ctPayloadData: + dataChunk = &chunkPayloadData{} + case ctSack: + dataChunk = &chunkSelectiveAck{} + case ctReconfig: + dataChunk = &chunkReconfig{} + case ctForwardTSN: + dataChunk = &chunkForwardTSN{} + case ctError: + dataChunk = &chunkError{} + case ctShutdown: + dataChunk = &chunkShutdown{} + case ctShutdownAck: + dataChunk = &chunkShutdownAck{} + case ctShutdownComplete: + dataChunk = &chunkShutdownComplete{} + default: + return fmt.Errorf("%w: %s", ErrUnmarshalUnknownChunkType, ctype.String()) + } + + if err := dataChunk.unmarshal(remaining); err != nil { + return err + } + + p.chunks = append(p.chunks, dataChunk) + chunkValuePadding := getPadding(dataChunk.valueLength()) + offset += chunkHeaderSize + dataChunk.valueLength() + chunkValuePadding + } + + // if we overshot then should error. + if offset != len(raw) { + if offset > len(raw) { + overshoot := offset - len(raw) + + return fmt.Errorf("%w: parsed past end of buffer by %d bytes (offset %d, length %d)", + ErrParseSCTPChunkNotEnoughData, overshoot, offset, len(raw)) + } + + remaining := len(raw) - offset + + return fmt.Errorf("%w: unparsed data remaining: %d bytes (offset %d, length %d)", + ErrParseSCTPChunkNotEnoughData, remaining, offset, len(raw)) + } + + return nil +} + +func (p *packet) marshal(doChecksum bool) ([]byte, error) { + raw := make([]byte, packetHeaderSize) + + // Populate static headers + // 8-12 is Checksum which will be populated when packet is complete + binary.BigEndian.PutUint16(raw[0:], p.sourcePort) + binary.BigEndian.PutUint16(raw[2:], p.destinationPort) + binary.BigEndian.PutUint32(raw[4:], p.verificationTag) + + // Populate chunks + for _, c := range p.chunks { + chunkRaw, err := c.marshal() + if err != nil { + return nil, err + } + raw = append(raw, chunkRaw...) //nolint:makezero // todo:fix + + paddingNeeded := getPadding(len(raw)) + if paddingNeeded != 0 { + raw = append(raw, make([]byte, paddingNeeded)...) //nolint:makezero // todo:fix + } + } + + if doChecksum { + // golang CRC32C uses reflected input and reflected output, the + // net result of this is to have the bytes flipped compared to + // the non reflected variant that the spec expects. + // + // Use LittleEndian.PutUint32 to avoid flipping the bytes in to + // the spec compliant checksum order + binary.LittleEndian.PutUint32(raw[8:], generatePacketChecksum(raw)) + } + + return raw, nil +} + +func generatePacketChecksum(raw []byte) (sum uint32) { + // Fastest way to do a crc32 without allocating. + sum = crc32.Update(sum, castagnoliTable, raw[0:8]) + sum = crc32.Update(sum, castagnoliTable, fourZeroes[:]) + sum = crc32.Update(sum, castagnoliTable, raw[12:]) + + return sum +} + +// String makes packet printable. +func (p *packet) String() string { + format := `Packet: + sourcePort: %d + destinationPort: %d + verificationTag: %d + ` + var res strings.Builder + fmt.Fprintf(&res, format, + p.sourcePort, + p.destinationPort, + p.verificationTag, + ) + for i, chunk := range p.chunks { + fmt.Fprintf(&res, "Chunk %d:\n %s", i, chunk) + } + + return res.String() +} + +// TryMarshalUnmarshal attempts to marshal and unmarshal a message. Added for fuzzing. +func TryMarshalUnmarshal(msg []byte) int { + p := &packet{} + err := p.unmarshal(false, msg) + if err != nil { + return 0 + } + + _, err = p.marshal(false) + if err != nil { + return 0 + } + + return 1 +} diff --git a/vendor/github.com/pion/sctp/param.go b/vendor/github.com/pion/sctp/param.go new file mode 100644 index 0000000..230a466 --- /dev/null +++ b/vendor/github.com/pion/sctp/param.go @@ -0,0 +1,46 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "errors" + "fmt" +) + +type param interface { + marshal() ([]byte, error) + length() int +} + +// ErrParamTypeUnhandled is returned if unknown parameter type is specified. +var ErrParamTypeUnhandled = errors.New("unhandled ParamType") + +func buildParam(typeParam paramType, rawParam []byte) (param, error) { //nolint:cyclop + switch typeParam { + case forwardTSNSupp: + return (¶mForwardTSNSupported{}).unmarshal(rawParam) + case supportedExt: + return (¶mSupportedExtensions{}).unmarshal(rawParam) + case ecnCapable: + return (¶mECNCapable{}).unmarshal(rawParam) + case random: + return (¶mRandom{}).unmarshal(rawParam) + case reqHMACAlgo: + return (¶mRequestedHMACAlgorithm{}).unmarshal(rawParam) + case chunkList: + return (¶mChunkList{}).unmarshal(rawParam) + case stateCookie: + return (¶mStateCookie{}).unmarshal(rawParam) + case heartbeatInfo: + return (¶mHeartbeatInfo{}).unmarshal(rawParam) + case outSSNResetReq: + return (¶mOutgoingResetRequest{}).unmarshal(rawParam) + case reconfigResp: + return (¶mReconfigResponse{}).unmarshal(rawParam) + case zeroChecksumAcceptable: + return (¶mZeroChecksumAcceptable{}).unmarshal(rawParam) + default: + return nil, fmt.Errorf("%w: %v", ErrParamTypeUnhandled, typeParam) + } +} diff --git a/vendor/github.com/pion/sctp/param_chunk_list.go b/vendor/github.com/pion/sctp/param_chunk_list.go new file mode 100644 index 0000000..8f2afe0 --- /dev/null +++ b/vendor/github.com/pion/sctp/param_chunk_list.go @@ -0,0 +1,31 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +type paramChunkList struct { + paramHeader + chunkTypes []chunkType +} + +func (c *paramChunkList) marshal() ([]byte, error) { + c.typ = chunkList + c.raw = make([]byte, len(c.chunkTypes)) + for i, t := range c.chunkTypes { + c.raw[i] = byte(t) + } + + return c.paramHeader.marshal() +} + +func (c *paramChunkList) unmarshal(raw []byte) (param, error) { + err := c.paramHeader.unmarshal(raw) + if err != nil { + return nil, err + } + for _, t := range c.raw { + c.chunkTypes = append(c.chunkTypes, chunkType(t)) + } + + return c, nil +} diff --git a/vendor/github.com/pion/sctp/param_ecn_capable.go b/vendor/github.com/pion/sctp/param_ecn_capable.go new file mode 100644 index 0000000..689c55e --- /dev/null +++ b/vendor/github.com/pion/sctp/param_ecn_capable.go @@ -0,0 +1,24 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +type paramECNCapable struct { + paramHeader +} + +func (r *paramECNCapable) marshal() ([]byte, error) { + r.typ = ecnCapable + r.raw = []byte{} + + return r.paramHeader.marshal() +} + +func (r *paramECNCapable) unmarshal(raw []byte) (param, error) { + err := r.paramHeader.unmarshal(raw) + if err != nil { + return nil, err + } + + return r, nil +} diff --git a/vendor/github.com/pion/sctp/param_forward_tsn_supported.go b/vendor/github.com/pion/sctp/param_forward_tsn_supported.go new file mode 100644 index 0000000..276efeb --- /dev/null +++ b/vendor/github.com/pion/sctp/param_forward_tsn_supported.go @@ -0,0 +1,33 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +// At the initialization of the association, the sender of the INIT or +// INIT ACK chunk MAY include this OPTIONAL parameter to inform its peer +// that it is able to support the Forward TSN chunk +// +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Parameter Type = 49152 | Parameter Length = 4 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +type paramForwardTSNSupported struct { + paramHeader +} + +func (f *paramForwardTSNSupported) marshal() ([]byte, error) { + f.typ = forwardTSNSupp + f.raw = []byte{} + + return f.paramHeader.marshal() +} + +func (f *paramForwardTSNSupported) unmarshal(raw []byte) (param, error) { + err := f.paramHeader.unmarshal(raw) + if err != nil { + return nil, err + } + + return f, nil +} diff --git a/vendor/github.com/pion/sctp/param_heartbeat_info.go b/vendor/github.com/pion/sctp/param_heartbeat_info.go new file mode 100644 index 0000000..8339bad --- /dev/null +++ b/vendor/github.com/pion/sctp/param_heartbeat_info.go @@ -0,0 +1,26 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +type paramHeartbeatInfo struct { + paramHeader + heartbeatInformation []byte +} + +func (h *paramHeartbeatInfo) marshal() ([]byte, error) { + h.typ = heartbeatInfo + h.raw = h.heartbeatInformation + + return h.paramHeader.marshal() +} + +func (h *paramHeartbeatInfo) unmarshal(raw []byte) (param, error) { + err := h.paramHeader.unmarshal(raw) + if err != nil { + return nil, err + } + h.heartbeatInformation = h.raw + + return h, nil +} diff --git a/vendor/github.com/pion/sctp/param_outgoing_reset_request.go b/vendor/github.com/pion/sctp/param_outgoing_reset_request.go new file mode 100644 index 0000000..0fb7007 --- /dev/null +++ b/vendor/github.com/pion/sctp/param_outgoing_reset_request.go @@ -0,0 +1,95 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "encoding/binary" + "errors" +) + +const ( + paramOutgoingResetRequestStreamIdentifiersOffset = 12 +) + +// This parameter is used by the sender to request the reset of some or +// all outgoing streams. +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Parameter Type = 13 | Parameter Length = 16 + 2 * N | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Re-configuration Request Sequence Number | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Re-configuration Response Sequence Number | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Sender's Last Assigned TSN | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Stream Number 1 (optional) | Stream Number 2 (optional) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// / ...... / +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Stream Number N-1 (optional) | Stream Number N (optional) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +type paramOutgoingResetRequest struct { + paramHeader + // reconfigRequestSequenceNumber is used to identify the request. It is a monotonically + // increasing number that is initialized to the same value as the + // initial TSN. It is increased by 1 whenever sending a new Re- + // configuration Request Parameter. + reconfigRequestSequenceNumber uint32 + // When this Outgoing SSN Reset Request Parameter is sent in response + // to an Incoming SSN Reset Request Parameter, this parameter is also + // an implicit response to the incoming request. This field then + // holds the Re-configuration Request Sequence Number of the incoming + // request. In other cases, it holds the next expected + // Re-configuration Request Sequence Number minus 1. + reconfigResponseSequenceNumber uint32 + // This value holds the next TSN minus 1 -- in other words, the last + // TSN that this sender assigned. + senderLastTSN uint32 + // This optional field, if included, is used to indicate specific + // streams that are to be reset. If no streams are listed, then all + // streams are to be reset. + streamIdentifiers []uint16 +} + +// Outgoing reset request parameter errors. +var ( + ErrSSNResetRequestParamTooShort = errors.New("outgoing SSN reset request parameter too short") +) + +func (r *paramOutgoingResetRequest) marshal() ([]byte, error) { + r.typ = outSSNResetReq + r.raw = make([]byte, paramOutgoingResetRequestStreamIdentifiersOffset+2*len(r.streamIdentifiers)) + binary.BigEndian.PutUint32(r.raw, r.reconfigRequestSequenceNumber) + binary.BigEndian.PutUint32(r.raw[4:], r.reconfigResponseSequenceNumber) + binary.BigEndian.PutUint32(r.raw[8:], r.senderLastTSN) + for i, sID := range r.streamIdentifiers { + binary.BigEndian.PutUint16(r.raw[paramOutgoingResetRequestStreamIdentifiersOffset+2*i:], sID) + } + + return r.paramHeader.marshal() +} + +func (r *paramOutgoingResetRequest) unmarshal(raw []byte) (param, error) { + err := r.paramHeader.unmarshal(raw) + if err != nil { + return nil, err + } + if len(r.raw) < paramOutgoingResetRequestStreamIdentifiersOffset { + return nil, ErrSSNResetRequestParamTooShort + } + r.reconfigRequestSequenceNumber = binary.BigEndian.Uint32(r.raw) + r.reconfigResponseSequenceNumber = binary.BigEndian.Uint32(r.raw[4:]) + r.senderLastTSN = binary.BigEndian.Uint32(r.raw[8:]) + + lim := (len(r.raw) - paramOutgoingResetRequestStreamIdentifiersOffset) / 2 + r.streamIdentifiers = make([]uint16, lim) + for i := range lim { + r.streamIdentifiers[i] = binary.BigEndian.Uint16(r.raw[paramOutgoingResetRequestStreamIdentifiersOffset+2*i:]) + } + + return r, nil +} diff --git a/vendor/github.com/pion/sctp/param_random.go b/vendor/github.com/pion/sctp/param_random.go new file mode 100644 index 0000000..e873d09 --- /dev/null +++ b/vendor/github.com/pion/sctp/param_random.go @@ -0,0 +1,26 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +type paramRandom struct { + paramHeader + randomData []byte +} + +func (r *paramRandom) marshal() ([]byte, error) { + r.typ = random + r.raw = r.randomData + + return r.paramHeader.marshal() +} + +func (r *paramRandom) unmarshal(raw []byte) (param, error) { + err := r.paramHeader.unmarshal(raw) + if err != nil { + return nil, err + } + r.randomData = r.raw + + return r, nil +} diff --git a/vendor/github.com/pion/sctp/param_reconfig_response.go b/vendor/github.com/pion/sctp/param_reconfig_response.go new file mode 100644 index 0000000..3740a56 --- /dev/null +++ b/vendor/github.com/pion/sctp/param_reconfig_response.go @@ -0,0 +1,98 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "encoding/binary" + "errors" + "fmt" +) + +// This parameter is used by the receiver of a Re-configuration Request +// Parameter to respond to the request. +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Parameter Type = 16 | Parameter Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Re-configuration Response Sequence Number | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Result | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Sender's Next TSN (optional) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Receiver's Next TSN (optional) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +type paramReconfigResponse struct { + paramHeader + // This value is copied from the request parameter and is used by the + // receiver of the Re-configuration Response Parameter to tie the + // response to the request. + reconfigResponseSequenceNumber uint32 + // This value describes the result of the processing of the request. + result reconfigResult +} + +type reconfigResult uint32 + +const ( + reconfigResultSuccessNOP reconfigResult = 0 + reconfigResultSuccessPerformed reconfigResult = 1 + reconfigResultDenied reconfigResult = 2 + reconfigResultErrorWrongSSN reconfigResult = 3 + reconfigResultErrorRequestAlreadyInProgress reconfigResult = 4 + reconfigResultErrorBadSequenceNumber reconfigResult = 5 + reconfigResultInProgress reconfigResult = 6 +) + +// Reconfiguration response errors. +var ( + ErrReconfigRespParamTooShort = errors.New("reconfig response parameter too short") +) + +func (t reconfigResult) String() string { + switch t { + case reconfigResultSuccessNOP: + return "0: Success - Nothing to do" + case reconfigResultSuccessPerformed: + return "1: Success - Performed" + case reconfigResultDenied: + return "2: Denied" + case reconfigResultErrorWrongSSN: + return "3: Error - Wrong SSN" + case reconfigResultErrorRequestAlreadyInProgress: + return "4: Error - Request already in progress" + case reconfigResultErrorBadSequenceNumber: + return "5: Error - Bad Sequence Number" + case reconfigResultInProgress: + return "6: In progress" + default: + return fmt.Sprintf("Unknown reconfigResult: %d", t) + } +} + +func (r *paramReconfigResponse) marshal() ([]byte, error) { + r.typ = reconfigResp + r.raw = make([]byte, 8) + binary.BigEndian.PutUint32(r.raw, r.reconfigResponseSequenceNumber) + binary.BigEndian.PutUint32(r.raw[4:], uint32(r.result)) + + return r.paramHeader.marshal() +} + +func (r *paramReconfigResponse) unmarshal(raw []byte) (param, error) { + err := r.paramHeader.unmarshal(raw) + if err != nil { + return nil, err + } + if len(r.raw) < 8 { + return nil, ErrReconfigRespParamTooShort + } + r.reconfigResponseSequenceNumber = binary.BigEndian.Uint32(r.raw) + r.result = reconfigResult(binary.BigEndian.Uint32(r.raw[4:])) + + return r, nil +} diff --git a/vendor/github.com/pion/sctp/param_requested_hmac_algorithm.go b/vendor/github.com/pion/sctp/param_requested_hmac_algorithm.go new file mode 100644 index 0000000..14756f7 --- /dev/null +++ b/vendor/github.com/pion/sctp/param_requested_hmac_algorithm.go @@ -0,0 +1,84 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "encoding/binary" + "errors" + "fmt" +) + +type hmacAlgorithm uint16 + +const ( + hmacResv1 hmacAlgorithm = 0 + hmacSHA128 hmacAlgorithm = 1 + hmacResv2 hmacAlgorithm = 2 + hmacSHA256 hmacAlgorithm = 3 +) + +// ErrInvalidAlgorithmType is returned if unknown auth algorithm is specified. +var ErrInvalidAlgorithmType = errors.New("invalid algorithm type") + +// ErrInvalidChunkLength is returned if the chunk length is invalid. +var ErrInvalidChunkLength = errors.New("invalid chunk length") + +func (c hmacAlgorithm) String() string { + switch c { + case hmacResv1: + return "HMAC Reserved (0x00)" + case hmacSHA128: + return "HMAC SHA-128" + case hmacResv2: + return "HMAC Reserved (0x02)" + case hmacSHA256: + return "HMAC SHA-256" + default: + return fmt.Sprintf("Unknown HMAC Algorithm type: %d", c) + } +} + +type paramRequestedHMACAlgorithm struct { + paramHeader + availableAlgorithms []hmacAlgorithm +} + +func (r *paramRequestedHMACAlgorithm) marshal() ([]byte, error) { + r.typ = reqHMACAlgo + r.raw = make([]byte, len(r.availableAlgorithms)*2) + i := 0 + for _, a := range r.availableAlgorithms { + binary.BigEndian.PutUint16(r.raw[i:], uint16(a)) + i += 2 + } + + return r.paramHeader.marshal() +} + +func (r *paramRequestedHMACAlgorithm) unmarshal(raw []byte) (param, error) { + err := r.paramHeader.unmarshal(raw) + if err != nil { + return nil, err + } + if len(r.raw)%2 == 1 { + return nil, ErrInvalidChunkLength + } + + i := 0 + for i < len(r.raw) { + a := hmacAlgorithm(binary.BigEndian.Uint16(r.raw[i:])) + switch a { + case hmacSHA128: + fallthrough + case hmacSHA256: + r.availableAlgorithms = append(r.availableAlgorithms, a) + default: + return nil, fmt.Errorf("%w: %v", ErrInvalidAlgorithmType, a) + } + + i += 2 + } + + return r, nil +} diff --git a/vendor/github.com/pion/sctp/param_state_cookie.go b/vendor/github.com/pion/sctp/param_state_cookie.go new file mode 100644 index 0000000..b39ba56 --- /dev/null +++ b/vendor/github.com/pion/sctp/param_state_cookie.go @@ -0,0 +1,51 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "crypto/rand" + "fmt" +) + +type paramStateCookie struct { + paramHeader + cookie []byte +} + +func newRandomStateCookie() (*paramStateCookie, error) { + randCookie := make([]byte, 32) + _, err := rand.Read(randCookie) + // crypto/rand.Read returns n == len(b) if and only if err == nil. + if err != nil { + return nil, err + } + + s := ¶mStateCookie{ + cookie: randCookie, + } + + return s, nil +} + +func (s *paramStateCookie) marshal() ([]byte, error) { + s.typ = stateCookie + s.raw = s.cookie + + return s.paramHeader.marshal() +} + +func (s *paramStateCookie) unmarshal(raw []byte) (param, error) { + err := s.paramHeader.unmarshal(raw) + if err != nil { + return nil, err + } + s.cookie = s.raw + + return s, nil +} + +// String makes paramStateCookie printable. +func (s *paramStateCookie) String() string { + return fmt.Sprintf("%s: %s", s.paramHeader, s.cookie) +} diff --git a/vendor/github.com/pion/sctp/param_supported_extensions.go b/vendor/github.com/pion/sctp/param_supported_extensions.go new file mode 100644 index 0000000..88ec9d1 --- /dev/null +++ b/vendor/github.com/pion/sctp/param_supported_extensions.go @@ -0,0 +1,32 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +type paramSupportedExtensions struct { + paramHeader + ChunkTypes []chunkType +} + +func (s *paramSupportedExtensions) marshal() ([]byte, error) { + s.typ = supportedExt + s.raw = make([]byte, len(s.ChunkTypes)) + for i, c := range s.ChunkTypes { + s.raw[i] = byte(c) + } + + return s.paramHeader.marshal() +} + +func (s *paramSupportedExtensions) unmarshal(raw []byte) (param, error) { + err := s.paramHeader.unmarshal(raw) + if err != nil { + return nil, err + } + + for _, t := range s.raw { + s.ChunkTypes = append(s.ChunkTypes, chunkType(t)) + } + + return s, nil +} diff --git a/vendor/github.com/pion/sctp/param_zero_checksum.go b/vendor/github.com/pion/sctp/param_zero_checksum.go new file mode 100644 index 0000000..b1dd211 --- /dev/null +++ b/vendor/github.com/pion/sctp/param_zero_checksum.go @@ -0,0 +1,58 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "encoding/binary" + "errors" +) + +// This parameter is used to inform the receiver that a sender is willing to +// accept zero as checksum if some other error detection method is used +// instead. +// +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 0x8001 (suggested) | Length = 8 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Error Detection Method Identifier (EDMID) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +type paramZeroChecksumAcceptable struct { + paramHeader + // The Error Detection Method Identifier (EDMID) specifies an alternate + // error detection method the sender of this parameter is willing to use for + // received packets. + edmid uint32 +} + +// Zero Checksum parameter error. +var ( + ErrZeroChecksumParamTooShort = errors.New("zero checksum parameter too short") +) + +const ( + dtlsErrorDetectionMethod uint32 = 1 +) + +func (r *paramZeroChecksumAcceptable) marshal() ([]byte, error) { + r.typ = zeroChecksumAcceptable + r.raw = make([]byte, 4) + binary.BigEndian.PutUint32(r.raw, r.edmid) + + return r.paramHeader.marshal() +} + +func (r *paramZeroChecksumAcceptable) unmarshal(raw []byte) (param, error) { + err := r.paramHeader.unmarshal(raw) + if err != nil { + return nil, err + } + if len(r.raw) < 4 { + return nil, ErrZeroChecksumParamTooShort + } + r.edmid = binary.BigEndian.Uint32(r.raw) + + return r, nil +} diff --git a/vendor/github.com/pion/sctp/paramheader.go b/vendor/github.com/pion/sctp/paramheader.go new file mode 100644 index 0000000..b876402 --- /dev/null +++ b/vendor/github.com/pion/sctp/paramheader.go @@ -0,0 +1,107 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "encoding/binary" + "encoding/hex" + "errors" + "fmt" +) + +type paramHeaderUnrecognizedAction byte + +type paramHeader struct { + typ paramType + unrecognizedAction paramHeaderUnrecognizedAction + len int + raw []byte +} + +/* + The Parameter Types are encoded such that the highest-order 2 bits specify + the action that is taken if the processing endpoint does not recognize the + Parameter Type. + + 00 - Stop processing this parameter and do not process any further parameters within this chunk. + + 01 - Stop processing this parameter, do not process any further parameters within this chunk, and + report the unrecognized parameter, as described in Section 3.2.2. + + 10 - Skip this parameter and continue processing. + + 11 - Skip this parameter and continue processing, but report the unrecognized + parameter, as described in Section 3.2.2. + + https://www.rfc-editor.org/rfc/rfc9260.html#section-3.2.1 +*/ + +const ( + paramHeaderUnrecognizedActionMask = 0b11000000 + paramHeaderUnrecognizedActionStop paramHeaderUnrecognizedAction = 0b00000000 + paramHeaderUnrecognizedActionStopAndReport paramHeaderUnrecognizedAction = 0b01000000 + paramHeaderUnrecognizedActionSkip paramHeaderUnrecognizedAction = 0b10000000 + paramHeaderUnrecognizedActionSkipAndReport paramHeaderUnrecognizedAction = 0b11000000 + + paramHeaderLength = 4 +) + +// Parameter header parse errors. +var ( + ErrParamHeaderTooShort = errors.New("param header too short") + ErrParamHeaderSelfReportedLengthShorter = errors.New("param self reported length is shorter than header length") + ErrParamHeaderSelfReportedLengthLonger = errors.New("param self reported length is longer than header length") + ErrParamHeaderParseFailed = errors.New("failed to parse param type") +) + +func (p *paramHeader) marshal() ([]byte, error) { + paramLengthPlusHeader := paramHeaderLength + len(p.raw) + + rawParam := make([]byte, paramLengthPlusHeader) + binary.BigEndian.PutUint16(rawParam[0:], uint16(p.typ)) + binary.BigEndian.PutUint16(rawParam[2:], uint16(paramLengthPlusHeader)) //nolint:gosec // G115 + copy(rawParam[paramHeaderLength:], p.raw) + + return rawParam, nil +} + +func (p *paramHeader) unmarshal(raw []byte) error { + if len(raw) < paramHeaderLength { + return ErrParamHeaderTooShort + } + + paramLengthPlusHeader := binary.BigEndian.Uint16(raw[2:]) + if int(paramLengthPlusHeader) < paramHeaderLength { + return fmt.Errorf( + "%w: param self reported length (%d) shorter than header length (%d)", + ErrParamHeaderSelfReportedLengthShorter, int(paramLengthPlusHeader), paramHeaderLength, + ) + } + if len(raw) < int(paramLengthPlusHeader) { + return fmt.Errorf( + "%w: param length (%d) shorter than its self reported length (%d)", + ErrParamHeaderSelfReportedLengthLonger, len(raw), int(paramLengthPlusHeader), + ) + } + + typ, err := parseParamType(raw[0:]) + if err != nil { + return fmt.Errorf("%w: %v", ErrParamHeaderParseFailed, err) //nolint:errorlint + } + p.typ = typ + p.unrecognizedAction = paramHeaderUnrecognizedAction(raw[0] & paramHeaderUnrecognizedActionMask) + p.raw = raw[paramHeaderLength:paramLengthPlusHeader] + p.len = int(paramLengthPlusHeader) + + return nil +} + +func (p *paramHeader) length() int { + return p.len +} + +// String makes paramHeader printable. +func (p paramHeader) String() string { + return fmt.Sprintf("%s (%d): %s", p.typ, p.len, hex.Dump(p.raw)) +} diff --git a/vendor/github.com/pion/sctp/paramtype.go b/vendor/github.com/pion/sctp/paramtype.go new file mode 100644 index 0000000..8ba0eaa --- /dev/null +++ b/vendor/github.com/pion/sctp/paramtype.go @@ -0,0 +1,120 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "encoding/binary" + "errors" + "fmt" +) + +// paramType represents a SCTP INIT/INITACK parameter. +type paramType uint16 + +const ( + heartbeatInfo paramType = 1 // Heartbeat Info [RFC9260] + ipV4Addr paramType = 5 // IPv4 IP [RFC9260] + ipV6Addr paramType = 6 // IPv6 IP [RFC9260] + stateCookie paramType = 7 // State Cookie [RFC9260] + unrecognizedParam paramType = 8 // Unrecognized Parameters [RFC9260] + cookiePreservative paramType = 9 // Cookie Preservative [RFC9260] + hostNameAddr paramType = 11 // Host Name Address [RFC9260] + supportedAddrTypes paramType = 12 // Supported IP Types [RFC9260] + outSSNResetReq paramType = 13 // Outgoing SSN Reset Request Parameter [RFC6525] + incSSNResetReq paramType = 14 // Incoming SSN Reset Request Parameter [RFC6525] + ssnTSNResetReq paramType = 15 // SSN/TSN Reset Request Parameter [RFC6525] + reconfigResp paramType = 16 // Re-configuration Response Parameter [RFC6525] + addOutStreamsReq paramType = 17 // Add Outgoing Streams Request Parameter [RFC6525] + addIncStreamsReq paramType = 18 // Add Incoming Streams Request Parameter [RFC6525] + ecnCapable paramType = 32768 // ECN Capable (0x8000) [RFC2960] + zeroChecksumAcceptable paramType = 32769 // Zero Checksum Acceptable [draft-ietf-tsvwg-sctp-zero-checksum-00] + random paramType = 32770 // Random (0x8002) [RFC4895] + chunkList paramType = 32771 // Chunk List (0x8003) [RFC4895] + reqHMACAlgo paramType = 32772 // Requested HMAC Algorithm Parameter (0x8004) [RFC4895] + padding paramType = 32773 // Padding (0x8005) + supportedExt paramType = 32776 // Supported Extensions (0x8008) [RFC5061] + forwardTSNSupp paramType = 49152 // Forward TSN supported (0xC000) [RFC3758] + addIPAddr paramType = 49153 // Add IP Address (0xC001) [RFC5061] + delIPAddr paramType = 49154 // Delete IP Address (0xC002) [RFC5061] + errClauseInd paramType = 49155 // Error Cause Indication (0xC003) [RFC5061] + setPriAddr paramType = 49156 // Set Primary IP (0xC004) [RFC5061] + successInd paramType = 49157 // Success Indication (0xC005) [RFC5061] + adaptLayerInd paramType = 49158 // Adaptation Layer Indication (0xC006) [RFC5061] +) + +// Parameter packet errors. +var ( + ErrParamPacketTooShort = errors.New("packet too short") +) + +func parseParamType(raw []byte) (paramType, error) { + if len(raw) < 2 { + return paramType(0), ErrParamPacketTooShort + } + + return paramType(binary.BigEndian.Uint16(raw)), nil +} + +func (p paramType) String() string { //nolint:cyclop + switch p { + case heartbeatInfo: + return "Heartbeat Info" + case ipV4Addr: + return "IPv4 IP" + case ipV6Addr: + return "IPv6 IP" + case stateCookie: + return "State Cookie" + case unrecognizedParam: + return "Unrecognized Parameters" + case cookiePreservative: + return "Cookie Preservative" + case hostNameAddr: + return "Host Name Address" + case supportedAddrTypes: + return "Supported IP Types" + case outSSNResetReq: + return "Outgoing SSN Reset Request Parameter" + case incSSNResetReq: + return "Incoming SSN Reset Request Parameter" + case ssnTSNResetReq: + return "SSN/TSN Reset Request Parameter" + case reconfigResp: + return "Re-configuration Response Parameter" + case addOutStreamsReq: + return "Add Outgoing Streams Request Parameter" + case addIncStreamsReq: + return "Add Incoming Streams Request Parameter" + case ecnCapable: + return "ECN Capable" + case zeroChecksumAcceptable: + return "Zero Checksum Acceptable" + case random: + return "Random" + case chunkList: + return "Chunk List" + case reqHMACAlgo: + return "Requested HMAC Algorithm Parameter" + case padding: + return "Padding" + case supportedExt: + return "Supported Extensions" + case forwardTSNSupp: + return "Forward TSN supported" + case addIPAddr: + return "Add IP Address" + case delIPAddr: + return "Delete IP Address" + case errClauseInd: + return "Error Cause Indication" + case setPriAddr: + return "Set Primary IP" + case successInd: + return "Success Indication" + case adaptLayerInd: + return "Adaptation Layer Indication" + default: + return fmt.Sprintf("Unknown ParamType: %d", p) + } +} diff --git a/vendor/github.com/pion/sctp/payload_queue.go b/vendor/github.com/pion/sctp/payload_queue.go new file mode 100644 index 0000000..622c0fd --- /dev/null +++ b/vendor/github.com/pion/sctp/payload_queue.go @@ -0,0 +1,75 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +type payloadQueue struct { + chunks *queue[*chunkPayloadData] + nBytes int +} + +func newPayloadQueue() *payloadQueue { + return &payloadQueue{chunks: newQueue[*chunkPayloadData](128)} +} + +func (q *payloadQueue) pushNoCheck(p *chunkPayloadData) { + q.chunks.PushBack(p) + q.nBytes += len(p.userData) +} + +// pop pops only if the oldest chunk's TSN matches the given TSN. +func (q *payloadQueue) pop(tsn uint32) (*chunkPayloadData, bool) { + if q.chunks.Len() > 0 && tsn == q.chunks.Front().tsn { + c := q.chunks.PopFront() + q.nBytes -= len(c.userData) + + return c, true + } + + return nil, false +} + +// get returns reference to chunkPayloadData with the given TSN value. +func (q *payloadQueue) get(tsn uint32) (*chunkPayloadData, bool) { + length := q.chunks.Len() + if length == 0 { + return nil, false + } + head := q.chunks.Front().tsn + if tsn < head || int(tsn-head) >= length { + return nil, false + } + + return q.chunks.At(int(tsn - head)), true +} + +func (q *payloadQueue) markAsAcked(tsn uint32) int { + var nBytesAcked int + if c, ok := q.get(tsn); ok { + c.acked = true + c.retransmit = false + nBytesAcked = len(c.userData) + q.nBytes -= nBytesAcked + c.userData = []byte{} + } + + return nBytesAcked +} + +func (q *payloadQueue) markAllToRetrasmit() { + for i := 0; i < q.chunks.Len(); i++ { + c := q.chunks.At(i) + if c.acked || c.abandoned() { + continue + } + c.retransmit = true + } +} + +func (q *payloadQueue) getNumBytes() int { + return q.nBytes +} + +func (q *payloadQueue) size() int { + return q.chunks.Len() +} diff --git a/vendor/github.com/pion/sctp/pending_queue.go b/vendor/github.com/pion/sctp/pending_queue.go new file mode 100644 index 0000000..be150f0 --- /dev/null +++ b/vendor/github.com/pion/sctp/pending_queue.go @@ -0,0 +1,164 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "errors" +) + +// pendingBaseQueue + +type pendingBaseQueue struct { + queue []*chunkPayloadData +} + +func newPendingBaseQueue() *pendingBaseQueue { + return &pendingBaseQueue{queue: []*chunkPayloadData{}} +} + +func (q *pendingBaseQueue) push(c *chunkPayloadData) { + q.queue = append(q.queue, c) +} + +func (q *pendingBaseQueue) pop() *chunkPayloadData { + if len(q.queue) == 0 { + return nil + } + + c := q.queue[0] + q.queue[0] = nil + + if len(q.queue) == 0 { + q.queue = nil + } else { + q.queue = q.queue[1:] + } + + return c +} + +func (q *pendingBaseQueue) get(i int) *chunkPayloadData { + if len(q.queue) == 0 || i < 0 || i >= len(q.queue) { + return nil + } + + return q.queue[i] +} + +func (q *pendingBaseQueue) size() int { + return len(q.queue) +} + +// pendingQueue + +type pendingQueue struct { + unorderedQueue *pendingBaseQueue + orderedQueue *pendingBaseQueue + nBytes int + selected bool + unorderedIsSelected bool +} + +// Pending queue errors. +var ( + ErrUnexpectedChunkPoppedUnordered = errors.New("unexpected chunk popped (unordered)") + ErrUnexpectedChunkPoppedOrdered = errors.New("unexpected chunk popped (ordered)") + ErrUnexpectedQState = errors.New("unexpected q state (should've been selected)") + + // Deprecated: use ErrUnexpectedChunkPoppedUnordered. + ErrUnexpectedChuckPoppedUnordered = ErrUnexpectedChunkPoppedUnordered + // Deprecated: use ErrUnexpectedChunkPoppedOrdered. + ErrUnexpectedChuckPoppedOrdered = ErrUnexpectedChunkPoppedOrdered +) + +func newPendingQueue() *pendingQueue { + return &pendingQueue{ + unorderedQueue: newPendingBaseQueue(), + orderedQueue: newPendingBaseQueue(), + } +} + +func (q *pendingQueue) push(c *chunkPayloadData) { + if c.unordered { + q.unorderedQueue.push(c) + } else { + q.orderedQueue.push(c) + } + q.nBytes += len(c.userData) +} + +func (q *pendingQueue) peek() *chunkPayloadData { + if q.selected { + if q.unorderedIsSelected { + return q.unorderedQueue.get(0) + } + + return q.orderedQueue.get(0) + } + + if c := q.unorderedQueue.get(0); c != nil { + return c + } + + return q.orderedQueue.get(0) +} + +func (q *pendingQueue) pop(chunkPayload *chunkPayloadData) error { //nolint:cyclop + if q.selected { //nolint:nestif + var popped *chunkPayloadData + if q.unorderedIsSelected { + popped = q.unorderedQueue.pop() + if popped != chunkPayload { + return ErrUnexpectedChunkPoppedUnordered + } + } else { + popped = q.orderedQueue.pop() + if popped != chunkPayload { + return ErrUnexpectedChunkPoppedOrdered + } + } + if popped.endingFragment { + q.selected = false + } + } else { + if !chunkPayload.beginningFragment { + return ErrUnexpectedQState + } + if chunkPayload.unordered { + popped := q.unorderedQueue.pop() + if popped != chunkPayload { + return ErrUnexpectedChunkPoppedUnordered + } + if !popped.endingFragment { + q.selected = true + q.unorderedIsSelected = true + } + } else { + popped := q.orderedQueue.pop() + if popped != chunkPayload { + return ErrUnexpectedChunkPoppedOrdered + } + if !popped.endingFragment { + q.selected = true + q.unorderedIsSelected = false + } + } + } + + // guard against negative values (should never happen, but just in case). + q.nBytes -= len(chunkPayload.userData) + if q.nBytes < 0 { + q.nBytes = 0 + } + + return nil +} + +func (q *pendingQueue) getNumBytes() int { + return q.nBytes +} + +func (q *pendingQueue) size() int { + return q.unorderedQueue.size() + q.orderedQueue.size() +} diff --git a/vendor/github.com/pion/sctp/queue.go b/vendor/github.com/pion/sctp/queue.go new file mode 100644 index 0000000..69f0edb --- /dev/null +++ b/vendor/github.com/pion/sctp/queue.go @@ -0,0 +1,75 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +type queue[T any] struct { + buf []T + head int + tail int + count int +} + +const minCap = 16 + +func newQueue[T any](capacity int) *queue[T] { + queueCap := minCap + for queueCap < capacity { + queueCap <<= 1 + } + + return &queue[T]{ + buf: make([]T, queueCap), + } +} + +func (q *queue[T]) Len() int { + return q.count +} + +func (q *queue[T]) PushBack(ele T) { + q.growIfFull() + q.buf[q.tail] = ele + q.tail = (q.tail + 1) % len(q.buf) + q.count++ +} + +func (q *queue[T]) PopFront() T { + ele := q.buf[q.head] + var zeroVal T + q.buf[q.head] = zeroVal + q.head = (q.head + 1) % len(q.buf) + q.count-- + + return ele +} + +func (q *queue[T]) Front() T { + return q.buf[q.head] +} + +func (q *queue[T]) Back() T { + return q.buf[(q.tail-1+len(q.buf))%len(q.buf)] +} + +func (q *queue[T]) At(i int) T { + return q.buf[(q.head+i)%(len(q.buf))] +} + +func (q *queue[T]) growIfFull() { + if q.count < len(q.buf) { + return + } + + newBuf := make([]T, q.count<<1) + if q.tail > q.head { + copy(newBuf, q.buf[q.head:q.tail]) + } else { + n := copy(newBuf, q.buf[q.head:]) + copy(newBuf[n:], q.buf[:q.tail]) + } + + q.head = 0 + q.tail = q.count + q.buf = newBuf +} diff --git a/vendor/github.com/pion/sctp/reassembly_queue.go b/vendor/github.com/pion/sctp/reassembly_queue.go new file mode 100644 index 0000000..998824e --- /dev/null +++ b/vendor/github.com/pion/sctp/reassembly_queue.go @@ -0,0 +1,383 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "errors" + "io" + "sort" + "sync/atomic" +) + +func sortChunksByTSN(a []*chunkPayloadData) { + sort.Slice(a, func(i, j int) bool { + return sna32LT(a[i].tsn, a[j].tsn) + }) +} + +func sortChunksBySSN(a []*chunkSet) { + sort.Slice(a, func(i, j int) bool { + return sna16LT(a[i].ssn, a[j].ssn) + }) +} + +// chunkSet is a set of chunks that share the same SSN. +type chunkSet struct { + ssn uint16 // used only with the ordered chunks + ppi PayloadProtocolIdentifier + chunks []*chunkPayloadData +} + +func newChunkSet(ssn uint16, ppi PayloadProtocolIdentifier) *chunkSet { + return &chunkSet{ + ssn: ssn, + ppi: ppi, + chunks: []*chunkPayloadData{}, + } +} + +func (set *chunkSet) push(chunk *chunkPayloadData) bool { + // check if dup + for _, c := range set.chunks { + if c.tsn == chunk.tsn { + return false + } + } + + // append and sort + set.chunks = append(set.chunks, chunk) + sortChunksByTSN(set.chunks) + + // Check if we now have a complete set + complete := set.isComplete() + + return complete +} + +func (set *chunkSet) isComplete() bool { + // Condition for complete set + // 0. Has at least one chunk. + // 1. Begins with beginningFragment set to true + // 2. Ends with endingFragment set to true + // 3. TSN monotinically increase by 1 from beginning to end + + // 0. + nChunks := len(set.chunks) + if nChunks == 0 { + return false + } + + // 1. + if !set.chunks[0].beginningFragment { + return false + } + + // 2. + if !set.chunks[nChunks-1].endingFragment { + return false + } + + // 3. + var lastTSN uint32 + for i, chunk := range set.chunks { + if i > 0 { + // Fragments must have contiguous TSN + // From RFC 4960 Section 3.3.1: + // When a user message is fragmented into multiple chunks, the TSNs are + // used by the receiver to reassemble the message. This means that the + // TSNs for each fragment of a fragmented user message MUST be strictly + // sequential. + if chunk.tsn != lastTSN+1 { + // mid or end fragment is missing + return false + } + } + + lastTSN = chunk.tsn + } + + return true +} + +type reassemblyQueue struct { + si uint16 + nextSSN uint16 // expected SSN for next ordered chunk + ordered []*chunkSet + unordered []*chunkSet + unorderedChunks []*chunkPayloadData + nBytes uint64 +} + +var errTryAgain = errors.New("try again") + +func newReassemblyQueue(si uint16) *reassemblyQueue { + // From RFC 4960 Sec 6.5: + // The Stream Sequence Number in all the streams MUST start from 0 when + // the association is established. Also, when the Stream Sequence + // Number reaches the value 65535 the next Stream Sequence Number MUST + // be set to 0. + return &reassemblyQueue{ + si: si, + nextSSN: 0, // From RFC 4960 Sec 6.5: + ordered: make([]*chunkSet, 0), + unordered: make([]*chunkSet, 0), + } +} + +func (r *reassemblyQueue) push(chunk *chunkPayloadData) bool { //nolint:cyclop + var cset *chunkSet + + if chunk.streamIdentifier != r.si { + return false + } + + if chunk.unordered { + // First, insert into unorderedChunks array + r.unorderedChunks = append(r.unorderedChunks, chunk) + atomic.AddUint64(&r.nBytes, uint64(len(chunk.userData))) + sortChunksByTSN(r.unorderedChunks) + + // Scan unorderedChunks that are contiguous (in TSN) + cset = r.findCompleteUnorderedChunkSet() + + // If found, append the complete set to the unordered array + if cset != nil { + r.unordered = append(r.unordered, cset) + + return true + } + + return false + } + + // This is an ordered chunk + + if sna16LT(chunk.streamSequenceNumber, r.nextSSN) { + return false + } + + // Check if a fragmented chunkSet with the fragmented SSN already exists + if chunk.isFragmented() { + for _, set := range r.ordered { + // nolint:godox + // TODO: add caution around SSN wrapping here... this helps only a little bit + // by ensuring we don't add to an unfragmented cset (1 chunk). There's + // a case where if the SSN does wrap around, we may see the same SSN + // for a different chunk. + + // nolint:godox + // TODO: this slice can get pretty big; it may be worth maintaining a map + // for O(1) lookups at the cost of 2x memory. + if set.ssn == chunk.streamSequenceNumber && set.chunks[0].isFragmented() { + cset = set + + break + } + } + } + + // If not found, create a new chunkSet + if cset == nil { + cset = newChunkSet(chunk.streamSequenceNumber, chunk.payloadType) + r.ordered = append(r.ordered, cset) + if !chunk.unordered { + sortChunksBySSN(r.ordered) + } + } + + atomic.AddUint64(&r.nBytes, uint64(len(chunk.userData))) + + return cset.push(chunk) +} + +func (r *reassemblyQueue) findCompleteUnorderedChunkSet() *chunkSet { + startIdx := -1 + nChunks := 0 + var lastTSN uint32 + var found bool + + for i, chunk := range r.unorderedChunks { + // seek beigining + if chunk.beginningFragment { + startIdx = i + nChunks = 1 + lastTSN = chunk.tsn + + if chunk.endingFragment { + found = true + + break + } + + continue + } + + if startIdx < 0 { + continue + } + + // Check if contiguous in TSN + if chunk.tsn != lastTSN+1 { + startIdx = -1 + + continue + } + + lastTSN = chunk.tsn + nChunks++ + + if chunk.endingFragment { + found = true + + break + } + } + + if !found { + return nil + } + + // Extract the range of chunks + var chunks []*chunkPayloadData + chunks = append(chunks, r.unorderedChunks[startIdx:startIdx+nChunks]...) + + r.unorderedChunks = append( + r.unorderedChunks[:startIdx], + r.unorderedChunks[startIdx+nChunks:]...) + + chunkSet := newChunkSet(0, chunks[0].payloadType) + chunkSet.chunks = chunks + + return chunkSet +} + +func (r *reassemblyQueue) isReadable() bool { + // Check unordered first + if len(r.unordered) > 0 { + // The chunk sets in r.unordered should all be complete. + return true + } + + // Check ordered sets + if len(r.ordered) > 0 { + cset := r.ordered[0] + if cset.isComplete() { + if sna16LTE(cset.ssn, r.nextSSN) { + return true + } + } + } + + return false +} + +func (r *reassemblyQueue) read(buf []byte) (int, PayloadProtocolIdentifier, error) { // nolint: cyclop + var ( + cset *chunkSet + isUnordered bool + nTotal int + err error + ) + + switch { + case len(r.unordered) > 0: + cset = r.unordered[0] + isUnordered = true + case len(r.ordered) > 0: + cset = r.ordered[0] + if !cset.isComplete() { + return 0, 0, errTryAgain + } + if sna16GT(cset.ssn, r.nextSSN) { + return 0, 0, errTryAgain + } + default: + return 0, 0, errTryAgain + } + + for _, c := range cset.chunks { + if len(buf)-nTotal < len(c.userData) { + err = io.ErrShortBuffer + } else { + copy(buf[nTotal:], c.userData) + } + + nTotal += len(c.userData) + } + + switch { + case err != nil: + return nTotal, 0, err + case isUnordered: + r.unordered = r.unordered[1:] + default: + r.ordered = r.ordered[1:] + if cset.ssn == r.nextSSN { + r.nextSSN++ + } + } + + r.subtractNumBytes(nTotal) + + return nTotal, cset.ppi, err +} + +func (r *reassemblyQueue) forwardTSNForOrdered(lastSSN uint16) { + // Use lastSSN to locate a chunkSet then remove it if the set has + // not been complete + keep := []*chunkSet{} + for _, set := range r.ordered { + if sna16LTE(set.ssn, lastSSN) { + if !set.isComplete() { + // drop the set + for _, c := range set.chunks { + r.subtractNumBytes(len(c.userData)) + } + + continue + } + } + keep = append(keep, set) + } + r.ordered = keep + + // Finally, forward nextSSN + if sna16LTE(r.nextSSN, lastSSN) { + r.nextSSN = lastSSN + 1 + } +} + +func (r *reassemblyQueue) forwardTSNForUnordered(newCumulativeTSN uint32) { + // Remove all fragments in the unordered sets that contains chunks + // equal to or older than `newCumulativeTSN`. + // We know all sets in the r.unordered are complete ones. + // Just remove chunks that are equal to or older than newCumulativeTSN + // from the unorderedChunks + lastIdx := -1 + for i, c := range r.unorderedChunks { + if sna32GT(c.tsn, newCumulativeTSN) { + break + } + lastIdx = i + } + if lastIdx >= 0 { + for _, c := range r.unorderedChunks[0 : lastIdx+1] { + r.subtractNumBytes(len(c.userData)) + } + r.unorderedChunks = r.unorderedChunks[lastIdx+1:] + } +} + +func (r *reassemblyQueue) subtractNumBytes(nBytes int) { + cur := atomic.LoadUint64(&r.nBytes) + if int(cur) >= nBytes { //nolint:gosec // G115 + atomic.AddUint64(&r.nBytes, -uint64(nBytes)) //nolint:gosec // G115 + } else { + atomic.StoreUint64(&r.nBytes, 0) + } +} + +func (r *reassemblyQueue) getNumBytes() int { + return int(atomic.LoadUint64(&r.nBytes)) //nolint:gosec // G115 +} diff --git a/vendor/github.com/pion/sctp/receive_payload_queue.go b/vendor/github.com/pion/sctp/receive_payload_queue.go new file mode 100644 index 0000000..676e69c --- /dev/null +++ b/vendor/github.com/pion/sctp/receive_payload_queue.go @@ -0,0 +1,205 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "fmt" + "math/bits" + "strings" +) + +type receivePayloadQueue struct { + tailTSN uint32 + chunkSize int + tsnBitmask []uint64 + dupTSN []uint32 + maxTSNOffset uint32 + + cumulativeTSN uint32 +} + +func newReceivePayloadQueue(maxTSNOffset uint32) *receivePayloadQueue { + maxTSNOffset = ((maxTSNOffset + 63) / 64) * 64 + + return &receivePayloadQueue{ + tsnBitmask: make([]uint64, maxTSNOffset/64), + maxTSNOffset: maxTSNOffset, + } +} + +func (q *receivePayloadQueue) init(cumulativeTSN uint32) { + q.cumulativeTSN = cumulativeTSN + q.tailTSN = cumulativeTSN + q.chunkSize = 0 + for i := range q.tsnBitmask { + q.tsnBitmask[i] = 0 + } + q.dupTSN = q.dupTSN[:0] +} + +func (q *receivePayloadQueue) hasChunk(tsn uint32) bool { + if q.chunkSize == 0 || sna32LTE(tsn, q.cumulativeTSN) || sna32GT(tsn, q.tailTSN) { + return false + } + + index, offset := int(tsn/64)%len(q.tsnBitmask), tsn%64 + + return q.tsnBitmask[index]&(1<> uint64(start)) //nolint:gosec // G115 + + return i + start, i+start < end +} + +func getFirstZeroBit(val uint64, start, end int) (int, bool) { + return getFirstNonZeroBit(^val, start, end) +} diff --git a/vendor/github.com/pion/sctp/renovate.json b/vendor/github.com/pion/sctp/renovate.json new file mode 100644 index 0000000..f1bb98c --- /dev/null +++ b/vendor/github.com/pion/sctp/renovate.json @@ -0,0 +1,6 @@ +{ + "$schema": "https://docs.renovatebot.com/renovate-schema.json", + "extends": [ + "github>pion/renovate-config" + ] +} diff --git a/vendor/github.com/pion/sctp/rtx_timer.go b/vendor/github.com/pion/sctp/rtx_timer.go new file mode 100644 index 0000000..f795b7f --- /dev/null +++ b/vendor/github.com/pion/sctp/rtx_timer.go @@ -0,0 +1,257 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "math" + "sync" + "time" +) + +const ( + // RTO.Initial in msec. + rtoInitial float64 = 1.0 * 1000 + + // RTO.Min in msec. + rtoMin float64 = 1.0 * 1000 + + // RTO.Max in msec. + defaultRTOMax float64 = 60.0 * 1000 + + // RTO.Alpha. + rtoAlpha float64 = 0.125 + + // RTO.Beta. + rtoBeta float64 = 0.25 + + // Max.Init.Retransmits. + maxInitRetrans uint = 8 + + // Path.Max.Retrans. + pathMaxRetrans uint = 5 + + noMaxRetrans uint = 0 +) + +// rtoManager manages Rtx timeout values. +// This is an implementation of RFC 4960 sec 6.3.1. +type rtoManager struct { + srtt float64 + rttvar float64 + rto float64 + noUpdate bool + mutex sync.RWMutex + rtoMax float64 +} + +// newRTOManager creates a new rtoManager. +func newRTOManager(rtoMax float64) *rtoManager { + mgr := rtoManager{ + rto: rtoInitial, + rtoMax: rtoMax, + } + if mgr.rtoMax == 0 { + mgr.rtoMax = defaultRTOMax + } + + return &mgr +} + +// setNewRTT takes a newly measured RTT then adjust the RTO in msec. +func (m *rtoManager) setNewRTT(rtt float64) float64 { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.noUpdate { + return m.srtt + } + + if m.srtt == 0 { + // First measurement + m.srtt = rtt + m.rttvar = rtt / 2 + } else { + // Subsequent rtt measurement + m.rttvar = (1-rtoBeta)*m.rttvar + rtoBeta*(math.Abs(m.srtt-rtt)) + m.srtt = (1-rtoAlpha)*m.srtt + rtoAlpha*rtt + } + m.rto = math.Min(math.Max(m.srtt+4*m.rttvar, rtoMin), m.rtoMax) + + return m.srtt +} + +// getRTO simply returns the current RTO in msec. +func (m *rtoManager) getRTO() float64 { + m.mutex.RLock() + defer m.mutex.RUnlock() + + return m.rto +} + +// reset resets the RTO variables to the initial values. +func (m *rtoManager) reset() { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.noUpdate { + return + } + + m.srtt = 0 + m.rttvar = 0 + m.rto = rtoInitial +} + +// set RTO value for testing. +func (m *rtoManager) setRTO(rto float64, noUpdate bool) { + m.mutex.Lock() + defer m.mutex.Unlock() + + m.rto = rto + m.noUpdate = noUpdate +} + +// rtxTimerObserver is the inteface to a timer observer. +// NOTE: Observers MUST NOT call start() or stop() method on rtxTimer +// from within these callbacks. +type rtxTimerObserver interface { + onRetransmissionTimeout(timerID int, n uint) + onRetransmissionFailure(timerID int) +} + +type rtxTimerState uint8 + +const ( + rtxTimerStopped rtxTimerState = iota + rtxTimerStarted + rtxTimerClosed +) + +// rtxTimer provides the retnransmission timer conforms with RFC 4960 Sec 6.3.1. +type rtxTimer struct { + timer *time.Timer + observer rtxTimerObserver + id int + maxRetrans uint + rtoMax float64 + mutex sync.Mutex + rto float64 + nRtos uint + state rtxTimerState + pending uint8 +} + +// newRTXTimer creates a new retransmission timer. +// if maxRetrans is set to 0, it will keep retransmitting until stop() is called. +// (it will never make onRetransmissionFailure() callback. +func newRTXTimer(id int, observer rtxTimerObserver, maxRetrans uint, + rtoMax float64, +) *rtxTimer { + timer := rtxTimer{ + id: id, + observer: observer, + maxRetrans: maxRetrans, + rtoMax: rtoMax, + } + if timer.rtoMax == 0 { + timer.rtoMax = defaultRTOMax + } + timer.timer = time.AfterFunc(math.MaxInt64, timer.timeout) + timer.timer.Stop() + + return &timer +} + +func (t *rtxTimer) calculateNextTimeout() time.Duration { + timeout := calculateNextTimeout(t.rto, t.nRtos, t.rtoMax) + + return time.Duration(timeout) * time.Millisecond +} + +func (t *rtxTimer) timeout() { + t.mutex.Lock() + if t.pending--; t.pending == 0 && t.state == rtxTimerStarted { + if t.nRtos++; t.maxRetrans == 0 || t.nRtos <= t.maxRetrans { + t.timer.Reset(t.calculateNextTimeout()) + t.pending++ + defer t.observer.onRetransmissionTimeout(t.id, t.nRtos) + } else { + t.state = rtxTimerStopped + defer t.observer.onRetransmissionFailure(t.id) + } + } + t.mutex.Unlock() +} + +// start starts the timer. +func (t *rtxTimer) start(rto float64) bool { + t.mutex.Lock() + defer t.mutex.Unlock() + + // this timer is already closed or aleady running + if t.state != rtxTimerStopped { + return false + } + + // Note: rto value is intentionally not capped by RTO.Min to allow + // fast timeout for the tests. Non-test code should pass in the + // rto generated by rtoManager getRTO() method which caps the + // value at RTO.Min or at RTO.Max. + t.rto = rto + t.nRtos = 0 + t.state = rtxTimerStarted + t.pending++ + t.timer.Reset(t.calculateNextTimeout()) + + return true +} + +// stop stops the timer. +func (t *rtxTimer) stop() { + t.mutex.Lock() + defer t.mutex.Unlock() + + if t.state == rtxTimerStarted { + if t.timer.Stop() { + t.pending-- + } + t.state = rtxTimerStopped + } +} + +// closes the timer. this is similar to stop() but subsequent start() call +// will fail (the timer is no longer usable). +func (t *rtxTimer) close() { + t.mutex.Lock() + defer t.mutex.Unlock() + + if t.state == rtxTimerStarted && t.timer.Stop() { + t.pending-- + } + t.state = rtxTimerClosed +} + +// isRunning tests if the timer is running. +// Debug purpose only. +func (t *rtxTimer) isRunning() bool { + t.mutex.Lock() + defer t.mutex.Unlock() + + return t.state == rtxTimerStarted +} + +func calculateNextTimeout(rto float64, nRtos uint, rtoMax float64) float64 { + // RFC 4096 sec 6.3.3. Handle T3-rtx Expiration + // E2) For the destination address for which the timer expires, set RTO + // <- RTO * 2 ("back off the timer"). The maximum value discussed + // in rule C7 above (RTO.max) may be used to provide an upper bound + // to this doubling operation. + if nRtos < 31 { + m := 1 << nRtos + + return math.Min(rto*float64(m), rtoMax) + } + + return rtoMax +} diff --git a/vendor/github.com/pion/sctp/sctp.go b/vendor/github.com/pion/sctp/sctp.go new file mode 100644 index 0000000..ec9c706 --- /dev/null +++ b/vendor/github.com/pion/sctp/sctp.go @@ -0,0 +1,5 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package sctp implements the SCTP spec +package sctp diff --git a/vendor/github.com/pion/sctp/stream.go b/vendor/github.com/pion/sctp/stream.go new file mode 100644 index 0000000..4fd2382 --- /dev/null +++ b/vendor/github.com/pion/sctp/stream.go @@ -0,0 +1,514 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "errors" + "fmt" + "io" + "os" + "sync" + "sync/atomic" + "time" + + "github.com/pion/logging" + "github.com/pion/transport/v4/deadline" +) + +const ( + // ReliabilityTypeReliable is used for reliable transmission. + ReliabilityTypeReliable byte = 0 + // ReliabilityTypeRexmit is used for partial reliability by retransmission count. + ReliabilityTypeRexmit byte = 1 + // ReliabilityTypeTimed is used for partial reliability by retransmission duration. + ReliabilityTypeTimed byte = 2 +) + +// StreamState is an enum for SCTP Stream state field +// This field identifies the state of stream. +type StreamState int + +// StreamState enums. +const ( + StreamStateOpen StreamState = iota // Stream object starts with StreamStateOpen + StreamStateClosing // Outgoing stream is being reset + StreamStateClosed // Stream has been closed +) + +func (ss StreamState) String() string { + switch ss { + case StreamStateOpen: + return "open" + case StreamStateClosing: + return "closing" + case StreamStateClosed: + return "closed" + } + + return "unknown" +} + +// SCTP stream errors. +var ( + ErrOutboundPacketTooLarge = errors.New("outbound packet larger than maximum message size") + ErrStreamClosed = errors.New("stream closed") + ErrReadDeadlineExceeded = fmt.Errorf("read deadline exceeded: %w", os.ErrDeadlineExceeded) +) + +// Stream represents an SCTP stream. +type Stream struct { + association *Association + lock sync.RWMutex + streamIdentifier uint16 + defaultPayloadType PayloadProtocolIdentifier + reassemblyQueue *reassemblyQueue + sequenceNumber uint16 + readNotifier *sync.Cond + readErr error + readTimeoutCancel chan struct{} + writeDeadline *deadline.Deadline + writeLock sync.Mutex + unordered bool + reliabilityType byte + reliabilityValue uint32 + bufferedAmount uint64 + bufferedAmountLow uint64 + onBufferedAmountLow func() + state StreamState + log logging.LeveledLogger + name string +} + +// StreamIdentifier returns the Stream identifier associated to the stream. +func (s *Stream) StreamIdentifier() uint16 { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.streamIdentifier +} + +// SetDefaultPayloadType sets the default payload type used by Write. +func (s *Stream) SetDefaultPayloadType(defaultPayloadType PayloadProtocolIdentifier) { + atomic.StoreUint32((*uint32)(&s.defaultPayloadType), uint32(defaultPayloadType)) +} + +// SetReliabilityParams sets reliability parameters for this stream. +func (s *Stream) SetReliabilityParams(unordered bool, relType byte, relVal uint32) { + s.lock.Lock() + defer s.lock.Unlock() + + s.setReliabilityParams(unordered, relType, relVal) +} + +// setReliabilityParams sets reliability parameters for this stream. +// The caller should hold the lock. +func (s *Stream) setReliabilityParams(unordered bool, relType byte, relVal uint32) { + s.log.Debugf("[%s] reliability params: ordered=%v type=%d value=%d", + s.name, !unordered, relType, relVal) + s.unordered = unordered + s.reliabilityType = relType + s.reliabilityValue = relVal +} + +// Read reads a packet of len(p) bytes, dropping the Payload Protocol Identifier. +// Returns EOF when the stream is reset or an error if the stream is closed +// otherwise. +func (s *Stream) Read(p []byte) (int, error) { + n, _, err := s.ReadSCTP(p) + + return n, err +} + +// ReadSCTP reads a packet of len(payload) bytes and returns the associated Payload +// Protocol Identifier. +// Returns EOF when the stream is reset or an error if the stream is closed +// otherwise. +func (s *Stream) ReadSCTP(payload []byte) (int, PayloadProtocolIdentifier, error) { + s.lock.Lock() + defer s.lock.Unlock() + + defer func() { + // close readTimeoutCancel if the current read timeout routine is no longer effective + if s.readTimeoutCancel != nil && s.readErr != nil { + close(s.readTimeoutCancel) + s.readTimeoutCancel = nil + } + }() + + for { + n, ppi, err := s.reassemblyQueue.read(payload) + if err == nil || errors.Is(err, io.ErrShortBuffer) { + return n, ppi, err + } + + if s.readErr != nil { + return 0, PayloadProtocolIdentifier(0), s.readErr + } + + s.readNotifier.Wait() + } +} + +// SetReadDeadline sets the read deadline in an identical way to net.Conn. +func (s *Stream) SetReadDeadline(deadline time.Time) error { + s.lock.Lock() + defer s.lock.Unlock() + + if s.readTimeoutCancel != nil { + close(s.readTimeoutCancel) + s.readTimeoutCancel = nil + } + + if s.readErr != nil { + if !errors.Is(s.readErr, ErrReadDeadlineExceeded) { + return nil + } + s.readErr = nil + } + + if !deadline.IsZero() { + s.readTimeoutCancel = make(chan struct{}) + + go func(readTimeoutCancel chan struct{}) { + t := time.NewTimer(time.Until(deadline)) + select { + case <-readTimeoutCancel: + t.Stop() + + return + case <-t.C: + select { + case <-readTimeoutCancel: + return + default: + } + s.lock.Lock() + if s.readErr == nil { + s.readErr = ErrReadDeadlineExceeded + } + s.readTimeoutCancel = nil + s.lock.Unlock() + + s.readNotifier.Signal() + } + }(s.readTimeoutCancel) + } + + return nil +} + +func (s *Stream) handleData(pd *chunkPayloadData) { + s.lock.Lock() + defer s.lock.Unlock() + + var readable bool + if s.reassemblyQueue.push(pd) { + readable = s.reassemblyQueue.isReadable() + s.log.Debugf("[%s] reassemblyQueue readable=%v", s.name, readable) + if readable { + s.log.Debugf("[%s] readNotifier.signal()", s.name) + s.readNotifier.Signal() + s.log.Debugf("[%s] readNotifier.signal() done", s.name) + } + } +} + +func (s *Stream) handleForwardTSNForOrdered(ssn uint16) { + var readable bool + + func() { + s.lock.Lock() + defer s.lock.Unlock() + + if s.unordered { + return // unordered chunks are handled by handleForwardUnordered method + } + + // Remove all chunks older than or equal to the new TSN from + // the reassemblyQueue. + s.reassemblyQueue.forwardTSNForOrdered(ssn) + readable = s.reassemblyQueue.isReadable() + }() + + // Notify the reader asynchronously if there's a data chunk to read. + if readable { + s.readNotifier.Signal() + } +} + +func (s *Stream) handleForwardTSNForUnordered(newCumulativeTSN uint32) { + var readable bool + + func() { + s.lock.Lock() + defer s.lock.Unlock() + + if !s.unordered { + return // ordered chunks are handled by handleForwardTSNOrdered method + } + + // Remove all chunks older than or equal to the new TSN from + // the reassemblyQueue. + s.reassemblyQueue.forwardTSNForUnordered(newCumulativeTSN) + readable = s.reassemblyQueue.isReadable() + }() + + // Notify the reader asynchronously if there's a data chunk to read. + if readable { + s.readNotifier.Signal() + } +} + +// Write writes len(payload) bytes from payload with the default Payload Protocol Identifier. +func (s *Stream) Write(payload []byte) (n int, err error) { + ppi := PayloadProtocolIdentifier(atomic.LoadUint32((*uint32)(&s.defaultPayloadType))) + + return s.WriteSCTP(payload, ppi) +} + +// WriteSCTP writes len(payload) bytes from payload to the DTLS connection. +func (s *Stream) WriteSCTP(payload []byte, ppi PayloadProtocolIdentifier) (int, error) { + maxMessageSize := s.association.MaxMessageSize() + if len(payload) > int(maxMessageSize) { + return 0, fmt.Errorf("%w: %v", ErrOutboundPacketTooLarge, maxMessageSize) + } + + if s.State() != StreamStateOpen { + return 0, ErrStreamClosed + } + + // the send could fail if the association is blocked for writing (timeout), it will left a hole + // in the stream sequence number space, so we need to lock the write to avoid concurrent send and decrement + // the sequence number in case of failure + if s.association.isBlockWrite() { + s.writeLock.Lock() + } + chunks, unordered := s.packetize(payload, ppi) + n := len(payload) + err := s.association.sendPayloadData(s.writeDeadline, chunks) + if err != nil { + s.lock.Lock() + s.bufferedAmount -= uint64(n) + if !unordered { + s.sequenceNumber-- + } + s.lock.Unlock() + n = 0 + } + if s.association.isBlockWrite() { + s.writeLock.Unlock() + } + + return n, err +} + +// SetWriteDeadline sets the write deadline in an identical way to net.Conn, +// it will only work for blocking writes. +func (s *Stream) SetWriteDeadline(deadline time.Time) error { + s.writeDeadline.Set(deadline) + + return nil +} + +// SetDeadline sets the read and write deadlines in an identical way to net.Conn. +func (s *Stream) SetDeadline(t time.Time) error { + if err := s.SetReadDeadline(t); err != nil { + return err + } + + return s.SetWriteDeadline(t) +} + +func (s *Stream) packetize(raw []byte, ppi PayloadProtocolIdentifier) ([]*chunkPayloadData, bool) { + s.lock.Lock() + defer s.lock.Unlock() + + offset := uint32(0) + remaining := uint32(len(raw)) //nolint:gosec // G115 + + // From draft-ietf-rtcweb-data-protocol-09, section 6: + // All Data Channel Establishment Protocol messages MUST be sent using + // ordered delivery and reliable transmission. + unordered := ppi != PayloadTypeWebRTCDCEP && s.unordered + + var chunks []*chunkPayloadData + var head *chunkPayloadData + for remaining != 0 { + fragmentSize := min32(s.association.maxPayloadSize, remaining) + + // Copy the userdata since we'll have to store it until acked + // and the caller may re-use the buffer in the mean time + userData := make([]byte, fragmentSize) + copy(userData, raw[offset:offset+fragmentSize]) + + chunk := &chunkPayloadData{ + streamIdentifier: s.streamIdentifier, + userData: userData, + unordered: unordered, + beginningFragment: offset == 0, + endingFragment: remaining-fragmentSize == 0, + immediateSack: false, + payloadType: ppi, + streamSequenceNumber: s.sequenceNumber, + head: head, + } + + if head == nil { + head = chunk + } + + chunks = append(chunks, chunk) + + remaining -= fragmentSize + offset += fragmentSize + } + + // RFC 4960 Sec 6.6 + // Note: When transmitting ordered and unordered data, an endpoint does + // not increment its Stream Sequence Number when transmitting a DATA + // chunk with U flag set to 1. + if !unordered { + s.sequenceNumber++ + } + + s.bufferedAmount += uint64(len(raw)) + s.log.Tracef("[%s] bufferedAmount = %d", s.name, s.bufferedAmount) + + return chunks, unordered +} + +// Close closes the write-direction of the stream. +// Future calls to Write are not permitted after calling Close. +func (s *Stream) Close() error { + if sid, resetOutbound := func() (uint16, bool) { + s.lock.Lock() + defer s.lock.Unlock() + + s.log.Debugf("[%s] Close: state=%s", s.name, s.state.String()) + + if s.state == StreamStateOpen { + if s.readErr == nil { + s.state = StreamStateClosing + } else { + s.state = StreamStateClosed + } + s.log.Debugf("[%s] state change: open => %s", s.name, s.state.String()) + + return s.streamIdentifier, true + } + + return s.streamIdentifier, false + }(); resetOutbound { + // Reset the outgoing stream + // https://tools.ietf.org/html/rfc6525 + return s.association.sendResetRequest(sid) + } + + return nil +} + +// BufferedAmount returns the number of bytes of data currently queued to be sent over this stream. +func (s *Stream) BufferedAmount() uint64 { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.bufferedAmount +} + +// BufferedAmountLowThreshold returns the number of bytes of buffered outgoing data that is +// considered "low." Defaults to 0. +func (s *Stream) BufferedAmountLowThreshold() uint64 { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.bufferedAmountLow +} + +// SetBufferedAmountLowThreshold is used to update the threshold. +// See BufferedAmountLowThreshold(). +func (s *Stream) SetBufferedAmountLowThreshold(th uint64) { + s.lock.Lock() + defer s.lock.Unlock() + + s.bufferedAmountLow = th +} + +// OnBufferedAmountLow sets the callback handler which would be called when the number of +// bytes of outgoing data buffered is lower than the threshold. +func (s *Stream) OnBufferedAmountLow(f func()) { + s.lock.Lock() + defer s.lock.Unlock() + + s.onBufferedAmountLow = f +} + +// This method is called by association's readLoop (go-)routine to notify this stream +// of the specified amount of outgoing data has been delivered to the peer. +func (s *Stream) onBufferReleased(nBytesReleased int) { + if nBytesReleased <= 0 { + return + } + + s.lock.Lock() + + fromAmount := s.bufferedAmount + + if s.bufferedAmount < uint64(nBytesReleased) { + s.bufferedAmount = 0 + s.log.Errorf("[%s] released buffer size %d should be <= %d", + s.name, nBytesReleased, s.bufferedAmount) + } else { + s.bufferedAmount -= uint64(nBytesReleased) + } + + s.log.Tracef("[%s] bufferedAmount = %d", s.name, s.bufferedAmount) + + if s.onBufferedAmountLow != nil && fromAmount > s.bufferedAmountLow && s.bufferedAmount <= s.bufferedAmountLow { + f := s.onBufferedAmountLow + s.lock.Unlock() + f() + + return + } + + s.lock.Unlock() +} + +func (s *Stream) getNumBytesInReassemblyQueue() int { + // No lock is required as it reads the size with atomic load function. + return s.reassemblyQueue.getNumBytes() +} + +func (s *Stream) onInboundStreamReset() { + s.lock.Lock() + defer s.lock.Unlock() + + s.log.Debugf("[%s] onInboundStreamReset: state=%s", s.name, s.state.String()) + + // No more inbound data to read. Unblock the read with io.EOF. + // This should cause DCEP layer (datachannel package) to call Close() which + // will reset outgoing stream also. + + // See RFC 8831 section 6.7: + // if one side decides to close the data channel, it resets the corresponding + // outgoing stream. When the peer sees that an incoming stream was + // reset, it also resets its corresponding outgoing stream. Once this + // is completed, the data channel is closed. + + s.readErr = io.EOF + s.readNotifier.Broadcast() + + if s.state == StreamStateClosing { + s.log.Debugf("[%s] state change: closing => closed", s.name) + s.state = StreamStateClosed + } +} + +// State return the stream state. +func (s *Stream) State() StreamState { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.state +} diff --git a/vendor/github.com/pion/sctp/util.go b/vendor/github.com/pion/sctp/util.go new file mode 100644 index 0000000..8a4a36e --- /dev/null +++ b/vendor/github.com/pion/sctp/util.go @@ -0,0 +1,62 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +const ( + paddingMultiple = 4 +) + +func getPadding(l int) int { + return (paddingMultiple - (l % paddingMultiple)) % paddingMultiple +} + +func padByte(in []byte, cnt int) []byte { + if cnt < 0 { + cnt = 0 + } + padding := make([]byte, cnt) + + return append(in, padding...) +} + +// Serial Number Arithmetic (RFC 1982). +func sna32LT(i1, i2 uint32) bool { + return (i1 < i2 && i2-i1 < 1<<31) || (i1 > i2 && i1-i2 > 1<<31) +} + +func sna32LTE(i1, i2 uint32) bool { + return i1 == i2 || sna32LT(i1, i2) +} + +func sna32GT(i1, i2 uint32) bool { + return (i1 < i2 && (i2-i1) >= 1<<31) || (i1 > i2 && (i1-i2) <= 1<<31) +} + +func sna32GTE(i1, i2 uint32) bool { + return i1 == i2 || sna32GT(i1, i2) +} + +func sna32EQ(i1, i2 uint32) bool { + return i1 == i2 +} + +func sna16LT(i1, i2 uint16) bool { + return (i1 < i2 && (i2-i1) < 1<<15) || (i1 > i2 && (i1-i2) > 1<<15) +} + +func sna16LTE(i1, i2 uint16) bool { + return i1 == i2 || sna16LT(i1, i2) +} + +func sna16GT(i1, i2 uint16) bool { + return (i1 < i2 && (i2-i1) >= 1<<15) || (i1 > i2 && (i1-i2) <= 1<<15) +} + +func sna16GTE(i1, i2 uint16) bool { + return i1 == i2 || sna16GT(i1, i2) +} + +func sna16EQ(i1, i2 uint16) bool { + return i1 == i2 +} diff --git a/vendor/github.com/pion/sctp/windowedmin.go b/vendor/github.com/pion/sctp/windowedmin.go new file mode 100644 index 0000000..eadb78c --- /dev/null +++ b/vendor/github.com/pion/sctp/windowedmin.go @@ -0,0 +1,81 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sctp + +import ( + "sort" + "time" +) + +// windowedMin maintains a monotonic deque of (time,value) to answer +// the minimum over a sliding window efficiently. +// Not thread-safe; caller must synchronize (Association already does). +type windowedMin struct { + rackMinRTTWnd time.Duration + deque []entry +} + +type entry struct { + t time.Time + v time.Duration +} + +func newWindowedMin(window time.Duration) *windowedMin { + if window <= 0 { + window = 30 * time.Second + } + + return &windowedMin{rackMinRTTWnd: window} +} + +// prune removes elements older than (now - wnd). +func (window *windowedMin) prune(now time.Time) { + if len(window.deque) == 0 { + return + } + + cutoff := now.Add(-window.rackMinRTTWnd) + + firstValidTSAfterCutoff := sort.Search(len(window.deque), func(i int) bool { + return !window.deque[i].t.Before(cutoff) // no builtin func for >= cutoff time + }) + + if firstValidTSAfterCutoff > 0 { + window.deque = window.deque[firstValidTSAfterCutoff:] + } +} + +// Push inserts a new sample and preserves monotonic non-increasing values. +// It maintains minimum values by removing larger entries. +func (window *windowedMin) Push(now time.Time, v time.Duration) { + window.prune(now) + + for i := len(window.deque); i > 0 && window.deque[i-1].v >= v; i-- { + window.deque = window.deque[:i-1] + } + + window.deque = append( + window.deque, + entry{ + t: now, + v: v, + }, + ) +} + +// Min returns the minimum value in the current window or 0 if empty. +func (window *windowedMin) Min(now time.Time) time.Duration { + window.prune(now) + + if len(window.deque) == 0 { + return 0 + } + + return window.deque[0].v +} + +// Len is only for tests/diagnostics. +func (window *windowedMin) Len() int { + return len(window.deque) +} diff --git a/vendor/github.com/pion/sdp/v3/.gitignore b/vendor/github.com/pion/sdp/v3/.gitignore new file mode 100644 index 0000000..2394557 --- /dev/null +++ b/vendor/github.com/pion/sdp/v3/.gitignore @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: 2026 The Pion community +# SPDX-License-Identifier: MIT + +### JetBrains IDE ### +##################### +.idea/ + +### Emacs Temporary Files ### +############################# +*~ + +### Folders ### +############### +bin/ +vendor/ +node_modules/ + +### Files ### +############# +*.ivf +*.ogg +tags +cover.out +*.sw[poe] +*.wasm +examples/sfu-ws/cert.pem +examples/sfu-ws/key.pem +wasm_exec.js diff --git a/vendor/github.com/pion/sdp/v3/.golangci.yml b/vendor/github.com/pion/sdp/v3/.golangci.yml new file mode 100644 index 0000000..43af4c3 --- /dev/null +++ b/vendor/github.com/pion/sdp/v3/.golangci.yml @@ -0,0 +1,147 @@ +# SPDX-FileCopyrightText: 2026 The Pion community +# SPDX-License-Identifier: MIT + +version: "2" +linters: + enable: + - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers + - bidichk # Checks for dangerous unicode character sequences + - bodyclose # checks whether HTTP response body is closed successfully + - containedctx # containedctx is a linter that detects struct contained context.Context field + - contextcheck # check the function whether use a non-inherited context + - cyclop # checks function and package cyclomatic complexity + - decorder # check declaration order and count of types, constants, variables and functions + - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) + - dupl # Tool for code clone detection + - durationcheck # check for two durations multiplied together + - err113 # Golang linter to check the errors handling expressions + - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases + - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. + - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. + - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. + - exhaustive # check exhaustiveness of enum switch statements + - forbidigo # Forbids identifiers + - forcetypeassert # finds forced type assertions + - gochecknoglobals # Checks that no globals are present in Go code + - gocognit # Computes and checks the cognitive complexity of functions + - goconst # Finds repeated strings that could be replaced by a constant + - gocritic # The most opinionated Go source code linter + - gocyclo # Computes and checks the cyclomatic complexity of functions + - godot # Check if comments end in a period + - godox # Tool for detection of FIXME, TODO and other comment keywords + - goheader # Checks is file header matches to pattern + - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. + - goprintffuncname # Checks that printf-like functions are named with `f` at the end + - gosec # Inspects source code for security problems + - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string + - grouper # An analyzer to analyze expression groups. + - importas # Enforces consistent import aliases + - ineffassign # Detects when assignments to existing variables are not used + - lll # Reports long lines + - maintidx # maintidx measures the maintainability index of each function. + - makezero # Finds slice declarations with non-zero initial length + - misspell # Finds commonly misspelled English words in comments + - nakedret # Finds naked returns in functions greater than a specified function length + - nestif # Reports deeply nested if statements + - nilerr # Finds the code that returns nil even if it checks that the error is not nil. + - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. + - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity + - noctx # noctx finds sending http request without context.Context + - predeclared # find code that shadows one of Go's predeclared identifiers + - revive # golint replacement, finds style mistakes + - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks + - tagliatelle # Checks the struct tags. + - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers + - unconvert # Remove unnecessary type conversions + - unparam # Reports unused function parameters + - unused # Checks Go code for unused constants, variables, functions and types + - varnamelen # checks that the length of a variable's name matches its scope + - wastedassign # wastedassign finds wasted assignment statements + - whitespace # Tool for detection of leading and trailing whitespace + disable: + - depguard # Go linter that checks if package imports are in a list of acceptable packages + - funlen # Tool for detection of long functions + - gochecknoinits # Checks that no init functions are present in Go code + - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. + - interfacebloat # A linter that checks length of interface. + - ireturn # Accept Interfaces, Return Concrete Types + - mnd # An analyzer to detect magic numbers + - nolintlint # Reports ill-formed or insufficient nolint directives + - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test + - prealloc # Finds slice declarations that could potentially be preallocated + - promlinter # Check Prometheus metrics naming via promlint + - rowserrcheck # checks whether Err of rows is checked successfully + - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. + - testpackage # linter that makes you use a separate _test package + - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes + - wrapcheck # Checks that errors returned from external packages are wrapped + - wsl # Whitespace Linter - Forces you to use empty lines! + settings: + staticcheck: + checks: + - all + - -QF1008 # "could remove embedded field", to keep it explicit! + - -QF1003 # "could use tagged switch on enum", Cases conflicts with exhaustive! + exhaustive: + default-signifies-exhaustive: true + forbidigo: + forbid: + - pattern: ^fmt.Print(f|ln)?$ + - pattern: ^log.(Panic|Fatal|Print)(f|ln)?$ + - pattern: ^os.Exit$ + - pattern: ^panic$ + - pattern: ^print(ln)?$ + - pattern: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$ + pkg: ^testing$ + msg: use testify/assert instead + analyze-types: true + gomodguard: + blocked: + modules: + - github.com/pkg/errors: + recommendations: + - errors + govet: + enable: + - shadow + revive: + rules: + # Prefer 'any' type alias over 'interface{}' for Go 1.18+ compatibility + - name: use-any + severity: warning + disabled: false + misspell: + locale: US + varnamelen: + max-distance: 12 + min-name-length: 2 + ignore-type-assert-ok: true + ignore-map-index-ok: true + ignore-chan-recv-ok: true + ignore-decls: + - i int + - n int + - w io.Writer + - r io.Reader + - b []byte + exclusions: + generated: lax + rules: + - linters: + - forbidigo + - gocognit + path: (examples|main\.go) + - linters: + - gocognit + path: _test\.go + - linters: + - forbidigo + path: cmd +formatters: + enable: + - gci # Gci control golang package import order and make it always deterministic. + - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification + - gofumpt # Gofumpt checks whether code was gofumpt-ed. + - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports + exclusions: + generated: lax diff --git a/vendor/github.com/pion/sdp/v3/.goreleaser.yml b/vendor/github.com/pion/sdp/v3/.goreleaser.yml new file mode 100644 index 0000000..8577d86 --- /dev/null +++ b/vendor/github.com/pion/sdp/v3/.goreleaser.yml @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: 2026 The Pion community +# SPDX-License-Identifier: MIT + +builds: +- skip: true diff --git a/vendor/github.com/pion/sdp/v3/LICENSE b/vendor/github.com/pion/sdp/v3/LICENSE new file mode 100644 index 0000000..d96df05 --- /dev/null +++ b/vendor/github.com/pion/sdp/v3/LICENSE @@ -0,0 +1,9 @@ +MIT License + +Copyright (c) 2026 The Pion community + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/pion/sdp/v3/README.md b/vendor/github.com/pion/sdp/v3/README.md new file mode 100644 index 0000000..32fa30d --- /dev/null +++ b/vendor/github.com/pion/sdp/v3/README.md @@ -0,0 +1,35 @@ +

+
+ Pion SDP +
+

+

A Go implementation of the SDP

+

+ Pion SDP + Sourcegraph Widget + join us on Discord Follow us on Bluesky +
+ GitHub Workflow Status + Go Reference + Coverage Status + Go Report Card + License: MIT +

+
+ +### Roadmap +The library is used as a part of our WebRTC implementation. Please refer to that [roadmap](https://github.com/pion/webrtc/issues/9) to track our major milestones. + +### Community +Pion has an active community on the [Discord](https://discord.gg/PngbdqpFbt). + +Follow the [Pion Bluesky](https://bsky.app/profile/pion.ly) or [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news. + +We are always looking to support **your projects**. Please reach out if you have something to build! +If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly) + +### Contributing +Check out the [contributing wiki](https://github.com/pion/webrtc/wiki/Contributing) to join the group of amazing people making this project possible + +### License +MIT License - see [LICENSE](LICENSE) for full text diff --git a/vendor/github.com/pion/sdp/v3/base_lexer.go b/vendor/github.com/pion/sdp/v3/base_lexer.go new file mode 100644 index 0000000..0261c37 --- /dev/null +++ b/vendor/github.com/pion/sdp/v3/base_lexer.go @@ -0,0 +1,233 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sdp + +import ( + "errors" + "fmt" + "io" + "slices" + "strconv" +) + +var errDocumentStart = errors.New("already on document start") + +type syntaxError struct { + s string + i int +} + +func (e syntaxError) Error() string { + if e.i < 0 { + e.i = 0 + } + + return fmt.Sprintf("sdp: syntax error at pos %d: %s", e.i, strconv.QuoteToASCII(e.s[e.i:e.i+1])) +} + +type baseLexer struct { + value string + pos int +} + +func (l baseLexer) syntaxError() error { + return syntaxError{s: l.value, i: l.pos - 1} +} + +func (l *baseLexer) unreadByte() error { + if l.pos <= 0 { + return errDocumentStart + } + l.pos-- + + return nil +} + +func (l *baseLexer) readByte() (byte, error) { + if l.pos >= len(l.value) { + return byte(0), io.EOF + } + ch := l.value[l.pos] + l.pos++ + + return ch, nil +} + +func (l *baseLexer) nextLine() error { + for { + ch, err := l.readByte() + if errors.Is(err, io.EOF) { + return nil + } else if err != nil { + return err + } + if !isNewline(ch) { + return l.unreadByte() + } + } +} + +func (l *baseLexer) readWhitespace() error { //notlint:cyclop + for { + ch, err := l.readByte() + if errors.Is(err, io.EOF) { + return nil + } else if err != nil { + return err + } + if !isWhitespace(ch) { + return l.unreadByte() + } + } +} + +func (l *baseLexer) readUint64Field() (i uint64, err error) { //nolint:cyclop + for { + ch, err := l.readByte() + if errors.Is(err, io.EOF) && i > 0 { + break + } else if err != nil { + return i, err + } + + if isNewline(ch) { + if err := l.unreadByte(); err != nil { + return i, err + } + + break + } + + if isWhitespace(ch) { + if err := l.readWhitespace(); err != nil { + return i, err + } + + break + } + + switch ch { + case '0': + i *= 10 + case '1': + i = i*10 + 1 + case '2': + i = i*10 + 2 + case '3': + i = i*10 + 3 + case '4': + i = i*10 + 4 + case '5': + i = i*10 + 5 + case '6': + i = i*10 + 6 + case '7': + i = i*10 + 7 + case '8': + i = i*10 + 8 + case '9': + i = i*10 + 9 + default: + return i, l.syntaxError() + } + } + + return i, nil +} + +// Returns next field on this line or empty string if no more fields on line. +func (l *baseLexer) readField() (string, error) { + start := l.pos + var stop int + for { + stop = l.pos + ch, err := l.readByte() + if errors.Is(err, io.EOF) && stop > start { + break + } else if err != nil { + return "", err + } + + if isNewline(ch) { + if err := l.unreadByte(); err != nil { + return "", err + } + + break + } + + if isWhitespace(ch) { + if err := l.readWhitespace(); err != nil { + return "", err + } + + break + } + } + + return l.value[start:stop], nil +} + +func (l *lexer) readRequiredField() (string, error) { + field, err := l.readField() + if err != nil { + return "", err + } + + if field == "" { + return "", errFieldMissing + } + + return field, nil +} + +// Returns symbols until line end. +func (l *baseLexer) readLine() (string, error) { + start := l.pos + trim := 1 + for { + ch, err := l.readByte() + if err != nil { + return "", err + } + if ch == '\r' { + trim++ + } + if ch == '\n' { + return l.value[start : l.pos-trim], nil + } + } +} + +func (l *baseLexer) readType() (byte, error) { + for { + firstByte, err := l.readByte() + if err != nil { + return 0, err + } + + if isNewline(firstByte) { + continue + } + + secondByte, err := l.readByte() + if err != nil { + return 0, err + } + + if secondByte != '=' { + return firstByte, l.syntaxError() + } + + return firstByte, nil + } +} + +func isNewline(ch byte) bool { return ch == '\n' || ch == '\r' } + +func isWhitespace(ch byte) bool { return ch == ' ' || ch == '\t' } + +func anyOf(element string, data ...string) bool { + return slices.Contains(data, element) +} diff --git a/vendor/github.com/pion/sdp/v3/codecov.yml b/vendor/github.com/pion/sdp/v3/codecov.yml new file mode 100644 index 0000000..b9639c2 --- /dev/null +++ b/vendor/github.com/pion/sdp/v3/codecov.yml @@ -0,0 +1,22 @@ +# +# DO NOT EDIT THIS FILE +# +# It is automatically copied from https://github.com/pion/.goassets repository. +# +# SPDX-FileCopyrightText: 2026 The Pion community +# SPDX-License-Identifier: MIT + +coverage: + status: + project: + default: + # Allow decreasing 2% of total coverage to avoid noise. + threshold: 2% + patch: + default: + target: 70% + only_pulls: true + +ignore: + - "examples/*" + - "examples/**/*" diff --git a/vendor/github.com/pion/sdp/v3/common_description.go b/vendor/github.com/pion/sdp/v3/common_description.go new file mode 100644 index 0000000..523195c --- /dev/null +++ b/vendor/github.com/pion/sdp/v3/common_description.go @@ -0,0 +1,190 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sdp + +import ( + "strconv" +) + +// Information describes the "i=" field which provides textual information +// about the session. +type Information string + +func (i Information) String() string { + return stringFromMarshal(i.marshalInto, i.marshalSize) +} + +func (i Information) marshalInto(b []byte) []byte { + return append(b, i...) +} + +func (i Information) marshalSize() (size int) { + return len(i) +} + +// ConnectionInformation defines the representation for the "c=" field +// containing connection data. +type ConnectionInformation struct { + NetworkType string + AddressType string + Address *Address +} + +func (c ConnectionInformation) String() string { + return stringFromMarshal(c.marshalInto, c.marshalSize) +} + +func (c ConnectionInformation) marshalInto(b []byte) []byte { + b = append(append(b, c.NetworkType...), ' ') + b = append(b, c.AddressType...) + + if c.Address != nil { + b = append(b, ' ') + b = c.Address.marshalInto(b) + } + + return b +} + +func (c ConnectionInformation) marshalSize() (size int) { + size = len(c.NetworkType) + size += 1 + len(c.AddressType) + if c.Address != nil { + size += 1 + c.Address.marshalSize() + } + + return +} + +// Address desribes a structured address token from within the "c=" field. +type Address struct { + Address string + TTL *int + Range *int +} + +func (c *Address) String() string { + return stringFromMarshal(c.marshalInto, c.marshalSize) +} + +func (c *Address) marshalInto(b []byte) []byte { + b = append(b, c.Address...) + if c.TTL != nil { + b = append(b, '/') + b = strconv.AppendInt(b, int64(*c.TTL), 10) + } + if c.Range != nil { + b = append(b, '/') + b = strconv.AppendInt(b, int64(*c.Range), 10) + } + + return b +} + +func (c Address) marshalSize() (size int) { + size = len(c.Address) + if c.TTL != nil { + size += 1 + lenUint(uint64(*c.TTL)) //nolint:gosec // G115 + } + if c.Range != nil { + size += 1 + lenUint(uint64(*c.Range)) //nolint:gosec // G115 + } + + return +} + +// Bandwidth describes an optional field which denotes the proposed bandwidth +// to be used by the session or media. +type Bandwidth struct { + Experimental bool + Type string + Bandwidth uint64 +} + +func (b Bandwidth) String() string { + return stringFromMarshal(b.marshalInto, b.marshalSize) +} + +func (b Bandwidth) marshalInto(d []byte) []byte { + if b.Experimental { + d = append(d, "X-"...) + } + d = append(append(d, b.Type...), ':') + + return strconv.AppendUint(d, b.Bandwidth, 10) +} + +func (b Bandwidth) marshalSize() (size int) { + if b.Experimental { + size += 2 + } + + size += len(b.Type) + 1 + lenUint(b.Bandwidth) + + return +} + +// EncryptionKey describes the "k=" which conveys encryption key information. +type EncryptionKey string + +func (e EncryptionKey) String() string { + return stringFromMarshal(e.marshalInto, e.marshalSize) +} + +func (e EncryptionKey) marshalInto(b []byte) []byte { + return append(b, e...) +} + +func (e EncryptionKey) marshalSize() (size int) { + return len(e) +} + +// Attribute describes the "a=" field which represents the primary means for +// extending SDP. +type Attribute struct { + Key string + Value string +} + +// NewPropertyAttribute constructs a new attribute. +func NewPropertyAttribute(key string) Attribute { + return Attribute{ + Key: key, + } +} + +// NewAttribute constructs a new attribute. +func NewAttribute(key, value string) Attribute { + return Attribute{ + Key: key, + Value: value, + } +} + +func (a Attribute) String() string { + return stringFromMarshal(a.marshalInto, a.marshalSize) +} + +func (a Attribute) marshalInto(b []byte) []byte { + b = append(b, a.Key...) + if len(a.Value) > 0 { + b = append(append(b, ':'), a.Value...) + } + + return b +} + +func (a Attribute) marshalSize() (size int) { + size = len(a.Key) + if len(a.Value) > 0 { + size += 1 + len(a.Value) + } + + return size +} + +// IsICECandidate returns true if the attribute key equals "candidate". +func (a Attribute) IsICECandidate() bool { + return a.Key == "candidate" +} diff --git a/vendor/github.com/pion/sdp/v3/direction.go b/vendor/github.com/pion/sdp/v3/direction.go new file mode 100644 index 0000000..2f80df9 --- /dev/null +++ b/vendor/github.com/pion/sdp/v3/direction.go @@ -0,0 +1,62 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sdp + +import "errors" + +// Direction is a marker for transmission directon of an endpoint. +type Direction int + +const ( + // DirectionSendRecv is for bidirectional communication. + DirectionSendRecv Direction = iota + 1 + // DirectionSendOnly is for outgoing communication. + DirectionSendOnly + // DirectionRecvOnly is for incoming communication. + DirectionRecvOnly + // DirectionInactive is for no communication. + DirectionInactive +) + +const ( + directionSendRecvStr = "sendrecv" + directionSendOnlyStr = "sendonly" + directionRecvOnlyStr = "recvonly" + directionInactiveStr = "inactive" + directionUnknownStr = "" +) + +var errDirectionString = errors.New("invalid direction string") + +// NewDirection defines a procedure for creating a new direction from a raw +// string. +func NewDirection(raw string) (Direction, error) { + switch raw { + case directionSendRecvStr: + return DirectionSendRecv, nil + case directionSendOnlyStr: + return DirectionSendOnly, nil + case directionRecvOnlyStr: + return DirectionRecvOnly, nil + case directionInactiveStr: + return DirectionInactive, nil + default: + return Direction(unknown), errDirectionString + } +} + +func (t Direction) String() string { + switch t { + case DirectionSendRecv: + return directionSendRecvStr + case DirectionSendOnly: + return directionSendOnlyStr + case DirectionRecvOnly: + return directionRecvOnlyStr + case DirectionInactive: + return directionInactiveStr + default: + return directionUnknownStr + } +} diff --git a/vendor/github.com/pion/sdp/v3/extmap.go b/vendor/github.com/pion/sdp/v3/extmap.go new file mode 100644 index 0000000..fdd1d81 --- /dev/null +++ b/vendor/github.com/pion/sdp/v3/extmap.go @@ -0,0 +1,113 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sdp + +import ( + "fmt" + "net/url" + "strconv" + "strings" +) + +// Default ext values. +const ( + DefExtMapValueABSSendTime = 1 + DefExtMapValueTransportCC = 2 + DefExtMapValueSDESMid = 3 + DefExtMapValueSDESRTPStreamID = 4 + + ABSSendTimeURI = "http://www.webrtc.org/experiments/rtp-hdrext/abs-send-time" + TransportCCURI = "http://www.ietf.org/id/draft-holmer-rmcat-transport-wide-cc-extensions-01" + SDESMidURI = "urn:ietf:params:rtp-hdrext:sdes:mid" + SDESRTPStreamIDURI = "urn:ietf:params:rtp-hdrext:sdes:rtp-stream-id" + SDESRepairRTPStreamIDURI = "urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id" + AudioLevelURI = "urn:ietf:params:rtp-hdrext:ssrc-audio-level" +) + +// ExtMap represents the activation of a single RTP header extension. +type ExtMap struct { + Value int + Direction Direction + URI *url.URL + ExtAttr *string +} + +// Clone converts this object to an Attribute. +func (e *ExtMap) Clone() Attribute { + return Attribute{Key: "extmap", Value: e.string()} +} + +// Unmarshal creates an Extmap from a string. +func (e *ExtMap) Unmarshal(raw string) error { + parts := strings.SplitN(raw, ":", 2) + if len(parts) != 2 { + return fmt.Errorf("%w: %v", errSyntaxError, raw) + } + + fields := strings.Fields(parts[1]) + if len(fields) < 2 { + return fmt.Errorf("%w: %v", errSyntaxError, raw) + } + + valdir := strings.Split(fields[0], "/") + value, err := strconv.ParseInt(valdir[0], 10, 64) + if (value < 1) || (value > 246) { + return fmt.Errorf("%w: %v -- extmap key must be in the range 1-256", errSyntaxError, valdir[0]) + } + if err != nil { + return fmt.Errorf("%w: %v", errSyntaxError, valdir[0]) + } + + var direction Direction + if len(valdir) == 2 { + direction, err = NewDirection(valdir[1]) + if err != nil { + return err + } + } + + uri, err := url.Parse(fields[1]) + if err != nil { + return err + } + + if len(fields) == 3 { + tmp := fields[2] + e.ExtAttr = &tmp + } + + e.Value = int(value) + e.Direction = direction + e.URI = uri + + return nil +} + +// Marshal creates a string from an ExtMap. +func (e *ExtMap) Marshal() string { + return e.Name() + ":" + e.string() +} + +func (e *ExtMap) string() string { + output := fmt.Sprintf("%d", e.Value) + dirstring := e.Direction.String() + if dirstring != directionUnknownStr { + output += "/" + dirstring + } + + if e.URI != nil { + output += " " + e.URI.String() + } + + if e.ExtAttr != nil { + output += " " + *e.ExtAttr + } + + return output +} + +// Name returns the constant name of this object. +func (e *ExtMap) Name() string { + return "extmap" +} diff --git a/vendor/github.com/pion/sdp/v3/jsep.go b/vendor/github.com/pion/sdp/v3/jsep.go new file mode 100644 index 0000000..6ab9127 --- /dev/null +++ b/vendor/github.com/pion/sdp/v3/jsep.go @@ -0,0 +1,258 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sdp + +import ( + "fmt" + "net/url" + "strconv" + "time" +) + +// Constants for SDP attributes used in JSEP. +const ( + AttrKeyCandidate = "candidate" + AttrKeyEndOfCandidates = "end-of-candidates" + AttrKeyIdentity = "identity" + AttrKeyGroup = "group" + AttrKeySSRC = "ssrc" + AttrKeySSRCGroup = "ssrc-group" + AttrKeyMsid = "msid" + AttrKeyMsidSemantic = "msid-semantic" + AttrKeyConnectionSetup = "setup" + AttrKeyMID = "mid" + AttrKeyICELite = "ice-lite" + AttrKeyICEOptions = "ice-options" + AttrKeyRTCPMux = "rtcp-mux" + AttrKeyRTCPRsize = "rtcp-rsize" + AttrKeyInactive = "inactive" + AttrKeyRecvOnly = "recvonly" + AttrKeySendOnly = "sendonly" + AttrKeySendRecv = "sendrecv" + AttrKeyExtMap = "extmap" + AttrKeyExtMapAllowMixed = "extmap-allow-mixed" + AttrKeyCryptex = "cryptex" +) + +// Constants for semantic tokens used in JSEP. +const ( + SemanticTokenLipSynchronization = "LS" + SemanticTokenFlowIdentification = "FID" + SemanticTokenForwardErrorCorrection = "FEC" + // https://datatracker.ietf.org/doc/html/rfc5956#section-4.1 + SemanticTokenForwardErrorCorrectionFramework = "FEC-FR" + SemanticTokenWebRTCMediaStreams = "WMS" +) + +// Constants for extmap key. +const ( + ExtMapValueTransportCC = 3 +) + +func extMapURI() map[int]string { + return map[int]string{ + ExtMapValueTransportCC: "http://www.ietf.org/id/draft-holmer-rmcat-transport-wide-cc-extensions-01", + } +} + +// API to match draft-ietf-rtcweb-jsep +// Move to webrtc or its own package? + +// NewJSEPSessionDescription creates a new SessionDescription with +// some settings that are required by the JSEP spec. +// +// Note: Since v2.4.0, session ID has been fixed to use crypto random according to +// +// JSEP spec, so that NewJSEPSessionDescription now returns error as a second +// return value. +func NewJSEPSessionDescription(identity bool) (*SessionDescription, error) { + sid, err := newSessionID() + if err != nil { + return nil, err + } + descr := &SessionDescription{ + Version: 0, + Origin: Origin{ + Username: "-", + SessionID: sid, + SessionVersion: uint64(time.Now().Unix()), //nolint:gosec // G115 + NetworkType: "IN", + AddressType: "IP4", + UnicastAddress: "0.0.0.0", + }, + SessionName: "-", + TimeDescriptions: []TimeDescription{ + { + Timing: Timing{ + StartTime: 0, + StopTime: 0, + }, + RepeatTimes: nil, + }, + }, + Attributes: []Attribute{ + // "Attribute(ice-options:trickle)", // TODO: implement trickle ICE + }, + } + + if identity { + descr.WithPropertyAttribute(AttrKeyIdentity) + } + + return descr, nil +} + +// WithPropertyAttribute adds a property attribute 'a=key' to the session description. +func (s *SessionDescription) WithPropertyAttribute(key string) *SessionDescription { + s.Attributes = append(s.Attributes, NewPropertyAttribute(key)) + + return s +} + +// WithValueAttribute adds a value attribute 'a=key:value' to the session description. +func (s *SessionDescription) WithValueAttribute(key, value string) *SessionDescription { + s.Attributes = append(s.Attributes, NewAttribute(key, value)) + + return s +} + +// addOrUpdateICEOption adds or updates the ice-options attribute with the given value. +func (s *SessionDescription) addOrUpdateICEOption(value string) *SessionDescription { + for i := range s.Attributes { + if s.Attributes[i].Key == AttrKeyICEOptions { + prefix := " " + if s.Attributes[i].Value == "" { + prefix = "" + } + + s.Attributes[i].Value += prefix + value + + return s + } + } + + return s.WithValueAttribute(AttrKeyICEOptions, value) +} + +// WithICETrickleAdvertised advertises ICE trickle support in the session description. +// See https://datatracker.ietf.org/doc/html/rfc9429#section-5.2.1 +func (s *SessionDescription) WithICETrickleAdvertised() *SessionDescription { + return s.addOrUpdateICEOption("trickle") +} + +// WithICERenomination advertises ICE renomination support in the session description. +// See https://datatracker.ietf.org/doc/html/draft-thatcher-ice-renomination-01#section-3 +func (s *SessionDescription) WithICERenomination() *SessionDescription { + return s.addOrUpdateICEOption("renomination") +} + +// WithFingerprint adds a fingerprint to the session description. +func (s *SessionDescription) WithFingerprint(algorithm, value string) *SessionDescription { + return s.WithValueAttribute("fingerprint", algorithm+" "+value) +} + +// WithMedia adds a media description to the session description. +func (s *SessionDescription) WithMedia(md *MediaDescription) *SessionDescription { + s.MediaDescriptions = append(s.MediaDescriptions, md) + + return s +} + +// NewJSEPMediaDescription creates a new MediaName with +// some settings that are required by the JSEP spec. +func NewJSEPMediaDescription(codecType string, _ []string) *MediaDescription { + return &MediaDescription{ + MediaName: MediaName{ + Media: codecType, + Port: RangedPort{Value: 9}, + Protos: []string{"UDP", "TLS", "RTP", "SAVPF"}, + }, + ConnectionInformation: &ConnectionInformation{ + NetworkType: "IN", + AddressType: "IP4", + Address: &Address{ + Address: "0.0.0.0", + }, + }, + } +} + +// WithPropertyAttribute adds a property attribute 'a=key' to the media description. +func (d *MediaDescription) WithPropertyAttribute(key string) *MediaDescription { + d.Attributes = append(d.Attributes, NewPropertyAttribute(key)) + + return d +} + +// WithValueAttribute adds a value attribute 'a=key:value' to the media description. +func (d *MediaDescription) WithValueAttribute(key, value string) *MediaDescription { + d.Attributes = append(d.Attributes, NewAttribute(key, value)) + + return d +} + +// WithFingerprint adds a fingerprint to the media description. +func (d *MediaDescription) WithFingerprint(algorithm, value string) *MediaDescription { + return d.WithValueAttribute("fingerprint", algorithm+" "+value) +} + +// WithICECredentials adds ICE credentials to the media description. +func (d *MediaDescription) WithICECredentials(username, password string) *MediaDescription { + return d. + WithValueAttribute("ice-ufrag", username). + WithValueAttribute("ice-pwd", password) +} + +// WithCodec adds codec information to the media description. +func (d *MediaDescription) WithCodec( + payloadType uint8, + name string, + clockrate uint32, + channels uint16, + fmtp string, +) *MediaDescription { + d.MediaName.Formats = append(d.MediaName.Formats, strconv.Itoa(int(payloadType))) + rtpmap := fmt.Sprintf("%d %s/%d", payloadType, name, clockrate) + if channels > 0 { + rtpmap += fmt.Sprintf("/%d", channels) + } + d.WithValueAttribute("rtpmap", rtpmap) + if fmtp != "" { + d.WithValueAttribute("fmtp", fmt.Sprintf("%d %s", payloadType, fmtp)) + } + + return d +} + +// WithMediaSource adds media source information to the media description. +func (d *MediaDescription) WithMediaSource(ssrc uint32, cname, streamLabel, label string) *MediaDescription { + return d. + WithValueAttribute("ssrc", fmt.Sprintf("%d cname:%s", ssrc, cname)). // Deprecated but not phased out? + WithValueAttribute("ssrc", fmt.Sprintf("%d msid:%s %s", ssrc, streamLabel, label)). + WithValueAttribute("ssrc", fmt.Sprintf("%d mslabel:%s", ssrc, streamLabel)). // Deprecated but not phased out? + WithValueAttribute("ssrc", fmt.Sprintf("%d label:%s", ssrc, label)) // Deprecated but not phased out? +} + +// WithCandidate adds an ICE candidate to the media description. +// +// Deprecated: use WithICECandidate instead. +func (d *MediaDescription) WithCandidate(value string) *MediaDescription { + return d.WithValueAttribute("candidate", value) +} + +// WithExtMap adds an extmap to the media description. +func (d *MediaDescription) WithExtMap(e ExtMap) *MediaDescription { + return d.WithPropertyAttribute(e.Marshal()) +} + +// WithTransportCCExtMap adds an extmap to the media description. +func (d *MediaDescription) WithTransportCCExtMap() *MediaDescription { + uri, _ := url.Parse(extMapURI()[ExtMapValueTransportCC]) + e := ExtMap{ + Value: ExtMapValueTransportCC, + URI: uri, + } + + return d.WithExtMap(e) +} diff --git a/vendor/github.com/pion/sdp/v3/marshal.go b/vendor/github.com/pion/sdp/v3/marshal.go new file mode 100644 index 0000000..178184a --- /dev/null +++ b/vendor/github.com/pion/sdp/v3/marshal.go @@ -0,0 +1,242 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sdp + +// Marshal takes a SDP struct to text +// https://tools.ietf.org/html/rfc4566#section-5 +// Session description +// +// v= (protocol version) +// o= (originator and session identifier) +// s= (session name) +// i=* (session information) +// u=* (URI of description) +// e=* (email address) +// p=* (phone number) +// c=* (connection information -- not required if included in +// all media) +// b=* (zero or more bandwidth information lines) +// One or more time descriptions ("t=" and "r=" lines; see below) +// z=* (time zone adjustments) +// k=* (encryption key) +// a=* (zero or more session attribute lines) +// Zero or more media descriptions +// +// Time description +// +// t= (time the session is active) +// r=* (zero or more repeat times) +// +// Media description, if present +// +// m= (media name and transport address) +// i=* (media title) +// c=* (connection information -- optional if included at +// session level) +// b=* (zero or more bandwidth information lines) +// k=* (encryption key) +// a=* (zero or more media attribute lines) +func (s *SessionDescription) Marshal() ([]byte, error) { //nolint:cyclop + marsh := make(marshaller, 0, s.MarshalSize()) + + marsh.addKeyValue("v=", s.Version.marshalInto) + marsh.addKeyValue("o=", s.Origin.marshalInto) + marsh.addKeyValue("s=", s.SessionName.marshalInto) + + if s.SessionInformation != nil { + marsh.addKeyValue("i=", s.SessionInformation.marshalInto) + } + + if s.URI != nil { + marsh = append(marsh, "u="...) + marsh = append(marsh, s.URI.String()...) + marsh = append(marsh, "\r\n"...) + } + + if s.EmailAddress != nil { + marsh.addKeyValue("e=", s.EmailAddress.marshalInto) + } + + if s.PhoneNumber != nil { + marsh.addKeyValue("p=", s.PhoneNumber.marshalInto) + } + + if s.ConnectionInformation != nil { + marsh.addKeyValue("c=", s.ConnectionInformation.marshalInto) + } + + for _, b := range s.Bandwidth { + marsh.addKeyValue("b=", b.marshalInto) + } + + for _, td := range s.TimeDescriptions { + marsh.addKeyValue("t=", td.Timing.marshalInto) + for _, r := range td.RepeatTimes { + marsh.addKeyValue("r=", r.marshalInto) + } + } + + if len(s.TimeZones) > 0 { + marsh = append(marsh, "z="...) + for i, z := range s.TimeZones { + if i > 0 { + marsh = append(marsh, ' ') + } + marsh = z.marshalInto(marsh) + } + marsh = append(marsh, "\r\n"...) + } + + if s.EncryptionKey != nil { + marsh.addKeyValue("k=", s.EncryptionKey.marshalInto) + } + + for _, a := range s.Attributes { + marsh.addKeyValue("a=", a.marshalInto) + } + + for _, md := range s.MediaDescriptions { + marsh.addKeyValue("m=", md.MediaName.marshalInto) + + if md.MediaTitle != nil { + marsh.addKeyValue("i=", md.MediaTitle.marshalInto) + } + + if md.ConnectionInformation != nil { + marsh.addKeyValue("c=", md.ConnectionInformation.marshalInto) + } + + for _, b := range md.Bandwidth { + marsh.addKeyValue("b=", b.marshalInto) + } + + if md.EncryptionKey != nil { + marsh.addKeyValue("k=", md.EncryptionKey.marshalInto) + } + + for _, a := range md.Attributes { + marsh.addKeyValue("a=", a.marshalInto) + } + } + + return marsh, nil +} + +// `$type=` and CRLF size. +const lineBaseSize = 4 + +// MarshalSize returns the size of the SessionDescription once marshaled. +func (s *SessionDescription) MarshalSize() (marshalSize int) { //nolint:cyclop + marshalSize += lineBaseSize + s.Version.marshalSize() + marshalSize += lineBaseSize + s.Origin.marshalSize() + marshalSize += lineBaseSize + s.SessionName.marshalSize() + + if s.SessionInformation != nil { + marshalSize += lineBaseSize + s.SessionInformation.marshalSize() + } + + if s.URI != nil { + marshalSize += lineBaseSize + len(s.URI.String()) + } + + if s.EmailAddress != nil { + marshalSize += lineBaseSize + s.EmailAddress.marshalSize() + } + + if s.PhoneNumber != nil { + marshalSize += lineBaseSize + s.PhoneNumber.marshalSize() + } + + if s.ConnectionInformation != nil { + marshalSize += lineBaseSize + s.ConnectionInformation.marshalSize() + } + + for _, b := range s.Bandwidth { + marshalSize += lineBaseSize + b.marshalSize() + } + + for _, td := range s.TimeDescriptions { + marshalSize += lineBaseSize + td.Timing.marshalSize() + for _, r := range td.RepeatTimes { + marshalSize += lineBaseSize + r.marshalSize() + } + } + + if len(s.TimeZones) > 0 { + marshalSize += lineBaseSize + + for i, z := range s.TimeZones { + if i > 0 { + marshalSize++ + } + marshalSize += z.marshalSize() + } + } + + if s.EncryptionKey != nil { + marshalSize += lineBaseSize + s.EncryptionKey.marshalSize() + } + + for _, a := range s.Attributes { + marshalSize += lineBaseSize + a.marshalSize() + } + + for _, md := range s.MediaDescriptions { + marshalSize += lineBaseSize + md.MediaName.marshalSize() + if md.MediaTitle != nil { + marshalSize += lineBaseSize + md.MediaTitle.marshalSize() + } + if md.ConnectionInformation != nil { + marshalSize += lineBaseSize + md.ConnectionInformation.marshalSize() + } + + for _, b := range md.Bandwidth { + marshalSize += lineBaseSize + b.marshalSize() + } + + if md.EncryptionKey != nil { + marshalSize += lineBaseSize + md.EncryptionKey.marshalSize() + } + + for _, a := range md.Attributes { + marshalSize += lineBaseSize + a.marshalSize() + } + } + + return marshalSize +} + +// marshaller contains state during marshaling. +type marshaller []byte + +func (m *marshaller) addKeyValue(key string, value func([]byte) []byte) { + *m = append(*m, key...) + *m = value(*m) + *m = append(*m, "\r\n"...) +} + +func lenUint(i uint64) (count int) { + if i == 0 { + return 1 + } + + for i != 0 { + i /= 10 + count++ + } + + return +} + +func lenInt(i int64) (count int) { + if i < 0 { + return lenUint(uint64(-i)) + 1 + } + + return lenUint(uint64(i)) +} + +func stringFromMarshal(marshalFunc func([]byte) []byte, sizeFunc func() int) string { + return string(marshalFunc(make([]byte, 0, sizeFunc()))) +} diff --git a/vendor/github.com/pion/sdp/v3/media_description.go b/vendor/github.com/pion/sdp/v3/media_description.go new file mode 100644 index 0000000..a0014b7 --- /dev/null +++ b/vendor/github.com/pion/sdp/v3/media_description.go @@ -0,0 +1,132 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sdp + +import ( + "strconv" +) + +// MediaDescription represents a media type. +// https://tools.ietf.org/html/rfc4566#section-5.14 +type MediaDescription struct { + // m= / ... + // https://tools.ietf.org/html/rfc4566#section-5.14 + MediaName MediaName + + // i= + // https://tools.ietf.org/html/rfc4566#section-5.4 + MediaTitle *Information + + // c= + // https://tools.ietf.org/html/rfc4566#section-5.7 + ConnectionInformation *ConnectionInformation + + // b=: + // https://tools.ietf.org/html/rfc4566#section-5.8 + Bandwidth []Bandwidth + + // k= + // k=: + // https://tools.ietf.org/html/rfc4566#section-5.12 + EncryptionKey *EncryptionKey + + // a= + // a=: + // https://tools.ietf.org/html/rfc4566#section-5.13 + Attributes []Attribute +} + +// Attribute returns the value of an attribute and if it exists. +func (d *MediaDescription) Attribute(key string) (string, bool) { + for _, a := range d.Attributes { + if a.Key == key { + return a.Value, true + } + } + + return "", false +} + +// RangedPort supports special format for the media field "m=" port value. If +// it may be necessary to specify multiple transport ports, the protocol allows +// to write it as: / where number of ports is a an +// offsetting range. +type RangedPort struct { + Value int + Range *int +} + +func (p *RangedPort) String() string { + output := strconv.Itoa(p.Value) + if p.Range != nil { + output += "/" + strconv.Itoa(*p.Range) + } + + return output +} + +func (p RangedPort) marshalInto(b []byte) []byte { + b = strconv.AppendInt(b, int64(p.Value), 10) + if p.Range != nil { + b = append(b, '/') + b = strconv.AppendInt(b, int64(*p.Range), 10) + } + + return b +} + +func (p RangedPort) marshalSize() (size int) { + size = lenInt(int64(p.Value)) + if p.Range != nil { + size += 1 + lenInt(int64(*p.Range)) + } + + return +} + +// MediaName describes the "m=" field storage structure. +type MediaName struct { + Media string + Port RangedPort + Protos []string + Formats []string +} + +func (m MediaName) String() string { + return stringFromMarshal(m.marshalInto, m.marshalSize) +} + +func (m MediaName) marshalInto(b []byte) []byte { + appendList := func(list []string, sep byte) { + for i, p := range list { + if i != 0 && i != len(list) { + b = append(b, sep) + } + b = append(b, p...) + } + } + + b = append(append(b, m.Media...), ' ') + b = append(m.Port.marshalInto(b), ' ') + appendList(m.Protos, '/') + b = append(b, ' ') + appendList(m.Formats, ' ') + + return b +} + +func (m MediaName) marshalSize() (size int) { + listSize := func(list []string) { + for _, p := range list { + size += 1 + len(p) + } + } + + size = len(m.Media) + size += 1 + m.Port.marshalSize() + listSize(m.Protos) + listSize(m.Formats) + + return size +} diff --git a/vendor/github.com/pion/sdp/v3/renovate.json b/vendor/github.com/pion/sdp/v3/renovate.json new file mode 100644 index 0000000..f1bb98c --- /dev/null +++ b/vendor/github.com/pion/sdp/v3/renovate.json @@ -0,0 +1,6 @@ +{ + "$schema": "https://docs.renovatebot.com/renovate-schema.json", + "extends": [ + "github>pion/renovate-config" + ] +} diff --git a/vendor/github.com/pion/sdp/v3/sdp.go b/vendor/github.com/pion/sdp/v3/sdp.go new file mode 100644 index 0000000..9472e8a --- /dev/null +++ b/vendor/github.com/pion/sdp/v3/sdp.go @@ -0,0 +1,5 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package sdp implements Session Description Protocol (SDP) +package sdp diff --git a/vendor/github.com/pion/sdp/v3/session_description.go b/vendor/github.com/pion/sdp/v3/session_description.go new file mode 100644 index 0000000..2a89386 --- /dev/null +++ b/vendor/github.com/pion/sdp/v3/session_description.go @@ -0,0 +1,204 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sdp + +import ( + "net/url" + "strconv" +) + +// SessionDescription is a a well-defined format for conveying sufficient +// information to discover and participate in a multimedia session. +type SessionDescription struct { + // v=0 + // https://tools.ietf.org/html/rfc4566#section-5.1 + Version Version + + // o= + // https://tools.ietf.org/html/rfc4566#section-5.2 + Origin Origin + + // s= + // https://tools.ietf.org/html/rfc4566#section-5.3 + SessionName SessionName + + // i= + // https://tools.ietf.org/html/rfc4566#section-5.4 + SessionInformation *Information + + // u= + // https://tools.ietf.org/html/rfc4566#section-5.5 + URI *url.URL + + // e= + // https://tools.ietf.org/html/rfc4566#section-5.6 + EmailAddress *EmailAddress + + // p= + // https://tools.ietf.org/html/rfc4566#section-5.6 + PhoneNumber *PhoneNumber + + // c= + // https://tools.ietf.org/html/rfc4566#section-5.7 + ConnectionInformation *ConnectionInformation + + // b=: + // https://tools.ietf.org/html/rfc4566#section-5.8 + Bandwidth []Bandwidth + + // https://tools.ietf.org/html/rfc4566#section-5.9 + // https://tools.ietf.org/html/rfc4566#section-5.10 + TimeDescriptions []TimeDescription + + // z= ... + // https://tools.ietf.org/html/rfc4566#section-5.11 + TimeZones []TimeZone + + // k= + // k=: + // https://tools.ietf.org/html/rfc4566#section-5.12 + EncryptionKey *EncryptionKey + + // a= + // a=: + // https://tools.ietf.org/html/rfc4566#section-5.13 + Attributes []Attribute + + // https://tools.ietf.org/html/rfc4566#section-5.14 + MediaDescriptions []*MediaDescription +} + +// Attribute returns the value of an attribute and if it exists. +func (s *SessionDescription) Attribute(key string) (string, bool) { + for _, a := range s.Attributes { + if a.Key == key { + return a.Value, true + } + } + + return "", false +} + +// Version describes the value provided by the "v=" field which gives +// the version of the Session Description Protocol. +type Version int + +func (v Version) String() string { + return stringFromMarshal(v.marshalInto, v.marshalSize) +} + +func (v Version) marshalInto(b []byte) []byte { + return strconv.AppendInt(b, int64(v), 10) +} + +func (v Version) marshalSize() (size int) { + return lenInt(int64(v)) +} + +// Origin defines the structure for the "o=" field which provides the +// originator of the session plus a session identifier and version number. +type Origin struct { + Username string + SessionID uint64 + SessionVersion uint64 + NetworkType string + AddressType string + UnicastAddress string +} + +func (o Origin) String() string { + return stringFromMarshal(o.marshalInto, o.marshalSize) +} + +func (o Origin) marshalInto(b []byte) []byte { + b = append(append(b, o.Username...), ' ') + b = append(strconv.AppendUint(b, o.SessionID, 10), ' ') + b = append(strconv.AppendUint(b, o.SessionVersion, 10), ' ') + b = append(append(b, o.NetworkType...), ' ') + b = append(append(b, o.AddressType...), ' ') + + return append(b, o.UnicastAddress...) +} + +func (o Origin) marshalSize() (size int) { + return len(o.Username) + + lenUint(o.SessionID) + + lenUint(o.SessionVersion) + + len(o.NetworkType) + + len(o.AddressType) + + len(o.UnicastAddress) + + 5 +} + +// SessionName describes a structured representations for the "s=" field +// and is the textual session name. +type SessionName string + +func (s SessionName) String() string { + return stringFromMarshal(s.marshalInto, s.marshalSize) +} + +func (s SessionName) marshalInto(b []byte) []byte { + return append(b, s...) +} + +func (s SessionName) marshalSize() (size int) { + return len(s) +} + +// EmailAddress describes a structured representations for the "e=" line +// which specifies email contact information for the person responsible for +// the conference. +type EmailAddress string + +func (e EmailAddress) String() string { + return stringFromMarshal(e.marshalInto, e.marshalSize) +} + +func (e EmailAddress) marshalInto(b []byte) []byte { + return append(b, e...) +} + +func (e EmailAddress) marshalSize() (size int) { + return len(e) +} + +// PhoneNumber describes a structured representations for the "p=" line +// specify phone contact information for the person responsible for the +// conference. +type PhoneNumber string + +func (p PhoneNumber) String() string { + return stringFromMarshal(p.marshalInto, p.marshalSize) +} + +func (p PhoneNumber) marshalInto(b []byte) []byte { + return append(b, p...) +} + +func (p PhoneNumber) marshalSize() (size int) { + return len(p) +} + +// TimeZone defines the structured object for "z=" line which describes +// repeated sessions scheduling. +type TimeZone struct { + AdjustmentTime uint64 + Offset int64 +} + +func (z TimeZone) String() string { + return stringFromMarshal(z.marshalInto, z.marshalSize) +} + +func (z TimeZone) marshalInto(b []byte) []byte { + b = strconv.AppendUint(b, z.AdjustmentTime, 10) + b = append(b, ' ') + + return strconv.AppendInt(b, z.Offset, 10) +} + +func (z TimeZone) marshalSize() (size int) { + return lenUint(z.AdjustmentTime) + 1 + lenInt(z.Offset) +} diff --git a/vendor/github.com/pion/sdp/v3/time_description.go b/vendor/github.com/pion/sdp/v3/time_description.go new file mode 100644 index 0000000..bc66498 --- /dev/null +++ b/vendor/github.com/pion/sdp/v3/time_description.go @@ -0,0 +1,76 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sdp + +import ( + "strconv" +) + +// TimeDescription describes "t=", "r=" fields of the session description +// which are used to specify the start and stop times for a session as well as +// repeat intervals and durations for the scheduled session. +type TimeDescription struct { + // t= + // https://tools.ietf.org/html/rfc4566#section-5.9 + Timing Timing + + // r= + // https://tools.ietf.org/html/rfc4566#section-5.10 + RepeatTimes []RepeatTime +} + +// Timing defines the "t=" field's structured representation for the start and +// stop times. +type Timing struct { + StartTime uint64 + StopTime uint64 +} + +func (t Timing) String() string { + return stringFromMarshal(t.marshalInto, t.marshalSize) +} + +func (t Timing) marshalInto(b []byte) []byte { + b = append(strconv.AppendUint(b, t.StartTime, 10), ' ') + + return strconv.AppendUint(b, t.StopTime, 10) +} + +func (t Timing) marshalSize() (size int) { + return lenUint(t.StartTime) + 1 + lenUint(t.StopTime) +} + +// RepeatTime describes the "r=" fields of the session description which +// represents the intervals and durations for repeated scheduled sessions. +type RepeatTime struct { + Interval int64 + Duration int64 + Offsets []int64 +} + +func (r RepeatTime) String() string { + return stringFromMarshal(r.marshalInto, r.marshalSize) +} + +func (r RepeatTime) marshalInto(b []byte) []byte { + b = strconv.AppendInt(b, r.Interval, 10) + b = append(b, ' ') + b = strconv.AppendInt(b, r.Duration, 10) + for _, value := range r.Offsets { + b = append(b, ' ') + b = strconv.AppendInt(b, value, 10) + } + + return b +} + +func (r RepeatTime) marshalSize() (size int) { + size = lenInt(r.Interval) + size += 1 + lenInt(r.Duration) + for _, o := range r.Offsets { + size += 1 + lenInt(o) + } + + return +} diff --git a/vendor/github.com/pion/sdp/v3/unmarshal.go b/vendor/github.com/pion/sdp/v3/unmarshal.go new file mode 100644 index 0000000..b423437 --- /dev/null +++ b/vendor/github.com/pion/sdp/v3/unmarshal.go @@ -0,0 +1,1050 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sdp + +import ( + "errors" + "fmt" + "net/url" + "strconv" + "strings" + "sync" +) + +var ( + errSDPInvalidSyntax = errors.New("sdp: invalid syntax") + errSDPInvalidNumericValue = errors.New("sdp: invalid numeric value") + errSDPInvalidValue = errors.New("sdp: invalid value") + errSDPInvalidPortValue = errors.New("sdp: invalid port value") + errSDPCacheInvalid = errors.New("sdp: invalid cache") + + //nolint: gochecknoglobals + unmarshalCachePool = sync.Pool{ + New: func() any { + return &unmarshalCache{} + }, + } +) + +// UnmarshalString is the primary function that deserializes the session description +// message and stores it inside of a structured SessionDescription object. +// +// The States Transition Table describes the computation flow between functions +// (namely s1, s2, s3, ...) for a parsing procedure that complies with the +// specifications laid out by the rfc4566#section-5 as well as by JavaScript +// Session Establishment Protocol draft. Links: +// +// https://tools.ietf.org/html/rfc4566#section-5 +// https://tools.ietf.org/html/draft-ietf-rtcweb-jsep-24 +// +// https://tools.ietf.org/html/rfc4566#section-5 +// Session description +// +// v= (protocol version) +// o= (originator and session identifier) +// s= (session name) +// i=* (session information) +// u=* (URI of description) +// e=* (email address) +// p=* (phone number) +// c=* (connection information -- not required if included in +// all media) +// b=* (zero or more bandwidth information lines) +// One or more time descriptions ("t=" and "r=" lines; see below) +// z=* (time zone adjustments) +// k=* (encryption key) +// a=* (zero or more session attribute lines) +// Zero or more media descriptions +// +// Time description +// +// t= (time the session is active) +// r=* (zero or more repeat times) +// +// Media description, if present +// +// m= (media name and transport address) +// i=* (media title) +// c=* (connection information -- optional if included at +// session level) +// b=* (zero or more bandwidth information lines) +// k=* (encryption key) +// a=* (zero or more media attribute lines) +// +// In order to generate the following state table and draw subsequent +// deterministic finite-state automota ("DFA") the following regex was used to +// derive the DFA: +// +// vosi?u?e?p?c?b*(tr*)+z?k?a*(mi?c?b*k?a*)* +// +// possible place and state to exit: +// +// ** * * * ** * * * * +// 99 1 1 1 11 1 1 1 1 +// 3 1 1 26 5 5 4 4 +// +// Please pay close attention to the `k`, and `a` parsing states. In the table +// below in order to distinguish between the states belonging to the media +// description as opposed to the session description, the states are marked +// with an asterisk ("a*", "k*"). +// +--------+----+-------+----+-----+----+-----+---+----+----+---+---+-----+---+---+----+---+----+ +// | STATES | a* | a*,k* | a | a,k | b | b,c | e | i | m | o | p | r,t | s | t | u | v | z | +// +--------+----+-------+----+-----+----+-----+---+----+----+---+---+-----+---+---+----+---+----+ +// | s1 | | | | | | | | | | | | | | | | 2 | | +// | s2 | | | | | | | | | | 3 | | | | | | | | +// | s3 | | | | | | | | | | | | | 4 | | | | | +// | s4 | | | | | | 5 | 6 | 7 | | | 8 | | | 9 | 10 | | | +// | s5 | | | | | 5 | | | | | | | | | 9 | | | | +// | s6 | | | | | | 5 | | | | | 8 | | | 9 | | | | +// | s7 | | | | | | 5 | 6 | | | | 8 | | | 9 | 10 | | | +// | s8 | | | | | | 5 | | | | | | | | 9 | | | | +// | s9 | | | | 11 | | | | | 12 | | | 9 | | | | | 13 | +// | s10 | | | | | | 5 | 6 | | | | 8 | | | 9 | | | | +// | s11 | | | 11 | | | | | | 12 | | | | | | | | | +// | s12 | | 14 | | | | 15 | | 16 | 12 | | | | | | | | | +// | s13 | | | | 11 | | | | | 12 | | | | | | | | | +// | s14 | 14 | | | | | | | | 12 | | | | | | | | | +// | s15 | | 14 | | | 15 | | | | 12 | | | | | | | | | +// | s16 | | 14 | | | | 15 | | | 12 | | | | | | | | | +// +--------+----+-------+----+-----+----+-----+---+----+----+---+---+-----+---+---+----+---+----+ . +func (s *SessionDescription) UnmarshalString(value string) error { + var ok bool + lex := new(lexer) + if lex.cache, ok = unmarshalCachePool.Get().(*unmarshalCache); !ok { + return errSDPCacheInvalid + } + defer unmarshalCachePool.Put(lex.cache) + + lex.cache.reset() + lex.desc = s + lex.value = value + + for state := s1; state != nil; { + var err error + state, err = state(lex) + if err != nil { + return err + } + } + + s.Attributes = lex.cache.cloneSessionAttributes() + populateMediaAttributes(lex.cache, lex.desc) + + return nil +} + +// Unmarshal converts the value into a []byte and then calls UnmarshalString. +// Callers should use the more performant UnmarshalString. +func (s *SessionDescription) Unmarshal(value []byte) error { + return s.UnmarshalString(string(value)) +} + +func s1(l *lexer) (stateFn, error) { + return l.handleType(func(key byte) stateFn { + if key == 'v' { + return unmarshalProtocolVersion + } + + return nil + }) +} + +func s2(l *lexer) (stateFn, error) { + return l.handleType(func(key byte) stateFn { + if key == 'o' { + return unmarshalOrigin + } + + return nil + }) +} + +func s3(l *lexer) (stateFn, error) { + return l.handleType(func(key byte) stateFn { + if key == 's' { + return unmarshalSessionName + } + + return nil + }) +} + +func s4(l *lexer) (stateFn, error) { + return l.handleType(func(key byte) stateFn { + switch key { + case 'i': + return unmarshalSessionInformation + case 'u': + return unmarshalURI + case 'e': + return unmarshalEmail + case 'p': + return unmarshalPhone + case 'c': + return unmarshalSessionConnectionInformation + case 'b': + return unmarshalSessionBandwidth + case 't': + return unmarshalTiming + } + + return nil + }) +} + +func s5(l *lexer) (stateFn, error) { + return l.handleType(func(key byte) stateFn { + switch key { + case 'b': + return unmarshalSessionBandwidth + case 't': + return unmarshalTiming + } + + return nil + }) +} + +func s6(l *lexer) (stateFn, error) { + return l.handleType(func(key byte) stateFn { + switch key { + case 'p': + return unmarshalPhone + case 'c': + return unmarshalSessionConnectionInformation + case 'b': + return unmarshalSessionBandwidth + case 't': + return unmarshalTiming + } + + return nil + }) +} + +func s7(l *lexer) (stateFn, error) { + return l.handleType(func(key byte) stateFn { + switch key { + case 'u': + return unmarshalURI + case 'e': + return unmarshalEmail + case 'p': + return unmarshalPhone + case 'c': + return unmarshalSessionConnectionInformation + case 'b': + return unmarshalSessionBandwidth + case 't': + return unmarshalTiming + } + + return nil + }) +} + +func s8(l *lexer) (stateFn, error) { + return l.handleType(func(key byte) stateFn { + switch key { + case 'c': + return unmarshalSessionConnectionInformation + case 'b': + return unmarshalSessionBandwidth + case 't': + return unmarshalTiming + } + + return nil + }) +} + +func s9(l *lexer) (stateFn, error) { + return l.handleType(func(key byte) stateFn { + switch key { + case 'z': + return unmarshalTimeZones + case 'k': + return unmarshalSessionEncryptionKey + case 'a': + return unmarshalSessionAttribute + case 'r': + return unmarshalRepeatTimes + case 't': + return unmarshalTiming + case 'm': + return unmarshalMediaDescription + } + + return nil + }) +} + +func s10(l *lexer) (stateFn, error) { + return l.handleType(func(key byte) stateFn { + switch key { + case 'e': + return unmarshalEmail + case 'p': + return unmarshalPhone + case 'c': + return unmarshalSessionConnectionInformation + case 'b': + return unmarshalSessionBandwidth + case 't': + return unmarshalTiming + } + + return nil + }) +} + +func s11(l *lexer) (stateFn, error) { + return l.handleType(func(key byte) stateFn { + switch key { + case 'a': + return unmarshalSessionAttribute + case 'm': + return unmarshalMediaDescription + } + + return nil + }) +} + +func s12(l *lexer) (stateFn, error) { + return l.handleType(func(key byte) stateFn { + switch key { + case 'a': + return unmarshalMediaAttribute + case 'k': + return unmarshalMediaEncryptionKey + case 'b': + return unmarshalMediaBandwidth + case 'c': + return unmarshalMediaConnectionInformation + case 'i': + return unmarshalMediaTitle + case 'm': + return unmarshalMediaDescription + } + + return nil + }) +} + +func s13(l *lexer) (stateFn, error) { + return l.handleType(func(key byte) stateFn { + switch key { + case 'a': + return unmarshalSessionAttribute + case 'k': + return unmarshalSessionEncryptionKey + case 'm': + return unmarshalMediaDescription + } + + return nil + }) +} + +func s14(l *lexer) (stateFn, error) { + return l.handleType(func(key byte) stateFn { + switch key { + case 'a': + return unmarshalMediaAttribute + case 'k': + // Non-spec ordering + return unmarshalMediaEncryptionKey + case 'b': + // Non-spec ordering + return unmarshalMediaBandwidth + case 'c': + // Non-spec ordering + return unmarshalMediaConnectionInformation + case 'i': + // Non-spec ordering + return unmarshalMediaTitle + case 'm': + return unmarshalMediaDescription + } + + return nil + }) +} + +func s15(l *lexer) (stateFn, error) { + return l.handleType(func(key byte) stateFn { + switch key { + case 'a': + return unmarshalMediaAttribute + case 'k': + return unmarshalMediaEncryptionKey + case 'b': + return unmarshalMediaBandwidth + case 'c': + return unmarshalMediaConnectionInformation + case 'i': + // Non-spec ordering + return unmarshalMediaTitle + case 'm': + return unmarshalMediaDescription + } + + return nil + }) +} + +func s16(l *lexer) (stateFn, error) { + return l.handleType(func(key byte) stateFn { + switch key { + case 'a': + return unmarshalMediaAttribute + case 'k': + return unmarshalMediaEncryptionKey + case 'c': + return unmarshalMediaConnectionInformation + case 'b': + return unmarshalMediaBandwidth + case 'i': + // Non-spec ordering + return unmarshalMediaTitle + case 'm': + return unmarshalMediaDescription + } + + return nil + }) +} + +func unmarshalProtocolVersion(l *lexer) (stateFn, error) { + version, err := l.readUint64Field() + if err != nil { + return nil, err + } + + // As off the latest draft of the rfc this value is required to be 0. + // https://tools.ietf.org/html/draft-ietf-rtcweb-jsep-24#section-5.8.1 + if version != 0 { + return nil, fmt.Errorf("%w `%v`", errSDPInvalidValue, version) + } + + if err := l.nextLine(); err != nil { + return nil, err + } + + return s2, nil +} + +func unmarshalOrigin(lex *lexer) (stateFn, error) { + var err error + + lex.desc.Origin.Username, err = lex.readField() + if err != nil { + return nil, err + } + + lex.desc.Origin.SessionID, err = lex.readUint64Field() + if err != nil { + return nil, err + } + + lex.desc.Origin.SessionVersion, err = lex.readUint64Field() + if err != nil { + return nil, err + } + + lex.desc.Origin.NetworkType, err = lex.readField() + if err != nil { + return nil, err + } + + // Set according to currently registered with IANA + // https://tools.ietf.org/html/rfc4566#section-8.2.6 + if !anyOf(lex.desc.Origin.NetworkType, "IN") { + return nil, fmt.Errorf("%w `%v`", errSDPInvalidValue, lex.desc.Origin.NetworkType) + } + + // Handle potentially missing AddressType field + err = handleAddressType(lex) + if err != nil { + return nil, err + } + + // Set according to currently registered with IANA + // https://tools.ietf.org/html/rfc4566#section-8.2.7 + if !anyOf(lex.desc.Origin.AddressType, "IP4", "IP6") { + return nil, fmt.Errorf("%w `%v`", errSDPInvalidValue, lex.desc.Origin.AddressType) + } + + // Handle potentially missing UnicastAddress field + err = handleUnicastAddress(lex) + if err != nil { + return nil, err + } + + if err := lex.nextLine(); err != nil { + return nil, err + } + + return s3, nil +} + +// handleAddressType processes AddressType field with graceful handling for missing fields. +func handleAddressType(lex *lexer) error { + addressType, err := lex.readRequiredField() + if err != nil { + if errors.Is(err, errFieldMissing) { + // Field missing - use defaults for camera compatibility + lex.desc.Origin.AddressType = "IP4" + lex.desc.Origin.UnicastAddress = "0.0.0.0" + + return nil + } + + return err + } + + lex.desc.Origin.AddressType = addressType + + return nil +} + +// handleUnicastAddress processes UnicastAddress field with graceful handling for missing fields. +func handleUnicastAddress(lex *lexer) error { + unicastAddress, err := lex.readRequiredField() + if err != nil { + if errors.Is(err, errFieldMissing) { + // Use appropriate default based on address type + if lex.desc.Origin.AddressType == "IP6" { + lex.desc.Origin.UnicastAddress = "::" + } else { + lex.desc.Origin.UnicastAddress = "0.0.0.0" + } + + return nil + } + + return err + } + + lex.desc.Origin.UnicastAddress = unicastAddress + + return nil +} + +func unmarshalSessionName(l *lexer) (stateFn, error) { + value, err := l.readLine() + if err != nil { + return nil, err + } + + l.desc.SessionName = SessionName(value) + + return s4, nil +} + +func unmarshalSessionInformation(l *lexer) (stateFn, error) { + value, err := l.readLine() + if err != nil { + return nil, err + } + + sessionInformation := Information(value) + l.desc.SessionInformation = &sessionInformation + + return s7, nil +} + +func unmarshalURI(l *lexer) (stateFn, error) { + value, err := l.readLine() + if err != nil { + return nil, err + } + + l.desc.URI, err = url.Parse(value) + if err != nil { + return nil, err + } + + return s10, nil +} + +func unmarshalEmail(l *lexer) (stateFn, error) { + value, err := l.readLine() + if err != nil { + return nil, err + } + + emailAddress := EmailAddress(value) + l.desc.EmailAddress = &emailAddress + + return s6, nil +} + +func unmarshalPhone(l *lexer) (stateFn, error) { + value, err := l.readLine() + if err != nil { + return nil, err + } + + phoneNumber := PhoneNumber(value) + l.desc.PhoneNumber = &phoneNumber + + return s8, nil +} + +func unmarshalSessionConnectionInformation(l *lexer) (stateFn, error) { + var err error + l.desc.ConnectionInformation, err = l.unmarshalConnectionInformation() + if err != nil { + return nil, err + } + + return s5, nil +} + +func (l *lexer) unmarshalConnectionInformation() (*ConnectionInformation, error) { + var err error + var connInfo ConnectionInformation + + connInfo.NetworkType, err = l.readField() + if err != nil { + return nil, err + } + + // Set according to currently registered with IANA + // https://tools.ietf.org/html/rfc4566#section-8.2.6 + if !anyOf(connInfo.NetworkType, "IN") { + return nil, fmt.Errorf("%w `%v`", errSDPInvalidValue, connInfo.NetworkType) + } + + connInfo.AddressType, err = l.readField() + if err != nil { + return nil, err + } + + // Set according to currently registered with IANA + // https://tools.ietf.org/html/rfc4566#section-8.2.7 + if !anyOf(connInfo.AddressType, "IP4", "IP6") { + return nil, fmt.Errorf("%w `%v`", errSDPInvalidValue, connInfo.AddressType) + } + + address, err := l.readField() + if err != nil { + return nil, err + } + + if address != "" { + connInfo.Address = new(Address) + connInfo.Address.Address = address + } + + if err := l.nextLine(); err != nil { + return nil, err + } + + return &connInfo, nil +} + +func unmarshalSessionBandwidth(l *lexer) (stateFn, error) { + value, err := l.readLine() + if err != nil { + return nil, err + } + + bandwidth, err := unmarshalBandwidth(value) + if err != nil { + return nil, fmt.Errorf("%w `b=%v`", errSDPInvalidValue, value) + } + l.desc.Bandwidth = append(l.desc.Bandwidth, *bandwidth) + + return s5, nil +} + +func unmarshalBandwidth(value string) (*Bandwidth, error) { + parts := strings.Split(value, ":") + if len(parts) != 2 { + return nil, fmt.Errorf("%w `b=%v`", errSDPInvalidValue, parts) + } + + experimental := strings.HasPrefix(parts[0], "X-") + if experimental { + parts[0] = strings.TrimPrefix(parts[0], "X-") + } else if !anyOf(parts[0], "CT", "AS", "TIAS", "RS", "RR") { + // Set according to currently registered with IANA + // https://tools.ietf.org/html/rfc4566#section-5.8 + // https://tools.ietf.org/html/rfc3890#section-6.2 + // https://tools.ietf.org/html/rfc3556#section-2 + return nil, fmt.Errorf("%w `%v`", errSDPInvalidValue, parts[0]) + } + + bandwidth, err := strconv.ParseUint(parts[1], 10, 64) + if err != nil { + return nil, fmt.Errorf("%w `%v`", errSDPInvalidNumericValue, parts[1]) + } + + return &Bandwidth{ + Experimental: experimental, + Type: parts[0], + Bandwidth: bandwidth, + }, nil +} + +func unmarshalTiming(lex *lexer) (stateFn, error) { + var err error + var td TimeDescription + + td.Timing.StartTime, err = lex.readUint64Field() + if err != nil { + return nil, err + } + + td.Timing.StopTime, err = lex.readUint64Field() + if err != nil { + return nil, err + } + + if err := lex.nextLine(); err != nil { + return nil, err + } + + lex.desc.TimeDescriptions = append(lex.desc.TimeDescriptions, td) + + return s9, nil +} + +func unmarshalRepeatTimes(lex *lexer) (stateFn, error) { + var err error + var newRepeatTime RepeatTime + + latestTimeDesc := &lex.desc.TimeDescriptions[len(lex.desc.TimeDescriptions)-1] + + field, err := lex.readField() + if err != nil { + return nil, err + } + + newRepeatTime.Interval, err = parseTimeUnits(field) + if err != nil { + return nil, fmt.Errorf("%w `%v`", errSDPInvalidValue, field) + } + + field, err = lex.readField() + if err != nil { + return nil, err + } + + newRepeatTime.Duration, err = parseTimeUnits(field) + if err != nil { + return nil, fmt.Errorf("%w `%v`", errSDPInvalidValue, field) + } + + for { + field, err := lex.readField() + if err != nil { + return nil, err + } + if field == "" { + break + } + offset, err := parseTimeUnits(field) + if err != nil { + return nil, fmt.Errorf("%w `%v`", errSDPInvalidValue, field) + } + newRepeatTime.Offsets = append(newRepeatTime.Offsets, offset) + } + + if err := lex.nextLine(); err != nil { + return nil, err + } + + latestTimeDesc.RepeatTimes = append(latestTimeDesc.RepeatTimes, newRepeatTime) + + return s9, nil +} + +func unmarshalTimeZones(lex *lexer) (stateFn, error) { + // These fields are transimitted in pairs + // z= .... + // so we are making sure that there are actually multiple of 2 total. + for { + var err error + var timeZone TimeZone + + timeZone.AdjustmentTime, err = lex.readUint64Field() + if err != nil { + return nil, err + } + + offset, err := lex.readField() + if err != nil { + return nil, err + } + + if offset == "" { + break + } + + timeZone.Offset, err = parseTimeUnits(offset) + if err != nil { + return nil, err + } + + lex.desc.TimeZones = append(lex.desc.TimeZones, timeZone) + } + + if err := lex.nextLine(); err != nil { + return nil, err + } + + return s13, nil +} + +func unmarshalSessionEncryptionKey(l *lexer) (stateFn, error) { + value, err := l.readLine() + if err != nil { + return nil, err + } + + encryptionKey := EncryptionKey(value) + l.desc.EncryptionKey = &encryptionKey + + return s11, nil +} + +func unmarshalSessionAttribute(l *lexer) (stateFn, error) { + value, err := l.readLine() + if err != nil { + return nil, err + } + + i := strings.IndexRune(value, ':') + a := l.cache.getSessionAttribute() + if i > 0 { + a.Key = value[:i] + a.Value = value[i+1:] + } else { + a.Key = value + } + + return s11, nil +} + +func unmarshalMediaDescription(lex *lexer) (stateFn, error) { //nolint:cyclop + populateMediaAttributes(lex.cache, lex.desc) + var newMediaDesc MediaDescription + + // + field, err := lex.readField() + if err != nil { + return nil, err + } + + // Set according to currently registered with IANA + // https://tools.ietf.org/html/rfc4566#section-5.14 + if !anyOf(field, "audio", "video", "text", "application", "message") { + return nil, fmt.Errorf("%w `%v`", errSDPInvalidValue, field) + } + newMediaDesc.MediaName.Media = field + + // + field, err = lex.readField() + if err != nil { + return nil, err + } + parts := strings.Split(field, "/") + newMediaDesc.MediaName.Port.Value, err = parsePort(parts[0]) + if err != nil { + return nil, fmt.Errorf("%w `%v`", errSDPInvalidPortValue, parts[0]) + } + + if len(parts) > 1 { + var portRange int + portRange, err = strconv.Atoi(parts[1]) + if err != nil { + return nil, fmt.Errorf("%w `%v`", errSDPInvalidValue, parts) + } + newMediaDesc.MediaName.Port.Range = &portRange + } + + // + field, err = lex.readField() + if err != nil { + return nil, err + } + + // Set according to currently registered with IANA + // https://tools.ietf.org/html/rfc4566#section-5.14 + // https://tools.ietf.org/html/rfc4975#section-8.1 + for _, proto := range strings.Split(field, "/") { + if !anyOf( + proto, + "UDP", + "RTP", + "AVP", + "SAVP", + "SAVPF", + "TLS", + "DTLS", + "SCTP", + "AVPF", + "TCP", + "MSRP", + "BFCP", + "UDT", + "IX", + "MRCPv2", + "FEC", + ) { + return nil, fmt.Errorf("%w `%v`", errSDPInvalidNumericValue, field) + } + newMediaDesc.MediaName.Protos = append(newMediaDesc.MediaName.Protos, proto) + } + + // ... + for { + field, err = lex.readField() + if err != nil { + return nil, err + } + if field == "" { + break + } + newMediaDesc.MediaName.Formats = append(newMediaDesc.MediaName.Formats, field) + } + + if err := lex.nextLine(); err != nil { + return nil, err + } + + lex.desc.MediaDescriptions = append(lex.desc.MediaDescriptions, &newMediaDesc) + + return s12, nil +} + +func unmarshalMediaTitle(l *lexer) (stateFn, error) { + value, err := l.readLine() + if err != nil { + return nil, err + } + + latestMediaDesc := l.desc.MediaDescriptions[len(l.desc.MediaDescriptions)-1] + mediaTitle := Information(value) + latestMediaDesc.MediaTitle = &mediaTitle + + return s16, nil +} + +func unmarshalMediaConnectionInformation(l *lexer) (stateFn, error) { + var err error + latestMediaDesc := l.desc.MediaDescriptions[len(l.desc.MediaDescriptions)-1] + latestMediaDesc.ConnectionInformation, err = l.unmarshalConnectionInformation() + if err != nil { + return nil, err + } + + return s15, nil +} + +func unmarshalMediaBandwidth(l *lexer) (stateFn, error) { + value, err := l.readLine() + if err != nil { + return nil, err + } + + latestMediaDesc := l.desc.MediaDescriptions[len(l.desc.MediaDescriptions)-1] + bandwidth, err := unmarshalBandwidth(value) + if err != nil { + return nil, fmt.Errorf("%w `b=%v`", errSDPInvalidSyntax, value) + } + latestMediaDesc.Bandwidth = append(latestMediaDesc.Bandwidth, *bandwidth) + + return s15, nil +} + +func unmarshalMediaEncryptionKey(l *lexer) (stateFn, error) { + value, err := l.readLine() + if err != nil { + return nil, err + } + + latestMediaDesc := l.desc.MediaDescriptions[len(l.desc.MediaDescriptions)-1] + encryptionKey := EncryptionKey(value) + latestMediaDesc.EncryptionKey = &encryptionKey + + return s14, nil +} + +func unmarshalMediaAttribute(l *lexer) (stateFn, error) { + value, err := l.readLine() + if err != nil { + return nil, err + } + + i := strings.IndexRune(value, ':') + a := l.cache.getMediaAttribute() + if i > 0 { + a.Key = value[:i] + a.Value = value[i+1:] + } else { + a.Key = value + } + + return s14, nil +} + +func parseTimeUnits(value string) (num int64, err error) { + if len(value) == 0 { + return 0, fmt.Errorf("%w `%v`", errSDPInvalidValue, value) + } + k := timeShorthand(value[len(value)-1]) + if k > 0 { + num, err = strconv.ParseInt(value[:len(value)-1], 10, 64) + } else { + k = 1 + num, err = strconv.ParseInt(value, 10, 64) + } + if err != nil { + return 0, fmt.Errorf("%w `%v`", errSDPInvalidValue, value) + } + + return num * k, nil +} + +func timeShorthand(b byte) int64 { + // Some time offsets in the protocol can be provided with a shorthand + // notation. This code ensures to convert it to NTP timestamp format. + switch b { + case 'd': // days + return 86400 + case 'h': // hours + return 3600 + case 'm': // minutes + return 60 + case 's': // seconds (allowed for completeness) + return 1 + default: + return 0 + } +} + +func parsePort(value string) (int, error) { + port, err := strconv.Atoi(value) + if err != nil { + return 0, fmt.Errorf("%w `%v`", errSDPInvalidPortValue, value) + } + + if port < 0 || port > 65535 { + return 0, fmt.Errorf("%w -- out of range `%v`", errSDPInvalidPortValue, port) + } + + return port, nil +} + +func populateMediaAttributes(c *unmarshalCache, s *SessionDescription) { + if len(s.MediaDescriptions) != 0 { + lastMediaDesc := s.MediaDescriptions[len(s.MediaDescriptions)-1] + lastMediaDesc.Attributes = c.cloneMediaAttributes() + } +} diff --git a/vendor/github.com/pion/sdp/v3/unmarshal_cache.go b/vendor/github.com/pion/sdp/v3/unmarshal_cache.go new file mode 100644 index 0000000..3d3e59a --- /dev/null +++ b/vendor/github.com/pion/sdp/v3/unmarshal_cache.go @@ -0,0 +1,48 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sdp + +type unmarshalCache struct { + sessionAttributes []Attribute + mediaAttributes []Attribute +} + +func (c *unmarshalCache) reset() { + c.sessionAttributes = c.sessionAttributes[:0] + c.mediaAttributes = c.mediaAttributes[:0] +} + +func (c *unmarshalCache) getSessionAttribute() *Attribute { + c.sessionAttributes = append(c.sessionAttributes, Attribute{}) + + return &c.sessionAttributes[len(c.sessionAttributes)-1] +} + +func (c *unmarshalCache) cloneSessionAttributes() []Attribute { + if len(c.sessionAttributes) == 0 { + return nil + } + s := make([]Attribute, len(c.sessionAttributes)) + copy(s, c.sessionAttributes) + c.sessionAttributes = c.sessionAttributes[:0] + + return s +} + +func (c *unmarshalCache) getMediaAttribute() *Attribute { + c.mediaAttributes = append(c.mediaAttributes, Attribute{}) + + return &c.mediaAttributes[len(c.mediaAttributes)-1] +} + +func (c *unmarshalCache) cloneMediaAttributes() []Attribute { + if len(c.mediaAttributes) == 0 { + return nil + } + s := make([]Attribute, len(c.mediaAttributes)) + copy(s, c.mediaAttributes) + c.mediaAttributes = c.mediaAttributes[:0] + + return s +} diff --git a/vendor/github.com/pion/sdp/v3/util.go b/vendor/github.com/pion/sdp/v3/util.go new file mode 100644 index 0000000..b6f6a98 --- /dev/null +++ b/vendor/github.com/pion/sdp/v3/util.go @@ -0,0 +1,390 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package sdp + +import ( + "errors" + "fmt" + "io" + "slices" + "sort" + "strconv" + "strings" + + "github.com/pion/randutil" +) + +const ( + attributeKey = "a=" +) + +var ( + errExtractCodecRtpmap = errors.New("could not extract codec from rtpmap") + errExtractCodecFmtp = errors.New("could not extract codec from fmtp") + errExtractCodecRtcpFb = errors.New("could not extract codec from rtcp-fb") + errPayloadTypeNotFound = errors.New("payload type not found") + errCodecNotFound = errors.New("codec not found") + errSyntaxError = errors.New("SyntaxError") + errFieldMissing = errors.New("field missing") +) + +// ConnectionRole indicates which of the end points should initiate the connection establishment. +type ConnectionRole int + +const ( + // ConnectionRoleActive indicates the endpoint will initiate an outgoing connection. + ConnectionRoleActive ConnectionRole = iota + 1 + + // ConnectionRolePassive indicates the endpoint will accept an incoming connection. + ConnectionRolePassive + + // ConnectionRoleActpass indicates the endpoint is willing to accept an incoming connection or + // to initiate an outgoing connection. + ConnectionRoleActpass + + // ConnectionRoleHoldconn indicates the endpoint does not want the connection to be established for the time being. + ConnectionRoleHoldconn +) + +func (t ConnectionRole) String() string { + switch t { + case ConnectionRoleActive: + return "active" + case ConnectionRolePassive: + return "passive" + case ConnectionRoleActpass: + return "actpass" + case ConnectionRoleHoldconn: + return "holdconn" + default: + return "Unknown" + } +} + +func newSessionID() (uint64, error) { + // https://tools.ietf.org/html/draft-ietf-rtcweb-jsep-26#section-5.2.1 + // Session ID is recommended to be constructed by generating a 64-bit + // quantity with the highest bit set to zero and the remaining 63-bits + // being cryptographically random. + id, err := randutil.CryptoUint64() + + return id & (^(uint64(1) << 63)), err +} + +// Codec represents a codec. +type Codec struct { + PayloadType uint8 + Name string + ClockRate uint32 + EncodingParameters string + Fmtp string + RTCPFeedback []string +} + +const ( + unknown = iota +) + +func (c Codec) String() string { + return fmt.Sprintf( + "%d %s/%d/%s (%s) [%s]", + c.PayloadType, + c.Name, + c.ClockRate, + c.EncodingParameters, + c.Fmtp, + strings.Join(c.RTCPFeedback, ", "), + ) +} + +func (c *Codec) appendRTCPFeedback(rtcpFeedback string) { + if slices.Contains(c.RTCPFeedback, rtcpFeedback) { + return + } + + c.RTCPFeedback = append(c.RTCPFeedback, rtcpFeedback) +} + +func parseRtpmap(rtpmap string) (Codec, error) { + var codec Codec + parsingFailed := errExtractCodecRtpmap + + // a=rtpmap: /[/] + split := strings.Split(rtpmap, " ") + if len(split) != 2 { + return codec, parsingFailed + } + + ptSplit := strings.Split(split[0], ":") + if len(ptSplit) != 2 { + return codec, parsingFailed + } + + ptInt, err := strconv.ParseUint(ptSplit[1], 10, 8) + if err != nil { + return codec, parsingFailed + } + + codec.PayloadType = uint8(ptInt) + + split = strings.Split(split[1], "/") + codec.Name = split[0] + parts := len(split) + if parts > 1 { + rate, err := strconv.ParseUint(split[1], 10, 32) + if err != nil { + return codec, parsingFailed + } + codec.ClockRate = uint32(rate) + } + if parts > 2 { + codec.EncodingParameters = split[2] + } + + return codec, nil +} + +func parseFmtp(fmtp string) (Codec, error) { + var codec Codec + parsingFailed := errExtractCodecFmtp + + // a=fmtp: + split := strings.SplitN(fmtp, " ", 2) + if len(split) != 2 { + return codec, parsingFailed + } + + formatParams := split[1] + + split = strings.Split(split[0], ":") + if len(split) != 2 { + return codec, parsingFailed + } + + ptInt, err := strconv.ParseUint(split[1], 10, 8) + if err != nil { + return codec, parsingFailed + } + + codec.PayloadType = uint8(ptInt) + codec.Fmtp = formatParams + + return codec, nil +} + +func parseRtcpFb(rtcpFb string) (codec Codec, isWildcard bool, err error) { + var ptInt uint64 + err = errExtractCodecRtcpFb + + // a=ftcp-fb: [] + split := strings.SplitN(rtcpFb, " ", 2) + if len(split) != 2 { + return + } + + ptSplit := strings.Split(split[0], ":") + if len(ptSplit) != 2 { + return + } + + isWildcard = ptSplit[1] == "*" + if !isWildcard { + ptInt, err = strconv.ParseUint(ptSplit[1], 10, 8) + if err != nil { + return + } + + codec.PayloadType = uint8(ptInt) + } + + codec.RTCPFeedback = append(codec.RTCPFeedback, split[1]) + + return codec, isWildcard, nil +} + +func mergeCodecs(codec Codec, codecs map[uint8]Codec) { + savedCodec := codecs[codec.PayloadType] + + if savedCodec.PayloadType == 0 { + savedCodec.PayloadType = codec.PayloadType + } + if savedCodec.Name == "" { + savedCodec.Name = codec.Name + } + if savedCodec.ClockRate == 0 { + savedCodec.ClockRate = codec.ClockRate + } + if savedCodec.EncodingParameters == "" { + savedCodec.EncodingParameters = codec.EncodingParameters + } + if savedCodec.Fmtp == "" { + savedCodec.Fmtp = codec.Fmtp + } + savedCodec.RTCPFeedback = append(savedCodec.RTCPFeedback, codec.RTCPFeedback...) + + codecs[savedCodec.PayloadType] = savedCodec +} + +func (s *SessionDescription) buildCodecMap() map[uint8]Codec { //nolint:cyclop + codecs := map[uint8]Codec{ + // static codecs that do not require a rtpmap + 0: { + PayloadType: 0, + Name: "PCMU", + ClockRate: 8000, + }, + 8: { + PayloadType: 8, + Name: "PCMA", + ClockRate: 8000, + }, + 9: { + PayloadType: 9, + Name: "G722", + ClockRate: 8000, + }, + } + + wildcardRTCPFeedback := []string{} + for _, m := range s.MediaDescriptions { + for _, a := range m.Attributes { + attr := a.String() + switch { + case strings.HasPrefix(attr, "rtpmap:"): + codec, err := parseRtpmap(attr) + if err == nil { + mergeCodecs(codec, codecs) + } + case strings.HasPrefix(attr, "fmtp:"): + codec, err := parseFmtp(attr) + if err == nil { + mergeCodecs(codec, codecs) + } + case strings.HasPrefix(attr, "rtcp-fb:"): + codec, isWildcard, err := parseRtcpFb(attr) + switch { + case err != nil: + case isWildcard: + wildcardRTCPFeedback = append(wildcardRTCPFeedback, codec.RTCPFeedback...) + default: + mergeCodecs(codec, codecs) + } + } + } + } + + for i, codec := range codecs { + for _, newRTCPFeedback := range wildcardRTCPFeedback { + codec.appendRTCPFeedback(newRTCPFeedback) + } + + codecs[i] = codec + } + + return codecs +} + +func equivalentFmtp(want, got string) bool { + wantSplit := strings.Split(want, ";") + gotSplit := strings.Split(got, ";") + + if len(wantSplit) != len(gotSplit) { + return false + } + + sort.Strings(wantSplit) + sort.Strings(gotSplit) + + for i, wantPart := range wantSplit { + wantPart = strings.TrimSpace(wantPart) + gotPart := strings.TrimSpace(gotSplit[i]) + if gotPart != wantPart { + return false + } + } + + return true +} + +func codecsMatch(wanted, got Codec) bool { + if wanted.Name != "" && !strings.EqualFold(wanted.Name, got.Name) { + return false + } + if wanted.ClockRate != 0 && wanted.ClockRate != got.ClockRate { + return false + } + if wanted.EncodingParameters != "" && wanted.EncodingParameters != got.EncodingParameters { + return false + } + if wanted.Fmtp != "" && !equivalentFmtp(wanted.Fmtp, got.Fmtp) { + return false + } + + return true +} + +// GetCodecForPayloadType scans the SessionDescription for the given payload type and returns the codec. +func (s *SessionDescription) GetCodecForPayloadType(payloadType uint8) (Codec, error) { + codecs := s.buildCodecMap() + + codec, ok := codecs[payloadType] + if ok { + return codec, nil + } + + return codec, errPayloadTypeNotFound +} + +func (s *SessionDescription) GetCodecsForPayloadTypes(payloadTypes []uint8) ([]Codec, error) { + codecs := s.buildCodecMap() + + result := make([]Codec, 0, len(payloadTypes)) + for _, payloadType := range payloadTypes { + codec, ok := codecs[payloadType] + if ok { + result = append(result, codec) + } + } + + return result, nil +} + +// GetPayloadTypeForCodec scans the SessionDescription for a codec that matches the provided codec +// as closely as possible and returns its payload type. +func (s *SessionDescription) GetPayloadTypeForCodec(wanted Codec) (uint8, error) { + codecs := s.buildCodecMap() + + for payloadType, codec := range codecs { + if codecsMatch(wanted, codec) { + return payloadType, nil + } + } + + return 0, errCodecNotFound +} + +type stateFn func(*lexer) (stateFn, error) + +type lexer struct { + desc *SessionDescription + cache *unmarshalCache + baseLexer +} + +type keyToState func(key byte) stateFn + +func (l *lexer) handleType(fn keyToState) (stateFn, error) { + key, err := l.readType() + if errors.Is(err, io.EOF) && key == 0 { + return nil, nil //nolint:nilnil + } else if err != nil { + return nil, err + } + + if res := fn(key); res != nil { + return res, nil + } + + return nil, l.syntaxError() +} diff --git a/vendor/github.com/pion/srtp/v3/.gitignore b/vendor/github.com/pion/srtp/v3/.gitignore new file mode 100644 index 0000000..6e2f206 --- /dev/null +++ b/vendor/github.com/pion/srtp/v3/.gitignore @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +### JetBrains IDE ### +##################### +.idea/ + +### Emacs Temporary Files ### +############################# +*~ + +### Folders ### +############### +bin/ +vendor/ +node_modules/ + +### Files ### +############# +*.ivf +*.ogg +tags +cover.out +*.sw[poe] +*.wasm +examples/sfu-ws/cert.pem +examples/sfu-ws/key.pem +wasm_exec.js diff --git a/vendor/github.com/pion/srtp/v3/.golangci.yml b/vendor/github.com/pion/srtp/v3/.golangci.yml new file mode 100644 index 0000000..4b4025f --- /dev/null +++ b/vendor/github.com/pion/srtp/v3/.golangci.yml @@ -0,0 +1,147 @@ +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +version: "2" +linters: + enable: + - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers + - bidichk # Checks for dangerous unicode character sequences + - bodyclose # checks whether HTTP response body is closed successfully + - containedctx # containedctx is a linter that detects struct contained context.Context field + - contextcheck # check the function whether use a non-inherited context + - cyclop # checks function and package cyclomatic complexity + - decorder # check declaration order and count of types, constants, variables and functions + - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) + - dupl # Tool for code clone detection + - durationcheck # check for two durations multiplied together + - err113 # Golang linter to check the errors handling expressions + - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases + - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. + - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. + - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. + - exhaustive # check exhaustiveness of enum switch statements + - forbidigo # Forbids identifiers + - forcetypeassert # finds forced type assertions + - gochecknoglobals # Checks that no globals are present in Go code + - gocognit # Computes and checks the cognitive complexity of functions + - goconst # Finds repeated strings that could be replaced by a constant + - gocritic # The most opinionated Go source code linter + - gocyclo # Computes and checks the cyclomatic complexity of functions + - godot # Check if comments end in a period + - godox # Tool for detection of FIXME, TODO and other comment keywords + - goheader # Checks is file header matches to pattern + - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. + - goprintffuncname # Checks that printf-like functions are named with `f` at the end + - gosec # Inspects source code for security problems + - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string + - grouper # An analyzer to analyze expression groups. + - importas # Enforces consistent import aliases + - ineffassign # Detects when assignments to existing variables are not used + - lll # Reports long lines + - maintidx # maintidx measures the maintainability index of each function. + - makezero # Finds slice declarations with non-zero initial length + - misspell # Finds commonly misspelled English words in comments + - nakedret # Finds naked returns in functions greater than a specified function length + - nestif # Reports deeply nested if statements + - nilerr # Finds the code that returns nil even if it checks that the error is not nil. + - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. + - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity + - noctx # noctx finds sending http request without context.Context + - predeclared # find code that shadows one of Go's predeclared identifiers + - revive # golint replacement, finds style mistakes + - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks + - tagliatelle # Checks the struct tags. + - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers + - unconvert # Remove unnecessary type conversions + - unparam # Reports unused function parameters + - unused # Checks Go code for unused constants, variables, functions and types + - varnamelen # checks that the length of a variable's name matches its scope + - wastedassign # wastedassign finds wasted assignment statements + - whitespace # Tool for detection of leading and trailing whitespace + disable: + - depguard # Go linter that checks if package imports are in a list of acceptable packages + - funlen # Tool for detection of long functions + - gochecknoinits # Checks that no init functions are present in Go code + - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. + - interfacebloat # A linter that checks length of interface. + - ireturn # Accept Interfaces, Return Concrete Types + - mnd # An analyzer to detect magic numbers + - nolintlint # Reports ill-formed or insufficient nolint directives + - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test + - prealloc # Finds slice declarations that could potentially be preallocated + - promlinter # Check Prometheus metrics naming via promlint + - rowserrcheck # checks whether Err of rows is checked successfully + - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. + - testpackage # linter that makes you use a separate _test package + - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes + - wrapcheck # Checks that errors returned from external packages are wrapped + - wsl # Whitespace Linter - Forces you to use empty lines! + settings: + staticcheck: + checks: + - all + - -QF1008 # "could remove embedded field", to keep it explicit! + - -QF1003 # "could use tagged switch on enum", Cases conflicts with exhaustive! + exhaustive: + default-signifies-exhaustive: true + forbidigo: + forbid: + - pattern: ^fmt.Print(f|ln)?$ + - pattern: ^log.(Panic|Fatal|Print)(f|ln)?$ + - pattern: ^os.Exit$ + - pattern: ^panic$ + - pattern: ^print(ln)?$ + - pattern: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$ + pkg: ^testing$ + msg: use testify/assert instead + analyze-types: true + gomodguard: + blocked: + modules: + - github.com/pkg/errors: + recommendations: + - errors + govet: + enable: + - shadow + revive: + rules: + # Prefer 'any' type alias over 'interface{}' for Go 1.18+ compatibility + - name: use-any + severity: warning + disabled: false + misspell: + locale: US + varnamelen: + max-distance: 12 + min-name-length: 2 + ignore-type-assert-ok: true + ignore-map-index-ok: true + ignore-chan-recv-ok: true + ignore-decls: + - i int + - n int + - w io.Writer + - r io.Reader + - b []byte + exclusions: + generated: lax + rules: + - linters: + - forbidigo + - gocognit + path: (examples|main\.go) + - linters: + - gocognit + path: _test\.go + - linters: + - forbidigo + path: cmd +formatters: + enable: + - gci # Gci control golang package import order and make it always deterministic. + - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification + - gofumpt # Gofumpt checks whether code was gofumpt-ed. + - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports + exclusions: + generated: lax diff --git a/vendor/github.com/pion/srtp/v3/.goreleaser.yml b/vendor/github.com/pion/srtp/v3/.goreleaser.yml new file mode 100644 index 0000000..30093e9 --- /dev/null +++ b/vendor/github.com/pion/srtp/v3/.goreleaser.yml @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +builds: +- skip: true diff --git a/vendor/github.com/pion/srtp/v3/LICENSE b/vendor/github.com/pion/srtp/v3/LICENSE new file mode 100644 index 0000000..491caf6 --- /dev/null +++ b/vendor/github.com/pion/srtp/v3/LICENSE @@ -0,0 +1,9 @@ +MIT License + +Copyright (c) 2023 The Pion community + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/pion/srtp/v3/README.md b/vendor/github.com/pion/srtp/v3/README.md new file mode 100644 index 0000000..10ede23 --- /dev/null +++ b/vendor/github.com/pion/srtp/v3/README.md @@ -0,0 +1,35 @@ +

+
+ Pion SRTP +
+

+

A Go implementation of SRTP

+

+ Pion SRTP + Sourcegraph Widget + join us on Discord Follow us on Bluesky +
+ GitHub Workflow Status + Go Reference + Coverage Status + Go Report Card + License: MIT +

+
+ +### Roadmap +The library is used as a part of our WebRTC implementation. Please refer to that [roadmap](https://github.com/pion/webrtc/issues/9) to track our major milestones. + +### Community +Pion has an active community on the [Discord](https://discord.gg/PngbdqpFbt). + +Follow the [Pion Bluesky](https://bsky.app/profile/pion.ly) or [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news. + +We are always looking to support **your projects**. Please reach out if you have something to build! +If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly) + +### Contributing +Check out the [contributing wiki](https://github.com/pion/webrtc/wiki/Contributing) to join the group of amazing people making this project possible + +### License +MIT License - see [LICENSE](LICENSE) for full text diff --git a/vendor/github.com/pion/srtp/v3/codecov.yml b/vendor/github.com/pion/srtp/v3/codecov.yml new file mode 100644 index 0000000..263e4d4 --- /dev/null +++ b/vendor/github.com/pion/srtp/v3/codecov.yml @@ -0,0 +1,22 @@ +# +# DO NOT EDIT THIS FILE +# +# It is automatically copied from https://github.com/pion/.goassets repository. +# +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +coverage: + status: + project: + default: + # Allow decreasing 2% of total coverage to avoid noise. + threshold: 2% + patch: + default: + target: 70% + only_pulls: true + +ignore: + - "examples/*" + - "examples/**/*" diff --git a/vendor/github.com/pion/srtp/v3/context.go b/vendor/github.com/pion/srtp/v3/context.go new file mode 100644 index 0000000..957f329 --- /dev/null +++ b/vendor/github.com/pion/srtp/v3/context.go @@ -0,0 +1,418 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package srtp + +import ( + "bytes" + "fmt" + + "github.com/pion/transport/v4/replaydetector" +) + +const ( + labelSRTPEncryption = 0x00 + labelSRTPAuthenticationTag = 0x01 + labelSRTPSalt = 0x02 + + labelSRTCPEncryption = 0x03 + labelSRTCPAuthenticationTag = 0x04 + labelSRTCPSalt = 0x05 + + maxSequenceNumber = 65535 + maxROC = (1 << 32) - 1 + + seqNumMedian = 1 << 15 + seqNumMax = 1 << 16 +) + +// Encrypt/Decrypt state for a single SRTP SSRC. +type srtpSSRCState struct { + ssrc uint32 + rolloverHasProcessed bool + index uint64 + replayDetector replaydetector.ReplayDetector +} + +// Encrypt/Decrypt state for a single SRTCP SSRC. +type srtcpSSRCState struct { + srtcpIndex uint32 + ssrc uint32 + replayDetector replaydetector.ReplayDetector +} + +// RCCMode is the mode of Roll-over Counter Carrying Transform from RFC 4771. +type RCCMode int + +const ( + // RCCModeNone is the default mode. + RCCModeNone RCCMode = iota + // RCCMode1 is RCCm1 mode from RFC 4771. In this mode ROC and truncated auth tag is sent every R-th packet, + // and no auth tag in other ones. This mode is not supported by pion/srtp. + RCCMode1 + // RCCMode2 is RCCm2 mode from RFC 4771. In this mode ROC and truncated auth tag is sent every R-th packet, + // and full auth tag in other ones. This mode is supported for AES-CM and NULL profiles only. + RCCMode2 + // RCCMode3 is RCCm3 mode from RFC 4771. In this mode ROC is sent every R-th packet (without truncated auth tag), + // and no auth tag in other ones. This mode is supported for AES-GCM profiles only. + RCCMode3 +) + +// CryptexMode is the mode of Cryptex support for SRTP packets from RFC 9335. +type CryptexMode int + +const ( + // CryptexModeDisabled (default) disables Cryptex support. Received Cryptex SRTP packets with encrypted + // CSRCs and header extensions will be rejected with an error. + CryptexModeDisabled CryptexMode = 0 + // CryptexModeEnabled enables Cryptex support when SRTP packets are encrypted. Received SRTP packets + // with unencrypted CSRCs and header extensions will be accepted and decrypted. + CryptexModeEnabled CryptexMode = 1 + // CryptexModeRequired enables Cryptex support when SRTP packets are encrypted. Received SRTP packets + // with unencrypted CSRCs and header extensions will be rejected with an error. + CryptexModeRequired CryptexMode = 2 +) + +// Context represents a SRTP cryptographic context. +// Context can only be used for one-way operations. +// it must either used ONLY for encryption or ONLY for decryption. +// Note that Context does not provide any concurrency protection: +// access to a Context from multiple goroutines requires external +// synchronization. +type Context struct { + cipher srtpCipher + + srtpSSRCStates map[uint32]*srtpSSRCState + srtcpSSRCStates map[uint32]*srtcpSSRCState + + newSRTCPReplayDetector func() replaydetector.ReplayDetector + newSRTPReplayDetector func() replaydetector.ReplayDetector + + profile ProtectionProfile + + // Master Key Identifier used for encrypting RTP/RTCP packets. Set to nil if MKI is not enabled. + sendMKI []byte + // Master Key Identifier to cipher mapping. Used for decrypting packets. Empty if MKI is not enabled. + mkis map[string]srtpCipher + + encryptSRTP bool + encryptSRTCP bool + + rccMode RCCMode + rocTransmitRate uint16 + + authTagRTPLen *int + + cryptexMode CryptexMode +} + +// CreateContext creates a new SRTP Context. +// +// CreateContext receives variable number of ContextOption-s. +// Passing multiple options which set the same parameter let the last one valid. +// Following example create SRTP Context with replay protection with window size of 256. +// +// decCtx, err := srtp.CreateContext(key, salt, profile, srtp.SRTPReplayProtection(256)) +func CreateContext( + masterKey, masterSalt []byte, + profile ProtectionProfile, + opts ...ContextOption, +) (c *Context, err error) { + c = &Context{ + srtpSSRCStates: map[uint32]*srtpSSRCState{}, + srtcpSSRCStates: map[uint32]*srtcpSSRCState{}, + profile: profile, + mkis: map[string]srtpCipher{}, + } + + for _, o := range append( + []ContextOption{ // Default options + SRTPNoReplayProtection(), + SRTCPNoReplayProtection(), + SRTPEncryption(), + SRTCPEncryption(), + }, + opts..., // User specified options + ) { + if errOpt := o(c); errOpt != nil { + return nil, errOpt + } + } + + if err = c.checkRCCMode(); err != nil { + return nil, err + } + + if c.authTagRTPLen != nil { + var authKeyLen int + authKeyLen, err = c.profile.AuthKeyLen() + if err != nil { + return nil, err + } + if *c.authTagRTPLen > authKeyLen { + return nil, errTooLongSRTPAuthTag + } + } + + c.cipher, err = c.createCipher(c.sendMKI, masterKey, masterSalt, c.encryptSRTP, c.encryptSRTCP) + if err != nil { + return nil, err + } + if len(c.sendMKI) != 0 { + c.mkis[string(c.sendMKI)] = c.cipher + } + + return c, nil +} + +// AddCipherForMKI adds new MKI with associated masker key and salt. +// Context must be created with MasterKeyIndicator option +// to enable MKI support. MKI must be unique and have the same length as the one used for creating Context. +// Operation is not thread-safe, you need to provide synchronization with decrypting packets. +func (c *Context) AddCipherForMKI(mki, masterKey, masterSalt []byte) error { + if len(c.mkis) == 0 { + return errMKIIsNotEnabled + } + if len(mki) == 0 || len(mki) != len(c.sendMKI) { + return errInvalidMKILength + } + if _, ok := c.mkis[string(mki)]; ok { + return errMKIAlreadyInUse + } + + cipher, err := c.createCipher(mki, masterKey, masterSalt, c.encryptSRTP, c.encryptSRTCP) + if err != nil { + return err + } + c.mkis[string(mki)] = cipher + + return nil +} + +func (c *Context) createCipher(mki, masterKey, masterSalt []byte, encryptSRTP, encryptSRTCP bool) (srtpCipher, error) { + keyLen, err := c.profile.KeyLen() + if err != nil { + return nil, err + } + + saltLen, err := c.profile.SaltLen() + if err != nil { + return nil, err + } + + if masterKeyLen := len(masterKey); masterKeyLen != keyLen { + return nil, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterKey, keyLen, masterKey) + } else if masterSaltLen := len(masterSalt); masterSaltLen != saltLen { + return nil, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterSalt, saltLen, masterSaltLen) + } + + profileWithArgs := protectionProfileWithArgs{ + ProtectionProfile: c.profile, + authTagRTPLen: c.authTagRTPLen, + } + + useCryptex := c.cryptexMode != CryptexModeDisabled && encryptSRTP + switch c.profile { + case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm: + return newSrtpCipherAeadAesGcm(profileWithArgs, masterKey, masterSalt, mki, encryptSRTP, encryptSRTCP, useCryptex) + case ProtectionProfileAes128CmHmacSha1_32, + ProtectionProfileAes128CmHmacSha1_80, + ProtectionProfileAes256CmHmacSha1_32, + ProtectionProfileAes256CmHmacSha1_80: + return newSrtpCipherAesCmHmacSha1(profileWithArgs, masterKey, masterSalt, mki, encryptSRTP, encryptSRTCP, useCryptex) + case ProtectionProfileNullHmacSha1_32, ProtectionProfileNullHmacSha1_80: + return newSrtpCipherAesCmHmacSha1(profileWithArgs, masterKey, masterSalt, mki, false, false, false) + default: + return nil, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, c.profile) + } +} + +// RemoveMKI removes one of MKIs. You cannot remove last MKI and one used for encrypting RTP/RTCP packets. +// Operation is not thread-safe, you need to provide synchronization with decrypting packets. +func (c *Context) RemoveMKI(mki []byte) error { + if _, ok := c.mkis[string(mki)]; !ok { + return ErrMKINotFound + } + if bytes.Equal(mki, c.sendMKI) { + return errMKIAlreadyInUse + } + delete(c.mkis, string(mki)) + + return nil +} + +// SetSendMKI switches MKI and cipher used for encrypting RTP/RTCP packets. +// Operation is not thread-safe, you need to provide synchronization with encrypting packets. +func (c *Context) SetSendMKI(mki []byte) error { + cipher, ok := c.mkis[string(mki)] + if !ok { + return ErrMKINotFound + } + c.sendMKI = mki + c.cipher = cipher + + return nil +} + +// https://tools.ietf.org/html/rfc3550#appendix-A.1 +func (s *srtpSSRCState) nextRolloverCount(sequenceNumber uint16) (roc uint32, diff int64, overflow bool) { + seq := int32(sequenceNumber) + localRoc := uint32(s.index >> 16) //nolint:gosec // G115 + localSeq := int32(s.index & (seqNumMax - 1)) //nolint:gosec // G115 + + guessRoc := localRoc + var difference int32 + + if s.rolloverHasProcessed { //nolint:nestif + // When localROC is equal to 0, and entering seq-localSeq > seqNumMedian + // judgment, it will cause guessRoc calculation error + if s.index > seqNumMedian { + if localSeq < seqNumMedian { + if seq-localSeq > seqNumMedian { + guessRoc = localRoc - 1 + difference = seq - localSeq - seqNumMax + } else { + guessRoc = localRoc + difference = seq - localSeq + } + } else { + if localSeq-seqNumMedian > seq { + guessRoc = localRoc + 1 + difference = seq - localSeq + seqNumMax + } else { + guessRoc = localRoc + difference = seq - localSeq + } + } + } else { + // localRoc is equal to 0 + difference = seq - localSeq + } + } + + return guessRoc, int64(difference), (guessRoc == 0 && localRoc == maxROC) +} + +func (s *srtpSSRCState) updateRolloverCount(sequenceNumber uint16, difference int64, hasRemoteRoc bool, + remoteRoc uint32, +) { + switch { + case hasRemoteRoc: + s.index = (uint64(remoteRoc) << 16) | uint64(sequenceNumber) + s.rolloverHasProcessed = true + case !s.rolloverHasProcessed: + s.index |= uint64(sequenceNumber) + s.rolloverHasProcessed = true + case difference > 0: + s.index += uint64(difference) + } +} + +func (c *Context) getSRTPSSRCState(ssrc uint32) *srtpSSRCState { + s, ok := c.srtpSSRCStates[ssrc] + if ok { + return s + } + + s = &srtpSSRCState{ + ssrc: ssrc, + replayDetector: c.newSRTPReplayDetector(), + } + c.srtpSSRCStates[ssrc] = s + + return s +} + +func (c *Context) getSRTCPSSRCState(ssrc uint32) *srtcpSSRCState { + s, ok := c.srtcpSSRCStates[ssrc] + if ok { + return s + } + + s = &srtcpSSRCState{ + ssrc: ssrc, + replayDetector: c.newSRTCPReplayDetector(), + } + c.srtcpSSRCStates[ssrc] = s + + return s +} + +// ROC returns SRTP rollover counter value of specified SSRC. +func (c *Context) ROC(ssrc uint32) (uint32, bool) { + s, ok := c.srtpSSRCStates[ssrc] + if !ok { + return 0, false + } + + return uint32(s.index >> 16), true //nolint:gosec // G115 +} + +// SetROC sets SRTP rollover counter value of specified SSRC. +func (c *Context) SetROC(ssrc uint32, roc uint32) { + s := c.getSRTPSSRCState(ssrc) + s.index = uint64(roc) << 16 + s.rolloverHasProcessed = false +} + +// Index returns SRTCP index value of specified SSRC. +func (c *Context) Index(ssrc uint32) (uint32, bool) { + s, ok := c.srtcpSSRCStates[ssrc] + if !ok { + return 0, false + } + + return s.srtcpIndex, true +} + +// SetIndex sets SRTCP index value of specified SSRC. +func (c *Context) SetIndex(ssrc uint32, index uint32) { + s := c.getSRTCPSSRCState(ssrc) + s.srtcpIndex = index % (maxSRTCPIndex + 1) +} + +//nolint:cyclop +func (c *Context) checkRCCMode() error { + if c.rccMode == RCCModeNone { + return nil + } + + if c.rocTransmitRate == 0 { + return errZeroRocTransmitRate + } + + switch c.profile { + case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm: + // AEAD profiles support RCCMode3 only + if c.rccMode != RCCMode3 { + return errUnsupportedRccMode + } + + case ProtectionProfileAes128CmHmacSha1_32, + ProtectionProfileAes256CmHmacSha1_32, + ProtectionProfileNullHmacSha1_32: + if c.authTagRTPLen == nil { + // ROC completely replaces auth tag for _32 profiles. If you really want to use 4-byte + // SRTP auth tag with RCC, use SRTPAuthenticationTagLength(4) option. + return errTooShortSRTPAuthTag + } + + fallthrough // Checks below are common for _32 and _80 profiles. + + case ProtectionProfileAes128CmHmacSha1_80, + ProtectionProfileAes256CmHmacSha1_80, + ProtectionProfileNullHmacSha1_80: + // AES-CM and NULL profiles support RCCMode2 only + if c.rccMode != RCCMode2 { + return errUnsupportedRccMode + } + if c.authTagRTPLen != nil && *c.authTagRTPLen < 4 { + return errTooShortSRTPAuthTag + } + + default: + return errUnsupportedRccMode + } + + return nil +} diff --git a/vendor/github.com/pion/srtp/v3/crypto.go b/vendor/github.com/pion/srtp/v3/crypto.go new file mode 100644 index 0000000..323dcb4 --- /dev/null +++ b/vendor/github.com/pion/srtp/v3/crypto.go @@ -0,0 +1,62 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package srtp + +import ( + "crypto/cipher" + "sync" + + "github.com/pion/transport/v4/utils/xor" +) + +// incrementCTR increments a big-endian integer of arbitrary size. +func incrementCTR(ctr []byte) { + for i := len(ctr) - 1; i >= 0; i-- { + ctr[i]++ + if ctr[i] != 0 { + break + } + } +} + +const xorBufferSize = 32 + +var xorBufferPool = sync.Pool{ // nolint:gochecknoglobals + New: func() any { + return make([]byte, xorBufferSize) + }, +} + +// xorBytesCTR performs CTR encryption and decryption. +// It is equivalent to cipher.NewCTR followed by XORKeyStream. +func xorBytesCTR(block cipher.Block, iv []byte, dst, src []byte) error { + if len(iv) != block.BlockSize() || (len(iv)+block.BlockSize()) > xorBufferSize { + return errBadIVLength + } + + xorBuf := xorBufferPool.Get() + defer xorBufferPool.Put(xorBuf) + buffer, ok := xorBuf.([]byte) + if !ok { + return errFailedTypeAssertion + } + + ctr := buffer[:len(iv)] + copy(ctr, iv) + bs := block.BlockSize() + stream := buffer[len(iv) : len(iv)+bs] + + i := 0 + for i < len(src) { + block.Encrypt(stream, ctr) + incrementCTR(ctr) + n := xor.XorBytes(dst[i:], src[i:], stream) + if n == 0 { + break + } + i += n + } + + return nil +} diff --git a/vendor/github.com/pion/srtp/v3/errors.go b/vendor/github.com/pion/srtp/v3/errors.go new file mode 100644 index 0000000..36b7416 --- /dev/null +++ b/vendor/github.com/pion/srtp/v3/errors.go @@ -0,0 +1,63 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package srtp + +import ( + "errors" + "fmt" +) + +var ( + // ErrFailedToVerifyAuthTag is returned when decryption fails due to invalid authentication tag. + ErrFailedToVerifyAuthTag = errors.New("failed to verify auth tag") + // ErrMKINotFound is returned when decryption fails due to unknown MKI value in packet. + ErrMKINotFound = errors.New("MKI not found") + + errDuplicated = errors.New("duplicated packet") + errShortSrtpMasterKey = errors.New("SRTP master key is not long enough") + errShortSrtpMasterSalt = errors.New("SRTP master salt is not long enough") + errNoSuchSRTPProfile = errors.New("no such SRTP Profile") + errNonZeroKDRNotSupported = errors.New("indexOverKdr > 0 is not supported yet") + errExporterWrongLabel = errors.New("exporter called with wrong label") + errNoConfig = errors.New("no config provided") + errNoConn = errors.New("no conn provided") + errTooShortRTP = errors.New("packet is too short to be RTP packet") + errTooShortRTCP = errors.New("packet is too short to be RTCP packet") + errPayloadDiffers = errors.New("payload differs") + errStartedChannelUsedIncorrectly = errors.New("started channel used incorrectly, should only be closed") + errBadIVLength = errors.New("bad iv length in xorBytesCTR") + errExceededMaxPackets = errors.New("exceeded the maximum number of packets") + errMKIAlreadyInUse = errors.New("MKI already in use") + errMKIIsNotEnabled = errors.New("MKI is not enabled") + errInvalidMKILength = errors.New("invalid MKI length") + errTooLongSRTPAuthTag = errors.New("SRTP auth tag is too long") + errTooShortSRTPAuthTag = errors.New("SRTP auth tag is too short") + + errStreamNotInited = errors.New("stream has not been inited, unable to close") + errStreamAlreadyClosed = errors.New("stream is already closed") + errStreamAlreadyInited = errors.New("stream is already inited") + errFailedTypeAssertion = errors.New("failed to cast child") + + errZeroRocTransmitRate = errors.New("ROC transmit rate is zero") + errUnsupportedRccMode = errors.New("unsupported RCC mode") + + errUnsupportedHeaderExtension = errors.New("unsupported header extension") + errHeaderLengthMismatch = errors.New("header length mismatch") + errUnencryptedHeaderExtAndCSRCs = errors.New("unencrypted header extensions and CSRCs are not allowed") + errCryptexDisabled = errors.New("cryptex is disabled") +) + +type duplicatedError struct { + Proto string // srtp or srtcp + SSRC uint32 + Index uint32 // sequence number or index +} + +func (e *duplicatedError) Error() string { + return fmt.Sprintf("%s ssrc=%d index=%d: %v", e.Proto, e.SSRC, e.Index, errDuplicated) +} + +func (e *duplicatedError) Unwrap() error { + return errDuplicated +} diff --git a/vendor/github.com/pion/srtp/v3/key_derivation.go b/vendor/github.com/pion/srtp/v3/key_derivation.go new file mode 100644 index 0000000..945b569 --- /dev/null +++ b/vendor/github.com/pion/srtp/v3/key_derivation.go @@ -0,0 +1,74 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package srtp + +import ( + "crypto/aes" + "encoding/binary" +) + +func aesCmKeyDerivation(label byte, masterKey, masterSalt []byte, indexOverKdr int, outLen int) ([]byte, error) { + if indexOverKdr != 0 { + // 24-bit "index DIV kdr" must be xored to prf input. + return nil, errNonZeroKDRNotSupported + } + + // https://tools.ietf.org/html/rfc3711#appendix-B.3 + // The input block for AES-CM is generated by exclusive-oring the master salt with the + // concatenation of the encryption key label 0x00 with (index DIV kdr), + // - index is 'rollover count' and DIV is 'divided by' + + nMasterSalt := len(masterSalt) + + prfIn := make([]byte, 16) + copy(prfIn[:nMasterSalt], masterSalt) + + prfIn[7] ^= label + + // The resulting value is then AES encrypted using the master key to get the cipher key. + block, err := aes.NewCipher(masterKey) + if err != nil { + return nil, err + } + + nBlockSize := block.BlockSize() + out := make([]byte, ((outLen+nBlockSize-1)/nBlockSize)*nBlockSize) + var i uint16 + for n := 0; n < outLen; n += nBlockSize { + binary.BigEndian.PutUint16(prfIn[len(prfIn)-2:], i) + block.Encrypt(out[n:n+nBlockSize], prfIn) + i++ + } + + return out[:outLen], nil +} + +// Generate IV https://tools.ietf.org/html/rfc3711#section-4.1.1 +// where the 128-bit integer value IV SHALL be defined by the SSRC, the +// SRTP packet index i, and the SRTP session salting key k_s, as below. +// - ROC = a 32-bit unsigned rollover counter (ROC), which records how many +// - times the 16-bit RTP sequence number has been reset to zero after +// - passing through 65,535 +// i = 2^16 * ROC + SEQ +// IV = (salt*2 ^ 16) | (ssrc*2 ^ 64) | (i*2 ^ 16). +func generateCounter( + sequenceNumber uint16, + rolloverCounter uint32, + ssrc uint32, sessionSalt []byte, +) (counter [16]byte) { + copy(counter[:], sessionSalt) + + counter[4] ^= byte(ssrc >> 24) + counter[5] ^= byte(ssrc >> 16) + counter[6] ^= byte(ssrc >> 8) + counter[7] ^= byte(ssrc) + counter[8] ^= byte(rolloverCounter >> 24) + counter[9] ^= byte(rolloverCounter >> 16) + counter[10] ^= byte(rolloverCounter >> 8) + counter[11] ^= byte(rolloverCounter) + counter[12] ^= byte(sequenceNumber >> 8) + counter[13] ^= byte(sequenceNumber) + + return counter +} diff --git a/vendor/github.com/pion/srtp/v3/keying.go b/vendor/github.com/pion/srtp/v3/keying.go new file mode 100644 index 0000000..c9dc183 --- /dev/null +++ b/vendor/github.com/pion/srtp/v3/keying.go @@ -0,0 +1,59 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package srtp + +const labelExtractorDtlsSrtp = "EXTRACTOR-dtls_srtp" + +// KeyingMaterialExporter allows package SRTP to extract keying material. +type KeyingMaterialExporter interface { + ExportKeyingMaterial(label string, context []byte, length int) ([]byte, error) +} + +// ExtractSessionKeysFromDTLS allows setting the Config SessionKeys by +// extracting them from DTLS. This behavior is defined in RFC5764: +// https://tools.ietf.org/html/rfc5764 +func (c *Config) ExtractSessionKeysFromDTLS(exporter KeyingMaterialExporter, isClient bool) error { + keyLen, err := c.Profile.KeyLen() + if err != nil { + return err + } + + saltLen, err := c.Profile.SaltLen() + if err != nil { + return err + } + + keyingMaterial, err := exporter.ExportKeyingMaterial(labelExtractorDtlsSrtp, nil, (keyLen*2)+(saltLen*2)) + if err != nil { + return err + } + + offset := 0 + clientWriteKey := append([]byte{}, keyingMaterial[offset:offset+keyLen]...) + offset += keyLen + + serverWriteKey := append([]byte{}, keyingMaterial[offset:offset+keyLen]...) + offset += keyLen + + clientWriteKey = append(clientWriteKey, keyingMaterial[offset:offset+saltLen]...) + offset += saltLen + + serverWriteKey = append(serverWriteKey, keyingMaterial[offset:offset+saltLen]...) + + if isClient { + c.Keys.LocalMasterKey = clientWriteKey[0:keyLen] + c.Keys.LocalMasterSalt = clientWriteKey[keyLen:] + c.Keys.RemoteMasterKey = serverWriteKey[0:keyLen] + c.Keys.RemoteMasterSalt = serverWriteKey[keyLen:] + + return nil + } + + c.Keys.LocalMasterKey = serverWriteKey[0:keyLen] + c.Keys.LocalMasterSalt = serverWriteKey[keyLen:] + c.Keys.RemoteMasterKey = clientWriteKey[0:keyLen] + c.Keys.RemoteMasterSalt = clientWriteKey[keyLen:] + + return nil +} diff --git a/vendor/github.com/pion/srtp/v3/option.go b/vendor/github.com/pion/srtp/v3/option.go new file mode 100644 index 0000000..1736a33 --- /dev/null +++ b/vendor/github.com/pion/srtp/v3/option.go @@ -0,0 +1,185 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package srtp + +import ( + "github.com/pion/transport/v4/replaydetector" +) + +// ContextOption represents option of Context using the functional options pattern. +type ContextOption func(*Context) error + +// SRTPReplayProtection sets SRTP replay protection window size. +func SRTPReplayProtection(windowSize uint) ContextOption { // nolint:revive + return func(c *Context) error { + c.newSRTPReplayDetector = func() replaydetector.ReplayDetector { + return replaydetector.New(windowSize, maxROC<<16|maxSequenceNumber) + } + + return nil + } +} + +// SRTCPReplayProtection sets SRTCP replay protection window size. +func SRTCPReplayProtection(windowSize uint) ContextOption { + return func(c *Context) error { + c.newSRTCPReplayDetector = func() replaydetector.ReplayDetector { + return replaydetector.New(windowSize, maxSRTCPIndex) + } + + return nil + } +} + +// SRTPNoReplayProtection disables SRTP replay protection. +func SRTPNoReplayProtection() ContextOption { // nolint:revive + return func(c *Context) error { + c.newSRTPReplayDetector = func() replaydetector.ReplayDetector { + return &nopReplayDetector{} + } + + return nil + } +} + +// SRTCPNoReplayProtection disables SRTCP replay protection. +func SRTCPNoReplayProtection() ContextOption { + return func(c *Context) error { + c.newSRTCPReplayDetector = func() replaydetector.ReplayDetector { + return &nopReplayDetector{} + } + + return nil + } +} + +// SRTPReplayDetectorFactory sets custom SRTP replay detector. +func SRTPReplayDetectorFactory(fn func() replaydetector.ReplayDetector) ContextOption { // nolint:revive + return func(c *Context) error { + c.newSRTPReplayDetector = fn + + return nil + } +} + +// SRTCPReplayDetectorFactory sets custom SRTCP replay detector. +func SRTCPReplayDetectorFactory(fn func() replaydetector.ReplayDetector) ContextOption { + return func(c *Context) error { + c.newSRTCPReplayDetector = fn + + return nil + } +} + +type nopReplayDetector struct{} + +func (s *nopReplayDetector) Check(uint64) (func() bool, bool) { + return func() bool { return true }, true +} + +// MasterKeyIndicator sets RTP/RTCP MKI for the initial master key. Array passed as an argument will be +// copied as-is to encrypted SRTP/SRTCP packets, so it must be of proper length and in Big Endian format. +// All MKIs added later using Context.AddCipherForMKI must have the same length as the one used here. +func MasterKeyIndicator(mki []byte) ContextOption { + return func(c *Context) error { + if len(mki) > 0 { + c.sendMKI = make([]byte, len(mki)) + copy(c.sendMKI, mki) + } + + return nil + } +} + +// SRTPEncryption enables SRTP encryption. +func SRTPEncryption() ContextOption { // nolint:revive + return func(c *Context) error { + c.encryptSRTP = true + + return nil + } +} + +// SRTPNoEncryption disables SRTP encryption. +// This option is useful when you want to use NullCipher for SRTP and keep authentication only. +// It simplifies debugging and testing, but it is not recommended for production use. +// +// Note: you can also use SRTPAuthenticationTagLength(0) to disable authentication tag too. +func SRTPNoEncryption() ContextOption { // nolint:revive + return func(c *Context) error { + c.encryptSRTP = false + + return nil + } +} + +// SRTCPEncryption enables SRTCP encryption. +func SRTCPEncryption() ContextOption { + return func(c *Context) error { + c.encryptSRTCP = true + + return nil + } +} + +// SRTCPNoEncryption disables SRTCP encryption. +// This option is useful when you want to use NullCipher for SRTCP and keep authentication only. +// It simplifies debugging and testing, but it is not recommended for production use. +func SRTCPNoEncryption() ContextOption { + return func(c *Context) error { + c.encryptSRTCP = false + + return nil + } +} + +// RolloverCounterCarryingTransform enables Rollover Counter Carrying Transform from RFC 4771. +// ROC value is sent in Authentication Tag of SRTP packets every rocTransmitRate packets. +// +// RFC 4771 defines 3 RCC modes. pion/srtp supports mode RCCm2 for AES-CM and NULL profiles, +// and mode RCCm3 for AES-GCM (AEAD) profiles. +// +// From RFC 4771: "[For modes RCCm1 and and RCCm3] the length of the MAC is shorter than the length +// of the authentication tag. To achieve the same (or less) MAC forgery success probability on all +// packets when using RCCm1 or RCCm2, as with the default integrity transform in RFC 3711, +// the tag-length must be set to 14 octets, which means that the length of MAC_tr is 10 octets." +// +// Protection profiles ProtectionProfile*CmHmacSha1_32 uses 4-byte SRTP auth tag, so in RCCm2 mode +// SRTP packets with ROC will not be integrity protected. +// +// You can increase the length of the authentication tag using SRTPAuthenticationTagLength option +// to mitigate this issue. +func RolloverCounterCarryingTransform(mode RCCMode, rocTransmitRate uint16) ContextOption { + return func(c *Context) error { + c.rccMode = mode + c.rocTransmitRate = rocTransmitRate + + return nil + } +} + +// SRTPAuthenticationTagLength sets length of SRTP authentication tag in bytes for AES-CM protection +// profiles. Decreasing the length of the authentication tag is not recommended for production use, +// as it decreases integrity protection. +// +// Zero value means that there is no authentication tag, what may be useful for debugging and testing. +// +// This option is ignored for AEAD profiles. +func SRTPAuthenticationTagLength(authTagRTPLen int) ContextOption { // nolint:revive + return func(c *Context) error { + c.authTagRTPLen = &authTagRTPLen + + return nil + } +} + +// Cryptex allows to enable Cryptex mechanism to completely encrypt RTP Header Extensions and Contributing +// Sources, as defined in RFC 9335. +func Cryptex(cryptexMode CryptexMode) ContextOption { + return func(c *Context) error { + c.cryptexMode = cryptexMode + + return nil + } +} diff --git a/vendor/github.com/pion/srtp/v3/protection_profile.go b/vendor/github.com/pion/srtp/v3/protection_profile.go new file mode 100644 index 0000000..181da22 --- /dev/null +++ b/vendor/github.com/pion/srtp/v3/protection_profile.go @@ -0,0 +1,167 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package srtp + +import "fmt" + +// ProtectionProfile specifies Cipher and AuthTag details, similar to TLS cipher suite. +type ProtectionProfile uint16 + +// Supported protection profiles +// See https://www.iana.org/assignments/srtp-protection/srtp-protection.xhtml +// +// AES128_CM_HMAC_SHA1_80 and AES128_CM_HMAC_SHA1_32 are valid SRTP profiles, +// but they do not have an DTLS-SRTP Protection Profiles ID assigned +// in RFC 5764. They were in earlier draft of this RFC: +// https://datatracker.ietf.org/doc/html/draft-ietf-avt-dtls-srtp-03#section-4.1.2 +// Their IDs are now marked as reserved in the IANA registry. Despite this Chrome supports them: +// https://chromium.googlesource.com/chromium/deps/libsrtp/+/84122798bb16927b1e676bd4f938a6e48e5bf2fe/srtp/include/srtp.h#694 +// +// Null profiles disable encryption, they are used for debugging and testing. +// They are not recommended for production use. +// Use of them is equivalent to using ProtectionProfileAes128CmHmacSha1_NN +// profile with SRTPNoEncryption and SRTCPNoEncryption options. +// +//nolint:lll +const ( + ProtectionProfileAes128CmHmacSha1_80 ProtectionProfile = 0x0001 + ProtectionProfileAes128CmHmacSha1_32 ProtectionProfile = 0x0002 + ProtectionProfileAes256CmHmacSha1_80 ProtectionProfile = 0x0003 + ProtectionProfileAes256CmHmacSha1_32 ProtectionProfile = 0x0004 + ProtectionProfileNullHmacSha1_80 ProtectionProfile = 0x0005 + ProtectionProfileNullHmacSha1_32 ProtectionProfile = 0x0006 + ProtectionProfileAeadAes128Gcm ProtectionProfile = 0x0007 + ProtectionProfileAeadAes256Gcm ProtectionProfile = 0x0008 +) + +// KeyLen returns length of encryption key in bytes. +// For all profiles except NullHmacSha1_32 and NullHmacSha1_80 is +// also the length of the session key. +func (p ProtectionProfile) KeyLen() (int, error) { + switch p { + case ProtectionProfileAes128CmHmacSha1_32, + ProtectionProfileAes128CmHmacSha1_80, + ProtectionProfileAeadAes128Gcm, + ProtectionProfileNullHmacSha1_32, + ProtectionProfileNullHmacSha1_80: + return 16, nil + case ProtectionProfileAeadAes256Gcm, ProtectionProfileAes256CmHmacSha1_32, ProtectionProfileAes256CmHmacSha1_80: + return 32, nil + default: + return 0, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, p) + } +} + +// SaltLen returns length of salt key in bytes. +// For all profiles except NullHmacSha1_32 and NullHmacSha1_80 +// is also the length of the session salt. +func (p ProtectionProfile) SaltLen() (int, error) { + switch p { + case ProtectionProfileAes128CmHmacSha1_32, + ProtectionProfileAes128CmHmacSha1_80, + ProtectionProfileAes256CmHmacSha1_32, + ProtectionProfileAes256CmHmacSha1_80, + ProtectionProfileNullHmacSha1_32, + ProtectionProfileNullHmacSha1_80: + return 14, nil + case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm: + return 12, nil + default: + return 0, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, p) + } +} + +// AuthTagRTPLen returns length of RTP authentication tag in bytes for AES protection profiles. +// For AEAD ones it returns zero. +func (p ProtectionProfile) AuthTagRTPLen() (int, error) { + switch p { + case ProtectionProfileAes128CmHmacSha1_80, ProtectionProfileAes256CmHmacSha1_80, ProtectionProfileNullHmacSha1_80: + return 10, nil + case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes256CmHmacSha1_32, ProtectionProfileNullHmacSha1_32: + return 4, nil + case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm: + return 0, nil + default: + return 0, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, p) + } +} + +// AuthTagRTCPLen returns length of RTCP authentication tag in bytes for AES protection profiles. +// +// For AEAD ones it returns zero. +func (p ProtectionProfile) AuthTagRTCPLen() (int, error) { + switch p { + case ProtectionProfileAes128CmHmacSha1_32, + ProtectionProfileAes128CmHmacSha1_80, + ProtectionProfileAes256CmHmacSha1_32, + ProtectionProfileAes256CmHmacSha1_80, + ProtectionProfileNullHmacSha1_32, + ProtectionProfileNullHmacSha1_80: + return 10, nil + case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm: + return 0, nil + default: + return 0, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, p) + } +} + +// AEADAuthTagLen returns length of authentication tag in bytes for AEAD protection profiles. +// For AES ones it returns zero. +func (p ProtectionProfile) AEADAuthTagLen() (int, error) { + switch p { + case ProtectionProfileAes128CmHmacSha1_32, + ProtectionProfileAes128CmHmacSha1_80, + ProtectionProfileAes256CmHmacSha1_32, + ProtectionProfileAes256CmHmacSha1_80, + ProtectionProfileNullHmacSha1_32, + ProtectionProfileNullHmacSha1_80: + return 0, nil + case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm: + return 16, nil + default: + return 0, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, p) + } +} + +// AuthKeyLen returns length of authentication key in bytes for AES protection profiles. +// For AEAD ones it returns zero. +func (p ProtectionProfile) AuthKeyLen() (int, error) { + switch p { + case ProtectionProfileAes128CmHmacSha1_32, + ProtectionProfileAes128CmHmacSha1_80, + ProtectionProfileAes256CmHmacSha1_32, + ProtectionProfileAes256CmHmacSha1_80, + ProtectionProfileNullHmacSha1_32, + ProtectionProfileNullHmacSha1_80: + return 20, nil + case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm: + return 0, nil + default: + return 0, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, p) + } +} + +// String returns the name of the protection profile. +func (p ProtectionProfile) String() string { + switch p { + case ProtectionProfileAes128CmHmacSha1_80: + return "SRTP_AES128_CM_HMAC_SHA1_80" + case ProtectionProfileAes128CmHmacSha1_32: + return "SRTP_AES128_CM_HMAC_SHA1_32" + case ProtectionProfileAes256CmHmacSha1_80: + return "SRTP_AES256_CM_HMAC_SHA1_80" + case ProtectionProfileAes256CmHmacSha1_32: + return "SRTP_AES256_CM_HMAC_SHA1_32" + case ProtectionProfileAeadAes128Gcm: + return "SRTP_AEAD_AES_128_GCM" + case ProtectionProfileAeadAes256Gcm: + return "SRTP_AEAD_AES_256_GCM" + case ProtectionProfileNullHmacSha1_80: + return "SRTP_NULL_HMAC_SHA1_80" + case ProtectionProfileNullHmacSha1_32: + return "SRTP_NULL_HMAC_SHA1_32" + default: + return fmt.Sprintf("Unknown SRTP profile: %#v", p) + } +} diff --git a/vendor/github.com/pion/srtp/v3/protection_profile_with_args.go b/vendor/github.com/pion/srtp/v3/protection_profile_with_args.go new file mode 100644 index 0000000..a0e08be --- /dev/null +++ b/vendor/github.com/pion/srtp/v3/protection_profile_with_args.go @@ -0,0 +1,21 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package srtp + +// protectionProfileWithArgs is a wrapper around ProtectionProfile that allows to +// specify additional arguments for the profile. +type protectionProfileWithArgs struct { + ProtectionProfile + authTagRTPLen *int +} + +// AuthTagRTPLen returns length of RTP authentication tag in bytes for AES protection profiles. +// For AEAD ones it returns zero. +func (p protectionProfileWithArgs) AuthTagRTPLen() (int, error) { + if p.authTagRTPLen != nil { + return *p.authTagRTPLen, nil + } + + return p.ProtectionProfile.AuthTagRTPLen() +} diff --git a/vendor/github.com/pion/srtp/v3/renovate.json b/vendor/github.com/pion/srtp/v3/renovate.json new file mode 100644 index 0000000..f1bb98c --- /dev/null +++ b/vendor/github.com/pion/srtp/v3/renovate.json @@ -0,0 +1,6 @@ +{ + "$schema": "https://docs.renovatebot.com/renovate-schema.json", + "extends": [ + "github>pion/renovate-config" + ] +} diff --git a/vendor/github.com/pion/srtp/v3/session.go b/vendor/github.com/pion/srtp/v3/session.go new file mode 100644 index 0000000..ad7eba8 --- /dev/null +++ b/vendor/github.com/pion/srtp/v3/session.go @@ -0,0 +1,168 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package srtp + +import ( + "errors" + "io" + "net" + "sync" + "time" + + "github.com/pion/logging" + "github.com/pion/transport/v4/packetio" +) + +type streamSession interface { + Close() error + write([]byte) (int, error) + decrypt([]byte) error +} + +type session struct { + localContextMutex sync.Mutex + localContext, remoteContext *Context + localOptions, remoteOptions []ContextOption + + newStream chan readStream + acceptStreamTimeout time.Time + + started chan any + closed chan any + + readStreamsClosed bool + readStreams map[uint32]readStream + readStreamsLock sync.Mutex + + log logging.LeveledLogger + bufferFactory func(packetType packetio.BufferPacketType, ssrc uint32) io.ReadWriteCloser + + nextConn net.Conn +} + +// Config is used to configure a session. +// You can provide either a KeyingMaterialExporter to export keys +// or directly pass the keys themselves. +// After a Config is passed to a session it must not be modified. +type Config struct { + Keys SessionKeys + Profile ProtectionProfile + BufferFactory func(packetType packetio.BufferPacketType, ssrc uint32) io.ReadWriteCloser + LoggerFactory logging.LoggerFactory + AcceptStreamTimeout time.Time + + // List of local/remote context options. + // ReplayProtection is enabled on remote context by default. + // Default replay protection window size is 64. + LocalOptions, RemoteOptions []ContextOption +} + +// SessionKeys bundles the keys required to setup an SRTP session. +type SessionKeys struct { + LocalMasterKey []byte + LocalMasterSalt []byte + RemoteMasterKey []byte + RemoteMasterSalt []byte +} + +func (s *session) getOrCreateReadStream(ssrc uint32, child streamSession, proto func() readStream) (readStream, bool) { + s.readStreamsLock.Lock() + defer s.readStreamsLock.Unlock() + + if s.readStreamsClosed { + return nil, false + } + + rStream, ok := s.readStreams[ssrc] + if ok { + return rStream, false + } + + // Create the readStream. + rStream = proto() + + if err := rStream.init(child, ssrc); err != nil { + return nil, false + } + + s.readStreams[ssrc] = rStream + + return rStream, true +} + +func (s *session) removeReadStream(ssrc uint32) { + s.readStreamsLock.Lock() + defer s.readStreamsLock.Unlock() + + if s.readStreamsClosed { + return + } + + delete(s.readStreams, ssrc) +} + +func (s *session) close() error { + if s.nextConn == nil { + return nil + } else if err := s.nextConn.Close(); err != nil { + return err + } + + <-s.closed + + return nil +} + +func (s *session) start( + localMasterKey, localMasterSalt, remoteMasterKey, remoteMasterSalt []byte, + profile ProtectionProfile, + child streamSession, +) error { + var err error + s.localContext, err = CreateContext(localMasterKey, localMasterSalt, profile, s.localOptions...) + if err != nil { + return err + } + + s.remoteContext, err = CreateContext(remoteMasterKey, remoteMasterSalt, profile, s.remoteOptions...) + if err != nil { + return err + } + + if err = s.nextConn.SetReadDeadline(s.acceptStreamTimeout); err != nil { + return err + } + + go func() { + defer func() { + close(s.newStream) + + s.readStreamsLock.Lock() + s.readStreamsClosed = true + s.readStreamsLock.Unlock() + close(s.closed) + }() + + b := make([]byte, 8192) + for { + var i int + i, err = s.nextConn.Read(b) + if err != nil { + if !errors.Is(err, io.EOF) { + s.log.Error(err.Error()) + } + + return + } + + if err = child.decrypt(b[:i]); err != nil { + s.log.Info(err.Error()) + } + } + }() + + close(s.started) + + return nil +} diff --git a/vendor/github.com/pion/srtp/v3/session_srtcp.go b/vendor/github.com/pion/srtp/v3/session_srtcp.go new file mode 100644 index 0000000..753576f --- /dev/null +++ b/vendor/github.com/pion/srtp/v3/session_srtcp.go @@ -0,0 +1,193 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package srtp + +import ( + "net" + "time" + + "github.com/pion/logging" + "github.com/pion/rtcp" +) + +const defaultSessionSRTCPReplayProtectionWindow = 64 + +// SessionSRTCP implements io.ReadWriteCloser and provides a bi-directional SRTCP session +// SRTCP itself does not have a design like this, but it is common in most applications +// for local/remote to each have their own keying material. This provides those patterns +// instead of making everyone re-implement. +type SessionSRTCP struct { + session + writeStream *WriteStreamSRTCP +} + +// NewSessionSRTCP creates a SRTCP session using conn as the underlying transport. +func NewSessionSRTCP(conn net.Conn, config *Config) (*SessionSRTCP, error) { //nolint:dupl + if config == nil { + return nil, errNoConfig + } else if conn == nil { + return nil, errNoConn + } + + loggerFactory := config.LoggerFactory + if loggerFactory == nil { + loggerFactory = logging.NewDefaultLoggerFactory() + } + + localOpts := append( + []ContextOption{}, + config.LocalOptions..., + ) + remoteOpts := append( + []ContextOption{ + // Default options + SRTCPReplayProtection(defaultSessionSRTCPReplayProtectionWindow), + }, + config.RemoteOptions..., + ) + + srtcpSession := &SessionSRTCP{ + session: session{ + nextConn: conn, + localOptions: localOpts, + remoteOptions: remoteOpts, + readStreams: map[uint32]readStream{}, + newStream: make(chan readStream), + acceptStreamTimeout: config.AcceptStreamTimeout, + started: make(chan any), + closed: make(chan any), + bufferFactory: config.BufferFactory, + log: loggerFactory.NewLogger("srtp"), + }, + } + srtcpSession.writeStream = &WriteStreamSRTCP{srtcpSession} + + err := srtcpSession.session.start( + config.Keys.LocalMasterKey, config.Keys.LocalMasterSalt, + config.Keys.RemoteMasterKey, config.Keys.RemoteMasterSalt, + config.Profile, + srtcpSession, + ) + if err != nil { + return nil, err + } + + return srtcpSession, nil +} + +// OpenWriteStream returns the global write stream for the Session. +func (s *SessionSRTCP) OpenWriteStream() (*WriteStreamSRTCP, error) { + return s.writeStream, nil +} + +// OpenReadStream opens a read stream for the given SSRC, it can be used +// if you want a certain SSRC, but don't want to wait for AcceptStream. +func (s *SessionSRTCP) OpenReadStream(ssrc uint32) (*ReadStreamSRTCP, error) { + r, _ := s.session.getOrCreateReadStream(ssrc, s, newReadStreamSRTCP) + + if readStream, ok := r.(*ReadStreamSRTCP); ok { + return readStream, nil + } + + return nil, errFailedTypeAssertion +} + +// AcceptStream returns a stream to handle RTCP for a single SSRC. +func (s *SessionSRTCP) AcceptStream() (*ReadStreamSRTCP, uint32, error) { + stream, ok := <-s.newStream + if !ok { + return nil, 0, errStreamAlreadyClosed + } + + readStream, ok := stream.(*ReadStreamSRTCP) + if !ok { + return nil, 0, errFailedTypeAssertion + } + + return readStream, stream.GetSSRC(), nil +} + +// Close ends the session. +func (s *SessionSRTCP) Close() error { + return s.session.close() +} + +// Private + +func (s *SessionSRTCP) write(buf []byte) (int, error) { + if _, ok := <-s.session.started; ok { + return 0, errStartedChannelUsedIncorrectly + } + + ibuf := bufferpool.Get() + defer bufferpool.Put(ibuf) + + s.session.localContextMutex.Lock() + encrypted, err := s.localContext.EncryptRTCP(ibuf.([]byte), buf, nil) //nolint:forcetypeassert + s.session.localContextMutex.Unlock() + + if err != nil { + return 0, err + } + + return s.session.nextConn.Write(encrypted) +} + +func (s *SessionSRTCP) setWriteDeadline(t time.Time) error { + return s.session.nextConn.SetWriteDeadline(t) +} + +// create a list of Destination SSRCs +// that's a superset of all Destinations in the slice. +func destinationSSRC(pkts []rtcp.Packet) []uint32 { + ssrcSet := make(map[uint32]struct{}) + for _, p := range pkts { + for _, ssrc := range p.DestinationSSRC() { + ssrcSet[ssrc] = struct{}{} + } + } + + out := make([]uint32, 0, len(ssrcSet)) + for ssrc := range ssrcSet { + out = append(out, ssrc) + } + + return out +} + +func (s *SessionSRTCP) decrypt(buf []byte) error { + decrypted, err := s.remoteContext.DecryptRTCP(buf, buf, nil) + if err != nil { + return err + } + + pkt, err := rtcp.Unmarshal(decrypted) + if err != nil { + return err + } + + for _, ssrc := range destinationSSRC(pkt) { + r, isNew := s.session.getOrCreateReadStream(ssrc, s, newReadStreamSRTCP) + if r == nil { + return nil // Session has been closed + } else if isNew { + if !s.session.acceptStreamTimeout.IsZero() { + _ = s.session.nextConn.SetReadDeadline(time.Time{}) + } + s.session.newStream <- r // Notify AcceptStream + } + + readStream, ok := r.(*ReadStreamSRTCP) + if !ok { + return errFailedTypeAssertion + } + + _, err = readStream.write(decrypted) + if err != nil { + return err + } + } + + return nil +} diff --git a/vendor/github.com/pion/srtp/v3/session_srtp.go b/vendor/github.com/pion/srtp/v3/session_srtp.go new file mode 100644 index 0000000..73ff253 --- /dev/null +++ b/vendor/github.com/pion/srtp/v3/session_srtp.go @@ -0,0 +1,213 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package srtp + +import ( + "net" + "sync" + "time" + + "github.com/pion/logging" + "github.com/pion/rtp" +) + +const defaultSessionSRTPReplayProtectionWindow = 64 + +// SessionSRTP implements io.ReadWriteCloser and provides a bi-directional SRTP session +// SRTP itself does not have a design like this, but it is common in most applications +// for local/remote to each have their own keying material. This provides those patterns +// instead of making everyone re-implement. +type SessionSRTP struct { + session + writeStream *WriteStreamSRTP +} + +// NewSessionSRTP creates a SRTP session using conn as the underlying transport. +func NewSessionSRTP(conn net.Conn, config *Config) (*SessionSRTP, error) { //nolint:dupl + if config == nil { + return nil, errNoConfig + } else if conn == nil { + return nil, errNoConn + } + + loggerFactory := config.LoggerFactory + if loggerFactory == nil { + loggerFactory = logging.NewDefaultLoggerFactory() + } + + localOpts := append( + []ContextOption{}, + config.LocalOptions..., + ) + remoteOpts := append( + []ContextOption{ + // Default options + SRTPReplayProtection(defaultSessionSRTPReplayProtectionWindow), + }, + config.RemoteOptions..., + ) + + srtpSession := &SessionSRTP{ + session: session{ + nextConn: conn, + localOptions: localOpts, + remoteOptions: remoteOpts, + readStreams: map[uint32]readStream{}, + newStream: make(chan readStream), + acceptStreamTimeout: config.AcceptStreamTimeout, + started: make(chan any), + closed: make(chan any), + bufferFactory: config.BufferFactory, + log: loggerFactory.NewLogger("srtp"), + }, + } + srtpSession.writeStream = &WriteStreamSRTP{srtpSession} + + err := srtpSession.session.start( + config.Keys.LocalMasterKey, config.Keys.LocalMasterSalt, + config.Keys.RemoteMasterKey, config.Keys.RemoteMasterSalt, + config.Profile, + srtpSession, + ) + if err != nil { + return nil, err + } + + return srtpSession, nil +} + +// OpenWriteStream returns the global write stream for the Session. +func (s *SessionSRTP) OpenWriteStream() (*WriteStreamSRTP, error) { + return s.writeStream, nil +} + +// OpenReadStream opens a read stream for the given SSRC, it can be used +// if you want a certain SSRC, but don't want to wait for AcceptStream. +func (s *SessionSRTP) OpenReadStream(ssrc uint32) (*ReadStreamSRTP, error) { + r, _ := s.session.getOrCreateReadStream(ssrc, s, newReadStreamSRTP) + + if readStream, ok := r.(*ReadStreamSRTP); ok { + return readStream, nil + } + + return nil, errFailedTypeAssertion +} + +// AcceptStream returns a stream to handle RTCP for a single SSRC. +func (s *SessionSRTP) AcceptStream() (*ReadStreamSRTP, uint32, error) { + stream, ok := <-s.newStream + if !ok { + return nil, 0, errStreamAlreadyClosed + } + + readStream, ok := stream.(*ReadStreamSRTP) + if !ok { + return nil, 0, errFailedTypeAssertion + } + + return readStream, stream.GetSSRC(), nil +} + +// Close ends the session. +func (s *SessionSRTP) Close() error { + return s.session.close() +} + +func (s *SessionSRTP) write(b []byte) (int, error) { + packet := &rtp.Packet{} + + if err := packet.Unmarshal(b); err != nil { + return 0, err + } + + return s.writeRTP(&packet.Header, packet.Payload) +} + +// bufferpool is a global pool of buffers used for encrypted packets in +// writeRTP below. Since it's global, buffers can be shared between +// different sessions, which amortizes the cost of allocating the pool. +// +// 1472 is the maximum Ethernet UDP payload. We give ourselves 20 bytes +// of slack for any authentication tags, which is more than enough for +// either CTR or GCM. If the buffer is too small, no harm, it will just +// get expanded by growBuffer. +var bufferpool = sync.Pool{ // nolint:gochecknoglobals + New: func() any { + return make([]byte, 1492) + }, +} + +func (s *SessionSRTP) writeRTP(header *rtp.Header, payload []byte) (int, error) { + if _, ok := <-s.session.started; ok { + return 0, errStartedChannelUsedIncorrectly + } + + // encryptRTP will either return our buffer, or, if it is too + // small, allocate a new buffer itself. In either case, it is + // safe to put the buffer back into the pool, but only after + // nextConn.Write has returned. + ibuf := bufferpool.Get() + defer bufferpool.Put(ibuf) + + buf := ibuf.([]byte) // nolint:forcetypeassert + headerLen, marshalSize := rtp.HeaderAndPacketMarshalSize(header, payload) // nolint:staticcheck + if len(buf) < marshalSize+20 { + // The buffer is too small, so we need to allocate a new one. Add 20 bytes for auth tag like + // for bufferpool above. + buf = make([]byte, marshalSize+20) + } + _, err := rtp.MarshalPacketTo(buf, header, payload) // nolint:staticcheck + if err != nil { + return 0, err + } + + s.session.localContextMutex.Lock() + encrypted, err := s.localContext.encryptRTP(buf, header, headerLen, buf[:marshalSize]) + s.session.localContextMutex.Unlock() + + if err != nil { + return 0, err + } + + return s.session.nextConn.Write(encrypted) +} + +func (s *SessionSRTP) setWriteDeadline(t time.Time) error { + return s.session.nextConn.SetWriteDeadline(t) +} + +func (s *SessionSRTP) decrypt(buf []byte) error { + header := &rtp.Header{} + headerLen, err := header.Unmarshal(buf) + if err != nil { + return err + } + + r, isNew := s.session.getOrCreateReadStream(header.SSRC, s, newReadStreamSRTP) + if r == nil { + return nil // Session has been closed + } else if isNew { + if !s.session.acceptStreamTimeout.IsZero() { + _ = s.session.nextConn.SetReadDeadline(time.Time{}) + } + s.session.newStream <- r // Notify AcceptStream + } + + readStream, ok := r.(*ReadStreamSRTP) + if !ok { + return errFailedTypeAssertion + } + + decrypted, err := s.remoteContext.decryptRTP(buf, buf, header, headerLen) + if err != nil { + return err + } + + _, err = readStream.write(decrypted) + if err != nil { + return err + } + + return nil +} diff --git a/vendor/github.com/pion/srtp/v3/srtcp.go b/vendor/github.com/pion/srtp/v3/srtcp.go new file mode 100644 index 0000000..6f70508 --- /dev/null +++ b/vendor/github.com/pion/srtp/v3/srtcp.go @@ -0,0 +1,122 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package srtp + +import ( + "encoding/binary" + "fmt" + + "github.com/pion/rtcp" +) + +/* +Simplified structure of SRTCP Packets: +- RTCP Header +- Payload +- AEAD Auth Tag - used by AEAD profiles only +- E flag and SRTCP Index +- MKI (optional) +- Auth Tag - used by non-AEAD profiles only +*/ + +const ( + maxSRTCPIndex = 0x7FFFFFFF + + srtcpHeaderSize = 8 + srtcpIndexSize = 4 + srtcpEncryptionFlag = 0x80 +) + +func (c *Context) decryptRTCP(dst, encrypted []byte) ([]byte, error) { + authTagLen, err := c.cipher.AuthTagRTCPLen() + if err != nil { + return nil, err + } + aeadAuthTagLen, err := c.cipher.AEADAuthTagLen() + if err != nil { + return nil, err + } + mkiLen := len(c.sendMKI) + + // Verify that encrypted packet is long enough + if len(encrypted) < (srtcpHeaderSize + aeadAuthTagLen + srtcpIndexSize + mkiLen + authTagLen) { + return nil, fmt.Errorf("%w: %d", errTooShortRTCP, len(encrypted)) + } + + index := c.cipher.getRTCPIndex(encrypted) + ssrc := binary.BigEndian.Uint32(encrypted[4:]) + + s := c.getSRTCPSSRCState(ssrc) + markAsValid, ok := s.replayDetector.Check(uint64(index)) + if !ok { + return nil, &duplicatedError{Proto: "srtcp", SSRC: ssrc, Index: index} + } + + cipher := c.cipher + if len(c.mkis) > 0 { + // Find cipher for MKI + actualMKI := encrypted[len(encrypted)-mkiLen-authTagLen : len(encrypted)-authTagLen] + cipher, ok = c.mkis[string(actualMKI)] + if !ok { + return nil, ErrMKINotFound + } + } + + out, err := cipher.decryptRTCP(dst, encrypted, index, ssrc) + if err != nil { + return nil, err + } + + markAsValid() + + return out, nil +} + +// DecryptRTCP decrypts a buffer that contains a RTCP packet. +func (c *Context) DecryptRTCP(dst, encrypted []byte, header *rtcp.Header) ([]byte, error) { + if header == nil { + header = &rtcp.Header{} + } + + if err := header.Unmarshal(encrypted); err != nil { + return nil, err + } + + return c.decryptRTCP(dst, encrypted) +} + +func (c *Context) encryptRTCP(dst, decrypted []byte) ([]byte, error) { + if len(decrypted) < srtcpHeaderSize { + return nil, fmt.Errorf("%w: %d", errTooShortRTCP, len(decrypted)) + } + + ssrc := binary.BigEndian.Uint32(decrypted[4:]) + ssrcState := c.getSRTCPSSRCState(ssrc) + + if ssrcState.srtcpIndex >= maxSRTCPIndex { + // ... when 2^48 SRTP packets or 2^31 SRTCP packets have been secured with the same key + // (whichever occurs before), the key management MUST be called to provide new master key(s) + // (previously stored and used keys MUST NOT be used again), or the session MUST be terminated. + // https://www.rfc-editor.org/rfc/rfc3711#section-9.2 + return nil, errExceededMaxPackets + } + + // We roll over early because MSB is used for marking as encrypted + ssrcState.srtcpIndex++ + + return c.cipher.encryptRTCP(dst, decrypted, ssrcState.srtcpIndex, ssrc) +} + +// EncryptRTCP Encrypts a RTCP packet. +func (c *Context) EncryptRTCP(dst, decrypted []byte, header *rtcp.Header) ([]byte, error) { + if header == nil { + header = &rtcp.Header{} + } + + if err := header.Unmarshal(decrypted); err != nil { + return nil, err + } + + return c.encryptRTCP(dst, decrypted) +} diff --git a/vendor/github.com/pion/srtp/v3/srtp.go b/vendor/github.com/pion/srtp/v3/srtp.go new file mode 100644 index 0000000..044f72d --- /dev/null +++ b/vendor/github.com/pion/srtp/v3/srtp.go @@ -0,0 +1,185 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package srtp implements Secure Real-time Transport Protocol +package srtp + +import ( + "encoding/binary" + "fmt" + + "github.com/pion/rtp" +) + +/* +Simplified structure of SRTP Packets: +- RTP Header (with optional RTP Header Extension) +- Payload (with optional padding) +- AEAD Auth Tag - used by AEAD profiles only +- MKI (optional) +- Auth Tag - used by non-AEAD profiles only. When RCC is used with AEAD profiles, the ROC is sent here. +*/ + +func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerLen int) ([]byte, error) { + authTagLen, err := c.cipher.AuthTagRTPLen() + if err != nil { + return nil, err + } + aeadAuthTagLen, err := c.cipher.AEADAuthTagLen() + if err != nil { + return nil, err + } + mkiLen := len(c.sendMKI) + + var hasRocInPacket bool + hasRocInPacket, authTagLen = c.hasROCInPacket(header, authTagLen) + + // Verify that encrypted packet is long enough + if len(ciphertext) < (headerLen + aeadAuthTagLen + mkiLen + authTagLen) { + return nil, fmt.Errorf("%w: %d", errTooShortRTP, len(ciphertext)) + } + + ssrcState := c.getSRTPSSRCState(header.SSRC) + + var roc uint32 + var diff int64 + var index uint64 + if !hasRocInPacket { + // The ROC is not sent in the packet. We need to guess it. + roc, diff, _ = ssrcState.nextRolloverCount(header.SequenceNumber) + index = (uint64(roc) << 16) | uint64(header.SequenceNumber) + } else { + // Extract ROC from the packet. The ROC is sent in the first 4 bytes of the auth tag. + roc = binary.BigEndian.Uint32(ciphertext[len(ciphertext)-authTagLen:]) + index = (uint64(roc) << 16) | uint64(header.SequenceNumber) + diff = int64(ssrcState.index) - int64(index) //nolint:gosec + } + + markAsValid, ok := ssrcState.replayDetector.Check(index) + if !ok { + return nil, &duplicatedError{ + Proto: "srtp", SSRC: header.SSRC, Index: uint32(header.SequenceNumber), + } + } + + err = c.checkCryptex(header) + if err != nil { + return nil, err + } + + cipher := c.cipher + if len(c.mkis) > 0 { + // Find cipher for MKI + actualMKI := ciphertext[len(ciphertext)-mkiLen-authTagLen : len(ciphertext)-authTagLen] + cipher, ok = c.mkis[string(actualMKI)] + if !ok { + return nil, ErrMKINotFound + } + } + + dst = growBufferSize(dst, len(ciphertext)-authTagLen-mkiLen) + + dst, err = cipher.decryptRTP(dst, ciphertext, header, headerLen, roc, hasRocInPacket) + if err != nil { + return nil, err + } + + markAsValid() + ssrcState.updateRolloverCount(header.SequenceNumber, diff, hasRocInPacket, roc) + + return dst, nil +} + +// DecryptRTP decrypts a RTP packet with an encrypted payload. +func (c *Context) DecryptRTP(dst, encrypted []byte, header *rtp.Header) ([]byte, error) { + if header == nil { + header = &rtp.Header{} + } + + headerLen, err := header.Unmarshal(encrypted) + if err != nil { + return nil, err + } + + return c.decryptRTP(dst, encrypted, header, headerLen) +} + +// EncryptRTP marshals and encrypts an RTP packet, writing to the dst buffer provided. +// If the dst buffer does not have the capacity to hold `len(plaintext) + 10` bytes, +// a new one will be allocated and returned. +// If a rtp.Header is provided, it will be Unmarshaled using the plaintext. +func (c *Context) EncryptRTP(dst []byte, plaintext []byte, header *rtp.Header) ([]byte, error) { + if header == nil { + header = &rtp.Header{} + } + + headerLen, err := header.Unmarshal(plaintext) + if err != nil { + return nil, err + } + + return c.encryptRTP(dst, header, headerLen, plaintext) +} + +// encryptRTP marshals and encrypts an RTP packet, writing to the dst buffer provided. +// If the dst buffer does not have the capacity, a new one will be allocated and returned. +// Similar to above but faster because it can avoid unmarshaling the header and marshaling the payload. +func (c *Context) encryptRTP(dst []byte, header *rtp.Header, headerLen int, plaintext []byte, +) (ciphertext []byte, err error) { + // RFC 9335, section 5.1: This mechanism [Cryptex] MUST NOT be used with header extensions other than + // the variety described in [RFC8285]. + if c.cryptexMode != CryptexModeDisabled && header.Extension && + header.ExtensionProfile != rtp.ExtensionProfileOneByte && + header.ExtensionProfile != rtp.ExtensionProfileTwoByte { + return nil, errUnsupportedHeaderExtension + } + + s := c.getSRTPSSRCState(header.SSRC) + roc, diff, ovf := s.nextRolloverCount(header.SequenceNumber) + if ovf { + // ... when 2^48 SRTP packets or 2^31 SRTCP packets have been secured with the same key + // (whichever occurs before), the key management MUST be called to provide new master key(s) + // (previously stored and used keys MUST NOT be used again), or the session MUST be terminated. + // https://www.rfc-editor.org/rfc/rfc3711#section-9.2 + return nil, errExceededMaxPackets + } + s.updateRolloverCount(header.SequenceNumber, diff, false, 0) + + rocInPacket := c.rccMode != RCCModeNone && header.SequenceNumber%c.rocTransmitRate == 0 + + return c.cipher.encryptRTP(dst, header, headerLen, plaintext, roc, rocInPacket) +} + +func (c *Context) hasROCInPacket(header *rtp.Header, authTagLen int) (bool, int) { + hasRocInPacket := false + switch c.rccMode { + case RCCMode2: + // This mode is supported for AES-CM and NULL profiles only. The ROC is sent in the first 4 bytes of the auth tag. + hasRocInPacket = header.SequenceNumber%c.rocTransmitRate == 0 + case RCCMode3: + // This mode is supported for AES-GCM only. The ROC is sent as 4-byte auth tag. + hasRocInPacket = header.SequenceNumber%c.rocTransmitRate == 0 + if hasRocInPacket { + authTagLen = 4 + } + default: + } + + return hasRocInPacket, authTagLen +} + +func (c *Context) checkCryptex(header *rtp.Header) error { + switch c.cryptexMode { + case CryptexModeDisabled: + if isCryptexPacket(header) { + return errCryptexDisabled + } + case CryptexModeRequired: + if (header.Extension || len(header.CSRC) > 0) && !isCryptexPacket(header) { + return errUnencryptedHeaderExtAndCSRCs + } + default: + } + + return nil +} diff --git a/vendor/github.com/pion/srtp/v3/srtp_cipher.go b/vendor/github.com/pion/srtp/v3/srtp_cipher.go new file mode 100644 index 0000000..3464e3e --- /dev/null +++ b/vendor/github.com/pion/srtp/v3/srtp_cipher.go @@ -0,0 +1,50 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package srtp + +import "github.com/pion/rtp" + +// cipher represents a implementation of one +// of the SRTP Specific ciphers. +type srtpCipher interface { + // AuthTagRTPLen/AuthTagRTCPLen return auth key length of the cipher. + // See the note below. + AuthTagRTPLen() (int, error) + AuthTagRTCPLen() (int, error) + // AEADAuthTagLen returns AEAD auth key length of the cipher. + // See the note below. + AEADAuthTagLen() (int, error) + getRTCPIndex([]byte) uint32 + + encryptRTP([]byte, *rtp.Header, int, []byte, uint32, bool) ([]byte, error) + encryptRTCP([]byte, []byte, uint32, uint32) ([]byte, error) + + decryptRTP([]byte, []byte, *rtp.Header, int, uint32, bool) ([]byte, error) + decryptRTCP([]byte, []byte, uint32, uint32) ([]byte, error) +} + +/* +NOTE: Auth tag and AEAD auth tag are placed at the different position in SRTCP + +In non-AEAD cipher, the authentication tag is placed *after* the ESRTCP word +(Encrypted-flag and SRTCP index). + +> AES_128_CM_HMAC_SHA1_80 +> | RTCP Header | Encrypted payload |E| SRTCP Index | Auth tag | +> ^ |----------| +> | ^ +> | authTagLen=10 +> aeadAuthTagLen=0 + +In AEAD cipher, the AEAD authentication tag is embedded in the ciphertext. +It is *before* the ESRTCP word (Encrypted-flag and SRTCP index). + +> AEAD_AES_128_GCM +> | RTCP Header | Encrypted payload | AEAD auth tag |E| SRTCP Index | +> |---------------| ^ +> ^ authTagLen=0 +> aeadAuthTagLen=16 + +See https://tools.ietf.org/html/rfc7714 for the full specifications. +*/ diff --git a/vendor/github.com/pion/srtp/v3/srtp_cipher_aead_aes_gcm.go b/vendor/github.com/pion/srtp/v3/srtp_cipher_aead_aes_gcm.go new file mode 100644 index 0000000..9320d40 --- /dev/null +++ b/vendor/github.com/pion/srtp/v3/srtp_cipher_aead_aes_gcm.go @@ -0,0 +1,392 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package srtp + +import ( + "crypto/aes" + "crypto/cipher" + "encoding/binary" + "fmt" + + "github.com/pion/rtp" +) + +type srtpCipherAeadAesGcm struct { + protectionProfileWithArgs + + srtpCipher, srtcpCipher cipher.AEAD + + srtpSessionSalt, srtcpSessionSalt []byte + + mki []byte + + srtpEncrypted, srtcpEncrypted bool + + useCryptex bool +} + +func newSrtpCipherAeadAesGcm( + profile protectionProfileWithArgs, + masterKey, masterSalt, mki []byte, + encryptSRTP, encryptSRTCP, useCryptex bool, +) (*srtpCipherAeadAesGcm, error) { + srtpCipher := &srtpCipherAeadAesGcm{ + protectionProfileWithArgs: profile, + srtpEncrypted: encryptSRTP, + srtcpEncrypted: encryptSRTCP, + useCryptex: useCryptex, + } + + srtpSessionKey, err := aesCmKeyDerivation(labelSRTPEncryption, masterKey, masterSalt, 0, len(masterKey)) + if err != nil { + return nil, err + } + + srtpBlock, err := aes.NewCipher(srtpSessionKey) + if err != nil { + return nil, err + } + + srtpCipher.srtpCipher, err = cipher.NewGCM(srtpBlock) + if err != nil { + return nil, err + } + + srtcpSessionKey, err := aesCmKeyDerivation(labelSRTCPEncryption, masterKey, masterSalt, 0, len(masterKey)) + if err != nil { + return nil, err + } + + srtcpBlock, err := aes.NewCipher(srtcpSessionKey) + if err != nil { + return nil, err + } + + srtpCipher.srtcpCipher, err = cipher.NewGCM(srtcpBlock) + if err != nil { + return nil, err + } + + if srtpCipher.srtpSessionSalt, err = aesCmKeyDerivation( + labelSRTPSalt, masterKey, masterSalt, 0, len(masterSalt), + ); err != nil { + return nil, err + } else if srtpCipher.srtcpSessionSalt, err = aesCmKeyDerivation( + labelSRTCPSalt, masterKey, masterSalt, 0, len(masterSalt), + ); err != nil { + return nil, err + } + + mkiLen := len(mki) + if mkiLen > 0 { + srtpCipher.mki = make([]byte, mkiLen) + copy(srtpCipher.mki, mki) + } + + return srtpCipher, nil +} + +func (s *srtpCipherAeadAesGcm) encryptRTP( + dst []byte, + header *rtp.Header, + headerLen int, + plaintext []byte, + roc uint32, + rocInAuthTag bool, +) (ciphertext []byte, err error) { + // Grow the given buffer to fit the output. + authTagLen, err := s.AEADAuthTagLen() + if err != nil { + return nil, err + } + payloadLen := len(plaintext) - headerLen + authPartLen := headerLen + payloadLen + authTagLen + dstLen := authPartLen + len(s.mki) + if rocInAuthTag { + dstLen += 4 + } + + insertEmptyExtHdr := needsEmptyExtensionHeader(s.useCryptex, header) + if insertEmptyExtHdr { + dstLen += extensionHeaderSize + } + + dst = growBufferSize(dst, dstLen) + sameBuffer := isSameBuffer(dst, plaintext) + + if insertEmptyExtHdr { + plaintext = insertEmptyExtensionHeader(dst, plaintext, sameBuffer, header) + sameBuffer = true + headerLen += extensionHeaderSize + } + + err = s.doEncryptRTP(dst, header, headerLen, plaintext, roc, rocInAuthTag, sameBuffer, payloadLen, authPartLen) + if err != nil { + return nil, err + } + + return dst, nil +} + +func (s *srtpCipherAeadAesGcm) doEncryptRTP(dst []byte, header *rtp.Header, headerLen int, plaintext []byte, roc uint32, + rocInAuthTag bool, sameBuffer bool, payloadLen int, authPartLen int, +) error { + iv := s.rtpInitializationVector(header, roc) + encrypt := func(dst, plaintext []byte, headerLen int) error { + s.srtpCipher.Seal(dst[headerLen:headerLen], iv[:], plaintext[headerLen:], plaintext[:headerLen]) + + return nil + } + + switch { + case s.useCryptex && header.Extension: + err := encryptCryptexRTP(dst, plaintext, sameBuffer, header, encrypt) + if err != nil { + return err + } + case s.srtpEncrypted: + // Copy the header unencrypted. + if !sameBuffer { + copy(dst, plaintext[:headerLen]) + } + s.srtpCipher.Seal(dst[headerLen:headerLen], iv[:], plaintext[headerLen:], dst[:headerLen]) + default: + clearLen := headerLen + payloadLen + if !sameBuffer { + copy(dst, plaintext) + } + s.srtpCipher.Seal(dst[clearLen:clearLen], iv[:], nil, dst[:clearLen]) + } + + // Add MKI after the encrypted payload + if len(s.mki) > 0 { + copy(dst[authPartLen:], s.mki) + } + + if rocInAuthTag { + binary.BigEndian.PutUint32(dst[len(dst)-4:], roc) + } + + return nil +} + +func (s *srtpCipherAeadAesGcm) decryptRTP( + dst, ciphertext []byte, + header *rtp.Header, + headerLen int, + roc uint32, + rocInAuthTag bool, +) ([]byte, error) { + // Grow the given buffer to fit the output. + authTagLen, err := s.AEADAuthTagLen() + if err != nil { + return nil, err + } + rocLen := 0 + if rocInAuthTag { + rocLen = 4 + } + nDst := len(ciphertext) - authTagLen - len(s.mki) - rocLen + if nDst < headerLen { + // Size of ciphertext is shorter than AEAD auth tag len. + return nil, ErrFailedToVerifyAuthTag + } + dst = growBufferSize(dst, nDst) + sameBuffer := isSameBuffer(dst, ciphertext) + + nEnd := len(ciphertext) - len(s.mki) - rocLen + + err = s.doDecryptRTP(dst, ciphertext, header, headerLen, roc, sameBuffer, nEnd, authTagLen) + if err != nil { + return nil, err + } + + return dst, nil +} + +func (s *srtpCipherAeadAesGcm) doDecryptRTP(dst, ciphertext []byte, header *rtp.Header, headerLen int, roc uint32, + sameBuffer bool, nEnd int, authTagLen int, +) error { + iv := s.rtpInitializationVector(header, roc) + decrypt := func(dst, ciphertext []byte, headerLen int) error { + _, err := s.srtpCipher.Open(dst[headerLen:headerLen], iv[:], ciphertext[headerLen:nEnd], ciphertext[:headerLen]) + + return err + } + + switch { + case isCryptexPacket(header): + err := decryptCryptexRTP(dst, ciphertext, sameBuffer, header, headerLen, decrypt) + if err != nil { + return fmt.Errorf("%w: %w", ErrFailedToVerifyAuthTag, err) + } + case s.srtpEncrypted: + if err := decrypt(dst, ciphertext[:nEnd], headerLen); err != nil { + return fmt.Errorf("%w: %w", ErrFailedToVerifyAuthTag, err) + } + // Copy the header unencrypted. + if !sameBuffer { + copy(dst[:headerLen], ciphertext[:headerLen]) + } + default: + nDataEnd := nEnd - authTagLen + if _, err := s.srtpCipher.Open( + nil, iv[:], ciphertext[nDataEnd:nEnd], ciphertext[:nDataEnd], + ); err != nil { + return fmt.Errorf("%w: %w", ErrFailedToVerifyAuthTag, err) + } + // Copy the header and payload unencrypted. + if !sameBuffer { + copy(dst, ciphertext[:nDataEnd]) + } + } + + return nil +} + +func (s *srtpCipherAeadAesGcm) encryptRTCP(dst, decrypted []byte, srtcpIndex uint32, ssrc uint32) ([]byte, error) { + authTagLen, err := s.AEADAuthTagLen() + if err != nil { + return nil, err + } + aadPos := len(decrypted) + authTagLen + // Grow the given buffer to fit the output. + dst = growBufferSize(dst, aadPos+srtcpIndexSize+len(s.mki)) + sameBuffer := isSameBuffer(dst, decrypted) + + iv := s.rtcpInitializationVector(srtcpIndex, ssrc) + if s.srtcpEncrypted { + aad := s.rtcpAdditionalAuthenticatedData(decrypted, srtcpIndex) + if !sameBuffer { + // Copy the header unencrypted. + copy(dst[:srtcpHeaderSize], decrypted[:srtcpHeaderSize]) + } + // Copy index to the proper place. + copy(dst[aadPos:aadPos+srtcpIndexSize], aad[8:12]) + s.srtcpCipher.Seal(dst[srtcpHeaderSize:srtcpHeaderSize], iv[:], decrypted[srtcpHeaderSize:], aad[:]) + } else { + // Copy the packet unencrypted. + if !sameBuffer { + copy(dst, decrypted) + } + // Append the SRTCP index to the end of the packet - this will form the AAD. + binary.BigEndian.PutUint32(dst[len(decrypted):], srtcpIndex) + // Generate the authentication tag. + tag := make([]byte, authTagLen) + s.srtcpCipher.Seal(tag[0:0], iv[:], nil, dst[:len(decrypted)+srtcpIndexSize]) + // Copy index to the proper place. + copy(dst[aadPos:], dst[len(decrypted):len(decrypted)+srtcpIndexSize]) + // Copy the auth tag after RTCP payload. + copy(dst[len(decrypted):], tag) + } + + copy(dst[aadPos+srtcpIndexSize:], s.mki) + + return dst, nil +} + +func (s *srtpCipherAeadAesGcm) decryptRTCP(dst, encrypted []byte, srtcpIndex, ssrc uint32) ([]byte, error) { + aadPos := len(encrypted) - srtcpIndexSize - len(s.mki) + // Grow the given buffer to fit the output. + authTagLen, err := s.AEADAuthTagLen() + if err != nil { + return nil, err + } + nDst := aadPos - authTagLen + if nDst < 0 { + // Size of ciphertext is shorter than AEAD auth tag len. + return nil, ErrFailedToVerifyAuthTag + } + dst = growBufferSize(dst, nDst) + sameBuffer := isSameBuffer(dst, encrypted) + + isEncrypted := encrypted[aadPos]&srtcpEncryptionFlag != 0 + iv := s.rtcpInitializationVector(srtcpIndex, ssrc) + if isEncrypted { + aad := s.rtcpAdditionalAuthenticatedData(encrypted, srtcpIndex) + if _, err := s.srtcpCipher.Open(dst[srtcpHeaderSize:srtcpHeaderSize], iv[:], encrypted[srtcpHeaderSize:aadPos], + aad[:]); err != nil { + return nil, fmt.Errorf("%w: %w", ErrFailedToVerifyAuthTag, err) + } + } else { + // Prepare AAD for received packet. + dataEnd := aadPos - authTagLen + aad := make([]byte, dataEnd+4) + copy(aad, encrypted[:dataEnd]) + copy(aad[dataEnd:], encrypted[aadPos:aadPos+4]) + // Verify the auth tag. + if _, err := s.srtcpCipher.Open(nil, iv[:], encrypted[dataEnd:aadPos], aad); err != nil { + return nil, fmt.Errorf("%w: %w", ErrFailedToVerifyAuthTag, err) + } + // Copy the unencrypted payload. + if !sameBuffer { + copy(dst[srtcpHeaderSize:], encrypted[srtcpHeaderSize:dataEnd]) + } + } + + // Copy the header unencrypted. + if !sameBuffer { + copy(dst[:srtcpHeaderSize], encrypted[:srtcpHeaderSize]) + } + + return dst, nil +} + +// The 12-octet IV used by AES-GCM SRTP is formed by first concatenating +// 2 octets of zeroes, the 4-octet SSRC, the 4-octet rollover counter +// (ROC), and the 2-octet sequence number (SEQ). The resulting 12-octet +// value is then XORed to the 12-octet salt to form the 12-octet IV. +// +// https://tools.ietf.org/html/rfc7714#section-8.1 +func (s *srtpCipherAeadAesGcm) rtpInitializationVector(header *rtp.Header, roc uint32) [12]byte { + var iv [12]byte + binary.BigEndian.PutUint32(iv[2:], header.SSRC) + binary.BigEndian.PutUint32(iv[6:], roc) + binary.BigEndian.PutUint16(iv[10:], header.SequenceNumber) + + for i := range iv { + iv[i] ^= s.srtpSessionSalt[i] + } + + return iv +} + +// The 12-octet IV used by AES-GCM SRTCP is formed by first +// concatenating 2 octets of zeroes, the 4-octet SSRC identifier, +// 2 octets of zeroes, a single "0" bit, and the 31-bit SRTCP index. +// The resulting 12-octet value is then XORed to the 12-octet salt to +// form the 12-octet IV. +// +// https://tools.ietf.org/html/rfc7714#section-9.1 +func (s *srtpCipherAeadAesGcm) rtcpInitializationVector(srtcpIndex uint32, ssrc uint32) [12]byte { + var iv [12]byte + + binary.BigEndian.PutUint32(iv[2:], ssrc) + binary.BigEndian.PutUint32(iv[8:], srtcpIndex) + + for i := range iv { + iv[i] ^= s.srtcpSessionSalt[i] + } + + return iv +} + +// In an SRTCP packet, a 1-bit Encryption flag is prepended to the +// 31-bit SRTCP index to form a 32-bit value we shall call the +// "ESRTCP word" +// +// https://tools.ietf.org/html/rfc7714#section-17 +func (s *srtpCipherAeadAesGcm) rtcpAdditionalAuthenticatedData(rtcpPacket []byte, srtcpIndex uint32) [12]byte { + var aad [12]byte + + copy(aad[:], rtcpPacket[:8]) + binary.BigEndian.PutUint32(aad[8:], srtcpIndex) + aad[8] |= srtcpEncryptionFlag + + return aad +} + +func (s *srtpCipherAeadAesGcm) getRTCPIndex(in []byte) uint32 { + return binary.BigEndian.Uint32(in[len(in)-len(s.mki)-srtcpIndexSize:]) &^ (srtcpEncryptionFlag << 24) +} diff --git a/vendor/github.com/pion/srtp/v3/srtp_cipher_aes_cm_hmac_sha1.go b/vendor/github.com/pion/srtp/v3/srtp_cipher_aes_cm_hmac_sha1.go new file mode 100644 index 0000000..1283ce1 --- /dev/null +++ b/vendor/github.com/pion/srtp/v3/srtp_cipher_aes_cm_hmac_sha1.go @@ -0,0 +1,438 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package srtp + +import ( //nolint:gci + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/sha1" //nolint:gosec + "crypto/subtle" + "encoding/binary" + "hash" + + "github.com/pion/rtp" +) + +type srtpCipherAesCmHmacSha1 struct { + protectionProfileWithArgs + + srtpSessionSalt []byte + srtpSessionAuth hash.Hash + srtpBlock cipher.Block + srtpEncrypted bool + + srtcpSessionSalt []byte + srtcpSessionAuth hash.Hash + srtcpBlock cipher.Block + srtcpEncrypted bool + + mki []byte + + useCryptex bool +} + +//nolint:cyclop +func newSrtpCipherAesCmHmacSha1( + profile protectionProfileWithArgs, + masterKey, masterSalt, mki []byte, + encryptSRTP, encryptSRTCP, useCryptex bool, +) (*srtpCipherAesCmHmacSha1, error) { + switch profile.ProtectionProfile { + case ProtectionProfileNullHmacSha1_80, ProtectionProfileNullHmacSha1_32: + encryptSRTP = false + encryptSRTCP = false + default: + } + + srtpCipher := &srtpCipherAesCmHmacSha1{ + protectionProfileWithArgs: profile, + srtpEncrypted: encryptSRTP, + srtcpEncrypted: encryptSRTCP, + useCryptex: useCryptex, + } + + srtpSessionKey, err := aesCmKeyDerivation(labelSRTPEncryption, masterKey, masterSalt, 0, len(masterKey)) + if err != nil { + return nil, err + } else if srtpCipher.srtpBlock, err = aes.NewCipher(srtpSessionKey); err != nil { + return nil, err + } + + srtcpSessionKey, err := aesCmKeyDerivation(labelSRTCPEncryption, masterKey, masterSalt, 0, len(masterKey)) + if err != nil { + return nil, err + } else if srtpCipher.srtcpBlock, err = aes.NewCipher(srtcpSessionKey); err != nil { + return nil, err + } + + if srtpCipher.srtpSessionSalt, err = aesCmKeyDerivation( + labelSRTPSalt, masterKey, masterSalt, 0, len(masterSalt), + ); err != nil { + return nil, err + } else if srtpCipher.srtcpSessionSalt, err = aesCmKeyDerivation( + labelSRTCPSalt, masterKey, masterSalt, 0, len(masterSalt), + ); err != nil { + return nil, err + } + + authKeyLen, err := profile.AuthKeyLen() + if err != nil { + return nil, err + } + + srtpSessionAuthTag, err := aesCmKeyDerivation(labelSRTPAuthenticationTag, masterKey, masterSalt, 0, authKeyLen) + if err != nil { + return nil, err + } + + srtcpSessionAuthTag, err := aesCmKeyDerivation(labelSRTCPAuthenticationTag, masterKey, masterSalt, 0, authKeyLen) + if err != nil { + return nil, err + } + + srtpCipher.srtcpSessionAuth = hmac.New(sha1.New, srtcpSessionAuthTag) + srtpCipher.srtpSessionAuth = hmac.New(sha1.New, srtpSessionAuthTag) + + mkiLen := len(mki) + if mkiLen > 0 { + srtpCipher.mki = make([]byte, mkiLen) + copy(srtpCipher.mki, mki) + } + + return srtpCipher, nil +} + +func (s *srtpCipherAesCmHmacSha1) encryptRTP( + dst []byte, + header *rtp.Header, + headerLen int, + plaintext []byte, + roc uint32, + rocInAuthTag bool, +) (ciphertext []byte, err error) { + // Grow the given buffer to fit the output. + authTagLen, err := s.AuthTagRTPLen() + if err != nil { + return nil, err + } + payloadLen := len(plaintext) - headerLen + dstLen := headerLen + payloadLen + len(s.mki) + authTagLen + + insertEmptyExtHdr := needsEmptyExtensionHeader(s.useCryptex, header) + if insertEmptyExtHdr { + dstLen += extensionHeaderSize + } + + dst = growBufferSize(dst, dstLen) + sameBuffer := isSameBuffer(dst, plaintext) + + if insertEmptyExtHdr { + // Insert an empty extension header to plaintext using dst buffer. After this operation dst is used as the + // plaintext buffer for next operations. + plaintext = insertEmptyExtensionHeader(dst, plaintext, sameBuffer, header) + sameBuffer = true + headerLen += extensionHeaderSize + } + + err = s.doEncryptRTP(dst, header, headerLen, plaintext, roc, rocInAuthTag, sameBuffer, payloadLen) + if err != nil { + return nil, err + } + + return dst, nil +} + +func (s *srtpCipherAesCmHmacSha1) doEncryptRTP(dst []byte, header *rtp.Header, headerLen int, plaintext []byte, + roc uint32, rocInAuthTag bool, sameBuffer bool, payloadLen int, +) error { + encrypt := func(dst, plaintext []byte, headerLen int) error { + counter := generateCounter(header.SequenceNumber, roc, header.SSRC, s.srtpSessionSalt) + + return xorBytesCTR(s.srtpBlock, counter[:], dst[headerLen:], plaintext[headerLen:]) + } + + var err error + switch { + case s.useCryptex && header.Extension: + err = encryptCryptexRTP(dst, plaintext, sameBuffer, header, encrypt) + case s.srtpEncrypted: + // Copy the header unencrypted. + if !sameBuffer { + copy(dst, plaintext[:headerLen]) + } + // Encrypt the payload + err = encrypt(dst, plaintext, headerLen) + case !sameBuffer: + copy(dst, plaintext) + default: + } + if err != nil { + return err + } + n := headerLen + payloadLen + + // Generate the auth tag. + authTag, err := s.generateSrtpAuthTag(dst[:n], roc, rocInAuthTag) + if err != nil { + return err + } + + // Append the MKI (if used) + if len(s.mki) > 0 { + copy(dst[n:], s.mki) + n += len(s.mki) + } + + // Write the auth tag to the dest. + copy(dst[n:], authTag) + + return nil +} + +func (s *srtpCipherAesCmHmacSha1) decryptRTP( + dst, ciphertext []byte, + header *rtp.Header, + headerLen int, + roc uint32, + rocInAuthTag bool, +) ([]byte, error) { + // Split the auth tag and the cipher text into two parts. + authTagLen, err := s.AuthTagRTPLen() + if err != nil { + return nil, err + } + + // Split the auth tag and the cipher text into two parts. + actualTag := ciphertext[len(ciphertext)-authTagLen:] + ciphertext = ciphertext[:len(ciphertext)-len(s.mki)-authTagLen] + + // Generate the auth tag we expect to see from the ciphertext. + expectedTag, err := s.generateSrtpAuthTag(ciphertext, roc, rocInAuthTag) + if err != nil { + return nil, err + } + + // See if the auth tag actually matches. + // We use a constant time comparison to prevent timing attacks. + if subtle.ConstantTimeCompare(actualTag, expectedTag) != 1 { + return nil, ErrFailedToVerifyAuthTag + } + + sameBuffer := isSameBuffer(dst, ciphertext) + + err = s.doDecryptRTP(dst, ciphertext, header, headerLen, roc, sameBuffer) + if err != nil { + return nil, err + } + + return dst, nil +} + +func (s *srtpCipherAesCmHmacSha1) doDecryptRTP(dst, ciphertext []byte, header *rtp.Header, headerLen int, roc uint32, + sameBuffer bool, +) error { + decrypt := func(dst, ciphertext []byte, headerLen int) error { + counter := generateCounter(header.SequenceNumber, roc, header.SSRC, s.srtpSessionSalt) + + return xorBytesCTR(s.srtpBlock, counter[:], dst[headerLen:], ciphertext[headerLen:]) + } + + switch { + case isCryptexPacket(header): + err := decryptCryptexRTP(dst, ciphertext, sameBuffer, header, headerLen, decrypt) + if err != nil { + return err + } + case s.srtpEncrypted: + // Write the plaintext header to the destination buffer. + if !sameBuffer { + copy(dst, ciphertext[:headerLen]) + } + + // Decrypt the ciphertext for the payload. + err := decrypt(dst, ciphertext, headerLen) + if err != nil { + return err + } + case !sameBuffer: + copy(dst, ciphertext) + default: + } + + return nil +} + +func (s *srtpCipherAesCmHmacSha1) encryptRTCP(dst, decrypted []byte, srtcpIndex uint32, ssrc uint32) ([]byte, error) { + authTagLen, err := s.AuthTagRTCPLen() + if err != nil { + return nil, err + } + mkiLen := len(s.mki) + decryptedLen := len(decrypted) + encryptedLen := decryptedLen + authTagLen + mkiLen + srtcpIndexSize + + dst = growBufferSize(dst, encryptedLen) + sameBuffer := isSameBuffer(dst, decrypted) + + if !sameBuffer { + copy(dst, decrypted[:srtcpHeaderSize]) // Copy the first 8 bytes (RTCP header) + } + + // Encrypt everything after header + if s.srtcpEncrypted { + counter := generateCounter(uint16(srtcpIndex&0xffff), srtcpIndex>>16, ssrc, s.srtcpSessionSalt) //nolint:gosec // G115 + if err = xorBytesCTR(s.srtcpBlock, counter[:], dst[srtcpHeaderSize:], decrypted[srtcpHeaderSize:]); err != nil { + return nil, err + } + + // Add SRTCP Index and set Encryption bit + binary.BigEndian.PutUint32(dst[decryptedLen:], srtcpIndex) + dst[decryptedLen] |= srtcpEncryptionFlag + } else { + // Copy the decrypted payload as is + if !sameBuffer { + copy(dst[srtcpHeaderSize:], decrypted[srtcpHeaderSize:]) + } + + // Add SRTCP Index with Encryption bit cleared + binary.BigEndian.PutUint32(dst[decryptedLen:], srtcpIndex) + } + + n := decryptedLen + srtcpIndexSize + + // Generate the authentication tag + authTag, err := s.generateSrtcpAuthTag(dst[:n]) + if err != nil { + return nil, err + } + + // Include the MKI if provided + if len(s.mki) > 0 { + copy(dst[n:], s.mki) + n += mkiLen + } + + // Append the auth tag at the end of the buffer + copy(dst[n:], authTag) + + return dst, nil +} + +func (s *srtpCipherAesCmHmacSha1) decryptRTCP(dst, encrypted []byte, index, ssrc uint32) ([]byte, error) { + authTagLen, err := s.AuthTagRTCPLen() + if err != nil { + return nil, err + } + mkiLen := len(s.mki) + encryptedLen := len(encrypted) + decryptedLen := encryptedLen - (authTagLen + mkiLen + srtcpIndexSize) + if decryptedLen < 8 { + return nil, errTooShortRTCP + } + + expectedTag, err := s.generateSrtcpAuthTag(encrypted[:encryptedLen-mkiLen-authTagLen]) + if err != nil { + return nil, err + } + + actualTag := encrypted[encryptedLen-authTagLen:] + if subtle.ConstantTimeCompare(actualTag, expectedTag) != 1 { + return nil, ErrFailedToVerifyAuthTag + } + + dst = growBufferSize(dst, decryptedLen) + sameBuffer := isSameBuffer(dst, encrypted) + + if !sameBuffer { + copy(dst, encrypted[:srtcpHeaderSize]) // Copy the first 8 bytes (RTCP header) + } + + isEncrypted := encrypted[decryptedLen]&srtcpEncryptionFlag != 0 + if isEncrypted { + counter := generateCounter(uint16(index&0xffff), index>>16, ssrc, s.srtcpSessionSalt) //nolint:gosec // G115 + err = xorBytesCTR(s.srtcpBlock, counter[:], dst[srtcpHeaderSize:], encrypted[srtcpHeaderSize:decryptedLen]) + } else if !sameBuffer { + copy(dst[srtcpHeaderSize:], encrypted[srtcpHeaderSize:]) + } + + return dst, err +} + +func (s *srtpCipherAesCmHmacSha1) generateSrtpAuthTag(buf []byte, roc uint32, rocInAuthTag bool) ([]byte, error) { + // https://tools.ietf.org/html/rfc3711#section-4.2 + // In the case of SRTP, M SHALL consist of the Authenticated + // Portion of the packet (as specified in Figure 1) concatenated with + // the ROC, M = Authenticated Portion || ROC; + // + // The pre-defined authentication transform for SRTP is HMAC-SHA1 + // [RFC2104]. With HMAC-SHA1, the SRTP_PREFIX_LENGTH (Figure 3) SHALL + // be 0. For SRTP (respectively SRTCP), the HMAC SHALL be applied to + // the session authentication key and M as specified above, i.e., + // HMAC(k_a, M). The HMAC output SHALL then be truncated to the n_tag + // left-most bits. + // - Authenticated portion of the packet is everything BEFORE MKI + // - k_a is the session message authentication key + // - n_tag is the bit-length of the output authentication tag + s.srtpSessionAuth.Reset() + + if _, err := s.srtpSessionAuth.Write(buf); err != nil { + return nil, err + } + + // For SRTP only, we need to hash the rollover counter as well. + rocRaw := [4]byte{} + binary.BigEndian.PutUint32(rocRaw[:], roc) + + _, err := s.srtpSessionAuth.Write(rocRaw[:]) + if err != nil { + return nil, err + } + + // Truncate the hash to the size indicated by the profile + authTagLen, err := s.AuthTagRTPLen() + if err != nil { + return nil, err + } + + var authTag []byte + if rocInAuthTag { + authTag = append(authTag, rocRaw[:]...) + } + + return s.srtpSessionAuth.Sum(authTag)[0:authTagLen], nil +} + +func (s *srtpCipherAesCmHmacSha1) generateSrtcpAuthTag(buf []byte) ([]byte, error) { + // https://tools.ietf.org/html/rfc3711#section-4.2 + // + // The pre-defined authentication transform for SRTP is HMAC-SHA1 + // [RFC2104]. With HMAC-SHA1, the SRTP_PREFIX_LENGTH (Figure 3) SHALL + // be 0. For SRTP (respectively SRTCP), the HMAC SHALL be applied to + // the session authentication key and M as specified above, i.e., + // HMAC(k_a, M). The HMAC output SHALL then be truncated to the n_tag + // left-most bits. + // - Authenticated portion of the packet is everything BEFORE MKI + // - k_a is the session message authentication key + // - n_tag is the bit-length of the output authentication tag + s.srtcpSessionAuth.Reset() + + if _, err := s.srtcpSessionAuth.Write(buf); err != nil { + return nil, err + } + authTagLen, err := s.AuthTagRTCPLen() + if err != nil { + return nil, err + } + + return s.srtcpSessionAuth.Sum(nil)[0:authTagLen], nil +} + +func (s *srtpCipherAesCmHmacSha1) getRTCPIndex(in []byte) uint32 { + authTagLen, _ := s.AuthTagRTCPLen() + tailOffset := len(in) - (authTagLen + srtcpIndexSize + len(s.mki)) + srtcpIndexBuffer := in[tailOffset : tailOffset+srtcpIndexSize] + + return binary.BigEndian.Uint32(srtcpIndexBuffer) &^ (1 << 31) +} diff --git a/vendor/github.com/pion/srtp/v3/srtp_cryptex.go b/vendor/github.com/pion/srtp/v3/srtp_cryptex.go new file mode 100644 index 0000000..acc150c --- /dev/null +++ b/vendor/github.com/pion/srtp/v3/srtp_cryptex.go @@ -0,0 +1,189 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package srtp + +import ( + "encoding/binary" + + "github.com/pion/rtp" +) + +/* +RFC 9335: Completely Encrypting RTP Header Extensions and Contributing Sources + +Section 6.2. Encryption Procedure + +When this mechanism [Cryptex] is active, the SRTP packet is protected as follows: + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+<+ + |V=2|P|X| CC |M| PT | sequence number | | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | + | timestamp | | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | + | synchronization source (SSRC) identifier | | ++>+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ | +| | contributing source (CSRC) identifiers | | +| | .... | | ++>+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | +X | 0xC0 or 0xC2 | 0xDE | length | | ++>+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | +| | RFC 8285 header extensions | | +| +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | +| | payload ... | | +| | +-------------------------------+ | +| | | RTP padding | RTP pad count | | ++>+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+<+ +| ~ SRTP Master Key Identifier (MKI) (OPTIONAL) ~ | +| +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | +| : authentication tag (RECOMMENDED) : | +| +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | +| | ++- Encrypted Portion Authenticated Portion ---+ +Figure 1: A Protected SRTP Packet +Note that, as required by [RFC8285], the 4 bytes at the start of the extension block are not encrypted. + +Specifically, the Encrypted Portion MUST include any CSRC identifiers, any RTP header extension +(except for the first 4 bytes), and the RTP payload. +*/ + +const ( + minSrtpHeaderSize = 12 // Minimum size of the SRTP header (12 bytes for RTP header without CSRCs and extensions) + extensionHeaderSize = 4 // Size of the header extension (4 bytes for profile and length fields) +) + +func isCryptexPacket(header *rtp.Header) bool { + return header.Extension && + (header.ExtensionProfile == rtp.CryptexProfileOneByte || header.ExtensionProfile == rtp.CryptexProfileTwoByte) +} + +func moveHeaderExtensionBeforeCSRCs(header *rtp.Header, buf []byte) { + if len(header.CSRC) == 0 || !header.Extension { + return + } + + var tmp [extensionHeaderSize]byte + csrcLen := len(header.CSRC) * 4 + copy(tmp[:], buf[minSrtpHeaderSize+csrcLen:minSrtpHeaderSize+csrcLen+extensionHeaderSize]) + copy(buf[minSrtpHeaderSize+extensionHeaderSize:], buf[minSrtpHeaderSize:minSrtpHeaderSize+csrcLen]) + copy(buf[minSrtpHeaderSize:], tmp[:]) +} + +func moveCSRCsBeforeHeaderExtension(header *rtp.Header, buf []byte) { + if len(header.CSRC) == 0 || !header.Extension { + return + } + + var tmp [extensionHeaderSize]byte + csrcLen := len(header.CSRC) * 4 + copy(tmp[:], buf[minSrtpHeaderSize:minSrtpHeaderSize+extensionHeaderSize]) + copy(buf[minSrtpHeaderSize:], + buf[minSrtpHeaderSize+extensionHeaderSize:minSrtpHeaderSize+csrcLen+extensionHeaderSize]) + copy(buf[minSrtpHeaderSize+csrcLen:], tmp[:]) +} + +func encryptCryptexRTP(dst, plaintext []byte, sameBuffer bool, header *rtp.Header, + encrypt func(dst, plaintext []byte, headerLen int) error, +) error { + moveHeaderExtensionBeforeCSRCs(header, plaintext) + + // Update Header Extension Profile to Cryptex one + if header.ExtensionProfile == rtp.ExtensionProfileOneByte { + binary.BigEndian.PutUint16(plaintext[minSrtpHeaderSize:], rtp.CryptexProfileOneByte) + } else { + binary.BigEndian.PutUint16(plaintext[minSrtpHeaderSize:], rtp.CryptexProfileTwoByte) + } + + err := encrypt(dst, plaintext, minSrtpHeaderSize+extensionHeaderSize) + if err != nil { + binary.BigEndian.PutUint16(plaintext[minSrtpHeaderSize:], header.ExtensionProfile) + moveCSRCsBeforeHeaderExtension(header, plaintext) + + return err + } + + if !sameBuffer { + copy(dst, plaintext[:minSrtpHeaderSize+extensionHeaderSize]) + binary.BigEndian.PutUint16(plaintext[minSrtpHeaderSize:], header.ExtensionProfile) + moveCSRCsBeforeHeaderExtension(header, plaintext) + } + moveCSRCsBeforeHeaderExtension(header, dst) + + return nil +} + +func decryptCryptexRTP(dst, ciphertext []byte, sameBuffer bool, header *rtp.Header, headerLen int, + decrypt func(dst, ciphertext []byte, headerLen int) error, +) error { + moveHeaderExtensionBeforeCSRCs(header, ciphertext) + err := decrypt(dst, ciphertext, minSrtpHeaderSize+extensionHeaderSize) + if err != nil { + moveCSRCsBeforeHeaderExtension(header, ciphertext) + + return err + } + + if !sameBuffer { + copy(dst, ciphertext[:minSrtpHeaderSize+extensionHeaderSize]) + moveCSRCsBeforeHeaderExtension(header, dst) + } + moveCSRCsBeforeHeaderExtension(header, ciphertext) + + // Update Header Extension Profile + offset := minSrtpHeaderSize + len(header.CSRC)*4 + if header.ExtensionProfile == rtp.CryptexProfileOneByte { + binary.BigEndian.PutUint16(dst[offset:], rtp.ExtensionProfileOneByte) + } else { + binary.BigEndian.PutUint16(dst[offset:], rtp.ExtensionProfileTwoByte) + } + + // Unmarshal decrypted header extension. + n, err := header.Unmarshal(dst) + if err != nil { + return err + } + if n != headerLen { + return errHeaderLengthMismatch + } + + return nil +} + +// RFC 9335, section 5.1: If the packet contains CSRCs but no header extensions, an empty extension block +// consisting of the 0xC0DE tag and a 16-bit length field set to zero (explicitly permitted by [RFC3550]) +// MUST be appended, and the X bit MUST be set to 1 to indicate an extension block is present. + +func needsEmptyExtensionHeader(useCryptex bool, header *rtp.Header) bool { + return useCryptex && len(header.CSRC) > 0 && !header.Extension +} + +// insertEmptyExtensionHeader inserts an empty extension header into the RTP packet. It assumes that the dst is big +// enough to hold extra data. +func insertEmptyExtensionHeader(dst, plaintext []byte, sameBuffer bool, header *rtp.Header) []byte { + header.Extension = true + header.ExtensionProfile = rtp.ExtensionProfileOneByte + header.Extensions = nil + + var emptyExtHdr [extensionHeaderSize]byte + binary.BigEndian.PutUint16(emptyExtHdr[:], rtp.ExtensionProfileOneByte) + + offset := minSrtpHeaderSize + len(header.CSRC)*4 + plaintextLen := len(plaintext) + if sameBuffer { + plaintext = plaintext[:plaintextLen+extensionHeaderSize] + copy(plaintext[offset+extensionHeaderSize:], plaintext[offset:plaintextLen]) + copy(plaintext[offset:], emptyExtHdr[:]) + } else { + newPlaintext := dst[:plaintextLen+extensionHeaderSize] + copy(newPlaintext, plaintext[:offset]) + copy(newPlaintext[offset:], emptyExtHdr[:]) + copy(newPlaintext[offset+extensionHeaderSize:], plaintext[offset:plaintextLen]) + plaintext = newPlaintext + } + + plaintext[0] |= 0x10 // Set the X bit to indicate an extension block is present + + return plaintext +} diff --git a/vendor/github.com/pion/srtp/v3/stream.go b/vendor/github.com/pion/srtp/v3/stream.go new file mode 100644 index 0000000..5f9c58a --- /dev/null +++ b/vendor/github.com/pion/srtp/v3/stream.go @@ -0,0 +1,11 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package srtp + +type readStream interface { + init(child streamSession, ssrc uint32) error + + Read(buf []byte) (int, error) + GetSSRC() uint32 +} diff --git a/vendor/github.com/pion/srtp/v3/stream_srtcp.go b/vendor/github.com/pion/srtp/v3/stream_srtcp.go new file mode 100644 index 0000000..87f4d98 --- /dev/null +++ b/vendor/github.com/pion/srtp/v3/stream_srtcp.go @@ -0,0 +1,162 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package srtp + +import ( + "errors" + "io" + "sync" + "time" + + "github.com/pion/rtcp" + "github.com/pion/transport/v4/packetio" +) + +// Limit the buffer size to 100KB. +const srtcpBufferSize = 100 * 1000 + +// ReadStreamSRTCP handles decryption for a single RTCP SSRC. +type ReadStreamSRTCP struct { + mu sync.Mutex + + isClosed chan bool + + session *SessionSRTCP + ssrc uint32 + isInited bool + + buffer io.ReadWriteCloser +} + +func (r *ReadStreamSRTCP) write(buf []byte) (n int, err error) { + n, err = r.buffer.Write(buf) + + if errors.Is(err, packetio.ErrFull) { + // Silently drop data when the buffer is full. + return len(buf), nil + } + + return n, err +} + +// Used by getOrCreateReadStream. +func newReadStreamSRTCP() readStream { + return &ReadStreamSRTCP{} +} + +// ReadRTCP reads and decrypts full RTCP packet and its header from the nextConn. +func (r *ReadStreamSRTCP) ReadRTCP(buf []byte) (int, *rtcp.Header, error) { + n, err := r.Read(buf) + if err != nil { + return 0, nil, err + } + + header := &rtcp.Header{} + err = header.Unmarshal(buf[:n]) + if err != nil { + return 0, nil, err + } + + return n, header, nil +} + +// Read reads and decrypts full RTCP packet from the nextConn. +func (r *ReadStreamSRTCP) Read(buf []byte) (int, error) { + return r.buffer.Read(buf) +} + +// SetReadDeadline sets the deadline for the Read operation. +// Setting to zero means no deadline. +func (r *ReadStreamSRTCP) SetReadDeadline(t time.Time) error { + if b, ok := r.buffer.(interface { + SetReadDeadline(time.Time) error + }); ok { + return b.SetReadDeadline(t) + } + + return nil +} + +// Close removes the ReadStream from the session and cleans up any associated state. +func (r *ReadStreamSRTCP) Close() error { + r.mu.Lock() + defer r.mu.Unlock() + + if !r.isInited { + return errStreamNotInited + } + + select { + case <-r.isClosed: + return errStreamAlreadyClosed + default: + err := r.buffer.Close() + if err != nil { + return err + } + + r.session.removeReadStream(r.ssrc) + + return nil + } +} + +func (r *ReadStreamSRTCP) init(child streamSession, ssrc uint32) error { + sessionSRTCP, ok := child.(*SessionSRTCP) + + r.mu.Lock() + defer r.mu.Unlock() + if !ok { + return errFailedTypeAssertion + } else if r.isInited { + return errStreamAlreadyInited + } + + r.session = sessionSRTCP + r.ssrc = ssrc + r.isInited = true + r.isClosed = make(chan bool) + + if r.session.bufferFactory != nil { + r.buffer = r.session.bufferFactory(packetio.RTCPBufferPacket, ssrc) + } else { + // Create a buffer and limit it to 100KB + buff := packetio.NewBuffer() + buff.SetLimitSize(srtcpBufferSize) + r.buffer = buff + } + + return nil +} + +// GetSSRC returns the SSRC we are demuxing for. +func (r *ReadStreamSRTCP) GetSSRC() uint32 { + return r.ssrc +} + +// WriteStreamSRTCP is stream for a single Session that is used to encrypt RTCP. +type WriteStreamSRTCP struct { + session *SessionSRTCP +} + +// WriteRTCP encrypts a RTCP header and its payload to the nextConn. +func (w *WriteStreamSRTCP) WriteRTCP(header *rtcp.Header, payload []byte) (int, error) { + headerRaw, err := header.Marshal() + if err != nil { + return 0, err + } + + return w.session.write(append(headerRaw, payload...)) +} + +// Write encrypts and writes a full RTCP packets to the nextConn. +func (w *WriteStreamSRTCP) Write(b []byte) (int, error) { + return w.session.write(b) +} + +// SetWriteDeadline sets the deadline for the Write operation. +// Setting to zero means no deadline. +func (w *WriteStreamSRTCP) SetWriteDeadline(t time.Time) error { + return w.session.setWriteDeadline(t) +} diff --git a/vendor/github.com/pion/srtp/v3/stream_srtp.go b/vendor/github.com/pion/srtp/v3/stream_srtp.go new file mode 100644 index 0000000..dd54984 --- /dev/null +++ b/vendor/github.com/pion/srtp/v3/stream_srtp.go @@ -0,0 +1,184 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package srtp + +import ( + "errors" + "io" + "slices" + "sync" + "time" + + "github.com/pion/rtp" + "github.com/pion/transport/v4/packetio" +) + +// Limit the buffer size to 1MB. +const srtpBufferSize = 1000 * 1000 + +// ReadStreamSRTP handles decryption for a single RTP SSRC. +type ReadStreamSRTP struct { + mu sync.Mutex + + isClosed chan bool + + session *SessionSRTP + ssrc uint32 + isInited bool + + buffer io.ReadWriteCloser + peekedPackets [][]byte +} + +// Used by getOrCreateReadStream. +func newReadStreamSRTP() readStream { + return &ReadStreamSRTP{} +} + +func (r *ReadStreamSRTP) init(child streamSession, ssrc uint32) error { + sessionSRTP, ok := child.(*SessionSRTP) + + r.mu.Lock() + defer r.mu.Unlock() + + if !ok { + return errFailedTypeAssertion + } else if r.isInited { + return errStreamAlreadyInited + } + + r.session = sessionSRTP + r.ssrc = ssrc + r.isInited = true + r.isClosed = make(chan bool) + + // Create a buffer with a 1MB limit + if r.session.bufferFactory != nil { + r.buffer = r.session.bufferFactory(packetio.RTPBufferPacket, ssrc) + } else { + buff := packetio.NewBuffer() + buff.SetLimitSize(srtpBufferSize) + r.buffer = buff + } + + return nil +} + +func (r *ReadStreamSRTP) write(buf []byte) (n int, err error) { + n, err = r.buffer.Write(buf) + + if errors.Is(err, packetio.ErrFull) { + // Silently drop data when the buffer is full. + return len(buf), nil + } + + return n, err +} + +// Peek reads and decrypts full RTP packet from the nextConn. +// It is then buffered so that a call to `Read` will return it. +func (r *ReadStreamSRTP) Peek(buf []byte) (n int, err error) { + n, err = r.buffer.Read(buf) + if err == nil { + r.peekedPackets = append(r.peekedPackets, slices.Clone(buf[:n])) + } + + return +} + +// Read reads and decrypts full RTP packet from the nextConn. +func (r *ReadStreamSRTP) Read(buf []byte) (int, error) { + if len(r.peekedPackets) != 0 { + if len(r.peekedPackets[0]) > len(buf) { + return 0, io.ErrShortBuffer + } + + n := len(r.peekedPackets[0]) + copy(buf, r.peekedPackets[0]) + r.peekedPackets = r.peekedPackets[1:] + + return n, nil + } + + return r.buffer.Read(buf) +} + +// ReadRTP reads and decrypts full RTP packet and its header from the nextConn. +func (r *ReadStreamSRTP) ReadRTP(buf []byte) (int, *rtp.Header, error) { + n, err := r.Read(buf) + if err != nil { + return 0, nil, err + } + + header := &rtp.Header{} + + _, err = header.Unmarshal(buf[:n]) + if err != nil { + return 0, nil, err + } + + return n, header, nil +} + +// SetReadDeadline sets the deadline for the Read operation. +// Setting to zero means no deadline. +func (r *ReadStreamSRTP) SetReadDeadline(t time.Time) error { + if b, ok := r.buffer.(interface { + SetReadDeadline(time.Time) error + }); ok { + return b.SetReadDeadline(t) + } + + return nil +} + +// Close removes the ReadStream from the session and cleans up any associated state. +func (r *ReadStreamSRTP) Close() error { + r.mu.Lock() + defer r.mu.Unlock() + + if !r.isInited { + return errStreamNotInited + } + + select { + case <-r.isClosed: + return errStreamAlreadyClosed + default: + err := r.buffer.Close() + if err != nil { + return err + } + + r.session.removeReadStream(r.ssrc) + + return nil + } +} + +// GetSSRC returns the SSRC we are demuxing for. +func (r *ReadStreamSRTP) GetSSRC() uint32 { + return r.ssrc +} + +// WriteStreamSRTP is stream for a single Session that is used to encrypt RTP. +type WriteStreamSRTP struct { + session *SessionSRTP +} + +// WriteRTP encrypts a RTP packet and writes to the connection. +func (w *WriteStreamSRTP) WriteRTP(header *rtp.Header, payload []byte) (int, error) { + return w.session.writeRTP(header, payload) +} + +// Write encrypts and writes a full RTP packets to the nextConn. +func (w *WriteStreamSRTP) Write(b []byte) (int, error) { + return w.session.write(b) +} + +// SetWriteDeadline sets the deadline for the Write operation. +// Setting to zero means no deadline. +func (w *WriteStreamSRTP) SetWriteDeadline(t time.Time) error { + return w.session.setWriteDeadline(t) +} diff --git a/vendor/github.com/pion/srtp/v3/util.go b/vendor/github.com/pion/srtp/v3/util.go new file mode 100644 index 0000000..8411601 --- /dev/null +++ b/vendor/github.com/pion/srtp/v3/util.go @@ -0,0 +1,39 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package srtp + +import ( + "unsafe" +) + +// growBufferSize grows the buffer size to the given number of bytes. +func growBufferSize(buf []byte, size int) []byte { + if size <= cap(buf) { + return buf[:size] + } + + buf2 := make([]byte, size) + copy(buf2, buf) + + return buf2 +} + +// isSameBuffer returns true if slices a and b share the same underlying buffer. +func isSameBuffer(a, b []byte) bool { + // If both are nil, they are technically the same (no buffer) + if a == nil && b == nil { + return true + } + + // If either is nil, or both have 0 capacity, they can't share backing buffer + if cap(a) == 0 || cap(b) == 0 { + return false + } + + // Create a slice of length 1 from each if possible + aPtr := unsafe.Pointer(&a[:1][0]) // nolint:gosec + bPtr := unsafe.Pointer(&b[:1][0]) // nolint:gosec + + return aPtr == bPtr +} diff --git a/vendor/github.com/pion/stun/v3/.gitignore b/vendor/github.com/pion/stun/v3/.gitignore new file mode 100644 index 0000000..6e2f206 --- /dev/null +++ b/vendor/github.com/pion/stun/v3/.gitignore @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +### JetBrains IDE ### +##################### +.idea/ + +### Emacs Temporary Files ### +############################# +*~ + +### Folders ### +############### +bin/ +vendor/ +node_modules/ + +### Files ### +############# +*.ivf +*.ogg +tags +cover.out +*.sw[poe] +*.wasm +examples/sfu-ws/cert.pem +examples/sfu-ws/key.pem +wasm_exec.js diff --git a/vendor/github.com/pion/stun/v3/.golangci.yml b/vendor/github.com/pion/stun/v3/.golangci.yml new file mode 100644 index 0000000..4b4025f --- /dev/null +++ b/vendor/github.com/pion/stun/v3/.golangci.yml @@ -0,0 +1,147 @@ +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +version: "2" +linters: + enable: + - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers + - bidichk # Checks for dangerous unicode character sequences + - bodyclose # checks whether HTTP response body is closed successfully + - containedctx # containedctx is a linter that detects struct contained context.Context field + - contextcheck # check the function whether use a non-inherited context + - cyclop # checks function and package cyclomatic complexity + - decorder # check declaration order and count of types, constants, variables and functions + - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) + - dupl # Tool for code clone detection + - durationcheck # check for two durations multiplied together + - err113 # Golang linter to check the errors handling expressions + - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases + - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. + - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. + - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. + - exhaustive # check exhaustiveness of enum switch statements + - forbidigo # Forbids identifiers + - forcetypeassert # finds forced type assertions + - gochecknoglobals # Checks that no globals are present in Go code + - gocognit # Computes and checks the cognitive complexity of functions + - goconst # Finds repeated strings that could be replaced by a constant + - gocritic # The most opinionated Go source code linter + - gocyclo # Computes and checks the cyclomatic complexity of functions + - godot # Check if comments end in a period + - godox # Tool for detection of FIXME, TODO and other comment keywords + - goheader # Checks is file header matches to pattern + - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. + - goprintffuncname # Checks that printf-like functions are named with `f` at the end + - gosec # Inspects source code for security problems + - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string + - grouper # An analyzer to analyze expression groups. + - importas # Enforces consistent import aliases + - ineffassign # Detects when assignments to existing variables are not used + - lll # Reports long lines + - maintidx # maintidx measures the maintainability index of each function. + - makezero # Finds slice declarations with non-zero initial length + - misspell # Finds commonly misspelled English words in comments + - nakedret # Finds naked returns in functions greater than a specified function length + - nestif # Reports deeply nested if statements + - nilerr # Finds the code that returns nil even if it checks that the error is not nil. + - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. + - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity + - noctx # noctx finds sending http request without context.Context + - predeclared # find code that shadows one of Go's predeclared identifiers + - revive # golint replacement, finds style mistakes + - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks + - tagliatelle # Checks the struct tags. + - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers + - unconvert # Remove unnecessary type conversions + - unparam # Reports unused function parameters + - unused # Checks Go code for unused constants, variables, functions and types + - varnamelen # checks that the length of a variable's name matches its scope + - wastedassign # wastedassign finds wasted assignment statements + - whitespace # Tool for detection of leading and trailing whitespace + disable: + - depguard # Go linter that checks if package imports are in a list of acceptable packages + - funlen # Tool for detection of long functions + - gochecknoinits # Checks that no init functions are present in Go code + - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. + - interfacebloat # A linter that checks length of interface. + - ireturn # Accept Interfaces, Return Concrete Types + - mnd # An analyzer to detect magic numbers + - nolintlint # Reports ill-formed or insufficient nolint directives + - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test + - prealloc # Finds slice declarations that could potentially be preallocated + - promlinter # Check Prometheus metrics naming via promlint + - rowserrcheck # checks whether Err of rows is checked successfully + - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. + - testpackage # linter that makes you use a separate _test package + - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes + - wrapcheck # Checks that errors returned from external packages are wrapped + - wsl # Whitespace Linter - Forces you to use empty lines! + settings: + staticcheck: + checks: + - all + - -QF1008 # "could remove embedded field", to keep it explicit! + - -QF1003 # "could use tagged switch on enum", Cases conflicts with exhaustive! + exhaustive: + default-signifies-exhaustive: true + forbidigo: + forbid: + - pattern: ^fmt.Print(f|ln)?$ + - pattern: ^log.(Panic|Fatal|Print)(f|ln)?$ + - pattern: ^os.Exit$ + - pattern: ^panic$ + - pattern: ^print(ln)?$ + - pattern: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$ + pkg: ^testing$ + msg: use testify/assert instead + analyze-types: true + gomodguard: + blocked: + modules: + - github.com/pkg/errors: + recommendations: + - errors + govet: + enable: + - shadow + revive: + rules: + # Prefer 'any' type alias over 'interface{}' for Go 1.18+ compatibility + - name: use-any + severity: warning + disabled: false + misspell: + locale: US + varnamelen: + max-distance: 12 + min-name-length: 2 + ignore-type-assert-ok: true + ignore-map-index-ok: true + ignore-chan-recv-ok: true + ignore-decls: + - i int + - n int + - w io.Writer + - r io.Reader + - b []byte + exclusions: + generated: lax + rules: + - linters: + - forbidigo + - gocognit + path: (examples|main\.go) + - linters: + - gocognit + path: _test\.go + - linters: + - forbidigo + path: cmd +formatters: + enable: + - gci # Gci control golang package import order and make it always deterministic. + - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification + - gofumpt # Gofumpt checks whether code was gofumpt-ed. + - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports + exclusions: + generated: lax diff --git a/vendor/github.com/pion/stun/v3/.goreleaser.yml b/vendor/github.com/pion/stun/v3/.goreleaser.yml new file mode 100644 index 0000000..30093e9 --- /dev/null +++ b/vendor/github.com/pion/stun/v3/.goreleaser.yml @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +builds: +- skip: true diff --git a/vendor/github.com/pion/stun/v3/LICENSE b/vendor/github.com/pion/stun/v3/LICENSE new file mode 100644 index 0000000..491caf6 --- /dev/null +++ b/vendor/github.com/pion/stun/v3/LICENSE @@ -0,0 +1,9 @@ +MIT License + +Copyright (c) 2023 The Pion community + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/pion/stun/v3/Makefile b/vendor/github.com/pion/stun/v3/Makefile new file mode 100644 index 0000000..eaaefa5 --- /dev/null +++ b/vendor/github.com/pion/stun/v3/Makefile @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +VERSION := $(shell git describe --tags | sed -e 's/^v//g' | awk -F "-" '{print $$1}') +ITERATION := $(shell git describe --tags --long | awk -F "-" '{print $$2}') +GO_VERSION=$(shell gobuild -v) +GO := $(or $(GOROOT),/usr/lib/go)/bin/go +PROCS := $(shell nproc) +cores: + @echo "cores: $(PROCS)" +bench: + go test -bench . +bench-record: + $(GO) test -bench . > "benchmarks/stun-go-$(GO_VERSION).txt" +lint: + @golangci-lint run ./... + @echo "ok" +escape: + @echo "Not escapes, except autogenerated:" + @go build -gcflags '-m -l' 2>&1 \ + | grep -v "" \ + | grep escapes +format: + goimports -w . +bench-compare: + go test -bench . > bench.go-16 + go-tip test -bench . > bench.go-tip + @benchcmp bench.go-16 bench.go-tip +install: + go get gortc.io/api + go get -u github.com/golangci/golangci-lint/cmd/golangci-lint +test-integration: + @cd e2e && bash ./test.sh +prepush: test lint test-integration +test: + @./go.test.sh +clean: diff --git a/vendor/github.com/pion/stun/v3/README.md b/vendor/github.com/pion/stun/v3/README.md new file mode 100644 index 0000000..06e1491 --- /dev/null +++ b/vendor/github.com/pion/stun/v3/README.md @@ -0,0 +1,186 @@ +

+
+ Pion STUN +
+

+

A Go implementation of STUN

+

+ Pion stun + join us on Discord Follow us on Bluesky +
+ GitHub Workflow Status + Go Reference + Coverage Status + Go Report Card + License: MIT +

+
+ +Package `stun` implements Session Traversal Utilities for NAT (STUN) ([RFC 5389][rfc5389]) +protocol and [client](https://pkg.go.dev/github.com/pion/stun#Client) with no external dependencies and zero allocations in hot paths. +Client [supports](https://pkg.go.dev/github.com/pion/stun#WithRTO) automatic request retransmissions. + +### Example +You can get your current IP address from any STUN server by sending +binding request. See more idiomatic example at `cmd/stun-client`. +```go +package main + +import ( + "fmt" + + "github.com/pion/stun" +) + +func main() { + // Parse a STUN URI + u, err := stun.ParseURI("stun:stun.l.google.com:19302") + if err != nil { + panic(err) + } + + // Creating a "connection" to STUN server. + c, err := stun.DialURI(u, &stun.DialConfig{}) + if err != nil { + panic(err) + } + // Building binding request with random transaction id. + message := stun.MustBuild(stun.TransactionID, stun.BindingRequest) + // Sending request to STUN server, waiting for response message. + if err := c.Do(message, func(res stun.Event) { + if res.Error != nil { + panic(res.Error) + } + // Decoding XOR-MAPPED-ADDRESS attribute from message. + var xorAddr stun.XORMappedAddress + if err := xorAddr.GetFrom(res.Message); err != nil { + panic(err) + } + fmt.Println("your IP is", xorAddr.IP) + }); err != nil { + panic(err) + } +} +``` + +### RFCs +#### Implemented +- **RFC 5389**: [Session Traversal Utilities for NAT (STUN)][rfc5389] +- **RFC 5769**: [Test Vectors for Session Traversal Utilities for NAT (STUN)][rfc5769] +- **RFC 6062**: [Traversal Using Relays around NAT (TURN) Extensions for TCP Allocations][rfc6062] +- **RFC 7064**: [URI Scheme for the Session Traversal Utilities for NAT (STUN) Protocol][rfc7064] +- **RFC 7065**: [Traversal Using Relays around NAT (TURN) Uniform Resource Identifiers][rfc7065] +- **RFC 5780**: [NAT Behavior Discovery Using Session Traversal Utilities for NAT (STUN)][rfc5780] via [cmd/stun-nat-behaviour](cmd/stun-nat-behaviour) +- (TLS-over-)TCP client support + +#### Planned +- **RFC 5389**: [ALTERNATE-SERVER](https://tools.ietf.org/html/rfc5389#section-11) support [#48](https://github.com/pion/stun/issues/48) + +#### Compatability notes + +[RFC 5389][rfc5389] obsoletes [RFC 3489][rfc3489], so implementation was ignored by purpose, however, +[RFC 3489][rfc3489] can be easily implemented as separate package. + +[rfc3489]: https://tools.ietf.org/html/rfc3489 +[rfc5389]: https://tools.ietf.org/html/rfc5389 +[rfc5769]: https://tools.ietf.org/html/rfc5769 +[rfc5780]: https://tools.ietf.org/html/rfc5780 +[rfc6062]: https://tools.ietf.org/html/rfc6062 +[rfc7064]: https://tools.ietf.org/html/rfc7064 +[rfc7065]: https://tools.ietf.org/html/rfc7065 + +### Stability +Package is currently stable, no backward incompatible changes are expected +with exception of critical bugs or security fixes. + +Additional attributes are unlikely to be implemented in scope of stun package, +the only exception is constants for attribute or message types. + +### Requirements +Go 1.12 is currently supported and tested in CI. + +### Testing +Client behavior is tested and verified in many ways: + * End-To-End with long-term credentials + * **coturn**: The coturn [server](https://github.com/coturn/coturn/wiki/turnserver) (linux) + * Bunch of code static checkers (linters) + * Standard unit-tests with coverage reporting (linux {amd64, **arm**64}, windows and darwin) + * Explicit API backward compatibility [check](https://github.com/gortc/api), see `api` directory + +See [TeamCity project](https://tc.gortc.io/project.html?projectId=stun&guest=1) and `e2e` directory +for more information. Also the Wireshark `.pcap` files are available for e2e test in +artifacts for build. + +### Benchmarks +Intel(R) Core(TM) i7-8700K: + +``` +version: 1.22.2 +goos: linux +goarch: amd64 +pkg: github.com/pion/stun +PASS +benchmark iter time/iter throughput bytes alloc allocs +--------- ---- --------- ---------- ----------- ------ +BenchmarkMappedAddress_AddTo-12 32489450 38.30 ns/op 0 B/op 0 allocs/op +BenchmarkAlternateServer_AddTo-12 31230991 39.00 ns/op 0 B/op 0 allocs/op +BenchmarkAgent_GC-12 431390 2918.00 ns/op 0 B/op 0 allocs/op +BenchmarkAgent_Process-12 35901940 36.20 ns/op 0 B/op 0 allocs/op +BenchmarkMessage_GetNotFound-12 242004358 5.19 ns/op 0 B/op 0 allocs/op +BenchmarkMessage_Get-12 230520343 5.21 ns/op 0 B/op 0 allocs/op +BenchmarkClient_Do-12 1282231 943.00 ns/op 0 B/op 0 allocs/op +BenchmarkErrorCode_AddTo-12 16318916 75.50 ns/op 0 B/op 0 allocs/op +BenchmarkErrorCodeAttribute_AddTo-12 21584140 54.80 ns/op 0 B/op 0 allocs/op +BenchmarkErrorCodeAttribute_GetFrom-12 100000000 11.10 ns/op 0 B/op 0 allocs/op +BenchmarkFingerprint_AddTo-12 19368768 64.00 ns/op 687.81 MB/s 0 B/op 0 allocs/op +BenchmarkFingerprint_Check-12 24167007 49.10 ns/op 1057.99 MB/s 0 B/op 0 allocs/op +BenchmarkBuildOverhead/Build-12 5486252 224.00 ns/op 0 B/op 0 allocs/op +BenchmarkBuildOverhead/BuildNonPointer-12 2496544 517.00 ns/op 100 B/op 4 allocs/op +BenchmarkBuildOverhead/Raw-12 6652118 181.00 ns/op 0 B/op 0 allocs/op +BenchmarkMessage_ForEach-12 28254212 35.90 ns/op 0 B/op 0 allocs/op +BenchmarkMessageIntegrity_AddTo-12 1000000 1179.00 ns/op 16.96 MB/s 0 B/op 0 allocs/op +BenchmarkMessageIntegrity_Check-12 975954 1219.00 ns/op 26.24 MB/s 0 B/op 0 allocs/op +BenchmarkMessage_Write-12 41040598 30.40 ns/op 922.13 MB/s 0 B/op 0 allocs/op +BenchmarkMessageType_Value-12 1000000000 0.53 ns/op 0 B/op 0 allocs/op +BenchmarkMessage_WriteTo-12 94942935 11.30 ns/op 0 B/op 0 allocs/op +BenchmarkMessage_ReadFrom-12 43437718 29.30 ns/op 682.87 MB/s 0 B/op 0 allocs/op +BenchmarkMessage_ReadBytes-12 74693397 15.90 ns/op 1257.42 MB/s 0 B/op 0 allocs/op +BenchmarkIsMessage-12 1000000000 1.20 ns/op 16653.64 MB/s 0 B/op 0 allocs/op +BenchmarkMessage_NewTransactionID-12 521121 2450.00 ns/op 0 B/op 0 allocs/op +BenchmarkMessageFull-12 5389495 221.00 ns/op 0 B/op 0 allocs/op +BenchmarkMessageFullHardcore-12 12715876 94.40 ns/op 0 B/op 0 allocs/op +BenchmarkMessage_WriteHeader-12 100000000 11.60 ns/op 0 B/op 0 allocs/op +BenchmarkMessage_CloneTo-12 30199020 41.80 ns/op 1626.66 MB/s 0 B/op 0 allocs/op +BenchmarkMessage_AddTo-12 415257625 2.97 ns/op 0 B/op 0 allocs/op +BenchmarkDecode-12 49573747 23.60 ns/op 0 B/op 0 allocs/op +BenchmarkUsername_AddTo-12 56282674 22.50 ns/op 0 B/op 0 allocs/op +BenchmarkUsername_GetFrom-12 100000000 10.10 ns/op 0 B/op 0 allocs/op +BenchmarkNonce_AddTo-12 39419097 35.80 ns/op 0 B/op 0 allocs/op +BenchmarkNonce_AddTo_BadLength-12 196291666 6.04 ns/op 0 B/op 0 allocs/op +BenchmarkNonce_GetFrom-12 120857732 9.93 ns/op 0 B/op 0 allocs/op +BenchmarkUnknownAttributes/AddTo-12 28881430 37.20 ns/op 0 B/op 0 allocs/op +BenchmarkUnknownAttributes/GetFrom-12 64907534 19.80 ns/op 0 B/op 0 allocs/op +BenchmarkXOR-12 32868506 32.20 ns/op 31836.66 MB/s +BenchmarkXORSafe-12 5185776 234.00 ns/op 4378.74 MB/s +BenchmarkXORFast-12 30975679 32.50 ns/op 31525.28 MB/s +BenchmarkXORMappedAddress_AddTo-12 21518028 54.50 ns/op 0 B/op 0 allocs/op +BenchmarkXORMappedAddress_GetFrom-12 35597667 34.40 ns/op 0 B/op 0 allocs/op +ok github.com/pion/stun 60.973s +``` + +### Roadmap +The library is used as a part of our WebRTC implementation. Please refer to that [roadmap](https://github.com/pion/webrtc/issues/9) to track our major milestones. + +### Community +Pion has an active community on the [Discord](https://discord.gg/PngbdqpFbt). + +Follow the [Pion Bluesky](https://bsky.app/profile/pion.ly) or [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news. + +We are always looking to support **your projects**. Please reach out if you have something to build! +If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly) + +### Contributing +Check out the [contributing wiki](https://github.com/pion/webrtc/wiki/Contributing) to join the group of amazing people making this project possible + +### License +MIT License - see [LICENSE](LICENSE) for full text diff --git a/vendor/github.com/pion/stun/v3/addr.go b/vendor/github.com/pion/stun/v3/addr.go new file mode 100644 index 0000000..37cbf52 --- /dev/null +++ b/vendor/github.com/pion/stun/v3/addr.go @@ -0,0 +1,168 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package stun + +import ( + "fmt" + "io" + "net" + "strconv" +) + +// MappedAddress represents MAPPED-ADDRESS attribute. +// +// This attribute is used only by servers for achieving backwards +// compatibility with RFC 3489 clients. +// +// RFC 5389 Section 15.1. +type MappedAddress struct { + IP net.IP + Port int +} + +// AlternateServer represents ALTERNATE-SERVER attribute. +// +// RFC 5389 Section 15.11. +type AlternateServer struct { + IP net.IP + Port int +} + +// ResponseOrigin represents RESPONSE-ORIGIN attribute. +// +// RFC 5780 Section 7.3. +type ResponseOrigin struct { + IP net.IP + Port int +} + +// OtherAddress represents OTHER-ADDRESS attribute. +// +// RFC 5780 Section 7.4. +type OtherAddress struct { + IP net.IP + Port int +} + +// AddTo adds ALTERNATE-SERVER attribute to message. +func (s *AlternateServer) AddTo(m *Message) error { + a := (*MappedAddress)(s) + + return a.AddToAs(m, AttrAlternateServer) +} + +// GetFrom decodes ALTERNATE-SERVER from message. +func (s *AlternateServer) GetFrom(m *Message) error { + a := (*MappedAddress)(s) + + return a.GetFromAs(m, AttrAlternateServer) +} + +func (a MappedAddress) String() string { + return net.JoinHostPort(a.IP.String(), strconv.Itoa(a.Port)) +} + +// GetFromAs decodes MAPPED-ADDRESS value in message m as an attribute of type t. +func (a *MappedAddress) GetFromAs(m *Message, t AttrType) error { + value, err := m.Get(t) + if err != nil { + return err + } + if len(value) <= 4 { + return io.ErrUnexpectedEOF + } + family := bin.Uint16(value[0:2]) + if family != familyIPv6 && family != familyIPv4 { + return newDecodeErr("xor-mapped address", "family", + fmt.Sprintf("bad value %d", family), + ) + } + ipLen := net.IPv4len + if family == familyIPv6 { + ipLen = net.IPv6len + } + // Ensuring len(a.IP) == ipLen and reusing a.IP. + if len(a.IP) < ipLen { + a.IP = make(net.IP, ipLen) + } else { + a.IP = a.IP[:ipLen] + for i := range a.IP { + a.IP[i] = 0 + } + } + a.Port = int(bin.Uint16(value[2:4])) + copy(a.IP, value[4:]) + + return nil +} + +// AddToAs adds MAPPED-ADDRESS value to m as t attribute. +func (a *MappedAddress) AddToAs(msg *Message, attrType AttrType) error { + var ( + family = familyIPv4 + ip = a.IP + ) + if len(a.IP) == net.IPv6len { + if isIPv4(ip) { + ip = ip[12:16] // like in ip.To4() + } else { + family = familyIPv6 + } + } else if len(ip) != net.IPv4len { + return ErrBadIPLength + } + value := make([]byte, 128) + bin.PutUint16(value[0:2], family) + bin.PutUint16(value[2:4], uint16(a.Port)) //nolint:gosec //G115 + copy(value[4:], ip) + msg.Add(attrType, value[:4+len(ip)]) + + return nil +} + +// AddTo adds MAPPED-ADDRESS to message. +func (a *MappedAddress) AddTo(m *Message) error { + return a.AddToAs(m, AttrMappedAddress) +} + +// GetFrom decodes MAPPED-ADDRESS from message. +func (a *MappedAddress) GetFrom(m *Message) error { + return a.GetFromAs(m, AttrMappedAddress) +} + +// AddTo adds OTHER-ADDRESS attribute to message. +func (o *OtherAddress) AddTo(m *Message) error { + a := (*MappedAddress)(o) + + return a.AddToAs(m, AttrOtherAddress) +} + +// GetFrom decodes OTHER-ADDRESS from message. +func (o *OtherAddress) GetFrom(m *Message) error { + a := (*MappedAddress)(o) + + return a.GetFromAs(m, AttrOtherAddress) +} + +func (o OtherAddress) String() string { + return net.JoinHostPort(o.IP.String(), strconv.Itoa(o.Port)) +} + +// AddTo adds RESPONSE-ORIGIN attribute to message. +func (o *ResponseOrigin) AddTo(m *Message) error { + a := (*MappedAddress)(o) + + return a.AddToAs(m, AttrResponseOrigin) +} + +// GetFrom decodes RESPONSE-ORIGIN from message. +func (o *ResponseOrigin) GetFrom(m *Message) error { + a := (*MappedAddress)(o) + + return a.GetFromAs(m, AttrResponseOrigin) +} + +func (o ResponseOrigin) String() string { + return net.JoinHostPort(o.IP.String(), strconv.Itoa(o.Port)) +} diff --git a/vendor/github.com/pion/stun/v3/agent.go b/vendor/github.com/pion/stun/v3/agent.go new file mode 100644 index 0000000..4f07dce --- /dev/null +++ b/vendor/github.com/pion/stun/v3/agent.go @@ -0,0 +1,245 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package stun + +import ( + "errors" + "sync" + "time" +) + +// NoopHandler just discards any event. +func NoopHandler() Handler { + return func(Event) {} +} + +// NewAgent initializes and returns new Agent with provided handler. +// If h is nil, the NoopHandler will be used. +func NewAgent(h Handler) *Agent { + if h == nil { + h = NoopHandler() + } + a := &Agent{ + transactions: make(map[transactionID]agentTransaction), + handler: h, + } + + return a +} + +// Agent is low-level abstraction over transaction list that +// handles concurrency (all calls are goroutine-safe) and +// time outs (via Collect call). +type Agent struct { + // transactions is map of transactions that are currently + // in progress. Event handling is done in such way when + // transaction is unregistered before agentTransaction access, + // minimizing mux lock and protecting agentTransaction from + // data races via unexpected concurrent access. + transactions map[transactionID]agentTransaction + closed bool // all calls are invalid if true + mux sync.Mutex // protects transactions and closed + handler Handler // handles transactions +} + +// Handler handles state changes of transaction. +// +// Handler is called on transaction state change. +// Usage of e is valid only during call, user must +// copy needed fields explicitly. +type Handler func(e Event) + +// Event is passed to Handler describing the transaction event. +// Do not reuse outside Handler. +type Event struct { + TransactionID [TransactionIDSize]byte + Message *Message + Error error +} + +// agentTransaction represents transaction in progress. +// Concurrent access is invalid. +type agentTransaction struct { + id transactionID + deadline time.Time +} + +var ( + // ErrTransactionStopped indicates that transaction was manually stopped. + ErrTransactionStopped = errors.New("transaction is stopped") + // ErrTransactionNotExists indicates that agent failed to find transaction. + ErrTransactionNotExists = errors.New("transaction not exists") + // ErrTransactionExists indicates that transaction with same id is already + // registered. + ErrTransactionExists = errors.New("transaction exists with same id") +) + +// StopWithError removes transaction from list and calls handler with +// provided error. Can return ErrTransactionNotExists and ErrAgentClosed. +func (a *Agent) StopWithError(id [TransactionIDSize]byte, err error) error { + a.mux.Lock() + if a.closed { + a.mux.Unlock() + + return ErrAgentClosed + } + t, exists := a.transactions[id] + delete(a.transactions, id) + h := a.handler + a.mux.Unlock() + if !exists { + return ErrTransactionNotExists + } + h(Event{ + TransactionID: t.id, + Error: err, + }) + + return nil +} + +// Stop stops transaction by id with ErrTransactionStopped, blocking +// until handler returns. +func (a *Agent) Stop(id [TransactionIDSize]byte) error { + return a.StopWithError(id, ErrTransactionStopped) +} + +// ErrAgentClosed indicates that agent is in closed state and is unable +// to handle transactions. +var ErrAgentClosed = errors.New("agent is closed") + +// Start registers transaction with provided id and deadline. +// Could return ErrAgentClosed, ErrTransactionExists. +// +// Agent handler is guaranteed to be eventually called. +func (a *Agent) Start(id [TransactionIDSize]byte, deadline time.Time) error { + a.mux.Lock() + defer a.mux.Unlock() + if a.closed { + return ErrAgentClosed + } + _, exists := a.transactions[id] + if exists { + return ErrTransactionExists + } + a.transactions[id] = agentTransaction{ + id: id, + deadline: deadline, + } + + return nil +} + +// agentCollectCap is initial capacity for Agent.Collect slices, +// sufficient to make function zero-alloc in most cases. +const agentCollectCap = 100 + +// ErrTransactionTimeOut indicates that transaction has reached deadline. +var ErrTransactionTimeOut = errors.New("transaction is timed out") + +// Collect terminates all transactions that have deadline before provided +// time, blocking until all handlers will process ErrTransactionTimeOut. +// Will return ErrAgentClosed if agent is already closed. +// +// It is safe to call Collect concurrently but makes no sense. +func (a *Agent) Collect(gcTime time.Time) error { + toRemove := make([]transactionID, 0, agentCollectCap) + a.mux.Lock() + if a.closed { + // Doing nothing if agent is closed. + // All transactions should be already closed + // during Close() call. + a.mux.Unlock() + + return ErrAgentClosed + } + // Adding all transactions with deadline before gcTime + // to toCall and toRemove slices. + // No allocs if there are less than agentCollectCap + // timed out transactions. + for id, t := range a.transactions { + if t.deadline.Before(gcTime) { + toRemove = append(toRemove, id) + } + } + // Un-registering timed out transactions. + for _, id := range toRemove { + delete(a.transactions, id) + } + // Calling handler does not require locked mutex, + // reducing lock time. + h := a.handler + a.mux.Unlock() + // Sending ErrTransactionTimeOut to handler for all transactions, + // blocking until last one. + event := Event{ + Error: ErrTransactionTimeOut, + } + for _, id := range toRemove { + event.TransactionID = id + h(event) + } + + return nil +} + +// Process incoming message, synchronously passing it to handler. +func (a *Agent) Process(m *Message) error { + event := Event{ + TransactionID: m.TransactionID, + Message: m, + } + a.mux.Lock() + if a.closed { + a.mux.Unlock() + + return ErrAgentClosed + } + h := a.handler + delete(a.transactions, m.TransactionID) + a.mux.Unlock() + h(event) + + return nil +} + +// SetHandler sets agent handler to h. +func (a *Agent) SetHandler(h Handler) error { + a.mux.Lock() + if a.closed { + a.mux.Unlock() + + return ErrAgentClosed + } + a.handler = h + a.mux.Unlock() + + return nil +} + +// Close terminates all transactions with ErrAgentClosed and renders Agent to +// closed state. +func (a *Agent) Close() error { + e := Event{ + Error: ErrAgentClosed, + } + a.mux.Lock() + if a.closed { + a.mux.Unlock() + + return ErrAgentClosed + } + for _, t := range a.transactions { + e.TransactionID = t.id + a.handler(e) + } + a.transactions = nil + a.closed = true + a.handler = nil + a.mux.Unlock() + + return nil +} + +type transactionID [TransactionIDSize]byte diff --git a/vendor/github.com/pion/stun/v3/attributes.go b/vendor/github.com/pion/stun/v3/attributes.go new file mode 100644 index 0000000..139a071 --- /dev/null +++ b/vendor/github.com/pion/stun/v3/attributes.go @@ -0,0 +1,277 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package stun + +import ( + "errors" + "fmt" +) + +// Attributes is list of message attributes. +type Attributes []RawAttribute + +// Get returns first attribute from list by the type. +// If attribute is present the RawAttribute is returned and the +// boolean is true. Otherwise the returned RawAttribute will be +// empty and boolean will be false. +func (a Attributes) Get(t AttrType) (RawAttribute, bool) { + for _, candidate := range a { + if candidate.Type == t { + return candidate, true + } + } + + return RawAttribute{}, false +} + +// AttrType is attribute type. +type AttrType uint16 + +// Required returns true if type is from comprehension-required range (0x0000-0x7FFF). +func (t AttrType) Required() bool { + return t <= 0x7FFF +} + +// Optional returns true if type is from comprehension-optional range (0x8000-0xFFFF). +func (t AttrType) Optional() bool { + return t >= 0x8000 +} + +// Attributes from comprehension-required range (0x0000-0x7FFF). +const ( + AttrMappedAddress AttrType = 0x0001 // MAPPED-ADDRESS + AttrUsername AttrType = 0x0006 // USERNAME + AttrMessageIntegrity AttrType = 0x0008 // MESSAGE-INTEGRITY + AttrErrorCode AttrType = 0x0009 // ERROR-CODE + AttrUnknownAttributes AttrType = 0x000A // UNKNOWN-ATTRIBUTES + AttrRealm AttrType = 0x0014 // REALM + AttrNonce AttrType = 0x0015 // NONCE + AttrXORMappedAddress AttrType = 0x0020 // XOR-MAPPED-ADDRESS +) + +// Attributes from comprehension-optional range (0x8000-0xFFFF). +const ( + AttrSoftware AttrType = 0x8022 // SOFTWARE + AttrAlternateServer AttrType = 0x8023 // ALTERNATE-SERVER + AttrFingerprint AttrType = 0x8028 // FINGERPRINT +) + +// Attributes from RFC 5245 ICE. +const ( + AttrPriority AttrType = 0x0024 // PRIORITY + AttrUseCandidate AttrType = 0x0025 // USE-CANDIDATE + AttrICEControlled AttrType = 0x8029 // ICE-CONTROLLED + AttrICEControlling AttrType = 0x802A // ICE-CONTROLLING +) + +// Attributes from RFC 5766 TURN. +const ( + AttrChannelNumber AttrType = 0x000C // CHANNEL-NUMBER + AttrLifetime AttrType = 0x000D // LIFETIME + AttrXORPeerAddress AttrType = 0x0012 // XOR-PEER-ADDRESS + AttrData AttrType = 0x0013 // DATA + AttrXORRelayedAddress AttrType = 0x0016 // XOR-RELAYED-ADDRESS + AttrEvenPort AttrType = 0x0018 // EVEN-PORT + AttrRequestedTransport AttrType = 0x0019 // REQUESTED-TRANSPORT + AttrDontFragment AttrType = 0x001A // DONT-FRAGMENT + AttrReservationToken AttrType = 0x0022 // RESERVATION-TOKEN +) + +// Attributes from RFC 5780 NAT Behavior Discovery. +const ( + AttrChangeRequest AttrType = 0x0003 // CHANGE-REQUEST + AttrPadding AttrType = 0x0026 // PADDING + AttrResponsePort AttrType = 0x0027 // RESPONSE-PORT + AttrCacheTimeout AttrType = 0x8027 // CACHE-TIMEOUT + AttrResponseOrigin AttrType = 0x802b // RESPONSE-ORIGIN + AttrOtherAddress AttrType = 0x802C // OTHER-ADDRESS +) + +// Attributes from RFC 3489, removed by RFC 5389, +// +// but still used by RFC5389-implementing software like Vovida.org, reTURNServer, etc. +const ( + AttrSourceAddress AttrType = 0x0004 // SOURCE-ADDRESS + AttrChangedAddress AttrType = 0x0005 // CHANGED-ADDRESS +) + +// Attributes from RFC 6062 TURN Extensions for TCP Allocations. +const ( + AttrConnectionID AttrType = 0x002a // CONNECTION-ID +) + +// Attributes from RFC 6156 TURN IPv6. +const ( + AttrRequestedAddressFamily AttrType = 0x0017 // REQUESTED-ADDRESS-FAMILY +) + +// Attributes from An Origin Attribute for the STUN Protocol. +const ( + AttrOrigin AttrType = 0x802F +) + +// Attributes from RFC 8489 STUN. +const ( + AttrMessageIntegritySHA256 AttrType = 0x001C // MESSAGE-INTEGRITY-SHA256 + AttrPasswordAlgorithm AttrType = 0x001D // PASSWORD-ALGORITHM + AttrUserhash AttrType = 0x001E // USERHASH + AttrPasswordAlgorithms AttrType = 0x8002 // PASSWORD-ALGORITHMS + AttrAlternateDomain AttrType = 0x8003 // ALTERNATE-DOMAIN +) + +// Attributes from SPED. +const ( + AttrDtlsInStun AttrType = 0xC070 + AttrDtlsInStunAck AttrType = 0xC071 +) + +// Value returns uint16 representation of attribute type. +func (t AttrType) Value() uint16 { + return uint16(t) +} + +func attrNames() map[AttrType]string { + return map[AttrType]string{ + AttrMappedAddress: "MAPPED-ADDRESS", + AttrUsername: "USERNAME", + AttrErrorCode: "ERROR-CODE", + AttrMessageIntegrity: "MESSAGE-INTEGRITY", + AttrUnknownAttributes: "UNKNOWN-ATTRIBUTES", + AttrRealm: "REALM", + AttrNonce: "NONCE", + AttrXORMappedAddress: "XOR-MAPPED-ADDRESS", + AttrSoftware: "SOFTWARE", + AttrAlternateServer: "ALTERNATE-SERVER", + AttrFingerprint: "FINGERPRINT", + AttrPriority: "PRIORITY", + AttrUseCandidate: "USE-CANDIDATE", + AttrICEControlled: "ICE-CONTROLLED", + AttrICEControlling: "ICE-CONTROLLING", + AttrChannelNumber: "CHANNEL-NUMBER", + AttrLifetime: "LIFETIME", + AttrXORPeerAddress: "XOR-PEER-ADDRESS", + AttrData: "DATA", + AttrXORRelayedAddress: "XOR-RELAYED-ADDRESS", + AttrEvenPort: "EVEN-PORT", + AttrRequestedTransport: "REQUESTED-TRANSPORT", + AttrDontFragment: "DONT-FRAGMENT", + AttrReservationToken: "RESERVATION-TOKEN", + AttrConnectionID: "CONNECTION-ID", + AttrRequestedAddressFamily: "REQUESTED-ADDRESS-FAMILY", + AttrMessageIntegritySHA256: "MESSAGE-INTEGRITY-SHA256", + AttrPasswordAlgorithm: "PASSWORD-ALGORITHM", + AttrUserhash: "USERHASH", + AttrPasswordAlgorithms: "PASSWORD-ALGORITHMS", + AttrAlternateDomain: "ALTERNATE-DOMAIN", + AttrDtlsInStun: "DTLS-IN-STUN", + AttrDtlsInStunAck: "DTLS-IN-STUN-ACKNOWLEDGEMENT", + } +} + +func (t AttrType) String() string { + s, ok := attrNames()[t] + if !ok { + // Just return hex representation of unknown attribute type. + return fmt.Sprintf("0x%x", uint16(t)) + } + + return s +} + +// Known returns true if AttrType is known and implemented +// by this library. +func (t AttrType) Known() bool { + _, valid := attrNames()[t] + + return valid +} + +// RawAttribute is a Type-Length-Value (TLV) object that +// can be added to a STUN message. Attributes are divided into two +// types: comprehension-required and comprehension-optional. STUN +// agents can safely ignore comprehension-optional attributes they +// don't understand, but cannot successfully process a message if it +// contains comprehension-required attributes that are not +// understood. +type RawAttribute struct { + Type AttrType + Length uint16 // ignored while encoding + Value []byte +} + +// AddTo implements Setter, adding attribute as a.Type with a.Value and ignoring +// the Length field. +func (a RawAttribute) AddTo(m *Message) error { + m.Add(a.Type, a.Value) + + return nil +} + +// Equal returns true if a == b. +func (a RawAttribute) Equal(b RawAttribute) bool { + if a.Type != b.Type { + return false + } + if a.Length != b.Length { + return false + } + if len(b.Value) != len(a.Value) { + return false + } + for i, v := range a.Value { + if b.Value[i] != v { + return false + } + } + + return true +} + +func (a RawAttribute) String() string { + return fmt.Sprintf("%s: 0x%x", a.Type, a.Value) +} + +// ErrAttributeNotFound means that attribute with provided attribute +// type does not exist in message. +var ErrAttributeNotFound = errors.New("attribute not found") + +// Get returns byte slice that represents attribute value, +// if there is no attribute with such type, +// ErrAttributeNotFound is returned. +func (m *Message) Get(t AttrType) ([]byte, error) { + v, ok := m.Attributes.Get(t) + if !ok { + return nil, ErrAttributeNotFound + } + + return v.Value, nil +} + +// STUN aligns attributes on 32-bit boundaries, attributes whose content +// is not a multiple of 4 bytes are padded with 1, 2, or 3 bytes of +// padding so that its value contains a multiple of 4 bytes. The +// padding bits are ignored, and may be any value. +// +// https://tools.ietf.org/html/rfc5389#section-15 +const padding = 4 + +func nearestPaddedValueLength(l int) int { + n := padding * (l / padding) + if n < l { + n += padding + } + + return n +} + +// This method converts an uint16 value to AttrType. If it finds an old attribute +// type value, it also translates it to the new value to enable backward +// compatibility. (See: https://github.com/pion/stun/issues/21) +func compatAttrType(val uint16) AttrType { + if val == 0x8020 { // draft-ietf-behave-rfc3489bis-02, MS-TURN + return AttrXORMappedAddress // new: 0x0020 (from draft-ietf-behave-rfc3489bis-03 on) + } + + return AttrType(val) +} diff --git a/vendor/github.com/pion/stun/v3/attributes_debug.go b/vendor/github.com/pion/stun/v3/attributes_debug.go new file mode 100644 index 0000000..836d79f --- /dev/null +++ b/vendor/github.com/pion/stun/v3/attributes_debug.go @@ -0,0 +1,37 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +//go:build debug +// +build debug + +package stun + +import "fmt" + +// AttrOverflowErr occurs when len(v) > Max. +type AttrOverflowErr struct { + Type AttrType + Max int + Got int +} + +func (e AttrOverflowErr) Error() string { + return fmt.Sprintf("incorrect length of %s attribute: %d exceeds maximum %d", + e.Type, e.Got, e.Max, + ) +} + +// AttrLengthErr means that length for attribute is invalid. +type AttrLengthErr struct { + Attr AttrType + Got int + Expected int +} + +func (e AttrLengthErr) Error() string { + return fmt.Sprintf("incorrect length of %s attribute: got %d, expected %d", + e.Attr, + e.Got, + e.Expected, + ) +} diff --git a/vendor/github.com/pion/stun/v3/checks.go b/vendor/github.com/pion/stun/v3/checks.go new file mode 100644 index 0000000..04716a4 --- /dev/null +++ b/vendor/github.com/pion/stun/v3/checks.go @@ -0,0 +1,57 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !debug +// +build !debug + +package stun + +import ( + "errors" + + "github.com/pion/stun/v3/internal/hmac" +) + +// CheckSize returns ErrAttrSizeInvalid if got is not equal to expected. +func CheckSize(_ AttrType, got, expected int) error { + if got == expected { + return nil + } + + return ErrAttributeSizeInvalid +} + +func checkHMAC(got, expected []byte) error { + if hmac.Equal(got, expected) { + return nil + } + + return ErrIntegrityMismatch +} + +func checkFingerprint(got, expected uint32) error { + if got == expected { + return nil + } + + return ErrFingerprintMismatch +} + +// IsAttrSizeInvalid returns true if error means that attribute size is invalid. +func IsAttrSizeInvalid(err error) bool { + return errors.Is(err, ErrAttributeSizeInvalid) +} + +// CheckOverflow returns ErrAttributeSizeOverflow if got is bigger that max. +func CheckOverflow(_ AttrType, got, maxVal int) error { + if got <= maxVal { + return nil + } + + return ErrAttributeSizeOverflow +} + +// IsAttrSizeOverflow returns true if error means that attribute size is too big. +func IsAttrSizeOverflow(err error) bool { + return errors.Is(err, ErrAttributeSizeOverflow) +} diff --git a/vendor/github.com/pion/stun/v3/checks_debug.go b/vendor/github.com/pion/stun/v3/checks_debug.go new file mode 100644 index 0000000..fe236c7 --- /dev/null +++ b/vendor/github.com/pion/stun/v3/checks_debug.go @@ -0,0 +1,65 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +//go:build debug +// +build debug + +package stun + +import "github.com/pion/stun/v3/internal/hmac" + +// CheckSize returns *AttrLengthError if got is not equal to expected. +func CheckSize(a AttrType, got, expected int) error { + if got == expected { + return nil + } + return &AttrLengthErr{ + Got: got, + Expected: expected, + Attr: a, + } +} + +func checkHMAC(got, expected []byte) error { + if hmac.Equal(got, expected) { + return nil + } + return &IntegrityErr{ + Expected: expected, + Actual: got, + } +} + +func checkFingerprint(got, expected uint32) error { + if got == expected { + return nil + } + return &CRCMismatch{ + Actual: got, + Expected: expected, + } +} + +// IsAttrSizeInvalid returns true if error means that attribute size is invalid. +func IsAttrSizeInvalid(err error) bool { + _, ok := err.(*AttrLengthErr) + return ok +} + +// CheckOverflow returns *AttrOverflowErr if got is bigger that max. +func CheckOverflow(t AttrType, got, max int) error { + if got <= max { + return nil + } + return &AttrOverflowErr{ + Type: t, + Got: got, + Max: max, + } +} + +// IsAttrSizeOverflow returns true if error means that attribute size is too big. +func IsAttrSizeOverflow(err error) bool { + _, ok := err.(*AttrOverflowErr) + return ok +} diff --git a/vendor/github.com/pion/stun/v3/client.go b/vendor/github.com/pion/stun/v3/client.go new file mode 100644 index 0000000..9a9d0dc --- /dev/null +++ b/vendor/github.com/pion/stun/v3/client.go @@ -0,0 +1,745 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package stun + +import ( + "crypto/tls" + "errors" + "fmt" + "io" + "log" + "net" + "runtime" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/pion/dtls/v3" + "github.com/pion/transport/v4" + "github.com/pion/transport/v4/stdnet" +) + +// ErrUnsupportedURI is an error thrown if the user passes an unsupported STUN or TURN URI. +var ErrUnsupportedURI = fmt.Errorf("invalid schema or transport") + +// Dial connects to the address on the named network and then +// initializes Client on that connection, returning error if any. +func Dial(network, address string) (*Client, error) { + conn, err := net.Dial(network, address) //nolint: noctx + if err != nil { + return nil, err + } + + return NewClient(conn) +} + +// DialConfig is used to pass configuration to DialURI(). +type DialConfig struct { + DTLSConfig dtls.Config + TLSConfig tls.Config + + Net transport.Net +} + +// DialURI connect to the STUN/TURN URI and then +// initializes Client on that connection, returning error if any. +func DialURI(uri *URI, cfg *DialConfig) (*Client, error) { //nolint:cyclop + var conn Connection + var err error + + nw := cfg.Net + if nw == nil { + nw, err = stdnet.NewNet() + if err != nil { + return nil, fmt.Errorf("failed to create net: %w", err) + } + } + + addr := net.JoinHostPort(uri.Host, strconv.Itoa(uri.Port)) + + switch { + case uri.Scheme == SchemeTypeSTUN: + if conn, err = nw.Dial("udp", addr); err != nil { + return nil, fmt.Errorf("failed to listen: %w", err) + } + + case uri.Scheme == SchemeTypeTURN: + network := "udp" //nolint:goconst + if uri.Proto == ProtoTypeTCP { + network = "tcp" //nolint:goconst + } + + if conn, err = nw.Dial(network, addr); err != nil { + return nil, fmt.Errorf("failed to dial: %w", err) + } + + case uri.Scheme == SchemeTypeTURNS && uri.Proto == ProtoTypeUDP: + dtlsCfg := cfg.DTLSConfig // Copy + dtlsCfg.ServerName = uri.Host + + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, fmt.Errorf("failed to resolve UDPAddr: %w", err) + } + + udpConn, err := nw.DialUDP("udp", nil, udpAddr) + if err != nil { + return nil, fmt.Errorf("failed to dial: %w", err) + } + + if conn, err = dtls.Client(udpConn, udpConn.RemoteAddr(), &dtlsCfg); err != nil { + return nil, fmt.Errorf("failed to connect to '%s': %w", addr, err) + } + + case (uri.Scheme == SchemeTypeTURNS || uri.Scheme == SchemeTypeSTUNS) && uri.Proto == ProtoTypeTCP: + tlsCfg := cfg.TLSConfig //nolint:govet, copylocks + tlsCfg.ServerName = uri.Host + + tcpConn, err := nw.Dial("tcp", addr) + if err != nil { + return nil, fmt.Errorf("failed to dial: %w", err) + } + + conn = tls.Client(tcpConn, &tlsCfg) + + default: + return nil, ErrUnsupportedURI + } + + return NewClient(conn) +} + +// ErrNoConnection means that ClientOptions.Connection is nil. +var ErrNoConnection = errors.New("no connection provided") + +// ClientOption sets some client option. +type ClientOption func(c *Client) + +// WithHandler sets client handler which is called if Agent emits the Event +// with TransactionID that is not currently registered by Client. +// Useful for handling Data indications from TURN server. +func WithHandler(h Handler) ClientOption { + return func(c *Client) { + c.handler = h + } +} + +// WithRTO sets client RTO as defined in STUN RFC. +func WithRTO(rto time.Duration) ClientOption { + return func(c *Client) { + c.rto = int64(rto) + } +} + +// WithClock sets Clock of client, the source of current time. +// Also clock is passed to default collector if set. +func WithClock(clock Clock) ClientOption { + return func(c *Client) { + c.clock = clock + } +} + +// WithTimeoutRate sets RTO timer minimum resolution. +func WithTimeoutRate(d time.Duration) ClientOption { + return func(c *Client) { + c.rtoRate = d + } +} + +// WithAgent sets client STUN agent. +// +// Defaults to agent implementation in current package, +// see agent.go. +func WithAgent(a ClientAgent) ClientOption { + return func(c *Client) { + c.a = a + } +} + +// WithCollector rests client timeout collector, the implementation +// of ticker which calls function on each tick. +func WithCollector(coll Collector) ClientOption { + return func(c *Client) { + c.collector = coll + } +} + +// WithNoConnClose prevents client from closing underlying connection when +// the Close() method is called. +func WithNoConnClose() ClientOption { + return func(c *Client) { + c.closeConn = false + } +} + +// WithNoRetransmit disables retransmissions and sets RTO to +// defaultMaxAttempts * defaultRTO which will be effectively time out +// if not set. +// +// Useful for TCP connections where transport handles RTO. +func WithNoRetransmit(c *Client) { + c.maxAttempts = 0 + if c.rto == 0 { + c.rto = defaultMaxAttempts * int64(defaultRTO) + } +} + +const ( + defaultTimeoutRate = time.Millisecond * 5 + defaultRTO = time.Millisecond * 300 + defaultMaxAttempts = 7 +) + +// NewClient initializes new Client from provided options, +// starting internal goroutines and using default options fields +// if necessary. Call Close method after using Client to close conn and +// release resources. +// +// The conn will be closed on Close call. Use WithNoConnClose option to +// prevent that. +// +// Note that user should handle the protocol multiplexing, client does not +// provide any API for it, so if you need to read application data, wrap the +// connection with your (de-)multiplexer and pass the wrapper as conn. +func NewClient(conn Connection, options ...ClientOption) (*Client, error) { + client := &Client{ + close: make(chan struct{}), + c: conn, + clock: systemClock(), + rto: int64(defaultRTO), + rtoRate: defaultTimeoutRate, + t: make(map[transactionID]*clientTransaction, 100), + maxAttempts: defaultMaxAttempts, + closeConn: true, + } + for _, o := range options { + o(client) + } + if client.c == nil { + return nil, ErrNoConnection + } + if client.a == nil { + client.a = NewAgent(nil) + } + if err := client.a.SetHandler(client.handleAgentCallback); err != nil { + return nil, err + } + if client.collector == nil { + client.collector = &tickerCollector{ + close: make(chan struct{}), + clock: client.clock, + } + } + if err := client.collector.Start(client.rtoRate, func(t time.Time) { + closedOrPanic(client.a.Collect(t)) + }); err != nil { + return nil, err + } + client.wg.Add(1) + go client.readUntilClosed() + runtime.SetFinalizer(client, clientFinalizer) + + return client, nil +} + +func clientFinalizer(c *Client) { + if c == nil { + return + } + err := c.Close() + if errors.Is(err, ErrClientClosed) { + return + } + if err == nil { + log.Println("client: called finalizer on non-closed client") // nolint + + return + } + log.Println("client: called finalizer on non-closed client:", err) // nolint +} + +// Connection wraps Reader, Writer and Closer interfaces. +type Connection interface { + io.Reader + io.Writer + io.Closer +} + +// ClientAgent is Agent implementation that is used by Client to +// process transactions. +type ClientAgent interface { + Process(*Message) error + Close() error + Start(id [TransactionIDSize]byte, deadline time.Time) error + Stop(id [TransactionIDSize]byte) error + Collect(time.Time) error + SetHandler(h Handler) error +} + +// Client simulates "connection" to STUN server. +type Client struct { + rto int64 // time.Duration + a ClientAgent + c Connection + close chan struct{} + rtoRate time.Duration + maxAttempts int32 + closed bool + closeConn bool // should call c.Close() while closing + wg sync.WaitGroup + clock Clock + handler Handler + collector Collector + t map[transactionID]*clientTransaction + + // mux guards closed and t + mux sync.RWMutex +} + +// clientTransaction represents transaction in progress. +// If transaction is succeed or failed, f will be called +// provided by event. +// Concurrent access is invalid. +type clientTransaction struct { + id transactionID + attempt int32 + calls int32 + h Handler + start time.Time + rto time.Duration + raw []byte +} + +func (t *clientTransaction) handle(e Event) { + if atomic.AddInt32(&t.calls, 1) == 1 { + t.h(e) + } +} + +var clientTransactionPool = &sync.Pool{ //nolint:gochecknoglobals + New: func() any { + return &clientTransaction{ + raw: make([]byte, 1500), + } + }, +} + +func acquireClientTransaction() *clientTransaction { + return clientTransactionPool.Get().(*clientTransaction) //nolint:forcetypeassert +} + +func putClientTransaction(t *clientTransaction) { + t.raw = t.raw[:0] + t.start = time.Time{} + t.attempt = 0 + t.id = transactionID{} + clientTransactionPool.Put(t) +} + +func (t *clientTransaction) nextTimeout(now time.Time) time.Time { + return now.Add(time.Duration(t.attempt+1) * t.rto) +} + +// start registers transaction. +// +// Could return ErrClientClosed, ErrTransactionExists. +func (c *Client) start(t *clientTransaction) error { + c.mux.Lock() + defer c.mux.Unlock() + if c.closed { + return ErrClientClosed + } + _, exists := c.t[t.id] + if exists { + return ErrTransactionExists + } + c.t[t.id] = t + + return nil +} + +// Clock abstracts the source of current time. +type Clock interface { + Now() time.Time +} + +type systemClockService struct{} + +func (systemClockService) Now() time.Time { return time.Now() } + +func systemClock() systemClockService { + return systemClockService{} +} + +// SetRTO sets current RTO value. +func (c *Client) SetRTO(rto time.Duration) { + atomic.StoreInt64(&c.rto, int64(rto)) +} + +// StopErr occurs when Client fails to stop transaction while +// processing error. +// +//nolint:errname +type StopErr struct { + Err error // value returned by Stop() + Cause error // error that caused Stop() call +} + +func (e StopErr) Error() string { + return fmt.Sprintf("error while stopping due to %s: %s", sprintErr(e.Cause), sprintErr(e.Err)) +} + +// CloseErr indicates client close failure. +// +//nolint:errname +type CloseErr struct { + AgentErr error + ConnectionErr error +} + +func sprintErr(err error) string { + if err == nil { + return "" //nolint:goconst + } + + return err.Error() +} + +func (c CloseErr) Error() string { + return fmt.Sprintf("failed to close: %s (connection), %s (agent)", sprintErr(c.ConnectionErr), sprintErr(c.AgentErr)) +} + +func (c *Client) readUntilClosed() { + defer c.wg.Done() + m := new(Message) + m.Raw = make([]byte, 1024) + for { + select { + case <-c.close: + return + default: + } + _, err := m.ReadFrom(c.c) + if err == nil { + if pErr := c.a.Process(m); errors.Is(pErr, ErrAgentClosed) { + return + } + } + } +} + +func closedOrPanic(err error) { + if err == nil || errors.Is(err, ErrAgentClosed) { + return + } + panic(err) //nolint +} + +type tickerCollector struct { + close chan struct{} + wg sync.WaitGroup + clock Clock +} + +// Collector calls function f with constant rate. +// +// The simple Collector is ticker which calls function on each tick. +type Collector interface { + Start(rate time.Duration, f func(now time.Time)) error + Close() error +} + +func (a *tickerCollector) Start(rate time.Duration, f func(now time.Time)) error { + t := time.NewTicker(rate) + a.wg.Add(1) + go func() { + defer a.wg.Done() + for { + select { + case <-a.close: + t.Stop() + + return + case <-t.C: + f(a.clock.Now()) + } + } + }() + + return nil +} + +func (a *tickerCollector) Close() error { + close(a.close) + a.wg.Wait() + + return nil +} + +// ErrClientClosed indicates that client is closed. +var ErrClientClosed = errors.New("client is closed") + +// Close stops internal connection and agent, returning CloseErr on error. +func (c *Client) Close() error { + if err := c.checkInit(); err != nil { + return err + } + c.mux.Lock() + if c.closed { + c.mux.Unlock() + + return ErrClientClosed + } + c.closed = true + c.mux.Unlock() + if closeErr := c.collector.Close(); closeErr != nil { + return closeErr + } + var connErr error + agentErr := c.a.Close() + if c.closeConn { + connErr = c.c.Close() + } + close(c.close) + c.wg.Wait() + if agentErr == nil && connErr == nil { + return nil + } + + return CloseErr{ + AgentErr: agentErr, + ConnectionErr: connErr, + } +} + +// Indicate sends indication m to server. Shorthand to Start call +// with zero deadline and callback. +func (c *Client) Indicate(m *Message) error { + return c.Start(m, nil) +} + +// callbackWaitHandler blocks on wait() call until callback is called. +type callbackWaitHandler struct { + handler Handler + callback func(event Event) + cond *sync.Cond + processed bool +} + +func (s *callbackWaitHandler) HandleEvent(e Event) { + s.cond.L.Lock() + if s.callback == nil { + panic("s.callback is nil") //nolint + } + s.callback(e) + s.processed = true + s.cond.Broadcast() + s.cond.L.Unlock() +} + +func (s *callbackWaitHandler) wait() { + s.cond.L.Lock() + for !s.processed { + s.cond.Wait() + } + s.processed = false + s.callback = nil + s.cond.L.Unlock() +} + +func (s *callbackWaitHandler) setCallback(f func(event Event)) { + if f == nil { + panic("f is nil") //nolint + } + s.cond.L.Lock() + s.callback = f + if s.handler == nil { + s.handler = s.HandleEvent + } + s.cond.L.Unlock() +} + +var callbackWaitHandlerPool = sync.Pool{ //nolint:gochecknoglobals + New: func() any { + return &callbackWaitHandler{ + cond: sync.NewCond(new(sync.Mutex)), + } + }, +} + +// ErrClientNotInitialized means that client connection or agent is nil. +var ErrClientNotInitialized = errors.New("client not initialized") + +func (c *Client) checkInit() error { + if c == nil || c.c == nil || c.a == nil || c.close == nil { + return ErrClientNotInitialized + } + + return nil +} + +// Do is Start wrapper that waits until callback is called. If no callback +// provided, Indicate is called instead. +// +// Do has cpu overhead due to blocking, see BenchmarkClient_Do. +// Use Start method for less overhead. +func (c *Client) Do(m *Message, f func(Event)) error { + if err := c.checkInit(); err != nil { + return err + } + if f == nil { + return c.Indicate(m) + } + h := callbackWaitHandlerPool.Get().(*callbackWaitHandler) //nolint:forcetypeassert + h.setCallback(f) + defer func() { + callbackWaitHandlerPool.Put(h) + }() + if err := c.Start(m, h.handler); err != nil { + return err + } + h.wait() + + return nil +} + +func (c *Client) delete(id transactionID) { + c.mux.Lock() + if c.t != nil { + delete(c.t, id) + } + c.mux.Unlock() +} + +type buffer struct { + buf []byte +} + +var bufferPool = &sync.Pool{ //nolint:gochecknoglobals + New: func() any { + return &buffer{buf: make([]byte, 2048)} + }, +} + +func (c *Client) handleAgentCallback(event Event) { //nolint:cyclop + c.mux.Lock() + if c.closed { + c.mux.Unlock() + + return + } + transaction, found := c.t[event.TransactionID] + if found { + delete(c.t, transaction.id) + } + c.mux.Unlock() + if !found { + if c.handler != nil && !errors.Is(event.Error, ErrTransactionStopped) { + c.handler(event) + } + // Ignoring. + return + } + if atomic.LoadInt32(&c.maxAttempts) <= transaction.attempt || event.Error == nil { + // Transaction completed. + transaction.handle(event) + putClientTransaction(transaction) + + return + } + // Doing re-transmission. + transaction.attempt++ + buff := bufferPool.Get().(*buffer) //nolint:forcetypeassert + buff.buf = buff.buf[:copy(buff.buf[:cap(buff.buf)], transaction.raw)] + defer bufferPool.Put(buff) + var ( + now = c.clock.Now() + timeOut = transaction.nextTimeout(now) + id = transaction.id + ) + // Starting client transaction. + if startErr := c.start(transaction); startErr != nil { + c.delete(id) + event.Error = startErr + transaction.handle(event) + putClientTransaction(transaction) + + return + } + // Starting agent transaction. + if startErr := c.a.Start(id, timeOut); startErr != nil { + c.delete(id) + event.Error = startErr + transaction.handle(event) + putClientTransaction(transaction) + + return + } + // Writing message to connection again. + _, writeErr := c.c.Write(buff.buf) + if writeErr != nil { + c.delete(id) + event.Error = writeErr + // Stopping agent transaction instead of waiting until it's deadline. + // This will call handleAgentCallback with "ErrTransactionStopped" error + // which will be ignored. + if stopErr := c.a.Stop(id); stopErr != nil { + // Failed to stop agent transaction. Wrapping the error in StopError. + event.Error = StopErr{ + Err: stopErr, + Cause: writeErr, + } + } + transaction.handle(event) + putClientTransaction(transaction) + + return + } +} + +// Start starts transaction (if h set) and writes message to server, handler +// is called asynchronously. +func (c *Client) Start(msg *Message, handler Handler) error { + if err := c.checkInit(); err != nil { + return err + } + c.mux.RLock() + closed := c.closed + c.mux.RUnlock() + if closed { + return ErrClientClosed + } + if handler != nil { + // Starting transaction only if h is set. Useful for indications. + t := acquireClientTransaction() + t.id = msg.TransactionID + t.start = c.clock.Now() + t.h = handler + t.rto = time.Duration(atomic.LoadInt64(&c.rto)) + t.attempt = 0 + t.raw = append(t.raw[:0], msg.Raw...) + t.calls = 0 + d := t.nextTimeout(t.start) + if err := c.start(t); err != nil { + return err + } + if err := c.a.Start(msg.TransactionID, d); err != nil { + return err + } + } + _, err := msg.WriteTo(c.c) + if err != nil && handler != nil { + c.delete(msg.TransactionID) + // Stopping transaction instead of waiting until deadline. + if stopErr := c.a.Stop(msg.TransactionID); stopErr != nil { + return StopErr{ + Err: stopErr, + Cause: err, + } + } + } + + return err +} diff --git a/vendor/github.com/pion/stun/v3/codecov.yml b/vendor/github.com/pion/stun/v3/codecov.yml new file mode 100644 index 0000000..263e4d4 --- /dev/null +++ b/vendor/github.com/pion/stun/v3/codecov.yml @@ -0,0 +1,22 @@ +# +# DO NOT EDIT THIS FILE +# +# It is automatically copied from https://github.com/pion/.goassets repository. +# +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +coverage: + status: + project: + default: + # Allow decreasing 2% of total coverage to avoid noise. + threshold: 2% + patch: + default: + target: 70% + only_pulls: true + +ignore: + - "examples/*" + - "examples/**/*" diff --git a/vendor/github.com/pion/stun/v3/errorcode.go b/vendor/github.com/pion/stun/v3/errorcode.go new file mode 100644 index 0000000..5cc90a0 --- /dev/null +++ b/vendor/github.com/pion/stun/v3/errorcode.go @@ -0,0 +1,181 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package stun + +import ( + "errors" + "fmt" + "io" +) + +// ErrorCodeAttribute represents ERROR-CODE attribute. +// +// RFC 5389 Section 15.6. +type ErrorCodeAttribute struct { + Code ErrorCode + Reason []byte +} + +func (c ErrorCodeAttribute) String() string { + return fmt.Sprintf("%d: %s", c.Code, c.Reason) +} + +// constants for ERROR-CODE encoding. +const ( + errorCodeReasonStart = 4 + errorCodeClassByte = 2 + errorCodeNumberByte = 3 + errorCodeReasonMaxB = 763 + errorCodeModulo = 100 +) + +// AddTo adds ERROR-CODE to m. +func (c ErrorCodeAttribute) AddTo(msg *Message) error { + value := make([]byte, 0, errorCodeReasonStart+errorCodeReasonMaxB) + if err := CheckOverflow(AttrErrorCode, + len(c.Reason)+errorCodeReasonStart, + errorCodeReasonMaxB+errorCodeReasonStart, + ); err != nil { + return err + } + value = value[:errorCodeReasonStart+len(c.Reason)] + number := byte(c.Code % errorCodeModulo) // error code modulo 100 + class := byte(c.Code / errorCodeModulo) // hundred digit + value[errorCodeClassByte] = class + value[errorCodeNumberByte] = number + copy(value[errorCodeReasonStart:], c.Reason) + msg.Add(AttrErrorCode, value) + + return nil +} + +// GetFrom decodes ERROR-CODE from m. Reason is valid until m.Raw is valid. +func (c *ErrorCodeAttribute) GetFrom(m *Message) error { + value, err := m.Get(AttrErrorCode) + if err != nil { + return err + } + if len(value) < errorCodeReasonStart { + return io.ErrUnexpectedEOF + } + var ( + class = uint16(value[errorCodeClassByte]) + number = uint16(value[errorCodeNumberByte]) + code = int(class*errorCodeModulo + number) + ) + c.Code = ErrorCode(code) + c.Reason = value[errorCodeReasonStart:] + + return nil +} + +// ErrorCode is code for ERROR-CODE attribute. +type ErrorCode int + +// ErrNoDefaultReason means that default reason for provided error code +// is not defined in RFC. +var ErrNoDefaultReason = errors.New("no default reason for ErrorCode") + +// AddTo adds ERROR-CODE with default reason to m. If there +// is no default reason, returns ErrNoDefaultReason. +func (c ErrorCode) AddTo(m *Message) error { + reason := errorReasons[c] + if reason == nil { + return ErrNoDefaultReason + } + a := &ErrorCodeAttribute{ + Code: c, + Reason: reason, + } + + return a.AddTo(m) +} + +// Possible error codes. +const ( + CodeTryAlternate ErrorCode = 300 + CodeBadRequest ErrorCode = 400 + CodeUnauthorized ErrorCode = 401 + CodeUnknownAttribute ErrorCode = 420 + CodeStaleNonce ErrorCode = 438 + CodeRoleConflict ErrorCode = 487 + CodeServerError ErrorCode = 500 +) + +// DEPRECATED constants. +const ( + // DEPRECATED, use CodeUnauthorized. + CodeUnauthorised = CodeUnauthorized +) + +// Error codes from RFC 5766. +// +// RFC 5766 Section 15. +const ( + CodeForbidden ErrorCode = 403 // Forbidden + CodeAllocMismatch ErrorCode = 437 // Allocation Mismatch + CodeWrongCredentials ErrorCode = 441 // Wrong Credentials + CodeUnsupportedTransProto ErrorCode = 442 // Unsupported Transport Protocol + CodeAllocQuotaReached ErrorCode = 486 // Allocation Quota Reached + CodeInsufficientCapacity ErrorCode = 508 // Insufficient Capacity +) + +// Error codes from RFC 6062. +// +// RFC 6062 Section 6.3. +const ( + CodeConnAlreadyExists ErrorCode = 446 + CodeConnTimeoutOrFailure ErrorCode = 447 +) + +// Error codes from RFC 6156. +// +// RFC 6156 Section 10.2. +const ( + CodeAddrFamilyNotSupported ErrorCode = 440 // Address Family not Supported + CodePeerAddrFamilyMismatch ErrorCode = 443 // Peer Address Family Mismatch +) + +//nolint:gochecknoglobals +var errorReasons = map[ErrorCode][]byte{ + CodeTryAlternate: []byte("Try Alternate"), + CodeBadRequest: []byte("Bad Request"), + CodeUnauthorized: []byte("Unauthorized"), + CodeUnknownAttribute: []byte("Unknown Attribute"), + CodeStaleNonce: []byte("Stale Nonce"), + CodeServerError: []byte("Server Error"), + CodeRoleConflict: []byte("Role Conflict"), + + // RFC 5766. + CodeForbidden: []byte("Forbidden"), + CodeAllocMismatch: []byte("Allocation Mismatch"), + CodeWrongCredentials: []byte("Wrong Credentials"), + CodeUnsupportedTransProto: []byte("Unsupported Transport Protocol"), + CodeAllocQuotaReached: []byte("Allocation Quota Reached"), + CodeInsufficientCapacity: []byte("Insufficient Capacity"), + + // RFC 6062. + CodeConnAlreadyExists: []byte("Connection Already Exists"), + CodeConnTimeoutOrFailure: []byte("Connection Timeout or Failure"), + + // RFC 6156. + CodeAddrFamilyNotSupported: []byte("Address Family not Supported"), + CodePeerAddrFamilyMismatch: []byte("Peer Address Family Mismatch"), +} + +// TurnError represents an error from a TURN response. +type TurnError struct { + StunMessageType MessageType + ErrorCodeAttr ErrorCodeAttribute +} + +// Error returns the formatted TURN error message. +func (e TurnError) Error() string { + return fmt.Sprintf("%s (error %s)", e.StunMessageType, e.ErrorCodeAttr.String()) +} + +// String returns the error message as a string. +func (e TurnError) String() string { + return e.Error() +} diff --git a/vendor/github.com/pion/stun/v3/errors.go b/vendor/github.com/pion/stun/v3/errors.go new file mode 100644 index 0000000..d5f59ed --- /dev/null +++ b/vendor/github.com/pion/stun/v3/errors.go @@ -0,0 +1,66 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package stun + +import "errors" + +// DecodeErr records an error and place when it is occurred. +// +//nolint:errname +type DecodeErr struct { + Place DecodeErrPlace + Message string +} + +// IsInvalidCookie returns true if error means that magic cookie +// value is invalid. +func (e DecodeErr) IsInvalidCookie() bool { + return e.Place == DecodeErrPlace{"message", "cookie"} +} + +// IsPlaceParent reports if error place parent is p. +func (e DecodeErr) IsPlaceParent(p string) bool { + return e.Place.Parent == p +} + +// IsPlaceChildren reports if error place children is c. +func (e DecodeErr) IsPlaceChildren(c string) bool { + return e.Place.Children == c +} + +// IsPlace reports if error place is p. +func (e DecodeErr) IsPlace(p DecodeErrPlace) bool { + return e.Place == p +} + +// DecodeErrPlace records a place where error is occurred. +type DecodeErrPlace struct { + Parent string + Children string +} + +func (p DecodeErrPlace) String() string { + return p.Parent + "/" + p.Children +} + +func (e DecodeErr) Error() string { + return "BadFormat for " + e.Place.String() + ": " + e.Message +} + +func newDecodeErr(parent, children, message string) *DecodeErr { + return &DecodeErr{ + Place: DecodeErrPlace{Parent: parent, Children: children}, + Message: message, + } +} + +func newAttrDecodeErr(children, message string) *DecodeErr { + return newDecodeErr("attribute", children, message) +} + +// ErrAttributeSizeInvalid means that decoded attribute size is invalid. +var ErrAttributeSizeInvalid = errors.New("attribute size is invalid") + +// ErrAttributeSizeOverflow means that decoded attribute size is too big. +var ErrAttributeSizeOverflow = errors.New("attribute size overflow") diff --git a/vendor/github.com/pion/stun/v3/fingerprint.go b/vendor/github.com/pion/stun/v3/fingerprint.go new file mode 100644 index 0000000..3c33bdd --- /dev/null +++ b/vendor/github.com/pion/stun/v3/fingerprint.go @@ -0,0 +1,72 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package stun + +import ( + "errors" + "hash/crc32" +) + +// FingerprintAttr represents FINGERPRINT attribute. +// +// RFC 5389 Section 15.5. +type FingerprintAttr struct{} + +// ErrFingerprintMismatch means that computed fingerprint differs from expected. +var ErrFingerprintMismatch = errors.New("fingerprint check failed") + +// Fingerprint is shorthand for FingerprintAttr. +// +// Example: +// +// m := New() +// Fingerprint.AddTo(m) +var Fingerprint FingerprintAttr //nolint:gochecknoglobals + +const ( + fingerprintXORValue uint32 = 0x5354554e //nolint:staticcheck + fingerprintSize = 4 // 32 bit +) + +// FingerprintValue returns CRC-32 of b XOR-ed by 0x5354554e. +// +// The value of the attribute is computed as the CRC-32 of the STUN message +// up to (but excluding) the FINGERPRINT attribute itself, XOR'ed with +// the 32-bit value 0x5354554e (the XOR helps in cases where an +// application packet is also using CRC-32 in it). +func FingerprintValue(b []byte) uint32 { + return crc32.ChecksumIEEE(b) ^ fingerprintXORValue // XOR +} + +// AddTo adds fingerprint to message. +func (FingerprintAttr) AddTo(m *Message) error { + l := m.Length + // length in header should include size of fingerprint attribute + m.Length += fingerprintSize + attributeHeaderSize // increasing length + m.WriteLength() // writing Length to Raw + b := make([]byte, fingerprintSize) + val := FingerprintValue(m.Raw) + bin.PutUint32(b, val) + m.Length = l + m.Add(AttrFingerprint, b) + + return nil +} + +// Check reads fingerprint value from m and checks it, returning error if any. +// Can return *AttrLengthErr, ErrAttributeNotFound, and *CRCMismatch. +func (FingerprintAttr) Check(m *Message) error { + b, err := m.Get(AttrFingerprint) + if err != nil { + return err + } + if err = CheckSize(AttrFingerprint, len(b), fingerprintSize); err != nil { + return err + } + val := bin.Uint32(b) + attrStart := len(m.Raw) - (fingerprintSize + attributeHeaderSize) + expected := FingerprintValue(m.Raw[:attrStart]) + + return checkFingerprint(val, expected) +} diff --git a/vendor/github.com/pion/stun/v3/fingerprint_debug.go b/vendor/github.com/pion/stun/v3/fingerprint_debug.go new file mode 100644 index 0000000..0e3471d --- /dev/null +++ b/vendor/github.com/pion/stun/v3/fingerprint_debug.go @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +//go:build debug +// +build debug + +package stun + +import "fmt" + +// CRCMismatch represents CRC check error. +type CRCMismatch struct { + Expected uint32 + Actual uint32 +} + +func (m CRCMismatch) Error() string { + return fmt.Sprintf("CRC mismatch: %x (expected) != %x (actual)", + m.Expected, + m.Actual, + ) +} diff --git a/vendor/github.com/pion/stun/v3/helpers.go b/vendor/github.com/pion/stun/v3/helpers.go new file mode 100644 index 0000000..75b2f3f --- /dev/null +++ b/vendor/github.com/pion/stun/v3/helpers.go @@ -0,0 +1,115 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package stun + +// Interfaces that are implemented by message attributes, shorthands for them, +// or helpers for message fields as type or transaction id. +type ( + // Setter sets *Message attribute. + Setter interface { + AddTo(m *Message) error + } + // Getter parses attribute from *Message. + Getter interface { + GetFrom(m *Message) error + } + // Checker checks *Message attribute. + Checker interface { + Check(m *Message) error + } +) + +// Build resets message and applies setters to it in batch, returning on +// first error. To prevent allocations, pass pointers to values. +// +// Example: +// +// var ( +// t = BindingRequest +// username = NewUsername("username") +// nonce = NewNonce("nonce") +// realm = NewRealm("example.org") +// ) +// m := new(Message) +// m.Build(t, username, nonce, realm) // 4 allocations +// m.Build(&t, &username, &nonce, &realm) // 0 allocations +// +// See BenchmarkBuildOverhead. +func (m *Message) Build(setters ...Setter) error { + m.Reset() + m.WriteHeader() + for _, s := range setters { + if err := s.AddTo(m); err != nil { + return err + } + } + + return nil +} + +// Check applies checkers to message in batch, returning on first error. +func (m *Message) Check(checkers ...Checker) error { + for _, c := range checkers { + if err := c.Check(m); err != nil { + return err + } + } + + return nil +} + +// Parse applies getters to message in batch, returning on first error. +func (m *Message) Parse(getters ...Getter) error { + for _, c := range getters { + if err := c.GetFrom(m); err != nil { + return err + } + } + + return nil +} + +// MustBuild wraps Build call and panics on error. +func MustBuild(setters ...Setter) *Message { + m, err := Build(setters...) + if err != nil { + panic(err) //nolint + } + + return m +} + +// Build wraps Message.Build method. +func Build(setters ...Setter) (*Message, error) { + m := new(Message) + if err := m.Build(setters...); err != nil { + return nil, err + } + + return m, nil +} + +// ForEach is helper that iterates over message attributes allowing to call +// Getter in f callback to get all attributes of type t and returning on first +// f error. +// +// The m.Get method inside f will be returning next attribute on each f call. +// Does not error if there are no results. +func (m *Message) ForEach(t AttrType, f func(m *Message) error) error { + attrs := m.Attributes + defer func() { + m.Attributes = attrs + }() + for i, a := range attrs { + if a.Type != t { + continue + } + m.Attributes = attrs[i:] + if err := f(m); err != nil { + return err + } + } + + return nil +} diff --git a/vendor/github.com/pion/stun/v3/integrity.go b/vendor/github.com/pion/stun/v3/integrity.go new file mode 100644 index 0000000..7ea5151 --- /dev/null +++ b/vendor/github.com/pion/stun/v3/integrity.go @@ -0,0 +1,130 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package stun + +import ( + "crypto/md5" //nolint:gosec + "crypto/sha1" //nolint:gosec + "errors" + "fmt" + "strings" + + "github.com/pion/stun/v3/internal/hmac" +) + +// separator for credentials. +const credentialsSep = ":" + +// NewLongTermIntegrity returns new MessageIntegrity with key for long-term +// credentials. Password, username, and realm must be SASL-prepared. +func NewLongTermIntegrity(username, realm, password string) MessageIntegrity { + k := strings.Join([]string{username, realm, password}, credentialsSep) + h := md5.New() //nolint:gosec + fmt.Fprint(h, k) //nolint:errcheck + + return MessageIntegrity(h.Sum(nil)) +} + +// NewShortTermIntegrity returns new MessageIntegrity with key for short-term +// credentials. Password must be SASL-prepared. +func NewShortTermIntegrity(password string) MessageIntegrity { + return MessageIntegrity(password) +} + +// MessageIntegrity represents MESSAGE-INTEGRITY attribute. +// +// AddTo and Check methods are using zero-allocation version of hmac, see +// newHMAC function and internal/hmac/pool.go. +// +// RFC 5389 Section 15.4. +type MessageIntegrity []byte + +func newHMAC(key, message, buf []byte) []byte { + mac := hmac.AcquireSHA1(key) + writeOrPanic(mac, message) + defer hmac.PutSHA1(mac) + + return mac.Sum(buf) +} + +func (i MessageIntegrity) String() string { + return fmt.Sprintf("KEY: 0x%x", []byte(i)) +} + +const messageIntegritySize = 20 + +// ErrFingerprintBeforeIntegrity means that FINGERPRINT attribute is already in +// message, so MESSAGE-INTEGRITY attribute cannot be added. +var ErrFingerprintBeforeIntegrity = errors.New("FINGERPRINT before MESSAGE-INTEGRITY attribute") + +// AddTo adds MESSAGE-INTEGRITY attribute to message. +// +// CPU costly, see BenchmarkMessageIntegrity_AddTo. +func (i MessageIntegrity) AddTo(msg *Message) error { + for _, a := range msg.Attributes { + // Message should not contain FINGERPRINT attribute + // before MESSAGE-INTEGRITY. + if a.Type == AttrFingerprint { + return ErrFingerprintBeforeIntegrity + } + } + // The text used as input to HMAC is the STUN message, + // including the header, up to and including the attribute preceding the + // MESSAGE-INTEGRITY attribute. + length := msg.Length + // Adjusting m.Length to contain MESSAGE-INTEGRITY TLV. + msg.Length += messageIntegritySize + attributeHeaderSize + msg.WriteLength() // writing length to m.Raw + v := newHMAC(i, msg.Raw, msg.Raw[len(msg.Raw):]) // calculating HMAC for adjusted m.Raw + msg.Length = length // changing m.Length back + + // Copy hmac value to temporary variable to protect it from resetting + // while processing m.Add call. + vBuf := make([]byte, sha1.Size) + copy(vBuf, v) + + msg.Add(AttrMessageIntegrity, vBuf) + + return nil +} + +// ErrIntegrityMismatch means that computed HMAC differs from expected. +var ErrIntegrityMismatch = errors.New("integrity check failed") + +// Check checks MESSAGE-INTEGRITY attribute. +// +// CPU costly, see BenchmarkMessageIntegrity_Check. +func (i MessageIntegrity) Check(msg *Message) error { + val, err := msg.Get(AttrMessageIntegrity) + if err != nil { + return err + } + + // Adjusting length in header to match m.Raw that was + // used when computing HMAC. + var ( + length = msg.Length + afterIntegrity = false + sizeReduced int + ) + for _, a := range msg.Attributes { + if afterIntegrity { + sizeReduced += nearestPaddedValueLength(int(a.Length)) + sizeReduced += attributeHeaderSize + } + if a.Type == AttrMessageIntegrity { + afterIntegrity = true + } + } + msg.Length -= uint32(sizeReduced) //nolint:gosec // G115 + msg.WriteLength() + // startOfHMAC should be first byte of integrity attribute. + startOfHMAC := messageHeaderSize + msg.Length - (attributeHeaderSize + messageIntegritySize) + b := msg.Raw[:startOfHMAC] // data before integrity attribute + expected := newHMAC(i, b, msg.Raw[len(msg.Raw):]) + msg.Length = length + msg.WriteLength() // writing length back + + return checkHMAC(val, expected) +} diff --git a/vendor/github.com/pion/stun/v3/integrity_debug.go b/vendor/github.com/pion/stun/v3/integrity_debug.go new file mode 100644 index 0000000..27fd0e2 --- /dev/null +++ b/vendor/github.com/pion/stun/v3/integrity_debug.go @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +//go:build debug +// +build debug + +package stun + +import "fmt" + +// IntegrityErr occurs when computed HMAC differs from expected. +type IntegrityErr struct { + Expected []byte + Actual []byte +} + +func (i *IntegrityErr) Error() string { + return fmt.Sprintf( + "Integrity check failed: 0x%x (expected) !- 0x%x (actual)", + i.Expected, i.Actual, + ) +} diff --git a/vendor/github.com/pion/stun/v3/internal/hmac/hmac.go b/vendor/github.com/pion/stun/v3/internal/hmac/hmac.go new file mode 100644 index 0000000..6caddbe --- /dev/null +++ b/vendor/github.com/pion/stun/v3/internal/hmac/hmac.go @@ -0,0 +1,158 @@ +// SPDX-FileCopyrightText: 2009 The Go Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +/* +Package hmac implements the Keyed-Hash Message Authentication Code (HMAC) as +defined in U.S. Federal Information Processing Standards Publication 198. +An HMAC is a cryptographic hash that uses a key to sign a message. +The receiver verifies the hash by recomputing it using the same key. + +Receivers should be careful to use Equal to compare MACs in order to avoid +timing side-channels: + + // ValidMAC reports whether messageMAC is a valid HMAC tag for message. + func ValidMAC(message, messageMAC, key []byte) bool { + mac := hmac.New(sha256.New, key) + mac.Write(message) + expectedMAC := mac.Sum(nil) + return hmac.Equal(messageMAC, expectedMAC) + } +*/ +package hmac + +import ( + "crypto/subtle" + "hash" +) + +// FIPS 198-1: +// https://csrc.nist.gov/publications/fips/fips198-1/FIPS-198-1_final.pdf + +// key is zero padded to the block size of the hash function +// ipad = 0x36 byte repeated for key length +// opad = 0x5c byte repeated for key length +// hmac = H([key ^ opad] H([key ^ ipad] text)) + +// Marshalable is the combination of encoding.BinaryMarshaler and +// encoding.BinaryUnmarshaler. Their method definitions are repeated here to +// avoid a dependency on the encoding package. +type marshalable interface { + MarshalBinary() ([]byte, error) + UnmarshalBinary([]byte) error +} + +type hmac struct { + opad, ipad []byte + outer, inner hash.Hash + + // If marshaled is true, then opad and ipad do not contain a padded + // copy of the key, but rather the marshaled state of outer/inner after + // opad/ipad has been fed into it. + marshaled bool +} + +func (h *hmac) Sum(in []byte) []byte { + origLen := len(in) + in = h.inner.Sum(in) + + if h.marshaled { + if err := h.outer.(marshalable).UnmarshalBinary(h.opad); err != nil { //nolint:forcetypeassert + panic(err) //nolint + } + } else { + h.outer.Reset() + h.outer.Write(h.opad) //nolint:errcheck,gosec + } + h.outer.Write(in[origLen:]) //nolint:errcheck,gosec + + return h.outer.Sum(in[:origLen]) +} + +func (h *hmac) Write(p []byte) (n int, err error) { + return h.inner.Write(p) +} + +func (h *hmac) Size() int { return h.outer.Size() } +func (h *hmac) BlockSize() int { return h.inner.BlockSize() } + +func (h *hmac) Reset() { + if h.marshaled { + if err := h.inner.(marshalable).UnmarshalBinary(h.ipad); err != nil { //nolint:forcetypeassert + panic(err) //nolint + } + + return + } + + h.inner.Reset() + h.inner.Write(h.ipad) //nolint:errcheck,gosec + + // If the underlying hash is marshalable, we can save some time by + // saving a copy of the hash state now, and restoring it on future + // calls to Reset and Sum instead of writing ipad/opad every time. + // + // If either hash is unmarshalable for whatever reason, + // it's safe to bail out here. + marshalableInner, innerOK := h.inner.(marshalable) + if !innerOK { + return + } + marshalableOuter, outerOK := h.outer.(marshalable) + if !outerOK { + return + } + + imarshal, err := marshalableInner.MarshalBinary() + if err != nil { + return + } + + h.outer.Reset() + h.outer.Write(h.opad) //nolint:errcheck,gosec + omarshal, err := marshalableOuter.MarshalBinary() + if err != nil { + return + } + + // Marshaling succeeded; save the marshaled state for later + h.ipad = imarshal + h.opad = omarshal + h.marshaled = true +} + +// New returns a new HMAC hash using the given hash.Hash type and key. +// Note that unlike other hash implementations in the standard library, +// the returned Hash does not implement encoding.BinaryMarshaler +// or encoding.BinaryUnmarshaler. +func New(h func() hash.Hash, key []byte) hash.Hash { + hm := new(hmac) + hm.outer = h() + hm.inner = h() + blocksize := hm.inner.BlockSize() + hm.ipad = make([]byte, blocksize) + hm.opad = make([]byte, blocksize) + if len(key) > blocksize { + // If key is too big, hash it. + hm.outer.Write(key) //nolint:errcheck,gosec + key = hm.outer.Sum(nil) + } + copy(hm.ipad, key) + copy(hm.opad, key) + for i := range hm.ipad { + hm.ipad[i] ^= 0x36 + } + for i := range hm.opad { + hm.opad[i] ^= 0x5c + } + hm.inner.Write(hm.ipad) //nolint:errcheck,gosec + + return hm +} + +// Equal compares two MACs for equality without leaking timing information. +func Equal(mac1, mac2 []byte) bool { + // We don't have to be constant time if the lengths of the MACs are + // different as that suggests that a completely different hash function + // was used. + return subtle.ConstantTimeCompare(mac1, mac2) == 1 +} diff --git a/vendor/github.com/pion/stun/v3/internal/hmac/pool.go b/vendor/github.com/pion/stun/v3/internal/hmac/pool.go new file mode 100644 index 0000000..2694d74 --- /dev/null +++ b/vendor/github.com/pion/stun/v3/internal/hmac/pool.go @@ -0,0 +1,96 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package hmac + +import ( + "crypto/sha1" //nolint:gosec + "crypto/sha256" + "hash" + "sync" +) + +func (h *hmac) resetTo(key []byte) { + h.outer.Reset() + h.inner.Reset() + blocksize := h.inner.BlockSize() + + // Reset size and zero of ipad and opad. + h.ipad = append(h.ipad[:0], make([]byte, blocksize)...) + h.opad = append(h.opad[:0], make([]byte, blocksize)...) + + if len(key) > blocksize { + // If key is too big, hash it. + h.outer.Write(key) //nolint:errcheck,gosec + key = h.outer.Sum(nil) + } + copy(h.ipad, key) + copy(h.opad, key) + for i := range h.ipad { + h.ipad[i] ^= 0x36 + } + for i := range h.opad { + h.opad[i] ^= 0x5c + } + h.inner.Write(h.ipad) //nolint:errcheck,gosec + + h.marshaled = false +} + +var hmacSHA1Pool = &sync.Pool{ //nolint:gochecknoglobals + New: func() any { + h := New(sha1.New, make([]byte, sha1.BlockSize)) + + return h + }, +} + +// AcquireSHA1 returns new HMAC from pool. +func AcquireSHA1(key []byte) hash.Hash { + h := hmacSHA1Pool.Get().(*hmac) //nolint:forcetypeassert + assertHMACSize(h, sha1.Size, sha1.BlockSize) + h.resetTo(key) + + return h +} + +// PutSHA1 puts h to pool. +func PutSHA1(h hash.Hash) { + hm := h.(*hmac) //nolint:forcetypeassert + assertHMACSize(hm, sha1.Size, sha1.BlockSize) + hmacSHA1Pool.Put(hm) +} + +var hmacSHA256Pool = &sync.Pool{ //nolint:gochecknoglobals + New: func() any { + h := New(sha256.New, make([]byte, sha256.BlockSize)) + + return h + }, +} + +// AcquireSHA256 returns new HMAC from SHA256 pool. +func AcquireSHA256(key []byte) hash.Hash { + h := hmacSHA256Pool.Get().(*hmac) //nolint:forcetypeassert + assertHMACSize(h, sha256.Size, sha256.BlockSize) + h.resetTo(key) + + return h +} + +// PutSHA256 puts h to SHA256 pool. +func PutSHA256(h hash.Hash) { + hm := h.(*hmac) //nolint:forcetypeassert + assertHMACSize(hm, sha256.Size, sha256.BlockSize) + hmacSHA256Pool.Put(hm) +} + +// assertHMACSize panics if h.size != size or h.blocksize != blocksize. +// +// Put and Acquire functions are internal functions to project, so +// checking it via such assert is optimal. +func assertHMACSize(h *hmac, size, blocksize int) { //nolint:unparam + if h.Size() != size || h.BlockSize() != blocksize { + panic("BUG: hmac size invalid") //nolint + } +} diff --git a/vendor/github.com/pion/stun/v3/internal/hmac/vendor.sh b/vendor/github.com/pion/stun/v3/internal/hmac/vendor.sh new file mode 100644 index 0000000..190d2b9 --- /dev/null +++ b/vendor/github.com/pion/stun/v3/internal/hmac/vendor.sh @@ -0,0 +1,7 @@ +#!/bin/env bash + +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +cp -v $GOROOT/src/crypto/hmac/{hmac,hmac_test}.go . +git diff {hmac,hmac_test}.go diff --git a/vendor/github.com/pion/stun/v3/message.go b/vendor/github.com/pion/stun/v3/message.go new file mode 100644 index 0000000..ed15e5e --- /dev/null +++ b/vendor/github.com/pion/stun/v3/message.go @@ -0,0 +1,651 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package stun + +import ( + "crypto/rand" + "encoding/base64" + "errors" + "fmt" + "io" +) + +const ( + // magicCookie is fixed value that aids in distinguishing STUN packets + // from packets of other protocols when STUN is multiplexed with those + // other protocols on the same Port. + // + // The magic cookie field MUST contain the fixed value 0x2112A442 in + // network byte order. + // + // Defined in "STUN Message Structure", section 6. + magicCookie = 0x2112A442 + attributeHeaderSize = 4 + messageHeaderSize = 20 + + // TransactionIDSize is length of transaction id array (in bytes). + TransactionIDSize = 12 // 96 bit +) + +// NewTransactionID returns new random transaction ID using crypto/rand +// as source. +func NewTransactionID() (b [TransactionIDSize]byte) { + readFullOrPanic(rand.Reader, b[:]) + + return b +} + +// IsMessage returns true if b looks like STUN message. +// Useful for multiplexing. IsMessage does not guarantee +// that decoding will be successful. +func IsMessage(b []byte) bool { + return len(b) >= messageHeaderSize && bin.Uint32(b[4:8]) == magicCookie +} + +// New returns *Message with pre-allocated Raw. +func New() *Message { + const defaultRawCapacity = 120 + + return &Message{ + Raw: make([]byte, messageHeaderSize, defaultRawCapacity), + } +} + +// ErrDecodeToNil occurs on Decode(data, nil) call. +var ErrDecodeToNil = errors.New("attempt to decode to nil message") + +// Decode decodes Message from data to m, returning error if any. +func Decode(data []byte, m *Message) error { + if m == nil { + return ErrDecodeToNil + } + m.Raw = append(m.Raw[:0], data...) + + return m.Decode() +} + +// Message represents a single STUN packet. It uses aggressive internal +// buffering to enable zero-allocation encoding and decoding, +// so there are some usage constraints: +// +// Message, its fields, results of m.Get or any attribute a.GetFrom +// are valid only until Message.Raw is not modified. +type Message struct { + Type MessageType + Length uint32 // len(Raw) not including header + TransactionID [TransactionIDSize]byte + Attributes Attributes + Raw []byte +} + +// MarshalBinary implements the encoding.BinaryMarshaler interface. +func (m Message) MarshalBinary() (data []byte, err error) { + // We can't return m.Raw, allocation is expected by implicit interface + // contract induced by other implementations. + b := make([]byte, len(m.Raw)) + copy(b, m.Raw) + + return b, nil +} + +// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface. +func (m *Message) UnmarshalBinary(data []byte) error { + // We can't retain data, copy is expected by interface contract. + m.Raw = append(m.Raw[:0], data...) + + return m.Decode() +} + +// GobEncode implements the gob.GobEncoder interface. +func (m Message) GobEncode() ([]byte, error) { + return m.MarshalBinary() +} + +// GobDecode implements the gob.GobDecoder interface. +func (m *Message) GobDecode(data []byte) error { + return m.UnmarshalBinary(data) +} + +// AddTo sets b.TransactionID to m.TransactionID. +// +// Implements Setter to aid in crafting responses. +func (m *Message) AddTo(b *Message) error { + b.TransactionID = m.TransactionID + b.WriteTransactionID() + + return nil +} + +// NewTransactionID sets m.TransactionID to random value from crypto/rand +// and returns error if any. +func (m *Message) NewTransactionID() error { + _, err := io.ReadFull(rand.Reader, m.TransactionID[:]) + if err == nil { + m.WriteTransactionID() + } + + return err +} + +func (m *Message) String() string { + tID := base64.StdEncoding.EncodeToString(m.TransactionID[:]) + aInfo := "" + for k, a := range m.Attributes { + aInfo += fmt.Sprintf("attr%d=%s ", k, a.Type) + } + + return fmt.Sprintf("%s l=%d attrs=%d id=%s, %s", m.Type, m.Length, len(m.Attributes), tID, aInfo) +} + +// Reset resets Message, attributes and underlying buffer length. +func (m *Message) Reset() { + m.Raw = m.Raw[:0] + m.Length = 0 + m.Attributes = m.Attributes[:0] +} + +// grow ensures that internal buffer has n length. +func (m *Message) grow(n int) { + if len(m.Raw) >= n { + return + } + if cap(m.Raw) >= n { + m.Raw = m.Raw[:n] + + return + } + m.Raw = append(m.Raw, make([]byte, n-len(m.Raw))...) +} + +// Add appends new attribute to message. Not goroutine-safe. +// +// Value of attribute is copied to internal buffer so +// it is safe to reuse v. +func (m *Message) Add(attrType AttrType, val []byte) { + // Allocating buffer for TLV (type-length-value). + // T = t, L = len(v), V = v. + // m.Raw will look like: + // [0:20] <- message header + // [20:20+m.Length] <- existing message attributes + // [20+m.Length:20+m.Length+len(v) + 4] <- allocated buffer for new TLV + // [first:last] <- same as previous + // [0 1|2 3|4 4 + len(v)] <- mapping for allocated buffer + // T L V + allocSize := attributeHeaderSize + len(val) // ~ len(TLV) = len(TL) + len(V) + first := messageHeaderSize + int(m.Length) // first byte number + last := first + allocSize // last byte number + m.grow(last) // growing cap(Raw) to fit TLV + m.Raw = m.Raw[:last] // now len(Raw) = last + //nolint:gosec // G115 + m.Length += uint32(allocSize) // rendering length change + + // Sub-slicing internal buffer to simplify encoding. + buf := m.Raw[first:last] // slice for TLV + value := buf[attributeHeaderSize:] // slice for V + attr := RawAttribute{ + Type: attrType, // T + //nolint:gosec // G115 + Length: uint16(len(val)), // L + Value: value, // V + } + + // Encoding attribute TLV to allocated buffer. + bin.PutUint16(buf[0:2], attr.Type.Value()) // T + bin.PutUint16(buf[2:4], attr.Length) // L + copy(value, val) // V + + // Checking that attribute value needs padding. + if attr.Length%padding != 0 { + // Performing padding. + bytesToAdd := nearestPaddedValueLength(len(val)) - len(val) + last += bytesToAdd + m.grow(last) + // setting all padding bytes to zero + // to prevent data leak from previous + // data in next bytesToAdd bytes + buf = m.Raw[last-bytesToAdd : last] + for i := range buf { + buf[i] = 0 + } + m.Raw = m.Raw[:last] // increasing buffer length + //nolint:gosec // G115 + m.Length += uint32(bytesToAdd) // rendering length change + } + m.Attributes = append(m.Attributes, attr) + m.WriteLength() +} + +func attrSliceEqual(a, b Attributes) bool { + for _, attr := range a { + found := false + for _, attrB := range b { + if attrB.Type != attr.Type { + continue + } + if attrB.Equal(attr) { + found = true + + break + } + } + if !found { + return false + } + } + + return true +} + +func attrEqual(attrA, attrB Attributes) bool { + if attrA == nil && attrB == nil { + return true + } + if attrA == nil || attrB == nil { + return false + } + if len(attrA) != len(attrB) { + return false + } + if !attrSliceEqual(attrA, attrB) { + return false + } + if !attrSliceEqual(attrB, attrA) { + return false + } + + return true +} + +// Equal returns true if Message msg equals to m. +// Ignores m.Raw. +func (m *Message) Equal(msg *Message) bool { + if m == nil && msg == nil { + return true + } + if m == nil || msg == nil { + return false + } + if m.Type != msg.Type { + return false + } + if m.TransactionID != msg.TransactionID { + return false + } + if m.Length != msg.Length { + return false + } + if !attrEqual(m.Attributes, msg.Attributes) { + return false + } + + return true +} + +// WriteLength writes m.Length to m.Raw. +func (m *Message) WriteLength() { + m.grow(4) + bin.PutUint16(m.Raw[2:4], uint16(m.Length)) //nolint:gosec // G115 +} + +// WriteHeader writes header to underlying buffer. Not goroutine-safe. +func (m *Message) WriteHeader() { + m.grow(messageHeaderSize) + _ = m.Raw[:messageHeaderSize] // early bounds check to guarantee safety of writes below + + m.WriteType() + m.WriteLength() + bin.PutUint32(m.Raw[4:8], magicCookie) // magic cookie + copy(m.Raw[8:messageHeaderSize], m.TransactionID[:]) // transaction ID +} + +// WriteTransactionID writes m.TransactionID to m.Raw. +func (m *Message) WriteTransactionID() { + copy(m.Raw[8:messageHeaderSize], m.TransactionID[:]) // transaction ID +} + +// WriteAttributes encodes all m.Attributes to m. +func (m *Message) WriteAttributes() { + attributes := m.Attributes + m.Attributes = attributes[:0] + for _, a := range attributes { + m.Add(a.Type, a.Value) + } + m.Attributes = attributes +} + +// WriteType writes m.Type to m.Raw. +func (m *Message) WriteType() { + m.grow(2) + bin.PutUint16(m.Raw[0:2], m.Type.Value()) // message type +} + +// SetType sets m.Type and writes it to m.Raw. +func (m *Message) SetType(t MessageType) { + m.Type = t + m.WriteType() +} + +// Encode re-encodes message into m.Raw. +func (m *Message) Encode() { + m.Raw = m.Raw[:0] + m.WriteHeader() + m.Length = 0 + m.WriteAttributes() +} + +// WriteTo implements WriterTo via calling Write(m.Raw) on w and returning +// call result. +func (m *Message) WriteTo(w io.Writer) (int64, error) { + n, err := w.Write(m.Raw) + + return int64(n), err +} + +// ReadFrom implements ReaderFrom. Reads message from r into m.Raw, +// Decodes it and return error if any. If m.Raw is too small, will return +// ErrUnexpectedEOF, ErrUnexpectedHeaderEOF or *DecodeErr. +// +// Can return *DecodeErr while decoding too. +func (m *Message) ReadFrom(r io.Reader) (int64, error) { + tBuf := m.Raw[:cap(m.Raw)] + var ( + n int + err error + ) + if n, err = r.Read(tBuf); err != nil { + return int64(n), err + } + m.Raw = tBuf[:n] + + return int64(n), m.Decode() +} + +// ErrUnexpectedHeaderEOF means that there were not enough bytes in +// m.Raw to read header. +var ErrUnexpectedHeaderEOF = errors.New("unexpected EOF: not enough bytes to read header") + +// Decode decodes m.Raw into m. +func (m *Message) Decode() error { + // decoding message header + buf := m.Raw + if len(buf) < messageHeaderSize { + return ErrUnexpectedHeaderEOF + } + var ( + msgType = bin.Uint16(buf[0:2]) // first 2 bytes + size = int(bin.Uint16(buf[2:4])) // second 2 bytes + cookie = bin.Uint32(buf[4:8]) // last 4 bytes + fullSize = messageHeaderSize + size // len(m.Raw) + ) + if cookie != magicCookie { + msg := fmt.Sprintf("%x is invalid magic cookie (should be %x)", cookie, magicCookie) + + return newDecodeErr("message", "cookie", msg) + } + if len(buf) < fullSize { + msg := fmt.Sprintf("buffer length %d is less than %d (expected message size)", len(buf), fullSize) + + return newAttrDecodeErr("message", msg) + } + // saving header data + m.Type.ReadValue(msgType) + m.Length = uint32(size) //nolint:gosec // G115 + copy(m.TransactionID[:], buf[8:messageHeaderSize]) + + m.Attributes = m.Attributes[:0] + var ( + offset = 0 + b = buf[messageHeaderSize:fullSize] + ) + for offset < size { + // checking that we have enough bytes to read header + if len(b) < attributeHeaderSize { + msg := fmt.Sprintf("buffer length %d is less than %d (expected header size)", len(b), attributeHeaderSize) + + return newAttrDecodeErr("header", msg) + } + var ( + attr = RawAttribute{ + Type: compatAttrType(bin.Uint16(b[0:2])), // first 2 bytes + Length: bin.Uint16(b[2:4]), // second 2 bytes + } + aL = int(attr.Length) // attribute length + aBuffL = nearestPaddedValueLength(aL) // expected buffer length (with padding) + ) + b = b[attributeHeaderSize:] // slicing again to simplify value read + offset += attributeHeaderSize + if len(b) < aBuffL { // checking size + msg := fmt.Sprintf("buffer length %d is less than %d (expected value size for %s)", len(b), aBuffL, attr.Type) + + return newAttrDecodeErr("value", msg) + } + attr.Value = b[:aL] + offset += aBuffL + b = b[aBuffL:] + + m.Attributes = append(m.Attributes, attr) + } + + return nil +} + +// Write decodes message and return error if any. +// +// Any error is unrecoverable, but message could be partially decoded. +func (m *Message) Write(tBuf []byte) (int, error) { + m.Raw = append(m.Raw[:0], tBuf...) + + return len(tBuf), m.Decode() +} + +// CloneTo clones m to b securing any further m mutations. +func (m *Message) CloneTo(b *Message) error { + b.Raw = append(b.Raw[:0], m.Raw...) + + return b.Decode() +} + +// MessageClass is 8-bit representation of 2-bit class of STUN Message Class. +type MessageClass byte + +// Possible values for message class in STUN Message Type. +const ( + ClassRequest MessageClass = 0x00 // 0b00 + ClassIndication MessageClass = 0x01 // 0b01 + ClassSuccessResponse MessageClass = 0x02 // 0b10 + ClassErrorResponse MessageClass = 0x03 // 0b11 +) + +// Common STUN message types. +var ( + // Binding request message type. + BindingRequest = NewType(MethodBinding, ClassRequest) //nolint:gochecknoglobals + // Binding success response message type. + BindingSuccess = NewType(MethodBinding, ClassSuccessResponse) //nolint:gochecknoglobals + // Binding error response message type. + BindingError = NewType(MethodBinding, ClassErrorResponse) //nolint:gochecknoglobals +) + +func (c MessageClass) String() string { + switch c { + case ClassRequest: + return "request" + case ClassIndication: + return "indication" + case ClassSuccessResponse: + return "success response" + case ClassErrorResponse: + return "error response" + default: + panic("unknown message class") //nolint + } +} + +// Method is uint16 representation of 12-bit STUN method. +type Method uint16 + +// Possible methods for STUN Message. +const ( + MethodBinding Method = 0x001 + MethodAllocate Method = 0x003 + MethodRefresh Method = 0x004 + MethodSend Method = 0x006 + MethodData Method = 0x007 + MethodCreatePermission Method = 0x008 + MethodChannelBind Method = 0x009 +) + +// Methods from RFC 6062. +const ( + MethodConnect Method = 0x000a + MethodConnectionBind Method = 0x000b + MethodConnectionAttempt Method = 0x000c +) + +func methodName() map[Method]string { + return map[Method]string{ + MethodBinding: "Binding", + MethodAllocate: "Allocate", + MethodRefresh: "Refresh", + MethodSend: "Send", + MethodData: "Data", + MethodCreatePermission: "CreatePermission", + MethodChannelBind: "ChannelBind", + + // RFC 6062. + MethodConnect: "Connect", + MethodConnectionBind: "ConnectionBind", + MethodConnectionAttempt: "ConnectionAttempt", + } +} + +func (m Method) String() string { + s, ok := methodName()[m] + if !ok { + // Falling back to hex representation. + s = fmt.Sprintf("0x%x", uint16(m)) + } + + return s +} + +// MessageType is STUN Message Type Field. +type MessageType struct { + Method Method // e.g. binding + Class MessageClass // e.g. request +} + +// AddTo sets m type to t. +func (t MessageType) AddTo(m *Message) error { + m.SetType(t) + + return nil +} + +// NewType returns new message type with provided method and class. +func NewType(method Method, class MessageClass) MessageType { + return MessageType{ + Method: method, + Class: class, + } +} + +const ( + methodABits = 0xf // 0b0000000000001111 + methodBBits = 0x70 // 0b0000000001110000 + methodDBits = 0xf80 // 0b0000111110000000 + + methodBShift = 1 + methodDShift = 2 + + firstBit = 0x1 + secondBit = 0x2 + + c0Bit = firstBit + c1Bit = secondBit + + classC0Shift = 4 + classC1Shift = 7 +) + +// Value returns bit representation of messageType. +func (t MessageType) Value() uint16 { + // 0 1 + // 2 3 4 5 6 7 8 9 0 1 2 3 4 5 + // +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ + // |M |M |M|M|M|C|M|M|M|C|M|M|M|M| + // |11|10|9|8|7|1|6|5|4|0|3|2|1|0| + // +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ + // Figure 3: Format of STUN Message Type Field + + // Warning: Abandon all hope ye who enter here. + // Splitting M into A(M0-M3), B(M4-M6), D(M7-M11). + msg := uint16(t.Method) + a := msg & methodABits // A = M * 0b0000000000001111 (right 4 bits) + b := msg & methodBBits // B = M * 0b0000000001110000 (3 bits after A) + d := msg & methodDBits // D = M * 0b0000111110000000 (5 bits after B) + + // Shifting to add "holes" for C0 (at 4 bit) and C1 (8 bit). + msg = a + (b << methodBShift) + (d << methodDShift) + + // C0 is zero bit of C, C1 is first bit. + // C0 = C * 0b01, C1 = (C * 0b10) >> 1 + // Ct = C0 << 4 + C1 << 8. + // Optimizations: "((C * 0b10) >> 1) << 8" as "(C * 0b10) << 7" + // We need C0 shifted by 4, and C1 by 8 to fit "11" and "7" positions + // (see figure 3). + c := uint16(t.Class) + c0 := (c & c0Bit) << classC0Shift + c1 := (c & c1Bit) << classC1Shift + class := c0 + c1 + + return msg + class +} + +// ReadValue decodes uint16 into MessageType. +func (t *MessageType) ReadValue(v uint16) { + // Decoding class. + // We are taking first bit from v >> 4 and second from v >> 7. + c0 := (v >> classC0Shift) & c0Bit + c1 := (v >> classC1Shift) & c1Bit + class := c0 + c1 + t.Class = MessageClass(class) + + // Decoding method. + a := v & methodABits // A(M0-M3) + b := (v >> methodBShift) & methodBBits // B(M4-M6) + d := (v >> methodDShift) & methodDBits // D(M7-M11) + m := a + b + d + t.Method = Method(m) +} + +func (t MessageType) String() string { + return fmt.Sprintf("%s %s", t.Method, t.Class) +} + +// Contains return true if message contain t attribute. +func (m *Message) Contains(t AttrType) bool { + for _, a := range m.Attributes { + if a.Type == t { + return true + } + } + + return false +} + +type transactionIDValueSetter [TransactionIDSize]byte + +// NewTransactionIDSetter returns new Setter that sets message transaction id +// to provided value. +func NewTransactionIDSetter(value [TransactionIDSize]byte) Setter { + return transactionIDValueSetter(value) +} + +func (t transactionIDValueSetter) AddTo(m *Message) error { + m.TransactionID = t + m.WriteTransactionID() + + return nil +} diff --git a/vendor/github.com/pion/stun/v3/renovate.json b/vendor/github.com/pion/stun/v3/renovate.json new file mode 100644 index 0000000..f1bb98c --- /dev/null +++ b/vendor/github.com/pion/stun/v3/renovate.json @@ -0,0 +1,6 @@ +{ + "$schema": "https://docs.renovatebot.com/renovate-schema.json", + "extends": [ + "github>pion/renovate-config" + ] +} diff --git a/vendor/github.com/pion/stun/v3/stun.go b/vendor/github.com/pion/stun/v3/stun.go new file mode 100644 index 0000000..26c2a53 --- /dev/null +++ b/vendor/github.com/pion/stun/v3/stun.go @@ -0,0 +1,56 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package stun implements Session Traversal Utilities for NAT (STUN) RFC 5389. +// +// The stun package is intended to use by package that implements extension +// to STUN (e.g. TURN) or client/server applications. +// +// Most methods are designed to be zero allocations. If it is not enough, +// low-level methods are available. On other hand, there are helpers that +// reduce code repeat. +// +// See examples for Message for basic usage, or https://github.com/pion/turn +// package for example of stun extension implementation. +package stun + +import ( + "encoding/binary" + "io" +) + +// bin is shorthand to binary.BigEndian. +var bin = binary.BigEndian //nolint:gochecknoglobals + +func readFullOrPanic(r io.Reader, v []byte) int { + n, err := io.ReadFull(r, v) + if err != nil { + panic(err) //nolint + } + + return n +} + +func writeOrPanic(w io.Writer, v []byte) int { + n, err := w.Write(v) + if err != nil { + panic(err) //nolint + } + + return n +} + +// IANA assigned ports for "stun" protocol. +const ( + DefaultPort = 3478 + DefaultTLSPort = 5349 +) + +type transactionIDSetter struct{} + +func (transactionIDSetter) AddTo(m *Message) error { + return m.NewTransactionID() +} + +// TransactionID is Setter for m.TransactionID. +var TransactionID Setter = transactionIDSetter{} //nolint:gochecknoglobals diff --git a/vendor/github.com/pion/stun/v3/textattrs.go b/vendor/github.com/pion/stun/v3/textattrs.go new file mode 100644 index 0000000..b9626f0 --- /dev/null +++ b/vendor/github.com/pion/stun/v3/textattrs.go @@ -0,0 +1,134 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package stun + +// NewUsername returns Username with provided value. +func NewUsername(username string) Username { + return Username(username) +} + +// Username represents USERNAME attribute. +// +// RFC 5389 Section 15.3. +type Username []byte + +func (u Username) String() string { + return string(u) +} + +const maxUsernameB = 513 + +// AddTo adds USERNAME attribute to message. +func (u Username) AddTo(m *Message) error { + return TextAttribute(u).AddToAs(m, AttrUsername, maxUsernameB) +} + +// GetFrom gets USERNAME from message. +func (u *Username) GetFrom(m *Message) error { + return (*TextAttribute)(u).GetFromAs(m, AttrUsername) +} + +// NewRealm returns Realm with provided value. +// Must be SASL-prepared. +func NewRealm(realm string) Realm { + return Realm(realm) +} + +// Realm represents REALM attribute. +// +// RFC 5389 Section 15.7. +type Realm []byte + +func (n Realm) String() string { + return string(n) +} + +const maxRealmB = 763 + +// AddTo adds NONCE to message. +func (n Realm) AddTo(m *Message) error { + return TextAttribute(n).AddToAs(m, AttrRealm, maxRealmB) +} + +// GetFrom gets REALM from message. +func (n *Realm) GetFrom(m *Message) error { + return (*TextAttribute)(n).GetFromAs(m, AttrRealm) +} + +const softwareRawMaxB = 763 + +// Software is SOFTWARE attribute. +// +// RFC 5389 Section 15.10. +type Software []byte + +func (s Software) String() string { + return string(s) +} + +// NewSoftware returns *Software from string. +func NewSoftware(software string) Software { + return Software(software) +} + +// AddTo adds Software attribute to m. +func (s Software) AddTo(m *Message) error { + return TextAttribute(s).AddToAs(m, AttrSoftware, softwareRawMaxB) +} + +// GetFrom decodes Software from m. +func (s *Software) GetFrom(m *Message) error { + return (*TextAttribute)(s).GetFromAs(m, AttrSoftware) +} + +// Nonce represents NONCE attribute. +// +// RFC 5389 Section 15.8. +type Nonce []byte + +// NewNonce returns new Nonce from string. +func NewNonce(nonce string) Nonce { + return Nonce(nonce) +} + +func (n Nonce) String() string { + return string(n) +} + +const maxNonceB = 763 + +// AddTo adds NONCE to message. +func (n Nonce) AddTo(m *Message) error { + return TextAttribute(n).AddToAs(m, AttrNonce, maxNonceB) +} + +// GetFrom gets NONCE from message. +func (n *Nonce) GetFrom(m *Message) error { + return (*TextAttribute)(n).GetFromAs(m, AttrNonce) +} + +// TextAttribute is helper for adding and getting text attributes. +type TextAttribute []byte + +// AddToAs adds attribute with type t to m, checking maximum length. If maxLen +// is less than 0, no check is performed. +func (v TextAttribute) AddToAs(m *Message, t AttrType, maxLen int) error { + if err := CheckOverflow(t, len(v), maxLen); err != nil { + return err + } + m.Add(t, v) + + return nil +} + +// GetFromAs gets t attribute from m and appends its value to reseted v. +func (v *TextAttribute) GetFromAs(m *Message, t AttrType) error { + a, err := m.Get(t) + if err != nil { + return err + } + *v = a + + return nil +} diff --git a/vendor/github.com/pion/stun/v3/uattrs.go b/vendor/github.com/pion/stun/v3/uattrs.go new file mode 100644 index 0000000..238cd84 --- /dev/null +++ b/vendor/github.com/pion/stun/v3/uattrs.go @@ -0,0 +1,69 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package stun + +import "errors" + +// UnknownAttributes represents UNKNOWN-ATTRIBUTES attribute. +// +// RFC 5389 Section 15.9. +type UnknownAttributes []AttrType + +func (a UnknownAttributes) String() string { + s := "" + if len(a) == 0 { + return "" + } + last := len(a) - 1 + for i, t := range a { + s += t.String() + if i != last { + s += ", " + } + } + + return s +} + +// type size is 16 bit. +const attrTypeSize = 4 + +// AddTo adds UNKNOWN-ATTRIBUTES attribute to message. +func (a UnknownAttributes) AddTo(m *Message) error { + v := make([]byte, 0, attrTypeSize*20) // 20 should be enough + // If len(a.Types) > 20, there will be allocations. + for i, t := range a { + v = append(v, 0, 0, 0, 0) // 4 times by 0 (16 bits) + first := attrTypeSize * i + last := first + attrTypeSize + bin.PutUint16(v[first:last], t.Value()) + } + m.Add(AttrUnknownAttributes, v) + + return nil +} + +// ErrBadUnknownAttrsSize means that UNKNOWN-ATTRIBUTES attribute value +// has invalid length. +var ErrBadUnknownAttrsSize = errors.New("bad UNKNOWN-ATTRIBUTES size") + +// GetFrom parses UNKNOWN-ATTRIBUTES from message. +func (a *UnknownAttributes) GetFrom(m *Message) error { + v, err := m.Get(AttrUnknownAttributes) + if err != nil { + return err + } + if len(v)%attrTypeSize != 0 { + return ErrBadUnknownAttrsSize + } + *a = (*a)[:0] + first := 0 + for first < len(v) { + last := first + attrTypeSize + *a = append(*a, AttrType(bin.Uint16(v[first:last]))) + first = last + } + + return nil +} diff --git a/vendor/github.com/pion/stun/v3/uri.go b/vendor/github.com/pion/stun/v3/uri.go new file mode 100644 index 0000000..b9f7cbf --- /dev/null +++ b/vendor/github.com/pion/stun/v3/uri.go @@ -0,0 +1,264 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package stun + +import ( + "errors" + "net" + "net/url" + "strconv" +) + +var ( + // ErrUnknownType indicates an error with Unknown info. + ErrUnknownType = errors.New("Unknown") + + // ErrSchemeType indicates the scheme type could not be parsed. + ErrSchemeType = errors.New("unknown scheme type") + + // ErrSTUNQuery indicates query arguments are provided in a STUN URL. + ErrSTUNQuery = errors.New("queries not supported in stun address") + + // ErrInvalidQuery indicates an malformed query is provided. + ErrInvalidQuery = errors.New("invalid query") + + // ErrHost indicates malformed hostname is provided. + ErrHost = errors.New("invalid hostname") + + // ErrPort indicates malformed port is provided. + ErrPort = errors.New("invalid port") + + // ErrProtoType indicates an unsupported transport type was provided. + ErrProtoType = errors.New("invalid transport protocol type") +) + +// SchemeType indicates the type of server used in the ice.URL structure. +type SchemeType int + +const ( + // SchemeTypeUnknown indicates an unknown or unsupported scheme. + SchemeTypeUnknown SchemeType = iota + + // SchemeTypeSTUN indicates the URL represents a STUN server. + SchemeTypeSTUN + + // SchemeTypeSTUNS indicates the URL represents a STUNS (secure) server. + SchemeTypeSTUNS + + // SchemeTypeTURN indicates the URL represents a TURN server. + SchemeTypeTURN + + // SchemeTypeTURNS indicates the URL represents a TURNS (secure) server. + SchemeTypeTURNS +) + +// NewSchemeType defines a procedure for creating a new SchemeType from a raw +// string naming the scheme type. +func NewSchemeType(raw string) SchemeType { + switch raw { + case "stun": + return SchemeTypeSTUN + case "stuns": + return SchemeTypeSTUNS + case "turn": + return SchemeTypeTURN + case "turns": + return SchemeTypeTURNS + default: + return SchemeTypeUnknown + } +} + +func (t SchemeType) String() string { + switch t { + case SchemeTypeSTUN: + return "stun" + case SchemeTypeSTUNS: + return "stuns" + case SchemeTypeTURN: + return "turn" + case SchemeTypeTURNS: + return "turns" + default: + return ErrUnknownType.Error() + } +} + +// ProtoType indicates the transport protocol type that is used in the ice.URL +// structure. +type ProtoType int + +const ( + // ProtoTypeUnknown indicates an unknown or unsupported protocol. + ProtoTypeUnknown ProtoType = iota + + // ProtoTypeUDP indicates the URL uses a UDP transport. + ProtoTypeUDP + + // ProtoTypeTCP indicates the URL uses a TCP transport. + ProtoTypeTCP +) + +// NewProtoType defines a procedure for creating a new ProtoType from a raw +// string naming the transport protocol type. +func NewProtoType(raw string) ProtoType { + switch raw { + case "udp": //nolint:goconst + return ProtoTypeUDP + case "tcp": //nolint:goconst + return ProtoTypeTCP + default: + return ProtoTypeUnknown + } +} + +func (t ProtoType) String() string { + switch t { + case ProtoTypeUDP: + return "udp" + case ProtoTypeTCP: + return "tcp" + default: + return ErrUnknownType.Error() + } +} + +// URI represents a STUN (rfc7064) or TURN (rfc7065) URI. +type URI struct { + Scheme SchemeType + Host string + Port int + Username string + Password string + Proto ProtoType +} + +// ParseURI parses a STUN or TURN urls following the ABNF syntax described in +// https://tools.ietf.org/html/rfc7064 and https://tools.ietf.org/html/rfc7065 +// respectively. +func ParseURI(raw string) (*URI, error) { //nolint:gocognit,cyclop + rawParts, err := url.Parse(raw) + if err != nil { + return nil, err + } + + var uri URI + uri.Scheme = NewSchemeType(rawParts.Scheme) + if uri.Scheme == SchemeTypeUnknown { + return nil, ErrSchemeType + } + + var rawPort string + if uri.Host, rawPort, err = net.SplitHostPort(rawParts.Opaque); err != nil { //nolint:nestif + var e *net.AddrError + if errors.As(err, &e) { + if e.Err == "missing port in address" { + nextRawURL := uri.Scheme.String() + ":" + rawParts.Opaque + switch uri.Scheme { + case SchemeTypeSTUN, SchemeTypeTURN: + nextRawURL += ":3478" + if rawParts.RawQuery != "" { + nextRawURL += "?" + rawParts.RawQuery + } + + return ParseURI(nextRawURL) + case SchemeTypeSTUNS, SchemeTypeTURNS: + nextRawURL += ":5349" + if rawParts.RawQuery != "" { + nextRawURL += "?" + rawParts.RawQuery + } + + return ParseURI(nextRawURL) + default: + return nil, ErrSchemeType + } + } + } + + return nil, err + } + + if uri.Host == "" { + return nil, ErrHost + } + + if uri.Port, err = strconv.Atoi(rawPort); err != nil { + return nil, ErrPort + } + + switch uri.Scheme { + case SchemeTypeSTUN: + qArgs, err := url.ParseQuery(rawParts.RawQuery) + if err != nil || len(qArgs) > 0 { + return nil, ErrSTUNQuery + } + uri.Proto = ProtoTypeUDP + case SchemeTypeSTUNS: + qArgs, err := url.ParseQuery(rawParts.RawQuery) + if err != nil || len(qArgs) > 0 { + return nil, ErrSTUNQuery + } + uri.Proto = ProtoTypeTCP + case SchemeTypeTURN: + proto, err := parseProto(rawParts.RawQuery) + if err != nil { + return nil, err + } + + uri.Proto = proto + if uri.Proto == ProtoTypeUnknown { + uri.Proto = ProtoTypeUDP + } + case SchemeTypeTURNS: + proto, err := parseProto(rawParts.RawQuery) + if err != nil { + return nil, err + } + + uri.Proto = proto + if uri.Proto == ProtoTypeUnknown { + uri.Proto = ProtoTypeTCP + } + + case SchemeTypeUnknown: + } + + return &uri, nil +} + +func parseProto(raw string) (ProtoType, error) { + qArgs, err := url.ParseQuery(raw) + if err != nil || len(qArgs) > 1 { + return ProtoTypeUnknown, ErrInvalidQuery + } + + var proto ProtoType + if rawProto := qArgs.Get("transport"); rawProto != "" { + if proto = NewProtoType(rawProto); proto == ProtoTypeUnknown { + return ProtoTypeUnknown, ErrProtoType + } + + return proto, nil + } + + if len(qArgs) > 0 { + return ProtoTypeUnknown, ErrInvalidQuery + } + + return proto, nil +} + +func (u URI) String() string { + rawURL := u.Scheme.String() + ":" + net.JoinHostPort(u.Host, strconv.Itoa(u.Port)) + if u.Scheme == SchemeTypeTURN || u.Scheme == SchemeTypeTURNS { + rawURL += "?transport=" + u.Proto.String() + } + + return rawURL +} + +// IsSecure returns whether the this URL's scheme describes secure scheme or not. +func (u URI) IsSecure() bool { + return u.Scheme == SchemeTypeSTUNS || u.Scheme == SchemeTypeTURNS +} diff --git a/vendor/github.com/pion/stun/v3/xoraddr.go b/vendor/github.com/pion/stun/v3/xoraddr.go new file mode 100644 index 0000000..5c1a193 --- /dev/null +++ b/vendor/github.com/pion/stun/v3/xoraddr.go @@ -0,0 +1,152 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package stun + +import ( + "errors" + "fmt" + "io" + "net" + "strconv" + + "github.com/pion/transport/v4/utils/xor" +) + +const ( + familyIPv4 uint16 = 0x01 + familyIPv6 uint16 = 0x02 +) + +// XORMappedAddress implements XOR-MAPPED-ADDRESS attribute. +// +// RFC 5389 Section 15.2. +type XORMappedAddress struct { + IP net.IP + Port int +} + +func (a XORMappedAddress) String() string { + return net.JoinHostPort(a.IP.String(), strconv.Itoa(a.Port)) +} + +// isIPv4 returns true if ip with len of net.IPv6Len seems to be ipv4. +func isIPv4(ip net.IP) bool { + // Optimized for performance. Copied from net.IP.To4. + return isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff +} + +// Is p all zeros? +func isZeros(p net.IP) bool { + for i := 0; i < len(p); i++ { + if p[i] != 0 { + return false + } + } + + return true +} + +// ErrBadIPLength means that len(IP) is not net.{IPv6len,IPv4len}. +var ErrBadIPLength = errors.New("invalid length of IP value") + +// AddToAs adds XOR-MAPPED-ADDRESS value to msg as attr attribute. +func (a XORMappedAddress) AddToAs(msg *Message, attr AttrType) error { + var ( + family = familyIPv4 + ip = a.IP + ) + if len(a.IP) == net.IPv6len { + if isIPv4(ip) { + ip = ip[12:16] // like in ip.To4() + } else { + family = familyIPv6 + } + } else if len(ip) != net.IPv4len { + return ErrBadIPLength + } + value := make([]byte, 32+128) + value[0] = 0 // first 8 bits are zeroes + xorValue := make([]byte, net.IPv6len) + copy(xorValue[4:], msg.TransactionID[:]) + bin.PutUint32(xorValue[0:4], magicCookie) + bin.PutUint16(value[0:2], family) + bin.PutUint16(value[2:4], uint16(a.Port^magicCookie>>16)) //nolint:gosec // G115, false positive, port + xor.XorBytes(value[4:4+len(ip)], ip, xorValue) + msg.Add(attr, value[:4+len(ip)]) + + return nil +} + +// AddTo adds XOR-MAPPED-ADDRESS to m. Can return ErrBadIPLength +// if len(a.IP) is invalid. +func (a XORMappedAddress) AddTo(m *Message) error { + return a.AddToAs(m, AttrXORMappedAddress) +} + +// GetFromAs decodes XOR-MAPPED-ADDRESS attribute value in message +// getting it as for attr type. +func (a *XORMappedAddress) GetFromAs(msg *Message, attr AttrType) error { + value, err := msg.Get(attr) + if err != nil { + return err + } + family := bin.Uint16(value[0:2]) + if family != familyIPv6 && family != familyIPv4 { + return newDecodeErr("xor-mapped address", "family", + fmt.Sprintf("bad value %d", family), + ) + } + ipLen := net.IPv4len + if family == familyIPv6 { + ipLen = net.IPv6len + } + // Ensuring len(a.IP) == ipLen and reusing a.IP. + if len(a.IP) < ipLen { + a.IP = make(net.IP, ipLen) + } else { + a.IP = a.IP[:ipLen] + for i := range a.IP { + a.IP[i] = 0 + } + } + + if len(value) <= 4 { + return io.ErrUnexpectedEOF + } + if err := CheckOverflow(attr, len(value[4:]), len(a.IP)); err != nil { + return err + } + a.Port = int(bin.Uint16(value[2:4])) ^ (magicCookie >> 16) + xorValue := make([]byte, 4+TransactionIDSize) + bin.PutUint32(xorValue[0:4], magicCookie) + copy(xorValue[4:], msg.TransactionID[:]) + xor.XorBytes(a.IP, value[4:], xorValue) + + return nil +} + +// GetFrom decodes XOR-MAPPED-ADDRESS attribute in message and returns +// error if any. While decoding, a.IP is reused if possible and can be +// rendered to invalid state (e.g. if a.IP was set to IPv6 and then +// IPv4 value were decoded into it), be careful. +// +// Example: +// +// expectedIP := net.ParseIP("213.141.156.236") +// expectedIP.String() // 213.141.156.236, 16 bytes, first 12 of them are zeroes +// expectedPort := 21254 +// addr := &XORMappedAddress{ +// IP: expectedIP, +// Port: expectedPort, +// } +// // addr were added to message that is decoded as newMessage +// // ... +// +// addr.GetFrom(newMessage) +// addr.IP.String() // 213.141.156.236, net.IPv4Len +// expectedIP.String() // d58d:9cec::ffff:d58d:9cec, 16 bytes, first 4 are IPv4 +// // now we have len(expectedIP) = 16 and len(addr.IP) = 4. +func (a *XORMappedAddress) GetFrom(m *Message) error { + return a.GetFromAs(m, AttrXORMappedAddress) +} diff --git a/vendor/github.com/pion/transport/v4/.gitignore b/vendor/github.com/pion/transport/v4/.gitignore new file mode 100644 index 0000000..6e2f206 --- /dev/null +++ b/vendor/github.com/pion/transport/v4/.gitignore @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +### JetBrains IDE ### +##################### +.idea/ + +### Emacs Temporary Files ### +############################# +*~ + +### Folders ### +############### +bin/ +vendor/ +node_modules/ + +### Files ### +############# +*.ivf +*.ogg +tags +cover.out +*.sw[poe] +*.wasm +examples/sfu-ws/cert.pem +examples/sfu-ws/key.pem +wasm_exec.js diff --git a/vendor/github.com/pion/transport/v4/.golangci.yml b/vendor/github.com/pion/transport/v4/.golangci.yml new file mode 100644 index 0000000..4b4025f --- /dev/null +++ b/vendor/github.com/pion/transport/v4/.golangci.yml @@ -0,0 +1,147 @@ +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +version: "2" +linters: + enable: + - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers + - bidichk # Checks for dangerous unicode character sequences + - bodyclose # checks whether HTTP response body is closed successfully + - containedctx # containedctx is a linter that detects struct contained context.Context field + - contextcheck # check the function whether use a non-inherited context + - cyclop # checks function and package cyclomatic complexity + - decorder # check declaration order and count of types, constants, variables and functions + - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) + - dupl # Tool for code clone detection + - durationcheck # check for two durations multiplied together + - err113 # Golang linter to check the errors handling expressions + - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases + - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. + - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. + - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. + - exhaustive # check exhaustiveness of enum switch statements + - forbidigo # Forbids identifiers + - forcetypeassert # finds forced type assertions + - gochecknoglobals # Checks that no globals are present in Go code + - gocognit # Computes and checks the cognitive complexity of functions + - goconst # Finds repeated strings that could be replaced by a constant + - gocritic # The most opinionated Go source code linter + - gocyclo # Computes and checks the cyclomatic complexity of functions + - godot # Check if comments end in a period + - godox # Tool for detection of FIXME, TODO and other comment keywords + - goheader # Checks is file header matches to pattern + - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. + - goprintffuncname # Checks that printf-like functions are named with `f` at the end + - gosec # Inspects source code for security problems + - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string + - grouper # An analyzer to analyze expression groups. + - importas # Enforces consistent import aliases + - ineffassign # Detects when assignments to existing variables are not used + - lll # Reports long lines + - maintidx # maintidx measures the maintainability index of each function. + - makezero # Finds slice declarations with non-zero initial length + - misspell # Finds commonly misspelled English words in comments + - nakedret # Finds naked returns in functions greater than a specified function length + - nestif # Reports deeply nested if statements + - nilerr # Finds the code that returns nil even if it checks that the error is not nil. + - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. + - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity + - noctx # noctx finds sending http request without context.Context + - predeclared # find code that shadows one of Go's predeclared identifiers + - revive # golint replacement, finds style mistakes + - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks + - tagliatelle # Checks the struct tags. + - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers + - unconvert # Remove unnecessary type conversions + - unparam # Reports unused function parameters + - unused # Checks Go code for unused constants, variables, functions and types + - varnamelen # checks that the length of a variable's name matches its scope + - wastedassign # wastedassign finds wasted assignment statements + - whitespace # Tool for detection of leading and trailing whitespace + disable: + - depguard # Go linter that checks if package imports are in a list of acceptable packages + - funlen # Tool for detection of long functions + - gochecknoinits # Checks that no init functions are present in Go code + - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. + - interfacebloat # A linter that checks length of interface. + - ireturn # Accept Interfaces, Return Concrete Types + - mnd # An analyzer to detect magic numbers + - nolintlint # Reports ill-formed or insufficient nolint directives + - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test + - prealloc # Finds slice declarations that could potentially be preallocated + - promlinter # Check Prometheus metrics naming via promlint + - rowserrcheck # checks whether Err of rows is checked successfully + - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. + - testpackage # linter that makes you use a separate _test package + - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes + - wrapcheck # Checks that errors returned from external packages are wrapped + - wsl # Whitespace Linter - Forces you to use empty lines! + settings: + staticcheck: + checks: + - all + - -QF1008 # "could remove embedded field", to keep it explicit! + - -QF1003 # "could use tagged switch on enum", Cases conflicts with exhaustive! + exhaustive: + default-signifies-exhaustive: true + forbidigo: + forbid: + - pattern: ^fmt.Print(f|ln)?$ + - pattern: ^log.(Panic|Fatal|Print)(f|ln)?$ + - pattern: ^os.Exit$ + - pattern: ^panic$ + - pattern: ^print(ln)?$ + - pattern: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$ + pkg: ^testing$ + msg: use testify/assert instead + analyze-types: true + gomodguard: + blocked: + modules: + - github.com/pkg/errors: + recommendations: + - errors + govet: + enable: + - shadow + revive: + rules: + # Prefer 'any' type alias over 'interface{}' for Go 1.18+ compatibility + - name: use-any + severity: warning + disabled: false + misspell: + locale: US + varnamelen: + max-distance: 12 + min-name-length: 2 + ignore-type-assert-ok: true + ignore-map-index-ok: true + ignore-chan-recv-ok: true + ignore-decls: + - i int + - n int + - w io.Writer + - r io.Reader + - b []byte + exclusions: + generated: lax + rules: + - linters: + - forbidigo + - gocognit + path: (examples|main\.go) + - linters: + - gocognit + path: _test\.go + - linters: + - forbidigo + path: cmd +formatters: + enable: + - gci # Gci control golang package import order and make it always deterministic. + - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification + - gofumpt # Gofumpt checks whether code was gofumpt-ed. + - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports + exclusions: + generated: lax diff --git a/vendor/github.com/pion/transport/v4/.goreleaser.yml b/vendor/github.com/pion/transport/v4/.goreleaser.yml new file mode 100644 index 0000000..30093e9 --- /dev/null +++ b/vendor/github.com/pion/transport/v4/.goreleaser.yml @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +builds: +- skip: true diff --git a/vendor/github.com/pion/transport/v4/LICENSE b/vendor/github.com/pion/transport/v4/LICENSE new file mode 100644 index 0000000..491caf6 --- /dev/null +++ b/vendor/github.com/pion/transport/v4/LICENSE @@ -0,0 +1,9 @@ +MIT License + +Copyright (c) 2023 The Pion community + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/pion/transport/v4/README.md b/vendor/github.com/pion/transport/v4/README.md new file mode 100644 index 0000000..0005e27 --- /dev/null +++ b/vendor/github.com/pion/transport/v4/README.md @@ -0,0 +1,34 @@ +

+
+ Pion Transport +
+

+

Transport testing for Pion

+

+ Pion transport + join us on Discord Follow us on Bluesky +
+ GitHub Workflow Status + Go Reference + Coverage Status + Go Report Card + License: MIT +

+
+ +### Roadmap +The library is used as a part of our WebRTC implementation. Please refer to that [roadmap](https://github.com/pion/webrtc/issues/9) to track our major milestones. + +### Community +Pion has an active community on the [Discord](https://discord.gg/PngbdqpFbt). + +Follow the [Pion Bluesky](https://bsky.app/profile/pion.ly) or [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news. + +We are always looking to support **your projects**. Please reach out if you have something to build! +If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly) + +### Contributing +Check out the [contributing wiki](https://github.com/pion/webrtc/wiki/Contributing) to join the group of amazing people making this project possible + +### License +MIT License - see [LICENSE](LICENSE) for full text diff --git a/vendor/github.com/pion/transport/v4/codecov.yml b/vendor/github.com/pion/transport/v4/codecov.yml new file mode 100644 index 0000000..263e4d4 --- /dev/null +++ b/vendor/github.com/pion/transport/v4/codecov.yml @@ -0,0 +1,22 @@ +# +# DO NOT EDIT THIS FILE +# +# It is automatically copied from https://github.com/pion/.goassets repository. +# +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +coverage: + status: + project: + default: + # Allow decreasing 2% of total coverage to avoid noise. + threshold: 2% + patch: + default: + target: 70% + only_pulls: true + +ignore: + - "examples/*" + - "examples/**/*" diff --git a/vendor/github.com/pion/transport/v4/deadline/deadline.go b/vendor/github.com/pion/transport/v4/deadline/deadline.go new file mode 100644 index 0000000..6f1fd0c --- /dev/null +++ b/vendor/github.com/pion/transport/v4/deadline/deadline.go @@ -0,0 +1,130 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package deadline provides deadline timer used to implement +// net.Conn compatible connection +package deadline + +import ( + "context" + "sync" + "time" +) + +type deadlineState uint8 + +const ( + deadlineStopped deadlineState = iota + deadlineStarted + deadlineExceeded +) + +var _ context.Context = (*Deadline)(nil) + +// Deadline signals updatable deadline timer. +// Also, it implements context.Context. +type Deadline struct { + mu sync.RWMutex + timer timer + done chan struct{} + deadline time.Time + state deadlineState + pending uint8 +} + +// New creates new deadline timer. +func New() *Deadline { + return &Deadline{ + done: make(chan struct{}), + } +} + +func (d *Deadline) timeout() { + d.mu.Lock() + if d.pending--; d.pending != 0 || d.state != deadlineStarted { + d.mu.Unlock() + + return + } + + d.state = deadlineExceeded + done := d.done + d.mu.Unlock() + + close(done) +} + +// Set new deadline. Zero value means no deadline. +func (d *Deadline) Set(setTo time.Time) { + d.mu.Lock() + defer d.mu.Unlock() + + if d.state == deadlineStarted && d.timer.Stop() { + d.pending-- + } + + d.deadline = setTo + d.pending++ + + if d.state == deadlineExceeded { + d.done = make(chan struct{}) + } + + if setTo.IsZero() { + d.pending-- + d.state = deadlineStopped + + return + } + + if dur := time.Until(setTo); dur > 0 { + d.state = deadlineStarted + if d.timer == nil { + d.timer = afterFunc(dur, d.timeout) + } else { + d.timer.Reset(dur) + } + + return + } + + d.pending-- + d.state = deadlineExceeded + close(d.done) +} + +// Done receives deadline signal. +func (d *Deadline) Done() <-chan struct{} { + d.mu.RLock() + defer d.mu.RUnlock() + + return d.done +} + +// Err returns context.DeadlineExceeded if the deadline is exceeded. +// Otherwise, it returns nil. +func (d *Deadline) Err() error { + d.mu.RLock() + defer d.mu.RUnlock() + if d.state == deadlineExceeded { + return context.DeadlineExceeded + } + + return nil +} + +// Deadline returns current deadline. +func (d *Deadline) Deadline() (time.Time, bool) { + d.mu.RLock() + defer d.mu.RUnlock() + if d.deadline.IsZero() { + return d.deadline, false + } + + return d.deadline, true +} + +// Value returns nil. +func (d *Deadline) Value(any) any { + return nil +} diff --git a/vendor/github.com/pion/transport/v4/deadline/timer.go b/vendor/github.com/pion/transport/v4/deadline/timer.go new file mode 100644 index 0000000..5a39724 --- /dev/null +++ b/vendor/github.com/pion/transport/v4/deadline/timer.go @@ -0,0 +1,13 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package deadline + +import ( + "time" +) + +type timer interface { + Stop() bool + Reset(time.Duration) bool +} diff --git a/vendor/github.com/pion/transport/v4/deadline/timer_generic.go b/vendor/github.com/pion/transport/v4/deadline/timer_generic.go new file mode 100644 index 0000000..0c8f87c --- /dev/null +++ b/vendor/github.com/pion/transport/v4/deadline/timer_generic.go @@ -0,0 +1,15 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js +// +build !js + +package deadline + +import ( + "time" +) + +func afterFunc(d time.Duration, f func()) timer { + return time.AfterFunc(d, f) +} diff --git a/vendor/github.com/pion/transport/v4/deadline/timer_js.go b/vendor/github.com/pion/transport/v4/deadline/timer_js.go new file mode 100644 index 0000000..b77e31e --- /dev/null +++ b/vendor/github.com/pion/transport/v4/deadline/timer_js.go @@ -0,0 +1,67 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +//go:build js +// +build js + +package deadline + +import ( + "sync" + "time" +) + +// jsTimer is a timer utility for wasm with a working Reset function. +type jsTimer struct { + f func() + mu sync.Mutex + timer *time.Timer + version uint64 + started bool +} + +func afterFunc(d time.Duration, f func()) timer { + t := &jsTimer{f: f} + t.Reset(d) + return t +} + +func (t *jsTimer) Stop() bool { + t.mu.Lock() + defer t.mu.Unlock() + + t.version++ + t.timer.Stop() + + started := t.started + t.started = false + return started +} + +func (t *jsTimer) Reset(d time.Duration) bool { + t.mu.Lock() + defer t.mu.Unlock() + + if t.timer != nil { + t.timer.Stop() + } + + t.version++ + version := t.version + t.timer = time.AfterFunc(d, func() { + t.mu.Lock() + if version != t.version { + t.mu.Unlock() + return + } + + t.started = false + t.mu.Unlock() + + t.f() + }) + + started := t.started + t.started = true + return started +} diff --git a/vendor/github.com/pion/transport/v4/net.go b/vendor/github.com/pion/transport/v4/net.go new file mode 100644 index 0000000..1de02ed --- /dev/null +++ b/vendor/github.com/pion/transport/v4/net.go @@ -0,0 +1,451 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package transport implements various networking related +// functions used throughout the Pion modules. +package transport + +import ( + "context" + "errors" + "io" + "net" + "time" +) + +var ( + // ErrNoAddressAssigned ... + ErrNoAddressAssigned = errors.New("no address assigned") + // ErrNotSupported ... + ErrNotSupported = errors.New("not supported yey") + // ErrInterfaceNotFound ... + ErrInterfaceNotFound = errors.New("interface not found") + // ErrNotUDPAddress ... + ErrNotUDPAddress = errors.New("not a UDP address") +) + +// Net is an interface providing common networking functions which are +// similar to the functions provided by standard net package. +type Net interface { + // ListenPacket announces on the local network address. + // + // The network must be "udp", "udp4", "udp6", "unixgram", or an IP + // transport. The IP transports are "ip", "ip4", or "ip6" followed by + // a colon and a literal protocol number or a protocol name, as in + // "ip:1" or "ip:icmp". + // + // For UDP and IP networks, if the host in the address parameter is + // empty or a literal unspecified IP address, ListenPacket listens on + // all available IP addresses of the local system except multicast IP + // addresses. + // To only use IPv4, use network "udp4" or "ip4:proto". + // The address can use a host name, but this is not recommended, + // because it will create a listener for at most one of the host's IP + // addresses. + // If the port in the address parameter is empty or "0", as in + // "127.0.0.1:" or "[::1]:0", a port number is automatically chosen. + // The LocalAddr method of PacketConn can be used to discover the + // chosen port. + // + // See func Dial for a description of the network and address + // parameters. + // + // ListenPacket uses context.Background internally; to specify the context, use + // ListenConfig.ListenPacket. + ListenPacket(network string, address string) (net.PacketConn, error) + + // ListenUDP acts like ListenPacket for UDP networks. + // + // The network must be a UDP network name; see func Dial for details. + // + // If the IP field of laddr is nil or an unspecified IP address, + // ListenUDP listens on all available IP addresses of the local system + // except multicast IP addresses. + // If the Port field of laddr is 0, a port number is automatically + // chosen. + ListenUDP(network string, locAddr *net.UDPAddr) (UDPConn, error) + + // ListenTCP acts like Listen for TCP networks. + // + // The network must be a TCP network name; see func Dial for details. + // + // If the IP field of laddr is nil or an unspecified IP address, + // ListenTCP listens on all available unicast and anycast IP addresses + // of the local system. + // If the Port field of laddr is 0, a port number is automatically + // chosen. + ListenTCP(network string, laddr *net.TCPAddr) (TCPListener, error) + + // Dial connects to the address on the named network. + // + // Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only), + // "udp", "udp4" (IPv4-only), "udp6" (IPv6-only), "ip", "ip4" + // (IPv4-only), "ip6" (IPv6-only), "unix", "unixgram" and + // "unixpacket". + // + // For TCP and UDP networks, the address has the form "host:port". + // The host must be a literal IP address, or a host name that can be + // resolved to IP addresses. + // The port must be a literal port number or a service name. + // If the host is a literal IPv6 address it must be enclosed in square + // brackets, as in "[2001:db8::1]:80" or "[fe80::1%zone]:80". + // The zone specifies the scope of the literal IPv6 address as defined + // in RFC 4007. + // The functions JoinHostPort and SplitHostPort manipulate a pair of + // host and port in this form. + // When using TCP, and the host resolves to multiple IP addresses, + // Dial will try each IP address in order until one succeeds. + // + // Examples: + // + // Dial("tcp", "golang.org:http") + // Dial("tcp", "192.0.2.1:http") + // Dial("tcp", "198.51.100.1:80") + // Dial("udp", "[2001:db8::1]:domain") + // Dial("udp", "[fe80::1%lo0]:53") + // Dial("tcp", ":80") + // + // For IP networks, the network must be "ip", "ip4" or "ip6" followed + // by a colon and a literal protocol number or a protocol name, and + // the address has the form "host". The host must be a literal IP + // address or a literal IPv6 address with zone. + // It depends on each operating system how the operating system + // behaves with a non-well known protocol number such as "0" or "255". + // + // Examples: + // + // Dial("ip4:1", "192.0.2.1") + // Dial("ip6:ipv6-icmp", "2001:db8::1") + // Dial("ip6:58", "fe80::1%lo0") + // + // For TCP, UDP and IP networks, if the host is empty or a literal + // unspecified IP address, as in ":80", "0.0.0.0:80" or "[::]:80" for + // TCP and UDP, "", "0.0.0.0" or "::" for IP, the local system is + // assumed. + // + // For Unix networks, the address must be a file system path. + Dial(network, address string) (net.Conn, error) + + // DialUDP acts like Dial for UDP networks. + // + // The network must be a UDP network name; see func Dial for details. + // + // If laddr is nil, a local address is automatically chosen. + // If the IP field of raddr is nil or an unspecified IP address, the + // local system is assumed. + DialUDP(network string, laddr, raddr *net.UDPAddr) (UDPConn, error) + + // DialTCP acts like Dial for TCP networks. + // + // The network must be a TCP network name; see func Dial for details. + // + // If laddr is nil, a local address is automatically chosen. + // If the IP field of raddr is nil or an unspecified IP address, the + // local system is assumed. + DialTCP(network string, laddr, raddr *net.TCPAddr) (TCPConn, error) + + // ResolveIPAddr returns an address of IP end point. + // + // The network must be an IP network name. + // + // If the host in the address parameter is not a literal IP address, + // ResolveIPAddr resolves the address to an address of IP end point. + // Otherwise, it parses the address as a literal IP address. + // The address parameter can use a host name, but this is not + // recommended, because it will return at most one of the host name's + // IP addresses. + // + // See func Dial for a description of the network and address + // parameters. + ResolveIPAddr(network, address string) (*net.IPAddr, error) + + // ResolveUDPAddr returns an address of UDP end point. + // + // The network must be a UDP network name. + // + // If the host in the address parameter is not a literal IP address or + // the port is not a literal port number, ResolveUDPAddr resolves the + // address to an address of UDP end point. + // Otherwise, it parses the address as a pair of literal IP address + // and port number. + // The address parameter can use a host name, but this is not + // recommended, because it will return at most one of the host name's + // IP addresses. + // + // See func Dial for a description of the network and address + // parameters. + ResolveUDPAddr(network, address string) (*net.UDPAddr, error) + + // ResolveTCPAddr returns an address of TCP end point. + // + // The network must be a TCP network name. + // + // If the host in the address parameter is not a literal IP address or + // the port is not a literal port number, ResolveTCPAddr resolves the + // address to an address of TCP end point. + // Otherwise, it parses the address as a pair of literal IP address + // and port number. + // The address parameter can use a host name, but this is not + // recommended, because it will return at most one of the host name's + // IP addresses. + // + // See func Dial for a description of the network and address + // parameters. + ResolveTCPAddr(network, address string) (*net.TCPAddr, error) + + // Interfaces returns a list of the system's network interfaces. + Interfaces() ([]*Interface, error) + + // InterfaceByIndex returns the interface specified by index. + // + // On Solaris, it returns one of the logical network interfaces + // sharing the logical data link; for more precision use + // InterfaceByName. + InterfaceByIndex(index int) (*Interface, error) + + // InterfaceByName returns the interface specified by name. + InterfaceByName(name string) (*Interface, error) + + // The following functions are extensions to Go's standard net package + + CreateDialer(dialer *net.Dialer) Dialer + CreateListenConfig(listenerConfig *net.ListenConfig) ListenConfig +} + +// Dialer is identical to net.Dialer excepts that its methods +// (Dial, DialContext) are overridden to use the Net interface. +// Use vnet.CreateDialer() to create an instance of this Dialer. +type Dialer interface { + Dial(network, address string) (net.Conn, error) +} + +// ListenConfig is identical to net.ListenConfig except that its methods +// (Listen, ListenPacket) are overridden to use the Net interface. +// Use vnet.Create:ListenConfig() to create an instance of this ListenConfig. +type ListenConfig interface { + Listen(ctx context.Context, network, address string) (net.Listener, error) + ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) +} + +// UDPConn is packet-oriented connection for UDP. +type UDPConn interface { + // Close closes the connection. + // Any blocked Read or Write operations will be unblocked and return errors. + Close() error + + // LocalAddr returns the local network address, if known. + LocalAddr() net.Addr + + // RemoteAddr returns the remote network address, if known. + RemoteAddr() net.Addr + + // SetDeadline sets the read and write deadlines associated + // with the connection. It is equivalent to calling both + // SetReadDeadline and SetWriteDeadline. + // + // A deadline is an absolute time after which I/O operations + // fail instead of blocking. The deadline applies to all future + // and pending I/O, not just the immediately following call to + // Read or Write. After a deadline has been exceeded, the + // connection can be refreshed by setting a deadline in the future. + // + // If the deadline is exceeded a call to Read or Write or to other + // I/O methods will return an error that wraps os.ErrDeadlineExceeded. + // This can be tested using errors.Is(err, os.ErrDeadlineExceeded). + // The error's Timeout method will return true, but note that there + // are other possible errors for which the Timeout method will + // return true even if the deadline has not been exceeded. + // + // An idle timeout can be implemented by repeatedly extending + // the deadline after successful Read or Write calls. + // + // A zero value for t means I/O operations will not time out. + SetDeadline(t time.Time) error + + // SetReadDeadline sets the deadline for future Read calls + // and any currently-blocked Read call. + // A zero value for t means Read will not time out. + SetReadDeadline(t time.Time) error + + // SetWriteDeadline sets the deadline for future Write calls + // and any currently-blocked Write call. + // Even if write times out, it may return n > 0, indicating that + // some of the data was successfully written. + // A zero value for t means Write will not time out. + SetWriteDeadline(t time.Time) error + + // SetReadBuffer sets the size of the operating system's + // receive buffer associated with the connection. + SetReadBuffer(bytes int) error + + // SetWriteBuffer sets the size of the operating system's + // transmit buffer associated with the connection. + SetWriteBuffer(bytes int) error + + // Read reads data from the connection. + // Read can be made to time out and return an error after a fixed + // time limit; see SetDeadline and SetReadDeadline. + Read(b []byte) (n int, err error) + + // ReadFrom reads a packet from the connection, + // copying the payload into p. It returns the number of + // bytes copied into p and the return address that + // was on the packet. + // It returns the number of bytes read (0 <= n <= len(p)) + // and any error encountered. Callers should always process + // the n > 0 bytes returned before considering the error err. + // ReadFrom can be made to time out and return an error after a + // fixed time limit; see SetDeadline and SetReadDeadline. + ReadFrom(p []byte) (n int, addr net.Addr, err error) + + // ReadFromUDP acts like ReadFrom but returns a UDPAddr. + ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) + + // ReadMsgUDP reads a message from c, copying the payload into b and + // the associated out-of-band data into oob. It returns the number of + // bytes copied into b, the number of bytes copied into oob, the flags + // that were set on the message and the source address of the message. + // + // The packages golang.org/x/net/ipv4 and golang.org/x/net/ipv6 can be + // used to manipulate IP-level socket options in oob. + ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) + + // Write writes data to the connection. + // Write can be made to time out and return an error after a fixed + // time limit; see SetDeadline and SetWriteDeadline. + Write(b []byte) (n int, err error) + + // WriteTo writes a packet with payload p to addr. + // WriteTo can be made to time out and return an Error after a + // fixed time limit; see SetDeadline and SetWriteDeadline. + // On packet-oriented connections, write timeouts are rare. + WriteTo(p []byte, addr net.Addr) (n int, err error) + + // WriteToUDP acts like WriteTo but takes a UDPAddr. + WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) + + // WriteMsgUDP writes a message to addr via c if c isn't connected, or + // to c's remote address if c is connected (in which case addr must be + // nil). The payload is copied from b and the associated out-of-band + // data is copied from oob. It returns the number of payload and + // out-of-band bytes written. + // + // The packages golang.org/x/net/ipv4 and golang.org/x/net/ipv6 can be + // used to manipulate IP-level socket options in oob. + WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) +} + +// TCPConn is an interface for TCP network connections. +type TCPConn interface { + net.Conn + + // CloseRead shuts down the reading side of the TCP connection. + // Most callers should just use Close. + CloseRead() error + + // CloseWrite shuts down the writing side of the TCP connection. + // Most callers should just use Close. + CloseWrite() error + + // ReadFrom implements the io.ReaderFrom ReadFrom method. + ReadFrom(r io.Reader) (int64, error) + + // SetLinger sets the behavior of Close on a connection which still + // has data waiting to be sent or to be acknowledged. + // + // If sec < 0 (the default), the operating system finishes sending the + // data in the background. + // + // If sec == 0, the operating system discards any unsent or + // unacknowledged data. + // + // If sec > 0, the data is sent in the background as with sec < 0. On + // some operating systems after sec seconds have elapsed any remaining + // unsent data may be discarded. + SetLinger(sec int) error + + // SetKeepAlive sets whether the operating system should send + // keep-alive messages on the connection. + SetKeepAlive(keepalive bool) error + + // SetKeepAlivePeriod sets period between keep-alives. + SetKeepAlivePeriod(d time.Duration) error + + // SetNoDelay controls whether the operating system should delay + // packet transmission in hopes of sending fewer packets (Nagle's + // algorithm). The default is true (no delay), meaning that data is + // sent as soon as possible after a Write. + SetNoDelay(noDelay bool) error + + // SetWriteBuffer sets the size of the operating system's + // transmit buffer associated with the connection. + SetWriteBuffer(bytes int) error + + // SetReadBuffer sets the size of the operating system's + // receive buffer associated with the connection. + SetReadBuffer(bytes int) error +} + +// TCPListener is a TCP network listener. Clients should typically +// use variables of type Listener instead of assuming TCP. +type TCPListener interface { + net.Listener + + // AcceptTCP accepts the next incoming call and returns the new + // connection. + AcceptTCP() (TCPConn, error) + + // SetDeadline sets the deadline associated with the listener. + // A zero time value disables the deadline. + SetDeadline(t time.Time) error +} + +// Interface wraps a standard net.Interfaces and its assigned addresses. +type Interface struct { + net.Interface + addrs []net.Addr +} + +// NewInterface creates a new interface based of a standard net.Interface. +func NewInterface(ifc net.Interface) *Interface { + return &Interface{ + Interface: ifc, + addrs: nil, + } +} + +// AddAddress adds a new address to the interface. +func (ifc *Interface) AddAddress(addr net.Addr) { + ifc.addrs = append(ifc.addrs, addr) +} + +// RemoveAddress removes an address from the interface. +func (ifc *Interface) RemoveAddress(ip net.IP) bool { + for i, addr := range ifc.addrs { + var addrIP net.IP + switch a := addr.(type) { + case *net.IPNet: + addrIP = a.IP + case *net.IPAddr: + addrIP = a.IP + default: + continue + } + if addrIP.Equal(ip) { + ifc.addrs = append(ifc.addrs[:i], ifc.addrs[i+1:]...) + + return true + } + } + + return false +} + +// Addrs returns a slice of configured addresses on the interface. +func (ifc *Interface) Addrs() ([]net.Addr, error) { + if len(ifc.addrs) == 0 { + return nil, ErrNoAddressAssigned + } + + return ifc.addrs, nil +} diff --git a/vendor/github.com/pion/transport/v4/netctx/conn.go b/vendor/github.com/pion/transport/v4/netctx/conn.go new file mode 100644 index 0000000..79506c2 --- /dev/null +++ b/vendor/github.com/pion/transport/v4/netctx/conn.go @@ -0,0 +1,191 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package netctx wraps common net interfaces using context.Context. +package netctx + +import ( + "context" + "errors" + "io" + "net" + "sync" + "sync/atomic" + "time" +) + +// ErrClosing is returned on Write to closed connection. +var ErrClosing = errors.New("use of closed network connection") + +// Reader is an interface for context controlled reader. +type Reader interface { + ReadContext(context.Context, []byte) (int, error) +} + +// Writer is an interface for context controlled writer. +type Writer interface { + WriteContext(context.Context, []byte) (int, error) +} + +// ReadWriter is a composite of ReadWriter. +type ReadWriter interface { + Reader + Writer +} + +// Conn is a wrapper of net.Conn using context.Context. +type Conn interface { + Reader + Writer + io.Closer + LocalAddr() net.Addr + RemoteAddr() net.Addr + Conn() net.Conn +} + +type conn struct { + nextConn net.Conn + closed chan struct{} + closeOnce sync.Once + readMu sync.Mutex + writeMu sync.Mutex +} + +var veryOld = time.Unix(0, 1) //nolint:gochecknoglobals + +// NewConn creates a new Conn wrapping given net.Conn. +func NewConn(netConn net.Conn) Conn { + c := &conn{ + nextConn: netConn, + closed: make(chan struct{}), + } + + return c +} + +// ReadContext reads data from the connection. +// Unlike net.Conn.Read(), the provided context is used to control timeout. +func (c *conn) ReadContext(ctx context.Context, b []byte) (int, error) { //nolint:cyclop + c.readMu.Lock() + defer c.readMu.Unlock() + + select { + case <-c.closed: + return 0, net.ErrClosed + default: + } + + done := make(chan struct{}) + var wg sync.WaitGroup + var errSetDeadline atomic.Value + wg.Add(1) + go func() { + defer wg.Done() + select { + case <-ctx.Done(): + // context canceled + if err := c.nextConn.SetReadDeadline(veryOld); err != nil { + errSetDeadline.Store(err) + + return + } + <-done + if err := c.nextConn.SetReadDeadline(time.Time{}); err != nil { + errSetDeadline.Store(err) + } + case <-done: + } + }() + + n, err := c.nextConn.Read(b) + + close(done) + wg.Wait() + if e := ctx.Err(); e != nil && n == 0 { + err = e + } + if err2, ok := errSetDeadline.Load().(error); ok && err == nil && err2 != nil { + err = err2 + } + + return n, err +} + +// WriteContext writes data to the connection. +// Unlike net.Conn.Write(), the provided context is used to control timeout. +func (c *conn) WriteContext(ctx context.Context, b []byte) (int, error) { //nolint:cyclop + c.writeMu.Lock() + defer c.writeMu.Unlock() + + select { + case <-c.closed: + return 0, ErrClosing + default: + } + + done := make(chan struct{}) + var wg sync.WaitGroup + var errSetDeadline atomic.Value + wg.Add(1) + go func() { + defer wg.Done() + select { + case <-ctx.Done(): + // context canceled + if err := c.nextConn.SetWriteDeadline(veryOld); err != nil { + errSetDeadline.Store(err) + + return + } + <-done + if err := c.nextConn.SetWriteDeadline(time.Time{}); err != nil { + errSetDeadline.Store(err) + } + case <-done: + } + }() + + n, err := c.nextConn.Write(b) + + close(done) + wg.Wait() + if e := ctx.Err(); e != nil && n == 0 { + err = e + } + if err2, ok := errSetDeadline.Load().(error); ok && err == nil && err2 != nil { + err = err2 + } + + return n, err +} + +// Close closes the connection. +// Any blocked ReadContext or WriteContext operations will be unblocked and +// return errors. +func (c *conn) Close() error { + err := c.nextConn.Close() + c.closeOnce.Do(func() { + c.writeMu.Lock() + c.readMu.Lock() + close(c.closed) + c.readMu.Unlock() + c.writeMu.Unlock() + }) + + return err +} + +// LocalAddr returns the local network address, if known. +func (c *conn) LocalAddr() net.Addr { + return c.nextConn.LocalAddr() +} + +// LocalAddr returns the local network address, if known. +func (c *conn) RemoteAddr() net.Addr { + return c.nextConn.RemoteAddr() +} + +// Conn returns the underlying net.Conn. +func (c *conn) Conn() net.Conn { + return c.nextConn +} diff --git a/vendor/github.com/pion/transport/v4/netctx/packetconn.go b/vendor/github.com/pion/transport/v4/netctx/packetconn.go new file mode 100644 index 0000000..feb5910 --- /dev/null +++ b/vendor/github.com/pion/transport/v4/netctx/packetconn.go @@ -0,0 +1,181 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package netctx + +import ( + "context" + "io" + "net" + "sync" + "sync/atomic" + "time" +) + +// ReaderFrom is an interface for context controlled packet reader. +type ReaderFrom interface { + ReadFromContext(context.Context, []byte) (int, net.Addr, error) +} + +// WriterTo is an interface for context controlled packet writer. +type WriterTo interface { + WriteToContext(context.Context, []byte, net.Addr) (int, error) +} + +// PacketConn is a wrapper of net.PacketConn using context.Context. +type PacketConn interface { + ReaderFrom + WriterTo + io.Closer + LocalAddr() net.Addr + Conn() net.PacketConn +} + +type packetConn struct { + nextConn net.PacketConn + closed chan struct{} + closeOnce sync.Once + readMu sync.Mutex + writeMu sync.Mutex +} + +// NewPacketConn creates a new PacketConn wrapping the given net.PacketConn. +func NewPacketConn(pconn net.PacketConn) PacketConn { + p := &packetConn{ + nextConn: pconn, + closed: make(chan struct{}), + } + + return p +} + +// ReadFromContext reads a packet from the connection, +// copying the payload into p. It returns the number of +// bytes copied into p and the return address that +// was on the packet. +// It returns the number of bytes read (0 <= n <= len(p)) +// and any error encountered. Callers should always process +// the n > 0 bytes returned before considering the error err. +// Unlike net.PacketConn.ReadFrom(), the provided context is +// used to control timeout. +func (p *packetConn) ReadFromContext(ctx context.Context, b []byte) (int, net.Addr, error) { //nolint:cyclop + p.readMu.Lock() + defer p.readMu.Unlock() + + select { + case <-p.closed: + return 0, nil, net.ErrClosed + default: + } + + done := make(chan struct{}) + var wg sync.WaitGroup + var errSetDeadline atomic.Value + wg.Add(1) + go func() { + defer wg.Done() + select { + case <-ctx.Done(): + // context canceled + if err := p.nextConn.SetReadDeadline(veryOld); err != nil { + errSetDeadline.Store(err) + + return + } + <-done + if err := p.nextConn.SetReadDeadline(time.Time{}); err != nil { + errSetDeadline.Store(err) + } + case <-done: + } + }() + + n, raddr, err := p.nextConn.ReadFrom(b) + + close(done) + wg.Wait() + if e := ctx.Err(); e != nil && n == 0 { + err = e + } + if err2, ok := errSetDeadline.Load().(error); ok && err == nil && err2 != nil { + err = err2 + } + + return n, raddr, err +} + +// WriteToContext writes a packet with payload p to addr. +// Unlike net.PacketConn.WriteTo(), the provided context +// is used to control timeout. +// On packet-oriented connections, write timeouts are rare. +func (p *packetConn) WriteToContext(ctx context.Context, b []byte, raddr net.Addr) (int, error) { //nolint:cyclop + p.writeMu.Lock() + defer p.writeMu.Unlock() + + select { + case <-p.closed: + return 0, ErrClosing + default: + } + + done := make(chan struct{}) + var wg sync.WaitGroup + var errSetDeadline atomic.Value + wg.Add(1) + go func() { + defer wg.Done() + select { + case <-ctx.Done(): + // context canceled + if err := p.nextConn.SetWriteDeadline(veryOld); err != nil { + errSetDeadline.Store(err) + + return + } + <-done + if err := p.nextConn.SetWriteDeadline(time.Time{}); err != nil { + errSetDeadline.Store(err) + } + case <-done: + } + }() + + n, err := p.nextConn.WriteTo(b, raddr) + + close(done) + wg.Wait() + if e := ctx.Err(); e != nil && n == 0 { + err = e + } + if err2, ok := errSetDeadline.Load().(error); ok && err == nil && err2 != nil { + err = err2 + } + + return n, err +} + +// Close closes the connection. +// Any blocked ReadFromContext or WriteToContext operations will be unblocked +// and return errors. +func (p *packetConn) Close() error { + err := p.nextConn.Close() + p.closeOnce.Do(func() { + p.writeMu.Lock() + p.readMu.Lock() + close(p.closed) + p.readMu.Unlock() + p.writeMu.Unlock() + }) + + return err +} + +// LocalAddr returns the local network address, if known. +func (p *packetConn) LocalAddr() net.Addr { + return p.nextConn.LocalAddr() +} + +// Conn returns the underlying net.PacketConn. +func (p *packetConn) Conn() net.PacketConn { + return p.nextConn +} diff --git a/vendor/github.com/pion/transport/v4/netctx/pipe.go b/vendor/github.com/pion/transport/v4/netctx/pipe.go new file mode 100644 index 0000000..12d4f72 --- /dev/null +++ b/vendor/github.com/pion/transport/v4/netctx/pipe.go @@ -0,0 +1,15 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package netctx + +import ( + "net" +) + +// Pipe creates piped pair of Conn. +func Pipe() (Conn, Conn) { + ca, cb := net.Pipe() + + return NewConn(ca), NewConn(cb) +} diff --git a/vendor/github.com/pion/transport/v4/packetio/buffer.go b/vendor/github.com/pion/transport/v4/packetio/buffer.go new file mode 100644 index 0000000..f7d24cf --- /dev/null +++ b/vendor/github.com/pion/transport/v4/packetio/buffer.go @@ -0,0 +1,342 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package packetio provides packet buffer +package packetio + +import ( + "errors" + "io" + "sync" + "time" + + "github.com/pion/transport/v4/deadline" +) + +var errPacketTooBig = errors.New("packet too big") + +// BufferPacketType allow the Buffer to know which packet protocol is writing. +type BufferPacketType int + +const ( + // RTPBufferPacket indicates the Buffer that is handling RTP packets. + RTPBufferPacket BufferPacketType = 1 + // RTCPBufferPacket indicates the Buffer that is handling RTCP packets. + RTCPBufferPacket BufferPacketType = 2 +) + +// Buffer allows writing packets to an intermediate buffer, which can then be read form. +// This is verify similar to bytes.Buffer but avoids combining multiple writes into a single read. +type Buffer struct { + mutex sync.Mutex + + // this is a circular buffer. If head <= tail, then the useful + // data is in the interval [head, tail[. If tail < head, then + // the useful data is the union of [head, len[ and [0, tail[. + // In order to avoid ambiguity when head = tail, we always leave + // an unused byte in the buffer. + data []byte + head, tail int + + notify chan struct{} + closed bool + + count int + limitCount, limitSize int + + readDeadline *deadline.Deadline +} + +const ( + minSize = 2048 + cutoffSize = 128 * 1024 + maxSize = 4 * 1024 * 1024 +) + +// NewBuffer creates a new Buffer. +func NewBuffer() *Buffer { + return &Buffer{ + notify: make(chan struct{}, 1), + readDeadline: deadline.New(), + } +} + +// available returns true if the buffer is large enough to fit a packet +// of the given size, taking overhead into account. +func (b *Buffer) available(size int) bool { + available := b.head - b.tail + if available <= 0 { + available += len(b.data) + } + // we interpret head=tail as empty, so always keep a byte free + if size+2+1 > available { + return false + } + + return true +} + +// grow increases the size of the buffer. If it returns nil, then the +// buffer has been grown. It returns ErrFull if hits a limit. +func (b *Buffer) grow() error { + var newSize int + if len(b.data) < cutoffSize { + newSize = 2 * len(b.data) + } else { + newSize = 5 * len(b.data) / 4 + } + if newSize < minSize { + newSize = minSize + } + if (b.limitSize <= 0 || sizeHardLimit) && newSize > maxSize { + newSize = maxSize + } + + // one byte slack + if b.limitSize > 0 && newSize > b.limitSize+1 { + newSize = b.limitSize + 1 + } + + if newSize <= len(b.data) { + return ErrFull + } + + newData := make([]byte, newSize) + + var n int + if b.head <= b.tail { + // data was contiguous + n = copy(newData, b.data[b.head:b.tail]) + } else { + // data was discontinuous + n = copy(newData, b.data[b.head:]) + n += copy(newData[n:], b.data[:b.tail]) + } + b.head = 0 + b.tail = n + b.data = newData + + return nil +} + +// Write appends a copy of the packet data to the buffer. +// Returns ErrFull if the packet doesn't fit. +// +// Note that the packet size is limited to 65536 bytes since v0.11.0 due to the internal data structure. +func (b *Buffer) Write(packet []byte) (int, error) { //nolint:cyclop + if len(packet) >= 0x10000 { + return 0, errPacketTooBig + } + + b.mutex.Lock() + + if b.closed { + b.mutex.Unlock() + + return 0, io.ErrClosedPipe + } + + if (b.limitCount > 0 && b.count >= b.limitCount) || + (b.limitSize > 0 && b.size()+2+len(packet) > b.limitSize) { + b.mutex.Unlock() + + return 0, ErrFull + } + + // grow the buffer until the packet fits + for !b.available(len(packet)) { + err := b.grow() + if err != nil { + b.mutex.Unlock() + + return 0, err + } + } + + // store the length of the packet + b.data[b.tail] = uint8(len(packet) >> 8) //nolint:gosec + b.tail++ + if b.tail >= len(b.data) { + b.tail = 0 + } + b.data[b.tail] = uint8(len(packet)) //nolint:gosec + b.tail++ + if b.tail >= len(b.data) { + b.tail = 0 + } + + // store the packet + n := copy(b.data[b.tail:], packet) + b.tail += n + if b.tail >= len(b.data) { + // we reached the end, wrap around + m := copy(b.data, packet[n:]) + b.tail = m + } + b.count++ + + select { + case b.notify <- struct{}{}: + default: + } + b.mutex.Unlock() + + return len(packet), nil +} + +// Read populates the given byte slice, returning the number of bytes read. +// Blocks until data is available or the buffer is closed. +// Returns io.ErrShortBuffer is the packet is too small to copy the Write. +// Returns io.EOF if the buffer is closed. +func (b *Buffer) Read(packet []byte) (n int, err error) { //nolint:gocognit,cyclop + // Return immediately if the deadline is already exceeded. + select { + case <-b.readDeadline.Done(): + return 0, &netError{ErrTimeout, true, true} + default: + } + + for { + b.mutex.Lock() + + if b.head != b.tail { //nolint:nestif + // decode the packet size + n1 := b.data[b.head] + b.head++ + if b.head >= len(b.data) { + b.head = 0 + } + n2 := b.data[b.head] + b.head++ + if b.head >= len(b.data) { + b.head = 0 + } + count := int((uint16(n1) << 8) | uint16(n2)) + + // determine the number of bytes we'll actually copy + copied := count + if copied > len(packet) { + copied = len(packet) + } + + // copy the data + if b.head+copied < len(b.data) { + copy(packet, b.data[b.head:b.head+copied]) + } else { + k := copy(packet, b.data[b.head:]) + copy(packet[k:], b.data[:copied-k]) + } + + // advance head, discarding any data that wasn't copied + b.head += count + if b.head >= len(b.data) { + b.head -= len(b.data) + } + + if b.head == b.tail { + // the buffer is empty, reset to beginning + // in order to improve cache locality. + b.head = 0 + b.tail = 0 + } + + b.count-- + b.mutex.Unlock() + + if copied < count { + return copied, io.ErrShortBuffer + } + + return copied, nil + } + + if b.closed { + b.mutex.Unlock() + + return 0, io.EOF + } + b.mutex.Unlock() + + select { + case <-b.readDeadline.Done(): + return 0, &netError{ErrTimeout, true, true} + case <-b.notify: + } + } +} + +// Close the buffer, unblocking any pending reads. +// Data in the buffer can still be read, Read will return io.EOF only when empty. +func (b *Buffer) Close() (err error) { + b.mutex.Lock() + + if b.closed { + b.mutex.Unlock() + + return nil + } + + b.closed = true + close(b.notify) + b.mutex.Unlock() + + return nil +} + +// Count returns the number of packets in the buffer. +func (b *Buffer) Count() int { + b.mutex.Lock() + defer b.mutex.Unlock() + + return b.count +} + +// SetLimitCount controls the maximum number of packets that can be buffered. +// Causes Write to return ErrFull when this limit is reached. +// A zero value will disable this limit. +func (b *Buffer) SetLimitCount(limit int) { + b.mutex.Lock() + defer b.mutex.Unlock() + + b.limitCount = limit +} + +// Size returns the total byte size of packets in the buffer, including +// a small amount of administrative overhead. +func (b *Buffer) Size() int { + b.mutex.Lock() + defer b.mutex.Unlock() + + return b.size() +} + +func (b *Buffer) size() int { + size := b.tail - b.head + if size < 0 { + size += len(b.data) + } + + return size +} + +// SetLimitSize controls the maximum number of bytes that can be buffered. +// Causes Write to return ErrFull when this limit is reached. +// A zero value means 4MB since v0.11.0. +// +// User can set packetioSizeHardLimit build tag to enable 4MB hard limit. +// When packetioSizeHardLimit build tag is set, SetLimitSize exceeding +// the hard limit will be silently discarded. +func (b *Buffer) SetLimitSize(limit int) { + b.mutex.Lock() + defer b.mutex.Unlock() + + b.limitSize = limit +} + +// SetReadDeadline sets the deadline for the Read operation. +// Setting to zero means no deadline. +func (b *Buffer) SetReadDeadline(t time.Time) error { + b.readDeadline.Set(t) + + return nil +} diff --git a/vendor/github.com/pion/transport/v4/packetio/errors.go b/vendor/github.com/pion/transport/v4/packetio/errors.go new file mode 100644 index 0000000..eb45a78 --- /dev/null +++ b/vendor/github.com/pion/transport/v4/packetio/errors.go @@ -0,0 +1,30 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package packetio + +import ( + "errors" +) + +// netError implements net.Error. +type netError struct { + error + timeout, temporary bool +} + +func (e *netError) Timeout() bool { + return e.timeout +} + +func (e *netError) Temporary() bool { + return e.temporary +} + +var ( + // ErrFull is returned when the buffer has hit the configured limits. + ErrFull = errors.New("packetio.Buffer is full, discarding write") + + // ErrTimeout is returned when a deadline has expired. + ErrTimeout = errors.New("i/o timeout") +) diff --git a/vendor/github.com/pion/transport/v4/packetio/hardlimit.go b/vendor/github.com/pion/transport/v4/packetio/hardlimit.go new file mode 100644 index 0000000..8058e47 --- /dev/null +++ b/vendor/github.com/pion/transport/v4/packetio/hardlimit.go @@ -0,0 +1,9 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +//go:build packetioSizeHardlimit +// +build packetioSizeHardlimit + +package packetio + +const sizeHardLimit = true diff --git a/vendor/github.com/pion/transport/v4/packetio/no_hardlimit.go b/vendor/github.com/pion/transport/v4/packetio/no_hardlimit.go new file mode 100644 index 0000000..a59e259 --- /dev/null +++ b/vendor/github.com/pion/transport/v4/packetio/no_hardlimit.go @@ -0,0 +1,9 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !packetioSizeHardlimit +// +build !packetioSizeHardlimit + +package packetio + +const sizeHardLimit = false diff --git a/vendor/github.com/pion/transport/v4/renovate.json b/vendor/github.com/pion/transport/v4/renovate.json new file mode 100644 index 0000000..f1bb98c --- /dev/null +++ b/vendor/github.com/pion/transport/v4/renovate.json @@ -0,0 +1,6 @@ +{ + "$schema": "https://docs.renovatebot.com/renovate-schema.json", + "extends": [ + "github>pion/renovate-config" + ] +} diff --git a/vendor/github.com/pion/transport/v4/replaydetector/fixedbig.go b/vendor/github.com/pion/transport/v4/replaydetector/fixedbig.go new file mode 100644 index 0000000..8a5655d --- /dev/null +++ b/vendor/github.com/pion/transport/v4/replaydetector/fixedbig.go @@ -0,0 +1,84 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package replaydetector + +import ( + "fmt" +) + +// fixedBigInt is the fix-sized multi-word integer. +type fixedBigInt struct { + bits []uint64 + n uint + msbMask uint64 +} + +// newFixedBigInt creates a new fix-sized multi-word int. +func newFixedBigInt(n uint) *fixedBigInt { + chunkSize := (n + 63) / 64 + if chunkSize == 0 { + chunkSize = 1 + } + + return &fixedBigInt{ + bits: make([]uint64, chunkSize), + n: n, + msbMask: (1 << (64 - n%64)) - 1, + } +} + +// Lsh is the left shift operation. +func (s *fixedBigInt) Lsh(n uint) { //nolint:varnamelen + if n == 0 { + return + } + nChunk := int(n / 64) //nolint:gosec + nN := n % 64 + + for i := len(s.bits) - 1; i >= 0; i-- { + var carry uint64 + if i-nChunk >= 0 { + carry = s.bits[i-nChunk] << nN + if i-nChunk-1 >= 0 { + carry |= s.bits[i-nChunk-1] >> (64 - nN) + } + } + s.bits[i] = (s.bits[i] << n) | carry + } + s.bits[len(s.bits)-1] &= s.msbMask +} + +// Bit returns i-th bit of the fixedBigInt. +func (s *fixedBigInt) Bit(i uint) uint { + if i >= s.n { + return 0 + } + chunk := i / 64 + pos := i % 64 + if s.bits[chunk]&(1<= s.n { + return + } + chunk := i / 64 + pos := i % 64 + s.bits[chunk] |= 1 << pos +} + +// String returns string representation of fixedBigInt. +func (s *fixedBigInt) String() string { + var out string + for i := len(s.bits) - 1; i >= 0; i-- { + out += fmt.Sprintf("%016X", s.bits[i]) + } + + return out +} diff --git a/vendor/github.com/pion/transport/v4/replaydetector/replaydetector.go b/vendor/github.com/pion/transport/v4/replaydetector/replaydetector.go new file mode 100644 index 0000000..a35a0c6 --- /dev/null +++ b/vendor/github.com/pion/transport/v4/replaydetector/replaydetector.go @@ -0,0 +1,136 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package replaydetector provides packet replay detection algorithm. +package replaydetector + +// ReplayDetector is the interface of sequence replay detector. +type ReplayDetector interface { + // Check returns true if given sequence number is not replayed. + // Call accept() to mark the packet is received properly. + // The return value of accept() indicates whether the accepted packet is + // has the latest observed sequence number. + Check(seq uint64) (accept func() bool, ok bool) +} + +// nop is a no-op func that is returned in the case that Check() fails. +func nop() bool { + return false +} + +type slidingWindowDetector struct { + latestSeq uint64 + maxSeq uint64 + windowSize uint + mask *fixedBigInt +} + +// New creates ReplayDetector. +// Created ReplayDetector doesn't allow wrapping. +// It can handle monotonically increasing sequence number up to +// full 64bit number. It is suitable for DTLS replay protection. +func New(windowSize uint, maxSeq uint64) ReplayDetector { + return &slidingWindowDetector{ + maxSeq: maxSeq, + windowSize: windowSize, + mask: newFixedBigInt(windowSize), + } +} + +func (d *slidingWindowDetector) Check(seq uint64) (func() bool, bool) { + if seq > d.maxSeq { + // Exceeded upper limit. + return nop, false + } + + if seq <= d.latestSeq { + if d.latestSeq >= uint64(d.windowSize)+seq { + return nop, false + } + if d.mask.Bit(uint(d.latestSeq-seq)) != 0 { + // The sequence number is duplicated. + return nop, false + } + } + + return func() bool { + latest := seq == 0 + if seq > d.latestSeq { + // Update the head of the window. + d.mask.Lsh(uint(seq - d.latestSeq)) + d.latestSeq = seq + latest = true + } + diff := (d.latestSeq - seq) % d.maxSeq + d.mask.SetBit(uint(diff)) + + return latest + }, true +} + +// WithWrap creates ReplayDetector allowing sequence wrapping. +// This is suitable for short bit width counter like SRTP and SRTCP. +func WithWrap(windowSize uint, maxSeq uint64) ReplayDetector { + return &wrappedSlidingWindowDetector{ + maxSeq: maxSeq, + windowSize: windowSize, + mask: newFixedBigInt(windowSize), + } +} + +type wrappedSlidingWindowDetector struct { + latestSeq uint64 + maxSeq uint64 + windowSize uint + mask *fixedBigInt + init bool +} + +func (d *wrappedSlidingWindowDetector) Check(seq uint64) (func() bool, bool) { + if seq > d.maxSeq { + // Exceeded upper limit. + return nop, false + } + if !d.init { + if seq != 0 { + d.latestSeq = seq - 1 + } else { + d.latestSeq = d.maxSeq + } + d.init = true + } + + diff := int64(d.latestSeq) - int64(seq) //nolint:gosec // GG115 TODO check + // Wrap the number. + if diff > int64(d.maxSeq)/2 { //nolint:gosec // GG115 TODO check + diff -= int64(d.maxSeq + 1) //nolint:gosec // GG115 TODO check + } else if diff <= -int64(d.maxSeq)/2 { //nolint:gosec // GG115 TODO check + diff += int64(d.maxSeq + 1) //nolint:gosec // GG115 TODO check + } + + if diff >= int64(d.windowSize) { //nolint:gosec // GG115 TODO check + // Too old. + return nop, false + } + if diff >= 0 { + if d.mask.Bit(uint(diff)) != 0 { + // The sequence number is duplicated. + return nop, false + } + } + + return func() bool { + latest := false + if diff < 0 { + // Update the head of the window. + d.mask.Lsh(uint(-diff)) + d.latestSeq = seq + latest = true + d.mask.SetBit(0) + } else { + d.mask.SetBit(uint(diff)) + } + + return latest + }, true +} diff --git a/vendor/github.com/pion/transport/v4/stdnet/net.go b/vendor/github.com/pion/transport/v4/stdnet/net.go new file mode 100644 index 0000000..5b2976a --- /dev/null +++ b/vendor/github.com/pion/transport/v4/stdnet/net.go @@ -0,0 +1,186 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package stdnet implements the transport.Net interface +// using methods from Go's standard net package. +package stdnet + +import ( + "context" + "fmt" + "net" + + "github.com/pion/transport/v4" + "github.com/wlynxg/anet" +) + +const ( + lo0String = "lo0String" + udpString = "udp" +) + +// Net is an implementation of the net.Net interface +// based on functions of the standard net package. +type Net struct { + interfaces []*transport.Interface +} + +// NewNet creates a new StdNet instance. +func NewNet() (*Net, error) { + n := &Net{} + + return n, n.UpdateInterfaces() +} + +// Compile-time assertion. +var _ transport.Net = &Net{} + +// UpdateInterfaces updates the internal list of network interfaces +// and associated addresses. +func (n *Net) UpdateInterfaces() error { + ifs := []*transport.Interface{} + + oifs, err := anet.Interfaces() + if err != nil { + return err + } + + for i := range oifs { + ifc := transport.NewInterface(oifs[i]) + + addrs, err := anet.InterfaceAddrsByInterface(&oifs[i]) + if err != nil { + return err + } + + for _, addr := range addrs { + ifc.AddAddress(addr) + } + + ifs = append(ifs, ifc) + } + + n.interfaces = ifs + + return nil +} + +// Interfaces returns a slice of interfaces which are available on the +// system. +func (n *Net) Interfaces() ([]*transport.Interface, error) { + return n.interfaces, nil +} + +// InterfaceByIndex returns the interface specified by index. +// +// On Solaris, it returns one of the logical network interfaces +// sharing the logical data link; for more precision use +// InterfaceByName. +func (n *Net) InterfaceByIndex(index int) (*transport.Interface, error) { + for _, ifc := range n.interfaces { + if ifc.Index == index { + return ifc, nil + } + } + + return nil, fmt.Errorf("%w: index=%d", transport.ErrInterfaceNotFound, index) +} + +// InterfaceByName returns the interface specified by name. +func (n *Net) InterfaceByName(name string) (*transport.Interface, error) { + for _, ifc := range n.interfaces { + if ifc.Name == name { + return ifc, nil + } + } + + return nil, fmt.Errorf("%w: %s", transport.ErrInterfaceNotFound, name) +} + +// ListenPacket announces on the local network address. +func (n *Net) ListenPacket(network string, address string) (net.PacketConn, error) { + return net.ListenPacket(network, address) //nolint: noctx +} + +// ListenUDP acts like ListenPacket for UDP networks. +func (n *Net) ListenUDP(network string, locAddr *net.UDPAddr) (transport.UDPConn, error) { + return net.ListenUDP(network, locAddr) +} + +// Dial connects to the address on the named network. +func (n *Net) Dial(network, address string) (net.Conn, error) { + return net.Dial(network, address) //nolint: noctx +} + +// DialUDP acts like Dial for UDP networks. +func (n *Net) DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) { + return net.DialUDP(network, laddr, raddr) +} + +// ResolveIPAddr returns an address of IP end point. +func (n *Net) ResolveIPAddr(network, address string) (*net.IPAddr, error) { + return net.ResolveIPAddr(network, address) +} + +// ResolveUDPAddr returns an address of UDP end point. +func (n *Net) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) { + return net.ResolveUDPAddr(network, address) +} + +// ResolveTCPAddr returns an address of TCP end point. +func (n *Net) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) { + return net.ResolveTCPAddr(network, address) +} + +// DialTCP acts like Dial for TCP networks. +func (n *Net) DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) { + return net.DialTCP(network, laddr, raddr) +} + +// ListenTCP acts like Listen for TCP networks. +func (n *Net) ListenTCP(network string, laddr *net.TCPAddr) (transport.TCPListener, error) { + l, err := net.ListenTCP(network, laddr) + if err != nil { + return nil, err + } + + return tcpListener{l}, nil +} + +type tcpListener struct { + *net.TCPListener +} + +func (l tcpListener) AcceptTCP() (transport.TCPConn, error) { + return l.TCPListener.AcceptTCP() +} + +type stdDialer struct { + *net.Dialer +} + +func (d stdDialer) Dial(network, address string) (net.Conn, error) { + return d.Dialer.Dial(network, address) +} + +// CreateDialer creates an instance of vnet.Dialer. +func (n *Net) CreateDialer(d *net.Dialer) transport.Dialer { + return stdDialer{d} +} + +type stdListenConfig struct { + *net.ListenConfig +} + +func (d stdListenConfig) Listen(ctx context.Context, network, address string) (net.Listener, error) { + return d.ListenConfig.Listen(ctx, network, address) +} + +func (d stdListenConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { + return d.ListenConfig.ListenPacket(ctx, network, address) +} + +// CreateListenConfig creates an instance of vnet.ListenConfig. +func (n *Net) CreateListenConfig(d *net.ListenConfig) transport.ListenConfig { + return stdListenConfig{d} +} diff --git a/vendor/github.com/pion/transport/v4/utils/xor/xor_arm.go b/vendor/github.com/pion/transport/v4/utils/xor/xor_arm.go new file mode 100644 index 0000000..25d6b72 --- /dev/null +++ b/vendor/github.com/pion/transport/v4/utils/xor/xor_arm.go @@ -0,0 +1,60 @@ +// SPDX-FileCopyrightText: 2022 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !gccgo +// +build !gccgo + +// Package xor provides utility functions used by other Pion +// packages. ARM arch. +package xor + +import ( + "unsafe" + + "golang.org/x/sys/cpu" +) + +const wordSize = int(unsafe.Sizeof(uintptr(0))) // nolint:gosec +var hasNEON = cpu.ARM.HasNEON // nolint:gochecknoglobals + +func isAligned(a *byte) bool { + return uintptr(unsafe.Pointer(a))%uintptr(wordSize) == 0 +} + +// XorBytes xors the bytes in a and b. The destination should have enough +// space, otherwise xorBytes will panic. Returns the number of bytes xor'd. +// +//revive:disable-next-line +func XorBytes(dst, a, b []byte) int { + n := len(a) + if len(b) < n { + n = len(b) + } + if n == 0 { + return 0 + } + // make sure dst has enough space + _ = dst[n-1] + + if hasNEON { + xorBytesNEON32(&dst[0], &a[0], &b[0], n) + } else if isAligned(&dst[0]) && isAligned(&a[0]) && isAligned(&b[0]) { + xorBytesARM32(&dst[0], &a[0], &b[0], n) + } else { + safeXORBytes(dst, a, b, n) + } + return n +} + +// n needs to be smaller or equal than the length of a and b. +func safeXORBytes(dst, a, b []byte, n int) { + for i := 0; i < n; i++ { + dst[i] = a[i] ^ b[i] + } +} + +//go:noescape +func xorBytesARM32(dst, a, b *byte, n int) + +//go:noescape +func xorBytesNEON32(dst, a, b *byte, n int) diff --git a/vendor/github.com/pion/transport/v4/utils/xor/xor_arm.s b/vendor/github.com/pion/transport/v4/utils/xor/xor_arm.s new file mode 100644 index 0000000..5e52a2d --- /dev/null +++ b/vendor/github.com/pion/transport/v4/utils/xor/xor_arm.s @@ -0,0 +1,116 @@ +// SPDX-FileCopyrightText: 2022 The Pion community +// SPDX-License-Identifier: MIT + +// go:build !gccgo +// +build !gccgo + +#include "textflag.h" + +// func xorBytesARM32(dst, a, b *byte, n int) +TEXT ·xorBytesARM32(SB), NOSPLIT|NOFRAME, $0 + MOVW dst+0(FP), R0 + MOVW a+4(FP), R1 + MOVW b+8(FP), R2 + MOVW n+12(FP), R3 + CMP $4, R3 + BLT less_than4 + +loop_4: + MOVW.P 4(R1), R4 + MOVW.P 4(R2), R5 + EOR R4, R5, R5 + MOVW.P R5, 4(R0) + + SUB $4, R3 + CMP $4, R3 + BGE loop_4 + +less_than4: + CMP $2, R3 + BLT less_than2 + MOVH.P 2(R1), R4 + MOVH.P 2(R2), R5 + EOR R4, R5, R5 + MOVH.P R5, 2(R0) + + SUB $2, R3 + +less_than2: + CMP $0, R3 + BEQ end + MOVB (R1), R4 + MOVB (R2), R5 + EOR R4, R5, R5 + MOVB R5, (R0) +end: + RET + +// func xorBytesNEON32(dst, a, b *byte, n int) +TEXT ·xorBytesNEON32(SB), NOSPLIT|NOFRAME, $0 + MOVW dst+0(FP), R0 + MOVW a+4(FP), R1 + MOVW b+8(FP), R2 + MOVW n+12(FP), R3 + CMP $32, R3 + BLT less_than32 + +loop_32: + WORD $0xF421020D // vld1.u8 {q0, q1}, [r1]! + WORD $0xF422420D // vld1.u8 {q2, q3}, [r2]! + WORD $0xF3004154 // veor q2, q0, q2 + WORD $0xF3026156 // veor q3, q1, q3 + WORD $0xF400420D // vst1.u8 {q2, q3}, [r0]! + + SUB $32, R3 + CMP $32, R3 + BGE loop_32 + +less_than32: + CMP $16, R3 + BLT less_than16 + WORD $0xF4210A0D // vld1.u8 q0, [r1]! + WORD $0xF4222A0D // vld1.u8 q1, [r2]! + WORD $0xF3002152 // veor q1, q0, q1 + WORD $0xF4002A0D // vst1.u8 {q1}, [r0]! + + SUB $16, R3 + +less_than16: + CMP $8, R3 + BLT less_than8 + WORD $0xF421070D // vld1.u8 d0, [r1]! + WORD $0xF422170D // vld1.u8 d1, [r2]! + WORD $0xF3001111 // veor d1, d0, d1 + WORD $0xF400170D // vst1.u8 {d1}, [r0]! + + SUB $8, R3 + +less_than8: + CMP $4, R3 + BLT less_than4 + MOVW.P 4(R1), R4 + MOVW.P 4(R2), R5 + EOR R4, R5, R5 + MOVW.P R5, 4(R0) + + SUB $4, R3 + +less_than4: + CMP $2, R3 + BLT less_than2 + MOVH.P 2(R1), R4 + MOVH.P 2(R2), R5 + EOR R4, R5, R5 + MOVH.P R5, 2(R0) + + SUB $2, R3 + +less_than2: + CMP $0, R3 + BEQ end + MOVB (R1), R4 + MOVB (R2), R5 + EOR R4, R5, R5 + MOVB R5, (R0) +end: + RET diff --git a/vendor/github.com/pion/transport/v4/utils/xor/xor_generic.go b/vendor/github.com/pion/transport/v4/utils/xor/xor_generic.go new file mode 100644 index 0000000..690549a --- /dev/null +++ b/vendor/github.com/pion/transport/v4/utils/xor/xor_generic.go @@ -0,0 +1,20 @@ +// SPDX-FileCopyrightText: 2013 The Go Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// SPDX-FileCopyrightText: 2024 The Pion community +// SPDX-License-Identifier: MIT + +//go:build go1.20 && !arm && !gccgo + +// Package xor provides the XorBytes function. +package xor + +import ( + "crypto/subtle" +) + +// XorBytes calls [crypto/suble.XORBytes]. +// +//revive:disable-next-line +func XorBytes(dst, a, b []byte) int { + return subtle.XORBytes(dst, a, b) +} diff --git a/vendor/github.com/pion/transport/v4/utils/xor/xor_old.go b/vendor/github.com/pion/transport/v4/utils/xor/xor_old.go new file mode 100644 index 0000000..f46f4d9 --- /dev/null +++ b/vendor/github.com/pion/transport/v4/utils/xor/xor_old.go @@ -0,0 +1,77 @@ +// SPDX-FileCopyrightText: 2013 The Go Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// SPDX-FileCopyrightText: 2022 The Pion community +// SPDX-License-Identifier: MIT + +//go:build (!go1.20 && !arm) || gccgo + +// Package xor provides the XorBytes function. +// This version is only used on Go up to version 1.19. +package xor + +import ( + "runtime" + "unsafe" +) + +const ( + wordSize = int(unsafe.Sizeof(uintptr(0))) // nolint:gosec + supportsUnaligned = runtime.GOARCH == "386" || runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64" || runtime.GOARCH == "ppc64" || runtime.GOARCH == "ppc64le" || runtime.GOARCH == "s390x" // nolint:gochecknoglobals +) + +func isAligned(a *byte) bool { + return uintptr(unsafe.Pointer(a))%uintptr(wordSize) == 0 +} + +// XorBytes xors the bytes in a and b. The destination should have enough +// space, otherwise xorBytes will panic. Returns the number of bytes xor'd. +// +//revive:disable-next-line +func XorBytes(dst, a, b []byte) int { + n := len(a) + if len(b) < n { + n = len(b) + } + if n == 0 { + return 0 + } + + switch { + case supportsUnaligned: + fastXORBytes(dst, a, b, n) + case isAligned(&dst[0]) && isAligned(&a[0]) && isAligned(&b[0]): + fastXORBytes(dst, a, b, n) + default: + safeXORBytes(dst, a, b, n) + } + return n +} + +// fastXORBytes xors in bulk. It only works on architectures that +// support unaligned read/writes. +// n needs to be smaller or equal than the length of a and b. +func fastXORBytes(dst, a, b []byte, n int) { + // Assert dst has enough space + _ = dst[n-1] + + w := n / wordSize + if w > 0 { + dw := *(*[]uintptr)(unsafe.Pointer(&dst)) // nolint:gosec + aw := *(*[]uintptr)(unsafe.Pointer(&a)) // nolint:gosec + bw := *(*[]uintptr)(unsafe.Pointer(&b)) // nolint:gosec + for i := 0; i < w; i++ { + dw[i] = aw[i] ^ bw[i] + } + } + + for i := (n - n%wordSize); i < n; i++ { + dst[i] = a[i] ^ b[i] + } +} + +// n needs to be smaller or equal than the length of a and b. +func safeXORBytes(dst, a, b []byte, n int) { + for i := 0; i < n; i++ { + dst[i] = a[i] ^ b[i] + } +} diff --git a/vendor/github.com/pion/transport/v4/vnet/.gitignore b/vendor/github.com/pion/transport/v4/vnet/.gitignore new file mode 100644 index 0000000..f2eef3e --- /dev/null +++ b/vendor/github.com/pion/transport/v4/vnet/.gitignore @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +*.sw[poe] diff --git a/vendor/github.com/pion/transport/v4/vnet/README.md b/vendor/github.com/pion/transport/v4/vnet/README.md new file mode 100644 index 0000000..313afc0 --- /dev/null +++ b/vendor/github.com/pion/transport/v4/vnet/README.md @@ -0,0 +1,231 @@ +# vnet +A virtual network layer for pion. + +## Overview + +### Goals +* To make NAT traversal tests easy. +* To emulate packet impairment at application level for testing. +* To monitor packets at specified arbitrary interfaces. + +### Features +* Configurable virtual LAN and WAN +* Virtually hosted ICE servers + +### Virtual network components + +#### Top View +``` + ...................................... + : Virtual Network (vnet) : + : : + +-------+ * 1 +----+ +--------+ : + | :App |------------>|:Net|--o<-----|:Router | : + +-------+ +----+ | | : + +-----------+ * 1 +----+ | | : + |:STUNServer|-------->|:Net|--o<-----| | : + +-----------+ +----+ | | : + +-----------+ * 1 +----+ | | : + |:TURNServer|-------->|:Net|--o<-----| | : + +-----------+ +----+ [1] | | : + : 1 | | 1 <> : + : +---<>| |<>----+ [2] : + : | +--------+ | : + To form | *| v 0..1 : + a subnet tree | o [3] +-----+ : + : | ^ |:NAT | : + : | | +-----+ : + : +-------+ : + ...................................... + Note: + o: NIC (Network Interface Controller) + [1]: Net implements NIC interface. + [2]: Root router has no NAT. All child routers have a NAT always. + [3]: Router implements NIC interface for accesses from the + parent router. +``` + +#### Net +Net provides 3 interfaces: +* Configuration API (direct) +* Network API via Net (equivalent to net.Xxx()) +* Router access via NIC interface +``` + (Pion module/app, ICE servers, etc.) + +-----------+ + | :App | + +-----------+ + * | + | <> + 1 v + +---------+ 1 * +-----------+ 1 * +-----------+ 1 * +------+ + ..| :Router |----+------>o--| :Net |<>------|:Interface |<>------|:Addr | + +---------+ | NIC +-----------+ +-----------+ +------+ + | <> (transport.Interface) (net.Addr) + | + | * +-----------+ 1 * +-----------+ 1 * +------+ + +------>o--| :Router |<>------|:Interface |<>------|:Addr | + NIC +-----------+ +-----------+ +------+ + <> (transport.Interface) (net.Addr) +``` + +> The instance of `Net` will be the one passed around the project. +> Net class has public methods for configuration and for application use. + + +## Implementation + +### Design Policy +* Each pion package should have config object which has `Net` (of type `transport.Net`) property. + - Just like how we distribute `LoggerFactory` throughout the pion project. +* DNS => a simple dictionary (global)? +* Each Net has routing capability (a goroutine) +* Use interface provided net package as much as possible +* Routers are connected in a tree structure (no loop is allowed) + - To simplify routing + - Easy to control / monitor (stats, etc) +* Root router has no NAT (== Internet / WAN) +* Non-root router has a NAT always +* When a Net is instantiated, it will automatically add `lo0` and `eth0` interface, and `lo0` will have one IP address, 127.0.0.1. (this is not used in pion/ice, however) +* When a Net is added to a router, the router automatically assign an IP address for `eth0` interface. + - For simplicity +* User data won't fragment, but optionally drop chunk larger than MTU +* IPv6 is not supported + +### Basic steps for setting up virtual network +1. Create a root router (WAN) +1. Create child routers and add to its parent (forms a tree, don't create a loop!) +1. Add instances of Net to each routers +1. Call Stop(), or Stop(), on the top router, which propagates all other routers + +#### Example: WAN with one endpoint (vnet) +```go +import ( + "net" + + "github.com/pion/transport" + "github.com/pion/transport/vnet" + "github.com/pion/logging" +) + +// Create WAN (a root router). +wan, err := vnet.NewRouter(&RouterConfig{ + CIDR: "0.0.0.0/0", + LoggerFactory: logging.NewDefaultLoggerFactory(), +}) + +// Create a network. +// You can specify static IPs for the instance of Net to use. If not specified, +// router will assign an IP address that is contained in the router's CIDR. +nw := vnet.NewNet(&vnet.NetConfig{ + StaticIPs: []string{"27.1.2.3"}, +}) + +// Add the network to the router. +// The router will assign an IP address to `nw`. +if err = wan.AddNet(nw); err != nil { + // handle error +} + +// Start router. +// This will start internal goroutine to route packets. +// If you set child routers (using AddRouter), the call on the root router +// will start the rest of routers for you. +if err = wan.Start(); err != nil { + // handle error +} + +// +// Your application runs here using `nw`. +// + +// Stop the router. +// This will stop all internal Go routines in the router tree. +// (No need to call Stop() on child routers) +if err = wan.Stop(); err != nil { + // handle error +} +``` + +#### Example of how to pass around the instance of vnet.Net +The instance of vnet.Net wraps a subset of net package to enable operations +on the virtual network. Your project must be able to pass the instance to +all your routines that do network operation with net package. A typical way +is to use a config param to create your instances with the virtual network +instance (`nw` in the above example) like this: + +```go +type AgentConfig struct { + : + Net: transport.Net, +} + +type Agent struct { + : + net: transport.Net, +} + +func NetAgent(config *AgentConfig) *Agent { + if config.Net == nil { + config.Net = vnet.NewNet() + } + + return &Agent { + : + net: config.Net, + } +} +``` + +```go +// a.net is the instance of vnet.Net class +func (a *Agent) listenUDP(...) error { + conn, err := a.net.ListenPacket(udpString, ...) + if err != nil { + return nil, err + } + : +} +``` + +### Compatibility and Support Status + +|`net`
(built-in) |`vnet` |Note | +|:--- |:--- |:--- | +| net.Interfaces() | a.net.Interfaces() | | +| net.InterfaceByName() | a.net.InterfaceByName() | | +| net.ResolveUDPAddr() | a.net.ResolveUDPAddr() | | +| net.ListenPacket() | a.net.ListenPacket() | | +| net.ListenUDP() | a.net.ListenUDP() | ListenPacket() is recommended | +| net.Listen() | a.net.Listen() | TODO) | +| net.ListenTCP() | (not supported) | Listen() would be recommended | +| net.Dial() | a.net.Dial() | | +| net.DialUDP() | a.net.DialUDP() | | +| net.DialTCP() | (not supported) | | +| net.Interface | transport.Interface | | +| net.PacketConn | (use it as-is) | | +| net.UDPConn | transport.UDPConn | | +| net.TCPConn | transport.TCPConn | TODO: Use net.Conn in your code | +| net.Dialer | transport.Dialer | Use a.net.CreateDialer() to create it.
The use of vnet.Dialer is currently experimental. | + +> `a.net` is an instance of Net class, and types are defined under the package name `vnet` + +> Most of other `interface` types in net package can be used as is. + +> Please post a github issue when other types/methods need to be added to vnet/vnet.Net. + +## TODO / Next Step +* Implement TCP (TCPConn, Listen) +* Support of IPv6 +* Write a bunch of examples for building virtual networks. +* Add network impairment features (on Router) + - Introduce latency / jitter + - Packet filtering handler (allow selectively drop packets, etc.) +* Add statistics data retrieval + - Total number of packets forward by each router + - Total number of packet loss + - Total number of connection failure (TCP) + +## References +* [Comparing Simulated Packet Loss and RealWorld Network Congestion](https://www.riverbed.com/document/fpo/WhitePaper-Riverbed-SimulatedPacketLoss.pdf) +* [wireguard-go using GVisor's netstack](https://github.com/WireGuard/wireguard-go/tree/master/tun/netstack) \ No newline at end of file diff --git a/vendor/github.com/pion/transport/v4/vnet/chunk.go b/vendor/github.com/pion/transport/v4/vnet/chunk.go new file mode 100644 index 0000000..d1cb4c8 --- /dev/null +++ b/vendor/github.com/pion/transport/v4/vnet/chunk.go @@ -0,0 +1,303 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package vnet + +import ( + "fmt" + "net" + "strconv" + "strings" + "sync/atomic" + "time" +) + +type tcpFlag uint8 + +const ( + tcpFIN tcpFlag = 0x01 + tcpSYN tcpFlag = 0x02 + tcpRST tcpFlag = 0x04 + tcpPSH tcpFlag = 0x08 + tcpACK tcpFlag = 0x10 +) + +func (f tcpFlag) String() string { + var sa []string + if f&tcpFIN != 0 { + sa = append(sa, "FIN") + } + if f&tcpSYN != 0 { + sa = append(sa, "SYN") + } + if f&tcpRST != 0 { + sa = append(sa, "RST") + } + if f&tcpPSH != 0 { + sa = append(sa, "PSH") + } + if f&tcpACK != 0 { + sa = append(sa, "ACK") + } + + return strings.Join(sa, "-") +} + +// Generate a base36-encoded unique tag +// See: https://play.golang.org/p/0ZaAID1q-HN +var assignChunkTag = func() func() string { //nolint:gochecknoglobals + var tagCtr uint64 + + return func() string { + n := atomic.AddUint64(&tagCtr, 1) + + return strconv.FormatUint(n, 36) + } +}() + +// Chunk represents a packet passed around in the vnet. +type Chunk interface { + setTimestamp() time.Time // used by router + getTimestamp() time.Time // used by router + getSourceIP() net.IP // used by router + getDestinationIP() net.IP // used by router + setSourceAddr(address string) error // used by nat + setDestinationAddr(address string) error // used by nat + + SourceAddr() net.Addr + DestinationAddr() net.Addr + UserData() []byte + Tag() string + Clone() Chunk + Network() string // returns "udp" or "tcp" + String() string +} + +type chunkIP struct { + timestamp time.Time + sourceIP net.IP + destinationIP net.IP + tag string + duplicate bool +} + +func (c *chunkIP) setTimestamp() time.Time { + c.timestamp = time.Now() + + return c.timestamp +} + +func (c *chunkIP) getTimestamp() time.Time { + return c.timestamp +} + +func (c *chunkIP) getDestinationIP() net.IP { + return c.destinationIP +} + +func (c *chunkIP) getSourceIP() net.IP { + return c.sourceIP +} + +func (c *chunkIP) Tag() string { + return c.tag +} + +func (c *chunkIP) markDuplicate() { + c.duplicate = true +} + +func (c *chunkIP) isDuplicate() bool { + return c.duplicate +} + +type chunkUDP struct { + chunkIP + sourcePort int + destinationPort int + userData []byte +} + +func newChunkUDP(srcAddr, dstAddr *net.UDPAddr) *chunkUDP { + return &chunkUDP{ + chunkIP: chunkIP{ + sourceIP: srcAddr.IP, + destinationIP: dstAddr.IP, + tag: assignChunkTag(), + }, + sourcePort: srcAddr.Port, + destinationPort: dstAddr.Port, + } +} + +func (c *chunkUDP) SourceAddr() net.Addr { + return &net.UDPAddr{ + IP: c.sourceIP, + Port: c.sourcePort, + } +} + +func (c *chunkUDP) DestinationAddr() net.Addr { + return &net.UDPAddr{ + IP: c.destinationIP, + Port: c.destinationPort, + } +} + +func (c *chunkUDP) UserData() []byte { + return c.userData +} + +func (c *chunkUDP) Clone() Chunk { + var userData []byte + if c.userData != nil { + userData = make([]byte, len(c.userData)) + copy(userData, c.userData) + } + + return &chunkUDP{ + chunkIP: chunkIP{ + timestamp: c.timestamp, + sourceIP: c.sourceIP, + destinationIP: c.destinationIP, + tag: c.tag, + }, + sourcePort: c.sourcePort, + destinationPort: c.destinationPort, + userData: userData, + } +} + +func (c *chunkUDP) Network() string { + return udp +} + +func (c *chunkUDP) String() string { + src := c.SourceAddr() + dst := c.DestinationAddr() + + return fmt.Sprintf("%s chunk %s %s => %s", + src.Network(), + c.tag, + src.String(), + dst.String(), + ) +} + +func (c *chunkUDP) setSourceAddr(address string) error { + addr, err := net.ResolveUDPAddr(udp, address) + if err != nil { + return err + } + c.sourceIP = addr.IP + c.sourcePort = addr.Port + + return nil +} + +func (c *chunkUDP) setDestinationAddr(address string) error { + addr, err := net.ResolveUDPAddr(udp, address) + if err != nil { + return err + } + c.destinationIP = addr.IP + c.destinationPort = addr.Port + + return nil +} + +type chunkTCP struct { + chunkIP + sourcePort int + destinationPort int + flags tcpFlag // control bits + userData []byte // only with PSH flag + // seq uint32 // always starts with 0 + // ack uint32 // always starts with 0 +} + +func newChunkTCP(srcAddr, dstAddr *net.TCPAddr, flags tcpFlag) *chunkTCP { + return &chunkTCP{ + chunkIP: chunkIP{ + sourceIP: srcAddr.IP, + destinationIP: dstAddr.IP, + tag: assignChunkTag(), + }, + sourcePort: srcAddr.Port, + destinationPort: dstAddr.Port, + flags: flags, + } +} + +func (c *chunkTCP) SourceAddr() net.Addr { + return &net.TCPAddr{ + IP: c.sourceIP, + Port: c.sourcePort, + } +} + +func (c *chunkTCP) DestinationAddr() net.Addr { + return &net.TCPAddr{ + IP: c.destinationIP, + Port: c.destinationPort, + } +} + +func (c *chunkTCP) UserData() []byte { + return c.userData +} + +func (c *chunkTCP) Clone() Chunk { + userData := make([]byte, len(c.userData)) + copy(userData, c.userData) + + return &chunkTCP{ + chunkIP: chunkIP{ + timestamp: c.timestamp, + sourceIP: c.sourceIP, + destinationIP: c.destinationIP, + }, + sourcePort: c.sourcePort, + destinationPort: c.destinationPort, + userData: userData, + } +} + +func (c *chunkTCP) Network() string { + return "tcp" +} + +func (c *chunkTCP) String() string { + src := c.SourceAddr() + dst := c.DestinationAddr() + + return fmt.Sprintf("%s %s chunk %s %s => %s", + src.Network(), + c.flags.String(), + c.tag, + src.String(), + dst.String(), + ) +} + +func (c *chunkTCP) setSourceAddr(address string) error { + addr, err := net.ResolveTCPAddr("tcp", address) + if err != nil { + return err + } + c.sourceIP = addr.IP + c.sourcePort = addr.Port + + return nil +} + +func (c *chunkTCP) setDestinationAddr(address string) error { + addr, err := net.ResolveTCPAddr("tcp", address) + if err != nil { + return err + } + c.destinationIP = addr.IP + c.destinationPort = addr.Port + + return nil +} diff --git a/vendor/github.com/pion/transport/v4/vnet/chunk_queue.go b/vendor/github.com/pion/transport/v4/vnet/chunk_queue.go new file mode 100644 index 0000000..f9cd811 --- /dev/null +++ b/vendor/github.com/pion/transport/v4/vnet/chunk_queue.go @@ -0,0 +1,69 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package vnet + +import ( + "sync" +) + +type chunkQueue struct { + chunks []Chunk + maxSize int // 0 or negative value: unlimited + maxBytes int // 0 or negative value: unlimited + currentBytes int + mutex sync.RWMutex +} + +func newChunkQueue(maxSize int, maxBytes int) *chunkQueue { + return &chunkQueue{ + chunks: []Chunk{}, + maxSize: maxSize, + maxBytes: maxBytes, + currentBytes: 0, + mutex: sync.RWMutex{}, + } +} + +func (q *chunkQueue) push(c Chunk) bool { + q.mutex.Lock() + defer q.mutex.Unlock() + + if q.maxSize > 0 && len(q.chunks) >= q.maxSize { + return false // dropped + } + if q.maxBytes > 0 && q.currentBytes+len(c.UserData()) >= q.maxBytes { + return false + } + + q.currentBytes += len(c.UserData()) + q.chunks = append(q.chunks, c) + + return true +} + +func (q *chunkQueue) pop() (Chunk, bool) { + q.mutex.Lock() + defer q.mutex.Unlock() + + if len(q.chunks) == 0 { + return nil, false + } + + c := q.chunks[0] + q.chunks = q.chunks[1:] + q.currentBytes -= len(c.UserData()) + + return c, true +} + +func (q *chunkQueue) peek() Chunk { + q.mutex.RLock() + defer q.mutex.RUnlock() + + if len(q.chunks) == 0 { + return nil + } + + return q.chunks[0] +} diff --git a/vendor/github.com/pion/transport/v4/vnet/conn.go b/vendor/github.com/pion/transport/v4/vnet/conn.go new file mode 100644 index 0000000..4684efb --- /dev/null +++ b/vendor/github.com/pion/transport/v4/vnet/conn.go @@ -0,0 +1,302 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package vnet + +import ( + "errors" + "fmt" + "io" + "math" + "net" + "sync" + "time" + + "github.com/pion/transport/v4" +) + +const ( + maxReadQueueSize = 1024 +) + +var ( + errObsCannotBeNil = errors.New("obs cannot be nil") + errUseClosedNetworkConn = errors.New("use of closed network connection") + errAddrNotUDPAddr = errors.New("addr is not a net.UDPAddr") + errLocAddr = errors.New("something went wrong with locAddr") + errAlreadyClosed = errors.New("already closed") + errNoRemAddr = errors.New("no remAddr defined") +) + +// vNet implements this. +type connObserver interface { + write(c Chunk) error + onClosed(addr net.Addr) + determineSourceIP(locIP, dstIP net.IP) net.IP +} + +// UDPConn is the implementation of the Conn and PacketConn interfaces for UDP network connections. +// compatible with net.PacketConn and net.Conn. +type UDPConn struct { + locAddr *net.UDPAddr // read-only + remAddr *net.UDPAddr // read-only + obs connObserver // read-only + readCh chan Chunk // thread-safe + closed bool // requires mutex + mu sync.Mutex // to mutex closed flag + readTimer *time.Timer // thread-safe +} + +var _ transport.UDPConn = &UDPConn{} + +func newUDPConn(locAddr, remAddr *net.UDPAddr, obs connObserver) (*UDPConn, error) { + if obs == nil { + return nil, errObsCannotBeNil + } + + return &UDPConn{ + locAddr: locAddr, + remAddr: remAddr, + obs: obs, + readCh: make(chan Chunk, maxReadQueueSize), + readTimer: time.NewTimer(time.Duration(math.MaxInt64)), + }, nil +} + +// Close closes the connection. +// Any blocked ReadFrom or WriteTo operations will be unblocked and return errors. +func (c *UDPConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.closed { + return errAlreadyClosed + } + c.closed = true + close(c.readCh) + + c.obs.onClosed(c.locAddr) + + return nil +} + +// LocalAddr returns the local network address. +func (c *UDPConn) LocalAddr() net.Addr { + return c.locAddr +} + +// RemoteAddr returns the remote network address. +func (c *UDPConn) RemoteAddr() net.Addr { + return c.remAddr +} + +// SetDeadline sets the read and write deadlines associated +// with the connection. It is equivalent to calling both +// SetReadDeadline and SetWriteDeadline. +// +// A deadline is an absolute time after which I/O operations +// fail with a timeout (see type Error) instead of +// blocking. The deadline applies to all future and pending +// I/O, not just the immediately following call to ReadFrom or +// WriteTo. After a deadline has been exceeded, the connection +// can be refreshed by setting a deadline in the future. +// +// An idle timeout can be implemented by repeatedly extending +// the deadline after successful ReadFrom or WriteTo calls. +// +// A zero value for t means I/O operations will not time out. +func (c *UDPConn) SetDeadline(t time.Time) error { + return c.SetReadDeadline(t) +} + +// SetReadDeadline sets the deadline for future ReadFrom calls +// and any currently-blocked ReadFrom call. +// A zero value for t means ReadFrom will not time out. +func (c *UDPConn) SetReadDeadline(t time.Time) error { + var d time.Duration + if t.IsZero() { + d = time.Duration(math.MaxInt64) + } else { + d = time.Until(t) + } + c.readTimer.Reset(d) + + return nil +} + +// SetWriteDeadline sets the deadline for future WriteTo calls +// and any currently-blocked WriteTo call. +// Even if write times out, it may return n > 0, indicating that +// some of the data was successfully written. +// A zero value for t means WriteTo will not time out. +func (c *UDPConn) SetWriteDeadline(time.Time) error { + // Write never blocks. + return nil +} + +// Read reads data from the connection. +// Read can be made to time out and return an Error with Timeout() == true +// after a fixed time limit; see SetDeadline and SetReadDeadline. +func (c *UDPConn) Read(b []byte) (int, error) { + n, _, err := c.ReadFrom(b) + + return n, err +} + +// ReadFrom reads a packet from the connection, +// copying the payload into p. It returns the number of +// bytes copied into p and the return address that +// was on the packet. +// It returns the number of bytes read (0 <= n <= len(p)) +// and any error encountered. Callers should always process +// the n > 0 bytes returned before considering the error err. +// ReadFrom can be made to time out and return +// an Error with Timeout() == true after a fixed time limit; +// see SetDeadline and SetReadDeadline. +func (c *UDPConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { +loop: + for { + select { + case chunk, ok := <-c.readCh: + if !ok { + break loop + } + var err error + n := copy(p, chunk.UserData()) + addr := chunk.SourceAddr() + if n < len(chunk.UserData()) { + err = io.ErrShortBuffer + } + + if c.remAddr != nil { + if addr.String() != c.remAddr.String() { + break // discard (shouldn't happen) + } + } + + return n, addr, err + + case <-c.readTimer.C: + return 0, nil, &net.OpError{ + Op: "read", + Net: c.locAddr.Network(), + Addr: c.locAddr, + Err: newTimeoutError("i/o timeout"), + } + } + } + + return 0, nil, &net.OpError{ + Op: "read", + Net: c.locAddr.Network(), + Addr: c.locAddr, + Err: errUseClosedNetworkConn, + } +} + +// ReadFromUDP acts like ReadFrom but returns a UDPAddr. +func (c *UDPConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) { + n, addr, err := c.ReadFrom(b) + + udpAddr, ok := addr.(*net.UDPAddr) + if !ok { + return -1, nil, fmt.Errorf("%w: %s", transport.ErrNotUDPAddress, addr) + } + + return n, udpAddr, err +} + +// ReadMsgUDP reads a message from c, copying the payload into b and +// the associated out-of-band data into oob. It returns the number of +// bytes copied into b, the number of bytes copied into oob, the flags +// that were set on the message and the source address of the message. +// +// The packages golang.org/x/net/ipv4 and golang.org/x/net/ipv6 can be +// used to manipulate IP-level socket options in oob. +func (c *UDPConn) ReadMsgUDP([]byte, []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) { + return -1, -1, -1, nil, transport.ErrNotSupported +} + +// Write writes data to the connection. +// Write can be made to time out and return an Error with Timeout() == true +// after a fixed time limit; see SetDeadline and SetWriteDeadline. +func (c *UDPConn) Write(b []byte) (int, error) { + if c.remAddr == nil { + return 0, errNoRemAddr + } + + return c.WriteTo(b, c.remAddr) +} + +// WriteTo writes a packet with payload to addr. +// WriteTo can be made to time out and return +// an Error with Timeout() == true after a fixed time limit; +// see SetDeadline and SetWriteDeadline. +// On packet-oriented connections, write timeouts are rare. +func (c *UDPConn) WriteTo(payload []byte, addr net.Addr) (n int, err error) { + dstAddr, ok := addr.(*net.UDPAddr) + if !ok { + return 0, errAddrNotUDPAddr + } + + srcIP := c.obs.determineSourceIP(c.locAddr.IP, dstAddr.IP) + if srcIP == nil { + return 0, errLocAddr + } + srcAddr := &net.UDPAddr{ + IP: srcIP, + Port: c.locAddr.Port, + } + + chunk := newChunkUDP(srcAddr, dstAddr) + chunk.userData = make([]byte, len(payload)) + copy(chunk.userData, payload) + if err := c.obs.write(chunk); err != nil { + return 0, err + } + + return len(payload), nil +} + +// WriteToUDP acts like WriteTo but takes a UDPAddr. +func (c *UDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { + return c.WriteTo(b, addr) +} + +// WriteMsgUDP writes a message to addr via c if c isn't connected, or +// to c's remote address if c is connected (in which case addr must be +// nil). The payload is copied from b and the associated out-of-band +// data is copied from oob. It returns the number of payload and +// out-of-band bytes written. +// +// The packages golang.org/x/net/ipv4 and golang.org/x/net/ipv6 can be +// used to manipulate IP-level socket options in oob. +func (c *UDPConn) WriteMsgUDP([]byte, []byte, *net.UDPAddr) (n, oobn int, err error) { + return -1, -1, transport.ErrNotSupported +} + +// SetReadBuffer sets the size of the operating system's +// receive buffer associated with the connection. +func (c *UDPConn) SetReadBuffer(int) error { + return transport.ErrNotSupported +} + +// SetWriteBuffer sets the size of the operating system's +// transmit buffer associated with the connection. +func (c *UDPConn) SetWriteBuffer(int) error { + return transport.ErrNotSupported +} + +func (c *UDPConn) onInboundChunk(chunk Chunk) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.closed { + return + } + + select { + case c.readCh <- chunk: + default: + } +} diff --git a/vendor/github.com/pion/transport/v4/vnet/conn_map.go b/vendor/github.com/pion/transport/v4/vnet/conn_map.go new file mode 100644 index 0000000..5e9a1ed --- /dev/null +++ b/vendor/github.com/pion/transport/v4/vnet/conn_map.go @@ -0,0 +1,143 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package vnet + +import ( + "errors" + "net" + "sync" +) + +var ( + errAddressAlreadyInUse = errors.New("address already in use") + errNoSuchUDPConn = errors.New("no such UDPConn") + errCannotRemoveUnspecifiedIP = errors.New("cannot remove unspecified IP by the specified IP") +) + +type udpConnMap struct { + portMap map[int][]*UDPConn + mutex sync.RWMutex +} + +func newUDPConnMap() *udpConnMap { + return &udpConnMap{ + portMap: map[int][]*UDPConn{}, + } +} + +func (m *udpConnMap) insert(conn *UDPConn) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + udpAddr := conn.LocalAddr().(*net.UDPAddr) //nolint:forcetypeassert + + // check if the port has a listener + conns, ok := m.portMap[udpAddr.Port] + if ok { + if udpAddr.IP.IsUnspecified() { + return errAddressAlreadyInUse + } + + for _, conn := range conns { + laddr := conn.LocalAddr().(*net.UDPAddr) //nolint:forcetypeassert + if laddr.IP.IsUnspecified() || laddr.IP.Equal(udpAddr.IP) { + return errAddressAlreadyInUse + } + } + + conns = append(conns, conn) + } else { + conns = []*UDPConn{conn} + } + + m.portMap[udpAddr.Port] = conns + + return nil +} + +func (m *udpConnMap) find(addr net.Addr) (*UDPConn, bool) { + m.mutex.Lock() // could be RLock, but we have delete() op + defer m.mutex.Unlock() + + udpAddr := addr.(*net.UDPAddr) //nolint:forcetypeassert + + if conns, ok := m.portMap[udpAddr.Port]; ok { + if udpAddr.IP.IsUnspecified() { + // pick the first one appears in the iteration + if len(conns) == 0 { + // This can't happen! + delete(m.portMap, udpAddr.Port) + + return nil, false + } + + return conns[0], true + } + + for _, conn := range conns { + laddr := conn.LocalAddr().(*net.UDPAddr) //nolint:forcetypeassert + if laddr.IP.IsUnspecified() || laddr.IP.Equal(udpAddr.IP) { + return conn, ok + } + } + } + + return nil, false +} + +func (m *udpConnMap) delete(addr net.Addr) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + udpAddr := addr.(*net.UDPAddr) //nolint:forcetypeassert + + conns, ok := m.portMap[udpAddr.Port] + if !ok { + return errNoSuchUDPConn + } + + if udpAddr.IP.IsUnspecified() { + // remove all from this port + delete(m.portMap, udpAddr.Port) + + return nil + } + + newConns := []*UDPConn{} + + for _, conn := range conns { + laddr := conn.LocalAddr().(*net.UDPAddr) //nolint:forcetypeassert + if laddr.IP.IsUnspecified() { + // This can't happen! + return errCannotRemoveUnspecifiedIP + } + + if laddr.IP.Equal(udpAddr.IP) { + continue + } + + newConns = append(newConns, conn) + } + + if len(newConns) == 0 { + delete(m.portMap, udpAddr.Port) + } else { + m.portMap[udpAddr.Port] = newConns + } + + return nil +} + +// size returns the number of UDPConns (UDP listeners). +func (m *udpConnMap) size() int { + m.mutex.RLock() + defer m.mutex.RUnlock() + + n := 0 + for _, conns := range m.portMap { + n += len(conns) + } + + return n +} diff --git a/vendor/github.com/pion/transport/v4/vnet/delay_filter.go b/vendor/github.com/pion/transport/v4/vnet/delay_filter.go new file mode 100644 index 0000000..89f1830 --- /dev/null +++ b/vendor/github.com/pion/transport/v4/vnet/delay_filter.go @@ -0,0 +1,186 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package vnet + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "time" +) + +// ErrInvalidDelay indicates an invalid (negative) delay duration was provided. +var ErrInvalidDelay = errors.New("delay must be non-negative") + +type delayFilterConfig struct { + delay time.Duration +} + +// DelayFilterOption configures DelayFilter creation. +type DelayFilterOption func(*delayFilterConfig) error + +// WithDelay sets the initial delay applied by the filter. +func WithDelay(delay time.Duration) DelayFilterOption { + return func(cfg *delayFilterConfig) error { + if delay < 0 { + return ErrInvalidDelay + } + cfg.delay = delay + + return nil + } +} + +// DelayFilter delays inbound packets by the given delay. Automatically starts +// processing when created and runs until Close() is called. +type DelayFilter struct { + NIC + delay atomic.Int64 // atomic field - stores time.Duration as int64 + push chan struct{} + queue *chunkQueue + done chan struct{} + wg sync.WaitGroup +} + +type timedChunk struct { + Chunk + deadline time.Time +} + +// NewDelayFilter creates and starts a new DelayFilter with the given NIC and options. +func NewDelayFilter(nic NIC, opts ...DelayFilterOption) (*DelayFilter, error) { + cfg := delayFilterConfig{} + + for _, opt := range opts { + if opt == nil { + continue + } + if err := opt(&cfg); err != nil { + return nil, err + } + } + + delayFilter := &DelayFilter{ + NIC: nic, + push: make(chan struct{}), + queue: newChunkQueue(0, 0), + done: make(chan struct{}), + } + + delayFilter.delay.Store(int64(cfg.delay)) + + // Start processing automatically + delayFilter.wg.Add(1) + go delayFilter.run() + + return delayFilter, nil +} + +// SetDelay atomically updates the delay. +func (f *DelayFilter) SetDelay(newDelay time.Duration) { + f.delay.Store(int64(newDelay)) +} + +func (f *DelayFilter) getDelay() time.Duration { + return time.Duration(f.delay.Load()) +} + +func (f *DelayFilter) onInboundChunk(c Chunk) { + f.queue.push(timedChunk{ + Chunk: c, + deadline: time.Now().Add(f.getDelay()), + }) + f.push <- struct{}{} +} + +// run processes the delayed packets queue until Close() is called. +func (f *DelayFilter) run() { + defer f.wg.Done() + + timer := time.NewTimer(0) + defer timer.Stop() + + for { + select { + case <-f.done: + f.drainRemainingPackets() + + return + + case <-f.push: + f.updateTimerForNextPacket(timer) + + case now := <-timer.C: + f.processReadyPackets(now) + f.scheduleNextPacketTimer(timer) + } + } +} + +// drainRemainingPackets sends all remaining packets immediately during shutdown. +func (f *DelayFilter) drainRemainingPackets() { + for { + next, ok := f.queue.pop() + if !ok { + break + } + if chunk, ok := next.(timedChunk); ok { + f.NIC.onInboundChunk(chunk.Chunk) + } + } +} + +// updateTimerForNextPacket updates the timer when a new packet arrives. +func (f *DelayFilter) updateTimerForNextPacket(timer *time.Timer) { + next := f.queue.peek() + if next != nil { + if chunk, ok := next.(timedChunk); ok { + if !timer.Stop() { + <-timer.C + } + timer.Reset(time.Until(chunk.deadline)) + } + } +} + +// processReadyPackets processes all packets that are ready to be sent. +func (f *DelayFilter) processReadyPackets(now time.Time) { + for { + next := f.queue.peek() + if next == nil { + break + } + if chunk, ok := next.(timedChunk); ok && !chunk.deadline.After(now) { + _, _ = f.queue.pop() // We already have the item from peek() + f.NIC.onInboundChunk(chunk.Chunk) + } else { + break + } + } +} + +// scheduleNextPacketTimer schedules the timer for the next packet to be processed. +func (f *DelayFilter) scheduleNextPacketTimer(timer *time.Timer) { + next := f.queue.peek() + if next == nil { + timer.Reset(time.Minute) // Long timeout when queue is empty + } else if chunk, ok := next.(timedChunk); ok { + timer.Reset(time.Until(chunk.deadline)) + } +} + +// Run is provided for backward compatibility. The DelayFilter now starts +// automatically when created, so this method is a no-op. +func (f *DelayFilter) Run(_ context.Context) { + // DelayFilter now starts automatically in NewDelayFilter, so this is a no-op +} + +// Close stops the DelayFilter and waits for graceful shutdown. +func (f *DelayFilter) Close() error { + close(f.done) + f.wg.Wait() + + return nil +} diff --git a/vendor/github.com/pion/transport/v4/vnet/duplication_filter.go b/vendor/github.com/pion/transport/v4/vnet/duplication_filter.go new file mode 100644 index 0000000..892d2c8 --- /dev/null +++ b/vendor/github.com/pion/transport/v4/vnet/duplication_filter.go @@ -0,0 +1,373 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package vnet + +import ( + "errors" + "math" + "math/rand" + "sync" + "time" +) + +const ( + defaultDuplicationBurstMultiplier = 10.0 + duplicationBucketSize = time.Millisecond +) + +var ( + // errInvalidDuplicationProbability indicates the configured duplication probability is outside [0, 1]. + errInvalidDuplicationProbability = errors.New("duplication probability must be between 0 and 1 inclusive") + // errInvalidDuplicationBurstProbability indicates the configured burst probability is outside [0, 1]. + errInvalidDuplicationBurstProbability = errors.New("duplication burst probability must be between 0 and 1 inclusive") + // errInvalidDuplicationBurstMultiplier indicates the configured burst multiplier is invalid. + errInvalidDuplicationBurstMultiplier = errors.New("duplication burst multiplier must be at least 1") + // errInvalidDuplicationDelayRange indicates the configured delay range is invalid. + errInvalidDuplicationDelayRange = errors.New("duplication delay range must satisfy 0 <= min <= max") + // errInvalidDuplicationBurstDuration indicates the burst duration is invalid. + errInvalidDuplicationBurstDuration = errors.New("duplication burst duration must be non-negative") + // errInvalidDuplicationRouter indicates a nil router was provided when constructing the filter. + errInvalidDuplicationRouter = errors.New("duplication filter requires a non-nil router reference") +) + +type duplicationConfig struct { + prob float64 + burstStartProb float64 + burstDuration time.Duration + burstMultiplier float64 + minExtraDelay time.Duration + maxExtraDelay time.Duration + seed *int64 +} + +// DuplicationOption configures a DuplicationFilter. +type DuplicationOption func(*duplicationConfig) error + +// WithDuplicationProbability sets the base duplication probability. +func WithDuplicationProbability(prob float64) DuplicationOption { + return func(cfg *duplicationConfig) error { + if prob < 0 || prob > 1 { + return errInvalidDuplicationProbability + } + + cfg.prob = prob + + return nil + } +} + +// WithDuplicationBurstProbability sets the probability that a burst window starts. Bursts are +// triggered probabilistically per packet when outside a burst window, creating time-based +// windows of elevated duplication. For deterministic burst cadences, seed the filter and +// control the burst timing externally. +func WithDuplicationBurstProbability(prob float64) DuplicationOption { + return func(cfg *duplicationConfig) error { + if prob < 0 || prob > 1 { + return errInvalidDuplicationBurstProbability + } + + cfg.burstStartProb = prob + + return nil + } +} + +// WithDuplicationBurstDuration configures how long burst mode stays active once triggered. +func WithDuplicationBurstDuration(duration time.Duration) DuplicationOption { + return func(cfg *duplicationConfig) error { + if duration < 0 { + return errInvalidDuplicationBurstDuration + } + + cfg.burstDuration = duration + + return nil + } +} + +// WithDuplicationBurstMultiplier adjusts how aggressively probability increases during a burst window. +func WithDuplicationBurstMultiplier(multiplier float64) DuplicationOption { + return func(cfg *duplicationConfig) error { + if multiplier < 1 { + return errInvalidDuplicationBurstMultiplier + } + + cfg.burstMultiplier = multiplier + + return nil + } +} + +// WithDuplicationExtraDelay sets the range for additional delay applied to duplicates. The +// selected delay is uniform across the inclusive range [minDelay, maxDelay]. +func WithDuplicationExtraDelay(minDelay, maxDelay time.Duration) DuplicationOption { + return func(cfg *duplicationConfig) error { + if minDelay < 0 || maxDelay < 0 || maxDelay < minDelay { + return errInvalidDuplicationDelayRange + } + + cfg.minExtraDelay = minDelay + cfg.maxExtraDelay = maxDelay + + return nil + } +} + +// WithDuplicationImmediate is a convenience that configures duplicates to be delivered +// without any extra delay (equivalent to WithDuplicationExtraDelay(0, 0)). +func WithDuplicationImmediate() DuplicationOption { + return WithDuplicationExtraDelay(0, 0) +} + +// WithDuplicationSeed sets the random seed used by the duplication filter. +func WithDuplicationSeed(seed int64) DuplicationOption { + return func(cfg *duplicationConfig) error { + cfg.seed = new(int64) + *cfg.seed = seed + + return nil + } +} + +// DuplicationFilter duplicates chunks that traverse a router according to the supplied configuration. +// When chaining with other filters, register duplication ahead of loss or latency filters to better +// emulate how duplicates typically occur before drop or jitter on real networks. +// +// Note: Call Close() to cancel pending delayed duplicates and prevent goroutine leaks when +// shutting down. Routers do not automatically close registered duplication filters so applications +// should wire Close() into their lifecycle (e.g., along with Router.Stop()). The filter is safe +// for concurrent use by multiple goroutines. +// +// Note: Duplicates re-enter the router and may be reordered relative to the original if other +// filters add jitter. Configure minExtraDelay appropriately to maintain ordering guarantees. +type DuplicationFilter struct { + router *Router + cfg duplicationConfig + mu sync.Mutex + rng *rand.Rand + burstEnd time.Time + now func() time.Time + timers map[*time.Timer]struct{} + closed bool + // bucketed scheduling to reduce timer churn + buckets map[int64]*dupBucket // key: fireAt in UnixNano aligned to duplicationBucketSize +} + +type dupBucket struct { + timer *time.Timer + chunks []Chunk +} + +// NewDuplicationFilter constructs a new DuplicationFilter bound to the provided router. +func NewDuplicationFilter(router *Router, opts ...DuplicationOption) (*DuplicationFilter, error) { + if router == nil { + return nil, errInvalidDuplicationRouter + } + + cfg := duplicationConfig{burstMultiplier: defaultDuplicationBurstMultiplier} + + for _, opt := range opts { + if opt == nil { + continue + } + if err := opt(&cfg); err != nil { + return nil, err + } + } + + if err := validateDuplicationConfig(&cfg); err != nil { + return nil, err + } + + rng := newRNG(cfg.seed) + + return &DuplicationFilter{ + router: router, + cfg: cfg, + rng: rng, + now: time.Now, + timers: make(map[*time.Timer]struct{}), + buckets: make(map[int64]*dupBucket), + }, nil +} + +// ChunkFilter returns a ChunkFilter that can be registered with Router.AddChunkFilter. +func (f *DuplicationFilter) ChunkFilter() ChunkFilter { + return func(c Chunk) bool { + if chunkIsDuplicate(c) { + return true + } + + delay, shouldDup := f.shouldDuplicate() + if shouldDup { + clone := c.Clone() + f.scheduleDuplicate(clone, delay) + } + + return true + } +} + +func (f *DuplicationFilter) shouldDuplicate() (time.Duration, bool) { + f.mu.Lock() + defer f.mu.Unlock() + + if f.closed { + return 0, false + } + + now := f.now() + probability := f.cfg.prob + + if f.cfg.burstDuration > 0 && f.cfg.burstStartProb > 0 { + if now.After(f.burstEnd) { + if f.rng.Float64() < f.cfg.burstStartProb { + f.burstEnd = now.Add(f.cfg.burstDuration) + } + } + + if now.Before(f.burstEnd) { + probability = math.Min(1.0, probability*f.cfg.burstMultiplier) + } + } + + if f.rng.Float64() >= probability { + return 0, false + } + + // compute delay: uniform distribution over [min, max]. + delay := f.cfg.minExtraDelay + if f.cfg.maxExtraDelay > f.cfg.minExtraDelay { + extra := f.cfg.maxExtraDelay - f.cfg.minExtraDelay + delay += time.Duration(f.rng.Int63n(int64(extra) + 1)) + } + + return delay, true +} + +func (f *DuplicationFilter) scheduleDuplicate(dup Chunk, delay time.Duration) { + markChunkDuplicate(dup) + + f.mu.Lock() + if f.closed { + f.mu.Unlock() + + return + } + + // bucketed scheduling, we group duplicates into fixed windows. + // this is to avoid creating a timer for each duplication. + now := f.now() + deadline := now.Add(delay) + bucketN := int64(duplicationBucketSize) + deadlineN := deadline.UnixNano() + // we round up to the next bucket boundary to avoid early delivery + fireAtN := ((deadlineN + bucketN - 1) / bucketN) * bucketN + + bucket, ok := f.buckets[fireAtN] + if !ok { + fireAt := time.Unix(0, fireAtN) + wait := max(fireAt.Sub(now), 0) + bucket = &dupBucket{} + bucket.timer = time.AfterFunc(wait, func() { + f.onBucketFired(fireAtN) + }) + f.buckets[fireAtN] = bucket + f.timers[bucket.timer] = struct{}{} + } + + bucket.chunks = append(bucket.chunks, dup) + f.mu.Unlock() +} + +func (f *DuplicationFilter) onBucketFired(key int64) { + f.mu.Lock() + if f.closed { + if bucket, ok := f.buckets[key]; ok { + delete(f.timers, bucket.timer) + delete(f.buckets, key) + } + f.mu.Unlock() + + return + } + + bucket, ok := f.buckets[key] + if ok { + delete(f.timers, bucket.timer) + delete(f.buckets, key) + } + chunks := bucket.chunks + f.mu.Unlock() + + for i := 0; i < len(chunks); i++ { + f.router.push(chunks[i]) + } +} + +// Close cancels all pending duplicate deliveries and prevents future duplications. +func (f *DuplicationFilter) Close() error { + f.mu.Lock() + if f.closed { + f.mu.Unlock() + + return nil + } + + f.closed = true + timers := make([]*time.Timer, 0, len(f.timers)) + for timer := range f.timers { + timers = append(timers, timer) + } + f.mu.Unlock() + + for _, timer := range timers { + timer.Stop() + } + + return nil +} + +func validateDuplicationConfig(cfg *duplicationConfig) error { + if cfg.prob < 0 || cfg.prob > 1 { + return errInvalidDuplicationProbability + } + if cfg.burstStartProb < 0 || cfg.burstStartProb > 1 { + return errInvalidDuplicationBurstProbability + } + if cfg.burstMultiplier < 1 { + return errInvalidDuplicationBurstMultiplier + } + if cfg.burstDuration < 0 { + return errInvalidDuplicationBurstDuration + } + if cfg.minExtraDelay < 0 || cfg.maxExtraDelay < 0 || cfg.maxExtraDelay < cfg.minExtraDelay { + return errInvalidDuplicationDelayRange + } + + return nil +} + +func chunkIsDuplicate(c Chunk) bool { + type duplicateChecker interface { + isDuplicate() bool + } + + // a small cheat for test 100% test cov :) + if checker, ok := c.(duplicateChecker); ok && checker.isDuplicate() { + return true + } + + return false +} + +func markChunkDuplicate(c Chunk) { + type duplicateMarker interface { + markDuplicate() + } + + if marker, ok := c.(duplicateMarker); ok { + marker.markDuplicate() + } +} diff --git a/vendor/github.com/pion/transport/v4/vnet/errors.go b/vendor/github.com/pion/transport/v4/vnet/errors.go new file mode 100644 index 0000000..22c7c2d --- /dev/null +++ b/vendor/github.com/pion/transport/v4/vnet/errors.go @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package vnet + +type timeoutError struct { + msg string +} + +func newTimeoutError(msg string) error { + return &timeoutError{ + msg: msg, + } +} + +func (e *timeoutError) Error() string { + return e.msg +} + +func (e *timeoutError) Timeout() bool { + return true +} diff --git a/vendor/github.com/pion/transport/v4/vnet/loss_filter.go b/vendor/github.com/pion/transport/v4/vnet/loss_filter.go new file mode 100644 index 0000000..d702529 --- /dev/null +++ b/vendor/github.com/pion/transport/v4/vnet/loss_filter.go @@ -0,0 +1,298 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package vnet + +import ( + "errors" + "math/rand" + "sync" + "time" +) + +// Static errors for better error handling. +var ( + ErrInvalidChance = errors.New("chance must be between 0 and 100 inclusive") + ErrInvalidShuffleBlockSize = errors.New("shuffleBlockSize must be greater than 0") +) + +type LossFilterHandler interface { + shouldDrop() bool + setLossRate(chance int, resetImmediately bool) +} + +// LossFilter is a wrapper around NICs, that drops some of the packets passed to +// onInboundChunk. +type LossFilter struct { + NIC + LossFilterHandler +} + +// lossFilterConfig holds the configuration for creating a LossFilter. +type lossFilterConfig struct { + nic NIC + chance int + handler LossFilterHandler + shuffleBlockSize int + seed *int64 +} + +// LossFilterOption represents a configuration option for LossFilter creation. +type LossFilterOption func(cfg *lossFilterConfig) error + +// WithLossHandler sets a custom loss handler for the LossFilter. +// This option takes precedence over WithShuffleLossHandler if both are provided. +func WithLossHandler(handler LossFilterHandler) LossFilterOption { + return func(cfg *lossFilterConfig) error { + cfg.handler = handler + + return nil + } +} + +// WithShuffleLossHandler configures the LossFilter to use deterministic shuffle-based packet loss +// with the specified block size. When set, for every blockSize packets, it guarantees that the +// number of packets dropped equals round(blockSize * chance / 100), where chance is a percentage (0-100). +func WithShuffleLossHandler(blockSize int) LossFilterOption { + return func(cfg *lossFilterConfig) error { + if blockSize < 1 { + return ErrInvalidShuffleBlockSize + } + cfg.shuffleBlockSize = blockSize + + return nil + } +} + +// WithLossSeed sets the random seed used by the loss filter for deterministic behavior. +// When a seed is provided (including seed==0), both random loss and shuffle-based loss will +// produce reproducible results. +// If no seed is provided (nil), the filter uses time-based seeding for non-deterministic behavior. +func WithLossSeed(seed int64) LossFilterOption { + return func(cfg *lossFilterConfig) error { + cfg.seed = new(int64) + *cfg.seed = seed + + return nil + } +} + +// lossHandle drops packets with configurable behavior: random or deterministic shuffle-based. +// When shuffleBlockSize is 0, it uses pure random dropping. +// When shuffleBlockSize > 0, it uses deterministic shuffle-based dropping for better distribution. +type lossHandle struct { + // percentage (0-100) - used in random mode, stored for consistency in shuffle mode + chance int + mutex sync.RWMutex + // seeded random number generator + rng *rand.Rand + + // Shuffle mode fields (only used when shuffleBlockSize > 0) + shuffleBlockSize int + blockIdx int + shuffledBlock []bool + // current number of drops per block (calculated from chance percentage) + currentDrops int + pendingDrops int +} + +// calculateDropsPerBlock calculates the number of packets to drop per block based on percentage chance. +// Uses rounding: (chance * blockSize + 50) / 100. +func calculateDropsPerBlock(chancePercent int, blockSize int) int { + return (chancePercent*blockSize + 50) / 100 +} + +// newRNG creates a new random number generator. If seed is nil, uses time-based seeding. +// A seed of 0 is treated as a valid deterministic seed (not time-based). +func newRNG(seed *int64) *rand.Rand { + if seed == nil { + // nolint:gosec // weak rand is intended + return rand.New(rand.NewSource(time.Now().UnixNano())) + } + // nolint:gosec // weak rand is intended + return rand.New(rand.NewSource(*seed)) +} + +// newRandomLossHandle creates a new lossHandle for random packet dropping. +func newRandomLossHandle(chance int, rng *rand.Rand) *lossHandle { + return &lossHandle{ + chance: chance, + shuffleBlockSize: 0, // 0 means random mode + rng: rng, + } +} + +// newShuffleLossHandle creates a new lossHandle for shuffle-based packet loss. +func newShuffleLossHandle(chance, shuffleBlockSize int, rng *rand.Rand) *lossHandle { + dropsPerBlock := calculateDropsPerBlock(chance, shuffleBlockSize) + handler := &lossHandle{ + chance: chance, + shuffleBlockSize: shuffleBlockSize, + shuffledBlock: make([]bool, shuffleBlockSize), + currentDrops: dropsPerBlock, + pendingDrops: dropsPerBlock, + rng: rng, + } + + for i := 0; i < handler.currentDrops; i++ { + handler.shuffledBlock[i] = true + } + + handler.shuffleBlock() + + return handler +} + +func (r *lossHandle) shouldDrop() bool { + if r.shuffleBlockSize > 0 { + return r.shouldDropShuffle() + } + + r.mutex.Lock() + chance := r.chance + result := r.rng.Intn(100) < chance + r.mutex.Unlock() + + return result +} + +func (r *lossHandle) shouldDropShuffle() bool { + r.mutex.Lock() + defer r.mutex.Unlock() + + if r.blockIdx == len(r.shuffledBlock) { + r.shuffleBlock() + } + + res := r.shuffledBlock[r.blockIdx] + r.blockIdx++ + + return res +} + +func (r *lossHandle) setLossRate(chance int, resetImmediately bool) { + if r.shuffleBlockSize > 0 { + r.setLossRateShuffle(chance, resetImmediately) + } else { + r.mutex.Lock() + defer r.mutex.Unlock() + r.chance = chance + } +} + +func (r *lossHandle) setLossRateShuffle(chance int, resetImmediately bool) { + r.mutex.Lock() + defer r.mutex.Unlock() + + r.chance = chance // store percentage for consistency + r.pendingDrops = calculateDropsPerBlock(chance, r.shuffleBlockSize) + + if resetImmediately { + r.shuffleBlock() + } +} + +// shuffleBlock shuffles the current block using the RNG. +// This method must be called while holding mutex to ensure thread-safe RNG access. +func (r *lossHandle) shuffleBlock() { + // Update shuffled block to match pending drops count + for idx := 0; idx < len(r.shuffledBlock); idx++ { + switch { + case r.pendingDrops == r.currentDrops: + goto shuffleComplete + case r.pendingDrops > r.currentDrops && !r.shuffledBlock[idx]: + r.shuffledBlock[idx] = true + r.currentDrops++ + case r.pendingDrops < r.currentDrops && r.shuffledBlock[idx]: + r.shuffledBlock[idx] = false + r.currentDrops-- + } + } + +shuffleComplete: + r.rng.Shuffle(len(r.shuffledBlock), func(i, j int) { + r.shuffledBlock[i], r.shuffledBlock[j] = r.shuffledBlock[j], r.shuffledBlock[i] + }) + r.blockIdx = 0 +} + +// NewLossFilter creates a new LossFilter that drops every packet with a +// probability of chance/100. You can provide custom options to override the +// default behavior. This follows the Pion options pattern for extensibility. +// +// Option precedence: If WithLossHandler is provided, it takes precedence and any +// WithShuffleLossHandler option will be ignored. +func NewLossFilter(nic NIC, chance int, options ...LossFilterOption) (*LossFilter, error) { + if !validateChance(chance) { + return nil, ErrInvalidChance + } + + // Initialize config with defaults + cfg := &lossFilterConfig{ + nic: nic, + chance: chance, + shuffleBlockSize: 0, // 0 means random mode + } + + for _, option := range options { + if option == nil { + continue + } + if err := option(cfg); err != nil { + return nil, err + } + } + + // Create handler based on config + // Precedence: WithLossHandler > WithShuffleLossHandler > default random handler + var lossHandler LossFilterHandler + + switch { + case cfg.handler != nil: + // Use provided handler (WithLossHandler takes precedence over WithShuffleLossHandler) + cfg.handler.setLossRate(cfg.chance, false) + lossHandler = cfg.handler + case cfg.shuffleBlockSize > 0: + // Create shuffle handler with seed from config if available + lossHandler = newShuffleLossHandle(cfg.chance, cfg.shuffleBlockSize, newRNG(cfg.seed)) + default: + // Random mode - create handler with seed from config if available + lossHandler = newRandomLossHandle(cfg.chance, newRNG(cfg.seed)) + } + + lossFilter := &LossFilter{ + NIC: nic, + LossFilterHandler: lossHandler, + } + + return lossFilter, nil +} + +func (f *LossFilter) onInboundChunk(c Chunk) { + if f.LossFilterHandler.shouldDrop() { + return + } + + f.NIC.onInboundChunk(c) +} + +// SetLossRate sets the loss rate for the loss filter. +// The chance parameter is an integer out of 100. +// The resetImmediately parameter is a boolean that indicates whether to reset the loss rate immediately. +// If resetImmediately is true, the loss rate will be reset immediately. +// If resetImmediately is false, the loss rate will be reset after the next shuffle for shuffle-based handlers. +// Note that for random loss handlers (when shuffleBlockSize is 0), the loss rate will be reset immediately +// regardless of the resetImmediately parameter. +func (f *LossFilter) SetLossRate(chance int, resetImmediately bool) error { + if !validateChance(chance) { + return ErrInvalidChance + } + + f.LossFilterHandler.setLossRate(chance, resetImmediately) + + return nil +} + +func validateChance(chance int) bool { + return chance >= 0 && chance <= 100 +} diff --git a/vendor/github.com/pion/transport/v4/vnet/nat.go b/vendor/github.com/pion/transport/v4/vnet/nat.go new file mode 100644 index 0000000..f4722af --- /dev/null +++ b/vendor/github.com/pion/transport/v4/vnet/nat.go @@ -0,0 +1,346 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package vnet + +import ( + "errors" + "fmt" + "net" + "sync" + "time" + + "github.com/pion/logging" +) + +var ( + errNATRequriesMapping = errors.New("1:1 NAT requires more than one mapping") + errMismatchLengthIP = errors.New("length mismtach between mappedIPs and localIPs") + errNonUDPTranslationNotSupported = errors.New("non-udp translation is not supported yet") + errNoAssociatedLocalAddress = errors.New("no associated local address") + errNoNATBindingFound = errors.New("no NAT binding found") + errHasNoPermission = errors.New("has no permission") +) + +// EndpointDependencyType defines a type of behavioral dependendency on the +// remote endpoint's IP address or port number. This is used for the two +// kinds of behaviors: +// - Port mapping behavior +// - Filtering behavior +// +// See: https://tools.ietf.org/html/rfc4787 +type EndpointDependencyType uint8 + +const ( + // EndpointIndependent means the behavior is independent of the endpoint's address or port. + EndpointIndependent EndpointDependencyType = iota + // EndpointAddrDependent means the behavior is dependent on the endpoint's address. + EndpointAddrDependent + // EndpointAddrPortDependent means the behavior is dependent on the endpoint's address and port. + EndpointAddrPortDependent +) + +// NATMode defines basic behavior of the NAT. +type NATMode uint8 + +const ( + // NATModeNormal means the NAT behaves as a standard NAPT (RFC 2663). + NATModeNormal NATMode = iota + // NATModeNAT1To1 exhibits 1:1 DNAT where the external IP address is statically mapped to + // a specific local IP address with port number is preserved always between them. + // When this mode is selected, MappingBehavior, FilteringBehavior, PortPreservation and + // MappingLifeTime of NATType are ignored. + NATModeNAT1To1 +) + +const ( + defaultNATMappingLifeTime = 30 * time.Second +) + +// NATType has a set of parameters that define the behavior of NAT. +type NATType struct { + Mode NATMode + MappingBehavior EndpointDependencyType + FilteringBehavior EndpointDependencyType + Hairpinning bool // Not implemented yet + PortPreservation bool // Not implemented yet + MappingLifeTime time.Duration +} + +type natConfig struct { + name string + natType NATType + mappedIPs []net.IP // mapped IPv4 + localIPs []net.IP // local IPv4, required only when the mode is NATModeNAT1To1 + loggerFactory logging.LoggerFactory +} + +type mapping struct { + proto string // "udp" or "tcp" + local string // ":" + mapped string // ":" + bound string // key: "[[:]]" + filters map[string]struct{} // key: "[[:]]" + expires time.Time // time to expire +} + +type networkAddressTranslator struct { + name string + natType NATType + mappedIPs []net.IP // mapped IPv4 + localIPs []net.IP // local IPv4, required only when the mode is NATModeNAT1To1 + outboundMap map[string]*mapping // key: "::[:remote-ip[:remote-port]] + inboundMap map[string]*mapping // key: "::" + udpPortCounter int + mutex sync.RWMutex + log logging.LeveledLogger +} + +func newNAT(config *natConfig) (*networkAddressTranslator, error) { + natType := config.natType + + if natType.Mode == NATModeNAT1To1 { + // 1:1 NAT behavior + natType.MappingBehavior = EndpointIndependent + natType.FilteringBehavior = EndpointIndependent + natType.PortPreservation = true + natType.MappingLifeTime = 0 + + if len(config.mappedIPs) == 0 { + return nil, errNATRequriesMapping + } + if len(config.mappedIPs) != len(config.localIPs) { + return nil, errMismatchLengthIP + } + } else { + // Normal (NAPT) behavior + natType.Mode = NATModeNormal + if natType.MappingLifeTime == 0 { + natType.MappingLifeTime = defaultNATMappingLifeTime + } + } + + return &networkAddressTranslator{ + name: config.name, + natType: natType, + mappedIPs: config.mappedIPs, + localIPs: config.localIPs, + outboundMap: map[string]*mapping{}, + inboundMap: map[string]*mapping{}, + log: config.loggerFactory.NewLogger("vnet"), + }, nil +} + +func (n *networkAddressTranslator) getPairedMappedIP(locIP net.IP) net.IP { + for i, ip := range n.localIPs { + if ip.Equal(locIP) { + return n.mappedIPs[i] + } + } + + return nil +} + +func (n *networkAddressTranslator) getPairedLocalIP(mappedIP net.IP) net.IP { + for i, ip := range n.mappedIPs { + if ip.Equal(mappedIP) { + return n.localIPs[i] + } + } + + return nil +} + +func (n *networkAddressTranslator) translateOutbound(from Chunk) (Chunk, error) { //nolint:cyclop + n.mutex.Lock() + defer n.mutex.Unlock() + + to := from.Clone() + + if from.Network() == udp { //nolint:nestif + if n.natType.Mode == NATModeNAT1To1 { + // 1:1 NAT behavior + srcAddr := from.SourceAddr().(*net.UDPAddr) //nolint:forcetypeassert + srcIP := n.getPairedMappedIP(srcAddr.IP) + if srcIP == nil { + n.log.Debugf("[%s] drop outbound chunk %s with not route", n.name, from.String()) + + return nil, nil // nolint:nilnil + } + srcPort := srcAddr.Port + if err := to.setSourceAddr(fmt.Sprintf("%s:%d", srcIP.String(), srcPort)); err != nil { + return nil, err + } + } else { + // Normal (NAPT) behavior + var bound, filterKey string + switch n.natType.MappingBehavior { + case EndpointIndependent: + bound = "" + case EndpointAddrDependent: + bound = from.getDestinationIP().String() + case EndpointAddrPortDependent: + bound = from.DestinationAddr().String() + } + + switch n.natType.FilteringBehavior { + case EndpointIndependent: + filterKey = "" + case EndpointAddrDependent: + filterKey = from.getDestinationIP().String() + case EndpointAddrPortDependent: + filterKey = from.DestinationAddr().String() + } + + oKey := fmt.Sprintf("udp:%s:%s", from.SourceAddr().String(), bound) + + mapp := n.findOutboundMapping(oKey) + if mapp == nil { + // Create a new mapping + mappedPort := 0xC000 + n.udpPortCounter + n.udpPortCounter++ + + mapp = &mapping{ + proto: from.SourceAddr().Network(), + local: from.SourceAddr().String(), + bound: bound, + mapped: fmt.Sprintf("%s:%d", n.mappedIPs[0].String(), mappedPort), + filters: map[string]struct{}{}, + expires: time.Now().Add(n.natType.MappingLifeTime), + } + + n.outboundMap[oKey] = mapp + + iKey := fmt.Sprintf("udp:%s", mapp.mapped) + + n.log.Debugf("[%s] created a new NAT binding oKey=%s iKey=%s", + n.name, + oKey, + iKey) + + mapp.filters[filterKey] = struct{}{} + n.log.Debugf("[%s] permit access from %s to %s", n.name, filterKey, mapp.mapped) + n.inboundMap[iKey] = mapp + } else if _, ok := mapp.filters[filterKey]; !ok { + n.log.Debugf("[%s] permit access from %s to %s", n.name, filterKey, mapp.mapped) + mapp.filters[filterKey] = struct{}{} + } + + if err := to.setSourceAddr(mapp.mapped); err != nil { + return nil, err + } + } + + n.log.Debugf("[%s] translate outbound chunk from %s to %s", n.name, from.String(), to.String()) + + return to, nil + } + + return nil, errNonUDPTranslationNotSupported +} + +func (n *networkAddressTranslator) translateInbound(from Chunk) (Chunk, error) { //nolint:cyclop + n.mutex.Lock() + defer n.mutex.Unlock() + + to := from.Clone() + + if from.Network() == udp { //nolint:nestif + if n.natType.Mode == NATModeNAT1To1 { + // 1:1 NAT behavior + dstAddr := from.DestinationAddr().(*net.UDPAddr) //nolint:forcetypeassert + dstIP := n.getPairedLocalIP(dstAddr.IP) + if dstIP == nil { + return nil, fmt.Errorf("drop %s as %w", from.String(), errNoAssociatedLocalAddress) + } + dstPort := from.DestinationAddr().(*net.UDPAddr).Port //nolint:forcetypeassert + if err := to.setDestinationAddr(fmt.Sprintf("%s:%d", dstIP, dstPort)); err != nil { + return nil, err + } + } else { + // Normal (NAPT) behavior + iKey := fmt.Sprintf("udp:%s", from.DestinationAddr().String()) + mapping := n.findInboundMapping(iKey) + if mapping == nil { + return nil, fmt.Errorf("drop %s as %w", from.String(), errNoNATBindingFound) + } + + var filterKey string + switch n.natType.FilteringBehavior { + case EndpointIndependent: + filterKey = "" + case EndpointAddrDependent: + filterKey = from.getSourceIP().String() + case EndpointAddrPortDependent: + filterKey = from.SourceAddr().String() + } + + if _, ok := mapping.filters[filterKey]; !ok { + return nil, fmt.Errorf("drop %s as the remote %s %w", from.String(), filterKey, errHasNoPermission) + } + + // See RFC 4847 Section 4.3. Mapping Refresh + // a) Inbound refresh may be useful for applications with no outgoing + // UDP traffic. However, allowing inbound refresh may allow an + // external attacker or misbehaving application to keep a mapping + // alive indefinitely. This may be a security risk. Also, if the + // process is repeated with different ports, over time, it could + // use up all the ports on the NAT. + + if err := to.setDestinationAddr(mapping.local); err != nil { + return nil, err + } + } + + n.log.Debugf("[%s] translate inbound chunk from %s to %s", n.name, from.String(), to.String()) + + return to, nil + } + + return nil, errNonUDPTranslationNotSupported +} + +// caller must hold the mutex. +func (n *networkAddressTranslator) findOutboundMapping(oKey string) *mapping { + now := time.Now() + + m, ok := n.outboundMap[oKey] + if ok { + // check if this mapping is expired + if now.After(m.expires) { + n.removeMapping(m) + m = nil // expired + } else { + m.expires = time.Now().Add(n.natType.MappingLifeTime) + } + } + + return m +} + +// caller must hold the mutex. +func (n *networkAddressTranslator) findInboundMapping(iKey string) *mapping { + now := time.Now() + m, ok := n.inboundMap[iKey] + if !ok { + return nil + } + + // check if this mapping is expired + if now.After(m.expires) { + n.removeMapping(m) + + return nil + } + + return m +} + +// caller must hold the mutex. +func (n *networkAddressTranslator) removeMapping(m *mapping) { + oKey := fmt.Sprintf("%s:%s:%s", m.proto, m.local, m.bound) + iKey := fmt.Sprintf("%s:%s", m.proto, m.mapped) + + delete(n.outboundMap, oKey) + delete(n.inboundMap, iKey) +} diff --git a/vendor/github.com/pion/transport/v4/vnet/net.go b/vendor/github.com/pion/transport/v4/vnet/net.go new file mode 100644 index 0000000..589fe4b --- /dev/null +++ b/vendor/github.com/pion/transport/v4/vnet/net.go @@ -0,0 +1,681 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package vnet + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "math/rand" + "net" + "strconv" + "strings" + "sync" + + "github.com/pion/transport/v4" +) + +const ( + lo0String = "lo0String" + udp = "udp" + udp4 = "udp4" +) + +var ( + macAddrCounter uint64 = 0xBEEFED910200 //nolint:gochecknoglobals + errNoInterface = errors.New("no interface is available") + errUnexpectedNetwork = errors.New("unexpected network") + errCantAssignRequestedAddr = errors.New("can't assign requested address") + errUnknownNetwork = errors.New("unknown network") + errNoRouterLinked = errors.New("no router linked") + errInvalidPortNumber = errors.New("invalid port number") + errUnexpectedTypeSwitchFailure = errors.New("unexpected type-switch failure") + errBindFailedFor = errors.New("bind failed for") + errEndPortLessThanStart = errors.New("end port is less than the start") + errPortSpaceExhausted = errors.New("port space exhausted") +) + +func newMACAddress() net.HardwareAddr { + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, macAddrCounter) + macAddrCounter++ + + return b[2:] +} + +// Net represents a local network stack equivalent to a set of layers from NIC +// up to the transport (UDP / TCP) layer. +type Net struct { + interfaces []*transport.Interface // read-only + staticIPs []net.IP // read-only + router *Router // read-only + udpConns *udpConnMap // read-only + mutex sync.RWMutex +} + +// Compile-time assertion. +var _ transport.Net = &Net{} + +func (v *Net) _getInterfaces() ([]*transport.Interface, error) { + if len(v.interfaces) == 0 { + return nil, errNoInterface + } + + return v.interfaces, nil +} + +// Interfaces returns a list of the system's network interfaces. +func (v *Net) Interfaces() ([]*transport.Interface, error) { + v.mutex.RLock() + defer v.mutex.RUnlock() + + return v._getInterfaces() +} + +// caller must hold the mutex (read). +func (v *Net) _getInterface(ifName string) (*transport.Interface, error) { + ifs, err := v._getInterfaces() + if err != nil { + return nil, err + } + for _, ifc := range ifs { + if ifc.Name == ifName { + return ifc, nil + } + } + + return nil, fmt.Errorf("%w: %s", transport.ErrInterfaceNotFound, ifName) +} + +func (v *Net) getInterface(ifName string) (*transport.Interface, error) { + v.mutex.RLock() + defer v.mutex.RUnlock() + + return v._getInterface(ifName) +} + +// InterfaceByIndex returns the interface specified by index. +// +// On Solaris, it returns one of the logical network interfaces +// sharing the logical data link; for more precision use +// InterfaceByName. +func (v *Net) InterfaceByIndex(index int) (*transport.Interface, error) { + for _, ifc := range v.interfaces { + if ifc.Index == index { + return ifc, nil + } + } + + return nil, fmt.Errorf("%w: index=%d", transport.ErrInterfaceNotFound, index) +} + +// InterfaceByName returns the interface specified by name. +func (v *Net) InterfaceByName(ifName string) (*transport.Interface, error) { + return v.getInterface(ifName) +} + +// caller must hold the mutex. +func (v *Net) getAllIPAddrs(ipv6 bool) []net.IP { + ips := []net.IP{} + + for _, ifc := range v.interfaces { + addrs, err := ifc.Addrs() + if err != nil { + continue + } + + for _, addr := range addrs { + var ip net.IP + if ipNet, ok := addr.(*net.IPNet); ok { + ip = ipNet.IP + } else if ipAddr, ok := addr.(*net.IPAddr); ok { + ip = ipAddr.IP + } else { + continue + } + + if !ipv6 { + if ip.To4() != nil { + ips = append(ips, ip) + } + } + } + } + + return ips +} + +func (v *Net) setRouter(r *Router) error { + v.mutex.Lock() + defer v.mutex.Unlock() + + v.router = r + + return nil +} + +// AddAddress adds an address to an interface and registers it for routing. +// This method can be called before or after the router has started. +func (v *Net) AddAddress(ifName string, addr *net.IPNet) error { + v.mutex.Lock() + defer v.mutex.Unlock() + + ifc, err := v._getInterface(ifName) + if err != nil { + return err + } + ifc.AddAddress(addr) + + if v.router != nil { + v.router.mutex.Lock() + defer v.router.mutex.Unlock() + + return v.router.addIPToNIC(v, addr.IP) + } + + return nil +} + +// RemoveAddress removes an address from an interface and unregisters it from routing. +// This method can be called before or after the router has started. +func (v *Net) RemoveAddress(ifName string, ip net.IP) error { + v.mutex.Lock() + defer v.mutex.Unlock() + + ifc, err := v._getInterface(ifName) + if err != nil { + return err + } + ifc.RemoveAddress(ip) + + if v.router != nil { + v.router.mutex.Lock() + defer v.router.mutex.Unlock() + v.router.removeIPFromNIC(ip) + } + + return nil +} + +func (v *Net) onInboundChunk(c Chunk) { + v.mutex.Lock() + defer v.mutex.Unlock() + + if c.Network() == udp { + if conn, ok := v.udpConns.find(c.DestinationAddr()); ok { + conn.onInboundChunk(c) + } + } +} + +// caller must hold the mutex. +func (v *Net) _dialUDP(network string, locAddr, remAddr *net.UDPAddr) (transport.UDPConn, error) { //nolint:cyclop + // validate network + if network != udp && network != udp4 { + return nil, fmt.Errorf("%w: %s", errUnexpectedNetwork, network) + } + + if locAddr == nil { + locAddr = &net.UDPAddr{ + IP: net.IPv4zero, + } + } else if locAddr.IP == nil { + locAddr.IP = net.IPv4zero + } + + // validate address. do we have that address? + if !v.hasIPAddr(locAddr.IP) { + return nil, &net.OpError{ + Op: "listen", + Net: network, + Addr: locAddr, + Err: fmt.Errorf("bind: %w", errCantAssignRequestedAddr), + } + } + + if locAddr.Port == 0 { + // choose randomly from the range between 5000 and 5999 + port, err := v.assignPort(locAddr.IP, 5000, 5999) + if err != nil { + return nil, &net.OpError{ + Op: "listen", + Net: network, + Addr: locAddr, + Err: err, + } + } + locAddr.Port = port + } else if _, ok := v.udpConns.find(locAddr); ok { + return nil, &net.OpError{ + Op: "listen", + Net: network, + Addr: locAddr, + Err: fmt.Errorf("bind: %w", errAddressAlreadyInUse), + } + } + + conn, err := newUDPConn(locAddr, remAddr, v) + if err != nil { + return nil, err + } + + err = v.udpConns.insert(conn) + if err != nil { + return nil, err + } + + return conn, nil +} + +// ListenPacket announces on the local network address. +func (v *Net) ListenPacket(network string, address string) (net.PacketConn, error) { + v.mutex.Lock() + defer v.mutex.Unlock() + + locAddr, err := v.ResolveUDPAddr(network, address) + if err != nil { + return nil, err + } + + return v._dialUDP(network, locAddr, nil) +} + +// ListenUDP acts like ListenPacket for UDP networks. +func (v *Net) ListenUDP(network string, locAddr *net.UDPAddr) (transport.UDPConn, error) { + v.mutex.Lock() + defer v.mutex.Unlock() + + return v._dialUDP(network, locAddr, nil) +} + +// DialUDP acts like Dial for UDP networks. +func (v *Net) DialUDP(network string, locAddr, remAddr *net.UDPAddr) (transport.UDPConn, error) { + v.mutex.Lock() + defer v.mutex.Unlock() + + return v._dialUDP(network, locAddr, remAddr) +} + +// Dial connects to the address on the named network. +func (v *Net) Dial(network string, address string) (net.Conn, error) { + v.mutex.Lock() + defer v.mutex.Unlock() + + remAddr, err := v.ResolveUDPAddr(network, address) + if err != nil { + return nil, err + } + + // Determine source address + srcIP := v.determineSourceIP(nil, remAddr.IP) + + locAddr := &net.UDPAddr{IP: srcIP, Port: 0} + + return v._dialUDP(network, locAddr, remAddr) +} + +// ResolveIPAddr returns an address of IP end point. +func (v *Net) ResolveIPAddr(_, address string) (*net.IPAddr, error) { + var err error + + // Check if host is a domain name + ip := net.ParseIP(address) + if ip == nil { //nolint:nestif + address = strings.ToLower(address) + if address == "localhost" { + ip = net.IPv4(127, 0, 0, 1) + } else { + // host is a domain name. resolve IP address by the name + if v.router == nil { + return nil, errNoRouterLinked + } + + ip, err = v.router.resolver.lookUp(address) + if err != nil { + return nil, err + } + } + } + + return &net.IPAddr{ + IP: ip, + }, nil +} + +// ResolveUDPAddr returns an address of UDP end point. +func (v *Net) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) { + if network != udp && network != udp4 { + return nil, fmt.Errorf("%w %s", errUnknownNetwork, network) + } + + host, sPort, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + + ipAddress, err := v.ResolveIPAddr("ip", host) + if err != nil { + return nil, err + } + + port, err := strconv.Atoi(sPort) + if err != nil { + return nil, errInvalidPortNumber + } + + udpAddr := &net.UDPAddr{ + IP: ipAddress.IP, + Zone: ipAddress.Zone, + Port: port, + } + + return udpAddr, nil +} + +// ResolveTCPAddr returns an address of TCP end point. +func (v *Net) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) { + if network != udp && network != "udp4" { + return nil, fmt.Errorf("%w %s", errUnknownNetwork, network) + } + + host, sPort, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + + ipAddr, err := v.ResolveIPAddr("ip", host) + if err != nil { + return nil, err + } + + port, err := strconv.Atoi(sPort) + if err != nil { + return nil, errInvalidPortNumber + } + + udpAddr := &net.TCPAddr{ + IP: ipAddr.IP, + Zone: ipAddr.Zone, + Port: port, + } + + return udpAddr, nil +} + +func (v *Net) write(chunk Chunk) error { + if chunk.Network() == udp { //nolint:nestif + if udp, ok := chunk.(*chunkUDP); ok { + if chunk.getDestinationIP().IsLoopback() { + if conn, ok := v.udpConns.find(udp.DestinationAddr()); ok { + conn.onInboundChunk(udp) + } + + return nil + } + } else { + return errUnexpectedTypeSwitchFailure + } + } + + if v.router == nil { + return errNoRouterLinked + } + + v.router.push(chunk) + + return nil +} + +func (v *Net) onClosed(addr net.Addr) { + if addr.Network() == udp { + //nolint:errcheck + v.udpConns.delete(addr) // #nosec + } +} + +// This method determines the srcIP based on the dstIP when locIP +// is any IP address ("0.0.0.0" or "::"). If locIP is a non-any addr, +// this method simply returns locIP. +// caller must hold the mutex. +func (v *Net) determineSourceIP(locIP, dstIP net.IP) net.IP { //nolint:cyclop + if locIP != nil && !locIP.IsUnspecified() { + return locIP + } + + var srcIP net.IP + + if dstIP.IsLoopback() { //nolint:nestif + srcIP = net.ParseIP("127.0.0.1") + } else { + ifc, err2 := v._getInterface("eth0") + if err2 != nil { + return nil + } + + addrs, err2 := ifc.Addrs() + if err2 != nil { + return nil + } + + if len(addrs) == 0 { + return nil + } + + var findIPv4 bool + if locIP != nil { + findIPv4 = (locIP.To4() != nil) + } else { + findIPv4 = (dstIP.To4() != nil) + } + + for _, addr := range addrs { + ip := addr.(*net.IPNet).IP //nolint:forcetypeassert + if findIPv4 { + if ip.To4() != nil { + srcIP = ip + + break + } + } else { + if ip.To4() == nil { + srcIP = ip + + break + } + } + } + } + + return srcIP +} + +// caller must hold the mutex. +func (v *Net) hasIPAddr(ip net.IP) bool { //nolint:gocognit,cyclop + for _, ifc := range v.interfaces { + if addrs, err := ifc.Addrs(); err == nil { //nolint:nestif + for _, addr := range addrs { + var locIP net.IP + if ipNet, ok := addr.(*net.IPNet); ok { + locIP = ipNet.IP + } else if ipAddr, ok := addr.(*net.IPAddr); ok { + locIP = ipAddr.IP + } else { + continue + } + + switch ip.String() { + case "0.0.0.0": + if locIP.To4() != nil { + return true + } + case "::": + if locIP.To4() == nil { + return true + } + default: + if locIP.Equal(ip) { + return true + } + } + } + } + } + + return false +} + +// caller must hold the mutex. +func (v *Net) allocateLocalAddr(ip net.IP, port int) error { + // gather local IP addresses to bind + var ips []net.IP + if ip.IsUnspecified() { + ips = v.getAllIPAddrs(ip.To4() == nil) + } else if v.hasIPAddr(ip) { + ips = []net.IP{ip} + } + + if len(ips) == 0 { + return fmt.Errorf("%w %s", errBindFailedFor, ip.String()) + } + + // check if all these transport addresses are not in use + for _, ip2 := range ips { + addr := &net.UDPAddr{ + IP: ip2, + Port: port, + } + if _, ok := v.udpConns.find(addr); ok { + return &net.OpError{ + Op: "bind", + Net: udp, + Addr: addr, + Err: fmt.Errorf("bind: %w", errAddressAlreadyInUse), + } + } + } + + return nil +} + +// caller must hold the mutex. +func (v *Net) assignPort(ip net.IP, start, end int) (int, error) { + // choose randomly from the range between start and end (inclusive) + if end < start { + return -1, errEndPortLessThanStart + } + + space := end + 1 - start + offset := rand.Intn(space) //nolint:gosec + for i := 0; i < space; i++ { + port := ((offset + i) % space) + start + + err := v.allocateLocalAddr(ip, port) + if err == nil { + return port, nil + } + } + + return -1, errPortSpaceExhausted +} + +func (v *Net) getStaticIPs() []net.IP { + return v.staticIPs +} + +// NetConfig is a bag of configuration parameters passed to NewNet(). +type NetConfig struct { + // StaticIPs is an array of static IP addresses to be assigned for this Net. + // If no static IP address is given, the router will automatically assign + // an IP address. + StaticIPs []string +} + +// NewNet creates an instance of a virtual network. +// +// By design, it always have lo0 and eth0 interfaces. +// The lo0 has the address 127.0.0.1 assigned by default. +// IP address for eth0 will be assigned when this Net is added to a router. +func NewNet(config *NetConfig) (*Net, error) { + lo0 := transport.NewInterface(net.Interface{ + Index: 1, + MTU: 16384, + Name: lo0String, + HardwareAddr: nil, + Flags: net.FlagUp | net.FlagLoopback | net.FlagMulticast, + }) + lo0.AddAddress(&net.IPNet{ + IP: net.ParseIP("127.0.0.1"), + Mask: net.CIDRMask(8, 32), + }) + + eth0 := transport.NewInterface(net.Interface{ + Index: 2, + MTU: 1500, + Name: "eth0", + HardwareAddr: newMACAddress(), + Flags: net.FlagUp | net.FlagMulticast, + }) + + var staticIPs []net.IP + for _, ipStr := range config.StaticIPs { + if ip := net.ParseIP(ipStr); ip != nil { + staticIPs = append(staticIPs, ip) + } + } + + return &Net{ + interfaces: []*transport.Interface{lo0, eth0}, + staticIPs: staticIPs, + udpConns: newUDPConnMap(), + }, nil +} + +// DialTCP acts like Dial for TCP networks. +func (v *Net) DialTCP(string, *net.TCPAddr, *net.TCPAddr) (transport.TCPConn, error) { + return nil, transport.ErrNotSupported +} + +// ListenTCP acts like Listen for TCP networks. +func (v *Net) ListenTCP(string, *net.TCPAddr) (transport.TCPListener, error) { + return nil, transport.ErrNotSupported +} + +// CreateDialer creates an instance of vnet.Dialer. +func (v *Net) CreateDialer(d *net.Dialer) transport.Dialer { + return &dialer{ + dialer: d, + net: v, + } +} + +type dialer struct { + dialer *net.Dialer + net *Net +} + +func (d *dialer) Dial(network, address string) (net.Conn, error) { + return d.net.Dial(network, address) +} + +// CreateListenConfig creates an instance of vnet.ListenConfig. +func (v *Net) CreateListenConfig(l *net.ListenConfig) transport.ListenConfig { + return &listenConfig{ + listenConfig: l, + net: v, + } +} + +type listenConfig struct { + listenConfig *net.ListenConfig + net *Net +} + +func (l *listenConfig) Listen(ctx context.Context, network, address string) (net.Listener, error) { + return l.listenConfig.Listen(ctx, network, address) +} + +func (l *listenConfig) ListenPacket(_ context.Context, network, address string) (net.PacketConn, error) { + return l.net.ListenPacket(network, address) +} diff --git a/vendor/github.com/pion/transport/v4/vnet/queue.go b/vendor/github.com/pion/transport/v4/vnet/queue.go new file mode 100644 index 0000000..0ce3d47 --- /dev/null +++ b/vendor/github.com/pion/transport/v4/vnet/queue.go @@ -0,0 +1,104 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package vnet + +import ( + "sync" + "time" +) + +type Discipline interface { + push(Chunk) + pop() Chunk + empty() bool + next() time.Time +} + +type Queue struct { + NIC + data Discipline + chunkCh chan Chunk + closed bool + close chan struct{} + wg sync.WaitGroup + lock sync.Mutex +} + +func NewQueue(n NIC, d Discipline) (*Queue, error) { + q := &Queue{ + NIC: n, + data: d, + chunkCh: make(chan Chunk), + closed: false, + close: make(chan struct{}), + wg: sync.WaitGroup{}, + lock: sync.Mutex{}, + } + q.wg.Add(1) + go q.run() + + return q, nil +} + +func (q *Queue) onInboundChunk(c Chunk) { + select { + case q.chunkCh <- c: + case <-q.close: + + return + } +} + +func (q *Queue) run() { + defer q.wg.Done() + for { + if !q.schedule() { + return + } + } +} + +func (q *Queue) schedule() bool { + q.lock.Lock() + if q.closed { + q.lock.Unlock() + + return false + } + q.lock.Unlock() + + var timer <-chan time.Time + + if !q.data.empty() { + next := q.data.next() + timer = time.After(time.Until(next)) + } + + select { + case chunk := <-q.chunkCh: + q.data.push(chunk) + case <-timer: + chunk := q.data.pop() + if chunk != nil { + q.NIC.onInboundChunk(chunk) + } + case <-q.close: + return false + } + + return true +} + +func (q *Queue) Close() error { + defer q.wg.Wait() + q.lock.Lock() + defer q.lock.Unlock() + if q.closed { + return nil + } + q.closed = true + close(q.close) + + return nil +} diff --git a/vendor/github.com/pion/transport/v4/vnet/resolver.go b/vendor/github.com/pion/transport/v4/vnet/resolver.go new file mode 100644 index 0000000..3054a4d --- /dev/null +++ b/vendor/github.com/pion/transport/v4/vnet/resolver.go @@ -0,0 +1,95 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package vnet + +import ( + "errors" + "fmt" + "net" + "sync" + + "github.com/pion/logging" +) + +var ( + errHostnameEmpty = errors.New("host name must not be empty") + errFailedToParseIPAddr = errors.New("failed to parse IP address") +) + +type resolverConfig struct { + LoggerFactory logging.LoggerFactory +} + +type resolver struct { + parent *resolver // read-only + hosts map[string]net.IP // requires mutex + mutex sync.RWMutex // thread-safe + log logging.LeveledLogger // read-only +} + +func newResolver(config *resolverConfig) *resolver { + r := &resolver{ + hosts: map[string]net.IP{}, + log: config.LoggerFactory.NewLogger("vnet"), + } + + if err := r.addHost("localhost", "127.0.0.1"); err != nil { + r.log.Warn("failed to add localhost to resolver") + } + + return r +} + +func (r *resolver) setParent(parent *resolver) { + r.mutex.Lock() + defer r.mutex.Unlock() + + r.parent = parent +} + +func (r *resolver) addHost(name string, ipAddr string) error { + r.mutex.Lock() + defer r.mutex.Unlock() + + if len(name) == 0 { + return errHostnameEmpty + } + ip := net.ParseIP(ipAddr) + if ip == nil { + return fmt.Errorf("%w \"%s\"", errFailedToParseIPAddr, ipAddr) + } + r.hosts[name] = ip + + return nil +} + +func (r *resolver) lookUp(hostName string) (net.IP, error) { + ip := func() net.IP { + r.mutex.RLock() + defer r.mutex.RUnlock() + + if ip2, ok := r.hosts[hostName]; ok { + return ip2 + } + + return nil + }() + if ip != nil { + return ip, nil + } + + // mutex must be unlocked before calling into parent resolver + + if r.parent != nil { + return r.parent.lookUp(hostName) + } + + return nil, &net.DNSError{ + Err: "host not found", + Name: hostName, + Server: "vnet resolver", + IsTimeout: false, + IsTemporary: false, + } +} diff --git a/vendor/github.com/pion/transport/v4/vnet/router.go b/vendor/github.com/pion/transport/v4/vnet/router.go new file mode 100644 index 0000000..2512e99 --- /dev/null +++ b/vendor/github.com/pion/transport/v4/vnet/router.go @@ -0,0 +1,640 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package vnet + +import ( + "errors" + "fmt" + "math/rand" + "net" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/pion/logging" + "github.com/pion/transport/v4" +) + +const ( + defaultRouterQueueSize = 0 // unlimited +) + +var ( + errInvalidLocalIPinStaticIPs = errors.New("invalid local IP in StaticIPs") + errLocalIPBeyondStaticIPsSubset = errors.New("mapped in StaticIPs is beyond subnet") + errLocalIPNoStaticsIPsAssociated = errors.New("all StaticIPs must have associated local IPs") + errRouterAlreadyStarted = errors.New("router already started") + errRouterAlreadyStopped = errors.New("router already stopped") + errStaticIPisBeyondSubnet = errors.New("static IP is beyond subnet") + errAddressSpaceExhausted = errors.New("address space exhausted") + errNoIPAddrEth0 = errors.New("no IP address is assigned for eth0") +) + +// Generate a unique router name. +var assignRouterName = func() func() string { //nolint:gochecknoglobals + var routerIDCtr uint64 + + return func() string { + n := atomic.AddUint64(&routerIDCtr, 1) + + return fmt.Sprintf("router%d", n) + } +}() + +// RouterConfig ... +type RouterConfig struct { + // Name of router. If not specified, a unique name will be assigned. + Name string + // CIDR notation, like "192.0.2.0/24" + CIDR string + // StaticIPs is an array of static IP addresses to be assigned for this router. + // If no static IP address is given, the router will automatically assign + // an IP address. + // This will be ignored if this router is the root. + StaticIPs []string + // Internal queue size + QueueSize int + // Effective only when this router has a parent router + NATType *NATType + // Minimum Delay + MinDelay time.Duration + // Max Jitter + MaxJitter time.Duration + // Logger factory + LoggerFactory logging.LoggerFactory +} + +// NIC is a network interface controller that interfaces Router. +type NIC interface { + getInterface(ifName string) (*transport.Interface, error) + onInboundChunk(c Chunk) + getStaticIPs() []net.IP + setRouter(r *Router) error +} + +// ChunkFilter is a handler users can add to filter chunks. +// If the filter returns false, the packet will be dropped. +type ChunkFilter func(c Chunk) bool + +// Router ... +type Router struct { + name string // read-only + interfaces []*transport.Interface // read-only + ipv4Net *net.IPNet // read-only + staticIPs []net.IP // read-only + staticLocalIPs map[string]net.IP // read-only, + lastID byte // requires mutex [x], used to assign the last digit of IPv4 address + queue *chunkQueue // read-only + parent *Router // read-only + children []*Router // read-only + natType *NATType // read-only + nat *networkAddressTranslator // read-only + nics map[string]NIC // read-only + stopFunc func() // requires mutex [x] + resolver *resolver // read-only + chunkFilters []ChunkFilter // requires mutex [x] + minDelay time.Duration // requires mutex [x] + maxJitter time.Duration // requires mutex [x] + mutex sync.RWMutex // thread-safe + pushCh chan struct{} // writer requires mutex + loggerFactory logging.LoggerFactory // read-only + log logging.LeveledLogger // read-only +} + +// NewRouter ... +func NewRouter(config *RouterConfig) (*Router, error) { //nolint:cyclop + loggerFactory := config.LoggerFactory + log := loggerFactory.NewLogger("vnet") + + _, ipv4Net, err := net.ParseCIDR(config.CIDR) + if err != nil { + return nil, err + } + + queueSize := defaultRouterQueueSize + if config.QueueSize > 0 { + queueSize = config.QueueSize + } + + // set up network interface, lo0 + lo0 := transport.NewInterface(net.Interface{ + Index: 1, + MTU: 16384, + Name: lo0String, + HardwareAddr: nil, + Flags: net.FlagUp | net.FlagLoopback | net.FlagMulticast, + }) + lo0.AddAddress(&net.IPAddr{IP: net.ParseIP("127.0.0.1"), Zone: ""}) + + // set up network interface, eth0 + eth0 := transport.NewInterface(net.Interface{ + Index: 2, + MTU: 1500, + Name: "eth0", + HardwareAddr: newMACAddress(), + Flags: net.FlagUp | net.FlagMulticast, + }) + + // local host name resolver + resolver := newResolver(&resolverConfig{ + LoggerFactory: config.LoggerFactory, + }) + + name := config.Name + if len(name) == 0 { + name = assignRouterName() + } + + var staticIPs []net.IP + staticLocalIPs := map[string]net.IP{} + for _, ipStr := range config.StaticIPs { + ipPair := strings.Split(ipStr, "/") + if ip := net.ParseIP(ipPair[0]); ip != nil { //nolint:nestif + if len(ipPair) > 1 { + locIP := net.ParseIP(ipPair[1]) + if locIP == nil { + return nil, errInvalidLocalIPinStaticIPs + } + if !ipv4Net.Contains(locIP) { + return nil, fmt.Errorf("local IP %s %w", locIP.String(), errLocalIPBeyondStaticIPsSubset) + } + staticLocalIPs[ip.String()] = locIP + } + staticIPs = append(staticIPs, ip) + } + } + + if nStaticLocal := len(staticLocalIPs); nStaticLocal > 0 { + if nStaticLocal != len(staticIPs) { + return nil, errLocalIPNoStaticsIPsAssociated + } + } + + return &Router{ + name: name, + interfaces: []*transport.Interface{lo0, eth0}, + ipv4Net: ipv4Net, + staticIPs: staticIPs, + staticLocalIPs: staticLocalIPs, + queue: newChunkQueue(queueSize, 0), + natType: config.NATType, + nics: map[string]NIC{}, + resolver: resolver, + minDelay: config.MinDelay, + maxJitter: config.MaxJitter, + pushCh: make(chan struct{}, 1), + loggerFactory: loggerFactory, + log: log, + }, nil +} + +// caller must hold the mutex. +func (r *Router) getInterfaces() ([]*transport.Interface, error) { + if len(r.interfaces) == 0 { + return nil, fmt.Errorf("%w is available", errNoInterface) + } + + return r.interfaces, nil +} + +func (r *Router) getInterface(ifName string) (*transport.Interface, error) { + r.mutex.RLock() + defer r.mutex.RUnlock() + + ifs, err := r.getInterfaces() + if err != nil { + return nil, err + } + for _, ifc := range ifs { + if ifc.Name == ifName { + return ifc, nil + } + } + + return nil, fmt.Errorf("%w: %s", transport.ErrInterfaceNotFound, ifName) +} + +// Start ... +func (r *Router) Start() error { //nolint:cyclop + r.mutex.Lock() + defer r.mutex.Unlock() + + if r.stopFunc != nil { + return errRouterAlreadyStarted + } + + cancelCh := make(chan struct{}) + + go func() { + loop: + for { + duration, err := r.processChunks() + if err != nil { + r.log.Errorf("[%s] %s", r.name, err.Error()) + + break + } + + if duration <= 0 { + select { + case <-r.pushCh: + case <-cancelCh: + break loop + } + } else { + t := time.NewTimer(duration) + select { + case <-t.C: + case <-cancelCh: + break loop + } + } + } + }() + + r.stopFunc = func() { + close(cancelCh) + } + + for _, child := range r.children { + if err := child.Start(); err != nil { + return err + } + } + + return nil +} + +// Stop ... +func (r *Router) Stop() error { + r.mutex.Lock() + defer r.mutex.Unlock() + + if r.stopFunc == nil { + return errRouterAlreadyStopped + } + + for _, router := range r.children { + r.mutex.Unlock() + err := router.Stop() + r.mutex.Lock() + + if err != nil { + return err + } + } + + r.stopFunc() + r.stopFunc = nil + + return nil +} + +// caller must hold the mutex. +func (r *Router) addNIC(nic NIC) error { + ifc, err := nic.getInterface("eth0") + if err != nil { + return err + } + + var ips []net.IP + + if ips = nic.getStaticIPs(); len(ips) == 0 { + // assign an IP address + ip, err2 := r.assignIPAddress() + if err2 != nil { + return err2 + } + ips = append(ips, ip) + } + + for _, ip := range ips { + if !r.ipv4Net.Contains(ip) { + return fmt.Errorf("%w: %s", errStaticIPisBeyondSubnet, r.ipv4Net.String()) + } + + ifc.AddAddress(&net.IPNet{ + IP: ip, + Mask: r.ipv4Net.Mask, + }) + + r.nics[ip.String()] = nic + } + + return nic.setRouter(r) +} + +// caller must hold the mutex. +func (r *Router) addIPToNIC(nic NIC, ip net.IP) error { + if !r.ipv4Net.Contains(ip) { + return fmt.Errorf("%w: %s", errStaticIPisBeyondSubnet, r.ipv4Net.String()) + } + r.nics[ip.String()] = nic + + return nil +} + +// caller must hold the mutex. +func (r *Router) removeIPFromNIC(ip net.IP) { + delete(r.nics, ip.String()) +} + +// AddRouter adds a child Router. +func (r *Router) AddRouter(router *Router) error { + r.mutex.Lock() + defer r.mutex.Unlock() + + // Router is a NIC. Add it as a NIC so that packets are routed to this child + // router. + err := r.addNIC(router) + if err != nil { + return err + } + + if err = router.setRouter(r); err != nil { + return err + } + + r.children = append(r.children, router) + + return nil +} + +// AddChildRouter is like AddRouter, but does not add the child routers NIC to +// the parent. This has to be done manually by calling AddNet, which allows to +// use a wrapper around the subrouters NIC. +// AddNet MUST be called before AddChildRouter. +func (r *Router) AddChildRouter(router *Router) error { + r.mutex.Lock() + defer r.mutex.Unlock() + if err := router.setRouter(r); err != nil { + return err + } + + r.children = append(r.children, router) + + return nil +} + +// AddNet ... +func (r *Router) AddNet(nic NIC) error { + r.mutex.Lock() + defer r.mutex.Unlock() + + return r.addNIC(nic) +} + +// AddHost adds a mapping of hostname and an IP address to the local resolver. +func (r *Router) AddHost(hostName string, ipAddr string) error { + return r.resolver.addHost(hostName, ipAddr) +} + +// AddChunkFilter adds a filter for chunks traversing this router. +// You may add more than one filter. The filters are called in the order of this method call. +// If a chunk is dropped by a filter, subsequent filter will not receive the chunk. +func (r *Router) AddChunkFilter(filter ChunkFilter) { + r.mutex.Lock() + defer r.mutex.Unlock() + + r.chunkFilters = append(r.chunkFilters, filter) +} + +// caller should hold the mutex. +func (r *Router) assignIPAddress() (net.IP, error) { + // See: https://stackoverflow.com/questions/14915188/ip-address-ending-with-zero + + if r.lastID == 0xfe { + return nil, errAddressSpaceExhausted + } + + ip := make(net.IP, 4) + copy(ip, r.ipv4Net.IP[:3]) + r.lastID++ + ip[3] = r.lastID //nolint:gosec // IPv4 address is always 4 bytes + + return ip, nil +} + +func (r *Router) push(c Chunk) { + r.mutex.Lock() + defer r.mutex.Unlock() + + r.log.Debugf("[%s] route %s", r.name, c.String()) + if r.stopFunc != nil { + c.setTimestamp() + if r.queue.push(c) { + select { + case r.pushCh <- struct{}{}: + default: + } + } else { + r.log.Warnf("[%s] queue was full. dropped a chunk", r.name) + } + } +} + +func (r *Router) processChunks() (time.Duration, error) { //nolint:cyclop + r.mutex.Lock() + defer r.mutex.Unlock() + + // Introduce jitter by delaying the processing of chunks. + if r.maxJitter > 0 { + jitter := time.Duration(rand.Int63n(int64(r.maxJitter))) //nolint:gosec + time.Sleep(jitter) + } + + // cutOff + // v min delay + // |<--->| + // +------------:-- + // |OOOOOOXXXXX : --> time + // +------------:-- + // |<--->| now + // due + + enteredAt := time.Now() + cutOff := enteredAt.Add(-r.minDelay) + + var duration time.Duration // the next sleep duration + + for { + duration = 0 + + chunk := r.queue.peek() + if chunk == nil { + break // no more chunk in the queue + } + + // check timestamp to find if the chunk is due + if chunk.getTimestamp().After(cutOff) { + // There is one or more chunk in the queue but none of them are due. + // Calculate the next sleep duration here. + nextExpire := chunk.getTimestamp().Add(r.minDelay) + duration = nextExpire.Sub(enteredAt) + + break + } + + var ok bool + if chunk, ok = r.queue.pop(); !ok { + break // no more chunk in the queue + } + + blocked := false + for i := 0; i < len(r.chunkFilters); i++ { + filter := r.chunkFilters[i] + if !filter(chunk) { + blocked = true + + break + } + } + if blocked { + continue // discard + } + + dstIP := chunk.getDestinationIP() + + // check if the destination is in our subnet + if r.ipv4Net.Contains(dstIP) { + // search for the destination NIC + var nic NIC + if nic, ok = r.nics[dstIP.String()]; !ok { + // NIC not found. drop it. + r.log.Debugf("[%s] %s unreachable", r.name, chunk.String()) + + continue + } + + // found the NIC, forward the chunk to the NIC. + // call to NIC must unlock mutex + r.mutex.Unlock() + nic.onInboundChunk(chunk) + r.mutex.Lock() + + continue + } + + // the destination is outside of this subnet + // is this WAN? + if r.parent == nil { + // this WAN. No route for this chunk + r.log.Debugf("[%s] no route found for %s", r.name, chunk.String()) + + continue + } + + // Pass it to the parent via NAT + toParent, err := r.nat.translateOutbound(chunk) + if err != nil { + return 0, err + } + + if toParent == nil { + continue + } + + //nolint:godox + /* FIXME: this implementation would introduce a duplicate packet! + if r.nat.natType.Hairpinning { + hairpinned, err := r.nat.translateInbound(toParent) + if err != nil { + r.log.Warnf("[%s] %s", r.name, err.Error()) + } else { + go func() { + r.push(hairpinned) + }() + } + } + */ + + // call to parent router mutex unlock mutex + r.mutex.Unlock() + r.parent.push(toParent) + r.mutex.Lock() + } + + return duration, nil +} + +// caller must hold the mutex. +func (r *Router) setRouter(parent *Router) error { //nolint:cyclop + r.parent = parent + r.resolver.setParent(parent.resolver) + + // when this method is called, one or more IP address has already been assigned by + // the parent router. + ifc, err := r.getInterface("eth0") + if err != nil { + return err + } + + addrs, _ := ifc.Addrs() + if len(addrs) == 0 { + return errNoIPAddrEth0 + } + + mappedIPs := []net.IP{} + localIPs := []net.IP{} + + for _, ifcAddr := range addrs { + var ip net.IP + switch addr := ifcAddr.(type) { + case *net.IPNet: + ip = addr.IP + case *net.IPAddr: // Do we really need this case? + ip = addr.IP + default: + } + + if ip == nil { + continue + } + + mappedIPs = append(mappedIPs, ip) + + if locIP := r.staticLocalIPs[ip.String()]; locIP != nil { + localIPs = append(localIPs, locIP) + } + } + + // Set up NAT here + if r.natType == nil { + r.natType = &NATType{ + MappingBehavior: EndpointIndependent, + FilteringBehavior: EndpointAddrPortDependent, + Hairpinning: false, + PortPreservation: false, + MappingLifeTime: 30 * time.Second, + } + } + r.nat, err = newNAT(&natConfig{ + name: r.name, + natType: *r.natType, + mappedIPs: mappedIPs, + localIPs: localIPs, + loggerFactory: r.loggerFactory, + }) + if err != nil { + return err + } + + return nil +} + +func (r *Router) onInboundChunk(c Chunk) { + fromParent, err := r.nat.translateInbound(c) + if err != nil { + r.log.Warnf("[%s] %s", r.name, err.Error()) + + return + } + + r.push(fromParent) +} + +func (r *Router) getStaticIPs() []net.IP { + return r.staticIPs +} diff --git a/vendor/github.com/pion/transport/v4/vnet/tbf_queue.go b/vendor/github.com/pion/transport/v4/vnet/tbf_queue.go new file mode 100644 index 0000000..c187a70 --- /dev/null +++ b/vendor/github.com/pion/transport/v4/vnet/tbf_queue.go @@ -0,0 +1,94 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package vnet + +import ( + "sync/atomic" + "time" + + "golang.org/x/time/rate" +) + +var _ Discipline = (*TBFQueue)(nil) + +type TBFQueue struct { + limiter *rate.Limiter + chunks []Chunk + maxSize atomic.Int64 + currentSize int +} + +// NewTBFQueue creates a new Token Bucket Filter queue with initial rate r in +// bit per second, burst size b in bytes and queue size s in bytes. +func NewTBFQueue(r int, b int, s int64) *TBFQueue { + q := &TBFQueue{ + limiter: rate.NewLimiter(rate.Limit(r), b*8), + chunks: []Chunk{}, + maxSize: atomic.Int64{}, + currentSize: 0, + } + q.maxSize.Store(s) + + return q +} + +// SetRate updates the rate to r bit per second. +func (t *TBFQueue) SetRate(r int) { + t.limiter.SetLimit(rate.Limit(r)) +} + +// SetBurst updates the max burst size to b bytes. +func (t *TBFQueue) SetBurst(b int) { + t.limiter.SetBurst(b * 8) +} + +func (t *TBFQueue) SetSize(s int64) { + t.maxSize.Store(s) +} + +// empty implements discipline. +func (t *TBFQueue) empty() bool { + return len(t.chunks) == 0 +} + +// next implements discipline. +func (t *TBFQueue) next() time.Time { + if t.empty() { + return time.Time{} + } + now := time.Now() + if t.limiter.TokensAt(now) > 8*float64(len(t.chunks[0].UserData())) { + return now + } + res := t.limiter.ReserveN(now, 8*len(t.chunks[0].UserData())) + delay := res.Delay() + res.Cancel() + + return now.Add(delay) +} + +// pop implements discipline. +func (t *TBFQueue) pop() (chunk Chunk) { + if t.empty() { + return nil + } + if !t.limiter.AllowN(time.Now(), 8*len(t.chunks[0].UserData())) { + return nil + } + chunk, t.chunks = t.chunks[0], t.chunks[1:] + t.currentSize -= len(chunk.UserData()) + + return chunk +} + +// push implements discipline. +func (t *TBFQueue) push(chunk Chunk) { + maxSize := int(t.maxSize.Load()) + if t.currentSize+len(chunk.UserData()) > maxSize { + // drop chunk because queue is full + return + } + t.currentSize += len(chunk.UserData()) + t.chunks = append(t.chunks, chunk) +} diff --git a/vendor/github.com/pion/transport/v4/vnet/udpproxy.go b/vendor/github.com/pion/transport/v4/vnet/udpproxy.go new file mode 100644 index 0000000..628f2fb --- /dev/null +++ b/vendor/github.com/pion/transport/v4/vnet/udpproxy.go @@ -0,0 +1,226 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package vnet + +import ( + "context" + "net" + "sync" + "time" +) + +// UDPProxy is a proxy between real server(net.UDPConn) and vnet.UDPConn. +// +// High level design: +// +// .............................................. +// : Virtual Network (vnet) : +// : : +// +-------+ * 1 +----+ +--------+ : +// | :App |------------>|:Net|--o<-----|:Router | ............................. +// +-------+ +----+ | | : UDPProxy : +// : | | +----+ +---------+ +---------+ +--------+ +// : | |--->o--|:Net|-->o-| vnet. |-->o-| net. |--->-| :Real | +// : | | +----+ | UDPConn | | UDPConn | | Server | +// : | | : +---------+ +---------+ +--------+ +// : | | ............................: +// : +--------+ : +// ............................................... +type UDPProxy struct { + // The router bind to. + router *Router + + // Each vnet source, bind to a real socket to server. + // key is real server addr, which is net.Addr + // value is *aUDPProxyWorker + workers sync.Map + + // For each endpoint, we never know when to start and stop proxy, + // so we stop the endpoint when timeout. + timeout time.Duration + + // For utest, to mock the target real server. + // Optional, use the address of received client packet. + mockRealServerAddr *net.UDPAddr +} + +// NewProxy create a proxy, the router for this proxy belongs/bind to. If need to proxy for +// please create a new proxy for each router. For all addresses we proxy, we will create a +// vnet.Net in this router and proxy all packets. +func NewProxy(router *Router) (*UDPProxy, error) { + v := &UDPProxy{router: router, timeout: 2 * time.Minute} + + return v, nil +} + +// Close the proxy, stop all workers. +func (v *UDPProxy) Close() error { + v.workers.Range(func(_, value any) bool { + _ = value.(*aUDPProxyWorker).Close() //nolint:forcetypeassert + + return true + }) + + return nil +} + +// Proxy starts a worker for server, ignore if already started. +func (v *UDPProxy) Proxy(client *Net, server *net.UDPAddr) error { + // Note that even if the worker exists, it's also ok to create a same worker, + // because the router will use the last one, and the real server will see a address + // change event after we switch to the next worker. + if _, ok := v.workers.Load(server.String()); ok { + // nolint:godox // TODO: Need to restart the stopped worker? + return nil + } + + // Not exists, create a new one. + worker := &aUDPProxyWorker{ + router: v.router, mockRealServerAddr: v.mockRealServerAddr, + } + + // Create context for cleanup. + var ctx context.Context + ctx, worker.ctxDisposeCancel = context.WithCancel(context.Background()) + + v.workers.Store(server.String(), worker) + + return worker.Proxy(ctx, client, server) +} + +// A proxy worker for a specified proxy server. +type aUDPProxyWorker struct { + router *Router + mockRealServerAddr *net.UDPAddr + + // Each vnet source, bind to a real socket to server. + // key is vnet client addr, which is net.Addr + // value is *net.UDPConn + endpoints sync.Map + + // For cleanup. + ctxDisposeCancel context.CancelFunc + wg sync.WaitGroup +} + +func (v *aUDPProxyWorker) Close() error { + // Notify all goroutines to dispose. + v.ctxDisposeCancel() + + // Wait for all goroutines quit. + v.wg.Wait() + + return nil +} + +func (v *aUDPProxyWorker) Proxy(ctx context.Context, _ *Net, serverAddr *net.UDPAddr) error { // nolint:gocognit,cyclop + // Create vnet for real server by serverAddr. + nw, err := NewNet(&NetConfig{ + StaticIPs: []string{serverAddr.IP.String()}, + }) + if err != nil { + return err + } + + if err = v.router.AddNet(nw); err != nil { + return err + } + + // We must create a "same" vnet.UDPConn as the net.UDPConn, + // which has the same ip:port, to copy packets between them. + vnetSocket, err := nw.ListenUDP("udp4", serverAddr) + if err != nil { + return err + } + + // User stop proxy, we should close the socket. + go func() { + <-ctx.Done() + _ = vnetSocket.Close() + }() + + // Got new vnet client, start a new endpoint. + findEndpointBy := func(addr net.Addr) (*net.UDPConn, error) { + // Exists binding. + if value, ok := v.endpoints.Load(addr.String()); ok { + // Exists endpoint, reuse it. + return value.(*net.UDPConn), nil //nolint:forcetypeassert + } + + // The real server we proxy to, for utest to mock it. + realAddr := serverAddr + if v.mockRealServerAddr != nil { + realAddr = v.mockRealServerAddr + } + + // Got new vnet client, create new endpoint. + realSocket, err := net.DialUDP("udp4", nil, realAddr) + if err != nil { + return nil, err + } + + // User stop proxy, we should close the socket. + go func() { + <-ctx.Done() + _ = realSocket.Close() + }() + + // Bind address. + v.endpoints.Store(addr.String(), realSocket) + + // Got packet from real serverAddr, we should proxy it to vnet. + v.wg.Add(1) + go func(vnetClientAddr net.Addr) { + defer v.wg.Done() + + buf := make([]byte, 1500) + for { + n, _, err := realSocket.ReadFrom(buf) + if err != nil { + return + } + + if n <= 0 { + continue // Drop packet + } + + if _, err := vnetSocket.WriteTo(buf[:n], vnetClientAddr); err != nil { + return + } + } + }(addr) + + return realSocket, nil + } + + // Start a proxy goroutine. + v.wg.Add(1) + go func() { + defer v.wg.Done() + + buf := make([]byte, 1500) + + for { + n, addr, err := vnetSocket.ReadFrom(buf) + if err != nil { + return + } + + if n <= 0 || addr == nil { + continue // Drop packet + } + + realSocket, err := findEndpointBy(addr) + if err != nil { + continue // Drop packet. + } + + if _, err := realSocket.Write(buf[:n]); err != nil { + return + } + } + }() + + return nil +} diff --git a/vendor/github.com/pion/transport/v4/vnet/udpproxy_direct.go b/vendor/github.com/pion/transport/v4/vnet/udpproxy_direct.go new file mode 100644 index 0000000..5f72c5e --- /dev/null +++ b/vendor/github.com/pion/transport/v4/vnet/udpproxy_direct.go @@ -0,0 +1,54 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package vnet + +import ( + "fmt" + "net" +) + +// Deliver directly send packet to vnet or real-server. +// For example, we can use this API to simulate the REPLAY ATTACK. +func (v *UDPProxy) Deliver(sourceAddr, destAddr net.Addr, b []byte) (nn int, err error) { + v.workers.Range(func(_, value any) bool { + worker, ok := value.(*aUDPProxyWorker) + if !ok { + return false + } + + if nn, err = worker.Deliver(sourceAddr, destAddr, b); err != nil { + return false // Fail, abort. + } else if nn == len(b) { + return false // Done. + } + + return true // Deliver by next worker. + }) + + return +} + +func (v *aUDPProxyWorker) Deliver(sourceAddr, _ net.Addr, b []byte) (nn int, err error) { + addr, ok := sourceAddr.(*net.UDPAddr) + if !ok { + return 0, fmt.Errorf("invalid addr %v", sourceAddr) // nolint:err113 + } + + // nolint:godox // TODO: Support deliver packet from real server to vnet. + // If packet is from vnet, proxy to real server. + var realSocket *net.UDPConn + value, ok := v.endpoints.Load(addr.String()) + if !ok { + return 0, nil + } + + realSocket = value.(*net.UDPConn) // nolint:forcetypeassert + + // Send to real server. + if _, err := realSocket.Write(b); err != nil { + return 0, err + } + + return len(b), nil +} diff --git a/vendor/github.com/pion/transport/v4/vnet/vnet.go b/vendor/github.com/pion/transport/v4/vnet/vnet.go new file mode 100644 index 0000000..1d244f0 --- /dev/null +++ b/vendor/github.com/pion/transport/v4/vnet/vnet.go @@ -0,0 +1,5 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package vnet provides a virtual network layer for pion +package vnet diff --git a/vendor/github.com/pion/turn/v4/.gitignore b/vendor/github.com/pion/turn/v4/.gitignore new file mode 100644 index 0000000..6e2f206 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/.gitignore @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +### JetBrains IDE ### +##################### +.idea/ + +### Emacs Temporary Files ### +############################# +*~ + +### Folders ### +############### +bin/ +vendor/ +node_modules/ + +### Files ### +############# +*.ivf +*.ogg +tags +cover.out +*.sw[poe] +*.wasm +examples/sfu-ws/cert.pem +examples/sfu-ws/key.pem +wasm_exec.js diff --git a/vendor/github.com/pion/turn/v4/.golangci.yml b/vendor/github.com/pion/turn/v4/.golangci.yml new file mode 100644 index 0000000..4b4025f --- /dev/null +++ b/vendor/github.com/pion/turn/v4/.golangci.yml @@ -0,0 +1,147 @@ +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +version: "2" +linters: + enable: + - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers + - bidichk # Checks for dangerous unicode character sequences + - bodyclose # checks whether HTTP response body is closed successfully + - containedctx # containedctx is a linter that detects struct contained context.Context field + - contextcheck # check the function whether use a non-inherited context + - cyclop # checks function and package cyclomatic complexity + - decorder # check declaration order and count of types, constants, variables and functions + - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) + - dupl # Tool for code clone detection + - durationcheck # check for two durations multiplied together + - err113 # Golang linter to check the errors handling expressions + - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases + - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. + - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. + - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. + - exhaustive # check exhaustiveness of enum switch statements + - forbidigo # Forbids identifiers + - forcetypeassert # finds forced type assertions + - gochecknoglobals # Checks that no globals are present in Go code + - gocognit # Computes and checks the cognitive complexity of functions + - goconst # Finds repeated strings that could be replaced by a constant + - gocritic # The most opinionated Go source code linter + - gocyclo # Computes and checks the cyclomatic complexity of functions + - godot # Check if comments end in a period + - godox # Tool for detection of FIXME, TODO and other comment keywords + - goheader # Checks is file header matches to pattern + - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. + - goprintffuncname # Checks that printf-like functions are named with `f` at the end + - gosec # Inspects source code for security problems + - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string + - grouper # An analyzer to analyze expression groups. + - importas # Enforces consistent import aliases + - ineffassign # Detects when assignments to existing variables are not used + - lll # Reports long lines + - maintidx # maintidx measures the maintainability index of each function. + - makezero # Finds slice declarations with non-zero initial length + - misspell # Finds commonly misspelled English words in comments + - nakedret # Finds naked returns in functions greater than a specified function length + - nestif # Reports deeply nested if statements + - nilerr # Finds the code that returns nil even if it checks that the error is not nil. + - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. + - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity + - noctx # noctx finds sending http request without context.Context + - predeclared # find code that shadows one of Go's predeclared identifiers + - revive # golint replacement, finds style mistakes + - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks + - tagliatelle # Checks the struct tags. + - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers + - unconvert # Remove unnecessary type conversions + - unparam # Reports unused function parameters + - unused # Checks Go code for unused constants, variables, functions and types + - varnamelen # checks that the length of a variable's name matches its scope + - wastedassign # wastedassign finds wasted assignment statements + - whitespace # Tool for detection of leading and trailing whitespace + disable: + - depguard # Go linter that checks if package imports are in a list of acceptable packages + - funlen # Tool for detection of long functions + - gochecknoinits # Checks that no init functions are present in Go code + - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. + - interfacebloat # A linter that checks length of interface. + - ireturn # Accept Interfaces, Return Concrete Types + - mnd # An analyzer to detect magic numbers + - nolintlint # Reports ill-formed or insufficient nolint directives + - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test + - prealloc # Finds slice declarations that could potentially be preallocated + - promlinter # Check Prometheus metrics naming via promlint + - rowserrcheck # checks whether Err of rows is checked successfully + - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. + - testpackage # linter that makes you use a separate _test package + - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes + - wrapcheck # Checks that errors returned from external packages are wrapped + - wsl # Whitespace Linter - Forces you to use empty lines! + settings: + staticcheck: + checks: + - all + - -QF1008 # "could remove embedded field", to keep it explicit! + - -QF1003 # "could use tagged switch on enum", Cases conflicts with exhaustive! + exhaustive: + default-signifies-exhaustive: true + forbidigo: + forbid: + - pattern: ^fmt.Print(f|ln)?$ + - pattern: ^log.(Panic|Fatal|Print)(f|ln)?$ + - pattern: ^os.Exit$ + - pattern: ^panic$ + - pattern: ^print(ln)?$ + - pattern: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$ + pkg: ^testing$ + msg: use testify/assert instead + analyze-types: true + gomodguard: + blocked: + modules: + - github.com/pkg/errors: + recommendations: + - errors + govet: + enable: + - shadow + revive: + rules: + # Prefer 'any' type alias over 'interface{}' for Go 1.18+ compatibility + - name: use-any + severity: warning + disabled: false + misspell: + locale: US + varnamelen: + max-distance: 12 + min-name-length: 2 + ignore-type-assert-ok: true + ignore-map-index-ok: true + ignore-chan-recv-ok: true + ignore-decls: + - i int + - n int + - w io.Writer + - r io.Reader + - b []byte + exclusions: + generated: lax + rules: + - linters: + - forbidigo + - gocognit + path: (examples|main\.go) + - linters: + - gocognit + path: _test\.go + - linters: + - forbidigo + path: cmd +formatters: + enable: + - gci # Gci control golang package import order and make it always deterministic. + - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification + - gofumpt # Gofumpt checks whether code was gofumpt-ed. + - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports + exclusions: + generated: lax diff --git a/vendor/github.com/pion/turn/v4/.goreleaser.yml b/vendor/github.com/pion/turn/v4/.goreleaser.yml new file mode 100644 index 0000000..30093e9 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/.goreleaser.yml @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +builds: +- skip: true diff --git a/vendor/github.com/pion/turn/v4/LICENSE b/vendor/github.com/pion/turn/v4/LICENSE new file mode 100644 index 0000000..491caf6 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/LICENSE @@ -0,0 +1,9 @@ +MIT License + +Copyright (c) 2023 The Pion community + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/pion/turn/v4/README.md b/vendor/github.com/pion/turn/v4/README.md new file mode 100644 index 0000000..9c8d703 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/README.md @@ -0,0 +1,93 @@ +

+ Pion TURN +
+ Pion TURN +
+

+

A toolkit for building TURN clients and servers in Go

+

+ Pion TURN + join us on Discord Follow us on Bluesky + +
+ GitHub Workflow Status + Go Reference + Coverage Status + Go Report Card + License: MIT +

+
+ +Pion TURN is a Go toolkit for building TURN servers and clients. We wrote it to solve problems we had when building RTC projects. + +* **Deployable** - Use modern tooling of the Go ecosystem. Stop generating config files. +* **Embeddable** - Include `pion/turn` in your existing applications. No need to manage another service. +* **Extendable** - TURN as an API so you can easily integrate with your existing monitoring and metrics. +* **Maintainable** - `pion/turn` is simple and well documented. Designed for learning and easy debugging. +* **Portable** - Quickly deploy to multiple architectures/platforms just by setting an environment variable. +* **Safe** - Stability and safety is important for network services. Go provides everything we need. +* **Scalable** - Create allocations and mutate state at runtime. Designed to make scaling easy. + +### Using +`pion/turn` is an API for building STUN/TURN clients and servers, not a binary you deploy then configure. It may require copying our examples and +making minor modifications to fit your need, no knowledge of Go is required however. You may be able to download the pre-made binaries of our examples +if you wish to get started quickly. + +The advantage of this is that you don't need to deal with complicated config files, or custom APIs to modify the state of Pion TURN. +After you instantiate an instance of a Pion TURN server or client you interact with it like any library. The quickest way to get started is to look at the +[examples](examples) or [GoDoc](https://godoc.org/github.com/pion/turn) + +### Examples +We try to cover most common use cases in [examples](examples). If more examples could be helpful please file an issue, we are always looking +to expand and improve `pion/turn` to make it easier for developers. + +To build any example you just need to run `go build` in the directory of the example you care about. +It is also very easy to [cross compile](https://dave.cheney.net/2015/08/22/cross-compilation-with-go-1-5) Go programs. + +You can also see `pion/turn` usage in [pion/ice](https://github.com/pion/ice) + +### FAQ + +Also take a look at the [Pion WebRTC FAQ](https://github.com/pion/webrtc/wiki/FAQ) + +#### Will pion/turn also act as a STUN server? +Yes. + +#### How do I implement token-based authentication? +Replace the username with a token in the [AuthHandler](https://github.com/pion/turn/blob/6d0ff435910870eb9024b18321b93b61844fcfec/examples/turn-server/simple/main.go#L49). +The password sent by the client can be any non-empty string, as long as it matches that used by the [GenerateAuthKey](https://github.com/pion/turn/blob/6d0ff435910870eb9024b18321b93b61844fcfec/examples/turn-server/simple/main.go#L41) +function. + +#### Will WebRTC prioritize using STUN over TURN? +Yes. + +### RFCs +#### Implemented +* **RFC 5389**: [Session Traversal Utilities for NAT (STUN)][rfc5389] +* **RFC 5766**: [Traversal Using Relays around NAT (TURN): Relay Extensions to Session Traversal Utilities for NAT (STUN)][rfc5766] + +#### Planned +* **RFC 6062**: [Traversal Using Relays around NAT (TURN) Extensions for TCP Allocations][rfc6062] +* **RFC 6156**: [Traversal Using Relays around NAT (TURN) Extension for IPv6][rfc6156] + +[rfc5389]: https://tools.ietf.org/html/rfc5389 +[rfc5766]: https://tools.ietf.org/html/rfc5766 +[rfc6062]: https://tools.ietf.org/html/rfc6062 +[rfc6156]: https://tools.ietf.org/html/rfc6156 + +### Roadmap +The library is used as a part of our WebRTC implementation. Please refer to that [roadmap](https://github.com/pion/webrtc/issues/9) to track our major milestones. + +### Community +Pion has an active community on the [Discord](https://discord.gg/PngbdqpFbt). + +Follow the [Pion Bluesky](https://bsky.app/profile/pion.ly) or [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news. + +We are always looking to support **your projects**. Please reach out if you have something to build! +If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly) + +### Contributing +Check out the [contributing wiki](https://github.com/pion/webrtc/wiki/Contributing) to join the group of amazing people making this project possible + +### License +MIT License - see [LICENSE](LICENSE) for full text diff --git a/vendor/github.com/pion/turn/v4/client.go b/vendor/github.com/pion/turn/v4/client.go new file mode 100644 index 0000000..4f2bdb2 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/client.go @@ -0,0 +1,711 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package turn + +import ( + b64 "encoding/base64" + "fmt" + "math" + "net" + "sync" + "time" + + "github.com/pion/logging" + "github.com/pion/stun/v3" + "github.com/pion/transport/v4" + "github.com/pion/transport/v4/stdnet" + "github.com/pion/turn/v4/internal/client" + "github.com/pion/turn/v4/internal/proto" +) + +const ( + defaultRTO = 200 * time.Millisecond + maxRtxCount = 7 // Total 7 requests (Rc) + maxDataBufferSize = math.MaxUint16 // Message size limit for Chromium +) + +// interval [msec] +// 0: 0 ms +500 +// 1: 500 ms +1000 +// 2: 1500 ms +2000 +// 3: 3500 ms +4000 +// 4: 7500 ms +8000 +// 5: 15500 ms +16000 +// 6: 31500 ms +32000 +// -: 63500 ms failed + +// ClientConfig is a bag of config parameters for Client. +type ClientConfig struct { + STUNServerAddr string // STUN server address (e.g. "stun.abc.com:3478") + TURNServerAddr string // TURN server address (e.g. "turn.abc.com:3478") + Username string + Password string + Realm string + Software string + RTO time.Duration + Conn net.PacketConn // Listening socket (net.PacketConn) + Net transport.Net + LoggerFactory logging.LoggerFactory +} + +// Client is a STUN server client. +type Client struct { + conn net.PacketConn // Read-only + net transport.Net // Read-only + stunServerAddr net.Addr // Read-only + turnServerAddr net.Addr // Read-only + + username stun.Username // Read-only + password string // Read-only + realm stun.Realm // Read-only + integrity stun.MessageIntegrity // Read-only + software stun.Software // Read-only + trMap *client.TransactionMap // Thread-safe + rto time.Duration // Read-only + relayedConn *client.UDPConn // Protected by mutex *** + tcpAllocation *client.TCPAllocation // Protected by mutex *** + allocTryLock client.TryLock // Thread-safe + listenTryLock client.TryLock // Thread-safe + mutex sync.RWMutex // Thread-safe + mutexTrMap sync.Mutex // Thread-safe + log logging.LeveledLogger // Read-only +} + +// NewClient returns a new Client instance. listeningAddress is the address and port to listen on, +// default "0.0.0.0:0". +func NewClient(config *ClientConfig) (*Client, error) { + loggerFactory := config.LoggerFactory + if loggerFactory == nil { + loggerFactory = logging.NewDefaultLoggerFactory() + } + + log := loggerFactory.NewLogger("turnc") + + if config.Conn == nil { + return nil, errNilConn + } + + rto := defaultRTO + if config.RTO > 0 { + rto = config.RTO + } + + if config.Net == nil { + n, err := stdnet.NewNet() + if err != nil { + return nil, err + } + config.Net = n + } + + var stunServ, turnServ net.Addr + var err error + + if len(config.STUNServerAddr) > 0 { + stunServ, err = config.Net.ResolveUDPAddr("udp4", config.STUNServerAddr) + if err != nil { + return nil, err + } + + log.Debugf("Resolved STUN server %s to %s", config.STUNServerAddr, stunServ) + } + + if len(config.TURNServerAddr) > 0 { + turnServ, err = config.Net.ResolveUDPAddr("udp4", config.TURNServerAddr) + if err != nil { + return nil, err + } + + log.Debugf("Resolved TURN server %s to %s", config.TURNServerAddr, turnServ) + } + + client := &Client{ + conn: config.Conn, + stunServerAddr: stunServ, + turnServerAddr: turnServ, + username: stun.NewUsername(config.Username), + password: config.Password, + realm: stun.NewRealm(config.Realm), + software: stun.NewSoftware(config.Software), + trMap: client.NewTransactionMap(), + net: config.Net, + rto: rto, + log: log, + } + + return client, nil +} + +// TURNServerAddr return the TURN server address. +func (c *Client) TURNServerAddr() net.Addr { + return c.turnServerAddr +} + +// STUNServerAddr return the STUN server address. +func (c *Client) STUNServerAddr() net.Addr { + return c.stunServerAddr +} + +// Username returns username. +func (c *Client) Username() stun.Username { + return c.username +} + +// Realm return realm. +func (c *Client) Realm() stun.Realm { + return c.realm +} + +// WriteTo sends data to the specified destination using the base socket. +func (c *Client) WriteTo(data []byte, to net.Addr) (int, error) { + return c.conn.WriteTo(data, to) +} + +// Listen will have this client start listening on the conn provided via the config. +// This is optional. If not used, you will need to call HandleInbound method +// to supply incoming data, instead. +func (c *Client) Listen() error { + if err := c.listenTryLock.Lock(); err != nil { + return fmt.Errorf("%w: %s", errAlreadyListening, err.Error()) + } + + go func() { + buf := make([]byte, maxDataBufferSize) + for { + n, from, err := c.conn.ReadFrom(buf) + if err != nil { + c.log.Debugf("Failed to read: %s. Exiting loop", err) + + break + } + + _, err = c.HandleInbound(buf[:n], from) + if err != nil { + c.log.Debugf("Failed to handle inbound message: %s. Exiting loop", err) + + break + } + } + + c.listenTryLock.Unlock() + }() + + return nil +} + +// Close closes this client. +func (c *Client) Close() { + c.mutexTrMap.Lock() + defer c.mutexTrMap.Unlock() + + c.trMap.CloseAndDeleteAll() +} + +// TransactionID & Base64: https://play.golang.org/p/EEgmJDI971P + +// SendBindingRequestTo sends a new STUN request to the given transport address. +func (c *Client) SendBindingRequestTo(to net.Addr) (net.Addr, error) { + attrs := []stun.Setter{stun.TransactionID, stun.BindingRequest} + if len(c.software) > 0 { + attrs = append(attrs, c.software) + } + + msg, err := stun.Build(attrs...) + if err != nil { + return nil, err + } + trRes, err := c.PerformTransaction(msg, to, false) + if err != nil { + return nil, err + } + + var reflAddr stun.XORMappedAddress + if err := reflAddr.GetFrom(trRes.Msg); err != nil { + return nil, err + } + + return &net.UDPAddr{ + IP: reflAddr.IP, + Port: reflAddr.Port, + }, nil +} + +// SendBindingRequest sends a new STUN request to the STUN server. +func (c *Client) SendBindingRequest() (net.Addr, error) { + if c.stunServerAddr == nil { + return nil, errSTUNServerAddressNotSet + } + + return c.SendBindingRequestTo(c.stunServerAddr) +} + +func (c *Client) sendAllocateRequest(protocol proto.Protocol) ( //nolint:cyclop + proto.RelayedAddress, + proto.Lifetime, + stun.Nonce, + error, +) { + var relayed proto.RelayedAddress + var lifetime proto.Lifetime + var nonce stun.Nonce + + msg, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassRequest), + proto.RequestedTransport{Protocol: protocol}, + stun.Fingerprint, + ) + if err != nil { + return relayed, lifetime, nonce, err + } + + trRes, err := c.PerformTransaction(msg, c.turnServerAddr, false) + if err != nil { + return relayed, lifetime, nonce, err + } + + res := trRes.Msg + + // Anonymous allocate failed, trying to authenticate. + if err = nonce.GetFrom(res); err != nil { + return relayed, lifetime, nonce, err + } + if err = c.realm.GetFrom(res); err != nil { + return relayed, lifetime, nonce, err + } + c.realm = append([]byte(nil), c.realm...) + c.integrity = stun.NewLongTermIntegrity( + c.username.String(), c.realm.String(), c.password, + ) + // Trying to authorize. + msg, err = stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassRequest), + proto.RequestedTransport{Protocol: protocol}, + &c.username, + &c.realm, + &nonce, + &c.integrity, + stun.Fingerprint, + ) + if err != nil { + return relayed, lifetime, nonce, err + } + + trRes, err = c.PerformTransaction(msg, c.turnServerAddr, false) + if err != nil { + return relayed, lifetime, nonce, err + } + res = trRes.Msg + + if res.Type.Class == stun.ClassErrorResponse { + var code stun.ErrorCodeAttribute + if err = code.GetFrom(res); err == nil { + turnError := &stun.TurnError{ + StunMessageType: res.Type, + ErrorCodeAttr: code, + } + + return relayed, lifetime, nonce, turnError + } + + return relayed, lifetime, nonce, fmt.Errorf("%s", res.Type) //nolint:err113 + } + + // Getting relayed addresses from response. + if err := relayed.GetFrom(res); err != nil { + return relayed, lifetime, nonce, err + } + + // Getting lifetime from response + if err := lifetime.GetFrom(res); err != nil { + return relayed, lifetime, nonce, err + } + + return relayed, lifetime, nonce, nil +} + +// Allocate sends a TURN allocation request to the given transport address. +func (c *Client) Allocate() (net.PacketConn, error) { + if err := c.allocTryLock.Lock(); err != nil { + return nil, fmt.Errorf("%w: %s", errOneAllocateOnly, err.Error()) + } + defer c.allocTryLock.Unlock() + + relayedConn := c.relayedUDPConn() + if relayedConn != nil { + return nil, fmt.Errorf("%w: %s", errAlreadyAllocated, relayedConn.LocalAddr().String()) + } + + relayed, lifetime, nonce, err := c.sendAllocateRequest(proto.ProtoUDP) + if err != nil { + return nil, err + } + + relayedAddr := &net.UDPAddr{ + IP: relayed.IP, + Port: relayed.Port, + } + + relayedConn = client.NewUDPConn(&client.AllocationConfig{ + Client: c, + RelayedAddr: relayedAddr, + ServerAddr: c.turnServerAddr, + Realm: c.realm, + Username: c.username, + Integrity: c.integrity, + Nonce: nonce, + Lifetime: lifetime.Duration, + Net: c.net, + Log: c.log, + }) + c.setRelayedUDPConn(relayedConn) + + return relayedConn, nil +} + +// AllocateTCP creates a new TCP allocation at the TURN server. +func (c *Client) AllocateTCP() (*client.TCPAllocation, error) { + if err := c.allocTryLock.Lock(); err != nil { + return nil, fmt.Errorf("%w: %s", errOneAllocateOnly, err.Error()) + } + defer c.allocTryLock.Unlock() + + allocation := c.getTCPAllocation() + if allocation != nil { + return nil, fmt.Errorf("%w: %s", errAlreadyAllocated, allocation.Addr()) + } + + relayed, lifetime, nonce, err := c.sendAllocateRequest(proto.ProtoTCP) + if err != nil { + return nil, err + } + + relayedAddr := &net.TCPAddr{ + IP: relayed.IP, + Port: relayed.Port, + } + + allocation = client.NewTCPAllocation(&client.AllocationConfig{ + Client: c, + RelayedAddr: relayedAddr, + ServerAddr: c.turnServerAddr, + Realm: c.realm, + Username: c.username, + Integrity: c.integrity, + Nonce: nonce, + Lifetime: lifetime.Duration, + Net: c.net, + Log: c.log, + }) + + c.setTCPAllocation(allocation) + + return allocation, nil +} + +// CreatePermission Issues a CreatePermission request for the supplied addresses +// as described in https://datatracker.ietf.org/doc/html/rfc5766#section-9 +func (c *Client) CreatePermission(addrs ...net.Addr) error { + if conn := c.relayedUDPConn(); conn != nil { + if err := conn.CreatePermissions(addrs...); err != nil { + return err + } + } + + if allocation := c.getTCPAllocation(); allocation != nil { + if err := allocation.CreatePermissions(addrs...); err != nil { + return err + } + } + + return nil +} + +// PerformTransaction performs STUN transaction. +func (c *Client) PerformTransaction(msg *stun.Message, to net.Addr, ignoreResult bool) (client.TransactionResult, + error, +) { + trKey := b64.StdEncoding.EncodeToString(msg.TransactionID[:]) + + raw := make([]byte, len(msg.Raw)) + copy(raw, msg.Raw) + + tr := client.NewTransaction(&client.TransactionConfig{ + Key: trKey, + Raw: raw, + To: to, + Interval: c.rto, + IgnoreResult: ignoreResult, + }) + + c.trMap.Insert(trKey, tr) + + c.log.Tracef("Start %s transaction %s to %s", msg.Type, trKey, tr.To) + _, err := c.conn.WriteTo(tr.Raw, to) + if err != nil { + return client.TransactionResult{}, err + } + + tr.StartRtxTimer(c.onRtxTimeout) + + // If ignoreResult is true, get the transaction going and return immediately + if ignoreResult { + return client.TransactionResult{}, nil + } + + res := tr.WaitForResult() + if res.Err != nil { + return res, res.Err + } + + return res, nil +} + +// OnDeallocated is called when de-allocation of relay address has been complete. +// (Called by UDPConn). +func (c *Client) OnDeallocated(net.Addr) { + c.setRelayedUDPConn(nil) + c.setTCPAllocation(nil) +} + +// HandleInbound handles data received. +// This method handles incoming packet de-multiplex it by the source address +// and the types of the message. +// This return a boolean (handled or not) and if there was an error. +// Caller should check if the packet was handled by this client or not. +// If not handled, it is assumed that the packet is application data. +// If an error is returned, the caller should discard the packet regardless. +func (c *Client) HandleInbound(data []byte, from net.Addr) (bool, error) { + // +-------------------+-------------------------------+ + // | Return Values | | + // +-------------------+ Meaning / Action | + // | handled | error | | + // |=========+=========+===============================+ + // | false | nil | Handle the packet as app data | + // |---------+---------+-------------------------------+ + // | true | nil | Nothing to do | + // |---------+---------+-------------------------------+ + // | false | error | (shouldn't happen) | + // |---------+---------+-------------------------------+ + // | true | error | Error occurred while handling | + // +---------+---------+-------------------------------+ + // Possible causes of the error: + // - Malformed packet (parse error) + // - STUN message was a request + // - Non-STUN message from the STUN server + + switch { + case stun.IsMessage(data): + return true, c.handleSTUNMessage(data, from) + case proto.IsChannelData(data): + return true, c.handleChannelData(data) + case c.stunServerAddr != nil && from.String() == c.stunServerAddr.String(): + // Received from STUN server but it is not a STUN message + return true, errNonSTUNMessage + default: + // Assume, this is an application data + c.log.Tracef("Ignoring non-STUN/TURN packet") + } + + return false, nil +} + +func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error { //nolint:cyclop + raw := make([]byte, len(data)) + copy(raw, data) + + msg := &stun.Message{Raw: raw} + if err := msg.Decode(); err != nil { + return fmt.Errorf("%w: %s", errFailedToDecodeSTUN, err.Error()) + } + + if msg.Type.Class == stun.ClassRequest { + return fmt.Errorf("%w : %s", errUnexpectedSTUNRequestMessage, msg.String()) + } + + if msg.Type.Class == stun.ClassIndication { // nolint:nestif + switch msg.Type.Method { + case stun.MethodData: + var peerAddr proto.PeerAddress + if err := peerAddr.GetFrom(msg); err != nil { + return err + } + from = &net.UDPAddr{ + IP: peerAddr.IP, + Port: peerAddr.Port, + } + + var data proto.Data + if err := data.GetFrom(msg); err != nil { + return err + } + + c.log.Tracef("Data indication received from %s", from) + + relayedConn := c.relayedUDPConn() + if relayedConn == nil { + c.log.Debug("No relayed conn allocated") + + return nil // Silently discard + } + relayedConn.HandleInbound(data, from) + case stun.MethodConnectionAttempt: + var peerAddr proto.PeerAddress + if err := peerAddr.GetFrom(msg); err != nil { + return err + } + + addr := &net.TCPAddr{ + IP: peerAddr.IP, + Port: peerAddr.Port, + } + + var cid proto.ConnectionID + if err := cid.GetFrom(msg); err != nil { + return err + } + + c.log.Debugf("Connection attempt from %s", addr) + + allocation := c.getTCPAllocation() + if allocation == nil { + c.log.Debug("No TCP allocation exists") + + return nil // Silently discard + } + + allocation.HandleConnectionAttempt(addr, cid) + default: + c.log.Debug("Received unsupported STUN method") + } + + return nil + } + + // This is a STUN response message (transactional) + // The type is either: + // - stun.ClassSuccessResponse + // - stun.ClassErrorResponse + + trKey := b64.StdEncoding.EncodeToString(msg.TransactionID[:]) + + c.mutexTrMap.Lock() + tr, ok := c.trMap.Find(trKey) + if !ok { + c.mutexTrMap.Unlock() + // Silently discard + c.log.Debugf("No transaction for %s", msg) + + return nil + } + + // End the transaction + tr.StopRtxTimer() + c.trMap.Delete(trKey) + c.mutexTrMap.Unlock() + + if !tr.WriteResult(client.TransactionResult{ + Msg: msg, + From: from, + Retries: tr.Retries(), + }) { + c.log.Debugf("No listener for %s", msg) + } + + return nil +} + +func (c *Client) handleChannelData(data []byte) error { + chData := &proto.ChannelData{ + Raw: make([]byte, len(data)), + } + copy(chData.Raw, data) + if err := chData.Decode(); err != nil { + return err + } + + relayedConn := c.relayedUDPConn() + if relayedConn == nil { + c.log.Debug("No relayed conn allocated") + + return nil // Silently discard + } + + addr, ok := relayedConn.FindAddrByChannelNumber(uint16(chData.Number)) + if !ok { + return fmt.Errorf("%w: %d", errChannelBindNotFound, int(chData.Number)) + } + + c.log.Tracef("Channel data received from %s (ch=%d)", addr.String(), int(chData.Number)) + + relayedConn.HandleInbound(chData.Data, addr) + + return nil +} + +func (c *Client) onRtxTimeout(trKey string, nRtx int) { + c.mutexTrMap.Lock() + defer c.mutexTrMap.Unlock() + + tr, ok := c.trMap.Find(trKey) + if !ok { + return // Already gone + } + + if nRtx == maxRtxCount { + // All retransmissions failed + c.trMap.Delete(trKey) + if !tr.WriteResult(client.TransactionResult{ + Err: fmt.Errorf("%w %s", errAllRetransmissionsFailed, trKey), + }) { + c.log.Debug("No listener for transaction") + } + + return + } + + c.log.Tracef("Retransmitting transaction %s to %s (nRtx=%d)", + trKey, tr.To, nRtx) + _, err := c.conn.WriteTo(tr.Raw, tr.To) + if err != nil { + c.trMap.Delete(trKey) + if !tr.WriteResult(client.TransactionResult{ + Err: fmt.Errorf("%w %s", errFailedToRetransmitTransaction, trKey), + }) { + c.log.Debug("No listener for transaction") + } + + return + } + tr.StartRtxTimer(c.onRtxTimeout) +} + +func (c *Client) setRelayedUDPConn(conn *client.UDPConn) { + c.mutex.Lock() + defer c.mutex.Unlock() + + c.relayedConn = conn +} + +func (c *Client) relayedUDPConn() *client.UDPConn { + c.mutex.RLock() + defer c.mutex.RUnlock() + + return c.relayedConn +} + +func (c *Client) setTCPAllocation(alloc *client.TCPAllocation) { + c.mutex.Lock() + defer c.mutex.Unlock() + + c.tcpAllocation = alloc +} + +func (c *Client) getTCPAllocation() *client.TCPAllocation { + c.mutex.RLock() + defer c.mutex.RUnlock() + + return c.tcpAllocation +} diff --git a/vendor/github.com/pion/turn/v4/codecov.yml b/vendor/github.com/pion/turn/v4/codecov.yml new file mode 100644 index 0000000..263e4d4 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/codecov.yml @@ -0,0 +1,22 @@ +# +# DO NOT EDIT THIS FILE +# +# It is automatically copied from https://github.com/pion/.goassets repository. +# +# SPDX-FileCopyrightText: 2023 The Pion community +# SPDX-License-Identifier: MIT + +coverage: + status: + project: + default: + # Allow decreasing 2% of total coverage to avoid noise. + threshold: 2% + patch: + default: + target: 70% + only_pulls: true + +ignore: + - "examples/*" + - "examples/**/*" diff --git a/vendor/github.com/pion/turn/v4/errors.go b/vendor/github.com/pion/turn/v4/errors.go new file mode 100644 index 0000000..3ebd26a --- /dev/null +++ b/vendor/github.com/pion/turn/v4/errors.go @@ -0,0 +1,32 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package turn + +import "errors" + +var ( + errRelayAddressInvalid = errors.New("turn: RelayAddress must be valid IP to use RelayAddressGeneratorStatic") + errNoAvailableConns = errors.New("turn: PacketConnConfigs and ConnConfigs are empty, unable to proceed") + errConnUnset = errors.New("turn: PacketConnConfig must have a non-nil Conn") + errListenerUnset = errors.New("turn: ListenerConfig must have a non-nil Listener") + errListeningAddressInvalid = errors.New("turn: RelayAddressGenerator has invalid ListeningAddress") + errRelayAddressGeneratorUnset = errors.New("turn: RelayAddressGenerator in RelayConfig is unset") + errMaxRetriesExceeded = errors.New("turn: max retries exceeded") + errMaxPortNotZero = errors.New("turn: MaxPort must be not 0") + errMinPortNotZero = errors.New("turn: MaxPort must be not 0") + errNilConn = errors.New("turn: conn cannot not be nil") + errTODO = errors.New("turn: TODO") + errAlreadyListening = errors.New("turn: already listening") + errFailedToClose = errors.New("turn: Server failed to close") + errFailedToRetransmitTransaction = errors.New("turn: failed to retransmit transaction") + errAllRetransmissionsFailed = errors.New("all retransmissions failed for") + errChannelBindNotFound = errors.New("no binding found for channel") + errSTUNServerAddressNotSet = errors.New("STUN server address is not set for the client") + errOneAllocateOnly = errors.New("only one Allocate() caller is allowed") + errAlreadyAllocated = errors.New("already allocated") + errNonSTUNMessage = errors.New("non-STUN message from STUN server") + errFailedToDecodeSTUN = errors.New("failed to decode STUN message") + errUnexpectedSTUNRequestMessage = errors.New("unexpected STUN request message") + errRelayAddressGeneratorNil = errors.New("RelayAddressGenerator is nil") +) diff --git a/vendor/github.com/pion/turn/v4/internal/allocation/allocation.go b/vendor/github.com/pion/turn/v4/internal/allocation/allocation.go new file mode 100644 index 0000000..8fb4752 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/allocation/allocation.go @@ -0,0 +1,355 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package allocation contains all CRUD operations for allocations +package allocation + +import ( + "net" + "sync" + "sync/atomic" + "time" + + "github.com/pion/logging" + "github.com/pion/stun/v3" + "github.com/pion/turn/v4/internal/ipnet" + "github.com/pion/turn/v4/internal/proto" +) + +type allocationResponse struct { + transactionID [stun.TransactionIDSize]byte + responseAttrs []stun.Setter +} + +// Allocation is tied to a FiveTuple and relays traffic +// use CreateAllocation and GetAllocation to operate. +type Allocation struct { + RelayAddr net.Addr + Protocol Protocol + TurnSocket net.PacketConn + RelaySocket net.PacketConn + fiveTuple *FiveTuple + permissionsLock sync.RWMutex + permissions map[string]*Permission + channelBindingsLock sync.RWMutex + channelBindings []*ChannelBind + lifetimeTimer *time.Timer + closed chan any + username, realm string + eventHandler EventHandler + log logging.LeveledLogger + + // Some clients (Firefox or others using resiprocate's nICE lib) may retry allocation + // with same 5 tuple when received 413, for compatible with these clients, + // cache for response lost and client retry to implement 'stateless stack approach' + // See: https://datatracker.ietf.org/doc/html/rfc5766#section-6.2 + responseCache atomic.Value // *allocationResponse +} + +// NewAllocation creates a new instance of NewAllocation. +func NewAllocation( + turnSocket net.PacketConn, + fiveTuple *FiveTuple, + eventHandler EventHandler, + log logging.LeveledLogger, +) *Allocation { + return &Allocation{ + TurnSocket: turnSocket, + fiveTuple: fiveTuple, + permissions: make(map[string]*Permission, 64), + closed: make(chan any), + eventHandler: eventHandler, + log: log, + } +} + +// GetPermission gets the Permission from the allocation. +func (a *Allocation) GetPermission(addr net.Addr) *Permission { + a.permissionsLock.RLock() + defer a.permissionsLock.RUnlock() + + return a.permissions[ipnet.FingerprintAddr(addr)] +} + +// AddPermission adds a new permission to the allocation. +func (a *Allocation) AddPermission(perms *Permission) { + fingerprint := ipnet.FingerprintAddr(perms.Addr) + + a.permissionsLock.RLock() + existedPermission, ok := a.permissions[fingerprint] + a.permissionsLock.RUnlock() + + if ok { + existedPermission.refresh(permissionTimeout) + + return + } + + perms.allocation = a + a.permissionsLock.Lock() + a.permissions[fingerprint] = perms + a.permissionsLock.Unlock() + + if a.eventHandler.OnPermissionCreated != nil { + if u, ok := perms.Addr.(*net.UDPAddr); ok { + a.eventHandler.OnPermissionCreated(a.fiveTuple.SrcAddr, a.fiveTuple.DstAddr, + a.fiveTuple.Protocol.String(), a.username, a.realm, + a.RelayAddr, u.IP) + } + } + + perms.start(permissionTimeout) +} + +// RemovePermission removes the net.Addr's fingerprint from the allocation's permissions. +func (a *Allocation) RemovePermission(addr net.Addr) { + a.permissionsLock.Lock() + defer a.permissionsLock.Unlock() + delete(a.permissions, ipnet.FingerprintAddr(addr)) + + if a.eventHandler.OnPermissionDeleted != nil { + if u, ok := addr.(*net.UDPAddr); ok { + a.eventHandler.OnPermissionDeleted(a.fiveTuple.SrcAddr, a.fiveTuple.DstAddr, + a.fiveTuple.Protocol.String(), a.username, a.realm, + a.RelayAddr, u.IP) + } + } +} + +// ListPermissions returns the permissions associated with an allocation. +func (a *Allocation) ListPermissions() []*Permission { + ps := []*Permission{} + a.permissionsLock.RLock() + defer a.permissionsLock.RUnlock() + for _, p := range a.permissions { + ps = append(ps, p) + } + + return ps +} + +// AddChannelBind adds a new ChannelBind to the allocation, it also updates the +// permissions needed for this ChannelBind. +func (a *Allocation) AddChannelBind(chanBind *ChannelBind, lifetime time.Duration) error { + // Check that this channel id isn't bound to another transport address, and + // that this transport address isn't bound to another channel number. + channelByNumber := a.GetChannelByNumber(chanBind.Number) + + if channelByNumber != a.GetChannelByAddr(chanBind.Peer) { + return errSameChannelDifferentPeer + } + + // Add or refresh this channel. + if channelByNumber == nil { + a.channelBindingsLock.Lock() + defer a.channelBindingsLock.Unlock() + + chanBind.allocation = a + a.channelBindings = append(a.channelBindings, chanBind) + chanBind.start(lifetime) + + // Channel binds also refresh permissions. + a.AddPermission(NewPermission(chanBind.Peer, a.log)) + + if a.eventHandler.OnChannelCreated != nil { + a.eventHandler.OnChannelCreated(a.fiveTuple.SrcAddr, a.fiveTuple.DstAddr, + a.fiveTuple.Protocol.String(), a.username, a.realm, + a.RelayAddr, chanBind.Peer, uint16(chanBind.Number)) + } + } else { + channelByNumber.refresh(lifetime) + + // Channel binds also refresh permissions. + a.AddPermission(NewPermission(channelByNumber.Peer, a.log)) + } + + return nil +} + +// RemoveChannelBind removes the ChannelBind from this allocation by id. +func (a *Allocation) RemoveChannelBind(number proto.ChannelNumber) bool { + a.channelBindingsLock.Lock() + defer a.channelBindingsLock.Unlock() + + for i := len(a.channelBindings) - 1; i >= 0; i-- { + if a.channelBindings[i].Number == number { + if a.eventHandler.OnChannelDeleted != nil { + a.eventHandler.OnChannelDeleted(a.fiveTuple.SrcAddr, a.fiveTuple.DstAddr, + a.fiveTuple.Protocol.String(), a.username, a.realm, + a.RelayAddr, a.channelBindings[i].Peer, uint16(a.channelBindings[i].Number)) + } + + a.channelBindings = append(a.channelBindings[:i], a.channelBindings[i+1:]...) + + return true + } + } + + return false +} + +// GetChannelByNumber gets the ChannelBind from this allocation by id. +func (a *Allocation) GetChannelByNumber(number proto.ChannelNumber) *ChannelBind { + a.channelBindingsLock.RLock() + defer a.channelBindingsLock.RUnlock() + for _, cb := range a.channelBindings { + if cb.Number == number { + return cb + } + } + + return nil +} + +// GetChannelByAddr gets the ChannelBind from this allocation by net.Addr. +func (a *Allocation) GetChannelByAddr(addr net.Addr) *ChannelBind { + a.channelBindingsLock.RLock() + defer a.channelBindingsLock.RUnlock() + for _, cb := range a.channelBindings { + if ipnet.AddrEqual(cb.Peer, addr) { + return cb + } + } + + return nil +} + +// ListChannelBindings returns the channel bindings associated with an allocation. +func (a *Allocation) ListChannelBindings() []*ChannelBind { + cs := []*ChannelBind{} + a.channelBindingsLock.RLock() + defer a.channelBindingsLock.RUnlock() + cs = append(cs, a.channelBindings...) + + return cs +} + +// Refresh updates the allocations lifetime. +func (a *Allocation) Refresh(lifetime time.Duration) { + if !a.lifetimeTimer.Reset(lifetime) { + a.log.Errorf("Failed to reset allocation timer for %v", a.fiveTuple) + } +} + +// SetResponseCache cache allocation response for retransmit allocation request. +func (a *Allocation) SetResponseCache(transactionID [stun.TransactionIDSize]byte, attrs []stun.Setter) { + a.responseCache.Store(&allocationResponse{ + transactionID: transactionID, + responseAttrs: attrs, + }) +} + +// GetResponseCache return response cache for retransmit allocation request. +func (a *Allocation) GetResponseCache() (id [stun.TransactionIDSize]byte, attrs []stun.Setter) { + if res, ok := a.responseCache.Load().(*allocationResponse); ok && res != nil { + id, attrs = res.transactionID, res.responseAttrs + } + + return +} + +// Close closes the allocation. +func (a *Allocation) Close() error { + select { + case <-a.closed: + return nil + default: + } + close(a.closed) + + a.lifetimeTimer.Stop() + + for _, p := range a.ListPermissions() { + a.RemovePermission(p.Addr) + p.lifetimeTimer.Stop() + } + + for _, c := range a.ListChannelBindings() { + a.RemoveChannelBind(c.Number) + c.lifetimeTimer.Stop() + } + + return a.RelaySocket.Close() +} + +// https://tools.ietf.org/html/rfc5766#section-10.3 +// When the server receives a UDP datagram at a currently allocated +// relayed transport address, the server looks up the allocation +// associated with the relayed transport address. The server then +// checks to see whether the set of permissions for the allocation allow +// the relaying of the UDP datagram as described in Section 8. +// +// If relaying is permitted, then the server checks if there is a +// channel bound to the peer that sent the UDP datagram (see +// Section 11). If a channel is bound, then processing proceeds as +// described in Section 11.7. +// +// If relaying is permitted but no channel is bound to the peer, then +// the server forms and sends a Data indication. The Data indication +// MUST contain both an XOR-PEER-ADDRESS and a DATA attribute. The DATA +// attribute is set to the value of the 'data octets' field from the +// datagram, and the XOR-PEER-ADDRESS attribute is set to the source +// transport address of the received UDP datagram. The Data indication +// is then sent on the 5-tuple associated with the allocation. + +const rtpMTU = 1600 + +func (a *Allocation) packetHandler(manager *Manager) { + buffer := make([]byte, rtpMTU) + + for { + n, srcAddr, err := a.RelaySocket.ReadFrom(buffer) + if err != nil { + manager.DeleteAllocation(a.fiveTuple) + + return + } + + a.log.Debugf("Relay socket %s received %d bytes from %s", + a.RelaySocket.LocalAddr(), + n, + srcAddr) + + if channel := a.GetChannelByAddr(srcAddr); channel != nil { // nolint:nestif + channelData := &proto.ChannelData{ + Data: buffer[:n], + Number: channel.Number, + } + channelData.Encode() + + if _, err = a.TurnSocket.WriteTo(channelData.Raw, a.fiveTuple.SrcAddr); err != nil { + a.log.Errorf("Failed to send ChannelData from allocation %v %v", srcAddr, err) + } + } else if p := a.GetPermission(srcAddr); p != nil { + udpAddr, ok := srcAddr.(*net.UDPAddr) + if !ok { + a.log.Errorf("Failed to send DataIndication from allocation %v %v", srcAddr, err) + + return + } + + peerAddressAttr := proto.PeerAddress{IP: udpAddr.IP, Port: udpAddr.Port} + dataAttr := proto.Data(buffer[:n]) + + msg, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodData, stun.ClassIndication), + peerAddressAttr, + dataAttr, + ) + if err != nil { + a.log.Errorf("Failed to send DataIndication from allocation %v %v", srcAddr, err) + + return + } + a.log.Debugf("Relaying message from %s to client at %s", + srcAddr, + a.fiveTuple.SrcAddr) + if _, err = a.TurnSocket.WriteTo(msg.Raw, a.fiveTuple.SrcAddr); err != nil { + a.log.Errorf("Failed to send DataIndication from allocation %v %v", srcAddr, err) + } + } else { + a.log.Infof("No Permission or Channel exists for %v on allocation %v", srcAddr, a.RelayAddr) + } + } +} diff --git a/vendor/github.com/pion/turn/v4/internal/allocation/allocation_manager.go b/vendor/github.com/pion/turn/v4/internal/allocation/allocation_manager.go new file mode 100644 index 0000000..02d11d0 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/allocation/allocation_manager.go @@ -0,0 +1,246 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package allocation + +import ( + "fmt" + "net" + "sync" + "time" + + "github.com/pion/logging" +) + +// ManagerConfig a bag of config params for Manager. +type ManagerConfig struct { + LeveledLogger logging.LeveledLogger + AllocatePacketConn func(network string, requestedPort int) (net.PacketConn, net.Addr, error) + AllocateConn func(network string, requestedPort int) (net.Conn, net.Addr, error) + PermissionHandler func(sourceAddr net.Addr, peerIP net.IP) bool + EventHandler EventHandler +} + +type reservation struct { + token string + port int +} + +// Manager is used to hold active allocations. +type Manager struct { + lock sync.RWMutex + log logging.LeveledLogger + + allocations map[FiveTupleFingerprint]*Allocation + reservations []*reservation + + allocatePacketConn func(network string, requestedPort int) (net.PacketConn, net.Addr, error) + allocateConn func(network string, requestedPort int) (net.Conn, net.Addr, error) + permissionHandler func(sourceAddr net.Addr, peerIP net.IP) bool + EventHandler EventHandler +} + +// NewManager creates a new instance of Manager. +func NewManager(config ManagerConfig) (*Manager, error) { + switch { + case config.AllocatePacketConn == nil: + return nil, errAllocatePacketConnMustBeSet + case config.AllocateConn == nil: + return nil, errAllocateConnMustBeSet + case config.LeveledLogger == nil: + return nil, errLeveledLoggerMustBeSet + } + + return &Manager{ + log: config.LeveledLogger, + allocations: make(map[FiveTupleFingerprint]*Allocation, 64), + allocatePacketConn: config.AllocatePacketConn, + allocateConn: config.AllocateConn, + permissionHandler: config.PermissionHandler, + EventHandler: config.EventHandler, + }, nil +} + +// GetAllocation fetches the allocation matching the passed FiveTuple. +func (m *Manager) GetAllocation(fiveTuple *FiveTuple) *Allocation { + m.lock.RLock() + defer m.lock.RUnlock() + + return m.allocations[fiveTuple.Fingerprint()] +} + +// AllocationCount returns the number of existing allocations. +func (m *Manager) AllocationCount() int { + m.lock.RLock() + defer m.lock.RUnlock() + + return len(m.allocations) +} + +// Close closes the manager and closes all allocations it manages. +func (m *Manager) Close() error { + m.lock.Lock() + defer m.lock.Unlock() + + for _, a := range m.allocations { + if err := a.Close(); err != nil { + return err + } + } + + return nil +} + +// CreateAllocation creates a new allocation and starts relaying. +func (m *Manager) CreateAllocation( + fiveTuple *FiveTuple, + turnSocket net.PacketConn, + requestedPort int, + lifetime time.Duration, + username, realm string, +) (*Allocation, error) { + switch { + case fiveTuple == nil: + return nil, errNilFiveTuple + case fiveTuple.SrcAddr == nil: + return nil, errNilFiveTupleSrcAddr + case fiveTuple.DstAddr == nil: + return nil, errNilFiveTupleDstAddr + case turnSocket == nil: + return nil, errNilTurnSocket + case lifetime == 0: + return nil, errLifetimeZero + } + + if alloc := m.GetAllocation(fiveTuple); alloc != nil { + return nil, fmt.Errorf("%w: %v", errDupeFiveTuple, fiveTuple) + } + alloc := NewAllocation(turnSocket, fiveTuple, m.EventHandler, m.log) + alloc.username = username + alloc.realm = realm + + conn, relayAddr, err := m.allocatePacketConn("udp4", requestedPort) + if err != nil { + return nil, err + } + + alloc.RelaySocket = conn + alloc.RelayAddr = relayAddr + + m.log.Debugf("Listening on relay address: %s", alloc.RelayAddr) + + alloc.lifetimeTimer = time.AfterFunc(lifetime, func() { + m.DeleteAllocation(alloc.fiveTuple) + }) + + m.lock.Lock() + m.allocations[fiveTuple.Fingerprint()] = alloc + m.lock.Unlock() + + if m.EventHandler.OnAllocationCreated != nil { + m.EventHandler.OnAllocationCreated(fiveTuple.SrcAddr, fiveTuple.DstAddr, + fiveTuple.Protocol.String(), username, realm, relayAddr, requestedPort) + } + + go alloc.packetHandler(m) + + return alloc, nil +} + +// DeleteAllocation removes an allocation. +func (m *Manager) DeleteAllocation(fiveTuple *FiveTuple) { + fingerprint := fiveTuple.Fingerprint() + + m.lock.Lock() + allocation := m.allocations[fingerprint] + delete(m.allocations, fingerprint) + m.lock.Unlock() + + if allocation == nil { + return + } + + if err := allocation.Close(); err != nil { + m.log.Errorf("Failed to close allocation: %v", err) + } + + if m.EventHandler.OnAllocationDeleted != nil { + m.EventHandler.OnAllocationDeleted(fiveTuple.SrcAddr, fiveTuple.DstAddr, + fiveTuple.Protocol.String(), allocation.username, allocation.realm) + } +} + +// CreateReservation stores the reservation for the token+port. +func (m *Manager) CreateReservation(reservationToken string, port int) { + time.AfterFunc(30*time.Second, func() { + m.lock.Lock() + defer m.lock.Unlock() + for i := len(m.reservations) - 1; i >= 0; i-- { + if m.reservations[i].token == reservationToken { + m.reservations = append(m.reservations[:i], m.reservations[i+1:]...) + + return + } + } + }) + + m.lock.Lock() + m.reservations = append(m.reservations, &reservation{ + token: reservationToken, + port: port, + }) + m.lock.Unlock() +} + +// GetReservation returns the port for a given reservation if it exists. +func (m *Manager) GetReservation(reservationToken string) (int, bool) { + m.lock.RLock() + defer m.lock.RUnlock() + + for _, r := range m.reservations { + if r.token == reservationToken { + return r.port, true + } + } + + return 0, false +} + +// GetRandomEvenPort returns a random un-allocated udp4 port. +func (m *Manager) GetRandomEvenPort() (int, error) { + for i := 0; i < 128; i++ { + conn, addr, err := m.allocatePacketConn("udp4", 0) + if err != nil { + return 0, err + } + udpAddr, ok := addr.(*net.UDPAddr) + err = conn.Close() + if err != nil { + return 0, err + } + + if !ok { + return 0, errFailedToCastUDPAddr + } + if udpAddr.Port%2 == 0 { + return udpAddr.Port, nil + } + } + + return 0, errFailedToAllocateEvenPort +} + +// GrantPermission handles permission requests by calling the permission handler callback +// associated with the TURN server listener socket. +func (m *Manager) GrantPermission(sourceAddr net.Addr, peerIP net.IP) error { + // No permission handler: open + if m.permissionHandler == nil { + return nil + } + + if m.permissionHandler(sourceAddr, peerIP) { + return nil + } + + return errAdminProhibited +} diff --git a/vendor/github.com/pion/turn/v4/internal/allocation/channel_bind.go b/vendor/github.com/pion/turn/v4/internal/allocation/channel_bind.go new file mode 100644 index 0000000..6ad9b46 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/allocation/channel_bind.go @@ -0,0 +1,46 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package allocation + +import ( + "net" + "time" + + "github.com/pion/logging" + "github.com/pion/turn/v4/internal/proto" +) + +// ChannelBind represents a TURN Channel +// See: https://tools.ietf.org/html/rfc5766#section-2.5 +type ChannelBind struct { + Peer net.Addr + Number proto.ChannelNumber + + allocation *Allocation + lifetimeTimer *time.Timer + log logging.LeveledLogger +} + +// NewChannelBind creates a new ChannelBind. +func NewChannelBind(number proto.ChannelNumber, peer net.Addr, log logging.LeveledLogger) *ChannelBind { + return &ChannelBind{ + Number: number, + Peer: peer, + log: log, + } +} + +func (c *ChannelBind) start(lifetime time.Duration) { + c.lifetimeTimer = time.AfterFunc(lifetime, func() { + if !c.allocation.RemoveChannelBind(c.Number) { + c.log.Errorf("Failed to remove ChannelBind for %v %x %v", c.Number, c.Peer, c.allocation.fiveTuple) + } + }) +} + +func (c *ChannelBind) refresh(lifetime time.Duration) { + if !c.lifetimeTimer.Reset(lifetime) { + c.log.Errorf("Failed to reset ChannelBind timer for %v %x %v", c.Number, c.Peer, c.allocation.fiveTuple) + } +} diff --git a/vendor/github.com/pion/turn/v4/internal/allocation/errors.go b/vendor/github.com/pion/turn/v4/internal/allocation/errors.go new file mode 100644 index 0000000..584073c --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/allocation/errors.go @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package allocation + +import "errors" + +var ( + errAllocatePacketConnMustBeSet = errors.New("AllocatePacketConn must be set") + errAllocateConnMustBeSet = errors.New("AllocateConn must be set") + errLeveledLoggerMustBeSet = errors.New("LeveledLogger must be set") + errSameChannelDifferentPeer = errors.New("you cannot use the same channel number with different peer") + errNilFiveTuple = errors.New("allocations must not be created with nil FivTuple") + errNilFiveTupleSrcAddr = errors.New("allocations must not be created with nil FiveTuple.SrcAddr") + errNilFiveTupleDstAddr = errors.New("allocations must not be created with nil FiveTuple.DstAddr") + errNilTurnSocket = errors.New("allocations must not be created with nil turnSocket") + errLifetimeZero = errors.New("allocations must not be created with a lifetime of 0") + errDupeFiveTuple = errors.New("allocation attempt created with duplicate FiveTuple") + errFailedToCastUDPAddr = errors.New("failed to cast net.Addr to *net.UDPAddr") + errFailedToAllocateEvenPort = errors.New("failed to allocate an even port") + errAdminProhibited = errors.New("permission request administratively prohibited") +) diff --git a/vendor/github.com/pion/turn/v4/internal/allocation/event_handler.go b/vendor/github.com/pion/turn/v4/internal/allocation/event_handler.go new file mode 100644 index 0000000..1d9bd3f --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/allocation/event_handler.go @@ -0,0 +1,47 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package allocation + +import ( + "net" +) + +// EventHandler is a set of callbacks that the server will call at certain hook points during an +// allocation's lifecycle. All events are reported with the context that identifies the allocation +// triggering the event (source and destination address, protocol, username and realm used for +// authenticating the allocation), plus additional callback specific parameters. It is OK to handle +// only a subset of the callbacks. +type EventHandler struct { + // OnAuth is called after an authentication request has been processed with the TURN method + // triggering the authentication request (either "Allocate", "Refresh" "CreatePermission", + // or "ChannelBind"), and the verdict is the authentication result. + OnAuth func(srcAddr, dstAddr net.Addr, protocol, username, realm string, method string, verdict bool) + // OnAllocationCreated is called after a new allocation has been made. The relayAddr + // argument specifies the relay address and requestedPort is the port requested by the + // client (if any). + OnAllocationCreated func(srcAddr, dstAddr net.Addr, protocol, username, realm string, + relayAddr net.Addr, requestedPort int) + // OnAllocationDeleted is called after an allocation has been removed. + OnAllocationDeleted func(srcAddr, dstAddr net.Addr, protocol, username, realm string) + // OnAllocationError is called when the readloop hdndling an allocation exits with an + // error with an error message. + OnAllocationError func(srcAddr, dstAddr net.Addr, protocol, message string) + // OnPermissionCreated is called after a new permission has been made to an IP address. + OnPermissionCreated func(srcAddr, dstAddr net.Addr, protocol, username, realm string, + relayAddr net.Addr, peer net.IP) + // OnPermissionDeleted is called after a permission for a given IP address has been + // removed. + OnPermissionDeleted func(srcAddr, dstAddr net.Addr, protocol, username, realm string, + relayAddr net.Addr, peer net.IP) + // OnChannelCreated is called after a new channel has been made. The relay address, the + // peer address and the channel number can be used to uniquely identify the channel + // created. + OnChannelCreated func(srcAddr, dstAddr net.Addr, protocol, username, realm string, + relayAddr, peer net.Addr, channelNumber uint16) + // OnChannelDeleted is called after a channel has been removed from the server. The relay + // address, the peer address and the channel number can be used to uniquely identify the + // channel deleted. + OnChannelDeleted func(srcAddr, dstAddr net.Addr, protocol, username, realm string, + relayAddr, peer net.Addr, channelNumber uint16) +} diff --git a/vendor/github.com/pion/turn/v4/internal/allocation/five_tuple.go b/vendor/github.com/pion/turn/v4/internal/allocation/five_tuple.go new file mode 100644 index 0000000..b9eba87 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/allocation/five_tuple.go @@ -0,0 +1,75 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package allocation + +import ( + "net" +) + +// Protocol is an enum for relay protocol. +type Protocol uint8 + +// Network protocols for relay. +const ( + UDP Protocol = iota + TCP +) + +func (p Protocol) String() string { + switch p { + case UDP: + return "UDP" + case TCP: + return "TCP" + default: + return "" + } +} + +// FiveTuple is the combination (client IP address and port, server IP +// address and port, and transport protocol (currently one of UDP, +// TCP, or TLS)) used to communicate between the client and the +// server. The 5-tuple uniquely identifies this communication +// stream. The 5-tuple also uniquely identifies the Allocation on +// the server. +type FiveTuple struct { + Protocol + SrcAddr, DstAddr net.Addr +} + +// Equal asserts if two FiveTuples are equal. +func (f *FiveTuple) Equal(b *FiveTuple) bool { + return f.Fingerprint() == b.Fingerprint() +} + +// FiveTupleFingerprint is a comparable representation of a FiveTuple. +type FiveTupleFingerprint struct { + srcIP, dstIP [16]byte + srcPort, dstPort uint16 + protocol Protocol +} + +// Fingerprint is the identity of a FiveTuple. +func (f *FiveTuple) Fingerprint() (fp FiveTupleFingerprint) { + srcIP, srcPort := netAddrIPAndPort(f.SrcAddr) + copy(fp.srcIP[:], srcIP) + fp.srcPort = srcPort + dstIP, dstPort := netAddrIPAndPort(f.DstAddr) + copy(fp.dstIP[:], dstIP) + fp.dstPort = dstPort + fp.protocol = f.Protocol + + return +} + +func netAddrIPAndPort(addr net.Addr) (net.IP, uint16) { + switch a := addr.(type) { + case *net.UDPAddr: + return a.IP.To16(), uint16(a.Port) // nolint:gosec // G115 + case *net.TCPAddr: + return a.IP.To16(), uint16(a.Port) // nolint:gosec // G115 + default: + return nil, 0 + } +} diff --git a/vendor/github.com/pion/turn/v4/internal/allocation/permission.go b/vendor/github.com/pion/turn/v4/internal/allocation/permission.go new file mode 100644 index 0000000..7b02adc --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/allocation/permission.go @@ -0,0 +1,43 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package allocation + +import ( + "net" + "time" + + "github.com/pion/logging" +) + +const permissionTimeout = time.Duration(5) * time.Minute + +// Permission represents a TURN permission. TURN permissions mimic the address-restricted +// filtering mechanism of NATs that comply with [RFC4787]. +// See: https://tools.ietf.org/html/rfc5766#section-2.3 +type Permission struct { + Addr net.Addr + allocation *Allocation + lifetimeTimer *time.Timer + log logging.LeveledLogger +} + +// NewPermission create a new Permission. +func NewPermission(addr net.Addr, log logging.LeveledLogger) *Permission { + return &Permission{ + Addr: addr, + log: log, + } +} + +func (p *Permission) start(lifetime time.Duration) { + p.lifetimeTimer = time.AfterFunc(lifetime, func() { + p.allocation.RemovePermission(p.Addr) + }) +} + +func (p *Permission) refresh(lifetime time.Duration) { + if !p.lifetimeTimer.Reset(lifetime) { + p.log.Errorf("Failed to reset permission timer for %v %v", p.Addr, p.allocation.fiveTuple) + } +} diff --git a/vendor/github.com/pion/turn/v4/internal/client/allocation.go b/vendor/github.com/pion/turn/v4/internal/client/allocation.go new file mode 100644 index 0000000..3fccb57 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/client/allocation.go @@ -0,0 +1,197 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package client + +import ( + "errors" + "fmt" + "net" + "sync" + "time" + + "github.com/pion/logging" + "github.com/pion/stun/v3" + "github.com/pion/transport/v4" + "github.com/pion/turn/v4/internal/proto" +) + +// AllocationConfig is a set of configuration params use by NewUDPConn and NewTCPAllocation. +type AllocationConfig struct { + Client Client + RelayedAddr net.Addr + ServerAddr net.Addr + Integrity stun.MessageIntegrity + Nonce stun.Nonce + Username stun.Username + Realm stun.Realm + Lifetime time.Duration + Net transport.Net + Log logging.LeveledLogger +} + +type allocation struct { + client Client // Read-only + relayedAddr net.Addr // Read-only + serverAddr net.Addr // Read-only + permMap *permissionMap // Thread-safe + integrity stun.MessageIntegrity // Read-only + username stun.Username // Read-only + realm stun.Realm // Read-only + _nonce stun.Nonce // Needs mutex x + _lifetime time.Duration // Needs mutex x + net transport.Net // Thread-safe + refreshAllocTimer *PeriodicTimer // Thread-safe + refreshPermsTimer *PeriodicTimer // Thread-safe + readTimer *time.Timer // Thread-safe + mutex sync.RWMutex // Thread-safe + log logging.LeveledLogger // Read-only +} + +func (a *allocation) setNonceFromMsg(msg *stun.Message) { + // Update nonce + var nonce stun.Nonce + if err := nonce.GetFrom(msg); err == nil { + a.setNonce(nonce) + a.log.Debug("Refresh allocation: 438, got new nonce.") + } else { + a.log.Warn("Refresh allocation: 438 but no nonce.") + } +} + +func (a *allocation) refreshAllocation(lifetime time.Duration, dontWait bool) error { + msg, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodRefresh, stun.ClassRequest), + proto.Lifetime{Duration: lifetime}, + a.username, + a.realm, + a.nonce(), + a.integrity, + stun.Fingerprint, + ) + if err != nil { + return fmt.Errorf("%w: %s", errFailedToBuildRefreshRequest, err.Error()) + } + + a.log.Debugf("Send refresh request (dontWait=%v)", dontWait) + trRes, err := a.client.PerformTransaction(msg, a.serverAddr, dontWait) + if err != nil { + return fmt.Errorf("%w: %s", errFailedToRefreshAllocation, err.Error()) + } + + if dontWait { + a.log.Debug("Refresh request sent") + + return nil + } + + a.log.Debug("Refresh request sent, and waiting response") + + res := trRes.Msg + if res.Type.Class == stun.ClassErrorResponse { + var code stun.ErrorCodeAttribute + if err = code.GetFrom(res); err == nil { + if code.Code == stun.CodeStaleNonce { + a.setNonceFromMsg(res) + + return errTryAgain + } + + return err + } + + return fmt.Errorf("%s", res.Type) //nolint:err113 + } + + // Getting lifetime from response + var updatedLifetime proto.Lifetime + if err := updatedLifetime.GetFrom(res); err != nil { + return fmt.Errorf("%w: %s", errFailedToGetLifetime, err.Error()) + } + + a.setLifetime(updatedLifetime.Duration) + a.log.Debugf("Updated lifetime: %d seconds", int(a.lifetime().Seconds())) + + return nil +} + +func (a *allocation) refreshPermissions() error { + addrs := a.permMap.addrs() + if len(addrs) == 0 { + a.log.Debug("No permission to refresh") + + return nil + } + if err := a.CreatePermissions(addrs...); err != nil { + if errors.Is(err, errTryAgain) { + return errTryAgain + } + a.log.Errorf("Fail to refresh permissions: %s", err) + + return err + } + a.log.Debug("Refresh permissions successful") + + return nil +} + +func (a *allocation) onRefreshTimers(id int) { + a.log.Debugf("Refresh timer %d expired", id) + switch id { + case timerIDRefreshAlloc: + var err error + lifetime := a.lifetime() + // Limit the max retries on errTryAgain to 3 + // when stale nonce returns, sencond retry should succeed + for i := 0; i < maxRetryAttempts; i++ { + err = a.refreshAllocation(lifetime, false) + if !errors.Is(err, errTryAgain) { + break + } + } + if err != nil { + a.log.Warnf("Failed to refresh allocation: %s", err) + } + case timerIDRefreshPerms: + var err error + for i := 0; i < maxRetryAttempts; i++ { + err = a.refreshPermissions() + if !errors.Is(err, errTryAgain) { + break + } + } + if err != nil { + a.log.Warnf("Failed to refresh permissions: %s", err) + } + } +} + +func (a *allocation) nonce() stun.Nonce { + a.mutex.RLock() + defer a.mutex.RUnlock() + + return a._nonce +} + +func (a *allocation) setNonce(nonce stun.Nonce) { + a.mutex.Lock() + defer a.mutex.Unlock() + + a.log.Debugf("Set new nonce with %d bytes", len(nonce)) + a._nonce = nonce +} + +func (a *allocation) lifetime() time.Duration { + a.mutex.RLock() + defer a.mutex.RUnlock() + + return a._lifetime +} + +func (a *allocation) setLifetime(lifetime time.Duration) { + a.mutex.Lock() + defer a.mutex.Unlock() + + a._lifetime = lifetime +} diff --git a/vendor/github.com/pion/turn/v4/internal/client/binding.go b/vendor/github.com/pion/turn/v4/internal/client/binding.go new file mode 100644 index 0000000..0720d03 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/client/binding.go @@ -0,0 +1,179 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package client + +import ( + "net" + "sync" + "sync/atomic" + "time" +) + +// Channel number: +// +// 0x4000 through 0x7FFF: These values are the allowed channel +// numbers (16,383 possible values). +const ( + minChannelNumber uint16 = 0x4000 + maxChannelNumber uint16 = 0x7fff +) + +type bindingState int32 + +const ( + bindingStateIdle bindingState = iota + bindingStateRequest + bindingStateReady + bindingStateRefresh + bindingStateFailed +) + +type binding struct { + number uint16 // Read-only + st bindingState // Thread-safe (atomic op) + addr net.Addr // Read-only + mgr *bindingManager // Read-only + muBind sync.Mutex // Thread-safe, for ChannelBind ops + _refreshedAt time.Time // Protected by mutex + mutex sync.RWMutex // Thread-safe +} + +func (b *binding) setState(state bindingState) { + atomic.StoreInt32((*int32)(&b.st), int32(state)) +} + +func (b *binding) state() bindingState { + return bindingState(atomic.LoadInt32((*int32)(&b.st))) +} + +func (b *binding) setRefreshedAt(at time.Time) { + b.mutex.Lock() + defer b.mutex.Unlock() + + b._refreshedAt = at +} + +func (b *binding) refreshedAt() time.Time { + b.mutex.RLock() + defer b.mutex.RUnlock() + + return b._refreshedAt +} + +func (b *binding) ok() bool { + state := b.state() + + return state == bindingStateReady || state == bindingStateRefresh +} + +// Thread-safe binding map. +type bindingManager struct { + chanMap map[uint16]*binding + addrMap map[string]*binding + next uint16 + mutex sync.RWMutex +} + +func newBindingManager() *bindingManager { + return &bindingManager{ + chanMap: map[uint16]*binding{}, + addrMap: map[string]*binding{}, + next: minChannelNumber, + } +} + +func (mgr *bindingManager) assignChannelNumber() uint16 { + n := mgr.next + if mgr.next == maxChannelNumber { + mgr.next = minChannelNumber + } else { + mgr.next++ + } + + return n +} + +func (mgr *bindingManager) create(addr net.Addr) *binding { + mgr.mutex.Lock() + defer mgr.mutex.Unlock() + + b := &binding{ + number: mgr.assignChannelNumber(), + addr: addr, + mgr: mgr, + _refreshedAt: time.Now(), + } + + mgr.chanMap[b.number] = b + mgr.addrMap[b.addr.String()] = b + + return b +} + +func (mgr *bindingManager) findByAddr(addr net.Addr) (*binding, bool) { + mgr.mutex.RLock() + defer mgr.mutex.RUnlock() + + b, ok := mgr.addrMap[addr.String()] + + return b, ok +} + +func (mgr *bindingManager) findByNumber(number uint16) (*binding, bool) { + mgr.mutex.RLock() + defer mgr.mutex.RUnlock() + + b, ok := mgr.chanMap[number] + + return b, ok +} + +func (mgr *bindingManager) deleteByAddr(addr net.Addr) bool { + mgr.mutex.Lock() + defer mgr.mutex.Unlock() + + b, ok := mgr.addrMap[addr.String()] + if !ok { + return false + } + + delete(mgr.addrMap, addr.String()) + delete(mgr.chanMap, b.number) + + return true +} + +func (mgr *bindingManager) deleteByNumber(number uint16) bool { + mgr.mutex.Lock() + defer mgr.mutex.Unlock() + + b, ok := mgr.chanMap[number] + if !ok { + return false + } + + delete(mgr.addrMap, b.addr.String()) + delete(mgr.chanMap, number) + + return true +} + +func (mgr *bindingManager) size() int { + mgr.mutex.RLock() + defer mgr.mutex.RUnlock() + + return len(mgr.chanMap) +} + +func (mgr *bindingManager) all() []*binding { + mgr.mutex.RLock() + defer mgr.mutex.RUnlock() + + list := make([]*binding, 0, len(mgr.chanMap)) + for _, b := range mgr.chanMap { + list = append(list, b) + } + + return list +} diff --git a/vendor/github.com/pion/turn/v4/internal/client/client.go b/vendor/github.com/pion/turn/v4/internal/client/client.go new file mode 100644 index 0000000..45041c9 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/client/client.go @@ -0,0 +1,18 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package client implements the API for a TURN client +package client + +import ( + "net" + + "github.com/pion/stun/v3" +) + +// Client is an interface for the public turn.Client in order to break cyclic dependencies. +type Client interface { + WriteTo(data []byte, to net.Addr) (int, error) + PerformTransaction(msg *stun.Message, to net.Addr, dontWait bool) (TransactionResult, error) + OnDeallocated(relayedAddr net.Addr) +} diff --git a/vendor/github.com/pion/turn/v4/internal/client/errors.go b/vendor/github.com/pion/turn/v4/internal/client/errors.go new file mode 100644 index 0000000..ff40d2e --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/client/errors.go @@ -0,0 +1,43 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package client + +import ( + "errors" +) + +var ( + errFake = errors.New("fake error") + errTryAgain = errors.New("try again") + errClosed = errors.New("use of closed network connection") + errTCPAddrCast = errors.New("addr is not a TCP address") + errUDPAddrCast = errors.New("addr is not a UDP address") + errAlreadyClosed = errors.New("already closed") + errDoubleLock = errors.New("try-lock is already locked") + errTransactionClosed = errors.New("transaction closed") + errWaitForResultOnNonResultTransaction = errors.New("WaitForResult called on non-result transaction") + errFailedToBuildRefreshRequest = errors.New("failed to build refresh request") + errFailedToRefreshAllocation = errors.New("failed to refresh allocation") + errFailedToGetLifetime = errors.New("failed to get lifetime from refresh response") + errInvalidTURNAddress = errors.New("invalid TURN server address") + errUnexpectedSTUNRequestMessage = errors.New("unexpected STUN request message") +) + +type timeoutError struct { + msg string +} + +func newTimeoutError(msg string) error { + return &timeoutError{ + msg: msg, + } +} + +func (e *timeoutError) Error() string { + return e.msg +} + +func (e *timeoutError) Timeout() bool { + return true +} diff --git a/vendor/github.com/pion/turn/v4/internal/client/periodic_timer.go b/vendor/github.com/pion/turn/v4/internal/client/periodic_timer.go new file mode 100644 index 0000000..9e94912 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/client/periodic_timer.go @@ -0,0 +1,85 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package client + +import ( + "sync" + "time" +) + +// PeriodicTimerTimeoutHandler is a handler called on timeout. +type PeriodicTimerTimeoutHandler func(timerID int) + +// PeriodicTimer is a periodic timer. +type PeriodicTimer struct { + id int + interval time.Duration + timeoutHandler PeriodicTimerTimeoutHandler + stopFunc func() + mutex sync.RWMutex +} + +// NewPeriodicTimer create a new timer. +func NewPeriodicTimer(id int, timeoutHandler PeriodicTimerTimeoutHandler, interval time.Duration) *PeriodicTimer { + return &PeriodicTimer{ + id: id, + interval: interval, + timeoutHandler: timeoutHandler, + } +} + +// Start starts the timer. +func (t *PeriodicTimer) Start() bool { + t.mutex.Lock() + defer t.mutex.Unlock() + + // This is a noop if the timer is always running + if t.stopFunc != nil { + return false + } + + cancelCh := make(chan struct{}) + + go func() { + canceling := false + + for !canceling { + timer := time.NewTimer(t.interval) + + select { + case <-timer.C: + t.timeoutHandler(t.id) + case <-cancelCh: + canceling = true + timer.Stop() + } + } + }() + + t.stopFunc = func() { + close(cancelCh) + } + + return true +} + +// Stop stops the timer. +func (t *PeriodicTimer) Stop() { + t.mutex.Lock() + defer t.mutex.Unlock() + + if t.stopFunc != nil { + t.stopFunc() + t.stopFunc = nil + } +} + +// IsRunning tests if the timer is running. +// Debug purpose only. +func (t *PeriodicTimer) IsRunning() bool { + t.mutex.RLock() + defer t.mutex.RUnlock() + + return (t.stopFunc != nil) +} diff --git a/vendor/github.com/pion/turn/v4/internal/client/permission.go b/vendor/github.com/pion/turn/v4/internal/client/permission.go new file mode 100644 index 0000000..d6708d8 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/client/permission.go @@ -0,0 +1,80 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package client + +import ( + "net" + "sync" + "sync/atomic" + + "github.com/pion/turn/v4/internal/ipnet" +) + +type permState int32 + +const ( + permStateIdle permState = iota + permStatePermitted +) + +type permission struct { + addr net.Addr + st permState // Thread-safe (atomic op) + mutex sync.RWMutex // Thread-safe +} + +func (p *permission) setState(state permState) { + atomic.StoreInt32((*int32)(&p.st), int32(state)) +} + +func (p *permission) state() permState { + return permState(atomic.LoadInt32((*int32)(&p.st))) +} + +// Thread-safe permission map. +type permissionMap struct { + permMap map[string]*permission + mutex sync.RWMutex +} + +func (m *permissionMap) insert(addr net.Addr, p *permission) bool { + m.mutex.Lock() + defer m.mutex.Unlock() + p.addr = addr + m.permMap[ipnet.FingerprintAddr(addr)] = p + + return true +} + +func (m *permissionMap) find(addr net.Addr) (*permission, bool) { + m.mutex.RLock() + defer m.mutex.RUnlock() + p, ok := m.permMap[ipnet.FingerprintAddr(addr)] + + return p, ok +} + +func (m *permissionMap) delete(addr net.Addr) { + m.mutex.Lock() + defer m.mutex.Unlock() + delete(m.permMap, ipnet.FingerprintAddr(addr)) +} + +func (m *permissionMap) addrs() []net.Addr { + m.mutex.RLock() + defer m.mutex.RUnlock() + + addrs := []net.Addr{} + for _, p := range m.permMap { + addrs = append(addrs, p.addr) + } + + return addrs +} + +func newPermissionMap() *permissionMap { + return &permissionMap{ + permMap: map[string]*permission{}, + } +} diff --git a/vendor/github.com/pion/turn/v4/internal/client/tcp_alloc.go b/vendor/github.com/pion/turn/v4/internal/client/tcp_alloc.go new file mode 100644 index 0000000..8ada3ca --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/client/tcp_alloc.go @@ -0,0 +1,381 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package client + +import ( + "encoding/binary" + "errors" + "fmt" + "math" + "net" + "time" + + "github.com/pion/stun/v3" + "github.com/pion/transport/v4" + "github.com/pion/turn/v4/internal/proto" +) + +var ( + _ transport.TCPListener = (*TCPAllocation)(nil) // Includes type check for net.Listener + _ transport.Dialer = (*TCPAllocation)(nil) +) + +func noDeadline() time.Time { + return time.Time{} +} + +// TCPAllocation is an active TCP allocation on the TURN server +// as specified by RFC 6062. +// The allocation can be used to Dial/Accept relayed outgoing/incoming TCP connections. +type TCPAllocation struct { + connAttemptCh chan *connectionAttempt + acceptTimer *time.Timer + allocation +} + +// NewTCPAllocation creates a new instance of TCPConn. +func NewTCPAllocation(config *AllocationConfig) *TCPAllocation { + alloc := &TCPAllocation{ + connAttemptCh: make(chan *connectionAttempt, 10), + acceptTimer: time.NewTimer(time.Duration(math.MaxInt64)), + allocation: allocation{ + client: config.Client, + relayedAddr: config.RelayedAddr, + serverAddr: config.ServerAddr, + username: config.Username, + realm: config.Realm, + permMap: newPermissionMap(), + integrity: config.Integrity, + _nonce: config.Nonce, + _lifetime: config.Lifetime, + net: config.Net, + log: config.Log, + }, + } + + alloc.log.Debugf("Initial lifetime: %d seconds", int(alloc.lifetime().Seconds())) + + alloc.refreshAllocTimer = NewPeriodicTimer( + timerIDRefreshAlloc, + alloc.onRefreshTimers, + alloc.lifetime()/2, + ) + + alloc.refreshPermsTimer = NewPeriodicTimer( + timerIDRefreshPerms, + alloc.onRefreshTimers, + permRefreshInterval, + ) + + if alloc.refreshAllocTimer.Start() { + alloc.log.Debug("Started refreshAllocTimer") + } + if alloc.refreshPermsTimer.Start() { + alloc.log.Debug("Started refreshPermsTimer") + } + + return alloc +} + +// Connect sends a Connect request to the turn server and returns a chosen connection ID. +func (a *TCPAllocation) Connect(peer net.Addr) (proto.ConnectionID, error) { + setters := []stun.Setter{ + stun.TransactionID, + stun.NewType(stun.MethodConnect, stun.ClassRequest), + addr2PeerAddress(peer), + a.username, + a.realm, + a.nonce(), + a.integrity, + stun.Fingerprint, + } + + msg, err := stun.Build(setters...) + if err != nil { + return 0, err + } + + a.log.Debugf("Send connect request (peer=%v)", peer) + trRes, err := a.client.PerformTransaction(msg, a.serverAddr, false) + if err != nil { + return 0, err + } + + res := trRes.Msg + + if res.Type.Class == stun.ClassErrorResponse { + var code stun.ErrorCodeAttribute + if err = code.GetFrom(res); err == nil { + return 0, fmt.Errorf("%s (error %s)", res.Type, code) //nolint // dynamic errors + } + + return 0, fmt.Errorf("%s", res.Type) //nolint // dynamic errors + } + + var cid proto.ConnectionID + if err := cid.GetFrom(res); err != nil { + return 0, err + } + + a.log.Debugf("Connect request successful (cid=%v)", cid) + + return cid, nil +} + +// Dial connects to the address on the named network. +func (a *TCPAllocation) Dial(network, rAddrStr string) (net.Conn, error) { + rAddr, err := net.ResolveTCPAddr(network, rAddrStr) + if err != nil { + return nil, err + } + + return a.DialTCP(network, nil, rAddr) +} + +// DialWithConn connects to the address on the named network with an already existing connection. +// The provided connection must be an already connected TCP connection to the TURN server. +func (a *TCPAllocation) DialWithConn(conn net.Conn, network, rAddrStr string) (*TCPConn, error) { + rAddr, err := net.ResolveTCPAddr(network, rAddrStr) + if err != nil { + return nil, err + } + + return a.DialTCPWithConn(conn, network, rAddr) +} + +// DialTCP acts like Dial for TCP networks. +func (a *TCPAllocation) DialTCP(network string, lAddr, rAddr *net.TCPAddr) (*TCPConn, error) { + var rAddrServer *net.TCPAddr + if addr, ok := a.serverAddr.(*net.TCPAddr); ok { + rAddrServer = &net.TCPAddr{ + IP: addr.IP, + Port: addr.Port, + } + } else if addr, ok := a.serverAddr.(*net.UDPAddr); ok { + rAddrServer = &net.TCPAddr{ + IP: addr.IP, + Port: addr.Port, + } + } else { + return nil, errInvalidTURNAddress + } + + conn, err := a.net.DialTCP(network, lAddr, rAddrServer) + if err != nil { + return nil, err + } + + dataConn, err := a.DialTCPWithConn(conn, network, rAddr) + if err != nil { + conn.Close() //nolint:errcheck,gosec + } + + return dataConn, err +} + +// DialTCPWithConn acts like DialWithConn for TCP networks. +func (a *TCPAllocation) DialTCPWithConn(conn net.Conn, _ string, rAddr *net.TCPAddr) (*TCPConn, error) { + var err error + + // Check if we have a permission for the destination IP addr + perm, ok := a.permMap.find(rAddr) + if !ok { + perm = &permission{} + a.permMap.insert(rAddr, perm) + } + + for i := 0; i < maxRetryAttempts; i++ { + if err = a.createPermission(perm, rAddr); !errors.Is(err, errTryAgain) { + break + } + } + if err != nil { + return nil, err + } + + // Send connect request if haven't done so. + cid, err := a.Connect(rAddr) + if err != nil { + return nil, err + } + + tcpConn, ok := conn.(transport.TCPConn) + if !ok { + return nil, errTCPAddrCast + } + + dataConn := &TCPConn{ + TCPConn: tcpConn, + ConnectionID: cid, + remoteAddress: rAddr, + allocation: a, + } + + if err := a.BindConnection(dataConn, cid); err != nil { + return nil, fmt.Errorf("failed to bind connection: %w", err) + } + + return dataConn, nil +} + +// BindConnection associates the provided connection. +func (a *TCPAllocation) BindConnection(dataConn *TCPConn, cid proto.ConnectionID) error { //nolint:cyclop + msg, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodConnectionBind, stun.ClassRequest), + cid, + a.username, + a.realm, + a.nonce(), + a.integrity, + stun.Fingerprint, + ) + if err != nil { + return err + } + + a.log.Debugf("Send connectionBind request (cid=%v)", cid) + + _, err = dataConn.Write(msg.Raw) + if err != nil { + return err + } + + // Read exactly one STUN message, any data after belongs to the user + b := make([]byte, stunHeaderSize) + n, err := dataConn.Read(b) + if n != stunHeaderSize { + return errIncompleteTURNFrame + } else if err != nil { + return err + } + + if !stun.IsMessage(b) { + return errInvalidTURNFrame + } + + datagramSize := binary.BigEndian.Uint16(b[2:4]) + stunHeaderSize + raw := make([]byte, datagramSize) + copy(raw, b) + _, err = dataConn.Read(raw[stunHeaderSize:]) + if err != nil { + return err + } + res := &stun.Message{Raw: raw} + if err = res.Decode(); err != nil { + return fmt.Errorf("failed to decode STUN message: %w", err) + } + + switch res.Type.Class { + case stun.ClassErrorResponse: + var code stun.ErrorCodeAttribute + if err = code.GetFrom(res); err == nil { + return fmt.Errorf("%s (error %s)", res.Type, code) //nolint // dynamic errors + } + + return fmt.Errorf("%s", res.Type) //nolint // dynamic errors + case stun.ClassSuccessResponse: + a.log.Debug("Successful connectionBind request") + + return nil + default: + return fmt.Errorf("%w: %s", errUnexpectedSTUNRequestMessage, res.String()) + } +} + +// Accept waits for and returns the next connection to the listener. +func (a *TCPAllocation) Accept() (net.Conn, error) { + return a.AcceptTCP() +} + +// AcceptTCP accepts the next incoming call and returns the new connection. +func (a *TCPAllocation) AcceptTCP() (transport.TCPConn, error) { + addr, err := net.ResolveTCPAddr("tcp4", a.serverAddr.String()) + if err != nil { + return nil, err + } + + tcpConn, err := a.net.DialTCP("tcp", nil, addr) + if err != nil { + return nil, err + } + + dataConn, err := a.AcceptTCPWithConn(tcpConn) + if err != nil { + tcpConn.Close() //nolint:errcheck,gosec + } + + return dataConn, err +} + +// AcceptTCPWithConn accepts the next incoming call and returns the new connection. +func (a *TCPAllocation) AcceptTCPWithConn(conn net.Conn) (*TCPConn, error) { + select { + case attempt := <-a.connAttemptCh: + + tcpConn, ok := conn.(transport.TCPConn) + if !ok { + return nil, errTCPAddrCast + } + + dataConn := &TCPConn{ + TCPConn: tcpConn, + ConnectionID: attempt.cid, + remoteAddress: attempt.from, + allocation: a, + } + + if err := a.BindConnection(dataConn, attempt.cid); err != nil { + return nil, fmt.Errorf("failed to bind connection: %w", err) + } + + return dataConn, nil + case <-a.acceptTimer.C: + return nil, &net.OpError{ + Op: "accept", + Net: a.Addr().Network(), + Addr: a.Addr(), + Err: newTimeoutError("i/o timeout"), + } + } +} + +// SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline. +func (a *TCPAllocation) SetDeadline(t time.Time) error { + var d time.Duration + if t.Equal(noDeadline()) { + d = time.Duration(math.MaxInt64) + } else { + d = time.Until(t) + } + a.acceptTimer.Reset(d) + + return nil +} + +// Close releases the allocation +// Any blocked Accept operations will be unblocked and return errors. +// Any opened connection via Dial/Accept will be closed. +func (a *TCPAllocation) Close() error { + a.refreshAllocTimer.Stop() + a.refreshPermsTimer.Stop() + + a.client.OnDeallocated(a.relayedAddr) + + return a.refreshAllocation(0, true /* dontWait=true */) +} + +// Addr returns the relayed address of the allocation. +func (a *TCPAllocation) Addr() net.Addr { + return a.relayedAddr +} + +// HandleConnectionAttempt is called by the TURN client +// when it receives a ConnectionAttempt indication. +func (a *TCPAllocation) HandleConnectionAttempt(from *net.TCPAddr, cid proto.ConnectionID) { + a.connAttemptCh <- &connectionAttempt{ + from: from, + cid: cid, + } +} diff --git a/vendor/github.com/pion/turn/v4/internal/client/tcp_conn.go b/vendor/github.com/pion/turn/v4/internal/client/tcp_conn.go new file mode 100644 index 0000000..ae2985e --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/client/tcp_conn.go @@ -0,0 +1,49 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package client + +import ( + "errors" + "net" + + "github.com/pion/transport/v4" + "github.com/pion/turn/v4/internal/proto" +) + +var ( + errInvalidTURNFrame = errors.New("data is not a valid TURN frame, no STUN or ChannelData found") + errIncompleteTURNFrame = errors.New("data contains incomplete STUN or TURN frame") +) + +const ( + stunHeaderSize = 20 +) + +var _ transport.TCPConn = (*TCPConn)(nil) // Includes type check for net.Conn + +// TCPConn wraps a transport.TCPConn and returns the allocations relayed +// transport address in response to TCPConn.LocalAddress(). +type TCPConn struct { + transport.TCPConn + remoteAddress *net.TCPAddr + allocation *TCPAllocation + ConnectionID proto.ConnectionID +} + +type connectionAttempt struct { + from *net.TCPAddr + cid proto.ConnectionID +} + +// LocalAddr returns the local network address. +// The Addr returned is shared by all invocations of LocalAddr, so do not modify it. +func (c *TCPConn) LocalAddr() net.Addr { + return c.allocation.Addr() +} + +// RemoteAddr returns the remote network address. +// The Addr returned is shared by all invocations of RemoteAddr, so do not modify it. +func (c *TCPConn) RemoteAddr() net.Addr { + return c.remoteAddress +} diff --git a/vendor/github.com/pion/turn/v4/internal/client/transaction.go b/vendor/github.com/pion/turn/v4/internal/client/transaction.go new file mode 100644 index 0000000..744df66 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/client/transaction.go @@ -0,0 +1,191 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package client + +import ( + "net" + "sync" + "time" + + "github.com/pion/stun/v3" +) + +const ( + maxRtxInterval time.Duration = 1600 * time.Millisecond +) + +// TransactionResult is a bag of result values of a transaction. +type TransactionResult struct { + Msg *stun.Message + From net.Addr + Retries int + Err error +} + +// TransactionConfig is a set of config params used by NewTransaction. +type TransactionConfig struct { + Key string + Raw []byte + To net.Addr + Interval time.Duration + IgnoreResult bool // True to throw away the result of this transaction (it will not be readable using WaitForResult) +} + +// Transaction represents a transaction. +type Transaction struct { + Key string // Read-only + Raw []byte // Read-only + To net.Addr // Read-only + nRtx int // Modified only by the timer thread + interval time.Duration // Modified only by the timer thread + timer *time.Timer // Thread-safe, set only by the creator, and stopper + resultCh chan TransactionResult // Thread-safe + mutex sync.RWMutex +} + +// NewTransaction creates a new instance of Transaction. +func NewTransaction(config *TransactionConfig) *Transaction { + var resultCh chan TransactionResult + if !config.IgnoreResult { + resultCh = make(chan TransactionResult) + } + + return &Transaction{ + Key: config.Key, // Read-only + Raw: config.Raw, // Read-only + To: config.To, // Read-only + interval: config.Interval, // Modified only by the timer thread + resultCh: resultCh, // Thread-safe + } +} + +// StartRtxTimer starts the transaction timer. +func (t *Transaction) StartRtxTimer(onTimeout func(trKey string, nRtx int)) { + t.mutex.Lock() + defer t.mutex.Unlock() + + t.timer = time.AfterFunc(t.interval, func() { + t.mutex.Lock() + t.nRtx++ + nRtx := t.nRtx + t.interval *= 2 + if t.interval > maxRtxInterval { + t.interval = maxRtxInterval + } + t.mutex.Unlock() + onTimeout(t.Key, nRtx) + }) +} + +// StopRtxTimer stop the transaction timer. +func (t *Transaction) StopRtxTimer() { + t.mutex.Lock() + defer t.mutex.Unlock() + + if t.timer != nil { + t.timer.Stop() + } +} + +// WriteResult writes the result to the result channel. +func (t *Transaction) WriteResult(res TransactionResult) bool { + if t.resultCh == nil { + return false + } + + t.resultCh <- res + + return true +} + +// WaitForResult waits for the transaction result. +func (t *Transaction) WaitForResult() TransactionResult { + if t.resultCh == nil { + return TransactionResult{ + Err: errWaitForResultOnNonResultTransaction, + } + } + + result, ok := <-t.resultCh + if !ok { + result.Err = errTransactionClosed + } + + return result +} + +// Close closes the transaction. +func (t *Transaction) Close() { + if t.resultCh != nil { + close(t.resultCh) + } +} + +// Retries returns the number of retransmission it has made. +func (t *Transaction) Retries() int { + t.mutex.RLock() + defer t.mutex.RUnlock() + + return t.nRtx +} + +// TransactionMap is a thread-safe transaction map. +type TransactionMap struct { + trMap map[string]*Transaction + mutex sync.RWMutex +} + +// NewTransactionMap create a new instance of the transaction map. +func NewTransactionMap() *TransactionMap { + return &TransactionMap{ + trMap: map[string]*Transaction{}, + } +} + +// Insert inserts a transaction to the map. +func (m *TransactionMap) Insert(key string, tr *Transaction) bool { + m.mutex.Lock() + defer m.mutex.Unlock() + + m.trMap[key] = tr + + return true +} + +// Find looks up a transaction by its key. +func (m *TransactionMap) Find(key string) (*Transaction, bool) { + m.mutex.RLock() + defer m.mutex.RUnlock() + + tr, ok := m.trMap[key] + + return tr, ok +} + +// Delete deletes a transaction by its key. +func (m *TransactionMap) Delete(key string) { + m.mutex.Lock() + defer m.mutex.Unlock() + + delete(m.trMap, key) +} + +// CloseAndDeleteAll closes and deletes all transactions. +func (m *TransactionMap) CloseAndDeleteAll() { + m.mutex.Lock() + defer m.mutex.Unlock() + + for trKey, tr := range m.trMap { + tr.Close() + delete(m.trMap, trKey) + } +} + +// Size returns the length of the transaction map. +func (m *TransactionMap) Size() int { + m.mutex.RLock() + defer m.mutex.RUnlock() + + return len(m.trMap) +} diff --git a/vendor/github.com/pion/turn/v4/internal/client/trylock.go b/vendor/github.com/pion/turn/v4/internal/client/trylock.go new file mode 100644 index 0000000..8380b0c --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/client/trylock.go @@ -0,0 +1,28 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package client + +import ( + "sync/atomic" +) + +// TryLock implement the classic "try-lock" operation. +type TryLock struct { + n int32 +} + +// Lock tries to lock the try-lock. If successful, it returns true. +// Otherwise, it returns false immediately. +func (c *TryLock) Lock() error { + if !atomic.CompareAndSwapInt32(&c.n, 0, 1) { + return errDoubleLock + } + + return nil +} + +// Unlock unlocks the try-lock. +func (c *TryLock) Unlock() { + atomic.StoreInt32(&c.n, 0) +} diff --git a/vendor/github.com/pion/turn/v4/internal/client/udp_conn.go b/vendor/github.com/pion/turn/v4/internal/client/udp_conn.go new file mode 100644 index 0000000..7f5e1d8 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/client/udp_conn.go @@ -0,0 +1,491 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package client implements the API for a TURN client +package client + +import ( + "errors" + "fmt" + "io" + "math" + "net" + "time" + + "github.com/pion/stun/v3" + "github.com/pion/turn/v4/internal/proto" +) + +const ( + maxReadQueueSize = 1024 + permRefreshInterval = 120 * time.Second + bindingRefreshInterval = 5 * time.Minute + bindingCheckInterval = 30 * time.Second + maxRetryAttempts = 3 +) + +const ( + timerIDRefreshAlloc int = iota + timerIDRefreshPerms + timerIDCheckBindings +) + +type inboundData struct { + data []byte + from net.Addr +} + +// UDPConn is the implementation of the Conn and PacketConn interfaces for UDP network connections. +// compatible with net.PacketConn and net.Conn. +type UDPConn struct { + bindingMgr *bindingManager // Thread-safe + checkBindingsTimer *PeriodicTimer // Thread-safe + readCh chan *inboundData // Thread-safe + closeCh chan struct{} // Thread-safe + allocation +} + +// NewUDPConn creates a new instance of UDPConn. +func NewUDPConn(config *AllocationConfig) *UDPConn { + conn := &UDPConn{ + bindingMgr: newBindingManager(), + readCh: make(chan *inboundData, maxReadQueueSize), + closeCh: make(chan struct{}), + allocation: allocation{ + client: config.Client, + relayedAddr: config.RelayedAddr, + serverAddr: config.ServerAddr, + readTimer: time.NewTimer(time.Duration(math.MaxInt64)), + permMap: newPermissionMap(), + username: config.Username, + realm: config.Realm, + integrity: config.Integrity, + _nonce: config.Nonce, + _lifetime: config.Lifetime, + net: config.Net, + log: config.Log, + }, + } + + conn.log.Debugf("Initial lifetime: %d seconds", int(conn.lifetime().Seconds())) + + conn.refreshAllocTimer = NewPeriodicTimer( + timerIDRefreshAlloc, + conn.onRefreshTimers, + conn.lifetime()/2, + ) + + conn.refreshPermsTimer = NewPeriodicTimer( + timerIDRefreshPerms, + conn.onRefreshTimers, + permRefreshInterval, + ) + + conn.checkBindingsTimer = NewPeriodicTimer( + timerIDCheckBindings, + func(timerID int) { + for _, bound := range conn.bindingMgr.all() { + conn.maybeBind(bound) + } + }, + bindingCheckInterval, + ) + + if conn.refreshAllocTimer.Start() { + conn.log.Debugf("Started refresh allocation timer") + } + if conn.refreshPermsTimer.Start() { + conn.log.Debugf("Started refresh permission timer") + } + if conn.checkBindingsTimer.Start() { + conn.log.Debugf("Started check bindings timer") + } + + return conn +} + +// ReadFrom reads a packet from the connection, +// copying the payload into p. It returns the number of +// bytes copied into p and the return address that +// was on the packet. +// It returns the number of bytes read (0 <= n <= len(p)) +// and any error encountered. Callers should always process +// the n > 0 bytes returned before considering the error err. +// ReadFrom can be made to time out and return +// an Error with Timeout() == true after a fixed time limit; +// see SetDeadline and SetReadDeadline. +func (c *UDPConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + for { + select { + case ibData := <-c.readCh: + n := copy(p, ibData.data) + if n < len(ibData.data) { + return 0, nil, io.ErrShortBuffer + } + + return n, ibData.from, nil + + case <-c.readTimer.C: + return 0, nil, &net.OpError{ + Op: "read", + Net: c.LocalAddr().Network(), + Addr: c.LocalAddr(), + Err: newTimeoutError("i/o timeout"), + } + + case <-c.closeCh: + return 0, nil, &net.OpError{ + Op: "read", + Net: c.LocalAddr().Network(), + Addr: c.LocalAddr(), + Err: errClosed, + } + } + } +} + +func (a *allocation) createPermission(perm *permission, addr net.Addr) error { + perm.mutex.Lock() + defer perm.mutex.Unlock() + + if perm.state() == permStateIdle { + // Punch a hole! (this would block a bit..) + if err := a.CreatePermissions(addr); err != nil { + a.permMap.delete(addr) + + return err + } + perm.setState(permStatePermitted) + } + + return nil +} + +// WriteTo writes a packet with payload to addr. +// WriteTo can be made to time out and return +// an Error with Timeout() == true after a fixed time limit; +// see SetDeadline and SetWriteDeadline. +// On packet-oriented connections, write timeouts are rare. +func (c *UDPConn) WriteTo(payload []byte, addr net.Addr) (int, error) { //nolint:gocognit,cyclop + var err error + _, ok := addr.(*net.UDPAddr) + if !ok { + return 0, errUDPAddrCast + } + + // Check if we have a permission for the destination IP addr + perm, ok := c.permMap.find(addr) + if !ok { + perm = &permission{} + c.permMap.insert(addr, perm) + } + + for i := 0; i < maxRetryAttempts; i++ { + // c.createPermission() would block, per destination IP (, or perm), + // until the perm state becomes "requested". Purpose of this is to + // guarantee the order of packets (within the same perm). + // Note that CreatePermission transaction may not be complete before + // all the data transmission. This is done assuming that the request + // will be most likely successful and we can tolerate some loss of + // UDP packet (or reorder), inorder to minimize the latency in most cases. + if err = c.createPermission(perm, addr); !errors.Is(err, errTryAgain) { + break + } + } + if err != nil { + return 0, err + } + + // Bind channel + bound, ok := c.bindingMgr.findByAddr(addr) + if !ok { + bound = c.bindingMgr.create(addr) + } + + //nolint:nestif + if !bound.ok() { + // Try to establish an initial binding with the server. + // Writes still occur via indications meanwhile. + c.maybeBind(bound) + + // Send data using SendIndication + peerAddr := addr2PeerAddress(addr) + var msg *stun.Message + msg, err = stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodSend, stun.ClassIndication), + proto.Data(payload), + peerAddr, + stun.Fingerprint, + ) + if err != nil { + return 0, err + } + + return c.client.WriteTo(msg.Raw, c.serverAddr) + } + + // Binding is ready beyond this point, so send over it. + _, err = c.sendChannelData(payload, bound.number) + if err != nil { + return 0, err + } + + return len(payload), nil +} + +// Close closes the connection. +// Any blocked ReadFrom or WriteTo operations will be unblocked and return errors. +func (c *UDPConn) Close() error { + c.refreshAllocTimer.Stop() + c.refreshPermsTimer.Stop() + c.checkBindingsTimer.Stop() + + select { + case <-c.closeCh: + return errAlreadyClosed + default: + close(c.closeCh) + } + + c.client.OnDeallocated(c.relayedAddr) + + return c.refreshAllocation(0, true /* dontWait=true */) +} + +// LocalAddr returns the local network address. +func (c *UDPConn) LocalAddr() net.Addr { + return c.relayedAddr +} + +// SetDeadline sets the read and write deadlines associated +// with the connection. It is equivalent to calling both +// SetReadDeadline and SetWriteDeadline. +// +// A deadline is an absolute time after which I/O operations +// fail with a timeout (see type Error) instead of +// blocking. The deadline applies to all future and pending +// I/O, not just the immediately following call to ReadFrom or +// WriteTo. After a deadline has been exceeded, the connection +// can be refreshed by setting a deadline in the future. +// +// An idle timeout can be implemented by repeatedly extending +// the deadline after successful ReadFrom or WriteTo calls. +// +// A zero value for t means I/O operations will not time out. +func (c *UDPConn) SetDeadline(t time.Time) error { + return c.SetReadDeadline(t) +} + +// SetReadDeadline sets the deadline for future ReadFrom calls +// and any currently-blocked ReadFrom call. +// A zero value for t means ReadFrom will not time out. +func (c *UDPConn) SetReadDeadline(t time.Time) error { + var d time.Duration + if t.Equal(noDeadline()) { + d = time.Duration(math.MaxInt64) + } else { + d = time.Until(t) + } + c.readTimer.Reset(d) + + return nil +} + +// SetWriteDeadline sets the deadline for future WriteTo calls +// and any currently-blocked WriteTo call. +// Even if write times out, it may return n > 0, indicating that +// some of the data was successfully written. +// A zero value for t means WriteTo will not time out. +func (c *UDPConn) SetWriteDeadline(time.Time) error { + // Write never blocks. + return nil +} + +func addr2PeerAddress(addr net.Addr) proto.PeerAddress { + var peerAddr proto.PeerAddress + switch a := addr.(type) { + case *net.UDPAddr: + peerAddr.IP = a.IP + peerAddr.Port = a.Port + case *net.TCPAddr: + peerAddr.IP = a.IP + peerAddr.Port = a.Port + } + + return peerAddr +} + +// CreatePermissions Issues a CreatePermission request for the supplied addresses +// as described in https://datatracker.ietf.org/doc/html/rfc5766#section-9 +func (a *allocation) CreatePermissions(addrs ...net.Addr) error { + setters := []stun.Setter{ + stun.TransactionID, + stun.NewType(stun.MethodCreatePermission, stun.ClassRequest), + } + + for _, addr := range addrs { + setters = append(setters, addr2PeerAddress(addr)) + } + + setters = append(setters, + a.username, + a.realm, + a.nonce(), + a.integrity, + stun.Fingerprint) + + msg, err := stun.Build(setters...) + if err != nil { + return err + } + + trRes, err := a.client.PerformTransaction(msg, a.serverAddr, false) + if err != nil { + return err + } + + res := trRes.Msg + + if res.Type.Class == stun.ClassErrorResponse { + var code stun.ErrorCodeAttribute + if err = code.GetFrom(res); err == nil { + if code.Code == stun.CodeStaleNonce { + a.setNonceFromMsg(res) + + return errTryAgain + } + + turnError := &stun.TurnError{ + StunMessageType: res.Type, + ErrorCodeAttr: code, + } + + return turnError + } + + return fmt.Errorf("%s", res.Type) //nolint // dynamic errors + } + + return nil +} + +// HandleInbound passes inbound data in UDPConn. +func (c *UDPConn) HandleInbound(data []byte, from net.Addr) { + // Copy data + copied := make([]byte, len(data)) + copy(copied, data) + + select { + case c.readCh <- &inboundData{data: copied, from: from}: + default: + c.log.Warnf("Receive buffer full") + } +} + +// FindAddrByChannelNumber returns a peer address associated with the +// channel number on this UDPConn. +func (c *UDPConn) FindAddrByChannelNumber(chNum uint16) (net.Addr, bool) { + b, ok := c.bindingMgr.findByNumber(chNum) + if !ok { + return nil, false + } + + return b.addr, true +} + +func (c *UDPConn) maybeBind(bound *binding) { + bind := func() { + var err error + for i := 0; i < maxRetryAttempts; i++ { + if err = c.bind(bound); !errors.Is(err, errTryAgain) { + break + } + } + if err != nil { + c.log.Warnf("Failed to bind channel %d: %s", bound.number, err) + bound.setState(bindingStateFailed) + + return + } + bound.setRefreshedAt(time.Now()) + bound.setState(bindingStateReady) + } + + // Block only callers with the same binding until + // the binding transaction has been complete + bound.muBind.Lock() + defer bound.muBind.Unlock() + + state := bound.state() + switch { + case state == bindingStateIdle: + bound.setState(bindingStateRequest) + case state == bindingStateReady && time.Since(bound.refreshedAt()) > bindingRefreshInterval: + bound.setState(bindingStateRefresh) + default: + return + } + + // Establish binding with the server if eligible + // with regard to cases right above. + go bind() +} + +func (c *UDPConn) bind(bound *binding) error { + setters := []stun.Setter{ + stun.TransactionID, + stun.NewType(stun.MethodChannelBind, stun.ClassRequest), + addr2PeerAddress(bound.addr), + proto.ChannelNumber(bound.number), + c.username, + c.realm, + c.nonce(), + c.integrity, + stun.Fingerprint, + } + + msg, err := stun.Build(setters...) + if err != nil { + return err + } + + trRes, err := c.client.PerformTransaction(msg, c.serverAddr, false) + if err != nil { + c.bindingMgr.deleteByAddr(bound.addr) + + return err + } + + res := trRes.Msg + if res.Type.Class == stun.ClassErrorResponse { + var code stun.ErrorCodeAttribute + if err = code.GetFrom(res); err == nil { + if code.Code == stun.CodeStaleNonce { + c.setNonceFromMsg(res) + + return errTryAgain + } + } + return fmt.Errorf("unexpected response type %s", res.Type) //nolint // dynamic errors + } + + c.log.Debugf("Channel binding successful: %s %d", bound.addr, bound.number) + + // Success. + return nil +} + +func (c *UDPConn) sendChannelData(data []byte, chNum uint16) (int, error) { + chData := &proto.ChannelData{ + Data: data, + Number: proto.ChannelNumber(chNum), + } + chData.Encode() + _, err := c.client.WriteTo(chData.Raw, c.serverAddr) + if err != nil { + return 0, err + } + + return len(data), nil +} diff --git a/vendor/github.com/pion/turn/v4/internal/ipnet/util.go b/vendor/github.com/pion/turn/v4/internal/ipnet/util.go new file mode 100644 index 0000000..9753ef3 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/ipnet/util.go @@ -0,0 +1,56 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package ipnet contains helper functions around net and IP +package ipnet + +import ( + "errors" + "net" +) + +var errFailedToCastAddr = errors.New("failed to cast net.Addr to *net.UDPAddr or *net.TCPAddr") + +// AddrIPPort extracts the IP and Port from a net.Addr. +func AddrIPPort(a net.Addr) (net.IP, int, error) { + aUDP, ok := a.(*net.UDPAddr) + if ok { + return aUDP.IP, aUDP.Port, nil + } + + aTCP, ok := a.(*net.TCPAddr) + if ok { + return aTCP.IP, aTCP.Port, nil + } + + return nil, 0, errFailedToCastAddr +} + +// AddrEqual asserts that two net.Addrs are equal +// Currently only supports UDP but will be extended in the future to support others. +func AddrEqual(a, b net.Addr) bool { + aUDP, ok := a.(*net.UDPAddr) + if !ok { + return false + } + + bUDP, ok := b.(*net.UDPAddr) + if !ok { + return false + } + + return aUDP.IP.Equal(bUDP.IP) && aUDP.Port == bUDP.Port +} + +// FingerprintAddr generates a fingerprint from net.UDPAddr or net.TCPAddr's +// which can be used for indexing maps. +func FingerprintAddr(addr net.Addr) string { + switch a := addr.(type) { + case *net.UDPAddr: + return a.IP.String() + case *net.TCPAddr: // Do we really need this case? + return a.IP.String() + } + + return "" // Should never happen +} diff --git a/vendor/github.com/pion/turn/v4/internal/proto/addr.go b/vendor/github.com/pion/turn/v4/internal/proto/addr.go new file mode 100644 index 0000000..38b3027 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/proto/addr.go @@ -0,0 +1,70 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package proto + +import ( + "fmt" + "net" +) + +// Addr is ip:port. +type Addr struct { + IP net.IP + Port int +} + +// Network implements net.Addr. +func (Addr) Network() string { return "turn" } + +// FromUDPAddr sets addr to UDPAddr. +func (a *Addr) FromUDPAddr(n *net.UDPAddr) { + a.IP = n.IP + a.Port = n.Port +} + +// Equal returns true if b == a. +func (a Addr) Equal(b Addr) bool { + if a.Port != b.Port { + return false + } + + return a.IP.Equal(b.IP) +} + +// EqualIP returns true if a and b have equal IP addresses. +func (a Addr) EqualIP(b Addr) bool { + return a.IP.Equal(b.IP) +} + +func (a Addr) String() string { + return fmt.Sprintf("%s:%d", a.IP, a.Port) +} + +// FiveTuple represents 5-TUPLE value. +type FiveTuple struct { + Client Addr + Server Addr + Proto Protocol +} + +func (t FiveTuple) String() string { + return fmt.Sprintf("%s->%s (%s)", + t.Client, t.Server, t.Proto, + ) +} + +// Equal returns true if b == t. +func (t FiveTuple) Equal(b FiveTuple) bool { + if t.Proto != b.Proto { + return false + } + if !t.Client.Equal(b.Client) { + return false + } + if !t.Server.Equal(b.Server) { + return false + } + + return true +} diff --git a/vendor/github.com/pion/turn/v4/internal/proto/chandata.go b/vendor/github.com/pion/turn/v4/internal/proto/chandata.go new file mode 100644 index 0000000..4c9bb26 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/proto/chandata.go @@ -0,0 +1,147 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package proto + +import ( + "bytes" + "encoding/binary" + "errors" + "io" +) + +// ChannelData represents The ChannelData Message. +// +// See RFC 5766 Section 11.4. +type ChannelData struct { + Data []byte // Can be sub slice of Raw + Length int // Ignored while encoding, len(Data) is used + Number ChannelNumber + Raw []byte +} + +// Equal returns true if compareTo == c. +func (c *ChannelData) Equal(compareTo *ChannelData) bool { + if c == nil && compareTo == nil { + return true + } + if c == nil || compareTo == nil { + return false + } + if c.Number != compareTo.Number { + return false + } + if len(c.Data) != len(compareTo.Data) { + return false + } + + return bytes.Equal(c.Data, compareTo.Data) +} + +// Grow ensures that internal buffer will fit v more bytes and +// increases it capacity if necessary. +// +// Similar to stun.Message.grow method. +func (c *ChannelData) grow(v int) { + n := len(c.Raw) + v + for cap(c.Raw) < n { + c.Raw = append(c.Raw, 0) + } + c.Raw = c.Raw[:n] +} + +// Reset resets Length, Data and Raw length. +func (c *ChannelData) Reset() { + c.Raw = c.Raw[:0] + c.Length = 0 + c.Data = c.Data[:0] +} + +// Encode encodes ChannelData Message to Raw. +func (c *ChannelData) Encode() { + c.Raw = c.Raw[:0] + c.WriteHeader() + c.Raw = append(c.Raw, c.Data...) + padded := nearestPaddedValueLength(len(c.Raw)) + if bytesToAdd := padded - len(c.Raw); bytesToAdd > 0 { + for i := 0; i < bytesToAdd; i++ { + c.Raw = append(c.Raw, 0) + } + } +} + +const padding = 4 + +func nearestPaddedValueLength(l int) int { + n := padding * (l / padding) + if n < l { + n += padding + } + + return n +} + +// WriteHeader writes channel number and length. +func (c *ChannelData) WriteHeader() { + if len(c.Raw) < channelDataHeaderSize { + // Making WriteHeader call valid even when c.Raw + // is nil or len(c.Raw) is less than needed for header. + c.grow(channelDataHeaderSize) + } + // Early bounds check to guarantee safety of writes below. + _ = c.Raw[:channelDataHeaderSize] + binary.BigEndian.PutUint16(c.Raw[:channelDataNumberSize], uint16(c.Number)) + binary.BigEndian.PutUint16(c.Raw[channelDataNumberSize:channelDataHeaderSize], + uint16(len(c.Data)), // nolint:gosec // G115 + ) +} + +// ErrBadChannelDataLength means that channel data length is not equal +// to actual data length. +var ErrBadChannelDataLength = errors.New("channelData length != len(Data)") + +// Decode decodes The ChannelData Message from Raw. +func (c *ChannelData) Decode() error { + buf := c.Raw + if len(buf) < channelDataHeaderSize { + return io.ErrUnexpectedEOF + } + num := binary.BigEndian.Uint16(buf[:channelDataNumberSize]) + c.Number = ChannelNumber(num) + l := binary.BigEndian.Uint16(buf[channelDataNumberSize:channelDataHeaderSize]) + c.Data = buf[channelDataHeaderSize:] + c.Length = int(l) + if !c.Number.Valid() { + return ErrInvalidChannelNumber + } + if int(l) < len(c.Data) { + c.Data = c.Data[:int(l)] + } + if int(l) > len(buf[channelDataHeaderSize:]) { + return ErrBadChannelDataLength + } + + return nil +} + +const ( + channelDataLengthSize = 2 + channelDataNumberSize = channelDataLengthSize + channelDataHeaderSize = channelDataLengthSize + channelDataNumberSize +) + +// IsChannelData returns true if buf looks like the ChannelData Message. +func IsChannelData(buf []byte) bool { + if len(buf) < channelDataHeaderSize { + return false + } + + if int(binary.BigEndian.Uint16(buf[channelDataNumberSize:channelDataHeaderSize])) > len(buf[channelDataHeaderSize:]) { + return false + } + + // Quick check for channel number. + num := binary.BigEndian.Uint16(buf[0:channelNumberSize]) + + return isChannelNumberValid(num) +} diff --git a/vendor/github.com/pion/turn/v4/internal/proto/chann.go b/vendor/github.com/pion/turn/v4/internal/proto/chann.go new file mode 100644 index 0000000..da017bf --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/proto/chann.go @@ -0,0 +1,71 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package proto + +import ( + "encoding/binary" + "errors" + "strconv" + + "github.com/pion/stun/v3" +) + +// ChannelNumber represents CHANNEL-NUMBER attribute. +// +// The CHANNEL-NUMBER attribute contains the number of the channel. +// +// RFC 5766 Section 14.1. +type ChannelNumber uint16 // Encoded as uint16 + +func (n ChannelNumber) String() string { return strconv.Itoa(int(n)) } + +// 16 bits of uint + 16 bits of RFFU = 0. +const channelNumberSize = 4 + +// AddTo adds CHANNEL-NUMBER to message. +func (n ChannelNumber) AddTo(m *stun.Message) error { + v := make([]byte, channelNumberSize) + binary.BigEndian.PutUint16(v[:2], uint16(n)) + // v[2:4] are zeroes (RFFU = 0) + m.Add(stun.AttrChannelNumber, v) + + return nil +} + +// GetFrom decodes CHANNEL-NUMBER from message. +func (n *ChannelNumber) GetFrom(m *stun.Message) error { + v, err := m.Get(stun.AttrChannelNumber) + if err != nil { + return err + } + if err = stun.CheckSize(stun.AttrChannelNumber, len(v), channelNumberSize); err != nil { + return err + } + _ = v[channelNumberSize-1] // Asserting length + *n = ChannelNumber(binary.BigEndian.Uint16(v[:2])) + // v[2:4] is RFFU and equals to 0. + return nil +} + +// See https://tools.ietf.org/html/rfc5766#section-11: +// +// 0x4000 through 0x7FFF: These values are the allowed channel +// numbers (16,383 possible values). +const ( + MinChannelNumber = 0x4000 + MaxChannelNumber = 0x7FFF +) + +// ErrInvalidChannelNumber means that channel number is not valid as by RFC 5766 Section 11. +var ErrInvalidChannelNumber = errors.New("channel number not in [0x4000, 0x7FFF]") + +// isChannelNumberValid returns true if c in [0x4000, 0x7FFF]. +func isChannelNumberValid(c uint16) bool { + return c >= MinChannelNumber && c <= MaxChannelNumber +} + +// Valid returns true if channel number has correct value that complies RFC 5766 Section 11 range. +func (n ChannelNumber) Valid() bool { + return isChannelNumberValid(uint16(n)) +} diff --git a/vendor/github.com/pion/turn/v4/internal/proto/connection_id.go b/vendor/github.com/pion/turn/v4/internal/proto/connection_id.go new file mode 100644 index 0000000..d95d9b0 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/proto/connection_id.go @@ -0,0 +1,44 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package proto + +import ( + "encoding/binary" + + "github.com/pion/stun/v3" +) + +// ConnectionID represents CONNECTION-ID attribute. +// +// The CONNECTION-ID attribute uniquely identifies a peer data +// connection. It is a 32-bit unsigned integral value. +// +// RFC 6062 Section 6.2.1. +type ConnectionID uint32 + +const connectionIDSize = 4 // uint32: 4 bytes, 32 bits + +// AddTo adds CONNECTION-ID to message. +func (c ConnectionID) AddTo(m *stun.Message) error { + v := make([]byte, lifetimeSize) + binary.BigEndian.PutUint32(v, uint32(c)) + m.Add(stun.AttrConnectionID, v) + + return nil +} + +// GetFrom decodes CONNECTION-ID from message. +func (c *ConnectionID) GetFrom(m *stun.Message) error { + v, err := m.Get(stun.AttrConnectionID) + if err != nil { + return err + } + if err = stun.CheckSize(stun.AttrConnectionID, len(v), connectionIDSize); err != nil { + return err + } + _ = v[connectionIDSize-1] // Asserting length + *(*uint32)(c) = binary.BigEndian.Uint32(v) + + return nil +} diff --git a/vendor/github.com/pion/turn/v4/internal/proto/data.go b/vendor/github.com/pion/turn/v4/internal/proto/data.go new file mode 100644 index 0000000..ea2f0c8 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/proto/data.go @@ -0,0 +1,35 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package proto + +import "github.com/pion/stun/v3" + +// Data represents DATA attribute. +// +// The DATA attribute is present in all Send and Data indications. The +// value portion of this attribute is variable length and consists of +// the application data (that is, the data that would immediately follow +// the UDP header if the data was been sent directly between the client +// and the peer). +// +// RFC 5766 Section 14.4. +type Data []byte + +// AddTo adds DATA to message. +func (d Data) AddTo(m *stun.Message) error { + m.Add(stun.AttrData, d) + + return nil +} + +// GetFrom decodes DATA from message. +func (d *Data) GetFrom(m *stun.Message) error { + v, err := m.Get(stun.AttrData) + if err != nil { + return err + } + *d = v + + return nil +} diff --git a/vendor/github.com/pion/turn/v4/internal/proto/dontfrag.go b/vendor/github.com/pion/turn/v4/internal/proto/dontfrag.go new file mode 100644 index 0000000..d46ae8f --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/proto/dontfrag.go @@ -0,0 +1,48 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package proto + +import ( + "github.com/pion/stun/v3" +) + +// DontFragmentAttr is a deprecated alias for DontFragment +// Deprecated: Please use DontFragment. +type DontFragmentAttr = DontFragment + +// DontFragment represents DONT-FRAGMENT attribute. +// +// This attribute is used by the client to request that the server set +// the DF (Don't Fragment) bit in the IP header when relaying the +// application data onward to the peer. This attribute has no value +// part and thus the attribute length field is 0. +// +// RFC 5766 Section 14.8. +type DontFragment struct{} + +const dontFragmentSize = 0 + +// AddTo adds DONT-FRAGMENT attribute to message. +func (DontFragment) AddTo(m *stun.Message) error { + m.Add(stun.AttrDontFragment, nil) + + return nil +} + +// GetFrom decodes DONT-FRAGMENT from message. +func (d *DontFragment) GetFrom(m *stun.Message) error { + v, err := m.Get(stun.AttrDontFragment) + if err != nil { + return err + } + + return stun.CheckSize(stun.AttrDontFragment, len(v), dontFragmentSize) +} + +// IsSet returns true if DONT-FRAGMENT attribute is set. +func (DontFragment) IsSet(m *stun.Message) bool { + _, err := m.Get(stun.AttrDontFragment) + + return err == nil +} diff --git a/vendor/github.com/pion/turn/v4/internal/proto/evenport.go b/vendor/github.com/pion/turn/v4/internal/proto/evenport.go new file mode 100644 index 0000000..6989ab8 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/proto/evenport.go @@ -0,0 +1,61 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package proto + +import "github.com/pion/stun/v3" + +// EvenPort represents EVEN-PORT attribute. +// +// This attribute allows the client to request that the port in the +// relayed transport address be even, and (optionally) that the server +// reserve the next-higher port number. +// +// RFC 5766 Section 14.6. +type EvenPort struct { + // ReservePort means that the server is requested to reserve + // the next-higher port number (on the same IP address) + // for a subsequent allocation. + ReservePort bool +} + +func (p EvenPort) String() string { + if p.ReservePort { + return "reserve: true" + } + + return "reserve: false" +} + +const ( + evenPortSize = 1 + firstBitSet = (1 << 8) - 1 // 0b100000000 +) + +// AddTo adds EVEN-PORT to message. +func (p EvenPort) AddTo(m *stun.Message) error { + v := make([]byte, evenPortSize) + if p.ReservePort { + // Set first bit to 1. + v[0] = firstBitSet + } + m.Add(stun.AttrEvenPort, v) + + return nil +} + +// GetFrom decodes EVEN-PORT from message. +func (p *EvenPort) GetFrom(m *stun.Message) error { + v, err := m.Get(stun.AttrEvenPort) + if err != nil { + return err + } + if err = stun.CheckSize(stun.AttrEvenPort, len(v), evenPortSize); err != nil { + return err + } + if v[0]&firstBitSet > 0 { + p.ReservePort = true + } + + return nil +} diff --git a/vendor/github.com/pion/turn/v4/internal/proto/lifetime.go b/vendor/github.com/pion/turn/v4/internal/proto/lifetime.go new file mode 100644 index 0000000..37b86f8 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/proto/lifetime.go @@ -0,0 +1,57 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package proto + +import ( + "encoding/binary" + "time" + + "github.com/pion/stun/v3" +) + +// DefaultLifetime in RFC 5766 is 10 minutes. +// +// RFC 5766 Section 2.2. +const DefaultLifetime = time.Minute * 10 + +// Lifetime represents LIFETIME attribute. +// +// The LIFETIME attribute represents the duration for which the server +// will maintain an allocation in the absence of a refresh. The value +// portion of this attribute is 4-bytes long and consists of a 32-bit +// unsigned integral value representing the number of seconds remaining +// until expiration. +// +// RFC 5766 Section 14.2. +type Lifetime struct { + time.Duration +} + +// Seconds in uint32. +const lifetimeSize = 4 // 4 bytes, 32 bits + +// AddTo adds LIFETIME to message. +func (l Lifetime) AddTo(m *stun.Message) error { + v := make([]byte, lifetimeSize) + binary.BigEndian.PutUint32(v, uint32(l.Seconds())) + m.Add(stun.AttrLifetime, v) + + return nil +} + +// GetFrom decodes LIFETIME from message. +func (l *Lifetime) GetFrom(m *stun.Message) error { + v, err := m.Get(stun.AttrLifetime) + if err != nil { + return err + } + if err = stun.CheckSize(stun.AttrLifetime, len(v), lifetimeSize); err != nil { + return err + } + _ = v[lifetimeSize-1] // Asserting length + seconds := binary.BigEndian.Uint32(v) + l.Duration = time.Second * time.Duration(seconds) + + return nil +} diff --git a/vendor/github.com/pion/turn/v4/internal/proto/peeraddr.go b/vendor/github.com/pion/turn/v4/internal/proto/peeraddr.go new file mode 100644 index 0000000..a919bd1 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/proto/peeraddr.go @@ -0,0 +1,45 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package proto + +import ( + "net" + + "github.com/pion/stun/v3" +) + +// PeerAddress implements XOR-PEER-ADDRESS attribute. +// +// The XOR-PEER-ADDRESS specifies the address and port of the peer as +// seen from the TURN server. (For example, the peer's server-reflexive +// transport address if the peer is behind a NAT.) +// +// RFC 5766 Section 14.3. +type PeerAddress struct { + IP net.IP + Port int +} + +func (a PeerAddress) String() string { + return stun.XORMappedAddress(a).String() +} + +// AddTo adds XOR-PEER-ADDRESS to message. +func (a PeerAddress) AddTo(m *stun.Message) error { + return stun.XORMappedAddress(a).AddToAs(m, stun.AttrXORPeerAddress) +} + +// GetFrom decodes XOR-PEER-ADDRESS from message. +func (a *PeerAddress) GetFrom(m *stun.Message) error { + return (*stun.XORMappedAddress)(a).GetFromAs(m, stun.AttrXORPeerAddress) +} + +// XORPeerAddress implements XOR-PEER-ADDRESS attribute. +// +// The XOR-PEER-ADDRESS specifies the address and port of the peer as +// seen from the TURN server. (For example, the peer's server-reflexive +// transport address if the peer is behind a NAT.) +// +// RFC 5766 Section 14.3. +type XORPeerAddress = PeerAddress diff --git a/vendor/github.com/pion/turn/v4/internal/proto/proto.go b/vendor/github.com/pion/turn/v4/internal/proto/proto.go new file mode 100644 index 0000000..170cf7f --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/proto/proto.go @@ -0,0 +1,31 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package proto implements RFC 5766 Traversal Using Relays around NAT. +package proto + +import ( + "github.com/pion/stun/v3" +) + +// Default ports for TURN from RFC 5766 Section 4. +const ( + // DefaultPort for TURN is same as STUN. + DefaultPort = stun.DefaultPort + // DefaultTLSPort is for TURN over TLS and is same as STUN. + DefaultTLSPort = stun.DefaultTLSPort +) + +// CreatePermissionRequest is shorthand for create permission request type. +func CreatePermissionRequest() stun.MessageType { + return stun.NewType(stun.MethodCreatePermission, stun.ClassRequest) +} + +// AllocateRequest is shorthand for allocation request message type. +func AllocateRequest() stun.MessageType { return stun.NewType(stun.MethodAllocate, stun.ClassRequest) } + +// SendIndication is shorthand for send indication message type. +func SendIndication() stun.MessageType { return stun.NewType(stun.MethodSend, stun.ClassIndication) } + +// RefreshRequest is shorthand for refresh request message type. +func RefreshRequest() stun.MessageType { return stun.NewType(stun.MethodRefresh, stun.ClassRequest) } diff --git a/vendor/github.com/pion/turn/v4/internal/proto/relayedaddr.go b/vendor/github.com/pion/turn/v4/internal/proto/relayedaddr.go new file mode 100644 index 0000000..1d22ade --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/proto/relayedaddr.go @@ -0,0 +1,43 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package proto + +import ( + "net" + + "github.com/pion/stun/v3" +) + +// RelayedAddress implements XOR-RELAYED-ADDRESS attribute. +// +// It specifies the address and port that the server allocated to the +// client. It is encoded in the same way as XOR-MAPPED-ADDRESS. +// +// RFC 5766 Section 14.5. +type RelayedAddress struct { + IP net.IP + Port int +} + +func (a RelayedAddress) String() string { + return stun.XORMappedAddress(a).String() +} + +// AddTo adds XOR-PEER-ADDRESS to message. +func (a RelayedAddress) AddTo(m *stun.Message) error { + return stun.XORMappedAddress(a).AddToAs(m, stun.AttrXORRelayedAddress) +} + +// GetFrom decodes XOR-PEER-ADDRESS from message. +func (a *RelayedAddress) GetFrom(m *stun.Message) error { + return (*stun.XORMappedAddress)(a).GetFromAs(m, stun.AttrXORRelayedAddress) +} + +// XORRelayedAddress implements XOR-RELAYED-ADDRESS attribute. +// +// It specifies the address and port that the server allocated to the +// client. It is encoded in the same way as XOR-MAPPED-ADDRESS. +// +// RFC 5766 Section 14.5. +type XORRelayedAddress = RelayedAddress diff --git a/vendor/github.com/pion/turn/v4/internal/proto/reqfamily.go b/vendor/github.com/pion/turn/v4/internal/proto/reqfamily.go new file mode 100644 index 0000000..d00c302 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/proto/reqfamily.go @@ -0,0 +1,66 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package proto + +import ( + "errors" + + "github.com/pion/stun/v3" +) + +// RequestedAddressFamily represents the REQUESTED-ADDRESS-FAMILY Attribute as +// defined in RFC 6156 Section 4.1.1. +type RequestedAddressFamily byte + +const requestedFamilySize = 4 + +var errInvalidRequestedFamilyValue = errors.New("invalid value for requested family attribute") + +// GetFrom decodes REQUESTED-ADDRESS-FAMILY from message. +func (f *RequestedAddressFamily) GetFrom(m *stun.Message) error { + v, err := m.Get(stun.AttrRequestedAddressFamily) + if err != nil { + return err + } + if err = stun.CheckSize(stun.AttrRequestedAddressFamily, len(v), requestedFamilySize); err != nil { + return err + } + switch v[0] { + case byte(RequestedFamilyIPv4), byte(RequestedFamilyIPv6): + *f = RequestedAddressFamily(v[0]) + default: + return errInvalidRequestedFamilyValue + } + + return nil +} + +func (f RequestedAddressFamily) String() string { + switch f { + case RequestedFamilyIPv4: + return "IPv4" + case RequestedFamilyIPv6: + return "IPv6" + default: + return "unknown" + } +} + +// AddTo adds REQUESTED-ADDRESS-FAMILY to message. +func (f RequestedAddressFamily) AddTo(m *stun.Message) error { + v := make([]byte, requestedFamilySize) + v[0] = byte(f) + // b[1:4] is RFFU = 0. + // The RFFU field MUST be set to zero on transmission and MUST be + // ignored on reception. It is reserved for future uses. + m.Add(stun.AttrRequestedAddressFamily, v) + + return nil +} + +// Values for RequestedAddressFamily as defined in RFC 6156 Section 4.1.1. +const ( + RequestedFamilyIPv4 RequestedAddressFamily = 0x01 + RequestedFamilyIPv6 RequestedAddressFamily = 0x02 +) diff --git a/vendor/github.com/pion/turn/v4/internal/proto/reqtrans.go b/vendor/github.com/pion/turn/v4/internal/proto/reqtrans.go new file mode 100644 index 0000000..b907e38 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/proto/reqtrans.go @@ -0,0 +1,74 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package proto + +import ( + "strconv" + + "github.com/pion/stun/v3" +) + +// Protocol is IANA assigned protocol number. +type Protocol byte + +const ( + // ProtoTCP is IANA assigned protocol number for TCP. + ProtoTCP Protocol = 6 + // ProtoUDP is IANA assigned protocol number for UDP. + ProtoUDP Protocol = 17 +) + +func (p Protocol) String() string { + switch p { + case ProtoTCP: + return "TCP" + case ProtoUDP: + return "UDP" + default: + return strconv.Itoa(int(p)) + } +} + +// RequestedTransport represents REQUESTED-TRANSPORT attribute. +// +// This attribute is used by the client to request a specific transport +// protocol for the allocated transport address. RFC 5766 only allows the use of +// code point 17 (User Datagram Protocol). +// +// RFC 5766 Section 14.7. +type RequestedTransport struct { + Protocol Protocol +} + +func (t RequestedTransport) String() string { + return "protocol: " + t.Protocol.String() +} + +const requestedTransportSize = 4 + +// AddTo adds REQUESTED-TRANSPORT to message. +func (t RequestedTransport) AddTo(m *stun.Message) error { + v := make([]byte, requestedTransportSize) + v[0] = byte(t.Protocol) + // b[1:4] is RFFU = 0. + // The RFFU field MUST be set to zero on transmission and MUST be + // ignored on reception. It is reserved for future uses. + m.Add(stun.AttrRequestedTransport, v) + + return nil +} + +// GetFrom decodes REQUESTED-TRANSPORT from message. +func (t *RequestedTransport) GetFrom(m *stun.Message) error { + v, err := m.Get(stun.AttrRequestedTransport) + if err != nil { + return err + } + if err = stun.CheckSize(stun.AttrRequestedTransport, len(v), requestedTransportSize); err != nil { + return err + } + t.Protocol = Protocol(v[0]) + + return nil +} diff --git a/vendor/github.com/pion/turn/v4/internal/proto/rsrvtoken.go b/vendor/github.com/pion/turn/v4/internal/proto/rsrvtoken.go new file mode 100644 index 0000000..aed3886 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/proto/rsrvtoken.go @@ -0,0 +1,44 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package proto + +import "github.com/pion/stun/v3" + +// ReservationToken represents RESERVATION-TOKEN attribute. +// +// The RESERVATION-TOKEN attribute contains a token that uniquely +// identifies a relayed transport address being held in reserve by the +// server. The server includes this attribute in a success response to +// tell the client about the token, and the client includes this +// attribute in a subsequent Allocate request to request the server use +// that relayed transport address for the allocation. +// +// RFC 5766 Section 14.9. +type ReservationToken []byte + +const reservationTokenSize = 8 // 8 bytes + +// AddTo adds RESERVATION-TOKEN to message. +func (t ReservationToken) AddTo(m *stun.Message) error { + if err := stun.CheckSize(stun.AttrReservationToken, len(t), reservationTokenSize); err != nil { + return err + } + m.Add(stun.AttrReservationToken, t) + + return nil +} + +// GetFrom decodes RESERVATION-TOKEN from message. +func (t *ReservationToken) GetFrom(m *stun.Message) error { + v, err := m.Get(stun.AttrReservationToken) + if err != nil { + return err + } + if err = stun.CheckSize(stun.AttrReservationToken, len(v), reservationTokenSize); err != nil { + return err + } + *t = v + + return nil +} diff --git a/vendor/github.com/pion/turn/v4/internal/server/base36.go b/vendor/github.com/pion/turn/v4/internal/server/base36.go new file mode 100644 index 0000000..ea1121a --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/server/base36.go @@ -0,0 +1,63 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package server + +import ( + "math/big" + "strings" +) + +// Base36 alphabet for encoding. +const base36Alphabet = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" + +// EncodeBase36 converts bytes to base36 string using big.Int for arbitrary length. +func encodeBase36(data []byte) string { + if len(data) == 0 { + return "" + } + + num := new(big.Int).SetBytes(data) + if num.Cmp(big.NewInt(0)) == 0 { + return "0" + } + + base := big.NewInt(36) + buf := make([]byte, 0, len(data)*2) + remainder := new(big.Int) + for num.Cmp(big.NewInt(0)) > 0 { + num.DivMod(num, base, remainder) + buf = append(buf, base36Alphabet[remainder.Int64()]) + } + + for i, j := 0, len(buf)-1; i < j; i, j = i+1, j-1 { + buf[i], buf[j] = buf[j], buf[i] + } + + return string(buf) +} + +// DecodeBase36 converts base36 string back to bytes using big.Int for arbitrary length. +func decodeBase36(encoded string) []byte { + if encoded == "" { + return []byte{} + } + + if encoded == "0" { + return []byte{0} + } + + num := big.NewInt(0) + base := big.NewInt(36) + + for _, char := range strings.ToUpper(encoded) { + digit := strings.IndexRune(base36Alphabet, char) + if digit == -1 { + return nil + } + num.Mul(num, base) + num.Add(num, big.NewInt(int64(digit))) + } + + return num.Bytes() +} diff --git a/vendor/github.com/pion/turn/v4/internal/server/errors.go b/vendor/github.com/pion/turn/v4/internal/server/errors.go new file mode 100644 index 0000000..ea6da83 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/server/errors.go @@ -0,0 +1,29 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package server + +import "errors" + +var ( + errFailedToGenerateNonce = errors.New("failed to generate nonce") + errInvalidNonce = errors.New("invalid nonce") + errFailedToSendError = errors.New("failed to send error message") + errNoSuchUser = errors.New("no such user exists") + errUnexpectedClass = errors.New("unexpected class") + errUnexpectedMethod = errors.New("unexpected method") + errFailedToHandle = errors.New("failed to handle") + errUnhandledSTUNPacket = errors.New("unhandled STUN packet") + errUnableToHandleChannelData = errors.New("unable to handle ChannelData") + errFailedToCreateSTUNPacket = errors.New("failed to create stun message from packet") + errFailedToCreateChannelData = errors.New("failed to create channel data from packet") + errRelayAlreadyAllocatedForFiveTuple = errors.New("relay already allocated for 5-TUPLE") + errUnsupportedTransportProtocol = errors.New("RequestedTransport must be UDP or TCP") + errNoDontFragmentSupport = errors.New("no support for DONT-FRAGMENT") + errRequestWithReservationTokenAndEvenPort = errors.New("Request must not contain RESERVATION-TOKEN and EVEN-PORT") + errNoAllocationFound = errors.New("no allocation found") + errNoPermission = errors.New("unable to handle send-indication, no permission added") + errShortWrite = errors.New("packet write smaller than packet") + errNoSuchChannelBind = errors.New("no such channel bind") + errFailedWriteSocket = errors.New("failed writing to socket") +) diff --git a/vendor/github.com/pion/turn/v4/internal/server/nonce.go b/vendor/github.com/pion/turn/v4/internal/server/nonce.go new file mode 100644 index 0000000..3b432a7 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/server/nonce.go @@ -0,0 +1,71 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package server + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "fmt" + "time" +) + +const ( + nonceLifetime = time.Hour // See: https://tools.ietf.org/html/rfc5766#section-4 + nonceLength = 40 + nonceKeyLength = 64 +) + +// NewNonceHash creates a NonceHash. +func NewNonceHash() (NonceManager, error) { + key := make([]byte, nonceKeyLength) + if _, err := rand.Read(key); err != nil { + return nil, err + } + + return &NonceHash{key}, nil +} + +// NonceHash is used to create and verify nonces. +type NonceHash struct { + key []byte +} + +// Generate a nonce. +func (n *NonceHash) Generate() (string, error) { + nonce := make([]byte, 8, nonceLength) + binary.BigEndian.PutUint64(nonce, uint64(time.Now().UnixMilli())) // nolint:gosec // G115 + + hash := hmac.New(sha256.New, n.key) + if _, err := hash.Write(nonce[:8]); err != nil { + return "", fmt.Errorf("%w: %v", errFailedToGenerateNonce, err) //nolint:errorlint + } + nonce = hash.Sum(nonce) + + return hex.EncodeToString(nonce), nil +} + +// Validate checks that nonce is signed and is not expired. +func (n *NonceHash) Validate(nonce string) error { + b, err := hex.DecodeString(nonce) + if err != nil || len(b) != nonceLength { + return fmt.Errorf("%w: %v", errInvalidNonce, err) //nolint:errorlint + } + + if ts := time.UnixMilli(int64(binary.BigEndian.Uint64(b))); time.Since(ts) > nonceLifetime { // nolint:gosec // G115 + return errInvalidNonce + } + + hash := hmac.New(sha256.New, n.key) + if _, err = hash.Write(b[:8]); err != nil { + return fmt.Errorf("%w: %v", errInvalidNonce, err) //nolint:errorlint + } + if !hmac.Equal(b[8:], hash.Sum(nil)) { + return errInvalidNonce + } + + return nil +} diff --git a/vendor/github.com/pion/turn/v4/internal/server/server.go b/vendor/github.com/pion/turn/v4/internal/server/server.go new file mode 100644 index 0000000..1bbfb0c --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/server/server.go @@ -0,0 +1,135 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package server implements the private API to implement a TURN server +package server + +import ( + "fmt" + "net" + "time" + + "github.com/pion/logging" + "github.com/pion/stun/v3" + "github.com/pion/turn/v4/internal/allocation" + "github.com/pion/turn/v4/internal/proto" +) + +// Request contains all the state needed to process a single incoming datagram. +type Request struct { + // Current Request State + Conn net.PacketConn + SrcAddr net.Addr + Buff []byte + + // Server State + AllocationManager *allocation.Manager + NonceHash NonceManager + + // User Configuration + AuthHandler func(username string, realm string, srcAddr net.Addr) (key []byte, ok bool) + + // Quota Handler + QuotaHandler func(username string, realm string, srcAddr net.Addr) (ok bool) + + Log logging.LeveledLogger + Realm string + ChannelBindTimeout time.Duration +} + +// HandleRequest processes the give Request. +func HandleRequest(r Request) error { + r.Log.Debugf("Received %d bytes of udp from %s on %s", len(r.Buff), r.SrcAddr, r.Conn.LocalAddr()) + + if proto.IsChannelData(r.Buff) { + return handleDataPacket(r) + } + + return handleTURNPacket(r) +} + +func handleDataPacket(req Request) error { + req.Log.Debugf("Received DataPacket from %s", req.SrcAddr.String()) + c := proto.ChannelData{Raw: req.Buff} + if err := c.Decode(); err != nil { + return fmt.Errorf("%w: %v", errFailedToCreateChannelData, err) //nolint:errorlint + } + + err := handleChannelData(req, &c) + if err != nil { + err = fmt.Errorf("%w from %v: %v", errUnableToHandleChannelData, req.SrcAddr, err) //nolint:errorlint + } + + return err +} + +func handleTURNPacket(req Request) error { + req.Log.Debug("Handling TURN packet") + stunMsg := &stun.Message{Raw: append([]byte{}, req.Buff...)} + if err := stunMsg.Decode(); err != nil { + // nolint:errorlint + return fmt.Errorf("%w: %v", errFailedToCreateSTUNPacket, err) + } + + handler, err := getMessageHandler(stunMsg.Type.Class, stunMsg.Type.Method) + if err != nil { + // nolint:errorlint + return fmt.Errorf( + "%w %v-%v from %v: %v", + errUnhandledSTUNPacket, + stunMsg.Type.Method, + stunMsg.Type.Class, + req.SrcAddr, + err, + ) + } + + err = handler(req, stunMsg) + if err != nil { + // nolint:errorlint + return fmt.Errorf( + "%w %v-%v from %v: %v", + errFailedToHandle, + stunMsg.Type.Method, + stunMsg.Type.Class, + req.SrcAddr, + err, + ) + } + + return nil +} + +func getMessageHandler(class stun.MessageClass, method stun.Method) ( // nolint:cyclop + func(req Request, stunMsg *stun.Message) error, + error, +) { + switch class { + case stun.ClassIndication: + switch method { + case stun.MethodSend: + return handleSendIndication, nil + default: + return nil, fmt.Errorf("%w: %s", errUnexpectedMethod, method) + } + + case stun.ClassRequest: + switch method { + case stun.MethodAllocate: + return handleAllocateRequest, nil + case stun.MethodRefresh: + return handleRefreshRequest, nil + case stun.MethodCreatePermission: + return handleCreatePermissionRequest, nil + case stun.MethodChannelBind: + return handleChannelBindRequest, nil + case stun.MethodBinding: + return handleBindingRequest, nil + default: + return nil, fmt.Errorf("%w: %s", errUnexpectedMethod, method) + } + + default: + return nil, fmt.Errorf("%w: %s", errUnexpectedClass, class) + } +} diff --git a/vendor/github.com/pion/turn/v4/internal/server/short_nonce.go b/vendor/github.com/pion/turn/v4/internal/server/short_nonce.go new file mode 100644 index 0000000..fc8b28c --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/server/short_nonce.go @@ -0,0 +1,132 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package server + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/binary" + "fmt" + "time" +) + +// NonceManager interface that both implementations satisfy. +type NonceManager interface { + Generate() (string, error) + Validate(nonce string) error +} + +const ( + shortNonceLifetime = time.Hour // Same as original + shortNonceKeyLength = 64 // Same as original + shortNonceTimestampLen = 4 // 6 bytes for timestamp (minutes) - optimal size + shortNonceMinHMACLen = 2 // Minimum HMAC length for security + shortNonceMaxHMACLen = 32 // Maximum HMAC length (full SHA256) + defaultNonceHMACLen = 12 // Default HMAC length +) + +// NewShortNonceHash creates a ShortNonceHash. The hmacLen argument specifies the number of HMAC +// bytes to include (2-32 bytes). The total nonce size will be 4 + hmacLen bytes, default hmaclen +// is 12 bytes. The 4 bytes timestamp gives about ~8000 years before nonces would start to repeat +// (safe until year 10,135). +func NewShortNonceHash(hmacLen int) (NonceManager, error) { + if hmacLen == 0 { + hmacLen = defaultNonceHMACLen + } + + if hmacLen < shortNonceMinHMACLen || hmacLen > shortNonceMaxHMACLen { + return nil, errFailedToGenerateNonce + } + + key := make([]byte, shortNonceKeyLength) + if _, err := rand.Read(key); err != nil { + return nil, fmt.Errorf("%w: %w", errFailedToGenerateNonce, err) + } + + return &ShortNonceHash{ + key: key, + hmacLen: hmacLen, + }, nil +} + +// ShortNonceHash is used to create and verify short nonces. +type ShortNonceHash struct { + key []byte + hmacLen int +} + +// Generate a short nonce (4 + hmacLen bytes encoded as base36). +func (s *ShortNonceHash) Generate() (string, error) { + timestampMinutes := time.Now().Unix() / 60 + + // Convert to bytes and trim to 4 bytes. This safely handles the conversion since we know + // current values fit in 4 bytes until year 10,135. + timestampBytes8 := make([]byte, 8) + binary.BigEndian.PutUint64(timestampBytes8, uint64(timestampMinutes)) // nolint:gosec // G115 + timestampBytes := timestampBytes8[4:] + + hash := hmac.New(sha256.New, s.key) + if _, err := hash.Write(timestampBytes); err != nil { + return "", fmt.Errorf("%w: %w", errFailedToGenerateNonce, err) + } + fullHMAC := hash.Sum(nil) + truncatedHMAC := fullHMAC[:s.hmacLen] + + totalLen := shortNonceTimestampLen + s.hmacLen + nonce := make([]byte, totalLen) + copy(nonce[:shortNonceTimestampLen], timestampBytes) + copy(nonce[shortNonceTimestampLen:], truncatedHMAC) + + return encodeBase36(nonce), nil +} + +// Validate checks that nonce is signed and is not expired. +func (s *ShortNonceHash) Validate(nonce string) error { + nonceBytes := decodeBase36(nonce) + if nonceBytes == nil { + return errInvalidNonce + } + + expectedLen := shortNonceTimestampLen + s.hmacLen + if len(nonceBytes) != expectedLen { + // Pad with leadnign zeros if leading zeros were stripped during encoding/decoding. + if len(nonceBytes) < expectedLen { + padded := make([]byte, expectedLen) + copy(padded[expectedLen-len(nonceBytes):], nonceBytes) + nonceBytes = padded + } else { + return errInvalidNonce + } + } + + timestampBytes := nonceBytes[:shortNonceTimestampLen] + receivedHMAC := nonceBytes[shortNonceTimestampLen:] + timestampMinutes := int64(binary.BigEndian.Uint32(timestampBytes)) + + // Check if nonce is expired (older than 1 hour). + currentMinutes := time.Now().Unix() / 60 + if currentMinutes < timestampMinutes { + return errInvalidNonce + } + + ageMinutes := currentMinutes - timestampMinutes + if ageMinutes > 60 { + return errInvalidNonce + } + + // Recompute HMAC and compare. + hash := hmac.New(sha256.New, s.key) + if _, err := hash.Write(timestampBytes); err != nil { + return fmt.Errorf("%w: %w", errInvalidNonce, err) + } + + expectedHMAC := hash.Sum(nil)[:s.hmacLen] + + if !hmac.Equal(receivedHMAC, expectedHMAC) { + return errInvalidNonce + } + + return nil +} diff --git a/vendor/github.com/pion/turn/v4/internal/server/stun.go b/vendor/github.com/pion/turn/v4/internal/server/stun.go new file mode 100644 index 0000000..1880bf1 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/server/stun.go @@ -0,0 +1,25 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package server + +import ( + "github.com/pion/stun/v3" + "github.com/pion/turn/v4/internal/ipnet" +) + +func handleBindingRequest(req Request, stunMsg *stun.Message) error { + req.Log.Debugf("Received BindingRequest from %s", req.SrcAddr) + + ip, port, err := ipnet.AddrIPPort(req.SrcAddr) + if err != nil { + return err + } + + attrs := buildMsg(stunMsg.TransactionID, stun.BindingSuccess, &stun.XORMappedAddress{ + IP: ip, + Port: port, + }, stun.Fingerprint) + + return buildAndSend(req.Conn, req.SrcAddr, attrs...) +} diff --git a/vendor/github.com/pion/turn/v4/internal/server/turn.go b/vendor/github.com/pion/turn/v4/internal/server/turn.go new file mode 100644 index 0000000..60dd3a6 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/server/turn.go @@ -0,0 +1,461 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package server + +import ( + "fmt" + "net" + + "github.com/pion/randutil" + "github.com/pion/stun/v3" + "github.com/pion/turn/v4/internal/allocation" + "github.com/pion/turn/v4/internal/ipnet" + "github.com/pion/turn/v4/internal/proto" +) + +const runesAlpha = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + +// See: https://tools.ietf.org/html/rfc5766#section-6.2 +// . +func handleAllocateRequest(req Request, stunMsg *stun.Message) error { //nolint:cyclop + req.Log.Debugf("Received AllocateRequest from %s", req.SrcAddr) + + // 1. The server MUST require that the request be authenticated. This + // authentication MUST be done using the long-term credential + // mechanism of [https://tools.ietf.org/html/rfc5389#section-10.2.2] + // unless the client and server agree to use another mechanism through + // some procedure outside the scope of this document. + messageIntegrity, hasAuth, err := authenticateRequest(req, stunMsg, stun.MethodAllocate) + if !hasAuth { + return err + } + + fiveTuple := &allocation.FiveTuple{ + SrcAddr: req.SrcAddr, + DstAddr: req.Conn.LocalAddr(), + Protocol: allocation.UDP, + } + requestedPort := 0 + reservationToken := "" + + badRequestMsg := buildMsg( + stunMsg.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), + &stun.ErrorCodeAttribute{Code: stun.CodeBadRequest}, + ) + insufficientCapacityMsg := buildMsg( + stunMsg.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), + &stun.ErrorCodeAttribute{Code: stun.CodeInsufficientCapacity}, + ) + + // 2. The server checks if the 5-tuple is currently in use by an + // existing allocation. If yes, the server rejects the request with + // a 437 (Allocation Mismatch) error. + if alloc := req.AllocationManager.GetAllocation(fiveTuple); alloc != nil { + id, attrs := alloc.GetResponseCache() + if id != stunMsg.TransactionID { + msg := buildMsg( + stunMsg.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), + &stun.ErrorCodeAttribute{Code: stun.CodeAllocMismatch}, + ) + + return buildAndSendErr(req.Conn, req.SrcAddr, errRelayAlreadyAllocatedForFiveTuple, msg...) + } + // A retry allocation + msg := buildMsg( + stunMsg.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassSuccessResponse), + append(attrs, messageIntegrity)..., + ) + + return buildAndSend(req.Conn, req.SrcAddr, msg...) + } + + // 3. The server checks if the request contains a REQUESTED-TRANSPORT + // attribute. If the REQUESTED-TRANSPORT attribute is not included + // or is malformed, the server rejects the request with a 400 (Bad + // Request) error. Otherwise, if the attribute is included but + // specifies a protocol other that UDP/TCP, the server rejects the + // request with a 442 (Unsupported Transport Protocol) error. + var requestedTransport proto.RequestedTransport + if err = requestedTransport.GetFrom(stunMsg); err != nil { + return buildAndSendErr(req.Conn, req.SrcAddr, err, badRequestMsg...) + } else if requestedTransport.Protocol != proto.ProtoUDP && requestedTransport.Protocol != proto.ProtoTCP { + msg := buildMsg( + stunMsg.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), + &stun.ErrorCodeAttribute{Code: stun.CodeUnsupportedTransProto}, + ) + + return buildAndSendErr(req.Conn, req.SrcAddr, errUnsupportedTransportProtocol, msg...) + } + + // 4. The request may contain a DONT-FRAGMENT attribute. If it does, + // but the server does not support sending UDP datagrams with the DF + // bit set to 1 (see Section 12), then the server treats the DONT- + // FRAGMENT attribute in the Allocate request as an unknown + // comprehension-required attribute. + if stunMsg.Contains(stun.AttrDontFragment) { + msg := buildMsg( + stunMsg.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), + &stun.ErrorCodeAttribute{Code: stun.CodeUnknownAttribute}, + &stun.UnknownAttributes{stun.AttrDontFragment}, + ) + + return buildAndSendErr(req.Conn, req.SrcAddr, errNoDontFragmentSupport, msg...) + } + + // 5. The server checks if the request contains a RESERVATION-TOKEN + // attribute. If yes, and the request also contains an EVEN-PORT + // attribute, then the server rejects the request with a 400 (Bad + // Request) error. Otherwise, it checks to see if the token is + // valid (i.e., the token is in range and has not expired and the + // corresponding relayed transport address is still available). If + // the token is not valid for some reason, the server rejects the + // request with a 508 (Insufficient Capacity) error. + var reservationTokenAttr proto.ReservationToken + if err = reservationTokenAttr.GetFrom(stunMsg); err == nil { + var evenPort proto.EvenPort + if err = evenPort.GetFrom(stunMsg); err == nil { + return buildAndSendErr(req.Conn, req.SrcAddr, errRequestWithReservationTokenAndEvenPort, badRequestMsg...) + } + } + + // 6. The server checks if the request contains an EVEN-PORT attribute. + // If yes, then the server checks that it can satisfy the request + // (i.e., can allocate a relayed transport address as described + // below). If the server cannot satisfy the request, then the + // server rejects the request with a 508 (Insufficient Capacity) + // error. + var evenPort proto.EvenPort + if err = evenPort.GetFrom(stunMsg); err == nil { + var randomPort int + randomPort, err = req.AllocationManager.GetRandomEvenPort() + if err != nil { + return buildAndSendErr(req.Conn, req.SrcAddr, err, insufficientCapacityMsg...) + } + requestedPort = randomPort + reservationToken, err = randutil.GenerateCryptoRandomString(8, runesAlpha) + if err != nil { + return err + } + } + + // Parse realm and username (already checked in authenticateRequest) + realmAttr := &stun.Realm{} + _ = realmAttr.GetFrom(stunMsg) + usernameAttr := &stun.Username{} + _ = usernameAttr.GetFrom(stunMsg) + + // 7. At any point, the server MAY choose to reject the request with a + // 486 (Allocation Quota Reached) error if it feels the client is + // trying to exceed some locally defined allocation quota. The + // server is free to define this allocation quota any way it wishes, + // but SHOULD define it based on the username used to authenticate + // the request, and not on the client's transport address. + if req.QuotaHandler != nil && !req.QuotaHandler(usernameAttr.String(), realmAttr.String(), req.SrcAddr) { + quotaReachedMsg := buildMsg(stunMsg.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), + &stun.ErrorCodeAttribute{Code: stun.CodeAllocQuotaReached}) + + return buildAndSend(req.Conn, req.SrcAddr, quotaReachedMsg...) + } + + // 8. Also at any point, the server MAY choose to reject the request + // with a 300 (Try Alternate) error if it wishes to redirect the + // client to a different server. The use of this error code and + // attribute follow the specification in [RFC5389]. + lifetimeDuration := allocationLifeTime(stunMsg) + alloc, err := req.AllocationManager.CreateAllocation( + fiveTuple, + req.Conn, + requestedPort, + lifetimeDuration, + usernameAttr.String(), + realmAttr.String(), + ) + if err != nil { + return buildAndSendErr(req.Conn, req.SrcAddr, err, insufficientCapacityMsg...) + } + + // Once the allocation is created, the server replies with a success + // response. + // The success response contains: + // * An XOR-RELAYED-ADDRESS attribute containing the relayed transport + // address. + // * A LIFETIME attribute containing the current value of the time-to- + // expiry timer. + // * A RESERVATION-TOKEN attribute (if a second relayed transport + // address was reserved). + // * An XOR-MAPPED-ADDRESS attribute containing the client's IP address + // and port (from the 5-tuple). + + srcIP, srcPort, err := ipnet.AddrIPPort(req.SrcAddr) + if err != nil { + return buildAndSendErr(req.Conn, req.SrcAddr, err, badRequestMsg...) + } + + relayIP, relayPort, err := ipnet.AddrIPPort(alloc.RelayAddr) + if err != nil { + return buildAndSendErr(req.Conn, req.SrcAddr, err, badRequestMsg...) + } + + responseAttrs := []stun.Setter{ + &proto.RelayedAddress{ + IP: relayIP, + Port: relayPort, + }, + &proto.Lifetime{ + Duration: lifetimeDuration, + }, + &stun.XORMappedAddress{ + IP: srcIP, + Port: srcPort, + }, + } + + if reservationToken != "" { + req.AllocationManager.CreateReservation(reservationToken, relayPort) + responseAttrs = append(responseAttrs, proto.ReservationToken([]byte(reservationToken))) + } + + msg := buildMsg( + stunMsg.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassSuccessResponse), + append(responseAttrs, messageIntegrity)..., + ) + alloc.SetResponseCache(stunMsg.TransactionID, responseAttrs) + + return buildAndSend(req.Conn, req.SrcAddr, msg...) +} + +func handleRefreshRequest(req Request, stunMsg *stun.Message) error { + req.Log.Debugf("Received RefreshRequest from %s", req.SrcAddr) + + messageIntegrity, hasAuth, err := authenticateRequest(req, stunMsg, stun.MethodRefresh) + if !hasAuth { + return err + } + + lifetimeDuration := allocationLifeTime(stunMsg) + fiveTuple := &allocation.FiveTuple{ + SrcAddr: req.SrcAddr, + DstAddr: req.Conn.LocalAddr(), + Protocol: allocation.UDP, + } + + if lifetimeDuration != 0 { + a := req.AllocationManager.GetAllocation(fiveTuple) + + if a == nil { + return fmt.Errorf("%w %v:%v", errNoAllocationFound, req.SrcAddr, req.Conn.LocalAddr()) + } + a.Refresh(lifetimeDuration) + } else { + req.AllocationManager.DeleteAllocation(fiveTuple) + } + + return buildAndSend( + req.Conn, + req.SrcAddr, + buildMsg( + stunMsg.TransactionID, + stun.NewType(stun.MethodRefresh, stun.ClassSuccessResponse), + []stun.Setter{ + &proto.Lifetime{ + Duration: lifetimeDuration, + }, + messageIntegrity, + }..., + )..., + ) +} + +func handleCreatePermissionRequest(req Request, stunMsg *stun.Message) error { + req.Log.Debugf("Received CreatePermission from %s", req.SrcAddr) + + alloc := req.AllocationManager.GetAllocation(&allocation.FiveTuple{ + SrcAddr: req.SrcAddr, + DstAddr: req.Conn.LocalAddr(), + Protocol: allocation.UDP, + }) + if alloc == nil { + return fmt.Errorf("%w %v:%v", errNoAllocationFound, req.SrcAddr, req.Conn.LocalAddr()) + } + + messageIntegrity, hasAuth, err := authenticateRequest(req, stunMsg, stun.MethodCreatePermission) + if !hasAuth { + return err + } + + addCount := 0 + + if err := stunMsg.ForEach(stun.AttrXORPeerAddress, func(m *stun.Message) error { + var peerAddress proto.PeerAddress + if err := peerAddress.GetFrom(m); err != nil { + return err + } + + if err := req.AllocationManager.GrantPermission(req.SrcAddr, peerAddress.IP); err != nil { + req.Log.Infof("permission denied for client %s to peer %s", req.SrcAddr, peerAddress.IP) + + return err + } + + req.Log.Debugf("Adding permission for %s", fmt.Sprintf("%s:%d", + peerAddress.IP, peerAddress.Port)) + + alloc.AddPermission(allocation.NewPermission( + &net.UDPAddr{ + IP: peerAddress.IP, + Port: peerAddress.Port, + }, + req.Log, + )) + addCount++ + + return nil + }); err != nil { + addCount = 0 + } + + respClass := stun.ClassSuccessResponse + if addCount == 0 { + respClass = stun.ClassErrorResponse + } + + return buildAndSend( + req.Conn, + req.SrcAddr, + buildMsg(stunMsg.TransactionID, stun.NewType(stun.MethodCreatePermission, respClass), + []stun.Setter{messageIntegrity}...)..., + ) +} + +func handleSendIndication(req Request, stunMsg *stun.Message) error { + req.Log.Debugf("Received SendIndication from %s", req.SrcAddr) + alloc := req.AllocationManager.GetAllocation(&allocation.FiveTuple{ + SrcAddr: req.SrcAddr, + DstAddr: req.Conn.LocalAddr(), + Protocol: allocation.UDP, + }) + if alloc == nil { + return fmt.Errorf("%w %v:%v", errNoAllocationFound, req.SrcAddr, req.Conn.LocalAddr()) + } + + dataAttr := proto.Data{} + if err := dataAttr.GetFrom(stunMsg); err != nil { + return err + } + + peerAddress := proto.PeerAddress{} + if err := peerAddress.GetFrom(stunMsg); err != nil { + return err + } + + msgDst := &net.UDPAddr{IP: peerAddress.IP, Port: peerAddress.Port} + if perm := alloc.GetPermission(msgDst); perm == nil { + return fmt.Errorf("%w: %v", errNoPermission, msgDst) + } + + l, err := alloc.RelaySocket.WriteTo(dataAttr, msgDst) + if err != nil { + return fmt.Errorf("%w: %s", errFailedWriteSocket, err.Error()) + } else if l != len(dataAttr) { + return fmt.Errorf("%w %d != %d (expected)", errShortWrite, l, len(dataAttr)) + } + + return err +} + +func handleChannelBindRequest(req Request, stunMsg *stun.Message) error { + req.Log.Debugf("Received ChannelBindRequest from %s", req.SrcAddr) + + alloc := req.AllocationManager.GetAllocation(&allocation.FiveTuple{ + SrcAddr: req.SrcAddr, + DstAddr: req.Conn.LocalAddr(), + Protocol: allocation.UDP, + }) + if alloc == nil { + return fmt.Errorf("%w %v:%v", errNoAllocationFound, req.SrcAddr, req.Conn.LocalAddr()) + } + + badRequestMsg := buildMsg( + stunMsg.TransactionID, + stun.NewType(stun.MethodChannelBind, stun.ClassErrorResponse), + &stun.ErrorCodeAttribute{Code: stun.CodeBadRequest}, + ) + + messageIntegrity, hasAuth, err := authenticateRequest(req, stunMsg, stun.MethodChannelBind) + if !hasAuth { + return err + } + + var channel proto.ChannelNumber + if err = channel.GetFrom(stunMsg); err != nil { + return buildAndSendErr(req.Conn, req.SrcAddr, err, badRequestMsg...) + } + + peerAddr := proto.PeerAddress{} + if err = peerAddr.GetFrom(stunMsg); err != nil { + return buildAndSendErr(req.Conn, req.SrcAddr, err, badRequestMsg...) + } + + if err = req.AllocationManager.GrantPermission(req.SrcAddr, peerAddr.IP); err != nil { + req.Log.Infof("permission denied for client %s to peer %s", req.SrcAddr, peerAddr.IP) + + unauthorizedRequestMsg := buildMsg(stunMsg.TransactionID, + stun.NewType(stun.MethodChannelBind, stun.ClassErrorResponse), + &stun.ErrorCodeAttribute{Code: stun.CodeUnauthorized}) + + return buildAndSendErr(req.Conn, req.SrcAddr, err, unauthorizedRequestMsg...) + } + + req.Log.Debugf("Binding channel %d to %s", channel, peerAddr) + err = alloc.AddChannelBind(allocation.NewChannelBind( + channel, + &net.UDPAddr{IP: peerAddr.IP, Port: peerAddr.Port}, + req.Log, + ), req.ChannelBindTimeout) + if err != nil { + return buildAndSendErr(req.Conn, req.SrcAddr, err, badRequestMsg...) + } + + return buildAndSend( + req.Conn, + req.SrcAddr, + buildMsg(stunMsg.TransactionID, stun.NewType(stun.MethodChannelBind, stun.ClassSuccessResponse), + []stun.Setter{messageIntegrity}...)..., + ) +} + +func handleChannelData(req Request, channelData *proto.ChannelData) error { + req.Log.Debugf("Received ChannelData from %s", req.SrcAddr) + + alloc := req.AllocationManager.GetAllocation(&allocation.FiveTuple{ + SrcAddr: req.SrcAddr, + DstAddr: req.Conn.LocalAddr(), + Protocol: allocation.UDP, + }) + if alloc == nil { + return fmt.Errorf("%w %v:%v", errNoAllocationFound, req.SrcAddr, req.Conn.LocalAddr()) + } + + channel := alloc.GetChannelByNumber(channelData.Number) + if channel == nil { + return fmt.Errorf("%w %x", errNoSuchChannelBind, uint16(channelData.Number)) + } + + l, err := alloc.RelaySocket.WriteTo(channelData.Data, channel.Peer) + if err != nil { + return fmt.Errorf("%w: %s", errFailedWriteSocket, err.Error()) + } else if l != len(channelData.Data) { + return fmt.Errorf("%w %d != %d (expected)", errShortWrite, l, len(channelData.Data)) + } + + return nil +} diff --git a/vendor/github.com/pion/turn/v4/internal/server/util.go b/vendor/github.com/pion/turn/v4/internal/server/util.go new file mode 100644 index 0000000..65ce103 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/internal/server/util.go @@ -0,0 +1,163 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package server + +import ( + "errors" + "fmt" + "net" + "time" + + "github.com/pion/stun/v3" + "github.com/pion/turn/v4/internal/proto" +) + +const ( + // See: https://tools.ietf.org/html/rfc5766#section-6.2 defines 3600 seconds recommendation. + maximumAllocationLifetime = time.Hour +) + +func buildAndSend(conn net.PacketConn, dst net.Addr, attrs ...stun.Setter) error { + msg, err := stun.Build(attrs...) + if err != nil { + return err + } + _, err = conn.WriteTo(msg.Raw, dst) + if errors.Is(err, net.ErrClosed) { + return nil + } + + return err +} + +// Send a STUN packet and return the original error to the caller. +func buildAndSendErr(conn net.PacketConn, dst net.Addr, err error, attrs ...stun.Setter) error { + if sendErr := buildAndSend(conn, dst, attrs...); sendErr != nil { + err = fmt.Errorf("%w %v %v", errFailedToSendError, sendErr, err) //nolint:errorlint + } + + return err +} + +func buildMsg( + transactionID [stun.TransactionIDSize]byte, + msgType stun.MessageType, + additional ...stun.Setter, +) []stun.Setter { + return append([]stun.Setter{&stun.Message{TransactionID: transactionID}, msgType}, additional...) +} + +func authenticateRequest(req Request, stunMsg *stun.Message, callingMethod stun.Method) ( + stun.MessageIntegrity, + bool, + error, +) { + respondWithNonce := func(responseCode stun.ErrorCode) (stun.MessageIntegrity, bool, error) { + nonce, err := req.NonceHash.Generate() + if err != nil { + return nil, false, err + } + + return nil, false, buildAndSend(req.Conn, req.SrcAddr, buildMsg(stunMsg.TransactionID, + stun.NewType(callingMethod, stun.ClassErrorResponse), + &stun.ErrorCodeAttribute{Code: responseCode}, + stun.NewNonce(nonce), + stun.NewRealm(req.Realm), + )...) + } + + if !stunMsg.Contains(stun.AttrMessageIntegrity) { + return respondWithNonce(stun.CodeUnauthorized) + } + + nonceAttr := &stun.Nonce{} + usernameAttr := &stun.Username{} + realmAttr := &stun.Realm{} + badRequestMsg := buildMsg( + stunMsg.TransactionID, + stun.NewType(callingMethod, stun.ClassErrorResponse), + &stun.ErrorCodeAttribute{Code: stun.CodeBadRequest}, + ) + + // No Auth handler is set, server is running in STUN only mode + // Respond with 400 so clients don't retry. + if req.AuthHandler == nil { + sendErr := buildAndSend(req.Conn, req.SrcAddr, badRequestMsg...) + + return nil, false, sendErr + } + + if err := nonceAttr.GetFrom(stunMsg); err != nil { + return nil, false, buildAndSendErr(req.Conn, req.SrcAddr, err, badRequestMsg...) + } + + // Assert Nonce is signed and is not expired. + if err := req.NonceHash.Validate(nonceAttr.String()); err != nil { + return respondWithNonce(stun.CodeStaleNonce) + } + + if err := realmAttr.GetFrom(stunMsg); err != nil { + return nil, false, buildAndSendErr(req.Conn, req.SrcAddr, err, badRequestMsg...) + } else if err := usernameAttr.GetFrom(stunMsg); err != nil { + return nil, false, buildAndSendErr(req.Conn, req.SrcAddr, err, badRequestMsg...) + } + + ourKey, ok := req.AuthHandler(usernameAttr.String(), realmAttr.String(), req.SrcAddr) + if !ok { + return nil, false, buildAndSendErr( + req.Conn, + req.SrcAddr, + fmt.Errorf("%w %s", errNoSuchUser, usernameAttr.String()), + badRequestMsg..., + ) + } + + if err := stun.MessageIntegrity(ourKey).Check(stunMsg); err != nil { + genAuthEvent(req, stunMsg, callingMethod, false) + + return nil, false, buildAndSendErr(req.Conn, req.SrcAddr, err, badRequestMsg...) + } + + genAuthEvent(req, stunMsg, callingMethod, true) + + return stun.MessageIntegrity(ourKey), true, nil +} + +func genAuthEvent(req Request, stunMsg *stun.Message, callingMethod stun.Method, verdict bool) { + if req.AllocationManager.EventHandler.OnAuth == nil { + return + } + + realmAttr := &stun.Realm{} + if err := realmAttr.GetFrom(stunMsg); err != nil { + return + } + + usernameAttr := &stun.Username{} + if err := usernameAttr.GetFrom(stunMsg); err != nil { + return + } + + transportAttr := &proto.RequestedTransport{} + if err := transportAttr.GetFrom(stunMsg); err != nil { + transportAttr = &proto.RequestedTransport{Protocol: proto.ProtoUDP} + } + + req.AllocationManager.EventHandler.OnAuth(req.SrcAddr, req.Conn.LocalAddr(), + transportAttr.Protocol.String(), usernameAttr.String(), realmAttr.String(), + callingMethod.String(), verdict) +} + +func allocationLifeTime(m *stun.Message) time.Duration { + lifetimeDuration := proto.DefaultLifetime + + var lifetime proto.Lifetime + if err := lifetime.GetFrom(m); err == nil { + if lifetime.Duration < maximumAllocationLifetime { + lifetimeDuration = lifetime.Duration + } + } + + return lifetimeDuration +} diff --git a/vendor/github.com/pion/turn/v4/lt_cred.go b/vendor/github.com/pion/turn/v4/lt_cred.go new file mode 100644 index 0000000..65576be --- /dev/null +++ b/vendor/github.com/pion/turn/v4/lt_cred.go @@ -0,0 +1,118 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package turn + +import ( + "crypto/hmac" + "crypto/sha1" //nolint:gosec,gci + "encoding/base64" + "net" + "strconv" + "strings" + "time" + + "github.com/pion/logging" +) + +// GenerateLongTermCredentials can be used to create credentials valid for [duration] time. +func GenerateLongTermCredentials(sharedSecret string, duration time.Duration) (string, string, error) { + t := time.Now().Add(duration).Unix() + username := strconv.FormatInt(t, 10) + password, err := longTermCredentials(username, sharedSecret) + + return username, password, err +} + +// GenerateLongTermTURNRESTCredentials can be used to create credentials valid for [duration] time. +func GenerateLongTermTURNRESTCredentials(sharedSecret string, user string, duration time.Duration) ( + string, + string, + error, +) { + t := time.Now().Add(duration).Unix() + timestamp := strconv.FormatInt(t, 10) + username := timestamp + ":" + user + password, err := longTermCredentials(username, sharedSecret) + + return username, password, err +} + +func longTermCredentials(username string, sharedSecret string) (string, error) { + mac := hmac.New(sha1.New, []byte(sharedSecret)) + _, err := mac.Write([]byte(username)) + if err != nil { + return "", err // Not sure if this will ever happen + } + password := mac.Sum(nil) + + return base64.StdEncoding.EncodeToString(password), nil +} + +// NewLongTermAuthHandler returns a turn.AuthAuthHandler used with Long Term (or Time Windowed) Credentials. +// See: https://datatracker.ietf.org/doc/html/rfc8489#section-9.2 +// . +func NewLongTermAuthHandler(sharedSecret string, logger logging.LeveledLogger) AuthHandler { + if logger == nil { + logger = logging.NewDefaultLoggerFactory().NewLogger("turn") + } + + return func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) { + logger.Tracef("Authentication username=%q realm=%q srcAddr=%v", username, realm, srcAddr) + t, err := strconv.Atoi(username) + if err != nil { + logger.Errorf("Invalid time-windowed username %q", username) + + return nil, false + } + if int64(t) < time.Now().Unix() { + logger.Errorf("Expired time-windowed username %q", username) + + return nil, false + } + password, err := longTermCredentials(username, sharedSecret) + if err != nil { + logger.Error(err.Error()) + + return nil, false + } + + return GenerateAuthKey(username, realm, password), true + } +} + +// LongTermTURNRESTAuthHandler returns a turn.AuthAuthHandler that can be used to authenticate +// time-windowed ephemeral credentials generated by the TURN REST API as described in +// https://datatracker.ietf.org/doc/html/draft-uberti-behave-turn-rest-00 +// +// The supported format of is timestamp:username, where username is an arbitrary user id and the +// timestamp specifies the expiry of the credential. +func LongTermTURNRESTAuthHandler(sharedSecret string, logger logging.LeveledLogger) AuthHandler { + if logger == nil { + logger = logging.NewDefaultLoggerFactory().NewLogger("turn") + } + + return func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) { + logger.Tracef("Authentication username=%q realm=%q srcAddr=%v", username, realm, srcAddr) + timestamp := strings.Split(username, ":")[0] + t, err := strconv.Atoi(timestamp) + if err != nil { + logger.Errorf("Invalid time-windowed username %q", username) + + return nil, false + } + if int64(t) < time.Now().Unix() { + logger.Errorf("Expired time-windowed username %q", username) + + return nil, false + } + password, err := longTermCredentials(username, sharedSecret) + if err != nil { + logger.Error(err.Error()) + + return nil, false + } + + return GenerateAuthKey(username, realm, password), true + } +} diff --git a/vendor/github.com/pion/turn/v4/relay_address_generator_none.go b/vendor/github.com/pion/turn/v4/relay_address_generator_none.go new file mode 100644 index 0000000..f549af7 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/relay_address_generator_none.go @@ -0,0 +1,59 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package turn + +import ( + "fmt" + "net" + "strconv" + + "github.com/pion/transport/v4" + "github.com/pion/transport/v4/stdnet" +) + +// RelayAddressGeneratorNone returns the listener with no modifications. +type RelayAddressGeneratorNone struct { + // Address is passed to Listen/ListenPacket when creating the Relay + Address string + + Net transport.Net +} + +// Validate is called on server startup and confirms the RelayAddressGenerator is properly configured. +func (r *RelayAddressGeneratorNone) Validate() error { + if r.Net == nil { + var err error + r.Net, err = stdnet.NewNet() + if err != nil { + return fmt.Errorf("failed to create network: %w", err) + } + } + + if r.Address == "" { + return errListeningAddressInvalid + } + + return nil +} + +// AllocatePacketConn generates a new PacketConn to receive traffic on and the IP/Port +// to populate the allocation response with. +func (r *RelayAddressGeneratorNone) AllocatePacketConn(network string, requestedPort int) ( + net.PacketConn, + net.Addr, + error, +) { + conn, err := r.Net.ListenPacket(network, r.Address+":"+strconv.Itoa(requestedPort)) // nolint: noctx + if err != nil { + return nil, nil, err + } + + return conn, conn.LocalAddr(), nil +} + +// AllocateConn generates a new Conn to receive traffic on and the IP/Port +// to populate the allocation response with. +func (r *RelayAddressGeneratorNone) AllocateConn(string, int) (net.Conn, net.Addr, error) { + return nil, nil, errTODO +} diff --git a/vendor/github.com/pion/turn/v4/relay_address_generator_range.go b/vendor/github.com/pion/turn/v4/relay_address_generator_range.go new file mode 100644 index 0000000..38201c6 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/relay_address_generator_range.go @@ -0,0 +1,116 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package turn + +import ( + "fmt" + "net" + + "github.com/pion/randutil" + "github.com/pion/transport/v4" + "github.com/pion/transport/v4/stdnet" +) + +// RelayAddressGeneratorPortRange can be used to only allocate connections inside a defined port range. +// Similar to the RelayAddressGeneratorStatic a static ip address can be set. +type RelayAddressGeneratorPortRange struct { + // RelayAddress is the IP returned to the user when the relay is created + RelayAddress net.IP + + // MinPort the minimum port to allocate + MinPort uint16 + // MaxPort the maximum (inclusive) port to allocate + MaxPort uint16 + + // MaxRetries the amount of tries to allocate a random port in the defined range + MaxRetries int + + // Rand the random source of numbers + Rand randutil.MathRandomGenerator + + // Address is passed to Listen/ListenPacket when creating the Relay + Address string + + Net transport.Net +} + +// Validate is called on server startup and confirms the RelayAddressGenerator is properly configured. +func (r *RelayAddressGeneratorPortRange) Validate() error { + if r.Net == nil { + var err error + r.Net, err = stdnet.NewNet() + if err != nil { + return fmt.Errorf("failed to create network: %w", err) + } + } + + if r.Rand == nil { + r.Rand = randutil.NewMathRandomGenerator() + } + + if r.MaxRetries == 0 { + r.MaxRetries = 10 + } + + switch { + case r.MinPort == 0: + return errMinPortNotZero + case r.MaxPort == 0: + return errMaxPortNotZero + case r.RelayAddress == nil: + return errRelayAddressInvalid + case r.Address == "": + return errListeningAddressInvalid + default: + return nil + } +} + +// AllocatePacketConn generates a new PacketConn to receive traffic on and the IP/Port +// to populate the allocation response with. +func (r *RelayAddressGeneratorPortRange) AllocatePacketConn( + network string, + requestedPort int, +) (net.PacketConn, net.Addr, error) { + if requestedPort != 0 { + conn, err := r.Net.ListenPacket(network, fmt.Sprintf("%s:%d", r.Address, requestedPort)) // nolint: noctx + if err != nil { + return nil, nil, err + } + + relayAddr, ok := conn.LocalAddr().(*net.UDPAddr) + if !ok { + return nil, nil, errNilConn + } + + relayAddr.IP = r.RelayAddress + + return conn, relayAddr, nil + } + + for try := 0; try < r.MaxRetries; try++ { + port := r.MinPort + uint16(r.Rand.Intn(int((r.MaxPort+1)-r.MinPort))) // nolint:gosec // G115 false positive + conn, err := r.Net.ListenPacket(network, fmt.Sprintf("%s:%d", r.Address, port)) // nolint: noctx + if err != nil { + continue + } + + relayAddr, ok := conn.LocalAddr().(*net.UDPAddr) + if !ok { + return nil, nil, errNilConn + } + + relayAddr.IP = r.RelayAddress + + return conn, relayAddr, nil + } + + return nil, nil, errMaxRetriesExceeded +} + +// AllocateConn generates a new Conn to receive traffic on and the IP/Port +// to populate the allocation response with. +func (r *RelayAddressGeneratorPortRange) AllocateConn(string, int) (net.Conn, net.Addr, error) { + return nil, nil, errTODO +} diff --git a/vendor/github.com/pion/turn/v4/relay_address_generator_static.go b/vendor/github.com/pion/turn/v4/relay_address_generator_static.go new file mode 100644 index 0000000..a448363 --- /dev/null +++ b/vendor/github.com/pion/turn/v4/relay_address_generator_static.go @@ -0,0 +1,73 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package turn + +import ( + "fmt" + "net" + "strconv" + + "github.com/pion/transport/v4" + "github.com/pion/transport/v4/stdnet" +) + +// RelayAddressGeneratorStatic can be used to return static IP address each time a relay is created. +// This can be used when you have a single static IP address that you want to use. +type RelayAddressGeneratorStatic struct { + // RelayAddress is the IP returned to the user when the relay is created + RelayAddress net.IP + + // Address is passed to Listen/ListenPacket when creating the Relay + Address string + + Net transport.Net +} + +// Validate is called on server startup and confirms the RelayAddressGenerator is properly configured. +func (r *RelayAddressGeneratorStatic) Validate() error { + if r.Net == nil { + var err error + r.Net, err = stdnet.NewNet() + if err != nil { + return fmt.Errorf("failed to create network: %w", err) + } + } + + switch { + case r.RelayAddress == nil: + return errRelayAddressInvalid + case r.Address == "": + return errListeningAddressInvalid + default: + return nil + } +} + +// AllocatePacketConn generates a new PacketConn to receive traffic on and the IP/Port +// to populate the allocation response with. +func (r *RelayAddressGeneratorStatic) AllocatePacketConn( + network string, + requestedPort int, +) (net.PacketConn, net.Addr, error) { + conn, err := r.Net.ListenPacket(network, r.Address+":"+strconv.Itoa(requestedPort)) // nolint: noctx + if err != nil { + return nil, nil, err + } + + // Replace actual listening IP with the user requested one of RelayAddressGeneratorStatic + relayAddr, ok := conn.LocalAddr().(*net.UDPAddr) + if !ok { + return nil, nil, errNilConn + } + + relayAddr.IP = r.RelayAddress + + return conn, relayAddr, nil +} + +// AllocateConn generates a new Conn to receive traffic on and the IP/Port +// to populate the allocation response with. +func (r *RelayAddressGeneratorStatic) AllocateConn(string, int) (net.Conn, net.Addr, error) { + return nil, nil, errTODO +} diff --git a/vendor/github.com/pion/turn/v4/renovate.json b/vendor/github.com/pion/turn/v4/renovate.json new file mode 100644 index 0000000..f1bb98c --- /dev/null +++ b/vendor/github.com/pion/turn/v4/renovate.json @@ -0,0 +1,6 @@ +{ + "$schema": "https://docs.renovatebot.com/renovate-schema.json", + "extends": [ + "github>pion/renovate-config" + ] +} diff --git a/vendor/github.com/pion/turn/v4/server.go b/vendor/github.com/pion/turn/v4/server.go new file mode 100644 index 0000000..eaf6fdb --- /dev/null +++ b/vendor/github.com/pion/turn/v4/server.go @@ -0,0 +1,248 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package turn contains the public API for pion/turn, a toolkit for building TURN clients and servers +package turn + +import ( + "errors" + "fmt" + "net" + "time" + + "github.com/pion/logging" + "github.com/pion/turn/v4/internal/allocation" + "github.com/pion/turn/v4/internal/proto" + "github.com/pion/turn/v4/internal/server" +) + +const ( + defaultInboundMTU = 1600 +) + +// Server is an instance of the Pion TURN Server. +type Server struct { + log logging.LeveledLogger + authHandler AuthHandler + quotaHandler QuotaHandler + realm string + channelBindTimeout time.Duration + nonceHash server.NonceManager + eventHandler EventHandler + + packetConnConfigs []PacketConnConfig + listenerConfigs []ListenerConfig + allocationManagers []*allocation.Manager + inboundMTU int +} + +// NewServer creates the Pion TURN server. +func NewServer(config ServerConfig) (*Server, error) { //nolint:gocognit,cyclop + if err := config.validate(); err != nil { + return nil, err + } + + loggerFactory := config.LoggerFactory + if loggerFactory == nil { + loggerFactory = logging.NewDefaultLoggerFactory() + } + + mtu := defaultInboundMTU + if config.InboundMTU != 0 { + mtu = config.InboundMTU + } + + nonceHash, err := server.NewShortNonceHash(0) + if err != nil { + return nil, err + } + + server := &Server{ + log: loggerFactory.NewLogger("turn"), + authHandler: config.AuthHandler, + quotaHandler: config.QuotaHandler, + realm: config.Realm, + channelBindTimeout: config.ChannelBindTimeout, + packetConnConfigs: config.PacketConnConfigs, + listenerConfigs: config.ListenerConfigs, + nonceHash: nonceHash, + inboundMTU: mtu, + eventHandler: config.EventHandler, + } + + if server.channelBindTimeout == 0 { + server.channelBindTimeout = proto.DefaultLifetime + } + + for _, cfg := range server.packetConnConfigs { + am, err := server.createAllocationManager(cfg.RelayAddressGenerator, cfg.PermissionHandler) + if err != nil { + return nil, fmt.Errorf("failed to create AllocationManager: %w", err) + } + + go func(cfg PacketConnConfig, am *allocation.Manager) { + server.readLoop(cfg.PacketConn, am) + + if err := am.Close(); err != nil { + server.log.Errorf("Failed to close AllocationManager: %s", err) + } + }(cfg, am) + } + + for _, cfg := range server.listenerConfigs { + am, err := server.createAllocationManager(cfg.RelayAddressGenerator, cfg.PermissionHandler) + if err != nil { + return nil, fmt.Errorf("failed to create AllocationManager: %w", err) + } + + go func(cfg ListenerConfig, am *allocation.Manager) { + server.readListener(cfg.Listener, am) + + if err := am.Close(); err != nil { + server.log.Errorf("Failed to close AllocationManager: %s", err) + } + }(cfg, am) + } + + return server, nil +} + +// AllocationCount returns the number of active allocations. +// It can be used to drain the server before closing. +func (s *Server) AllocationCount() int { + allocs := 0 + for _, am := range s.allocationManagers { + allocs += am.AllocationCount() + } + + return allocs +} + +// Close stops the TURN Server. +// It cleans up any associated state and closes all connections it is managing. +func (s *Server) Close() error { + var errors []error + + for _, cfg := range s.packetConnConfigs { + if err := cfg.PacketConn.Close(); err != nil { + errors = append(errors, err) + } + } + + for _, cfg := range s.listenerConfigs { + if err := cfg.Listener.Close(); err != nil { + errors = append(errors, err) + } + } + + if len(errors) == 0 { + return nil + } + + err := errFailedToClose + for _, e := range errors { + err = fmt.Errorf("%s; close error (%w) ", err, e) //nolint:errorlint + } + + return err +} + +func (s *Server) readListener(l net.Listener, am *allocation.Manager) { + for { + conn, err := l.Accept() + if err != nil { + s.log.Debugf("Failed to accept: %s", err) + + return + } + + go func() { + s.readLoop(NewSTUNConn(conn), am) + + // Delete allocation + am.DeleteAllocation(&allocation.FiveTuple{ + Protocol: allocation.UDP, // fixed UDP + SrcAddr: conn.RemoteAddr(), + DstAddr: conn.LocalAddr(), + }) + + if err := conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + s.log.Errorf("Failed to close conn: %s", err) + } + }() + } +} + +type nilAddressGenerator struct{} + +func (n *nilAddressGenerator) Validate() error { return errRelayAddressGeneratorNil } + +func (n *nilAddressGenerator) AllocatePacketConn(string, int) (net.PacketConn, net.Addr, error) { + return nil, nil, errRelayAddressGeneratorNil +} + +func (n *nilAddressGenerator) AllocateConn(string, int) (net.Conn, net.Addr, error) { + return nil, nil, errRelayAddressGeneratorNil +} + +func (s *Server) createAllocationManager( + addrGenerator RelayAddressGenerator, + handler PermissionHandler, +) (*allocation.Manager, error) { + if handler == nil { + handler = DefaultPermissionHandler + } + if addrGenerator == nil { + addrGenerator = &nilAddressGenerator{} + } + + am, err := allocation.NewManager(allocation.ManagerConfig{ + AllocatePacketConn: addrGenerator.AllocatePacketConn, + AllocateConn: addrGenerator.AllocateConn, + PermissionHandler: handler, + EventHandler: s.eventHandler, + LeveledLogger: s.log, + }) + if err != nil { + return am, err + } + + s.allocationManagers = append(s.allocationManagers, am) + + return am, err +} + +func (s *Server) readLoop(conn net.PacketConn, allocationManager *allocation.Manager) { + buf := make([]byte, s.inboundMTU) + for { + n, addr, err := conn.ReadFrom(buf) + switch { + case err != nil: + s.log.Debugf("Exit read loop on error: %s", err) + + return + case n >= s.inboundMTU: + s.log.Debugf("Read bytes exceeded MTU, packet is possibly truncated") + + continue + } + + if err := server.HandleRequest(server.Request{ + Conn: conn, + SrcAddr: addr, + Buff: buf[:n], + Log: s.log, + AuthHandler: s.authHandler, + QuotaHandler: s.quotaHandler, + Realm: s.realm, + AllocationManager: allocationManager, + ChannelBindTimeout: s.channelBindTimeout, + NonceHash: s.nonceHash, + }); err != nil { + if s.eventHandler.OnAllocationError != nil { + s.eventHandler.OnAllocationError(addr, conn.LocalAddr(), allocation.UDP.String(), err.Error()) + } + s.log.Debugf("Failed to handle datagram: %v", err) + } + } +} diff --git a/vendor/github.com/pion/turn/v4/server_config.go b/vendor/github.com/pion/turn/v4/server_config.go new file mode 100644 index 0000000..fcddabc --- /dev/null +++ b/vendor/github.com/pion/turn/v4/server_config.go @@ -0,0 +1,170 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package turn + +import ( + "crypto/md5" //nolint:gosec,gci + "fmt" + "net" + "strings" + "time" + + "github.com/pion/logging" + "github.com/pion/turn/v4/internal/allocation" +) + +// RelayAddressGenerator is used to generate a RelayAddress when creating an allocation. +// You can use one of the provided ones or provide your own. +type RelayAddressGenerator interface { + // Validate confirms that the RelayAddressGenerator is properly initialized + Validate() error + + // Allocate a PacketConn (UDP) RelayAddress + AllocatePacketConn(network string, requestedPort int) (net.PacketConn, net.Addr, error) + + // Allocate a Conn (TCP) RelayAddress + AllocateConn(network string, requestedPort int) (net.Conn, net.Addr, error) +} + +// PermissionHandler is a callback to filter incoming CreatePermission and ChannelBindRequest +// requests based on the client IP address and port and the peer IP address the client intends to +// connect to. If the client is behind a NAT then the filter acts on the server reflexive +// ("mapped") address instead of the real client IP address and port. Note that TURN permissions +// are per-allocation and per-peer-IP-address, to mimic the address-restricted filtering mechanism +// of NATs that comply with [RFC4787], see https://tools.ietf.org/html/rfc5766#section-2.3. +type PermissionHandler func(clientAddr net.Addr, peerIP net.IP) (ok bool) + +// DefaultPermissionHandler is convince function that grants permission to all peers. +func DefaultPermissionHandler(net.Addr, net.IP) (ok bool) { + return true +} + +// PacketConnConfig is a single net.PacketConn to listen/write on. +// This will be used for UDP listeners. +type PacketConnConfig struct { + PacketConn net.PacketConn + + // When an allocation is generated the RelayAddressGenerator + // creates the net.PacketConn and returns the IP/Port it is available at + RelayAddressGenerator RelayAddressGenerator + + // PermissionHandler is a callback to filter peer addresses. Can be set as nil, in which + // case the DefaultPermissionHandler is automatically instantiated to admit all peer + // connections + PermissionHandler PermissionHandler +} + +func (c *PacketConnConfig) validate() error { + if c.PacketConn == nil { + return errConnUnset + } + + if c.RelayAddressGenerator != nil { + if err := c.RelayAddressGenerator.Validate(); err != nil { + return err + } + } + + return nil +} + +// ListenerConfig is a single net.Listener to accept connections on. +// This will be used for TCP, TLS and DTLS listeners. +type ListenerConfig struct { + Listener net.Listener + + // When an allocation is generated the RelayAddressGenerator + // creates the net.PacketConn and returns the IP/Port it is available at + RelayAddressGenerator RelayAddressGenerator + + // PermissionHandler is a callback to filter peer addresses. Can be set as nil, in which + // case the DefaultPermissionHandler is automatically instantiated to admit all peer + // connections + PermissionHandler PermissionHandler +} + +func (c *ListenerConfig) validate() error { + if c.Listener == nil { + return errListenerUnset + } + + if c.RelayAddressGenerator == nil { + return errRelayAddressGeneratorUnset + } + + return c.RelayAddressGenerator.Validate() +} + +// AuthHandler is a callback used to handle incoming auth requests, +// allowing users to customize Pion TURN with custom behavior. +type AuthHandler func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) + +// GenerateAuthKey is a convenience function to easily generate keys in the format used by AuthHandler. +func GenerateAuthKey(username, realm, password string) []byte { + // #nosec + h := md5.New() + fmt.Fprint(h, strings.Join([]string{username, realm, password}, ":")) // nolint: errcheck + + return h.Sum(nil) +} + +// EventHandler is a set of callbacks that the server will call at certain hook points during an +// allocation's lifecycle. +type EventHandler = allocation.EventHandler + +// QuotaHandler is a callback allows allocations to be rejected when a per-user quota is +// exceeded. If the callback returns true the allocation request is accepted, otherwise it is +// rejected and a 486 (Allocation Quota Reached) error is returned to the user. +type QuotaHandler func(username, realm string, srcAddr net.Addr) (ok bool) + +// ServerConfig configures the Pion TURN Server. +type ServerConfig struct { + // PacketConnConfigs and ListenerConfigs are a list of all the turn listeners + // Each listener can have custom behavior around the creation of Relays + PacketConnConfigs []PacketConnConfig + ListenerConfigs []ListenerConfig + + // LoggerFactory must be set for logging from this server. + LoggerFactory logging.LoggerFactory + + // Realm sets the realm for this server + Realm string + + // AuthHandler is a callback used to handle incoming auth requests, + // allowing users to customize Pion TURN with custom behavior + AuthHandler AuthHandler + + // QuotaHandler is a callback used to reject new allocations when a + // per-user quota is exceeded. + QuotaHandler QuotaHandler + + // EventHandlers is a set of callbacks for tracking allocation lifecycle. + EventHandler EventHandler + + // ChannelBindTimeout sets the lifetime of channel binding. Defaults to 10 minutes. + ChannelBindTimeout time.Duration + + // Sets the server inbound MTU(Maximum transmition unit). Defaults to 1600 bytes. + InboundMTU int +} + +func (s *ServerConfig) validate() error { + if len(s.PacketConnConfigs) == 0 && len(s.ListenerConfigs) == 0 { + return errNoAvailableConns + } + + for _, s := range s.PacketConnConfigs { + if err := s.validate(); err != nil { + return err + } + } + + for _, s := range s.ListenerConfigs { + if err := s.validate(); err != nil { + return err + } + } + + return nil +} diff --git a/vendor/github.com/pion/turn/v4/stun_conn.go b/vendor/github.com/pion/turn/v4/stun_conn.go new file mode 100644 index 0000000..fd1f1ae --- /dev/null +++ b/vendor/github.com/pion/turn/v4/stun_conn.go @@ -0,0 +1,128 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package turn + +import ( + "encoding/binary" + "errors" + "net" + "time" + + "github.com/pion/stun/v3" + "github.com/pion/turn/v4/internal/proto" +) + +var ( + errInvalidTURNFrame = errors.New("data is not a valid TURN frame, no STUN or ChannelData found") + errIncompleteTURNFrame = errors.New("data contains incomplete STUN or TURN frame") +) + +// STUNConn wraps a net.Conn and implements +// net.PacketConn by being STUN aware and +// packetizing the stream. +type STUNConn struct { + nextConn net.Conn + buff []byte +} + +const ( + stunHeaderSize = 20 + + channelDataLengthSize = 2 + channelDataNumberSize = channelDataLengthSize + channelDataHeaderSize = channelDataLengthSize + channelDataNumberSize + channelDataPadding = 4 +) + +// Given a buffer give the last offset of the TURN frame +// If the buffer isn't a valid STUN or ChannelData packet, +// or the length doesn't match return false. +func consumeSingleTURNFrame(b []byte) (int, error) { + // Too short to determine if ChannelData or STUN + if len(b) < 9 { + return 0, errIncompleteTURNFrame + } + + var datagramSize uint16 + switch { + case stun.IsMessage(b): + datagramSize = binary.BigEndian.Uint16(b[2:4]) + stunHeaderSize + case proto.ChannelNumber(binary.BigEndian.Uint16(b[0:2])).Valid(): + datagramSize = binary.BigEndian.Uint16(b[channelDataNumberSize:channelDataHeaderSize]) + if paddingOverflow := (datagramSize + channelDataPadding) % channelDataPadding; paddingOverflow != 0 { + datagramSize = (datagramSize + channelDataPadding) - paddingOverflow + } + + datagramSize += channelDataHeaderSize + case len(b) < stunHeaderSize: + return 0, errIncompleteTURNFrame + default: + return 0, errInvalidTURNFrame + } + + if len(b) < int(datagramSize) { + return 0, errIncompleteTURNFrame + } + + return int(datagramSize), nil +} + +// ReadFrom implements ReadFrom from net.PacketConn. +func (s *STUNConn) ReadFrom(payload []byte) (n int, addr net.Addr, err error) { + // First pass any buffered data from previous reads + n, err = consumeSingleTURNFrame(s.buff) + if errors.Is(err, errInvalidTURNFrame) { + return 0, nil, err + } else if err == nil { + copy(payload, s.buff[:n]) + s.buff = s.buff[n:] + + return n, s.nextConn.RemoteAddr(), nil + } + + // Then read from the nextConn, appending to our buff + n, err = s.nextConn.Read(payload) + if err != nil { + return 0, nil, err + } + + s.buff = append(s.buff, append([]byte{}, payload[:n]...)...) + + return s.ReadFrom(payload) +} + +// WriteTo implements WriteTo from net.PacketConn. +func (s *STUNConn) WriteTo(payload []byte, _ net.Addr) (n int, err error) { + return s.nextConn.Write(payload) +} + +// Close implements Close from net.PacketConn. +func (s *STUNConn) Close() error { + return s.nextConn.Close() +} + +// LocalAddr implements LocalAddr from net.PacketConn. +func (s *STUNConn) LocalAddr() net.Addr { + return s.nextConn.LocalAddr() +} + +// SetDeadline implements SetDeadline from net.PacketConn. +func (s *STUNConn) SetDeadline(t time.Time) error { + return s.nextConn.SetDeadline(t) +} + +// SetReadDeadline implements SetReadDeadline from net.PacketConn. +func (s *STUNConn) SetReadDeadline(t time.Time) error { + return s.nextConn.SetReadDeadline(t) +} + +// SetWriteDeadline implements SetWriteDeadline from net.PacketConn. +func (s *STUNConn) SetWriteDeadline(t time.Time) error { + return s.nextConn.SetWriteDeadline(t) +} + +// NewSTUNConn creates a STUNConn. +func NewSTUNConn(nextConn net.Conn) *STUNConn { + return &STUNConn{nextConn: nextConn} +} diff --git a/vendor/github.com/pion/webrtc/v4/.codacy.yaml b/vendor/github.com/pion/webrtc/v4/.codacy.yaml new file mode 100644 index 0000000..4e7406f --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/.codacy.yaml @@ -0,0 +1,6 @@ +--- +# SPDX-FileCopyrightText: 2026 The Pion community +# SPDX-License-Identifier: MIT + +exclude_paths: + - examples/examples.json diff --git a/vendor/github.com/pion/webrtc/v4/.eslintrc.json b/vendor/github.com/pion/webrtc/v4/.eslintrc.json new file mode 100644 index 0000000..a755cdb --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/.eslintrc.json @@ -0,0 +1,3 @@ +{ + "extends": ["standard"] +} diff --git a/vendor/github.com/pion/webrtc/v4/.gitignore b/vendor/github.com/pion/webrtc/v4/.gitignore new file mode 100644 index 0000000..2394557 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/.gitignore @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: 2026 The Pion community +# SPDX-License-Identifier: MIT + +### JetBrains IDE ### +##################### +.idea/ + +### Emacs Temporary Files ### +############################# +*~ + +### Folders ### +############### +bin/ +vendor/ +node_modules/ + +### Files ### +############# +*.ivf +*.ogg +tags +cover.out +*.sw[poe] +*.wasm +examples/sfu-ws/cert.pem +examples/sfu-ws/key.pem +wasm_exec.js diff --git a/vendor/github.com/pion/webrtc/v4/.golangci.yml b/vendor/github.com/pion/webrtc/v4/.golangci.yml new file mode 100644 index 0000000..1fbb8db --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/.golangci.yml @@ -0,0 +1,148 @@ +# SPDX-FileCopyrightText: 2026 The Pion community +# SPDX-License-Identifier: MIT + +version: "2" +linters: + enable: + - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers + - bidichk # Checks for dangerous unicode character sequences + - bodyclose # checks whether HTTP response body is closed successfully + - containedctx # containedctx is a linter that detects struct contained context.Context field + - contextcheck # check the function whether use a non-inherited context + - cyclop # checks function and package cyclomatic complexity + - decorder # check declaration order and count of types, constants, variables and functions + - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) + - dupl # Tool for code clone detection + - durationcheck # check for two durations multiplied together + - err113 # Golang linter to check the errors handling expressions + - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases + - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. + - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. + - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. + - exhaustive # check exhaustiveness of enum switch statements + - forbidigo # Forbids identifiers + - forcetypeassert # finds forced type assertions + - gochecknoglobals # Checks that no globals are present in Go code + - gocognit # Computes and checks the cognitive complexity of functions + - goconst # Finds repeated strings that could be replaced by a constant + - gocritic # The most opinionated Go source code linter + - gocyclo # Computes and checks the cyclomatic complexity of functions + - godot # Check if comments end in a period + - godox # Tool for detection of FIXME, TODO and other comment keywords + - goheader # Checks is file header matches to pattern + - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. + - goprintffuncname # Checks that printf-like functions are named with `f` at the end + - gosec # Inspects source code for security problems + - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string + - grouper # An analyzer to analyze expression groups. + - importas # Enforces consistent import aliases + - ineffassign # Detects when assignments to existing variables are not used + - lll # Reports long lines + - maintidx # maintidx measures the maintainability index of each function. + - makezero # Finds slice declarations with non-zero initial length + - misspell # Finds commonly misspelled English words in comments + - modernize # Replace and suggests simplifications to code + - nakedret # Finds naked returns in functions greater than a specified function length + - nestif # Reports deeply nested if statements + - nilerr # Finds the code that returns nil even if it checks that the error is not nil. + - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. + - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity + - noctx # noctx finds sending http request without context.Context + - predeclared # find code that shadows one of Go's predeclared identifiers + - revive # golint replacement, finds style mistakes + - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks + - tagliatelle # Checks the struct tags. + - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers + - unconvert # Remove unnecessary type conversions + - unparam # Reports unused function parameters + - unused # Checks Go code for unused constants, variables, functions and types + - varnamelen # checks that the length of a variable's name matches its scope + - wastedassign # wastedassign finds wasted assignment statements + - whitespace # Tool for detection of leading and trailing whitespace + disable: + - depguard # Go linter that checks if package imports are in a list of acceptable packages + - funlen # Tool for detection of long functions + - gochecknoinits # Checks that no init functions are present in Go code + - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. + - interfacebloat # A linter that checks length of interface. + - ireturn # Accept Interfaces, Return Concrete Types + - mnd # An analyzer to detect magic numbers + - nolintlint # Reports ill-formed or insufficient nolint directives + - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test + - prealloc # Finds slice declarations that could potentially be preallocated + - promlinter # Check Prometheus metrics naming via promlint + - rowserrcheck # checks whether Err of rows is checked successfully + - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. + - testpackage # linter that makes you use a separate _test package + - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes + - wrapcheck # Checks that errors returned from external packages are wrapped + - wsl # Whitespace Linter - Forces you to use empty lines! + settings: + staticcheck: + checks: + - all + - -QF1008 # "could remove embedded field", to keep it explicit! + - -QF1003 # "could use tagged switch on enum", Cases conflicts with exhaustive! + exhaustive: + default-signifies-exhaustive: true + forbidigo: + forbid: + - pattern: ^fmt.Print(f|ln)?$ + - pattern: ^log.(Panic|Fatal|Print)(f|ln)?$ + - pattern: ^os.Exit$ + - pattern: ^panic$ + - pattern: ^print(ln)?$ + - pattern: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$ + pkg: ^testing$ + msg: use testify/assert instead + analyze-types: true + gomodguard: + blocked: + modules: + - github.com/pkg/errors: + recommendations: + - errors + govet: + enable: + - shadow + revive: + rules: + # Prefer 'any' type alias over 'interface{}' for Go 1.18+ compatibility + - name: use-any + severity: warning + disabled: false + misspell: + locale: US + varnamelen: + max-distance: 12 + min-name-length: 2 + ignore-type-assert-ok: true + ignore-map-index-ok: true + ignore-chan-recv-ok: true + ignore-decls: + - i int + - n int + - w io.Writer + - r io.Reader + - b []byte + exclusions: + generated: lax + rules: + - linters: + - forbidigo + - gocognit + path: (examples|main\.go) + - linters: + - gocognit + path: _test\.go + - linters: + - forbidigo + path: cmd +formatters: + enable: + - gci # Gci control golang package import order and make it always deterministic. + - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification + - gofumpt # Gofumpt checks whether code was gofumpt-ed. + - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports + exclusions: + generated: lax diff --git a/vendor/github.com/pion/webrtc/v4/.goreleaser.yml b/vendor/github.com/pion/webrtc/v4/.goreleaser.yml new file mode 100644 index 0000000..8577d86 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/.goreleaser.yml @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: 2026 The Pion community +# SPDX-License-Identifier: MIT + +builds: +- skip: true diff --git a/vendor/github.com/pion/webrtc/v4/DESIGN.md b/vendor/github.com/pion/webrtc/v4/DESIGN.md new file mode 100644 index 0000000..45ca1ac --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/DESIGN.md @@ -0,0 +1,43 @@ +

+ Design +

+WebRTC is a powerful, but complicated technology you can build amazing things with, it comes with a steep learning curve though. +Using WebRTC in the browser is easy, but outside the browser is more of a challenge. There are multiple libraries, and they all have +varying levels of quality. Most are also difficult to build, and depend on libraries that aren't available in repos or portable. + +Pion WebRTC aims to solve all that! Built in native Go you should be able to send and receive media and text from anywhere with minimal headache. +These are the design principals that drive Pion WebRTC and hopefully convince you it is worth a try. + +### Portable +Pion WebRTC is written in Go and extremely portable. Anywhere Golang runs, Pion WebRTC should work as well! Instead of dealing with complicated +cross-compiling of multiple libraries, you now can run anywhere with one `go build` + +### Flexible +When possible we leave all decisions to the user. When choice is possible (like what logging library is used) we defer to the developer. + +### Simple API +If you know how to use WebRTC in your browser, you know how to use Pion WebRTC. +We try our best just to duplicate the Javascript API, so your code can look the same everywhere. + +If this is your first time using WebRTC, don't worry! We have multiple [examples](https://github.com/pion/webrtc/tree/master/examples) and [GoDoc](https://pkg.go.dev/github.com/pion/webrtc/v4) + +### Bring your own media +Pion WebRTC doesn't make any assumptions about where your audio, video or text come from. You can use FFmpeg, GStreamer, MLT or just serve a video file. +This library only serves to transport, not create media. + +### Safe +Golang provides a great foundation to build safe network services. +Especially when running a networked service that is highly concurrent bugs can be devastating. + +### Readable +If code comes from an RFC we try to make sure everything is commented with a link to the spec. +This makes learning and debugging easier, this WebRTC library was written to also serve as a guide for others. + +### Tested +Every commit is tested via travis-ci Go provides fantastic facilities for testing, and more will be added as time goes on. + +### Shared libraries +Every Pion project is built using shared libraries, allowing others to review and reuse our libraries. + +### Community +The most important part of Pion is the community. This projects only exist because of individual contributions. We aim to be radically open and do everything we can to support those that make Pion possible. diff --git a/vendor/github.com/pion/webrtc/v4/LICENSE b/vendor/github.com/pion/webrtc/v4/LICENSE new file mode 100644 index 0000000..d96df05 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/LICENSE @@ -0,0 +1,9 @@ +MIT License + +Copyright (c) 2026 The Pion community + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/pion/webrtc/v4/README.md b/vendor/github.com/pion/webrtc/v4/README.md new file mode 100644 index 0000000..97bbe70 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/README.md @@ -0,0 +1,139 @@ +

+ Pion WebRTC +
+ Pion WebRTC +
+

+

A pure Go implementation of the WebRTC API

+

+ Pion WebRTC + Sourcegraph Widget + join us on Discord Follow us on Bluesky Twitter Widget + +
+ GitHub Workflow Status + Go Reference + Coverage Status + Go Report Card + License: MIT +

+
+ +### New Release + +Pion WebRTC v4.0.0 has been released! See the [release notes](https://github.com/pion/webrtc/wiki/Release-WebRTC@v4.0.0) to learn about new features and breaking changes. + +If you aren't able to upgrade yet check the [tags](https://github.com/pion/webrtc/tags) for the latest `v3` release. + +We would love your feedback! Please create GitHub issues or Join the [Discord](https://discord.gg/PngbdqpFbt) to follow development and speak with the maintainers. + +----- + +### Usage +[Go Modules](https://blog.golang.org/using-go-modules) are mandatory for using Pion WebRTC. So make sure you set `export GO111MODULE=on`, and explicitly specify `/v4` (or an earlier version) when importing. + + +**[example applications](examples/README.md)** contains code samples of common things people build with Pion WebRTC. + +**[example-webrtc-applications](https://github.com/pion/example-webrtc-applications)** contains more full featured examples that use 3rd party libraries. + +**[awesome-pion](https://github.com/pion/awesome-pion)** contains projects that have used Pion, and serve as real world examples of usage. + +**[GoDoc](https://pkg.go.dev/github.com/pion/webrtc/v4)** is an auto generated API reference. All our Public APIs are commented. + +**[FAQ](https://github.com/pion/webrtc/wiki/FAQ)** has answers to common questions. If you have a question not covered please ask in [Discord](https://discord.gg/PngbdqpFbt) we are always looking to expand it. + +Now go build something awesome! Here are some **ideas** to get your creative juices flowing: +* Send a video file to multiple browser in real time for perfectly synchronized movie watching. +* Send a webcam on an embedded device to your browser with no additional server required! +* Securely send data between two servers, without using pub/sub. +* Record your webcam and do special effects server side. +* Build a conferencing application that processes audio/video and make decisions off of it. +* Remotely control a robots and stream its cameras in realtime. + +### Need Help? +Check out [WebRTC for the Curious](https://webrtcforthecurious.com). A book about WebRTC in depth, not just about the APIs. +Learn the full details of ICE, SCTP, DTLS, SRTP, and how they work together to make up the WebRTC stack. This is also a great +resource if you are trying to debug. Learn the tools of the trade and how to approach WebRTC issues. This book is vendor +agnostic and will not have any Pion specific information. + +Pion has an active community on [Discord](https://discord.gg/PngbdqpFbt). Please ask for help about anything, questions don't have to be Pion specific! +Come share your interesting project you are working on. We are here to support you. + +One of the maintainers of Pion [Sean-Der](https://github.com/sean-der) is available to help. Schedule at [siobud.com/meeting](https://siobud.com/meeting) +He is available to talk about Pion or general WebRTC questions, feel free to reach out about anything! + +### Features +#### PeerConnection API +* Go implementation of [webrtc-pc](https://w3c.github.io/webrtc-pc/) and [webrtc-stats](https://www.w3.org/TR/webrtc-stats/) +* DataChannels +* Send/Receive audio and video +* Renegotiation +* Plan-B and Unified Plan +* [SettingEngine](https://pkg.go.dev/github.com/pion/webrtc/v4#SettingEngine) for Pion specific extensions + + +#### Connectivity +* Full ICE Agent +* ICE Restart +* Trickle ICE +* STUN +* TURN (UDP, TCP, DTLS and TLS) +* mDNS candidates + +#### DataChannels +* Ordered/Unordered +* Lossy/Lossless + +#### Media +* API with direct RTP/RTCP access +* Opus, PCM, H264, VP8 and VP9 packetizer +* API also allows developer to pass their own packetizer +* IVF, Ogg, H264 and Matroska provided for easy sending and saving +* [getUserMedia](https://github.com/pion/mediadevices) implementation (Requires Cgo) +* Easy integration with x264, libvpx, GStreamer and ffmpeg. +* [Simulcast](https://github.com/pion/webrtc/tree/master/examples/simulcast) +* [SVC](https://github.com/pion/rtp/blob/master/codecs/vp9_packet.go#L138) +* [NACK](https://github.com/pion/interceptor/pull/4) +* [Sender/Receiver Reports](https://github.com/pion/interceptor/tree/master/pkg/report) +* [Transport Wide Congestion Control Feedback](https://github.com/pion/interceptor/tree/master/pkg/twcc) +* [Bandwidth Estimation](https://github.com/pion/webrtc/tree/master/examples/bandwidth-estimation-from-disk) + +#### Security +* TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 and TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA for DTLS v1.2 +* SRTP_AEAD_AES_256_GCM and SRTP_AES128_CM_HMAC_SHA1_80 for SRTP +* Hardware acceleration available for GCM suites + +#### Pure Go +* No Cgo usage +* Wide platform support + * Windows, macOS, Linux, FreeBSD + * iOS, Android + * [WASM](https://github.com/pion/webrtc/wiki/WebAssembly-Development-and-Testing) see [examples](examples/README.md#webassembly) + * 386, amd64, arm, mips, ppc64 +* Easy to build *Numbers generated on Intel(R) Core(TM) i5-2520M CPU @ 2.50GHz* + * **Time to build examples/play-from-disk** - 0.66s user 0.20s system 306% cpu 0.279 total + * **Time to run entire test suite** - 25.60s user 9.40s system 45% cpu 1:16.69 total +* Tools to measure performance [provided](https://github.com/pion/rtsp-bench) + +### Roadmap +The library is in active development, please refer to the [roadmap](https://github.com/pion/webrtc/issues/9) to track our major milestones. +We also maintain a list of [Big Ideas](https://github.com/pion/webrtc/wiki/Big-Ideas) these are things we want to build but don't have a clear plan or the resources yet. +If you are looking to get involved this is a great place to get started! We would also love to hear your ideas! Even if you can't implement it yourself, it could inspire others. + +### Sponsoring +Work on Pion's congestion control and bandwidth estimation was funded through the [User-Operated Internet](https://nlnet.nl/useroperated/) fund, a fund established by [NLnet](https://nlnet.nl/) made possible by financial support from the [PKT Community](https://pkt.cash/)/[The Network Steward](https://pkt.cash/network-steward) and stichting [Technology Commons Trust](https://technologycommons.org/). + +### Community +Pion has an active community on the [Discord](https://discord.gg/PngbdqpFbt). + +Follow the [Pion Bluesky](https://bsky.app/profile/pion.ly) or [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news. + +We are always looking to support **your projects**. Please reach out if you have something to build! +If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly) + +### Contributing +Check out the [contributing wiki](https://github.com/pion/webrtc/wiki/Contributing) to join the group of amazing people making this project possible + +### License +MIT License - see [LICENSE](LICENSE) for full text diff --git a/vendor/github.com/pion/webrtc/v4/api.go b/vendor/github.com/pion/webrtc/v4/api.go new file mode 100644 index 0000000..1f7b4a5 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/api.go @@ -0,0 +1,96 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js + +package webrtc + +import ( + "github.com/pion/interceptor" + "github.com/pion/logging" +) + +// API allows configuration of a PeerConnection +// with APIs that are available in the standard. This +// lets you set custom behavior via the SettingEngine, configure +// codecs via the MediaEngine and define custom media behaviors via +// Interceptors. +type API struct { + settingEngine *SettingEngine + mediaEngine *MediaEngine + interceptorRegistry *interceptor.Registry + + interceptor interceptor.Interceptor // Generated per PeerConnection +} + +// NewAPI Creates a new API object for keeping semi-global settings to WebRTC objects +// +// It uses the default Codecs and Interceptors unless you customize them +// using WithMediaEngine and WithInterceptorRegistry respectively. +func NewAPI(options ...func(*API)) *API { + api := &API{ + interceptor: &interceptor.NoOp{}, + settingEngine: &SettingEngine{}, + } + + for _, o := range options { + o(api) + } + + if api.settingEngine.LoggerFactory == nil { + api.settingEngine.LoggerFactory = logging.NewDefaultLoggerFactory() + } + + logger := api.settingEngine.LoggerFactory.NewLogger("api") + + if api.mediaEngine == nil { + api.mediaEngine = &MediaEngine{} + err := api.mediaEngine.RegisterDefaultCodecs() + if err != nil { + logger.Errorf("Failed to register default codecs %s", err) + } + } + + if api.interceptorRegistry == nil { + api.interceptorRegistry = &interceptor.Registry{} + err := RegisterDefaultInterceptorsWithOptions(api.mediaEngine, api.interceptorRegistry, + WithInterceptorLoggerFactory(api.settingEngine.LoggerFactory)) + if err != nil { + logger.Errorf("Failed to register default interceptors %s", err) + } + } + + return api +} + +// WithMediaEngine allows providing a MediaEngine to the API. +// Settings can be changed after passing the engine to an API. +// When a PeerConnection is created the MediaEngine is copied +// and no more changes can be made. +func WithMediaEngine(m *MediaEngine) func(a *API) { + return func(a *API) { + a.mediaEngine = m + if a.mediaEngine == nil { + a.mediaEngine = &MediaEngine{} + } + } +} + +// WithSettingEngine allows providing a SettingEngine to the API. +// Settings should not be changed after passing the engine to an API. +func WithSettingEngine(s SettingEngine) func(a *API) { + return func(a *API) { + a.settingEngine = &s + } +} + +// WithInterceptorRegistry allows providing Interceptors to the API. +// Settings should not be changed after passing the registry to an API. +func WithInterceptorRegistry(ir *interceptor.Registry) func(a *API) { + return func(a *API) { + a.interceptorRegistry = ir + if a.interceptorRegistry == nil { + a.interceptorRegistry = &interceptor.Registry{} + } + } +} diff --git a/vendor/github.com/pion/webrtc/v4/api_js.go b/vendor/github.com/pion/webrtc/v4/api_js.go new file mode 100644 index 0000000..cca17b4 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/api_js.go @@ -0,0 +1,35 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build js && wasm +// +build js,wasm + +package webrtc + +// API bundles the global functions of the WebRTC and ORTC API. +type API struct { + settingEngine *SettingEngine +} + +// NewAPI Creates a new API object for keeping semi-global settings to WebRTC objects +func NewAPI(options ...func(*API)) *API { + a := &API{} + + for _, o := range options { + o(a) + } + + if a.settingEngine == nil { + a.settingEngine = &SettingEngine{} + } + + return a +} + +// WithSettingEngine allows providing a SettingEngine to the API. +// Settings should not be changed after passing the engine to an API. +func WithSettingEngine(s SettingEngine) func(a *API) { + return func(a *API) { + a.settingEngine = &s + } +} diff --git a/vendor/github.com/pion/webrtc/v4/bundlepolicy.go b/vendor/github.com/pion/webrtc/v4/bundlepolicy.go new file mode 100644 index 0000000..13411b0 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/bundlepolicy.go @@ -0,0 +1,85 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +import ( + "encoding/json" +) + +// BundlePolicy affects which media tracks are negotiated if the remote +// endpoint is not bundle-aware, and what ICE candidates are gathered. If the +// remote endpoint is bundle-aware, all media tracks and data channels are +// bundled onto the same transport. +type BundlePolicy int + +const ( + // BundlePolicyUnknown is the enum's zero-value. + BundlePolicyUnknown BundlePolicy = iota + + // BundlePolicyBalanced indicates to gather ICE candidates for each + // media type in use (audio, video, and data). If the remote endpoint is + // not bundle-aware, negotiate only one audio and video track on separate + // transports. + BundlePolicyBalanced + + // BundlePolicyMaxCompat indicates to gather ICE candidates for each + // track. If the remote endpoint is not bundle-aware, negotiate all media + // tracks on separate transports. + BundlePolicyMaxCompat + + // BundlePolicyMaxBundle indicates to gather ICE candidates for only + // one track. If the remote endpoint is not bundle-aware, negotiate only + // one media track. + BundlePolicyMaxBundle +) + +// This is done this way because of a linter. +const ( + bundlePolicyBalancedStr = "balanced" + bundlePolicyMaxCompatStr = "max-compat" + bundlePolicyMaxBundleStr = "max-bundle" +) + +func newBundlePolicy(raw string) BundlePolicy { + switch raw { + case bundlePolicyBalancedStr: + return BundlePolicyBalanced + case bundlePolicyMaxCompatStr: + return BundlePolicyMaxCompat + case bundlePolicyMaxBundleStr: + return BundlePolicyMaxBundle + default: + return BundlePolicyUnknown + } +} + +func (t BundlePolicy) String() string { + switch t { + case BundlePolicyBalanced: + return bundlePolicyBalancedStr + case BundlePolicyMaxCompat: + return bundlePolicyMaxCompatStr + case BundlePolicyMaxBundle: + return bundlePolicyMaxBundleStr + default: + return ErrUnknownType.Error() + } +} + +// UnmarshalJSON parses the JSON-encoded data and stores the result. +func (t *BundlePolicy) UnmarshalJSON(b []byte) error { + var val string + if err := json.Unmarshal(b, &val); err != nil { + return err + } + + *t = newBundlePolicy(val) + + return nil +} + +// MarshalJSON returns the JSON encoding. +func (t BundlePolicy) MarshalJSON() ([]byte, error) { + return json.Marshal(t.String()) +} diff --git a/vendor/github.com/pion/webrtc/v4/certificate.go b/vendor/github.com/pion/webrtc/v4/certificate.go new file mode 100644 index 0000000..9ad5e13 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/certificate.go @@ -0,0 +1,262 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +import ( + "crypto" + "crypto/ecdsa" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "encoding/pem" + "fmt" + "math/big" + "strings" + "time" + + "github.com/pion/dtls/v3/pkg/crypto/fingerprint" + "github.com/pion/webrtc/v4/pkg/rtcerr" +) + +// Certificate represents a x509Cert used to authenticate WebRTC communications. +type Certificate struct { + privateKey crypto.PrivateKey + x509Cert *x509.Certificate + statsID string +} + +// NewCertificate generates a new x509 compliant Certificate to be used +// by DTLS for encrypting data sent over the wire. This method differs from +// GenerateCertificate by allowing to specify a template x509.Certificate to +// be used in order to define certificate parameters. +func NewCertificate(key crypto.PrivateKey, tpl x509.Certificate) (*Certificate, error) { + var err error + var certDER []byte + switch sk := key.(type) { + case *rsa.PrivateKey: + pk := sk.Public() + tpl.SignatureAlgorithm = x509.SHA256WithRSA + certDER, err = x509.CreateCertificate(rand.Reader, &tpl, &tpl, pk, sk) + if err != nil { + return nil, &rtcerr.UnknownError{Err: err} + } + case *ecdsa.PrivateKey: + pk := sk.Public() + tpl.SignatureAlgorithm = x509.ECDSAWithSHA256 + certDER, err = x509.CreateCertificate(rand.Reader, &tpl, &tpl, pk, sk) + if err != nil { + return nil, &rtcerr.UnknownError{Err: err} + } + default: + return nil, &rtcerr.NotSupportedError{Err: ErrPrivateKeyType} + } + + cert, err := x509.ParseCertificate(certDER) + if err != nil { + return nil, &rtcerr.UnknownError{Err: err} + } + + return &Certificate{ + privateKey: key, + x509Cert: cert, + statsID: fmt.Sprintf("certificate-%d", time.Now().UnixNano()), + }, nil +} + +// Equals determines if two certificates are identical by comparing both the +// secretKeys and x509Certificates. +func (c Certificate) Equals(cert Certificate) bool { + switch cSK := c.privateKey.(type) { + case *rsa.PrivateKey: + if oSK, ok := cert.privateKey.(*rsa.PrivateKey); ok { + if cSK.N.Cmp(oSK.N) != 0 { + return false + } + + return c.x509Cert.Equal(cert.x509Cert) + } + + return false + case *ecdsa.PrivateKey: + if oSK, ok := cert.privateKey.(*ecdsa.PrivateKey); ok { + if cSK.X.Cmp(oSK.X) != 0 || cSK.Y.Cmp(oSK.Y) != 0 { + return false + } + + return c.x509Cert.Equal(cert.x509Cert) + } + + return false + default: + return false + } +} + +// Expires returns the timestamp after which this certificate is no longer valid. +func (c Certificate) Expires() time.Time { + if c.x509Cert == nil { + return time.Time{} + } + + return c.x509Cert.NotAfter +} + +// GetFingerprints returns the list of certificate fingerprints, one of which +// is computed with the digest algorithm used in the certificate signature. +func (c Certificate) GetFingerprints() ([]DTLSFingerprint, error) { + fingerprintAlgorithms := []crypto.Hash{crypto.SHA256} + res := make([]DTLSFingerprint, len(fingerprintAlgorithms)) + + i := 0 + for _, algo := range fingerprintAlgorithms { + name, err := fingerprint.StringFromHash(algo) + if err != nil { + // nolint + return nil, fmt.Errorf("%w: %v", ErrFailedToGenerateCertificateFingerprint, err) + } + value, err := fingerprint.Fingerprint(c.x509Cert, algo) + if err != nil { + // nolint + return nil, fmt.Errorf("%w: %v", ErrFailedToGenerateCertificateFingerprint, err) + } + res[i] = DTLSFingerprint{ + Algorithm: name, + Value: value, + } + } + + return res[:i+1], nil +} + +// GenerateCertificate causes the creation of an X.509 certificate and +// corresponding private key. +func GenerateCertificate(secretKey crypto.PrivateKey) (*Certificate, error) { + // Max random value, a 130-bits integer, i.e 2^130 - 1 + maxBigInt := new(big.Int) + /* #nosec */ + maxBigInt.Exp(big.NewInt(2), big.NewInt(130), nil).Sub(maxBigInt, big.NewInt(1)) + /* #nosec */ + serialNumber, err := rand.Int(rand.Reader, maxBigInt) + if err != nil { + return nil, &rtcerr.UnknownError{Err: err} + } + + return NewCertificate(secretKey, x509.Certificate{ + Issuer: pkix.Name{CommonName: generatedCertificateOrigin}, + NotBefore: time.Now().AddDate(0, 0, -1), + NotAfter: time.Now().AddDate(0, 1, -1), + SerialNumber: serialNumber, + Version: 2, + Subject: pkix.Name{CommonName: generatedCertificateOrigin}, + }) +} + +// CertificateFromX509 creates a new WebRTC Certificate from a given PrivateKey and Certificate +// +// This can be used if you want to share a certificate across multiple PeerConnections. +func CertificateFromX509(privateKey crypto.PrivateKey, certificate *x509.Certificate) Certificate { + return Certificate{privateKey, certificate, fmt.Sprintf("certificate-%d", time.Now().UnixNano())} +} + +func (c Certificate) collectStats(report *statsReportCollector) error { + report.Collecting() + + fingerPrintAlgo, err := c.GetFingerprints() + if err != nil { + return err + } + + base64Certificate := base64.RawURLEncoding.EncodeToString(c.x509Cert.Raw) + + stats := CertificateStats{ + Timestamp: statsTimestampFrom(time.Now()), + Type: StatsTypeCertificate, + ID: c.statsID, + Fingerprint: fingerPrintAlgo[0].Value, + FingerprintAlgorithm: fingerPrintAlgo[0].Algorithm, + Base64Certificate: base64Certificate, + IssuerCertificateID: c.x509Cert.Issuer.String(), + } + + report.Collect(stats.ID, stats) + + return nil +} + +// CertificateFromPEM creates a fresh certificate based on a string containing +// pem blocks fort the private key and x509 certificate. +func CertificateFromPEM(pems string) (*Certificate, error) { //nolint: cyclop + var cert *x509.Certificate + var privateKey crypto.PrivateKey + + var block *pem.Block + more := []byte(pems) + for { + var err error + block, more = pem.Decode(more) + if block == nil { + break + } + + // decode & parse the certificate + switch block.Type { + case "CERTIFICATE": + if cert != nil { + return nil, errCertificatePEMMultipleCert + } + cert, err = x509.ParseCertificate(block.Bytes) + // If parsing failed using block.Bytes, then parse the bytes as base64 and try again + if err != nil { + var n int + certBytes := make([]byte, base64.StdEncoding.DecodedLen(len(block.Bytes))) + n, err = base64.StdEncoding.Decode(certBytes, block.Bytes) + if err == nil { + cert, err = x509.ParseCertificate(certBytes[:n]) + } + } + case "PRIVATE KEY": + if privateKey != nil { + return nil, errCertificatePEMMultiplePriv + } + privateKey, err = x509.ParsePKCS8PrivateKey(block.Bytes) + } + + // Report errors from parsing either the private key or the certificate + if err != nil { + return nil, fmt.Errorf("failed to decode %s: %w", block.Type, err) + } + } + + if cert == nil || privateKey == nil { + return nil, errCertificatePEMMissing + } + + ret := CertificateFromX509(privateKey, cert) + + return &ret, nil +} + +// PEM returns the certificate encoded as two pem block: once for the X509 +// certificate and the other for the private key. +func (c Certificate) PEM() (string, error) { + // First write the X509 certificate + var builder strings.Builder + err := pem.Encode(&builder, &pem.Block{Type: "CERTIFICATE", Bytes: c.x509Cert.Raw}) + if err != nil { + return "", fmt.Errorf("failed to pem encode the X certificate: %w", err) + } + // Next write the private key + privBytes, err := x509.MarshalPKCS8PrivateKey(c.privateKey) + if err != nil { + return "", fmt.Errorf("failed to marshal private key: %w", err) + } + err = pem.Encode(&builder, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}) + if err != nil { + return "", fmt.Errorf("failed to encode private key: %w", err) + } + + return builder.String(), nil +} diff --git a/vendor/github.com/pion/webrtc/v4/codecov.yml b/vendor/github.com/pion/webrtc/v4/codecov.yml new file mode 100644 index 0000000..b9639c2 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/codecov.yml @@ -0,0 +1,22 @@ +# +# DO NOT EDIT THIS FILE +# +# It is automatically copied from https://github.com/pion/.goassets repository. +# +# SPDX-FileCopyrightText: 2026 The Pion community +# SPDX-License-Identifier: MIT + +coverage: + status: + project: + default: + # Allow decreasing 2% of total coverage to avoid noise. + threshold: 2% + patch: + default: + target: 70% + only_pulls: true + +ignore: + - "examples/*" + - "examples/**/*" diff --git a/vendor/github.com/pion/webrtc/v4/configuration.go b/vendor/github.com/pion/webrtc/v4/configuration.go new file mode 100644 index 0000000..3254dd0 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/configuration.go @@ -0,0 +1,58 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js + +package webrtc + +// A Configuration defines how peer-to-peer communication via PeerConnection +// is established or re-established. +// Configurations may be set up once and reused across multiple connections. +// Configurations are treated as readonly. As long as they are unmodified, +// they are safe for concurrent use. +type Configuration struct { + // ICEServers defines a slice describing servers available to be used by + // ICE, such as STUN and TURN servers. + ICEServers []ICEServer `json:"iceServers,omitempty"` + + // ICETransportPolicy indicates which candidates the ICEAgent is allowed + // to use. + ICETransportPolicy ICETransportPolicy `json:"iceTransportPolicy,omitempty"` + + // BundlePolicy indicates which media-bundling policy to use when gathering + // ICE candidates. + BundlePolicy BundlePolicy `json:"bundlePolicy,omitempty"` + + // RTCPMuxPolicy indicates which rtcp-mux policy to use when gathering ICE + // candidates. + RTCPMuxPolicy RTCPMuxPolicy `json:"rtcpMuxPolicy,omitempty"` + + // PeerIdentity sets the target peer identity for the PeerConnection. + // The PeerConnection will not establish a connection to a remote peer + // unless it can be successfully authenticated with the provided name. + PeerIdentity string `json:"peerIdentity,omitempty"` + + // Certificates describes a set of certificates that the PeerConnection + // uses to authenticate. Valid values for this parameter are created + // through calls to the GenerateCertificate function. Although any given + // DTLS connection will use only one certificate, this attribute allows the + // caller to provide multiple certificates that support different + // algorithms. The final certificate will be selected based on the DTLS + // handshake, which establishes which certificates are allowed. The + // PeerConnection implementation selects which of the certificates is + // used for a given connection; how certificates are selected is outside + // the scope of this specification. If this value is absent, then a default + // set of certificates is generated for each PeerConnection instance. + Certificates []Certificate `json:"certificates,omitempty"` + + // ICECandidatePoolSize describes the size of the prefetched ICE pool. + ICECandidatePoolSize uint8 `json:"iceCandidatePoolSize,omitempty"` + + // SDPSemantics controls the type of SDP offers accepted by and + // SDP answers generated by the PeerConnection. + SDPSemantics SDPSemantics `json:"sdpSemantics,omitempty"` + + // AlwaysNegotiateDataChannels specifies whether the application prefers + // to always negotiate data channels in the initial SDP offer. + AlwaysNegotiateDataChannels bool `json:"alwaysNegotiateDataChannels,omitempty"` +} diff --git a/vendor/github.com/pion/webrtc/v4/configuration_common.go b/vendor/github.com/pion/webrtc/v4/configuration_common.go new file mode 100644 index 0000000..8c8181b --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/configuration_common.go @@ -0,0 +1,28 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +import "strings" + +// getICEServers side-steps the strict parsing mode of the ice package +// (as defined in https://tools.ietf.org/html/rfc7064) by copying and then +// stripping any erroneous queries from "stun(s):" URLs before parsing. +func (c Configuration) getICEServers() []ICEServer { + iceServers := append([]ICEServer{}, c.ICEServers...) + + for iceServersIndex := range iceServers { + iceServers[iceServersIndex].URLs = append([]string{}, iceServers[iceServersIndex].URLs...) + + for urlsIndex, rawURL := range iceServers[iceServersIndex].URLs { + if strings.HasPrefix(rawURL, "stun") { + // strip the query from "stun(s):" if present + parts := strings.Split(rawURL, "?") + rawURL = parts[0] + } + iceServers[iceServersIndex].URLs[urlsIndex] = rawURL + } + } + + return iceServers +} diff --git a/vendor/github.com/pion/webrtc/v4/configuration_js.go b/vendor/github.com/pion/webrtc/v4/configuration_js.go new file mode 100644 index 0000000..5a88f0d --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/configuration_js.go @@ -0,0 +1,45 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build js && wasm +// +build js,wasm + +package webrtc + +// Configuration defines a set of parameters to configure how the +// peer-to-peer communication via PeerConnection is established or +// re-established. +type Configuration struct { + // ICEServers defines a slice describing servers available to be used by + // ICE, such as STUN and TURN servers. + ICEServers []ICEServer + + // ICETransportPolicy indicates which candidates the ICEAgent is allowed + // to use. + ICETransportPolicy ICETransportPolicy + + // BundlePolicy indicates which media-bundling policy to use when gathering + // ICE candidates. + BundlePolicy BundlePolicy + + // RTCPMuxPolicy indicates which rtcp-mux policy to use when gathering ICE + // candidates. + RTCPMuxPolicy RTCPMuxPolicy + + // PeerIdentity sets the target peer identity for the PeerConnection. + // The PeerConnection will not establish a connection to a remote peer + // unless it can be successfully authenticated with the provided name. + PeerIdentity string + + // Certificates are not supported in the JavaScript/Wasm bindings. + // Certificates []Certificate + + // ICECandidatePoolSize describes the size of the prefetched ICE pool. + ICECandidatePoolSize uint8 + + // AlwaysNegotiateDataChannels specifies whether the application prefers + // to always negotiate data channels in the initial SDP offer. + AlwaysNegotiateDataChannels bool + + Certificates []Certificate `json:"certificates,omitempty"` +} diff --git a/vendor/github.com/pion/webrtc/v4/constants.go b/vendor/github.com/pion/webrtc/v4/constants.go new file mode 100644 index 0000000..caf24c1 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/constants.go @@ -0,0 +1,68 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +import ( + "math" + + "github.com/pion/dtls/v3" +) + +const ( + // default as the standard ethernet MTU + // can be overwritten with SettingEngine.SetReceiveMTU(). + receiveMTU = 1500 + + // simulcastProbeCount is the amount of RTP Packets + // that handleUndeclaredSSRC will read and try to dispatch from + // mid and rid values. + simulcastProbeCount = 10 + + // simulcastMaxProbeRoutines is how many active routines can be used to probe + // If the total amount of incoming SSRCes exceeds this new requests will be ignored. + simulcastMaxProbeRoutines = 25 + + // Default Max SCTP Message Size is the largest single DataChannel + // message we can send or accept. This default was chosen to match FireFox. + defaultMaxSCTPMessageSize = 1073741823 + + // If a DataChannel Max Message Size isn't declared by the Remote(max-message-size) + // this is the value we default to. This value was chosen because it was the behavior + // of Pion before max-message-size was implemented. + sctpMaxMessageSizeUnsetValue = math.MaxUint16 + + mediaSectionApplication = "application" + + sdpAttributeRid = "rid" + + sdpAttributeSimulcast = "simulcast" + + outboundMTU = 1200 + + rtpPayloadTypeBitmask = 0x7F + + incomingUnhandledRTPSsrc = "Incoming unhandled RTP ssrc(%d), OnTrack will not be fired. %v" + + useReadSimulcast = "Use ReadSimulcast(rid) instead of Read() when multiple tracks are present" + + generatedCertificateOrigin = "WebRTC" + + // AttributeRtxPayloadType is the interceptor attribute added when Read() + // returns an RTX packet containing the RTX stream payload type. + AttributeRtxPayloadType = "rtx_payload_type" + // AttributeRtxSsrc is the interceptor attribute added when Read() + // returns an RTX packet containing the RTX stream SSRC. + AttributeRtxSsrc = "rtx_ssrc" + // AttributeRtxSequenceNumber is the interceptor attribute added when + // Read() returns an RTX packet containing the RTX stream sequence number. + AttributeRtxSequenceNumber = "rtx_sequence_number" +) + +func defaultSrtpProtectionProfiles() []dtls.SRTPProtectionProfile { + return []dtls.SRTPProtectionProfile{ + dtls.SRTP_AEAD_AES_256_GCM, + dtls.SRTP_AEAD_AES_128_GCM, + dtls.SRTP_AES128_CM_HMAC_SHA1_80, + } +} diff --git a/vendor/github.com/pion/webrtc/v4/datachannel.go b/vendor/github.com/pion/webrtc/v4/datachannel.go new file mode 100644 index 0000000..c3c3ebd --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/datachannel.go @@ -0,0 +1,768 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js + +package webrtc + +import ( + "errors" + "fmt" + "io" + "sync" + "sync/atomic" + "time" + + "github.com/pion/datachannel" + "github.com/pion/logging" + "github.com/pion/webrtc/v4/pkg/rtcerr" +) + +var errSCTPNotEstablished = errors.New("SCTP not established") + +// DataChannel represents a WebRTC DataChannel +// The DataChannel interface represents a network channel +// which can be used for bidirectional peer-to-peer transfers of arbitrary data. +type DataChannel struct { + mu sync.RWMutex + + statsID string + label string + ordered bool + maxPacketLifeTime *uint16 + maxRetransmits *uint16 + protocol string + negotiated bool + id *uint16 + readyState atomic.Value // DataChannelState + bufferedAmountLowThreshold uint64 + detachCalled bool + readLoopActive chan struct{} + isGracefulClosed bool + + // The binaryType represents attribute MUST, on getting, return the value to + // which it was last set. On setting, if the new value is either the string + // "blob" or the string "arraybuffer", then set the IDL attribute to this + // new value. Otherwise, throw a SyntaxError. When an DataChannel object + // is created, the binaryType attribute MUST be initialized to the string + // "blob". This attribute controls how binary data is exposed to scripts. + // binaryType string + + onMessageHandler func(DataChannelMessage) + openHandlerOnce sync.Once + onOpenHandler func() + dialHandlerOnce sync.Once + onDialHandler func() + onCloseHandler func() + onBufferedAmountLow func() + onErrorHandler func(error) + + sctpTransport *SCTPTransport + dataChannel *datachannel.DataChannel + + // A reference to the associated api object used by this datachannel + api *API + log logging.LeveledLogger +} + +// NewDataChannel creates a new DataChannel. +// This constructor is part of the ORTC API. It is not +// meant to be used together with the basic WebRTC API. +func (api *API) NewDataChannel(transport *SCTPTransport, params *DataChannelParameters) (*DataChannel, error) { + d, err := api.newDataChannel(params, nil, api.settingEngine.LoggerFactory.NewLogger("ortc")) + if err != nil { + return nil, err + } + + err = d.open(transport) + if err != nil { + return nil, err + } + + return d, nil +} + +// newDataChannel is an internal constructor for the data channel used to +// create the DataChannel object before the networking is set up. +func (api *API) newDataChannel( + params *DataChannelParameters, + sctpTransport *SCTPTransport, + log logging.LeveledLogger, +) (*DataChannel, error) { + // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #5) + if len(params.Label) > 65535 { + return nil, &rtcerr.TypeError{Err: ErrStringSizeLimit} + } + + dataChannel := &DataChannel{ + sctpTransport: sctpTransport, + statsID: fmt.Sprintf("DataChannel-%d", time.Now().UnixNano()), + label: params.Label, + protocol: params.Protocol, + negotiated: params.Negotiated, + id: params.ID, + ordered: params.Ordered, + maxPacketLifeTime: params.MaxPacketLifeTime, + maxRetransmits: params.MaxRetransmits, + api: api, + log: log, + } + + dataChannel.setReadyState(DataChannelStateConnecting) + + return dataChannel, nil +} + +// open opens the datachannel over the sctp transport. +func (d *DataChannel) open(sctpTransport *SCTPTransport) error { //nolint:cyclop + association := sctpTransport.association() + if association == nil { + return errSCTPNotEstablished + } + + d.mu.Lock() + if d.sctpTransport != nil { // already open + d.mu.Unlock() + + return nil + } + d.sctpTransport = sctpTransport + var channelType datachannel.ChannelType + var reliabilityParameter uint32 + + switch { + case d.maxPacketLifeTime == nil && d.maxRetransmits == nil: + if d.ordered { + channelType = datachannel.ChannelTypeReliable + } else { + channelType = datachannel.ChannelTypeReliableUnordered + } + + case d.maxRetransmits != nil: + reliabilityParameter = uint32(*d.maxRetransmits) + if d.ordered { + channelType = datachannel.ChannelTypePartialReliableRexmit + } else { + channelType = datachannel.ChannelTypePartialReliableRexmitUnordered + } + default: + reliabilityParameter = uint32(*d.maxPacketLifeTime) + if d.ordered { + channelType = datachannel.ChannelTypePartialReliableTimed + } else { + channelType = datachannel.ChannelTypePartialReliableTimedUnordered + } + } + + cfg := &datachannel.Config{ + ChannelType: channelType, + Priority: datachannel.ChannelPriorityNormal, + ReliabilityParameter: reliabilityParameter, + Label: d.label, + Protocol: d.protocol, + Negotiated: d.negotiated, + LoggerFactory: d.api.settingEngine.LoggerFactory, + } + + if d.id == nil { + // avoid holding lock when generating ID, since id generation locks + d.mu.Unlock() + var dcID *uint16 + err := d.sctpTransport.generateAndSetDataChannelID(d.sctpTransport.dtlsTransport.role(), &dcID) + if err != nil { + return err + } + d.mu.Lock() + d.id = dcID + } + dc, err := datachannel.Dial(association, *d.id, cfg) + if err != nil { + d.mu.Unlock() + + return err + } + + // bufferedAmountLowThreshold and onBufferedAmountLow might be set earlier + dc.SetBufferedAmountLowThreshold(d.bufferedAmountLowThreshold) + dc.OnBufferedAmountLow(d.onBufferedAmountLow) + d.mu.Unlock() + + d.onDial() + d.handleOpen(dc, false, d.negotiated) + + return nil +} + +// Transport returns the SCTPTransport instance the DataChannel is sending over. +func (d *DataChannel) Transport() *SCTPTransport { + d.mu.RLock() + defer d.mu.RUnlock() + + return d.sctpTransport +} + +// After onOpen is complete check that the user called detach +// and provide an error message if the call was missed. +func (d *DataChannel) checkDetachAfterOpen() { + d.mu.RLock() + defer d.mu.RUnlock() + + if d.api.settingEngine.detach.DataChannels && !d.detachCalled { + d.log.Warn("webrtc.DetachDataChannels() enabled but didn't Detach, call Detach from OnOpen") + } +} + +// OnOpen sets an event handler which is invoked when +// the underlying data transport has been established (or re-established). +func (d *DataChannel) OnOpen(f func()) { + d.mu.Lock() + d.openHandlerOnce = sync.Once{} + d.onOpenHandler = f + d.mu.Unlock() + + if d.ReadyState() == DataChannelStateOpen { + // If the data channel is already open, call the handler immediately. + go d.openHandlerOnce.Do(func() { + f() + d.checkDetachAfterOpen() + }) + } +} + +func (d *DataChannel) onOpen() { + d.mu.RLock() + handler := d.onOpenHandler + if d.isGracefulClosed { + d.mu.RUnlock() + + return + } + d.mu.RUnlock() + + if handler != nil { + go d.openHandlerOnce.Do(func() { + handler() + d.checkDetachAfterOpen() + }) + } +} + +// OnDial sets an event handler which is invoked when the +// peer has been dialed, but before said peer has responded. +func (d *DataChannel) OnDial(f func()) { + d.mu.Lock() + d.dialHandlerOnce = sync.Once{} + d.onDialHandler = f + d.mu.Unlock() + + if d.ReadyState() == DataChannelStateOpen { + // If the data channel is already open, call the handler immediately. + go d.dialHandlerOnce.Do(f) + } +} + +func (d *DataChannel) onDial() { + d.mu.RLock() + handler := d.onDialHandler + if d.isGracefulClosed { + d.mu.RUnlock() + + return + } + d.mu.RUnlock() + + if handler != nil { + go d.dialHandlerOnce.Do(handler) + } +} + +// OnClose sets an event handler which is invoked when +// the underlying data transport has been closed. +// Note: Due to backwards compatibility, there is a chance that +// OnClose can be called, even if the GracefulClose is used. +// If this is the case for you, you can deregister OnClose +// prior to GracefulClose. +func (d *DataChannel) OnClose(f func()) { + d.mu.Lock() + defer d.mu.Unlock() + d.onCloseHandler = f +} + +func (d *DataChannel) onClose() { + d.mu.RLock() + handler := d.onCloseHandler + d.mu.RUnlock() + + if handler != nil { + go handler() + } +} + +// OnMessage sets an event handler which is invoked on a binary +// message arrival over the sctp transport from a remote peer. +// OnMessage can currently receive messages up to 16384 bytes +// in size. Check out the detach API if you want to use larger +// message sizes. Note that browser support for larger messages +// is also limited. +func (d *DataChannel) OnMessage(f func(msg DataChannelMessage)) { + d.mu.Lock() + defer d.mu.Unlock() + d.onMessageHandler = f +} + +func (d *DataChannel) onMessage(msg DataChannelMessage) { + d.mu.RLock() + handler := d.onMessageHandler + if d.isGracefulClosed { + d.mu.RUnlock() + + return + } + d.mu.RUnlock() + + if handler == nil { + return + } + handler(msg) +} + +func (d *DataChannel) handleOpen(dc *datachannel.DataChannel, isRemote, isAlreadyNegotiated bool) { + d.mu.Lock() + if d.isGracefulClosed { // The channel was closed during the connecting state + d.mu.Unlock() + if err := dc.Close(); err != nil { + d.log.Errorf("Failed to close DataChannel that was closed during connecting state %v", err.Error()) + } + d.onClose() + + return + } + d.dataChannel = dc + bufferedAmountLowThreshold := d.bufferedAmountLowThreshold + onBufferedAmountLow := d.onBufferedAmountLow + d.mu.Unlock() + d.setReadyState(DataChannelStateOpen) + + // Fire the OnOpen handler immediately not using pion/datachannel + // * detached datachannels have no read loop, the user needs to read and query themselves + // * remote datachannels should fire OnOpened. This isn't spec compliant, but we can't break behavior yet + // * already negotiated datachannels should fire OnOpened + if d.api.settingEngine.detach.DataChannels || isRemote || isAlreadyNegotiated { + // bufferedAmountLowThreshold and onBufferedAmountLow might be set earlier + d.dataChannel.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold) + d.dataChannel.OnBufferedAmountLow(onBufferedAmountLow) + d.onOpen() + } else { + dc.OnOpen(func() { + d.onOpen() + }) + } + + d.mu.Lock() + defer d.mu.Unlock() + + if d.isGracefulClosed { + return + } + + if !d.api.settingEngine.detach.DataChannels { + d.readLoopActive = make(chan struct{}) + go d.readLoop() + } +} + +// OnError sets an event handler which is invoked when +// the underlying data transport cannot be read. +func (d *DataChannel) OnError(f func(err error)) { + d.mu.Lock() + defer d.mu.Unlock() + d.onErrorHandler = f +} + +func (d *DataChannel) onError(err error) { + d.mu.RLock() + handler := d.onErrorHandler + if d.isGracefulClosed { + d.mu.RUnlock() + + return + } + d.mu.RUnlock() + + if handler != nil { + go handler(err) + } +} + +func (d *DataChannel) readLoop() { + defer func() { + d.mu.Lock() + readLoopActive := d.readLoopActive + d.mu.Unlock() + defer close(readLoopActive) + }() + + buffer := make([]byte, sctpMaxMessageSizeUnsetValue) + for { + n, isString, err := d.dataChannel.ReadDataChannel(buffer) + if err != nil { + if errors.Is(err, io.ErrShortBuffer) { + if int64(n) < int64(d.api.settingEngine.getSCTPMaxMessageSize()) { + buffer = append(buffer, make([]byte, len(buffer))...) // nolint + + continue + } + + d.log.Errorf( + "Incoming DataChannel message larger then Max Message size %v", + d.api.settingEngine.getSCTPMaxMessageSize(), + ) + } + + d.setReadyState(DataChannelStateClosed) + if !errors.Is(err, io.EOF) { + d.onError(err) + } + d.onClose() + + return + } + + d.onMessage(DataChannelMessage{ + Data: append([]byte{}, buffer[:n]...), + IsString: isString, + }) + } +} + +// Send sends the binary message to the DataChannel peer. +func (d *DataChannel) Send(data []byte) error { + err := d.ensureOpen() + if err != nil { + return err + } + + _, err = d.dataChannel.WriteDataChannel(data, false) + + return err +} + +// SendText sends the text message to the DataChannel peer. +func (d *DataChannel) SendText(s string) error { + err := d.ensureOpen() + if err != nil { + return err + } + + _, err = d.dataChannel.WriteDataChannel([]byte(s), true) + + return err +} + +func (d *DataChannel) ensureOpen() error { + d.mu.RLock() + defer d.mu.RUnlock() + if d.ReadyState() != DataChannelStateOpen { + return io.ErrClosedPipe + } + + return nil +} + +// Detach allows you to detach the underlying datachannel. +// This provides an idiomatic API to work with +// (`io.ReadWriteCloser` with its `.Read()` and `.Write()` methods, +// as opposed to `.Send()` and `.OnMessage`), +// however it disables the OnMessage callback. +// Before calling Detach you have to enable this behavior by calling +// webrtc.DetachDataChannels(). Combining detached and normal data channels +// is not supported. +// Please refer to the data-channels-detach example and the +// pion/datachannel documentation for the correct way to handle the +// resulting DataChannel object. +func (d *DataChannel) Detach() (datachannel.ReadWriteCloser, error) { + return d.DetachWithDeadline() +} + +// DetachWithDeadline allows you to detach the underlying datachannel. +// It is the same as Detach but returns a ReadWriteCloserDeadliner. +func (d *DataChannel) DetachWithDeadline() (datachannel.ReadWriteCloserDeadliner, error) { + d.mu.Lock() + + if !d.api.settingEngine.detach.DataChannels { + d.mu.Unlock() + + return nil, errDetachNotEnabled + } + + if d.dataChannel == nil { + d.mu.Unlock() + + return nil, errDetachBeforeOpened + } + + d.detachCalled = true + + dataChannel := d.dataChannel + d.mu.Unlock() + + // Remove the reference from SCTPTransport so that the datachannel + // can be garbage collected on close + d.sctpTransport.lock.Lock() + n := len(d.sctpTransport.dataChannels) + j := 0 + for i := range n { + if d == d.sctpTransport.dataChannels[i] { + continue + } + d.sctpTransport.dataChannels[j] = d.sctpTransport.dataChannels[i] + j++ + } + for i := j; i < n; i++ { + d.sctpTransport.dataChannels[i] = nil + } + d.sctpTransport.dataChannels = d.sctpTransport.dataChannels[:j] + d.sctpTransport.lock.Unlock() + + return dataChannel, nil +} + +// Close Closes the DataChannel. It may be called regardless of whether +// the DataChannel object was created by this peer or the remote peer. +func (d *DataChannel) Close() error { + return d.close(false) +} + +// GracefulClose Closes the DataChannel. It may be called regardless of whether +// the DataChannel object was created by this peer or the remote peer. It also waits +// for any goroutines it started to complete. This is only safe to call outside of +// DataChannel callbacks or if in a callback, in its own goroutine. +func (d *DataChannel) GracefulClose() error { + return d.close(true) +} + +// Normally, close only stops writes from happening, so graceful=true +// will wait for reads to be finished based on underlying SCTP association +// closure or a SCTP reset stream from the other side. This is safe to call +// with graceful=true after tearing down a PeerConnection but not +// necessarily before. For example, if you used a vnet and dropped all packets +// right before closing the DataChannel, you'd need never see a reset stream. +func (d *DataChannel) close(shouldGracefullyClose bool) error { + d.mu.Lock() + d.isGracefulClosed = true + readLoopActive := d.readLoopActive + if shouldGracefullyClose && readLoopActive != nil { + defer func() { + <-readLoopActive + }() + } + haveSctpTransport := d.dataChannel != nil + d.mu.Unlock() + + if d.ReadyState() == DataChannelStateClosed { + return nil + } + + d.setReadyState(DataChannelStateClosing) + if !haveSctpTransport { + return nil + } + + return d.dataChannel.Close() +} + +// Label represents a label that can be used to distinguish this +// DataChannel object from other DataChannel objects. Scripts are +// allowed to create multiple DataChannel objects with the same label. +func (d *DataChannel) Label() string { + d.mu.RLock() + defer d.mu.RUnlock() + + return d.label +} + +// Ordered returns true if the DataChannel is ordered, and false if +// out-of-order delivery is allowed. +func (d *DataChannel) Ordered() bool { + d.mu.RLock() + defer d.mu.RUnlock() + + return d.ordered +} + +// MaxPacketLifeTime represents the length of the time window (msec) during +// which transmissions and retransmissions may occur in unreliable mode. +func (d *DataChannel) MaxPacketLifeTime() *uint16 { + d.mu.RLock() + defer d.mu.RUnlock() + + return d.maxPacketLifeTime +} + +// MaxRetransmits represents the maximum number of retransmissions that are +// attempted in unreliable mode. +func (d *DataChannel) MaxRetransmits() *uint16 { + d.mu.RLock() + defer d.mu.RUnlock() + + return d.maxRetransmits +} + +// Protocol represents the name of the sub-protocol used with this +// DataChannel. +func (d *DataChannel) Protocol() string { + d.mu.RLock() + defer d.mu.RUnlock() + + return d.protocol +} + +// Negotiated represents whether this DataChannel was negotiated by the +// application (true), or not (false). +func (d *DataChannel) Negotiated() bool { + d.mu.RLock() + defer d.mu.RUnlock() + + return d.negotiated +} + +// ID represents the ID for this DataChannel. The value is initially +// null, which is what will be returned if the ID was not provided at +// channel creation time, and the DTLS role of the SCTP transport has not +// yet been negotiated. Otherwise, it will return the ID that was either +// selected by the script or generated. After the ID is set to a non-null +// value, it will not change. +func (d *DataChannel) ID() *uint16 { + d.mu.RLock() + defer d.mu.RUnlock() + + return d.id +} + +// ReadyState represents the state of the DataChannel object. +func (d *DataChannel) ReadyState() DataChannelState { + if v, ok := d.readyState.Load().(DataChannelState); ok { + return v + } + + return DataChannelState(0) +} + +// BufferedAmount represents the number of bytes of application data +// (UTF-8 text and binary data) that have been queued using send(). Even +// though the data transmission can occur in parallel, the returned value +// MUST NOT be decreased before the current task yielded back to the event +// loop to prevent race conditions. The value does not include framing +// overhead incurred by the protocol, or buffering done by the operating +// system or network hardware. The value of BufferedAmount slot will only +// increase with each call to the send() method as long as the ReadyState is +// open; however, BufferedAmount does not reset to zero once the channel +// closes. +func (d *DataChannel) BufferedAmount() uint64 { + d.mu.RLock() + defer d.mu.RUnlock() + + if d.dataChannel == nil { + return 0 + } + + return d.dataChannel.BufferedAmount() +} + +// BufferedAmountLowThreshold represents the threshold at which the +// bufferedAmount is considered to be low. When the bufferedAmount decreases +// from above this threshold to equal or below it, the bufferedamountlow +// event fires. BufferedAmountLowThreshold is initially zero on each new +// DataChannel, but the application may change its value at any time. +// The threshold is set to 0 by default. +func (d *DataChannel) BufferedAmountLowThreshold() uint64 { + d.mu.RLock() + defer d.mu.RUnlock() + + if d.dataChannel == nil { + return d.bufferedAmountLowThreshold + } + + return d.dataChannel.BufferedAmountLowThreshold() +} + +// SetBufferedAmountLowThreshold is used to update the threshold. +// See BufferedAmountLowThreshold(). +func (d *DataChannel) SetBufferedAmountLowThreshold(th uint64) { + d.mu.Lock() + defer d.mu.Unlock() + + d.bufferedAmountLowThreshold = th + + if d.dataChannel != nil { + d.dataChannel.SetBufferedAmountLowThreshold(th) + } +} + +// OnBufferedAmountLow sets an event handler which is invoked when +// the number of bytes of outgoing data becomes lower than or equal to the +// BufferedAmountLowThreshold. +func (d *DataChannel) OnBufferedAmountLow(f func()) { + d.mu.Lock() + defer d.mu.Unlock() + + onBufferedAmountLow := d.makeBufferedAmountLowHandler(f) + d.onBufferedAmountLow = onBufferedAmountLow + + if d.dataChannel != nil { + d.dataChannel.OnBufferedAmountLow(onBufferedAmountLow) + } +} + +func (d *DataChannel) makeBufferedAmountLowHandler(f func()) func() { + return func() { + go func() { + if d.ReadyState() != DataChannelStateOpen { + return + } + + f() + }() + } +} + +func (d *DataChannel) getStatsID() string { + d.mu.Lock() + defer d.mu.Unlock() + + return d.statsID +} + +func (d *DataChannel) collectStats(collector *statsReportCollector) { + collector.Collecting() + + d.mu.Lock() + defer d.mu.Unlock() + + stats := DataChannelStats{ + Timestamp: statsTimestampNow(), + Type: StatsTypeDataChannel, + ID: d.statsID, + Label: d.label, + Protocol: d.protocol, + // TransportID string `json:"transportId"` + State: d.ReadyState(), + } + + if d.id != nil { + stats.DataChannelIdentifier = int32(*d.id) + } + + if d.dataChannel != nil { + stats.MessagesSent = d.dataChannel.MessagesSent() + stats.BytesSent = d.dataChannel.BytesSent() + stats.MessagesReceived = d.dataChannel.MessagesReceived() + stats.BytesReceived = d.dataChannel.BytesReceived() + } + + collector.Collect(stats.ID, stats) +} + +func (d *DataChannel) setReadyState(r DataChannelState) { + d.readyState.Store(r) +} diff --git a/vendor/github.com/pion/webrtc/v4/datachannel_js.go b/vendor/github.com/pion/webrtc/v4/datachannel_js.go new file mode 100644 index 0000000..2d5336b --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/datachannel_js.go @@ -0,0 +1,374 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build js && wasm +// +build js,wasm + +package webrtc + +import ( + "errors" + "fmt" + "syscall/js" + + "github.com/pion/datachannel" +) + +const dataChannelBufferSize = 16384 // Lowest common denominator among browsers + +// DataChannel represents a WebRTC DataChannel +// The DataChannel interface represents a network channel +// which can be used for bidirectional peer-to-peer transfers of arbitrary data +type DataChannel struct { + // Pointer to the underlying JavaScript RTCPeerConnection object. + underlying js.Value + + // Keep track of handlers/callbacks so we can call Release as required by the + // syscall/js API. Initially nil. + onOpenHandler *js.Func + onCloseHandler *js.Func + onClosingHandler *js.Func + onMessageHandler *js.Func + onBufferedAmountLow *js.Func + onErrorHandler *js.Func + + // A reference to the associated api object used by this datachannel + api *API +} + +// JSValue returns the underlying RTCDataChannel +func (d *DataChannel) JSValue() js.Value { + return d.underlying +} + +// OnOpen sets an event handler which is invoked when +// the underlying data transport has been established (or re-established). +func (d *DataChannel) OnOpen(f func()) { + if d.onOpenHandler != nil { + oldHandler := d.onOpenHandler + defer oldHandler.Release() + } + onOpenHandler := js.FuncOf(func(this js.Value, args []js.Value) any { + go f() + return js.Undefined() + }) + d.onOpenHandler = &onOpenHandler + d.underlying.Set("onopen", onOpenHandler) +} + +// OnClose sets an event handler which is invoked when +// the underlying data transport has been closed. +func (d *DataChannel) OnClose(f func()) { + if d.onCloseHandler != nil { + oldHandler := d.onCloseHandler + defer oldHandler.Release() + } + onCloseHandler := js.FuncOf(func(this js.Value, args []js.Value) any { + go f() + return js.Undefined() + }) + d.onCloseHandler = &onCloseHandler + d.underlying.Set("onclose", onCloseHandler) +} + +// FYI `OnClosing` is not implemented in the non-JS version of Pion. + +func (d *DataChannel) OnClosing(f func()) { + if d.onClosingHandler != nil { + oldHandler := d.onClosingHandler + defer oldHandler.Release() + } + onClosingHandler := js.FuncOf(func(this js.Value, args []js.Value) any { + go f() + return js.Undefined() + }) + d.onClosingHandler = &onClosingHandler + d.underlying.Set("onclosing", onClosingHandler) +} + +func (d *DataChannel) OnError(f func(err error)) { + if d.onErrorHandler != nil { + oldHandler := d.onErrorHandler + defer oldHandler.Release() + } + onErrorHandler := js.FuncOf(func(this js.Value, args []js.Value) any { + event := args[0] + errorObj := event.Get("error") + // FYI RTCError has some extra properties, e.g. `errorDetail`: + // https://developer.mozilla.org/en-US/docs/Web/API/RTCDataChannel/error_event + errorMessage := errorObj.Get("message").String() + go f(errors.New(errorMessage)) + return js.Undefined() + }) + d.onErrorHandler = &onErrorHandler + d.underlying.Set("onerror", onErrorHandler) +} + +// OnMessage sets an event handler which is invoked on a binary message arrival +// from a remote peer. Note that browsers may place limitations on message size. +func (d *DataChannel) OnMessage(f func(msg DataChannelMessage)) { + if d.onMessageHandler != nil { + oldHandler := d.onMessageHandler + defer oldHandler.Release() + } + onMessageHandler := js.FuncOf(func(this js.Value, args []js.Value) any { + // pion/webrtc/projects/15 + data := args[0].Get("data") + go func() { + // valueToDataChannelMessage may block when handling 'Blob' data + // so we need to call it from a new routine. See: + // https://pkg.go.dev/syscall/js#FuncOf + msg := valueToDataChannelMessage(data) + f(msg) + }() + return js.Undefined() + }) + d.onMessageHandler = &onMessageHandler + d.underlying.Set("onmessage", onMessageHandler) +} + +// Send sends the binary message to the DataChannel peer +func (d *DataChannel) Send(data []byte) (err error) { + defer func() { + if e := recover(); e != nil { + err = recoveryToError(e) + } + }() + array := js.Global().Get("Uint8Array").New(len(data)) + js.CopyBytesToJS(array, data) + d.underlying.Call("send", array) + return nil +} + +// SendText sends the text message to the DataChannel peer +func (d *DataChannel) SendText(s string) (err error) { + defer func() { + if e := recover(); e != nil { + err = recoveryToError(e) + } + }() + d.underlying.Call("send", s) + return nil +} + +// Detach allows you to detach the underlying datachannel. This provides +// an idiomatic API to work with, however it disables the OnMessage callback. +// Before calling Detach you have to enable this behavior by calling +// webrtc.DetachDataChannels(). Combining detached and normal data channels +// is not supported. +// Please refer to the data-channels-detach example and the +// pion/datachannel documentation for the correct way to handle the +// resulting DataChannel object. +func (d *DataChannel) Detach() (datachannel.ReadWriteCloser, error) { + if !d.api.settingEngine.detach.DataChannels { + return nil, fmt.Errorf("enable detaching by calling webrtc.DetachDataChannels()") + } + + detached := newDetachedDataChannel(d) + return detached, nil +} + +// Close Closes the DataChannel. It may be called regardless of whether +// the DataChannel object was created by this peer or the remote peer. +func (d *DataChannel) Close() (err error) { + defer func() { + if e := recover(); e != nil { + err = recoveryToError(e) + } + }() + + d.underlying.Call("close") + + // Release any handlers as required by the syscall/js API. + if d.onOpenHandler != nil { + d.onOpenHandler.Release() + } + if d.onCloseHandler != nil { + d.onCloseHandler.Release() + } + if d.onClosingHandler != nil { + d.onClosingHandler.Release() + } + if d.onMessageHandler != nil { + d.onMessageHandler.Release() + } + if d.onBufferedAmountLow != nil { + d.onBufferedAmountLow.Release() + } + if d.onErrorHandler != nil { + d.onErrorHandler.Release() + } + + return nil +} + +// Label represents a label that can be used to distinguish this +// DataChannel object from other DataChannel objects. Scripts are +// allowed to create multiple DataChannel objects with the same label. +func (d *DataChannel) Label() string { + return d.underlying.Get("label").String() +} + +// Ordered represents if the DataChannel is ordered, and false if +// out-of-order delivery is allowed. +func (d *DataChannel) Ordered() bool { + ordered := d.underlying.Get("ordered") + if ordered.IsUndefined() { + return true // default is true + } + return ordered.Bool() +} + +// MaxPacketLifeTime represents the length of the time window (msec) during +// which transmissions and retransmissions may occur in unreliable mode. +func (d *DataChannel) MaxPacketLifeTime() *uint16 { + if !d.underlying.Get("maxPacketLifeTime").IsUndefined() { + return valueToUint16Pointer(d.underlying.Get("maxPacketLifeTime")) + } + + // See https://bugs.chromium.org/p/chromium/issues/detail?id=696681 + // Chrome calls this "maxRetransmitTime" + return valueToUint16Pointer(d.underlying.Get("maxRetransmitTime")) +} + +// MaxRetransmits represents the maximum number of retransmissions that are +// attempted in unreliable mode. +func (d *DataChannel) MaxRetransmits() *uint16 { + return valueToUint16Pointer(d.underlying.Get("maxRetransmits")) +} + +// Protocol represents the name of the sub-protocol used with this +// DataChannel. +func (d *DataChannel) Protocol() string { + return d.underlying.Get("protocol").String() +} + +// Negotiated represents whether this DataChannel was negotiated by the +// application (true), or not (false). +func (d *DataChannel) Negotiated() bool { + return d.underlying.Get("negotiated").Bool() +} + +// ID represents the ID for this DataChannel. The value is initially +// null, which is what will be returned if the ID was not provided at +// channel creation time. Otherwise, it will return the ID that was either +// selected by the script or generated. After the ID is set to a non-null +// value, it will not change. +func (d *DataChannel) ID() *uint16 { + return valueToUint16Pointer(d.underlying.Get("id")) +} + +// ReadyState represents the state of the DataChannel object. +func (d *DataChannel) ReadyState() DataChannelState { + return newDataChannelState(d.underlying.Get("readyState").String()) +} + +// BufferedAmount represents the number of bytes of application data +// (UTF-8 text and binary data) that have been queued using send(). Even +// though the data transmission can occur in parallel, the returned value +// MUST NOT be decreased before the current task yielded back to the event +// loop to prevent race conditions. The value does not include framing +// overhead incurred by the protocol, or buffering done by the operating +// system or network hardware. The value of BufferedAmount slot will only +// increase with each call to the send() method as long as the ReadyState is +// open; however, BufferedAmount does not reset to zero once the channel +// closes. +func (d *DataChannel) BufferedAmount() uint64 { + return uint64(d.underlying.Get("bufferedAmount").Int()) +} + +// BufferedAmountLowThreshold represents the threshold at which the +// bufferedAmount is considered to be low. When the bufferedAmount decreases +// from above this threshold to equal or below it, the bufferedamountlow +// event fires. BufferedAmountLowThreshold is initially zero on each new +// DataChannel, but the application may change its value at any time. +func (d *DataChannel) BufferedAmountLowThreshold() uint64 { + return uint64(d.underlying.Get("bufferedAmountLowThreshold").Int()) +} + +// SetBufferedAmountLowThreshold is used to update the threshold. +// See BufferedAmountLowThreshold(). +func (d *DataChannel) SetBufferedAmountLowThreshold(th uint64) { + d.underlying.Set("bufferedAmountLowThreshold", th) +} + +// OnBufferedAmountLow sets an event handler which is invoked when +// the number of bytes of outgoing data becomes lower than or equal to the +// BufferedAmountLowThreshold. +func (d *DataChannel) OnBufferedAmountLow(f func()) { + if d.onBufferedAmountLow != nil { + oldHandler := d.onBufferedAmountLow + defer oldHandler.Release() + } + onBufferedAmountLow := js.FuncOf(func(this js.Value, args []js.Value) any { + if d.ReadyState() != DataChannelStateOpen { + return js.Undefined() + } + + go f() + return js.Undefined() + }) + d.onBufferedAmountLow = &onBufferedAmountLow + d.underlying.Set("onbufferedamountlow", onBufferedAmountLow) +} + +// valueToDataChannelMessage converts the given value to a DataChannelMessage. +// val should be obtained from MessageEvent.data where MessageEvent is received +// via the RTCDataChannel.onmessage callback. +func valueToDataChannelMessage(val js.Value) DataChannelMessage { + // If val is of type string, the conversion is straightforward. + if val.Type() == js.TypeString { + return DataChannelMessage{ + IsString: true, + Data: []byte(val.String()), + } + } + + // For other types, we need to first determine val.constructor.name. + constructorName := val.Get("constructor").Get("name").String() + var data []byte + switch constructorName { + case "Uint8Array": + // We can easily convert Uint8Array to []byte + data = uint8ArrayValueToBytes(val) + case "Blob": + // Convert the Blob to an ArrayBuffer and then convert the ArrayBuffer + // to a Uint8Array. + // See: https://developer.mozilla.org/en-US/docs/Web/API/Blob + + // The JavaScript API for reading from the Blob is asynchronous. We use a + // channel to signal when reading is done. + reader := js.Global().Get("FileReader").New() + doneChan := make(chan struct{}) + reader.Call("addEventListener", "loadend", js.FuncOf(func(this js.Value, args []js.Value) any { + go func() { + // Signal that the FileReader is done reading/loading by sending through + // the doneChan. + doneChan <- struct{}{} + }() + return js.Undefined() + })) + + reader.Call("readAsArrayBuffer", val) + + // Wait for the FileReader to finish reading/loading. + <-doneChan + + // At this point buffer.result is a typed array, which we know how to + // handle. + buffer := reader.Get("result") + uint8Array := js.Global().Get("Uint8Array").New(buffer) + data = uint8ArrayValueToBytes(uint8Array) + default: + // Assume we have an ArrayBufferView type which we can convert to a + // Uint8Array in JavaScript. + // See: https://developer.mozilla.org/en-US/docs/Web/API/ArrayBufferView + uint8Array := js.Global().Get("Uint8Array").New(val) + data = uint8ArrayValueToBytes(uint8Array) + } + + return DataChannelMessage{ + IsString: false, + Data: data, + } +} diff --git a/vendor/github.com/pion/webrtc/v4/datachannel_js_detach.go b/vendor/github.com/pion/webrtc/v4/datachannel_js_detach.go new file mode 100644 index 0000000..c45fb41 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/datachannel_js_detach.go @@ -0,0 +1,75 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build js && wasm +// +build js,wasm + +package webrtc + +import ( + "errors" +) + +type detachedDataChannel struct { + dc *DataChannel + + read chan DataChannelMessage + done chan struct{} +} + +func newDetachedDataChannel(dc *DataChannel) *detachedDataChannel { + read := make(chan DataChannelMessage) + done := make(chan struct{}) + + // Wire up callbacks + dc.OnMessage(func(msg DataChannelMessage) { + read <- msg // pion/webrtc/projects/15 + }) + + // pion/webrtc/projects/15 + + return &detachedDataChannel{ + dc: dc, + read: read, + done: done, + } +} + +func (c *detachedDataChannel) Read(p []byte) (int, error) { + n, _, err := c.ReadDataChannel(p) + return n, err +} + +func (c *detachedDataChannel) ReadDataChannel(p []byte) (int, bool, error) { + select { + case <-c.done: + return 0, false, errors.New("Reader closed") + case msg := <-c.read: + n := copy(p, msg.Data) + if n < len(msg.Data) { + return n, msg.IsString, errors.New("Read buffer to small") + } + return n, msg.IsString, nil + } +} + +func (c *detachedDataChannel) Write(p []byte) (n int, err error) { + return c.WriteDataChannel(p, false) +} + +func (c *detachedDataChannel) WriteDataChannel(p []byte, isString bool) (n int, err error) { + if isString { + err = c.dc.SendText(string(p)) + return len(p), err + } + + err = c.dc.Send(p) + + return len(p), err +} + +func (c *detachedDataChannel) Close() error { + close(c.done) + + return c.dc.Close() +} diff --git a/vendor/github.com/pion/webrtc/v4/datachannelinit.go b/vendor/github.com/pion/webrtc/v4/datachannelinit.go new file mode 100644 index 0000000..b8726ff --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/datachannelinit.go @@ -0,0 +1,36 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +// DataChannelInit can be used to configure properties of the underlying +// channel such as data reliability. +type DataChannelInit struct { + // Ordered indicates if data is allowed to be delivered out of order. The + // default value of true, guarantees that data will be delivered in order. + Ordered *bool + + // MaxPacketLifeTime limits the time (in milliseconds) during which the + // channel will transmit or retransmit data if not acknowledged. This value + // may be clamped if it exceeds the maximum value supported. + MaxPacketLifeTime *uint16 + + // MaxRetransmits limits the number of times a channel will retransmit data + // if not successfully delivered. This value may be clamped if it exceeds + // the maximum value supported. + MaxRetransmits *uint16 + + // Protocol describes the subprotocol name used for this channel. + Protocol *string + + // Negotiated describes if the data channel is created by the local peer or + // the remote peer. The default value of false tells the user agent to + // announce the channel in-band and instruct the other peer to dispatch a + // corresponding DataChannel. If set to true, it is up to the application + // to negotiate the channel and create an DataChannel with the same id + // at the other peer. + Negotiated *bool + + // ID overrides the default selection of ID for this channel. + ID *uint16 +} diff --git a/vendor/github.com/pion/webrtc/v4/datachannelmessage.go b/vendor/github.com/pion/webrtc/v4/datachannelmessage.go new file mode 100644 index 0000000..66cf536 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/datachannelmessage.go @@ -0,0 +1,13 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +// DataChannelMessage represents a message received from the +// data channel. IsString will be set to true if the incoming +// message is of the string type. Otherwise the message is of +// a binary type. +type DataChannelMessage struct { + IsString bool + Data []byte +} diff --git a/vendor/github.com/pion/webrtc/v4/datachannelparameters.go b/vendor/github.com/pion/webrtc/v4/datachannelparameters.go new file mode 100644 index 0000000..89d4eff --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/datachannelparameters.go @@ -0,0 +1,15 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +// DataChannelParameters describes the configuration of the DataChannel. +type DataChannelParameters struct { + Label string `json:"label"` + Protocol string `json:"protocol"` + ID *uint16 `json:"id"` + Ordered bool `json:"ordered"` + MaxPacketLifeTime *uint16 `json:"maxPacketLifeTime"` + MaxRetransmits *uint16 `json:"maxRetransmits"` + Negotiated bool `json:"negotiated"` +} diff --git a/vendor/github.com/pion/webrtc/v4/datachannelstate.go b/vendor/github.com/pion/webrtc/v4/datachannelstate.go new file mode 100644 index 0000000..f19db27 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/datachannelstate.go @@ -0,0 +1,79 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +// DataChannelState indicates the state of a data channel. +type DataChannelState int + +const ( + // DataChannelStateUnknown is the enum's zero-value. + DataChannelStateUnknown DataChannelState = iota + + // DataChannelStateConnecting indicates that the data channel is being + // established. This is the initial state of DataChannel, whether created + // with CreateDataChannel, or dispatched as a part of an DataChannelEvent. + DataChannelStateConnecting + + // DataChannelStateOpen indicates that the underlying data transport is + // established and communication is possible. + DataChannelStateOpen + + // DataChannelStateClosing indicates that the procedure to close down the + // underlying data transport has started. + DataChannelStateClosing + + // DataChannelStateClosed indicates that the underlying data transport + // has been closed or could not be established. + DataChannelStateClosed +) + +// This is done this way because of a linter. +const ( + dataChannelStateConnectingStr = "connecting" + dataChannelStateOpenStr = "open" + dataChannelStateClosingStr = "closing" + dataChannelStateClosedStr = "closed" +) + +func newDataChannelState(raw string) DataChannelState { + switch raw { + case dataChannelStateConnectingStr: + return DataChannelStateConnecting + case dataChannelStateOpenStr: + return DataChannelStateOpen + case dataChannelStateClosingStr: + return DataChannelStateClosing + case dataChannelStateClosedStr: + return DataChannelStateClosed + default: + return DataChannelStateUnknown + } +} + +func (t DataChannelState) String() string { + switch t { + case DataChannelStateConnecting: + return dataChannelStateConnectingStr + case DataChannelStateOpen: + return dataChannelStateOpenStr + case DataChannelStateClosing: + return dataChannelStateClosingStr + case DataChannelStateClosed: + return dataChannelStateClosedStr + default: + return ErrUnknownType.Error() + } +} + +// MarshalText implements encoding.TextMarshaler. +func (t DataChannelState) MarshalText() ([]byte, error) { + return []byte(t.String()), nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (t *DataChannelState) UnmarshalText(b []byte) error { + *t = newDataChannelState(string(b)) + + return nil +} diff --git a/vendor/github.com/pion/webrtc/v4/dtlsfingerprint.go b/vendor/github.com/pion/webrtc/v4/dtlsfingerprint.go new file mode 100644 index 0000000..749c52a --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/dtlsfingerprint.go @@ -0,0 +1,17 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +// DTLSFingerprint specifies the hash function algorithm and certificate +// fingerprint as described in https://tools.ietf.org/html/rfc4572. +type DTLSFingerprint struct { + // Algorithm specifies one of the hash function algorithms defined in + // the 'Hash function Textual Names' registry. + Algorithm string `json:"algorithm"` + + // Value specifies the value of the certificate fingerprint in lowercase + // hex string as expressed utilizing the syntax of 'fingerprint' in + // https://tools.ietf.org/html/rfc4572#section-5. + Value string `json:"value"` +} diff --git a/vendor/github.com/pion/webrtc/v4/dtlsparameters.go b/vendor/github.com/pion/webrtc/v4/dtlsparameters.go new file mode 100644 index 0000000..ea7f269 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/dtlsparameters.go @@ -0,0 +1,10 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +// DTLSParameters holds information relating to DTLS configuration. +type DTLSParameters struct { + Role DTLSRole `json:"role"` + Fingerprints []DTLSFingerprint `json:"fingerprints"` +} diff --git a/vendor/github.com/pion/webrtc/v4/dtlsrole.go b/vendor/github.com/pion/webrtc/v4/dtlsrole.go new file mode 100644 index 0000000..24e0b1b --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/dtlsrole.go @@ -0,0 +1,97 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +import ( + "github.com/pion/sdp/v3" +) + +// DTLSRole indicates the role of the DTLS transport. +type DTLSRole byte + +const ( + // DTLSRoleUnknown is the enum's zero-value. + DTLSRoleUnknown DTLSRole = iota + + // DTLSRoleAuto defines the DTLS role is determined based on + // the resolved ICE role: the ICE controlled role acts as the DTLS + // client and the ICE controlling role acts as the DTLS server. + DTLSRoleAuto + + // DTLSRoleClient defines the DTLS client role. + DTLSRoleClient + + // DTLSRoleServer defines the DTLS server role. + DTLSRoleServer +) + +const ( + // https://tools.ietf.org/html/rfc5763 + /* + The answerer MUST use either a + setup attribute value of setup:active or setup:passive. Note that + if the answerer uses setup:passive, then the DTLS handshake will + not begin until the answerer is received, which adds additional + latency. setup:active allows the answer and the DTLS handshake to + occur in parallel. Thus, setup:active is RECOMMENDED. + */ + defaultDtlsRoleAnswer = DTLSRoleClient + /* + The endpoint that is the offerer MUST use the setup attribute + value of setup:actpass and be prepared to receive a client_hello + before it receives the answer. + */ + defaultDtlsRoleOffer = DTLSRoleAuto +) + +func (r DTLSRole) String() string { + switch r { + case DTLSRoleAuto: + return "auto" + case DTLSRoleClient: + return "client" + case DTLSRoleServer: + return "server" + default: + return ErrUnknownType.Error() + } +} + +// Extract the dtls role from a session description. The decision is made from +// the first role we we parse. If no role can be found we return DTLSRoleAuto. +func dtlsRoleFromSDP(sessionDescription *sdp.SessionDescription) DTLSRole { + if sessionDescription == nil { + return DTLSRoleAuto + } + + for _, mediaSection := range sessionDescription.MediaDescriptions { + for _, attribute := range mediaSection.Attributes { + if attribute.Key == "setup" { + switch attribute.Value { + case sdp.ConnectionRoleActive.String(): + return DTLSRoleClient + case sdp.ConnectionRolePassive.String(): + return DTLSRoleServer + default: + return DTLSRoleAuto + } + } + } + } + + return DTLSRoleAuto +} + +func connectionRoleFromDtlsRole(d DTLSRole) sdp.ConnectionRole { + switch d { + case DTLSRoleClient: + return sdp.ConnectionRoleActive + case DTLSRoleServer: + return sdp.ConnectionRolePassive + case DTLSRoleAuto: + return sdp.ConnectionRoleActpass + default: + return sdp.ConnectionRole(0) + } +} diff --git a/vendor/github.com/pion/webrtc/v4/dtlstransport.go b/vendor/github.com/pion/webrtc/v4/dtlstransport.go new file mode 100644 index 0000000..c9276a1 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/dtlstransport.go @@ -0,0 +1,717 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js + +package webrtc + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/pion/dtls/v3" + "github.com/pion/dtls/v3/pkg/crypto/fingerprint" + "github.com/pion/interceptor" + "github.com/pion/logging" + "github.com/pion/rtcp" + "github.com/pion/srtp/v3" + "github.com/pion/webrtc/v4/internal/mux" + "github.com/pion/webrtc/v4/internal/util" + "github.com/pion/webrtc/v4/pkg/rtcerr" +) + +// DTLSTransport allows an application access to information about the DTLS +// transport over which RTP and RTCP packets are sent and received by +// RTPSender and RTPReceiver, as well other data such as SCTP packets sent +// and received by data channels. +type DTLSTransport struct { + lock sync.RWMutex + + iceTransport *ICETransport + certificates []Certificate + remoteParameters DTLSParameters + remoteCertificate []byte + state DTLSTransportState + srtpProtectionProfile srtp.ProtectionProfile + + onStateChangeHandler func(DTLSTransportState) + internalOnCloseHandler func() + + conn *dtls.Conn + + srtpSession, srtcpSession atomic.Value + srtpEndpoint, srtcpEndpoint *mux.Endpoint + simulcastStreams []simulcastStreamPair + srtpReady chan struct{} + + dtlsMatcher mux.MatchFunc + + api *API + log logging.LeveledLogger +} + +type simulcastStreamPair struct { + srtp *srtp.ReadStreamSRTP + srtcp *srtp.ReadStreamSRTCP +} + +type streamsForSSRCResult struct { + rtpReadStream *srtp.ReadStreamSRTP + rtpInterceptor interceptor.RTPReader + rtcpReadStream *srtp.ReadStreamSRTCP + rtcpInterceptor interceptor.RTCPReader +} + +// NewDTLSTransport creates a new DTLSTransport. +// This constructor is part of the ORTC API. It is not +// meant to be used together with the basic WebRTC API. +func (api *API) NewDTLSTransport(transport *ICETransport, certificates []Certificate) (*DTLSTransport, error) { + trans := &DTLSTransport{ + iceTransport: transport, + api: api, + state: DTLSTransportStateNew, + dtlsMatcher: mux.MatchDTLS, + srtpReady: make(chan struct{}), + log: api.settingEngine.LoggerFactory.NewLogger("DTLSTransport"), + } + + if len(certificates) > 0 { + now := time.Now() + for _, x509Cert := range certificates { + if !x509Cert.Expires().IsZero() && now.After(x509Cert.Expires()) { + return nil, &rtcerr.InvalidAccessError{Err: ErrCertificateExpired} + } + trans.certificates = append(trans.certificates, x509Cert) + } + } else { + sk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, &rtcerr.UnknownError{Err: err} + } + certificate, err := GenerateCertificate(sk) + if err != nil { + return nil, err + } + trans.certificates = []Certificate{*certificate} + } + + return trans, nil +} + +// ICETransport returns the currently-configured *ICETransport or nil +// if one has not been configured. +func (t *DTLSTransport) ICETransport() *ICETransport { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.iceTransport +} + +// onStateChange requires the caller holds the lock. +func (t *DTLSTransport) onStateChange(state DTLSTransportState) { + t.state = state + handler := t.onStateChangeHandler + if handler != nil { + handler(state) + } +} + +// OnStateChange sets a handler that is fired when the DTLS +// connection state changes. +func (t *DTLSTransport) OnStateChange(f func(DTLSTransportState)) { + t.lock.Lock() + defer t.lock.Unlock() + t.onStateChangeHandler = f +} + +// State returns the current dtls transport state. +func (t *DTLSTransport) State() DTLSTransportState { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.state +} + +// WriteRTCP sends a user provided RTCP packet to the connected peer. If no peer is connected the +// packet is discarded. +func (t *DTLSTransport) WriteRTCP(pkts []rtcp.Packet) (int, error) { + raw, err := rtcp.Marshal(pkts) + if err != nil { + return 0, err + } + + srtcpSession, err := t.getSRTCPSession() + if err != nil { + return 0, err + } + + writeStream, err := srtcpSession.OpenWriteStream() + if err != nil { + // nolint + return 0, fmt.Errorf("%w: %v", errPeerConnWriteRTCPOpenWriteStream, err) + } + + return writeStream.Write(raw) +} + +// GetLocalParameters returns the DTLS parameters of the local DTLSTransport upon construction. +func (t *DTLSTransport) GetLocalParameters() (DTLSParameters, error) { + fingerprints := []DTLSFingerprint{} + + for _, c := range t.certificates { + prints, err := c.GetFingerprints() + if err != nil { + return DTLSParameters{}, err + } + + fingerprints = append(fingerprints, prints...) + } + + return DTLSParameters{ + Role: DTLSRoleAuto, // always returns the default role + Fingerprints: fingerprints, + }, nil +} + +// GetRemoteCertificate returns the certificate chain in use by the remote side +// returns an empty list prior to selection of the remote certificate. +func (t *DTLSTransport) GetRemoteCertificate() []byte { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.remoteCertificate +} + +func (t *DTLSTransport) startSRTP() error { + srtpConfig := &srtp.Config{ + Profile: t.srtpProtectionProfile, + BufferFactory: t.api.settingEngine.BufferFactory, + LoggerFactory: t.api.settingEngine.LoggerFactory, + } + if t.api.settingEngine.replayProtection.SRTP != nil { + srtpConfig.RemoteOptions = append( + srtpConfig.RemoteOptions, + srtp.SRTPReplayProtection(*t.api.settingEngine.replayProtection.SRTP), + ) + } + + if t.api.settingEngine.disableSRTPReplayProtection { + srtpConfig.RemoteOptions = append( + srtpConfig.RemoteOptions, + srtp.SRTPNoReplayProtection(), + ) + } + + if t.api.settingEngine.replayProtection.SRTCP != nil { + srtpConfig.RemoteOptions = append( + srtpConfig.RemoteOptions, + srtp.SRTCPReplayProtection(*t.api.settingEngine.replayProtection.SRTCP), + ) + } + + if t.api.settingEngine.disableSRTCPReplayProtection { + srtpConfig.RemoteOptions = append( + srtpConfig.RemoteOptions, + srtp.SRTCPNoReplayProtection(), + ) + } + + connState, ok := t.conn.ConnectionState() + if !ok { + // nolint + return fmt.Errorf("%w: Failed to get DTLS ConnectionState", errDtlsKeyExtractionFailed) + } + + err := srtpConfig.ExtractSessionKeysFromDTLS(&connState, t.role() == DTLSRoleClient) + if err != nil { + // nolint + return fmt.Errorf("%w: %v", errDtlsKeyExtractionFailed, err) + } + + srtpSession, err := srtp.NewSessionSRTP(t.srtpEndpoint, srtpConfig) + if err != nil { + // nolint + return fmt.Errorf("%w: %v", errFailedToStartSRTP, err) + } + + srtcpSession, err := srtp.NewSessionSRTCP(t.srtcpEndpoint, srtpConfig) + if err != nil { + // nolint + return fmt.Errorf("%w: %v", errFailedToStartSRTCP, err) + } + + t.srtpSession.Store(srtpSession) + t.srtcpSession.Store(srtcpSession) + close(t.srtpReady) + + return nil +} + +func (t *DTLSTransport) getSRTPSession() (*srtp.SessionSRTP, error) { + if value, ok := t.srtpSession.Load().(*srtp.SessionSRTP); ok { + return value, nil + } + + return nil, errDtlsTransportNotStarted +} + +func (t *DTLSTransport) getSRTCPSession() (*srtp.SessionSRTCP, error) { + if value, ok := t.srtcpSession.Load().(*srtp.SessionSRTCP); ok { + return value, nil + } + + return nil, errDtlsTransportNotStarted +} + +func (t *DTLSTransport) role() DTLSRole { + // If remote has an explicit role use the inverse + switch t.remoteParameters.Role { + case DTLSRoleClient: + return DTLSRoleServer + case DTLSRoleServer: + return DTLSRoleClient + default: + } + + // If SettingEngine has an explicit role + switch t.api.settingEngine.answeringDTLSRole { + case DTLSRoleServer: + return DTLSRoleServer + case DTLSRoleClient: + return DTLSRoleClient + default: + } + + // Remote was auto and no explicit role was configured via SettingEngine + if t.iceTransport.Role() == ICERoleControlling { + return DTLSRoleServer + } + + return defaultDtlsRoleAnswer +} + +// Start DTLS transport negotiation with the parameters of the remote DTLS transport. +func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error { + role, certificate, err := t.prepareStart(remoteParameters) + if err != nil { + return err + } + + dtlsEndpoint := t.iceTransport.newEndpoint(mux.MatchDTLS) + dtlsEndpoint.SetOnClose(t.internalOnCloseHandler) + + sharedOpts := t.dtlsSharedOptions(certificate) + + dtlsConn, err := t.connectDTLS(dtlsEndpoint, role, sharedOpts) + if err != nil { + dtlsEndpoint.SetOnClose(nil) + _ = dtlsEndpoint.Close() + + return t.failStart(err) + } + + if err = t.handshakeDTLS(dtlsConn); err != nil { + dtlsEndpoint.SetOnClose(nil) + _ = dtlsConn.Close() + + return t.failStart(err) + } + + if err = t.completeStart(dtlsConn); err != nil { + dtlsEndpoint.SetOnClose(nil) + _ = dtlsConn.Close() + + return err + } + + return nil +} + +func (t *DTLSTransport) prepareStart(remoteParameters DTLSParameters) (DTLSRole, tls.Certificate, error) { + t.lock.Lock() + defer t.lock.Unlock() + + if err := t.ensureICEConn(); err != nil { + return DTLSRole(0), tls.Certificate{}, err + } + + if t.state != DTLSTransportStateNew { + return DTLSRole(0), tls.Certificate{}, &rtcerr.InvalidStateError{ + Err: fmt.Errorf("%w: %s", errInvalidDTLSStart, t.state), + } + } + + t.srtpEndpoint = t.iceTransport.newEndpoint(mux.MatchSRTP) + t.srtcpEndpoint = t.iceTransport.newEndpoint(mux.MatchSRTCP) + t.remoteParameters = remoteParameters + + cert := t.certificates[0] + t.onStateChange(DTLSTransportStateConnecting) + + return t.role(), tls.Certificate{ + Certificate: [][]byte{cert.x509Cert.Raw}, + PrivateKey: cert.privateKey, + }, nil +} + +func (t *DTLSTransport) dtlsSharedOptions(certificate tls.Certificate) []dtls.Option { + sharedOpts := []dtls.Option{ + dtls.WithCertificates(certificate), + dtls.WithSRTPProtectionProfiles(t.srtpProtectionProfiles()...), + dtls.WithExtendedMasterSecret(t.api.settingEngine.dtls.extendedMasterSecret), + dtls.WithInsecureSkipVerify(!t.api.settingEngine.dtls.disableInsecureSkipVerify), + dtls.WithLoggerFactory(t.api.settingEngine.LoggerFactory), + dtls.WithVerifyPeerCertificate(t.verifyPeerCertificateFunc()), + } + + if t.api.settingEngine.dtls.customCipherSuites != nil { + sharedOpts = append( + sharedOpts, + dtls.WithCustomCipherSuites(t.api.settingEngine.dtls.customCipherSuites), + ) + } + + if t.api.settingEngine.dtls.retransmissionInterval > 0 { + sharedOpts = append( + sharedOpts, + dtls.WithFlightInterval(t.api.settingEngine.dtls.retransmissionInterval), + ) + } + + if t.api.settingEngine.replayProtection.DTLS != nil { + sharedOpts = append( + sharedOpts, + dtls.WithReplayProtectionWindow(int(*t.api.settingEngine.replayProtection.DTLS)), //nolint:gosec // G115 + ) + } + + if t.api.settingEngine.dtls.cipherSuites != nil { + sharedOpts = append( + sharedOpts, + dtls.WithCipherSuites(t.api.settingEngine.dtls.cipherSuites...), + ) + } + + if len(t.api.settingEngine.dtls.ellipticCurves) > 0 { + sharedOpts = append( + sharedOpts, + dtls.WithEllipticCurves(t.api.settingEngine.dtls.ellipticCurves...), + ) + } + + if t.api.settingEngine.dtls.rootCAs != nil { + sharedOpts = append(sharedOpts, dtls.WithRootCAs(t.api.settingEngine.dtls.rootCAs)) + } + + if t.api.settingEngine.dtls.keyLogWriter != nil { + sharedOpts = append(sharedOpts, dtls.WithKeyLogWriter(t.api.settingEngine.dtls.keyLogWriter)) + } + + if len(t.api.settingEngine.dtls.supportedProtocols) > 0 { + sharedOpts = append( + sharedOpts, + dtls.WithSupportedProtocols(t.api.settingEngine.dtls.supportedProtocols...), + ) + } + + return sharedOpts +} + +func (t *DTLSTransport) srtpProtectionProfiles() []dtls.SRTPProtectionProfile { + if len(t.api.settingEngine.srtpProtectionProfiles) > 0 { + return t.api.settingEngine.srtpProtectionProfiles + } + + return defaultSrtpProtectionProfiles() +} + +func (t *DTLSTransport) verifyPeerCertificateFunc() func([][]byte, [][]*x509.Certificate) error { + return func(rawCerts [][]byte, _ [][]*x509.Certificate) error { + if len(rawCerts) == 0 { + return errNoRemoteCertificate + } + + t.lock.Lock() + defer t.lock.Unlock() + t.remoteCertificate = rawCerts[0] + + if t.api.settingEngine.disableCertificateFingerprintVerification { + return nil + } + + parsedRemoteCert, err := x509.ParseCertificate(t.remoteCertificate) + if err != nil { + return err + } + + return t.validateFingerPrint(parsedRemoteCert) + } +} + +func (t *DTLSTransport) connectDTLS( + dtlsEndpoint *mux.Endpoint, + role DTLSRole, + sharedOpts []dtls.Option, +) (*dtls.Conn, error) { + if role == DTLSRoleClient { + clientOpts := t.toDTLSClientOptions(sharedOpts) + + return dtls.ClientWithOptions( + dtlsEndpoint, + dtlsEndpoint.RemoteAddr(), + clientOpts..., + ) + } + + serverOpts := t.toDTLSServerOptions(sharedOpts) + + return dtls.ServerWithOptions( + dtlsEndpoint, + dtlsEndpoint.RemoteAddr(), + serverOpts..., + ) +} + +func (t *DTLSTransport) toDTLSServerOptions(sharedOpts []dtls.Option) []dtls.ServerOption { + serverOpts := make([]dtls.ServerOption, 0, len(sharedOpts)+5) + for _, opt := range sharedOpts { + serverOpts = append(serverOpts, opt) + } + + clientAuth := dtls.RequireAnyClientCert + if t.api.settingEngine.dtls.clientAuth != nil { + clientAuth = *t.api.settingEngine.dtls.clientAuth + } + + serverOpts = append(serverOpts, + dtls.WithClientAuth(clientAuth), + dtls.WithClientCAs(t.api.settingEngine.dtls.clientCAs), + dtls.WithInsecureSkipVerifyHello(t.api.settingEngine.dtls.insecureSkipHelloVerify), + ) + + if t.api.settingEngine.dtls.serverHelloMessageHook != nil { + serverOpts = append( + serverOpts, + dtls.WithServerHelloMessageHook(t.api.settingEngine.dtls.serverHelloMessageHook), + ) + } + + if t.api.settingEngine.dtls.certificateRequestMessageHook != nil { + serverOpts = append( + serverOpts, + dtls.WithCertificateRequestMessageHook(t.api.settingEngine.dtls.certificateRequestMessageHook), + ) + } + + return serverOpts +} + +func (t *DTLSTransport) toDTLSClientOptions(sharedOpts []dtls.Option) []dtls.ClientOption { + clientOpts := make([]dtls.ClientOption, 0, len(sharedOpts)+1) + for _, opt := range sharedOpts { + clientOpts = append(clientOpts, opt) + } + + if t.api.settingEngine.dtls.clientHelloMessageHook != nil { + clientOpts = append( + clientOpts, + dtls.WithClientHelloMessageHook(t.api.settingEngine.dtls.clientHelloMessageHook), + ) + } + + return clientOpts +} + +func (t *DTLSTransport) handshakeDTLS(dtlsConn *dtls.Conn) error { + if t.api.settingEngine.dtls.connectContextMaker == nil { + return dtlsConn.Handshake() + } + + handshakeCtx, cancel := t.api.settingEngine.dtls.connectContextMaker() + if cancel != nil { + defer cancel() + } + + return dtlsConn.HandshakeContext(handshakeCtx) +} + +func (t *DTLSTransport) completeStart(dtlsConn *dtls.Conn) error { + srtpProtectionProfile, err := srtpProtectionProfileFromDTLSConn(dtlsConn) + + t.lock.Lock() + defer t.lock.Unlock() + + if err != nil { + t.onStateChange(DTLSTransportStateFailed) + + return err + } + + t.srtpProtectionProfile = srtpProtectionProfile + t.conn = dtlsConn + t.onStateChange(DTLSTransportStateConnected) + + return t.startSRTP() +} + +func (t *DTLSTransport) failStart(err error) error { + t.lock.Lock() + defer t.lock.Unlock() + t.onStateChange(DTLSTransportStateFailed) + + return err +} + +func srtpProtectionProfileFromDTLSConn(dtlsConn *dtls.Conn) (srtp.ProtectionProfile, error) { + srtpProfile, ok := dtlsConn.SelectedSRTPProtectionProfile() + if !ok { + return 0, ErrNoSRTPProtectionProfile + } + + return srtpProtectionProfileFromDTLS(srtpProfile) +} + +func srtpProtectionProfileFromDTLS(srtpProfile dtls.SRTPProtectionProfile) (srtp.ProtectionProfile, error) { + switch srtpProfile { + case dtls.SRTP_AEAD_AES_128_GCM: + return srtp.ProtectionProfileAeadAes128Gcm, nil + case dtls.SRTP_AEAD_AES_256_GCM: + return srtp.ProtectionProfileAeadAes256Gcm, nil + case dtls.SRTP_AES128_CM_HMAC_SHA1_80: + return srtp.ProtectionProfileAes128CmHmacSha1_80, nil + case dtls.SRTP_NULL_HMAC_SHA1_80: + return srtp.ProtectionProfileNullHmacSha1_80, nil + default: + return 0, ErrNoSRTPProtectionProfile + } +} + +// Stop stops and closes the DTLSTransport object. +func (t *DTLSTransport) Stop() error { + t.lock.Lock() + defer t.lock.Unlock() + + // Try closing everything and collect the errors + var closeErrs []error + + if srtpSession, err := t.getSRTPSession(); err == nil && srtpSession != nil { + closeErrs = append(closeErrs, srtpSession.Close()) + } + + if srtcpSession, err := t.getSRTCPSession(); err == nil && srtcpSession != nil { + closeErrs = append(closeErrs, srtcpSession.Close()) + } + + for i := range t.simulcastStreams { + closeErrs = append(closeErrs, t.simulcastStreams[i].srtp.Close()) + closeErrs = append(closeErrs, t.simulcastStreams[i].srtcp.Close()) + } + + if t.conn != nil { + // dtls connection may be closed on sctp close. + if err := t.conn.Close(); err != nil && !errors.Is(err, dtls.ErrConnClosed) { + closeErrs = append(closeErrs, err) + } + } + t.onStateChange(DTLSTransportStateClosed) + + return util.FlattenErrs(closeErrs) +} + +func (t *DTLSTransport) validateFingerPrint(remoteCert *x509.Certificate) error { + for _, fp := range t.remoteParameters.Fingerprints { + hashAlgo, err := fingerprint.HashFromString(fp.Algorithm) + if err != nil { + return err + } + + remoteValue, err := fingerprint.Fingerprint(remoteCert, hashAlgo) + if err != nil { + return err + } + + if strings.EqualFold(remoteValue, fp.Value) { + return nil + } + } + + return errNoMatchingCertificateFingerprint +} + +func (t *DTLSTransport) ensureICEConn() error { + if t.iceTransport == nil { + return errICEConnectionNotStarted + } + + return nil +} + +func (t *DTLSTransport) storeSimulcastStream( + srtpReadStream *srtp.ReadStreamSRTP, + srtcpReadStream *srtp.ReadStreamSRTCP, +) { + t.lock.Lock() + defer t.lock.Unlock() + + t.simulcastStreams = append(t.simulcastStreams, simulcastStreamPair{srtpReadStream, srtcpReadStream}) +} + +func (t *DTLSTransport) streamsForSSRC( + ssrc SSRC, + streamInfo interceptor.StreamInfo, +) (*streamsForSSRCResult, error) { + srtpSession, err := t.getSRTPSession() + if err != nil { + return nil, err + } + + rtpReadStream, err := srtpSession.OpenReadStream(uint32(ssrc)) + if err != nil { + return nil, err + } + + rtpInterceptor := t.api.interceptor.BindRemoteStream( + &streamInfo, + interceptor.RTPReaderFunc( + func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) { + n, err = rtpReadStream.Read(in) + + return n, a, err + }, + ), + ) + + srtcpSession, err := t.getSRTCPSession() + if err != nil { + return nil, err + } + + rtcpReadStream, err := srtcpSession.OpenReadStream(uint32(ssrc)) + if err != nil { + return nil, err + } + + rtcpInterceptor := t.api.interceptor.BindRTCPReader(interceptor.RTCPReaderFunc( + func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) { + n, err = rtcpReadStream.Read(in) + + return n, a, err + }), + ) + + return &streamsForSSRCResult{ + rtpReadStream: rtpReadStream, + rtpInterceptor: rtpInterceptor, + rtcpReadStream: rtcpReadStream, + rtcpInterceptor: rtcpInterceptor, + }, nil +} diff --git a/vendor/github.com/pion/webrtc/v4/dtlstransport_js.go b/vendor/github.com/pion/webrtc/v4/dtlstransport_js.go new file mode 100644 index 0000000..3d85e80 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/dtlstransport_js.go @@ -0,0 +1,65 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build js && wasm +// +build js,wasm + +package webrtc + +import "syscall/js" + +// DTLSTransport allows an application access to information about the DTLS +// transport over which RTP and RTCP packets are sent and received by +// RTPSender and RTPReceiver, as well other data such as SCTP packets sent +// and received by data channels. +type DTLSTransport struct { + // Pointer to the underlying JavaScript DTLSTransport object. + underlying js.Value +} + +// JSValue returns the underlying RTCDtlsTransport +func (r *DTLSTransport) JSValue() js.Value { + return r.underlying +} + +// ICETransport returns the currently-configured *ICETransport or nil +// if one has not been configured +func (r *DTLSTransport) ICETransport() *ICETransport { + underlying := r.underlying.Get("iceTransport") + if underlying.IsNull() || underlying.IsUndefined() { + return nil + } + + return &ICETransport{ + underlying: underlying, + } +} + +func (t *DTLSTransport) GetRemoteCertificate() []byte { + if t.underlying.IsNull() || t.underlying.IsUndefined() { + return nil + } + + // Firefox does not support getRemoteCertificates: https://bugzilla.mozilla.org/show_bug.cgi?id=1805446 + jsGet := t.underlying.Get("getRemoteCertificates") + if jsGet.IsUndefined() || jsGet.IsNull() { + return nil + } + + jsCerts := t.underlying.Call("getRemoteCertificates") + if jsCerts.Length() == 0 { + return nil + } + + buf := jsCerts.Index(0) + u8 := js.Global().Get("Uint8Array").New(buf) + + if u8.Length() == 0 { + return nil + } + + cert := make([]byte, u8.Length()) + js.CopyBytesToGo(cert, u8) + + return cert +} diff --git a/vendor/github.com/pion/webrtc/v4/dtlstransportstate.go b/vendor/github.com/pion/webrtc/v4/dtlstransportstate.go new file mode 100644 index 0000000..1b9b3e4 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/dtlstransportstate.go @@ -0,0 +1,89 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +// DTLSTransportState indicates the DTLS transport establishment state. +type DTLSTransportState int + +const ( + // DTLSTransportStateUnknown is the enum's zero-value. + DTLSTransportStateUnknown DTLSTransportState = iota + + // DTLSTransportStateNew indicates that DTLS has not started negotiating + // yet. + DTLSTransportStateNew + + // DTLSTransportStateConnecting indicates that DTLS is in the process of + // negotiating a secure connection and verifying the remote fingerprint. + DTLSTransportStateConnecting + + // DTLSTransportStateConnected indicates that DTLS has completed + // negotiation of a secure connection and verified the remote fingerprint. + DTLSTransportStateConnected + + // DTLSTransportStateClosed indicates that the transport has been closed + // intentionally as the result of receipt of a close_notify alert, or + // calling close(). + DTLSTransportStateClosed + + // DTLSTransportStateFailed indicates that the transport has failed as + // the result of an error (such as receipt of an error alert or failure to + // validate the remote fingerprint). + DTLSTransportStateFailed +) + +// This is done this way because of a linter. +const ( + dtlsTransportStateNewStr = "new" + dtlsTransportStateConnectingStr = "connecting" + dtlsTransportStateConnectedStr = "connected" + dtlsTransportStateClosedStr = "closed" + dtlsTransportStateFailedStr = "failed" +) + +func newDTLSTransportState(raw string) DTLSTransportState { + switch raw { + case dtlsTransportStateNewStr: + return DTLSTransportStateNew + case dtlsTransportStateConnectingStr: + return DTLSTransportStateConnecting + case dtlsTransportStateConnectedStr: + return DTLSTransportStateConnected + case dtlsTransportStateClosedStr: + return DTLSTransportStateClosed + case dtlsTransportStateFailedStr: + return DTLSTransportStateFailed + default: + return DTLSTransportStateUnknown + } +} + +func (t DTLSTransportState) String() string { + switch t { + case DTLSTransportStateNew: + return dtlsTransportStateNewStr + case DTLSTransportStateConnecting: + return dtlsTransportStateConnectingStr + case DTLSTransportStateConnected: + return dtlsTransportStateConnectedStr + case DTLSTransportStateClosed: + return dtlsTransportStateClosedStr + case DTLSTransportStateFailed: + return dtlsTransportStateFailedStr + default: + return ErrUnknownType.Error() + } +} + +// MarshalText implements encoding.TextMarshaler. +func (t DTLSTransportState) MarshalText() ([]byte, error) { + return []byte(t.String()), nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (t *DTLSTransportState) UnmarshalText(b []byte) error { + *t = newDTLSTransportState(string(b)) + + return nil +} diff --git a/vendor/github.com/pion/webrtc/v4/errors.go b/vendor/github.com/pion/webrtc/v4/errors.go new file mode 100644 index 0000000..b023d16 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/errors.go @@ -0,0 +1,296 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +import ( + "errors" +) + +var ( + // ErrUnknownType indicates an error with Unknown info. + ErrUnknownType = errors.New("unknown") + + // ErrConnectionClosed indicates an operation executed after connection + // has already been closed. + ErrConnectionClosed = errors.New("connection closed") + + // ErrDataChannelNotOpen indicates an operation executed when the data + // channel is not (yet) open. + ErrDataChannelNotOpen = errors.New("data channel not open") + + // ErrCertificateExpired indicates that an x509 certificate has expired. + ErrCertificateExpired = errors.New("x509Cert expired") + + // ErrNoTurnCredentials indicates that a TURN server URL was provided + // without required credentials. + ErrNoTurnCredentials = errors.New("turn server credentials required") + + // ErrTurnCredentials indicates that provided TURN credentials are partial + // or malformed. + ErrTurnCredentials = errors.New("invalid turn server credentials") + + // ErrExistingTrack indicates that a track already exists. + ErrExistingTrack = errors.New("track already exists") + + // ErrPrivateKeyType indicates that a particular private key encryption + // chosen to generate a certificate is not supported. + ErrPrivateKeyType = errors.New("private key type not supported") + + // ErrModifyingPeerIdentity indicates that an attempt to modify + // PeerIdentity was made after PeerConnection has been initialized. + ErrModifyingPeerIdentity = errors.New("peerIdentity cannot be modified") + + // ErrModifyingCertificates indicates that an attempt to modify + // Certificates was made after PeerConnection has been initialized. + ErrModifyingCertificates = errors.New("certificates cannot be modified") + + // ErrModifyingBundlePolicy indicates that an attempt to modify + // BundlePolicy was made after PeerConnection has been initialized. + ErrModifyingBundlePolicy = errors.New("bundle policy cannot be modified") + + // ErrModifyingRTCPMuxPolicy indicates that an attempt to modify + // RTCPMuxPolicy was made after PeerConnection has been initialized. + ErrModifyingRTCPMuxPolicy = errors.New("rtcp mux policy cannot be modified") + + // ErrModifyingICECandidatePoolSize indicates that an attempt to modify + // ICECandidatePoolSize was made after PeerConnection has been initialized. + ErrModifyingICECandidatePoolSize = errors.New("ice candidate pool size cannot be modified") + + // ErrStringSizeLimit indicates that the character size limit of string is + // exceeded. The limit is hardcoded to 65535 according to specifications. + ErrStringSizeLimit = errors.New("data channel label exceeds size limit") + + // ErrMaxDataChannelID indicates that the maximum number ID that could be + // specified for a data channel has been exceeded. + ErrMaxDataChannelID = errors.New("maximum number ID for datachannel specified") + + // ErrNegotiatedWithoutID indicates that an attempt to create a data channel + // was made while setting the negotiated option to true without providing + // the negotiated channel ID. + ErrNegotiatedWithoutID = errors.New("negotiated set without channel id") + + // ErrRetransmitsOrPacketLifeTime indicates that an attempt to create a data + // channel was made with both options MaxPacketLifeTime and MaxRetransmits + // set together. Such configuration is not supported by the specification + // and is mutually exclusive. + ErrRetransmitsOrPacketLifeTime = errors.New("both MaxPacketLifeTime and MaxRetransmits was set") + + // ErrCodecNotFound is returned when a codec search to the Media Engine fails. + ErrCodecNotFound = errors.New("codec not found") + + // ErrNoRemoteDescription indicates that an operation was rejected because + // the remote description is not set. + ErrNoRemoteDescription = errors.New("remote description is not set") + + // ErrIncorrectSDPSemantics indicates that the PeerConnection was configured to + // generate SDP Answers with different SDP Semantics than the received Offer. + ErrIncorrectSDPSemantics = errors.New("remote SessionDescription semantics does not match configuration") + + // ErrIncorrectSignalingState indicates that the signaling state of PeerConnection is not correct. + ErrIncorrectSignalingState = errors.New("operation can not be run in current signaling state") + + // ErrProtocolTooLarge indicates that value given for a DataChannelInit protocol is + // longer then 65535 bytes. + ErrProtocolTooLarge = errors.New("protocol is larger then 65535 bytes") + + // ErrSenderNotCreatedByConnection indicates RemoveTrack was called with a RtpSender not created + // by this PeerConnection. + ErrSenderNotCreatedByConnection = errors.New("RtpSender not created by this PeerConnection") + + // ErrSessionDescriptionNoFingerprint indicates SetRemoteDescription was called with a SessionDescription that has no + // fingerprint. + ErrSessionDescriptionNoFingerprint = errors.New("SetRemoteDescription called with no fingerprint") + + // ErrSessionDescriptionInvalidFingerprint indicates SetRemoteDescription was called with a SessionDescription that + // has an invalid fingerprint. + ErrSessionDescriptionInvalidFingerprint = errors.New("SetRemoteDescription called with an invalid fingerprint") + + // ErrSessionDescriptionConflictingFingerprints indicates SetRemoteDescription was called with a SessionDescription + // that has an conflicting fingerprints. + ErrSessionDescriptionConflictingFingerprints = errors.New( + "SetRemoteDescription called with multiple conflicting fingerprint", + ) + + // ErrSessionDescriptionMissingIceUfrag indicates SetRemoteDescription was called with a SessionDescription that + // is missing an ice-ufrag value. + ErrSessionDescriptionMissingIceUfrag = errors.New("SetRemoteDescription called with no ice-ufrag") + + // ErrSessionDescriptionMissingIcePwd indicates SetRemoteDescription was called with a SessionDescription that + // is missing an ice-pwd value. + ErrSessionDescriptionMissingIcePwd = errors.New("SetRemoteDescription called with no ice-pwd") + + // ErrSessionDescriptionConflictingIceUfrag indicates SetRemoteDescription was called with a SessionDescription + // that contains multiple conflicting ice-ufrag values. + ErrSessionDescriptionConflictingIceUfrag = errors.New( + "SetRemoteDescription called with multiple conflicting ice-ufrag values", + ) + + // ErrSessionDescriptionConflictingIcePwd indicates SetRemoteDescription was called with a SessionDescription + // that contains multiple conflicting ice-pwd values. + ErrSessionDescriptionConflictingIcePwd = errors.New( + "SetRemoteDescription called with multiple conflicting ice-pwd values", + ) + + // ErrNoSRTPProtectionProfile indicates that the DTLS handshake completed and no SRTP Protection Profile was chosen. + ErrNoSRTPProtectionProfile = errors.New("DTLS Handshake completed and no SRTP Protection Profile was chosen") + + // ErrFailedToGenerateCertificateFingerprint indicates that we failed to generate the fingerprint + // used for comparing certificates. + ErrFailedToGenerateCertificateFingerprint = errors.New("failed to generate certificate fingerprint") + + // ErrNoCodecsAvailable indicates that operation isn't possible because the MediaEngine has no codecs available. + ErrNoCodecsAvailable = errors.New("operation failed no codecs are available") + + // ErrUnsupportedCodec indicates the remote peer doesn't support the requested codec. + ErrUnsupportedCodec = errors.New("unable to start track, codec is not supported by remote") + + // ErrSenderWithNoCodecs indicates that a RTPSender was created without any codecs. To send media the MediaEngine + // needs at least one configured codec. + ErrSenderWithNoCodecs = errors.New("unable to populate media section, RTPSender created with no codecs") + + // ErrCodecAlreadyRegistered indicates that a codec has already been registered for the same payload type. + ErrCodecAlreadyRegistered = errors.New("codec already registered for same payload type") + + // ErrRTPSenderNewTrackHasIncorrectKind indicates that the new track is of a different kind than the previous/original. + ErrRTPSenderNewTrackHasIncorrectKind = errors.New("new track must be of the same kind as previous") + + // ErrRTPSenderNewTrackHasIncorrectEnvelope indicates that the new track has a different envelope + // than the previous/original. + ErrRTPSenderNewTrackHasIncorrectEnvelope = errors.New("new track must have the same envelope as previous") + + // ErrUnbindFailed indicates that a TrackLocal was not able to be unbind. + ErrUnbindFailed = errors.New("failed to unbind TrackLocal from PeerConnection") + + // ErrNoPayloaderForCodec indicates that the requested codec does not have a payloader. + ErrNoPayloaderForCodec = errors.New("the requested codec does not have a payloader") + + // ErrRegisterHeaderExtensionInvalidDirection indicates that a extension was + // registered with a direction besides `sendonly` or `recvonly`. + ErrRegisterHeaderExtensionInvalidDirection = errors.New( + "a header extension must be registered as 'recvonly', 'sendonly' or both", + ) + + // ErrSimulcastProbeOverflow indicates that too many Simulcast probe streams are in flight + // and the requested SSRC was ignored. + ErrSimulcastProbeOverflow = errors.New("simulcast probe limit has been reached, new SSRC has been discarded") + + // ErrSDPUnmarshalling indicates that the SDP could not be unmarshalled. + ErrSDPUnmarshalling = errors.New("failed to unmarshal SDP") + + errDetachNotEnabled = errors.New("enable detaching by calling webrtc.DetachDataChannels()") + errDetachBeforeOpened = errors.New("datachannel not opened yet, try calling Detach from OnOpen") + errDtlsTransportNotStarted = errors.New("the DTLS transport has not started yet") + errDtlsKeyExtractionFailed = errors.New("failed extracting keys from DTLS for SRTP") + errFailedToStartSRTP = errors.New("failed to start SRTP") + errFailedToStartSRTCP = errors.New("failed to start SRTCP") + errInvalidDTLSStart = errors.New("attempted to start DTLSTransport that is not in new state") + errNoRemoteCertificate = errors.New("peer didn't provide certificate via DTLS") + errIdentityProviderNotImplemented = errors.New("identity provider is not implemented") + errNoMatchingCertificateFingerprint = errors.New("remote certificate does not match any fingerprint") + + errICEConnectionNotStarted = errors.New("ICE connection not started") + errICECandidateTypeUnknown = errors.New("unknown candidate type") + errICEInvalidConvertCandidateType = errors.New( + "cannot convert ice.CandidateType into webrtc.ICECandidateType, invalid type", + ) + errICEAgentNotExist = errors.New("ICEAgent does not exist") + errICECandiatesCoversionFailed = errors.New("unable to convert ICE candidates to ICECandidates") + errICERoleUnknown = errors.New("unknown ICE Role") + errICEProtocolUnknown = errors.New("unknown protocol") + errICEGathererNotStarted = errors.New("gatherer not started") + errAddressRewriteWithNAT1To1 = errors.New("address rewrite rules cannot be combined with NAT1To1IPs") + + errNetworkTypeUnknown = errors.New("unknown network type") + + errSDPDoesNotMatchOffer = errors.New("new sdp does not match previous offer") + errSDPDoesNotMatchAnswer = errors.New("new sdp does not match previous answer") + errPeerConnSDPTypeInvalidValue = errors.New( + "provided value is not a valid enum value of type SDPType", + ) + errPeerConnStateChangeInvalid = errors.New("invalid state change op") + errPeerConnStateChangeUnhandled = errors.New("unhandled state change op") + errPeerConnSDPTypeInvalidValueSetLocalDescription = errors.New("invalid SDP type supplied to SetLocalDescription()") + errPeerConnRemoteDescriptionWithoutMidValue = errors.New( + "remoteDescription contained media section without mid value", + ) + errPeerConnRemoteDescriptionNil = errors.New("remoteDescription has not been set yet") + errMediaSectionHasExplictSSRCAttribute = errors.New("media section has an explicit SSRC") + errPeerConnRemoteSSRCAddTransceiver = errors.New("could not add transceiver for remote SSRC") + errPeerConnSimulcastMidRTPExtensionRequired = errors.New("mid RTP Extensions required for Simulcast") + errPeerConnSimulcastStreamIDRTPExtensionRequired = errors.New("stream id RTP Extensions required for Simulcast") + errPeerConnSimulcastIncomingSSRCFailed = errors.New("incoming SSRC failed Simulcast probing") + errPeerConnAddTransceiverFromKindOnlyAcceptsOne = errors.New( + "AddTransceiverFromKind only accepts one RTPTransceiverInit", + ) + errPeerConnAddTransceiverFromTrackOnlyAcceptsOne = errors.New( + "AddTransceiverFromTrack only accepts one RTPTransceiverInit", + ) + errPeerConnAddTransceiverFromKindSupport = errors.New( + "AddTransceiverFromKind currently only supports recvonly", + ) + errPeerConnAddTransceiverFromTrackSupport = errors.New( + "AddTransceiverFromTrack currently only supports sendonly and sendrecv", + ) + errPeerConnSetIdentityProviderNotImplemented = errors.New("TODO SetIdentityProvider") + errPeerConnWriteRTCPOpenWriteStream = errors.New("WriteRTCP failed to open WriteStream") + errPeerConnTranscieverMidNil = errors.New("cannot find transceiver with mid") + errPeerConnEarlyMediaWithoutAnswer = errors.New( + "cannot process early media without SDP answer," + + "use SettingEngine.SetHandleUndeclaredSSRCWithoutAnswer(true) to process without answer", + ) + + errRTPReceiverDTLSTransportNil = errors.New("DTLSTransport must not be nil") + errRTPReceiverReceiveAlreadyCalled = errors.New("Receive has already been called") + errRTPReceiverWithSSRCTrackStreamNotFound = errors.New("unable to find stream for Track with SSRC") + errRTPReceiverForRIDTrackStreamNotFound = errors.New("no trackStreams found for RID") + + errRTPSenderTrackNil = errors.New("Track must not be nil") + errRTPSenderDTLSTransportNil = errors.New("DTLSTransport must not be nil") + errRTPSenderSendAlreadyCalled = errors.New("Send has already been called") + errRTPSenderSendNotCalled = errors.New("Send has not been called") + errRTPSenderStopped = errors.New("Sender has already been stopped") + errRTPSenderTrackRemoved = errors.New("Sender Track has been removed or replaced to nil") + errRTPSenderRidNil = errors.New("Sender cannot add encoding as rid is empty") + errRTPSenderNoBaseEncoding = errors.New("Sender cannot add encoding as there is no base track") + errRTPSenderBaseEncodingMismatch = errors.New("Sender cannot add encoding as provided track does not match base track") + errRTPSenderRIDCollision = errors.New("Sender cannot encoding due to RID collision") + errRTPSenderNoTrackForRID = errors.New("Sender does not have track for RID") + + errRTPTransceiverCannotChangeMid = errors.New("cannot change transceiver mid") + errRTPTransceiverSetSendingInvalidState = errors.New("invalid state change in RTPTransceiver.setSending") + errRTPTransceiverCodecUnsupported = errors.New("unsupported codec type by this transceiver") + + errSCTPTransportDTLS = errors.New("DTLS not established") + + errSDPZeroTransceivers = errors.New("addTransceiverSDP() called with 0 transceivers") + errSDPMediaSectionMediaDataChanInvalid = errors.New("invalid Media Section. Media + DataChannel both enabled") + errSDPMediaSectionMultipleTrackInvalid = errors.New( + "invalid Media Section. Can not have multiple tracks in one MediaSection in UnifiedPlan", + ) + + errSettingEngineSetAnsweringDTLSRole = errors.New("SetAnsweringDTLSRole must DTLSRoleClient or DTLSRoleServer") + + errSignalingStateCannotRollback = errors.New("can't rollback from stable state") + errSignalingStateProposedTransitionInvalid = errors.New("invalid proposed signaling state transition") + + errStatsICECandidateStateInvalid = errors.New( + "cannot convert to StatsICECandidatePairStateSucceeded invalid ice candidate state", + ) + + errICECandidatePoolSizeTooLarge = errors.New("ice candidate pool size greater than 1 is not supported") + + errInvalidICECredentialTypeString = errors.New("invalid ICECredentialType") + errInvalidICEServer = errors.New("invalid ICEServer") + + errICETransportNotInNew = errors.New("ICETransport can only be called in ICETransportStateNew") + errICETransportClosed = errors.New("ICETransport closed") + + errCertificatePEMMultipleCert = errors.New("failed parsing certificate, more than 1 CERTIFICATE block in pems") + errCertificatePEMMultiplePriv = errors.New("failed parsing certificate, more than 1 PRIVATE KEY block in pems") + errCertificatePEMMissing = errors.New("failed parsing certificate, pems must contain both a CERTIFICATE block and a PRIVATE KEY block") // nolint: lll + + errRTPTooShort = errors.New("not long enough to be a RTP Packet") + + errExcessiveRetries = errors.New("excessive retries in CreateOffer") +) diff --git a/vendor/github.com/pion/webrtc/v4/gathering_complete_promise.go b/vendor/github.com/pion/webrtc/v4/gathering_complete_promise.go new file mode 100644 index 0000000..fbb2eef --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/gathering_complete_promise.go @@ -0,0 +1,29 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +import ( + "context" +) + +// GatheringCompletePromise is a Pion specific helper function that returns a channel that is closed +// when gathering is complete. +// This function may be helpful in cases where you are unable to trickle your ICE Candidates. +// +// It is better to not use this function, and instead trickle candidates. +// If you use this function you will see longer connection startup times. +// When the call is connected you will see no impact however. +func GatheringCompletePromise(pc *PeerConnection) (gatherComplete <-chan struct{}) { + gatheringComplete, done := context.WithCancel(context.Background()) + + // It's possible to miss the GatherComplete event since setGatherCompleteHandler is an atomic operation and the + // promise might have been created after the gathering is finished. Therefore, we need to check if the ICE gathering + // state has changed to complete so that we don't block the caller forever. + pc.setGatherCompleteHandler(func() { done() }) + if pc.ICEGatheringState() == ICEGatheringStateComplete { + done() + } + + return gatheringComplete.Done() +} diff --git a/vendor/github.com/pion/webrtc/v4/ice_go.go b/vendor/github.com/pion/webrtc/v4/ice_go.go new file mode 100644 index 0000000..e42170d --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/ice_go.go @@ -0,0 +1,13 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js + +package webrtc + +// NewICETransport creates a new NewICETransport. +// This constructor is part of the ORTC API. It is not +// meant to be used together with the basic WebRTC API. +func (api *API) NewICETransport(gatherer *ICEGatherer) *ICETransport { + return NewICETransport(gatherer, api.settingEngine.LoggerFactory) +} diff --git a/vendor/github.com/pion/webrtc/v4/icecandidate.go b/vendor/github.com/pion/webrtc/v4/icecandidate.go new file mode 100644 index 0000000..17182a0 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/icecandidate.go @@ -0,0 +1,244 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +import ( + "fmt" + "strings" + + "github.com/pion/ice/v4" +) + +// ICECandidate represents a ice candidate. +type ICECandidate struct { + statsID string + Foundation string `json:"foundation"` + Priority uint32 `json:"priority"` + Address string `json:"address"` + Protocol ICEProtocol `json:"protocol"` + Port uint16 `json:"port"` + Typ ICECandidateType `json:"type"` + Component uint16 `json:"component"` + RelatedAddress string `json:"relatedAddress"` + RelatedPort uint16 `json:"relatedPort"` + TCPType string `json:"tcpType"` + SDPMid string `json:"sdpMid"` + SDPMLineIndex uint16 `json:"sdpMLineIndex"` + extensions string +} + +// Conversion for package ice. +func newICECandidatesFromICE( + iceCandidates []ice.Candidate, + sdpMid string, + sdpMLineIndex uint16, +) ([]ICECandidate, error) { + candidates := []ICECandidate{} + + for _, i := range iceCandidates { + c, err := newICECandidateFromICE(i, sdpMid, sdpMLineIndex) + if err != nil { + return nil, err + } + candidates = append(candidates, c) + } + + return candidates, nil +} + +func newICECandidateFromICE(candidate ice.Candidate, sdpMid string, sdpMLineIndex uint16) (ICECandidate, error) { + typ, err := convertTypeFromICE(candidate.Type()) + if err != nil { + return ICECandidate{}, err + } + protocol, err := NewICEProtocol(candidate.NetworkType().NetworkShort()) + if err != nil { + return ICECandidate{}, err + } + + newCandidate := ICECandidate{ + statsID: candidate.ID(), + Foundation: candidate.Foundation(), + Priority: candidate.Priority(), + Address: candidate.Address(), + Protocol: protocol, + Port: uint16(candidate.Port()), //nolint:gosec // G115 + Component: candidate.Component(), + Typ: typ, + TCPType: candidate.TCPType().String(), + SDPMid: sdpMid, + SDPMLineIndex: sdpMLineIndex, + } + + newCandidate.setExtensions(candidate.Extensions()) + + if candidate.RelatedAddress() != nil { + newCandidate.RelatedAddress = candidate.RelatedAddress().Address + newCandidate.RelatedPort = uint16(candidate.RelatedAddress().Port) //nolint:gosec // G115 + } + + return newCandidate, nil +} + +// ToICE converts ICECandidate to ice.Candidate. +func (c ICECandidate) ToICE() (cand ice.Candidate, err error) { + candidateID := c.statsID + switch c.Typ { + case ICECandidateTypeHost: + config := ice.CandidateHostConfig{ + CandidateID: candidateID, + Network: c.Protocol.String(), + Address: c.Address, + Port: int(c.Port), + Component: c.Component, + TCPType: ice.NewTCPType(c.TCPType), + Foundation: c.Foundation, + Priority: c.Priority, + } + + cand, err = ice.NewCandidateHost(&config) + case ICECandidateTypeSrflx: + config := ice.CandidateServerReflexiveConfig{ + CandidateID: candidateID, + Network: c.Protocol.String(), + Address: c.Address, + Port: int(c.Port), + Component: c.Component, + Foundation: c.Foundation, + Priority: c.Priority, + RelAddr: c.RelatedAddress, + RelPort: int(c.RelatedPort), + } + + cand, err = ice.NewCandidateServerReflexive(&config) + case ICECandidateTypePrflx: + config := ice.CandidatePeerReflexiveConfig{ + CandidateID: candidateID, + Network: c.Protocol.String(), + Address: c.Address, + Port: int(c.Port), + Component: c.Component, + Foundation: c.Foundation, + Priority: c.Priority, + RelAddr: c.RelatedAddress, + RelPort: int(c.RelatedPort), + } + + cand, err = ice.NewCandidatePeerReflexive(&config) + case ICECandidateTypeRelay: + config := ice.CandidateRelayConfig{ + CandidateID: candidateID, + Network: c.Protocol.String(), + Address: c.Address, + Port: int(c.Port), + Component: c.Component, + Foundation: c.Foundation, + Priority: c.Priority, + RelAddr: c.RelatedAddress, + RelPort: int(c.RelatedPort), + } + + cand, err = ice.NewCandidateRelay(&config) + default: + return nil, fmt.Errorf("%w: %s", errICECandidateTypeUnknown, c.Typ) + } + + if cand != nil && err == nil { + err = c.exportExtensions(cand) + } + + return cand, err +} + +func (c *ICECandidate) setExtensions(ext []ice.CandidateExtension) { + var extensions strings.Builder + + for i := range ext { + if i > 0 { + extensions.WriteString(" ") + } + + extensions.WriteString(ext[i].Key + " " + ext[i].Value) + } + + c.extensions = extensions.String() +} + +func (c *ICECandidate) exportExtensions(cand ice.Candidate) error { + extensions := c.extensions + var ext ice.CandidateExtension + var field string + + for i, start := 0, 0; i < len(extensions); i++ { + switch { + case extensions[i] == ' ': + field = extensions[start:i] + start = i + 1 + case i == len(extensions)-1: + field = extensions[start:] + default: + continue + } + + // Extension keys can't be empty + hasKey := ext.Key != "" + if !hasKey { + ext.Key = field + } else { + ext.Value = field + } + + // Extension value can be empty + if hasKey || i == len(extensions)-1 { + if err := cand.AddExtension(ext); err != nil { + return err + } + + ext = ice.CandidateExtension{} + } + } + + return nil +} + +func convertTypeFromICE(t ice.CandidateType) (ICECandidateType, error) { + switch t { + case ice.CandidateTypeHost: + return ICECandidateTypeHost, nil + case ice.CandidateTypeServerReflexive: + return ICECandidateTypeSrflx, nil + case ice.CandidateTypePeerReflexive: + return ICECandidateTypePrflx, nil + case ice.CandidateTypeRelay: + return ICECandidateTypeRelay, nil + default: + return ICECandidateType(t), fmt.Errorf("%w: %s", errICECandidateTypeUnknown, t) + } +} + +func (c ICECandidate) String() string { + ic, err := c.ToICE() + if err != nil { + return fmt.Sprintf("%#v failed to convert to ICE: %s", c, err) + } + + return ic.String() +} + +// ToJSON returns an ICECandidateInit +// as indicated by the spec https://w3c.github.io/webrtc-pc/#dom-rtcicecandidate-tojson +func (c ICECandidate) ToJSON() ICECandidateInit { + candidateStr := "" + + candidate, err := c.ToICE() + if err == nil { + candidateStr = candidate.Marshal() + } + + return ICECandidateInit{ + Candidate: fmt.Sprintf("candidate:%s", candidateStr), + SDPMid: &c.SDPMid, + SDPMLineIndex: &c.SDPMLineIndex, + } +} diff --git a/vendor/github.com/pion/webrtc/v4/icecandidateinit.go b/vendor/github.com/pion/webrtc/v4/icecandidateinit.go new file mode 100644 index 0000000..f3a51c1 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/icecandidateinit.go @@ -0,0 +1,12 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +// ICECandidateInit is used to serialize ice candidates. +type ICECandidateInit struct { + Candidate string `json:"candidate"` + SDPMid *string `json:"sdpMid"` + SDPMLineIndex *uint16 `json:"sdpMLineIndex"` + UsernameFragment *string `json:"usernameFragment"` +} diff --git a/vendor/github.com/pion/webrtc/v4/icecandidatepair.go b/vendor/github.com/pion/webrtc/v4/icecandidatepair.go new file mode 100644 index 0000000..8868671 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/icecandidatepair.go @@ -0,0 +1,37 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +import "fmt" + +// ICECandidatePair represents an ICE Candidate pair. +type ICECandidatePair struct { + statsID string + Local *ICECandidate + Remote *ICECandidate +} + +func newICECandidatePairStatsID(localID, remoteID string) string { + return fmt.Sprintf("%s-%s", localID, remoteID) +} + +func (p *ICECandidatePair) String() string { + if p == nil { + return "" + } + + return fmt.Sprintf("(local) %s <-> (remote) %s", p.Local, p.Remote) +} + +// NewICECandidatePair returns an initialized *ICECandidatePair +// for the given pair of ICECandidate instances. +func NewICECandidatePair(local, remote *ICECandidate) *ICECandidatePair { + statsID := newICECandidatePairStatsID(local.statsID, remote.statsID) + + return &ICECandidatePair{ + statsID: statsID, + Local: local, + Remote: remote, + } +} diff --git a/vendor/github.com/pion/webrtc/v4/icecandidatetype.go b/vendor/github.com/pion/webrtc/v4/icecandidatetype.go new file mode 100644 index 0000000..4683266 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/icecandidatetype.go @@ -0,0 +1,119 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +import ( + "fmt" + + "github.com/pion/ice/v4" +) + +// ICECandidateType represents the type of the ICE candidate used. +type ICECandidateType int + +const ( + // ICECandidateTypeUnknown is the enum's zero-value. + ICECandidateTypeUnknown ICECandidateType = iota + + // ICECandidateTypeHost indicates that the candidate is of Host type as + // described in https://tools.ietf.org/html/rfc8445#section-5.1.1.1. A + // candidate obtained by binding to a specific port from an IP address on + // the host. This includes IP addresses on physical interfaces and logical + // ones, such as ones obtained through VPNs. + ICECandidateTypeHost + + // ICECandidateTypeSrflx indicates the candidate is of Server + // Reflexive type as described + // https://tools.ietf.org/html/rfc8445#section-5.1.1.2. A candidate type + // whose IP address and port are a binding allocated by a NAT for an ICE + // agent after it sends a packet through the NAT to a server, such as a + // STUN server. + ICECandidateTypeSrflx + + // ICECandidateTypePrflx indicates that the candidate is of Peer + // Reflexive type. A candidate type whose IP address and port are a binding + // allocated by a NAT for an ICE agent after it sends a packet through the + // NAT to its peer. + ICECandidateTypePrflx + + // ICECandidateTypeRelay indicates the candidate is of Relay type as + // described in https://tools.ietf.org/html/rfc8445#section-5.1.1.2. A + // candidate type obtained from a relay server, such as a TURN server. + ICECandidateTypeRelay +) + +// This is done this way because of a linter. +const ( + iceCandidateTypeHostStr = "host" + iceCandidateTypeSrflxStr = "srflx" + iceCandidateTypePrflxStr = "prflx" + iceCandidateTypeRelayStr = "relay" +) + +// NewICECandidateType takes a string and converts it into ICECandidateType. +func NewICECandidateType(raw string) (ICECandidateType, error) { + switch raw { + case iceCandidateTypeHostStr: + return ICECandidateTypeHost, nil + case iceCandidateTypeSrflxStr: + return ICECandidateTypeSrflx, nil + case iceCandidateTypePrflxStr: + return ICECandidateTypePrflx, nil + case iceCandidateTypeRelayStr: + return ICECandidateTypeRelay, nil + default: + return ICECandidateTypeUnknown, fmt.Errorf("%w: %s", errICECandidateTypeUnknown, raw) + } +} + +func (t ICECandidateType) String() string { + switch t { + case ICECandidateTypeHost: + return iceCandidateTypeHostStr + case ICECandidateTypeSrflx: + return iceCandidateTypeSrflxStr + case ICECandidateTypePrflx: + return iceCandidateTypePrflxStr + case ICECandidateTypeRelay: + return iceCandidateTypeRelayStr + default: + return ErrUnknownType.Error() + } +} + +func getCandidateType(candidateType ice.CandidateType) (ICECandidateType, error) { + switch candidateType { + case ice.CandidateTypeHost: + return ICECandidateTypeHost, nil + case ice.CandidateTypeServerReflexive: + return ICECandidateTypeSrflx, nil + case ice.CandidateTypePeerReflexive: + return ICECandidateTypePrflx, nil + case ice.CandidateTypeRelay: + return ICECandidateTypeRelay, nil + default: + // NOTE: this should never happen[tm] + err := fmt.Errorf("%w: %s", errICEInvalidConvertCandidateType, candidateType.String()) + + return ICECandidateTypeUnknown, err + } +} + +// MarshalText implements the encoding.TextMarshaler interface. +func (t ICECandidateType) MarshalText() ([]byte, error) { //nolint:staticcheck + return []byte(t.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +func (t *ICECandidateType) UnmarshalText(b []byte) error { + var err error + *t, err = NewICECandidateType(string(b)) + + return err +} + +func (r ICECandidateType) toICE() ice.CandidateType { + //nolint:gosec // G115, no overflow, ICECandidateType matches ice.CandidateType in granularity. + return ice.CandidateType(r) +} diff --git a/vendor/github.com/pion/webrtc/v4/icecomponent.go b/vendor/github.com/pion/webrtc/v4/icecomponent.go new file mode 100644 index 0000000..fcfe040 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/icecomponent.go @@ -0,0 +1,53 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +// ICEComponent describes if the ice transport is used for RTP +// (or RTCP multiplexing). +type ICEComponent int + +const ( + // ICEComponentUnknown is the enum's zero-value. + ICEComponentUnknown ICEComponent = iota + + // ICEComponentRTP indicates that the ICE Transport is used for RTP (or + // RTCP multiplexing), as defined in + // https://tools.ietf.org/html/rfc5245#section-4.1.1.1. Protocols + // multiplexed with RTP (e.g. data channel) share its component ID. This + // represents the component-id value 1 when encoded in candidate-attribute. + ICEComponentRTP + + // ICEComponentRTCP indicates that the ICE Transport is used for RTCP as + // defined by https://tools.ietf.org/html/rfc5245#section-4.1.1.1. This + // represents the component-id value 2 when encoded in candidate-attribute. + ICEComponentRTCP +) + +// This is done this way because of a linter. +const ( + iceComponentRTPStr = "rtp" + iceComponentRTCPStr = "rtcp" +) + +func newICEComponent(raw string) ICEComponent { + switch raw { + case iceComponentRTPStr: + return ICEComponentRTP + case iceComponentRTCPStr: + return ICEComponentRTCP + default: + return ICEComponentUnknown + } +} + +func (t ICEComponent) String() string { + switch t { + case ICEComponentRTP: + return iceComponentRTPStr + case ICEComponentRTCP: + return iceComponentRTCPStr + default: + return ErrUnknownType.Error() + } +} diff --git a/vendor/github.com/pion/webrtc/v4/iceconnectionstate.go b/vendor/github.com/pion/webrtc/v4/iceconnectionstate.go new file mode 100644 index 0000000..c241554 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/iceconnectionstate.go @@ -0,0 +1,100 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +// ICEConnectionState indicates signaling state of the ICE Connection. +type ICEConnectionState int + +const ( + // ICEConnectionStateUnknown is the enum's zero-value. + ICEConnectionStateUnknown ICEConnectionState = iota + + // ICEConnectionStateNew indicates that any of the ICETransports are + // in the "new" state and none of them are in the "checking", "disconnected" + // or "failed" state, or all ICETransports are in the "closed" state, or + // there are no transports. + ICEConnectionStateNew + + // ICEConnectionStateChecking indicates that any of the ICETransports + // are in the "checking" state and none of them are in the "disconnected" + // or "failed" state. + ICEConnectionStateChecking + + // ICEConnectionStateConnected indicates that all ICETransports are + // in the "connected", "completed" or "closed" state and at least one of + // them is in the "connected" state. + ICEConnectionStateConnected + + // ICEConnectionStateCompleted indicates that all ICETransports are + // in the "completed" or "closed" state and at least one of them is in the + // "completed" state. + ICEConnectionStateCompleted + + // ICEConnectionStateDisconnected indicates that any of the + // ICETransports are in the "disconnected" state and none of them are + // in the "failed" state. + ICEConnectionStateDisconnected + + // ICEConnectionStateFailed indicates that any of the ICETransports + // are in the "failed" state. + ICEConnectionStateFailed + + // ICEConnectionStateClosed indicates that the PeerConnection's + // isClosed is true. + ICEConnectionStateClosed +) + +// This is done this way because of a linter. +const ( + iceConnectionStateNewStr = "new" + iceConnectionStateCheckingStr = "checking" + iceConnectionStateConnectedStr = "connected" + iceConnectionStateCompletedStr = "completed" + iceConnectionStateDisconnectedStr = "disconnected" + iceConnectionStateFailedStr = "failed" + iceConnectionStateClosedStr = "closed" +) + +// NewICEConnectionState takes a string and converts it to ICEConnectionState. +func NewICEConnectionState(raw string) ICEConnectionState { + switch raw { + case iceConnectionStateNewStr: + return ICEConnectionStateNew + case iceConnectionStateCheckingStr: + return ICEConnectionStateChecking + case iceConnectionStateConnectedStr: + return ICEConnectionStateConnected + case iceConnectionStateCompletedStr: + return ICEConnectionStateCompleted + case iceConnectionStateDisconnectedStr: + return ICEConnectionStateDisconnected + case iceConnectionStateFailedStr: + return ICEConnectionStateFailed + case iceConnectionStateClosedStr: + return ICEConnectionStateClosed + default: + return ICEConnectionStateUnknown + } +} + +func (c ICEConnectionState) String() string { + switch c { + case ICEConnectionStateNew: + return iceConnectionStateNewStr + case ICEConnectionStateChecking: + return iceConnectionStateCheckingStr + case ICEConnectionStateConnected: + return iceConnectionStateConnectedStr + case ICEConnectionStateCompleted: + return iceConnectionStateCompletedStr + case ICEConnectionStateDisconnected: + return iceConnectionStateDisconnectedStr + case ICEConnectionStateFailed: + return iceConnectionStateFailedStr + case ICEConnectionStateClosed: + return iceConnectionStateClosedStr + default: + return ErrUnknownType.Error() + } +} diff --git a/vendor/github.com/pion/webrtc/v4/icecredentialtype.go b/vendor/github.com/pion/webrtc/v4/icecredentialtype.go new file mode 100644 index 0000000..a7af67a --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/icecredentialtype.go @@ -0,0 +1,73 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +import ( + "encoding/json" + "fmt" +) + +// ICECredentialType indicates the type of credentials used to connect to +// an ICE server. +type ICECredentialType int + +const ( + // ICECredentialTypePassword describes username and password based + // credentials as described in https://tools.ietf.org/html/rfc5389. + ICECredentialTypePassword ICECredentialType = iota + + // ICECredentialTypeOauth describes token based credential as described + // in https://tools.ietf.org/html/rfc7635. + ICECredentialTypeOauth +) + +// This is done this way because of a linter. +const ( + iceCredentialTypePasswordStr = "password" + iceCredentialTypeOauthStr = "oauth" +) + +func newICECredentialType(raw string) (ICECredentialType, error) { + switch raw { + case iceCredentialTypePasswordStr: + return ICECredentialTypePassword, nil + case iceCredentialTypeOauthStr: + return ICECredentialTypeOauth, nil + default: + return ICECredentialTypePassword, errInvalidICECredentialTypeString + } +} + +func (t ICECredentialType) String() string { + switch t { + case ICECredentialTypePassword: + return iceCredentialTypePasswordStr + case ICECredentialTypeOauth: + return iceCredentialTypeOauthStr + default: + return ErrUnknownType.Error() + } +} + +// UnmarshalJSON parses the JSON-encoded data and stores the result. +func (t *ICECredentialType) UnmarshalJSON(b []byte) error { + var val string + if err := json.Unmarshal(b, &val); err != nil { + return err + } + + tmp, err := newICECredentialType(val) + if err != nil { + return fmt.Errorf("%w: (%s)", err, val) + } + + *t = tmp + + return nil +} + +// MarshalJSON returns the JSON encoding. +func (t ICECredentialType) MarshalJSON() ([]byte, error) { + return json.Marshal(t.String()) +} diff --git a/vendor/github.com/pion/webrtc/v4/icegatherer.go b/vendor/github.com/pion/webrtc/v4/icegatherer.go new file mode 100644 index 0000000..7f0294b --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/icegatherer.go @@ -0,0 +1,753 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js + +package webrtc + +import ( + "fmt" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/pion/ice/v4" + "github.com/pion/logging" + "github.com/pion/stun/v3" +) + +// ICEGatherer gathers local host, server reflexive and relay +// candidates, as well as enabling the retrieval of local Interactive +// Connectivity Establishment (ICE) parameters which can be +// exchanged in signaling. +type ICEGatherer struct { + lock sync.RWMutex + log logging.LeveledLogger + state ICEGathererState + + validatedServers []*stun.URI + gatherPolicy ICETransportPolicy + + agent *ice.Agent + + onLocalCandidateHandler atomic.Value // func(candidate *ICECandidate) + onStateChangeHandler atomic.Value // func(state ICEGathererState) + + // Used for GatheringCompletePromise + onGatheringCompleteHandler atomic.Value // func() + + api *API + + // Used to set the corresponding media stream identification tag and media description index + // for ICE candidates generated by this gatherer. + sdpMid atomic.Value // string + sdpMLineIndex atomic.Uint32 // uint16 + + // Used for ICE candidate pooling + candidatePoolLock sync.Mutex + candidatePool []ice.Candidate + iceCandidatePoolSize uint8 +} + +// ICEAddressRewriteMode controls whether a rule replaces or appends candidates. +type ICEAddressRewriteMode byte + +const ( + ICEAddressRewriteModeUnspecified ICEAddressRewriteMode = iota + ICEAddressRewriteReplace + ICEAddressRewriteAppend +) + +func (r ICEAddressRewriteMode) toICE() ice.AddressRewriteMode { + return ice.AddressRewriteMode(r) +} + +// ICEAddressRewriteRule represents a rule for remapping candidate addresses. +type ICEAddressRewriteRule struct { + External []string + Local string + Iface string + CIDR string + AsCandidateType ICECandidateType + Mode ICEAddressRewriteMode + Networks []NetworkType +} + +func (r ICEAddressRewriteRule) toICE() ice.AddressRewriteRule { + candidateType := r.AsCandidateType.toICE() + mode := r.Mode.toICE() + networks := toICENetworkTypes(r.Networks) + + rule := ice.AddressRewriteRule{ + External: append([]string(nil), r.External...), + Local: r.Local, + Iface: r.Iface, + CIDR: r.CIDR, + AsCandidateType: candidateType, + Mode: mode, + Networks: networks, + } + + return rule +} + +// NewICEGatherer creates a new NewICEGatherer. +// This constructor is part of the ORTC API. It is not +// meant to be used together with the basic WebRTC API. +func (api *API) NewICEGatherer(opts ICEGatherOptions) (*ICEGatherer, error) { + var validatedServers []*stun.URI + if len(opts.ICEServers) > 0 { + for _, server := range opts.ICEServers { + url, err := server.urls() + if err != nil { + return nil, err + } + validatedServers = append(validatedServers, url...) + } + } + + return &ICEGatherer{ + state: ICEGathererStateNew, + gatherPolicy: opts.ICEGatherPolicy, + validatedServers: validatedServers, + api: api, + log: api.settingEngine.LoggerFactory.NewLogger("ice"), + sdpMid: atomic.Value{}, + sdpMLineIndex: atomic.Uint32{}, + candidatePool: make([]ice.Candidate, 0, opts.ICECandidatePoolSize), + iceCandidatePoolSize: opts.ICECandidatePoolSize, + }, nil +} + +// updateServers updates the ICE servers and gather policy. +// If called before gathering starts, the new servers will be used for initial gathering. +// If called after gathering has started, the new servers will be used on the next ICE restart. +func (g *ICEGatherer) updateServers(servers []ICEServer, policy ICETransportPolicy) error { + g.lock.Lock() + defer g.lock.Unlock() + + var validatedServers []*stun.URI + for _, server := range servers { + urls, err := server.urls() + if err != nil { + return err + } + validatedServers = append(validatedServers, urls...) + } + + g.validatedServers = validatedServers + g.gatherPolicy = policy + + if g.agent != nil && (g.State() != ICEGathererStateGathering || + g.iceCandidatePoolSize == 0) { + return g.agent.UpdateOptions(ice.WithUrls(validatedServers)) + } + + return nil +} + +// validatedServersCount returns the number of validated ICE server URLs. +func (g *ICEGatherer) validatedServersCount() int { + g.lock.RLock() + defer g.lock.RUnlock() + + return len(g.validatedServers) +} + +func (g *ICEGatherer) createAgent() error { + g.lock.Lock() + defer g.lock.Unlock() + + if g.agent != nil || g.State() != ICEGathererStateNew { + return nil + } + + options, err := g.buildAgentOptions() + if err != nil { + return err + } + + agent, err := ice.NewAgentWithOptions(options...) + if err != nil { + return err + } + + g.agent = agent + + return nil +} + +func (g *ICEGatherer) buildAgentOptions() ([]ice.AgentOption, error) { + candidateTypes := g.resolveCandidateTypes() + nat1To1CandiTyp := g.resolveNAT1To1CandidateType() + mDNSMode := g.sanitizedMDNSMode() + + options := g.baseAgentOptions(mDNSMode) + if len(candidateTypes) > 0 { + options = append(options, ice.WithCandidateTypes(candidateTypes)) + } + + options = append(options, g.credentialOptions()...) + + rewriteOptions, err := g.addressRewriteOptions(nat1To1CandiTyp) + if err != nil { + return nil, err + } + options = append(options, rewriteOptions...) + options = append(options, g.timeoutOptions()...) + options = append(options, g.miscOptions()...) + options = append(options, g.renominationOptions()...) + + requestedNetworkTypes := g.api.settingEngine.candidates.ICENetworkTypes + if len(requestedNetworkTypes) == 0 { + requestedNetworkTypes = supportedNetworkTypes() + } + + return append(options, ice.WithNetworkTypes(toICENetworkTypes(requestedNetworkTypes))), nil +} + +func (g *ICEGatherer) resolveCandidateTypes() []ice.CandidateType { + if g.api.settingEngine.candidates.ICELite { + return []ice.CandidateType{ice.CandidateTypeHost} + } + + switch g.gatherPolicy { + case ICETransportPolicyRelay: + return []ice.CandidateType{ice.CandidateTypeRelay} + case ICETransportPolicyNoHost: + return []ice.CandidateType{ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay} + default: + } + + return nil +} + +func (g *ICEGatherer) resolveNAT1To1CandidateType() ice.CandidateType { + switch g.api.settingEngine.candidates.NAT1To1IPCandidateType { + case ICECandidateTypeHost: + return ice.CandidateTypeHost + case ICECandidateTypeSrflx: + return ice.CandidateTypeServerReflexive + default: + return ice.CandidateTypeUnspecified + } +} + +func (g *ICEGatherer) sanitizedMDNSMode() ice.MulticastDNSMode { + mode := g.api.settingEngine.candidates.MulticastDNSMode + if mode == ice.MulticastDNSModeDisabled || mode == ice.MulticastDNSModeQueryAndGather { + return mode + } + + return ice.MulticastDNSModeQueryOnly +} + +func (g *ICEGatherer) baseAgentOptions(mDNSMode ice.MulticastDNSMode) []ice.AgentOption { + return []ice.AgentOption{ + ice.WithICELite(g.api.settingEngine.candidates.ICELite), + ice.WithUrls(g.validatedServers), + ice.WithPortRange(g.api.settingEngine.ephemeralUDP.PortMin, g.api.settingEngine.ephemeralUDP.PortMax), + ice.WithLoggerFactory(g.api.settingEngine.LoggerFactory), + ice.WithInterfaceFilter(g.api.settingEngine.candidates.InterfaceFilter), + ice.WithIPFilter(g.api.settingEngine.candidates.IPFilter), + ice.WithRemoteIPFilter(g.api.settingEngine.candidates.RemoteIPFilter), + ice.WithNet(g.api.settingEngine.net), + ice.WithMulticastDNSMode(mDNSMode), + ice.WithTCPMux(g.api.settingEngine.iceTCPMux), + ice.WithUDPMux(g.api.settingEngine.iceUDPMux), + ice.WithProxyDialer(g.api.settingEngine.iceProxyDialer), + ice.WithBindingRequestHandler(g.api.settingEngine.iceBindingRequestHandler), + } +} + +func (g *ICEGatherer) credentialOptions() []ice.AgentOption { + ufrag := g.api.settingEngine.candidates.UsernameFragment + pass := g.api.settingEngine.candidates.Password + if ufrag == "" && pass == "" { + return nil + } + + return []ice.AgentOption{ + ice.WithLocalCredentials(g.api.settingEngine.candidates.UsernameFragment, g.api.settingEngine.candidates.Password), + } +} + +func (g *ICEGatherer) addressRewriteOptions(candidateType ice.CandidateType) ([]ice.AgentOption, error) { + rules := g.api.settingEngine.candidates.addressRewriteRules + nat1To1IPs := g.api.settingEngine.candidates.NAT1To1IPs + if len(rules) > 0 && len(nat1To1IPs) > 0 { + return nil, errAddressRewriteWithNAT1To1 + } + + if len(rules) > 0 { + return []ice.AgentOption{ice.WithAddressRewriteRules(rules...)}, nil + } + + if len(nat1To1IPs) == 0 { + return nil, nil + } + + return []ice.AgentOption{ + ice.WithAddressRewriteRules( + legacyNAT1To1AddressRewriteRules( + nat1To1IPs, + candidateType, + )..., + ), + }, nil +} + +func (g *ICEGatherer) timeoutOptions() []ice.AgentOption { + opts := make([]ice.AgentOption, 0, 8) + + if g.api.settingEngine.timeout.ICEDisconnectedTimeout != nil { + opts = append(opts, ice.WithDisconnectedTimeout(*g.api.settingEngine.timeout.ICEDisconnectedTimeout)) + } + if g.api.settingEngine.timeout.ICEFailedTimeout != nil { + opts = append(opts, ice.WithFailedTimeout(*g.api.settingEngine.timeout.ICEFailedTimeout)) + } + if g.api.settingEngine.timeout.ICEKeepaliveInterval != nil { + opts = append(opts, ice.WithKeepaliveInterval(*g.api.settingEngine.timeout.ICEKeepaliveInterval)) + } + if g.api.settingEngine.timeout.ICEHostAcceptanceMinWait != nil { + opts = append(opts, ice.WithHostAcceptanceMinWait(*g.api.settingEngine.timeout.ICEHostAcceptanceMinWait)) + } + if g.api.settingEngine.timeout.ICESrflxAcceptanceMinWait != nil { + opts = append(opts, ice.WithSrflxAcceptanceMinWait(*g.api.settingEngine.timeout.ICESrflxAcceptanceMinWait)) + } + if g.api.settingEngine.timeout.ICEPrflxAcceptanceMinWait != nil { + opts = append(opts, ice.WithPrflxAcceptanceMinWait(*g.api.settingEngine.timeout.ICEPrflxAcceptanceMinWait)) + } + if g.api.settingEngine.timeout.ICERelayAcceptanceMinWait != nil { + opts = append(opts, ice.WithRelayAcceptanceMinWait(*g.api.settingEngine.timeout.ICERelayAcceptanceMinWait)) + } + if g.api.settingEngine.timeout.ICESTUNGatherTimeout != nil { + opts = append(opts, ice.WithSTUNGatherTimeout(*g.api.settingEngine.timeout.ICESTUNGatherTimeout)) + } + + return opts +} + +func (g *ICEGatherer) miscOptions() []ice.AgentOption { + opts := make([]ice.AgentOption, 0, 4) + + if g.api.settingEngine.candidates.MulticastDNSHostName != "" { + opts = append(opts, ice.WithMulticastDNSHostName(g.api.settingEngine.candidates.MulticastDNSHostName)) + } + + if g.api.settingEngine.candidates.IncludeLoopbackCandidate { + opts = append(opts, ice.WithIncludeLoopback()) + } + + if g.api.settingEngine.iceDisableActiveTCP { + opts = append(opts, ice.WithDisableActiveTCP()) + } + + if g.api.settingEngine.iceMaxBindingRequests != nil { + opts = append(opts, ice.WithMaxBindingRequests(*g.api.settingEngine.iceMaxBindingRequests)) + } + + return opts +} + +func (g *ICEGatherer) renominationOptions() []ice.AgentOption { + renom := g.api.settingEngine.renomination + if !renom.enabled && !renom.automatic { + return nil + } + + generator := renom.generator + opts := []ice.AgentOption{ + ice.WithRenomination(func() uint32 { + return generator() + }), + } + if renom.attributeType != nil { + opts = append(opts, ice.WithNominationAttribute(*renom.attributeType)) + } + + if renom.automatic { + interval := time.Duration(0) + if renom.automaticInterval != nil { + interval = *renom.automaticInterval + } + + opts = append(opts, ice.WithAutomaticRenomination(interval)) + } + + return opts +} + +func legacyNAT1To1AddressRewriteRules(ips []string, candidateType ice.CandidateType) []ice.AddressRewriteRule { + catchAll := make([]string, 0, len(ips)) + rules := make([]ice.AddressRewriteRule, 0, len(ips)+1) + + for _, ip := range ips { + splits := strings.SplitN(ip, "/", 2) + + if len(splits) == 2 { + rules = append(rules, ice.AddressRewriteRule{ + External: []string{splits[0]}, + Local: splits[1], + AsCandidateType: candidateType, + }) + catchAll = append(catchAll, splits[0]) + } else { + catchAll = append(catchAll, ip) + } + } + + if len(catchAll) > 0 { + rules = append(rules, ice.AddressRewriteRule{ + External: catchAll, + AsCandidateType: candidateType, + }) + } + + return rules +} + +// Gather ICE candidates. +func (g *ICEGatherer) Gather() error { //nolint:cyclop + if err := g.createAgent(); err != nil { + return err + } + + agent := g.getAgent() + // it is possible agent had just been closed + if agent == nil { + return fmt.Errorf("%w: unable to gather", errICEAgentNotExist) + } + + g.setState(ICEGathererStateGathering) + if err := agent.OnCandidate(func(candidate ice.Candidate) { + onLocalCandidateHandler := func(*ICECandidate) {} + if handler, ok := g.onLocalCandidateHandler.Load().(func(candidate *ICECandidate)); ok && handler != nil { + onLocalCandidateHandler = handler + } + + onGatheringCompleteHandler := func() {} + if handler, ok := g.onGatheringCompleteHandler.Load().(func()); ok && handler != nil { + onGatheringCompleteHandler = handler + } + + sdpMid := "" + + if mid, ok := g.sdpMid.Load().(string); ok { + sdpMid = mid + } + + sdpMLineIndex := uint16(g.sdpMLineIndex.Load()) //nolint:gosec // G115 + + if candidate != nil { + g.candidatePoolLock.Lock() + if g.iceCandidatePoolSize > 0 && g.candidatePool != nil { + g.candidatePool = append(g.candidatePool, candidate) + g.candidatePoolLock.Unlock() + + return + } + g.candidatePoolLock.Unlock() + + c, err := newICECandidateFromICE(candidate, sdpMid, sdpMLineIndex) + if err != nil { + g.log.Warnf("Failed to convert ice.Candidate: %s", err) + + return + } + onLocalCandidateHandler(&c) + } else { + g.setState(ICEGathererStateComplete) + onGatheringCompleteHandler() + + // If gathering completes before flushing (i.e., before SetLocalDescription), avoid triggering nil. + // Users expect valid candidates to be emitted before the nil completion signal. + g.candidatePoolLock.Lock() + if g.iceCandidatePoolSize > 0 && g.candidatePool != nil { + g.candidatePoolLock.Unlock() + + return + } + g.candidatePoolLock.Unlock() + + onLocalCandidateHandler(nil) + } + }); err != nil { + return err + } + + return agent.GatherCandidates() +} + +// set media stream identification tag and media description index for this gatherer. +func (g *ICEGatherer) setMediaStreamIdentification(mid string, mLineIndex uint16) { + g.sdpMid.Store(mid) + g.sdpMLineIndex.Store(uint32(mLineIndex)) +} + +func (g *ICEGatherer) flushCandidates() { + g.candidatePoolLock.Lock() + + candidates := g.candidatePool + g.candidatePool = nil + g.iceCandidatePoolSize = 0 + + g.candidatePoolLock.Unlock() + + onLocalCandidateHandler := func(*ICECandidate) {} + if handler, ok := g.onLocalCandidateHandler.Load().(func(candidate *ICECandidate)); ok && handler != nil { + onLocalCandidateHandler = handler + } + + sdpMid := "" + if mid, ok := g.sdpMid.Load().(string); ok { + sdpMid = mid + } + + sdpMLineIndex := uint16(g.sdpMLineIndex.Load()) //nolint:gosec // G115 + + currentState := g.State() + + for _, candidate := range candidates { + c, err := newICECandidateFromICE(candidate, sdpMid, sdpMLineIndex) + if err != nil { + g.log.Warnf("Failed to convert pooled ice.Candidate: %s", err) + + continue + } + onLocalCandidateHandler(&c) + } + + // If this is true, gathering completed before flushing, + // so trigger nil to notify the user that all candidates have been gathered. + if currentState == ICEGathererStateComplete { + onLocalCandidateHandler(nil) + } +} + +// Close prunes all local candidates, and closes the ports. +func (g *ICEGatherer) Close() error { + return g.close(false /* shouldGracefullyClose */) +} + +// GracefulClose prunes all local candidates, and closes the ports. It also waits +// for any goroutines it started to complete. This is only safe to call outside of +// ICEGatherer callbacks or if in a callback, in its own goroutine. +func (g *ICEGatherer) GracefulClose() error { + return g.close(true /* shouldGracefullyClose */) +} + +func (g *ICEGatherer) close(shouldGracefullyClose bool) error { + g.lock.Lock() + defer g.lock.Unlock() + + if g.agent == nil { + return nil + } + if shouldGracefullyClose { + if err := g.agent.GracefulClose(); err != nil { + return err + } + } else { + if err := g.agent.Close(); err != nil { + return err + } + } + + // onGatheringCompleteHandler is used solely by the GatheringCompletePromise helper and the common usage + // for that helper is aided by ensuring that this completion is fired in case the PC/ICEGatherer are closed + // before gathering actually completes. If things have already completed then this should be a no-op + if handler, ok := g.onGatheringCompleteHandler.Load().(func()); ok && handler != nil { + handler() + } + + g.agent = nil + g.setState(ICEGathererStateClosed) + + return nil +} + +// GetLocalParameters returns the ICE parameters of the ICEGatherer. +func (g *ICEGatherer) GetLocalParameters() (ICEParameters, error) { + if err := g.createAgent(); err != nil { + return ICEParameters{}, err + } + + agent := g.getAgent() + // it is possible agent had just been closed + if agent == nil { + return ICEParameters{}, fmt.Errorf("%w: unable to get local parameters", errICEAgentNotExist) + } + + frag, pwd, err := agent.GetLocalUserCredentials() + if err != nil { + return ICEParameters{}, err + } + + return ICEParameters{ + UsernameFragment: frag, + Password: pwd, + ICELite: false, + }, nil +} + +// GetLocalCandidates returns the sequence of valid local candidates associated with the ICEGatherer. +func (g *ICEGatherer) GetLocalCandidates() ([]ICECandidate, error) { + if err := g.createAgent(); err != nil { + return nil, err + } + + agent := g.getAgent() + // it is possible agent had just been closed + if agent == nil { + return nil, fmt.Errorf("%w: unable to get local candidates", errICEAgentNotExist) + } + + iceCandidates, err := agent.GetLocalCandidates() + if err != nil { + return nil, err + } + + sdpMid := "" + if mid, ok := g.sdpMid.Load().(string); ok { + sdpMid = mid + } + + sdpMLineIndex := uint16(g.sdpMLineIndex.Load()) //nolint:gosec // G115 + + return newICECandidatesFromICE(iceCandidates, sdpMid, sdpMLineIndex) +} + +// OnLocalCandidate sets an event handler which fires when a new local ICE candidate is available +// Take note that the handler will be called with a nil pointer when gathering is finished. +func (g *ICEGatherer) OnLocalCandidate(f func(*ICECandidate)) { + g.onLocalCandidateHandler.Store(f) +} + +// OnStateChange fires any time the ICEGatherer changes. +func (g *ICEGatherer) OnStateChange(f func(ICEGathererState)) { + g.onStateChangeHandler.Store(f) +} + +// State indicates the current state of the ICE gatherer. +func (g *ICEGatherer) State() ICEGathererState { + return atomicLoadICEGathererState(&g.state) +} + +func (g *ICEGatherer) setState(s ICEGathererState) { + atomicStoreICEGathererState(&g.state, s) + + if handler, ok := g.onStateChangeHandler.Load().(func(state ICEGathererState)); ok && handler != nil { + handler(s) + } +} + +func (g *ICEGatherer) getAgent() *ice.Agent { + g.lock.RLock() + defer g.lock.RUnlock() + + return g.agent +} + +func (g *ICEGatherer) collectStats(collector *statsReportCollector) { + agent := g.getAgent() + if agent == nil { + return + } + + collector.Collecting() + go func(collector *statsReportCollector, agent *ice.Agent) { + for _, candidatePairStats := range agent.GetCandidatePairsStats() { + collector.Collecting() + + stats, err := toICECandidatePairStats(candidatePairStats) + if err != nil { + g.log.Error(err.Error()) + collector.Done() + + continue + } + + collector.Collect(stats.ID, stats) + } + + for _, candidateStats := range agent.GetLocalCandidatesStats() { + collector.Collecting() + + networkType, err := getNetworkType(candidateStats.NetworkType) + if err != nil { + g.log.Error(err.Error()) + } + + candidateType, err := getCandidateType(candidateStats.CandidateType) + if err != nil { + g.log.Error(err.Error()) + } + + stats := ICECandidateStats{ + Timestamp: statsTimestampFrom(candidateStats.Timestamp), + ID: candidateStats.ID, + Type: StatsTypeLocalCandidate, + IP: candidateStats.IP, + Port: int32(candidateStats.Port), //nolint:gosec // G115, no overflow, port + Protocol: networkType.Protocol(), + CandidateType: candidateType, + Priority: int32(candidateStats.Priority), //nolint:gosec + URL: candidateStats.URL, + RelayProtocol: candidateStats.RelayProtocol, + Deleted: candidateStats.Deleted, + } + collector.Collect(stats.ID, stats) + } + + for _, candidateStats := range agent.GetRemoteCandidatesStats() { + collector.Collecting() + networkType, err := getNetworkType(candidateStats.NetworkType) + if err != nil { + g.log.Error(err.Error()) + } + + candidateType, err := getCandidateType(candidateStats.CandidateType) + if err != nil { + g.log.Error(err.Error()) + } + + stats := ICECandidateStats{ + Timestamp: statsTimestampFrom(candidateStats.Timestamp), + ID: candidateStats.ID, + Type: StatsTypeRemoteCandidate, + IP: candidateStats.IP, + Port: int32(candidateStats.Port), //nolint:gosec // G115, no overflow, port + Protocol: networkType.Protocol(), + CandidateType: candidateType, + Priority: int32(candidateStats.Priority), //nolint:gosec // G115 + URL: candidateStats.URL, + RelayProtocol: candidateStats.RelayProtocol, + } + collector.Collect(stats.ID, stats) + } + collector.Done() + }(collector, agent) +} + +func (g *ICEGatherer) getSelectedCandidatePairStats() (ICECandidatePairStats, bool) { + agent := g.getAgent() + if agent == nil { + return ICECandidatePairStats{}, false + } + + selectedCandidatePairStats, isAvailable := agent.GetSelectedCandidatePairStats() + if !isAvailable { + return ICECandidatePairStats{}, false + } + + stats, err := toICECandidatePairStats(selectedCandidatePairStats) + if err != nil { + g.log.Error(err.Error()) + + return ICECandidatePairStats{}, false + } + + return stats, true +} diff --git a/vendor/github.com/pion/webrtc/v4/icegathererstate.go b/vendor/github.com/pion/webrtc/v4/icegathererstate.go new file mode 100644 index 0000000..357a12b --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/icegathererstate.go @@ -0,0 +1,54 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +import ( + "sync/atomic" +) + +// ICEGathererState represents the current state of the ICE gatherer. +type ICEGathererState uint32 + +const ( + // ICEGathererStateUnknown is the enum's zero-value. + ICEGathererStateUnknown ICEGathererState = iota + + // ICEGathererStateNew indicates object has been created but + // gather() has not been called. + ICEGathererStateNew + + // ICEGathererStateGathering indicates gather() has been called, + // and the ICEGatherer is in the process of gathering candidates. + ICEGathererStateGathering + + // ICEGathererStateComplete indicates the ICEGatherer has completed gathering. + ICEGathererStateComplete + + // ICEGathererStateClosed indicates the closed state can only be entered + // when the ICEGatherer has been closed intentionally by calling close(). + ICEGathererStateClosed +) + +func (s ICEGathererState) String() string { + switch s { + case ICEGathererStateNew: + return "new" + case ICEGathererStateGathering: + return "gathering" + case ICEGathererStateComplete: + return "complete" + case ICEGathererStateClosed: + return "closed" + default: + return ErrUnknownType.Error() + } +} + +func atomicStoreICEGathererState(state *ICEGathererState, newState ICEGathererState) { + atomic.StoreUint32((*uint32)(state), uint32(newState)) +} + +func atomicLoadICEGathererState(state *ICEGathererState) ICEGathererState { + return ICEGathererState(atomic.LoadUint32((*uint32)(state))) +} diff --git a/vendor/github.com/pion/webrtc/v4/icegatheringstate.go b/vendor/github.com/pion/webrtc/v4/icegatheringstate.go new file mode 100644 index 0000000..bfb9166 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/icegatheringstate.go @@ -0,0 +1,59 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +// ICEGatheringState describes the state of the candidate gathering process. +type ICEGatheringState int + +const ( + // ICEGatheringStateUnknown is the enum's zero-value. + ICEGatheringStateUnknown ICEGatheringState = iota + + // ICEGatheringStateNew indicates that any of the ICETransports are + // in the "new" gathering state and none of the transports are in the + // "gathering" state, or there are no transports. + ICEGatheringStateNew + + // ICEGatheringStateGathering indicates that any of the ICETransports + // are in the "gathering" state. + ICEGatheringStateGathering + + // ICEGatheringStateComplete indicates that at least one ICETransport + // exists, and all ICETransports are in the "completed" gathering state. + ICEGatheringStateComplete +) + +// This is done this way because of a linter. +const ( + iceGatheringStateNewStr = "new" + iceGatheringStateGatheringStr = "gathering" + iceGatheringStateCompleteStr = "complete" +) + +// NewICEGatheringState takes a string and converts it to ICEGatheringState. +func NewICEGatheringState(raw string) ICEGatheringState { + switch raw { + case iceGatheringStateNewStr: + return ICEGatheringStateNew + case iceGatheringStateGatheringStr: + return ICEGatheringStateGathering + case iceGatheringStateCompleteStr: + return ICEGatheringStateComplete + default: + return ICEGatheringStateUnknown + } +} + +func (t ICEGatheringState) String() string { + switch t { + case ICEGatheringStateNew: + return iceGatheringStateNewStr + case ICEGatheringStateGathering: + return iceGatheringStateGatheringStr + case ICEGatheringStateComplete: + return iceGatheringStateCompleteStr + default: + return ErrUnknownType.Error() + } +} diff --git a/vendor/github.com/pion/webrtc/v4/icegatheroptions.go b/vendor/github.com/pion/webrtc/v4/icegatheroptions.go new file mode 100644 index 0000000..63a5fba --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/icegatheroptions.go @@ -0,0 +1,11 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +// ICEGatherOptions provides options relating to the gathering of ICE candidates. +type ICEGatherOptions struct { + ICEServers []ICEServer + ICEGatherPolicy ICETransportPolicy + ICECandidatePoolSize uint8 +} diff --git a/vendor/github.com/pion/webrtc/v4/icemux.go b/vendor/github.com/pion/webrtc/v4/icemux.go new file mode 100644 index 0000000..cc7f89f --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/icemux.go @@ -0,0 +1,30 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +import ( + "net" + + "github.com/pion/ice/v4" + "github.com/pion/logging" +) + +// NewICETCPMux creates a new instance of ice.TCPMuxDefault. It enables use of +// passive ICE TCP candidates. +func NewICETCPMux(logger logging.LeveledLogger, listener net.Listener, readBufferSize int) ice.TCPMux { + return ice.NewTCPMuxDefault(ice.TCPMuxParams{ + Listener: listener, + Logger: logger, + ReadBufferSize: readBufferSize, + }) +} + +// NewICEUDPMux creates a new instance of ice.UDPMuxDefault. It allows many PeerConnections to be served +// by a single UDP Port. +func NewICEUDPMux(logger logging.LeveledLogger, udpConn net.PacketConn) ice.UDPMux { + return ice.NewUDPMuxDefault(ice.UDPMuxParams{ + UDPConn: udpConn, + Logger: logger, + }) +} diff --git a/vendor/github.com/pion/webrtc/v4/iceparameters.go b/vendor/github.com/pion/webrtc/v4/iceparameters.go new file mode 100644 index 0000000..c8bd3c9 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/iceparameters.go @@ -0,0 +1,12 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +// ICEParameters includes the ICE username fragment +// and password and other ICE-related parameters. +type ICEParameters struct { + UsernameFragment string `json:"usernameFragment"` + Password string `json:"password"` //nolint:gosec // not a secret. + ICELite bool `json:"iceLite"` +} diff --git a/vendor/github.com/pion/webrtc/v4/iceprotocol.go b/vendor/github.com/pion/webrtc/v4/iceprotocol.go new file mode 100644 index 0000000..64fbdc5 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/iceprotocol.go @@ -0,0 +1,53 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +import ( + "fmt" + "strings" +) + +// ICEProtocol indicates the transport protocol type that is used in the +// ice.URL structure. +type ICEProtocol int + +const ( + // ICEProtocolUnknown is the enum's zero-value. + ICEProtocolUnknown ICEProtocol = iota + + // ICEProtocolUDP indicates the URL uses a UDP transport. + ICEProtocolUDP + + // ICEProtocolTCP indicates the URL uses a TCP transport. + ICEProtocolTCP +) + +// This is done this way because of a linter. +const ( + iceProtocolUDPStr = "udp" + iceProtocolTCPStr = "tcp" +) + +// NewICEProtocol takes a string and converts it to ICEProtocol. +func NewICEProtocol(raw string) (ICEProtocol, error) { + switch { + case strings.EqualFold(iceProtocolUDPStr, raw): + return ICEProtocolUDP, nil + case strings.EqualFold(iceProtocolTCPStr, raw): + return ICEProtocolTCP, nil + default: + return ICEProtocolUnknown, fmt.Errorf("%w: %s", errICEProtocolUnknown, raw) + } +} + +func (t ICEProtocol) String() string { + switch t { + case ICEProtocolUDP: + return iceProtocolUDPStr + case ICEProtocolTCP: + return iceProtocolTCPStr + default: + return ErrUnknownType.Error() + } +} diff --git a/vendor/github.com/pion/webrtc/v4/icerole.go b/vendor/github.com/pion/webrtc/v4/icerole.go new file mode 100644 index 0000000..128c68b --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/icerole.go @@ -0,0 +1,63 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +// ICERole describes the role ice.Agent is playing in selecting the +// preferred the candidate pair. +type ICERole int + +const ( + // ICERoleUnknown is the enum's zero-value. + ICERoleUnknown ICERole = iota + + // ICERoleControlling indicates that the ICE agent that is responsible + // for selecting the final choice of candidate pairs and signaling them + // through STUN and an updated offer, if needed. In any session, one agent + // is always controlling. The other is the controlled agent. + ICERoleControlling + + // ICERoleControlled indicates that an ICE agent that waits for the + // controlling agent to select the final choice of candidate pairs. + ICERoleControlled +) + +// This is done this way because of a linter. +const ( + iceRoleControllingStr = "controlling" + iceRoleControlledStr = "controlled" +) + +func newICERole(raw string) ICERole { + switch raw { + case iceRoleControllingStr: + return ICERoleControlling + case iceRoleControlledStr: + return ICERoleControlled + default: + return ICERoleUnknown + } +} + +func (t ICERole) String() string { + switch t { + case ICERoleControlling: + return iceRoleControllingStr + case ICERoleControlled: + return iceRoleControlledStr + default: + return ErrUnknownType.Error() + } +} + +// MarshalText implements encoding.TextMarshaler. +func (t ICERole) MarshalText() ([]byte, error) { + return []byte(t.String()), nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (t *ICERole) UnmarshalText(b []byte) error { + *t = newICERole(string(b)) + + return nil +} diff --git a/vendor/github.com/pion/webrtc/v4/iceserver.go b/vendor/github.com/pion/webrtc/v4/iceserver.go new file mode 100644 index 0000000..4f63d39 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/iceserver.go @@ -0,0 +1,187 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js + +package webrtc + +import ( + "encoding/json" + + "github.com/pion/stun/v3" + "github.com/pion/webrtc/v4/pkg/rtcerr" +) + +// ICEServer describes a single STUN and TURN server that can be used by +// the ICEAgent to establish a connection with a peer. +type ICEServer struct { + URLs []string `json:"urls"` + Username string `json:"username,omitempty"` + Credential any `json:"credential,omitempty"` + CredentialType ICECredentialType `json:"credentialType,omitempty"` +} + +func (s ICEServer) parseURL(i int) (*stun.URI, error) { + return stun.ParseURI(s.URLs[i]) +} + +func (s ICEServer) validate() error { + _, err := s.urls() + + return err +} + +func (s ICEServer) urls() ([]*stun.URI, error) { //nolint:cyclop + urls := []*stun.URI{} + + for i := range s.URLs { + url, err := s.parseURL(i) + if err != nil { + return nil, &rtcerr.InvalidAccessError{Err: err} + } + + if url.Scheme == stun.SchemeTypeTURN || url.Scheme == stun.SchemeTypeTURNS { + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #11.3.2) + if s.Username == "" || s.Credential == nil { + return nil, &rtcerr.InvalidAccessError{Err: ErrNoTurnCredentials} + } + url.Username = s.Username + + switch s.CredentialType { + case ICECredentialTypePassword: + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #11.3.3) + password, ok := s.Credential.(string) + if !ok { + return nil, &rtcerr.InvalidAccessError{Err: ErrTurnCredentials} + } + url.Password = password + + case ICECredentialTypeOauth: + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #11.3.4) + if _, ok := s.Credential.(OAuthCredential); !ok { + return nil, &rtcerr.InvalidAccessError{Err: ErrTurnCredentials} + } + + default: + return nil, &rtcerr.InvalidAccessError{Err: ErrTurnCredentials} + } + } + + urls = append(urls, url) + } + + return urls, nil +} + +func iceserverUnmarshalUrls(val any) (*[]string, error) { + s, ok := val.([]any) + if !ok { + return nil, errInvalidICEServer + } + out := make([]string, len(s)) + for idx, url := range s { + out[idx], ok = url.(string) + if !ok { + return nil, errInvalidICEServer + } + } + + return &out, nil +} + +func iceserverUnmarshalOauth(val any) (*OAuthCredential, error) { + c, ok := val.(map[string]any) + if !ok { + return nil, errInvalidICEServer + } + MACKey, ok := c["MACKey"].(string) + if !ok { + return nil, errInvalidICEServer + } + AccessToken, ok := c["AccessToken"].(string) + if !ok { + return nil, errInvalidICEServer + } + + return &OAuthCredential{ + MACKey: MACKey, + AccessToken: AccessToken, + }, nil +} + +func (s *ICEServer) iceserverUnmarshalFields(fields map[string]any) error { //nolint:cyclop + if val, ok := fields["urls"]; ok { + u, err := iceserverUnmarshalUrls(val) + if err != nil { + return err + } + s.URLs = *u + } else { + s.URLs = []string{} + } + + if val, ok := fields["username"]; ok { + s.Username, ok = val.(string) + if !ok { + return errInvalidICEServer + } + } + if val, ok := fields["credentialType"]; ok { + ct, ok := val.(string) + if !ok { + return errInvalidICEServer + } + tpe, err := newICECredentialType(ct) + if err != nil { + return err + } + s.CredentialType = tpe + } else { + s.CredentialType = ICECredentialTypePassword + } + if val, ok := fields["credential"]; ok { + switch s.CredentialType { + case ICECredentialTypePassword: + s.Credential = val + case ICECredentialTypeOauth: + c, err := iceserverUnmarshalOauth(val) + if err != nil { + return err + } + s.Credential = *c + default: + return errInvalidICECredentialTypeString + } + } + + return nil +} + +// UnmarshalJSON parses the JSON-encoded data and stores the result. +func (s *ICEServer) UnmarshalJSON(b []byte) error { + var tmp any + err := json.Unmarshal(b, &tmp) + if err != nil { + return err + } + if m, ok := tmp.(map[string]any); ok { + return s.iceserverUnmarshalFields(m) + } + + return errInvalidICEServer +} + +// MarshalJSON returns the JSON encoding. +func (s ICEServer) MarshalJSON() ([]byte, error) { + m := make(map[string]any) + m["urls"] = s.URLs + if s.Username != "" { + m["username"] = s.Username + } + if s.Credential != nil { + m["credential"] = s.Credential + } + m["credentialType"] = s.CredentialType + + return json.Marshal(m) +} diff --git a/vendor/github.com/pion/webrtc/v4/iceserver_js.go b/vendor/github.com/pion/webrtc/v4/iceserver_js.go new file mode 100644 index 0000000..e379ab7 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/iceserver_js.go @@ -0,0 +1,46 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build js && wasm +// +build js,wasm + +package webrtc + +import ( + "errors" + + "github.com/pion/ice/v4" +) + +// ICEServer describes a single STUN and TURN server that can be used by +// the ICEAgent to establish a connection with a peer. +type ICEServer struct { + URLs []string + Username string + // Note: TURN is not supported in the WASM bindings yet + Credential any + CredentialType ICECredentialType +} + +func (s ICEServer) parseURL(i int) (*ice.URL, error) { + return ice.ParseURL(s.URLs[i]) +} + +func (s ICEServer) validate() ([]*ice.URL, error) { + urls := []*ice.URL{} + + for i := range s.URLs { + url, err := s.parseURL(i) + if err != nil { + return nil, err + } + + if url.Scheme == ice.SchemeTypeTURN || url.Scheme == ice.SchemeTypeTURNS { + return nil, errors.New("TURN is not currently supported in the JavaScript/Wasm bindings") + } + + urls = append(urls, url) + } + + return urls, nil +} diff --git a/vendor/github.com/pion/webrtc/v4/icetransport.go b/vendor/github.com/pion/webrtc/v4/icetransport.go new file mode 100644 index 0000000..0e6d864 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/icetransport.go @@ -0,0 +1,457 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js + +package webrtc + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/pion/ice/v4" + "github.com/pion/logging" + "github.com/pion/webrtc/v4/internal/mux" + "github.com/pion/webrtc/v4/internal/util" +) + +// ICETransport allows an application access to information about the ICE +// transport over which packets are sent and received. +type ICETransport struct { + lock sync.RWMutex + + role ICERole + + onConnectionStateChangeHandler atomic.Value // func(ICETransportState) + internalOnConnectionStateChangeHandler atomic.Value // func(ICETransportState) + onSelectedCandidatePairChangeHandler atomic.Value // func(*ICECandidatePair) + + state atomic.Value // ICETransportState + + gatherer *ICEGatherer + conn *ice.Conn + mux *mux.Mux + + ctxCancel func() + + loggerFactory logging.LoggerFactory + + log logging.LeveledLogger +} + +// GetSelectedCandidatePair returns the selected candidate pair on which packets are sent +// if there is no selected pair nil is returned. +func (t *ICETransport) GetSelectedCandidatePair() (*ICECandidatePair, error) { + agent := t.gatherer.getAgent() + if agent == nil { + return nil, nil //nolint:nilnil + } + + icePair, err := agent.GetSelectedCandidatePair() + if icePair == nil || err != nil { + return nil, err + } + + local, err := newICECandidateFromICE(icePair.Local, "", 0) + if err != nil { + return nil, err + } + + remote, err := newICECandidateFromICE(icePair.Remote, "", 0) + if err != nil { + return nil, err + } + + return NewICECandidatePair(&local, &remote), nil +} + +// GetSelectedCandidatePairStats returns the selected candidate pair stats on which packets are sent +// if there is no selected pair empty stats, false is returned to indicate stats not available. +func (t *ICETransport) GetSelectedCandidatePairStats() (ICECandidatePairStats, bool) { + return t.gatherer.getSelectedCandidatePairStats() +} + +// NewICETransport creates a new NewICETransport. +func NewICETransport(gatherer *ICEGatherer, loggerFactory logging.LoggerFactory) *ICETransport { + iceTransport := &ICETransport{ + gatherer: gatherer, + loggerFactory: loggerFactory, + log: loggerFactory.NewLogger("ortc"), + } + iceTransport.setState(ICETransportStateNew) + + return iceTransport +} + +// Start incoming connectivity checks based on its configured role. +func (t *ICETransport) Start(gatherer *ICEGatherer, params ICEParameters, role *ICERole) error { //nolint:cyclop + t.lock.Lock() + defer t.lock.Unlock() + + if t.State() != ICETransportStateNew { + return errICETransportNotInNew + } + + if gatherer != nil { + t.gatherer = gatherer + } + + if err := t.ensureGatherer(); err != nil { + return err + } + + agent := t.gatherer.getAgent() + if agent == nil { + return fmt.Errorf("%w: unable to start ICETransport", errICEAgentNotExist) + } + + if err := agent.OnConnectionStateChange(func(iceState ice.ConnectionState) { + state := newICETransportStateFromICE(iceState) + + t.setState(state) + t.onConnectionStateChange(state) + }); err != nil { + return err + } + if err := agent.OnSelectedCandidatePairChange(func(local, remote ice.Candidate) { + candidates, err := newICECandidatesFromICE([]ice.Candidate{local, remote}, "", 0) + if err != nil { + t.log.Warnf("%w: %s", errICECandiatesCoversionFailed, err) + + return + } + t.onSelectedCandidatePairChange(NewICECandidatePair(&candidates[0], &candidates[1])) + }); err != nil { + return err + } + + if role == nil { + controlled := ICERoleControlled + role = &controlled + } + t.role = *role + + ctx, ctxCancel := context.WithCancel(context.Background()) + t.ctxCancel = ctxCancel + + // Drop the lock here to allow ICE candidates to be + // added so that the agent can complete a connection + t.lock.Unlock() + + var iceConn *ice.Conn + var err error + switch *role { + case ICERoleControlling: + iceConn, err = agent.Dial(ctx, + params.UsernameFragment, + params.Password) + + case ICERoleControlled: + iceConn, err = agent.Accept(ctx, + params.UsernameFragment, + params.Password) + + default: + err = errICERoleUnknown + } + + // Reacquire the lock to set the connection/mux + t.lock.Lock() + if err != nil { + return err + } + + if t.State() == ICETransportStateClosed { + return errICETransportClosed + } + + t.conn = iceConn + + config := mux.Config{ + Conn: t.conn, + BufferSize: int(t.gatherer.api.settingEngine.getReceiveMTU()), //nolint:gosec // G115 + LoggerFactory: t.loggerFactory, + } + t.mux = mux.NewMux(config) + + return nil +} + +// restart is not exposed currently because ORTC has users create a whole new ICETransport +// so for now lets keep it private so we don't cause ORTC users to depend on non-standard APIs. +func (t *ICETransport) restart() error { + t.lock.Lock() + defer t.lock.Unlock() + + agent := t.gatherer.getAgent() + if agent == nil { + return fmt.Errorf("%w: unable to restart ICETransport", errICEAgentNotExist) + } + + if err := agent.Restart( + t.gatherer.api.settingEngine.candidates.UsernameFragment, + t.gatherer.api.settingEngine.candidates.Password, + ); err != nil { + return err + } + + return t.gatherer.Gather() +} + +// Stop irreversibly stops the ICETransport. +func (t *ICETransport) Stop() error { + return t.stop(false /* shouldGracefullyClose */) +} + +// GracefulStop irreversibly stops the ICETransport. It also waits +// for any goroutines it started to complete. This is only safe to call outside of +// ICETransport callbacks or if in a callback, in its own goroutine. +func (t *ICETransport) GracefulStop() error { + return t.stop(true /* shouldGracefullyClose */) +} + +func (t *ICETransport) stop(shouldGracefullyClose bool) error { + t.lock.Lock() + t.setState(ICETransportStateClosed) + + if t.ctxCancel != nil { + t.ctxCancel() + } + + // mux and gatherer can only be set when ICETransport.State != Closed. + mux := t.mux + gatherer := t.gatherer + t.lock.Unlock() + + if mux != nil { + var closeErrs []error + if shouldGracefullyClose && gatherer != nil { + // we can't access icegatherer/icetransport.Close via + // mux's net.Conn Close so we call it earlier here. + closeErrs = append(closeErrs, gatherer.GracefulClose()) + } + closeErrs = append(closeErrs, mux.Close()) + + return util.FlattenErrs(closeErrs) + } else if gatherer != nil { + if shouldGracefullyClose { + return gatherer.GracefulClose() + } + + return gatherer.Close() + } + + return nil +} + +// OnSelectedCandidatePairChange sets a handler that is invoked when a new +// ICE candidate pair is selected. +func (t *ICETransport) OnSelectedCandidatePairChange(f func(*ICECandidatePair)) { + t.onSelectedCandidatePairChangeHandler.Store(f) +} + +func (t *ICETransport) onSelectedCandidatePairChange(pair *ICECandidatePair) { + if handler, ok := t.onSelectedCandidatePairChangeHandler.Load().(func(*ICECandidatePair)); ok { + handler(pair) + } +} + +// OnConnectionStateChange sets a handler that is fired when the ICE +// connection state changes. +func (t *ICETransport) OnConnectionStateChange(f func(ICETransportState)) { + t.onConnectionStateChangeHandler.Store(f) +} + +func (t *ICETransport) onConnectionStateChange(state ICETransportState) { + if handler, ok := t.onConnectionStateChangeHandler.Load().(func(ICETransportState)); ok { + handler(state) + } + if handler, ok := t.internalOnConnectionStateChangeHandler.Load().(func(ICETransportState)); ok { + handler(state) + } +} + +// Role indicates the current role of the ICE transport. +func (t *ICETransport) Role() ICERole { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.role +} + +// SetRemoteCandidates sets the sequence of candidates associated with the remote ICETransport. +func (t *ICETransport) SetRemoteCandidates(remoteCandidates []ICECandidate) error { + t.lock.RLock() + defer t.lock.RUnlock() + + if err := t.ensureGatherer(); err != nil { + return err + } + + agent := t.gatherer.getAgent() + if agent == nil { + return fmt.Errorf("%w: unable to set remote candidates", errICEAgentNotExist) + } + + for _, c := range remoteCandidates { + i, err := c.ToICE() + if err != nil { + return err + } + + if err = agent.AddRemoteCandidate(i); err != nil { + return err + } + } + + return nil +} + +// AddRemoteCandidate adds a candidate associated with the remote ICETransport. +func (t *ICETransport) AddRemoteCandidate(remoteCandidate *ICECandidate) error { + t.lock.RLock() + defer t.lock.RUnlock() + + var ( + candidate ice.Candidate + err error + ) + + if err = t.ensureGatherer(); err != nil { + return err + } + + if remoteCandidate != nil { + if candidate, err = remoteCandidate.ToICE(); err != nil { + return err + } + } + + agent := t.gatherer.getAgent() + if agent == nil { + return fmt.Errorf("%w: unable to add remote candidates", errICEAgentNotExist) + } + + return agent.AddRemoteCandidate(candidate) +} + +// State returns the current ice transport state. +func (t *ICETransport) State() ICETransportState { + if v, ok := t.state.Load().(ICETransportState); ok { + return v + } + + return ICETransportState(0) +} + +// GetLocalParameters returns an IceParameters object which provides information +// uniquely identifying the local peer for the duration of the ICE session. +func (t *ICETransport) GetLocalParameters() (ICEParameters, error) { + if err := t.ensureGatherer(); err != nil { + return ICEParameters{}, err + } + + return t.gatherer.GetLocalParameters() +} + +// GetRemoteParameters returns an IceParameters object which provides information +// uniquely identifying the remote peer for the duration of the ICE session. +func (t *ICETransport) GetRemoteParameters() (ICEParameters, error) { + t.lock.Lock() + defer t.lock.Unlock() + + agent := t.gatherer.getAgent() + if agent == nil { + return ICEParameters{}, fmt.Errorf("%w: unable to get remote parameters", errICEAgentNotExist) + } + + uFrag, uPwd, err := agent.GetRemoteUserCredentials() + if err != nil { + return ICEParameters{}, fmt.Errorf("%w: unable to get remote parameters", err) + } + + return ICEParameters{ + UsernameFragment: uFrag, + Password: uPwd, + }, nil +} + +func (t *ICETransport) setState(i ICETransportState) { + t.state.Store(i) +} + +func (t *ICETransport) newEndpoint(f mux.MatchFunc) *mux.Endpoint { + t.lock.Lock() + defer t.lock.Unlock() + + return t.mux.NewEndpoint(f) +} + +func (t *ICETransport) ensureGatherer() error { + if t.gatherer == nil { + return errICEGathererNotStarted + } else if t.gatherer.getAgent() == nil { + if err := t.gatherer.createAgent(); err != nil { + return err + } + } + + return nil +} + +// Stats reports the current statistics of the ICETransport. +func (t *ICETransport) Stats() TransportStats { + t.lock.RLock() + conn := t.conn + t.lock.RUnlock() + + stats := TransportStats{ + Timestamp: statsTimestampFrom(time.Now()), + Type: StatsTypeTransport, + ID: "iceTransport", + } + if conn != nil { + stats.BytesSent = conn.BytesSent() + stats.BytesReceived = conn.BytesReceived() + } + + return stats +} + +func (t *ICETransport) collectStats(collector *statsReportCollector) { + collector.Collecting() + stats := t.Stats() + collector.Collect(stats.ID, stats) +} + +func (t *ICETransport) haveRemoteCredentialsChange(newUfrag, newPwd string) bool { + t.lock.Lock() + defer t.lock.Unlock() + + agent := t.gatherer.getAgent() + if agent == nil { + return false + } + + uFrag, uPwd, err := agent.GetRemoteUserCredentials() + if err != nil { + return false + } + + return uFrag != newUfrag || uPwd != newPwd +} + +func (t *ICETransport) setRemoteCredentials(newUfrag, newPwd string) error { + t.lock.Lock() + defer t.lock.Unlock() + + agent := t.gatherer.getAgent() + if agent == nil { + return fmt.Errorf("%w: unable to SetRemoteCredentials", errICEAgentNotExist) + } + + return agent.SetRemoteCredentials(newUfrag, newPwd) +} diff --git a/vendor/github.com/pion/webrtc/v4/icetransport_js.go b/vendor/github.com/pion/webrtc/v4/icetransport_js.go new file mode 100644 index 0000000..dad0a07 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/icetransport_js.go @@ -0,0 +1,35 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build js && wasm +// +build js,wasm + +package webrtc + +import "syscall/js" + +// ICETransport allows an application access to information about the ICE +// transport over which packets are sent and received. +type ICETransport struct { + // Pointer to the underlying JavaScript ICETransport object. + underlying js.Value +} + +// JSValue returns the underlying RTCIceTransport +func (t *ICETransport) JSValue() js.Value { + return t.underlying +} + +// GetSelectedCandidatePair returns the selected candidate pair on which packets are sent +// if there is no selected pair nil is returned +func (t *ICETransport) GetSelectedCandidatePair() (*ICECandidatePair, error) { + val := t.underlying.Call("getSelectedCandidatePair") + if val.IsNull() || val.IsUndefined() { + return nil, nil + } + + return NewICECandidatePair( + valueToICECandidate(val.Get("local")), + valueToICECandidate(val.Get("remote")), + ), nil +} diff --git a/vendor/github.com/pion/webrtc/v4/icetransportpolicy.go b/vendor/github.com/pion/webrtc/v4/icetransportpolicy.go new file mode 100644 index 0000000..1f42956 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/icetransportpolicy.go @@ -0,0 +1,75 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +import ( + "encoding/json" +) + +// ICETransportPolicy defines the ICE candidate policy surface the +// permitted candidates. Only these candidates are used for connectivity checks. +type ICETransportPolicy int + +// ICEGatherPolicy is the ORTC equivalent of ICETransportPolicy. +type ICEGatherPolicy = ICETransportPolicy + +const ( + // ICETransportPolicyAll indicates any type of candidate is used. + ICETransportPolicyAll ICETransportPolicy = iota + + // ICETransportPolicyRelay indicates only media relay candidates such + // as candidates passing through a TURN server are used. + ICETransportPolicyRelay + + // ICETransportPolicyNoHost indicates only non-host candidates are used. + ICETransportPolicyNoHost +) + +// This is done this way because of a linter. +const ( + iceTransportPolicyRelayStr = "relay" + iceTransportPolicyNoHostStr = "nohost" + iceTransportPolicyAllStr = "all" +) + +// NewICETransportPolicy takes a string and converts it to ICETransportPolicy. +func NewICETransportPolicy(raw string) ICETransportPolicy { + switch raw { + case iceTransportPolicyNoHostStr: + return ICETransportPolicyNoHost + case iceTransportPolicyRelayStr: + return ICETransportPolicyRelay + default: + return ICETransportPolicyAll + } +} + +func (t ICETransportPolicy) String() string { + switch t { + case ICETransportPolicyNoHost: + return iceTransportPolicyNoHostStr + case ICETransportPolicyRelay: + return iceTransportPolicyRelayStr + case ICETransportPolicyAll: + return iceTransportPolicyAllStr + default: + return ErrUnknownType.Error() + } +} + +// UnmarshalJSON parses the JSON-encoded data and stores the result. +func (t *ICETransportPolicy) UnmarshalJSON(b []byte) error { + var val string + if err := json.Unmarshal(b, &val); err != nil { + return err + } + *t = NewICETransportPolicy(val) + + return nil +} + +// MarshalJSON returns the JSON encoding. +func (t ICETransportPolicy) MarshalJSON() ([]byte, error) { + return json.Marshal(t.String()) +} diff --git a/vendor/github.com/pion/webrtc/v4/icetransportstate.go b/vendor/github.com/pion/webrtc/v4/icetransportstate.go new file mode 100644 index 0000000..ed892ec --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/icetransportstate.go @@ -0,0 +1,156 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +import "github.com/pion/ice/v4" + +// ICETransportState represents the current state of the ICE transport. +type ICETransportState int + +const ( + // ICETransportStateUnknown is the enum's zero-value. + ICETransportStateUnknown ICETransportState = iota + + // ICETransportStateNew indicates the ICETransport is waiting + // for remote candidates to be supplied. + ICETransportStateNew + + // ICETransportStateChecking indicates the ICETransport has + // received at least one remote candidate, and a local and remote + // ICECandidateComplete dictionary was not added as the last candidate. + ICETransportStateChecking + + // ICETransportStateConnected indicates the ICETransport has + // received a response to an outgoing connectivity check, or has + // received incoming DTLS/media after a successful response to an + // incoming connectivity check, but is still checking other candidate + // pairs to see if there is a better connection. + ICETransportStateConnected + + // ICETransportStateCompleted indicates the ICETransport tested + // all appropriate candidate pairs and at least one functioning + // candidate pair has been found. + ICETransportStateCompleted + + // ICETransportStateFailed indicates the ICETransport the last + // candidate was added and all appropriate candidate pairs have either + // failed connectivity checks or have lost consent. + ICETransportStateFailed + + // ICETransportStateDisconnected indicates the ICETransport has received + // at least one local and remote candidate, but the final candidate was + // received yet and all appropriate candidate pairs thus far have been + // tested and failed. + ICETransportStateDisconnected + + // ICETransportStateClosed indicates the ICETransport has shut down + // and is no longer responding to STUN requests. + ICETransportStateClosed +) + +const ( + iceTransportStateNewStr = "new" + iceTransportStateCheckingStr = "checking" + iceTransportStateConnectedStr = "connected" + iceTransportStateCompletedStr = "completed" + iceTransportStateFailedStr = "failed" + iceTransportStateDisconnectedStr = "disconnected" + iceTransportStateClosedStr = "closed" +) + +func newICETransportState(raw string) ICETransportState { + switch raw { + case iceTransportStateNewStr: + return ICETransportStateNew + case iceTransportStateCheckingStr: + return ICETransportStateChecking + case iceTransportStateConnectedStr: + return ICETransportStateConnected + case iceTransportStateCompletedStr: + return ICETransportStateCompleted + case iceTransportStateFailedStr: + return ICETransportStateFailed + case iceTransportStateDisconnectedStr: + return ICETransportStateDisconnected + case iceTransportStateClosedStr: + return ICETransportStateClosed + default: + return ICETransportStateUnknown + } +} + +func (c ICETransportState) String() string { + switch c { + case ICETransportStateNew: + return iceTransportStateNewStr + case ICETransportStateChecking: + return iceTransportStateCheckingStr + case ICETransportStateConnected: + return iceTransportStateConnectedStr + case ICETransportStateCompleted: + return iceTransportStateCompletedStr + case ICETransportStateFailed: + return iceTransportStateFailedStr + case ICETransportStateDisconnected: + return iceTransportStateDisconnectedStr + case ICETransportStateClosed: + return iceTransportStateClosedStr + default: + return ErrUnknownType.Error() + } +} + +func newICETransportStateFromICE(i ice.ConnectionState) ICETransportState { + switch i { + case ice.ConnectionStateNew: + return ICETransportStateNew + case ice.ConnectionStateChecking: + return ICETransportStateChecking + case ice.ConnectionStateConnected: + return ICETransportStateConnected + case ice.ConnectionStateCompleted: + return ICETransportStateCompleted + case ice.ConnectionStateFailed: + return ICETransportStateFailed + case ice.ConnectionStateDisconnected: + return ICETransportStateDisconnected + case ice.ConnectionStateClosed: + return ICETransportStateClosed + default: + return ICETransportStateUnknown + } +} + +func (c ICETransportState) toICE() ice.ConnectionState { + switch c { + case ICETransportStateNew: + return ice.ConnectionStateNew + case ICETransportStateChecking: + return ice.ConnectionStateChecking + case ICETransportStateConnected: + return ice.ConnectionStateConnected + case ICETransportStateCompleted: + return ice.ConnectionStateCompleted + case ICETransportStateFailed: + return ice.ConnectionStateFailed + case ICETransportStateDisconnected: + return ice.ConnectionStateDisconnected + case ICETransportStateClosed: + return ice.ConnectionStateClosed + default: + return ice.ConnectionStateUnknown + } +} + +// MarshalText implements encoding.TextMarshaler. +func (c ICETransportState) MarshalText() ([]byte, error) { + return []byte(c.String()), nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (c *ICETransportState) UnmarshalText(b []byte) error { + *c = newICETransportState(string(b)) + + return nil +} diff --git a/vendor/github.com/pion/webrtc/v4/interceptor.go b/vendor/github.com/pion/webrtc/v4/interceptor.go new file mode 100644 index 0000000..bf4ba39 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/interceptor.go @@ -0,0 +1,361 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js + +package webrtc + +import ( + "sync" + "sync/atomic" + + "github.com/pion/interceptor" + "github.com/pion/interceptor/pkg/flexfec" + "github.com/pion/interceptor/pkg/nack" + "github.com/pion/interceptor/pkg/report" + "github.com/pion/interceptor/pkg/rfc8888" + "github.com/pion/interceptor/pkg/stats" + "github.com/pion/interceptor/pkg/twcc" + "github.com/pion/rtp" + "github.com/pion/sdp/v3" +) + +// RegisterDefaultInterceptors will register some useful interceptors. +// If you want to customize which interceptors are loaded, you should copy the code from this method and remove +// unwanted interceptors. You can also use RegisterDefaultInterceptorsWithOptions to pass in options to modify behavior. +func RegisterDefaultInterceptors(mediaEngine *MediaEngine, interceptorRegistry *interceptor.Registry) error { + return RegisterDefaultInterceptorsWithOptions(mediaEngine, interceptorRegistry) +} + +// RegisterDefaultInterceptorsWithOptions will register some useful interceptors with the provided options. +// If you want to customize which interceptors are loaded, you should copy the code from this method and remove +// unwanted interceptors, or pass in options to modify behavior. +func RegisterDefaultInterceptorsWithOptions(mediaEngine *MediaEngine, interceptorRegistry *interceptor.Registry, + opts ...InterceptorOption, +) error { + var options interceptorOptions + for _, opt := range opts { + opt(&options) + } + + if options.loggerFactory != nil { + // Set logger factory for all interceptors + options.nackGeneratorOptions = append(options.nackGeneratorOptions, + nack.WithGeneratorLoggerFactory(options.loggerFactory)) + options.nackResponderOptions = append(options.nackResponderOptions, + nack.WithResponderLoggerFactory(options.loggerFactory)) + options.reportReceiverOptions = append(options.reportReceiverOptions, + report.WithReceiverLoggerFactory(options.loggerFactory)) + options.reportSenderOptions = append(options.reportSenderOptions, + report.WithSenderLoggerFactory(options.loggerFactory)) + options.statsOptions = append(options.statsOptions, stats.WithLoggerFactory(options.loggerFactory)) + options.twccOptions = append(options.twccOptions, twcc.WithLoggerFactory(options.loggerFactory)) + } + + if err := ConfigureNackWithOptions(mediaEngine, interceptorRegistry, options.nackGeneratorOptions, + options.nackResponderOptions...); err != nil { + return err + } + + if err := ConfigureRTCPReportsWithOptions(interceptorRegistry, options.reportReceiverOptions, + options.reportSenderOptions...); err != nil { + return err + } + + if err := ConfigureSimulcastExtensionHeaders(mediaEngine); err != nil { + return err + } + + if err := ConfigureStatsInterceptorWithOptions(interceptorRegistry, options.statsOptions...); err != nil { + return err + } + + return ConfigureTWCCSenderWithOptions(mediaEngine, interceptorRegistry, options.twccOptions...) +} + +// ConfigureStatsInterceptor will setup everything necessary for generating RTP stream statistics. +func ConfigureStatsInterceptor(interceptorRegistry *interceptor.Registry) error { + return ConfigureStatsInterceptorWithOptions(interceptorRegistry) +} + +// ConfigureStatsInterceptorWithOptions will setup everything necessary for generating RTP stream statistics +// with the provided options. +func ConfigureStatsInterceptorWithOptions(interceptorRegistry *interceptor.Registry, opts ...stats.Option) error { + statsInterceptor, err := stats.NewInterceptor(opts...) + if err != nil { + return err + } + statsInterceptor.OnNewPeerConnection(func(id string, stats stats.Getter) { + statsGetter.Store(id, stats) + }) + interceptorRegistry.Add(statsInterceptor) + + return nil +} + +// lookupStats returns the stats getter for a given peerconnection.statsId. +func lookupStats(id string) (stats.Getter, bool) { + if value, exists := statsGetter.Load(id); exists { + if getter, ok := value.(stats.Getter); ok { + return getter, true + } + } + + return nil, false +} + +// cleanupStats removes the stats getter for a given peerconnection.statsId. +func cleanupStats(id string) { + statsGetter.Delete(id) +} + +// key: string (peerconnection.statsId), value: stats.Getter +var statsGetter sync.Map // nolint:gochecknoglobals + +// ConfigureRTCPReports will setup everything necessary for generating Sender and Receiver Reports. +func ConfigureRTCPReports(interceptorRegistry *interceptor.Registry) error { + return ConfigureRTCPReportsWithOptions(interceptorRegistry, nil) +} + +// ConfigureRTCPReportsWithOptions will setup everything necessary for generating Sender and Receiver Reports +// with the provided options. +func ConfigureRTCPReportsWithOptions(interceptorRegistry *interceptor.Registry, recvOpts []report.ReceiverOption, + sendOpts ...report.SenderOption, +) error { + receiver, err := report.NewReceiverInterceptor(recvOpts...) + if err != nil { + return err + } + + sender, err := report.NewSenderInterceptor(sendOpts...) + if err != nil { + return err + } + + interceptorRegistry.Add(receiver) + interceptorRegistry.Add(sender) + + return nil +} + +// ConfigureNack will setup everything necessary for handling generating/responding to nack messages. +func ConfigureNack(mediaEngine *MediaEngine, interceptorRegistry *interceptor.Registry) error { + return ConfigureNackWithOptions(mediaEngine, interceptorRegistry, nil) +} + +// ConfigureNackWithOptions will setup everything necessary for handling generating/responding to nack messages +// with the provided options. +func ConfigureNackWithOptions(mediaEngine *MediaEngine, interceptorRegistry *interceptor.Registry, + genOpts []nack.GeneratorOption, respOpts ...nack.ResponderOption, +) error { + generator, err := nack.NewGeneratorInterceptor(genOpts...) + if err != nil { + return err + } + + responder, err := nack.NewResponderInterceptor(respOpts...) + if err != nil { + return err + } + + mediaEngine.RegisterFeedback(RTCPFeedback{Type: "nack"}, RTPCodecTypeVideo) + mediaEngine.RegisterFeedback(RTCPFeedback{Type: "nack", Parameter: "pli"}, RTPCodecTypeVideo) + interceptorRegistry.Add(responder) + interceptorRegistry.Add(generator) + + return nil +} + +// ConfigureTWCCHeaderExtensionSender will setup everything necessary for adding +// a TWCC header extension to outgoing RTP packets. This will allow the remote peer to generate TWCC reports. +func ConfigureTWCCHeaderExtensionSender(mediaEngine *MediaEngine, interceptorRegistry *interceptor.Registry) error { + if err := mediaEngine.RegisterHeaderExtension( + RTPHeaderExtensionCapability{URI: sdp.TransportCCURI}, RTPCodecTypeVideo, + ); err != nil { + return err + } + + if err := mediaEngine.RegisterHeaderExtension( + RTPHeaderExtensionCapability{URI: sdp.TransportCCURI}, RTPCodecTypeAudio, + ); err != nil { + return err + } + + twccInterceptor, err := twcc.NewHeaderExtensionInterceptor() + if err != nil { + return err + } + + interceptorRegistry.Add(twccInterceptor) + + return nil +} + +// ConfigureTWCCSender will setup everything necessary for generating TWCC reports. +// This must be called after registering codecs with the MediaEngine. +func ConfigureTWCCSender(mediaEngine *MediaEngine, interceptorRegistry *interceptor.Registry) error { + return ConfigureTWCCSenderWithOptions(mediaEngine, interceptorRegistry) +} + +// ConfigureTWCCSenderWithOptions will setup everything necessary for generating TWCC reports with the provided options. +// This must be called after registering codecs with the MediaEngine. +func ConfigureTWCCSenderWithOptions(mediaEngine *MediaEngine, interceptorRegistry *interceptor.Registry, + opts ...twcc.Option, +) error { + mediaEngine.RegisterFeedback(RTCPFeedback{Type: TypeRTCPFBTransportCC}, RTPCodecTypeVideo) + if err := mediaEngine.RegisterHeaderExtension( + RTPHeaderExtensionCapability{URI: sdp.TransportCCURI}, RTPCodecTypeVideo, + ); err != nil { + return err + } + + mediaEngine.RegisterFeedback(RTCPFeedback{Type: TypeRTCPFBTransportCC}, RTPCodecTypeAudio) + if err := mediaEngine.RegisterHeaderExtension( + RTPHeaderExtensionCapability{URI: sdp.TransportCCURI}, RTPCodecTypeAudio, + ); err != nil { + return err + } + + generator, err := twcc.NewSenderInterceptor(opts...) + if err != nil { + return err + } + + interceptorRegistry.Add(generator) + + return nil +} + +// ConfigureCongestionControlFeedback registers congestion control feedback as +// defined in RFC 8888 (https://datatracker.ietf.org/doc/rfc8888/) +func ConfigureCongestionControlFeedback(mediaEngine *MediaEngine, interceptorRegistry *interceptor.Registry) error { + return ConfigureCongestionControlFeedbackWithOptions(mediaEngine, interceptorRegistry) +} + +// ConfigureCongestionControlFeedbackWithOptions registers congestion control feedback as +// defined in RFC 8888 (https://datatracker.ietf.org/doc/rfc8888/) with the provided options. +func ConfigureCongestionControlFeedbackWithOptions(mediaEngine *MediaEngine, interceptorRegistry *interceptor.Registry, + opts ...rfc8888.Option, +) error { + mediaEngine.RegisterFeedback(RTCPFeedback{Type: TypeRTCPFBACK, Parameter: "ccfb"}, RTPCodecTypeVideo) + mediaEngine.RegisterFeedback(RTCPFeedback{Type: TypeRTCPFBACK, Parameter: "ccfb"}, RTPCodecTypeAudio) + generator, err := rfc8888.NewSenderInterceptor(opts...) + if err != nil { + return err + } + interceptorRegistry.Add(generator) + + return nil +} + +// ConfigureSimulcastExtensionHeaders enables the RTP Extension Headers needed for Simulcast. +func ConfigureSimulcastExtensionHeaders(mediaEngine *MediaEngine) error { + if err := mediaEngine.RegisterHeaderExtension( + RTPHeaderExtensionCapability{URI: sdp.SDESMidURI}, RTPCodecTypeVideo, + ); err != nil { + return err + } + + if err := mediaEngine.RegisterHeaderExtension( + RTPHeaderExtensionCapability{URI: sdp.SDESRTPStreamIDURI}, RTPCodecTypeVideo, + ); err != nil { + return err + } + + return mediaEngine.RegisterHeaderExtension( + RTPHeaderExtensionCapability{URI: sdp.SDESRepairRTPStreamIDURI}, RTPCodecTypeVideo, + ) +} + +// ConfigureFlexFEC03 registers flexfec-03 codec with provided payloadType in mediaEngine +// and adds corresponding interceptor to the registry. +// Note that this function should be called before any other interceptor that modifies RTP packets +// (i.e. TWCCHeaderExtensionSender) is added to the registry, so that packets generated by flexfec +// interceptor are not modified. +func ConfigureFlexFEC03( + payloadType PayloadType, + mediaEngine *MediaEngine, + interceptorRegistry *interceptor.Registry, + options ...flexfec.FecOption, +) error { + codecFEC := RTPCodecParameters{ + RTPCodecCapability: RTPCodecCapability{ + MimeType: MimeTypeFlexFEC03, + ClockRate: 90000, + SDPFmtpLine: "repair-window=10000000", + RTCPFeedback: nil, + }, + PayloadType: payloadType, + } + + if err := mediaEngine.RegisterCodec(codecFEC, RTPCodecTypeVideo); err != nil { + return err + } + + generator, err := flexfec.NewFecInterceptor(options...) + if err != nil { + return err + } + + interceptorRegistry.Add(generator) + + return nil +} + +// interceptorToTrackLocalWriter is an RTPWriter that holds a reference to interceptor.RTPWriter. +type interceptorToTrackLocalWriter struct{ interceptor atomic.Value } // interceptor.RTPWriter } + +// WriteRTP writes an RTP packet using the underlying interceptor.RTPWriter. +func (i *interceptorToTrackLocalWriter) WriteRTP(header *rtp.Header, payload []byte) (int, error) { + if writer, ok := i.interceptor.Load().(interceptor.RTPWriter); ok && writer != nil { + return writer.Write(header, payload, interceptor.Attributes{}) + } + + return 0, nil +} + +// Write writes a raw RTP packet using the underlying interceptor.RTPWriter. +func (i *interceptorToTrackLocalWriter) Write(b []byte) (int, error) { + packet := &rtp.Packet{} + if err := packet.Unmarshal(b); err != nil { + return 0, err + } + + return i.WriteRTP(&packet.Header, packet.Payload) +} + +//nolint:unparam +func createStreamInfo( + id string, + ssrc, ssrcRTX, ssrcFEC SSRC, + payloadType, payloadTypeRTX, payloadTypeFEC PayloadType, + codec RTPCodecCapability, + webrtcHeaderExtensions []RTPHeaderExtensionParameter, +) *interceptor.StreamInfo { + headerExtensions := make([]interceptor.RTPHeaderExtension, 0, len(webrtcHeaderExtensions)) + for _, h := range webrtcHeaderExtensions { + headerExtensions = append(headerExtensions, interceptor.RTPHeaderExtension{ID: h.ID, URI: h.URI}) + } + + feedbacks := make([]interceptor.RTCPFeedback, 0, len(codec.RTCPFeedback)) + for _, f := range codec.RTCPFeedback { + feedbacks = append(feedbacks, interceptor.RTCPFeedback{Type: f.Type, Parameter: f.Parameter}) + } + + return &interceptor.StreamInfo{ + ID: id, + Attributes: interceptor.Attributes{}, + SSRC: uint32(ssrc), + SSRCRetransmission: uint32(ssrcRTX), + SSRCForwardErrorCorrection: uint32(ssrcFEC), + PayloadType: uint8(payloadType), + PayloadTypeRetransmission: uint8(payloadTypeRTX), + PayloadTypeForwardErrorCorrection: uint8(payloadTypeFEC), + RTPHeaderExtensions: headerExtensions, + MimeType: codec.MimeType, + ClockRate: codec.ClockRate, + Channels: codec.Channels, + SDPFmtpLine: codec.SDPFmtpLine, + RTCPFeedback: feedbacks, + } +} diff --git a/vendor/github.com/pion/webrtc/v4/interceptor_option.go b/vendor/github.com/pion/webrtc/v4/interceptor_option.go new file mode 100644 index 0000000..3849b2f --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/interceptor_option.go @@ -0,0 +1,78 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js + +package webrtc + +import ( + "github.com/pion/interceptor/pkg/nack" + "github.com/pion/interceptor/pkg/report" + "github.com/pion/interceptor/pkg/stats" + "github.com/pion/interceptor/pkg/twcc" + "github.com/pion/logging" +) + +// interceptorOptions contains options for configuring interceptors. +type interceptorOptions struct { + loggerFactory logging.LoggerFactory + + nackGeneratorOptions []nack.GeneratorOption + nackResponderOptions []nack.ResponderOption + reportReceiverOptions []report.ReceiverOption + reportSenderOptions []report.SenderOption + statsOptions []stats.Option + twccOptions []twcc.Option +} + +// InterceptorOption is a function that configures InterceptorOptions. +type InterceptorOption func(*interceptorOptions) + +// WithInterceptorLoggerFactory sets the logger factory for interceptors. +func WithInterceptorLoggerFactory(loggerFactory logging.LoggerFactory) InterceptorOption { + return func(o *interceptorOptions) { + o.loggerFactory = loggerFactory + } +} + +// WithNackGeneratorOptions sets options for the NACK generator interceptor. +func WithNackGeneratorOptions(opts ...nack.GeneratorOption) InterceptorOption { + return func(o *interceptorOptions) { + o.nackGeneratorOptions = opts + } +} + +// WithNackResponderOptions sets options for the NACK responder interceptor. +func WithNackResponderOptions(opts ...nack.ResponderOption) InterceptorOption { + return func(o *interceptorOptions) { + o.nackResponderOptions = opts + } +} + +// WithReportReceiverOptions sets options for the report receiver interceptor. +func WithReportReceiverOptions(opts ...report.ReceiverOption) InterceptorOption { + return func(o *interceptorOptions) { + o.reportReceiverOptions = opts + } +} + +// WithReportSenderOptions sets options for the report sender interceptor. +func WithReportSenderOptions(opts ...report.SenderOption) InterceptorOption { + return func(o *interceptorOptions) { + o.reportSenderOptions = opts + } +} + +// WithStatsInterceptorOptions sets options for the stats interceptor. +func WithStatsInterceptorOptions(opts ...stats.Option) InterceptorOption { + return func(o *interceptorOptions) { + o.statsOptions = opts + } +} + +// WithTWCCOptions sets options for the TWCC interceptor. +func WithTWCCOptions(opts ...twcc.Option) InterceptorOption { + return func(o *interceptorOptions) { + o.twccOptions = opts + } +} diff --git a/vendor/github.com/pion/webrtc/v4/internal/fmtp/av1.go b/vendor/github.com/pion/webrtc/v4/internal/fmtp/av1.go new file mode 100644 index 0000000..320c489 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/internal/fmtp/av1.go @@ -0,0 +1,42 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package fmtp + +type av1FMTP struct { + parameters map[string]string +} + +func (h *av1FMTP) MimeType() string { + return "video/av1" +} + +func (h *av1FMTP) Match(b FMTP) bool { + c, ok := b.(*av1FMTP) + if !ok { + return false + } + + // RTP Payload Format For AV1 (v1.0) + // https://aomediacodec.github.io/av1-rtp-spec/ + // If the profile parameter is not present, it MUST be inferred to be 0 (“Main” profile). + hProfile, ok := h.parameters["profile"] + if !ok { + hProfile = "0" + } + cProfile, ok := c.parameters["profile"] + if !ok { + cProfile = "0" + } + if hProfile != cProfile { + return false + } + + return true +} + +func (h *av1FMTP) Parameter(key string) (string, bool) { + v, ok := h.parameters[key] + + return v, ok +} diff --git a/vendor/github.com/pion/webrtc/v4/internal/fmtp/fmtp.go b/vendor/github.com/pion/webrtc/v4/internal/fmtp/fmtp.go new file mode 100644 index 0000000..9af40e5 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/internal/fmtp/fmtp.go @@ -0,0 +1,185 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package fmtp implements per codec parsing of fmtp lines +package fmtp + +import ( + "strings" +) + +func defaultClockRate(mimeType string) uint32 { + defaults := map[string]uint32{ + "audio/opus": 48000, + "audio/pcmu": 8000, + "audio/pcma": 8000, + } + + if def, ok := defaults[strings.ToLower(mimeType)]; ok { + return def + } + + return 90000 +} + +func defaultChannels(mimeType string) uint16 { + defaults := map[string]uint16{ + "audio/opus": 2, + } + + if def, ok := defaults[strings.ToLower(mimeType)]; ok { + return def + } + + return 0 +} + +func parseParameters(line string) map[string]string { + parameters := make(map[string]string) + + for p := range strings.SplitSeq(line, ";") { + pp := strings.SplitN(strings.TrimSpace(p), "=", 2) + key := strings.ToLower(pp[0]) + var value string + if len(pp) > 1 { + value = pp[1] + } + parameters[key] = value + } + + return parameters +} + +// ClockRateEqual checks whether two clock rates are equal. +func ClockRateEqual(mimeType string, valA, valB uint32) bool { + // Lots of users use formats without setting clock rate or channels. + // In this case, use default values. + // It would be better to remove this exception in a future major release. + if valA == 0 { + valA = defaultClockRate(mimeType) + } + if valB == 0 { + valB = defaultClockRate(mimeType) + } + + return valA == valB +} + +// ChannelsEqual checks whether two channels are equal. +func ChannelsEqual(mimeType string, valA, valB uint16) bool { + // Lots of users use formats without setting clock rate or channels. + // In this case, use default values. + // It would be better to remove this exception in a future major release. + if valA == 0 { + valA = defaultChannels(mimeType) + } + if valB == 0 { + valB = defaultChannels(mimeType) + } + + // RFC8866: channel count "is OPTIONAL and may be omitted + // if the number of channels is one". + if valA == 0 { + valA = 1 + } + if valB == 0 { + valB = 1 + } + + return valA == valB +} + +func paramsEqual(valA, valB map[string]string) bool { + for k, v := range valA { + if vb, ok := valB[k]; ok && !strings.EqualFold(vb, v) { + return false + } + } + + for k, v := range valB { + if va, ok := valA[k]; ok && !strings.EqualFold(va, v) { + return false + } + } + + return true +} + +// FMTP interface for implementing custom +// FMTP parsers based on MimeType. +type FMTP interface { + // MimeType returns the MimeType associated with + // the fmtp + MimeType() string + // Match compares two fmtp descriptions for + // compatibility based on the MimeType + Match(f FMTP) bool + // Parameter returns a value for the associated key + // if contained in the parsed fmtp string + Parameter(key string) (string, bool) +} + +// Parse parses an fmtp string based on the MimeType. +func Parse(mimeType string, clockRate uint32, channels uint16, line string) FMTP { + var fmtp FMTP + + parameters := parseParameters(line) + + switch { + case strings.EqualFold(mimeType, "video/h264"): + fmtp = &h264FMTP{ + parameters: parameters, + } + + case strings.EqualFold(mimeType, "video/vp9"): + fmtp = &vp9FMTP{ + parameters: parameters, + } + + case strings.EqualFold(mimeType, "video/av1"): + fmtp = &av1FMTP{ + parameters: parameters, + } + + default: + fmtp = &genericFMTP{ + mimeType: mimeType, + clockRate: clockRate, + channels: channels, + parameters: parameters, + } + } + + return fmtp +} + +type genericFMTP struct { + mimeType string + clockRate uint32 + channels uint16 + parameters map[string]string +} + +func (g *genericFMTP) MimeType() string { + return g.mimeType +} + +// Match returns true if g and b are compatible fmtp descriptions +// The generic implementation is used for MimeTypes that are not defined. +func (g *genericFMTP) Match(b FMTP) bool { + fmtp, ok := b.(*genericFMTP) + if !ok { + return false + } + + return strings.EqualFold(g.mimeType, fmtp.MimeType()) && + ClockRateEqual(g.mimeType, g.clockRate, fmtp.clockRate) && + ChannelsEqual(g.mimeType, g.channels, fmtp.channels) && + paramsEqual(g.parameters, fmtp.parameters) +} + +func (g *genericFMTP) Parameter(key string) (string, bool) { + v, ok := g.parameters[key] + + return v, ok +} diff --git a/vendor/github.com/pion/webrtc/v4/internal/fmtp/h264.go b/vendor/github.com/pion/webrtc/v4/internal/fmtp/h264.go new file mode 100644 index 0000000..4300d37 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/internal/fmtp/h264.go @@ -0,0 +1,86 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package fmtp + +import ( + "encoding/hex" +) + +func profileLevelIDMatches(a, b string) bool { + aa, err := hex.DecodeString(a) + if err != nil || len(aa) < 2 { + return false + } + bb, err := hex.DecodeString(b) + if err != nil || len(bb) < 2 { + return false + } + + return aa[0] == bb[0] && aa[1] == bb[1] +} + +type h264FMTP struct { + parameters map[string]string +} + +func (h *h264FMTP) MimeType() string { + return "video/h264" +} + +// Match returns true if h and b are compatible fmtp descriptions +// Based on RFC6184 Section 8.2.2: +// +// The parameters identifying a media format configuration for H.264 +// are profile-level-id and packetization-mode. These media format +// configuration parameters (except for the level part of profile- +// level-id) MUST be used symmetrically; that is, the answerer MUST +// either maintain all configuration parameters or remove the media +// format (payload type) completely if one or more of the parameter +// values are not supported. +// Informative note: The requirement for symmetric use does not +// apply for the level part of profile-level-id and does not apply +// for the other stream properties and capability parameters. +func (h *h264FMTP) Match(b FMTP) bool { + fmtp, ok := b.(*h264FMTP) + if !ok { + return false + } + + // test packetization-mode + hpmode, hok := h.parameters["packetization-mode"] + if !hok { + return false + } + cpmode, cok := fmtp.parameters["packetization-mode"] + if !cok { + return false + } + + if hpmode != cpmode { + return false + } + + // test profile-level-id + hplid, hok := h.parameters["profile-level-id"] + if !hok { + return false + } + + cplid, cok := fmtp.parameters["profile-level-id"] + if !cok { + return false + } + + if !profileLevelIDMatches(hplid, cplid) { + return false + } + + return true +} + +func (h *h264FMTP) Parameter(key string) (string, bool) { + v, ok := h.parameters[key] + + return v, ok +} diff --git a/vendor/github.com/pion/webrtc/v4/internal/fmtp/vp9.go b/vendor/github.com/pion/webrtc/v4/internal/fmtp/vp9.go new file mode 100644 index 0000000..b647248 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/internal/fmtp/vp9.go @@ -0,0 +1,42 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package fmtp + +type vp9FMTP struct { + parameters map[string]string +} + +func (h *vp9FMTP) MimeType() string { + return "video/vp9" +} + +func (h *vp9FMTP) Match(b FMTP) bool { + c, ok := b.(*vp9FMTP) + if !ok { + return false + } + + // RTP Payload Format for VP9 Video - draft-ietf-payload-vp9-16 + // https://datatracker.ietf.org/doc/html/draft-ietf-payload-vp9-16 + // If no profile-id is present, Profile 0 MUST be inferred + hProfileID, ok := h.parameters["profile-id"] + if !ok { + hProfileID = "0" + } + cProfileID, ok := c.parameters["profile-id"] + if !ok { + cProfileID = "0" + } + if hProfileID != cProfileID { + return false + } + + return true +} + +func (h *vp9FMTP) Parameter(key string) (string, bool) { + v, ok := h.parameters[key] + + return v, ok +} diff --git a/vendor/github.com/pion/webrtc/v4/internal/mux/endpoint.go b/vendor/github.com/pion/webrtc/v4/internal/mux/endpoint.go new file mode 100644 index 0000000..1535250 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/internal/mux/endpoint.go @@ -0,0 +1,107 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package mux + +import ( + "errors" + "io" + "net" + "time" + + "github.com/pion/ice/v4" + "github.com/pion/transport/v4/packetio" +) + +// Endpoint implements net.Conn. It is used to read muxed packets. +type Endpoint struct { + mux *Mux + buffer *packetio.Buffer + onClose func() +} + +// Close unregisters the endpoint from the Mux. +func (e *Endpoint) Close() (err error) { + if e.onClose != nil { + e.onClose() + } + + if err = e.close(); err != nil { + return err + } + + e.mux.RemoveEndpoint(e) + + return nil +} + +func (e *Endpoint) close() error { + return e.buffer.Close() +} + +// Read reads a packet of len(p) bytes from the underlying conn +// that are matched by the associated MuxFunc. +func (e *Endpoint) Read(p []byte) (int, error) { + return e.buffer.Read(p) +} + +// ReadFrom reads a packet of len(p) bytes from the underlying conn +// that are matched by the associated MuxFunc. +func (e *Endpoint) ReadFrom(p []byte) (int, net.Addr, error) { + i, err := e.Read(p) + + return i, nil, err +} + +// Write writes len(p) bytes to the underlying conn. +func (e *Endpoint) Write(p []byte) (int, error) { + n, err := e.mux.nextConn.Write(p) + if errors.Is(err, ice.ErrNoCandidatePairs) { + return 0, nil + } else if errors.Is(err, ice.ErrClosed) { + return 0, io.ErrClosedPipe + } + + return n, err +} + +// WriteTo writes len(p) bytes to the underlying conn. +func (e *Endpoint) WriteTo(p []byte, _ net.Addr) (int, error) { + return e.Write(p) +} + +// LocalAddr returns the local network address, if known. +func (e *Endpoint) LocalAddr() net.Addr { + return e.mux.nextConn.LocalAddr() +} + +// RemoteAddr returns the remote network address, if known. +func (e *Endpoint) RemoteAddr() net.Addr { + return e.mux.nextConn.RemoteAddr() +} + +// SetDeadline sets the read and write deadlines on the shared underlying +// connection. Because the connection is shared, this applies to all endpoints +// on the mux. Per-endpoint read deadlines can be set with SetReadDeadline. +func (e *Endpoint) SetDeadline(t time.Time) error { + return e.mux.nextConn.SetDeadline(t) +} + +// SetReadDeadline sets the read deadline for this Endpoint's internal +// packet buffer. This timeout applies only to reads from this Endpoint, +// not to the shared underlying connection. +func (e *Endpoint) SetReadDeadline(t time.Time) error { + return e.buffer.SetReadDeadline(t) +} + +// SetWriteDeadline sets the write deadline on the shared underlying connection. +// Because the connection is shared, this applies to all endpoints on the mux. +func (e *Endpoint) SetWriteDeadline(t time.Time) error { + return e.mux.nextConn.SetWriteDeadline(t) +} + +// SetOnClose is a user set callback that +// will be executed when `Close` is called. +func (e *Endpoint) SetOnClose(onClose func()) { + e.onClose = onClose +} diff --git a/vendor/github.com/pion/webrtc/v4/internal/mux/mux.go b/vendor/github.com/pion/webrtc/v4/internal/mux/mux.go new file mode 100644 index 0000000..7c47f13 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/internal/mux/mux.go @@ -0,0 +1,216 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package mux multiplexes packets on a single socket (RFC7983) +package mux + +import ( + "errors" + "io" + "net" + "sync" + + "github.com/pion/ice/v4" + "github.com/pion/logging" + "github.com/pion/transport/v4/packetio" +) + +const ( + // The maximum amount of data that can be buffered before returning errors. + maxBufferSize = 1000 * 1000 // 1MB + + // How many total pending packets can be cached. + maxPendingPackets = 15 +) + +// Config collects the arguments to mux.Mux construction into +// a single structure. +type Config struct { + Conn net.Conn + BufferSize int + LoggerFactory logging.LoggerFactory +} + +// Mux allows multiplexing. +type Mux struct { + nextConn net.Conn + bufferSize int + lock sync.Mutex + endpoints map[*Endpoint]MatchFunc + isClosed bool + + pendingPackets [][]byte + + closedCh chan struct{} + log logging.LeveledLogger +} + +// NewMux creates a new Mux. +func NewMux(config Config) *Mux { + mux := &Mux{ + nextConn: config.Conn, + endpoints: make(map[*Endpoint]MatchFunc), + bufferSize: config.BufferSize, + closedCh: make(chan struct{}), + log: config.LoggerFactory.NewLogger("mux"), + } + + go mux.readLoop() + + return mux +} + +// NewEndpoint creates a new Endpoint. +func (m *Mux) NewEndpoint(matchFunc MatchFunc) *Endpoint { + endpoint := &Endpoint{ + mux: m, + buffer: packetio.NewBuffer(), + } + + // Set a maximum size of the buffer in bytes. + endpoint.buffer.SetLimitSize(maxBufferSize) + + m.lock.Lock() + m.endpoints[endpoint] = matchFunc + m.lock.Unlock() + + go m.handlePendingPackets(endpoint, matchFunc) + + return endpoint +} + +// RemoveEndpoint removes an endpoint from the Mux. +func (m *Mux) RemoveEndpoint(e *Endpoint) { + m.lock.Lock() + defer m.lock.Unlock() + delete(m.endpoints, e) +} + +// Close closes the Mux and all associated Endpoints. +func (m *Mux) Close() error { + m.lock.Lock() + for e := range m.endpoints { + if err := e.close(); err != nil { + m.lock.Unlock() + + return err + } + + delete(m.endpoints, e) + } + m.isClosed = true + m.lock.Unlock() + + err := m.nextConn.Close() + if err != nil { + return err + } + + // Wait for readLoop to end + <-m.closedCh + + return nil +} + +func (m *Mux) readLoop() { + defer func() { + close(m.closedCh) + }() + + buf := make([]byte, m.bufferSize) + for { + n, err := m.nextConn.Read(buf) + switch { + case errors.Is(err, io.EOF), errors.Is(err, ice.ErrClosed): + return + case errors.Is(err, io.ErrShortBuffer), errors.Is(err, packetio.ErrTimeout): + m.log.Errorf("mux: failed to read from packetio.Buffer %s", err.Error()) + + continue + case err != nil: + m.log.Errorf("mux: ending readLoop packetio.Buffer error %s", err.Error()) + + return + } + + if err = m.dispatch(buf[:n]); err != nil { + if errors.Is(err, io.ErrClosedPipe) { + // if the buffer was closed, that's not an error we care to report + return + } + m.log.Errorf("mux: ending readLoop dispatch error %s", err.Error()) + + return + } + } +} + +func (m *Mux) dispatch(buf []byte) error { + if len(buf) == 0 { + m.log.Warnf("Warning: mux: unable to dispatch zero length packet") + + return nil + } + + var endpoint *Endpoint + + m.lock.Lock() + for e, f := range m.endpoints { + if f(buf) { + endpoint = e + + break + } + } + if endpoint == nil { + defer m.lock.Unlock() + + if !m.isClosed { + if len(m.pendingPackets) >= maxPendingPackets { + m.log.Warnf( + "Warning: mux: no endpoint for packet starting with %d, not adding to queue size(%d)", + buf[0], //nolint:gosec // G602, false positive? + len(m.pendingPackets), + ) + } else { + m.log.Warnf( + "Warning: mux: no endpoint for packet starting with %d, adding to queue size(%d)", + buf[0], //nolint:gosec // G602, false positive? + len(m.pendingPackets), + ) + m.pendingPackets = append(m.pendingPackets, append([]byte{}, buf...)) + } + } + + return nil + } + + m.lock.Unlock() + _, err := endpoint.buffer.Write(buf) + + // Expected when bytes are received faster than the endpoint can process them (#2152, #2180) + if errors.Is(err, packetio.ErrFull) { + m.log.Infof("mux: endpoint buffer is full, dropping packet") + + return nil + } + + return err +} + +func (m *Mux) handlePendingPackets(endpoint *Endpoint, matchFunc MatchFunc) { + m.lock.Lock() + defer m.lock.Unlock() + + pendingPackets := make([][]byte, 0, len(m.pendingPackets)) + for _, buf := range m.pendingPackets { + if matchFunc(buf) { + if _, err := endpoint.buffer.Write(buf); err != nil { + m.log.Warnf("Warning: mux: error writing packet to endpoint from pending queue: %s", err) + } + } else { + pendingPackets = append(pendingPackets, buf) + } + } + m.pendingPackets = pendingPackets +} diff --git a/vendor/github.com/pion/webrtc/v4/internal/mux/muxfunc.go b/vendor/github.com/pion/webrtc/v4/internal/mux/muxfunc.go new file mode 100644 index 0000000..fedc54e --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/internal/mux/muxfunc.go @@ -0,0 +1,67 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package mux + +// MatchFunc allows custom logic for mapping packets to an Endpoint. +type MatchFunc func([]byte) bool + +// MatchAll always returns true. +func MatchAll([]byte) bool { + return true +} + +// MatchRange returns true if the first byte of buf is in [lower..upper]. +func MatchRange(lower, upper byte, buf []byte) bool { + if len(buf) < 1 { + return false + } + b := buf[0] + + return b >= lower && b <= upper +} + +// MatchFuncs as described in RFC7983 +// https://tools.ietf.org/html/rfc7983 +// +----------------+ +// | [0..3] -+--> forward to STUN +// | | +// | [16..19] -+--> forward to ZRTP +// | | +// packet --> | [20..63] -+--> forward to DTLS +// | | +// | [64..79] -+--> forward to TURN Channel +// | | +// | [128..191] -+--> forward to RTP/RTCP +// +----------------+ + +// MatchDTLS is a MatchFunc that accepts packets with the first byte in [20..63] +// as defied in RFC7983. +func MatchDTLS(b []byte) bool { + return MatchRange(20, 63, b) +} + +// MatchSRTPOrSRTCP is a MatchFunc that accepts packets with the first byte in [128..191] +// as defied in RFC7983. +func MatchSRTPOrSRTCP(b []byte) bool { + return MatchRange(128, 191, b) +} + +func isRTCP(buf []byte) bool { + // Not long enough to determine RTP/RTCP + if len(buf) < 4 { + return false + } + + return buf[1] >= 192 && buf[1] <= 223 +} + +// MatchSRTP is a MatchFunc that only matches SRTP and not SRTCP. +func MatchSRTP(buf []byte) bool { + return MatchSRTPOrSRTCP(buf) && !isRTCP(buf) +} + +// MatchSRTCP is a MatchFunc that only matches SRTCP and not SRTP. +func MatchSRTCP(buf []byte) bool { + return MatchSRTPOrSRTCP(buf) && isRTCP(buf) +} diff --git a/vendor/github.com/pion/webrtc/v4/internal/util/util.go b/vendor/github.com/pion/webrtc/v4/internal/util/util.go new file mode 100644 index 0000000..df1fd9c --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/internal/util/util.go @@ -0,0 +1,77 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package util provides auxiliary functions internally used in webrtc package +package util //nolint: revive + +import ( + "errors" + "strings" + + "github.com/pion/randutil" +) + +const ( + runesAlpha = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" +) + +// Use global random generator to properly seed by crypto grade random. +var globalMathRandomGenerator = randutil.NewMathRandomGenerator() // nolint:gochecknoglobals + +// MathRandAlpha generates a mathematical random alphabet sequence of the requested length. +func MathRandAlpha(n int) string { + return globalMathRandomGenerator.GenerateString(n, runesAlpha) +} + +// RandUint32 generates a mathematical random uint32. +func RandUint32() uint32 { + return globalMathRandomGenerator.Uint32() +} + +// FlattenErrs flattens multiple errors into one. +func FlattenErrs(errs []error) error { + errs2 := []error{} + for _, e := range errs { + if e != nil { + errs2 = append(errs2, e) + } + } + if len(errs2) == 0 { + return nil + } + + return multiError(errs2) +} + +type multiError []error //nolint:errname + +func (me multiError) Error() string { + var errstrings []string + + for _, err := range me { + if err != nil { + errstrings = append(errstrings, err.Error()) + } + } + + if len(errstrings) == 0 { + return "multiError must contain multiple error but is empty" + } + + return strings.Join(errstrings, "\n") +} + +func (me multiError) Is(err error) bool { + for _, e := range me { + if errors.Is(e, err) { + return true + } + if me2, ok := e.(multiError); ok { //nolint:errorlint + if me2.Is(err) { + return true + } + } + } + + return false +} diff --git a/vendor/github.com/pion/webrtc/v4/js_utils.go b/vendor/github.com/pion/webrtc/v4/js_utils.go new file mode 100644 index 0000000..9ee5c1e --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/js_utils.go @@ -0,0 +1,197 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build js && wasm +// +build js,wasm + +package webrtc + +import ( + "fmt" + "syscall/js" +) + +// awaitPromise accepts a js.Value representing a Promise. If the promise +// resolves, it returns (result, nil). If the promise rejects, it returns +// (js.Undefined, error). awaitPromise has a synchronous-like API but does not +// block the JavaScript event loop. +func awaitPromise(promise js.Value) (js.Value, error) { + resultsChan := make(chan js.Value) + errChan := make(chan js.Error) + + thenFunc := js.FuncOf(func(this js.Value, args []js.Value) any { + go func() { + resultsChan <- args[0] + }() + return js.Undefined() + }) + defer thenFunc.Release() + + catchFunc := js.FuncOf(func(this js.Value, args []js.Value) any { + go func() { + errChan <- js.Error{args[0]} + }() + return js.Undefined() + }) + defer catchFunc.Release() + + promise.Call("then", thenFunc).Call("catch", catchFunc) + + select { + case result := <-resultsChan: + return result, nil + case err := <-errChan: + return js.Undefined(), err + } +} + +func valueToUint16Pointer(val js.Value) *uint16 { + if val.IsNull() || val.IsUndefined() { + return nil + } + convertedVal := uint16(val.Int()) + return &convertedVal +} + +func valueToStringPointer(val js.Value) *string { + if val.IsNull() || val.IsUndefined() { + return nil + } + stringVal := val.String() + return &stringVal +} + +func stringToValueOrUndefined(val string) js.Value { + if val == "" { + return js.Undefined() + } + return js.ValueOf(val) +} + +func uint8ToValueOrUndefined(val uint8) js.Value { + if val == 0 { + return js.Undefined() + } + return js.ValueOf(val) +} + +func interfaceToValueOrUndefined(val any) js.Value { + if val == nil { + return js.Undefined() + } + return js.ValueOf(val) +} + +func valueToStringOrZero(val js.Value) string { + if val.IsUndefined() || val.IsNull() { + return "" + } + return val.String() +} + +func valueToUint8OrZero(val js.Value) uint8 { + if val.IsUndefined() || val.IsNull() { + return 0 + } + return uint8(val.Int()) +} + +func valueToUint16OrZero(val js.Value) uint16 { + if val.IsNull() || val.IsUndefined() { + return 0 + } + return uint16(val.Int()) +} + +func valueToUint32OrZero(val js.Value) uint32 { + if val.IsNull() || val.IsUndefined() { + return 0 + } + return uint32(val.Int()) +} + +func valueToStrings(val js.Value) []string { + result := make([]string, val.Length()) + for i := 0; i < val.Length(); i++ { + result[i] = val.Index(i).String() + } + return result +} + +func valueToBoolOrFalse(val js.Value) bool { + if val.IsNull() || val.IsUndefined() { + return false + } + + return val.Bool() +} + +func valueToBoolPointer(val js.Value) *bool { + if val.IsNull() || val.IsUndefined() { + return nil + } + b := val.Bool() + + return &b +} + +func stringPointerToValue(val *string) js.Value { + if val == nil { + return js.Undefined() + } + return js.ValueOf(*val) +} + +func uint16PointerToValue(val *uint16) js.Value { + if val == nil { + return js.Undefined() + } + return js.ValueOf(*val) +} + +func boolToValueOrUndefined(val bool) js.Value { + if !val { + return js.Undefined() + } + + return js.ValueOf(val) +} + +func boolPointerToValue(val *bool) js.Value { + if val == nil { + return js.Undefined() + } + return js.ValueOf(*val) +} + +func stringsToValue(strings []string) js.Value { + val := make([]any, len(strings)) + for i, s := range strings { + val[i] = s + } + return js.ValueOf(val) +} + +func stringEnumToValueOrUndefined(s string) js.Value { + if s == "unknown" { + return js.Undefined() + } + return js.ValueOf(s) +} + +// Converts the return value of recover() to an error. +func recoveryToError(e any) error { + switch e := e.(type) { + case error: + return e + default: + return fmt.Errorf("recovered with non-error value: (%T) %s", e, e) + } +} + +func uint8ArrayValueToBytes(val js.Value) []byte { + result := make([]byte, val.Length()) + js.CopyBytesToGo(result, val) + + return result +} diff --git a/vendor/github.com/pion/webrtc/v4/mediaengine.go b/vendor/github.com/pion/webrtc/v4/mediaengine.go new file mode 100644 index 0000000..102614d --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/mediaengine.go @@ -0,0 +1,860 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js + +package webrtc + +import ( + "errors" + "fmt" + "strconv" + "strings" + "sync" + "time" + + "github.com/pion/rtp" + "github.com/pion/rtp/codecs" + "github.com/pion/sdp/v3" + "github.com/pion/webrtc/v4/internal/fmtp" +) + +type mediaEngineHeaderExtension struct { + uri string + isAudio, isVideo bool + + // If set only Transceivers of this direction are allowed + allowedDirections []RTPTransceiverDirection +} + +// A MediaEngine defines the codecs supported by a PeerConnection, and the +// configuration of those codecs. +type MediaEngine struct { + // If we have attempted to negotiate a codec type yet. + negotiatedVideo, negotiatedAudio bool + negotiateMultiCodecs bool + + videoCodecs, audioCodecs []RTPCodecParameters + negotiatedVideoCodecs, negotiatedAudioCodecs []RTPCodecParameters + + headerExtensions []mediaEngineHeaderExtension + negotiatedHeaderExtensions map[int]mediaEngineHeaderExtension + + mu sync.RWMutex +} + +// setMultiCodecNegotiation enables or disables the negotiation of multiple codecs. +func (m *MediaEngine) setMultiCodecNegotiation(negotiateMultiCodecs bool) { + m.mu.Lock() + defer m.mu.Unlock() + + m.negotiateMultiCodecs = negotiateMultiCodecs +} + +// multiCodecNegotiation returns the current state of the negotiation of multiple codecs. +func (m *MediaEngine) multiCodecNegotiation() bool { + m.mu.RLock() + defer m.mu.RUnlock() + + return m.negotiateMultiCodecs +} + +// RegisterDefaultCodecs registers the default codecs supported by Pion WebRTC. +// RegisterDefaultCodecs is not safe for concurrent use. +func (m *MediaEngine) RegisterDefaultCodecs() error { + // Default Pion Audio Codecs + for _, codec := range []RTPCodecParameters{ + { + RTPCodecCapability: RTPCodecCapability{MimeTypeOpus, 48000, 2, "minptime=10;useinbandfec=1", nil}, + PayloadType: 111, + }, + { + RTPCodecCapability: RTPCodecCapability{MimeTypeG722, 8000, 0, "", nil}, + PayloadType: rtp.PayloadTypeG722, + }, + { + RTPCodecCapability: RTPCodecCapability{MimeTypePCMU, 8000, 0, "", nil}, + PayloadType: rtp.PayloadTypePCMU, + }, + { + RTPCodecCapability: RTPCodecCapability{MimeTypePCMA, 8000, 0, "", nil}, + PayloadType: rtp.PayloadTypePCMA, + }, + } { + if err := m.RegisterCodec(codec, RTPCodecTypeAudio); err != nil { + return err + } + } + + videoRTCPFeedback := []RTCPFeedback{{"goog-remb", ""}, {"ccm", "fir"}, {"nack", ""}, {"nack", "pli"}} + for _, codec := range []RTPCodecParameters{ + { + RTPCodecCapability: RTPCodecCapability{MimeTypeVP8, 90000, 0, "", videoRTCPFeedback}, + PayloadType: 96, + }, + { + RTPCodecCapability: RTPCodecCapability{MimeTypeRTX, 90000, 0, "apt=96", nil}, + PayloadType: 97, + }, + + { + RTPCodecCapability: RTPCodecCapability{ + MimeTypeH264, 90000, 0, + "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42001f", + videoRTCPFeedback, + }, + PayloadType: 102, + }, + { + RTPCodecCapability: RTPCodecCapability{MimeTypeRTX, 90000, 0, "apt=102", nil}, + PayloadType: 103, + }, + + { + RTPCodecCapability: RTPCodecCapability{ + MimeTypeH264, 90000, 0, + "level-asymmetry-allowed=1;packetization-mode=0;profile-level-id=42001f", + videoRTCPFeedback, + }, + PayloadType: 104, + }, + { + RTPCodecCapability: RTPCodecCapability{MimeTypeRTX, 90000, 0, "apt=104", nil}, + PayloadType: 105, + }, + + { + RTPCodecCapability: RTPCodecCapability{ + MimeTypeH264, 90000, 0, + "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f", + videoRTCPFeedback, + }, + PayloadType: 106, + }, + { + RTPCodecCapability: RTPCodecCapability{MimeTypeRTX, 90000, 0, "apt=106", nil}, + PayloadType: 107, + }, + + { + RTPCodecCapability: RTPCodecCapability{ + MimeTypeH264, 90000, 0, + "level-asymmetry-allowed=1;packetization-mode=0;profile-level-id=42e01f", + videoRTCPFeedback, + }, + PayloadType: 108, + }, + { + RTPCodecCapability: RTPCodecCapability{MimeTypeRTX, 90000, 0, "apt=108", nil}, + PayloadType: 109, + }, + + { + RTPCodecCapability: RTPCodecCapability{ + MimeTypeH264, 90000, 0, + "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=4d001f", + videoRTCPFeedback, + }, + PayloadType: 127, + }, + { + RTPCodecCapability: RTPCodecCapability{MimeTypeRTX, 90000, 0, "apt=127", nil}, + PayloadType: 125, + }, + + { + RTPCodecCapability: RTPCodecCapability{ + MimeTypeH264, + 90000, 0, + "level-asymmetry-allowed=1;packetization-mode=0;profile-level-id=4d001f", + videoRTCPFeedback, + }, + PayloadType: 39, + }, + { + RTPCodecCapability: RTPCodecCapability{MimeTypeRTX, 90000, 0, "apt=39", nil}, + PayloadType: 40, + }, + { + RTPCodecCapability: RTPCodecCapability{ + MimeType: MimeTypeH265, + ClockRate: 90000, + RTCPFeedback: videoRTCPFeedback, + }, + PayloadType: 116, + }, + { + RTPCodecCapability: RTPCodecCapability{MimeTypeRTX, 90000, 0, "apt=116", nil}, + PayloadType: 117, + }, + { + RTPCodecCapability: RTPCodecCapability{MimeTypeAV1, 90000, 0, "", videoRTCPFeedback}, + PayloadType: 45, + }, + { + RTPCodecCapability: RTPCodecCapability{MimeTypeRTX, 90000, 0, "apt=45", nil}, + PayloadType: 46, + }, + + { + RTPCodecCapability: RTPCodecCapability{MimeTypeVP9, 90000, 0, "profile-id=0", videoRTCPFeedback}, + PayloadType: 98, + }, + { + RTPCodecCapability: RTPCodecCapability{MimeTypeRTX, 90000, 0, "apt=98", nil}, + PayloadType: 99, + }, + + { + RTPCodecCapability: RTPCodecCapability{MimeTypeVP9, 90000, 0, "profile-id=2", videoRTCPFeedback}, + PayloadType: 100, + }, + { + RTPCodecCapability: RTPCodecCapability{MimeTypeRTX, 90000, 0, "apt=100", nil}, + PayloadType: 101, + }, + + { + RTPCodecCapability: RTPCodecCapability{ + MimeTypeH264, 90000, 0, + "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=64001f", + videoRTCPFeedback, + }, + PayloadType: 112, + }, + { + RTPCodecCapability: RTPCodecCapability{MimeTypeRTX, 90000, 0, "apt=112", nil}, + PayloadType: 113, + }, + } { + if err := m.RegisterCodec(codec, RTPCodecTypeVideo); err != nil { + return err + } + } + + return nil +} + +// addCodec will append codec if it not exists. +func (m *MediaEngine) addCodec(codecs []RTPCodecParameters, codec RTPCodecParameters) ([]RTPCodecParameters, error) { + for _, c := range codecs { + if c.PayloadType == codec.PayloadType { + if strings.EqualFold(c.MimeType, codec.MimeType) && + fmtp.ClockRateEqual(c.MimeType, c.ClockRate, codec.ClockRate) && + fmtp.ChannelsEqual(c.MimeType, c.Channels, codec.Channels) { + return codecs, nil + } + + return codecs, ErrCodecAlreadyRegistered + } + } + + return append(codecs, codec), nil +} + +// RegisterCodec adds codec to the MediaEngine +// These are the list of codecs supported by this PeerConnection. +func (m *MediaEngine) RegisterCodec(codec RTPCodecParameters, typ RTPCodecType) error { + m.mu.Lock() + defer m.mu.Unlock() + + var err error + codec.statsID = fmt.Sprintf("RTPCodec-%d", time.Now().UnixNano()) + switch typ { + case RTPCodecTypeAudio: + m.audioCodecs, err = m.addCodec(m.audioCodecs, codec) + case RTPCodecTypeVideo: + m.videoCodecs, err = m.addCodec(m.videoCodecs, codec) + default: + return ErrUnknownType + } + + return err +} + +// RegisterHeaderExtension adds a header extension to the MediaEngine +// To determine the negotiated value use `GetHeaderExtensionID` after signaling is complete. +// +//nolint:cyclop +func (m *MediaEngine) RegisterHeaderExtension( + extension RTPHeaderExtensionCapability, + typ RTPCodecType, + allowedDirections ...RTPTransceiverDirection, +) error { + m.mu.Lock() + defer m.mu.Unlock() + + if m.negotiatedHeaderExtensions == nil { + m.negotiatedHeaderExtensions = map[int]mediaEngineHeaderExtension{} + } + + if len(allowedDirections) == 0 { + allowedDirections = []RTPTransceiverDirection{RTPTransceiverDirectionRecvonly, RTPTransceiverDirectionSendonly} + } + + for _, direction := range allowedDirections { + if direction != RTPTransceiverDirectionRecvonly && direction != RTPTransceiverDirectionSendonly { + return ErrRegisterHeaderExtensionInvalidDirection + } + } + + extensionIndex := -1 + for i := range m.headerExtensions { + if extension.URI == m.headerExtensions[i].uri { + extensionIndex = i + } + } + + if extensionIndex == -1 { + m.headerExtensions = append(m.headerExtensions, mediaEngineHeaderExtension{}) + extensionIndex = len(m.headerExtensions) - 1 + } + + if typ == RTPCodecTypeAudio { + m.headerExtensions[extensionIndex].isAudio = true + } else if typ == RTPCodecTypeVideo { + m.headerExtensions[extensionIndex].isVideo = true + } + + m.headerExtensions[extensionIndex].uri = extension.URI + m.headerExtensions[extensionIndex].allowedDirections = allowedDirections + + return nil +} + +// RegisterFeedback adds feedback mechanism to already registered codecs. +func (m *MediaEngine) RegisterFeedback(feedback RTCPFeedback, typ RTPCodecType) { + m.mu.Lock() + defer m.mu.Unlock() + + addUniqueFeedback := func(existing []RTCPFeedback) []RTCPFeedback { + for _, f := range existing { + if strings.EqualFold(f.Type, feedback.Type) && strings.EqualFold(f.Parameter, feedback.Parameter) { + return existing + } + } + + return append(existing, feedback) + } + + switch typ { + case RTPCodecTypeVideo: + for i, v := range m.videoCodecs { + v.RTCPFeedback = addUniqueFeedback(v.RTCPFeedback) + m.videoCodecs[i] = v + } + case RTPCodecTypeAudio: + for i, v := range m.audioCodecs { + v.RTCPFeedback = addUniqueFeedback(v.RTCPFeedback) + m.audioCodecs[i] = v + } + default: + } +} + +// getHeaderExtensionID returns the negotiated ID for a header extension. +// If the Header Extension isn't enabled ok will be false. +func (m *MediaEngine) getHeaderExtensionID(extension RTPHeaderExtensionCapability) ( + val int, + audioNegotiated, videoNegotiated bool, +) { + m.mu.RLock() + defer m.mu.RUnlock() + + if m.negotiatedHeaderExtensions == nil { + return 0, false, false + } + + for id, h := range m.negotiatedHeaderExtensions { + if extension.URI == h.uri { + return id, h.isAudio, h.isVideo + } + } + + return +} + +// copy copies any user modifiable state of the MediaEngine +// all internal state is reset. +func (m *MediaEngine) copy() *MediaEngine { + m.mu.Lock() + defer m.mu.Unlock() + cloned := &MediaEngine{ + videoCodecs: append([]RTPCodecParameters{}, m.videoCodecs...), + audioCodecs: append([]RTPCodecParameters{}, m.audioCodecs...), + headerExtensions: append([]mediaEngineHeaderExtension{}, m.headerExtensions...), + } + if len(m.headerExtensions) > 0 { + cloned.negotiatedHeaderExtensions = map[int]mediaEngineHeaderExtension{} + } + + return cloned +} + +func findCodecByPayload(codecs []RTPCodecParameters, payloadType PayloadType) *RTPCodecParameters { + for _, codec := range codecs { + if codec.PayloadType == payloadType { + return &codec + } + } + + return nil +} + +func (m *MediaEngine) getCodecByPayload(payloadType PayloadType) (RTPCodecParameters, RTPCodecType, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + // if we've negotiated audio or video, check the negotiated types before our + // built-in payload types, to ensure we pick the codec the other side wants. + if m.negotiatedVideo { + if codec := findCodecByPayload(m.negotiatedVideoCodecs, payloadType); codec != nil { + return *codec, RTPCodecTypeVideo, nil + } + } + if m.negotiatedAudio { + if codec := findCodecByPayload(m.negotiatedAudioCodecs, payloadType); codec != nil { + return *codec, RTPCodecTypeAudio, nil + } + } + if !m.negotiatedVideo { + if codec := findCodecByPayload(m.videoCodecs, payloadType); codec != nil { + return *codec, RTPCodecTypeVideo, nil + } + } + if !m.negotiatedAudio { + if codec := findCodecByPayload(m.audioCodecs, payloadType); codec != nil { + return *codec, RTPCodecTypeAudio, nil + } + } + + return RTPCodecParameters{}, 0, ErrCodecNotFound +} + +func (m *MediaEngine) collectStats(collector *statsReportCollector) { + m.mu.RLock() + defer m.mu.RUnlock() + + statsLoop := func(codecs []RTPCodecParameters) { + for _, codec := range codecs { + collector.Collecting() + stats := CodecStats{ + Timestamp: statsTimestampFrom(time.Now()), + Type: StatsTypeCodec, + ID: codec.statsID, + PayloadType: codec.PayloadType, + MimeType: codec.MimeType, + ClockRate: codec.ClockRate, + Channels: uint8(codec.Channels), //nolint:gosec // G115 + SDPFmtpLine: codec.SDPFmtpLine, + } + + collector.Collect(stats.ID, stats) + } + } + + statsLoop(m.videoCodecs) + statsLoop(m.audioCodecs) +} + +// Look up a codec and enable if it exists. +// +//nolint:cyclop +func (m *MediaEngine) matchRemoteCodec( + remoteCodec RTPCodecParameters, + typ RTPCodecType, + exactMatches, partialMatches []RTPCodecParameters, +) (RTPCodecParameters, codecMatchType, error) { + codecs := m.videoCodecs + if typ == RTPCodecTypeAudio { + codecs = m.audioCodecs + } + + remoteFmtp := fmtp.Parse( + remoteCodec.RTPCodecCapability.MimeType, + remoteCodec.RTPCodecCapability.ClockRate, + remoteCodec.RTPCodecCapability.Channels, + remoteCodec.RTPCodecCapability.SDPFmtpLine) + + if apt, hasApt := remoteFmtp.Parameter("apt"); hasApt { //nolint:nestif + payloadType, err := strconv.ParseUint(apt, 10, 8) + if err != nil { + return RTPCodecParameters{}, codecMatchNone, err + } + + aptMatch := codecMatchNone + var aptCodec RTPCodecParameters + for _, codec := range exactMatches { + if codec.PayloadType == PayloadType(payloadType) { + aptMatch = codecMatchExact + aptCodec = codec + + break + } + } + + if aptMatch == codecMatchNone { + for _, codec := range partialMatches { + if codec.PayloadType == PayloadType(payloadType) { + aptMatch = codecMatchPartial + aptCodec = codec + + break + } + } + } + + if aptMatch == codecMatchNone { + return RTPCodecParameters{}, codecMatchNone, nil // not an error, we just ignore this codec we don't support + } + + // replace the apt value with the original codec's payload type + toMatchCodec := remoteCodec + if aptMatched, mt := codecParametersFuzzySearch(aptCodec, codecs); mt == aptMatch { + toMatchCodec.SDPFmtpLine = strings.Replace( + toMatchCodec.SDPFmtpLine, + fmt.Sprintf("apt=%d", payloadType), + fmt.Sprintf("apt=%d", aptMatched.PayloadType), + 1, + ) + } + + // if apt's media codec is partial match, then apt codec must be partial match too. + localCodec, matchType := codecParametersFuzzySearch(toMatchCodec, codecs) + if matchType == codecMatchExact && aptMatch == codecMatchPartial { + matchType = codecMatchPartial + } + + return localCodec, matchType, nil + } + + localCodec, matchType := codecParametersFuzzySearch(remoteCodec, codecs) + + return localCodec, matchType, nil +} + +// Update header extensions from a remote media section. +func (m *MediaEngine) updateHeaderExtensionFromMediaSection(media *sdp.MediaDescription) error { + var typ RTPCodecType + switch { + case strings.EqualFold(media.MediaName.Media, "audio"): + typ = RTPCodecTypeAudio + case strings.EqualFold(media.MediaName.Media, "video"): + typ = RTPCodecTypeVideo + default: + return nil + } + extensions, err := rtpExtensionsFromMediaDescription(media) + if err != nil { + return err + } + + for extension, id := range extensions { + if err = m.updateHeaderExtension(id, extension, typ); err != nil { + return err + } + } + + return nil +} + +// Look up a header extension and enable if it exists. +func (m *MediaEngine) updateHeaderExtension(id int, extension string, typ RTPCodecType) error { + if m.negotiatedHeaderExtensions == nil { + return nil + } + + for _, localExtension := range m.headerExtensions { + if localExtension.uri == extension { + h := mediaEngineHeaderExtension{uri: extension, allowedDirections: localExtension.allowedDirections} + if existingValue, ok := m.negotiatedHeaderExtensions[id]; ok { + h = existingValue + } + + switch { + case localExtension.isAudio && typ == RTPCodecTypeAudio: + h.isAudio = true + case localExtension.isVideo && typ == RTPCodecTypeVideo: + h.isVideo = true + } + + m.negotiatedHeaderExtensions[id] = h + } + } + + return nil +} + +func (m *MediaEngine) pushCodecs(codecs []RTPCodecParameters, typ RTPCodecType) error { + var joinedErr error + for _, codec := range codecs { + var err error + if typ == RTPCodecTypeAudio { + m.negotiatedAudioCodecs, err = m.addCodec(m.negotiatedAudioCodecs, codec) + } else if typ == RTPCodecTypeVideo { + m.negotiatedVideoCodecs, err = m.addCodec(m.negotiatedVideoCodecs, codec) + } + if err != nil { + joinedErr = errors.Join(joinedErr, err) + } + } + + return joinedErr +} + +// Update the MediaEngine from a remote description. +func (m *MediaEngine) updateFromRemoteDescription(desc sdp.SessionDescription) error { //nolint:cyclop,gocognit + m.mu.Lock() + defer m.mu.Unlock() + + for _, media := range desc.MediaDescriptions { + var typ RTPCodecType + + switch { + case strings.EqualFold(media.MediaName.Media, "audio"): + typ = RTPCodecTypeAudio + case strings.EqualFold(media.MediaName.Media, "video"): + typ = RTPCodecTypeVideo + } + + switch { + case !m.negotiatedAudio && typ == RTPCodecTypeAudio: + m.negotiatedAudio = true + case !m.negotiatedVideo && typ == RTPCodecTypeVideo: + m.negotiatedVideo = true + default: + // update header extesions from remote sdp if codec is negotiated, Firefox + // would send updated header extension in renegotiation. + // e.g. publish first track without simucalst ->negotiated-> publish second track with simucalst + // then the two media secontions have different rtp header extensions in offer + if err := m.updateHeaderExtensionFromMediaSection(media); err != nil { + return err + } + + if !m.negotiateMultiCodecs || (typ != RTPCodecTypeAudio && typ != RTPCodecTypeVideo) { + continue + } + } + + codecs, err := codecsFromMediaDescription(media) + if err != nil { + return err + } + + addIfNew := func(existingCodecs []RTPCodecParameters, codec RTPCodecParameters) []RTPCodecParameters { + found := false + for _, existingCodec := range existingCodecs { + if existingCodec.PayloadType == codec.PayloadType { + found = true + + break + } + } + + if !found { + existingCodecs = append(existingCodecs, codec) + } + + return existingCodecs + } + + exactMatches := make([]RTPCodecParameters, 0, len(codecs)) + partialMatches := make([]RTPCodecParameters, 0, len(codecs)) + + for _, remoteCodec := range codecs { + localCodec, matchType, mErr := m.matchRemoteCodec(remoteCodec, typ, exactMatches, partialMatches) + if mErr != nil { + return mErr + } + + remoteCodec.RTCPFeedback = rtcpFeedbackIntersection(localCodec.RTCPFeedback, remoteCodec.RTCPFeedback) + + if matchType == codecMatchExact { + exactMatches = addIfNew(exactMatches, remoteCodec) + } else if matchType == codecMatchPartial { + partialMatches = addIfNew(partialMatches, remoteCodec) + } + } + // second pass in case there were missed RTX codecs + for _, remoteCodec := range codecs { + localCodec, matchType, mErr := m.matchRemoteCodec(remoteCodec, typ, exactMatches, partialMatches) + if mErr != nil { + return mErr + } + + remoteCodec.RTCPFeedback = rtcpFeedbackIntersection(localCodec.RTCPFeedback, remoteCodec.RTCPFeedback) + + if matchType == codecMatchExact { + exactMatches = addIfNew(exactMatches, remoteCodec) + } else if matchType == codecMatchPartial { + partialMatches = addIfNew(partialMatches, remoteCodec) + } + } + + // use exact matches when they exist, otherwise fall back to partial + switch { + case len(exactMatches) > 0: + err = m.pushCodecs(exactMatches, typ) + case len(partialMatches) > 0: + err = m.pushCodecs(partialMatches, typ) + default: + // no match, not negotiated + continue + } + if err != nil { + return err + } + + if err := m.updateHeaderExtensionFromMediaSection(media); err != nil { + return err + } + } + + return nil +} + +func (m *MediaEngine) getCodecsByKind(typ RTPCodecType) []RTPCodecParameters { + m.mu.RLock() + defer m.mu.RUnlock() + + if typ == RTPCodecTypeVideo { + if m.negotiatedVideo { + return m.negotiatedVideoCodecs + } + + return m.videoCodecs + } else if typ == RTPCodecTypeAudio { + if m.negotiatedAudio { + return m.negotiatedAudioCodecs + } + + return m.audioCodecs + } + + return nil +} + +//nolint:gocognit,cyclop +func (m *MediaEngine) getRTPParametersByKind(typ RTPCodecType, directions []RTPTransceiverDirection) RTPParameters { + headerExtensions := make([]RTPHeaderExtensionParameter, 0) + + // perform before locking to prevent recursive RLocks + foundCodecs := m.getCodecsByKind(typ) + + m.mu.RLock() + defer m.mu.RUnlock() + + //nolint:nestif + if (m.negotiatedVideo && typ == RTPCodecTypeVideo) || (m.negotiatedAudio && typ == RTPCodecTypeAudio) { + for id, e := range m.negotiatedHeaderExtensions { + if haveRTPTransceiverDirectionIntersection(e.allowedDirections, directions) && + (e.isAudio && typ == RTPCodecTypeAudio || e.isVideo && typ == RTPCodecTypeVideo) { + headerExtensions = append(headerExtensions, RTPHeaderExtensionParameter{ID: id, URI: e.uri}) + } + } + } else { + mediaHeaderExtensions := make(map[int]mediaEngineHeaderExtension) + for _, ext := range m.headerExtensions { + usingNegotiatedID := false + for id := range m.negotiatedHeaderExtensions { + if m.negotiatedHeaderExtensions[id].uri == ext.uri { + usingNegotiatedID = true + mediaHeaderExtensions[id] = ext + + break + } + } + if !usingNegotiatedID { + for id := 1; id < 15; id++ { + idAvailable := true + if _, ok := mediaHeaderExtensions[id]; ok { + idAvailable = false + } + if _, taken := m.negotiatedHeaderExtensions[id]; idAvailable && !taken { + mediaHeaderExtensions[id] = ext + + break + } + } + } + } + + for id, e := range mediaHeaderExtensions { + if haveRTPTransceiverDirectionIntersection(e.allowedDirections, directions) && + (e.isAudio && typ == RTPCodecTypeAudio || e.isVideo && typ == RTPCodecTypeVideo) { + headerExtensions = append(headerExtensions, RTPHeaderExtensionParameter{ID: id, URI: e.uri}) + } + } + } + + return RTPParameters{ + HeaderExtensions: headerExtensions, + Codecs: foundCodecs, + } +} + +func (m *MediaEngine) getRTPParametersByPayloadType(payloadType PayloadType) (RTPParameters, error) { + codec, typ, err := m.getCodecByPayload(payloadType) + if err != nil { + return RTPParameters{}, err + } + + m.mu.RLock() + defer m.mu.RUnlock() + headerExtensions := make([]RTPHeaderExtensionParameter, 0) + for id, e := range m.negotiatedHeaderExtensions { + if e.isAudio && typ == RTPCodecTypeAudio || e.isVideo && typ == RTPCodecTypeVideo { + headerExtensions = append(headerExtensions, RTPHeaderExtensionParameter{ID: id, URI: e.uri}) + } + } + + return RTPParameters{ + HeaderExtensions: headerExtensions, + Codecs: []RTPCodecParameters{codec}, + }, nil +} + +func payloaderForCodec(codec RTPCodecCapability) (rtp.Payloader, error) { + switch strings.ToLower(codec.MimeType) { + case strings.ToLower(MimeTypeH264): + return &codecs.H264Payloader{}, nil + case strings.ToLower(MimeTypeH265): + return &codecs.H265Payloader{}, nil + case strings.ToLower(MimeTypeOpus): + return &codecs.OpusPayloader{}, nil + case strings.ToLower(MimeTypeVP8): + return &codecs.VP8Payloader{ + EnablePictureID: true, + }, nil + case strings.ToLower(MimeTypeVP9): + return &codecs.VP9Payloader{}, nil + case strings.ToLower(MimeTypeAV1): + return &codecs.AV1Payloader{}, nil + case strings.ToLower(MimeTypeG722): + return &codecs.G722Payloader{}, nil + case strings.ToLower(MimeTypePCMU), strings.ToLower(MimeTypePCMA): + return &codecs.G711Payloader{}, nil + default: + return nil, ErrNoPayloaderForCodec + } +} + +func (m *MediaEngine) isRTXEnabled(typ RTPCodecType, directions []RTPTransceiverDirection) bool { + for _, p := range m.getRTPParametersByKind(typ, directions).Codecs { + if strings.EqualFold(p.MimeType, MimeTypeRTX) { + return true + } + } + + return false +} + +func (m *MediaEngine) isFECEnabled(typ RTPCodecType, directions []RTPTransceiverDirection) bool { + for _, p := range m.getRTPParametersByKind(typ, directions).Codecs { + if strings.Contains(strings.ToLower(p.MimeType), MimeTypeFlexFEC) { + return true + } + } + + return false +} diff --git a/vendor/github.com/pion/webrtc/v4/mimetype.go b/vendor/github.com/pion/webrtc/v4/mimetype.go new file mode 100644 index 0000000..5850153 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/mimetype.go @@ -0,0 +1,46 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +const ( + // MimeTypeH264 H264 MIME type. + // Note: Matching should be case insensitive. + MimeTypeH264 = "video/H264" + // MimeTypeH265 H265 MIME type + // Note: Matching should be case insensitive. + MimeTypeH265 = "video/H265" + // MimeTypeOpus Opus MIME type + // Note: Matching should be case insensitive. + MimeTypeOpus = "audio/opus" + // MimeTypeVP8 VP8 MIME type + // Note: Matching should be case insensitive. + MimeTypeVP8 = "video/VP8" + // MimeTypeVP9 VP9 MIME type + // Note: Matching should be case insensitive. + MimeTypeVP9 = "video/VP9" + // MimeTypeAV1 AV1 MIME type + // Note: Matching should be case insensitive. + MimeTypeAV1 = "video/AV1" + // MimeTypeG722 G722 MIME type + // Note: Matching should be case insensitive. + MimeTypeG722 = "audio/G722" + // MimeTypePCMU PCMU MIME type + // Note: Matching should be case insensitive. + MimeTypePCMU = "audio/PCMU" + // MimeTypePCMA PCMA MIME type + // Note: Matching should be case insensitive. + MimeTypePCMA = "audio/PCMA" + // MimeTypeRTX RTX MIME type + // Note: Matching should be case insensitive. + MimeTypeRTX = "video/rtx" + // MimeTypeFlexFEC FEC MIME Type + // Note: Matching should be case insensitive. + MimeTypeFlexFEC = "video/flexfec" + // MimeTypeFlexFEC03 FlexFEC03 MIME Type + // Note: Matching should be case insensitive. + MimeTypeFlexFEC03 = "video/flexfec-03" + // MimeTypeUlpFEC UlpFEC MIME Type + // Note: Matching should be case insensitive. + MimeTypeUlpFEC = "video/ulpfec" +) diff --git a/vendor/github.com/pion/webrtc/v4/networktype.go b/vendor/github.com/pion/webrtc/v4/networktype.go new file mode 100644 index 0000000..046c723 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/networktype.go @@ -0,0 +1,127 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +import ( + "fmt" + + "github.com/pion/ice/v4" +) + +func supportedNetworkTypes() []NetworkType { + return []NetworkType{ + NetworkTypeUDP4, + NetworkTypeUDP6, + // NetworkTypeTCP4, // Not supported yet + // NetworkTypeTCP6, // Not supported yet + } +} + +// NetworkType represents the type of network. +type NetworkType int + +const ( + // NetworkTypeUnknown is the enum's zero-value. + NetworkTypeUnknown NetworkType = iota + + // NetworkTypeUDP4 indicates UDP over IPv4. + NetworkTypeUDP4 + + // NetworkTypeUDP6 indicates UDP over IPv6. + NetworkTypeUDP6 + + // NetworkTypeTCP4 indicates TCP over IPv4. + NetworkTypeTCP4 + + // NetworkTypeTCP6 indicates TCP over IPv6. + NetworkTypeTCP6 +) + +// This is done this way because of a linter. +const ( + networkTypeUDP4Str = "udp4" + networkTypeUDP6Str = "udp6" + networkTypeTCP4Str = "tcp4" + networkTypeTCP6Str = "tcp6" +) + +func (t NetworkType) String() string { + switch t { + case NetworkTypeUDP4: + return networkTypeUDP4Str + case NetworkTypeUDP6: + return networkTypeUDP6Str + case NetworkTypeTCP4: + return networkTypeTCP4Str + case NetworkTypeTCP6: + return networkTypeTCP6Str + default: + return ErrUnknownType.Error() + } +} + +// Protocol returns udp or tcp. +func (t NetworkType) Protocol() string { //nolint:staticcheck + switch t { + case NetworkTypeUDP4: + return "udp" + case NetworkTypeUDP6: + return "udp" + case NetworkTypeTCP4: + return "tcp" + case NetworkTypeTCP6: + return "tcp" + default: + return ErrUnknownType.Error() + } +} + +// NewNetworkType allows create network type from string +// It will be useful for getting custom network types from external config. +func NewNetworkType(raw string) (NetworkType, error) { + switch raw { + case networkTypeUDP4Str: + return NetworkTypeUDP4, nil + case networkTypeUDP6Str: + return NetworkTypeUDP6, nil + case networkTypeTCP4Str: + return NetworkTypeTCP4, nil + case networkTypeTCP6Str: + return NetworkTypeTCP6, nil + default: + return NetworkTypeUnknown, fmt.Errorf("%w: %s", errNetworkTypeUnknown, raw) + } +} + +func getNetworkType(iceNetworkType ice.NetworkType) (NetworkType, error) { + switch iceNetworkType { + case ice.NetworkTypeUDP4: + return NetworkTypeUDP4, nil + case ice.NetworkTypeUDP6: + return NetworkTypeUDP6, nil + case ice.NetworkTypeTCP4: + return NetworkTypeTCP4, nil + case ice.NetworkTypeTCP6: + return NetworkTypeTCP6, nil + default: + return NetworkTypeUnknown, fmt.Errorf("%w: %s", errNetworkTypeUnknown, iceNetworkType.String()) + } +} + +func toICENetworkTypes(networkTypes []NetworkType) []ice.NetworkType { + if len(networkTypes) == 0 { + return nil + } + + converted := make([]ice.NetworkType, 0, len(networkTypes)) + for _, networkType := range networkTypes { + converted = append(converted, networkType.toICE()) + } + + return converted +} + +func (networkType NetworkType) toICE() ice.NetworkType { + return ice.NetworkType(networkType) +} diff --git a/vendor/github.com/pion/webrtc/v4/oauthcredential.go b/vendor/github.com/pion/webrtc/v4/oauthcredential.go new file mode 100644 index 0000000..03073a6 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/oauthcredential.go @@ -0,0 +1,18 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +// OAuthCredential represents OAuth credential information which is used by +// the STUN/TURN client to connect to an ICE server as defined in +// https://tools.ietf.org/html/rfc7635. Note that the kid parameter is not +// located in OAuthCredential, but in ICEServer's username member. +type OAuthCredential struct { + // MACKey is a base64-url encoded format. It is used in STUN message + // integrity hash calculation. + MACKey string + + // AccessToken is a base64-encoded format. This is an encrypted + // self-contained token that is opaque to the application. + AccessToken string //nolint:gosec // not a secret. +} diff --git a/vendor/github.com/pion/webrtc/v4/offeransweroptions.go b/vendor/github.com/pion/webrtc/v4/offeransweroptions.go new file mode 100644 index 0000000..7eb5ebe --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/offeransweroptions.go @@ -0,0 +1,33 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +// OfferAnswerOptions is a base structure which describes the options that +// can be used to control the offer/answer creation process. +type OfferAnswerOptions struct { + // VoiceActivityDetection allows the application to provide information + // about whether it wishes voice detection feature to be enabled or disabled. + VoiceActivityDetection bool + // ICETricklingSupported indicates whether the ICE agent should use trickle ICE + // If set, the "a=ice-options:trickle" attribute is added to the generated SDP payload. + // (See https://datatracker.ietf.org/doc/html/rfc9725#section-4.3.3) + ICETricklingSupported bool +} + +// AnswerOptions structure describes the options used to control the answer +// creation process. +type AnswerOptions struct { + OfferAnswerOptions +} + +// OfferOptions structure describes the options used to control the offer +// creation process. +type OfferOptions struct { + OfferAnswerOptions + + // ICERestart forces the underlying ice gathering process to be restarted. + // When this value is true, the generated description will have ICE + // credentials that are different from the current credentials + ICERestart bool +} diff --git a/vendor/github.com/pion/webrtc/v4/operations.go b/vendor/github.com/pion/webrtc/v4/operations.go new file mode 100644 index 0000000..fb0717e --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/operations.go @@ -0,0 +1,159 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +import ( + "container/list" + "sync" + "sync/atomic" +) + +// Operation is a function. +type operation func() + +// Operations is a task executor. +type operations struct { + mu sync.Mutex + busyCh chan struct{} + ops *list.List + + updateNegotiationNeededFlagOnEmptyChain *atomic.Bool + onNegotiationNeeded func() + isClosed bool +} + +func newOperations( + updateNegotiationNeededFlagOnEmptyChain *atomic.Bool, + onNegotiationNeeded func(), +) *operations { + return &operations{ + ops: list.New(), + updateNegotiationNeededFlagOnEmptyChain: updateNegotiationNeededFlagOnEmptyChain, + onNegotiationNeeded: onNegotiationNeeded, + } +} + +// Enqueue adds a new action to be executed. If there are no actions scheduled, +// the execution will start immediately in a new goroutine. If the queue has been +// closed, the operation will be dropped. The queue is only deliberately closed +// by a user. +func (o *operations) Enqueue(op operation) { + o.mu.Lock() + defer o.mu.Unlock() + _ = o.tryEnqueue(op) +} + +// tryEnqueue attempts to enqueue the given operation. It returns false +// if the op is invalid or the queue is closed. mu must be locked by +// tryEnqueue's caller. +func (o *operations) tryEnqueue(op operation) bool { + if op == nil { + return false + } + + if o.isClosed { + return false + } + o.ops.PushBack(op) + + if o.busyCh == nil { + o.busyCh = make(chan struct{}) + go o.start() + } + + return true +} + +// IsEmpty checks if there are tasks in the queue. +func (o *operations) IsEmpty() bool { + o.mu.Lock() + defer o.mu.Unlock() + + return o.ops.Len() == 0 +} + +// Done blocks until all currently enqueued operations are finished executing. +// For more complex synchronization, use Enqueue directly. +func (o *operations) Done() { + var wg sync.WaitGroup + wg.Add(1) + o.mu.Lock() + enqueued := o.tryEnqueue(func() { + wg.Done() + }) + o.mu.Unlock() + if !enqueued { + return + } + wg.Wait() +} + +// GracefulClose waits for the operations queue to be cleared and forbids +// new operations from being enqueued. +func (o *operations) GracefulClose() { + o.mu.Lock() + if o.isClosed { + o.mu.Unlock() + + return + } + // do not enqueue anymore ops from here on + // o.isClosed=true will also not allow a new busyCh + // to be created. + o.isClosed = true + + busyCh := o.busyCh + o.mu.Unlock() + if busyCh == nil { + return + } + <-busyCh +} + +func (o *operations) pop() func() { + o.mu.Lock() + defer o.mu.Unlock() + if o.ops.Len() == 0 { + return nil + } + + e := o.ops.Front() + o.ops.Remove(e) + if op, ok := e.Value.(operation); ok { + return op + } + + return nil +} + +func (o *operations) start() { + defer func() { + o.mu.Lock() + defer o.mu.Unlock() + // this wil lbe the most recent busy chan + close(o.busyCh) + + if o.ops.Len() == 0 || o.isClosed { + o.busyCh = nil + + return + } + + // either a new operation was enqueued while we + // were busy, or an operation panicked + o.busyCh = make(chan struct{}) + go o.start() + }() + + fn := o.pop() + for fn != nil { + fn() + fn = o.pop() + } + if !o.updateNegotiationNeededFlagOnEmptyChain.Load() { + return + } + o.updateNegotiationNeededFlagOnEmptyChain.Store(false) + o.onNegotiationNeeded() +} diff --git a/vendor/github.com/pion/webrtc/v4/package.json b/vendor/github.com/pion/webrtc/v4/package.json new file mode 100644 index 0000000..a7b35c7 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/package.json @@ -0,0 +1,12 @@ +{ + "name": "webrtc", + "repository": "git@github.com:pion/webrtc.git", + "private": true, + "devDependencies": { + "@roamhq/wrtc": "^0.10.0" + }, + "dependencies": { + "request": "2.88.2" + }, + "packageManager": "yarn@1.22.22+sha512.a6b2f7906b721bba3d67d4aff083df04dad64c399707841b7acf00f6b133b7ac24255f2652fa22ae3534329dc6180534e98d17432037ff6fd140556e2bb3137e" +} diff --git a/vendor/github.com/pion/webrtc/v4/peerconnection.go b/vendor/github.com/pion/webrtc/v4/peerconnection.go new file mode 100644 index 0000000..acb5b5b --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/peerconnection.go @@ -0,0 +1,3129 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js + +package webrtc + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "errors" + "fmt" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/pion/ice/v4" + "github.com/pion/interceptor" + "github.com/pion/interceptor/pkg/stats" + "github.com/pion/logging" + "github.com/pion/rtcp" + "github.com/pion/sdp/v3" + "github.com/pion/srtp/v3" + "github.com/pion/webrtc/v4/internal/util" + "github.com/pion/webrtc/v4/pkg/rtcerr" +) + +// PeerConnection represents a WebRTC connection that establishes a +// peer-to-peer communications with another PeerConnection instance in a +// browser, or to another endpoint implementing the required protocols. +type PeerConnection struct { + id string + mu sync.RWMutex + + sdpOrigin sdp.Origin + + // ops is an operations queue which will ensure the enqueued actions are + // executed in order. It is used for asynchronously, but serially processing + // remote and local descriptions + ops *operations + + configuration Configuration + + currentLocalDescription *SessionDescription + pendingLocalDescription *SessionDescription + currentRemoteDescription *SessionDescription + pendingRemoteDescription *SessionDescription + signalingState SignalingState + iceConnectionState atomic.Value // ICEConnectionState + connectionState atomic.Value // PeerConnectionState + + idpLoginURL *string + + isClosed *atomic.Bool + isGracefullyClosingOrClosed bool + isCloseDone chan struct{} + isGracefulCloseDone chan struct{} + isNegotiationNeeded *atomic.Bool + updateNegotiationNeededFlagOnEmptyChain *atomic.Bool + + lastOffer string + lastAnswer string + // Whether the remote endpoint can accept trickled ICE candidates. + canTrickleICECandidates ICETrickleCapability + + // a value containing the last known greater mid value + // we internally generate mids as numbers. Needed since JSEP + // requires that when reusing a media section a new unique mid + // should be defined (see JSEP 3.4.1). + greaterMid int + + rtpTransceivers []*RTPTransceiver + nonMediaBandwidthProbe atomic.Value // RTPReceiver + + onSignalingStateChangeHandler func(SignalingState) + onICEConnectionStateChangeHandler atomic.Value // func(ICEConnectionState) + onConnectionStateChangeHandler atomic.Value // func(PeerConnectionState) + onTrackHandler func(*TrackRemote, *RTPReceiver) + onDataChannelHandler func(*DataChannel) + onNegotiationNeededHandler atomic.Value // func() + + iceGatherer *ICEGatherer + iceTransport *ICETransport + dtlsTransport *DTLSTransport + sctpTransport *SCTPTransport + + // A reference to the associated API state used by this connection + api *API + log logging.LeveledLogger + + interceptorRTCPWriter interceptor.RTCPWriter + statsGetter stats.Getter +} + +// NewPeerConnection creates a PeerConnection with the default codecs and interceptors. +// +// If you wish to customize the set of available codecs and/or the set of active interceptors, +// create an API with a custom MediaEngine and/or interceptor.Registry, +// then call [(*API).NewPeerConnection] instead of this function. +func NewPeerConnection(configuration Configuration) (*PeerConnection, error) { + api := NewAPI() + + return api.NewPeerConnection(configuration) +} + +// NewPeerConnection creates a new PeerConnection with the provided configuration against the received API object. +// This method will attach a default set of codecs and interceptors to +// the resulting PeerConnection. If this behavior is not desired, +// set the set of codecs and interceptors explicitly by using +// [WithMediaEngine] and [WithInterceptorRegistry] when calling [NewAPI]. +func (api *API) NewPeerConnection(configuration Configuration) (*PeerConnection, error) { + // https://w3c.github.io/webrtc-pc/#constructor (Step #2) + // Some variables defined explicitly despite their implicit zero values to + // allow better readability to understand what is happening. + + pc := &PeerConnection{ + id: fmt.Sprintf("PeerConnection-%d", time.Now().UnixNano()), + configuration: Configuration{ + ICEServers: []ICEServer{}, + ICETransportPolicy: ICETransportPolicyAll, + BundlePolicy: BundlePolicyBalanced, + RTCPMuxPolicy: RTCPMuxPolicyRequire, + Certificates: []Certificate{}, + ICECandidatePoolSize: 0, + }, + isClosed: &atomic.Bool{}, + isCloseDone: make(chan struct{}), + isGracefulCloseDone: make(chan struct{}), + isNegotiationNeeded: &atomic.Bool{}, + updateNegotiationNeededFlagOnEmptyChain: &atomic.Bool{}, + lastOffer: "", + lastAnswer: "", + greaterMid: -1, + signalingState: SignalingStateStable, + + api: api, + log: api.settingEngine.LoggerFactory.NewLogger("pc"), + } + pc.ops = newOperations(pc.updateNegotiationNeededFlagOnEmptyChain, pc.onNegotiationNeeded) + + pc.iceConnectionState.Store(ICEConnectionStateNew) + pc.connectionState.Store(PeerConnectionStateNew) + + i, err := api.interceptorRegistry.Build(pc.id) + if err != nil { + return nil, err + } + + if getter, ok := lookupStats(pc.id); ok { + pc.statsGetter = getter + } + + pc.api = &API{ + settingEngine: api.settingEngine, + interceptor: i, + } + + if api.settingEngine.disableMediaEngineCopy { + pc.api.mediaEngine = api.mediaEngine + } else { + pc.api.mediaEngine = api.mediaEngine.copy() + pc.api.mediaEngine.setMultiCodecNegotiation(!api.settingEngine.disableMediaEngineMultipleCodecs) + } + + if err = pc.initConfiguration(configuration); err != nil { + return nil, err + } + + pc.iceGatherer, err = pc.createICEGatherer() + if err != nil { + return nil, err + } + + // Create the ice transport + iceTransport := pc.createICETransport() + pc.iceTransport = iceTransport + + // Create the DTLS transport + dtlsTransport, err := pc.api.NewDTLSTransport(pc.iceTransport, pc.configuration.Certificates) + if err != nil { + return nil, err + } + pc.dtlsTransport = dtlsTransport + + // Create the SCTP transport + pc.sctpTransport = pc.api.NewSCTPTransport(pc.dtlsTransport) + + // Wire up the on datachannel handler + pc.sctpTransport.OnDataChannel(func(d *DataChannel) { + pc.mu.RLock() + handler := pc.onDataChannelHandler + pc.mu.RUnlock() + if handler != nil { + handler(d) + } + }) + + if pc.configuration.ICECandidatePoolSize > 0 { + if err := pc.iceGatherer.Gather(); err != nil { + return nil, err + } + } + + pc.interceptorRTCPWriter = pc.api.interceptor.BindRTCPWriter(interceptor.RTCPWriterFunc(pc.writeRTCP)) + + return pc, nil +} + +// initConfiguration defines validation of the specified Configuration and +// its assignment to the internal configuration variable. This function differs +// from its SetConfiguration counterpart because most of the checks do not +// include verification statements related to the existing state. Thus the +// function describes only minor verification of some the struct variables. +func (pc *PeerConnection) initConfiguration(configuration Configuration) error { //nolint:cyclop + if configuration.PeerIdentity != "" { + pc.configuration.PeerIdentity = configuration.PeerIdentity + } + + // https://www.w3.org/TR/webrtc/#constructor (step #3) + if len(configuration.Certificates) > 0 { + now := time.Now() + for _, x509Cert := range configuration.Certificates { + if !x509Cert.Expires().IsZero() && now.After(x509Cert.Expires()) { + return &rtcerr.InvalidAccessError{Err: ErrCertificateExpired} + } + pc.configuration.Certificates = append(pc.configuration.Certificates, x509Cert) + } + } else { + sk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return &rtcerr.UnknownError{Err: err} + } + certificate, err := GenerateCertificate(sk) + if err != nil { + return err + } + pc.configuration.Certificates = []Certificate{*certificate} + } + + if configuration.BundlePolicy != BundlePolicyUnknown { + pc.configuration.BundlePolicy = configuration.BundlePolicy + } + + if configuration.RTCPMuxPolicy != RTCPMuxPolicyUnknown { + pc.configuration.RTCPMuxPolicy = configuration.RTCPMuxPolicy + } + + if configuration.ICECandidatePoolSize != 0 { + // Issue #2892, ice candidate pool size greater than 1 is not supported + if configuration.ICECandidatePoolSize > 1 { + return &rtcerr.NotSupportedError{Err: errICECandidatePoolSizeTooLarge} + } + + pc.configuration.ICECandidatePoolSize = configuration.ICECandidatePoolSize + } + + pc.configuration.ICETransportPolicy = configuration.ICETransportPolicy + pc.configuration.SDPSemantics = configuration.SDPSemantics + pc.configuration.AlwaysNegotiateDataChannels = configuration.AlwaysNegotiateDataChannels + + sanitizedICEServers := configuration.getICEServers() + if len(sanitizedICEServers) > 0 { + for _, server := range sanitizedICEServers { + if err := server.validate(); err != nil { + return err + } + } + pc.configuration.ICEServers = sanitizedICEServers + } + + return nil +} + +// OnSignalingStateChange sets an event handler which is invoked when the +// peer connection's signaling state changes. +func (pc *PeerConnection) OnSignalingStateChange(f func(SignalingState)) { + pc.mu.Lock() + defer pc.mu.Unlock() + pc.onSignalingStateChangeHandler = f +} + +func (pc *PeerConnection) onSignalingStateChange(newState SignalingState) { + pc.mu.RLock() + handler := pc.onSignalingStateChangeHandler + pc.mu.RUnlock() + + pc.log.Infof("signaling state changed to %s", newState) + if handler != nil { + go handler(newState) + } +} + +// OnDataChannel sets an event handler which is invoked when a data +// channel message arrives from a remote peer. +func (pc *PeerConnection) OnDataChannel(f func(*DataChannel)) { + pc.mu.Lock() + defer pc.mu.Unlock() + pc.onDataChannelHandler = f +} + +// OnNegotiationNeeded sets an event handler which is invoked when +// a change has occurred which requires session negotiation. +func (pc *PeerConnection) OnNegotiationNeeded(f func()) { + pc.onNegotiationNeededHandler.Store(f) +} + +// onNegotiationNeeded enqueues negotiationNeededOp if necessary +// caller of this method should hold `pc.mu` lock +// https://www.w3.org/TR/webrtc/#dfn-update-the-negotiation-needed-flag +func (pc *PeerConnection) onNegotiationNeeded() { + // 4.7.3.1 If the length of connection.[[Operations]] is not 0, then set + // connection.[[UpdateNegotiationNeededFlagOnEmptyChain]] to true, and abort these steps. + if !pc.ops.IsEmpty() { + pc.updateNegotiationNeededFlagOnEmptyChain.Store(true) + + return + } + pc.ops.Enqueue(pc.negotiationNeededOp) +} + +// https://www.w3.org/TR/webrtc/#dfn-update-the-negotiation-needed-flag +func (pc *PeerConnection) negotiationNeededOp() { + // 4.7.3.2.1 If connection.[[IsClosed]] is true, abort these steps. + if pc.isClosed.Load() { + return + } + + // 4.7.3.2.2 If the length of connection.[[Operations]] is not 0, + // then set connection.[[UpdateNegotiationNeededFlagOnEmptyChain]] to + // true, and abort these steps. + if !pc.ops.IsEmpty() { + pc.updateNegotiationNeededFlagOnEmptyChain.Store(true) + + return + } + + // 4.7.3.2.3 If connection's signaling state is not "stable", abort these steps. + if pc.SignalingState() != SignalingStateStable { + return + } + + // 4.7.3.2.4 If the result of checking if negotiation is needed is false, + // clear the negotiation-needed flag by setting connection.[[NegotiationNeeded]] + // to false, and abort these steps. + if !pc.checkNegotiationNeeded() { + pc.isNegotiationNeeded.Store(false) + + return + } + + // 4.7.3.2.5 If connection.[[NegotiationNeeded]] is already true, abort these steps. + if pc.isNegotiationNeeded.Load() { + return + } + + // 4.7.3.2.6 Set connection.[[NegotiationNeeded]] to true. + pc.isNegotiationNeeded.Store(true) + + // 4.7.3.2.7 Fire an event named negotiationneeded at connection. + if handler, ok := pc.onNegotiationNeededHandler.Load().(func()); ok && handler != nil { + handler() + } +} + +func (pc *PeerConnection) checkNegotiationNeeded() bool { //nolint:gocognit,cyclop + // To check if negotiation is needed for connection, perform the following checks: + // Skip 1, 2 steps + // Step 3 + pc.mu.Lock() + defer pc.mu.Unlock() + + localDesc := pc.currentLocalDescription + remoteDesc := pc.currentRemoteDescription + + if localDesc == nil { + return true + } + + pc.sctpTransport.lock.Lock() + lenDataChannel := len(pc.sctpTransport.dataChannels) + pc.sctpTransport.lock.Unlock() + + if lenDataChannel != 0 && haveDataChannel(localDesc) == nil { + return true + } + + for _, transceiver := range pc.rtpTransceivers { + // https://www.w3.org/TR/webrtc/#dfn-update-the-negotiation-needed-flag + // Step 5.1 + // if t.stopping && !t.stopped { + // return true + // } + mid := getByMid(transceiver.Mid(), localDesc) + + // Step 5.2 + if mid == nil { + return true + } + + // Step 5.3.1 + if transceiver.Direction() == RTPTransceiverDirectionSendrecv || + transceiver.Direction() == RTPTransceiverDirectionSendonly { + descMsid, okMsid := mid.Attribute(sdp.AttrKeyMsid) + sender := transceiver.Sender() + if sender == nil { + return true + } + track := sender.Track() + if track == nil { + // Situation when sender's track is nil could happen when + // a) replaceTrack(nil) is called + // b) removeTrack() is called, changing the transceiver's direction to inactive + // As t.Direction() in this branch is either sendrecv or sendonly, we believe (a) option is the case + // As calling replaceTrack does not require renegotiation, we skip check for this transceiver + continue + } + if !okMsid || descMsid != track.StreamID()+" "+track.ID() { + return true + } + } + switch localDesc.Type { + case SDPTypeOffer: + // Step 5.3.2 + rm := getByMid(transceiver.Mid(), remoteDesc) + if rm == nil { + return true + } + + if getPeerDirection(mid) != transceiver.Direction() && getPeerDirection(rm) != transceiver.Direction().Revers() { + return true + } + case SDPTypeAnswer: + // Step 5.3.3 + if _, ok := mid.Attribute(transceiver.Direction().String()); !ok { + return true + } + default: + } + + // Step 5.4 + // if t.stopped && t.Mid() != "" { + // if getByMid(t.Mid(), localDesc) != nil || getByMid(t.Mid(), remoteDesc) != nil { + // return true + // } + // } + } + // Step 6 + return false +} + +// OnICECandidate sets an event handler which is invoked when a new ICE +// candidate is found. +// ICE candidate gathering only begins when SetLocalDescription or +// SetRemoteDescription is called. +// Take note that the handler will be called with a nil pointer when +// gathering is finished. +func (pc *PeerConnection) OnICECandidate(f func(*ICECandidate)) { + pc.iceGatherer.OnLocalCandidate(f) +} + +// OnICEGatheringStateChange sets an event handler which is invoked when the +// ICE candidate gathering state has changed. +func (pc *PeerConnection) OnICEGatheringStateChange(f func(ICEGatheringState)) { + pc.iceGatherer.OnStateChange( + func(gathererState ICEGathererState) { + switch gathererState { + case ICEGathererStateGathering: + f(ICEGatheringStateGathering) + case ICEGathererStateComplete: + f(ICEGatheringStateComplete) + default: + // Other states ignored + } + }) +} + +// OnTrack sets an event handler which is called when remote track +// arrives from a remote peer. +func (pc *PeerConnection) OnTrack(f func(*TrackRemote, *RTPReceiver)) { + pc.mu.Lock() + defer pc.mu.Unlock() + pc.onTrackHandler = f +} + +func (pc *PeerConnection) onTrack(t *TrackRemote, r *RTPReceiver) { + pc.mu.RLock() + handler := pc.onTrackHandler + pc.mu.RUnlock() + + pc.log.Debugf("got new track: %+v", t) + if t != nil { + if handler != nil { + go handler(t, r) + } else { + pc.log.Warnf("OnTrack unset, unable to handle incoming media streams") + } + } +} + +// OnICEConnectionStateChange sets an event handler which is called +// when an ICE connection state is changed. +func (pc *PeerConnection) OnICEConnectionStateChange(f func(ICEConnectionState)) { + pc.onICEConnectionStateChangeHandler.Store(f) +} + +func (pc *PeerConnection) onICEConnectionStateChange(cs ICEConnectionState) { + pc.iceConnectionState.Store(cs) + pc.log.Infof("ICE connection state changed: %s", cs) + if handler, ok := pc.onICEConnectionStateChangeHandler.Load().(func(ICEConnectionState)); ok && handler != nil { + handler(cs) + } +} + +// OnConnectionStateChange sets an event handler which is called +// when the PeerConnectionState has changed. +func (pc *PeerConnection) OnConnectionStateChange(f func(PeerConnectionState)) { + pc.onConnectionStateChangeHandler.Store(f) +} + +func (pc *PeerConnection) onConnectionStateChange(cs PeerConnectionState) { + pc.connectionState.Store(cs) + pc.log.Infof("peer connection state changed: %s", cs) + if handler, ok := pc.onConnectionStateChangeHandler.Load().(func(PeerConnectionState)); ok && handler != nil { + go handler(cs) + } +} + +// SetConfiguration updates the configuration of this PeerConnection object. +// https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-setconfiguration +func (pc *PeerConnection) SetConfiguration(configuration Configuration) error { //nolint:gocognit,cyclop + // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-setconfiguration (step #2) + if pc.isClosed.Load() { + return &rtcerr.InvalidStateError{Err: ErrConnectionClosed} + } + + // Not in W3C spec, but we validate PeerIdentity cannot be modified. + if configuration.PeerIdentity != "" { + if configuration.PeerIdentity != pc.configuration.PeerIdentity { + return &rtcerr.InvalidModificationError{Err: ErrModifyingPeerIdentity} + } + pc.configuration.PeerIdentity = configuration.PeerIdentity + } + + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #3.1 - #3.3) + if len(configuration.Certificates) > 0 { + if len(configuration.Certificates) != len(pc.configuration.Certificates) { + return &rtcerr.InvalidModificationError{Err: ErrModifyingCertificates} + } + + for i, certificate := range configuration.Certificates { + if !pc.configuration.Certificates[i].Equals(certificate) { + return &rtcerr.InvalidModificationError{Err: ErrModifyingCertificates} + } + } + pc.configuration.Certificates = configuration.Certificates + } + + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #3.4) + if configuration.BundlePolicy != BundlePolicyUnknown { + if configuration.BundlePolicy != pc.configuration.BundlePolicy { + return &rtcerr.InvalidModificationError{Err: ErrModifyingBundlePolicy} + } + pc.configuration.BundlePolicy = configuration.BundlePolicy + } + + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #3.5) + if configuration.RTCPMuxPolicy != RTCPMuxPolicyUnknown { + if configuration.RTCPMuxPolicy != pc.configuration.RTCPMuxPolicy { + return &rtcerr.InvalidModificationError{Err: ErrModifyingRTCPMuxPolicy} + } + pc.configuration.RTCPMuxPolicy = configuration.RTCPMuxPolicy + } + + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #3.6) + if configuration.ICECandidatePoolSize != 0 { + if pc.configuration.ICECandidatePoolSize != configuration.ICECandidatePoolSize && + pc.LocalDescription() != nil { + return &rtcerr.InvalidModificationError{Err: ErrModifyingICECandidatePoolSize} + } + + // Currently, there is no logic implemented to handle runtime changes to this value. + // Commenting out to prevent unexpected behavior. + // nolint:godox + // TODO: Re-enable this in a future update when proper handling is implemented. + // pc.configuration.ICECandidatePoolSize = configuration.ICECandidatePoolSize + pc.log.Warn("Changing ICECandidatePoolSize is not yet supported. The new value will be ignored.") + } + + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #4-6) + for _, server := range configuration.ICEServers { + if err := server.validate(); err != nil { + return err + } + } + + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #7) + pc.configuration.ICETransportPolicy = configuration.ICETransportPolicy + + // AlwaysNegotiateDataChannels is treated like other zero-value configuration + // fields: only a non-zero value (true) updates the existing setting. + if configuration.AlwaysNegotiateDataChannels { + pc.configuration.AlwaysNegotiateDataChannels = configuration.AlwaysNegotiateDataChannels + } + + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #8) + // nolint:godox + // TODO: If the new ICE candidate pool size changes the existing setting, + // this may result in immediate gathering of new pooled candidates, + // or discarding of existing pooled candidates + if pc.configuration.ICECandidatePoolSize != configuration.ICECandidatePoolSize { + pc.log.Warn("Dynamic ICE candidate pool adjustment is not yet supported") + } + + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #9) + // Update the ICE gatherer so new servers take effect at the next gathering phase. + if pc.iceGatherer != nil { + if err := pc.iceGatherer.updateServers(configuration.ICEServers, pc.configuration.ICETransportPolicy); err != nil { + pc.log.Debugf("Could not update ICE gatherer servers: %v", err) + } + } + + pc.configuration.ICEServers = configuration.ICEServers + + return nil +} + +// GetConfiguration returns a Configuration object representing the current +// configuration of this PeerConnection object. The returned object is a +// copy and direct mutation on it will not take affect until SetConfiguration +// has been called with Configuration passed as its only argument. +// https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-getconfiguration +func (pc *PeerConnection) GetConfiguration() Configuration { + return pc.configuration +} + +func (pc *PeerConnection) ID() string { + pc.mu.RLock() + defer pc.mu.RUnlock() + + return pc.id +} + +// hasLocalDescriptionChanged returns whether local media (rtpTransceivers) has changed +// caller of this method should hold `pc.mu` lock. +func (pc *PeerConnection) hasLocalDescriptionChanged(desc *SessionDescription) bool { + for _, t := range pc.rtpTransceivers { + m := getByMid(t.Mid(), desc) + if m == nil { + return true + } + + if getPeerDirection(m) != t.Direction() { + return true + } + } + + return false +} + +// CreateOffer starts the PeerConnection and generates the localDescription +// https://w3c.github.io/webrtc-pc/#dom-rtcpeerconnection-createoffer +// +//nolint:gocognit,cyclop +func (pc *PeerConnection) CreateOffer(options *OfferOptions) (SessionDescription, error) { + useIdentity := pc.idpLoginURL != nil + switch { + case useIdentity: + return SessionDescription{}, errIdentityProviderNotImplemented + case pc.isClosed.Load(): + return SessionDescription{}, &rtcerr.InvalidStateError{Err: ErrConnectionClosed} + } + + if options != nil && options.ICERestart { + if err := pc.iceTransport.restart(); err != nil { + return SessionDescription{}, err + } + } + + var ( + descr *sdp.SessionDescription + offer SessionDescription + err error + ) + + // This may be necessary to recompute if, for example, createOffer was called when only an + // audio RTCRtpTransceiver was added to connection, but while performing the in-parallel + // steps to create an offer, a video RTCRtpTransceiver was added, requiring additional + // inspection of video system resources. + count := 0 + pc.mu.Lock() + defer pc.mu.Unlock() + for { + // We cache current transceivers to ensure they aren't + // mutated during offer generation. We later check if they have + // been mutated and recompute the offer if necessary. + currentTransceivers := pc.rtpTransceivers + + // in-parallel steps to create an offer + // https://w3c.github.io/webrtc-pc/#dfn-in-parallel-steps-to-create-an-offer + isPlanB := pc.configuration.SDPSemantics == SDPSemanticsPlanB + if pc.currentRemoteDescription != nil && isPlanB { + isPlanB = descriptionPossiblyPlanB(pc.currentRemoteDescription) + } + + // include unmatched local transceivers + if !isPlanB { //nolint:nestif + // update the greater mid if the remote description provides a greater one + if pc.currentRemoteDescription != nil { + var numericMid int + for _, media := range pc.currentRemoteDescription.parsed.MediaDescriptions { + mid := getMidValue(media) + if mid == "" { + continue + } + numericMid, err = strconv.Atoi(mid) + if err != nil { + continue + } + if numericMid > pc.greaterMid { + pc.greaterMid = numericMid + } + } + } + for _, t := range currentTransceivers { + if mid := t.Mid(); mid != "" { + numericMid, errMid := strconv.Atoi(mid) + if errMid == nil { + if numericMid > pc.greaterMid { + pc.greaterMid = numericMid + } + } + + continue + } + pc.greaterMid++ + err = t.SetMid(strconv.Itoa(pc.greaterMid)) + if err != nil { + return SessionDescription{}, err + } + } + } + + if pc.currentRemoteDescription == nil { + descr, err = pc.generateUnmatchedSDP(currentTransceivers, useIdentity) + } else { + descr, err = pc.generateMatchedSDP( + currentTransceivers, + useIdentity, + true, /*includeUnmatched */ + connectionRoleFromDtlsRole(defaultDtlsRoleOffer), + false, + ) + } + + if err != nil { + return SessionDescription{}, err + } + + if options != nil && options.ICETricklingSupported { + descr.WithICETrickleAdvertised() + } + if pc.api.settingEngine.renomination.enabled { + descr.WithICERenomination() + } + + updateSDPOrigin(&pc.sdpOrigin, descr) + sdpBytes, err := descr.Marshal() + if err != nil { + return SessionDescription{}, err + } + + offer = SessionDescription{ + Type: SDPTypeOffer, + SDP: string(sdpBytes), + parsed: descr, + } + + // Verify local media hasn't changed during offer + // generation. Recompute if necessary + if isPlanB || !pc.hasLocalDescriptionChanged(&offer) { + break + } + count++ + if count >= 128 { + return SessionDescription{}, errExcessiveRetries + } + } + + pc.lastOffer = offer.SDP + + return offer, nil +} + +func (pc *PeerConnection) createICEGatherer() (*ICEGatherer, error) { + g, err := pc.api.NewICEGatherer(ICEGatherOptions{ + ICEServers: pc.configuration.getICEServers(), + ICEGatherPolicy: pc.configuration.ICETransportPolicy, + ICECandidatePoolSize: pc.configuration.ICECandidatePoolSize, + }) + if err != nil { + return nil, err + } + + return g, nil +} + +// Update the PeerConnectionState given the state of relevant transports +// https://www.w3.org/TR/webrtc/#rtcpeerconnectionstate-enum +// +//nolint:cyclop +func (pc *PeerConnection) updateConnectionState( + iceConnectionState ICEConnectionState, + dtlsTransportState DTLSTransportState, +) { + connectionState := PeerConnectionStateNew + switch { + // The RTCPeerConnection object's [[IsClosed]] slot is true. + case pc.isClosed.Load(): + connectionState = PeerConnectionStateClosed + + // Any of the RTCIceTransports or RTCDtlsTransports are in a "failed" state. + case iceConnectionState == ICEConnectionStateFailed || dtlsTransportState == DTLSTransportStateFailed: + connectionState = PeerConnectionStateFailed + + // Any of the RTCIceTransports or RTCDtlsTransports are in the "disconnected" + // state and none of them are in the "failed" or "connecting" or "checking" state. */ + case iceConnectionState == ICEConnectionStateDisconnected: + connectionState = PeerConnectionStateDisconnected + + // None of the previous states apply and all RTCIceTransports are in the "new" or "closed" state, + // and all RTCDtlsTransports are in the "new" or "closed" state, or there are no transports. + case (iceConnectionState == ICEConnectionStateNew || iceConnectionState == ICEConnectionStateClosed) && + (dtlsTransportState == DTLSTransportStateNew || dtlsTransportState == DTLSTransportStateClosed): + connectionState = PeerConnectionStateNew + + // None of the previous states apply and any RTCIceTransport is in the "new" or "checking" state or + // any RTCDtlsTransport is in the "new" or "connecting" state. + case (iceConnectionState == ICEConnectionStateNew || iceConnectionState == ICEConnectionStateChecking) || + (dtlsTransportState == DTLSTransportStateNew || dtlsTransportState == DTLSTransportStateConnecting): + connectionState = PeerConnectionStateConnecting + + // All RTCIceTransports and RTCDtlsTransports are in the "connected", "completed" or "closed" + // state and all RTCDtlsTransports are in the "connected" or "closed" state. + case (iceConnectionState == ICEConnectionStateConnected || + iceConnectionState == ICEConnectionStateCompleted || iceConnectionState == ICEConnectionStateClosed) && + (dtlsTransportState == DTLSTransportStateConnected || dtlsTransportState == DTLSTransportStateClosed): + connectionState = PeerConnectionStateConnected + } + + if pc.connectionState.Load() == connectionState { + return + } + + pc.onConnectionStateChange(connectionState) +} + +func (pc *PeerConnection) createICETransport() *ICETransport { + transport := pc.api.NewICETransport(pc.iceGatherer) + transport.internalOnConnectionStateChangeHandler.Store(func(state ICETransportState) { + var cs ICEConnectionState + switch state { + case ICETransportStateNew: + cs = ICEConnectionStateNew + case ICETransportStateChecking: + cs = ICEConnectionStateChecking + case ICETransportStateConnected: + cs = ICEConnectionStateConnected + case ICETransportStateCompleted: + cs = ICEConnectionStateCompleted + case ICETransportStateFailed: + cs = ICEConnectionStateFailed + case ICETransportStateDisconnected: + cs = ICEConnectionStateDisconnected + case ICETransportStateClosed: + cs = ICEConnectionStateClosed + default: + pc.log.Warnf("OnConnectionStateChange: unhandled ICE state: %s", state) + + return + } + pc.onICEConnectionStateChange(cs) + pc.updateConnectionState(cs, pc.dtlsTransport.State()) + }) + + return transport +} + +// CreateAnswer starts the PeerConnection and generates the localDescription. +// +//nolint:cyclop +func (pc *PeerConnection) CreateAnswer(options *AnswerOptions) (SessionDescription, error) { + useIdentity := pc.idpLoginURL != nil + remoteDesc := pc.RemoteDescription() + switch { + case remoteDesc == nil: + return SessionDescription{}, &rtcerr.InvalidStateError{Err: ErrNoRemoteDescription} + case useIdentity: + return SessionDescription{}, errIdentityProviderNotImplemented + case pc.isClosed.Load(): + return SessionDescription{}, &rtcerr.InvalidStateError{Err: ErrConnectionClosed} + case pc.signalingState.Get() != SignalingStateHaveRemoteOffer && + pc.signalingState.Get() != SignalingStateHaveLocalPranswer: + return SessionDescription{}, &rtcerr.InvalidStateError{Err: ErrIncorrectSignalingState} + } + + connectionRole := connectionRoleFromDtlsRole(pc.api.settingEngine.answeringDTLSRole) + if connectionRole == sdp.ConnectionRole(0) { + dtlsRole := dtlsRoleFromSDP(remoteDesc.parsed) + switch dtlsRole { + case DTLSRoleClient: + connectionRole = connectionRoleFromDtlsRole(DTLSRoleServer) + case DTLSRoleServer: + connectionRole = connectionRoleFromDtlsRole(DTLSRoleClient) + default: + connectionRole = connectionRoleFromDtlsRole(defaultDtlsRoleAnswer) + } + + // If one of the agents is lite and the other one is not, the lite agent must be the controlled agent. + // If both or neither agents are lite the offering agent is controlling. + // RFC 8445 S6.1.1 + if isIceLiteSet(remoteDesc.parsed) && !pc.api.settingEngine.candidates.ICELite { + connectionRole = connectionRoleFromDtlsRole(DTLSRoleServer) + } + } + pc.mu.Lock() + defer pc.mu.Unlock() + + descr, err := pc.generateMatchedSDP( + pc.rtpTransceivers, + useIdentity, + false, /*includeUnmatched */ + connectionRole, + pc.api.settingEngine.ignoreRidPauseForRecv, + ) + if err != nil { + return SessionDescription{}, err + } + + if options != nil && options.ICETricklingSupported { + descr.WithICETrickleAdvertised() + } + if pc.api.settingEngine.renomination.enabled { + descr.WithICERenomination() + } + + updateSDPOrigin(&pc.sdpOrigin, descr) + sdpBytes, err := descr.Marshal() + if err != nil { + return SessionDescription{}, err + } + + desc := SessionDescription{ + Type: SDPTypeAnswer, + SDP: string(sdpBytes), + parsed: descr, + } + pc.lastAnswer = desc.SDP + + return desc, nil +} + +// 4.4.1.6 Set the SessionDescription +// +//nolint:gocognit,cyclop +func (pc *PeerConnection) setDescription(sd *SessionDescription, op stateChangeOp) error { + switch { + case pc.isClosed.Load(): + return &rtcerr.InvalidStateError{Err: ErrConnectionClosed} + case NewSDPType(sd.Type.String()) == SDPTypeUnknown: + return &rtcerr.TypeError{ + Err: fmt.Errorf("%w: '%d' is not a valid enum value of type SDPType", errPeerConnSDPTypeInvalidValue, sd.Type), + } + } + + nextState, err := func() (SignalingState, error) { + pc.mu.Lock() + defer pc.mu.Unlock() + + cur := pc.SignalingState() + setLocal := stateChangeOpSetLocal + setRemote := stateChangeOpSetRemote + newSDPDoesNotMatchOffer := &rtcerr.InvalidModificationError{Err: errSDPDoesNotMatchOffer} + newSDPDoesNotMatchAnswer := &rtcerr.InvalidModificationError{Err: errSDPDoesNotMatchAnswer} + + var nextState SignalingState + var err error + switch op { + case setLocal: + switch sd.Type { + // stable->SetLocal(offer)->have-local-offer + case SDPTypeOffer: + if sd.SDP != pc.lastOffer { + return nextState, newSDPDoesNotMatchOffer + } + nextState, err = checkNextSignalingState(cur, SignalingStateHaveLocalOffer, setLocal, sd.Type) + if err == nil { + pc.pendingLocalDescription = sd + } + // have-remote-offer->SetLocal(answer)->stable + // have-local-pranswer->SetLocal(answer)->stable + case SDPTypeAnswer: + if sd.SDP != pc.lastAnswer { + return nextState, newSDPDoesNotMatchAnswer + } + nextState, err = checkNextSignalingState(cur, SignalingStateStable, setLocal, sd.Type) + if err == nil { + pc.currentLocalDescription = sd + pc.currentRemoteDescription = pc.pendingRemoteDescription + pc.pendingRemoteDescription = nil + pc.pendingLocalDescription = nil + } + case SDPTypeRollback: + nextState, err = checkNextSignalingState(cur, SignalingStateStable, setLocal, sd.Type) + if err == nil { + pc.pendingLocalDescription = nil + } + // have-remote-offer->SetLocal(pranswer)->have-local-pranswer + case SDPTypePranswer: + if sd.SDP != pc.lastAnswer { + return nextState, newSDPDoesNotMatchAnswer + } + nextState, err = checkNextSignalingState(cur, SignalingStateHaveLocalPranswer, setLocal, sd.Type) + if err == nil { + pc.pendingLocalDescription = sd + } + default: + return nextState, &rtcerr.OperationError{Err: fmt.Errorf("%w: %s(%s)", errPeerConnStateChangeInvalid, op, sd.Type)} + } + case setRemote: + switch sd.Type { + // stable->SetRemote(offer)->have-remote-offer + case SDPTypeOffer: + nextState, err = checkNextSignalingState(cur, SignalingStateHaveRemoteOffer, setRemote, sd.Type) + if err == nil { + pc.pendingRemoteDescription = sd + } + // have-local-offer->SetRemote(answer)->stable + // have-remote-pranswer->SetRemote(answer)->stable + case SDPTypeAnswer: + nextState, err = checkNextSignalingState(cur, SignalingStateStable, setRemote, sd.Type) + if err == nil { + pc.currentRemoteDescription = sd + pc.currentLocalDescription = pc.pendingLocalDescription + pc.pendingRemoteDescription = nil + pc.pendingLocalDescription = nil + } + case SDPTypeRollback: + nextState, err = checkNextSignalingState(cur, SignalingStateStable, setRemote, sd.Type) + if err == nil { + pc.pendingRemoteDescription = nil + } + // have-local-offer->SetRemote(pranswer)->have-remote-pranswer + case SDPTypePranswer: + nextState, err = checkNextSignalingState(cur, SignalingStateHaveRemotePranswer, setRemote, sd.Type) + if err == nil { + pc.pendingRemoteDescription = sd + } + default: + return nextState, &rtcerr.OperationError{Err: fmt.Errorf("%w: %s(%s)", errPeerConnStateChangeInvalid, op, sd.Type)} + } + default: + return nextState, &rtcerr.OperationError{Err: fmt.Errorf("%w: %q", errPeerConnStateChangeUnhandled, op)} + } + + return nextState, err + }() + + if err == nil { + pc.signalingState.Set(nextState) + if pc.signalingState.Get() == SignalingStateStable { + pc.isNegotiationNeeded.Store(false) + pc.mu.Lock() + pc.onNegotiationNeeded() + pc.mu.Unlock() + } + pc.onSignalingStateChange(nextState) + } + + return err +} + +// SetLocalDescription sets the SessionDescription of the local peer +// +//nolint:cyclop +func (pc *PeerConnection) SetLocalDescription(desc SessionDescription) error { + if pc.isClosed.Load() { + return &rtcerr.InvalidStateError{Err: ErrConnectionClosed} + } + + haveLocalDescription := pc.currentLocalDescription != nil + + // JSEP 5.4 + if desc.SDP == "" { + switch desc.Type { + case SDPTypeAnswer, SDPTypePranswer: + desc.SDP = pc.lastAnswer + case SDPTypeOffer: + desc.SDP = pc.lastOffer + default: + return &rtcerr.InvalidModificationError{ + Err: fmt.Errorf("%w: %s", errPeerConnSDPTypeInvalidValueSetLocalDescription, desc.Type), + } + } + } + + desc.parsed = &sdp.SessionDescription{} + if err := desc.parsed.UnmarshalString(desc.SDP); err != nil { + return err + } + if err := pc.setDescription(&desc, stateChangeOpSetLocal); err != nil { + return err + } + + currentTransceivers := append([]*RTPTransceiver{}, pc.GetTransceivers()...) + + weAnswer := desc.Type == SDPTypeAnswer + remoteDesc := pc.RemoteDescription() + if weAnswer && remoteDesc != nil { + _ = setRTPTransceiverCurrentDirection(&desc, currentTransceivers, false) + if err := pc.startRTPSenders(currentTransceivers); err != nil { + return err + } + pc.configureRTPReceivers(haveLocalDescription, remoteDesc, currentTransceivers) + pc.ops.Enqueue(func() { + pc.startRTP(haveLocalDescription, remoteDesc, currentTransceivers) + }) + } + + mediaSection, ok := selectCandidateMediaSection(desc.parsed) + if ok { + pc.iceGatherer.setMediaStreamIdentification(mediaSection.SDPMid, mediaSection.SDPMLineIndex) + } + + pc.iceGatherer.flushCandidates() + + if pc.iceGatherer.State() == ICEGathererStateNew { + return pc.iceGatherer.Gather() + } + + return nil +} + +// LocalDescription returns PendingLocalDescription if it is not null and +// otherwise it returns CurrentLocalDescription. This property is used to +// determine if SetLocalDescription has already been called. +// https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-localdescription +func (pc *PeerConnection) LocalDescription() *SessionDescription { + if pendingLocalDescription := pc.PendingLocalDescription(); pendingLocalDescription != nil { + return pendingLocalDescription + } + + return pc.CurrentLocalDescription() +} + +// SetRemoteDescription sets the SessionDescription of the remote peer +// +//nolint:gocognit,gocyclo,cyclop,maintidx +func (pc *PeerConnection) SetRemoteDescription(desc SessionDescription) error { + if pc.isClosed.Load() { + return &rtcerr.InvalidStateError{Err: ErrConnectionClosed} + } + + isRenegotiation := pc.currentRemoteDescription != nil + + if _, err := desc.Unmarshal(); err != nil { + return err + } + + if err := pc.setDescription(&desc, stateChangeOpSetRemote); err != nil { + return err + } + + if err := pc.api.mediaEngine.updateFromRemoteDescription(*desc.parsed); err != nil { + return err + } + + canTrickle := hasICETrickleOption(desc.parsed) + pc.mu.Lock() + switch desc.Type { + case SDPTypeOffer, SDPTypeAnswer, SDPTypePranswer: + if canTrickle { + pc.canTrickleICECandidates = ICETrickleCapabilitySupported + } else { + pc.canTrickleICECandidates = ICETrickleCapabilityUnsupported + } + default: + pc.canTrickleICECandidates = ICETrickleCapabilityUnknown + } + pc.mu.Unlock() + + // Disable RTX/FEC on RTPSenders if the remote didn't support it + for _, sender := range pc.GetSenders() { + sender.configureRTXAndFEC() + } + + var transceiver *RTPTransceiver + localTransceivers := append([]*RTPTransceiver{}, pc.GetTransceivers()...) + detectedPlanB := descriptionIsPlanB(pc.RemoteDescription(), pc.log) + if pc.configuration.SDPSemantics != SDPSemanticsUnifiedPlan { + detectedPlanB = descriptionPossiblyPlanB(pc.RemoteDescription()) + } + + weOffer := desc.Type == SDPTypeAnswer + + if !weOffer && !detectedPlanB { //nolint:nestif + for _, media := range pc.RemoteDescription().parsed.MediaDescriptions { + midValue := getMidValue(media) + if midValue == "" { + return errPeerConnRemoteDescriptionWithoutMidValue + } + + if media.MediaName.Media == mediaSectionApplication { + continue + } + + kind := NewRTPCodecType(media.MediaName.Media) + direction := getPeerDirection(media) + if kind == 0 || direction == RTPTransceiverDirectionUnknown { + continue + } + + transceiver, localTransceivers = findByMid(midValue, localTransceivers) + if transceiver == nil { + transceiver, localTransceivers = satisfyTypeAndDirection(kind, direction, localTransceivers) + } else if direction == RTPTransceiverDirectionInactive { + if err := transceiver.Stop(); err != nil { + return err + } + } + if transceiver != nil { + transceiver.setCurrentRemoteDirection(direction) + } + + switch { + case transceiver == nil: + receiver, err := pc.api.NewRTPReceiver(kind, pc.dtlsTransport) + if err != nil { + return err + } + + localDirection := RTPTransceiverDirectionRecvonly + if direction == RTPTransceiverDirectionRecvonly { + localDirection = RTPTransceiverDirectionSendonly + } else if direction == RTPTransceiverDirectionInactive { + localDirection = RTPTransceiverDirectionInactive + } + + transceiver = newRTPTransceiver(receiver, nil, localDirection, kind, pc.api) + transceiver.setCurrentRemoteDirection(direction) + transceiver.setCodecPreferencesFromRemoteDescription(media) + pc.mu.Lock() + pc.addRTPTransceiver(transceiver) + pc.mu.Unlock() + + case direction == RTPTransceiverDirectionRecvonly: + if transceiver.Direction() == RTPTransceiverDirectionSendrecv { + transceiver.setDirection(RTPTransceiverDirectionSendonly) + } else if transceiver.Direction() == RTPTransceiverDirectionRecvonly { + transceiver.setDirection(RTPTransceiverDirectionInactive) + } + case direction == RTPTransceiverDirectionSendrecv: + if transceiver.Direction() == RTPTransceiverDirectionSendonly { + transceiver.setDirection(RTPTransceiverDirectionSendrecv) + } else if transceiver.Direction() == RTPTransceiverDirectionInactive { + transceiver.setDirection(RTPTransceiverDirectionRecvonly) + } + case direction == RTPTransceiverDirectionSendonly: + if transceiver.Direction() == RTPTransceiverDirectionInactive { + transceiver.setDirection(RTPTransceiverDirectionRecvonly) + } + } + + if transceiver.Mid() == "" { + if err := transceiver.SetMid(midValue); err != nil { + return err + } + } + } + } + + iceDetails, err := extractICEDetails(desc.parsed, pc.log) + if err != nil { + return err + } + + if isRenegotiation && pc.iceTransport.haveRemoteCredentialsChange(iceDetails.Ufrag, iceDetails.Password) { + // An ICE Restart only happens implicitly for a SetRemoteDescription of type offer + if !weOffer { + if err = pc.iceTransport.restart(); err != nil { + return err + } + } + + if err = pc.iceTransport.setRemoteCredentials(iceDetails.Ufrag, iceDetails.Password); err != nil { + return err + } + } + + for i := range iceDetails.Candidates { + if err = pc.iceTransport.AddRemoteCandidate(&iceDetails.Candidates[i]); err != nil { + return err + } + } + + currentTransceivers := append([]*RTPTransceiver{}, pc.GetTransceivers()...) + + if isRenegotiation { + if weOffer { + _ = setRTPTransceiverCurrentDirection(&desc, currentTransceivers, true) + if err = pc.startRTPSenders(currentTransceivers); err != nil { + return err + } + pc.configureRTPReceivers(true, &desc, currentTransceivers) + pc.ops.Enqueue(func() { + pc.startRTP(true, &desc, currentTransceivers) + }) + } + + return nil + } + + remoteIsLite := isIceLiteSet(desc.parsed) + + fingerprint, fingerprintHash, err := extractFingerprint(desc.parsed) + if err != nil { + return err + } + + iceRole := ICERoleControlled + // If one of the agents is lite and the other one is not, the lite agent must be the controlled agent. + // If both or neither agents are lite the offering agent is controlling. + // RFC 8445 S6.1.1 + if (weOffer && remoteIsLite == pc.api.settingEngine.candidates.ICELite) || + (remoteIsLite && !pc.api.settingEngine.candidates.ICELite) { + iceRole = ICERoleControlling + } + + // Start the networking in a new routine since it will block until + // the connection is actually established. + if weOffer { + _ = setRTPTransceiverCurrentDirection(&desc, currentTransceivers, true) + if err := pc.startRTPSenders(currentTransceivers); err != nil { + return err + } + + pc.configureRTPReceivers(false, &desc, currentTransceivers) + } + + pc.ops.Enqueue(func() { + pc.startTransports( + iceRole, + dtlsRoleFromSDP(desc.parsed), + iceDetails.Ufrag, + iceDetails.Password, + fingerprint, + fingerprintHash, + ) + if weOffer { + pc.startRTP(false, &desc, currentTransceivers) + } + }) + + return nil +} + +func (pc *PeerConnection) configureReceiver(incoming trackDetails, receiver *RTPReceiver) { + receiver.configureReceive(trackDetailsToRTPReceiveParameters(&incoming)) + + // set track id and label early so they can be set as new track information + // is received from the SDP. + for i := range receiver.tracks { + receiver.tracks[i].track.mu.Lock() + receiver.tracks[i].track.id = incoming.id + receiver.tracks[i].track.streamID = incoming.streamID + receiver.tracks[i].track.mu.Unlock() + } +} + +func (pc *PeerConnection) startReceiver(incoming trackDetails, receiver *RTPReceiver) { + if err := receiver.startReceive(trackDetailsToRTPReceiveParameters(&incoming)); err != nil { + pc.log.Warnf("RTPReceiver Receive failed %s", err) + + return + } + + for _, track := range receiver.Tracks() { + if track.SSRC() == 0 || track.RID() != "" { + return + } + + if pc.api.settingEngine.fireOnTrackBeforeFirstRTP { + pc.onTrack(track, receiver) + + return + } + go func(track *TrackRemote) { + b := make([]byte, pc.api.settingEngine.getReceiveMTU()) + n, _, err := track.peek(b) + if err != nil { + pc.log.Warnf("Could not determine PayloadType for SSRC %d (%s)", track.SSRC(), err) + + return + } + + if err = track.checkAndUpdateTrack(b[:n]); err != nil { + pc.log.Warnf("Failed to set codec settings for track SSRC %d (%s)", track.SSRC(), err) + + return + } + + pc.onTrack(track, receiver) + }(track) + } +} + +//nolint:cyclop +func setRTPTransceiverCurrentDirection( + answer *SessionDescription, + currentTransceivers []*RTPTransceiver, + weOffer bool, +) error { + currentTransceivers = append([]*RTPTransceiver{}, currentTransceivers...) + for _, media := range answer.parsed.MediaDescriptions { + midValue := getMidValue(media) + if midValue == "" { + return errPeerConnRemoteDescriptionWithoutMidValue + } + + if media.MediaName.Media == mediaSectionApplication { + continue + } + + var transceiver *RTPTransceiver + transceiver, currentTransceivers = findByMid(midValue, currentTransceivers) + + if transceiver == nil { + return fmt.Errorf("%w: %q", errPeerConnTranscieverMidNil, midValue) + } + + direction := getPeerDirection(media) + if direction == RTPTransceiverDirectionUnknown { + continue + } + + // reverse direction if it was a remote answer + if weOffer { + switch direction { + case RTPTransceiverDirectionSendonly: + direction = RTPTransceiverDirectionRecvonly + case RTPTransceiverDirectionRecvonly: + direction = RTPTransceiverDirectionSendonly + default: + } + } + + // If a transceiver is created by applying a remote description that has recvonly transceiver, + // it will have no sender. In this case, the transceiver's current direction is set to inactive so + // that the transceiver can be reused by next AddTrack. + if !weOffer && direction == RTPTransceiverDirectionSendonly && transceiver.Sender() == nil { + direction = RTPTransceiverDirectionInactive + } + + transceiver.setCurrentDirection(direction) + } + + return nil +} + +func runIfNewReceiver( + incomingTrack trackDetails, + transceivers []*RTPTransceiver, + callbackFunc func(incomingTrack trackDetails, receiver *RTPReceiver), +) bool { + for _, t := range transceivers { + if t.Mid() != incomingTrack.mid { + continue + } + + receiver := t.Receiver() + if (incomingTrack.kind != t.Kind()) || + (t.Direction() != RTPTransceiverDirectionRecvonly && t.Direction() != RTPTransceiverDirectionSendrecv) || + receiver == nil || + (receiver.haveReceived()) { + continue + } + + callbackFunc(incomingTrack, receiver) + + return true + } + + return false +} + +// configureRTPReceivers opens knows inbound SRTP streams from the RemoteDescription. +// +//nolint:gocognit,cyclop +func (pc *PeerConnection) configureRTPReceivers( + isRenegotiation bool, + remoteDesc *SessionDescription, + currentTransceivers []*RTPTransceiver, +) { + incomingTracks := trackDetailsFromSDP(pc.log, remoteDesc.parsed) + + if isRenegotiation { //nolint:nestif + for _, transceiver := range currentTransceivers { + receiver := transceiver.Receiver() + if receiver == nil { + continue + } + + tracks := transceiver.Receiver().Tracks() + if len(tracks) == 0 { + continue + } + + mid := transceiver.Mid() + receiverNeedsStopped := false + for _, trackRemote := range tracks { + func(track *TrackRemote) { + track.mu.Lock() + defer track.mu.Unlock() + + if track.rid != "" { + if details := trackDetailsForRID(incomingTracks, mid, track.rid); details != nil { + track.id = details.id + track.streamID = details.streamID + + return + } + } else if track.ssrc != 0 { + if details := trackDetailsForSSRC(incomingTracks, track.ssrc); details != nil { + track.id = details.id + track.streamID = details.streamID + + return + } + } + + receiverNeedsStopped = true + }(trackRemote) + } + + if !receiverNeedsStopped { + continue + } + + if err := receiver.Stop(); err != nil { + pc.log.Warnf("Failed to stop RtpReceiver: %s", err) + + continue + } + + receiver, err := pc.api.NewRTPReceiver(receiver.kind, pc.dtlsTransport) + if err != nil { + pc.log.Warnf("Failed to create new RtpReceiver: %s", err) + + continue + } + transceiver.setReceiver(receiver) + } + } + + localTransceivers := append([]*RTPTransceiver{}, currentTransceivers...) + + // Ensure we haven't already started a transceiver for this ssrc + filteredTracks := append([]trackDetails{}, incomingTracks...) + for _, incomingTrack := range incomingTracks { + // If we already have a TrackRemote for a given SSRC don't handle it again + for _, t := range localTransceivers { + if receiver := t.Receiver(); receiver != nil { + for _, track := range receiver.Tracks() { + for _, ssrc := range incomingTrack.ssrcs { + if ssrc == track.SSRC() { + filteredTracks = filterTrackWithSSRC(filteredTracks, track.SSRC()) + } + } + } + } + } + } + + for _, incomingTrack := range filteredTracks { + _ = runIfNewReceiver(incomingTrack, localTransceivers, pc.configureReceiver) + } +} + +// startRTPReceivers opens knows inbound SRTP streams from the RemoteDescription. +func (pc *PeerConnection) startRTPReceivers(remoteDesc *SessionDescription, currentTransceivers []*RTPTransceiver) { + incomingTracks := trackDetailsFromSDP(pc.log, remoteDesc.parsed) + if len(incomingTracks) == 0 { + return + } + + localTransceivers := append([]*RTPTransceiver{}, currentTransceivers...) + + unhandledTracks := incomingTracks[:0] + for _, incomingTrack := range incomingTracks { + trackHandled := runIfNewReceiver(incomingTrack, localTransceivers, pc.startReceiver) + if !trackHandled { + unhandledTracks = append(unhandledTracks, incomingTrack) + } + } + + remoteIsPlanB := false + switch pc.configuration.SDPSemantics { + case SDPSemanticsPlanB: + remoteIsPlanB = true + case SDPSemanticsUnifiedPlanWithFallback: + remoteIsPlanB = descriptionPossiblyPlanB(pc.RemoteDescription()) + default: + // none + } + + if remoteIsPlanB { + for _, incomingTrack := range unhandledTracks { + t, err := pc.AddTransceiverFromKind(incomingTrack.kind, RTPTransceiverInit{ + Direction: RTPTransceiverDirectionSendrecv, + }) + if err != nil { + pc.log.Warnf("Could not add transceiver for remote SSRC %d: %s", incomingTrack.ssrcs[0], err) + + continue + } + pc.configureReceiver(incomingTrack, t.Receiver()) + pc.startReceiver(incomingTrack, t.Receiver()) + } + } +} + +// startRTPSenders starts all outbound RTP streams. +func (pc *PeerConnection) startRTPSenders(currentTransceivers []*RTPTransceiver) error { + for _, transceiver := range currentTransceivers { + if sender := transceiver.Sender(); sender != nil && sender.isNegotiated() && !sender.hasSent() { + err := sender.Send(sender.GetParameters()) + if err != nil { + return err + } + } + } + + return nil +} + +// Start SCTP subsystem. +func (pc *PeerConnection) startSCTP(maxMessageSize uint32, remoteSctpInit []byte) { + // Start sctp + if err := pc.sctpTransport.Start(SCTPCapabilities{ + MaxMessageSize: maxMessageSize, + sctpInit: string(remoteSctpInit), + }); err != nil { + pc.log.Warnf("Failed to start SCTP: %s", err) + if err = pc.sctpTransport.Stop(); err != nil { + pc.log.Warnf("Failed to stop SCTPTransport: %s", err) + } + + return + } +} + +func (pc *PeerConnection) handleUndeclaredSSRC( + ssrc SSRC, + mediaSection *sdp.MediaDescription, +) (handled bool, err error) { + streamID := "" + id := "" + hasRidAttribute := false + hasSSRCAttribute := false + + for _, a := range mediaSection.Attributes { + switch a.Key { + case sdp.AttrKeyMsid: + if split := strings.Split(a.Value, " "); len(split) == 2 { + streamID = split[0] + id = split[1] + } + case sdp.AttrKeySSRC: + hasSSRCAttribute = true + case sdpAttributeRid: + hasRidAttribute = true + } + } + + if hasRidAttribute { + return false, nil + } else if hasSSRCAttribute { + return false, errMediaSectionHasExplictSSRCAttribute + } + + incoming := trackDetails{ + ssrcs: []SSRC{ssrc}, + kind: RTPCodecTypeVideo, + streamID: streamID, + id: id, + } + if mediaSection.MediaName.Media == RTPCodecTypeAudio.String() { + incoming.kind = RTPCodecTypeAudio + } + + t, err := pc.AddTransceiverFromKind(incoming.kind, RTPTransceiverInit{ + Direction: RTPTransceiverDirectionSendrecv, + }) + if err != nil { + // nolint + return false, fmt.Errorf("%w: %d: %s", errPeerConnRemoteSSRCAddTransceiver, ssrc, err) + } + + pc.configureReceiver(incoming, t.Receiver()) + pc.startReceiver(incoming, t.Receiver()) + + return true, nil +} + +// For legacy clients that didn't support urn:ietf:params:rtp-hdrext:sdes:rtp-stream-id +// or urn:ietf:params:rtp-hdrext:sdes:mid extension, and didn't declare a=ssrc lines. +// Assumes that the payload type is unique across the media section. +func (pc *PeerConnection) findMediaSectionByPayloadType( + payloadType PayloadType, + remoteDescription *SessionDescription, +) (selectedMediaSection *sdp.MediaDescription, ok bool) { + for i := range remoteDescription.parsed.MediaDescriptions { + descr := remoteDescription.parsed.MediaDescriptions[i] + media := descr.MediaName.Media + if !strings.EqualFold(media, "video") && !strings.EqualFold(media, "audio") { + continue + } + + formats := descr.MediaName.Formats + for _, payloadStr := range formats { + payload, err := strconv.ParseUint(payloadStr, 10, 8) + if err != nil { + continue + } + + // Return the first media section that has the payload type. + // Assuming that the payload type is unique across the media section. + if PayloadType(payload) == payloadType { + return remoteDescription.parsed.MediaDescriptions[i], true + } + } + } + + return nil, false +} + +// Chrome sends probing traffic on SSRC 0. This reads the packets to ensure that we properly +// generate TWCC reports for it. Since this isn't actually media we don't pass this to the user. +func (pc *PeerConnection) handleNonMediaBandwidthProbe() { + nonMediaBandwidthProbe, err := pc.api.NewRTPReceiver(RTPCodecTypeVideo, pc.dtlsTransport) + if err != nil { + pc.log.Errorf("handleNonMediaBandwidthProbe failed to create RTPReceiver: %v", err) + + return + } + + if err = nonMediaBandwidthProbe.Receive(RTPReceiveParameters{ + Encodings: []RTPDecodingParameters{{RTPCodingParameters: RTPCodingParameters{}}}, + }); err != nil { + pc.log.Errorf("handleNonMediaBandwidthProbe failed to start RTPReceiver: %v", err) + + return + } + + pc.nonMediaBandwidthProbe.Store(nonMediaBandwidthProbe) + b := make([]byte, pc.api.settingEngine.getReceiveMTU()) + for { + if _, _, err = nonMediaBandwidthProbe.readRTP(b, nonMediaBandwidthProbe.Track()); err != nil { + pc.log.Tracef("handleNonMediaBandwidthProbe read exiting: %v", err) + + return + } + } +} + +func (pc *PeerConnection) handleIncomingSSRC(rtpStream *srtp.ReadStreamSRTP, ssrc SSRC) error { //nolint:gocyclo,gocognit,cyclop,lll + remoteDescription := pc.RemoteDescription() + if remoteDescription == nil { + return errPeerConnRemoteDescriptionNil + } + + // If a SSRC already exists in the RemoteDescription don't perform heuristics upon it + for _, track := range trackDetailsFromSDP(pc.log, remoteDescription.parsed) { + if track.rtxSsrc != nil && ssrc == *track.rtxSsrc { + return nil + } + if track.fecSsrc != nil && ssrc == *track.fecSsrc { + return nil + } + if slices.Contains(track.ssrcs, ssrc) { + return nil + } + } + + // if the SSRC is not declared in the SDP and there is only one media section, + // we attempt to resolve it using this single section + // This applies even if the client supports RTP extensions: + // (urn:ietf:params:rtp-hdrext:sdes:rtp-stream-id and urn:ietf:params:rtp-hdrext:sdes:mid) + // and even if the RTP stream contains an incorrect MID or RID. + // while this can be incorrect, this is done to maintain compatibility with older behavior. + if remoteDescription.Type != SDPTypeAnswer || pc.api.settingEngine.handleUndeclaredSSRCWithoutAnswer { + if len(remoteDescription.parsed.MediaDescriptions) == 1 { + mediaSection := remoteDescription.parsed.MediaDescriptions[0] + if handled, err := pc.handleUndeclaredSSRC(ssrc, mediaSection); handled || err != nil { + return err + } + } + } + + // We read the RTP packet to determine the payload type + b := make([]byte, pc.api.settingEngine.getReceiveMTU()) + + i, err := rtpStream.Peek(b) + if err != nil { + return err + } + + if i < 4 { + return errRTPTooShort + } + + payloadType := PayloadType(b[1] & 0x7f) + params, err := pc.api.mediaEngine.getRTPParametersByPayloadType(payloadType) + if err != nil { + return err + } + + midExtensionID, audioSupported, videoSupported := pc.api.mediaEngine.getHeaderExtensionID( + RTPHeaderExtensionCapability{sdp.SDESMidURI}, + ) + if !audioSupported && !videoSupported { + if remoteDescription.Type == SDPTypeAnswer && !pc.api.settingEngine.handleUndeclaredSSRCWithoutAnswer { + // if we are offerer, wait for answer with media setion to process this SSRC + return errPeerConnEarlyMediaWithoutAnswer + } + + // try to find media section by payload type as a last resort for legacy clients. + mediaSection, ok := pc.findMediaSectionByPayloadType(payloadType, remoteDescription) + if ok { + if ok, err = pc.handleUndeclaredSSRC(ssrc, mediaSection); ok || err != nil { + return err + } + } + + return errPeerConnSimulcastMidRTPExtensionRequired + } + + streamIDExtensionID, audioSupported, videoSupported := pc.api.mediaEngine.getHeaderExtensionID( + RTPHeaderExtensionCapability{sdp.SDESRTPStreamIDURI}, + ) + if !audioSupported && !videoSupported { + return errPeerConnSimulcastStreamIDRTPExtensionRequired + } + + repairStreamIDExtensionID, _, _ := pc.api.mediaEngine.getHeaderExtensionID( + RTPHeaderExtensionCapability{sdp.SDESRepairRTPStreamIDURI}, + ) + + streamInfo := createStreamInfo( + "", + ssrc, + 0, 0, + params.Codecs[0].PayloadType, + 0, 0, + params.Codecs[0].RTPCodecCapability, + params.HeaderExtensions, + ) + result, err := pc.dtlsTransport.streamsForSSRC(ssrc, *streamInfo) + if err != nil { + return err + } + readStream := result.rtpReadStream + interceptor := result.rtpInterceptor + rtcpReadStream := result.rtcpReadStream + rtcpInterceptor := result.rtcpInterceptor + + // try to read simulcast IDs from the packet we already have + mid, rid, rsid, _, err := handleUnknownRTPPacket( + b[:i], uint8(midExtensionID), //nolint:gosec // G115 + uint8(streamIDExtensionID), //nolint:gosec // G115 + uint8(repairStreamIDExtensionID), //nolint:gosec // G115 + ) + if err != nil { + return err + } + + peekedPackets := []*peekedPacket(nil) + + // if the first packet didn't contain simuilcast IDs, then probe more packets + var paddingOnly bool + for readCount := 0; readCount <= simulcastProbeCount; readCount++ { + if mid == "" || (rid == "" && rsid == "") { + // skip padding only packets for probing + if paddingOnly { + readCount-- + } + + i, attributes, err := interceptor.Read(b, nil) + if err != nil { + return err + } + + peekedPackets = append(peekedPackets, &peekedPacket{ + payload: slices.Clone(b[:i]), + attributes: attributes, + }) + + mid, rid, rsid, paddingOnly, err = handleUnknownRTPPacket( + b[:i], uint8(midExtensionID), //nolint:gosec // G115 + uint8(streamIDExtensionID), //nolint:gosec // G115 + uint8(repairStreamIDExtensionID), //nolint:gosec // G115 + ) + if err != nil { + return err + } + + continue + } + + for _, t := range pc.GetTransceivers() { + receiver := t.Receiver() + if t.Mid() != mid || receiver == nil { + continue + } + + if rsid != "" { + return receiver.receiveForRtx(SSRC(0), rsid, streamInfo, readStream, interceptor, rtcpReadStream, rtcpInterceptor) + } + + track, err := receiver.receiveForRid( + rid, + params, + streamInfo, + readStream, + interceptor, + rtcpReadStream, + rtcpInterceptor, + peekedPackets, + ) + if err != nil { + return err + } + pc.onTrack(track, receiver) + + return nil + } + } + + pc.api.interceptor.UnbindRemoteStream(streamInfo) + + return errPeerConnSimulcastIncomingSSRCFailed +} + +// undeclaredMediaProcessor handles RTP/RTCP packets that don't match any a:ssrc lines. +func (pc *PeerConnection) undeclaredMediaProcessor() { + go pc.undeclaredRTPMediaProcessor() + go pc.undeclaredRTCPMediaProcessor() +} + +func (pc *PeerConnection) undeclaredRTPMediaProcessor() { //nolint:cyclop + var simulcastRoutineCount uint64 + for { + srtpSession, err := pc.dtlsTransport.getSRTPSession() + if err != nil { + pc.log.Warnf("undeclaredMediaProcessor failed to open SrtpSession: %v", err) + + return + } + + srtcpSession, err := pc.dtlsTransport.getSRTCPSession() + if err != nil { + pc.log.Warnf("undeclaredMediaProcessor failed to open SrtcpSession: %v", err) + + return + } + + srtpReadStream, ssrc, err := srtpSession.AcceptStream() + if err != nil { + pc.log.Warnf("Failed to accept RTP %v", err) + + return + } + + // open accompanying srtcp stream + srtcpReadStream, err := srtcpSession.OpenReadStream(ssrc) + if err != nil { + pc.log.Warnf("Failed to open RTCP stream for %d: %v", ssrc, err) + + return + } + + if pc.isClosed.Load() { + if err = srtpReadStream.Close(); err != nil { + pc.log.Warnf("Failed to close RTP stream %v", err) + } + if err = srtcpReadStream.Close(); err != nil { + pc.log.Warnf("Failed to close RTCP stream %v", err) + } + + continue + } + + pc.dtlsTransport.storeSimulcastStream(srtpReadStream, srtcpReadStream) + + if ssrc == 0 { + go pc.handleNonMediaBandwidthProbe() + + continue + } + + if atomic.AddUint64(&simulcastRoutineCount, 1) >= simulcastMaxProbeRoutines { + atomic.AddUint64(&simulcastRoutineCount, ^uint64(0)) + pc.log.Warn(ErrSimulcastProbeOverflow.Error()) + + continue + } + + go func(rtpStream *srtp.ReadStreamSRTP, ssrc SSRC) { + if err := pc.handleIncomingSSRC(rtpStream, ssrc); err != nil { + pc.log.Errorf(incomingUnhandledRTPSsrc, ssrc, err) + } + atomic.AddUint64(&simulcastRoutineCount, ^uint64(0)) + }(srtpReadStream, SSRC(ssrc)) + } +} + +func (pc *PeerConnection) undeclaredRTCPMediaProcessor() { + var unhandledStreams []*srtp.ReadStreamSRTCP + defer func() { + for _, s := range unhandledStreams { + _ = s.Close() + } + }() + for { + srtcpSession, err := pc.dtlsTransport.getSRTCPSession() + if err != nil { + pc.log.Warnf("undeclaredMediaProcessor failed to open SrtcpSession: %v", err) + + return + } + + stream, ssrc, err := srtcpSession.AcceptStream() + if err != nil { + pc.log.Warnf("Failed to accept RTCP %v", err) + + return + } + pc.log.Warnf("Incoming unhandled RTCP ssrc(%d), OnTrack will not be fired", ssrc) + unhandledStreams = append(unhandledStreams, stream) + } +} + +// RemoteDescription returns pendingRemoteDescription if it is not null and +// otherwise it returns currentRemoteDescription. This property is used to +// determine if setRemoteDescription has already been called. +// https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-remotedescription +func (pc *PeerConnection) RemoteDescription() *SessionDescription { + pc.mu.RLock() + defer pc.mu.RUnlock() + + if pc.pendingRemoteDescription != nil { + return pc.pendingRemoteDescription + } + + return pc.currentRemoteDescription +} + +// AddICECandidate accepts an ICE candidate string and adds it +// to the existing set of candidates. +func (pc *PeerConnection) AddICECandidate(candidate ICECandidateInit) error { + remoteDesc := pc.RemoteDescription() + if remoteDesc == nil { + return &rtcerr.InvalidStateError{Err: ErrNoRemoteDescription} + } + + candidateValue := strings.TrimPrefix(candidate.Candidate, "candidate:") + + if candidateValue == "" { + return pc.iceTransport.AddRemoteCandidate(nil) + } + + cand, err := ice.UnmarshalCandidate(candidateValue) + if err != nil { + if errors.Is(err, ice.ErrUnknownCandidateTyp) || errors.Is(err, ice.ErrDetermineNetworkType) { + pc.log.Warnf("Discarding remote candidate: %s", err) + + return nil + } + + return err + } + + // Reject candidates from old generations. + // If candidate.usernameFragment is not null, + // and is not equal to any username fragment present in the corresponding media + // description of an applied remote description, + // return a promise rejected with a newly created OperationError. + // https://w3c.github.io/webrtc-pc/#dom-peerconnection-addicecandidate + if ufrag, ok := cand.GetExtension("ufrag"); ok { + if !pc.descriptionContainsUfrag(remoteDesc.parsed, ufrag.Value) { + pc.log.Errorf("dropping candidate with ufrag %s because it doesn't match the current ufrags", ufrag.Value) + + return nil + } + } + + c, err := newICECandidateFromICE(cand, "", 0) + if err != nil { + return err + } + + return pc.iceTransport.AddRemoteCandidate(&c) +} + +// Return true if the sdp contains a specific ufrag. +func (pc *PeerConnection) descriptionContainsUfrag(sdp *sdp.SessionDescription, matchUfrag string) bool { + ufrag, ok := sdp.Attribute("ice-ufrag") + if ok && ufrag == matchUfrag { + return true + } + + for _, media := range sdp.MediaDescriptions { + ufrag, ok := media.Attribute("ice-ufrag") + if ok && ufrag == matchUfrag { + return true + } + } + + return false +} + +// ICEConnectionState returns the ICE connection state of the +// PeerConnection instance. +func (pc *PeerConnection) ICEConnectionState() ICEConnectionState { + if state, ok := pc.iceConnectionState.Load().(ICEConnectionState); ok { + return state + } + + return ICEConnectionState(0) +} + +// GetSenders returns the RTPSender that are currently attached to this PeerConnection. +func (pc *PeerConnection) GetSenders() (result []*RTPSender) { + pc.mu.Lock() + defer pc.mu.Unlock() + + for _, transceiver := range pc.rtpTransceivers { + if sender := transceiver.Sender(); sender != nil { + result = append(result, sender) + } + } + + return result +} + +// GetReceivers returns the RTPReceivers that are currently attached to this PeerConnection. +func (pc *PeerConnection) GetReceivers() (receivers []*RTPReceiver) { + pc.mu.Lock() + defer pc.mu.Unlock() + + for _, transceiver := range pc.rtpTransceivers { + if receiver := transceiver.Receiver(); receiver != nil { + receivers = append(receivers, receiver) + } + } + + return +} + +// GetTransceivers returns the RtpTransceiver that are currently attached to this PeerConnection. +func (pc *PeerConnection) GetTransceivers() []*RTPTransceiver { + pc.mu.Lock() + defer pc.mu.Unlock() + + return pc.rtpTransceivers +} + +// AddTrack adds a Track to the PeerConnection. +// +//nolint:cyclop +func (pc *PeerConnection) AddTrack(track TrackLocal) (*RTPSender, error) { + if pc.isClosed.Load() { + return nil, &rtcerr.InvalidStateError{Err: ErrConnectionClosed} + } + + pc.mu.Lock() + defer pc.mu.Unlock() + for _, transceiver := range pc.rtpTransceivers { + if !transceiver.isSendAllowed(track.Kind()) { + continue + } + + sender, err := pc.api.NewRTPSender(track, pc.dtlsTransport) + if err == nil { + err = transceiver.SetSender(sender, track) + if err != nil { + _ = sender.Stop() + transceiver.setSender(nil) + } + } + if err != nil { + return nil, err + } + pc.onNegotiationNeeded() + + return sender, nil + } + + transceiver, err := pc.newTransceiverFromTrack(RTPTransceiverDirectionSendrecv, track) + if err != nil { + return nil, err + } + pc.addRTPTransceiver(transceiver) + + return transceiver.Sender(), nil +} + +// RemoveTrack removes a Track from the PeerConnection. +func (pc *PeerConnection) RemoveTrack(sender *RTPSender) (err error) { + if pc.isClosed.Load() { + return &rtcerr.InvalidStateError{Err: ErrConnectionClosed} + } + + var transceiver *RTPTransceiver + pc.mu.Lock() + defer pc.mu.Unlock() + for _, t := range pc.rtpTransceivers { + if t.Sender() == sender { + transceiver = t + + break + } + } + if transceiver == nil { + return &rtcerr.InvalidAccessError{Err: ErrSenderNotCreatedByConnection} + } else if err = sender.Stop(); err == nil { + err = transceiver.setSendingTrack(nil) + if err == nil { + pc.onNegotiationNeeded() + } + } + + return +} + +//nolint:cyclop +func (pc *PeerConnection) newTransceiverFromTrack( + direction RTPTransceiverDirection, + track TrackLocal, + init ...RTPTransceiverInit, +) (t *RTPTransceiver, err error) { + var ( + receiver *RTPReceiver + sender *RTPSender + ) + switch direction { + case RTPTransceiverDirectionSendrecv: + receiver, err = pc.api.NewRTPReceiver(track.Kind(), pc.dtlsTransport) + if err != nil { + return t, err + } + sender, err = pc.api.NewRTPSender(track, pc.dtlsTransport) + case RTPTransceiverDirectionSendonly: + sender, err = pc.api.NewRTPSender(track, pc.dtlsTransport) + default: + err = errPeerConnAddTransceiverFromTrackSupport + } + if err != nil { + return t, err + } + + // Allow RTPTransceiverInit to override SSRC + if sender != nil && len(sender.trackEncodings) == 1 && + len(init) == 1 && len(init[0].SendEncodings) == 1 && init[0].SendEncodings[0].SSRC != 0 { + sender.trackEncodings[0].ssrc = init[0].SendEncodings[0].SSRC + } + + return newRTPTransceiver(receiver, sender, direction, track.Kind(), pc.api), nil +} + +// AddTransceiverFromKind Create a new RtpTransceiver and adds it to the set of transceivers. +// +//nolint:cyclop +func (pc *PeerConnection) AddTransceiverFromKind( + kind RTPCodecType, + init ...RTPTransceiverInit, +) (t *RTPTransceiver, err error) { + if pc.isClosed.Load() { + return nil, &rtcerr.InvalidStateError{Err: ErrConnectionClosed} + } + + direction := RTPTransceiverDirectionSendrecv + if len(init) > 1 { + return nil, errPeerConnAddTransceiverFromKindOnlyAcceptsOne + } else if len(init) == 1 { + direction = init[0].Direction + } + switch direction { + case RTPTransceiverDirectionSendonly, RTPTransceiverDirectionSendrecv: + codecs := pc.api.mediaEngine.getCodecsByKind(kind) + if len(codecs) == 0 { + return nil, ErrNoCodecsAvailable + } + track, err := NewTrackLocalStaticSample(codecs[0].RTPCodecCapability, util.MathRandAlpha(16), util.MathRandAlpha(16)) + if err != nil { + return nil, err + } + t, err = pc.newTransceiverFromTrack(direction, track, init...) + if err != nil { + return nil, err + } + case RTPTransceiverDirectionRecvonly: + receiver, err := pc.api.NewRTPReceiver(kind, pc.dtlsTransport) + if err != nil { + return nil, err + } + t = newRTPTransceiver(receiver, nil, RTPTransceiverDirectionRecvonly, kind, pc.api) + default: + return nil, errPeerConnAddTransceiverFromKindSupport + } + pc.mu.Lock() + pc.addRTPTransceiver(t) + pc.mu.Unlock() + + return t, nil +} + +// AddTransceiverFromTrack Create a new RtpTransceiver(SendRecv or SendOnly) and add it to the set of transceivers. +func (pc *PeerConnection) AddTransceiverFromTrack( + track TrackLocal, + init ...RTPTransceiverInit, +) (t *RTPTransceiver, err error) { + if pc.isClosed.Load() { + return nil, &rtcerr.InvalidStateError{Err: ErrConnectionClosed} + } + + direction := RTPTransceiverDirectionSendrecv + if len(init) > 1 { + return nil, errPeerConnAddTransceiverFromTrackOnlyAcceptsOne + } else if len(init) == 1 { + direction = init[0].Direction + } + + t, err = pc.newTransceiverFromTrack(direction, track, init...) + if err == nil { + pc.mu.Lock() + pc.addRTPTransceiver(t) + pc.mu.Unlock() + } + + return +} + +// CreateDataChannel creates a new DataChannel object with the given label +// and optional DataChannelInit used to configure properties of the +// underlying channel such as data reliability. +// +//nolint:cyclop +func (pc *PeerConnection) CreateDataChannel(label string, options *DataChannelInit) (*DataChannel, error) { + // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #2) + if pc.isClosed.Load() { + return nil, &rtcerr.InvalidStateError{Err: ErrConnectionClosed} + } + + params := &DataChannelParameters{ + Label: label, + Ordered: true, + } + + // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #19) + if options != nil { + params.ID = options.ID + } + + if options != nil { //nolint:nestif + // Ordered indicates if data is allowed to be delivered out of order. The + // default value of true, guarantees that data will be delivered in order. + // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #9) + if options.Ordered != nil { + params.Ordered = *options.Ordered + } + + // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #7) + if options.MaxPacketLifeTime != nil { + params.MaxPacketLifeTime = options.MaxPacketLifeTime + } + + // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #8) + if options.MaxRetransmits != nil { + params.MaxRetransmits = options.MaxRetransmits + } + + // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #10) + if options.Protocol != nil { + params.Protocol = *options.Protocol + } + + // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #11) + if len(params.Protocol) > 65535 { + return nil, &rtcerr.TypeError{Err: ErrProtocolTooLarge} + } + + // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #12) + if options.Negotiated != nil { + params.Negotiated = *options.Negotiated + } + } + + dataChannel, err := pc.api.newDataChannel(params, nil, pc.log) + if err != nil { + return nil, err + } + + // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #16) + if dataChannel.maxPacketLifeTime != nil && dataChannel.maxRetransmits != nil { + return nil, &rtcerr.TypeError{Err: ErrRetransmitsOrPacketLifeTime} + } + + pc.sctpTransport.lock.Lock() + pc.sctpTransport.dataChannels = append(pc.sctpTransport.dataChannels, dataChannel) + if dataChannel.ID() != nil { + pc.sctpTransport.dataChannelIDsUsed[*dataChannel.ID()] = struct{}{} + } + pc.sctpTransport.dataChannelsRequested++ + pc.sctpTransport.lock.Unlock() + + // If SCTP already connected open all the channels + if pc.sctpTransport.State() == SCTPTransportStateConnected { + if err = dataChannel.open(pc.sctpTransport); err != nil { + return nil, err + } + } + + pc.mu.Lock() + pc.onNegotiationNeeded() + pc.mu.Unlock() + + return dataChannel, nil +} + +// SetIdentityProvider is used to configure an identity provider to generate identity assertions. +func (pc *PeerConnection) SetIdentityProvider(string) error { + return errPeerConnSetIdentityProviderNotImplemented +} + +// WriteRTCP sends a user provided RTCP packet to the connected peer. If no peer is connected the +// packet is discarded. It also runs any configured interceptors. +func (pc *PeerConnection) WriteRTCP(pkts []rtcp.Packet) error { + _, err := pc.interceptorRTCPWriter.Write(pkts, make(interceptor.Attributes)) + + return err +} + +func (pc *PeerConnection) writeRTCP(pkts []rtcp.Packet, _ interceptor.Attributes) (int, error) { + return pc.dtlsTransport.WriteRTCP(pkts) +} + +// Close ends the PeerConnection. +func (pc *PeerConnection) Close() error { + return pc.close(false /* shouldGracefullyClose */) +} + +// GracefulClose ends the PeerConnection. It also waits +// for any goroutines it started to complete. This is only safe to call outside of +// PeerConnection callbacks or if in a callback, in its own goroutine. +func (pc *PeerConnection) GracefulClose() error { + return pc.close(true /* shouldGracefullyClose */) +} + +func (pc *PeerConnection) close(shouldGracefullyClose bool) error { //nolint:cyclop + // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #1) + // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #2) + + pc.mu.Lock() + // A lock in this critical section is needed because pc.isClosed and + // pc.isGracefullyClosingOrClosed are related to each other in that we + // want to make graceful and normal closure one time operations in order + // to avoid any double closure errors from cropping up. However, there are + // some overlapping close cases when both normal and graceful close are used + // that should be idempotent, but be cautioned when writing new close behavior + // to preserve this property. + isAlreadyClosingOrClosed := pc.isClosed.Swap(true) + isAlreadyGracefullyClosingOrClosed := pc.isGracefullyClosingOrClosed + if shouldGracefullyClose && !isAlreadyGracefullyClosingOrClosed { + pc.isGracefullyClosingOrClosed = true + } + pc.mu.Unlock() + + if isAlreadyClosingOrClosed { + if !shouldGracefullyClose { + return nil + } + // Even if we're already closing, it may not be graceful: + // If we are not the ones doing the closing, we just wait for the graceful close + // to happen and then return. + if isAlreadyGracefullyClosingOrClosed { + <-pc.isGracefulCloseDone + + return nil + } + // Otherwise we need to go through the graceful closure flow once the + // normal closure is done since there are extra steps to take with a + // graceful close. + <-pc.isCloseDone + } else { + defer close(pc.isCloseDone) + } + + if shouldGracefullyClose { + defer close(pc.isGracefulCloseDone) + } + + // Try closing everything and collect the errors + // Shutdown strategy: + // 1. All Conn close by closing their underlying Conn. + // 2. A Mux stops this chain. It won't close the underlying + // Conn if one of the endpoints is closed down. To + // continue the chain the Mux has to be closed. + closeErrs := make([]error, 0, 4) + + doGracefulCloseOps := func() []error { + if !shouldGracefullyClose { + return nil + } + + // these are all non-canon steps + var gracefulCloseErrors []error + if pc.iceTransport != nil { + gracefulCloseErrors = append(gracefulCloseErrors, pc.iceTransport.GracefulStop()) + } + + pc.ops.GracefulClose() + + pc.sctpTransport.lock.Lock() + for _, d := range pc.sctpTransport.dataChannels { + gracefulCloseErrors = append(gracefulCloseErrors, d.GracefulClose()) + } + pc.sctpTransport.lock.Unlock() + + return gracefulCloseErrors + } + + if isAlreadyClosingOrClosed { + return util.FlattenErrs(doGracefulCloseOps()) + } + + // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #3) + pc.signalingState.Set(SignalingStateClosed) + + // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #4) + pc.mu.Lock() + for _, t := range pc.rtpTransceivers { + closeErrs = append(closeErrs, t.Stop()) + } + if nonMediaBandwidthProbe, ok := pc.nonMediaBandwidthProbe.Load().(*RTPReceiver); ok { + closeErrs = append(closeErrs, nonMediaBandwidthProbe.Stop()) + } + pc.mu.Unlock() + + // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #5) + pc.sctpTransport.lock.Lock() + for _, d := range pc.sctpTransport.dataChannels { + d.setReadyState(DataChannelStateClosed) + } + pc.sctpTransport.lock.Unlock() + + // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #6) + if pc.sctpTransport != nil { + closeErrs = append(closeErrs, pc.sctpTransport.Stop()) + } + + // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #7) + closeErrs = append(closeErrs, pc.dtlsTransport.Stop()) + + // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #8, #9, #10) + if pc.iceTransport != nil && !shouldGracefullyClose { + // we will stop gracefully in doGracefulCloseOps + closeErrs = append(closeErrs, pc.iceTransport.Stop()) + } + + // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #11) + pc.updateConnectionState(pc.ICEConnectionState(), pc.dtlsTransport.State()) + + closeErrs = append(closeErrs, doGracefulCloseOps()...) + + pc.statsGetter = nil + cleanupStats(pc.id) + + // Interceptor closes at the end to prevent Bind from being called after interceptor is closed + closeErrs = append(closeErrs, pc.api.interceptor.Close()) + + return util.FlattenErrs(closeErrs) +} + +// addRTPTransceiver appends t into rtpTransceivers +// and fires onNegotiationNeeded; +// caller of this method should hold `pc.mu` lock. +func (pc *PeerConnection) addRTPTransceiver(t *RTPTransceiver) { + pc.rtpTransceivers = append(pc.rtpTransceivers, t) + pc.onNegotiationNeeded() +} + +// CurrentLocalDescription represents the local description that was +// successfully negotiated the last time the PeerConnection transitioned +// into the stable state plus any local candidates that have been generated +// by the ICEAgent since the offer or answer was created. +func (pc *PeerConnection) CurrentLocalDescription() *SessionDescription { + pc.mu.Lock() + defer pc.mu.Unlock() + + localDescription := pc.currentLocalDescription + iceGather := pc.iceGatherer + iceGatheringState := pc.ICEGatheringState() + + return populateLocalCandidates(localDescription, iceGather, iceGatheringState) +} + +// PendingLocalDescription represents a local description that is in the +// process of being negotiated plus any local candidates that have been +// generated by the ICEAgent since the offer or answer was created. If the +// PeerConnection is in the stable state, the value is null. +func (pc *PeerConnection) PendingLocalDescription() *SessionDescription { + pc.mu.Lock() + defer pc.mu.Unlock() + + localDescription := pc.pendingLocalDescription + iceGather := pc.iceGatherer + iceGatheringState := pc.ICEGatheringState() + + return populateLocalCandidates(localDescription, iceGather, iceGatheringState) +} + +// CurrentRemoteDescription represents the last remote description that was +// successfully negotiated the last time the PeerConnection transitioned +// into the stable state plus any remote candidates that have been supplied +// via AddICECandidate() since the offer or answer was created. +func (pc *PeerConnection) CurrentRemoteDescription() *SessionDescription { + pc.mu.RLock() + defer pc.mu.RUnlock() + + return pc.currentRemoteDescription +} + +// PendingRemoteDescription represents a remote description that is in the +// process of being negotiated, complete with any remote candidates that +// have been supplied via AddICECandidate() since the offer or answer was +// created. If the PeerConnection is in the stable state, the value is +// null. +func (pc *PeerConnection) PendingRemoteDescription() *SessionDescription { + pc.mu.RLock() + defer pc.mu.RUnlock() + + return pc.pendingRemoteDescription +} + +// CanTrickleICECandidates reports whether the remote endpoint indicated +// support for receiving trickled ICE candidates. +func (pc *PeerConnection) CanTrickleICECandidates() ICETrickleCapability { + pc.mu.RLock() + defer pc.mu.RUnlock() + + return pc.canTrickleICECandidates +} + +// SignalingState attribute returns the signaling state of the +// PeerConnection instance. +func (pc *PeerConnection) SignalingState() SignalingState { + return pc.signalingState.Get() +} + +// ICEGatheringState attribute returns the ICE gathering state of the +// PeerConnection instance. +func (pc *PeerConnection) ICEGatheringState() ICEGatheringState { + if pc.iceGatherer == nil { + return ICEGatheringStateNew + } + + switch pc.iceGatherer.State() { + case ICEGathererStateNew: + return ICEGatheringStateNew + case ICEGathererStateGathering: + return ICEGatheringStateGathering + default: + return ICEGatheringStateComplete + } +} + +// ConnectionState attribute returns the connection state of the +// PeerConnection instance. +func (pc *PeerConnection) ConnectionState() PeerConnectionState { + if state, ok := pc.connectionState.Load().(PeerConnectionState); ok { + return state + } + + return PeerConnectionState(0) +} + +// GetStats return data providing statistics about the overall connection. +func (pc *PeerConnection) GetStats() StatsReport { + var ( + dataChannelsAccepted uint32 + dataChannelsClosed uint32 + dataChannelsOpened uint32 + dataChannelsRequested uint32 + ) + statsCollector := newStatsReportCollector() + statsCollector.Collecting() + + pc.mu.Lock() + if pc.iceGatherer != nil { + pc.iceGatherer.collectStats(statsCollector) + } + if pc.iceTransport != nil { + pc.iceTransport.collectStats(statsCollector) + } + + pc.sctpTransport.lock.Lock() + dataChannels := append([]*DataChannel{}, pc.sctpTransport.dataChannels...) + dataChannelsAccepted = pc.sctpTransport.dataChannelsAccepted + dataChannelsOpened = pc.sctpTransport.dataChannelsOpened + dataChannelsRequested = pc.sctpTransport.dataChannelsRequested + pc.sctpTransport.lock.Unlock() + + for _, d := range dataChannels { + state := d.ReadyState() + if state != DataChannelStateConnecting && state != DataChannelStateOpen { + dataChannelsClosed++ + } + + d.collectStats(statsCollector) + } + pc.sctpTransport.collectStats(statsCollector) + + stats := PeerConnectionStats{ + Timestamp: statsTimestampNow(), + Type: StatsTypePeerConnection, + ID: pc.id, + DataChannelsAccepted: dataChannelsAccepted, + DataChannelsClosed: dataChannelsClosed, + DataChannelsOpened: dataChannelsOpened, + DataChannelsRequested: dataChannelsRequested, + } + + statsCollector.Collect(stats.ID, stats) + + certificates := pc.configuration.Certificates + for _, certificate := range certificates { + if err := certificate.collectStats(statsCollector); err != nil { + continue + } + } + pc.mu.Unlock() + + receivers := pc.GetReceivers() + for _, receiver := range receivers { + receiver.collectStats(statsCollector, pc.statsGetter) + } + + pc.api.mediaEngine.collectStats(statsCollector) + + return statsCollector.Ready() +} + +// Start all transports. PeerConnection now has enough state. +func (pc *PeerConnection) startTransports( + iceRole ICERole, + dtlsRole DTLSRole, + remoteUfrag, remotePwd, fingerprint, fingerprintHash string, +) { + // Start the ice transport + err := pc.iceTransport.Start( + pc.iceGatherer, + ICEParameters{ + UsernameFragment: remoteUfrag, + Password: remotePwd, + ICELite: false, + }, + &iceRole, + ) + if err != nil { + pc.log.Warnf("Failed to start manager: %s", err) + + return + } + + pc.dtlsTransport.internalOnCloseHandler = func() { + if pc.isClosed.Load() || pc.api.settingEngine.disableCloseByDTLS { + return + } + + pc.log.Info("Closing PeerConnection from DTLS CloseNotify") + go func() { + if pcClosErr := pc.Close(); pcClosErr != nil { + pc.log.Warnf("Failed to close PeerConnection from DTLS CloseNotify: %s", pcClosErr) + } + }() + } + + // Start the dtls transport + err = pc.dtlsTransport.Start(DTLSParameters{ + Role: dtlsRole, + Fingerprints: []DTLSFingerprint{{Algorithm: fingerprintHash, Value: fingerprint}}, + }) + pc.updateConnectionState(pc.ICEConnectionState(), pc.dtlsTransport.State()) + if err != nil { + pc.log.Warnf("Failed to start manager: %s", err) + + return + } +} + +// nolint: gocognit +func (pc *PeerConnection) startRTP( + isRenegotiation bool, + remoteDesc *SessionDescription, + currentTransceivers []*RTPTransceiver, +) { + if !isRenegotiation { + pc.undeclaredMediaProcessor() + } + + pc.startRTPReceivers(remoteDesc, currentTransceivers) + if d := haveDataChannel(remoteDesc); d != nil && d.MediaName.Port.Value != 0 { + remoteSctpInit, _ := getSctpInit(d) + pc.startSCTP(getMaxMessageSize(d), remoteSctpInit) + } +} + +// generateUnmatchedSDP generates an SDP that doesn't take remote state into account. +// This is used for the initial call for CreateOffer. +// +//nolint:cyclop +func (pc *PeerConnection) generateUnmatchedSDP( + transceivers []*RTPTransceiver, + useIdentity bool, +) (*sdp.SessionDescription, error) { + desc, err := sdp.NewJSEPSessionDescription(useIdentity) + if err != nil { + return nil, err + } + desc.Attributes = append(desc.Attributes, sdp.Attribute{Key: sdp.AttrKeyMsidSemantic, Value: "WMS *"}) + + iceParams, err := pc.iceGatherer.GetLocalParameters() + if err != nil { + return nil, err + } + + candidates, err := pc.iceGatherer.GetLocalCandidates() + if err != nil { + return nil, err + } + + isPlanB := pc.configuration.SDPSemantics == SDPSemanticsPlanB + mediaSections := []mediaSection{} + + // Needed for pc.sctpTransport.dataChannelsRequested + pc.sctpTransport.lock.Lock() + + var localSctpInit []byte + if pc.sctpTransport.dataChannelsRequested != 0 && pc.api.settingEngine.sctp.enableSnap { + localSctpInit = pc.sctpTransport.GetSctpInit() + } + defer pc.sctpTransport.lock.Unlock() + + if isPlanB { //nolint:nestif + video := make([]*RTPTransceiver, 0) + audio := make([]*RTPTransceiver, 0) + + for _, t := range transceivers { + if t.kind == RTPCodecTypeVideo { + video = append(video, t) + } else if t.kind == RTPCodecTypeAudio { + audio = append(audio, t) + } + if sender := t.Sender(); sender != nil { + sender.setNegotiated() + } + } + + if len(video) > 0 { + mediaSections = append(mediaSections, mediaSection{id: "video", transceivers: video}) + } + if len(audio) > 0 { + mediaSections = append(mediaSections, mediaSection{id: "audio", transceivers: audio}) + } + + if pc.configuration.AlwaysNegotiateDataChannels || pc.sctpTransport.dataChannelsRequested != 0 { + mediaSections = append(mediaSections, mediaSection{id: "data", data: true}) + } + } else { + for _, t := range transceivers { + if sender := t.Sender(); sender != nil { + sender.setNegotiated() + } + mediaSections = append(mediaSections, mediaSection{id: t.Mid(), transceivers: []*RTPTransceiver{t}}) + } + + if pc.configuration.AlwaysNegotiateDataChannels || pc.sctpTransport.dataChannelsRequested != 0 { + mediaSections = append(mediaSections, mediaSection{ + id: strconv.Itoa(len(mediaSections)), + data: true, + sctpInit: localSctpInit, + }) + } + } + + dtlsFingerprints, err := pc.configuration.Certificates[0].GetFingerprints() + if err != nil { + return nil, err + } + + return populateSDP( + desc, + isPlanB, + dtlsFingerprints, + pc.api.settingEngine.sdpMediaLevelFingerprints, + pc.api.settingEngine.candidates.ICELite, + true, + pc.api.mediaEngine, + connectionRoleFromDtlsRole(defaultDtlsRoleOffer), + candidates, + iceParams, + mediaSections, + pc.ICEGatheringState(), + nil, + pc.api.settingEngine.getSCTPMaxMessageSize(), + false, + ) +} + +// generateMatchedSDP generates a SDP and takes the remote state into account. +// This is used everytime we have a RemoteDescription +// +//nolint:gocognit,gocyclo,cyclop,maintidx +func (pc *PeerConnection) generateMatchedSDP( + transceivers []*RTPTransceiver, + useIdentity, includeUnmatched bool, + connectionRole sdp.ConnectionRole, + ignoreRidPauseForRecv bool, +) (*sdp.SessionDescription, error) { + desc, err := sdp.NewJSEPSessionDescription(useIdentity) + if err != nil { + return nil, err + } + desc.Attributes = append(desc.Attributes, sdp.Attribute{Key: sdp.AttrKeyMsidSemantic, Value: "WMS *"}) + + iceParams, err := pc.iceGatherer.GetLocalParameters() + if err != nil { + return nil, err + } + + candidates, err := pc.iceGatherer.GetLocalCandidates() + if err != nil { + return nil, err + } + + var transceiver *RTPTransceiver + remoteDescription := pc.currentRemoteDescription + if pc.pendingRemoteDescription != nil { + remoteDescription = pc.pendingRemoteDescription + } + isExtmapAllowMixed := isExtMapAllowMixedSet(remoteDescription.parsed) + localTransceivers := append([]*RTPTransceiver{}, transceivers...) + + detectedPlanB := descriptionIsPlanB(remoteDescription, pc.log) + if pc.configuration.SDPSemantics != SDPSemanticsUnifiedPlan { + detectedPlanB = descriptionPossiblyPlanB(remoteDescription) + } + + mediaSections := []mediaSection{} + alreadyHaveApplicationMediaSection := false + var localSctpInit []byte + for _, media := range remoteDescription.parsed.MediaDescriptions { + midValue := getMidValue(media) + if midValue == "" { + return nil, errPeerConnRemoteDescriptionWithoutMidValue + } + + if media.MediaName.Media == mediaSectionApplication { + init, _ := getSctpInit(media) + if init != nil && pc.api.settingEngine.sctp.enableSnap { + pc.sctpTransport.lock.Lock() + localSctpInit = pc.sctpTransport.GetSctpInit() + pc.sctpTransport.lock.Unlock() + } + + mediaSections = append(mediaSections, mediaSection{id: midValue, data: true, sctpInit: localSctpInit}) + alreadyHaveApplicationMediaSection = true + + continue + } + + kind := NewRTPCodecType(media.MediaName.Media) + direction := getPeerDirection(media) + if kind == 0 || direction == RTPTransceiverDirectionUnknown { + continue + } + + sdpSemantics := pc.configuration.SDPSemantics + + switch { + case sdpSemantics == SDPSemanticsPlanB || sdpSemantics == SDPSemanticsUnifiedPlanWithFallback && detectedPlanB: + if !detectedPlanB { + return nil, &rtcerr.TypeError{ + Err: fmt.Errorf("%w: Expected PlanB, but RemoteDescription is UnifiedPlan", ErrIncorrectSDPSemantics), + } + } + // If we're responding to a plan-b offer, then we should try to fill up this + // media entry with all matching local transceivers + mediaTransceivers := []*RTPTransceiver{} + for { + // keep going until we can't get any more + transceiver, localTransceivers = satisfyTypeAndDirection(kind, direction, localTransceivers) + if transceiver == nil { + if len(mediaTransceivers) == 0 { + transceiver = &RTPTransceiver{kind: kind, api: pc.api, codecs: pc.api.mediaEngine.getCodecsByKind(kind)} + transceiver.setDirection(RTPTransceiverDirectionInactive) + mediaTransceivers = append(mediaTransceivers, transceiver) + } + + break + } + if sender := transceiver.Sender(); sender != nil { + sender.setNegotiated() + } + mediaTransceivers = append(mediaTransceivers, transceiver) + } + mediaSections = append(mediaSections, mediaSection{id: midValue, transceivers: mediaTransceivers}) + case sdpSemantics == SDPSemanticsUnifiedPlan || sdpSemantics == SDPSemanticsUnifiedPlanWithFallback: + if detectedPlanB { + return nil, &rtcerr.TypeError{ + Err: fmt.Errorf( + "%w: Expected UnifiedPlan, but RemoteDescription is PlanB", + ErrIncorrectSDPSemantics, + ), + } + } + transceiver, localTransceivers = findByMid(midValue, localTransceivers) + if transceiver == nil { + return nil, fmt.Errorf("%w: %q", errPeerConnTranscieverMidNil, midValue) + } + if sender := transceiver.Sender(); sender != nil { + sender.setNegotiated() + } + mediaTransceivers := []*RTPTransceiver{transceiver} + + extensions, _ := rtpExtensionsFromMediaDescription(media) + mediaSections = append( + mediaSections, + mediaSection{id: midValue, transceivers: mediaTransceivers, matchExtensions: extensions, rids: getRids(media)}, + ) + } + } + + pc.sctpTransport.lock.Lock() + defer pc.sctpTransport.lock.Unlock() + + var bundleGroup *string + // If we are offering also include unmatched local transceivers + if includeUnmatched { //nolint:nestif + if !detectedPlanB { + for _, t := range localTransceivers { + if sender := t.Sender(); sender != nil { + sender.setNegotiated() + } + mediaSections = append(mediaSections, mediaSection{id: t.Mid(), transceivers: []*RTPTransceiver{t}}) + } + } + + if (pc.configuration.AlwaysNegotiateDataChannels || pc.sctpTransport.dataChannelsRequested != 0) && + !alreadyHaveApplicationMediaSection { + if detectedPlanB { + mediaSections = append(mediaSections, mediaSection{id: "data", data: true}) + } else { + mediaSections = append(mediaSections, mediaSection{ + id: strconv.Itoa(len(mediaSections)), + data: true, + sctpInit: localSctpInit, + }) + } + } + } else if remoteDescription != nil { + groupValue, _ := remoteDescription.parsed.Attribute(sdp.AttrKeyGroup) + groupValue = strings.TrimLeft(groupValue, "BUNDLE") + bundleGroup = &groupValue + } + + if pc.configuration.SDPSemantics == SDPSemanticsUnifiedPlanWithFallback && detectedPlanB { + pc.log.Info("Plan-B Offer detected; responding with Plan-B Answer") + } + + dtlsFingerprints, err := pc.configuration.Certificates[0].GetFingerprints() + if err != nil { + return nil, err + } + + return populateSDP( + desc, + detectedPlanB, + dtlsFingerprints, + pc.api.settingEngine.sdpMediaLevelFingerprints, + pc.api.settingEngine.candidates.ICELite, + isExtmapAllowMixed, + pc.api.mediaEngine, + connectionRole, + candidates, + iceParams, + mediaSections, + pc.ICEGatheringState(), + bundleGroup, + pc.api.settingEngine.getSCTPMaxMessageSize(), + ignoreRidPauseForRecv, + ) +} + +func (pc *PeerConnection) setGatherCompleteHandler(handler func()) { + pc.iceGatherer.onGatheringCompleteHandler.Store(handler) +} + +// SCTP returns the SCTPTransport for this PeerConnection +// +// The SCTP transport over which SCTP data is sent and received. If SCTP has not been negotiated, the value is nil. +// https://www.w3.org/TR/webrtc/#attributes-15 +func (pc *PeerConnection) SCTP() *SCTPTransport { + return pc.sctpTransport +} diff --git a/vendor/github.com/pion/webrtc/v4/peerconnection_js.go b/vendor/github.com/pion/webrtc/v4/peerconnection_js.go new file mode 100644 index 0000000..87460bd --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/peerconnection_js.go @@ -0,0 +1,792 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build js && wasm +// +build js,wasm + +// Package webrtc implements the WebRTC 1.0 as defined in W3C WebRTC specification document. +package webrtc + +import ( + "syscall/js" + + "github.com/pion/ice/v4" + "github.com/pion/webrtc/v4/pkg/rtcerr" +) + +// PeerConnection represents a WebRTC connection that establishes a +// peer-to-peer communications with another PeerConnection instance in a +// browser, or to another endpoint implementing the required protocols. +type PeerConnection struct { + // Pointer to the underlying JavaScript RTCPeerConnection object. + underlying js.Value + + // Keep track of handlers/callbacks so we can call Release as required by the + // syscall/js API. Initially nil. + onSignalingStateChangeHandler *js.Func + onDataChannelHandler *js.Func + onNegotiationNeededHandler *js.Func + onConnectionStateChangeHandler *js.Func + onICEConnectionStateChangeHandler *js.Func + onICECandidateHandler *js.Func + onICEGatheringStateChangeHandler *js.Func + + // Used by GatheringCompletePromise + onGatherCompleteHandler func() + + // A reference to the associated API state used by this connection + api *API +} + +// NewPeerConnection creates a peerconnection. +func NewPeerConnection(configuration Configuration) (*PeerConnection, error) { + api := NewAPI() + return api.NewPeerConnection(configuration) +} + +// NewPeerConnection creates a new PeerConnection with the provided configuration against the received API object +func (api *API) NewPeerConnection(configuration Configuration) (_ *PeerConnection, err error) { + defer func() { + if e := recover(); e != nil { + err = recoveryToError(e) + } + }() + configMap := configurationToValue(configuration) + underlying := js.Global().Get("window").Get("RTCPeerConnection").New(configMap) + return &PeerConnection{ + underlying: underlying, + api: api, + }, nil +} + +// JSValue returns the underlying PeerConnection +func (pc *PeerConnection) JSValue() js.Value { + return pc.underlying +} + +// OnSignalingStateChange sets an event handler which is invoked when the +// peer connection's signaling state changes +func (pc *PeerConnection) OnSignalingStateChange(f func(SignalingState)) { + if pc.onSignalingStateChangeHandler != nil { + oldHandler := pc.onSignalingStateChangeHandler + defer oldHandler.Release() + } + onSignalingStateChangeHandler := js.FuncOf(func(this js.Value, args []js.Value) any { + state := newSignalingState(args[0].String()) + go f(state) + return js.Undefined() + }) + pc.onSignalingStateChangeHandler = &onSignalingStateChangeHandler + pc.underlying.Set("onsignalingstatechange", onSignalingStateChangeHandler) +} + +// OnDataChannel sets an event handler which is invoked when a data +// channel message arrives from a remote peer. +func (pc *PeerConnection) OnDataChannel(f func(*DataChannel)) { + if pc.onDataChannelHandler != nil { + oldHandler := pc.onDataChannelHandler + defer oldHandler.Release() + } + onDataChannelHandler := js.FuncOf(func(this js.Value, args []js.Value) any { + // pion/webrtc/projects/15 + // This reference to the underlying DataChannel doesn't know + // about any other references to the same DataChannel. This might result in + // memory leaks where we don't clean up handler functions. Could possibly fix + // by keeping a mutex-protected list of all DataChannel references as a + // property of this PeerConnection, but at the cost of additional overhead. + dataChannel := &DataChannel{ + underlying: args[0].Get("channel"), + api: pc.api, + } + go f(dataChannel) + return js.Undefined() + }) + pc.onDataChannelHandler = &onDataChannelHandler + pc.underlying.Set("ondatachannel", onDataChannelHandler) +} + +// OnNegotiationNeeded sets an event handler which is invoked when +// a change has occurred which requires session negotiation +func (pc *PeerConnection) OnNegotiationNeeded(f func()) { + if pc.onNegotiationNeededHandler != nil { + oldHandler := pc.onNegotiationNeededHandler + defer oldHandler.Release() + } + onNegotiationNeededHandler := js.FuncOf(func(this js.Value, args []js.Value) any { + go f() + return js.Undefined() + }) + pc.onNegotiationNeededHandler = &onNegotiationNeededHandler + pc.underlying.Set("onnegotiationneeded", onNegotiationNeededHandler) +} + +// OnICEConnectionStateChange sets an event handler which is called +// when an ICE connection state is changed. +func (pc *PeerConnection) OnICEConnectionStateChange(f func(ICEConnectionState)) { + if pc.onICEConnectionStateChangeHandler != nil { + oldHandler := pc.onICEConnectionStateChangeHandler + defer oldHandler.Release() + } + onICEConnectionStateChangeHandler := js.FuncOf(func(this js.Value, args []js.Value) any { + connectionState := NewICEConnectionState(pc.underlying.Get("iceConnectionState").String()) + go f(connectionState) + return js.Undefined() + }) + pc.onICEConnectionStateChangeHandler = &onICEConnectionStateChangeHandler + pc.underlying.Set("oniceconnectionstatechange", onICEConnectionStateChangeHandler) +} + +// OnConnectionStateChange sets an event handler which is called +// when an PeerConnectionState is changed. +func (pc *PeerConnection) OnConnectionStateChange(f func(PeerConnectionState)) { + if pc.onConnectionStateChangeHandler != nil { + oldHandler := pc.onConnectionStateChangeHandler + defer oldHandler.Release() + } + onConnectionStateChangeHandler := js.FuncOf(func(this js.Value, args []js.Value) any { + connectionState := newPeerConnectionState(pc.underlying.Get("connectionState").String()) + go f(connectionState) + return js.Undefined() + }) + pc.onConnectionStateChangeHandler = &onConnectionStateChangeHandler + pc.underlying.Set("onconnectionstatechange", onConnectionStateChangeHandler) +} + +func (pc *PeerConnection) checkConfiguration(configuration Configuration) error { + // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-setconfiguration (step #2) + if pc.ConnectionState() == PeerConnectionStateClosed { + return &rtcerr.InvalidStateError{Err: ErrConnectionClosed} + } + + existingConfig := pc.GetConfiguration() + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #3) + if configuration.PeerIdentity != "" { + if configuration.PeerIdentity != existingConfig.PeerIdentity { + return &rtcerr.InvalidModificationError{Err: ErrModifyingPeerIdentity} + } + } + + // https://github.com/pion/webrtc/issues/513 + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #4) + // if len(configuration.Certificates) > 0 { + // if len(configuration.Certificates) != len(existingConfiguration.Certificates) { + // return &rtcerr.InvalidModificationError{Err: ErrModifyingCertificates} + // } + + // for i, certificate := range configuration.Certificates { + // if !pc.configuration.Certificates[i].Equals(certificate) { + // return &rtcerr.InvalidModificationError{Err: ErrModifyingCertificates} + // } + // } + // pc.configuration.Certificates = configuration.Certificates + // } + + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #5) + if configuration.BundlePolicy != BundlePolicyUnknown { + if configuration.BundlePolicy != existingConfig.BundlePolicy { + return &rtcerr.InvalidModificationError{Err: ErrModifyingBundlePolicy} + } + } + + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #6) + if configuration.RTCPMuxPolicy != RTCPMuxPolicyUnknown { + if configuration.RTCPMuxPolicy != existingConfig.RTCPMuxPolicy { + return &rtcerr.InvalidModificationError{Err: ErrModifyingRTCPMuxPolicy} + } + } + + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #7) + if configuration.ICECandidatePoolSize != 0 { + if configuration.ICECandidatePoolSize != existingConfig.ICECandidatePoolSize && + pc.LocalDescription() != nil { + return &rtcerr.InvalidModificationError{Err: ErrModifyingICECandidatePoolSize} + } + } + + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #11) + if len(configuration.ICEServers) > 0 { + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #11.3) + for _, server := range configuration.ICEServers { + if _, err := server.validate(); err != nil { + return err + } + } + } + return nil +} + +// SetConfiguration updates the configuration of this PeerConnection object. +func (pc *PeerConnection) SetConfiguration(configuration Configuration) (err error) { + defer func() { + if e := recover(); e != nil { + err = recoveryToError(e) + } + }() + if err := pc.checkConfiguration(configuration); err != nil { + return err + } + configMap := configurationToValue(configuration) + pc.underlying.Call("setConfiguration", configMap) + return nil +} + +// GetConfiguration returns a Configuration object representing the current +// configuration of this PeerConnection object. The returned object is a +// copy and direct mutation on it will not take affect until SetConfiguration +// has been called with Configuration passed as its only argument. +// https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-getconfiguration +func (pc *PeerConnection) GetConfiguration() Configuration { + return valueToConfiguration(pc.underlying.Call("getConfiguration")) +} + +// CreateOffer starts the PeerConnection and generates the localDescription +func (pc *PeerConnection) CreateOffer(options *OfferOptions) (_ SessionDescription, err error) { + defer func() { + if e := recover(); e != nil { + err = recoveryToError(e) + } + }() + promise := pc.underlying.Call("createOffer", offerOptionsToValue(options)) + desc, err := awaitPromise(promise) + if err != nil { + return SessionDescription{}, err + } + return *valueToSessionDescription(desc), nil +} + +// CreateAnswer starts the PeerConnection and generates the localDescription +func (pc *PeerConnection) CreateAnswer(options *AnswerOptions) (_ SessionDescription, err error) { + defer func() { + if e := recover(); e != nil { + err = recoveryToError(e) + } + }() + promise := pc.underlying.Call("createAnswer", answerOptionsToValue(options)) + desc, err := awaitPromise(promise) + if err != nil { + return SessionDescription{}, err + } + return *valueToSessionDescription(desc), nil +} + +// SetLocalDescription sets the SessionDescription of the local peer +func (pc *PeerConnection) SetLocalDescription(desc SessionDescription) (err error) { + defer func() { + if e := recover(); e != nil { + err = recoveryToError(e) + } + }() + promise := pc.underlying.Call("setLocalDescription", sessionDescriptionToValue(&desc)) + _, err = awaitPromise(promise) + return err +} + +// LocalDescription returns PendingLocalDescription if it is not null and +// otherwise it returns CurrentLocalDescription. This property is used to +// determine if setLocalDescription has already been called. +// https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-localdescription +func (pc *PeerConnection) LocalDescription() *SessionDescription { + return valueToSessionDescription(pc.underlying.Get("localDescription")) +} + +// SetRemoteDescription sets the SessionDescription of the remote peer +func (pc *PeerConnection) SetRemoteDescription(desc SessionDescription) (err error) { + defer func() { + if e := recover(); e != nil { + err = recoveryToError(e) + } + }() + promise := pc.underlying.Call("setRemoteDescription", sessionDescriptionToValue(&desc)) + _, err = awaitPromise(promise) + return err +} + +// RemoteDescription returns PendingRemoteDescription if it is not null and +// otherwise it returns CurrentRemoteDescription. This property is used to +// determine if setRemoteDescription has already been called. +// https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-remotedescription +func (pc *PeerConnection) RemoteDescription() *SessionDescription { + return valueToSessionDescription(pc.underlying.Get("remoteDescription")) +} + +// AddICECandidate accepts an ICE candidate string and adds it +// to the existing set of candidates +func (pc *PeerConnection) AddICECandidate(candidate ICECandidateInit) (err error) { + defer func() { + if e := recover(); e != nil { + err = recoveryToError(e) + } + }() + promise := pc.underlying.Call("addIceCandidate", iceCandidateInitToValue(candidate)) + _, err = awaitPromise(promise) + return err +} + +// ICEConnectionState returns the ICE connection state of the +// PeerConnection instance. +func (pc *PeerConnection) ICEConnectionState() ICEConnectionState { + return NewICEConnectionState(pc.underlying.Get("iceConnectionState").String()) +} + +// OnICECandidate sets an event handler which is invoked when a new ICE +// candidate is found. +func (pc *PeerConnection) OnICECandidate(f func(candidate *ICECandidate)) { + if pc.onICECandidateHandler != nil { + oldHandler := pc.onICECandidateHandler + defer oldHandler.Release() + } + onICECandidateHandler := js.FuncOf(func(this js.Value, args []js.Value) any { + candidate := valueToICECandidate(args[0].Get("candidate")) + if candidate == nil && pc.onGatherCompleteHandler != nil { + go pc.onGatherCompleteHandler() + } + + go f(candidate) + return js.Undefined() + }) + pc.onICECandidateHandler = &onICECandidateHandler + pc.underlying.Set("onicecandidate", onICECandidateHandler) +} + +// OnICEGatheringStateChange sets an event handler which is invoked when the +// ICE candidate gathering state has changed. +func (pc *PeerConnection) OnICEGatheringStateChange(f func()) { + if pc.onICEGatheringStateChangeHandler != nil { + oldHandler := pc.onICEGatheringStateChangeHandler + defer oldHandler.Release() + } + onICEGatheringStateChangeHandler := js.FuncOf(func(this js.Value, args []js.Value) any { + go f() + return js.Undefined() + }) + pc.onICEGatheringStateChangeHandler = &onICEGatheringStateChangeHandler + pc.underlying.Set("onicegatheringstatechange", onICEGatheringStateChangeHandler) +} + +// CreateDataChannel creates a new DataChannel object with the given label +// and optional DataChannelInit used to configure properties of the +// underlying channel such as data reliability. +func (pc *PeerConnection) CreateDataChannel(label string, options *DataChannelInit) (_ *DataChannel, err error) { + defer func() { + if e := recover(); e != nil { + err = recoveryToError(e) + } + }() + channel := pc.underlying.Call("createDataChannel", label, dataChannelInitToValue(options)) + return &DataChannel{ + underlying: channel, + api: pc.api, + }, nil +} + +// SetIdentityProvider is used to configure an identity provider to generate identity assertions +func (pc *PeerConnection) SetIdentityProvider(provider string) (err error) { + defer func() { + if e := recover(); e != nil { + err = recoveryToError(e) + } + }() + pc.underlying.Call("setIdentityProvider", provider) + return nil +} + +// Close ends the PeerConnection +func (pc *PeerConnection) Close() (err error) { + defer func() { + if e := recover(); e != nil { + err = recoveryToError(e) + } + }() + + pc.underlying.Call("close") + + // Release any handlers as required by the syscall/js API. + if pc.onSignalingStateChangeHandler != nil { + pc.onSignalingStateChangeHandler.Release() + } + if pc.onDataChannelHandler != nil { + pc.onDataChannelHandler.Release() + } + if pc.onNegotiationNeededHandler != nil { + pc.onNegotiationNeededHandler.Release() + } + if pc.onConnectionStateChangeHandler != nil { + pc.onConnectionStateChangeHandler.Release() + } + if pc.onICEConnectionStateChangeHandler != nil { + pc.onICEConnectionStateChangeHandler.Release() + } + if pc.onICECandidateHandler != nil { + pc.onICECandidateHandler.Release() + } + if pc.onICEGatheringStateChangeHandler != nil { + pc.onICEGatheringStateChangeHandler.Release() + } + + return nil +} + +// CurrentLocalDescription represents the local description that was +// successfully negotiated the last time the PeerConnection transitioned +// into the stable state plus any local candidates that have been generated +// by the ICEAgent since the offer or answer was created. +func (pc *PeerConnection) CurrentLocalDescription() *SessionDescription { + desc := pc.underlying.Get("currentLocalDescription") + return valueToSessionDescription(desc) +} + +// PendingLocalDescription represents a local description that is in the +// process of being negotiated plus any local candidates that have been +// generated by the ICEAgent since the offer or answer was created. If the +// PeerConnection is in the stable state, the value is null. +func (pc *PeerConnection) PendingLocalDescription() *SessionDescription { + desc := pc.underlying.Get("pendingLocalDescription") + return valueToSessionDescription(desc) +} + +// CurrentRemoteDescription represents the last remote description that was +// successfully negotiated the last time the PeerConnection transitioned +// into the stable state plus any remote candidates that have been supplied +// via AddICECandidate() since the offer or answer was created. +func (pc *PeerConnection) CurrentRemoteDescription() *SessionDescription { + desc := pc.underlying.Get("currentRemoteDescription") + return valueToSessionDescription(desc) +} + +// PendingRemoteDescription represents a remote description that is in the +// process of being negotiated, complete with any remote candidates that +// have been supplied via AddICECandidate() since the offer or answer was +// created. If the PeerConnection is in the stable state, the value is +// null. +func (pc *PeerConnection) PendingRemoteDescription() *SessionDescription { + desc := pc.underlying.Get("pendingRemoteDescription") + return valueToSessionDescription(desc) +} + +// CanTrickleICECandidates reports whether the remote endpoint indicated +// support for receiving trickled ICE candidates. +func (pc *PeerConnection) CanTrickleICECandidates() ICETrickleCapability { + val := pc.underlying.Get("canTrickleIceCandidates") + if val.IsNull() || val.IsUndefined() { + return ICETrickleCapabilityUnknown + } + + if val.Bool() { + return ICETrickleCapabilitySupported + } + + return ICETrickleCapabilityUnsupported +} + +// SignalingState returns the signaling state of the PeerConnection instance. +func (pc *PeerConnection) SignalingState() SignalingState { + rawState := pc.underlying.Get("signalingState").String() + return newSignalingState(rawState) +} + +// ICEGatheringState attribute the ICE gathering state of the PeerConnection +// instance. +func (pc *PeerConnection) ICEGatheringState() ICEGatheringState { + rawState := pc.underlying.Get("iceGatheringState").String() + return NewICEGatheringState(rawState) +} + +// ConnectionState attribute the connection state of the PeerConnection +// instance. +func (pc *PeerConnection) ConnectionState() PeerConnectionState { + rawState := pc.underlying.Get("connectionState").String() + return newPeerConnectionState(rawState) +} + +func (pc *PeerConnection) setGatherCompleteHandler(handler func()) { + pc.onGatherCompleteHandler = handler + + // If no onIceCandidate handler has been set provide an empty one + // otherwise our onGatherCompleteHandler will not be executed + if pc.onICECandidateHandler == nil { + pc.OnICECandidate(func(i *ICECandidate) {}) + } +} + +// AddTransceiverFromKind Create a new RtpTransceiver and adds it to the set of transceivers. +func (pc *PeerConnection) AddTransceiverFromKind(kind RTPCodecType, init ...RTPTransceiverInit) (transceiver *RTPTransceiver, err error) { + defer func() { + if e := recover(); e != nil { + err = recoveryToError(e) + } + }() + + if len(init) == 1 { + return &RTPTransceiver{ + underlying: pc.underlying.Call("addTransceiver", kind.String(), rtpTransceiverInitInitToValue(init[0])), + }, err + } + + return &RTPTransceiver{ + underlying: pc.underlying.Call("addTransceiver", kind.String()), + }, err +} + +// GetTransceivers returns the RtpTransceiver that are currently attached to this PeerConnection +func (pc *PeerConnection) GetTransceivers() (transceivers []*RTPTransceiver) { + rawTransceivers := pc.underlying.Call("getTransceivers") + transceivers = make([]*RTPTransceiver, rawTransceivers.Length()) + + for i := 0; i < rawTransceivers.Length(); i++ { + transceivers[i] = &RTPTransceiver{ + underlying: rawTransceivers.Index(i), + } + } + + return +} + +// SCTP returns the SCTPTransport for this PeerConnection +// +// The SCTP transport over which SCTP data is sent and received. If SCTP has not been negotiated, the value is nil. +// https://www.w3.org/TR/webrtc/#attributes-15 +func (pc *PeerConnection) SCTP() *SCTPTransport { + underlying := pc.underlying.Get("sctp") + if underlying.IsNull() || underlying.IsUndefined() { + return nil + } + + return &SCTPTransport{ + underlying: underlying, + } +} + +// Converts a Configuration to js.Value so it can be passed +// through to the JavaScript WebRTC API. Any zero values are converted to +// js.Undefined(), which will result in the default value being used. +func configurationToValue(configuration Configuration) js.Value { + return js.ValueOf(map[string]any{ + "iceServers": iceServersToValue(configuration.ICEServers), + "iceTransportPolicy": stringEnumToValueOrUndefined(configuration.ICETransportPolicy.String()), + "bundlePolicy": stringEnumToValueOrUndefined(configuration.BundlePolicy.String()), + "rtcpMuxPolicy": stringEnumToValueOrUndefined(configuration.RTCPMuxPolicy.String()), + "peerIdentity": stringToValueOrUndefined(configuration.PeerIdentity), + "iceCandidatePoolSize": uint8ToValueOrUndefined(configuration.ICECandidatePoolSize), + "alwaysNegotiateDataChannels": boolToValueOrUndefined(configuration.AlwaysNegotiateDataChannels), + + // Note: Certificates are not currently supported. + // "certificates": configuration.Certificates, + }) +} + +func iceServersToValue(iceServers []ICEServer) js.Value { + if len(iceServers) == 0 { + return js.Undefined() + } + maps := make([]any, len(iceServers)) + for i, server := range iceServers { + maps[i] = iceServerToValue(server) + } + return js.ValueOf(maps) +} + +func oauthCredentialToValue(o OAuthCredential) js.Value { + out := map[string]any{ + "MACKey": o.MACKey, + "AccessToken": o.AccessToken, + } + return js.ValueOf(out) +} + +func iceServerToValue(server ICEServer) js.Value { + out := map[string]any{ + "urls": stringsToValue(server.URLs), // required + } + if server.Username != "" { + out["username"] = stringToValueOrUndefined(server.Username) + } + if server.Credential != nil { + switch t := server.Credential.(type) { + case string: + out["credential"] = stringToValueOrUndefined(t) + case OAuthCredential: + out["credential"] = oauthCredentialToValue(t) + } + } + out["credentialType"] = stringEnumToValueOrUndefined(server.CredentialType.String()) + return js.ValueOf(out) +} + +func valueToConfiguration(configValue js.Value) Configuration { + if configValue.IsNull() || configValue.IsUndefined() { + return Configuration{} + } + return Configuration{ + ICEServers: valueToICEServers(configValue.Get("iceServers")), + ICETransportPolicy: NewICETransportPolicy(valueToStringOrZero(configValue.Get("iceTransportPolicy"))), + BundlePolicy: newBundlePolicy(valueToStringOrZero(configValue.Get("bundlePolicy"))), + RTCPMuxPolicy: newRTCPMuxPolicy(valueToStringOrZero(configValue.Get("rtcpMuxPolicy"))), + PeerIdentity: valueToStringOrZero(configValue.Get("peerIdentity")), + ICECandidatePoolSize: valueToUint8OrZero(configValue.Get("iceCandidatePoolSize")), + AlwaysNegotiateDataChannels: valueToBoolOrFalse(configValue.Get("alwaysNegotiateDataChannels")), + + // Note: Certificates are not supported. + // Certificates []Certificate + } +} + +func valueToICEServers(iceServersValue js.Value) []ICEServer { + if iceServersValue.IsNull() || iceServersValue.IsUndefined() { + return nil + } + iceServers := make([]ICEServer, iceServersValue.Length()) + for i := 0; i < iceServersValue.Length(); i++ { + iceServers[i] = valueToICEServer(iceServersValue.Index(i)) + } + return iceServers +} + +func valueToICECredential(iceCredentialValue js.Value) any { + if iceCredentialValue.IsNull() || iceCredentialValue.IsUndefined() { + return nil + } + if iceCredentialValue.Type() == js.TypeString { + return iceCredentialValue.String() + } + if iceCredentialValue.Type() == js.TypeObject { + return OAuthCredential{ + MACKey: iceCredentialValue.Get("MACKey").String(), + AccessToken: iceCredentialValue.Get("AccessToken").String(), + } + } + return nil +} + +func valueToICEServer(iceServerValue js.Value) ICEServer { + tpe, err := newICECredentialType(valueToStringOrZero(iceServerValue.Get("credentialType"))) + if err != nil { + tpe = ICECredentialTypePassword + } + s := ICEServer{ + URLs: valueToStrings(iceServerValue.Get("urls")), // required + Username: valueToStringOrZero(iceServerValue.Get("username")), + // Note: Credential and CredentialType are not currently supported. + Credential: valueToICECredential(iceServerValue.Get("credential")), + CredentialType: tpe, + } + + return s +} + +func valueToICECandidate(val js.Value) *ICECandidate { + if val.IsNull() || val.IsUndefined() { + return nil + } + if val.Get("protocol").IsUndefined() && !val.Get("candidate").IsUndefined() { + // Missing some fields, assume it's Firefox and parse SDP candidate. + c, err := ice.UnmarshalCandidate(val.Get("candidate").String()) + if err != nil { + return nil + } + + iceCandidate, err := newICECandidateFromICE(c, "", 0) + if err != nil { + return nil + } + + return &iceCandidate + } + protocol, _ := NewICEProtocol(val.Get("protocol").String()) + candidateType, _ := NewICECandidateType(val.Get("type").String()) + return &ICECandidate{ + Foundation: val.Get("foundation").String(), + Priority: valueToUint32OrZero(val.Get("priority")), + Address: val.Get("address").String(), + Protocol: protocol, + Port: valueToUint16OrZero(val.Get("port")), + Typ: candidateType, + Component: stringToComponentIDOrZero(val.Get("component").String()), + RelatedAddress: val.Get("relatedAddress").String(), + RelatedPort: valueToUint16OrZero(val.Get("relatedPort")), + } +} + +func stringToComponentIDOrZero(val string) uint16 { + // See: https://developer.mozilla.org/en-US/docs/Web/API/RTCIceComponent + switch val { + case "rtp": + return 1 + case "rtcp": + return 2 + } + return 0 +} + +func sessionDescriptionToValue(desc *SessionDescription) js.Value { + if desc == nil { + return js.Undefined() + } + return js.ValueOf(map[string]any{ + "type": desc.Type.String(), + "sdp": desc.SDP, + }) +} + +func valueToSessionDescription(descValue js.Value) *SessionDescription { + if descValue.IsNull() || descValue.IsUndefined() { + return nil + } + + return &SessionDescription{ + Type: NewSDPType(descValue.Get("type").String()), + SDP: descValue.Get("sdp").String(), + } +} + +func offerOptionsToValue(offerOptions *OfferOptions) js.Value { + if offerOptions == nil { + return js.Undefined() + } + return js.ValueOf(map[string]any{ + "iceRestart": offerOptions.ICERestart, + "voiceActivityDetection": offerOptions.VoiceActivityDetection, + }) +} + +func answerOptionsToValue(answerOptions *AnswerOptions) js.Value { + if answerOptions == nil { + return js.Undefined() + } + return js.ValueOf(map[string]any{ + "voiceActivityDetection": answerOptions.VoiceActivityDetection, + }) +} + +func iceCandidateInitToValue(candidate ICECandidateInit) js.Value { + return js.ValueOf(map[string]any{ + "candidate": candidate.Candidate, + "sdpMid": stringPointerToValue(candidate.SDPMid), + "sdpMLineIndex": uint16PointerToValue(candidate.SDPMLineIndex), + "usernameFragment": stringPointerToValue(candidate.UsernameFragment), + }) +} + +func dataChannelInitToValue(options *DataChannelInit) js.Value { + if options == nil { + return js.Undefined() + } + + maxPacketLifeTime := uint16PointerToValue(options.MaxPacketLifeTime) + return js.ValueOf(map[string]any{ + "ordered": boolPointerToValue(options.Ordered), + "maxPacketLifeTime": maxPacketLifeTime, + // See https://bugs.chromium.org/p/chromium/issues/detail?id=696681 + // Chrome calls this "maxRetransmitTime" + "maxRetransmitTime": maxPacketLifeTime, + "maxRetransmits": uint16PointerToValue(options.MaxRetransmits), + "protocol": stringPointerToValue(options.Protocol), + "negotiated": boolPointerToValue(options.Negotiated), + "id": uint16PointerToValue(options.ID), + }) +} + +func rtpTransceiverInitInitToValue(init RTPTransceiverInit) js.Value { + return js.ValueOf(map[string]any{ + "direction": init.Direction.String(), + }) +} diff --git a/vendor/github.com/pion/webrtc/v4/peerconnectionstate.go b/vendor/github.com/pion/webrtc/v4/peerconnectionstate.go new file mode 100644 index 0000000..7ba2310 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/peerconnectionstate.go @@ -0,0 +1,89 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +// PeerConnectionState indicates the state of the PeerConnection. +type PeerConnectionState int + +const ( + // PeerConnectionStateUnknown is the enum's zero-value. + PeerConnectionStateUnknown PeerConnectionState = iota + + // PeerConnectionStateNew indicates that any of the ICETransports or + // DTLSTransports are in the "new" state and none of the transports are + // in the "connecting", "checking", "failed" or "disconnected" state, or + // all transports are in the "closed" state, or there are no transports. + PeerConnectionStateNew + + // PeerConnectionStateConnecting indicates that any of the + // ICETransports or DTLSTransports are in the "connecting" or + // "checking" state and none of them is in the "failed" state. + PeerConnectionStateConnecting + + // PeerConnectionStateConnected indicates that all ICETransports and + // DTLSTransports are in the "connected", "completed" or "closed" state + // and at least one of them is in the "connected" or "completed" state. + PeerConnectionStateConnected + + // PeerConnectionStateDisconnected indicates that any of the + // ICETransports or DTLSTransports are in the "disconnected" state + // and none of them are in the "failed" or "connecting" or "checking" state. + PeerConnectionStateDisconnected + + // PeerConnectionStateFailed indicates that any of the ICETransports + // or DTLSTransports are in a "failed" state. + PeerConnectionStateFailed + + // PeerConnectionStateClosed indicates the peer connection is closed + // and the isClosed member variable of PeerConnection is true. + PeerConnectionStateClosed +) + +// This is done this way because of a linter. +const ( + peerConnectionStateNewStr = "new" + peerConnectionStateConnectingStr = "connecting" + peerConnectionStateConnectedStr = "connected" + peerConnectionStateDisconnectedStr = "disconnected" + peerConnectionStateFailedStr = "failed" + peerConnectionStateClosedStr = "closed" +) + +func newPeerConnectionState(raw string) PeerConnectionState { + switch raw { + case peerConnectionStateNewStr: + return PeerConnectionStateNew + case peerConnectionStateConnectingStr: + return PeerConnectionStateConnecting + case peerConnectionStateConnectedStr: + return PeerConnectionStateConnected + case peerConnectionStateDisconnectedStr: + return PeerConnectionStateDisconnected + case peerConnectionStateFailedStr: + return PeerConnectionStateFailed + case peerConnectionStateClosedStr: + return PeerConnectionStateClosed + default: + return PeerConnectionStateUnknown + } +} + +func (t PeerConnectionState) String() string { + switch t { + case PeerConnectionStateNew: + return peerConnectionStateNewStr + case PeerConnectionStateConnecting: + return peerConnectionStateConnectingStr + case PeerConnectionStateConnected: + return peerConnectionStateConnectedStr + case PeerConnectionStateDisconnected: + return peerConnectionStateDisconnectedStr + case PeerConnectionStateFailed: + return peerConnectionStateFailedStr + case PeerConnectionStateClosed: + return peerConnectionStateClosedStr + default: + return ErrUnknownType.Error() + } +} diff --git a/vendor/github.com/pion/webrtc/v4/pkg/media/media.go b/vendor/github.com/pion/webrtc/v4/pkg/media/media.go new file mode 100644 index 0000000..4b35ec8 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/pkg/media/media.go @@ -0,0 +1,35 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package media provides media writer and filters +package media + +import ( + "time" + + "github.com/pion/rtp" +) + +// A Sample contains encoded media and timing information. +type Sample struct { + Data []byte + Timestamp time.Time + Duration time.Duration + PacketTimestamp uint32 + PrevDroppedPackets uint16 + Metadata any + + // RTP headers of RTP packets forming this Sample. (Optional) + // Useful for accessing RTP extensions associated to the Sample. + RTPHeaders []*rtp.Header +} + +// Writer defines an interface to handle +// the creation of media files. +type Writer interface { + // Add the content of an RTP packet to the media + WriteRTP(packet *rtp.Packet) error + // Close the media + // Note: Close implementation must be idempotent + Close() error +} diff --git a/vendor/github.com/pion/webrtc/v4/pkg/rtcerr/errors.go b/vendor/github.com/pion/webrtc/v4/pkg/rtcerr/errors.go new file mode 100644 index 0000000..ebc60b3 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/pkg/rtcerr/errors.go @@ -0,0 +1,163 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package rtcerr implements the error wrappers defined throughout the +// WebRTC 1.0 specifications. +package rtcerr + +import ( + "fmt" +) + +// UnknownError indicates the operation failed for an unknown transient reason. +type UnknownError struct { + Err error +} + +func (e *UnknownError) Error() string { + return fmt.Sprintf("UnknownError: %v", e.Err) +} + +// Unwrap returns the result of calling the Unwrap method on err, if err's type contains +// an Unwrap method returning error. Otherwise, Unwrap returns nil. +func (e *UnknownError) Unwrap() error { + return e.Err +} + +// InvalidStateError indicates the object is in an invalid state. +type InvalidStateError struct { + Err error +} + +func (e *InvalidStateError) Error() string { + return fmt.Sprintf("InvalidStateError: %v", e.Err) +} + +// Unwrap returns the result of calling the Unwrap method on err, if err's type contains +// an Unwrap method returning error. Otherwise, Unwrap returns nil. +func (e *InvalidStateError) Unwrap() error { + return e.Err +} + +// InvalidAccessError indicates the object does not support the operation or +// argument. +type InvalidAccessError struct { + Err error +} + +func (e *InvalidAccessError) Error() string { + return fmt.Sprintf("InvalidAccessError: %v", e.Err) +} + +// Unwrap returns the result of calling the Unwrap method on err, if err's type contains +// an Unwrap method returning error. Otherwise, Unwrap returns nil. +func (e *InvalidAccessError) Unwrap() error { + return e.Err +} + +// NotSupportedError indicates the operation is not supported. +type NotSupportedError struct { + Err error +} + +func (e *NotSupportedError) Error() string { + return fmt.Sprintf("NotSupportedError: %v", e.Err) +} + +// Unwrap returns the result of calling the Unwrap method on err, if err's type contains +// an Unwrap method returning error. Otherwise, Unwrap returns nil. +func (e *NotSupportedError) Unwrap() error { + return e.Err +} + +// InvalidModificationError indicates the object cannot be modified in this way. +type InvalidModificationError struct { + Err error +} + +func (e *InvalidModificationError) Error() string { + return fmt.Sprintf("InvalidModificationError: %v", e.Err) +} + +// Unwrap returns the result of calling the Unwrap method on err, if err's type contains +// an Unwrap method returning error. Otherwise, Unwrap returns nil. +func (e *InvalidModificationError) Unwrap() error { + return e.Err +} + +// SyntaxError indicates the string did not match the expected pattern. +type SyntaxError struct { + Err error +} + +func (e *SyntaxError) Error() string { + return fmt.Sprintf("SyntaxError: %v", e.Err) +} + +// Unwrap returns the result of calling the Unwrap method on err, if err's type contains +// an Unwrap method returning error. Otherwise, Unwrap returns nil. +func (e *SyntaxError) Unwrap() error { + return e.Err +} + +// TypeError indicates an error when a value is not of the expected type. +type TypeError struct { + Err error +} + +func (e *TypeError) Error() string { + return fmt.Sprintf("TypeError: %v", e.Err) +} + +// Unwrap returns the result of calling the Unwrap method on err, if err's type contains +// an Unwrap method returning error. Otherwise, Unwrap returns nil. +func (e *TypeError) Unwrap() error { + return e.Err +} + +// OperationError indicates the operation failed for an operation-specific +// reason. +type OperationError struct { + Err error +} + +func (e *OperationError) Error() string { + return fmt.Sprintf("OperationError: %v", e.Err) +} + +// Unwrap returns the result of calling the Unwrap method on err, if err's type contains +// an Unwrap method returning error. Otherwise, Unwrap returns nil. +func (e *OperationError) Unwrap() error { + return e.Err +} + +// NotReadableError indicates the input/output read operation failed. +type NotReadableError struct { + Err error +} + +func (e *NotReadableError) Error() string { + return fmt.Sprintf("NotReadableError: %v", e.Err) +} + +// Unwrap returns the result of calling the Unwrap method on err, if err's type contains +// an Unwrap method returning error. Otherwise, Unwrap returns nil. +func (e *NotReadableError) Unwrap() error { + return e.Err +} + +// RangeError indicates an error when a value is not in the set or range +// of allowed values. +type RangeError struct { + Err error +} + +func (e *RangeError) Error() string { + return fmt.Sprintf("RangeError: %v", e.Err) +} + +// Unwrap returns the result of calling the Unwrap method on err, if err's type contains +// an Unwrap method returning error. Otherwise, Unwrap returns nil. +func (e *RangeError) Unwrap() error { + return e.Err +} diff --git a/vendor/github.com/pion/webrtc/v4/renovate.json b/vendor/github.com/pion/webrtc/v4/renovate.json new file mode 100644 index 0000000..f1bb98c --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/renovate.json @@ -0,0 +1,6 @@ +{ + "$schema": "https://docs.renovatebot.com/renovate-schema.json", + "extends": [ + "github>pion/renovate-config" + ] +} diff --git a/vendor/github.com/pion/webrtc/v4/rtcpfeedback.go b/vendor/github.com/pion/webrtc/v4/rtcpfeedback.go new file mode 100644 index 0000000..49b2604 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/rtcpfeedback.go @@ -0,0 +1,34 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +const ( + // TypeRTCPFBTransportCC .. + TypeRTCPFBTransportCC = "transport-cc" + + // TypeRTCPFBGoogREMB .. + TypeRTCPFBGoogREMB = "goog-remb" + + // TypeRTCPFBACK .. + TypeRTCPFBACK = "ack" + + // TypeRTCPFBCCM .. + TypeRTCPFBCCM = "ccm" + + // TypeRTCPFBNACK .. + TypeRTCPFBNACK = "nack" +) + +// RTCPFeedback signals the connection to use additional RTCP packet types. +// https://draft.ortc.org/#dom-rtcrtcpfeedback +type RTCPFeedback struct { + // Type is the type of feedback. + // see: https://draft.ortc.org/#dom-rtcrtcpfeedback + // valid: ack, ccm, nack, goog-remb, transport-cc + Type string + + // The parameter value depends on the type. + // For example, type="nack" parameter="pli" will send Picture Loss Indicator packets. + Parameter string +} diff --git a/vendor/github.com/pion/webrtc/v4/rtcpmuxpolicy.go b/vendor/github.com/pion/webrtc/v4/rtcpmuxpolicy.go new file mode 100644 index 0000000..bff5a03 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/rtcpmuxpolicy.go @@ -0,0 +1,73 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +import ( + "encoding/json" +) + +// RTCPMuxPolicy affects what ICE candidates are gathered to support +// non-multiplexed RTCP. +type RTCPMuxPolicy int + +const ( + // RTCPMuxPolicyUnknown is the enum's zero-value. + RTCPMuxPolicyUnknown RTCPMuxPolicy = iota + + // RTCPMuxPolicyNegotiate indicates to gather ICE candidates for both + // RTP and RTCP candidates. If the remote-endpoint is capable of + // multiplexing RTCP, multiplex RTCP on the RTP candidates. If it is not, + // use both the RTP and RTCP candidates separately. + RTCPMuxPolicyNegotiate + + // RTCPMuxPolicyRequire indicates to gather ICE candidates only for + // RTP and multiplex RTCP on the RTP candidates. If the remote endpoint is + // not capable of rtcp-mux, session negotiation will fail. + RTCPMuxPolicyRequire +) + +// This is done this way because of a linter. +const ( + rtcpMuxPolicyNegotiateStr = "negotiate" + rtcpMuxPolicyRequireStr = "require" +) + +func newRTCPMuxPolicy(raw string) RTCPMuxPolicy { + switch raw { + case rtcpMuxPolicyNegotiateStr: + return RTCPMuxPolicyNegotiate + case rtcpMuxPolicyRequireStr: + return RTCPMuxPolicyRequire + default: + return RTCPMuxPolicyUnknown + } +} + +func (t RTCPMuxPolicy) String() string { + switch t { + case RTCPMuxPolicyNegotiate: + return rtcpMuxPolicyNegotiateStr + case RTCPMuxPolicyRequire: + return rtcpMuxPolicyRequireStr + default: + return ErrUnknownType.Error() + } +} + +// UnmarshalJSON parses the JSON-encoded data and stores the result. +func (t *RTCPMuxPolicy) UnmarshalJSON(b []byte) error { + var val string + if err := json.Unmarshal(b, &val); err != nil { + return err + } + + *t = newRTCPMuxPolicy(val) + + return nil +} + +// MarshalJSON returns the JSON encoding. +func (t RTCPMuxPolicy) MarshalJSON() ([]byte, error) { + return json.Marshal(t.String()) +} diff --git a/vendor/github.com/pion/webrtc/v4/rtpcapabilities.go b/vendor/github.com/pion/webrtc/v4/rtpcapabilities.go new file mode 100644 index 0000000..90b647f --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/rtpcapabilities.go @@ -0,0 +1,12 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +// RTPCapabilities represents the capabilities of a transceiver +// +// https://w3c.github.io/webrtc-pc/#rtcrtpcapabilities +type RTPCapabilities struct { + Codecs []RTPCodecCapability + HeaderExtensions []RTPHeaderExtensionCapability +} diff --git a/vendor/github.com/pion/webrtc/v4/rtpcodec.go b/vendor/github.com/pion/webrtc/v4/rtpcodec.go new file mode 100644 index 0000000..70806f1 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/rtpcodec.go @@ -0,0 +1,227 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +import ( + "fmt" + "strconv" + "strings" + + "github.com/pion/webrtc/v4/internal/fmtp" +) + +// RTPCodecType determines the type of a codec. +type RTPCodecType int + +const ( + // RTPCodecTypeUnknown is the enum's zero-value. + RTPCodecTypeUnknown RTPCodecType = iota + + // RTPCodecTypeAudio indicates this is an audio codec. + RTPCodecTypeAudio + + // RTPCodecTypeVideo indicates this is a video codec. + RTPCodecTypeVideo +) + +func (t RTPCodecType) String() string { + switch t { + case RTPCodecTypeAudio: + return "audio" //nolint: goconst + case RTPCodecTypeVideo: + return "video" //nolint: goconst + default: + return ErrUnknownType.Error() + } +} + +// NewRTPCodecType creates a RTPCodecType from a string. +func NewRTPCodecType(r string) RTPCodecType { + switch { + case strings.EqualFold(r, RTPCodecTypeAudio.String()): + return RTPCodecTypeAudio + case strings.EqualFold(r, RTPCodecTypeVideo.String()): + return RTPCodecTypeVideo + default: + return RTPCodecType(0) + } +} + +// RTPCodecCapability provides information about codec capabilities. +// +// https://w3c.github.io/webrtc-pc/#dictionary-rtcrtpcodeccapability-members +type RTPCodecCapability struct { + MimeType string + ClockRate uint32 + Channels uint16 + SDPFmtpLine string + RTCPFeedback []RTCPFeedback +} + +// RTPHeaderExtensionCapability is used to define a RFC5285 RTP header extension supported by the codec. +// +// https://w3c.github.io/webrtc-pc/#dom-rtcrtpcapabilities-headerextensions +type RTPHeaderExtensionCapability struct { + URI string +} + +// RTPHeaderExtensionParameter represents a negotiated RFC5285 RTP header extension. +// +// https://w3c.github.io/webrtc-pc/#dictionary-rtcrtpheaderextensionparameters-members +type RTPHeaderExtensionParameter struct { + URI string + ID int +} + +// RTPCodecParameters is a sequence containing the media codecs that an RtpSender +// will choose from, as well as entries for RTX, RED and FEC mechanisms. This also +// includes the PayloadType that has been negotiated +// +// https://w3c.github.io/webrtc-pc/#rtcrtpcodecparameters +type RTPCodecParameters struct { + RTPCodecCapability + PayloadType PayloadType + + statsID string +} + +// RTPParameters is a list of negotiated codecs and header extensions +// +// https://w3c.github.io/webrtc-pc/#dictionary-rtcrtpparameters-members +type RTPParameters struct { + HeaderExtensions []RTPHeaderExtensionParameter + Codecs []RTPCodecParameters +} + +type codecMatchType int + +const ( + codecMatchNone codecMatchType = 0 + codecMatchPartial codecMatchType = 1 + codecMatchExact codecMatchType = 2 +) + +// Do a fuzzy find for a codec in the list of codecs +// Used for lookup up a codec in an existing list to find a match +// Returns codecMatchExact, codecMatchPartial, or codecMatchNone. +func codecParametersFuzzySearch( + needle RTPCodecParameters, + haystack []RTPCodecParameters, +) (RTPCodecParameters, codecMatchType) { + needleFmtp := fmtp.Parse( + needle.RTPCodecCapability.MimeType, + needle.RTPCodecCapability.ClockRate, + needle.RTPCodecCapability.Channels, + needle.RTPCodecCapability.SDPFmtpLine) + + // First attempt to match on MimeType + ClockRate + Channels + SDPFmtpLine + for _, c := range haystack { + cfmtp := fmtp.Parse( + c.RTPCodecCapability.MimeType, + c.RTPCodecCapability.ClockRate, + c.RTPCodecCapability.Channels, + c.RTPCodecCapability.SDPFmtpLine) + + if needleFmtp.Match(cfmtp) { + return c, codecMatchExact + } + } + + // Fallback to just MimeType + ClockRate + Channels + for _, c := range haystack { + if strings.EqualFold(c.RTPCodecCapability.MimeType, needle.RTPCodecCapability.MimeType) && + fmtp.ClockRateEqual(c.RTPCodecCapability.MimeType, + c.RTPCodecCapability.ClockRate, + needle.RTPCodecCapability.ClockRate) && + fmtp.ChannelsEqual(c.RTPCodecCapability.MimeType, + c.RTPCodecCapability.Channels, + needle.RTPCodecCapability.Channels) { + return c, codecMatchPartial + } + } + + return RTPCodecParameters{}, codecMatchNone +} + +// Given a CodecParameters find the RTX CodecParameters if one exists. +func findRTXPayloadType(needle PayloadType, haystack []RTPCodecParameters) PayloadType { + aptStr := fmt.Sprintf("apt=%d", needle) + for _, c := range haystack { + if aptStr == c.SDPFmtpLine { + return c.PayloadType + } + } + + return PayloadType(0) +} + +// Given needle CodecParameters, returns if needle is RTX and +// if primary codec corresponding to that needle is in the haystack of codecs. +func primaryPayloadTypeForRTXExists(needle RTPCodecParameters, haystack []RTPCodecParameters) ( + isRTX bool, primaryExists bool, +) { + if !strings.EqualFold(needle.MimeType, MimeTypeRTX) { + return + } + + isRTX = true + parsed := fmtp.Parse(needle.MimeType, needle.ClockRate, needle.Channels, needle.SDPFmtpLine) + aptPayload, ok := parsed.Parameter("apt") + if !ok { + return + } + + primaryPayloadType, err := strconv.Atoi(aptPayload) + if err != nil || primaryPayloadType < 0 || primaryPayloadType > 255 { + return + } + + for _, c := range haystack { + if c.PayloadType == PayloadType(primaryPayloadType) { + primaryExists = true + + return + } + } + + return +} + +// Filter out RTX codecs that do not have a primary codec. +func filterUnattachedRTX(codecs []RTPCodecParameters) []RTPCodecParameters { + for i := len(codecs) - 1; i >= 0; i-- { + c := codecs[i] + if isRTX, primaryExists := primaryPayloadTypeForRTXExists(c, codecs); isRTX && !primaryExists { + // no primary for RTX, remove the RTX + codecs = append(codecs[:i], codecs[i+1:]...) + } + } + + return codecs +} + +// For now, only FlexFEC is supported. +func findFECPayloadType(haystack []RTPCodecParameters) PayloadType { + for _, c := range haystack { + if strings.Contains(c.RTPCodecCapability.MimeType, MimeTypeFlexFEC) { + return c.PayloadType + } + } + + return PayloadType(0) +} + +func rtcpFeedbackIntersection(a, b []RTCPFeedback) (out []RTCPFeedback) { + for _, aFeedback := range a { + for _, bFeeback := range b { + if aFeedback.Type == bFeeback.Type && aFeedback.Parameter == bFeeback.Parameter { + out = append(out, aFeedback) + + break + } + } + } + + return +} diff --git a/vendor/github.com/pion/webrtc/v4/rtpcodingparameters.go b/vendor/github.com/pion/webrtc/v4/rtpcodingparameters.go new file mode 100644 index 0000000..0a369ae --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/rtpcodingparameters.go @@ -0,0 +1,27 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +// RTPRtxParameters dictionary contains information relating to retransmission (RTX) settings. +// https://draft.ortc.org/#dom-rtcrtprtxparameters +type RTPRtxParameters struct { + SSRC SSRC `json:"ssrc"` +} + +// RTPFecParameters dictionary contains information relating to forward error correction (FEC) settings. +// https://draft.ortc.org/#dom-rtcrtpfecparameters +type RTPFecParameters struct { + SSRC SSRC `json:"ssrc"` +} + +// RTPCodingParameters provides information relating to both encoding and decoding. +// This is a subset of the RFC since Pion WebRTC doesn't implement encoding/decoding itself +// http://draft.ortc.org/#dom-rtcrtpcodingparameters +type RTPCodingParameters struct { + RID string `json:"rid"` + SSRC SSRC `json:"ssrc"` + PayloadType PayloadType `json:"payloadType"` + RTX RTPRtxParameters `json:"rtx"` + FEC RTPFecParameters `json:"fec"` +} diff --git a/vendor/github.com/pion/webrtc/v4/rtpdecodingparameters.go b/vendor/github.com/pion/webrtc/v4/rtpdecodingparameters.go new file mode 100644 index 0000000..70a179b --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/rtpdecodingparameters.go @@ -0,0 +1,11 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +// RTPDecodingParameters provides information relating to both encoding and decoding. +// This is a subset of the RFC since Pion WebRTC doesn't implement decoding itself +// http://draft.ortc.org/#dom-rtcrtpdecodingparameters +type RTPDecodingParameters struct { + RTPCodingParameters +} diff --git a/vendor/github.com/pion/webrtc/v4/rtpencodingparameters.go b/vendor/github.com/pion/webrtc/v4/rtpencodingparameters.go new file mode 100644 index 0000000..4c7f618 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/rtpencodingparameters.go @@ -0,0 +1,11 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +// RTPEncodingParameters provides information relating to both encoding and decoding. +// This is a subset of the RFC since Pion WebRTC doesn't implement encoding itself +// http://draft.ortc.org/#dom-rtcrtpencodingparameters +type RTPEncodingParameters struct { + RTPCodingParameters +} diff --git a/vendor/github.com/pion/webrtc/v4/rtpreceiveparameters.go b/vendor/github.com/pion/webrtc/v4/rtpreceiveparameters.go new file mode 100644 index 0000000..aa685b3 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/rtpreceiveparameters.go @@ -0,0 +1,9 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +// RTPReceiveParameters contains the RTP stack settings used by receivers. +type RTPReceiveParameters struct { + Encodings []RTPDecodingParameters +} diff --git a/vendor/github.com/pion/webrtc/v4/rtpreceiver.go b/vendor/github.com/pion/webrtc/v4/rtpreceiver.go new file mode 100644 index 0000000..284e65e --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/rtpreceiver.go @@ -0,0 +1,776 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js + +package webrtc + +import ( + "encoding/binary" + "fmt" + "io" + "math" + "sync" + "sync/atomic" + "time" + + "github.com/pion/interceptor" + "github.com/pion/interceptor/pkg/stats" + "github.com/pion/logging" + "github.com/pion/rtcp" + "github.com/pion/srtp/v3" + "github.com/pion/webrtc/v4/internal/util" +) + +// trackStreams maintains a mapping of RTP/RTCP streams to a specific track +// a RTPReceiver may contain multiple streams if we are dealing with Simulcast. +type trackStreams struct { + track *TrackRemote + + streamInfo, repairStreamInfo *interceptor.StreamInfo + + rtpReadStream *srtp.ReadStreamSRTP + rtpInterceptor interceptor.RTPReader + + rtcpReadStream *srtp.ReadStreamSRTCP + rtcpInterceptor interceptor.RTCPReader + + repairReadStream *srtp.ReadStreamSRTP + repairInterceptor interceptor.RTPReader + repairStreamChannel chan rtxPacketWithAttributes + + repairRtcpReadStream *srtp.ReadStreamSRTCP + repairRtcpInterceptor interceptor.RTCPReader +} + +type rtxPacketWithAttributes struct { + pkt []byte + attributes interceptor.Attributes + pool *sync.Pool +} + +func (p *rtxPacketWithAttributes) release() { + if p.pkt != nil { + b := p.pkt[:cap(p.pkt)] + p.pool.Put(b) // nolint:staticcheck + p.pkt = nil + } +} + +// RTPReceiver allows an application to inspect the receipt of a TrackRemote. +type RTPReceiver struct { + kind RTPCodecType + transport *DTLSTransport + + tracks []trackStreams + + closed atomic.Bool + closedChan, received chan any + mu sync.RWMutex + + tr *RTPTransceiver + + // A reference to the associated api object + api *API + + rtxPool sync.Pool + + log logging.LeveledLogger +} + +// NewRTPReceiver constructs a new RTPReceiver. +func (api *API) NewRTPReceiver(kind RTPCodecType, transport *DTLSTransport) (*RTPReceiver, error) { + if transport == nil { + return nil, errRTPReceiverDTLSTransportNil + } + + rtpReceiver := &RTPReceiver{ + kind: kind, + transport: transport, + api: api, + closedChan: make(chan any), + received: make(chan any), + tracks: []trackStreams{}, + rtxPool: sync.Pool{New: func() any { + return make([]byte, api.settingEngine.getReceiveMTU()) + }}, + log: api.settingEngine.LoggerFactory.NewLogger("RTPReceiver"), + } + + return rtpReceiver, nil +} + +func (r *RTPReceiver) setRTPTransceiver(tr *RTPTransceiver) { + r.mu.Lock() + defer r.mu.Unlock() + r.tr = tr +} + +// Transport returns the currently-configured *DTLSTransport or nil +// if one has not yet been configured. +func (r *RTPReceiver) Transport() *DTLSTransport { + r.mu.RLock() + defer r.mu.RUnlock() + + return r.transport +} + +func (r *RTPReceiver) getParameters() RTPParameters { + parameters := r.api.mediaEngine.getRTPParametersByKind( + r.kind, + []RTPTransceiverDirection{RTPTransceiverDirectionRecvonly}, + ) + if r.tr != nil { + parameters.Codecs = r.tr.getCodecs() + } + + return parameters +} + +// GetParameters describes the current configuration for the encoding and +// transmission of media on the receiver's track. +func (r *RTPReceiver) GetParameters() RTPParameters { + r.mu.RLock() + defer r.mu.RUnlock() + + return r.getParameters() +} + +// Track returns the RtpTransceiver TrackRemote. +func (r *RTPReceiver) Track() *TrackRemote { + r.mu.RLock() + defer r.mu.RUnlock() + + if len(r.tracks) != 1 { + return nil + } + + return r.tracks[0].track +} + +// Tracks returns the RtpTransceiver tracks +// A RTPReceiver to support Simulcast may now have multiple tracks. +func (r *RTPReceiver) Tracks() []*TrackRemote { + r.mu.RLock() + defer r.mu.RUnlock() + + var tracks []*TrackRemote + for i := range r.tracks { + tracks = append(tracks, r.tracks[i].track) + } + + return tracks +} + +// RTPTransceiver returns the RTPTransceiver this +// RTPReceiver belongs too, or nil if none. +func (r *RTPReceiver) RTPTransceiver() *RTPTransceiver { + r.mu.Lock() + defer r.mu.Unlock() + + return r.tr +} + +// configureReceive initialize the track. +func (r *RTPReceiver) configureReceive(parameters RTPReceiveParameters) { + r.mu.Lock() + defer r.mu.Unlock() + + for i := range parameters.Encodings { + t := trackStreams{ + track: newTrackRemote( + r.kind, + parameters.Encodings[i].SSRC, + parameters.Encodings[i].RTX.SSRC, + parameters.Encodings[i].RID, + r, + ), + } + + r.tracks = append(r.tracks, t) + } +} + +// startReceive starts all the transports. +func (r *RTPReceiver) startReceive(parameters RTPReceiveParameters) error { //nolint:cyclop + r.mu.Lock() + defer r.mu.Unlock() + select { + case <-r.received: + return errRTPReceiverReceiveAlreadyCalled + default: + } + + globalParams := r.getParameters() + codec := RTPCodecCapability{} + if len(globalParams.Codecs) != 0 { + codec = globalParams.Codecs[0].RTPCodecCapability + } + + for i := range parameters.Encodings { + if parameters.Encodings[i].RID != "" { + // RID based tracks will be set up in receiveForRid + continue + } + + var streams *trackStreams + for idx, ts := range r.tracks { + if ts.track != nil && ts.track.SSRC() == parameters.Encodings[i].SSRC { + streams = &r.tracks[idx] + + break + } + } + if streams == nil { + return fmt.Errorf("%w: %d", errRTPReceiverWithSSRCTrackStreamNotFound, parameters.Encodings[i].SSRC) + } + + streams.streamInfo = createStreamInfo( + "", + parameters.Encodings[i].SSRC, + 0, 0, 0, 0, 0, + codec, + globalParams.HeaderExtensions, + ) + + result, err := r.transport.streamsForSSRC(parameters.Encodings[i].SSRC, *streams.streamInfo) + if err != nil { + return err + } + streams.rtpReadStream = result.rtpReadStream + streams.rtpInterceptor = result.rtpInterceptor + streams.rtcpReadStream = result.rtcpReadStream + streams.rtcpInterceptor = result.rtcpInterceptor + + if rtxSsrc := parameters.Encodings[i].RTX.SSRC; rtxSsrc != 0 { + // See RFC 4588 section 6.3, + // NACKs MUST be sent only for the original RTP stream. + rtxCodec := codec + rtxCodec.RTCPFeedback = nil + rtxCodec.MimeType = MimeTypeRTX + streamInfo := createStreamInfo("", rtxSsrc, 0, 0, 0, 0, 0, rtxCodec, globalParams.HeaderExtensions) + result, err = r.transport.streamsForSSRC( + rtxSsrc, + *streamInfo, + ) + if err != nil { + return err + } + rtpReadStream := result.rtpReadStream + rtpInterceptor := result.rtpInterceptor + rtcpReadStream := result.rtcpReadStream + rtcpInterceptor := result.rtcpInterceptor + + if err = r.receiveForRtxInternal( + rtxSsrc, + "", + streamInfo, + rtpReadStream, + rtpInterceptor, + rtcpReadStream, + rtcpInterceptor, + ); err != nil { + return err + } + } + } + + close(r.received) + + return nil +} + +// Receive initialize the track and starts all the transports. +func (r *RTPReceiver) Receive(parameters RTPReceiveParameters) error { + r.configureReceive(parameters) + + return r.startReceive(parameters) +} + +// Read reads incoming RTCP for this RTPReceiver. +func (r *RTPReceiver) Read(b []byte) (n int, a interceptor.Attributes, err error) { + select { + case <-r.received: + if len(r.tracks) > 1 { + r.log.Errorf(useReadSimulcast) + } + + return r.tracks[0].rtcpInterceptor.Read(b, a) + case <-r.closedChan: + return 0, nil, io.ErrClosedPipe + } +} + +// ReadSimulcast reads incoming RTCP for this RTPReceiver for given rid. +func (r *RTPReceiver) ReadSimulcast(b []byte, rid string) (n int, a interceptor.Attributes, err error) { + select { + case <-r.received: + var rtcpInterceptor interceptor.RTCPReader + + r.mu.Lock() + for _, t := range r.tracks { + if t.track != nil && t.track.rid == rid { + rtcpInterceptor = t.rtcpInterceptor + } + } + r.mu.Unlock() + + if rtcpInterceptor == nil { + return 0, nil, fmt.Errorf("%w: %s", errRTPReceiverForRIDTrackStreamNotFound, rid) + } + + return rtcpInterceptor.Read(b, a) + + case <-r.closedChan: + return 0, nil, io.ErrClosedPipe + } +} + +// ReadRTCP is a convenience method that wraps Read and unmarshal for you. +// It also runs any configured interceptors. +func (r *RTPReceiver) ReadRTCP() ([]rtcp.Packet, interceptor.Attributes, error) { + b := make([]byte, r.api.settingEngine.getReceiveMTU()) + i, attributes, err := r.Read(b) + if err != nil { + return nil, nil, err + } + + pkts, err := rtcp.Unmarshal(b[:i]) + if err != nil { + return nil, nil, err + } + + return pkts, attributes, nil +} + +// ReadSimulcastRTCP is a convenience method that wraps ReadSimulcast and unmarshal for you. +func (r *RTPReceiver) ReadSimulcastRTCP(rid string) ([]rtcp.Packet, interceptor.Attributes, error) { + b := make([]byte, r.api.settingEngine.getReceiveMTU()) + i, attributes, err := r.ReadSimulcast(b, rid) + if err != nil { + return nil, nil, err + } + + pkts, err := rtcp.Unmarshal(b[:i]) + + return pkts, attributes, err +} + +func (r *RTPReceiver) haveReceived() bool { + select { + case <-r.received: + return true + default: + return false + } +} + +func (r *RTPReceiver) haveClosed() bool { + return r.closed.Load() +} + +// Stop irreversibly stops the RTPReceiver. +func (r *RTPReceiver) Stop() error { //nolint:cyclop + r.mu.Lock() + defer r.mu.Unlock() + var err error + + select { + case <-r.closedChan: + return err + default: + } + + select { + case <-r.received: + for i := range r.tracks { + errs := []error{} + + if r.tracks[i].rtcpReadStream != nil { + errs = append(errs, r.tracks[i].rtcpReadStream.Close()) + } + + if r.tracks[i].rtpReadStream != nil { + errs = append(errs, r.tracks[i].rtpReadStream.Close()) + } + + if r.tracks[i].repairReadStream != nil { + errs = append(errs, r.tracks[i].repairReadStream.Close()) + } + + if r.tracks[i].repairRtcpReadStream != nil { + errs = append(errs, r.tracks[i].repairRtcpReadStream.Close()) + } + + if r.tracks[i].streamInfo != nil { + r.api.interceptor.UnbindRemoteStream(r.tracks[i].streamInfo) + } + + if r.tracks[i].repairStreamInfo != nil { + r.api.interceptor.UnbindRemoteStream(r.tracks[i].repairStreamInfo) + } + + err = util.FlattenErrs(errs) + } + default: + } + + close(r.closedChan) + r.closed.Store(true) + + return err +} + +func (r *RTPReceiver) collectStats(collector *statsReportCollector, statsGetter stats.Getter) { + if statsGetter == nil { + return + } + + r.mu.Lock() + defer r.mu.Unlock() + + // Emit inbound-rtp stats for each track + mid := "" + if r.tr != nil { + mid = r.tr.Mid() + } + now := statsTimestampNow() + nowTime := now.Time() + for trackIndex := range r.tracks { + remoteTrack := r.tracks[trackIndex].track + if remoteTrack == nil { + continue + } + + collector.Collecting() + + inboundID := fmt.Sprintf("inbound-rtp-%d", uint32(remoteTrack.SSRC())) + codecID := "" + if remoteTrack.codec.statsID != "" { + codecID = remoteTrack.codec.statsID + } + + inboundStats := InboundRTPStreamStats{ + Rid: remoteTrack.RID(), + Mid: mid, + Timestamp: now, + Type: StatsTypeInboundRTP, + ID: inboundID, + SSRC: remoteTrack.SSRC(), + Kind: r.kind.String(), + TransportID: "iceTransport", + CodecID: codecID, + } + r.populateInboundStats(&inboundStats, statsGetter, remoteTrack) + + collector.Collect(inboundID, inboundStats) + + if remoteTrack.Kind() == RTPCodecTypeAudio { + r.collectAudioPlayoutStats(collector, nowTime, remoteTrack) + } + } +} + +func (r *RTPReceiver) populateInboundStats( + inboundStats *InboundRTPStreamStats, + statsGetter stats.Getter, + remoteTrack *TrackRemote, +) { + stats := statsGetter.Get(uint32(remoteTrack.SSRC())) + if stats == nil { + return + } + + // Wrap-around casting by design, with warnings if overflow/underflow is detected. + pr := stats.InboundRTPStreamStats.PacketsReceived + if pr > math.MaxUint32 { + r.log.Warnf("Inbound PacketsReceived exceeds uint32 and will wrap: %d", pr) + } + inboundStats.PacketsReceived = uint32(pr) //nolint:gosec + + pl := stats.InboundRTPStreamStats.PacketsLost + if pl > math.MaxInt32 || pl < math.MinInt32 { + r.log.Warnf("Inbound PacketsLost exceeds int32 range and will wrap: %d", pl) + } + inboundStats.PacketsLost = int32(pl) //nolint:gosec + + inboundStats.Jitter = stats.InboundRTPStreamStats.Jitter + inboundStats.BytesReceived = stats.InboundRTPStreamStats.BytesReceived + inboundStats.HeaderBytesReceived = stats.InboundRTPStreamStats.HeaderBytesReceived + timestamp := stats.InboundRTPStreamStats.LastPacketReceivedTimestamp + inboundStats.LastPacketReceivedTimestamp = StatsTimestamp( + timestamp.UnixNano() / int64(time.Millisecond)) + inboundStats.FIRCount = stats.InboundRTPStreamStats.FIRCount + inboundStats.PLICount = stats.InboundRTPStreamStats.PLICount + inboundStats.NACKCount = stats.InboundRTPStreamStats.NACKCount +} + +func (r *RTPReceiver) collectAudioPlayoutStats( + collector *statsReportCollector, + nowTime time.Time, + remoteTrack *TrackRemote, +) { + playoutStats := remoteTrack.pullAudioPlayoutStats(nowTime) + for _, stats := range playoutStats { + collector.Collecting() + collector.Collect(stats.ID, stats) + } +} + +func (r *RTPReceiver) streamsForTrack(t *TrackRemote) *trackStreams { + for i := range r.tracks { + if r.tracks[i].track == t { + return &r.tracks[i] + } + } + + return nil +} + +// readRTP should only be called by a track, this only exists so we can keep state in one place. +func (r *RTPReceiver) readRTP(b []byte, reader *TrackRemote) (n int, a interceptor.Attributes, err error) { + select { + case <-r.received: + case <-r.closedChan: + return 0, nil, io.EOF + } + + if t := r.streamsForTrack(reader); t != nil { + return t.rtpInterceptor.Read(b, a) + } + + return 0, nil, fmt.Errorf("%w: %d", errRTPReceiverWithSSRCTrackStreamNotFound, reader.SSRC()) +} + +// receiveForRid is the sibling of Receive expect for RIDs instead of SSRCs +// It populates all the internal state for the given RID. +func (r *RTPReceiver) receiveForRid( + rid string, + params RTPParameters, + streamInfo *interceptor.StreamInfo, + rtpReadStream *srtp.ReadStreamSRTP, + rtpInterceptor interceptor.RTPReader, + rtcpReadStream *srtp.ReadStreamSRTCP, + rtcpInterceptor interceptor.RTCPReader, + peekedPackets []*peekedPacket, +) (*TrackRemote, error) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.haveClosed() { + return nil, io.EOF + } + + for i := range r.tracks { + if r.tracks[i].track.RID() == rid { + r.tracks[i].track.mu.Lock() + r.tracks[i].track.kind = r.kind + r.tracks[i].track.codec = params.Codecs[0] + r.tracks[i].track.params = params + r.tracks[i].track.ssrc = SSRC(streamInfo.SSRC) + r.tracks[i].track.peekedPackets = peekedPackets + r.tracks[i].track.mu.Unlock() + + r.tracks[i].streamInfo = streamInfo + r.tracks[i].rtpReadStream = rtpReadStream + r.tracks[i].rtpInterceptor = rtpInterceptor + r.tracks[i].rtcpReadStream = rtcpReadStream + r.tracks[i].rtcpInterceptor = rtcpInterceptor + + return r.tracks[i].track, nil + } + } + + return nil, fmt.Errorf("%w: %s", errRTPReceiverForRIDTrackStreamNotFound, rid) +} + +// receiveForRtx starts a routine that processes the repair stream. +func (r *RTPReceiver) receiveForRtx( + ssrc SSRC, + rsid string, + streamInfo *interceptor.StreamInfo, + rtpReadStream *srtp.ReadStreamSRTP, + rtpInterceptor interceptor.RTPReader, + rtcpReadStream *srtp.ReadStreamSRTCP, + rtcpInterceptor interceptor.RTCPReader, +) error { + r.mu.Lock() + defer r.mu.Unlock() + + return r.receiveForRtxInternal( + ssrc, + rsid, + streamInfo, + rtpReadStream, + rtpInterceptor, + rtcpReadStream, + rtcpInterceptor, + ) +} + +//nolint:gocognit,cyclop +func (r *RTPReceiver) receiveForRtxInternal( + ssrc SSRC, + rsid string, + streamInfo *interceptor.StreamInfo, + rtpReadStream *srtp.ReadStreamSRTP, + rtpInterceptor interceptor.RTPReader, + rtcpReadStream *srtp.ReadStreamSRTCP, + rtcpInterceptor interceptor.RTCPReader, +) error { + if r.haveClosed() { + return io.EOF + } + + var track *trackStreams + if ssrc != 0 && len(r.tracks) == 1 { + track = &r.tracks[0] + } else { + for i := range r.tracks { + if r.tracks[i].track.RID() == rsid { + track = &r.tracks[i] + if track.track.RtxSSRC() == 0 { + track.track.setRtxSSRC(SSRC(streamInfo.SSRC)) + } + + break + } + } + } + + if track == nil { + return fmt.Errorf("%w: ssrc(%d) rsid(%s)", errRTPReceiverForRIDTrackStreamNotFound, ssrc, rsid) + } + + track.repairStreamInfo = streamInfo + track.repairReadStream = rtpReadStream + track.repairInterceptor = rtpInterceptor + track.repairRtcpReadStream = rtcpReadStream + track.repairRtcpInterceptor = rtcpInterceptor + track.repairStreamChannel = make(chan rtxPacketWithAttributes, 50) + + repairInterceptor := track.repairInterceptor + repairStreamChannel := track.repairStreamChannel + go func() { + for { + b := r.rtxPool.Get().([]byte) // nolint:forcetypeassert + i, attributes, err := repairInterceptor.Read(b, nil) + if err != nil { + r.rtxPool.Put(b) // nolint:staticcheck + + return + } + + // RTX packets have a different payload format. Move the OSN in the payload to the RTP header and rewrite the + // payload type and SSRC, so that we can return RTX packets to the caller 'transparently' i.e. in the same format + // as non-RTX RTP packets + hasExtension := b[0]&0b10000 > 0 + hasPadding := b[0]&0b100000 > 0 + csrcCount := b[0] & 0b1111 + headerLength := uint16(12 + (4 * csrcCount)) + paddingLength := 0 + if hasExtension { + headerLength += 4 * (1 + binary.BigEndian.Uint16(b[headerLength+2:headerLength+4])) + } + if hasPadding { + paddingLength = int(b[i-1]) + } + + if i-int(headerLength)-paddingLength < 2 { + // BWE probe packet, ignore + r.rtxPool.Put(b) // nolint:staticcheck + + continue + } + + if attributes == nil { + attributes = make(interceptor.Attributes) + } + attributes.Set(AttributeRtxPayloadType, b[1]&0x7F) + attributes.Set(AttributeRtxSequenceNumber, binary.BigEndian.Uint16(b[2:4])) + attributes.Set(AttributeRtxSsrc, binary.BigEndian.Uint32(b[8:12])) + + b[1] = (b[1] & 0x80) | uint8(track.track.PayloadType()) + b[2] = b[headerLength] + b[3] = b[headerLength+1] + binary.BigEndian.PutUint32(b[8:12], uint32(track.track.SSRC())) + copy(b[headerLength:i-2], b[headerLength+2:i]) + + select { + case <-r.closedChan: + r.rtxPool.Put(b) // nolint:staticcheck + + return + case repairStreamChannel <- rtxPacketWithAttributes{pkt: b[:i-2], attributes: attributes, pool: &r.rtxPool}: + default: + // skip the RTX packet if the repair stream channel is full, could be blocked in the application's read loop + } + } + }() + + return nil +} + +// SetReadDeadline sets the max amount of time the RTCP stream will block before returning. 0 is forever. +func (r *RTPReceiver) SetReadDeadline(t time.Time) error { + r.mu.RLock() + defer r.mu.RUnlock() + + return r.tracks[0].rtcpReadStream.SetReadDeadline(t) +} + +// SetReadDeadlineSimulcast sets the max amount of time the RTCP stream for a given rid will block before returning. +// 0 is forever. +func (r *RTPReceiver) SetReadDeadlineSimulcast(deadline time.Time, rid string) error { + r.mu.RLock() + defer r.mu.RUnlock() + + for _, t := range r.tracks { + if t.track != nil && t.track.rid == rid { + return t.rtcpReadStream.SetReadDeadline(deadline) + } + } + + return fmt.Errorf("%w: %s", errRTPReceiverForRIDTrackStreamNotFound, rid) +} + +// setRTPReadDeadline sets the max amount of time the RTP stream will block before returning. 0 is forever. +// This should be fired by calling SetReadDeadline on the TrackRemote. +func (r *RTPReceiver) setRTPReadDeadline(deadline time.Time, reader *TrackRemote) error { + r.mu.RLock() + defer r.mu.RUnlock() + + if t := r.streamsForTrack(reader); t != nil { + return t.rtpReadStream.SetReadDeadline(deadline) + } + + return fmt.Errorf("%w: %d", errRTPReceiverWithSSRCTrackStreamNotFound, reader.SSRC()) +} + +// readRTX returns an RTX packet if one is available on the RTX track, otherwise returns nil. +func (r *RTPReceiver) readRTX(reader *TrackRemote) *rtxPacketWithAttributes { + if !reader.HasRTX() || r.haveClosed() { + return nil + } + + select { + case <-r.received: + default: + return nil + } + + r.mu.RLock() + var ch chan rtxPacketWithAttributes + if t := r.streamsForTrack(reader); t != nil { + ch = t.repairStreamChannel + } + r.mu.RUnlock() + + select { + case rtxPacketReceived := <-ch: + return &rtxPacketReceived + default: + } + + return nil +} diff --git a/vendor/github.com/pion/webrtc/v4/rtpreceiver_go.go b/vendor/github.com/pion/webrtc/v4/rtpreceiver_go.go new file mode 100644 index 0000000..6eee580 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/rtpreceiver_go.go @@ -0,0 +1,35 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js + +package webrtc + +import "github.com/pion/interceptor" + +// SetRTPParameters applies provided RTPParameters the RTPReceiver's tracks. +// +// This method is part of the ORTC API. It is not +// meant to be used together with the basic WebRTC API. +// +// The amount of provided codecs must match the number of tracks on the receiver. +func (r *RTPReceiver) SetRTPParameters(params RTPParameters) { + headerExtensions := make([]interceptor.RTPHeaderExtension, 0, len(params.HeaderExtensions)) + for _, h := range params.HeaderExtensions { + headerExtensions = append(headerExtensions, interceptor.RTPHeaderExtension{ID: h.ID, URI: h.URI}) + } + + r.mu.Lock() + defer r.mu.Unlock() + + for ndx, codec := range params.Codecs { + currentTrack := r.tracks[ndx].track + + r.tracks[ndx].streamInfo.RTPHeaderExtensions = headerExtensions + + currentTrack.mu.Lock() + currentTrack.codec = codec + currentTrack.params = params + currentTrack.mu.Unlock() + } +} diff --git a/vendor/github.com/pion/webrtc/v4/rtpreceiver_js.go b/vendor/github.com/pion/webrtc/v4/rtpreceiver_js.go new file mode 100644 index 0000000..e54ce2b --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/rtpreceiver_js.go @@ -0,0 +1,20 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build js && wasm +// +build js,wasm + +package webrtc + +import "syscall/js" + +// RTPReceiver allows an application to inspect the receipt of a TrackRemote +type RTPReceiver struct { + // Pointer to the underlying JavaScript RTCRTPReceiver object. + underlying js.Value +} + +// JSValue returns the underlying RTCRtpReceiver +func (r *RTPReceiver) JSValue() js.Value { + return r.underlying +} \ No newline at end of file diff --git a/vendor/github.com/pion/webrtc/v4/rtpsender.go b/vendor/github.com/pion/webrtc/v4/rtpsender.go new file mode 100644 index 0000000..0c88404 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/rtpsender.go @@ -0,0 +1,529 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js + +package webrtc + +import ( + "fmt" + "io" + "sync" + "time" + + "github.com/pion/interceptor" + "github.com/pion/randutil" + "github.com/pion/rtcp" + "github.com/pion/rtp" + "github.com/pion/webrtc/v4/internal/util" +) + +type trackEncoding struct { + track TrackLocal + + srtpStream *srtpWriterFuture + + rtcpInterceptor interceptor.RTCPReader + streamInfo interceptor.StreamInfo + + context *baseTrackLocalContext + + ssrc, ssrcRTX, ssrcFEC SSRC +} + +// RTPSender allows an application to control how a given Track is encoded and transmitted to a remote peer. +type RTPSender struct { + trackEncodings []*trackEncoding + + transport *DTLSTransport + + payloadType PayloadType + kind RTPCodecType + + // nolint:godox + // TODO(sgotti) remove this when in future we'll avoid replacing + // a transceiver sender since we can just check the + // transceiver negotiation status + negotiated bool + + // A reference to the associated api object + api *API + id string + + rtpTransceiver *RTPTransceiver + + mu sync.RWMutex + sendCalled, stopCalled chan struct{} +} + +// NewRTPSender constructs a new RTPSender. +func (api *API) NewRTPSender(track TrackLocal, transport *DTLSTransport) (*RTPSender, error) { + if track == nil { + return nil, errRTPSenderTrackNil + } else if transport == nil { + return nil, errRTPSenderDTLSTransportNil + } + + id, err := randutil.GenerateCryptoRandomString(32, "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + if err != nil { + return nil, err + } + + r := &RTPSender{ + transport: transport, + api: api, + sendCalled: make(chan struct{}), + stopCalled: make(chan struct{}), + id: id, + kind: track.Kind(), + } + + r.addEncoding(track) + + return r, nil +} + +func (r *RTPSender) isNegotiated() bool { + r.mu.RLock() + defer r.mu.RUnlock() + + return r.negotiated +} + +func (r *RTPSender) setNegotiated() { + r.mu.Lock() + defer r.mu.Unlock() + r.negotiated = true +} + +func (r *RTPSender) setRTPTransceiver(rtpTransceiver *RTPTransceiver) { + r.mu.Lock() + defer r.mu.Unlock() + r.rtpTransceiver = rtpTransceiver +} + +// Transport returns the currently-configured *DTLSTransport or nil +// if one has not yet been configured. +func (r *RTPSender) Transport() *DTLSTransport { + r.mu.RLock() + defer r.mu.RUnlock() + + return r.transport +} + +// GetParameters describes the current configuration for the encoding and +// transmission of media on the sender's track. +func (r *RTPSender) GetParameters() RTPSendParameters { + r.mu.RLock() + defer r.mu.RUnlock() + + var encodings []RTPEncodingParameters + for _, trackEncoding := range r.trackEncodings { + var rid string + if trackEncoding.track != nil { + rid = trackEncoding.track.RID() + } + encodings = append(encodings, RTPEncodingParameters{ + RTPCodingParameters: RTPCodingParameters{ + RID: rid, + SSRC: trackEncoding.ssrc, + RTX: RTPRtxParameters{SSRC: trackEncoding.ssrcRTX}, + FEC: RTPFecParameters{SSRC: trackEncoding.ssrcFEC}, + PayloadType: r.payloadType, + }, + }) + } + sendParameters := RTPSendParameters{ + RTPParameters: r.api.mediaEngine.getRTPParametersByKind( + r.kind, + []RTPTransceiverDirection{RTPTransceiverDirectionSendonly}, + ), + Encodings: encodings, + } + if r.rtpTransceiver != nil { + sendParameters.Codecs = r.rtpTransceiver.getCodecs() + } else { + sendParameters.Codecs = r.api.mediaEngine.getCodecsByKind(r.kind) + } + + return sendParameters +} + +// AddEncoding adds an encoding to RTPSender. Used by simulcast senders. +func (r *RTPSender) AddEncoding(track TrackLocal) error { //nolint:cyclop + r.mu.Lock() + defer r.mu.Unlock() + + if track == nil { + return errRTPSenderTrackNil + } + + if track.RID() == "" { + return errRTPSenderRidNil + } + + if r.hasStopped() { + return errRTPSenderStopped + } + + if r.hasSent() { + return errRTPSenderSendAlreadyCalled + } + + var refTrack TrackLocal + if len(r.trackEncodings) != 0 { + refTrack = r.trackEncodings[0].track + } + if refTrack == nil || refTrack.RID() == "" { + return errRTPSenderNoBaseEncoding + } + + if refTrack.ID() != track.ID() || refTrack.StreamID() != track.StreamID() || refTrack.Kind() != track.Kind() { + return errRTPSenderBaseEncodingMismatch + } + + for _, encoding := range r.trackEncodings { + if encoding.track == nil { + continue + } + + if encoding.track.RID() == track.RID() { + return errRTPSenderRIDCollision + } + } + + r.addEncoding(track) + + return nil +} + +func (r *RTPSender) addEncoding(track TrackLocal) { + trackEncoding := &trackEncoding{ + track: track, + ssrc: SSRC(util.RandUint32()), + } + + if r.api.mediaEngine.isRTXEnabled(r.kind, []RTPTransceiverDirection{RTPTransceiverDirectionSendonly}) { + trackEncoding.ssrcRTX = SSRC(util.RandUint32()) + } + + if r.api.mediaEngine.isFECEnabled(r.kind, []RTPTransceiverDirection{RTPTransceiverDirectionSendonly}) { + trackEncoding.ssrcFEC = SSRC(util.RandUint32()) + } + + r.trackEncodings = append(r.trackEncodings, trackEncoding) +} + +// Track returns the RTCRtpTransceiver track, or nil. +func (r *RTPSender) Track() TrackLocal { + r.mu.RLock() + defer r.mu.RUnlock() + + if len(r.trackEncodings) == 0 { + return nil + } + + return r.trackEncodings[0].track +} + +// ReplaceTrack replaces the track currently being used as the sender's source with a new TrackLocal. +// The new track must be of the same media kind (audio, video, etc) and switching the track should not +// require negotiation. +func (r *RTPSender) ReplaceTrack(track TrackLocal) error { //nolint:cyclop + r.mu.Lock() + defer r.mu.Unlock() + + if track != nil && r.kind != track.Kind() { + return ErrRTPSenderNewTrackHasIncorrectKind + } + + // cannot replace simulcast envelope + if track != nil && len(r.trackEncodings) > 1 { + return ErrRTPSenderNewTrackHasIncorrectEnvelope + } + + var replacedTrack TrackLocal + var context *baseTrackLocalContext + for _, e := range r.trackEncodings { + replacedTrack = e.track + context = e.context + + if r.hasSent() && replacedTrack != nil { + if err := replacedTrack.Unbind(context); err != nil { + return err + } + } + + if !r.hasSent() || track == nil { + e.track = track + } + } + + if !r.hasSent() || track == nil { + return nil + } + + params := r.api.mediaEngine.getRTPParametersByKind( + track.Kind(), + []RTPTransceiverDirection{RTPTransceiverDirectionSendonly}, + ) + + // If we reach this point in the routine, there is only 1 track encoding + codec, err := track.Bind(&baseTrackLocalContext{ + id: context.ID(), + params: params, + ssrc: context.SSRC(), + ssrcRTX: context.SSRCRetransmission(), + ssrcFEC: context.SSRCForwardErrorCorrection(), + writeStream: context.WriteStream(), + rtcpInterceptor: context.RTCPReader(), + }) + if err != nil { + // Re-bind the original track + if _, reBindErr := replacedTrack.Bind(context); reBindErr != nil { + return reBindErr + } + + return err + } + + // Codec has changed + if r.payloadType != codec.PayloadType { + context.params.Codecs = []RTPCodecParameters{codec} + } + + r.trackEncodings[0].track = track + + return nil +} + +// Send Attempts to set the parameters controlling the sending of media. +func (r *RTPSender) Send(parameters RTPSendParameters) error { + r.mu.Lock() + defer r.mu.Unlock() + + switch { + case r.hasSent(): + return errRTPSenderSendAlreadyCalled + case r.trackEncodings[0].track == nil: + return errRTPSenderTrackRemoved + } + + for idx := range r.trackEncodings { + trackEncoding := r.trackEncodings[idx] + srtpStream := &srtpWriterFuture{ssrc: parameters.Encodings[idx].SSRC, rtpSender: r} + writeStream := &interceptorToTrackLocalWriter{} + rtpParameters := r.api.mediaEngine.getRTPParametersByKind( + trackEncoding.track.Kind(), + []RTPTransceiverDirection{RTPTransceiverDirectionSendonly}, + ) + + trackEncoding.srtpStream = srtpStream + trackEncoding.ssrc = parameters.Encodings[idx].SSRC + trackEncoding.ssrcRTX = parameters.Encodings[idx].RTX.SSRC + trackEncoding.ssrcFEC = parameters.Encodings[idx].FEC.SSRC + trackEncoding.rtcpInterceptor = r.api.interceptor.BindRTCPReader( + interceptor.RTCPReaderFunc( + func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) { + n, err = trackEncoding.srtpStream.Read(in) + + return n, a, err + }, + ), + ) + trackEncoding.context = &baseTrackLocalContext{ + id: r.id, + params: rtpParameters, + ssrc: parameters.Encodings[idx].SSRC, + ssrcFEC: parameters.Encodings[idx].FEC.SSRC, + ssrcRTX: parameters.Encodings[idx].RTX.SSRC, + writeStream: writeStream, + rtcpInterceptor: trackEncoding.rtcpInterceptor, + } + + codec, err := trackEncoding.track.Bind(trackEncoding.context) + if err != nil { + return err + } + trackEncoding.context.params.Codecs = []RTPCodecParameters{codec} + + trackEncoding.streamInfo = *createStreamInfo( + r.id, + parameters.Encodings[idx].SSRC, + parameters.Encodings[idx].RTX.SSRC, + parameters.Encodings[idx].FEC.SSRC, + codec.PayloadType, + findRTXPayloadType(codec.PayloadType, rtpParameters.Codecs), + findFECPayloadType(rtpParameters.Codecs), + codec.RTPCodecCapability, + parameters.HeaderExtensions, + ) + + rtpInterceptor := r.api.interceptor.BindLocalStream( + &trackEncoding.streamInfo, + interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, _ interceptor.Attributes) (int, error) { + return srtpStream.WriteRTP(header, payload) + }), + ) + + writeStream.interceptor.Store(rtpInterceptor) + } + + close(r.sendCalled) + + return nil +} + +// Stop irreversibly stops the RTPSender. +func (r *RTPSender) Stop() error { + r.mu.Lock() + + if stopped := r.hasStopped(); stopped { + r.mu.Unlock() + + return nil + } + + close(r.stopCalled) + r.mu.Unlock() + + if !r.hasSent() { + return nil + } + + if err := r.ReplaceTrack(nil); err != nil { + return err + } + + errs := []error{} + for _, trackEncoding := range r.trackEncodings { + r.api.interceptor.UnbindLocalStream(&trackEncoding.streamInfo) + if trackEncoding.srtpStream != nil { + errs = append(errs, trackEncoding.srtpStream.Close()) + } + } + + return util.FlattenErrs(errs) +} + +// Read reads incoming RTCP for this RTPSender. +func (r *RTPSender) Read(b []byte) (n int, a interceptor.Attributes, err error) { + select { + case <-r.sendCalled: + return r.trackEncodings[0].rtcpInterceptor.Read(b, a) + case <-r.stopCalled: + return 0, nil, io.ErrClosedPipe + } +} + +// ReadRTCP is a convenience method that wraps Read and unmarshals for you. +func (r *RTPSender) ReadRTCP() ([]rtcp.Packet, interceptor.Attributes, error) { + b := make([]byte, r.api.settingEngine.getReceiveMTU()) + i, attributes, err := r.Read(b) + if err != nil { + return nil, nil, err + } + + pkts, err := rtcp.Unmarshal(b[:i]) + if err != nil { + return nil, nil, err + } + + return pkts, attributes, nil +} + +// ReadSimulcast reads incoming RTCP for this RTPSender for given rid. +func (r *RTPSender) ReadSimulcast(b []byte, rid string) (n int, a interceptor.Attributes, err error) { + select { + case <-r.sendCalled: + r.mu.Lock() + for _, t := range r.trackEncodings { + if t.track != nil && t.track.RID() == rid { + reader := t.rtcpInterceptor + r.mu.Unlock() + + return reader.Read(b, a) + } + } + r.mu.Unlock() + + return 0, nil, fmt.Errorf("%w: %s", errRTPSenderNoTrackForRID, rid) + case <-r.stopCalled: + return 0, nil, io.ErrClosedPipe + } +} + +// ReadSimulcastRTCP is a convenience method that wraps ReadSimulcast and unmarshal for you. +func (r *RTPSender) ReadSimulcastRTCP(rid string) ([]rtcp.Packet, interceptor.Attributes, error) { + b := make([]byte, r.api.settingEngine.getReceiveMTU()) + i, attributes, err := r.ReadSimulcast(b, rid) + if err != nil { + return nil, nil, err + } + + pkts, err := rtcp.Unmarshal(b[:i]) + + return pkts, attributes, err +} + +// SetReadDeadline sets the deadline for the Read operation. +// Setting to zero means no deadline. +func (r *RTPSender) SetReadDeadline(t time.Time) error { + if r.trackEncodings[0].srtpStream == nil { + return errRTPSenderSendNotCalled + } + + return r.trackEncodings[0].srtpStream.SetReadDeadline(t) +} + +// SetReadDeadlineSimulcast sets the max amount of time the RTCP stream for a given rid +// will block before returning. 0 is forever. +func (r *RTPSender) SetReadDeadlineSimulcast(deadline time.Time, rid string) error { + r.mu.RLock() + defer r.mu.RUnlock() + + for _, t := range r.trackEncodings { + if t.track != nil && t.track.RID() == rid { + return t.srtpStream.SetReadDeadline(deadline) + } + } + + return fmt.Errorf("%w: %s", errRTPSenderNoTrackForRID, rid) +} + +// hasSent tells if data has been ever sent for this instance. +func (r *RTPSender) hasSent() bool { + select { + case <-r.sendCalled: + return true + default: + return false + } +} + +// hasStopped tells if stop has been called. +func (r *RTPSender) hasStopped() bool { + select { + case <-r.stopCalled: + return true + default: + return false + } +} + +// Set a SSRC for FEC and RTX if MediaEngine has them enabled +// If the remote doesn't support FEC or RTX we disable locally. +func (r *RTPSender) configureRTXAndFEC() { + r.mu.Lock() + defer r.mu.Unlock() + + for _, trackEncoding := range r.trackEncodings { + if !r.api.mediaEngine.isRTXEnabled(r.kind, []RTPTransceiverDirection{RTPTransceiverDirectionSendonly}) { + trackEncoding.ssrcRTX = SSRC(0) + } + + if !r.api.mediaEngine.isFECEnabled(r.kind, []RTPTransceiverDirection{RTPTransceiverDirectionSendonly}) { + trackEncoding.ssrcFEC = SSRC(0) + } + } +} diff --git a/vendor/github.com/pion/webrtc/v4/rtpsender_js.go b/vendor/github.com/pion/webrtc/v4/rtpsender_js.go new file mode 100644 index 0000000..4fde0c3 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/rtpsender_js.go @@ -0,0 +1,20 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build js && wasm +// +build js,wasm + +package webrtc + +import "syscall/js" + +// RTPSender allows an application to control how a given Track is encoded and transmitted to a remote peer +type RTPSender struct { + // Pointer to the underlying JavaScript RTCRTPSender object. + underlying js.Value +} + +// JSValue returns the underlying RTCRtpSender +func (s *RTPSender) JSValue() js.Value { + return s.underlying +} diff --git a/vendor/github.com/pion/webrtc/v4/rtpsendparameters.go b/vendor/github.com/pion/webrtc/v4/rtpsendparameters.go new file mode 100644 index 0000000..d24c5ae --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/rtpsendparameters.go @@ -0,0 +1,10 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +// RTPSendParameters contains the RTP stack settings used by receivers. +type RTPSendParameters struct { + RTPParameters + Encodings []RTPEncodingParameters +} diff --git a/vendor/github.com/pion/webrtc/v4/rtptransceiver.go b/vendor/github.com/pion/webrtc/v4/rtptransceiver.go new file mode 100644 index 0000000..d3dcfa7 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/rtptransceiver.go @@ -0,0 +1,456 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js + +package webrtc + +import ( + "fmt" + "strings" + "sync" + "sync/atomic" + + "github.com/pion/rtp" + "github.com/pion/sdp/v3" + "github.com/pion/webrtc/v4/internal/fmtp" +) + +// RTPTransceiver represents a combination of an RTPSender and an RTPReceiver that share a common mid. +type RTPTransceiver struct { + mid atomic.Value // string + sender atomic.Value // *RTPSender + receiver atomic.Value // *RTPReceiver + direction atomic.Value // RTPTransceiverDirection + currentDirection atomic.Value // RTPTransceiverDirection + currentRemoteDirection atomic.Value // RTPTransceiverDirection + + codecs []RTPCodecParameters // User provided codecs via SetCodecPreferences + + kind RTPCodecType + + api *API + mu sync.RWMutex +} + +func newRTPTransceiver( + receiver *RTPReceiver, + sender *RTPSender, + direction RTPTransceiverDirection, + kind RTPCodecType, + api *API, +) *RTPTransceiver { + t := &RTPTransceiver{kind: kind, api: api} + t.setReceiver(receiver) + t.setSender(sender) + t.setDirection(direction) + t.setCurrentDirection(RTPTransceiverDirectionUnknown) + + return t +} + +// SetCodecPreferences sets preferred list of supported codecs +// if codecs is empty or nil we reset to default from MediaEngine. +func (t *RTPTransceiver) SetCodecPreferences(codecs []RTPCodecParameters) error { + t.mu.Lock() + defer t.mu.Unlock() + + for _, codec := range codecs { + if _, matchType := codecParametersFuzzySearch( + codec, t.api.mediaEngine.getCodecsByKind(t.kind), + ); matchType == codecMatchNone { + return fmt.Errorf("%w %s", errRTPTransceiverCodecUnsupported, codec.MimeType) + } + } + + t.codecs = filterUnattachedRTX(codecs) + + return nil +} + +// getCodecs returns list of supported codecs. +func (t *RTPTransceiver) getCodecs() []RTPCodecParameters { + t.mu.RLock() + defer t.mu.RUnlock() + + mediaEngineCodecs := t.api.mediaEngine.getCodecsByKind(t.kind) + if len(t.codecs) == 0 { + return filterUnattachedRTX(mediaEngineCodecs) + } + + filteredCodecs := []RTPCodecParameters{} + for _, codec := range t.codecs { + if c, matchType := codecParametersFuzzySearch(codec, mediaEngineCodecs); matchType != codecMatchNone { + if codec.PayloadType == 0 { + codec.PayloadType = c.PayloadType + } + codec.RTCPFeedback = rtcpFeedbackIntersection(codec.RTCPFeedback, c.RTCPFeedback) + filteredCodecs = append(filteredCodecs, codec) + } + } + + return filterUnattachedRTX(filteredCodecs) +} + +// match codecs from remote description, used when remote is offerer and creating a transceiver +// from remote description with the aim of keeping order of codecs in remote description. +func (t *RTPTransceiver) setCodecPreferencesFromRemoteDescription(media *sdp.MediaDescription) { //nolint:cyclop + remoteCodecs, err := codecsFromMediaDescription(media) + if err != nil { + return + } + + // make a copy as this slice is modified + leftCodecs := append([]RTPCodecParameters{}, t.api.mediaEngine.getCodecsByKind(t.kind)...) + + // find codec matches between what is in remote description and + // the transceivers codecs and use payload type registered to + // media engine. + payloadMapping := make(map[PayloadType]PayloadType) // for RTX re-mapping later + filterByMatchType := func(matchFilter codecMatchType) []RTPCodecParameters { + filteredCodecs := []RTPCodecParameters{} + for remoteCodecIdx := len(remoteCodecs) - 1; remoteCodecIdx >= 0; remoteCodecIdx-- { + remoteCodec := remoteCodecs[remoteCodecIdx] + if strings.EqualFold(remoteCodec.RTPCodecCapability.MimeType, MimeTypeRTX) { + continue + } + + matchCodec, matchType := codecParametersFuzzySearch( + remoteCodec, + leftCodecs, + ) + if matchType == matchFilter { + payloadMapping[remoteCodec.PayloadType] = matchCodec.PayloadType + + remoteCodec.PayloadType = matchCodec.PayloadType + filteredCodecs = append([]RTPCodecParameters{remoteCodec}, filteredCodecs...) + + // removed matched codec for next round + remoteCodecs = append(remoteCodecs[:remoteCodecIdx], remoteCodecs[remoteCodecIdx+1:]...) + + needleFmtp := fmtp.Parse( + matchCodec.RTPCodecCapability.MimeType, + matchCodec.RTPCodecCapability.ClockRate, + matchCodec.RTPCodecCapability.Channels, + matchCodec.RTPCodecCapability.SDPFmtpLine, + ) + + for leftCodecIdx := len(leftCodecs) - 1; leftCodecIdx >= 0; leftCodecIdx-- { + leftCodec := leftCodecs[leftCodecIdx] + leftCodecFmtp := fmtp.Parse( + leftCodec.RTPCodecCapability.MimeType, + leftCodec.RTPCodecCapability.ClockRate, + leftCodec.RTPCodecCapability.Channels, + leftCodec.RTPCodecCapability.SDPFmtpLine, + ) + + if needleFmtp.Match(leftCodecFmtp) { + leftCodecs = append(leftCodecs[:leftCodecIdx], leftCodecs[leftCodecIdx+1:]...) + + break + } + } + } + } + + return filteredCodecs + } + + filteredCodecs := filterByMatchType(codecMatchExact) + filteredCodecs = append(filteredCodecs, filterByMatchType(codecMatchPartial)...) + + // find RTX associations and add those + for remotePayloadType, mediaEnginePayloadType := range payloadMapping { + remoteRTX := findRTXPayloadType(remotePayloadType, remoteCodecs) + if remoteRTX == PayloadType(0) { + continue + } + + mediaEngineRTX := findRTXPayloadType(mediaEnginePayloadType, leftCodecs) + if mediaEngineRTX == PayloadType(0) { + continue + } + + for _, rtxCodec := range leftCodecs { + if rtxCodec.PayloadType == mediaEngineRTX { + filteredCodecs = append(filteredCodecs, rtxCodec) + + break + } + } + } + _ = t.SetCodecPreferences(filteredCodecs) +} + +// Sender returns the RTPTransceiver's RTPSender if it has one. +func (t *RTPTransceiver) Sender() *RTPSender { + if v, ok := t.sender.Load().(*RTPSender); ok { + return v + } + + return nil +} + +// SetSender sets the RTPSender and Track to current transceiver. +func (t *RTPTransceiver) SetSender(s *RTPSender, track TrackLocal) error { + t.setSender(s) + + return t.setSendingTrack(track) +} + +func (t *RTPTransceiver) setSender(s *RTPSender) { + if s != nil { + s.setRTPTransceiver(t) + } + + if prevSender := t.Sender(); prevSender != nil { + prevSender.setRTPTransceiver(nil) + } + + t.sender.Store(s) +} + +// Receiver returns the RTPTransceiver's RTPReceiver if it has one. +func (t *RTPTransceiver) Receiver() *RTPReceiver { + if v, ok := t.receiver.Load().(*RTPReceiver); ok { + return v + } + + return nil +} + +// SetMid sets the RTPTransceiver's mid. If it was already set, will return an error. +func (t *RTPTransceiver) SetMid(mid string) error { + if currentMid := t.Mid(); currentMid != "" { + return fmt.Errorf("%w: %s to %s", errRTPTransceiverCannotChangeMid, currentMid, mid) + } + t.mid.Store(mid) + + return nil +} + +// Mid gets the Transceiver's mid value. When not already set, this value will be set in CreateOffer or CreateAnswer. +func (t *RTPTransceiver) Mid() string { + if v, ok := t.mid.Load().(string); ok { + return v + } + + return "" +} + +// Kind returns RTPTransceiver's kind. +func (t *RTPTransceiver) Kind() RTPCodecType { + return t.kind +} + +// Direction returns the RTPTransceiver's current direction. +func (t *RTPTransceiver) Direction() RTPTransceiverDirection { + if direction, ok := t.direction.Load().(RTPTransceiverDirection); ok { + return direction + } + + return RTPTransceiverDirection(0) +} + +// Stop irreversibly stops the RTPTransceiver. +func (t *RTPTransceiver) Stop() error { + if sender := t.Sender(); sender != nil { + if err := sender.Stop(); err != nil { + return err + } + } + if receiver := t.Receiver(); receiver != nil { + if err := receiver.Stop(); err != nil { + return err + } + } + + t.setDirection(RTPTransceiverDirectionInactive) + t.setCurrentDirection(RTPTransceiverDirectionInactive) + + return nil +} + +func (t *RTPTransceiver) setReceiver(r *RTPReceiver) { + if r != nil { + r.setRTPTransceiver(t) + } + + if prevReceiver := t.Receiver(); prevReceiver != nil { + prevReceiver.setRTPTransceiver(nil) + } + + t.receiver.Store(r) +} + +func (t *RTPTransceiver) setDirection(d RTPTransceiverDirection) { + t.direction.Store(d) +} + +func (t *RTPTransceiver) setCurrentDirection(d RTPTransceiverDirection) { + t.currentDirection.Store(d) +} + +func (t *RTPTransceiver) getCurrentDirection() RTPTransceiverDirection { + if v, ok := t.currentDirection.Load().(RTPTransceiverDirection); ok { + return v + } + + return RTPTransceiverDirectionUnknown +} + +func (t *RTPTransceiver) setCurrentRemoteDirection(d RTPTransceiverDirection) { + t.currentRemoteDirection.Store(d) +} + +func (t *RTPTransceiver) getCurrentRemoteDirection() RTPTransceiverDirection { + if v, ok := t.currentRemoteDirection.Load().(RTPTransceiverDirection); ok { + return v + } + + return RTPTransceiverDirectionUnknown +} + +func (t *RTPTransceiver) setSendingTrack(track TrackLocal) error { //nolint:cyclop + if err := t.Sender().ReplaceTrack(track); err != nil { + return err + } + if track == nil { + t.setSender(nil) + } + + switch { + case track != nil && t.Direction() == RTPTransceiverDirectionRecvonly: + t.setDirection(RTPTransceiverDirectionSendrecv) + case track != nil && t.Direction() == RTPTransceiverDirectionInactive: + t.setDirection(RTPTransceiverDirectionSendonly) + case track == nil && t.Direction() == RTPTransceiverDirectionSendrecv: + t.setDirection(RTPTransceiverDirectionRecvonly) + case track != nil && t.Direction() == RTPTransceiverDirectionSendonly: + // Handle the case where a sendonly transceiver was added by a negotiation + // initiated by remote peer. For example a remote peer added a transceiver + // with direction recvonly. + case track != nil && t.Direction() == RTPTransceiverDirectionSendrecv: + // Similar to above, but for sendrecv transceiver. + case track == nil && t.Direction() == RTPTransceiverDirectionSendonly: + t.setDirection(RTPTransceiverDirectionInactive) + default: + return errRTPTransceiverSetSendingInvalidState + } + + return nil +} + +func (t *RTPTransceiver) isSendAllowed(kind RTPCodecType) bool { + if t.kind != kind || t.Sender() != nil { + return false + } + + // According to https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-addtrack, if the + // transceiver can be reused only if its currentDirection was never sendrecv or sendonly. + // But that will cause sdp to inflate. So we only check currentDirection's current value, + // that's worked for all browsers. + currentDirection := t.getCurrentDirection() + if currentDirection == RTPTransceiverDirectionSendrecv || + currentDirection == RTPTransceiverDirectionSendonly { + return false + } + + // `currentRemoteDirection` should be checked before using the transceiver for send. + // Remote directions could be + // - `sendrecv` or `recvonly` - can send, remote direction will transition from + // `sendrecv` -> `recvonly` if a remote track was removed. + // - `sendonly` or `inactive` - cannot send, remote direction will transitions from + // `sendonly` -> `inactive` if a remote track was removed. + // - `unknown` - can send - we are the offering side and remote direction is unknown + currentRemoteDirection := t.getCurrentRemoteDirection() + if currentRemoteDirection == RTPTransceiverDirectionSendonly || + currentRemoteDirection == RTPTransceiverDirectionInactive { + return false + } + + return true +} + +func findByMid(mid string, localTransceivers []*RTPTransceiver) (*RTPTransceiver, []*RTPTransceiver) { + for i, t := range localTransceivers { + if t.Mid() == mid { + return t, append(localTransceivers[:i], localTransceivers[i+1:]...) + } + } + + return nil, localTransceivers +} + +// Given a direction+type pluck a transceiver from the passed list +// if no entry satisfies the requested type+direction return a inactive Transceiver. +func satisfyTypeAndDirection( + remoteKind RTPCodecType, + remoteDirection RTPTransceiverDirection, + localTransceivers []*RTPTransceiver, +) (*RTPTransceiver, []*RTPTransceiver) { + // Get direction order from most preferred to least + getPreferredDirections := func() []RTPTransceiverDirection { + switch remoteDirection { + case RTPTransceiverDirectionSendrecv: + return []RTPTransceiverDirection{ + RTPTransceiverDirectionRecvonly, + RTPTransceiverDirectionSendrecv, + RTPTransceiverDirectionSendonly, + } + case RTPTransceiverDirectionSendonly: + return []RTPTransceiverDirection{RTPTransceiverDirectionRecvonly} + case RTPTransceiverDirectionRecvonly: + return []RTPTransceiverDirection{RTPTransceiverDirectionSendonly, RTPTransceiverDirectionSendrecv} + default: + return []RTPTransceiverDirection{} + } + } + + for _, possibleDirection := range getPreferredDirections() { + for i := range localTransceivers { + t := localTransceivers[i] + if t.Mid() == "" && t.kind == remoteKind && possibleDirection == t.Direction() { + return t, append(localTransceivers[:i], localTransceivers[i+1:]...) + } + } + } + + return nil, localTransceivers +} + +// handleUnknownRTPPacket consumes a single RTP Packet and returns information that is helpful +// for demuxing and handling an unknown SSRC (usually for Simulcast). +func handleUnknownRTPPacket( + buf []byte, + midExtensionID, + streamIDExtensionID, + repairStreamIDExtensionID uint8, +) (mid, rid, rsid string, paddingOnly bool, err error) { + rp := &rtp.Packet{} + if err = rp.Unmarshal(buf); err != nil { + return mid, rid, rsid, false, err + } + + if rp.Padding && len(rp.Payload) == 0 { + return mid, rid, rsid, true, nil + } + + if !rp.Header.Extension { + return mid, rid, rsid, false, nil + } + + if payload := rp.GetExtension(midExtensionID); payload != nil { + mid = string(payload) + } + + if payload := rp.GetExtension(streamIDExtensionID); payload != nil { + rid = string(payload) + } + + if payload := rp.GetExtension(repairStreamIDExtensionID); payload != nil { + rsid = string(payload) + } + + return mid, rid, rsid, false, nil +} diff --git a/vendor/github.com/pion/webrtc/v4/rtptransceiver_js.go b/vendor/github.com/pion/webrtc/v4/rtptransceiver_js.go new file mode 100644 index 0000000..70ccb2d --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/rtptransceiver_js.go @@ -0,0 +1,47 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build js && wasm +// +build js,wasm + +package webrtc + +import ( + "syscall/js" +) + +// RTPTransceiver represents a combination of an RTPSender and an RTPReceiver that share a common mid. +type RTPTransceiver struct { + // Pointer to the underlying JavaScript RTCRTPTransceiver object. + underlying js.Value +} + +// JSValue returns the underlying RTCRtpTransceiver +func (r *RTPTransceiver) JSValue() js.Value { + return r.underlying +} + +// Direction returns the RTPTransceiver's current direction +func (r *RTPTransceiver) Direction() RTPTransceiverDirection { + return NewRTPTransceiverDirection(r.underlying.Get("direction").String()) +} + +// Sender returns the RTPTransceiver's RTPSender if it has one +func (r *RTPTransceiver) Sender() *RTPSender { + underlying := r.underlying.Get("sender") + if underlying.IsNull() { + return nil + } + + return &RTPSender{underlying: underlying} +} + +// Receiver returns the RTPTransceiver's RTPReceiver if it has one +func (r *RTPTransceiver) Receiver() *RTPReceiver { + underlying := r.underlying.Get("receiver") + if underlying.IsNull() { + return nil + } + + return &RTPReceiver{underlying: underlying} +} diff --git a/vendor/github.com/pion/webrtc/v4/rtptransceiverdirection.go b/vendor/github.com/pion/webrtc/v4/rtptransceiverdirection.go new file mode 100644 index 0000000..ea836e7 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/rtptransceiverdirection.go @@ -0,0 +1,95 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +import "slices" + +// RTPTransceiverDirection indicates the direction of the RTPTransceiver. +type RTPTransceiverDirection int + +const ( + // RTPTransceiverDirectionUnknown is the enum's zero-value. + RTPTransceiverDirectionUnknown RTPTransceiverDirection = iota + + // RTPTransceiverDirectionSendrecv indicates the RTPSender will offer + // to send RTP and the RTPReceiver will offer to receive RTP. + RTPTransceiverDirectionSendrecv + + // RTPTransceiverDirectionSendonly indicates the RTPSender will offer + // to send RTP. + RTPTransceiverDirectionSendonly + + // RTPTransceiverDirectionRecvonly indicates the RTPReceiver will + // offer to receive RTP. + RTPTransceiverDirectionRecvonly + + // RTPTransceiverDirectionInactive indicates the RTPSender won't offer + // to send RTP and the RTPReceiver won't offer to receive RTP. + RTPTransceiverDirectionInactive +) + +// This is done this way because of a linter. +const ( + rtpTransceiverDirectionSendrecvStr = "sendrecv" + rtpTransceiverDirectionSendonlyStr = "sendonly" + rtpTransceiverDirectionRecvonlyStr = "recvonly" + rtpTransceiverDirectionInactiveStr = "inactive" +) + +// NewRTPTransceiverDirection defines a procedure for creating a new +// RTPTransceiverDirection from a raw string naming the transceiver direction. +func NewRTPTransceiverDirection(raw string) RTPTransceiverDirection { + switch raw { + case rtpTransceiverDirectionSendrecvStr: + return RTPTransceiverDirectionSendrecv + case rtpTransceiverDirectionSendonlyStr: + return RTPTransceiverDirectionSendonly + case rtpTransceiverDirectionRecvonlyStr: + return RTPTransceiverDirectionRecvonly + case rtpTransceiverDirectionInactiveStr: + return RTPTransceiverDirectionInactive + default: + return RTPTransceiverDirectionUnknown + } +} + +func (t RTPTransceiverDirection) String() string { + switch t { + case RTPTransceiverDirectionSendrecv: + return rtpTransceiverDirectionSendrecvStr + case RTPTransceiverDirectionSendonly: + return rtpTransceiverDirectionSendonlyStr + case RTPTransceiverDirectionRecvonly: + return rtpTransceiverDirectionRecvonlyStr + case RTPTransceiverDirectionInactive: + return rtpTransceiverDirectionInactiveStr + default: + return ErrUnknownType.Error() + } +} + +// Revers indicate the opposite direction. +func (t RTPTransceiverDirection) Revers() RTPTransceiverDirection { + switch t { + case RTPTransceiverDirectionSendonly: + return RTPTransceiverDirectionRecvonly + case RTPTransceiverDirectionRecvonly: + return RTPTransceiverDirectionSendonly + default: + return t + } +} + +func haveRTPTransceiverDirectionIntersection( + haystack []RTPTransceiverDirection, + needle []RTPTransceiverDirection, +) bool { + for _, n := range needle { + if slices.Contains(haystack, n) { + return true + } + } + + return false +} diff --git a/vendor/github.com/pion/webrtc/v4/rtptransceiverinit.go b/vendor/github.com/pion/webrtc/v4/rtptransceiverinit.go new file mode 100644 index 0000000..7aac65a --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/rtptransceiverinit.go @@ -0,0 +1,12 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +// RTPTransceiverInit dictionary is used when calling the WebRTC function addTransceiver() +// to provide configuration options for the new transceiver. +type RTPTransceiverInit struct { + Direction RTPTransceiverDirection + SendEncodings []RTPEncodingParameters + // Streams []*Track +} diff --git a/vendor/github.com/pion/webrtc/v4/sctpcapabilities.go b/vendor/github.com/pion/webrtc/v4/sctpcapabilities.go new file mode 100644 index 0000000..d9e94de --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/sctpcapabilities.go @@ -0,0 +1,11 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +// SCTPCapabilities indicates the capabilities of the SCTPTransport. +type SCTPCapabilities struct { + MaxMessageSize uint32 `json:"maxMessageSize"` + // Note: this is the binary sctp-init, not the base64 encoded version. + sctpInit string +} diff --git a/vendor/github.com/pion/webrtc/v4/sctptransport.go b/vendor/github.com/pion/webrtc/v4/sctptransport.go new file mode 100644 index 0000000..9bb9a72 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/sctptransport.go @@ -0,0 +1,528 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js + +package webrtc + +import ( + "errors" + "io" + "net" + "sync" + "time" + + "github.com/pion/datachannel" + "github.com/pion/logging" + "github.com/pion/sctp" + "github.com/pion/webrtc/v4/pkg/rtcerr" +) + +const sctpMaxChannels = uint16(65535) + +// SCTPTransport provides details about the SCTP transport. +type SCTPTransport struct { + lock sync.RWMutex + + dtlsTransport *DTLSTransport + + // State represents the current state of the SCTP transport. + state SCTPTransportState + + // SCTPTransportState doesn't have an enum to distinguish between New/Connecting + // so we need a dedicated field + isStarted bool + + // MaxChannels represents the maximum amount of DataChannel's that can + // be used simultaneously. + maxChannels *uint16 + + // OnStateChange func() + + onErrorHandler func(error) + onCloseHandler func(error) + + sctpAssociation *sctp.Association + onDataChannelHandler func(*DataChannel) + onDataChannelOpenedHandler func(*DataChannel) + + // DataChannels + dataChannels []*DataChannel + dataChannelIDsUsed map[uint16]struct{} + dataChannelsOpened uint32 + dataChannelsRequested uint32 + dataChannelsAccepted uint32 + + localSctpInit []byte + + api *API + log logging.LeveledLogger +} + +// NewSCTPTransport creates a new SCTPTransport. +// This constructor is part of the ORTC API. It is not +// meant to be used together with the basic WebRTC API. +func (api *API) NewSCTPTransport(dtls *DTLSTransport) *SCTPTransport { + res := &SCTPTransport{ + dtlsTransport: dtls, + state: SCTPTransportStateConnecting, + api: api, + log: api.settingEngine.LoggerFactory.NewLogger("ortc"), + dataChannelIDsUsed: make(map[uint16]struct{}), + } + + res.updateMaxChannels() + + return res +} + +// Transport returns the DTLSTransport instance the SCTPTransport is sending over. +func (r *SCTPTransport) Transport() *DTLSTransport { + r.lock.RLock() + defer r.lock.RUnlock() + + return r.dtlsTransport +} + +// GetCapabilities returns the SCTPCapabilities of the SCTPTransport. +func (r *SCTPTransport) GetCapabilities() SCTPCapabilities { + var maxMessageSize uint32 + if a := r.association(); a != nil { + maxMessageSize = a.MaxMessageSize() + } + + return SCTPCapabilities{ + MaxMessageSize: maxMessageSize, + } +} + +// Start the SCTPTransport. Since both local and remote parties must mutually +// create an SCTPTransport, SCTP SO (Simultaneous Open) is used to establish +// a connection over SCTP. +// +//nolint:cyclop +func (r *SCTPTransport) Start(capabilities SCTPCapabilities) error { + if r.isStarted { + return nil + } + r.isStarted = true + + maxMessageSize := capabilities.MaxMessageSize + if maxMessageSize == 0 { + maxMessageSize = sctpMaxMessageSizeUnsetValue + } + remoteSctpInit := []byte(capabilities.sctpInit) + + dtlsTransport := r.Transport() + if dtlsTransport == nil || dtlsTransport.conn == nil { + return errSCTPTransportDTLS + } + opts := r.sctpClientOptions(dtlsTransport.conn, maxMessageSize) + if len(r.localSctpInit) > 0 && len(remoteSctpInit) > 0 { + opts = append( + opts, + sctp.WithSNAP(r.localSctpInit, remoteSctpInit), + ) + } + sctpAssociation, err := sctp.ClientWithOptions(opts...) + if err != nil { + return err + } + + r.lock.Lock() + r.sctpAssociation = sctpAssociation + r.state = SCTPTransportStateConnected + dataChannels := append([]*DataChannel{}, r.dataChannels...) + r.lock.Unlock() + + var openedDCCount uint32 + for _, d := range dataChannels { + if d.ReadyState() == DataChannelStateConnecting { + err := d.open(r) + if err != nil { + r.log.Warnf("failed to open data channel: %s", err) + + continue + } + openedDCCount++ + } + } + + r.lock.Lock() + r.dataChannelsOpened += openedDCCount + r.lock.Unlock() + + go r.acceptDataChannels(sctpAssociation, dataChannels) + + return nil +} + +func (r *SCTPTransport) sctpClientOptions(netConn net.Conn, maxMessageSize uint32) []sctp.ClientOption { + opts := []sctp.ClientOption{ + sctp.WithNetConn(netConn), + sctp.WithLoggerFactory(r.api.settingEngine.LoggerFactory), + sctp.WithMTU(outboundMTU), + sctp.WithMaxMessageSize(maxMessageSize), + } + + return append(opts, r.optionalSCTPClientOptions()...) +} + +func (r *SCTPTransport) optionalSCTPClientOptions() []sctp.ClientOption { + opts := make([]sctp.ClientOption, 0, 7) + + if r.api.settingEngine.sctp.maxReceiveBufferSize != 0 { + opts = append(opts, sctp.WithMaxReceiveBufferSize(r.api.settingEngine.sctp.maxReceiveBufferSize)) + } + + if r.api.settingEngine.sctp.enableZeroChecksum { + opts = append(opts, sctp.WithEnableZeroChecksum(true)) + } + + if r.api.settingEngine.detach.DataChannels && r.api.settingEngine.dataChannelBlockWrite { + opts = append(opts, sctp.WithBlockWrite(true)) + } + + if r.api.settingEngine.sctp.rtoMax > 0 { + opts = append( + opts, + sctp.WithRTOMax(float64(r.api.settingEngine.sctp.rtoMax)/float64(time.Millisecond)), + ) + } + + if r.api.settingEngine.sctp.minCwnd != 0 { + opts = append(opts, sctp.WithMinCwnd(r.api.settingEngine.sctp.minCwnd)) + } + + if r.api.settingEngine.sctp.fastRtxWnd != 0 { + opts = append(opts, sctp.WithFastRtxWnd(r.api.settingEngine.sctp.fastRtxWnd)) + } + + if r.api.settingEngine.sctp.cwndCAStep != 0 { + opts = append(opts, sctp.WithCwndCAStep(r.api.settingEngine.sctp.cwndCAStep)) + } + + return opts +} + +// Stop stops the SCTPTransport. +func (r *SCTPTransport) Stop() error { + r.lock.Lock() + defer r.lock.Unlock() + if r.sctpAssociation == nil { + return nil + } + + r.sctpAssociation.Abort("") + + r.sctpAssociation = nil + r.state = SCTPTransportStateClosed + + return nil +} + +//nolint:cyclop +func (r *SCTPTransport) acceptDataChannels( + assoc *sctp.Association, + existingDataChannels []*DataChannel, +) { + dataChannels := make([]*datachannel.DataChannel, 0, len(existingDataChannels)) + for _, dc := range existingDataChannels { + dc.mu.Lock() + isNil := dc.dataChannel == nil + dc.mu.Unlock() + if isNil { + continue + } + dataChannels = append(dataChannels, dc.dataChannel) + } +ACCEPT: + for { + // check if the association has been stopped before calling accept. + r.lock.RLock() + currentAssoc := r.sctpAssociation + shouldStop := currentAssoc == nil || currentAssoc != assoc + r.lock.RUnlock() + if shouldStop { + r.onClose(nil) + + return + } + + dc, err := datachannel.Accept(assoc, &datachannel.Config{ + LoggerFactory: r.api.settingEngine.LoggerFactory, + }, dataChannels...) + if err != nil { + if !errors.Is(err, io.EOF) { + r.log.Errorf("Failed to accept data channel: %v", err) + r.onError(err) + r.onClose(err) + } else { + r.onClose(nil) + } + + return + } + for _, ch := range dataChannels { + if ch.StreamIdentifier() == dc.StreamIdentifier() { + continue ACCEPT + } + } + + var ( + maxRetransmits *uint16 + maxPacketLifeTime *uint16 + ) + val := uint16(dc.Config.ReliabilityParameter) //nolint:gosec //G115 + ordered := true + + switch dc.Config.ChannelType { + case datachannel.ChannelTypeReliable: + ordered = true + case datachannel.ChannelTypeReliableUnordered: + ordered = false + case datachannel.ChannelTypePartialReliableRexmit: + ordered = true + maxRetransmits = &val + case datachannel.ChannelTypePartialReliableRexmitUnordered: + ordered = false + maxRetransmits = &val + case datachannel.ChannelTypePartialReliableTimed: + ordered = true + maxPacketLifeTime = &val + case datachannel.ChannelTypePartialReliableTimedUnordered: + ordered = false + maxPacketLifeTime = &val + default: + } + + sid := dc.StreamIdentifier() + rtcDC, err := r.api.newDataChannel(&DataChannelParameters{ + ID: &sid, + Label: dc.Config.Label, + Protocol: dc.Config.Protocol, + Negotiated: dc.Config.Negotiated, + Ordered: ordered, + MaxPacketLifeTime: maxPacketLifeTime, + MaxRetransmits: maxRetransmits, + }, r, r.api.settingEngine.LoggerFactory.NewLogger("ortc")) + if err != nil { + // This data channel is invalid. Close it and log an error. + if err1 := dc.Close(); err1 != nil { + r.log.Errorf("Failed to close invalid data channel: %v", err1) + } + r.log.Errorf("Failed to accept data channel: %v", err) + r.onError(err) + // We've received a datachannel with invalid configuration. We can still receive other datachannels. + continue ACCEPT + } + + <-r.onDataChannel(rtcDC) + rtcDC.handleOpen(dc, true, dc.Config.Negotiated) + + r.lock.Lock() + r.dataChannelsOpened++ + handler := r.onDataChannelOpenedHandler + r.lock.Unlock() + + if handler != nil { + handler(rtcDC) + } + } +} + +// OnError sets an event handler which is invoked when the SCTP Association errors. +func (r *SCTPTransport) OnError(f func(err error)) { + r.lock.Lock() + defer r.lock.Unlock() + r.onErrorHandler = f +} + +func (r *SCTPTransport) onError(err error) { + r.lock.RLock() + handler := r.onErrorHandler + r.lock.RUnlock() + + if handler != nil { + go handler(err) + } +} + +// OnClose sets an event handler which is invoked when the SCTP Association closes. +func (r *SCTPTransport) OnClose(f func(err error)) { + r.lock.Lock() + defer r.lock.Unlock() + r.onCloseHandler = f +} + +func (r *SCTPTransport) onClose(err error) { + r.lock.RLock() + handler := r.onCloseHandler + r.lock.RUnlock() + + if handler != nil { + go handler(err) + } +} + +// OnDataChannel sets an event handler which is invoked when a data +// channel message arrives from a remote peer. +func (r *SCTPTransport) OnDataChannel(f func(*DataChannel)) { + r.lock.Lock() + defer r.lock.Unlock() + r.onDataChannelHandler = f +} + +// OnDataChannelOpened sets an event handler which is invoked when a data +// channel is opened. +func (r *SCTPTransport) OnDataChannelOpened(f func(*DataChannel)) { + r.lock.Lock() + defer r.lock.Unlock() + r.onDataChannelOpenedHandler = f +} + +func (r *SCTPTransport) onDataChannel(dc *DataChannel) (done chan struct{}) { + r.lock.Lock() + r.dataChannels = append(r.dataChannels, dc) + r.dataChannelsAccepted++ + if dc.ID() != nil { + r.dataChannelIDsUsed[*dc.ID()] = struct{}{} + } else { + // This cannot happen, the constructor for this datachannel in the caller + // takes a pointer to the id. + r.log.Errorf("accepted data channel with no ID") + } + handler := r.onDataChannelHandler + r.lock.Unlock() + + done = make(chan struct{}) + if handler == nil || dc == nil { + close(done) + + return + } + + // Run this synchronously to allow setup done in onDataChannelFn() + // to complete before datachannel event handlers might be called. + go func() { + handler(dc) + close(done) + }() + + return +} + +func (r *SCTPTransport) updateMaxChannels() { + val := sctpMaxChannels + r.maxChannels = &val +} + +// MaxChannels is the maximum number of RTCDataChannels that can be open simultaneously. +func (r *SCTPTransport) MaxChannels() uint16 { + r.lock.Lock() + defer r.lock.Unlock() + + if r.maxChannels == nil { + return sctpMaxChannels + } + + return *r.maxChannels +} + +// State returns the current state of the SCTPTransport. +func (r *SCTPTransport) State() SCTPTransportState { + r.lock.RLock() + defer r.lock.RUnlock() + + return r.state +} + +// Stats reports the current statistics of the SCTPTransport. +func (r *SCTPTransport) Stats() SCTPTransportStats { + stats := SCTPTransportStats{ + Timestamp: statsTimestampFrom(time.Now()), + Type: StatsTypeSCTPTransport, + ID: "sctpTransport", + } + + association := r.association() + if association != nil { + stats.BytesSent = association.BytesSent() + stats.BytesReceived = association.BytesReceived() + stats.SmoothedRoundTripTime = association.SRTT() * 0.001 // convert milliseconds to seconds + stats.CongestionWindow = association.CWND() + stats.ReceiverWindow = association.RWND() + stats.MTU = association.MTU() + } + + return stats +} + +func (r *SCTPTransport) collectStats(collector *statsReportCollector) { + collector.Collecting() + stats := r.Stats() + collector.Collect(stats.ID, stats) +} + +func (r *SCTPTransport) generateAndSetDataChannelID(dtlsRole DTLSRole, idOut **uint16) error { + var id uint16 + if dtlsRole != DTLSRoleClient { + id++ + } + + maxVal := r.MaxChannels() + + r.lock.Lock() + defer r.lock.Unlock() + + for ; id < maxVal-1; id += 2 { + if _, ok := r.dataChannelIDsUsed[id]; ok { + continue + } + *idOut = &id + r.dataChannelIDsUsed[id] = struct{}{} + + return nil + } + + return &rtcerr.OperationError{Err: ErrMaxDataChannelID} +} + +func (r *SCTPTransport) association() *sctp.Association { + if r == nil { + return nil + } + r.lock.RLock() + association := r.sctpAssociation + r.lock.RUnlock() + + return association +} + +// BufferedAmount returns total amount (in bytes) of currently buffered user data. +func (r *SCTPTransport) BufferedAmount() int { + r.lock.Lock() + defer r.lock.Unlock() + if r.sctpAssociation == nil { + return 0 + } + + return r.sctpAssociation.BufferedAmount() +} + +// GetSctpInit returns the current sctp-init attribute and caches the last created. +// The caller should hold the lock. +func (r *SCTPTransport) GetSctpInit() []byte { + if len(r.localSctpInit) == 0 { + var err error + r.localSctpInit, err = sctp.GenerateOutOfBandToken(sctp.Config{ + MaxReceiveBufferSize: r.api.settingEngine.sctp.maxReceiveBufferSize, + EnableZeroChecksum: r.api.settingEngine.sctp.enableZeroChecksum, + }) + if err != nil { + r.log.Warnf("Failed to create sctp-init: %v", err) + } + } + + return r.localSctpInit +} diff --git a/vendor/github.com/pion/webrtc/v4/sctptransport_js.go b/vendor/github.com/pion/webrtc/v4/sctptransport_js.go new file mode 100644 index 0000000..0c2df49 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/sctptransport_js.go @@ -0,0 +1,32 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build js && wasm +// +build js,wasm + +package webrtc + +import "syscall/js" + +// SCTPTransport provides details about the SCTP transport. +type SCTPTransport struct { + // Pointer to the underlying JavaScript SCTPTransport object. + underlying js.Value +} + +// JSValue returns the underlying RTCSctpTransport +func (r *SCTPTransport) JSValue() js.Value { + return r.underlying +} + +// Transport returns the DTLSTransport instance the SCTPTransport is sending over. +func (r *SCTPTransport) Transport() *DTLSTransport { + underlying := r.underlying.Get("transport") + if underlying.IsNull() || underlying.IsUndefined() { + return nil + } + + return &DTLSTransport{ + underlying: underlying, + } +} diff --git a/vendor/github.com/pion/webrtc/v4/sctptransportstate.go b/vendor/github.com/pion/webrtc/v4/sctptransportstate.go new file mode 100644 index 0000000..6794599 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/sctptransportstate.go @@ -0,0 +1,60 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +// SCTPTransportState indicates the state of the SCTP transport. +type SCTPTransportState int + +const ( + // SCTPTransportStateUnknown is the enum's zero-value. + SCTPTransportStateUnknown SCTPTransportState = iota + + // SCTPTransportStateConnecting indicates the SCTPTransport is in the + // process of negotiating an association. This is the initial state of the + // SCTPTransportState when an SCTPTransport is created. + SCTPTransportStateConnecting + + // SCTPTransportStateConnected indicates the negotiation of an + // association is completed. + SCTPTransportStateConnected + + // SCTPTransportStateClosed indicates a SHUTDOWN or ABORT chunk is + // received or when the SCTP association has been closed intentionally, + // such as by closing the peer connection or applying a remote description + // that rejects data or changes the SCTP port. + SCTPTransportStateClosed +) + +// This is done this way because of a linter. +const ( + sctpTransportStateConnectingStr = "connecting" + sctpTransportStateConnectedStr = "connected" + sctpTransportStateClosedStr = "closed" +) + +func newSCTPTransportState(raw string) SCTPTransportState { + switch raw { + case sctpTransportStateConnectingStr: + return SCTPTransportStateConnecting + case sctpTransportStateConnectedStr: + return SCTPTransportStateConnected + case sctpTransportStateClosedStr: + return SCTPTransportStateClosed + default: + return SCTPTransportStateUnknown + } +} + +func (s SCTPTransportState) String() string { + switch s { + case SCTPTransportStateConnecting: + return sctpTransportStateConnectingStr + case SCTPTransportStateConnected: + return sctpTransportStateConnectedStr + case SCTPTransportStateClosed: + return sctpTransportStateClosedStr + default: + return ErrUnknownType.Error() + } +} diff --git a/vendor/github.com/pion/webrtc/v4/sdp.go b/vendor/github.com/pion/webrtc/v4/sdp.go new file mode 100644 index 0000000..fe1500e --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/sdp.go @@ -0,0 +1,1233 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js + +package webrtc + +import ( + "encoding/base64" + "errors" + "fmt" + "net/url" + "regexp" + "slices" + "strconv" + "strings" + "sync/atomic" + + "github.com/pion/ice/v4" + "github.com/pion/logging" + "github.com/pion/sdp/v3" +) + +// trackDetails represents any media source that can be represented in a SDP +// This isn't keyed by SSRC because it also needs to support rid based sources. +type trackDetails struct { + mid string + kind RTPCodecType + streamID string + id string + ssrcs []SSRC + rtxSsrc *SSRC + fecSsrc *SSRC + rids []string +} + +func trackDetailsForSSRC(trackDetails []trackDetails, ssrc SSRC) *trackDetails { + for i := range trackDetails { + if slices.Contains(trackDetails[i].ssrcs, ssrc) { + return &trackDetails[i] + } + } + + return nil +} + +func trackDetailsForRID(trackDetails []trackDetails, mid, rid string) *trackDetails { + for i := range trackDetails { + if trackDetails[i].mid != mid { + continue + } + + if slices.Contains(trackDetails[i].rids, rid) { + return &trackDetails[i] + } + } + + return nil +} + +func filterTrackWithSSRC(incomingTracks []trackDetails, ssrc SSRC) []trackDetails { + filtered := []trackDetails{} + doesTrackHaveSSRC := func(t trackDetails) bool { + return slices.Contains(t.ssrcs, ssrc) + } + + for i := range incomingTracks { + if !doesTrackHaveSSRC(incomingTracks[i]) { + filtered = append(filtered, incomingTracks[i]) + } + } + + return filtered +} + +// extract all trackDetails from an SDP. +// +//nolint:gocognit,gocyclo,cyclop +func trackDetailsFromSDP( + log logging.LeveledLogger, + s *sdp.SessionDescription, +) (incomingTracks []trackDetails) { + for _, media := range s.MediaDescriptions { + tracksInMediaSection := []trackDetails{} + rtxRepairFlows := map[uint64]uint64{} + fecRepairFlows := map[uint64]uint64{} + + // Plan B can have multiple tracks in a single media section + streamID := "" + trackID := "" + + // If media section is recvonly or inactive skip + if _, ok := media.Attribute(sdp.AttrKeyRecvOnly); ok { + continue + } else if _, ok := media.Attribute(sdp.AttrKeyInactive); ok { + continue + } + + midValue := getMidValue(media) + if midValue == "" { + continue + } + + codecType := NewRTPCodecType(media.MediaName.Media) + if codecType == 0 { + continue + } + + for _, attr := range media.Attributes { + switch attr.Key { + case sdp.AttrKeySSRCGroup: + split := strings.Split(attr.Value, " ") + if split[0] == sdp.SemanticTokenFlowIdentification { //nolint:nestif + // Add rtx ssrcs to blacklist, to avoid adding them as tracks + // Essentially lines like `a=ssrc-group:FID 2231627014 632943048` are processed by this section + // as this declares that the second SSRC (632943048) is a rtx repair flow (RFC4588) for the first + // (2231627014) as specified in RFC5576 + if len(split) == 3 { + baseSsrc, err := strconv.ParseUint(split[1], 10, 32) + if err != nil { + log.Warnf("Failed to parse SSRC: %v", err) + + continue + } + rtxRepairFlow, err := strconv.ParseUint(split[2], 10, 32) + if err != nil { + log.Warnf("Failed to parse SSRC: %v", err) + + continue + } + rtxRepairFlows[rtxRepairFlow] = baseSsrc + tracksInMediaSection = filterTrackWithSSRC( + tracksInMediaSection, + SSRC(rtxRepairFlow), + ) // Remove if rtx was added as track before + for i := range tracksInMediaSection { + if tracksInMediaSection[i].ssrcs[0] == SSRC(baseSsrc) { + repairSsrc := SSRC(rtxRepairFlow) + tracksInMediaSection[i].rtxSsrc = &repairSsrc + } + } + } + } else if split[0] == sdp.SemanticTokenForwardErrorCorrectionFramework { + // Similar to above, lines like `a=ssrc-group:FEC-FR aaaaa bbbbb` + // means for video ssrc aaaaa, there's a FEC track bbbbb + if len(split) == 3 { + baseSsrc, err := strconv.ParseUint(split[1], 10, 32) + if err != nil { + log.Warnf("Failed to parse SSRC: %v", err) + + continue + } + fecRepairFlow, err := strconv.ParseUint(split[2], 10, 32) + if err != nil { + log.Warnf("Failed to parse SSRC: %v", err) + + continue + } + fecRepairFlows[fecRepairFlow] = baseSsrc + tracksInMediaSection = filterTrackWithSSRC( + tracksInMediaSection, + SSRC(fecRepairFlow), + ) // Remove if fec was added as track before + for i := range tracksInMediaSection { + if tracksInMediaSection[i].ssrcs[0] == SSRC(baseSsrc) { + repairSsrc := SSRC(fecRepairFlow) + tracksInMediaSection[i].fecSsrc = &repairSsrc + } + } + } + } + + // Handle `a=msid: ` for Unified plan. The first value is the same as MediaStream.id + // in the browser and can be used to figure out which tracks belong to the same stream. The browser should + // figure this out automatically when an ontrack event is emitted on RTCPeerConnection. + case sdp.AttrKeyMsid: + split := strings.Split(attr.Value, " ") + if len(split) == 2 { + streamID = split[0] + trackID = split[1] + } + + case sdp.AttrKeySSRC: + split := strings.Split(attr.Value, " ") + ssrc, err := strconv.ParseUint(split[0], 10, 32) + if err != nil { + log.Warnf("Failed to parse SSRC: %v", err) + + continue + } + + if _, ok := rtxRepairFlows[ssrc]; ok { + continue // This ssrc is a RTX repair flow, ignore + } + if _, ok := fecRepairFlows[ssrc]; ok { + continue // This ssrc is a FEC repair flow, ignore + } + + if len(split) == 3 && strings.HasPrefix(split[1], "msid:") { + streamID = split[1][len("msid:"):] + trackID = split[2] + } + + isNewTrack := true + trackDetails := &trackDetails{} + for i := range tracksInMediaSection { + for j := range tracksInMediaSection[i].ssrcs { + if tracksInMediaSection[i].ssrcs[j] == SSRC(ssrc) { + trackDetails = &tracksInMediaSection[i] + isNewTrack = false + } + } + } + + trackDetails.mid = midValue + trackDetails.kind = codecType + trackDetails.streamID = streamID + trackDetails.id = trackID + trackDetails.ssrcs = []SSRC{SSRC(ssrc)} + + for r, baseSsrc := range rtxRepairFlows { + if baseSsrc == ssrc { + repairSsrc := SSRC(r) //nolint:gosec // G115 + trackDetails.rtxSsrc = &repairSsrc + } + } + for r, baseSsrc := range fecRepairFlows { + if baseSsrc == ssrc { + fecSsrc := SSRC(r) //nolint:gosec // G115 + trackDetails.fecSsrc = &fecSsrc + } + } + + if isNewTrack { + tracksInMediaSection = append(tracksInMediaSection, *trackDetails) + } + } + } + + if rids := getRids(media); len(rids) != 0 && trackID != "" && streamID != "" { + simulcastTrack := trackDetails{ + mid: midValue, + kind: codecType, + streamID: streamID, + id: trackID, + rids: []string{}, + } + for _, rid := range rids { + simulcastTrack.rids = append(simulcastTrack.rids, rid.id) + } + + tracksInMediaSection = []trackDetails{simulcastTrack} + } + + incomingTracks = append(incomingTracks, tracksInMediaSection...) + } + + return incomingTracks +} + +func trackDetailsToRTPReceiveParameters(trackDetails *trackDetails) RTPReceiveParameters { + encodingSize := max(len(trackDetails.rids), len(trackDetails.ssrcs)) + + encodings := make([]RTPDecodingParameters, encodingSize) + for i := range encodings { + if len(trackDetails.rids) > i { + encodings[i].RID = trackDetails.rids[i] + } + if len(trackDetails.ssrcs) > i { + encodings[i].SSRC = trackDetails.ssrcs[i] + } + + if trackDetails.rtxSsrc != nil { + encodings[i].RTX.SSRC = *trackDetails.rtxSsrc + } + + if trackDetails.fecSsrc != nil { + encodings[i].FEC.SSRC = *trackDetails.fecSsrc + } + } + + return RTPReceiveParameters{Encodings: encodings} +} + +func getRids(media *sdp.MediaDescription) []*simulcastRid { + rids := []*simulcastRid{} + var simulcastAttr string + for _, attr := range media.Attributes { + if attr.Key == sdpAttributeRid { + split := strings.Split(attr.Value, " ") + rids = append(rids, &simulcastRid{id: split[0], attrValue: attr.Value}) + } else if attr.Key == sdpAttributeSimulcast { + simulcastAttr = attr.Value + } + } + // process paused stream like "a=simulcast:send 1;~2;~3" + if simulcastAttr != "" { + if space := strings.Index(simulcastAttr, " "); space > 0 { + simulcastAttr = simulcastAttr[space+1:] + } + ridStates := strings.SplitSeq(simulcastAttr, ";") + for ridState := range ridStates { + if ridState[:1] == "~" { + ridID := ridState[1:] + for _, rid := range rids { + if rid.id == ridID { + rid.paused = true + + break + } + } + } + } + } + + return rids +} + +func addCandidatesToMediaDescriptions( + candidates []ICECandidate, + mediaDescr *sdp.MediaDescription, + iceGatheringState ICEGatheringState, +) error { + appendCandidateIfNew := func(c ice.Candidate, attributes []sdp.Attribute) { + marshaled := c.Marshal() + for _, a := range attributes { + if marshaled == a.Value { + return + } + } + + mediaDescr.WithValueAttribute("candidate", marshaled) + } + + for _, c := range candidates { + candidate, err := c.ToICE() + if err != nil { + return err + } + + candidate.SetComponent(1) + appendCandidateIfNew(candidate, mediaDescr.Attributes) + + candidate.SetComponent(2) + appendCandidateIfNew(candidate, mediaDescr.Attributes) + } + + if iceGatheringState != ICEGatheringStateComplete { + return nil + } + for _, a := range mediaDescr.Attributes { + if a.Key == "end-of-candidates" { + return nil + } + } + + mediaDescr.WithPropertyAttribute("end-of-candidates") + + return nil +} + +func addDataMediaSection( + descr *sdp.SessionDescription, + shouldAddCandidates bool, + dtlsFingerprints []DTLSFingerprint, + midValue string, + iceParams ICEParameters, + candidates []ICECandidate, + dtlsRole sdp.ConnectionRole, + iceGatheringState ICEGatheringState, + sctpMaxMessageSize uint32, + sctpInit []byte, +) error { + media := (&sdp.MediaDescription{ + MediaName: sdp.MediaName{ + Media: mediaSectionApplication, + Port: sdp.RangedPort{Value: 9}, + Protos: []string{"UDP", "DTLS", "SCTP"}, + Formats: []string{"webrtc-datachannel"}, + }, + ConnectionInformation: &sdp.ConnectionInformation{ + NetworkType: "IN", + AddressType: "IP4", + Address: &sdp.Address{ + Address: "0.0.0.0", + }, + }, + }). + WithValueAttribute(sdp.AttrKeyConnectionSetup, dtlsRole.String()). + WithValueAttribute(sdp.AttrKeyMID, midValue). + WithPropertyAttribute(RTPTransceiverDirectionSendrecv.String()). + WithPropertyAttribute("sctp-port:5000"). + WithValueAttribute("max-message-size", fmt.Sprintf("%d", sctpMaxMessageSize)). + WithICECredentials(iceParams.UsernameFragment, iceParams.Password) + + if len(sctpInit) != 0 { + media = media.WithValueAttribute("sctp-init", base64.StdEncoding.EncodeToString(sctpInit)) + } + for _, f := range dtlsFingerprints { + media = media.WithFingerprint(f.Algorithm, strings.ToUpper(f.Value)) + } + + if shouldAddCandidates { + if err := addCandidatesToMediaDescriptions(candidates, media, iceGatheringState); err != nil { + return err + } + } + + descr.WithMedia(media) + + return nil +} + +func populateLocalCandidates( + sessionDescription *SessionDescription, + i *ICEGatherer, + iceGatheringState ICEGatheringState, +) *SessionDescription { + if sessionDescription == nil || i == nil { + return sessionDescription + } + + candidates, err := i.GetLocalCandidates() + if err != nil { + return sessionDescription + } + + parsed := sessionDescription.parsed + if len(parsed.MediaDescriptions) > 0 { + mediaDescr := parsed.MediaDescriptions[0] + if err = addCandidatesToMediaDescriptions(candidates, mediaDescr, iceGatheringState); err != nil { + return sessionDescription + } + } + + sdp, err := parsed.Marshal() + if err != nil { + return sessionDescription + } + + return &SessionDescription{ + SDP: string(sdp), + Type: sessionDescription.Type, + parsed: parsed, + } +} + +//nolint:gocognit,cyclop +func addSenderSDP( + mediaSection mediaSection, + isPlanB bool, + media *sdp.MediaDescription, +) { + for _, mt := range mediaSection.transceivers { + sender := mt.Sender() + if sender == nil { + continue + } + + track := sender.Track() + if track == nil { + continue + } + + sendParameters := sender.GetParameters() + for _, encoding := range sendParameters.Encodings { + if encoding.RTX.SSRC != 0 { + media = media.WithValueAttribute( + "ssrc-group", + fmt.Sprintf( + "%s %d %d", + sdp.SemanticTokenFlowIdentification, + encoding.SSRC, + encoding.RTX.SSRC, + ), + ) + } + if encoding.FEC.SSRC != 0 { + media = media.WithValueAttribute( + "ssrc-group", + fmt.Sprintf( + "%s %d %d", + sdp.SemanticTokenForwardErrorCorrectionFramework, + encoding.SSRC, + encoding.FEC.SSRC, + ), + ) + } + + media = media.WithMediaSource( + uint32(encoding.SSRC), + track.StreamID(), /* cname */ + track.StreamID(), /* streamLabel */ + track.ID(), + ) + + if !isPlanB { + if encoding.RTX.SSRC != 0 { + media = media.WithMediaSource( + uint32(encoding.RTX.SSRC), + track.StreamID(), /* cname */ + track.StreamID(), /* streamLabel */ + track.ID(), + ) + } + if encoding.FEC.SSRC != 0 { + media = media.WithMediaSource( + uint32(encoding.FEC.SSRC), + track.StreamID(), /* cname */ + track.StreamID(), /* streamLabel */ + track.ID(), + ) + } + + media = media.WithPropertyAttribute("msid:" + track.StreamID() + " " + track.ID()) + } + } + + if len(sendParameters.Encodings) > 1 { + sendRids := make([]string, 0, len(sendParameters.Encodings)) + + for _, encoding := range sendParameters.Encodings { + media.WithValueAttribute(sdpAttributeRid, encoding.RID+" send") + sendRids = append(sendRids, encoding.RID) + } + // Simulcast + media.WithValueAttribute(sdpAttributeSimulcast, "send "+strings.Join(sendRids, ";")) + } + + if !isPlanB { + break + } + } +} + +//nolint:cyclop, gocognit +func addTransceiverSDP( + descr *sdp.SessionDescription, + isPlanB bool, + shouldAddCandidates bool, + dtlsFingerprints []DTLSFingerprint, + mediaEngine *MediaEngine, + midValue string, + iceParams ICEParameters, + candidates []ICECandidate, + dtlsRole sdp.ConnectionRole, + iceGatheringState ICEGatheringState, + mediaSection mediaSection, + ignoreRidPauseForRecv bool, +) (bool, error) { + transceivers := mediaSection.transceivers + if len(transceivers) < 1 { + return false, errSDPZeroTransceivers + } + // Use the first transceiver to generate the section attributes + transceiver := transceivers[0] + media := sdp.NewJSEPMediaDescription(transceiver.kind.String(), []string{}). + WithValueAttribute(sdp.AttrKeyConnectionSetup, dtlsRole.String()). + WithValueAttribute(sdp.AttrKeyMID, midValue). + WithICECredentials(iceParams.UsernameFragment, iceParams.Password). + WithPropertyAttribute(sdp.AttrKeyRTCPMux). + WithPropertyAttribute(sdp.AttrKeyRTCPRsize) + + codecs := transceiver.getCodecs() + for _, codec := range codecs { + name := strings.TrimPrefix(codec.MimeType, "audio/") + name = strings.TrimPrefix(name, "video/") + media.WithCodec(uint8(codec.PayloadType), name, codec.ClockRate, codec.Channels, codec.SDPFmtpLine) + + for _, feedback := range codec.RTPCodecCapability.RTCPFeedback { + if feedback.Parameter == "" { + media.WithValueAttribute("rtcp-fb", fmt.Sprintf("%d %s", codec.PayloadType, feedback.Type)) + } else { + media.WithValueAttribute("rtcp-fb", fmt.Sprintf("%d %s %s", codec.PayloadType, feedback.Type, feedback.Parameter)) + } + } + } + if len(codecs) == 0 { + // If we are sender and we have no codecs throw an error early + if transceiver.Sender() != nil { + return false, ErrSenderWithNoCodecs + } + + // Explicitly reject track if we don't have the codec + // We need to include connection information even if we're rejecting a track, otherwise Firefox will fail to + // parse the SDP with an error like: + // SIPCC Failed to parse SDP: SDP Parse Error on line 50: c= connection line not specified for every media level, + // validation failed. + // In addition this makes our SDP compliant with RFC 4566 Section 5.7: + // https://datatracker.ietf.org/doc/html/rfc4566#section-5.7 + descr.WithMedia(&sdp.MediaDescription{ + MediaName: sdp.MediaName{ + Media: transceiver.kind.String(), + Port: sdp.RangedPort{Value: 0}, + Protos: []string{"UDP", "TLS", "RTP", "SAVPF"}, + Formats: []string{"0"}, + }, + ConnectionInformation: &sdp.ConnectionInformation{ + NetworkType: "IN", + AddressType: "IP4", + Address: &sdp.Address{ + Address: "0.0.0.0", + }, + }, + }) + + return false, nil + } + + directions := []RTPTransceiverDirection{} + if transceiver.Sender() != nil { + directions = append(directions, RTPTransceiverDirectionSendonly) + } + if transceiver.Receiver() != nil { + directions = append(directions, RTPTransceiverDirectionRecvonly) + } + + parameters := mediaEngine.getRTPParametersByKind(transceiver.kind, directions) + for _, rtpExtension := range parameters.HeaderExtensions { + if mediaSection.matchExtensions != nil { + if _, enabled := mediaSection.matchExtensions[rtpExtension.URI]; !enabled { + continue + } + } + extURL, err := url.Parse(rtpExtension.URI) + if err != nil { + return false, err + } + media.WithExtMap(sdp.ExtMap{Value: rtpExtension.ID, URI: extURL}) + } + + if len(mediaSection.rids) > 0 { + recvRids := make([]string, 0, len(mediaSection.rids)) + + for _, rid := range mediaSection.rids { + ridID := rid.id + media.WithValueAttribute(sdpAttributeRid, ridID+" recv") + if rid.paused && !ignoreRidPauseForRecv { + ridID = "~" + ridID + } + recvRids = append(recvRids, ridID) + } + // Simulcast + media.WithValueAttribute(sdpAttributeSimulcast, "recv "+strings.Join(recvRids, ";")) + } + + addSenderSDP(mediaSection, isPlanB, media) + + media = media.WithPropertyAttribute(transceiver.Direction().String()) + + for _, fingerprint := range dtlsFingerprints { + media = media.WithFingerprint(fingerprint.Algorithm, strings.ToUpper(fingerprint.Value)) + } + + if shouldAddCandidates { + if err := addCandidatesToMediaDescriptions(candidates, media, iceGatheringState); err != nil { + return false, err + } + } + + descr.WithMedia(media) + + return true, nil +} + +type simulcastRid struct { + id string + attrValue string + paused bool +} + +type mediaSection struct { + id string + transceivers []*RTPTransceiver + data bool + sctpInit []byte + matchExtensions map[string]int + rids []*simulcastRid +} + +func bundleMatchFromRemote(matchBundleGroup *string) func(mid string) bool { + if matchBundleGroup == nil { + return func(string) bool { + return true + } + } + bundleTags := strings.Split(*matchBundleGroup, " ") + + return func(midValue string) bool { + return slices.Contains(bundleTags, midValue) + } +} + +// populateSDP serializes a PeerConnections state into an SDP. +// +//nolint:cyclop +func populateSDP( + descr *sdp.SessionDescription, + isPlanB bool, + dtlsFingerprints []DTLSFingerprint, + mediaDescriptionFingerprint bool, + isICELite bool, + isExtmapAllowMixed bool, + mediaEngine *MediaEngine, + connectionRole sdp.ConnectionRole, + candidates []ICECandidate, + iceParams ICEParameters, + mediaSections []mediaSection, + iceGatheringState ICEGatheringState, + matchBundleGroup *string, + sctpMaxMessageSize uint32, + ignoreRidPauseForRecv bool, +) (*sdp.SessionDescription, error) { + var err error + mediaDtlsFingerprints := []DTLSFingerprint{} + + if mediaDescriptionFingerprint { + mediaDtlsFingerprints = dtlsFingerprints + } + + bundleValue := "BUNDLE" + bundleCount := 0 + + bundleMatch := bundleMatchFromRemote(matchBundleGroup) + appendBundle := func(midValue string) { + bundleValue += " " + midValue + bundleCount++ + } + + for i, section := range mediaSections { + if section.data && len(section.transceivers) != 0 { + return nil, errSDPMediaSectionMediaDataChanInvalid + } else if !isPlanB && len(section.transceivers) > 1 { + return nil, errSDPMediaSectionMultipleTrackInvalid + } + + shouldAddID := true + shouldAddCandidates := i == 0 + if section.data { + if err = addDataMediaSection( + descr, + shouldAddCandidates, + mediaDtlsFingerprints, + section.id, + iceParams, + candidates, + connectionRole, + iceGatheringState, + sctpMaxMessageSize, + section.sctpInit, + ); err != nil { + return nil, err + } + } else { + shouldAddID, err = addTransceiverSDP( + descr, + isPlanB, + shouldAddCandidates, + mediaDtlsFingerprints, + mediaEngine, + section.id, + iceParams, + candidates, + connectionRole, + iceGatheringState, + section, + ignoreRidPauseForRecv, + ) + if err != nil { + return nil, err + } + } + + if shouldAddID { + if bundleMatch(section.id) { + appendBundle(section.id) + } else { + descr.MediaDescriptions[len(descr.MediaDescriptions)-1].MediaName.Port = sdp.RangedPort{Value: 0} + } + } + } + + if !mediaDescriptionFingerprint { + for _, fingerprint := range dtlsFingerprints { + descr.WithFingerprint(fingerprint.Algorithm, strings.ToUpper(fingerprint.Value)) + } + } + + if isICELite { + // RFC 5245 S15.3 + descr = descr.WithValueAttribute(sdp.AttrKeyICELite, "") + } + + if isExtmapAllowMixed { + descr = descr.WithPropertyAttribute(sdp.AttrKeyExtMapAllowMixed) + } + + if bundleCount > 0 { + descr = descr.WithValueAttribute(sdp.AttrKeyGroup, bundleValue) + } + + return descr, nil +} + +func getMidValue(media *sdp.MediaDescription) string { + for _, attr := range media.Attributes { + if attr.Key == "mid" { + return attr.Value + } + } + + return "" +} + +// SessionDescription contains a MediaSection with Multiple SSRCs, it is Plan-B. +func descriptionIsPlanB(desc *SessionDescription, log logging.LeveledLogger) bool { + if desc == nil || desc.parsed == nil { + return false + } + + // Store all MIDs that already contain a track + midWithTrack := map[string]bool{} + + for _, trackDetail := range trackDetailsFromSDP(log, desc.parsed) { + if _, ok := midWithTrack[trackDetail.mid]; ok { + return true + } + midWithTrack[trackDetail.mid] = true + } + + return false +} + +// SessionDescription contains a MediaSection with name `audio`, `video` or `data` +// If only one SSRC is set we can't know if it is Plan-B or Unified. If users have +// set fallback mode assume it is Plan-B. +func descriptionPossiblyPlanB(desc *SessionDescription) bool { + if desc == nil || desc.parsed == nil { + return false + } + + detectionRegex := regexp.MustCompile(`(?i)^(audio|video|data)$`) + for _, media := range desc.parsed.MediaDescriptions { + if len(detectionRegex.FindStringSubmatch(getMidValue(media))) == 2 { + return true + } + } + + return false +} + +func getPeerDirection(media *sdp.MediaDescription) RTPTransceiverDirection { + for _, a := range media.Attributes { + if direction := NewRTPTransceiverDirection(a.Key); direction != RTPTransceiverDirectionUnknown { + return direction + } + } + + return RTPTransceiverDirectionUnknown +} + +func extractBundleID(desc *sdp.SessionDescription) string { + groupAttribute, _ := desc.Attribute(sdp.AttrKeyGroup) + + isBundled := strings.Contains(groupAttribute, "BUNDLE") + + if !isBundled { + return "" + } + + bundleIDs := strings.Split(groupAttribute, " ") + + if len(bundleIDs) < 2 { + return "" + } + + return bundleIDs[1] +} + +func extractFingerprint(desc *sdp.SessionDescription) (string, string, error) { //nolint:gocognit,cyclop + fingerprint := "" + + // Fingerprint on session level has highest priority + if sessionFingerprint, haveFingerprint := desc.Attribute("fingerprint"); haveFingerprint { + fingerprint = sessionFingerprint + } + + if fingerprint == "" { //nolint:nestif + bundleID := extractBundleID(desc) + if bundleID != "" { + // Locate the fingerprint of the bundled media section + for _, mediaDescr := range desc.MediaDescriptions { + if mid, haveMid := mediaDescr.Attribute("mid"); haveMid { + if mid == bundleID && fingerprint == "" { + if mediaFingerprint, haveFingerprint := mediaDescr.Attribute("fingerprint"); haveFingerprint { + fingerprint = mediaFingerprint + } + } + } + } + } else { + // Take the fingerprint from the first media section which has one. + // Note: According to Bundle spec each media section would have it's own transport + // with it's own cert and fingerprint each, so we would need to return a list. + for _, mediaDescr := range desc.MediaDescriptions { + mediaFingerprint, haveFingerprint := mediaDescr.Attribute("fingerprint") + if haveFingerprint && fingerprint == "" { + fingerprint = mediaFingerprint + } + } + } + } + + if fingerprint == "" { + return "", "", ErrSessionDescriptionNoFingerprint + } + + parts := strings.Split(fingerprint, " ") + if len(parts) != 2 { + return "", "", ErrSessionDescriptionInvalidFingerprint + } + + return parts[1], parts[0], nil +} + +// identifiedMediaDescription contains a MediaDescription with sdpMid and sdpMLineIndex. +type identifiedMediaDescription struct { + MediaDescription *sdp.MediaDescription + SDPMid string + SDPMLineIndex uint16 +} + +func extractICEDetailsFromMedia( //nolint:cyclop + media *identifiedMediaDescription, + log logging.LeveledLogger, +) (string, string, []ICECandidate, error) { + remoteUfrag := "" + remotePwd := "" + candidates := []ICECandidate{} + descr := media.MediaDescription + + if ufrag, haveUfrag := descr.Attribute("ice-ufrag"); haveUfrag { + remoteUfrag = ufrag + } + if pwd, havePwd := descr.Attribute("ice-pwd"); havePwd { + remotePwd = pwd + } + + // track the last error we saw while parsing candidates. + // if we end up with no valid candidates then return prevErr. + var prevErr error + + for _, attr := range descr.Attributes { + if !attr.IsICECandidate() { + continue + } + + cand, err := ice.UnmarshalCandidate(attr.Value) + if err != nil { + // similar to AddICECandidate + if errors.Is(err, ice.ErrUnknownCandidateTyp) || errors.Is(err, ice.ErrDetermineNetworkType) { + if log != nil { + log.Warnf("Discarding remote candidate: %s", err) + } + + continue + } + + if log != nil { + log.Warnf("Failed to parse remote candidate %q: %v", attr.Value, err) + } + + prevErr = err + + continue + } + + candidate, err := newICECandidateFromICE(cand, media.SDPMid, media.SDPMLineIndex) + if err != nil { + if log != nil { + log.Warnf("Failed to convert remote candidate %q: %v", attr.Value, err) + } + + prevErr = err + + continue + } + + candidates = append(candidates, candidate) + } + + // if we saw only invalid candidates then bubble up the last error + // so SetRemoteDescription fails with prevErr. + if len(candidates) == 0 && prevErr != nil { + return "", "", nil, prevErr + } + + return remoteUfrag, remotePwd, candidates, nil +} + +type sdpICEDetails struct { + Ufrag string + Password string //nolint:gosec // not a secret. + Candidates []ICECandidate +} + +func extractICEDetails( + desc *sdp.SessionDescription, + log logging.LeveledLogger, +) (*sdpICEDetails, error) { // nolint:gocognit + details := &sdpICEDetails{ + Candidates: []ICECandidate{}, + } + + // Ufrag and Pw are allow at session level and thus have highest prio + if ufrag, haveUfrag := desc.Attribute("ice-ufrag"); haveUfrag { + details.Ufrag = ufrag + } + if pwd, havePwd := desc.Attribute("ice-pwd"); havePwd { + details.Password = pwd + } + + mediaDescr, ok := selectCandidateMediaSection(desc) + if ok { + ufrag, pwd, candidates, err := extractICEDetailsFromMedia(mediaDescr, log) + if err != nil { + return nil, err + } + + if details.Ufrag == "" && ufrag != "" { + details.Ufrag = ufrag + details.Password = pwd + } + + details.Candidates = candidates + } + + if details.Ufrag == "" { + return nil, ErrSessionDescriptionMissingIceUfrag + } else if details.Password == "" { + return nil, ErrSessionDescriptionMissingIcePwd + } + + return details, nil +} + +// Select the first media section or the first bundle section +// Currently Pion uses the first media section to gather candidates. +// https://github.com/pion/webrtc/pull/2950 +func selectCandidateMediaSection(sessionDescription *sdp.SessionDescription) ( + descr *identifiedMediaDescription, + ok bool, +) { + bundleID := extractBundleID(sessionDescription) + + for mLineIndex, mediaDescr := range sessionDescription.MediaDescriptions { + mid := getMidValue(mediaDescr) + // If bundled, only take ICE detail from bundle master section + if bundleID != "" { + if mid == bundleID { + return &identifiedMediaDescription{ + MediaDescription: mediaDescr, + SDPMid: mid, + SDPMLineIndex: uint16(mLineIndex), //nolint:gosec // G115 + }, true + } + } else { + // For not-bundled, take ICE details from the first media section + return &identifiedMediaDescription{ + MediaDescription: mediaDescr, + SDPMid: mid, + SDPMLineIndex: uint16(mLineIndex), //nolint:gosec // G115 + }, true + } + } + + return nil, false +} + +func getByMid(searchMid string, desc *SessionDescription) *sdp.MediaDescription { + for _, m := range desc.parsed.MediaDescriptions { + if mid, ok := m.Attribute(sdp.AttrKeyMID); ok && mid == searchMid { + return m + } + } + + return nil +} + +// haveDataChannel return MediaDescription with MediaName equal application. +func haveDataChannel(desc *SessionDescription) *sdp.MediaDescription { + for _, d := range desc.parsed.MediaDescriptions { + if d.MediaName.Media == mediaSectionApplication { + return d + } + } + + return nil +} + +func codecsFromMediaDescription(mediaDescr *sdp.MediaDescription) (out []RTPCodecParameters, err error) { + s := &sdp.SessionDescription{ + MediaDescriptions: []*sdp.MediaDescription{mediaDescr}, + } + + for _, payloadStr := range mediaDescr.MediaName.Formats { + payloadType, err := strconv.ParseUint(payloadStr, 10, 8) + if err != nil { + return nil, err + } + + codec, err := s.GetCodecForPayloadType(uint8(payloadType)) + if err != nil { + if payloadType == 0 { + continue + } + + return nil, err + } + + channels := uint16(0) + val, err := strconv.ParseUint(codec.EncodingParameters, 10, 16) + if err == nil { + channels = uint16(val) + } + + feedback := []RTCPFeedback{} + for _, raw := range codec.RTCPFeedback { + split := strings.Split(raw, " ") + entry := RTCPFeedback{Type: split[0]} + if len(split) == 2 { + entry.Parameter = split[1] + } + + feedback = append(feedback, entry) + } + + out = append(out, RTPCodecParameters{ + RTPCodecCapability: RTPCodecCapability{ + mediaDescr.MediaName.Media + "/" + codec.Name, + codec.ClockRate, + channels, + codec.Fmtp, + feedback, + }, + PayloadType: PayloadType(payloadType), + }) + } + + return out, nil +} + +func rtpExtensionsFromMediaDescription(m *sdp.MediaDescription) (map[string]int, error) { + out := map[string]int{} + + for _, a := range m.Attributes { + if a.Key == sdp.AttrKeyExtMap { + e := sdp.ExtMap{} + if err := e.Unmarshal(a.String()); err != nil { + return nil, err + } + + out[e.URI.String()] = e.Value + } + } + + return out, nil +} + +// updateSDPOrigin saves sdp.Origin in PeerConnection when creating 1st local SDP; +// for subsequent calling, it updates Origin for SessionDescription from saved one +// and increments session version by one. +// https://tools.ietf.org/html/draft-ietf-rtcweb-jsep-25#section-5.2.2 +func updateSDPOrigin(origin *sdp.Origin, descr *sdp.SessionDescription) { + if atomic.CompareAndSwapUint64(&origin.SessionVersion, 0, descr.Origin.SessionVersion) { // store + atomic.StoreUint64(&origin.SessionID, descr.Origin.SessionID) + } else { // load + for { // awaiting for saving session id + descr.Origin.SessionID = atomic.LoadUint64(&origin.SessionID) + if descr.Origin.SessionID != 0 { + break + } + } + descr.Origin.SessionVersion = atomic.AddUint64(&origin.SessionVersion, 1) + } +} + +func isIceLiteSet(desc *sdp.SessionDescription) bool { + for _, a := range desc.Attributes { + if strings.TrimSpace(a.Key) == sdp.AttrKeyICELite { + return true + } + } + + return false +} + +func isExtMapAllowMixedSet(desc *sdp.SessionDescription) bool { + for _, a := range desc.Attributes { + if strings.TrimSpace(a.Key) == sdp.AttrKeyExtMapAllowMixed { + return true + } + } + + return false +} + +func getMaxMessageSize(desc *sdp.MediaDescription) uint32 { + for _, a := range desc.Attributes { + if strings.TrimSpace(a.Key) == "max-message-size" { + if v, err := strconv.ParseUint(a.Value, 10, 32); err == nil { + return uint32(v) + } + } + } + + return 0 +} + +func getSctpInit(desc *sdp.MediaDescription) ([]byte, error) { + for _, a := range desc.Attributes { + if strings.TrimSpace(a.Key) == "sctp-init" { + decoded, err := base64.StdEncoding.DecodeString(a.Value) + if err != nil { + return nil, err + } + + return decoded, nil + } + } + + return nil, nil +} diff --git a/vendor/github.com/pion/webrtc/v4/sdpsemantics.go b/vendor/github.com/pion/webrtc/v4/sdpsemantics.go new file mode 100644 index 0000000..a5a8103 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/sdpsemantics.go @@ -0,0 +1,76 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +import ( + "encoding/json" +) + +// SDPSemantics determines which style of SDP offers and answers +// can be used. +type SDPSemantics int + +const ( + // SDPSemanticsUnifiedPlan uses unified-plan offers and answers + // (the default in Chrome since M72) + // https://tools.ietf.org/html/draft-roach-mmusic-unified-plan-00 + SDPSemanticsUnifiedPlan SDPSemantics = iota + + // SDPSemanticsPlanB uses plan-b offers and answers + // NB: This format should be considered deprecated + // https://tools.ietf.org/html/draft-uberti-rtcweb-plan-00 + SDPSemanticsPlanB + + // SDPSemanticsUnifiedPlanWithFallback prefers unified-plan + // offers and answers, but will respond to a plan-b offer + // with a plan-b answer. + SDPSemanticsUnifiedPlanWithFallback +) + +const ( + sdpSemanticsUnifiedPlanWithFallback = "unified-plan-with-fallback" + sdpSemanticsUnifiedPlan = "unified-plan" + sdpSemanticsPlanB = "plan-b" +) + +func newSDPSemantics(raw string) SDPSemantics { + switch raw { + case sdpSemanticsPlanB: + return SDPSemanticsPlanB + case sdpSemanticsUnifiedPlanWithFallback: + return SDPSemanticsUnifiedPlanWithFallback + default: + return SDPSemanticsUnifiedPlan + } +} + +func (s SDPSemantics) String() string { + switch s { + case SDPSemanticsUnifiedPlanWithFallback: + return sdpSemanticsUnifiedPlanWithFallback + case SDPSemanticsUnifiedPlan: + return sdpSemanticsUnifiedPlan + case SDPSemanticsPlanB: + return sdpSemanticsPlanB + default: + return ErrUnknownType.Error() + } +} + +// UnmarshalJSON parses the JSON-encoded data and stores the result. +func (s *SDPSemantics) UnmarshalJSON(b []byte) error { + var val string + if err := json.Unmarshal(b, &val); err != nil { + return err + } + + *s = newSDPSemantics(val) + + return nil +} + +// MarshalJSON returns the JSON encoding. +func (s SDPSemantics) MarshalJSON() ([]byte, error) { + return json.Marshal(s.String()) +} diff --git a/vendor/github.com/pion/webrtc/v4/sdptype.go b/vendor/github.com/pion/webrtc/v4/sdptype.go new file mode 100644 index 0000000..f42b0a8 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/sdptype.go @@ -0,0 +1,105 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +import ( + "encoding/json" + "strings" +) + +// SDPType describes the type of an SessionDescription. +type SDPType int + +const ( + // SDPTypeUnknown is the enum's zero-value. + SDPTypeUnknown SDPType = iota + + // SDPTypeOffer indicates that a description MUST be treated as an SDP offer. + SDPTypeOffer + + // SDPTypePranswer indicates that a description MUST be treated as an + // SDP answer, but not a final answer. A description used as an SDP + // pranswer may be applied as a response to an SDP offer, or an update to + // a previously sent SDP pranswer. + SDPTypePranswer + + // SDPTypeAnswer indicates that a description MUST be treated as an SDP + // final answer, and the offer-answer exchange MUST be considered complete. + // A description used as an SDP answer may be applied as a response to an + // SDP offer or as an update to a previously sent SDP pranswer. + SDPTypeAnswer + + // SDPTypeRollback indicates that a description MUST be treated as + // canceling the current SDP negotiation and moving the SDP offer and + // answer back to what it was in the previous stable state. Note the + // local or remote SDP descriptions in the previous stable state could be + // null if there has not yet been a successful offer-answer negotiation. + SDPTypeRollback +) + +// This is done this way because of a linter. +const ( + sdpTypeOfferStr = "offer" + sdpTypePranswerStr = "pranswer" + sdpTypeAnswerStr = "answer" + sdpTypeRollbackStr = "rollback" +) + +// NewSDPType creates an SDPType from a string. +func NewSDPType(raw string) SDPType { + switch raw { + case sdpTypeOfferStr: + return SDPTypeOffer + case sdpTypePranswerStr: + return SDPTypePranswer + case sdpTypeAnswerStr: + return SDPTypeAnswer + case sdpTypeRollbackStr: + return SDPTypeRollback + default: + return SDPTypeUnknown + } +} + +func (t SDPType) String() string { + switch t { + case SDPTypeOffer: + return sdpTypeOfferStr + case SDPTypePranswer: + return sdpTypePranswerStr + case SDPTypeAnswer: + return sdpTypeAnswerStr + case SDPTypeRollback: + return sdpTypeRollbackStr + default: + return ErrUnknownType.Error() + } +} + +// MarshalJSON enables JSON marshaling of a SDPType. +func (t SDPType) MarshalJSON() ([]byte, error) { + return json.Marshal(t.String()) +} + +// UnmarshalJSON enables JSON unmarshaling of a SDPType. +func (t *SDPType) UnmarshalJSON(b []byte) error { + var s string + if err := json.Unmarshal(b, &s); err != nil { + return err + } + switch strings.ToLower(s) { + default: + return ErrUnknownType + case "offer": + *t = SDPTypeOffer + case "pranswer": + *t = SDPTypePranswer + case "answer": + *t = SDPTypeAnswer + case "rollback": + *t = SDPTypeRollback + } + + return nil +} diff --git a/vendor/github.com/pion/webrtc/v4/sessiondescription.go b/vendor/github.com/pion/webrtc/v4/sessiondescription.go new file mode 100644 index 0000000..1f4d8ab --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/sessiondescription.go @@ -0,0 +1,75 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +import ( + "fmt" + "slices" + "strings" + + "github.com/pion/sdp/v3" +) + +// ICETrickleCapability represents whether the remote endpoint accepts +// trickled ICE candidates. +type ICETrickleCapability int + +const ( + // ICETrickleCapabilityUnknown no remote peer has been established. + ICETrickleCapabilityUnknown ICETrickleCapability = iota + // ICETrickleCapabilitySupported remote peer can accept trickled ICE candidates. + ICETrickleCapabilitySupported + // ICETrickleCapabilitySupported remote peer didn't state that it can accept trickle ICE candidates. + ICETrickleCapabilityUnsupported +) + +// String returns the string representation of ICETrickleCapability. +func (t ICETrickleCapability) String() string { + switch t { + case ICETrickleCapabilitySupported: + return "supported" + case ICETrickleCapabilityUnsupported: + return "unsupported" + default: + return "unknown" + } +} + +// SessionDescription is used to expose local and remote session descriptions. +type SessionDescription struct { + Type SDPType `json:"type"` + SDP string `json:"sdp"` + + // This will never be initialized by callers, internal use only + parsed *sdp.SessionDescription +} + +// Unmarshal is a helper to deserialize the sdp. +func (sd *SessionDescription) Unmarshal() (*sdp.SessionDescription, error) { + sd.parsed = &sdp.SessionDescription{} + err := sd.parsed.UnmarshalString(sd.SDP) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrSDPUnmarshalling, err) + } + + return sd.parsed, nil +} + +func hasICETrickleOption(desc *sdp.SessionDescription) bool { + if value, ok := desc.Attribute(sdp.AttrKeyICEOptions); ok && hasTrickleOptionValue(value) { + return true + } + + for _, media := range desc.MediaDescriptions { + if value, ok := media.Attribute(sdp.AttrKeyICEOptions); ok && hasTrickleOptionValue(value) { + return true + } + } + + return false +} + +func hasTrickleOptionValue(value string) bool { + return slices.Contains(strings.Fields(value), "trickle") +} diff --git a/vendor/github.com/pion/webrtc/v4/settingengine.go b/vendor/github.com/pion/webrtc/v4/settingengine.go new file mode 100644 index 0000000..dda020f --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/settingengine.go @@ -0,0 +1,724 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js + +package webrtc + +import ( + "context" + "crypto/x509" + "errors" + "io" + "net" + "time" + + "github.com/pion/dtls/v3" + dtlsElliptic "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/ice/v4" + "github.com/pion/logging" + "github.com/pion/stun/v3" + "github.com/pion/transport/v4" + "github.com/pion/transport/v4/packetio" + "golang.org/x/net/proxy" +) + +// SettingEngine allows influencing behavior in ways that are not +// supported by the WebRTC API. This allows us to support additional +// use-cases without deviating from the WebRTC API elsewhere. +type SettingEngine struct { + ephemeralUDP struct { + PortMin uint16 + PortMax uint16 + } + detach struct { + DataChannels bool + } + timeout struct { + ICEDisconnectedTimeout *time.Duration + ICEFailedTimeout *time.Duration + ICEKeepaliveInterval *time.Duration + ICEHostAcceptanceMinWait *time.Duration + ICESrflxAcceptanceMinWait *time.Duration + ICEPrflxAcceptanceMinWait *time.Duration + ICERelayAcceptanceMinWait *time.Duration + ICESTUNGatherTimeout *time.Duration + } + renomination renominationSettings + candidates struct { + ICELite bool + ICENetworkTypes []NetworkType + InterfaceFilter func(string) (keep bool) + IPFilter func(net.IP) (keep bool) + RemoteIPFilter func(net.IP) (keep bool) + NAT1To1IPs []string + NAT1To1IPCandidateType ICECandidateType + addressRewriteRules []ice.AddressRewriteRule + MulticastDNSMode ice.MulticastDNSMode + MulticastDNSHostName string + UsernameFragment string + Password string //nolint:gosec // not a secret. + IncludeLoopbackCandidate bool + } + replayProtection struct { + DTLS *uint + SRTP *uint + SRTCP *uint + } + dtls struct { + insecureSkipHelloVerify bool + disableInsecureSkipVerify bool + retransmissionInterval time.Duration + ellipticCurves []dtlsElliptic.Curve + connectContextMaker func() (context.Context, func()) + extendedMasterSecret dtls.ExtendedMasterSecretType + clientAuth *dtls.ClientAuthType + clientCAs *x509.CertPool + rootCAs *x509.CertPool + keyLogWriter io.Writer + cipherSuites []dtls.CipherSuiteID + customCipherSuites func() []dtls.CipherSuite + clientHelloMessageHook func(handshake.MessageClientHello) handshake.Message + serverHelloMessageHook func(handshake.MessageServerHello) handshake.Message + certificateRequestMessageHook func(handshake.MessageCertificateRequest) handshake.Message + supportedProtocols []string + } + sctp struct { + maxReceiveBufferSize uint32 + enableZeroChecksum bool + rtoMax time.Duration + maxMessageSize uint32 + minCwnd uint32 + fastRtxWnd uint32 + cwndCAStep uint32 + enableSnap bool + } + sdpMediaLevelFingerprints bool + answeringDTLSRole DTLSRole + disableCertificateFingerprintVerification bool + disableSRTPReplayProtection bool + disableSRTCPReplayProtection bool + net transport.Net + BufferFactory func(packetType packetio.BufferPacketType, ssrc uint32) io.ReadWriteCloser + LoggerFactory logging.LoggerFactory + iceTCPMux ice.TCPMux + iceUDPMux ice.UDPMux + iceProxyDialer proxy.Dialer + iceDisableActiveTCP bool + iceBindingRequestHandler func(m *stun.Message, local, remote ice.Candidate, pair *ice.CandidatePair) bool //nolint:lll + disableMediaEngineCopy bool + disableMediaEngineMultipleCodecs bool + srtpProtectionProfiles []dtls.SRTPProtectionProfile + receiveMTU uint + iceMaxBindingRequests *uint16 + fireOnTrackBeforeFirstRTP bool + disableCloseByDTLS bool + dataChannelBlockWrite bool + handleUndeclaredSSRCWithoutAnswer bool + ignoreRidPauseForRecv bool +} + +type renominationSettings struct { + enabled bool + generator ice.NominationValueGenerator + automatic bool + automaticInterval *time.Duration + attributeType *uint16 +} + +// NominationValueGenerator generates nomination values for ICE renomination. +type NominationValueGenerator func() uint32 + +func (f NominationValueGenerator) toIce() ice.NominationValueGenerator { + return ice.NominationValueGenerator(f) +} + +// RenominationOption allows configuring ICE renomination behavior. +type RenominationOption func(*renominationSettings) + +// WithRenominationGenerator overrides the default nomination value generator. +func WithRenominationGenerator(generator NominationValueGenerator) RenominationOption { + return func(cfg *renominationSettings) { + cfg.generator = generator.toIce() + } +} + +// WithRenominationInterval sets the interval for automatic renomination checks. +// Passing zero or a negative duration returns an error from SetICERenomination. +func WithRenominationInterval(interval time.Duration) RenominationOption { + return func(cfg *renominationSettings) { + i := interval + cfg.automaticInterval = &i + } +} + +// WithRenominationNominationAttribute overrides the STUN attribute type used for ICE renomination. +// If unset, the underlying ICE agent default is used. +func WithRenominationNominationAttribute(attrType uint16) RenominationOption { + return func(cfg *renominationSettings) { + a := attrType + cfg.attributeType = &a + } +} + +var errInvalidRenominationInterval = errors.New("renomination interval must be greater than zero") + +// SetICERenomination configures ICE renomination using options for generator, scheduling, and attribute type. +// Manual control is not exposed yet. This always enables automatic renomination with the default +// generator unless a custom one is provided. +func (e *SettingEngine) SetICERenomination(options ...RenominationOption) error { + cfg := e.renomination + for _, opt := range options { + if opt != nil { + opt(&cfg) + } + } + + if cfg.automaticInterval != nil && *cfg.automaticInterval <= 0 { + return errInvalidRenominationInterval + } + + if cfg.generator == nil { + cfg.generator = ice.DefaultNominationValueGenerator() + } + + e.renomination.enabled = true + e.renomination.generator = cfg.generator + e.renomination.automatic = true + e.renomination.automaticInterval = cfg.automaticInterval + e.renomination.attributeType = cfg.attributeType + + return nil +} + +func (e *SettingEngine) getSCTPMaxMessageSize() uint32 { + if e.sctp.maxMessageSize != 0 { + return e.sctp.maxMessageSize + } + + return defaultMaxSCTPMessageSize +} + +// getReceiveMTU returns the configured MTU. If SettingEngine's MTU is configured to 0 it returns the default. +func (e *SettingEngine) getReceiveMTU() uint { + if e.receiveMTU != 0 { + return e.receiveMTU + } + + return receiveMTU +} + +// DetachDataChannels enables detaching data channels. When enabled +// data channels have to be detached in the OnOpen callback using the +// DataChannel.Detach method. +func (e *SettingEngine) DetachDataChannels() { + e.detach.DataChannels = true +} + +// EnableDataChannelBlockWrite allows data channels to block on write, +// it only works if DetachDataChannels is enabled. +func (e *SettingEngine) EnableDataChannelBlockWrite(nonblockWrite bool) { + e.dataChannelBlockWrite = nonblockWrite +} + +// SetSRTPProtectionProfiles allows the user to override the default SRTP Protection Profiles +// The default srtp protection profiles are provided by the function `defaultSrtpProtectionProfiles`. +func (e *SettingEngine) SetSRTPProtectionProfiles(profiles ...dtls.SRTPProtectionProfile) { + e.srtpProtectionProfiles = profiles +} + +// SetICETimeouts sets the behavior around ICE Timeouts +// +// disconnectedTimeout: +// +// Duration without network activity before an Agent is considered disconnected. Default is 5 Seconds +// +// failedTimeout: +// +// Duration without network activity before an Agent is considered failed after disconnected. Default is 25 Seconds +// +// keepAliveInterval: +// +// How often the ICE Agent sends extra traffic if there is no activity, if media is flowing no traffic will be sent. +// +// Default is 2 seconds. +func (e *SettingEngine) SetICETimeouts(disconnectedTimeout, failedTimeout, keepAliveInterval time.Duration) { + e.timeout.ICEDisconnectedTimeout = &disconnectedTimeout + e.timeout.ICEFailedTimeout = &failedTimeout + e.timeout.ICEKeepaliveInterval = &keepAliveInterval +} + +// SetHostAcceptanceMinWait sets the ICEHostAcceptanceMinWait. +func (e *SettingEngine) SetHostAcceptanceMinWait(t time.Duration) { + e.timeout.ICEHostAcceptanceMinWait = &t +} + +// SetSrflxAcceptanceMinWait sets the ICESrflxAcceptanceMinWait. +func (e *SettingEngine) SetSrflxAcceptanceMinWait(t time.Duration) { + e.timeout.ICESrflxAcceptanceMinWait = &t +} + +// SetPrflxAcceptanceMinWait sets the ICEPrflxAcceptanceMinWait. +func (e *SettingEngine) SetPrflxAcceptanceMinWait(t time.Duration) { + e.timeout.ICEPrflxAcceptanceMinWait = &t +} + +// SetRelayAcceptanceMinWait sets the ICERelayAcceptanceMinWait. +func (e *SettingEngine) SetRelayAcceptanceMinWait(t time.Duration) { + e.timeout.ICERelayAcceptanceMinWait = &t +} + +// SetSTUNGatherTimeout sets the ICESTUNGatherTimeout. +func (e *SettingEngine) SetSTUNGatherTimeout(t time.Duration) { + e.timeout.ICESTUNGatherTimeout = &t +} + +// SetEphemeralUDPPortRange limits the pool of ephemeral ports that +// ICE UDP connections can allocate from. This affects both host candidates, +// and the local address of server reflexive candidates. +// +// When portMin and portMax are left to the 0 default value, pion/ice candidate +// gatherer replaces them and uses 1 for portMin and 65535 for portMax. +func (e *SettingEngine) SetEphemeralUDPPortRange(portMin, portMax uint16) error { + if portMax < portMin { + return ice.ErrPort + } + + e.ephemeralUDP.PortMin = portMin + e.ephemeralUDP.PortMax = portMax + + return nil +} + +// SetLite configures whether or not the ice agent should be a lite agent. +func (e *SettingEngine) SetLite(lite bool) { + e.candidates.ICELite = lite +} + +// SetNetworkTypes configures what types of candidate networks are supported +// during local and server reflexive gathering. +func (e *SettingEngine) SetNetworkTypes(candidateTypes []NetworkType) { + e.candidates.ICENetworkTypes = candidateTypes +} + +// SetInterfaceFilter sets the filtering functions when gathering ICE candidates +// This can be used to exclude certain network interfaces from ICE. Which may be +// useful if you know a certain interface will never succeed, or if you wish to reduce +// the amount of information you wish to expose to the remote peer. +func (e *SettingEngine) SetInterfaceFilter(filter func(string) (keep bool)) { + e.candidates.InterfaceFilter = filter +} + +// SetIPFilter sets the filtering functions when gathering ICE candidates +// This can be used to exclude certain ip from ICE. Which may be +// useful if you know a certain ip will never succeed, or if you wish to reduce +// the amount of information you wish to expose to the remote peer. +func (e *SettingEngine) SetIPFilter(filter func(net.IP) (keep bool)) { + e.candidates.IPFilter = filter +} + +// SetRemoteIPFilter sets the filtering function for remote candidate IP addresses. +// This can be used to whitelist or blacklist remote candidate IPs before they are +// added to the ICE agent. +func (e *SettingEngine) SetRemoteIPFilter(filter func(net.IP) (keep bool)) { + e.candidates.RemoteIPFilter = filter +} + +// SetNAT1To1IPs sets a list of external IP addresses of 1:1 (D)NAT +// and a candidate type for which the external IP address is used. +// This is useful when you host a server using Pion on an AWS EC2 instance +// which has a private address, behind a 1:1 DNAT with a public IP (e.g. +// Elastic IP). In this case, you can give the public IP address so that +// Pion will use the public IP address in its candidate instead of the private +// IP address. The second argument, candidateType, is used to tell Pion which +// type of candidate should use the given public IP address. +// Two types of candidates are supported: +// +// ICECandidateTypeHost: +// +// The public IP address will be used for the host candidate in the SDP. +// +// ICECandidateTypeSrflx: +// +// A server reflexive candidate with the given public IP address will be added to the SDP. +// +// Please note that if you choose ICECandidateTypeHost, then the private IP address +// won't be advertised with the peer. Also, this option cannot be used along with mDNS. +// +// If you choose ICECandidateTypeSrflx, it simply adds a server reflexive candidate +// with the public IP. The host candidate is still available along with mDNS +// capabilities unaffected. Also, you cannot give STUN server URL at the same time. +// It will result in an error otherwise. +// +// Deprecated: Use SetICEAddressRewriteRules instead. To mirror the legacy +// behavior, supply ICEAddressRewriteRule with External set to ips, AsCandidateType +// set to candidateType, and Mode set to ICEAddressRewriteReplace for host +// candidates or ICEAddressRewriteAppend for server reflexive candidates. +// Or leave Mode unspecified to use the default behavior; +// replace for host candidates and append for server reflexive candidates. +func (e *SettingEngine) SetNAT1To1IPs(ips []string, candidateType ICECandidateType) { + e.candidates.NAT1To1IPs = ips + e.candidates.NAT1To1IPCandidateType = candidateType +} + +// SetICEAddressRewriteRules configures address rewrite rules for candidate publication. +// These rules provide fine-grained control over which local addresses are replaced or +// supplemented with external IPs. +// This replaces the legacy NAT1To1 settings, which will be deprecated in the future. +func (e *SettingEngine) SetICEAddressRewriteRules(rules ...ICEAddressRewriteRule) error { + if len(rules) == 0 { + e.candidates.addressRewriteRules = nil + + return nil + } + + if len(e.candidates.NAT1To1IPs) > 0 { + return errAddressRewriteWithNAT1To1 + } + + converted := make([]ice.AddressRewriteRule, 0, len(rules)) + for _, rule := range rules { + converted = append(converted, rule.toICE()) + } + + e.candidates.addressRewriteRules = converted + + return nil +} + +// SetIncludeLoopbackCandidate enable pion to gather loopback candidates, it is useful +// for some VM have public IP mapped to loopback interface. +func (e *SettingEngine) SetIncludeLoopbackCandidate(include bool) { + e.candidates.IncludeLoopbackCandidate = include +} + +// SetAnsweringDTLSRole sets the DTLS role that is selected when offering +// The DTLS role controls if the WebRTC Client as a client or server. This +// may be useful when interacting with non-compliant clients or debugging issues. +// +// DTLSRoleActive: +// +// Act as DTLS Client, send the ClientHello and starts the handshake +// +// DTLSRolePassive: +// +// Act as DTLS Server, wait for ClientHello +func (e *SettingEngine) SetAnsweringDTLSRole(role DTLSRole) error { + if role != DTLSRoleClient && role != DTLSRoleServer { + return errSettingEngineSetAnsweringDTLSRole + } + + e.answeringDTLSRole = role + + return nil +} + +// SetNet sets the Net instance that is passed to pion/ice +// +// Net is an network interface layer for Pion, allowing users to replace +// Pions network stack with a custom implementation. +func (e *SettingEngine) SetNet(net transport.Net) { + e.net = net +} + +// SetICEMulticastDNSMode controls if pion/ice queries and generates mDNS ICE Candidates. +func (e *SettingEngine) SetICEMulticastDNSMode(multicastDNSMode ice.MulticastDNSMode) { + e.candidates.MulticastDNSMode = multicastDNSMode +} + +// SetMulticastDNSHostName sets a static HostName to be used by pion/ice instead of generating one on startup +// +// This should only be used for a single PeerConnection. +// Having multiple PeerConnections with the same HostName will cause undefined behavior. +func (e *SettingEngine) SetMulticastDNSHostName(hostName string) { + e.candidates.MulticastDNSHostName = hostName +} + +// SetICECredentials sets a staic uFrag/uPwd to be used by pion/ice +// +// This is useful if you want to do signalless WebRTC session, +// or having a reproducible environment with static credentials. +func (e *SettingEngine) SetICECredentials(usernameFragment, password string) { + e.candidates.UsernameFragment = usernameFragment + e.candidates.Password = password +} + +// DisableCertificateFingerprintVerification disables fingerprint verification after DTLS Handshake has finished. +func (e *SettingEngine) DisableCertificateFingerprintVerification(isDisabled bool) { + e.disableCertificateFingerprintVerification = isDisabled +} + +// SetDTLSReplayProtectionWindow sets a replay attack protection window size of DTLS connection. +func (e *SettingEngine) SetDTLSReplayProtectionWindow(n uint) { + e.replayProtection.DTLS = &n +} + +// SetSRTPReplayProtectionWindow sets a replay attack protection window size of SRTP session. +func (e *SettingEngine) SetSRTPReplayProtectionWindow(n uint) { + e.disableSRTPReplayProtection = false + e.replayProtection.SRTP = &n +} + +// SetSRTCPReplayProtectionWindow sets a replay attack protection window size of SRTCP session. +func (e *SettingEngine) SetSRTCPReplayProtectionWindow(n uint) { + e.disableSRTCPReplayProtection = false + e.replayProtection.SRTCP = &n +} + +// DisableSRTPReplayProtection disables SRTP replay protection. +func (e *SettingEngine) DisableSRTPReplayProtection(isDisabled bool) { + e.disableSRTPReplayProtection = isDisabled +} + +// DisableSRTCPReplayProtection disables SRTCP replay protection. +func (e *SettingEngine) DisableSRTCPReplayProtection(isDisabled bool) { + e.disableSRTCPReplayProtection = isDisabled +} + +// SetSDPMediaLevelFingerprints configures the logic for DTLS Fingerprint insertion +// If true, fingerprints will be inserted in the sdp at the fingerprint +// level, instead of the session level. This helps with compatibility with +// some webrtc implementations. +func (e *SettingEngine) SetSDPMediaLevelFingerprints(sdpMediaLevelFingerprints bool) { + e.sdpMediaLevelFingerprints = sdpMediaLevelFingerprints +} + +// SetICETCPMux enables ICE-TCP when set to a non-nil value. Make sure that +// NetworkTypeTCP4 or NetworkTypeTCP6 is enabled as well. +func (e *SettingEngine) SetICETCPMux(tcpMux ice.TCPMux) { + e.iceTCPMux = tcpMux +} + +// SetICEUDPMux allows ICE traffic to come through a single UDP port, drastically +// simplifying deployments where ports will need to be opened/forwarded. +// UDPMux should be started prior to creating PeerConnections. +func (e *SettingEngine) SetICEUDPMux(udpMux ice.UDPMux) { + e.iceUDPMux = udpMux +} + +// SetICEProxyDialer sets the proxy dialer interface based on golang.org/x/net/proxy. +func (e *SettingEngine) SetICEProxyDialer(d proxy.Dialer) { + e.iceProxyDialer = d +} + +// SetICEMaxBindingRequests sets the maximum amount of binding requests +// that can be sent on a candidate before it is considered invalid. +func (e *SettingEngine) SetICEMaxBindingRequests(d uint16) { + e.iceMaxBindingRequests = &d +} + +// DisableActiveTCP disables using active TCP for ICE. Active TCP is enabled by default. +func (e *SettingEngine) DisableActiveTCP(isDisabled bool) { + e.iceDisableActiveTCP = isDisabled +} + +// DisableMediaEngineCopy stops the MediaEngine from being copied. This allows a user to modify +// the MediaEngine after the PeerConnection has been constructed. This is useful if you wish to +// modify codecs after signaling. Make sure not to share MediaEngines between PeerConnections. +func (e *SettingEngine) DisableMediaEngineCopy(isDisabled bool) { + e.disableMediaEngineCopy = isDisabled +} + +// DisableMediaEngineMultipleCodecs disables the MediaEngine negotiating different codecs. +// With the default value multiple media sections in the SDP can each negotiate different +// codecs. This is the new default behvior, because it makes Pion more spec compliant. +// The value of this setting will get copied to every copy of the MediaEngine generated +// for new PeerConnections (assuming DisableMediaEngineCopy is set to false). +// Note: this setting is targeted to be removed in release 4.2.0 (or later). +func (e *SettingEngine) DisableMediaEngineMultipleCodecs(isDisabled bool) { + e.disableMediaEngineMultipleCodecs = isDisabled +} + +// SetReceiveMTU sets the size of read buffer that copies incoming packets. This is optional. +// Leave this 0 for the default receiveMTU. +func (e *SettingEngine) SetReceiveMTU(receiveMTU uint) { + e.receiveMTU = receiveMTU +} + +// SetDTLSRetransmissionInterval sets the retranmission interval for DTLS. +func (e *SettingEngine) SetDTLSRetransmissionInterval(interval time.Duration) { + e.dtls.retransmissionInterval = interval +} + +// SetDTLSInsecureSkipHelloVerify sets the skip HelloVerify flag for DTLS. +// If true and when acting as DTLS server, will allow client to skip hello verify phase and +// receive ServerHello after initial ClientHello. This will mean faster connect times, +// but will have lower DoS attack resistance. +func (e *SettingEngine) SetDTLSInsecureSkipHelloVerify(skip bool) { + e.dtls.insecureSkipHelloVerify = skip +} + +// SetDTLSDisableInsecureSkipVerify sets the disable skip insecure verify flag for DTLS. +// This controls whether a client verifies the server's certificate chain and host name. +func (e *SettingEngine) SetDTLSDisableInsecureSkipVerify(disable bool) { + e.dtls.disableInsecureSkipVerify = disable +} + +// SetDTLSEllipticCurves sets the elliptic curves for DTLS. +func (e *SettingEngine) SetDTLSEllipticCurves(ellipticCurves ...dtlsElliptic.Curve) { + e.dtls.ellipticCurves = ellipticCurves +} + +// SetDTLSConnectContextMaker sets the context used during the DTLS Handshake. +// It can be used to extend or reduce the timeout on the DTLS Handshake. +// If nil, the default dtls.ConnectContextMaker is used. It can be implemented as following. +// +// func ConnectContextMaker() (context.Context, func()) { +// return context.WithTimeout(context.Background(), 30*time.Second) +// } +func (e *SettingEngine) SetDTLSConnectContextMaker(connectContextMaker func() (context.Context, func())) { + e.dtls.connectContextMaker = connectContextMaker +} + +// SetDTLSExtendedMasterSecret sets the extended master secret type for DTLS. +func (e *SettingEngine) SetDTLSExtendedMasterSecret(extendedMasterSecret dtls.ExtendedMasterSecretType) { + e.dtls.extendedMasterSecret = extendedMasterSecret +} + +// SetDTLSClientAuth sets the client auth type for DTLS. +func (e *SettingEngine) SetDTLSClientAuth(clientAuth dtls.ClientAuthType) { + e.dtls.clientAuth = &clientAuth +} + +// SetDTLSClientCAs sets the client CA certificate pool for DTLS certificate verification. +func (e *SettingEngine) SetDTLSClientCAs(clientCAs *x509.CertPool) { + e.dtls.clientCAs = clientCAs +} + +// SetDTLSRootCAs sets the root CA certificate pool for DTLS certificate verification. +func (e *SettingEngine) SetDTLSRootCAs(rootCAs *x509.CertPool) { + e.dtls.rootCAs = rootCAs +} + +// SetDTLSKeyLogWriter sets the destination of the TLS key material for debugging. +// Logging key material compromises security and should only be use for debugging. +func (e *SettingEngine) SetDTLSKeyLogWriter(writer io.Writer) { + e.dtls.keyLogWriter = writer +} + +// SetSCTPMaxReceiveBufferSize sets the maximum receive buffer size. +// Leave this 0 for the default maxReceiveBufferSize. +func (e *SettingEngine) SetSCTPMaxReceiveBufferSize(maxReceiveBufferSize uint32) { + e.sctp.maxReceiveBufferSize = maxReceiveBufferSize +} + +// EnableSCTPZeroChecksum controls the zero checksum feature in SCTP. +// This removes the need to checksum every incoming/outgoing packet and will reduce +// latency and CPU usage. This feature is not backwards compatible so is disabled by default. +func (e *SettingEngine) EnableSCTPZeroChecksum(isEnabled bool) { + e.sctp.enableZeroChecksum = isEnabled +} + +// EnableSctpSnap enables the use of the SCTP SNAP connect optimization. +func (e *SettingEngine) EnableSctpSnap(isEnabled bool) { + e.sctp.enableSnap = isEnabled +} + +// SetSCTPMaxMessageSize sets the largest message we are willing to accept. +// Leave this 0 for the default max message size. +func (e *SettingEngine) SetSCTPMaxMessageSize(maxMessageSize uint32) { + e.sctp.maxMessageSize = maxMessageSize +} + +// SetDTLSCipherSuites allows the user to specify a list of DTLS CipherSuites. +// This allow to control which ciphers implemented by pion/dtls are used during the DTLS handshake. +// It can be used for DTLS connection hardening. +func (e *SettingEngine) SetDTLSCipherSuites(cipherSuites ...dtls.CipherSuiteID) { + e.dtls.cipherSuites = cipherSuites +} + +// SetDTLSCustomerCipherSuites allows the user to specify a list of custom DTLS CipherSuites. +// It allows to use custom/private DTLS CipherSuites in addition to the ones implemented by pion/dtls. +func (e *SettingEngine) SetDTLSCustomerCipherSuites(customCipherSuites func() []dtls.CipherSuite) { + e.dtls.customCipherSuites = customCipherSuites +} + +// SetDTLSClientHelloMessageHook if not nil, is called when a DTLS Client Hello message is sent +// from a client. The returned handshake message replaces the original message. +func (e *SettingEngine) SetDTLSClientHelloMessageHook(hook func(handshake.MessageClientHello) handshake.Message) { + e.dtls.clientHelloMessageHook = hook +} + +// SetDTLSServerHelloMessageHook if not nil, is called when a DTLS Server Hello message is sent +// from a client. The returned handshake message replaces the original message. +func (e *SettingEngine) SetDTLSServerHelloMessageHook(hook func(handshake.MessageServerHello) handshake.Message) { + e.dtls.serverHelloMessageHook = hook +} + +// SetDTLSCertificateRequestMessageHook if not nil, is called when a DTLS Certificate Request message is sent +// from a client. The returned handshake message replaces the original message. +func (e *SettingEngine) SetDTLSCertificateRequestMessageHook( + hook func(handshake.MessageCertificateRequest) handshake.Message, +) { + e.dtls.certificateRequestMessageHook = hook +} + +// SetDTLSSupportedProtocols sets the supported application protocols (ALPN) for the DTLS handshake. +// Note: RFC 8833 defines two application protocols for WebRTC: +// - `webrtc` - mixed media and data communications using SRTP and data channels. +// - `c-webrtc` - WebRTC with a promise to protect media confidentiality. +func (e *SettingEngine) SetDTLSSupportedProtocols(protocols ...string) { + e.dtls.supportedProtocols = protocols +} + +// SetSCTPRTOMax sets the maximum retransmission timeout. +// Leave this 0 for the default timeout. +func (e *SettingEngine) SetSCTPRTOMax(rtoMax time.Duration) { + e.sctp.rtoMax = rtoMax +} + +// SetSCTPMinCwnd sets the minimum congestion window size. The congestion window +// will not be smaller than this value during congestion control. +func (e *SettingEngine) SetSCTPMinCwnd(minCwnd uint32) { + e.sctp.minCwnd = minCwnd +} + +// SetSCTPFastRtxWnd sets the fast retransmission window size. +func (e *SettingEngine) SetSCTPFastRtxWnd(fastRtxWnd uint32) { + e.sctp.fastRtxWnd = fastRtxWnd +} + +// SetSCTPCwndCAStep sets congestion window adjustment step size during congestion avoidance. +func (e *SettingEngine) SetSCTPCwndCAStep(cwndCAStep uint32) { + e.sctp.cwndCAStep = cwndCAStep +} + +// SetICEBindingRequestHandler sets a callback that is fired on a STUN BindingRequest +// This allows users to do things like +// - Log incoming Binding Requests for debugging +// - Implement draft-thatcher-ice-renomination +// - Implement custom CandidatePair switching logic. +func (e *SettingEngine) SetICEBindingRequestHandler( + bindingRequestHandler func(m *stun.Message, local, remote ice.Candidate, pair *ice.CandidatePair) bool, +) { + e.iceBindingRequestHandler = bindingRequestHandler +} + +// SetFireOnTrackBeforeFirstRTP sets if firing the OnTrack event should happen +// before any RTP packets are received. Setting this to true will +// have the Track's Codec and PayloadTypes be initially set to their +// zero values in the OnTrack handler. +// Note: This does not yet affect simulcast tracks. +func (e *SettingEngine) SetFireOnTrackBeforeFirstRTP(fireOnTrackBeforeFirstRTP bool) { + e.fireOnTrackBeforeFirstRTP = fireOnTrackBeforeFirstRTP +} + +// DisableCloseByDTLS sets if the connection should be closed when dtls transport is closed. +// Setting this to true will keep the connection open when dtls transport is closed +// and relies on the ice failed state to detect the connection is interrupted. +func (e *SettingEngine) DisableCloseByDTLS(isEnabled bool) { + e.disableCloseByDTLS = isEnabled +} + +// SetHandleUndeclaredSSRCWithoutAnswer controls if an SDP answer is required for +// processing early media of non-simulcast tracks. +func (e *SettingEngine) SetHandleUndeclaredSSRCWithoutAnswer(handleUndeclaredSSRCWithoutAnswer bool) { + e.handleUndeclaredSSRCWithoutAnswer = handleUndeclaredSSRCWithoutAnswer +} + +// SetIgnoreRidPauseForRecv controls if SDP `a=simulcast:recv` will include the paused attribute of a RID +// (simulcast layer). +func (e *SettingEngine) SetIgnoreRidPauseForRecv(ignoreRidPauseForRecv bool) { + e.ignoreRidPauseForRecv = ignoreRidPauseForRecv +} diff --git a/vendor/github.com/pion/webrtc/v4/settingengine_js.go b/vendor/github.com/pion/webrtc/v4/settingengine_js.go new file mode 100644 index 0000000..3b2394b --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/settingengine_js.go @@ -0,0 +1,23 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build js && wasm +// +build js,wasm + +package webrtc + +// SettingEngine allows influencing behavior in ways that are not +// supported by the WebRTC API. This allows us to support additional +// use-cases without deviating from the WebRTC API elsewhere. +type SettingEngine struct { + detach struct { + DataChannels bool + } +} + +// DetachDataChannels enables detaching data channels. When enabled +// data channels have to be detached in the OnOpen callback using the +// DataChannel.Detach method. +func (e *SettingEngine) DetachDataChannels() { + e.detach.DataChannels = true +} diff --git a/vendor/github.com/pion/webrtc/v4/signalingstate.go b/vendor/github.com/pion/webrtc/v4/signalingstate.go new file mode 100644 index 0000000..0e9de5f --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/signalingstate.go @@ -0,0 +1,196 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +import ( + "fmt" + "sync/atomic" + + "github.com/pion/webrtc/v4/pkg/rtcerr" +) + +type stateChangeOp int + +const ( + stateChangeOpSetLocal stateChangeOp = iota + 1 + stateChangeOpSetRemote +) + +func (op stateChangeOp) String() string { + switch op { + case stateChangeOpSetLocal: + return "SetLocal" + case stateChangeOpSetRemote: + return "SetRemote" + default: + return "Unknown State Change Operation" + } +} + +// SignalingState indicates the signaling state of the offer/answer process. +type SignalingState int32 + +const ( + // SignalingStateUnknown is the enum's zero-value. + SignalingStateUnknown SignalingState = iota + + // SignalingStateStable indicates there is no offer/answer exchange in + // progress. This is also the initial state, in which case the local and + // remote descriptions are nil. + SignalingStateStable + + // SignalingStateHaveLocalOffer indicates that a local description, of + // type "offer", has been successfully applied. + SignalingStateHaveLocalOffer + + // SignalingStateHaveRemoteOffer indicates that a remote description, of + // type "offer", has been successfully applied. + SignalingStateHaveRemoteOffer + + // SignalingStateHaveLocalPranswer indicates that a remote description + // of type "offer" has been successfully applied and a local description + // of type "pranswer" has been successfully applied. + SignalingStateHaveLocalPranswer + + // SignalingStateHaveRemotePranswer indicates that a local description + // of type "offer" has been successfully applied and a remote description + // of type "pranswer" has been successfully applied. + SignalingStateHaveRemotePranswer + + // SignalingStateClosed indicates The PeerConnection has been closed. + SignalingStateClosed +) + +// This is done this way because of a linter. +const ( + signalingStateStableStr = "stable" + signalingStateHaveLocalOfferStr = "have-local-offer" + signalingStateHaveRemoteOfferStr = "have-remote-offer" + signalingStateHaveLocalPranswerStr = "have-local-pranswer" + signalingStateHaveRemotePranswerStr = "have-remote-pranswer" + signalingStateClosedStr = "closed" +) + +func newSignalingState(raw string) SignalingState { + switch raw { + case signalingStateStableStr: + return SignalingStateStable + case signalingStateHaveLocalOfferStr: + return SignalingStateHaveLocalOffer + case signalingStateHaveRemoteOfferStr: + return SignalingStateHaveRemoteOffer + case signalingStateHaveLocalPranswerStr: + return SignalingStateHaveLocalPranswer + case signalingStateHaveRemotePranswerStr: + return SignalingStateHaveRemotePranswer + case signalingStateClosedStr: + return SignalingStateClosed + default: + return SignalingStateUnknown + } +} + +func (t SignalingState) String() string { + switch t { + case SignalingStateStable: + return signalingStateStableStr + case SignalingStateHaveLocalOffer: + return signalingStateHaveLocalOfferStr + case SignalingStateHaveRemoteOffer: + return signalingStateHaveRemoteOfferStr + case SignalingStateHaveLocalPranswer: + return signalingStateHaveLocalPranswerStr + case SignalingStateHaveRemotePranswer: + return signalingStateHaveRemotePranswerStr + case SignalingStateClosed: + return signalingStateClosedStr + default: + return ErrUnknownType.Error() + } +} + +// Get thread safe read value. +func (t *SignalingState) Get() SignalingState { + return SignalingState(atomic.LoadInt32((*int32)(t))) +} + +// Set thread safe write value. +func (t *SignalingState) Set(state SignalingState) { + atomic.StoreInt32((*int32)(t), int32(state)) +} + +//nolint:gocognit,cyclop +func checkNextSignalingState(cur, next SignalingState, op stateChangeOp, sdpType SDPType) (SignalingState, error) { + // Special case for rollbacks + if sdpType == SDPTypeRollback && cur == SignalingStateStable { + return cur, &rtcerr.InvalidModificationError{ + Err: errSignalingStateCannotRollback, + } + } + + // 4.3.1 valid state transitions + switch cur { // nolint:exhaustive + case SignalingStateStable: + switch op { + case stateChangeOpSetLocal: + // stable->SetLocal(offer)->have-local-offer + if sdpType == SDPTypeOffer && next == SignalingStateHaveLocalOffer { + return next, nil + } + case stateChangeOpSetRemote: + // stable->SetRemote(offer)->have-remote-offer + if sdpType == SDPTypeOffer && next == SignalingStateHaveRemoteOffer { + return next, nil + } + } + case SignalingStateHaveLocalOffer: + if op == stateChangeOpSetRemote { + switch sdpType { // nolint:exhaustive + // have-local-offer->SetRemote(answer)->stable + case SDPTypeAnswer: + if next == SignalingStateStable { + return next, nil + } + // have-local-offer->SetRemote(pranswer)->have-remote-pranswer + case SDPTypePranswer: + if next == SignalingStateHaveRemotePranswer { + return next, nil + } + } + } + case SignalingStateHaveRemotePranswer: + if op == stateChangeOpSetRemote && sdpType == SDPTypeAnswer { + // have-remote-pranswer->SetRemote(answer)->stable + if next == SignalingStateStable { + return next, nil + } + } + case SignalingStateHaveRemoteOffer: + if op == stateChangeOpSetLocal { + switch sdpType { // nolint:exhaustive + // have-remote-offer->SetLocal(answer)->stable + case SDPTypeAnswer: + if next == SignalingStateStable { + return next, nil + } + // have-remote-offer->SetLocal(pranswer)->have-local-pranswer + case SDPTypePranswer: + if next == SignalingStateHaveLocalPranswer { + return next, nil + } + } + } + case SignalingStateHaveLocalPranswer: + if op == stateChangeOpSetLocal && sdpType == SDPTypeAnswer { + // have-local-pranswer->SetLocal(answer)->stable + if next == SignalingStateStable { + return next, nil + } + } + } + + return cur, &rtcerr.InvalidModificationError{ + Err: fmt.Errorf("%w: %s->%s(%s)->%s", errSignalingStateProposedTransitionInvalid, cur, op, sdpType, next), + } +} diff --git a/vendor/github.com/pion/webrtc/v4/srtp_writer_future.go b/vendor/github.com/pion/webrtc/v4/srtp_writer_future.go new file mode 100644 index 0000000..042373b --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/srtp_writer_future.go @@ -0,0 +1,141 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js + +package webrtc + +import ( + "io" + "sync" + "sync/atomic" + "time" + + "github.com/pion/rtp" + "github.com/pion/srtp/v3" +) + +// srtpWriterFuture blocks Read/Write calls until +// the SRTP Session is available. +type srtpWriterFuture struct { + ssrc SSRC + rtpSender *RTPSender + rtcpReadStream atomic.Value // *srtp.ReadStreamSRTCP + rtpWriteStream atomic.Value // *srtp.WriteStreamSRTP + mu sync.Mutex + closed bool +} + +func (s *srtpWriterFuture) init(returnWhenNoSRTP bool) error { //nolint:cyclop + if returnWhenNoSRTP { + select { + case <-s.rtpSender.stopCalled: + return io.ErrClosedPipe + case <-s.rtpSender.transport.srtpReady: + default: + return nil + } + } else { + select { + case <-s.rtpSender.stopCalled: + return io.ErrClosedPipe + case <-s.rtpSender.transport.srtpReady: + } + } + + s.mu.Lock() + defer s.mu.Unlock() + + if s.closed { + return io.ErrClosedPipe + } + + srtcpSession, err := s.rtpSender.transport.getSRTCPSession() + if err != nil { + return err + } + + rtcpReadStream, err := srtcpSession.OpenReadStream(uint32(s.ssrc)) + if err != nil { + return err + } + + srtpSession, err := s.rtpSender.transport.getSRTPSession() + if err != nil { + return err + } + + rtpWriteStream, err := srtpSession.OpenWriteStream() + if err != nil { + return err + } + + s.rtcpReadStream.Store(rtcpReadStream) + s.rtpWriteStream.Store(rtpWriteStream) + + return nil +} + +func (s *srtpWriterFuture) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.closed { + return nil + } + s.closed = true + + if value, ok := s.rtcpReadStream.Load().(*srtp.ReadStreamSRTCP); ok { + return value.Close() + } + + return nil +} + +func (s *srtpWriterFuture) Read(b []byte) (n int, err error) { + if value, ok := s.rtcpReadStream.Load().(*srtp.ReadStreamSRTCP); ok { + return value.Read(b) + } + + if err := s.init(false); err != nil || s.rtcpReadStream.Load() == nil { + return 0, err + } + + return s.Read(b) +} + +func (s *srtpWriterFuture) SetReadDeadline(t time.Time) error { + if value, ok := s.rtcpReadStream.Load().(*srtp.ReadStreamSRTCP); ok { + return value.SetReadDeadline(t) + } + + if err := s.init(false); err != nil || s.rtcpReadStream.Load() == nil { + return err + } + + return s.SetReadDeadline(t) +} + +func (s *srtpWriterFuture) WriteRTP(header *rtp.Header, payload []byte) (int, error) { + if value, ok := s.rtpWriteStream.Load().(*srtp.WriteStreamSRTP); ok { + return value.WriteRTP(header, payload) + } + + if err := s.init(true); err != nil || s.rtpWriteStream.Load() == nil { + return 0, err + } + + return s.WriteRTP(header, payload) +} + +func (s *srtpWriterFuture) Write(b []byte) (int, error) { + if value, ok := s.rtpWriteStream.Load().(*srtp.WriteStreamSRTP); ok { + return value.Write(b) + } + + if err := s.init(true); err != nil || s.rtpWriteStream.Load() == nil { + return 0, err + } + + return s.Write(b) +} diff --git a/vendor/github.com/pion/webrtc/v4/stats.go b/vendor/github.com/pion/webrtc/v4/stats.go new file mode 100644 index 0000000..042fe97 --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/stats.go @@ -0,0 +1,2448 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +import ( + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/pion/ice/v4" +) + +// A Stats object contains a set of statistics copies out of a monitored component +// of the WebRTC stack at a specific time. +type Stats interface { + statsMarker() +} + +// UnmarshalStatsJSON unmarshals a Stats object from JSON. +func UnmarshalStatsJSON(b []byte) (Stats, error) { //nolint:cyclop + type typeJSON struct { + Type StatsType `json:"type"` + } + typeHolder := typeJSON{} + + err := json.Unmarshal(b, &typeHolder) + if err != nil { + return nil, fmt.Errorf("unmarshal json type: %w", err) + } + + switch typeHolder.Type { + case StatsTypeCodec: + return unmarshalCodecStats(b) + case StatsTypeInboundRTP: + return unmarshalInboundRTPStreamStats(b) + case StatsTypeOutboundRTP: + return unmarshalOutboundRTPStreamStats(b) + case StatsTypeRemoteInboundRTP: + return unmarshalRemoteInboundRTPStreamStats(b) + case StatsTypeRemoteOutboundRTP: + return unmarshalRemoteOutboundRTPStreamStats(b) + case StatsTypeCSRC: + return unmarshalCSRCStats(b) + case StatsTypeMediaSource: + return unmarshalMediaSourceStats(b) + case StatsTypeMediaPlayout: + return unmarshalMediaPlayoutStats(b) + case StatsTypePeerConnection: + return unmarshalPeerConnectionStats(b) + case StatsTypeDataChannel: + return unmarshalDataChannelStats(b) + case StatsTypeStream: + return unmarshalStreamStats(b) + case StatsTypeTrack: + return unmarshalTrackStats(b) + case StatsTypeSender: + return unmarshalSenderStats(b) + case StatsTypeReceiver: + return unmarshalReceiverStats(b) + case StatsTypeTransport: + return unmarshalTransportStats(b) + case StatsTypeCandidatePair: + return unmarshalICECandidatePairStats(b) + case StatsTypeLocalCandidate, StatsTypeRemoteCandidate: + return unmarshalICECandidateStats(b) + case StatsTypeCertificate: + return unmarshalCertificateStats(b) + case StatsTypeSCTPTransport: + return unmarshalSCTPTransportStats(b) + default: + return nil, fmt.Errorf("type: %w", ErrUnknownType) + } +} + +// StatsType indicates the type of the object that a Stats object represents. +type StatsType string + +const ( + // StatsTypeCodec is used by CodecStats. + StatsTypeCodec StatsType = "codec" + + // StatsTypeInboundRTP is used by InboundRTPStreamStats. + StatsTypeInboundRTP StatsType = "inbound-rtp" + + // StatsTypeOutboundRTP is used by OutboundRTPStreamStats. + StatsTypeOutboundRTP StatsType = "outbound-rtp" + + // StatsTypeRemoteInboundRTP is used by RemoteInboundRTPStreamStats. + StatsTypeRemoteInboundRTP StatsType = "remote-inbound-rtp" + + // StatsTypeRemoteOutboundRTP is used by RemoteOutboundRTPStreamStats. + StatsTypeRemoteOutboundRTP StatsType = "remote-outbound-rtp" + + // StatsTypeCSRC is used by RTPContributingSourceStats. + StatsTypeCSRC StatsType = "csrc" + + // StatsTypeMediaSource is used by AudioSourceStats or VideoSourceStats depending on kind. + StatsTypeMediaSource = "media-source" + + // StatsTypeMediaPlayout is used by AudioPlayoutStats. + StatsTypeMediaPlayout StatsType = "media-playout" + + // StatsTypePeerConnection used by PeerConnectionStats. + StatsTypePeerConnection StatsType = "peer-connection" + + // StatsTypeDataChannel is used by DataChannelStats. + StatsTypeDataChannel StatsType = "data-channel" + + // StatsTypeStream is used by MediaStreamStats. + StatsTypeStream StatsType = "stream" + + // StatsTypeTrack is used by SenderVideoTrackAttachmentStats and SenderAudioTrackAttachmentStats depending on kind. + StatsTypeTrack StatsType = "track" + + // StatsTypeSender is used by the AudioSenderStats or VideoSenderStats depending on kind. + StatsTypeSender StatsType = "sender" + + // StatsTypeReceiver is used by the AudioReceiverStats or VideoReceiverStats depending on kind. + StatsTypeReceiver StatsType = "receiver" + + // StatsTypeTransport is used by TransportStats. + StatsTypeTransport StatsType = "transport" + + // StatsTypeCandidatePair is used by ICECandidatePairStats. + StatsTypeCandidatePair StatsType = "candidate-pair" + + // StatsTypeLocalCandidate is used by ICECandidateStats for the local candidate. + StatsTypeLocalCandidate StatsType = "local-candidate" + + // StatsTypeRemoteCandidate is used by ICECandidateStats for the remote candidate. + StatsTypeRemoteCandidate StatsType = "remote-candidate" + + // StatsTypeCertificate is used by CertificateStats. + StatsTypeCertificate StatsType = "certificate" + + // StatsTypeSCTPTransport is used by SCTPTransportStats. + StatsTypeSCTPTransport StatsType = "sctp-transport" +) + +// MediaKind indicates the kind of media (audio or video). +type MediaKind string + +const ( + // MediaKindAudio indicates this is audio stats. + MediaKindAudio MediaKind = "audio" + // MediaKindVideo indicates this is video stats. + MediaKindVideo MediaKind = "video" +) + +// StatsTimestamp is a timestamp represented by the floating point number of +// milliseconds since the epoch. +type StatsTimestamp float64 + +// Time returns the time.Time represented by this timestamp. +func (s StatsTimestamp) Time() time.Time { + millis := float64(s) + nanos := int64(millis * float64(time.Millisecond)) + + return time.Unix(0, nanos).UTC() +} + +func statsTimestampFrom(t time.Time) StatsTimestamp { + return StatsTimestamp(t.UnixNano() / int64(time.Millisecond)) +} + +func statsTimestampNow() StatsTimestamp { + return statsTimestampFrom(time.Now()) +} + +// StatsReport collects Stats objects indexed by their ID. +type StatsReport map[string]Stats + +type statsReportCollector struct { + collectingGroup sync.WaitGroup + report StatsReport + mux sync.Mutex +} + +func newStatsReportCollector() *statsReportCollector { + return &statsReportCollector{report: make(StatsReport)} +} + +func (src *statsReportCollector) Collecting() { + src.collectingGroup.Add(1) +} + +func (src *statsReportCollector) Collect(id string, stats Stats) { + src.mux.Lock() + defer src.mux.Unlock() + + src.report[id] = stats + src.collectingGroup.Done() +} + +func (src *statsReportCollector) Done() { + src.collectingGroup.Done() +} + +func (src *statsReportCollector) Ready() StatsReport { + src.collectingGroup.Wait() + src.mux.Lock() + defer src.mux.Unlock() + + return src.report +} + +// CodecType specifies whether a CodecStats objects represents a media format +// that is being encoded or decoded. +type CodecType string + +const ( + // CodecTypeEncode means the attached CodecStats represents a media format that + // is being encoded, or that the implementation is prepared to encode. + CodecTypeEncode CodecType = "encode" + + // CodecTypeDecode means the attached CodecStats represents a media format + // that the implementation is prepared to decode. + CodecTypeDecode CodecType = "decode" +) + +// CodecStats contains statistics for a codec that is currently being used by RTP streams +// being sent or received by this PeerConnection object. +type CodecStats struct { + // Timestamp is the timestamp associated with this object. + Timestamp StatsTimestamp `json:"timestamp"` + + // Type is the object's StatsType + Type StatsType `json:"type"` + + // ID is a unique id that is associated with the component inspected to produce + // this Stats object. Two Stats objects will have the same ID if they were produced + // by inspecting the same underlying object. + ID string `json:"id"` + + // PayloadType as used in RTP encoding or decoding + PayloadType PayloadType `json:"payloadType"` + + // CodecType of this CodecStats + CodecType CodecType `json:"codecType"` + + // TransportID is the unique identifier of the transport on which this codec is + // being used, which can be used to look up the corresponding TransportStats object. + TransportID string `json:"transportId"` + + // MimeType is the codec MIME media type/subtype. e.g., video/vp8 or equivalent. + MimeType string `json:"mimeType"` + + // ClockRate represents the media sampling rate. + ClockRate uint32 `json:"clockRate"` + + // Channels is 2 for stereo, missing for most other cases. + Channels uint8 `json:"channels"` + + // SDPFmtpLine is the a=fmtp line in the SDP corresponding to the codec, + // i.e., after the colon following the PT. + SDPFmtpLine string `json:"sdpFmtpLine"` + + // Implementation identifies the implementation used. This is useful for diagnosing + // interoperability issues. + Implementation string `json:"implementation"` +} + +func (s CodecStats) statsMarker() {} + +func unmarshalCodecStats(b []byte) (CodecStats, error) { + var codecStats CodecStats + err := json.Unmarshal(b, &codecStats) + if err != nil { + return CodecStats{}, fmt.Errorf("unmarshal codec stats: %w", err) + } + + return codecStats, nil +} + +// InboundRTPStreamStats contains statistics for an inbound RTP stream that is +// currently received with this PeerConnection object. +type InboundRTPStreamStats struct { + // Mid represents a mid value of RTPTransceiver owning this stream, if that value is not + // null. Otherwise, this member is not present. + Mid string `json:"mid"` + + // Rid only exists if a rid has been set for this RTP stream. + // Must not exist for audio. + Rid string `json:"rid,omitempty"` + + // Timestamp is the timestamp associated with this object. + Timestamp StatsTimestamp `json:"timestamp"` + + // Type is the object's StatsType + Type StatsType `json:"type"` + + // ID is a unique id that is associated with the component inspected to produce + // this Stats object. Two Stats objects will have the same ID if they were produced + // by inspecting the same underlying object. + ID string `json:"id"` + + // SSRC is the 32-bit unsigned integer value used to identify the source of the + // stream of RTP packets that this stats object concerns. + SSRC SSRC `json:"ssrc"` + + // Kind is either "audio" or "video" + Kind string `json:"kind"` + + // It is a unique identifier that is associated to the object that was inspected + // to produce the TransportStats associated with this RTP stream. + TransportID string `json:"transportId"` + + // CodecID is a unique identifier that is associated to the object that was inspected + // to produce the CodecStats associated with this RTP stream. + CodecID string `json:"codecId"` + + // FIRCount counts the total number of Full Intra Request (FIR) packets received + // by the sender. This metric is only valid for video and is sent by receiver. + FIRCount uint32 `json:"firCount"` + + // PLICount counts the total number of Picture Loss Indication (PLI) packets + // received by the sender. This metric is only valid for video and is sent by receiver. + PLICount uint32 `json:"pliCount"` + + // TotalProcessingDelay is the sum of the time, in seconds, each audio sample or video frame + // takes from the time the first RTP packet is received (reception timestamp) and to the time + // the corresponding sample or frame is decoded (decoded timestamp). At this point the audio + // sample or video frame is ready for playout by the MediaStreamTrack. Typically ready for + // playout here means after the audio sample or video frame is fully decoded by the decoder. + TotalProcessingDelay float64 `json:"totalProcessingDelay"` + + // NACKCount counts the total number of Negative ACKnowledgement (NACK) packets + // received by the sender and is sent by receiver. + NACKCount uint32 `json:"nackCount"` + + // JitterBufferDelay is the sum of the time, in seconds, each audio sample or a video frame + // takes from the time the first packet is received by the jitter buffer (ingest timestamp) + // to the time it exits the jitter buffer (emit timestamp). The average jitter buffer delay + // can be calculated by dividing the JitterBufferDelay with the JitterBufferEmittedCount. + JitterBufferDelay float64 `json:"jitterBufferDelay"` + + // JitterBufferTargetDelay is increased by the target jitter buffer delay every time a sample is emitted + // by the jitter buffer. The added target is the target delay, in seconds, at the time that + // the sample was emitted from the jitter buffer. To get the average target delay, + // divide by JitterBufferEmittedCount + JitterBufferTargetDelay float64 `json:"jitterBufferTargetDelay"` + + // JitterBufferEmittedCount is the total number of audio samples or video frames that + // have come out of the jitter buffer (increasing jitterBufferDelay). + JitterBufferEmittedCount uint64 `json:"jitterBufferEmittedCount"` + + // JitterBufferMinimumDelay works the same way as jitterBufferTargetDelay, except that + // it is not affected by external mechanisms that increase the jitter buffer target delay, + // such as jitterBufferTarget, AV sync, or any other mechanisms. This metric is purely + // based on the network characteristics such as jitter and packet loss, and can be seen + // as the minimum obtainable jitter buffer delay if no external factors would affect it. + // The metric is updated every time JitterBufferEmittedCount is updated. + JitterBufferMinimumDelay float64 `json:"jitterBufferMinimumDelay"` + + // TotalSamplesReceived is the total number of samples that have been received on + // this RTP stream. This includes concealedSamples. Does not exist for video. + TotalSamplesReceived uint64 `json:"totalSamplesReceived"` + + // ConcealedSamples is the total number of samples that are concealed samples. + // A concealed sample is a sample that was replaced with synthesized samples generated + // locally before being played out. Examples of samples that have to be concealed are + // samples from lost packets (reported in packetsLost) or samples from packets that + // arrive too late to be played out (reported in packetsDiscarded). Does not exist for video. + ConcealedSamples uint64 `json:"concealedSamples"` + + // SilentConcealedSamples is the total number of concealed samples inserted that + // are "silent". Playing out silent samples results in silence or comfort noise. + // This is a subset of concealedSamples. Does not exist for video. + SilentConcealedSamples uint64 `json:"silentConcealedSamples"` + + // ConcealmentEvents increases every time a concealed sample is synthesized after + // a non-concealed sample. That is, multiple consecutive concealed samples will increase + // the concealedSamples count multiple times but is a single concealment event. + // Does not exist for video. + ConcealmentEvents uint64 `json:"concealmentEvents"` + + // InsertedSamplesForDeceleration is increased by the difference between the number of + // samples received and the number of samples played out when playout is slowed down. + // If playout is slowed down by inserting samples, this will be the number of inserted samples. + // Does not exist for video. + InsertedSamplesForDeceleration uint64 `json:"insertedSamplesForDeceleration"` + + // RemovedSamplesForAcceleration is increased by the difference between the number of + // samples received and the number of samples played out when playout is sped up. If speedup + // is achieved by removing samples, this will be the count of samples removed. + // Does not exist for video. + RemovedSamplesForAcceleration uint64 `json:"removedSamplesForAcceleration"` + + // AudioLevel represents the audio level of the receiving track.. + // + // The value is a value between 0..1 (linear), where 1.0 represents 0 dBov, + // 0 represents silence, and 0.5 represents approximately 6 dBSPL change in + // the sound pressure level from 0 dBov. Does not exist for video. + AudioLevel float64 `json:"audioLevel"` + + // TotalAudioEnergy represents the audio energy of the receiving track. It is calculated + // by duration * Math.pow(energy/maxEnergy, 2) for each audio sample received (and thus + // counted by TotalSamplesReceived). Does not exist for video. + TotalAudioEnergy float64 `json:"totalAudioEnergy"` + + // TotalSamplesDuration represents the total duration in seconds of all samples that have been + // received (and thus counted by TotalSamplesReceived). Can be used with totalAudioEnergy to + // compute an average audio level over different intervals. Does not exist for video. + TotalSamplesDuration float64 `json:"totalSamplesDuration"` + + // SLICount counts the total number of Slice Loss Indication (SLI) packets received + // by the sender. This metric is only valid for video and is sent by receiver. + SLICount uint32 `json:"sliCount"` + + // QPSum is the sum of the QP values of frames passed. The count of frames is + // in FramesDecoded for inbound stream stats, and in FramesEncoded for outbound stream stats. + QPSum uint64 `json:"qpSum"` + + // TotalDecodeTime is the total number of seconds that have been spent decoding the FramesDecoded + // frames of this stream. The average decode time can be calculated by dividing this value + // with FramesDecoded. The time it takes to decode one frame is the time passed between + // feeding the decoder a frame and the decoder returning decoded data for that frame. + TotalDecodeTime float64 `json:"totalDecodeTime"` + + // TotalInterFrameDelay is the sum of the interframe delays in seconds between consecutively + // rendered frames, recorded just after a frame has been rendered. The interframe delay variance + // be calculated from TotalInterFrameDelay, TotalSquaredInterFrameDelay, and FramesRendered according + // to the formula: (TotalSquaredInterFrameDelay - TotalInterFrameDelay^2 / FramesRendered) / FramesRendered. + // Does not exist for audio. + TotalInterFrameDelay float64 `json:"totalInterFrameDelay"` + + // TotalSquaredInterFrameDelay is the sum of the squared interframe delays in seconds + // between consecutively rendered frames, recorded just after a frame has been rendered. + // See TotalInterFrameDelay for details on how to calculate the interframe delay variance. + // Does not exist for audio. + TotalSquaredInterFrameDelay float64 `json:"totalSquaredInterFrameDelay"` + + // PacketsReceived is the total number of RTP packets received for this SSRC. + PacketsReceived uint32 `json:"packetsReceived"` + + // PacketsLost is the total number of RTP packets lost for this SSRC. Note that + // because of how this is estimated, it can be negative if more packets are received than sent. + PacketsLost int32 `json:"packetsLost"` + + // Jitter is the packet jitter measured in seconds for this SSRC + Jitter float64 `json:"jitter"` + + // PacketsDiscarded is the cumulative number of RTP packets discarded by the jitter + // buffer due to late or early-arrival, i.e., these packets are not played out. + // RTP packets discarded due to packet duplication are not reported in this metric. + PacketsDiscarded uint32 `json:"packetsDiscarded"` + + // PacketsRepaired is the cumulative number of lost RTP packets repaired after applying + // an error-resilience mechanism. It is measured for the primary source RTP packets + // and only counted for RTP packets that have no further chance of repair. + PacketsRepaired uint32 `json:"packetsRepaired"` + + // BurstPacketsLost is the cumulative number of RTP packets lost during loss bursts. + BurstPacketsLost uint32 `json:"burstPacketsLost"` + + // BurstPacketsDiscarded is the cumulative number of RTP packets discarded during discard bursts. + BurstPacketsDiscarded uint32 `json:"burstPacketsDiscarded"` + + // BurstLossCount is the cumulative number of bursts of lost RTP packets. + BurstLossCount uint32 `json:"burstLossCount"` + + // BurstDiscardCount is the cumulative number of bursts of discarded RTP packets. + BurstDiscardCount uint32 `json:"burstDiscardCount"` + + // BurstLossRate is the fraction of RTP packets lost during bursts to the + // total number of RTP packets expected in the bursts. + BurstLossRate float64 `json:"burstLossRate"` + + // BurstDiscardRate is the fraction of RTP packets discarded during bursts to + // the total number of RTP packets expected in bursts. + BurstDiscardRate float64 `json:"burstDiscardRate"` + + // GapLossRate is the fraction of RTP packets lost during the gap periods. + GapLossRate float64 `json:"gapLossRate"` + + // GapDiscardRate is the fraction of RTP packets discarded during the gap periods. + GapDiscardRate float64 `json:"gapDiscardRate"` + + // TrackID is the identifier of the stats object representing the receiving track, + // a ReceiverAudioTrackAttachmentStats or ReceiverVideoTrackAttachmentStats. + TrackID string `json:"trackId"` + + // ReceiverID is the stats ID used to look up the AudioReceiverStats or VideoReceiverStats + // object receiving this stream. + ReceiverID string `json:"receiverId"` + + // RemoteID is used for looking up the remote RemoteOutboundRTPStreamStats object + // for the same SSRC. + RemoteID string `json:"remoteId"` + + // FramesDecoded represents the total number of frames correctly decoded for this SSRC, + // i.e., frames that would be displayed if no frames are dropped. Only valid for video. + FramesDecoded uint32 `json:"framesDecoded"` + + // KeyFramesDecoded represents the total number of key frames, such as key frames in + // VP8 [RFC6386] or IDR-frames in H.264 [RFC6184], successfully decoded for this RTP + // media stream. This is a subset of FramesDecoded. FramesDecoded - KeyFramesDecoded + // gives you the number of delta frames decoded. Does not exist for audio. + KeyFramesDecoded uint32 `json:"keyFramesDecoded"` + + // FramesRendered represents the total number of frames that have been rendered. + // It is incremented just after a frame has been rendered. Does not exist for audio. + FramesRendered uint32 `json:"framesRendered"` + + // FramesDropped is the total number of frames dropped prior to decode or dropped + // because the frame missed its display deadline for this receiver's track. + // The measurement begins when the receiver is created and is a cumulative metric + // as defined in Appendix A (g) of [RFC7004]. Does not exist for audio. + FramesDropped uint32 `json:"framesDropped"` + + // FrameWidth represents the width of the last decoded frame. Before the first + // frame is decoded this member does not exist. Does not exist for audio. + FrameWidth uint32 `json:"frameWidth"` + + // FrameHeight represents the height of the last decoded frame. Before the first + // frame is decoded this member does not exist. Does not exist for audio. + FrameHeight uint32 `json:"frameHeight"` + + // LastPacketReceivedTimestamp represents the timestamp at which the last packet was + // received for this SSRC. This differs from Timestamp, which represents the time + // at which the statistics were generated by the local endpoint. + LastPacketReceivedTimestamp StatsTimestamp `json:"lastPacketReceivedTimestamp"` + + // HeaderBytesReceived is the total number of RTP header and padding bytes received for this SSRC. + // This includes retransmissions. This does not include the size of transport layer headers such + // as IP or UDP. headerBytesReceived + bytesReceived equals the number of bytes received as + // payload over the transport. + HeaderBytesReceived uint64 `json:"headerBytesReceived"` + + // AverageRTCPInterval is the average RTCP interval between two consecutive compound RTCP packets. + // This is calculated by the sending endpoint when sending compound RTCP reports. + // Compound packets must contain at least a RTCP RR or SR packet and an SDES packet + // with the CNAME item. + AverageRTCPInterval float64 `json:"averageRtcpInterval"` + + // FECPacketsReceived is the total number of RTP FEC packets received for this SSRC. + // This counter can also be incremented when receiving FEC packets in-band with media packets (e.g., with Opus). + FECPacketsReceived uint32 `json:"fecPacketsReceived"` + + // FECPacketsDiscarded is the total number of RTP FEC packets received for this SSRC where the + // error correction payload was discarded by the application. This may happen + // 1. if all the source packets protected by the FEC packet were received or already + // recovered by a separate FEC packet, or + // 2. if the FEC packet arrived late, i.e., outside the recovery window, and the + // lost RTP packets have already been skipped during playout. + // This is a subset of FECPacketsReceived. + FECPacketsDiscarded uint64 `json:"fecPacketsDiscarded"` + + // BytesReceived is the total number of bytes received for this SSRC. + BytesReceived uint64 `json:"bytesReceived"` + + // FramesReceived represents the total number of complete frames received on this RTP stream. + // This metric is incremented when the complete frame is received. Does not exist for audio. + FramesReceived uint32 `json:"framesReceived"` + + // PacketsFailedDecryption is the cumulative number of RTP packets that failed + // to be decrypted. These packets are not counted by PacketsDiscarded. + PacketsFailedDecryption uint32 `json:"packetsFailedDecryption"` + + // PacketsDuplicated is the cumulative number of packets discarded because they + // are duplicated. Duplicate packets are not counted in PacketsDiscarded. + // + // Duplicated packets have the same RTP sequence number and content as a previously + // received packet. If multiple duplicates of a packet are received, all of them are counted. + // An improved estimate of lost packets can be calculated by adding PacketsDuplicated to PacketsLost. + PacketsDuplicated uint32 `json:"packetsDuplicated"` + + // PerDSCPPacketsReceived is the total number of packets received for this SSRC, + // per Differentiated Services code point (DSCP) [RFC2474]. DSCPs are identified + // as decimal integers in string form. Note that due to network remapping and bleaching, + // these numbers are not expected to match the numbers seen on sending. Not all + // OSes make this information available. + PerDSCPPacketsReceived map[string]uint32 `json:"perDscpPacketsReceived"` + + // Identifies the decoder implementation used. This is useful for diagnosing interoperability issues. + // Does not exist for audio. + DecoderImplementation string `json:"decoderImplementation"` + + // PauseCount is the total number of video pauses experienced by this receiver. + // Video is considered to be paused if time passed since last rendered frame exceeds 5 seconds. + // PauseCount is incremented when a frame is rendered after such a pause. Does not exist for audio. + PauseCount uint32 `json:"pauseCount"` + + // TotalPausesDuration is the total duration of pauses (for definition of pause see PauseCount), in seconds. + // Does not exist for audio. + TotalPausesDuration float64 `json:"totalPausesDuration"` + + // FreezeCount is the total number of video freezes experienced by this receiver. + // It is a freeze if frame duration, which is time interval between two consecutively rendered frames, + // is equal or exceeds Max(3 * avg_frame_duration_ms, avg_frame_duration_ms + 150), + // where avg_frame_duration_ms is linear average of durations of last 30 rendered frames. + // Does not exist for audio. + FreezeCount uint32 `json:"freezeCount"` + + // TotalFreezesDuration is the total duration of rendered frames which are considered as frozen + // (for definition of freeze see freezeCount), in seconds. Does not exist for audio. + TotalFreezesDuration float64 `json:"totalFreezesDuration"` + + // PowerEfficientDecoder indicates whether the decoder currently used is considered power efficient + // by the user agent. Does not exist for audio. + PowerEfficientDecoder bool `json:"powerEfficientDecoder"` +} + +func (s InboundRTPStreamStats) statsMarker() {} + +func unmarshalInboundRTPStreamStats(b []byte) (InboundRTPStreamStats, error) { + var inboundRTPStreamStats InboundRTPStreamStats + err := json.Unmarshal(b, &inboundRTPStreamStats) + if err != nil { + return InboundRTPStreamStats{}, fmt.Errorf("unmarshal inbound rtp stream stats: %w", err) + } + + return inboundRTPStreamStats, nil +} + +// QualityLimitationReason lists the reason for limiting the resolution and/or framerate. +// Only valid for video. +type QualityLimitationReason string + +const ( + // QualityLimitationReasonNone means the resolution and/or framerate is not limited. + QualityLimitationReasonNone QualityLimitationReason = "none" + + // QualityLimitationReasonCPU means the resolution and/or framerate is primarily limited due to CPU load. + QualityLimitationReasonCPU QualityLimitationReason = "cpu" + + // QualityLimitationReasonBandwidth means the resolution and/or framerate is primarily limited + // due to congestion cues during bandwidth estimation. + // Typical, congestion control algorithms use inter-arrival time, round-trip time, + // packet or other congestion cues to perform bandwidth estimation. + QualityLimitationReasonBandwidth QualityLimitationReason = "bandwidth" + + // QualityLimitationReasonOther means the resolution and/or framerate is primarily limited + // for a reason other than the above. + QualityLimitationReasonOther QualityLimitationReason = "other" +) + +// OutboundRTPStreamStats contains statistics for an outbound RTP stream that is +// currently sent with this PeerConnection object. +type OutboundRTPStreamStats struct { + // Mid represents a mid value of RTPTransceiver owning this stream, if that value is not + // null. Otherwise, this member is not present. + Mid string `json:"mid"` + + // Rid only exists if a rid has been set for this RTP stream. + // Must not exist for audio. + Rid string `json:"rid"` + + // MediaSourceID is the identifier of the stats object representing the track currently + // attached to the sender of this stream, an RTCMediaSourceStats. + MediaSourceID string `json:"mediaSourceId"` + + // Timestamp is the timestamp associated with this object. + Timestamp StatsTimestamp `json:"timestamp"` + + // Type is the object's StatsType + Type StatsType `json:"type"` + + // ID is a unique id that is associated with the component inspected to produce + // this Stats object. Two Stats objects will have the same ID if they were produced + // by inspecting the same underlying object. + ID string `json:"id"` + + // SSRC is the 32-bit unsigned integer value used to identify the source of the + // stream of RTP packets that this stats object concerns. + SSRC SSRC `json:"ssrc"` + + // Kind is either "audio" or "video" + Kind string `json:"kind"` + + // It is a unique identifier that is associated to the object that was inspected + // to produce the TransportStats associated with this RTP stream. + TransportID string `json:"transportId"` + + // CodecID is a unique identifier that is associated to the object that was inspected + // to produce the CodecStats associated with this RTP stream. + CodecID string `json:"codecId"` + + // HeaderBytesSent is the total number of RTP header and padding bytes sent for this SSRC. This does not + // include the size of transport layer headers such as IP or UDP. + // HeaderBytesSent + BytesSent equals the number of bytes sent as payload over the transport. + HeaderBytesSent uint64 `json:"headerBytesSent"` + + // RetransmittedPacketsSent is the total number of packets that were retransmitted for this SSRC. + // This is a subset of packetsSent. If RTX is not negotiated, retransmitted packets are sent + // over this ssrc. If RTX was negotiated, retransmitted packets are sent over a separate SSRC + // but is still accounted for here. + RetransmittedPacketsSent uint64 `json:"retransmittedPacketsSent"` + + // RetransmittedBytesSent is the total number of bytes that were retransmitted for this SSRC, + // only including payload bytes. This is a subset of bytesSent. If RTX is not negotiated, + // retransmitted bytes are sent over this ssrc. If RTX was negotiated, retransmitted bytes + // are sent over a separate SSRC but is still accounted for here. + RetransmittedBytesSent uint64 `json:"retransmittedBytesSent"` + + // FIRCount counts the total number of Full Intra Request (FIR) packets received + // by the sender. This metric is only valid for video and is sent by receiver. + FIRCount uint32 `json:"firCount"` + + // PLICount counts the total number of Picture Loss Indication (PLI) packets + // received by the sender. This metric is only valid for video and is sent by receiver. + PLICount uint32 `json:"pliCount"` + + // NACKCount counts the total number of Negative ACKnowledgement (NACK) packets + // received by the sender and is sent by receiver. + NACKCount uint32 `json:"nackCount"` + + // SLICount counts the total number of Slice Loss Indication (SLI) packets received + // by the sender. This metric is only valid for video and is sent by receiver. + SLICount uint32 `json:"sliCount"` + + // QPSum is the sum of the QP values of frames passed. The count of frames is + // in FramesDecoded for inbound stream stats, and in FramesEncoded for outbound stream stats. + QPSum uint64 `json:"qpSum"` + + // PacketsSent is the total number of RTP packets sent for this SSRC. + PacketsSent uint32 `json:"packetsSent"` + + // PacketsDiscardedOnSend is the total number of RTP packets for this SSRC that + // have been discarded due to socket errors, i.e. a socket error occurred when handing + // the packets to the socket. This might happen due to various reasons, including + // full buffer or no available memory. + PacketsDiscardedOnSend uint32 `json:"packetsDiscardedOnSend"` + + // FECPacketsSent is the total number of RTP FEC packets sent for this SSRC. + // This counter can also be incremented when sending FEC packets in-band with + // media packets (e.g., with Opus). + FECPacketsSent uint32 `json:"fecPacketsSent"` + + // BytesSent is the total number of bytes sent for this SSRC. + BytesSent uint64 `json:"bytesSent"` + + // BytesDiscardedOnSend is the total number of bytes for this SSRC that have + // been discarded due to socket errors, i.e. a socket error occurred when handing + // the packets containing the bytes to the socket. This might happen due to various + // reasons, including full buffer or no available memory. + BytesDiscardedOnSend uint64 `json:"bytesDiscardedOnSend"` + + // TrackID is the identifier of the stats object representing the current track + // attachment to the sender of this stream, a SenderAudioTrackAttachmentStats + // or SenderVideoTrackAttachmentStats. + TrackID string `json:"trackId"` + + // SenderID is the stats ID used to look up the AudioSenderStats or VideoSenderStats + // object sending this stream. + SenderID string `json:"senderId"` + + // RemoteID is used for looking up the remote RemoteInboundRTPStreamStats object + // for the same SSRC. + RemoteID string `json:"remoteId"` + + // LastPacketSentTimestamp represents the timestamp at which the last packet was + // sent for this SSRC. This differs from timestamp, which represents the time at + // which the statistics were generated by the local endpoint. + LastPacketSentTimestamp StatsTimestamp `json:"lastPacketSentTimestamp"` + + // TargetBitrate is the current target bitrate configured for this particular SSRC + // and is the Transport Independent Application Specific (TIAS) bitrate [RFC3890]. + // Typically, the target bitrate is a configuration parameter provided to the codec's + // encoder and does not count the size of the IP or other transport layers like TCP or UDP. + // It is measured in bits per second and the bitrate is calculated over a 1 second window. + TargetBitrate float64 `json:"targetBitrate"` + + // TotalEncodedBytesTarget is increased by the target frame size in bytes every time + // a frame has been encoded. The actual frame size may be bigger or smaller than this number. + // This value goes up every time framesEncoded goes up. + TotalEncodedBytesTarget uint64 `json:"totalEncodedBytesTarget"` + + // FrameWidth represents the width of the last encoded frame. The resolution of the + // encoded frame may be lower than the media source. Before the first frame is encoded + // this member does not exist. Does not exist for audio. + FrameWidth uint32 `json:"frameWidth"` + + // FrameHeight represents the height of the last encoded frame. The resolution of the + // encoded frame may be lower than the media source. Before the first frame is encoded + // this member does not exist. Does not exist for audio. + FrameHeight uint32 `json:"frameHeight"` + + // FramesPerSecond is the number of encoded frames during the last second. This may be + // lower than the media source frame rate. Does not exist for audio. + FramesPerSecond float64 `json:"framesPerSecond"` + + // FramesSent represents the total number of frames sent on this RTP stream. Does not exist for audio. + FramesSent uint32 `json:"framesSent"` + + // HugeFramesSent represents the total number of huge frames sent by this RTP stream. + // Huge frames, by definition, are frames that have an encoded size at least 2.5 times + // the average size of the frames. The average size of the frames is defined as the + // target bitrate per second divided by the target FPS at the time the frame was encoded. + // These are usually complex to encode frames with a lot of changes in the picture. + // This can be used to estimate, e.g slide changes in the streamed presentation. + // Does not exist for audio. + HugeFramesSent uint32 `json:"hugeFramesSent"` + + // FramesEncoded represents the total number of frames successfully encoded for this RTP media stream. + // Only valid for video. + FramesEncoded uint32 `json:"framesEncoded"` + + // KeyFramesEncoded represents the total number of key frames, such as key frames in VP8 [RFC6386] or + // IDR-frames in H.264 [RFC6184], successfully encoded for this RTP media stream. This is a subset of + // FramesEncoded. FramesEncoded - KeyFramesEncoded gives you the number of delta frames encoded. + // Does not exist for audio. + KeyFramesEncoded uint32 `json:"keyFramesEncoded"` + + // TotalEncodeTime is the total number of seconds that has been spent encoding the + // framesEncoded frames of this stream. The average encode time can be calculated by + // dividing this value with FramesEncoded. The time it takes to encode one frame is the + // time passed between feeding the encoder a frame and the encoder returning encoded data + // for that frame. This does not include any additional time it may take to packetize the resulting data. + TotalEncodeTime float64 `json:"totalEncodeTime"` + + // TotalPacketSendDelay is the total number of seconds that packets have spent buffered + // locally before being transmitted onto the network. The time is measured from when + // a packet is emitted from the RTP packetizer until it is handed over to the OS network socket. + // This measurement is added to totalPacketSendDelay when packetsSent is incremented. + TotalPacketSendDelay float64 `json:"totalPacketSendDelay"` + + // AverageRTCPInterval is the average RTCP interval between two consecutive compound RTCP + // packets. This is calculated by the sending endpoint when sending compound RTCP reports. + // Compound packets must contain at least a RTCP RR or SR packet and an SDES packet with the CNAME item. + AverageRTCPInterval float64 `json:"averageRtcpInterval"` + + // QualityLimitationReason is the current reason for limiting the resolution and/or framerate, + // or "none" if not limited. Only valid for video. + QualityLimitationReason QualityLimitationReason `json:"qualityLimitationReason"` + + // QualityLimitationDurations is record of the total time, in seconds, that this + // stream has spent in each quality limitation state. The record includes a mapping + // for all QualityLimitationReason types, including "none". Only valid for video. + QualityLimitationDurations map[string]float64 `json:"qualityLimitationDurations"` + + // QualityLimitationResolutionChanges is the number of times that the resolution has changed + // because we are quality limited (qualityLimitationReason has a value other than "none"). + // The counter is initially zero and increases when the resolution goes up or down. + // For example, if a 720p track is sent as 480p for some time and then recovers to 720p, + // qualityLimitationResolutionChanges will have the value 2. Does not exist for audio. + QualityLimitationResolutionChanges uint32 `json:"qualityLimitationResolutionChanges"` + + // PerDSCPPacketsSent is the total number of packets sent for this SSRC, per DSCP. + // DSCPs are identified as decimal integers in string form. + PerDSCPPacketsSent map[string]uint32 `json:"perDscpPacketsSent"` + + // Active indicates whether this RTP stream is configured to be sent or disabled. Note that an + // active stream can still not be sending, e.g. when being limited by network conditions. + Active bool `json:"active"` + + // Identifies the encoder implementation used. This is useful for diagnosing interoperability issues. + // Does not exist for audio. + EncoderImplementation string `json:"encoderImplementation"` + + // PowerEfficientEncoder indicates whether the encoder currently used is considered power efficient. + // by the user agent. Does not exist for audio. + PowerEfficientEncoder bool `json:"powerEfficientEncoder"` + + // ScalabilityMode identifies the layering mode used for video encoding. Does not exist for audio. + ScalabilityMode string `json:"scalabilityMode"` +} + +func (s OutboundRTPStreamStats) statsMarker() {} + +func unmarshalOutboundRTPStreamStats(b []byte) (OutboundRTPStreamStats, error) { + var outboundRTPStreamStats OutboundRTPStreamStats + err := json.Unmarshal(b, &outboundRTPStreamStats) + if err != nil { + return OutboundRTPStreamStats{}, fmt.Errorf("unmarshal outbound rtp stream stats: %w", err) + } + + return outboundRTPStreamStats, nil +} + +// RemoteInboundRTPStreamStats contains statistics for the remote endpoint's inbound +// RTP stream corresponding to an outbound stream that is currently sent with this +// PeerConnection object. It is measured at the remote endpoint and reported in an RTCP +// Receiver Report (RR) or RTCP Extended Report (XR). +type RemoteInboundRTPStreamStats struct { + // Timestamp is the timestamp associated with this object. + Timestamp StatsTimestamp `json:"timestamp"` + + // Type is the object's StatsType + Type StatsType `json:"type"` + + // ID is a unique id that is associated with the component inspected to produce + // this Stats object. Two Stats objects will have the same ID if they were produced + // by inspecting the same underlying object. + ID string `json:"id"` + + // SSRC is the 32-bit unsigned integer value used to identify the source of the + // stream of RTP packets that this stats object concerns. + SSRC SSRC `json:"ssrc"` + + // Kind is either "audio" or "video" + Kind string `json:"kind"` + + // It is a unique identifier that is associated to the object that was inspected + // to produce the TransportStats associated with this RTP stream. + TransportID string `json:"transportId"` + + // CodecID is a unique identifier that is associated to the object that was inspected + // to produce the CodecStats associated with this RTP stream. + CodecID string `json:"codecId"` + + // FIRCount counts the total number of Full Intra Request (FIR) packets received + // by the sender. This metric is only valid for video and is sent by receiver. + FIRCount uint32 `json:"firCount"` + + // PLICount counts the total number of Picture Loss Indication (PLI) packets + // received by the sender. This metric is only valid for video and is sent by receiver. + PLICount uint32 `json:"pliCount"` + + // NACKCount counts the total number of Negative ACKnowledgement (NACK) packets + // received by the sender and is sent by receiver. + NACKCount uint32 `json:"nackCount"` + + // SLICount counts the total number of Slice Loss Indication (SLI) packets received + // by the sender. This metric is only valid for video and is sent by receiver. + SLICount uint32 `json:"sliCount"` + + // QPSum is the sum of the QP values of frames passed. The count of frames is + // in FramesDecoded for inbound stream stats, and in FramesEncoded for outbound stream stats. + QPSum uint64 `json:"qpSum"` + + // PacketsReceived is the total number of RTP packets received for this SSRC. + PacketsReceived uint32 `json:"packetsReceived"` + + // PacketsLost is the total number of RTP packets lost for this SSRC. Note that + // because of how this is estimated, it can be negative if more packets are received than sent. + PacketsLost int32 `json:"packetsLost"` + + // Jitter is the packet jitter measured in seconds for this SSRC + Jitter float64 `json:"jitter"` + + // PacketsDiscarded is the cumulative number of RTP packets discarded by the jitter + // buffer due to late or early-arrival, i.e., these packets are not played out. + // RTP packets discarded due to packet duplication are not reported in this metric. + PacketsDiscarded uint32 `json:"packetsDiscarded"` + + // PacketsRepaired is the cumulative number of lost RTP packets repaired after applying + // an error-resilience mechanism. It is measured for the primary source RTP packets + // and only counted for RTP packets that have no further chance of repair. + PacketsRepaired uint32 `json:"packetsRepaired"` + + // BurstPacketsLost is the cumulative number of RTP packets lost during loss bursts. + BurstPacketsLost uint32 `json:"burstPacketsLost"` + + // BurstPacketsDiscarded is the cumulative number of RTP packets discarded during discard bursts. + BurstPacketsDiscarded uint32 `json:"burstPacketsDiscarded"` + + // BurstLossCount is the cumulative number of bursts of lost RTP packets. + BurstLossCount uint32 `json:"burstLossCount"` + + // BurstDiscardCount is the cumulative number of bursts of discarded RTP packets. + BurstDiscardCount uint32 `json:"burstDiscardCount"` + + // BurstLossRate is the fraction of RTP packets lost during bursts to the + // total number of RTP packets expected in the bursts. + BurstLossRate float64 `json:"burstLossRate"` + + // BurstDiscardRate is the fraction of RTP packets discarded during bursts to + // the total number of RTP packets expected in bursts. + BurstDiscardRate float64 `json:"burstDiscardRate"` + + // GapLossRate is the fraction of RTP packets lost during the gap periods. + GapLossRate float64 `json:"gapLossRate"` + + // GapDiscardRate is the fraction of RTP packets discarded during the gap periods. + GapDiscardRate float64 `json:"gapDiscardRate"` + + // LocalID is used for looking up the local OutboundRTPStreamStats object for the same SSRC. + LocalID string `json:"localId"` + + // RoundTripTime is the estimated round trip time for this SSRC based on the + // RTCP timestamps in the RTCP Receiver Report (RR) and measured in seconds. + RoundTripTime float64 `json:"roundTripTime"` + + // TotalRoundTripTime represents the cumulative sum of all round trip time measurements + // in seconds since the beginning of the session. The individual round trip time is calculated + // based on the RTCP timestamps in the RTCP Receiver Report (RR) [RFC3550], hence requires + // a DLSR value other than 0. The average round trip time can be computed from + // TotalRoundTripTime by dividing it by RoundTripTimeMeasurements. + TotalRoundTripTime float64 `json:"totalRoundTripTime"` + + // FractionLost is the fraction packet loss reported for this SSRC. + FractionLost float64 `json:"fractionLost"` + + // RoundTripTimeMeasurements represents the total number of RTCP RR blocks received for this SSRC + // that contain a valid round trip time. This counter will not increment if the RoundTripTime can + // not be calculated because no RTCP Receiver Report with a DLSR value other than 0 has been received. + RoundTripTimeMeasurements uint64 `json:"roundTripTimeMeasurements"` +} + +func (s RemoteInboundRTPStreamStats) statsMarker() {} + +func unmarshalRemoteInboundRTPStreamStats(b []byte) (RemoteInboundRTPStreamStats, error) { + var remoteInboundRTPStreamStats RemoteInboundRTPStreamStats + err := json.Unmarshal(b, &remoteInboundRTPStreamStats) + if err != nil { + return RemoteInboundRTPStreamStats{}, fmt.Errorf("unmarshal remote inbound rtp stream stats: %w", err) + } + + return remoteInboundRTPStreamStats, nil +} + +// RemoteOutboundRTPStreamStats contains statistics for the remote endpoint's outbound +// RTP stream corresponding to an inbound stream that is currently received with this +// PeerConnection object. It is measured at the remote endpoint and reported in an +// RTCP Sender Report (SR). +type RemoteOutboundRTPStreamStats struct { + // Timestamp is the timestamp associated with this object. + Timestamp StatsTimestamp `json:"timestamp"` + + // Type is the object's StatsType + Type StatsType `json:"type"` + + // ID is a unique id that is associated with the component inspected to produce + // this Stats object. Two Stats objects will have the same ID if they were produced + // by inspecting the same underlying object. + ID string `json:"id"` + + // SSRC is the 32-bit unsigned integer value used to identify the source of the + // stream of RTP packets that this stats object concerns. + SSRC SSRC `json:"ssrc"` + + // Kind is either "audio" or "video" + Kind string `json:"kind"` + + // It is a unique identifier that is associated to the object that was inspected + // to produce the TransportStats associated with this RTP stream. + TransportID string `json:"transportId"` + + // CodecID is a unique identifier that is associated to the object that was inspected + // to produce the CodecStats associated with this RTP stream. + CodecID string `json:"codecId"` + + // FIRCount counts the total number of Full Intra Request (FIR) packets received + // by the sender. This metric is only valid for video and is sent by receiver. + FIRCount uint32 `json:"firCount"` + + // PLICount counts the total number of Picture Loss Indication (PLI) packets + // received by the sender. This metric is only valid for video and is sent by receiver. + PLICount uint32 `json:"pliCount"` + + // NACKCount counts the total number of Negative ACKnowledgement (NACK) packets + // received by the sender and is sent by receiver. + NACKCount uint32 `json:"nackCount"` + + // SLICount counts the total number of Slice Loss Indication (SLI) packets received + // by the sender. This metric is only valid for video and is sent by receiver. + SLICount uint32 `json:"sliCount"` + + // QPSum is the sum of the QP values of frames passed. The count of frames is + // in FramesDecoded for inbound stream stats, and in FramesEncoded for outbound stream stats. + QPSum uint64 `json:"qpSum"` + + // PacketsSent is the total number of RTP packets sent for this SSRC. + PacketsSent uint32 `json:"packetsSent"` + + // PacketsDiscardedOnSend is the total number of RTP packets for this SSRC that + // have been discarded due to socket errors, i.e. a socket error occurred when handing + // the packets to the socket. This might happen due to various reasons, including + // full buffer or no available memory. + PacketsDiscardedOnSend uint32 `json:"packetsDiscardedOnSend"` + + // FECPacketsSent is the total number of RTP FEC packets sent for this SSRC. + // This counter can also be incremented when sending FEC packets in-band with + // media packets (e.g., with Opus). + FECPacketsSent uint32 `json:"fecPacketsSent"` + + // BytesSent is the total number of bytes sent for this SSRC. + BytesSent uint64 `json:"bytesSent"` + + // BytesDiscardedOnSend is the total number of bytes for this SSRC that have + // been discarded due to socket errors, i.e. a socket error occurred when handing + // the packets containing the bytes to the socket. This might happen due to various + // reasons, including full buffer or no available memory. + BytesDiscardedOnSend uint64 `json:"bytesDiscardedOnSend"` + + // LocalID is used for looking up the local InboundRTPStreamStats object for the same SSRC. + LocalID string `json:"localId"` + + // RemoteTimestamp represents the remote timestamp at which these statistics were + // sent by the remote endpoint. This differs from timestamp, which represents the + // time at which the statistics were generated or received by the local endpoint. + // The RemoteTimestamp, if present, is derived from the NTP timestamp in an RTCP + // Sender Report (SR) packet, which reflects the remote endpoint's clock. + // That clock may not be synchronized with the local clock. + RemoteTimestamp StatsTimestamp `json:"remoteTimestamp"` + + // ReportsSent represents the total number of RTCP Sender Report (SR) blocks sent for this SSRC. + ReportsSent uint64 `json:"reportsSent"` + + // RoundTripTime is estimated round trip time for this SSRC based on the latest + // RTCP Sender Report (SR) that contains a DLRR report block as defined in [RFC3611]. + // The Calculation of the round trip time is defined in section 4.5. of [RFC3611]. + // Does not exist if the latest SR does not contain the DLRR report block, or if the last RR timestamp + // in the DLRR report block is zero, or if the delay since last RR value in the DLRR report block is zero. + RoundTripTime float64 `json:"roundTripTime"` + + // TotalRoundTripTime represents the cumulative sum of all round trip time measurements in seconds + // since the beginning of the session. The individual round trip time is calculated based on the DLRR + // report block in the RTCP Sender Report (SR) [RFC3611]. This counter will not increment if the + // RoundTripTime can not be calculated. The average round trip time can be computed from + // TotalRoundTripTime by dividing it by RoundTripTimeMeasurements. + TotalRoundTripTime float64 `json:"totalRoundTripTime"` + + // RoundTripTimeMeasurements represents the total number of RTCP Sender Report (SR) blocks + // received for this SSRC that contain a DLRR report block that can derive a valid round trip time + // according to [RFC3611]. This counter will not increment if the RoundTripTime can not be calculated. + RoundTripTimeMeasurements uint64 `json:"roundTripTimeMeasurements"` +} + +func (s RemoteOutboundRTPStreamStats) statsMarker() {} + +func unmarshalRemoteOutboundRTPStreamStats(b []byte) (RemoteOutboundRTPStreamStats, error) { + var remoteOutboundRTPStreamStats RemoteOutboundRTPStreamStats + err := json.Unmarshal(b, &remoteOutboundRTPStreamStats) + if err != nil { + return RemoteOutboundRTPStreamStats{}, fmt.Errorf("unmarshal remote outbound rtp stream stats: %w", err) + } + + return remoteOutboundRTPStreamStats, nil +} + +// RTPContributingSourceStats contains statistics for a contributing source (CSRC) that contributed +// to an inbound RTP stream. +type RTPContributingSourceStats struct { + // Timestamp is the timestamp associated with this object. + Timestamp StatsTimestamp `json:"timestamp"` + + // Type is the object's StatsType + Type StatsType `json:"type"` + + // ID is a unique id that is associated with the component inspected to produce + // this Stats object. Two Stats objects will have the same ID if they were produced + // by inspecting the same underlying object. + ID string `json:"id"` + + // ContributorSSRC is the SSRC identifier of the contributing source represented + // by this stats object. It is a 32-bit unsigned integer that appears in the CSRC + // list of any packets the relevant source contributed to. + ContributorSSRC SSRC `json:"contributorSsrc"` + + // InboundRTPStreamID is the ID of the InboundRTPStreamStats object representing + // the inbound RTP stream that this contributing source is contributing to. + InboundRTPStreamID string `json:"inboundRtpStreamId"` + + // PacketsContributedTo is the total number of RTP packets that this contributing + // source contributed to. This value is incremented each time a packet is counted + // by InboundRTPStreamStats.packetsReceived, and the packet's CSRC list contains + // the SSRC identifier of this contributing source, ContributorSSRC. + PacketsContributedTo uint32 `json:"packetsContributedTo"` + + // AudioLevel is present if the last received RTP packet that this source contributed + // to contained an [RFC6465] mixer-to-client audio level header extension. The value + // of audioLevel is between 0..1 (linear), where 1.0 represents 0 dBov, 0 represents + // silence, and 0.5 represents approximately 6 dBSPL change in the sound pressure level from 0 dBov. + AudioLevel float64 `json:"audioLevel"` +} + +func (s RTPContributingSourceStats) statsMarker() {} + +func unmarshalCSRCStats(b []byte) (RTPContributingSourceStats, error) { + var csrcStats RTPContributingSourceStats + err := json.Unmarshal(b, &csrcStats) + if err != nil { + return RTPContributingSourceStats{}, fmt.Errorf("unmarshal csrc stats: %w", err) + } + + return csrcStats, nil +} + +// AudioSourceStats represents an audio track that is attached to one or more senders. +type AudioSourceStats struct { + // Timestamp is the timestamp associated with this object. + Timestamp StatsTimestamp `json:"timestamp"` + + // Type is the object's StatsType + Type StatsType `json:"type"` + + // ID is a unique id that is associated with the component inspected to produce + // this Stats object. Two Stats objects will have the same ID if they were produced + // by inspecting the same underlying object. + ID string `json:"id"` + + // TrackIdentifier represents the id property of the track. + TrackIdentifier string `json:"trackIdentifier"` + + // Kind is "audio" + Kind string `json:"kind"` + + // AudioLevel represents the output audio level of the track. + // + // The value is a value between 0..1 (linear), where 1.0 represents 0 dBov, + // 0 represents silence, and 0.5 represents approximately 6 dBSPL change in + // the sound pressure level from 0 dBov. + // + // If the track is sourced from an Receiver, does no audio processing, has a + // constant level, and has a volume setting of 1.0, the audio level is expected + // to be the same as the audio level of the source SSRC, while if the volume setting + // is 0.5, the AudioLevel is expected to be half that value. + AudioLevel float64 `json:"audioLevel"` + + // TotalAudioEnergy is the total energy of all the audio samples sent/received + // for this object, calculated by duration * Math.pow(energy/maxEnergy, 2) for + // each audio sample seen. + TotalAudioEnergy float64 `json:"totalAudioEnergy"` + + // TotalSamplesDuration represents the total duration in seconds of all samples + // that have sent or received (and thus counted by TotalSamplesSent or TotalSamplesReceived). + // Can be used with TotalAudioEnergy to compute an average audio level over different intervals. + TotalSamplesDuration float64 `json:"totalSamplesDuration"` + + // EchoReturnLoss is only present while the sender is sending a track sourced from + // a microphone where echo cancellation is applied. Calculated in decibels. + EchoReturnLoss float64 `json:"echoReturnLoss"` + + // EchoReturnLossEnhancement is only present while the sender is sending a track + // sourced from a microphone where echo cancellation is applied. Calculated in decibels. + EchoReturnLossEnhancement float64 `json:"echoReturnLossEnhancement"` + + // DroppedSamplesDuration represents the total duration, in seconds, of samples produced by the device that got + // dropped before reaching the media source. Only applicable if this media source is backed by an audio capture device. + DroppedSamplesDuration float64 `json:"droppedSamplesDuration"` + + // DroppedSamplesEvents is the number of dropped samples events. This counter increases every time a sample is + // dropped after a non-dropped sample. That is, multiple consecutive dropped samples will increase + // droppedSamplesDuration multiple times but is a single dropped samples event. + DroppedSamplesEvents uint64 `json:"droppedSamplesEvents"` + + // TotalCaptureDelay is the total delay, in seconds, for each audio sample between the time the sample was emitted + // by the capture device and the sample reaching the source. This can be used together with totalSamplesCaptured to + // calculate the average capture delay per sample. + // Only applicable if the audio source represents an audio capture device. + TotalCaptureDelay float64 `json:"totalCaptureDelay"` + + // TotalSamplesCaptured is the total number of captured samples reaching the audio source, i.e. that were not dropped + // by the capture pipeline. The frequency of the media source is not necessarily the same as the frequency of encoders + // later in the pipeline. Only applicable if the audio source represents an audio capture device. + TotalSamplesCaptured uint64 `json:"totalSamplesCaptured"` +} + +func (s AudioSourceStats) statsMarker() {} + +// VideoSourceStats represents a video track that is attached to one or more senders. +type VideoSourceStats struct { + // Timestamp is the timestamp associated with this object. + Timestamp StatsTimestamp `json:"timestamp"` + + // Type is the object's StatsType + Type StatsType `json:"type"` + + // ID is a unique id that is associated with the component inspected to produce + // this Stats object. Two Stats objects will have the same ID if they were produced + // by inspecting the same underlying object. + ID string `json:"id"` + + // TrackIdentifier represents the id property of the track. + TrackIdentifier string `json:"trackIdentifier"` + + // Kind is "video" + Kind string `json:"kind"` + + // Width is width of the last frame originating from this source in pixels. + Width uint32 `json:"width"` + + // Height is height of the last frame originating from this source in pixels. + Height uint32 `json:"height"` + + // Frames is the total number of frames originating from this source. + Frames uint32 `json:"frames"` + + // FramesPerSecond is the number of frames originating from this source, measured during the last second. + FramesPerSecond float64 `json:"framesPerSecond"` +} + +func (s VideoSourceStats) statsMarker() {} + +func unmarshalMediaSourceStats(b []byte) (Stats, error) { + type kindJSON struct { + Kind string `json:"kind"` + } + kindHolder := kindJSON{} + + err := json.Unmarshal(b, &kindHolder) + if err != nil { + return nil, fmt.Errorf("unmarshal json kind: %w", err) + } + + switch MediaKind(kindHolder.Kind) { + case MediaKindAudio: + var mediaSourceStats AudioSourceStats + err := json.Unmarshal(b, &mediaSourceStats) + if err != nil { + return nil, fmt.Errorf("unmarshal audio source stats: %w", err) + } + + return mediaSourceStats, nil + case MediaKindVideo: + var mediaSourceStats VideoSourceStats + err := json.Unmarshal(b, &mediaSourceStats) + if err != nil { + return nil, fmt.Errorf("unmarshal video source stats: %w", err) + } + + return mediaSourceStats, nil + default: + return nil, fmt.Errorf("kind: %w", ErrUnknownType) + } +} + +// AudioPlayoutStats represents one playout path - if the same playout stats object is referenced by multiple +// RTCInboundRtpStreamStats this is an indication that audio mixing is happening in which case sample counters in this +// stats object refer to the samples after mixing. Only applicable if the playout path represents an audio device. +type AudioPlayoutStats struct { + // Timestamp is the timestamp associated with this object. + Timestamp StatsTimestamp `json:"timestamp"` + + // Type is the object's StatsType + Type StatsType `json:"type"` + + // ID is a unique id that is associated with the component inspected to produce + // this Stats object. Two Stats objects will have the same ID if they were produced + // by inspecting the same underlying object. + ID string `json:"id"` + + // Kind is "audio" + Kind string `json:"kind"` + + // SynthesizedSamplesDuration is measured in seconds and is incremented each time an audio sample is synthesized by + // this playout path. This metric can be used together with totalSamplesDuration to calculate the percentage of played + // out media being synthesized. If the playout path is unable to produce audio samples on time for device playout, + // samples are synthesized to be played out instead. Synthesization typically only happens if the pipeline is + // underperforming. Samples synthesized by the RTCInboundRtpStreamStats are not counted for here, but in + // InboundRtpStreamStats.concealedSamples. + SynthesizedSamplesDuration float64 `json:"synthesizedSamplesDuration"` + + // SynthesizedSamplesEvents is the number of synthesized samples events. This counter increases every time a sample + // is synthesized after a non-synthesized sample. That is, multiple consecutive synthesized samples will increase + // synthesizedSamplesDuration multiple times but is a single synthesization samples event. + SynthesizedSamplesEvents uint64 `json:"synthesizedSamplesEvents"` + + // TotalSamplesDuration represents the total duration in seconds of all samples + // that have sent or received (and thus counted by TotalSamplesSent or TotalSamplesReceived). + // Can be used with TotalAudioEnergy to compute an average audio level over different intervals. + TotalSamplesDuration float64 `json:"totalSamplesDuration"` + + // When audio samples are pulled by the playout device, this counter is incremented with the estimated delay of the + // playout path for that audio sample. The playout delay includes the delay from being emitted to the actual time of + // playout on the device. This metric can be used together with totalSamplesCount to calculate the average + // playout delay per sample. + TotalPlayoutDelay float64 `json:"totalPlayoutDelay"` + + // When audio samples are pulled by the playout device, this counter is incremented with the number of samples + // emitted for playout. + TotalSamplesCount uint64 `json:"totalSamplesCount"` +} + +func (s AudioPlayoutStats) statsMarker() {} + +func unmarshalMediaPlayoutStats(b []byte) (Stats, error) { + var audioPlayoutStats AudioPlayoutStats + err := json.Unmarshal(b, &audioPlayoutStats) + if err != nil { + return nil, fmt.Errorf("unmarshal audio playout stats: %w", err) + } + + return audioPlayoutStats, nil +} + +// PeerConnectionStats contains statistics related to the PeerConnection object. +type PeerConnectionStats struct { + // Timestamp is the timestamp associated with this object. + Timestamp StatsTimestamp `json:"timestamp"` + + // Type is the object's StatsType + Type StatsType `json:"type"` + + // ID is a unique id that is associated with the component inspected to produce + // this Stats object. Two Stats objects will have the same ID if they were produced + // by inspecting the same underlying object. + ID string `json:"id"` + + // DataChannelsOpened represents the number of unique DataChannels that have + // entered the "open" state during their lifetime. + DataChannelsOpened uint32 `json:"dataChannelsOpened"` + + // DataChannelsClosed represents the number of unique DataChannels that have + // left the "open" state during their lifetime (due to being closed by either + // end or the underlying transport being closed). DataChannels that transition + // from "connecting" to "closing" or "closed" without ever being "open" + // are not counted in this number. + DataChannelsClosed uint32 `json:"dataChannelsClosed"` + + // DataChannelsRequested Represents the number of unique DataChannels returned + // from a successful createDataChannel() call on the PeerConnection. If the + // underlying data transport is not established, these may be in the "connecting" state. + DataChannelsRequested uint32 `json:"dataChannelsRequested"` + + // DataChannelsAccepted represents the number of unique DataChannels signaled + // in a "datachannel" event on the PeerConnection. + DataChannelsAccepted uint32 `json:"dataChannelsAccepted"` +} + +func (s PeerConnectionStats) statsMarker() {} + +func unmarshalPeerConnectionStats(b []byte) (PeerConnectionStats, error) { + var pcStats PeerConnectionStats + err := json.Unmarshal(b, &pcStats) + if err != nil { + return PeerConnectionStats{}, fmt.Errorf("unmarshal pc stats: %w", err) + } + + return pcStats, nil +} + +// DataChannelStats contains statistics related to each DataChannel ID. +type DataChannelStats struct { + // Timestamp is the timestamp associated with this object. + Timestamp StatsTimestamp `json:"timestamp"` + + // Type is the object's StatsType + Type StatsType `json:"type"` + + // ID is a unique id that is associated with the component inspected to produce + // this Stats object. Two Stats objects will have the same ID if they were produced + // by inspecting the same underlying object. + ID string `json:"id"` + + // Label is the "label" value of the DataChannel object. + Label string `json:"label"` + + // Protocol is the "protocol" value of the DataChannel object. + Protocol string `json:"protocol"` + + // DataChannelIdentifier is the "id" attribute of the DataChannel object. + DataChannelIdentifier int32 `json:"dataChannelIdentifier"` + + // TransportID the ID of the TransportStats object for transport used to carry this datachannel. + TransportID string `json:"transportId"` + + // State is the "readyState" value of the DataChannel object. + State DataChannelState `json:"state"` + + // MessagesSent represents the total number of API "message" events sent. + MessagesSent uint32 `json:"messagesSent"` + + // BytesSent represents the total number of payload bytes sent on this + // datachannel not including headers or padding. + BytesSent uint64 `json:"bytesSent"` + + // MessagesReceived represents the total number of API "message" events received. + MessagesReceived uint32 `json:"messagesReceived"` + + // BytesReceived represents the total number of bytes received on this + // datachannel not including headers or padding. + BytesReceived uint64 `json:"bytesReceived"` +} + +func (s DataChannelStats) statsMarker() {} + +func unmarshalDataChannelStats(b []byte) (DataChannelStats, error) { + var dataChannelStats DataChannelStats + err := json.Unmarshal(b, &dataChannelStats) + if err != nil { + return DataChannelStats{}, fmt.Errorf("unmarshal data channel stats: %w", err) + } + + return dataChannelStats, nil +} + +// MediaStreamStats contains statistics related to a specific MediaStream. +type MediaStreamStats struct { + // Timestamp is the timestamp associated with this object. + Timestamp StatsTimestamp `json:"timestamp"` + + // Type is the object's StatsType + Type StatsType `json:"type"` + + // ID is a unique id that is associated with the component inspected to produce + // this Stats object. Two Stats objects will have the same ID if they were produced + // by inspecting the same underlying object. + ID string `json:"id"` + + // StreamIdentifier is the "id" property of the MediaStream + StreamIdentifier string `json:"streamIdentifier"` + + // TrackIDs is a list of the identifiers of the stats object representing the + // stream's tracks, either ReceiverAudioTrackAttachmentStats or ReceiverVideoTrackAttachmentStats. + TrackIDs []string `json:"trackIds"` +} + +func (s MediaStreamStats) statsMarker() {} + +func unmarshalStreamStats(b []byte) (MediaStreamStats, error) { + var streamStats MediaStreamStats + err := json.Unmarshal(b, &streamStats) + if err != nil { + return MediaStreamStats{}, fmt.Errorf("unmarshal stream stats: %w", err) + } + + return streamStats, nil +} + +// AudioSenderStats represents the stats about one audio sender of a PeerConnection +// object for which one calls GetStats. +// +// It appears in the stats as soon as the RTPSender is added by either AddTrack +// or AddTransceiver, or by media negotiation. +type AudioSenderStats struct { + // Timestamp is the timestamp associated with this object. + Timestamp StatsTimestamp `json:"timestamp"` + + // Type is the object's StatsType + Type StatsType `json:"type"` + + // ID is a unique id that is associated with the component inspected to produce + // this Stats object. Two Stats objects will have the same ID if they were produced + // by inspecting the same underlying object. + ID string `json:"id"` + + // TrackIdentifier represents the id property of the track. + TrackIdentifier string `json:"trackIdentifier"` + + // RemoteSource is true if the source is remote, for instance if it is sourced + // from another host via a PeerConnection. False otherwise. Only applicable for 'track' stats. + RemoteSource bool `json:"remoteSource"` + + // Ended reflects the "ended" state of the track. + Ended bool `json:"ended"` + + // Kind is "audio" + Kind string `json:"kind"` + + // AudioLevel represents the output audio level of the track. + // + // The value is a value between 0..1 (linear), where 1.0 represents 0 dBov, + // 0 represents silence, and 0.5 represents approximately 6 dBSPL change in + // the sound pressure level from 0 dBov. + // + // If the track is sourced from an Receiver, does no audio processing, has a + // constant level, and has a volume setting of 1.0, the audio level is expected + // to be the same as the audio level of the source SSRC, while if the volume setting + // is 0.5, the AudioLevel is expected to be half that value. + // + // For outgoing audio tracks, the AudioLevel is the level of the audio being sent. + AudioLevel float64 `json:"audioLevel"` + + // TotalAudioEnergy is the total energy of all the audio samples sent/received + // for this object, calculated by duration * Math.pow(energy/maxEnergy, 2) for + // each audio sample seen. + TotalAudioEnergy float64 `json:"totalAudioEnergy"` + + // VoiceActivityFlag represents whether the last RTP packet sent or played out + // by this track contained voice activity or not based on the presence of the + // V bit in the extension header, as defined in [RFC6464]. + // + // This value indicates the voice activity in the latest RTP packet played out + // from a given SSRC, and is defined in RTPSynchronizationSource.voiceActivityFlag. + VoiceActivityFlag bool `json:"voiceActivityFlag"` + + // TotalSamplesDuration represents the total duration in seconds of all samples + // that have sent or received (and thus counted by TotalSamplesSent or TotalSamplesReceived). + // Can be used with TotalAudioEnergy to compute an average audio level over different intervals. + TotalSamplesDuration float64 `json:"totalSamplesDuration"` + + // EchoReturnLoss is only present while the sender is sending a track sourced from + // a microphone where echo cancellation is applied. Calculated in decibels. + EchoReturnLoss float64 `json:"echoReturnLoss"` + + // EchoReturnLossEnhancement is only present while the sender is sending a track + // sourced from a microphone where echo cancellation is applied. Calculated in decibels. + EchoReturnLossEnhancement float64 `json:"echoReturnLossEnhancement"` + + // TotalSamplesSent is the total number of samples that have been sent by this sender. + TotalSamplesSent uint64 `json:"totalSamplesSent"` +} + +func (s AudioSenderStats) statsMarker() {} + +// SenderAudioTrackAttachmentStats object represents the stats about one attachment +// of an audio MediaStreamTrack to the PeerConnection object for which one calls GetStats. +// +// It appears in the stats as soon as it is attached (via AddTrack, via AddTransceiver, +// via ReplaceTrack on an RTPSender object). +// +// If an audio track is attached twice (via AddTransceiver or ReplaceTrack), there +// will be two SenderAudioTrackAttachmentStats objects, one for each attachment. +// They will have the same "TrackIdentifier" attribute, but different "ID" attributes. +// +// If the track is detached from the PeerConnection (via removeTrack or via replaceTrack), +// it continues to appear, but with the "ObjectDeleted" member set to true. +type SenderAudioTrackAttachmentStats AudioSenderStats + +func (s SenderAudioTrackAttachmentStats) statsMarker() {} + +// VideoSenderStats represents the stats about one video sender of a PeerConnection +// object for which one calls GetStats. +// +// It appears in the stats as soon as the sender is added by either AddTrack or +// AddTransceiver, or by media negotiation. +type VideoSenderStats struct { + // Timestamp is the timestamp associated with this object. + Timestamp StatsTimestamp `json:"timestamp"` + + // Type is the object's StatsType + Type StatsType `json:"type"` + + // ID is a unique id that is associated with the component inspected to produce + // this Stats object. Two Stats objects will have the same ID if they were produced + // by inspecting the same underlying object. + ID string `json:"id"` + + // Kind is "video" + Kind string `json:"kind"` + + // FramesCaptured represents the total number of frames captured, before encoding, + // for this RTPSender (or for this MediaStreamTrack, if type is "track"). For example, + // if type is "sender" and this sender's track represents a camera, then this is the + // number of frames produced by the camera for this track while being sent by this sender, + // combined with the number of frames produced by all tracks previously attached to this + // sender while being sent by this sender. Framerates can vary due to hardware limitations + // or environmental factors such as lighting conditions. + FramesCaptured uint32 `json:"framesCaptured"` + + // FramesSent represents the total number of frames sent by this RTPSender + // (or for this MediaStreamTrack, if type is "track"). + FramesSent uint32 `json:"framesSent"` + + // HugeFramesSent represents the total number of huge frames sent by this RTPSender + // (or for this MediaStreamTrack, if type is "track"). Huge frames, by definition, + // are frames that have an encoded size at least 2.5 times the average size of the frames. + // The average size of the frames is defined as the target bitrate per second divided + // by the target fps at the time the frame was encoded. These are usually complex + // to encode frames with a lot of changes in the picture. This can be used to estimate, + // e.g slide changes in the streamed presentation. If a huge frame is also a key frame, + // then both counters HugeFramesSent and KeyFramesSent are incremented. + HugeFramesSent uint32 `json:"hugeFramesSent"` + + // KeyFramesSent represents the total number of key frames sent by this RTPSender + // (or for this MediaStreamTrack, if type is "track"), such as Infra-frames in + // VP8 [RFC6386] or I-frames in H.264 [RFC6184]. This is a subset of FramesSent. + // FramesSent - KeyFramesSent gives you the number of delta frames sent. + KeyFramesSent uint32 `json:"keyFramesSent"` +} + +func (s VideoSenderStats) statsMarker() {} + +// SenderVideoTrackAttachmentStats represents the stats about one attachment of a +// video MediaStreamTrack to the PeerConnection object for which one calls GetStats. +// +// It appears in the stats as soon as it is attached (via AddTrack, via AddTransceiver, +// via ReplaceTrack on an RTPSender object). +// +// If a video track is attached twice (via AddTransceiver or ReplaceTrack), there +// will be two SenderVideoTrackAttachmentStats objects, one for each attachment. +// They will have the same "TrackIdentifier" attribute, but different "ID" attributes. +// +// If the track is detached from the PeerConnection (via RemoveTrack or via ReplaceTrack), +// it continues to appear, but with the "ObjectDeleted" member set to true. +type SenderVideoTrackAttachmentStats VideoSenderStats + +func (s SenderVideoTrackAttachmentStats) statsMarker() {} + +func unmarshalSenderStats(b []byte) (Stats, error) { + type kindJSON struct { + Kind string `json:"kind"` + } + kindHolder := kindJSON{} + + err := json.Unmarshal(b, &kindHolder) + if err != nil { + return nil, fmt.Errorf("unmarshal json kind: %w", err) + } + + switch MediaKind(kindHolder.Kind) { + case MediaKindAudio: + var senderStats AudioSenderStats + err := json.Unmarshal(b, &senderStats) + if err != nil { + return nil, fmt.Errorf("unmarshal audio sender stats: %w", err) + } + + return senderStats, nil + case MediaKindVideo: + var senderStats VideoSenderStats + err := json.Unmarshal(b, &senderStats) + if err != nil { + return nil, fmt.Errorf("unmarshal video sender stats: %w", err) + } + + return senderStats, nil + default: + return nil, fmt.Errorf("kind: %w", ErrUnknownType) + } +} + +func unmarshalTrackStats(b []byte) (Stats, error) { + type kindJSON struct { + Kind string `json:"kind"` + } + kindHolder := kindJSON{} + + err := json.Unmarshal(b, &kindHolder) + if err != nil { + return nil, fmt.Errorf("unmarshal json kind: %w", err) + } + + switch MediaKind(kindHolder.Kind) { + case MediaKindAudio: + var trackStats SenderAudioTrackAttachmentStats + err := json.Unmarshal(b, &trackStats) + if err != nil { + return nil, fmt.Errorf("unmarshal audio track stats: %w", err) + } + + return trackStats, nil + case MediaKindVideo: + var trackStats SenderVideoTrackAttachmentStats + err := json.Unmarshal(b, &trackStats) + if err != nil { + return nil, fmt.Errorf("unmarshal video track stats: %w", err) + } + + return trackStats, nil + default: + return nil, fmt.Errorf("kind: %w", ErrUnknownType) + } +} + +// AudioReceiverStats contains audio metrics related to a specific receiver. +type AudioReceiverStats struct { + // Timestamp is the timestamp associated with this object. + Timestamp StatsTimestamp `json:"timestamp"` + + // Type is the object's StatsType + Type StatsType `json:"type"` + + // ID is a unique id that is associated with the component inspected to produce + // this Stats object. Two Stats objects will have the same ID if they were produced + // by inspecting the same underlying object. + ID string `json:"id"` + + // Kind is "audio" + Kind string `json:"kind"` + + // AudioLevel represents the output audio level of the track. + // + // The value is a value between 0..1 (linear), where 1.0 represents 0 dBov, + // 0 represents silence, and 0.5 represents approximately 6 dBSPL change in + // the sound pressure level from 0 dBov. + // + // If the track is sourced from a Receiver, does no audio processing, has a + // constant level, and has a volume setting of 1.0, the audio level is expected + // to be the same as the audio level of the source SSRC, while if the volume setting + // is 0.5, the AudioLevel is expected to be half that value. + // + // For outgoing audio tracks, the AudioLevel is the level of the audio being sent. + AudioLevel float64 `json:"audioLevel"` + + // TotalAudioEnergy is the total energy of all the audio samples sent/received + // for this object, calculated by duration * Math.pow(energy/maxEnergy, 2) for + // each audio sample seen. + TotalAudioEnergy float64 `json:"totalAudioEnergy"` + + // VoiceActivityFlag represents whether the last RTP packet sent or played out + // by this track contained voice activity or not based on the presence of the + // V bit in the extension header, as defined in [RFC6464]. + // + // This value indicates the voice activity in the latest RTP packet played out + // from a given SSRC, and is defined in RTPSynchronizationSource.voiceActivityFlag. + VoiceActivityFlag bool `json:"voiceActivityFlag"` + + // TotalSamplesDuration represents the total duration in seconds of all samples + // that have sent or received (and thus counted by TotalSamplesSent or TotalSamplesReceived). + // Can be used with TotalAudioEnergy to compute an average audio level over different intervals. + TotalSamplesDuration float64 `json:"totalSamplesDuration"` + + // EstimatedPlayoutTimestamp is the estimated playout time of this receiver's + // track. The playout time is the NTP timestamp of the last playable sample that + // has a known timestamp (from an RTCP SR packet mapping RTP timestamps to NTP + // timestamps), extrapolated with the time elapsed since it was ready to be played out. + // This is the "current time" of the track in NTP clock time of the sender and + // can be present even if there is no audio currently playing. + // + // This can be useful for estimating how much audio and video is out of + // sync for two tracks from the same source: + // AudioTrackStats.EstimatedPlayoutTimestamp - VideoTrackStats.EstimatedPlayoutTimestamp + EstimatedPlayoutTimestamp StatsTimestamp `json:"estimatedPlayoutTimestamp"` + + // JitterBufferDelay is the sum of the time, in seconds, each sample takes from + // the time it is received and to the time it exits the jitter buffer. + // This increases upon samples exiting, having completed their time in the buffer + // (incrementing JitterBufferEmittedCount). The average jitter buffer delay can + // be calculated by dividing the JitterBufferDelay with the JitterBufferEmittedCount. + JitterBufferDelay float64 `json:"jitterBufferDelay"` + + // JitterBufferEmittedCount is the total number of samples that have come out + // of the jitter buffer (increasing JitterBufferDelay). + JitterBufferEmittedCount uint64 `json:"jitterBufferEmittedCount"` + + // TotalSamplesReceived is the total number of samples that have been received + // by this receiver. This includes ConcealedSamples. + TotalSamplesReceived uint64 `json:"totalSamplesReceived"` + + // ConcealedSamples is the total number of samples that are concealed samples. + // A concealed sample is a sample that is based on data that was synthesized + // to conceal packet loss and does not represent incoming data. + ConcealedSamples uint64 `json:"concealedSamples"` + + // ConcealmentEvents is the number of concealment events. This counter increases + // every time a concealed sample is synthesized after a non-concealed sample. + // That is, multiple consecutive concealed samples will increase the concealedSamples + // count multiple times but is a single concealment event. + ConcealmentEvents uint64 `json:"concealmentEvents"` +} + +func (s AudioReceiverStats) statsMarker() {} + +// VideoReceiverStats contains video metrics related to a specific receiver. +type VideoReceiverStats struct { + // Timestamp is the timestamp associated with this object. + Timestamp StatsTimestamp `json:"timestamp"` + + // Type is the object's StatsType + Type StatsType `json:"type"` + + // ID is a unique id that is associated with the component inspected to produce + // this Stats object. Two Stats objects will have the same ID if they were produced + // by inspecting the same underlying object. + ID string `json:"id"` + + // Kind is "video" + Kind string `json:"kind"` + + // FrameWidth represents the width of the last processed frame for this track. + // Before the first frame is processed this attribute is missing. + FrameWidth uint32 `json:"frameWidth"` + + // FrameHeight represents the height of the last processed frame for this track. + // Before the first frame is processed this attribute is missing. + FrameHeight uint32 `json:"frameHeight"` + + // FramesPerSecond represents the nominal FPS value before the degradation preference + // is applied. It is the number of complete frames in the last second. For sending + // tracks it is the current captured FPS and for the receiving tracks it is the + // current decoding framerate. + FramesPerSecond float64 `json:"framesPerSecond"` + + // EstimatedPlayoutTimestamp is the estimated playout time of this receiver's + // track. The playout time is the NTP timestamp of the last playable sample that + // has a known timestamp (from an RTCP SR packet mapping RTP timestamps to NTP + // timestamps), extrapolated with the time elapsed since it was ready to be played out. + // This is the "current time" of the track in NTP clock time of the sender and + // can be present even if there is no audio currently playing. + // + // This can be useful for estimating how much audio and video is out of + // sync for two tracks from the same source: + // AudioTrackStats.EstimatedPlayoutTimestamp - VideoTrackStats.EstimatedPlayoutTimestamp + EstimatedPlayoutTimestamp StatsTimestamp `json:"estimatedPlayoutTimestamp"` + + // JitterBufferDelay is the sum of the time, in seconds, each sample takes from + // the time it is received and to the time it exits the jitter buffer. + // This increases upon samples exiting, having completed their time in the buffer + // (incrementing JitterBufferEmittedCount). The average jitter buffer delay can + // be calculated by dividing the JitterBufferDelay with the JitterBufferEmittedCount. + JitterBufferDelay float64 `json:"jitterBufferDelay"` + + // JitterBufferEmittedCount is the total number of samples that have come out + // of the jitter buffer (increasing JitterBufferDelay). + JitterBufferEmittedCount uint64 `json:"jitterBufferEmittedCount"` + + // FramesReceived Represents the total number of complete frames received for + // this receiver. This metric is incremented when the complete frame is received. + FramesReceived uint32 `json:"framesReceived"` + + // KeyFramesReceived represents the total number of complete key frames received + // for this MediaStreamTrack, such as Intra-frames in VP8 [RFC6386] or I-frames + // in H.264 [RFC6184]. This is a subset of framesReceived. `framesReceived - keyFramesReceived` + // gives you the number of delta frames received. This metric is incremented when + // the complete key frame is received. It is not incremented if a partial key + // frame is received and sent for decoding, i.e., the frame could not be recovered + // via retransmission or FEC. + KeyFramesReceived uint32 `json:"keyFramesReceived"` + + // FramesDecoded represents the total number of frames correctly decoded for this + // SSRC, i.e., frames that would be displayed if no frames are dropped. + FramesDecoded uint32 `json:"framesDecoded"` + + // FramesDropped is the total number of frames dropped predecode or dropped + // because the frame missed its display deadline for this receiver's track. + FramesDropped uint32 `json:"framesDropped"` + + // The cumulative number of partial frames lost. This metric is incremented when + // the frame is sent to the decoder. If the partial frame is received and recovered + // via retransmission or FEC before decoding, the FramesReceived counter is incremented. + PartialFramesLost uint32 `json:"partialFramesLost"` + + // FullFramesLost is the cumulative number of full frames lost. + FullFramesLost uint32 `json:"fullFramesLost"` +} + +func (s VideoReceiverStats) statsMarker() {} + +func unmarshalReceiverStats(b []byte) (Stats, error) { + type kindJSON struct { + Kind string `json:"kind"` + } + kindHolder := kindJSON{} + + err := json.Unmarshal(b, &kindHolder) + if err != nil { + return nil, fmt.Errorf("unmarshal json kind: %w", err) + } + + switch MediaKind(kindHolder.Kind) { + case MediaKindAudio: + var receiverStats AudioReceiverStats + err := json.Unmarshal(b, &receiverStats) + if err != nil { + return nil, fmt.Errorf("unmarshal audio receiver stats: %w", err) + } + + return receiverStats, nil + case MediaKindVideo: + var receiverStats VideoReceiverStats + err := json.Unmarshal(b, &receiverStats) + if err != nil { + return nil, fmt.Errorf("unmarshal video receiver stats: %w", err) + } + + return receiverStats, nil + default: + return nil, fmt.Errorf("kind: %w", ErrUnknownType) + } +} + +// TransportStats contains transport statistics related to the PeerConnection object. +type TransportStats struct { + // Timestamp is the timestamp associated with this object. + Timestamp StatsTimestamp `json:"timestamp"` + + // Type is the object's StatsType + Type StatsType `json:"type"` + + // ID is a unique id that is associated with the component inspected to produce + // this Stats object. Two Stats objects will have the same ID if they were produced + // by inspecting the same underlying object. + ID string `json:"id"` + + // PacketsSent represents the total number of packets sent over this transport. + PacketsSent uint32 `json:"packetsSent"` + + // PacketsReceived represents the total number of packets received on this transport. + PacketsReceived uint32 `json:"packetsReceived"` + + // BytesSent represents the total number of payload bytes sent on this PeerConnection + // not including headers or padding. + BytesSent uint64 `json:"bytesSent"` + + // BytesReceived represents the total number of bytes received on this PeerConnection + // not including headers or padding. + BytesReceived uint64 `json:"bytesReceived"` + + // RTCPTransportStatsID is the ID of the transport that gives stats for the RTCP + // component If RTP and RTCP are not multiplexed and this record has only + // the RTP component stats. + RTCPTransportStatsID string `json:"rtcpTransportStatsId"` + + // ICERole is set to the current value of the "role" attribute of the underlying + // DTLSTransport's "iceTransport". + ICERole ICERole `json:"iceRole"` + + // DTLSState is set to the current value of the "state" attribute of the underlying DTLSTransport. + DTLSState DTLSTransportState `json:"dtlsState"` + + // ICEState is set to the current value of the "state" attribute of the underlying + // RTCIceTransport's "state". + ICEState ICETransportState `json:"iceState"` + + // SelectedCandidatePairID is a unique identifier that is associated to the object + // that was inspected to produce the ICECandidatePairStats associated with this transport. + SelectedCandidatePairID string `json:"selectedCandidatePairId"` + + // LocalCertificateID is the ID of the CertificateStats for the local certificate. + // Present only if DTLS is negotiated. + LocalCertificateID string `json:"localCertificateId"` + + // RemoteCertificateID is the ID of the CertificateStats for the remote certificate. + // Present only if DTLS is negotiated. + RemoteCertificateID string `json:"remoteCertificateId"` + + // DTLSCipher is the descriptive name of the cipher suite used for the DTLS transport, + // as defined in the "Description" column of the IANA cipher suite registry. + DTLSCipher string `json:"dtlsCipher"` + + // SRTPCipher is the descriptive name of the protection profile used for the SRTP + // transport, as defined in the "Profile" column of the IANA DTLS-SRTP protection + // profile registry. + SRTPCipher string `json:"srtpCipher"` +} + +func (s TransportStats) statsMarker() {} + +func unmarshalTransportStats(b []byte) (TransportStats, error) { + var transportStats TransportStats + err := json.Unmarshal(b, &transportStats) + if err != nil { + return TransportStats{}, fmt.Errorf("unmarshal transport stats: %w", err) + } + + return transportStats, nil +} + +// StatsICECandidatePairState is the state of an ICE candidate pair used in the +// ICECandidatePairStats object. +type StatsICECandidatePairState string + +func toStatsICECandidatePairState(state ice.CandidatePairState) (StatsICECandidatePairState, error) { + switch state { + case ice.CandidatePairStateWaiting: + return StatsICECandidatePairStateWaiting, nil + case ice.CandidatePairStateInProgress: + return StatsICECandidatePairStateInProgress, nil + case ice.CandidatePairStateFailed: + return StatsICECandidatePairStateFailed, nil + case ice.CandidatePairStateSucceeded: + return StatsICECandidatePairStateSucceeded, nil + default: + // NOTE: this should never happen[tm] + err := fmt.Errorf("%w: %s", errStatsICECandidateStateInvalid, state.String()) + + return StatsICECandidatePairState("Unknown"), err + } +} + +func toICECandidatePairStats(candidatePairStats ice.CandidatePairStats) (ICECandidatePairStats, error) { + state, err := toStatsICECandidatePairState(candidatePairStats.State) + if err != nil { + return ICECandidatePairStats{}, err + } + + return ICECandidatePairStats{ + Timestamp: statsTimestampFrom(candidatePairStats.Timestamp), + Type: StatsTypeCandidatePair, + ID: newICECandidatePairStatsID(candidatePairStats.LocalCandidateID, candidatePairStats.RemoteCandidateID), + // TransportID: + LocalCandidateID: candidatePairStats.LocalCandidateID, + RemoteCandidateID: candidatePairStats.RemoteCandidateID, + State: state, + Nominated: candidatePairStats.Nominated, + PacketsSent: candidatePairStats.PacketsSent, + PacketsReceived: candidatePairStats.PacketsReceived, + BytesSent: candidatePairStats.BytesSent, + BytesReceived: candidatePairStats.BytesReceived, + LastPacketSentTimestamp: statsTimestampFrom(candidatePairStats.LastPacketSentTimestamp), + LastPacketReceivedTimestamp: statsTimestampFrom(candidatePairStats.LastPacketReceivedTimestamp), + FirstRequestTimestamp: statsTimestampFrom(candidatePairStats.FirstRequestTimestamp), + LastRequestTimestamp: statsTimestampFrom(candidatePairStats.LastRequestTimestamp), + FirstResponseTimestamp: statsTimestampFrom(candidatePairStats.FirstResponseTimestamp), + LastResponseTimestamp: statsTimestampFrom(candidatePairStats.LastResponseTimestamp), + FirstRequestReceivedTimestamp: statsTimestampFrom(candidatePairStats.FirstRequestReceivedTimestamp), + LastRequestReceivedTimestamp: statsTimestampFrom(candidatePairStats.LastRequestReceivedTimestamp), + TotalRoundTripTime: candidatePairStats.TotalRoundTripTime, + CurrentRoundTripTime: candidatePairStats.CurrentRoundTripTime, + AvailableOutgoingBitrate: candidatePairStats.AvailableOutgoingBitrate, + AvailableIncomingBitrate: candidatePairStats.AvailableIncomingBitrate, + CircuitBreakerTriggerCount: candidatePairStats.CircuitBreakerTriggerCount, + RequestsReceived: candidatePairStats.RequestsReceived, + RequestsSent: candidatePairStats.RequestsSent, + ResponsesReceived: candidatePairStats.ResponsesReceived, + ResponsesSent: candidatePairStats.ResponsesSent, + RetransmissionsReceived: candidatePairStats.RetransmissionsReceived, + RetransmissionsSent: candidatePairStats.RetransmissionsSent, + ConsentRequestsSent: candidatePairStats.ConsentRequestsSent, + ConsentExpiredTimestamp: statsTimestampFrom(candidatePairStats.ConsentExpiredTimestamp), + }, nil +} + +const ( + // StatsICECandidatePairStateFrozen means a check for this pair hasn't been + // performed, and it can't yet be performed until some other check succeeds, + // allowing this pair to unfreeze and move into the Waiting state. + StatsICECandidatePairStateFrozen StatsICECandidatePairState = "frozen" + + // StatsICECandidatePairStateWaiting means a check has not been performed for + // this pair, and can be performed as soon as it is the highest-priority Waiting + // pair on the check list. + StatsICECandidatePairStateWaiting StatsICECandidatePairState = "waiting" + + // StatsICECandidatePairStateInProgress means a check has been sent for this pair, + // but the transaction is in progress. + StatsICECandidatePairStateInProgress StatsICECandidatePairState = "in-progress" + + // StatsICECandidatePairStateFailed means a check for this pair was already done + // and failed, either never producing any response or producing an unrecoverable + // failure response. + StatsICECandidatePairStateFailed StatsICECandidatePairState = "failed" + + // StatsICECandidatePairStateSucceeded means a check for this pair was already + // done and produced a successful result. + StatsICECandidatePairStateSucceeded StatsICECandidatePairState = "succeeded" +) + +// ICECandidatePairStats contains ICE candidate pair statistics related +// to the ICETransport objects. +type ICECandidatePairStats struct { + // Timestamp is the timestamp associated with this object. + Timestamp StatsTimestamp `json:"timestamp"` + + // Type is the object's StatsType + Type StatsType `json:"type"` + + // ID is a unique id that is associated with the component inspected to produce + // this Stats object. Two Stats objects will have the same ID if they were produced + // by inspecting the same underlying object. + ID string `json:"id"` + + // TransportID is a unique identifier that is associated to the object that + // was inspected to produce the TransportStats associated with this candidate pair. + TransportID string `json:"transportId"` + + // LocalCandidateID is a unique identifier that is associated to the object + // that was inspected to produce the ICECandidateStats for the local candidate + // associated with this candidate pair. + LocalCandidateID string `json:"localCandidateId"` + + // RemoteCandidateID is a unique identifier that is associated to the object + // that was inspected to produce the ICECandidateStats for the remote candidate + // associated with this candidate pair. + RemoteCandidateID string `json:"remoteCandidateId"` + + // State represents the state of the checklist for the local and remote + // candidates in a pair. + State StatsICECandidatePairState `json:"state"` + + // Nominated is true when this valid pair that should be used for media + // if it is the highest-priority one amongst those whose nominated flag is set + Nominated bool `json:"nominated"` + + // PacketsSent represents the total number of packets sent on this candidate pair. + PacketsSent uint32 `json:"packetsSent"` + + // PacketsReceived represents the total number of packets received on this candidate pair. + PacketsReceived uint32 `json:"packetsReceived"` + + // BytesSent represents the total number of payload bytes sent on this candidate pair + // not including headers or padding. + BytesSent uint64 `json:"bytesSent"` + + // BytesReceived represents the total number of payload bytes received on this candidate pair + // not including headers or padding. + BytesReceived uint64 `json:"bytesReceived"` + + // LastPacketSentTimestamp represents the timestamp at which the last packet was + // sent on this particular candidate pair, excluding STUN packets. + LastPacketSentTimestamp StatsTimestamp `json:"lastPacketSentTimestamp"` + + // LastPacketReceivedTimestamp represents the timestamp at which the last packet + // was received on this particular candidate pair, excluding STUN packets. + LastPacketReceivedTimestamp StatsTimestamp `json:"lastPacketReceivedTimestamp"` + + // FirstRequestTimestamp represents the timestamp at which the first STUN request + // was sent on this particular candidate pair. + FirstRequestTimestamp StatsTimestamp `json:"firstRequestTimestamp"` + + // LastRequestTimestamp represents the timestamp at which the last STUN request + // was sent on this particular candidate pair. The average interval between two + // consecutive connectivity checks sent can be calculated with + // (LastRequestTimestamp - FirstRequestTimestamp) / RequestsSent. + LastRequestTimestamp StatsTimestamp `json:"lastRequestTimestamp"` + + // FirstResponseTimestamp represents the timestamp at which the first STUN response + // was received on this particular candidate pair. + FirstResponseTimestamp StatsTimestamp `json:"firstResponseTimestamp"` + + // LastResponseTimestamp represents the timestamp at which the last STUN response + // was received on this particular candidate pair. + LastResponseTimestamp StatsTimestamp `json:"lastResponseTimestamp"` + + // FirstRequestReceivedTimestamp represents the timestamp at which the first + // connectivity check request was received. + FirstRequestReceivedTimestamp StatsTimestamp `json:"firstRequestReceivedTimestamp"` + + // LastRequestReceivedTimestamp represents the timestamp at which the last + // connectivity check request was received. + LastRequestReceivedTimestamp StatsTimestamp `json:"lastRequestReceivedTimestamp"` + + // TotalRoundTripTime represents the sum of all round trip time measurements + // in seconds since the beginning of the session, based on STUN connectivity + // check responses (ResponsesReceived), including those that reply to requests + // that are sent in order to verify consent. The average round trip time can + // be computed from TotalRoundTripTime by dividing it by ResponsesReceived. + TotalRoundTripTime float64 `json:"totalRoundTripTime"` + + // CurrentRoundTripTime represents the latest round trip time measured in seconds, + // computed from both STUN connectivity checks, including those that are sent + // for consent verification. + CurrentRoundTripTime float64 `json:"currentRoundTripTime"` + + // AvailableOutgoingBitrate is calculated by the underlying congestion control + // by combining the available bitrate for all the outgoing RTP streams using + // this candidate pair. The bitrate measurement does not count the size of the + // IP or other transport layers like TCP or UDP. It is similar to the TIAS defined + // in RFC 3890, i.e., it is measured in bits per second and the bitrate is calculated + // over a 1 second window. + AvailableOutgoingBitrate float64 `json:"availableOutgoingBitrate"` + + // AvailableIncomingBitrate is calculated by the underlying congestion control + // by combining the available bitrate for all the incoming RTP streams using + // this candidate pair. The bitrate measurement does not count the size of the + // IP or other transport layers like TCP or UDP. It is similar to the TIAS defined + // in RFC 3890, i.e., it is measured in bits per second and the bitrate is + // calculated over a 1 second window. + AvailableIncomingBitrate float64 `json:"availableIncomingBitrate"` + + // CircuitBreakerTriggerCount represents the number of times the circuit breaker + // is triggered for this particular 5-tuple, ceasing transmission. + CircuitBreakerTriggerCount uint32 `json:"circuitBreakerTriggerCount"` + + // RequestsReceived represents the total number of connectivity check requests + // received (including retransmissions). It is impossible for the receiver to + // tell whether the request was sent in order to check connectivity or check + // consent, so all connectivity checks requests are counted here. + RequestsReceived uint64 `json:"requestsReceived"` + + // RequestsSent represents the total number of connectivity check requests + // sent (not including retransmissions). + RequestsSent uint64 `json:"requestsSent"` + + // ResponsesReceived represents the total number of connectivity check responses received. + ResponsesReceived uint64 `json:"responsesReceived"` + + // ResponsesSent represents the total number of connectivity check responses sent. + // Since we cannot distinguish connectivity check requests and consent requests, + // all responses are counted. + ResponsesSent uint64 `json:"responsesSent"` + + // RetransmissionsReceived represents the total number of connectivity check + // request retransmissions received. + RetransmissionsReceived uint64 `json:"retransmissionsReceived"` + + // RetransmissionsSent represents the total number of connectivity check + // request retransmissions sent. + RetransmissionsSent uint64 `json:"retransmissionsSent"` + + // ConsentRequestsSent represents the total number of consent requests sent. + ConsentRequestsSent uint64 `json:"consentRequestsSent"` + + // ConsentExpiredTimestamp represents the timestamp at which the latest valid + // STUN binding response expired. + ConsentExpiredTimestamp StatsTimestamp `json:"consentExpiredTimestamp"` + + // PacketsDiscardedOnSend represents the total number of packets for this candidate pair + // that have been discarded due to socket errors, i.e. a socket error occurred + // when handing the packets to the socket. This might happen due to various reasons, + // including full buffer or no available memory. + PacketsDiscardedOnSend uint32 `json:"packetsDiscardedOnSend"` + + // BytesDiscardedOnSend represents the total number of bytes for this candidate pair + // that have been discarded due to socket errors, i.e. a socket error occurred + // when handing the packets containing the bytes to the socket. This might happen due + // to various reasons, including full buffer or no available memory. + // Calculated as defined in [RFC3550] section 6.4.1. + BytesDiscardedOnSend uint32 `json:"bytesDiscardedOnSend"` +} + +func (s ICECandidatePairStats) statsMarker() {} + +func unmarshalICECandidatePairStats(b []byte) (ICECandidatePairStats, error) { + var iceCandidatePairStats ICECandidatePairStats + err := json.Unmarshal(b, &iceCandidatePairStats) + if err != nil { + return ICECandidatePairStats{}, fmt.Errorf("unmarshal ice candidate pair stats: %w", err) + } + + return iceCandidatePairStats, nil +} + +// ICECandidateStats contains ICE candidate statistics related to the ICETransport objects. +type ICECandidateStats struct { + // Timestamp is the timestamp associated with this object. + Timestamp StatsTimestamp `json:"timestamp"` + + // Type is the object's StatsType + Type StatsType `json:"type"` + + // ID is a unique id that is associated with the component inspected to produce + // this Stats object. Two Stats objects will have the same ID if they were produced + // by inspecting the same underlying object. + ID string `json:"id"` + + // TransportID is a unique identifier that is associated to the object that + // was inspected to produce the TransportStats associated with this candidate. + TransportID string `json:"transportId"` + + // NetworkType represents the type of network interface used by the base of a + // local candidate (the address the ICE agent sends from). Only present for + // local candidates; it's not possible to know what type of network interface + // a remote candidate is using. + // + // Note: + // This stat only tells you about the network interface used by the first "hop"; + // it's possible that a connection will be bottlenecked by another type of network. + // For example, when using Wi-Fi tethering, the networkType of the relevant candidate + // would be "wifi", even when the next hop is over a cellular connection. + // + // DEPRECATED. Although it may still work in some browsers, the networkType property was deprecated for + // preserving privacy. + NetworkType string `json:"networkType,omitempty"` + + // IP is the IP address of the candidate, allowing for IPv4 addresses and + // IPv6 addresses, but fully qualified domain names (FQDNs) are not allowed. + IP string `json:"ip"` + + // Port is the port number of the candidate. + Port int32 `json:"port"` + + // Protocol is one of udp and tcp. + Protocol string `json:"protocol"` + + // CandidateType is the "Type" field of the ICECandidate. + CandidateType ICECandidateType `json:"candidateType"` + + // Priority is the "Priority" field of the ICECandidate. + Priority int32 `json:"priority"` + + // URL of the TURN or STUN server that produced this candidate + // It is the URL address surfaced in an PeerConnectionICEEvent. + URL string `json:"url"` + + // RelayProtocol is the protocol used by the endpoint to communicate with the + // TURN server. This is only present for local candidates. Valid values for + // the TURN URL protocol is one of udp, tcp, or tls. + RelayProtocol string `json:"relayProtocol"` + + // Deleted is true if the candidate has been deleted/freed. For host candidates, + // this means that any network resources (typically a socket) associated with the + // candidate have been released. For TURN candidates, this means the TURN allocation + // is no longer active. + // + // Only defined for local candidates. For remote candidates, this property is not applicable. + Deleted bool `json:"deleted"` +} + +func (s ICECandidateStats) statsMarker() {} + +func unmarshalICECandidateStats(b []byte) (ICECandidateStats, error) { + var iceCandidateStats ICECandidateStats + err := json.Unmarshal(b, &iceCandidateStats) + if err != nil { + return ICECandidateStats{}, fmt.Errorf("unmarshal ice candidate stats: %w", err) + } + + return iceCandidateStats, nil +} + +// CertificateStats contains information about a certificate used by an ICETransport. +type CertificateStats struct { + // Timestamp is the timestamp associated with this object. + Timestamp StatsTimestamp `json:"timestamp"` + + // Type is the object's StatsType + Type StatsType `json:"type"` + + // ID is a unique id that is associated with the component inspected to produce + // this Stats object. Two Stats objects will have the same ID if they were produced + // by inspecting the same underlying object. + ID string `json:"id"` + + // Fingerprint is the fingerprint of the certificate. + Fingerprint string `json:"fingerprint"` + + // FingerprintAlgorithm is the hash function used to compute the certificate fingerprint. For instance, "sha-256". + FingerprintAlgorithm string `json:"fingerprintAlgorithm"` + + // Base64Certificate is the DER-encoded base-64 representation of the certificate. + Base64Certificate string `json:"base64Certificate"` + + // IssuerCertificateID refers to the stats object that contains the next certificate + // in the certificate chain. If the current certificate is at the end of the chain + // (i.e. a self-signed certificate), this will not be set. + IssuerCertificateID string `json:"issuerCertificateId"` +} + +func (s CertificateStats) statsMarker() {} + +func unmarshalCertificateStats(b []byte) (CertificateStats, error) { + var certificateStats CertificateStats + err := json.Unmarshal(b, &certificateStats) + if err != nil { + return CertificateStats{}, fmt.Errorf("unmarshal certificate stats: %w", err) + } + + return certificateStats, nil +} + +// SCTPTransportStats contains information about a certificate used by an SCTPTransport. +type SCTPTransportStats struct { + // Timestamp is the timestamp associated with this object. + Timestamp StatsTimestamp `json:"timestamp"` + + // Type is the object's StatsType + Type StatsType `json:"type"` + + // ID is a unique id that is associated with the component inspected to produce + // this Stats object. Two Stats objects will have the same ID if they were produced + // by inspecting the same underlying object. + ID string `json:"id"` + + // TransportID is the identifier of the object that was inspected to produce the + // RTCTransportStats for the DTLSTransport and ICETransport supporting the SCTP transport. + TransportID string `json:"transportId"` + + // SmoothedRoundTripTime is the latest smoothed round-trip time value, + // corresponding to spinfo_srtt defined in [RFC6458] but converted to seconds. + // If there has been no round-trip time measurements yet, this value is undefined. + SmoothedRoundTripTime float64 `json:"smoothedRoundTripTime"` + + // CongestionWindow is the latest congestion window, corresponding to spinfo_cwnd defined in [RFC6458]. + CongestionWindow uint32 `json:"congestionWindow"` + + // ReceiverWindow is the latest receiver window, corresponding to sstat_rwnd defined in [RFC6458]. + ReceiverWindow uint32 `json:"receiverWindow"` + + // MTU is the latest maximum transmission unit, corresponding to spinfo_mtu defined in [RFC6458]. + MTU uint32 `json:"mtu"` + + // UNACKData is the number of unacknowledged DATA chunks, corresponding to sstat_unackdata defined in [RFC6458]. + UNACKData uint32 `json:"unackData"` + + // BytesSent represents the total number of bytes sent on this SCTPTransport + BytesSent uint64 `json:"bytesSent"` + + // BytesReceived represents the total number of bytes received on this SCTPTransport + BytesReceived uint64 `json:"bytesReceived"` +} + +func (s SCTPTransportStats) statsMarker() {} + +func unmarshalSCTPTransportStats(b []byte) (SCTPTransportStats, error) { + var sctpTransportStats SCTPTransportStats + if err := json.Unmarshal(b, &sctpTransportStats); err != nil { + return SCTPTransportStats{}, fmt.Errorf("unmarshal sctp transport stats: %w", err) + } + + return sctpTransportStats, nil +} diff --git a/vendor/github.com/pion/webrtc/v4/stats_go.go b/vendor/github.com/pion/webrtc/v4/stats_go.go new file mode 100644 index 0000000..5f0eb2d --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/stats_go.go @@ -0,0 +1,242 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js + +package webrtc + +import ( + "context" + "sync" + "time" +) + +// GetConnectionStats is a helper method to return the associated stats for a given PeerConnection. +func (r StatsReport) GetConnectionStats(conn *PeerConnection) (PeerConnectionStats, bool) { + statsID := conn.ID() + stats, ok := r[statsID] + if !ok { + return PeerConnectionStats{}, false + } + + pcStats, ok := stats.(PeerConnectionStats) + if !ok { + return PeerConnectionStats{}, false + } + + return pcStats, true +} + +// GetDataChannelStats is a helper method to return the associated stats for a given DataChannel. +func (r StatsReport) GetDataChannelStats(dc *DataChannel) (DataChannelStats, bool) { + statsID := dc.getStatsID() + stats, ok := r[statsID] + if !ok { + return DataChannelStats{}, false + } + + dcStats, ok := stats.(DataChannelStats) + if !ok { + return DataChannelStats{}, false + } + + return dcStats, true +} + +// GetICECandidateStats is a helper method to return the associated stats for a given ICECandidate. +func (r StatsReport) GetICECandidateStats(c *ICECandidate) (ICECandidateStats, bool) { + statsID := c.statsID + stats, ok := r[statsID] + if !ok { + return ICECandidateStats{}, false + } + + candidateStats, ok := stats.(ICECandidateStats) + if !ok { + return ICECandidateStats{}, false + } + + return candidateStats, true +} + +// GetICECandidatePairStats is a helper method to return the associated stats for a given ICECandidatePair. +func (r StatsReport) GetICECandidatePairStats(c *ICECandidatePair) (ICECandidatePairStats, bool) { + statsID := c.statsID + stats, ok := r[statsID] + if !ok { + return ICECandidatePairStats{}, false + } + + candidateStats, ok := stats.(ICECandidatePairStats) + if !ok { + return ICECandidatePairStats{}, false + } + + return candidateStats, true +} + +// GetCertificateStats is a helper method to return the associated stats for a given Certificate. +func (r StatsReport) GetCertificateStats(c *Certificate) (CertificateStats, bool) { + statsID := c.statsID + stats, ok := r[statsID] + if !ok { + return CertificateStats{}, false + } + + certificateStats, ok := stats.(CertificateStats) + if !ok { + return CertificateStats{}, false + } + + return certificateStats, true +} + +// GetCodecStats is a helper method to return the associated stats for a given Codec. +func (r StatsReport) GetCodecStats(c *RTPCodecParameters) (CodecStats, bool) { + statsID := c.statsID + stats, ok := r[statsID] + if !ok { + return CodecStats{}, false + } + + codecStats, ok := stats.(CodecStats) + if !ok { + return CodecStats{}, false + } + + return codecStats, true +} + +// AudioPlayoutStatsProvider is an interface for getting audio playout metrics. +type AudioPlayoutStatsProvider interface { + // AddTrack registers a track to report playout stats to this provider. + AddTrack(track *TrackRemote) error + + // RemoveTrack unregisters a track from this provider. + RemoveTrack(track *TrackRemote) + + // Snapshot returns the accumulated stats at the given time. + Snapshot(now time.Time) (AudioPlayoutStats, bool) +} + +type trackContext struct { + cancel context.CancelFunc +} + +// defaultAudioPlayoutStatsProvider accumulates audio playout stats on behalf of the application. +type defaultAudioPlayoutStatsProvider struct { + mu sync.Mutex + + stats AudioPlayoutStats + lastSynthesized bool + tracks map[*TrackRemote]*trackContext +} + +// NewAudioPlayoutStatsProvider constructs a default provider with the supplied stats ID. +func NewAudioPlayoutStatsProvider(id string) *defaultAudioPlayoutStatsProvider { + return &defaultAudioPlayoutStatsProvider{ + stats: AudioPlayoutStats{ + ID: id, + Type: StatsTypeMediaPlayout, + Kind: string(MediaKindAudio), + }, + tracks: make(map[*TrackRemote]*trackContext), + } +} + +// Accumulate applies a new batch of played-out samples to the running totals. +func (p *defaultAudioPlayoutStatsProvider) Accumulate( + samples int, sampleRate uint32, deviceDelay time.Duration, synthesized bool, +) { + if samples <= 0 || sampleRate == 0 { + return + } + + delaySeconds := deviceDelay.Seconds() + if delaySeconds < 0 { + delaySeconds = 0 + } + + duration := float64(samples) / float64(sampleRate) + + p.mu.Lock() + defer p.mu.Unlock() + + p.stats.TotalSamplesCount += uint64(samples) + p.stats.TotalSamplesDuration += duration + p.stats.TotalPlayoutDelay += delaySeconds * float64(samples) + + if synthesized { + p.stats.SynthesizedSamplesDuration += duration + if !p.lastSynthesized { + p.stats.SynthesizedSamplesEvents++ + } + } + + p.lastSynthesized = synthesized +} + +// Snapshot returns the accumulated stats at the given time. +func (p *defaultAudioPlayoutStatsProvider) Snapshot(now time.Time) (AudioPlayoutStats, bool) { + p.mu.Lock() + defer p.mu.Unlock() + + if p.stats.TotalSamplesCount == 0 { + return AudioPlayoutStats{}, false + } + + stats := p.stats + stats.Timestamp = statsTimestampFrom(now) + + return stats, true +} + +// AddTrack registers a track to report playout stats to this provider. +func (p *defaultAudioPlayoutStatsProvider) AddTrack(track *TrackRemote) error { + p.mu.Lock() + defer p.mu.Unlock() + + if _, exists := p.tracks[track]; exists { + return nil + } + + track.addProvider(p) + + ctx, cancel := context.WithCancel(context.Background()) + p.tracks[track] = &trackContext{cancel: cancel} + + go func() { + receiver := track.receiver + if receiver == nil { + cancel() + + return + } + + select { + case <-receiver.closedChan: + p.removeTrackInternal(track) + case <-ctx.Done(): + return + } + }() + + return nil +} + +// RemoveTrack unregisters a track from this provider. +func (p *defaultAudioPlayoutStatsProvider) RemoveTrack(track *TrackRemote) { + p.removeTrackInternal(track) +} + +func (p *defaultAudioPlayoutStatsProvider) removeTrackInternal(track *TrackRemote) { + p.mu.Lock() + defer p.mu.Unlock() + + if tc, exists := p.tracks[track]; exists { + tc.cancel() + delete(p.tracks, track) + } + + track.removeProvider(p) +} diff --git a/vendor/github.com/pion/webrtc/v4/track_local.go b/vendor/github.com/pion/webrtc/v4/track_local.go new file mode 100644 index 0000000..667500f --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/track_local.go @@ -0,0 +1,128 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package webrtc + +import ( + "github.com/pion/interceptor" + "github.com/pion/rtp" +) + +// TrackLocalWriter is the Writer for outbound RTP Packets. +type TrackLocalWriter interface { + // WriteRTP encrypts a RTP packet and writes to the connection + WriteRTP(header *rtp.Header, payload []byte) (int, error) + + // Write encrypts and writes a full RTP packet + Write(b []byte) (int, error) +} + +// TrackLocalContext is the Context passed when a TrackLocal has been Binded/Unbinded from a PeerConnection, and used +// in Interceptors. +type TrackLocalContext interface { + // CodecParameters returns the negotiated RTPCodecParameters. These are the codecs supported by both + // PeerConnections and the PayloadTypes + CodecParameters() []RTPCodecParameters + + // HeaderExtensions returns the negotiated RTPHeaderExtensionParameters. These are the header extensions supported by + // both PeerConnections and the URI/IDs + HeaderExtensions() []RTPHeaderExtensionParameter + + // SSRC returns the negotiated SSRC of this track + SSRC() SSRC + + // SSRCRetransmission returns the negotiated SSRC used to send retransmissions for this track + SSRCRetransmission() SSRC + + // SSRCForwardErrorCorrection returns the negotiated SSRC to send forward error correction for this track + SSRCForwardErrorCorrection() SSRC + + // WriteStream returns the WriteStream for this TrackLocal. The implementer writes the outbound + // media packets to it + WriteStream() TrackLocalWriter + + // ID is a unique identifier that is used for both Bind/Unbind + ID() string + + // RTCPReader returns the RTCP interceptor for this TrackLocal. Used to read RTCP of this TrackLocal. + RTCPReader() interceptor.RTCPReader +} + +type baseTrackLocalContext struct { + id string + params RTPParameters + ssrc, ssrcRTX, ssrcFEC SSRC + writeStream TrackLocalWriter + rtcpInterceptor interceptor.RTCPReader +} + +// CodecParameters returns the negotiated RTPCodecParameters. These are the codecs supported by both +// PeerConnections and the SSRC/PayloadTypes. +func (t *baseTrackLocalContext) CodecParameters() []RTPCodecParameters { + return t.params.Codecs +} + +// HeaderExtensions returns the negotiated RTPHeaderExtensionParameters. These are the header extensions supported by +// both PeerConnections and the SSRC/PayloadTypes. +func (t *baseTrackLocalContext) HeaderExtensions() []RTPHeaderExtensionParameter { + return t.params.HeaderExtensions +} + +// SSRC requires the negotiated SSRC of this track. +func (t *baseTrackLocalContext) SSRC() SSRC { + return t.ssrc +} + +// SSRCRetransmission returns the negotiated SSRC used to send retransmissions for this track. +func (t *baseTrackLocalContext) SSRCRetransmission() SSRC { + return t.ssrcRTX +} + +// SSRCForwardErrorCorrection returns the negotiated SSRC to send forward error correction for this track. +func (t *baseTrackLocalContext) SSRCForwardErrorCorrection() SSRC { + return t.ssrcFEC +} + +// WriteStream returns the WriteStream for this TrackLocal. The implementer writes the outbound +// media packets to it. +func (t *baseTrackLocalContext) WriteStream() TrackLocalWriter { + return t.writeStream +} + +// ID is a unique identifier that is used for both Bind/Unbind. +func (t *baseTrackLocalContext) ID() string { + return t.id +} + +// RTCPReader returns the RTCP interceptor for this TrackLocal. Used to read RTCP of this TrackLocal. +func (t *baseTrackLocalContext) RTCPReader() interceptor.RTCPReader { + return t.rtcpInterceptor +} + +// TrackLocal is an interface that controls how the user can send media +// The user can provide their own TrackLocal implementations, or use +// the implementations in pkg/media. +type TrackLocal interface { + // Bind should implement the way how the media data flows from the Track to the PeerConnection + // This will be called internally after signaling is complete and the list of available + // codecs has been determined + Bind(TrackLocalContext) (RTPCodecParameters, error) + + // Unbind should implement the teardown logic when the track is no longer needed. This happens + // because a track has been stopped. + Unbind(TrackLocalContext) error + + // ID is the unique identifier for this Track. This should be unique for the + // stream, but doesn't have to globally unique. A common example would be 'audio' or 'video' + // and StreamID would be 'desktop' or 'webcam' + ID() string + + // RID is the RTP Stream ID for this track. + RID() string + + // StreamID is the group this track belongs too. This must be unique + StreamID() string + + // Kind controls if this TrackLocal is audio or video + Kind() RTPCodecType +} diff --git a/vendor/github.com/pion/webrtc/v4/track_local_static.go b/vendor/github.com/pion/webrtc/v4/track_local_static.go new file mode 100644 index 0000000..c1d53bc --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/track_local_static.go @@ -0,0 +1,411 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js + +package webrtc + +import ( + "strings" + "sync" + + "github.com/pion/rtp" + "github.com/pion/webrtc/v4/internal/util" + "github.com/pion/webrtc/v4/pkg/media" +) + +// trackBinding is a single bind for a Track +// Bind can be called multiple times, this stores the +// result for a single bind call so that it can be used when writing. +type trackBinding struct { + id string + ssrc, ssrcRTX, ssrcFEC SSRC + payloadType, payloadTypeRTX PayloadType + writeStream TrackLocalWriter +} + +// TrackLocalStaticRTP is a TrackLocal that has a pre-set codec and accepts RTP Packets. +// If you wish to send a media.Sample use TrackLocalStaticSample. +type TrackLocalStaticRTP struct { + mu sync.RWMutex + bindings []trackBinding + codec RTPCodecCapability + payloader func(RTPCodecCapability) (rtp.Payloader, error) + id, rid, streamID string + initalTimestamp *uint32 + initialSeqNumber *uint16 +} + +// NewTrackLocalStaticRTP returns a TrackLocalStaticRTP. +func NewTrackLocalStaticRTP( + c RTPCodecCapability, + id, streamID string, + options ...func(*TrackLocalStaticRTP), +) (*TrackLocalStaticRTP, error) { + t := &TrackLocalStaticRTP{ + codec: c, + bindings: []trackBinding{}, + id: id, + streamID: streamID, + } + + for _, option := range options { + option(t) + } + + return t, nil +} + +// WithRTPStreamID sets the RTP stream ID for this TrackLocalStaticRTP. +func WithRTPStreamID(rid string) func(*TrackLocalStaticRTP) { + return func(t *TrackLocalStaticRTP) { + t.rid = rid + } +} + +// WithPayloader allows the user to override the Payloader. +func WithPayloader(h func(RTPCodecCapability) (rtp.Payloader, error)) func(*TrackLocalStaticRTP) { + return func(s *TrackLocalStaticRTP) { + s.payloader = h + } +} + +// WithRTPTimestamp set the initial RTP timestamp for the track. +func WithRTPTimestamp(timestamp uint32) func(*TrackLocalStaticRTP) { + return func(s *TrackLocalStaticRTP) { + s.initalTimestamp = ×tamp + } +} + +// WithRTPSequenceNumber sets the initial RTP sequence number for the track. +func WithRTPSequenceNumber(sequenceNumber uint16) func(*TrackLocalStaticRTP) { + return func(s *TrackLocalStaticRTP) { + s.initialSeqNumber = &sequenceNumber + } +} + +// Bind is called by the PeerConnection after negotiation is complete +// This asserts that the code requested is supported by the remote peer. +// If so it sets up all the state (SSRC and PayloadType) to have a call. +func (s *TrackLocalStaticRTP) Bind(trackContext TrackLocalContext) (RTPCodecParameters, error) { + s.mu.Lock() + defer s.mu.Unlock() + + parameters := RTPCodecParameters{RTPCodecCapability: s.codec} + if codec, matchType := codecParametersFuzzySearch( + parameters, + trackContext.CodecParameters(), + ); matchType != codecMatchNone { + s.bindings = append(s.bindings, trackBinding{ + ssrc: trackContext.SSRC(), + ssrcRTX: trackContext.SSRCRetransmission(), + ssrcFEC: trackContext.SSRCForwardErrorCorrection(), + payloadType: codec.PayloadType, + payloadTypeRTX: findRTXPayloadType(codec.PayloadType, trackContext.CodecParameters()), + writeStream: trackContext.WriteStream(), + id: trackContext.ID(), + }) + + return codec, nil + } + + return RTPCodecParameters{}, ErrUnsupportedCodec +} + +// Unbind implements the teardown logic when the track is no longer needed. This happens +// because a track has been stopped. +func (s *TrackLocalStaticRTP) Unbind(t TrackLocalContext) error { + s.mu.Lock() + defer s.mu.Unlock() + + for i := range s.bindings { + if s.bindings[i].id == t.ID() { + s.bindings[i] = s.bindings[len(s.bindings)-1] + s.bindings = s.bindings[:len(s.bindings)-1] + + return nil + } + } + + return ErrUnbindFailed +} + +// ID is the unique identifier for this Track. This should be unique for the +// stream, but doesn't have to globally unique. A common example would be 'audio' or 'video' +// and StreamID would be 'desktop' or 'webcam'. +func (s *TrackLocalStaticRTP) ID() string { return s.id } + +// StreamID is the group this track belongs too. This must be unique. +func (s *TrackLocalStaticRTP) StreamID() string { return s.streamID } + +// RID is the RTP stream identifier. +func (s *TrackLocalStaticRTP) RID() string { return s.rid } + +// Kind controls if this TrackLocal is audio or video. +func (s *TrackLocalStaticRTP) Kind() RTPCodecType { + switch { + case strings.HasPrefix(s.codec.MimeType, "audio/"): + return RTPCodecTypeAudio + case strings.HasPrefix(s.codec.MimeType, "video/"): + return RTPCodecTypeVideo + default: + return RTPCodecType(0) + } +} + +// Codec gets the Codec of the track. +func (s *TrackLocalStaticRTP) Codec() RTPCodecCapability { + return s.codec +} + +// packetPool is a pool of packets used by WriteRTP and Write below +// nolint:gochecknoglobals +var rtpPacketPool = sync.Pool{ + New: func() any { + return &rtp.Packet{} + }, +} + +func resetPacketPoolAllocation(localPacket *rtp.Packet) { + *localPacket = rtp.Packet{} + rtpPacketPool.Put(localPacket) +} + +func getPacketAllocationFromPool() *rtp.Packet { + ipacket := rtpPacketPool.Get() + + return ipacket.(*rtp.Packet) //nolint:forcetypeassert +} + +// WriteRTP writes a RTP Packet to the TrackLocalStaticRTP +// If one PeerConnection fails the packets will still be sent to +// all PeerConnections. The error message will contain the ID of the failed +// PeerConnections so you can remove them. +func (s *TrackLocalStaticRTP) WriteRTP(p *rtp.Packet) error { + packet := getPacketAllocationFromPool() + + defer resetPacketPoolAllocation(packet) + + *packet = *p + + return s.writeRTP(packet) +} + +// writeRTP is like WriteRTP, except that it may modify the packet p. +func (s *TrackLocalStaticRTP) writeRTP(packet *rtp.Packet) error { + s.mu.RLock() + defer s.mu.RUnlock() + + writeErrs := []error{} + + for _, b := range s.bindings { + packet.Header.SSRC = uint32(b.ssrc) + packet.Header.PayloadType = uint8(b.payloadType) + // b.writeStream.WriteRTP below expects header and payload separately, so value of Packet.PaddingSize + // would be lost. Copy it to Packet.Header.PaddingSize to avoid that problem. + if packet.PaddingSize != 0 && packet.Header.PaddingSize == 0 { + packet.Header.PaddingSize = packet.PaddingSize + } + if _, err := b.writeStream.WriteRTP(&packet.Header, packet.Payload); err != nil { + writeErrs = append(writeErrs, err) + } + } + + return util.FlattenErrs(writeErrs) +} + +// Write writes a RTP Packet as a buffer to the TrackLocalStaticRTP +// If one PeerConnection fails the packets will still be sent to +// all PeerConnections. The error message will contain the ID of the failed +// PeerConnections so you can remove them. +func (s *TrackLocalStaticRTP) Write(b []byte) (n int, err error) { + packet := getPacketAllocationFromPool() + + defer resetPacketPoolAllocation(packet) + + if err = packet.Unmarshal(b); err != nil { + return 0, err + } + + return len(b), s.writeRTP(packet) +} + +// TrackLocalStaticSample is a TrackLocal that has a pre-set codec and accepts Samples. +// If you wish to send a RTP Packet use TrackLocalStaticRTP. +type TrackLocalStaticSample struct { + mu sync.Mutex + packetizer rtp.Packetizer + sequencer rtp.Sequencer + rtpTrack *TrackLocalStaticRTP + clockRate float64 + remainder float64 +} + +// NewTrackLocalStaticSample returns a TrackLocalStaticSample. +func NewTrackLocalStaticSample( + c RTPCodecCapability, + id, streamID string, + options ...func(*TrackLocalStaticRTP), +) (*TrackLocalStaticSample, error) { + rtpTrack, err := NewTrackLocalStaticRTP(c, id, streamID, options...) + if err != nil { + return nil, err + } + + return &TrackLocalStaticSample{ + rtpTrack: rtpTrack, + }, nil +} + +// ID is the unique identifier for this Track. This should be unique for the +// stream, but doesn't have to globally unique. A common example would be 'audio' or 'video' +// and StreamID would be 'desktop' or 'webcam'. +func (s *TrackLocalStaticSample) ID() string { return s.rtpTrack.ID() } + +// StreamID is the group this track belongs too. This must be unique. +func (s *TrackLocalStaticSample) StreamID() string { return s.rtpTrack.StreamID() } + +// RID is the RTP stream identifier. +func (s *TrackLocalStaticSample) RID() string { return s.rtpTrack.RID() } + +// Kind controls if this TrackLocal is audio or video. +func (s *TrackLocalStaticSample) Kind() RTPCodecType { return s.rtpTrack.Kind() } + +// Codec gets the Codec of the track. +func (s *TrackLocalStaticSample) Codec() RTPCodecCapability { + return s.rtpTrack.Codec() +} + +// Bind is called by the PeerConnection after negotiation is complete +// This asserts that the code requested is supported by the remote peer. +// If so it setups all the state (SSRC and PayloadType) to have a call. +func (s *TrackLocalStaticSample) Bind(t TrackLocalContext) (RTPCodecParameters, error) { + codec, err := s.rtpTrack.Bind(t) + if err != nil { + return codec, err + } + + s.rtpTrack.mu.Lock() + defer s.rtpTrack.mu.Unlock() + + // We only need one packetizer + if s.packetizer != nil { + return codec, nil + } + + payloadHandler := s.rtpTrack.payloader + if payloadHandler == nil { + payloadHandler = payloaderForCodec + } + + payloader, err := payloadHandler(codec.RTPCodecCapability) + if err != nil { + return codec, err + } + + options := []rtp.PacketizerOption{} + + if s.rtpTrack.initalTimestamp != nil { + options = append(options, rtp.WithTimestamp(*s.rtpTrack.initalTimestamp)) + } + + if s.rtpTrack.initialSeqNumber != nil { + s.sequencer = rtp.NewFixedSequencer(*s.rtpTrack.initialSeqNumber) + } + + if s.sequencer == nil { + s.sequencer = rtp.NewRandomSequencer() + } + + s.packetizer = rtp.NewPacketizerWithOptions( + outboundMTU, + payloader, + s.sequencer, + codec.ClockRate, + options..., + ) + + s.clockRate = float64(codec.RTPCodecCapability.ClockRate) + + return codec, nil +} + +// Unbind implements the teardown logic when the track is no longer needed. This happens +// because a track has been stopped. +func (s *TrackLocalStaticSample) Unbind(t TrackLocalContext) error { + return s.rtpTrack.Unbind(t) +} + +// WriteSample writes a Sample to the TrackLocalStaticSample +// If one PeerConnection fails the packets will still be sent to +// all PeerConnections. The error message will contain the ID of the failed +// PeerConnections so you can remove them. +func (s *TrackLocalStaticSample) WriteSample(sample media.Sample) error { + s.rtpTrack.mu.RLock() + packetizer := s.packetizer + clockRate := s.clockRate + sequencer := s.sequencer + s.rtpTrack.mu.RUnlock() + if packetizer == nil { + return nil + } + + s.mu.Lock() + remainder := s.remainder + + // skip packets by the number of previously dropped packets + for i := uint16(0); i < sample.PrevDroppedPackets; i++ { + sequencer.NextSequenceNumber() + } + + tickF := sample.Duration.Seconds() * clockRate + + if sample.PrevDroppedPackets > 0 { + dropTotal := tickF*float64(sample.PrevDroppedPackets) + remainder + dropTicks := uint32(dropTotal) + remainder = dropTotal - float64(dropTicks) + packetizer.SkipSamples(dropTicks) + } + + curTotal := tickF + remainder + curTicks := uint32(curTotal) + remainder = curTotal - float64(curTicks) + + s.remainder = remainder + packets := packetizer.Packetize(sample.Data, curTicks) + s.mu.Unlock() + + writeErrs := []error{} + for _, p := range packets { + if err := s.rtpTrack.WriteRTP(p); err != nil { + writeErrs = append(writeErrs, err) + } + } + + return util.FlattenErrs(writeErrs) +} + +// GeneratePadding writes padding-only samples to the TrackLocalStaticSample +// If one PeerConnection fails the packets will still be sent to +// all PeerConnections. The error message will contain the ID of the failed +// PeerConnections so you can remove them. +func (s *TrackLocalStaticSample) GeneratePadding(samples uint32) error { + s.rtpTrack.mu.RLock() + p := s.packetizer + s.rtpTrack.mu.RUnlock() + + if p == nil { + return nil + } + + packets := p.GeneratePadding(samples) + + writeErrs := []error{} + for _, p := range packets { + if err := s.rtpTrack.WriteRTP(p); err != nil { + writeErrs = append(writeErrs, err) + } + } + + return util.FlattenErrs(writeErrs) +} diff --git a/vendor/github.com/pion/webrtc/v4/track_remote.go b/vendor/github.com/pion/webrtc/v4/track_remote.go new file mode 100644 index 0000000..75b232c --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/track_remote.go @@ -0,0 +1,309 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js + +package webrtc + +import ( + "fmt" + "io" + "slices" + "sync" + "time" + + "github.com/pion/interceptor" + "github.com/pion/rtp" +) + +type peekedPacket struct { + payload []byte + attributes interceptor.Attributes +} + +// TrackRemote represents a single inbound source of media. +type TrackRemote struct { + mu sync.RWMutex + + id string + streamID string + + payloadType PayloadType + kind RTPCodecType + ssrc SSRC + rtxSsrc SSRC + codec RTPCodecParameters + params RTPParameters + rid string + + receiver *RTPReceiver + + peekedPackets []*peekedPacket + + audioPlayoutStatsProviders []AudioPlayoutStatsProvider +} + +func newTrackRemote(kind RTPCodecType, ssrc, rtxSsrc SSRC, rid string, receiver *RTPReceiver) *TrackRemote { + return &TrackRemote{ + kind: kind, + ssrc: ssrc, + rtxSsrc: rtxSsrc, + rid: rid, + receiver: receiver, + } +} + +// ID is the unique identifier for this Track. This should be unique for the +// stream, but doesn't have to globally unique. A common example would be 'audio' or 'video' +// and StreamID would be 'desktop' or 'webcam'. +func (t *TrackRemote) ID() string { + t.mu.RLock() + defer t.mu.RUnlock() + + return t.id +} + +// RID gets the RTP Stream ID of this Track +// With Simulcast you will have multiple tracks with the same ID, but different RID values. +// In many cases a TrackRemote will not have an RID, so it is important to assert it is non-zero. +func (t *TrackRemote) RID() string { + t.mu.RLock() + defer t.mu.RUnlock() + + return t.rid +} + +// PayloadType gets the PayloadType of the track. +func (t *TrackRemote) PayloadType() PayloadType { + t.mu.RLock() + defer t.mu.RUnlock() + + return t.payloadType +} + +// Kind gets the Kind of the track. +func (t *TrackRemote) Kind() RTPCodecType { + t.mu.RLock() + defer t.mu.RUnlock() + + return t.kind +} + +// StreamID is the group this track belongs too. This must be unique. +func (t *TrackRemote) StreamID() string { + t.mu.RLock() + defer t.mu.RUnlock() + + return t.streamID +} + +// SSRC gets the SSRC of the track. +func (t *TrackRemote) SSRC() SSRC { + t.mu.RLock() + defer t.mu.RUnlock() + + return t.ssrc +} + +// Msid gets the Msid of the track. +func (t *TrackRemote) Msid() string { + return t.StreamID() + " " + t.ID() +} + +// Codec gets the Codec of the track. +func (t *TrackRemote) Codec() RTPCodecParameters { + t.mu.RLock() + defer t.mu.RUnlock() + + return t.codec +} + +// Read reads data from the track. +func (t *TrackRemote) Read(b []byte) (n int, attributes interceptor.Attributes, err error) { + t.mu.RLock() + receiver := t.receiver + var peekedPkt *peekedPacket + if len(t.peekedPackets) != 0 { + peekedPkt = t.peekedPackets[0] + t.peekedPackets = t.peekedPackets[1:] + } + t.mu.RUnlock() + + if receiver.haveClosed() { + return 0, nil, io.EOF + } + + if peekedPkt != nil { + n = copy(b, peekedPkt.payload) + err = t.checkAndUpdateTrack(b) + + return n, peekedPkt.attributes, err + } + + // If there's a separate RTX track and an RTX packet is available, return that + if rtxPacketReceived := receiver.readRTX(t); rtxPacketReceived != nil { + n = copy(b, rtxPacketReceived.pkt) + attributes = rtxPacketReceived.attributes + rtxPacketReceived.release() + + return n, attributes, nil + } + + n, attributes, err = receiver.readRTP(b, t) + if err != nil { + return n, attributes, err + } + err = t.checkAndUpdateTrack(b) + + return n, attributes, err +} + +// checkAndUpdateTrack checks payloadType for every incoming packet +// once a different payloadType is detected the track will be updated. +func (t *TrackRemote) checkAndUpdateTrack(b []byte) error { + if len(b) < 2 { + return errRTPTooShort + } + + payloadType := PayloadType(b[1] & rtpPayloadTypeBitmask) + if payloadType != t.PayloadType() || len(t.params.Codecs) == 0 { + t.mu.Lock() + defer t.mu.Unlock() + + params, err := t.receiver.api.mediaEngine.getRTPParametersByPayloadType(payloadType) + if err != nil { + return err + } + + t.kind = t.receiver.kind + t.payloadType = payloadType + t.codec = params.Codecs[0] + t.params = params + } + + return nil +} + +// ReadRTP is a convenience method that wraps Read and unmarshals for you. +func (t *TrackRemote) ReadRTP() (*rtp.Packet, interceptor.Attributes, error) { + b := make([]byte, t.receiver.api.settingEngine.getReceiveMTU()) + i, attributes, err := t.Read(b) + if err != nil { + return nil, nil, err + } + + r := &rtp.Packet{} + if err := r.Unmarshal(b[:i]); err != nil { + return nil, nil, err + } + + return r, attributes, nil +} + +// peek is like Read, but it doesn't discard the packet read. +func (t *TrackRemote) peek(b []byte) (n int, a interceptor.Attributes, err error) { + n, a, err = t.Read(b) + if err != nil { + return + } + + t.mu.Lock() + // this might overwrite data if somebody peeked between the Read + // and us getting the lock. Oh well, we'll just drop a packet in + // that case. + data := make([]byte, n) + n = copy(data, b[:n]) + t.peekedPackets = append(t.peekedPackets, &peekedPacket{payload: data, attributes: a}) + t.mu.Unlock() + + return +} + +// SetReadDeadline sets the max amount of time the RTP stream will block before returning. 0 is forever. +func (t *TrackRemote) SetReadDeadline(deadline time.Time) error { + return t.receiver.setRTPReadDeadline(deadline, t) +} + +// RtxSSRC returns the RTX SSRC for a track, or 0 if track does not have a separate RTX stream. +func (t *TrackRemote) RtxSSRC() SSRC { + t.mu.RLock() + defer t.mu.RUnlock() + + return t.rtxSsrc +} + +// HasRTX returns true if the track has a separate RTX stream. +func (t *TrackRemote) HasRTX() bool { + t.mu.RLock() + defer t.mu.RUnlock() + + return t.rtxSsrc != 0 +} + +func (t *TrackRemote) addProvider(provider AudioPlayoutStatsProvider) { + t.mu.Lock() + defer t.mu.Unlock() + + if slices.Contains(t.audioPlayoutStatsProviders, provider) { + return + } + + t.audioPlayoutStatsProviders = append(t.audioPlayoutStatsProviders, provider) +} + +func (t *TrackRemote) removeProvider(provider AudioPlayoutStatsProvider) { + t.mu.Lock() + defer t.mu.Unlock() + + for i, p := range t.audioPlayoutStatsProviders { + if p == provider { + t.audioPlayoutStatsProviders = append(t.audioPlayoutStatsProviders[:i], t.audioPlayoutStatsProviders[i+1:]...) + + return + } + } +} + +func (t *TrackRemote) pullAudioPlayoutStats(now time.Time) []AudioPlayoutStats { + t.mu.RLock() + providers := t.audioPlayoutStatsProviders + t.mu.RUnlock() + + if len(providers) == 0 { + return nil + } + + var allStats []AudioPlayoutStats + for _, provider := range providers { + stats, ok := provider.Snapshot(now) + if !ok { + continue + } + + if stats.ID == "" { + stats.ID = fmt.Sprintf("media-playout-%d", uint32(t.SSRC())) + } + + if stats.Type == "" { + stats.Type = StatsTypeMediaPlayout + } + + if stats.Kind == "" { + stats.Kind = string(MediaKindAudio) + } + + if stats.Timestamp == 0 { + stats.Timestamp = statsTimestampFrom(now) + } + + allStats = append(allStats, stats) + } + + return allStats +} + +func (t *TrackRemote) setRtxSSRC(ssrc SSRC) { + t.mu.Lock() + defer t.mu.Unlock() + t.rtxSsrc = ssrc +} diff --git a/vendor/github.com/pion/webrtc/v4/webrtc.go b/vendor/github.com/pion/webrtc/v4/webrtc.go new file mode 100644 index 0000000..e1f51ef --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/webrtc.go @@ -0,0 +1,20 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package webrtc implements the WebRTC 1.0 as defined in W3C WebRTC specification document. +package webrtc + +// SSRC represents a synchronization source +// A synchronization source is a randomly chosen +// value meant to be globally unique within a particular +// RTP session. Used to identify a single stream of media. +// +// https://tools.ietf.org/html/rfc3550#section-3 +type SSRC uint32 + +// PayloadType identifies the format of the RTP payload and determines +// its interpretation by the application. Each codec in a RTP Session +// will have a different PayloadType +// +// https://tools.ietf.org/html/rfc3550#section-3 +type PayloadType uint8 diff --git a/vendor/github.com/pion/webrtc/v4/yarn.lock b/vendor/github.com/pion/webrtc/v4/yarn.lock new file mode 100644 index 0000000..fcf971e --- /dev/null +++ b/vendor/github.com/pion/webrtc/v4/yarn.lock @@ -0,0 +1,375 @@ +# THIS IS AN AUTOGENERATED FILE. DO NOT EDIT THIS FILE DIRECTLY. +# yarn lockfile v1 + + +"@roamhq/wrtc-darwin-arm64@0.10.0": + version "0.10.0" + resolved "https://registry.yarnpkg.com/@roamhq/wrtc-darwin-arm64/-/wrtc-darwin-arm64-0.10.0.tgz#8fd9b6eb1c7189fa4f69becef7125d56c82398e4" + integrity sha512-vFdi79jWuPHcnUcnuOjTvyKtmY/RI2xRQo9Y6RsIjIlYePN/7LTy00c+Ivrz4prYAPbp0oHscl7PDV64VUqGTQ== + +"@roamhq/wrtc-darwin-x64@0.10.0": + version "0.10.0" + resolved "https://registry.yarnpkg.com/@roamhq/wrtc-darwin-x64/-/wrtc-darwin-x64-0.10.0.tgz#c860caa6997552b7d7218635f9b587e6e3900f68" + integrity sha512-H6852g2xYCuaR+/TrthpdMafs4bMfAUEpvRDhsIguzrK7Dz+MKpNI8MkwdqJN8W65J+7w7k+YqXIkTHe7Fz/cg== + +"@roamhq/wrtc-linux-arm64@0.10.0": + version "0.10.0" + resolved "https://registry.yarnpkg.com/@roamhq/wrtc-linux-arm64/-/wrtc-linux-arm64-0.10.0.tgz#088f411f1d33decf530d419e4f9dab997bcc3f18" + integrity sha512-fEuJbNjprxQG6QlFd2iqBW9x028RDSho6izVg7gyt8irdPiXWOxzOxNnYMs/B2fohBTd1wD4Qxfivl07/dCR8A== + +"@roamhq/wrtc-linux-x64@0.10.0": + version "0.10.0" + resolved "https://registry.yarnpkg.com/@roamhq/wrtc-linux-x64/-/wrtc-linux-x64-0.10.0.tgz#709d91ee73b66c24825498a3c39fa18fc03dc0b0" + integrity sha512-H32lK2eFg3sVb/9nkHIX5HIisxFoS82Gpesuea+zqAyRpRzSd5NpFXx28bVy9wQyRrNtj8k0bTUgEzWRzSbYCA== + +"@roamhq/wrtc-win32-x64@0.10.0": + version "0.10.0" + resolved "https://registry.yarnpkg.com/@roamhq/wrtc-win32-x64/-/wrtc-win32-x64-0.10.0.tgz#ecc0f3dde8ecc8ed20263a315d385fea51af309f" + integrity sha512-wEVXMvLrBizdLyrd+Zc7zb7zpwUuHUBXwrdIvI69e3i/AA8YsVYI2xo/sxk6GoQ+o8a14ONc4SStDS35TCjg+w== + +"@roamhq/wrtc@^0.10.0": + version "0.10.0" + resolved "https://registry.yarnpkg.com/@roamhq/wrtc/-/wrtc-0.10.0.tgz#eecdfdad778e75b42c5e9f66584872c7dff4155f" + integrity sha512-yFqQQ0EV1ZUHaphh3tmjoxPi2wzhW2vjmzoAVNRRLUjXYd2e1nvwi9TKfE2w4WNvNws/hBkouvOt23Xo9FkXkQ== + optionalDependencies: + "@roamhq/wrtc-darwin-arm64" "0.10.0" + "@roamhq/wrtc-darwin-x64" "0.10.0" + "@roamhq/wrtc-linux-arm64" "0.10.0" + "@roamhq/wrtc-linux-x64" "0.10.0" + "@roamhq/wrtc-win32-x64" "0.10.0" + domexception "^4.0.0" + +ajv@^6.5.5: + version "6.12.2" + resolved "https://registry.yarnpkg.com/ajv/-/ajv-6.12.2.tgz#c629c5eced17baf314437918d2da88c99d5958cd" + integrity sha512-k+V+hzjm5q/Mr8ef/1Y9goCmlsK4I6Sm74teeyGvFk1XrOsbsKLjEdrvny42CZ+a8sXbk8KWpY/bDwS+FLL2UQ== + dependencies: + fast-deep-equal "^3.1.1" + fast-json-stable-stringify "^2.0.0" + json-schema-traverse "^0.4.1" + uri-js "^4.2.2" + +asn1@~0.2.3: + version "0.2.4" + resolved "https://registry.yarnpkg.com/asn1/-/asn1-0.2.4.tgz#8d2475dfab553bb33e77b54e59e880bb8ce23136" + integrity sha512-jxwzQpLQjSmWXgwaCZE9Nz+glAG01yF1QnWgbhGwHI5A6FRIEY6IVqtHhIepHqI7/kyEyQEagBC5mBEFlIYvdg== + dependencies: + safer-buffer "~2.1.0" + +assert-plus@1.0.0, assert-plus@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/assert-plus/-/assert-plus-1.0.0.tgz#f12e0f3c5d77b0b1cdd9146942e4e96c1e4dd525" + integrity sha1-8S4PPF13sLHN2RRpQuTpbB5N1SU= + +asynckit@^0.4.0: + version "0.4.0" + resolved "https://registry.yarnpkg.com/asynckit/-/asynckit-0.4.0.tgz#c79ed97f7f34cb8f2ba1bc9790bcc366474b4b79" + integrity sha1-x57Zf380y48robyXkLzDZkdLS3k= + +aws-sign2@~0.7.0: + version "0.7.0" + resolved "https://registry.yarnpkg.com/aws-sign2/-/aws-sign2-0.7.0.tgz#b46e890934a9591f2d2f6f86d7e6a9f1b3fe76a8" + integrity sha1-tG6JCTSpWR8tL2+G1+ap8bP+dqg= + +aws4@^1.8.0: + version "1.10.0" + resolved "https://registry.yarnpkg.com/aws4/-/aws4-1.10.0.tgz#a17b3a8ea811060e74d47d306122400ad4497ae2" + integrity sha512-3YDiu347mtVtjpyV3u5kVqQLP242c06zwDOgpeRnybmXlYYsLbtTrUBUm8i8srONt+FWobl5aibnU1030PeeuA== + +bcrypt-pbkdf@^1.0.0: + version "1.0.2" + resolved "https://registry.yarnpkg.com/bcrypt-pbkdf/-/bcrypt-pbkdf-1.0.2.tgz#a4301d389b6a43f9b67ff3ca11a3f6637e360e9e" + integrity sha1-pDAdOJtqQ/m2f/PKEaP2Y342Dp4= + dependencies: + tweetnacl "^0.14.3" + +caseless@~0.12.0: + version "0.12.0" + resolved "https://registry.yarnpkg.com/caseless/-/caseless-0.12.0.tgz#1b681c21ff84033c826543090689420d187151dc" + integrity sha1-G2gcIf+EAzyCZUMJBolCDRhxUdw= + +combined-stream@^1.0.6, combined-stream@~1.0.6: + version "1.0.8" + resolved "https://registry.yarnpkg.com/combined-stream/-/combined-stream-1.0.8.tgz#c3d45a8b34fd730631a110a8a2520682b31d5a7f" + integrity sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg== + dependencies: + delayed-stream "~1.0.0" + +core-util-is@1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/core-util-is/-/core-util-is-1.0.2.tgz#b5fd54220aa2bc5ab57aab7140c940754503c1a7" + integrity sha1-tf1UIgqivFq1eqtxQMlAdUUDwac= + +dashdash@^1.12.0: + version "1.14.1" + resolved "https://registry.yarnpkg.com/dashdash/-/dashdash-1.14.1.tgz#853cfa0f7cbe2fed5de20326b8dd581035f6e2f0" + integrity sha1-hTz6D3y+L+1d4gMmuN1YEDX24vA= + dependencies: + assert-plus "^1.0.0" + +delayed-stream@~1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/delayed-stream/-/delayed-stream-1.0.0.tgz#df3ae199acadfb7d440aaae0b29e2272b24ec619" + integrity sha1-3zrhmayt+31ECqrgsp4icrJOxhk= + +domexception@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/domexception/-/domexception-4.0.0.tgz#4ad1be56ccadc86fc76d033353999a8037d03673" + integrity sha512-A2is4PLG+eeSfoTMA95/s4pvAoSo2mKtiM5jlHkAVewmiO8ISFTFKZjH7UAM1Atli/OT/7JHOrJRJiMKUZKYBw== + dependencies: + webidl-conversions "^7.0.0" + +ecc-jsbn@~0.1.1: + version "0.1.2" + resolved "https://registry.yarnpkg.com/ecc-jsbn/-/ecc-jsbn-0.1.2.tgz#3a83a904e54353287874c564b7549386849a98c9" + integrity sha1-OoOpBOVDUyh4dMVkt1SThoSamMk= + dependencies: + jsbn "~0.1.0" + safer-buffer "^2.1.0" + +extend@~3.0.2: + version "3.0.2" + resolved "https://registry.yarnpkg.com/extend/-/extend-3.0.2.tgz#f8b1136b4071fbd8eb140aff858b1019ec2915fa" + integrity sha512-fjquC59cD7CyW6urNXK0FBufkZcoiGG80wTuPujX590cB5Ttln20E2UB4S/WARVqhXffZl2LNgS+gQdPIIim/g== + +extsprintf@1.3.0: + version "1.3.0" + resolved "https://registry.yarnpkg.com/extsprintf/-/extsprintf-1.3.0.tgz#96918440e3041a7a414f8c52e3c574eb3c3e1e05" + integrity sha1-lpGEQOMEGnpBT4xS48V06zw+HgU= + +extsprintf@^1.2.0: + version "1.4.0" + resolved "https://registry.yarnpkg.com/extsprintf/-/extsprintf-1.4.0.tgz#e2689f8f356fad62cca65a3a91c5df5f9551692f" + integrity sha1-4mifjzVvrWLMplo6kcXfX5VRaS8= + +fast-deep-equal@^3.1.1: + version "3.1.1" + resolved "https://registry.yarnpkg.com/fast-deep-equal/-/fast-deep-equal-3.1.1.tgz#545145077c501491e33b15ec408c294376e94ae4" + integrity sha512-8UEa58QDLauDNfpbrX55Q9jrGHThw2ZMdOky5Gl1CDtVeJDPVrG4Jxx1N8jw2gkWaff5UUuX1KJd+9zGe2B+ZA== + +fast-json-stable-stringify@^2.0.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/fast-json-stable-stringify/-/fast-json-stable-stringify-2.1.0.tgz#874bf69c6f404c2b5d99c481341399fd55892633" + integrity sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw== + +forever-agent@~0.6.1: + version "0.6.1" + resolved "https://registry.yarnpkg.com/forever-agent/-/forever-agent-0.6.1.tgz#fbc71f0c41adeb37f96c577ad1ed42d8fdacca91" + integrity sha1-+8cfDEGt6zf5bFd60e1C2P2sypE= + +form-data@~2.3.2: + version "2.3.3" + resolved "https://registry.yarnpkg.com/form-data/-/form-data-2.3.3.tgz#dcce52c05f644f298c6a7ab936bd724ceffbf3a6" + integrity sha512-1lLKB2Mu3aGP1Q/2eCOx0fNbRMe7XdwktwOruhfqqd0rIJWwN4Dh+E3hrPSlDCXnSR7UtZ1N38rVXm+6+MEhJQ== + dependencies: + asynckit "^0.4.0" + combined-stream "^1.0.6" + mime-types "^2.1.12" + +getpass@^0.1.1: + version "0.1.7" + resolved "https://registry.yarnpkg.com/getpass/-/getpass-0.1.7.tgz#5eff8e3e684d569ae4cb2b1282604e8ba62149fa" + integrity sha1-Xv+OPmhNVprkyysSgmBOi6YhSfo= + dependencies: + assert-plus "^1.0.0" + +har-schema@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/har-schema/-/har-schema-2.0.0.tgz#a94c2224ebcac04782a0d9035521f24735b7ec92" + integrity sha1-qUwiJOvKwEeCoNkDVSHyRzW37JI= + +har-validator@~5.1.3: + version "5.1.3" + resolved "https://registry.yarnpkg.com/har-validator/-/har-validator-5.1.3.tgz#1ef89ebd3e4996557675eed9893110dc350fa080" + integrity sha512-sNvOCzEQNr/qrvJgc3UG/kD4QtlHycrzwS+6mfTrrSq97BvaYcPZZI1ZSqGSPR73Cxn4LKTD4PttRwfU7jWq5g== + dependencies: + ajv "^6.5.5" + har-schema "^2.0.0" + +http-signature@~1.2.0: + version "1.2.0" + resolved "https://registry.yarnpkg.com/http-signature/-/http-signature-1.2.0.tgz#9aecd925114772f3d95b65a60abb8f7c18fbace1" + integrity sha1-muzZJRFHcvPZW2WmCruPfBj7rOE= + dependencies: + assert-plus "^1.0.0" + jsprim "^1.2.2" + sshpk "^1.7.0" + +is-typedarray@~1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/is-typedarray/-/is-typedarray-1.0.0.tgz#e479c80858df0c1b11ddda6940f96011fcda4a9a" + integrity sha1-5HnICFjfDBsR3dppQPlgEfzaSpo= + +isstream@~0.1.2: + version "0.1.2" + resolved "https://registry.yarnpkg.com/isstream/-/isstream-0.1.2.tgz#47e63f7af55afa6f92e1500e690eb8b8529c099a" + integrity sha1-R+Y/evVa+m+S4VAOaQ64uFKcCZo= + +jsbn@~0.1.0: + version "0.1.1" + resolved "https://registry.yarnpkg.com/jsbn/-/jsbn-0.1.1.tgz#a5e654c2e5a2deb5f201d96cefbca80c0ef2f513" + integrity sha1-peZUwuWi3rXyAdls77yoDA7y9RM= + +json-schema-traverse@^0.4.1: + version "0.4.1" + resolved "https://registry.yarnpkg.com/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz#69f6a87d9513ab8bb8fe63bdb0979c448e684660" + integrity sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg== + +json-schema@0.2.3: + version "0.2.3" + resolved "https://registry.yarnpkg.com/json-schema/-/json-schema-0.2.3.tgz#b480c892e59a2f05954ce727bd3f2a4e882f9e13" + integrity sha1-tIDIkuWaLwWVTOcnvT8qTogvnhM= + +json-stringify-safe@~5.0.1: + version "5.0.1" + resolved "https://registry.yarnpkg.com/json-stringify-safe/-/json-stringify-safe-5.0.1.tgz#1296a2d58fd45f19a0f6ce01d65701e2c735b6eb" + integrity sha1-Epai1Y/UXxmg9s4B1lcB4sc1tus= + +jsprim@^1.2.2: + version "1.4.1" + resolved "https://registry.yarnpkg.com/jsprim/-/jsprim-1.4.1.tgz#313e66bc1e5cc06e438bc1b7499c2e5c56acb6a2" + integrity sha1-MT5mvB5cwG5Di8G3SZwuXFastqI= + dependencies: + assert-plus "1.0.0" + extsprintf "1.3.0" + json-schema "0.2.3" + verror "1.10.0" + +mime-db@1.44.0: + version "1.44.0" + resolved "https://registry.yarnpkg.com/mime-db/-/mime-db-1.44.0.tgz#fa11c5eb0aca1334b4233cb4d52f10c5a6272f92" + integrity sha512-/NOTfLrsPBVeH7YtFPgsVWveuL+4SjjYxaQ1xtM1KMFj7HdxlBlxeyNLzhyJVx7r4rZGJAZ/6lkKCitSc/Nmpg== + +mime-types@^2.1.12, mime-types@~2.1.19: + version "2.1.27" + resolved "https://registry.yarnpkg.com/mime-types/-/mime-types-2.1.27.tgz#47949f98e279ea53119f5722e0f34e529bec009f" + integrity sha512-JIhqnCasI9yD+SsmkquHBxTSEuZdQX5BuQnS2Vc7puQQQ+8yiP5AY5uWhpdv4YL4VM5c6iliiYWPgJ/nJQLp7w== + dependencies: + mime-db "1.44.0" + +oauth-sign@~0.9.0: + version "0.9.0" + resolved "https://registry.yarnpkg.com/oauth-sign/-/oauth-sign-0.9.0.tgz#47a7b016baa68b5fa0ecf3dee08a85c679ac6455" + integrity sha512-fexhUFFPTGV8ybAtSIGbV6gOkSv8UtRbDBnAyLQw4QPKkgNlsH2ByPGtMUqdWkos6YCRmAqViwgZrJc/mRDzZQ== + +performance-now@^2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/performance-now/-/performance-now-2.1.0.tgz#6309f4e0e5fa913ec1c69307ae364b4b377c9e7b" + integrity sha1-Ywn04OX6kT7BxpMHrjZLSzd8nns= + +psl@^1.1.28: + version "1.8.0" + resolved "https://registry.yarnpkg.com/psl/-/psl-1.8.0.tgz#9326f8bcfb013adcc005fdff056acce020e51c24" + integrity sha512-RIdOzyoavK+hA18OGGWDqUTsCLhtA7IcZ/6NCs4fFJaHBDab+pDDmDIByWFRQJq2Cd7r1OoQxBGKOaztq+hjIQ== + +punycode@^2.1.0, punycode@^2.1.1: + version "2.1.1" + resolved "https://registry.yarnpkg.com/punycode/-/punycode-2.1.1.tgz#b58b010ac40c22c5657616c8d2c2c02c7bf479ec" + integrity sha512-XRsRjdf+j5ml+y/6GKHPZbrF/8p2Yga0JPtdqTIY2Xe5ohJPD9saDJJLPvp9+NSBprVvevdXZybnj2cv8OEd0A== + +qs@~6.5.2: + version "6.5.2" + resolved "https://registry.yarnpkg.com/qs/-/qs-6.5.2.tgz#cb3ae806e8740444584ef154ce8ee98d403f3e36" + integrity sha512-N5ZAX4/LxJmF+7wN74pUD6qAh9/wnvdQcjq9TZjevvXzSUo7bfmw91saqMjzGS2xq91/odN2dW/WOl7qQHNDGA== + +request@2.88.2: + version "2.88.2" + resolved "https://registry.yarnpkg.com/request/-/request-2.88.2.tgz#d73c918731cb5a87da047e207234146f664d12b3" + integrity sha512-MsvtOrfG9ZcrOwAW+Qi+F6HbD0CWXEh9ou77uOb7FM2WPhwT7smM833PzanhJLsgXjN89Ir6V2PczXNnMpwKhw== + dependencies: + aws-sign2 "~0.7.0" + aws4 "^1.8.0" + caseless "~0.12.0" + combined-stream "~1.0.6" + extend "~3.0.2" + forever-agent "~0.6.1" + form-data "~2.3.2" + har-validator "~5.1.3" + http-signature "~1.2.0" + is-typedarray "~1.0.0" + isstream "~0.1.2" + json-stringify-safe "~5.0.1" + mime-types "~2.1.19" + oauth-sign "~0.9.0" + performance-now "^2.1.0" + qs "~6.5.2" + safe-buffer "^5.1.2" + tough-cookie "~2.5.0" + tunnel-agent "^0.6.0" + uuid "^3.3.2" + +safe-buffer@^5.0.1: + version "5.2.1" + resolved "https://registry.yarnpkg.com/safe-buffer/-/safe-buffer-5.2.1.tgz#1eaf9fa9bdb1fdd4ec75f58f9cdb4e6b7827eec6" + integrity sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ== + +safe-buffer@^5.1.2: + version "5.1.2" + resolved "https://registry.yarnpkg.com/safe-buffer/-/safe-buffer-5.1.2.tgz#991ec69d296e0313747d59bdfd2b745c35f8828d" + integrity sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g== + +safer-buffer@^2.0.2, safer-buffer@^2.1.0, safer-buffer@~2.1.0: + version "2.1.2" + resolved "https://registry.yarnpkg.com/safer-buffer/-/safer-buffer-2.1.2.tgz#44fa161b0187b9549dd84bb91802f9bd8385cd6a" + integrity sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg== + +sshpk@^1.7.0: + version "1.16.1" + resolved "https://registry.yarnpkg.com/sshpk/-/sshpk-1.16.1.tgz#fb661c0bef29b39db40769ee39fa70093d6f6877" + integrity sha512-HXXqVUq7+pcKeLqqZj6mHFUMvXtOJt1uoUx09pFW6011inTMxqI8BA8PM95myrIyyKwdnzjdFjLiE6KBPVtJIg== + dependencies: + asn1 "~0.2.3" + assert-plus "^1.0.0" + bcrypt-pbkdf "^1.0.0" + dashdash "^1.12.0" + ecc-jsbn "~0.1.1" + getpass "^0.1.1" + jsbn "~0.1.0" + safer-buffer "^2.0.2" + tweetnacl "~0.14.0" + +tough-cookie@~2.5.0: + version "2.5.0" + resolved "https://registry.yarnpkg.com/tough-cookie/-/tough-cookie-2.5.0.tgz#cd9fb2a0aa1d5a12b473bd9fb96fa3dcff65ade2" + integrity sha512-nlLsUzgm1kfLXSXfRZMc1KLAugd4hqJHDTvc2hDIwS3mZAfMEuMbc03SujMF+GEcpaX/qboeycw6iO8JwVv2+g== + dependencies: + psl "^1.1.28" + punycode "^2.1.1" + +tunnel-agent@^0.6.0: + version "0.6.0" + resolved "https://registry.yarnpkg.com/tunnel-agent/-/tunnel-agent-0.6.0.tgz#27a5dea06b36b04a0a9966774b290868f0fc40fd" + integrity sha1-J6XeoGs2sEoKmWZ3SykIaPD8QP0= + dependencies: + safe-buffer "^5.0.1" + +tweetnacl@^0.14.3, tweetnacl@~0.14.0: + version "0.14.5" + resolved "https://registry.yarnpkg.com/tweetnacl/-/tweetnacl-0.14.5.tgz#5ae68177f192d4456269d108afa93ff8743f4f64" + integrity sha1-WuaBd/GS1EViadEIr6k/+HQ/T2Q= + +uri-js@^4.2.2: + version "4.2.2" + resolved "https://registry.yarnpkg.com/uri-js/-/uri-js-4.2.2.tgz#94c540e1ff772956e2299507c010aea6c8838eb0" + integrity sha512-KY9Frmirql91X2Qgjry0Wd4Y+YTdrdZheS8TFwvkbLWf/G5KNJDCh6pKL5OZctEW4+0Baa5idK2ZQuELRwPznQ== + dependencies: + punycode "^2.1.0" + +uuid@^3.3.2: + version "3.4.0" + resolved "https://registry.yarnpkg.com/uuid/-/uuid-3.4.0.tgz#b23e4358afa8a202fe7a100af1f5f883f02007ee" + integrity sha512-HjSDRw6gZE5JMggctHBcjVak08+KEVhSIiDzFnT9S9aegmp85S/bReBVTb4QTFaRNptJ9kuYaNhnbNEOkbKb/A== + +verror@1.10.0: + version "1.10.0" + resolved "https://registry.yarnpkg.com/verror/-/verror-1.10.0.tgz#3a105ca17053af55d6e270c1f8288682e18da400" + integrity sha1-OhBcoXBTr1XW4nDB+CiGguGNpAA= + dependencies: + assert-plus "^1.0.0" + core-util-is "1.0.2" + extsprintf "^1.2.0" + +webidl-conversions@^7.0.0: + version "7.0.0" + resolved "https://registry.yarnpkg.com/webidl-conversions/-/webidl-conversions-7.0.0.tgz#256b4e1882be7debbf01d05f0aa2039778ea080a" + integrity sha512-VwddBukDzu71offAQR975unBIGqfKZpM+8ZX6ySk8nYhVoo5CYaZyzt3YBvYtRtO+aoGlqxPg/B87NGVZ/fu6g== diff --git a/vendor/github.com/wlynxg/anet/.gitignore b/vendor/github.com/wlynxg/anet/.gitignore new file mode 100644 index 0000000..723ef36 --- /dev/null +++ b/vendor/github.com/wlynxg/anet/.gitignore @@ -0,0 +1 @@ +.idea \ No newline at end of file diff --git a/vendor/github.com/wlynxg/anet/LICENSE b/vendor/github.com/wlynxg/anet/LICENSE new file mode 100644 index 0000000..db6fa36 --- /dev/null +++ b/vendor/github.com/wlynxg/anet/LICENSE @@ -0,0 +1,28 @@ +BSD 3-Clause License + +Copyright (c) 2023, wlynxg + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/wlynxg/anet/README.md b/vendor/github.com/wlynxg/anet/README.md new file mode 100644 index 0000000..93fe68b --- /dev/null +++ b/vendor/github.com/wlynxg/anet/README.md @@ -0,0 +1,126 @@ +## Introduction +In response to the modifications made to the permissions for accessing system MAC addresses in Android 11, ordinary applications encounter several main issues when using NETLINK sockets: + +- Not allowing bind operations on `NETLINK` sockets. +- Not permitting the use of the `RTM_GETLINK` functionality. + +For detailed information, please refer to: https://developer.android.com/training/articles/user-data-ids#mac-11-plus + +As a result of the aforementioned reasons, using `net.Interfaces()` and `net.InterfaceAddrs()` from the Go net package in the Android environment leads to the `route ip+net: netlinkrib: permission denied` error. + +You can find specific issue details here: https://github.com/golang/go/issues/40569 + +To address the issue of using the Go net package in the Android environment, we have made partial modifications to its source code to ensure proper functionality on Android. + +I have fully resolved the issues with `net.InterfaceAddrs()`. + +However, for `net.Interfaces()`, we have only addressed some problems, as the following issues still remain: +- It can only return interfaces with IP addresses. +- It cannot return hardware MAC addresses. + +Nevertheless, the fixed `net.Interfaces()` function now aligns with the Android API's `NetworkInterface.getNetworkInterfaces()` and can be used normally in most scenarios. + +The specific fix logic includes: + +Removing the `Bind()` operation on `Netlink` sockets in the `NetlinkRIB()` function. +Using `ioctl` based on the Index number returned by `RTM_GETADDR` to retrieve the network card's name, MTU, and flags. + +There are two implementations of the `net` package: one from the [Go standard library](https://pkg.go.dev/net) and another from the [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) module. Both of these implementations have the same issues in the Android environment. The `anet` package should be compatible with both of them. + +## Test Code +### net.Interface() +use `net.Interface()`: +```go +func RawInterface() { + interfaces, err := net.Interfaces() + if err != nil { + panic(err) + } + + for _, i := range interfaces { + log.Println(i) + } +} +``` +result: +``` +panic: route ip+net: netlinkrib: permission denied +``` + +use `anet.Interface()`: +```go +func AnetInterface() { + interfaces, err := anet.Interfaces() + if err != nil { + panic(err) + } + + for _, i := range interfaces { + log.Println(i) + } +} +``` + +result: +``` +{1 65536 lo up|loopback|running} +{15 1400 rmnet_data1 up|running} +{24 1500 wlan0 up|broadcast|multicast|running} +{3 1500 dummy0 up|broadcast|running} +{4 1500 ifb0 up|broadcast|running} +{5 1500 ifb1 up|broadcast|running} +{12 1500 ifb2 up|broadcast|running} +{14 1500 rmnet_data0 up|running} +{16 1400 rmnet_data2 up|running} +{17 1400 rmnet_data3 up|running} +``` + +### net.InterfaceAddrs() +use `net.InterfaceAddrs()`: +```go +func NetInterfaceAddrs() { + addrs, err := net.InterfaceAddrs() + if err != nil { + panic(err) + } + + for _, addr := range addrs { + log.Println(addr) + } +} +``` +result: +``` +panic: route ip+net: netlinkrib: permission denied +``` + +use `anet.InterfaceAddrs()`: +```go +func AnetInterfaceAddrs() { + addrs, err := anet.InterfaceAddrs() + if err != nil { + panic(err) + } + + for _, addr := range addrs { + log.Println(addr) + } +} +``` +result: +``` +127.0.0.1/8 +::1/128 +... +192.168.6.143/24 +fe80::7e4f:4446:eb3:1eb8/64 +``` + +## Other issues due to #40569 +- https://github.com/golang/go/issues/68082 + +## How to build with Go 1.23.0 or later +The `anet` library internally relies on `//go:linkname` directive. Since the usage of `//go:linkname` has been restricted since Go 1.23.0 ([Go 1.23 Release Notes](https://tip.golang.org/doc/go1.23#linker)), it is necessary to specify the `-checklinkname=0` linker flag when building the `anet` package with Go 1.23.0 or later. Without this flag, the following linker error will occur: +``` +link: github.com/wlynxg/anet: invalid reference to net.zoneCache +``` diff --git a/vendor/github.com/wlynxg/anet/README_zh.md b/vendor/github.com/wlynxg/anet/README_zh.md new file mode 100644 index 0000000..07e9a3c --- /dev/null +++ b/vendor/github.com/wlynxg/anet/README_zh.md @@ -0,0 +1,25 @@ +针对Android 11之后对访问系统MAC地址的权限进行了修改的问题,导致普通应用在调用`NETLINK`套接字时会遇到以下几个主要问题: +- 不允许对`NETLINK`套接字进行`bind`操作。 +- 不允许调用`RTM_GETLINK`功能。 + +详细说明可以在此链接找到:https://developer.android.com/training/articles/user-data-ids#mac-11-plus + +由于上述两个原因,导致在安卓环境下使用Go net包中的`net.Interfaces()`和`net.InterfaceAddrs()`时会抛出`route ip+net: netlinkrib: permission denied`错误。 +具体 issue 可见:https://github.com/golang/go/issues/40569 + +为了解决在安卓环境下使用Go net包的问题,我们对其源代码进行了部分改造,以使其能够在Android上正常工作。 + +对于`net.InterfaceAddrs()`,我已经完全解决了其中的问题; +对于`net.Interfaces()`,我只解决了部分问题,目前仍存在以下问题: +- 只能返回具有IP地址的接口。 +- 不能返回硬件的MAC地址。 + +但是修复后的`net.Interfaces()`函数现在与Android API的`NetworkInterface.getNetworkInterfaces()`保持一致,在大多数情况下可正常使用。 + +具体修复逻辑包括: + +- 取消了`NetlinkRIB()`函数中对`Netlink`套接字的`Bind()`操作。 +- 根据`RTM_GETADDR`返回的Index号,使用`ioctl`获取其网卡的名称、MTU和标志位。 + +## 由于 #40569 导致的其他问题 +- #[68082](https://github.com/golang/go/issues/68082) \ No newline at end of file diff --git a/vendor/github.com/wlynxg/anet/android_api_level.go b/vendor/github.com/wlynxg/anet/android_api_level.go new file mode 100644 index 0000000..a0a93e8 --- /dev/null +++ b/vendor/github.com/wlynxg/anet/android_api_level.go @@ -0,0 +1,7 @@ +//go:build !(android && cgo) + +package anet + +func androidDeviceApiLevel() int { + return -1 +} diff --git a/vendor/github.com/wlynxg/anet/android_api_level_cgo.go b/vendor/github.com/wlynxg/anet/android_api_level_cgo.go new file mode 100644 index 0000000..992feee --- /dev/null +++ b/vendor/github.com/wlynxg/anet/android_api_level_cgo.go @@ -0,0 +1,23 @@ +//go:build android && cgo + +package anet + +// #include +import "C" + +import "sync" + +var ( + apiLevel int + once sync.Once +) + +// Returns the API level of the device we're actually running on, or -1 on failure. +// The returned value is equivalent to the Java Build.VERSION.SDK_INT API. +func androidDeviceApiLevel() int { + once.Do(func() { + apiLevel = int(C.android_get_device_api_level()) + }) + + return apiLevel +} diff --git a/vendor/github.com/wlynxg/anet/interface.go b/vendor/github.com/wlynxg/anet/interface.go new file mode 100644 index 0000000..384f29c --- /dev/null +++ b/vendor/github.com/wlynxg/anet/interface.go @@ -0,0 +1,30 @@ +//go:build !android +// +build !android + +package anet + +import ( + "net" +) + +// Interfaces returns a list of the system's network interfaces. +func Interfaces() ([]net.Interface, error) { + return net.Interfaces() +} + +// InterfaceAddrs returns a list of the system's unicast interface +// addresses. +// +// The returned list does not identify the associated interface; use +// Interfaces and Interface.Addrs for more detail. +func InterfaceAddrs() ([]net.Addr, error) { + return net.InterfaceAddrs() +} + +// InterfaceAddrsByInterface returns a list of the system's unicast +// interface addresses by specific interface. +func InterfaceAddrsByInterface(ifi *net.Interface) ([]net.Addr, error) { + return ifi.Addrs() +} + +func SetAndroidVersion(version uint) {} diff --git a/vendor/github.com/wlynxg/anet/interface_android.go b/vendor/github.com/wlynxg/anet/interface_android.go new file mode 100644 index 0000000..6b080c5 --- /dev/null +++ b/vendor/github.com/wlynxg/anet/interface_android.go @@ -0,0 +1,442 @@ +package anet + +import ( + "bytes" + "errors" + "net" + "os" + "sync" + "syscall" + "time" + "unsafe" +) + +const ( + android11ApiLevel = 30 +) + +var ( + customAndroidApiLevel = -1 + errInvalidInterface = errors.New("invalid network interface") + errInvalidInterfaceIndex = errors.New("invalid network interface index") + errInvalidInterfaceName = errors.New("invalid network interface name") + errNoSuchInterface = errors.New("no such network interface") + errNoSuchMulticastInterface = errors.New("no such multicast network interface") +) + +type ifReq [40]byte + +// Interfaces returns a list of the system's network interfaces. +func Interfaces() ([]net.Interface, error) { + if androidApiLevel() < android11ApiLevel { + return net.Interfaces() + } + + ift, err := interfaceTable(0) + if err != nil { + return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: err} + } + if len(ift) != 0 { + zoneCache.update(ift, true) + zoneCacheX.update(ift, true) + } + return ift, nil +} + +// InterfaceAddrs returns a list of the system's unicast interface +// addresses. +// +// The returned list does not identify the associated interface; use +// Interfaces and Interface.Addrs for more detail. +func InterfaceAddrs() ([]net.Addr, error) { + if androidApiLevel() < android11ApiLevel { + return net.InterfaceAddrs() + } + + ifat, err := interfaceAddrTable(nil) + if err != nil { + err = &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: err} + } + return ifat, err +} + +// InterfaceByIndex returns the interface specified by index. +// +// On Solaris, it returns one of the logical network interfaces +// sharing the logical data link; for more precision use +// InterfaceByName. +func InterfaceByIndex(index int) (*net.Interface, error) { + if androidApiLevel() < android11ApiLevel { + return net.InterfaceByIndex(index) + } + + if index <= 0 { + return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errInvalidInterfaceIndex} + } + ift, err := interfaceTable(index) + if err != nil { + return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: err} + } + ifi, err := interfaceByIndex(ift, index) + if err != nil { + err = &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: err} + } + return ifi, err +} + +// InterfaceByName returns the interface specified by name. +func InterfaceByName(name string) (*net.Interface, error) { + if name == "" { + return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errInvalidInterfaceName} + } + ift, err := interfaceTable(0) + if err != nil { + return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: err} + } + if len(ift) != 0 { + zoneCache.update(ift, true) + zoneCacheX.update(ift, true) + } + for _, ifi := range ift { + if name == ifi.Name { + return &ifi, nil + } + } + return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errNoSuchInterface} +} + +// InterfaceAddrsByInterface returns a list of the system's unicast +// interface addresses by specific interface. +func InterfaceAddrsByInterface(ifi *net.Interface) ([]net.Addr, error) { + if ifi == nil { + return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errInvalidInterface} + } + + if androidApiLevel() < android11ApiLevel { + return ifi.Addrs() + } + + ifat, err := interfaceAddrTable(ifi) + if err != nil { + err = &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: err} + } + return ifat, err +} + +// SetAndroidVersion set the Android environment in which the program runs. +// The Android system version number can be obtained through +// `android.os.Build.VERSION.RELEASE` of the Android framework. +// If version is 0 the actual version will be detected automatically if possible. +func SetAndroidVersion(version uint) { + switch { + case version == 0: + customAndroidApiLevel = -1 + case version >= 11: + customAndroidApiLevel = android11ApiLevel + default: + customAndroidApiLevel = 0 + } +} + +func androidApiLevel() int { + if customAndroidApiLevel != -1 { + // user-provided api level should be used + return customAndroidApiLevel + } + + // try to autodetect api level + return androidDeviceApiLevel() +} + +// An ipv6ZoneCache represents a cache holding partial network +// interface information. It is used for reducing the cost of IPv6 +// addressing scope zone resolution. +// +// Multiple names sharing the index are managed by first-come +// first-served basis for consistency. +type ipv6ZoneCache struct { + sync.RWMutex // guard the following + lastFetched time.Time // last time routing information was fetched + toIndex map[string]int // interface name to its index + toName map[int]string // interface index to its name +} + +//go:linkname zoneCache net.zoneCache +var zoneCache ipv6ZoneCache + +//go:linkname zoneCacheX golang.org/x/net/internal/socket.zoneCache +var zoneCacheX ipv6ZoneCache + +// update refreshes the network interface information if the cache was last +// updated more than 1 minute ago, or if force is set. It reports whether the +// cache was updated. +func (zc *ipv6ZoneCache) update(ift []net.Interface, force bool) (updated bool) { + zc.Lock() + defer zc.Unlock() + now := time.Now() + if !force && zc.lastFetched.After(now.Add(-60*time.Second)) { + return false + } + zc.lastFetched = now + if len(ift) == 0 { + var err error + if ift, err = interfaceTable(0); err != nil { + return false + } + } + zc.toIndex = make(map[string]int, len(ift)) + zc.toName = make(map[int]string, len(ift)) + for _, ifi := range ift { + zc.toIndex[ifi.Name] = ifi.Index + if _, ok := zc.toName[ifi.Index]; !ok { + zc.toName[ifi.Index] = ifi.Name + } + } + return true +} + +// If the ifindex is zero, interfaceTable returns mappings of all +// network interfaces. Otherwise it returns a mapping of a specific +// interface. +func interfaceTable(ifindex int) ([]net.Interface, error) { + tab, err := NetlinkRIB(syscall.RTM_GETADDR, syscall.AF_UNSPEC) + if err != nil { + return nil, os.NewSyscallError("netlinkrib", err) + } + msgs, err := syscall.ParseNetlinkMessage(tab) + if err != nil { + return nil, os.NewSyscallError("parsenetlinkmessage", err) + } + + var ift []net.Interface + im := make(map[uint32]struct{}) +loop: + for _, m := range msgs { + switch m.Header.Type { + case syscall.NLMSG_DONE: + break loop + case syscall.RTM_NEWADDR: + ifam := (*syscall.IfAddrmsg)(unsafe.Pointer(&m.Data[0])) + if _, ok := im[ifam.Index]; ok { + continue + } else { + im[ifam.Index] = struct{}{} + } + + if ifindex == 0 || ifindex == int(ifam.Index) { + ifi := newLink(ifam) + if ifi != nil { + ift = append(ift, *ifi) + } + if ifindex == int(ifam.Index) { + break loop + } + } + } + } + + return ift, nil +} + +func newLink(ifam *syscall.IfAddrmsg) *net.Interface { + ift := &net.Interface{Index: int(ifam.Index)} + + name, err := indexToName(ifam.Index) + if err != nil { + return nil + } + ift.Name = name + + mtu, err := nameToMTU(name) + if err != nil { + return nil + } + ift.MTU = mtu + + flags, err := nameToFlags(name) + if err != nil { + return nil + } + ift.Flags = flags + return ift +} + +func linkFlags(rawFlags uint32) net.Flags { + var f net.Flags + if rawFlags&syscall.IFF_UP != 0 { + f |= net.FlagUp + } + if rawFlags&syscall.IFF_RUNNING != 0 { + f |= net.FlagRunning + } + if rawFlags&syscall.IFF_BROADCAST != 0 { + f |= net.FlagBroadcast + } + if rawFlags&syscall.IFF_LOOPBACK != 0 { + f |= net.FlagLoopback + } + if rawFlags&syscall.IFF_POINTOPOINT != 0 { + f |= net.FlagPointToPoint + } + if rawFlags&syscall.IFF_MULTICAST != 0 { + f |= net.FlagMulticast + } + return f +} + +// If the ifi is nil, interfaceAddrTable returns addresses for all +// network interfaces. Otherwise it returns addresses for a specific +// interface. +func interfaceAddrTable(ifi *net.Interface) ([]net.Addr, error) { + tab, err := NetlinkRIB(syscall.RTM_GETADDR, syscall.AF_UNSPEC) + if err != nil { + return nil, os.NewSyscallError("netlinkrib", err) + } + msgs, err := syscall.ParseNetlinkMessage(tab) + if err != nil { + return nil, os.NewSyscallError("parsenetlinkmessage", err) + } + + var ift []net.Interface + if ifi == nil { + var err error + ift, err = interfaceTable(0) + if err != nil { + return nil, err + } + } + ifat, err := addrTable(ift, ifi, msgs) + if err != nil { + return nil, err + } + return ifat, nil +} + +func addrTable(ift []net.Interface, ifi *net.Interface, msgs []syscall.NetlinkMessage) ([]net.Addr, error) { + var ifat []net.Addr +loop: + for _, m := range msgs { + switch m.Header.Type { + case syscall.NLMSG_DONE: + break loop + case syscall.RTM_NEWADDR: + ifam := (*syscall.IfAddrmsg)(unsafe.Pointer(&m.Data[0])) + if len(ift) != 0 || ifi.Index == int(ifam.Index) { + attrs, err := syscall.ParseNetlinkRouteAttr(&m) + if err != nil { + return nil, os.NewSyscallError("parsenetlinkrouteattr", err) + } + ifa := newAddr(ifam, attrs) + if ifa != nil { + ifat = append(ifat, ifa) + } + } + } + } + return ifat, nil +} + +func newAddr(ifam *syscall.IfAddrmsg, attrs []syscall.NetlinkRouteAttr) net.Addr { + var ipPointToPoint bool + // Seems like we need to make sure whether the IP interface + // stack consists of IP point-to-point numbered or unnumbered + // addressing. + for _, a := range attrs { + if a.Attr.Type == syscall.IFA_LOCAL { + ipPointToPoint = true + break + } + } + for _, a := range attrs { + if ipPointToPoint && a.Attr.Type == syscall.IFA_ADDRESS { + continue + } + switch ifam.Family { + case syscall.AF_INET: + return &net.IPNet{IP: net.IPv4(a.Value[0], a.Value[1], a.Value[2], a.Value[3]), Mask: net.CIDRMask(int(ifam.Prefixlen), 8*net.IPv4len)} + case syscall.AF_INET6: + ifa := &net.IPNet{IP: make(net.IP, net.IPv6len), Mask: net.CIDRMask(int(ifam.Prefixlen), 8*net.IPv6len)} + copy(ifa.IP, a.Value[:]) + return ifa + } + } + return nil +} + +func interfaceByIndex(ift []net.Interface, index int) (*net.Interface, error) { + for _, ifi := range ift { + if index == ifi.Index { + return &ifi, nil + } + } + return nil, errNoSuchInterface +} + +func ioctl(fd int, req uint, arg unsafe.Pointer) error { + _, _, e1 := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), uintptr(req), uintptr(arg)) + if e1 != 0 { + return e1 + } + return nil +} + +func indexToName(index uint32) (string, error) { + fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM|syscall.SOCK_CLOEXEC, 0) + if err != nil { + return "", err + } + defer syscall.Close(fd) + + var ifr ifReq + *(*uint32)(unsafe.Pointer(&ifr[syscall.IFNAMSIZ])) = index + err = ioctl(fd, syscall.SIOCGIFNAME, unsafe.Pointer(&ifr[0])) + if err != nil { + return "", err + } + + return string(bytes.Trim(ifr[:syscall.IFNAMSIZ], "\x00")), nil +} + +func nameToMTU(name string) (int, error) { + // Leave room for terminating NULL byte. + if len(name) >= syscall.IFNAMSIZ { + return -1, syscall.EINVAL + } + + fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM|syscall.SOCK_CLOEXEC, 0) + if err != nil { + return -1, err + } + defer syscall.Close(fd) + + var ifr ifReq + copy(ifr[:], name) + err = ioctl(fd, syscall.SIOCGIFMTU, unsafe.Pointer(&ifr[0])) + if err != nil { + return -1, err + } + + return int(*(*int32)(unsafe.Pointer(&ifr[syscall.IFNAMSIZ]))), nil +} + +func nameToFlags(name string) (net.Flags, error) { + // Leave room for terminating NULL byte. + if len(name) >= syscall.IFNAMSIZ { + return 0, syscall.EINVAL + } + + fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM|syscall.SOCK_CLOEXEC, 0) + if err != nil { + return 0, err + } + defer syscall.Close(fd) + + var ifr ifReq + copy(ifr[:], name) + err = ioctl(fd, syscall.SIOCGIFFLAGS, unsafe.Pointer(&ifr[0])) + if err != nil { + return 0, err + } + + return linkFlags(*(*uint32)(unsafe.Pointer(&ifr[syscall.IFNAMSIZ]))), nil +} diff --git a/vendor/github.com/wlynxg/anet/netlink_android.go b/vendor/github.com/wlynxg/anet/netlink_android.go new file mode 100644 index 0000000..fc0d84d --- /dev/null +++ b/vendor/github.com/wlynxg/anet/netlink_android.go @@ -0,0 +1,179 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Netlink sockets and messages + +package anet + +import ( + "syscall" + "unsafe" +) + +// Round the length of a netlink message up to align it properly. +func nlmAlignOf(msglen int) int { + return (msglen + syscall.NLMSG_ALIGNTO - 1) & ^(syscall.NLMSG_ALIGNTO - 1) +} + +// Round the length of a netlink route attribute up to align it +// properly. +func rtaAlignOf(attrlen int) int { + return (attrlen + syscall.RTA_ALIGNTO - 1) & ^(syscall.RTA_ALIGNTO - 1) +} + +// NetlinkRouteRequest represents a request message to receive routing +// and link states from the kernel. +type NetlinkRouteRequest struct { + Header syscall.NlMsghdr + Data syscall.RtGenmsg +} + +func (rr *NetlinkRouteRequest) toWireFormat() []byte { + b := make([]byte, rr.Header.Len) + *(*uint32)(unsafe.Pointer(&b[0:4][0])) = rr.Header.Len + *(*uint16)(unsafe.Pointer(&b[4:6][0])) = rr.Header.Type + *(*uint16)(unsafe.Pointer(&b[6:8][0])) = rr.Header.Flags + *(*uint32)(unsafe.Pointer(&b[8:12][0])) = rr.Header.Seq + *(*uint32)(unsafe.Pointer(&b[12:16][0])) = rr.Header.Pid + b[16] = byte(rr.Data.Family) + return b +} + +func newNetlinkRouteRequest(proto, seq, family int) []byte { + rr := &NetlinkRouteRequest{} + rr.Header.Len = uint32(syscall.NLMSG_HDRLEN + syscall.SizeofRtGenmsg) + rr.Header.Type = uint16(proto) + rr.Header.Flags = syscall.NLM_F_DUMP | syscall.NLM_F_REQUEST + rr.Header.Seq = uint32(seq) + rr.Data.Family = uint8(family) + return rr.toWireFormat() +} + +// NetlinkRIB returns routing information base, as known as RIB, which +// consists of network facility information, states and parameters. +func NetlinkRIB(proto, family int) ([]byte, error) { + s, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_RAW|syscall.SOCK_CLOEXEC, syscall.NETLINK_ROUTE) + if err != nil { + return nil, err + } + defer syscall.Close(s) + sa := &syscall.SockaddrNetlink{Family: syscall.AF_NETLINK} + + wb := newNetlinkRouteRequest(proto, 1, family) + if err := syscall.Sendto(s, wb, 0, sa); err != nil { + return nil, err + } + lsa, err := syscall.Getsockname(s) + if err != nil { + return nil, err + } + lsanl, ok := lsa.(*syscall.SockaddrNetlink) + if !ok { + return nil, syscall.EINVAL + } + var tab []byte + rbNew := make([]byte, syscall.Getpagesize()) +done: + for { + rb := rbNew + nr, _, err := syscall.Recvfrom(s, rb, 0) + if err != nil { + return nil, err + } + if nr < syscall.NLMSG_HDRLEN { + return nil, syscall.EINVAL + } + rb = rb[:nr] + tab = append(tab, rb...) + msgs, err := ParseNetlinkMessage(rb) + if err != nil { + return nil, err + } + for _, m := range msgs { + if m.Header.Seq != 1 || m.Header.Pid != lsanl.Pid { + return nil, syscall.EINVAL + } + if m.Header.Type == syscall.NLMSG_DONE { + break done + } + if m.Header.Type == syscall.NLMSG_ERROR { + return nil, syscall.EINVAL + } + } + } + return tab, nil +} + +// NetlinkMessage represents a netlink message. +type NetlinkMessage struct { + Header syscall.NlMsghdr + Data []byte +} + +// ParseNetlinkMessage parses b as an array of netlink messages and +// returns the slice containing the NetlinkMessage structures. +func ParseNetlinkMessage(b []byte) ([]NetlinkMessage, error) { + var msgs []NetlinkMessage + for len(b) >= syscall.NLMSG_HDRLEN { + h, dbuf, dlen, err := netlinkMessageHeaderAndData(b) + if err != nil { + return nil, err + } + m := NetlinkMessage{Header: *h, Data: dbuf[:int(h.Len)-syscall.NLMSG_HDRLEN]} + msgs = append(msgs, m) + b = b[dlen:] + } + return msgs, nil +} + +func netlinkMessageHeaderAndData(b []byte) (*syscall.NlMsghdr, []byte, int, error) { + h := (*syscall.NlMsghdr)(unsafe.Pointer(&b[0])) + l := nlmAlignOf(int(h.Len)) + if int(h.Len) < syscall.NLMSG_HDRLEN || l > len(b) { + return nil, nil, 0, syscall.EINVAL + } + return h, b[syscall.NLMSG_HDRLEN:], l, nil +} + +// NetlinkRouteAttr represents a netlink route attribute. +type NetlinkRouteAttr struct { + Attr syscall.RtAttr + Value []byte +} + +// ParseNetlinkRouteAttr parses m's payload as an array of netlink +// route attributes and returns the slice containing the +// NetlinkRouteAttr structures. +func ParseNetlinkRouteAttr(m *NetlinkMessage) ([]NetlinkRouteAttr, error) { + var b []byte + switch m.Header.Type { + case syscall.RTM_NEWLINK, syscall.RTM_DELLINK: + b = m.Data[syscall.SizeofIfInfomsg:] + case syscall.RTM_NEWADDR, syscall.RTM_DELADDR: + b = m.Data[syscall.SizeofIfAddrmsg:] + case syscall.RTM_NEWROUTE, syscall.RTM_DELROUTE: + b = m.Data[syscall.SizeofRtMsg:] + default: + return nil, syscall.EINVAL + } + var attrs []NetlinkRouteAttr + for len(b) >= syscall.SizeofRtAttr { + a, vbuf, alen, err := netlinkRouteAttrAndValue(b) + if err != nil { + return nil, err + } + ra := NetlinkRouteAttr{Attr: *a, Value: vbuf[:int(a.Len)-syscall.SizeofRtAttr]} + attrs = append(attrs, ra) + b = b[alen:] + } + return attrs, nil +} + +func netlinkRouteAttrAndValue(b []byte) (*syscall.RtAttr, []byte, int, error) { + a := (*syscall.RtAttr)(unsafe.Pointer(&b[0])) + if int(a.Len) < syscall.SizeofRtAttr || int(a.Len) > len(b) { + return nil, nil, 0, syscall.EINVAL + } + return a, b[syscall.SizeofRtAttr:], rtaAlignOf(int(a.Len)), nil +} diff --git a/vendor/golang.org/x/net/dns/dnsmessage/message.go b/vendor/golang.org/x/net/dns/dnsmessage/message.go new file mode 100644 index 0000000..7a978b4 --- /dev/null +++ b/vendor/golang.org/x/net/dns/dnsmessage/message.go @@ -0,0 +1,2741 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package dnsmessage provides a mostly RFC 1035 compliant implementation of +// DNS message packing and unpacking. +// +// The package also supports messages with Extension Mechanisms for DNS +// (EDNS(0)) as defined in RFC 6891. +// +// This implementation is designed to minimize heap allocations and avoid +// unnecessary packing and unpacking as much as possible. +package dnsmessage + +import ( + "errors" +) + +// Message formats +// +// To add a new Resource Record type: +// 1. Create Resource Record types +// 1.1. Add a Type constant named "Type" +// 1.2. Add the corresponding entry to the typeNames map +// 1.3. Add a [ResourceBody] implementation named "Resource" +// 2. Implement packing +// 2.1. Implement Builder.Resource() +// 3. Implement unpacking +// 3.1. Add the unpacking code to unpackResourceBody() +// 3.2. Implement Parser.Resource() + +// A Type is the type of a DNS Resource Record, as defined in the [IANA registry]. +// +// [IANA registry]: https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-4 +type Type uint16 + +const ( + // ResourceHeader.Type and Question.Type + TypeA Type = 1 + TypeNS Type = 2 + TypeCNAME Type = 5 + TypeSOA Type = 6 + TypePTR Type = 12 + TypeMX Type = 15 + TypeTXT Type = 16 + TypeAAAA Type = 28 + TypeSRV Type = 33 + TypeOPT Type = 41 + TypeSVCB Type = 64 + TypeHTTPS Type = 65 + + // Question.Type + TypeWKS Type = 11 + TypeHINFO Type = 13 + TypeMINFO Type = 14 + TypeAXFR Type = 252 + TypeALL Type = 255 +) + +var typeNames = map[Type]string{ + TypeA: "TypeA", + TypeNS: "TypeNS", + TypeCNAME: "TypeCNAME", + TypeSOA: "TypeSOA", + TypePTR: "TypePTR", + TypeMX: "TypeMX", + TypeTXT: "TypeTXT", + TypeAAAA: "TypeAAAA", + TypeSRV: "TypeSRV", + TypeOPT: "TypeOPT", + TypeSVCB: "TypeSVCB", + TypeHTTPS: "TypeHTTPS", + TypeWKS: "TypeWKS", + TypeHINFO: "TypeHINFO", + TypeMINFO: "TypeMINFO", + TypeAXFR: "TypeAXFR", + TypeALL: "TypeALL", +} + +// String implements fmt.Stringer.String. +func (t Type) String() string { + if n, ok := typeNames[t]; ok { + return n + } + return printUint16(uint16(t)) +} + +// GoString implements fmt.GoStringer.GoString. +func (t Type) GoString() string { + if n, ok := typeNames[t]; ok { + return "dnsmessage." + n + } + return printUint16(uint16(t)) +} + +// A Class is a type of network. +type Class uint16 + +const ( + // ResourceHeader.Class and Question.Class + ClassINET Class = 1 + ClassCSNET Class = 2 + ClassCHAOS Class = 3 + ClassHESIOD Class = 4 + + // Question.Class + ClassANY Class = 255 +) + +var classNames = map[Class]string{ + ClassINET: "ClassINET", + ClassCSNET: "ClassCSNET", + ClassCHAOS: "ClassCHAOS", + ClassHESIOD: "ClassHESIOD", + ClassANY: "ClassANY", +} + +// String implements fmt.Stringer.String. +func (c Class) String() string { + if n, ok := classNames[c]; ok { + return n + } + return printUint16(uint16(c)) +} + +// GoString implements fmt.GoStringer.GoString. +func (c Class) GoString() string { + if n, ok := classNames[c]; ok { + return "dnsmessage." + n + } + return printUint16(uint16(c)) +} + +// An OpCode is a DNS operation code. +type OpCode uint16 + +// GoString implements fmt.GoStringer.GoString. +func (o OpCode) GoString() string { + return printUint16(uint16(o)) +} + +// An RCode is a DNS response status code. +type RCode uint16 + +// Header.RCode values. +const ( + RCodeSuccess RCode = 0 // NoError + RCodeFormatError RCode = 1 // FormErr + RCodeServerFailure RCode = 2 // ServFail + RCodeNameError RCode = 3 // NXDomain + RCodeNotImplemented RCode = 4 // NotImp + RCodeRefused RCode = 5 // Refused +) + +var rCodeNames = map[RCode]string{ + RCodeSuccess: "RCodeSuccess", + RCodeFormatError: "RCodeFormatError", + RCodeServerFailure: "RCodeServerFailure", + RCodeNameError: "RCodeNameError", + RCodeNotImplemented: "RCodeNotImplemented", + RCodeRefused: "RCodeRefused", +} + +// String implements fmt.Stringer.String. +func (r RCode) String() string { + if n, ok := rCodeNames[r]; ok { + return n + } + return printUint16(uint16(r)) +} + +// GoString implements fmt.GoStringer.GoString. +func (r RCode) GoString() string { + if n, ok := rCodeNames[r]; ok { + return "dnsmessage." + n + } + return printUint16(uint16(r)) +} + +func printPaddedUint8(i uint8) string { + b := byte(i) + return string([]byte{ + b/100 + '0', + b/10%10 + '0', + b%10 + '0', + }) +} + +func printUint8Bytes(buf []byte, i uint8) []byte { + b := byte(i) + if i >= 100 { + buf = append(buf, b/100+'0') + } + if i >= 10 { + buf = append(buf, b/10%10+'0') + } + return append(buf, b%10+'0') +} + +func printByteSlice(b []byte) string { + if len(b) == 0 { + return "" + } + buf := make([]byte, 0, 5*len(b)) + buf = printUint8Bytes(buf, uint8(b[0])) + for _, n := range b[1:] { + buf = append(buf, ',', ' ') + buf = printUint8Bytes(buf, uint8(n)) + } + return string(buf) +} + +const hexDigits = "0123456789abcdef" + +func printString(str []byte) string { + buf := make([]byte, 0, len(str)) + for i := 0; i < len(str); i++ { + c := str[i] + if c == '.' || c == '-' || c == ' ' || + 'A' <= c && c <= 'Z' || + 'a' <= c && c <= 'z' || + '0' <= c && c <= '9' { + buf = append(buf, c) + continue + } + + upper := c >> 4 + lower := (c << 4) >> 4 + buf = append( + buf, + '\\', + 'x', + hexDigits[upper], + hexDigits[lower], + ) + } + return string(buf) +} + +func printUint16(i uint16) string { + return printUint32(uint32(i)) +} + +func printUint32(i uint32) string { + // Max value is 4294967295. + buf := make([]byte, 10) + for b, d := buf, uint32(1000000000); d > 0; d /= 10 { + b[0] = byte(i/d%10 + '0') + if b[0] == '0' && len(b) == len(buf) && len(buf) > 1 { + buf = buf[1:] + } + b = b[1:] + i %= d + } + return string(buf) +} + +func printBool(b bool) string { + if b { + return "true" + } + return "false" +} + +var ( + // ErrNotStarted indicates that the prerequisite information isn't + // available yet because the previous records haven't been appropriately + // parsed, skipped or finished. + ErrNotStarted = errors.New("parsing/packing of this type isn't available yet") + + // ErrSectionDone indicated that all records in the section have been + // parsed or finished. + ErrSectionDone = errors.New("parsing/packing of this section has completed") + + errBaseLen = errors.New("insufficient data for base length type") + errCalcLen = errors.New("insufficient data for calculated length type") + errReserved = errors.New("segment prefix is reserved") + errTooManyPtr = errors.New("too many pointers (>10)") + errInvalidPtr = errors.New("invalid pointer") + errInvalidName = errors.New("invalid dns name") + errNilResouceBody = errors.New("nil resource body") + errResourceLen = errors.New("insufficient data for resource body length") + errSegTooLong = errors.New("segment length too long") + errNameTooLong = errors.New("name too long") + errZeroSegLen = errors.New("zero length segment") + errResTooLong = errors.New("resource length too long") + errTooManyQuestions = errors.New("too many Questions to pack (>65535)") + errTooManyAnswers = errors.New("too many Answers to pack (>65535)") + errTooManyAuthorities = errors.New("too many Authorities to pack (>65535)") + errTooManyAdditionals = errors.New("too many Additionals to pack (>65535)") + errNonCanonicalName = errors.New("name is not in canonical format (it must end with a .)") + errStringTooLong = errors.New("character string exceeds maximum length (255)") + errParamOutOfOrder = errors.New("parameter out of order") + errTooLongSVCBValue = errors.New("value too long (>65535 bytes)") +) + +// Internal constants. +const ( + // packStartingCap is the default initial buffer size allocated during + // packing. + // + // The starting capacity doesn't matter too much, but most DNS responses + // Will be <= 512 bytes as it is the limit for DNS over UDP. + packStartingCap = 512 + + // uint16Len is the length (in bytes) of a uint16. + uint16Len = 2 + + // uint32Len is the length (in bytes) of a uint32. + uint32Len = 4 + + // headerLen is the length (in bytes) of a DNS header. + // + // A header is comprised of 6 uint16s and no padding. + headerLen = 6 * uint16Len +) + +type nestedError struct { + // s is the current level's error message. + s string + + // err is the nested error. + err error +} + +// nestedError implements error.Error. +func (e *nestedError) Error() string { + return e.s + ": " + e.err.Error() +} + +// Header is a representation of a DNS message header. +type Header struct { + ID uint16 + Response bool + OpCode OpCode + Authoritative bool + Truncated bool + RecursionDesired bool + RecursionAvailable bool + AuthenticData bool + CheckingDisabled bool + RCode RCode +} + +func (m *Header) pack() (id uint16, bits uint16) { + id = m.ID + bits = uint16(m.OpCode)<<11 | uint16(m.RCode) + if m.RecursionAvailable { + bits |= headerBitRA + } + if m.RecursionDesired { + bits |= headerBitRD + } + if m.Truncated { + bits |= headerBitTC + } + if m.Authoritative { + bits |= headerBitAA + } + if m.Response { + bits |= headerBitQR + } + if m.AuthenticData { + bits |= headerBitAD + } + if m.CheckingDisabled { + bits |= headerBitCD + } + return +} + +// GoString implements fmt.GoStringer.GoString. +func (m *Header) GoString() string { + return "dnsmessage.Header{" + + "ID: " + printUint16(m.ID) + ", " + + "Response: " + printBool(m.Response) + ", " + + "OpCode: " + m.OpCode.GoString() + ", " + + "Authoritative: " + printBool(m.Authoritative) + ", " + + "Truncated: " + printBool(m.Truncated) + ", " + + "RecursionDesired: " + printBool(m.RecursionDesired) + ", " + + "RecursionAvailable: " + printBool(m.RecursionAvailable) + ", " + + "AuthenticData: " + printBool(m.AuthenticData) + ", " + + "CheckingDisabled: " + printBool(m.CheckingDisabled) + ", " + + "RCode: " + m.RCode.GoString() + "}" +} + +// Message is a representation of a DNS message. +type Message struct { + Header + Questions []Question + Answers []Resource + Authorities []Resource + Additionals []Resource +} + +type section uint8 + +const ( + sectionNotStarted section = iota + sectionHeader + sectionQuestions + sectionAnswers + sectionAuthorities + sectionAdditionals + sectionDone + + headerBitQR = 1 << 15 // query/response (response=1) + headerBitAA = 1 << 10 // authoritative + headerBitTC = 1 << 9 // truncated + headerBitRD = 1 << 8 // recursion desired + headerBitRA = 1 << 7 // recursion available + headerBitAD = 1 << 5 // authentic data + headerBitCD = 1 << 4 // checking disabled +) + +var sectionNames = map[section]string{ + sectionHeader: "header", + sectionQuestions: "Question", + sectionAnswers: "Answer", + sectionAuthorities: "Authority", + sectionAdditionals: "Additional", +} + +// header is the wire format for a DNS message header. +type header struct { + id uint16 + bits uint16 + questions uint16 + answers uint16 + authorities uint16 + additionals uint16 +} + +func (h *header) count(sec section) uint16 { + switch sec { + case sectionQuestions: + return h.questions + case sectionAnswers: + return h.answers + case sectionAuthorities: + return h.authorities + case sectionAdditionals: + return h.additionals + } + return 0 +} + +// pack appends the wire format of the header to msg. +func (h *header) pack(msg []byte) []byte { + msg = packUint16(msg, h.id) + msg = packUint16(msg, h.bits) + msg = packUint16(msg, h.questions) + msg = packUint16(msg, h.answers) + msg = packUint16(msg, h.authorities) + return packUint16(msg, h.additionals) +} + +func (h *header) unpack(msg []byte, off int) (int, error) { + newOff := off + var err error + if h.id, newOff, err = unpackUint16(msg, newOff); err != nil { + return off, &nestedError{"id", err} + } + if h.bits, newOff, err = unpackUint16(msg, newOff); err != nil { + return off, &nestedError{"bits", err} + } + if h.questions, newOff, err = unpackUint16(msg, newOff); err != nil { + return off, &nestedError{"questions", err} + } + if h.answers, newOff, err = unpackUint16(msg, newOff); err != nil { + return off, &nestedError{"answers", err} + } + if h.authorities, newOff, err = unpackUint16(msg, newOff); err != nil { + return off, &nestedError{"authorities", err} + } + if h.additionals, newOff, err = unpackUint16(msg, newOff); err != nil { + return off, &nestedError{"additionals", err} + } + return newOff, nil +} + +func (h *header) header() Header { + return Header{ + ID: h.id, + Response: (h.bits & headerBitQR) != 0, + OpCode: OpCode(h.bits>>11) & 0xF, + Authoritative: (h.bits & headerBitAA) != 0, + Truncated: (h.bits & headerBitTC) != 0, + RecursionDesired: (h.bits & headerBitRD) != 0, + RecursionAvailable: (h.bits & headerBitRA) != 0, + AuthenticData: (h.bits & headerBitAD) != 0, + CheckingDisabled: (h.bits & headerBitCD) != 0, + RCode: RCode(h.bits & 0xF), + } +} + +// A Resource is a DNS resource record. +type Resource struct { + Header ResourceHeader + Body ResourceBody +} + +func (r *Resource) GoString() string { + return "dnsmessage.Resource{" + + "Header: " + r.Header.GoString() + + ", Body: &" + r.Body.GoString() + + "}" +} + +// A ResourceBody is a DNS resource record minus the header. +type ResourceBody interface { + // pack packs a Resource except for its header. + pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) + + // realType returns the actual type of the Resource. This is used to + // fill in the header Type field. + realType() Type + + // GoString implements fmt.GoStringer.GoString. + GoString() string +} + +// pack appends the wire format of the Resource to msg. +func (r *Resource) pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) { + if r.Body == nil { + return msg, errNilResouceBody + } + oldMsg := msg + r.Header.Type = r.Body.realType() + msg, lenOff, err := r.Header.pack(msg, compression, compressionOff) + if err != nil { + return msg, &nestedError{"ResourceHeader", err} + } + preLen := len(msg) + msg, err = r.Body.pack(msg, compression, compressionOff) + if err != nil { + return msg, &nestedError{"content", err} + } + if err := r.Header.fixLen(msg, lenOff, preLen); err != nil { + return oldMsg, err + } + return msg, nil +} + +// A Parser allows incrementally parsing a DNS message. +// +// When parsing is started, the Header is parsed. Next, each Question can be +// either parsed or skipped. Alternatively, all Questions can be skipped at +// once. When all Questions have been parsed, attempting to parse Questions +// will return the [ErrSectionDone] error. +// After all Questions have been either parsed or skipped, all +// Answers, Authorities and Additionals can be either parsed or skipped in the +// same way, and each type of Resource must be fully parsed or skipped before +// proceeding to the next type of Resource. +// +// Parser is safe to copy to preserve the parsing state. +// +// Note that there is no requirement to fully skip or parse the message. +type Parser struct { + msg []byte + header header + + section section + off int + index int + resHeaderValid bool + resHeaderOffset int + resHeaderType Type + resHeaderLength uint16 +} + +// Start parses the header and enables the parsing of Questions. +func (p *Parser) Start(msg []byte) (Header, error) { + if p.msg != nil { + *p = Parser{} + } + p.msg = msg + var err error + if p.off, err = p.header.unpack(msg, 0); err != nil { + return Header{}, &nestedError{"unpacking header", err} + } + p.section = sectionQuestions + return p.header.header(), nil +} + +func (p *Parser) checkAdvance(sec section) error { + if p.section < sec { + return ErrNotStarted + } + if p.section > sec { + return ErrSectionDone + } + p.resHeaderValid = false + if p.index == int(p.header.count(sec)) { + p.index = 0 + p.section++ + return ErrSectionDone + } + return nil +} + +func (p *Parser) resource(sec section) (Resource, error) { + var r Resource + var err error + r.Header, err = p.resourceHeader(sec) + if err != nil { + return r, err + } + p.resHeaderValid = false + r.Body, p.off, err = unpackResourceBody(p.msg, p.off, r.Header) + if err != nil { + return Resource{}, &nestedError{"unpacking " + sectionNames[sec], err} + } + p.index++ + return r, nil +} + +func (p *Parser) resourceHeader(sec section) (ResourceHeader, error) { + if p.resHeaderValid { + p.off = p.resHeaderOffset + } + + if err := p.checkAdvance(sec); err != nil { + return ResourceHeader{}, err + } + var hdr ResourceHeader + off, err := hdr.unpack(p.msg, p.off) + if err != nil { + return ResourceHeader{}, err + } + p.resHeaderValid = true + p.resHeaderOffset = p.off + p.resHeaderType = hdr.Type + p.resHeaderLength = hdr.Length + p.off = off + return hdr, nil +} + +func (p *Parser) skipResource(sec section) error { + if p.resHeaderValid && p.section == sec { + newOff := p.off + int(p.resHeaderLength) + if newOff > len(p.msg) { + return errResourceLen + } + p.off = newOff + p.resHeaderValid = false + p.index++ + return nil + } + if err := p.checkAdvance(sec); err != nil { + return err + } + var err error + p.off, err = skipResource(p.msg, p.off) + if err != nil { + return &nestedError{"skipping: " + sectionNames[sec], err} + } + p.index++ + return nil +} + +// Question parses a single Question. +func (p *Parser) Question() (Question, error) { + if err := p.checkAdvance(sectionQuestions); err != nil { + return Question{}, err + } + var name Name + off, err := name.unpack(p.msg, p.off) + if err != nil { + return Question{}, &nestedError{"unpacking Question.Name", err} + } + typ, off, err := unpackType(p.msg, off) + if err != nil { + return Question{}, &nestedError{"unpacking Question.Type", err} + } + class, off, err := unpackClass(p.msg, off) + if err != nil { + return Question{}, &nestedError{"unpacking Question.Class", err} + } + p.off = off + p.index++ + return Question{name, typ, class}, nil +} + +// AllQuestions parses all Questions. +func (p *Parser) AllQuestions() ([]Question, error) { + // Multiple questions are valid according to the spec, + // but servers don't actually support them. There will + // be at most one question here. + // + // Do not pre-allocate based on info in p.header, since + // the data is untrusted. + qs := []Question{} + for { + q, err := p.Question() + if err == ErrSectionDone { + return qs, nil + } + if err != nil { + return nil, err + } + qs = append(qs, q) + } +} + +// SkipQuestion skips a single Question. +func (p *Parser) SkipQuestion() error { + if err := p.checkAdvance(sectionQuestions); err != nil { + return err + } + off, err := skipName(p.msg, p.off) + if err != nil { + return &nestedError{"skipping Question Name", err} + } + if off, err = skipType(p.msg, off); err != nil { + return &nestedError{"skipping Question Type", err} + } + if off, err = skipClass(p.msg, off); err != nil { + return &nestedError{"skipping Question Class", err} + } + p.off = off + p.index++ + return nil +} + +// SkipAllQuestions skips all Questions. +func (p *Parser) SkipAllQuestions() error { + for { + if err := p.SkipQuestion(); err == ErrSectionDone { + return nil + } else if err != nil { + return err + } + } +} + +// AnswerHeader parses a single Answer ResourceHeader. +func (p *Parser) AnswerHeader() (ResourceHeader, error) { + return p.resourceHeader(sectionAnswers) +} + +// Answer parses a single Answer Resource. +func (p *Parser) Answer() (Resource, error) { + return p.resource(sectionAnswers) +} + +// AllAnswers parses all Answer Resources. +func (p *Parser) AllAnswers() ([]Resource, error) { + // The most common query is for A/AAAA, which usually returns + // a handful of IPs. + // + // Pre-allocate up to a certain limit, since p.header is + // untrusted data. + n := int(p.header.answers) + if n > 20 { + n = 20 + } + as := make([]Resource, 0, n) + for { + a, err := p.Answer() + if err == ErrSectionDone { + return as, nil + } + if err != nil { + return nil, err + } + as = append(as, a) + } +} + +// SkipAnswer skips a single Answer Resource. +// +// It does not perform a complete validation of the resource header, which means +// it may return a nil error when the [AnswerHeader] would actually return an error. +func (p *Parser) SkipAnswer() error { + return p.skipResource(sectionAnswers) +} + +// SkipAllAnswers skips all Answer Resources. +func (p *Parser) SkipAllAnswers() error { + for { + if err := p.SkipAnswer(); err == ErrSectionDone { + return nil + } else if err != nil { + return err + } + } +} + +// AuthorityHeader parses a single Authority ResourceHeader. +func (p *Parser) AuthorityHeader() (ResourceHeader, error) { + return p.resourceHeader(sectionAuthorities) +} + +// Authority parses a single Authority Resource. +func (p *Parser) Authority() (Resource, error) { + return p.resource(sectionAuthorities) +} + +// AllAuthorities parses all Authority Resources. +func (p *Parser) AllAuthorities() ([]Resource, error) { + // Authorities contains SOA in case of NXDOMAIN and friends, + // otherwise it is empty. + // + // Pre-allocate up to a certain limit, since p.header is + // untrusted data. + n := int(p.header.authorities) + if n > 10 { + n = 10 + } + as := make([]Resource, 0, n) + for { + a, err := p.Authority() + if err == ErrSectionDone { + return as, nil + } + if err != nil { + return nil, err + } + as = append(as, a) + } +} + +// SkipAuthority skips a single Authority Resource. +// +// It does not perform a complete validation of the resource header, which means +// it may return a nil error when the [AuthorityHeader] would actually return an error. +func (p *Parser) SkipAuthority() error { + return p.skipResource(sectionAuthorities) +} + +// SkipAllAuthorities skips all Authority Resources. +func (p *Parser) SkipAllAuthorities() error { + for { + if err := p.SkipAuthority(); err == ErrSectionDone { + return nil + } else if err != nil { + return err + } + } +} + +// AdditionalHeader parses a single Additional ResourceHeader. +func (p *Parser) AdditionalHeader() (ResourceHeader, error) { + return p.resourceHeader(sectionAdditionals) +} + +// Additional parses a single Additional Resource. +func (p *Parser) Additional() (Resource, error) { + return p.resource(sectionAdditionals) +} + +// AllAdditionals parses all Additional Resources. +func (p *Parser) AllAdditionals() ([]Resource, error) { + // Additionals usually contain OPT, and sometimes A/AAAA + // glue records. + // + // Pre-allocate up to a certain limit, since p.header is + // untrusted data. + n := int(p.header.additionals) + if n > 10 { + n = 10 + } + as := make([]Resource, 0, n) + for { + a, err := p.Additional() + if err == ErrSectionDone { + return as, nil + } + if err != nil { + return nil, err + } + as = append(as, a) + } +} + +// SkipAdditional skips a single Additional Resource. +// +// It does not perform a complete validation of the resource header, which means +// it may return a nil error when the [AdditionalHeader] would actually return an error. +func (p *Parser) SkipAdditional() error { + return p.skipResource(sectionAdditionals) +} + +// SkipAllAdditionals skips all Additional Resources. +func (p *Parser) SkipAllAdditionals() error { + for { + if err := p.SkipAdditional(); err == ErrSectionDone { + return nil + } else if err != nil { + return err + } + } +} + +// CNAMEResource parses a single CNAMEResource. +// +// One of the XXXHeader methods must have been called before calling this +// method. +func (p *Parser) CNAMEResource() (CNAMEResource, error) { + if !p.resHeaderValid || p.resHeaderType != TypeCNAME { + return CNAMEResource{}, ErrNotStarted + } + r, err := unpackCNAMEResource(p.msg, p.off) + if err != nil { + return CNAMEResource{}, err + } + p.off += int(p.resHeaderLength) + p.resHeaderValid = false + p.index++ + return r, nil +} + +// MXResource parses a single MXResource. +// +// One of the XXXHeader methods must have been called before calling this +// method. +func (p *Parser) MXResource() (MXResource, error) { + if !p.resHeaderValid || p.resHeaderType != TypeMX { + return MXResource{}, ErrNotStarted + } + r, err := unpackMXResource(p.msg, p.off) + if err != nil { + return MXResource{}, err + } + p.off += int(p.resHeaderLength) + p.resHeaderValid = false + p.index++ + return r, nil +} + +// NSResource parses a single NSResource. +// +// One of the XXXHeader methods must have been called before calling this +// method. +func (p *Parser) NSResource() (NSResource, error) { + if !p.resHeaderValid || p.resHeaderType != TypeNS { + return NSResource{}, ErrNotStarted + } + r, err := unpackNSResource(p.msg, p.off) + if err != nil { + return NSResource{}, err + } + p.off += int(p.resHeaderLength) + p.resHeaderValid = false + p.index++ + return r, nil +} + +// PTRResource parses a single PTRResource. +// +// One of the XXXHeader methods must have been called before calling this +// method. +func (p *Parser) PTRResource() (PTRResource, error) { + if !p.resHeaderValid || p.resHeaderType != TypePTR { + return PTRResource{}, ErrNotStarted + } + r, err := unpackPTRResource(p.msg, p.off) + if err != nil { + return PTRResource{}, err + } + p.off += int(p.resHeaderLength) + p.resHeaderValid = false + p.index++ + return r, nil +} + +// SOAResource parses a single SOAResource. +// +// One of the XXXHeader methods must have been called before calling this +// method. +func (p *Parser) SOAResource() (SOAResource, error) { + if !p.resHeaderValid || p.resHeaderType != TypeSOA { + return SOAResource{}, ErrNotStarted + } + r, err := unpackSOAResource(p.msg, p.off) + if err != nil { + return SOAResource{}, err + } + p.off += int(p.resHeaderLength) + p.resHeaderValid = false + p.index++ + return r, nil +} + +// TXTResource parses a single TXTResource. +// +// One of the XXXHeader methods must have been called before calling this +// method. +func (p *Parser) TXTResource() (TXTResource, error) { + if !p.resHeaderValid || p.resHeaderType != TypeTXT { + return TXTResource{}, ErrNotStarted + } + r, err := unpackTXTResource(p.msg, p.off, p.resHeaderLength) + if err != nil { + return TXTResource{}, err + } + p.off += int(p.resHeaderLength) + p.resHeaderValid = false + p.index++ + return r, nil +} + +// SRVResource parses a single SRVResource. +// +// One of the XXXHeader methods must have been called before calling this +// method. +func (p *Parser) SRVResource() (SRVResource, error) { + if !p.resHeaderValid || p.resHeaderType != TypeSRV { + return SRVResource{}, ErrNotStarted + } + r, err := unpackSRVResource(p.msg, p.off) + if err != nil { + return SRVResource{}, err + } + p.off += int(p.resHeaderLength) + p.resHeaderValid = false + p.index++ + return r, nil +} + +// AResource parses a single AResource. +// +// One of the XXXHeader methods must have been called before calling this +// method. +func (p *Parser) AResource() (AResource, error) { + if !p.resHeaderValid || p.resHeaderType != TypeA { + return AResource{}, ErrNotStarted + } + r, err := unpackAResource(p.msg, p.off) + if err != nil { + return AResource{}, err + } + p.off += int(p.resHeaderLength) + p.resHeaderValid = false + p.index++ + return r, nil +} + +// AAAAResource parses a single AAAAResource. +// +// One of the XXXHeader methods must have been called before calling this +// method. +func (p *Parser) AAAAResource() (AAAAResource, error) { + if !p.resHeaderValid || p.resHeaderType != TypeAAAA { + return AAAAResource{}, ErrNotStarted + } + r, err := unpackAAAAResource(p.msg, p.off) + if err != nil { + return AAAAResource{}, err + } + p.off += int(p.resHeaderLength) + p.resHeaderValid = false + p.index++ + return r, nil +} + +// OPTResource parses a single OPTResource. +// +// One of the XXXHeader methods must have been called before calling this +// method. +func (p *Parser) OPTResource() (OPTResource, error) { + if !p.resHeaderValid || p.resHeaderType != TypeOPT { + return OPTResource{}, ErrNotStarted + } + r, err := unpackOPTResource(p.msg, p.off, p.resHeaderLength) + if err != nil { + return OPTResource{}, err + } + p.off += int(p.resHeaderLength) + p.resHeaderValid = false + p.index++ + return r, nil +} + +// UnknownResource parses a single UnknownResource. +// +// One of the XXXHeader methods must have been called before calling this +// method. +func (p *Parser) UnknownResource() (UnknownResource, error) { + if !p.resHeaderValid { + return UnknownResource{}, ErrNotStarted + } + r, err := unpackUnknownResource(p.resHeaderType, p.msg, p.off, p.resHeaderLength) + if err != nil { + return UnknownResource{}, err + } + p.off += int(p.resHeaderLength) + p.resHeaderValid = false + p.index++ + return r, nil +} + +// Unpack parses a full Message. +func (m *Message) Unpack(msg []byte) error { + var p Parser + var err error + if m.Header, err = p.Start(msg); err != nil { + return err + } + if m.Questions, err = p.AllQuestions(); err != nil { + return err + } + if m.Answers, err = p.AllAnswers(); err != nil { + return err + } + if m.Authorities, err = p.AllAuthorities(); err != nil { + return err + } + if m.Additionals, err = p.AllAdditionals(); err != nil { + return err + } + return nil +} + +// Pack packs a full Message. +func (m *Message) Pack() ([]byte, error) { + return m.AppendPack(make([]byte, 0, packStartingCap)) +} + +// AppendPack is like Pack but appends the full Message to b and returns the +// extended buffer. +func (m *Message) AppendPack(b []byte) ([]byte, error) { + // Validate the lengths. It is very unlikely that anyone will try to + // pack more than 65535 of any particular type, but it is possible and + // we should fail gracefully. + if len(m.Questions) > int(^uint16(0)) { + return nil, errTooManyQuestions + } + if len(m.Answers) > int(^uint16(0)) { + return nil, errTooManyAnswers + } + if len(m.Authorities) > int(^uint16(0)) { + return nil, errTooManyAuthorities + } + if len(m.Additionals) > int(^uint16(0)) { + return nil, errTooManyAdditionals + } + + var h header + h.id, h.bits = m.Header.pack() + + h.questions = uint16(len(m.Questions)) + h.answers = uint16(len(m.Answers)) + h.authorities = uint16(len(m.Authorities)) + h.additionals = uint16(len(m.Additionals)) + + compressionOff := len(b) + msg := h.pack(b) + + // RFC 1035 allows (but does not require) compression for packing. RFC + // 1035 requires unpacking implementations to support compression, so + // unconditionally enabling it is fine. + // + // DNS lookups are typically done over UDP, and RFC 1035 states that UDP + // DNS messages can be a maximum of 512 bytes long. Without compression, + // many DNS response messages are over this limit, so enabling + // compression will help ensure compliance. + compression := map[string]uint16{} + + for i := range m.Questions { + var err error + if msg, err = m.Questions[i].pack(msg, compression, compressionOff); err != nil { + return nil, &nestedError{"packing Question", err} + } + } + for i := range m.Answers { + var err error + if msg, err = m.Answers[i].pack(msg, compression, compressionOff); err != nil { + return nil, &nestedError{"packing Answer", err} + } + } + for i := range m.Authorities { + var err error + if msg, err = m.Authorities[i].pack(msg, compression, compressionOff); err != nil { + return nil, &nestedError{"packing Authority", err} + } + } + for i := range m.Additionals { + var err error + if msg, err = m.Additionals[i].pack(msg, compression, compressionOff); err != nil { + return nil, &nestedError{"packing Additional", err} + } + } + + return msg, nil +} + +// GoString implements fmt.GoStringer.GoString. +func (m *Message) GoString() string { + s := "dnsmessage.Message{Header: " + m.Header.GoString() + ", " + + "Questions: []dnsmessage.Question{" + if len(m.Questions) > 0 { + s += m.Questions[0].GoString() + for _, q := range m.Questions[1:] { + s += ", " + q.GoString() + } + } + s += "}, Answers: []dnsmessage.Resource{" + if len(m.Answers) > 0 { + s += m.Answers[0].GoString() + for _, a := range m.Answers[1:] { + s += ", " + a.GoString() + } + } + s += "}, Authorities: []dnsmessage.Resource{" + if len(m.Authorities) > 0 { + s += m.Authorities[0].GoString() + for _, a := range m.Authorities[1:] { + s += ", " + a.GoString() + } + } + s += "}, Additionals: []dnsmessage.Resource{" + if len(m.Additionals) > 0 { + s += m.Additionals[0].GoString() + for _, a := range m.Additionals[1:] { + s += ", " + a.GoString() + } + } + return s + "}}" +} + +// A Builder allows incrementally packing a DNS message. +// +// Example usage: +// +// buf := make([]byte, 2, 514) +// b := NewBuilder(buf, Header{...}) +// b.EnableCompression() +// // Optionally start a section and add things to that section. +// // Repeat adding sections as necessary. +// buf, err := b.Finish() +// // If err is nil, buf[2:] will contain the built bytes. +type Builder struct { + // msg is the storage for the message being built. + msg []byte + + // section keeps track of the current section being built. + section section + + // header keeps track of what should go in the header when Finish is + // called. + header header + + // start is the starting index of the bytes allocated in msg for header. + start int + + // compression is a mapping from name suffixes to their starting index + // in msg. + compression map[string]uint16 +} + +// NewBuilder creates a new builder with compression disabled. +// +// Note: Most users will want to immediately enable compression with the +// EnableCompression method. See that method's comment for why you may or may +// not want to enable compression. +// +// The DNS message is appended to the provided initial buffer buf (which may be +// nil) as it is built. The final message is returned by the (*Builder).Finish +// method, which includes buf[:len(buf)] and may return the same underlying +// array if there was sufficient capacity in the slice. +func NewBuilder(buf []byte, h Header) Builder { + if buf == nil { + buf = make([]byte, 0, packStartingCap) + } + b := Builder{msg: buf, start: len(buf)} + b.header.id, b.header.bits = h.pack() + var hb [headerLen]byte + b.msg = append(b.msg, hb[:]...) + b.section = sectionHeader + return b +} + +// EnableCompression enables compression in the Builder. +// +// Leaving compression disabled avoids compression related allocations, but can +// result in larger message sizes. Be careful with this mode as it can cause +// messages to exceed the UDP size limit. +// +// According to RFC 1035, section 4.1.4, the use of compression is optional, but +// all implementations must accept both compressed and uncompressed DNS +// messages. +// +// Compression should be enabled before any sections are added for best results. +func (b *Builder) EnableCompression() { + b.compression = map[string]uint16{} +} + +func (b *Builder) startCheck(s section) error { + if b.section <= sectionNotStarted { + return ErrNotStarted + } + if b.section > s { + return ErrSectionDone + } + return nil +} + +// StartQuestions prepares the builder for packing Questions. +func (b *Builder) StartQuestions() error { + if err := b.startCheck(sectionQuestions); err != nil { + return err + } + b.section = sectionQuestions + return nil +} + +// StartAnswers prepares the builder for packing Answers. +func (b *Builder) StartAnswers() error { + if err := b.startCheck(sectionAnswers); err != nil { + return err + } + b.section = sectionAnswers + return nil +} + +// StartAuthorities prepares the builder for packing Authorities. +func (b *Builder) StartAuthorities() error { + if err := b.startCheck(sectionAuthorities); err != nil { + return err + } + b.section = sectionAuthorities + return nil +} + +// StartAdditionals prepares the builder for packing Additionals. +func (b *Builder) StartAdditionals() error { + if err := b.startCheck(sectionAdditionals); err != nil { + return err + } + b.section = sectionAdditionals + return nil +} + +func (b *Builder) incrementSectionCount() error { + var count *uint16 + var err error + switch b.section { + case sectionQuestions: + count = &b.header.questions + err = errTooManyQuestions + case sectionAnswers: + count = &b.header.answers + err = errTooManyAnswers + case sectionAuthorities: + count = &b.header.authorities + err = errTooManyAuthorities + case sectionAdditionals: + count = &b.header.additionals + err = errTooManyAdditionals + } + if *count == ^uint16(0) { + return err + } + *count++ + return nil +} + +// Question adds a single Question. +func (b *Builder) Question(q Question) error { + if b.section < sectionQuestions { + return ErrNotStarted + } + if b.section > sectionQuestions { + return ErrSectionDone + } + msg, err := q.pack(b.msg, b.compression, b.start) + if err != nil { + return err + } + if err := b.incrementSectionCount(); err != nil { + return err + } + b.msg = msg + return nil +} + +func (b *Builder) checkResourceSection() error { + if b.section < sectionAnswers { + return ErrNotStarted + } + if b.section > sectionAdditionals { + return ErrSectionDone + } + return nil +} + +// CNAMEResource adds a single CNAMEResource. +func (b *Builder) CNAMEResource(h ResourceHeader, r CNAMEResource) error { + if err := b.checkResourceSection(); err != nil { + return err + } + h.Type = r.realType() + msg, lenOff, err := h.pack(b.msg, b.compression, b.start) + if err != nil { + return &nestedError{"ResourceHeader", err} + } + preLen := len(msg) + if msg, err = r.pack(msg, b.compression, b.start); err != nil { + return &nestedError{"CNAMEResource body", err} + } + if err := h.fixLen(msg, lenOff, preLen); err != nil { + return err + } + if err := b.incrementSectionCount(); err != nil { + return err + } + b.msg = msg + return nil +} + +// MXResource adds a single MXResource. +func (b *Builder) MXResource(h ResourceHeader, r MXResource) error { + if err := b.checkResourceSection(); err != nil { + return err + } + h.Type = r.realType() + msg, lenOff, err := h.pack(b.msg, b.compression, b.start) + if err != nil { + return &nestedError{"ResourceHeader", err} + } + preLen := len(msg) + if msg, err = r.pack(msg, b.compression, b.start); err != nil { + return &nestedError{"MXResource body", err} + } + if err := h.fixLen(msg, lenOff, preLen); err != nil { + return err + } + if err := b.incrementSectionCount(); err != nil { + return err + } + b.msg = msg + return nil +} + +// NSResource adds a single NSResource. +func (b *Builder) NSResource(h ResourceHeader, r NSResource) error { + if err := b.checkResourceSection(); err != nil { + return err + } + h.Type = r.realType() + msg, lenOff, err := h.pack(b.msg, b.compression, b.start) + if err != nil { + return &nestedError{"ResourceHeader", err} + } + preLen := len(msg) + if msg, err = r.pack(msg, b.compression, b.start); err != nil { + return &nestedError{"NSResource body", err} + } + if err := h.fixLen(msg, lenOff, preLen); err != nil { + return err + } + if err := b.incrementSectionCount(); err != nil { + return err + } + b.msg = msg + return nil +} + +// PTRResource adds a single PTRResource. +func (b *Builder) PTRResource(h ResourceHeader, r PTRResource) error { + if err := b.checkResourceSection(); err != nil { + return err + } + h.Type = r.realType() + msg, lenOff, err := h.pack(b.msg, b.compression, b.start) + if err != nil { + return &nestedError{"ResourceHeader", err} + } + preLen := len(msg) + if msg, err = r.pack(msg, b.compression, b.start); err != nil { + return &nestedError{"PTRResource body", err} + } + if err := h.fixLen(msg, lenOff, preLen); err != nil { + return err + } + if err := b.incrementSectionCount(); err != nil { + return err + } + b.msg = msg + return nil +} + +// SOAResource adds a single SOAResource. +func (b *Builder) SOAResource(h ResourceHeader, r SOAResource) error { + if err := b.checkResourceSection(); err != nil { + return err + } + h.Type = r.realType() + msg, lenOff, err := h.pack(b.msg, b.compression, b.start) + if err != nil { + return &nestedError{"ResourceHeader", err} + } + preLen := len(msg) + if msg, err = r.pack(msg, b.compression, b.start); err != nil { + return &nestedError{"SOAResource body", err} + } + if err := h.fixLen(msg, lenOff, preLen); err != nil { + return err + } + if err := b.incrementSectionCount(); err != nil { + return err + } + b.msg = msg + return nil +} + +// TXTResource adds a single TXTResource. +func (b *Builder) TXTResource(h ResourceHeader, r TXTResource) error { + if err := b.checkResourceSection(); err != nil { + return err + } + h.Type = r.realType() + msg, lenOff, err := h.pack(b.msg, b.compression, b.start) + if err != nil { + return &nestedError{"ResourceHeader", err} + } + preLen := len(msg) + if msg, err = r.pack(msg, b.compression, b.start); err != nil { + return &nestedError{"TXTResource body", err} + } + if err := h.fixLen(msg, lenOff, preLen); err != nil { + return err + } + if err := b.incrementSectionCount(); err != nil { + return err + } + b.msg = msg + return nil +} + +// SRVResource adds a single SRVResource. +func (b *Builder) SRVResource(h ResourceHeader, r SRVResource) error { + if err := b.checkResourceSection(); err != nil { + return err + } + h.Type = r.realType() + msg, lenOff, err := h.pack(b.msg, b.compression, b.start) + if err != nil { + return &nestedError{"ResourceHeader", err} + } + preLen := len(msg) + if msg, err = r.pack(msg, b.compression, b.start); err != nil { + return &nestedError{"SRVResource body", err} + } + if err := h.fixLen(msg, lenOff, preLen); err != nil { + return err + } + if err := b.incrementSectionCount(); err != nil { + return err + } + b.msg = msg + return nil +} + +// AResource adds a single AResource. +func (b *Builder) AResource(h ResourceHeader, r AResource) error { + if err := b.checkResourceSection(); err != nil { + return err + } + h.Type = r.realType() + msg, lenOff, err := h.pack(b.msg, b.compression, b.start) + if err != nil { + return &nestedError{"ResourceHeader", err} + } + preLen := len(msg) + if msg, err = r.pack(msg, b.compression, b.start); err != nil { + return &nestedError{"AResource body", err} + } + if err := h.fixLen(msg, lenOff, preLen); err != nil { + return err + } + if err := b.incrementSectionCount(); err != nil { + return err + } + b.msg = msg + return nil +} + +// AAAAResource adds a single AAAAResource. +func (b *Builder) AAAAResource(h ResourceHeader, r AAAAResource) error { + if err := b.checkResourceSection(); err != nil { + return err + } + h.Type = r.realType() + msg, lenOff, err := h.pack(b.msg, b.compression, b.start) + if err != nil { + return &nestedError{"ResourceHeader", err} + } + preLen := len(msg) + if msg, err = r.pack(msg, b.compression, b.start); err != nil { + return &nestedError{"AAAAResource body", err} + } + if err := h.fixLen(msg, lenOff, preLen); err != nil { + return err + } + if err := b.incrementSectionCount(); err != nil { + return err + } + b.msg = msg + return nil +} + +// OPTResource adds a single OPTResource. +func (b *Builder) OPTResource(h ResourceHeader, r OPTResource) error { + if err := b.checkResourceSection(); err != nil { + return err + } + h.Type = r.realType() + msg, lenOff, err := h.pack(b.msg, b.compression, b.start) + if err != nil { + return &nestedError{"ResourceHeader", err} + } + preLen := len(msg) + if msg, err = r.pack(msg, b.compression, b.start); err != nil { + return &nestedError{"OPTResource body", err} + } + if err := h.fixLen(msg, lenOff, preLen); err != nil { + return err + } + if err := b.incrementSectionCount(); err != nil { + return err + } + b.msg = msg + return nil +} + +// UnknownResource adds a single UnknownResource. +func (b *Builder) UnknownResource(h ResourceHeader, r UnknownResource) error { + if err := b.checkResourceSection(); err != nil { + return err + } + h.Type = r.realType() + msg, lenOff, err := h.pack(b.msg, b.compression, b.start) + if err != nil { + return &nestedError{"ResourceHeader", err} + } + preLen := len(msg) + if msg, err = r.pack(msg, b.compression, b.start); err != nil { + return &nestedError{"UnknownResource body", err} + } + if err := h.fixLen(msg, lenOff, preLen); err != nil { + return err + } + if err := b.incrementSectionCount(); err != nil { + return err + } + b.msg = msg + return nil +} + +// Finish ends message building and generates a binary message. +func (b *Builder) Finish() ([]byte, error) { + if b.section < sectionHeader { + return nil, ErrNotStarted + } + b.section = sectionDone + // Space for the header was allocated in NewBuilder. + b.header.pack(b.msg[b.start:b.start]) + return b.msg, nil +} + +// A ResourceHeader is the header of a DNS resource record. There are +// many types of DNS resource records, but they all share the same header. +type ResourceHeader struct { + // Name is the domain name for which this resource record pertains. + Name Name + + // Type is the type of DNS resource record. + // + // This field will be set automatically during packing. + Type Type + + // Class is the class of network to which this DNS resource record + // pertains. + Class Class + + // TTL is the length of time (measured in seconds) which this resource + // record is valid for (time to live). All Resources in a set should + // have the same TTL (RFC 2181 Section 5.2). + TTL uint32 + + // Length is the length of data in the resource record after the header. + // + // This field will be set automatically during packing. + Length uint16 +} + +// GoString implements fmt.GoStringer.GoString. +func (h *ResourceHeader) GoString() string { + return "dnsmessage.ResourceHeader{" + + "Name: " + h.Name.GoString() + ", " + + "Type: " + h.Type.GoString() + ", " + + "Class: " + h.Class.GoString() + ", " + + "TTL: " + printUint32(h.TTL) + ", " + + "Length: " + printUint16(h.Length) + "}" +} + +// pack appends the wire format of the ResourceHeader to oldMsg. +// +// lenOff is the offset in msg where the Length field was packed. +func (h *ResourceHeader) pack(oldMsg []byte, compression map[string]uint16, compressionOff int) (msg []byte, lenOff int, err error) { + msg = oldMsg + if msg, err = h.Name.pack(msg, compression, compressionOff); err != nil { + return oldMsg, 0, &nestedError{"Name", err} + } + msg = packType(msg, h.Type) + msg = packClass(msg, h.Class) + msg = packUint32(msg, h.TTL) + lenOff = len(msg) + msg = packUint16(msg, h.Length) + return msg, lenOff, nil +} + +func (h *ResourceHeader) unpack(msg []byte, off int) (int, error) { + newOff := off + var err error + if newOff, err = h.Name.unpack(msg, newOff); err != nil { + return off, &nestedError{"Name", err} + } + if h.Type, newOff, err = unpackType(msg, newOff); err != nil { + return off, &nestedError{"Type", err} + } + if h.Class, newOff, err = unpackClass(msg, newOff); err != nil { + return off, &nestedError{"Class", err} + } + if h.TTL, newOff, err = unpackUint32(msg, newOff); err != nil { + return off, &nestedError{"TTL", err} + } + if h.Length, newOff, err = unpackUint16(msg, newOff); err != nil { + return off, &nestedError{"Length", err} + } + return newOff, nil +} + +// fixLen updates a packed ResourceHeader to include the length of the +// ResourceBody. +// +// lenOff is the offset of the ResourceHeader.Length field in msg. +// +// preLen is the length that msg was before the ResourceBody was packed. +func (h *ResourceHeader) fixLen(msg []byte, lenOff int, preLen int) error { + conLen := len(msg) - preLen + if conLen > int(^uint16(0)) { + return errResTooLong + } + + // Fill in the length now that we know how long the content is. + packUint16(msg[lenOff:lenOff], uint16(conLen)) + h.Length = uint16(conLen) + + return nil +} + +// EDNS(0) wire constants. +const ( + edns0Version = 0 + + edns0DNSSECOK = 0x00008000 + ednsVersionMask = 0x00ff0000 + edns0DNSSECOKMask = 0x00ff8000 +) + +// SetEDNS0 configures h for EDNS(0). +// +// The provided extRCode must be an extended RCode. +func (h *ResourceHeader) SetEDNS0(udpPayloadLen int, extRCode RCode, dnssecOK bool) error { + h.Name = Name{Data: [255]byte{'.'}, Length: 1} // RFC 6891 section 6.1.2 + h.Type = TypeOPT + h.Class = Class(udpPayloadLen) + h.TTL = uint32(extRCode) >> 4 << 24 + if dnssecOK { + h.TTL |= edns0DNSSECOK + } + return nil +} + +// DNSSECAllowed reports whether the DNSSEC OK bit is set. +func (h *ResourceHeader) DNSSECAllowed() bool { + return h.TTL&edns0DNSSECOKMask == edns0DNSSECOK // RFC 6891 section 6.1.3 +} + +// ExtendedRCode returns an extended RCode. +// +// The provided rcode must be the RCode in DNS message header. +func (h *ResourceHeader) ExtendedRCode(rcode RCode) RCode { + if h.TTL&ednsVersionMask == edns0Version { // RFC 6891 section 6.1.3 + return RCode(h.TTL>>24<<4) | rcode + } + return rcode +} + +func skipResource(msg []byte, off int) (int, error) { + newOff, err := skipName(msg, off) + if err != nil { + return off, &nestedError{"Name", err} + } + if newOff, err = skipType(msg, newOff); err != nil { + return off, &nestedError{"Type", err} + } + if newOff, err = skipClass(msg, newOff); err != nil { + return off, &nestedError{"Class", err} + } + if newOff, err = skipUint32(msg, newOff); err != nil { + return off, &nestedError{"TTL", err} + } + length, newOff, err := unpackUint16(msg, newOff) + if err != nil { + return off, &nestedError{"Length", err} + } + if newOff += int(length); newOff > len(msg) { + return off, errResourceLen + } + return newOff, nil +} + +// packUint16 appends the wire format of field to msg. +func packUint16(msg []byte, field uint16) []byte { + return append(msg, byte(field>>8), byte(field)) +} + +func unpackUint16(msg []byte, off int) (uint16, int, error) { + if off+uint16Len > len(msg) { + return 0, off, errBaseLen + } + return uint16(msg[off])<<8 | uint16(msg[off+1]), off + uint16Len, nil +} + +func skipUint16(msg []byte, off int) (int, error) { + if off+uint16Len > len(msg) { + return off, errBaseLen + } + return off + uint16Len, nil +} + +// packType appends the wire format of field to msg. +func packType(msg []byte, field Type) []byte { + return packUint16(msg, uint16(field)) +} + +func unpackType(msg []byte, off int) (Type, int, error) { + t, o, err := unpackUint16(msg, off) + return Type(t), o, err +} + +func skipType(msg []byte, off int) (int, error) { + return skipUint16(msg, off) +} + +// packClass appends the wire format of field to msg. +func packClass(msg []byte, field Class) []byte { + return packUint16(msg, uint16(field)) +} + +func unpackClass(msg []byte, off int) (Class, int, error) { + c, o, err := unpackUint16(msg, off) + return Class(c), o, err +} + +func skipClass(msg []byte, off int) (int, error) { + return skipUint16(msg, off) +} + +// packUint32 appends the wire format of field to msg. +func packUint32(msg []byte, field uint32) []byte { + return append( + msg, + byte(field>>24), + byte(field>>16), + byte(field>>8), + byte(field), + ) +} + +func unpackUint32(msg []byte, off int) (uint32, int, error) { + if off+uint32Len > len(msg) { + return 0, off, errBaseLen + } + v := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3]) + return v, off + uint32Len, nil +} + +func skipUint32(msg []byte, off int) (int, error) { + if off+uint32Len > len(msg) { + return off, errBaseLen + } + return off + uint32Len, nil +} + +// packText appends the wire format of field to msg. +func packText(msg []byte, field string) ([]byte, error) { + l := len(field) + if l > 255 { + return nil, errStringTooLong + } + msg = append(msg, byte(l)) + msg = append(msg, field...) + + return msg, nil +} + +func unpackText(msg []byte, off int) (string, int, error) { + if off >= len(msg) { + return "", off, errBaseLen + } + beginOff := off + 1 + endOff := beginOff + int(msg[off]) + if endOff > len(msg) { + return "", off, errCalcLen + } + return string(msg[beginOff:endOff]), endOff, nil +} + +// packBytes appends the wire format of field to msg. +func packBytes(msg []byte, field []byte) []byte { + return append(msg, field...) +} + +func unpackBytes(msg []byte, off int, field []byte) (int, error) { + newOff := off + len(field) + if newOff > len(msg) { + return off, errBaseLen + } + copy(field, msg[off:newOff]) + return newOff, nil +} + +const nonEncodedNameMax = 254 + +// A Name is a non-encoded and non-escaped domain name. It is used instead of strings to avoid +// allocations. +type Name struct { + Data [255]byte + Length uint8 +} + +// NewName creates a new Name from a string. +func NewName(name string) (Name, error) { + n := Name{Length: uint8(len(name))} + if len(name) > len(n.Data) { + return Name{}, errCalcLen + } + copy(n.Data[:], name) + return n, nil +} + +// MustNewName creates a new Name from a string and panics on error. +func MustNewName(name string) Name { + n, err := NewName(name) + if err != nil { + panic("creating name: " + err.Error()) + } + return n +} + +// String implements fmt.Stringer.String. +// +// Note: characters inside the labels are not escaped in any way. +func (n Name) String() string { + return string(n.Data[:n.Length]) +} + +// GoString implements fmt.GoStringer.GoString. +func (n *Name) GoString() string { + return `dnsmessage.MustNewName("` + printString(n.Data[:n.Length]) + `")` +} + +// pack appends the wire format of the Name to msg. +// +// Domain names are a sequence of counted strings split at the dots. They end +// with a zero-length string. Compression can be used to reuse domain suffixes. +// +// The compression map will be updated with new domain suffixes. If compression +// is nil, compression will not be used. +func (n *Name) pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) { + oldMsg := msg + + if n.Length > nonEncodedNameMax { + return nil, errNameTooLong + } + + // Add a trailing dot to canonicalize name. + if n.Length == 0 || n.Data[n.Length-1] != '.' { + return oldMsg, errNonCanonicalName + } + + // Allow root domain. + if n.Data[0] == '.' && n.Length == 1 { + return append(msg, 0), nil + } + + var nameAsStr string + + // Emit sequence of counted strings, chopping at dots. + for i, begin := 0, 0; i < int(n.Length); i++ { + // Check for the end of the segment. + if n.Data[i] == '.' { + // The two most significant bits have special meaning. + // It isn't allowed for segments to be long enough to + // need them. + if i-begin >= 1<<6 { + return oldMsg, errSegTooLong + } + + // Segments must have a non-zero length. + if i-begin == 0 { + return oldMsg, errZeroSegLen + } + + msg = append(msg, byte(i-begin)) + + for j := begin; j < i; j++ { + msg = append(msg, n.Data[j]) + } + + begin = i + 1 + continue + } + + // We can only compress domain suffixes starting with a new + // segment. A pointer is two bytes with the two most significant + // bits set to 1 to indicate that it is a pointer. + if (i == 0 || n.Data[i-1] == '.') && compression != nil { + if ptr, ok := compression[string(n.Data[i:n.Length])]; ok { + // Hit. Emit a pointer instead of the rest of + // the domain. + return append(msg, byte(ptr>>8|0xC0), byte(ptr)), nil + } + + // Miss. Add the suffix to the compression table if the + // offset can be stored in the available 14 bits. + newPtr := len(msg) - compressionOff + if newPtr <= int(^uint16(0)>>2) { + if nameAsStr == "" { + // allocate n.Data on the heap once, to avoid allocating it + // multiple times (for next labels). + nameAsStr = string(n.Data[:n.Length]) + } + compression[nameAsStr[i:]] = uint16(newPtr) + } + } + } + return append(msg, 0), nil +} + +// unpack unpacks a domain name. +func (n *Name) unpack(msg []byte, off int) (int, error) { + // currOff is the current working offset. + currOff := off + + // newOff is the offset where the next record will start. Pointers lead + // to data that belongs to other names and thus doesn't count towards to + // the usage of this name. + newOff := off + + // ptr is the number of pointers followed. + var ptr int + + // Name is a slice representation of the name data. + name := n.Data[:0] + +Loop: + for { + if currOff >= len(msg) { + return off, errBaseLen + } + c := int(msg[currOff]) + currOff++ + switch c & 0xC0 { + case 0x00: // String segment + if c == 0x00 { + // A zero length signals the end of the name. + break Loop + } + endOff := currOff + c + if endOff > len(msg) { + return off, errCalcLen + } + + // Reject names containing dots. + // See issue golang/go#56246 + for _, v := range msg[currOff:endOff] { + if v == '.' { + return off, errInvalidName + } + } + + name = append(name, msg[currOff:endOff]...) + name = append(name, '.') + currOff = endOff + case 0xC0: // Pointer + if currOff >= len(msg) { + return off, errInvalidPtr + } + c1 := msg[currOff] + currOff++ + if ptr == 0 { + newOff = currOff + } + // Don't follow too many pointers, maybe there's a loop. + if ptr++; ptr > 10 { + return off, errTooManyPtr + } + currOff = (c^0xC0)<<8 | int(c1) + default: + // Prefixes 0x80 and 0x40 are reserved. + return off, errReserved + } + } + if len(name) == 0 { + name = append(name, '.') + } + if len(name) > nonEncodedNameMax { + return off, errNameTooLong + } + n.Length = uint8(len(name)) + if ptr == 0 { + newOff = currOff + } + return newOff, nil +} + +func skipName(msg []byte, off int) (int, error) { + // newOff is the offset where the next record will start. Pointers lead + // to data that belongs to other names and thus doesn't count towards to + // the usage of this name. + newOff := off + +Loop: + for { + if newOff >= len(msg) { + return off, errBaseLen + } + c := int(msg[newOff]) + newOff++ + switch c & 0xC0 { + case 0x00: + if c == 0x00 { + // A zero length signals the end of the name. + break Loop + } + // literal string + newOff += c + if newOff > len(msg) { + return off, errCalcLen + } + case 0xC0: + // Pointer to somewhere else in msg. + + // Pointers are two bytes. + newOff++ + + // Don't follow the pointer as the data here has ended. + break Loop + default: + // Prefixes 0x80 and 0x40 are reserved. + return off, errReserved + } + } + + return newOff, nil +} + +// A Question is a DNS query. +type Question struct { + Name Name + Type Type + Class Class +} + +// pack appends the wire format of the Question to msg. +func (q *Question) pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) { + msg, err := q.Name.pack(msg, compression, compressionOff) + if err != nil { + return msg, &nestedError{"Name", err} + } + msg = packType(msg, q.Type) + return packClass(msg, q.Class), nil +} + +// GoString implements fmt.GoStringer.GoString. +func (q *Question) GoString() string { + return "dnsmessage.Question{" + + "Name: " + q.Name.GoString() + ", " + + "Type: " + q.Type.GoString() + ", " + + "Class: " + q.Class.GoString() + "}" +} + +func unpackResourceBody(msg []byte, off int, hdr ResourceHeader) (ResourceBody, int, error) { + var ( + r ResourceBody + err error + name string + ) + switch hdr.Type { + case TypeA: + var rb AResource + rb, err = unpackAResource(msg, off) + r = &rb + name = "A" + case TypeNS: + var rb NSResource + rb, err = unpackNSResource(msg, off) + r = &rb + name = "NS" + case TypeCNAME: + var rb CNAMEResource + rb, err = unpackCNAMEResource(msg, off) + r = &rb + name = "CNAME" + case TypeSOA: + var rb SOAResource + rb, err = unpackSOAResource(msg, off) + r = &rb + name = "SOA" + case TypePTR: + var rb PTRResource + rb, err = unpackPTRResource(msg, off) + r = &rb + name = "PTR" + case TypeMX: + var rb MXResource + rb, err = unpackMXResource(msg, off) + r = &rb + name = "MX" + case TypeTXT: + var rb TXTResource + rb, err = unpackTXTResource(msg, off, hdr.Length) + r = &rb + name = "TXT" + case TypeAAAA: + var rb AAAAResource + rb, err = unpackAAAAResource(msg, off) + r = &rb + name = "AAAA" + case TypeSRV: + var rb SRVResource + rb, err = unpackSRVResource(msg, off) + r = &rb + name = "SRV" + case TypeSVCB: + var rb SVCBResource + rb, err = unpackSVCBResource(msg, off, hdr.Length) + r = &rb + name = "SVCB" + case TypeHTTPS: + var rb HTTPSResource + rb.SVCBResource, err = unpackSVCBResource(msg, off, hdr.Length) + r = &rb + name = "HTTPS" + case TypeOPT: + var rb OPTResource + rb, err = unpackOPTResource(msg, off, hdr.Length) + r = &rb + name = "OPT" + default: + var rb UnknownResource + rb, err = unpackUnknownResource(hdr.Type, msg, off, hdr.Length) + r = &rb + name = "Unknown" + } + if err != nil { + return nil, off, &nestedError{name + " record", err} + } + return r, off + int(hdr.Length), nil +} + +// A CNAMEResource is a CNAME Resource record. +type CNAMEResource struct { + CNAME Name +} + +func (r *CNAMEResource) realType() Type { + return TypeCNAME +} + +// pack appends the wire format of the CNAMEResource to msg. +func (r *CNAMEResource) pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) { + return r.CNAME.pack(msg, compression, compressionOff) +} + +// GoString implements fmt.GoStringer.GoString. +func (r *CNAMEResource) GoString() string { + return "dnsmessage.CNAMEResource{CNAME: " + r.CNAME.GoString() + "}" +} + +func unpackCNAMEResource(msg []byte, off int) (CNAMEResource, error) { + var cname Name + if _, err := cname.unpack(msg, off); err != nil { + return CNAMEResource{}, err + } + return CNAMEResource{cname}, nil +} + +// An MXResource is an MX Resource record. +type MXResource struct { + Pref uint16 + MX Name +} + +func (r *MXResource) realType() Type { + return TypeMX +} + +// pack appends the wire format of the MXResource to msg. +func (r *MXResource) pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) { + oldMsg := msg + msg = packUint16(msg, r.Pref) + msg, err := r.MX.pack(msg, compression, compressionOff) + if err != nil { + return oldMsg, &nestedError{"MXResource.MX", err} + } + return msg, nil +} + +// GoString implements fmt.GoStringer.GoString. +func (r *MXResource) GoString() string { + return "dnsmessage.MXResource{" + + "Pref: " + printUint16(r.Pref) + ", " + + "MX: " + r.MX.GoString() + "}" +} + +func unpackMXResource(msg []byte, off int) (MXResource, error) { + pref, off, err := unpackUint16(msg, off) + if err != nil { + return MXResource{}, &nestedError{"Pref", err} + } + var mx Name + if _, err := mx.unpack(msg, off); err != nil { + return MXResource{}, &nestedError{"MX", err} + } + return MXResource{pref, mx}, nil +} + +// An NSResource is an NS Resource record. +type NSResource struct { + NS Name +} + +func (r *NSResource) realType() Type { + return TypeNS +} + +// pack appends the wire format of the NSResource to msg. +func (r *NSResource) pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) { + return r.NS.pack(msg, compression, compressionOff) +} + +// GoString implements fmt.GoStringer.GoString. +func (r *NSResource) GoString() string { + return "dnsmessage.NSResource{NS: " + r.NS.GoString() + "}" +} + +func unpackNSResource(msg []byte, off int) (NSResource, error) { + var ns Name + if _, err := ns.unpack(msg, off); err != nil { + return NSResource{}, err + } + return NSResource{ns}, nil +} + +// A PTRResource is a PTR Resource record. +type PTRResource struct { + PTR Name +} + +func (r *PTRResource) realType() Type { + return TypePTR +} + +// pack appends the wire format of the PTRResource to msg. +func (r *PTRResource) pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) { + return r.PTR.pack(msg, compression, compressionOff) +} + +// GoString implements fmt.GoStringer.GoString. +func (r *PTRResource) GoString() string { + return "dnsmessage.PTRResource{PTR: " + r.PTR.GoString() + "}" +} + +func unpackPTRResource(msg []byte, off int) (PTRResource, error) { + var ptr Name + if _, err := ptr.unpack(msg, off); err != nil { + return PTRResource{}, err + } + return PTRResource{ptr}, nil +} + +// An SOAResource is an SOA Resource record. +type SOAResource struct { + NS Name + MBox Name + Serial uint32 + Refresh uint32 + Retry uint32 + Expire uint32 + + // MinTTL the is the default TTL of Resources records which did not + // contain a TTL value and the TTL of negative responses. (RFC 2308 + // Section 4) + MinTTL uint32 +} + +func (r *SOAResource) realType() Type { + return TypeSOA +} + +// pack appends the wire format of the SOAResource to msg. +func (r *SOAResource) pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) { + oldMsg := msg + msg, err := r.NS.pack(msg, compression, compressionOff) + if err != nil { + return oldMsg, &nestedError{"SOAResource.NS", err} + } + msg, err = r.MBox.pack(msg, compression, compressionOff) + if err != nil { + return oldMsg, &nestedError{"SOAResource.MBox", err} + } + msg = packUint32(msg, r.Serial) + msg = packUint32(msg, r.Refresh) + msg = packUint32(msg, r.Retry) + msg = packUint32(msg, r.Expire) + return packUint32(msg, r.MinTTL), nil +} + +// GoString implements fmt.GoStringer.GoString. +func (r *SOAResource) GoString() string { + return "dnsmessage.SOAResource{" + + "NS: " + r.NS.GoString() + ", " + + "MBox: " + r.MBox.GoString() + ", " + + "Serial: " + printUint32(r.Serial) + ", " + + "Refresh: " + printUint32(r.Refresh) + ", " + + "Retry: " + printUint32(r.Retry) + ", " + + "Expire: " + printUint32(r.Expire) + ", " + + "MinTTL: " + printUint32(r.MinTTL) + "}" +} + +func unpackSOAResource(msg []byte, off int) (SOAResource, error) { + var ns Name + off, err := ns.unpack(msg, off) + if err != nil { + return SOAResource{}, &nestedError{"NS", err} + } + var mbox Name + if off, err = mbox.unpack(msg, off); err != nil { + return SOAResource{}, &nestedError{"MBox", err} + } + serial, off, err := unpackUint32(msg, off) + if err != nil { + return SOAResource{}, &nestedError{"Serial", err} + } + refresh, off, err := unpackUint32(msg, off) + if err != nil { + return SOAResource{}, &nestedError{"Refresh", err} + } + retry, off, err := unpackUint32(msg, off) + if err != nil { + return SOAResource{}, &nestedError{"Retry", err} + } + expire, off, err := unpackUint32(msg, off) + if err != nil { + return SOAResource{}, &nestedError{"Expire", err} + } + minTTL, _, err := unpackUint32(msg, off) + if err != nil { + return SOAResource{}, &nestedError{"MinTTL", err} + } + return SOAResource{ns, mbox, serial, refresh, retry, expire, minTTL}, nil +} + +// A TXTResource is a TXT Resource record. +type TXTResource struct { + TXT []string +} + +func (r *TXTResource) realType() Type { + return TypeTXT +} + +// pack appends the wire format of the TXTResource to msg. +func (r *TXTResource) pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) { + oldMsg := msg + for _, s := range r.TXT { + var err error + msg, err = packText(msg, s) + if err != nil { + return oldMsg, err + } + } + return msg, nil +} + +// GoString implements fmt.GoStringer.GoString. +func (r *TXTResource) GoString() string { + s := "dnsmessage.TXTResource{TXT: []string{" + if len(r.TXT) == 0 { + return s + "}}" + } + s += `"` + printString([]byte(r.TXT[0])) + for _, t := range r.TXT[1:] { + s += `", "` + printString([]byte(t)) + } + return s + `"}}` +} + +func unpackTXTResource(msg []byte, off int, length uint16) (TXTResource, error) { + txts := make([]string, 0, 1) + for n := uint16(0); n < length; { + var t string + var err error + if t, off, err = unpackText(msg, off); err != nil { + return TXTResource{}, &nestedError{"text", err} + } + // Check if we got too many bytes. + if length-n < uint16(len(t))+1 { + return TXTResource{}, errCalcLen + } + n += uint16(len(t)) + 1 + txts = append(txts, t) + } + return TXTResource{txts}, nil +} + +// An SRVResource is an SRV Resource record. +type SRVResource struct { + Priority uint16 + Weight uint16 + Port uint16 + Target Name // Not compressed as per RFC 2782. +} + +func (r *SRVResource) realType() Type { + return TypeSRV +} + +// pack appends the wire format of the SRVResource to msg. +func (r *SRVResource) pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) { + oldMsg := msg + msg = packUint16(msg, r.Priority) + msg = packUint16(msg, r.Weight) + msg = packUint16(msg, r.Port) + msg, err := r.Target.pack(msg, nil, compressionOff) + if err != nil { + return oldMsg, &nestedError{"SRVResource.Target", err} + } + return msg, nil +} + +// GoString implements fmt.GoStringer.GoString. +func (r *SRVResource) GoString() string { + return "dnsmessage.SRVResource{" + + "Priority: " + printUint16(r.Priority) + ", " + + "Weight: " + printUint16(r.Weight) + ", " + + "Port: " + printUint16(r.Port) + ", " + + "Target: " + r.Target.GoString() + "}" +} + +func unpackSRVResource(msg []byte, off int) (SRVResource, error) { + priority, off, err := unpackUint16(msg, off) + if err != nil { + return SRVResource{}, &nestedError{"Priority", err} + } + weight, off, err := unpackUint16(msg, off) + if err != nil { + return SRVResource{}, &nestedError{"Weight", err} + } + port, off, err := unpackUint16(msg, off) + if err != nil { + return SRVResource{}, &nestedError{"Port", err} + } + var target Name + if _, err := target.unpack(msg, off); err != nil { + return SRVResource{}, &nestedError{"Target", err} + } + return SRVResource{priority, weight, port, target}, nil +} + +// An AResource is an A Resource record. +type AResource struct { + A [4]byte +} + +func (r *AResource) realType() Type { + return TypeA +} + +// pack appends the wire format of the AResource to msg. +func (r *AResource) pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) { + return packBytes(msg, r.A[:]), nil +} + +// GoString implements fmt.GoStringer.GoString. +func (r *AResource) GoString() string { + return "dnsmessage.AResource{" + + "A: [4]byte{" + printByteSlice(r.A[:]) + "}}" +} + +func unpackAResource(msg []byte, off int) (AResource, error) { + var a [4]byte + if _, err := unpackBytes(msg, off, a[:]); err != nil { + return AResource{}, err + } + return AResource{a}, nil +} + +// An AAAAResource is an AAAA Resource record. +type AAAAResource struct { + AAAA [16]byte +} + +func (r *AAAAResource) realType() Type { + return TypeAAAA +} + +// GoString implements fmt.GoStringer.GoString. +func (r *AAAAResource) GoString() string { + return "dnsmessage.AAAAResource{" + + "AAAA: [16]byte{" + printByteSlice(r.AAAA[:]) + "}}" +} + +// pack appends the wire format of the AAAAResource to msg. +func (r *AAAAResource) pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) { + return packBytes(msg, r.AAAA[:]), nil +} + +func unpackAAAAResource(msg []byte, off int) (AAAAResource, error) { + var aaaa [16]byte + if _, err := unpackBytes(msg, off, aaaa[:]); err != nil { + return AAAAResource{}, err + } + return AAAAResource{aaaa}, nil +} + +// An OPTResource is an OPT pseudo Resource record. +// +// The pseudo resource record is part of the extension mechanisms for DNS +// as defined in RFC 6891. +type OPTResource struct { + Options []Option +} + +// An Option represents a DNS message option within OPTResource. +// +// The message option is part of the extension mechanisms for DNS as +// defined in RFC 6891. +type Option struct { + Code uint16 // option code + Data []byte +} + +// GoString implements fmt.GoStringer.GoString. +func (o *Option) GoString() string { + return "dnsmessage.Option{" + + "Code: " + printUint16(o.Code) + ", " + + "Data: []byte{" + printByteSlice(o.Data) + "}}" +} + +func (r *OPTResource) realType() Type { + return TypeOPT +} + +func (r *OPTResource) pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) { + for _, opt := range r.Options { + msg = packUint16(msg, opt.Code) + l := uint16(len(opt.Data)) + msg = packUint16(msg, l) + msg = packBytes(msg, opt.Data) + } + return msg, nil +} + +// GoString implements fmt.GoStringer.GoString. +func (r *OPTResource) GoString() string { + s := "dnsmessage.OPTResource{Options: []dnsmessage.Option{" + if len(r.Options) == 0 { + return s + "}}" + } + s += r.Options[0].GoString() + for _, o := range r.Options[1:] { + s += ", " + o.GoString() + } + return s + "}}" +} + +func unpackOPTResource(msg []byte, off int, length uint16) (OPTResource, error) { + var opts []Option + for oldOff := off; off < oldOff+int(length); { + var err error + var o Option + o.Code, off, err = unpackUint16(msg, off) + if err != nil { + return OPTResource{}, &nestedError{"Code", err} + } + var l uint16 + l, off, err = unpackUint16(msg, off) + if err != nil { + return OPTResource{}, &nestedError{"Data", err} + } + o.Data = make([]byte, l) + if copy(o.Data, msg[off:]) != int(l) { + return OPTResource{}, &nestedError{"Data", errCalcLen} + } + off += int(l) + opts = append(opts, o) + } + return OPTResource{opts}, nil +} + +// An UnknownResource is a catch-all container for unknown record types. +type UnknownResource struct { + Type Type + Data []byte +} + +func (r *UnknownResource) realType() Type { + return r.Type +} + +// pack appends the wire format of the UnknownResource to msg. +func (r *UnknownResource) pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) { + return packBytes(msg, r.Data[:]), nil +} + +// GoString implements fmt.GoStringer.GoString. +func (r *UnknownResource) GoString() string { + return "dnsmessage.UnknownResource{" + + "Type: " + r.Type.GoString() + ", " + + "Data: []byte{" + printByteSlice(r.Data) + "}}" +} + +func unpackUnknownResource(recordType Type, msg []byte, off int, length uint16) (UnknownResource, error) { + parsed := UnknownResource{ + Type: recordType, + Data: make([]byte, length), + } + if _, err := unpackBytes(msg, off, parsed.Data); err != nil { + return UnknownResource{}, err + } + return parsed, nil +} diff --git a/vendor/golang.org/x/net/dns/dnsmessage/svcb.go b/vendor/golang.org/x/net/dns/dnsmessage/svcb.go new file mode 100644 index 0000000..4840516 --- /dev/null +++ b/vendor/golang.org/x/net/dns/dnsmessage/svcb.go @@ -0,0 +1,326 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dnsmessage + +import ( + "slices" +) + +// An SVCBResource is an SVCB Resource record. +type SVCBResource struct { + Priority uint16 + Target Name + Params []SVCParam // Must be in strict increasing order by Key. +} + +func (r *SVCBResource) realType() Type { + return TypeSVCB +} + +// GoString implements fmt.GoStringer.GoString. +func (r *SVCBResource) GoString() string { + b := []byte("dnsmessage.SVCBResource{" + + "Priority: " + printUint16(r.Priority) + ", " + + "Target: " + r.Target.GoString() + ", " + + "Params: []dnsmessage.SVCParam{") + if len(r.Params) > 0 { + b = append(b, r.Params[0].GoString()...) + for _, p := range r.Params[1:] { + b = append(b, ", "+p.GoString()...) + } + } + b = append(b, "}}"...) + return string(b) +} + +// An HTTPSResource is an HTTPS Resource record. +// It has the same format as the SVCB record. +type HTTPSResource struct { + // Alias for SVCB resource record. + SVCBResource +} + +func (r *HTTPSResource) realType() Type { + return TypeHTTPS +} + +// GoString implements fmt.GoStringer.GoString. +func (r *HTTPSResource) GoString() string { + return "dnsmessage.HTTPSResource{SVCBResource: " + r.SVCBResource.GoString() + "}" +} + +// GetParam returns a parameter value by key. +func (r *SVCBResource) GetParam(key SVCParamKey) (value []byte, ok bool) { + for i := range r.Params { + if r.Params[i].Key == key { + return r.Params[i].Value, true + } + if r.Params[i].Key > key { + break + } + } + return nil, false +} + +// SetParam sets a parameter value by key. +// The Params list is kept sorted by key. +func (r *SVCBResource) SetParam(key SVCParamKey, value []byte) { + i := 0 + for i < len(r.Params) { + if r.Params[i].Key >= key { + break + } + i++ + } + + if i < len(r.Params) && r.Params[i].Key == key { + r.Params[i].Value = value + return + } + + r.Params = slices.Insert(r.Params, i, SVCParam{Key: key, Value: value}) +} + +// DeleteParam deletes a parameter by key. +// It returns true if the parameter was present. +func (r *SVCBResource) DeleteParam(key SVCParamKey) bool { + for i := range r.Params { + if r.Params[i].Key == key { + r.Params = slices.Delete(r.Params, i, i+1) + return true + } + if r.Params[i].Key > key { + break + } + } + return false +} + +// A SVCParam is a service parameter. +type SVCParam struct { + Key SVCParamKey + Value []byte +} + +// GoString implements fmt.GoStringer.GoString. +func (p SVCParam) GoString() string { + return "dnsmessage.SVCParam{" + + "Key: " + p.Key.GoString() + ", " + + "Value: []byte{" + printByteSlice(p.Value) + "}}" +} + +// A SVCParamKey is a key for a service parameter. +type SVCParamKey uint16 + +// Values defined at https://www.iana.org/assignments/dns-svcb/dns-svcb.xhtml#dns-svcparamkeys. +const ( + SVCParamMandatory SVCParamKey = 0 + SVCParamALPN SVCParamKey = 1 + SVCParamNoDefaultALPN SVCParamKey = 2 + SVCParamPort SVCParamKey = 3 + SVCParamIPv4Hint SVCParamKey = 4 + SVCParamECH SVCParamKey = 5 + SVCParamIPv6Hint SVCParamKey = 6 + SVCParamDOHPath SVCParamKey = 7 + SVCParamOHTTP SVCParamKey = 8 + SVCParamTLSSupportedGroups SVCParamKey = 9 +) + +var svcParamKeyNames = map[SVCParamKey]string{ + SVCParamMandatory: "Mandatory", + SVCParamALPN: "ALPN", + SVCParamNoDefaultALPN: "NoDefaultALPN", + SVCParamPort: "Port", + SVCParamIPv4Hint: "IPv4Hint", + SVCParamECH: "ECH", + SVCParamIPv6Hint: "IPv6Hint", + SVCParamDOHPath: "DOHPath", + SVCParamOHTTP: "OHTTP", + SVCParamTLSSupportedGroups: "TLSSupportedGroups", +} + +// String implements fmt.Stringer.String. +func (k SVCParamKey) String() string { + if n, ok := svcParamKeyNames[k]; ok { + return n + } + return printUint16(uint16(k)) +} + +// GoString implements fmt.GoStringer.GoString. +func (k SVCParamKey) GoString() string { + if n, ok := svcParamKeyNames[k]; ok { + return "dnsmessage.SVCParam" + n + } + return printUint16(uint16(k)) +} + +func (r *SVCBResource) pack(msg []byte, _ map[string]uint16, _ int) ([]byte, error) { + oldMsg := msg + msg = packUint16(msg, r.Priority) + // https://datatracker.ietf.org/doc/html/rfc3597#section-4 prohibits name + // compression for RR types that are not "well-known". + // https://datatracker.ietf.org/doc/html/rfc9460#section-2.2 explicitly states that + // compression of the Target is prohibited, following RFC 3597. + msg, err := r.Target.pack(msg, nil, 0) + if err != nil { + return oldMsg, &nestedError{"SVCBResource.Target", err} + } + var previousKey SVCParamKey + for i, param := range r.Params { + if i > 0 && param.Key <= previousKey { + return oldMsg, &nestedError{"SVCBResource.Params", errParamOutOfOrder} + } + if len(param.Value) > (1<<16)-1 { + return oldMsg, &nestedError{"SVCBResource.Params", errTooLongSVCBValue} + } + msg = packUint16(msg, uint16(param.Key)) + msg = packUint16(msg, uint16(len(param.Value))) + msg = append(msg, param.Value...) + } + return msg, nil +} + +func unpackSVCBResource(msg []byte, off int, length uint16) (SVCBResource, error) { + // Wire format reference: https://www.rfc-editor.org/rfc/rfc9460.html#section-2.2. + r := SVCBResource{} + paramsOff := off + bodyEnd := off + int(length) + + var err error + if r.Priority, paramsOff, err = unpackUint16(msg, paramsOff); err != nil { + return SVCBResource{}, &nestedError{"Priority", err} + } + + if paramsOff, err = r.Target.unpack(msg, paramsOff); err != nil { + return SVCBResource{}, &nestedError{"Target", err} + } + + // Two-pass parsing to avoid allocations. + // First, count the number of params. + n := 0 + var totalValueLen uint16 + off = paramsOff + var previousKey uint16 + for off < bodyEnd { + var key, len uint16 + if key, off, err = unpackUint16(msg, off); err != nil { + return SVCBResource{}, &nestedError{"Params key", err} + } + if n > 0 && key <= previousKey { + // As per https://www.rfc-editor.org/rfc/rfc9460.html#section-2.2, clients MUST + // consider the RR malformed if the SvcParamKeys are not in strictly increasing numeric order + return SVCBResource{}, &nestedError{"Params", errParamOutOfOrder} + } + if len, off, err = unpackUint16(msg, off); err != nil { + return SVCBResource{}, &nestedError{"Params value length", err} + } + if off+int(len) > bodyEnd { + return SVCBResource{}, errResourceLen + } + totalValueLen += len + off += int(len) + n++ + } + if off != bodyEnd { + return SVCBResource{}, errResourceLen + } + + // Second, fill in the params. + r.Params = make([]SVCParam, n) + // valuesBuf is used to hold all param values to reduce allocations. + // Each param's Value slice will point into this buffer. + valuesBuf := make([]byte, totalValueLen) + off = paramsOff + for i := 0; i < n; i++ { + p := &r.Params[i] + var key, len uint16 + if key, off, err = unpackUint16(msg, off); err != nil { + return SVCBResource{}, &nestedError{"param key", err} + } + p.Key = SVCParamKey(key) + if len, off, err = unpackUint16(msg, off); err != nil { + return SVCBResource{}, &nestedError{"param length", err} + } + if copy(valuesBuf, msg[off:off+int(len)]) != int(len) { + return SVCBResource{}, &nestedError{"param value", errCalcLen} + } + p.Value = valuesBuf[:len:len] + valuesBuf = valuesBuf[len:] + off += int(len) + } + + return r, nil +} + +// genericSVCBResource parses a single Resource Record compatible with SVCB. +func (p *Parser) genericSVCBResource(svcbType Type) (SVCBResource, error) { + if !p.resHeaderValid || p.resHeaderType != svcbType { + return SVCBResource{}, ErrNotStarted + } + r, err := unpackSVCBResource(p.msg, p.off, p.resHeaderLength) + if err != nil { + return SVCBResource{}, err + } + p.off += int(p.resHeaderLength) + p.resHeaderValid = false + p.index++ + return r, nil +} + +// SVCBResource parses a single SVCBResource. +// +// One of the XXXHeader methods must have been called before calling this +// method. +func (p *Parser) SVCBResource() (SVCBResource, error) { + return p.genericSVCBResource(TypeSVCB) +} + +// HTTPSResource parses a single HTTPSResource. +// +// One of the XXXHeader methods must have been called before calling this +// method. +func (p *Parser) HTTPSResource() (HTTPSResource, error) { + svcb, err := p.genericSVCBResource(TypeHTTPS) + if err != nil { + return HTTPSResource{}, err + } + return HTTPSResource{svcb}, nil +} + +// genericSVCBResource is the generic implementation for adding SVCB-like resources. +func (b *Builder) genericSVCBResource(h ResourceHeader, r SVCBResource) error { + if err := b.checkResourceSection(); err != nil { + return err + } + msg, lenOff, err := h.pack(b.msg, b.compression, b.start) + if err != nil { + return &nestedError{"ResourceHeader", err} + } + preLen := len(msg) + if msg, err = r.pack(msg, b.compression, b.start); err != nil { + return &nestedError{"ResourceBody", err} + } + if err := h.fixLen(msg, lenOff, preLen); err != nil { + return err + } + if err := b.incrementSectionCount(); err != nil { + return err + } + b.msg = msg + return nil +} + +// SVCBResource adds a single SVCBResource. +func (b *Builder) SVCBResource(h ResourceHeader, r SVCBResource) error { + h.Type = r.realType() + return b.genericSVCBResource(h, r) +} + +// HTTPSResource adds a single HTTPSResource. +func (b *Builder) HTTPSResource(h ResourceHeader, r HTTPSResource) error { + h.Type = r.realType() + return b.genericSVCBResource(h, r.SVCBResource) +} diff --git a/vendor/modules.txt b/vendor/modules.txt index bb084c6..1d6ba08 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -255,13 +255,111 @@ github.com/minio/minio-go/v7/pkg/tags # github.com/mitchellh/mapstructure v1.5.0 ## explicit; go 1.14 github.com/mitchellh/mapstructure +# github.com/pion/datachannel v1.6.0 +## explicit; go 1.21 +github.com/pion/datachannel +# github.com/pion/dtls/v3 v3.1.2 +## explicit; go 1.24.0 +github.com/pion/dtls/v3 +github.com/pion/dtls/v3/internal/ciphersuite +github.com/pion/dtls/v3/internal/ciphersuite/types +github.com/pion/dtls/v3/internal/closer +github.com/pion/dtls/v3/internal/net +github.com/pion/dtls/v3/internal/net/udp +github.com/pion/dtls/v3/internal/util +github.com/pion/dtls/v3/pkg/crypto/ccm +github.com/pion/dtls/v3/pkg/crypto/ciphersuite +github.com/pion/dtls/v3/pkg/crypto/clientcertificate +github.com/pion/dtls/v3/pkg/crypto/elliptic +github.com/pion/dtls/v3/pkg/crypto/fingerprint +github.com/pion/dtls/v3/pkg/crypto/hash +github.com/pion/dtls/v3/pkg/crypto/prf +github.com/pion/dtls/v3/pkg/crypto/signature +github.com/pion/dtls/v3/pkg/crypto/signaturehash +github.com/pion/dtls/v3/pkg/net +github.com/pion/dtls/v3/pkg/protocol +github.com/pion/dtls/v3/pkg/protocol/alert +github.com/pion/dtls/v3/pkg/protocol/extension +github.com/pion/dtls/v3/pkg/protocol/handshake +github.com/pion/dtls/v3/pkg/protocol/recordlayer +# github.com/pion/ice/v4 v4.2.2 +## explicit; go 1.24.0 +github.com/pion/ice/v4 +github.com/pion/ice/v4/internal/atomic +github.com/pion/ice/v4/internal/fakenet +github.com/pion/ice/v4/internal/netutil +github.com/pion/ice/v4/internal/stun +github.com/pion/ice/v4/internal/taskloop +# github.com/pion/interceptor v0.1.44 +## explicit; go 1.21.0 +github.com/pion/interceptor +github.com/pion/interceptor/internal/ntp +github.com/pion/interceptor/internal/rtpbuffer +github.com/pion/interceptor/internal/sequencenumber +github.com/pion/interceptor/pkg/flexfec +github.com/pion/interceptor/pkg/flexfec/util +github.com/pion/interceptor/pkg/nack +github.com/pion/interceptor/pkg/report +github.com/pion/interceptor/pkg/rfc8888 +github.com/pion/interceptor/pkg/stats +github.com/pion/interceptor/pkg/twcc +# github.com/pion/logging v0.2.4 +## explicit; go 1.20 +github.com/pion/logging +# github.com/pion/mdns/v2 v2.1.0 +## explicit; go 1.21 +github.com/pion/mdns/v2 # github.com/pion/randutil v0.1.0 ## explicit; go 1.14 github.com/pion/randutil +# github.com/pion/rtcp v1.2.16 +## explicit; go 1.21 +github.com/pion/rtcp # github.com/pion/rtp v1.10.1 ## explicit; go 1.21 github.com/pion/rtp +github.com/pion/rtp/codecs github.com/pion/rtp/codecs/av1/obu +github.com/pion/rtp/codecs/vp9 +# github.com/pion/sctp v1.9.4 +## explicit; go 1.24.0 +github.com/pion/sctp +# github.com/pion/sdp/v3 v3.0.18 +## explicit; go 1.24 +github.com/pion/sdp/v3 +# github.com/pion/srtp/v3 v3.0.10 +## explicit; go 1.21 +github.com/pion/srtp/v3 +# github.com/pion/stun/v3 v3.1.1 +## explicit; go 1.21 +github.com/pion/stun/v3 +github.com/pion/stun/v3/internal/hmac +# github.com/pion/transport/v4 v4.0.1 +## explicit; go 1.21 +github.com/pion/transport/v4 +github.com/pion/transport/v4/deadline +github.com/pion/transport/v4/netctx +github.com/pion/transport/v4/packetio +github.com/pion/transport/v4/replaydetector +github.com/pion/transport/v4/stdnet +github.com/pion/transport/v4/utils/xor +github.com/pion/transport/v4/vnet +# github.com/pion/turn/v4 v4.1.4 +## explicit; go 1.21 +github.com/pion/turn/v4 +github.com/pion/turn/v4/internal/allocation +github.com/pion/turn/v4/internal/client +github.com/pion/turn/v4/internal/ipnet +github.com/pion/turn/v4/internal/proto +github.com/pion/turn/v4/internal/server +# github.com/pion/webrtc/v4 v4.2.11 +## explicit; go 1.24.0 +github.com/pion/webrtc/v4 +github.com/pion/webrtc/v4/internal/fmtp +github.com/pion/webrtc/v4/internal/mux +github.com/pion/webrtc/v4/internal/util +github.com/pion/webrtc/v4/pkg/media +github.com/pion/webrtc/v4/pkg/rtcerr # github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 ## explicit github.com/pmezard/go-difflib/difflib @@ -351,6 +449,9 @@ github.com/vektah/gqlparser/v2/lexer github.com/vektah/gqlparser/v2/parser github.com/vektah/gqlparser/v2/validator github.com/vektah/gqlparser/v2/validator/rules +# github.com/wlynxg/anet v0.0.5 +## explicit; go 1.20 +github.com/wlynxg/anet # github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb ## explicit github.com/xeipuuv/gojsonpointer @@ -411,6 +512,7 @@ golang.org/x/mod/semver # golang.org/x/net v0.50.0 ## explicit; go 1.24.0 golang.org/x/net/bpf +golang.org/x/net/dns/dnsmessage golang.org/x/net/html golang.org/x/net/html/atom golang.org/x/net/http/httpguts